Skip to content

Commit

Permalink
[python] Use OutputDataWithValue type for function calls
Browse files Browse the repository at this point in the history
This is extending #605 where I converted from model-specific response types --> pure strings. Now I'm going to also support function calls --> OutputData types

Same as last diff (#610) except now for python files instead of typescript

Don't need to do this for the image outputs since I already did that in #608

Actually like typescript, the only model parser we need to support for function calling is `openai.py`
  • Loading branch information
Rossdan Craig [email protected] committed Dec 27, 2023
1 parent f35d6a4 commit 77ce9ec
Show file tree
Hide file tree
Showing 9 changed files with 329 additions and 141 deletions.
37 changes: 25 additions & 12 deletions extensions/Gemini/python/src/aiconfig_extension_gemini/Gemini.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,27 @@
# Define a Model Parser for LLama-Guard
from typing import TYPE_CHECKING, Dict, List, Optional, Any
import copy
import json

import google.generativeai as genai
import copy
from google.generativeai.types import content_types
from google.protobuf.json_format import MessageToDict

from aiconfig import (
AIConfigRuntime,
CallbackEvent,
get_api_key_from_environment,
)
from aiconfig.default_parsers.parameterized_model_parser import ParameterizedModelParser
from aiconfig.model_parser import InferenceOptions
from aiconfig.schema import ExecuteResult, Output, Prompt
from aiconfig.schema import (
ExecuteResult,
Output,
OutputDataWithValue,
Prompt,
PromptInput,
)
from aiconfig.util.params import resolve_prompt, resolve_prompt_string
from aiconfig import CallbackEvent, get_api_key_from_environment, AIConfigRuntime, PromptMetadata, PromptInput
from google.generativeai.types import content_types
from google.protobuf.json_format import MessageToDict

# Circuluar Dependency Type Hints
if TYPE_CHECKING:
Expand Down Expand Up @@ -218,8 +228,8 @@ async def serialize(
ExecuteResult(
**{
"output_type": "execute_result",
"data": model_message_parts[0],
"metadata": {"rawResponse": model_message}
"data": model_message_parts[0],
"metadata": {"rawResponse": model_message},
}
)
]
Expand Down Expand Up @@ -359,11 +369,14 @@ def get_output_text(
output_data = output.data
if isinstance(output_data, str):
return output_data
else:
raise ValueError("Not Implemented")

else:
return ""
if isinstance(output_data, OutputDataWithValue):
if isinstance(output_data.value, str):
return output_data.value
# Gemini does not support function calls so shouldn't
# get here, but just being safe
return json.dumps(output_data.value, indent=2)
raise ValueError("Not Implemented")
return ""

def _construct_chat_history(
self, prompt: Prompt, aiconfig: "AIConfigRuntime", params: Dict
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,23 @@
import copy
from typing import TYPE_CHECKING, Any, Dict, List, Optional
from transformers import AutoTokenizer, Pipeline, pipeline, TextIteratorStreamer
import json
import threading
from typing import TYPE_CHECKING, Any, Dict, List, Optional
from transformers import (
AutoTokenizer,
Pipeline,
pipeline,
TextIteratorStreamer,
)

from aiconfig.default_parsers.parameterized_model_parser import ParameterizedModelParser
from aiconfig.model_parser import InferenceOptions
from aiconfig.schema import ExecuteResult, Output, Prompt, PromptMetadata
from aiconfig.schema import (
ExecuteResult,
Output,
OutputDataWithValue,
Prompt,
PromptMetadata,
)
from aiconfig.util.params import resolve_prompt

# Circuluar Dependency Type Hints
Expand Down Expand Up @@ -105,19 +117,20 @@ def construct_stream_output(
options (InferenceOptions): The inference options. Used to determine the stream callback.
"""
accumulated_message = ""
output = ExecuteResult(
**{
"output_type": "execute_result",
"data": accumulated_message,
"data": "", # We update this below
"execution_count": 0, #Multiple outputs are not supported for streaming
"metadata": {},
}
)
accumulated_message = ""
for new_text in streamer:
accumulated_message += new_text
options.stream_callback(new_text, accumulated_message, 0)
output.data = accumulated_message
if isinstance(new_text, str):
accumulated_message += new_text
options.stream_callback(new_text, accumulated_message, 0)
output.data = accumulated_message
return output


Expand Down Expand Up @@ -283,7 +296,14 @@ def get_output_text(
# TODO (rossdanlm): Handle multiple outputs in list
# https://github.com/lastmile-ai/aiconfig/issues/467
if output.output_type == "execute_result":
if isinstance(output.data, str):
return output.data
else:
return ""
output_data = output.data
if isinstance(output_data, str):
return output_data
if isinstance(output_data, OutputDataWithValue):
if isinstance(output_data.value, str):
return output_data.value
# HuggingFace Text generation does not support function
# calls so shouldn't get here, but just being safe
return json.dumps(output_data.value, indent=2)
raise ValueError("Not Implemented")
return ""
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
import copy
from typing import TYPE_CHECKING, Any, Dict, List, Optional

from aiconfig.default_parsers.parameterized_model_parser import ParameterizedModelParser
from aiconfig.model_parser import InferenceOptions
from aiconfig.util.config_utils import get_api_key_from_environment
import json
from typing import TYPE_CHECKING, Any, Iterable, List, Optional

# HuggingFace API imports
from huggingface_hub import InferenceClient
Expand All @@ -12,13 +9,19 @@
TextGenerationStreamResponse,
)

from aiconfig.schema import ExecuteResult, Output, Prompt, PromptMetadata
from aiconfig import CallbackEvent

from aiconfig.default_parsers.parameterized_model_parser import ParameterizedModelParser
from aiconfig.model_parser import InferenceOptions
from aiconfig.schema import (
ExecuteResult,
Output,
OutputDataWithValue,
Prompt,
PromptMetadata,
)
from aiconfig.util.config_utils import get_api_key_from_environment
from aiconfig.util.params import resolve_prompt

# ModelParser Utils
# Type hint imports

# Circuluar Dependency Type Hints
if TYPE_CHECKING:
Expand Down Expand Up @@ -63,7 +66,7 @@ def refine_chat_completion_params(model_settings: dict[Any, Any]) -> dict[str, A


def construct_stream_output(
response: TextGenerationStreamResponse,
response: Iterable[TextGenerationStreamResponse],
response_includes_details: bool,
options: InferenceOptions,
) -> Output:
Expand All @@ -79,26 +82,25 @@ def construct_stream_output(
accumulated_message = ""
for iteration in response:
metadata = {}
data = iteration
# If response_includes_details is false, `iteration` will be a string,
# otherwise, `iteration` is a TextGenerationStreamResponse
new_text = iteration
if response_includes_details:
iteration: TextGenerationStreamResponse
data = iteration.token.text
new_text = iteration.token.text
metadata = {"token": iteration.token, "details": iteration.details}
else:
data: str

# Reduce
accumulated_message += data
accumulated_message += new_text

index = 0 # HF Text Generation api doesn't support multiple outputs
delta = data
if options and options.stream_callback:
options.stream_callback(delta, accumulated_message, index)
options.stream_callback(new_text, accumulated_message, index)

output = ExecuteResult(
**{
"output_type": "execute_result",
"data": copy.deepcopy(accumulated_message),
"data": accumulated_message,
"execution_count": index,
"metadata": metadata,
}
Expand Down Expand Up @@ -316,7 +318,14 @@ def get_output_text(
return ""

if output.output_type == "execute_result":
if isinstance(output.data, str):
return output.data
else:
return ""
output_data = output.data
if isinstance(output_data, str):
return output_data
if isinstance(output_data, OutputDataWithValue):
if isinstance(output_data.value, str):
return output_data.value
# HuggingFace Text generation does not support function
# calls so shouldn't get here, but just being safe
return json.dumps(output_data.value, indent=2)
raise ValueError("Not Implemented")
return ""
Original file line number Diff line number Diff line change
@@ -1,26 +1,29 @@
# Define a Model Parser for LLama-Guard

"""Define a Model Parser for LLama-Guard"""

import copy
import json
from typing import TYPE_CHECKING, Any, Dict, List, Optional
from transformers import AutoTokenizer

from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

from aiconfig import CallbackEvent
from aiconfig.default_parsers.parameterized_model_parser import ParameterizedModelParser
from aiconfig.model_parser import InferenceOptions
from aiconfig.schema import ExecuteResult, Output, Prompt, PromptMetadata
from aiconfig.schema import (
ExecuteResult,
Output,
OutputDataWithValue,
Prompt,
PromptMetadata,
)
from aiconfig.util.params import resolve_prompt
from aiconfig import CallbackEvent

from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

# Circuluar Dependency Type Hints
if TYPE_CHECKING:
from aiconfig.Config import AIConfigRuntime




# Step 1: define Helpers
def refine_chat_completion_params(model_settings: Dict[str, Any]) -> Dict[str, Any]:
"""
Expand Down Expand Up @@ -242,17 +245,21 @@ async def run_inference(
output_text = self.tokenizer.decode(
response[0][prompt_len:], skip_special_tokens=True
)

Output = ExecuteResult(
output_data_content : str = ''
if isinstance(output_text, str):
output_data_content = output_text
else:
raise ValueError(f"Output {output_text} needs to be of type 'str' but is of type: {type(output_text)}")
output = ExecuteResult(
**{
"output_type": "execute_result",
"data": output_text,
"data": output_data_content,
"execution_count": 0,
"metadata": {},
}
)

prompt.outputs = [Output]
prompt.outputs = [output]
return prompt.outputs

def get_output_text(
Expand All @@ -268,7 +275,13 @@ def get_output_text(
return ""

if output.output_type == "execute_result":
if isinstance(output.data, str):
return output.data
else:
return ""
output_data = output.data
if isinstance(output_data, str):
return output_data
if isinstance(output_data, OutputDataWithValue):
if isinstance(output_data.value, str):
return output_data.value
# LlamaGuard does not support function
# calls so shouldn't get here, but just being safe
return json.dumps(output_data.value, indent=2)
return ""
Loading

0 comments on commit 77ce9ec

Please sign in to comment.