From 0e5972cb7dc3d5abb6f4de28a5a51d2759af3cd3 Mon Sep 17 00:00:00 2001 From: Shilpa Kancharla Date: Mon, 30 Oct 2023 13:48:36 -0700 Subject: [PATCH 1/9] Adding AQA (GenerateAnswer). Change-Id: Ia95d8b2c48506f843f765c0546e848b46f6d8d32 --- google/generativeai/answer.py | 223 ++++++++++++++++++++++ google/generativeai/types/answer_types.py | 110 +++++++++++ tests/test_answer.py | 202 ++++++++++++++++++++ 3 files changed, 535 insertions(+) create mode 100644 google/generativeai/answer.py create mode 100644 google/generativeai/types/answer_types.py create mode 100644 tests/test_answer.py diff --git a/google/generativeai/answer.py b/google/generativeai/answer.py new file mode 100644 index 000000000..355a41f06 --- /dev/null +++ b/google/generativeai/answer.py @@ -0,0 +1,223 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import dataclasses +from collections.abc import Iterable +import itertools +from typing import Iterable, Union, Mapping, Optional, Any + +import google.ai.generativelanguage as glm + +from google.generativeai.client import get_default_generative_client +from google.generativeai import string_utils +from google.generativeai.types import model_types +from google.generativeai import models +from google.generativeai.types import safety_types +from google.generativeai.types import content_types +from google.generativeai.types import answer_types + +DEFAULT_ANSWER_MODEL = "models/aqa" + +GroundingPassageOptions = ( + Union[glm.GroundingPassage, tuple[str, content_types.ContentType], content_types.ContentType], +) + +GroundingPassagesOptions = Union[ + glm.GroundingPassages, + Iterable[GroundingPassageOptions], + Mapping[str, content_types.ContentType], +] + +AnswerStyle = glm.GenerateAnswerRequest.AnswerStyle + +AnswerStyleOptions = Union[int, str, AnswerStyle] + +_ANSWER_STYLES: dict[AnswerStyleOptions, AnswerStyle] = { + AnswerStyle.ANSWER_STYLE_UNSPECIFIED: AnswerStyle.ANSWER_STYLE_UNSPECIFIED, + 0: AnswerStyle.ANSWER_STYLE_UNSPECIFIED, + "answer_style_unspecified": AnswerStyle.ANSWER_STYLE_UNSPECIFIED, + "unspecified": AnswerStyle.ANSWER_STYLE_UNSPECIFIED, + AnswerStyle.ABSTRACTIVE: AnswerStyle.ABSTRACTIVE, + 1: AnswerStyle.ABSTRACTIVE, + "answer_style_abstractive": AnswerStyle.ABSTRACTIVE, + "abstractive": AnswerStyle.ABSTRACTIVE, + AnswerStyle.EXTRACTIVE: AnswerStyle.EXTRACTIVE, + 2: AnswerStyle.EXTRACTIVE, + "answer_style_extractive": AnswerStyle.EXTRACTIVE, + "extractive": AnswerStyle.EXTRACTIVE, + AnswerStyle.VERBOSE: AnswerStyle.VERBOSE, + 3: AnswerStyle.VERBOSE, + "answer_style_verbose": AnswerStyle.VERBOSE, + "verbose": AnswerStyle.VERBOSE, +} + + +def to_answer_style(x: AnswerStyleOptions) -> AnswerStyle: + if isinstance(x, str): + x = x.lower() + return _ANSWER_STYLES[x] + + +def _make_grounding_passages(source: GroundingPassagesOptions) -> glm.GroundingPassages: + """ + Creates a list of `glm.Content` wrapped in `glm.GroundingPassages`. The `glm.GroundingPassage` + object contains an id of the content, and the actual content itself. + + Args: + source: `Content` or an iterable `Content` that will be converted to glm.GroundingPassages. + + Return: + glm.GroundingPassages to be passed into glm.GenerateAnswer. + """ + if isinstance(source, glm.GroundingPassages): + return source + + passages = [] + if isinstance(source, Iterable): + if isinstance(source, Mapping): + source = source.items() + for n, data in enumerate(source): + if isinstance(data, glm.GroundingPassage): + passages.append(data) + elif isinstance(data, tuple): + if isinstance(data[0], str): + passages.append({"id": data[0], "content": content_types.to_content(data[1])}) + else: + passages.append({"id": str(n), "content": content_types.to_content(data[1])}) + else: + passages.append({"id": str(n), "content": content_types.to_content(data)}) + return glm.GroundingPassages(passages=passages) + else: + raise TypeError("`source` must be a valid `Content` type object.") + + +def _make_generate_answer_request( + *, + model: model_types.AnyModelNameOptions = DEFAULT_ANSWER_MODEL, + contents: content_types.ContentsType, + grounding_source: GroundingPassagesOptions, + answer_style: AnswerStyle | None = None, + safety_settings: safety_types.SafetySettingOptions | None = None, + temperature: float | None = None, +) -> glm.GenerateAnswerRequest: + """ + Calls the API to generate a grounded answer from the model. + + Args: + model: Name of the model used to generate the grounded response. + contents: Content of the current conversation with the model. For single-turn query, this is a + single question to answer. For multi-turn queries, this is a repeated field that contains + conversation history and the last `Content` in the list containing the question. + grounding_source: Sources in which to grounding the answer. + answer_style: Style for grounded answers. + safety_settings: Safety settings for generated output. Defaults to None. + temperature: The temperature for randomness in the output. Defaults to None. + + Returns: + Call for glm.GenerateAnswerRequest(). + """ + model = model_types.make_model_name(model) + + contents = content_types.to_contents(contents) + + if safety_settings: + safety_settings = safety_types.normalize_safety_settings(safety_settings) + + grounding_source = _make_grounding_passages(grounding_source) + + if answer_style: + answer_style = to_answer_style(answer_style) + + return glm.GenerateAnswerRequest( + model=model, + contents=contents, + inline_passages=grounding_source, + safety_settings=safety_settings, + temperature=temperature, + answer_style=answer_style, + ) + + +def generate_answer( + *, + model: model_types.AnyModelNameOptions = DEFAULT_ANSWER_MODEL, + contents: content_types.ContentsType, + inline_passages: GroundingPassagesOptions, + answer_style: AnswerStyle | None = None, + safety_settings: safety_types.SafetySettingOptions | None = None, + temperature: float | None = None, + client: glm.GenerativeServiceClient | None = None, +) -> answer_types.Answer: + """ + Calls the API and returns a `types.Answer` containing the answer. + + Args: + model: Which model to call, as a string or a `types.Model`. + question: The question to be answered by the model, grounded in the + provided source. + grounding_source: Source indicating the passages in which the answer should be grounded. + answer_style: Style in which the grounded answer should be returned. + safety_settings: Safety settings for generated output. Defaults to None. + client: If you're not relying on a default client, you pass a `glm.TextServiceClient` instead. + + Returns: + A `types.Answer` containing the model's text answer response. + """ + request = _make_generate_answer_request( + model=model, + contents=contents, + grounding_source=inline_passages, + safety_settings=safety_settings, + temperature=temperature, + answer_style=answer_style, + ) + + return _generate_answer_response(client=client, request=request) + + +@string_utils.prettyprint +@dataclasses.dataclass(init=False) +class Answer(answer_types.Answer): + def __init__(self, **kwargs): + for key, value in kwargs.items(): + setattr(self, key, value) + + self.result = None + if self.answer: + self.result = self.answer["content"]["parts"] + + +def _generate_answer_response( + request: glm.GenerateAnswerRequest, client: glm.GenerativeServiceClient | None = None +) -> Answer: + """ + Generates a response using the provided `glm.GenerateAnswerRequest` and client. + + Args: + request: The answer generation request. + client: The client used for text answer generation. Defaults to None, in which + case the default generative client is used. + + Returns: + `Answer`: An `Answer` object with the generated text and response information. + """ + if client is None: + client = get_default_generative_client() + + response = client.generate_answer(request) + response = type(response).to_dict(response) + + return Answer(_client=client, **response) diff --git a/google/generativeai/types/answer_types.py b/google/generativeai/types/answer_types.py new file mode 100644 index 000000000..53489be0a --- /dev/null +++ b/google/generativeai/types/answer_types.py @@ -0,0 +1,110 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import abc +import dataclasses +from typing import Any, Dict, List, TypedDict, Optional, Union + +import google.ai.generativelanguage as glm + +from google.generativeai import string_utils +from google.generativeai.types import safety_types +from google.generativeai.types import citation_types +from google.generativeai.types import content_types + +__all__ = ["Answer"] + +FinishReason = glm.Candidate.FinishReason + +FinishReasonOptions = Union[int, str, FinishReason] + +_FINISH_REASONS: dict[FinishReasonOptions, FinishReason] = { + FinishReason.FINISH_REASON_UNSPECIFIED: FinishReason.FINISH_REASON_UNSPECIFIED, + 0: FinishReason.FINISH_REASON_UNSPECIFIED, + "finish_reason_unspecified": FinishReason.FINISH_REASON_UNSPECIFIED, + "unspecified": FinishReason.FINISH_REASON_UNSPECIFIED, + FinishReason.STOP: FinishReason.STOP, + 1: FinishReason.STOP, + "finish_reason_stop": FinishReason.STOP, + "stop": FinishReason.STOP, + FinishReason.MAX_TOKENS: FinishReason.MAX_TOKENS, + 2: FinishReason.MAX_TOKENS, + "finish_reason_max_tokens": FinishReason.MAX_TOKENS, + "max_tokens": FinishReason.MAX_TOKENS, + FinishReason.SAFETY: FinishReason.SAFETY, + 3: FinishReason.SAFETY, + "finish_reason_safety": FinishReason.SAFETY, + "safety": FinishReason.SAFETY, + FinishReason.RECITATION: FinishReason.RECITATION, + 4: FinishReason.RECITATION, + "finish_reason_recitation": FinishReason.RECITATION, + "recitation": FinishReason.RECITATION, + FinishReason.OTHER: FinishReason.OTHER, + 5: FinishReason.OTHER, + "finish_reason_other": FinishReason.OTHER, + "other": FinishReason.OTHER, +} + + +def to_finish_reason(x: FinishReasonOptions) -> FinishReason: + if isinstance(x, str): + x = x.lower() + return _FINISH_REASONS[x] + + +class AttributionSourceId(TypedDict): + passage_id: str + part_index: int + + +class GroundingAttribution(TypedDict): + source_id: AttributionSourceId + content: content_types.ContentType + + +class Candidate(TypedDict): + index: Optional[int] + content: content_types.ContentType + finish_reason: Optional[glm.Candidate.FinishReason] + finish_message: Optional[str] + safety_ratings: List[safety_types.SafetyRatingDict | None] + citation_metadata: citation_types.CitationMetadataDict | None + token_count: int + grounding_attribution: list[GroundingAttribution] + + +@string_utils.prettyprint +@dataclasses.dataclass(init=False) +class Answer(abc.ABC): + """The result returned by `generativeai.generate_answer`. + + Use `GenerateAnswerResponse.answer` to access all the candidates used to create the answer. + + Attributes: + answer: Answer grounded in the requested passages. + answerable_probability: Indicates which safety settings blocked content in this result. + + """ + + answer: Candidate + answerable_probability: float + + def to_dict(self) -> dict[str, Any]: + result = { + "answer": self.answer, + "answerable_probability": self.answerable_probability, + } + return result diff --git a/tests/test_answer.py b/tests/test_answer.py new file mode 100644 index 000000000..e04337ee9 --- /dev/null +++ b/tests/test_answer.py @@ -0,0 +1,202 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import copy +import math +import unittest +import unittest.mock as mock + +import google.ai.generativelanguage as glm + +from google.generativeai import answer +from google.generativeai import client +from absl.testing import absltest +from absl.testing import parameterized + +from google.generativeai.types import content_types + +DEFAULT_ANSWER_MODEL = "models/aqa" + + +class UnitTests(parameterized.TestCase): + def setUp(self): + self.client = unittest.mock.MagicMock() + + client._client_manager.clients["generative"] = self.client + client._client_manager.clients["model"] = self.client + + self.observed_requests = [] + + def add_client_method(f): + name = f.__name__ + setattr(self.client, name, f) + return f + + @add_client_method + def generate_answer( + request: glm.GenerateAnswerRequest, + ) -> glm.GenerateAnswerResponse: + self.observed_requests.append(request) + return glm.GenerateAnswerResponse( + answer=glm.Candidate( + index=1, + content=(glm.Content(parts=[glm.Part(text="Demo answer.")])), + ), + answerable_probability=0.500, + ) + + @parameterized.named_parameters( + [ + dict( + testcase_name="grounding_passage", + inline_passages=glm.GroundingPassages( + passages=[ + { + "id": "0", + "content": glm.Content(parts=[glm.Part(text="I am a chicken")]), + }, + {"id": "1", "content": glm.Content(parts=[glm.Part(text="I am a bird.")])}, + {"id": "2", "content": glm.Content(parts=[glm.Part(text="I can fly!")])}, + ] + ), + ), + dict( + testcase_name="content_object", + inline_passages=[ + glm.Content(parts=[glm.Part(text="I am a chicken")]), + glm.Content(parts=[glm.Part(text="I am a bird.")]), + glm.Content(parts=[glm.Part(text="I can fly!")]), + ], + ), + dict( + testcase_name="list_of_strings", + inline_passages=["I am a chicken", "I am a bird.", "I can fly!"], + ), + dict( + testcase_name="dict_of_strings", + inline_passages={4: "I am a chicken", 5: "I am a bird.", 6: "I can fly!"}, + ), + dict( + testcase_name="tuple_of_strings", + inline_passages=[(4, "I am a chicken"), (5, "I am a bird."), (6, "I can fly!")], + ), + dict( + testcase_name="mixed_types", + inline_passages=[ + "I am a chicken", + glm.Content(parts=[glm.Part(text="I am a bird.")]), + glm.Content(parts=[glm.Part(text="I can fly!")]), + ], + ), + dict( + testcase_name="list_of_grounding_passages", + inline_passages=[ + glm.GroundingPassage( + id="0", content=glm.Content(parts=[glm.Part(text="I am a chicken")]) + ), + glm.GroundingPassage( + id="1", content=glm.Content(parts=[glm.Part(text="I am a bird.")]) + ), + glm.GroundingPassage( + id="2", content=glm.Content(parts=[glm.Part(text="I can fly!")]) + ), + ], + ), + ] + ) + def test_make_grounding_passages(self, inline_passages): + x = answer._make_grounding_passages(inline_passages) + self.assertIsInstance(x, glm.GroundingPassages) + self.assertEqual( + glm.GroundingPassages( + passages=[ + {"id": "0", "content": glm.Content(parts=[glm.Part(text="I am a chicken")])}, + {"id": "1", "content": glm.Content(parts=[glm.Part(text="I am a bird.")])}, + {"id": "2", "content": glm.Content(parts=[glm.Part(text="I can fly!")])}, + ] + ), + x, + ) + + def test_make_grounding_passages_key_strings(self): + inline_passages = { + "first": "I am a chicken", + "second": "I am a bird.", + "third": "I can fly!", + } + + x = answer._make_grounding_passages(inline_passages) + self.assertIsInstance(x, glm.GroundingPassages) + self.assertEqual( + glm.GroundingPassages( + passages=[ + { + "id": "first", + "content": glm.Content(parts=[glm.Part(text="I am a chicken")]), + }, + {"id": "second", "content": glm.Content(parts=[glm.Part(text="I am a bird.")])}, + {"id": "third", "content": glm.Content(parts=[glm.Part(text="I can fly!")])}, + ] + ), + x, + ) + + def test_generate_answer_request(self): + # Should be a list of contents to use to_contents() function. + contents = [glm.Content(parts=[glm.Part(text="I have wings.")])] + + inline_passages = ["I am a chicken", "I am a bird.", "I can fly!"] + grounding_passages = glm.GroundingPassages( + passages=[ + {"id": "0", "content": glm.Content(parts=[glm.Part(text="I am a chicken")])}, + {"id": "1", "content": glm.Content(parts=[glm.Part(text="I am a bird.")])}, + {"id": "2", "content": glm.Content(parts=[glm.Part(text="I can fly!")])}, + ] + ) + + x = answer._make_generate_answer_request( + model=DEFAULT_ANSWER_MODEL, contents=contents, grounding_source=inline_passages + ) + + self.assertEqual( + glm.GenerateAnswerRequest( + model=DEFAULT_ANSWER_MODEL, contents=contents, inline_passages=grounding_passages + ), + x, + ) + + def test_generate_answer(self): + # Test handling return value of generate_answer(). + contents = [glm.Content(parts=[glm.Part(text="I have wings.")])] + + grounding_passages = glm.GroundingPassages( + passages=[ + {"id": "0", "content": glm.Content(parts=[glm.Part(text="I am a chicken")])}, + {"id": "1", "content": glm.Content(parts=[glm.Part(text="I am a bird.")])}, + {"id": "2", "content": glm.Content(parts=[glm.Part(text="I can fly!")])}, + ] + ) + request = glm.GenerateAnswerRequest( + model="models/aqa", + contents=contents, + inline_passages=grounding_passages, + answer_style="ABSTRACTIVE", + ) + response = answer._generate_answer_response(request) + self.assertIsInstance(response, answer.Answer) + self.assertEqual(response.result[0]["text"], "Demo answer.") + + +if __name__ == "__main__": + absltest.main() From 1d8b37eaa91eb39c91ddf7dbf68c8614d5e92db5 Mon Sep 17 00:00:00 2001 From: Shilpa Kancharla Date: Fri, 12 Jan 2024 14:19:24 -0500 Subject: [PATCH 2/9] Updating safety_settings parameter in answer.py --- google/generativeai/answer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/google/generativeai/answer.py b/google/generativeai/answer.py index 355a41f06..562c498bd 100644 --- a/google/generativeai/answer.py +++ b/google/generativeai/answer.py @@ -134,7 +134,7 @@ def _make_generate_answer_request( contents = content_types.to_contents(contents) if safety_settings: - safety_settings = safety_types.normalize_safety_settings(safety_settings) + safety_settings = safety_types.normalize_safety_settings(safety_settings, harm_category_set="new") grounding_source = _make_grounding_passages(grounding_source) From 18433e148b7128d21361d7956ad93f8553c66b3b Mon Sep 17 00:00:00 2001 From: Shilpa Kancharla Date: Fri, 12 Jan 2024 14:21:54 -0500 Subject: [PATCH 3/9] Fix formatting with black . --- google/generativeai/answer.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/google/generativeai/answer.py b/google/generativeai/answer.py index 562c498bd..10c8ce764 100644 --- a/google/generativeai/answer.py +++ b/google/generativeai/answer.py @@ -134,7 +134,9 @@ def _make_generate_answer_request( contents = content_types.to_contents(contents) if safety_settings: - safety_settings = safety_types.normalize_safety_settings(safety_settings, harm_category_set="new") + safety_settings = safety_types.normalize_safety_settings( + safety_settings, harm_category_set="new" + ) grounding_source = _make_grounding_passages(grounding_source) From 3674baf64dafab6d6f8e69c609c60e3239f0e052 Mon Sep 17 00:00:00 2001 From: Shilpa Kancharla Date: Thu, 25 Jan 2024 12:30:57 -0800 Subject: [PATCH 4/9] Update GroundingPassage error message Co-authored-by: Mark Daoust --- google/generativeai/answer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/google/generativeai/answer.py b/google/generativeai/answer.py index 10c8ce764..327ace469 100644 --- a/google/generativeai/answer.py +++ b/google/generativeai/answer.py @@ -101,7 +101,7 @@ def _make_grounding_passages(source: GroundingPassagesOptions) -> glm.GroundingP passages.append({"id": str(n), "content": content_types.to_content(data)}) return glm.GroundingPassages(passages=passages) else: - raise TypeError("`source` must be a valid `Content` type object.") + raise TypeError(f"`source` must be a valid `GroundingPassagesOptions` type object got a: `{type(source)}`.") def _make_generate_answer_request( From a8cb6739151af51f55c03b3d66455be6a6b2b7e8 Mon Sep 17 00:00:00 2001 From: Shilpa Kancharla Date: Thu, 25 Jan 2024 13:48:03 -0800 Subject: [PATCH 5/9] Removed separate Answer class, updated test cases --- .gitignore | 1 + google/generativeai/answer.py | 103 ++++++++-------------- google/generativeai/types/answer_types.py | 24 ----- tests/test_answer.py | 95 +++++++++++++------- 4 files changed, 99 insertions(+), 124 deletions(-) diff --git a/.gitignore b/.gitignore index f6e6bdba8..c9f130875 100644 --- a/.gitignore +++ b/.gitignore @@ -8,3 +8,4 @@ .DS_Store __pycache__ *.iml +/venv310/ \ No newline at end of file diff --git a/google/generativeai/answer.py b/google/generativeai/answer.py index 327ace469..694a6a851 100644 --- a/google/generativeai/answer.py +++ b/google/generativeai/answer.py @@ -31,16 +31,6 @@ DEFAULT_ANSWER_MODEL = "models/aqa" -GroundingPassageOptions = ( - Union[glm.GroundingPassage, tuple[str, content_types.ContentType], content_types.ContentType], -) - -GroundingPassagesOptions = Union[ - glm.GroundingPassages, - Iterable[GroundingPassageOptions], - Mapping[str, content_types.ContentType], -] - AnswerStyle = glm.GenerateAnswerRequest.AnswerStyle AnswerStyleOptions = Union[int, str, AnswerStyle] @@ -71,37 +61,50 @@ def to_answer_style(x: AnswerStyleOptions) -> AnswerStyle: return _ANSWER_STYLES[x] +GroundingPassageOptions = ( + Union[glm.GroundingPassage, tuple[str, content_types.ContentType], content_types.ContentType], +) + +GroundingPassagesOptions = Union[ + glm.GroundingPassages, + Iterable[GroundingPassageOptions], + Mapping[str, content_types.ContentType], +] + + def _make_grounding_passages(source: GroundingPassagesOptions) -> glm.GroundingPassages: """ - Creates a list of `glm.Content` wrapped in `glm.GroundingPassages`. The `glm.GroundingPassage` - object contains an id of the content, and the actual content itself. + Converts the `source` into a `glm.GroundingPassage`. A `GroundingPassages` contains a list of + `glm.GroundingPassage` objects, which each contain a `glm.Contant` and a string `id`. Args: - source: `Content` or an iterable `Content` that will be converted to glm.GroundingPassages. + source: `Content` or a `GroundingPassagesOptions` that will be converted to glm.GroundingPassages. Return: - glm.GroundingPassages to be passed into glm.GenerateAnswer. + `glm.GroundingPassages` to be passed into `glm.GenerateAnswer`. """ + if not isinstance(source, Iterable): + raise TypeError( + f"`source` must be a valid `GroundingPassagesOptions` type object got a: `{type(source)}`." + ) + if isinstance(source, glm.GroundingPassages): return source passages = [] - if isinstance(source, Iterable): - if isinstance(source, Mapping): - source = source.items() - for n, data in enumerate(source): - if isinstance(data, glm.GroundingPassage): - passages.append(data) - elif isinstance(data, tuple): - if isinstance(data[0], str): - passages.append({"id": data[0], "content": content_types.to_content(data[1])}) - else: - passages.append({"id": str(n), "content": content_types.to_content(data[1])}) - else: - passages.append({"id": str(n), "content": content_types.to_content(data)}) - return glm.GroundingPassages(passages=passages) - else: - raise TypeError(f"`source` must be a valid `GroundingPassagesOptions` type object got a: `{type(source)}`.") + if isinstance(source, Mapping): + source = source.items() + + for n, data in enumerate(source): + if isinstance(data, glm.GroundingPassage): + passages.append(data) + elif isinstance(data, tuple): + id, content = data # tuple must have exactly 2 items. + passages.append({'id':id, 'content': content_types.to_content(content)}) + else: + passages.append({"id": str(n), "content": content_types.to_content(data)}) + + return glm.GroundingPassages(passages=passages) def _make_generate_answer_request( @@ -123,8 +126,8 @@ def _make_generate_answer_request( conversation history and the last `Content` in the list containing the question. grounding_source: Sources in which to grounding the answer. answer_style: Style for grounded answers. - safety_settings: Safety settings for generated output. Defaults to None. - temperature: The temperature for randomness in the output. Defaults to None. + safety_settings: Safety settings for generated output. + temperature: The temperature for randomness in the output. Returns: Call for glm.GenerateAnswerRequest(). @@ -162,7 +165,7 @@ def generate_answer( safety_settings: safety_types.SafetySettingOptions | None = None, temperature: float | None = None, client: glm.GenerativeServiceClient | None = None, -) -> answer_types.Answer: +): """ Calls the API and returns a `types.Answer` containing the answer. @@ -187,39 +190,7 @@ def generate_answer( answer_style=answer_style, ) - return _generate_answer_response(client=client, request=request) - - -@string_utils.prettyprint -@dataclasses.dataclass(init=False) -class Answer(answer_types.Answer): - def __init__(self, **kwargs): - for key, value in kwargs.items(): - setattr(self, key, value) - - self.result = None - if self.answer: - self.result = self.answer["content"]["parts"] - - -def _generate_answer_response( - request: glm.GenerateAnswerRequest, client: glm.GenerativeServiceClient | None = None -) -> Answer: - """ - Generates a response using the provided `glm.GenerateAnswerRequest` and client. - - Args: - request: The answer generation request. - client: The client used for text answer generation. Defaults to None, in which - case the default generative client is used. - - Returns: - `Answer`: An `Answer` object with the generated text and response information. - """ - if client is None: - client = get_default_generative_client() - response = client.generate_answer(request) response = type(response).to_dict(response) - return Answer(_client=client, **response) + return response diff --git a/google/generativeai/types/answer_types.py b/google/generativeai/types/answer_types.py index 53489be0a..e60b45070 100644 --- a/google/generativeai/types/answer_types.py +++ b/google/generativeai/types/answer_types.py @@ -84,27 +84,3 @@ class Candidate(TypedDict): citation_metadata: citation_types.CitationMetadataDict | None token_count: int grounding_attribution: list[GroundingAttribution] - - -@string_utils.prettyprint -@dataclasses.dataclass(init=False) -class Answer(abc.ABC): - """The result returned by `generativeai.generate_answer`. - - Use `GenerateAnswerResponse.answer` to access all the candidates used to create the answer. - - Attributes: - answer: Answer grounded in the requested passages. - answerable_probability: Indicates which safety settings blocked content in this result. - - """ - - answer: Candidate - answerable_probability: float - - def to_dict(self) -> dict[str, Any]: - result = { - "answer": self.answer, - "answerable_probability": self.answerable_probability, - } - return result diff --git a/tests/test_answer.py b/tests/test_answer.py index e04337ee9..0aa4835e2 100644 --- a/tests/test_answer.py +++ b/tests/test_answer.py @@ -55,6 +55,25 @@ def generate_answer( ), answerable_probability=0.500, ) + + def test_make_grounding_passages_mixed_types(self): + inline_passages=[ + "I am a chicken", + glm.Content(parts=[glm.Part(text="I am a bird.")]), + glm.Content(parts=[glm.Part(text="I can fly!")]), + ] + x = answer._make_grounding_passages(inline_passages) + self.assertIsInstance(x, glm.GroundingPassages) + self.assertEqual( + glm.GroundingPassages( + passages=[ + {"id": "0", "content": glm.Content(parts=[glm.Part(text="I am a chicken")])}, + {"id": "1", "content": glm.Content(parts=[glm.Part(text="I am a bird.")])}, + {"id": "2", "content": glm.Content(parts=[glm.Part(text="I can fly!")])}, + ] + ), + x, + ) @parameterized.named_parameters( [ @@ -83,36 +102,6 @@ def generate_answer( testcase_name="list_of_strings", inline_passages=["I am a chicken", "I am a bird.", "I can fly!"], ), - dict( - testcase_name="dict_of_strings", - inline_passages={4: "I am a chicken", 5: "I am a bird.", 6: "I can fly!"}, - ), - dict( - testcase_name="tuple_of_strings", - inline_passages=[(4, "I am a chicken"), (5, "I am a bird."), (6, "I can fly!")], - ), - dict( - testcase_name="mixed_types", - inline_passages=[ - "I am a chicken", - glm.Content(parts=[glm.Part(text="I am a bird.")]), - glm.Content(parts=[glm.Part(text="I can fly!")]), - ], - ), - dict( - testcase_name="list_of_grounding_passages", - inline_passages=[ - glm.GroundingPassage( - id="0", content=glm.Content(parts=[glm.Part(text="I am a chicken")]) - ), - glm.GroundingPassage( - id="1", content=glm.Content(parts=[glm.Part(text="I am a bird.")]) - ), - glm.GroundingPassage( - id="2", content=glm.Content(parts=[glm.Part(text="I can fly!")]) - ), - ], - ), ] ) def test_make_grounding_passages(self, inline_passages): @@ -129,6 +118,44 @@ def test_make_grounding_passages(self, inline_passages): x, ) + @parameterized.named_parameters( + dict( + testcase_name="dict_of_strings", + inline_passages={"4": "I am a chicken", "5": "I am a bird.", "6": "I can fly!"}, + ), + dict( + testcase_name="tuple_of_strings", + inline_passages=[("4", "I am a chicken"), ("5", "I am a bird."), ("6", "I can fly!")], + ), + dict( + testcase_name="list_of_grounding_passages", + inline_passages=[ + glm.GroundingPassage( + id="4", content=glm.Content(parts=[glm.Part(text="I am a chicken")]) + ), + glm.GroundingPassage( + id="5", content=glm.Content(parts=[glm.Part(text="I am a bird.")]) + ), + glm.GroundingPassage( + id="6", content=glm.Content(parts=[glm.Part(text="I can fly!")]) + ), + ], + ), + ) + def test_make_grounding_passages_different_id(self, inline_passages): + x = answer._make_grounding_passages(inline_passages) + self.assertIsInstance(x, glm.GroundingPassages) + self.assertEqual( + glm.GroundingPassages( + passages=[ + {"id": "4", "content": glm.Content(parts=[glm.Part(text="I am a chicken")])}, + {"id": "5", "content": glm.Content(parts=[glm.Part(text="I am a bird.")])}, + {"id": "6", "content": glm.Content(parts=[glm.Part(text="I can fly!")])}, + ] + ), + x, + ) + def test_make_grounding_passages_key_strings(self): inline_passages = { "first": "I am a chicken", @@ -187,15 +214,15 @@ def test_generate_answer(self): {"id": "2", "content": glm.Content(parts=[glm.Part(text="I can fly!")])}, ] ) - request = glm.GenerateAnswerRequest( + print(type(grounding_passages)) + a = answer.generate_answer( model="models/aqa", contents=contents, inline_passages=grounding_passages, answer_style="ABSTRACTIVE", ) - response = answer._generate_answer_response(request) - self.assertIsInstance(response, answer.Answer) - self.assertEqual(response.result[0]["text"], "Demo answer.") + self.assertIsInstance(a, glm.GenerateAnswerResponse) + self.assertEqual(a.result[0]["text"], "Demo answer.") if __name__ == "__main__": From 15bb21129923689055fd69e8a6750aa555538408 Mon Sep 17 00:00:00 2001 From: Shilpa Kancharla Date: Fri, 26 Jan 2024 14:00:06 -0800 Subject: [PATCH 6/9] Fixed test cases for answer.py --- google/generativeai/answer.py | 21 ++++++++++++--------- google/generativeai/types/answer_types.py | 19 +++++++++++++++++++ tests/test_answer.py | 8 ++++---- 3 files changed, 35 insertions(+), 13 deletions(-) diff --git a/google/generativeai/answer.py b/google/generativeai/answer.py index 694a6a851..62f1690b3 100644 --- a/google/generativeai/answer.py +++ b/google/generativeai/answer.py @@ -74,7 +74,7 @@ def to_answer_style(x: AnswerStyleOptions) -> AnswerStyle: def _make_grounding_passages(source: GroundingPassagesOptions) -> glm.GroundingPassages: """ - Converts the `source` into a `glm.GroundingPassage`. A `GroundingPassages` contains a list of + Converts the `source` into a `glm.GroundingPassage`. A `GroundingPassages` contains a list of `glm.GroundingPassage` objects, which each contain a `glm.Contant` and a string `id`. Args: @@ -83,14 +83,14 @@ def _make_grounding_passages(source: GroundingPassagesOptions) -> glm.GroundingP Return: `glm.GroundingPassages` to be passed into `glm.GenerateAnswer`. """ + if isinstance(source, glm.GroundingPassages): + return source + if not isinstance(source, Iterable): raise TypeError( f"`source` must be a valid `GroundingPassagesOptions` type object got a: `{type(source)}`." ) - - if isinstance(source, glm.GroundingPassages): - return source - + passages = [] if isinstance(source, Mapping): source = source.items() @@ -99,8 +99,8 @@ def _make_grounding_passages(source: GroundingPassagesOptions) -> glm.GroundingP if isinstance(data, glm.GroundingPassage): passages.append(data) elif isinstance(data, tuple): - id, content = data # tuple must have exactly 2 items. - passages.append({'id':id, 'content': content_types.to_content(content)}) + id, content = data # tuple must have exactly 2 items. + passages.append({"id": id, "content": content_types.to_content(content)}) else: passages.append({"id": str(n), "content": content_types.to_content(data)}) @@ -126,8 +126,8 @@ def _make_generate_answer_request( conversation history and the last `Content` in the list containing the question. grounding_source: Sources in which to grounding the answer. answer_style: Style for grounded answers. - safety_settings: Safety settings for generated output. - temperature: The temperature for randomness in the output. + safety_settings: Safety settings for generated output. + temperature: The temperature for randomness in the output. Returns: Call for glm.GenerateAnswerRequest(). @@ -181,6 +181,9 @@ def generate_answer( Returns: A `types.Answer` containing the model's text answer response. """ + if client is None: + client = get_default_generative_client() + request = _make_generate_answer_request( model=model, contents=contents, diff --git a/google/generativeai/types/answer_types.py b/google/generativeai/types/answer_types.py index e60b45070..84a2cbe11 100644 --- a/google/generativeai/types/answer_types.py +++ b/google/generativeai/types/answer_types.py @@ -27,6 +27,25 @@ __all__ = ["Answer"] +""" BlockReason = glm.InputFeedback.BlockReason + +BlockReasonOptions = Union[int, str, BlockReason] + +_BLOCK_REASONS: dict[BlockReasonOptions, BlockReason] = { + BlockReason.BLOCK_REASON_UNSPECIFIED: BlockReason.BLOCK_REASON_UNSPECIFIED, + 0: BlockReason.BLOCK_REASON_UNSPECIFIED, + "block_reason_unspecified": BlockReason.BLOCK_REASON_UNSPECIFIED, + "unspecified": BlockReason.BLOCK_REASON_UNSPECIFIED, + BlockReason.SAFETY: BlockReason.SAFETY, + 1: BlockReason.SAFETY, + "block_reason_safety": BlockReason.SAFETY, + "safety": BlockReason.SAFETY, + BlockReason.OTHER: BlockReason.OTHER, + 2: BlockReason.OTHER, + "block_reason_other": BlockReason.OTHER, + "other": BlockReason.OTHER, +} """ + FinishReason = glm.Candidate.FinishReason FinishReasonOptions = Union[int, str, FinishReason] diff --git a/tests/test_answer.py b/tests/test_answer.py index 0aa4835e2..c2bf3ab67 100644 --- a/tests/test_answer.py +++ b/tests/test_answer.py @@ -55,9 +55,9 @@ def generate_answer( ), answerable_probability=0.500, ) - + def test_make_grounding_passages_mixed_types(self): - inline_passages=[ + inline_passages = [ "I am a chicken", glm.Content(parts=[glm.Part(text="I am a bird.")]), glm.Content(parts=[glm.Part(text="I can fly!")]), @@ -221,8 +221,8 @@ def test_generate_answer(self): inline_passages=grounding_passages, answer_style="ABSTRACTIVE", ) - self.assertIsInstance(a, glm.GenerateAnswerResponse) - self.assertEqual(a.result[0]["text"], "Demo answer.") + self.assertIsInstance(a, dict) + self.assertEqual(a["answer"]["content"]["parts"][0]["text"], "Demo answer.") if __name__ == "__main__": From 9d672aa618cf5be717463d94f9062e920ada195a Mon Sep 17 00:00:00 2001 From: Shilpa Kancharla Date: Wed, 31 Jan 2024 13:31:40 -0800 Subject: [PATCH 7/9] Removed unecessary enums and classes --- google/generativeai/types/answer_types.py | 40 ----------------------- 1 file changed, 40 deletions(-) diff --git a/google/generativeai/types/answer_types.py b/google/generativeai/types/answer_types.py index 84a2cbe11..fe9abaed1 100644 --- a/google/generativeai/types/answer_types.py +++ b/google/generativeai/types/answer_types.py @@ -27,25 +27,6 @@ __all__ = ["Answer"] -""" BlockReason = glm.InputFeedback.BlockReason - -BlockReasonOptions = Union[int, str, BlockReason] - -_BLOCK_REASONS: dict[BlockReasonOptions, BlockReason] = { - BlockReason.BLOCK_REASON_UNSPECIFIED: BlockReason.BLOCK_REASON_UNSPECIFIED, - 0: BlockReason.BLOCK_REASON_UNSPECIFIED, - "block_reason_unspecified": BlockReason.BLOCK_REASON_UNSPECIFIED, - "unspecified": BlockReason.BLOCK_REASON_UNSPECIFIED, - BlockReason.SAFETY: BlockReason.SAFETY, - 1: BlockReason.SAFETY, - "block_reason_safety": BlockReason.SAFETY, - "safety": BlockReason.SAFETY, - BlockReason.OTHER: BlockReason.OTHER, - 2: BlockReason.OTHER, - "block_reason_other": BlockReason.OTHER, - "other": BlockReason.OTHER, -} """ - FinishReason = glm.Candidate.FinishReason FinishReasonOptions = Union[int, str, FinishReason] @@ -82,24 +63,3 @@ def to_finish_reason(x: FinishReasonOptions) -> FinishReason: if isinstance(x, str): x = x.lower() return _FINISH_REASONS[x] - - -class AttributionSourceId(TypedDict): - passage_id: str - part_index: int - - -class GroundingAttribution(TypedDict): - source_id: AttributionSourceId - content: content_types.ContentType - - -class Candidate(TypedDict): - index: Optional[int] - content: content_types.ContentType - finish_reason: Optional[glm.Candidate.FinishReason] - finish_message: Optional[str] - safety_ratings: List[safety_types.SafetyRatingDict | None] - citation_metadata: citation_types.CitationMetadataDict | None - token_count: int - grounding_attribution: list[GroundingAttribution] From 673a12ffcea4b3fd61d93b89920cb37151e086d3 Mon Sep 17 00:00:00 2001 From: Shilpa Kancharla Date: Wed, 31 Jan 2024 13:41:03 -0800 Subject: [PATCH 8/9] Leave response as proto and return --- google/generativeai/answer.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/google/generativeai/answer.py b/google/generativeai/answer.py index 62f1690b3..21e77e032 100644 --- a/google/generativeai/answer.py +++ b/google/generativeai/answer.py @@ -85,12 +85,12 @@ def _make_grounding_passages(source: GroundingPassagesOptions) -> glm.GroundingP """ if isinstance(source, glm.GroundingPassages): return source - + if not isinstance(source, Iterable): raise TypeError( f"`source` must be a valid `GroundingPassagesOptions` type object got a: `{type(source)}`." ) - + passages = [] if isinstance(source, Mapping): source = source.items() @@ -183,7 +183,7 @@ def generate_answer( """ if client is None: client = get_default_generative_client() - + request = _make_generate_answer_request( model=model, contents=contents, @@ -194,6 +194,5 @@ def generate_answer( ) response = client.generate_answer(request) - response = type(response).to_dict(response) return response From f2ed9e1fe860893e16b029818b33dfa1c787cc1c Mon Sep 17 00:00:00 2001 From: Shilpa Kancharla Date: Wed, 31 Jan 2024 13:51:15 -0800 Subject: [PATCH 9/9] Updated test case for returning proto --- tests/test_answer.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/tests/test_answer.py b/tests/test_answer.py index c2bf3ab67..d916d24e3 100644 --- a/tests/test_answer.py +++ b/tests/test_answer.py @@ -221,8 +221,17 @@ def test_generate_answer(self): inline_passages=grounding_passages, answer_style="ABSTRACTIVE", ) - self.assertIsInstance(a, dict) - self.assertEqual(a["answer"]["content"]["parts"][0]["text"], "Demo answer.") + self.assertIsInstance(a, glm.GenerateAnswerResponse) + self.assertEqual( + a, + glm.GenerateAnswerResponse( + answer=glm.Candidate( + index=1, + content=(glm.Content(parts=[glm.Part(text="Demo answer.")])), + ), + answerable_probability=0.500, + ), + ) if __name__ == "__main__":