diff --git a/cookbooks/Basic-Prompt-Routing/assistant_app.py b/cookbooks/Basic-Prompt-Routing/assistant_app.py index 36c7f2911..db4d4b562 100644 --- a/cookbooks/Basic-Prompt-Routing/assistant_app.py +++ b/cookbooks/Basic-Prompt-Routing/assistant_app.py @@ -1,14 +1,14 @@ import asyncio import os +# Create ~/.env file with this line: export OPENAI_API_KEY= +# You can get your key from https://platform.openai.com/api-keys +import dotenv import openai import streamlit as st from aiconfig import AIConfigRuntime -# Create ~/.env file with this line: export OPENAI_API_KEY= -# You can get your key from https://platform.openai.com/api-keys -import dotenv dotenv.load_dotenv() openai.api_key = os.getenv("OPENAI_API_KEY") @@ -32,12 +32,8 @@ async def assistant_response(prompt): # Streamlit Setup st.title("AI Teaching Assistant") -st.markdown( - "Ask a math, physics, or general question. Based on your question, an AI math prof, physics prof, or general assistant will respond." -) -st.markdown( - "**This is a simple demo of prompt routing - based on your question, an LLM decides which AI teacher responds.**" -) +st.markdown("Ask a math, physics, or general question. Based on your question, an AI math prof, physics prof, or general assistant will respond.") +st.markdown("**This is a simple demo of prompt routing - based on your question, an LLM decides which AI teacher responds.**") # Chat setup if "messages" not in st.session_state: diff --git a/cookbooks/Cli-Mate/cli-mate.py b/cookbooks/Cli-Mate/cli-mate.py index f869d2da8..c7f9e0b35 100644 --- a/cookbooks/Cli-Mate/cli-mate.py +++ b/cookbooks/Cli-Mate/cli-mate.py @@ -53,19 +53,13 @@ async def query(aiconfig_path: str, question: str) -> list[ExecuteResult]: return result -async def get_mod_result( - aiconfig_path: str, source_code: str, question: str -) -> list[ExecuteResult]: - question_about_code = ( - f"QUERY ABOUT SOURCE CODE:\n{question}\nSOURCE CODE:\n```{source_code}\n```" - ) +async def get_mod_result(aiconfig_path: str, source_code: str, question: str) -> list[ExecuteResult]: + question_about_code = f"QUERY ABOUT SOURCE CODE:\n{question}\nSOURCE CODE:\n```{source_code}\n```" return await query(aiconfig_path, question_about_code) -async def mod_code( - aiconfig_path: str, source_code_file: str, question: str, update_file: bool = False -): +async def mod_code(aiconfig_path: str, source_code_file: str, question: str, update_file: bool = False): # read source code from file with open(source_code_file, "r", encoding="utf8") as file: source_code = file.read() @@ -99,9 +93,7 @@ def signal_handler(_: int, __: FrameType | None): i = 0 while True: try: - user_input = await event_loop.run_in_executor( - None, session.prompt, "Query: [ctrl-D to exit] " - ) + user_input = await event_loop.run_in_executor(None, session.prompt, "Query: [ctrl-D to exit] ") except KeyboardInterrupt: continue except EOFError: @@ -152,9 +144,7 @@ async def main(): subparsers = parser.add_subparsers(dest="command") loop_parser = subparsers.add_parser("loop") - loop_parser.add_argument( - "-scf", "--source-code-file", help="Specify a source code file." - ) + loop_parser.add_argument("-scf", "--source-code-file", help="Specify a source code file.") args = parser.parse_args() diff --git a/cookbooks/HuggingFace/hf.py b/cookbooks/HuggingFace/hf.py index aad4e1f2b..8257b326e 100644 --- a/cookbooks/HuggingFace/hf.py +++ b/cookbooks/HuggingFace/hf.py @@ -3,10 +3,7 @@ # HuggingFace API imports from huggingface_hub import InferenceClient -from huggingface_hub.inference._text_generation import ( - TextGenerationResponse, - TextGenerationStreamResponse, -) +from huggingface_hub.inference._text_generation import TextGenerationResponse, TextGenerationStreamResponse # ModelParser Utils # Type hint imports @@ -166,14 +163,7 @@ def id(self) -> str: """ return "HuggingFaceTextParser" - def serialize( - self, - prompt_name: str, - data: Any, - ai_config: "AIConfigRuntime", - parameters: Optional[Dict] = None, - **kwargs - ) -> List[Prompt]: + def serialize(self, prompt_name: str, data: Any, ai_config: "AIConfigRuntime", parameters: Optional[Dict] = None, **kwargs) -> List[Prompt]: """ Defines how a prompt and model inference settings get serialized in the .aiconfig. @@ -196,9 +186,7 @@ def serialize( prompt = Prompt( name=prompt_name, input=prompt_input, - metadata=PromptMetadata( - model=model_metadata, parameters=parameters, **kwargs - ), + metadata=PromptMetadata(model=model_metadata, parameters=parameters, **kwargs), ) return [prompt] @@ -230,9 +218,7 @@ async def deserialize( return completion_data - async def run_inference( - self, prompt: Prompt, aiconfig, options, parameters - ) -> List[Output]: + async def run_inference(self, prompt: Prompt, aiconfig, options, parameters) -> List[Output]: """ Invoked to run a prompt in the .aiconfig. This method should perform the actual model inference based on the provided prompt and inference settings. @@ -247,9 +233,7 @@ async def run_inference( completion_data = await self.deserialize(prompt, aiconfig, options, parameters) # if stream enabled in runtime options and config, then stream. Otherwise don't stream. - stream = (options.stream if options else False) and ( - not "stream" in completion_data or completion_data.get("stream") != False - ) + stream = (options.stream if options else False) and (not "stream" in completion_data or completion_data.get("stream") != False) response = self.client.text_generation(**completion_data) response_is_detailed = completion_data.get("details", False) diff --git a/cookbooks/HuggingFace/python/hf.py b/cookbooks/HuggingFace/python/hf.py index aad4e1f2b..8257b326e 100644 --- a/cookbooks/HuggingFace/python/hf.py +++ b/cookbooks/HuggingFace/python/hf.py @@ -3,10 +3,7 @@ # HuggingFace API imports from huggingface_hub import InferenceClient -from huggingface_hub.inference._text_generation import ( - TextGenerationResponse, - TextGenerationStreamResponse, -) +from huggingface_hub.inference._text_generation import TextGenerationResponse, TextGenerationStreamResponse # ModelParser Utils # Type hint imports @@ -166,14 +163,7 @@ def id(self) -> str: """ return "HuggingFaceTextParser" - def serialize( - self, - prompt_name: str, - data: Any, - ai_config: "AIConfigRuntime", - parameters: Optional[Dict] = None, - **kwargs - ) -> List[Prompt]: + def serialize(self, prompt_name: str, data: Any, ai_config: "AIConfigRuntime", parameters: Optional[Dict] = None, **kwargs) -> List[Prompt]: """ Defines how a prompt and model inference settings get serialized in the .aiconfig. @@ -196,9 +186,7 @@ def serialize( prompt = Prompt( name=prompt_name, input=prompt_input, - metadata=PromptMetadata( - model=model_metadata, parameters=parameters, **kwargs - ), + metadata=PromptMetadata(model=model_metadata, parameters=parameters, **kwargs), ) return [prompt] @@ -230,9 +218,7 @@ async def deserialize( return completion_data - async def run_inference( - self, prompt: Prompt, aiconfig, options, parameters - ) -> List[Output]: + async def run_inference(self, prompt: Prompt, aiconfig, options, parameters) -> List[Output]: """ Invoked to run a prompt in the .aiconfig. This method should perform the actual model inference based on the provided prompt and inference settings. @@ -247,9 +233,7 @@ async def run_inference( completion_data = await self.deserialize(prompt, aiconfig, options, parameters) # if stream enabled in runtime options and config, then stream. Otherwise don't stream. - stream = (options.stream if options else False) and ( - not "stream" in completion_data or completion_data.get("stream") != False - ) + stream = (options.stream if options else False) and (not "stream" in completion_data or completion_data.get("stream") != False) response = self.client.text_generation(**completion_data) response_is_detailed = completion_data.get("details", False) diff --git a/cookbooks/Wizard-GPT/wizard-gpt.py b/cookbooks/Wizard-GPT/wizard-gpt.py index 71b03dfdb..976de85ef 100644 --- a/cookbooks/Wizard-GPT/wizard-gpt.py +++ b/cookbooks/Wizard-GPT/wizard-gpt.py @@ -1,12 +1,14 @@ import asyncio +import os -from aiconfig import AIConfigRuntime, InferenceOptions, Prompt +import dotenv # Create ~/.env file with this line: export OPENAI_API_KEY= -# You can get your key from https://platform.openai.com/api-keys +# You can get your key from https://platform.openai.com/api-keys import openai -import dotenv -import os + +from aiconfig import AIConfigRuntime, InferenceOptions, Prompt + dotenv.load_dotenv() openai.api_key = os.getenv("OPENAI_API_KEY") diff --git a/python/src/aiconfig/ChatCompletion.py b/python/src/aiconfig/ChatCompletion.py index 4426507ac..8bcbd2d1c 100644 --- a/python/src/aiconfig/ChatCompletion.py +++ b/python/src/aiconfig/ChatCompletion.py @@ -29,10 +29,7 @@ def validate_and_add_prompts_to_config(prompts: List[Prompt], aiconfig) -> None: in_config = False for config_prompt in aiconfig.prompts: # check for duplicates (same input and settings.) - if ( - config_prompt.input == new_prompt.input - and new_prompt.metadata == config_prompt.metadata - ): + if config_prompt.input == new_prompt.input and new_prompt.metadata == config_prompt.metadata: in_config = True # update outputs if different if config_prompt.outputs != new_prompt.outputs: @@ -58,9 +55,7 @@ def extract_outputs_from_response(response) -> List[Output]: response = response.model_dump(exclude_none=True) - response_without_choices = { - key: copy.deepcopy(value) for key, value in response.items() if key != "choices" - } + response_without_choices = {key: copy.deepcopy(value) for key, value in response.items() if key != "choices"} for i, choice in enumerate(response.get("choices")): response_without_choices.update({"finish_reason": choice.get("finish_reason")}) output = ExecuteResult( @@ -77,7 +72,7 @@ def extract_outputs_from_response(response) -> List[Output]: def async_run_serialize_helper( - aiconfig: AIConfigRuntime, + aiconfig: AIConfigRuntime, request_kwargs: Dict, ) -> List[Prompt]: """ @@ -88,9 +83,7 @@ def async_run_serialize_helper( serialized_prompts = None async def run_and_await_serialize(): - result = await aiconfig.serialize( - request_kwargs.get("model"), request_kwargs, "prompt" - ) + result = await aiconfig.serialize(request_kwargs.get("model"), request_kwargs, "prompt") return result # serialize prompts from ChatCompletion kwargs @@ -125,17 +118,9 @@ def _get_aiconfig_runtime(output_aiconfig_path: str) -> AIConfigRuntime: except IOError: return AIConfigRuntime.create(**(aiconfig_settings or {})) - output_aiconfig = ( - output_aiconfig_ref - if isinstance(output_aiconfig_ref, AIConfigRuntime) - else _get_aiconfig_runtime(output_aiconfig_ref) - ) + output_aiconfig = output_aiconfig_ref if isinstance(output_aiconfig_ref, AIConfigRuntime) else _get_aiconfig_runtime(output_aiconfig_ref) - output_config_file_path = ( - output_aiconfig_ref - if isinstance(output_aiconfig_ref, str) - else output_aiconfig_ref.file_path - ) + output_config_file_path = output_aiconfig_ref if isinstance(output_aiconfig_ref, str) else output_aiconfig_ref.file_path # TODO: openai makes it hard to statically annotate. def _create_chat_completion_with_config_saving(*args, **kwargs) -> Any: # type: ignore @@ -147,9 +132,7 @@ def _create_chat_completion_with_config_saving(*args, **kwargs) -> Any: # type: outputs = [] # Check if response is a stream - stream = kwargs.get("stream", False) is True and isinstance( - response, openai.Stream - ) + stream = kwargs.get("stream", False) is True and isinstance(response, openai.Stream) # Convert Response to output for last prompt if not stream: @@ -189,9 +172,7 @@ def generate_streamed_response() -> Generator[Any, None, None]: ) stream_outputs[index] = output yield chunk - stream_outputs = [ - stream_outputs[i] for i in sorted(list(stream_outputs.keys())) - ] + stream_outputs = [stream_outputs[i] for i in sorted(list(stream_outputs.keys()))] # Add outputs to last prompt serialized_prompts[-1].outputs = stream_outputs diff --git a/python/src/aiconfig/Config.py b/python/src/aiconfig/Config.py index 794c16d35..20cba84c2 100644 --- a/python/src/aiconfig/Config.py +++ b/python/src/aiconfig/Config.py @@ -1,22 +1,20 @@ import json import os -import yaml from typing import Any, Dict, List, Literal, Optional, Tuple import requests +import yaml from aiconfig.callback import CallbackEvent, CallbackManager from aiconfig.default_parsers.anyscale_endpoint import DefaultAnyscaleEndpointParser from aiconfig.default_parsers.openai import DefaultOpenAIParser from aiconfig.default_parsers.palm import PaLMChatParser, PaLMTextParser from aiconfig.model_parser import InferenceOptions, ModelParser + from aiconfig.schema import JSONObject from .default_parsers.dalle import DalleImageGenerationParser from .default_parsers.hf import HuggingFaceTextGenerationParser -from .registry import ( - ModelParserRegistry, - update_model_parser_registry_with_config_runtime, -) +from .registry import ModelParserRegistry, update_model_parser_registry_with_config_runtime from .schema import AIConfig, Prompt from .util.config_utils import is_yaml_ext diff --git a/python/src/aiconfig/callback.py b/python/src/aiconfig/callback.py index da6f386de..d38895391 100644 --- a/python/src/aiconfig/callback.py +++ b/python/src/aiconfig/callback.py @@ -2,17 +2,7 @@ import asyncio import logging import time -from typing import ( - Any, - Awaitable, - Callable, - Coroutine, - Final, - List, - Sequence, - TypeAlias, - Union, -) +from typing import Any, Awaitable, Callable, Coroutine, Final, List, Sequence, TypeAlias, Union from pydantic import BaseModel, ConfigDict @@ -43,9 +33,7 @@ def __init__(self, name: str, file: str, data: Any, ts_ns: int = time.time_ns()) Result: TypeAlias = Union[Ok[Any], Err[Any]] -async def execute_coroutine_with_timeout( - coroutine: Coroutine[Any, Any, Any], timeout: int -) -> Result: +async def execute_coroutine_with_timeout(coroutine: Coroutine[Any, Any, Any], timeout: int) -> Result: """ Execute a coroutine with a timeout, return an Ok result or an Err on Exception @@ -123,9 +111,7 @@ def setup_logger(): name = "my-logger" log_file = "aiconfig.log" - formatter = logging.Formatter( - "%(asctime)s - %(name)s - %(levelname)s - %(message)s" - ) + formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") handler = logging.FileHandler(log_file) handler.setFormatter(formatter) diff --git a/python/src/aiconfig/editor/server/server.py b/python/src/aiconfig/editor/server/server.py index bd967577a..3bec22a66 100644 --- a/python/src/aiconfig/editor/server/server.py +++ b/python/src/aiconfig/editor/server/server.py @@ -1,6 +1,6 @@ import logging -from typing import Any, Dict, Type, Union import webbrowser +from typing import Any, Dict, Type, Union import lastmile_utils.lib.core.api as core_utils import result diff --git a/python/src/aiconfig/model_parser.py b/python/src/aiconfig/model_parser.py index 1f2a6d694..13e559020 100644 --- a/python/src/aiconfig/model_parser.py +++ b/python/src/aiconfig/model_parser.py @@ -1,14 +1,7 @@ -from abc import ABC, abstractmethod import asyncio import copy -from typing import ( - TYPE_CHECKING, - Any, - Callable, - Dict, - List, - Optional, -) +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional from aiconfig.schema import AIConfig, ExecuteResult, Output, Prompt @@ -33,7 +26,7 @@ async def serialize( **kwargs, ) -> List[Prompt]: """ - Serialize a prompt and additional metadata/model settings into a + Serialize a prompt and additional metadata/model settings into a list of Prompt objects that can be saved in the AIConfig. Args: @@ -150,9 +143,7 @@ def get_output_text( str: The output text from the model inference response. """ - def get_model_settings( - self, prompt: Prompt, aiconfig: "AIConfigRuntime" - ) -> Dict[str, Any]: + def get_model_settings(self, prompt: Prompt, aiconfig: "AIConfigRuntime") -> Dict[str, Any]: """ Extracts the AI model's settings from the configuration. If both prompt and config level settings are defined, merge them with prompt settings taking precedence. @@ -166,10 +157,7 @@ def get_model_settings( return aiconfig.get_global_settings(self.id()) # Check if the prompt exists in the config - if ( - prompt.name not in aiconfig.prompt_index - or aiconfig.prompt_index[prompt.name] != prompt - ): + if prompt.name not in aiconfig.prompt_index or aiconfig.prompt_index[prompt.name] != prompt: raise IndexError(f"Prompt '{prompt.name}' not in config.") model_metadata = prompt.metadata.model if prompt.metadata else None @@ -178,9 +166,7 @@ def get_model_settings( # Use Default Model default_model = aiconfig.get_default_model() if not default_model: - raise KeyError( - f"No default model specified in AIConfigMetadata, and prompt `{prompt.name}` does not specify a model." - ) + raise KeyError(f"No default model specified in AIConfigMetadata, and prompt `{prompt.name}` does not specify a model.") return aiconfig.get_global_settings(default_model) elif isinstance(model_metadata, str): # Use Global settings @@ -189,11 +175,7 @@ def get_model_settings( # Merge config and prompt settings with prompt settings taking precedent model_settings = {} global_settings = aiconfig.get_global_settings(model_metadata.name) - prompt_setings = ( - prompt.metadata.model.settings - if prompt.metadata.model.settings is not None - else {} - ) + prompt_setings = prompt.metadata.model.settings if prompt.metadata.model.settings is not None else {} model_settings.update(global_settings) model_settings.update(prompt_setings) @@ -205,11 +187,7 @@ def print_stream_callback(data, accumulated_data, index: int): """ Default streamCallback function that prints the output to the console. """ - print( - "\ndata: {}\naccumulated_data:{}\nindex:{}\n".format( - data, accumulated_data, index - ) - ) + print("\ndata: {}\naccumulated_data:{}\nindex:{}\n".format(data, accumulated_data, index)) def print_stream_delta(data, accumulated_data, index: int): diff --git a/python/src/aiconfig/scripts/aiconfig_cli.py b/python/src/aiconfig/scripts/aiconfig_cli.py index 31d0a82be..f57de74dc 100644 --- a/python/src/aiconfig/scripts/aiconfig_cli.py +++ b/python/src/aiconfig/scripts/aiconfig_cli.py @@ -1,16 +1,15 @@ import asyncio import logging import signal +import socket import subprocess import sys -import socket import lastmile_utils.lib.core.api as core_utils import result -from result import Err, Ok, Result from aiconfig.editor.server.server import run_backend_server - from aiconfig.editor.server.server_utils import EditServerConfig, ServerMode +from result import Err, Ok, Result class AIConfigCLIConfig(core_utils.Record): @@ -65,19 +64,19 @@ def _sigint(procs: list[subprocess.Popen[bytes]]) -> Result[str, str]: p.send_signal(signal.SIGINT) return Ok("Sent SIGINT to frontend servers.") + def is_port_in_use(port: int) -> bool: - """ + """ Checks if a port is in use at localhost. - + Creates a temporary connection. Context manager will automatically close the connection """ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: - return s.connect_ex(('localhost', port)) == 0 + return s.connect_ex(("localhost", port)) == 0 def _run_editor_servers(edit_config: EditServerConfig) -> Result[list[str], str]: - port = edit_config.server_port while is_port_in_use(port): diff --git a/python/src/aiconfig/scripts/run_aiconfig.py b/python/src/aiconfig/scripts/run_aiconfig.py index d89dca482..25a2a7460 100644 --- a/python/src/aiconfig/scripts/run_aiconfig.py +++ b/python/src/aiconfig/scripts/run_aiconfig.py @@ -2,12 +2,10 @@ import logging import sys -from result import Ok, Result - - +import lastmile_utils.lib.core.api as core_utils from aiconfig.Config import AIConfigRuntime from aiconfig.eval.lib import run_aiconfig_helper -import lastmile_utils.lib.core.api as core_utils +from result import Ok, Result logging.basicConfig(format=core_utils.LOGGER_FMT) LOGGER = logging.getLogger(__name__) @@ -35,9 +33,7 @@ def _load_aiconfig(settings: Settings) -> Result[AIConfigRuntime, str]: settings = res_settings.unwrap() res_aiconfig = _load_aiconfig(settings) aiconfig = res_aiconfig.unwrap() - final_value = await run_aiconfig_helper( - runtime=aiconfig, prompt_name=settings.prompt_name, question=question - ) + final_value = await run_aiconfig_helper(runtime=aiconfig, prompt_name=settings.prompt_name, question=question) print(final_value.unwrap_or_else(str)) diff --git a/python/src/aiconfig/util/params.py b/python/src/aiconfig/util/params.py index 8fcc72c1e..33905c034 100644 --- a/python/src/aiconfig/util/params.py +++ b/python/src/aiconfig/util/params.py @@ -123,9 +123,7 @@ def resolve_parametrized_prompt(raw_prompt, params): return resolved_prompt -def find_dependencies_in_prompt( - prompt_template: str, current_prompt_name: str, prompt_list: List[Prompt] -) -> Set[str]: +def find_dependencies_in_prompt(prompt_template: str, current_prompt_name: str, prompt_list: List[Prompt]) -> Set[str]: """ Finds and returns a set of prompt IDs that are dependencies of the given prompt. @@ -158,9 +156,7 @@ def find_dependencies_in_prompt( return dependencies -def get_dependency_graph( - root_prompt: Prompt, all_prompts: List[Prompt], prompt_dict: Dict[str, Prompt] -) -> dict[str, List[str]]: +def get_dependency_graph(root_prompt: Prompt, all_prompts: List[Prompt], prompt_dict: Dict[str, Prompt]) -> dict[str, List[str]]: """ Generates an upstream dependency graph of prompts in the configuration, with each entry representing only its direct dependencies. Traversal is required to identify all upstream dependencies. The specified prompt serves as the root. @@ -184,9 +180,7 @@ def build_dependency_graph_recursive(current_prompt_name: str) -> dict: visited.add(current_prompt_name) prompt_template = prompt_dict[current_prompt_name].get_raw_prompt_from_config() - prompt_dependencies = find_dependencies_in_prompt( - prompt_template, current_prompt_name, all_prompts - ) + prompt_dependencies = find_dependencies_in_prompt(prompt_template, current_prompt_name, all_prompts) for prompt_dependency in prompt_dependencies: dependency_graph[current_prompt_name].append(prompt_dependency) @@ -236,9 +230,7 @@ def get_prompt_template(prompt: Prompt, aiconfig: "AIConfigRuntime"): elif isinstance(prompt.input.data, str): return prompt.input.data else: - raise Exception( - f"Cannot get prompt template string from prompt: {prompt.input}" - ) + raise Exception(f"Cannot get prompt template string from prompt: {prompt.input}") def collect_prompt_references(current_prompt: Prompt, ai_config: "AIConfigRuntime"): @@ -251,13 +243,7 @@ def collect_prompt_references(current_prompt: Prompt, ai_config: "AIConfigRuntim break prompt_input = get_prompt_template(previous_prompt, ai_config) - prompt_output = ( - ai_config.get_output_text( - previous_prompt, ai_config.get_latest_output(previous_prompt) - ) - if previous_prompt.outputs - else None - ) + prompt_output = ai_config.get_output_text(previous_prompt, ai_config.get_latest_output(previous_prompt)) if previous_prompt.outputs else None prompt_references[previous_prompt.name] = { "input": prompt_input, "output": prompt_output, @@ -265,9 +251,7 @@ def collect_prompt_references(current_prompt: Prompt, ai_config: "AIConfigRuntim return prompt_references -def resolve_prompt( - current_prompt: Prompt, input_params: Dict, ai_config: "AIConfigRuntime" -) -> str: +def resolve_prompt(current_prompt: Prompt, input_params: Dict, ai_config: "AIConfigRuntime") -> str: """ Parameterizes a prompt using provided parameters, references to other prompts, and parameters stored in config.. """ diff --git a/python/tests/mocks.py b/python/tests/mocks.py index 8af3bb340..bcc210400 100644 --- a/python/tests/mocks.py +++ b/python/tests/mocks.py @@ -1,4 +1,3 @@ - from typing import Any from aiconfig.Config import AIConfigRuntime @@ -27,4 +26,3 @@ async def run_and_get_output_text( assert params_.keys() == {"the_query"}, 'For eval, AIConfig params must have just the key "the_query".' the_query = params_["the_query"] return f"output_for_{prompt_name}_the_query_{the_query}" - diff --git a/python/tests/parameterization/test_parameterization.py b/python/tests/parameterization/test_parameterization.py index 0c5704c8d..05dc47a76 100644 --- a/python/tests/parameterization/test_parameterization.py +++ b/python/tests/parameterization/test_parameterization.py @@ -123,7 +123,7 @@ async def test_prompt_params_only(): config = config3 ai_config = AIConfigRuntime(**config) params = {"name": "Tanya User"} - resolved_params = await ai_config.resolve("prompt1", params) + resolved_params = await ai_config.resolve("prompt1", params) prompt_from_resolved_params = resolved_params["messages"][0]["content"] assert prompt_from_resolved_params == "Hello, Tanya User" diff --git a/python/tests/parsers/test_dalle_parser.py b/python/tests/parsers/test_dalle_parser.py index 1aedd6e37..a9d275c2b 100644 --- a/python/tests/parsers/test_dalle_parser.py +++ b/python/tests/parsers/test_dalle_parser.py @@ -14,9 +14,7 @@ async def test_serialize_basic(set_temporary_env_vars: None): "size": "1024x1024", } aiconfig = AIConfigRuntime.create() - serialized_prompts = await aiconfig.serialize( - "dall-e-3", completion_params, prompt_name="panda_eating_dumplings" - ) + serialized_prompts = await aiconfig.serialize("dall-e-3", completion_params, prompt_name="panda_eating_dumplings") new_prompt = serialized_prompts[0] assert new_prompt == Prompt( name="panda_eating_dumplings", diff --git a/python/tests/parsers/test_openai_util.py b/python/tests/parsers/test_openai_util.py index 9064d9cb1..e8ae540db 100644 --- a/python/tests/parsers/test_openai_util.py +++ b/python/tests/parsers/test_openai_util.py @@ -2,15 +2,9 @@ import pytest from aiconfig.Config import AIConfigRuntime from aiconfig.default_parsers.openai import refine_chat_completion_params -from aiconfig.schema import ( - ExecuteResult, - OutputDataWithValue, - Prompt, - PromptInput, - PromptMetadata, -) from mock import patch +from aiconfig.schema import ExecuteResult, Prompt, PromptInput, PromptMetadata from ..conftest import mock_openai_chat_completion from ..util.file_path_utils import get_absolute_file_path_from_relative diff --git a/python/tests/parsers/test_parser.py b/python/tests/parsers/test_parser.py index 1c4691592..32bed9709 100644 --- a/python/tests/parsers/test_parser.py +++ b/python/tests/parsers/test_parser.py @@ -62,18 +62,14 @@ def test_get_model_settings(ai_config_runtime: AIConfigRuntime): Prompt( name="test", input="test", - metadata=PromptMetadata( - model=ModelMetadata(name="fakemodel", settings=None) - ), + metadata=PromptMetadata(model=ModelMetadata(name="fakemodel", settings=None)), ) ], ) prompt = aiconfig.prompts[0] - assert mock_model_parser.get_model_settings(prompt, aiconfig) == { - "fake_setting": "True" - } + assert mock_model_parser.get_model_settings(prompt, aiconfig) == {"fake_setting": "True"} with pytest.raises(IndexError, match=r"Prompt '.*' not in config"): prompt = Prompt( diff --git a/python/tests/test_library_helpers.py b/python/tests/test_library_helpers.py index f665a7e46..0a5289715 100644 --- a/python/tests/test_library_helpers.py +++ b/python/tests/test_library_helpers.py @@ -38,9 +38,7 @@ def test_collect_prompt_references(): # input is an aiconfig with a 4 prompts. Test collects prompt references for the 3rd prompt # collect_prompt_references should return references to 1 and 2. 3 is the prompt we are collecting references for, 4 is after. Both are expected to be skipped config_relative_path = "aiconfigs/GPT4 Coding Assistant_aiconfig.json" - config_absolute_path = get_absolute_file_path_from_relative( - __file__, config_relative_path - ) + config_absolute_path = get_absolute_file_path_from_relative(__file__, config_relative_path) aiconfig = AIConfigRuntime.load(config_absolute_path) prompt3 = aiconfig.prompts[2] diff --git a/python/tests/test_load_config.py b/python/tests/test_load_config.py index 10c4c3eaf..6d2af5f7c 100644 --- a/python/tests/test_load_config.py +++ b/python/tests/test_load_config.py @@ -12,9 +12,7 @@ async def test_load_basic_chatgpt_query_config(set_temporary_env_vars): """Test loading a basic chatgpt query config""" config_relative_path = "aiconfigs/basic_chatgpt_query_config.json" - config_absolute_path = get_absolute_file_path_from_relative( - __file__, config_relative_path - ) + config_absolute_path = get_absolute_file_path_from_relative(__file__, config_relative_path) config = AIConfigRuntime.load(config_absolute_path) data_for_inference = await config.resolve("prompt1") @@ -24,9 +22,7 @@ async def test_load_basic_chatgpt_query_config(set_temporary_env_vars): "top_p": 1, "temperature": 1, "stream": False, - "messages": [ - {"content": "Hi! Tell me 10 cool things to do in NYC.", "role": "user"} - ], + "messages": [{"content": "Hi! Tell me 10 cool things to do in NYC.", "role": "user"}], } @@ -34,9 +30,7 @@ async def test_load_basic_chatgpt_query_config(set_temporary_env_vars): async def test_load_basic_dalle2_config(set_temporary_env_vars): """Test loading a basic Dall-E 2 config""" config_relative_path = "aiconfigs/basic_dalle2_config.json" - config_absolute_path = get_absolute_file_path_from_relative( - __file__, config_relative_path - ) + config_absolute_path = get_absolute_file_path_from_relative(__file__, config_relative_path) config = AIConfigRuntime.load(config_absolute_path) data_for_inference = await config.resolve("panda_eating_dumplings") @@ -53,9 +47,7 @@ async def test_load_basic_dalle2_config(set_temporary_env_vars): async def test_load_basic_dalle3_config(set_temporary_env_vars): """Test loading a basic Dall-E 3 config""" config_relative_path = "aiconfigs/basic_dalle3_config.json" - config_absolute_path = get_absolute_file_path_from_relative( - __file__, config_relative_path - ) + config_absolute_path = get_absolute_file_path_from_relative(__file__, config_relative_path) config = AIConfigRuntime.load(config_absolute_path) data_for_inference = await config.resolve("panda_eating_dumplings") @@ -72,9 +64,7 @@ async def test_load_basic_dalle3_config(set_temporary_env_vars): async def test_chained_gpt_config(set_temporary_env_vars): """Test loading a chained gpt config and resolving it, with chat context enabled""" config_relative_path = "aiconfigs/chained_gpt_config.json" - config_absolute_path = get_absolute_file_path_from_relative( - __file__, config_relative_path - ) + config_absolute_path = get_absolute_file_path_from_relative(__file__, config_relative_path) config = AIConfigRuntime.load(config_absolute_path) data_for_inference1 = await config.resolve("prompt1") @@ -119,9 +109,7 @@ async def test_resolve_system_prompt(): Resolves a system prompt with a provided parameter """ config_relative_path = "aiconfigs/system_prompt_parameters_config.json" - config_absolute_path = get_absolute_file_path_from_relative( - __file__, config_relative_path - ) + config_absolute_path = get_absolute_file_path_from_relative(__file__, config_relative_path) config = AIConfigRuntime.load(config_absolute_path) data_for_inference = await config.resolve("prompt1", {"system": "skip odd numbers"}) diff --git a/python/tests/test_parameter_api.py b/python/tests/test_parameter_api.py index a0acd5578..2f06153c8 100644 --- a/python/tests/test_parameter_api.py +++ b/python/tests/test_parameter_api.py @@ -1,15 +1,7 @@ import pytest from aiconfig.Config import AIConfigRuntime -from aiconfig.util.config_utils import extract_override_settings -from aiconfig.schema import ( - AIConfig, - ConfigMetadata, - ExecuteResult, - ModelMetadata, - Prompt, - PromptMetadata, -) +from aiconfig.schema import AIConfig, ConfigMetadata, Prompt, PromptMetadata @pytest.fixture @@ -17,6 +9,7 @@ def ai_config_runtime(): runtime = AIConfigRuntime.create("Untitled AIConfig") return runtime + @pytest.fixture def ai_config(): config = AIConfig( @@ -27,6 +20,7 @@ def ai_config(): ) return config + def test_delete_nonexistent_parameter(ai_config_runtime: AIConfigRuntime): """ Test deleting a nonexistent parameter. @@ -43,9 +37,7 @@ def test_delete_nonexistent_parameter(ai_config_runtime: AIConfigRuntime): ) # Ensure deleting a nonexistent parameter raises a KeyError - with pytest.raises( - KeyError, match=f"Parameter '{parameter_name_to_delete}' does not exist." - ): + with pytest.raises(KeyError, match=f"Parameter '{parameter_name_to_delete}' does not exist."): config.delete_parameter(parameter_name_to_delete) @@ -62,9 +54,9 @@ def test_set_parameter_for_aiconfig_empty_params(ai_config: AIConfig): input="This is a prompt", metadata=PromptMetadata( model="fakemodel", - parameters= { + parameters={ prompt_parameter_name: prompt_parameter_value, - } + }, ), ) ai_config.add_prompt(prompt_name, prompt) @@ -84,6 +76,7 @@ def test_set_parameter_for_aiconfig_empty_params(ai_config: AIConfig): aiconfig_parameter_name: aiconfig_parameter_value, } + def test_set_parameter_for_aiconfig_has_parameters(ai_config: AIConfig): """ Test setting a global parameter when it already has parameters. @@ -98,16 +91,16 @@ def test_set_parameter_for_aiconfig_has_parameters(ai_config: AIConfig): input="This is a prompt", metadata=PromptMetadata( model="fakemodel", - parameters= { + parameters={ prompt_parameter_name: prompt_parameter_value, - } + }, ), ) ai_config.add_prompt(prompt_name, prompt) aiconfig_parameter_name = "aiconfig_param" ai_config.metadata = ConfigMetadata( - parameters= { + parameters={ "random_key": "keep this parameter", aiconfig_parameter_name: "should update this value", } @@ -126,6 +119,7 @@ def test_set_parameter_for_aiconfig_has_parameters(ai_config: AIConfig): aiconfig_parameter_name: aiconfig_parameter_value, } + def test_set_parameter_for_prompt_no_metadata(ai_config: AIConfig): """ Test setting a prompt parameter when there is no prompt metadata. @@ -140,7 +134,7 @@ def test_set_parameter_for_prompt_no_metadata(ai_config: AIConfig): aiconfig_parameter_name = "aiconfig_param" aiconfig_parameter_value = "aiconfig_value" ai_config.metadata = ConfigMetadata( - parameters= { + parameters={ aiconfig_parameter_name: aiconfig_parameter_value, } ) @@ -159,7 +153,8 @@ def test_set_parameter_for_prompt_no_metadata(ai_config: AIConfig): assert ai_config.metadata.parameters == { aiconfig_parameter_name: aiconfig_parameter_value, } - + + def test_set_parameter_for_prompt_no_parameters(ai_config: AIConfig): """ Test setting a prompt parameter when there are no prompt parameters. @@ -175,7 +170,7 @@ def test_set_parameter_for_prompt_no_parameters(ai_config: AIConfig): aiconfig_parameter_name = "aiconfig_param" aiconfig_parameter_value = "aiconfig_value" ai_config.metadata = ConfigMetadata( - parameters= { + parameters={ aiconfig_parameter_name: aiconfig_parameter_value, } ) @@ -196,6 +191,7 @@ def test_set_parameter_for_prompt_no_parameters(ai_config: AIConfig): aiconfig_parameter_name: aiconfig_parameter_value, } + def test_set_parameter_for_prompt_has_parameters(ai_config: AIConfig): """ Test setting a prompt parameter when it already has parameters. @@ -209,10 +205,10 @@ def test_set_parameter_for_prompt_has_parameters(ai_config: AIConfig): input="This is a prompt", metadata=PromptMetadata( model="fakemodel", - parameters= { + parameters={ "random_key": "keep this parameter", prompt_parameter_name: "should update this value", - } + }, ), ) ai_config.add_prompt(prompt_name, prompt) @@ -220,7 +216,7 @@ def test_set_parameter_for_prompt_has_parameters(ai_config: AIConfig): aiconfig_parameter_name = "aiconfig_param" aiconfig_parameter_value = "aiconfig_value" ai_config.metadata = ConfigMetadata( - parameters= { + parameters={ aiconfig_parameter_name: aiconfig_parameter_value, } ) @@ -269,7 +265,7 @@ def test_delete_existing_parameter(ai_config: AIConfig): assert parameter_name_to_delete not in ai_config.metadata.parameters -# | With both local and global (should use local override) +# | With both local and global (should use local override) # | Without AIConfig but local is `{}` | def test_get_parameter_prompt_has_parameters(ai_config: AIConfig): """ @@ -328,9 +324,7 @@ def test_get_parameter_prompt_has_no_metadata( assert parameters == ai_config.metadata.parameters -def test_get_parameter_prompt_has_metadata_no_parameters( - ai_config: AIConfig -): +def test_get_parameter_prompt_has_metadata_no_parameters(ai_config: AIConfig): """ Test getting a parameter when only aiconfig param is set. Prompt has metadata but no parameters. @@ -355,9 +349,7 @@ def test_get_parameter_prompt_has_metadata_no_parameters( assert parameters == ai_config.metadata.parameters -def test_get_parameter_prompt_has_empty_parameters( - ai_config: AIConfig -): +def test_get_parameter_prompt_has_empty_parameters(ai_config: AIConfig): """ Test getting a parameter when only aiconfig param is set. Prompt has empty parameters. @@ -382,9 +374,7 @@ def test_get_parameter_prompt_has_empty_parameters( assert parameters == ai_config.metadata.parameters -def test_get_parameter_prompt_has_empty_parameters( - ai_config: AIConfig -): +def test_get_parameter_prompt_has_empty_parameters(ai_config: AIConfig): """ Test getting a parameter when only aiconfig param is set. Prompt has empty parameters. @@ -408,6 +398,7 @@ def test_get_parameter_prompt_has_empty_parameters( parameters = ai_config.get_parameters(prompt_name) assert parameters == ai_config.metadata.parameters + def test_get_parameter_aiconfig_has_parameters(ai_config: AIConfig): """ Test getting a parameter for an aiconfig diff --git a/python/tests/test_programmatically_create_an_AIConfig.py b/python/tests/test_programmatically_create_an_AIConfig.py index 52a5d45de..245bba3cd 100644 --- a/python/tests/test_programmatically_create_an_AIConfig.py +++ b/python/tests/test_programmatically_create_an_AIConfig.py @@ -2,14 +2,7 @@ from aiconfig.Config import AIConfigRuntime from aiconfig.util.config_utils import extract_override_settings -from aiconfig.schema import ( - AIConfig, - ConfigMetadata, - ExecuteResult, - ModelMetadata, - Prompt, - PromptMetadata, -) +from aiconfig.schema import AIConfig, ConfigMetadata, ExecuteResult, ModelMetadata, Prompt, PromptMetadata @pytest.fixture @@ -68,9 +61,7 @@ def test_delete_nonexistent_model(ai_config: AIConfig): non_existent_model = "non_existent_model" # Ensure trying to delete a non-existent model raises an exception - with pytest.raises( - Exception, match=f"Model '{non_existent_model}' does not exist." - ): + with pytest.raises(Exception, match=f"Model '{non_existent_model}' does not exist."): ai_config.delete_model(non_existent_model) @@ -279,9 +270,7 @@ def test_load_saved_config(tmp_path): assert loaded_config.metadata.parameters == {"config_param": "config_value"} assert "prompt1" in loaded_config.prompt_index assert loaded_config.prompt_index["prompt1"].metadata is not None - assert loaded_config.prompt_index["prompt1"].metadata.parameters == { - "prompt_param": "prompt_value" - } + assert loaded_config.prompt_index["prompt1"].metadata.parameters == {"prompt_param": "prompt_value"} def test_set_config_name(ai_config_runtime: AIConfigRuntime): @@ -325,9 +314,7 @@ def test_get_prompt_after_deleting_previous(ai_config_runtime: AIConfigRuntime): def test_get_prompt_nonexistent(ai_config_runtime: AIConfigRuntime): - with pytest.raises( - IndexError, match=r"Prompt 'GreetingPrompt' not found in config" - ): + with pytest.raises(IndexError, match=r"Prompt 'GreetingPrompt' not found in config"): ai_config_runtime.get_prompt("GreetingPrompt") @@ -364,11 +351,8 @@ def test_update_model_for_ai_config(ai_config_runtime: AIConfigRuntime): def test_update_model_for_prompt(ai_config_runtime: AIConfigRuntime): """Test updating model for a specific prompt.""" - #New model name, new settings, no prompt metadata --> update - prompt1 = Prompt( - name="GreetingPrompt", - input="Hello, how are you?" - ) + # New model name, new settings, no prompt metadata --> update + prompt1 = Prompt(name="GreetingPrompt", input="Hello, how are you?") ai_config_runtime.add_prompt(prompt1.name, prompt1) model_name = "testmodel" settings = {"topP": 0.9} @@ -376,26 +360,26 @@ def test_update_model_for_prompt(ai_config_runtime: AIConfigRuntime): prompt = ai_config_runtime.get_prompt(prompt1.name) assert prompt.metadata is not None assert prompt.metadata.model == ModelMetadata(name=model_name, settings=settings) - + # New model name, no settings --> update name only new_model_name = "testmodel_new_name" ai_config_runtime.update_model(new_model_name, None, prompt1.name) prompt = ai_config_runtime.get_prompt(prompt1.name) assert prompt.metadata is not None assert prompt.metadata.model == ModelMetadata(name=new_model_name, settings=settings) - + # Same model name, new settings --> update settings only settings = {"topP": 0.9} ai_config_runtime.update_model(new_model_name, settings, prompt1.name) prompt = ai_config_runtime.get_prompt(prompt1.name) assert prompt.metadata is not None assert prompt.metadata.model == ModelMetadata(name=new_model_name, settings=settings) - + # New name, no settings, prompt with model name as string --> update prompt2 = Prompt( name="GreetingPromptModelAsStr", input="Hello, how are you?", - metadata=PromptMetadata(model= "some_random_model"), + metadata=PromptMetadata(model="some_random_model"), ) ai_config_runtime.add_prompt(prompt2.name, prompt2) new_name_again = "testmodel_new_name_model_as_str" @@ -409,7 +393,7 @@ def test_update_model_for_prompt(ai_config_runtime: AIConfigRuntime): prompt3 = Prompt( name="GreetingsNumber3", input="Hello, how are you?", - metadata=PromptMetadata(tags= tags), + metadata=PromptMetadata(tags=tags), ) ai_config_runtime.add_prompt(prompt3.name, prompt3) new_name_3 = "new_name_number_3" @@ -419,6 +403,7 @@ def test_update_model_for_prompt(ai_config_runtime: AIConfigRuntime): assert prompt.metadata.model == ModelMetadata(name=new_name_3, settings={}) assert prompt.metadata.tags == tags + def test_update_model_with_invalid_arguments(ai_config_runtime: AIConfigRuntime): """Test trying to update model with invalid arguments.""" with pytest.raises( @@ -426,7 +411,7 @@ def test_update_model_with_invalid_arguments(ai_config_runtime: AIConfigRuntime) match=r"Cannot update model. Either model name or model settings must be specified.", ): ai_config_runtime.update_model(model_name=None, settings=None) - + with pytest.raises( ValueError, match=r"Cannot update model. There are two things you are trying:", @@ -455,16 +440,11 @@ def test_set_and_delete_metadata_ai_config_prompt(ai_config_runtime: AIConfigRun ai_config_runtime.add_prompt(prompt.name, prompt) ai_config_runtime.set_metadata("testkey", "testvalue", "GreetingPrompt") - assert ( - ai_config_runtime.get_prompt("GreetingPrompt").metadata.testkey == "testvalue" - ) + assert ai_config_runtime.get_prompt("GreetingPrompt").metadata.testkey == "testvalue" ai_config_runtime.delete_metadata("testkey", "GreetingPrompt") - assert ( - hasattr(ai_config_runtime.get_prompt("GreetingPrompt").metadata, "testkey") - is False - ) + assert hasattr(ai_config_runtime.get_prompt("GreetingPrompt").metadata, "testkey") is False def test_add_output_existing_prompt_no_overwrite(ai_config_runtime: AIConfigRuntime): @@ -479,9 +459,7 @@ def test_add_output_existing_prompt_no_overwrite(ai_config_runtime: AIConfigRunt output_type="execute_result", execution_count=0, data="test output", - metadata={ - "raw_response": {"role": "assistant", "content": "test output"} - }, + metadata={"raw_response": {"role": "assistant", "content": "test output"}}, ) ai_config_runtime.add_output("GreetingPrompt", test_result) @@ -491,26 +469,23 @@ def test_add_output_existing_prompt_no_overwrite(ai_config_runtime: AIConfigRunt output_type="execute_result", execution_count=0, data="test output", - metadata={ - "raw_response": {"role": "assistant", "content": "test output for second time"} - }, - ) - + metadata={"raw_response": {"role": "assistant", "content": "test output for second time"}}, + ) + ai_config_runtime.add_output("GreetingPrompt", test_result2) assert ai_config_runtime.get_latest_output("GreetingPrompt") == test_result2 ai_config_runtime.delete_output("GreetingPrompt") assert ai_config_runtime.get_latest_output("GreetingPrompt") == None + def test_add_outputs_existing_prompt_no_overwrite(ai_config_runtime: AIConfigRuntime): """Test adding outputs to an existing prompt without overwriting.""" original_result = ExecuteResult( output_type="execute_result", execution_count=0, data="original result", - metadata={ - "raw_response": {"role": "assistant", "content": "original result"} - }, + metadata={"raw_response": {"role": "assistant", "content": "original result"}}, ) prompt = Prompt( name="GreetingPrompt", @@ -519,39 +494,34 @@ def test_add_outputs_existing_prompt_no_overwrite(ai_config_runtime: AIConfigRun outputs=[original_result], ) ai_config_runtime.add_prompt(prompt.name, prompt) - + assert ai_config_runtime.get_latest_output("GreetingPrompt") == original_result test_result1 = ExecuteResult( output_type="execute_result", execution_count=0, data="test output 1", - metadata={ - "raw_response": {"role": "assistant", "content": "test output 1"} - }, + metadata={"raw_response": {"role": "assistant", "content": "test output 1"}}, ) test_result2 = ExecuteResult( output_type="execute_result", execution_count=0, data="test output 2", - metadata={ - "raw_response": {"role": "assistant", "content": "test output 2"} - }, + metadata={"raw_response": {"role": "assistant", "content": "test output 2"}}, ) ai_config_runtime.add_outputs("GreetingPrompt", [test_result1, test_result2]) assert ai_config_runtime.get_latest_output("GreetingPrompt") == test_result2 assert prompt.outputs == [original_result, test_result1, test_result2] + def test_add_outputs_existing_prompt_with_overwrite(ai_config_runtime: AIConfigRuntime): """Test adding outputs to an existing prompt with overwriting.""" original_result = ExecuteResult( output_type="execute_result", execution_count=0, data="original result", - metadata={ - "raw_response": {"role": "assistant", "content": "original result"} - }, + metadata={"raw_response": {"role": "assistant", "content": "original result"}}, ) prompt = Prompt( name="GreetingPrompt", @@ -560,30 +530,27 @@ def test_add_outputs_existing_prompt_with_overwrite(ai_config_runtime: AIConfigR outputs=[original_result], ) ai_config_runtime.add_prompt(prompt.name, prompt) - + assert ai_config_runtime.get_latest_output("GreetingPrompt") == original_result test_result1 = ExecuteResult( output_type="execute_result", execution_count=0, data="test output 1", - metadata={ - "raw_response": {"role": "assistant", "content": "test output 1"} - }, + metadata={"raw_response": {"role": "assistant", "content": "test output 1"}}, ) test_result2 = ExecuteResult( output_type="execute_result", execution_count=0, data="test output 2", - metadata={ - "raw_response": {"role": "assistant", "content": "test output 2"} - }, + metadata={"raw_response": {"role": "assistant", "content": "test output 2"}}, ) ai_config_runtime.add_outputs("GreetingPrompt", [test_result1, test_result2], True) assert ai_config_runtime.get_latest_output("GreetingPrompt") == test_result2 assert prompt.outputs == [test_result1, test_result2] + def test_add_undefined_outputs_to_prompt(ai_config_runtime: AIConfigRuntime): """Test for adding undefined outputs to an existing prompt with/without overwriting. Should result in an error.""" prompt = Prompt( @@ -606,15 +573,14 @@ def test_add_undefined_outputs_to_prompt(ai_config_runtime: AIConfigRuntime): ): ai_config_runtime.add_outputs("GreetingPrompt", [], True) + def test_add_output_existing_prompt_overwrite(ai_config_runtime: AIConfigRuntime): """Test adding an output to an existing prompt with overwriting.""" original_output = ExecuteResult( output_type="execute_result", execution_count=0, data="original output", - metadata={ - "raw_response": {"role": "assistant", "content": "original output"} - }, + metadata={"raw_response": {"role": "assistant", "content": "original output"}}, ) prompt = Prompt( name="GreetingPrompt", @@ -629,14 +595,13 @@ def test_add_output_existing_prompt_overwrite(ai_config_runtime: AIConfigRuntime output_type="execute_result", execution_count=0, data="original output", - metadata={ - "raw_response": {"role": "assistant", "content": "original output"} - }, + metadata={"raw_response": {"role": "assistant", "content": "original output"}}, ) # overwrite the original_output ai_config_runtime.add_output("GreetingPrompt", expected_output, True) assert ai_config_runtime.get_latest_output("GreetingPrompt") == expected_output + def test_add_undefined_output_to_prompt(ai_config_runtime: AIConfigRuntime): """Test for adding an undefined output to a prompt with/without overwriting. Should result in an error.""" prompt = Prompt( @@ -657,28 +622,23 @@ def test_add_undefined_output_to_prompt(ai_config_runtime: AIConfigRuntime): match=r"Cannot add output to prompt 'GreetingPrompt'. Output is not defined.", ): ai_config_runtime.add_output("GreetingPrompt", None, True) - + + def test_extract_override_settings(ai_config_runtime: AIConfigRuntime): initial_settings = {"topP": 0.9} # Test Case 1: No global setting, Expect an override - override = extract_override_settings( - ai_config_runtime, initial_settings, "testmodel" - ) + override = extract_override_settings(ai_config_runtime, initial_settings, "testmodel") assert override == {"topP": 0.9} # Test Case 2: Global Settings differ, expect override ai_config_runtime.add_model("testmodel", {"topP": 0.8}) - override = extract_override_settings( - ai_config_runtime, initial_settings, "testmodel" - ) + override = extract_override_settings(ai_config_runtime, initial_settings, "testmodel") assert override == {"topP": 0.9} # Test Case 3: Global Settings match settings, expect no override ai_config_runtime.update_model(model_name="testmodel", settings={"topP": 0.9}) - override = extract_override_settings( - ai_config_runtime, initial_settings, "testmodel" - ) + override = extract_override_settings(ai_config_runtime, initial_settings, "testmodel") assert override == {} # Test Case 4: Global settings defined and empty settings defined. Expect no override diff --git a/python/tests/test_registry.py b/python/tests/test_registry.py index ef6a4eb23..2b84149dd 100644 --- a/python/tests/test_registry.py +++ b/python/tests/test_registry.py @@ -42,10 +42,7 @@ def test_register_single_model_parser(self): mock_model_parser = MockModelParser() ModelParserRegistry.register_model_parser(mock_model_parser) - assert ( - ModelParserRegistry.get_model_parser(mock_model_parser.id()) - == mock_model_parser - ) + assert ModelParserRegistry.get_model_parser(mock_model_parser.id()) == mock_model_parser def test_register_multiple_model_parsers_with_different_ids(self): # Create model parsers @@ -96,16 +93,12 @@ def test_retrieve_model_parser_for_prompt(self, ai_config_runtime: AIConfigRunti ai_config_runtime.add_prompt(prompt.name, prompt) # Retrieve the model parser for the Prompt - retrieved_parser = ModelParserRegistry.get_model_parser_for_prompt( - prompt, ai_config_runtime - ) + retrieved_parser = ModelParserRegistry.get_model_parser_for_prompt(prompt, ai_config_runtime) # Assert that the retrieved model parser is the same as the registered one assert retrieved_parser == model_parser - def test_retrieve_model_parser_for_prompt_with_nonexistent_model( - self, ai_config_runtime: AIConfigRuntime - ): + def test_retrieve_model_parser_for_prompt_with_nonexistent_model(self, ai_config_runtime: AIConfigRuntime): # Create a Prompt object with a model name that is not registered prompt = Prompt( **{ @@ -118,9 +111,7 @@ def test_retrieve_model_parser_for_prompt_with_nonexistent_model( # Attempt to retrieve a model parser for the Prompt with pytest.raises(KeyError): - ModelParserRegistry.get_model_parser_for_prompt( - prompt, ai_config_runtime - ).id() + ModelParserRegistry.get_model_parser_for_prompt(prompt, ai_config_runtime).id() def test_remove_model_parser(self): # Create a model parser diff --git a/python/tests/test_resolve.py b/python/tests/test_resolve.py index ab4fe5fe2..555643131 100644 --- a/python/tests/test_resolve.py +++ b/python/tests/test_resolve.py @@ -11,16 +11,12 @@ async def test_resolve_default_model_config_with_openai_parser(): Test that the default model config is resolved correctly. `basic_default_model_aiconfig.json` is an aiconfig with 1 prompt that has no settings or model defined besides the default. """ config_relative_path = "aiconfigs/basic_default_model_aiconfig.json" - config_absolute_path = get_absolute_file_path_from_relative( - __file__, config_relative_path - ) + config_absolute_path = get_absolute_file_path_from_relative(__file__, config_relative_path) config = AIConfigRuntime.load(config_absolute_path) resolved_params = await config.resolve("prompt1") assert resolved_params == { - "messages": [ - {"content": "Hi! Tell me 10 cool things to do in NYC.", "role": "user"} - ], + "messages": [{"content": "Hi! Tell me 10 cool things to do in NYC.", "role": "user"}], "model": "gpt-3.5-turbo", "temperature": 1, "top_p": 1, diff --git a/python/tests/test_run_config.py b/python/tests/test_run_config.py index 9a44c88c6..99dd4d7c3 100644 --- a/python/tests/test_run_config.py +++ b/python/tests/test_run_config.py @@ -13,13 +13,9 @@ async def test_load_parametrized_data_config(set_temporary_env_vars): Config has 2 prompts. Prompt2 uses prompt1.output in its input. """ - with patch.object( - openai.chat.completions, "create", side_effect=mock_openai_chat_completion - ): + with patch.object(openai.chat.completions, "create", side_effect=mock_openai_chat_completion): config_relative_path = "aiconfigs/parametrized_data_config.json" - config_absolute_path = get_absolute_file_path_from_relative( - __file__, config_relative_path - ) + config_absolute_path = get_absolute_file_path_from_relative(__file__, config_relative_path) config = AIConfigRuntime.load(config_absolute_path) prompt1_params = { diff --git a/python/tests/test_util/test_params.py b/python/tests/test_util/test_params.py index b5291f36e..9952c84e1 100644 --- a/python/tests/test_util/test_params.py +++ b/python/tests/test_util/test_params.py @@ -1,9 +1,5 @@ import pytest -from aiconfig.util.params import ( - find_dependencies_in_prompt, - get_dependency_graph, - get_parameters_in_template, -) +from aiconfig.util.params import find_dependencies_in_prompt, get_dependency_graph, get_parameters_in_template from aiconfig.schema import Prompt, PromptMetadata @@ -30,9 +26,7 @@ def test_template_with_no_parameters(): def test_template_with_empty_params(): - result = get_parameters_in_template( - "This is a plain text template with a fake param {{}}." - ) + result = get_parameters_in_template("This is a plain text template with a fake param {{}}.") assert result == {} @@ -81,9 +75,7 @@ def test_find_dependencies_in_prompt(prompt_list_with_5_prompts): prompt_template = "I am referring to {{prompt1.input}} and this {{prompt4.output}}" # only allowed to reference upstream prompts current_prompt_name = "prompt2" - result = find_dependencies_in_prompt( - prompt_template, current_prompt_name, prompt_list_with_5_prompts - ) + result = find_dependencies_in_prompt(prompt_template, current_prompt_name, prompt_list_with_5_prompts) assert result == {"prompt1"} @@ -93,9 +85,7 @@ def test_find_dependencies_in_prompt_with_no_dependencies(prompt_list_with_5_pro prompt_template = "I am referring to {{}}" current_prompt_name = "prompt2" - result = find_dependencies_in_prompt( - prompt_template, current_prompt_name, prompt_list_with_5_prompts - ) + result = find_dependencies_in_prompt(prompt_template, current_prompt_name, prompt_list_with_5_prompts) assert not result @@ -105,9 +95,7 @@ def test_find_dependencies_in_prompt_with_two_dependencies(prompt_list_with_5_pr prompt_template = "I am referring to {{prompt2.output}} and {{prompt1.output}}" # current_prompt_name = "prompt4" - result = find_dependencies_in_prompt( - prompt_template, current_prompt_name, prompt_list_with_5_prompts - ) + result = find_dependencies_in_prompt(prompt_template, current_prompt_name, prompt_list_with_5_prompts) assert result == {"prompt1", "prompt2"} @@ -119,9 +107,7 @@ def test_find_dependencies_in_prompt_with_no_prompt_reference( prompt_template = "I am referring to {{fakeprompt.output}} and {{fakeprompt.output}}" # should return none, no prompt references here current_prompt_name = "prompt4" - result = find_dependencies_in_prompt( - prompt_template, current_prompt_name, prompt_list_with_5_prompts - ) + result = find_dependencies_in_prompt(prompt_template, current_prompt_name, prompt_list_with_5_prompts) assert not result diff --git a/python/tests/util/file_path_utils.py b/python/tests/util/file_path_utils.py index 8228cd4e5..70794ed6a 100644 --- a/python/tests/util/file_path_utils.py +++ b/python/tests/util/file_path_utils.py @@ -1,9 +1,7 @@ import os -def get_absolute_file_path_from_relative( - working_file_path: str, relative_file_path: str -) -> str: +def get_absolute_file_path_from_relative(working_file_path: str, relative_file_path: str) -> str: """ Returns the absolute file path of a file given its relative file path and the file path of the calling file.