Skip to content

Commit

Permalink
Update documentation and implement CustomPickler for function seriali…
Browse files Browse the repository at this point in the history
…zation

Co-Authored-By: [email protected] <[email protected]>
  • Loading branch information
1 parent 0d5703f commit c065d27
Show file tree
Hide file tree
Showing 4 changed files with 93 additions and 9 deletions.
1 change: 1 addition & 0 deletions .python-version
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
3.11.7
13 changes: 7 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ print(poems)
```
Note that retries and caching are enabled by default.
So now if you run the same prompt again, you will get the same response, pretty much instantly.
You can delete the cache at `~/.cache/curator`.
You can delete the cache at the path specified by `os.path.join(os.path.expanduser("~"), ".cache", "curator")`.

#### Use LiteLLM backend for calling other models
You can use the [LiteLLM](https://docs.litellm.ai/docs/providers) backend for calling other models.
Expand Down Expand Up @@ -127,10 +127,11 @@ poet = curator.LLM(
Here:
* `prompt_func` takes a row of the dataset as input and returns the prompt for the LLM.
* `response_format` is the structured output class we defined above.
* `parse_func` takes the input (`row`) and the structured output (`poems`) and converts it to a list of dictionaries. This is so that we can easily convert the output to a HuggingFace Dataset object. For best practices with type annotations:
* `parse_func` takes the input (`row`) and the structured output (`poems`) and converts it to a list of dictionaries. This is so that we can easily convert the output to a HuggingFace Dataset object. For best practices:
* Define `parse_func` as a module-level function rather than a lambda to ensure proper serialization
* Use the `_DictOrBaseModel` type alias for input/output types: `def parse_func(row: _DictOrBaseModel, response: _DictOrBaseModel) -> _DictOrBaseModel`
* Type annotations are now fully supported thanks to cloudpickle serialization
* Type annotations are fully supported through our CustomPickler implementation
* Function hashing is path-independent, ensuring consistent caching across different environments (e.g., Ray clusters)

Now we can apply the `LLM` object to the dataset, which reads very pythonic.
```python
Expand All @@ -145,8 +146,8 @@ print(poem.to_pandas())
```
Note that `topics` can be created with `curator.LLM` as well,
and we can scale this up to create tens of thousands of diverse poems.
You can see a more detailed example in the [examples/poem.py](https://github.com/bespokelabsai/curator/blob/mahesh/update_doc/examples/poem.py) file,
and other examples in the [examples](https://github.com/bespokelabsai/curator/blob/mahesh/update_doc/examples) directory.
You can see a more detailed example in the [examples/poem-generation/poem.py](https://github.com/bespokelabsai/curator/blob/main/examples/poem-generation/poem.py) file,
and other examples in the [examples](https://github.com/bespokelabsai/curator/blob/main/examples) directory.

See the [docs](https://docs.bespokelabs.ai/) for more details as well as
for troubleshooting information.
Expand Down Expand Up @@ -204,4 +205,4 @@ npm -v # should print `10.9.0`
```

## Contributing
Contributions are welcome!
Contributions are welcome!
9 changes: 6 additions & 3 deletions src/bespokelabs/curator/llm/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,12 @@
from io import BytesIO
from typing import Any, Callable, Dict, Iterable, Optional, Type, TypeVar, Union

import cloudpickle
from datasets import Dataset
from pydantic import BaseModel
from xxhash import xxh64

from bespokelabs.curator.utils.custom_pickler import CustomPickler

from bespokelabs.curator.db import MetadataDB
from bespokelabs.curator.llm.prompt_formatter import PromptFormatter
from bespokelabs.curator.request_processor import (
Expand Down Expand Up @@ -285,13 +286,15 @@ def __call__(
def _get_function_hash(func) -> str:
"""Get a hash of a function's source code.
Uses cloudpickle to properly handle functions with type annotations and closure variables.
Uses CustomPickler to properly handle both:
1. Path normalization (from HuggingFace's Pickler)
2. Type annotations and closure variables (from cloudpickle)
"""
if func is None:
return xxh64("").hexdigest()

file = BytesIO()
file.write(cloudpickle.dumps(func))
CustomPickler(file, recurse=True).dump(func)
return xxh64(file.getvalue()).hexdigest()


Expand Down
79 changes: 79 additions & 0 deletions src/bespokelabs/curator/utils/custom_pickler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
"""Custom Pickler that combines HuggingFace's path normalization with type annotation support.
This module provides a CustomPickler class that extends HuggingFace's Pickler to support
both path normalization (for consistent function hashing across different environments)
and type annotations in function signatures.
"""

import os
from io import BytesIO
from typing import Any, Optional, Type, TypeVar, Union

import cloudpickle
from datasets.utils._dill import Pickler as HFPickler


class CustomPickler(HFPickler):
"""A custom pickler that combines HuggingFace's path normalization with type annotation support.
This pickler extends HuggingFace's Pickler to:
1. Preserve path normalization for consistent function hashing
2. Support type annotations in function signatures
3. Handle closure variables properly
"""

def __init__(self, file: BytesIO, recurse: bool = True):
"""Initialize the CustomPickler.
Args:
file: The file-like object to pickle to
recurse: Whether to recursively handle object attributes
"""
super().__init__(file, recurse=recurse)

def save(self, obj: Any, save_persistent_id: bool = True) -> None:
"""Save an object, handling type annotations properly.
This method attempts to use cloudpickle's type annotation handling while
preserving HuggingFace's path normalization logic.
Args:
obj: The object to pickle
save_persistent_id: Whether to save persistent IDs
"""
try:
# First try HuggingFace's pickler for path normalization
super().save(obj, save_persistent_id=save_persistent_id)
except Exception as e:
if "No default __reduce__ due to non-trivial __cinit__" in str(e):
# If HF's pickler fails with type annotation error, use cloudpickle
cloudpickle.dump(obj, self._file)
else:
# Re-raise other exceptions
raise


def dumps(obj: Any) -> bytes:
"""Pickle an object to bytes using CustomPickler.
Args:
obj: The object to pickle
Returns:
The pickled object as bytes
"""
file = BytesIO()
CustomPickler(file, recurse=True).dump(obj)
return file.getvalue()


def loads(data: bytes) -> Any:
"""Unpickle an object from bytes.
Args:
data: The pickled data
Returns:
The unpickled object
"""
return cloudpickle.loads(data)

0 comments on commit c065d27

Please sign in to comment.