Skip to content

Commit

Permalink
resolved: stereoset task issues
Browse files Browse the repository at this point in the history
chakravarthik27 committed Oct 24, 2023
1 parent c5e6cc8 commit 76a4ea6
Showing 1 changed file with 63 additions and 0 deletions.
63 changes: 63 additions & 0 deletions langtest/tasks/task.py
Original file line number Diff line number Diff line change
@@ -26,6 +26,7 @@
SensitivitySample,
LLMAnswerSample,
CrowsPairsSample,
StereoSetSample,
)
from langtest.utils.custom_types.predictions import NERPrediction

@@ -707,3 +708,65 @@ def create_sample(
mask1=row_data[column_mapper[mask1]],
mask2=row_data[column_mapper[mask2]],
)


class StereosetTask(BaseTask):
"""StereoSet task."""

_name = "stereoset"
_default_col = {
"text": ["text", "sentence"],
"mask1": ["mask1"],
"mask2": ["mask2"],
}
sample_class = StereoSetSample

def create_sample(
cls,
row_data: dict,
bias_type: str = "bias_type",
test_type: str = "type",
target_column: str = "target",
context: str = "context",
sent_stereo: str = "stereotype",
sent_antistereo: str = "anti-stereotype",
sent_unrelated: str = "unrelated",
*args,
**kwargs,
) -> StereoSetSample:
"""Create a sample."""
keys = list(row_data.keys())
if set(
[
bias_type,
test_type,
target_column,
context,
sent_stereo,
sent_antistereo,
sent_unrelated,
]
).intersection(set(keys)):
# if the column names are provided, use them directly
column_mapper = {
bias_type: bias_type,
test_type: test_type,
target_column: target_column,
context: context,
sent_stereo: sent_stereo,
sent_antistereo: sent_antistereo,
sent_unrelated: sent_unrelated,
}
else:
# auto-detect the default column names from the row_data
column_mapper = cls.column_mapping(keys)

return StereoSetSample(
test_type=row_data[column_mapper[test_type]],
bias_type=row_data[column_mapper[bias_type]],
target=row_data[column_mapper[target_column]],
context=row_data[column_mapper[context]],
sent_stereo=row_data[column_mapper[sent_stereo]],
sent_antistereo=row_data[column_mapper[sent_antistereo]],
sent_unrelated=row_data[column_mapper[sent_unrelated]],
)

0 comments on commit 76a4ea6

Please sign in to comment.