Skip to content

Commit

Permalink
updates to app.py
Browse files Browse the repository at this point in the history
  • Loading branch information
aaravnavani committed Sep 5, 2024
1 parent d2b7d3e commit f8a8ec8
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 7 deletions.
12 changes: 10 additions & 2 deletions app.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,20 @@
from typing import List
from transformers import pipeline
import os
import torch

app = FastAPI()
# Initialize the NSFW model once

env = os.environ.get("env", "dev")
torch_device = "cuda" if env == "prod" else "cpu"

if env == "prod" and torch.cuda.is_available():
torch_device = "cuda"
elif env == "prod" and not torch.cuda.is_available():
print("Warning: CUDA is not available. Falling back to CPU.")
torch_device = "cpu"
else:
torch_device = "cpu"

class InferenceData(BaseModel):
name: str
shape: List[int]
Expand Down
10 changes: 5 additions & 5 deletions validator/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def validate_each_sentence(
# Split the value into sentences using nltk sentence tokenizer.
sentences = nltk.sent_tokenize(value)

unsupported_sentences, supported_sentences = [], []
clean_sentences, nsfw_sentences = [], []
error_spans: List[ErrorSpan] = []
char_index = 0

Expand All @@ -108,7 +108,7 @@ def validate_each_sentence(
pred_labels = sentence_predictions[idx]

if pred_labels:
unsupported_sentences.append(sentence)
nsfw_sentences.append(sentence)
error_spans.append(
ErrorSpan(
start=char_index,
Expand All @@ -117,11 +117,11 @@ def validate_each_sentence(
)
)
else:
supported_sentences.append(sentence)
clean_sentences.append(sentence)
char_index += len(sentence) + 1

if unsupported_sentences:
unsupported_sentences_text = "- " + "\n- ".join(unsupported_sentences)
if nsfw_sentences:
nsfw_sentences_text = "- " + "\n- ".join(nsfw_sentences)

return FailResult(
metadata=metadata,
Expand Down

0 comments on commit f8a8ec8

Please sign in to comment.