diff --git a/.gitignore b/.gitignore index 21390fcb7..6ec10aee5 100644 --- a/.gitignore +++ b/.gitignore @@ -31,3 +31,7 @@ MANIFEST # Notebook tests generate these files: imagenet_class_index.json imagenet_class_index.json.* + +*/.env +*/*.db +*/*.sqlite diff --git a/Makefile b/Makefile new file mode 100644 index 000000000..e7046524e --- /dev/null +++ b/Makefile @@ -0,0 +1,9 @@ +SHELL := /bin/bash +CONDA_ENV := demo3 +CONDA := source $$(conda info --base)/etc/profile.d/conda.sh ; conda activate ; conda activate $(CONDA_ENV) + +format: + $(CONDA); bash format.sh + +lab: + $(CONDA); jupyter lab --ip=0.0.0.0 --no-browser --ServerApp.token=deadbeef diff --git a/README.md b/README.md new file mode 100644 index 000000000..3de58075e --- /dev/null +++ b/README.md @@ -0,0 +1,13 @@ +# Welcome to TruLens! + +![TruLens](https://www.trulens.org/Assets/image/Neural_Network_Explainability.png) + +TruLens provides a set of tools for developing and monitoring neural nets, including large language models. This includes both tools for evaluation of LLMs and LLM-based applications with TruLens-Eval and deep learning explainability with TruLens-Explain. TruLens-Eval and TruLens-Explain are housed in separate packages and can be used independently. + +**TruLens-Eval** contains instrumentation and evaluation tools for large language model (LLM) based applications. It supports the iterative development and monitoring of a wide range of LLM applications by wrapping your application to log key metadata across the entire chain (or off chain if your project does not use chains) on your local machine. Importantly, it also gives you the tools you need to evaluate the quality of your LLM-based applications. + +For more information, see [TruLens-Eval Documentation](trulens_eval/install.md). + +**TruLens-Explain** is a cross-framework library for deep learning explainability. It provides a uniform abstraction over a number of different frameworks. It provides a uniform abstraction layer over TensorFlow, Pytorch, and Keras and allows input and internal explanations. + +For more information, see [TruLens-Explain Documentation](trulens_explain/install.md). diff --git a/trulens_explain/docs/Assets/Scripts/app.js b/docs/Assets/Scripts/app.js similarity index 100% rename from trulens_explain/docs/Assets/Scripts/app.js rename to docs/Assets/Scripts/app.js diff --git a/trulens_explain/docs/Assets/favicon.svg b/docs/Assets/favicon.svg similarity index 100% rename from trulens_explain/docs/Assets/favicon.svg rename to docs/Assets/favicon.svg diff --git a/trulens_explain/docs/Assets/favicon/android-chrome-192x192.png b/docs/Assets/favicon/android-chrome-192x192.png similarity index 100% rename from trulens_explain/docs/Assets/favicon/android-chrome-192x192.png rename to docs/Assets/favicon/android-chrome-192x192.png diff --git a/trulens_explain/docs/Assets/favicon/android-chrome-512x512.png b/docs/Assets/favicon/android-chrome-512x512.png similarity index 100% rename from trulens_explain/docs/Assets/favicon/android-chrome-512x512.png rename to docs/Assets/favicon/android-chrome-512x512.png diff --git a/trulens_explain/docs/Assets/favicon/apple-touch-icon.png b/docs/Assets/favicon/apple-touch-icon.png similarity index 100% rename from trulens_explain/docs/Assets/favicon/apple-touch-icon.png rename to docs/Assets/favicon/apple-touch-icon.png diff --git a/trulens_explain/docs/Assets/favicon/browserconfig.xml b/docs/Assets/favicon/browserconfig.xml similarity index 100% rename from trulens_explain/docs/Assets/favicon/browserconfig.xml rename to docs/Assets/favicon/browserconfig.xml diff --git a/trulens_explain/docs/Assets/favicon/favicon-16x16.png b/docs/Assets/favicon/favicon-16x16.png similarity index 100% rename from trulens_explain/docs/Assets/favicon/favicon-16x16.png rename to docs/Assets/favicon/favicon-16x16.png diff --git a/trulens_explain/docs/Assets/favicon/favicon-32x32.png b/docs/Assets/favicon/favicon-32x32.png similarity index 100% rename from trulens_explain/docs/Assets/favicon/favicon-32x32.png rename to docs/Assets/favicon/favicon-32x32.png diff --git a/trulens_explain/docs/Assets/favicon/favicon.ico b/docs/Assets/favicon/favicon.ico similarity index 100% rename from trulens_explain/docs/Assets/favicon/favicon.ico rename to docs/Assets/favicon/favicon.ico diff --git a/trulens_explain/docs/Assets/favicon/mstile-144x144.png b/docs/Assets/favicon/mstile-144x144.png similarity index 100% rename from trulens_explain/docs/Assets/favicon/mstile-144x144.png rename to docs/Assets/favicon/mstile-144x144.png diff --git a/trulens_explain/docs/Assets/favicon/mstile-150x150.png b/docs/Assets/favicon/mstile-150x150.png similarity index 100% rename from trulens_explain/docs/Assets/favicon/mstile-150x150.png rename to docs/Assets/favicon/mstile-150x150.png diff --git a/trulens_explain/docs/Assets/favicon/mstile-310x150.png b/docs/Assets/favicon/mstile-310x150.png similarity index 100% rename from trulens_explain/docs/Assets/favicon/mstile-310x150.png rename to docs/Assets/favicon/mstile-310x150.png diff --git a/trulens_explain/docs/Assets/favicon/mstile-310x310.png b/docs/Assets/favicon/mstile-310x310.png similarity index 100% rename from trulens_explain/docs/Assets/favicon/mstile-310x310.png rename to docs/Assets/favicon/mstile-310x310.png diff --git a/trulens_explain/docs/Assets/favicon/mstile-70x70.png b/docs/Assets/favicon/mstile-70x70.png similarity index 100% rename from trulens_explain/docs/Assets/favicon/mstile-70x70.png rename to docs/Assets/favicon/mstile-70x70.png diff --git a/trulens_explain/docs/Assets/favicon/safari-pinned-tab.svg b/docs/Assets/favicon/safari-pinned-tab.svg similarity index 100% rename from trulens_explain/docs/Assets/favicon/safari-pinned-tab.svg rename to docs/Assets/favicon/safari-pinned-tab.svg diff --git a/trulens_explain/docs/Assets/favicon/site.webmanifest b/docs/Assets/favicon/site.webmanifest similarity index 100% rename from trulens_explain/docs/Assets/favicon/site.webmanifest rename to docs/Assets/favicon/site.webmanifest diff --git a/docs/Assets/image/Chain_Explore.png b/docs/Assets/image/Chain_Explore.png new file mode 100644 index 000000000..a0630e7bc Binary files /dev/null and b/docs/Assets/image/Chain_Explore.png differ diff --git a/docs/Assets/image/Evaluations.png b/docs/Assets/image/Evaluations.png new file mode 100644 index 000000000..cbbaac15b Binary files /dev/null and b/docs/Assets/image/Evaluations.png differ diff --git a/docs/Assets/image/Leaderboard.png b/docs/Assets/image/Leaderboard.png new file mode 100644 index 000000000..9a91e7872 Binary files /dev/null and b/docs/Assets/image/Leaderboard.png differ diff --git a/trulens_explain/docs/Assets/image/Neural_Network_Explainability.png b/docs/Assets/image/Neural_Network_Explainability.png similarity index 100% rename from trulens_explain/docs/Assets/image/Neural_Network_Explainability.png rename to docs/Assets/image/Neural_Network_Explainability.png diff --git a/docs/Assets/image/TruLens_Architecture.png b/docs/Assets/image/TruLens_Architecture.png new file mode 100644 index 000000000..c05555bfd Binary files /dev/null and b/docs/Assets/image/TruLens_Architecture.png differ diff --git a/trulens_explain/docs/CNAME b/docs/CNAME similarity index 100% rename from trulens_explain/docs/CNAME rename to docs/CNAME diff --git a/trulens_explain/docs/conf.py b/docs/conf.py similarity index 97% rename from trulens_explain/docs/conf.py rename to docs/conf.py index dd4506585..509f3cd03 100644 --- a/trulens_explain/docs/conf.py +++ b/docs/conf.py @@ -20,8 +20,8 @@ # -- Project information ----------------------------------------------------- project = 'trulens' -copyright = '2020, Klas Leino' -author = 'Klas Leino' +copyright = '2023, TruEra' +author = 'TruEra' # -- General configuration --------------------------------------------------- diff --git a/trulens_explain/docs/img/favicon.ico b/docs/img/favicon.ico similarity index 100% rename from trulens_explain/docs/img/favicon.ico rename to docs/img/favicon.ico diff --git a/trulens_explain/docs/img/squid.png b/docs/img/squid.png similarity index 100% rename from trulens_explain/docs/img/squid.png rename to docs/img/squid.png diff --git a/trulens_explain/docs/index.md b/docs/index.md similarity index 100% rename from trulens_explain/docs/index.md rename to docs/index.md diff --git a/trulens_explain/docs/javascript/config.js b/docs/javascript/config.js similarity index 100% rename from trulens_explain/docs/javascript/config.js rename to docs/javascript/config.js diff --git a/trulens_explain/docs/javascript/tex-mml-chtml-3.0.0.js b/docs/javascript/tex-mml-chtml-3.0.0.js similarity index 100% rename from trulens_explain/docs/javascript/tex-mml-chtml-3.0.0.js rename to docs/javascript/tex-mml-chtml-3.0.0.js diff --git a/trulens_explain/docs/overrides/home.html b/docs/overrides/home.html similarity index 100% rename from trulens_explain/docs/overrides/home.html rename to docs/overrides/home.html diff --git a/trulens_explain/docs/robots.txt b/docs/robots.txt similarity index 100% rename from trulens_explain/docs/robots.txt rename to docs/robots.txt diff --git a/trulens_explain/docs/stylesheets/Base/Mixins/_responsive.scss b/docs/stylesheets/Base/Mixins/_responsive.scss similarity index 100% rename from trulens_explain/docs/stylesheets/Base/Mixins/_responsive.scss rename to docs/stylesheets/Base/Mixins/_responsive.scss diff --git a/trulens_explain/docs/stylesheets/Base/_base.scss b/docs/stylesheets/Base/_base.scss similarity index 100% rename from trulens_explain/docs/stylesheets/Base/_base.scss rename to docs/stylesheets/Base/_base.scss diff --git a/trulens_explain/docs/stylesheets/Base/_classes.scss b/docs/stylesheets/Base/_classes.scss similarity index 100% rename from trulens_explain/docs/stylesheets/Base/_classes.scss rename to docs/stylesheets/Base/_classes.scss diff --git a/trulens_explain/docs/stylesheets/Base/_shared.scss b/docs/stylesheets/Base/_shared.scss similarity index 100% rename from trulens_explain/docs/stylesheets/Base/_shared.scss rename to docs/stylesheets/Base/_shared.scss diff --git a/trulens_explain/docs/stylesheets/Base/_typography.scss b/docs/stylesheets/Base/_typography.scss similarity index 100% rename from trulens_explain/docs/stylesheets/Base/_typography.scss rename to docs/stylesheets/Base/_typography.scss diff --git a/trulens_explain/docs/stylesheets/Base/_variables.scss b/docs/stylesheets/Base/_variables.scss similarity index 100% rename from trulens_explain/docs/stylesheets/Base/_variables.scss rename to docs/stylesheets/Base/_variables.scss diff --git a/trulens_explain/docs/stylesheets/Components/_box.scss b/docs/stylesheets/Components/_box.scss similarity index 100% rename from trulens_explain/docs/stylesheets/Components/_box.scss rename to docs/stylesheets/Components/_box.scss diff --git a/trulens_explain/docs/stylesheets/Components/_buttons.scss b/docs/stylesheets/Components/_buttons.scss similarity index 100% rename from trulens_explain/docs/stylesheets/Components/_buttons.scss rename to docs/stylesheets/Components/_buttons.scss diff --git a/trulens_explain/docs/stylesheets/Components/_container.scss b/docs/stylesheets/Components/_container.scss similarity index 100% rename from trulens_explain/docs/stylesheets/Components/_container.scss rename to docs/stylesheets/Components/_container.scss diff --git a/trulens_explain/docs/stylesheets/Components/_list.scss b/docs/stylesheets/Components/_list.scss similarity index 100% rename from trulens_explain/docs/stylesheets/Components/_list.scss rename to docs/stylesheets/Components/_list.scss diff --git a/trulens_explain/docs/stylesheets/Components/_navigation.scss b/docs/stylesheets/Components/_navigation.scss similarity index 100% rename from trulens_explain/docs/stylesheets/Components/_navigation.scss rename to docs/stylesheets/Components/_navigation.scss diff --git a/trulens_explain/docs/stylesheets/Components/_section.scss b/docs/stylesheets/Components/_section.scss similarity index 100% rename from trulens_explain/docs/stylesheets/Components/_section.scss rename to docs/stylesheets/Components/_section.scss diff --git a/trulens_explain/docs/stylesheets/extra.css b/docs/stylesheets/extra.css similarity index 100% rename from trulens_explain/docs/stylesheets/extra.css rename to docs/stylesheets/extra.css diff --git a/trulens_explain/docs/stylesheets/style.css b/docs/stylesheets/style.css similarity index 100% rename from trulens_explain/docs/stylesheets/style.css rename to docs/stylesheets/style.css diff --git a/trulens_explain/docs/stylesheets/style.css.map b/docs/stylesheets/style.css.map similarity index 100% rename from trulens_explain/docs/stylesheets/style.css.map rename to docs/stylesheets/style.css.map diff --git a/trulens_explain/docs/stylesheets/style.scss b/docs/stylesheets/style.scss similarity index 100% rename from trulens_explain/docs/stylesheets/style.scss rename to docs/stylesheets/style.scss diff --git a/docs/trulens_eval/Assets/image/Chain_Explore.png b/docs/trulens_eval/Assets/image/Chain_Explore.png new file mode 100644 index 000000000..a0630e7bc Binary files /dev/null and b/docs/trulens_eval/Assets/image/Chain_Explore.png differ diff --git a/docs/trulens_eval/Assets/image/Evaluations.png b/docs/trulens_eval/Assets/image/Evaluations.png new file mode 100644 index 000000000..cbbaac15b Binary files /dev/null and b/docs/trulens_eval/Assets/image/Evaluations.png differ diff --git a/docs/trulens_eval/Assets/image/Leaderboard.png b/docs/trulens_eval/Assets/image/Leaderboard.png new file mode 100644 index 000000000..9a91e7872 Binary files /dev/null and b/docs/trulens_eval/Assets/image/Leaderboard.png differ diff --git a/docs/trulens_eval/Assets/image/TruLens_Architecture.png b/docs/trulens_eval/Assets/image/TruLens_Architecture.png new file mode 100644 index 000000000..c05555bfd Binary files /dev/null and b/docs/trulens_eval/Assets/image/TruLens_Architecture.png differ diff --git a/docs/trulens_eval/api/tru.md b/docs/trulens_eval/api/tru.md new file mode 100644 index 000000000..e902a42c8 --- /dev/null +++ b/docs/trulens_eval/api/tru.md @@ -0,0 +1,3 @@ +# Tru + +::: trulens_eval.trulens_eval.tru.Tru diff --git a/docs/trulens_eval/api/tru_feedback.md b/docs/trulens_eval/api/tru_feedback.md new file mode 100644 index 000000000..c3a35fd9f --- /dev/null +++ b/docs/trulens_eval/api/tru_feedback.md @@ -0,0 +1,3 @@ +# Feedback Functions + +::: trulens_eval.trulens_eval.tru_feedback diff --git a/docs/trulens_eval/api/truchain.md b/docs/trulens_eval/api/truchain.md new file mode 100644 index 000000000..3d13bd810 --- /dev/null +++ b/docs/trulens_eval/api/truchain.md @@ -0,0 +1,3 @@ +# Tru Chain + +::: trulens_eval.trulens_eval.tru_chain \ No newline at end of file diff --git a/docs/trulens_eval/install.md b/docs/trulens_eval/install.md new file mode 100644 index 000000000..f94f7d754 --- /dev/null +++ b/docs/trulens_eval/install.md @@ -0,0 +1,27 @@ +## Getting access to TruLens + +These installation instructions assume that you have conda installed and added to your path. + +1. Create a virtual environment (or modify an existing one). +``` +conda create -n "" python=3 # Skip if using existing environment. +conda activate +``` + +2. [Pip installation] Install the trulens-eval pip package. +``` +pip install trulens-eval +``` + +3. [Local installation] If you would like to develop or modify trulens, you can download the source code by cloning the trulens repo. +``` +git clone https://github.com/truera/trulens.git +``` + +4. [Locall installation] Install the trulens repo. +``` +cd trulens/trulens_eval +pip install -e . +``` + + diff --git a/docs/trulens_eval/quickstart.md b/docs/trulens_eval/quickstart.md new file mode 100644 index 000000000..62831e777 --- /dev/null +++ b/docs/trulens_eval/quickstart.md @@ -0,0 +1,267 @@ +## Quickstart + +### Playground + +To quickly play around with the TruLens Eval library, download this notebook: [trulens_eval_quickstart.ipynb](https://github.com/truera/trulens/blob/main/trulens_eval/trulens_eval_quickstart.ipynb). + + +### Install & Use + +Install trulens-eval from pypi. + +``` +pip install trulens-eval +``` + +Imports from langchain to build app, trulens for evaluation + +```python +from IPython.display import JSON +# imports from langchain to build app +from langchain import PromptTemplate +from langchain.chains import LLMChain +from langchain.chat_models import ChatOpenAI +from langchain.prompts.chat import ChatPromptTemplate +from langchain.prompts.chat import HumanMessagePromptTemplate +# imports from trulens to log and get feedback on chain +from trulens_eval import tru +from trulens_eval import tru_chain +tru = Tru() +``` + +### API Keys + +Our example chat app and feedback functions call external APIs such as OpenAI or Huggingface. You can add keys by setting the environment variables. + +#### In Python + +```python +import os +os.environ["OPENAI_API_KEY"] = "..." +``` +#### In Terminal + +```bash +export OPENAI_API_KEY = "..." +``` + +### Create a basic LLM chain to evaluate + +This example uses langchain and OpenAI, but the same process can be followed with any framework and model provider. Once you've created your chain, just call TruChain to wrap it. Doing so allows you to capture the chain metadata for logging. + +```python +full_prompt = HumanMessagePromptTemplate( + prompt=PromptTemplate( + template="Provide a helpful response with relevant background information for the following: {prompt}", + input_variables=["prompt"], + ) + ) +chat_prompt_template = ChatPromptTemplate.from_messages([full_prompt]) + +chat = ChatOpenAI(model_name='gpt-3.5-turbo', temperature=0.9) + +chain = LLMChain(llm=chat, prompt=chat_prompt_template) + +# wrap with truchain to instrument your chain +tc = tru_chain.TruChain(chain) +``` + +### Set up logging and instrumentation + +Make the first call to your LLM Application. The instrumented chain can operate like the original but can also produce a log or "record" of the chain execution. + +```python +prompt_input = 'que hora es?' +gpt3_response, record = tc(prompt_input) +``` + +We can log the records but first we need to log the chain itself. + +```python +tru.add_chain(chain_json=truchain.json) +``` + +Now we can log the record: +```python +tru.add_record( + prompt=prompt_input, # prompt input + response=gpt3_response['text'], # LLM response + record_json=record # record is returned by the TruChain wrapper +) +``` + +## Evaluate Quality + +Following the request to your app, you can then evaluate LLM quality using feedback functions. This is completed in a sequential call to minimize latency for your application, and evaluations will also be logged to your local machine. + +To get feedback on the quality of your LLM, you can use any of the provided feedback functions or add your own. + +To assess your LLM quality, you can provide the feedback functions to tru.run_feedback() in a list as shown below. Here we'll just add a simple language match checker. +```python +from trulens_eval.tru_feedback import Feedback, Huggingface + +os.environ["HUGGINGFACE_API_KEY"] = "..." + +# Initialize Huggingface-based feedback function collection class: +hugs = Huggingface() + +# Define a language match feedback function using HuggingFace. +f_lang_match = Feedback(hugs.language_match).on( + text1="prompt", text2="response" +) + +# Run feedack functions. This might take a moment if the public api needs to load the language model used by the feedback function. +feedback_result = f_lang_match.run_on_record( + chain_json=truchain.json, record_json=record +) + +JSON(feedback_result) + +# We can also run a collection of feedback functions +feedback_results = tru.run_feedback_functions( + record_json=record, + feedback_functions=[f_lang_match] +) +display(feedback_results) +``` + +After capturing feedback, you can then log it to your local database +```python +tru.add_feedback(feedback_results) +``` + +### Automatic logging +The above logging and feedback function evaluation steps can be done by TruChain. +```python +truchain = TruChain( + chain, + chain_id='Chain1_ChatApplication', + feedbacks=[f_lang_match], + tru=tru +) +# Note: providing `db: TruDB` causes the above constructor to log the wrapped chain in the database specified. +# Note: any `feedbacks` specified here will be evaluated and logged whenever the chain is used. + +truchain("This will be automatically logged.") +``` + +### Out-of-band Feedback evaluation + +In the above example, the feedback function evaluation is done in the same process as the chain evaluation. The alternative approach is the use the provided persistent evaluator started via `tru.start_deferred_feedback_evaluator`. Then specify the `feedback_mode` for `TruChain` as `deferred` to let the evaluator handle the feedback functions. + +For demonstration purposes, we start the evaluator here but it can be started in another process. +```python +truchain: TruChain = TruChain( + chain, + chain_id='Chain1_ChatApplication', + feedbacks=[f_lang_match], + tru=tru, + feedback_mode="deferred" +) + +tru.start_evaluator() +truchain("This will be logged by deferred evaluator.") +tru.stop_evaluator() +``` + + +### Run the dashboard! +```python +tru.run_dashboard() # open a streamlit app to explore +# tru.stop_dashboard() # stop if needed +``` + +### Chain Leaderboard: Quickly identify quality issues. + +Understand how your LLM application is performing at a glance. Once you've set up logging and evaluation in your application, you can view key performance statistics including cost and average feedback value across all of your LLM apps using the chain leaderboard. As you iterate new versions of your LLM application, you can compare their performance across all of the different quality metrics you've set up. + +Note: Average feedback values are returned and displayed in a range from 0 (worst) to 1 (best). + +![Chain Leaderboard](Assets/image/Leaderboard.png) + +To dive deeper on a particular chain, click "Select Chain". + +### Understand chain performance with Evaluations + +To learn more about the performance of a particular chain or LLM model, we can select it to view its evaluations at the record level. LLM quality is assessed through the use of feedback functions. Feedback functions are extensible methods for determining the quality of LLM responses and can be applied to any downstream LLM task. Out of the box we provide a number of feedback functions for assessing model agreement, sentiment, relevance and more. + +The evaluations tab provides record-level metadata and feedback on the quality of your LLM application. + +![Evaluations](Assets/image/Leaderboard.png) + +Click on a record to dive deep into all of the details of your chain stack and underlying LLM, captured by tru_chain. + +![Explore a Chain](Assets/image/Chain_Explore.png) + +If you prefer the raw format, you can quickly get it using the "Display full chain json" or "Display full record json" buttons at the bottom of the page. + +### Out-of-the-box Feedback Functions +See: + +#### Relevance + +This evaluates the *relevance* of the LLM response to the given text by LLM prompting. + +Relevance is currently only available with OpenAI ChatCompletion API. + +#### Sentiment + +This evaluates the *positive sentiment* of either the prompt or response. + +Sentiment is currently available to use with OpenAI, HuggingFace or Cohere as the model provider. + +* The OpenAI sentiment feedback function prompts a Chat Completion model to rate the sentiment from 1 to 10, and then scales the response down to 0-1. +* The HuggingFace sentiment feedback function returns a raw score from 0 to 1. +* The Cohere sentiment feedback function uses the classification endpoint and a small set of examples stored in feedback_prompts.py to return either a 0 or a 1. + +#### Model Agreement + +Model agreement uses OpenAI to attempt an honest answer at your prompt with system prompts for correctness, and then evaluates the aggreement of your LLM response to this model on a scale from 1 to 10. The agreement with each honest bot is then averaged and scaled from 0 to 1. + +#### Language Match + +This evaluates if the language of the prompt and response match. + +Language match is currently only available to use with HuggingFace as the model provider. This feedback function returns a score in the range from 0 to 1, where 1 indicates match and 0 indicates mismatch. + +#### Toxicity + +This evaluates the toxicity of the prompt or response. + +Toxicity is currently only available to be used with HuggingFace, and uses a classification endpoint to return a score from 0 to 1. The feedback function is negated as not_toxicity, and returns a 1 if not toxic and a 0 if toxic. + +#### Moderation + +The OpenAI Moderation API is made available for use as feedback functions. This includes hate, hate/threatening, self-harm, sexual, sexual/minors, violence, and violence/graphic. Each is negated (ex: not_hate) so that a 0 would indicate that the moderation rule is violated. These feedback functions return a score in the range 0 to 1. + +## Adding new feedback functions + +Feedback functions are an extensible framework for evaluating LLMs. You can add your own feedback functions to evaluate the qualities required by your application by updating trulens_eval/tru_feedback.py. If your contributions would be useful for others, we encourage you to contribute to trulens! + +Feedback functions are organized by model provider into Provider classes. + +The process for adding new feedback functions is: +1. Create a new Provider class or locate an existing one that applies to your feedback function. If your feedback function does not rely on a model provider, you can create a standalone class: + +```python +class StandAlone(Provider): + def __init__(self): + pass +``` + +2. Add a new feedback function method to your selected class. Your new method can either take a single text (str) as a parameter or both promopt (str) and response (str). It should return a float between 0 (worst) and 1 (best). + +```python +def feedback(self, text: str) -> float: + """ + Describe how the model works + + Parameters: + text (str): Text to evaluate. + Can also be prompt (str) and response (str). + + Returns: + float: A value between 0 (worst) and 1 (best). + """ + return float +``` diff --git a/docs/trulens_explain/api/attribution.md b/docs/trulens_explain/api/attribution.md new file mode 100644 index 000000000..aac5a7c40 --- /dev/null +++ b/docs/trulens_explain/api/attribution.md @@ -0,0 +1,3 @@ +# Attribution Methods + +::: trulens_explain.trulens.nn.attribution \ No newline at end of file diff --git a/docs/trulens_explain/api/distributions.md b/docs/trulens_explain/api/distributions.md new file mode 100644 index 000000000..3ca0253ee --- /dev/null +++ b/docs/trulens_explain/api/distributions.md @@ -0,0 +1,3 @@ +# Distributions of Interest + +::: trulens_explain.trulens.nn.distributions \ No newline at end of file diff --git a/docs/trulens_explain/api/model_wrappers.md b/docs/trulens_explain/api/model_wrappers.md new file mode 100644 index 000000000..d64f83e8a --- /dev/null +++ b/docs/trulens_explain/api/model_wrappers.md @@ -0,0 +1,3 @@ +# Model Wrappers + +::: trulens_explain.trulens.nn.models \ No newline at end of file diff --git a/docs/trulens_explain/api/quantities.md b/docs/trulens_explain/api/quantities.md new file mode 100644 index 000000000..5f187f7b3 --- /dev/null +++ b/docs/trulens_explain/api/quantities.md @@ -0,0 +1,3 @@ +# Quantities of Interest + +::: trulens_explain.trulens.nn.quantities \ No newline at end of file diff --git a/docs/trulens_explain/api/slices.md b/docs/trulens_explain/api/slices.md new file mode 100644 index 000000000..cc7f17eb2 --- /dev/null +++ b/docs/trulens_explain/api/slices.md @@ -0,0 +1,3 @@ +# Slices + +::: trulens_explain.trulens.nn.slices \ No newline at end of file diff --git a/docs/trulens_explain/api/visualizations.md b/docs/trulens_explain/api/visualizations.md new file mode 100644 index 000000000..6bd9e79e0 --- /dev/null +++ b/docs/trulens_explain/api/visualizations.md @@ -0,0 +1,3 @@ +# Visualization Methods + +::: trulens_explain.trulens.visualizations \ No newline at end of file diff --git a/trulens_explain/docs/attribution_parameterization.md b/docs/trulens_explain/attribution_parameterization.md similarity index 100% rename from trulens_explain/docs/attribution_parameterization.md rename to docs/trulens_explain/attribution_parameterization.md diff --git a/trulens_explain/docs/install.md b/docs/trulens_explain/install.md similarity index 97% rename from trulens_explain/docs/install.md rename to docs/trulens_explain/install.md index 94c182a3f..36ed26628 100644 --- a/trulens_explain/docs/install.md +++ b/docs/trulens_explain/install.md @@ -27,7 +27,7 @@ git clone https://github.com/truera/trulens.git 4. [Locall installation] Install the trulens repo. ``` -cd trulens +cd trulens_explain pip install -e . ``` diff --git a/trulens_explain/docs/quickstart.md b/docs/trulens_explain/quickstart.md similarity index 74% rename from trulens_explain/docs/quickstart.md rename to docs/trulens_explain/quickstart.md index 7a36ccbbc..3e730a23a 100644 --- a/trulens_explain/docs/quickstart.md +++ b/docs/trulens_explain/quickstart.md @@ -8,4 +8,4 @@ To quickly play around with the TruLens library, check out the following CoLab n ### Install & Use -Check out the [Installation](https://truera.github.io/trulens/install/) instructions for information on how to install the library, use it, and contribute. +Check out the [Installation](https://truera.github.io/trulens/trulens_explain/install/) instructions for information on how to install the library, use it, and contribute. diff --git a/trulens_explain/docs/welcome.md b/docs/welcome.md similarity index 100% rename from trulens_explain/docs/welcome.md rename to docs/welcome.md diff --git a/trulens_explain/mkdocs.yml b/mkdocs.yml similarity index 64% rename from trulens_explain/mkdocs.yml rename to mkdocs.yml index 8f4f6f772..918deccab 100644 --- a/trulens_explain/mkdocs.yml +++ b/mkdocs.yml @@ -42,7 +42,7 @@ plugins: - "^__init__$" # but always include __init__ modules and methods - "^__call__$" # and __call__ methods watch: - - trulens + - trulens_explain/trulens - search @@ -55,22 +55,32 @@ theme: # text: Source Sans Pro favicon: img/favicon.ico logo: img/squid.png + features: + - navigation.sections nav: - Home: index.md - Welcome to TruLens!: welcome.md - - Installation: install.md - - Quickstart: quickstart.md - - Attributions for Different Use Cases: attribution_parameterization.md - - API Reference: - - Attribution: api/attribution.md - - Models: api/model_wrappers.md - - Slices: api/slices.md - - Quantities: api/quantities.md - - Distributions: api/distributions.md - - Visualizations: api/visualizations.md - - Resources: - - NeurIPS Demo: https://truera.github.io/neurips-demo-2021/ + - Eval: + - Installation: trulens_eval/install.md + - Quickstart: trulens_eval/quickstart.md + - API Reference: + - Tru: trulens_eval/api/tru.md + - TruChain: trulens_eval/api/truchain.md + - Feedback Functions: trulens_eval/api/tru_feedback.md + - Explain: + - Installation: trulens_explain/install.md + - Quickstart: trulens_explain/quickstart.md + - Attributions for Different Use Cases: trulens_explain/attribution_parameterization.md + - API Reference: + - Attribution: trulens_explain/api/attribution.md + - Models: trulens_explain/api/model_wrappers.md + - Slices: trulens_explain/api/slices.md + - Quantities: trulens_explain/api/quantities.md + - Distributions: trulens_explain/api/distributions.md + - Visualizations: trulens_explain/api/visualizations.md +# - Resources: +# - NeurIPS Demo: https://truera.github.io/neurips-demo-2021/ extra_css: - stylesheets/extra.css diff --git a/trulens_eval/.streamlit/config.toml b/trulens_eval/.streamlit/config.toml new file mode 100644 index 000000000..43a50cbc9 --- /dev/null +++ b/trulens_eval/.streamlit/config.toml @@ -0,0 +1,6 @@ +[theme] +primaryColor="#0A2C37" +backgroundColor="#FFFFFF" +secondaryBackgroundColor="F5F5F5" +textColor="#0A2C37" +font="sans serif" diff --git a/trulens_eval/.streamlit/credentials.toml b/trulens_eval/.streamlit/credentials.toml new file mode 100644 index 000000000..b87ca415f --- /dev/null +++ b/trulens_eval/.streamlit/credentials.toml @@ -0,0 +1,2 @@ +[general] +email="" diff --git a/trulens_eval/MANIFEST.in b/trulens_eval/MANIFEST.in new file mode 100644 index 000000000..419ce56ad --- /dev/null +++ b/trulens_eval/MANIFEST.in @@ -0,0 +1 @@ +include trulens_eval/ux/trulens_logo.svg \ No newline at end of file diff --git a/trulens_eval/Makefile b/trulens_eval/Makefile new file mode 100644 index 000000000..2e671047c --- /dev/null +++ b/trulens_eval/Makefile @@ -0,0 +1,26 @@ +SHELL := /bin/bash +CONDA_ENV := py38_trulens_eval +CONDA := source $$(conda info --base)/etc/profile.d/conda.sh ; conda activate ; conda activate $(CONDA_ENV) + +slackbot: + $(CONDA); (PYTHONPATH=. python -u trulens_eval/slackbot.py) + +# 1>&2 > slackbot.log) + +test: + $(CONDA); python -m pytest -s test_tru_chain.py + +format: + $(CONDA); bash format.sh + +lab: + $(CONDA); jupyter lab --ip=0.0.0.0 --no-browser --ServerApp.token=deadbeef + +example_app: + $(CONDA); PYTHONPATH=. streamlit run trulens_eval/Example_Application.py + +example_trubot: + $(CONDA); PYTHONPATH=. streamlit run trulens_eval/Example_TruBot.py + +leaderboard: + $(CONDA); PYTHONPATH=. streamlit run trulens_eval/Leaderboard.py diff --git a/trulens_eval/README.md b/trulens_eval/README.md new file mode 100644 index 000000000..b6a0f58cb --- /dev/null +++ b/trulens_eval/README.md @@ -0,0 +1,276 @@ +# Welcome to TruLens-Eval! + +![TruLens](https://www.trulens.org/Assets/image/Neural_Network_Explainability.png) + +Evaluate and track your LLM experiments with TruLens. As you work on your models and prompts TruLens-Eval supports the iterative development and of a wide range of LLM applications by wrapping your application to log key metadata across the entire chain (or off chain if your project does not use chains) on your local machine. + +Using feedback functions, you can objectively evaluate the quality of the responses provided by an LLM to your requests. This is completed with minimal latency, as this is achieved in a sequential call for your application, and evaluations are logged to your local machine. Finally, we provide an easy to use streamlit dashboard run locally on your machine for you to better understand your LLM’s performance. + +![Architecture Diagram](../docs/Assets/image/TruLens_Architecture.png) + +# Quick Usage +To quickly play around with the TruLens Eval library, download this notebook: [trulens_eval_quickstart.ipynb](https://github.com/truera/trulens/blob/main/trulens_eval/trulens_eval_quickstart.ipynb). + + + +# Installation and Setup + +Install trulens-eval from pypi. + +``` +pip install trulens-eval +``` + +Imports from langchain to build app, trulens for evaluation + +```python +from IPython.display import JSON +# imports from langchain to build app +from langchain import PromptTemplate +from langchain.chains import LLMChain +from langchain.chat_models import ChatOpenAI +from langchain.prompts.chat import ChatPromptTemplate +from langchain.prompts.chat import HumanMessagePromptTemplate +# imports from trulens to log and get feedback on chain +from trulens_eval.tru import Tru +from trulens_eval import tru_chain +tru = Tru() +``` + +## API Keys + +Our example chat app and feedback functions call external APIs such as OpenAI or Huggingface. You can add keys by setting the environment variables. + +### In Python + +```python +import os +os.environ["OPENAI_API_KEY"] = "..." +``` +### In Terminal + +```bash +export OPENAI_API_KEY = "..." +``` + +## Create a basic LLM chain to evaluate + +This example uses langchain and OpenAI, but the same process can be followed with any framework and model provider. Once you've created your chain, just call TruChain to wrap it. Doing so allows you to capture the chain metadata for logging. + +```python +full_prompt = HumanMessagePromptTemplate( + prompt=PromptTemplate( + template="Provide a helpful response with relevant background information for the following: {prompt}", + input_variables=["prompt"], + ) + ) +chat_prompt_template = ChatPromptTemplate.from_messages([full_prompt]) + +chat = ChatOpenAI(model_name='gpt-3.5-turbo', temperature=0.9) + +chain = LLMChain(llm=chat, prompt=chat_prompt_template) + +# wrap with truchain to instrument your chain +tc = tru_chain.TruChain(chain) +``` + +## Set up logging and instrumentation + +Make the first call to your LLM Application. The instrumented chain can operate like the original but can also produce a log or "record" of the chain execution. + +```python +prompt_input = 'que hora es?' +gpt3_response, record = tc.call_with_record(prompt_input) +``` + +We can log the records but first we need to log the chain itself. + +```python +tru.add_chain(chain_json=tc.json) +``` + +Now we can log the record: +```python +tru.add_record( + prompt=prompt_input, # prompt input + response=gpt3_response['text'], # LLM response + record_json=record # record is returned by the TruChain wrapper +) +``` + +# Evaluate Quality + +Following the request to your app, you can then evaluate LLM quality using feedback functions. This is completed in a sequential call to minimize latency for your application, and evaluations will also be logged to your local machine. + +To get feedback on the quality of your LLM, you can use any of the provided feedback functions or add your own. + +To assess your LLM quality, you can provide the feedback functions to tru.run_feedback() in a list as shown below. Here we'll just add a simple language match checker. +```python +from trulens_eval.tru_feedback import Feedback, Huggingface + +os.environ["HUGGINGFACE_API_KEY"] = "..." + +# Initialize Huggingface-based feedback function collection class: +hugs = Huggingface() + +# Define a language match feedback function using HuggingFace. +f_lang_match = Feedback(hugs.language_match).on( + text1="prompt", text2="response" +) + +# Run feedack functions. This might take a moment if the public api needs to load the language model used by the feedback function. +feedback_result = f_lang_match.run_on_record( + chain_json=tc.json, record_json=record +) + +JSON(feedback_result) + +# We can also run a collection of feedback functions +feedback_results = tru.run_feedback_functions( + record_json=record, + feedback_functions=[f_lang_match] +) +display(feedback_results) +``` + +After capturing feedback, you can then log it to your local database +```python +tru.add_feedback(feedback_results) +``` + +## Automatic logging +The above logging and feedback function evaluation steps can be done by TruChain. +```python +tc = tru_chain.TruChain( + chain, + chain_id='Chain1_ChatApplication', + feedbacks=[f_lang_match], + tru=tru +) +# Note: providing `db: TruDB` causes the above constructor to log the wrapped chain in the database specified. +# Note: any `feedbacks` specified here will be evaluated and logged whenever the chain is used. + +tc("This will be automatically logged.") +``` + +## Out-of-band Feedback evaluation + +In the above example, the feedback function evaluation is done in the same process as the chain evaluation. The alternative approach is the use the provided persistent evaluator started via `tru.start_deferred_feedback_evaluator`. Then specify the `feedback_mode` for `TruChain` as `deferred` to let the evaluator handle the feedback functions. + +For demonstration purposes, we start the evaluator here but it can be started in another process. +```python +tc: tru_chain.TruChain = tru_chain.TruChain( + chain, + chain_id='Chain1_ChatApplication', + feedbacks=[f_lang_match], + tru=tru, + feedback_mode="deferred" +) + +tru.start_evaluator() +tc("This will be logged by deferred evaluator.") +tru.stop_evaluator() +``` + + +## Run the dashboard! +```python +tru.run_dashboard() # open a streamlit app to explore +# tru.stop_dashboard() # stop if needed +``` + +## Chain Leaderboard: Quickly identify quality issues. + +Understand how your LLM application is performing at a glance. Once you've set up logging and evaluation in your application, you can view key performance statistics including cost and average feedback value across all of your LLM apps using the chain leaderboard. As you iterate new versions of your LLM application, you can compare their performance across all of the different quality metrics you've set up. + +Note: Average feedback values are returned and displayed in a range from 0 (worst) to 1 (best). + +![Chain Leaderboard](../docs/Assets/image/Leaderboard.png) + +To dive deeper on a particular chain, click "Select Chain". + +## Understand chain performance with Evaluations + +To learn more about the performance of a particular chain or LLM model, we can select it to view its evaluations at the record level. LLM quality is assessed through the use of feedback functions. Feedback functions are extensible methods for determining the quality of LLM responses and can be applied to any downstream LLM task. Out of the box we provide a number of feedback functions for assessing model agreement, sentiment, relevance and more. + +The evaluations tab provides record-level metadata and feedback on the quality of your LLM application. + + +![Evaluations](../docs/Assets/image/Evaluations.png) + +Click on a record to dive deep into all of the details of your chain stack and underlying LLM, captured by tru_chain. + +![TruChain Details](../docs/Assets/image/Chain_Explore.png) + +If you prefer the raw format, you can quickly get it using the "Display full chain json" or "Display full record json" buttons at the bottom of the page. + +## Out-of-the-box Feedback Functions +See: + +### Relevance + +This evaluates the *relevance* of the LLM response to the given text by LLM prompting. + +Relevance is currently only available with OpenAI ChatCompletion API. + +### Sentiment + +This evaluates the *positive sentiment* of either the prompt or response. + +Sentiment is currently available to use with OpenAI, HuggingFace or Cohere as the model provider. + +* The OpenAI sentiment feedback function prompts a Chat Completion model to rate the sentiment from 1 to 10, and then scales the response down to 0-1. +* The HuggingFace sentiment feedback function returns a raw score from 0 to 1. +* The Cohere sentiment feedback function uses the classification endpoint and a small set of examples stored in feedback_prompts.py to return either a 0 or a 1. + +### Model Agreement + +Model agreement uses OpenAI to attempt an honest answer at your prompt with system prompts for correctness, and then evaluates the aggreement of your LLM response to this model on a scale from 1 to 10. The agreement with each honest bot is then averaged and scaled from 0 to 1. + +### Language Match + +This evaluates if the language of the prompt and response match. + +Language match is currently only available to use with HuggingFace as the model provider. This feedback function returns a score in the range from 0 to 1, where 1 indicates match and 0 indicates mismatch. + +### Toxicity + +This evaluates the toxicity of the prompt or response. + +Toxicity is currently only available to be used with HuggingFace, and uses a classification endpoint to return a score from 0 to 1. The feedback function is negated as not_toxicity, and returns a 1 if not toxic and a 0 if toxic. + +### Moderation + +The OpenAI Moderation API is made available for use as feedback functions. This includes hate, hate/threatening, self-harm, sexual, sexual/minors, violence, and violence/graphic. Each is negated (ex: not_hate) so that a 0 would indicate that the moderation rule is violated. These feedback functions return a score in the range 0 to 1. + +# Adding new feedback functions + +Feedback functions are an extensible framework for evaluating LLMs. You can add your own feedback functions to evaluate the qualities required by your application by updating trulens_eval/tru_feedback.py. If your contributions would be useful for others, we encourage you to contribute to trulens! + +Feedback functions are organized by model provider into Provider classes. + +The process for adding new feedback functions is: +1. Create a new Provider class or locate an existing one that applies to your feedback function. If your feedback function does not rely on a model provider, you can create a standalone class: + +```python +class StandAlone(Provider): + def __init__(self): + pass +``` + +2. Add a new feedback function method to your selected class. Your new method can either take a single text (str) as a parameter or both promopt (str) and response (str). It should return a float between 0 (worst) and 1 (best). + +```python +def feedback(self, text: str) -> float: + """ + Describe how the model works + + Parameters: + text (str): Text to evaluate. + Can also be prompt (str) and response (str). + + Returns: + float: A value between 0 (worst) and 1 (best). + """ + return float +``` diff --git a/trulens_eval/benchmarking.ipynb b/trulens_eval/benchmarking.ipynb new file mode 100644 index 000000000..868a9d503 --- /dev/null +++ b/trulens_eval/benchmarking.ipynb @@ -0,0 +1,95 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "got OPENAI_API_KEY\n", + "got COHERE_API_KEY\n", + "got KAGGLE_USERNAME\n", + "got KAGGLE_KEY\n", + "got HUGGINGFACE_API_KEY\n", + "got HUGGINGFACE_HEADERS\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/jreini/opt/anaconda3/envs/tru_llm/lib/python3.9/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n" + ] + } + ], + "source": [ + "from keys import *\n", + "import benchmark\n", + "import pandas as pd\n", + "import openai\n", + "openai.api_key = OPENAI_API_KEY\n", + "\n", + "import tru_feedback" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Found cached dataset imdb (/Users/jreini/.cache/huggingface/datasets/imdb/plain_text/1.0.0/d613c88cf8fa3bab83b4ded3713f1f74830d1100e171db75bbddb80b3345c9c0)\n", + "100%|██████████| 3/3 [00:00<00:00, 105.44it/s]\n" + ] + } + ], + "source": [ + "imdb = benchmark.load_data('imdb (binary sentiment)')\n", + "imdb25 = benchmark.sample_data(imdb, 25)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "positive_sentiment_benchmarked = benchmark.rate_limited_benchmark_on_data(imdb25, 'sentiment-positive', rate_limit = 10, evaluation_choice=\"response\", provider=\"openai\", model_engine=\"gpt-3.5-turbo\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3.9.16 ('tru_llm')", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.16" + }, + "orig_nbformat": 4, + "vscode": { + "interpreter": { + "hash": "d21f7c0bcad57942e36e4792dcf2729b091974a5bb8779ce77766f08b1284f72" + } + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/trulens_eval/prompt_response.ipynb b/trulens_eval/prompt_response.ipynb new file mode 100644 index 000000000..2fea86901 --- /dev/null +++ b/trulens_eval/prompt_response.ipynb @@ -0,0 +1,292 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "got OPENAI_API_KEY\n", + "got HUGGINGFACE_API_KEY\n", + "got COHERE_API_KEY\n", + "got KAGGLE_USERNAME\n", + "got KAGGLE_KEY\n" + ] + } + ], + "source": [ + "from keys import *\n", + "import benchmark\n", + "import pandas as pd\n", + "import openai\n", + "openai.api_key = OPENAI_API_KEY\n", + "\n", + "import tru_feedback" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Found cached dataset json (/Users/ricardoshih/.cache/huggingface/datasets/databricks___json/databricks--databricks-dolly-15k-6e0f9ea7eaa0ee08/0.0.0/e347ab1c932092252e717ff3f949105a4dd28b27e842dd53157d2f72e276c2e4)\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "eac4aea8a3d948b28c93d7a31e26bb3d", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/1 [00:00\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
promptresponselabel
1Which is a species of fish? Tope or RopeTope1
2Why can camels survive for long without water?Camels use the fat in their humps to keep them...1
3Alice's parents have three daughters: Amy, Jes...The name of the third daughter is Alice1
4When was Tomoaki Komorida born?Tomoaki Komorida was born on July 10,1981.1
5If I have more pieces at the time of stalemate...No. \\nStalemate is a drawn position. It doesn'...1
6Given a reference text about Lollapalooza, whe...Lollapalooze is an annual musical festival hel...1
7Who gave the UN the land in NY to build their HQJohn D Rockerfeller1
8Why mobile is bad for humanWe are always engaged one phone which is not g...1
9Who was John Moses Browning?John Moses Browning is one of the most well-kn...1
\n", + "" + ], + "text/plain": [ + " prompt \n", + "1 Which is a species of fish? Tope or Rope \\\n", + "2 Why can camels survive for long without water? \n", + "3 Alice's parents have three daughters: Amy, Jes... \n", + "4 When was Tomoaki Komorida born? \n", + "5 If I have more pieces at the time of stalemate... \n", + "6 Given a reference text about Lollapalooza, whe... \n", + "7 Who gave the UN the land in NY to build their HQ \n", + "8 Why mobile is bad for human \n", + "9 Who was John Moses Browning? \n", + "\n", + " response label \n", + "1 Tope 1 \n", + "2 Camels use the fat in their humps to keep them... 1 \n", + "3 The name of the third daughter is Alice 1 \n", + "4 Tomoaki Komorida was born on July 10,1981. 1 \n", + "5 No. \\nStalemate is a drawn position. It doesn'... 1 \n", + "6 Lollapalooze is an annual musical festival hel... 1 \n", + "7 John D Rockerfeller 1 \n", + "8 We are always engaged one phone which is not g... 1 \n", + "9 John Moses Browning is one of the most well-kn... 1 " + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "data" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "True" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "'prompt' in data" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "ename": "RateLimitError", + "evalue": "That model is currently overloaded with other requests. You can retry your request, or contact us through our help center at help.openai.com if the error persists. (Please include the request ID 05ca742434ae11c87334ef50ec81b3db in your message.)", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mRateLimitError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[5], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m factagreement_benchmarked \u001b[38;5;241m=\u001b[39m \u001b[43mbenchmark\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbenchmark_on_data\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdata\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mfactagreement\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mevaluation_choice\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mresponse\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mprovider\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mopenai\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmodel_engine\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mgpt-3.5-turbo\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/llm-experiments/benchmark.py:109\u001b[0m, in \u001b[0;36mbenchmark_on_data\u001b[0;34m(data, feedback_function_name, evaluation_choice, provider, model_engine)\u001b[0m\n\u001b[1;32m 107\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mUnrecognized feedback_function_name. Please use one of \u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mlist\u001b[39m(tru_feedback\u001b[38;5;241m.\u001b[39mFEEDBACK_FUNCTIONS\u001b[38;5;241m.\u001b[39mkeys())\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m \u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 108\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mprompt\u001b[39m\u001b[38;5;124m'\u001b[39m \u001b[38;5;129;01min\u001b[39;00m data \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mresponse\u001b[39m\u001b[38;5;124m'\u001b[39m \u001b[38;5;129;01min\u001b[39;00m data:\n\u001b[0;32m--> 109\u001b[0m data[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mfeedback\u001b[39m\u001b[38;5;124m'\u001b[39m] \u001b[38;5;241m=\u001b[39m \u001b[43mdata\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mapply\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43;01mlambda\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mx\u001b[49m\u001b[43m:\u001b[49m\u001b[43m \u001b[49m\u001b[43mfeedback_function\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mprompt\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mx\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mresponse\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43maxis\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[1;32m 110\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 111\u001b[0m data[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mfeedback\u001b[39m\u001b[38;5;124m'\u001b[39m] \u001b[38;5;241m=\u001b[39m data[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mtext\u001b[39m\u001b[38;5;124m'\u001b[39m]\u001b[38;5;241m.\u001b[39mapply(\u001b[38;5;28;01mlambda\u001b[39;00m x: feedback_function(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m'\u001b[39m, x))\n", + "File \u001b[0;32m/opt/anaconda3/envs/py_3_10/lib/python3.10/site-packages/pandas/core/frame.py:9428\u001b[0m, in \u001b[0;36mDataFrame.apply\u001b[0;34m(self, func, axis, raw, result_type, args, **kwargs)\u001b[0m\n\u001b[1;32m 9417\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mpandas\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mcore\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mapply\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m frame_apply\n\u001b[1;32m 9419\u001b[0m op \u001b[38;5;241m=\u001b[39m frame_apply(\n\u001b[1;32m 9420\u001b[0m \u001b[38;5;28mself\u001b[39m,\n\u001b[1;32m 9421\u001b[0m func\u001b[38;5;241m=\u001b[39mfunc,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 9426\u001b[0m kwargs\u001b[38;5;241m=\u001b[39mkwargs,\n\u001b[1;32m 9427\u001b[0m )\n\u001b[0;32m-> 9428\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mop\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mapply\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241m.\u001b[39m__finalize__(\u001b[38;5;28mself\u001b[39m, method\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mapply\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", + "File \u001b[0;32m/opt/anaconda3/envs/py_3_10/lib/python3.10/site-packages/pandas/core/apply.py:678\u001b[0m, in \u001b[0;36mFrameApply.apply\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 675\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mraw:\n\u001b[1;32m 676\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mapply_raw()\n\u001b[0;32m--> 678\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mapply_standard\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m/opt/anaconda3/envs/py_3_10/lib/python3.10/site-packages/pandas/core/apply.py:798\u001b[0m, in \u001b[0;36mFrameApply.apply_standard\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 797\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mapply_standard\u001b[39m(\u001b[38;5;28mself\u001b[39m):\n\u001b[0;32m--> 798\u001b[0m results, res_index \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mapply_series_generator\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 800\u001b[0m \u001b[38;5;66;03m# wrap results\u001b[39;00m\n\u001b[1;32m 801\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mwrap_results(results, res_index)\n", + "File \u001b[0;32m/opt/anaconda3/envs/py_3_10/lib/python3.10/site-packages/pandas/core/apply.py:814\u001b[0m, in \u001b[0;36mFrameApply.apply_series_generator\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 811\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m option_context(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmode.chained_assignment\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;28;01mNone\u001b[39;00m):\n\u001b[1;32m 812\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m i, v \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28menumerate\u001b[39m(series_gen):\n\u001b[1;32m 813\u001b[0m \u001b[38;5;66;03m# ignore SettingWithCopy here in case the user mutates\u001b[39;00m\n\u001b[0;32m--> 814\u001b[0m results[i] \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mf\u001b[49m\u001b[43m(\u001b[49m\u001b[43mv\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 815\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(results[i], ABCSeries):\n\u001b[1;32m 816\u001b[0m \u001b[38;5;66;03m# If we have a view on v, we need to make a copy because\u001b[39;00m\n\u001b[1;32m 817\u001b[0m \u001b[38;5;66;03m# series_generator will swap out the underlying data\u001b[39;00m\n\u001b[1;32m 818\u001b[0m results[i] \u001b[38;5;241m=\u001b[39m results[i]\u001b[38;5;241m.\u001b[39mcopy(deep\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m)\n", + "File \u001b[0;32m~/llm-experiments/benchmark.py:109\u001b[0m, in \u001b[0;36mbenchmark_on_data..\u001b[0;34m(x)\u001b[0m\n\u001b[1;32m 107\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mUnrecognized feedback_function_name. Please use one of \u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mlist\u001b[39m(tru_feedback\u001b[38;5;241m.\u001b[39mFEEDBACK_FUNCTIONS\u001b[38;5;241m.\u001b[39mkeys())\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m \u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 108\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mprompt\u001b[39m\u001b[38;5;124m'\u001b[39m \u001b[38;5;129;01min\u001b[39;00m data \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mresponse\u001b[39m\u001b[38;5;124m'\u001b[39m \u001b[38;5;129;01min\u001b[39;00m data:\n\u001b[0;32m--> 109\u001b[0m data[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mfeedback\u001b[39m\u001b[38;5;124m'\u001b[39m] \u001b[38;5;241m=\u001b[39m data\u001b[38;5;241m.\u001b[39mapply(\u001b[38;5;28;01mlambda\u001b[39;00m x: \u001b[43mfeedback_function\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mprompt\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mx\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mresponse\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m)\u001b[49m, axis\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m1\u001b[39m)\n\u001b[1;32m 110\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 111\u001b[0m data[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mfeedback\u001b[39m\u001b[38;5;124m'\u001b[39m] \u001b[38;5;241m=\u001b[39m data[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mtext\u001b[39m\u001b[38;5;124m'\u001b[39m]\u001b[38;5;241m.\u001b[39mapply(\u001b[38;5;28;01mlambda\u001b[39;00m x: feedback_function(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m'\u001b[39m, x))\n", + "File \u001b[0;32m~/llm-experiments/tru_feedback.py:416\u001b[0m, in \u001b[0;36mget_factagreement_function..\u001b[0;34m(prompt, response)\u001b[0m\n\u001b[1;32m 415\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mget_factagreement_function\u001b[39m(provider, model_engine, evaluation_choice):\n\u001b[0;32m--> 416\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;01mlambda\u001b[39;00m prompt, response: \u001b[43mget_factagreement\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 417\u001b[0m \u001b[43m \u001b[49m\u001b[43mprompt\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mresponse\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmodel_engine\u001b[49m\n\u001b[1;32m 418\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/llm-experiments/tru_feedback.py:265\u001b[0m, in \u001b[0;36mget_factagreement\u001b[0;34m(prompt, response, model_engine)\u001b[0m\n\u001b[1;32m 250\u001b[0m oai_chat_response \u001b[38;5;241m=\u001b[39m openai\u001b[38;5;241m.\u001b[39mChatCompletion\u001b[38;5;241m.\u001b[39mcreate(\n\u001b[1;32m 251\u001b[0m model\u001b[38;5;241m=\u001b[39mmodel_engine,\n\u001b[1;32m 252\u001b[0m temperature\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m0.5\u001b[39m,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 261\u001b[0m ]\n\u001b[1;32m 262\u001b[0m )[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mchoices\u001b[39m\u001b[38;5;124m\"\u001b[39m][\u001b[38;5;241m0\u001b[39m][\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmessage\u001b[39m\u001b[38;5;124m\"\u001b[39m][\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mcontent\u001b[39m\u001b[38;5;124m\"\u001b[39m]\n\u001b[1;32m 264\u001b[0m oai_similarity_response_1 \u001b[38;5;241m=\u001b[39m _get_answer_agreement(prompt, response, hf_response, model_engine)\n\u001b[0;32m--> 265\u001b[0m oai_similarity_response_2 \u001b[38;5;241m=\u001b[39m \u001b[43m_get_answer_agreement\u001b[49m\u001b[43m(\u001b[49m\u001b[43mprompt\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mresponse\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43moai_chat_response\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmodel_engine\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 267\u001b[0m \u001b[38;5;66;03m#print(f\"Prompt: {prompt}\\n\\nResponse: {response}\\n\\nHFResp: {hf_response}\\n\\nOAIResp: {oai_chat_response}\\n\\nAgree1: {oai_similarity_response_1}\\n\\nAgree2: {oai_similarity_response_2}\\n\\n\")\u001b[39;00m\n\u001b[1;32m 268\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m (_re_1_10_rating(oai_similarity_response_1)\u001b[38;5;241m+\u001b[39m_re_1_10_rating(oai_similarity_response_2))\u001b[38;5;241m/\u001b[39m\u001b[38;5;241m2\u001b[39m\n", + "File \u001b[0;32m~/llm-experiments/tru_feedback.py:229\u001b[0m, in \u001b[0;36m_get_answer_agreement\u001b[0;34m(prompt, response, check_response, model_engine)\u001b[0m\n\u001b[1;32m 228\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m_get_answer_agreement\u001b[39m(prompt, response, check_response, model_engine):\n\u001b[0;32m--> 229\u001b[0m oai_chat_response \u001b[38;5;241m=\u001b[39m \u001b[43mopenai\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mChatCompletion\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcreate\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 230\u001b[0m \u001b[43m \u001b[49m\u001b[43mmodel\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mmodel_engine\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 231\u001b[0m \u001b[43m \u001b[49m\u001b[43mtemperature\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m0.5\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 232\u001b[0m \u001b[43m \u001b[49m\u001b[43mmessages\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m[\u001b[49m\n\u001b[1;32m 233\u001b[0m \u001b[43m \u001b[49m\u001b[43m{\u001b[49m\n\u001b[1;32m 234\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mrole\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m:\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43msystem\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 235\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mcontent\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m:\u001b[49m\u001b[43m \u001b[49m\u001b[43mfeedback_prompts\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mAGREEMENT_SYSTEM_PROMPT\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m%\u001b[39;49m\u001b[43m \u001b[49m\u001b[43m(\u001b[49m\u001b[43mprompt\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mresponse\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 236\u001b[0m \u001b[43m \u001b[49m\u001b[43m}\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m{\u001b[49m\n\u001b[1;32m 237\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mrole\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m:\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43muser\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 238\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mcontent\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m:\u001b[49m\u001b[43m \u001b[49m\u001b[43mcheck_response\u001b[49m\n\u001b[1;32m 239\u001b[0m \u001b[43m \u001b[49m\u001b[43m}\u001b[49m\n\u001b[1;32m 240\u001b[0m \u001b[43m \u001b[49m\u001b[43m]\u001b[49m\n\u001b[1;32m 241\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mchoices\u001b[39m\u001b[38;5;124m\"\u001b[39m][\u001b[38;5;241m0\u001b[39m][\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmessage\u001b[39m\u001b[38;5;124m\"\u001b[39m][\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mcontent\u001b[39m\u001b[38;5;124m\"\u001b[39m]\n\u001b[1;32m 242\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m oai_chat_response\n", + "File \u001b[0;32m/opt/anaconda3/envs/py_3_10/lib/python3.10/site-packages/openai/api_resources/chat_completion.py:25\u001b[0m, in \u001b[0;36mChatCompletion.create\u001b[0;34m(cls, *args, **kwargs)\u001b[0m\n\u001b[1;32m 23\u001b[0m \u001b[38;5;28;01mwhile\u001b[39;00m \u001b[38;5;28;01mTrue\u001b[39;00m:\n\u001b[1;32m 24\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m---> 25\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43msuper\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcreate\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 26\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m TryAgain \u001b[38;5;28;01mas\u001b[39;00m e:\n\u001b[1;32m 27\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m timeout \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mand\u001b[39;00m time\u001b[38;5;241m.\u001b[39mtime() \u001b[38;5;241m>\u001b[39m start \u001b[38;5;241m+\u001b[39m timeout:\n", + "File \u001b[0;32m/opt/anaconda3/envs/py_3_10/lib/python3.10/site-packages/openai/api_resources/abstract/engine_api_resource.py:153\u001b[0m, in \u001b[0;36mEngineAPIResource.create\u001b[0;34m(cls, api_key, api_base, api_type, request_id, api_version, organization, **params)\u001b[0m\n\u001b[1;32m 127\u001b[0m \u001b[38;5;129m@classmethod\u001b[39m\n\u001b[1;32m 128\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mcreate\u001b[39m(\n\u001b[1;32m 129\u001b[0m \u001b[38;5;28mcls\u001b[39m,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 136\u001b[0m \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mparams,\n\u001b[1;32m 137\u001b[0m ):\n\u001b[1;32m 138\u001b[0m (\n\u001b[1;32m 139\u001b[0m deployment_id,\n\u001b[1;32m 140\u001b[0m engine,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 150\u001b[0m api_key, api_base, api_type, api_version, organization, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mparams\n\u001b[1;32m 151\u001b[0m )\n\u001b[0;32m--> 153\u001b[0m response, _, api_key \u001b[38;5;241m=\u001b[39m \u001b[43mrequestor\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrequest\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 154\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mpost\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 155\u001b[0m \u001b[43m \u001b[49m\u001b[43murl\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 156\u001b[0m \u001b[43m \u001b[49m\u001b[43mparams\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mparams\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 157\u001b[0m \u001b[43m \u001b[49m\u001b[43mheaders\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mheaders\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 158\u001b[0m \u001b[43m \u001b[49m\u001b[43mstream\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mstream\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 159\u001b[0m \u001b[43m \u001b[49m\u001b[43mrequest_id\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mrequest_id\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 160\u001b[0m \u001b[43m \u001b[49m\u001b[43mrequest_timeout\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mrequest_timeout\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 161\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 163\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m stream:\n\u001b[1;32m 164\u001b[0m \u001b[38;5;66;03m# must be an iterator\u001b[39;00m\n\u001b[1;32m 165\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(response, OpenAIResponse)\n", + "File \u001b[0;32m/opt/anaconda3/envs/py_3_10/lib/python3.10/site-packages/openai/api_requestor.py:230\u001b[0m, in \u001b[0;36mAPIRequestor.request\u001b[0;34m(self, method, url, params, headers, files, stream, request_id, request_timeout)\u001b[0m\n\u001b[1;32m 209\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mrequest\u001b[39m(\n\u001b[1;32m 210\u001b[0m \u001b[38;5;28mself\u001b[39m,\n\u001b[1;32m 211\u001b[0m method,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 218\u001b[0m request_timeout: Optional[Union[\u001b[38;5;28mfloat\u001b[39m, Tuple[\u001b[38;5;28mfloat\u001b[39m, \u001b[38;5;28mfloat\u001b[39m]]] \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[1;32m 219\u001b[0m ) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Tuple[Union[OpenAIResponse, Iterator[OpenAIResponse]], \u001b[38;5;28mbool\u001b[39m, \u001b[38;5;28mstr\u001b[39m]:\n\u001b[1;32m 220\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mrequest_raw(\n\u001b[1;32m 221\u001b[0m method\u001b[38;5;241m.\u001b[39mlower(),\n\u001b[1;32m 222\u001b[0m url,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 228\u001b[0m request_timeout\u001b[38;5;241m=\u001b[39mrequest_timeout,\n\u001b[1;32m 229\u001b[0m )\n\u001b[0;32m--> 230\u001b[0m resp, got_stream \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_interpret_response\u001b[49m\u001b[43m(\u001b[49m\u001b[43mresult\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mstream\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 231\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m resp, got_stream, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mapi_key\n", + "File \u001b[0;32m/opt/anaconda3/envs/py_3_10/lib/python3.10/site-packages/openai/api_requestor.py:624\u001b[0m, in \u001b[0;36mAPIRequestor._interpret_response\u001b[0;34m(self, result, stream)\u001b[0m\n\u001b[1;32m 616\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m (\n\u001b[1;32m 617\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_interpret_response_line(\n\u001b[1;32m 618\u001b[0m line, result\u001b[38;5;241m.\u001b[39mstatus_code, result\u001b[38;5;241m.\u001b[39mheaders, stream\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m\n\u001b[1;32m 619\u001b[0m )\n\u001b[1;32m 620\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m line \u001b[38;5;129;01min\u001b[39;00m parse_stream(result\u001b[38;5;241m.\u001b[39miter_lines())\n\u001b[1;32m 621\u001b[0m ), \u001b[38;5;28;01mTrue\u001b[39;00m\n\u001b[1;32m 622\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 623\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m (\n\u001b[0;32m--> 624\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_interpret_response_line\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 625\u001b[0m \u001b[43m \u001b[49m\u001b[43mresult\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcontent\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdecode\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mutf-8\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 626\u001b[0m \u001b[43m \u001b[49m\u001b[43mresult\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mstatus_code\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 627\u001b[0m \u001b[43m \u001b[49m\u001b[43mresult\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mheaders\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 628\u001b[0m \u001b[43m \u001b[49m\u001b[43mstream\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mFalse\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m 629\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m,\n\u001b[1;32m 630\u001b[0m \u001b[38;5;28;01mFalse\u001b[39;00m,\n\u001b[1;32m 631\u001b[0m )\n", + "File \u001b[0;32m/opt/anaconda3/envs/py_3_10/lib/python3.10/site-packages/openai/api_requestor.py:687\u001b[0m, in \u001b[0;36mAPIRequestor._interpret_response_line\u001b[0;34m(self, rbody, rcode, rheaders, stream)\u001b[0m\n\u001b[1;32m 685\u001b[0m stream_error \u001b[38;5;241m=\u001b[39m stream \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124merror\u001b[39m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;129;01min\u001b[39;00m resp\u001b[38;5;241m.\u001b[39mdata\n\u001b[1;32m 686\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m stream_error \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;241m200\u001b[39m \u001b[38;5;241m<\u001b[39m\u001b[38;5;241m=\u001b[39m rcode \u001b[38;5;241m<\u001b[39m \u001b[38;5;241m300\u001b[39m:\n\u001b[0;32m--> 687\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mhandle_error_response(\n\u001b[1;32m 688\u001b[0m rbody, rcode, resp\u001b[38;5;241m.\u001b[39mdata, rheaders, stream_error\u001b[38;5;241m=\u001b[39mstream_error\n\u001b[1;32m 689\u001b[0m )\n\u001b[1;32m 690\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m resp\n", + "\u001b[0;31mRateLimitError\u001b[0m: That model is currently overloaded with other requests. You can retry your request, or contact us through our help center at help.openai.com if the error persists. (Please include the request ID 05ca742434ae11c87334ef50ec81b3db in your message.)" + ] + } + ], + "source": [ + "factagreement_benchmarked = benchmark.benchmark_on_data(data, 'factagreement', evaluation_choice=\"response\", provider=\"openai\", model_engine=\"gpt-3.5-turbo\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "3_10", + "language": "python", + "name": "myenv" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.6" + }, + "vscode": { + "interpreter": { + "hash": "d21f7c0bcad57942e36e4792dcf2729b091974a5bb8779ce77766f08b1284f72" + } + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/trulens_eval/pytest.ini b/trulens_eval/pytest.ini new file mode 100644 index 000000000..200d45814 --- /dev/null +++ b/trulens_eval/pytest.ini @@ -0,0 +1,4 @@ +[pytest] +markers = + slow: marks tests as slow + nonfree: marks tests that may incur API costs diff --git a/trulens_eval/requirements.txt b/trulens_eval/requirements.txt new file mode 100644 index 000000000..df4f1d234 --- /dev/null +++ b/trulens_eval/requirements.txt @@ -0,0 +1,34 @@ +# common requirements +python-dotenv +langchain +typing-inspect==0.8.0 # langchain with python < 3.9 fix +typing_extensions==4.5.0 # langchain with python < 3.9 fix +# slack bot and its indexing requirements: +sentencepiece +transformers +pyllama +tokenizers +protobuf +accelerate +openai +pinecone-client +tiktoken +slack_bolt +requests +beautifulsoup4 +unstructured +pypdf +pdfminer.six +# TruChain requirements: +tinydb +pydantic +merkle_json +# app requirements: +streamlit +streamlit-aggrid +streamlit-extras +datasets +cohere +kaggle +watchdog +millify diff --git a/trulens_eval/setup.cfg b/trulens_eval/setup.cfg new file mode 100644 index 000000000..7107cc1a4 --- /dev/null +++ b/trulens_eval/setup.cfg @@ -0,0 +1,15 @@ +[metadata] +name = trulens_eval +version = attr: trulens_eval.__version__ +url = https://www.trulens.org +license = MIT +author = Truera Inc +author_email = all@truera.com +description = Library with langchain instrumentation to evaluate LLM based applications. +long_description = file: README.md +long_description_content_type = text/markdown +classifiers = + Programming Language :: Python :: 3 + Operating System :: OS Independent + Development Status :: 3 - Alpha + License :: OSI Approved :: MIT License diff --git a/trulens_eval/setup.py b/trulens_eval/setup.py new file mode 100644 index 000000000..46545a99a --- /dev/null +++ b/trulens_eval/setup.py @@ -0,0 +1,33 @@ +from setuptools import find_namespace_packages +from setuptools import setup + +setup( + name="trulens_eval", + include_package_data=True, + packages=find_namespace_packages( + include=["trulens_eval", "trulens_eval.*"] + ), + python_requires='>=3.8', + install_requires=[ + 'cohere>=4.4.1', + 'datasets>=2.12.0', + 'python-dotenv>=1.0.0', + 'kaggle>=1.5.13', + 'langchain>=0.0.170', + 'merkle-json>=1.0.0', + 'millify>=0.1.1', + 'openai>=0.27.6', + 'pinecone-client>=2.2.1', + 'pydantic>=1.10.7', + 'requests>=2.30.0', + 'slack-bolt>=1.18.0', + 'slack-sdk>=3.21.3', + 'streamlit>=1.22.0', + 'streamlit-aggrid>=0.3.4.post3', + 'streamlit-extras>=0.2.7', + 'tinydb>=4.7.1', + 'transformers>=4.10.0', + 'typing-inspect==0.8.0', # langchain with python < 3.9 fix + 'typing_extensions==4.5.0' # langchain with python < 3.9 fix + ], +) diff --git a/trulens_eval/slackbot.ipynb b/trulens_eval/slackbot.ipynb new file mode 100644 index 000000000..eb7250485 --- /dev/null +++ b/trulens_eval/slackbot.ipynb @@ -0,0 +1,784 @@ +{ + "cells": [ + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Slackbot-related work" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%load_ext autoreload\n", + "%autoreload 2" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from trulens_eval.keys import *\n", + "from trulens_eval.slackbot import get_or_make_chain, get_answer\n", + "from trulens_eval.util import TP\n", + "from trulens_eval import Tru\n", + "from trulens_eval.tru_feedback import Huggingface" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "Tru().start_dashboard(_dev=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "thread = Tru().start_evaluator()\n", + "# Tru().stop_evaluator()\n", + "# Tru().reset_database()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "selectors = [0,1,3,4]\n", + "messages = [\"Who is Shayak?\", \"Wer ist Shayak?\", \"Kim jest Shayak?\", \"¿Quién es Shayak?\", \"Was ist QII?\", \"Co jest QII?\"]\n", + "\n", + "# selectors = selectors[0:2]\n", + "# messages = messages[0:2]\n", + "\n", + "def test_bot(selector, question):\n", + " print(selector, question)\n", + " chain = get_or_make_chain(cid=question + str(selector), selector=selector)\n", + " answer = get_answer(chain=chain, question=question)\n", + " return answer\n", + "\n", + "results = []\n", + "\n", + "for s in selectors:\n", + " for m in messages:\n", + " results.append(TP().promise(test_bot, selector=s, question=m))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "TP().finish()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "TP().promises.qsize()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "for res in results:\n", + " print(res.get())" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from trulens_eval.tru_db import Record, TruDB, LocalSQLite, Chain, Query\n", + "from trulens_eval import Tru\n", + "from trulens_eval.util import TP\n", + "from trulens_eval import tru_feedback\n", + "from IPython.display import JSON\n", + "from ipywidgets import widgets\n", + "import json" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "db = tru.db" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "conn, c = db._connect()\n", + "# c.execute(\"select * from records\")\n", + "# rows = c.fetchall()\n", + "c.execute(\"delete from records where chain_id='2/relevance_prompt'\")\n", + "db._close(conn)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "conn, c = tru.db._connect()\n", + "c.execute(\"select * from records\")\n", + "rows = c.fetchall()\n", + "tru.db._close(conn)\n", + "rows" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from trulens_eval.provider_apis import Endpoint" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from langchain.callbacks import get_openai_callback\n", + "from langchain.chains import ConversationalRetrievalChain\n", + "from langchain.embeddings.openai import OpenAIEmbeddings\n", + "from langchain.llms.base import BaseLLM\n", + "from langchain.llms import OpenAI\n", + "from langchain.memory import ConversationSummaryBufferMemory\n", + "from langchain.vectorstores import Pinecone\n", + "import pinecone\n", + "\n", + "from trulens_eval import tru\n", + "from trulens_eval import tru_chain\n", + "from trulens_eval.keys import *\n", + "from trulens_eval.keys import PINECONE_API_KEY\n", + "from trulens_eval.keys import PINECONE_ENV\n", + "\n", + "# Set up GPT-3 model\n", + "model_name = \"gpt-3.5-turbo\"\n", + "\n", + "chain_id = \"TruBot_relevance\"\n", + "\n", + "# Pinecone configuration.\n", + "pinecone.init(\n", + " api_key=PINECONE_API_KEY, # find at app.pinecone.io\n", + " environment=PINECONE_ENV # next to api key in console\n", + ")\n", + "\n", + "identity = lambda h: h\n", + "\n", + "# Embedding needed for Pinecone vector db.\n", + "embedding = OpenAIEmbeddings(model='text-embedding-ada-002') # 1536 dims\n", + "docsearch = Pinecone.from_existing_index(\n", + " index_name=\"llmdemo\", embedding=embedding\n", + ")\n", + "retriever = docsearch.as_retriever()\n", + "\n", + "# LLM for completing prompts, and other tasks.\n", + "llm = OpenAI(temperature=0, max_tokens=128)\n", + "\n", + "# Conversation memory.\n", + "memory = ConversationSummaryBufferMemory(\n", + " max_token_limit=650,\n", + " llm=llm,\n", + " memory_key=\"chat_history\",\n", + " output_key='answer'\n", + ")\n", + "\n", + "# Conversational chain puts it all together.\n", + "chain = ConversationalRetrievalChain.from_llm(\n", + " llm=llm,\n", + " retriever=retriever,\n", + " return_source_documents=True,\n", + " memory=memory,\n", + " get_chat_history=identity,\n", + " max_tokens_limit=4096\n", + ")\n", + "\n", + "\"\"\"\n", + "# Language mismatch fix:\n", + "chain.combine_docs_chain.llm_chain.prompt.template = \\\n", + " \"Use the following pieces of context to answer the question at the end \" \\\n", + " \"in the same language as the question. If you don't know the answer, \" \\\n", + " \"just say that you don't know, don't try to make up an answer.\\n\\n\" \\\n", + " \"{context}\\n\\n\" \\\n", + " \"Question: {question}\\n\" \\\n", + " \"Helpful Answer: \"\n", + "\"\"\"\n", + "\n", + "# Contexts fix\n", + "chain.combine_docs_chain.llm_chain.prompt.template = \\\n", + " \"Use only the relevant contexts to answer the question at the end \" \\\n", + " \". Some pieces of context may not be relevant. If you don't know the answer, \" \\\n", + " \"just say that you don't know, don't try to make up an answer.\\n\\n\" \\\n", + " \"Contexts: \\n{context}\\n\\n\" \\\n", + " \"Question: {question}\\n\" \\\n", + " \"Helpful Answer: \"\n", + "\n", + "chain.combine_docs_chain.document_prompt.template=\"\\tContext: {page_content}\"\n", + "\n", + "# Trulens instrumentation." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "hugs = tru_feedback.Huggingface()\n", + "openai = tru_feedback.OpenAI()\n", + "\n", + "f_toxic = tru_feedback.Feedback(hugs.not_toxic).on_response()\n", + "f_lang_match = tru_feedback.Feedback(hugs.language_match).on(text1=\"prompt\", text2=\"response\")\n", + "f_relevance = tru_feedback.Feedback(openai.relevance).on(prompt=\"input\", response=\"output\")\n", + "f_qs_relevance = tru_feedback.Feedback(openai.qs_relevance) \\\n", + " .on(question=\"input\", statement=Record.chain.combine_docs_chain._call.args.inputs.input_documents) \\\n", + " .on_multiple(multiarg=\"statement\", each_query=Query().page_content)\n", + "\n", + "# feedbacks = tru.run_feedback_functions(chain=tc, record=record, feedback_functions=[f_qs_relevance, f_toxic, f_lang_match, f_relevance])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def filter_statements(question: str, statements: str, threshold: float = 0.5):\n", + " promises = []\n", + " for statement in statements:\n", + " promises.append((statement, TP().promise(openai.qs_relevance, question=question, statement=statement)))\n", + " \n", + " results = []\n", + " for statement, promise in promises:\n", + " results.append((statement, promise.get()))\n", + "\n", + " results = map(lambda sr: sr[0], filter(lambda sr: sr[1] >= threshold, results))\n", + "\n", + " return list(results)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "good = filter_statements(question=\"Who is Shayak?\", statements=[\"Piotr is a person.\", \"Shayak is a person.\", \"Shammek is a person.\"])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "retriever.get_relevant_documents" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "\n", + "# help(retriever)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "test = WithFilterDocuments.of_vectorstoreretriever(retriever=retriever, filter_func=ffunc)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "test.get_relevant_documents(\"Who is Shayak?\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "tc = tru_chain.TruChain(\n", + " chain,\n", + " chain_id=chain_id,\n", + " feedbacks=[f_toxic, f_lang_match, f_relevance, f_qs_relevance],\n", + " db=tru.lms\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "res, record = tc.call_with_record(\"What is TruEra?\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "display(record)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "tc.json" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from trulens_eval.tru_feedback import Feedback" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "obj = f_qs_relevance.json\n", + "display(obj)\n", + "#display(Feedback.of_json(obj).to_json())\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# from tinydb import Query\n", + "from trulens_eval.tru_db import Query, Record" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "dir(Record)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "feedbacks" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "for doc in TruDB.project(query=Record.chain.combine_docs_chain._call.args.inputs.input_documents, obj=record):\n", + " print(doc)\n", + " content = TruDB.project(query=Record.page_content, obj=doc)\n", + " print(content)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "feedbacks" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# e = Endpoint(name=\"openai\", rpm=120)\n", + "# print(e.pace.qsize())" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "tru.endpoint_openai.tqdm.display()\n", + "i = 0\n", + "while True:\n", + " # print(e.pace.qsize())\n", + " tru.endpoint_openai.pace_me()\n", + " # print(i)\n", + " i+=1" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# tru_feedback.huggingface_language_match(prompt=\"Hello there?\", response=\"How are you?\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# db = LocalTinyDB(\"slackbot.json\")\n", + "#tru.init_db(\"slackbot.sql\")\n", + "#db = LocalSQLite(\"slackbot.sql.db\")\n", + "db = LocalSQLite()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "df, dff = db.get_records_and_feedback(chain_ids=[])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "df" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from pprint import PrettyPrinter\n", + "pp = PrettyPrinter(compact=True)\n", + "\n", + "for i, row in df.iterrows():\n", + " \n", + " display(widgets.HTML(f\"Question: {row.input}\"))\n", + " \n", + " display(widgets.HTML(f\"Answer: {row.output}\"))\n", + " \n", + " details = json.loads(eval(row.details))\n", + "\n", + " display(widgets.HTML(str(details['chain']['combine_docs_chain']['llm_chain']['prompt']['template'])))\n", + " \n", + " for doc in details['chain']['combine_docs_chain']['_call']['args']['inputs']['input_documents']:\n", + " display(widgets.HTML(f\"\"\"\n", + "
\n", + " Context chunk: {doc['page_content']}\n", + " \"\"\"))\n", + "\n", + " \"\"\"
\n", + "\n", + " source: {doc['metadata']['source']}\n", + "
\"\"\"\n", + "\n", + " print()\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "df = db.select(\n", + " Record,\n", + " Record.record_id,\n", + " Record.chain_id,\n", + " Record.chain._call.args.inputs.question,\n", + " Record.chain._call.rets.answer)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "df" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "for row_id, row in df.iterrows():\n", + " record_id = row.record_id\n", + " chain_id = row.chain_id\n", + "\n", + " main_question = row['Record.chain._call.args.inputs.question']\n", + " main_answer = row['Record.chain._call.rets.answer']\n", + "\n", + " print(chain_id, record_id, main_question)\n", + " \"\"\"\n", + " \n", + " print(question, answer)\n", + "\n", + " # Run feedback function and get value\n", + "\n", + " feedback = tru.run_feedback_function(\n", + " main_question, main_answer, [\n", + " tru_feedback.get_not_hate_function(\n", + " evaluation_choice='prompt',\n", + " provider='openai',\n", + " model_engine='moderation'\n", + " ),\n", + " tru_feedback.get_sentimentpositive_function(\n", + " evaluation_choice='response',\n", + " provider='openai',\n", + " model_engine='gpt-3.5-turbo'\n", + " ),\n", + " tru_feedback.get_relevance_function(\n", + " evaluation_choice='both',\n", + " provider='openai',\n", + " model_engine='gpt-3.5-turbo'\n", + " )\n", + " ]\n", + " )\n", + " print(f\"will insert overall feedback for chain {chain_id}, record {record_id}\")\n", + " db.insert_feedback(record_id=record_id, chain_id=chain_id, feedback=feedback)\n", + " \"\"\"\n", + " \n", + " # display(JSON(row.Record))\n", + " # print(row.Record['chain'])\n", + "\n", + " model_name = \"gpt-3.5-turbo\"\n", + "\n", + " \"\"\"\n", + " for page in TruDB.project(query=Record.chain.combine_docs_chain._call.args.inputs.input_documents, obj=row.Record):\n", + " answer = page['page_content']\n", + " feedback = tru.run_feedback_function(\n", + " main_question,\n", + " answer,\n", + "\t [\n", + " tru_feedback.get_qs_relevance_function(\n", + " evaluation_choice='prompt',\n", + " provider='openai',\n", + " model_engine=model_name\n", + " )]\n", + " )\n", + " db.insert_feedback(record_id=record_id, chain_id=chain_id, feedback=feedback)\n", + "\n", + " \"\"\"\n", + "\n", + " feedback = tru.run_feedback_function(\n", + " main_question, main_answer, [\n", + " tru_feedback.get_language_match_function(\n", + " provider='huggingface'\n", + " )\n", + " ]\n", + " )\n", + " print(f\"will insert language match feedback for chain {chain_id}, record {record_id}\")\n", + " db.insert_feedback(record_id=record_id, chain_id=chain_id, feedback=feedback)\n", + "\n", + " # feedback = tru.run_feedback_function(\n", + "\n", + " #for leaf in TruDB.leafs(row.Record):\n", + " # print(leaf)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "row.record_id" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "feedback = {'openai_hate_function': 1.849137515819166e-05,\n", + " 'openai_sentimentpositive_feedback_function': 1,\n", + " 'openai_relevance_function': 10}" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "db.insert_feedback(2, feedback)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "db.select(Record, table=db.feedbacks)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "model_name = \"gpt-3.5-turbo\"\n", + "feedback = tru.run_feedback_function(\n", + " \"Who is Piotr?\",\n", + " \"Piotr Mardziel works on transparency and accountability in machine learning with applications to security, privacy, and fairness. He holds Bachelor’s and Master’s degrees from the Worcester Polytechnic Institute and a PhD in computer science from University of Maryland, College Park. He has conducted post-doctoral research at Carnegie Mellon University, as well as taught classes in trustworthy machine learning at Stanford University and machine learning privacy and security at Carnegie Mellon University.\",\n", + "\t [\n", + " tru_feedback.get_qs_relevance_function(\n", + " evaluation_choice='prompt',\n", + " provider='openai',\n", + " model_engine=model_name\n", + " )])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "feedback" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "demo3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.6" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/trulens_eval/trulens_eval/.env.example b/trulens_eval/trulens_eval/.env.example new file mode 100644 index 000000000..551eb4f63 --- /dev/null +++ b/trulens_eval/trulens_eval/.env.example @@ -0,0 +1,19 @@ +# Once you add your API keys below, make sure to not share it with anyone! The API key should remain private. + +# models + +## openai +OPENAI_API_KEY = "" + +## cohere +COHERE_API_KEY = "" + +## huggingface: +HUGGINGFACE_API_KEY = "" +HUGGINGFACE_HEADERS = {"Authorization": f"Bearer {HUGGINGFACE_API_KEY}"} + +# benchmarking data + +## kaggle +KAGGLE_USERNAME = "" +KAGGLE_KEY = "" diff --git a/trulens_eval/trulens_eval/Example_TruBot.py b/trulens_eval/trulens_eval/Example_TruBot.py new file mode 100644 index 000000000..d23833a84 --- /dev/null +++ b/trulens_eval/trulens_eval/Example_TruBot.py @@ -0,0 +1,160 @@ +import os + +os.environ['PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION'] = 'python' + +from langchain.callbacks import get_openai_callback +from langchain.chains import ConversationalRetrievalChain +from langchain.embeddings.openai import OpenAIEmbeddings +from langchain.llms import OpenAI +from langchain.memory import ConversationSummaryBufferMemory +from langchain.vectorstores import Pinecone +import numpy as np +import pinecone +import streamlit as st + +from trulens_eval import tru +from trulens_eval import tru_chain +from trulens_eval import tru_feedback +from trulens_eval.keys import * +from trulens_eval.keys import PINECONE_API_KEY +from trulens_eval.keys import PINECONE_ENV +from trulens_eval.tru_db import Record +from trulens_eval.tru_feedback import Feedback + +# Set up GPT-3 model +model_name = "gpt-3.5-turbo" + +chain_id = "TruBot" +# chain_id = "TruBot_langprompt" +# chain_id = "TruBot_relevance" + +# Pinecone configuration. +pinecone.init( + api_key=PINECONE_API_KEY, # find at app.pinecone.io + environment=PINECONE_ENV # next to api key in console +) + +identity = lambda h: h + +hugs = tru_feedback.Huggingface() +openai = tru_feedback.OpenAI() + +f_lang_match = Feedback(hugs.language_match).on( + text1="prompt", text2="response" +) + +f_qa_relevance = Feedback(openai.relevance).on( + prompt="input", response="output" +) + +f_qs_relevance = Feedback(openai.qs_relevance).on( + question="input", + statement=Record.chain.combine_docs_chain._call.args.inputs.input_documents +).on_multiple( + multiarg="statement", each_query=Record.page_content, agg=np.min +) + + +# @st.cache_data +def generate_response(prompt): + # Embedding needed for Pinecone vector db. + embedding = OpenAIEmbeddings(model='text-embedding-ada-002') # 1536 dims + docsearch = Pinecone.from_existing_index( + index_name="llmdemo", embedding=embedding + ) + retriever = docsearch.as_retriever() + + # LLM for completing prompts, and other tasks. + llm = OpenAI(temperature=0, max_tokens=128) + + # Conversation memory. + memory = ConversationSummaryBufferMemory( + max_token_limit=650, + llm=llm, + memory_key="chat_history", + output_key='answer' + ) + + # Conversational chain puts it all together. + chain = ConversationalRetrievalChain.from_llm( + llm=llm, + retriever=retriever, + return_source_documents=True, + memory=memory, + get_chat_history=identity, + max_tokens_limit=4096 + ) + + # Language mismatch fix: + if "langprompt" in chain_id: + chain.combine_docs_chain.llm_chain.prompt.template = \ + "Use the following pieces of CONTEXT to answer the question at the end " \ + "in the same language as the question. If you don't know the answer, " \ + "just say that you don't know, don't try to make up an answer.\n" \ + "\n" \ + "CONTEXT: {context}\n" \ + "\n" \ + "Question: {question}\n" \ + "Helpful Answer: " + + elif "relevance" in chain_id: + # Contexts fix + chain.combine_docs_chain.llm_chain.prompt.template = \ + "Use only the relevant contexts to answer the question at the end " \ + ". Some pieces of context may not be relevant. If you don't know the answer, " \ + "just say that you don't know, don't try to make up an answer.\n" \ + "\n" \ + "Contexts: \n" \ + "{context}\n" \ + "\n" \ + "Question: {question}\n" \ + "Helpful Answer: " + + # space is important + + chain.combine_docs_chain.document_prompt.template = "\tContext: {page_content}" + + # Trulens instrumentation. + tc = tru_chain.TruChain(chain, chain_id=chain_id) + + return tc, tc.call_with_record(dict(question=prompt)) + + +# Set up Streamlit app +st.title("TruBot") +user_input = st.text_input("Ask a question about TruEra") + +if user_input: + # Generate GPT-3 response + prompt_input = user_input + # add context manager to capture tokens and cost of the chain + + with get_openai_callback() as cb: + chain, (response, record) = generate_response(prompt_input) + total_tokens = cb.total_tokens + total_cost = cb.total_cost + + answer = response['answer'] + + # Display response + st.write(answer) + + record_id = tru.add_data( + chain_id=chain_id, + prompt=prompt_input, + response=answer, + record=record, + tags='dev', + total_tokens=total_tokens, + total_cost=total_cost + ) + + # Run feedback function and get value + feedbacks = tru.run_feedback_functions( + chain=chain, + record=record, + feedback_functions=[f_lang_match, f_qa_relevance, f_qs_relevance] + ) + + # Add value to database + tru.add_feedback(record_id, feedbacks) diff --git a/trulens_eval/trulens_eval/Leaderboard.py b/trulens_eval/trulens_eval/Leaderboard.py new file mode 100644 index 000000000..cc6bb4780 --- /dev/null +++ b/trulens_eval/trulens_eval/Leaderboard.py @@ -0,0 +1,98 @@ +import math + +from millify import millify +import numpy as np +import streamlit as st +from streamlit_extras.switch_page_button import switch_page + +st.runtime.legacy_caching.clear_cache() + +from trulens_eval import Tru +from trulens_eval import tru_db + +st.set_page_config(page_title="Leaderboard", layout="wide") + +from trulens_eval.ux.add_logo import add_logo + +add_logo() + +tru = Tru() +lms = tru.db + + +def app(): + # Set the title and subtitle of the app + st.title('Chain Leaderboard') + st.write( + 'Average feedback values displayed in the range from 0 (worst) to 1 (best).' + ) + df, feedback_col_names = lms.get_records_and_feedback([]) + + if df.empty: + st.write("No records yet...") + return + + df = df.sort_values(by="chain_id") + + if df.empty: + st.write("No records yet...") + + chains = list(df.chain_id.unique()) + st.markdown("""---""") + + for chain in chains: + st.header(chain) + col1, col2, col3, *feedback_cols, col99 = st.columns( + 4 + len(feedback_col_names) + ) + chain_df = df.loc[df.chain_id == chain] + #model_df_feedback = df.loc[df.chain_id == model] + + col1.metric("Records", len(chain_df)) + col2.metric( + "Cost", + f"${millify(round(sum(cost for cost in chain_df.total_cost if cost is not None), 5), precision = 2)}" + ) + col3.metric( + "Tokens", + millify( + sum( + tokens for tokens in chain_df.total_tokens + if tokens is not None + ), + precision=2 + ) + ) + + for i, col_name in enumerate(feedback_col_names): + mean = chain_df[col_name].mean() + + if i < len(feedback_cols): + if math.isnan(mean): + pass + + else: + feedback_cols[i].metric(col_name, round(mean, 2)) + + else: + if math.isnan(mean): + pass + + else: + feedback_cols[i].metric(col_name, round(mean, 2)) + + with col99: + if st.button('Select Chain', key=f"model-selector-{chain}"): + st.session_state.chain = chain + switch_page('Evaluations') + + st.markdown("""---""") + + +# Define the main function to run the app +def main(): + app() + + +if __name__ == '__main__': + main() diff --git a/trulens_eval/trulens_eval/__init__.py b/trulens_eval/trulens_eval/__init__.py new file mode 100644 index 000000000..db68f3d42 --- /dev/null +++ b/trulens_eval/trulens_eval/__init__.py @@ -0,0 +1,13 @@ +""" +Imports of most common parts of the library. Should include everything to get started. +""" + +__version__ = "0.0.1" + +from trulens_eval.tru_chain import TruChain +from trulens_eval.tru_feedback import Feedback +from trulens_eval.tru_feedback import OpenAI +from trulens_eval.tru_feedback import Huggingface +from trulens_eval.tru import Tru + +__all__ = ['TruChain', 'Feedback', 'OpenAI', 'Huggingface', 'Tru'] diff --git a/trulens_eval/trulens_eval/benchmark.py b/trulens_eval/trulens_eval/benchmark.py new file mode 100644 index 000000000..5a33f8493 --- /dev/null +++ b/trulens_eval/trulens_eval/benchmark.py @@ -0,0 +1,165 @@ +import time +import zipfile + +from datasets import load_dataset +from kaggle.api.kaggle_api_extended import KaggleApi +import pandas as pd + +from trulens_eval import tru_feedback + + +def load_data(dataset_choice): + if dataset_choice == 'imdb (binary sentiment)': + data = load_dataset('imdb') + train = pd.DataFrame(data['train']) + test = pd.DataFrame(data['test']) + data = pd.concat([train, test]) + elif dataset_choice == 'jigsaw (binary toxicity)': + kaggle_api = KaggleApi() + kaggle_api.authenticate() + + kaggle_api.dataset_download_files( + 'julian3833/jigsaw-unintended-bias-in-toxicity-classification' + ) + with zipfile.ZipFile( + 'jigsaw-unintended-bias-in-toxicity-classification.zip') as z: + with z.open('all_data.csv') as f: + data = pd.read_csv( + f, header=0, sep=',', quotechar='"' + )[['comment_text', + 'toxicity']].rename(columns={'comment_text': 'text'}) + + data['label'] = data['toxicity'] >= 0.5 + data['label'] = data['label'].astype(int) + elif dataset_choice == 'fake news (binary)': + kaggle_api = KaggleApi() + kaggle_api.authenticate() + + kaggle_api.dataset_download_files( + 'clmentbisaillon/fake-and-real-news-dataset' + ) + with zipfile.ZipFile('fake-and-real-news-dataset.zip') as z: + with z.open('True.csv') as f: + realdata = pd.read_csv( + f, header=0, sep=',', quotechar='"' + )[['title', 'text']] + realdata['label'] = 0 + realdata = pd.DataFrame(realdata) + with z.open('Fake.csv') as f: + fakedata = pd.read_csv( + f, header=0, sep=',', quotechar='"' + )[['title', 'text']] + fakedata['label'] = 1 + fakedata = pd.DataFrame(fakedata) + data = pd.concat([realdata, fakedata]) + data['text'] = 'title: ' + data['title'] + '; text: ' + data['text'] + + return data + + +def sample_data(data, num_samples): + return data.sample(num_samples) + + +def get_rate_limited_feedback_function( + feedback_function_name, provider, model_engine, rate_limit, + evaluation_choice +): + rate_limit = rate_limit + interval = 60 / rate_limit + last_call_time = time.time() + + def rate_limited_feedback(prompt='', response='', **kwargs): + nonlocal last_call_time + + elapsed_time = time.time() - last_call_time + + if elapsed_time < interval: + time.sleep(interval - elapsed_time) + + if feedback_function_name in tru_feedback.FEEDBACK_FUNCTIONS: + feedback_function = tru_feedback.FEEDBACK_FUNCTIONS[ + feedback_function_name]( + provider=provider, + model_engine=model_engine, + evaluation_choice=evaluation_choice, + **kwargs + ) + else: + raise ValueError( + f"Unrecognized feedback_function_name. Please use one of {list(tru_feedback.FEEDBACK_FUNCTIONS.keys())} " + ) + + result = feedback_function(prompt=prompt, response=response, **kwargs) + last_call_time = time.time() + + return result + + return rate_limited_feedback + + +def benchmark_on_data( + data, feedback_function_name, evaluation_choice, provider, model_engine +): + if feedback_function_name in tru_feedback.FEEDBACK_FUNCTIONS: + feedback_function = tru_feedback.FEEDBACK_FUNCTIONS[ + feedback_function_name]( + evaluation_choice=evaluation_choice, + provider=provider, + model_engine=model_engine + ) + else: + raise ValueError( + f"Unrecognized feedback_function_name. Please use one of {list(tru_feedback.FEEDBACK_FUNCTIONS.keys())} " + ) + if 'prompt' in data and 'response' in data: + data['feedback'] = data.apply( + lambda x: feedback_function(x['prompt'], x['response']), axis=1 + ) + else: + data['feedback'] = data['text'].apply( + lambda x: feedback_function('', x) + ) + + data['correct'] = data['label'] == data['feedback'] + + score = data['correct'].sum() / len(data) + + print( + feedback_function, 'scored: ', '{:.1%}'.format(score), + 'on the benchmark: ', "imdb" + ) + return data + + +def rate_limited_benchmark_on_data( + data, feedback_function_name, rate_limit, evaluation_choice, provider, + model_engine +): + rate_limited_feedback_function = get_rate_limited_feedback_function( + feedback_function_name, provider, model_engine, rate_limit, + evaluation_choice + ) + if 'prompt' in data and 'response' in data: + data['feedback'] = data.apply( + lambda x: + rate_limited_feedback_function(x['prompt'], x['response']), + axis=1 + ) + else: + data['feedback'] = data['text'].apply( + lambda x: rate_limited_feedback_function( + prompt='', + response=x, + ) + ) + + data['correct'] = data['label'] == data['feedback'] + + score = data['correct'].sum() / len(data) + + print( + feedback_function_name, 'scored: ', '{:.1%}'.format(score), + 'on the benchmark: ', "imdb" + ) + return data diff --git a/trulens_eval/trulens_eval/feedback_prompts.py b/trulens_eval/trulens_eval/feedback_prompts.py new file mode 100644 index 000000000..73f150eb1 --- /dev/null +++ b/trulens_eval/trulens_eval/feedback_prompts.py @@ -0,0 +1,96 @@ +from cohere.responses.classify import Example + +QS_RELEVANCE = """You are a RELEVANCE classifier; providing the relevance of the given STATEMENT to the given QUESTION. +Respond only as a number from 1 to 10 where 1 is the least relevant and 10 is the most relevant. +Never elaborate. + +QUESTION: {question} + +STATEMENT: {statement} + +RELEVANCE: """ + +PR_RELEVANCE = """ +You are a relevance classifier, providing the relevance of a given response to the given prompt. +Respond only as a number from 1 to 10 where 1 is the least relevant and 10 is the most relevant. +Never elaborate. + +Prompt: {prompt} + +Response: {response} + +Relevance: """ + +SENTIMENT_SYSTEM_PROMPT = f"Please classify the sentiment of the following text as 1 if positive or 0 if not positive. Respond with only a '1' or '0', nothing more." +RELEVANCE_SYSTEM_PROMPT = f"You are a relevance classifier, providing the relevance of a given response to a particular prompt. \n" +"Provide all responses only as a number from 1 to 10 where 1 is the least relevant and 10 is the most. Always respond with an integer between 1 and 10. \n" +"Never elaborate. The prompt is: " +RELEVANCE_CONTENT_PROMPT = f"For that prompt, how relevant is this response on the scale between 1 and 10: " + +COHERE_SENTIMENT_EXAMPLES = [ + Example("The order came 5 days early", "1"), + Example("I just got a promotion at work and I\'m so excited!", "1"), + Example( + "My best friend surprised me with tickets to my favorite band's concert.", + "1" + ), + Example( + "I\'m so grateful for my family's support during a difficult time.", "1" + ), + Example("It\'s kind of grungy, but the pumpkin pie slaps", "1"), + Example( + "I love spending time in nature and feeling connected to the earth.", + "1" + ), + Example("I had an amazing meal at the new restaurant in town", "1"), + Example("The pizza is good, but the staff is horrible to us", "0"), + Example("The package was damaged", "0"), + Example("I\'m feeling really sick and can\'t seem to shake it off", "0"), + Example("I got into a car accident and my car is completely totaled.", "0"), + Example( + "My boss gave me a bad performance review and I might get fired", "0" + ), + Example("I got into a car accident and my car is completely totaled.", "0"), + Example( + "I\'m so disapointed in myself for not following through on my goals", + "0" + ) +] + +COHERE_NOT_DISINFORMATION_EXAMPLES = [ + Example( + "Bud Light Official SALES REPORT Just Released ′ 50% DROP In Sales ′ Total COLLAPSE ′ Bankruptcy?", + "0" + ), + Example( + "The Centers for Disease Control and Prevention quietly confirmed that at least 118,000 children and young adults have “died suddenly” in the U.S. since the COVID-19 vaccines rolled out,", + "0" + ), + Example( + "Silicon Valley Bank collapses, in biggest failure since financial crisis", + "1" + ), + Example( + "Biden admin says Alabama health officials didn’t address sewage system failures disproportionately affecting Black residents", + "1" + ) +] + + +CORRECT_SYSTEM_PROMPT = \ +""" +You are a fact bot and you answer with verifiable facts +""" + +AGREEMENT_SYSTEM_PROMPT = \ +""" +You will continually start seeing responses to the prompt: + +%s + +The right answer is: + +%s + +Answer only with an integer from 1 to 10 based on how close the responses are to the right answer. +""" diff --git a/trulens_eval/trulens_eval/keys.py b/trulens_eval/trulens_eval/keys.py new file mode 100644 index 000000000..c39abe17b --- /dev/null +++ b/trulens_eval/trulens_eval/keys.py @@ -0,0 +1,43 @@ +""" +Read secrets from .env for exporting to python scripts. Usage: +```python + from keys import * +``` +Will get you access to all of the vars defined in .env in wherever you put that import statement. +""" + +import os + +import cohere +import dotenv + +config = dotenv.dotenv_values(".env") + +for k, v in config.items(): + print(f"KEY SET: {k}") + globals()[k] = v + + # set them into environment as well + os.environ[k] = v + +if 'OPENAI_API_KEY' in os.environ: + import openai + openai.api_key = os.environ["OPENAI_API_KEY"] + +global cohere_agent +cohere_agent = None + + +def get_cohere_agent(): + global cohere_agent + if cohere_agent is None: + cohere.api_key = os.environ['COHERE_API_KEY'] + cohere_agent = cohere.Client(cohere.api_key) + return cohere_agent + + +def get_huggingface_headers(): + HUGGINGFACE_HEADERS = { + "Authorization": f"Bearer {os.environ['HUGGINGFACE_API_KEY']}" + } + return HUGGINGFACE_HEADERS diff --git a/trulens_eval/trulens_eval/pages/Evaluations.py b/trulens_eval/trulens_eval/pages/Evaluations.py new file mode 100644 index 000000000..0a3c14df3 --- /dev/null +++ b/trulens_eval/trulens_eval/pages/Evaluations.py @@ -0,0 +1,242 @@ +import json +from typing import Dict, List + +import matplotlib.pyplot as plt +import pandas as pd +from st_aggrid import AgGrid +from st_aggrid.grid_options_builder import GridOptionsBuilder +from st_aggrid.shared import GridUpdateMode +from st_aggrid.shared import JsCode +import streamlit as st +from ux.add_logo import add_logo + +from trulens_eval import Tru +from trulens_eval import tru_db +from trulens_eval.tru_db import is_empty +from trulens_eval.tru_db import is_noserio +from trulens_eval.tru_db import TruDB + +st.set_page_config(page_title="Evaluations", layout="wide") + +st.title("Evaluations") + +st.runtime.legacy_caching.clear_cache() + +add_logo() + +tru = Tru() +lms = tru.db + +df_results, feedback_cols = lms.get_records_and_feedback([]) + +if df_results.empty: + st.write("No records yet...") + +else: + chains = list(df_results.chain_id.unique()) + + if 'Chains' in st.session_state: + chain = st.session_state.chain + else: + chain = chains + + options = st.multiselect('Filter Chains', chains, default=chain) + + if (len(options) == 0): + st.header("All Chains") + chain_df = df_results + + elif (len(options) == 1): + st.header(options[0]) + + chain_df = df_results[df_results.chain_id.isin(options)] + + else: + st.header("Multiple Chains Selected") + + chain_df = df_results[df_results.chain_id.isin(options)] + + tab1, tab2 = st.tabs(["Records", "Feedback Functions"]) + + with tab1: + gridOptions = {'alwaysShowHorizontalScroll': True} + evaluations_df = chain_df + gb = GridOptionsBuilder.from_dataframe(evaluations_df) + + cellstyle_jscode = JsCode( + """ + function(params) { + if (parseFloat(params.value) < 0.5) { + return { + 'color': 'black', + 'backgroundColor': '#FCE6E6' + } + } else if (parseFloat(params.value) >= 0.5) { + return { + 'color': 'black', + 'backgroundColor': '#4CAF50' + } + } else { + return { + 'color': 'black', + 'backgroundColor': 'white' + } + } + }; + """ + ) + + gb.configure_column('record_id', header_name='Record ID') + gb.configure_column('chain_id', header_name='Chain ID') + gb.configure_column('input', header_name='User Input') + gb.configure_column( + 'output', + header_name='Response', + ) + gb.configure_column('total_tokens', header_name='Total Tokens') + gb.configure_column('total_cost', header_name='Total Cost') + gb.configure_column('tags', header_name='Tags') + gb.configure_column('ts', header_name='Time Stamp') + + for feedback_col in evaluations_df.columns.drop(['chain_id', 'ts', + 'total_tokens', + 'total_cost']): + gb.configure_column(feedback_col, cellStyle=cellstyle_jscode) + gb.configure_pagination() + gb.configure_side_bar() + gb.configure_selection(selection_mode="single", use_checkbox=False) + + #gb.configure_default_column(groupable=True, value=True, enableRowGroup=True, aggFunc="sum", editable=True) + gridOptions = gb.build() + data = AgGrid( + evaluations_df, + gridOptions=gridOptions, + update_mode=GridUpdateMode.SELECTION_CHANGED, + allow_unsafe_jscode=True + ) + + selected_rows = data['selected_rows'] + selected_rows = pd.DataFrame(selected_rows) + + if len(selected_rows) == 0: + st.write("Hint: select a row to display chain metadata") + + else: + st.header(f"Selected Chain ID: {selected_rows['chain_id'][0]}") + st.text(f"Selected Record ID: {selected_rows['record_id'][0]}") + prompt = selected_rows['input'][0] + response = selected_rows['output'][0] + with st.expander("Input Prompt", expanded=True): + st.write(prompt) + + with st.expander("Response", expanded=True): + st.write(response) + + record_str = selected_rows['record_json'][0] + record_json = json.loads(record_str) + + details = selected_rows['chain_json'][0] + details_json = json.loads(details) + #json.loads(details)) # ??? + + chain_json = details_json['chain'] + + llm_queries = list( + TruDB.matching_objects( + details_json, + match=lambda q, o: len(q._path) > 0 and "llm" == q._path[-1] + ) + ) + + prompt_queries = list( + TruDB.matching_objects( + details_json, + match=lambda q, o: len(q._path) > 0 and "prompt" == q._path[ + -1] and "_call" not in q._path + ) + ) + + max_len = max(len(llm_queries), len(prompt_queries)) + + for i in range(max_len): + if i < len(llm_queries): + query, llm_details_json = llm_queries[i] + path_str = TruDB._query_str(query) + st.header(f"Chain Step {i}: {path_str.replace('.llm', '')}") + st.subheader(f"LLM Details:") + + llm_kv = { + k: v + for k, v in llm_details_json.items() + if (v is not None) and not is_empty(v) and + not is_noserio(v) + } + # CSS to inject contained in a string + hide_table_row_index = """ + + """ + df = pd.DataFrame.from_dict(llm_kv, orient='index') + # Inject CSS with Markdown + st.markdown(hide_table_row_index, unsafe_allow_html=True) + st.table(df.transpose()) + + if i < len(prompt_queries): + query, prompt_details_json = prompt_queries[i] + path_str = TruDB._query_str(query) + st.subheader(f"Prompt Details:") + + prompt_types = { + k: v + for k, v in prompt_details_json.items() + if (v is not None) and not is_empty(v) and + not is_noserio(v) + } + + for key, value in prompt_types.items(): + with st.expander(key.capitalize(), expanded=True): + if isinstance(value, (Dict, List)): + st.write(value) + else: + if isinstance(value, str) and len(value) > 32: + st.text(value) + else: + st.write(value) + st.header("More options:") + if st.button("Display full chain json"): + + st.write(details_json) + + if st.button("Display full record json"): + + st.write(record_json) + + with tab2: + feedback = feedback_cols + cols = 4 + rows = len(feedback) // cols + 1 + + for row_num in range(rows): + with st.container(): + columns = st.columns(cols) + for col_num in range(cols): + with columns[col_num]: + ind = row_num * cols + col_num + if ind < len(feedback): + # Generate histogram + fig, ax = plt.subplots() + bins = [ + 0, 0.2, 0.4, 0.6, 0.8, 1.0 + ] # Quintile buckets + ax.hist( + chain_df[feedback[ind]], + bins=bins, + edgecolor='black', + color='#2D736D' + ) + ax.set_xlabel('Feedback Value') + ax.set_ylabel('Frequency') + ax.set_title(feedback[ind], loc='center') + st.pyplot(fig) diff --git a/trulens_eval/trulens_eval/pages/Progress.py b/trulens_eval/trulens_eval/pages/Progress.py new file mode 100644 index 000000000..45fdfef84 --- /dev/null +++ b/trulens_eval/trulens_eval/pages/Progress.py @@ -0,0 +1,56 @@ +from datetime import datetime +import json +from typing import Dict, List + +from trulens_eval.keys import * + +import pandas as pd +from st_aggrid import AgGrid +from st_aggrid.grid_options_builder import GridOptionsBuilder +from st_aggrid.shared import GridUpdateMode +from st_aggrid.shared import JsCode +import streamlit as st +from trulens_eval.tru_feedback import Feedback +from trulens_eval.util import TP +from ux.add_logo import add_logo +from trulens_eval.tru_db import is_empty + + +from trulens_eval import tru_db, Tru +from trulens_eval.provider_apis import Endpoint + +from trulens_eval.tru_db import is_noserio +from trulens_eval.tru_db import TruDB + +st.set_page_config(page_title="Feedback Progress", layout="wide") + +st.title("Feedback Progress") + +st.runtime.legacy_caching.clear_cache() + +add_logo() + +tru = Tru() +lms = tru.db + +e_openai = Endpoint("openai") +e_hugs = Endpoint("huggingface") +e_cohere = Endpoint("cohere") + +endpoints = [e_openai, e_hugs, e_cohere] + +tab1, tab2, tab3 = st.tabs(["Progress", "Endpoints", "Feedback Functions"]) + +with tab1: + feedbacks = lms.get_feedback(status=[-1,0,1]) + st.write(feedbacks) + +with tab2: + for e in endpoints: + st.header(e.name.upper()) + st.metric("RPM", e.rpm) + st.write(e.tqdm) + +with tab3: + feedbacks = lms.get_feedback_defs() + st.write(feedbacks) diff --git a/trulens_eval/trulens_eval/pages/__init__.py b/trulens_eval/trulens_eval/pages/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/trulens_eval/trulens_eval/provider_apis.py b/trulens_eval/trulens_eval/provider_apis.py new file mode 100644 index 000000000..507309a04 --- /dev/null +++ b/trulens_eval/trulens_eval/provider_apis.py @@ -0,0 +1,122 @@ +import logging +from multiprocessing import Queue +# from queue import Queue +from threading import Thread +from time import sleep +from typing import Any, Optional, Sequence + +import requests +from tqdm.auto import tqdm +from trulens_eval.tru_db import JSON + +from trulens_eval.util import SingletonPerName +from trulens_eval.util import TP + + +class Endpoint(SingletonPerName): + + def __init__( + self, name: str, rpm: float = 60, retries: int = 3, post_headers=None + ): + """ + Pacing and utilities for API endpoints. + + Args: + + - name: str -- api name / identifier. + + - rpm: float -- requests per minute. + + - retries: int -- number of retries before failure. + + - post_headers: Dict -- http post headers if this endpoint uses http + post. + """ + + if hasattr(self, "rpm"): + # already initialized via the SingletonPerName mechanism + return + + logging.debug(f"*** Creating {name} endpoint ***") + + self.rpm = rpm + self.retries = retries + self.pace = Queue( + maxsize=rpm // 6 + ) # 10 second's worth of accumulated api + self.tqdm = tqdm(desc=f"{name} api", unit="requests") + self.name = name + self.post_headers = post_headers + + self._start_pacer() + + def pace_me(self): + """ + Block until we can make a request to this endpoint. + """ + + self.pace.get() + self.tqdm.update(1) + return + + def post(self, url: str, payload: JSON, timeout: Optional[int] = None) -> Any: + extra = dict() + if self.post_headers is not None: + extra['headers'] = self.post_headers + + self.pace_me() + ret = requests.post(url, json=payload, timeout=timeout, **extra) + + j = ret.json() + + # Huggingface public api sometimes tells us that a model is loading and how long to wait: + if "estimated_time" in j: + wait_time = j['estimated_time'] + logging.error(f"Waiting for {j} ({wait_time}) second(s).") + sleep(wait_time+2) + return self.post(url, payload) + + assert isinstance( + j, Sequence + ) and len(j) > 0, f"Post did not return a sequence: {j}" + + return j[0] + + def run_me(self, thunk): + """ + Run the given thunk, returning itse output, on pace with the api. + Retries request multiple times if self.retries > 0. + """ + + retries = self.retries + 1 + retry_delay = 2.0 + + while retries > 0: + try: + self.pace_me() + ret = thunk() + return ret + except Exception as e: + retries -= 1 + logging.error( + f"{self.name} request failed {type(e)}={e}. Retries={retries}." + ) + if retries > 0: + sleep(retry_delay) + retry_delay *= 2 + + raise RuntimeError( + f"API {self.name} request failed {self.retries+1} time(s)." + ) + + def _start_pacer(self): + + def keep_pace(): + while True: + sleep(60.0 / self.rpm) + self.pace.put(True) + + thread = Thread(target=keep_pace) + thread.start() + + self.pacer_thread = thread diff --git a/trulens_eval/trulens_eval/slackbot.py b/trulens_eval/trulens_eval/slackbot.py new file mode 100644 index 000000000..cbcfc7521 --- /dev/null +++ b/trulens_eval/trulens_eval/slackbot.py @@ -0,0 +1,411 @@ +import logging +import os +from pprint import PrettyPrinter +from typing import Callable, Dict, List, Set, Tuple + +import numpy as np + +# This needs to be before some others to make sure api keys are ready before +# relevant classes are loaded. +from trulens_eval.keys import * + + +# This is here so that import organizer does not move the keys import below this +# line. +_ = None + +from langchain.chains import ConversationalRetrievalChain +from langchain.embeddings.openai import OpenAIEmbeddings +from langchain.llms import OpenAI +from langchain.memory import ConversationSummaryBufferMemory +from langchain.schema import Document +from langchain.vectorstores import Pinecone +from langchain.vectorstores.base import VectorStoreRetriever +import pinecone +from pydantic import Field +from slack_bolt import App +from slack_sdk import WebClient + +from trulens_eval import Tru +from trulens_eval import tru_feedback +from trulens_eval.tru_chain import TruChain +from trulens_eval.tru_db import LocalSQLite +from trulens_eval.tru_db import Record +from trulens_eval.tru_feedback import Feedback +from trulens_eval.util import TP + +os.environ['PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION'] = 'python' + +pp = PrettyPrinter() + +PORT = 3000 +verb = False + +# create a conversational chain with relevant models and vector store + +# Pinecone configuration. +pinecone.init( + api_key=PINECONE_API_KEY, # find at app.pinecone.io + environment=PINECONE_ENV # next to api key in console +) + +# Cache of conversations. Keys are SlackAPI conversation ids (channel ids or +# otherwise) and values are TruChain to handle that conversation. +convos: Dict[str, TruChain] = dict() + +# Keep track of timestamps of messages already handled. Sometimes the same +# message gets received more than once if there is a network hickup. +handled_ts: Set[Tuple[str, str]] = set() + +# DB to save models and records. +tru = Tru()#LocalSQLite("trubot.sqlite")) + +ident = lambda h: h + +chain_ids = { + 0: "0/default", + 1: "1/lang_prompt", + 2: "2/relevance_prompt", + 3: "3/filtered_context", + 4: "4/filtered_context_and_lang_prompt" +} + +# Construct feedback functions. + +hugs = tru_feedback.Huggingface() +openai = tru_feedback.OpenAI() + +# Language match between question/answer. +f_lang_match = Feedback(hugs.language_match).on( + text1="prompt", text2="response" +) + +# Question/answer relevance between overall question and answer. +f_qa_relevance = Feedback(openai.relevance).on( + prompt="input", response="output" +) + +# Question/statement relevance between question and each context chunk. +f_qs_relevance = Feedback(openai.qs_relevance).on( + question="input", + statement=Record.chain.combine_docs_chain._call.args.inputs.input_documents +).on_multiple( + multiarg="statement", each_query=Record.page_content, agg=np.min +) + +class WithFilterDocuments(VectorStoreRetriever): + filter_func: Callable = Field(exclude=True) + + def __init__(self, filter_func: Callable, *args, **kwargs): + super().__init__(filter_func=filter_func, *args, **kwargs) + # self.filter_func = filter_func + + def get_relevant_documents(self, query: str) -> List[Document]: + docs = super().get_relevant_documents(query) + + promises = [] + for doc in docs: + promises.append( + (doc, TP().promise(self.filter_func, query=query, doc=doc)) + ) + + results = [] + for doc, promise in promises: + results.append((doc, promise.get())) + + docs_filtered = map(lambda sr: sr[0], filter(lambda sr: sr[1], results)) + + return list(docs_filtered) + + @staticmethod + def of_vectorstoreretriever(retriever, filter_func: Callable): + return WithFilterDocuments(filter_func=filter_func, **retriever.dict()) + +def filter_by_relevance(query, doc): + return openai.qs_relevance(question=query, statement=doc.page_content) > 0.5 + +def get_or_make_chain(cid: str, selector: int = 0) -> TruChain: + """ + Create a new chain for the given conversation id `cid` or return an existing + one. Return the new or existing chain. + """ + + # NOTE(piotrm): Unsure about the thread safety of the various components so + # making new ones for each conversation. + + if cid in convos: + return convos[cid] + + if selector not in chain_ids: + selector = 0 + + chain_id = chain_ids[selector] + + pp.pprint(f"Starting a new conversation with {chain_id}.") + + # Embedding needed for Pinecone vector db. + embedding = OpenAIEmbeddings(model='text-embedding-ada-002') # 1536 dims + docsearch = Pinecone.from_existing_index( + index_name="llmdemo", embedding=embedding + ) + + retriever = docsearch.as_retriever() + + if "filtered" in chain_id: + retriever = WithFilterDocuments.of_vectorstoreretriever( + retriever=retriever, filter_func=filter_by_relevance + ) + + # LLM for completing prompts, and other tasks. + llm = OpenAI(temperature=0, max_tokens=128) + + # Conversation memory. + memory = ConversationSummaryBufferMemory( + max_token_limit=650, + llm=llm, + memory_key="chat_history", + output_key='answer' + ) + + # Conversational chain puts it all together. + chain = ConversationalRetrievalChain.from_llm( + llm=llm, + retriever=retriever, + verbose=verb, + return_source_documents=True, + memory=memory, + get_chat_history=ident, + max_tokens_limit=4096 + ) + + # Need to copy these otherwise various chains will feature templates that + # point to the same objects. + chain.combine_docs_chain.llm_chain.prompt = \ + chain.combine_docs_chain.llm_chain.prompt.copy() + chain.combine_docs_chain.document_prompt = \ + chain.combine_docs_chain.document_prompt.copy() + + if "lang" in chain_id: + # Language mismatch fix: + chain.combine_docs_chain.llm_chain.prompt.template = \ + "Use the following pieces of context to answer the question at the end " \ + "in the same language as the question. If you don't know the answer, " \ + "just say that you don't know, don't try to make up an answer.\n" \ + "\n" \ + "{context}\n" \ + "\n" \ + "Question: {question}\n" \ + "Helpful Answer: " + + elif "relevance" in chain_id: + # Contexts fix + + # whitespace important in "Contexts! " + chain.combine_docs_chain.llm_chain.prompt.template = \ + "Use only the relevant contexts to answer the question at the end " \ + ". Some pieces of context may not be relevant. If you don't know the answer, " \ + "just say that you don't know, don't try to make up an answer.\n" \ + "\n" \ + "Contexts: \n" \ + "{context}\n" \ + "\n" \ + "Question: {question}\n" \ + "Helpful Answer: " + + # "\t" important here: + chain.combine_docs_chain.document_prompt.template = "\tContext: {page_content}" + + # Trulens instrumentation. + tc = tru.Chain( + chain=chain, + chain_id=chain_id, + feedbacks=[f_lang_match, f_qa_relevance, f_qs_relevance], + feedback_mode="deferred" + ) + + convos[cid] = tc + + return tc + + +def get_answer(chain: TruChain, question: str) -> Tuple[str, str]: + """ + Use the given `chain` to respond to `question`. Return the answer text and + sources elaboration text. + """ + + # Pace our API usage. This is not perfect since the chain makes multiple api calls + # internally. + openai.endpoint.pace_me() + + outs = chain(dict(question=question)) + + result = outs['answer'] + sources = outs['source_documents'] + + result_sources = "Sources:\n" + + temp = set() + + for doc in sources: + src = doc.metadata['source'] + if src not in temp: + result_sources += " - " + doc.metadata['source'] + if 'page' in doc.metadata: + result_sources += f" (page {int(doc.metadata['page'])})\n" + else: + result_sources += "\n" + + temp.add(src) + + return result, result_sources + + +def answer_message(client, body: dict, logger): + """ + SlackAPI handler of message received. + """ + + pp.pprint(body) + + ts = body['event']['ts'] + user = body['event']['user'] + + if (ts, user) in handled_ts: + print(f"WARNING: I already handled message with ts={ts}, user={user} .") + return + else: + handled_ts.add((ts, user)) + + message = body['event']['text'] + channel = body['event']['channel'] + + if "thread_ts" in body['event']: + client.chat_postMessage( + channel=channel, thread_ts=ts, text=f"Looking..." + ) + + convo_id = body['event']['thread_ts'] + + chain = get_or_make_chain(convo_id) + + else: + convo_id = ts + + if len(message) >= 2 and message[0].lower() == "s" and message[1] in [ + "0", "1", "2", "3", "4", "5" + ]: + selector = int(message[1]) + chain = get_or_make_chain(convo_id, selector=selector) + + client.chat_postMessage( + channel=channel, + thread_ts=ts, + text=f"I will use chain {chain.chain_id} for this conversation." + ) + + if len(message) == 2: + return + else: + message = message[2:] + + else: + chain = get_or_make_chain(convo_id) + + client.chat_postMessage( + channel=channel, + thread_ts=ts, + text=f"Hi. Let me check that for you..." + ) + + res, res_sources = get_answer(chain, message) + + client.chat_postMessage( + channel=channel, + thread_ts=ts, + text=str(res) + "\n" + str(res_sources), + blocks=[ + dict(type="section", text=dict(type='mrkdwn', text=str(res))), + dict( + type="context", + elements=[dict(type='mrkdwn', text=str(res_sources))] + ) + ] + ) + + pp.pprint(res) + pp.pprint(res_sources) + + logger.info(body) + + +# WebClient instantiates a client that can call API methods When using Bolt, you +# can use either `app.client` or the `client` passed to listeners. +client = WebClient(token=SLACK_TOKEN) +logger = logging.getLogger(__name__) + +# Initializes your app with your bot token and signing secret +app = App(token=SLACK_TOKEN, signing_secret=SLACK_SIGNING_SECRET) + + +@app.event("app_home_opened") +def update_home_tab(client, event, logger): + try: + # views.publish is the method that your app uses to push a view to the Home tab + client.views_publish( + # the user that opened your app's app home + user_id=event["user"], + # the view object that appears in the app home + view={ + "type": + "home", + "callback_id": + "home_view", + + # body of the view + "blocks": + [ + { + "type": "section", + "text": + { + "type": + "mrkdwn", + "text": + "*I'm here to answer questions and test feedback functions.* :tada: Note that all of my conversations and thinking are recorded." + } + } + ] + } + ) + + except Exception as e: + logger.error(f"Error publishing home tab: {e}") + + +@app.event("message") +def handle_message_events(body, logger): + """ + Handle direct messages to the bot. + """ + + answer_message(client, body, logger) + + +@app.event("app_mention") +def handle_app_mention_events(body, logger): + """ + Handle messages that mention the bot. + """ + + answer_message(client, body, logger) + + +def start_bot(): + tru.start_evaluator() + app.start(port=int(PORT)) + + +# Start your app +if __name__ == "__main__": + start_bot() diff --git a/trulens_eval/trulens_eval/test_tru_chain.py b/trulens_eval/trulens_eval/test_tru_chain.py new file mode 100644 index 000000000..1ad3d4851 --- /dev/null +++ b/trulens_eval/trulens_eval/test_tru_chain.py @@ -0,0 +1,172 @@ +# from llama.hf import LLaMATokenizer + +from langchain import LLMChain +from langchain import PromptTemplate +from langchain.chains import ConversationalRetrievalChain +from langchain.chains import SimpleSequentialChain +from langchain.embeddings.openai import OpenAIEmbeddings +from langchain.llms import HuggingFacePipeline +from langchain.memory import ConversationBufferWindowMemory +from langchain.vectorstores import Pinecone +import pinecone +import pytest +import torch +from transformers import AutoModelForCausalLM +from transformers import AutoTokenizer +from transformers import pipeline + +from trulens_eval.keys import PINECONE_API_KEY +from trulens_eval.keys import PINECONE_ENV +from trulens_eval.tru_chain import TruChain + + +class TestTruChain(): + + def setup_method(self): + print("setup") + + self.llm_model_id = "gpt2" + # This model is pretty bad but using it for tests because it is free and + # relatively small. + + # model_id = "decapoda-research/llama-7b-hf" + # model_id = "decapoda-research/llama-13b-hf" + + self.model = AutoModelForCausalLM.from_pretrained( + self.llm_model_id, + device_map='auto', + torch_dtype=torch.float16, + local_files_only=True + ) + + self.tokenizer = AutoTokenizer.from_pretrained( + self.llm_model_id, local_files_only=True + ) + + self.pipe = pipeline( + "text-generation", + model=self.model, + tokenizer=self.tokenizer, + max_new_tokens=16, + device_map="auto", + early_stopping=True + ) + + self.llm = HuggingFacePipeline(pipeline=self.pipe) + + def test_qa_prompt(self): + # Test of a small q/a chain using a prompt and a single call to an llm. + + # llm = OpenAI() + + template = """Q: {question} A:""" + prompt = PromptTemplate(template=template, input_variables=["question"]) + llm_chain = LLMChain(prompt=prompt, llm=self.llm) + + tru_chain = TruChain(chain=llm_chain) + + assert tru_chain.model is not None + + tru_chain.run(dict(question="How are you?")) + tru_chain.run(dict(question="How are you today?")) + + assert len(tru_chain.db.select()) == 2 + + def test_qa_prompt_with_memory(self): + # Test of a small q/a chain using a prompt and a single call to an llm. + # Also has memory. + + # llm = OpenAI() + + template = """Q: {question} A:""" + prompt = PromptTemplate(template=template, input_variables=["question"]) + + memory = ConversationBufferWindowMemory(k=2) + + llm_chain = LLMChain(prompt=prompt, llm=self.llm, memory=memory) + + tru_chain = TruChain(chain=llm_chain) + + assert tru_chain.model is not None + + tru_chain.run(dict(question="How are you?")) + tru_chain.run(dict(question="How are you today?")) + + assert len(tru_chain.db.select()) == 2 + + @pytest.mark.nonfree + def test_qa_db(self): + # Test a q/a chain that uses a vector store to look up context to include in + # llm prompt. + + # WARNING: this test incurs calls to pinecone and openai APIs and may cost money. + + index_name = "llmdemo" + + embedding = OpenAIEmbeddings( + model='text-embedding-ada-002' + ) # 1536 dims + + pinecone.init( + api_key=PINECONE_API_KEY, # find at app.pinecone.io + environment=PINECONE_ENV # next to api key in console + ) + docsearch = Pinecone.from_existing_index( + index_name=index_name, embedding=embedding + ) + + # llm = OpenAI(temperature=0,max_tokens=128) + + retriever = docsearch.as_retriever() + chain = ConversationalRetrievalChain.from_llm( + llm=self.llm, retriever=retriever, return_source_documents=True + ) + + tru_chain = TruChain(chain) + assert tru_chain.model is not None + tru_chain(dict(question="How do I add a model?", chat_history=[])) + + assert len(tru_chain.db.select()) == 1 + + def test_sequential(self): + # Test of a sequential chain that contains the same llm twice with + # different prompts. + + template = """Q: {question} A:""" + prompt = PromptTemplate(template=template, input_variables=["question"]) + llm_chain = LLMChain(prompt=prompt, llm=self.llm) + + template_2 = """Reverse this sentence: {sentence}.""" + prompt_2 = PromptTemplate( + template=template_2, input_variables=["sentence"] + ) + llm_chain_2 = LLMChain(prompt=prompt_2, llm=self.llm) + + seq_chain = SimpleSequentialChain( + chains=[llm_chain, llm_chain_2], + input_key="question", + output_key="answer" + ) + seq_chain.run( + question="What is the average air speed velocity of a laden swallow?" + ) + + tru_chain = TruChain(seq_chain) + assert tru_chain.model is not None + + # This run should not be recorded. + seq_chain.run( + question="What is the average air speed velocity of a laden swallow?" + ) + + # These two should. + tru_chain.run( + question= + "What is the average air speed velocity of a laden european swallow?" + ) + tru_chain.run( + question= + "What is the average air speed velocity of a laden african swallow?" + ) + + assert len(tru_chain.db.select()) == 2 diff --git a/trulens_eval/trulens_eval/tru.py b/trulens_eval/trulens_eval/tru.py new file mode 100644 index 000000000..5b0e3fea9 --- /dev/null +++ b/trulens_eval/trulens_eval/tru.py @@ -0,0 +1,344 @@ +from datetime import datetime +import logging +from multiprocessing import Process +import os +from pathlib import Path +import subprocess +from threading import Thread +import threading +from time import sleep +from typing import Iterable, List, Optional, Sequence, Union + +import pkg_resources + +from trulens_eval.tru_db import JSON +from trulens_eval.tru_db import LocalSQLite +from trulens_eval.tru_db import TruDB +from trulens_eval.tru_feedback import Feedback +from trulens_eval.util import TP, SingletonPerName + + +class Tru(SingletonPerName): + """ + Tru is the main class that provides an entry points to trulens-eval. Tru lets you: + + * Log chain prompts and outputs + * Log chain Metadata + * Run and log feedback functions + * Run streamlit dashboard to view experiment results + + All data is logged to the current working directory to default.sqlite. + """ + DEFAULT_DATABASE_FILE = "default.sqlite" + + # Process or Thread of the deferred feedback function evaluator. + evaluator_proc = None + + # Process of the dashboard app. + dashboard_proc = None + + def Chain(self, *args, **kwargs): + """ + Create a TruChain with database managed by self. + """ + + from trulens_eval.tru_chain import TruChain + + return TruChain(tru=self, *args, **kwargs) + + def __init__(self): + """ + TruLens instrumentation, logging, and feedback functions for chains. + Creates a local database 'default.sqlite' in current working directory. + """ + + if hasattr(self, "db"): + # Already initialized by SingletonByName mechanism. + return + + self.db = LocalSQLite(Tru.DEFAULT_DATABASE_FILE) + + def reset_database(self): + """ + Reset the database. Clears all tables. + """ + + self.db.reset_database() + + def add_record( + self, + prompt: str, + response: str, + record_json: JSON, + tags: Optional[str] = "", + ts: Optional[int] = None, + total_tokens: Optional[int] = None, + total_cost: Optional[float] = None, + ): + """ + Add a record to the database. + + Parameters: + + prompt (str): Chain input or "prompt". + + response (str): Chain output or "response". + + record_json (JSON): Record as produced by `TruChain.call_with_record`. + + tags (str, optional): Additional metadata to include with the record. + + ts (int, optional): Timestamp of record creation. + + total_tokens (int, optional): The number of tokens generated in + producing the response. + + total_cost (float, optional): The cost of producing the response. + + Returns: + str: Unique record identifier. + + """ + ts = ts or datetime.now() + total_tokens = total_tokens or record_json['_cost']['total_tokens'] + total_cost = total_cost or record_json['_cost']['total_cost'] + + chain_id = record_json['chain_id'] + + record_id = self.db.insert_record( + chain_id=chain_id, + input=prompt, + output=response, + record_json=record_json, + ts=ts, + tags=tags, + total_tokens=total_tokens, + total_cost=total_cost + ) + + return record_id + + def run_feedback_functions( + self, + record_json: JSON, + feedback_functions: Sequence['Feedback'], + chain_json: Optional[JSON] = None, + ) -> Sequence[JSON]: + """ + Run a collection of feedback functions and report their result. + + Parameters: + + record_json (JSON): The record on which to evaluate the feedback + functions. + + chain_json (JSON, optional): The chain that produced the given record. + If not provided, it is looked up from the given database `db`. + + feedback_functions (Sequence[Feedback]): A collection of feedback + functions to evaluate. + + Returns nothing. + """ + + chain_id = record_json['chain_id'] + + if chain_json is None: + chain_json = self.db.get_chain(chain_id=chain_id) + if chain_json is None: + raise RuntimeError( + "Chain {chain_id} not present in db. " + "Either add it with `tru.add_chain` or provide `chain_json` to `tru.run_feedback_functions`." + ) + + else: + assert chain_id == chain_json[ + 'chain_id'], "Record was produced by a different chain." + + if self.db.get_chain(chain_id=chain_json['chain_id']) is None: + logging.warn( + "Chain {chain_id} was not present in database. Adding it." + ) + self.add_chain(chain_json=chain_json) + + evals = [] + + for func in feedback_functions: + evals.append( + TP().promise( + lambda f: f.run_on_record( + chain_json=chain_json, record_json=record_json + ), func + ) + ) + + evals = map(lambda p: p.get(), evals) + + return list(evals) + + def add_chain( + self, chain_json: JSON, chain_id: Optional[str] = None + ) -> None: + """ + Add a chain to the database. + """ + + self.db.insert_chain(chain_id=chain_id, chain_json=chain_json) + + def add_feedback(self, result_json: JSON) -> None: + """ + Add a single feedback result to the database. + """ + + if 'record_id' not in result_json or result_json['record_id'] is None: + raise RuntimeError( + "Result does not include record_id. " + "To log feedback, log the record first using `tru.add_record`." + ) + + self.db.insert_feedback(result_json=result_json, status=2) + + def add_feedbacks(self, result_jsons: Iterable[JSON]) -> None: + """ + Add multiple feedback results to the database. + """ + + for result_json in result_jsons: + self.add_feedback(result_json=result_json) + + def get_chain(self, chain_id: str) -> JSON: + """ + Look up a chain from the database. + """ + + return self.db.get_chain(chain_id) + + def get_records_and_feedback(self, chain_ids: List[str]): + """ + Get records, their feeback results, and feedback names from the database. + """ + + df, feedback_columns = self.db.get_records_and_feedback(chain_ids) + + return df, feedback_columns + + def start_evaluator(self, fork=False) -> Union[Process, Thread]: + """ + Start a deferred feedback function evaluation thread. + """ + + assert not fork, "Fork mode not yet implemented." + + if self.evaluator_proc is not None: + raise RuntimeError("Evaluator is already running in this process.") + + from trulens_eval.tru_feedback import Feedback + + if not fork: + self.evaluator_stop = threading.Event() + + def runloop(): + while fork or not self.evaluator_stop.is_set(): + print("Looking for things to do. Stop me with `tru.stop_evaluator()`.", end='') + Feedback.evaluate_deferred(tru=self) + TP().finish(timeout=10) + if fork: + sleep(10) + else: + self.evaluator_stop.wait(10) + + print("Evaluator stopped.") + + if fork: + proc = Process(target=runloop) + else: + proc = Thread(target=runloop) + + # Start a persistent thread or process that evaluates feedback functions. + + self.evaluator_proc = proc + proc.start() + + return proc + + def stop_evaluator(self): + """ + Stop the deferred feedback evaluation thread. + """ + + if self.evaluator_proc is None: + raise RuntimeError("Evaluator not running this process.") + + if isinstance(self.evaluator_proc, Process): + self.evaluator_proc.terminate() + + elif isinstance(self.evaluator_proc, Thread): + self.evaluator_stop.set() + self.evaluator_proc.join() + self.evaluator_stop = None + + self.evaluator_proc = None + + def stop_dashboard(self) -> None: + """Stop existing dashboard if running. + + Raises: + ValueError: Dashboard is already running. + """ + if Tru.dashboard_proc is None: + raise ValueError("Dashboard not running.") + + Tru.dashboard_proc.kill() + Tru.dashboard_proc = None + + def run_dashboard(self, _dev: bool = False) -> Process: + """ Runs a streamlit dashboard to view logged results and chains + + Raises: + ValueError: Dashboard is already running. + + Returns: + Process: Process containing streamlit dashboard. + """ + + if Tru.dashboard_proc is not None: + raise ValueError("Dashboard already running. Run tru.stop_dashboard() to stop existing dashboard.") + + # Create .streamlit directory if it doesn't exist + streamlit_dir = os.path.join(os.getcwd(), '.streamlit') + os.makedirs(streamlit_dir, exist_ok=True) + + # Create config.toml file + config_path = os.path.join(streamlit_dir, 'config.toml') + with open(config_path, 'w') as f: + f.write('[theme]\n') + f.write('primaryColor="#0A2C37"\n') + f.write('backgroundColor="#FFFFFF"\n') + f.write('secondaryBackgroundColor="F5F5F5"\n') + f.write('textColor="#0A2C37"\n') + f.write('font="sans serif"\n') + + cred_path = os.path.join(streamlit_dir, 'credentials.toml') + with open(cred_path, 'w') as f: + f.write('[general]\n') + f.write('email=""\n') + + #run leaderboard with subprocess + leaderboard_path = pkg_resources.resource_filename( + 'trulens_eval', 'Leaderboard.py' + ) + + env_opts = {} + if _dev: + env_opts['env'] = os.environ + env_opts['env']['PYTHONPATH'] = str(Path.cwd()) + + proc = subprocess.Popen( + ["streamlit", "run", "--server.headless=True", leaderboard_path], **env_opts + ) + + Tru.dashboard_proc = proc + + return proc + + start_dashboard = run_dashboard \ No newline at end of file diff --git a/trulens_eval/trulens_eval/tru_chain.py b/trulens_eval/trulens_eval/tru_chain.py new file mode 100644 index 000000000..854508a4d --- /dev/null +++ b/trulens_eval/trulens_eval/tru_chain.py @@ -0,0 +1,656 @@ +""" +# Langchain instrumentation and monitoring. + +## Limitations + +- Uncertain thread safety. + +- If the same wrapped sub-chain is called multiple times within a single call to + the root chain, the record of this execution will not be exact with regards to + the path to the call information. All call dictionaries will appear in a list + addressed by the last subchain (by order in which it is instrumented). For + example, in a sequential chain containing two of the same chain, call records + will be addressed to the second of the (same) chains and contain a list + describing calls of both the first and second. + +- Some chains cannot be serialized/jsonized. Sequential chain is an example. + This is a limitation of langchain itself. + +## Basic Usage + +- Wrap an existing chain: + +```python + + tc = TruChain(t.llm_chain) + +``` + +- Retrieve the parameters of the wrapped chain: + +```python + + tc.chain + +``` + +Output: + +```json + +{'memory': None, + 'verbose': False, + 'chain': {'memory': None, + 'verbose': True, + 'prompt': {'input_variables': ['question'], + 'output_parser': None, + 'partial_variables': {}, + 'template': 'Q: {question} A:', + 'template_format': 'f-string', + 'validate_template': True, + '_type': 'prompt'}, + 'llm': {'model_id': 'gpt2', + 'model_kwargs': None, + '_type': 'huggingface_pipeline'}, + 'output_key': 'text', + '_type': 'llm_chain'}, + '_type': 'TruChain'} + + ``` + +- Make calls like you would to the wrapped chain. + +```python + + rec1: dict = tc("hello there") + rec2: dict = tc("hello there general kanobi") + +``` + +""" + +from collections import defaultdict +from datetime import datetime +from inspect import BoundArguments +from inspect import signature +from inspect import stack +import logging +import os +import threading as th +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union + +import langchain +from langchain.callbacks import get_openai_callback +from langchain.chains.base import Chain +from pydantic import BaseModel +from pydantic import Field + +from trulens_eval.tru_db import JSON, noserio +from trulens_eval.tru_db import obj_id_of_obj +from trulens_eval.tru_db import Query +from trulens_eval.tru_db import TruDB +from trulens_eval.tru_feedback import Feedback +from trulens_eval.util import TP + +langchain.verbose = False + +# Addresses of chains or their contents. This is used to refer chains/parameters +# even in cases where the live object is not in memory (i.e. on some remote +# app). +Path = Tuple[Union[str, int], ...] + +# Records of a chain run are dictionaries with these keys: +# +# - 'args': Dict[str, Any] -- chain __call__ input args. +# - 'rets': Dict[str, Any] -- chain __call__ return dict for calls that succeed. +# - 'error': str -- exception text if not successful. +# - 'start_time': datetime +# - 'end_time': datetime -- runtime info. +# - 'pid': int -- process id for debugging multiprocessing. +# - 'tid': int -- thread id for debuggin threading. +# - 'chain_stack': List[Path] -- call stack of chain runs. Elements address +# chains. + + +class TruChain(Chain): + """ + Wrap a langchain Chain to capture its configuration and evaluation steps. + """ + + # The wrapped/instrumented chain. + chain: Chain = None + + # Chain name/id. Will be a hash of chain definition/configuration if not provided. + chain_id: Optional[str] = None + + # Flag of whether the chain is currently recording records. This is set + # automatically but is imperfect in threaded situations. The second check + # for recording is based on the call stack, see _call. + recording: Optional[bool] = Field(exclude=True) + + # Feedback functions to evaluate on each record. + feedbacks: Optional[Sequence[Feedback]] = Field(exclude=True) + + # Feedback evaluation mode. + # - "withchain" - Try to run feedback functions immediately and before chain + # returns a record. + # - "withchainthread" - Try to run feedback functions in the same process as + # the chain but after it produces a record. + # - "deferred" - Evaluate later via the process started by + # `tru.start_deferred_feedback_evaluator`. + # - None - No evaluation will happen even if feedback functions are specified. + + # NOTE: Custom feedback functions cannot be run deferred and will be run as + # if "withchainthread" was set. + feedback_mode: Optional[str] = "withchainthread" + + # Database interfaces for models/records/feedbacks. + tru: Optional['Tru'] = Field(exclude=True) + + # Database interfaces for models/records/feedbacks. + db: Optional[TruDB] = Field(exclude=True) + + def __init__( + self, + chain: Chain, + chain_id: Optional[str] = None, + verbose: bool = False, + feedbacks: Optional[Sequence[Feedback]] = None, + feedback_mode: Optional[str] = "withchainthread", + tru: Optional['Tru'] = None + ): + """ + Wrap a chain for monitoring. + + Arguments: + + - chain: Chain -- the chain to wrap. + - chain_id: Optional[str] -- chain name or id. If not given, the + name is constructed from wrapped chain parameters. + """ + + Chain.__init__(self, verbose=verbose) + + self.chain = chain + + self._instrument_object(obj=self.chain, query=Query().chain) + self.recording = False + + chain_def = self.json + + # Track chain id. This will produce a name if not provided. + self.chain_id = chain_id or obj_id_of_obj(obj=chain_def, prefix="chain") + + if feedbacks is not None and tru is None: + raise ValueError("Feedback logging requires `tru` to be specified.") + + self.feedbacks = feedbacks or [] + + assert feedback_mode in [ + 'withchain', 'withchainthread', 'deferred', None + ], "`feedback_mode` must be one of 'withchain', 'withchainthread', 'deferred', or None." + self.feedback_mode = feedback_mode + + if tru is not None: + self.db = tru.db + + if feedback_mode is None: + logging.warn( + "`tru` is specified but `feedback_mode` is None. " + "No feedback evaluation and logging will occur." + ) + else: + if feedback_mode is not None: + logging.warn( + f"`feedback_mode` is {feedback_mode} but `tru` was not specified. Reverting to None." + + ) + self.feedback_mode = None + feedback_mode = None + # Import here to avoid circular imports. + # from trulens_eval import Tru + # tru = Tru() + + self.tru = tru + + if tru is not None and feedback_mode is not None: + logging.debug( + "Inserting chain and feedback function definitions to db." + ) + self.db.insert_chain(chain_id=self.chain_id, chain_json=self.json) + for f in self.feedbacks: + self.db.insert_feedback_def(f.json) + + @property + def json(self): + temp = TruDB.jsonify(self) # not using self.dict() + # Need these to run feedback functions when they don't specify their + # inputs exactly. + + temp['input_keys'] = self.input_keys + temp['output_keys'] = self.output_keys + + return temp + + # Chain requirement + @property + def _chain_type(self): + return "TruChain" + + # Chain requirement + @property + def input_keys(self) -> List[str]: + return self.chain.input_keys + + # Chain requirement + @property + def output_keys(self) -> List[str]: + return self.chain.output_keys + + def call_with_record(self, *args, **kwargs): + """ Run the chain and also return a record metadata object. + + + Returns: + Any: chain output + dict: record metadata + """ + # Mark us as recording calls. Should be sufficient for non-threaded + # cases. + self.recording = True + + # Wrapped calls will look this up by traversing the call stack. This + # should work with threads. + record = defaultdict(list) + + ret = None + error = None + + total_tokens = None + total_cost = None + + try: + # TODO: do this only if there is an openai model inside the chain: + with get_openai_callback() as cb: + ret = self.chain.__call__(*args, **kwargs) + total_tokens = cb.total_tokens + total_cost = cb.total_cost + + except BaseException as e: + error = e + logging.error(f"Chain raised an exception: {e}") + + self.recording = False + + assert len(record) > 0, "No information recorded in call." + + ret_record = dict() + chain_json = self.json + + for path, calls in record.items(): + obj = TruDB._project(path=path, obj=chain_json) + + if obj is None: + logging.warn(f"Cannot locate {path} in chain.") + + ret_record = TruDB._set_in_json( + path=path, in_json=ret_record, val={"_call": calls} + ) + + ret_record['_cost'] = dict( + total_tokens=total_tokens, total_cost=total_cost + ) + ret_record['chain_id'] = self.chain_id + + if error is not None: + + if self.feedback_mode == "withchain": + self._handle_error(record_json=ret_record, error=error) + + elif self.feedback_mode in ["deferred", "withchainthread"]: + TP().runlater( + self._handle_error, record_json=ret_record, error=error + ) + + raise error + + if self.feedback_mode == "withchain": + self._handle_record(record_json=ret_record) + + elif self.feedback_mode in ["deferred", "withchainthread"]: + TP().runlater(self._handle_record, record_json=ret_record) + + return ret, ret_record + + # langchain.chains.base.py:Chain + def __call__(self, *args, **kwargs) -> Dict[str, Any]: + """ + Wrapped call to self.chain.__call__ with instrumentation. If you need to + get the record, use `call_with_record` instead. + """ + + ret, record = self.call_with_record(*args, **kwargs) + + return ret + + def _handle_record(self, record_json: JSON): + """ + Write out record-related info to database if set. + """ + + if self.tru is None or self.feedback_mode is None: + return + + main_input = record_json['chain']['_call']['args']['inputs'][ + self.input_keys[0]] + main_output = record_json['chain']['_call']['rets'][self.output_keys[0]] + + record_id = self.tru.add_record( + prompt=main_input, + response=main_output, + record_json=record_json, + tags='dev', # TODO: generalize + total_tokens=record_json['_cost']['total_tokens'], + total_cost=record_json['_cost']['total_cost'] + ) + + if len(self.feedbacks) == 0: + return + + # Add empty (to run) feedback to db. + if self.feedback_mode == "deferred": + for f in self.feedbacks: + feedback_id = f.feedback_id + self.db.insert_feedback(record_id, feedback_id) + + elif self.feedback_mode in ["withchain", "withchainthread"]: + + results = self.tru.run_feedback_functions( + record_json=record_json, + feedback_functions=self.feedbacks, + chain_json=self.json + ) + + for result_json in results: + self.tru.add_feedback(result_json) + + def _handle_error(self, record, error): + if self.db is None: + return + + pass + + # Chain requirement + # TODO(piotrm): figure out whether the combination of _call and __call__ is working right. + def _call(self, *args, **kwargs) -> Any: + return self.chain._call(*args, **kwargs) + + def _get_local_in_call_stack( + self, key: str, func: Callable, offset: int = 1 + ) -> Optional[Any]: + """ + Get the value of the local variable named `key` in the stack at the + nearest frame executing `func`. Returns None if `func` is not in call + stack. Raises RuntimeError if `func` is in call stack but does not have + `key` in its locals. + """ + + for fi in stack()[offset + 1:]: # + 1 to skip this method itself + if id(fi.frame.f_code) == id(func.__code__): + locs = fi.frame.f_locals + if key in locs: + return locs[key] + else: + raise RuntimeError(f"No local named {key} in {func} found.") + + return None + + def _instrument_dict(self, cls, obj: Any): + """ + Replacement for langchain's dict method to one that does not fail under + non-serialization situations. + """ + + if obj.memory is not None: + + # logging.warn( + # f"Will not be able to serialize object of type {cls} because it has memory." + # ) + + pass + + def safe_dict(s, json: bool = True, **kwargs: Any) -> Dict: + """ + Return dictionary representation `s`. If `json` is set, will make + sure output can be serialized. + """ + + #if s.memory is not None: + # continue anyway + # raise ValueError("Saving of memory is not yet supported.") + + _dict = super(cls, s).dict(**kwargs) + _dict["_type"] = s._chain_type + + # TODO: json + + return _dict + + safe_dict._instrumented = getattr(cls, "dict") + + return safe_dict + + def _instrument_type_method(self, obj, prop): + """ + Instrument the Langchain class's method _*_type which is presently used + to control chain saving. Override the exception behaviour. Note that + _chain_type is defined as a property in langchain. + """ + + # Properties doesn't let us new define new attributes like "_instrument" + # so we put it on fget instead. + if hasattr(prop.fget, "_instrumented"): + prop = prop.fget._instrumented + + def safe_type(s) -> Union[str, Dict]: + # self should be chain + try: + ret = prop.fget(s) + return ret + + except NotImplementedError as e: + + return noserio(obj, error=f"{e.__class__.__name__}='{str(e)}'") + + safe_type._instrumented = prop + new_prop = property(fget=safe_type) + + return new_prop + + def _instrument_call(self, query: Query, func: Callable): + """ + Instrument a Chain.__call__ method to capture its inputs/outputs/errors. + """ + + if hasattr(func, "_instrumented"): + if self.verbose: + print(f"{func} is already instrumented") + + # Already instrumented. Note that this may happen under expected + # operation when the same chain is used multiple times as part of a + # larger chain. + + # TODO: How to consistently address calls to chains that appear more + # than once in the wrapped chain or are called more than once. + func = func._instrumented + + sig = signature(func) + + def wrapper(*args, **kwargs): + # If not within TruChain._call, call the wrapped function without + # any recording. This check is not perfect in threaded situations so + # the next call stack-based lookup handles the rarer cases. + + # NOTE(piotrm): Disabling this for now as it is not thread safe. + #if not self.recording: + # return func(*args, **kwargs) + + # Look up whether TruChain._call was called earlier in the stack and + # "record" variable was defined there. Will use that for recording + # the wrapped call. + record = self._get_local_in_call_stack( + key="record", func=TruChain.call_with_record + ) + + if record is None: + return func(*args, **kwargs) + + # Otherwise keep track of inputs and outputs (or exception). + + error = None + ret = None + + start_time = datetime.now() + + chain_stack = self._get_local_in_call_stack( + key="chain_stack", func=wrapper, offset=1 + ) or [] + chain_stack = chain_stack + [query._path] + + try: + # Using sig bind here so we can produce a list of key-value + # pairs even if positional arguments were provided. + bindings: BoundArguments = sig.bind(*args, **kwargs) + ret = func(*bindings.args, **bindings.kwargs) + + except BaseException as e: + error = e + + end_time = datetime.now() + + # Don't include self in the recorded arguments. + nonself = { + k: TruDB.jsonify(v) + for k, v in bindings.arguments.items() + if k != "self" + } + row_args = dict( + args=nonself, + start_time=str(start_time), + end_time=str(end_time), + pid=os.getpid(), + tid=th.get_native_id(), + chain_stack=chain_stack + ) + + if error is not None: + row_args['error'] = error + else: + row_args['rets'] = ret + + # If there already is a call recorded at the same path, turn the + # calls into a list. + if query._path in record: + existing_call = record[query._path] + if isinstance(existing_call, dict): + record[query._path] = [existing_call, row_args] + else: + record[query._path].append(row_args) + else: + # Otherwise record just the one call not inside a list. + record[query._path] = row_args + + if error is not None: + raise error + + return ret + + wrapper._instrumented = func + + # Put the address of the instrumented chain in the wrapper so that we + # don't pollute its list of fields. Note that this address may be + # deceptive if the same subchain appears multiple times in the wrapped + # chain. + wrapper._query = query + + return wrapper + + def _instrument_object(self, obj, query: Query): + if self.verbose: + print(f"instrumenting {query._path} {obj.__class__.__name__}") + + cls = obj.__class__ + + # NOTE: We cannot instrument chain directly and have to instead + # instrument its class. The pydantic BaseModel does not allow instance + # attributes that are not fields: + # https://github.com/pydantic/pydantic/blob/11079e7e9c458c610860a5776dc398a4764d538d/pydantic/main.py#LL370C13-L370C13 + # . + for base in cls.mro(): + # All of mro() may need instrumentation here if some subchains call + # superchains, and we want to capture the intermediate steps. + + if not base.__module__.startswith("langchain."): + continue + + if hasattr(base, "_call"): + original_fun = getattr(base, "_call") + + if self.verbose: + print(f"instrumenting {base}._call") + + setattr( + base, "_call", + self._instrument_call(query=query, func=original_fun) + ) + + if hasattr(base, "_chain_type"): + if self.verbose: + print(f"instrumenting {base}._chain_type") + + prop = getattr(base, "_chain_type") + setattr( + base, "_chain_type", + self._instrument_type_method(obj=obj, prop=prop) + ) + + if hasattr(base, "_prompt_type"): + if self.verbose: + print(f"instrumenting {base}._chain_prompt") + + prop = getattr(base, "_prompt_type") + setattr( + base, "_prompt_type", + self._instrument_type_method(obj=obj, prop=prop) + ) + + if isinstance(obj, Chain): + if self.verbose: + print(f"instrumenting {base}.dict") + + setattr(base, "dict", self._instrument_dict(cls=base, obj=obj)) + + # Not using chain.dict() here as that recursively converts subchains to + # dicts but we want to traverse the instantiations here. + if isinstance(obj, BaseModel): + + for k in obj.__fields__: + # NOTE(piotrm): may be better to use inspect.getmembers_static . + v = getattr(obj, k) + + if isinstance(v, str): + pass + + elif v.__class__.__module__.startswith("langchain."): + self._instrument_object(obj=v, query=query[k]) + + elif isinstance(v, Sequence): + for i, sv in enumerate(v): + if isinstance(sv, Chain): + self._instrument_object(obj=sv, query=query[k][i]) + + # TODO: check if we want to instrument anything not accessible through __fields__ . + else: + logging.debug( + f"Do not know how to instrument object {str(obj)[:32]} of type {type(obj)}." + ) diff --git a/trulens_eval/trulens_eval/tru_db.py b/trulens_eval/trulens_eval/tru_db.py new file mode 100644 index 000000000..5cb6ef79e --- /dev/null +++ b/trulens_eval/trulens_eval/tru_db.py @@ -0,0 +1,934 @@ +import abc +import json +import logging +from pathlib import Path +import sqlite3 +from typing import ( + Any, Callable, Dict, Iterable, List, Optional, Sequence, Set, Tuple, Union +) + +from merkle_json import MerkleJson +import pandas as pd +import pydantic +from tinydb import Query as TinyQuery +from tinydb.queries import QueryInstance as TinyQueryInstance +from trulens_eval.util import UNCIODE_YIELD, UNICODE_CHECK + +mj = MerkleJson() +NoneType = type(None) + +JSON_BASES = (str, int, float, NoneType) +JSON_BASES_T = Union[str, int, float, NoneType] +# JSON = (List, Dict) + JSON_BASES +# JSON_T = Union[JSON_BASES_T, List, Dict] +JSON = Dict + +def is_empty(obj): + try: + return len(obj) == 0 + except Exception: + return False + + +def is_noserio(obj): + """ + Determines whether the given json object represents some non-serializable + object. See `noserio`. + """ + return isinstance(obj, dict) and "_NON_SERIALIZED_OBJECT" in obj + + +def noserio(obj, **extra: Dict) -> dict: + """ + Create a json structure to represent a non-serializable object. Any + additional keyword arguments are included. + """ + + inner = { + "id": id(obj), + "class": obj.__class__.__name__, + "module": obj.__class__.__module__, + "bases": list(map(lambda b: b.__name__, obj.__class__.__bases__)) + } + inner.update(extra) + + return {'_NON_SERIALIZED_OBJECT': inner} + + +def obj_id_of_obj(obj: dict, prefix="obj"): + """ + Create an id from a json-able structure/definition. Should produce the same + name if definition stays the same. + """ + + return f"{prefix}_hash_{mj.hash(obj)}" + + +def json_str_of_obj(obj: Any) -> str: + """ + Encode the given json object as a string. + """ + return json.dumps(obj, default=json_default) + + +def json_default(obj: Any) -> str: + """ + Produce a representation of an object which cannot be json-serialized. + """ + + if isinstance(obj, pydantic.BaseModel): + try: + return json.dumps(obj.dict()) + except Exception as e: + return noserio(obj, exception=e) + + # Intentionally not including much in this indicator to make sure the model + # hashing procedure does not get randomized due to something here. + + return noserio(obj) + + +# Typing for type hints. +Query = TinyQuery + +# Instance for constructing queries for record json like `Record.chain.llm`. +Record = Query()._record + +# Instance for constructing queries for chain json. +Chain = Query()._chain + +# Type of conditions, constructed from query/record like `Record.chain != None`. +Condition = TinyQueryInstance + + +def query_of_path(path: List[Union[str, int]]) -> Query: + if path[0] == "_record": + ret = Record + path = path[1:] + elif path[0] == "_chain": + ret = Chain + path = path[1:] + else: + ret = Query() + + for attr in path: + ret = getattr(ret, attr) + + return ret + + +def path_of_query(query: Query) -> List[Union[str, int]]: + return query._path + + +class TruDB(abc.ABC): + + # Use TinyDB queries for looking up parts of records/models and/or filtering + # on those parts. + + @abc.abstractmethod + def reset_database(self): + """Delete all data.""" + + raise NotImplementedError() + + @abc.abstractmethod + def select( + self, + *query: Tuple[Query], + where: Optional[Condition] = None + ) -> pd.DataFrame: + """ + Select `query` fields from the records database, filtering documents + that do not match the `where` condition. + """ + + raise NotImplementedError() + + @abc.abstractmethod + def insert_record( + self, chain_id: str, input: str, output: str, record_json: JSON, + ts: int, tags: str, total_tokens: int, total_cost: float + ) -> int: + """ + Insert a new `record` into db, indicating its `model` as well. Return + record id. + """ + + raise NotImplementedError() + + @abc.abstractmethod + def insert_chain( + self, chain_json: JSON, chain_id: Optional[str] = None + ) -> str: + """ + Insert a new `chain` into db under the given `chain_id`. If name not + provided, generate a name from chain definition. Return the name. + """ + + raise NotImplementedError() + + @abc.abstractmethod + def insert_feedback_def(self, feedback_json: dict): + raise NotImplementedError() + + @abc.abstractmethod + def insert_feedback( + self, + record_id: str, + feedback_id: str, + last_ts: Optional[int] = None, # "last timestamp" + status: Optional[int] = None, + result_json: Optional[JSON] = None, + total_cost: Optional[float] = None, + total_tokens: Optional[int] = None, + ) -> str: + raise NotImplementedError() + + @abc.abstractmethod + def get_records_and_feedback( + self, chain_ids: List[str] + ) -> Tuple[pd.DataFrame, pd.DataFrame]: + raise NotImplementedError() + + @staticmethod + def jsonify(obj: Any, dicted=None) -> JSON: + """ + Convert the given object into types that can be serialized in json. + """ + + dicted = dicted or dict() + + if isinstance(obj, JSON_BASES): + return obj + + if id(obj) in dicted: + return {'_CIRCULAR_REFERENCE': id(obj)} + + new_dicted = {k: v for k, v in dicted.items()} + + if isinstance(obj, Dict): + temp = {} + new_dicted[id(obj)] = temp + temp.update( + { + k: TruDB.jsonify(v, dicted=new_dicted) + for k, v in obj.items() + } + ) + return temp + + elif isinstance(obj, Sequence): + temp = [] + new_dicted[id(obj)] = temp + for x in (TruDB.jsonify(v, dicted=new_dicted) for v in obj): + temp.append(x) + return temp + + elif isinstance(obj, Set): + temp = [] + new_dicted[id(obj)] = temp + for x in (TruDB.jsonify(v, dicted=new_dicted) for v in obj): + temp.append(x) + return temp + + elif isinstance(obj, pydantic.BaseModel): + temp = {} + new_dicted[id(obj)] = temp + temp.update( + { + k: TruDB.jsonify(getattr(obj, k), dicted=new_dicted) + for k in obj.__fields__ + } + ) + return temp + + else: + logging.debug( + f"Don't know how to jsonify an object '{str(obj)[0:32]}' of type '{type(obj)}'." + ) + return noserio(obj) + + @staticmethod + def leaf_queries(obj_json: JSON, query: Query = None) -> Iterable[Query]: + """ + Get all queries for the given object that select all of its leaf values. + """ + + query = query or Record + + if isinstance(obj_json, (str, int, float, NoneType)): + yield query + + elif isinstance(obj_json, Dict): + for k, v in obj_json.items(): + sub_query = query[k] + for res in TruDB.leaf_queries(obj_json[k], sub_query): + yield res + + elif isinstance(obj_json, Sequence): + for i, v in enumerate(obj_json): + sub_query = query[i] + for res in TruDB.leaf_queries(obj_json[i], sub_query): + yield res + + else: + yield query + + @staticmethod + def all_queries(obj: Any, query: Query = None) -> Iterable[Query]: + """ + Get all queries for the given object. + """ + + query = query or Record + + if isinstance(obj, (str, int, float, NoneType)): + yield query + + elif isinstance(obj, pydantic.BaseModel): + yield query + + for k in obj.__fields__: + v = getattr(obj, k) + sub_query = query[k] + for res in TruDB.all_queries(v, sub_query): + yield res + + elif isinstance(obj, Dict): + yield query + + for k, v in obj.items(): + sub_query = query[k] + for res in TruDB.all_queries(obj[k], sub_query): + yield res + + elif isinstance(obj, Sequence): + yield query + + for i, v in enumerate(obj): + sub_query = query[i] + for res in TruDB.all_queries(obj[i], sub_query): + yield res + + else: + yield query + + @staticmethod + def all_objects(obj: Any, + query: Query = None) -> Iterable[Tuple[Query, Any]]: + """ + Get all queries for the given object. + """ + + query = query or Record + + if isinstance(obj, (str, int, float, NoneType)): + yield (query, obj) + + elif isinstance(obj, pydantic.BaseModel): + yield (query, obj) + + for k in obj.__fields__: + v = getattr(obj, k) + sub_query = query[k] + for res in TruDB.all_objects(v, sub_query): + yield res + + elif isinstance(obj, Dict): + yield (query, obj) + + for k, v in obj.items(): + sub_query = query[k] + for res in TruDB.all_objects(obj[k], sub_query): + yield res + + elif isinstance(obj, Sequence): + yield (query, obj) + + for i, v in enumerate(obj): + sub_query = query[i] + for res in TruDB.all_objects(obj[i], sub_query): + yield res + + else: + yield (query, obj) + + @staticmethod + def leafs(obj: Any) -> Iterable[Tuple[str, Any]]: + for q in TruDB.leaf_queries(obj): + path_str = TruDB._query_str(q) + val = TruDB._project(q._path, obj) + yield (path_str, val) + + @staticmethod + def matching_queries(obj: Any, match: Callable) -> Iterable[Query]: + for q in TruDB.all_queries(obj): + val = TruDB._project(q._path, obj) + if match(q, val): + yield q + + @staticmethod + def matching_objects(obj: Any, + match: Callable) -> Iterable[Tuple[Query, Any]]: + for q, val in TruDB.all_objects(obj): + if match(q, val): + yield (q, val) + + @staticmethod + def _query_str(query: Query) -> str: + + def render(ks): + if len(ks) == 0: + return "" + + first = ks[0] + if len(ks) > 1: + rest = ks[1:] + else: + rest = () + + if isinstance(first, str): + return f".{first}{render(rest)}" + elif isinstance(first, int): + return f"[{first}]{render(rest)}" + else: + RuntimeError( + f"Don't know how to render path element {first} of type {type(first)}." + ) + + return "Record" + render(query._path) + + @staticmethod + def set_in_json(query: Query, in_json: JSON, val: JSON) -> JSON: + return TruDB._set_in_json(query._path, in_json=in_json, val=val) + + @staticmethod + def _set_in_json(path, in_json: JSON, val: JSON) -> JSON: + if len(path) == 0: + if isinstance(in_json, Dict): + assert isinstance(val, Dict) + in_json = {k: v for k, v in in_json.items()} + in_json.update(val) + return in_json + + assert in_json is None, f"Cannot set non-None json object: {in_json}" + + return val + + if len(path) == 1: + first = path[0] + rest = [] + else: + first = path[0] + rest = path[1:] + + if isinstance(first, str): + if isinstance(in_json, Dict): + in_json = {k: v for k, v in in_json.items()} + if not first in in_json: + in_json[first] = None + elif in_json is None: + in_json = {first: None} + else: + raise RuntimeError( + f"Do not know how to set path {path} in {in_json}." + ) + + in_json[first] = TruDB._set_in_json( + path=rest, in_json=in_json[first], val=val + ) + return in_json + + elif isinstance(first, int): + if isinstance(in_json, Sequence): + # In case it is some immutable sequence. Also copy. + in_json = list(in_json) + elif in_json is None: + in_json = [] + else: + raise RuntimeError( + f"Do not know how to set path {path} in {in_json}." + ) + + while len(in_json) <= first: + in_json.append(None) + + in_json[first] = TruDB._set_in_json( + path=rest, in_json=in_json[first], val=val + ) + return in_json + + else: + raise RuntimeError( + f"Do not know how to set path {path} in {in_json}." + ) + + @staticmethod + def project( + query: Query, + record_json: JSON, + chain_json: JSON, + obj: Optional[JSON] = None + ): + path = query._path + if path[0] == "_record": + if len(path) == 1: + return record_json + return TruDB._project(path=path[1:], obj=record_json) + elif path[0] == "_chain": + if len(path) == 1: + return chain_json + return TruDB._project(path=path[1:], obj=chain_json) + else: + return TruDB._project(path=path, obj=obj) + + @staticmethod + def _project(path: List, obj: Any): + if len(path) == 0: + return obj + + first = path[0] + if len(path) > 1: + rest = path[1:] + else: + rest = () + + if isinstance(first, str): + if isinstance(obj, pydantic.BaseModel): + if not hasattr(obj, first): + logging.warn( + f"Cannot project {str(obj)[0:32]} with path {path} because {first} is not an attribute here." + ) + return None + return TruDB._project(path=rest, obj=getattr(obj, first)) + + elif isinstance(obj, Dict): + if first not in obj: + logging.warn( + f"Cannot project {str(obj)[0:32]} with path {path} because {first} is not a key here." + ) + return None + return TruDB._project(path=rest, obj=obj[first]) + + else: + logging.warn( + f"Cannot project {str(obj)[0:32]} with path {path} because object is not a dict or model." + ) + return None + + elif isinstance(first, int): + if not isinstance(obj, Sequence) or first >= len(obj): + logging.warn( + f"Cannot project {str(obj)[0:32]} with path {path}." + ) + return None + + return TruDB._project(path=rest, obj=obj[first]) + else: + raise RuntimeError( + f"Don't know how to locate element with key of type {first}" + ) + + +class LocalSQLite(TruDB): + + TABLE_RECORDS = "records" + TABLE_FEEDBACKS = "feedbacks" + TABLE_FEEDBACK_DEFS = "feedback_defs" + TABLE_CHAINS = "chains" + + def __str__(self): + return f"SQLite({self.filename})" + + def reset_database(self): + self._clear_tables() + self._build_tables() + + def _clear_tables(self): + conn, c = self._connect() + + # Create table if it does not exist + c.execute( + f'''DELETE FROM {self.TABLE_RECORDS}''' + ) + c.execute( + f'''DELETE FROM {self.TABLE_FEEDBACKS}''' + ) + c.execute( + f'''DELETE FROM {self.TABLE_FEEDBACK_DEFS}''' + ) + c.execute( + f'''DELETE FROM {self.TABLE_CHAINS}''' + ) + self._close(conn) + + def _build_tables(self): + conn, c = self._connect() + + # Create table if it does not exist + c.execute( + f'''CREATE TABLE IF NOT EXISTS {self.TABLE_RECORDS} ( + record_id TEXT, + chain_id TEXT, + input TEXT, + output TEXT, + record_json TEXT, + tags TEXT, + ts INTEGER, + total_tokens INTEGER, + total_cost REAL, + PRIMARY KEY (record_id, chain_id) + )''' + ) + c.execute( + f'''CREATE TABLE IF NOT EXISTS {self.TABLE_FEEDBACKS} ( + record_id TEXT, + feedback_id TEXT, + last_ts INTEGER, + status INTEGER, + result_json TEXT, + total_tokens INTEGER, + total_cost REAL, + PRIMARY KEY (record_id, feedback_id) + )''' + ) + c.execute( + f'''CREATE TABLE IF NOT EXISTS {self.TABLE_FEEDBACK_DEFS} ( + feedback_id TEXT PRIMARY KEY, + feedback_json TEXT + )''' + ) + c.execute( + f'''CREATE TABLE IF NOT EXISTS {self.TABLE_CHAINS} ( + chain_id TEXT PRIMARY KEY, + chain_json TEXT + )''' + ) + self._close(conn) + + def __init__(self, filename: Optional[Path] = 'default.sqlite'): + self.filename = filename + self._build_tables() + + def _connect(self): + conn = sqlite3.connect(self.filename) + c = conn.cursor() + return conn, c + + def _close(self, conn): + conn.commit() + conn.close() + + # TruDB requirement + def insert_record( + self, chain_id: str, input: str, output: str, record_json: dict, + ts: int, tags: str, total_tokens: int, total_cost: float + ) -> int: + assert isinstance( + record_json, Dict + ), f"Attempting to add a record that is not a dict, is {type(record_json)} instead." + + conn, c = self._connect() + + record_id = obj_id_of_obj(obj=record_json, prefix="record") + record_json['record_id'] = record_id + record_str = json_str_of_obj(record_json) + + c.execute( + f"INSERT INTO {self.TABLE_RECORDS} VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)", + ( + record_id, chain_id, input, output, record_str, tags, ts, + total_tokens, total_cost + ) + ) + self._close(conn) + + print(f"{UNICODE_CHECK} record {record_id} from {chain_id} -> {self.filename}") + + return record_id + + # TruDB requirement + def insert_chain( + self, chain_json: dict, chain_id: Optional[str] = None + ) -> str: + chain_id = chain_id or chain_json['chain_id'] or obj_id_of_obj( + obj=chain_json, prefix="chain" + ) + chain_str = json_str_of_obj(chain_json) + + conn, c = self._connect() + c.execute( + f"INSERT OR IGNORE INTO {self.TABLE_CHAINS} VALUES (?, ?)", + (chain_id, chain_str) + ) + self._close(conn) + + print(f"{UNICODE_CHECK} chain {chain_id} -> {self.filename}") + + return chain_id + + def insert_feedback_def(self, feedback_json: dict): + """ + Insert a feedback definition into the database. + """ + + feedback_id = feedback_json['feedback_id'] + feedback_str = json_str_of_obj(feedback_json) + + conn, c = self._connect() + c.execute( + f"INSERT OR REPLACE INTO {self.TABLE_FEEDBACK_DEFS} VALUES (?, ?)", + (feedback_id, feedback_str) + ) + self._close(conn) + + print(f"{UNICODE_CHECK} feedback def. {feedback_id} -> {self.filename}") + + def get_feedback_defs(self, feedback_id: Optional[str] = None): + clause = "" + args = () + if feedback_id is not None: + clause = "WHERE feedback_id=?" + args = (feedback_id,) + + query = f""" + SELECT + feedback_id, feedback_json + FROM {self.TABLE_FEEDBACK_DEFS} + {clause} + """ + + conn, c = self._connect() + c.execute(query, args) + rows = c.fetchall() + self._close(conn) + + from trulens_eval.tru_feedback import Feedback + + df_rows = [] + + for row in rows: + row = list(row) + row[1] = Feedback.of_json(json.loads(row[1])) + df_rows.append(row) + + return pd.DataFrame(rows, columns=['feedback_id', 'feedback']) + + def insert_feedback( + self, + record_id: Optional[str] = None, + feedback_id: Optional[str] = None, + last_ts: Optional[int] = None, # "last timestamp" + status: Optional[int] = None, + result_json: Optional[dict] = None, + total_cost: Optional[float] = None, + total_tokens: Optional[int] = None, + ): + """ + Insert a record-feedback link to db or update an existing one. + """ + + if record_id is None or feedback_id is None: + assert result_json is not None, "`result_json` needs to be given if `record_id` or `feedback_id` are not provided." + record_id = result_json['record_id'] + feedback_id = result_json['feedback_id'] + + last_ts = last_ts or 0 + status = status or 0 + result_json = result_json or dict() + total_cost = total_cost = 0.0 + total_tokens = total_tokens or 0 + result_str = json_str_of_obj(result_json) + + conn, c = self._connect() + c.execute( + f"INSERT OR REPLACE INTO {self.TABLE_FEEDBACKS} VALUES (?, ?, ?, ?, ?, ?, ?)", + ( + record_id, feedback_id, last_ts, status, result_str, + total_tokens, total_cost + ) + ) + self._close(conn) + + if status == 2: + print(f"{UNICODE_CHECK} feedback {feedback_id} on {record_id} -> {self.filename}") + else: + print(f"{UNCIODE_YIELD} feedback {feedback_id} on {record_id} -> {self.filename}") + + def get_feedback( + self, + record_id: Optional[str] = None, + feedback_id: Optional[str] = None, + status: Optional[int] = None, + last_ts_before: Optional[int] = None + ): + + clauses = [] + vars = [] + if record_id is not None: + clauses.append("record_id=?") + vars.append(record_id) + if feedback_id is not None: + clauses.append("feedback_id=?") + vars.append(feedback_id) + if status is not None: + if isinstance(status, Sequence): + clauses.append( + "status in (" + (",".join(["?"] * len(status))) + ")" + ) + for v in status: + vars.append(v) + else: + clauses.append("status=?") + vars.append(status) + if last_ts_before is not None: + clauses.append("last_ts<=?") + vars.append(last_ts_before) + + where_clause = " AND ".join(clauses) + if len(where_clause) > 0: + where_clause = " AND " + where_clause + + query = f""" + SELECT + f.record_id, f.feedback_id, f.last_ts, f.status, + f.result_json, f.total_cost, f.total_tokens, + fd.feedback_json, r.record_json, c.chain_json + FROM {self.TABLE_FEEDBACKS} f + JOIN {self.TABLE_FEEDBACK_DEFS} fd + JOIN {self.TABLE_RECORDS} r + JOIN {self.TABLE_CHAINS} c + WHERE f.feedback_id=fd.feedback_id + AND r.record_id=f.record_id + AND r.chain_id=c.chain_id + {where_clause} + """ + + conn, c = self._connect() + c.execute(query, vars) + rows = c.fetchall() + self._close(conn) + + from trulens_eval.tru_feedback import Feedback + + df_rows = [] + for row in rows: + row = list(row) + row[4] = json.loads(row[4]) # result_json + row[7] = json.loads(row[7]) # feedback_json + row[8] = json.loads(row[8]) # record_json + row[9] = json.loads(row[9]) # chain_json + + df_rows.append(row) + + return pd.DataFrame( + df_rows, + columns=[ + 'record_id', 'feedback_id', 'last_ts', 'status', 'result_json', + 'total_cost', 'total_tokens', 'feedback_json', 'record_json', + 'chain_json' + ] + ) + + # TO REMOVE: + # TruDB requirement + def select( + self, + *query: Tuple[Query], + where: Optional[Condition] = None + ) -> pd.DataFrame: + raise NotImplementedError + """ + # get the record json dumps from sql + record_strs = ... # TODO(shayak) + + records: Sequence[Dict] = map(json.loads, record_strs) + + db = LocalTinyDB() # in-memory db if filename not provided + for record in records: + db.insert_record(chain_id=record['chain_id'], record=record) + + return db.select(*query, where) + """ + + def get_chain(self, chain_id: str) -> JSON: + conn, c = self._connect() + c.execute( + f"SELECT chain_json FROM {self.TABLE_CHAINS} WHERE chain_id=?", + (chain_id,) + ) + result = c.fetchone()[0] + conn.close() + + return json.loads(result) + + def get_records_and_feedback(self, chain_ids: List[str]) -> Tuple[pd.DataFrame, Sequence[str]]: + # This returns all models if the list of chain_ids is empty. + conn, c = self._connect() + query = f""" + SELECT r.record_id, f.result_json + FROM {self.TABLE_RECORDS} r + LEFT JOIN {self.TABLE_FEEDBACKS} f + ON r.record_id = f.record_id + """ + if len(chain_ids) > 0: + chain_id_list = ', '.join('?' * len(chain_ids)) + query = query + f" WHERE r.chain_id IN ({chain_id_list})" + + c.execute(query) + rows = c.fetchall() + conn.close() + + df_results = pd.DataFrame( + rows, columns=[description[0] for description in c.description] + ) + + if len(df_results) == 0: + return df_results, [] + + conn, c = self._connect() + query = f""" + SELECT DISTINCT r.*, c.chain_json + FROM {self.TABLE_RECORDS} r + JOIN {self.TABLE_CHAINS} c + ON r.chain_id = c.chain_id + """ + if len(chain_ids) > 0: + chain_id_list = ', '.join('?' * len(chain_ids)) + query = query + f" WHERE r.chain_id IN ({chain_id_list})" + + c.execute(query) + rows = c.fetchall() + conn.close() + + df_records = pd.DataFrame( + rows, columns=[description[0] for description in c.description] + ) + + if len(df_records) == 0: + return df_records, [] + + # Apply the function to the 'data' column to convert it into separate columns + df_results['result_json'] = df_results['result_json'].apply(lambda d: {} if d is None else json.loads(d)) + + if "record_id" not in df_results.columns: + return df_results, [] + + df_results = df_results.groupby("record_id").agg( + lambda dicts: {key: val for d in dicts for key, val in d.items()} + ).reset_index() + + df_results = df_results['result_json'].apply(pd.Series) + + result_cols = [col for col in df_results.columns if col not in ['feedback_id', 'record_id', '_success', "_error"]] + + if len(df_results) == 0 or len(result_cols) == 0: + return df_records, [] + + assert "record_id" in df_results.columns + assert "record_id" in df_records.columns + + combined_df = df_records.merge(df_results, on=['record_id']) + + return combined_df, result_cols diff --git a/trulens_eval/trulens_eval/tru_feedback.py b/trulens_eval/trulens_eval/tru_feedback.py new file mode 100644 index 000000000..f1225e5f1 --- /dev/null +++ b/trulens_eval/trulens_eval/tru_feedback.py @@ -0,0 +1,945 @@ +""" +# Feedback Functions + +Initialize feedback function providers: + +```python + hugs = Huggingface() + openai = OpenAI() +``` + +Run feedback functions. See examples below on how to create them: + +```python + feedbacks = tru.run_feedback_functions( + chain=chain, + record=record, + feedback_functions=[f_lang_match, f_qs_relevance] + ) +``` + +## Examples: + +Non-toxicity of response: + +```python + f_non_toxic = Feedback(hugs.not_toxic).on_response() +``` + +Language match feedback function: + +```python + f_lang_match = Feedback(hugs.language_match).on(text1="prompt", text2="response") +``` + +""" + +from datetime import datetime +from inspect import Signature +from inspect import signature +import logging +from multiprocessing.pool import AsyncResult +import re +from time import sleep +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Type, Union + +import numpy as np +import openai +import requests +from tqdm.auto import tqdm + +from trulens_eval import feedback_prompts +from trulens_eval.keys import * +from trulens_eval.provider_apis import Endpoint +from trulens_eval.tru_db import JSON, Query, obj_id_of_obj, query_of_path +from trulens_eval.tru_db import Record +from trulens_eval.tru_db import TruDB +from trulens_eval.util import TP + +# openai + +# (external) feedback- +# provider +# model + +# feedback_collator: +# - record, feedback_imp, selector -> dict (input, output, other) + +# (external) feedback: +# - *args, **kwargs -> real +# - dict -> real +# - (record, selectors) -> real +# - str, List[str] -> real +# agg(relevance(str, str[0]), +# relevance(str, str[1]) +# ...) + +# (internal) feedback_imp: +# - Option 1 : input, output -> real +# - Option 2: dict (input, output, other) -> real + +Selection = Union[Query, str] +# "prompt" or "input" mean overall chain input text +# "response" or "output"mean overall chain output text +# Otherwise a Query is a path into a record structure. + +PROVIDER_CLASS_NAMES = ['OpenAI', 'Huggingface', 'Cohere'] + + +def check_provider(cls_or_name: Union[Type, str]) -> None: + if isinstance(cls_or_name, str): + cls_name = cls_or_name + else: + cls_name = cls_or_name.__name__ + + assert cls_name in PROVIDER_CLASS_NAMES, f"Unsupported provider class {cls_name}" + + +class Feedback(): + + def __init__( + self, + imp: Optional[Callable] = None, + selectors: Optional[Dict[str, Selection]] = None, + feedback_id: Optional[str] = None + ): + """ + A Feedback function container. + + Parameters: + + - imp: Optional[Callable] -- implementation of the feedback function. + - selectors: Optional[Dict[str, Selection]] -- mapping of implementation + argument names to where to get them from a record. + """ + + # Verify that `imp` expects the arguments specified in `selectors`: + if imp is not None and selectors is not None: + sig: Signature = signature(imp) + for argname in selectors.keys(): + assert argname in sig.parameters, ( + f"{argname} is not an argument to {imp.__name__}. " + f"Its arguments are {list(sig.parameters.keys())}." + ) + + self.imp = imp + self.selectors = selectors + + if feedback_id is not None: + self._feedback_id = feedback_id + + if imp is not None and selectors is not None: + # These are for serialization to/from json and for db storage. + + assert hasattr( + imp, "__self__" + ), "Feedback implementation is not a method (it may be a function)." + self.provider = imp.__self__ + check_provider(self.provider.__class__.__name__) + self.imp_method_name = imp.__name__ + self._json = self.to_json() + self._feedback_id = feedback_id or obj_id_of_obj(self._json, prefix="feedback") + self._json['feedback_id'] = self._feedback_id + + @staticmethod + def evaluate_deferred(tru: 'Tru'): + db = tru.db + + def prepare_feedback(row): + record_json = row.record_json + + feedback = Feedback.of_json(row.feedback_json) + feedback.run_and_log(record_json=record_json, tru=tru) + + feedbacks = db.get_feedback() + + for i, row in feedbacks.iterrows(): + if row.status == 0: + tqdm.write(f"Starting run for row {i}.") + + TP().runlater(prepare_feedback, row) + elif row.status in [1]: + now = datetime.now().timestamp() + if now - row.last_ts > 30: + tqdm.write(f"Incomplete row {i} last made progress over 30 seconds ago. Retrying.") + TP().runlater(prepare_feedback, row) + else: + tqdm.write(f"Incomplete row {i} last made progress less than 30 seconds ago. Giving it more time.") + + elif row.status in [-1]: + now = datetime.now().timestamp() + if now - row.last_ts > 60*5: + tqdm.write(f"Failed row {i} last made progress over 5 minutes ago. Retrying.") + TP().runlater(prepare_feedback, row) + else: + tqdm.write(f"Failed row {i} last made progress less than 5 minutes ago. Not touching it for now.") + + elif row.status == 2: + pass + + # TP().finish() + # TP().runrepeatedly(runner) + + @property + def json(self): + assert hasattr(self, "_json"), "Cannot json-size partially defined feedback function." + return self._json + + @property + def feedback_id(self): + assert hasattr(self, "_feedback_id"), "Cannot get id of partially defined feedback function." + return self._feedback_id + + @staticmethod + def selection_to_json(select: Selection) -> dict: + if isinstance(select, str): + return select + elif isinstance(select, Query): + return select._path + else: + raise ValueError(f"Unknown selection type {type(select)}.") + + @staticmethod + def selection_of_json(obj: Union[List, str]) -> Selection: + if isinstance(obj, str): + return obj + elif isinstance(obj, (List, Tuple)): + return query_of_path(obj) # TODO + else: + raise ValueError(f"Unknown selection encoding of type {type(obj)}.") + + def to_json(self) -> dict: + selectors_json = { + k: Feedback.selection_to_json(v) for k, v in self.selectors.items() + } + return { + 'selectors': selectors_json, + 'imp_method_name': self.imp_method_name, + 'provider': self.provider.to_json() + } + + @staticmethod + def of_json(obj) -> 'Feedback': + assert "selectors" in obj, "Feedback encoding has no 'selectors' field." + assert "imp_method_name" in obj, "Feedback encoding has no 'imp_method_name' field." + assert "provider" in obj, "Feedback encoding has no 'provider' field." + + imp_method_name = obj['imp_method_name'] + selectors = { + k: Feedback.selection_of_json(v) + for k, v in obj['selectors'].items() + } + provider = Provider.of_json(obj['provider']) + + assert hasattr( + provider, imp_method_name + ), f"Provider {provider.__name__} has no feedback function {imp_method_name}." + imp = getattr(provider, imp_method_name) + + return Feedback(imp, selectors=selectors) + + def on_multiple( + self, + multiarg: str, + each_query: Optional[Query] = None, + agg: Callable = np.mean + ) -> 'Feedback': + """ + Create a variant of `self` whose implementation will accept multiple + values for argument `multiarg`, aggregating feedback results for each. + Optionally each input element is further projected with `each_query`. + + Parameters: + + - multiarg: str -- implementation argument that expects multiple values. + - each_query: Optional[Query] -- a query providing the path from each + input to `multiarg` to some inner value which will be sent to `self.imp`. + """ + + def wrapped_imp(**kwargs): + assert multiarg in kwargs, f"Feedback function expected {multiarg} keyword argument." + + multi = kwargs[multiarg] + + assert isinstance( + multi, Sequence + ), f"Feedback function expected a sequence on {multiarg} argument." + + rets: List[AsyncResult[float]] = [] + + for aval in multi: + + if each_query is not None: + aval = TruDB.project(query=each_query, obj=aval) + + kwargs[multiarg] = aval + + rets.append(TP().promise(self.imp, **kwargs)) + + rets: List[float] = list(map(lambda r: r.get(), rets)) + + rets = np.array(rets) + + return agg(rets) + + wrapped_imp.__name__ = self.imp.__name__ + + wrapped_imp.__self__ = self.imp.__self__ # needed for serialization + + # Copy over signature from wrapped function. Otherwise signature of the + # wrapped method will include just kwargs which is insufficient for + # verify arguments (see Feedback.__init__). + wrapped_imp.__signature__ = signature(self.imp) + + return Feedback(imp=wrapped_imp, selectors=self.selectors) + + def on_prompt(self, arg: str = "text"): + """ + Create a variant of `self` that will take in the main chain input or + "prompt" as input, sending it as an argument `arg` to implementation. + """ + + return Feedback(imp=self.imp, selectors={arg: "prompt"}) + + on_input = on_prompt + + def on_response(self, arg: str = "text"): + """ + Create a variant of `self` that will take in the main chain output or + "response" as input, sending it as an argument `arg` to implementation. + """ + + return Feedback(imp=self.imp, selectors={arg: "response"}) + + on_output = on_response + + def on(self, **selectors): + """ + Create a variant of `self` with the same implementation but the given `selectors`. + """ + + return Feedback(imp=self.imp, selectors=selectors) + + def run_on_record(self, chain_json: JSON, record_json: JSON) -> Any: + """ + Run the feedback function on the given `record`. The `chain` that + produced the record is also required to determine input/output argument + names. + """ + + if 'record_id' not in record_json: + record_json['record_id'] = None + + try: + ins = self.extract_selection(chain_json=chain_json, record_json=record_json) + ret = self.imp(**ins) + + return { + '_success': True, + 'feedback_id': self.feedback_id, + 'record_id': record_json['record_id'], + self.name: ret + } + + except Exception as e: + return { + '_success': False, + 'feedback_id': self.feedback_id, + 'record_id': record_json['record_id'], + '_error': str(e) + } + + def run_and_log(self, record_json: JSON, tru: 'Tru') -> None: + record_id = record_json['record_id'] + chain_id = record_json['chain_id'] + + ts_now = datetime.now().timestamp() + + db = tru.db + + try: + db.insert_feedback( + record_id=record_id, + feedback_id=self.feedback_id, + last_ts = ts_now, + status = 1 # in progress + ) + + chain_json = db.get_chain(chain_id=chain_id) + + res = self.run_on_record(chain_json=chain_json, record_json=record_json) + + except Exception as e: + print(e) + res = { + '_success': False, + 'feedback_id': self.feedback_id, + 'record_id': record_json['record_id'], + '_error': str(e) + } + + ts_now = datetime.now().timestamp() + + if res['_success']: + db.insert_feedback( + record_id=record_id, + feedback_id=self.feedback_id, + last_ts = ts_now, + status = 2, # done and good + result_json=res, + total_cost=-1.0, # todo + total_tokens=-1 # todo + ) + else: + # TODO: indicate failure better + db.insert_feedback( + record_id=record_id, + feedback_id=self.feedback_id, + last_ts = ts_now, + status = -1, # failure + result_json=res, + total_cost=-1.0, # todo + total_tokens=-1 # todo + ) + + @property + def name(self): + """ + Name of the feedback function. Presently derived from the name of the + function implementing it. + """ + + return self.imp.__name__ + + def extract_selection( + self, + chain_json: Dict, + record_json: Dict + ) -> Dict[str, Any]: + """ + Given the `chain` that produced the given `record`, extract from + `record` the values that will be sent as arguments to the implementation + as specified by `self.selectors`. + """ + + ret = {} + + for k, v in self.selectors.items(): + if isinstance(v, Query): + q = v + + elif v == "prompt" or v == "input": + if len(chain_json['input_keys']) > 1: + #logging.warn( + # f"Chain has more than one input, guessing the first one is prompt." + #) + pass + + input_key = chain_json['input_keys'][0] + + q = Record.chain._call.args.inputs[input_key] + + elif v == "response" or v == "output": + if len(chain_json['output_keys']) > 1: + #logging.warn( + # "Chain has more than one ouput, guessing the first one is response." + #) + pass + + output_key = chain_json['output_keys'][0] + + q = Record.chain._call.rets[output_key] + + else: + raise RuntimeError(f"Unhandled selection type {type(v)}.") + + val = TruDB.project(query=q, record_json=record_json, chain_json=chain_json) + ret[k] = val + + return ret + + +pat_1_10 = re.compile(r"\s*([1-9][0-9]*)\s*") + + +def _re_1_10_rating(str_val): + matches = pat_1_10.fullmatch(str_val) + if not matches: + # Try soft match + matches = re.search('[1-9][0-9]*', str_val) + if not matches: + logging.warn(f"1-10 rating regex failed to match on: '{str_val}'") + return -10 # so this will be reported as -1 after division by 10 + + return int(matches.group()) + + +class Provider(): + + @staticmethod + def of_json(obj: Dict) -> 'Provider': + cls_name = obj['class'] + check_provider(cls_name) + + cls = eval(cls_name) + kwargs = {k: v for k, v in obj.items() if k != "class"} + + return cls(**kwargs) + + def to_json(self: 'Provider', **extras) -> Dict: + obj = {'class': self.__class__.__name__} + obj.update(**extras) + return obj + +class OpenAI(Provider): + + def __init__(self, model_engine: str = "gpt-3.5-turbo"): + """ + A set of OpenAI Feedback Functions. + + Parameters: + + - model_engine (str, optional): The specific model version. Defaults to + "gpt-3.5-turbo". + """ + self.model_engine = model_engine + self.endpoint = Endpoint(name="openai") + + def to_json(self) -> Dict: + return Provider.to_json(self, model_engine=self.model_engine) + + def _moderation(self, text: str): + return self.endpoint.run_me( + lambda: openai.Moderation.create(input=text) + ) + + def moderation_not_hate(self, text: str) -> float: + """ + Uses OpenAI's Moderation API. A function that checks if text is hate + speech. + + Parameters: + text (str): Text to evaluate. + + Returns: + float: A value between 0 and 1. 0 being "hate" and 1 being "not + hate". + """ + openai_response = self._moderation(text) + return 1 - float( + openai_response["results"][0]["category_scores"]["hate"] + ) + + def moderation_not_hatethreatening(self, text: str) -> float: + """ + Uses OpenAI's Moderation API. A function that checks if text is + threatening speech. + + Parameters: + text (str): Text to evaluate. + + Returns: + float: A value between 0 and 1. 0 being "threatening" and 1 being + "not threatening". + """ + openai_response = self._moderation(text) + + return 1 - int( + openai_response["results"][0]["category_scores"]["hate/threatening"] + ) + + def moderation_not_selfharm(self, text: str) -> float: + """ + Uses OpenAI's Moderation API. A function that checks if text is about + self harm. + + Parameters: + text (str): Text to evaluate. + + Returns: + float: A value between 0 and 1. 0 being "self harm" and 1 being "not + self harm". + """ + openai_response = self._moderation(text) + + return 1 - int( + openai_response["results"][0]["category_scores"]["self-harm"] + ) + + def moderation_not_sexual(self, text: str) -> float: + """ + Uses OpenAI's Moderation API. A function that checks if text is sexual + speech. + + Parameters: + text (str): Text to evaluate. + + Returns: + float: A value between 0 and 1. 0 being "sexual" and 1 being "not + sexual". + """ + openai_response = self._moderation(text) + + return 1 - int( + openai_response["results"][0]["category_scores"]["sexual"] + ) + + def moderation_not_sexualminors(self, text: str) -> float: + """ + Uses OpenAI's Moderation API. A function that checks if text is about + sexual minors. + + Parameters: + text (str): Text to evaluate. + + Returns: + float: A value between 0 and 1. 0 being "sexual minors" and 1 being + "not sexual minors". + """ + openai_response = self._moderation(text) + + return 1 - int( + openai_response["results"][0]["category_scores"]["sexual/minors"] + ) + + def moderation_not_violence(self, text: str) -> float: + """ + Uses OpenAI's Moderation API. A function that checks if text is about + violence. + + Parameters: + text (str): Text to evaluate. + + Returns: + float: A value between 0 and 1. 0 being "violence" and 1 being "not + violence". + """ + openai_response = self._moderation(text) + + return 1 - int( + openai_response["results"][0]["category_scores"]["violence"] + ) + + def moderation_not_violencegraphic(self, text: str) -> float: + """ + Uses OpenAI's Moderation API. A function that checks if text is about + graphic violence. + + Parameters: + text (str): Text to evaluate. + + Returns: + float: A value between 0 and 1. 0 being "graphic violence" and 1 + being "not graphic violence". + """ + openai_response = self._moderation(text) + + return 1 - int( + openai_response["results"][0]["category_scores"]["violence/graphic"] + ) + + def qs_relevance(self, question: str, statement: str) -> float: + """ + Uses OpenAI's Chat Completion Model. A function that completes a + template to check the relevance of the statement to the question. + + Parameters: + question (str): A question being asked. statement (str): A statement + to the question. + + Returns: + float: A value between 0 and 1. 0 being "not relevant" and 1 being + "relevant". + """ + return _re_1_10_rating( + self.endpoint.run_me( + lambda: openai.ChatCompletion.create( + model=self.model_engine, + temperature=0.0, + messages=[ + { + "role": + "system", + "content": + str.format( + feedback_prompts.QS_RELEVANCE, + question=question, + statement=statement + ) + } + ] + )["choices"][0]["message"]["content"] + ) + ) / 10 + + def relevance(self, prompt: str, response: str) -> float: + """ + Uses OpenAI's Chat Completion Model. A function that completes a + template to check the relevance of the response to a prompt. + + Parameters: + prompt (str): A text prompt to an agent. response (str): The agent's + response to the prompt. + + Returns: + float: A value between 0 and 1. 0 being "not relevant" and 1 being + "relevant". + """ + return _re_1_10_rating( + self.endpoint.run_me( + lambda: openai.ChatCompletion.create( + model=self.model_engine, + temperature=0.0, + messages=[ + { + "role": + "system", + "content": + str.format( + feedback_prompts.PR_RELEVANCE, + prompt=prompt, + response=response + ) + } + ] + )["choices"][0]["message"]["content"] + ) + ) / 10 + + def model_agreement(self, prompt: str, response: str) -> float: + """ + Uses OpenAI's Chat GPT Model. A function that gives Chat GPT the same + prompt and gets a response, encouraging truthfulness. A second template + is given to Chat GPT with a prompt that the original response is + correct, and measures whether previous Chat GPT's response is similar. + + Parameters: + prompt (str): A text prompt to an agent. response (str): The agent's + response to the prompt. + + Returns: + float: A value between 0 and 1. 0 being "not in agreement" and 1 + being "in agreement". + """ + oai_chat_response = OpenAI().endpoint_openai.run_me( + lambda: openai.ChatCompletion.create( + model=self.model_engine, + temperature=0.0, + messages=[ + { + "role": "system", + "content": feedback_prompts.CORRECT_SYSTEM_PROMPT + }, { + "role": "user", + "content": prompt + } + ] + )["choices"][0]["message"]["content"] + ) + agreement_txt = _get_answer_agreement( + prompt, response, oai_chat_response, self.model_engine + ) + return _re_1_10_rating(agreement_txt) / 10 + + def sentiment(self, text: str) -> float: + """ + Uses OpenAI's Chat Completion Model. A function that completes a + template to check the sentiment of some text. + + Parameters: + text (str): A prompt to an agent. response (str): The agent's + response to the prompt. + + Returns: + float: A value between 0 and 1. 0 being "negative sentiment" and 1 + being "positive sentiment". + """ + + return _re_1_10_rating( + self.endpoint.run_me( + lambda: openai.ChatCompletion.create( + model=self.model_engine, + temperature=0.5, + messages=[ + { + "role": "system", + "content": feedback_prompts.SENTIMENT_SYSTEM_PROMPT + }, { + "role": "user", + "content": text + } + ] + )["choices"][0]["message"]["content"] + ) + ) + + +def _get_answer_agreement(prompt, response, check_response, model_engine): + print("DEBUG") + print(feedback_prompts.AGREEMENT_SYSTEM_PROMPT % (prompt, response)) + print("MODEL ANSWER") + print(check_response) + oai_chat_response = OpenAI().endpoint.run_me( + lambda: openai.ChatCompletion.create( + model=model_engine, + temperature=0.5, + messages=[ + { + "role": + "system", + "content": + feedback_prompts.AGREEMENT_SYSTEM_PROMPT % + (prompt, response) + }, { + "role": "user", + "content": check_response + } + ] + )["choices"][0]["message"]["content"] + ) + return oai_chat_response + + +class Huggingface(Provider): + + SENTIMENT_API_URL = "https://api-inference.huggingface.co/models/cardiffnlp/twitter-roberta-base-sentiment" + TOXIC_API_URL = "https://api-inference.huggingface.co/models/martin-ha/toxic-comment-model" + CHAT_API_URL = "https://api-inference.huggingface.co/models/facebook/blenderbot-3B" + LANGUAGE_API_URL = "https://api-inference.huggingface.co/models/papluca/xlm-roberta-base-language-detection" + + def __init__(self): + """A set of Huggingface Feedback Functions. Utilizes huggingface api-inference + """ + self.endpoint = Endpoint( + name="huggingface", post_headers=get_huggingface_headers() + ) + + def language_match(self, text1: str, text2: str) -> float: + """ + Uses Huggingface's papluca/xlm-roberta-base-language-detection model. A + function that uses language detection on `text1` and `text2` and + calculates the probit difference on the language detected on text1. The + function is: `1.0 - (|probit_language_text1(text1) - + probit_language_text1(text2))` + + Parameters: + + text1 (str): Text to evaluate. + + text2 (str): Comparative text to evaluate. + + Returns: + + float: A value between 0 and 1. 0 being "different languages" and 1 + being "same languages". + """ + + def get_scores(text): + payload = {"inputs": text} + hf_response = self.endpoint.post( + url=Huggingface.LANGUAGE_API_URL, payload=payload, timeout=30 + ) + return {r['label']: r['score'] for r in hf_response} + + max_length = 500 + scores1: AsyncResult[Dict] = TP().promise( + get_scores, text=text1[:max_length] + ) + scores2: AsyncResult[Dict] = TP().promise( + get_scores, text=text2[:max_length] + ) + + scores1: Dict = scores1.get() + scores2: Dict = scores2.get() + + langs = list(scores1.keys()) + prob1 = np.array([scores1[k] for k in langs]) + prob2 = np.array([scores2[k] for k in langs]) + diff = prob1 - prob2 + + l1 = 1.0 - (np.linalg.norm(diff, ord=1)) / 2.0 + + return l1 + + def positive_sentiment(self, text: str) -> float: + """ + Uses Huggingface's cardiffnlp/twitter-roberta-base-sentiment model. A + function that uses a sentiment classifier on `text`. + + Parameters: + text (str): Text to evaluate. + + Returns: + float: A value between 0 and 1. 0 being "negative sentiment" and 1 + being "positive sentiment". + """ + max_length = 500 + truncated_text = text[:max_length] + payload = {"inputs": truncated_text} + + hf_response = self.endpoint.post( + url=Huggingface.SENTIMENT_API_URL, payload=payload + ) + + for label in hf_response: + if label['label'] == 'LABEL_2': + return label['score'] + + def not_toxic(self, text: str) -> float: + """ + Uses Huggingface's martin-ha/toxic-comment-model model. A function that + uses a toxic comment classifier on `text`. + + Parameters: + text (str): Text to evaluate. + + Returns: + float: A value between 0 and 1. 0 being "toxic" and 1 being "not + toxic". + """ + max_length = 500 + truncated_text = text[:max_length] + payload = {"inputs": truncated_text} + hf_response = self.endpoint.post( + url=Huggingface.TOXIC_API_URL, payload=payload + ) + + for label in hf_response: + if label['label'] == 'toxic': + return label['score'] + + +# cohere +class Cohere(Provider): + + def __init__(self, model_engine='large'): + Cohere().endpoint = Endpoint(name="cohere") + self.model_engine = model_engine + + def to_json(self) -> Dict: + return Provider.to_json(self, model_engine=self.model_engine) + + def sentiment( + self, + text, + ): + return int( + Cohere().endpoint.run_me( + lambda: get_cohere_agent().classify( + model=self.model_engine, + inputs=[text], + examples=feedback_prompts.COHERE_SENTIMENT_EXAMPLES + )[0].prediction + ) + ) + + def not_disinformation(self, text): + return int( + Cohere().endpoint.run_me( + lambda: get_cohere_agent().classify( + model=self.model_engine, + inputs=[text], + examples=feedback_prompts.COHERE_NOT_DISINFORMATION_EXAMPLES + )[0].prediction + ) + ) diff --git a/trulens_eval/trulens_eval/util.py b/trulens_eval/trulens_eval/util.py new file mode 100644 index 000000000..c7e6249af --- /dev/null +++ b/trulens_eval/trulens_eval/util.py @@ -0,0 +1,122 @@ +""" +Utilities. + +Do not import anything from trulens_eval here. +""" + +import logging +from multiprocessing.pool import AsyncResult +from multiprocessing.pool import ThreadPool +from queue import Queue +from time import sleep +from typing import Callable, Dict, Hashable, List, Optional, TypeVar + +from multiprocessing.context import TimeoutError + +import pandas as pd +from tqdm.auto import tqdm + +T = TypeVar("T") + +UNICODE_CHECK = "✅" +UNCIODE_YIELD = "⚡" + + +class SingletonPerName(): + """ + Class for creating singleton instances except there being one instance max, + there is one max per different `name` argument. If `name` is never given, + reverts to normal singleton behaviour. + """ + + # Hold singleton instances here. + instances: Dict[Hashable, 'SingletonPerName'] = dict() + + def __new__(cls, name: str = None, *args, **kwargs): + """ + Create the singleton instance if it doesn't already exist and return it. + """ + + key = cls.__name__, name + + if key not in cls.instances: + logging.debug( + f"*** Creating new {cls.__name__} singleton instance for name = {name} ***" + ) + SingletonPerName.instances[key] = super().__new__(cls) + + return SingletonPerName.instances[key] + + +class TP(SingletonPerName): # "thread processing" + + def __init__(self): + if hasattr(self, "thread_pool"): + # Already initialized as per SingletonPerName mechanism. + return + + # TODO(piotrm): if more tasks than `processes` get added, future ones + # will block and earlier ones may never start executing. + self.thread_pool = ThreadPool(processes=1024) + self.running = 0 + self.promises = Queue(maxsize=1024) + + def _started(self, *args, **kwargs): + self.running += 1 + + def _finished(self, *args, **kwargs): + self.running -= 1 + + def runrepeatedly(self, func: Callable, rpm: float = 6, *args, **kwargs): + def runner(): + while True: + func(*args, **kwargs) + sleep(60 / rpm) + + self.runlater(runner) + + def runlater(self, func: Callable, *args, **kwargs) -> None: + self._started() + + prom = self.thread_pool.apply_async(func, callback=self._finished, args=args, kwds=kwargs) + self.promises.put(prom) + + def promise(self, func: Callable[..., T], *args, + **kwargs) -> AsyncResult: + self._started() + prom = self.thread_pool.apply_async(func, callback=self._finished, args=args, kwds=kwargs) + self.promises.put(prom) + + return prom + + def finish(self, timeout: Optional[float] = None) -> int: + print(f"Finishing {self.promises.qsize()} task(s) ", end='') + + timeouts = [] + + while not self.promises.empty(): + prom = self.promises.get() + try: + prom.get(timeout=timeout) + print(".", end="") + except TimeoutError: + print("!", end="") + timeouts.append(prom) + + for prom in timeouts: + self.promises.put(prom) + + if len(timeouts) == 0: + print("done.") + else: + print("some tasks timed out.") + + return len(timeouts) + + def status(self) -> List[str]: + rows = [] + + for p in self.thread_pool._pool: + rows.append([p.is_alive(), str(p)]) + + return pd.DataFrame(rows, columns=["alive", "thread"]) diff --git a/trulens_eval/trulens_eval/ux/add_logo.py b/trulens_eval/trulens_eval/ux/add_logo.py new file mode 100644 index 000000000..fe4dd9f7d --- /dev/null +++ b/trulens_eval/trulens_eval/ux/add_logo.py @@ -0,0 +1,35 @@ +import base64 + +import pkg_resources +import streamlit as st + + +def add_logo(): + logo = open( + pkg_resources.resource_filename( + 'trulens_eval', 'ux/trulens_logo.svg' + ), "rb" + ).read() + + logo_encoded = base64.b64encode(logo).decode() + st.markdown( + f""" + + """, + unsafe_allow_html=True, + ) diff --git a/trulens_eval/trulens_eval/ux/trulens_logo.svg b/trulens_eval/trulens_eval/ux/trulens_logo.svg new file mode 100644 index 000000000..06d52ad37 --- /dev/null +++ b/trulens_eval/trulens_eval/ux/trulens_logo.svg @@ -0,0 +1,44 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/trulens_eval/trulens_eval_quickstart.ipynb b/trulens_eval/trulens_eval_quickstart.ipynb new file mode 100644 index 000000000..8bf21d963 --- /dev/null +++ b/trulens_eval/trulens_eval_quickstart.ipynb @@ -0,0 +1,394 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# TruLens for LLMs: Quickstart\n", + "\n", + "In this quickstart you will create a simple LLM Chain and learn how to log it and get feedback on an LLM response.\n", + "\n", + "Note: If you haven't already, make sure to set up your local .env file with your OpenAI and Huggingface Keys. Your .env file should be in the same directory that you run this notebook. If you need help, take a look at the [.env.example](https://github.com/truera/trulens/blob/e8b11c4689e644687d2eafe09d90d8d7774b581c/trulens_eval/trulens_eval/.env.example#L4)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Add API keys\n", + "For this quickstart you will need Open AI and Huggingface keys" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "os.environ[\"OPENAI_API_KEY\"] = \"...\"\n", + "os.environ[\"HUGGINGFACE_API_KEY\"] = \"...\"" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Import from LangChain and TruLens" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from IPython.display import JSON\n", + "\n", + "# Imports main tools:\n", + "from trulens_eval import TruChain, Feedback, Huggingface, Tru\n", + "tru = Tru()\n", + "\n", + "# imports from langchain to build app\n", + "from langchain.chains import LLMChain\n", + "from langchain.llms import OpenAI\n", + "from langchain.prompts.chat import ChatPromptTemplate, PromptTemplate\n", + "from langchain.prompts.chat import HumanMessagePromptTemplate" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Create Simple LLM Application\n", + "\n", + "This example uses a LangChain framework and OpenAI LLM" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "full_prompt = HumanMessagePromptTemplate(\n", + " prompt=PromptTemplate(\n", + " template=\n", + " \"Provide a helpful response with relevant background information for the following: {prompt}\",\n", + " input_variables=[\"prompt\"],\n", + " )\n", + ")\n", + "\n", + "chat_prompt_template = ChatPromptTemplate.from_messages([full_prompt])\n", + "\n", + "llm = OpenAI(temperature=0.9, max_tokens=128)\n", + "\n", + "chain = LLMChain(llm=llm, prompt=chat_prompt_template, verbose=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Send your first request to your new app, asking what time it is in spanish" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "prompt_input = '¿que hora es?'" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "gpt3_response = chain(prompt_input)\n", + "\n", + "display(gpt3_response)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Instrument chain for logging with TruLens" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "truchain: TruChain = TruChain(chain, chain_id='Chain1_ChatApplication')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Instrumented chain can operate like the original:\n", + "\n", + "gpt3_response = truchain(prompt_input)\n", + "\n", + "display(gpt3_response)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# But can also produce a log or \"record\" of the execution of the chain:\n", + "\n", + "gpt3_response, record = truchain.call_with_record(prompt_input)\n", + "\n", + "JSON(record)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# We can log the records but first we need to log the chain itself:\n", + "\n", + "tru.add_chain(chain_json=truchain.json)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Now the record:\n", + "\n", + "tru.add_record(\n", + " prompt=prompt_input, # prompt input\n", + " response=gpt3_response['text'], # LLM response\n", + " record_json=record # record is returned by the TruChain wrapper\n", + ")\n", + "\n", + "# Note that the `add_record` call automatically sets the `record_id` field of the\n", + "# `record_json` to the returned record id. Retrieving it from the output of `add_record` is not \n", + "# necessary." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Initialize Huggingface-based feedback function collection class:\n", + "hugs = Huggingface()\n", + "\n", + "# Define a language match feedback function using HuggingFace.\n", + "f_lang_match = Feedback(hugs.language_match).on(\n", + " text1=\"prompt\", text2=\"response\"\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# This might take a moment if the public api needs to load the language model\n", + "# used in the feedback function:\n", + "feedback_result = f_lang_match.run_on_record(\n", + " chain_json=truchain.json, record_json=record\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "JSON(feedback_result)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Alternatively, run a collection of feedback functions:\n", + "\n", + "feedback_results = tru.run_feedback_functions(\n", + " record_json=record,\n", + " feedback_functions=[f_lang_match]\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "display(feedback_results)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# These can be logged:\n", + "\n", + "tru.add_feedbacks(feedback_results)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Run the TruLens dashboard to explore the quality of your LLM chain" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "tru.run_dashboard() # open a local streamlit app to explore\n", + "\n", + "# tru.run_dashboard(_dev=True) # if running from repo\n", + "# tru.stop_dashboard() # stop if needed" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Automatic Logging\n", + "\n", + "The above logging and feedback function evaluation steps can be done by TruChain." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "truchain: TruChain = TruChain(\n", + " chain,\n", + " chain_id='Chain1_ChatApplication',\n", + " feedbacks=[f_lang_match],\n", + " tru=tru\n", + ")\n", + "# or tru.Chain(...)\n", + "\n", + "# Note: providing `db: TruDB` causes the above constructor to log the wrapped chain in the database specified.\n", + "\n", + "# Note: any `feedbacks` specified here will be evaluated and logged whenever the chain is used." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "truchain(\"This will be automatically logged.\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Out-of-band Feedback evaluation\n", + "\n", + "In the above example, the feedback function evaluation is done in the same process as the chain evaluation. The alternative approach is the use the provided persistent evaluator started via `tru.start_deferred_feedback_evaluator`. Then specify the `feedback_mode` for `TruChain` as `deferred` to let the evaluator handle the feedback functions.\n", + "\n", + "For demonstration purposes, we start the evaluator here but it can be started in another process." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "truchain: TruChain = TruChain(\n", + " chain,\n", + " chain_id='Chain1_ChatApplication',\n", + " feedbacks=[f_lang_match],\n", + " tru=tru,\n", + " feedback_mode=\"deferred\"\n", + ")\n", + "# or tru.Chain(...)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "tru.start_evaluator()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "truchain(\"This will be logged by deferred evaluator.\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Feedback functions evaluated in the deferred manner can be seen in the \"Progress\" page of the TruLens dashboard." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.16" + }, + "vscode": { + "interpreter": { + "hash": "c633204c92f433e69d41413efde9db4a539ce972d10326abcceb024ad118839e" + } + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/trulens_eval/trulens_logo.svg b/trulens_eval/trulens_logo.svg new file mode 100644 index 000000000..06d52ad37 --- /dev/null +++ b/trulens_eval/trulens_logo.svg @@ -0,0 +1,44 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/trulens_eval/webindex.ipynb b/trulens_eval/webindex.ipynb new file mode 100644 index 000000000..37584c355 --- /dev/null +++ b/trulens_eval/webindex.ipynb @@ -0,0 +1,302 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from pathlib import Path\n", + "\n", + "import pinecone\n", + "import requests\n", + "from langchain.document_loaders import (PagedPDFSplitter, TextLoader,\n", + " UnstructuredHTMLLoader,\n", + " UnstructuredMarkdownLoader,\n", + " UnstructuredPDFLoader)\n", + "from langchain.embeddings.openai import OpenAIEmbeddings\n", + "from langchain.text_splitter import CharacterTextSplitter\n", + "from langchain.vectorstores import Pinecone\n", + "\n", + "# from urllib.parse import urlparse\n", + "\n", + "TRUERA_BASE_URL = 'https://truera.com'\n", + "TRUREA_DOC_URL = 'https://docs.truera.com/1.34/public'" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from bs4 import BeautifulSoup\n", + "\n", + "# Create a pinecone vector db from a few blogs and docs.\n", + "# TODO: langchain includes html loaders which may produce better chunks.\n", + "\n", + "text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)\n", + "pdf_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)\n", + "\n", + "scrape_path = Path(\"webscrape\")\n", + "\n", + "collected = dict()\n", + "documents = []\n", + "\n", + "\n", + "def url_to_path(url):\n", + " url_esc = url.replace(\"https://\", \"\").replace(\"http://\",\n", + " \"\").replace(\"/\", \":\")\n", + "\n", + " ext = \".html\"\n", + "\n", + " if url_esc.endswith(\".png\"):\n", + " ext = \"\"\n", + " elif url_esc.endswith(\".pdf\"):\n", + " ext = \"\"\n", + " elif url_esc.endswith(\".jpg\"):\n", + " ext = \"\"\n", + " elif url_esc.endswith(\".md\"):\n", + " ext = \"\"\n", + "\n", + " return scrape_path / (url_esc + ext)\n", + "\n", + "\n", + "def scrape(url):\n", + " if url in collected:\n", + " return\n", + "\n", + " collected[url] = True\n", + "\n", + " print(url)\n", + "\n", + " scrape_file = url_to_path(url)\n", + "\n", + " if str(url).endswith(\".pdf\"):\n", + " # skipping for now since issues with the content extractors noted below\n", + " # return\n", + " pass\n", + "\n", + " if scrape_file.exists():\n", + " print(\"cached\", end=\" \")\n", + " content = bytes()\n", + " with scrape_file.open(\"rb\") as fh:\n", + " for line in fh.readlines():\n", + " content += line\n", + " else:\n", + " print(\"downloading\", end=\" \")\n", + " response = requests.get(url)\n", + "\n", + " if response.encoding is None:\n", + " content = response.content\n", + "\n", + " with scrape_file.open(\"wb\") as fh:\n", + " fh.write(content)\n", + "\n", + " else:\n", + " content = response.text\n", + "\n", + " with scrape_file.open(\"w\") as fh:\n", + " fh.write(content)\n", + "\n", + " loader = UnstructuredHTMLLoader\n", + " if url.endswith(\".pdf\"):\n", + " #return\n", + " loader = PagedPDFSplitter # freezes for some pdfs\n", + " # loader = UnstructuredPDFLoader # cannot get requirement installation figured out\n", + "\n", + " elif url.endswith(\".png\"):\n", + " return\n", + " \n", + " elif url.endswith(\".jpg\"):\n", + " return\n", + " \n", + " elif url.endswith(\".md\"):\n", + " loader = UnstructuredMarkdownLoader\n", + "\n", + " elif (not url.endswith(\"truera.com\")) and (\n", + " not url.endswith(\"truera.net\")) and \".\" in url[-5:]:\n", + " \n", + " raise RuntimeError(f\"Unhandled source type {url}\")\n", + "\n", + " docs = loader(str(scrape_file)).load()\n", + " print(f\"got {len(docs)} document(s)\")\n", + " for doc in docs:\n", + " doc.metadata['source'] = url\n", + " documents.append(doc)\n", + "\n", + " try:\n", + " soup = BeautifulSoup(content, 'html.parser')\n", + "\n", + " except Exception as e:\n", + " print(e)\n", + " return\n", + "\n", + " for surl in soup.findAll(\"a\"):\n", + " # print(url)\n", + " sub = surl.get('href')\n", + " if sub is not None:\n", + " sub = str(sub)\n", + " # print(\"\\t\", sub)\n", + "\n", + " if sub.startswith(\"mailto\") or sub.startswith(\"tel\"):\n", + " continue\n", + "\n", + " if not (sub.startswith(\"http\") or sub.startswith(\"//\")):\n", + " sub = url + \"/\" + sub\n", + "\n", + " # print(\"sub=\", sub)\n", + "\n", + " if not (sub.startswith(\"https://truera.com\")\n", + " or sub.startswith(\"https://support.truera.com\")\n", + " or sub.startswith(\"https://marketing.truera.com\")\n", + " or sub.startswith(\"https://go.truera.com\")\n", + " or sub.startswith(\"https://app.truera.net\")\n", + " or sub.startswith(\"https://docs.truera.com\")):\n", + " continue\n", + "\n", + " if \"?\" in sub:\n", + " continue\n", + "\n", + " if \"#\" in sub:\n", + " sub = sub.split(\"#\")[0]\n", + "\n", + " while \"/\" == sub[-1]:\n", + " sub = sub[0:-1]\n", + "\n", + " if sub.endswith(\"/.\"):\n", + " continue\n", + "\n", + " if sub.endswith(\"/..\"):\n", + " continue\n", + "\n", + " if \"..\" in sub:\n", + " continue\n", + "\n", + " if sub.endswith(\"//\"):\n", + " continue\n", + "\n", + " scrape(sub)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "scrape(TRUERA_BASE_URL)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#collected = dict()\n", + "scrape(TRUREA_DOC_URL)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "scrape(\"https://truera.com/ai-quality-blog/\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "len(documents)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# split scraped documents into chunks\n", + "\n", + "text_splitter = CharacterTextSplitter(chunk_size=512, chunk_overlap=0)\n", + "docs = text_splitter.split_documents(documents)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# keep only big ones\n", + "\n", + "print(len(docs))\n", + "bigdocs = [doc for doc in docs if len(doc.page_content) > 256]\n", + "print(len(bigdocs))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from keys import *\n", + "pinecone.init(\n", + " api_key=PINECONE_API_KEY, # find at app.pinecone.io\n", + " environment=PINECONE_ENV # next to api key in console\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# create / upload an index of the docs to pinecone\n", + "\n", + "index_name = \"llmdemo\"\n", + "embedding = OpenAIEmbeddings(model='text-embedding-ada-002') # 1536 dims\n", + "pinecone.delete_index(index_name)\n", + "pinecone.create_index(index_name, dimension=1536)\n", + "Pinecone.from_documents(bigdocs, embedding, index_name=index_name)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "demo3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.6" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/trulens_explain/docs/api/attribution.md b/trulens_explain/docs/api/attribution.md deleted file mode 100644 index 6a3ed48a7..000000000 --- a/trulens_explain/docs/api/attribution.md +++ /dev/null @@ -1,3 +0,0 @@ -# Attribution Methods - -::: trulens.nn.attribution \ No newline at end of file diff --git a/trulens_explain/docs/api/distributions.md b/trulens_explain/docs/api/distributions.md deleted file mode 100644 index 8b39b1f75..000000000 --- a/trulens_explain/docs/api/distributions.md +++ /dev/null @@ -1,3 +0,0 @@ -# Distributions of Interest - -::: trulens.nn.distributions \ No newline at end of file diff --git a/trulens_explain/docs/api/model_wrappers.md b/trulens_explain/docs/api/model_wrappers.md deleted file mode 100644 index cdd973f7a..000000000 --- a/trulens_explain/docs/api/model_wrappers.md +++ /dev/null @@ -1,3 +0,0 @@ -# Model Wrappers - -::: trulens.nn.models \ No newline at end of file diff --git a/trulens_explain/docs/api/quantities.md b/trulens_explain/docs/api/quantities.md deleted file mode 100644 index e904148f8..000000000 --- a/trulens_explain/docs/api/quantities.md +++ /dev/null @@ -1,3 +0,0 @@ -# Quantities of Interest - -::: trulens.nn.quantities \ No newline at end of file diff --git a/trulens_explain/docs/api/slices.md b/trulens_explain/docs/api/slices.md deleted file mode 100644 index 4e54562f8..000000000 --- a/trulens_explain/docs/api/slices.md +++ /dev/null @@ -1,3 +0,0 @@ -# Slices - -::: trulens.nn.slices \ No newline at end of file diff --git a/trulens_explain/docs/api/visualizations.md b/trulens_explain/docs/api/visualizations.md deleted file mode 100644 index e4c4f439f..000000000 --- a/trulens_explain/docs/api/visualizations.md +++ /dev/null @@ -1,3 +0,0 @@ -# Visualization Methods - -::: trulens.visualizations \ No newline at end of file diff --git a/trulens_explain/tests/keras/unit/attribution_axioms_test.py b/trulens_explain/tests/keras/unit/attribution_axioms_test.py index 713c051a1..f325302f5 100644 --- a/trulens_explain/tests/keras/unit/attribution_axioms_test.py +++ b/trulens_explain/tests/keras/unit/attribution_axioms_test.py @@ -13,7 +13,6 @@ from keras.layers import Dense from keras.layers import Input from keras.models import Model - from tests.unit.attribution_axioms_test_base import AxiomsTestBase from trulens.nn.models import get_model_wrapper diff --git a/trulens_explain/tests/keras/unit/batch_tests.py b/trulens_explain/tests/keras/unit/batch_tests.py index 2da44c161..cbd602c34 100644 --- a/trulens_explain/tests/keras/unit/batch_tests.py +++ b/trulens_explain/tests/keras/unit/batch_tests.py @@ -9,7 +9,6 @@ from keras.layers import Dense from keras.layers import Input from keras.models import Model - from tests.unit.batch_test_base import BatchTestBase from trulens.nn.models import get_model_wrapper diff --git a/trulens_explain/tests/keras/unit/doi_test.py b/trulens_explain/tests/keras/unit/doi_test.py index cf125dd48..e385610f6 100644 --- a/trulens_explain/tests/keras/unit/doi_test.py +++ b/trulens_explain/tests/keras/unit/doi_test.py @@ -11,7 +11,6 @@ from keras.layers import Input from keras.layers import Lambda from keras.models import Model - from tests.unit.doi_test_base import DoiTestBase from trulens.nn.models import get_model_wrapper diff --git a/trulens_explain/tests/keras/unit/environment_test.py b/trulens_explain/tests/keras/unit/environment_test.py index ffe64bc4d..ebc0566d3 100644 --- a/trulens_explain/tests/keras/unit/environment_test.py +++ b/trulens_explain/tests/keras/unit/environment_test.py @@ -4,7 +4,6 @@ from keras.layers import Dense from keras.layers import Input from keras.models import Model - from tests.unit.environment_test_base import EnvironmentTestBase from trulens.nn.backend import Backend from trulens.nn.models.keras import KerasModelWrapper diff --git a/trulens_explain/tests/keras/unit/ffn_edge_case_architectures_test.py b/trulens_explain/tests/keras/unit/ffn_edge_case_architectures_test.py index 781bb74ee..ee0605cdc 100644 --- a/trulens_explain/tests/keras/unit/ffn_edge_case_architectures_test.py +++ b/trulens_explain/tests/keras/unit/ffn_edge_case_architectures_test.py @@ -15,7 +15,6 @@ from keras.layers import Dense from keras.layers import Input from keras.models import Model - from trulens.nn.attribution import InternalInfluence from trulens.nn.backend import get_backend from trulens.nn.distributions import PointDoi diff --git a/trulens_explain/tests/keras/unit/keras_model_test.py b/trulens_explain/tests/keras/unit/keras_model_test.py index 1bb274bf3..cb63a2adb 100644 --- a/trulens_explain/tests/keras/unit/keras_model_test.py +++ b/trulens_explain/tests/keras/unit/keras_model_test.py @@ -13,7 +13,6 @@ from keras.layers import Input from keras.models import Model import numpy as np - from tests.unit.model_wrapper_test_base import ModelWrapperTestBase from trulens.nn.models.keras import KerasModelWrapper diff --git a/trulens_explain/tests/pytorch/unit/attribution_axioms_test.py b/trulens_explain/tests/pytorch/unit/attribution_axioms_test.py index 49c7136d7..a94d8e7a4 100644 --- a/trulens_explain/tests/pytorch/unit/attribution_axioms_test.py +++ b/trulens_explain/tests/pytorch/unit/attribution_axioms_test.py @@ -5,12 +5,11 @@ from unittest import main from unittest import TestCase +from tests.unit.attribution_axioms_test_base import AxiomsTestBase from torch import Tensor from torch.nn import Linear from torch.nn import Module from torch.nn import ReLU - -from tests.unit.attribution_axioms_test_base import AxiomsTestBase from trulens.nn.backend import get_backend from trulens.nn.models import get_model_wrapper diff --git a/trulens_explain/tests/pytorch/unit/batch_test.py b/trulens_explain/tests/pytorch/unit/batch_test.py index 0f4d981e9..77deaa347 100644 --- a/trulens_explain/tests/pytorch/unit/batch_test.py +++ b/trulens_explain/tests/pytorch/unit/batch_test.py @@ -5,12 +5,11 @@ from unittest import main from unittest import TestCase +from tests.unit.batch_test_base import BatchTestBase from torch import Tensor from torch.nn import Linear from torch.nn import Module from torch.nn import ReLU - -from tests.unit.batch_test_base import BatchTestBase from trulens.nn.backend import get_backend from trulens.nn.models import get_model_wrapper diff --git a/trulens_explain/tests/pytorch/unit/doi_test.py b/trulens_explain/tests/pytorch/unit/doi_test.py index 5a482fcf0..9cef3fe04 100644 --- a/trulens_explain/tests/pytorch/unit/doi_test.py +++ b/trulens_explain/tests/pytorch/unit/doi_test.py @@ -5,12 +5,11 @@ from unittest import main from unittest import TestCase +from tests.unit.doi_test_base import DoiTestBase from torch import Tensor from torch.nn import Linear from torch.nn import Module from torch.nn import ReLU - -from tests.unit.doi_test_base import DoiTestBase from trulens.nn.backend import get_backend from trulens.nn.models import get_model_wrapper diff --git a/trulens_explain/tests/pytorch/unit/environment_test.py b/trulens_explain/tests/pytorch/unit/environment_test.py index 2382ecc0a..5f7b420e5 100644 --- a/trulens_explain/tests/pytorch/unit/environment_test.py +++ b/trulens_explain/tests/pytorch/unit/environment_test.py @@ -2,10 +2,9 @@ from unittest import main from unittest import TestCase +from tests.unit.environment_test_base import EnvironmentTestBase from torch.nn import Linear from torch.nn import Module - -from tests.unit.environment_test_base import EnvironmentTestBase from trulens.nn.backend import Backend from trulens.nn.backend import get_backend from trulens.nn.models.pytorch import PytorchModelWrapper diff --git a/trulens_explain/tests/pytorch/unit/ffn_edge_case_architectures_test.py b/trulens_explain/tests/pytorch/unit/ffn_edge_case_architectures_test.py index 9d8e615b3..55f033dba 100644 --- a/trulens_explain/tests/pytorch/unit/ffn_edge_case_architectures_test.py +++ b/trulens_explain/tests/pytorch/unit/ffn_edge_case_architectures_test.py @@ -10,7 +10,6 @@ from torch.nn import Linear from torch.nn import Module from torch.nn import ReLU - from trulens.nn.attribution import InternalInfluence from trulens.nn.backend import get_backend from trulens.nn.distributions import PointDoi diff --git a/trulens_explain/tests/pytorch/unit/model_wrapper_test.py b/trulens_explain/tests/pytorch/unit/model_wrapper_test.py index 886feac30..fd59f376a 100644 --- a/trulens_explain/tests/pytorch/unit/model_wrapper_test.py +++ b/trulens_explain/tests/pytorch/unit/model_wrapper_test.py @@ -6,12 +6,11 @@ from unittest import TestCase import numpy as np +from tests.unit.model_wrapper_test_base import ModelWrapperTestBase from torch import Tensor from torch.nn import Linear from torch.nn import Module from torch.nn import ReLU - -from tests.unit.model_wrapper_test_base import ModelWrapperTestBase from trulens.nn.backend import get_backend from trulens.nn.models.pytorch import PytorchModelWrapper from trulens.nn.quantities import MaxClassQoI diff --git a/trulens_explain/tests/pytorch/unit/multi_qoi_test.py b/trulens_explain/tests/pytorch/unit/multi_qoi_test.py index b41ad28a4..222930b78 100644 --- a/trulens_explain/tests/pytorch/unit/multi_qoi_test.py +++ b/trulens_explain/tests/pytorch/unit/multi_qoi_test.py @@ -6,13 +6,12 @@ from unittest import TestCase import numpy as np +from tests.unit.multi_qoi_test_base import MultiQoiTestBase import torch from torch import cat from torch.nn import GRU from torch.nn import Linear from torch.nn import Module - -from tests.unit.multi_qoi_test_base import MultiQoiTestBase from trulens.nn.backend import get_backend from trulens.nn.models import get_model_wrapper diff --git a/trulens_explain/tests/tf/unit/ffn_edge_case_architectures_test.py b/trulens_explain/tests/tf/unit/ffn_edge_case_architectures_test.py index b8ed480f4..c204b1657 100644 --- a/trulens_explain/tests/tf/unit/ffn_edge_case_architectures_test.py +++ b/trulens_explain/tests/tf/unit/ffn_edge_case_architectures_test.py @@ -13,7 +13,6 @@ from tensorflow import Graph import tensorflow as tf from tensorflow.nn import relu - from trulens.nn.attribution import InternalInfluence from trulens.nn.backend import get_backend from trulens.nn.distributions import PointDoi diff --git a/trulens_explain/tests/tf2/unit/attribution_axioms_test.py b/trulens_explain/tests/tf2/unit/attribution_axioms_test.py index c34cc52ee..0ccd46d66 100644 --- a/trulens_explain/tests/tf2/unit/attribution_axioms_test.py +++ b/trulens_explain/tests/tf2/unit/attribution_axioms_test.py @@ -9,7 +9,6 @@ from tensorflow.keras.layers import Dense from tensorflow.keras.layers import Input from tensorflow.keras.models import Model - from tests.unit.attribution_axioms_test_base import AxiomsTestBase from trulens.nn.backend import get_backend from trulens.nn.models import get_model_wrapper diff --git a/trulens_explain/tests/tf2/unit/batch_test.py b/trulens_explain/tests/tf2/unit/batch_test.py index 4e6090d1c..3c3150c28 100644 --- a/trulens_explain/tests/tf2/unit/batch_test.py +++ b/trulens_explain/tests/tf2/unit/batch_test.py @@ -9,7 +9,6 @@ from tensorflow.keras.layers import Dense from tensorflow.keras.layers import Input from tensorflow.keras.models import Model - from tests.unit.batch_test_base import BatchTestBase from trulens.nn.models import get_model_wrapper diff --git a/trulens_explain/tests/tf2/unit/doi_test.py b/trulens_explain/tests/tf2/unit/doi_test.py index b7ef2eb63..40fc7c223 100644 --- a/trulens_explain/tests/tf2/unit/doi_test.py +++ b/trulens_explain/tests/tf2/unit/doi_test.py @@ -14,7 +14,6 @@ from tensorflow.keras.layers import Input from tensorflow.keras.layers import Lambda from tensorflow.keras.models import Model - from tests.unit.doi_test_base import DoiTestBase diff --git a/trulens_explain/tests/tf2/unit/environment_test.py b/trulens_explain/tests/tf2/unit/environment_test.py index cd8442ba7..af531c55c 100644 --- a/trulens_explain/tests/tf2/unit/environment_test.py +++ b/trulens_explain/tests/tf2/unit/environment_test.py @@ -5,7 +5,6 @@ from tensorflow.keras.layers import Dense from tensorflow.keras.layers import Input from tensorflow.keras.models import Model - from tests.unit.environment_test_base import EnvironmentTestBase from trulens.nn.backend import Backend from trulens.nn.models.tensorflow_v2 import Tensorflow2ModelWrapper diff --git a/trulens_explain/tests/tf2/unit/ffn_edge_case_architectures_test.py b/trulens_explain/tests/tf2/unit/ffn_edge_case_architectures_test.py index b2163ec3c..f3a45a473 100644 --- a/trulens_explain/tests/tf2/unit/ffn_edge_case_architectures_test.py +++ b/trulens_explain/tests/tf2/unit/ffn_edge_case_architectures_test.py @@ -11,7 +11,6 @@ from tensorflow.keras.layers import Dense from tensorflow.keras.layers import Input from tensorflow.keras.models import Model - from trulens.nn.attribution import InternalInfluence from trulens.nn.backend import get_backend from trulens.nn.distributions import DoI diff --git a/trulens_explain/tests/tf2/unit/model_wrapper_test.py b/trulens_explain/tests/tf2/unit/model_wrapper_test.py index 2cb177c90..00c09e448 100644 --- a/trulens_explain/tests/tf2/unit/model_wrapper_test.py +++ b/trulens_explain/tests/tf2/unit/model_wrapper_test.py @@ -9,7 +9,6 @@ from tensorflow.keras.layers import Dense from tensorflow.keras.layers import Input from tensorflow.keras.models import Model - from tests.unit.model_wrapper_test_base import ModelWrapperTestBase from trulens.nn.models import get_model_wrapper diff --git a/trulens_explain/tests/tf2/unit/model_wrapper_tf_function_test.py b/trulens_explain/tests/tf2/unit/model_wrapper_tf_function_test.py index 35ae0e66d..6ecfe33f8 100644 --- a/trulens_explain/tests/tf2/unit/model_wrapper_tf_function_test.py +++ b/trulens_explain/tests/tf2/unit/model_wrapper_tf_function_test.py @@ -11,7 +11,6 @@ from tensorflow.keras.layers import Dense from tensorflow.keras.layers import Input from tensorflow.keras.models import Model - from tests.unit.model_wrapper_test_base import ModelWrapperTestBase from trulens.nn.models import get_model_wrapper diff --git a/trulens_explain/tests/tf2/unit/model_wrapper_tf_subclassed_test.py b/trulens_explain/tests/tf2/unit/model_wrapper_tf_subclassed_test.py index 9cb11e6e0..f8b9fe76c 100644 --- a/trulens_explain/tests/tf2/unit/model_wrapper_tf_subclassed_test.py +++ b/trulens_explain/tests/tf2/unit/model_wrapper_tf_subclassed_test.py @@ -11,7 +11,6 @@ from tensorflow.keras.layers import Dense from tensorflow.keras.layers import Input from tensorflow.keras.models import Model - from tests.unit.model_wrapper_test_base import ModelWrapperTestBase from trulens.nn.models import get_model_wrapper diff --git a/trulens_explain/tests/tf2/unit/multi_qoi_test.py b/trulens_explain/tests/tf2/unit/multi_qoi_test.py index aeb9200b5..a3eda65b7 100644 --- a/trulens_explain/tests/tf2/unit/multi_qoi_test.py +++ b/trulens_explain/tests/tf2/unit/multi_qoi_test.py @@ -12,7 +12,6 @@ from tensorflow.keras.layers import GRU from tensorflow.keras.layers import Input from tensorflow.keras.models import Model - from tests.unit.multi_qoi_test_base import MultiQoiTestBase from trulens.nn.backend import get_backend from trulens.nn.models import get_model_wrapper diff --git a/trulens_explain/tests/tf2_non_eager/unit/attribution_axioms_test.py b/trulens_explain/tests/tf2_non_eager/unit/attribution_axioms_test.py index af0ca40bd..a49d6e002 100644 --- a/trulens_explain/tests/tf2_non_eager/unit/attribution_axioms_test.py +++ b/trulens_explain/tests/tf2_non_eager/unit/attribution_axioms_test.py @@ -10,7 +10,6 @@ from tensorflow.keras.layers import Dense from tensorflow.keras.layers import Input from tensorflow.keras.models import Model - from tests.unit.attribution_axioms_test_base import AxiomsTestBase from trulens.nn.backend import get_backend from trulens.nn.models import get_model_wrapper diff --git a/trulens_explain/tests/tf2_non_eager/unit/batch_test.py b/trulens_explain/tests/tf2_non_eager/unit/batch_test.py index 47d91c9ba..b4e118095 100644 --- a/trulens_explain/tests/tf2_non_eager/unit/batch_test.py +++ b/trulens_explain/tests/tf2_non_eager/unit/batch_test.py @@ -10,7 +10,6 @@ from tensorflow.keras.layers import Dense from tensorflow.keras.layers import Input from tensorflow.keras.models import Model - from tests.unit.batch_test_base import BatchTestBase from trulens.nn.models import get_model_wrapper diff --git a/trulens_explain/tests/tf2_non_eager/unit/doi_test.py b/trulens_explain/tests/tf2_non_eager/unit/doi_test.py index 811aa0b63..2f2087adf 100644 --- a/trulens_explain/tests/tf2_non_eager/unit/doi_test.py +++ b/trulens_explain/tests/tf2_non_eager/unit/doi_test.py @@ -6,7 +6,6 @@ from unittest import TestCase import tensorflow as tf - from tests.unit.doi_test_base import DoiTestBase assert (not tf.executing_eagerly()) @@ -14,7 +13,6 @@ from tensorflow.keras.layers import Input from tensorflow.keras.layers import Lambda from tensorflow.keras.models import Model - from tests.unit.doi_test_base import DoiTestBase from trulens.nn.models.keras import KerasModelWrapper diff --git a/trulens_explain/tests/tf2_non_eager/unit/environment_test.py b/trulens_explain/tests/tf2_non_eager/unit/environment_test.py index cd8442ba7..af531c55c 100644 --- a/trulens_explain/tests/tf2_non_eager/unit/environment_test.py +++ b/trulens_explain/tests/tf2_non_eager/unit/environment_test.py @@ -5,7 +5,6 @@ from tensorflow.keras.layers import Dense from tensorflow.keras.layers import Input from tensorflow.keras.models import Model - from tests.unit.environment_test_base import EnvironmentTestBase from trulens.nn.backend import Backend from trulens.nn.models.tensorflow_v2 import Tensorflow2ModelWrapper diff --git a/trulens_explain/tests/tf2_non_eager/unit/ffn_edge_case_architectures_test.py b/trulens_explain/tests/tf2_non_eager/unit/ffn_edge_case_architectures_test.py index 095108c17..1ddc35d37 100644 --- a/trulens_explain/tests/tf2_non_eager/unit/ffn_edge_case_architectures_test.py +++ b/trulens_explain/tests/tf2_non_eager/unit/ffn_edge_case_architectures_test.py @@ -12,7 +12,6 @@ from tensorflow.keras.layers import Dense from tensorflow.keras.layers import Input from tensorflow.keras.models import Model - from trulens.nn.attribution import InternalInfluence from trulens.nn.backend import get_backend from trulens.nn.distributions import PointDoi diff --git a/trulens_explain/tests/tf2_non_eager/unit/model_wrapper_test.py b/trulens_explain/tests/tf2_non_eager/unit/model_wrapper_test.py index 0d600f081..6a7d07590 100644 --- a/trulens_explain/tests/tf2_non_eager/unit/model_wrapper_test.py +++ b/trulens_explain/tests/tf2_non_eager/unit/model_wrapper_test.py @@ -6,7 +6,6 @@ from unittest import TestCase import tensorflow as tf - from tests.tf2.unit.model_wrapper_test import ModelWrapperTest assert (not tf.executing_eagerly()) diff --git a/trulens_explain/tests/tf2_non_eager/unit/qoi_test.py b/trulens_explain/tests/tf2_non_eager/unit/qoi_test.py index 842f01ecc..f59c0c99b 100644 --- a/trulens_explain/tests/tf2_non_eager/unit/qoi_test.py +++ b/trulens_explain/tests/tf2_non_eager/unit/qoi_test.py @@ -6,7 +6,6 @@ from unittest import TestCase import tensorflow as tf - from tests.unit.qoi_test_base import QoiTestBase assert (not tf.executing_eagerly()) diff --git a/trulens_explain/tests/tf_keras/unit/attribution_axioms_test.py b/trulens_explain/tests/tf_keras/unit/attribution_axioms_test.py index efb467446..46ae1a019 100644 --- a/trulens_explain/tests/tf_keras/unit/attribution_axioms_test.py +++ b/trulens_explain/tests/tf_keras/unit/attribution_axioms_test.py @@ -13,7 +13,6 @@ from tensorflow.keras.layers import Dense from tensorflow.keras.layers import Input from tensorflow.keras.models import Model - from tests.unit.attribution_axioms_test_base import AxiomsTestBase from trulens.nn.models import get_model_wrapper diff --git a/trulens_explain/tests/tf_keras/unit/batch_tests.py b/trulens_explain/tests/tf_keras/unit/batch_tests.py index 059aaea9c..955f0214c 100644 --- a/trulens_explain/tests/tf_keras/unit/batch_tests.py +++ b/trulens_explain/tests/tf_keras/unit/batch_tests.py @@ -9,7 +9,6 @@ from tensorflow.keras.layers import Dense from tensorflow.keras.layers import Input from tensorflow.keras.models import Model - from tests.unit.batch_test_base import BatchTestBase from trulens.nn.models import get_model_wrapper diff --git a/trulens_explain/tests/tf_keras/unit/doi_test.py b/trulens_explain/tests/tf_keras/unit/doi_test.py index 5f3630210..86180c45b 100644 --- a/trulens_explain/tests/tf_keras/unit/doi_test.py +++ b/trulens_explain/tests/tf_keras/unit/doi_test.py @@ -12,7 +12,6 @@ from tensorflow.keras.layers import Input from tensorflow.keras.layers import Lambda from tensorflow.keras.models import Model - from tests.unit.doi_test_base import DoiTestBase from trulens.nn.models.keras import KerasModelWrapper diff --git a/trulens_explain/tests/tf_keras/unit/environment_test.py b/trulens_explain/tests/tf_keras/unit/environment_test.py index f359fead8..8f50d2a37 100644 --- a/trulens_explain/tests/tf_keras/unit/environment_test.py +++ b/trulens_explain/tests/tf_keras/unit/environment_test.py @@ -4,7 +4,6 @@ from tensorflow.keras.layers import Dense from tensorflow.keras.layers import Input from tensorflow.keras.models import Model - from tests.unit.environment_test_base import EnvironmentTestBase from trulens.nn.backend import Backend from trulens.nn.models.keras import KerasModelWrapper diff --git a/trulens_explain/tests/tf_keras/unit/ffn_edge_case_architectures_test.py b/trulens_explain/tests/tf_keras/unit/ffn_edge_case_architectures_test.py index 7ce008da0..020bcdb10 100644 --- a/trulens_explain/tests/tf_keras/unit/ffn_edge_case_architectures_test.py +++ b/trulens_explain/tests/tf_keras/unit/ffn_edge_case_architectures_test.py @@ -15,7 +15,6 @@ from tensorflow.keras.layers import Dense from tensorflow.keras.layers import Input from tensorflow.keras.models import Model - from trulens.nn.attribution import InternalInfluence from trulens.nn.backend import get_backend from trulens.nn.distributions import PointDoi diff --git a/trulens_explain/tests/tf_keras/unit/keras_model_test.py b/trulens_explain/tests/tf_keras/unit/keras_model_test.py index 0300e6392..765a3187a 100644 --- a/trulens_explain/tests/tf_keras/unit/keras_model_test.py +++ b/trulens_explain/tests/tf_keras/unit/keras_model_test.py @@ -13,7 +13,6 @@ from tensorflow.keras.layers import Dense from tensorflow.keras.layers import Input from tensorflow.keras.models import Model - from tests.unit.model_wrapper_test_base import ModelWrapperTestBase from trulens.nn.models.keras import KerasModelWrapper diff --git a/trulens_explain/tests/unit/attribution_axioms_test_base.py b/trulens_explain/tests/unit/attribution_axioms_test_base.py index 799c5f343..a59d17dae 100644 --- a/trulens_explain/tests/unit/attribution_axioms_test_base.py +++ b/trulens_explain/tests/unit/attribution_axioms_test_base.py @@ -14,7 +14,6 @@ from functools import partial import numpy as np - from trulens.nn.attribution import InternalInfluence from trulens.nn.backend import get_backend from trulens.nn.distributions import DoI diff --git a/trulens_explain/tests/unit/backend_test_base.py b/trulens_explain/tests/unit/backend_test_base.py index eedf3ef36..c9ada9640 100644 --- a/trulens_explain/tests/unit/backend_test_base.py +++ b/trulens_explain/tests/unit/backend_test_base.py @@ -1,7 +1,6 @@ from unittest import TestCase import numpy as np - from trulens.nn.backend import get_backend diff --git a/trulens_explain/tests/unit/batch_test_base.py b/trulens_explain/tests/unit/batch_test_base.py index ea93b1091..30478f2a4 100644 --- a/trulens_explain/tests/unit/batch_test_base.py +++ b/trulens_explain/tests/unit/batch_test_base.py @@ -1,5 +1,4 @@ import numpy as np - from trulens.nn.attribution import InternalInfluence from trulens.nn.distributions import LinearDoi from trulens.nn.quantities import MaxClassQoI diff --git a/trulens_explain/tests/unit/determinism_test_base.py b/trulens_explain/tests/unit/determinism_test_base.py index 9f315054f..a0455276b 100644 --- a/trulens_explain/tests/unit/determinism_test_base.py +++ b/trulens_explain/tests/unit/determinism_test_base.py @@ -1,5 +1,4 @@ import numpy as np - from trulens.nn.backend import get_backend from trulens.nn.quantities import LambdaQoI from trulens.nn.quantities import MaxClassQoI diff --git a/trulens_explain/tests/unit/doi_test_base.py b/trulens_explain/tests/unit/doi_test_base.py index e99b369a2..0ef8db015 100644 --- a/trulens_explain/tests/unit/doi_test_base.py +++ b/trulens_explain/tests/unit/doi_test_base.py @@ -1,5 +1,4 @@ import numpy as np - from trulens.nn.attribution import InternalInfluence from trulens.nn.backend import get_backend from trulens.nn.distributions import GaussianDoi diff --git a/trulens_explain/tests/unit/environment_test_base.py b/trulens_explain/tests/unit/environment_test_base.py index c7e733203..ab3290ac2 100644 --- a/trulens_explain/tests/unit/environment_test_base.py +++ b/trulens_explain/tests/unit/environment_test_base.py @@ -2,7 +2,6 @@ import os import numpy as np - import trulens from trulens.nn.backend import Backend from trulens.nn.backend import get_backend diff --git a/trulens_explain/tests/unit/model_wrapper_test_base.py b/trulens_explain/tests/unit/model_wrapper_test_base.py index 0ce9b25b3..c466f8f07 100644 --- a/trulens_explain/tests/unit/model_wrapper_test_base.py +++ b/trulens_explain/tests/unit/model_wrapper_test_base.py @@ -1,5 +1,4 @@ import numpy as np - from trulens.nn.attribution import InputAttribution from trulens.nn.attribution import InternalInfluence from trulens.nn.backend import get_backend diff --git a/trulens_explain/tests/unit/multi_qoi_test_base.py b/trulens_explain/tests/unit/multi_qoi_test_base.py index 56d2f61ec..e467ef13d 100644 --- a/trulens_explain/tests/unit/multi_qoi_test_base.py +++ b/trulens_explain/tests/unit/multi_qoi_test_base.py @@ -2,7 +2,6 @@ from unittest import TestCase import numpy as np - from trulens.nn.attribution import InternalInfluence from trulens.nn.backend import get_backend from trulens.nn.distributions import DoI diff --git a/trulens_explain/tests/unit/qoi_test_base.py b/trulens_explain/tests/unit/qoi_test_base.py index 4a07b2a37..d2f5675d4 100644 --- a/trulens_explain/tests/unit/qoi_test_base.py +++ b/trulens_explain/tests/unit/qoi_test_base.py @@ -1,5 +1,4 @@ import numpy as np - from trulens.nn.backend import get_backend from trulens.nn.quantities import * diff --git a/trulens_explain/trulens/nn/attribution.py b/trulens_explain/trulens/nn/attribution.py index 998d7cf25..aa163f887 100644 --- a/trulens_explain/trulens/nn/attribution.py +++ b/trulens_explain/trulens/nn/attribution.py @@ -16,7 +16,6 @@ from typing import Callable, get_type_hints, List, Tuple, Union import numpy as np - from trulens.nn.backend import get_backend from trulens.nn.backend import memory_suggestions from trulens.nn.backend import rebatch diff --git a/trulens_explain/trulens/nn/backend/__init__.py b/trulens_explain/trulens/nn/backend/__init__.py index 26e863c9f..93d60a971 100644 --- a/trulens_explain/trulens/nn/backend/__init__.py +++ b/trulens_explain/trulens/nn/backend/__init__.py @@ -6,7 +6,6 @@ from typing import Iterable, Tuple import numpy as np - from trulens.utils import tru_logger from trulens.utils.typing import ModelInputs from trulens.utils.typing import nested_map diff --git a/trulens_explain/trulens/nn/backend/keras_backend/keras.py b/trulens_explain/trulens/nn/backend/keras_backend/keras.py index 3f8d3cc16..1d8d04c38 100644 --- a/trulens_explain/trulens/nn/backend/keras_backend/keras.py +++ b/trulens_explain/trulens/nn/backend/keras_backend/keras.py @@ -7,7 +7,6 @@ from typing import Sequence import numpy as np - from trulens.nn.backend import _ALL_BACKEND_API_FUNCTIONS from trulens.nn.backend import Backend from trulens.utils.typing import float_size diff --git a/trulens_explain/trulens/nn/backend/pytorch_backend/pytorch.py b/trulens_explain/trulens/nn/backend/pytorch_backend/pytorch.py index c0fc54a1a..af60b1c86 100644 --- a/trulens_explain/trulens/nn/backend/pytorch_backend/pytorch.py +++ b/trulens_explain/trulens/nn/backend/pytorch_backend/pytorch.py @@ -7,7 +7,6 @@ import numpy as np import torch - from trulens.nn.backend import _ALL_BACKEND_API_FUNCTIONS from trulens.nn.backend import Backend import trulens.nn.backend as base_backend diff --git a/trulens_explain/trulens/nn/backend/tf_backend/tf.py b/trulens_explain/trulens/nn/backend/tf_backend/tf.py index 1c436184f..f675ed523 100644 --- a/trulens_explain/trulens/nn/backend/tf_backend/tf.py +++ b/trulens_explain/trulens/nn/backend/tf_backend/tf.py @@ -7,7 +7,6 @@ import numpy as np import tensorflow as tf - from trulens.nn.backend import _ALL_BACKEND_API_FUNCTIONS from trulens.nn.backend import Backend from trulens.utils.typing import float_size diff --git a/trulens_explain/trulens/nn/distributions.py b/trulens_explain/trulens/nn/distributions.py index 9caa14638..a72d346e0 100644 --- a/trulens_explain/trulens/nn/distributions.py +++ b/trulens_explain/trulens/nn/distributions.py @@ -11,7 +11,6 @@ from typing import Callable, Optional import numpy as np - from trulens.nn.backend import get_backend from trulens.nn.slices import Cut from trulens.utils.typing import accepts_model_inputs diff --git a/trulens_explain/trulens/nn/models/_model_base.py b/trulens_explain/trulens/nn/models/_model_base.py index 95be22dc5..b3f22d284 100644 --- a/trulens_explain/trulens/nn/models/_model_base.py +++ b/trulens_explain/trulens/nn/models/_model_base.py @@ -12,7 +12,6 @@ from typing import List, Optional, Tuple, Type, Union import numpy as np - from trulens.nn.backend import get_backend from trulens.nn.quantities import QoI from trulens.nn.slices import Cut @@ -149,10 +148,11 @@ def fprop( attribution_cut: Optional[Cut] = None, intervention: InterventionLike = None, **kwargs - ) -> Union[ArgsLike[TensorLike], # attribution_cut is None - Tuple[ArgsLike[TensorLike], - ArgsLike[TensorLike]] # attribution_cut is not None - ]: + ) -> Union[ + ArgsLike[TensorLike], # attribution_cut is None + Tuple[ArgsLike[TensorLike], + ArgsLike[TensorLike]] # attribution_cut is not None + ]: """ **_Used internally by `AttributionMethod`._** diff --git a/trulens_explain/trulens/nn/models/keras.py b/trulens_explain/trulens/nn/models/keras.py index 857ca41ab..f799ac4aa 100644 --- a/trulens_explain/trulens/nn/models/keras.py +++ b/trulens_explain/trulens/nn/models/keras.py @@ -462,12 +462,16 @@ def _tensor(k): ) # Other placeholders come from kwargs. val_map.update( - {hash_tensor(_tensor(k)): v for k, v in model_kwargs.items()} + { + hash_tensor(_tensor(k)): v for k, v in model_kwargs.items() + } ) # Finally, interventions override any previously set tensors. val_map.update( - {hash_tensor(k): v for k, v in zip(doi_tensors, intervention)} + { + hash_tensor(k): v for k, v in zip(doi_tensors, intervention) + } ) all_inputs = [unhash_tensor(k) for k in val_map] diff --git a/trulens_explain/trulens/nn/models/pytorch.py b/trulens_explain/trulens/nn/models/pytorch.py index c2728d361..a78385c56 100644 --- a/trulens_explain/trulens/nn/models/pytorch.py +++ b/trulens_explain/trulens/nn/models/pytorch.py @@ -5,7 +5,6 @@ import numpy as np import torch - from trulens.nn.backend import get_backend from trulens.nn.backend.pytorch_backend import pytorch from trulens.nn.backend.pytorch_backend.pytorch import memory_suggestions @@ -224,8 +223,8 @@ def _get_hook_val(k): elif isinstance(cut, LogitCut): return_output = many_of_om( - hooks['logits' if self._logit_layer is None else self. - _logit_layer] + hooks['logits' if self._logit_layer is + None else self._logit_layer] ) elif isinstance(cut.name, DATA_CONTAINER_TYPE): diff --git a/trulens_explain/trulens/nn/models/tensorflow_v1.py b/trulens_explain/trulens/nn/models/tensorflow_v1.py index ccd28ae82..87acddfba 100644 --- a/trulens_explain/trulens/nn/models/tensorflow_v1.py +++ b/trulens_explain/trulens/nn/models/tensorflow_v1.py @@ -3,7 +3,6 @@ import numpy as np import tensorflow as tf - from trulens.nn.backend import get_backend from trulens.nn.models._model_base import ModelWrapper from trulens.nn.quantities import QoI @@ -207,7 +206,9 @@ def _tensor(k): # raise ValueError(f"Expected to get {len(doi_tensors)} inputs for intervention but got {len(args)} args and {len(kwargs)} kwargs.") intervention_dict.update( - {k: v for k, v in zip(doi_tensors[0:len(args)], args)} + { + k: v for k, v in zip(doi_tensors[0:len(args)], args) + } ) intervention_dict.update({_tensor(k): v for k, v in kwargs.items()}) diff --git a/trulens_explain/trulens/nn/models/tensorflow_v2.py b/trulens_explain/trulens/nn/models/tensorflow_v2.py index a644ea195..d49ccd9c3 100644 --- a/trulens_explain/trulens/nn/models/tensorflow_v2.py +++ b/trulens_explain/trulens/nn/models/tensorflow_v2.py @@ -1,7 +1,6 @@ from typing import Tuple import tensorflow as tf - from trulens.nn.backend import get_backend from trulens.nn.models.keras import \ KerasModelWrapper # dangerous to have this here if tf-less keras gets imported diff --git a/trulens_explain/trulens/utils/typing.py b/trulens_explain/trulens/utils/typing.py index 8750e095f..101d4ef0b 100644 --- a/trulens_explain/trulens/utils/typing.py +++ b/trulens_explain/trulens/utils/typing.py @@ -187,11 +187,12 @@ class Uniform(Generic[V], List[V]): # Each backend should define this. Tensor = TypeVar("Tensor") -ModelLike = Union['tf.Graph', # tf1 - 'keras.Model', # keras - 'tensorflow.keras.Model', # tf2 - 'torch.nn.Module', # pytorch - ] +ModelLike = Union[ + 'tf.Graph', # tf1 + 'keras.Model', # keras + 'tensorflow.keras.Model', # tf2 + 'torch.nn.Module', # pytorch +] # Atomic model inputs (at least from our perspective) TensorLike = Union[np.ndarray, Tensor] diff --git a/trulens_explain/trulens/visualizations.py b/trulens_explain/trulens/visualizations.py index f8d40c3f3..52bf8a893 100644 --- a/trulens_explain/trulens/visualizations.py +++ b/trulens_explain/trulens/visualizations.py @@ -27,7 +27,6 @@ import matplotlib.pyplot as plt import numpy as np from scipy.ndimage.filters import gaussian_filter - from trulens.nn.attribution import AttributionMethod from trulens.nn.attribution import InternalInfluence from trulens.nn.backend import get_backend