Skip to content

Commit

Permalink
Add DynamoDB checkpoint backend
Browse files Browse the repository at this point in the history
  • Loading branch information
mbklein committed Dec 11, 2024
1 parent ff8dd16 commit 7f9ad28
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 5 deletions.
5 changes: 2 additions & 3 deletions chat/src/agent/dynamodb_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,6 @@
get_checkpoint_id,
)


import json
from typing import Any, Tuple
from langgraph.checkpoint.serde.jsonplus import JsonPlusSerializer

class JsonPlusSerializer(JsonPlusSerializer):
Expand Down Expand Up @@ -260,6 +257,7 @@ def put(
'type': type_,
'checkpoint': checkpoint_data,
'metadata': self.serde.dumps_typed(metadata)[1],
'timestamp': checkpoint_created_at,
}

self.table.put_item(Item=item)
Expand Down Expand Up @@ -301,5 +299,6 @@ def put_writes(
'channel': channel,
'type': type_,
'value': value_data,
'timestamp': int(time.time() * 1000),
}
batch.put_item(Item=item)
2 changes: 1 addition & 1 deletion chat/src/agent/search_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,6 @@ def call_model(state: MessagesState):
# Add a normal edge from `tools` to `agent`
workflow.add_edge("tools", "agent")

checkpointer = DynamoDBSaver("checkpoints", "checkpoint-writes", "us-east-1")
checkpointer = DynamoDBSaver(os.getenv("CHECKPOINT_TABLE"), os.getenv("CHECKPOINT_WRITES_TABLE"), os.getenv("AWS_REGION", "us-east-1"))

search_agent = workflow.compile(checkpointer=checkpointer)
3 changes: 2 additions & 1 deletion chat/src/event_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from helpers.apitoken import ApiToken
from helpers.prompts import prompt_template
from websocket import Websocket
from uuid import uuid4

CHAIN_TYPE = "stuff"
DOCUMENT_VARIABLE_NAME = "context"
Expand Down Expand Up @@ -88,7 +89,7 @@ def __post_init__(self):
self.prompt_text = self._get_prompt_text()
self.request_context = self.event.get("requestContext", {})
self.question = self.payload.get("question")
self.ref = self.payload.get("ref")
self.ref = self.payload.get("ref", uuid4().hex)
self.size = self._get_size()
self.stream_response = self.payload.get("stream_response", not self.debug_mode)
self.temperature = self._get_temperature()
Expand Down
46 changes: 46 additions & 0 deletions chat/template.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,8 @@ Resources:
Environment:
Variables:
API_TOKEN_NAME: !Ref ApiTokenName
CHECKPOINT_TABLE: !Ref CheckpointTable
CHECKPOINT_WRITES_TABLE: !Ref CheckpointWritesTable
ENV_PREFIX: !Ref EnvironmentPrefix
HONEYBADGER_API_KEY: !Ref HoneybadgerApiKey
HONEYBADGER_ENVIRONMENT: !Ref HoneybadgerEnv
Expand All @@ -223,6 +225,22 @@ Resources:
- 'es:ESHttpGet'
- 'es:ESHttpPost'
Resource: '*'
- Statement:
- Effect: Allow
Action:
- 'dynamodb:BatchGetItem'
- 'dynamodb:BatchWriteItem'
- 'dynamodb:ConditionCheckItem'
- 'dynamodb:DeleteItem'
- 'dynamodb:GetItem'
- 'dynamodb:PutItem'
- 'dynamodb:Query'
- 'dynamodb:Scan'
- 'dynamodb:UpdateItem'
- 'dynamodb:WriteItem'
Resource:
- !GetAtt CheckpointTable.Arn
- !GetAtt CheckpointWritesTable.Arn
- Statement:
- Effect: Allow
Action:
Expand Down Expand Up @@ -270,6 +288,34 @@ Resources:
# Resource: !Sub "${ChatMetricsLog.Arn}:*"
#* Metadata:
#* BuildMethod: nodejs20.x
CheckpointTable:
Type: AWS::DynamoDB::Table
Properties:
AttributeDefinitions:
- AttributeName: thread_id
AttributeType: S
- AttributeName: sort_key
AttributeType: S
BillingMode: PAY_PER_REQUEST
KeySchema:
- AttributeName: thread_id
KeyType: HASH
- AttributeName: sort_key
KeyType: RANGE
CheckpointWritesTable:
Type: AWS::DynamoDB::Table
Properties:
AttributeDefinitions:
- AttributeName: thread_id
AttributeType: S
- AttributeName: sort_key
AttributeType: S
BillingMode: PAY_PER_REQUEST
KeySchema:
- AttributeName: thread_id
KeyType: HASH
- AttributeName: sort_key
KeyType: RANGE
ChatMetricsLog:
Type: AWS::Logs::LogGroup
Properties:
Expand Down

0 comments on commit 7f9ad28

Please sign in to comment.