Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding AQA (GenerateAnswer). #169

Merged
merged 10 commits into from
Jan 31, 2024
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@
.DS_Store
__pycache__
*.iml
/venv310/
199 changes: 199 additions & 0 deletions google/generativeai/answer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,199 @@
# -*- coding: utf-8 -*-
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations

import dataclasses
from collections.abc import Iterable
import itertools
from typing import Iterable, Union, Mapping, Optional, Any

import google.ai.generativelanguage as glm

from google.generativeai.client import get_default_generative_client
from google.generativeai import string_utils
from google.generativeai.types import model_types
from google.generativeai import models
from google.generativeai.types import safety_types
from google.generativeai.types import content_types
from google.generativeai.types import answer_types

DEFAULT_ANSWER_MODEL = "models/aqa"

AnswerStyle = glm.GenerateAnswerRequest.AnswerStyle

AnswerStyleOptions = Union[int, str, AnswerStyle]

_ANSWER_STYLES: dict[AnswerStyleOptions, AnswerStyle] = {
AnswerStyle.ANSWER_STYLE_UNSPECIFIED: AnswerStyle.ANSWER_STYLE_UNSPECIFIED,
0: AnswerStyle.ANSWER_STYLE_UNSPECIFIED,
"answer_style_unspecified": AnswerStyle.ANSWER_STYLE_UNSPECIFIED,
"unspecified": AnswerStyle.ANSWER_STYLE_UNSPECIFIED,
AnswerStyle.ABSTRACTIVE: AnswerStyle.ABSTRACTIVE,
1: AnswerStyle.ABSTRACTIVE,
"answer_style_abstractive": AnswerStyle.ABSTRACTIVE,
"abstractive": AnswerStyle.ABSTRACTIVE,
AnswerStyle.EXTRACTIVE: AnswerStyle.EXTRACTIVE,
2: AnswerStyle.EXTRACTIVE,
"answer_style_extractive": AnswerStyle.EXTRACTIVE,
"extractive": AnswerStyle.EXTRACTIVE,
AnswerStyle.VERBOSE: AnswerStyle.VERBOSE,
3: AnswerStyle.VERBOSE,
"answer_style_verbose": AnswerStyle.VERBOSE,
"verbose": AnswerStyle.VERBOSE,
}


def to_answer_style(x: AnswerStyleOptions) -> AnswerStyle:
if isinstance(x, str):
x = x.lower()
return _ANSWER_STYLES[x]


GroundingPassageOptions = (
Union[glm.GroundingPassage, tuple[str, content_types.ContentType], content_types.ContentType],
)

GroundingPassagesOptions = Union[
glm.GroundingPassages,
Iterable[GroundingPassageOptions],
Mapping[str, content_types.ContentType],
]


def _make_grounding_passages(source: GroundingPassagesOptions) -> glm.GroundingPassages:
"""
Converts the `source` into a `glm.GroundingPassage`. A `GroundingPassages` contains a list of
`glm.GroundingPassage` objects, which each contain a `glm.Contant` and a string `id`.

Args:
source: `Content` or a `GroundingPassagesOptions` that will be converted to glm.GroundingPassages.

Return:
`glm.GroundingPassages` to be passed into `glm.GenerateAnswer`.
"""
if isinstance(source, glm.GroundingPassages):
return source

if not isinstance(source, Iterable):
raise TypeError(
f"`source` must be a valid `GroundingPassagesOptions` type object got a: `{type(source)}`."
)

passages = []
if isinstance(source, Mapping):
source = source.items()

for n, data in enumerate(source):
if isinstance(data, glm.GroundingPassage):
passages.append(data)
elif isinstance(data, tuple):
id, content = data # tuple must have exactly 2 items.
passages.append({"id": id, "content": content_types.to_content(content)})
else:
passages.append({"id": str(n), "content": content_types.to_content(data)})

return glm.GroundingPassages(passages=passages)


def _make_generate_answer_request(
*,
model: model_types.AnyModelNameOptions = DEFAULT_ANSWER_MODEL,
contents: content_types.ContentsType,
grounding_source: GroundingPassagesOptions,
answer_style: AnswerStyle | None = None,
safety_settings: safety_types.SafetySettingOptions | None = None,
temperature: float | None = None,
) -> glm.GenerateAnswerRequest:
"""
Calls the API to generate a grounded answer from the model.

Args:
model: Name of the model used to generate the grounded response.
contents: Content of the current conversation with the model. For single-turn query, this is a
single question to answer. For multi-turn queries, this is a repeated field that contains
conversation history and the last `Content` in the list containing the question.
grounding_source: Sources in which to grounding the answer.
answer_style: Style for grounded answers.
safety_settings: Safety settings for generated output.
temperature: The temperature for randomness in the output.

Returns:
Call for glm.GenerateAnswerRequest().
"""
model = model_types.make_model_name(model)

contents = content_types.to_contents(contents)

if safety_settings:
safety_settings = safety_types.normalize_safety_settings(
safety_settings, harm_category_set="new"

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note:
GenerateAnswer is currently using the "old" HarmCategory set, but it will switch over to "new" very shortly.
I think you should leave this as "new", because the switch from old to new should be complete by early next week, probably sooner than when this code gets released.
But just FYI you might see an exception "AQA does not yet support custom harassment safety thresholds" in the meantime.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the warning.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We'll test against the API before making the release.

)

grounding_source = _make_grounding_passages(grounding_source)

if answer_style:
answer_style = to_answer_style(answer_style)

return glm.GenerateAnswerRequest(
model=model,
contents=contents,
inline_passages=grounding_source,
safety_settings=safety_settings,
temperature=temperature,
answer_style=answer_style,
)


def generate_answer(
*,
model: model_types.AnyModelNameOptions = DEFAULT_ANSWER_MODEL,
contents: content_types.ContentsType,
inline_passages: GroundingPassagesOptions,
answer_style: AnswerStyle | None = None,
safety_settings: safety_types.SafetySettingOptions | None = None,
temperature: float | None = None,
client: glm.GenerativeServiceClient | None = None,
):
"""
Calls the API and returns a `types.Answer` containing the answer.

Args:
model: Which model to call, as a string or a `types.Model`.
question: The question to be answered by the model, grounded in the
provided source.
grounding_source: Source indicating the passages in which the answer should be grounded.
answer_style: Style in which the grounded answer should be returned.
safety_settings: Safety settings for generated output. Defaults to None.
client: If you're not relying on a default client, you pass a `glm.TextServiceClient` instead.

Returns:
A `types.Answer` containing the model's text answer response.
"""
if client is None:
client = get_default_generative_client()

request = _make_generate_answer_request(
model=model,
contents=contents,
grounding_source=inline_passages,
safety_settings=safety_settings,
temperature=temperature,
answer_style=answer_style,
)

response = client.generate_answer(request)
response = type(response).to_dict(response)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the rest of the "GenerativeService" APIs I've just been using the protos as it. They're ugly when you print them, but we should stay in sync on that.


return response
105 changes: 105 additions & 0 deletions google/generativeai/types/answer_types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
# -*- coding: utf-8 -*-
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
MarkDaoust marked this conversation as resolved.
Show resolved Hide resolved

import abc
import dataclasses
from typing import Any, Dict, List, TypedDict, Optional, Union

import google.ai.generativelanguage as glm

from google.generativeai import string_utils
from google.generativeai.types import safety_types
from google.generativeai.types import citation_types
from google.generativeai.types import content_types

__all__ = ["Answer"]

""" BlockReason = glm.InputFeedback.BlockReason
shilpakancharla marked this conversation as resolved.
Show resolved Hide resolved

BlockReasonOptions = Union[int, str, BlockReason]

_BLOCK_REASONS: dict[BlockReasonOptions, BlockReason] = {
BlockReason.BLOCK_REASON_UNSPECIFIED: BlockReason.BLOCK_REASON_UNSPECIFIED,
0: BlockReason.BLOCK_REASON_UNSPECIFIED,
"block_reason_unspecified": BlockReason.BLOCK_REASON_UNSPECIFIED,
"unspecified": BlockReason.BLOCK_REASON_UNSPECIFIED,
BlockReason.SAFETY: BlockReason.SAFETY,
1: BlockReason.SAFETY,
"block_reason_safety": BlockReason.SAFETY,
"safety": BlockReason.SAFETY,
BlockReason.OTHER: BlockReason.OTHER,
2: BlockReason.OTHER,
"block_reason_other": BlockReason.OTHER,
"other": BlockReason.OTHER,
} """

FinishReason = glm.Candidate.FinishReason

FinishReasonOptions = Union[int, str, FinishReason]

_FINISH_REASONS: dict[FinishReasonOptions, FinishReason] = {
FinishReason.FINISH_REASON_UNSPECIFIED: FinishReason.FINISH_REASON_UNSPECIFIED,
0: FinishReason.FINISH_REASON_UNSPECIFIED,
"finish_reason_unspecified": FinishReason.FINISH_REASON_UNSPECIFIED,
"unspecified": FinishReason.FINISH_REASON_UNSPECIFIED,
FinishReason.STOP: FinishReason.STOP,
1: FinishReason.STOP,
"finish_reason_stop": FinishReason.STOP,
"stop": FinishReason.STOP,
FinishReason.MAX_TOKENS: FinishReason.MAX_TOKENS,
2: FinishReason.MAX_TOKENS,
"finish_reason_max_tokens": FinishReason.MAX_TOKENS,
"max_tokens": FinishReason.MAX_TOKENS,
FinishReason.SAFETY: FinishReason.SAFETY,
3: FinishReason.SAFETY,
"finish_reason_safety": FinishReason.SAFETY,
"safety": FinishReason.SAFETY,
FinishReason.RECITATION: FinishReason.RECITATION,
4: FinishReason.RECITATION,
"finish_reason_recitation": FinishReason.RECITATION,
"recitation": FinishReason.RECITATION,
FinishReason.OTHER: FinishReason.OTHER,
5: FinishReason.OTHER,
"finish_reason_other": FinishReason.OTHER,
"other": FinishReason.OTHER,
}


def to_finish_reason(x: FinishReasonOptions) -> FinishReason:
if isinstance(x, str):
x = x.lower()
return _FINISH_REASONS[x]


class AttributionSourceId(TypedDict):
shilpakancharla marked this conversation as resolved.
Show resolved Hide resolved
passage_id: str
part_index: int


class GroundingAttribution(TypedDict):
source_id: AttributionSourceId
content: content_types.ContentType


class Candidate(TypedDict):
index: Optional[int]
content: content_types.ContentType
finish_reason: Optional[glm.Candidate.FinishReason]
finish_message: Optional[str]
safety_ratings: List[safety_types.SafetyRatingDict | None]
citation_metadata: citation_types.CitationMetadataDict | None
token_count: int
grounding_attribution: list[GroundingAttribution]
Loading
Loading