diff --git a/lib/logitech_receiver/diversion.py b/lib/logitech_receiver/diversion.py index 5ac595dca..e764813d3 100644 --- a/lib/logitech_receiver/diversion.py +++ b/lib/logitech_receiver/diversion.py @@ -29,16 +29,18 @@ import time import typing +from pathlib import Path from typing import Any from typing import Dict from typing import Tuple import gi import psutil -import yaml from keysyms import keysymdef +from . import rule_storage + # There is no evdev on macOS or Windows. Diversion will not work without # it but other Solaar functionality is available. if platform.system() in ("Darwin", "Windows"): @@ -58,6 +60,14 @@ logger = logging.getLogger(__name__) +if os.environ.get("XDG_CONFIG_HOME"): + xdg_config_home = Path(os.environ.get("XDG_CONFIG_HOME")) +else: + xdg_config_home = Path("~/.config").expanduser() + +RULES_CONFIG = xdg_config_home / "solaar" / "rules.yaml" + + # # See docs/rules.md for documentation # @@ -146,6 +156,17 @@ _dbus_interface = None +class AbstractRepository(typing.Protocol): + def save(self, rules: Dict[str, str]) -> None: + ... + + def load(self) -> list: + ... + + +storage: AbstractRepository = rule_storage.YmlRuleStorage(RULES_CONFIG) + + class XkbDisplay(ctypes.Structure): """opaque struct""" @@ -1548,22 +1569,13 @@ def process_notification(device, notification: HIDPPNotification, feature) -> No GLib.idle_add(evaluate_rules, feature, notification, device) -_XDG_CONFIG_HOME = os.environ.get("XDG_CONFIG_HOME") or os.path.expanduser(os.path.join("~", ".config")) -_file_path = os.path.join(_XDG_CONFIG_HOME, "solaar", "rules.yaml") - -rules = built_in_rules +def save_config_rule_file() -> None: + """Saves user configured rules.""" - -def _save_config_rule_file(file_name: str = _file_path): # This is a trick to show str/float/int lists in-line (inspired by https://stackoverflow.com/a/14001707) class inline_list(list): pass - def blockseq_rep(dumper, data): - return dumper.represent_sequence("tag:yaml.org,2002:seq", data, flow_style=True) - - yaml.add_representer(inline_list, blockseq_rep) - def convert(elem): if isinstance(elem, list): if len(elem) == 1 and isinstance(elem[0], (int, str, float)): @@ -1578,53 +1590,41 @@ def convert(elem): return int(elem) return elem - # YAML format settings - dump_settings = { - "encoding": "utf-8", - "explicit_start": True, - "explicit_end": True, - "default_flow_style": False, - # 'version': (1, 3), # it would be printed for every rule - } + global rules + # Save only user-defined rules - rules_to_save = sum((r.data()["Rule"] for r in rules.components if r.source == file_name), []) + rules_to_save = sum((r.data()["Rule"] for r in rules.components if r.source == str(RULES_CONFIG)), []) if logger.isEnabledFor(logging.INFO): - logger.info("saving %d rule(s) to %s", len(rules_to_save), file_name) + logger.info(f"saving {len(rules_to_save)} rule(s) to {str(RULES_CONFIG)}") + dump_data = [r["Rule"] for r in rules_to_save] try: - with open(file_name, "w") as f: - if rules_to_save: - f.write("%YAML 1.3\n") # Write version manually - dump_data = [r["Rule"] for r in rules_to_save] - yaml.dump_all(convert(dump_data), f, **dump_settings) - except Exception as e: - logger.error("failed to save to %s\n%s", file_name, e) - return False - return True + data = convert(dump_data) + storage.save(data) + except Exception: + logger.error("failed to save to rules config") -def load_config_rule_file(): +def load_rule_config() -> Rule: """Loads user configured rules.""" global rules - if os.path.isfile(_file_path): - rules = _load_rule_config(_file_path) - - -def _load_rule_config(file_path: str) -> Rule: loaded_rules = [] try: - with open(file_path) as config_file: - loaded_rules = [] - for loaded_rule in yaml.safe_load_all(config_file): - rule = Rule(loaded_rule, source=file_path) - if logger.isEnabledFor(logging.DEBUG): - logger.debug("load rule: %s", rule) - loaded_rules.append(rule) - if logger.isEnabledFor(logging.INFO): - logger.info("loaded %d rules from %s", len(loaded_rules), config_file.name) + plain_rules = storage.load() + for loaded_rule in plain_rules: + rule = Rule(loaded_rule, source=str(RULES_CONFIG)) + if logger.isEnabledFor(logging.DEBUG): + logger.debug(f"load rule: {rule}") + loaded_rules.append(rule) + if logger.isEnabledFor(logging.INFO): + logger.info( + f"loaded {len(loaded_rules)} rules from config file", + ) except Exception as e: - logger.error("failed to load from %s\n%s", file_path, e) - return Rule([Rule(loaded_rules, source=file_path), built_in_rules]) + logger.error(f"failed to load from {RULES_CONFIG}\n{e}") + user_rules = Rule(loaded_rules, source=str(RULES_CONFIG)) + rules = Rule([user_rules, built_in_rules]) + return rules -load_config_rule_file() +load_rule_config() diff --git a/lib/logitech_receiver/rule_storage.py b/lib/logitech_receiver/rule_storage.py new file mode 100644 index 000000000..1c38e4c5c --- /dev/null +++ b/lib/logitech_receiver/rule_storage.py @@ -0,0 +1,47 @@ +from pathlib import Path +from typing import Dict + +import yaml + + +class YmlRuleStorage: + def __init__(self, path: Path): + self._config_path = path + + def save(self, rules: Dict[str, str]) -> None: + # This is a trick to show str/float/int lists in-line (inspired by https://stackoverflow.com/a/14001707) + class inline_list(list): + pass + + def blockseq_rep(dumper, data): + return dumper.represent_sequence("tag:yaml.org,2002:seq", data, flow_style=True) + + yaml.add_representer(inline_list, blockseq_rep) + format_settings = { + "encoding": "utf-8", + "explicit_start": True, + "explicit_end": True, + "default_flow_style": False, + } + with open(self._config_path, "w") as f: + f.write("%YAML 1.3\n") # Write version manually + yaml.dump_all(rules, f, **format_settings) + + def load(self) -> list: + with open(self._config_path) as config_file: + plain_rules = list(yaml.safe_load_all(config_file)) + return plain_rules + + +class FakeRuleStorage: + def __init__(self, rules=None): + if rules is None: + self._rules = {} + else: + self._rules = rules + + def save(self, rules: dict) -> None: + self._rules = rules + + def load(self) -> dict: + return self._rules diff --git a/lib/solaar/ui/diversion_rules.py b/lib/solaar/ui/diversion_rules.py index 788eab241..17f8d02f6 100644 --- a/lib/solaar/ui/diversion_rules.py +++ b/lib/solaar/ui/diversion_rules.py @@ -347,7 +347,7 @@ def _menu_do_insert(self, _mitem, m, it, new_c, below=False): else: idx = parent_c.components.index(c) if isinstance(new_c, diversion.Rule) and wrapped.level == 1: - new_c.source = diversion._file_path # new rules will be saved to the YAML file + new_c.source = str(diversion.RULES_CONFIG) # new rules will be saved to the YAML file idx += int(below) parent_c.components.insert(idx, new_c) self._populate_model_func(m, parent_it, new_c, level=wrapped.level, pos=idx) @@ -632,10 +632,13 @@ def _reload_yaml_file(self): self.view.expand_all() def _save_yaml_file(self): - if diversion._save_config_rule_file(): + try: + diversion.save_config_rule_file() self.dirty = False self.save_btn.set_sensitive(False) self.discard_btn.set_sensitive(False) + except Exception: + pass def _create_top_panel(self): sw = Gtk.ScrolledWindow() diff --git a/tests/logitech_receiver/test_diversion.py b/tests/logitech_receiver/test_diversion.py index 70aba163f..5218a2402 100644 --- a/tests/logitech_receiver/test_diversion.py +++ b/tests/logitech_receiver/test_diversion.py @@ -48,7 +48,7 @@ def test_load_rule_config(rule_config): ] with mock.patch("builtins.open", new=mock_open(read_data=rule_config)): - loaded_rules = diversion._load_rule_config(file_path=mock.Mock()) + loaded_rules = diversion.load_rule_config() assert len(loaded_rules.components) == 2 # predefined and user configured rules user_configured_rules = loaded_rules.components[0]