-
Notifications
You must be signed in to change notification settings - Fork 8
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
dfaafe6
commit ce2f778
Showing
7 changed files
with
1,375 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
import asyncio | ||
import os | ||
|
||
from PIL import Image as PIL_Image | ||
from pydantic import BaseModel, Field | ||
|
||
from tiny_ai_client import AI, AsyncAI | ||
|
||
|
||
class WeatherParams(BaseModel): | ||
location: str = Field(..., description="The city and state, e.g. San Francisco, CA") | ||
unit: str = Field( | ||
"celsius", description="Temperature unit", enum=["celsius", "fahrenheit"] | ||
) | ||
|
||
|
||
def get_current_weather(weather: WeatherParams): | ||
""" | ||
Get the current weather in a given location | ||
""" | ||
return { | ||
"abc": f"Getting the current weather in {weather.location} in {weather.unit}." | ||
} | ||
|
||
|
||
def get_images(): | ||
return [PIL_Image.open("assets/kirk.jpg"), PIL_Image.open("assets/spock.jpg")] | ||
|
||
|
||
def main(): | ||
print("### SYNC AI ###") | ||
ai = AI( | ||
model_name="gemini-1.5-flash", | ||
system="You are Spock, from Star Trek.", | ||
max_new_tokens=128, | ||
tools=[get_current_weather], | ||
) | ||
response = ai("How are you?") | ||
print(f"{response=}") | ||
response = ai("Who is on the images?", images=get_images()) | ||
print(f"{response=}") | ||
# print(f"{ai.chat=}") | ||
|
||
|
||
async def async_ai_main(): | ||
print("### ASYNC AI ###") | ||
ai = AsyncAI( | ||
model_name="gemini-1.5-flash", | ||
system="You are Spock, from Star Trek.", | ||
max_new_tokens=128, | ||
tools=[get_current_weather], | ||
) | ||
response = await ai("How are you?") | ||
print(f"{response=}") | ||
response = await ai("Who is on the images?", images=get_images()) | ||
print(f"{response=}") | ||
# print(f"{ai.chat=}") | ||
|
||
|
||
if __name__ == "__main__": | ||
os.environ["GOOGLE_API_KEY"] = None | ||
main() | ||
asyncio.run(async_ai_main()) |
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,6 @@ | ||
[tool.poetry] | ||
name = "tiny-ai-client" | ||
version = "0.0.3" | ||
version = "0.0.4" | ||
description = "Tiny AI client for LLMs. As simple as it gets." | ||
authors = ["piEsposito <[email protected]>"] | ||
license = "Apache 2.0" | ||
|
@@ -13,6 +13,12 @@ openai = "1.31.0" | |
anthropic = "0.28.0" | ||
|
||
|
||
[tool.poetry.group.gemini] | ||
optional = true | ||
|
||
[tool.poetry.group.gemini.dependencies] | ||
google-generativeai = "^0.6.0" | ||
|
||
[build-system] | ||
requires = ["poetry-core"] | ||
build-backend = "poetry.core.masonry.api" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,118 @@ | ||
import os | ||
from copy import deepcopy | ||
from typing import Any, Callable, Dict, List, Union | ||
|
||
import google.generativeai as genai | ||
|
||
from tiny_ai_client.models import LLMClientWrapper, Message | ||
from tiny_ai_client.tools import function_to_json | ||
|
||
|
||
class GeminiClientWrapper(LLMClientWrapper): | ||
def __init__(self, model_name: str, tools: List[Union[Callable, Dict]]): | ||
self.model_name = model_name | ||
genai.configure(api_key=os.environ["GOOGLE_API_KEY"]) | ||
self.tools = tools | ||
self.tools_json = [function_to_json(tool) for tool in tools] | ||
if len(self.tools_json) > 0: | ||
raise ValueError("Gemini does not support tools") | ||
# self.tools_obj = [ | ||
# genai.types.FunctionDeclaration( | ||
# name=tool["function"]["name"], | ||
# description=tool["function"]["description"], | ||
# parameters=genai.types.content_types.strip_titles( | ||
# tool["function"]["parameters"] | ||
# ), | ||
# ) | ||
# for tool in self.tools_json | ||
# ] | ||
|
||
def build_model_input(self, messages: List["Message"]) -> Any: | ||
history = [] | ||
local_messages = deepcopy(messages) | ||
system = None | ||
message = None | ||
|
||
for message in local_messages: | ||
if message.role == "system": | ||
system = message.text | ||
continue | ||
else: | ||
if message.role not in ["user", "assistant"]: | ||
raise ValueError(f"Invalid role for Gemini: {message.role}") | ||
role = "user" if message.role == "user" else "model" | ||
parts = [] | ||
if message.text is not None: | ||
parts.append(message.text) | ||
if message.images is not None: | ||
parts.extend(message.images) | ||
history.append( | ||
{ | ||
"role": role, | ||
"parts": parts, | ||
} | ||
) | ||
|
||
return (system, history) | ||
|
||
def call_llm_provider( | ||
self, | ||
model_input: Any, | ||
temperature: int | None, | ||
max_new_tokens: int | None, | ||
timeout: int, | ||
) -> str: | ||
system, history = model_input | ||
|
||
generation_config_kwargs = {} | ||
if temperature is not None: | ||
generation_config_kwargs["temperature"] = temperature | ||
if max_new_tokens is not None: | ||
generation_config_kwargs["max_output_tokens"] = max_new_tokens | ||
|
||
generation_config = genai.GenerationConfig(**generation_config_kwargs) | ||
model = genai.GenerativeModel( | ||
self.model_name, | ||
system_instruction=system, | ||
generation_config=generation_config, | ||
) | ||
|
||
model.start_chat(history=history) | ||
response = model.generate_content(history) | ||
response = response.candidates[0].content.parts[0] | ||
if response.function_call.name != "": | ||
raise ValueError("Function calls are not supported in Gemini") | ||
elif response.text is not None: | ||
return Message(role="assistant", text=response.text) | ||
raise ValueError("Invalid response from Gemini") | ||
|
||
async def async_call_llm_provider( | ||
self, | ||
model_input: Any, | ||
temperature: int | None, | ||
max_new_tokens: int | None, | ||
timeout: int, | ||
) -> str: | ||
system, history = model_input | ||
|
||
generation_config_kwargs = {} | ||
if temperature is not None: | ||
generation_config_kwargs["temperature"] = temperature | ||
if max_new_tokens is not None: | ||
generation_config_kwargs["max_output_tokens"] = max_new_tokens | ||
|
||
generation_config = genai.GenerationConfig(**generation_config_kwargs) | ||
model = genai.GenerativeModel( | ||
self.model_name, | ||
system_instruction=system, | ||
generation_config=generation_config, | ||
) | ||
|
||
model.start_chat(history=history) | ||
response = await model.generate_content_async(history) | ||
response = response.candidates[0].content.parts[0] | ||
if response.function_call.name != "": | ||
raise ValueError("Function calls are not supported in Gemini") | ||
elif response.text is not None: | ||
return Message(role="assistant", text=response.text) | ||
raise ValueError("Invalid response from Gemini") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters