diff --git a/src/datatrove/pipeline/filters/fasttext_filter.py b/src/datatrove/pipeline/filters/fasttext_filter.py index 1cd8d101..5ebb3cc8 100644 --- a/src/datatrove/pipeline/filters/fasttext_filter.py +++ b/src/datatrove/pipeline/filters/fasttext_filter.py @@ -36,7 +36,7 @@ class FastTextClassifierFilter(BaseFilter): """ name = "🤖 fastText" - _requires_dependencies = [("fasttext", "fasttext-wheel"), "fasteners"] + _requires_dependencies = [("fasttext", "fasttext-numpy2-wheel"), "fasteners"] def __init__( self, diff --git a/src/datatrove/pipeline/filters/language_filter.py b/src/datatrove/pipeline/filters/language_filter.py index e57dce07..29359aa7 100644 --- a/src/datatrove/pipeline/filters/language_filter.py +++ b/src/datatrove/pipeline/filters/language_filter.py @@ -8,7 +8,7 @@ class LanguageFilter(BaseFilter): name = "🌍 Language ID" - _requires_dependencies = [("fasttext", "fasttext-wheel"), "fasteners"] + _requires_dependencies = [("fasttext", "fasttext-numpy2-wheel"), "fasteners"] def __init__( self, diff --git a/src/datatrove/utils/_import_utils.py b/src/datatrove/utils/_import_utils.py index bdb875e1..42deed36 100644 --- a/src/datatrove/utils/_import_utils.py +++ b/src/datatrove/utils/_import_utils.py @@ -8,12 +8,27 @@ def check_required_dependencies(step_name: str, required_dependencies: list[str] | list[tuple[str, str]]): + """Check whether the required dependencies are installed or not. + + Args: + step_name: str + The name of the step + required_dependencies: List[str] | List[tuple[str, str]] + required dependencies. If the format is a tuple, it is checked as (module name, pip name). + When provided as a tuple, an error will be raised if the top-level module name is correct but the pip distribution name differs + (e.g., (fasttext, fasttext-numpy2-wheel)). + + """ missing_dependencies: dict[str, str] = {} for dependency in required_dependencies: dependency = dependency if isinstance(dependency, tuple) else (dependency, dependency) package_name, pip_name = dependency + # case1: in case we didn't install package if not _is_package_available(package_name): missing_dependencies[package_name] = pip_name + # case2: top-level package is installed but distribution is incorrect. (i. e. fasttext-numpy2-wheel; compatibility for numpy2) + if not _is_distribution_available(pip_name): + missing_dependencies[package_name] = pip_name if missing_dependencies: _raise_error_for_missing_dependencies(step_name, missing_dependencies) @@ -66,6 +81,15 @@ def is_tokenizers_available(): return _is_package_available("tokenizers") +# Distribution Check +def _is_distribution_available(distribution_name: str): + found = None + for dist in importlib.metadata.distributions(): + if dist.metadata["Name"] == distribution_name: + found = True + return found + + # Used in tests diff --git a/src/datatrove/utils/lid.py b/src/datatrove/utils/lid.py index 1a6a597c..2ab5298d 100644 --- a/src/datatrove/utils/lid.py +++ b/src/datatrove/utils/lid.py @@ -38,7 +38,7 @@ def __init__(self, languages: list[str] | None = None, k: int = -1) -> None: @property def model(self): if self._model is None: - check_required_dependencies("lid", [("fasttext", "fasttext-wheel")]) + check_required_dependencies("lid", [("fasttext", "fasttext-numpy2-wheel")]) from fasttext.FastText import _FastText model_file = cached_asset_path_or_download(