Skip to content

Commit

Permalink
[eng week] Set run output type to be list[Output] instead of Output
Browse files Browse the repository at this point in the history
This diff is the same as #343, rebased onto main. Sapling decided to create a new pr.


Makes it more scalable. The Prompt schema already refers to `outputs` as `List[Output]` anyways
  • Loading branch information
Rossdan Craig [email protected] committed Dec 26, 2023
1 parent 7733c2d commit 534ee8e
Show file tree
Hide file tree
Showing 9 changed files with 24 additions and 24 deletions.
4 changes: 2 additions & 2 deletions cookbooks/HuggingFace/hf.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import copy
from typing import TYPE_CHECKING, Any, Dict, Optional
from typing import TYPE_CHECKING, Any, Dict, List, Optional

# HuggingFace API imports
from huggingface_hub import InferenceClient
Expand Down Expand Up @@ -232,7 +232,7 @@ async def deserialize(

async def run_inference(
self, prompt: Prompt, aiconfig, options, parameters
) -> Output:
) -> List[Output]:
"""
Invoked to run a prompt in the .aiconfig. This method should perform
the actual model inference based on the provided prompt and inference settings.
Expand Down
4 changes: 2 additions & 2 deletions cookbooks/HuggingFace/python/hf.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import copy
from typing import TYPE_CHECKING, Any, Dict, Optional
from typing import TYPE_CHECKING, Any, Dict, List, Optional

# HuggingFace API imports
from huggingface_hub import InferenceClient
Expand Down Expand Up @@ -232,7 +232,7 @@ async def deserialize(

async def run_inference(
self, prompt: Prompt, aiconfig, options, parameters
) -> Output:
) -> List[Output]:
"""
Invoked to run a prompt in the .aiconfig. This method should perform
the actual model inference based on the provided prompt and inference settings.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import copy
from typing import TYPE_CHECKING, Any, Dict, Optional
from typing import TYPE_CHECKING, Any, Dict, List, Optional

from aiconfig.default_parsers.parameterized_model_parser import ParameterizedModelParser
from aiconfig.model_parser import InferenceOptions
Expand Down Expand Up @@ -254,7 +254,7 @@ async def deserialize(

async def run_inference(
self, prompt: Prompt, aiconfig: "AIConfigRuntime", options: InferenceOptions, parameters: dict[Any, Any]
) -> Output:
) -> List[Output]:
"""
Invoked to run a prompt in the .aiconfig. This method should perform
the actual model inference based on the provided prompt and inference settings.
Expand Down
8 changes: 4 additions & 4 deletions extensions/llama/python/llama.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any
from typing import Any, List

from aiconfig.Config import AIConfigRuntime
from aiconfig.default_parsers.parameterized_model_parser import ParameterizedModelParser
Expand Down Expand Up @@ -61,16 +61,16 @@ async def run_inference(
aiconfig: AIConfigRuntime,
options: InferenceOptions | None,
parameters: dict,
) -> ExecuteResult:
) -> List[Output]:
resolved = await self.deserialize(prompt, aiconfig, parameters)
model_input = resolved["model_input"]
result = await self._run_inference_helper(model_input, options)

self.qa.append((model_input, result.data[0]))

return result
return [result]

async def _run_inference_helper(self, model_input, options) -> ExecuteResult:
async def _run_inference_helper(self, model_input, options) -> List[Output]:
llm = Llama(self.model_path)
acc = ""
stream = options.stream if options else True
Expand Down
4 changes: 2 additions & 2 deletions python/src/aiconfig/default_parsers/dalle.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import copy
from typing import TYPE_CHECKING, Any, Dict, Optional
from typing import TYPE_CHECKING, Any, Dict, List, Optional

import openai
from aiconfig.default_parsers.parameterized_model_parser import ParameterizedModelParser
Expand Down Expand Up @@ -151,7 +151,7 @@ async def deserialize(

async def run_inference(
self, prompt: Prompt, aiconfig, _options, parameters
) -> Output:
) -> List[Output]:
"""
Invoked to run a prompt in the .aiconfig. This method should perform
the actual model inference based on the provided prompt and inference settings.
Expand Down
4 changes: 2 additions & 2 deletions python/src/aiconfig/default_parsers/hf.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import copy
from typing import TYPE_CHECKING, Any, Dict, Optional
from typing import TYPE_CHECKING, Any, Dict, List, Optional

from aiconfig.default_parsers.parameterized_model_parser import ParameterizedModelParser
from aiconfig.model_parser import InferenceOptions
Expand Down Expand Up @@ -254,7 +254,7 @@ async def deserialize(

async def run_inference(
self, prompt: Prompt, aiconfig: "AIConfigRuntime", options: InferenceOptions, parameters: dict[Any, Any]
) -> Output:
) -> List[Output]:
"""
Invoked to run a prompt in the .aiconfig. This method should perform
the actual model inference based on the provided prompt and inference settings.
Expand Down
2 changes: 1 addition & 1 deletion python/src/aiconfig/default_parsers/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ async def run_inference(
aiconfig: "AIConfigRuntime",
options: InferenceOptions,
parameters,
) -> Output:
) -> List[Output]:
"""
Invoked to run a prompt in the .aiconfig. This method should perform
the actual model inference based on the provided prompt and inference settings.
Expand Down
4 changes: 2 additions & 2 deletions python/src/aiconfig/default_parsers/palm.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ async def run_inference(
aiconfig: "AIConfigRuntime",
options: InferenceOptions,
parameters,
) -> Output:
) -> List[Output]:
"""
Invoked to run a prompt in the .aiconfig. This method should perform
the actual model inference based on the provided prompt and inference settings.
Expand Down Expand Up @@ -313,7 +313,7 @@ async def run_inference(
aiconfig: "AIConfigRuntime",
options: InferenceOptions,
parameters,
) -> Output:
) -> List[Output]:
"""
Invoked to run a prompt in the .aiconfig. This method should perform
the actual model inference based on the provided prompt and inference settings.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@

import typing
from abc import abstractmethod
from typing import Dict, Optional
from typing import Dict, List, Optional

from aiconfig.model_parser import InferenceOptions, ModelParser
from aiconfig.util.params import get_dependency_graph, resolve_prompt_string

from aiconfig.schema import AIConfig, ExecuteResult, JSONObject, Prompt, PromptInput
from aiconfig.schema import AIConfig, ExecuteResult, JSONObject, Output, Prompt, PromptInput

if typing.TYPE_CHECKING:
from aiconfig import AIConfigRuntime
Expand Down Expand Up @@ -36,7 +36,7 @@ async def deserialize(
"""

@abstractmethod
async def run_inference(self):
async def run_inference(self) -> List[Output]:
pass

async def run(
Expand All @@ -46,7 +46,7 @@ async def run(
options: Optional[InferenceOptions] = None,
parameters: Dict = {},
**kwargs,
) -> ExecuteResult:
) -> List[Output]:
# maybe use prompt metadata instead of kwargs?
if kwargs.get("run_with_dependencies", False):
return await self.run_with_dependencies(
Expand All @@ -57,7 +57,7 @@ async def run(

async def run_with_dependencies(
self, prompt: Prompt, aiconfig: AIConfig, options=None, parameters: Dict = {}
) -> ExecuteResult:
) -> List[Output]:
"""
Executes the AI model with the resolved dependencies and prompt references and returns the API response.
Expand Down Expand Up @@ -95,11 +95,11 @@ async def execute_recursive(prompt_name):
for dependency_prompt_name in dependency_graph[prompt_name]:
await execute_recursive(dependency_prompt_name)

output = await aiconfig.run(prompt_name, parameters, options)
outputs = await aiconfig.run(prompt_name, parameters, options)

# Return the output of the original prompt being executed
if prompt_name == prompt.name:
return output
return outputs

return await execute_recursive(prompt.name)

Expand Down

0 comments on commit 534ee8e

Please sign in to comment.