Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix NLTK dataset #27

Merged
merged 1 commit into from
Nov 5, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 46 additions & 5 deletions validator/post-install.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,54 @@
import detoxify

# Download NLTK data if not already present
try:
nltk.data.find("tokenizers/punkt")
except LookupError:
nltk.download("punkt")
print("NLTK stuff loaded successfully.")
def load_nltk_data():
import re
import nltk
from importlib.metadata import version

nltk_version = version("nltk")
nltk_breaking_version = "3.8.2" # The version where the dataset changed

def parse_major_minor_patch(version: str):
"""Extract the major, minor, and patch version numbers from a version string."""
match = re.match(r"^(0|[1-9]\d*)\.(0|[1-9]\d*)(?:\.(0|[1-9]\d*))?(?:[-+][0-9A-Za-z-.]+)?$", version)
if match:
major = int(match.group(1))
minor = int(match.group(2))
patch = int(match.group(3)) if match.group(3) else 0 # Default to 0 if patch is not provided
return major, minor, patch
else:
raise ValueError(f"Invalid semantic version: '{version}'")

def install_pre_382_dataset():
try:
nltk.data.find("tokenizers/punkt")
except LookupError:
nltk.download("punkt")

def install_post_382_dataset():
try:
nltk.data.find("tokenizers/punkt_tab")
except LookupError:
nltk.download("punkt_tab")

try:
target_major, target_minor, target_patch = parse_major_minor_patch(nltk_breaking_version)
major, minor, patch = parse_major_minor_patch(nltk_version)

if (major, minor, patch) >= (target_major, target_minor, target_patch):
install_post_382_dataset()
elif (major, minor, patch) < (target_major, target_minor, target_patch):
install_pre_382_dataset()
except Exception:
print((
"Error auto-installing nltk dataset, please install manually.\n"
"This can be done with:\n",
"Version < 3.8.2:\n import nltk\n nltk.download('punkt')",
"Version >= 3.8.2:\n import nltk\n nltk.download('punkt_tab')"
))

load_nltk_data()

model = detoxify.Detoxify("unbiased-small")
print("Detoxify's 'unbiased-small' toxicity model downloaded successfully!")
Loading