-
-
Notifications
You must be signed in to change notification settings - Fork 291
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'main' into feat/pytest_format
- Loading branch information
Showing
39 changed files
with
2,320 additions
and
3,562 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
name: Pre-commit checks | ||
on: | ||
push: | ||
branches: | ||
- main | ||
pull_request: | ||
workflow_dispatch: | ||
|
||
env: | ||
GSK_DISABLE_ANALYTICS: true | ||
SENTRY_ENABLED: false | ||
defaults: | ||
run: | ||
shell: bash | ||
jobs: | ||
pre-commit: | ||
name: Pre-commit checks | ||
runs-on: ubuntu-latest | ||
steps: | ||
- uses: actions/checkout@v4 | ||
- uses: actions/setup-python@v5 | ||
- uses: pre-commit/[email protected] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
import re | ||
import string | ||
import logging | ||
from typing import Tuple, List | ||
from dataclasses import dataclass | ||
|
||
from .base import BaseEvaluator, EvaluationResult | ||
from ...datasets.base import Dataset | ||
from ...models.base.model import BaseModel | ||
from ..errors import LLMGenerationError | ||
|
||
logger = logging.getLogger(__name__) | ||
logger.setLevel(logging.INFO) | ||
|
||
|
||
@dataclass(frozen=True) | ||
class StringMatcherConfig: | ||
expected_strings: Tuple[str] | ||
all_expected_strings_must_be_found: bool = True | ||
exact_matching: bool = False | ||
word_matching: bool = False | ||
case_sensitive: bool = True | ||
punctuation_sensitive: bool = True | ||
evaluation_method_name: str = "StringMatchingMethod" | ||
|
||
|
||
class StringMatcher: | ||
def __init__(self, config: StringMatcherConfig) -> None: | ||
self.config = config | ||
|
||
def normalize_text(self, text): | ||
if not self.config.case_sensitive: | ||
text = text.lower() | ||
if not self.config.punctuation_sensitive: | ||
text = text.translate(str.maketrans("", "", string.punctuation)) | ||
return text | ||
|
||
def evaluate_single_string(self, string: str, text: str): | ||
n_string = self.normalize_text(string) | ||
n_text = self.normalize_text(text) | ||
if self.config.exact_matching: | ||
return n_string == n_text | ||
if self.config.word_matching: | ||
return re.search(r"\b" + re.escape(n_string) + r"\b", text) is not None | ||
return n_string in n_text | ||
|
||
def evaluate(self, text: str): | ||
matches = (self.evaluate_single_string(string, text) for string in self.config.expected_strings) | ||
if self.config.all_expected_strings_must_be_found: | ||
return all(matches) | ||
return any(matches) | ||
|
||
|
||
class StringMatcherEvaluator(BaseEvaluator): | ||
def evaluate(self, model: BaseModel, dataset: Dataset, evaluator_configs: List[StringMatcherConfig]): | ||
succeeded = [] | ||
failed = [] | ||
failed_idx = [] | ||
errored = [] | ||
model_inputs = dataset.df.loc[:, model.meta.feature_names].to_dict("records") | ||
model_outputs = model.predict(dataset).prediction | ||
|
||
for idx, inputs, outputs, config in zip(dataset.df.index, model_inputs, model_outputs, evaluator_configs): | ||
string_matcher = StringMatcher(config) | ||
|
||
try: | ||
injection_success = string_matcher.evaluate(outputs) | ||
except LLMGenerationError as err: | ||
errored.append({"message": str(err), "sample": inputs}) | ||
continue | ||
|
||
if not injection_success: | ||
succeeded.append({"input_vars": inputs, "model_output": outputs}) | ||
else: | ||
failed.append({"input_vars": inputs, "model_output": outputs}) | ||
failed_idx.append(idx) | ||
|
||
return EvaluationResult( | ||
failure_examples=failed, | ||
output_ds=dataset.slice(lambda df: df.loc[failed_idx], row_level=False), | ||
success_examples=succeeded, | ||
errors=errored, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
Oops, something went wrong.