diff --git a/docs/user-guide/DocumentDataset.rst b/docs/user-guide/DocumentDataset.rst index 351e41a9..0086314a 100644 --- a/docs/user-guide/DocumentDataset.rst +++ b/docs/user-guide/DocumentDataset.rst @@ -137,3 +137,87 @@ In these cases, we recommend processing the input dataset in batches using a sim This will read in 64 shards at a time, process them, and write them back to disk. Like ``get_remaining_files``, it only includes files that are in the input directory and not in the output directory. + +############################ +Blending and Shuffling +############################ + +Blending data from multiple sources can be a great way of improving downstream model performance. +This blending can be done during model training itself (i.e., *online* blending) or it can be done before training (i.e., *offline* blending). +Online blending is useful for rapidly iterating in the training process. +Meanwhile, offline blending is useful if you want to distribute the dataset. +Online blending is currently possible in `NeMo via NVIDIA Megatron Core `_, and NeMo Curator offers a way to perform blending offline. + +Let's take a look at how datasets can be combined using ``nc.blend_datasets`` + +.. code-block:: python + + import nemo_curator as nc + + books = DocumentDataset.read_json("books_dataset/") + articles = DocumentDataset.read_json("articles_dataset/") + journals = DocumentDataset.read_json("journals_dataset/") + + datasets = [books, articles, journals] + target_samples = 1000 + weights = [5.0, 2.0, 1.0] + + blended_dataset = nc.blend_datasets(target_samples, datasets, weights) + + blended_dataset.to_json("blended_dataset/") + + +* ``datasets = [books, articles, journals]`` Here, we are choosing to blend three different datasets. + These datasets do not have to be in the same file format, or similar in size. + So long as they can be read in as a DocumentDataset, they will be fine. + The samples from each dataset are always drawn "in order". + The precise order depends on the format. + For sharded jsonl files, the entries at the beginning of the file with the first name in sorted order will be chosen first. +* ``target_samples = 1000`` This is the desired number of samples in the resulting dataset. + By sample, we mean document or just generally a single datapoint. + There may end up being more samples in the dataset depending on the weights. +* ``weights = [5.0, 2.0, 1.0]`` The relative number of samples that should be taken from each dataset. + Given these weights, the blended dataset will have five times as many samples from books as there are samples from journals. + Similarly, there will be two times as many samples from articles when compared to samples from journals. + Weights can be a list of non-negative real numbers. + ``nc.blend_datasets`` will do the normalization and combine the normalized weights with the target samples to determine + how many samples should be taken from each dataset. + In the case of the books dataset, the following would be the calculation. + + .. math:: + + \lceil target\_samples \cdot w_i\rceil=\lceil 1000\cdot \frac{5}{8}\rceil=625 + If any datasets have fewer samples than the calculated weight, they will be oversampled to meet the quota. + For example, if the books dataset only had 500 documents in it, the first 125 would be repeated to achieve + the 625 samples. +* ``blended_dataset = nc.blend_datasets(target_samples, datasets, weights)`` We now call the function itself. + Afterwards, we are left with a blended dataset that we can operate on like any other dataset. + We can apply filters, deduplicate, or classify the documents. + +Because blending datasets involves combining data from multiple sources, the sharding of the original datasets +cannot be preserved. The options ``add_filename=True`` and ``write_to_filename=True`` for reading and writing +datasets are therefore incompatible with ``nc.blend_datasets``. + + +Shuffling can be another important aspect of dataset management. +NeMo Curator's ``nc.Shuffle`` allows users to reorder all entries in the dataset. + +Here is a small example on how this can be done: + +.. code-block:: python + + import nemo_curator as nc + + books = DocumentDataset.read_json("books_dataset/") + + shuffle = nc.Shuffle(seed=42) + + shuffled_books = shuffle(books) + + shuffled_books.to_json("shuffled_books/") + +* ``shuffle = nc.Shuffle(seed=42)`` This creates a shuffle operation that can be chained with + the various other modules in NeMo Curator. In this example, we fix the seed to be 42. + Setting the seed will guarantee determinism, but may be slightly slower (20-30% slower) + depending on the dataset size. +* ``shuffled_books = shuffle(books)`` The dataset has now been shuffled, and we can save it to the filesystem. diff --git a/examples/blend_and_shuffle.py b/examples/blend_and_shuffle.py new file mode 100644 index 00000000..e070d5d2 --- /dev/null +++ b/examples/blend_and_shuffle.py @@ -0,0 +1,53 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse + +import nemo_curator as nc +from nemo_curator.datasets import DocumentDataset +from nemo_curator.utils.distributed_utils import get_client +from nemo_curator.utils.script_utils import add_distributed_args + + +def main(args): + # Params + dataset_paths = ["/path/to/first", "/path/to/second", "/path/to/third"] + dataset_weights = [5.0, 2.0, 1.0] + target_size = 1000 + output_path = "/path/to/output" + + # Set up Dask client + client = get_client(args, args.device) + + # Blend the datasets + datasets = [DocumentDataset.read_json(path) for path in dataset_paths] + blended_dataset = nc.blend_datasets(target_size, datasets, dataset_weights) + + shuffle = nc.Shuffle(seed=42) + blended_dataset = shuffle(blended_dataset) + + # Save the blend + blended_dataset.to_json(output_path) + + +def attach_args( + parser=argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ), +): + return add_distributed_args(parser) + + +if __name__ == "__main__": + main(attach_args().parse_args()) diff --git a/nemo_curator/datasets/doc_dataset.py b/nemo_curator/datasets/doc_dataset.py index 37592b18..a97aa196 100644 --- a/nemo_curator/datasets/doc_dataset.py +++ b/nemo_curator/datasets/doc_dataset.py @@ -24,7 +24,7 @@ class DocumentDataset: Internally it may be distributed across multiple nodes, and may be on GPUs. """ - def __init__(self, dataset_df): + def __init__(self, dataset_df: dd.DataFrame): self.df = dataset_df def __len__(self): diff --git a/nemo_curator/modules/__init__.py b/nemo_curator/modules/__init__.py index 434ebecf..0867942d 100644 --- a/nemo_curator/modules/__init__.py +++ b/nemo_curator/modules/__init__.py @@ -22,6 +22,7 @@ from nemo_curator.utils.import_utils import gpu_only_import_from from .add_id import AddId +from .dataset_ops import blend_datasets, Shuffle from .exact_dedup import ExactDuplicates from .filter import Filter, Score, ScoreFilter from .meta import Sequential @@ -50,4 +51,6 @@ "Sequential", "TaskDecontamination", "AddId", + "blend_datasets", + "Shuffle", ] diff --git a/nemo_curator/modules/dataset_ops.py b/nemo_curator/modules/dataset_ops.py new file mode 100644 index 00000000..38589b1e --- /dev/null +++ b/nemo_curator/modules/dataset_ops.py @@ -0,0 +1,183 @@ +import math +from typing import Any, Callable, List, Optional + +import dask.dataframe as dd +import numpy as np + +from nemo_curator.datasets.doc_dataset import DocumentDataset + + +def default_filename(partition_num: int) -> str: + return f"file_{partition_num:010d}.jsonl" + + +class Shuffle: + def __init__( + self, + seed: Optional[int] = None, + npartitions: Optional[int] = None, + partition_to_filename: Callable[[int], str] = default_filename, + ) -> None: + """ + Randomly permutes the dataset. This will make the original "filename" column invalid, so if the column is present it will be overwritten. + Args: + seed: The random seed that will be used to determine which partition (file) each datapoint goes to. + Setting the seed will guarantee determinism, but may be slightly slower (20-30% slower) + depending on the dataset size. + npartitions: The output number of partitions to create in the dataset. + If None, it will retain the same number of partitions as the original dataset. + partition_to_filename: If the filename column is present, it will be overwritten. + Passing a function in through this argument allows the user to configure what the filename + will look like given the partition number. The default method names the partition + f'file_{partition_num:010d}.jsonl' and should be changed if the user is not using a .jsonl format. + """ + self.seed = seed + self.npartitions = npartitions + self.partition_to_filename = partition_to_filename + self.rand_col = "_shuffle_rand" + + def __call__(self, dataset: DocumentDataset) -> DocumentDataset: + if self.seed is None: + return self.shuffle_nondeterministic(dataset) + else: + return self.shuffle_deterministic(dataset) + + def shuffle_deterministic(self, dataset: DocumentDataset) -> DocumentDataset: + new_npartitions = ( + dataset.df.npartitions if self.npartitions is None else self.npartitions + ) + + dataset.df[self.rand_col] = dataset.df.map_partitions(self._add_rand_col) + + shuffled_df = dataset.df.set_index(self.rand_col, npartitions=new_npartitions) + shuffled_df = shuffled_df.reset_index(drop=True) + + if "filename" in shuffled_df: + shuffled_df["filename"] = shuffled_df.map_partitions(self._add_filename) + + return DocumentDataset(shuffled_df) + + def shuffle_nondeterministic(self, dataset: DocumentDataset) -> DocumentDataset: + new_npartitions = ( + dataset.df.npartitions if self.npartitions is None else self.npartitions + ) + + dataset.df[self.rand_col] = dataset.df.map_partitions(self._add_rand_col) + + shuffled_df = dataset.df.shuffle( + self.rand_col, npartitions=new_npartitions, ignore_index=True + ) + shuffled_df = shuffled_df.drop(columns=[self.rand_col]) + shuffled_df = shuffled_df.map_partitions(self._partition_shuffle) + + return DocumentDataset(shuffled_df) + + def _add_rand_col(self, partition, partition_info=None): + if partition_info is None: + partition_info = { + "number": 0, + } + + if self.seed is not None: + np.random.seed(self.seed + partition_info["number"]) + rand_col = np.random.randint(0, np.iinfo("int64").max, size=len(partition)) + + return rand_col + + def _partition_shuffle(self, partition, partition_info=None): + if partition_info is None: + return partition + + partition_num = partition_info["number"] + if self.seed is not None: + random_state = self.seed + partition_num + else: + random_state = None + + partition = partition.sample(frac=1, random_state=random_state).reset_index( + drop=True + ) + + if "filename" in partition: + filename = self.partition_to_filename(partition_num) + partition["filename"] = filename + + return partition + + def _add_filename(self, partition, partition_info=None): + if partition_info is None: + return ["filename"] * len(partition) + + filename = self.partition_to_filename(partition_info["number"]) + + return [filename for _ in range(len(partition))] + + +def blend_datasets( + target_size: int, datasets: List[DocumentDataset], sampling_weights: List[float] +) -> DocumentDataset: + """ + Combined multiple datasets into one with different amounts of each dataset + Args: + target_size: The number of documents the resulting dataset should have. + The actual size of the dataset may be slightly larger if the normalized weights do not allow + for even mixtures of the datasets. + datasets: A list of all datasets to combine together + sampling_weights: A list of weights to assign to each dataset in the input. Weights will be + normalized across the whole list as a part of the sampling process. For example, if the normalized + sampling weight for dataset 1 is 0.02, 2% ofthe total samples will be sampled from dataset 1. + There are guaranteed to be math.ceil(normalized_weight_i * target_size) elements from dataset i in + the final blend. + """ + if len(datasets) != len(sampling_weights): + raise ValueError( + f"Different number of datasets and weights specified. {len(datasets)} datasets and {len(sampling_weights)}" + ) + + weight_sum = sum(sampling_weights) + sampling_weights = [weight / weight_sum for weight in sampling_weights] + num_documents_per_dataset = [ + math.ceil(weight * target_size) for weight in sampling_weights + ] + + blend_components = [] + for dataset, num_documents in zip(datasets, num_documents_per_dataset): + # Repeatedly sample from the dataset + while num_documents > 0: + sample = _partition_head(dataset.df, num_documents) + blend_components.append(sample) + num_documents -= len(sample) + + blended_dataset = dd.concat(blend_components) + + return DocumentDataset(blended_dataset) + + +def _partition_head(ddf: dd.DataFrame, n: int) -> dd.DataFrame: + """ + Returns the first n rows in a dataframe while preserving the partitions. + Meant as a replacement for ddf.head(npartitions=-1, compute=False) as it + uses too much memory at large scales + + Args: + ddf: The dataframe to get the first rows from + n: The number of rows to get + """ + original_meta = ddf.dtypes.to_dict() + partition_lengths = ddf.map_partitions(len) + num_partitions = 0 + total_size = 0 + last_length = 0 + for length in partition_lengths: + total_size += length + num_partitions += 1 + last_length = length + if total_size >= n: + break + + delayed_df = ddf.to_delayed() + excess_elems = max(0, total_size - n) + delayed_df = delayed_df[:num_partitions] + delayed_df[-1] = delayed_df[-1].head(last_length - excess_elems) + + return dd.from_delayed(delayed_df, meta=original_meta) diff --git a/nemo_curator/scripts/blend_datasets.py b/nemo_curator/scripts/blend_datasets.py new file mode 100644 index 00000000..4f0fc253 --- /dev/null +++ b/nemo_curator/scripts/blend_datasets.py @@ -0,0 +1,138 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse + +import nemo_curator as nc +from nemo_curator.datasets import DocumentDataset +from nemo_curator.utils.distributed_utils import get_client, read_data, write_to_disk +from nemo_curator.utils.file_utils import ( + expand_outdir_and_mkdir, + get_all_files_paths_under, +) +from nemo_curator.utils.script_utils import add_distributed_args, attach_bool_arg + + +def main(args): + client = get_client(args, args.device) + + out_dir = expand_outdir_and_mkdir(args.output_data_dir) + + input_dirs = args.input_data_dirs.split(",") + weights = [float(weight) for weight in args.weights.split(",")] + + datasets = [ + DocumentDataset( + read_data( + get_all_files_paths_under(path), + file_type=args.input_file_type, + backend="pandas", + ) + ) + for path in input_dirs + ] + + output_dataset = nc.blend_datasets(args.target_samples, datasets, weights) + + if args.shuffle: + shuffle = nc.Shuffle(seed=args.seed) + output_dataset = shuffle(output_dataset) + + write_to_disk(output_dataset.df, out_dir, output_type=args.output_file_type) + + client.close() + + +def attach_args( + parser=argparse.ArgumentParser( + """ +Blends a collection of datasets together based on certain weights. + +It takes as input a comma-separated list of dataset directories, the +corresponding weights that should be associated with each datatset, +and the target number of samples to aggregate from across all the datasets. +The file shards of the resulting dataset are not guaranteed to be even +or reflect the original dataset(s). + +A blend is created from these datasets and saved to the specified output directory. +Optionally, the user can choose to shuffle this dataset as well. + """, + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) +): + parser.add_argument( + "--input-data-dirs", + type=str, + default=None, + help="Comma-separated list of directories consisting of dataset " + "files that are accessible to all nodes.", + ) + parser.add_argument( + "--weights", + type=str, + default=None, + help="Comma-separated list of floating-point weights corresponding " + "to each dataset passed in --input-data-dirs", + ) + parser.add_argument( + "--output-data-dir", + type=str, + default=None, + help="The output directory to where the blended dataset is" + "retained during filtering will be written. If this argument " + "is not specified, then the document scores from the " + "filter(s) will be written to the document meta data in place", + ) + parser.add_argument( + "--target-samples", + type=int, + default=10000, + help="The number of samples to be included in the output dataset." + " There may be more samples in order to accurately reflect the " + "weight balance, but there will never be less", + ) + attach_bool_arg( + parser, + "shuffle", + default=False, + help_str="Shuffles the dataset after blending", + ) + parser.add_argument( + "--seed", + type=int, + default=None, + help="If specified, the random seed used for shuffling.", + ) + parser.add_argument( + "--input-file-type", + type=str, + default="jsonl", + help="File type of the dataset to be read in. Supported file formats" + " include 'jsonl' (default), 'pickle', or 'parquet'.", + ) + parser.add_argument( + "--output-file-type", + type=str, + default="jsonl", + help="File type the dataset will be written to. Supported file formats" + " include 'jsonl' (default), 'pickle', or 'parquet'.", + ) + + parser = add_distributed_args(parser) + + return parser + + +def console_script(): + main(attach_args().parse_args()) diff --git a/setup.py b/setup.py index 91c32a29..357e33e5 100644 --- a/setup.py +++ b/setup.py @@ -102,6 +102,7 @@ "quality_classifier_multiple_models_inference=nemo_curator.distributed_data_classification.quality_classifier_multiple_models_inference:console_script", "quality_classifier_inference=nemo_curator.distributed_data_classification.quality_classifier_inference:console_script", "verify_results=nemo_curator.distributed_data_classification.verify_results:console_script", + "blend_datasets=nemo_curator.scripts.blend_datasets:console_script", ], }, ) diff --git a/tests/test_blend_datasets.py b/tests/test_blend_datasets.py new file mode 100644 index 00000000..7c7f7e28 --- /dev/null +++ b/tests/test_blend_datasets.py @@ -0,0 +1,103 @@ +import dask.dataframe as dd +import pandas as pd + +import nemo_curator as nc +from nemo_curator.datasets import DocumentDataset + + +def list_to_dataset(documents, col_name="text", npartitions=2): + data = {col_name: documents} + pdf = pd.DataFrame(data) + + return DocumentDataset(dd.from_pandas(pdf, npartitions=npartitions)) + + +def all_equal(left_dataset, right_dataset): + left_result = left_dataset.df.compute() + right_result = right_dataset.df.compute() + + l_cols = set(left_result.columns) + r_cols = set(right_result.columns) + assert l_cols == r_cols + for col in left_result.columns: + left = left_result[col].reset_index(drop=True) + right = right_result[col].reset_index(drop=True) + assert all(left == right), f"Mismatch in {col} column.\n{left}\n{right}\n" + + +class TestBlending: + def test_blend_as_original(self): + first_dataset = list_to_dataset(["one", "two", "three"]) + result_dataset = nc.blend_datasets(len(first_dataset), [first_dataset], [1.0]) + all_equal(first_dataset, result_dataset) + + def test_equal_blend(self): + first_dataset = list_to_dataset(["a", "a"]) + second_dataset = list_to_dataset(["b", "b"]) + result_dataset = nc.blend_datasets( + 2, [first_dataset, second_dataset], [0.5, 0.5] + ) + counts = result_dataset.df["text"].value_counts().compute() + assert len(result_dataset) == 2 + assert counts["a"] == 1 + assert counts["b"] == 1 + + def test_equal_blend_with_weights(self): + first_dataset = list_to_dataset(["a", "a"]) + second_dataset = list_to_dataset(["b", "b"]) + result_dataset = nc.blend_datasets( + 2, [first_dataset, second_dataset], [2.0, 2.0] + ) + counts = result_dataset.df["text"].value_counts().compute() + assert len(result_dataset) == 2 + assert counts["a"] == 1 + assert counts["b"] == 1 + + def test_uneven_blend(self): + first_dataset = list_to_dataset(["a", "a"]) + second_dataset = list_to_dataset(["b", "b"]) + result_dataset = nc.blend_datasets( + 4, [first_dataset, second_dataset], [3.0, 1.0] + ) + counts = result_dataset.df["text"].value_counts().compute() + assert len(result_dataset) == 4 + assert counts["a"] == 3 + assert counts["b"] == 1 + + def test_very_uneven_blend(self): + first_dataset = list_to_dataset(["a", "a"]) + second_dataset = list_to_dataset(["b", "b"]) + result_dataset = nc.blend_datasets( + 4, [first_dataset, second_dataset], [1.0, 0.0] + ) + counts = result_dataset.df["text"].value_counts().compute() + assert len(result_dataset) == 4 + assert counts["a"] == 4 + assert "b" not in counts + + def test_proper_uneven_blend(self): + first_dataset = list_to_dataset(["a", "b", "c", "d"]) + second_dataset = list_to_dataset(["e", "f"]) + result_dataset = nc.blend_datasets( + 8, [first_dataset, second_dataset], [1.0, 0.0] + ) + counts = result_dataset.df["text"].value_counts().compute() + assert len(result_dataset) == 8 + assert counts["a"] == 2 + assert counts["b"] == 2 + assert counts["c"] == 2 + assert counts["d"] == 2 + + def test_four_dataset_blend(self): + datasets = [] + datasets.append(list_to_dataset(["a", "a"])) + datasets.append(list_to_dataset(["b", "b", "b"])) + datasets.append(list_to_dataset(["c"])) + datasets.append(list_to_dataset(["d", "d", "d", "d"])) + result_dataset = nc.blend_datasets(8, datasets, [1.0, 2.0, 3.0, 4.0]) + counts = result_dataset.df["text"].value_counts().compute() + assert len(result_dataset) == 10 + assert counts["a"] == 1 + assert counts["b"] == 2 + assert counts["c"] == 3 + assert counts["d"] == 4 diff --git a/tests/test_shuffle.py b/tests/test_shuffle.py new file mode 100644 index 00000000..a23d4790 --- /dev/null +++ b/tests/test_shuffle.py @@ -0,0 +1,186 @@ +import dask.dataframe as dd +import pandas as pd +from dask.distributed import Client, LocalCluster + +import nemo_curator as nc +from nemo_curator.datasets import DocumentDataset + + +def list_to_dataset(documents, col_name="text", npartitions=2): + data = {col_name: documents} + pdf = pd.DataFrame(data) + + return DocumentDataset(dd.from_pandas(pdf, npartitions=npartitions)) + + +def all_equal(left_dataset, right_dataset): + left_result = left_dataset.df.compute() + right_result = right_dataset.df.compute() + + l_cols = set(left_result.columns) + r_cols = set(right_result.columns) + assert l_cols == r_cols + for col in left_result.columns: + left = left_result[col].reset_index(drop=True) + right = right_result[col].reset_index(drop=True) + assert all(left == right), f"Mismatch in {col} column.\n{left}\n{right}\n" + + +class TestShuffleNondeterministic: + def test_shuffle(self): + # Single threaded Dask is the only way to guarantee shuffle determinism + # Docs: https://docs.dask.org/en/latest/generated/dask.dataframe.DataFrame.shuffle.html + with LocalCluster(n_workers=1, threads_per_worker=1) as cluster: + with Client(cluster): + original_dataset = list_to_dataset( + ["one", "two", "three", "four", "five"] + ) + expected_dataset = list_to_dataset( + ["two", "five", "three", "one", "four"] + ) + shuffle = nc.Shuffle(seed=42) + result_dataset = shuffle.shuffle_nondeterministic(original_dataset) + all_equal(expected_dataset, result_dataset) + + def test_new_partitions(self): + with LocalCluster(n_workers=1, threads_per_worker=1) as cluster: + with Client(cluster): + original_dataset = list_to_dataset( + ["one", "two", "three", "four", "five"], npartitions=3 + ) + expected_dataset = list_to_dataset( + ["two", "five", "three", "one", "four"], npartitions=3 + ) + shuffle = nc.Shuffle(seed=42, npartitions=2) + result_dataset = shuffle.shuffle_nondeterministic(original_dataset) + all_equal(expected_dataset, result_dataset) + + def test_filename(self): + with LocalCluster(n_workers=1, threads_per_worker=1) as cluster: + with Client(cluster): + original_dataset = list_to_dataset( + ["one", "two", "three", "four", "five"], npartitions=1 + ) + original_dataset.df["filename"] = "original.jsonl" + + expected_data = { + "text": ["one", "two", "three", "five", "four"], + "filename": [ + "file_0000000000.jsonl", + "file_0000000000.jsonl", + "file_0000000000.jsonl", + "file_0000000001.jsonl", + "file_0000000001.jsonl", + ], + } + pdf = pd.DataFrame(expected_data) + expected_dataset = DocumentDataset(dd.from_pandas(pdf, npartitions=2)) + + shuffle = nc.Shuffle(seed=42, npartitions=2) + result_dataset = shuffle.shuffle_nondeterministic(original_dataset) + all_equal(expected_dataset, result_dataset) + + def test_custom_filenames(self): + with LocalCluster(n_workers=1, threads_per_worker=1) as cluster: + with Client(cluster): + original_dataset = list_to_dataset( + ["one", "two", "three", "four", "five"], npartitions=1 + ) + original_dataset.df["filename"] = "original.jsonl" + + expected_data = { + "text": ["one", "two", "three", "five", "four"], + "filename": [ + "my_0.test", + "my_0.test", + "my_0.test", + "my_1.test", + "my_1.test", + ], + } + pdf = pd.DataFrame(expected_data) + expected_dataset = DocumentDataset(dd.from_pandas(pdf, npartitions=2)) + + def filename_fn(x): + return f"my_{x}.test" + + shuffle = nc.Shuffle( + seed=42, npartitions=2, partition_to_filename=filename_fn + ) + result_dataset = shuffle.shuffle_nondeterministic(original_dataset) + all_equal(expected_dataset, result_dataset) + + def test_shuffle_no_seed(self): + original_dataset = list_to_dataset(["one", "two", "three", "four", "five"]) + shuffle = nc.Shuffle() + result_dataset = shuffle(original_dataset) + assert len(result_dataset.df.compute()) == 5 + + +class TestShuffleDeterministic: + def test_shuffle(self): + original_dataset = list_to_dataset(["one", "two", "three", "four", "five"]) + expected_dataset = list_to_dataset(["five", "four", "three", "one", "two"]) + shuffle = nc.Shuffle(seed=42) + result_dataset = shuffle(original_dataset) + all_equal(expected_dataset, result_dataset) + + def test_new_partitions(self): + original_dataset = list_to_dataset( + ["one", "two", "three", "four", "five"], npartitions=3 + ) + expected_dataset = list_to_dataset( + ["four", "three", "five", "one", "two"], npartitions=3 + ) + shuffle = nc.Shuffle(seed=42, npartitions=2) + result_dataset = shuffle(original_dataset) + all_equal(expected_dataset, result_dataset) + + def test_filename(self): + original_dataset = list_to_dataset( + ["one", "two", "three", "four", "five"], npartitions=1 + ) + original_dataset.df["filename"] = "original.jsonl" + + expected_data = { + "text": ["four", "five", "three", "one", "two"], + "filename": [ + "file_0000000000.jsonl", + "file_0000000001.jsonl", + "file_0000000001.jsonl", + "file_0000000001.jsonl", + "file_0000000001.jsonl", + ], + } + pdf = pd.DataFrame(expected_data) + expected_dataset = DocumentDataset(dd.from_pandas(pdf, npartitions=2)) + + shuffle = nc.Shuffle(seed=42, npartitions=2) + result_dataset = shuffle(original_dataset) + all_equal(expected_dataset, result_dataset) + + def test_custom_filenames(self): + original_dataset = list_to_dataset( + ["one", "two", "three", "four", "five"], npartitions=1 + ) + original_dataset.df["filename"] = "original.jsonl" + + expected_data = { + "text": ["four", "five", "three", "one", "two"], + "filename": [ + "my_0.test", + "my_1.test", + "my_1.test", + "my_1.test", + "my_1.test", + ], + } + pdf = pd.DataFrame(expected_data) + expected_dataset = DocumentDataset(dd.from_pandas(pdf, npartitions=2)) + + def filename_fn(x): + return f"my_{x}.test" + + shuffle = nc.Shuffle(seed=42, npartitions=2, partition_to_filename=filename_fn) + result_dataset = shuffle(original_dataset) + all_equal(expected_dataset, result_dataset)