Skip to content

Commit

Permalink
debug migration issue in release pipeline (#726)
Browse files Browse the repository at this point in the history
* debug

* display python version

* python version

* PromptTemplate update import

* bad escape fix

* add msg to exception

* pass kwargs in Groundedness

* pass kwargs with GroundTruthAgreement

* give default value to ground_truth_imp

* migrate db on reset
  • Loading branch information
piotrm0 authored Dec 27, 2023
1 parent 27b664c commit eff8562
Show file tree
Hide file tree
Showing 13 changed files with 47 additions and 28 deletions.
1 change: 1 addition & 0 deletions .azure_pipelines/ci-eval-pr.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ jobs:
set -e
source activate $(condaEnvFileSuffix)
which python
python --version
displayName: Which Python
- bash: |
Expand Down
1 change: 1 addition & 0 deletions .azure_pipelines/ci-eval.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ jobs:
set -e
source activate $(condaEnvFileSuffix)
which python
python --version
displayName: Which Python
- bash: |
Expand Down
1 change: 1 addition & 0 deletions .azure_pipelines/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ jobs:
set -e
source activate $(condaEnvFileSuffix)
which python
python --version
displayName: Which Python
- bash: |
set -e
Expand Down
2 changes: 1 addition & 1 deletion docs/trulens_eval/langchain_instrumentation.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@
"outputs": [],
"source": [
"from langchain import LLMChain\n",
"from langchain import PromptTemplate\n",
"from langchain.prompts import PromptTemplate\n",
"from langchain.callbacks import AsyncIteratorCallbackHandler\n",
"from langchain.chains import LLMChain\n",
"from langchain.chat_models.openai import ChatOpenAI\n",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
"from langchain.llms import OpenAI\n",
"from langchain.prompts import ChatPromptTemplate, PromptTemplate\n",
"from langchain.prompts import HumanMessagePromptTemplate\n",
"from langchain import PromptTemplate\n",
"from langchain.prompts import PromptTemplate\n",
"from langchain.llms import OpenAI\n",
"from langchain import LLMChain\n",
"\n",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@
"source": [
"from transformers import LlamaTokenizer, LlamaForCausalLM, GenerationConfig, pipeline\n",
"from langchain.llms import HuggingFacePipeline\n",
"from langchain import PromptTemplate, LLMChain\n",
"from langchain import LLMChain\n",
"from langchain.prompts import PromptTemplate\n",
"import openai\n",
"import torch\n",
"from trulens_eval.schema import Select\n",
Expand Down
18 changes: 10 additions & 8 deletions trulens_eval/trulens_eval/database/sqlalchemy_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from sqlalchemy import create_engine
from sqlalchemy import Engine
from sqlalchemy import select
from sqlalchemy.schema import MetaData
from sqlalchemy.orm import sessionmaker

from trulens_eval import schema
Expand Down Expand Up @@ -47,7 +48,11 @@

@for_all_methods(
run_before(lambda self, *args, **kwargs: check_db_revision(self.engine)),
_except=["migrate_database", "reload_engine"]
_except=[
"migrate_database",
"reload_engine",
"reset_database" # migrates database automatically
]
)
class SqlAlchemyDB(DB):
engine_params: dict = Field(default_factory=dict)
Expand Down Expand Up @@ -132,14 +137,11 @@ def migrate_database(self):
logger.info("Your database does not need migration.")

def reset_database(self):
deleted = 0
with self.Session.begin() as session:
deleted += session.query(AppDefinition).delete()
deleted += session.query(FeedbackDefinition).delete()
deleted += session.query(Record).delete()
deleted += session.query(FeedbackResult).delete()
meta = MetaData()
meta.reflect(bind=self.engine)
meta.drop_all(bind=self.engine)

logger.info(f"Deleted {deleted} rows.")
self.migrate_database()

def insert_record(self, record: schema.Record) -> schema.RecordID:
# TODO: thread safety
Expand Down
15 changes: 8 additions & 7 deletions trulens_eval/trulens_eval/db_migration.py
Original file line number Diff line number Diff line change
Expand Up @@ -523,8 +523,6 @@ def _check_needs_migration(version: str, warn=False) -> None:
"""
compat_version = _get_compatibility_version(version)

print("compat_version=", compat_version)

if migration_versions.index(compat_version) > 0:

if _upgrade_possible(compat_version):
Expand Down Expand Up @@ -629,13 +627,16 @@ def migrate(db) -> None:
Args:
db (DB): the db object
"""
# NOTE TO DEVELOPER: If this method fails: It's likely you made a db breaking change.
# Follow these steps to add a compatibility change
# NOTE TO DEVELOPER: If this method fails: It's likely you made a db
# breaking change. Follow these steps to add a compatibility change
# - Update the __init__ version to the next one (if not already)
# - In this file: add that version to `migration_versions` variable`
# - Add the migration step in `upgrade_paths` of the form `from_version`:(`to_version_you_just_created`, `migration_function`)
# - AFTER YOU PASS TESTS - add your newest db into `release_dbs/<version_you_just_created>/default.sqlite`
# - This is created by running the all_tools and llama_quickstart from a fresh db (you can `rm -rf` the sqlite file )
# - Add the migration step in `upgrade_paths` of the form
# `from_version`:(`to_version_you_just_created`, `migration_function`)
# - AFTER YOU PASS TESTS - add your newest db into
# `release_dbs/<version_you_just_created>/default.sqlite`
# - This is created by running the all_tools and llama_quickstart from a
# fresh db (you can `rm -rf` the sqlite file )
# - TODO: automate this step
original_db_file = db.filename
global saved_db_locations
Expand Down
7 changes: 4 additions & 3 deletions trulens_eval/trulens_eval/feedback/groundedness.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import logging
from typing import Dict, List
from typing import Dict, List, Optional

import numpy as np
from tqdm.auto import tqdm
Expand All @@ -23,7 +23,7 @@ class Groundedness(SerialModel, WithClassInfo):
"""
groundedness_provider: Provider

def __init__(self, groundedness_provider: Provider = None):
def __init__(self, groundedness_provider: Optional[Provider] = None, **kwargs):
"""Instantiates the groundedness providers. Currently the groundedness functions work well with a summarizer.
This class will use an LLM to find the relevant strings in a text. The groundedness_provider can
either be an LLM provider (such as OpenAI) or NLI with huggingface.
Expand Down Expand Up @@ -53,7 +53,8 @@ def __init__(self, groundedness_provider: Provider = None):
groundedness_provider = OpenAI()
super().__init__(
groundedness_provider=groundedness_provider,
obj=self # for WithClassInfo
obj=self, # for WithClassInfo
**kwargs
)

def groundedness_measure(self, source: str, statement: str) -> float:
Expand Down
8 changes: 5 additions & 3 deletions trulens_eval/trulens_eval/feedback/groundtruth.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class GroundTruthAgreement(SerialModel, WithClassInfo):
# It's a class member because creating it is expensive
bert_scorer: object

ground_truth_imp: Optional[Callable] = pydantic.Field(exclude=True)
ground_truth_imp: Optional[Callable] = pydantic.Field(None, exclude=True)

model_config: ClassVar[dict] = dict(
arbitrary_types_allowed = True
Expand All @@ -43,7 +43,8 @@ def __init__(
self,
ground_truth: Union[List, Callable, FunctionOrMethod],
provider: Optional[Provider] = None,
bert_scorer: Optional["BERTScorer"] = None
bert_scorer: Optional["BERTScorer"] = None,
**kwargs
):
"""Measures Agreement against a Ground Truth.
Expand Down Expand Up @@ -94,7 +95,8 @@ def __init__(
ground_truth_imp=ground_truth_imp,
provider=provider,
bert_scorer=bert_scorer,
obj=self # for WithClassInfo
obj=self, # for WithClassInfo
**kwargs
)

def _find_response(self, prompt: str) -> Optional[str]:
Expand Down
2 changes: 1 addition & 1 deletion trulens_eval/trulens_eval/feedback/v2/feedback.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from abc import abstractmethod
from typing import ClassVar, List, Optional

from langchain import PromptTemplate
from langchain.prompts import PromptTemplate
from langchain.evaluation.criteria.eval_chain import _SUPPORTED_CRITERIA
import pydantic

Expand Down
2 changes: 1 addition & 1 deletion trulens_eval/trulens_eval/utils/generated.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def re_0_10_rating(str_val):
matches = pat_0_10.fullmatch(str_val)
if not matches:
# Try soft match
matches = re.search('([0-9]+)(?=\D*$)', str_val)
matches = re.search(r'([0-9]+)(?=\D*$)', str_val)
if not matches:
logger.warning(f"0-10 rating regex failed to match on: '{str_val}'")
return -10 # so this will be reported as -1 after division by 10
Expand Down
13 changes: 11 additions & 2 deletions trulens_eval/trulens_eval/utils/pyschema.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,8 +392,17 @@ def load(self) -> object:
else:
extra_kwargs = {}

bindings = self.init_bindings.load(sig, extra_kwargs=extra_kwargs)

try:
bindings = self.init_bindings.load(sig, extra_kwargs=extra_kwargs)

except Exception as e:
msg = f"Error binding constructor args for object:\n"
msg += str(e) + "\n"
msg += f"\tobj={self}\n"
msg += f"\targs={self.init_bindings.args}\n"
msg += f"\tkwargs={self.init_bindings.kwargs}\n"
raise type(e)(msg)

return cls(*bindings.args, **bindings.kwargs)


Expand Down

0 comments on commit eff8562

Please sign in to comment.