diff --git a/README.md b/README.md index 4bd087f..552a48d 100644 --- a/README.md +++ b/README.md @@ -12,6 +12,14 @@ Install this library using `pip`: pip install crossfit +### Installation from source (for cuda 12.x) + +``` +git clone https://github.com/rapidsai/crossfit.git +cd crossfit +pip install --extra-index-url https://pypi.nvidia.com ".[cuda12x]" +``` + ## Usage Usage instructions go here. diff --git a/crossfit/__init__.py b/crossfit/__init__.py index 9d79be9..bf0141e 100644 --- a/crossfit/__init__.py +++ b/crossfit/__init__.py @@ -25,6 +25,41 @@ from crossfit.metric import * from crossfit.op import * + +class LazyLoader: + def __init__(self, name): + self._name = name + self._module = None + self._error = None + + def _load(self): + if self._module is None and self._error is None: + try: + parts = self._name.split(".") + module_name = ".".join(parts[:-1]) + attribute_name = parts[-1] + module = __import__(module_name, fromlist=[attribute_name]) + self._module = getattr(module, attribute_name) + except ImportError as e: + self._error = e + except AttributeError as e: + self._error = AttributeError( + f"Module '{module_name}' has no attribute '{attribute_name}'" + ) + + def __getattr__(self, item): + self._load() + if self._error is not None: + raise ImportError(f"Failed to import {self._name}: {self._error}") + return getattr(self._module, item) + + def __call__(self, *args, **kwargs): + self._load() + if self._error is not None: + raise ImportError(f"Failed to import {self._name}: {self._error}") + return self._module(*args, **kwargs) + + __all__ = [ "Aggregator", "backend", @@ -40,25 +75,25 @@ "Serial", ] +# Using the lazy import function +HFModel = LazyLoader("crossfit.backend.torch.HFModel") +SentenceTransformerModel = LazyLoader("crossfit.backend.torch.SentenceTransformerModel") +TorchExactSearch = LazyLoader("crossfit.backend.torch.TorchExactSearch") +IRDataset = LazyLoader("crossfit.dataset.base.IRDataset") +MultiDataset = LazyLoader("crossfit.dataset.base.MultiDataset") +load_dataset = LazyLoader("crossfit.dataset.load.load_dataset") +embed = LazyLoader("crossfit.report.beir.embed.embed") +beir_report = LazyLoader("crossfit.report.beir.report.beir_report") -try: - from crossfit.backend.torch import HFModel, SentenceTransformerModel, TorchExactSearch - from crossfit.dataset.base import IRDataset, MultiDataset - from crossfit.dataset.load import load_dataset - from crossfit.report.beir.embed import embed - from crossfit.report.beir.report import beir_report - - __all__.extend( - [ - "embed", - "beir_report", - "load_dataset", - "TorchExactSearch", - "SentenceTransformerModel", - "HFModel", - "MultiDataset", - "IRDataset", - ] - ) -except ImportError as e: - pass +__all__.extend( + [ + "embed", + "beir_report", + "load_dataset", + "TorchExactSearch", + "SentenceTransformerModel", + "HFModel", + "MultiDataset", + "IRDataset", + ] +) diff --git a/requirements/cuda12x.txt b/requirements/cuda12x.txt new file mode 100644 index 0000000..54e1c2d --- /dev/null +++ b/requirements/cuda12x.txt @@ -0,0 +1,13 @@ +cudf-cu12>=24.4 +dask-cudf-cu12>=24.4 +cuml-cu12>=24.4 +pylibraft-cu12>=24.4 +raft-dask-cu12>=24.4 +cuvs-cu12>=24.4 +dask-cuda>=24.6 +torch>=2.0 +transformers>=4.0 +curated-transformers>=1.0 +bitsandbytes>=0.30 +sentence-transformers>=2.0 +sentencepiece diff --git a/setup.py b/setup.py index ba8a42e..ea906a9 100644 --- a/setup.py +++ b/setup.py @@ -31,7 +31,7 @@ def get_long_description(): def read_requirements(filename): base = os.path.abspath(os.path.dirname(__file__)) - with codecs.open(os.path.join(base, filename), "rb", "utf-8") as f: + with codecs.open(os.path.join(base, filename), "r", "utf-8") as f: lineiter = (line.strip() for line in f) return [line for line in lineiter if line and not line.startswith("#")] @@ -40,12 +40,15 @@ def read_requirements(filename): requirements = { "base": read_requirements("requirements/base.txt"), + "cuda12x": read_requirements("requirements/cuda12x.txt"), "dev": _dev, "tensorflow": read_requirements("requirements/tensorflow.txt"), "pytorch": read_requirements("requirements/pytorch.txt"), "jax": read_requirements("requirements/jax.txt"), } + dev_requirements = { + "cuda12x-dev": requirements["cuda12x"] + _dev, "tensorflow-dev": requirements["tensorflow"] + _dev, "pytorch-dev": requirements["pytorch"] + _dev, "jax-dev": requirements["jax"] + _dev, @@ -75,6 +78,6 @@ def read_requirements(filename): **dev_requirements, "all": list(itertools.chain(*list(requirements.values()))), }, - python_requires=">=3.7", + python_requires=">=3.7, <3.12", test_suite="tests", )