diff --git a/nemo_curator/classifiers/domain.py b/nemo_curator/classifiers/domain.py index 1751c3b1..f77bade7 100644 --- a/nemo_curator/classifiers/domain.py +++ b/nemo_curator/classifiers/domain.py @@ -95,6 +95,7 @@ def __init__( self.prob_column = prob_column self.labels = list(config.label2id.keys()) + self.labels.sort(key=lambda x: config.label2id[x]) self.out_dim = len(self.labels) model = DomainModel( diff --git a/nemo_curator/classifiers/quality.py b/nemo_curator/classifiers/quality.py index 91f33997..c0f2bf77 100644 --- a/nemo_curator/classifiers/quality.py +++ b/nemo_curator/classifiers/quality.py @@ -92,6 +92,7 @@ def __init__( self.prob_column = prob_column self.labels = list(config.label2id.keys()) + self.labels.sort(key=lambda x: config.label2id[x]) self.out_dim = len(self.labels) model = QualityModel(