# Project     OBS Indications on Corsair Keyboard
# @author     Roman Zhigunov [reejk.xyz]
# @license    GPLv3 - Copyright (c) 2022 Roman Zhigunov

# this script requires cuesdk, to install it use this command:
# > python -m pip install cuesdk==1.1.0

from enum import Enum
from typing import List
from cuesdk import CueSdk, CorsairDeviceType, CorsairLedId
import obspython as obs

class BindingWhat(Enum): 
    Source = 0
    Scene = 1
    Filter = 2

class BindingAction(Enum): 
    Active = 0
    Showing = 1
    Mute = 2

class Binding:
    def __init__(self, active : bool, led : CorsairLedId, what : BindingWhat, name : str, action : BindingAction, off_color, on_color):
        self.active = active
        self.led = led
        self.what = what
        self.name = name
        self.action = action
        self.off_color = off_color
        self.on_color = on_color
        self.subscribed = None


loaded = False
keyboard_index = 0
bindings : List[Binding] = list()


def list_keyboards():
    global sdk
    result = {}
    for device_index in range(sdk.get_device_count()):
        device = sdk.get_device_info(device_index)
        if device.type == CorsairDeviceType.Keyboard:
            result[str(device_index)] = device.model

    return result


def list_scenes():
    results = list()    
    scenes = obs.obs_frontend_get_scenes()
    for scene in scenes:
        results.append(obs.obs_source_get_name(scene))

    obs.source_list_release(scenes)
    results.sort()
    results.insert(0, "None")
    return results

def list_sources():
    results = list()
    sources = obs.obs_enum_sources()
    for source in sources:
        results.append(obs.obs_source_get_name(source))

    obs.source_list_release(sources)
    results.sort()
    results.insert(0, "None")
    return results

def list_filters():
    results = list()
    sources = obs.obs_enum_sources()
    for source in sources:
        filters = obs.obs_source_backup_filters(source)
        filter_count = obs.obs_source_filter_count(source)
        for i in range(filter_count):
            filter_name = obs.obs_data_get_string(obs.obs_data_array_item(filters, i), "name")
            results.append(obs.obs_source_get_name(source) + "\n" + filter_name)
        
        obs.obs_data_array_release(filters)

    obs.source_list_release(sources)
    results.sort()
    results.insert(0, "None")
    return results


def script_description():
	return """<b>Indications on Corsair Keyboard [iCUE]</b>"""


def cb_change_what_visibility(props, prop, settings):
    name = obs.obs_property_name(prop)
    id = name[4:]
    what = BindingWhat(obs.obs_data_get_int(settings, name))
    obs.obs_property_set_visible(obs.obs_properties_get(props, "source" + id), what == BindingWhat.Source)
    obs.obs_property_set_visible(obs.obs_properties_get(props, "action" + id), what == BindingWhat.Source)
    obs.obs_property_set_visible(obs.obs_properties_get(props, "scene" + id), what == BindingWhat.Scene)
    obs.obs_property_set_visible(obs.obs_properties_get(props, "filter" + id), what == BindingWhat.Filter)
    return True


def script_item_properties(index, scenes, sources, filters, leds):
    global bindings

    props = obs.obs_properties_create()

    id = str(index)
    what = bindings[index].what

    param_what = obs.obs_properties_add_list(props, "what" + id, "What", obs.OBS_COMBO_TYPE_LIST, obs.OBS_COMBO_FORMAT_INT)
    obs.obs_property_list_add_int(param_what, "Source", int(BindingWhat.Source.value))
    obs.obs_property_list_add_int(param_what, "Scene", int(BindingWhat.Scene.value))
    obs.obs_property_list_add_int(param_what, "Filter", int(BindingWhat.Filter.value))
    obs.obs_property_set_modified_callback(param_what, cb_change_what_visibility)

    param_source = obs.obs_properties_add_list(props, "source" + id, "Source", obs.OBS_COMBO_TYPE_LIST, obs.OBS_COMBO_FORMAT_STRING)
    for name in sources:
        obs.obs_property_list_add_string(param_source, name, name)
    obs.obs_property_set_visible(param_source, what == BindingWhat.Source)

    param_action = obs.obs_properties_add_list(props, "action" + id, "Action", obs.OBS_COMBO_TYPE_LIST, obs.OBS_COMBO_FORMAT_INT)
    obs.obs_property_list_add_int(param_action, "Active", int(BindingAction.Active.value))
    obs.obs_property_list_add_int(param_action, "Showing", int(BindingAction.Showing.value))
    obs.obs_property_list_add_int(param_action, "Mute", int(BindingAction.Mute.value))
    obs.obs_property_set_visible(param_action, what == BindingWhat.Source)

    param_scene = obs.obs_properties_add_list(props, "scene" + id, "Scene", obs.OBS_COMBO_TYPE_LIST, obs.OBS_COMBO_FORMAT_STRING)
    for name in scenes:
        obs.obs_property_list_add_string(param_scene, name, name)
    obs.obs_property_set_visible(param_scene, what == BindingWhat.Scene)

    param_filter = obs.obs_properties_add_list(props, "filter" + id, "Filter", obs.OBS_COMBO_TYPE_LIST, obs.OBS_COMBO_FORMAT_STRING)
    for name in filters:
        obs.obs_property_list_add_string(param_filter, name.replace("\n", ": "), name)
    obs.obs_property_set_visible(param_filter, what == BindingWhat.Filter)

    param_led = obs.obs_properties_add_list(props, "key" + id, "Key", obs.OBS_COMBO_TYPE_LIST, obs.OBS_COMBO_FORMAT_STRING)
    for name in leds:
        obs.obs_property_list_add_string(param_led, name, name)

    param_off_color = obs.obs_properties_add_color(props, "off_color" + id, "Inactive Color")
    param_on_color = obs.obs_properties_add_color(props, "on_color" + id, "Active Color")

    return props


def script_properties():
    global bindings
    props = obs.obs_properties_create()

    scenes = list_scenes()
    sources = list_sources()
    filters = list_filters()
    keyboards = list_keyboards()
    leds = leds = [k[2:] for k in CorsairLedId.__dir__(CorsairLedId()) if k.startswith("K_")]
    leds.sort()

    param_keyboard = obs.obs_properties_add_list(props, "keyboard", "Keyboard", obs.OBS_COMBO_TYPE_LIST, obs.OBS_COMBO_FORMAT_STRING)
    for index, name in keyboards.items():
        obs.obs_property_list_add_string(param_keyboard, name, index)

    param_bindings = obs.obs_properties_add_int(props, "bindings", "Bindings <i>(reload script after change)</i>", 0, 100, 1)

    for i in range(len(bindings)):
        param_group = obs.obs_properties_add_group(props, "binding" + str(i), "Binding", obs.OBS_GROUP_CHECKABLE, script_item_properties(i, scenes, sources, filters, leds))

    return props


def script_defaults(settings):
    pass



def int_to_led_color(s : int):
    x = list(s.to_bytes(4, "little"))
    return x[0:3]

def load_binding(settings, id) -> Binding:
    active = obs.obs_data_get_bool(settings, "binding" + id)
    led = obs.obs_data_get_string(settings, "key" + id)
    what = obs.obs_data_get_int(settings, "what" + id)
    source = obs.obs_data_get_string(settings, "source" + id)
    scene = obs.obs_data_get_string(settings, "scene" + id)
    filter = obs.obs_data_get_string(settings, "filter" + id)
    action = obs.obs_data_get_int(settings, "action" + id)
    off_color = obs.obs_data_get_int(settings, "off_color" + id)
    on_color = obs.obs_data_get_int(settings, "on_color" + id)

    what = BindingWhat(what)
    action = BindingAction(action)

    ledId = "K_" + led
    if ledId not in CorsairLedId.__dict__:
        ledId = CorsairLedId.K_0
    else:
        ledId = CorsairLedId.__dict__[ledId]

    off_color = int_to_led_color(off_color)
    on_color = int_to_led_color(on_color)

    name = scene if what == BindingWhat.Scene else (source if what == BindingWhat.Source else filter)
    return Binding(active, ledId, what, name, action, off_color, on_color)

def script_update(settings):
    global bindings, keyboard_index, loaded

    unsubscribe_all()

    try:
        keyboard_index = int(obs.obs_data_get_string(settings, "keyboard"))
    except:
        keyboard_index = 0
    
    count = obs.obs_data_get_int(settings, "bindings")
    bindings = list()
    
    for i in range(count):
        bindings.append(load_binding(settings, str(i)))

    set_all_active_leds()
    subscribe_all()


def cb_frontend_event(event):
    global loaded
    if event == obs.OBS_FRONTEND_EVENT_SCENE_CHANGED:
        if not loaded:
            unsubscribe_all()
            set_all_active_leds()
            subscribe_all()
            loaded = True
        else:
            set_all_active_leds()


def script_load(settings):
    global sdk
    sdk = CueSdk()
    if not sdk.connect():
        err = sdk.get_last_error()
        print("CUE Handshake failed: {err}".format(err=err))

    obs.obs_frontend_add_event_callback(cb_frontend_event)

def script_unload():
    global sdk
    unsubscribe_all()
    obs.obs_frontend_remove_event_callback(cb_frontend_event)
    sdk = None


def cb_mute_active(calldata):
    set_all_active_leds()

def subscribe_all():
    global bindings

    for bind in bindings:
        if not bind.active:
            continue

        if bind.what == BindingWhat.Source:
            source = obs.obs_get_source_by_name(bind.name)
            if source is None:
                continue

            handler = obs.obs_source_get_signal_handler(source)
            if bind.action == BindingAction.Active:
                obs.signal_handler_connect(handler, "activate", cb_mute_active)
                obs.signal_handler_connect(handler, "deactivate", cb_mute_active)
            elif bind.action == BindingAction.Showing:
                obs.signal_handler_connect(handler, "show", cb_mute_active)
                obs.signal_handler_connect(handler, "hide", cb_mute_active)
            elif bind.action == BindingAction.Mute:
                obs.signal_handler_connect(handler, "mute", cb_mute_active)

            obs.obs_source_release(source)
            bind.subscribed = bind.action

        elif bind.what == BindingWhat.Filter and "\n" in bind.name:
            (source_name, filter_name) = bind.name.split("\n")
            source = obs.obs_get_source_by_name(source_name)
            if source is not None:
                filter = obs.obs_source_get_filter_by_name(source, filter_name)
                if filter is not None:
                    handler = obs.obs_source_get_signal_handler(filter)
                    obs.signal_handler_connect(handler, "enable", cb_mute_active)
                    obs.obs_source_release(filter)
                    bind.subscribed = bind.action

                obs.obs_source_release(source)

def unsubscribe_all():
    global bindings
    
    for bind in bindings:
        if bind.subscribed is None:
            continue

        if bind.what == BindingWhat.Source:
            source = obs.obs_get_source_by_name(bind.name)
            if source is None:
                continue

            handler = obs.obs_source_get_signal_handler(source)
            if bind.subscribed == BindingAction.Active:
                obs.signal_handler_disconnect(handler, "activate", cb_mute_active)
                obs.signal_handler_disconnect(handler, "deactivate", cb_mute_active)
            elif bind.subscribed == BindingAction.Showing:
                obs.signal_handler_disconnect(handler, "show", cb_mute_active)
                obs.signal_handler_disconnect(handler, "hide", cb_mute_active)
            elif bind.subscribed == BindingAction.Mute:
                obs.signal_handler_disconnect(handler, "mute", cb_mute_active)

            obs.obs_source_release(source)

        elif bind.what == BindingWhat.Filter and "\n" in bind.name:
            (source_name, filter_name) = bind.name.split("\n")
            source = obs.obs_get_source_by_name(source_name)
            if source is not None:
                filter = obs.obs_source_get_filter_by_name(source, filter_name)
                if filter is not None:
                    handler = obs.obs_source_get_signal_handler(filter)
                    obs.signal_handler_disconnect(handler, "enable", cb_mute_active)
                    obs.obs_source_release(filter)

                obs.obs_source_release(source)

        bind.subscribed = None


def set_all_active_leds():
    global sdk, bindings, keyboard_index

    current_scene = obs.obs_frontend_get_current_scene()
    current_scene_name = obs.obs_source_get_name(current_scene) if current_scene is not None else ""
    obs.obs_source_release(current_scene)

    leds = dict()
    for bind in bindings:
        if not bind.active:
            continue

        is_active = False
        if bind.what == BindingWhat.Source:
            source = obs.obs_get_source_by_name(bind.name)
            if source is not None:
                if bind.action == BindingAction.Active:
                    is_active = obs.obs_source_active(source)
                elif bind.action == BindingAction.Showing:
                    is_active = obs.obs_source_showing(source)
                elif bind.action == BindingAction.Mute:
                    is_active = not obs.obs_source_muted(source)
                
                obs.obs_source_release(source)
                
        elif bind.what == BindingWhat.Scene:
            is_active = (bind.name == current_scene_name)

        elif bind.what == BindingWhat.Filter and "\n" in bind.name:
            (source_name, filter_name) = bind.name.split("\n")
            source = obs.obs_get_source_by_name(source_name)
            if source is not None:
                filter = obs.obs_source_get_filter_by_name(source, filter_name)
                if filter is not None:
                    is_active = obs.obs_source_enabled(filter)
                    obs.obs_source_release(filter)

                obs.obs_source_release(source)

        if is_active:
            leds[bind.led] = bind.on_color
        elif bind.led not in leds:
            leds[bind.led] = bind.off_color

    # print("set {leds}".format(leds=leds))
    sdk.set_led_colors_buffer_by_device_index(keyboard_index, leds)
    sdk.set_led_colors_flush_buffer()