Skip to content

Commit

Permalink
basic langchain provider implementation (#664)
Browse files Browse the repository at this point in the history
* basic langchain provider implementation

* langchain provider, minor refactor

* update docstring

---------

Co-authored-by: Josh Reini <[email protected]>
  • Loading branch information
Nvillaluenga and joshreini1 authored Dec 12, 2023
1 parent 5cddd81 commit f29bdc9
Show file tree
Hide file tree
Showing 7 changed files with 175 additions and 20 deletions.
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -41,4 +41,7 @@ credentials.toml
# Generated Files, not enabling this yet
# trulens_eval/generated_files/*.{py,md,ipynb}

**/*.sqlite
**/*.sqlite

# Virtual environment
.venv
32 changes: 17 additions & 15 deletions trulens_eval/trulens_eval/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@
from trulens_eval.feedback import Bedrock
from trulens_eval.feedback import Feedback
from trulens_eval.feedback import Huggingface
from trulens_eval.feedback import Langchain
from trulens_eval.feedback import LiteLLM
from trulens_eval.feedback import OpenAI
from trulens_eval.feedback.provider import Provider
Expand All @@ -98,19 +99,20 @@
from trulens_eval.utils.threading import TP

__all__ = [
'Tru',
'TruBasicApp',
'TruCustomApp',
'TruChain',
'TruLlama',
'Feedback',
'OpenAI',
'LiteLLM',
'Bedrock',
'Huggingface',
'FeedbackMode',
'Provider',
'Query', # to deprecate in 0.3.0
'Select',
'TP'
"Tru",
"TruBasicApp",
"TruCustomApp",
"TruChain",
"TruLlama",
"Feedback",
"OpenAI",
"Langchain",
"LiteLLM",
"Bedrock",
"Huggingface",
"FeedbackMode",
"Provider",
"Query", # to deprecate in 0.3.0
"Select",
"TP",
]
16 changes: 14 additions & 2 deletions trulens_eval/trulens_eval/feedback/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,31 @@

# Specific feedback functions:
from trulens_eval.feedback.embeddings import Embeddings

# Main class holding and running feedback functions:
from trulens_eval.feedback.feedback import Feedback
from trulens_eval.feedback.groundedness import Groundedness
from trulens_eval.feedback.groundtruth import GroundTruthAgreement
from trulens_eval.feedback.provider.bedrock import Bedrock
from trulens_eval.feedback.provider.cohere import Cohere

# Providers of feedback functions evaluation:
from trulens_eval.feedback.provider.hugs import Huggingface
from trulens_eval.feedback.provider.langchain import Langchain
from trulens_eval.feedback.provider.litellm import LiteLLM
from trulens_eval.feedback.provider.openai import AzureOpenAI
from trulens_eval.feedback.provider.openai import OpenAI

__all__ = [
'Feedback', 'Embeddings', 'Groundedness', 'GroundTruthAgreement', 'OpenAI',
'AzureOpenAI', 'Huggingface', 'Cohere', 'LiteLLM', 'Bedrock'
"Feedback",
"Embeddings",
"Groundedness",
"GroundTruthAgreement",
"OpenAI",
"AzureOpenAI",
"Huggingface",
"Cohere",
"LiteLLM",
"Bedrock",
"Langchain",
]
10 changes: 9 additions & 1 deletion trulens_eval/trulens_eval/feedback/provider/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,15 @@
from trulens_eval.feedback.provider.base import Provider
from trulens_eval.feedback.provider.bedrock import Bedrock
from trulens_eval.feedback.provider.hugs import Huggingface
from trulens_eval.feedback.provider.langchain import Langchain
from trulens_eval.feedback.provider.litellm import LiteLLM
from trulens_eval.feedback.provider.openai import OpenAI

__all__ = ['Provider', 'OpenAI', 'Huggingface', 'LiteLLM', 'Bedrock']
__all__ = [
"Provider",
"OpenAI",
"Huggingface",
"LiteLLM",
"Bedrock",
"Langchain",
]
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@
from trulens_eval.feedback.provider.endpoint.base import Endpoint
from trulens_eval.feedback.provider.endpoint.bedrock import BedrockEndpoint
from trulens_eval.feedback.provider.endpoint.hugs import HuggingfaceEndpoint
from trulens_eval.feedback.provider.endpoint.langchain import LangchainEndpoint
from trulens_eval.feedback.provider.endpoint.litellm import LiteLLMEndpoint
from trulens_eval.feedback.provider.endpoint.openai import OpenAIEndpoint, OpenAIClient
from trulens_eval.feedback.provider.endpoint.openai import OpenAIClient
from trulens_eval.feedback.provider.endpoint.openai import OpenAIEndpoint

__all__ = [
"Endpoint",
Expand All @@ -13,4 +15,5 @@
"LiteLLMEndpoint",
"BedrockEndpoint",
"OpenAIClient",
"LangchainEndpoint",
]
62 changes: 62 additions & 0 deletions trulens_eval/trulens_eval/feedback/provider/endpoint/langchain.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import inspect
import logging
from typing import Any, Callable, Dict, Optional, Union

from langchain.chat_models.base import BaseChatModel
from langchain.llms.base import BaseLLM

from trulens_eval.feedback.provider.endpoint.base import Endpoint
from trulens_eval.feedback.provider.endpoint.base import EndpointCallback
from trulens_eval.utils.pyschema import WithClassInfo

logger = logging.getLogger(__name__)


class LangchainCallback(EndpointCallback):
class Config:
arbitrary_types_allowed = True

def handle_classification(self, response: Dict) -> None:
super().handle_classification(response)

def handle_generation(self, response: Any) -> None:
super().handle_generation(response)


class LangchainEndpoint(Endpoint, WithClassInfo):
"""
Langchain endpoint.
"""

chain: Union[BaseLLM, BaseChatModel]

def __new__(cls, *args, **kwargs):
return super(Endpoint, cls).__new__(cls, name="langchain")

def handle_wrapped_call(
self,
func: Callable,
bindings: inspect.BoundArguments,
response: Any,
callback: Optional[EndpointCallback],
) -> None:
# TODO: Implement this and wrapped
self.global_callback.handle_generation(response=None)
if callback is not None:
callback.handle_generation(response=None)

def __init__(self, chain: Union[BaseLLM, BaseChatModel], *args, **kwargs):
if chain is None:
raise ValueError("`chain` must be specified.")

if not (isinstance(chain, BaseLLM) or isinstance(chain, BaseChatModel)):
raise ValueError(
f"`chain` must be of type {BaseLLM.__name__} or {BaseChatModel.__name__}"
)

kwargs["chain"] = chain
kwargs["name"] = "langchain"
kwargs["callback_class"] = LangchainCallback
kwargs["obj"] = self

super().__init__(*args, **kwargs)
65 changes: 65 additions & 0 deletions trulens_eval/trulens_eval/feedback/provider/langchain.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import json
import logging
from typing import Dict, Optional, Sequence, Union

from langchain.chat_models.base import BaseChatModel
from langchain.llms.base import BaseLLM

from trulens_eval.feedback.provider.base import LLMProvider
from trulens_eval.feedback.provider.endpoint import LangchainEndpoint

logger = logging.getLogger(__name__)


class Langchain(LLMProvider):
"""Out of the box feedback functions using Langchain LLMs and ChatModels"""

endpoint: LangchainEndpoint

def __init__(
self,
chain: Union[BaseLLM, BaseChatModel],
model_engine: str = "",
*args,
**kwargs
):
"""
Create a Langchain Provider with out of the box feedback functions.
**Usage:**
```
from trulens_eval.feedback.provider.langchain import Langchain
from langchain.llms import OpenAI
gpt3_llm = OpenAI(model="gpt-3.5-turbo-instruct")
langchain_provider = Langchain(chain = gpt3_llm)
```
Args:
chain (Union[BaseLLM, BaseChatModel]): Langchain LLMs or chat models
"""
self_kwargs = kwargs.copy()
self_kwargs["model_engine"] = model_engine or type(chain).__name__
self_kwargs["endpoint"] = LangchainEndpoint(
*args, chain=chain, **kwargs.copy()
)

super().__init__(**self_kwargs)

def _create_chat_completion(
self,
prompt: Optional[str] = None,
messages: Optional[Sequence[Dict]] = None,
**kwargs
) -> str:
if prompt is not None:
predict = self.endpoint.chain.predict(prompt, **kwargs)

elif messages is not None:
prompt = json.dumps(messages)
predict = self.endpoint.chain.predict(prompt, **kwargs)

else:
raise ValueError("`prompt` or `messages` must be specified.")

return predict

0 comments on commit f29bdc9

Please sign in to comment.