# Project     OBS Indications on Corsair Keyboard 1.1
# @author     Roman Zhigunov [reejk.xyz]
# @license    GPLv3 - Copyright (c) 2022-2023 Roman Zhigunov

# this script requires cuesdk, to install it use this command:
# > python -m pip install cuesdk==4.0.3

from enum import Enum
from typing import List
from cuesdk import CueSdk, CorsairDeviceType, CorsairDeviceFilter, CorsairLedColor, CorsairError, CorsairLedId_Keyboard
import obspython as obs

class BindingWhat(Enum): 
    Source = 0
    Scene = 1
    Filter = 2

class BindingAction(Enum): 
    Active = 0
    Showing = 1
    Mute = 2

class Binding:
    def __init__(self, active : bool, leds : int, what : BindingWhat, name : str, action : BindingAction, off_color, on_color):
        self.active = active
        self.leds = leds
        self.what = what
        self.name = name
        self.action = action
        self.off_color = off_color
        self.on_color = on_color
        self.subscribed = None


def keys_range(keys: dict[str, int], start : CorsairLedId_Keyboard = None, end : CorsairLedId_Keyboard = None):
    if not start:
        start = CorsairLedId_Keyboard.CLK_Escape
    if not end:
        end = CorsairLedId_Keyboard.CLK_Fn

    return [ v for v in keys.values() if v >= int(start) and v <= int(end) ]


sdk : CueSdk = None
loaded = False
keyboard_id = str()
bindings : List[Binding] = list()

keys = dict( (k[4:], CorsairLedId_Keyboard.__dict__[k]) for k in sorted([ k for k in CorsairLedId_Keyboard.__dict__.keys() if not k.startswith('_') ]) )

keys_group_numpad = keys_range(keys, CorsairLedId_Keyboard.CLK_NumLock, CorsairLedId_Keyboard.CLK_KeypadPeriodAndDelete)
keys_group_navigation = keys_range(keys, CorsairLedId_Keyboard.CLK_Insert, CorsairLedId_Keyboard.CLK_RightArrow)
keys_group_typing = keys_range(keys, CorsairLedId_Keyboard.CLK_GraveAccentAndTilde, CorsairLedId_Keyboard.CLK_RightShift) + [ int(CorsairLedId_Keyboard.CLK_Space) ]
keys_group_functional = keys_range(keys, CorsairLedId_Keyboard.CLK_F1, CorsairLedId_Keyboard.CLK_F12)

keys_groups = {
    'Group: Full': keys_range(keys),
    'Group: Left Side': keys_range(keys, CorsairLedId_Keyboard.CLK_Escape, CorsairLedId_Keyboard.CLK_RightCtrl) + [ int(CorsairLedId_Keyboard.CLK_Fn) ], 
    'Group: Right Side': keys_group_navigation + keys_group_numpad + keys_range(keys, CorsairLedId_Keyboard.CLK_PrintScreen, CorsairLedId_Keyboard.CLK_PauseBreak),
    'Group: Typing': keys_group_typing,
    'Group: Control': [ int(CorsairLedId_Keyboard.CLK_Escape), int(CorsairLedId_Keyboard.CLK_Fn) ] + keys_range(keys, CorsairLedId_Keyboard.CLK_PrintScreen, CorsairLedId_Keyboard.CLK_PauseBreak) + keys_range(keys, CorsairLedId_Keyboard.CLK_LeftCtrl, CorsairLedId_Keyboard.CLK_LeftAlt) + keys_range(keys, CorsairLedId_Keyboard.CLK_RightAlt, CorsairLedId_Keyboard.CLK_RightCtrl),
    'Group: Function': keys_group_functional,
    'Group: Navigation': keys_group_navigation,
    'Group: Numpad': keys_group_numpad
}


def list_keyboards():
    global sdk
    result = {}
    devices, err = sdk.get_devices(CorsairDeviceFilter(device_type_mask=CorsairDeviceType.CDT_Keyboard))
    if err != CorsairError.CE_Success or not devices:
        return result
    
    for device in devices:
        result[device.device_id] = device.model

    return result


def list_scenes():
    results = list()    
    scenes = obs.obs_frontend_get_scenes()
    for scene in scenes:
        results.append(obs.obs_source_get_name(scene))

    obs.source_list_release(scenes)
    results.sort()
    results.insert(0, "None")
    return results

def list_sources():
    results = list()
    sources = obs.obs_enum_sources()
    for source in sources:
        results.append(obs.obs_source_get_name(source))

    obs.source_list_release(sources)
    results.sort()
    results.insert(0, "None")
    return results

def list_filters():
    results = list()
    sources = obs.obs_enum_sources()
    for source in sources:
        filters = obs.obs_source_backup_filters(source)
        filter_count = obs.obs_source_filter_count(source)
        for i in range(filter_count):
            filter_name = obs.obs_data_get_string(obs.obs_data_array_item(filters, i), "name")
            results.append(obs.obs_source_get_name(source) + "\n" + filter_name)
        
        obs.obs_data_array_release(filters)

    obs.source_list_release(sources)
    results.sort()
    results.insert(0, "None")
    return results


def script_description():
	return """<b>Indications on Corsair Keyboard [iCUE]</b>"""


def cb_change_what_visibility(props, prop, settings):
    name = obs.obs_property_name(prop)
    id = name[4:]
    what = BindingWhat(obs.obs_data_get_int(settings, name))
    obs.obs_property_set_visible(obs.obs_properties_get(props, "source" + id), what == BindingWhat.Source)
    obs.obs_property_set_visible(obs.obs_properties_get(props, "action" + id), what == BindingWhat.Source)
    obs.obs_property_set_visible(obs.obs_properties_get(props, "scene" + id), what == BindingWhat.Scene)
    obs.obs_property_set_visible(obs.obs_properties_get(props, "filter" + id), what == BindingWhat.Filter)
    return True


def script_item_properties(index, scenes, sources, filters, leds):
    global bindings

    props = obs.obs_properties_create()

    id = str(index)
    what = bindings[index].what

    param_what = obs.obs_properties_add_list(props, "what" + id, "What", obs.OBS_COMBO_TYPE_LIST, obs.OBS_COMBO_FORMAT_INT)
    obs.obs_property_list_add_int(param_what, "Source", int(BindingWhat.Source.value))
    obs.obs_property_list_add_int(param_what, "Scene", int(BindingWhat.Scene.value))
    obs.obs_property_list_add_int(param_what, "Filter", int(BindingWhat.Filter.value))
    obs.obs_property_set_modified_callback(param_what, cb_change_what_visibility)

    param_source = obs.obs_properties_add_list(props, "source" + id, "Source", obs.OBS_COMBO_TYPE_LIST, obs.OBS_COMBO_FORMAT_STRING)
    for name in sources:
        obs.obs_property_list_add_string(param_source, name, name)
    obs.obs_property_set_visible(param_source, what == BindingWhat.Source)

    param_action = obs.obs_properties_add_list(props, "action" + id, "Action", obs.OBS_COMBO_TYPE_LIST, obs.OBS_COMBO_FORMAT_INT)
    obs.obs_property_list_add_int(param_action, "Active", int(BindingAction.Active.value))
    obs.obs_property_list_add_int(param_action, "Showing", int(BindingAction.Showing.value))
    obs.obs_property_list_add_int(param_action, "Mute", int(BindingAction.Mute.value))
    obs.obs_property_set_visible(param_action, what == BindingWhat.Source)

    param_scene = obs.obs_properties_add_list(props, "scene" + id, "Scene", obs.OBS_COMBO_TYPE_LIST, obs.OBS_COMBO_FORMAT_STRING)
    for name in scenes:
        obs.obs_property_list_add_string(param_scene, name, name)
    obs.obs_property_set_visible(param_scene, what == BindingWhat.Scene)

    param_filter = obs.obs_properties_add_list(props, "filter" + id, "Filter", obs.OBS_COMBO_TYPE_LIST, obs.OBS_COMBO_FORMAT_STRING)
    for name in filters:
        obs.obs_property_list_add_string(param_filter, name.replace("\n", ": "), name)
    obs.obs_property_set_visible(param_filter, what == BindingWhat.Filter)

    param_led = obs.obs_properties_add_list(props, "key" + id, "Key", obs.OBS_COMBO_TYPE_LIST, obs.OBS_COMBO_FORMAT_STRING)
    for name in leds:
        obs.obs_property_list_add_string(param_led, name, name)

    param_off_color = obs.obs_properties_add_color(props, "off_color" + id, "Inactive Color")
    param_on_color = obs.obs_properties_add_color(props, "on_color" + id, "Active Color")

    return props


def script_properties():
    global bindings
    props = obs.obs_properties_create()

    scenes = list_scenes()
    sources = list_sources()
    filters = list_filters()
    keyboards = list_keyboards()
    leds = list(keys_groups.keys()) + list(keys.keys())

    param_keyboard = obs.obs_properties_add_list(props, "keyboard", "Keyboard", obs.OBS_COMBO_TYPE_LIST, obs.OBS_COMBO_FORMAT_STRING)
    for id, name in keyboards.items():
        obs.obs_property_list_add_string(param_keyboard, name, id)

    param_bindings = obs.obs_properties_add_int(props, "bindings", "Bindings <i>(reload script after change)</i>", 0, 100, 1)

    for i in range(len(bindings)):
        param_group = obs.obs_properties_add_group(props, "binding" + str(i), "Binding", obs.OBS_GROUP_CHECKABLE, script_item_properties(i, scenes, sources, filters, leds))

    return props


def script_defaults(settings):
    pass



def int_to_led_color(s : int):
    x = list(s.to_bytes(4, "little"))
    return x[0:3]

def load_binding(settings, id) -> Binding:
    active = obs.obs_data_get_bool(settings, "binding" + id)
    led = obs.obs_data_get_string(settings, "key" + id)
    what = obs.obs_data_get_int(settings, "what" + id)
    source = obs.obs_data_get_string(settings, "source" + id)
    scene = obs.obs_data_get_string(settings, "scene" + id)
    filter = obs.obs_data_get_string(settings, "filter" + id)
    action = obs.obs_data_get_int(settings, "action" + id)
    off_color = obs.obs_data_get_int(settings, "off_color" + id)
    on_color = obs.obs_data_get_int(settings, "on_color" + id)

    what = BindingWhat(what)
    action = BindingAction(action)

    led_ids = [] 
    if led in keys_groups:
        led_ids = keys_groups[led]
    elif led in keys:
        led_ids = [ keys[led] ]

    off_color = int_to_led_color(off_color)
    on_color = int_to_led_color(on_color)

    name = scene if what == BindingWhat.Scene else (source if what == BindingWhat.Source else filter)
    return Binding(active, led_ids, what, name, action, off_color, on_color)

def script_update(settings):
    global bindings, keyboard_id, loaded

    unsubscribe_all()

    try:
        keyboard_id = obs.obs_data_get_string(settings, "keyboard")
    except:
        keyboard_id = str()
    
    count = obs.obs_data_get_int(settings, "bindings")
    bindings = list()
    
    for i in range(count):
        bindings.append(load_binding(settings, str(i)))

    set_all_active_leds()
    subscribe_all()


def cb_frontend_event(event):
    global loaded
    if event == obs.OBS_FRONTEND_EVENT_SCENE_CHANGED:
        if not loaded:
            unsubscribe_all()
            set_all_active_leds()
            subscribe_all()
            loaded = True
        else:
            set_all_active_leds()


def on_cuesdk_state_changed(evt):
    pass

def script_load(settings):
    global sdk
    sdk = CueSdk()
    result = sdk.connect(on_cuesdk_state_changed)
    if result != CorsairError.CE_Success:
        print("Unable to connect to iCUE: {err}".format(err=result))

    obs.obs_frontend_add_event_callback(cb_frontend_event)

def script_unload():
    global sdk
    unsubscribe_all()
    obs.obs_frontend_remove_event_callback(cb_frontend_event)

    # disconnect causes freeze or crash
    #sdk.disconnect()
    sdk = None


def cb_mute_active(calldata):
    set_all_active_leds()

def subscribe_all():
    global bindings

    for bind in bindings:
        if not bind.active:
            continue

        if bind.what == BindingWhat.Source:
            source = obs.obs_get_source_by_name(bind.name)
            if source is None:
                continue

            handler = obs.obs_source_get_signal_handler(source)
            if bind.action == BindingAction.Active:
                obs.signal_handler_connect(handler, "activate", cb_mute_active)
                obs.signal_handler_connect(handler, "deactivate", cb_mute_active)
            elif bind.action == BindingAction.Showing:
                obs.signal_handler_connect(handler, "show", cb_mute_active)
                obs.signal_handler_connect(handler, "hide", cb_mute_active)
            elif bind.action == BindingAction.Mute:
                obs.signal_handler_connect(handler, "mute", cb_mute_active)

            obs.obs_source_release(source)
            bind.subscribed = bind.action

        elif bind.what == BindingWhat.Filter and "\n" in bind.name:
            (source_name, filter_name) = bind.name.split("\n")
            source = obs.obs_get_source_by_name(source_name)
            if source is not None:
                filter = obs.obs_source_get_filter_by_name(source, filter_name)
                if filter is not None:
                    handler = obs.obs_source_get_signal_handler(filter)
                    obs.signal_handler_connect(handler, "enable", cb_mute_active)
                    obs.obs_source_release(filter)
                    bind.subscribed = bind.action

                obs.obs_source_release(source)

def unsubscribe_all():
    global bindings
    
    for bind in bindings:
        if bind.subscribed is None:
            continue

        if bind.what == BindingWhat.Source:
            source = obs.obs_get_source_by_name(bind.name)
            if source is None:
                continue

            handler = obs.obs_source_get_signal_handler(source)
            if bind.subscribed == BindingAction.Active:
                obs.signal_handler_disconnect(handler, "activate", cb_mute_active)
                obs.signal_handler_disconnect(handler, "deactivate", cb_mute_active)
            elif bind.subscribed == BindingAction.Showing:
                obs.signal_handler_disconnect(handler, "show", cb_mute_active)
                obs.signal_handler_disconnect(handler, "hide", cb_mute_active)
            elif bind.subscribed == BindingAction.Mute:
                obs.signal_handler_disconnect(handler, "mute", cb_mute_active)

            obs.obs_source_release(source)

        elif bind.what == BindingWhat.Filter and "\n" in bind.name:
            (source_name, filter_name) = bind.name.split("\n")
            source = obs.obs_get_source_by_name(source_name)
            if source is not None:
                filter = obs.obs_source_get_filter_by_name(source, filter_name)
                if filter is not None:
                    handler = obs.obs_source_get_signal_handler(filter)
                    obs.signal_handler_disconnect(handler, "enable", cb_mute_active)
                    obs.obs_source_release(filter)

                obs.obs_source_release(source)

        bind.subscribed = None


def set_all_active_leds():
    global sdk, bindings, keyboard_id

    current_scene = obs.obs_frontend_get_current_scene()
    current_scene_name = obs.obs_source_get_name(current_scene) if current_scene is not None else ""
    obs.obs_source_release(current_scene)

    leds = {}
    for bind in bindings:
        if not bind.active:
            continue

        is_active = False
        if bind.what == BindingWhat.Source:
            source = obs.obs_get_source_by_name(bind.name)
            if source is not None:
                if bind.action == BindingAction.Active:
                    is_active = obs.obs_source_active(source)
                elif bind.action == BindingAction.Showing:
                    is_active = obs.obs_source_showing(source)
                elif bind.action == BindingAction.Mute:
                    is_active = not obs.obs_source_muted(source)
                
                obs.obs_source_release(source)
                
        elif bind.what == BindingWhat.Scene:
            is_active = (bind.name == current_scene_name)

        elif bind.what == BindingWhat.Filter and "\n" in bind.name:
            (source_name, filter_name) = bind.name.split("\n")
            source = obs.obs_get_source_by_name(source_name)
            if source is not None:
                filter = obs.obs_source_get_filter_by_name(source, filter_name)
                if filter is not None:
                    is_active = obs.obs_source_enabled(filter)
                    obs.obs_source_release(filter)

                obs.obs_source_release(source)

        for led in bind.leds:
            if is_active:
                leds[led] = CorsairLedColor(id=led, a=255, r=bind.on_color[0], g=bind.on_color[1], b=bind.on_color[2])
            elif led not in leds:
                leds[led] = CorsairLedColor(id=led, a=255, r=bind.off_color[0], g=bind.off_color[1], b=bind.off_color[2])

    # print("set {leds}".format(leds=leds))

    sdk.set_led_colors(keyboard_id, leds.values())