-
Notifications
You must be signed in to change notification settings - Fork 54
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
f19e003
commit a7e1237
Showing
8 changed files
with
153 additions
and
68 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
"""Generate synthetic recipes from ingrients image and cuisine using curator.""" | ||
|
||
from datasets import Dataset | ||
|
||
from bespokelabs import curator | ||
|
||
|
||
class RecipeGenerator(curator.LLM): | ||
"""A recipe generator that generates recipes for different cuisines.""" | ||
|
||
def prompt(self, input: dict) -> str: | ||
"""Generate a prompt using the template and cuisine.""" | ||
prompt = f"Create me a recipe for {input['cuisine']} cuisine and ingrients from the image." | ||
return prompt, curator.types.Image(url=input["image_url"]) | ||
|
||
def parse(self, input: dict, response: str) -> dict: | ||
"""Parse the model response along with the input to the model into the desired output format..""" | ||
return { | ||
"recipe": response, | ||
} | ||
|
||
|
||
def main(): | ||
"""Generate synthetic recipes for different cuisines.""" | ||
# List of cuisines to generate recipes for | ||
cuisines = [ | ||
{"cuisine": cuisine[0], "image_url": cuisine[1]} | ||
for cuisine in [ | ||
("Indian", "https://cdn.tasteatlas.com//images/ingredients/fcee541cd2354ed8b68b50d1aa1acad8.jpeg"), | ||
("Thai", "https://cdn.tasteatlas.com//images/dishes/da5fd425608f48b09555f5257a8d3a86.jpg"), | ||
] | ||
] | ||
cuisines = Dataset.from_list(cuisines) | ||
|
||
# Create prompter using LiteLLM backend | ||
############################################# | ||
# To use Gemini models: | ||
# 1. Go to https://aistudio.google.com/app/apikey | ||
# 2. Generate an API key | ||
# 3. Set environment variable: GEMINI_API_KEY | ||
# 4. If you are a free user, update rate limits: | ||
# max_requests_per_minute=15 | ||
# max_tokens_per_minute=1_000_000 | ||
# (Up to 1,000 requests per day) | ||
############################################# | ||
|
||
recipe_generator = RecipeGenerator( | ||
model_name="gpt-4o", | ||
backend="openai", | ||
) | ||
|
||
# Generate recipes for all cuisines | ||
recipes = recipe_generator(cuisines) | ||
|
||
# Print results | ||
print(recipes.to_pandas()) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,48 +1,51 @@ | ||
# Description: Pydantic models for multimodal prompts. | ||
import typing as t | ||
|
||
from pydantic import BaseModel, Field | ||
|
||
|
||
class BaseType(BaseModel): | ||
"""A class to represent the base type for multimodal prompts.""" | ||
|
||
type: str = Field(..., description="The type of the multimodal prompt.") | ||
type: t.ClassVar[str] = Field(..., description="The type of the multimodal prompt.") | ||
|
||
|
||
class Image(BaseModel): | ||
class Image(BaseType): | ||
"""A class to represent an image for multimodal prompts.""" | ||
|
||
url: str = Field(None, description="The URL of the image.") | ||
content: str = Field(None, description="Base64-encoded image content.") | ||
type = "image" | ||
url: str = Field("", description="The URL of the image.") | ||
content: str = Field("", description="Base64-encoded image content.") | ||
type: t.ClassVar[str] = "image" | ||
|
||
def __post_init__(self): | ||
"""Post init.""" | ||
# assert url or content is provided | ||
assert self.url or self.content, "Either 'url' or 'content' must be provided." | ||
|
||
|
||
class File(BaseModel): | ||
class File(BaseType): | ||
"""A class to represent a file for multimodal prompts.""" | ||
|
||
url: str = Field(..., description="The URL of the file.") | ||
type = "file" | ||
type: t.ClassVar[str] = "file" | ||
|
||
|
||
class _MultiModalPrompt(BaseType): | ||
"""A class to represent a multimodal prompt.""" | ||
|
||
texts: str = Field(None, description="The text of the prompt.") | ||
images: Image = Field(None, description="The image of the prompt.") | ||
files: File = Field(None, description="The file of the prompt.") | ||
texts: t.List[str] = Field(default_factory=list, description="The text of the prompt.") | ||
images: t.List[Image] = Field(default_factory=list, description="The image of the prompt.") | ||
files: t.List[File] = Field(default_factory=list, description="The file of the prompt.") | ||
|
||
@classmethod | ||
def load(cls, messages): | ||
prompt = {} | ||
prompt = {"texts": [], "images": [], "files": []} | ||
for msg in messages: | ||
if isinstance(msg, BaseType): | ||
if msg.type == "image": | ||
prompt["images"] = msg | ||
prompt["images"].append(msg) | ||
elif msg.type == "file": | ||
prompt["files"] = msg | ||
prompt["files"].append(msg) | ||
else: | ||
prompt["text"] = msg | ||
prompt["texts"].append(msg) | ||
return cls(**prompt) |