From 3e7b77aae7e4d0879757b6402115d83d833d10e5 Mon Sep 17 00:00:00 2001 From: Stefan Krawczyk Date: Tue, 23 Jul 2024 09:08:40 -0700 Subject: [PATCH] Local tracker encoding fix (#281) This was in response to #273 . A user on windows wasn't able to load things because windows had a different default encoding. This fixes that and ensures the local tracker writes & reads utf-8 everywhere. * Tests adding specific encoding * Adds assumption on state to docs To make it clear what we assume. * Adds test case on utf-8 encoding for local tracker Adds test cases to: 1. enshrine behavior that developer needs to make sure things are utf-8 encodable. 2. integration test to ensure what is written can be read. --- burr/tracking/client.py | 10 +- docs/concepts/state-persistence.rst | 5 + tests/tracking/test_local_tracking_client.py | 107 +++++++++++++++++++ 3 files changed, 117 insertions(+), 5 deletions(-) diff --git a/burr/tracking/client.py b/burr/tracking/client.py index 332c2d05..5a8d4ac4 100644 --- a/burr/tracking/client.py +++ b/burr/tracking/client.py @@ -221,7 +221,7 @@ def _log_child_relationships( # currently we write start events, so it really won't matter # but in the future we'll write end events, but we'll parse it in a # way that allows them to be interwoven - with open(parent_children_list_path, "a") as f: + with open(parent_children_list_path, "a", errors="replace", encoding="utf-8") as f: fileno = f.fileno() try: fcntl.flock(fileno, fcntl.LOCK_EX) @@ -257,7 +257,7 @@ def app_log_exists( path = os.path.join(cls.get_storage_path(project, storage_dir), app_id, cls.LOG_FILENAME) if not os.path.exists(path): return False - lines = open(path, "r").readlines() + lines = open(path, "r", errors="replace", encoding="utf-8").readlines() if len(lines) == 0: return False return True @@ -292,7 +292,7 @@ def load_state( path = os.path.join(cls.get_storage_path(project, storage_dir), app_id, cls.LOG_FILENAME) if not os.path.exists(path): raise ValueError(f"No logs found for {project}/{app_id} under {storage_dir}") - with open(path, "r") as f: + with open(path, "r", errors="replace", encoding="utf-8") as f: json_lines = f.readlines() # load as JSON json_lines = [json.loads(js_line) for js_line in json_lines] @@ -369,7 +369,7 @@ def post_application_create( parent_pointer=PointerModel.from_pointer(parent_pointer), spawning_parent_pointer=PointerModel.from_pointer(spawning_parent_pointer), ).model_dump() - with open(metadata_path, "w", errors="replace") as f: + with open(metadata_path, "w", errors="replace", encoding="utf-8") as f: json.dump(metadata, f) # Append to the parents of this the pointer to this, now @@ -471,7 +471,7 @@ def load( path = os.path.join(self.storage_dir, app_id, self.LOG_FILENAME) if not os.path.exists(path): return None - with open(path, "r") as f: + with open(path, "r", errors="replace", encoding="utf-8") as f: json_lines = f.readlines() if len(json_lines) == 0: return None # in this case we have not logged anything yet diff --git a/docs/concepts/state-persistence.rst b/docs/concepts/state-persistence.rst index 623c37f9..24dabb16 100644 --- a/docs/concepts/state-persistence.rst +++ b/docs/concepts/state-persistence.rst @@ -26,6 +26,11 @@ and you want to store the state of the process after each action, and then reloa ``Burr`` provides a few simple interfaces to do this with minimal changes. Let's walk through a simple chatbot example as we're explaining concepts: +Two notable assumptions: + +1. for library provided persisters, state needs to ultimately be JSON serializable. If it's not, you can use the :doc:`serde` API to customize serialization and deserialization. +2. your general assumption should be that strings are/will be encoded as UTF-8, however this is dependent on the persister you use. + State Keys ---------- Burr `applications` are, by default, keyed on two entities: diff --git a/tests/tracking/test_local_tracking_client.py b/tests/tracking/test_local_tracking_client.py index 184e6a95..cc95ae0b 100644 --- a/tests/tracking/test_local_tracking_client.py +++ b/tests/tracking/test_local_tracking_client.py @@ -354,3 +354,110 @@ def test_application_tracks_link_from_spawning_parent(tmpdir: str): children_parsed = [ChildApplicationModel.model_validate(child) for child in children] assert set(child.child.app_id for child in children_parsed) == set(spawned_children) assert all(child.event_type == "spawn_start" for child in children_parsed) + + +def test_that_we_fail_on_non_unicode_characters(tmp_path): + """This is a test to log expected behavior. + + Right now it is on the developer to ensure that state can be encoded into UTF-8. + + This test is here to capture this assumption. + """ + + @action(reads=["test"], writes=["test"]) + def state_1(state: State) -> State: + return state.update(test="test") + + @action(reads=["test"], writes=["test"]) + def state_2(state: State) -> State: + return state.update(test="\uD800") # Invalid UTF-8 byte sequence + + tracker = LocalTrackingClient(project="test", storage_dir=tmp_path) + app: Application = ( + ApplicationBuilder() + .with_actions(state_1, state_2) + .with_transitions(("state_1", "state_2"), ("state_2", "state_1")) + .with_tracker(tracker=tracker) + .initialize_from( + initializer=tracker, + resume_at_next_action=False, + default_entrypoint="state_1", + default_state={}, + ) + .with_identifiers(app_id="3") + .build() + ) + + with pytest.raises(ValueError): + app.run(halt_after=["state_2"]) + + +def test_that_we_can_read_write_local_tracker(tmp_path): + """Integration like test to ensure we can write and then read what was written""" + + @action( + reads=[], + writes=[ + "text", + "greek", + "cyrillic", + "hebrew", + "arabic", + "hindi", + "chinese", + "japanese", + "korean", + "emoji", + ], + ) + def state_1(state: State) -> State: + text = "á, é, í, ó, ú, ñ, ü" + greek = "α, β, γ, δ" + cyrillic = "ж, ы, б, ъ" + hebrew = "א, ב, ג, ד" + arabic = "خ, د, ذ, ر" + hindi = "अ, आ, इ, ई" + chinese = "中, 国, 文" + japanese = "日, 本, 語" + korean = "한, 국, 어" + emoji = "😀, 👍, 🚀, 🌍" + return state.update( + text=text, + greek=greek, + cyrillic=cyrillic, + hebrew=hebrew, + arabic=arabic, + hindi=hindi, + chinese=chinese, + japanese=japanese, + korean=korean, + emoji=emoji, + ) + + @action(reads=["text"], writes=["text"]) + def state_2(state: State) -> State: + return state.update(text="\x9d") # encode-able UTF-8 sequence + + tracker = LocalTrackingClient( + project="test", + storage_dir=tmp_path, + ) + + for i in range(2): + # reloads from log.jsonl in the second run and errors + app: Application = ( + ApplicationBuilder() + .with_actions(state_1, state_2) + .with_transitions(("state_1", "state_2"), ("state_2", "state_1")) + .with_tracker(tracker=tracker) + .initialize_from( + initializer=tracker, + resume_at_next_action=False, + default_entrypoint="state_1", + default_state={}, + ) + .with_identifiers(app_id="3") + .build() + ) + + app.run(halt_after=["state_2"])