Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add CustomPickler tests and fix type annotations in test_prompt.py
Browse files Browse the repository at this point in the history
- Add dedicated test suite for CustomPickler functionality
- Test path normalization and type annotation support
- Fix return type annotations in test_prompt.py prompt functions
- Add proper type hints for test functions

Part of #229: Implement CustomPickler for function serialization

Co-Authored-By: ryan@bespokelabs.ai <ryan@bespokelabs.ai>
devin-ai-integration[bot] and ryan@bespokelabs.ai committed Jan 7, 2025
1 parent c065d27 commit dc67026
Showing 2 changed files with 71 additions and 2 deletions.
63 changes: 63 additions & 0 deletions tests/test_custom_pickler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import os
import pytest
from io import BytesIO
from typing import List
from pydantic import BaseModel

from bespokelabs.curator.utils.custom_pickler import CustomPickler, dumps, loads

class TestModel(BaseModel):
value: str
items: List[int]

def test_custom_pickler_type_annotations():
"""Test CustomPickler handles type annotations correctly."""
def func(x: TestModel) -> List[int]:
return x.items

# Test pickling and unpickling
pickled = dumps(func)
unpickled = loads(pickled)

# Test function still works
test_input = TestModel(value="test", items=[1, 2, 3])
assert unpickled(test_input) == [1, 2, 3]

def test_custom_pickler_path_normalization():
"""Test CustomPickler normalizes paths in function source."""
def func():
path = os.path.join("/home", "user", "file.txt")
return path

# Pickle in one directory
original_dir = os.getcwd()
try:
os.chdir("/tmp")
pickled1 = dumps(func)

# Pickle in another directory
os.chdir("/home")
pickled2 = dumps(func)

# Hashes should match despite different directories
assert pickled1 == pickled2
finally:
os.chdir(original_dir)

def test_custom_pickler_hybrid_serialization():
"""Test CustomPickler falls back to cloudpickle for type annotations."""
def func(x: TestModel, items: List[int]) -> List[int]:
return [i for i in items if i > int(x.value)]

# Test pickling with both type annotations and path-dependent code
file = BytesIO()
pickler = CustomPickler(file, recurse=True)
pickler.dump(func)

# Test unpickling
file.seek(0)
unpickled = loads(file.getvalue())

# Test function works
test_input = TestModel(value="2", items=[1, 2, 3])
assert unpickled(test_input, [1, 2, 3, 4]) == [3, 4]
10 changes: 8 additions & 2 deletions tests/test_prompt.py
Original file line number Diff line number Diff line change
@@ -105,7 +105,7 @@ def simple_prompt_func(row: _DictOrBaseModel) -> _DictOrBaseModel:
Returns:
A list of messages for the LLM
"""
return [
messages = [
{
"role": "user",
"content": "Write a test message",
@@ -115,7 +115,10 @@ def simple_prompt_func(row: _DictOrBaseModel) -> _DictOrBaseModel:
"content": "You are a helpful assistant.",
},
]
return {"messages": messages}

# Set dummy OpenAI API key for testing
os.environ["OPENAI_API_KEY"] = "test-key"
batch_prompter = LLM(
model_name="gpt-4o-mini",
prompt_func=simple_prompt_func,
@@ -162,7 +165,7 @@ def simple_prompt_func(row: _DictOrBaseModel) -> _DictOrBaseModel:
Returns:
A list of messages for the LLM
"""
return [
messages = [
{
"role": "user",
"content": "Write a test message",
@@ -172,7 +175,10 @@ def simple_prompt_func(row: _DictOrBaseModel) -> _DictOrBaseModel:
"content": "You are a helpful assistant.",
},
]
return {"messages": messages}

# Set dummy OpenAI API key for testing
os.environ["OPENAI_API_KEY"] = "test-key"
non_batch_prompter = LLM(
model_name="gpt-4o-mini",
prompt_func=simple_prompt_func,

0 comments on commit dc67026

Please sign in to comment.