Skip to content

Commit

Permalink
Simply reuse parent's init function instead of overriding private att…
Browse files Browse the repository at this point in the history
…ributes
  • Loading branch information
rogeriochaves committed Jul 25, 2023
1 parent aff3242 commit f95a0aa
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 14 deletions.
9 changes: 4 additions & 5 deletions litechain/contrib/llms/gpt4all_chain.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import asyncio
from typing import AsyncGenerator, Callable, Iterable, Optional, TypeVar
from typing import AsyncGenerator, Callable, Iterable, Optional, TypeVar, cast

from gpt4all import GPT4All

Expand Down Expand Up @@ -65,10 +65,9 @@ def __init__(
n_batch=8,
n_threads: Optional[int] = None,
) -> None:
self.name = name
gpt4all = GPT4All(model, n_threads=n_threads)

async def generate(prompt: str) -> AsyncGenerator[str, None]:
async def generate(prompt: str) -> AsyncGenerator[U, None]:
loop = asyncio.get_event_loop()

def get_outputs() -> Iterable[str]:
Expand All @@ -87,6 +86,6 @@ def get_outputs() -> Iterable[str]:
outputs = await loop.run_in_executor(None, get_outputs)

for output in outputs:
yield output
yield cast(U, output)

self._call = lambda input: generate(call(input))
super().__init__(name, lambda input: generate(call(input)))
17 changes: 8 additions & 9 deletions litechain/contrib/llms/open_ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,7 @@ def __init__(
temperature: Optional[float] = 0,
max_tokens: Optional[int] = None,
) -> None:
self.name = name

async def completion(prompt: str) -> AsyncGenerator[str, None]:
async def completion(prompt: str) -> AsyncGenerator[U, None]:
loop = asyncio.get_event_loop()

def get_completions():
Expand All @@ -88,7 +86,7 @@ def get_completions():
if "text" in output["choices"][0]:
yield output["choices"][0]["text"]

self._call = lambda input: completion(call(input))
super().__init__(name, lambda input: completion(call(input)))


@dataclass
Expand Down Expand Up @@ -263,7 +261,7 @@ class OpenAIChatChain(Chain[T, U]):
"""

def __init__(
self: "OpenAIChatChain[T, Union[OpenAIChatDelta, V]]",
self: "OpenAIChatChain[T, OpenAIChatDelta]",
name: str,
call: Callable[
[T],
Expand All @@ -275,11 +273,9 @@ def __init__(
temperature: Optional[float] = 0,
max_tokens: Optional[int] = None,
) -> None:
self.name = name

async def chat_completion(
messages: List[OpenAIChatMessage],
) -> AsyncGenerator[ChainOutput[Union[OpenAIChatDelta, V]], Any]:
) -> AsyncGenerator[ChainOutput[OpenAIChatDelta], None]:
loop = asyncio.get_event_loop()

def get_completions():
Expand Down Expand Up @@ -349,4 +345,7 @@ def get_completions():
yield self._output_wrap(pending_function_call)
pending_function_call = None

self._call = lambda input: chat_completion(call(input))
super().__init__(
name,
lambda input: cast(AsyncGenerator[U, None], chat_completion(call(input))),
)

0 comments on commit f95a0aa

Please sign in to comment.