From 534ee8e4305ad7185b848d69491a473cdfa7801d Mon Sep 17 00:00:00 2001 From: "Rossdan Craig rossdan@lastmileai.dev" <> Date: Tue, 26 Dec 2023 13:12:18 -0500 Subject: [PATCH] [eng week] Set run output type to be list[Output] instead of Output This diff is the same as #343, rebased onto main. Sapling decided to create a new pr. Makes it more scalable. The Prompt schema already refers to `outputs` as `List[Output]` anyways --- cookbooks/HuggingFace/hf.py | 4 ++-- cookbooks/HuggingFace/python/hf.py | 4 ++-- .../remote_inference_client/text_generation.py | 4 ++-- extensions/llama/python/llama.py | 8 ++++---- python/src/aiconfig/default_parsers/dalle.py | 4 ++-- python/src/aiconfig/default_parsers/hf.py | 4 ++-- python/src/aiconfig/default_parsers/openai.py | 2 +- python/src/aiconfig/default_parsers/palm.py | 4 ++-- .../default_parsers/parameterized_model_parser.py | 14 +++++++------- 9 files changed, 24 insertions(+), 24 deletions(-) diff --git a/cookbooks/HuggingFace/hf.py b/cookbooks/HuggingFace/hf.py index a3c55eddf..485dc92e9 100644 --- a/cookbooks/HuggingFace/hf.py +++ b/cookbooks/HuggingFace/hf.py @@ -1,5 +1,5 @@ import copy -from typing import TYPE_CHECKING, Any, Dict, Optional +from typing import TYPE_CHECKING, Any, Dict, List, Optional # HuggingFace API imports from huggingface_hub import InferenceClient @@ -232,7 +232,7 @@ async def deserialize( async def run_inference( self, prompt: Prompt, aiconfig, options, parameters - ) -> Output: + ) -> List[Output]: """ Invoked to run a prompt in the .aiconfig. This method should perform the actual model inference based on the provided prompt and inference settings. diff --git a/cookbooks/HuggingFace/python/hf.py b/cookbooks/HuggingFace/python/hf.py index a3c55eddf..485dc92e9 100644 --- a/cookbooks/HuggingFace/python/hf.py +++ b/cookbooks/HuggingFace/python/hf.py @@ -1,5 +1,5 @@ import copy -from typing import TYPE_CHECKING, Any, Dict, Optional +from typing import TYPE_CHECKING, Any, Dict, List, Optional # HuggingFace API imports from huggingface_hub import InferenceClient @@ -232,7 +232,7 @@ async def deserialize( async def run_inference( self, prompt: Prompt, aiconfig, options, parameters - ) -> Output: + ) -> List[Output]: """ Invoked to run a prompt in the .aiconfig. This method should perform the actual model inference based on the provided prompt and inference settings. diff --git a/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/remote_inference_client/text_generation.py b/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/remote_inference_client/text_generation.py index 691c0c454..254f36a87 100644 --- a/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/remote_inference_client/text_generation.py +++ b/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/remote_inference_client/text_generation.py @@ -1,5 +1,5 @@ import copy -from typing import TYPE_CHECKING, Any, Dict, Optional +from typing import TYPE_CHECKING, Any, Dict, List, Optional from aiconfig.default_parsers.parameterized_model_parser import ParameterizedModelParser from aiconfig.model_parser import InferenceOptions @@ -254,7 +254,7 @@ async def deserialize( async def run_inference( self, prompt: Prompt, aiconfig: "AIConfigRuntime", options: InferenceOptions, parameters: dict[Any, Any] - ) -> Output: + ) -> List[Output]: """ Invoked to run a prompt in the .aiconfig. This method should perform the actual model inference based on the provided prompt and inference settings. diff --git a/extensions/llama/python/llama.py b/extensions/llama/python/llama.py index 79c76ffef..4902e6d80 100644 --- a/extensions/llama/python/llama.py +++ b/extensions/llama/python/llama.py @@ -1,4 +1,4 @@ -from typing import Any +from typing import Any, List from aiconfig.Config import AIConfigRuntime from aiconfig.default_parsers.parameterized_model_parser import ParameterizedModelParser @@ -61,16 +61,16 @@ async def run_inference( aiconfig: AIConfigRuntime, options: InferenceOptions | None, parameters: dict, - ) -> ExecuteResult: + ) -> List[Output]: resolved = await self.deserialize(prompt, aiconfig, parameters) model_input = resolved["model_input"] result = await self._run_inference_helper(model_input, options) self.qa.append((model_input, result.data[0])) - return result + return [result] - async def _run_inference_helper(self, model_input, options) -> ExecuteResult: + async def _run_inference_helper(self, model_input, options) -> List[Output]: llm = Llama(self.model_path) acc = "" stream = options.stream if options else True diff --git a/python/src/aiconfig/default_parsers/dalle.py b/python/src/aiconfig/default_parsers/dalle.py index c8dbbd1ce..aee20b294 100644 --- a/python/src/aiconfig/default_parsers/dalle.py +++ b/python/src/aiconfig/default_parsers/dalle.py @@ -1,5 +1,5 @@ import copy -from typing import TYPE_CHECKING, Any, Dict, Optional +from typing import TYPE_CHECKING, Any, Dict, List, Optional import openai from aiconfig.default_parsers.parameterized_model_parser import ParameterizedModelParser @@ -151,7 +151,7 @@ async def deserialize( async def run_inference( self, prompt: Prompt, aiconfig, _options, parameters - ) -> Output: + ) -> List[Output]: """ Invoked to run a prompt in the .aiconfig. This method should perform the actual model inference based on the provided prompt and inference settings. diff --git a/python/src/aiconfig/default_parsers/hf.py b/python/src/aiconfig/default_parsers/hf.py index af224c69e..03800ed36 100644 --- a/python/src/aiconfig/default_parsers/hf.py +++ b/python/src/aiconfig/default_parsers/hf.py @@ -1,5 +1,5 @@ import copy -from typing import TYPE_CHECKING, Any, Dict, Optional +from typing import TYPE_CHECKING, Any, Dict, List, Optional from aiconfig.default_parsers.parameterized_model_parser import ParameterizedModelParser from aiconfig.model_parser import InferenceOptions @@ -254,7 +254,7 @@ async def deserialize( async def run_inference( self, prompt: Prompt, aiconfig: "AIConfigRuntime", options: InferenceOptions, parameters: dict[Any, Any] - ) -> Output: + ) -> List[Output]: """ Invoked to run a prompt in the .aiconfig. This method should perform the actual model inference based on the provided prompt and inference settings. diff --git a/python/src/aiconfig/default_parsers/openai.py b/python/src/aiconfig/default_parsers/openai.py index af61d430f..b9c64cdab 100644 --- a/python/src/aiconfig/default_parsers/openai.py +++ b/python/src/aiconfig/default_parsers/openai.py @@ -223,7 +223,7 @@ async def run_inference( aiconfig: "AIConfigRuntime", options: InferenceOptions, parameters, - ) -> Output: + ) -> List[Output]: """ Invoked to run a prompt in the .aiconfig. This method should perform the actual model inference based on the provided prompt and inference settings. diff --git a/python/src/aiconfig/default_parsers/palm.py b/python/src/aiconfig/default_parsers/palm.py index 6aacda1f0..35260fc76 100644 --- a/python/src/aiconfig/default_parsers/palm.py +++ b/python/src/aiconfig/default_parsers/palm.py @@ -114,7 +114,7 @@ async def run_inference( aiconfig: "AIConfigRuntime", options: InferenceOptions, parameters, - ) -> Output: + ) -> List[Output]: """ Invoked to run a prompt in the .aiconfig. This method should perform the actual model inference based on the provided prompt and inference settings. @@ -313,7 +313,7 @@ async def run_inference( aiconfig: "AIConfigRuntime", options: InferenceOptions, parameters, - ) -> Output: + ) -> List[Output]: """ Invoked to run a prompt in the .aiconfig. This method should perform the actual model inference based on the provided prompt and inference settings. diff --git a/python/src/aiconfig/default_parsers/parameterized_model_parser.py b/python/src/aiconfig/default_parsers/parameterized_model_parser.py index 71f0c2750..ff8ac479c 100644 --- a/python/src/aiconfig/default_parsers/parameterized_model_parser.py +++ b/python/src/aiconfig/default_parsers/parameterized_model_parser.py @@ -3,12 +3,12 @@ import typing from abc import abstractmethod -from typing import Dict, Optional +from typing import Dict, List, Optional from aiconfig.model_parser import InferenceOptions, ModelParser from aiconfig.util.params import get_dependency_graph, resolve_prompt_string -from aiconfig.schema import AIConfig, ExecuteResult, JSONObject, Prompt, PromptInput +from aiconfig.schema import AIConfig, ExecuteResult, JSONObject, Output, Prompt, PromptInput if typing.TYPE_CHECKING: from aiconfig import AIConfigRuntime @@ -36,7 +36,7 @@ async def deserialize( """ @abstractmethod - async def run_inference(self): + async def run_inference(self) -> List[Output]: pass async def run( @@ -46,7 +46,7 @@ async def run( options: Optional[InferenceOptions] = None, parameters: Dict = {}, **kwargs, - ) -> ExecuteResult: + ) -> List[Output]: # maybe use prompt metadata instead of kwargs? if kwargs.get("run_with_dependencies", False): return await self.run_with_dependencies( @@ -57,7 +57,7 @@ async def run( async def run_with_dependencies( self, prompt: Prompt, aiconfig: AIConfig, options=None, parameters: Dict = {} - ) -> ExecuteResult: + ) -> List[Output]: """ Executes the AI model with the resolved dependencies and prompt references and returns the API response. @@ -95,11 +95,11 @@ async def execute_recursive(prompt_name): for dependency_prompt_name in dependency_graph[prompt_name]: await execute_recursive(dependency_prompt_name) - output = await aiconfig.run(prompt_name, parameters, options) + outputs = await aiconfig.run(prompt_name, parameters, options) # Return the output of the original prompt being executed if prompt_name == prompt.name: - return output + return outputs return await execute_recursive(prompt.name)