diff --git a/lib/logitech_receiver/diversion.py b/lib/logitech_receiver/diversion.py index e764813d3..9c3ab8cc3 100644 --- a/lib/logitech_receiver/diversion.py +++ b/lib/logitech_receiver/diversion.py @@ -1569,62 +1569,64 @@ def process_notification(device, notification: HIDPPNotification, feature) -> No GLib.idle_add(evaluate_rules, feature, notification, device) -def save_config_rule_file() -> None: - """Saves user configured rules.""" - - # This is a trick to show str/float/int lists in-line (inspired by https://stackoverflow.com/a/14001707) - class inline_list(list): - pass - - def convert(elem): - if isinstance(elem, list): - if len(elem) == 1 and isinstance(elem[0], (int, str, float)): - # All diversion classes that expect a list of scalars also support a single scalar without a list - return elem[0] - if all(isinstance(c, (int, str, float)) for c in elem): - return inline_list([convert(c) for c in elem]) - return [convert(c) for c in elem] - if isinstance(elem, dict): - return {k: convert(v) for k, v in elem.items()} - if isinstance(elem, NamedInt): - return int(elem) - return elem - - global rules - - # Save only user-defined rules - rules_to_save = sum((r.data()["Rule"] for r in rules.components if r.source == str(RULES_CONFIG)), []) - if logger.isEnabledFor(logging.INFO): - logger.info(f"saving {len(rules_to_save)} rule(s) to {str(RULES_CONFIG)}") - dump_data = [r["Rule"] for r in rules_to_save] - try: - data = convert(dump_data) - storage.save(data) - except Exception: - logger.error("failed to save to rules config") - +class Persister: + @staticmethod + def save_config_rule_file() -> None: + """Saves user configured rules.""" + + # This is a trick to show str/float/int lists in-line (inspired by https://stackoverflow.com/a/14001707) + class inline_list(list): + pass + + def convert(elem): + if isinstance(elem, list): + if len(elem) == 1 and isinstance(elem[0], (int, str, float)): + # All diversion classes that expect a list of scalars also support a single scalar without a list + return elem[0] + if all(isinstance(c, (int, str, float)) for c in elem): + return inline_list([convert(c) for c in elem]) + return [convert(c) for c in elem] + if isinstance(elem, dict): + return {k: convert(v) for k, v in elem.items()} + if isinstance(elem, NamedInt): + return int(elem) + return elem + + global rules + + # Save only user-defined rules + rules_to_save = sum((r.data()["Rule"] for r in rules.components if r.source == str(RULES_CONFIG)), []) + if logger.isEnabledFor(logging.INFO): + logger.info(f"saving {len(rules_to_save)} rule(s) to {str(RULES_CONFIG)}") + dump_data = [r["Rule"] for r in rules_to_save] + try: + data = convert(dump_data) + storage.save(data) + except Exception: + logger.error("failed to save to rules config") -def load_rule_config() -> Rule: - """Loads user configured rules.""" - global rules + @staticmethod + def load_rule_config() -> Rule: + """Loads user configured rules.""" + global rules - loaded_rules = [] - try: - plain_rules = storage.load() - for loaded_rule in plain_rules: - rule = Rule(loaded_rule, source=str(RULES_CONFIG)) - if logger.isEnabledFor(logging.DEBUG): - logger.debug(f"load rule: {rule}") - loaded_rules.append(rule) - if logger.isEnabledFor(logging.INFO): - logger.info( - f"loaded {len(loaded_rules)} rules from config file", - ) - except Exception as e: - logger.error(f"failed to load from {RULES_CONFIG}\n{e}") - user_rules = Rule(loaded_rules, source=str(RULES_CONFIG)) - rules = Rule([user_rules, built_in_rules]) - return rules + loaded_rules = [] + try: + plain_rules = storage.load() + for loaded_rule in plain_rules: + rule = Rule(loaded_rule, source=str(RULES_CONFIG)) + if logger.isEnabledFor(logging.DEBUG): + logger.debug(f"load rule: {rule}") + loaded_rules.append(rule) + if logger.isEnabledFor(logging.INFO): + logger.info( + f"loaded {len(loaded_rules)} rules from config file", + ) + except Exception as e: + logger.error(f"failed to load from {RULES_CONFIG}\n{e}") + user_rules = Rule(loaded_rules, source=str(RULES_CONFIG)) + rules = Rule([user_rules, built_in_rules]) + return rules -load_rule_config() +Persister.load_rule_config() diff --git a/lib/solaar/ui/diversion_rules.py b/lib/solaar/ui/diversion_rules.py index 17f8d02f6..736898247 100644 --- a/lib/solaar/ui/diversion_rules.py +++ b/lib/solaar/ui/diversion_rules.py @@ -31,6 +31,7 @@ from typing import Callable from typing import Dict from typing import Optional +from typing import Protocol from gi.repository import Gdk from gi.repository import GObject @@ -550,8 +551,14 @@ def _menu_copy(self, m, it): return menu_copy +class RulePersister(Protocol): + def load_rule_config(self) -> _DIV.Rule: ... + + def save_config_rule_file(self) -> None: ... + + class DiversionDialog: - def __init__(self, action_menu): + def __init__(self, action_menu, rule_persister: RulePersister): window = Gtk.Window() window.set_title(_("Solaar Rule Editor")) window.connect("delete-event", self._closing) @@ -568,6 +575,7 @@ def __init__(self, action_menu): populate_model_func=_populate_model, on_update=self.on_update, ) + self._ruler_persister = rule_persister self.dirty = False # if dirty, there are pending changes to be saved @@ -626,6 +634,7 @@ def _reload_yaml_file(self): self.dirty = False for c in self.selected_rule_edit_panel.get_children(): self.selected_rule_edit_panel.remove(c) + self._ruler_persister.load_rule_config() diversion.load_config_rule_file() self.model = self._create_model() self.view.set_model(self.model) @@ -633,7 +642,7 @@ def _reload_yaml_file(self): def _save_yaml_file(self): try: - diversion.save_config_rule_file() + self._ruler_persister.save_config_rule_file() self.dirty = False self.save_btn.set_sensitive(False) self.discard_btn.set_sensitive(False) @@ -1867,6 +1876,6 @@ def show_window(model: Gtk.TreeStore): global _dev_model _dev_model = model if _diversion_dialog is None: - _diversion_dialog = DiversionDialog(ActionMenu) + _diversion_dialog = DiversionDialog(action_menu=ActionMenu, rule_persister=diversion.Persister()) update_devices() _diversion_dialog.window.present() diff --git a/tests/logitech_receiver/test_diversion.py b/tests/logitech_receiver/test_diversion.py index 5218a2402..9e8ae4b22 100644 --- a/tests/logitech_receiver/test_diversion.py +++ b/tests/logitech_receiver/test_diversion.py @@ -48,7 +48,7 @@ def test_load_rule_config(rule_config): ] with mock.patch("builtins.open", new=mock_open(read_data=rule_config)): - loaded_rules = diversion.load_rule_config() + loaded_rules = diversion.Persister.load_rule_config() assert len(loaded_rules.components) == 2 # predefined and user configured rules user_configured_rules = loaded_rules.components[0]