Skip to content

Commit

Permalink
Semantic retriever (#168)
Browse files Browse the repository at this point in the history
* Semantic retriever

Change-Id: I5c4f35238f3bc0bbc798abd72cf824ebe1103152

* Adding parameter to re.sub for create_chunk function

* attempting to fix pytype error with ChunkData and CustomMetadata

* Adding async to semantic retriever functions

* Update _flatten to flatten

Co-authored-by: Mark Daoust <[email protected]>

* Update google/generativeai/types/retriever_types.py

Co-authored-by: Mark Daoust <[email protected]>

* Update google/generativeai/client.py

Co-authored-by: Mark Daoust <[email protected]>

* Update google/generativeai/models.py

Co-authored-by: Mark Daoust <[email protected]>

* Resolving Github precheck failures

* Changed .data to .string_value

* Update _flatten_update_paths to flatten_update_paths

* Updating async test cases for retriever

* Added in client methods in async retriever test

* Added all async test cases

* Fixed all test cases locally

* Updated async retriever tests

* Fixing names in async test cases

* Fixed async method for QueryCorpus

* Added await statements

* Reformatted file

* Added async methods to test cases

* Updated regex statements and removed redundancy from elif statements

* Updates to create_chunk

* Async code test update, dataclass updates

* Skipping format check to resolve errors

* Modified gitignore

* Update to delete_document

* Formatting check

---------

Co-authored-by: Mark Daoust <[email protected]>
  • Loading branch information
shilpakancharla and MarkDaoust authored Jan 30, 2024
1 parent 1e67f3d commit 6fe8c12
Show file tree
Hide file tree
Showing 9 changed files with 2,855 additions and 15 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@
*.egg-info
.DS_Store
__pycache__
*.iml
*.iml
9 changes: 8 additions & 1 deletion google/generativeai/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,6 @@ def get_default_operations_client(self) -> operations_v1.OperationsClient:
model_client = self.get_default_client("Model")
client = model_client._transport.operations_client
self.clients["operations"] = client

return client


Expand Down Expand Up @@ -244,3 +243,11 @@ def get_default_operations_client() -> operations_v1.OperationsClient:

def get_default_model_client() -> glm.ModelServiceAsyncClient:
return _client_manager.get_default_client("model")


def get_default_retriever_client() -> glm.RetrieverClient:
return _client_manager.get_default_client("retriever")


def get_default_retriever_async_client() -> glm.RetrieverAsyncClient:
return _client_manager.get_default_client("retriever_async")
15 changes: 2 additions & 13 deletions google/generativeai/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from google.api_core import operation
from google.api_core import protobuf_helpers
from google.protobuf import field_mask_pb2
from google.generativeai.utils import flatten_update_paths


def get_model(
Expand Down Expand Up @@ -351,7 +352,7 @@ def update_tuned_model(
)
tuned_model = client.get_tuned_model(name=name)

updates = _flatten_update_paths(updates)
updates = flatten_update_paths(updates)
field_mask = field_mask_pb2.FieldMask()
for path in updates.keys():
field_mask.paths.append(path)
Expand Down Expand Up @@ -379,18 +380,6 @@ def update_tuned_model(
return model_types.decode_tuned_model(result)


def _flatten_update_paths(updates):
new_updates = {}
for key, value in updates.items():
if isinstance(value, dict):
for sub_key, sub_value in _flatten_update_paths(value).items():
new_updates[f"{key}.{sub_key}"] = sub_value
else:
new_updates[key] = value

return new_updates


def _apply_update(thing, path, value):
parts = path.split(".")
for part in parts[:-1]:
Expand Down
224 changes: 224 additions & 0 deletions google/generativeai/retriever.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,224 @@
# -*- coding: utf-8 -*-
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations

import re
import string
import dataclasses
from typing import Optional

import google.ai.generativelanguage as glm

from google.generativeai.client import get_default_retriever_client
from google.generativeai.client import get_default_retriever_async_client
from google.generativeai import string_utils
from google.generativeai.types import retriever_types
from google.generativeai.types import model_types
from google.generativeai import models
from google.generativeai.types import safety_types
from google.generativeai.types.model_types import idecode_time

_CORPORA_NAME_REGEX = re.compile(r"^corpora/[a-z0-9-]+")
_REMOVE = string.punctuation
_REMOVE = _REMOVE.replace("-", "") # Don't remove hyphens
_PATTERN = r"[{}]".format(_REMOVE) # Create the pattern


@string_utils.prettyprint
@dataclasses.dataclass(init=False)
class Corpus(retriever_types.Corpus):
def __init__(self, **kwargs):
for key, value in kwargs.items():
setattr(self, key, value)

self.result = None
if self.name:
self.result = self.name


def create_corpus(
name: Optional[str] = None,
display_name: Optional[str] = None,
client: glm.RetrieverServiceClient | None = None,
) -> Corpus:
"""
Create a Corpus object. Users can specify either a name or display_name.
Args:
name: The corpus resource name (ID). The name must be alphanumeric and fewer
than 40 characters.
display_name: The human readable display name. The display name must be fewer
than 128 characters. All characters, including alphanumeric, spaces, and
dashes are supported.
Return:
Corpus object with specified name or display name.
Raises:
ValueError: When the name is not specified or formatted incorrectly.
"""
if client is None:
client = get_default_retriever_client()

if not name and not display_name:
raise ValueError("Either the corpus name or display name must be specified.")

corpus = None
if name:
if re.match(_CORPORA_NAME_REGEX, name):
corpus = glm.Corpus(name=name, display_name=display_name)
elif "corpora/" not in name:
corpus_name = "corpora/" + re.sub(_PATTERN, "", name)
corpus = glm.Corpus(name=corpus_name, display_name=display_name)
else:
raise ValueError("Corpus name must be formatted as corpora/<corpus_name>.")

request = glm.CreateCorpusRequest(corpus=corpus)
response = client.create_corpus(request)
response = type(response).to_dict(response)
idecode_time(response, "create_time")
idecode_time(response, "update_time")
response = Corpus(**response)
return response


async def create_corpus_async(
name: Optional[str] = None,
display_name: Optional[str] = None,
client: glm.RetrieverServiceAsyncClient | None = None,
) -> Corpus:
"""This is the async version of `create_corpus`."""
if client is None:
client = get_default_retriever_async_client()

if not name and not display_name:
raise ValueError("Either the corpus name or display name must be specified.")

corpus = None
if name:
if re.match(_CORPORA_NAME_REGEX, name):
corpus = glm.Corpus(name=name, display_name=display_name)
elif "corpora/" not in name:
corpus_name = "corpora/" + re.sub(_PATTERN, "", name)
corpus = glm.Corpus(name=corpus_name, display_name=display_name)
else:
raise ValueError("Corpus name must be formatted as corpora/<corpus_name>.")

request = glm.CreateCorpusRequest(corpus=corpus)
response = await client.create_corpus(request)
response = type(response).to_dict(response)
idecode_time(response, "create_time")
idecode_time(response, "update_time")
response = Corpus(**response)
return response


def get_corpus(name: str, client: glm.RetrieverServiceClient | None = None) -> Corpus: # fmt: skip
"""
Get information about a specific `Corpus`.
Args:
name: The `Corpus` name.
Return:
`Corpus` of interest.
"""
if client is None:
client = get_default_retriever_client()

request = glm.GetCorpusRequest(name=name)
response = client.get_corpus(request)
response = type(response).to_dict(response)
idecode_time(response, "create_time")
idecode_time(response, "update_time")
response = Corpus(**response)
return response


async def get_corpus_async(name: str, client: glm.RetrieverServiceAsyncClient | None = None) -> Corpus: # fmt: skip
"""This is the async version of `get_corpus`."""
if client is None:
client = get_default_retriever_async_client()

request = glm.GetCorpusRequest(name=name)
response = await client.get_corpus(request)
response = type(response).to_dict(response)
idecode_time(response, "create_time")
idecode_time(response, "update_time")
response = Corpus(**response)
return response


def delete_corpus(name: str, force: bool, client: glm.RetrieverServiceClient | None = None): # fmt: skip
"""
Delete a `Corpus`.
Args:
name: The `Corpus` name.
force: If set to true, any `Document`s and objects related to this `Corpus` will also be deleted.
"""
if client is None:
client = get_default_retriever_client()

request = glm.DeleteCorpusRequest(name=name, force=force)
client.delete_corpus(request)


async def delete_corpus_async(name: str, force: bool, client: glm.RetrieverServiceAsyncClient | None = None): # fmt: skip
"""This is the async version of `delete_corpus`."""
if client is None:
client = get_default_retriever_async_client()

request = glm.DeleteCorpusRequest(name=name, force=force)
await client.delete_corpus(request)


def list_corpora(
*,
page_size: Optional[int] = None,
page_token: Optional[str] = None,
client: glm.RetrieverServiceClient | None = None,
) -> list[Corpus]:
"""
List `Corpus`.
Args:
page_size: Maximum number of `Corpora` to request.
page_token: A page token, received from a previous ListCorpora call.
Return:
Paginated list of `Corpora`.
"""
if client is None:
client = get_default_retriever_client()

request = glm.ListCorporaRequest(page_size=page_size, page_token=page_token)
response = client.list_corpora(request)
return response


async def list_corpora_async(
*,
page_size: Optional[int] = None,
page_token: Optional[str] = None,
client: glm.RetrieverServiceClient | None = None,
) -> list[Corpus]:
"""This is the async version of `list_corpora`."""
if client is None:
client = get_default_retriever_async_client()

request = glm.ListCorporaRequest(page_size=page_size, page_token=page_token)
response = await client.list_corpora(request)
return response
27 changes: 27 additions & 0 deletions google/generativeai/types/embedding_types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# -*- coding: utf-8 -*-
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations

import abc
import dataclasses
from typing import Any, Dict, List, TypedDict


class EmbeddingDict(TypedDict):
embedding: list[float]


class BatchEmbeddingDict(TypedDict):
embedding: list[list[float]]
Loading

0 comments on commit 6fe8c12

Please sign in to comment.