Skip to content

Commit

Permalink
Add delete_checkpoints to the s3 checkpointer
Browse files Browse the repository at this point in the history
Slightly change key pattern of checkpoints in s3
  • Loading branch information
mbklein committed Dec 12, 2024
1 parent a075cfa commit e61928d
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 15 deletions.
70 changes: 58 additions & 12 deletions chat/src/agent/s3_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,21 +73,38 @@ def object_hook(dct):
obj = json.loads(json_str, object_hook=object_hook)
return obj

def _namespace(val):
return "__default__" if val == "" else val

def _namespace_val(namespace):
return "" if namespace == "__default__" else namespace

def _make_s3_thread_prefix(thread_id: str) -> str:
return f"checkpoints/{thread_id}"

def _make_s3_namespace_prefix(thread_id: str, checkpoint_ns: str) -> str:
prefix = _make_s3_thread_prefix(thread_id)
return f"{prefix}/{_namespace(checkpoint_ns)}"

def _make_s3_checkpoint_prefix(thread_id: str, checkpoint_ns: str, checkpoint_id: str) -> str:
prefix = _make_s3_namespace_prefix(thread_id, checkpoint_ns)
return f"{prefix}/{checkpoint_id}"

def _make_s3_checkpoint_key(thread_id: str, checkpoint_ns: str, checkpoint_id: str) -> str:
return f"checkpoints/{thread_id}/{checkpoint_ns}/{checkpoint_id}.json"
prefix = _make_s3_checkpoint_prefix(thread_id, checkpoint_ns, checkpoint_id)
return f"{prefix}/checkpoint.json"

def _make_s3_write_key(thread_id: str, checkpoint_ns: str, checkpoint_id: str, task_id: str, idx: int) -> str:
return f"checkpoints/{thread_id}/{checkpoint_ns}/{checkpoint_id}/writes/{task_id}/{idx}.json"
prefix = _make_s3_checkpoint_prefix(thread_id, checkpoint_ns, checkpoint_id)
return f"{prefix}/writes/{task_id}/{idx}.json"

def _parse_s3_checkpoint_key(key: str) -> Dict[str, str]:
parts = key.split("/")
if len(parts) < 4:
if len(parts) < 5 or parts[4] != "checkpoint.json":
raise ValueError("Invalid checkpoint key format")
thread_id = parts[1]
checkpoint_ns = parts[2]
filename = parts[3]
checkpoint_id = filename[:-5] # remove ".json"
checkpoint_ns = _namespace_val(parts[2])
checkpoint_id = parts[3]
return {
"thread_id": thread_id,
"checkpoint_ns": checkpoint_ns,
Expand Down Expand Up @@ -239,10 +256,10 @@ def list(

thread_id = config["configurable"]["thread_id"]
checkpoint_ns = config["configurable"].get("checkpoint_ns", "")
prefix = f"checkpoints/{thread_id}/{checkpoint_ns}/"
prefix = _make_s3_namespace_prefix(thread_id, checkpoint_ns)

paginator = self.s3.get_paginator("list_objects_v2")
pages = paginator.paginate(Bucket=self.bucket_name, Prefix=prefix)
pages = paginator.paginate(Bucket=self.bucket_name, Prefix=f"{prefix}/")

keys = []
for page in pages:
Expand Down Expand Up @@ -312,9 +329,9 @@ def list(
)

def _get_latest_checkpoint_id(self, thread_id: str, checkpoint_ns: str) -> Optional[str]:
prefix = f"checkpoints/{thread_id}/{checkpoint_ns}/"
prefix = _make_s3_namespace_prefix(thread_id, checkpoint_ns)
paginator = self.s3.get_paginator("list_objects_v2")
pages = paginator.paginate(Bucket=self.bucket_name, Prefix=prefix)
pages = paginator.paginate(Bucket=self.bucket_name, Prefix=f"{prefix}/")
keys = []
for page in pages:
for c in page.get("Contents", []):
Expand All @@ -331,7 +348,7 @@ def _get_latest_checkpoint_id(self, thread_id: str, checkpoint_ns: str) -> Optio
return latest_id

def _load_pending_writes(self, thread_id: str, checkpoint_ns: str, checkpoint_id: str) -> List[PendingWrite]:
prefix = f"checkpoints/{thread_id}/{checkpoint_ns}/{checkpoint_id}/writes/"
prefix = _make_s3_checkpoint_prefix(thread_id, checkpoint_ns, checkpoint_id) + "/writes/"
paginator = self.s3.get_paginator("list_objects_v2")
pages = paginator.paginate(Bucket=self.bucket_name, Prefix=prefix)

Expand All @@ -351,4 +368,33 @@ def _load_pending_writes(self, thread_id: str, checkpoint_ns: str, checkpoint_id
value = self.serde.loads_typed((value_type, value_data))
writes.append((task_id, channel, value))

return writes
return writes

def delete_checkpoints(bucket_name, thread_id, region_name="us-east-1"):
"""
Deletes all items with the specified thread_id from the checkpoint
bucket.
:param bucket_name: The name of the S3 checkpoint bucket
:param thread_id: The thread_id value to delete.
:param region_name: The S3 region the bucket is in
"""
session = boto3.Session(region_name=region_name)
client = session.client("s3")

def delete_objects(objects):
if objects['Objects']:
client.delete_objects(Bucket=bucket_name, Delete=objects)

paginator = client.get_paginator("list_objects_v2")
pages = paginator.paginate(Bucket=bucket_name, Prefix=f"checkpoints/{thread_id}/")

to_delete = dict(Objects=[])
for item in pages.search('Contents'):
if item is not None:
to_delete['Objects'].append(dict(Key=item['Key']))

if len(to_delete['Objects']) >= 1000:
delete_objects(to_delete)

delete_objects(to_delete)
7 changes: 4 additions & 3 deletions chat/src/handlers/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from datetime import datetime
from event_config import EventConfig
# from honeybadger import honeybadger
from agent.s3_saver import delete_checkpoints
from agent.search_agent import search_agent
from langchain_core.messages import HumanMessage
from agent.agent_handler import AgentHandler
Expand Down Expand Up @@ -55,8 +56,8 @@ def handler(event, context):
config.socket.send({"type": "error", "message": "Unauthorized"})
return {"statusCode": 401, "body": "Unauthorized"}

# if config.forget:
# delete_checkpoint(config.ref)
if config.forget:
delete_checkpoints(os.getenv("CHECKPOINT_BUCKET_NAME"), config.ref)

if config.question is None or config.question == "":
config.socket.send({"type": "error", "message": "Question cannot be blank"})
Expand All @@ -76,7 +77,7 @@ def handler(event, context):
search_agent.invoke(
{"messages": [HumanMessage(content=config.question)]},
config={"configurable": {"thread_id": config.ref}, "callbacks": callbacks},
debug=True
debug=False
)
except Exception as e:
print(f"Error: {e}")
Expand Down

0 comments on commit e61928d

Please sign in to comment.