forked from NVIDIA/NeMo-Curator
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Tutorial for reproducing Zyda2 dataset (NVIDIA#303)
* Zyda2 tutorial Signed-off-by: Yury Tokpanov <[email protected]> * Fix linter errors Signed-off-by: Yury Tokpanov <[email protected]> --------- Signed-off-by: Yury Tokpanov <[email protected]>
- Loading branch information
1 parent
f130aed
commit 65affd6
Showing
28 changed files
with
2,001 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
import os | ||
|
||
from nemo_curator import AddId | ||
from nemo_curator.datasets import DocumentDataset | ||
from nemo_curator.utils.file_utils import get_all_files_paths_under | ||
|
||
|
||
def ensure_directory_exists(filename: str): | ||
os.makedirs(os.path.dirname(filename), exist_ok=True) | ||
|
||
|
||
def process_data(input_folder, output_folder, prefix, partition_size="512MB"): | ||
raw_files = get_all_files_paths_under(input_folder) | ||
raw_data = DocumentDataset.read_parquet(raw_files) | ||
raw_data_rep = DocumentDataset( | ||
raw_data.df.repartition(partition_size=partition_size) | ||
) | ||
add_id = AddId(id_field="nemo_id", id_prefix=prefix) | ||
data_with_id = add_id(raw_data_rep) | ||
ensure_directory_exists(output_folder) | ||
data_with_id.to_parquet(output_folder) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
import os | ||
|
||
os.environ["DASK_DATAFRAME__QUERY_PLANNING"] = "False" | ||
|
||
import logging | ||
|
||
from dask.distributed import Client, LocalCluster | ||
from helper import process_data | ||
|
||
logging.basicConfig(format="%(asctime)s: %(message)s", level=logging.INFO) | ||
|
||
DATA_BASE = os.environ.get("DATA_BASE") | ||
INPUT_BASE = os.path.join(DATA_BASE, "raw/dclm-baseline-1.0-parquet/filtered") | ||
OUTPUT_BASE = os.path.join(DATA_BASE, "processed/dclm-baseline-1.0-parquet") | ||
CPU_WORKERS = os.environ.get("CPU_WORKERS") | ||
|
||
|
||
if __name__ == "__main__": | ||
logging.info("Starting Dask cluster") | ||
cluster = LocalCluster(n_workers=CPU_WORKERS, processes=True, memory_limit="48GB") | ||
client = Client(cluster) | ||
logging.info(client) | ||
|
||
components = [ | ||
"global-shard_01_of_10", | ||
"global-shard_02_of_10", | ||
"global-shard_03_of_10", | ||
"global-shard_04_of_10", | ||
"global-shard_05_of_10", | ||
"global-shard_06_of_10", | ||
"global-shard_07_of_10", | ||
"global-shard_08_of_10", | ||
"global-shard_09_of_10", | ||
"global-shard_10_of_10", | ||
] | ||
|
||
for i, component in enumerate(components, start=1): | ||
input_path = os.path.join(INPUT_BASE, component) | ||
if not os.path.exists(input_path): | ||
continue | ||
output_path = os.path.join(OUTPUT_BASE, component) | ||
logging.info(f"Processing {component}") | ||
process_data( | ||
input_folder=input_path, | ||
output_folder=output_path, | ||
prefix=f"dclm-gs{i}", | ||
) | ||
logging.info("Done!") | ||
|
||
client.cluster.close() | ||
client.shutdown() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
import os | ||
|
||
os.environ["DASK_DATAFRAME__QUERY_PLANNING"] = "False" | ||
|
||
import logging | ||
|
||
from dask.distributed import Client, LocalCluster | ||
from helper import process_data | ||
|
||
logging.basicConfig(format="%(asctime)s: %(message)s", level=logging.INFO) | ||
|
||
DATA_BASE = os.environ.get("DATA_BASE") | ||
INPUT_BASE = os.path.join(DATA_BASE, "raw/dolma-v1_7-cc-parquet") | ||
OUTPUT_BASE = os.path.join(DATA_BASE, "processed/dolma-v1_7-cc-parquet") | ||
CPU_WORKERS = os.environ.get("CPU_WORKERS") | ||
|
||
|
||
if __name__ == "__main__": | ||
logging.info("Starting Dask cluster") | ||
cluster = LocalCluster(n_workers=CPU_WORKERS, processes=True, memory_limit="48GB") | ||
client = Client(cluster) | ||
logging.info(client) | ||
|
||
logging.info(f"Processing Dolma-CC") | ||
process_data(input_folder=INPUT_BASE, output_folder=OUTPUT_BASE, prefix="dolma-cc") | ||
logging.info("Done!") | ||
|
||
client.cluster.close() | ||
client.shutdown() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
import os | ||
|
||
os.environ["DASK_DATAFRAME__QUERY_PLANNING"] = "False" | ||
|
||
import ctypes | ||
import gc | ||
import logging | ||
from pathlib import Path | ||
|
||
from dask.distributed import Client, LocalCluster | ||
from helper import process_data | ||
|
||
logging.basicConfig(format="%(asctime)s: %(message)s", level=logging.INFO) | ||
|
||
DATA_BASE = os.environ.get("DATA_BASE") | ||
INPUT_BASE = os.path.join(DATA_BASE, "raw/fineweb-edu-score-2/data") | ||
OUTPUT_BASE = os.path.join(DATA_BASE, "processed/fineweb-edu-score-2") | ||
CPU_WORKERS = os.environ.get("CPU_WORKERS") | ||
|
||
|
||
def trim_memory() -> int: | ||
libc = ctypes.CDLL("libc.so.6") | ||
return libc.malloc_trim(0) | ||
|
||
|
||
def get_folder_size(folder_path): | ||
return sum( | ||
file.stat().st_size for file in Path(folder_path).rglob("*") if file.is_file() | ||
) | ||
|
||
|
||
def sort_folders_by_size(parent_directory): | ||
folders = [ | ||
f | ||
for f in os.listdir(parent_directory) | ||
if os.path.isdir(os.path.join(parent_directory, f)) | ||
] | ||
folder_sizes = [ | ||
(folder, get_folder_size(os.path.join(parent_directory, folder))) | ||
for folder in folders | ||
] | ||
return sorted(folder_sizes, key=lambda x: x[1]) | ||
|
||
|
||
def bytes_to_human_readable(size_in_bytes): | ||
suffixes = ["B", "KB", "MB", "GB", "TB", "PB"] | ||
suffix_index = 0 | ||
size = float(size_in_bytes) | ||
while size >= 1024 and suffix_index < len(suffixes) - 1: | ||
size /= 1024.0 | ||
suffix_index += 1 | ||
return f"{size:.2f} {suffixes[suffix_index]}" | ||
|
||
|
||
if __name__ == "__main__": | ||
logging.info("Starting Dask cluster") | ||
cluster = LocalCluster(n_workers=CPU_WORKERS, processes=True, memory_limit="240GB") | ||
client = Client(cluster) | ||
logging.info(client) | ||
|
||
components_with_sizes = sort_folders_by_size(INPUT_BASE) | ||
|
||
for component, component_size in components_with_sizes: | ||
input_path = os.path.join(INPUT_BASE, component) | ||
if not os.path.exists(input_path) or not os.path.isdir(input_path): | ||
continue | ||
output_path = os.path.join(OUTPUT_BASE, component) | ||
logging.info( | ||
f"Processing {component}, size = {bytes_to_human_readable(component_size)}" | ||
) | ||
process_data( | ||
input_folder=input_path, | ||
output_folder=output_path, | ||
prefix=f"fwe2-{component}", | ||
) | ||
logging.info("Trimming memory") | ||
gc.collect() | ||
client.run(trim_memory) | ||
logging.info("Done!") | ||
|
||
client.cluster.close() | ||
client.shutdown() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
import os | ||
|
||
os.environ["DASK_DATAFRAME__QUERY_PLANNING"] = "False" | ||
|
||
import logging | ||
|
||
from dask.distributed import Client, LocalCluster | ||
from helper import process_data | ||
|
||
logging.basicConfig(format="%(asctime)s: %(message)s", level=logging.INFO) | ||
|
||
DATA_BASE = os.environ.get("DATA_BASE") | ||
INPUT_BASE = os.path.join(DATA_BASE, "raw/data/zyda_no_starcoder") | ||
OUTPUT_BASE = os.path.join(DATA_BASE, "processed/zyda-parquet") | ||
CPU_WORKERS = os.environ.get("CPU_WORKERS") | ||
|
||
|
||
if __name__ == "__main__": | ||
logging.info("Starting Dask cluster") | ||
cluster = LocalCluster(n_workers=CPU_WORKERS, processes=True, memory_limit="48GB") | ||
client = Client(cluster) | ||
logging.info(client) | ||
|
||
components = [ | ||
"zyda_arxiv", | ||
"zyda_peS2o", | ||
"zyda_pile-uncopyrighted", | ||
"zyda_slimpajama", | ||
"zyda_c4-en", | ||
"zyda_refinedweb", | ||
] | ||
|
||
for component in components: | ||
input_path = os.path.join(INPUT_BASE, component) | ||
if not os.path.exists(input_path): | ||
continue | ||
output_path = os.path.join(OUTPUT_BASE, component) | ||
logging.info(f"Processing {component}") | ||
process_data( | ||
input_folder=input_path, output_folder=output_path, prefix=component | ||
) | ||
logging.info("Done!") | ||
|
||
client.cluster.close() | ||
client.shutdown() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
import os | ||
|
||
os.environ["DASK_DATAFRAME__QUERY_PLANNING"] = "False" | ||
|
||
import logging | ||
import time | ||
|
||
import dask_cudf | ||
|
||
from nemo_curator import MinHash | ||
from nemo_curator.datasets import DocumentDataset | ||
from nemo_curator.utils.distributed_utils import get_client, get_num_workers | ||
from nemo_curator.utils.file_utils import get_all_files_paths_under | ||
|
||
logging.basicConfig(format="%(asctime)s: %(message)s", level=logging.INFO) | ||
|
||
|
||
def read_folder(input_folder, columns=["nemo_id", "text"]): | ||
data_paths = get_all_files_paths_under(input_folder) | ||
data_paths = [f for f in data_paths if f.endswith(".parquet")] | ||
data_paths.sort() | ||
logging.info(f"Number of files being read: {len(data_paths)}") | ||
text_ddf = dask_cudf.read_parquet( | ||
data_paths, | ||
columns=columns, | ||
) | ||
return text_ddf | ||
|
||
|
||
DATA_BASE = os.environ.get("DATA_BASE") | ||
SCHEDULER_FILE = os.environ.get("SCHEDULER_FILE") | ||
|
||
|
||
if __name__ == "__main__": | ||
client = get_client(scheduler_file=SCHEDULER_FILE) | ||
logging.info(f"Number of dask workers: {get_num_workers(client)}") | ||
|
||
minhash_base_output_path = os.path.join(DATA_BASE, "fuzzy/minhash") | ||
minhash_output_dir = os.path.join(minhash_base_output_path, "data") | ||
|
||
# Relevant parameters | ||
minhash_id_field = "nemo_id" | ||
minhash_text_field = "text" | ||
seed = 10 | ||
minhash_length = 128 | ||
char_ngram = 25 | ||
use_64bit_hash = False | ||
|
||
# Reading all the data | ||
text_ddf = read_folder( | ||
input_folder=os.path.join(DATA_BASE, "processed"), | ||
columns=[minhash_id_field, minhash_text_field], | ||
) | ||
|
||
# Computing minhashes | ||
t0 = time.time() | ||
minhasher = MinHash( | ||
seed=seed, | ||
num_hashes=minhash_length, | ||
char_ngrams=char_ngram, | ||
use_64bit_hash=use_64bit_hash, | ||
id_field=minhash_id_field, | ||
text_field=minhash_text_field, | ||
cache_dir=minhash_output_dir, | ||
) | ||
res = minhasher(DocumentDataset(text_ddf)).df | ||
logging.info(f"Time taken for MinHash: {time.time()-t0:.2f}sec.") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
import os | ||
|
||
os.environ["DASK_DATAFRAME__QUERY_PLANNING"] = "False" | ||
|
||
import logging | ||
import time | ||
|
||
import cudf | ||
import dask_cudf | ||
import numpy as np | ||
|
||
from nemo_curator import LSH | ||
from nemo_curator.datasets import DocumentDataset | ||
from nemo_curator.utils.distributed_utils import get_client, get_num_workers | ||
from nemo_curator.utils.fuzzy_dedup_utils.id_mapping import convert_str_id_to_int | ||
|
||
logging.basicConfig(format="%(asctime)s: %(message)s", level=logging.INFO) | ||
|
||
|
||
DATA_BASE = os.environ.get("DATA_BASE") | ||
SCHEDULER_FILE = os.environ.get("SCHEDULER_FILE") | ||
|
||
|
||
if __name__ == "__main__": | ||
client = get_client(scheduler_file=SCHEDULER_FILE) | ||
logging.info(f"Number of dask workers: {get_num_workers(client)}") | ||
|
||
minhash_base_output_path = os.path.join(DATA_BASE, "fuzzy/minhash") | ||
minhash_output_dir = os.path.join(minhash_base_output_path, "data") | ||
|
||
# Input | ||
lsh_input_data_path = minhash_output_dir | ||
|
||
# Output | ||
lsh_base_output_path = os.path.join(DATA_BASE, "fuzzy/lsh") | ||
lsh_output_dir = os.path.join(lsh_base_output_path, "data") | ||
|
||
# Relevant parameters | ||
lsh_id_field = "nemo_id" | ||
minhash_field = "_minhash_signature" | ||
minhash_length = 128 | ||
num_bands = 8 | ||
buckets_per_shuffle = 8 | ||
|
||
t0 = time.time() | ||
|
||
# Load MinHash output | ||
logging.info("Converting ids") | ||
df = dask_cudf.read_parquet(lsh_input_data_path, backend="cudf") | ||
df = df.map_partitions( | ||
convert_str_id_to_int, | ||
id_column=lsh_id_field, | ||
meta=cudf.DataFrame( | ||
{minhash_field: [[1, 2, 3]], "doc_id": [1], "dataset_id": np.uint32(1)} | ||
), | ||
) | ||
# Run LSH() | ||
lsh = LSH( | ||
cache_dir=lsh_output_dir, | ||
num_hashes=minhash_length, | ||
num_buckets=num_bands, | ||
buckets_per_shuffle=buckets_per_shuffle, | ||
id_fields=["dataset_id", "doc_id"], | ||
minhash_field=minhash_field, | ||
) | ||
res = lsh(DocumentDataset(df)) | ||
|
||
t1 = time.time() | ||
logging.info(f"Time taken for LSH: {time.time() - t0} s") |
Oops, something went wrong.