Skip to content

Commit

Permalink
feat: add TestConfig TypedDict for clinical, grammar, legal, and perf…
Browse files Browse the repository at this point in the history
…ormance tests to enhance type safety and configuration structure
  • Loading branch information
chakravarthik27 committed Nov 22, 2024
1 parent b3a591e commit 1fd426e
Show file tree
Hide file tree
Showing 5 changed files with 86 additions and 7 deletions.
10 changes: 8 additions & 2 deletions langtest/transform/clinical.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from collections import defaultdict
import logging
import random
from typing import List, Dict, Union
from typing import List, Dict, TypedDict, Union

import importlib_resources
from langtest.errors import Errors, Warnings
Expand Down Expand Up @@ -100,6 +100,12 @@ class BaseClincial(ABC):
"question-answering",
]

# TestConfig
TestConfig = TypedDict(
"TestConfig",
min_pass_rate=float,
)

@staticmethod
@abstractmethod
def transform(*args, **kwargs):
Expand Down Expand Up @@ -147,7 +153,7 @@ class DemographicBias(BaseClincial):
DemographicBias class for the clinical tests
"""

alias_name = "demographic-bias"
alias_name = ["demographic-bias", "demographic_bias"]
supported_tasks = ["question-answering", "text-generation"]

@staticmethod
Expand Down
18 changes: 16 additions & 2 deletions langtest/transform/grammar.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import asyncio
from typing import List, Dict, Optional
from typing import List, Dict, Optional, TypedDict
from langtest.utils.custom_types.sample import Sample
from abc import ABC, abstractmethod
from langtest.errors import Errors, Warnings
Expand Down Expand Up @@ -103,7 +103,7 @@ def transform(self) -> List[Sample]:
return all_samples

@staticmethod
def available_tests() -> dict:
def available_tests() -> Dict[str, type["BaseGrammar"]]:
"""
Retrieve a dictionary of all available tests, with their names as keys
and their corresponding classes as values.
Expand All @@ -120,9 +120,23 @@ def available_tests() -> dict:


class BaseGrammar(ABC):
"""
BaseGrammar abstract class for implementing to test the model performance on varying the input by grammatically changes .
"""

# TestConfig
TestConfig = TypedDict(
"TestConfig",
min_pass_rate=float,
)

@staticmethod
@abstractmethod
def transform(sample_list: List[Sample], *args, **kwargs):
"""
Abstract method that transforms the sample data with grammatically changes.
"""
raise NotImplementedError

@staticmethod
Expand Down
8 changes: 7 additions & 1 deletion langtest/transform/legal.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import asyncio
from typing import List, Dict
from typing import List, Dict, TypedDict
from langtest.transform.base import ITests
from langtest.modelhandler.modelhandler import ModelAPI
from langtest.utils.custom_types.sample import Sample
Expand All @@ -11,6 +11,12 @@ class LegalTestFactory(ITests):
alias_name = "legal"
supported_tasks = ["legal", "question-answering"]

# TestConfig
TestConfig = TypedDict(
"TestConfig",
min_pass_rate=float,
)

def __init__(self, data_handler: List[Sample], tests: Dict = None, **kwargs) -> None:
"""Initializes the legal tests"""

Expand Down
10 changes: 8 additions & 2 deletions langtest/transform/performance.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import asyncio
import time
from collections import defaultdict
from typing import List, Dict
from typing import List, Dict, TypedDict
from langtest.errors import Errors
from abc import ABC, abstractmethod
from langtest.modelhandler.modelhandler import ModelAPI
Expand Down Expand Up @@ -73,7 +73,7 @@ async def run(
return tasks

@classmethod
def available_tests(cls) -> Dict[str, str]:
def available_tests(cls) -> Dict[str, type["BasePerformance"]]:
"""Returns the available model performance
Returns:
Expand All @@ -96,6 +96,12 @@ class BasePerformance(ABC):
alias_name = None
TOKENS = 0

# TestConfig
TestConfig = TypedDict(
"TestConfig",
min_pass_rate=float,
)

@staticmethod
@abstractmethod
def transform():
Expand Down
47 changes: 47 additions & 0 deletions langtest/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,48 @@ class SafetyTestsConfig(TypedDict):
jailbreak_probalities_score: safety.JailBreakProbalities.TestConfig


class PerformanceTestsConfig(TypedDict):
"""
TestsConfig is for defining the configuration of a Performance Tests.
"""

from langtest.transform import performance

speed: performance.Speed.TestConfig


class LegalTestsConfig(TypedDict):
"""
TestsConfig is for defining the configuration of a Legal Tests.
"""

from langtest.transform import legal

legal_support: legal.LegalTestFactory.TestConfig


class GrammarTestsConfig(TypedDict):
"""
TestsConfig is for defining the configuration of a Grammar Tests.
"""

from langtest.transform import grammar

paraphase: grammar.Paraphrase.TestConfig


class ClinicalTestsConfig(TypedDict):
"""
TestsConfig is for defining the configuration of a Clinical Tests.
"""

from langtest.transform import clinical

demographic_bias: clinical.DemographicBias.TestConfig
drug_generic_to_brand: clinical.Generic2Brand.TestConfig
drug_brand_to_generic: clinical.Brand2Generic.TestConfig


class TestCategories(TypedDict):
"""
TestCategories is a TypedDict that defines the categories of tests.
Expand All @@ -216,6 +258,11 @@ class TestCategories(TypedDict):
accuracy: AccuracyTestsConfig
toxicity: ToxicityTestsConfig
security: SecurityTestsConfig
safety: SafetyTestsConfig
performance: PerformanceTestsConfig
legal: LegalTestsConfig
grammar: GrammarTestsConfig
clinical: ClinicalTestsConfig


class HarnessConfig(TypedDict):
Expand Down

0 comments on commit 1fd426e

Please sign in to comment.