Skip to content

Commit

Permalink
gemini wth vision but no tools
Browse files Browse the repository at this point in the history
  • Loading branch information
piEsposito committed Jun 5, 2024
1 parent dfaafe6 commit ce2f778
Show file tree
Hide file tree
Showing 7 changed files with 1,375 additions and 3 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,11 @@ Features:
- Structured output
- Vision
- PyPI package `tiny-ai-client`
- Gemini (vision, no tools)


Roadmap:
- Gemini
- Gemini tools

## Simple
`tiny-ai-client` is simple and intuitive:
Expand Down
63 changes: 63 additions & 0 deletions examples/gemini_.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import asyncio
import os

from PIL import Image as PIL_Image
from pydantic import BaseModel, Field

from tiny_ai_client import AI, AsyncAI


class WeatherParams(BaseModel):
location: str = Field(..., description="The city and state, e.g. San Francisco, CA")
unit: str = Field(
"celsius", description="Temperature unit", enum=["celsius", "fahrenheit"]
)


def get_current_weather(weather: WeatherParams):
"""
Get the current weather in a given location
"""
return {
"abc": f"Getting the current weather in {weather.location} in {weather.unit}."
}


def get_images():
return [PIL_Image.open("assets/kirk.jpg"), PIL_Image.open("assets/spock.jpg")]


def main():
print("### SYNC AI ###")
ai = AI(
model_name="gemini-1.5-flash",
system="You are Spock, from Star Trek.",
max_new_tokens=128,
tools=[get_current_weather],
)
response = ai("How are you?")
print(f"{response=}")
response = ai("Who is on the images?", images=get_images())
print(f"{response=}")
# print(f"{ai.chat=}")


async def async_ai_main():
print("### ASYNC AI ###")
ai = AsyncAI(
model_name="gemini-1.5-flash",
system="You are Spock, from Star Trek.",
max_new_tokens=128,
tools=[get_current_weather],
)
response = await ai("How are you?")
print(f"{response=}")
response = await ai("Who is on the images?", images=get_images())
print(f"{response=}")
# print(f"{ai.chat=}")


if __name__ == "__main__":
os.environ["GOOGLE_API_KEY"] = None
main()
asyncio.run(async_ai_main())
1,179 changes: 1,179 additions & 0 deletions poetry.lock

Large diffs are not rendered by default.

8 changes: 7 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "tiny-ai-client"
version = "0.0.3"
version = "0.0.4"
description = "Tiny AI client for LLMs. As simple as it gets."
authors = ["piEsposito <[email protected]>"]
license = "Apache 2.0"
Expand All @@ -13,6 +13,12 @@ openai = "1.31.0"
anthropic = "0.28.0"


[tool.poetry.group.gemini]
optional = true

[tool.poetry.group.gemini.dependencies]
google-generativeai = "^0.6.0"

[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"
118 changes: 118 additions & 0 deletions tiny_ai_client/gemini_.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
import os
from copy import deepcopy
from typing import Any, Callable, Dict, List, Union

import google.generativeai as genai

from tiny_ai_client.models import LLMClientWrapper, Message
from tiny_ai_client.tools import function_to_json


class GeminiClientWrapper(LLMClientWrapper):
def __init__(self, model_name: str, tools: List[Union[Callable, Dict]]):
self.model_name = model_name
genai.configure(api_key=os.environ["GOOGLE_API_KEY"])
self.tools = tools
self.tools_json = [function_to_json(tool) for tool in tools]
if len(self.tools_json) > 0:
raise ValueError("Gemini does not support tools")
# self.tools_obj = [
# genai.types.FunctionDeclaration(
# name=tool["function"]["name"],
# description=tool["function"]["description"],
# parameters=genai.types.content_types.strip_titles(
# tool["function"]["parameters"]
# ),
# )
# for tool in self.tools_json
# ]

def build_model_input(self, messages: List["Message"]) -> Any:
history = []
local_messages = deepcopy(messages)
system = None
message = None

for message in local_messages:
if message.role == "system":
system = message.text
continue
else:
if message.role not in ["user", "assistant"]:
raise ValueError(f"Invalid role for Gemini: {message.role}")
role = "user" if message.role == "user" else "model"
parts = []
if message.text is not None:
parts.append(message.text)
if message.images is not None:
parts.extend(message.images)
history.append(
{
"role": role,
"parts": parts,
}
)

return (system, history)

def call_llm_provider(
self,
model_input: Any,
temperature: int | None,
max_new_tokens: int | None,
timeout: int,
) -> str:
system, history = model_input

generation_config_kwargs = {}
if temperature is not None:
generation_config_kwargs["temperature"] = temperature
if max_new_tokens is not None:
generation_config_kwargs["max_output_tokens"] = max_new_tokens

generation_config = genai.GenerationConfig(**generation_config_kwargs)
model = genai.GenerativeModel(
self.model_name,
system_instruction=system,
generation_config=generation_config,
)

model.start_chat(history=history)
response = model.generate_content(history)
response = response.candidates[0].content.parts[0]
if response.function_call.name != "":
raise ValueError("Function calls are not supported in Gemini")
elif response.text is not None:
return Message(role="assistant", text=response.text)
raise ValueError("Invalid response from Gemini")

async def async_call_llm_provider(
self,
model_input: Any,
temperature: int | None,
max_new_tokens: int | None,
timeout: int,
) -> str:
system, history = model_input

generation_config_kwargs = {}
if temperature is not None:
generation_config_kwargs["temperature"] = temperature
if max_new_tokens is not None:
generation_config_kwargs["max_output_tokens"] = max_new_tokens

generation_config = genai.GenerationConfig(**generation_config_kwargs)
model = genai.GenerativeModel(
self.model_name,
system_instruction=system,
generation_config=generation_config,
)

model.start_chat(history=history)
response = await model.generate_content_async(history)
response = response.candidates[0].content.parts[0]
if response.function_call.name != "":
raise ValueError("Function calls are not supported in Gemini")
elif response.text is not None:
return Message(role="assistant", text=response.text)
raise ValueError("Invalid response from Gemini")
5 changes: 5 additions & 0 deletions tiny_ai_client/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,11 @@ def get_llm_client_wrapper(

return AnthropicClientWrapper(model_name, tools)

if "gemini" in model_name:
from tiny_ai_client.gemini_ import GeminiClientWrapper

return GeminiClientWrapper(model_name, tools)

raise NotImplementedError(f"{model_name=} not supported")

@property
Expand Down
2 changes: 1 addition & 1 deletion tiny_ai_client/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def function_to_json(func, parameters_key="parameters") -> dict:
parameters_key: {
"type": "object",
"properties": model_schema["properties"],
"required": model_schema.get("required", []),
"required": model_schema.get("required", True),
},
},
}
Expand Down

0 comments on commit ce2f778

Please sign in to comment.