Skip to content

Commit

Permalink
Merge branch 'main' into feat/pytest_format
Browse files Browse the repository at this point in the history
  • Loading branch information
andreybavt authored Jan 2, 2024
2 parents e7e6194 + bd478ef commit c1edb75
Show file tree
Hide file tree
Showing 39 changed files with 2,320 additions and 3,562 deletions.
22 changes: 22 additions & 0 deletions .github/workflows/pre-commit-checks.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
name: Pre-commit checks
on:
push:
branches:
- main
pull_request:
workflow_dispatch:

env:
GSK_DISABLE_ANALYTICS: true
SENTRY_ENABLED: false
defaults:
run:
shell: bash
jobs:
pre-commit:
name: Pre-commit checks
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
- uses: pre-commit/[email protected]
2 changes: 1 addition & 1 deletion conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def pytest_addoption(parser: Parser):

def separate_process(item: Function) -> List[TestReport]:
with NamedTemporaryFile(delete=False) as fp:
proc = subprocess.run(
subprocess.run(
shell=True,
check=False,
stdout=sys.stdout,
Expand Down
6 changes: 6 additions & 0 deletions docs/community/contribution_guidelines/dev-environment.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,12 @@ brew install pre-commit
pre-commit install
```

## Run pre-commit hook manually to fix easy issues
In case the build is failing because of the pre-commit checks that don't pass it's possible to fix easy issues by running
```sh
pre-commit run --all-files
```
and then committing the fixed files

## Troubleshooting

Expand Down
3 changes: 2 additions & 1 deletion docs/reference/tests/llm.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@ LLM tests
Injections
----------
.. autofunction:: giskard.testing.tests.llm.test_llm_char_injection
.. autofunction:: giskard.testing.tests.llm.test_llm_prompt_injection
.. autofunction:: giskard.testing.tests.llm.test_llm_single_output_against_strings
.. autofunction:: giskard.testing.tests.llm.test_llm_output_against_strings

LLM-as-a-judge
--------------
Expand Down
19 changes: 11 additions & 8 deletions giskard/client/giskard_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,14 +208,17 @@ def load_model_meta(self, project_key: str, uuid: str) -> ModelMetaInfo:
def load_dataset_meta(self, project_key: str, uuid: str) -> DatasetMeta:
res = self._session.get(f"project/{project_key}/datasets/{uuid}").json()
info = DatasetMetaInfo.parse_obj(res) # Used for validation, and avoid extraand typos
analytics.track("hub:dataset:download", {
"project": anonymize(project_key),
"name": anonymize(info.name),
"target": anonymize(info.target),
"columnTypes": anonymize(info.columnTypes),
"columnDtypes": anonymize(info.columnDtypes),
"nb_rows": info.numberOfRows,
})
analytics.track(
"hub:dataset:download",
{
"project": anonymize(project_key),
"name": anonymize(info.name),
"target": anonymize(info.target),
"columnTypes": anonymize(info.columnTypes),
"columnDtypes": anonymize(info.columnDtypes),
"nb_rows": info.numberOfRows,
},
)
return DatasetMeta(
name=info.name,
target=info.target,
Expand Down
4 changes: 2 additions & 2 deletions giskard/commands/cli_hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,9 +192,9 @@ def _pull_image(version):
if not _check_downloaded(version):
logger.info(f"Downloading image for version {version}")
try:
analytics.track('giskard-server:install:start', {'version': version})
analytics.track("giskard-server:install:start", {"version": version})
create_docker_client().images.pull(IMAGE_NAME, tag=version)
analytics.track('giskard-server:install:success', {'version': version})
analytics.track("giskard-server:install:success", {"version": version})
except NotFound:
logger.error(
f"Image {get_image_name(version)} not found. Use a valid `--version` argument or check the content of $GSK_HOME/server-settings.yml"
Expand Down
59 changes: 36 additions & 23 deletions giskard/core/suite.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
from typing import Any, Dict, List, Optional, Tuple, Union

import inspect
import logging
import traceback
from dataclasses import dataclass
from functools import singledispatchmethod

from mlflow import MlflowClient
from typing import Any, Dict, List, Optional, Tuple, Union

from giskard.client.dtos import SuiteInfo, SuiteTestDTO, TestInputDTO, TestSuiteDTO
from giskard.client.giskard_client import GiskardClient
Expand All @@ -21,8 +20,9 @@
from giskard.ml_worker.testing.registry.transformation_function import TransformationFunction
from giskard.ml_worker.testing.test_result import TestMessage, TestMessageLevel, TestResult
from giskard.models.base import BaseModel

from ..client.python_utils import warning
from ..utils.analytics_collector import analytics
from ..utils.artifacts import serialize_parameter

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -231,17 +231,15 @@ def single_binary_result(test_results: List):
return all(res.passed for res in test_results)


def build_test_input_dto(client, p, pname, ptype, project_key, uploaded_uuids):
def build_test_input_dto(client, p, pname, ptype, project_key, uploaded_uuid_status: Dict[str, bool]):
if issubclass(type(p), Dataset) or issubclass(type(p), BaseModel):
if str(p.id) not in uploaded_uuids:
p.upload(client, project_key)
uploaded_uuids.append(str(p.id))
return TestInputDTO(name=pname, value=str(p.id), type=ptype)
if _try_upload_artifact(p, client, project_key, uploaded_uuid_status):
return TestInputDTO(name=pname, value=str(p.id), type=ptype)
else:
return TestInputDTO(name=pname, value=pname, is_alias=True, type=ptype)
elif issubclass(type(p), Artifact):
if str(p.meta.uuid) not in uploaded_uuids:
p.upload(client, None if "giskard" in p.meta.tags else project_key)

uploaded_uuids.append(str(p.meta.uuid))
if not _try_upload_artifact(p, client, None if "giskard" in p.meta.tags else project_key, uploaded_uuid_status):
return TestInputDTO(name=pname, value=pname, is_alias=True, type=ptype)

kwargs_params = [
f"kwargs[{pname}] = {repr(value)}" for pname, value in p.params.items() if pname not in p.meta.args
Expand All @@ -263,7 +261,7 @@ def build_test_input_dto(client, p, pname, ptype, project_key, uploaded_uuids):
pname,
p.meta.args[pname].type,
project_key,
uploaded_uuids,
uploaded_uuid_status,
)
for pname, value in p.params.items()
if pname in p.meta.args
Expand Down Expand Up @@ -446,26 +444,25 @@ def upload(self, client: GiskardClient, project_key: str):
if self.name is None:
self.name = "Unnamed test suite"

uploaded_uuids: List[str] = []
uploaded_uuid_status: Dict[str, bool] = dict()

# Upload the default parameters if they are model or dataset
for arg in self.default_params.values():
if isinstance(arg, BaseModel) or isinstance(arg, Dataset):
arg.upload(client, project_key)
uploaded_uuids.append(str(arg.id))
_try_upload_artifact(arg, client, project_key, uploaded_uuid_status)

self.id = client.save_test_suite(self.to_dto(client, project_key, uploaded_uuid_status))

self.id = client.save_test_suite(self.to_dto(client, project_key, uploaded_uuids))
project_id = client.get_project(project_key).project_id
print(f"Test suite has been saved: {client.host_url}/main/projects/{project_id}/test-suite/{self.id}/overview")
print(f"Test suite has been saved: {client.host_url}/main/projects/{project_key}/test-suite/{self.id}/overview")
analytics.track("hub:test_suite:uploaded")
return self

def to_dto(self, client: GiskardClient, project_key: str, uploaded_uuids: Optional[List[str]] = None):
def to_dto(self, client: GiskardClient, project_key: str, uploaded_uuid_status: Optional[Dict[str, bool]] = None):
suite_tests: List[SuiteTestDTO] = list()

# Avoid to upload the same artifacts several times
if uploaded_uuids is None:
uploaded_uuids = []
if uploaded_uuid_status is None:
uploaded_uuid_status = dict()

for t in self.tests:
params = dict(
Expand All @@ -476,7 +473,7 @@ def to_dto(self, client: GiskardClient, project_key: str, uploaded_uuids: Option
pname,
t.giskard_test.meta.args[pname].type,
project_key,
uploaded_uuids,
uploaded_uuid_status,
)
for pname, p in t.provided_inputs.items()
if pname in t.giskard_test.meta.args
Expand Down Expand Up @@ -699,3 +696,19 @@ def format_test_result(result: Union[bool, TestResult]) -> str:
return f"{{{'passed' if result.passed else 'failed'}, metric={result.metric}}}"
else:
return "passed" if result else "failed"


def _try_upload_artifact(artifact, client, project_key: str, uploaded_uuid_status: Dict[str, bool]) -> bool:
artifact_id = serialize_parameter(artifact)

if artifact_id not in uploaded_uuid_status:
try:
artifact.upload(client, project_key)
uploaded_uuid_status[artifact_id] = True
except: # noqa NOSONAR
warning(
f"Failed to upload {str(artifact)} used in the test suite. The test suite will be partially uploaded."
)
uploaded_uuid_status[artifact_id] = False

return uploaded_uuid_status[artifact_id]
1 change: 0 additions & 1 deletion giskard/datasets/base/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,6 @@ def __init__(
}

self.data_processor = DataProcessor()
analytics.track("wrap:dataset:success", {"nb_rows": self.number_of_rows})
logger.info("Your 'pandas.DataFrame' is successfully wrapped by Giskard's 'Dataset' wrapper class.")

@property
Expand Down
1 change: 1 addition & 0 deletions giskard/demo/titanic_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ def get_model_and_df(model: str = ModelTypes.LOGISTIC_REGRESSION, max_iter: int
clf = Pipeline(steps=[("preprocessor", preprocessor), ("classifier", LogisticRegression(max_iter=max_iter))])
elif model.lower() == ModelTypes.LGBM_CLASSIFIER.lower():
from lightgbm import LGBMClassifier

clf = Pipeline(steps=[("preprocessor", preprocessor), ("classifier", LGBMClassifier(n_estimators=max_iter))])
else:
raise NotImplementedError(f"The model type '{model}' is not supported!")
Expand Down
13 changes: 11 additions & 2 deletions giskard/llm/evaluators/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from dataclasses import dataclass

from typing import Sequence, Optional
from abc import ABC, abstractmethod


from ..client import LLMClient, get_default_client
from ..errors import LLMGenerationError
Expand Down Expand Up @@ -53,7 +54,15 @@ def passed_ratio(self):
return len(self.success_examples) / (len(self.success_examples) + len(self.failure_examples))


class LLMBasedEvaluator:
class BaseEvaluator(ABC):
"""Base class for evaluators that define a way of detecting a LLM failure"""

@abstractmethod
def evaluate(self, model: BaseModel, dataset: Dataset):
...


class LLMBasedEvaluator(BaseEvaluator):
_default_eval_prompt: str

def __init__(self, eval_prompt=None, llm_temperature=0.1, llm_client: LLMClient = None):
Expand Down
83 changes: 83 additions & 0 deletions giskard/llm/evaluators/string_matcher.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
import re
import string
import logging
from typing import Tuple, List
from dataclasses import dataclass

from .base import BaseEvaluator, EvaluationResult
from ...datasets.base import Dataset
from ...models.base.model import BaseModel
from ..errors import LLMGenerationError

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)


@dataclass(frozen=True)
class StringMatcherConfig:
expected_strings: Tuple[str]
all_expected_strings_must_be_found: bool = True
exact_matching: bool = False
word_matching: bool = False
case_sensitive: bool = True
punctuation_sensitive: bool = True
evaluation_method_name: str = "StringMatchingMethod"


class StringMatcher:
def __init__(self, config: StringMatcherConfig) -> None:
self.config = config

def normalize_text(self, text):
if not self.config.case_sensitive:
text = text.lower()
if not self.config.punctuation_sensitive:
text = text.translate(str.maketrans("", "", string.punctuation))
return text

def evaluate_single_string(self, string: str, text: str):
n_string = self.normalize_text(string)
n_text = self.normalize_text(text)
if self.config.exact_matching:
return n_string == n_text
if self.config.word_matching:
return re.search(r"\b" + re.escape(n_string) + r"\b", text) is not None
return n_string in n_text

def evaluate(self, text: str):
matches = (self.evaluate_single_string(string, text) for string in self.config.expected_strings)
if self.config.all_expected_strings_must_be_found:
return all(matches)
return any(matches)


class StringMatcherEvaluator(BaseEvaluator):
def evaluate(self, model: BaseModel, dataset: Dataset, evaluator_configs: List[StringMatcherConfig]):
succeeded = []
failed = []
failed_idx = []
errored = []
model_inputs = dataset.df.loc[:, model.meta.feature_names].to_dict("records")
model_outputs = model.predict(dataset).prediction

for idx, inputs, outputs, config in zip(dataset.df.index, model_inputs, model_outputs, evaluator_configs):
string_matcher = StringMatcher(config)

try:
injection_success = string_matcher.evaluate(outputs)
except LLMGenerationError as err:
errored.append({"message": str(err), "sample": inputs})
continue

if not injection_success:
succeeded.append({"input_vars": inputs, "model_output": outputs})
else:
failed.append({"input_vars": inputs, "model_output": outputs})
failed_idx.append(idx)

return EvaluationResult(
failure_examples=failed,
output_ds=dataset.slice(lambda df: df.loc[failed_idx], row_level=False),
success_examples=succeeded,
errors=errored,
)
7 changes: 6 additions & 1 deletion giskard/llm/generators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,9 @@
from .implausible import ImplausibleDataGenerator
from .sycophancy import SycophancyDataGenerator

__all__ = ["BaseDataGenerator", "SycophancyDataGenerator", "ImplausibleDataGenerator", "AdversarialDataGenerator"]
__all__ = [
"BaseDataGenerator",
"SycophancyDataGenerator",
"ImplausibleDataGenerator",
"AdversarialDataGenerator",
]
12 changes: 7 additions & 5 deletions giskard/llm/generators/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,13 @@
LANGUAGE_REQUIREMENT_PROMPT = "You must generate input using different languages among the following list: {languages}."


class LLMGenerator(ABC):
class BaseGenerator(ABC):
@abstractmethod
def generate_dataset(self, model, num_samples=10, column_types=None) -> Dataset:
...


class LLMGenerator(BaseGenerator, ABC):
_default_temperature = 0.5
_default_model = "gpt-4"
_default_prompt = DEFAULT_GENERATE_INPUTS_PROMPT
Expand All @@ -41,10 +47,6 @@ def __init__(
self.languages = languages
self.prompt = prompt if prompt is not None else self._default_prompt

@abstractmethod
def generate_dataset(self, model, num_samples=10, column_types=None) -> Dataset:
...


class BaseDataGenerator(LLMGenerator):
def _make_generate_input_prompt(self, model: BaseModel, num_samples: int):
Expand Down
File renamed without changes.
Loading

0 comments on commit c1edb75

Please sign in to comment.