diff --git a/src/streamsync/__init__.py b/src/streamsync/__init__.py index e81861765..1177be510 100644 --- a/src/streamsync/__init__.py +++ b/src/streamsync/__init__.py @@ -1,11 +1,11 @@ import importlib.metadata from typing import Union, Optional, Dict, Any from streamsync.core import Readable, FileWrapper, BytesWrapper, Config -from streamsync.core import initial_state, component_manager, session_manager, session_verifier +from streamsync.core import initial_state, base_component_tree, session_manager, session_verifier VERSION = importlib.metadata.version("streamsync") -component_manager +base_component_tree session_manager Config session_verifier diff --git a/src/streamsync/app_runner.py b/src/streamsync/app_runner.py index e817093e3..e0c316264 100644 --- a/src/streamsync/app_runner.py +++ b/src/streamsync/app_runner.py @@ -69,7 +69,7 @@ def __init__(self, app_path: str, mode: str, run_code: str, - components: Dict, + bmc_components: Dict, is_app_process_server_ready: multiprocessing.synchronize.Event, is_app_process_server_failed: multiprocessing.synchronize.Event): super().__init__(name="AppProcess") @@ -78,7 +78,7 @@ def __init__(self, self.app_path = app_path self.mode = mode self.run_code = run_code - self.components = components + self.bmc_components = bmc_components self.is_app_process_server_ready = is_app_process_server_ready self.is_app_process_server_failed = is_app_process_server_failed self.logger = logging.getLogger("app") @@ -149,7 +149,7 @@ def _handle_session_init(self, payload: InitSessionRequestPayload) -> InitSessio userState=user_state, sessionId=session.session_id, mail=session.session_state.mail, - components=streamsync.component_manager.to_dict(), + components=session.session_component_tree.to_dict(), userFunctions=self._get_user_functions() ) @@ -207,10 +207,10 @@ def _handle_state_enquiry(self, session: StreamsyncSession) -> StateEnquiryRespo session.session_state.clear_mail() return res_payload - + def _handle_component_update(self, payload: ComponentUpdateRequestPayload) -> None: import streamsync - streamsync.component_manager.ingest(payload.components) + streamsync.base_component_tree.ingest(payload.components) def _handle_message(self, session_id: str, request: AppProcessServerRequest) -> AppProcessServerResponse: """ @@ -337,7 +337,7 @@ def _main(self) -> None: terminate_early = True try: - streamsync.component_manager.ingest(self.components) + streamsync.base_component_tree.ingest(self.bmc_components) except BaseException: streamsync.initial_state.add_log_entry( "error", "UI Components Error", "Couldn't load components. An exception was raised.", tb.format_exc()) @@ -535,7 +535,7 @@ def __init__(self, app_path: str, mode: str): self.client_conn: Optional[multiprocessing.connection.Connection] = None self.app_process: Optional[AppProcess] = None self.run_code: Optional[str] = None - self.components: Optional[Dict] = None + self.bmc_components: Optional[Dict] = None self.is_app_process_server_ready = multiprocessing.Event() self.is_app_process_server_failed = multiprocessing.Event() self.app_process_listener: Optional[AppProcessListener] = None @@ -589,7 +589,7 @@ def signal_handler(sig, frame): pass self.run_code = self._load_persisted_script() - self.components = self._load_persisted_components() + self.bmc_components = self._load_persisted_components() if self.mode == "edit": self._set_observer() @@ -676,7 +676,7 @@ async def update_components(self, session_id: str, payload: ComponentUpdateReque if self.mode != "edit": raise PermissionError( "Cannot update components in non-update mode.") - self.components = payload.components + self.bmc_components = payload.components file_contents = { "metadata": { "streamsync_version": VERSION @@ -744,7 +744,7 @@ def shut_down(self) -> None: def _start_app_process(self) -> None: if self.run_code is None: raise ValueError("Cannot start app process. Code hasn't been set.") - if self.components is None: + if self.bmc_components is None: raise ValueError( "Cannot start app process. Components haven't been set.") self.is_app_process_server_ready.clear() @@ -758,7 +758,7 @@ def _start_app_process(self) -> None: app_path=self.app_path, mode=self.mode, run_code=self.run_code, - components=self.components, + bmc_components=self.bmc_components, is_app_process_server_ready=self.is_app_process_server_ready, is_app_process_server_failed=self.is_app_process_server_failed) self.app_process.start() diff --git a/src/streamsync/core.py b/src/streamsync/core.py index e13521487..793d570d0 100644 --- a/src/streamsync/core.py +++ b/src/streamsync/core.py @@ -504,7 +504,7 @@ def to_dict(self) -> Dict: return c_dict -class ComponentManager: +class ComponentTree: def __init__(self) -> None: self.counter: int = 0 @@ -512,6 +512,9 @@ def __init__(self) -> None: root_component = Component("root", "root", {}) self.attach(root_component) + def get_component(self, component_id: str) -> Optional[Component]: + return self.components.get(component_id) + def get_descendents(self, parent_id: str) -> List[Component]: children = list(filter(lambda c: c.parentId == parent_id, self.components.values())) @@ -547,6 +550,25 @@ def to_dict(self) -> Dict: for id, component in self.components.items(): active_components[id] = component.to_dict() return active_components + + +class SessionComponentTree(ComponentTree): + + def __init__(self, base_component_tree: ComponentTree): + super().__init__() + self.base_component_tree = base_component_tree + + def get_component(self, component_id: str) -> Optional[Component]: + base_component = self.base_component_tree.get_component(component_id) + if base_component: + return base_component + return self.components.get(component_id) + + def to_dict(self) -> Dict: + active_components = {} + for id, component in {**self.components, **self.base_component_tree.components}.items(): + active_components[id] = component.to_dict() + return active_components class EventDeserialiser: @@ -558,8 +580,8 @@ class EventDeserialiser: Its main goal is to deserialise incoming content in a controlled and predictable way, applying sanitisation of inputs where relevant.""" - def __init__(self, session_state: StreamsyncState): - self.evaluator = Evaluator(session_state) + def __init__(self, session_state: StreamsyncState, session_component_tree: SessionComponentTree): + self.evaluator = Evaluator(session_state, session_component_tree) def transform(self, ev: StreamsyncEvent) -> None: # Events without payloads are safe @@ -723,8 +745,9 @@ class Evaluator: template_regex = re.compile(r"[\\]?@{([\w\s.]*)}") - def __init__(self, session_state: StreamsyncState): + def __init__(self, session_state: StreamsyncState, session_component_tree: ComponentTree): self.ss = session_state + self.ct = session_component_tree def evaluate_field(self, instance_path: InstancePath, field_key: str, as_json=False, default_field_value="") -> Any: def replacer(matched): @@ -745,22 +768,26 @@ def replacer(matched): return str(serialised_value) component_id = instance_path[-1]["componentId"] - component = component_manager.components[component_id] - field_value = component.content.get(field_key) or default_field_value - replaced = self.template_regex.sub(replacer, field_value) + component = self.ct.get_component(component_id) + if component: + field_value = component.content.get(field_key) or default_field_value + replaced = self.template_regex.sub(replacer, field_value) - if as_json: - return json.loads(replaced) + if as_json: + return json.loads(replaced) + else: + return replaced else: - return replaced + raise ValueError(f"Couldn't acquire a component by ID '{component_id}'") def get_context_data(self, instance_path: InstancePath) -> Dict[str, Any]: context: Dict[str, Any] = {} - for i in range(len(instance_path)): path_item = instance_path[i] component_id = path_item["componentId"] - component = component_manager.components[component_id] + component = self.ct.get_component(component_id) + if not component: + continue if component.type != "repeater": continue if i + 1 >= len(instance_path): @@ -780,7 +807,7 @@ def get_context_data(self, instance_path: InstancePath) -> Dict[str, Any]: repeater_items = list(repeater_object.items()) elif isinstance(repeater_object, list): repeater_items = [(k, v) - for (k, v) in enumerate(repeater_object)] + for (k, v) in enumerate(repeater_object)] else: raise ValueError( "Cannot produce context. Repeater object must evaluate to a dictionary.") @@ -878,6 +905,7 @@ def __init__(self, session_id: str, cookies: Optional[Dict[str, str]], headers: new_state = StreamsyncState.get_new() new_state.user_state.mutated = set() self.session_state = new_state + self.session_component_tree = SessionComponentTree(base_component_tree) self.event_handler = EventHandler(self) def update_last_active_timestamp(self) -> None: @@ -977,8 +1005,9 @@ class EventHandler: def __init__(self, session: StreamsyncSession) -> None: self.session = session self.session_state = session.session_state - self.deser = EventDeserialiser(self.session_state) - self.evaluator = Evaluator(self.session_state) + self.session_component_tree = session.session_component_tree + self.deser = EventDeserialiser(self.session_state, self.session_component_tree) + self.evaluator = Evaluator(self.session_state, self.session_component_tree) def _handle_binding(self, event_type, target_component, instance_path, payload) -> None: @@ -1075,7 +1104,7 @@ def handle(self, ev: StreamsyncEvent) -> StreamsyncEventResult: try: instance_path = ev.instancePath target_id = instance_path[-1]["componentId"] - target_component = component_manager.components[target_id] + target_component = self.session_component_tree.get_component(target_id) self._handle_binding(ev.type, target_component, instance_path, ev.payload) result = self._call_handler_callable( @@ -1091,8 +1120,8 @@ def handle(self, ev: StreamsyncEvent) -> StreamsyncEventResult: state_serialiser = StateSerialiser() -component_manager = ComponentManager() initial_state = StreamsyncState() +base_component_tree = ComponentTree() session_manager = SessionManager() diff --git a/tests/test_core.py b/tests/test_core.py index 3e8957952..f1d1783c3 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -3,7 +3,7 @@ from typing import Dict import numpy as np -from streamsync.core import (BytesWrapper, ComponentManager, Evaluator, EventDeserialiser, +from streamsync.core import (BytesWrapper, ComponentTree, Evaluator, EventDeserialiser, FileWrapper, SessionManager, StateProxy, StateSerialiser, StateSerialiserException, StreamsyncState) import streamsync as ss from streamsync.ss_types import StreamsyncEvent @@ -48,7 +48,8 @@ ss.Config.is_mail_enabled_for_log = True ss.init_state(raw_state_dict) -ss.component_manager.ingest(sc) +session = ss.session_manager.get_new_session() +session.session_component_tree.ingest(sc) class TestStateProxy: @@ -215,18 +216,18 @@ def test_unpickable_members(self) -> None: json.dumps(cloned.mail) -class TestComponentManager: +class TestComponentTree: - cm = ComponentManager() + ct = ComponentTree() def test_ingest(self) -> None: - self.cm.ingest(sc) - d = self.cm.to_dict() + self.ct.ingest(sc) + d = self.ct.to_dict() assert d.get( "84378aea-b64c-49a3-9539-f854532279ee").get("type") == "header" def test_descendents(self) -> None: - desc = self.cm.get_descendents("root") + desc = self.ct.get_descendents("root") desc_ids = list(map(lambda x: x.id, desc)) assert "84378aea-b64c-49a3-9539-f854532279ee" in desc_ids assert "bb4d0e86-619e-4367-a180-be28ab6059f4" in desc_ids @@ -237,7 +238,8 @@ class TestEventDeserialiser: root_instance_path = [{"componentId": "root", "instanceNumber": 0}] session_state = StreamsyncState(raw_state_dict) - ed = EventDeserialiser(session_state) + component_tree = session.session_component_tree + ed = EventDeserialiser(session_state, component_tree) def test_unknown_no_payload(self) -> None: ev = StreamsyncEvent( @@ -598,7 +600,8 @@ def test_evaluate_field_simple(self) -> None: st = StreamsyncState({ "counter": 8 }) - e = Evaluator(st) + ct = session.session_component_tree + e = Evaluator(st, ct) evaluated = e.evaluate_field(instance_path, "text") assert evaluated == "The counter is 8" @@ -622,7 +625,8 @@ def test_evaluate_field_repeater(self) -> None: "ts": "TypeScript" } }) - e = Evaluator(st) + ct = session.session_component_tree + e = Evaluator(st, ct) assert e.evaluate_field( instance_path_0, "text") == "The id is c and the name is C" assert e.evaluate_field( @@ -633,7 +637,8 @@ def test_set_state(self) -> None: {"componentId": "root", "instanceNumber": 0} ] st = StreamsyncState(raw_state_dict) - e = Evaluator(st) + ct = session.session_component_tree + e = Evaluator(st, ct) e.set_state("name", instance_path, "Roger") e.set_state("dynamic_prop", instance_path, "height") e.set_state("features[dynamic_prop]", instance_path, "toddler height") @@ -647,7 +652,8 @@ def test_evaluate_expression(self) -> None: {"componentId": "root", "instanceNumber": 0} ] st = StreamsyncState(raw_state_dict) - e = Evaluator(st) + ct = session.session_component_tree + e = Evaluator(st, ct) assert e.evaluate_expression("features.eyes", instance_path) == "green" assert e.evaluate_expression("best_feature", instance_path) == "eyes" assert e.evaluate_expression("features[best_feature]", instance_path) == "green"