diff --git a/langtest/tasks/task.py b/langtest/tasks/task.py index c14a568f4..cb2543eaf 100644 --- a/langtest/tasks/task.py +++ b/langtest/tasks/task.py @@ -26,6 +26,7 @@ SensitivitySample, LLMAnswerSample, CrowsPairsSample, + StereoSetSample, ) from langtest.utils.custom_types.predictions import NERPrediction @@ -707,3 +708,65 @@ def create_sample( mask1=row_data[column_mapper[mask1]], mask2=row_data[column_mapper[mask2]], ) + + +class StereosetTask(BaseTask): + """StereoSet task.""" + + _name = "stereoset" + _default_col = { + "text": ["text", "sentence"], + "mask1": ["mask1"], + "mask2": ["mask2"], + } + sample_class = StereoSetSample + + def create_sample( + cls, + row_data: dict, + bias_type: str = "bias_type", + test_type: str = "type", + target_column: str = "target", + context: str = "context", + sent_stereo: str = "stereotype", + sent_antistereo: str = "anti-stereotype", + sent_unrelated: str = "unrelated", + *args, + **kwargs, + ) -> StereoSetSample: + """Create a sample.""" + keys = list(row_data.keys()) + if set( + [ + bias_type, + test_type, + target_column, + context, + sent_stereo, + sent_antistereo, + sent_unrelated, + ] + ).intersection(set(keys)): + # if the column names are provided, use them directly + column_mapper = { + bias_type: bias_type, + test_type: test_type, + target_column: target_column, + context: context, + sent_stereo: sent_stereo, + sent_antistereo: sent_antistereo, + sent_unrelated: sent_unrelated, + } + else: + # auto-detect the default column names from the row_data + column_mapper = cls.column_mapping(keys) + + return StereoSetSample( + test_type=row_data[column_mapper[test_type]], + bias_type=row_data[column_mapper[bias_type]], + target=row_data[column_mapper[target_column]], + context=row_data[column_mapper[context]], + sent_stereo=row_data[column_mapper[sent_stereo]], + sent_antistereo=row_data[column_mapper[sent_antistereo]], + sent_unrelated=row_data[column_mapper[sent_unrelated]], + )