From c8155a885ab16bdbef867d8c87709602f306de4f Mon Sep 17 00:00:00 2001 From: Jonathan Lessinger Date: Tue, 23 Jan 2024 17:08:03 -0500 Subject: [PATCH] [RFC][wip] make room for batch eval This moves the existing eval library to "test_suite_eval" and starts the equivalent for batch runs. Also makes the interface a little clearer. Essentially, the differences are: - each metric runs on a _list_ of inputs, not just one - each input can be paired with a reference. This is possible in the "test suite" setup, but it is clunkier. --- .vscode/settings.json | 4 +- .../Basic-Prompt-Routing/assistant_app.py | 12 +- .../Basic-Prompt-Routing/create_config.py | 11 +- cookbooks/Cli-Mate/cli-mate.py | 23 +- cookbooks/Gradio/hf_model_parsers.py | 31 +- cookbooks/HuggingFace/hf.py | 47 +- cookbooks/HuggingFace/python/hf.py | 43 +- cookbooks/Wizard-GPT/wizard-gpt.py | 4 +- cookbooks/llama/python/ask_llama.py | 12 +- .../src/aiconfig_extension_gemini/Gemini.py | 194 +++++-- .../__init__.py | 20 +- .../automatic_speech_recognition.py | 124 ++++- .../local_inference/image_2_text.py | 104 +++- .../local_inference/text_2_image.py | 82 ++- .../local_inference/text_2_speech.py | 76 ++- .../local_inference/text_generation.py | 62 ++- .../local_inference/text_summarization.py | 58 +- .../local_inference/text_translation.py | 58 +- .../local_inference/util.py | 12 +- .../text_generation.py | 75 ++- .../LLamaGuard.py | 46 +- extensions/llama/python/llama.py | 42 +- python/demo/function_calling_demo.py | 4 +- python/src/aiconfig/ChatCompletion.py | 76 ++- python/src/aiconfig/Config.py | 116 +++- python/src/aiconfig/__init__.py | 13 +- python/src/aiconfig/callback.py | 32 +- .../default_parsers/anyscale_endpoint.py | 55 +- python/src/aiconfig/default_parsers/dalle.py | 73 ++- python/src/aiconfig/default_parsers/hf.py | 75 ++- python/src/aiconfig/default_parsers/openai.py | 178 ++++-- python/src/aiconfig/default_parsers/palm.py | 120 ++++- .../parameterized_model_parser.py | 38 +- python/src/aiconfig/editor/server/server.py | 278 +++++++--- .../aiconfig/editor/server/server_utils.py | 142 +++-- python/src/aiconfig/eval/api/__init__.py | 16 +- python/src/aiconfig/eval/batch_common.py | 31 ++ python/src/aiconfig/eval/batch_lib.py | 402 ++++++++++++++ python/src/aiconfig/eval/batch_metrics.py | 80 +++ python/src/aiconfig/eval/common.py | 173 +++--- python/src/aiconfig/eval/openai.py | 48 +- python/src/aiconfig/eval/test_suite_common.py | 109 ++++ .../travel_aiconfig_test_suite_settings.json | 0 .../travel/travel_eval.ipynb | 506 +++++++++--------- .../travel/travel_parametrized.aiconfig.json | 0 .../travel/travel_promptfoo_config.yaml | 0 .../eval/{lib.py => test_suite_lib.py} | 217 +++++--- .../{metrics.py => test_suite_metrics.py} | 190 +++++-- python/src/aiconfig/model_parser.py | 31 +- python/src/aiconfig/registry.py | 26 +- python/src/aiconfig/schema.py | 203 +++++-- python/src/aiconfig/scripts/aiconfig_cli.py | 78 ++- python/src/aiconfig/scripts/rage/rage.py | 83 ++- python/src/aiconfig/scripts/run_aiconfig.py | 9 +- python/src/aiconfig/util/config_utils.py | 26 +- python/src/aiconfig/util/params.py | 58 +- python/tests/mocks.py | 14 +- python/tests/parsers/test_dalle_parser.py | 10 +- python/tests/parsers/test_openai_util.py | 78 ++- python/tests/parsers/test_parser.py | 31 +- python/tests/test_library_helpers.py | 4 +- python/tests/test_load_config.py | 31 +- python/tests/test_parameter_api.py | 21 +- ...est_programmatically_create_an_AIConfig.py | 233 ++++++-- python/tests/test_registry.py | 65 ++- python/tests/test_resolve.py | 11 +- python/tests/test_run_config.py | 10 +- .../{test_eval.py => test_test_suite_eval.py} | 162 ++++-- ...st_test_suite_eval_model_graded_openai.py} | 48 +- python/tests/test_util/test_params.py | 38 +- python/tests/util/file_path_utils.py | 4 +- 71 files changed, 4014 insertions(+), 1372 deletions(-) create mode 100644 python/src/aiconfig/eval/batch_common.py create mode 100644 python/src/aiconfig/eval/batch_lib.py create mode 100644 python/src/aiconfig/eval/batch_metrics.py create mode 100644 python/src/aiconfig/eval/test_suite_common.py rename python/src/aiconfig/eval/{examples => test_suite_examples}/travel/travel_aiconfig_test_suite_settings.json (100%) rename python/src/aiconfig/eval/{examples => test_suite_examples}/travel/travel_eval.ipynb (78%) rename python/src/aiconfig/eval/{examples => test_suite_examples}/travel/travel_parametrized.aiconfig.json (100%) rename python/src/aiconfig/eval/{examples => test_suite_examples}/travel/travel_promptfoo_config.yaml (100%) rename python/src/aiconfig/eval/{lib.py => test_suite_lib.py} (59%) rename python/src/aiconfig/eval/{metrics.py => test_suite_metrics.py} (64%) rename python/tests/{test_eval.py => test_test_suite_eval.py} (67%) rename python/tests/{test_eval_model_graded_openai.py => test_test_suite_eval_model_graded_openai.py} (66%) diff --git a/.vscode/settings.json b/.vscode/settings.json index 2b2401be1..42a2f995c 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -25,9 +25,9 @@ "[python]": { "editor.defaultFormatter": "ms-python.black-formatter", "editor.formatOnSave": true, - "editor.rulers": [150] + "editor.rulers": [79] }, - "black-formatter.args": ["--line-length=150"], + "black-formatter.args": ["--line-length=79"], // example: "--disable=C0114,C0115,C0116" "pylint.args": [] } diff --git a/cookbooks/Basic-Prompt-Routing/assistant_app.py b/cookbooks/Basic-Prompt-Routing/assistant_app.py index db4d4b562..cfcb1ac83 100644 --- a/cookbooks/Basic-Prompt-Routing/assistant_app.py +++ b/cookbooks/Basic-Prompt-Routing/assistant_app.py @@ -32,8 +32,12 @@ async def assistant_response(prompt): # Streamlit Setup st.title("AI Teaching Assistant") -st.markdown("Ask a math, physics, or general question. Based on your question, an AI math prof, physics prof, or general assistant will respond.") -st.markdown("**This is a simple demo of prompt routing - based on your question, an LLM decides which AI teacher responds.**") +st.markdown( + "Ask a math, physics, or general question. Based on your question, an AI math prof, physics prof, or general assistant will respond." +) +st.markdown( + "**This is a simple demo of prompt routing - based on your question, an LLM decides which AI teacher responds.**" +) # Chat setup if "messages" not in st.session_state: @@ -54,4 +58,6 @@ async def assistant_response(prompt): with st.chat_message("assistant"): st.markdown(response) - st.session_state.messages.append({"role": "assistant", "content": response}) + st.session_state.messages.append( + {"role": "assistant", "content": response} + ) diff --git a/cookbooks/Basic-Prompt-Routing/create_config.py b/cookbooks/Basic-Prompt-Routing/create_config.py index d72a8bcfa..1bf30bdd1 100644 --- a/cookbooks/Basic-Prompt-Routing/create_config.py +++ b/cookbooks/Basic-Prompt-Routing/create_config.py @@ -1,10 +1,17 @@ from aiconfig import AIConfigRuntime, Prompt -aiconfig = AIConfigRuntime.create("assistant_config", "teaching assistant config") +aiconfig = AIConfigRuntime.create( + "assistant_config", "teaching assistant config" +) # Set GPT-4 as default model from Teaching Assistant prompts model_name = "gpt-4" -model_settings = {"top_k": 40, "top_p": 1, "model": "gpt-4", "temperature": 0.0} +model_settings = { + "top_k": 40, + "top_p": 1, + "model": "gpt-4", + "temperature": 0.0, +} aiconfig.add_model(model_name, model_settings) diff --git a/cookbooks/Cli-Mate/cli-mate.py b/cookbooks/Cli-Mate/cli-mate.py index c7f9e0b35..b700b1185 100644 --- a/cookbooks/Cli-Mate/cli-mate.py +++ b/cookbooks/Cli-Mate/cli-mate.py @@ -53,13 +53,20 @@ async def query(aiconfig_path: str, question: str) -> list[ExecuteResult]: return result -async def get_mod_result(aiconfig_path: str, source_code: str, question: str) -> list[ExecuteResult]: +async def get_mod_result( + aiconfig_path: str, source_code: str, question: str +) -> list[ExecuteResult]: question_about_code = f"QUERY ABOUT SOURCE CODE:\n{question}\nSOURCE CODE:\n```{source_code}\n```" return await query(aiconfig_path, question_about_code) -async def mod_code(aiconfig_path: str, source_code_file: str, question: str, update_file: bool = False): +async def mod_code( + aiconfig_path: str, + source_code_file: str, + question: str, + update_file: bool = False, +): # read source code from file with open(source_code_file, "r", encoding="utf8") as file: source_code = file.read() @@ -93,7 +100,9 @@ def signal_handler(_: int, __: FrameType | None): i = 0 while True: try: - user_input = await event_loop.run_in_executor(None, session.prompt, "Query: [ctrl-D to exit] ") + user_input = await event_loop.run_in_executor( + None, session.prompt, "Query: [ctrl-D to exit] " + ) except KeyboardInterrupt: continue except EOFError: @@ -113,7 +122,9 @@ def signal_handler(_: int, __: FrameType | None): prompt = user_input # Dynamically generate the prompt name and prompt object - new_prompt_name = f"prompt{len(runtime.prompts)+1}" # Prompt{number of prompts} + new_prompt_name = ( + f"prompt{len(runtime.prompts)+1}" # Prompt{number of prompts} + ) new_prompt = Prompt(name=new_prompt_name, input=prompt) # Add the new prompt and run the model @@ -144,7 +155,9 @@ async def main(): subparsers = parser.add_subparsers(dest="command") loop_parser = subparsers.add_parser("loop") - loop_parser.add_argument("-scf", "--source-code-file", help="Specify a source code file.") + loop_parser.add_argument( + "-scf", "--source-code-file", help="Specify a source code file." + ) args = parser.parse_args() diff --git a/cookbooks/Gradio/hf_model_parsers.py b/cookbooks/Gradio/hf_model_parsers.py index 1046b9e3d..d0d253328 100644 --- a/cookbooks/Gradio/hf_model_parsers.py +++ b/cookbooks/Gradio/hf_model_parsers.py @@ -1,22 +1,27 @@ from aiconfig_extension_hugging_face import ( HuggingFaceAutomaticSpeechRecognitionTransformer, HuggingFaceImage2TextTransformer, - HuggingFaceTextSummarizationTransformer, HuggingFaceText2ImageDiffusor, HuggingFaceText2SpeechTransformer, HuggingFaceTextGenerationTransformer, + HuggingFaceTextSummarizationTransformer, HuggingFaceTextTranslationTransformer, ) - -from aiconfig_extension_hugging_face.remote_inference_client.text_generation import HuggingFaceTextGenerationParser +from aiconfig_extension_hugging_face.remote_inference_client.text_generation import ( + HuggingFaceTextGenerationParser, +) from aiconfig import AIConfigRuntime def register_model_parsers() -> None: """Register model parsers for HuggingFace models.""" - automatic_speech_recognition = HuggingFaceAutomaticSpeechRecognitionTransformer() - AIConfigRuntime.register_model_parser(automatic_speech_recognition, automatic_speech_recognition.id()) + automatic_speech_recognition = ( + HuggingFaceAutomaticSpeechRecognitionTransformer() + ) + AIConfigRuntime.register_model_parser( + automatic_speech_recognition, automatic_speech_recognition.id() + ) image_to_text = HuggingFaceImage2TextTransformer() AIConfigRuntime.register_model_parser(image_to_text, image_to_text.id()) @@ -28,12 +33,20 @@ def register_model_parsers() -> None: AIConfigRuntime.register_model_parser(text_to_speech, text_to_speech.id()) text_generation = HuggingFaceTextGenerationTransformer() - AIConfigRuntime.register_model_parser(text_generation, text_generation.id()) + AIConfigRuntime.register_model_parser( + text_generation, text_generation.id() + ) text_summarization = HuggingFaceTextSummarizationTransformer() - AIConfigRuntime.register_model_parser(text_summarization, text_summarization.id()) + AIConfigRuntime.register_model_parser( + text_summarization, text_summarization.id() + ) text_translation = HuggingFaceTextTranslationTransformer() - AIConfigRuntime.register_model_parser(text_translation, text_translation.id()) + AIConfigRuntime.register_model_parser( + text_translation, text_translation.id() + ) # Register remote inference client for text generation text_generation_remote = HuggingFaceTextGenerationParser() - AIConfigRuntime.register_model_parser(text_generation_remote, text_generation_remote.id()) + AIConfigRuntime.register_model_parser( + text_generation_remote, text_generation_remote.id() + ) diff --git a/cookbooks/HuggingFace/hf.py b/cookbooks/HuggingFace/hf.py index 20405a6e6..c2ac3c9ef 100644 --- a/cookbooks/HuggingFace/hf.py +++ b/cookbooks/HuggingFace/hf.py @@ -3,7 +3,10 @@ # HuggingFace API imports from huggingface_hub import InferenceClient -from huggingface_hub.inference._text_generation import TextGenerationResponse, TextGenerationStreamResponse +from huggingface_hub.inference._text_generation import ( + TextGenerationResponse, + TextGenerationStreamResponse, +) # ModelParser Utils # Type hint imports @@ -104,7 +107,9 @@ def construct_stream_output( return output -def construct_regular_output(response, response_includes_details: bool) -> Output: +def construct_regular_output( + response, response_includes_details: bool +) -> Output: metadata = {} data = response if response_includes_details: @@ -155,7 +160,9 @@ def __init__(self, model_id: str = None, use_api_token=True): if use_api_token: # You are allowed to use Hugging Face for a bit before you get # rate limited, in which case you will receive a clear error - token = get_api_key_from_environment("HUGGING_FACE_API_TOKEN", required=False).unwrap() + token = get_api_key_from_environment( + "HUGGING_FACE_API_TOKEN", required=False + ).unwrap() self.client = InferenceClient(model_id, token=token) @@ -165,7 +172,14 @@ def id(self) -> str: """ return "HuggingFaceTextParser" - def serialize(self, prompt_name: str, data: Any, ai_config: "AIConfigRuntime", parameters: Optional[Dict] = None, **kwargs) -> List[Prompt]: + def serialize( + self, + prompt_name: str, + data: Any, + ai_config: "AIConfigRuntime", + parameters: Optional[Dict] = None, + **kwargs + ) -> List[Prompt]: """ Defines how a prompt and model inference settings get serialized in the .aiconfig. @@ -188,7 +202,9 @@ def serialize(self, prompt_name: str, data: Any, ai_config: "AIConfigRuntime", p prompt = Prompt( name=prompt_name, input=prompt_input, - metadata=PromptMetadata(model=model_metadata, parameters=parameters, **kwargs), + metadata=PromptMetadata( + model=model_metadata, parameters=parameters, **kwargs + ), ) return [prompt] @@ -209,7 +225,9 @@ async def deserialize( Returns: dict: Model-specific completion parameters. """ - resolved_prompt = resolve_prompt(prompt, params if params is not None else {}, aiconfig) + resolved_prompt = resolve_prompt( + prompt, params if params is not None else {}, aiconfig + ) # Build Completion data model_settings = self.get_model_settings(prompt, aiconfig) @@ -220,7 +238,9 @@ async def deserialize( return completion_data - async def run_inference(self, prompt: Prompt, aiconfig, options, parameters) -> List[Output]: + async def run_inference( + self, prompt: Prompt, aiconfig, options, parameters + ) -> List[Output]: """ Invoked to run a prompt in the .aiconfig. This method should perform the actual model inference based on the provided prompt and inference settings. @@ -232,10 +252,15 @@ async def run_inference(self, prompt: Prompt, aiconfig, options, parameters) -> Returns: InferenceResponse: The response from the model. """ - completion_data = await self.deserialize(prompt, aiconfig, options, parameters) + completion_data = await self.deserialize( + prompt, aiconfig, options, parameters + ) # if stream enabled in runtime options and config, then stream. Otherwise don't stream. - stream = (options.stream if options else False) and (not "stream" in completion_data or completion_data.get("stream") != False) + stream = (options.stream if options else False) and ( + not "stream" in completion_data + or completion_data.get("stream") != False + ) response = self.client.text_generation(**completion_data) response_is_detailed = completion_data.get("details", False) @@ -248,7 +273,9 @@ async def run_inference(self, prompt: Prompt, aiconfig, options, parameters) -> outputs.append(output) else: # Handles stream callback - output = construct_stream_output(response, response_is_detailed, options) + output = construct_stream_output( + response, response_is_detailed, options + ) outputs.append(output) prompt.outputs = outputs diff --git a/cookbooks/HuggingFace/python/hf.py b/cookbooks/HuggingFace/python/hf.py index 3958e5c4e..5a865954a 100644 --- a/cookbooks/HuggingFace/python/hf.py +++ b/cookbooks/HuggingFace/python/hf.py @@ -3,7 +3,10 @@ # HuggingFace API imports from huggingface_hub import InferenceClient -from huggingface_hub.inference._text_generation import TextGenerationResponse, TextGenerationStreamResponse +from huggingface_hub.inference._text_generation import ( + TextGenerationResponse, + TextGenerationStreamResponse, +) # ModelParser Utils # Type hint imports @@ -104,7 +107,9 @@ def construct_stream_output( return output -def construct_regular_output(response, response_includes_details: bool) -> Output: +def construct_regular_output( + response, response_includes_details: bool +) -> Output: metadata = {} data = response if response_includes_details: @@ -155,7 +160,9 @@ def __init__(self, model_id: str = None, use_api_token=True): if use_api_token: # You are allowed to use Hugging Face for a bit before you get # rate limited, in which case you will receive a clear error - token = get_api_key_from_environment("HUGGING_FACE_API_TOKEN", required=False).unwrap() + token = get_api_key_from_environment( + "HUGGING_FACE_API_TOKEN", required=False + ).unwrap() self.client = InferenceClient(model_id, token=token) @@ -165,7 +172,14 @@ def id(self) -> str: """ return "HuggingFaceTextParser" - def serialize(self, prompt_name: str, data: Any, ai_config: "AIConfigRuntime", parameters: Optional[Dict] = None, **kwargs) -> List[Prompt]: + def serialize( + self, + prompt_name: str, + data: Any, + ai_config: "AIConfigRuntime", + parameters: Optional[Dict] = None, + **kwargs + ) -> List[Prompt]: """ Defines how a prompt and model inference settings get serialized in the .aiconfig. @@ -188,7 +202,9 @@ def serialize(self, prompt_name: str, data: Any, ai_config: "AIConfigRuntime", p prompt = Prompt( name=prompt_name, input=prompt_input, - metadata=PromptMetadata(model=model_metadata, parameters=parameters, **kwargs), + metadata=PromptMetadata( + model=model_metadata, parameters=parameters, **kwargs + ), ) return [prompt] @@ -220,7 +236,9 @@ async def deserialize( return completion_data - async def run_inference(self, prompt: Prompt, aiconfig, options, parameters) -> List[Output]: + async def run_inference( + self, prompt: Prompt, aiconfig, options, parameters + ) -> List[Output]: """ Invoked to run a prompt in the .aiconfig. This method should perform the actual model inference based on the provided prompt and inference settings. @@ -232,10 +250,15 @@ async def run_inference(self, prompt: Prompt, aiconfig, options, parameters) -> Returns: InferenceResponse: The response from the model. """ - completion_data = await self.deserialize(prompt, aiconfig, options, parameters) + completion_data = await self.deserialize( + prompt, aiconfig, options, parameters + ) # if stream enabled in runtime options and config, then stream. Otherwise don't stream. - stream = (options.stream if options else False) and (not "stream" in completion_data or completion_data.get("stream") != False) + stream = (options.stream if options else False) and ( + not "stream" in completion_data + or completion_data.get("stream") != False + ) response = self.client.text_generation(**completion_data) response_is_detailed = completion_data.get("details", False) @@ -248,7 +271,9 @@ async def run_inference(self, prompt: Prompt, aiconfig, options, parameters) -> outputs.append(output) else: # Handles stream callback - output = construct_stream_output(response, response_is_detailed, options) + output = construct_stream_output( + response, response_is_detailed, options + ) outputs.append(output) prompt.outputs = outputs diff --git a/cookbooks/Wizard-GPT/wizard-gpt.py b/cookbooks/Wizard-GPT/wizard-gpt.py index 976de85ef..224ac175a 100644 --- a/cookbooks/Wizard-GPT/wizard-gpt.py +++ b/cookbooks/Wizard-GPT/wizard-gpt.py @@ -20,7 +20,9 @@ async def main(): break # Dynamically generate the prompt name and prompt object - new_prompt_name = f"prompt{len(config.prompts)+1}" # Prompt{number of prompts} + new_prompt_name = ( + f"prompt{len(config.prompts)+1}" # Prompt{number of prompts} + ) new_prompt = Prompt(name=new_prompt_name, input=user_input) # Add the new prompt and run the model diff --git a/cookbooks/llama/python/ask_llama.py b/cookbooks/llama/python/ask_llama.py index f10d56902..3327c3116 100644 --- a/cookbooks/llama/python/ask_llama.py +++ b/cookbooks/llama/python/ask_llama.py @@ -17,7 +17,9 @@ async def main(): ) # 4. Register the model parser with the model name (see file path). - AIConfigRuntime.register_model_parser(llama_model_parser, "llama-2-7b-chat") + AIConfigRuntime.register_model_parser( + llama_model_parser, "llama-2-7b-chat" + ) # 5. Use the AIConfigRuntime API to load and run your prompt(s). config = AIConfigRuntime.load(aiconfig_path) @@ -41,7 +43,9 @@ def stream_callback(data, accumulated_message, index): ) # 4. Register the model parser with the model name (see file path). - AIConfigRuntime.register_model_parser(llama_model_parser_13b, "llama-2-13b-chat") + AIConfigRuntime.register_model_parser( + llama_model_parser_13b, "llama-2-13b-chat" + ) print("\n\nRunning prompt13b...") await config.run("prompt13b", params={}, options=inference_options) @@ -53,7 +57,9 @@ def stream_callback(data, accumulated_message, index): llama_model_parser_code, "codeup-llama-2-13b-chat-hf" ) print("\n\nRunning prompt13b_code...") - code_res = await config.run("prompt13b_code", params={}, options=inference_options) + code_res = await config.run( + "prompt13b_code", params={}, options=inference_options + ) print(f"\n\n\n\nCode response:\n{code_res}") diff --git a/extensions/Gemini/python/src/aiconfig_extension_gemini/Gemini.py b/extensions/Gemini/python/src/aiconfig_extension_gemini/Gemini.py index 922f9d0dd..ddcc51352 100644 --- a/extensions/Gemini/python/src/aiconfig_extension_gemini/Gemini.py +++ b/extensions/Gemini/python/src/aiconfig_extension_gemini/Gemini.py @@ -1,8 +1,13 @@ # Define a Model Parser for LLama-Guard -from typing import TYPE_CHECKING, Dict, List, Optional, Any import copy +from typing import TYPE_CHECKING, Any, Dict, List, Optional import google.generativeai as genai +from aiconfig.default_parsers.parameterized_model_parser import ( + ParameterizedModelParser, +) +from aiconfig.model_parser import InferenceOptions +from aiconfig.util.params import resolve_prompt, resolve_prompt_string from google.protobuf.json_format import MessageToDict from aiconfig import ( @@ -10,8 +15,6 @@ CallbackEvent, get_api_key_from_environment, ) -from aiconfig.default_parsers.parameterized_model_parser import ParameterizedModelParser -from aiconfig.model_parser import InferenceOptions from aiconfig.schema import ( ExecuteResult, Output, @@ -19,7 +22,6 @@ Prompt, PromptInput, ) -from aiconfig.util.params import resolve_prompt, resolve_prompt_string # Circuluar Dependency Type Hints if TYPE_CHECKING: @@ -48,7 +50,9 @@ """ -def construct_regular_outputs(response: "AsyncGenerateContentResponse") -> list[Output]: +def construct_regular_outputs( + response: "AsyncGenerateContentResponse", +) -> list[Output]: """ Construct regular output per response result, without streaming enabled """ @@ -69,7 +73,9 @@ def construct_regular_outputs(response: "AsyncGenerateContentResponse") -> list[ return output_list -async def construct_stream_outputs(response: "AsyncGenerateContentResponse", options: InferenceOptions) -> list[Output]: +async def construct_stream_outputs( + response: "AsyncGenerateContentResponse", options: InferenceOptions +) -> list[Output]: """ Construct Outputs while also streaming the response with stream callback @@ -111,7 +117,9 @@ def __init__(self, id: str = "gemini-pro"): # as an env var genai.configure() will pick up the env var # `GOOGLE_API_KEY`, it's just that we prefer not to call # `get_api_key_from_environment` multiple times if we don't need to - self.api_key = get_api_key_from_environment("GOOGLE_API_KEY", required=False).unwrap() + self.api_key = get_api_key_from_environment( + "GOOGLE_API_KEY", required=False + ).unwrap() def id(self) -> str: """ @@ -196,23 +204,42 @@ async def serialize( prompts = [] contents_is_str = isinstance(contents, str) - contents_is_list_of_strings = all(isinstance(item, str) for item in contents) if isinstance(contents, list) else False + contents_is_list_of_strings = ( + all(isinstance(item, str) for item in contents) + if isinstance(contents, list) + else False + ) # Role Dict looks like this: # {'role':'user', # 'parts': ["Briefly explain how a computer works to a young child."] # } - contents_is_role_dict = isinstance(contents, dict) and "role" in contents and "parts" + contents_is_role_dict = ( + isinstance(contents, dict) and "role" in contents and "parts" + ) # Multi Turn means that the contents is a list of dicts with alternating role and parts. See for more info: https://ai.google.dev/tutorials/python_quickstart#multi-turn_conversations contents_is_multi_turn = isinstance(contents, list) and all( - isinstance(item, dict) and "role" in item and "parts" in item for item in contents + isinstance(item, dict) and "role" in item and "parts" in item + for item in contents ) if contents is None: - raise ValueError("No contents found in data. Gemini api request requires a contents field") - if contents_is_str or contents_is_list_of_strings or contents_is_role_dict: + raise ValueError( + "No contents found in data. Gemini api request requires a contents field" + ) + if ( + contents_is_str + or contents_is_list_of_strings + or contents_is_role_dict + ): # Just one string. Assume it's a single-turn prompt - prompt = Prompt(**{"name": prompt_name, "input": {"contents": contents}, "metadata": {"model": model_metadata}}) + prompt = Prompt( + **{ + "name": prompt_name, + "input": {"contents": contents}, + "metadata": {"model": model_metadata}, + } + ) prompts.append(prompt) elif contents_is_multi_turn: # Assume it's a multi-turn prompt. Each item in the list is a dict with role and parts @@ -248,14 +275,23 @@ async def serialize( prompts.append(prompt) i += 1 else: - raise ValueError("Unable to parse Data into prompts. Contents data is either invalid or contains unsupported objects like protobufs.") + raise ValueError( + "Unable to parse Data into prompts. Contents data is either invalid or contains unsupported objects like protobufs." + ) - event = CallbackEvent("on_serialize_complete", __name__, {"result": prompts}) + event = CallbackEvent( + "on_serialize_complete", __name__, {"result": prompts} + ) await ai_config.callback_manager.run_callbacks(event) return prompts - async def deserialize(self, prompt: Prompt, aiconfig: "AIConfigRuntime", params: Optional[Dict] = None) -> Dict: + async def deserialize( + self, + prompt: Prompt, + aiconfig: "AIConfigRuntime", + params: Optional[Dict] = None, + ) -> Dict: """ Defines how to parse a prompt in the .aiconfig for a particular model and constructs the completion params for that model. @@ -266,7 +302,13 @@ async def deserialize(self, prompt: Prompt, aiconfig: "AIConfigRuntime", params: Returns: dict: Model-specific completion parameters. """ - await aiconfig.callback_manager.run_callbacks(CallbackEvent("on_deserialize_start", __name__, {"prompt": prompt, "params": params})) + await aiconfig.callback_manager.run_callbacks( + CallbackEvent( + "on_deserialize_start", + __name__, + {"prompt": prompt, "params": params}, + ) + ) # Build Completion data model_settings = self.get_model_settings(prompt, aiconfig) @@ -278,7 +320,9 @@ async def deserialize(self, prompt: Prompt, aiconfig: "AIConfigRuntime", params: resolved_prompt = resolve_prompt(prompt, params, aiconfig) - messages.append({"role": "user", "parts": [{"text": resolved_prompt}]}) + messages.append( + {"role": "user", "parts": [{"text": resolved_prompt}]} + ) completion_data["contents"] = messages else: @@ -298,9 +342,19 @@ async def deserialize(self, prompt: Prompt, aiconfig: "AIConfigRuntime", params: "Unable to deserialize input. Prompt input type is not a string, Gemini Model Parser expects prompt input to contain a 'contents' field as expected by Gemini API" ) - completion_data["contents"] = parameterize_supported_gemini_input_data(prompt_input.contents, prompt, aiconfig, params) + completion_data[ + "contents" + ] = parameterize_supported_gemini_input_data( + prompt_input.contents, prompt, aiconfig, params + ) - await aiconfig.callback_manager.run_callbacks(CallbackEvent("on_deserialize_complete", __name__, {"output": completion_data})) + await aiconfig.callback_manager.run_callbacks( + CallbackEvent( + "on_deserialize_complete", + __name__, + {"output": completion_data}, + ) + ) return completion_data async def run_inference( @@ -325,12 +379,18 @@ async def run_inference( CallbackEvent( "on_run_start", __name__, - {"prompt": prompt, "options": options, "parameters": parameters}, + { + "prompt": prompt, + "options": options, + "parameters": parameters, + }, ) ) if not self.api_key: - self.api_key = get_api_key_from_environment("GOOGLE_API_KEY", required=True).unwrap() + self.api_key = get_api_key_from_environment( + "GOOGLE_API_KEY", required=True + ).unwrap() genai.configure(api_key=self.api_key) # TODO: check and handle api key here @@ -352,7 +412,11 @@ async def run_inference( outputs = construct_regular_outputs(response) prompt.outputs = outputs - await aiconfig.callback_manager.run_callbacks(CallbackEvent("on_run_complete", __name__, {"result": prompt.outputs})) + await aiconfig.callback_manager.run_callbacks( + CallbackEvent( + "on_run_complete", __name__, {"result": prompt.outputs} + ) + ) return prompt.outputs def get_output_text( @@ -385,14 +449,19 @@ def get_output_text( """ raise ValueError(error_message) - def _construct_chat_history(self, prompt: Prompt, aiconfig: "AIConfigRuntime", params: Dict) -> List: + def _construct_chat_history( + self, prompt: Prompt, aiconfig: "AIConfigRuntime", params: Dict + ) -> List: """ Constructs the chat history for the model """ messages = [] # Default to always use chat context - remember_chat_context = not hasattr(prompt.metadata, "remember_chat_context") or ( - hasattr(prompt.metadata, "remember_chat_context") and prompt.metadata.remember_chat_context != False + remember_chat_context = not hasattr( + prompt.metadata, "remember_chat_context" + ) or ( + hasattr(prompt.metadata, "remember_chat_context") + and prompt.metadata.remember_chat_context != False ) if remember_chat_context: # handle chat history. check previous prompts for the same model. if same model, add prompt and its output to completion data if it has a completed output @@ -401,13 +470,26 @@ def _construct_chat_history(self, prompt: Prompt, aiconfig: "AIConfigRuntime", p if previous_prompt.name == prompt.name: break - previous_prompt_is_same_model = aiconfig.get_model_name(previous_prompt) == aiconfig.get_model_name(prompt) + previous_prompt_is_same_model = aiconfig.get_model_name( + previous_prompt + ) == aiconfig.get_model_name(prompt) if previous_prompt_is_same_model: - previous_prompt_template = resolve_prompt(previous_prompt, params, aiconfig) - previous_prompt_output = aiconfig.get_latest_output(previous_prompt) - previous_prompt_output_text = self.get_output_text(previous_prompt, aiconfig, previous_prompt_output) + previous_prompt_template = resolve_prompt( + previous_prompt, params, aiconfig + ) + previous_prompt_output = aiconfig.get_latest_output( + previous_prompt + ) + previous_prompt_output_text = self.get_output_text( + previous_prompt, aiconfig, previous_prompt_output + ) - messages.append({"role": "user", "parts": [{"text": previous_prompt_template}]}) + messages.append( + { + "role": "user", + "parts": [{"text": previous_prompt_template}], + } + ) messages.append( { "role": "model", @@ -417,7 +499,9 @@ def _construct_chat_history(self, prompt: Prompt, aiconfig: "AIConfigRuntime", p return messages - def get_prompt_template(self, prompt: Prompt, aiConfig: "AIConfigRuntime") -> str: + def get_prompt_template( + self, prompt: Prompt, aiConfig: "AIConfigRuntime" + ) -> str: """ This method is overriden from the ParameterizedModelParser class. Its intended to be used only when collecting prompt references, nothing else. @@ -440,12 +524,18 @@ def get_prompt_template(self, prompt: Prompt, aiConfig: "AIConfigRuntime") -> st elif isinstance(parts, list): return " ".join(parts) else: - raise Exception(f"Cannot get prompt template string from prompt input: {prompt.input}") + raise Exception( + f"Cannot get prompt template string from prompt input: {prompt.input}" + ) else: - raise Exception(f"Cannot get prompt template string from prompt input: {prompt.input}") + raise Exception( + f"Cannot get prompt template string from prompt input: {prompt.input}" + ) else: - raise Exception(f"Cannot get prompt template string from prompt input: {prompt.input}") + raise Exception( + f"Cannot get prompt template string from prompt input: {prompt.input}" + ) def refine_chat_completion_params(model_settings): @@ -465,7 +555,12 @@ def refine_chat_completion_params(model_settings): return completion_data -def parameterize_supported_gemini_input_data(part: Any, prompt: Prompt, aiconfig: "AIConfigRuntime", input_params: dict[str, Any]): +def parameterize_supported_gemini_input_data( + part: Any, + prompt: Prompt, + aiconfig: "AIConfigRuntime", + input_params: dict[str, Any], +): """ Parameterizes the input for the Gemini API based on the type of the input part. This function specifically handles string-based types in the context of Gemini API. @@ -486,24 +581,39 @@ def parameterize_supported_gemini_input_data(part: Any, prompt: Prompt, aiconfig return resolve_prompt_string(prompt, input_params, aiconfig, part) elif isinstance(part, list): # This is expecting a list of strings. If its anything else, this will probably fail. - return [parameterize_supported_gemini_input_data(item, prompt, aiconfig, input_params) for item in part] + return [ + parameterize_supported_gemini_input_data( + item, prompt, aiconfig, input_params + ) + for item in part + ] elif isinstance(part, dict): # Expect "parts" key to be present in role dict if "parts" in part: part = copy.deepcopy(part) - part["parts"] = parameterize_supported_gemini_input_data(part["parts"], prompt, aiconfig, input_params) + part["parts"] = parameterize_supported_gemini_input_data( + part["parts"], prompt, aiconfig, input_params + ) return part else: - raise ValueError(f"Input Dictionary to Gemini Model Parser must contain a 'parts' key. Input provided: {part}") + raise ValueError( + f"Input Dictionary to Gemini Model Parser must contain a 'parts' key. Input provided: {part}" + ) else: - raise ValueError(f"Unable to parameterize part. Unsupported type: {type(part)} with value: {part}") + raise ValueError( + f"Unable to parameterize part. Unsupported type: {type(part)} with value: {part}" + ) def contains_prompt_template(prompt: Prompt): """ Check if a prompt's input is a valid string. """ - return isinstance(prompt.input, str) or (hasattr(prompt.input, "data") and isinstance(prompt.input.data, str)) + return isinstance(prompt.input, str) or ( + hasattr(prompt.input, "data") and isinstance(prompt.input.data, str) + ) -AIConfigRuntime.register_model_parser(GeminiModelParser("gemini-pro"), "gemini-pro") +AIConfigRuntime.register_model_parser( + GeminiModelParser("gemini-pro"), "gemini-pro" +) diff --git a/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/__init__.py b/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/__init__.py index a6e28cab8..45d401dd0 100644 --- a/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/__init__.py +++ b/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/__init__.py @@ -1,12 +1,22 @@ -from .local_inference.automatic_speech_recognition import HuggingFaceAutomaticSpeechRecognitionTransformer +from .local_inference.automatic_speech_recognition import ( + HuggingFaceAutomaticSpeechRecognitionTransformer, +) from .local_inference.image_2_text import HuggingFaceImage2TextTransformer from .local_inference.text_2_image import HuggingFaceText2ImageDiffusor from .local_inference.text_2_speech import HuggingFaceText2SpeechTransformer -from .local_inference.text_generation import HuggingFaceTextGenerationTransformer -from .local_inference.text_summarization import HuggingFaceTextSummarizationTransformer -from .local_inference.text_translation import HuggingFaceTextTranslationTransformer -from .remote_inference_client.text_generation import HuggingFaceTextGenerationParser +from .local_inference.text_generation import ( + HuggingFaceTextGenerationTransformer, +) +from .local_inference.text_summarization import ( + HuggingFaceTextSummarizationTransformer, +) +from .local_inference.text_translation import ( + HuggingFaceTextTranslationTransformer, +) from .local_inference.util import get_hf_model +from .remote_inference_client.text_generation import ( + HuggingFaceTextGenerationParser, +) UTILS = [get_hf_model] diff --git a/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/local_inference/automatic_speech_recognition.py b/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/local_inference/automatic_speech_recognition.py index 1c5c2f3d5..c0cddcaf7 100644 --- a/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/local_inference/automatic_speech_recognition.py +++ b/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/local_inference/automatic_speech_recognition.py @@ -1,11 +1,19 @@ -from typing import Any, Dict, Optional, List, TYPE_CHECKING +from typing import TYPE_CHECKING, Any, Dict, List, Optional import torch -from transformers import pipeline, Pipeline -from aiconfig_extension_hugging_face.local_inference.util import get_hf_model -from aiconfig import ModelParser, InferenceOptions from aiconfig.callback import CallbackEvent -from aiconfig.schema import AttachmentDataWithStringValue, Prompt, Output, ExecuteResult, Attachment, PromptInput +from aiconfig_extension_hugging_face.local_inference.util import get_hf_model +from transformers import Pipeline, pipeline + +from aiconfig import InferenceOptions, ModelParser +from aiconfig.schema import ( + Attachment, + AttachmentDataWithStringValue, + ExecuteResult, + Output, + Prompt, + PromptInput, +) if TYPE_CHECKING: from aiconfig import AIConfigRuntime @@ -57,7 +65,9 @@ async def serialize( str: Serialized representation of the prompt and inference settings. """ # TODO: See https://github.com/lastmile-ai/aiconfig/issues/822 - raise NotImplementedError("serialize is not implemented for HuggingFaceAutomaticSpeechRecognitionTransformer") + raise NotImplementedError( + "serialize is not implemented for HuggingFaceAutomaticSpeechRecognitionTransformer" + ) async def deserialize( self, @@ -65,12 +75,23 @@ async def deserialize( aiconfig: "AIConfigRuntime", params: Optional[Dict[str, Any]] = {}, ) -> Dict[str, Any]: - await aiconfig.callback_manager.run_callbacks(CallbackEvent("on_deserialize_start", __name__, {"prompt": prompt, "params": params})) + await aiconfig.callback_manager.run_callbacks( + CallbackEvent( + "on_deserialize_start", + __name__, + {"prompt": prompt, "params": params}, + ) + ) # Build Completion data model_settings = self.get_model_settings(prompt, aiconfig) - [_pipeline_creation_params, unfiltered_completion_params] = refine_pipeline_creation_params(model_settings) - completion_data = refine_asr_completion_params(unfiltered_completion_params) + [ + _pipeline_creation_params, + unfiltered_completion_params, + ] = refine_pipeline_creation_params(model_settings) + completion_data = refine_asr_completion_params( + unfiltered_completion_params + ) # ASR Pipeline supports input types of bytes, file path, and a dict containing raw sampled audio. Also supports multiple input # For now, support multiple or single uri's as input @@ -82,20 +103,39 @@ async def deserialize( completion_data["inputs"] = inputs - await aiconfig.callback_manager.run_callbacks(CallbackEvent("on_deserialize_complete", __name__, {"output": completion_data})) + await aiconfig.callback_manager.run_callbacks( + CallbackEvent( + "on_deserialize_complete", + __name__, + {"output": completion_data}, + ) + ) return completion_data - async def run(self, prompt: Prompt, aiconfig: "AIConfigRuntime", options: InferenceOptions, parameters: Dict[str, Any], **kwargs) -> list[Output]: + async def run( + self, + prompt: Prompt, + aiconfig: "AIConfigRuntime", + options: InferenceOptions, + parameters: Dict[str, Any], + **kwargs, + ) -> list[Output]: await aiconfig.callback_manager.run_callbacks( CallbackEvent( "on_run_start", __name__, - {"prompt": prompt, "options": options, "parameters": parameters}, + { + "prompt": prompt, + "options": options, + "parameters": parameters, + }, ) ) model_settings = self.get_model_settings(prompt, aiconfig) - [pipeline_creation_data, _] = refine_pipeline_creation_params(model_settings) + [pipeline_creation_data, _] = refine_pipeline_creation_params( + model_settings + ) model_name = get_hf_model(aiconfig, prompt, self) key = model_name if model_name is not None else "__default__" @@ -103,7 +143,11 @@ async def run(self, prompt: Prompt, aiconfig: "AIConfigRuntime", options: Infere device = self._get_device() if pipeline_creation_data.get("device", None) is None: pipeline_creation_data["device"] = device - self.pipelines[key] = pipeline(task="automatic-speech-recognition", model=model_name, **pipeline_creation_data) + self.pipelines[key] = pipeline( + task="automatic-speech-recognition", + model=model_name, + **pipeline_creation_data, + ) asr_pipeline = self.pipelines[key] completion_data = await self.deserialize(prompt, aiconfig, parameters) @@ -114,7 +158,11 @@ async def run(self, prompt: Prompt, aiconfig: "AIConfigRuntime", options: Infere outputs = construct_outputs(response) prompt.outputs = outputs - await aiconfig.callback_manager.run_callbacks(CallbackEvent("on_run_complete", __name__, {"result": prompt.outputs})) + await aiconfig.callback_manager.run_callbacks( + CallbackEvent( + "on_run_complete", __name__, {"result": prompt.outputs} + ) + ) return prompt.outputs def _get_device(self) -> str: @@ -146,10 +194,14 @@ def get_output_text( def validate_attachment_type_is_audio(attachment: Attachment): if not hasattr(attachment, "mime_type"): - raise ValueError(f"Attachment has no mime type. Specify the audio mimetype in the aiconfig") + raise ValueError( + f"Attachment has no mime type. Specify the audio mimetype in the aiconfig" + ) if not attachment.mime_type.startswith("audio/"): - raise ValueError(f"Invalid attachment mimetype {attachment.mime_type}. Expected audio mimetype.") + raise ValueError( + f"Invalid attachment mimetype {attachment.mime_type}. Expected audio mimetype." + ) def validate_and_retrieve_audio_from_attachments(prompt: Prompt) -> list[str]: @@ -164,10 +216,14 @@ def validate_and_retrieve_audio_from_attachments(prompt: Prompt) -> list[str]: """ if not isinstance(prompt.input, PromptInput): - raise ValueError(f"Prompt input is of type {type(prompt.input) }. Please specify a PromptInput with attachments for prompt {prompt.name}.") - + raise ValueError( + f"Prompt input is of type {type(prompt.input) }. Please specify a PromptInput with attachments for prompt {prompt.name}." + ) + if prompt.input.attachments is None or len(prompt.input.attachments) == 0: - raise ValueError(f"No attachments found in input for prompt {prompt.name}. Please add an audio attachment to the prompt input.") + raise ValueError( + f"No attachments found in input for prompt {prompt.name}. Please add an audio attachment to the prompt input." + ) audio_inputs: list[str] = [] @@ -175,15 +231,19 @@ def validate_and_retrieve_audio_from_attachments(prompt: Prompt) -> list[str]: validate_attachment_type_is_audio(attachment) if not isinstance(attachment.data, AttachmentDataWithStringValue): - raise ValueError(f"""Attachment data must be of type `AttachmentDataWithStringValue` with a `kind` and `value` field. - Please specify a uri for the audio attachment in prompt {prompt.name}.""") + raise ValueError( + f"""Attachment data must be of type `AttachmentDataWithStringValue` with a `kind` and `value` field. + Please specify a uri for the audio attachment in prompt {prompt.name}.""" + ) audio_inputs.append(attachment.data.value) return audio_inputs -def refine_pipeline_creation_params(model_settings: Dict[str, Any]) -> List[Dict[str, Any]]: +def refine_pipeline_creation_params( + model_settings: Dict[str, Any] +) -> List[Dict[str, Any]]: """ Refines the pipeline creation params for the HF text2Image generation api. Defers unsupported params as completion params, where they can get processed in @@ -214,7 +274,9 @@ def refine_pipeline_creation_params(model_settings: Dict[str, Any]) -> List[Dict if key.lower() in supported_keys: pipeline_creation_params[key.lower()] = model_settings[key] else: - if key.lower() == "kwargs" and isinstance(model_settings[key], Dict): + if key.lower() == "kwargs" and isinstance( + model_settings[key], Dict + ): completion_params.update(model_settings[key]) else: completion_params[key.lower()] = model_settings[key] @@ -222,7 +284,9 @@ def refine_pipeline_creation_params(model_settings: Dict[str, Any]) -> List[Dict return [pipeline_creation_params, completion_params] -def refine_asr_completion_params(unfiltered_completion_params: Dict[str, Any]) -> Dict[str, Any]: +def refine_asr_completion_params( + unfiltered_completion_params: Dict[str, Any] +) -> Dict[str, Any]: """ Refines the ASR params for the HF asr generation api after a pipeline has been created via `refine_pipeline_creation_params`. Removes any @@ -272,13 +336,19 @@ def construct_outputs(response: list[Any]) -> list[Output]: for i, result in enumerate(response): # response is expected to be a dict containing the text output and timestamps if specified. Could not find docs for this. result: dict[str, Any] - text_output = result.get("text") if "text" in result and isinstance(result, dict) else result + text_output = ( + result.get("text") + if "text" in result and isinstance(result, dict) + else result + ) output = ExecuteResult( **{ "output_type": "execute_result", "data": text_output, "execution_count": i, - "metadata": {"result": result} if result.get("chunks", False) else {}, # may contain timestamps and chunks, for now pass result + "metadata": {"result": result} + if result.get("chunks", False) + else {}, # may contain timestamps and chunks, for now pass result } ) outputs.append(output) diff --git a/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/local_inference/image_2_text.py b/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/local_inference/image_2_text.py index e5579e5f8..4aedf0500 100644 --- a/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/local_inference/image_2_text.py +++ b/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/local_inference/image_2_text.py @@ -1,24 +1,21 @@ import base64 import json from io import BytesIO -from typing import Any, Dict, Optional, List, TYPE_CHECKING, Union -from PIL import Image as img_module -from PIL.Image import Image as ImageType -from transformers import ( - Pipeline, - pipeline, -) +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union +from aiconfig.callback import CallbackEvent from aiconfig_extension_hugging_face.local_inference.util import get_hf_model +from PIL import Image as img_module +from PIL.Image import Image as ImageType +from transformers import Pipeline, pipeline -from aiconfig import ModelParser, InferenceOptions -from aiconfig.callback import CallbackEvent +from aiconfig import InferenceOptions, ModelParser from aiconfig.schema import ( Attachment, + AttachmentDataWithStringValue, ExecuteResult, Output, Prompt, - AttachmentDataWithStringValue ) # Circular Dependency Type Hints @@ -82,9 +79,13 @@ async def serialize( prompts = [] if not isinstance(data, dict): - raise ValueError("Invalid data type. Expected dict when serializing prompt data to aiconfig.") + raise ValueError( + "Invalid data type. Expected dict when serializing prompt data to aiconfig." + ) if data.get("inputs", None) is None: - raise ValueError("Invalid data when serializing prompt to aiconfig. Input data must contain an inputs field.") + raise ValueError( + "Invalid data when serializing prompt to aiconfig. Input data must contain an inputs field." + ) prompt = Prompt( **{ @@ -97,7 +98,11 @@ async def serialize( prompts.append(prompt) - await ai_config.callback_manager.run_callbacks(CallbackEvent("on_serialize_complete", __name__, {"result": prompts})) + await ai_config.callback_manager.run_callbacks( + CallbackEvent( + "on_serialize_complete", __name__, {"result": prompts} + ) + ) return prompts async def deserialize( @@ -106,7 +111,13 @@ async def deserialize( aiconfig: "AIConfigRuntime", params: Optional[Dict[str, Any]] = {}, ) -> Dict[str, Any]: - await aiconfig.callback_manager.run_callbacks(CallbackEvent("on_deserialize_start", __name__, {"prompt": prompt, "params": params})) + await aiconfig.callback_manager.run_callbacks( + CallbackEvent( + "on_deserialize_start", + __name__, + {"prompt": prompt, "params": params}, + ) + ) # Build Completion data model_settings = self.get_model_settings(prompt, aiconfig) @@ -116,15 +127,32 @@ async def deserialize( inputs = validate_and_retrieve_images_from_attachments(prompt) completion_params["inputs"] = inputs - await aiconfig.callback_manager.run_callbacks(CallbackEvent("on_deserialize_complete", __name__, {"output": completion_params})) + await aiconfig.callback_manager.run_callbacks( + CallbackEvent( + "on_deserialize_complete", + __name__, + {"output": completion_params}, + ) + ) return completion_params - async def run(self, prompt: Prompt, aiconfig: "AIConfigRuntime", options: InferenceOptions, parameters: Dict[str, Any], **kwargs) -> list[Output]: + async def run( + self, + prompt: Prompt, + aiconfig: "AIConfigRuntime", + options: InferenceOptions, + parameters: Dict[str, Any], + **kwargs, + ) -> list[Output]: await aiconfig.callback_manager.run_callbacks( CallbackEvent( "on_run_start", __name__, - {"prompt": prompt, "options": options, "parameters": parameters}, + { + "prompt": prompt, + "options": options, + "parameters": parameters, + }, ) ) @@ -135,7 +163,9 @@ async def run(self, prompt: Prompt, aiconfig: "AIConfigRuntime", options: Infere key = model_name if model_name is not None else "__default__" if key not in self.pipelines: - self.pipelines[key] = pipeline(task="image-to-text", model=model_name) + self.pipelines[key] = pipeline( + task="image-to-text", model=model_name + ) captioner = self.pipelines[key] outputs: List[Output] = [] @@ -198,7 +228,9 @@ def refine_completion_params(model_settings: Dict[str, Any]) -> Dict[str, Any]: # Helper methods -def construct_regular_output(result: Dict[str, str], execution_count: int) -> Output: +def construct_regular_output( + result: Dict[str, str], execution_count: int +) -> Output: """ Construct regular output per response result, without streaming enabled """ @@ -225,13 +257,19 @@ def validate_attachment_type_is_image( image format. Raises ValueError if there's an issue. """ if not hasattr(attachment, "mime_type"): - raise ValueError(f"Attachment has no mime type for prompt '{prompt_name}'. Please specify the image mimetype in the AIConfig") + raise ValueError( + f"Attachment has no mime type for prompt '{prompt_name}'. Please specify the image mimetype in the AIConfig" + ) if not attachment.mime_type.startswith("image/"): - raise ValueError(f"Invalid attachment mimetype {attachment.mime_type} for prompt '{prompt_name}'. Please use a mimetype that starts with 'image/'.") + raise ValueError( + f"Invalid attachment mimetype {attachment.mime_type} for prompt '{prompt_name}'. Please use a mimetype that starts with 'image/'." + ) -def validate_and_retrieve_images_from_attachments(prompt: Prompt) -> list[Union[str, ImageType]]: +def validate_and_retrieve_images_from_attachments( + prompt: Prompt, +) -> list[Union[str, ImageType]]: """ Retrieves the image uri's from each attachment in the prompt input. @@ -242,24 +280,32 @@ def validate_and_retrieve_images_from_attachments(prompt: Prompt) -> list[Union[ - operation fails for any reason """ - if not hasattr(prompt.input, "attachments") or len(prompt.input.attachments) == 0: - raise ValueError(f"No attachments found in input for prompt '{prompt.name}'. Please add an image attachment to the prompt input.") + if ( + not hasattr(prompt.input, "attachments") + or len(prompt.input.attachments) == 0 + ): + raise ValueError( + f"No attachments found in input for prompt '{prompt.name}'. Please add an image attachment to the prompt input." + ) images: list[Union[str, ImageType]] = [] for i, attachment in enumerate(prompt.input.attachments): validate_attachment_type_is_image(prompt.name, attachment) - if not isinstance(attachment.data, AttachmentDataWithStringValue): # See todo above, but for now only support uris and base64 - raise ValueError(f"""Attachment #{i} data must be of type `AttachmentDataWithStringValue` with a `kind` and `value` field. - Please specify a uri or base64 encoded string for the image attachment in prompt '{prompt.name}'.""") + raise ValueError( + f"""Attachment #{i} data must be of type `AttachmentDataWithStringValue` with a `kind` and `value` field. + Please specify a uri or base64 encoded string for the image attachment in prompt '{prompt.name}'.""" + ) input_data = attachment.data.value if attachment.data.kind == "base64": - pil_image: ImageType = img_module.open(BytesIO(base64.b64decode(input_data))) + pil_image: ImageType = img_module.open( + BytesIO(base64.b64decode(input_data)) + ) images.append(pil_image) else: - images.append(input_data) # expect a uri + images.append(input_data) # expect a uri return images diff --git a/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/local_inference/text_2_image.py b/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/local_inference/text_2_image.py index 360ef9dc5..0ef873589 100644 --- a/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/local_inference/text_2_image.py +++ b/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/local_inference/text_2_image.py @@ -3,18 +3,23 @@ import io import itertools import json -import torch from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union + +import torch +from aiconfig.default_parsers.parameterized_model_parser import ( + ParameterizedModelParser, +) +from aiconfig.model_parser import InferenceOptions +from aiconfig.util.params import resolve_prompt +from aiconfig_extension_hugging_face.local_inference.util import get_hf_model from diffusers import AutoPipelineForText2Image from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput -from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput +from diffusers.pipelines.stable_diffusion_xl.pipeline_output import ( + StableDiffusionXLPipelineOutput, +) from PIL import Image from transformers import Pipeline -from aiconfig_extension_hugging_face.local_inference.util import get_hf_model - -from aiconfig.default_parsers.parameterized_model_parser import ParameterizedModelParser -from aiconfig.model_parser import InferenceOptions from aiconfig.schema import ( ExecuteResult, Output, @@ -22,7 +27,6 @@ Prompt, PromptMetadata, ) -from aiconfig.util.params import resolve_prompt # Circuluar Dependency Type Hints if TYPE_CHECKING: @@ -30,7 +34,9 @@ # Step 1: define Helpers -def refine_pipeline_creation_params(model_settings: Dict[str, Any]) -> List[Dict[str, Any]]: +def refine_pipeline_creation_params( + model_settings: Dict[str, Any] +) -> List[Dict[str, Any]]: """ Refines the pipeline creation params for the HF text2Image generation api. Defers unsupported params as completion params, where they can get processed in @@ -70,7 +76,9 @@ def refine_pipeline_creation_params(model_settings: Dict[str, Any]) -> List[Dict if key.lower() in supported_keys: pipeline_creation_params[key.lower()] = model_settings[key] else: - if key.lower() == "kwargs" and isinstance(model_settings[key], Dict): + if key.lower() == "kwargs" and isinstance( + model_settings[key], Dict + ): completion_params.update(model_settings[key]) else: completion_params[key.lower()] = model_settings[key] @@ -78,7 +86,9 @@ def refine_pipeline_creation_params(model_settings: Dict[str, Any]) -> List[Dict return [pipeline_creation_params, completion_params] -def refine_image_completion_params(unfiltered_completion_params: Dict[str, Any]) -> Dict[str, Any]: +def refine_image_completion_params( + unfiltered_completion_params: Dict[str, Any] +) -> Dict[str, Any]: """ Refines the image creation params for the HF text2Image generation api after a pipeline has been created via `refine_pipeline_creation_params`. Removes any @@ -164,7 +174,9 @@ def pillow_image_to_base64_string(img: Image.Image): "output_type": "execute_result", "data": data, "execution_count": execution_count, - "metadata": {"nsfw_content_detected": image_data.nsfw_content_detected}, + "metadata": { + "nsfw_content_detected": image_data.nsfw_content_detected + }, "mime_type": "image/png", } ) @@ -233,7 +245,11 @@ async def serialize( prompt = Prompt( name=prompt_name, input=prompt_input, - metadata=PromptMetadata(model=model_metadata, parameters=parameters, **completion_params), + metadata=PromptMetadata( + model=model_metadata, + parameters=parameters, + **completion_params, + ), ) return [prompt] @@ -256,15 +272,26 @@ async def deserialize( """ # Build Completion data model_settings = self.get_model_settings(prompt, aiconfig) - [_pipeline_creation_params, unfiltered_completion_params] = refine_pipeline_creation_params(model_settings) - completion_data = refine_image_completion_params(unfiltered_completion_params) + [ + _pipeline_creation_params, + unfiltered_completion_params, + ] = refine_pipeline_creation_params(model_settings) + completion_data = refine_image_completion_params( + unfiltered_completion_params + ) # Add resolved prompt resolved_prompt = resolve_prompt(prompt, params, aiconfig) completion_data["prompt"] = resolved_prompt return completion_data - async def run_inference(self, prompt: Prompt, aiconfig: "AIConfigRuntime", options: InferenceOptions, parameters: Dict[str, Any]) -> List[Output]: + async def run_inference( + self, + prompt: Prompt, + aiconfig: "AIConfigRuntime", + options: InferenceOptions, + parameters: Dict[str, Any], + ) -> List[Output]: """ Invoked to run a prompt in the .aiconfig. This method should perform the actual model inference based on the provided prompt and inference settings. @@ -277,7 +304,9 @@ async def run_inference(self, prompt: Prompt, aiconfig: "AIConfigRuntime", optio InferenceResponse: The response from the model. """ model_settings = self.get_model_settings(prompt, aiconfig) - [pipeline_creation_data, _] = refine_pipeline_creation_params(model_settings) + [pipeline_creation_data, _] = refine_pipeline_creation_params( + model_settings + ) if not pipeline_creation_data.get("requires_safety_checker", True): pipeline_creation_data["safety_checker"] = None @@ -300,7 +329,9 @@ async def run_inference(self, prompt: Prompt, aiconfig: "AIConfigRuntime", optio # https://huggingface.co/docs/diffusers/using-diffusers/loading#checkpoint-variants if key not in self.generators: device = self._get_device() - self.generators[key] = AutoPipelineForText2Image.from_pretrained(pretrained_model_or_path=model_name, **pipeline_creation_data).to(device) + self.generators[key] = AutoPipelineForText2Image.from_pretrained( + pretrained_model_or_path=model_name, **pipeline_creation_data + ).to(device) generator = self.generators[key] disclaimer_long_response_print_message = """\n @@ -316,8 +347,12 @@ async def run_inference(self, prompt: Prompt, aiconfig: "AIConfigRuntime", optio """ print(disclaimer_long_response_print_message) - completion_data = await self.deserialize(prompt, aiconfig, options, parameters) - response: Union[StableDiffusionPipelineOutput, StableDiffusionXLPipelineOutput] = generator(**completion_data) + completion_data = await self.deserialize( + prompt, aiconfig, options, parameters + ) + response: Union[ + StableDiffusionPipelineOutput, StableDiffusionXLPipelineOutput + ] = generator(**completion_data) nsfw_content_detected = [] if hasattr(response, "nsfw_content_detected"): # StableDiffusionPipelineOutput has "nsfw_content_detected" field but @@ -328,7 +363,9 @@ async def run_inference(self, prompt: Prompt, aiconfig: "AIConfigRuntime", optio # TODO (rossdanlm): Check if "image" field is present for other image # diffusers other than StableDiffusion and StableDiffusionXL # https://github.com/lastmile-ai/aiconfig/issues/471 - refined_responses = _refine_responses(response.images or [], nsfw_content_detected) + refined_responses = _refine_responses( + response.images or [], nsfw_content_detected + ) for count, image_data in enumerate(refined_responses): # TODO (rossdanlm): It's possible for image to be of type np.ndarray # Update `construct_output` to process this type. @@ -399,5 +436,8 @@ def _refine_responses( # Use zip.longest because nsfw_content_detected can be empty itertools.zip_longest(response_images, nsfw_content_detected) ) - image_data_objects: List[ImageData] = [ImageData(image=image, nsfw_content_detected=has_nsfw) for (image, has_nsfw) in merged_responses] + image_data_objects: List[ImageData] = [ + ImageData(image=image, nsfw_content_detected=has_nsfw) + for (image, has_nsfw) in merged_responses + ] return image_data_objects diff --git a/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/local_inference/text_2_speech.py b/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/local_inference/text_2_speech.py index a68dff207..db107b38d 100644 --- a/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/local_inference/text_2_speech.py +++ b/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/local_inference/text_2_speech.py @@ -2,15 +2,18 @@ import copy import io import json -import numpy as np from typing import TYPE_CHECKING, Any, Dict, List, Optional -from transformers import Pipeline, pipeline -from scipy.io.wavfile import write as write_wav +import numpy as np +from aiconfig.default_parsers.parameterized_model_parser import ( + ParameterizedModelParser, +) +from aiconfig.model_parser import InferenceOptions +from aiconfig.util.params import resolve_prompt from aiconfig_extension_hugging_face.local_inference.util import get_hf_model +from scipy.io.wavfile import write as write_wav +from transformers import Pipeline, pipeline -from aiconfig.default_parsers.parameterized_model_parser import ParameterizedModelParser -from aiconfig.model_parser import InferenceOptions from aiconfig.schema import ( ExecuteResult, Output, @@ -18,7 +21,6 @@ Prompt, PromptMetadata, ) -from aiconfig.util.params import resolve_prompt # Circuluar Dependency Type Hints if TYPE_CHECKING: @@ -26,7 +28,9 @@ # Step 1: define Helpers -def refine_pipeline_creation_params(model_settings: Dict[str, Any]) -> List[Dict[str, Any]]: +def refine_pipeline_creation_params( + model_settings: Dict[str, Any] +) -> List[Dict[str, Any]]: # These are from the transformers Github repo: # https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_utils.py#L2534 supported_keys = { @@ -56,7 +60,9 @@ def refine_pipeline_creation_params(model_settings: Dict[str, Any]) -> List[Dict if key.lower() in supported_keys: pipeline_creation_params[key.lower()] = model_settings[key] else: - if key.lower() == "kwargs" and isinstance(model_settings[key], Dict): + if key.lower() == "kwargs" and isinstance( + model_settings[key], Dict + ): completion_params.update(model_settings[key]) else: completion_params[key.lower()] = model_settings[key] @@ -64,7 +70,9 @@ def refine_pipeline_creation_params(model_settings: Dict[str, Any]) -> List[Dict return [pipeline_creation_params, completion_params] -def refine_completion_params(unfiltered_completion_params: Dict[str, Any]) -> Dict[str, Any]: +def refine_completion_params( + unfiltered_completion_params: Dict[str, Any] +) -> Dict[str, Any]: # Note: There seems to be no public API docs on what completion # params are supported for text to speech: # https://huggingface.co/docs/transformers/tasks/text-to-speech#inference @@ -83,7 +91,9 @@ def construct_output(audio, execution_count: int) -> Output: def _b64_encode_bytes(byte_array: bytes) -> str: return base64.b64encode(byte_array).decode("utf-8") - def _audio_ndarray_to_wav_bytes(audio: np.ndarray, sampling_rate: int) -> bytes: + def _audio_ndarray_to_wav_bytes( + audio: np.ndarray, sampling_rate: int + ) -> bytes: buffered = io.BytesIO() write_wav(buffered, sampling_rate, audio) @@ -91,11 +101,18 @@ def _audio_ndarray_to_wav_bytes(audio: np.ndarray, sampling_rate: int) -> bytes: byte_array = buffered.getvalue() return byte_array - def _audio_ndarray_to_b64_str(audio: np.ndarray, sampling_rate: int) -> str: + def _audio_ndarray_to_b64_str( + audio: np.ndarray, sampling_rate: int + ) -> str: byte_array = _audio_ndarray_to_wav_bytes(audio, sampling_rate) return _b64_encode_bytes(byte_array) - data = dict(kind="base64", value=_audio_ndarray_to_b64_str(np.squeeze(audio["audio"]), audio["sampling_rate"])) + data = dict( + kind="base64", + value=_audio_ndarray_to_b64_str( + np.squeeze(audio["audio"]), audio["sampling_rate"] + ), + ) output = ExecuteResult( **{ "output_type": "execute_result", @@ -151,7 +168,11 @@ async def serialize( prompt = Prompt( name=prompt_name, input=prompt_input, - metadata=PromptMetadata(model=model_metadata, parameters=parameters, **completion_params), + metadata=PromptMetadata( + model=model_metadata, + parameters=parameters, + **completion_params, + ), ) return [prompt] @@ -174,15 +195,26 @@ async def deserialize( """ # Build Completion data model_settings = self.get_model_settings(prompt, aiconfig) - [_pipeline_creation_params, unfiltered_completion_params] = refine_pipeline_creation_params(model_settings) - completion_data = refine_completion_params(unfiltered_completion_params) + [ + _pipeline_creation_params, + unfiltered_completion_params, + ] = refine_pipeline_creation_params(model_settings) + completion_data = refine_completion_params( + unfiltered_completion_params + ) # Add resolved prompt resolved_prompt = resolve_prompt(prompt, params, aiconfig) completion_data["prompt"] = resolved_prompt return completion_data - async def run_inference(self, prompt: Prompt, aiconfig: "AIConfigRuntime", options: InferenceOptions, parameters: Dict[str, Any]) -> List[Output]: + async def run_inference( + self, + prompt: Prompt, + aiconfig: "AIConfigRuntime", + options: InferenceOptions, + parameters: Dict[str, Any], + ) -> List[Output]: """ Invoked to run a prompt in the .aiconfig. This method should perform the actual model inference based on the provided prompt and inference settings. @@ -195,15 +227,21 @@ async def run_inference(self, prompt: Prompt, aiconfig: "AIConfigRuntime", optio InferenceResponse: The response from the model. """ model_settings = self.get_model_settings(prompt, aiconfig) - [pipeline_creation_data, _] = refine_pipeline_creation_params(model_settings) + [pipeline_creation_data, _] = refine_pipeline_creation_params( + model_settings + ) model_name = get_hf_model(aiconfig, prompt, self) key = model_name if model_name is not None else "__default__" if key not in self.synthesizers: - self.synthesizers[key] = pipeline("text-to-speech", model=model_name) + self.synthesizers[key] = pipeline( + "text-to-speech", model=model_name + ) synthesizer = self.synthesizers[key] - completion_data = await self.deserialize(prompt, aiconfig, options, parameters) + completion_data = await self.deserialize( + prompt, aiconfig, options, parameters + ) inputs = completion_data.pop("prompt", None) response = synthesizer(inputs, **completion_data) diff --git a/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/local_inference/text_generation.py b/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/local_inference/text_generation.py index 7ad79cf97..4090281d3 100644 --- a/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/local_inference/text_generation.py +++ b/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/local_inference/text_generation.py @@ -2,24 +2,21 @@ import json import threading from typing import TYPE_CHECKING, Any, Dict, List, Optional + +from aiconfig.default_parsers.parameterized_model_parser import ( + ParameterizedModelParser, +) +from aiconfig.model_parser import InferenceOptions +from aiconfig.util.params import resolve_prompt +from aiconfig_extension_hugging_face.local_inference.util import get_hf_model from transformers import ( AutoTokenizer, Pipeline, - pipeline, TextIteratorStreamer, + pipeline, ) -from aiconfig_extension_hugging_face.local_inference.util import get_hf_model - -from aiconfig.default_parsers.parameterized_model_parser import ParameterizedModelParser -from aiconfig.model_parser import InferenceOptions -from aiconfig.schema import ( - ExecuteResult, - Output, - Prompt, - PromptMetadata, -) -from aiconfig.util.params import resolve_prompt +from aiconfig.schema import ExecuteResult, Output, Prompt, PromptMetadata # Circuluar Dependency Type Hints if TYPE_CHECKING: @@ -92,7 +89,9 @@ def refine_completion_params(model_settings: Dict[str, Any]) -> Dict[str, Any]: return completion_data -def construct_regular_output(result: Dict[str, str], execution_count: int) -> Output: +def construct_regular_output( + result: Dict[str, str], execution_count: int +) -> Output: """ Construct regular output per response result, without streaming enabled """ @@ -193,7 +192,9 @@ async def serialize( prompt = Prompt( name=prompt_name, input=prompt_input, - metadata=PromptMetadata(model=model_metadata, parameters=parameters, **kwargs), + metadata=PromptMetadata( + model=model_metadata, parameters=parameters, **kwargs + ), ) return [prompt] @@ -223,7 +224,13 @@ async def deserialize( completion_data["prompt"] = resolved_prompt return completion_data - async def run_inference(self, prompt: Prompt, aiconfig: "AIConfigRuntime", options: InferenceOptions, parameters: Dict[str, Any]) -> List[Output]: + async def run_inference( + self, + prompt: Prompt, + aiconfig: "AIConfigRuntime", + options: InferenceOptions, + parameters: Dict[str, Any], + ) -> List[Output]: """ Invoked to run a prompt in the .aiconfig. This method should perform the actual model inference based on the provided prompt and inference settings. @@ -235,20 +242,29 @@ async def run_inference(self, prompt: Prompt, aiconfig: "AIConfigRuntime", optio Returns: InferenceResponse: The response from the model. """ - completion_data = await self.deserialize(prompt, aiconfig, options, parameters) + completion_data = await self.deserialize( + prompt, aiconfig, options, parameters + ) completion_data["text_inputs"] = completion_data.pop("prompt", None) model_name = get_hf_model(aiconfig, prompt, self) key = model_name if model_name is not None else "__default__" if key not in self.generators: - self.generators[key] = pipeline("text-generation", model=model_name) + self.generators[key] = pipeline( + "text-generation", model=model_name + ) generator = self.generators[key] # if stream enabled in runtime options and config, then stream. Otherwise don't stream. streamer = None - should_stream = (options.stream if options else False) and (not "stream" in completion_data or completion_data.get("stream") != False) + should_stream = (options.stream if options else False) and ( + not "stream" in completion_data + or completion_data.get("stream") != False + ) if should_stream: - tokenizer: AutoTokenizer = AutoTokenizer.from_pretrained(model_name) + tokenizer: AutoTokenizer = AutoTokenizer.from_pretrained( + model_name + ) streamer = TextIteratorStreamer(tokenizer) completion_data["streamer"] = streamer @@ -261,9 +277,13 @@ async def run_inference(self, prompt: Prompt, aiconfig: "AIConfigRuntime", optio outputs.append(output) else: if completion_data.get("num_return_sequences", 1) > 1: - raise ValueError("Sorry, TextIteratorStreamer does not support multiple return sequences, please set `num_return_sequences` to 1") + raise ValueError( + "Sorry, TextIteratorStreamer does not support multiple return sequences, please set `num_return_sequences` to 1" + ) if not streamer: - raise ValueError("Stream option is selected but streamer is not initialized") + raise ValueError( + "Stream option is selected but streamer is not initialized" + ) # For streaming, cannot call `generator` directly otherwise response will be blocking thread = threading.Thread(target=generator, kwargs=completion_data) diff --git a/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/local_inference/text_summarization.py b/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/local_inference/text_summarization.py index 2c0125a28..894e1f4b3 100644 --- a/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/local_inference/text_summarization.py +++ b/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/local_inference/text_summarization.py @@ -2,24 +2,21 @@ import json import threading from typing import TYPE_CHECKING, Any, Dict, List, Optional + +from aiconfig.default_parsers.parameterized_model_parser import ( + ParameterizedModelParser, +) +from aiconfig.model_parser import InferenceOptions +from aiconfig.util.params import resolve_prompt +from aiconfig_extension_hugging_face.local_inference.util import get_hf_model from transformers import ( AutoTokenizer, Pipeline, - pipeline, TextIteratorStreamer, + pipeline, ) -from aiconfig_extension_hugging_face.local_inference.util import get_hf_model - -from aiconfig.default_parsers.parameterized_model_parser import ParameterizedModelParser -from aiconfig.model_parser import InferenceOptions -from aiconfig.schema import ( - ExecuteResult, - Output, - Prompt, - PromptMetadata, -) -from aiconfig.util.params import resolve_prompt +from aiconfig.schema import ExecuteResult, Output, Prompt, PromptMetadata # Circuluar Dependency Type Hints if TYPE_CHECKING: @@ -92,7 +89,9 @@ def refine_completion_params(model_settings: Dict[str, Any]) -> Dict[str, Any]: return completion_data -def construct_regular_output(result: Dict[str, str], execution_count: int) -> Output: +def construct_regular_output( + result: Dict[str, str], execution_count: int +) -> Output: """ Construct regular output per response result, without streaming enabled """ @@ -199,7 +198,9 @@ async def serialize( prompt = Prompt( name=prompt_name, input=prompt_input, - metadata=PromptMetadata(model=model_metadata, parameters=parameters, **kwargs), + metadata=PromptMetadata( + model=model_metadata, parameters=parameters, **kwargs + ), ) return [prompt] @@ -229,7 +230,13 @@ async def deserialize( completion_data["prompt"] = resolved_prompt return completion_data - async def run_inference(self, prompt: Prompt, aiconfig: "AIConfigRuntime", options: InferenceOptions, parameters: Dict[str, Any]) -> List[Output]: + async def run_inference( + self, + prompt: Prompt, + aiconfig: "AIConfigRuntime", + options: InferenceOptions, + parameters: Dict[str, Any], + ) -> List[Output]: """ Invoked to run a prompt in the .aiconfig. This method should perform the actual model inference based on the provided prompt and inference settings. @@ -241,7 +248,9 @@ async def run_inference(self, prompt: Prompt, aiconfig: "AIConfigRuntime", optio Returns: InferenceResponse: The response from the model. """ - completion_data = await self.deserialize(prompt, aiconfig, options, parameters) + completion_data = await self.deserialize( + prompt, aiconfig, options, parameters + ) inputs = completion_data.pop("prompt", None) model_name = get_hf_model(aiconfig, prompt, self) @@ -252,9 +261,14 @@ async def run_inference(self, prompt: Prompt, aiconfig: "AIConfigRuntime", optio # if stream enabled in runtime options and config, then stream. Otherwise don't stream. streamer = None - should_stream = (options.stream if options else False) and (not "stream" in completion_data or completion_data.get("stream") != False) + should_stream = (options.stream if options else False) and ( + not "stream" in completion_data + or completion_data.get("stream") != False + ) if should_stream: - tokenizer: AutoTokenizer = AutoTokenizer.from_pretrained(model_name) + tokenizer: AutoTokenizer = AutoTokenizer.from_pretrained( + model_name + ) streamer = TextIteratorStreamer(tokenizer) completion_data["streamer"] = streamer @@ -269,9 +283,13 @@ def _summarize(): outputs.append(output) else: if completion_data.get("num_return_sequences", 1) > 1: - raise ValueError("Sorry, TextIteratorStreamer does not support multiple return sequences, please set `num_return_sequences` to 1") + raise ValueError( + "Sorry, TextIteratorStreamer does not support multiple return sequences, please set `num_return_sequences` to 1" + ) if not streamer: - raise ValueError("Stream option is selected but streamer is not initialized") + raise ValueError( + "Stream option is selected but streamer is not initialized" + ) # For streaming, cannot call `summarizer` directly otherwise response will be blocking thread = threading.Thread(target=_summarize) diff --git a/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/local_inference/text_translation.py b/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/local_inference/text_translation.py index abdc5a625..f42d70e3f 100644 --- a/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/local_inference/text_translation.py +++ b/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/local_inference/text_translation.py @@ -2,24 +2,21 @@ import json import threading from typing import TYPE_CHECKING, Any, Dict, List, Optional + +from aiconfig.default_parsers.parameterized_model_parser import ( + ParameterizedModelParser, +) +from aiconfig.model_parser import InferenceOptions +from aiconfig.util.params import resolve_prompt +from aiconfig_extension_hugging_face.local_inference.util import get_hf_model from transformers import ( AutoTokenizer, Pipeline, - pipeline, TextIteratorStreamer, + pipeline, ) -from aiconfig_extension_hugging_face.local_inference.util import get_hf_model - -from aiconfig.default_parsers.parameterized_model_parser import ParameterizedModelParser -from aiconfig.model_parser import InferenceOptions -from aiconfig.schema import ( - ExecuteResult, - Output, - Prompt, - PromptMetadata, -) -from aiconfig.util.params import resolve_prompt +from aiconfig.schema import ExecuteResult, Output, Prompt, PromptMetadata # Circuluar Dependency Type Hints if TYPE_CHECKING: @@ -92,7 +89,9 @@ def refine_completion_params(model_settings: Dict[str, Any]) -> Dict[str, Any]: return completion_data -def construct_regular_output(result: Dict[str, str], execution_count: int) -> Output: +def construct_regular_output( + result: Dict[str, str], execution_count: int +) -> Output: """ Construct regular output per response result, without streaming enabled """ @@ -201,7 +200,9 @@ async def serialize( prompt = Prompt( name=prompt_name, input=prompt_input, - metadata=PromptMetadata(model=model_metadata, parameters=parameters, **kwargs), + metadata=PromptMetadata( + model=model_metadata, parameters=parameters, **kwargs + ), ) return [prompt] @@ -231,7 +232,13 @@ async def deserialize( completion_data["prompt"] = resolved_prompt return completion_data - async def run_inference(self, prompt: Prompt, aiconfig: "AIConfigRuntime", options: InferenceOptions, parameters: Dict[str, Any]) -> List[Output]: + async def run_inference( + self, + prompt: Prompt, + aiconfig: "AIConfigRuntime", + options: InferenceOptions, + parameters: Dict[str, Any], + ) -> List[Output]: """ Invoked to run a prompt in the .aiconfig. This method should perform the actual model inference based on the provided prompt and inference settings. @@ -243,7 +250,9 @@ async def run_inference(self, prompt: Prompt, aiconfig: "AIConfigRuntime", optio Returns: InferenceResponse: The response from the model. """ - completion_data = await self.deserialize(prompt, aiconfig, options, parameters) + completion_data = await self.deserialize( + prompt, aiconfig, options, parameters + ) inputs = completion_data.pop("prompt", None) model_name = get_hf_model(aiconfig, prompt, self) @@ -254,9 +263,14 @@ async def run_inference(self, prompt: Prompt, aiconfig: "AIConfigRuntime", optio # if stream enabled in runtime options and config, then stream. Otherwise don't stream. streamer = None - should_stream = (options.stream if options else False) and (not "stream" in completion_data or completion_data.get("stream") != False) + should_stream = (options.stream if options else False) and ( + not "stream" in completion_data + or completion_data.get("stream") != False + ) if should_stream: - tokenizer: AutoTokenizer = AutoTokenizer.from_pretrained(model_name) + tokenizer: AutoTokenizer = AutoTokenizer.from_pretrained( + model_name + ) streamer = TextIteratorStreamer(tokenizer) completion_data["streamer"] = streamer @@ -272,9 +286,13 @@ def _translate(): outputs.append(output) else: if completion_data.get("num_return_sequences", 1) > 1: - raise ValueError("Sorry, TextIteratorStreamer does not support multiple return sequences, please set `num_return_sequences` to 1") + raise ValueError( + "Sorry, TextIteratorStreamer does not support multiple return sequences, please set `num_return_sequences` to 1" + ) if not streamer: - raise ValueError("Stream option is selected but streamer is not initialized") + raise ValueError( + "Stream option is selected but streamer is not initialized" + ) # For streaming, cannot call `translator` directly otherwise response will be blocking thread = threading.Thread(target=_translate) diff --git a/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/local_inference/util.py b/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/local_inference/util.py index a7ff6d307..f69cd616f 100644 --- a/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/local_inference/util.py +++ b/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/local_inference/util.py @@ -1,20 +1,26 @@ from typing import TYPE_CHECKING + from aiconfig import ParameterizedModelParser from aiconfig.schema import Prompt - # Circular Dependency Type Hints if TYPE_CHECKING: from aiconfig import AIConfigRuntime -def get_hf_model(aiconfig: "AIConfigRuntime", prompt: Prompt, model_parser: ParameterizedModelParser) -> str | None: +def get_hf_model( + aiconfig: "AIConfigRuntime", + prompt: Prompt, + model_parser: ParameterizedModelParser, +) -> str | None: """ Returns the HuggingFace model to use for the given prompt and model parser. """ model_name: str | None = aiconfig.get_model_name(prompt) model_settings = model_parser.get_model_settings(prompt, aiconfig) - hf_model = model_settings.get("model") or None # Replace "" with None value + hf_model = ( + model_settings.get("model") or None + ) # Replace "" with None value if hf_model is not None and isinstance(hf_model, str): # If the model property is set in the model settings, use that. diff --git a/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/remote_inference_client/text_generation.py b/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/remote_inference_client/text_generation.py index 02bc0e514..4dba9c042 100644 --- a/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/remote_inference_client/text_generation.py +++ b/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/remote_inference_client/text_generation.py @@ -2,6 +2,13 @@ import json from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Union +from aiconfig.default_parsers.parameterized_model_parser import ( + ParameterizedModelParser, +) +from aiconfig.model_parser import InferenceOptions +from aiconfig.util.config_utils import get_api_key_from_environment +from aiconfig.util.params import resolve_prompt + # HuggingFace API imports from huggingface_hub import InferenceClient from huggingface_hub.inference._text_generation import ( @@ -10,17 +17,7 @@ ) from aiconfig import CallbackEvent -from aiconfig.default_parsers.parameterized_model_parser import ParameterizedModelParser -from aiconfig.model_parser import InferenceOptions -from aiconfig.schema import ( - ExecuteResult, - Output, - Prompt, - PromptMetadata, -) -from aiconfig.util.config_utils import get_api_key_from_environment -from aiconfig.util.params import resolve_prompt - +from aiconfig.schema import ExecuteResult, Output, Prompt, PromptMetadata # Circuluar Dependency Type Hints if TYPE_CHECKING: @@ -105,7 +102,9 @@ def construct_stream_output( return output -def construct_regular_output(response: TextGenerationResponse, response_includes_details: bool) -> Output: +def construct_regular_output( + response: TextGenerationResponse, response_includes_details: bool +) -> Output: metadata = {"raw_response": response} if response_includes_details: metadata["details"] = response.details @@ -153,7 +152,9 @@ def __init__(self, model_id: str = None, use_api_token=False): if use_api_token: # You are allowed to use Hugging Face for a bit before you get # rate limited, in which case you will receive a clear error - token = get_api_key_from_environment("HUGGING_FACE_API_TOKEN", required=False).unwrap() + token = get_api_key_from_environment( + "HUGGING_FACE_API_TOKEN", required=False + ).unwrap() self.client = InferenceClient(model_id, token=token) @@ -208,12 +209,18 @@ async def serialize( prompt = Prompt( name=prompt_name, input=prompt_input, - metadata=PromptMetadata(model=model_metadata, parameters=parameters, **kwargs), + metadata=PromptMetadata( + model=model_metadata, parameters=parameters, **kwargs + ), ) prompts.append(prompt) - await ai_config.callback_manager.run_callbacks(CallbackEvent("on_serialize_complete", __name__, {"result": prompts})) + await ai_config.callback_manager.run_callbacks( + CallbackEvent( + "on_serialize_complete", __name__, {"result": prompts} + ) + ) return prompts @@ -233,7 +240,13 @@ async def deserialize( Returns: dict: Model-specific completion parameters. """ - await aiconfig.callback_manager.run_callbacks(CallbackEvent("on_deserialize_start", __name__, {"prompt": prompt, "params": params})) + await aiconfig.callback_manager.run_callbacks( + CallbackEvent( + "on_deserialize_start", + __name__, + {"prompt": prompt, "params": params}, + ) + ) resolved_prompt = resolve_prompt(prompt, params, aiconfig) @@ -244,11 +257,23 @@ async def deserialize( completion_data["prompt"] = resolved_prompt - await aiconfig.callback_manager.run_callbacks(CallbackEvent("on_deserialize_complete", __name__, {"output": completion_data})) + await aiconfig.callback_manager.run_callbacks( + CallbackEvent( + "on_deserialize_complete", + __name__, + {"output": completion_data}, + ) + ) return completion_data - async def run_inference(self, prompt: Prompt, aiconfig: "AIConfigRuntime", options: InferenceOptions, parameters: dict[Any, Any]) -> List[Output]: + async def run_inference( + self, + prompt: Prompt, + aiconfig: "AIConfigRuntime", + options: InferenceOptions, + parameters: dict[Any, Any], + ) -> List[Output]: """ Invoked to run a prompt in the .aiconfig. This method should perform the actual model inference based on the provided prompt and inference settings. @@ -264,7 +289,11 @@ async def run_inference(self, prompt: Prompt, aiconfig: "AIConfigRuntime", optio CallbackEvent( "on_run_start", __name__, - {"prompt": prompt, "options": options, "parameters": parameters}, + { + "prompt": prompt, + "options": options, + "parameters": parameters, + }, ) ) @@ -290,12 +319,16 @@ async def run_inference(self, prompt: Prompt, aiconfig: "AIConfigRuntime", optio outputs.append(output) else: # Handles stream callback - output = construct_stream_output(response, response_is_detailed, options) + output = construct_stream_output( + response, response_is_detailed, options + ) outputs.append(output) prompt.outputs = outputs - await aiconfig.callback_manager.run_callbacks(CallbackEvent("on_run_complete", __name__, {"result": outputs})) + await aiconfig.callback_manager.run_callbacks( + CallbackEvent("on_run_complete", __name__, {"result": outputs}) + ) return outputs diff --git a/extensions/LLama-Guard/python/src/aiconfig_extension_llama_guard/LLamaGuard.py b/extensions/LLama-Guard/python/src/aiconfig_extension_llama_guard/LLamaGuard.py index 34d8233ef..4131fd418 100644 --- a/extensions/LLama-Guard/python/src/aiconfig_extension_llama_guard/LLamaGuard.py +++ b/extensions/LLama-Guard/python/src/aiconfig_extension_llama_guard/LLamaGuard.py @@ -4,12 +4,15 @@ import json from typing import TYPE_CHECKING, Any, Dict, List, Optional -from transformers import AutoTokenizer, AutoModelForCausalLM import torch +from aiconfig.default_parsers.parameterized_model_parser import ( + ParameterizedModelParser, +) +from aiconfig.model_parser import InferenceOptions +from aiconfig.util.params import resolve_prompt +from transformers import AutoModelForCausalLM, AutoTokenizer from aiconfig import CallbackEvent -from aiconfig.default_parsers.parameterized_model_parser import ParameterizedModelParser -from aiconfig.model_parser import InferenceOptions from aiconfig.schema import ( ExecuteResult, Output, @@ -17,7 +20,6 @@ Prompt, PromptMetadata, ) -from aiconfig.util.params import resolve_prompt # Circuluar Dependency Type Hints if TYPE_CHECKING: @@ -25,7 +27,9 @@ # Step 1: define Helpers -def refine_chat_completion_params(model_settings: Dict[str, Any]) -> Dict[str, Any]: +def refine_chat_completion_params( + model_settings: Dict[str, Any] +) -> Dict[str, Any]: """ Refines the completion params for the HF text generation api. Removes any unsupported params. The supported keys were found by looking at the HF text generation api. `huggingface_hub.InferenceClient.text_generation()` @@ -90,7 +94,9 @@ def refine_chat_completion_params(model_settings: Dict[str, Any]) -> Dict[str, A return completion_data -def construct_regular_output(result: Dict[str, str], execution_count: int) -> Output: +def construct_regular_output( + result: Dict[str, str], execution_count: int +) -> Output: """ Construct regular output per response result, without streaming enabled """ @@ -105,7 +111,6 @@ def construct_regular_output(result: Dict[str, str], execution_count: int) -> Ou return output - # This model parser doesn't support streaming. TODO: Implement streaming # This Model Parser doesn't support n-outputs. class LLamaGuardParser(ParameterizedModelParser): @@ -126,7 +131,9 @@ def __init__(self): """ super().__init__() model_id = "meta-llama/LlamaGuard-7b" - self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + self.device = torch.device( + "cuda" if torch.cuda.is_available() else "cpu" + ) print("device: ", self.device) dtype = torch.bfloat16 self.tokenizer = AutoTokenizer.from_pretrained(model_id) @@ -201,13 +208,20 @@ async def deserialize( resolved_prompt = resolve_prompt(prompt, params, aiconfig) # Tokenized Prompt - inputs = self.tokenizer([resolved_prompt], return_tensors="pt").to(self.device) + inputs = self.tokenizer([resolved_prompt], return_tensors="pt").to( + self.device + ) - deserialize_output = {"tokenized_input": inputs, "gen_params": completion_data} + deserialize_output = { + "tokenized_input": inputs, + "gen_params": completion_data, + } await aiconfig.callback_manager.run_callbacks( CallbackEvent( - "on_deserialize_complete", __name__, {"text_prompt": resolved_prompt, "output": deserialize_output} + "on_deserialize_complete", + __name__, + {"text_prompt": resolved_prompt, "output": deserialize_output}, ) ) @@ -232,7 +246,9 @@ async def run_inference( InferenceResponse: The response from the model. """ - resolved_data = await self.deserialize(prompt, aiconfig, options, parameters) + resolved_data = await self.deserialize( + prompt, aiconfig, options, parameters + ) # Move to GPU if applicable, self.device is set in __init__). Otherwise this is a no-op tokenized_input_cuda = resolved_data["tokenized_input"].to(self.device) @@ -245,11 +261,13 @@ async def run_inference( output_text = self.tokenizer.decode( response[0][prompt_len:], skip_special_tokens=True ) - output_data_content: str = '' + output_data_content: str = "" if isinstance(output_text, str): output_data_content = output_text else: - raise ValueError(f"Output {output_text} needs to be of type 'str' but is of type: {type(output_text)}") + raise ValueError( + f"Output {output_text} needs to be of type 'str' but is of type: {type(output_text)}" + ) output = ExecuteResult( **{ "output_type": "execute_result", diff --git a/extensions/llama/python/llama.py b/extensions/llama/python/llama.py index fc41d6ff7..908215e40 100644 --- a/extensions/llama/python/llama.py +++ b/extensions/llama/python/llama.py @@ -2,17 +2,15 @@ from typing import Any, List from aiconfig.Config import AIConfigRuntime -from aiconfig.default_parsers.parameterized_model_parser import ParameterizedModelParser -from aiconfig.model_parser import InferenceOptions -from aiconfig.schema import ( - ExecuteResult, - OutputDataWithValue, - Output, - Prompt, +from aiconfig.default_parsers.parameterized_model_parser import ( + ParameterizedModelParser, ) +from aiconfig.model_parser import InferenceOptions from aiconfig.util.params import resolve_prompt from llama_cpp import Llama +from aiconfig.schema import ExecuteResult, Output, OutputDataWithValue, Prompt + class LlamaModelParser(ParameterizedModelParser): def __init__(self, model_path: str) -> None: @@ -33,7 +31,10 @@ async def serialize( return [out] async def deserialize( - self, prompt: Prompt, aiconfig: AIConfigRuntime, params: dict | None = None + self, + prompt: Prompt, + aiconfig: AIConfigRuntime, + params: dict | None = None, ) -> dict: resolved = resolve_prompt(prompt, params, aiconfig) @@ -74,7 +75,9 @@ async def run_inference( return [result] - async def _run_inference_helper(self, model_input, options) -> List[Output]: + async def _run_inference_helper( + self, model_input, options + ) -> List[Output]: llm = Llama(self.model_path) acc = "" stream = options.stream if options else True @@ -88,15 +91,17 @@ async def _run_inference_helper(self, model_input, options) -> List[Output]: options.stream_callback(data, acc, index) print(flush=True) - output_data_value: str = '' + output_data_value: str = "" if isinstance(acc, str): output_data_value = acc else: - raise ValueError(f"Output {acc} needs to be of type 'str' but is of type: {type(acc)}") + raise ValueError( + f"Output {acc} needs to be of type 'str' but is of type: {type(acc)}" + ) return ExecuteResult( output_type="execute_result", data=output_data_value, - metadata={} + metadata={}, ) else: response = llm(model_input) @@ -108,14 +113,17 @@ async def _run_inference_helper(self, model_input, options) -> List[Output]: return ExecuteResult( output_type="execute_result", # TODO: Map all text responses to multiple outputs - # This would be part of a large refactor: + # This would be part of a large refactor: # https://github.com/lastmile-ai/aiconfig/issues/630 data="\n".join(texts), - metadata={} + metadata={}, ) def get_output_text( - self, prompt: Prompt, aiconfig: AIConfigRuntime, output: Output | None = None + self, + prompt: Prompt, + aiconfig: AIConfigRuntime, + output: Output | None = None, ) -> str: if not output: output = aiconfig.get_latest_output(prompt) @@ -141,4 +149,6 @@ def get_output_text( # TODO: This is an error since list is not compatible with str type return output_data return "" - raise ValueError(f"Output is an unexpected output type: {type(output)}") + raise ValueError( + f"Output is an unexpected output type: {type(output)}" + ) diff --git a/python/demo/function_calling_demo.py b/python/demo/function_calling_demo.py index 22fb47e72..c417a9b19 100644 --- a/python/demo/function_calling_demo.py +++ b/python/demo/function_calling_demo.py @@ -33,7 +33,9 @@ async def function_calling(): while True: model_output = await config.run(promptToRun, params, inference_options) - output = model_output[0] if isinstance(model_output, list) else model_output + output = ( + model_output[0] if isinstance(model_output, list) else model_output + ) if output.output_type == "error": print(f"Error during inference: {output.ename}: {output.evalue}") diff --git a/python/src/aiconfig/ChatCompletion.py b/python/src/aiconfig/ChatCompletion.py index 8bcbd2d1c..c6a0be421 100644 --- a/python/src/aiconfig/ChatCompletion.py +++ b/python/src/aiconfig/ChatCompletion.py @@ -17,7 +17,9 @@ from aiconfig.schema import ExecuteResult, Output, Prompt -def validate_and_add_prompts_to_config(prompts: List[Prompt], aiconfig) -> None: +def validate_and_add_prompts_to_config( + prompts: List[Prompt], aiconfig +) -> None: """ Validates and adds new prompts to the AI configuration, ensuring no duplicates and updating outputs if necessary. @@ -29,7 +31,10 @@ def validate_and_add_prompts_to_config(prompts: List[Prompt], aiconfig) -> None: in_config = False for config_prompt in aiconfig.prompts: # check for duplicates (same input and settings.) - if config_prompt.input == new_prompt.input and new_prompt.metadata == config_prompt.metadata: + if ( + config_prompt.input == new_prompt.input + and new_prompt.metadata == config_prompt.metadata + ): in_config = True # update outputs if different if config_prompt.outputs != new_prompt.outputs: @@ -55,9 +60,15 @@ def extract_outputs_from_response(response) -> List[Output]: response = response.model_dump(exclude_none=True) - response_without_choices = {key: copy.deepcopy(value) for key, value in response.items() if key != "choices"} + response_without_choices = { + key: copy.deepcopy(value) + for key, value in response.items() + if key != "choices" + } for i, choice in enumerate(response.get("choices")): - response_without_choices.update({"finish_reason": choice.get("finish_reason")}) + response_without_choices.update( + {"finish_reason": choice.get("finish_reason")} + ) output = ExecuteResult( **{ "output_type": "execute_result", @@ -83,7 +94,9 @@ def async_run_serialize_helper( serialized_prompts = None async def run_and_await_serialize(): - result = await aiconfig.serialize(request_kwargs.get("model"), request_kwargs, "prompt") + result = await aiconfig.serialize( + request_kwargs.get("model"), request_kwargs, "prompt" + ) return result # serialize prompts from ChatCompletion kwargs @@ -118,21 +131,33 @@ def _get_aiconfig_runtime(output_aiconfig_path: str) -> AIConfigRuntime: except IOError: return AIConfigRuntime.create(**(aiconfig_settings or {})) - output_aiconfig = output_aiconfig_ref if isinstance(output_aiconfig_ref, AIConfigRuntime) else _get_aiconfig_runtime(output_aiconfig_ref) + output_aiconfig = ( + output_aiconfig_ref + if isinstance(output_aiconfig_ref, AIConfigRuntime) + else _get_aiconfig_runtime(output_aiconfig_ref) + ) - output_config_file_path = output_aiconfig_ref if isinstance(output_aiconfig_ref, str) else output_aiconfig_ref.file_path + output_config_file_path = ( + output_aiconfig_ref + if isinstance(output_aiconfig_ref, str) + else output_aiconfig_ref.file_path + ) # TODO: openai makes it hard to statically annotate. def _create_chat_completion_with_config_saving(*args, **kwargs) -> Any: # type: ignore response = openai_api.chat.completions.create(*args, **kwargs) - serialized_prompts = async_run_serialize_helper(output_aiconfig, kwargs) + serialized_prompts = async_run_serialize_helper( + output_aiconfig, kwargs + ) # serialize output from response outputs = [] # Check if response is a stream - stream = kwargs.get("stream", False) is True and isinstance(response, openai.Stream) + stream = kwargs.get("stream", False) is True and isinstance( + response, openai.Stream + ) # Convert Response to output for last prompt if not stream: @@ -141,7 +166,9 @@ def _create_chat_completion_with_config_saving(*args, **kwargs) -> Any: # type: # Add outputs to last prompt serialized_prompts[-1].outputs = outputs - validate_and_add_prompts_to_config(serialized_prompts, output_aiconfig) + validate_and_add_prompts_to_config( + serialized_prompts, output_aiconfig + ) # Save config to file output_aiconfig.save(output_config_file_path, include_outputs=True) @@ -159,28 +186,41 @@ def generate_streamed_response() -> Generator[Any, None, None]: chunk_dict = chunk.model_dump(exclude_none=True) # type: ignore [fixme] # streaming only returns one chunk, one choice at a time. The order in which the choices are returned is not guaranteed. - messages = multi_choice_message_reducer(messages, chunk_dict) + messages = multi_choice_message_reducer( + messages, chunk_dict + ) for choice in chunk_dict["choices"]: index = choice.get("index") - accumulated_message_for_choice = messages.get(index, {}) + accumulated_message_for_choice = messages.get( + index, {} + ) output = ExecuteResult( output_type="execute_result", data=copy.deepcopy(accumulated_message_for_choice), execution_count=index, - metadata={"finish_reason": choice.get("finish_reason")}, + metadata={ + "finish_reason": choice.get("finish_reason") + }, ) stream_outputs[index] = output yield chunk - stream_outputs = [stream_outputs[i] for i in sorted(list(stream_outputs.keys()))] + stream_outputs = [ + stream_outputs[i] + for i in sorted(list(stream_outputs.keys())) + ] # Add outputs to last prompt serialized_prompts[-1].outputs = stream_outputs - validate_and_add_prompts_to_config(serialized_prompts, output_aiconfig) + validate_and_add_prompts_to_config( + serialized_prompts, output_aiconfig + ) # Save config to file - output_aiconfig.save(output_config_file_path, include_outputs=True) + output_aiconfig.save( + output_config_file_path, include_outputs=True + ) return generate_streamed_response() @@ -215,6 +255,8 @@ def get_completion_create_wrapped_openai_client( openai_api=api, aiconfig_settings=aiconfig_settings, ) - client_mocked = core_utils.make_wrap_object(api, "chat.completions.create", wrapped) + client_mocked = core_utils.make_wrap_object( + api, "chat.completions.create", wrapped + ) return cast(openai.OpenAI, client_mocked) diff --git a/python/src/aiconfig/Config.py b/python/src/aiconfig/Config.py index 31da31185..de24cbc3b 100644 --- a/python/src/aiconfig/Config.py +++ b/python/src/aiconfig/Config.py @@ -5,7 +5,9 @@ import requests import yaml from aiconfig.callback import CallbackEvent, CallbackManager -from aiconfig.default_parsers.anyscale_endpoint import DefaultAnyscaleEndpointParser +from aiconfig.default_parsers.anyscale_endpoint import ( + DefaultAnyscaleEndpointParser, +) from aiconfig.default_parsers.openai import DefaultOpenAIParser from aiconfig.default_parsers.palm import PaLMChatParser, PaLMTextParser from aiconfig.model_parser import InferenceOptions, ModelParser @@ -14,7 +16,10 @@ from .default_parsers.dalle import DalleImageGenerationParser from .default_parsers.hf import HuggingFaceTextGenerationParser -from .registry import ModelParserRegistry, update_model_parser_registry_with_config_runtime +from .registry import ( + ModelParserRegistry, + update_model_parser_registry_with_config_runtime, +) from .schema import AIConfig, Prompt from .util.config_utils import is_yaml_ext @@ -33,7 +38,9 @@ ] for model in gpt_models: ModelParserRegistry.register_model_parser(DefaultOpenAIParser(model)) -ModelParserRegistry.register_model_parser(DefaultAnyscaleEndpointParser("AnyscaleEndpoint")) +ModelParserRegistry.register_model_parser( + DefaultAnyscaleEndpointParser("AnyscaleEndpoint") +) ModelParserRegistry.register_model_parser(PaLMChatParser()) ModelParserRegistry.register_model_parser(PaLMTextParser()) ModelParserRegistry.register_model_parser(HuggingFaceTextGenerationParser()) @@ -42,7 +49,9 @@ "dall-e-3", ] for model in dalle_image_generation_models: - ModelParserRegistry.register_model_parser(DalleImageGenerationParser(model)) + ModelParserRegistry.register_model_parser( + DalleImageGenerationParser(model) + ) class AIConfigRuntime(AIConfig): @@ -127,14 +136,18 @@ def load_from_workbook(cls, workbook_id: str) -> "AIConfigRuntime": lastmileapi_token = os.environ.get("LASTMILE_API_TOKEN") if not lastmileapi_token: - raise ValueError("LASTMILE_API_TOKEN environment variable is not set.") + raise ValueError( + "LASTMILE_API_TOKEN environment variable is not set." + ) headers = {"Authorization": "Bearer " + lastmileapi_token} url = f"{API_ENDPOINT}/workbooks/aiconfig?id={workbook_id}" resp = requests.get(url, headers=headers) if resp.status_code != 200: - raise Exception(f"Failed to load workbook. Status code: {resp.status_code}") + raise Exception( + f"Failed to load workbook. Status code: {resp.status_code}" + ) data = resp.json() @@ -177,11 +190,15 @@ async def serialize( model_parser = ModelParserRegistry.get_model_parser(model_name) if not model_parser: - raise ValueError(f"Unable to serialize data: `{data}`\n Model Parser for model {model_name} does not exist.") + raise ValueError( + f"Unable to serialize data: `{data}`\n Model Parser for model {model_name} does not exist." + ) prompts = await model_parser.serialize(prompt_name, data, self, params) - event = CallbackEvent("on_serialize_complete", __name__, {"result": prompts}) + event = CallbackEvent( + "on_serialize_complete", __name__, {"result": prompts} + ) await self.callback_manager.run_callbacks(event) return prompts @@ -201,14 +218,20 @@ async def resolve( Returns: str: The resolved prompt. """ - event = CallbackEvent("on_resolve_start", __file__, {"prompt_name": prompt_name, "params": params}) + event = CallbackEvent( + "on_resolve_start", + __file__, + {"prompt_name": prompt_name, "params": params}, + ) await self.callback_manager.run_callbacks(event) if not params: params = {} if prompt_name not in self.prompt_index: - raise IndexError(f"Prompt '{prompt_name}' not found in config, available prompts are:\n {list(self.prompt_index.keys())}") + raise IndexError( + f"Prompt '{prompt_name}' not found in config, available prompts are:\n {list(self.prompt_index.keys())}" + ) prompt_data = self.prompt_index[prompt_name] model_name = self.get_model_name(prompt_data) @@ -216,7 +239,9 @@ async def resolve( response = await model_provider.deserialize(prompt_data, self, params) - event = CallbackEvent("on_resolve_complete", __name__, {"result": response}) + event = CallbackEvent( + "on_resolve_complete", __name__, {"result": response} + ) await self.callback_manager.run_callbacks(event) return response @@ -253,7 +278,9 @@ async def run( params = {} if prompt_name not in self.prompt_index: - raise IndexError(f"Prompt '{prompt_name}' not found in config, available prompts are:\n {list(self.prompt_index.keys())}") + raise IndexError( + f"Prompt '{prompt_name}' not found in config, available prompts are:\n {list(self.prompt_index.keys())}" + ) prompt_data = self.prompt_index[prompt_name] model_name = self.get_model_name(prompt_data) @@ -267,10 +294,12 @@ async def run( self, options, params, - **kwargs, # TODO: We should remove and make argument explicit + **kwargs, # TODO: We should remove and make argument explicit ) - event = CallbackEvent("on_run_complete", __name__, {"result": response}) + event = CallbackEvent( + "on_run_complete", __name__, {"result": response} + ) await self.callback_manager.run_callbacks(event) return response @@ -316,7 +345,9 @@ async def run_batch( # Check if the provided prompt name is available in the list of prompts if prompt_name not in self.prompt_index: - raise IndexError(f"Prompt '{prompt_name}' not found in config, available prompts are:\n {list(self.prompt_index.keys())}") + raise IndexError( + f"Prompt '{prompt_name}' not found in config, available prompts are:\n {list(self.prompt_index.keys())}" + ) # Retrieve model and respective provider prompt_data = self.prompt_index[prompt_name] @@ -343,11 +374,25 @@ async def run_batch( parameters_dict_used = parameters_list[i] aiconfig_execute_results = aiconfig.get_prompt(prompt_name).outputs - prompt_data_resolved = await aiconfig.resolve(prompt_name, parameters_dict_used) + prompt_data_resolved = await aiconfig.resolve( + prompt_name, parameters_dict_used + ) - batch_results_formatted.append(tuple([aiconfig_execute_results, prompt_data_resolved, parameters_dict_used])) + batch_results_formatted.append( + tuple( + [ + aiconfig_execute_results, + prompt_data_resolved, + parameters_dict_used, + ] + ) + ) - event = CallbackEvent("on_run_batch_complete", __name__, {"result": batch_results_formatted}) + event = CallbackEvent( + "on_run_batch_complete", + __name__, + {"result": batch_results_formatted}, + ) await self.callback_manager.run_callbacks(event) return batch_results_formatted @@ -359,10 +404,17 @@ async def run_and_get_output_text( options: Optional[InferenceOptions] = None, **kwargs, ) -> str: - result: Any = await self.run(prompt_name, params, options=options, **kwargs) + result: Any = await self.run( + prompt_name, params, options=options, **kwargs + ) return self.get_output_text(prompt_name, result[0]) - def save(self, config_filepath: str | None = None, include_outputs: bool = True, mode: Literal["json", "yaml"] | None = None): + def save( + self, + config_filepath: str | None = None, + include_outputs: bool = True, + mode: Literal["json", "yaml"] | None = None, + ): """ Save the AI Configuration to a file. @@ -382,7 +434,9 @@ def save(self, config_filepath: str | None = None, include_outputs: bool = True, if not include_outputs: exclude_options["prompts"] = {"__all__": {"outputs"}} - default_filepath = "aiconfig.yaml" if mode == "yaml" else "aiconfig.json" + default_filepath = ( + "aiconfig.yaml" if mode == "yaml" else "aiconfig.json" + ) if not config_filepath: config_filepath = self.file_path or default_filepath @@ -410,14 +464,18 @@ def save(self, config_filepath: str | None = None, include_outputs: bool = True, ) else: # Save AIConfig as JSON to the file, with the schema specified - json_data["$schema"] = "https://json.schemastore.org/aiconfig-1.0" + json_data[ + "$schema" + ] = "https://json.schemastore.org/aiconfig-1.0" json.dump( json_data, file, indent=2, ) - def get_output_text(self, prompt: str | Prompt, output: Optional[dict] = None) -> str: + def get_output_text( + self, prompt: str | Prompt, output: Optional[dict] = None + ) -> str: """ Get the string representing the output from a prompt (if any) @@ -430,7 +488,9 @@ def get_output_text(self, prompt: str | Prompt, output: Optional[dict] = None) - """ if isinstance(prompt, str): prompt = self.get_prompt(prompt) - model_parser = ModelParserRegistry.get_model_parser_for_prompt(prompt, self) + model_parser = ModelParserRegistry.get_model_parser_for_prompt( + prompt, self + ) return model_parser.get_output_text(prompt, self, output) @staticmethod @@ -455,10 +515,14 @@ def get_model_parser(model_id: str) -> ModelParser: ModelParser: The model parser corresponding to the given identifier. """ if model_id not in ModelParserRegistry.parser_ids(): - raise IndexError(f"Model parser '{model_id}' not found in registry, available model parsers are:\n {ModelParserRegistry.parser_ids()}") + raise IndexError( + f"Model parser '{model_id}' not found in registry, available model parsers are:\n {ModelParserRegistry.parser_ids()}" + ) return ModelParserRegistry.get_model_parser(model_id) def set_callback_manager(self, callback_manager: CallbackManager): if callback_manager is None: - raise ValueError("callback_manager cannot be None. Create a new CallbackManager with No callbacks instead.") + raise ValueError( + "callback_manager cannot be None. Create a new CallbackManager with No callbacks instead." + ) self.callback_manager = callback_manager diff --git a/python/src/aiconfig/__init__.py b/python/src/aiconfig/__init__.py index e225f40e3..1009b71f5 100644 --- a/python/src/aiconfig/__init__.py +++ b/python/src/aiconfig/__init__.py @@ -1,6 +1,11 @@ # Core Data Classes # Callback Utilities -from .callback import Callback, CallbackEvent, CallbackManager, create_logging_callback +from .callback import ( + Callback, + CallbackEvent, + CallbackManager, + create_logging_callback, +) # The AIConfigRuntime class. This is the main class that you will use to run your AIConfig. from .Config import AIConfigRuntime @@ -8,16 +13,18 @@ # Model Parsers from .default_parsers.palm import PaLMChatParser, PaLMTextParser -from .default_parsers.parameterized_model_parser import ParameterizedModelParser +from .default_parsers.parameterized_model_parser import ( + ParameterizedModelParser, +) # ModelParser Utilities from .model_parser import InferenceOptions, ModelParser from .registry import ModelParserRegistry from .schema import ( AIConfig, + AttachmentDataWithStringValue, ConfigMetadata, ExecuteResult, - AttachmentDataWithStringValue, JSONObject, ModelMetadata, Output, diff --git a/python/src/aiconfig/callback.py b/python/src/aiconfig/callback.py index d38895391..cf598d307 100644 --- a/python/src/aiconfig/callback.py +++ b/python/src/aiconfig/callback.py @@ -2,7 +2,17 @@ import asyncio import logging import time -from typing import Any, Awaitable, Callable, Coroutine, Final, List, Sequence, TypeAlias, Union +from typing import ( + Any, + Awaitable, + Callable, + Coroutine, + Final, + List, + Sequence, + TypeAlias, + Union, +) from pydantic import BaseModel, ConfigDict @@ -18,7 +28,9 @@ class CallbackEvent: Represents an event with data to be passed to a callback. """ - def __init__(self, name: str, file: str, data: Any, ts_ns: int = time.time_ns()): + def __init__( + self, name: str, file: str, data: Any, ts_ns: int = time.time_ns() + ): self.name = name # The name of the file that triggered the event. self.file = file @@ -33,7 +45,9 @@ def __init__(self, name: str, file: str, data: Any, ts_ns: int = time.time_ns()) Result: TypeAlias = Union[Ok[Any], Err[Any]] -async def execute_coroutine_with_timeout(coroutine: Coroutine[Any, Any, Any], timeout: int) -> Result: +async def execute_coroutine_with_timeout( + coroutine: Coroutine[Any, Any, Any], timeout: int +) -> Result: """ Execute a coroutine with a timeout, return an Ok result or an Err on Exception @@ -73,7 +87,9 @@ class CallbackManager: Manages a sequence of callbacks to be executed in response to Events """ - def __init__(self, callbacks: Sequence[Callback], timeout: int = None) -> None: + def __init__( + self, callbacks: Sequence[Callback], timeout: int = None + ) -> None: if timeout is None: timeout = DEFAULT_TIMEOUT self.callbacks: Final[Sequence[Callback]] = callbacks @@ -84,7 +100,9 @@ async def run_callbacks(self, event: CallbackEvent) -> None: event = CallbackEventModel(**event.__dict__) tasks = [] for callback in self.callbacks: - task = execute_coroutine_with_timeout(callback(event), self.timeout) + task = execute_coroutine_with_timeout( + callback(event), self.timeout + ) tasks.append(task) self.results = await asyncio.gather(*tasks) @@ -111,7 +129,9 @@ def setup_logger(): name = "my-logger" log_file = "aiconfig.log" - formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") + formatter = logging.Formatter( + "%(asctime)s - %(name)s - %(levelname)s - %(message)s" + ) handler = logging.FileHandler(log_file) handler.setFormatter(formatter) diff --git a/python/src/aiconfig/default_parsers/anyscale_endpoint.py b/python/src/aiconfig/default_parsers/anyscale_endpoint.py index a0fdd150e..ab0724d73 100644 --- a/python/src/aiconfig/default_parsers/anyscale_endpoint.py +++ b/python/src/aiconfig/default_parsers/anyscale_endpoint.py @@ -7,7 +7,15 @@ from openai import OpenAI from openai.types.chat import ChatCompletionMessage -from aiconfig.schema import ExecuteResult, FunctionCallData, Output, OutputDataWithToolCallsValue, OutputDataWithValue, Prompt, ToolCallData +from aiconfig.schema import ( + ExecuteResult, + FunctionCallData, + Output, + OutputDataWithToolCallsValue, + OutputDataWithValue, + Prompt, + ToolCallData, +) from .openai import OpenAIInference @@ -38,7 +46,11 @@ async def run_inference( CallbackEvent( "on_run_start", __name__, - {"prompt": prompt, "options": options, "parameters": parameters}, + { + "prompt": prompt, + "options": options, + "parameters": parameters, + }, ) ) @@ -55,7 +67,9 @@ async def run_inference( else: api_key = os.environ[anyscale_api_key_name] - client = OpenAI(api_key=api_key, base_url="https://api.endpoints.anyscale.com/v1") + client = OpenAI( + api_key=api_key, base_url="https://api.endpoints.anyscale.com/v1" + ) completion_data = await self.deserialize(prompt, aiconfig, parameters) # if stream enabled in runtime options and config, then stream. Otherwise don't stream. @@ -75,13 +89,22 @@ async def run_inference( # # OpenAI>1.0.0 uses pydantic models for response response = response.model_dump(exclude_none=True) - response_without_choices = {key: copy.deepcopy(value) for key, value in response.items() if key != "choices"} + response_without_choices = { + key: copy.deepcopy(value) + for key, value in response.items() + if key != "choices" + } for i, choice in enumerate(response.get("choices")): output_message = choice["message"] output_data = build_output_data(output_message) - response_without_choices.update({"finish_reason": choice.get("finish_reason")}) - metadata = {"raw_response": output_message, **response_without_choices} + response_without_choices.update( + {"finish_reason": choice.get("finish_reason")} + ) + metadata = { + "raw_response": output_message, + **response_without_choices, + } if output_message.get("role", None) is not None: metadata["role"] = output_message.get("role") @@ -101,7 +124,11 @@ async def run_inference( for chunk in response: # OpenAI>1.0.0 uses pydantic models. Chunk is of type ChatCompletionChunk; type is not directly importable from openai Library, will require some diffing chunk = chunk.model_dump(exclude_none=True) - chunk_without_choices = {key: copy.deepcopy(value) for key, value in chunk.items() if key != "choices"} + chunk_without_choices = { + key: copy.deepcopy(value) + for key, value in chunk.items() + if key != "choices" + } # streaming only returns one chunk, one choice at a time (before 1.0.0). The order in which the choices are returned is not guaranteed. messages = multi_choice_message_reducer(messages, chunk) @@ -111,7 +138,9 @@ async def run_inference( delta = choice.get("delta") if options and options.stream_callback: - options.stream_callback(delta, accumulated_message_for_choice, index) + options.stream_callback( + delta, accumulated_message_for_choice, index + ) output = ExecuteResult( **{ @@ -139,7 +168,11 @@ async def run_inference( # rewrite or extend list of outputs? prompt.outputs = outputs - await aiconfig.callback_manager.run_callbacks(CallbackEvent("on_run_complete", __name__, {"result": prompt.outputs})) + await aiconfig.callback_manager.run_callbacks( + CallbackEvent( + "on_run_complete", __name__, {"result": prompt.outputs} + ) + ) return prompt.outputs @@ -223,7 +256,9 @@ def reduce(acc, delta): return acc -def multi_choice_message_reducer(messages: Union[Dict[int, dict], None], chunk: dict) -> Dict[int, dict]: +def multi_choice_message_reducer( + messages: Union[Dict[int, dict], None], chunk: dict +) -> Dict[int, dict]: if messages is None: messages = {} diff --git a/python/src/aiconfig/default_parsers/dalle.py b/python/src/aiconfig/default_parsers/dalle.py index feb59fa2c..427b60bd1 100644 --- a/python/src/aiconfig/default_parsers/dalle.py +++ b/python/src/aiconfig/default_parsers/dalle.py @@ -2,7 +2,9 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional import openai -from aiconfig.default_parsers.parameterized_model_parser import ParameterizedModelParser +from aiconfig.default_parsers.parameterized_model_parser import ( + ParameterizedModelParser, +) from aiconfig.util.config_utils import get_api_key_from_environment from aiconfig.util.params import resolve_prompt from openai import OpenAI @@ -10,7 +12,13 @@ # Dall-E API imports from openai.types import Image, ImagesResponse -from aiconfig.schema import ExecuteResult, Output, OutputDataWithStringValue, Prompt, PromptMetadata +from aiconfig.schema import ( + ExecuteResult, + Output, + OutputDataWithStringValue, + Prompt, + PromptMetadata, +) # ModelParser Utils # Type hint imports @@ -30,7 +38,14 @@ def refine_image_completion_params(model_settings, aiconfig, prompt): Refines the completion params for the Dall-E request API. Removes any unsupported params. The supported keys were found by looking at the OpenAI Dall-E API: https://platform.openai.com/docs/api-reference/images/create?lang=python` """ - supported_keys = {"model", "n", "quality", "response_format", "size", "style"} + supported_keys = { + "model", + "n", + "quality", + "response_format", + "size", + "style", + } completion_data = {} for key in model_settings: @@ -48,11 +63,17 @@ def refine_image_completion_params(model_settings, aiconfig, prompt): def construct_output(image_data: Image, execution_count: int) -> Output: data = None if image_data.b64_json is not None: - data = OutputDataWithStringValue(kind="base64", value=str(image_data.b64_json)) + data = OutputDataWithStringValue( + kind="base64", value=str(image_data.b64_json) + ) elif image_data.url is not None: - data = OutputDataWithStringValue(kind="file_uri", value=str(image_data.url)) + data = OutputDataWithStringValue( + kind="file_uri", value=str(image_data.url) + ) else: - raise ValueError(f"Did not receive a valid image type from image_data: {image_data}") + raise ValueError( + f"Did not receive a valid image type from image_data: {image_data}" + ) output = ExecuteResult( **{ "output_type": "execute_result", @@ -88,7 +109,10 @@ def __init__(self, model_id: str = "dall-e-3"): "dall-e-3", } if model_id.lower() not in supported_models: - raise ValueError("{model_id}" + " is not a valid model ID for Dall-E image generation. Supported models: {supported_models}.") + raise ValueError( + "{model_id}" + + " is not a valid model ID for Dall-E image generation. Supported models: {supported_models}." + ) self.model_id = model_id self.client = None @@ -129,12 +153,19 @@ async def serialize( prompt = Prompt( name=prompt_name, input=prompt_input, - metadata=PromptMetadata(model=model_metadata, parameters=parameters, **kwargs), + metadata=PromptMetadata( + model=model_metadata, parameters=parameters, **kwargs + ), ) return [prompt] # TODO (rossdanlm): Update documentation for args - async def deserialize(self, prompt: Prompt, aiconfig: "AIConfigRuntime", params: Optional[Dict] = {}) -> Dict: + async def deserialize( + self, + prompt: Prompt, + aiconfig: "AIConfigRuntime", + params: Optional[Dict] = {}, + ) -> Dict: """ Defines how to parse a prompt in the .aiconfig for a particular model and constructs the completion params for that model. @@ -146,15 +177,21 @@ async def deserialize(self, prompt: Prompt, aiconfig: "AIConfigRuntime", params: dict: Model-specific completion parameters. """ # Get inputs from aiconfig - resolved_prompt = resolve_prompt(prompt, params if params is not None else {}, aiconfig) + resolved_prompt = resolve_prompt( + prompt, params if params is not None else {}, aiconfig + ) model_settings = self.get_model_settings(prompt, aiconfig) # Build Completion data - completion_data = refine_image_completion_params(model_settings, aiconfig, prompt) + completion_data = refine_image_completion_params( + model_settings, aiconfig, prompt + ) completion_data["prompt"] = resolved_prompt return completion_data - async def run_inference(self, prompt: Prompt, aiconfig, _options, parameters) -> List[Output]: + async def run_inference( + self, prompt: Prompt, aiconfig, _options, parameters + ) -> List[Output]: """ Invoked to run a prompt in the .aiconfig. This method should perform the actual model inference based on the provided prompt and inference settings. @@ -168,14 +205,20 @@ async def run_inference(self, prompt: Prompt, aiconfig, _options, parameters) -> """ # If needed, certify the API key and initialize the OpenAI client if not openai.api_key: - openai.api_key = get_api_key_from_environment("OPENAI_API_KEY").unwrap() + openai.api_key = get_api_key_from_environment( + "OPENAI_API_KEY" + ).unwrap() if not self.client: self.client = OpenAI(api_key=openai.api_key) completion_data = await self.deserialize(prompt, aiconfig, parameters) - print("Calling image generation. This can take several seconds, please hold on...") - response: ImagesResponse = self.client.images.generate(**completion_data) + print( + "Calling image generation. This can take several seconds, please hold on..." + ) + response: ImagesResponse = self.client.images.generate( + **completion_data + ) outputs = [] # ImageResponse object also contains a "created" field for timestamp, should I store that somewhere? diff --git a/python/src/aiconfig/default_parsers/hf.py b/python/src/aiconfig/default_parsers/hf.py index 02bc0e514..4dba9c042 100644 --- a/python/src/aiconfig/default_parsers/hf.py +++ b/python/src/aiconfig/default_parsers/hf.py @@ -2,6 +2,13 @@ import json from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Union +from aiconfig.default_parsers.parameterized_model_parser import ( + ParameterizedModelParser, +) +from aiconfig.model_parser import InferenceOptions +from aiconfig.util.config_utils import get_api_key_from_environment +from aiconfig.util.params import resolve_prompt + # HuggingFace API imports from huggingface_hub import InferenceClient from huggingface_hub.inference._text_generation import ( @@ -10,17 +17,7 @@ ) from aiconfig import CallbackEvent -from aiconfig.default_parsers.parameterized_model_parser import ParameterizedModelParser -from aiconfig.model_parser import InferenceOptions -from aiconfig.schema import ( - ExecuteResult, - Output, - Prompt, - PromptMetadata, -) -from aiconfig.util.config_utils import get_api_key_from_environment -from aiconfig.util.params import resolve_prompt - +from aiconfig.schema import ExecuteResult, Output, Prompt, PromptMetadata # Circuluar Dependency Type Hints if TYPE_CHECKING: @@ -105,7 +102,9 @@ def construct_stream_output( return output -def construct_regular_output(response: TextGenerationResponse, response_includes_details: bool) -> Output: +def construct_regular_output( + response: TextGenerationResponse, response_includes_details: bool +) -> Output: metadata = {"raw_response": response} if response_includes_details: metadata["details"] = response.details @@ -153,7 +152,9 @@ def __init__(self, model_id: str = None, use_api_token=False): if use_api_token: # You are allowed to use Hugging Face for a bit before you get # rate limited, in which case you will receive a clear error - token = get_api_key_from_environment("HUGGING_FACE_API_TOKEN", required=False).unwrap() + token = get_api_key_from_environment( + "HUGGING_FACE_API_TOKEN", required=False + ).unwrap() self.client = InferenceClient(model_id, token=token) @@ -208,12 +209,18 @@ async def serialize( prompt = Prompt( name=prompt_name, input=prompt_input, - metadata=PromptMetadata(model=model_metadata, parameters=parameters, **kwargs), + metadata=PromptMetadata( + model=model_metadata, parameters=parameters, **kwargs + ), ) prompts.append(prompt) - await ai_config.callback_manager.run_callbacks(CallbackEvent("on_serialize_complete", __name__, {"result": prompts})) + await ai_config.callback_manager.run_callbacks( + CallbackEvent( + "on_serialize_complete", __name__, {"result": prompts} + ) + ) return prompts @@ -233,7 +240,13 @@ async def deserialize( Returns: dict: Model-specific completion parameters. """ - await aiconfig.callback_manager.run_callbacks(CallbackEvent("on_deserialize_start", __name__, {"prompt": prompt, "params": params})) + await aiconfig.callback_manager.run_callbacks( + CallbackEvent( + "on_deserialize_start", + __name__, + {"prompt": prompt, "params": params}, + ) + ) resolved_prompt = resolve_prompt(prompt, params, aiconfig) @@ -244,11 +257,23 @@ async def deserialize( completion_data["prompt"] = resolved_prompt - await aiconfig.callback_manager.run_callbacks(CallbackEvent("on_deserialize_complete", __name__, {"output": completion_data})) + await aiconfig.callback_manager.run_callbacks( + CallbackEvent( + "on_deserialize_complete", + __name__, + {"output": completion_data}, + ) + ) return completion_data - async def run_inference(self, prompt: Prompt, aiconfig: "AIConfigRuntime", options: InferenceOptions, parameters: dict[Any, Any]) -> List[Output]: + async def run_inference( + self, + prompt: Prompt, + aiconfig: "AIConfigRuntime", + options: InferenceOptions, + parameters: dict[Any, Any], + ) -> List[Output]: """ Invoked to run a prompt in the .aiconfig. This method should perform the actual model inference based on the provided prompt and inference settings. @@ -264,7 +289,11 @@ async def run_inference(self, prompt: Prompt, aiconfig: "AIConfigRuntime", optio CallbackEvent( "on_run_start", __name__, - {"prompt": prompt, "options": options, "parameters": parameters}, + { + "prompt": prompt, + "options": options, + "parameters": parameters, + }, ) ) @@ -290,12 +319,16 @@ async def run_inference(self, prompt: Prompt, aiconfig: "AIConfigRuntime", optio outputs.append(output) else: # Handles stream callback - output = construct_stream_output(response, response_is_detailed, options) + output = construct_stream_output( + response, response_is_detailed, options + ) outputs.append(output) prompt.outputs = outputs - await aiconfig.callback_manager.run_callbacks(CallbackEvent("on_run_complete", __name__, {"result": outputs})) + await aiconfig.callback_manager.run_callbacks( + CallbackEvent("on_run_complete", __name__, {"result": outputs}) + ) return outputs diff --git a/python/src/aiconfig/default_parsers/openai.py b/python/src/aiconfig/default_parsers/openai.py index c6a59ab14..ca3760456 100644 --- a/python/src/aiconfig/default_parsers/openai.py +++ b/python/src/aiconfig/default_parsers/openai.py @@ -4,10 +4,16 @@ import openai from aiconfig.callback import CallbackEvent -from aiconfig.default_parsers.parameterized_model_parser import ParameterizedModelParser +from aiconfig.default_parsers.parameterized_model_parser import ( + ParameterizedModelParser, +) from aiconfig.model_parser import InferenceOptions from aiconfig.util.config_utils import get_api_key_from_environment -from aiconfig.util.params import resolve_prompt, resolve_prompt_string, resolve_system_prompt +from aiconfig.util.params import ( + resolve_prompt, + resolve_prompt_string, + resolve_system_prompt, +) from openai.types.chat import ChatCompletionMessage from aiconfig.schema import ( @@ -72,7 +78,9 @@ async def serialize( conversation_data = {**data} if not "messages" in conversation_data: - raise ValueError("Data must have `messages` array to match openai api spec") + raise ValueError( + "Data must have `messages` array to match openai api spec" + ) # Find first system prompt. Every prompt in the config will bet set to use this system prompt. system_prompt = None @@ -83,9 +91,15 @@ async def serialize( break # Get the global settings for the model - model_name = conversation_data["model"] if "model" in conversation_data else self.id() + model_name = ( + conversation_data["model"] + if "model" in conversation_data + else self.id() + ) - model_metadata = ai_config.get_model_metadata(conversation_data, model_name) + model_metadata = ai_config.get_model_metadata( + conversation_data, model_name + ) # Remove messages array from model metadata. Handled separately model_metadata.settings.pop("messages", None) @@ -107,7 +121,11 @@ async def serialize( i += 1 new_prompt_name = f"{prompt_name}_{len(prompts) + 1}" - input = messsage["content"] if role == "user" else PromptInput(**messsage) + input = ( + messsage["content"] + if role == "user" + else PromptInput(**messsage) + ) assistant_output = [] if assistant_response is not None: @@ -141,11 +159,18 @@ async def serialize( if prompts: prompts[len(prompts) - 1].name = prompt_name - event = CallbackEvent("on_serialize_complete", __name__, {"result": prompts}) + event = CallbackEvent( + "on_serialize_complete", __name__, {"result": prompts} + ) await ai_config.callback_manager.run_callbacks(event) return prompts - async def deserialize(self, prompt: Prompt, aiconfig: "AIConfigRuntime", params: Optional[Dict] = {}) -> Dict: + async def deserialize( + self, + prompt: Prompt, + aiconfig: "AIConfigRuntime", + params: Optional[Dict] = {}, + ) -> Dict: """ Defines how to parse a prompt in the .aiconfig for a particular model and constructs the completion params for that model. @@ -156,11 +181,19 @@ async def deserialize(self, prompt: Prompt, aiconfig: "AIConfigRuntime", params: Returns: dict: Model-specific completion parameters. """ - await aiconfig.callback_manager.run_callbacks(CallbackEvent("on_deserialize_start", __name__, {"prompt": prompt, "params": params})) + await aiconfig.callback_manager.run_callbacks( + CallbackEvent( + "on_deserialize_start", + __name__, + {"prompt": prompt, "params": params}, + ) + ) # Build Completion params model_settings = self.get_model_settings(prompt, aiconfig) - completion_params = refine_chat_completion_params(model_settings, aiconfig, prompt) + completion_params = refine_chat_completion_params( + model_settings, aiconfig, prompt + ) # In the case thhat the messages array weren't saves as part of the model settings, build it here. Messages array is used for conversation history. if not completion_params.get("messages"): @@ -172,12 +205,17 @@ async def deserialize(self, prompt: Prompt, aiconfig: "AIConfigRuntime", params: if isinstance(system_prompt, dict): # If system prompt is an object, then it should have content and role attributes system_prompt = system_prompt["content"] - resolved_system_prompt = resolve_system_prompt(prompt, system_prompt, params, aiconfig) - completion_params["messages"].append({"content": resolved_system_prompt, "role": "system"}) + resolved_system_prompt = resolve_system_prompt( + prompt, system_prompt, params, aiconfig + ) + completion_params["messages"].append( + {"content": resolved_system_prompt, "role": "system"} + ) # Default to always use chat context if not hasattr(prompt.metadata, "remember_chat_context") or ( - hasattr(prompt.metadata, "remember_chat_context") and prompt.metadata.remember_chat_context != False + hasattr(prompt.metadata, "remember_chat_context") + and prompt.metadata.remember_chat_context != False ): # handle chat history. check previous prompts for the same model. if same model, add prompt and its output to completion data if it has a completed output for i, previous_prompt in enumerate(aiconfig.prompts): @@ -185,7 +223,9 @@ async def deserialize(self, prompt: Prompt, aiconfig: "AIConfigRuntime", params: if previous_prompt.name == prompt.name: break - if aiconfig.get_model_name(previous_prompt) == aiconfig.get_model_name(prompt): + if aiconfig.get_model_name( + previous_prompt + ) == aiconfig.get_model_name(prompt): # Add prompt and its output to completion data. Constructing this prompt will take into account available parameters. add_prompt_as_message( previous_prompt, @@ -196,7 +236,9 @@ async def deserialize(self, prompt: Prompt, aiconfig: "AIConfigRuntime", params: else: # If messages are already specified in the model settings, then just resolve each message with the given parameters and append the latest message for i in range(len(completion_params.get("messages"))): - completion_params["messages"][i]["content"] = resolve_prompt_string( + completion_params["messages"][i][ + "content" + ] = resolve_prompt_string( prompt, params, aiconfig, @@ -204,8 +246,16 @@ async def deserialize(self, prompt: Prompt, aiconfig: "AIConfigRuntime", params: ) # Add in the latest prompt - add_prompt_as_message(prompt, aiconfig, completion_params["messages"], params) - await aiconfig.callback_manager.run_callbacks(CallbackEvent("on_deserialize_complete", __name__, {"output": completion_params})) + add_prompt_as_message( + prompt, aiconfig, completion_params["messages"], params + ) + await aiconfig.callback_manager.run_callbacks( + CallbackEvent( + "on_deserialize_complete", + __name__, + {"output": completion_params}, + ) + ) return completion_params async def run_inference( @@ -230,12 +280,18 @@ async def run_inference( CallbackEvent( "on_run_start", __name__, - {"prompt": prompt, "options": options, "parameters": parameters}, + { + "prompt": prompt, + "options": options, + "parameters": parameters, + }, ) ) if not openai.api_key: - openai.api_key = get_api_key_from_environment("OPENAI_API_KEY").unwrap() + openai.api_key = get_api_key_from_environment( + "OPENAI_API_KEY" + ).unwrap() completion_data = await self.deserialize(prompt, aiconfig, parameters) # if stream enabled in runtime options and config, then stream. Otherwise don't stream. @@ -255,13 +311,22 @@ async def run_inference( # # OpenAI>1.0.0 uses pydantic models for response response = response.model_dump(exclude_none=True) - response_without_choices = {key: copy.deepcopy(value) for key, value in response.items() if key != "choices"} + response_without_choices = { + key: copy.deepcopy(value) + for key, value in response.items() + if key != "choices" + } for i, choice in enumerate(response.get("choices")): output_message = choice["message"] output_data = build_output_data(output_message) - response_without_choices.update({"finish_reason": choice.get("finish_reason")}) - metadata = {"raw_response": output_message, **response_without_choices} + response_without_choices.update( + {"finish_reason": choice.get("finish_reason")} + ) + metadata = { + "raw_response": output_message, + **response_without_choices, + } if output_message.get("role", None) is not None: metadata["role"] = output_message.get("role") @@ -281,7 +346,11 @@ async def run_inference( for chunk in response: # OpenAI>1.0.0 uses pydantic models. Chunk is of type ChatCompletionChunk; type is not directly importable from openai Library, will require some diffing chunk = chunk.model_dump(exclude_none=True) - chunk_without_choices = {key: copy.deepcopy(value) for key, value in chunk.items() if key != "choices"} + chunk_without_choices = { + key: copy.deepcopy(value) + for key, value in chunk.items() + if key != "choices" + } # streaming only returns one chunk, one choice at a time (before 1.0.0). The order in which the choices are returned is not guaranteed. messages = multi_choice_message_reducer(messages, chunk) @@ -291,7 +360,9 @@ async def run_inference( delta = choice.get("delta") if options and options.stream_callback: - options.stream_callback(delta, accumulated_message_for_choice, index) + options.stream_callback( + delta, accumulated_message_for_choice, index + ) output = ExecuteResult( **{ @@ -319,16 +390,24 @@ async def run_inference( # rewrite or extend list of outputs? prompt.outputs = outputs - await aiconfig.callback_manager.run_callbacks(CallbackEvent("on_run_complete", __name__, {"result": prompt.outputs})) + await aiconfig.callback_manager.run_callbacks( + CallbackEvent( + "on_run_complete", __name__, {"result": prompt.outputs} + ) + ) return prompt.outputs - def get_prompt_template(self, prompt: Prompt, aiconfig: "AIConfigRuntime") -> str: + def get_prompt_template( + self, prompt: Prompt, aiconfig: "AIConfigRuntime" + ) -> str: """ Returns a template for a prompt. """ if isinstance(prompt.input, str): return prompt.input - elif isinstance(prompt.input, PromptInput) and isinstance(prompt.input.data, str): + elif isinstance(prompt.input, PromptInput) and isinstance( + prompt.input.data, str + ): return prompt.input.data else: message = prompt.input @@ -358,7 +437,10 @@ def get_output_text( # Doing this to be backwards-compatible with old output format # where we used to save the ChatCompletionMessage in output.data if isinstance(output_data, ChatCompletionMessage): - if hasattr(output_data, "content") and output_data.content is not None: + if ( + hasattr(output_data, "content") + and output_data.content is not None + ): return output_data.content elif output_data.function_call is not None: return str(output_data.function_call) @@ -391,7 +473,9 @@ def reduce(acc, delta): return acc -def multi_choice_message_reducer(messages: Union[Dict[int, dict], None], chunk: dict) -> Dict[int, dict]: +def multi_choice_message_reducer( + messages: Union[Dict[int, dict], None], chunk: dict +) -> Dict[int, dict]: if messages is None: messages = {} @@ -444,7 +528,9 @@ def refine_chat_completion_params(model_settings, aiconfig, prompt): return completion_data -def add_prompt_as_message(prompt: Prompt, aiconfig: "AIConfigRuntime", messages: List, params=None): +def add_prompt_as_message( + prompt: Prompt, aiconfig: "AIConfigRuntime", messages: List, params=None +): """ Converts a given prompt to a message and adds it to the specified messages list. @@ -458,11 +544,17 @@ def add_prompt_as_message(prompt: Prompt, aiconfig: "AIConfigRuntime", messages: messages.append({"content": resolved_prompt, "role": "user"}) else: # Assumes Prompt input will be in the format of ChatCompletionMessageParam (with content, role, function_name, and name attributes) - resolved_prompt = resolve_prompt_string(prompt, params, aiconfig, prompt.input.content) + resolved_prompt = resolve_prompt_string( + prompt, params, aiconfig, prompt.input.content + ) prompt_input = prompt.input role = prompt_input.role if hasattr(prompt_input, "role") else "user" - fn_call = prompt_input.function_call if hasattr(prompt_input, "function_call") else None + fn_call = ( + prompt_input.function_call + if hasattr(prompt_input, "function_call") + else None + ) name = prompt_input.name if hasattr(prompt_input, "name") else None message_data = {"content": resolved_prompt, "role": role} @@ -480,7 +572,10 @@ def add_prompt_as_message(prompt: Prompt, aiconfig: "AIConfigRuntime", messages: if output.output_type == "execute_result": assert isinstance(output, ExecuteResult) output_data = output.data - role = output.metadata.get("role", None) or ("raw_response" in output.metadata and output.metadata["raw_response"].get("role", None)) + role = output.metadata.get("role", None) or ( + "raw_response" in output.metadata + and output.metadata["raw_response"].get("role", None) + ) if role == "assistant": output_message = {} @@ -492,8 +587,12 @@ def add_prompt_as_message(prompt: Prompt, aiconfig: "AIConfigRuntime", messages: if isinstance(output_data.value, str): content = output_data.value elif output_data.kind == "tool_calls": - assert isinstance(output_data, OutputDataWithToolCallsValue) - function_call = output_data.value[len(output_data.value) - 1].function + assert isinstance( + output_data, OutputDataWithToolCallsValue + ) + function_call = output_data.value[ + len(output_data.value) - 1 + ].function output_message["content"] = content output_message["role"] = role @@ -504,7 +603,10 @@ def add_prompt_as_message(prompt: Prompt, aiconfig: "AIConfigRuntime", messages: if function_call is not None: output_message["function_call"] = function_call - name = output.metadata.get("name", None) or ("raw_response" in output.metadata and output.metadata["raw_response"].get("name", None)) + name = output.metadata.get("name", None) or ( + "raw_response" in output.metadata + and output.metadata["raw_response"].get("name", None) + ) if name is not None: output_message["name"] = name @@ -523,7 +625,9 @@ def is_prompt_template(prompt: Prompt): """ Check if a prompt's input is a valid string. """ - return isinstance(prompt.input, str) or (hasattr(prompt.input, "data") and isinstance(prompt.input.data, str)) + return isinstance(prompt.input, str) or ( + hasattr(prompt.input, "data") and isinstance(prompt.input.data, str) + ) def build_output_data( diff --git a/python/src/aiconfig/default_parsers/palm.py b/python/src/aiconfig/default_parsers/palm.py index 749c2e540..f3271a07e 100644 --- a/python/src/aiconfig/default_parsers/palm.py +++ b/python/src/aiconfig/default_parsers/palm.py @@ -3,13 +3,21 @@ import google.generativeai as palm from aiconfig.callback import CallbackEvent -from aiconfig.default_parsers.parameterized_model_parser import ParameterizedModelParser +from aiconfig.default_parsers.parameterized_model_parser import ( + ParameterizedModelParser, +) from aiconfig.model_parser import InferenceOptions from aiconfig.util.params import resolve_parameters, resolve_prompt from google.generativeai.text import Completion from google.generativeai.types.discuss_types import MessageDict -from aiconfig.schema import ExecuteResult, Output, OutputDataWithValue, Prompt, PromptMetadata +from aiconfig.schema import ( + ExecuteResult, + Output, + OutputDataWithValue, + Prompt, + PromptMetadata, +) if TYPE_CHECKING: from aiconfig.Config import AIConfigRuntime @@ -70,12 +78,19 @@ async def serialize( ) ] - event = CallbackEvent("on_serialize_complete", __name__, {"result": prompts}) + event = CallbackEvent( + "on_serialize_complete", __name__, {"result": prompts} + ) await ai_config.callback_manager.run_callbacks(event) return prompts - async def deserialize(self, prompt: Prompt, aiconfig: "AIConfigRuntime", params: Optional[Dict] = {}) -> Dict: + async def deserialize( + self, + prompt: Prompt, + aiconfig: "AIConfigRuntime", + params: Optional[Dict] = {}, + ) -> Dict: """ Defines how to parse a prompt in the .aiconfig for a particular model and constructs the completion params for that model. @@ -86,18 +101,32 @@ async def deserialize(self, prompt: Prompt, aiconfig: "AIConfigRuntime", params: Returns: dict: Model-specific completion parameters. """ - await aiconfig.callback_manager.run_callbacks(CallbackEvent("on_deserialize_start", __name__, {"prompt": prompt, "params": params})) + await aiconfig.callback_manager.run_callbacks( + CallbackEvent( + "on_deserialize_start", + __name__, + {"prompt": prompt, "params": params}, + ) + ) # Build Completion data model_settings = self.get_model_settings(prompt, aiconfig) - completion_data = refine_chat_completion_params(model_settings, aiconfig, prompt) + completion_data = refine_chat_completion_params( + model_settings, aiconfig, prompt + ) prompt_str = resolve_prompt(prompt, params, aiconfig) # pass in the user prompt completion_data["prompt"] = prompt_str - await aiconfig.callback_manager.run_callbacks(CallbackEvent("on_deserialize_complete", __name__, {"output": completion_data})) + await aiconfig.callback_manager.run_callbacks( + CallbackEvent( + "on_deserialize_complete", + __name__, + {"output": completion_data}, + ) + ) return completion_data async def run_inference( @@ -126,7 +155,11 @@ async def run_inference( CallbackEvent( "on_run_start", __name__, - {"prompt": prompt, "options": options, "parameters": parameters}, + { + "prompt": prompt, + "options": options, + "parameters": parameters, + }, ) ) @@ -151,7 +184,11 @@ async def run_inference( outputs.append(output) prompt.outputs = outputs - await aiconfig.callback_manager.run_callbacks(CallbackEvent("on_run_complete", __name__, {"result": prompt.outputs})) + await aiconfig.callback_manager.run_callbacks( + CallbackEvent( + "on_run_complete", __name__, {"result": prompt.outputs} + ) + ) return outputs def get_output_text( @@ -225,11 +262,18 @@ async def serialize( ) ] - event = CallbackEvent("on_serialize_complete", __name__, {"result": prompts}) + event = CallbackEvent( + "on_serialize_complete", __name__, {"result": prompts} + ) await ai_config.callback_manager.run_callbacks(event) return prompts - async def deserialize(self, prompt: Prompt, aiconfig: "AIConfigRuntime", params: Optional[Dict] = {}) -> Dict: + async def deserialize( + self, + prompt: Prompt, + aiconfig: "AIConfigRuntime", + params: Optional[Dict] = {}, + ) -> Dict: """ Defines how to parse a prompt in the .aiconfig for a particular model and constructs the completion params for that model. @@ -240,20 +284,29 @@ async def deserialize(self, prompt: Prompt, aiconfig: "AIConfigRuntime", params: Returns: dict: Model-specific completion parameters. """ - await aiconfig.callback_manager.run_callbacks(CallbackEvent("on_deserialize_start", __name__, {"prompt": prompt, "params": params})) + await aiconfig.callback_manager.run_callbacks( + CallbackEvent( + "on_deserialize_start", + __name__, + {"prompt": prompt, "params": params}, + ) + ) resolved_prompt = resolve_prompt(prompt, params, aiconfig) # Build Completion data model_settings = self.get_model_settings(prompt, aiconfig) - completion_data = refine_chat_completion_params(model_settings, aiconfig, prompt) + completion_data = refine_chat_completion_params( + model_settings, aiconfig, prompt + ) # TODO: handle if user specifies previous messages in settings completion_data["messages"] = [] # Default to always use chat contextjkl; if not hasattr(prompt.metadata, "remember_chat_context") or ( - hasattr(prompt.metadata, "remember_chat_context") and prompt.metadata.remember_chat_context != False + hasattr(prompt.metadata, "remember_chat_context") + and prompt.metadata.remember_chat_context != False ): # handle chat history. check previous prompts for the same model. if same model, add prompt and its output to completion data if it has a completed output for i, previous_prompt in enumerate(aiconfig.prompts): @@ -268,22 +321,39 @@ async def deserialize(self, prompt: Prompt, aiconfig: "AIConfigRuntime", params: # check if prompt has an output. PaLM Api requires this if len(previous_prompt.outputs) > 0: - resolved_previous_prompt = resolve_parameters({}, previous_prompt, aiconfig) - completion_data["messages"].append({"content": resolved_previous_prompt, "author": "0"}) + resolved_previous_prompt = resolve_parameters( + {}, previous_prompt, aiconfig + ) + completion_data["messages"].append( + { + "content": resolved_previous_prompt, + "author": "0", + } + ) completion_data["messages"].append( { "content": aiconfig.get_output_text( previous_prompt, - aiconfig.get_latest_output(previous_prompt), + aiconfig.get_latest_output( + previous_prompt + ), ), "author": "1", } ) # pass in the user prompt - completion_data["messages"].append({"content": resolved_prompt, "author": "0"}) - await aiconfig.callback_manager.run_callbacks(CallbackEvent("on_deserialize_complete", __name__, {"output": completion_data})) + completion_data["messages"].append( + {"content": resolved_prompt, "author": "0"} + ) + await aiconfig.callback_manager.run_callbacks( + CallbackEvent( + "on_deserialize_complete", + __name__, + {"output": completion_data}, + ) + ) return completion_data async def run_inference( @@ -308,7 +378,11 @@ async def run_inference( CallbackEvent( "on_run_start", __name__, - {"prompt": prompt, "options": options, "parameters": parameters}, + { + "prompt": prompt, + "options": options, + "parameters": parameters, + }, ) ) @@ -332,7 +406,11 @@ async def run_inference( outputs.append(output) prompt.outputs = outputs - await aiconfig.callback_manager.run_callbacks(CallbackEvent("on_run_complete", __name__, {"result": prompt.outputs})) + await aiconfig.callback_manager.run_callbacks( + CallbackEvent( + "on_run_complete", __name__, {"result": prompt.outputs} + ) + ) return prompt.outputs def get_output_text( diff --git a/python/src/aiconfig/default_parsers/parameterized_model_parser.py b/python/src/aiconfig/default_parsers/parameterized_model_parser.py index 374339939..dd4ee4c95 100644 --- a/python/src/aiconfig/default_parsers/parameterized_model_parser.py +++ b/python/src/aiconfig/default_parsers/parameterized_model_parser.py @@ -45,15 +45,25 @@ async def run( aiconfig: AIConfig, options: Optional[InferenceOptions] = None, parameters: Dict = {}, - **kwargs, #TODO: We should remove and make arguments explicit + **kwargs, # TODO: We should remove and make arguments explicit ) -> List[Output]: # maybe use prompt metadata instead of kwargs? if kwargs.get("run_with_dependencies", False): - return await self.run_with_dependencies(prompt, aiconfig, options, parameters) + return await self.run_with_dependencies( + prompt, aiconfig, options, parameters + ) else: - return await self.run_inference(prompt, aiconfig, options, parameters) + return await self.run_inference( + prompt, aiconfig, options, parameters + ) - async def run_with_dependencies(self, prompt: Prompt, aiconfig: AIConfig, options=None, parameters: Dict = {}) -> List[Output]: + async def run_with_dependencies( + self, + prompt: Prompt, + aiconfig: AIConfig, + options=None, + parameters: Dict = {}, + ) -> List[Output]: """ Executes the AI model with the resolved dependencies and prompt references and returns the API response. @@ -65,7 +75,9 @@ async def run_with_dependencies(self, prompt: Prompt, aiconfig: AIConfig, option Returns: ExecuteResult: An Object containing the response from the AI model. """ - dependency_graph = get_dependency_graph(prompt, aiconfig.prompts, aiconfig.prompt_index) + dependency_graph = get_dependency_graph( + prompt, aiconfig.prompts, aiconfig.prompt_index + ) # Create a set to keep track of visited prompts visited_prompts = set() @@ -115,15 +127,23 @@ def resolve_prompt_template( Returns: str: The resolved string. """ - return resolve_prompt_string(prompt, params, ai_config, prompt_template) + return resolve_prompt_string( + prompt, params, ai_config, prompt_template + ) - def get_prompt_template(self, prompt: Prompt, aiConfig: "AIConfigRuntime") -> str: + def get_prompt_template( + self, prompt: Prompt, aiConfig: "AIConfigRuntime" + ) -> str: """ An overrideable method that returns a template for a prompt. """ if isinstance(prompt.input, str): return prompt.input - elif isinstance(prompt.input, PromptInput) and isinstance(prompt.input.data, str): + elif isinstance(prompt.input, PromptInput) and isinstance( + prompt.input.data, str + ): return prompt.input.data else: - raise Exception(f"Cannot get prompt template string from prompt input: {prompt.input}") + raise Exception( + f"Cannot get prompt template string from prompt input: {prompt.input}" + ) diff --git a/python/src/aiconfig/editor/server/server.py b/python/src/aiconfig/editor/server/server.py index 6ac905f9a..afac83691 100644 --- a/python/src/aiconfig/editor/server/server.py +++ b/python/src/aiconfig/editor/server/server.py @@ -5,14 +5,17 @@ import logging import threading import time -import webbrowser import uuid +import webbrowser from typing import Any, Dict, Type, Union import lastmile_utils.lib.core.api as core_utils import result from aiconfig.Config import AIConfigRuntime -from aiconfig.editor.server.queue_iterator import STOP_STREAMING_SIGNAL, QueueIterator +from aiconfig.editor.server.queue_iterator import ( + STOP_STREAMING_SIGNAL, + QueueIterator, +) from aiconfig.editor.server.server_utils import ( AIConfigRC, EditServerConfig, @@ -57,35 +60,56 @@ CORS(app, resources={r"/api/*": {"origins": "*"}}) -def run_backend_server(edit_config: EditServerConfig, aiconfigrc_path: str) -> Result[str, str]: +def run_backend_server( + edit_config: EditServerConfig, aiconfigrc_path: str +) -> Result[str, str]: LOGGER.setLevel(edit_config.log_level) LOGGER.info("Edit config: %s", edit_config.model_dump_json()) - LOGGER.info(f"Starting server on http://localhost:{edit_config.server_port}") + LOGGER.info( + f"Starting server on http://localhost:{edit_config.server_port}" + ) try: - LOGGER.info(f"Opening browser at http://localhost:{edit_config.server_port}") + LOGGER.info( + f"Opening browser at http://localhost:{edit_config.server_port}" + ) webbrowser.open(f"http://localhost:{edit_config.server_port}") except Exception as e: - LOGGER.warning(f"Failed to open browser: {e}. Please open http://localhost:{port} manually.") + LOGGER.warning( + f"Failed to open browser: {e}. Please open http://localhost:{port} manually." + ) app.server_state = ServerState() # type: ignore - res_server_state_init = init_server_state(app, edit_config, aiconfigrc_path) + res_server_state_init = init_server_state( + app, edit_config, aiconfigrc_path + ) match res_server_state_init: case Ok(_): LOGGER.info("Initialized server state") - debug = edit_config.server_mode in [ServerMode.DEBUG_BACKEND, ServerMode.DEBUG_SERVERS] + debug = edit_config.server_mode in [ + ServerMode.DEBUG_BACKEND, + ServerMode.DEBUG_SERVERS, + ] LOGGER.info(f"Running in {edit_config.server_mode} mode") - app.run(port=edit_config.server_port, debug=debug, use_reloader=debug) + app.run( + port=edit_config.server_port, debug=debug, use_reloader=debug + ) return Ok("Done") case Err(e): LOGGER.error(f"Failed to initialize server state: {e}") return Err(f"Failed to initialize server state: {e}") -def _validated_request_path(request_json: core_utils.JSONObject, allow_create: bool = False) -> Result[ValidatedPath, str]: +def _validated_request_path( + request_json: core_utils.JSONObject, allow_create: bool = False +) -> Result[ValidatedPath, str]: if "path" not in request_json or not isinstance(request_json["path"], str): - return Err("Request JSON must contain a 'path' field with a string value.") + return Err( + "Request JSON must contain a 'path' field with a string value." + ) else: - return get_validated_path(request_json["path"], allow_create=allow_create) + return get_validated_path( + request_json["path"], allow_create=allow_create + ) @app.route("/") @@ -115,7 +139,11 @@ def load_model_parser_module() -> FlaskResponse: case Ok(resp): return resp.to_flask_format() case Err(e): - return HttpResponseWithAIConfig(message=f"Failed to load model parser module: {e}", code=400, aiconfig=None).to_flask_format() + return HttpResponseWithAIConfig( + message=f"Failed to load model parser module: {e}", + code=400, + aiconfig=None, + ).to_flask_format() @app.route("/api/load", methods=["POST"]) @@ -124,25 +152,40 @@ def load() -> FlaskResponse: request_json = request.get_json() if not request_json.keys() <= {"path"}: return HttpResponseWithAIConfig( - message="Request JSON must contain a 'path' field with a string value, or no arguments.", code=400, aiconfig=None + message="Request JSON must contain a 'path' field with a string value, or no arguments.", + code=400, + aiconfig=None, ).to_flask_format() path: str | None = request_json.get("path", None) if path is None: aiconfig = state.aiconfig if aiconfig is None: - return HttpResponseWithAIConfig(message="No AIConfig loaded", code=400, aiconfig=None).to_flask_format() + return HttpResponseWithAIConfig( + message="No AIConfig loaded", code=400, aiconfig=None + ).to_flask_format() else: - return HttpResponseWithAIConfig(message="AIConfig already loaded. Here it is!", aiconfig=aiconfig).to_flask_format() + return HttpResponseWithAIConfig( + message="AIConfig already loaded. Here it is!", + aiconfig=aiconfig, + ).to_flask_format() else: res_path_val = get_validated_path(path) res_aiconfig = res_path_val.and_then(safe_load_from_disk) match res_aiconfig: case Ok(aiconfig): - LOGGER.warning(f"Loaded AIConfig from {res_path_val}. This may have overwritten in-memory changes.") + LOGGER.warning( + f"Loaded AIConfig from {res_path_val}. This may have overwritten in-memory changes." + ) state.aiconfig = aiconfig - return HttpResponseWithAIConfig(message="Loaded", aiconfig=aiconfig).to_flask_format() + return HttpResponseWithAIConfig( + message="Loaded", aiconfig=aiconfig + ).to_flask_format() case Err(e): - return HttpResponseWithAIConfig(message=f"Failed to load AIConfig: {res_path_val}, {e}", code=400, aiconfig=None).to_flask_format() + return HttpResponseWithAIConfig( + message=f"Failed to load AIConfig: {res_path_val}, {e}", + code=400, + aiconfig=None, + ).to_flask_format() @app.route("/api/save", methods=["POST"]) @@ -154,32 +197,52 @@ def save() -> FlaskResponse: if path is None: if aiconfig is None: - return HttpResponseWithAIConfig(message="No AIConfig loaded", code=400, aiconfig=None).to_flask_format() + return HttpResponseWithAIConfig( + message="No AIConfig loaded", code=400, aiconfig=None + ).to_flask_format() else: - LOGGER.info(f"No path provided, saving to original path, {aiconfig.file_path}") + LOGGER.info( + f"No path provided, saving to original path, {aiconfig.file_path}" + ) path = aiconfig.file_path res_path_val = get_validated_path(path, allow_create=True) match res_path_val: case Ok(path_ok): _op = make_op_run_method(MethodName("save")) - op_args: Result[OpArgs, str] = result.Ok(OpArgs({"config_filepath": path_ok})) - return run_aiconfig_operation_with_op_args(aiconfig, "save", _op, op_args) + op_args: Result[OpArgs, str] = result.Ok( + OpArgs({"config_filepath": path_ok}) + ) + return run_aiconfig_operation_with_op_args( + aiconfig, "save", _op, op_args + ) case Err(e): - return HttpResponseWithAIConfig(message=f"Failed to save AIConfig: {e}", code=400, aiconfig=None).to_flask_format() + return HttpResponseWithAIConfig( + message=f"Failed to save AIConfig: {e}", + code=400, + aiconfig=None, + ).to_flask_format() @app.route("/api/create", methods=["POST"]) def create() -> FlaskResponse: state = get_server_state(app) - aiconfig = safe_run_aiconfig_static_method(MethodName("create"), OpArgs({}), AIConfigRuntime) + aiconfig = safe_run_aiconfig_static_method( + MethodName("create"), OpArgs({}), AIConfigRuntime + ) match aiconfig: case Ok(aiconfig_ok): state.aiconfig = aiconfig_ok - return HttpResponseWithAIConfig(message="Created new AIConfig", aiconfig=aiconfig_ok).to_flask_format() + return HttpResponseWithAIConfig( + message="Created new AIConfig", aiconfig=aiconfig_ok + ).to_flask_format() case Err(e): - return HttpResponseWithAIConfig(message=f"Failed to create AIConfig: {e}", code=400, aiconfig=None).to_flask_format() + return HttpResponseWithAIConfig( + message=f"Failed to create AIConfig: {e}", + code=400, + aiconfig=None, + ).to_flask_format() @app.route("/api/run", methods=["POST"]) @@ -217,7 +280,9 @@ def run() -> FlaskResponse: # Define stream callback and queue object for streaming results output_text_queue = QueueIterator() - def update_output_queue(data: str, _accumulated_data: str, _index: int) -> None: + def update_output_queue( + data: str, _accumulated_data: str, _index: int + ) -> None: should_end_stream: bool = data == STOP_STREAMING_SIGNAL output_text_queue.put(data, should_end_stream) @@ -243,11 +308,25 @@ def run_async_config_in_thread(): output_text_queue.put(STOP_STREAMING_SIGNAL) # type: ignore def create_error_payload(message: str, code: int): - aiconfig_json = aiconfig_deep_copy.model_dump(exclude=EXCLUDE_OPTIONS) if aiconfig_deep_copy is not None else None - return json.dumps({"error": {"message": message, "code": code, "data": aiconfig_json}}) + aiconfig_json = ( + aiconfig_deep_copy.model_dump(exclude=EXCLUDE_OPTIONS) + if aiconfig_deep_copy is not None + else None + ) + return json.dumps( + { + "error": { + "message": message, + "code": code, + "data": aiconfig_json, + } + } + ) def create_cancellation_payload(): - return create_error_payload(message="The task was cancelled.", code=499) + return create_error_payload( + message="The task was cancelled.", code=499 + ) def handle_cancellation(): yield "[" @@ -271,7 +350,9 @@ def kill_thread(thread_id: int | None): # Nothing to do return - response = ctypes.pythonapi.PyThreadState_SetAsyncExc(ctypes.c_long(thread_id), ctypes.py_object(SystemExit)) + response = ctypes.pythonapi.PyThreadState_SetAsyncExc( + ctypes.c_long(thread_id), ctypes.py_object(SystemExit) + ) if response == 0: print(f"Invalid thread id {thread_id}") @@ -298,7 +379,9 @@ def kill_thread(thread_id: int | None): # we can fix later time.sleep(0.1) wait_time_in_seconds += SLEEP_DELAY_SECONDS - print(f"Output queue is currently empty. Waiting for {wait_time_in_seconds:.1f}s...") + print( + f"Output queue is currently empty. Waiting for {wait_time_in_seconds:.1f}s..." + ) # Yield in flask is weird and you either need to send responses as a # string, or artificially wrap them around "[" and "]" @@ -311,7 +394,9 @@ def kill_thread(thread_id: int | None): return if isinstance(text, Exception): - yield from create_error_payload(message=f"Exception: {text}", code=500) + yield from create_error_payload( + message=f"Exception: {text}", code=500 + ) return elif isinstance(text, str): accumulated_output_text += text @@ -334,7 +419,9 @@ def kill_thread(thread_id: int | None): } # type: ignore ) yield "[" - yield json.dumps({"output_chunk": accumulated_output.to_json()}) + yield json.dumps( + {"output_chunk": accumulated_output.to_json()} + ) yield "]" # Ensure that the run process is complete to yield final output @@ -346,17 +433,23 @@ def kill_thread(thread_id: int | None): else: state.events.pop(cancellation_token_id, None) - aiconfig_json = aiconfig.model_dump(exclude=EXCLUDE_OPTIONS) if aiconfig is not None else None + aiconfig_json = ( + aiconfig.model_dump(exclude=EXCLUDE_OPTIONS) + if aiconfig is not None + else None + ) yield "[" yield json.dumps({"aiconfig_chunk": aiconfig_json}) yield "]" - + yield "[" yield json.dumps({"stop_streaming": None}) yield "]" try: - LOGGER.info(f"Running `aiconfig.run()` command with request: {request_json}") + LOGGER.info( + f"Running `aiconfig.run()` command with request: {request_json}" + ) # Note; We run the streaming API even for non-streaming runs so that # we can unify the way we process data on the client # Streaming based on @@ -380,7 +473,9 @@ def cancel() -> FlaskResponse: state = get_server_state(app) request_json = request.get_json() - cancellation_token_id: str | None = request_json.get("cancellation_token_id") + cancellation_token_id: str | None = request_json.get( + "cancellation_token_id" + ) if cancellation_token_id is not None: event = state.events.get(cancellation_token_id) if event is not None: @@ -388,7 +483,9 @@ def cancel() -> FlaskResponse: # Remove the event from the events dict state.events.pop(cancellation_token_id) - return FlaskResponse(({"cancellation_token_id": cancellation_token_id}, 200)) + return FlaskResponse( + ({"cancellation_token_id": cancellation_token_id}, 200) + ) else: # Return a 422 Unprocessable Entity return FlaskResponse( @@ -415,27 +512,38 @@ def cancel() -> FlaskResponse: @app.route("/api/add_prompt", methods=["POST"]) def add_prompt() -> FlaskResponse: method_name = MethodName("add_prompt") - signature: dict[str, Type[Any]] = {"prompt_name": str, "prompt_data": Prompt, "index": int} + signature: dict[str, Type[Any]] = { + "prompt_name": str, + "prompt_data": Prompt, + "index": int, + } state = get_server_state(app) aiconfig = state.aiconfig request_json = request.get_json() operation = make_op_run_method(method_name) - return run_aiconfig_operation_with_request_json(aiconfig, request_json, f"method_{method_name}", operation, signature) + return run_aiconfig_operation_with_request_json( + aiconfig, request_json, f"method_{method_name}", operation, signature + ) @app.route("/api/update_prompt", methods=["POST"]) def update_prompt() -> FlaskResponse: method_name = MethodName("update_prompt") - signature: dict[str, Type[Any]] = {"prompt_name": str, "prompt_data": Prompt} + signature: dict[str, Type[Any]] = { + "prompt_name": str, + "prompt_data": Prompt, + } state = get_server_state(app) aiconfig = state.aiconfig request_json = request.get_json() operation = make_op_run_method(method_name) - return run_aiconfig_operation_with_request_json(aiconfig, request_json, f"method_{method_name}", operation, signature) + return run_aiconfig_operation_with_request_json( + aiconfig, request_json, f"method_{method_name}", operation, signature + ) @app.route("/api/delete_prompt", methods=["POST"]) @@ -448,7 +556,9 @@ def delete_prompt() -> FlaskResponse: request_json = request.get_json() operation = make_op_run_method(method_name) - return run_aiconfig_operation_with_request_json(aiconfig, request_json, f"method_{method_name}", operation, signature) + return run_aiconfig_operation_with_request_json( + aiconfig, request_json, f"method_{method_name}", operation, signature + ) @app.route("/api/update_model", methods=["POST"]) @@ -462,8 +572,18 @@ def update_model() -> FlaskResponse: prompt_name: str | None = request_json.get("prompt_name") operation = make_op_run_method(MethodName("update_model")) - operation_args: Result[OpArgs, str] = result.Ok(OpArgs({"model_name": model_name, "settings": settings, "prompt_name": prompt_name})) - return run_aiconfig_operation_with_op_args(aiconfig, "update_model", operation, operation_args) + operation_args: Result[OpArgs, str] = result.Ok( + OpArgs( + { + "model_name": model_name, + "settings": settings, + "prompt_name": prompt_name, + } + ) + ) + return run_aiconfig_operation_with_op_args( + aiconfig, "update_model", operation, operation_args + ) @app.route("/api/set_parameter", methods=["POST"]) @@ -473,14 +593,24 @@ def set_parameter() -> FlaskResponse: request_json = request.get_json() parameter_name: str | None = request_json.get("parameter_name") - parameter_value: Union[str, Dict[str, Any]] | None = request_json.get("parameter_value") + parameter_value: Union[str, Dict[str, Any]] | None = request_json.get( + "parameter_value" + ) prompt_name: str | None = request_json.get("prompt_name") operation = make_op_run_method(MethodName("set_parameter")) operation_args: Result[OpArgs, str] = result.Ok( - OpArgs({"parameter_name": parameter_name, "parameter_value": parameter_value, "prompt_name": prompt_name}) + OpArgs( + { + "parameter_name": parameter_name, + "parameter_value": parameter_value, + "prompt_name": prompt_name, + } + ) + ) + return run_aiconfig_operation_with_op_args( + aiconfig, "set_parameter", operation, operation_args ) - return run_aiconfig_operation_with_op_args(aiconfig, "set_parameter", operation, operation_args) @app.route("/api/set_parameters", methods=["POST"]) @@ -493,8 +623,12 @@ def set_parameters() -> FlaskResponse: prompt_name: str | None = request_json.get("prompt_name") operation = make_op_run_method(MethodName("set_parameters")) - operation_args: Result[OpArgs, str] = result.Ok(OpArgs({"parameters": parameters, "prompt_name": prompt_name})) - return run_aiconfig_operation_with_op_args(aiconfig, "set_parameters", operation, operation_args) + operation_args: Result[OpArgs, str] = result.Ok( + OpArgs({"parameters": parameters, "prompt_name": prompt_name}) + ) + return run_aiconfig_operation_with_op_args( + aiconfig, "set_parameters", operation, operation_args + ) @app.route("/api/delete_parameter", methods=["POST"]) @@ -507,8 +641,12 @@ def delete_parameter() -> FlaskResponse: prompt_name: str | None = request_json.get("prompt_name") operation = make_op_run_method(MethodName("delete_parameter")) - operation_args: Result[OpArgs, str] = result.Ok(OpArgs({"parameter_name": parameter_name, "prompt_name": prompt_name})) - return run_aiconfig_operation_with_op_args(aiconfig, "delete_parameter", operation, operation_args) + operation_args: Result[OpArgs, str] = result.Ok( + OpArgs({"parameter_name": parameter_name, "prompt_name": prompt_name}) + ) + return run_aiconfig_operation_with_op_args( + aiconfig, "delete_parameter", operation, operation_args + ) @app.route("/api/set_name", methods=["POST"]) @@ -521,7 +659,9 @@ def set_name() -> FlaskResponse: operation = make_op_run_method(MethodName("set_name")) operation_args: Result[OpArgs, str] = result.Ok(OpArgs({"name": name})) - return run_aiconfig_operation_with_op_args(aiconfig, "set_name", operation, operation_args) + return run_aiconfig_operation_with_op_args( + aiconfig, "set_name", operation, operation_args + ) @app.route("/api/set_description", methods=["POST"]) @@ -533,8 +673,12 @@ def set_description() -> FlaskResponse: description: str | None = request_json.get("description") operation = make_op_run_method(MethodName("set_description")) - operation_args: Result[OpArgs, str] = result.Ok(OpArgs({"description": description})) - return run_aiconfig_operation_with_op_args(aiconfig, "set_description", operation, operation_args) + operation_args: Result[OpArgs, str] = result.Ok( + OpArgs({"description": description}) + ) + return run_aiconfig_operation_with_op_args( + aiconfig, "set_description", operation, operation_args + ) @app.route("/api/clear_outputs", methods=["POST"]) @@ -554,7 +698,9 @@ def clear_outputs() -> FlaskResponse: aiconfig=None, ).to_flask_format() - def _op(aiconfig_runtime: AIConfigRuntime, _op_args: OpArgs) -> Result[None, str]: + def _op( + aiconfig_runtime: AIConfigRuntime, _op_args: OpArgs + ) -> Result[None, str]: for prompt in aiconfig_runtime.prompts: prompt_name = prompt.name # fn name `delete_output`` is misleading. TODO: Rename to `delete_outputs`` in AIConfig API @@ -562,25 +708,31 @@ def _op(aiconfig_runtime: AIConfigRuntime, _op_args: OpArgs) -> Result[None, str return Ok(None) signature: dict[str, Type[Any]] = {} - return run_aiconfig_operation_with_request_json(aiconfig, request_json, f"method_", _op, signature) + return run_aiconfig_operation_with_request_json( + aiconfig, request_json, f"method_", _op, signature + ) @app.route("/api/get_aiconfigrc", methods=["GET"]) def get_aiconfigrc() -> FlaskResponse: state = get_server_state(app) - yaml_mapping: Result[AIConfigRC, str] = core_utils.read_text_file(state.aiconfigrc_path).and_then(AIConfigRC.from_yaml) + yaml_mapping: Result[AIConfigRC, str] = core_utils.read_text_file( + state.aiconfigrc_path + ).and_then(AIConfigRC.from_yaml) match yaml_mapping: case Ok(yaml_mapping_ok): return FlaskResponse((yaml_mapping_ok.model_dump(), 200)) case Err(e): - return FlaskResponse(({"message": f"Failed to load aiconfigrc: {e}"}, 400)) + return FlaskResponse( + ({"message": f"Failed to load aiconfigrc: {e}"}, 400) + ) @app.route("/api/set_aiconfigrc", methods=["POST"]) def set_aiconfigrc() -> FlaskResponse: - state = get_server_state(app) - request_json = request.get_json() + get_server_state(app) + request.get_json() # TODO: # We might not need to implement this at all. # diff --git a/python/src/aiconfig/editor/server/server_utils.py b/python/src/aiconfig/editor/server/server_utils.py index 19465ad9b..3e2cbaf55 100644 --- a/python/src/aiconfig/editor/server/server_utils.py +++ b/python/src/aiconfig/editor/server/server_utils.py @@ -66,7 +66,9 @@ class EditServerConfig(core_utils.Record): parsers_module_path: str = "aiconfig_model_registry.py" @field_validator("server_mode", mode="before") - def convert_to_mode(cls, value: Any) -> ServerMode: # pylint: disable=no-self-argument + def convert_to_mode( + cls, value: Any + ) -> ServerMode: # pylint: disable=no-self-argument if isinstance(value, str): try: return ServerMode[value.upper()] @@ -89,7 +91,9 @@ class Config: extra = "forbid" @classmethod - def from_yaml(cls: Type["AIConfigRC"], yaml: str) -> Result["AIConfigRC", str]: + def from_yaml( + cls: Type["AIConfigRC"], yaml: str + ) -> Result["AIConfigRC", str]: try: loaded = YAML().load(yaml) loaded_dict = dict(loaded) @@ -128,9 +132,14 @@ class HttpResponseWithAIConfig: } def to_flask_format(self) -> FlaskResponse: - out: core_utils.JSONObject = {"message": self.message, **(self.payload if self.payload is not None else {})} + out: core_utils.JSONObject = { + "message": self.message, + **(self.payload if self.payload is not None else {}), + } if self.aiconfig is not None: - out["aiconfig"] = self.aiconfig.model_dump(exclude=HttpResponseWithAIConfig.EXCLUDE_OPTIONS) + out["aiconfig"] = self.aiconfig.model_dump( + exclude=HttpResponseWithAIConfig.EXCLUDE_OPTIONS + ) return FlaskResponse((out, self.code)) @@ -143,7 +152,9 @@ def resolve_path(path: str) -> str: return os.path.abspath(os.path.expanduser(path)) -def get_validated_path(raw_path: str | None, allow_create: bool = False) -> Result[ValidatedPath, str]: +def get_validated_path( + raw_path: str | None, allow_create: bool = False +) -> Result[ValidatedPath, str]: LOGGER.debug(f"{allow_create=}") if not raw_path: return Err("No path provided") @@ -160,11 +171,15 @@ def _import_module_from_path(path_to_module: str) -> Result[ModuleType, str]: module_name = os.path.basename(resolved_path).replace(".py", "") try: - spec = importlib.util.spec_from_file_location(module_name, resolved_path) + spec = importlib.util.spec_from_file_location( + module_name, resolved_path + ) if spec is None: return Err(f"Could not import module from path: {resolved_path}") elif spec.loader is None: - return Err(f"Could not import module from path: {resolved_path} (no loader)") + return Err( + f"Could not import module from path: {resolved_path} (no loader)" + ) else: module = importlib.util.module_from_spec(spec) sys.modules[module_name] = module @@ -174,17 +189,25 @@ def _import_module_from_path(path_to_module: str) -> Result[ModuleType, str]: return core_utils.ErrWithTraceback(e) -def _load_register_fn_from_user_module(user_module: ModuleType) -> Result[Callable[[], None], str]: +def _load_register_fn_from_user_module( + user_module: ModuleType, +) -> Result[Callable[[], None], str]: if not hasattr(user_module, "register_model_parsers"): - return Err(f"User module {user_module} does not have a register_model_parsers function.") + return Err( + f"User module {user_module} does not have a register_model_parsers function." + ) register_fn = getattr(user_module, "register_model_parsers") if not callable(register_fn): - return Err(f"User module {user_module} does not have a register_model_parsers function") + return Err( + f"User module {user_module} does not have a register_model_parsers function" + ) else: return Ok(register_fn) -def _register_user_model_parsers(user_register_fn: Callable[[], None]) -> Result[None, str]: +def _register_user_model_parsers( + user_register_fn: Callable[[], None] +) -> Result[None, str]: try: return Ok(user_register_fn()) except Exception as e: @@ -202,21 +225,31 @@ def load_user_parser_module(path_to_module: str) -> Result[None, str]: return register_result -def get_http_response_load_user_parser_module(path_to_module: str) -> HttpResponseWithAIConfig: +def get_http_response_load_user_parser_module( + path_to_module: str, +) -> HttpResponseWithAIConfig: register_result = load_user_parser_module(path_to_module) match register_result: case Ok(_): - msg = f"Successfully registered model parsers from {path_to_module}" + msg = ( + f"Successfully registered model parsers from {path_to_module}" + ) LOGGER.info(msg) return HttpResponseWithAIConfig(message=msg, aiconfig=None) case Err(e): - msg = f"Failed to register model parsers from {path_to_module}: {e}" + msg = ( + f"Failed to register model parsers from {path_to_module}: {e}" + ) LOGGER.error(msg) - return HttpResponseWithAIConfig(message=msg, code=400, aiconfig=None) + return HttpResponseWithAIConfig( + message=msg, code=400, aiconfig=None + ) def _load_user_parser_module_if_exists(parsers_module_path: str) -> None: - res = get_validated_path(parsers_module_path).and_then(load_user_parser_module) + res = get_validated_path(parsers_module_path).and_then( + load_user_parser_module + ) match res: case Ok(_): LOGGER.info(f"Loaded parsers module from {parsers_module_path}") @@ -224,7 +257,9 @@ def _load_user_parser_module_if_exists(parsers_module_path: str) -> None: LOGGER.warning(f"Failed to load parsers module: {e}") -def safe_load_from_disk(aiconfig_path: ValidatedPath) -> Result[AIConfigRuntime, str]: +def safe_load_from_disk( + aiconfig_path: ValidatedPath, +) -> Result[AIConfigRuntime, str]: try: aiconfig = AIConfigRuntime.load(aiconfig_path) # type: ignore return Ok(aiconfig) @@ -232,7 +267,9 @@ def safe_load_from_disk(aiconfig_path: ValidatedPath) -> Result[AIConfigRuntime, return core_utils.ErrWithTraceback(e) -def init_server_state(app: Flask, edit_config: EditServerConfig, aiconfigrc_path: str) -> Result[None, str]: +def init_server_state( + app: Flask, edit_config: EditServerConfig, aiconfigrc_path: str +) -> Result[None, str]: LOGGER.info("Initializing server state") _load_user_parser_module_if_exists(edit_config.parsers_module_path) state = get_server_state(app) @@ -246,32 +283,51 @@ def init_server_state(app: Flask, edit_config: EditServerConfig, aiconfigrc_path match aiconfig_runtime: case Ok(aiconfig_runtime_): state.aiconfig = aiconfig_runtime_ - LOGGER.info(f"Loaded AIConfig from {edit_config.aiconfig_path}") + LOGGER.info( + f"Loaded AIConfig from {edit_config.aiconfig_path}" + ) return Ok(None) case Err(e): - LOGGER.error(f"Failed to load AIConfig from {edit_config.aiconfig_path}: {e}") - return Err(f"Failed to load AIConfig from {edit_config.aiconfig_path}: {e}") + LOGGER.error( + f"Failed to load AIConfig from {edit_config.aiconfig_path}: {e}" + ) + return Err( + f"Failed to load AIConfig from {edit_config.aiconfig_path}: {e}" + ) else: LOGGER.info(f"Creating new AIConfig at {edit_config.aiconfig_path}") aiconfig_runtime = AIConfigRuntime.create() # type: ignore model_ids = ModelParserRegistry.parser_ids() if len(model_ids) > 0: - aiconfig_runtime.add_prompt("prompt_1", Prompt(name="prompt_1", input="", metadata=PromptMetadata(model=model_ids[0]))) + aiconfig_runtime.add_prompt( + "prompt_1", + Prompt( + name="prompt_1", + input="", + metadata=PromptMetadata(model=model_ids[0]), + ), + ) state.aiconfig = aiconfig_runtime LOGGER.info("Created new AIConfig") try: aiconfig_runtime.save(edit_config.aiconfig_path) aiconfig_runtime.file_path = edit_config.aiconfig_path # type: ignore[bug in runtime init] - LOGGER.info(f"Saved new AIConfig to {edit_config.aiconfig_path} (aiconfig path field: {aiconfig_runtime.file_path})") + LOGGER.info( + f"Saved new AIConfig to {edit_config.aiconfig_path} (aiconfig path field: {aiconfig_runtime.file_path})" + ) state.aiconfig = aiconfig_runtime return Ok(None) except Exception as e: - LOGGER.error(f"Failed to create new AIConfig at {edit_config.aiconfig_path}: {e}") + LOGGER.error( + f"Failed to create new AIConfig at {edit_config.aiconfig_path}: {e}" + ) return core_utils.ErrWithTraceback(e) -def _safe_run_aiconfig_method(aiconfig: AIConfigRuntime, method_name: MethodName, method_args: OpArgs) -> Result[None, str]: +def _safe_run_aiconfig_method( + aiconfig: AIConfigRuntime, method_name: MethodName, method_args: OpArgs +) -> Result[None, str]: # TODO: use `out` try: method = getattr(aiconfig, method_name) @@ -283,7 +339,9 @@ def _safe_run_aiconfig_method(aiconfig: AIConfigRuntime, method_name: MethodName return core_utils.ErrWithTraceback(e) -def safe_run_aiconfig_static_method(method_name: MethodName, method_args: OpArgs, output_typ: Type[T]) -> Result[T, str]: +def safe_run_aiconfig_static_method( + method_name: MethodName, method_args: OpArgs, output_typ: Type[T] +) -> Result[T, str]: try: method = getattr(AIConfigRuntime, method_name) out = method(**method_args) @@ -295,7 +353,9 @@ def safe_run_aiconfig_static_method(method_name: MethodName, method_args: OpArgs def make_op_run_method(method_name: MethodName) -> Operation[None]: - def _op(aiconfig: AIConfigRuntime, operation_args: OpArgs) -> Result[None, str]: + def _op( + aiconfig: AIConfigRuntime, operation_args: OpArgs + ) -> Result[None, str]: LOGGER.info(f"Running method: {method_name}, {operation_args=}") return _safe_run_aiconfig_method(aiconfig, method_name, operation_args) @@ -349,10 +409,16 @@ def _validated_op_args_from_request_json( signature: dict[str, Type[Any]], ) -> Result[OpArgs, str]: if signature.keys() != request_json.keys(): - LOGGER.info(f"Expected keys: {signature.keys()}, got: {request_json.keys()}") - return Err(f"Expected keys: {signature.keys()}, got: {request_json.keys()}") - - def _resolve(key: str, value: core_utils.JSONValue, signature: dict[str, Type[Any]]) -> Result[Any, str]: + LOGGER.info( + f"Expected keys: {signature.keys()}, got: {request_json.keys()}" + ) + return Err( + f"Expected keys: {signature.keys()}, got: {request_json.keys()}" + ) + + def _resolve( + key: str, value: core_utils.JSONValue, signature: dict[str, Type[Any]] + ) -> Result[Any, str]: LOGGER.info(f"Resolving: {key}, {value}, {signature[key]}") _type = signature[key] if isinstance(value, typing.Dict): @@ -360,9 +426,15 @@ def _resolve(key: str, value: core_utils.JSONValue, signature: dict[str, Type[An else: return Ok(value) - operation_args_results = {key: _resolve(key, value, signature) for key, value in request_json.items()} + operation_args_results = { + key: _resolve(key, value, signature) + for key, value in request_json.items() + } LOGGER.info(f"{operation_args_results=}") - res_op_args: Result[OpArgs, str] = cast(Result[OpArgs, str], core_utils.result_reduce_dict_all_ok(operation_args_results)) + res_op_args: Result[OpArgs, str] = cast( + Result[OpArgs, str], + core_utils.result_reduce_dict_all_ok(operation_args_results), + ) return res_op_args @@ -386,7 +458,9 @@ def run_aiconfig_operation_with_request_json( op_args = _validated_op_args_from_request_json(request_json, signature) match op_args: case Ok(op_args_ok): - return run_aiconfig_operation_with_op_args(aiconfig, operation_name, operation, Ok(op_args_ok)) + return run_aiconfig_operation_with_op_args( + aiconfig, operation_name, operation, Ok(op_args_ok) + ) case Err(e): return HttpResponseWithAIConfig( message=f"Failed to run {operation_name}: {e}", diff --git a/python/src/aiconfig/eval/api/__init__.py b/python/src/aiconfig/eval/api/__init__.py index f04ad9683..164ff6234 100644 --- a/python/src/aiconfig/eval/api/__init__.py +++ b/python/src/aiconfig/eval/api/__init__.py @@ -9,16 +9,20 @@ TestSuiteWithInputsSettings, ) """ -from .. import common, metrics +from .. import test_suite_common, test_suite_metrics # pyright: reportWildcardImportFromLibrary=false -from ..lib import TestSuiteWithInputsSettings, run_test_suite_outputs_only, run_test_suite_with_inputs -from ..metrics import Metric, brevity, substring_match +from ..test_suite_lib import ( + TestSuiteWithInputsSettings, + run_test_suite_outputs_only, + run_test_suite_with_inputs, +) +from ..test_suite_metrics import TestSuiteMetric, brevity, substring_match __all__ = [ - "Metric", - "common", - "metrics", + "TestSuiteMetric", + "test_suite_common", + "test_suite_metrics", "brevity", "substring_match", "run_test_suite_with_inputs", diff --git a/python/src/aiconfig/eval/batch_common.py b/python/src/aiconfig/eval/batch_common.py new file mode 100644 index 000000000..9122b8d05 --- /dev/null +++ b/python/src/aiconfig/eval/batch_common.py @@ -0,0 +1,31 @@ +from abc import abstractmethod +from typing import Protocol, Sequence, TypeVar + +from aiconfig.eval import batch_common, common + +T_Ref = TypeVar("T_Ref") +T_Ref_contra = TypeVar("T_Ref_contra", contravariant=True) + + +class BatchEvaluationFunctionWithReference( + Protocol[ + common.T_Evaluable, batch_common.T_Ref_contra, common.T_MetricValue_inv + ] +): + @abstractmethod + async def __call__( + self, + data: Sequence[common.T_Evaluable], + ref: Sequence[batch_common.T_Ref_contra], + ) -> list[common.T_MetricValue_inv]: + pass + + +class BatchEvaluationFunctionWithoutReference( + Protocol[common.T_Evaluable, common.T_MetricValue_inv] +): + @abstractmethod + async def __call__( + self, data: Sequence[common.T_Evaluable] + ) -> list[common.T_MetricValue_inv]: + pass diff --git a/python/src/aiconfig/eval/batch_lib.py b/python/src/aiconfig/eval/batch_lib.py new file mode 100644 index 000000000..82b247576 --- /dev/null +++ b/python/src/aiconfig/eval/batch_lib.py @@ -0,0 +1,402 @@ +import asyncio +import logging +from dataclasses import dataclass +from functools import partial +from typing import Any, Coroutine, Generic, Sequence, cast + +import lastmile_utils.lib.core.api as core_utils +import pandas as pd +import result +from aiconfig.eval import batch_common, batch_metrics, common +from result import Result + +logging.basicConfig(format=core_utils.LOGGER_FMT) +LOGGER = logging.getLogger(__name__) + +# Types + + +@dataclass(frozen=True) +class BatchEvalGeneralSettings: + eval_fn_timeout_s: int = 5 + + +@dataclass(frozen=True) +class EvaluableTableWithReference( + Generic[common.T_Evaluable, batch_common.T_Ref] +): + df: pd.DataFrame + + @staticmethod + def make( + # At this point, don't care about the type of input_data. It's display-only now. + input_data: Sequence[Any] | None, + evaluable: Sequence[common.T_Evaluable], + ref_data: Sequence[batch_common.T_Ref], + ) -> Result[ + "EvaluableTableWithReference[common.T_Evaluable, batch_common.T_Ref]", + str, + ]: + # make_df is untyped, but it's safe to cast it here because the types are annotated in this function signature. + # We can clearly see here that the output df types will match the input types, so it's safe to cast the output. + df = common.make_df( + { + "input_data": input_data, + "ref_data": ref_data, + "evaluable": evaluable, + } + ) + out: Result[ + EvaluableTableWithReference[ + common.T_Evaluable, batch_common.T_Ref + ], + str, + ] = cast( + # + Result[ + EvaluableTableWithReference[ + common.T_Evaluable, batch_common.T_Ref + ], + str, + ], + df.map(EvaluableTableWithReference), + ) + return out + + async def calculate( + self, + metric: batch_metrics.BatchMetricWithReference[ + common.T_Evaluable, batch_common.T_Ref, common.T_MetricValue + ], + ) -> Result["ResultTable[common.T_Evaluable, common.T_MetricValue]", str]: + evaluable: Sequence[common.T_Evaluable] = cast( + # + Sequence[common.T_Evaluable], + self.df.evaluable, # type: ignore[pandas] + ) + + ref_data: Sequence[batch_common.T_Ref] = cast( + # + Sequence[batch_common.T_Ref], + self.df.ref_data, # type: ignore[pandas] + ) + + @core_utils.exception_to_err_with_traceback_async + async def _run(): + return await metric.evaluation_fn(evaluable, ref_data) + + def _make( + values_ok: list[common.T_MetricValue], + ) -> Result[ + ResultTable[common.T_Evaluable, common.T_MetricValue], str + ]: + # safe annotation, we know what's in the df. + out: Result[ + ResultTable[common.T_Evaluable, common.T_MetricValue], str + ] = ResultTable.make(self.df, values_ok) + return out + + values = await _run() + out = values.and_then(_make) + return out + + +@dataclass(frozen=True) +class EvaluableTableWithoutRef(Generic[common.T_Evaluable]): + df: pd.DataFrame + + @staticmethod + def make( + # At this point, I don't care about the type of input_data. It's display-only now. + input_data: Sequence[Any] | None, + evaluable: Sequence[common.T_Evaluable], + ) -> Result["EvaluableTableWithoutRef[common.T_Evaluable]", str]: + # make_df is untyped, but it's safe to cast it here because the types are annotated in this function signature. + # We can clearly see here that the output df types will match the input types, so it's safe to cast the output. + df = common.make_df({"input_data": input_data, "evaluable": evaluable}) + out: Result[EvaluableTableWithoutRef[common.T_Evaluable], str] = cast( + # + Result[EvaluableTableWithoutRef[common.T_Evaluable], str], + df.map(EvaluableTableWithoutRef), + ) + return out + + async def calculate( + self, + metric: batch_metrics.BatchMetricWithoutReference[ + common.T_Evaluable, common.T_MetricValue + ], + ) -> Result["ResultTable[common.T_Evaluable, common.T_MetricValue]", str]: + evaluable: Sequence[common.T_Evaluable] = cast( + # + Sequence[common.T_Evaluable], + self.df.evaluable, # type: ignore[pandas] + ) + + @core_utils.exception_to_err_with_traceback_async + async def _run(): + return await metric.evaluation_fn(evaluable) + + def _make( + values_ok: list[common.T_MetricValue], + ) -> Result[ + ResultTable[common.T_Evaluable, common.T_MetricValue], str + ]: + # safe annotation, we know what's in the df. + out: Result[ + ResultTable[common.T_Evaluable, common.T_MetricValue], str + ] = ResultTable.make(self.df, values_ok) + return out + + values = await _run() + out = values.and_then(_make) + return out + + +@dataclass(frozen=True) +class ResultTable(Generic[common.T_Evaluable, common.T_MetricValue]): + df: pd.DataFrame + + @staticmethod + def make( + df_evaluable: pd.DataFrame, + metric_values: Sequence[common.T_MetricValue], + ) -> Result["ResultTable[common.T_Evaluable, common.T_MetricValue]", str]: + if len(df_evaluable) != len(metric_values): + return result.Err( + f"len(df_evaluable) != len(metric_values): {len(df_evaluable)} != {len(metric_values)}" + ) + else: + return result.Ok(ResultTable(df_evaluable.assign(metric_values=metric_values))) # type: ignore[pandas] + + @staticmethod + def concatenate_tables( + tables: Sequence[ + "ResultTable[common.T_Evaluable, common.T_MetricValue]" + ], + ) -> Result["ResultTable[common.T_Evaluable, common.T_MetricValue]", str]: + dfs = [table.df for table in tables if len(table.df) > 0] + df = pd.concat(dfs) # type: ignore[pandas] + return result.Ok(ResultTable(df)) + + +# API + + +async def run_evaluation( + # + evaluable: Sequence[str], + reference: Sequence[str] | None, + metrics: batch_metrics.BatchMetrics[str, str, common.T_MetricValue], + settings: BatchEvalGeneralSettings | None = None, +) -> pd.DataFrame: + settings_ = settings or BatchEvalGeneralSettings() + res_table = await _evaluable_to_result_table( + None, evaluable, reference, metrics, settings_ + ) + return res_table.map(_process_result_table_to_df).unwrap_or_raise( + ValueError + ) + + +async def _evaluable_to_result_table( + # Intentional any. Inputs is display-only + inputs: Sequence[Any] | None, + evaluable: Sequence[str], + reference: Sequence[str] | None, + metrics: batch_metrics.BatchMetrics[str, str, common.T_MetricValue], + settings: BatchEvalGeneralSettings, +): + match metrics: + case batch_metrics.BatchMetricsWithReference(metrics=metrics_): + if not reference: + raise ValueError( + "got BatchMetricsWithReference, reference cannot be None" + ) + else: + table = EvaluableTableWithReference.make( + inputs, evaluable, reference + ) + res = await result.do_async( + # + await _run_evaluation_helper_with_ref( + table_ok, metrics_, settings + ) + for table_ok in table + ) + return res + case batch_metrics.BatchMetricsWithoutReference(metrics=metrics_): + if reference: + raise ValueError( + "got BatchMetricsWithoutReference, reference must be None" + ) + else: + table = EvaluableTableWithoutRef.make(inputs, evaluable) + res = await result.do_async( + # + await _run_evaluation_helper_without_ref( + table_ok, metrics_, settings + ) + for table_ok in table + ) + return res + + +async def run_aiconfig_and_evaluation( + # + aiconfig_path: str, + prompt_name: str, + aiconfig_params: Sequence[common.TextBasedInputDatum], + reference: Sequence[str] | None, + metrics: batch_metrics.BatchMetrics[str, str, common.T_MetricValue], + settings: BatchEvalGeneralSettings | None = None, +) -> pd.DataFrame: + settings_ = settings or BatchEvalGeneralSettings() + evaluable = await _run_aiconfig_batch_helper( + aiconfig_path, prompt_name, aiconfig_params + ) + + res_table = await result.do_async( + await _evaluable_to_result_table( + aiconfig_params, evaluable_ok, reference, metrics, settings_ + ) + for evaluable_ok in evaluable + ) + return res_table.map(_process_result_table_to_df).unwrap_or_raise( + ValueError + ) + + +# Implementation + + +async def _run_aiconfig_batch_helper( + # + aiconfig_path: str, + prompt_name: str, + params_seq: Sequence[common.TextBasedInputDatum], +) -> result.Result[list[common.TextOutput], str]: + aiconfig = common.load_aiconfig_runtime(aiconfig_path) + + out = await result.do_async( + await common.batch_run_aiconfig_on_text_based_input( + # + aiconfig_ok, + prompt_name, + params_seq, + ) + for aiconfig_ok in aiconfig + ) + return out + + +async def _run_evaluation_helper_with_ref( + evaluable_with_ref: EvaluableTableWithReference[ + common.T_Evaluable, batch_common.T_Ref + ], + metrics: Sequence[ + batch_metrics.BatchMetricWithReference[ + common.T_Evaluable, batch_common.T_Ref, common.T_MetricValue + ] + ], + settings: BatchEvalGeneralSettings, +) -> result.Result[ResultTable[common.T_Evaluable, common.T_MetricValue], str]: + timeout_s = settings.eval_fn_timeout_s + + async def _calculate( + metric: batch_metrics.BatchMetricWithReference[ + common.T_Evaluable, batch_common.T_Ref, common.T_MetricValue + ], + ): + async def _thunk() -> Result[ + ResultTable[common.T_Evaluable, common.T_MetricValue], str + ]: + return await evaluable_with_ref.calculate(metric) + + values = await async_thunk_with_timeout(_thunk(), timeout_s=timeout_s) + return values + + res = await core_utils.result_reduce_list_all_ok_async( + map( + partial(_calculate), + metrics, + ) + ) + + match res: + case result.Ok(res_): + list_results = core_utils.result_reduce_list_all_ok(res_) + match list_results: + case result.Ok(list_results_ok): + all_results = ResultTable.concatenate_tables( + list_results_ok + ) + return all_results + case result.Err(err): + return result.Err(err) + case result.Err(err): + return result.Err(err) + + +async def _run_evaluation_helper_without_ref( + evaluable_without_ref: EvaluableTableWithoutRef[common.T_Evaluable], + metrics: Sequence[ + batch_metrics.BatchMetricWithoutReference[ + common.T_Evaluable, common.T_MetricValue + ] + ], + settings: BatchEvalGeneralSettings, +) -> result.Result[ResultTable[common.T_Evaluable, common.T_MetricValue], str]: + timeout_s = settings.eval_fn_timeout_s + + async def _calculate( + metric: batch_metrics.BatchMetricWithoutReference[ + common.T_Evaluable, common.T_MetricValue + ], + ): + async def _thunk() -> Result[ + ResultTable[common.T_Evaluable, common.T_MetricValue], str + ]: + return await evaluable_without_ref.calculate(metric) + + values = await async_thunk_with_timeout(_thunk(), timeout_s=timeout_s) + return values + + res = await core_utils.result_reduce_list_all_ok_async( + map( + partial(_calculate), + metrics, + ) + ) + + match res: + case result.Ok(res_): + list_results = core_utils.result_reduce_list_all_ok(res_) + match list_results: + case result.Ok(list_results_ok): + all_results = ResultTable.concatenate_tables( + list_results_ok + ) + return all_results + case result.Err(err): + return result.Err(err) + case result.Err(err): + return result.Err(err) + + +async def async_thunk_with_timeout( + thunk: Coroutine[None, None, common.T_cov], timeout_s: int +) -> result.Result[common.T_cov, str]: + task = asyncio.create_task(thunk) + try: + res = await asyncio.wait_for(task, timeout=timeout_s) + return result.Ok(res) + except asyncio.TimeoutError: + task.cancel() + return result.Err( + f"async_thunk_with_timeout, {thunk.__name__} timed out after {timeout_s}s" + ) + + +def _process_result_table_to_df(eval_res: ResultTable[common.T_Evaluable, common.T_MetricValue]) -> pd.DataFrame: # type: ignore[pandas untyped] + raise NotImplementedError diff --git a/python/src/aiconfig/eval/batch_metrics.py b/python/src/aiconfig/eval/batch_metrics.py new file mode 100644 index 000000000..a8367005b --- /dev/null +++ b/python/src/aiconfig/eval/batch_metrics.py @@ -0,0 +1,80 @@ +from dataclasses import dataclass +from typing import Generic, Sequence + +from aiconfig.eval import batch_common, common + + +@dataclass(frozen=True) +class BatchMetricWithReference( + Generic[common.T_Evaluable, batch_common.T_Ref, common.T_MetricValue] +): + """See metrics.py for examples.""" + + evaluation_fn: batch_common.BatchEvaluationFunctionWithReference[ + common.T_Evaluable, batch_common.T_Ref, common.T_MetricValue + ] + metric_metadata: common.EvaluationMetricMetadata[ + common.T_Evaluable, common.T_MetricValue + ] + + async def __call__( + self, + data: Sequence[common.T_Evaluable], + ref: Sequence[batch_common.T_Ref], + ) -> list[common.T_MetricValue]: + """ + For convenience, make a Metric callable. + Similar to torch Module `forward()`. + """ + return await self.evaluation_fn(data, ref) + + +@dataclass(frozen=True) +class BatchMetricWithoutReference( + Generic[common.T_Evaluable, common.T_MetricValue] +): + """See metrics.py for examples.""" + + evaluation_fn: batch_common.BatchEvaluationFunctionWithoutReference[ + common.T_Evaluable, common.T_MetricValue + ] + metric_metadata: common.EvaluationMetricMetadata[ + common.T_Evaluable, common.T_MetricValue + ] + + async def __call__( + self, data: Sequence[common.T_Evaluable] + ) -> list[common.T_MetricValue]: + """ + For convenience, make a Metric callable. + Similar to torch Module `forward()`. + """ + return await self.evaluation_fn(data) + + +@dataclass +class BatchMetricsWithReference( + Generic[common.T_Evaluable, batch_common.T_Ref, common.T_MetricValue] +): + metrics: Sequence[ + BatchMetricWithReference[ + common.T_Evaluable, batch_common.T_Ref, common.T_MetricValue + ] + ] + + +@dataclass +class BatchMetricsWithoutReference( + Generic[common.T_Evaluable, common.T_MetricValue] +): + metrics: Sequence[ + BatchMetricWithoutReference[common.T_Evaluable, common.T_MetricValue] + ] + + +BatchMetrics = ( + BatchMetricsWithReference[ + common.T_Evaluable, batch_common.T_Ref, common.T_MetricValue + ] + | BatchMetricsWithoutReference[common.T_Evaluable, common.T_MetricValue] +) diff --git a/python/src/aiconfig/eval/common.py b/python/src/aiconfig/eval/common.py index 1101cb185..75c498b4d 100644 --- a/python/src/aiconfig/eval/common.py +++ b/python/src/aiconfig/eval/common.py @@ -1,22 +1,25 @@ import json -from abc import ABC, abstractmethod +from abc import ABC from dataclasses import dataclass -from typing import Any, Generic, NewType, Protocol, Type, TypeVar +from typing import Any, Generic, NewType, Sequence, TypeVar import lastmile_utils.lib.core.api as core_utils -import result +import pandas as pd from aiconfig.Config import AIConfigRuntime -from pydantic import BaseModel +from aiconfig.eval import common +from frozendict import frozendict from result import Result +T_cov = TypeVar("T_cov", covariant=True) +U_cov = TypeVar("U_cov", covariant=True) + T_InputDatum = TypeVar("T_InputDatum", contravariant=True) T_OutputDatum = TypeVar("T_OutputDatum", contravariant=True) -T_Evaluable = TypeVar("T_Evaluable", contravariant=True) - -T_BaseModel = TypeVar("T_BaseModel", bound=BaseModel) -SerializedJSON = NewType("SerializedJSON", str) +# NOTE: it's probably better to avoid NewType in the future, because it doesn't +# ... create a ... new type. For example, you can't pattern match against it. +TextOutput = NewType("TextOutput", str) @dataclass(frozen=True) @@ -33,33 +36,26 @@ class CustomMetricValue(ABC): """ -T_MetricValue = TypeVar("T_MetricValue", int, float, str, bool, CustomMetricValue, covariant=True) - - -class CompletionTextToSerializedJSON(Protocol): - @abstractmethod - def __call__(self, output_datum: str) -> Result[SerializedJSON, str]: - pass - - -@dataclass(frozen=True) -class CustomMetricPydanticObject(CustomMetricValue, Generic[T_BaseModel]): - data: T_BaseModel +T_Evaluable = TypeVar("T_Evaluable", contravariant=True) -class EvaluationFunction(Protocol, Generic[T_Evaluable, T_MetricValue]): - @abstractmethod - async def __call__(self, datum: T_Evaluable) -> T_MetricValue: - pass +T_MetricValue = TypeVar( + "T_MetricValue", int, float, str, bool, CustomMetricValue, covariant=True +) +T_MetricValue_inv = TypeVar( + "T_MetricValue_inv", int, float, str, bool, CustomMetricValue +) -class EvaluationMetricMetadata(core_utils.Record, Generic[T_Evaluable, T_MetricValue]): +class EvaluationMetricMetadata( + core_utils.Record, Generic[common.T_Evaluable, common.T_MetricValue] +): """A record to tie together metadata about an evaluation metric to ensure that numbers are interpreted as intended. - Assumptions: + Assumptions:t * If the best and worst values are not None, then the metric is assumed to be ordered. In this case (if the metric is ordered) then the comparison operators <, <=, >, and >= must be implemented (see CustomMetricValue). @@ -91,8 +87,8 @@ def _serialize_extra_metadata(self) -> str: name: str description: str - best_value: T_MetricValue | None = None - worst_value: T_MetricValue | None = None + best_value: common.T_MetricValue | None = None + worst_value: common.T_MetricValue | None = None # e.g. {"substring": "hello", "case_sensitive": False} extra_metadata: dict[str, Any] = {} @@ -104,74 +100,13 @@ def __repr__(self) -> str: @dataclass(frozen=True) -class SampleMetricValue(Generic[T_Evaluable, T_MetricValue]): - # `None` is used to signal that there was an error during calculation. - # In this case, error information is written to stderr (see lib.py:_evaluate_for_sample()). - value: T_MetricValue | None - metric_metadata: EvaluationMetricMetadata[T_Evaluable, T_MetricValue] - - def __post_init__(self) -> None: - metric_metadata = self.metric_metadata - worst_value, best_value = ( - metric_metadata.worst_value, - metric_metadata.best_value, - ) - value = self.value - if worst_value is None and best_value is None: - # fine - return - elif worst_value is None or best_value is None: - raise ValueError( - f""" - [{metric_metadata.name}] - {metric_metadata.description} - - You must define both worst_value and best_value, or neither. - You defined worst_value = {worst_value} and best_value = {best_value}. - """ - ) - elif worst_value == best_value: - raise ValueError("best_value and worst_value cannot be equal") - elif value is not None and worst_value < best_value and not worst_value <= value <= best_value: # type: ignore[fixme] - raise ValueError( - f""" - [{metric_metadata.name}] - {metric_metadata.description} - - Value {value} is not in range [{worst_value}, {best_value}]. - You defined worst_value = {worst_value} and best_value = {best_value}, - but got value outside that range. - """ - ) - elif value is not None and worst_value > best_value and not worst_value >= value >= best_value: # type: ignore[fixme] - raise ValueError( - f""" - [{metric_metadata.name}] - {metric_metadata.description} - - Value {value} is not in range [{worst_value}, {best_value}]. - You defined worst_value = {worst_value} and best_value = {best_value}, - but got value outside that range. - """ - ) - +class TextBasedInputDatum: + value: str | frozendict[str, str] -class TextRatingsData(core_utils.Record): - conciseness_rating: int - conciseness_confidence: float - conciseness_reasoning: str - -def get_llm_structured_response( - input_text: str, - chat_completion_create: CompletionTextToSerializedJSON, - basemodel_type: Type[T_BaseModel], -) -> Result[T_BaseModel, str]: - return result.do( - core_utils.safe_model_validate_json(response_ok, basemodel_type) - # get the serialized JSON response - for response_ok in chat_completion_create(input_text) - ) +@core_utils.exception_to_err_with_traceback +def load_aiconfig_runtime(aiconfig_path: str) -> AIConfigRuntime: + return AIConfigRuntime.load(aiconfig_path) @core_utils.exception_to_err_with_traceback_async @@ -182,3 +117,53 @@ async def run_aiconfig_get_output_text( run_with_dependencies: bool, ): return await aiconfig.run_and_get_output_text(prompt_name, params, run_with_dependencies=run_with_dependencies) # type: ignore + + +async def run_aiconfig_on_text_based_input( + runtime: AIConfigRuntime, + prompt_name: str, + params: common.TextBasedInputDatum, +) -> Result[str, str]: + def _get_params_for_aiconfig( + params: common.TextBasedInputDatum, + ) -> dict[str, str]: + match params.value: + case str(input_text): + return {"the_query": input_text} + case frozendict(): + return dict(params.value) + + params_for_aiconfig = _get_params_for_aiconfig(params) + return await run_aiconfig_get_output_text( + runtime, prompt_name, params_for_aiconfig, run_with_dependencies=True + ) + + +async def batch_run_aiconfig_on_text_based_input( + aiconfig: AIConfigRuntime, + prompt_name: str, + params_seq: Sequence[common.TextBasedInputDatum], +) -> Result[list[TextOutput], str]: + async def _run( + input_datum: common.TextBasedInputDatum, + ) -> Result[TextOutput, str]: + return ( + await run_aiconfig_on_text_based_input( + aiconfig, prompt_name, input_datum + ) + ).map(TextOutput) + + # TODO: fix the race condition and then use gather + # https://github.com/lastmile-ai/aiconfig/issues/434 + res_outputs_: list[Result[TextOutput, str]] = [] + for input_datum in params_seq: + res_outputs_.append(await _run(input_datum)) + res_outputs = core_utils.result_reduce_list_all_ok(res_outputs_) + # res_outputs = await core_utils.result_reduce_list_all_ok_async(list(map(_run, all_inputs))) + + return res_outputs + + +@core_utils.exception_to_err_with_traceback +def make_df(data: Any) -> pd.DataFrame: + return pd.DataFrame(data) diff --git a/python/src/aiconfig/eval/openai.py b/python/src/aiconfig/eval/openai.py index 690940c48..1caa3da75 100644 --- a/python/src/aiconfig/eval/openai.py +++ b/python/src/aiconfig/eval/openai.py @@ -5,7 +5,7 @@ import lastmile_utils.lib.core.api as core_utils import openai import openai.types.chat as openai_types -from aiconfig.eval import common +from aiconfig.eval import test_suite_common from result import Err, Ok, Result @@ -20,11 +20,15 @@ class OpenAIChatCompletionParams: class OpenAIChatCompletionCreate(Protocol): @abstractmethod - def __call__(self, completion_params: OpenAIChatCompletionParams) -> Result[openai_types.ChatCompletion, str]: + def __call__( + self, completion_params: OpenAIChatCompletionParams + ) -> Result[openai_types.ChatCompletion, str]: pass -def default_openai_chat_completion_create(completion_params: OpenAIChatCompletionParams) -> Result[openai_types.ChatCompletion, str]: +def default_openai_chat_completion_create( + completion_params: OpenAIChatCompletionParams, +) -> Result[openai_types.ChatCompletion, str]: try: result = openai.chat.completions.create( messages=completion_params.messages, @@ -38,13 +42,19 @@ def default_openai_chat_completion_create(completion_params: OpenAIChatCompletio return core_utils.ErrWithTraceback(e) -def extract_json_from_chat_completion(chat_completion: openai_types.ChatCompletion) -> Result[common.SerializedJSON, str]: +def extract_json_from_chat_completion( + chat_completion: openai_types.ChatCompletion, +) -> Result[test_suite_common.SerializedJSON, str]: choice = chat_completion.choices[0] message = choice.message if message.tool_calls is None: return Err("No tool calls found") - return Ok(common.SerializedJSON(message.tool_calls[0].function.arguments)) + return Ok( + test_suite_common.SerializedJSON( + message.tool_calls[0].function.arguments + ) + ) def make_fn_completion_text_to_serialized_json( @@ -52,21 +62,35 @@ def make_fn_completion_text_to_serialized_json( properties: dict[str, dict[str, str]], required: list[str], openai_chat_completion_create: OpenAIChatCompletionCreate, -) -> common.CompletionTextToSerializedJSON: - def _chat_completion_create(output_datum: str) -> Result[common.SerializedJSON, str]: - openai_chat_completion_params = _make_openai_completion_params(output_datum, eval_llm_name, properties, required) - return openai_chat_completion_create(openai_chat_completion_params).and_then(extract_json_from_chat_completion) +) -> test_suite_common.CompletionTextToSerializedJSON: + def _chat_completion_create( + output_datum: str, + ) -> Result[test_suite_common.SerializedJSON, str]: + openai_chat_completion_params = _make_openai_completion_params( + output_datum, eval_llm_name, properties, required + ) + return openai_chat_completion_create( + openai_chat_completion_params + ).and_then(extract_json_from_chat_completion) - out: common.CompletionTextToSerializedJSON = _chat_completion_create + out: test_suite_common.CompletionTextToSerializedJSON = ( + _chat_completion_create + ) return out def _make_openai_completion_params( - input_text: str, eval_llm_name: str, properties: dict[str, dict[str, str]], required: list[str] + input_text: str, + eval_llm_name: str, + properties: dict[str, dict[str, str]], + required: list[str], ) -> OpenAIChatCompletionParams: return OpenAIChatCompletionParams( messages=[ - {"role": "system", "content": "Call the function with arguments based on the provided text."}, + { + "role": "system", + "content": "Call the function with arguments based on the provided text.", + }, {"role": "user", "content": input_text}, ], model=eval_llm_name, diff --git a/python/src/aiconfig/eval/test_suite_common.py b/python/src/aiconfig/eval/test_suite_common.py new file mode 100644 index 000000000..0bce3fd69 --- /dev/null +++ b/python/src/aiconfig/eval/test_suite_common.py @@ -0,0 +1,109 @@ +from abc import abstractmethod +from dataclasses import dataclass +from typing import Generic, NewType, Protocol, Type, TypeVar + +import lastmile_utils.lib.core.api as core_utils +import result +from aiconfig.eval import common +from pydantic import BaseModel +from result import Result + +T_BaseModel = TypeVar("T_BaseModel", bound=BaseModel) + +SerializedJSON = NewType("SerializedJSON", str) + + +class CompletionTextToSerializedJSON(Protocol): + @abstractmethod + def __call__(self, output_datum: str) -> Result[SerializedJSON, str]: + pass + + +@dataclass(frozen=True) +class CustomMetricPydanticObject( + common.CustomMetricValue, Generic[T_BaseModel] +): + data: T_BaseModel + + +class EvaluationFunction( + Protocol, Generic[common.T_Evaluable, common.T_MetricValue] +): + @abstractmethod + async def __call__( + self, datum: common.T_Evaluable + ) -> common.T_MetricValue: + pass + + +@dataclass(frozen=True) +class SampleMetricValue(Generic[common.T_Evaluable, common.T_MetricValue]): + # `None` is used to signal that there was an error during calculation. + # In this case, error information is written to stderr (see lib.py:_evaluate_for_sample()). + value: common.T_MetricValue | None + metric_metadata: common.EvaluationMetricMetadata[ + common.T_Evaluable, common.T_MetricValue + ] + + def __post_init__(self) -> None: + metric_metadata = self.metric_metadata + worst_value, best_value = ( + metric_metadata.worst_value, + metric_metadata.best_value, + ) + value = self.value + if worst_value is None and best_value is None: + # fine + return + elif worst_value is None or best_value is None: + raise ValueError( + f""" + [{metric_metadata.name}] + {metric_metadata.description} + + You must define both worst_value and best_value, or neither. + You defined worst_value = {worst_value} and best_value = {best_value}. + """ + ) + elif worst_value == best_value: + raise ValueError("best_value and worst_value cannot be equal") + elif value is not None and worst_value < best_value and not worst_value <= value <= best_value: # type: ignore[fixme] + raise ValueError( + f""" + [{metric_metadata.name}] + {metric_metadata.description} + + Value {value} is not in range [{worst_value}, {best_value}]. + You defined worst_value = {worst_value} and best_value = {best_value}, + but got value outside that range. + """ + ) + elif value is not None and worst_value > best_value and not worst_value >= value >= best_value: # type: ignore[fixme] + raise ValueError( + f""" + [{metric_metadata.name}] + {metric_metadata.description} + + Value {value} is not in range [{worst_value}, {best_value}]. + You defined worst_value = {worst_value} and best_value = {best_value}, + but got value outside that range. + """ + ) + + +class TextRatingsData(core_utils.Record): + conciseness_rating: int + conciseness_confidence: float + conciseness_reasoning: str + + +def get_llm_structured_response( + input_text: str, + chat_completion_create: CompletionTextToSerializedJSON, + basemodel_type: Type[T_BaseModel], +) -> Result[T_BaseModel, str]: + return result.do( + core_utils.safe_model_validate_json(response_ok, basemodel_type) + # get the serialized JSON response + for response_ok in chat_completion_create(input_text) + ) diff --git a/python/src/aiconfig/eval/examples/travel/travel_aiconfig_test_suite_settings.json b/python/src/aiconfig/eval/test_suite_examples/travel/travel_aiconfig_test_suite_settings.json similarity index 100% rename from python/src/aiconfig/eval/examples/travel/travel_aiconfig_test_suite_settings.json rename to python/src/aiconfig/eval/test_suite_examples/travel/travel_aiconfig_test_suite_settings.json diff --git a/python/src/aiconfig/eval/examples/travel/travel_eval.ipynb b/python/src/aiconfig/eval/test_suite_examples/travel/travel_eval.ipynb similarity index 78% rename from python/src/aiconfig/eval/examples/travel/travel_eval.ipynb rename to python/src/aiconfig/eval/test_suite_examples/travel/travel_eval.ipynb index 92d936fd4..846399f86 100644 --- a/python/src/aiconfig/eval/examples/travel/travel_eval.ipynb +++ b/python/src/aiconfig/eval/test_suite_examples/travel/travel_eval.ipynb @@ -39,7 +39,7 @@ "text": [ "\u001b[33mWARNING: Ignoring invalid distribution -etuptools (/opt/homebrew/Caskroom/miniconda/base/envs/aiconfig/lib/python3.10/site-packages)\u001b[0m\u001b[33m\n", "\u001b[0mCollecting lastmile-utils\n", - " Using cached lastmile_utils-0.0.13-py3-none-any.whl.metadata (901 bytes)\n", + " Using cached lastmile_utils-0.0.21-py3-none-any.whl.metadata (901 bytes)\n", "Collecting black==23.11.0 (from lastmile-utils)\n", " Using cached black-23.11.0-cp310-cp310-macosx_11_0_arm64.whl.metadata (66 kB)\n", "Collecting chardet==5.2.0 (from lastmile-utils)\n", @@ -62,12 +62,12 @@ " Using cached pytest-7.4.3-py3-none-any.whl.metadata (7.9 kB)\n", "Collecting python-dotenv==1.0.0 (from lastmile-utils)\n", " Using cached python_dotenv-1.0.0-py3-none-any.whl (19 kB)\n", - "Collecting result==0.15.0 (from lastmile-utils)\n", - " Using cached result-0.15.0-py3-none-any.whl.metadata (12 kB)\n", + "Collecting result==0.16.0 (from lastmile-utils)\n", + " Using cached result-0.16.0-py3-none-any.whl.metadata (857 bytes)\n", "Collecting autoflake==2.2.1 (from lastmile-utils)\n", " Using cached autoflake-2.2.1-py3-none-any.whl.metadata (7.3 kB)\n", "Collecting pyflakes>=3.0.0 (from autoflake==2.2.1->lastmile-utils)\n", - " Using cached pyflakes-3.1.0-py2.py3-none-any.whl.metadata (3.5 kB)\n", + " Downloading pyflakes-3.2.0-py2.py3-none-any.whl.metadata (3.5 kB)\n", "Collecting tomli>=2.0.1 (from autoflake==2.2.1->lastmile-utils)\n", " Using cached tomli-2.0.1-py3-none-any.whl (12 kB)\n", "Collecting click>=8.0.0 (from black==23.11.0->lastmile-utils)\n", @@ -86,16 +86,18 @@ " Using cached mccabe-0.7.0-py2.py3-none-any.whl (7.3 kB)\n", "Collecting pycodestyle<2.12.0,>=2.11.0 (from flake8==6.1.0->lastmile-utils)\n", " Using cached pycodestyle-2.11.1-py2.py3-none-any.whl.metadata (4.5 kB)\n", + "Collecting pyflakes>=3.0.0 (from autoflake==2.2.1->lastmile-utils)\n", + " Using cached pyflakes-3.1.0-py2.py3-none-any.whl.metadata (3.5 kB)\n", "Collecting json-spec (from jsoncomment==0.4.2->lastmile-utils)\n", " Using cached json_spec-0.11.0-py3-none-any.whl (41 kB)\n", "Collecting numpy<2,>=1.22.4 (from pandas==2.1.2->lastmile-utils)\n", - " Using cached numpy-1.26.2-cp310-cp310-macosx_11_0_arm64.whl.metadata (61 kB)\n", + " Using cached numpy-1.26.3-cp310-cp310-macosx_11_0_arm64.whl.metadata (61 kB)\n", "Collecting python-dateutil>=2.8.2 (from pandas==2.1.2->lastmile-utils)\n", " Using cached python_dateutil-2.8.2-py2.py3-none-any.whl (247 kB)\n", "Collecting pytz>=2020.1 (from pandas==2.1.2->lastmile-utils)\n", " Using cached pytz-2023.3.post1-py2.py3-none-any.whl.metadata (22 kB)\n", "Collecting tzdata>=2022.1 (from pandas==2.1.2->lastmile-utils)\n", - " Using cached tzdata-2023.3-py2.py3-none-any.whl (341 kB)\n", + " Using cached tzdata-2023.4-py2.py3-none-any.whl.metadata (1.4 kB)\n", "Collecting annotated-types>=0.4.0 (from pydantic==2.4.2->lastmile-utils)\n", " Using cached annotated_types-0.6.0-py3-none-any.whl.metadata (12 kB)\n", "Collecting pydantic-core==2.10.1 (from pydantic==2.4.2->lastmile-utils)\n", @@ -115,14 +117,14 @@ "Collecting exceptiongroup>=1.0.0rc8 (from pytest==7.4.3->lastmile-utils)\n", " Using cached exceptiongroup-1.2.0-py3-none-any.whl.metadata (6.6 kB)\n", "Collecting setuptools (from nodeenv>=1.6.0->pyright==1.1.335->lastmile-utils)\n", - " Using cached setuptools-69.0.2-py3-none-any.whl.metadata (6.3 kB)\n", + " Using cached setuptools-69.0.3-py3-none-any.whl.metadata (6.3 kB)\n", "Collecting six>=1.5 (from python-dateutil>=2.8.2->pandas==2.1.2->lastmile-utils)\n", " Using cached six-1.16.0-py2.py3-none-any.whl (11 kB)\n", "Collecting importlib-metadata<6.0.0,>=5.0.0 (from json-spec->jsoncomment==0.4.2->lastmile-utils)\n", " Using cached importlib_metadata-5.2.0-py3-none-any.whl (21 kB)\n", "Collecting zipp>=0.5 (from importlib-metadata<6.0.0,>=5.0.0->json-spec->jsoncomment==0.4.2->lastmile-utils)\n", " Using cached zipp-3.17.0-py3-none-any.whl.metadata (3.7 kB)\n", - "Using cached lastmile_utils-0.0.13-py3-none-any.whl (14 kB)\n", + "Using cached lastmile_utils-0.0.21-py3-none-any.whl (15 kB)\n", "Using cached autoflake-2.2.1-py3-none-any.whl (32 kB)\n", "Using cached black-23.11.0-cp310-cp310-macosx_11_0_arm64.whl (1.4 MB)\n", "Using cached chardet-5.2.0-py3-none-any.whl (199 kB)\n", @@ -132,7 +134,7 @@ "Using cached pylint-3.0.2-py3-none-any.whl (510 kB)\n", "Using cached pyright-1.1.335-py3-none-any.whl (17 kB)\n", "Using cached pytest-7.4.3-py3-none-any.whl (325 kB)\n", - "Using cached result-0.15.0-py3-none-any.whl (10 kB)\n", + "Using cached result-0.16.0-py3-none-any.whl (6.8 kB)\n", "Using cached pydantic_core-2.10.1-cp310-cp310-macosx_11_0_arm64.whl (1.7 MB)\n", "Using cached annotated_types-0.6.0-py3-none-any.whl (12 kB)\n", "Using cached astroid-3.0.2-py3-none-any.whl (275 kB)\n", @@ -140,7 +142,7 @@ "Using cached dill-0.3.7-py3-none-any.whl (115 kB)\n", "Using cached exceptiongroup-1.2.0-py3-none-any.whl (16 kB)\n", "Using cached nodeenv-1.8.0-py2.py3-none-any.whl (22 kB)\n", - "Using cached numpy-1.26.2-cp310-cp310-macosx_11_0_arm64.whl (14.0 MB)\n", + "Using cached numpy-1.26.3-cp310-cp310-macosx_11_0_arm64.whl (14.0 MB)\n", "Using cached packaging-23.2-py3-none-any.whl (53 kB)\n", "Using cached pathspec-0.12.1-py3-none-any.whl (31 kB)\n", "Using cached platformdirs-4.1.0-py3-none-any.whl (17 kB)\n", @@ -150,7 +152,8 @@ "Using cached pytz-2023.3.post1-py2.py3-none-any.whl (502 kB)\n", "Using cached tomlkit-0.12.3-py3-none-any.whl (37 kB)\n", "Using cached typing_extensions-4.9.0-py3-none-any.whl (32 kB)\n", - "Using cached setuptools-69.0.2-py3-none-any.whl (819 kB)\n", + "Using cached tzdata-2023.4-py2.py3-none-any.whl (346 kB)\n", + "Using cached setuptools-69.0.3-py3-none-any.whl (819 kB)\n", "Using cached zipp-3.17.0-py3-none-any.whl (7.4 kB)\n", "\u001b[33mWARNING: Ignoring invalid distribution -etuptools (/opt/homebrew/Caskroom/miniconda/base/envs/aiconfig/lib/python3.10/site-packages)\u001b[0m\u001b[33m\n", "\u001b[0mInstalling collected packages: pytz, zipp, tzdata, typing-extensions, tomlkit, tomli, six, setuptools, result, python-dotenv, pyflakes, pycodestyle, pluggy, platformdirs, pathspec, packaging, numpy, mypy-extensions, mccabe, isort, iniconfig, exceptiongroup, dill, click, chardet, annotated-types, python-dateutil, pytest, pydantic-core, nodeenv, importlib-metadata, flake8, black, autoflake, astroid, pyright, pylint, pydantic, pandas, json-spec, jsoncomment, lastmile-utils\n", @@ -163,9 +166,9 @@ " Uninstalling zipp-3.17.0:\n", " Successfully uninstalled zipp-3.17.0\n", " Attempting uninstall: tzdata\n", - " Found existing installation: tzdata 2023.3\n", - " Uninstalling tzdata-2023.3:\n", - " Successfully uninstalled tzdata-2023.3\n", + " Found existing installation: tzdata 2023.4\n", + " Uninstalling tzdata-2023.4:\n", + " Successfully uninstalled tzdata-2023.4\n", " Attempting uninstall: typing-extensions\n", " Found existing installation: typing_extensions 4.9.0\n", " Uninstalling typing_extensions-4.9.0:\n", @@ -183,13 +186,13 @@ " Uninstalling six-1.16.0:\n", " Successfully uninstalled six-1.16.0\n", " Attempting uninstall: setuptools\n", - " Found existing installation: setuptools 69.0.2\n", - " Uninstalling setuptools-69.0.2:\n", - " Successfully uninstalled setuptools-69.0.2\n", + " Found existing installation: setuptools 69.0.3\n", + " Uninstalling setuptools-69.0.3:\n", + " Successfully uninstalled setuptools-69.0.3\n", " Attempting uninstall: result\n", - " Found existing installation: result 0.15.0\n", - " Uninstalling result-0.15.0:\n", - " Successfully uninstalled result-0.15.0\n", + " Found existing installation: result 0.16.0\n", + " Uninstalling result-0.16.0:\n", + " Successfully uninstalled result-0.16.0\n", " Attempting uninstall: python-dotenv\n", " Found existing installation: python-dotenv 1.0.0\n", " Uninstalling python-dotenv-1.0.0:\n", @@ -219,9 +222,9 @@ " Uninstalling packaging-23.2:\n", " Successfully uninstalled packaging-23.2\n", " Attempting uninstall: numpy\n", - " Found existing installation: numpy 1.26.2\n", - " Uninstalling numpy-1.26.2:\n", - " Successfully uninstalled numpy-1.26.2\n", + " Found existing installation: numpy 1.26.3\n", + " Uninstalling numpy-1.26.3:\n", + " Successfully uninstalled numpy-1.26.3\n", " Attempting uninstall: mypy-extensions\n", " Found existing installation: mypy-extensions 1.0.0\n", " Uninstalling mypy-extensions-1.0.0:\n", @@ -319,14 +322,16 @@ " Uninstalling jsoncomment-0.4.2:\n", " Successfully uninstalled jsoncomment-0.4.2\n", " Attempting uninstall: lastmile-utils\n", - " Found existing installation: lastmile-utils 0.0.13\n", - " Uninstalling lastmile-utils-0.0.13:\n", - " Successfully uninstalled lastmile-utils-0.0.13\n", + " Found existing installation: lastmile_utils 0.0.21\n", + " Uninstalling lastmile_utils-0.0.21:\n", + " Successfully uninstalled lastmile_utils-0.0.21\n", "\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n", - "python-aiconfig 1.1.7 requires lastmile-utils==0.0.10, but you have lastmile-utils 0.0.13 which is incompatible.\u001b[0m\u001b[31m\n", - "\u001b[0mSuccessfully installed annotated-types-0.6.0 astroid-3.0.2 autoflake-2.2.1 black-23.11.0 chardet-5.2.0 click-8.1.7 dill-0.3.7 exceptiongroup-1.2.0 flake8-6.1.0 importlib-metadata-5.2.0 iniconfig-2.0.0 isort-5.12.0 json-spec-0.11.0 jsoncomment-0.4.2 lastmile-utils-0.0.13 mccabe-0.7.0 mypy-extensions-1.0.0 nodeenv-1.8.0 numpy-1.26.2 packaging-23.2 pandas-2.1.2 pathspec-0.12.1 platformdirs-4.1.0 pluggy-1.3.0 pycodestyle-2.11.1 pydantic-2.4.2 pydantic-core-2.10.1 pyflakes-3.1.0 pylint-3.0.2 pyright-1.1.335 pytest-7.4.3 python-dateutil-2.8.2 python-dotenv-1.0.0 pytz-2023.3.post1 result-0.15.0 setuptools-69.0.2 six-1.16.0 tomli-2.0.1 tomlkit-0.12.3 typing-extensions-4.9.0 tzdata-2023.3 zipp-3.17.0\n", + "fastapi 0.105.0 requires anyio<4.0.0,>=3.7.1, but you have anyio 4.2.0 which is incompatible.\n", + "langchain 0.0.339 requires anyio<4.0, but you have anyio 4.2.0 which is incompatible.\n", + "python-aiconfig 1.1.9 requires lastmile-utils==0.0.20, but you have lastmile-utils 0.0.21 which is incompatible.\u001b[0m\u001b[31m\n", + "\u001b[0mSuccessfully installed annotated-types-0.6.0 astroid-3.0.2 autoflake-2.2.1 black-23.11.0 chardet-5.2.0 click-8.1.7 dill-0.3.7 exceptiongroup-1.2.0 flake8-6.1.0 importlib-metadata-5.2.0 iniconfig-2.0.0 isort-5.12.0 json-spec-0.11.0 jsoncomment-0.4.2 lastmile-utils-0.0.21 mccabe-0.7.0 mypy-extensions-1.0.0 nodeenv-1.8.0 numpy-1.26.3 packaging-23.2 pandas-2.1.2 pathspec-0.12.1 platformdirs-4.1.0 pluggy-1.3.0 pycodestyle-2.11.1 pydantic-2.4.2 pydantic-core-2.10.1 pyflakes-3.1.0 pylint-3.0.2 pyright-1.1.335 pytest-7.4.3 python-dateutil-2.8.2 python-dotenv-1.0.0 pytz-2023.3.post1 result-0.16.0 setuptools-69.0.3 six-1.16.0 tomli-2.0.1 tomlkit-0.12.3 typing-extensions-4.9.0 tzdata-2023.4 zipp-3.17.0\n", "\u001b[33mWARNING: Ignoring invalid distribution -etuptools (/opt/homebrew/Caskroom/miniconda/base/envs/aiconfig/lib/python3.10/site-packages)\u001b[0m\u001b[33m\n", - "\u001b[0mlastmile-utils 0.0.13\n" + "\u001b[0m" ] } ], @@ -401,7 +406,7 @@ }, { "cell_type": "code", - "execution_count": 34, + "execution_count": 3, "metadata": {}, "outputs": [ { @@ -413,21 +418,11 @@ " on our data along with some off-the-shelf metrics.\n", " \n" ] - }, - { - "data": { - "text/plain": [ - "3" - ] - }, - "execution_count": 34, - "metadata": {}, - "output_type": "execute_result" } ], "source": [ "from typing import Literal\n", - "from aiconfig.eval.api import common as common, metrics as metrics\n", + "from aiconfig.eval.api import test_suite_common as common, test_suite_metrics as metrics\n", "import lastmile_utils.lib.core.api as core_utils\n", "\n", "print(\n", @@ -438,11 +433,11 @@ ")\n", "\n", "# 1. Helper function to construct a Metric that counts a specific letter.\n", - "def make_letter_count_metric(letter_to_count: str) -> metrics.Metric[str, int]:\n", + "def make_letter_count_metric(letter_to_count: str) -> metrics.TestSuiteMetric[str, int]:\n", " async def letter_count_metric(datum: str):\n", " return datum.count(letter_to_count)\n", " \n", - " output_metric = metrics.Metric(\n", + " output_metric = metrics.TestSuiteMetric(\n", " evaluation_fn=letter_count_metric,\n", " metric_metadata=common.EvaluationMetricMetadata(\n", " name=\"letter_count\",\n", @@ -489,7 +484,7 @@ }, { "cell_type": "code", - "execution_count": 35, + "execution_count": 4, "metadata": {}, "outputs": [ { @@ -558,7 +553,7 @@ }, { "cell_type": "code", - "execution_count": 36, + "execution_count": 5, "metadata": {}, "outputs": [ { @@ -570,34 +565,38 @@ "Test input:\n", " different kinds of cuisines \n", "Function:\n", - " Metric(evaluation_fn=, metric_metadata=EvaluationMetricMetadata({\n", + " TestSuiteMetric(evaluation_fn=._construct..evaluation_fn at 0x2bc1e8940>, metric_metadata=EvaluationMetricMetadata({\n", " \"name\": \"brevity\",\n", " \"description\": \"Absolute text length\",\n", " \"best_value\": 1,\n", " \"worst_value\": 9223372036854775807,\n", - " \"extra_metadata\": {},\n", - " \"id\": \"24952ce05ce6dcbd370ccc3b39d410edeab8e1cf420130a83cf9388df6bcfdc3\"\n", + " \"extra_metadata\": {\n", + " \"args\": []\n", + " },\n", + " \"id\": \"5b29b6ba68aeeadc42b7333015f4b158f7514f68c05fef79a702e98cf9983085\"\n", "}))\n", "\n", "Test input:\n", " different kinds of cuisines \n", "Function:\n", - " Metric(evaluation_fn=._fn at 0x1428140d0>, metric_metadata=EvaluationMetricMetadata({\n", + " TestSuiteMetric(evaluation_fn=._construct..evaluation_fn at 0x2bc355360>, metric_metadata=EvaluationMetricMetadata({\n", " \"name\": \"substring_match\",\n", " \"description\": \"True (pass) if contains given substring\",\n", " \"best_value\": true,\n", " \"worst_value\": false,\n", " \"extra_metadata\": {\n", - " \"substring\": \"Magnolia Bakery\",\n", + " \"args\": [\n", + " \"Magnolia Bakery\"\n", + " ],\n", " \"case_sensitive\": false\n", " },\n", - " \"id\": \"0c461362f44884023dda5537ce88263ba20d555562bac8abc05bcde0ce1aacf6\"\n", + " \"id\": \"12b2b88421a53f87fa1502c48a3bfa8b84aa22af3528178f0ec8d699db041d8d\"\n", "}))\n", "\n", "Test input:\n", " different kinds of cuisines \n", "Function:\n", - " Metric(evaluation_fn=._make_evaluation_fn.._evaluation_fn at 0x132f3c430>, metric_metadata=EvaluationMetricMetadata({\n", + " TestSuiteMetric(evaluation_fn=._make_evaluation_fn.._evaluation_fn at 0x2bc1e8d30>, metric_metadata=EvaluationMetricMetadata({\n", " \"name\": \"text_ratings\",\n", " \"description\": \"Text ratings\",\n", " \"best_value\": null,\n", @@ -613,7 +612,7 @@ "Test input:\n", " different kinds of cuisines \n", "Function:\n", - " Metric(evaluation_fn=.letter_count_metric at 0x142815510>, metric_metadata=EvaluationMetricMetadata({\n", + " TestSuiteMetric(evaluation_fn=.letter_count_metric at 0x2bc1eb880>, metric_metadata=EvaluationMetricMetadata({\n", " \"name\": \"letter_count\",\n", " \"description\": \"Counts the number of times the given letter appears in the text\",\n", " \"best_value\": null,\n", @@ -627,7 +626,7 @@ "Test input:\n", " different kinds of cuisines \n", "Function:\n", - " Metric(evaluation_fn=._make_evaluation_fn.._evaluation_fn at 0x142817760>, metric_metadata=EvaluationMetricMetadata({\n", + " TestSuiteMetric(evaluation_fn=._make_evaluation_fn.._evaluation_fn at 0x2bc355120>, metric_metadata=EvaluationMetricMetadata({\n", " \"name\": \"emotional_valence\",\n", " \"description\": \"Emotional valence\",\n", " \"best_value\": null,\n", @@ -643,34 +642,38 @@ "Test input:\n", " iconic midtown skyscrapers \n", "Function:\n", - " Metric(evaluation_fn=, metric_metadata=EvaluationMetricMetadata({\n", + " TestSuiteMetric(evaluation_fn=._construct..evaluation_fn at 0x2bc1e8940>, metric_metadata=EvaluationMetricMetadata({\n", " \"name\": \"brevity\",\n", " \"description\": \"Absolute text length\",\n", " \"best_value\": 1,\n", " \"worst_value\": 9223372036854775807,\n", - " \"extra_metadata\": {},\n", - " \"id\": \"24952ce05ce6dcbd370ccc3b39d410edeab8e1cf420130a83cf9388df6bcfdc3\"\n", + " \"extra_metadata\": {\n", + " \"args\": []\n", + " },\n", + " \"id\": \"5b29b6ba68aeeadc42b7333015f4b158f7514f68c05fef79a702e98cf9983085\"\n", "}))\n", "\n", "Test input:\n", " iconic midtown skyscrapers \n", "Function:\n", - " Metric(evaluation_fn=._fn at 0x142816b90>, metric_metadata=EvaluationMetricMetadata({\n", + " TestSuiteMetric(evaluation_fn=._construct..evaluation_fn at 0x2bc3553f0>, metric_metadata=EvaluationMetricMetadata({\n", " \"name\": \"substring_match\",\n", " \"description\": \"True (pass) if contains given substring\",\n", " \"best_value\": true,\n", " \"worst_value\": false,\n", " \"extra_metadata\": {\n", - " \"substring\": \"Empire State Building\",\n", + " \"args\": [\n", + " \"Empire State Building\"\n", + " ],\n", " \"case_sensitive\": false\n", " },\n", - " \"id\": \"53e4c7163f49fdc7727286e638ff07bcb570faaa334456775c616c2f4ad3eb3f\"\n", + " \"id\": \"17bb1efe1fb306bce98240f3534f5d29c68564e4e7c1c0db17198247d19754e3\"\n", "}))\n", "\n", "Test input:\n", " iconic midtown skyscrapers \n", "Function:\n", - " Metric(evaluation_fn=._make_evaluation_fn.._evaluation_fn at 0x132f3c430>, metric_metadata=EvaluationMetricMetadata({\n", + " TestSuiteMetric(evaluation_fn=._make_evaluation_fn.._evaluation_fn at 0x2bc1e8d30>, metric_metadata=EvaluationMetricMetadata({\n", " \"name\": \"text_ratings\",\n", " \"description\": \"Text ratings\",\n", " \"best_value\": null,\n", @@ -686,7 +689,7 @@ "Test input:\n", " iconic midtown skyscrapers \n", "Function:\n", - " Metric(evaluation_fn=.letter_count_metric at 0x142815510>, metric_metadata=EvaluationMetricMetadata({\n", + " TestSuiteMetric(evaluation_fn=.letter_count_metric at 0x2bc1eb880>, metric_metadata=EvaluationMetricMetadata({\n", " \"name\": \"letter_count\",\n", " \"description\": \"Counts the number of times the given letter appears in the text\",\n", " \"best_value\": null,\n", @@ -700,7 +703,7 @@ "Test input:\n", " iconic midtown skyscrapers \n", "Function:\n", - " Metric(evaluation_fn=._make_evaluation_fn.._evaluation_fn at 0x142817760>, metric_metadata=EvaluationMetricMetadata({\n", + " TestSuiteMetric(evaluation_fn=._make_evaluation_fn.._evaluation_fn at 0x2bc355120>, metric_metadata=EvaluationMetricMetadata({\n", " \"name\": \"emotional_valence\",\n", " \"description\": \"Emotional valence\",\n", " \"best_value\": null,\n", @@ -724,7 +727,7 @@ }, { "cell_type": "code", - "execution_count": 37, + "execution_count": 6, "metadata": {}, "outputs": [ { @@ -770,13 +773,11 @@ " \n", " 0\n", " different kinds of cuisines\n", - " Day 1: Food tour in Manhattan's Chinatown, tasting regional Chinese dishes. \n", - "\n", - "Day 2: Brooklyn pizza-making class, uncovering NY-style pizza secrets. \n", - "\n", - "Day 3: Evening of Spanish tapas with flamenco performances in NYC.\n", - " 218\n", - " 24952ce05ce6dcbd370ccc3b39d410edeab8e1cf420130a83cf9388df6bcfdc3\n", + " 1. Explore Chelsea Market's international food stalls.\n", + "2. Guided Manhattan Chinatown food tour.\n", + "3. Experience Italian heritage and cuisine in Little Italy.\n", + " 155\n", + " 5b29b6ba68aeeadc42b7333015f4b158f7514f68c05fef79a702e98cf9983085\n", " brevity\n", " Absolute text length\n", " 1\n", @@ -785,13 +786,11 @@ " \n", " 1\n", " different kinds of cuisines\n", - " Day 1: Food tour in Manhattan's Chinatown, tasting regional Chinese dishes. \n", - "\n", - "Day 2: Brooklyn pizza-making class, uncovering NY-style pizza secrets. \n", - "\n", - "Day 3: Evening of Spanish tapas with flamenco performances in NYC.\n", + " 1. Explore Chelsea Market's international food stalls.\n", + "2. Guided Manhattan Chinatown food tour.\n", + "3. Experience Italian heritage and cuisine in Little Italy.\n", " False\n", - " 0c461362f44884023dda5537ce88263ba20d555562bac8abc05bcde0ce1aacf6\n", + " 12b2b88421a53f87fa1502c48a3bfa8b84aa22af3528178f0ec8d699db041d8d\n", " substring_match\n", " True (pass) if contains given substring\n", " True\n", @@ -800,12 +799,10 @@ " \n", " 2\n", " different kinds of cuisines\n", - " Day 1: Food tour in Manhattan's Chinatown, tasting regional Chinese dishes. \n", - "\n", - "Day 2: Brooklyn pizza-making class, uncovering NY-style pizza secrets. \n", - "\n", - "Day 3: Evening of Spanish tapas with flamenco performances in NYC.\n", - " CustomMetricPydanticObject(data={\\n \"conciseness_rating\": 5,\\n \"conciseness_confidence\": 0.9,\\n \"conciseness_reasoning\": \"The text is concise and provides clear information about the activities planned for each day.\"\\n})\n", + " 1. Explore Chelsea Market's international food stalls.\n", + "2. Guided Manhattan Chinatown food tour.\n", + "3. Experience Italian heritage and cuisine in Little Italy.\n", + " CustomMetricPydanticObject(data={\\n \"conciseness_rating\": 5,\\n \"conciseness_confidence\": 0.9,\\n \"conciseness_reasoning\": \"The text is concise and provides clear information about three different food-related experiences in New York City.\"\\n})\n", " 300b32bb8a01befd5e729eaf73506bdba01f910c0db0c8f70136dd2e48e298a7\n", " text_ratings\n", " Text ratings\n", @@ -815,12 +812,10 @@ " \n", " 3\n", " different kinds of cuisines\n", - " Day 1: Food tour in Manhattan's Chinatown, tasting regional Chinese dishes. \n", - "\n", - "Day 2: Brooklyn pizza-making class, uncovering NY-style pizza secrets. \n", - "\n", - "Day 3: Evening of Spanish tapas with flamenco performances in NYC.\n", - " 4\n", + " 1. Explore Chelsea Market's international food stalls.\n", + "2. Guided Manhattan Chinatown food tour.\n", + "3. Experience Italian heritage and cuisine in Little Italy.\n", + " 0\n", " 855b84d49dadc258f82d949bf3d57a100c788e6e093e4615e8b4e03567f1ffc9\n", " letter_count\n", " Counts the number of times the given letter appears in the text\n", @@ -830,11 +825,9 @@ " \n", " 4\n", " different kinds of cuisines\n", - " Day 1: Food tour in Manhattan's Chinatown, tasting regional Chinese dishes. \n", - "\n", - "Day 2: Brooklyn pizza-making class, uncovering NY-style pizza secrets. \n", - "\n", - "Day 3: Evening of Spanish tapas with flamenco performances in NYC.\n", + " 1. Explore Chelsea Market's international food stalls.\n", + "2. Guided Manhattan Chinatown food tour.\n", + "3. Experience Italian heritage and cuisine in Little Italy.\n", " CustomMetricPydanticObject(data={\\n \"emotional_valence\": \"happy\",\\n \"confidence_probability\": 0.9\\n})\n", " a351b0b7ab1639eb32695430b3e1bb65c96d11b528730c103d5879234a3bd8bb\n", " emotional_valence\n", @@ -845,11 +838,11 @@ " \n", " 5\n", " iconic midtown skyscrapers\n", - " 1. Visit Empire State Building, explore exhibits. \n", - "2. Proceed to Top of the Rock, photograph city views. \n", - "3. Explore New York Public Library Schwarzman Building, see exhibitions.\n", - " 178\n", - " 24952ce05ce6dcbd370ccc3b39d410edeab8e1cf420130a83cf9388df6bcfdc3\n", + " Day 1: Empire State Building, Skyride.\n", + "Day 2: Rockefeller Center, Top of the Rock.\n", + "Day 3: One World Trade Center, 9/11 Memorial & Museum.\n", + " 137\n", + " 5b29b6ba68aeeadc42b7333015f4b158f7514f68c05fef79a702e98cf9983085\n", " brevity\n", " Absolute text length\n", " 1\n", @@ -858,11 +851,11 @@ " \n", " 6\n", " iconic midtown skyscrapers\n", - " 1. Visit Empire State Building, explore exhibits. \n", - "2. Proceed to Top of the Rock, photograph city views. \n", - "3. Explore New York Public Library Schwarzman Building, see exhibitions.\n", + " Day 1: Empire State Building, Skyride.\n", + "Day 2: Rockefeller Center, Top of the Rock.\n", + "Day 3: One World Trade Center, 9/11 Memorial & Museum.\n", " True\n", - " 53e4c7163f49fdc7727286e638ff07bcb570faaa334456775c616c2f4ad3eb3f\n", + " 17bb1efe1fb306bce98240f3534f5d29c68564e4e7c1c0db17198247d19754e3\n", " substring_match\n", " True (pass) if contains given substring\n", " True\n", @@ -871,10 +864,10 @@ " \n", " 7\n", " iconic midtown skyscrapers\n", - " 1. Visit Empire State Building, explore exhibits. \n", - "2. Proceed to Top of the Rock, photograph city views. \n", - "3. Explore New York Public Library Schwarzman Building, see exhibitions.\n", - " CustomMetricPydanticObject(data={\\n \"conciseness_rating\": 4,\\n \"conciseness_confidence\": 0.8,\\n \"conciseness_reasoning\": \"The text is concise and provides clear instructions for visiting three different attractions in New York City.\"\\n})\n", + " Day 1: Empire State Building, Skyride.\n", + "Day 2: Rockefeller Center, Top of the Rock.\n", + "Day 3: One World Trade Center, 9/11 Memorial & Museum.\n", + " CustomMetricPydanticObject(data={\\n \"conciseness_rating\": 5,\\n \"conciseness_confidence\": 0.9,\\n \"conciseness_reasoning\": \"The text provides a clear and concise itinerary for three days in New York City, mentioning the main attractions to visit each day.\"\\n})\n", " 300b32bb8a01befd5e729eaf73506bdba01f910c0db0c8f70136dd2e48e298a7\n", " text_ratings\n", " Text ratings\n", @@ -884,10 +877,10 @@ " \n", " 8\n", " iconic midtown skyscrapers\n", - " 1. Visit Empire State Building, explore exhibits. \n", - "2. Proceed to Top of the Rock, photograph city views. \n", - "3. Explore New York Public Library Schwarzman Building, see exhibitions.\n", - " 1\n", + " Day 1: Empire State Building, Skyride.\n", + "Day 2: Rockefeller Center, Top of the Rock.\n", + "Day 3: One World Trade Center, 9/11 Memorial & Museum.\n", + " 0\n", " 855b84d49dadc258f82d949bf3d57a100c788e6e093e4615e8b4e03567f1ffc9\n", " letter_count\n", " Counts the number of times the given letter appears in the text\n", @@ -897,10 +890,10 @@ " \n", " 9\n", " iconic midtown skyscrapers\n", - " 1. Visit Empire State Building, explore exhibits. \n", - "2. Proceed to Top of the Rock, photograph city views. \n", - "3. Explore New York Public Library Schwarzman Building, see exhibitions.\n", - " CustomMetricPydanticObject(data={\\n \"emotional_valence\": \"happy\",\\n \"confidence_probability\": 0.9\\n})\n", + " Day 1: Empire State Building, Skyride.\n", + "Day 2: Rockefeller Center, Top of the Rock.\n", + "Day 3: One World Trade Center, 9/11 Memorial & Museum.\n", + " CustomMetricPydanticObject(data={\\n \"emotional_valence\": \"neutral\",\\n \"confidence_probability\": 0.9\\n})\n", " a351b0b7ab1639eb32695430b3e1bb65c96d11b528730c103d5879234a3bd8bb\n", " emotional_valence\n", " Emotional valence\n", @@ -924,73 +917,58 @@ "8 iconic midtown skyscrapers \n", "9 iconic midtown skyscrapers \n", "\n", - " aiconfig_output \\\n", - "0 Day 1: Food tour in Manhattan's Chinatown, tasting regional Chinese dishes. \n", - "\n", - "Day 2: Brooklyn pizza-making class, uncovering NY-style pizza secrets. \n", - "\n", - "Day 3: Evening of Spanish tapas with flamenco performances in NYC.\n", - " \n", - "1 Day 1: Food tour in Manhattan's Chinatown, tasting regional Chinese dishes. \n", - "\n", - "Day 2: Brooklyn pizza-making class, uncovering NY-style pizza secrets. \n", - "\n", - "Day 3: Evening of Spanish tapas with flamenco performances in NYC.\n", - " \n", - "2 Day 1: Food tour in Manhattan's Chinatown, tasting regional Chinese dishes. \n", - "\n", - "Day 2: Brooklyn pizza-making class, uncovering NY-style pizza secrets. \n", - "\n", - "Day 3: Evening of Spanish tapas with flamenco performances in NYC.\n", - " \n", - "3 Day 1: Food tour in Manhattan's Chinatown, tasting regional Chinese dishes. \n", - "\n", - "Day 2: Brooklyn pizza-making class, uncovering NY-style pizza secrets. \n", + " aiconfig_output \\\n", + "0 1. Explore Chelsea Market's international food stalls.\n", + "2. Guided Manhattan Chinatown food tour.\n", + "3. Experience Italian heritage and cuisine in Little Italy. \n", + "1 1. Explore Chelsea Market's international food stalls.\n", + "2. Guided Manhattan Chinatown food tour.\n", + "3. Experience Italian heritage and cuisine in Little Italy. \n", + "2 1. Explore Chelsea Market's international food stalls.\n", + "2. Guided Manhattan Chinatown food tour.\n", + "3. Experience Italian heritage and cuisine in Little Italy. \n", + "3 1. Explore Chelsea Market's international food stalls.\n", + "2. Guided Manhattan Chinatown food tour.\n", + "3. Experience Italian heritage and cuisine in Little Italy. \n", + "4 1. Explore Chelsea Market's international food stalls.\n", + "2. Guided Manhattan Chinatown food tour.\n", + "3. Experience Italian heritage and cuisine in Little Italy. \n", + "5 Day 1: Empire State Building, Skyride.\n", + "Day 2: Rockefeller Center, Top of the Rock.\n", + "Day 3: One World Trade Center, 9/11 Memorial & Museum. \n", + "6 Day 1: Empire State Building, Skyride.\n", + "Day 2: Rockefeller Center, Top of the Rock.\n", + "Day 3: One World Trade Center, 9/11 Memorial & Museum. \n", + "7 Day 1: Empire State Building, Skyride.\n", + "Day 2: Rockefeller Center, Top of the Rock.\n", + "Day 3: One World Trade Center, 9/11 Memorial & Museum. \n", + "8 Day 1: Empire State Building, Skyride.\n", + "Day 2: Rockefeller Center, Top of the Rock.\n", + "Day 3: One World Trade Center, 9/11 Memorial & Museum. \n", + "9 Day 1: Empire State Building, Skyride.\n", + "Day 2: Rockefeller Center, Top of the Rock.\n", + "Day 3: One World Trade Center, 9/11 Memorial & Museum. \n", "\n", - "Day 3: Evening of Spanish tapas with flamenco performances in NYC.\n", - " \n", - "4 Day 1: Food tour in Manhattan's Chinatown, tasting regional Chinese dishes. \n", - "\n", - "Day 2: Brooklyn pizza-making class, uncovering NY-style pizza secrets. \n", - "\n", - "Day 3: Evening of Spanish tapas with flamenco performances in NYC.\n", - " \n", - "5 1. Visit Empire State Building, explore exhibits. \n", - "2. Proceed to Top of the Rock, photograph city views. \n", - "3. Explore New York Public Library Schwarzman Building, see exhibitions. \n", - "6 1. Visit Empire State Building, explore exhibits. \n", - "2. Proceed to Top of the Rock, photograph city views. \n", - "3. Explore New York Public Library Schwarzman Building, see exhibitions. \n", - "7 1. Visit Empire State Building, explore exhibits. \n", - "2. Proceed to Top of the Rock, photograph city views. \n", - "3. Explore New York Public Library Schwarzman Building, see exhibitions. \n", - "8 1. Visit Empire State Building, explore exhibits. \n", - "2. Proceed to Top of the Rock, photograph city views. \n", - "3. Explore New York Public Library Schwarzman Building, see exhibitions. \n", - "9 1. Visit Empire State Building, explore exhibits. \n", - "2. Proceed to Top of the Rock, photograph city views. \n", - "3. Explore New York Public Library Schwarzman Building, see exhibitions. \n", - "\n", - " value \\\n", - "0 218 \n", - "1 False \n", - "2 CustomMetricPydanticObject(data={\\n \"conciseness_rating\": 5,\\n \"conciseness_confidence\": 0.9,\\n \"conciseness_reasoning\": \"The text is concise and provides clear information about the activities planned for each day.\"\\n}) \n", - "3 4 \n", - "4 CustomMetricPydanticObject(data={\\n \"emotional_valence\": \"happy\",\\n \"confidence_probability\": 0.9\\n}) \n", - "5 178 \n", - "6 True \n", - "7 CustomMetricPydanticObject(data={\\n \"conciseness_rating\": 4,\\n \"conciseness_confidence\": 0.8,\\n \"conciseness_reasoning\": \"The text is concise and provides clear instructions for visiting three different attractions in New York City.\"\\n}) \n", - "8 1 \n", - "9 CustomMetricPydanticObject(data={\\n \"emotional_valence\": \"happy\",\\n \"confidence_probability\": 0.9\\n}) \n", + " value \\\n", + "0 155 \n", + "1 False \n", + "2 CustomMetricPydanticObject(data={\\n \"conciseness_rating\": 5,\\n \"conciseness_confidence\": 0.9,\\n \"conciseness_reasoning\": \"The text is concise and provides clear information about three different food-related experiences in New York City.\"\\n}) \n", + "3 0 \n", + "4 CustomMetricPydanticObject(data={\\n \"emotional_valence\": \"happy\",\\n \"confidence_probability\": 0.9\\n}) \n", + "5 137 \n", + "6 True \n", + "7 CustomMetricPydanticObject(data={\\n \"conciseness_rating\": 5,\\n \"conciseness_confidence\": 0.9,\\n \"conciseness_reasoning\": \"The text provides a clear and concise itinerary for three days in New York City, mentioning the main attractions to visit each day.\"\\n}) \n", + "8 0 \n", + "9 CustomMetricPydanticObject(data={\\n \"emotional_valence\": \"neutral\",\\n \"confidence_probability\": 0.9\\n}) \n", "\n", " metric_id \\\n", - "0 24952ce05ce6dcbd370ccc3b39d410edeab8e1cf420130a83cf9388df6bcfdc3 \n", - "1 0c461362f44884023dda5537ce88263ba20d555562bac8abc05bcde0ce1aacf6 \n", + "0 5b29b6ba68aeeadc42b7333015f4b158f7514f68c05fef79a702e98cf9983085 \n", + "1 12b2b88421a53f87fa1502c48a3bfa8b84aa22af3528178f0ec8d699db041d8d \n", "2 300b32bb8a01befd5e729eaf73506bdba01f910c0db0c8f70136dd2e48e298a7 \n", "3 855b84d49dadc258f82d949bf3d57a100c788e6e093e4615e8b4e03567f1ffc9 \n", "4 a351b0b7ab1639eb32695430b3e1bb65c96d11b528730c103d5879234a3bd8bb \n", - "5 24952ce05ce6dcbd370ccc3b39d410edeab8e1cf420130a83cf9388df6bcfdc3 \n", - "6 53e4c7163f49fdc7727286e638ff07bcb570faaa334456775c616c2f4ad3eb3f \n", + "5 5b29b6ba68aeeadc42b7333015f4b158f7514f68c05fef79a702e98cf9983085 \n", + "6 17bb1efe1fb306bce98240f3534f5d29c68564e4e7c1c0db17198247d19754e3 \n", "7 300b32bb8a01befd5e729eaf73506bdba01f910c0db0c8f70136dd2e48e298a7 \n", "8 855b84d49dadc258f82d949bf3d57a100c788e6e093e4615e8b4e03567f1ffc9 \n", "9 a351b0b7ab1639eb32695430b3e1bb65c96d11b528730c103d5879234a3bd8bb \n", @@ -1032,7 +1010,7 @@ "9 None None " ] }, - "execution_count": 37, + "execution_count": 6, "metadata": {}, "output_type": "execute_result" } @@ -1051,7 +1029,7 @@ }, { "cell_type": "code", - "execution_count": 38, + "execution_count": 7, "metadata": {}, "outputs": [ { @@ -1102,54 +1080,54 @@ " \n", " \n", " different kinds of cuisines\n", - " Day 1: Food tour in Manhattan's Chinatown, tasting regional Chinese dishes. \\n\\nDay 2: Brooklyn pizza-making class, uncovering NY-style pizza secrets. \\n\\nDay 3: Evening of Spanish tapas with flamenco performances in NYC.\\n\n", - " 218\n", + " 1. Explore Chelsea Market's international food stalls.\\n2. Guided Manhattan Chinatown food tour.\\n3. Experience Italian heritage and cuisine in Little Italy.\n", + " 155\n", " CustomMetricPydanticObject(data={\\n \"emotional_valence\": \"happy\",\\n \"confidence_probability\": 0.9\\n})\n", - " 4\n", + " 0\n", " False\n", - " CustomMetricPydanticObject(data={\\n \"conciseness_rating\": 5,\\n \"conciseness_confidence\": 0.9,\\n \"conciseness_reasoning\": \"The text is concise and provides clear information about the activities planned for each day.\"\\n})\n", + " CustomMetricPydanticObject(data={\\n \"conciseness_rating\": 5,\\n \"conciseness_confidence\": 0.9,\\n \"conciseness_reasoning\": \"The text is concise and provides clear information about three different food-related experiences in New York City.\"\\n})\n", " \n", " \n", " iconic midtown skyscrapers\n", - " 1. Visit Empire State Building, explore exhibits. \\n2. Proceed to Top of the Rock, photograph city views. \\n3. Explore New York Public Library Schwarzman Building, see exhibitions.\n", - " 178\n", - " CustomMetricPydanticObject(data={\\n \"emotional_valence\": \"happy\",\\n \"confidence_probability\": 0.9\\n})\n", - " 1\n", + " Day 1: Empire State Building, Skyride.\\nDay 2: Rockefeller Center, Top of the Rock.\\nDay 3: One World Trade Center, 9/11 Memorial & Museum.\n", + " 137\n", + " CustomMetricPydanticObject(data={\\n \"emotional_valence\": \"neutral\",\\n \"confidence_probability\": 0.9\\n})\n", + " 0\n", " True\n", - " CustomMetricPydanticObject(data={\\n \"conciseness_rating\": 4,\\n \"conciseness_confidence\": 0.8,\\n \"conciseness_reasoning\": \"The text is concise and provides clear instructions for visiting three different attractions in New York City.\"\\n})\n", + " CustomMetricPydanticObject(data={\\n \"conciseness_rating\": 5,\\n \"conciseness_confidence\": 0.9,\\n \"conciseness_reasoning\": \"The text provides a clear and concise itinerary for three days in New York City, mentioning the main attractions to visit each day.\"\\n})\n", " \n", " \n", "\n", "" ], "text/plain": [ - "metric_name brevity \\\n", - "input aiconfig_output \n", - "different kinds of cuisines Day 1: Food tour in Manhattan's Chinatown, tasting regional Chinese dishes. \\n\\nDay 2: Brooklyn pizza-making class, uncovering NY-style pizza secrets. \\n\\nDay 3: Evening of Spanish tapas with flamenco performances in NYC.\\n 218 \n", - "iconic midtown skyscrapers 1. Visit Empire State Building, explore exhibits. \\n2. Proceed to Top of the Rock, photograph city views. \\n3. Explore New York Public Library Schwarzman Building, see exhibitions. 178 \n", + "metric_name brevity \\\n", + "input aiconfig_output \n", + "different kinds of cuisines 1. Explore Chelsea Market's international food stalls.\\n2. Guided Manhattan Chinatown food tour.\\n3. Experience Italian heritage and cuisine in Little Italy. 155 \n", + "iconic midtown skyscrapers Day 1: Empire State Building, Skyride.\\nDay 2: Rockefeller Center, Top of the Rock.\\nDay 3: One World Trade Center, 9/11 Memorial & Museum. 137 \n", "\n", - "metric_name emotional_valence \\\n", - "input aiconfig_output \n", - "different kinds of cuisines Day 1: Food tour in Manhattan's Chinatown, tasting regional Chinese dishes. \\n\\nDay 2: Brooklyn pizza-making class, uncovering NY-style pizza secrets. \\n\\nDay 3: Evening of Spanish tapas with flamenco performances in NYC.\\n CustomMetricPydanticObject(data={\\n \"emotional_valence\": \"happy\",\\n \"confidence_probability\": 0.9\\n}) \n", - "iconic midtown skyscrapers 1. Visit Empire State Building, explore exhibits. \\n2. Proceed to Top of the Rock, photograph city views. \\n3. Explore New York Public Library Schwarzman Building, see exhibitions. CustomMetricPydanticObject(data={\\n \"emotional_valence\": \"happy\",\\n \"confidence_probability\": 0.9\\n}) \n", + "metric_name emotional_valence \\\n", + "input aiconfig_output \n", + "different kinds of cuisines 1. Explore Chelsea Market's international food stalls.\\n2. Guided Manhattan Chinatown food tour.\\n3. Experience Italian heritage and cuisine in Little Italy. CustomMetricPydanticObject(data={\\n \"emotional_valence\": \"happy\",\\n \"confidence_probability\": 0.9\\n}) \n", + "iconic midtown skyscrapers Day 1: Empire State Building, Skyride.\\nDay 2: Rockefeller Center, Top of the Rock.\\nDay 3: One World Trade Center, 9/11 Memorial & Museum. CustomMetricPydanticObject(data={\\n \"emotional_valence\": \"neutral\",\\n \"confidence_probability\": 0.9\\n}) \n", "\n", - "metric_name letter_count \\\n", - "input aiconfig_output \n", - "different kinds of cuisines Day 1: Food tour in Manhattan's Chinatown, tasting regional Chinese dishes. \\n\\nDay 2: Brooklyn pizza-making class, uncovering NY-style pizza secrets. \\n\\nDay 3: Evening of Spanish tapas with flamenco performances in NYC.\\n 4 \n", - "iconic midtown skyscrapers 1. Visit Empire State Building, explore exhibits. \\n2. Proceed to Top of the Rock, photograph city views. \\n3. Explore New York Public Library Schwarzman Building, see exhibitions. 1 \n", + "metric_name letter_count \\\n", + "input aiconfig_output \n", + "different kinds of cuisines 1. Explore Chelsea Market's international food stalls.\\n2. Guided Manhattan Chinatown food tour.\\n3. Experience Italian heritage and cuisine in Little Italy. 0 \n", + "iconic midtown skyscrapers Day 1: Empire State Building, Skyride.\\nDay 2: Rockefeller Center, Top of the Rock.\\nDay 3: One World Trade Center, 9/11 Memorial & Museum. 0 \n", "\n", - "metric_name substring_match \\\n", - "input aiconfig_output \n", - "different kinds of cuisines Day 1: Food tour in Manhattan's Chinatown, tasting regional Chinese dishes. \\n\\nDay 2: Brooklyn pizza-making class, uncovering NY-style pizza secrets. \\n\\nDay 3: Evening of Spanish tapas with flamenco performances in NYC.\\n False \n", - "iconic midtown skyscrapers 1. Visit Empire State Building, explore exhibits. \\n2. Proceed to Top of the Rock, photograph city views. \\n3. Explore New York Public Library Schwarzman Building, see exhibitions. True \n", + "metric_name substring_match \\\n", + "input aiconfig_output \n", + "different kinds of cuisines 1. Explore Chelsea Market's international food stalls.\\n2. Guided Manhattan Chinatown food tour.\\n3. Experience Italian heritage and cuisine in Little Italy. False \n", + "iconic midtown skyscrapers Day 1: Empire State Building, Skyride.\\nDay 2: Rockefeller Center, Top of the Rock.\\nDay 3: One World Trade Center, 9/11 Memorial & Museum. True \n", "\n", - "metric_name text_ratings \n", - "input aiconfig_output \n", - "different kinds of cuisines Day 1: Food tour in Manhattan's Chinatown, tasting regional Chinese dishes. \\n\\nDay 2: Brooklyn pizza-making class, uncovering NY-style pizza secrets. \\n\\nDay 3: Evening of Spanish tapas with flamenco performances in NYC.\\n CustomMetricPydanticObject(data={\\n \"conciseness_rating\": 5,\\n \"conciseness_confidence\": 0.9,\\n \"conciseness_reasoning\": \"The text is concise and provides clear information about the activities planned for each day.\"\\n}) \n", - "iconic midtown skyscrapers 1. Visit Empire State Building, explore exhibits. \\n2. Proceed to Top of the Rock, photograph city views. \\n3. Explore New York Public Library Schwarzman Building, see exhibitions. CustomMetricPydanticObject(data={\\n \"conciseness_rating\": 4,\\n \"conciseness_confidence\": 0.8,\\n \"conciseness_reasoning\": \"The text is concise and provides clear instructions for visiting three different attractions in New York City.\"\\n}) " + "metric_name text_ratings \n", + "input aiconfig_output \n", + "different kinds of cuisines 1. Explore Chelsea Market's international food stalls.\\n2. Guided Manhattan Chinatown food tour.\\n3. Experience Italian heritage and cuisine in Little Italy. CustomMetricPydanticObject(data={\\n \"conciseness_rating\": 5,\\n \"conciseness_confidence\": 0.9,\\n \"conciseness_reasoning\": \"The text is concise and provides clear information about three different food-related experiences in New York City.\"\\n}) \n", + "iconic midtown skyscrapers Day 1: Empire State Building, Skyride.\\nDay 2: Rockefeller Center, Top of the Rock.\\nDay 3: One World Trade Center, 9/11 Memorial & Museum. CustomMetricPydanticObject(data={\\n \"conciseness_rating\": 5,\\n \"conciseness_confidence\": 0.9,\\n \"conciseness_reasoning\": \"The text provides a clear and concise itinerary for three days in New York City, mentioning the main attractions to visit each day.\"\\n}) " ] }, - "execution_count": 38, + "execution_count": 7, "metadata": {}, "output_type": "execute_result" } @@ -1169,7 +1147,7 @@ }, { "cell_type": "code", - "execution_count": 40, + "execution_count": 8, "metadata": {}, "outputs": [ { @@ -1188,7 +1166,7 @@ " run_test_suite_outputs_only,\n", ")\n", "\n", - "from aiconfig.eval.api import metrics\n", + "from aiconfig.eval.api import test_suite_metrics as metrics\n", "\n", "\n", "# This is similar to \"test_inputs_with_substrings\" above, but we have the AIConfig *outputs*\n", @@ -1221,7 +1199,7 @@ }, { "cell_type": "code", - "execution_count": 41, + "execution_count": 9, "metadata": {}, "outputs": [ { @@ -1233,34 +1211,38 @@ "Test output:\n", " Begin at Chelsea Market for diverse food options. Continue to Queens for immersive food tours. Conclude at Smorgasburg for unique outdoor food market experience \n", "Function:\n", - " Metric(evaluation_fn=, metric_metadata=EvaluationMetricMetadata({\n", + " TestSuiteMetric(evaluation_fn=._construct..evaluation_fn at 0x2bc1e8940>, metric_metadata=EvaluationMetricMetadata({\n", " \"name\": \"brevity\",\n", " \"description\": \"Absolute text length\",\n", " \"best_value\": 1,\n", " \"worst_value\": 9223372036854775807,\n", - " \"extra_metadata\": {},\n", - " \"id\": \"24952ce05ce6dcbd370ccc3b39d410edeab8e1cf420130a83cf9388df6bcfdc3\"\n", + " \"extra_metadata\": {\n", + " \"args\": []\n", + " },\n", + " \"id\": \"5b29b6ba68aeeadc42b7333015f4b158f7514f68c05fef79a702e98cf9983085\"\n", "}))\n", "\n", "Test output:\n", " Begin at Chelsea Market for diverse food options. Continue to Queens for immersive food tours. Conclude at Smorgasburg for unique outdoor food market experience \n", "Function:\n", - " Metric(evaluation_fn=._fn at 0x1426c83a0>, metric_metadata=EvaluationMetricMetadata({\n", + " TestSuiteMetric(evaluation_fn=._construct..evaluation_fn at 0x2bc388790>, metric_metadata=EvaluationMetricMetadata({\n", " \"name\": \"substring_match\",\n", " \"description\": \"True (pass) if contains given substring\",\n", " \"best_value\": true,\n", " \"worst_value\": false,\n", " \"extra_metadata\": {\n", - " \"substring\": \"Magnolia Bakery\",\n", + " \"args\": [\n", + " \"Magnolia Bakery\"\n", + " ],\n", " \"case_sensitive\": false\n", " },\n", - " \"id\": \"0c461362f44884023dda5537ce88263ba20d555562bac8abc05bcde0ce1aacf6\"\n", + " \"id\": \"12b2b88421a53f87fa1502c48a3bfa8b84aa22af3528178f0ec8d699db041d8d\"\n", "}))\n", "\n", "Test output:\n", " Begin at Chelsea Market for diverse food options. Continue to Queens for immersive food tours. Conclude at Smorgasburg for unique outdoor food market experience \n", "Function:\n", - " Metric(evaluation_fn=._make_evaluation_fn.._evaluation_fn at 0x132f3c430>, metric_metadata=EvaluationMetricMetadata({\n", + " TestSuiteMetric(evaluation_fn=._make_evaluation_fn.._evaluation_fn at 0x2bc1e8d30>, metric_metadata=EvaluationMetricMetadata({\n", " \"name\": \"text_ratings\",\n", " \"description\": \"Text ratings\",\n", " \"best_value\": null,\n", @@ -1276,7 +1258,7 @@ "Test output:\n", " Begin at Chelsea Market for diverse food options. Continue to Queens for immersive food tours. Conclude at Smorgasburg for unique outdoor food market experience \n", "Function:\n", - " Metric(evaluation_fn=.letter_count_metric at 0x142815510>, metric_metadata=EvaluationMetricMetadata({\n", + " TestSuiteMetric(evaluation_fn=.letter_count_metric at 0x2bc1eb880>, metric_metadata=EvaluationMetricMetadata({\n", " \"name\": \"letter_count\",\n", " \"description\": \"Counts the number of times the given letter appears in the text\",\n", " \"best_value\": null,\n", @@ -1290,7 +1272,7 @@ "Test output:\n", " Begin at Chelsea Market for diverse food options. Continue to Queens for immersive food tours. Conclude at Smorgasburg for unique outdoor food market experience \n", "Function:\n", - " Metric(evaluation_fn=._make_evaluation_fn.._evaluation_fn at 0x142817760>, metric_metadata=EvaluationMetricMetadata({\n", + " TestSuiteMetric(evaluation_fn=._make_evaluation_fn.._evaluation_fn at 0x2bc355120>, metric_metadata=EvaluationMetricMetadata({\n", " \"name\": \"emotional_valence\",\n", " \"description\": \"Emotional valence\",\n", " \"best_value\": null,\n", @@ -1306,34 +1288,38 @@ "Test output:\n", " 1. Empire State Building: Observation deck visit, explore exhibits and historical displays. 2. Rockefeller Center: Visit \"Top of the Rock\", ice-skating, NBC Studio tour, shopping and dining. 3. Chrysler Building: Admire exterior and iconic spire, photo opportunities. \n", "Function:\n", - " Metric(evaluation_fn=, metric_metadata=EvaluationMetricMetadata({\n", + " TestSuiteMetric(evaluation_fn=._construct..evaluation_fn at 0x2bc1e8940>, metric_metadata=EvaluationMetricMetadata({\n", " \"name\": \"brevity\",\n", " \"description\": \"Absolute text length\",\n", " \"best_value\": 1,\n", " \"worst_value\": 9223372036854775807,\n", - " \"extra_metadata\": {},\n", - " \"id\": \"24952ce05ce6dcbd370ccc3b39d410edeab8e1cf420130a83cf9388df6bcfdc3\"\n", + " \"extra_metadata\": {\n", + " \"args\": []\n", + " },\n", + " \"id\": \"5b29b6ba68aeeadc42b7333015f4b158f7514f68c05fef79a702e98cf9983085\"\n", "}))\n", "\n", "Test output:\n", " 1. Empire State Building: Observation deck visit, explore exhibits and historical displays. 2. Rockefeller Center: Visit \"Top of the Rock\", ice-skating, NBC Studio tour, shopping and dining. 3. Chrysler Building: Admire exterior and iconic spire, photo opportunities. \n", "Function:\n", - " Metric(evaluation_fn=._fn at 0x1426c8940>, metric_metadata=EvaluationMetricMetadata({\n", + " TestSuiteMetric(evaluation_fn=._construct..evaluation_fn at 0x2bc388550>, metric_metadata=EvaluationMetricMetadata({\n", " \"name\": \"substring_match\",\n", " \"description\": \"True (pass) if contains given substring\",\n", " \"best_value\": true,\n", " \"worst_value\": false,\n", " \"extra_metadata\": {\n", - " \"substring\": \"Empire State Building\",\n", + " \"args\": [\n", + " \"Empire State Building\"\n", + " ],\n", " \"case_sensitive\": false\n", " },\n", - " \"id\": \"53e4c7163f49fdc7727286e638ff07bcb570faaa334456775c616c2f4ad3eb3f\"\n", + " \"id\": \"17bb1efe1fb306bce98240f3534f5d29c68564e4e7c1c0db17198247d19754e3\"\n", "}))\n", "\n", "Test output:\n", " 1. Empire State Building: Observation deck visit, explore exhibits and historical displays. 2. Rockefeller Center: Visit \"Top of the Rock\", ice-skating, NBC Studio tour, shopping and dining. 3. Chrysler Building: Admire exterior and iconic spire, photo opportunities. \n", "Function:\n", - " Metric(evaluation_fn=._make_evaluation_fn.._evaluation_fn at 0x132f3c430>, metric_metadata=EvaluationMetricMetadata({\n", + " TestSuiteMetric(evaluation_fn=._make_evaluation_fn.._evaluation_fn at 0x2bc1e8d30>, metric_metadata=EvaluationMetricMetadata({\n", " \"name\": \"text_ratings\",\n", " \"description\": \"Text ratings\",\n", " \"best_value\": null,\n", @@ -1349,7 +1335,7 @@ "Test output:\n", " 1. Empire State Building: Observation deck visit, explore exhibits and historical displays. 2. Rockefeller Center: Visit \"Top of the Rock\", ice-skating, NBC Studio tour, shopping and dining. 3. Chrysler Building: Admire exterior and iconic spire, photo opportunities. \n", "Function:\n", - " Metric(evaluation_fn=.letter_count_metric at 0x142815510>, metric_metadata=EvaluationMetricMetadata({\n", + " TestSuiteMetric(evaluation_fn=.letter_count_metric at 0x2bc1eb880>, metric_metadata=EvaluationMetricMetadata({\n", " \"name\": \"letter_count\",\n", " \"description\": \"Counts the number of times the given letter appears in the text\",\n", " \"best_value\": null,\n", @@ -1363,7 +1349,7 @@ "Test output:\n", " 1. Empire State Building: Observation deck visit, explore exhibits and historical displays. 2. Rockefeller Center: Visit \"Top of the Rock\", ice-skating, NBC Studio tour, shopping and dining. 3. Chrysler Building: Admire exterior and iconic spire, photo opportunities. \n", "Function:\n", - " Metric(evaluation_fn=._make_evaluation_fn.._evaluation_fn at 0x142817760>, metric_metadata=EvaluationMetricMetadata({\n", + " TestSuiteMetric(evaluation_fn=._make_evaluation_fn.._evaluation_fn at 0x2bc355120>, metric_metadata=EvaluationMetricMetadata({\n", " \"name\": \"emotional_valence\",\n", " \"description\": \"Emotional valence\",\n", " \"best_value\": null,\n", @@ -1387,7 +1373,7 @@ }, { "cell_type": "code", - "execution_count": 42, + "execution_count": 10, "metadata": {}, "outputs": [ { @@ -1435,7 +1421,7 @@ " Missing\n", " Begin at Chelsea Market for diverse food options. Continue to Queens for immersive food tours. Conclude at Smorgasburg for unique outdoor food market experience\n", " 160\n", - " 24952ce05ce6dcbd370ccc3b39d410edeab8e1cf420130a83cf9388df6bcfdc3\n", + " 5b29b6ba68aeeadc42b7333015f4b158f7514f68c05fef79a702e98cf9983085\n", " brevity\n", " Absolute text length\n", " 1\n", @@ -1446,7 +1432,7 @@ " Missing\n", " Begin at Chelsea Market for diverse food options. Continue to Queens for immersive food tours. Conclude at Smorgasburg for unique outdoor food market experience\n", " False\n", - " 0c461362f44884023dda5537ce88263ba20d555562bac8abc05bcde0ce1aacf6\n", + " 12b2b88421a53f87fa1502c48a3bfa8b84aa22af3528178f0ec8d699db041d8d\n", " substring_match\n", " True (pass) if contains given substring\n", " True\n", @@ -1456,7 +1442,7 @@ " 2\n", " Missing\n", " Begin at Chelsea Market for diverse food options. Continue to Queens for immersive food tours. Conclude at Smorgasburg for unique outdoor food market experience\n", - " CustomMetricPydanticObject(data={\\n \"conciseness_rating\": 4,\\n \"conciseness_confidence\": 0.8,\\n \"conciseness_reasoning\": \"The text provides a clear and concise description of the itinerary, mentioning the starting point at Chelsea Market, the visit to Queens for food tours, and the conclusion at Smorgasburg for an outdoor food market experience.\"\\n})\n", + " CustomMetricPydanticObject(data={\\n \"conciseness_rating\": 4,\\n \"conciseness_confidence\": 0.8,\\n \"conciseness_reasoning\": \"The text provides a clear and concise description of the itinerary, mentioning the starting point, the main activity in Queens, and the final destination.\"\\n})\n", " 300b32bb8a01befd5e729eaf73506bdba01f910c0db0c8f70136dd2e48e298a7\n", " text_ratings\n", " Text ratings\n", @@ -1490,7 +1476,7 @@ " Missing\n", " 1. Empire State Building: Observation deck visit, explore exhibits and historical displays. 2. Rockefeller Center: Visit \"Top of the Rock\", ice-skating, NBC Studio tour, shopping and dining. 3. Chrysler Building: Admire exterior and iconic spire, photo opportunities.\n", " 267\n", - " 24952ce05ce6dcbd370ccc3b39d410edeab8e1cf420130a83cf9388df6bcfdc3\n", + " 5b29b6ba68aeeadc42b7333015f4b158f7514f68c05fef79a702e98cf9983085\n", " brevity\n", " Absolute text length\n", " 1\n", @@ -1501,7 +1487,7 @@ " Missing\n", " 1. Empire State Building: Observation deck visit, explore exhibits and historical displays. 2. Rockefeller Center: Visit \"Top of the Rock\", ice-skating, NBC Studio tour, shopping and dining. 3. Chrysler Building: Admire exterior and iconic spire, photo opportunities.\n", " True\n", - " 53e4c7163f49fdc7727286e638ff07bcb570faaa334456775c616c2f4ad3eb3f\n", + " 17bb1efe1fb306bce98240f3534f5d29c68564e4e7c1c0db17198247d19754e3\n", " substring_match\n", " True (pass) if contains given substring\n", " True\n", @@ -1511,7 +1497,7 @@ " 7\n", " Missing\n", " 1. Empire State Building: Observation deck visit, explore exhibits and historical displays. 2. Rockefeller Center: Visit \"Top of the Rock\", ice-skating, NBC Studio tour, shopping and dining. 3. Chrysler Building: Admire exterior and iconic spire, photo opportunities.\n", - " CustomMetricPydanticObject(data={\\n \"conciseness_rating\": 5,\\n \"conciseness_confidence\": 0.9,\\n \"conciseness_reasoning\": \"The text is concise and provides a clear description of each attraction.\"\\n})\n", + " CustomMetricPydanticObject(data={\\n \"conciseness_rating\": 5,\\n \"conciseness_confidence\": 0.9,\\n \"conciseness_reasoning\": \"The text provides a concise description of the attractions and activities at each location.\"\\n})\n", " 300b32bb8a01befd5e729eaf73506bdba01f910c0db0c8f70136dd2e48e298a7\n", " text_ratings\n", " Text ratings\n", @@ -1569,26 +1555,26 @@ "8 1. Empire State Building: Observation deck visit, explore exhibits and historical displays. 2. Rockefeller Center: Visit \"Top of the Rock\", ice-skating, NBC Studio tour, shopping and dining. 3. Chrysler Building: Admire exterior and iconic spire, photo opportunities. \n", "9 1. Empire State Building: Observation deck visit, explore exhibits and historical displays. 2. Rockefeller Center: Visit \"Top of the Rock\", ice-skating, NBC Studio tour, shopping and dining. 3. Chrysler Building: Admire exterior and iconic spire, photo opportunities. \n", "\n", - " value \\\n", - "0 160 \n", - "1 False \n", - "2 CustomMetricPydanticObject(data={\\n \"conciseness_rating\": 4,\\n \"conciseness_confidence\": 0.8,\\n \"conciseness_reasoning\": \"The text provides a clear and concise description of the itinerary, mentioning the starting point at Chelsea Market, the visit to Queens for food tours, and the conclusion at Smorgasburg for an outdoor food market experience.\"\\n}) \n", - "3 0 \n", - "4 CustomMetricPydanticObject(data={\\n \"emotional_valence\": \"happy\",\\n \"confidence_probability\": 0.9\\n}) \n", - "5 267 \n", - "6 True \n", - "7 CustomMetricPydanticObject(data={\\n \"conciseness_rating\": 5,\\n \"conciseness_confidence\": 0.9,\\n \"conciseness_reasoning\": \"The text is concise and provides a clear description of each attraction.\"\\n}) \n", - "8 0 \n", - "9 CustomMetricPydanticObject(data={\\n \"emotional_valence\": \"happy\",\\n \"confidence_probability\": 0.9\\n}) \n", + " value \\\n", + "0 160 \n", + "1 False \n", + "2 CustomMetricPydanticObject(data={\\n \"conciseness_rating\": 4,\\n \"conciseness_confidence\": 0.8,\\n \"conciseness_reasoning\": \"The text provides a clear and concise description of the itinerary, mentioning the starting point, the main activity in Queens, and the final destination.\"\\n}) \n", + "3 0 \n", + "4 CustomMetricPydanticObject(data={\\n \"emotional_valence\": \"happy\",\\n \"confidence_probability\": 0.9\\n}) \n", + "5 267 \n", + "6 True \n", + "7 CustomMetricPydanticObject(data={\\n \"conciseness_rating\": 5,\\n \"conciseness_confidence\": 0.9,\\n \"conciseness_reasoning\": \"The text provides a concise description of the attractions and activities at each location.\"\\n}) \n", + "8 0 \n", + "9 CustomMetricPydanticObject(data={\\n \"emotional_valence\": \"happy\",\\n \"confidence_probability\": 0.9\\n}) \n", "\n", " metric_id \\\n", - "0 24952ce05ce6dcbd370ccc3b39d410edeab8e1cf420130a83cf9388df6bcfdc3 \n", - "1 0c461362f44884023dda5537ce88263ba20d555562bac8abc05bcde0ce1aacf6 \n", + "0 5b29b6ba68aeeadc42b7333015f4b158f7514f68c05fef79a702e98cf9983085 \n", + "1 12b2b88421a53f87fa1502c48a3bfa8b84aa22af3528178f0ec8d699db041d8d \n", "2 300b32bb8a01befd5e729eaf73506bdba01f910c0db0c8f70136dd2e48e298a7 \n", "3 855b84d49dadc258f82d949bf3d57a100c788e6e093e4615e8b4e03567f1ffc9 \n", "4 a351b0b7ab1639eb32695430b3e1bb65c96d11b528730c103d5879234a3bd8bb \n", - "5 24952ce05ce6dcbd370ccc3b39d410edeab8e1cf420130a83cf9388df6bcfdc3 \n", - "6 53e4c7163f49fdc7727286e638ff07bcb570faaa334456775c616c2f4ad3eb3f \n", + "5 5b29b6ba68aeeadc42b7333015f4b158f7514f68c05fef79a702e98cf9983085 \n", + "6 17bb1efe1fb306bce98240f3534f5d29c68564e4e7c1c0db17198247d19754e3 \n", "7 300b32bb8a01befd5e729eaf73506bdba01f910c0db0c8f70136dd2e48e298a7 \n", "8 855b84d49dadc258f82d949bf3d57a100c788e6e093e4615e8b4e03567f1ffc9 \n", "9 a351b0b7ab1639eb32695430b3e1bb65c96d11b528730c103d5879234a3bd8bb \n", @@ -1630,7 +1616,7 @@ "9 None None " ] }, - "execution_count": 42, + "execution_count": 10, "metadata": {}, "output_type": "execute_result" } @@ -1646,7 +1632,7 @@ }, { "cell_type": "code", - "execution_count": 43, + "execution_count": 11, "metadata": {}, "outputs": [ { @@ -1699,7 +1685,7 @@ " CustomMetricPydanticObject(data={\\n \"emotional_valence\": \"happy\",\\n \"confidence_probability\": 0.9\\n})\n", " 0\n", " True\n", - " CustomMetricPydanticObject(data={\\n \"conciseness_rating\": 5,\\n \"conciseness_confidence\": 0.9,\\n \"conciseness_reasoning\": \"The text is concise and provides a clear description of each attraction.\"\\n})\n", + " CustomMetricPydanticObject(data={\\n \"conciseness_rating\": 5,\\n \"conciseness_confidence\": 0.9,\\n \"conciseness_reasoning\": \"The text provides a concise description of the attractions and activities at each location.\"\\n})\n", " \n", " \n", " Begin at Chelsea Market for diverse food options. Continue to Queens for immersive food tours. Conclude at Smorgasburg for unique outdoor food market experience\n", @@ -1707,7 +1693,7 @@ " CustomMetricPydanticObject(data={\\n \"emotional_valence\": \"happy\",\\n \"confidence_probability\": 0.9\\n})\n", " 0\n", " False\n", - " CustomMetricPydanticObject(data={\\n \"conciseness_rating\": 4,\\n \"conciseness_confidence\": 0.8,\\n \"conciseness_reasoning\": \"The text provides a clear and concise description of the itinerary, mentioning the starting point at Chelsea Market, the visit to Queens for food tours, and the conclusion at Smorgasburg for an outdoor food market experience.\"\\n})\n", + " CustomMetricPydanticObject(data={\\n \"conciseness_rating\": 4,\\n \"conciseness_confidence\": 0.8,\\n \"conciseness_reasoning\": \"The text provides a clear and concise description of the itinerary, mentioning the starting point, the main activity in Queens, and the final destination.\"\\n})\n", " \n", " \n", "\n", @@ -1734,13 +1720,13 @@ "1. Empire State Building: Observation deck visit, explore exhibits and historical displays. 2. Rockefeller Center: Visit \"Top of the Rock\", ice-skating, NBC Studio tour, shopping and dining. 3. Chrysler Building: Admire exterior and iconic spire, photo opportunities. True \n", "Begin at Chelsea Market for diverse food options. Continue to Queens for immersive food tours. Conclude at Smorgasburg for unique outdoor food market experience False \n", "\n", - "metric_name text_ratings \n", - "aiconfig_output \n", - "1. Empire State Building: Observation deck visit, explore exhibits and historical displays. 2. Rockefeller Center: Visit \"Top of the Rock\", ice-skating, NBC Studio tour, shopping and dining. 3. Chrysler Building: Admire exterior and iconic spire, photo opportunities. CustomMetricPydanticObject(data={\\n \"conciseness_rating\": 5,\\n \"conciseness_confidence\": 0.9,\\n \"conciseness_reasoning\": \"The text is concise and provides a clear description of each attraction.\"\\n}) \n", - "Begin at Chelsea Market for diverse food options. Continue to Queens for immersive food tours. Conclude at Smorgasburg for unique outdoor food market experience CustomMetricPydanticObject(data={\\n \"conciseness_rating\": 4,\\n \"conciseness_confidence\": 0.8,\\n \"conciseness_reasoning\": \"The text provides a clear and concise description of the itinerary, mentioning the starting point at Chelsea Market, the visit to Queens for food tours, and the conclusion at Smorgasburg for an outdoor food market experience.\"\\n}) " + "metric_name text_ratings \n", + "aiconfig_output \n", + "1. Empire State Building: Observation deck visit, explore exhibits and historical displays. 2. Rockefeller Center: Visit \"Top of the Rock\", ice-skating, NBC Studio tour, shopping and dining. 3. Chrysler Building: Admire exterior and iconic spire, photo opportunities. CustomMetricPydanticObject(data={\\n \"conciseness_rating\": 5,\\n \"conciseness_confidence\": 0.9,\\n \"conciseness_reasoning\": \"The text provides a concise description of the attractions and activities at each location.\"\\n}) \n", + "Begin at Chelsea Market for diverse food options. Continue to Queens for immersive food tours. Conclude at Smorgasburg for unique outdoor food market experience CustomMetricPydanticObject(data={\\n \"conciseness_rating\": 4,\\n \"conciseness_confidence\": 0.8,\\n \"conciseness_reasoning\": \"The text provides a clear and concise description of the itinerary, mentioning the starting point, the main activity in Queens, and the final destination.\"\\n}) " ] }, - "execution_count": 43, + "execution_count": 11, "metadata": {}, "output_type": "execute_result" } diff --git a/python/src/aiconfig/eval/examples/travel/travel_parametrized.aiconfig.json b/python/src/aiconfig/eval/test_suite_examples/travel/travel_parametrized.aiconfig.json similarity index 100% rename from python/src/aiconfig/eval/examples/travel/travel_parametrized.aiconfig.json rename to python/src/aiconfig/eval/test_suite_examples/travel/travel_parametrized.aiconfig.json diff --git a/python/src/aiconfig/eval/examples/travel/travel_promptfoo_config.yaml b/python/src/aiconfig/eval/test_suite_examples/travel/travel_promptfoo_config.yaml similarity index 100% rename from python/src/aiconfig/eval/examples/travel/travel_promptfoo_config.yaml rename to python/src/aiconfig/eval/test_suite_examples/travel/travel_promptfoo_config.yaml diff --git a/python/src/aiconfig/eval/lib.py b/python/src/aiconfig/eval/test_suite_lib.py similarity index 59% rename from python/src/aiconfig/eval/lib.py rename to python/src/aiconfig/eval/test_suite_lib.py index 1260a86fa..dbf48379e 100644 --- a/python/src/aiconfig/eval/lib.py +++ b/python/src/aiconfig/eval/test_suite_lib.py @@ -2,15 +2,16 @@ import json import logging from dataclasses import dataclass -from frozendict import frozendict from functools import partial -from typing import Any, Generic, NewType, Sequence, Tuple, TypeVar +from typing import Any, Generic, Sequence, Tuple, TypeVar -import aiconfig.eval.common as common +import aiconfig.eval.test_suite_common as test_suite_common import lastmile_utils.lib.core.api as core_utils import pandas as pd from aiconfig.Config import AIConfigRuntime -from aiconfig.eval.metrics import Metric +from aiconfig.eval import common +from aiconfig.eval.test_suite_metrics import TestSuiteMetric +from frozendict import frozendict from result import Err, Ok, Result logging.basicConfig(format=core_utils.LOGGER_FMT) @@ -19,20 +20,12 @@ # TODO: figure out a way to do heterogenous list without Any # Each test is a (input_datum, Metric) pair -UserTestSuiteWithInputs = Sequence[Tuple[str | dict[str, str], Metric[str, Any]]] +UserTestSuiteWithInputs = Sequence[ + Tuple[str | dict[str, str], TestSuiteMetric[str, Any]] +] # Each test is a (output_datum, Metric) pair -UserTestSuiteOutputsOnly = Sequence[Tuple[str, Metric[str, Any]]] - - -# NOTE: it's probably better to avoid NewType in the future, because it doesn't -# ... create a ... new type. For example, you can't pattern match against it. -TextOutput = NewType("TextOutput", str) - - -@dataclass(frozen=True) -class TextBasedInputDatum: - value: str | frozendict[str, str] +UserTestSuiteOutputsOnly = Sequence[Tuple[str, TestSuiteMetric[str, Any]]] @dataclass(frozen=True) @@ -70,7 +63,11 @@ async def run_test_suite_outputs_only( test_suite: UserTestSuiteOutputsOnly, settings: TestSuiteOutputsOnlySettings = TestSuiteOutputsOnlySettings(), ) -> pd.DataFrame: - res = await run_test_suite_helper(TestSuiteOutputsOnlySpec(test_suite=test_suite, general_settings=settings.general_settings)) + res = await run_test_suite_helper( + TestSuiteOutputsOnlySpec( + test_suite=test_suite, general_settings=settings.general_settings + ) + ) return res.map(text_eval_res_to_df).unwrap_or_raise(ValueError) @@ -85,34 +82,50 @@ class NumericalEvalDataset(core_utils.Record): # TODO: # GenericBeforeBaseModelWarning: Classes should inherit from `BaseModel` before generic classes (e.g. `typing.Generic[T]`) for pydantic generics to work properly. # But swapping the order breaks -class SampleEvaluationResult(Generic[common.T_InputDatum, common.T_OutputDatum, common.T_MetricValue], core_utils.Record): +class SampleEvaluationResult( + Generic[common.T_InputDatum, common.T_OutputDatum, common.T_MetricValue], + core_utils.Record, +): input_datum: common.T_InputDatum | None output_datum: common.T_OutputDatum - metric_value: common.SampleMetricValue[common.T_OutputDatum, common.T_MetricValue] + metric_value: test_suite_common.SampleMetricValue[ + common.T_OutputDatum, common.T_MetricValue + ] @dataclass(frozen=True) -class SampleEvaluationParams(Generic[common.T_InputDatum, common.T_OutputDatum, common.T_MetricValue]): +class SampleEvaluationParams( + Generic[common.T_InputDatum, common.T_OutputDatum, common.T_MetricValue] +): # input_sample doesn't _need_ to be here, because we already have # output_sample ready to input to eval. # input_sample is here for documentation/debugging. input_sample: common.T_InputDatum | None output_sample: common.T_OutputDatum - metric: Metric[common.T_OutputDatum, common.T_MetricValue] + metric: TestSuiteMetric[common.T_OutputDatum, common.T_MetricValue] def __str__(self) -> str: return f"\nSampleEvaluationParams:\n\t{self.output_sample=}\n\t{self.metric=}" # TODO: don't use Any. -DatasetEvaluationResult = Sequence[SampleEvaluationResult[common.T_InputDatum, common.T_OutputDatum, Any]] -DatasetEvaluationParams = Sequence[SampleEvaluationParams[common.T_InputDatum, common.T_OutputDatum, Any]] -MetricList = list[Metric[common.T_OutputDatum, Any]] +DatasetEvaluationResult = Sequence[ + SampleEvaluationResult[common.T_InputDatum, common.T_OutputDatum, Any] +] +DatasetEvaluationParams = Sequence[ + SampleEvaluationParams[common.T_InputDatum, common.T_OutputDatum, Any] +] +MetricList = list[TestSuiteMetric[common.T_OutputDatum, Any]] async def _evaluate_for_sample( - eval_params: SampleEvaluationParams[common.T_InputDatum, common.T_OutputDatum, common.T_MetricValue], timeout_s: int -) -> SampleEvaluationResult[common.T_InputDatum, common.T_OutputDatum, common.T_MetricValue]: + eval_params: SampleEvaluationParams[ + common.T_InputDatum, common.T_OutputDatum, common.T_MetricValue + ], + timeout_s: int, +) -> SampleEvaluationResult[ + common.T_InputDatum, common.T_OutputDatum, common.T_MetricValue +]: sample, metric = ( eval_params.output_sample, eval_params.metric, @@ -121,7 +134,9 @@ async def _evaluate_for_sample( async def _calculate() -> common.T_MetricValue: return await metric.evaluation_fn(sample) - def _ok_with_log(res_: Result[common.T_MetricValue, str]) -> common.T_MetricValue | None: + def _ok_with_log( + res_: Result[common.T_MetricValue, str] + ) -> common.T_MetricValue | None: match res_: case Ok(res): return res @@ -133,7 +148,7 @@ def _ok_with_log(res_: Result[common.T_MetricValue, str]) -> common.T_MetricValu result = SampleEvaluationResult( input_datum=eval_params.input_sample, output_datum=sample, - metric_value=common.SampleMetricValue( + metric_value=test_suite_common.SampleMetricValue( # value=_ok_with_log(res_), metric_metadata=metric.metric_metadata, @@ -143,18 +158,36 @@ def _ok_with_log(res_: Result[common.T_MetricValue, str]) -> common.T_MetricValu async def evaluate( - evaluation_params_list: DatasetEvaluationParams[common.T_InputDatum, common.T_OutputDatum], eval_fn_timeout_s: int -) -> Result[DatasetEvaluationResult[common.T_InputDatum, common.T_OutputDatum], str]: - return Ok(await asyncio.gather(*map(partial(_evaluate_for_sample, timeout_s=eval_fn_timeout_s), evaluation_params_list))) + evaluation_params_list: DatasetEvaluationParams[ + common.T_InputDatum, common.T_OutputDatum + ], + eval_fn_timeout_s: int, +) -> Result[ + DatasetEvaluationResult[common.T_InputDatum, common.T_OutputDatum], str +]: + return Ok( + await asyncio.gather( + *map( + partial(_evaluate_for_sample, timeout_s=eval_fn_timeout_s), + evaluation_params_list, + ) + ) + ) def text_eval_res_to_df( - eval_res: DatasetEvaluationResult[TextBasedInputDatum, TextOutput], + eval_res: DatasetEvaluationResult[ + common.TextBasedInputDatum, common.TextOutput + ], ) -> pd.DataFrame: def _extract_text_based_input_for_display( - eval_res: DatasetEvaluationResult[TextBasedInputDatum, TextOutput], - ) -> DatasetEvaluationResult[str, TextOutput]: - def _extract_value(input_text_datum: TextBasedInputDatum | None) -> str | None: + eval_res: DatasetEvaluationResult[ + common.TextBasedInputDatum, common.TextOutput + ], + ) -> DatasetEvaluationResult[str, common.TextOutput]: + def _extract_value( + input_text_datum: common.TextBasedInputDatum | None, + ) -> str | None: if input_text_datum is None: return None else: @@ -162,7 +195,9 @@ def _extract_value(input_text_datum: TextBasedInputDatum | None) -> str | None: case str(input_text): return input_text case frozendict(): - return json.dumps(input_text_datum.value, sort_keys=True) + return json.dumps( + input_text_datum.value, sort_keys=True + ) return [ SampleEvaluationResult( @@ -199,28 +234,41 @@ def _extract_value(input_text_datum: TextBasedInputDatum | None) -> str | None: async def user_test_suite_with_inputs_to_eval_params_list( - test_suite: UserTestSuiteWithInputs, prompt_name: str, aiconfig: AIConfigRuntime -) -> Result[DatasetEvaluationParams[TextBasedInputDatum, TextOutput], str]: + test_suite: UserTestSuiteWithInputs, + prompt_name: str, + aiconfig: AIConfigRuntime, +) -> Result[ + DatasetEvaluationParams[common.TextBasedInputDatum, common.TextOutput], str +]: """ Example in/out: [("hello", brevity)] -> [SampleEvaluationParams("hello", "output_is_world", brevity)] """ - def _user_test_input_to_internal_type(input_datum_user_given: str | dict[str, str]) -> TextBasedInputDatum: + def _user_test_input_to_internal_type( + input_datum_user_given: str | dict[str, str] + ) -> common.TextBasedInputDatum: match input_datum_user_given: case str(input_datum): - return TextBasedInputDatum(input_datum) + return common.TextBasedInputDatum(input_datum) case dict(input_datum): - return TextBasedInputDatum(frozendict(input_datum)) + return common.TextBasedInputDatum(frozendict(input_datum)) - test_suite_internal_types = [(_user_test_input_to_internal_type(input_datum), metric) for input_datum, metric in test_suite] + test_suite_internal_types = [ + (_user_test_input_to_internal_type(input_datum), metric) + for input_datum, metric in test_suite + ] - out: DatasetEvaluationParams[TextBasedInputDatum, TextOutput] = [] + out: DatasetEvaluationParams[ + common.TextBasedInputDatum, common.TextOutput + ] = [] # Group by input so that we only run each input through the AIConfig once. # This is sort of an optimization because the user can give the same input # multiple times (with different metrics). - input_to_metrics_mapping: dict[TextBasedInputDatum, MetricList[TextOutput]] = {} + input_to_metrics_mapping: dict[ + common.TextBasedInputDatum, MetricList[common.TextOutput] + ] = {} for input_datum, metric in test_suite_internal_types: if input_datum not in input_to_metrics_mapping: input_to_metrics_mapping[input_datum] = [] @@ -228,18 +276,11 @@ def _user_test_input_to_internal_type(input_datum_user_given: str | dict[str, st all_inputs = list(input_to_metrics_mapping.keys()) - async def _run(input_datum: TextBasedInputDatum) -> Result[TextOutput, str]: - return (await run_aiconfig_on_text_based_input(aiconfig, prompt_name, input_datum)).map(TextOutput) - - # TODO: fix the race condition and then use gather - # https://github.com/lastmile-ai/aiconfig/issues/434 - res_outputs_: list[Result[TextOutput, str]] = [] - for input_datum in all_inputs: - res_outputs_.append(await _run(input_datum)) - res_outputs = core_utils.result_reduce_list_all_ok(res_outputs_) - # res_outputs = await core_utils.result_reduce_list_all_ok_async(list(map(_run, all_inputs))) + res_outputs = await common.batch_run_aiconfig_on_text_based_input( + aiconfig, prompt_name, all_inputs + ) - def _zip_inputs_outputs(outputs: list[TextOutput]): + def _zip_inputs_outputs(outputs: list[common.TextOutput]): # This zip is safe because we have defined an order for the keys in `all_inputs` # them apped run_aiconfig over that list. # Docs: https://docs.python.org/3/library/asyncio-task.html#running-tasks-concurrently @@ -264,23 +305,18 @@ def _zip_inputs_outputs(outputs: list[TextOutput]): def user_test_suite_outputs_only_to_eval_params_list( test_suite: UserTestSuiteOutputsOnly, -) -> DatasetEvaluationParams[TextBasedInputDatum, TextOutput]: +) -> DatasetEvaluationParams[common.TextBasedInputDatum, common.TextOutput]: """ Example: [("the_output_is_world", brevity)] -> [SampleEvaluationParams(None, "the_output_is_world", brevity) """ - return [SampleEvaluationParams(input_sample=None, output_sample=TextOutput(output_datum), metric=metric) for output_datum, metric in test_suite] - - -async def run_aiconfig_on_text_based_input(runtime: AIConfigRuntime, prompt_name: str, params: TextBasedInputDatum) -> Result[str, str]: - def _get_params_for_aiconfig(params: TextBasedInputDatum) -> dict[str, str]: - match params.value: - case str(input_text): - return {"the_query": input_text} - case frozendict(): - return dict(params.value) - - params_for_aiconfig = _get_params_for_aiconfig(params) - return await common.run_aiconfig_get_output_text(runtime, prompt_name, params_for_aiconfig, run_with_dependencies=True) + return [ + SampleEvaluationParams( + input_sample=None, + output_sample=common.TextOutput(output_datum), + metric=metric, + ) + for output_datum, metric in test_suite + ] @dataclass(frozen=True) @@ -302,22 +338,47 @@ class TestSuiteOutputsOnlySpec: async def run_test_suite_helper( test_suite_spec: TestSuiteSpec, -) -> Result[DatasetEvaluationResult[TextBasedInputDatum, TextOutput], str]: +) -> Result[ + DatasetEvaluationResult[common.TextBasedInputDatum, common.TextOutput], str +]: async def _get_eval_params_list( test_suite_spec: TestSuiteSpec, - ) -> Result[DatasetEvaluationParams[TextBasedInputDatum, TextOutput], str]: + ) -> Result[ + DatasetEvaluationParams[common.TextBasedInputDatum, common.TextOutput], + str, + ]: match test_suite_spec: - case TestSuiteWithInputsSpec(test_suite=test_suite, prompt_name=prompt_name, aiconfig=aiconfig): - return await user_test_suite_with_inputs_to_eval_params_list(test_suite, prompt_name, aiconfig) + case TestSuiteWithInputsSpec( + test_suite=test_suite, + prompt_name=prompt_name, + aiconfig=aiconfig, + ): + return await user_test_suite_with_inputs_to_eval_params_list( + test_suite, prompt_name, aiconfig + ) case TestSuiteOutputsOnlySpec(test_suite=test_suite): - return Ok(user_test_suite_outputs_only_to_eval_params_list(test_suite)) + return Ok( + user_test_suite_outputs_only_to_eval_params_list( + test_suite + ) + ) eval_params_list = await _get_eval_params_list(test_suite_spec) async def _evaluate_with_timeout( - eval_params_list: DatasetEvaluationParams[TextBasedInputDatum, TextOutput], - ) -> Result[DatasetEvaluationResult[TextBasedInputDatum, TextOutput], str]: - return await evaluate(eval_params_list, eval_fn_timeout_s=test_suite_spec.general_settings.eval_fn_timeout_s) + eval_params_list: DatasetEvaluationParams[ + common.TextBasedInputDatum, common.TextOutput + ], + ) -> Result[ + DatasetEvaluationResult[common.TextBasedInputDatum, common.TextOutput], + str, + ]: + return await evaluate( + eval_params_list, + eval_fn_timeout_s=test_suite_spec.general_settings.eval_fn_timeout_s, + ) - res_evaluated = await eval_params_list.and_then_async(_evaluate_with_timeout) + res_evaluated = await eval_params_list.and_then_async( + _evaluate_with_timeout + ) return res_evaluated diff --git a/python/src/aiconfig/eval/metrics.py b/python/src/aiconfig/eval/test_suite_metrics.py similarity index 64% rename from python/src/aiconfig/eval/metrics.py rename to python/src/aiconfig/eval/test_suite_metrics.py index 02617ffee..41407a02f 100644 --- a/python/src/aiconfig/eval/metrics.py +++ b/python/src/aiconfig/eval/test_suite_metrics.py @@ -3,25 +3,46 @@ from abc import abstractmethod from dataclasses import dataclass from functools import partial, total_ordering -from typing import Any, Awaitable, Callable, Concatenate, Generic, ParamSpec, Protocol, Type +from typing import ( + Any, + Awaitable, + Callable, + Concatenate, + Generic, + ParamSpec, + Protocol, + Type, +) import lastmile_utils.lib.core.api as core_utils import nltk import pandas as pd -from aiconfig.eval import common -from aiconfig.eval.openai import OpenAIChatCompletionCreate, default_openai_chat_completion_create, make_fn_completion_text_to_serialized_json -from nltk.sentiment.vader import SentimentIntensityAnalyzer as NLTKSentimentIntensityAnalyzer +from aiconfig.eval import common, test_suite_common +from aiconfig.eval.openai import ( + OpenAIChatCompletionCreate, + default_openai_chat_completion_create, + make_fn_completion_text_to_serialized_json, +) +from nltk.sentiment.vader import ( + SentimentIntensityAnalyzer as NLTKSentimentIntensityAnalyzer, +) from result import Err, Ok, Result @dataclass(frozen=True) -class Metric(Generic[common.T_Evaluable, common.T_MetricValue]): +class TestSuiteMetric(Generic[common.T_Evaluable, common.T_MetricValue]): """See metrics.py for examples.""" - evaluation_fn: common.EvaluationFunction[common.T_Evaluable, common.T_MetricValue] - metric_metadata: common.EvaluationMetricMetadata[common.T_Evaluable, common.T_MetricValue] + evaluation_fn: test_suite_common.EvaluationFunction[ + common.T_Evaluable, common.T_MetricValue + ] + metric_metadata: common.EvaluationMetricMetadata[ + common.T_Evaluable, common.T_MetricValue + ] - async def __call__(self, datum: common.T_Evaluable) -> common.T_MetricValue: + async def __call__( + self, datum: common.T_Evaluable + ) -> common.T_MetricValue: """ For convenience, make a Metric callable. Similar to torch Module `forward()`. @@ -34,20 +55,26 @@ async def __call__(self, datum: common.T_Evaluable) -> common.T_MetricValue: @core_utils.parametrized def metric( - parametrized_evaluation_fn: Callable[Concatenate[common.T_Evaluable, PS], common.T_MetricValue], + parametrized_evaluation_fn: Callable[ + Concatenate[common.T_Evaluable, PS], common.T_MetricValue + ], name: str | None = None, description: str | None = None, best_value: common.T_MetricValue | None = None, worst_value: common.T_MetricValue | None = None, -) -> Callable[PS, Metric[common.T_Evaluable, common.T_MetricValue]]: +) -> Callable[PS, TestSuiteMetric[common.T_Evaluable, common.T_MetricValue]]: name_ = name or parametrized_evaluation_fn.__name__ description_ = description or name_ - def _construct(*args: PS.args, **kwargs: PS.kwargs) -> Metric[common.T_Evaluable, common.T_MetricValue]: - async def evaluation_fn(datum: common.T_Evaluable) -> common.T_MetricValue: + def _construct( + *args: PS.args, **kwargs: PS.kwargs + ) -> TestSuiteMetric[common.T_Evaluable, common.T_MetricValue]: + async def evaluation_fn( + datum: common.T_Evaluable, + ) -> common.T_MetricValue: return parametrized_evaluation_fn(datum, *args, **kwargs) - return Metric( + return TestSuiteMetric( evaluation_fn=evaluation_fn, metric_metadata=common.EvaluationMetricMetadata( name=name_, @@ -63,20 +90,26 @@ async def evaluation_fn(datum: common.T_Evaluable) -> common.T_MetricValue: @core_utils.parametrized def metric_async( - parametrized_evaluation_fn: Callable[Concatenate[common.T_Evaluable, PS], Awaitable[common.T_MetricValue]], + parametrized_evaluation_fn: Callable[ + Concatenate[common.T_Evaluable, PS], Awaitable[common.T_MetricValue] + ], name: str | None = None, description: str | None = None, best_value: common.T_MetricValue | None = None, worst_value: common.T_MetricValue | None = None, -) -> Callable[PS, Metric[common.T_Evaluable, common.T_MetricValue]]: +) -> Callable[PS, TestSuiteMetric[common.T_Evaluable, common.T_MetricValue]]: name_ = name or parametrized_evaluation_fn.__name__ description_ = description or name_ - def _construct(*args: PS.args, **kwargs: PS.kwargs) -> Metric[common.T_Evaluable, common.T_MetricValue]: - async def evaluation_fn(datum: common.T_Evaluable) -> common.T_MetricValue: + def _construct( + *args: PS.args, **kwargs: PS.kwargs + ) -> TestSuiteMetric[common.T_Evaluable, common.T_MetricValue]: + async def evaluation_fn( + datum: common.T_Evaluable, + ) -> common.T_MetricValue: return await parametrized_evaluation_fn(datum, *args, **kwargs) - return Metric( + return TestSuiteMetric( evaluation_fn=evaluation_fn, metric_metadata=common.EvaluationMetricMetadata( name=name_, @@ -117,7 +150,9 @@ def __eq__(self, other: Any) -> bool: def __lt__(self, other: Any) -> bool: if not isinstance(other, TextOverallPositiveSentiment): - raise TypeError(f"Cannot compare TextPositiveSentimentScores with {type(other)}") + raise TypeError( + f"Cannot compare TextPositiveSentimentScores with {type(other)}" + ) return self.pos - self.neg < other.pos - other.neg @@ -132,20 +167,26 @@ def _get_nltk_polarity_scores(text: str, model: str) -> dict[str, float]: return NLTKSentimentIntensityAnalyzer().polarity_scores(text) # type: ignore -def _get_sentiment_scores(output_datum: str, get_polarity_scores: GetPolarityScores) -> TextSentimentScores: +def _get_sentiment_scores( + output_datum: str, get_polarity_scores: GetPolarityScores +) -> TextSentimentScores: mapping: dict[str, float] = get_polarity_scores(output_datum) highest: str = pd.Series(mapping).idxmax() # type: ignore return TextSentimentScores(mapping=mapping, **mapping, highest=highest) -def make_get_sentiment_scores(get_polarity_scores: GetPolarityScores) -> common.EvaluationFunction[str, TextSentimentScores]: +def make_get_sentiment_scores( + get_polarity_scores: GetPolarityScores, +) -> test_suite_common.EvaluationFunction[str, TextSentimentScores]: async def _f(datum: str) -> TextSentimentScores: return _get_sentiment_scores(datum, get_polarity_scores) return _f -def make_get_sentiment_class(get_polarity_scores: GetPolarityScores) -> common.EvaluationFunction[str, str]: +def make_get_sentiment_class( + get_polarity_scores: GetPolarityScores, +) -> test_suite_common.EvaluationFunction[str, str]: async def _f(datum: str) -> str: scores = _get_sentiment_scores(datum, get_polarity_scores) return scores.highest @@ -153,7 +194,9 @@ async def _f(datum: str) -> str: return _f -def make_get_overall_positive_sentiment(get_polarity_scores: GetPolarityScores) -> common.EvaluationFunction[str, TextOverallPositiveSentiment]: +def make_get_overall_positive_sentiment( + get_polarity_scores: GetPolarityScores, +) -> test_suite_common.EvaluationFunction[str, TextOverallPositiveSentiment]: async def _f(datum: str) -> TextOverallPositiveSentiment: scores = _get_sentiment_scores(datum, get_polarity_scores) return TextOverallPositiveSentiment(pos=scores.pos, neg=scores.neg) @@ -163,14 +206,19 @@ async def _f(datum: str) -> TextOverallPositiveSentiment: def make_sentiment_scores_metric( get_polarity_scores: GetPolarityScores, - make_evaluation_fn: Callable[[GetPolarityScores], common.EvaluationFunction[str, common.T_MetricValue]], + make_evaluation_fn: Callable[ + [GetPolarityScores], + test_suite_common.EvaluationFunction[str, common.T_MetricValue], + ], name: str, description: str, best_value: common.T_MetricValue | None = None, worst_value: common.T_MetricValue | None = None, -) -> Metric[str, common.T_MetricValue]: - evaluation_fn: common.EvaluationFunction[str, common.T_MetricValue] = make_evaluation_fn(get_polarity_scores) - out: Metric[str, common.T_MetricValue] = Metric( +) -> TestSuiteMetric[str, common.T_MetricValue]: + evaluation_fn: test_suite_common.EvaluationFunction[ + str, common.T_MetricValue + ] = make_evaluation_fn(get_polarity_scores) + out: TestSuiteMetric[str, common.T_MetricValue] = TestSuiteMetric( evaluation_fn=evaluation_fn, metric_metadata=common.EvaluationMetricMetadata( # @@ -185,18 +233,32 @@ def make_sentiment_scores_metric( def make_structured_llm_metric( - chat_completion_create: common.CompletionTextToSerializedJSON, + chat_completion_create: test_suite_common.CompletionTextToSerializedJSON, eval_llm_name: str, - pydantic_basemodel_type: Type[common.T_BaseModel], + pydantic_basemodel_type: Type[test_suite_common.T_BaseModel], metric_name: str, metric_description: str, field_descriptions: dict[str, str] = {}, -) -> Metric[str, common.CustomMetricPydanticObject[common.T_BaseModel]]: +) -> TestSuiteMetric[ + str, + test_suite_common.CustomMetricPydanticObject[ + test_suite_common.T_BaseModel + ], +]: def _make_evaluation_fn( - basemodel_type: Type[common.T_BaseModel], - ) -> common.EvaluationFunction[str, common.CustomMetricPydanticObject[common.T_BaseModel]]: - async def _evaluation_fn(datum: str) -> common.CustomMetricPydanticObject[common.T_BaseModel]: - resp = common.get_llm_structured_response( + basemodel_type: Type[test_suite_common.T_BaseModel], + ) -> test_suite_common.EvaluationFunction[ + str, + test_suite_common.CustomMetricPydanticObject[ + test_suite_common.T_BaseModel + ], + ]: + async def _evaluation_fn( + datum: str, + ) -> test_suite_common.CustomMetricPydanticObject[ + test_suite_common.T_BaseModel + ]: + resp = test_suite_common.get_llm_structured_response( input_text=datum, chat_completion_create=chat_completion_create, basemodel_type=basemodel_type, @@ -207,11 +269,13 @@ async def _evaluation_fn(datum: str) -> common.CustomMetricPydanticObject[common case Err(e): raise ValueError(f"Error getting structured response: {e}") case Ok(data): - return common.CustomMetricPydanticObject(data=data) + return test_suite_common.CustomMetricPydanticObject( + data=data + ) return _evaluation_fn - return Metric( + return TestSuiteMetric( evaluation_fn=_make_evaluation_fn(pydantic_basemodel_type), metric_metadata=common.EvaluationMetricMetadata( name=metric_name, @@ -219,7 +283,9 @@ async def _evaluation_fn(datum: str) -> common.CustomMetricPydanticObject[common extra_metadata=dict( basemodel_type_name=pydantic_basemodel_type.__name__, eval_llm_name=eval_llm_name, - field_descriptions_json=json.dumps(field_descriptions, sort_keys=True), + field_descriptions_json=json.dumps( + field_descriptions, sort_keys=True + ), ), ), ) @@ -227,12 +293,20 @@ async def _evaluation_fn(datum: str) -> common.CustomMetricPydanticObject[common def _make_openai_structured_llm_metric_helper( eval_llm_name: str, - pydantic_basemodel_type: Type[common.T_BaseModel], + pydantic_basemodel_type: Type[test_suite_common.T_BaseModel], metric_name: str, metric_description: str, field_descriptions: dict[str, str], openai_chat_completion_create: OpenAIChatCompletionCreate | None = None, -) -> Result[Metric[str, common.CustomMetricPydanticObject[common.T_BaseModel]], str]: +) -> Result[ + TestSuiteMetric[ + str, + test_suite_common.CustomMetricPydanticObject[ + test_suite_common.T_BaseModel + ], + ], + str, +]: schema = pydantic_basemodel_type.model_json_schema() properties = schema["properties"] required = schema["required"] @@ -247,18 +321,23 @@ def _make_openai_structured_llm_metric_helper( def _with_description(key: str, value: dict[str, str]) -> dict[str, str]: if key in field_descriptions: - return core_utils.dict_union_allow_replace(value, {"description": field_descriptions[key]}) + return core_utils.dict_union_allow_replace( + value, {"description": field_descriptions[key]} + ) return value properties = {k: _with_description(k, v) for k, v in properties.items()} required = required or list(properties.keys()) - openai_eval_llm_chat_completion_create: common.CompletionTextToSerializedJSON = make_fn_completion_text_to_serialized_json( + openai_eval_llm_chat_completion_create: test_suite_common.CompletionTextToSerializedJSON = make_fn_completion_text_to_serialized_json( eval_llm_name=eval_llm_name, properties=properties, required=required, - openai_chat_completion_create=(openai_chat_completion_create or default_openai_chat_completion_create), + openai_chat_completion_create=( + openai_chat_completion_create + or default_openai_chat_completion_create + ), ) return Ok( @@ -281,12 +360,17 @@ def _with_description(key: str, value: dict[str, str]) -> dict[str, str]: def make_openai_structured_llm_metric( eval_llm_name: str, - pydantic_basemodel_type: Type[common.T_BaseModel], + pydantic_basemodel_type: Type[test_suite_common.T_BaseModel], metric_name: str, metric_description: str, field_descriptions: dict[str, str] = {}, openai_chat_completion_create: OpenAIChatCompletionCreate | None = None, -) -> Metric[str, common.CustomMetricPydanticObject[common.T_BaseModel]]: +) -> TestSuiteMetric[ + str, + test_suite_common.CustomMetricPydanticObject[ + test_suite_common.T_BaseModel + ], +]: res_metric = _make_openai_structured_llm_metric_helper( eval_llm_name=eval_llm_name, pydantic_basemodel_type=pydantic_basemodel_type, @@ -313,7 +397,9 @@ def make_openai_structured_llm_metric( best_value=True, worst_value=False, ) -def substring_match(datum: str, substring: str, case_sensitive: bool = True) -> bool: +def substring_match( + datum: str, substring: str, case_sensitive: bool = True +) -> bool: if case_sensitive: return substring in datum else: @@ -339,7 +425,7 @@ def make_brevity(datum: str): gpt3_5_text_ratings = make_openai_structured_llm_metric( eval_llm_name="gpt-3.5-turbo-0613", - pydantic_basemodel_type=common.TextRatingsData, + pydantic_basemodel_type=test_suite_common.TextRatingsData, metric_name="text_ratings", metric_description="Text ratings", field_descriptions=dict( @@ -350,21 +436,27 @@ def make_brevity(datum: str): ) nltk_sentiment_scores_vader = make_sentiment_scores_metric( - get_polarity_scores=partial(_get_nltk_polarity_scores, model="vader_lexicon"), + get_polarity_scores=partial( + _get_nltk_polarity_scores, model="vader_lexicon" + ), make_evaluation_fn=make_get_sentiment_scores, name="nltk_sentiment_scores_vader", description="NLTK sentiment scores using Vader", ) nltk_sentiment_class_vader = make_sentiment_scores_metric( - get_polarity_scores=partial(_get_nltk_polarity_scores, model="vader_lexicon"), + get_polarity_scores=partial( + _get_nltk_polarity_scores, model="vader_lexicon" + ), make_evaluation_fn=make_get_sentiment_class, name="nltk_sentiment_class_vader", description="Highest-probability NLTK sentiment class using Vader", ) nltk_sentiment_score_overall_positive = make_sentiment_scores_metric( - get_polarity_scores=partial(_get_nltk_polarity_scores, model="vader_lexicon"), + get_polarity_scores=partial( + _get_nltk_polarity_scores, model="vader_lexicon" + ), make_evaluation_fn=make_get_overall_positive_sentiment, name="nltk_sentiment_score_overall_positive", description="Positive minus negative", diff --git a/python/src/aiconfig/model_parser.py b/python/src/aiconfig/model_parser.py index 006afbeae..b582185e2 100644 --- a/python/src/aiconfig/model_parser.py +++ b/python/src/aiconfig/model_parser.py @@ -66,7 +66,7 @@ async def run( aiconfig: AIConfig, options: Optional["InferenceOptions"] = None, parameters: Dict = {}, - **kwargs, # TODO: Remove this, just a hack for now to ensure that it doesn't break + **kwargs, # TODO: Remove this, just a hack for now to ensure that it doesn't break ) -> ExecuteResult: """ Execute model inference based on completion data to be constructed in deserialize(), which includes the input prompt and @@ -115,7 +115,9 @@ async def run_batch( prompt = aiconfig_deep_copy.get_prompt(prompt.name) # Asynchronously schedule 'run()' for execution with a set of parameters. # This approach enables concurrent processing of multiple aiconfigs. - task = asyncio.create_task(self.run(prompt, aiconfig_deep_copy, options, params, **kwargs)) + task = asyncio.create_task( + self.run(prompt, aiconfig_deep_copy, options, params, **kwargs) + ) tasks.append(task) # store reference to the deep copy in inference_results inference_results.append(aiconfig_deep_copy) @@ -144,7 +146,9 @@ def get_output_text( str: The output text from the model inference response. """ - def get_model_settings(self, prompt: Prompt, aiconfig: "AIConfigRuntime") -> Dict[str, Any]: + def get_model_settings( + self, prompt: Prompt, aiconfig: "AIConfigRuntime" + ) -> Dict[str, Any]: """ Extracts the AI model's settings from the configuration. If both prompt and config level settings are defined, merge them with prompt settings taking precedence. @@ -158,7 +162,10 @@ def get_model_settings(self, prompt: Prompt, aiconfig: "AIConfigRuntime") -> Dic return aiconfig.get_global_settings(self.id()) # Check if the prompt exists in the config - if prompt.name not in aiconfig.prompt_index or aiconfig.prompt_index[prompt.name] != prompt: + if ( + prompt.name not in aiconfig.prompt_index + or aiconfig.prompt_index[prompt.name] != prompt + ): raise IndexError(f"Prompt '{prompt.name}' not in config.") model_metadata = prompt.metadata.model if prompt.metadata else None @@ -167,7 +174,9 @@ def get_model_settings(self, prompt: Prompt, aiconfig: "AIConfigRuntime") -> Dic # Use Default Model default_model = aiconfig.get_default_model() if not default_model: - raise KeyError(f"No default model specified in AIConfigMetadata, and prompt `{prompt.name}` does not specify a model.") + raise KeyError( + f"No default model specified in AIConfigMetadata, and prompt `{prompt.name}` does not specify a model." + ) return aiconfig.get_global_settings(default_model) elif isinstance(model_metadata, str): # Use Global settings @@ -176,7 +185,11 @@ def get_model_settings(self, prompt: Prompt, aiconfig: "AIConfigRuntime") -> Dic # Merge config and prompt settings with prompt settings taking precedent model_settings = {} global_settings = aiconfig.get_global_settings(model_metadata.name) - prompt_settings = prompt.metadata.model.settings if prompt.metadata.model.settings is not None else {} + prompt_settings = ( + prompt.metadata.model.settings + if prompt.metadata.model.settings is not None + else {} + ) model_settings.update(global_settings) model_settings.update(prompt_settings) @@ -188,7 +201,11 @@ def print_stream_callback(data, accumulated_data, index: int): """ Default streamCallback function that prints the output to the console. """ - print("\ndata: {}\naccumulated_data:{}\nindex:{}\n".format(data, accumulated_data, index)) + print( + "\ndata: {}\naccumulated_data:{}\nindex:{}\n".format( + data, accumulated_data, index + ) + ) def print_stream_delta(data, accumulated_data, index: int): diff --git a/python/src/aiconfig/registry.py b/python/src/aiconfig/registry.py index f60ac0973..fa5937bbc 100644 --- a/python/src/aiconfig/registry.py +++ b/python/src/aiconfig/registry.py @@ -32,7 +32,9 @@ class ModelParserRegistry: _parsers: Dict[str, ModelParser] = {} @staticmethod - def register_model_parser(model_parser: ModelParser, ids: Optional[List[str]] = None): + def register_model_parser( + model_parser: ModelParser, ids: Optional[List[str]] = None + ): """ Adds a model parser to the registry. This model parser is used to parse Prompts in the AIConfig that use the given model. @@ -108,10 +110,15 @@ def display_parsers() -> Dict[str, str]: """ returns a dictionary of model names and their correspondings model parser ids """ - return {model_name: model_parser.id() for model_name, model_parser in ModelParserRegistry._parsers.items()} + return { + model_name: model_parser.id() + for model_name, model_parser in ModelParserRegistry._parsers.items() + } -def update_model_parser_registry_with_config_runtime(config_runtime: "AIConfigRuntime"): +def update_model_parser_registry_with_config_runtime( + config_runtime: "AIConfigRuntime", +): """ Updates the model parser registry with any model parsers specified in the AIConfig. @@ -120,8 +127,13 @@ def update_model_parser_registry_with_config_runtime(config_runtime: "AIConfigRu """ if not config_runtime.metadata.model_parsers: return - for model_id, model_parser_id in config_runtime.metadata.model_parsers.items(): - retrieved_model_parser = ModelParserRegistry.get_model_parser(model_parser_id) # Fix + for ( + model_id, + model_parser_id, + ) in config_runtime.metadata.model_parsers.items(): + retrieved_model_parser = ModelParserRegistry.get_model_parser( + model_parser_id + ) # Fix if retrieved_model_parser is None: error_message = ( f"Unable to load AIConfig: It specifies {config_runtime.metadata.model_parsers}, " @@ -131,4 +143,6 @@ def update_model_parser_registry_with_config_runtime(config_runtime: "AIConfigRu ) raise Exception(error_message) - ModelParserRegistry.register_model_parser(retrieved_model_parser, [model_id]) + ModelParserRegistry.register_model_parser( + retrieved_model_parser, [model_id] + ) diff --git a/python/src/aiconfig/schema.py b/python/src/aiconfig/schema.py index d730c7578..87b820fd1 100644 --- a/python/src/aiconfig/schema.py +++ b/python/src/aiconfig/schema.py @@ -1,4 +1,3 @@ -import json import warnings from typing import Any, Dict, List, Literal, Optional, Union @@ -17,6 +16,7 @@ "callback_manager": True, } + class OutputDataWithStringValue(BaseModel): """ This represents the output content that is storied as a string, but we use @@ -79,6 +79,7 @@ class OutputDataWithToolCallsValue(BaseModel): OutputDataWithToolCallsValue, ] + class AttachmentDataWithStringValue(BaseModel): """ This represents the attachment data that is stored as a string, but we use @@ -89,6 +90,7 @@ class AttachmentDataWithStringValue(BaseModel): kind: Literal["file_uri", "base64"] value: str + class ExecuteResult(BaseModel): """ ExecuteResult represents the result of executing a prompt. @@ -104,7 +106,7 @@ class ExecuteResult(BaseModel): mime_type: Optional[str] = None # Output metadata metadata: Dict[str, Any] - + def to_json(self) -> JSONObject: """ Helper method used to ensure this is formatted to a valid JSON object @@ -276,7 +278,9 @@ def add_model(self, model_name: str, model_settings: InferenceSettings): Adds model settings to config level metadata """ if model_name in self.metadata.models: - raise Exception(f"Model '{model_name}' already exists. Use `update_model()`.") + raise Exception( + f"Model '{model_name}' already exists. Use `update_model()`." + ) self.metadata.models[model_name] = model_settings def delete_model(self, model_name: str): @@ -306,7 +310,9 @@ def get_model_name(self, prompt: Union[str, Prompt]) -> str: # If the prompt doesn't have a model, use the default model default_model = self.metadata.default_model if not default_model: - raise Exception(f"No model specified in AIConfig metadata, prompt {prompt.name} does not specify a model.") + raise Exception( + f"No model specified in AIConfig metadata, prompt {prompt.name} does not specify a model." + ) return default_model if isinstance(prompt.metadata.model, str): return prompt.metadata.model @@ -329,7 +335,9 @@ def get_default_model(self) -> Union[str, None]: """ return self.metadata.default_model - def set_model_parser(self, model_name: str, model_parser_id: Union[str, None]): + def set_model_parser( + self, model_name: str, model_parser_id: Union[str, None] + ): """ Adds a model name : model parser ID mapping to the AIConfig metadata. This model parser will be used to parse Promps in the AIConfig that use the given model. @@ -354,7 +362,9 @@ def get_metadata(self, prompt_name: Optional[str] = None): """ if prompt_name: if prompt_name not in self.prompt_index: - raise IndexError(f"Prompt '{prompt_name}' not found in config.") + raise IndexError( + f"Prompt '{prompt_name}' not found in config." + ) return self.prompt_index[prompt_name].metadata else: return self.metadata @@ -375,11 +385,17 @@ def get_parameters( prompt = prompt_or_prompt_name if isinstance(prompt_or_prompt_name, str): if prompt_or_prompt_name not in self.prompt_index: - raise IndexError(f"Prompt '{prompt_or_prompt_name}' not found in config, available prompts are:\n {list(self.prompt_index.keys())}") + raise IndexError( + f"Prompt '{prompt_or_prompt_name}' not found in config, available prompts are:\n {list(self.prompt_index.keys())}" + ) prompt = self.prompt_index[prompt_or_prompt_name] assert prompt is None or isinstance(prompt, Prompt) - if prompt is None or not prompt.metadata or not prompt.metadata.parameters: + if ( + prompt is None + or not prompt.metadata + or not prompt.metadata.parameters + ): return self.get_global_parameters() return self.get_prompt_parameters(prompt) @@ -422,7 +438,9 @@ def get_prompt_parameters( default_return_value JSONObject - Default value to return if prompt parameters are not defined. """ - return self._get_prompt_parameters_exact(prompt) or default_return_value + return ( + self._get_prompt_parameters_exact(prompt) or default_return_value + ) # pylint: enable=W0102 @@ -438,7 +456,12 @@ def _get_prompt_parameters_exact( return prompt.metadata return prompt.metadata.parameters - def set_parameter(self, parameter_name: str, parameter_value: Union[str, JSONObject], prompt_name: Optional[str] = None): + def set_parameter( + self, + parameter_name: str, + parameter_value: Union[str, JSONObject], + prompt_name: Optional[str] = None, + ): """ Sets a parameter in the AI configuration metadata. If a prompt_name is specified, it adds the parameter to a specific prompt's metadata @@ -487,7 +510,9 @@ def set_parameter(self, parameter_name: str, parameter_value: Union[str, JSONObj target_metadata.parameters = {} target_metadata.parameters[parameter_name] = parameter_value - def set_parameters(self, parameters: JSONObject, prompt_name: Optional[str] = None) -> None: + def set_parameters( + self, parameters: JSONObject, prompt_name: Optional[str] = None + ) -> None: """ Set the entire parameters dict for either a prompt (if specified) or the AIConfig (if prompt is not specified). It overwrites whatever @@ -511,9 +536,13 @@ def set_parameters(self, parameters: JSONObject, prompt_name: Optional[str] = No parameter_names_to_delete = [] if prompt_name: prompt = self.get_prompt(prompt_name) - parameter_names_to_delete = list(self.get_prompt_parameters(prompt).keys()) + parameter_names_to_delete = list( + self.get_prompt_parameters(prompt).keys() + ) else: - parameter_names_to_delete = list(self.get_global_parameters().keys()) + parameter_names_to_delete = list( + self.get_global_parameters().keys() + ) for parameter_name in parameter_names_to_delete: self.delete_parameter(parameter_name, prompt_name) @@ -540,7 +569,9 @@ def update_parameter( target_metadata = self.get_metadata(prompt_name) target_metadata.parameters[parameter_name] = parameter_value - def delete_parameter(self, parameter_name, prompt_name: Optional[str] = None): + def delete_parameter( + self, parameter_name, prompt_name: Optional[str] = None + ): """ Removes a parameter from the AI configuration metadata. If a prompt_name is specified, it removes the parameter from a particular prompt's metadata in the AI configuration. Else, it removes the parameter from the global @@ -556,8 +587,14 @@ def delete_parameter(self, parameter_name, prompt_name: Optional[str] = None): if parameter_name in target_metadata.parameters: del target_metadata.parameters[parameter_name] else: - scope_suffix = f"prompt '{prompt_name}'" if prompt_name is not None else "current AIConfig-scoped metadata" - raise KeyError(f"Parameter '{parameter_name}' does not exist for {scope_suffix}.") + scope_suffix = ( + f"prompt '{prompt_name}'" + if prompt_name is not None + else "current AIConfig-scoped metadata" + ) + raise KeyError( + f"Parameter '{parameter_name}' does not exist for {scope_suffix}." + ) def get_prompt(self, prompt_name: str) -> Prompt: """ @@ -570,10 +607,16 @@ def get_prompt(self, prompt_name: str) -> Prompt: Prompt: The prompt object. """ if prompt_name not in self.prompt_index: - raise IndexError("Prompt '{}' not found in config, available prompts are:\n {}".format(prompt_name, list(self.prompt_index.keys()))) + raise IndexError( + "Prompt '{}' not found in config, available prompts are:\n {}".format( + prompt_name, list(self.prompt_index.keys()) + ) + ) return self.prompt_index[prompt_name] - def add_prompt(self, prompt_name: str, prompt_data: Prompt, index: int | None = None): + def add_prompt( + self, prompt_name: str, prompt_data: Prompt, index: int | None = None + ): """ Adds a prompt to the .aiconfig. @@ -584,7 +627,11 @@ def add_prompt(self, prompt_name: str, prompt_data: Prompt, index: int | None = if prompt_name is None: prompt_name = prompt_data.name if prompt_name in self.prompt_index: - raise Exception("Prompt with name {} already exists. Use`update_prompt()`".format(prompt_name)) + raise Exception( + "Prompt with name {} already exists. Use`update_prompt()`".format( + prompt_name + ) + ) prompt_data.name = prompt_name self.prompt_index[prompt_name] = prompt_data @@ -602,7 +649,11 @@ def update_prompt(self, prompt_name: str, prompt_data: Prompt): prompt_data (Prompt): The prompt object containing the updated prompt data. """ if prompt_name not in self.prompt_index: - raise IndexError("Prompt '{}' not found in config, available prompts are:\n {}".format(prompt_name, list(self.prompt_index.keys()))) + raise IndexError( + "Prompt '{}' not found in config, available prompts are:\n {}".format( + prompt_name, list(self.prompt_index.keys()) + ) + ) self.prompt_index[prompt_name] = prompt_data # update prompt list @@ -621,13 +672,21 @@ def delete_prompt(self, prompt_name: str): prompt_name (str): The name of the prompt to delete. """ if prompt_name not in self.prompt_index: - raise IndexError("Prompt '{}' not found in config, available prompts are:\n {}".format(prompt_name, list(self.prompt_index.keys()))) + raise IndexError( + "Prompt '{}' not found in config, available prompts are:\n {}".format( + prompt_name, list(self.prompt_index.keys()) + ) + ) del self.prompt_index[prompt_name] # remove from prompt list - self.prompts = [prompt for prompt in self.prompts if prompt.name != prompt_name] + self.prompts = [ + prompt for prompt in self.prompts if prompt.name != prompt_name + ] - def get_model_metadata(self, inference_settings: InferenceSettings, model_id: str) -> ModelMetadata: + def get_model_metadata( + self, inference_settings: InferenceSettings, model_id: str + ) -> ModelMetadata: """ Generate a model metadata object based on the provided inference settings @@ -641,12 +700,16 @@ def get_model_metadata(self, inference_settings: InferenceSettings, model_id: st ModelMetadata: The model metadata. """ - overriden_settings = extract_override_settings(self, inference_settings, model_id) + overriden_settings = extract_override_settings( + self, inference_settings, model_id + ) if not overriden_settings: model_metadata = ModelMetadata(**{"name": model_id}) else: - model_metadata = ModelMetadata(**{"name": model_id, "settings": overriden_settings}) + model_metadata = ModelMetadata( + **{"name": model_id, "settings": overriden_settings} + ) return model_metadata # TODO (rossdan): If we pass in a new model under ModelMetadata, but that model is @@ -656,7 +719,12 @@ def get_model_metadata(self, inference_settings: InferenceSettings, model_id: st # that matches this class and do this automatically with the # `update_model_parser_registry_with_config_runtime`` function # Tracked in https://github.com/lastmile-ai/aiconfig/issues/503 - def update_model(self, model_name: Optional[str] = None, settings: Optional[InferenceSettings] = None, prompt_name: Optional[str] = None): + def update_model( + self, + model_name: Optional[str] = None, + settings: Optional[InferenceSettings] = None, + prompt_name: Optional[str] = None, + ): """ Updates model name and/or settings at the prompt (if specified) or AIConfig level. @@ -684,8 +752,12 @@ def update_model(self, model_name: Optional[str] = None, settings: Optional[Infe --> errors becasue no model name or settings provided """ if model_name is None and settings is None: - raise ValueError("Cannot update model. Either model name or model settings must be specified.") - if model_name is None and prompt_name is None: # Only settings param is set + raise ValueError( + "Cannot update model. Either model name or model settings must be specified." + ) + if ( + model_name is None and prompt_name is None + ): # Only settings param is set raise ValueError( """ Cannot update model. There are two things you are trying: \ @@ -738,14 +810,22 @@ def _update_model_name_for_prompt(self, model_name: str, prompt_name: str): if prompt.metadata is None: model_metadata = ModelMetadata(name=model_name, settings={}) prompt.metadata = PromptMetadata(model=model_metadata) - elif prompt.metadata.model is None or isinstance(prompt.metadata.model, str): + elif prompt.metadata.model is None or isinstance( + prompt.metadata.model, str + ): prompt.metadata.model = ModelMetadata(name=model_name, settings={}) else: # prompt.metadata.model is a ModelMetadata object - model_settings: InferenceSettings = prompt.metadata.model.settings or {} - prompt.metadata.model = ModelMetadata(name=model_name, settings=model_settings) + model_settings: InferenceSettings = ( + prompt.metadata.model.settings or {} + ) + prompt.metadata.model = ModelMetadata( + name=model_name, settings=model_settings + ) - def _update_model_settings_for_prompt(self, settings: InferenceSettings, prompt_name: str): + def _update_model_settings_for_prompt( + self, settings: InferenceSettings, prompt_name: str + ): """ Updates model name at the prompt level. We do not update at the AIConfig level because an AIConfig can have multiple models, so @@ -757,7 +837,9 @@ def _update_model_settings_for_prompt(self, settings: InferenceSettings, prompt_ """ prompt = self.get_prompt(prompt_name) if not prompt: - raise IndexError(f"Cannot update model settings for prompt '{prompt_name}'. Prompt '{prompt_name}' does not exist in AIConfig.") + raise IndexError( + f"Cannot update model settings for prompt '{prompt_name}'. Prompt '{prompt_name}' does not exist in AIConfig." + ) metadata_error_message = f""" Cannot update model settings for prompt '{prompt_name}' because it does not \ @@ -770,11 +852,18 @@ def _update_model_settings_for_prompt(self, settings: InferenceSettings, prompt_ if isinstance(prompt.metadata.model, str): model_name = prompt.metadata.model - prompt.metadata.model = ModelMetadata(name=model_name, settings=settings) + prompt.metadata.model = ModelMetadata( + name=model_name, settings=settings + ) else: prompt.metadata.model.settings = settings - def _update_model_for_aiconfig(self, model_name: str, settings: Union[InferenceSettings, None], prompt_name: Optional[str] = None): + def _update_model_for_aiconfig( + self, + model_name: str, + settings: Union[InferenceSettings, None], + prompt_name: Optional[str] = None, + ): """ Updates model name at the AIConfig level. @@ -802,10 +891,14 @@ def _update_model_for_aiconfig(self, model_name: str, settings: Union[InferenceS # If the model name already exists and settings is None, # this is essentially a no-op since we are preserving # existing settings for that model name - model_settings = settings or self.metadata.models.get(model_name, {}) + model_settings = settings or self.metadata.models.get( + model_name, {} + ) self.metadata.models[model_name] = model_settings - def set_metadata(self, key: str, value: Any, prompt_name: Optional[str] = None): + def set_metadata( + self, key: str, value: Any, prompt_name: Optional[str] = None + ): """ Sets a metadata property in the AIConfig @@ -817,7 +910,9 @@ def set_metadata(self, key: str, value: Any, prompt_name: Optional[str] = None): if prompt_name: prompt = self.get_prompt(prompt_name) if not prompt: - raise IndexError(f"Cannot set metadata property '{key}' for prompt {prompt_name}. Prompt {prompt_name} does not exist in AIConfig.") + raise IndexError( + f"Cannot set metadata property '{key}' for prompt {prompt_name}. Prompt {prompt_name} does not exist in AIConfig." + ) setattr(prompt.metadata, key, value) else: setattr(self.metadata, key, value) @@ -833,11 +928,15 @@ def delete_metadata(self, key: str, prompt_name: Optional[str] = None): if prompt_name: prompt = self.get_prompt(prompt_name) if not prompt: - raise IndexError(f"Cannot delete metadata. Prompt '{prompt_name}' not found in config.") + raise IndexError( + f"Cannot delete metadata. Prompt '{prompt_name}' not found in config." + ) if hasattr(prompt.metadata, key): delattr(prompt.metadata, key) else: - raise KeyError(f"Metadata '{key}' does not exist for Prompt {prompt_name}.") + raise KeyError( + f"Metadata '{key}' does not exist for Prompt {prompt_name}." + ) else: if hasattr(self.metadata, key): delattr(self.metadata, key) @@ -846,7 +945,9 @@ def delete_metadata(self, key: str, prompt_name: Optional[str] = None): # TODO: rename _get_metadata to get_metadata - def add_output(self, prompt_name: str, output: Output, overwrite: bool = False): + def add_output( + self, prompt_name: str, output: Output, overwrite: bool = False + ): """ Add an output to the prompt with the given name in the AIConfig @@ -857,15 +958,21 @@ def add_output(self, prompt_name: str, output: Output, overwrite: bool = False): """ prompt = self.get_prompt(prompt_name) if not prompt: - raise IndexError(f"Cannot add output. Prompt '{prompt_name}' not found in config.") + raise IndexError( + f"Cannot add output. Prompt '{prompt_name}' not found in config." + ) if not output: - raise ValueError(f"Cannot add output to prompt '{prompt_name}'. Output is not defined.") + raise ValueError( + f"Cannot add output to prompt '{prompt_name}'. Output is not defined." + ) if overwrite: prompt.outputs = [output] else: prompt.outputs.append(output) - def add_outputs(self, prompt_name: str, outputs: List[Output], overwrite: bool = False): + def add_outputs( + self, prompt_name: str, outputs: List[Output], overwrite: bool = False + ): """ Add multiple outputs to the prompt with the given name in the AIConfig @@ -876,9 +983,13 @@ def add_outputs(self, prompt_name: str, outputs: List[Output], overwrite: bool = """ prompt = self.get_prompt(prompt_name) if not prompt: - raise IndexError(f"Cannot add outputs. Prompt '{prompt_name}' not found in config.") + raise IndexError( + f"Cannot add outputs. Prompt '{prompt_name}' not found in config." + ) if not outputs: - raise ValueError(f"Cannot add outputs. No outputs provided for prompt '{prompt_name}'.") + raise ValueError( + f"Cannot add outputs. No outputs provided for prompt '{prompt_name}'." + ) if overwrite: prompt.outputs = outputs else: diff --git a/python/src/aiconfig/scripts/aiconfig_cli.py b/python/src/aiconfig/scripts/aiconfig_cli.py index 44d698ac3..a2ad2ad64 100644 --- a/python/src/aiconfig/scripts/aiconfig_cli.py +++ b/python/src/aiconfig/scripts/aiconfig_cli.py @@ -6,14 +6,16 @@ import subprocess import sys +import aiconfig.scripts.rage.rage as rage import lastmile_utils.lib.core.api as core_utils -from ruamel.yaml import YAML - from aiconfig.editor.server.server import run_backend_server -from aiconfig.editor.server.server_utils import DEFAULT_AICONFIGRC, EditServerConfig, ServerMode +from aiconfig.editor.server.server_utils import ( + DEFAULT_AICONFIGRC, + EditServerConfig, + ServerMode, +) from result import Err, Ok, Result - -import aiconfig.scripts.rage.rage as rage +from ruamel.yaml import YAML class AIConfigCLIConfig(core_utils.Record): @@ -38,16 +40,25 @@ async def main_with_args(argv: list[str]) -> int: def run_subcommand(argv: list[str]) -> Result[str, str]: LOGGER.info("Running subcommand") - subparser_record_types = {"edit": EditServerConfig, "rage": rage.RageConfig} - main_parser = core_utils.argparsify(AIConfigCLIConfig, subparser_record_types=subparser_record_types) + subparser_record_types = { + "edit": EditServerConfig, + "rage": rage.RageConfig, + } + main_parser = core_utils.argparsify( + AIConfigCLIConfig, subparser_record_types=subparser_record_types + ) # Try to parse the CLI args into a config. - cli_config: Result[AIConfigCLIConfig, str] = core_utils.parse_args(main_parser, argv[1:], AIConfigCLIConfig) + cli_config: Result[AIConfigCLIConfig, str] = core_utils.parse_args( + main_parser, argv[1:], AIConfigCLIConfig + ) # If cli_config is Ok(), pass its contents to _get_cli_process_result_from_config(). # Otherwise, short circuit and assign process_result to the Err. # Nothing gets mutated except for log level (see inside _get_cli_process_result_from_config() - process_result = cli_config.and_then(_set_log_level_and_create_default_yaml) + process_result = cli_config.and_then( + _set_log_level_and_create_default_yaml + ) LOGGER.info(f"{process_result=}") subparser_name = core_utils.get_subparser_name(main_parser, argv[1:]) @@ -55,12 +66,16 @@ def run_subcommand(argv: list[str]) -> Result[str, str]: if subparser_name == "edit": LOGGER.debug("Running edit subcommand") - edit_config = core_utils.parse_args(main_parser, argv[1:], EditServerConfig) + edit_config = core_utils.parse_args( + main_parser, argv[1:], EditServerConfig + ) LOGGER.debug(f"{edit_config.is_ok()=}") out = _run_editor_servers_with_configs(edit_config, cli_config) return out elif subparser_name == "rage": - res_rage_config = core_utils.parse_args(main_parser, argv[1:], rage.RageConfig) + res_rage_config = core_utils.parse_args( + main_parser, argv[1:], rage.RageConfig + ) res_rage = res_rage_config.and_then(rage.rage) match res_rage: case Ok(msg): @@ -71,11 +86,18 @@ def run_subcommand(argv: list[str]) -> Result[str, str]: return Err(f"Unknown subparser: {subparser_name}") -def _run_editor_servers_with_configs(edit_config: Result[EditServerConfig, str], cli_config: Result[AIConfigCLIConfig, str]) -> Result[str, str]: +def _run_editor_servers_with_configs( + edit_config: Result[EditServerConfig, str], + cli_config: Result[AIConfigCLIConfig, str], +) -> Result[str, str]: if not (edit_config.is_ok() and cli_config.is_ok()): - return Err(f"Something went wrong with configs: {edit_config=}, {cli_config=}") + return Err( + f"Something went wrong with configs: {edit_config=}, {cli_config=}" + ) - server_outcomes = _run_editor_servers(edit_config.unwrap(), cli_config.unwrap().aiconfigrc_path) + server_outcomes = _run_editor_servers( + edit_config.unwrap(), cli_config.unwrap().aiconfigrc_path + ) if server_outcomes.is_err(): return Err(f"Something went wrong with servers: {server_outcomes=}") @@ -100,7 +122,9 @@ def is_port_in_use(port: int) -> bool: return s.connect_ex(("localhost", port)) == 0 -def _run_editor_servers(edit_config: EditServerConfig, aiconfigrc_path: str) -> Result[list[str], str]: +def _run_editor_servers( + edit_config: EditServerConfig, aiconfigrc_path: str +) -> Result[list[str], str]: port = edit_config.server_port while is_port_in_use(port): @@ -116,7 +140,11 @@ def _run_editor_servers(edit_config: EditServerConfig, aiconfigrc_path: str) -> # Check if server is already running LOGGER.info("Running editor servers") - frontend_procs = _run_frontend_server_background() if edit_config.server_mode in [ServerMode.DEBUG_SERVERS] else Ok([]) + frontend_procs = ( + _run_frontend_server_background() + if edit_config.server_mode in [ServerMode.DEBUG_SERVERS] + else Ok([]) + ) match frontend_procs: case Ok(_): pass @@ -138,7 +166,9 @@ def _run_editor_servers(edit_config: EditServerConfig, aiconfigrc_path: str) -> return core_utils.result_reduce_list_all_ok(results) -def _set_log_level_and_create_default_yaml(cli_config: AIConfigCLIConfig) -> Result[bool, str]: +def _set_log_level_and_create_default_yaml( + cli_config: AIConfigCLIConfig, +) -> Result[bool, str]: """ This function has 2 jobs (currently): 1. Set the log level @@ -179,16 +209,24 @@ def _read() -> str: return core_utils.ErrWithTraceback(e) -def _run_frontend_server_background() -> Result[list[subprocess.Popen[bytes]], str]: +def _run_frontend_server_background() -> ( + Result[list[subprocess.Popen[bytes]], str] +): LOGGER.info("Running frontend server in background") p1, p2 = None, None try: - p1 = subprocess.Popen(["yarn"], cwd="python/src/aiconfig/editor/client") + p1 = subprocess.Popen( + ["yarn"], cwd="python/src/aiconfig/editor/client" + ) except Exception as e: return core_utils.ErrWithTraceback(e) try: - p2 = subprocess.Popen(["yarn", "start"], cwd="python/src/aiconfig/editor/client", stdin=subprocess.PIPE) + p2 = subprocess.Popen( + ["yarn", "start"], + cwd="python/src/aiconfig/editor/client", + stdin=subprocess.PIPE, + ) except Exception as e: return core_utils.ErrWithTraceback(e) diff --git a/python/src/aiconfig/scripts/rage/rage.py b/python/src/aiconfig/scripts/rage/rage.py index 8bdc5e8c8..72229e9f5 100644 --- a/python/src/aiconfig/scripts/rage/rage.py +++ b/python/src/aiconfig/scripts/rage/rage.py @@ -28,7 +28,9 @@ def rage(config: RageConfig) -> Result[None, str]: print("Please open an issue! :) Here's a template:") out = _create_issue_draft() - print("Our sincerest apologies and gratitude. If you opened an issue, will comment on it as soon as possible.") + print( + "Our sincerest apologies and gratitude. If you opened an issue, will comment on it as soon as possible." + ) print("\n\n") print("Done raging! :)") @@ -47,7 +49,9 @@ def _try_run_command(command: str) -> Result[str, str]: # Try to execute the command and capture output try: - process = subprocess.Popen(cmd_parts, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + process = subprocess.Popen( + cmd_parts, stdout=subprocess.PIPE, stderr=subprocess.PIPE + ) stdout, stderr = process.communicate() # Decode bytes to string and return formatted message @@ -60,8 +64,12 @@ def _try_run_command(command: str) -> Result[str, str]: return core_utils.ErrWithTraceback(e) -def _fmt_command_input_output(command: str, command_output: Result[str, str]) -> str: - command_output_str = command_output.unwrap_or("Couldn't run command :(. Please run manually.") +def _fmt_command_input_output( + command: str, command_output: Result[str, str] +) -> str: + command_output_str = command_output.unwrap_or( + "Couldn't run command :(. Please run manually." + ) return dedent( f""" Command: {command} @@ -74,8 +82,16 @@ def _create_issue_draft() -> Result[None, str]: title = "Bug Report [TODO: brief description]" server_log_path = _look_for_log("editor_flask_server.log") aiconfig_log_path = _look_for_log("aiconfig.log") - server_log_str = f"see {server_log_path}" if server_log_path else "Couldn't find the log file :( Check your terminal output" - aiconfig_log_str = f"see {aiconfig_log_path}" if aiconfig_log_path else "Couldn't find the log file :( Check your terminal output" + server_log_str = ( + f"see {server_log_path}" + if server_log_path + else "Couldn't find the log file :( Check your terminal output" + ) + aiconfig_log_str = ( + f"see {aiconfig_log_path}" + if aiconfig_log_path + else "Couldn't find the log file :( Check your terminal output" + ) commands_to_run = [ "which pip", @@ -85,7 +101,12 @@ def _create_issue_draft() -> Result[None, str]: "python --version", "python3 --version", ] - command_inputs_outputs = dedent("\n\n".join(_fmt_command_input_output(command, _try_run_command(command)) for command in commands_to_run)) + command_inputs_outputs = dedent( + "\n\n".join( + _fmt_command_input_output(command, _try_run_command(command)) + for command in commands_to_run + ) + ) pip_list_output = _get_pip_list_filtered() body = dedent( f""" @@ -111,7 +132,9 @@ def _create_issue_draft() -> Result[None, str]: print("\n\n\n\nIssue draft:\n\n") print(f"Title: {title}") print(f"Body:\n{body}") - open_draft = _get_yes_or_no_input("Would you like to open a draft issue in your browser? [Y/n] ") + open_draft = _get_yes_or_no_input( + "Would you like to open a draft issue in your browser? [Y/n] " + ) if open_draft: _troll_the_user_part_2() _open_github_issue_draft( @@ -130,8 +153,14 @@ def _get_pip_list_filtered() -> str: output = _try_run_command("pip3 list") match output: case Ok(output_str): - filtered_lines = [line for line in output_str.split("\n") if "aiconfig" in line.lower()] - filtered_str = Ok("aiconfig packages:\n" + "\n".join(filtered_lines)) + filtered_lines = [ + line + for line in output_str.split("\n") + if "aiconfig" in line.lower() + ] + filtered_str = Ok( + "aiconfig packages:\n" + "\n".join(filtered_lines) + ) return _fmt_command_input_output("pip3 list", filtered_str) case Err(_): return "\nCommand: pip3 list | grep aiconfig\nCouldn't run command :(. Please run manually." @@ -140,11 +169,15 @@ def _get_pip_list_filtered() -> str: def _look_for_log(logfile: str) -> str | None: print() if os.path.exists(logfile): - print(f"Found {logfile}! Please include its contents in your bug report.") + print( + f"Found {logfile}! Please include its contents in your bug report." + ) return os.path.abspath(logfile) else: print(f"No {logfile} found. This might be another bug :)") - print("For now, please include your terminal output in your bug report.") + print( + "For now, please include your terminal output in your bug report." + ) return None @@ -171,7 +204,9 @@ def get_animation(): if type == "spinner": return next(spinning) else: - return "".join(np.random.choice(["♩", "♫", "♬", "♪"]) for _ in range(5)) + return "".join( + np.random.choice(["♩", "♫", "♬", "♪"]) for _ in range(5) + ) end_time = time.time() + seconds @@ -184,7 +219,9 @@ def get_animation(): def _troll_the_user_part_1(): _spin(2) - print("Please hold. Your call is important to us.\nA representative will be with you shortly.") + print( + "Please hold. Your call is important to us.\nA representative will be with you shortly." + ) _spin(3, type="music") print("Looking for your server logs...") _spin(4, type="music") @@ -194,7 +231,9 @@ def _troll_the_user_part_1(): _spin(5) print("I'm glad we're finally spending time together.") _spin(4, type="music") - print("Please continue holding. We appreciate your continued support, or whatever.") + print( + "Please continue holding. We appreciate your continued support, or whatever." + ) _spin(3, type="music") @@ -209,11 +248,17 @@ def _troll_the_user_part_2(): _spin(3) -def _open_github_issue_draft(repo: str, title: str, body: str, labels: list[str] | None = None) -> bool: +def _open_github_issue_draft( + repo: str, title: str, body: str, labels: list[str] | None = None +) -> bool: base_url = f"https://github.com/{repo}/issues/new" title_str = f"title={quote(title)}" body_str = f"body={quote(body)}" - labels_str = f"labels={','.join([quote(label) for label in labels])}" if labels else "" + labels_str = ( + f"labels={','.join([quote(label) for label in labels])}" + if labels + else "" + ) issue_url = f"{base_url}?{title_str}&{body_str}&{labels_str}" try: @@ -221,6 +266,8 @@ def _open_github_issue_draft(repo: str, title: str, body: str, labels: list[str] return True except Exception as e: logger.debug(f"exn={e}") - logger.warning(f"Couldn't open your browser for you. I guess you'll have to do it yourself :)") + logger.warning( + f"Couldn't open your browser for you. I guess you'll have to do it yourself :)" + ) logger.warning(f"Please open an issue here: {issue_url}") return False diff --git a/python/src/aiconfig/scripts/run_aiconfig.py b/python/src/aiconfig/scripts/run_aiconfig.py index cb5e038bc..50a7b4c89 100644 --- a/python/src/aiconfig/scripts/run_aiconfig.py +++ b/python/src/aiconfig/scripts/run_aiconfig.py @@ -5,7 +5,10 @@ import lastmile_utils.lib.core.api as core_utils import result from aiconfig.Config import AIConfigRuntime -from aiconfig.eval.lib import TextBasedInputDatum, run_aiconfig_on_text_based_input +from aiconfig.eval.lib import ( + TextBasedInputDatum, + run_aiconfig_on_text_based_input, +) from result import Result logging.basicConfig(format=core_utils.LOGGER_FMT) @@ -45,7 +48,9 @@ async def main(): def _load_settings(settings_path: str) -> Result[Settings, str]: - return core_utils.pydantic_model_validate_from_json_file_path(settings_path, Settings) + return core_utils.pydantic_model_validate_from_json_file_path( + settings_path, Settings + ) if __name__ == "__main__": diff --git a/python/src/aiconfig/util/config_utils.py b/python/src/aiconfig/util/config_utils.py index 2f3513a49..bf1bd2052 100644 --- a/python/src/aiconfig/util/config_utils.py +++ b/python/src/aiconfig/util/config_utils.py @@ -1,8 +1,9 @@ import copy -import dotenv import os -from typing import TYPE_CHECKING, Union -from result import Result, Ok, Err +from typing import TYPE_CHECKING + +import dotenv +from result import Err, Ok, Result if TYPE_CHECKING: pass @@ -13,13 +14,13 @@ def get_api_key_from_environment( - api_key_name: str, - required: bool = True) -> Result[str | None, str]: + api_key_name: str, required: bool = True +) -> Result[str | None, str]: """Get the API key if it exists, return None or error if it doesn't Args: api_key_name (str): The keyname that we're trying to import from env variable - required (bool, optional): If this is true, we raise an error if the + required (bool, optional): If this is true, we raise an error if the key is not found Returns: @@ -31,14 +32,20 @@ def get_api_key_from_environment( return Ok(os.getenv(api_key_name)) -def _get_api_key_from_environment_required(api_key_name: str) -> Result[str, str]: +def _get_api_key_from_environment_required( + api_key_name: str, +) -> Result[str, str]: try: return Ok(os.environ[api_key_name]) except KeyError: return Err(f"Missing API key '{api_key_name}' in environment") -def extract_override_settings(config_runtime: "AIConfig", inference_settings: "InferenceSettings", model_id: str): +def extract_override_settings( + config_runtime: "AIConfig", + inference_settings: "InferenceSettings", + model_id: str, +): """ Extract inference settings with overrides based on inference settings. @@ -62,7 +69,8 @@ def extract_override_settings(config_runtime: "AIConfig", inference_settings: "I override_settings = { key: copy.deepcopy(inference_settings[key]) for key in inference_settings - if key not in global_model_settings or global_model_settings.get(key) != inference_settings[key] + if key not in global_model_settings + or global_model_settings.get(key) != inference_settings[key] } return override_settings return inference_settings diff --git a/python/src/aiconfig/util/params.py b/python/src/aiconfig/util/params.py index 959d2a278..e7e2b2efa 100644 --- a/python/src/aiconfig/util/params.py +++ b/python/src/aiconfig/util/params.py @@ -32,7 +32,9 @@ def get_parameters_in_template(template) -> dict: re_pattern = r"{{[{]?(.*?)[}]?}}" # Find all Handlebars tags in the template - tags = [match.group(1).strip() for match in re.finditer(re_pattern, template)] + tags = [ + match.group(1).strip() for match in re.finditer(re_pattern, template) + ] # Initialize a dictionary to store parameters root = defaultdict(lambda: defaultdict(bool)) @@ -123,7 +125,9 @@ def resolve_parametrized_prompt(raw_prompt, params): return resolved_prompt -def find_dependencies_in_prompt(prompt_template: str, current_prompt_name: str, prompt_list: List[Prompt]) -> Set[str]: +def find_dependencies_in_prompt( + prompt_template: str, current_prompt_name: str, prompt_list: List[Prompt] +) -> Set[str]: """ Finds and returns a set of prompt IDs that are dependencies of the given prompt. @@ -156,7 +160,11 @@ def find_dependencies_in_prompt(prompt_template: str, current_prompt_name: str, return dependencies -def get_dependency_graph(root_prompt: Prompt, all_prompts: List[Prompt], prompt_dict: Dict[str, Prompt]) -> dict[str, List[str]]: +def get_dependency_graph( + root_prompt: Prompt, + all_prompts: List[Prompt], + prompt_dict: Dict[str, Prompt], +) -> dict[str, List[str]]: """ Generates an upstream dependency graph of prompts in the configuration, with each entry representing only its direct dependencies. Traversal is required to identify all upstream dependencies. The specified prompt serves as the root. @@ -179,8 +187,12 @@ def build_dependency_graph_recursive(current_prompt_name: str) -> dict: return visited.add(current_prompt_name) - prompt_template = prompt_dict[current_prompt_name].get_raw_prompt_from_config() - prompt_dependencies = find_dependencies_in_prompt(prompt_template, current_prompt_name, all_prompts) + prompt_template = prompt_dict[ + current_prompt_name + ].get_raw_prompt_from_config() + prompt_dependencies = find_dependencies_in_prompt( + prompt_template, current_prompt_name, all_prompts + ) for prompt_dependency in prompt_dependencies: dependency_graph[current_prompt_name].append(prompt_dependency) @@ -207,7 +219,9 @@ def resolve_parameters(params, prompt: Prompt, ai_config: "AIConfig"): return resolved_prompt -def get_prompt_template(prompt: Prompt, aiconfig: "AIConfigRuntime") -> str | None: +def get_prompt_template( + prompt: Prompt, aiconfig: "AIConfigRuntime" +) -> str | None: """ Returns the template for a prompt. @@ -218,9 +232,13 @@ def get_prompt_template(prompt: Prompt, aiconfig: "AIConfigRuntime") -> str | No Returns: str: Returns the template for a prompt. """ - model_parser = ModelParserRegistry.get_model_parser_for_prompt(prompt, aiconfig) + model_parser = ModelParserRegistry.get_model_parser_for_prompt( + prompt, aiconfig + ) # Circular type reference - from ..default_parsers.parameterized_model_parser import ParameterizedModelParser + from ..default_parsers.parameterized_model_parser import ( + ParameterizedModelParser, + ) if isinstance(model_parser, ParameterizedModelParser): return model_parser.get_prompt_template(prompt, aiconfig) @@ -233,7 +251,9 @@ def get_prompt_template(prompt: Prompt, aiconfig: "AIConfigRuntime") -> str | No return None -def collect_prompt_references(current_prompt: Prompt, ai_config: "AIConfigRuntime"): +def collect_prompt_references( + current_prompt: Prompt, ai_config: "AIConfigRuntime" +): """ Collects references to all other prompts in the AIConfig. Only prompts that appear before the current prompt are collected. """ @@ -247,7 +267,13 @@ def collect_prompt_references(current_prompt: Prompt, ai_config: "AIConfigRuntim # not all model inputs are parameterizable continue - prompt_output = ai_config.get_output_text(previous_prompt, ai_config.get_latest_output(previous_prompt)) if previous_prompt.outputs else None + prompt_output = ( + ai_config.get_output_text( + previous_prompt, ai_config.get_latest_output(previous_prompt) + ) + if previous_prompt.outputs + else None + ) prompt_references[previous_prompt.name] = { "input": prompt_input, "output": prompt_output, @@ -255,13 +281,17 @@ def collect_prompt_references(current_prompt: Prompt, ai_config: "AIConfigRuntim return prompt_references -def resolve_prompt(current_prompt: Prompt, input_params: Dict, ai_config: "AIConfigRuntime") -> str: +def resolve_prompt( + current_prompt: Prompt, input_params: Dict, ai_config: "AIConfigRuntime" +) -> str: """ Parameterizes a prompt using provided parameters, references to other prompts, and parameters stored in config.. """ raw_prompt = get_prompt_template(current_prompt, ai_config) - return resolve_prompt_string(current_prompt, input_params, ai_config, raw_prompt) + return resolve_prompt_string( + current_prompt, input_params, ai_config, raw_prompt + ) def resolve_system_prompt( @@ -273,7 +303,9 @@ def resolve_system_prompt( """ Parameterizes a system prompt using provided prompt and parameters, references to other prompts, and parameters stored in config.. """ - return resolve_prompt_string(current_prompt, input_params, ai_config, system_prompt) + return resolve_prompt_string( + current_prompt, input_params, ai_config, system_prompt + ) def resolve_prompt_string( diff --git a/python/tests/mocks.py b/python/tests/mocks.py index c29aae2ab..9a123a5a9 100644 --- a/python/tests/mocks.py +++ b/python/tests/mocks.py @@ -11,11 +11,19 @@ async def __call__(self, prompt_name: str, params: dict[str, str]) -> str: pass -def make_mock_aiconfig_runtime(mock_run_text_to_text: MockRunTextToText | None = None) -> AIConfigRuntime: - async def _default_mock_run_text_to_text(prompt_name: str, params: dict[str, str]) -> str: +def make_mock_aiconfig_runtime( + mock_run_text_to_text: MockRunTextToText | None = None, +) -> AIConfigRuntime: + async def _default_mock_run_text_to_text( + prompt_name: str, params: dict[str, str] + ) -> str: return f"output_for_{prompt_name}_the_query_{params['the_query']}" - mock_run_text_to_text_impl = _default_mock_run_text_to_text if mock_run_text_to_text is None else mock_run_text_to_text + mock_run_text_to_text_impl = ( + _default_mock_run_text_to_text + if mock_run_text_to_text is None + else mock_run_text_to_text + ) class _MockAIConfigRuntime(AIConfigRuntime): def __init__(self): diff --git a/python/tests/parsers/test_dalle_parser.py b/python/tests/parsers/test_dalle_parser.py index a9d275c2b..475b67089 100644 --- a/python/tests/parsers/test_dalle_parser.py +++ b/python/tests/parsers/test_dalle_parser.py @@ -14,7 +14,9 @@ async def test_serialize_basic(set_temporary_env_vars: None): "size": "1024x1024", } aiconfig = AIConfigRuntime.create() - serialized_prompts = await aiconfig.serialize("dall-e-3", completion_params, prompt_name="panda_eating_dumplings") + serialized_prompts = await aiconfig.serialize( + "dall-e-3", completion_params, prompt_name="panda_eating_dumplings" + ) new_prompt = serialized_prompts[0] assert new_prompt == Prompt( name="panda_eating_dumplings", @@ -23,7 +25,11 @@ async def test_serialize_basic(set_temporary_env_vars: None): **{ "model": { "name": "dall-e-3", - "settings": {"model": "dall-e-3", "n": 1, "size": "1024x1024"}, + "settings": { + "model": "dall-e-3", + "n": 1, + "size": "1024x1024", + }, }, } ), diff --git a/python/tests/parsers/test_openai_util.py b/python/tests/parsers/test_openai_util.py index 81a92f0c6..54fad4a2e 100644 --- a/python/tests/parsers/test_openai_util.py +++ b/python/tests/parsers/test_openai_util.py @@ -4,7 +4,13 @@ from aiconfig.default_parsers.openai import refine_chat_completion_params from mock import patch -from aiconfig.schema import ExecuteResult, Prompt, PromptInput, PromptMetadata, ModelMetadata +from aiconfig.schema import ( + ExecuteResult, + ModelMetadata, + Prompt, + PromptInput, + PromptMetadata, +) from ..conftest import mock_openai_chat_completion from ..util.file_path_utils import get_absolute_file_path_from_relative @@ -27,9 +33,13 @@ def test_refine_chat_completion_params(): ), ) - aiconfig = AIConfigRuntime.create(name="test_refine_chat_completion_params", prompts=[prompt]) + aiconfig = AIConfigRuntime.create( + name="test_refine_chat_completion_params", prompts=[prompt] + ) - refined_params = refine_chat_completion_params(prompt.metadata.model.settings, aiconfig, prompt) + refined_params = refine_chat_completion_params( + prompt.metadata.model.settings, aiconfig, prompt + ) assert "system_prompt" not in refined_params assert "stream" in refined_params @@ -39,9 +49,15 @@ def test_refine_chat_completion_params(): @pytest.mark.asyncio async def test_get_output_text(set_temporary_env_vars): - with patch.object(openai.chat.completions, "create", side_effect=mock_openai_chat_completion): + with patch.object( + openai.chat.completions, + "create", + side_effect=mock_openai_chat_completion, + ): config_relative_path = "../aiconfigs/basic_chatgpt_query_config.json" - config_absolute_path = get_absolute_file_path_from_relative(__file__, config_relative_path) + config_absolute_path = get_absolute_file_path_from_relative( + __file__, config_relative_path + ) aiconfig = AIConfigRuntime.load(config_absolute_path) await aiconfig.run("prompt1", {}) @@ -56,7 +72,11 @@ async def test_get_output_text(set_temporary_env_vars): @pytest.mark.asyncio async def test_serialize(set_temporary_env_vars): - with patch.object(openai.chat.completions, "create", side_effect=mock_openai_chat_completion): + with patch.object( + openai.chat.completions, + "create", + side_effect=mock_openai_chat_completion, + ): # Test with one input prompt and system. No output completion_params = { "model": "gpt-3.5-turbo", @@ -69,7 +89,9 @@ async def test_serialize(set_temporary_env_vars): } aiconfig = AIConfigRuntime.create() - serialized_prompts = await aiconfig.serialize("gpt-3.5-turbo", completion_params, prompt_name="the prompt") + serialized_prompts = await aiconfig.serialize( + "gpt-3.5-turbo", completion_params, prompt_name="the prompt" + ) new_prompt = serialized_prompts[0] # assert prompt serialized correctly into config @@ -104,11 +126,16 @@ async def test_serialize(set_temporary_env_vars): "messages": [ {"role": "system", "content": "You are an expert greeter"}, {"role": "user", "content": "Hello!"}, - {"role": "assistant", "content": "Hello! How can I assist you today?"}, + { + "role": "assistant", + "content": "Hello! How can I assist you today?", + }, ], } - serialized_prompts = await aiconfig.serialize("gpt-3.5-turbo", completion_params, "prompt") + serialized_prompts = await aiconfig.serialize( + "gpt-3.5-turbo", completion_params, "prompt" + ) new_prompt = serialized_prompts[0] expected_prompt = Prompt( @@ -137,7 +164,10 @@ async def test_serialize(set_temporary_env_vars): execution_count=None, data="Hello! How can I assist you today?", metadata={ - "raw_response": {"role": "assistant", "content": "Hello! How can I assist you today?"}, + "raw_response": { + "role": "assistant", + "content": "Hello! How can I assist you today?", + }, "role": "assistant", }, mime_type=None, @@ -157,7 +187,10 @@ async def test_serialize(set_temporary_env_vars): "temperature": 0.7, "max_tokens": 900, "messages": [ - {"role": "system", "content": "You are an expert decision maker"}, + { + "role": "system", + "content": "You are an expert decision maker", + }, {"role": "user", "content": "What is the weather today?"}, ], "functions": [ @@ -182,7 +215,9 @@ async def test_serialize(set_temporary_env_vars): ], } - serialized_prompts = await aiconfig.serialize("gpt-3.5-turbo", completion_params, "prompt") + serialized_prompts = await aiconfig.serialize( + "gpt-3.5-turbo", completion_params, "prompt" + ) new_prompt = serialized_prompts[0] assert new_prompt == Prompt( name="prompt", @@ -212,7 +247,10 @@ async def test_serialize(set_temporary_env_vars): }, "unit": { "type": "string", - "enum": ["celsius", "fahrenheit"], + "enum": [ + "celsius", + "fahrenheit", + ], }, }, "required": ["location"], @@ -231,8 +269,14 @@ async def test_serialize(set_temporary_env_vars): "temperature": 0.7, "max_tokens": 900, "messages": [ - {"role": "system", "content": "You are an expert decision maker"}, - {"role": "user", "content": "What's the weather like in Boston today?"}, + { + "role": "system", + "content": "You are an expert decision maker", + }, + { + "role": "user", + "content": "What's the weather like in Boston today?", + }, { "role": "assistant", "content": None, @@ -273,7 +317,9 @@ async def test_serialize(set_temporary_env_vars): ], } - prompts = await aiconfig.serialize("gpt-3.5-turbo", completion_params, "prompt") + prompts = await aiconfig.serialize( + "gpt-3.5-turbo", completion_params, "prompt" + ) new_prompt = prompts[1] expected_prompt = Prompt( diff --git a/python/tests/parsers/test_parser.py b/python/tests/parsers/test_parser.py index 32bed9709..b142a605e 100644 --- a/python/tests/parsers/test_parser.py +++ b/python/tests/parsers/test_parser.py @@ -1,7 +1,12 @@ import pytest from aiconfig.Config import AIConfigRuntime -from aiconfig.schema import ConfigMetadata, ModelMetadata, Prompt, PromptMetadata +from aiconfig.schema import ( + ConfigMetadata, + ModelMetadata, + Prompt, + PromptMetadata, +) from ..util.mock_parser import MockModelParser @@ -34,18 +39,24 @@ def test_get_model_settings(ai_config_runtime: AIConfigRuntime): prompt = ai_config_runtime.prompts[0] - assert mock_model_parser.get_model_settings(prompt, ai_config_runtime) == {} + assert ( + mock_model_parser.get_model_settings(prompt, ai_config_runtime) == {} + ) # settings is defined as {}. Should be returned as {} aiconfig = AIConfigRuntime( name="test", - metadata=ConfigMetadata(**{"models": {"fakemodel": {"fake_setting": "True"}}}), + metadata=ConfigMetadata( + **{"models": {"fakemodel": {"fake_setting": "True"}}} + ), # here is settings = None. This implies that settings were not passed in. Should default to global params prompts=[ Prompt( name="test", input="test", - metadata=PromptMetadata(model=ModelMetadata(name="test", settings={})), + metadata=PromptMetadata( + model=ModelMetadata(name="test", settings={}) + ), ) ], ) @@ -56,20 +67,26 @@ def test_get_model_settings(ai_config_runtime: AIConfigRuntime): # settings is defined as None. Should be returned as config level, ie {"fake_setting": "True"} aiconfig = AIConfigRuntime( name="test", - metadata=ConfigMetadata(**{"models": {"fakemodel": {"fake_setting": "True"}}}), + metadata=ConfigMetadata( + **{"models": {"fakemodel": {"fake_setting": "True"}}} + ), # here is settings = None. This implies that settings were not passed in. Should default to global params prompts=[ Prompt( name="test", input="test", - metadata=PromptMetadata(model=ModelMetadata(name="fakemodel", settings=None)), + metadata=PromptMetadata( + model=ModelMetadata(name="fakemodel", settings=None) + ), ) ], ) prompt = aiconfig.prompts[0] - assert mock_model_parser.get_model_settings(prompt, aiconfig) == {"fake_setting": "True"} + assert mock_model_parser.get_model_settings(prompt, aiconfig) == { + "fake_setting": "True" + } with pytest.raises(IndexError, match=r"Prompt '.*' not in config"): prompt = Prompt( diff --git a/python/tests/test_library_helpers.py b/python/tests/test_library_helpers.py index 0a5289715..f665a7e46 100644 --- a/python/tests/test_library_helpers.py +++ b/python/tests/test_library_helpers.py @@ -38,7 +38,9 @@ def test_collect_prompt_references(): # input is an aiconfig with a 4 prompts. Test collects prompt references for the 3rd prompt # collect_prompt_references should return references to 1 and 2. 3 is the prompt we are collecting references for, 4 is after. Both are expected to be skipped config_relative_path = "aiconfigs/GPT4 Coding Assistant_aiconfig.json" - config_absolute_path = get_absolute_file_path_from_relative(__file__, config_relative_path) + config_absolute_path = get_absolute_file_path_from_relative( + __file__, config_relative_path + ) aiconfig = AIConfigRuntime.load(config_absolute_path) prompt3 = aiconfig.prompts[2] diff --git a/python/tests/test_load_config.py b/python/tests/test_load_config.py index 6d2af5f7c..7a4eec522 100644 --- a/python/tests/test_load_config.py +++ b/python/tests/test_load_config.py @@ -12,7 +12,9 @@ async def test_load_basic_chatgpt_query_config(set_temporary_env_vars): """Test loading a basic chatgpt query config""" config_relative_path = "aiconfigs/basic_chatgpt_query_config.json" - config_absolute_path = get_absolute_file_path_from_relative(__file__, config_relative_path) + config_absolute_path = get_absolute_file_path_from_relative( + __file__, config_relative_path + ) config = AIConfigRuntime.load(config_absolute_path) data_for_inference = await config.resolve("prompt1") @@ -22,7 +24,12 @@ async def test_load_basic_chatgpt_query_config(set_temporary_env_vars): "top_p": 1, "temperature": 1, "stream": False, - "messages": [{"content": "Hi! Tell me 10 cool things to do in NYC.", "role": "user"}], + "messages": [ + { + "content": "Hi! Tell me 10 cool things to do in NYC.", + "role": "user", + } + ], } @@ -30,7 +37,9 @@ async def test_load_basic_chatgpt_query_config(set_temporary_env_vars): async def test_load_basic_dalle2_config(set_temporary_env_vars): """Test loading a basic Dall-E 2 config""" config_relative_path = "aiconfigs/basic_dalle2_config.json" - config_absolute_path = get_absolute_file_path_from_relative(__file__, config_relative_path) + config_absolute_path = get_absolute_file_path_from_relative( + __file__, config_relative_path + ) config = AIConfigRuntime.load(config_absolute_path) data_for_inference = await config.resolve("panda_eating_dumplings") @@ -47,7 +56,9 @@ async def test_load_basic_dalle2_config(set_temporary_env_vars): async def test_load_basic_dalle3_config(set_temporary_env_vars): """Test loading a basic Dall-E 3 config""" config_relative_path = "aiconfigs/basic_dalle3_config.json" - config_absolute_path = get_absolute_file_path_from_relative(__file__, config_relative_path) + config_absolute_path = get_absolute_file_path_from_relative( + __file__, config_relative_path + ) config = AIConfigRuntime.load(config_absolute_path) data_for_inference = await config.resolve("panda_eating_dumplings") @@ -64,7 +75,9 @@ async def test_load_basic_dalle3_config(set_temporary_env_vars): async def test_chained_gpt_config(set_temporary_env_vars): """Test loading a chained gpt config and resolving it, with chat context enabled""" config_relative_path = "aiconfigs/chained_gpt_config.json" - config_absolute_path = get_absolute_file_path_from_relative(__file__, config_relative_path) + config_absolute_path = get_absolute_file_path_from_relative( + __file__, config_relative_path + ) config = AIConfigRuntime.load(config_absolute_path) data_for_inference1 = await config.resolve("prompt1") @@ -109,10 +122,14 @@ async def test_resolve_system_prompt(): Resolves a system prompt with a provided parameter """ config_relative_path = "aiconfigs/system_prompt_parameters_config.json" - config_absolute_path = get_absolute_file_path_from_relative(__file__, config_relative_path) + config_absolute_path = get_absolute_file_path_from_relative( + __file__, config_relative_path + ) config = AIConfigRuntime.load(config_absolute_path) - data_for_inference = await config.resolve("prompt1", {"system": "skip odd numbers"}) + data_for_inference = await config.resolve( + "prompt1", {"system": "skip odd numbers"} + ) assert data_for_inference == { "temperature": 1, "model": "gpt-3.5-turbo", diff --git a/python/tests/test_parameter_api.py b/python/tests/test_parameter_api.py index 2f06153c8..257709b34 100644 --- a/python/tests/test_parameter_api.py +++ b/python/tests/test_parameter_api.py @@ -37,7 +37,10 @@ def test_delete_nonexistent_parameter(ai_config_runtime: AIConfigRuntime): ) # Ensure deleting a nonexistent parameter raises a KeyError - with pytest.raises(KeyError, match=f"Parameter '{parameter_name_to_delete}' does not exist."): + with pytest.raises( + KeyError, + match=f"Parameter '{parameter_name_to_delete}' does not exist.", + ): config.delete_parameter(parameter_name_to_delete) @@ -142,7 +145,9 @@ def test_set_parameter_for_prompt_no_metadata(ai_config: AIConfig): assert prompt.metadata is None prompt_parameter_name = "prompt_param" prompt_parameter_value = "prompt_value" - ai_config.set_parameter(prompt_parameter_name, prompt_parameter_value, prompt_name) + ai_config.set_parameter( + prompt_parameter_name, prompt_parameter_value, prompt_name + ) # Ensure the prompt parameter is set correctly assert prompt.metadata is not None @@ -179,7 +184,9 @@ def test_set_parameter_for_prompt_no_parameters(ai_config: AIConfig): assert prompt.metadata.parameters == {} prompt_parameter_name = "prompt_param" prompt_parameter_value = "prompt_value" - ai_config.set_parameter(prompt_parameter_name, prompt_parameter_value, prompt_name) + ai_config.set_parameter( + prompt_parameter_name, prompt_parameter_value, prompt_name + ) # Ensure the prompt parameter is set correctly assert prompt.metadata is not None @@ -222,7 +229,9 @@ def test_set_parameter_for_prompt_has_parameters(ai_config: AIConfig): ) prompt_parameter_value = "prompt_value" - ai_config.set_parameter(prompt_parameter_name, prompt_parameter_value, prompt_name) + ai_config.set_parameter( + prompt_parameter_name, prompt_parameter_value, prompt_name + ) # Ensure the prompt parameter is set correctly assert prompt.metadata is not None @@ -258,7 +267,9 @@ def test_delete_existing_parameter(ai_config: AIConfig): parameter_name_to_delete = "param_to_delete" parameter_value = "param_value" - ai_config.set_parameter(parameter_name_to_delete, parameter_value, prompt_name=None) + ai_config.set_parameter( + parameter_name_to_delete, parameter_value, prompt_name=None + ) ai_config.delete_parameter(parameter_name_to_delete, prompt_name=None) assert ai_config.metadata.parameters is not None diff --git a/python/tests/test_programmatically_create_an_AIConfig.py b/python/tests/test_programmatically_create_an_AIConfig.py index 245bba3cd..39188315a 100644 --- a/python/tests/test_programmatically_create_an_AIConfig.py +++ b/python/tests/test_programmatically_create_an_AIConfig.py @@ -2,7 +2,14 @@ from aiconfig.Config import AIConfigRuntime from aiconfig.util.config_utils import extract_override_settings -from aiconfig.schema import AIConfig, ConfigMetadata, ExecuteResult, ModelMetadata, Prompt, PromptMetadata +from aiconfig.schema import ( + AIConfig, + ConfigMetadata, + ExecuteResult, + ModelMetadata, + Prompt, + PromptMetadata, +) @pytest.fixture @@ -61,7 +68,9 @@ def test_delete_nonexistent_model(ai_config: AIConfig): non_existent_model = "non_existent_model" # Ensure trying to delete a non-existent model raises an exception - with pytest.raises(Exception, match=f"Model '{non_existent_model}' does not exist."): + with pytest.raises( + Exception, match=f"Model '{non_existent_model}' does not exist." + ): ai_config.delete_model(non_existent_model) @@ -156,7 +165,9 @@ def test_create_config_with_schema_version(): """ Test creating an AIConfig with a specific schema version. """ - config_runtime = AIConfigRuntime.create("AIConfig with Schema", schema_version="v1") + config_runtime = AIConfigRuntime.create( + "AIConfig with Schema", schema_version="v1" + ) config = config_runtime @@ -183,7 +194,9 @@ def test_add_prompt_with_duplicate_name(ai_config_runtime: AIConfigRuntime): config.add_prompt("prompt1", prompt_data1) # Ensure adding a prompt with a duplicate name raises an exception - with pytest.raises(Exception, match=r"Prompt with name prompt1 already exists."): + with pytest.raises( + Exception, match=r"Prompt with name prompt1 already exists." + ): config.add_prompt("prompt1", prompt_data2) @@ -201,7 +214,10 @@ def test_update_nonexistent_prompt(ai_config_runtime: AIConfigRuntime): ) # Ensure updating a nonexistent prompt raises an exception - with pytest.raises(IndexError, match=f"Prompt '{nonexistent_prompt_name}' not found in config"): + with pytest.raises( + IndexError, + match=f"Prompt '{nonexistent_prompt_name}' not found in config", + ): config.update_prompt(nonexistent_prompt_name, prompt_data) @@ -213,11 +229,15 @@ def test_delete_nonexistent_prompt(ai_config_runtime: AIConfigRuntime): prompt_name = "nonexistent_prompt" # Ensure deleting a nonexistent prompt raises an exception - with pytest.raises(IndexError, match=f"Prompt '{prompt_name}' not found in config"): + with pytest.raises( + IndexError, match=f"Prompt '{prompt_name}' not found in config" + ): config.delete_prompt(prompt_name) -def test_get_metadata_with_nonexistent_prompt(ai_config_runtime: AIConfigRuntime): +def test_get_metadata_with_nonexistent_prompt( + ai_config_runtime: AIConfigRuntime, +): """ Test the retrieval of metadata from a non-existent prompt. """ @@ -225,7 +245,9 @@ def test_get_metadata_with_nonexistent_prompt(ai_config_runtime: AIConfigRuntime prompt_name = "nonexistent_prompt" # Ensure that attempting to retrieve metadata for a non-existent prompt raises an exception - with pytest.raises(IndexError, match=f"Prompt '{prompt_name}' not found in config"): + with pytest.raises( + IndexError, match=f"Prompt '{prompt_name}' not found in config" + ): config.get_metadata(prompt_name) @@ -247,7 +269,9 @@ def test_load_saved_config(tmp_path): config_runtime = AIConfigRuntime.create("My AIConfig") # Create a configuration-level parameter - config_runtime.set_parameter("config_param", "config_value", prompt_name=None) + config_runtime.set_parameter( + "config_param", "config_value", prompt_name=None + ) # Create a sample prompt for testing prompt_data = Prompt( @@ -258,7 +282,9 @@ def test_load_saved_config(tmp_path): config_runtime.add_prompt("prompt1", prompt_data) # Set a prompt-level parameter - config_runtime.set_parameter("prompt_param", "prompt_value", prompt_name="prompt1") + config_runtime.set_parameter( + "prompt_param", "prompt_value", prompt_name="prompt1" + ) json_config_filepath = tmp_path / "my_aiconfig.json" config_runtime.save(json_config_filepath) @@ -267,10 +293,14 @@ def test_load_saved_config(tmp_path): # Ensure the loaded AIConfig contains the expected data assert loaded_config.name == "My AIConfig" - assert loaded_config.metadata.parameters == {"config_param": "config_value"} + assert loaded_config.metadata.parameters == { + "config_param": "config_value" + } assert "prompt1" in loaded_config.prompt_index assert loaded_config.prompt_index["prompt1"].metadata is not None - assert loaded_config.prompt_index["prompt1"].metadata.parameters == {"prompt_param": "prompt_value"} + assert loaded_config.prompt_index["prompt1"].metadata.parameters == { + "prompt_param": "prompt_value" + } def test_set_config_name(ai_config_runtime: AIConfigRuntime): @@ -295,7 +325,9 @@ def test_get_prompt_existing(ai_config_runtime: AIConfigRuntime): assert retrieved_prompt == prompt -def test_get_prompt_after_deleting_previous(ai_config_runtime: AIConfigRuntime): +def test_get_prompt_after_deleting_previous( + ai_config_runtime: AIConfigRuntime, +): prompt1 = Prompt( name="GreetingPrompt", input="Hello, how are you?", @@ -314,7 +346,9 @@ def test_get_prompt_after_deleting_previous(ai_config_runtime: AIConfigRuntime): def test_get_prompt_nonexistent(ai_config_runtime: AIConfigRuntime): - with pytest.raises(IndexError, match=r"Prompt 'GreetingPrompt' not found in config"): + with pytest.raises( + IndexError, match=r"Prompt 'GreetingPrompt' not found in config" + ): ai_config_runtime.get_prompt("GreetingPrompt") @@ -324,27 +358,35 @@ def test_update_model_for_ai_config(ai_config_runtime: AIConfigRuntime): model_name = "testmodel" settings = {"topP": 0.9} ai_config_runtime.update_model(model_name, settings) - pytest.warns(match=f"No prompt name was given to update the model name to '{model_name}'.") + pytest.warns( + match=f"No prompt name was given to update the model name to '{model_name}'." + ) assert ai_config_runtime.metadata.models is not None assert ai_config_runtime.metadata.models[model_name] == settings # Existing model name, no settings --> no-op ai_config_runtime.update_model(model_name) - pytest.warns(match=f"No prompt name was given to update the model name to '{model_name}'.") + pytest.warns( + match=f"No prompt name was given to update the model name to '{model_name}'." + ) assert ai_config_runtime.metadata.models is not None assert ai_config_runtime.metadata.models[model_name] == settings # Existing model name, new settings --> update new_settings = {"topP": 0.75} ai_config_runtime.update_model(model_name, new_settings) - pytest.warns(match=f"No prompt name was given to update the model name to '{model_name}'.") + pytest.warns( + match=f"No prompt name was given to update the model name to '{model_name}'." + ) assert ai_config_runtime.metadata.models is not None assert ai_config_runtime.metadata.models[model_name] == new_settings # New model name, no settings --> update new_model_name = "testmodel_without_settings" ai_config_runtime.update_model(new_model_name) - pytest.warns(match=f"No prompt name was given to update the model name to '{new_model_name}'.") + pytest.warns( + match=f"No prompt name was given to update the model name to '{new_model_name}'." + ) assert ai_config_runtime.metadata.models is not None assert ai_config_runtime.metadata.models[new_model_name] == {} @@ -359,21 +401,27 @@ def test_update_model_for_prompt(ai_config_runtime: AIConfigRuntime): ai_config_runtime.update_model(model_name, settings, prompt1.name) prompt = ai_config_runtime.get_prompt(prompt1.name) assert prompt.metadata is not None - assert prompt.metadata.model == ModelMetadata(name=model_name, settings=settings) + assert prompt.metadata.model == ModelMetadata( + name=model_name, settings=settings + ) # New model name, no settings --> update name only new_model_name = "testmodel_new_name" ai_config_runtime.update_model(new_model_name, None, prompt1.name) prompt = ai_config_runtime.get_prompt(prompt1.name) assert prompt.metadata is not None - assert prompt.metadata.model == ModelMetadata(name=new_model_name, settings=settings) + assert prompt.metadata.model == ModelMetadata( + name=new_model_name, settings=settings + ) # Same model name, new settings --> update settings only settings = {"topP": 0.9} ai_config_runtime.update_model(new_model_name, settings, prompt1.name) prompt = ai_config_runtime.get_prompt(prompt1.name) assert prompt.metadata is not None - assert prompt.metadata.model == ModelMetadata(name=new_model_name, settings=settings) + assert prompt.metadata.model == ModelMetadata( + name=new_model_name, settings=settings + ) # New name, no settings, prompt with model name as string --> update prompt2 = Prompt( @@ -386,7 +434,9 @@ def test_update_model_for_prompt(ai_config_runtime: AIConfigRuntime): ai_config_runtime.update_model(new_name_again, None, prompt2.name) prompt = ai_config_runtime.get_prompt(prompt2.name) assert prompt.metadata is not None - assert prompt.metadata.model == ModelMetadata(name=new_name_again, settings={}) + assert prompt.metadata.model == ModelMetadata( + name=new_name_again, settings={} + ) # New name, no settings, prompt with metadata but no model --> update tags = ["my_fancy_tags"] @@ -404,7 +454,9 @@ def test_update_model_for_prompt(ai_config_runtime: AIConfigRuntime): assert prompt.metadata.tags == tags -def test_update_model_with_invalid_arguments(ai_config_runtime: AIConfigRuntime): +def test_update_model_with_invalid_arguments( + ai_config_runtime: AIConfigRuntime, +): """Test trying to update model with invalid arguments.""" with pytest.raises( ValueError, @@ -416,7 +468,9 @@ def test_update_model_with_invalid_arguments(ai_config_runtime: AIConfigRuntime) ValueError, match=r"Cannot update model. There are two things you are trying:", ): - ai_config_runtime.update_model(model_name=None, settings={"top": 0.9}, prompt_name=None) + ai_config_runtime.update_model( + model_name=None, settings={"top": 0.9}, prompt_name=None + ) def test_set_and_delete_metadata_ai_config(ai_config_runtime: AIConfigRuntime): @@ -430,7 +484,9 @@ def test_set_and_delete_metadata_ai_config(ai_config_runtime: AIConfigRuntime): assert hasattr(ai_config_runtime.get_metadata(), "testkey") is False -def test_set_and_delete_metadata_ai_config_prompt(ai_config_runtime: AIConfigRuntime): +def test_set_and_delete_metadata_ai_config_prompt( + ai_config_runtime: AIConfigRuntime, +): """Test deleting a non-existent metadata key at the AIConfig level.""" prompt = Prompt( name="GreetingPrompt", @@ -440,14 +496,24 @@ def test_set_and_delete_metadata_ai_config_prompt(ai_config_runtime: AIConfigRun ai_config_runtime.add_prompt(prompt.name, prompt) ai_config_runtime.set_metadata("testkey", "testvalue", "GreetingPrompt") - assert ai_config_runtime.get_prompt("GreetingPrompt").metadata.testkey == "testvalue" + assert ( + ai_config_runtime.get_prompt("GreetingPrompt").metadata.testkey + == "testvalue" + ) ai_config_runtime.delete_metadata("testkey", "GreetingPrompt") - assert hasattr(ai_config_runtime.get_prompt("GreetingPrompt").metadata, "testkey") is False + assert ( + hasattr( + ai_config_runtime.get_prompt("GreetingPrompt").metadata, "testkey" + ) + is False + ) -def test_add_output_existing_prompt_no_overwrite(ai_config_runtime: AIConfigRuntime): +def test_add_output_existing_prompt_no_overwrite( + ai_config_runtime: AIConfigRuntime, +): """Test adding an output to an existing prompt without overwriting.""" prompt = Prompt( name="GreetingPrompt", @@ -459,7 +525,9 @@ def test_add_output_existing_prompt_no_overwrite(ai_config_runtime: AIConfigRunt output_type="execute_result", execution_count=0, data="test output", - metadata={"raw_response": {"role": "assistant", "content": "test output"}}, + metadata={ + "raw_response": {"role": "assistant", "content": "test output"} + }, ) ai_config_runtime.add_output("GreetingPrompt", test_result) @@ -469,23 +537,34 @@ def test_add_output_existing_prompt_no_overwrite(ai_config_runtime: AIConfigRunt output_type="execute_result", execution_count=0, data="test output", - metadata={"raw_response": {"role": "assistant", "content": "test output for second time"}}, + metadata={ + "raw_response": { + "role": "assistant", + "content": "test output for second time", + } + }, ) ai_config_runtime.add_output("GreetingPrompt", test_result2) - assert ai_config_runtime.get_latest_output("GreetingPrompt") == test_result2 + assert ( + ai_config_runtime.get_latest_output("GreetingPrompt") == test_result2 + ) ai_config_runtime.delete_output("GreetingPrompt") assert ai_config_runtime.get_latest_output("GreetingPrompt") == None -def test_add_outputs_existing_prompt_no_overwrite(ai_config_runtime: AIConfigRuntime): +def test_add_outputs_existing_prompt_no_overwrite( + ai_config_runtime: AIConfigRuntime, +): """Test adding outputs to an existing prompt without overwriting.""" original_result = ExecuteResult( output_type="execute_result", execution_count=0, data="original result", - metadata={"raw_response": {"role": "assistant", "content": "original result"}}, + metadata={ + "raw_response": {"role": "assistant", "content": "original result"} + }, ) prompt = Prompt( name="GreetingPrompt", @@ -495,33 +574,48 @@ def test_add_outputs_existing_prompt_no_overwrite(ai_config_runtime: AIConfigRun ) ai_config_runtime.add_prompt(prompt.name, prompt) - assert ai_config_runtime.get_latest_output("GreetingPrompt") == original_result + assert ( + ai_config_runtime.get_latest_output("GreetingPrompt") + == original_result + ) test_result1 = ExecuteResult( output_type="execute_result", execution_count=0, data="test output 1", - metadata={"raw_response": {"role": "assistant", "content": "test output 1"}}, + metadata={ + "raw_response": {"role": "assistant", "content": "test output 1"} + }, ) test_result2 = ExecuteResult( output_type="execute_result", execution_count=0, data="test output 2", - metadata={"raw_response": {"role": "assistant", "content": "test output 2"}}, + metadata={ + "raw_response": {"role": "assistant", "content": "test output 2"} + }, + ) + ai_config_runtime.add_outputs( + "GreetingPrompt", [test_result1, test_result2] ) - ai_config_runtime.add_outputs("GreetingPrompt", [test_result1, test_result2]) - assert ai_config_runtime.get_latest_output("GreetingPrompt") == test_result2 + assert ( + ai_config_runtime.get_latest_output("GreetingPrompt") == test_result2 + ) assert prompt.outputs == [original_result, test_result1, test_result2] -def test_add_outputs_existing_prompt_with_overwrite(ai_config_runtime: AIConfigRuntime): +def test_add_outputs_existing_prompt_with_overwrite( + ai_config_runtime: AIConfigRuntime, +): """Test adding outputs to an existing prompt with overwriting.""" original_result = ExecuteResult( output_type="execute_result", execution_count=0, data="original result", - metadata={"raw_response": {"role": "assistant", "content": "original result"}}, + metadata={ + "raw_response": {"role": "assistant", "content": "original result"} + }, ) prompt = Prompt( name="GreetingPrompt", @@ -531,23 +625,34 @@ def test_add_outputs_existing_prompt_with_overwrite(ai_config_runtime: AIConfigR ) ai_config_runtime.add_prompt(prompt.name, prompt) - assert ai_config_runtime.get_latest_output("GreetingPrompt") == original_result + assert ( + ai_config_runtime.get_latest_output("GreetingPrompt") + == original_result + ) test_result1 = ExecuteResult( output_type="execute_result", execution_count=0, data="test output 1", - metadata={"raw_response": {"role": "assistant", "content": "test output 1"}}, + metadata={ + "raw_response": {"role": "assistant", "content": "test output 1"} + }, ) test_result2 = ExecuteResult( output_type="execute_result", execution_count=0, data="test output 2", - metadata={"raw_response": {"role": "assistant", "content": "test output 2"}}, + metadata={ + "raw_response": {"role": "assistant", "content": "test output 2"} + }, + ) + ai_config_runtime.add_outputs( + "GreetingPrompt", [test_result1, test_result2], True ) - ai_config_runtime.add_outputs("GreetingPrompt", [test_result1, test_result2], True) - assert ai_config_runtime.get_latest_output("GreetingPrompt") == test_result2 + assert ( + ai_config_runtime.get_latest_output("GreetingPrompt") == test_result2 + ) assert prompt.outputs == [test_result1, test_result2] @@ -574,13 +679,17 @@ def test_add_undefined_outputs_to_prompt(ai_config_runtime: AIConfigRuntime): ai_config_runtime.add_outputs("GreetingPrompt", [], True) -def test_add_output_existing_prompt_overwrite(ai_config_runtime: AIConfigRuntime): +def test_add_output_existing_prompt_overwrite( + ai_config_runtime: AIConfigRuntime, +): """Test adding an output to an existing prompt with overwriting.""" original_output = ExecuteResult( output_type="execute_result", execution_count=0, data="original output", - metadata={"raw_response": {"role": "assistant", "content": "original output"}}, + metadata={ + "raw_response": {"role": "assistant", "content": "original output"} + }, ) prompt = Prompt( name="GreetingPrompt", @@ -590,16 +699,24 @@ def test_add_output_existing_prompt_overwrite(ai_config_runtime: AIConfigRuntime ) ai_config_runtime.add_prompt(prompt.name, prompt) # check that the original_output is there - assert ai_config_runtime.get_latest_output("GreetingPrompt") == original_output + assert ( + ai_config_runtime.get_latest_output("GreetingPrompt") + == original_output + ) expected_output = ExecuteResult( output_type="execute_result", execution_count=0, data="original output", - metadata={"raw_response": {"role": "assistant", "content": "original output"}}, + metadata={ + "raw_response": {"role": "assistant", "content": "original output"} + }, ) # overwrite the original_output ai_config_runtime.add_output("GreetingPrompt", expected_output, True) - assert ai_config_runtime.get_latest_output("GreetingPrompt") == expected_output + assert ( + ai_config_runtime.get_latest_output("GreetingPrompt") + == expected_output + ) def test_add_undefined_output_to_prompt(ai_config_runtime: AIConfigRuntime): @@ -628,17 +745,25 @@ def test_extract_override_settings(ai_config_runtime: AIConfigRuntime): initial_settings = {"topP": 0.9} # Test Case 1: No global setting, Expect an override - override = extract_override_settings(ai_config_runtime, initial_settings, "testmodel") + override = extract_override_settings( + ai_config_runtime, initial_settings, "testmodel" + ) assert override == {"topP": 0.9} # Test Case 2: Global Settings differ, expect override ai_config_runtime.add_model("testmodel", {"topP": 0.8}) - override = extract_override_settings(ai_config_runtime, initial_settings, "testmodel") + override = extract_override_settings( + ai_config_runtime, initial_settings, "testmodel" + ) assert override == {"topP": 0.9} # Test Case 3: Global Settings match settings, expect no override - ai_config_runtime.update_model(model_name="testmodel", settings={"topP": 0.9}) - override = extract_override_settings(ai_config_runtime, initial_settings, "testmodel") + ai_config_runtime.update_model( + model_name="testmodel", settings={"topP": 0.9} + ) + override = extract_override_settings( + ai_config_runtime, initial_settings, "testmodel" + ) assert override == {} # Test Case 4: Global settings defined and empty settings defined. Expect no override diff --git a/python/tests/test_registry.py b/python/tests/test_registry.py index 2b84149dd..6fc65e115 100644 --- a/python/tests/test_registry.py +++ b/python/tests/test_registry.py @@ -24,7 +24,9 @@ class TestModelParserRegistry: @classmethod def setup_class(cls): # Store the original value of static_variable - cls.original_static_variable = copy.deepcopy(ModelParserRegistry._parsers) + cls.original_static_variable = copy.deepcopy( + ModelParserRegistry._parsers + ) @classmethod def teardown_class(cls): @@ -33,7 +35,9 @@ def teardown_class(cls): def test_register_multiple_ids_to_one_parser(self): mock_model_parser = MockModelParser() - ModelParserRegistry.register_model_parser(mock_model_parser, ["id1", "id2"]) + ModelParserRegistry.register_model_parser( + mock_model_parser, ["id1", "id2"] + ) assert ModelParserRegistry.get_model_parser("id1") == mock_model_parser assert ModelParserRegistry.get_model_parser("id2") == mock_model_parser @@ -42,7 +46,10 @@ def test_register_single_model_parser(self): mock_model_parser = MockModelParser() ModelParserRegistry.register_model_parser(mock_model_parser) - assert ModelParserRegistry.get_model_parser(mock_model_parser.id()) == mock_model_parser + assert ( + ModelParserRegistry.get_model_parser(mock_model_parser.id()) + == mock_model_parser + ) def test_register_multiple_model_parsers_with_different_ids(self): # Create model parsers @@ -50,19 +57,29 @@ def test_register_multiple_model_parsers_with_different_ids(self): model_parser_2 = MockModelParser() # Register the model parsers with different IDs - ModelParserRegistry.register_model_parser(model_parser_1, ids=["model-4"]) - ModelParserRegistry.register_model_parser(model_parser_2, ids=["model-5"]) + ModelParserRegistry.register_model_parser( + model_parser_1, ids=["model-4"] + ) + ModelParserRegistry.register_model_parser( + model_parser_2, ids=["model-5"] + ) # Assert that each model parser is registered under its respective ID - assert ModelParserRegistry.get_model_parser("model-4") == model_parser_1 - assert ModelParserRegistry.get_model_parser("model-5") == model_parser_2 + assert ( + ModelParserRegistry.get_model_parser("model-4") == model_parser_1 + ) + assert ( + ModelParserRegistry.get_model_parser("model-5") == model_parser_2 + ) def test_retrieve_model_parser(self): # Create a model parser model_parser = MockModelParser() # Register the model parser - ModelParserRegistry.register_model_parser(model_parser, ids=["model-6"]) + ModelParserRegistry.register_model_parser( + model_parser, ids=["model-6"] + ) # Retrieve the model parser using its ID retrieved_parser = ModelParserRegistry.get_model_parser("model-6") @@ -75,12 +92,16 @@ def test_retrieve_nonexistent_model_parser(self): with pytest.raises(KeyError): ModelParserRegistry.get_model_parser("nonexistent-model") - def test_retrieve_model_parser_for_prompt(self, ai_config_runtime: AIConfigRuntime): + def test_retrieve_model_parser_for_prompt( + self, ai_config_runtime: AIConfigRuntime + ): # Create a model parser model_parser = MockModelParser() # Register the model parser with a specific model name - ModelParserRegistry.register_model_parser(model_parser, ids=["model-7"]) + ModelParserRegistry.register_model_parser( + model_parser, ids=["model-7"] + ) # Create a Prompt object with the registered model name prompt = Prompt( @@ -93,12 +114,16 @@ def test_retrieve_model_parser_for_prompt(self, ai_config_runtime: AIConfigRunti ai_config_runtime.add_prompt(prompt.name, prompt) # Retrieve the model parser for the Prompt - retrieved_parser = ModelParserRegistry.get_model_parser_for_prompt(prompt, ai_config_runtime) + retrieved_parser = ModelParserRegistry.get_model_parser_for_prompt( + prompt, ai_config_runtime + ) # Assert that the retrieved model parser is the same as the registered one assert retrieved_parser == model_parser - def test_retrieve_model_parser_for_prompt_with_nonexistent_model(self, ai_config_runtime: AIConfigRuntime): + def test_retrieve_model_parser_for_prompt_with_nonexistent_model( + self, ai_config_runtime: AIConfigRuntime + ): # Create a Prompt object with a model name that is not registered prompt = Prompt( **{ @@ -111,14 +136,18 @@ def test_retrieve_model_parser_for_prompt_with_nonexistent_model(self, ai_config # Attempt to retrieve a model parser for the Prompt with pytest.raises(KeyError): - ModelParserRegistry.get_model_parser_for_prompt(prompt, ai_config_runtime).id() + ModelParserRegistry.get_model_parser_for_prompt( + prompt, ai_config_runtime + ).id() def test_remove_model_parser(self): # Create a model parser model_parser = MockModelParser() # Register the model parser - ModelParserRegistry.register_model_parser(model_parser, ids=["model-8"]) + ModelParserRegistry.register_model_parser( + model_parser, ids=["model-8"] + ) assert ModelParserRegistry.get_model_parser("model-8") == model_parser # Remove the registered model parser @@ -134,8 +163,12 @@ def test_clear_registry(self): model_parser_2 = MockModelParser() # Register the model parsers - ModelParserRegistry.register_model_parser(model_parser_1, ids=["model-9"]) - ModelParserRegistry.register_model_parser(model_parser_2, ids=["model-10"]) + ModelParserRegistry.register_model_parser( + model_parser_1, ids=["model-9"] + ) + ModelParserRegistry.register_model_parser( + model_parser_2, ids=["model-10"] + ) # Clear the registry ModelParserRegistry.clear_registry() diff --git a/python/tests/test_resolve.py b/python/tests/test_resolve.py index 555643131..04a312f5a 100644 --- a/python/tests/test_resolve.py +++ b/python/tests/test_resolve.py @@ -11,12 +11,19 @@ async def test_resolve_default_model_config_with_openai_parser(): Test that the default model config is resolved correctly. `basic_default_model_aiconfig.json` is an aiconfig with 1 prompt that has no settings or model defined besides the default. """ config_relative_path = "aiconfigs/basic_default_model_aiconfig.json" - config_absolute_path = get_absolute_file_path_from_relative(__file__, config_relative_path) + config_absolute_path = get_absolute_file_path_from_relative( + __file__, config_relative_path + ) config = AIConfigRuntime.load(config_absolute_path) resolved_params = await config.resolve("prompt1") assert resolved_params == { - "messages": [{"content": "Hi! Tell me 10 cool things to do in NYC.", "role": "user"}], + "messages": [ + { + "content": "Hi! Tell me 10 cool things to do in NYC.", + "role": "user", + } + ], "model": "gpt-3.5-turbo", "temperature": 1, "top_p": 1, diff --git a/python/tests/test_run_config.py b/python/tests/test_run_config.py index 99dd4d7c3..886f5c2b3 100644 --- a/python/tests/test_run_config.py +++ b/python/tests/test_run_config.py @@ -13,9 +13,15 @@ async def test_load_parametrized_data_config(set_temporary_env_vars): Config has 2 prompts. Prompt2 uses prompt1.output in its input. """ - with patch.object(openai.chat.completions, "create", side_effect=mock_openai_chat_completion): + with patch.object( + openai.chat.completions, + "create", + side_effect=mock_openai_chat_completion, + ): config_relative_path = "aiconfigs/parametrized_data_config.json" - config_absolute_path = get_absolute_file_path_from_relative(__file__, config_relative_path) + config_absolute_path = get_absolute_file_path_from_relative( + __file__, config_relative_path + ) config = AIConfigRuntime.load(config_absolute_path) prompt1_params = { diff --git a/python/tests/test_eval.py b/python/tests/test_test_suite_eval.py similarity index 67% rename from python/tests/test_eval.py rename to python/tests/test_test_suite_eval.py index 34f3087ce..8832177ca 100644 --- a/python/tests/test_eval.py +++ b/python/tests/test_test_suite_eval.py @@ -3,30 +3,48 @@ import logging import os from typing import Any -from frozendict import frozendict import hypothesis import hypothesis.strategies as st import lastmile_utils.lib.core.api as core_utils import pandas as pd import pytest -from aiconfig.eval.api import TestSuiteWithInputsSettings, metrics, run_test_suite_outputs_only, run_test_suite_with_inputs -from aiconfig.eval.lib import MetricList, TestSuiteGeneralSettings, TestSuiteWithInputsSpec, run_test_suite_helper, text_eval_res_to_df +from aiconfig.eval.api import ( + TestSuiteWithInputsSettings, + run_test_suite_outputs_only, + run_test_suite_with_inputs, + test_suite_metrics, +) +from aiconfig.eval.test_suite_lib import ( + MetricList, + TestSuiteGeneralSettings, + TestSuiteWithInputsSpec, + run_test_suite_helper, + text_eval_res_to_df, +) +from frozendict import frozendict from result import Err, Ok from . import mocks -brevity = metrics.brevity -substring_match = metrics.substring_match +brevity = test_suite_metrics.brevity +substring_match = test_suite_metrics.substring_match MOCK_NLTK_SENTIMENT_SCORE_MAPPING = { "nltk is amazing": {"pos": 0.9, "neu": 0.1, "neg": 0.0, "compound": 0.9}, - "whats for dinner?": {"pos": 0.0, "neu": 0.9, "neg": 0.1, "compound": -0.9}, + "whats for dinner?": { + "pos": 0.0, + "neu": 0.9, + "neg": 0.1, + "compound": -0.9, + }, "oh, bother": {"pos": 0.0, "neu": 0.1, "neg": 0.9, "compound": -0.9}, } -def _compute_mock_sentiment_class_mapping(score_mapping: dict[str, dict[str, float]]) -> dict[str, str]: +def _compute_mock_sentiment_class_mapping( + score_mapping: dict[str, dict[str, float]] +) -> dict[str, str]: out: dict[str, str] = {} for k, scores in score_mapping.items(): max_class, max_score = "", float("-inf") @@ -39,7 +57,9 @@ def _compute_mock_sentiment_class_mapping(score_mapping: dict[str, dict[str, flo return out -MOCK_NLTK_SENTIMENT_CLASS_MAPPING = _compute_mock_sentiment_class_mapping(MOCK_NLTK_SENTIMENT_SCORE_MAPPING) +MOCK_NLTK_SENTIMENT_CLASS_MAPPING = _compute_mock_sentiment_class_mapping( + MOCK_NLTK_SENTIMENT_SCORE_MAPPING +) def set_pd(): @@ -60,19 +80,26 @@ async def test_metrics(): assert await brevity("hello") == 5.0 assert await substring_match("lo w")("hello world") == 1.0 - assert await substring_match("hello", case_sensitive=False)("HELLO world") == 1.0 - assert await substring_match("hello", case_sensitive=True)("HELLO world") == 0.0 + assert ( + await substring_match("hello", case_sensitive=False)("HELLO world") + == 1.0 + ) + assert ( + await substring_match("hello", case_sensitive=True)("HELLO world") + == 0.0 + ) @pytest.mark.asyncio async def test_run_with_inputs_sanity_check(): """No easy way to mock LLM calls from outside run_test_suite_with_inputs. - Instead, give empty list and just test the imports and sanity check output.""" + Instead, give empty list and just test the imports and sanity check output. + """ path = os.path.join( current_dir(), - "../src/aiconfig/eval/examples/travel/travel_parametrized.aiconfig.json", + "../src/aiconfig/eval/test_suite_examples/travel/travel_parametrized.aiconfig.json", ) out = await run_test_suite_with_inputs( [], @@ -152,7 +179,10 @@ async def test_run_test_suite_with_inputs(data: st.DataObject): out = await run_test_suite_helper( TestSuiteWithInputsSpec( - test_suite=user_test_suite_with_inputs, prompt_name="prompt0", aiconfig=mock_aiconfig, general_settings=TestSuiteGeneralSettings() + test_suite=user_test_suite_with_inputs, + prompt_name="prompt0", + aiconfig=mock_aiconfig, + general_settings=TestSuiteGeneralSettings(), ) ) @@ -173,7 +203,10 @@ async def test_run_test_suite_with_inputs(data: st.DataObject): "worst_possible_value", ] - input_pairs = {(input_datum, metric.metric_metadata.id) for input_datum, metric in user_test_suite_with_inputs} + input_pairs = { + (input_datum, metric.metric_metadata.id) + for input_datum, metric in user_test_suite_with_inputs + } result_pairs = set( # type: ignore[no-untyped-call] df[["input", "metric_id"]].itertuples(index=False, name=None) # type: ignore[no-untyped-call] ) @@ -202,7 +235,9 @@ async def test_run_test_suite_with_inputs_general_params(data: st.DataObject): Also see test_run_with_inputs_sanity_check. """ metrics_list = [brevity, substring_match("hello")] - inputs = st.dictionaries(st.text(min_size=1), st.text(min_size=1), min_size=0, max_size=2) + inputs = st.dictionaries( + st.text(min_size=1), st.text(min_size=1), min_size=0, max_size=2 + ) test_pairs = st.tuples(inputs, st.sampled_from(metrics_list)) user_test_suite_with_inputs = data.draw( st.lists( @@ -211,14 +246,21 @@ async def test_run_test_suite_with_inputs_general_params(data: st.DataObject): ) ) - async def mock_run_text_to_text(prompt_name: str, params: dict[str, str]) -> str: - return f"{prompt_name}_output." + ",".join(f"{key=};{value=}" for key, value in params.items()) + async def mock_run_text_to_text( + prompt_name: str, params: dict[str, str] + ) -> str: + return f"{prompt_name}_output." + ",".join( + f"{key=};{value=}" for key, value in params.items() + ) mock_aiconfig = mocks.make_mock_aiconfig_runtime(mock_run_text_to_text) out = await run_test_suite_helper( TestSuiteWithInputsSpec( - test_suite=user_test_suite_with_inputs, prompt_name="prompt0", aiconfig=mock_aiconfig, general_settings=TestSuiteGeneralSettings() + test_suite=user_test_suite_with_inputs, + prompt_name="prompt0", + aiconfig=mock_aiconfig, + general_settings=TestSuiteGeneralSettings(), ) ) @@ -251,7 +293,9 @@ async def mock_run_text_to_text(prompt_name: str, params: dict[str, str]) -> str df[["input", "metric_id"]].itertuples(index=False, name=None) # type: ignore[no-untyped-call] ) - assert input_pairs == result_pairs, f"fail: {input_pairs=}, {result_pairs=}" + assert ( + input_pairs == result_pairs + ), f"fail: {input_pairs=}, {result_pairs=}" df_brevity = df[df["metric_name"] == "brevity"] # type: ignore assert (df_brevity["aiconfig_output"].apply(len) == df_brevity["value"]).all() # type: ignore @@ -267,30 +311,42 @@ def _make_mock_nltk_metrics() -> MetricList[str]: def _mock_get_nltk_polarity_scores(text: str) -> dict[str, float]: return MOCK_NLTK_SENTIMENT_SCORE_MAPPING[text] - mock_nltk_sentiment_scores_vader = metrics.make_sentiment_scores_metric( - get_polarity_scores=_mock_get_nltk_polarity_scores, - make_evaluation_fn=metrics.make_get_sentiment_scores, - name="nltk_sentiment_scores_vader", - description="NLTK sentiment scores using Vader", + mock_nltk_sentiment_scores_vader = ( + test_suite_metrics.make_sentiment_scores_metric( + get_polarity_scores=_mock_get_nltk_polarity_scores, + make_evaluation_fn=test_suite_metrics.make_get_sentiment_scores, + name="nltk_sentiment_scores_vader", + description="NLTK sentiment scores using Vader", + ) ) - mock_nltk_sentiment_class_vader = metrics.make_sentiment_scores_metric( - get_polarity_scores=_mock_get_nltk_polarity_scores, - make_evaluation_fn=metrics.make_get_sentiment_class, - name="nltk_sentiment_class_vader", - description="Highest-probability NLTK sentiment class using Vader", + mock_nltk_sentiment_class_vader = ( + test_suite_metrics.make_sentiment_scores_metric( + get_polarity_scores=_mock_get_nltk_polarity_scores, + make_evaluation_fn=test_suite_metrics.make_get_sentiment_class, + name="nltk_sentiment_class_vader", + description="Highest-probability NLTK sentiment class using Vader", + ) ) - mock_nltk_sentiment_score_overall_positive = metrics.make_sentiment_scores_metric( + mock_nltk_sentiment_score_overall_positive = test_suite_metrics.make_sentiment_scores_metric( get_polarity_scores=_mock_get_nltk_polarity_scores, - make_evaluation_fn=metrics.make_get_overall_positive_sentiment, + make_evaluation_fn=test_suite_metrics.make_get_overall_positive_sentiment, name="nltk_sentiment_score_overall_positive", description="Positive minus negative", - best_value=metrics.TextOverallPositiveSentiment(pos=1.0, neg=0.0), - worst_value=metrics.TextOverallPositiveSentiment(pos=0.0, neg=1.0), + best_value=test_suite_metrics.TextOverallPositiveSentiment( + pos=1.0, neg=0.0 + ), + worst_value=test_suite_metrics.TextOverallPositiveSentiment( + pos=0.0, neg=1.0 + ), ) - return [mock_nltk_sentiment_scores_vader, mock_nltk_sentiment_class_vader, mock_nltk_sentiment_score_overall_positive] + return [ + mock_nltk_sentiment_scores_vader, + mock_nltk_sentiment_class_vader, + mock_nltk_sentiment_score_overall_positive, + ] @pytest.mark.asyncio @@ -304,31 +360,41 @@ async def test_custom_metric_type(): ) df = await run_test_suite_outputs_only(user_test_suite_outputs_only) result = df.set_index(["metric_name", "aiconfig_output"]).value.unstack(0).to_dict() # type: ignore - assert result["nltk_sentiment_class_vader"] == MOCK_NLTK_SENTIMENT_CLASS_MAPPING + assert ( + result["nltk_sentiment_class_vader"] + == MOCK_NLTK_SENTIMENT_CLASS_MAPPING + ) - assert all(isinstance(v, metrics.TextSentimentScores) for v in result["nltk_sentiment_scores_vader"].values()) # type: ignore + assert all(isinstance(v, test_suite_metrics.TextSentimentScores) for v in result["nltk_sentiment_scores_vader"].values()) # type: ignore - assert all(isinstance(v, metrics.TextOverallPositiveSentiment) for v in result["nltk_sentiment_score_overall_positive"].values()) # type: ignore + assert all(isinstance(v, test_suite_metrics.TextOverallPositiveSentiment) for v in result["nltk_sentiment_score_overall_positive"].values()) # type: ignore - neutral = metrics.TextOverallPositiveSentiment(pos=0.0, neg=0.0) + neutral = test_suite_metrics.TextOverallPositiveSentiment(pos=0.0, neg=0.0) - assert result["nltk_sentiment_score_overall_positive"]["nltk is amazing"] > neutral - assert result["nltk_sentiment_score_overall_positive"]["oh, bother"] < neutral + assert ( + result["nltk_sentiment_score_overall_positive"]["nltk is amazing"] + > neutral + ) + assert ( + result["nltk_sentiment_score_overall_positive"]["oh, bother"] < neutral + ) @pytest.mark.asyncio async def test_exception_metric(caplog: pytest.LogCaptureFixture): - user_test_suite_outputs_only = list( - itertools.product( - ["Hundred Acre Wood", ""], - [brevity], - ) - ) + user_test_suite_outputs_only = [ + ("Hundred Acre Wood", brevity), + ("", brevity), + ] + with caplog.at_level(logging.ERROR): df = await run_test_suite_outputs_only(user_test_suite_outputs_only) - print(df[["metric_name"]]) + print(df[["metric_name"]]) # type: ignore[pandas] mapping: dict[str, Any] = df.query("metric_name=='brevity'").set_index("aiconfig_output").value.to_dict() # type: ignore assert mapping["Hundred Acre Wood"] == 17.0 assert pd.isnull(mapping[""]) # type: ignore - assert any("Brevity is meaningless for empty string." in record.msg for record in caplog.records) + assert any( + "Brevity is meaningless for empty string." in record.msg + for record in caplog.records + ) diff --git a/python/tests/test_eval_model_graded_openai.py b/python/tests/test_test_suite_eval_model_graded_openai.py similarity index 66% rename from python/tests/test_eval_model_graded_openai.py rename to python/tests/test_test_suite_eval_model_graded_openai.py index 8d7dc6848..198430ef8 100644 --- a/python/tests/test_eval_model_graded_openai.py +++ b/python/tests/test_test_suite_eval_model_graded_openai.py @@ -5,12 +5,14 @@ import openai.types.chat.chat_completion as openai_chat_completion_types import openai.types.chat.chat_completion_message_tool_call as openai_tool_call_types import pytest -from aiconfig.eval import common -from aiconfig.eval.api import metrics, run_test_suite_outputs_only +from aiconfig.eval import test_suite_common +from aiconfig.eval.api import run_test_suite_outputs_only, test_suite_metrics from result import Ok, Result -def _mock_response(function_args: common.SerializedJSON) -> openai_chat_types.ChatCompletion: +def _mock_response( + function_args: test_suite_common.SerializedJSON, +) -> openai_chat_types.ChatCompletion: return openai_chat_types.ChatCompletion( id="123", choices=[ @@ -39,7 +41,9 @@ def _mock_response(function_args: common.SerializedJSON) -> openai_chat_types.Ch ) -def _make_mock_openai_chat_completion_create(function_arguments_return: common.SerializedJSON) -> lib_openai.OpenAIChatCompletionCreate: +def _make_mock_openai_chat_completion_create( + function_arguments_return: test_suite_common.SerializedJSON, +) -> lib_openai.OpenAIChatCompletionCreate: def _mock_openai_chat_completion_create( completion_params: lib_openai.OpenAIChatCompletionParams, ) -> Result[openai_chat_types.ChatCompletion, str]: @@ -55,11 +59,13 @@ def _mock_openai_chat_completion_create( @pytest.mark.asyncio async def test_openai_structured_eval(): _mock_create = _make_mock_openai_chat_completion_create( - common.SerializedJSON('{"conciseness_rating": 5, "conciseness_confidence": 0.9, "conciseness_reasoning": "I think it\'s pretty concise."}') + test_suite_common.SerializedJSON( + '{"conciseness_rating": 5, "conciseness_confidence": 0.9, "conciseness_reasoning": "I think it\'s pretty concise."}' + ) ) - mock_metric = metrics.make_openai_structured_llm_metric( + mock_metric = test_suite_metrics.make_openai_structured_llm_metric( eval_llm_name="gpt-3.5-turbo-0613", - pydantic_basemodel_type=common.TextRatingsData, + pydantic_basemodel_type=test_suite_common.TextRatingsData, metric_name="text_ratings", metric_description="Text ratings", field_descriptions=dict( @@ -74,22 +80,33 @@ async def test_openai_structured_eval(): ("one two three", mock_metric), ] df = await run_test_suite_outputs_only(user_test_suite_outputs_only) - metric_data = cast(common.CustomMetricPydanticObject[common.TextRatingsData], df.loc[0, "value"]).data - assert isinstance(metric_data, common.TextRatingsData) + metric_data = cast( + test_suite_common.CustomMetricPydanticObject[ + test_suite_common.TextRatingsData + ], + df.loc[0, "value"], + ).data + assert isinstance(metric_data, test_suite_common.TextRatingsData) metric_json = metric_data.to_dict() - assert metric_json == {"conciseness_rating": 5, "conciseness_confidence": 0.9, "conciseness_reasoning": "I think it's pretty concise."} + assert metric_json == { + "conciseness_rating": 5, + "conciseness_confidence": 0.9, + "conciseness_reasoning": "I think it's pretty concise.", + } @pytest.mark.asyncio async def test_bad_structured_eval_metric(): _mock_create = _make_mock_openai_chat_completion_create( - common.SerializedJSON('{"conciseness_rating": 5, "conciseness_confidence": 0.9, "conciseness_reasoning": "I think it\'s pretty concise."}') + test_suite_common.SerializedJSON( + '{"conciseness_rating": 5, "conciseness_confidence": 0.9, "conciseness_reasoning": "I think it\'s pretty concise."}' + ) ) with pytest.raises(ValueError) as exc: - _ = metrics.make_openai_structured_llm_metric( + _ = test_suite_metrics.make_openai_structured_llm_metric( eval_llm_name="gpt-3.5-turbo-0613", - pydantic_basemodel_type=common.TextRatingsData, + pydantic_basemodel_type=test_suite_common.TextRatingsData, metric_name="text_ratings", metric_description="Text ratings", field_descriptions=dict( @@ -101,4 +118,7 @@ async def test_bad_structured_eval_metric(): openai_chat_completion_create=_mock_create, ) - assert "The following field_descriptions keys are not in the schema" in str(exc) + assert ( + "The following field_descriptions keys are not in the schema" + in str(exc) + ) diff --git a/python/tests/test_util/test_params.py b/python/tests/test_util/test_params.py index 9952c84e1..a6d09ad86 100644 --- a/python/tests/test_util/test_params.py +++ b/python/tests/test_util/test_params.py @@ -1,5 +1,9 @@ import pytest -from aiconfig.util.params import find_dependencies_in_prompt, get_dependency_graph, get_parameters_in_template +from aiconfig.util.params import ( + find_dependencies_in_prompt, + get_dependency_graph, + get_parameters_in_template, +) from aiconfig.schema import Prompt, PromptMetadata @@ -26,7 +30,9 @@ def test_template_with_no_parameters(): def test_template_with_empty_params(): - result = get_parameters_in_template("This is a plain text template with a fake param {{}}.") + result = get_parameters_in_template( + "This is a plain text template with a fake param {{}}." + ) assert result == {} @@ -75,27 +81,39 @@ def test_find_dependencies_in_prompt(prompt_list_with_5_prompts): prompt_template = "I am referring to {{prompt1.input}} and this {{prompt4.output}}" # only allowed to reference upstream prompts current_prompt_name = "prompt2" - result = find_dependencies_in_prompt(prompt_template, current_prompt_name, prompt_list_with_5_prompts) + result = find_dependencies_in_prompt( + prompt_template, current_prompt_name, prompt_list_with_5_prompts + ) assert result == {"prompt1"} -def test_find_dependencies_in_prompt_with_no_dependencies(prompt_list_with_5_prompts): +def test_find_dependencies_in_prompt_with_no_dependencies( + prompt_list_with_5_prompts, +): # generate a list of 5 Prompts with name prompt1, prompt2, ... prompt_template = "I am referring to {{}}" current_prompt_name = "prompt2" - result = find_dependencies_in_prompt(prompt_template, current_prompt_name, prompt_list_with_5_prompts) + result = find_dependencies_in_prompt( + prompt_template, current_prompt_name, prompt_list_with_5_prompts + ) assert not result -def test_find_dependencies_in_prompt_with_two_dependencies(prompt_list_with_5_prompts): +def test_find_dependencies_in_prompt_with_two_dependencies( + prompt_list_with_5_prompts, +): # generate a list of 5 Prompts with name prompt1, prompt2, ... - prompt_template = "I am referring to {{prompt2.output}} and {{prompt1.output}}" # + prompt_template = ( + "I am referring to {{prompt2.output}} and {{prompt1.output}}" # + ) current_prompt_name = "prompt4" - result = find_dependencies_in_prompt(prompt_template, current_prompt_name, prompt_list_with_5_prompts) + result = find_dependencies_in_prompt( + prompt_template, current_prompt_name, prompt_list_with_5_prompts + ) assert result == {"prompt1", "prompt2"} @@ -107,7 +125,9 @@ def test_find_dependencies_in_prompt_with_no_prompt_reference( prompt_template = "I am referring to {{fakeprompt.output}} and {{fakeprompt.output}}" # should return none, no prompt references here current_prompt_name = "prompt4" - result = find_dependencies_in_prompt(prompt_template, current_prompt_name, prompt_list_with_5_prompts) + result = find_dependencies_in_prompt( + prompt_template, current_prompt_name, prompt_list_with_5_prompts + ) assert not result diff --git a/python/tests/util/file_path_utils.py b/python/tests/util/file_path_utils.py index 70794ed6a..8228cd4e5 100644 --- a/python/tests/util/file_path_utils.py +++ b/python/tests/util/file_path_utils.py @@ -1,7 +1,9 @@ import os -def get_absolute_file_path_from_relative(working_file_path: str, relative_file_path: str) -> str: +def get_absolute_file_path_from_relative( + working_file_path: str, relative_file_path: str +) -> str: """ Returns the absolute file path of a file given its relative file path and the file path of the calling file.