diff --git a/validator/main.py b/validator/main.py index ccdc10f..051238f 100644 --- a/validator/main.py +++ b/validator/main.py @@ -34,7 +34,7 @@ def __init__( threshold: float = 0.9, on_fail: Optional[Union[str, Callable]] = None, ): - super().__init__(on_fail=on_fail) + super().__init__(on_fail=on_fail) # type: ignore valid_on_fail_operations = {"fix", "noop", "exception"} if isinstance(on_fail, str) and on_fail not in valid_on_fail_operations: raise Exception( @@ -115,15 +115,17 @@ def fix_passage(self, text: str) -> str: # This normally will be called by _inference. # Remote inference is unsupported for this model on account of the NER. - def _inference_local(self, sentences: List[str]) -> List[float]: + def _inference_local(self, sentences: List[str]) -> List[float]: # type: ignore scores = list() predictions = self.classification_model(sentences) for pred in predictions: - if pred['label'] == 'Biased': - scores.append(pred['score']) - elif pred['label'] == 'Non-biased': - scores.append(-pred['score']) + label = pred['label'] # type: ignore + score = pred['score'] # type: ignore + if label == 'Biased': + scores.append(score) + elif label == 'Non-biased': + scores.append(-score) else: # This should never happen: - raise Exception("Unexpected prediction label: {}".format(pred['label'])) + raise Exception("Unexpected prediction label: {}".format(label)) return scores