Skip to content

Commit

Permalink
Semantic retriever
Browse files Browse the repository at this point in the history
Change-Id: I5c4f35238f3bc0bbc798abd72cf824ebe1103152
  • Loading branch information
shilpakancharla committed Jan 12, 2024
1 parent 7abbdf3 commit 2824827
Show file tree
Hide file tree
Showing 7 changed files with 1,575 additions and 13 deletions.
8 changes: 7 additions & 1 deletion google/generativeai/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,10 +168,12 @@ def get_default_operations_client(self) -> operations_v1.OperationsClient:
model_client = self.get_default_client("Model")
client = model_client._transport.operations_client
self.clients["operations"] = client

return client


_client_manager = _ClientManager()


def configure(
*,
api_key: str | None = None,
Expand Down Expand Up @@ -244,3 +246,7 @@ def get_default_operations_client() -> operations_v1.OperationsClient:

def get_default_model_client() -> glm.ModelServiceAsyncClient:
return _client_manager.get_default_client("model")


def get_default_retriever_client() -> glm.RetrieverClient:
return _client_manager.get_default_client("retriever")
13 changes: 1 addition & 12 deletions google/generativeai/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from google.api_core import operation
from google.api_core import protobuf_helpers
from google.protobuf import field_mask_pb2
from google.generativeai.utils import _flatten_update_paths


def get_model(
Expand Down Expand Up @@ -379,18 +380,6 @@ def update_tuned_model(
return model_types.decode_tuned_model(result)


def _flatten_update_paths(updates):
new_updates = {}
for key, value in updates.items():
if isinstance(value, dict):
for sub_key, sub_value in _flatten_update_paths(value).items():
new_updates[f"{key}.{sub_key}"] = sub_value
else:
new_updates[key] = value

return new_updates


def _apply_update(thing, path, value):
parts = path.split(".")
for part in parts[:-1]:
Expand Down
156 changes: 156 additions & 0 deletions google/generativeai/retriever.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
# -*- coding: utf-8 -*-
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations

import re
import string
import dataclasses
from collections.abc import Iterable, Sequence
from typing import Iterable, overload, TypeVar, Union, Mapping, Optional

import google.ai.generativelanguage as glm

from google.generativeai.client import get_default_retriever_client
from google.generativeai import string_utils
from google.generativeai.types import retriever_types
from google.generativeai.types import model_types
from google.generativeai import models
from google.generativeai.types import safety_types
from google.generativeai.types.model_types import idecode_time

_CORPORA_NAME_REGEX = re.compile(r"^corpora/([^/]+?)")
_REMOVE = string.punctuation
_REMOVE = _REMOVE.replace("-", "") # Don't remove hyphens
_PATTERN = r"[{}]".format(_REMOVE) # Create the pattern


@string_utils.prettyprint
@dataclasses.dataclass(init=False)
class Corpus(retriever_types.Corpus):
def __init__(self, **kwargs):
for key, value in kwargs.items():
setattr(self, key, value)

self.result = None
if self.name:
self.result = self.name


def create_corpus(
name: Optional[str] = None,
display_name: Optional[str] = None,
client: glm.RetrieverServiceClient | None = None,
) -> Corpus:
"""
Create a Corpus object. Users can specify either a name or display_name. Users can
create up to 5 corpora.
Args:
name: The corpus resource name (ID). The name must be alphanumeric and fewer
than 40 characters.
display_name: The human readable display name. The display name must be fewer
than 128 characters. All characters, including alphanumeric, spaces, and
dashes are supported.
Return:
Corpus object with specified name or display name.
Raises:
ValueError: When the name is not specified or formatted incorrectly.
"""
if client is None:
client = get_default_retriever_client()

if not name and not display_name:
raise ValueError("Either the corpus name or display name must be specified.")

corpus = None
if name:
if re.match(_CORPORA_NAME_REGEX, name):
corpus = glm.Corpus(name=name, display_name=display_name)
elif "corpora/" not in name:
corpus_name = "corpora/" + re.sub(_PATTERN, "", name)
corpus = glm.Corpus(name=corpus_name, display_name=display_name)
else:
raise ValueError("Corpus name must be formatted as corpora/<corpus_name>.")

request = glm.CreateCorpusRequest(corpus=corpus)
response = client.create_corpus(request)
response = type(response).to_dict(response)
idecode_time(response, "create_time")
idecode_time(response, "update_time")
response = Corpus(**response)
return response


def get_corpus(name: str, client: glm.RetrieverServiceClient | None = None) -> Corpus:
"""
Get information about a specific `Corpus`.
Args:
name: The `Corpus` name.
Return:
`Corpus` of interest.
"""
if client is None:
client = get_default_retriever_client()

request = glm.GetCorpusRequest(name=name)
response = client.get_corpus(request)
response = type(response).to_dict(response)
idecode_time(response, "create_time")
idecode_time(response, "update_time")
response = Corpus(**response)
return response


def delete_corpus(name: str, force: bool, client: glm.RetrieverServiceClient | None = None):
"""
Delete a `Corpus`.
Args:
name: The `Corpus` name.
force: If set to true, any `Document`s and objects related to this `Corpus` will also be deleted.
"""
if client is None:
client = get_default_retriever_client()

request = glm.DeleteCorpusRequest(name=name, force=force)
response = client.delete_corpus(request)


def list_corpora(
*,
page_size: Optional[int] = None,
page_token: Optional[str] = None,
client: glm.RetrieverServiceClient | None = None,
) -> list[Corpus]:
"""
List `Corpus`.
Args:
page_size: Maximum number of `Corpora` to request.
page_token: A page token, received from a previous ListCorpora call.
Return:
Paginated list of `Corpora`.
"""
if client is None:
client = get_default_retriever_client()

request = glm.ListCorporaRequest(page_size=page_size, page_token=page_token)
response = client.list_corpora(request)
return response
31 changes: 31 additions & 0 deletions google/generativeai/types/embedding_types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# -*- coding: utf-8 -*-
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations

import abc
import dataclasses
from typing import Any, Dict, List, TypedDict

from google.generativeai import string_utils
from google.generativeai.types import safety_types
from google.generativeai.types import citation_types


class EmbeddingDict(TypedDict):
embedding: list[float]


class BatchEmbeddingDict(TypedDict):
embedding: list[list[float]]
Loading

0 comments on commit 2824827

Please sign in to comment.