Skip to content

Commit

Permalink
add multi-proc to dataset dict (huggingface#612)
Browse files Browse the repository at this point in the history
* add multi-proc to dataset dict

* update list_datasets and list_metrics
  • Loading branch information
thomwolf authored Sep 11, 2020
1 parent d4f29c0 commit 537de11
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 6 deletions.
8 changes: 8 additions & 0 deletions src/datasets/dataset_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,7 @@ def map(
features: Optional[Features] = None,
disable_nullable: bool = False,
fn_kwargs: Optional[dict] = None,
num_proc: Optional[int] = None,
) -> "DatasetDict":
"""Apply a function to all the elements in the table (individually or in batches)
and update the table (if function does updated examples).
Expand Down Expand Up @@ -272,6 +273,8 @@ def map(
instead of the automatically generated one.
disable_nullable (`bool`, defaults to `True`): Disallow null values in the table.
fn_kwargs (`Optional[Dict]`, defaults to `None`): Keyword arguments to be passed to `function`
num_proc (`Optional[int]`, defaults to `None`): Number of processes for multiprocessing. By default it doesn't
use multiprocessing.
"""
self._check_values_type()
if cache_file_names is None:
Expand All @@ -292,6 +295,7 @@ def map(
features=features,
disable_nullable=disable_nullable,
fn_kwargs=fn_kwargs,
num_proc=num_proc,
)
for k, dataset in self.items()
}
Expand All @@ -309,6 +313,7 @@ def filter(
cache_file_names: Optional[Dict[str, str]] = None,
writer_batch_size: Optional[int] = 1000,
fn_kwargs: Optional[dict] = None,
num_proc: Optional[int] = None,
) -> "DatasetDict":
"""Apply a filter function to all the elements in the table in batches
and update the table so that the dataset only includes examples according to the filter function.
Expand All @@ -335,6 +340,8 @@ def filter(
writer_batch_size (`int`, defaults to `1000`): Number of rows per write operation for the cache file writer.
Higher value gives smaller cache files, lower value consume less temporary memory while running `.map()`.
fn_kwargs (`Optional[Dict]`, defaults to `None`): Keyword arguments to be passed to `function`
num_proc (`Optional[int]`, defaults to `None`): Number of processes for multiprocessing. By default it doesn't
use multiprocessing.
"""
self._check_values_type()
if cache_file_names is None:
Expand All @@ -352,6 +359,7 @@ def filter(
cache_file_name=cache_file_names[k],
writer_batch_size=writer_batch_size,
fn_kwargs=fn_kwargs,
num_proc=num_proc,
)
for k, dataset in self.items()
}
Expand Down
22 changes: 16 additions & 6 deletions src/datasets/inspect.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,16 +27,26 @@
logger = get_logger(__name__)


def list_datasets(with_community_datasets=True, id_only=False):
""" List all the datasets scripts available on HuggingFace AWS bucket """
def list_datasets(with_community_datasets=True, with_details=False):
"""List all the datasets scripts available on HuggingFace AWS bucket.
Args:
with_community_datasets (Optional ``bool``): Include the community provided datasets (default: ``True``)
with_details (Optional ``bool``): Return the full details on the datasets instead of only the short name (default: ``False``)
"""
api = HfApi()
return api.dataset_list(with_community_datasets=with_community_datasets, id_only=id_only)
return api.dataset_list(with_community_datasets=with_community_datasets, id_only=bool(not with_details))


def list_metrics(with_community_metrics=True, id_only=False):
""" List all the metrics script available on HuggingFace AWS bucket """
def list_metrics(with_community_metrics=True, id_only=False, with_details=False):
"""List all the metrics script available on HuggingFace AWS bucket
Args:
with_community_metrics (Optional ``bool``): Include the community provided metrics (default: ``True``)
with_details (Optional ``bool``): Return the full details on the metrics instead of only the short name (default: ``False``)
"""
api = HfApi()
return api.metric_list(with_community_metrics=with_community_metrics, id_only=id_only)
return api.metric_list(with_community_metrics=with_community_metrics, id_only=bool(not with_details))


def inspect_dataset(path: str, local_path: str, download_config: Optional[DownloadConfig] = None, **download_kwargs):
Expand Down

0 comments on commit 537de11

Please sign in to comment.