diff --git a/docs/integrations/huggingface.rst b/docs/integrations/huggingface.rst new file mode 100644 index 0000000000..892e56be77 --- /dev/null +++ b/docs/integrations/huggingface.rst @@ -0,0 +1,18 @@ +Lance ❤️ HuggingFace +-------------------- + +The HuggingFace Hub has become the go to place for ML practitioners to find pre-trained models and useful datasets. + +HuggingFace datasets can be written directly into Lance format by using the +:meth:`lance.write_dataset` method. You can write the entire dataset or a particular split. For example: + + +.. code-block:: python + + # Huggingface datasets + import datasets + import lance + + lance.write_dataset(datasets.load_dataset( + "poloclub/diffusiondb", split="train[:10]", + ), "diffusiondb_train.lance") \ No newline at end of file diff --git a/docs/integrations/integrations.rst b/docs/integrations/integrations.rst index c6e85e7d63..a84ca4f9b8 100644 --- a/docs/integrations/integrations.rst +++ b/docs/integrations/integrations.rst @@ -3,4 +3,5 @@ Integrations .. toctree:: + Huggingface <./huggingface> Tensorflow <./tensorflow> \ No newline at end of file diff --git a/python/pyproject.toml b/python/pyproject.toml index 6e2842d265..1282fcb51e 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -2,9 +2,7 @@ name = "pylance" dependencies = ["pyarrow>=12", "numpy>=1.22"] description = "python wrapper for Lance columnar format" -authors = [ - { name = "Lance Devs", email = "dev@lancedb.com" }, -] +authors = [{ name = "Lance Devs", email = "dev@lancedb.com" }] license = { file = "LICENSE" } repository = "https://github.com/eto-ai/lance" readme = "README.md" @@ -48,20 +46,18 @@ build-backend = "maturin" [project.optional-dependencies] tests = [ - "pandas", - "pytest", + "datasets", "duckdb", "ml_dtypes", + "pillow", + "pandas", "polars[pyarrow,pandas]", + "pytest", "tensorflow", "tqdm", ] -benchmarks = [ - "pytest-benchmark", -] -torch = [ - "torch", -] +benchmarks = ["pytest-benchmark"] +torch = ["torch"] [tool.ruff] select = ["F", "E", "W", "I", "G", "TCH", "PERF", "CPY001"] diff --git a/python/python/lance/dataset.py b/python/python/lance/dataset.py index 520d37ae53..9ad59721dd 100644 --- a/python/python/lance/dataset.py +++ b/python/python/lance/dataset.py @@ -42,7 +42,12 @@ import pyarrow.dataset from pyarrow import RecordBatch, Schema -from .dependencies import _check_for_numpy, _check_for_pandas, torch +from .dependencies import ( + _check_for_hugging_face, + _check_for_numpy, + _check_for_pandas, + torch, +) from .dependencies import numpy as np from .dependencies import pandas as pd from .fragment import FragmentMetadata, LanceFragment @@ -1960,6 +1965,7 @@ def write_dataset( data_obj: Reader-like The data to be written. Acceptable types are: - Pandas DataFrame, Pyarrow Table, Dataset, Scanner, or RecordBatchReader + - Huggingface dataset uri: str or Path Where to write the dataset to (directory) schema: Schema, optional @@ -1988,6 +1994,15 @@ def write_dataset( a custom class that defines hooks to be called when each fragment is starting to write and finishing writing. """ + if _check_for_hugging_face(data_obj): + # Huggingface datasets + from .dependencies import datasets + + if isinstance(data_obj, datasets.Dataset): + if schema is None: + schema = data_obj.features.arrow_schema + data_obj = data_obj.data.to_batches() + reader = _coerce_reader(data_obj, schema) _validate_schema(reader.schema) # TODO add support for passing in LanceDataset and LanceScanner here diff --git a/python/python/lance/dependencies.py b/python/python/lance/dependencies.py index f020d7a921..13c8c2156a 100644 --- a/python/python/lance/dependencies.py +++ b/python/python/lance/dependencies.py @@ -34,6 +34,7 @@ _PANDAS_AVAILABLE = True _POLARS_AVAILABLE = True _TORCH_AVAILABLE = True +_HUGGING_FACE_AVAILABLE = True class _LazyModule(ModuleType): @@ -164,6 +165,7 @@ def _lazy_import(module_name: str) -> tuple[ModuleType, bool]: if TYPE_CHECKING: + import datasets import numpy import pandas import polars @@ -174,6 +176,7 @@ def _lazy_import(module_name: str) -> tuple[ModuleType, bool]: pandas, _PANDAS_AVAILABLE = _lazy_import("pandas") polars, _POLARS_AVAILABLE = _lazy_import("polars") torch, _TORCH_AVAILABLE = _lazy_import("torch") + datasets, _HUGGING_FACE_AVAILABLE = _lazy_import("datasets") @lru_cache(maxsize=None) @@ -210,6 +213,12 @@ def _check_for_torch(obj: Any, *, check_type: bool = True) -> bool: ) +def _check_for_hugging_face(obj: Any, *, check_type: bool = True) -> bool: + return _HUGGING_FACE_AVAILABLE and _might_be( + cast(Hashable, type(obj) if check_type else obj), "datasets" + ) + + __all__ = [ # lazy-load third party libs "numpy", @@ -221,10 +230,12 @@ def _check_for_torch(obj: Any, *, check_type: bool = True) -> bool: "_check_for_pandas", "_check_for_polars", "_check_for_torch", + "_check_for_hugging_face", "_LazyModule", # exported flags/guards "_NUMPY_AVAILABLE", "_PANDAS_AVAILABLE", "_POLARS_AVAILABLE", "_TORCH_AVAILABLE", + "_HUGGING_FACE_AVAILABLE", ] diff --git a/python/python/tests/test_huggingface.py b/python/python/tests/test_huggingface.py new file mode 100644 index 0000000000..fd7a64167f --- /dev/null +++ b/python/python/tests/test_huggingface.py @@ -0,0 +1,33 @@ +# Copyright 2023 Lance Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pathlib import Path + +import lance +import pytest + +datasets = pytest.importorskip("datasets") + + +def test_write_hf_dataset(tmp_path: Path): + hf_ds = datasets.load_dataset( + "poloclub/diffusiondb", + name="2m_first_1k", + split="train[:50]", + trust_remote_code=True, + ) + + ds = lance.write_dataset(hf_ds, tmp_path) + assert ds.count_rows() == 50 + + assert ds.schema == hf_ds.features.arrow_schema