Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(python): lance write huggingface dataset directly #1882

Merged
merged 7 commits into from
Jan 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions docs/integrations/huggingface.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
Lance ❤️ HuggingFace
--------------------

The HuggingFace Hub has become the go to place for ML practitioners to find pre-trained models and useful datasets.

HuggingFace datasets can be written directly into Lance format by using the
:meth:`lance.write_dataset` method. You can write the entire dataset or a particular split. For example:


.. code-block:: python

# Huggingface datasets
import datasets
import lance

lance.write_dataset(datasets.load_dataset(
"poloclub/diffusiondb", split="train[:10]",
), "diffusiondb_train.lance")
1 change: 1 addition & 0 deletions docs/integrations/integrations.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,5 @@ Integrations

.. toctree::

Huggingface <./huggingface>
Tensorflow <./tensorflow>
18 changes: 7 additions & 11 deletions python/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,7 @@
name = "pylance"
dependencies = ["pyarrow>=12", "numpy>=1.22"]
description = "python wrapper for Lance columnar format"
authors = [
{ name = "Lance Devs", email = "[email protected]" },
]
authors = [{ name = "Lance Devs", email = "[email protected]" }]
license = { file = "LICENSE" }
repository = "https://github.com/eto-ai/lance"
readme = "README.md"
Expand Down Expand Up @@ -48,20 +46,18 @@ build-backend = "maturin"

[project.optional-dependencies]
tests = [
"pandas",
"pytest",
"datasets",
"duckdb",
"ml_dtypes",
"pillow",
"pandas",
"polars[pyarrow,pandas]",
"pytest",
"tensorflow",
"tqdm",
]
benchmarks = [
"pytest-benchmark",
]
torch = [
"torch",
]
benchmarks = ["pytest-benchmark"]
torch = ["torch"]

[tool.ruff]
select = ["F", "E", "W", "I", "G", "TCH", "PERF", "CPY001"]
Expand Down
17 changes: 16 additions & 1 deletion python/python/lance/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,12 @@
import pyarrow.dataset
from pyarrow import RecordBatch, Schema

from .dependencies import _check_for_numpy, _check_for_pandas, torch
from .dependencies import (
_check_for_hugging_face,
_check_for_numpy,
_check_for_pandas,
torch,
)
from .dependencies import numpy as np
from .dependencies import pandas as pd
from .fragment import FragmentMetadata, LanceFragment
Expand Down Expand Up @@ -1960,6 +1965,7 @@ def write_dataset(
data_obj: Reader-like
The data to be written. Acceptable types are:
- Pandas DataFrame, Pyarrow Table, Dataset, Scanner, or RecordBatchReader
- Huggingface dataset
uri: str or Path
Where to write the dataset to (directory)
schema: Schema, optional
Expand Down Expand Up @@ -1988,6 +1994,15 @@ def write_dataset(
a custom class that defines hooks to be called when each fragment is
starting to write and finishing writing.
"""
if _check_for_hugging_face(data_obj):
# Huggingface datasets
from .dependencies import datasets

if isinstance(data_obj, datasets.Dataset):
if schema is None:
schema = data_obj.features.arrow_schema
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for the datasets that have embeddings, are they usually list or fsl? do we need to check/cast or anything like that?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is too smart at the lance level tho.

data_obj = data_obj.data.to_batches()

reader = _coerce_reader(data_obj, schema)
_validate_schema(reader.schema)
# TODO add support for passing in LanceDataset and LanceScanner here
Expand Down
11 changes: 11 additions & 0 deletions python/python/lance/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
_PANDAS_AVAILABLE = True
_POLARS_AVAILABLE = True
_TORCH_AVAILABLE = True
_HUGGING_FACE_AVAILABLE = True


class _LazyModule(ModuleType):
Expand Down Expand Up @@ -164,6 +165,7 @@ def _lazy_import(module_name: str) -> tuple[ModuleType, bool]:


if TYPE_CHECKING:
import datasets
import numpy
import pandas
import polars
Expand All @@ -174,6 +176,7 @@ def _lazy_import(module_name: str) -> tuple[ModuleType, bool]:
pandas, _PANDAS_AVAILABLE = _lazy_import("pandas")
polars, _POLARS_AVAILABLE = _lazy_import("polars")
torch, _TORCH_AVAILABLE = _lazy_import("torch")
datasets, _HUGGING_FACE_AVAILABLE = _lazy_import("datasets")


@lru_cache(maxsize=None)
Expand Down Expand Up @@ -210,6 +213,12 @@ def _check_for_torch(obj: Any, *, check_type: bool = True) -> bool:
)


def _check_for_hugging_face(obj: Any, *, check_type: bool = True) -> bool:
return _HUGGING_FACE_AVAILABLE and _might_be(
cast(Hashable, type(obj) if check_type else obj), "datasets"
)


__all__ = [
# lazy-load third party libs
"numpy",
Expand All @@ -221,10 +230,12 @@ def _check_for_torch(obj: Any, *, check_type: bool = True) -> bool:
"_check_for_pandas",
"_check_for_polars",
"_check_for_torch",
"_check_for_hugging_face",
"_LazyModule",
# exported flags/guards
"_NUMPY_AVAILABLE",
"_PANDAS_AVAILABLE",
"_POLARS_AVAILABLE",
"_TORCH_AVAILABLE",
"_HUGGING_FACE_AVAILABLE",
]
33 changes: 33 additions & 0 deletions python/python/tests/test_huggingface.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# Copyright 2023 Lance Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from pathlib import Path

import lance
import pytest

datasets = pytest.importorskip("datasets")


def test_write_hf_dataset(tmp_path: Path):
hf_ds = datasets.load_dataset(
"poloclub/diffusiondb",
name="2m_first_1k",
split="train[:50]",
trust_remote_code=True,
)

ds = lance.write_dataset(hf_ds, tmp_path)
assert ds.count_rows() == 50

assert ds.schema == hf_ds.features.arrow_schema
Loading