From 2e0af575e65b73b7cf9455d36a68c5dba51a6553 Mon Sep 17 00:00:00 2001 From: Ryan Wolf Date: Thu, 29 Aug 2024 08:52:45 -0700 Subject: [PATCH] Sort labels (#221) Signed-off-by: Ryan Wolf --- nemo_curator/classifiers/domain.py | 1 + nemo_curator/classifiers/quality.py | 1 + 2 files changed, 2 insertions(+) diff --git a/nemo_curator/classifiers/domain.py b/nemo_curator/classifiers/domain.py index 1751c3b1..f77bade7 100644 --- a/nemo_curator/classifiers/domain.py +++ b/nemo_curator/classifiers/domain.py @@ -95,6 +95,7 @@ def __init__( self.prob_column = prob_column self.labels = list(config.label2id.keys()) + self.labels.sort(key=lambda x: config.label2id[x]) self.out_dim = len(self.labels) model = DomainModel( diff --git a/nemo_curator/classifiers/quality.py b/nemo_curator/classifiers/quality.py index 91f33997..c0f2bf77 100644 --- a/nemo_curator/classifiers/quality.py +++ b/nemo_curator/classifiers/quality.py @@ -92,6 +92,7 @@ def __init__( self.prob_column = prob_column self.labels = list(config.label2id.keys()) + self.labels.sort(key=lambda x: config.label2id[x]) self.out_dim = len(self.labels) model = QualityModel(