Skip to content

Commit

Permalink
rawResponse -> raw_response for snake_case consistency
Browse files Browse the repository at this point in the history
  • Loading branch information
Ryan Holinshead committed Dec 28, 2023
1 parent 741ce9c commit 32b36ce
Show file tree
Hide file tree
Showing 11 changed files with 148 additions and 282 deletions.
111 changes: 43 additions & 68 deletions extensions/Gemini/python/src/aiconfig_extension_gemini/Gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,7 @@ def construct_regular_outputs(response: "AsyncGenerateContentResponse") -> list[
return output_list


async def construct_stream_outputs(
response: "AsyncGenerateContentResponse", options: InferenceOptions
) -> list[Output]:
async def construct_stream_outputs(response: "AsyncGenerateContentResponse", options: InferenceOptions) -> list[Output]:
"""
Construct Outputs while also streaming the response with stream callback
Expand Down Expand Up @@ -132,17 +130,17 @@ async def serialize(
The data passed in is the completion params a user would use to call the Gemini API directly.
If the user wanted to call the Gemini API directly, they might do something like this:
```
model = genai.GenerativeModel('gemini-pro')
completion_params = {"contents": "Hello"}
model.generate_content(**completion_params)
# Note: The above line is the same as doing this:
# Note: The above line is the same as doing this:
model.generate_content(contents="Hello")
```
* Important: The contents field is what contains the input data. In this case, prompt input would be the contents field.
Args:
prompt (str): The prompt to be serialized.
Expand All @@ -152,7 +150,7 @@ async def serialize(
str: Serialized representation of the prompt and inference settings.
Sample Usage:
1.
1.
completion_params = {"contents": "Hello"}
serialized_prompts = await ai_config.serialize("prompt", completion_params, "gemini-pro")
Expand Down Expand Up @@ -188,7 +186,7 @@ async def serialize(
data = copy.deepcopy(data)
contents = data.pop("contents", None)

model_name = self.model.model_name[len("models/"):]
model_name = self.model.model_name[len("models/") :]
model_metadata = ai_config.get_model_metadata(data, model_name)

prompts = []
Expand All @@ -202,7 +200,9 @@ async def serialize(
# }
contents_is_role_dict = isinstance(contents, dict) and "role" in contents and "parts"
# Multi Turn means that the contents is a list of dicts with alternating role and parts. See for more info: https://ai.google.dev/tutorials/python_quickstart#multi-turn_conversations
contents_is_multi_turn = isinstance(contents, list) and all(isinstance(item, dict) and "role" in item and "parts" in item for item in contents)
contents_is_multi_turn = isinstance(contents, list) and all(
isinstance(item, dict) and "role" in item and "parts" in item for item in contents
)

if contents is None:
raise ValueError("No contents found in data. Gemini api request requires a contents field")
Expand All @@ -226,14 +226,21 @@ async def serialize(
outputs = [
ExecuteResult(
**{
"output_type": "execute_result",
"output_type": "execute_result",
"data": model_message_parts[0],
"metadata": {"rawResponse": model_message},
"metadata": {"raw_response": model_message},
}
)
]
i += 1
prompt = Prompt(**{"name": f'{prompt_name}_{len(prompts) + 1}', "input": user_message_parts, "metadata": {"model": model_metadata}, "outputs": outputs})
prompt = Prompt(
**{
"name": f"{prompt_name}_{len(prompts) + 1}",
"input": user_message_parts,
"metadata": {"model": model_metadata},
"outputs": outputs,
}
)
prompts.append(prompt)
i += 1
else:
Expand All @@ -255,11 +262,7 @@ async def deserialize(self, prompt: Prompt, aiconfig: "AIConfigRuntime", params:
Returns:
dict: Model-specific completion parameters.
"""
await aiconfig.callback_manager.run_callbacks(
CallbackEvent(
"on_deserialize_start", __name__, {"prompt": prompt, "params": params}
)
)
await aiconfig.callback_manager.run_callbacks(CallbackEvent("on_deserialize_start", __name__, {"prompt": prompt, "params": params}))

# Build Completion data
model_settings = self.get_model_settings(prompt, aiconfig)
Expand All @@ -270,9 +273,9 @@ async def deserialize(self, prompt: Prompt, aiconfig: "AIConfigRuntime", params:
messages = self._construct_chat_history(prompt, aiconfig, params)

resolved_prompt = resolve_prompt(prompt, params, aiconfig)

messages.append({"role": "user", "parts": [{"text": resolved_prompt}]})

completion_data["contents"] = messages
else:
# If contents is already set, do not construct chat history. TODO: @Ankush-lastmile
Expand All @@ -287,15 +290,13 @@ async def deserialize(self, prompt: Prompt, aiconfig: "AIConfigRuntime", params:
# This is checking attributes and not a dict like object. in schema.py, PromptInput allows arbitrary attributes/data, and gets serialized as an attribute because it is a pydantic type
if not hasattr(prompt_input, "contents"):
# The source code show cases this more than the docs. This curl request docs similar to python sdk: https://cloud.google.com/vertex-ai/docs/generative-ai/model-reference/gemini#request_body
raise ValueError("Unable to deserialize input. Prompt input type is not a string, Gemini Model Parser expects prompt input to contain a 'contents' field as expected by Gemini API")
raise ValueError(
"Unable to deserialize input. Prompt input type is not a string, Gemini Model Parser expects prompt input to contain a 'contents' field as expected by Gemini API"
)

completion_data['contents'] = parameterize_supported_gemini_input_data(prompt_input.contents, prompt, aiconfig, params)
completion_data["contents"] = parameterize_supported_gemini_input_data(prompt_input.contents, prompt, aiconfig, params)

await aiconfig.callback_manager.run_callbacks(
CallbackEvent(
"on_deserialize_complete", __name__, {"output": completion_data}
)
)
await aiconfig.callback_manager.run_callbacks(CallbackEvent("on_deserialize_complete", __name__, {"output": completion_data}))
return completion_data

async def run_inference(
Expand Down Expand Up @@ -347,9 +348,7 @@ async def run_inference(
outputs = construct_regular_outputs(response)

prompt.outputs = outputs
await aiconfig.callback_manager.run_callbacks(
CallbackEvent("on_run_complete", __name__, {"result": prompt.outputs})
)
await aiconfig.callback_manager.run_callbacks(CallbackEvent("on_run_complete", __name__, {"result": prompt.outputs}))
return prompt.outputs

def get_output_text(
Expand All @@ -375,27 +374,21 @@ def get_output_text(
# We use this method in the model parser functionality itself,
# so we must error to let the user know non-string
# formats are not supported
error_message = \
"""
error_message = """
We currently only support chats where prior messages from a user are only a
single message string instance, please see docstring for more details:
https://github.com/lastmile-ai/aiconfig/blob/v1.1.8/extensions/Gemini/python/src/aiconfig_extension_gemini/Gemini.py#L31-L33
"""
raise ValueError(error_message)

def _construct_chat_history(
self, prompt: Prompt, aiconfig: "AIConfigRuntime", params: Dict
) -> List:
def _construct_chat_history(self, prompt: Prompt, aiconfig: "AIConfigRuntime", params: Dict) -> List:
"""
Constructs the chat history for the model
"""
messages = []
# Default to always use chat context
remember_chat_context = not hasattr(
prompt.metadata, "remember_chat_context"
) or (
hasattr(prompt.metadata, "remember_chat_context")
and prompt.metadata.remember_chat_context != False
remember_chat_context = not hasattr(prompt.metadata, "remember_chat_context") or (
hasattr(prompt.metadata, "remember_chat_context") and prompt.metadata.remember_chat_context != False
)
if remember_chat_context:
# handle chat history. check previous prompts for the same model. if same model, add prompt and its output to completion data if it has a completed output
Expand All @@ -404,29 +397,21 @@ def _construct_chat_history(
if previous_prompt.name == prompt.name:
break

previous_prompt_is_same_model = aiconfig.get_model_name(
previous_prompt
) == aiconfig.get_model_name(prompt)
previous_prompt_is_same_model = aiconfig.get_model_name(previous_prompt) == aiconfig.get_model_name(prompt)
if previous_prompt_is_same_model:
previous_prompt_template = resolve_prompt(
previous_prompt, params, aiconfig
)
previous_prompt_template = resolve_prompt(previous_prompt, params, aiconfig)
previous_prompt_output = aiconfig.get_latest_output(previous_prompt)
previous_prompt_output_text = self.get_output_text(
previous_prompt, aiconfig, previous_prompt_output
)
previous_prompt_output_text = self.get_output_text(previous_prompt, aiconfig, previous_prompt_output)

messages.append(
{"role": "user", "parts": [{"text": previous_prompt_template}]}
)
messages.append({"role": "user", "parts": [{"text": previous_prompt_template}]})
messages.append(
{
"role": "model",
"parts": [{"text": previous_prompt_output_text}],
}
)

return messages
return messages

def get_prompt_template(self, prompt: Prompt, aiConfig: "AIConfigRuntime") -> str:
"""
Expand All @@ -445,26 +430,18 @@ def get_prompt_template(self, prompt: Prompt, aiConfig: "AIConfigRuntime") -> st
elif isinstance(contents, list):
return " ".join(contents)
elif isinstance(contents, dict):
parts= contents["parts"]
parts = contents["parts"]
if isinstance(parts, str):
return parts
elif isinstance(parts, list):
return " ".join(parts)
else:
raise Exception(
f"Cannot get prompt template string from prompt input: {prompt.input}"
)
raise Exception(f"Cannot get prompt template string from prompt input: {prompt.input}")
else:
raise Exception(
f"Cannot get prompt template string from prompt input: {prompt.input}"
)
raise Exception(f"Cannot get prompt template string from prompt input: {prompt.input}")



else:
raise Exception(
f"Cannot get prompt template string from prompt input: {prompt.input}"
)
raise Exception(f"Cannot get prompt template string from prompt input: {prompt.input}")


def refine_chat_completion_params(model_settings):
Expand Down Expand Up @@ -522,9 +499,7 @@ def contains_prompt_template(prompt: Prompt):
"""
Check if a prompt's input is a valid string.
"""
return isinstance(prompt.input, str) or (
hasattr(prompt.input, "data") and isinstance(prompt.input.data, str)
)
return isinstance(prompt.input, str) or (hasattr(prompt.input, "data") and isinstance(prompt.input.data, str))


AIConfigRuntime.register_model_parser(GeminiModelParser("gemini-pro"), "gemini-pro")
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def construct_stream_output(


def construct_regular_output(response: TextGenerationResponse, response_includes_details: bool) -> Output:
metadata = {"rawResponse": response}
metadata = {"raw_response": response}
if response_includes_details:
metadata["details"] = response.details

Expand Down Expand Up @@ -209,15 +209,13 @@ async def serialize(
prompt = Prompt(
name=prompt_name,
input=prompt_input,
metadata=PromptMetadata(
model=model_metadata, parameters=parameters, **kwargs
),
metadata=PromptMetadata(model=model_metadata, parameters=parameters, **kwargs),
)

prompts.append(prompt)
await ai_config.callback_manager.run_callbacks(CallbackEvent("on_serialize_complete", __name__, {"result": prompts }))

await ai_config.callback_manager.run_callbacks(CallbackEvent("on_serialize_complete", __name__, {"result": prompts}))

return prompts

async def deserialize(
Expand Down Expand Up @@ -251,9 +249,7 @@ async def deserialize(

return completion_data

async def run_inference(
self, prompt: Prompt, aiconfig: "AIConfigRuntime", options: InferenceOptions, parameters: dict[Any, Any]
) -> List[Output]:
async def run_inference(self, prompt: Prompt, aiconfig: "AIConfigRuntime", options: InferenceOptions, parameters: dict[Any, Any]) -> List[Output]:
"""
Invoked to run a prompt in the .aiconfig. This method should perform
the actual model inference based on the provided prompt and inference settings.
Expand Down
2 changes: 1 addition & 1 deletion extensions/HuggingFace/typescript/hf.ts
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,7 @@ function constructOutput(response: TextGenerationOutput): Output {
output_type: "execute_result",
data: response.generated_text,
execution_count: 0,
metadata: { rawResponse: response },
metadata: { raw_response: response },
} as ExecuteResult;

return output;
Expand Down
18 changes: 7 additions & 11 deletions python/src/aiconfig/default_parsers/hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,14 +109,14 @@ def construct_stream_output(


def construct_regular_output(response: TextGenerationResponse, response_includes_details: bool) -> Output:
metadata = {"rawResponse": response}
metadata = {"raw_response": response}
if response_includes_details:
metadata["details"] = response.details

output = ExecuteResult(
**{
"output_type": "execute_result",
"data": response.generated_text or '',
"data": response.generated_text or "",
"execution_count": 0,
"metadata": metadata,
}
Expand Down Expand Up @@ -209,15 +209,13 @@ async def serialize(
prompt = Prompt(
name=prompt_name,
input=prompt_input,
metadata=PromptMetadata(
model=model_metadata, parameters=parameters, **kwargs
),
metadata=PromptMetadata(model=model_metadata, parameters=parameters, **kwargs),
)

prompts.append(prompt)
await ai_config.callback_manager.run_callbacks(CallbackEvent("on_serialize_complete", __name__, {"result": prompts }))

await ai_config.callback_manager.run_callbacks(CallbackEvent("on_serialize_complete", __name__, {"result": prompts}))

return prompts

async def deserialize(
Expand Down Expand Up @@ -251,9 +249,7 @@ async def deserialize(

return completion_data

async def run_inference(
self, prompt: Prompt, aiconfig: "AIConfigRuntime", options: InferenceOptions, parameters: dict[Any, Any]
) -> List[Output]:
async def run_inference(self, prompt: Prompt, aiconfig: "AIConfigRuntime", options: InferenceOptions, parameters: dict[Any, Any]) -> List[Output]:
"""
Invoked to run a prompt in the .aiconfig. This method should perform
the actual model inference based on the provided prompt and inference settings.
Expand Down
Loading

0 comments on commit 32b36ce

Please sign in to comment.