diff --git a/allennlp/data/dataset_readers/huggingface_datasets_reader.py b/allennlp/data/dataset_readers/huggingface_datasets_reader.py index b1cba0a3471..42f11f7780d 100644 --- a/allennlp/data/dataset_readers/huggingface_datasets_reader.py +++ b/allennlp/data/dataset_readers/huggingface_datasets_reader.py @@ -1,14 +1,16 @@ from typing import Iterable, Optional +import datasets from allennlp.data import DatasetReader, Token, Field from allennlp.data.fields import TextField, LabelField, ListField from allennlp.data.instance import Instance -from datasets import load_dataset, Dataset, DatasetDict +from datasets import load_dataset, Dataset, DatasetDict, Split from datasets.features import ClassLabel, Sequence, Translation, TranslationVariableLanguages from datasets.features import Value -# TODO pab complete the documentation comments -class HuggingfaceDatasetSplitReader(DatasetReader): + +# TODO pab-vmware complete the documentation comments +class HuggingfaceDatasetReader(DatasetReader): """ This reader implementation wraps the huggingface datasets package to utilize it's dataset management functionality and load the information in AllenNLP friendly formats @@ -44,15 +46,17 @@ class HuggingfaceDatasetSplitReader(DatasetReader): pre_load : `bool`, optional (default='False`) """ + SUPPORTED_SPLITS = [Split.TRAIN, Split.TEST, Split.VALIDATION] + def __init__( - self, - max_instances: Optional[int] = None, - manual_distributed_sharding: bool = False, - manual_multiprocess_sharding: bool = False, - serialization_dir: Optional[str] = None, - dataset_name: str = None, - config_name: Optional[str] = None, - pre_load: Optional[bool] = False + self, + max_instances: Optional[int] = None, + manual_distributed_sharding: bool = False, + manual_multiprocess_sharding: bool = False, + serialization_dir: Optional[str] = None, + dataset_name: str = None, + config_name: Optional[str] = None, + pre_load: Optional[bool] = False ) -> None: super().__init__( max_instances, @@ -77,22 +81,29 @@ def load_dataset(self): else: self.datasets = load_dataset(self.dataset_name) - def load_dataset_split(self, split): - if self.config_name is not None: - self.datasets[split] = load_dataset(self.dataset_name, self.config_name, split=split) + def load_dataset_split(self, split: str): + # TODO add support for datasets.split.NamedSplit + if split in self.SUPPORTED_SPLITS: + if self.config_name is not None: + self.datasets[split] = load_dataset(self.dataset_name, self.config_name, split=split) + else: + self.datasets[split] = load_dataset(self.dataset_name, split=split) else: - self.datasets[split] = load_dataset(self.dataset_name, split=split) + raise ValueError(f"Only default splits:{self.SUPPORTED_SPLITS} are currently supported.") - def _read(self, file_path) -> Iterable[Instance]: + def _read(self, file_path: str) -> Iterable[Instance]: """ Reads the dataset and converts the entry to AllenNLP friendly instance """ + if file_path is None: + raise ValueError("parameter split cannot be None") + + # If split is not loaded, load the specific split if file_path not in self.datasets: self.load_dataset_split(file_path) - if self.datasets is not None and self.datasets[file_path] is not None: - for entry in self.datasets[file_path]: - yield self.text_to_instance(entry) + for entry in self.datasets[file_path]: + yield self.text_to_instance(entry) def raise_feature_not_supported_value_error(self, value): raise ValueError(f"Datasets feature type {type(value)} is not supported yet.") @@ -179,7 +190,6 @@ def text_to_instance(self, *inputs) -> Instance: else: self.raise_feature_not_supported_value_error(value) - # datasets.Translation cannot be mapped directly # but it's dict structure can be mapped to a ListField of 2 ListField elif isinstance(value, Translation): diff --git a/tests/data/dataset_readers/huggingface_datasets_test.py b/tests/data/dataset_readers/huggingface_datasets_test.py index 688dc2d110a..6f29a081498 100644 --- a/tests/data/dataset_readers/huggingface_datasets_test.py +++ b/tests/data/dataset_readers/huggingface_datasets_test.py @@ -1,13 +1,12 @@ import pytest -from allennlp.data.dataset_readers.huggingface_datasets_reader import HuggingfaceDatasetSplitReader +from allennlp.data.dataset_readers.huggingface_datasets_reader import HuggingfaceDatasetReader import logging logger = logging.getLogger(__name__) -# TODO these UTs are actually downloading the datasets and will be very very slow -# TODO add UT were we compare huggingface wrapped reader with an explicitly coded builder +# TODO add UT were we compare huggingface wrapped reader with an explicitly coded dataset class HuggingfaceDatasetSplitReaderTest: """ @@ -15,7 +14,7 @@ class HuggingfaceDatasetSplitReaderTest: """ @pytest.mark.parametrize("dataset, config, split", (("glue", "cola", "train"), ("glue", "cola", "test"))) def test_read_for_datasets_requiring_config(self, dataset, config, split): - huggingface_reader = HuggingfaceDatasetSplitReader(dataset_name=dataset, config_name=config) + huggingface_reader = HuggingfaceDatasetReader(dataset_name=dataset, config_name=config) instances = list(huggingface_reader.read(split)) assert len(instances) == len(huggingface_reader.datasets[split]) print(instances[0], print(huggingface_reader.datasets[split][0]))