Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merge branch 'main' into imagen #9

Merged
merged 1 commit into from
Jan 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion google/generativeai/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@
from google.generativeai.generative_models import GenerativeModel
from google.generativeai.generative_models import ChatSession

from google.generativeai.vision_models import *

from google.generativeai.models import list_models
from google.generativeai.models import list_tuned_models

Expand All @@ -77,7 +79,6 @@

__version__ = version.__version__

del embedding
del files
del generative_models
del models
Expand Down
103 changes: 93 additions & 10 deletions google/generativeai/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@
import os
import contextlib
import inspect
import collections
import dataclasses
import pathlib
import threading
from typing import Any, cast
from collections.abc import Sequence
from collections.abc import Sequence, Mapping
import httplib2
from io import IOBase

Expand All @@ -24,6 +25,11 @@
import googleapiclient.http
import googleapiclient.discovery

from google.protobuf import struct_pb2

from proto.marshal.collections import maps
from proto.marshal.collections import repeated

try:
from google.generativeai import version

Expand Down Expand Up @@ -132,6 +138,70 @@ async def create_file(self, *args, **kwargs):
)


# This is to get around https://github.com/googleapis/proto-plus-python/issues/488
def to_value(value) -> struct_pb2.Value:
"""Return a protobuf Value object representing this value."""
if isinstance(value, struct_pb2.Value):
return value
if value is None:
return struct_pb2.Value(null_value=0)
if isinstance(value, bool):
return struct_pb2.Value(bool_value=value)
if isinstance(value, (int, float)):
return struct_pb2.Value(number_value=float(value))
if isinstance(value, str):
return struct_pb2.Value(string_value=value)
if isinstance(value, collections.abc.Sequence):
return struct_pb2.Value(list_value=to_list_value(value))
if isinstance(value, collections.abc.Mapping):
return struct_pb2.Value(struct_value=to_mapping_value(value))
raise ValueError("Unable to coerce value: %r" % value)


def to_list_value(value) -> struct_pb2.ListValue:
# We got a proto, or else something we sent originally.
# Preserve the instance we have.
if isinstance(value, struct_pb2.ListValue):
return value
if isinstance(value, repeated.RepeatedComposite):
return struct_pb2.ListValue(values=[v for v in value.pb])

# We got a list (or something list-like); convert it.
return struct_pb2.ListValue(values=[to_value(v) for v in value])


def to_mapping_value(value) -> struct_pb2.Struct:
# We got a proto, or else something we sent originally.
# Preserve the instance we have.
if isinstance(value, struct_pb2.Struct):
return value
if isinstance(value, maps.MapComposite):
return struct_pb2.Struct(
fields={k: v for k, v in value.pb.items()},
)

# We got a dict (or something dict-like); convert it.
return struct_pb2.Struct(fields={k: to_value(v) for k, v in value.items()})


class PredictionServiceClient(glm.PredictionServiceClient):
def predict(self, model=None, instances=None, parameters=None):
pr = protos.PredictRequest.pb()
request = pr(
model=model, instances=[to_value(i) for i in instances], parameters=to_value(parameters)
)
return super().predict(request)


class PredictionServiceAsyncClient(glm.PredictionServiceAsyncClient):
async def predict(self, model=None, instances=None, parameters=None):
pr = protos.PredictRequest.pb()
request = pr(
model=model, instances=[to_value(i) for i in instances], parameters=to_value(parameters)
)
return await super().predict(request)


@dataclasses.dataclass
class _ClientManager:
client_config: dict[str, Any] = dataclasses.field(default_factory=dict)
Expand Down Expand Up @@ -222,15 +292,20 @@ def configure(
self.clients = {}

def make_client(self, name):
if name == "file":
cls = FileServiceClient
elif name == "file_async":
cls = FileServiceAsyncClient
elif name.endswith("_async"):
name = name.split("_")[0]
cls = getattr(glm, name.title() + "ServiceAsyncClient")
else:
cls = getattr(glm, name.title() + "ServiceClient")
local_clients = {
"file": FileServiceClient,
"file_async": FileServiceAsyncClient,
"prediction": PredictionServiceClient,
"prediction_async": PredictionServiceAsyncClient,
}
cls = local_clients.get(name, None)

if cls is None:
if name.endswith("_async"):
name = name.split("_")[0]
cls = getattr(glm, name.title() + "ServiceAsyncClient")
else:
cls = getattr(glm, name.title() + "ServiceClient")

# Attempt to configure using defaults.
if not self.client_config:
Expand Down Expand Up @@ -386,3 +461,11 @@ def get_default_permission_client() -> glm.PermissionServiceClient:

def get_default_permission_async_client() -> glm.PermissionServiceAsyncClient:
return _client_manager.get_default_client("permission_async")


def get_default_prediction_client() -> glm.PermissionServiceClient:
return _client_manager.get_default_client("prediction")


def get_default_prediction_async_client() -> glm.PermissionServiceAsyncClient:
return _client_manager.get_default_client("prediction_async")
2 changes: 1 addition & 1 deletion google/generativeai/generative_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from collections.abc import Iterable
import textwrap
from typing import Any, Union, overload
from typing import Any, Literal, Union, overload
import reprlib

# pylint: disable=bad-continuation, line-too-long
Expand Down
95 changes: 4 additions & 91 deletions google/generativeai/types/content_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,42 +16,16 @@
from __future__ import annotations

from collections.abc import Iterable, Mapping, Sequence
import io
import inspect
import mimetypes
import pathlib
import typing
from typing import Any, Callable, Union
from typing_extensions import TypedDict

import pydantic

from google.generativeai.types import file_types
from google.generativeai.types.image_types import _image_types
from google.generativeai import protos

if typing.TYPE_CHECKING:
import PIL.Image
import PIL.ImageFile
import IPython.display

IMAGE_TYPES = (PIL.Image.Image, IPython.display.Image)
else:
IMAGE_TYPES = ()
try:
import PIL.Image
import PIL.ImageFile

IMAGE_TYPES = IMAGE_TYPES + (PIL.Image.Image,)
except ImportError:
PIL = None

try:
import IPython.display

IMAGE_TYPES = IMAGE_TYPES + (IPython.display.Image,)
except ImportError:
IPython = None


__all__ = [
"BlobDict",
Expand Down Expand Up @@ -94,62 +68,6 @@ def to_mode(x: ModeOptions) -> Mode:
return _MODE[x]


def _pil_to_blob(image: PIL.Image.Image) -> protos.Blob:
# If the image is a local file, return a file-based blob without any modification.
# Otherwise, return a lossless WebP blob (same quality with optimized size).
def file_blob(image: PIL.Image.Image) -> protos.Blob | None:
if not isinstance(image, PIL.ImageFile.ImageFile) or image.filename is None:
return None
filename = str(image.filename)
if not pathlib.Path(filename).is_file():
return None

mime_type = image.get_format_mimetype()
image_bytes = pathlib.Path(filename).read_bytes()

return protos.Blob(mime_type=mime_type, data=image_bytes)

def webp_blob(image: PIL.Image.Image) -> protos.Blob:
# Reference: https://pillow.readthedocs.io/en/stable/handbook/image-file-formats.html#webp
image_io = io.BytesIO()
image.save(image_io, format="webp", lossless=True)
image_io.seek(0)

mime_type = "image/webp"
image_bytes = image_io.read()

return protos.Blob(mime_type=mime_type, data=image_bytes)

return file_blob(image) or webp_blob(image)


def image_to_blob(image) -> protos.Blob:
if PIL is not None:
if isinstance(image, PIL.Image.Image):
return _pil_to_blob(image)

if IPython is not None:
if isinstance(image, IPython.display.Image):
name = image.filename
if name is None:
raise ValueError(
"Conversion failed. The `IPython.display.Image` can only be converted if "
"it is constructed from a local file. Please ensure you are using the format: Image(filename='...')."
)
mime_type, _ = mimetypes.guess_type(name)
if mime_type is None:
mime_type = "image/unknown"

return protos.Blob(mime_type=mime_type, data=image.data)

raise TypeError(
"Image conversion failed. The input was expected to be of type `Image` "
"(either `PIL.Image.Image` or `IPython.display.Image`).\n"
f"However, received an object of type: {type(image)}.\n"
f"Object Value: {image}"
)


class BlobDict(TypedDict):
mime_type: str
data: bytes
Expand Down Expand Up @@ -186,12 +104,7 @@ def is_blob_dict(d):
return "mime_type" in d and "data" in d


if typing.TYPE_CHECKING:
BlobType = Union[
protos.Blob, BlobDict, PIL.Image.Image, IPython.display.Image
] # Any for the images
else:
BlobType = Union[protos.Blob, BlobDict, Any]
BlobType = Union[protos.Blob, BlobDict, _image_types.ImageType] # Any for the images


def to_blob(blob: BlobType) -> protos.Blob:
Expand All @@ -200,8 +113,8 @@ def to_blob(blob: BlobType) -> protos.Blob:

if isinstance(blob, protos.Blob):
return blob
elif isinstance(blob, IMAGE_TYPES):
return image_to_blob(blob)
elif isinstance(blob, _image_types.IMAGE_TYPES):
return _image_types.image_to_blob(blob)
else:
if isinstance(blob, Mapping):
raise KeyError(
Expand Down
1 change: 1 addition & 0 deletions google/generativeai/types/image_types/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from google.generativeai.types.image_types._image_types import *
Loading
Loading