Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

rawResponse -> raw_response for snake_case consistency #657

Merged
merged 1 commit into from
Dec 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 43 additions & 68 deletions extensions/Gemini/python/src/aiconfig_extension_gemini/Gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,7 @@ def construct_regular_outputs(response: "AsyncGenerateContentResponse") -> list[
return output_list


async def construct_stream_outputs(
response: "AsyncGenerateContentResponse", options: InferenceOptions
) -> list[Output]:
async def construct_stream_outputs(response: "AsyncGenerateContentResponse", options: InferenceOptions) -> list[Output]:
"""
Construct Outputs while also streaming the response with stream callback

Expand Down Expand Up @@ -132,17 +130,17 @@ async def serialize(

The data passed in is the completion params a user would use to call the Gemini API directly.
If the user wanted to call the Gemini API directly, they might do something like this:

```
model = genai.GenerativeModel('gemini-pro')
completion_params = {"contents": "Hello"}

model.generate_content(**completion_params)
# Note: The above line is the same as doing this:
# Note: The above line is the same as doing this:
model.generate_content(contents="Hello")
```
* Important: The contents field is what contains the input data. In this case, prompt input would be the contents field.


Args:
prompt (str): The prompt to be serialized.
Expand All @@ -152,7 +150,7 @@ async def serialize(
str: Serialized representation of the prompt and inference settings.

Sample Usage:
1.
1.
completion_params = {"contents": "Hello"}
serialized_prompts = await ai_config.serialize("prompt", completion_params, "gemini-pro")

Expand Down Expand Up @@ -188,7 +186,7 @@ async def serialize(
data = copy.deepcopy(data)
contents = data.pop("contents", None)

model_name = self.model.model_name[len("models/"):]
model_name = self.model.model_name[len("models/") :]
model_metadata = ai_config.get_model_metadata(data, model_name)

prompts = []
Expand All @@ -202,7 +200,9 @@ async def serialize(
# }
contents_is_role_dict = isinstance(contents, dict) and "role" in contents and "parts"
# Multi Turn means that the contents is a list of dicts with alternating role and parts. See for more info: https://ai.google.dev/tutorials/python_quickstart#multi-turn_conversations
contents_is_multi_turn = isinstance(contents, list) and all(isinstance(item, dict) and "role" in item and "parts" in item for item in contents)
contents_is_multi_turn = isinstance(contents, list) and all(
isinstance(item, dict) and "role" in item and "parts" in item for item in contents
)

if contents is None:
raise ValueError("No contents found in data. Gemini api request requires a contents field")
Expand All @@ -226,14 +226,21 @@ async def serialize(
outputs = [
ExecuteResult(
**{
"output_type": "execute_result",
"output_type": "execute_result",
"data": model_message_parts[0],
"metadata": {"rawResponse": model_message},
"metadata": {"raw_response": model_message},
}
)
]
i += 1
prompt = Prompt(**{"name": f'{prompt_name}_{len(prompts) + 1}', "input": user_message_parts, "metadata": {"model": model_metadata}, "outputs": outputs})
prompt = Prompt(
**{
"name": f"{prompt_name}_{len(prompts) + 1}",
"input": user_message_parts,
"metadata": {"model": model_metadata},
"outputs": outputs,
}
)
prompts.append(prompt)
i += 1
else:
Expand All @@ -255,11 +262,7 @@ async def deserialize(self, prompt: Prompt, aiconfig: "AIConfigRuntime", params:
Returns:
dict: Model-specific completion parameters.
"""
await aiconfig.callback_manager.run_callbacks(
CallbackEvent(
"on_deserialize_start", __name__, {"prompt": prompt, "params": params}
)
)
await aiconfig.callback_manager.run_callbacks(CallbackEvent("on_deserialize_start", __name__, {"prompt": prompt, "params": params}))

# Build Completion data
model_settings = self.get_model_settings(prompt, aiconfig)
Expand All @@ -270,9 +273,9 @@ async def deserialize(self, prompt: Prompt, aiconfig: "AIConfigRuntime", params:
messages = self._construct_chat_history(prompt, aiconfig, params)

resolved_prompt = resolve_prompt(prompt, params, aiconfig)

messages.append({"role": "user", "parts": [{"text": resolved_prompt}]})

completion_data["contents"] = messages
else:
# If contents is already set, do not construct chat history. TODO: @Ankush-lastmile
Expand All @@ -287,15 +290,13 @@ async def deserialize(self, prompt: Prompt, aiconfig: "AIConfigRuntime", params:
# This is checking attributes and not a dict like object. in schema.py, PromptInput allows arbitrary attributes/data, and gets serialized as an attribute because it is a pydantic type
if not hasattr(prompt_input, "contents"):
# The source code show cases this more than the docs. This curl request docs similar to python sdk: https://cloud.google.com/vertex-ai/docs/generative-ai/model-reference/gemini#request_body
raise ValueError("Unable to deserialize input. Prompt input type is not a string, Gemini Model Parser expects prompt input to contain a 'contents' field as expected by Gemini API")
raise ValueError(
"Unable to deserialize input. Prompt input type is not a string, Gemini Model Parser expects prompt input to contain a 'contents' field as expected by Gemini API"
)

completion_data['contents'] = parameterize_supported_gemini_input_data(prompt_input.contents, prompt, aiconfig, params)
completion_data["contents"] = parameterize_supported_gemini_input_data(prompt_input.contents, prompt, aiconfig, params)

await aiconfig.callback_manager.run_callbacks(
CallbackEvent(
"on_deserialize_complete", __name__, {"output": completion_data}
)
)
await aiconfig.callback_manager.run_callbacks(CallbackEvent("on_deserialize_complete", __name__, {"output": completion_data}))
return completion_data

async def run_inference(
Expand Down Expand Up @@ -347,9 +348,7 @@ async def run_inference(
outputs = construct_regular_outputs(response)

prompt.outputs = outputs
await aiconfig.callback_manager.run_callbacks(
CallbackEvent("on_run_complete", __name__, {"result": prompt.outputs})
)
await aiconfig.callback_manager.run_callbacks(CallbackEvent("on_run_complete", __name__, {"result": prompt.outputs}))
return prompt.outputs

def get_output_text(
Expand All @@ -375,27 +374,21 @@ def get_output_text(
# We use this method in the model parser functionality itself,
# so we must error to let the user know non-string
# formats are not supported
error_message = \
"""
error_message = """
We currently only support chats where prior messages from a user are only a
single message string instance, please see docstring for more details:
https://github.com/lastmile-ai/aiconfig/blob/v1.1.8/extensions/Gemini/python/src/aiconfig_extension_gemini/Gemini.py#L31-L33
"""
raise ValueError(error_message)

def _construct_chat_history(
self, prompt: Prompt, aiconfig: "AIConfigRuntime", params: Dict
) -> List:
def _construct_chat_history(self, prompt: Prompt, aiconfig: "AIConfigRuntime", params: Dict) -> List:
"""
Constructs the chat history for the model
"""
messages = []
# Default to always use chat context
remember_chat_context = not hasattr(
prompt.metadata, "remember_chat_context"
) or (
hasattr(prompt.metadata, "remember_chat_context")
and prompt.metadata.remember_chat_context != False
remember_chat_context = not hasattr(prompt.metadata, "remember_chat_context") or (
hasattr(prompt.metadata, "remember_chat_context") and prompt.metadata.remember_chat_context != False
)
if remember_chat_context:
# handle chat history. check previous prompts for the same model. if same model, add prompt and its output to completion data if it has a completed output
Expand All @@ -404,29 +397,21 @@ def _construct_chat_history(
if previous_prompt.name == prompt.name:
break

previous_prompt_is_same_model = aiconfig.get_model_name(
previous_prompt
) == aiconfig.get_model_name(prompt)
previous_prompt_is_same_model = aiconfig.get_model_name(previous_prompt) == aiconfig.get_model_name(prompt)
if previous_prompt_is_same_model:
previous_prompt_template = resolve_prompt(
previous_prompt, params, aiconfig
)
previous_prompt_template = resolve_prompt(previous_prompt, params, aiconfig)
previous_prompt_output = aiconfig.get_latest_output(previous_prompt)
previous_prompt_output_text = self.get_output_text(
previous_prompt, aiconfig, previous_prompt_output
)
previous_prompt_output_text = self.get_output_text(previous_prompt, aiconfig, previous_prompt_output)

messages.append(
{"role": "user", "parts": [{"text": previous_prompt_template}]}
)
messages.append({"role": "user", "parts": [{"text": previous_prompt_template}]})
messages.append(
{
"role": "model",
"parts": [{"text": previous_prompt_output_text}],
}
)

return messages
return messages

def get_prompt_template(self, prompt: Prompt, aiConfig: "AIConfigRuntime") -> str:
"""
Expand All @@ -445,26 +430,18 @@ def get_prompt_template(self, prompt: Prompt, aiConfig: "AIConfigRuntime") -> st
elif isinstance(contents, list):
return " ".join(contents)
elif isinstance(contents, dict):
parts= contents["parts"]
parts = contents["parts"]
if isinstance(parts, str):
return parts
elif isinstance(parts, list):
return " ".join(parts)
else:
raise Exception(
f"Cannot get prompt template string from prompt input: {prompt.input}"
)
raise Exception(f"Cannot get prompt template string from prompt input: {prompt.input}")
else:
raise Exception(
f"Cannot get prompt template string from prompt input: {prompt.input}"
)
raise Exception(f"Cannot get prompt template string from prompt input: {prompt.input}")



else:
raise Exception(
f"Cannot get prompt template string from prompt input: {prompt.input}"
)
raise Exception(f"Cannot get prompt template string from prompt input: {prompt.input}")


def refine_chat_completion_params(model_settings):
Expand Down Expand Up @@ -522,9 +499,7 @@ def contains_prompt_template(prompt: Prompt):
"""
Check if a prompt's input is a valid string.
"""
return isinstance(prompt.input, str) or (
hasattr(prompt.input, "data") and isinstance(prompt.input.data, str)
)
return isinstance(prompt.input, str) or (hasattr(prompt.input, "data") and isinstance(prompt.input.data, str))


AIConfigRuntime.register_model_parser(GeminiModelParser("gemini-pro"), "gemini-pro")
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def construct_stream_output(


def construct_regular_output(response: TextGenerationResponse, response_includes_details: bool) -> Output:
metadata = {"rawResponse": response}
metadata = {"raw_response": response}
if response_includes_details:
metadata["details"] = response.details

Expand Down Expand Up @@ -209,15 +209,13 @@ async def serialize(
prompt = Prompt(
name=prompt_name,
input=prompt_input,
metadata=PromptMetadata(
model=model_metadata, parameters=parameters, **kwargs
),
metadata=PromptMetadata(model=model_metadata, parameters=parameters, **kwargs),
)

prompts.append(prompt)
await ai_config.callback_manager.run_callbacks(CallbackEvent("on_serialize_complete", __name__, {"result": prompts }))

await ai_config.callback_manager.run_callbacks(CallbackEvent("on_serialize_complete", __name__, {"result": prompts}))

return prompts

async def deserialize(
Expand Down Expand Up @@ -251,9 +249,7 @@ async def deserialize(

return completion_data

async def run_inference(
self, prompt: Prompt, aiconfig: "AIConfigRuntime", options: InferenceOptions, parameters: dict[Any, Any]
) -> List[Output]:
async def run_inference(self, prompt: Prompt, aiconfig: "AIConfigRuntime", options: InferenceOptions, parameters: dict[Any, Any]) -> List[Output]:
"""
Invoked to run a prompt in the .aiconfig. This method should perform
the actual model inference based on the provided prompt and inference settings.
Expand Down
2 changes: 1 addition & 1 deletion extensions/HuggingFace/typescript/hf.ts
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,7 @@ function constructOutput(response: TextGenerationOutput): Output {
output_type: "execute_result",
data: response.generated_text,
execution_count: 0,
metadata: { rawResponse: response },
metadata: { raw_response: response },
} as ExecuteResult;

return output;
Expand Down
18 changes: 7 additions & 11 deletions python/src/aiconfig/default_parsers/hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,14 +109,14 @@ def construct_stream_output(


def construct_regular_output(response: TextGenerationResponse, response_includes_details: bool) -> Output:
metadata = {"rawResponse": response}
metadata = {"raw_response": response}
if response_includes_details:
metadata["details"] = response.details

output = ExecuteResult(
**{
"output_type": "execute_result",
"data": response.generated_text or '',
"data": response.generated_text or "",
"execution_count": 0,
"metadata": metadata,
}
Expand Down Expand Up @@ -209,15 +209,13 @@ async def serialize(
prompt = Prompt(
name=prompt_name,
input=prompt_input,
metadata=PromptMetadata(
model=model_metadata, parameters=parameters, **kwargs
),
metadata=PromptMetadata(model=model_metadata, parameters=parameters, **kwargs),
)

prompts.append(prompt)
await ai_config.callback_manager.run_callbacks(CallbackEvent("on_serialize_complete", __name__, {"result": prompts }))

await ai_config.callback_manager.run_callbacks(CallbackEvent("on_serialize_complete", __name__, {"result": prompts}))

return prompts

async def deserialize(
Expand Down Expand Up @@ -251,9 +249,7 @@ async def deserialize(

return completion_data

async def run_inference(
self, prompt: Prompt, aiconfig: "AIConfigRuntime", options: InferenceOptions, parameters: dict[Any, Any]
) -> List[Output]:
async def run_inference(self, prompt: Prompt, aiconfig: "AIConfigRuntime", options: InferenceOptions, parameters: dict[Any, Any]) -> List[Output]:
"""
Invoked to run a prompt in the .aiconfig. This method should perform
the actual model inference based on the provided prompt and inference settings.
Expand Down
Loading