Skip to content

Commit

Permalink
Local tracker encoding fix (#281)
Browse files Browse the repository at this point in the history
This was in response to #273 . A user on windows
wasn't able to load things because windows had
a different default encoding. This fixes that
and ensures the local tracker writes & reads utf-8 
everywhere.

* Tests adding specific encoding

* Adds assumption on state to docs

To make it clear what we assume.

* Adds test case on utf-8 encoding for local tracker

Adds test cases to:

1. enshrine behavior that developer needs to make sure things are utf-8 encodable.
2. integration test to ensure what is written can be read.
  • Loading branch information
skrawcz authored Jul 23, 2024
1 parent 9f835f4 commit 3e7b77a
Show file tree
Hide file tree
Showing 3 changed files with 117 additions and 5 deletions.
10 changes: 5 additions & 5 deletions burr/tracking/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ def _log_child_relationships(
# currently we write start events, so it really won't matter
# but in the future we'll write end events, but we'll parse it in a
# way that allows them to be interwoven
with open(parent_children_list_path, "a") as f:
with open(parent_children_list_path, "a", errors="replace", encoding="utf-8") as f:
fileno = f.fileno()
try:
fcntl.flock(fileno, fcntl.LOCK_EX)
Expand Down Expand Up @@ -257,7 +257,7 @@ def app_log_exists(
path = os.path.join(cls.get_storage_path(project, storage_dir), app_id, cls.LOG_FILENAME)
if not os.path.exists(path):
return False
lines = open(path, "r").readlines()
lines = open(path, "r", errors="replace", encoding="utf-8").readlines()
if len(lines) == 0:
return False
return True
Expand Down Expand Up @@ -292,7 +292,7 @@ def load_state(
path = os.path.join(cls.get_storage_path(project, storage_dir), app_id, cls.LOG_FILENAME)
if not os.path.exists(path):
raise ValueError(f"No logs found for {project}/{app_id} under {storage_dir}")
with open(path, "r") as f:
with open(path, "r", errors="replace", encoding="utf-8") as f:
json_lines = f.readlines()
# load as JSON
json_lines = [json.loads(js_line) for js_line in json_lines]
Expand Down Expand Up @@ -369,7 +369,7 @@ def post_application_create(
parent_pointer=PointerModel.from_pointer(parent_pointer),
spawning_parent_pointer=PointerModel.from_pointer(spawning_parent_pointer),
).model_dump()
with open(metadata_path, "w", errors="replace") as f:
with open(metadata_path, "w", errors="replace", encoding="utf-8") as f:
json.dump(metadata, f)

# Append to the parents of this the pointer to this, now
Expand Down Expand Up @@ -471,7 +471,7 @@ def load(
path = os.path.join(self.storage_dir, app_id, self.LOG_FILENAME)
if not os.path.exists(path):
return None
with open(path, "r") as f:
with open(path, "r", errors="replace", encoding="utf-8") as f:
json_lines = f.readlines()
if len(json_lines) == 0:
return None # in this case we have not logged anything yet
Expand Down
5 changes: 5 additions & 0 deletions docs/concepts/state-persistence.rst
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,11 @@ and you want to store the state of the process after each action, and then reloa

``Burr`` provides a few simple interfaces to do this with minimal changes. Let's walk through a simple chatbot example as we're explaining concepts:

Two notable assumptions:

1. for library provided persisters, state needs to ultimately be JSON serializable. If it's not, you can use the :doc:`serde` API to customize serialization and deserialization.
2. your general assumption should be that strings are/will be encoded as UTF-8, however this is dependent on the persister you use.

State Keys
----------
Burr `applications` are, by default, keyed on two entities:
Expand Down
107 changes: 107 additions & 0 deletions tests/tracking/test_local_tracking_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,3 +354,110 @@ def test_application_tracks_link_from_spawning_parent(tmpdir: str):
children_parsed = [ChildApplicationModel.model_validate(child) for child in children]
assert set(child.child.app_id for child in children_parsed) == set(spawned_children)
assert all(child.event_type == "spawn_start" for child in children_parsed)


def test_that_we_fail_on_non_unicode_characters(tmp_path):
"""This is a test to log expected behavior.
Right now it is on the developer to ensure that state can be encoded into UTF-8.
This test is here to capture this assumption.
"""

@action(reads=["test"], writes=["test"])
def state_1(state: State) -> State:
return state.update(test="test")

@action(reads=["test"], writes=["test"])
def state_2(state: State) -> State:
return state.update(test="\uD800") # Invalid UTF-8 byte sequence

tracker = LocalTrackingClient(project="test", storage_dir=tmp_path)
app: Application = (
ApplicationBuilder()
.with_actions(state_1, state_2)
.with_transitions(("state_1", "state_2"), ("state_2", "state_1"))
.with_tracker(tracker=tracker)
.initialize_from(
initializer=tracker,
resume_at_next_action=False,
default_entrypoint="state_1",
default_state={},
)
.with_identifiers(app_id="3")
.build()
)

with pytest.raises(ValueError):
app.run(halt_after=["state_2"])


def test_that_we_can_read_write_local_tracker(tmp_path):
"""Integration like test to ensure we can write and then read what was written"""

@action(
reads=[],
writes=[
"text",
"greek",
"cyrillic",
"hebrew",
"arabic",
"hindi",
"chinese",
"japanese",
"korean",
"emoji",
],
)
def state_1(state: State) -> State:
text = "á, é, í, ó, ú, ñ, ü"
greek = "α, β, γ, δ"
cyrillic = "ж, ы, б, ъ"
hebrew = "א, ב, ג, ד"
arabic = "خ, د, ذ, ر"
hindi = "अ, आ, इ, ई"
chinese = "中, 国, 文"
japanese = "日, 本, 語"
korean = "한, 국, 어"
emoji = "😀, 👍, 🚀, 🌍"
return state.update(
text=text,
greek=greek,
cyrillic=cyrillic,
hebrew=hebrew,
arabic=arabic,
hindi=hindi,
chinese=chinese,
japanese=japanese,
korean=korean,
emoji=emoji,
)

@action(reads=["text"], writes=["text"])
def state_2(state: State) -> State:
return state.update(text="\x9d") # encode-able UTF-8 sequence

tracker = LocalTrackingClient(
project="test",
storage_dir=tmp_path,
)

for i in range(2):
# reloads from log.jsonl in the second run and errors
app: Application = (
ApplicationBuilder()
.with_actions(state_1, state_2)
.with_transitions(("state_1", "state_2"), ("state_2", "state_1"))
.with_tracker(tracker=tracker)
.initialize_from(
initializer=tracker,
resume_at_next_action=False,
default_entrypoint="state_1",
default_state={},
)
.with_identifiers(app_id="3")
.build()
)

app.run(halt_after=["state_2"])

0 comments on commit 3e7b77a

Please sign in to comment.