From 1fd426edd63149ae7d56604612357652e8964379 Mon Sep 17 00:00:00 2001 From: Kalyan Chakravarthy Date: Fri, 22 Nov 2024 14:58:46 +0530 Subject: [PATCH] feat: add TestConfig TypedDict for clinical, grammar, legal, and performance tests to enhance type safety and configuration structure --- langtest/transform/clinical.py | 10 +++++-- langtest/transform/grammar.py | 18 ++++++++++-- langtest/transform/legal.py | 8 +++++- langtest/transform/performance.py | 10 +++++-- langtest/types.py | 47 +++++++++++++++++++++++++++++++ 5 files changed, 86 insertions(+), 7 deletions(-) diff --git a/langtest/transform/clinical.py b/langtest/transform/clinical.py index 2c54425fa..62de4e9f1 100644 --- a/langtest/transform/clinical.py +++ b/langtest/transform/clinical.py @@ -3,7 +3,7 @@ from collections import defaultdict import logging import random -from typing import List, Dict, Union +from typing import List, Dict, TypedDict, Union import importlib_resources from langtest.errors import Errors, Warnings @@ -100,6 +100,12 @@ class BaseClincial(ABC): "question-answering", ] + # TestConfig + TestConfig = TypedDict( + "TestConfig", + min_pass_rate=float, + ) + @staticmethod @abstractmethod def transform(*args, **kwargs): @@ -147,7 +153,7 @@ class DemographicBias(BaseClincial): DemographicBias class for the clinical tests """ - alias_name = "demographic-bias" + alias_name = ["demographic-bias", "demographic_bias"] supported_tasks = ["question-answering", "text-generation"] @staticmethod diff --git a/langtest/transform/grammar.py b/langtest/transform/grammar.py index 893854ec8..a5a1ef7e4 100644 --- a/langtest/transform/grammar.py +++ b/langtest/transform/grammar.py @@ -1,5 +1,5 @@ import asyncio -from typing import List, Dict, Optional +from typing import List, Dict, Optional, TypedDict from langtest.utils.custom_types.sample import Sample from abc import ABC, abstractmethod from langtest.errors import Errors, Warnings @@ -103,7 +103,7 @@ def transform(self) -> List[Sample]: return all_samples @staticmethod - def available_tests() -> dict: + def available_tests() -> Dict[str, type["BaseGrammar"]]: """ Retrieve a dictionary of all available tests, with their names as keys and their corresponding classes as values. @@ -120,9 +120,23 @@ def available_tests() -> dict: class BaseGrammar(ABC): + """ + BaseGrammar abstract class for implementing to test the model performance on varying the input by grammatically changes . + + """ + + # TestConfig + TestConfig = TypedDict( + "TestConfig", + min_pass_rate=float, + ) + @staticmethod @abstractmethod def transform(sample_list: List[Sample], *args, **kwargs): + """ + Abstract method that transforms the sample data with grammatically changes. + """ raise NotImplementedError @staticmethod diff --git a/langtest/transform/legal.py b/langtest/transform/legal.py index e914ad06c..ce717f6d1 100644 --- a/langtest/transform/legal.py +++ b/langtest/transform/legal.py @@ -1,5 +1,5 @@ import asyncio -from typing import List, Dict +from typing import List, Dict, TypedDict from langtest.transform.base import ITests from langtest.modelhandler.modelhandler import ModelAPI from langtest.utils.custom_types.sample import Sample @@ -11,6 +11,12 @@ class LegalTestFactory(ITests): alias_name = "legal" supported_tasks = ["legal", "question-answering"] + # TestConfig + TestConfig = TypedDict( + "TestConfig", + min_pass_rate=float, + ) + def __init__(self, data_handler: List[Sample], tests: Dict = None, **kwargs) -> None: """Initializes the legal tests""" diff --git a/langtest/transform/performance.py b/langtest/transform/performance.py index b7ffea612..6f0906bc7 100644 --- a/langtest/transform/performance.py +++ b/langtest/transform/performance.py @@ -1,7 +1,7 @@ import asyncio import time from collections import defaultdict -from typing import List, Dict +from typing import List, Dict, TypedDict from langtest.errors import Errors from abc import ABC, abstractmethod from langtest.modelhandler.modelhandler import ModelAPI @@ -73,7 +73,7 @@ async def run( return tasks @classmethod - def available_tests(cls) -> Dict[str, str]: + def available_tests(cls) -> Dict[str, type["BasePerformance"]]: """Returns the available model performance Returns: @@ -96,6 +96,12 @@ class BasePerformance(ABC): alias_name = None TOKENS = 0 + # TestConfig + TestConfig = TypedDict( + "TestConfig", + min_pass_rate=float, + ) + @staticmethod @abstractmethod def transform(): diff --git a/langtest/types.py b/langtest/types.py index 717f8a09c..9f1118482 100644 --- a/langtest/types.py +++ b/langtest/types.py @@ -203,6 +203,48 @@ class SafetyTestsConfig(TypedDict): jailbreak_probalities_score: safety.JailBreakProbalities.TestConfig +class PerformanceTestsConfig(TypedDict): + """ + TestsConfig is for defining the configuration of a Performance Tests. + """ + + from langtest.transform import performance + + speed: performance.Speed.TestConfig + + +class LegalTestsConfig(TypedDict): + """ + TestsConfig is for defining the configuration of a Legal Tests. + """ + + from langtest.transform import legal + + legal_support: legal.LegalTestFactory.TestConfig + + +class GrammarTestsConfig(TypedDict): + """ + TestsConfig is for defining the configuration of a Grammar Tests. + """ + + from langtest.transform import grammar + + paraphase: grammar.Paraphrase.TestConfig + + +class ClinicalTestsConfig(TypedDict): + """ + TestsConfig is for defining the configuration of a Clinical Tests. + """ + + from langtest.transform import clinical + + demographic_bias: clinical.DemographicBias.TestConfig + drug_generic_to_brand: clinical.Generic2Brand.TestConfig + drug_brand_to_generic: clinical.Brand2Generic.TestConfig + + class TestCategories(TypedDict): """ TestCategories is a TypedDict that defines the categories of tests. @@ -216,6 +258,11 @@ class TestCategories(TypedDict): accuracy: AccuracyTestsConfig toxicity: ToxicityTestsConfig security: SecurityTestsConfig + safety: SafetyTestsConfig + performance: PerformanceTestsConfig + legal: LegalTestsConfig + grammar: GrammarTestsConfig + clinical: ClinicalTestsConfig class HarnessConfig(TypedDict):