-
Notifications
You must be signed in to change notification settings - Fork 206
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
basic langchain provider implementation (#664)
* basic langchain provider implementation * langchain provider, minor refactor * update docstring --------- Co-authored-by: Josh Reini <[email protected]>
- Loading branch information
1 parent
5cddd81
commit f29bdc9
Showing
7 changed files
with
175 additions
and
20 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,15 @@ | ||
from trulens_eval.feedback.provider.base import Provider | ||
from trulens_eval.feedback.provider.bedrock import Bedrock | ||
from trulens_eval.feedback.provider.hugs import Huggingface | ||
from trulens_eval.feedback.provider.langchain import Langchain | ||
from trulens_eval.feedback.provider.litellm import LiteLLM | ||
from trulens_eval.feedback.provider.openai import OpenAI | ||
|
||
__all__ = ['Provider', 'OpenAI', 'Huggingface', 'LiteLLM', 'Bedrock'] | ||
__all__ = [ | ||
"Provider", | ||
"OpenAI", | ||
"Huggingface", | ||
"LiteLLM", | ||
"Bedrock", | ||
"Langchain", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
62 changes: 62 additions & 0 deletions
62
trulens_eval/trulens_eval/feedback/provider/endpoint/langchain.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
import inspect | ||
import logging | ||
from typing import Any, Callable, Dict, Optional, Union | ||
|
||
from langchain.chat_models.base import BaseChatModel | ||
from langchain.llms.base import BaseLLM | ||
|
||
from trulens_eval.feedback.provider.endpoint.base import Endpoint | ||
from trulens_eval.feedback.provider.endpoint.base import EndpointCallback | ||
from trulens_eval.utils.pyschema import WithClassInfo | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class LangchainCallback(EndpointCallback): | ||
class Config: | ||
arbitrary_types_allowed = True | ||
|
||
def handle_classification(self, response: Dict) -> None: | ||
super().handle_classification(response) | ||
|
||
def handle_generation(self, response: Any) -> None: | ||
super().handle_generation(response) | ||
|
||
|
||
class LangchainEndpoint(Endpoint, WithClassInfo): | ||
""" | ||
Langchain endpoint. | ||
""" | ||
|
||
chain: Union[BaseLLM, BaseChatModel] | ||
|
||
def __new__(cls, *args, **kwargs): | ||
return super(Endpoint, cls).__new__(cls, name="langchain") | ||
|
||
def handle_wrapped_call( | ||
self, | ||
func: Callable, | ||
bindings: inspect.BoundArguments, | ||
response: Any, | ||
callback: Optional[EndpointCallback], | ||
) -> None: | ||
# TODO: Implement this and wrapped | ||
self.global_callback.handle_generation(response=None) | ||
if callback is not None: | ||
callback.handle_generation(response=None) | ||
|
||
def __init__(self, chain: Union[BaseLLM, BaseChatModel], *args, **kwargs): | ||
if chain is None: | ||
raise ValueError("`chain` must be specified.") | ||
|
||
if not (isinstance(chain, BaseLLM) or isinstance(chain, BaseChatModel)): | ||
raise ValueError( | ||
f"`chain` must be of type {BaseLLM.__name__} or {BaseChatModel.__name__}" | ||
) | ||
|
||
kwargs["chain"] = chain | ||
kwargs["name"] = "langchain" | ||
kwargs["callback_class"] = LangchainCallback | ||
kwargs["obj"] = self | ||
|
||
super().__init__(*args, **kwargs) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
import json | ||
import logging | ||
from typing import Dict, Optional, Sequence, Union | ||
|
||
from langchain.chat_models.base import BaseChatModel | ||
from langchain.llms.base import BaseLLM | ||
|
||
from trulens_eval.feedback.provider.base import LLMProvider | ||
from trulens_eval.feedback.provider.endpoint import LangchainEndpoint | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class Langchain(LLMProvider): | ||
"""Out of the box feedback functions using Langchain LLMs and ChatModels""" | ||
|
||
endpoint: LangchainEndpoint | ||
|
||
def __init__( | ||
self, | ||
chain: Union[BaseLLM, BaseChatModel], | ||
model_engine: str = "", | ||
*args, | ||
**kwargs | ||
): | ||
""" | ||
Create a Langchain Provider with out of the box feedback functions. | ||
**Usage:** | ||
``` | ||
from trulens_eval.feedback.provider.langchain import Langchain | ||
from langchain.llms import OpenAI | ||
gpt3_llm = OpenAI(model="gpt-3.5-turbo-instruct") | ||
langchain_provider = Langchain(chain = gpt3_llm) | ||
``` | ||
Args: | ||
chain (Union[BaseLLM, BaseChatModel]): Langchain LLMs or chat models | ||
""" | ||
self_kwargs = kwargs.copy() | ||
self_kwargs["model_engine"] = model_engine or type(chain).__name__ | ||
self_kwargs["endpoint"] = LangchainEndpoint( | ||
*args, chain=chain, **kwargs.copy() | ||
) | ||
|
||
super().__init__(**self_kwargs) | ||
|
||
def _create_chat_completion( | ||
self, | ||
prompt: Optional[str] = None, | ||
messages: Optional[Sequence[Dict]] = None, | ||
**kwargs | ||
) -> str: | ||
if prompt is not None: | ||
predict = self.endpoint.chain.predict(prompt, **kwargs) | ||
|
||
elif messages is not None: | ||
prompt = json.dumps(messages) | ||
predict = self.endpoint.chain.predict(prompt, **kwargs) | ||
|
||
else: | ||
raise ValueError("`prompt` or `messages` must be specified.") | ||
|
||
return predict |