Skip to content

Commit

Permalink
fix(product-assistant): checkpoint blob queries must not rely on chec…
Browse files Browse the repository at this point in the history
…kpoint (#27048)
  • Loading branch information
skoob13 authored Dec 20, 2024
1 parent 5b7a31d commit f55dca1
Show file tree
Hide file tree
Showing 5 changed files with 204 additions and 4 deletions.
5 changes: 4 additions & 1 deletion ee/hogai/django_checkpoint/checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,9 @@ def _get_checkpoint_channel_values(
query = Q()
for channel, version in loaded_checkpoint["channel_versions"].items():
query |= Q(channel=channel, version=version)
return checkpoint.blobs.filter(query)
return ConversationCheckpointBlob.objects.filter(
Q(thread_id=checkpoint.thread_id, checkpoint_ns=checkpoint.checkpoint_ns) & query
)

def list(
self,
Expand Down Expand Up @@ -238,6 +240,7 @@ def put(
blobs.append(
ConversationCheckpointBlob(
checkpoint=updated_checkpoint,
thread_id=thread_id,
channel=channel,
version=str(version),
type=type,
Expand Down
153 changes: 152 additions & 1 deletion ee/hogai/django_checkpoint/test/test_checkpointer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# type: ignore

from typing import Any, TypedDict
import operator
from typing import Annotated, Any, Optional, TypedDict

from langchain_core.runnables import RunnableConfig
from langgraph.checkpoint.base import (
Expand All @@ -13,6 +14,7 @@
from langgraph.errors import NodeInterrupt
from langgraph.graph import END, START
from langgraph.graph.state import CompiledStateGraph, StateGraph
from pydantic import BaseModel, Field

from ee.hogai.django_checkpoint.checkpointer import DjangoCheckpointer
from ee.models.assistant import (
Expand Down Expand Up @@ -272,3 +274,152 @@ def test_resuming(self):
self.assertEqual(res, {"val": 3})
snapshot = graph.get_state(config)
self.assertFalse(snapshot.next)

def test_checkpoint_blobs_are_bound_to_thread(self):
class State(TypedDict, total=False):
messages: Annotated[list[str], operator.add]
string: Optional[str]

graph = StateGraph(State)

def handle_node1(state: State):
return

def handle_node2(state: State):
raise NodeInterrupt("test")

graph.add_node("node1", handle_node1)
graph.add_node("node2", handle_node2)

graph.add_edge(START, "node1")
graph.add_edge("node1", "node2")
graph.add_edge("node2", END)

compiled = graph.compile(checkpointer=DjangoCheckpointer())

thread = Conversation.objects.create(user=self.user, team=self.team)
config = {"configurable": {"thread_id": str(thread.id)}}
compiled.invoke({"messages": ["hello"], "string": "world"}, config=config)

snapshot = compiled.get_state(config)
self.assertIsNotNone(snapshot.next)
self.assertEqual(snapshot.tasks[0].interrupts[0].value, "test")
saved_state = snapshot.values
self.assertEqual(saved_state["messages"], ["hello"])
self.assertEqual(saved_state["string"], "world")

def test_checkpoint_can_save_and_load_pydantic_state(self):
class State(BaseModel):
messages: Annotated[list[str], operator.add]
string: Optional[str]

class PartialState(BaseModel):
messages: Optional[list[str]] = Field(default=None)
string: Optional[str] = Field(default=None)

graph = StateGraph(State)

def handle_node1(state: State):
return PartialState()

def handle_node2(state: State):
raise NodeInterrupt("test")

graph.add_node("node1", handle_node1)
graph.add_node("node2", handle_node2)

graph.add_edge(START, "node1")
graph.add_edge("node1", "node2")
graph.add_edge("node2", END)

compiled = graph.compile(checkpointer=DjangoCheckpointer())

thread = Conversation.objects.create(user=self.user, team=self.team)
config = {"configurable": {"thread_id": str(thread.id)}}
compiled.invoke({"messages": ["hello"], "string": "world"}, config=config)

snapshot = compiled.get_state(config)
self.assertIsNotNone(snapshot.next)
self.assertEqual(snapshot.tasks[0].interrupts[0].value, "test")
saved_state = snapshot.values
self.assertEqual(saved_state["messages"], ["hello"])
self.assertEqual(saved_state["string"], "world")

def test_saved_blobs(self):
class State(TypedDict, total=False):
messages: Annotated[list[str], operator.add]

graph = StateGraph(State)

def handle_node1(state: State):
return {"messages": ["world"]}

graph.add_node("node1", handle_node1)

graph.add_edge(START, "node1")
graph.add_edge("node1", END)

checkpointer = DjangoCheckpointer()
compiled = graph.compile(checkpointer=checkpointer)

thread = Conversation.objects.create(user=self.user, team=self.team)
config = {"configurable": {"thread_id": str(thread.id)}}
compiled.invoke({"messages": ["hello"]}, config=config)

snapshot = compiled.get_state(config)
self.assertFalse(snapshot.next)
saved_state = snapshot.values
self.assertEqual(saved_state["messages"], ["hello", "world"])

blobs = list(ConversationCheckpointBlob.objects.filter(thread=thread))
self.assertEqual(len(blobs), 7)

# Set initial state
self.assertEqual(blobs[0].channel, "__start__")
self.assertEqual(blobs[0].type, "msgpack")
self.assertEqual(
checkpointer.serde.loads_typed((blobs[0].type, blobs[0].blob)),
{"messages": ["hello"]},
)

# Set first node
self.assertEqual(blobs[1].channel, "__start__")
self.assertEqual(blobs[1].type, "empty")
self.assertIsNone(blobs[1].blob)

# Set value channels before start
self.assertEqual(blobs[2].channel, "messages")
self.assertEqual(blobs[2].type, "msgpack")
self.assertEqual(
checkpointer.serde.loads_typed((blobs[2].type, blobs[2].blob)),
["hello"],
)

# Transition to node1
self.assertEqual(blobs[3].channel, "start:node1")
self.assertEqual(blobs[3].type, "msgpack")
self.assertEqual(
checkpointer.serde.loads_typed((blobs[3].type, blobs[3].blob)),
"__start__",
)

# Set new state for messages
self.assertEqual(blobs[4].channel, "messages")
self.assertEqual(blobs[4].type, "msgpack")
self.assertEqual(
checkpointer.serde.loads_typed((blobs[4].type, blobs[4].blob)),
["hello", "world"],
)

# After setting a state
self.assertEqual(blobs[5].channel, "start:node1")
self.assertEqual(blobs[5].type, "empty")
self.assertIsNone(blobs[5].blob)

# Set last step
self.assertEqual(blobs[6].channel, "node1")
self.assertEqual(blobs[6].type, "msgpack")
self.assertEqual(
checkpointer.serde.loads_typed((blobs[6].type, blobs[6].blob)),
"node1",
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# Generated by Django 4.2.15 on 2024-12-19 11:00

from django.db import migrations, models
import django.db.models.deletion


class Migration(migrations.Migration):
dependencies = [
("ee", "0018_conversation_conversationcheckpoint_and_more"),
]

operations = [
migrations.RemoveConstraint(
model_name="conversationcheckpointblob",
name="unique_checkpoint_blob",
),
migrations.AddField(
model_name="conversationcheckpointblob",
name="checkpoint_ns",
field=models.TextField(
default="",
help_text='Checkpoint namespace. Denotes the path to the subgraph node the checkpoint originates from, separated by `|` character, e.g. `"child|grandchild"`. Defaults to "" (root graph).',
),
),
migrations.AddField(
model_name="conversationcheckpointblob",
name="thread",
field=models.ForeignKey(
null=True, on_delete=django.db.models.deletion.CASCADE, related_name="blobs", to="ee.conversation"
),
),
migrations.AddConstraint(
model_name="conversationcheckpointblob",
constraint=models.UniqueConstraint(
fields=("thread_id", "checkpoint_ns", "channel", "version"), name="unique_checkpoint_blob"
),
),
]
2 changes: 1 addition & 1 deletion ee/migrations/max_migration.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0018_conversation_conversationcheckpoint_and_more
0019_remove_conversationcheckpointblob_unique_checkpoint_blob_and_more
10 changes: 9 additions & 1 deletion ee/models/assistant.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,14 @@ def pending_writes(self) -> Iterable["ConversationCheckpointWrite"]:

class ConversationCheckpointBlob(UUIDModel):
checkpoint = models.ForeignKey(ConversationCheckpoint, on_delete=models.CASCADE, related_name="blobs")
"""
The checkpoint that created the blob. Do not use this field to query blobs.
"""
thread = models.ForeignKey(Conversation, on_delete=models.CASCADE, related_name="blobs", null=True)
checkpoint_ns = models.TextField(
default="",
help_text='Checkpoint namespace. Denotes the path to the subgraph node the checkpoint originates from, separated by `|` character, e.g. `"child|grandchild"`. Defaults to "" (root graph).',
)
channel = models.TextField(
help_text="An arbitrary string defining the channel name. For example, it can be a node name or a reserved LangGraph's enum."
)
Expand All @@ -56,7 +64,7 @@ class ConversationCheckpointBlob(UUIDModel):
class Meta:
constraints = [
models.UniqueConstraint(
fields=["checkpoint_id", "channel", "version"],
fields=["thread_id", "checkpoint_ns", "channel", "version"],
name="unique_checkpoint_blob",
)
]
Expand Down

0 comments on commit f55dca1

Please sign in to comment.