Skip to content

Commit

Permalink
[python] Use OutputDataWithValue type for function calls
Browse files Browse the repository at this point in the history
This is extending #605 where I converted from model-specific response types --> pure strings. Now I'm going to also support function calls --> OutputData types

Same as last diff (#610) except now for python files instead of typescript

Don't need to do this for the image outputs since I already did that in #608

Actually like typescript, the only model parser we need to support for function calling is `openai.py`
  • Loading branch information
Rossdan Craig [email protected] committed Dec 28, 2023
1 parent f35d6a4 commit f1248b5
Show file tree
Hide file tree
Showing 9 changed files with 347 additions and 155 deletions.
44 changes: 31 additions & 13 deletions extensions/Gemini/python/src/aiconfig_extension_gemini/Gemini.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,27 @@
# Define a Model Parser for LLama-Guard
from typing import TYPE_CHECKING, Dict, List, Optional, Any
import copy
import json

import google.generativeai as genai
import copy
from google.generativeai.types import content_types
from google.protobuf.json_format import MessageToDict

from aiconfig import (
AIConfigRuntime,
CallbackEvent,
get_api_key_from_environment,
)
from aiconfig.default_parsers.parameterized_model_parser import ParameterizedModelParser
from aiconfig.model_parser import InferenceOptions
from aiconfig.schema import ExecuteResult, Output, Prompt
from aiconfig.schema import (
ExecuteResult,
Output,
OutputDataWithValue,
Prompt,
PromptInput,
)
from aiconfig.util.params import resolve_prompt, resolve_prompt_string
from aiconfig import CallbackEvent, get_api_key_from_environment, AIConfigRuntime, PromptMetadata, PromptInput
from google.generativeai.types import content_types
from google.protobuf.json_format import MessageToDict

# Circuluar Dependency Type Hints
if TYPE_CHECKING:
Expand Down Expand Up @@ -208,7 +218,6 @@ async def serialize(
user_message_parts = user_message["parts"]
outputs = []
if i + 1 < len(contents):
# TODO (rossdanlm): Support function calls
model_message = contents[i + 1]
model_message_parts = model_message["parts"]
# Gemini api currently only supports one candidate aka one output. Model should only be retuning one part in response.
Expand All @@ -218,8 +227,8 @@ async def serialize(
ExecuteResult(
**{
"output_type": "execute_result",
"data": model_message_parts[0],
"metadata": {"rawResponse": model_message}
"data": model_message_parts[0],
"metadata": {"rawResponse": model_message},
}
)
]
Expand Down Expand Up @@ -359,11 +368,20 @@ def get_output_text(
output_data = output.data
if isinstance(output_data, str):
return output_data
else:
raise ValueError("Not Implemented")

else:
return ""
if isinstance(output_data, OutputDataWithValue):
if isinstance(output_data.value, str):
return output_data.value

# We use this method in the model parser functionality itself,
# so we must error to let the user know non-string
# formats are not supported
error_message = \
"""
We currently only support chats where prior messages from a user are only a
single message string instance, please see docstring for more details:
https://github.com/lastmile-ai/aiconfig/blob/v1.1.8/extensions/Gemini/python/src/aiconfig_extension_gemini/Gemini.py#L31-L33
"""
raise ValueError(error_message)

def _construct_chat_history(
self, prompt: Prompt, aiconfig: "AIConfigRuntime", params: Dict
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,23 @@
import copy
from typing import TYPE_CHECKING, Any, Dict, List, Optional
from transformers import AutoTokenizer, Pipeline, pipeline, TextIteratorStreamer
import json
import threading
from typing import TYPE_CHECKING, Any, Dict, List, Optional
from transformers import (
AutoTokenizer,
Pipeline,
pipeline,
TextIteratorStreamer,
)

from aiconfig.default_parsers.parameterized_model_parser import ParameterizedModelParser
from aiconfig.model_parser import InferenceOptions
from aiconfig.schema import ExecuteResult, Output, Prompt, PromptMetadata
from aiconfig.schema import (
ExecuteResult,
Output,
OutputDataWithValue,
Prompt,
PromptMetadata,
)
from aiconfig.util.params import resolve_prompt

# Circuluar Dependency Type Hints
Expand Down Expand Up @@ -101,19 +113,21 @@ def construct_stream_output(
Constructs the output for a stream response.
Args:
streamer (TextIteratorStreamer): Streams the output. See https://huggingface.co/docs/transformers/v4.35.2/en/internal/generation_utils#transformers.TextIteratorStreamer
options (InferenceOptions): The inference options. Used to determine the stream callback.
streamer (TextIteratorStreamer): Streams the output. See:
https://huggingface.co/docs/transformers/v4.35.2/en/internal/generation_utils#transformers.TextIteratorStreamer
options (InferenceOptions): The inference options. Used to determine
the stream callback.
"""
accumulated_message = ""
output = ExecuteResult(
**{
"output_type": "execute_result",
"data": accumulated_message,
"data": "", # We update this below
"execution_count": 0, #Multiple outputs are not supported for streaming
"metadata": {},
}
)
accumulated_message = ""
for new_text in streamer:
accumulated_message += new_text
options.stream_callback(new_text, accumulated_message, 0)
Expand Down Expand Up @@ -283,7 +297,13 @@ def get_output_text(
# TODO (rossdanlm): Handle multiple outputs in list
# https://github.com/lastmile-ai/aiconfig/issues/467
if output.output_type == "execute_result":
if isinstance(output.data, str):
return output.data
else:
return ""
output_data = output.data
if isinstance(output_data, str):
return output_data
if isinstance(output_data, OutputDataWithValue):
if isinstance(output_data.value, str):
return output_data.value
# HuggingFace Text generation does not support function
# calls so shouldn't get here, but just being safe
return json.dumps(output_data.value, indent=2)
return ""
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
import copy
from typing import TYPE_CHECKING, Any, Dict, List, Optional

from aiconfig.default_parsers.parameterized_model_parser import ParameterizedModelParser
from aiconfig.model_parser import InferenceOptions
from aiconfig.util.config_utils import get_api_key_from_environment
import json
from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Union

# HuggingFace API imports
from huggingface_hub import InferenceClient
Expand All @@ -12,13 +9,19 @@
TextGenerationStreamResponse,
)

from aiconfig.schema import ExecuteResult, Output, Prompt, PromptMetadata
from aiconfig import CallbackEvent

from aiconfig.default_parsers.parameterized_model_parser import ParameterizedModelParser
from aiconfig.model_parser import InferenceOptions
from aiconfig.schema import (
ExecuteResult,
Output,
OutputDataWithValue,
Prompt,
PromptMetadata,
)
from aiconfig.util.config_utils import get_api_key_from_environment
from aiconfig.util.params import resolve_prompt

# ModelParser Utils
# Type hint imports

# Circuluar Dependency Type Hints
if TYPE_CHECKING:
Expand Down Expand Up @@ -63,7 +66,7 @@ def refine_chat_completion_params(model_settings: dict[Any, Any]) -> dict[str, A


def construct_stream_output(
response: TextGenerationStreamResponse,
response: Union[Iterable[TextGenerationStreamResponse], Iterable[str]],
response_includes_details: bool,
options: InferenceOptions,
) -> Output:
Expand All @@ -79,26 +82,24 @@ def construct_stream_output(
accumulated_message = ""
for iteration in response:
metadata = {}
data = iteration
# If response_includes_details is false, `iteration` will be a string,
# otherwise, `iteration` is a TextGenerationStreamResponse
new_text = iteration
if response_includes_details:
iteration: TextGenerationStreamResponse
data = iteration.token.text
new_text = iteration.token.text
metadata = {"token": iteration.token, "details": iteration.details}
else:
data: str

# Reduce
accumulated_message += data
accumulated_message += new_text

index = 0 # HF Text Generation api doesn't support multiple outputs
delta = data
if options and options.stream_callback:
options.stream_callback(delta, accumulated_message, index)

options.stream_callback(new_text, accumulated_message, index)
output = ExecuteResult(
**{
"output_type": "execute_result",
"data": copy.deepcopy(accumulated_message),
"data": accumulated_message,
"execution_count": index,
"metadata": metadata,
}
Expand Down Expand Up @@ -316,7 +317,13 @@ def get_output_text(
return ""

if output.output_type == "execute_result":
if isinstance(output.data, str):
return output.data
else:
return ""
output_data = output.data
if isinstance(output_data, str):
return output_data
if isinstance(output_data, OutputDataWithValue):
if isinstance(output_data.value, str):
return output_data.value
# HuggingFace Text generation does not support function
# calls so shouldn't get here, but just being safe
return json.dumps(output_data.value, indent=2)
return ""
Original file line number Diff line number Diff line change
@@ -1,26 +1,29 @@
# Define a Model Parser for LLama-Guard

"""Define a Model Parser for LLama-Guard"""

import copy
import json
from typing import TYPE_CHECKING, Any, Dict, List, Optional
from transformers import AutoTokenizer

from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

from aiconfig import CallbackEvent
from aiconfig.default_parsers.parameterized_model_parser import ParameterizedModelParser
from aiconfig.model_parser import InferenceOptions
from aiconfig.schema import ExecuteResult, Output, Prompt, PromptMetadata
from aiconfig.schema import (
ExecuteResult,
Output,
OutputDataWithValue,
Prompt,
PromptMetadata,
)
from aiconfig.util.params import resolve_prompt
from aiconfig import CallbackEvent

from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

# Circuluar Dependency Type Hints
if TYPE_CHECKING:
from aiconfig.Config import AIConfigRuntime




# Step 1: define Helpers
def refine_chat_completion_params(model_settings: Dict[str, Any]) -> Dict[str, Any]:
"""
Expand Down Expand Up @@ -242,8 +245,7 @@ async def run_inference(
output_text = self.tokenizer.decode(
response[0][prompt_len:], skip_special_tokens=True
)

Output = ExecuteResult(
output = ExecuteResult(
**{
"output_type": "execute_result",
"data": output_text,
Expand All @@ -252,7 +254,7 @@ async def run_inference(
}
)

prompt.outputs = [Output]
prompt.outputs = [output]
return prompt.outputs

def get_output_text(
Expand All @@ -268,7 +270,13 @@ def get_output_text(
return ""

if output.output_type == "execute_result":
if isinstance(output.data, str):
return output.data
else:
return ""
output_data = output.data
if isinstance(output_data, str):
return output_data
if isinstance(output_data, OutputDataWithValue):
if isinstance(output_data.value, str):
return output_data.value
# LlamaGuard does not support function
# calls so shouldn't get here, but just being safe
return json.dumps(output_data.value, indent=2)
return ""
53 changes: 36 additions & 17 deletions extensions/llama/python/llama.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
import json
from typing import Any, List

from aiconfig.Config import AIConfigRuntime
from aiconfig.default_parsers.parameterized_model_parser import ParameterizedModelParser
from aiconfig.model_parser import InferenceOptions
from aiconfig.schema import (
ExecuteResult,
OutputDataWithValue,
Output,
Prompt,
)
from aiconfig.util.params import resolve_prompt
from llama_cpp import Llama

from aiconfig import Output, Prompt
from aiconfig.schema import ExecuteResult


class LlamaModelParser(ParameterizedModelParser):
def __init__(self, model_path: str) -> None:
Expand Down Expand Up @@ -97,23 +101,38 @@ async def _run_inference_helper(self, model_input, options) -> List[Output]:

return ExecuteResult(
output_type="execute_result",
data=texts[-1],
# TODO: Map all text responses to multiple outputs
# This would be part of a large refactor:
# https://github.com/lastmile-ai/aiconfig/issues/630
data="\n".join(texts),
metadata={}
)

def get_output_text(
self, prompt: Prompt, aiconfig: AIConfigRuntime, output: Output | None = None
) -> str:
match output:
case ExecuteResult(data=d):
if isinstance(d, str):
return d

# Doing this to be backwards-compatible with old output format
# where we used to save a list in output.data
# See https://github.com/lastmile-ai/aiconfig/issues/630
elif isinstance(d, list):
return d
return ""
case _:
raise ValueError(f"Unexpected output type: {type(output)}")
if not output:
output = aiconfig.get_latest_output(prompt)

if not output:
return ""

if output.output_type == "execute_result":
output_data = output.data
if isinstance(output_data, str):
return output_data
if isinstance(output_data, OutputDataWithValue):
if isinstance(output_data.value, str):
return output_data.value
# llama model parser does not support function
# calls so shouldn't get here, but just being safe
return json.dumps(output_data.value, indent=2)

# Doing this to be backwards-compatible with old output format
# where we used to save a list in output.data
# See https://github.com/lastmile-ai/aiconfig/issues/630
if isinstance(output_data, list):
# TODO: This is an error since list is not compatible with str type
return output_data
return ""
raise ValueError(f"Output is an unexpected output type: {type(output)}")
Loading

0 comments on commit f1248b5

Please sign in to comment.