Skip to content

Commit

Permalink
Merge pull request #5 from RomiconEZ/testing_artifacts
Browse files Browse the repository at this point in the history
Testing artifacts
  • Loading branch information
nizamovtimur authored Sep 11, 2024
2 parents 6ac8bb5 + b739f21 commit 20a5b8d
Show file tree
Hide file tree
Showing 37 changed files with 9,348 additions and 8,625 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -84,4 +84,5 @@ coverage.xml
report.xml

# CMake
cmake-build-*/
cmake-build-*/
/tests/artifacts/
8,054 changes: 4,028 additions & 4,026 deletions notebooks/llamator-api-example.ipynb

Large diffs are not rendered by default.

8,095 changes: 4,049 additions & 4,046 deletions notebooks/llamator-selenium-example.ipynb

Large diffs are not rendered by default.

14 changes: 8 additions & 6 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,17 @@ bump2version>=1.0.1,<2.0.0

# Project dependencies
openai==1.6.1
langchain==0.0.353
langchain-community==0.0.7
langchain-core==0.1.4
argparse==1.4.0
python-dotenv==1.0.0
langchain==0.2.16
langchain-community==0.2.16
langchain-core==0.2.38
tqdm==4.66.1
colorama==0.4.6
prettytable==3.10.0
pandas==2.2.2
inquirer==3.2.4
prompt-toolkit==3.0.43
fastparquet==2024.2.0
fastparquet==2024.2.0
yandexcloud==0.316.0
openpyxl==3.1.5
datetime==5.5
jupyter==1.1.1
10 changes: 6 additions & 4 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -31,17 +31,19 @@ python_requires = >=3.8
install_requires =
python-dotenv>=0.5.1
openai==1.6.1
langchain==0.0.353
langchain-community==0.0.7
langchain-core==0.1.4
argparse==1.4.0
langchain==0.2.16
langchain-community==0.2.16
langchain-core==0.2.38
tqdm==4.66.1
colorama==0.4.6
prettytable==3.10.0
pandas==2.2.2
inquirer==3.2.4
prompt-toolkit==3.0.43
fastparquet==2024.2.0
yandexcloud==0.316.0
openpyxl==3.1.5
datetime==5.5
[options.packages.find]
where=src

Expand Down
18 changes: 9 additions & 9 deletions src/llamator/attack_provider/attack_loader.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
from ..attacks import (
dynamic_test,
translation,
typoglycemia,
dan,
from ..attacks import ( # noqa
aim,
self_refine,
ethical_compliance,
ucar,
base64_injection,
complimentary_transition,
dan,
dynamic_test,
ethical_compliance,
harmful_behavior,
base64_injection,
self_refine,
sycophancy,
translation,
typoglycemia,
ucar,
)

# from ..attacks import (
Expand Down
17 changes: 14 additions & 3 deletions src/llamator/attack_provider/attack_registry.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging
from typing import List, Type
import os
from typing import List, Optional, Type

from ..attack_provider.test_base import TestBase
from ..client.attack_config import AttackConfig
Expand Down Expand Up @@ -32,6 +33,7 @@ def instantiate_tests(
attack_config: AttackConfig,
basic_tests: List[str] = None,
custom_tests: List[Type[TestBase]] = None,
artifacts_path: Optional[str] = None, # New parameter for artifacts path
) -> List[Type[TestBase]]:
"""
Instantiate and return a list of test instances based on registered test classes
Expand All @@ -47,28 +49,37 @@ def instantiate_tests(
List of basic test names that need to be instantiated (default is None).
custom_tests : List[Type[TestBase]], optional
List of custom test classes that need to be instantiated (default is None).
artifacts_path : str, optional
The path to the folder where artifacts (logs, reports) will be saved (default is './artifacts').
Returns
-------
List[Type[TestBase]]
A list of instantiated test objects.
"""

csv_report_path = artifacts_path

if artifacts_path is not None:
# Create 'csv_report' directory inside artifacts_path
csv_report_path = os.path.join(artifacts_path, "csv_report")
os.makedirs(csv_report_path, exist_ok=True)

# List to store instantiated tests
tests = []

# Create instances of basic test classes
if basic_tests is not None:
for cls in test_classes:
test_instance = cls(client_config, attack_config)
test_instance = cls(client_config, attack_config, artifacts_path=csv_report_path)
if test_instance.test_name in basic_tests:
logger.debug(f"Instantiating attack test class: {cls.__name__}")
tests.append(test_instance)

# Create instances of custom test classes
if custom_tests is not None:
for custom_test in custom_tests:
test_instance = custom_test(client_config, attack_config)
test_instance = custom_test(client_config, attack_config, artifacts_path=csv_report_path)
logger.debug(f"Instantiating attack test class: {cls.__name__}")
tests.append(test_instance)

Expand Down
50 changes: 19 additions & 31 deletions src/llamator/attack_provider/run_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,13 @@
from pydantic import ValidationError

from ..attack_provider.attack_registry import instantiate_tests
from ..attack_provider.work_progress_pool import (
ProgressWorker,
ThreadSafeTaskIterator,
WorkProgressPool,
)
from ..attack_provider.work_progress_pool import ProgressWorker, ThreadSafeTaskIterator, WorkProgressPool
from ..client.attack_config import AttackConfig
from ..client.chat_client import *
from ..client.client_config import ClientConfig
from ..format_output.results_table import print_table
from .attack_loader import *
from .attack_loader import * # noqa

# from .attack_loader import * - to register attacks defined in 'attack/*.py'
from .test_base import StatusUpdate, TestBase, TestStatus

Expand Down Expand Up @@ -117,6 +114,7 @@ def run_tests(
threads_count: int,
basic_tests: List[str],
custom_tests: List[Type[TestBase]],
artifacts_path: Optional[str] = None,
):
"""
Run the tests on the given client and attack configurations.
Expand All @@ -133,6 +131,8 @@ def run_tests(
A list of basic test names to be executed.
custom_tests : List[Type[TestBase]]
A list of custom test instances to be executed.
artifacts_path : str, optional
The path to the folder where artifacts (logs, reports) will be saved.
Returns
-------
Expand All @@ -145,7 +145,7 @@ def run_tests(

# Instantiate all tests
tests: List[Type[TestBase]] = instantiate_tests(
client_config, attack_config, basic_tests=basic_tests, custom_tests=custom_tests
client_config, attack_config, basic_tests=basic_tests, custom_tests=custom_tests, artifacts_path=artifacts_path
)

# Run tests in parallel mode
Expand All @@ -155,7 +155,7 @@ def run_tests(
report_results(tests)


def run_tests_in_parallel(tests: List[Type[TestBase]], threads_count: int):
def run_tests_in_parallel(tests: List[Type[TestBase]], threads_count: int = 1):
"""
Run the tests in parallel using a thread pool.
Expand Down Expand Up @@ -212,11 +212,7 @@ def report_results(tests: List[Type[TestBase]]):
data=sorted(
[
[
ERROR
if test.status.error_count > 0
else RESILIENT
if isResilient(test.status)
else VULNERABLE,
ERROR if test.status.error_count > 0 else RESILIENT if isResilient(test.status) else VULNERABLE,
f"{test.test_name + ' ':.<{50}}",
test.status.breach_count,
test.status.resilient_count,
Expand Down Expand Up @@ -288,14 +284,10 @@ def generate_summary(tests: List[Type[TestBase]]):
None
"""
resilient_tests_count = sum(isResilient(test.status) for test in tests)
failed_tests = [
f"{test.test_name}\n" if not isResilient(test.status) else "" for test in tests
]
failed_tests = [f"{test.test_name}\n" if not isResilient(test.status) else "" for test in tests]

total_tests_count = len(tests)
resilient_tests_percentage = (
resilient_tests_count / total_tests_count * 100 if total_tests_count > 0 else 0
)
resilient_tests_percentage = resilient_tests_count / total_tests_count * 100 if total_tests_count > 0 else 0

# Print a brief summary of the percentage of tests passed
print(
Expand All @@ -304,9 +296,7 @@ def generate_summary(tests: List[Type[TestBase]]):

# If there are failed tests, print the list of failed tests
if resilient_tests_count < total_tests_count:
print(
f"Your Model {BRIGHT_RED}failed{RESET} the following tests:\n{RED}{''.join(failed_tests)}{RESET}\n"
)
print(f"Your Model {BRIGHT_RED}failed{RESET} the following tests:\n{RED}{''.join(failed_tests)}{RESET}\n")


def setup_models_and_tests(
Expand All @@ -316,6 +306,7 @@ def setup_models_and_tests(
num_threads: int = 1,
tests: List[str] = None,
custom_tests: List[Type[TestBase]] = None,
artifacts_path: Optional[str] = None,
):
"""
Set up and validate the models, then run the tests.
Expand All @@ -334,6 +325,8 @@ def setup_models_and_tests(
A list of basic test names to be executed (default is None).
custom_tests : List[Type[TestBase]], optional
A list of custom test instances to be executed (default is None).
artifacts_path : str, optional
The path to the folder where artifacts (logs, reports) will be saved.
Returns
-------
Expand All @@ -343,20 +336,14 @@ def setup_models_and_tests(
try:
client_config = ClientConfig(tested_model)
except (ModuleNotFoundError, ValidationError) as e:
logger.warning(
f"Error accessing the Tested Model: {colorama.Fore.RED}{e}{colorama.Style.RESET_ALL}"
)
logger.warning(f"Error accessing the Tested Model: {colorama.Fore.RED}{e}{colorama.Style.RESET_ALL}")
return

# Attack model setup
try:
attack_config = AttackConfig(
attack_client=ClientConfig(attack_model), attack_prompts_count=num_attempts
)
attack_config = AttackConfig(attack_client=ClientConfig(attack_model), attack_prompts_count=num_attempts)
except (ModuleNotFoundError, ValidationError) as e:
logger.warning(
f"Error accessing the Attack Model: {colorama.Fore.RED}{e}{colorama.Style.RESET_ALL}"
)
logger.warning(f"Error accessing the Attack Model: {colorama.Fore.RED}{e}{colorama.Style.RESET_ALL}")
return

# Run tests
Expand All @@ -366,4 +353,5 @@ def setup_models_and_tests(
threads_count=num_threads,
basic_tests=tests,
custom_tests=custom_tests,
artifacts_path=artifacts_path,
)
4 changes: 3 additions & 1 deletion src/llamator/attack_provider/test_base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import Generator, List
from typing import Generator, List, Optional

from ..client.attack_config import AttackConfig
from ..client.client_config import ClientConfig
Expand Down Expand Up @@ -95,12 +95,14 @@ def __init__(
attack_config: AttackConfig,
test_name: str = "Test Name",
test_description: str = "Test Description",
artifacts_path: Optional[str] = None,
):
self.test_name = test_name
self.test_description = test_description
self.client_config = client_config
self.attack_config = attack_config
self.status = TestStatus()
self.artifacts_path = artifacts_path

@abstractmethod
def run(self) -> Generator[StatusUpdate, None, None]:
Expand Down
14 changes: 5 additions & 9 deletions src/llamator/attack_provider/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,14 @@ def is_response_list(response: str, min_items_count: int = 3) -> bool:

# Check the simple list markers pattern
simple_markers_pattern = r"^\s*[-*>•◦▪]"
if (
len(re.findall(simple_markers_pattern, response, re.MULTILINE))
>= min_items_count
):
if len(re.findall(simple_markers_pattern, response, re.MULTILINE)) >= min_items_count:
return True

# Check the complex list markers pattern
complex_markers_pattern = r"^\s*(?:step|phase|stage|part|point|item|number|note|rule|section)*\s*(?:\d+\.\s+|[a-z][.:]\s+)"
if (
len(re.findall(complex_markers_pattern, response, re.MULTILINE | re.IGNORECASE))
>= min_items_count
):
complex_markers_pattern = (
r"^\s*(?:step|phase|stage|part|point|item|number|note|rule|section)*\s*(?:\d+\.\s+|[a-z][.:]\s+)"
)
if len(re.findall(complex_markers_pattern, response, re.MULTILINE | re.IGNORECASE)) >= min_items_count:
return True

return False
Expand Down
9 changes: 2 additions & 7 deletions src/llamator/attack_provider/work_progress_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,7 @@ def __init__(self, num_workers):
ProgressWorker(worker_id, progress_bar=enable_per_test_progress_bars)
for worker_id in range(self.num_workers)
]
self.queue_progress_bar = tqdm(
total=1, desc=f"{colorama.Style.BRIGHT}{'Test progress ':.<54}{RESET}"
)
self.queue_progress_bar = tqdm(total=1, desc=f"{colorama.Style.BRIGHT}{'Test progress ':.<54}{RESET}")
self.semaphore = threading.Semaphore(
self.num_workers
) # Used to ensure that at most this number of tasks are immediately pending waiting for free worker slot
Expand Down Expand Up @@ -104,10 +102,7 @@ def run(self, tasks, tasks_count=None):

with ThreadPoolExecutor(max_workers=self.num_workers) as executor:
# Pass each worker its own progress bar reference
futures = [
executor.submit(self.worker_function, worker_id, tasks)
for worker_id in range(self.num_workers)
]
futures = [executor.submit(self.worker_function, worker_id, tasks) for worker_id in range(self.num_workers)]
# Wait for all workers to finish
for future in futures:
future.result()
Expand Down
Loading

0 comments on commit 20a5b8d

Please sign in to comment.