Skip to content

Commit

Permalink
fix(utils): Enhance the dependencies check to include pip distributio…
Browse files Browse the repository at this point in the history
…n validation.(e.g. fasttext)
  • Loading branch information
aiqwe committed Jan 6, 2025
1 parent 2548cdf commit 3759e67
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 3 deletions.
2 changes: 1 addition & 1 deletion src/datatrove/pipeline/filters/fasttext_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class FastTextClassifierFilter(BaseFilter):
"""

name = "🤖 fastText"
_requires_dependencies = [("fasttext", "fasttext-wheel"), "fasteners"]
_requires_dependencies = [("fasttext", "fasttext-numpy2-wheel"), "fasteners"]

def __init__(
self,
Expand Down
2 changes: 1 addition & 1 deletion src/datatrove/pipeline/filters/language_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

class LanguageFilter(BaseFilter):
name = "🌍 Language ID"
_requires_dependencies = [("fasttext", "fasttext-wheel"), "fasteners"]
_requires_dependencies = [("fasttext", "fasttext-numpy2-wheel"), "fasteners"]

def __init__(
self,
Expand Down
24 changes: 24 additions & 0 deletions src/datatrove/utils/_import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,27 @@


def check_required_dependencies(step_name: str, required_dependencies: list[str] | list[tuple[str, str]]):
"""Check whether the required dependencies are installed or not.
Args:
step_name: str
The name of the step
required_dependencies: List[str] | List[tuple[str, str]]
required dependencies. If the format is a tuple, it is checked as (module name, pip name).
When provided as a tuple, an error will be raised if the top-level module name is correct but the pip distribution name differs
(e.g., (fasttext, fasttext-numpy2-wheel)).
"""
missing_dependencies: dict[str, str] = {}
for dependency in required_dependencies:
dependency = dependency if isinstance(dependency, tuple) else (dependency, dependency)
package_name, pip_name = dependency
# case1: in case we didn't install package
if not _is_package_available(package_name):
missing_dependencies[package_name] = pip_name
# case2: top-level package is installed but distribution is incorrect. (i. e. fasttext-numpy2-wheel; compatibility for numpy2)
if not _is_distribution_available(pip_name):
missing_dependencies[package_name] = pip_name
if missing_dependencies:
_raise_error_for_missing_dependencies(step_name, missing_dependencies)

Expand Down Expand Up @@ -66,6 +81,15 @@ def is_tokenizers_available():
return _is_package_available("tokenizers")


# Distribution Check
def _is_distribution_available(distribution_name: str):
found = None
for dist in importlib.metadata.distributions():
if dist.metadata["Name"] == distribution_name:
found = True
return found


# Used in tests


Expand Down
2 changes: 1 addition & 1 deletion src/datatrove/utils/lid.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def __init__(self, languages: list[str] | None = None, k: int = -1) -> None:
@property
def model(self):
if self._model is None:
check_required_dependencies("lid", [("fasttext", "fasttext-wheel")])
check_required_dependencies("lid", [("fasttext", "fasttext-numpy2-wheel")])
from fasttext.FastText import _FastText

model_file = cached_asset_path_or_download(
Expand Down

0 comments on commit 3759e67

Please sign in to comment.