Skip to content

Commit

Permalink
remove retriever.Corpus class in favor of retriever_types.Corpus class
Browse files Browse the repository at this point in the history
  • Loading branch information
MarkDaoust committed Jan 30, 2024
1 parent 6fe8c12 commit 77581c2
Show file tree
Hide file tree
Showing 5 changed files with 38 additions and 37 deletions.
6 changes: 4 additions & 2 deletions google/generativeai/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,8 @@ def embed_content(
task_type: EmbeddingTaskTypeOptions | None = None,
title: str | None = None,
client: glm.GenerativeServiceClient | None = None,
) -> text_types.EmbeddingDict: ...
) -> text_types.EmbeddingDict:
...


@overload
Expand All @@ -103,7 +104,8 @@ def embed_content(
task_type: EmbeddingTaskTypeOptions | None = None,
title: str | None = None,
client: glm.GenerativeServiceClient | None = None,
) -> text_types.BatchEmbeddingDict: ...
) -> text_types.BatchEmbeddingDict:
...


def embed_content(
Expand Down
6 changes: 4 additions & 2 deletions google/generativeai/notebook/lib/llmfn_outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,10 +109,12 @@ def __len__(self) -> int:

# Needed for Sequence[LLMFnOutputEntry].
@overload
def __getitem__(self, x: int) -> LLMFnOutputEntry: ...
def __getitem__(self, x: int) -> LLMFnOutputEntry:
...

@overload
def __getitem__(self, x: slice) -> Sequence[LLMFnOutputEntry]: ...
def __getitem__(self, x: slice) -> Sequence[LLMFnOutputEntry]:
...

def __getitem__(self, x: int | slice) -> LLMFnOutputEntry | Sequence[LLMFnOutputEntry]:
return self._outputs.__getitem__(x)
Expand Down
4 changes: 3 additions & 1 deletion google/generativeai/notebook/magics_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,9 @@ def parse_line(
) -> tuple[parsed_args_lib.ParsedArgs, parsed_args_lib.PostProcessingTokens]:
return cmd_line_parser.CmdLineParser().parse_line(line, placeholders)

def _get_handler(self, line: str, placeholders: AbstractSet[str]) -> tuple[
def _get_handler(
self, line: str, placeholders: AbstractSet[str]
) -> tuple[
command.Command,
parsed_args_lib.ParsedArgs,
Sequence[post_process_utils.ParsedPostProcessExpr],
Expand Down
53 changes: 23 additions & 30 deletions google/generativeai/retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,25 +36,18 @@
_PATTERN = r"[{}]".format(_REMOVE) # Create the pattern


@string_utils.prettyprint
@dataclasses.dataclass(init=False)
class Corpus(retriever_types.Corpus):
def __init__(self, **kwargs):
for key, value in kwargs.items():
setattr(self, key, value)

self.result = None
if self.name:
self.result = self.name
retriever_types.Corpus


def create_corpus(
name: Optional[str] = None,
display_name: Optional[str] = None,
client: glm.RetrieverServiceClient | None = None,
) -> Corpus:
) -> retriever_types.Corpus:
"""
Create a Corpus object. Users can specify either a name or display_name.
Create a new `Corpus` in the retriever service, and return it as a `retriever_types.Corpus` instance.
Users can specify either a name or display_name.
Args:
name: The corpus resource name (ID). The name must be alphanumeric and fewer
Expand All @@ -64,7 +57,7 @@ def create_corpus(
dashes are supported.
Return:
Corpus object with specified name or display name.
`retriever_types.Corpus` object with specified name or display name.
Raises:
ValueError: When the name is not specified or formatted incorrectly.
Expand All @@ -90,16 +83,16 @@ def create_corpus(
response = type(response).to_dict(response)
idecode_time(response, "create_time")
idecode_time(response, "update_time")
response = Corpus(**response)
response = retriever_types.Corpus(**response)
return response


async def create_corpus_async(
name: Optional[str] = None,
display_name: Optional[str] = None,
client: glm.RetrieverServiceAsyncClient | None = None,
) -> Corpus:
"""This is the async version of `create_corpus`."""
) -> retriever_types.Corpus:
"""This is the async version of `retriever.create_corpus`."""
if client is None:
client = get_default_retriever_async_client()

Expand All @@ -121,19 +114,19 @@ async def create_corpus_async(
response = type(response).to_dict(response)
idecode_time(response, "create_time")
idecode_time(response, "update_time")
response = Corpus(**response)
response = retriever_types.Corpus(**response)
return response


def get_corpus(name: str, client: glm.RetrieverServiceClient | None = None) -> Corpus: # fmt: skip
def get_corpus(name: str, client: glm.RetrieverServiceClient | None = None) -> retriever_types.Corpus: # fmt: skip
"""
Get information about a specific `Corpus`.
Fetch a specific `Corpus` from the retriever service.
Args:
name: The `Corpus` name.
Return:
`Corpus` of interest.
a `retriever_types.Corpus` of interest.
"""
if client is None:
client = get_default_retriever_client()
Expand All @@ -143,12 +136,12 @@ def get_corpus(name: str, client: glm.RetrieverServiceClient | None = None) -> C
response = type(response).to_dict(response)
idecode_time(response, "create_time")
idecode_time(response, "update_time")
response = Corpus(**response)
response = retriever_types.Corpus(**response)
return response


async def get_corpus_async(name: str, client: glm.RetrieverServiceAsyncClient | None = None) -> Corpus: # fmt: skip
"""This is the async version of `get_corpus`."""
async def get_corpus_async(name: str, client: glm.RetrieverServiceAsyncClient | None = None) -> retriever_types.Corpus: # fmt: skip
"""This is the async version of `retriever.get_corpus`."""
if client is None:
client = get_default_retriever_async_client()

Expand All @@ -157,13 +150,13 @@ async def get_corpus_async(name: str, client: glm.RetrieverServiceAsyncClient |
response = type(response).to_dict(response)
idecode_time(response, "create_time")
idecode_time(response, "update_time")
response = Corpus(**response)
response = retriever_types.Corpus(**response)
return response


def delete_corpus(name: str, force: bool, client: glm.RetrieverServiceClient | None = None): # fmt: skip
"""
Delete a `Corpus`.
Delete a `Corpus` from the service.
Args:
name: The `Corpus` name.
Expand All @@ -177,7 +170,7 @@ def delete_corpus(name: str, force: bool, client: glm.RetrieverServiceClient | N


async def delete_corpus_async(name: str, force: bool, client: glm.RetrieverServiceAsyncClient | None = None): # fmt: skip
"""This is the async version of `delete_corpus`."""
"""This is the async version of `retriever.delete_corpus`."""
if client is None:
client = get_default_retriever_async_client()

Expand All @@ -190,9 +183,9 @@ def list_corpora(
page_size: Optional[int] = None,
page_token: Optional[str] = None,
client: glm.RetrieverServiceClient | None = None,
) -> list[Corpus]:
) -> list[retriever_types.Corpus]:
"""
List `Corpus`.
List the Corpuses you own in the service.
Args:
page_size: Maximum number of `Corpora` to request.
Expand All @@ -214,8 +207,8 @@ async def list_corpora_async(
page_size: Optional[int] = None,
page_token: Optional[str] = None,
client: glm.RetrieverServiceClient | None = None,
) -> list[Corpus]:
"""This is the async version of `list_corpora`."""
) -> list[retriever_types.Corpus]:
"""This is the async version of `retriever.list_corpora`."""
if client is None:
client = get_default_retriever_async_client()

Expand Down
6 changes: 4 additions & 2 deletions google/generativeai/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,15 +263,17 @@ def generate_embeddings(
model: model_types.BaseModelNameOptions,
text: str,
client: glm.TextServiceClient = None,
) -> text_types.EmbeddingDict: ...
) -> text_types.EmbeddingDict:
...


@overload
def generate_embeddings(
model: model_types.BaseModelNameOptions,
text: Sequence[str],
client: glm.TextServiceClient = None,
) -> text_types.BatchEmbeddingDict: ...
) -> text_types.BatchEmbeddingDict:
...


def generate_embeddings(
Expand Down

0 comments on commit 77581c2

Please sign in to comment.