Skip to content

Commit

Permalink
feat(python): lance write huggingface dataset directly (#1882)
Browse files Browse the repository at this point in the history
Be able to directly write a huggingface dataset
  • Loading branch information
eddyxu authored Jan 30, 2024
1 parent 1365378 commit b3db3cc
Show file tree
Hide file tree
Showing 6 changed files with 86 additions and 12 deletions.
18 changes: 18 additions & 0 deletions docs/integrations/huggingface.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
Lance ❤️ HuggingFace
--------------------

The HuggingFace Hub has become the go to place for ML practitioners to find pre-trained models and useful datasets.

HuggingFace datasets can be written directly into Lance format by using the
:meth:`lance.write_dataset` method. You can write the entire dataset or a particular split. For example:


.. code-block:: python
# Huggingface datasets
import datasets
import lance
lance.write_dataset(datasets.load_dataset(
"poloclub/diffusiondb", split="train[:10]",
), "diffusiondb_train.lance")
1 change: 1 addition & 0 deletions docs/integrations/integrations.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,5 @@ Integrations

.. toctree::

Huggingface <./huggingface>
Tensorflow <./tensorflow>
18 changes: 7 additions & 11 deletions python/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,7 @@
name = "pylance"
dependencies = ["pyarrow>=12", "numpy>=1.22"]
description = "python wrapper for Lance columnar format"
authors = [
{ name = "Lance Devs", email = "[email protected]" },
]
authors = [{ name = "Lance Devs", email = "[email protected]" }]
license = { file = "LICENSE" }
repository = "https://github.com/eto-ai/lance"
readme = "README.md"
Expand Down Expand Up @@ -48,20 +46,18 @@ build-backend = "maturin"

[project.optional-dependencies]
tests = [
"pandas",
"pytest",
"datasets",
"duckdb",
"ml_dtypes",
"pillow",
"pandas",
"polars[pyarrow,pandas]",
"pytest",
"tensorflow",
"tqdm",
]
benchmarks = [
"pytest-benchmark",
]
torch = [
"torch",
]
benchmarks = ["pytest-benchmark"]
torch = ["torch"]

[tool.ruff]
select = ["F", "E", "W", "I", "G", "TCH", "PERF", "CPY001", "B019"]
Expand Down
17 changes: 16 additions & 1 deletion python/python/lance/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,12 @@
import pyarrow.dataset
from pyarrow import RecordBatch, Schema

from .dependencies import _check_for_numpy, _check_for_pandas, torch
from .dependencies import (
_check_for_hugging_face,
_check_for_numpy,
_check_for_pandas,
torch,
)
from .dependencies import numpy as np
from .dependencies import pandas as pd
from .fragment import FragmentMetadata, LanceFragment
Expand Down Expand Up @@ -1992,6 +1997,7 @@ def write_dataset(
data_obj: Reader-like
The data to be written. Acceptable types are:
- Pandas DataFrame, Pyarrow Table, Dataset, Scanner, or RecordBatchReader
- Huggingface dataset
uri: str or Path
Where to write the dataset to (directory)
schema: Schema, optional
Expand Down Expand Up @@ -2020,6 +2026,15 @@ def write_dataset(
a custom class that defines hooks to be called when each fragment is
starting to write and finishing writing.
"""
if _check_for_hugging_face(data_obj):
# Huggingface datasets
from .dependencies import datasets

if isinstance(data_obj, datasets.Dataset):
if schema is None:
schema = data_obj.features.arrow_schema
data_obj = data_obj.data.to_batches()

reader = _coerce_reader(data_obj, schema)
_validate_schema(reader.schema)
# TODO add support for passing in LanceDataset and LanceScanner here
Expand Down
11 changes: 11 additions & 0 deletions python/python/lance/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
_PANDAS_AVAILABLE = True
_POLARS_AVAILABLE = True
_TORCH_AVAILABLE = True
_HUGGING_FACE_AVAILABLE = True


class _LazyModule(ModuleType):
Expand Down Expand Up @@ -164,6 +165,7 @@ def _lazy_import(module_name: str) -> tuple[ModuleType, bool]:


if TYPE_CHECKING:
import datasets
import numpy
import pandas
import polars
Expand All @@ -174,6 +176,7 @@ def _lazy_import(module_name: str) -> tuple[ModuleType, bool]:
pandas, _PANDAS_AVAILABLE = _lazy_import("pandas")
polars, _POLARS_AVAILABLE = _lazy_import("polars")
torch, _TORCH_AVAILABLE = _lazy_import("torch")
datasets, _HUGGING_FACE_AVAILABLE = _lazy_import("datasets")


@lru_cache(maxsize=None)
Expand Down Expand Up @@ -210,6 +213,12 @@ def _check_for_torch(obj: Any, *, check_type: bool = True) -> bool:
)


def _check_for_hugging_face(obj: Any, *, check_type: bool = True) -> bool:
return _HUGGING_FACE_AVAILABLE and _might_be(
cast(Hashable, type(obj) if check_type else obj), "datasets"
)


__all__ = [
# lazy-load third party libs
"numpy",
Expand All @@ -221,10 +230,12 @@ def _check_for_torch(obj: Any, *, check_type: bool = True) -> bool:
"_check_for_pandas",
"_check_for_polars",
"_check_for_torch",
"_check_for_hugging_face",
"_LazyModule",
# exported flags/guards
"_NUMPY_AVAILABLE",
"_PANDAS_AVAILABLE",
"_POLARS_AVAILABLE",
"_TORCH_AVAILABLE",
"_HUGGING_FACE_AVAILABLE",
]
33 changes: 33 additions & 0 deletions python/python/tests/test_huggingface.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# Copyright 2023 Lance Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from pathlib import Path

import lance
import pytest

datasets = pytest.importorskip("datasets")


def test_write_hf_dataset(tmp_path: Path):
hf_ds = datasets.load_dataset(
"poloclub/diffusiondb",
name="2m_first_1k",
split="train[:50]",
trust_remote_code=True,
)

ds = lance.write_dataset(hf_ds, tmp_path)
assert ds.count_rows() == 50

assert ds.schema == hf_ds.features.arrow_schema

0 comments on commit b3db3cc

Please sign in to comment.