Skip to content

Commit

Permalink
fix return types
Browse files Browse the repository at this point in the history
  • Loading branch information
aaravnavani committed Sep 3, 2024
1 parent 9e2e2be commit 860744e
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions validator/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def validate(self, value: str, metadata: Dict[str, Any]) -> ValidationResult:
return self.validate_each_sentence(value, metadata)
return self.validate_full_text(value, metadata)

def _inference_local(self, value: str | list) -> List[bool]:
def _inference_local(self, value: str | list) -> ValidationResult:
"""Local inference method for the NSFW text validator."""

if isinstance(value, str):
Expand All @@ -153,7 +153,7 @@ def _inference_local(self, value: str | list) -> List[bool]:
predictions.append(self.is_nsfw(text))
return predictions

def _inference_remote(self, value: str | list) -> List[bool]:
def _inference_remote(self, value: str | list) -> ValidationResult:
"""Remote inference method for the NSFW text validator."""

if isinstance(value, str):
Expand Down Expand Up @@ -181,7 +181,7 @@ def _inference_remote(self, value: str | list) -> List[bool]:
if not response or "outputs" not in response:
raise ValueError("Invalid response from remote inference", response)

data = [bool(output["data"][0]) for output in response["outputs"]]
data = [output["data"][0] for output in response["outputs"]]
return data

def get_error_spans(self, original: str, fixed: str) -> List[ErrorSpan]:
Expand Down

0 comments on commit 860744e

Please sign in to comment.