diff --git a/docs/user-guide/download.rst b/docs/user-guide/download.rst index 465e2e4f..c05d9e6f 100644 --- a/docs/user-guide/download.rst +++ b/docs/user-guide/download.rst @@ -36,45 +36,106 @@ By "extraction", we typically mean the process of converting a data format from Common crawl has an S3 bucket and a direct HTTPS endpoint. If you want to use the S3 bucket, ensure you have properly set up your credentials with `s5cmd `_. Otherwise, the HTTPS endpoints will be used with ``wget``. Here is a small example of how to use it: - .. code-block:: python - - from nemo_curator.download import download_common_crawl - - common_crawl = download_common_crawl("/extracted/output/folder", "2020-50", "2021-04", output_type="jsonl") - - * ``"/extracted/output/folder"`` is the path to on your local filesystem where the final extracted files will be placed. - * ``"2020-50"`` is the first common crawl snapshot that will be included in the download. **Note:** Not every year and week has a snapshot. Ensure that your range includes at least one valid Common Crawl snapshot. A list of valid Common Crawl snapshots can be found `here `_. - * ``"2021-04"`` is the last common crawl snapshot that will be included in the download. - * ``output_type="jsonl"`` is the file format that will be used for storing the data on disk. Currently ``"jsonl"`` and ``"parquet"`` are supported. +.. code-block:: python + + import os + from nemo_curator import get_client + from nemo_curator.download import download_common_crawl + from nemo_curator.datasets import DocumentDataset + + def main(): + # Initialize a distributed Dask client + client = get_client(cluster_type="cpu") + + # Parameters for downloading Common Crawl data. + # - output_folder: directory for temporary download/extraction files + # - start_snapshot and end_snapshot define the range to fetch + # - output_type: specifies file format for the extracted data (e.g., "jsonl") + output_folder = "/extracted/output/folder" + start_snapshot = "2020-50" + end_snapshot = "2021-04" + output_type = "jsonl" + os.makedirs(output_folder, exist_ok=True) + + # Download and extract the Common Crawl data. + # The function returns a DocumentDataset that contains the extracted documents. + # Note: The output folder and output type are passed here to store intermediate files + # and check if the data has already been downloaded. They should match the final location + # and format of the extracted data. + common_crawl_dataset = download_common_crawl( + output_folder, start_snapshot, end_snapshot, output_type=output_type + ) + + # Write the extracted dataset to JSON format. + # The 'to_json' method will write one JSON document per line, + # preserving the original shard information if write_to_filename is True. + common_crawl_dataset.to_json(output_path=output_folder, write_to_filename=True) + print("Extracted dataset saved to:", output_folder) + + if __name__ == "__main__": + main() + +* ``"/extracted/output/folder"`` is the path to on your local filesystem where the final extracted files will be placed. +* ``"2020-50"`` is the first common crawl snapshot that will be included in the download. **Note:** Not every year and week has a snapshot. Ensure that your range includes at least one valid Common Crawl snapshot. A list of valid Common Crawl snapshots can be found `here `_. +* ``"2021-04"`` is the last common crawl snapshot that will be included in the download. +* ``output_type="jsonl"`` is the file format that will be used for storing the data on disk. Currently ``"jsonl"`` and ``"parquet"`` are supported. You can choose to modify the HTML text extraction algorithm used in ``download_common_crawl``. See an example below. - .. code-block:: python +.. code-block:: python - from nemo_curator.download import ( + import os + from nemo_curator import get_client + from nemo_curator.download import ( ResiliparseExtractor, TrafilaturaExtractor, download_common_crawl, - ) - - # Change the extraction algorithm - extraction_algorithm = ResiliparseExtractor() - # Alternatively - # extraction_algorithm = TrafilaturaExtractor() - - common_crawl = download_common_crawl( - "/extracted/output/folder", - "2020-50", - "2021-04", - output_type="jsonl", - algorithm=extraction_algorithm, - ) - - Above, we changed the extraction algorithm from the default ``JusTextExtractor``. **Note:** Please see the Trafilatura documentation `here `_ and `here `_ for more information about custom Trafilatura extraction parameters. - - The return value ``common_crawl`` will be in NeMo Curator's standard ``DocumentDataset`` format. Check out the function's docstring for more parameters you can use. - - NeMo Curator's Common Crawl extraction process looks like this under the hood: + ) + from nemo_curator.datasets import DocumentDataset + + def main(): + # Initialize a distributed Dask client + client = get_client(cluster_type="cpu") + + # Parameters for downloading Common Crawl data. + # - output_folder: directory for temporary download/extraction files + # - start_snapshot and end_snapshot define the range to fetch + # - output_type: specifies file format for the extracted data (e.g., "jsonl") + output_folder = "/extracted/output/folder" + start_snapshot = "2020-50" + end_snapshot = "2021-04" + output_type = "jsonl" + os.makedirs(output_folder, exist_ok=True) + + # Change the extraction algorithm to Resiliparse + extraction_algorithm = ResiliparseExtractor() + # Alternatively, change the extraction algorithm to Trafilatura + # extraction_algorithm = TrafilaturaExtractor() + + # Download and extract the Common Crawl data using the Resiliparse extraction algorithm. + # The function returns a DocumentDataset that contains the extracted documents. + common_crawl_dataset = download_common_crawl( + output_folder, + start_snapshot, + end_snapshot, + output_type=output_type, + algorithm=extraction_algorithm, + ) + + # Write the extracted dataset to JSON format. + # The 'to_json' method writes one JSON document per line, + # preserving the original shard information if write_to_filename is True. + common_crawl_dataset.to_json(output_path=output_folder, write_to_filename=True) + print("Extracted dataset saved to:", output_folder) + + if __name__ == "__main__": + main() + +Above, we changed the extraction algorithm from the default ``JusTextExtractor``. **Note:** Please see the Trafilatura documentation `here `_ + +The return value ``common_crawl`` will be in NeMo Curator's standard ``DocumentDataset`` format. Check out the function's docstring for more parameters you can use. + +NeMo Curator's Common Crawl extraction process looks like this under the hood: 1. Decode the HTML within the record from binary to text. 2. If the HTML can be properly decoded, then with `pyCLD2 `_, perform language detection on the input HTML. diff --git a/nemo_curator/classifiers/base.py b/nemo_curator/classifiers/base.py index 585bdbc5..bc268e92 100644 --- a/nemo_curator/classifiers/base.py +++ b/nemo_curator/classifiers/base.py @@ -123,10 +123,13 @@ def _run_classifier_helper( prob_col: str = None, ) -> "dask_cudf.DataFrame": - if prob_col: - df[prob_col] = 0 - else: + if prob_col is None: prob_col = "_prob" + labeler = op.Labeler(labels, cols=[prob_col], suffix=label_col) + else: + labeler = op.Labeler( + labels, cols=[prob_col], keep_cols=[prob_col], suffix=label_col + ) columns_to_keep_list = df.columns.to_list() @@ -140,7 +143,7 @@ def _run_classifier_helper( batch_size=batch_size, pred_output_col=prob_col, ), - op.Labeler(labels, cols=[prob_col], suffix=label_col), + labeler, repartition=df.npartitions, keep_cols=columns_to_keep_list, ) diff --git a/nemo_curator/classifiers/prompt_task_complexity.py b/nemo_curator/classifiers/prompt_task_complexity.py index 32db8382..116f9471 100644 --- a/nemo_curator/classifiers/prompt_task_complexity.py +++ b/nemo_curator/classifiers/prompt_task_complexity.py @@ -337,11 +337,15 @@ def _run_classifier(self, dataset: DocumentDataset) -> DocumentDataset: df = dataset.df columns_to_keep_list = df.columns.to_list() - df["sliced_text"] = df[self.text_field].str.slice(0, self.max_chars) model = self.model classifier_pipe = op.Sequential( - op.Tokenizer(model, cols=["sliced_text"], tokenizer_type="default"), + op.Tokenizer( + model, + cols=[self.text_field], + tokenizer_type="default", + max_chars=self.max_chars, + ), op.Predictor( model, sorted_data_loader=True, diff --git a/nemo_curator/download/arxiv.py b/nemo_curator/download/arxiv.py index 538d567d..e7e5327a 100644 --- a/nemo_curator/download/arxiv.py +++ b/nemo_curator/download/arxiv.py @@ -18,6 +18,7 @@ import subprocess import tarfile import tempfile +from typing import Literal, Optional from nemo_curator.datasets import DocumentDataset from nemo_curator.download.doc_builder import ( @@ -218,12 +219,12 @@ def extract(self, content): for file_content in content ) except Exception: - return {}, None + return None # Don't return meta if cleaned_latex_file_str is not None: if len(cleaned_latex_file_str) > 0: - return {}, cleaned_latex_file_str + return {"text": cleaned_latex_file_str} def _clean_tex_file(self, file_content, arg_macros, non_arg_macros): r"""function takes a tex file as input and returns a cleaned version. The @@ -365,25 +366,44 @@ def _build_non_arg_macros_dict(self, file_content): def download_arxiv( output_path: str, - output_type: str = "jsonl", - raw_download_dir=None, - keep_raw_download=False, - force_download=False, - url_limit=None, + output_type: Literal["jsonl", "parquet"] = "jsonl", + raw_download_dir: Optional[str] = None, + keep_raw_download: bool = False, + force_download: bool = False, + url_limit: Optional[int] = None, + record_limit: Optional[int] = None, ) -> DocumentDataset: """ - Downloads Arxiv tar files and extracts them + Download Arxiv tar files and extract the contained LaTeX projects. + + This function obtains a list of Arxiv tar file URLs (via get_arxiv_urls), downloads the tar files, + and then extracts the contained LaTeX source files. The resulting documents (after extraction) are + assembled into a DocumentDataset. Args: - output_path: The path to the root directory of the files - output_type: The file type to save the data as. - raw_download_dir: Path to store the raw download files for intermediate processing. - If None, they are stored in a folder named "downloads" under output_path. - keep_raw_download: If True, keeps the compressed WARC files that have not been extracted. - force_download: If False, will skip processing all files in output_paths that already exist and - directly read from them instead. - url_limit: The maximum number of raw files to download from the snapshot. If None, all - files from the range of snapshots are downloaded. + output_path (str): + The root directory where both the final extracted files and the raw download subdirectory will be stored. + The extracted files (in the format specified by output_type) are eventually saved in this directory. + output_type (Literal["jsonl", "parquet"], optional): + The file format/extension used for saving the extracted documents (e.g., "jsonl" or "parquet"). + Default is "jsonl". This is not used for the output file, but is used to check if an extracted output already exists and read it if so. + raw_download_dir (Optional[str], optional): + The directory where the raw downloaded tar files will be kept. If None, a folder named "downloads" + under output_path is used. + keep_raw_download (bool, optional): + If True, the raw tar files (before extraction) are not removed after processing. Default is False. + force_download (bool, optional): + If False, then if an output file already exists for a given URL, re-downloading and re-extraction will be skipped. + Default is False. + url_limit (Optional[int], optional): + Limits the maximum number of Arxiv tar file URLs to download and process. + If None, all available URLs (from get_arxiv_urls) are processed. + record_limit (Optional[int], optional): + Limits the maximum number of records to extract from each tar file. + If None, all available records are extracted. + Returns: + DocumentDataset: + A dataset object containing the extracted documents. """ arxiv_urls = get_arxiv_urls() if url_limit: @@ -416,6 +436,7 @@ def download_arxiv( keep_raw_download=keep_raw_download, force_download=force_download, filename_col="file_name", + record_limit=record_limit, ) return dataset diff --git a/nemo_curator/download/commoncrawl.py b/nemo_curator/download/commoncrawl.py index 2e715e11..6e03628a 100644 --- a/nemo_curator/download/commoncrawl.py +++ b/nemo_curator/download/commoncrawl.py @@ -18,6 +18,7 @@ import unicodedata from abc import ABC, abstractmethod from copy import deepcopy +from typing import Literal, Optional from urllib.parse import urlparse import justext @@ -447,48 +448,55 @@ def extract(self, content): if text is not None: if len(text) > 0: text = "\n\n".join(text) - meta = {"language": lang} - return meta, text + meta = {"language": lang, "text": text} + return meta else: - return None, None + return None def download_common_crawl( output_path: str, start_snapshot: str, end_snapshot: str, - output_type: str = "jsonl", + output_type: Literal["jsonl", "parquet"] = "jsonl", algorithm=JusTextExtractor(), - news=False, - aws=False, - raw_download_dir=None, - keep_raw_download=False, - force_download=False, - url_limit=None, + news: bool = False, + aws: bool = False, + raw_download_dir: Optional[str] = None, + keep_raw_download: bool = False, + force_download: bool = False, + url_limit: Optional[int] = None, + record_limit: Optional[int] = None, ) -> DocumentDataset: """ - Downloads Common Crawl WARC snapshots and extracts them using jusText, Resiliparse, or Trafilatura + Downloads Common Crawl WARC snapshots and extracts text content using a specified extraction algorithm. Args: - output_path: The path to the root directory of the files - start_snapshot: The first common crawl snapshot to include. Snapshots must be - specified by YYYY-WeekNumber (e.g., '2020-50' or '2021-04'). For the CC-NEWS dataset, - (specified with news=True flag) this changes to Year-Month (YYYY-MM). - end_snapshot: The last common crawl snapshot to include. Must be chronologically - after the starting snapshot. - output_type: The file type to save the data as. - algorithm: A JusTextExtractor, ResiliparseExtractor, or TrafilaturaExtractor object. - news: If True, gets WARC URLs for the CC-NEWS dataset instead of the CC-MAIN datasets. - Also assumes that the format for the start and end snapshots is 'YYYY-MM' (Year-Month). - aws: Whether to download from Common Crawl's S3 bucket. If True, uses s5cmd to download. - If False, uses wget. - raw_download_dir: Path to store the raw download files for intermediate processing. - If None, they are stored in a folder named "downloads" under output_path. - keep_raw_download: If True, keeps the compressed WARC files that have not been extracted. - force_download: If False, will skip processing all files in output_paths that already exist and - directly read from them instead. - url_limit: The maximum number of raw files to download from the snapshot. If None, all - files from the range of snapshots are downloaded. + output_path (str): The root directory used for managing download and extraction. + • Raw WARC files are stored in a "downloads" subdirectory under this path. + • This path is also checked for existing extraction results; if found, extraction can be skipped. + • Note: This function returns a DocumentDataset, and writing the extracted data to disk is the caller's responsibility. + start_snapshot (str): Identifier for the earliest snapshot to process. + • For CC-MAIN datasets, use the format 'YYYY-WeekNumber' (e.g., '2020-50' or '2021-04'). + • For CC-NEWS datasets (when news=True), use the 'YYYY-MM' (Year-Month) format. + end_snapshot (str): Identifier for the latest snapshot to process, which must be chronologically after start_snapshot. + output_type (Literal["jsonl", "parquet"]): The file format for the extracted output. Must be either "jsonl" or "parquet". + • This is not used for the output file, but is used to check if an extracted output already exists. + algorithm: The text extraction algorithm instance to use for HTML processing. + • This can be a JusTextExtractor (default), ResiliparseExtractor, or TrafilaturaExtractor object. + news (bool): When True, indicates that URLs should be retrieved from the CC-NEWS dataset. + • This also means snapshot identifiers should follow the 'YYYY-MM' format. + aws (bool): If True, downloads are sourced from Common Crawl's S3 bucket using s5cmd; + • If False, wget is used to fetch the files via HTTPS. + raw_download_dir: Optional; the directory to temporarily store raw WARC files. + • If not provided, defaults to a "downloads" folder within output_path. + keep_raw_download (bool): If True, retains the downloaded raw WARC files after extraction. + • If False, these raw files may be removed following extraction. + force_download (bool): If False, skips re-downloading or re-extracting snapshots if outputs already exist in output_path. + url_limit: Optional; the maximum number of WARC files to download from the snapshot range. + • If None, all available files within the specified snapshots are downloaded. + record_limit: Optional; the maximum number of records to extract from each WARC file. + • If None, all available records are extracted. """ common_crawl_urls = get_common_crawl_urls( starting_snapshot=start_snapshot, ending_snapshot=end_snapshot, news=news @@ -538,6 +546,7 @@ def download_common_crawl( keep_raw_download=keep_raw_download, force_download=force_download, filename_col="file_name", + record_limit=record_limit, ) return dataset diff --git a/nemo_curator/download/doc_builder.py b/nemo_curator/download/doc_builder.py index 5ad3094e..9d3e849b 100644 --- a/nemo_curator/download/doc_builder.py +++ b/nemo_curator/download/doc_builder.py @@ -15,17 +15,14 @@ import importlib import os from abc import ABC, abstractmethod -from typing import List, Tuple, Union +from typing import List, Literal, Optional, Tuple, Union import dask.dataframe as dd import pandas as pd from dask import compute, delayed from nemo_curator.datasets import DocumentDataset -from nemo_curator.utils.distributed_utils import ( - read_single_partition, - single_partition_write_with_filename, -) +from nemo_curator.utils.distributed_utils import read_single_partition class DocumentDownloader(ABC): @@ -108,14 +105,36 @@ def _download_and_extract_single_partition( downloader: DocumentDownloader, iterator: DocumentIterator, extractor: DocumentExtractor, - output_type: str, + output_type: Literal["jsonl", "parquet"], keep_raw_download: bool, force_download: bool, input_meta: Union[str, dict] = None, filename_col: str = "file_name", + record_limit: Optional[int] = None, ) -> pd.DataFrame: + """ + Downloads a single partition from a URL and extracts its contents in-memory without writing + the extracted output to disk. The provided output_path is used only to check if an extracted + output already exists and, if so, skips re-downloading and extraction. + + Parameters: + paths (Tuple[str, str]): A tuple (url, output_path) where 'url' is the source URL and + 'output_path' is the expected location of a previously extracted output. + downloader (DocumentDownloader): An object to download the content from the URL. + iterator (DocumentIterator): An object to iterate over records in the downloaded file. + extractor (DocumentExtractor): An object to extract the desired content from each record. + output_type (Literal["jsonl", "parquet"]): The output file format/extension. Must be either "jsonl" or "parquet". Defaults to "jsonl". This parameter is only used to verify whether an extracted output already exists. + keep_raw_download (bool): If False, deletes the raw download file after extraction. + force_download (bool): If False and output_path exists, skips downloading and extraction. + input_meta (Union[str, dict], optional): Metadata describing the input file's structure. + filename_col (str, optional): Name of the column to store the filename within the result DataFrame. + record_limit (int, optional): Limit the number of records to extract from each file. + Returns: + pd.DataFrame: A DataFrame containing the extracted records. + """ url, output_path = paths + # If an extracted output already exists and we're not forcing a download, load and return it. if os.path.exists(output_path) and not force_download: partition = read_single_partition( [output_path], @@ -125,31 +144,26 @@ def _download_and_extract_single_partition( ) return partition + # Download the file and extract its records in memory. downloaded_file = downloader.download(url) records = [] - # Iterate over all records in file for item in iterator.iterate(downloaded_file): + if record_limit is not None and len(records) >= record_limit: + break record_meta, content = item - # Extract the text from the record extracted = extractor.extract(content) if extracted is not None: - text_meta, text = extracted - if text is not None: - line = { - "text": text, - **text_meta, - **record_meta, - } - records.append(line) + # Merge the extracted data and record metadata into one dictionary. + line = {**extracted, **record_meta} + records.append(line) partition = pd.DataFrame(records) - filename = os.path.basename(output_path) - output_dir = os.path.dirname(output_path) - partition[filename_col] = filename - single_partition_write_with_filename( - partition, output_dir, output_type=output_type, filename_col=filename_col - ) - if not keep_raw_download: + # Add a filename column for consistency using the basename of the output_path. + partition[filename_col] = os.path.basename(output_path) + + # Since we're not writing the extracted partition to disk, the output_path is not used here. + # Clean up the raw downloaded file if it's not meant to be kept. + if not keep_raw_download and os.path.exists(downloaded_file): os.remove(downloaded_file) return partition @@ -162,42 +176,76 @@ def download_and_extract( iterator: DocumentIterator, extractor: DocumentExtractor, output_format: dict, - output_type: str = "jsonl", - keep_raw_download=False, - force_download=False, + output_type: Literal["jsonl", "parquet"] = "jsonl", + keep_raw_download: bool = False, + force_download: bool = False, input_meta: Union[str, dict] = None, filename_col: str = "file_name", + record_limit: Optional[int] = None, ) -> DocumentDataset: """ - Downloads and extracts a dataset into a format accepted by the NeMo Curator + Download files from the given URLs, extract their records, and + construct a DocumentDataset. + + For each URL provided, this function downloads the corresponding + file (unless an extracted output already exists and force_download is + False), iterates over its records, extracts the desired content, and + finally converts all records into a DocumentDataset. Args: - urls: A list of urls to download the dataset from - output_paths: A list of paths to save the final extracted output to. - The raw output of the downloader will be saved using the path given by downloader.download(url). - downloader: A DocumentDownloader that handles retrieving each file from its url and saving it to storage - iterator: A DocumentIterator that handles iterating through the downloaded file's format - extractor: A DocumentExtractor that handles extracting the data from its raw format into text - output_format: A dictionary mappings columns to datatypes for the fields of each datapoint after extraction. - output_type: The file type to save the dataset as. - keep_raw_download: Whether to keep the pre-extracted download file. - force_download: If False, will skip processing all files in output_paths that already exist and - directly read from them instead. - input_meta: A dictionary or a string formatted as a dictionary, which outlines - the field names and their respective data types within the JSONL input file. - filename_col : The name of the column that contains the filename. Default is "filename_col" + urls (List[str]): + A list of URLs from which to download dataset files. + output_paths (List[str]): + A list of file paths where the extracted outputs should be + found. If a file already exists at a given path and force_download + is False, that partition is skipped. + downloader (DocumentDownloader): + The downloader instance responsible for fetching files from + the specified URLs. + iterator (DocumentIterator): + The iterator instance used to traverse the downloaded file + and yield records. + extractor (DocumentExtractor): + The extractor instance used to obtain the desired content from + each record. + output_format (dict): + A dictionary mapping column names to the data types for the + extracted records. + output_type (Literal["jsonl", "parquet"], optional): + The output file format/extension. Must be either "jsonl" or "parquet". + Defaults to "jsonl". This parameter is only used to verify whether + an extracted output already exists. + keep_raw_download (bool, optional): + If True, the raw downloaded files are retained after extraction. + Defaults to False. + force_download (bool, optional): + If False and an output file already exists at a given path, the + download and extraction for that file are skipped. + Defaults to False. + input_meta (Union[str, dict], optional): + Optional metadata describing the input file's schema. + Defaults to None. + filename_col (str, optional): + The name for the column in the resulting dataset that records + the basename of the output file. Defaults to "file_name". + record_limit (int, optional): Limit the number of records to extract from each file. + Defaults to None. Returns: - A DocumentDataset of the downloaded data + DocumentDataset: + A dataset composed of the records extracted from the downloaded + files. """ - if len(urls) == 0: - raise ValueError("No urls were provided to download") - + # Validate parameters + if not urls: + raise ValueError("No URLs were provided to download") if len(urls) != len(output_paths): raise ValueError( - f"Different number of urls and output_paths. {len(urls)} urls vs {len(output_paths)} output_paths" + f"Different number of URLs and output_paths. {len(urls)} URLs vs {len(output_paths)} output_paths" ) + # Ensure consistent ordering of output_format keys. output_format = dict(sorted(output_format.items())) + df = dd.from_map( _download_and_extract_single_partition, zip(urls, output_paths), @@ -210,6 +258,7 @@ def download_and_extract( enforce_metadata=False, input_meta=input_meta, filename_col=filename_col, + record_limit=record_limit, meta=output_format, ) diff --git a/nemo_curator/download/wikipedia.py b/nemo_curator/download/wikipedia.py index 54b85d45..e494ed38 100644 --- a/nemo_curator/download/wikipedia.py +++ b/nemo_curator/download/wikipedia.py @@ -18,6 +18,7 @@ import re import subprocess import xml.etree.cElementTree as etree +from typing import Literal, Optional from urllib.parse import quote, urlparse import mwparserfromhell @@ -742,35 +743,49 @@ def try_remove_obj(obj, section): ) ) # Don't return any meta here - return {}, "\n\n".join(section_text) + return {"text": "\n\n".join(section_text)} def download_wikipedia( output_path: str, language: str = "en", - dump_date=None, - output_type: str = "jsonl", - raw_download_dir=None, - keep_raw_download=False, - force_download=False, - url_limit=None, + dump_date: Optional[str] = None, + output_type: Literal["jsonl", "parquet"] = "jsonl", + raw_download_dir: Optional[str] = None, + keep_raw_download: bool = False, + force_download: bool = False, + url_limit: Optional[int] = None, + record_limit: Optional[int] = None, ) -> DocumentDataset: """ - Downloads the latest Wikipedia dumps and extracts them using mwparserfromhell + Downloads and extracts articles from a Wikipedia dump. + + This function retrieves a list of Wikipedia dump URLs for the specified language and dump date, + downloads the compressed bz2 dump file (if it is not already present), and extracts its articles + using mwparserfromhell. The resulting articles are saved in the specified output format (e.g., "jsonl") + along with relevant metadata. Args: - output_path: The path to the root directory of the files - language: The language of the Wikipedia articles to download - dump_date: A string formatted as "YYYYMMDD" for the wikipedia dump to use. - If None, latest dump is used. - output_type: The file type to save the data as. - raw_download_dir: Path to store the raw download files for intermediate processing. - If None, they are stored in a folder named "downloads" under output_path. - keep_raw_download: If True, keeps the bz2 files that have not been extracted. - force_download: If False, will skip processing all files in output_paths that already exist and - directly read from them instead. - url_limit: The maximum number of raw files to download from the snapshot. If None, all - files from the range of snapshots are downloaded. + output_path (str): The root directory where the final extracted files and intermediate outputs + (if any) are stored. + language (str, optional): The language code for the Wikipedia dump to download. Default is "en". + dump_date (Optional[str], optional): The dump date in "YYYYMMDD" format. If None, the latest + available dump is downloaded. + output_type (Literal["jsonl", "parquet"], optional): The file format/extension for saving the extracted documents (e.g., "jsonl"). + Defaults to "jsonl". This is not used for the output file, but is used to check if an extracted output + already exists and read it if so. + raw_download_dir (Optional[str], optional): Directory used for temporary storage of raw bz2 dump files. + If None, a subdirectory named "downloads" under output_path is used. + keep_raw_download (bool, optional): If True, retains the raw bz2 files after extraction. + Default is False. + force_download (bool, optional): If False, skips re-downloading or re-extracting files that already exist. + url_limit (Optional[int], optional): The maximum number of dump file URLs to process. If None, all + available URLs are processed. + record_limit (Optional[int], optional): Limit the number of records to extract from each file. + If None, all available records are extracted. + + Returns: + DocumentDataset: A dataset object containing the extracted Wikipedia articles along with associated metadata. """ wikipedia_urls = get_wikipedia_urls(language=language, dump_date=dump_date) if url_limit: @@ -812,6 +827,7 @@ def download_wikipedia( keep_raw_download=keep_raw_download, force_download=force_download, filename_col="file_name", + record_limit=record_limit, ) return dataset diff --git a/nemo_curator/scripts/classifiers/fineweb_mixtral_edu_classifier_inference.py b/nemo_curator/scripts/classifiers/fineweb_mixtral_edu_classifier_inference.py index 582ec4c5..756584f3 100644 --- a/nemo_curator/scripts/classifiers/fineweb_mixtral_edu_classifier_inference.py +++ b/nemo_curator/scripts/classifiers/fineweb_mixtral_edu_classifier_inference.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/nemo_curator/scripts/classifiers/fineweb_nemotron_edu_classifier_inference.py b/nemo_curator/scripts/classifiers/fineweb_nemotron_edu_classifier_inference.py index 112453a2..d58867ce 100644 --- a/nemo_curator/scripts/classifiers/fineweb_nemotron_edu_classifier_inference.py +++ b/nemo_curator/scripts/classifiers/fineweb_nemotron_edu_classifier_inference.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/nemo_curator/scripts/download_and_extract.py b/nemo_curator/scripts/download_and_extract.py index ef5de89d..aff5fc4d 100644 --- a/nemo_curator/scripts/download_and_extract.py +++ b/nemo_curator/scripts/download_and_extract.py @@ -17,7 +17,7 @@ from nemo_curator.download.doc_builder import batch_download, download_and_extract from nemo_curator.utils.config_utils import build_downloader -from nemo_curator.utils.distributed_utils import get_client +from nemo_curator.utils.distributed_utils import get_client, write_to_disk from nemo_curator.utils.file_utils import ( expand_outdir_and_mkdir, get_all_files_paths_under, @@ -74,10 +74,16 @@ def main(args): keep_raw_download=args.keep_downloaded_files, force_download=args.overwrite_existing_json, input_meta=args.input_meta, + record_limit=args.record_limit, ) # Sample to trigger the dask computation - sample = dataset.df.sample(frac=10 / len(dataset)).compute() + write_to_disk( + dataset.df, + args.output_json_dir, + write_to_filename=True, + output_type="jsonl", + ) def attach_args( @@ -149,6 +155,12 @@ def attach_args( default=None, help="Output directory to store the extracted text in JSONL files.", ) + parser.add_argument( + "--record-limit", + type=int, + default=None, + help="Limit the number of records to extract from each file.", + ) ArgumentHelper.attach_bool_arg( parser, "overwrite-existing-json", diff --git a/pyproject.toml b/pyproject.toml index 0c7b37ac..77f90b63 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,4 +1,4 @@ -# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -78,20 +78,20 @@ dynamic = ["version"] [project.optional-dependencies] # Installs CPU + GPU text curation modules cuda12x = [ - "cudf-cu12>=24.12", - "cugraph-cu12>=24.12", - "cuml-cu12>=24.12", - "dask-cuda>=24.12", - "dask-cudf-cu12>=24.12", + "cudf-cu12>=25.02", + "cugraph-cu12>=25.02", + "cuml-cu12>=25.02", + "dask-cuda>=25.02", + "dask-cudf-cu12>=25.02", "spacy[cuda12x]>=3.6.0, <3.8.0", ] # Installs CPU + GPU text curation modules with RAPIDS Nightlies cuda12x_nightly = [ - "cudf-cu12>=25.02.0a0,<=25.02", - "cugraph-cu12>=25.02.0a0,<=25.02", - "cuml-cu12>=25.02.0a0,<=25.02", - "dask-cuda>=25.02.0a0,<=25.02", - "dask-cudf-cu12>=25.02.0a0,<=25.02", + "cudf-cu12>=25.04.0a0,<=25.04", + "cugraph-cu12>=25.04.0a0,<=25.04", + "cuml-cu12>=25.04.0a0,<=25.04", + "dask-cuda>=25.04.0a0,<=25.04", + "dask-cudf-cu12>=25.04.0a0,<=25.04", "spacy[cuda12x]>=3.6.0, <3.8.0", ] # Installs CPU + GPU text and image curation modules diff --git a/tests/test_classifiers.py b/tests/test_classifiers.py index d6d2852c..81b1112e 100644 --- a/tests/test_classifiers.py +++ b/tests/test_classifiers.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -139,9 +139,6 @@ def test_fineweb_edu_classifier(gpu_client, domain_dataset): assert result_pred.equals(expected_pred) -@pytest.mark.skip( - reason="Skipping until https://huggingface.co/nvidia/nemocurator-fineweb-mixtral-edu-classifier is published" -) @pytest.mark.gpu def test_fineweb_mixtral_classifier(gpu_client, domain_dataset): from nemo_curator.classifiers import FineWebMixtralEduClassifier @@ -155,9 +152,6 @@ def test_fineweb_mixtral_classifier(gpu_client, domain_dataset): assert result_pred.equals(expected_pred) -@pytest.mark.skip( - reason="Skipping until https://huggingface.co/nvidia/nemocurator-fineweb-nemotron-4-edu-classifier is published" -) @pytest.mark.gpu def test_fineweb_nemotron_classifier(gpu_client, domain_dataset): from nemo_curator.classifiers import FineWebNemotronEduClassifier diff --git a/tests/test_download.py b/tests/test_download.py index b4dd1be5..51eb4e31 100644 --- a/tests/test_download.py +++ b/tests/test_download.py @@ -1,4 +1,21 @@ -from pathlib import Path +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import bz2 +import os +import subprocess +import tarfile +from urllib.parse import urlparse import pytest @@ -7,13 +24,38 @@ TrafilaturaExtractor, download_and_extract, ) +from nemo_curator.download.arxiv import ArxivDownloader, ArxivExtractor, ArxivIterator from nemo_curator.download.commoncrawl import ( CommonCrawlWARCDownloader, CommonCrawlWARCExtractor, CommonCrawlWARCIterator, + JusTextExtractor, + ResiliparseExtractor, get_common_crawl_urls, get_stop_list_dict, ) +from nemo_curator.download.wikipedia import ( + WikipediaDownloader, + WikipediaExtractor, + WikipediaIterator, +) + + +class DummyLock: + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + pass + + +class FakeCompletedProcess: + def __init__(self): + self.returncode = 0 + + +def fake_run_success(cmd, stdout, stderr): + return FakeCompletedProcess() @pytest.fixture @@ -107,66 +149,18 @@ def test_trafilatura_extract_text(self, html_string): assert result == expected - def test_common_crawl_urls(self): - start_snapshot = "2021-04" - end_snapshot = "2021-10" - urls = get_common_crawl_urls(start_snapshot, end_snapshot) - - assert ( - urls[0] - == "https://data.commoncrawl.org/crawl-data/CC-MAIN-2021-10/segments/1614178347293.1/warc/CC-MAIN-20210224165708-20210224195708-00000.warc.gz" - ) - assert ( - urls[-1] - == "https://data.commoncrawl.org/crawl-data/CC-MAIN-2021-04/segments/1610704847953.98/warc/CC-MAIN-20210128134124-20210128164124-00799.warc.gz" - ) - assert len(urls) == 143840 - def test_incorrect_snapshot_order(self): with pytest.raises(ValueError): end_snapshot = "2021-04" start_snapshot = "2021-10" urls = get_common_crawl_urls(start_snapshot, end_snapshot) - def test_common_crawl_news_urls(self): - start_snapshot = "2021-04" - end_snapshot = "2021-10" - urls = get_common_crawl_urls(start_snapshot, end_snapshot, news=True) - - assert ( - urls[0] - == "https://data.commoncrawl.org/crawl-data/CC-NEWS/2021/04/CC-NEWS-20210401004522-01022.warc.gz" - ) - assert ( - urls[-1] - == "https://data.commoncrawl.org/crawl-data/CC-NEWS/2021/10/CC-NEWS-20211031225258-00089.warc.gz" - ) - assert len(urls) == 3838 - def test_incorrect_snapshot_order_news(self): with pytest.raises(ValueError): end_snapshot = "2021-04" start_snapshot = "2021-10" urls = get_common_crawl_urls(start_snapshot, end_snapshot, news=True) - @pytest.mark.skip( - reason="Skipping until we figure out how to get this to a non flaky state" - ) - def test_uneven_common_crawl_range(self): - start_snapshot = "2021-03" - end_snapshot = "2021-11" - urls = get_common_crawl_urls(start_snapshot, end_snapshot) - - assert ( - urls[0] - == "https://data.commoncrawl.org/crawl-data/CC-MAIN-2021-10/segments/1614178347293.1/warc/CC-MAIN-20210224165708-20210224195708-00000.warc.gz" - ) - assert ( - urls[-1] - == "https://data.commoncrawl.org/crawl-data/CC-MAIN-2021-04/segments/1610704847953.98/warc/CC-MAIN-20210128134124-20210128164124-00799.warc.gz" - ) - assert len(urls) == 143840 - def test_no_urls(self): with pytest.raises(ValueError): output_format = { @@ -194,3 +188,319 @@ def test_url_path_mismatch(self): CommonCrawlWARCExtractor(), output_format, ) + + +class TestWikipedia: + def test_wikipedia_downloader_existing_file(self, tmp_path, monkeypatch): + # Create a temporary directory and simulate an already-downloaded file. + download_dir = tmp_path / "downloads" + download_dir.mkdir() + + url = "https://en.wikipedia.org/dummy-file" + parsed = urlparse(url) + output_name = parsed.path[1:].replace("/", "-") # "dummy-file" + file_path = os.path.join(str(download_dir), output_name) + + # Write a dummy file to simulate an existing download. + with open(file_path, "w") as f: + f.write("existing content") + + downloader = WikipediaDownloader(str(download_dir), verbose=False) + + # Monkey-patch subprocess.run (should not be called since file exists). + monkeypatch.setattr(subprocess, "run", fake_run_success) + + result = downloader.download(url) + assert result == file_path + + def test_wikipedia_downloader_new_file(self, tmp_path, monkeypatch): + download_dir = tmp_path / "downloads" + download_dir.mkdir() + + url = "https://en.wikipedia.org/new-file" + parsed = urlparse(url) + output_name = parsed.path[1:].replace("/", "-") # "new-file" + file_path = os.path.join(str(download_dir), output_name) + + # Ensure the file does not exist. + if os.path.exists(file_path): + os.remove(file_path) + + downloader = WikipediaDownloader(str(download_dir), verbose=False) + downloader._lock = DummyLock() + + called_run = False + + def fake_run(cmd, stdout, stderr): + nonlocal called_run + called_run = True + + return FakeCompletedProcess() + + monkeypatch.setattr(subprocess, "run", fake_run) + + result = downloader.download(url) + assert result == file_path + assert called_run + + def test_wikipedia_iterator(self, tmp_path): + # Create a minimal valid XML resembling a Wikipedia dump with one page. + xml_content = """ + + + Test Article + 0 + 123 + + Test content with [[link]] + + +""" + # Compress the XML content using bz2. + compressed_data = bz2.compress(xml_content.encode("utf-8")) + + # Write the compressed data to a temporary file. + temp_file = tmp_path / "test_wiki.xml.bz2" + temp_file.write_bytes(compressed_data) + + iterator = WikipediaIterator(language="en") + pages = list(iterator.iterate(str(temp_file))) + + assert len(pages) == 1 + metadata, raw_text = pages[0] + assert metadata["title"] == "Test Article" + assert metadata["id"] == "123" + # The URL is constructed by quoting the title. + expected_url = "https://en.wikipedia.org/wiki/Test%20Article" + assert metadata["url"] == expected_url + assert "Test content with" in raw_text + + def test_wikipedia_extractor(self): + extractor = WikipediaExtractor(language="en") + # Sample wiki markup; note the presence of a heading and a magic word. + content = "== Heading ==\nThis is a sample article. __NOTOC__" + result = extractor.extract(content) + + # # The extractor should return a dict with a "text" key. + assert isinstance(result, dict) + extracted_text = result.get("text", "") + # Verify that the magic word was removed. + assert "__NOTOC__" not in extracted_text + # Verify that the main content appears. + assert "This is a sample article." in extracted_text + + +class TestArxiv: + def test_arxiv_downloader_existing_file(self, tmp_path, monkeypatch): + # Create a temporary download directory and simulate an already-downloaded tar file. + download_dir = tmp_path / "downloads" + download_dir.mkdir() + tar_filename = "dummy.tar" + file_path = os.path.join(str(download_dir), tar_filename) + # Write dummy content to simulate an existing download. + with open(file_path, "w") as f: + f.write("existing content") + + downloader = ArxivDownloader(str(download_dir), verbose=False) + # Monkey-patch subprocess.run (should not be called since file exists). + monkeypatch.setattr(subprocess, "run", fake_run_success) + result = downloader.download(tar_filename) + assert result == file_path + + def test_arxiv_downloader_new_file(self, tmp_path, monkeypatch): + # Create a temporary download directory and ensure the tar file does not exist. + download_dir = tmp_path / "downloads" + download_dir.mkdir() + tar_filename = "dummy.tar" + file_path = os.path.join(str(download_dir), tar_filename) + if os.path.exists(file_path): + os.remove(file_path) + + downloader = ArxivDownloader(str(download_dir), verbose=False) + called_run = False + + def fake_run(cmd, stdout, stderr): + nonlocal called_run + called_run = True + return FakeCompletedProcess() + + monkeypatch.setattr(subprocess, "run", fake_run) + result = downloader.download(tar_filename) + assert result == file_path + assert called_run + + def test_arxiv_iterator(self, tmp_path): + # Create an inner tar archive containing a .tex file. + inner_tar_path = tmp_path / "2103.00001.tar" + dummy_tex_filename = "2103.00001.tex" + dummy_tex_content = "This is a dummy LaTeX content." + with tarfile.open(inner_tar_path, "w") as inner_tar: + # Create a temporary tex file to add into the inner tar archive. + temp_tex_path = tmp_path / dummy_tex_filename + with open(temp_tex_path, "w") as f: + f.write(dummy_tex_content) + inner_tar.add(temp_tex_path, arcname=dummy_tex_filename) + + # Create an outer tar archive that contains the inner tar archive. + outer_tar_path = tmp_path / "dummy_main.tar" + with tarfile.open(outer_tar_path, "w") as outer_tar: + outer_tar.add(inner_tar_path, arcname="2103.00001.tar") + + iterator = ArxivIterator(log_frequency=1) + results = list(iterator.iterate(str(outer_tar_path))) + # Expect one paper extracted. + assert len(results) == 1 + metadata, tex_files = results[0] + # The ArxivIterator extracts the arxiv id from the inner archive's filename. + assert metadata["id"] == "2103.00001" + # The source_id is set to the outer tar file's basename. + assert metadata["source_id"] == "dummy_main.tar" + # Verify that the tex extraction returns the dummy content. + assert isinstance(tex_files, list) + assert dummy_tex_content in tex_files[0] + + def test_arxiv_extractor(self): + extractor = ArxivExtractor() + # Create a minimal LaTeX document including comments and a section header. + content = r""" + % This is a comment line that should be removed. + \section{Introduction} + This is the introduction of the paper. + % Another comment that should vanish. + """ + result = extractor.extract([content]) + assert isinstance(result, dict) + extracted_text = result.get("text", "") + # Verify that comments have been removed. + assert "% This is a comment" not in extracted_text + # Verify that the section header content is retained. + assert "Introduction" in extracted_text + assert "This is the introduction" in extracted_text + + +class TestCommonCrawl: + def test_common_crawl_downloader_existing_file(self, tmp_path, monkeypatch): + # Create a temporary downloads directory and simulate an already-downloaded file. + download_dir = tmp_path / "downloads" + download_dir.mkdir() + url = "http://dummy/commoncrawl.warc" + parsed = urlparse(url) + output_name = parsed.path[1:].replace("/", "-") # "commoncrawl.warc" + file_path = os.path.join(str(download_dir), output_name) + # Write dummy content to simulate an existing download. + with open(file_path, "w") as f: + f.write("existing content") + + downloader = CommonCrawlWARCDownloader( + str(download_dir), aws=False, verbose=False + ) + + # Monkey-patch subprocess.run to track if it gets called. + called_run = False + + def fake_run(cmd, stdout, stderr): + nonlocal called_run + called_run = True + return FakeCompletedProcess() + + monkeypatch.setattr(subprocess, "run", fake_run) + + result = downloader.download(url) + assert result == file_path + # Since the file already exists, no download should be attempted. + assert not called_run + + def test_common_crawl_downloader_new_file(self, tmp_path, monkeypatch): + # Create a temporary downloads directory; ensure the file does not exist. + download_dir = tmp_path / "downloads" + download_dir.mkdir() + url = "http://dummy/commoncrawl.warc" + parsed = urlparse(url) + output_name = parsed.path[1:].replace("/", "-") # "commoncrawl.warc" + file_path = os.path.join(str(download_dir), output_name) + if os.path.exists(file_path): + os.remove(file_path) + + downloader = CommonCrawlWARCDownloader( + str(download_dir), aws=False, verbose=False + ) + + called_run = False + + def fake_run(cmd, stdout, stderr): + nonlocal called_run + called_run = True + return FakeCompletedProcess() + + monkeypatch.setattr(subprocess, "run", fake_run) + + result = downloader.download(url) + assert result == file_path + # Since the file did not exist, a download call (and subprocess.run) should have been made. + assert called_run + + def test_common_crawl_iterator(self, tmp_path): + # Create a minimal valid WARC file with a single "response" record. + raw_warc_path = tmp_path / "dummy.warc" + http_response = ( + "HTTP/1.1 200 OK\r\n" + "Content-Type: text/html\r\n" + "\r\n" + "

Common Crawl test paragraph with some content.

\r\n" + ) + http_response_bytes = http_response.encode("utf-8") + content_length = len(http_response_bytes) + warc_record = ( + ( + f"WARC/1.0\r\n" + f"WARC-Type: response\r\n" + f"WARC-Record-ID: \r\n" + f"WARC-Date: 2022-01-01T00:00:00Z\r\n" + f"WARC-Target-URI: http://example.com\r\n" + f"Content-Length: {content_length}\r\n" + f"\r\n" + ).encode("utf-8") + + http_response_bytes + + b"\r\n\r\n" + ) + raw_warc_path.write_bytes(warc_record) + + iterator = CommonCrawlWARCIterator(log_frequency=1) + records = list(iterator.iterate(str(raw_warc_path))) + assert len(records) == 1 + meta, content = records[0] + # Check that the URL from the header is captured. + assert "example.com" in meta["url"] + # Verify that the content includes our test paragraph. + assert b"Common Crawl test paragraph" in content + + def test_common_crawl_extractor_justext(self): + extractor = CommonCrawlWARCExtractor(algorithm=JusTextExtractor()) + html = ( + "

Common Crawl test paragraph for justext extractor. " + "Four score and seven years ago our fathers brought forth on this continent a new nation, " + "conceived in liberty, and dedicated to the proposition that all men are created equal.

" + ) + content = html.encode("utf-8") + result = extractor.extract(content) + print(result) + assert result is not None + # The extracted text should include our test paragraph. + assert "Common Crawl test paragraph for justext extractor." in result["text"] + assert "language" in result + + def test_common_crawl_extractor_resiliparse(self): + extractor = CommonCrawlWARCExtractor(algorithm=ResiliparseExtractor()) + html = ( + "

Common Crawl test paragraph for resiliparse extractor. " + "Four score and seven years ago our fathers brought forth on this continent a new nation, " + "conceived in liberty, and dedicated to the proposition that all men are created equal.

" + ) + content = html.encode("utf-8") + result = extractor.extract(content) + print(result) + assert result is not None + assert ( + "Common Crawl test paragraph for resiliparse extractor." in result["text"] + ) + assert "language" in result diff --git a/tests/test_read_data.py b/tests/test_read_data.py index 0c9a2aa5..d362fb7e 100644 --- a/tests/test_read_data.py +++ b/tests/test_read_data.py @@ -540,6 +540,7 @@ def test_read_data_different_columns_files_per_partition( assert len(df) == NUM_FILES * NUM_RECORDS +@pytest.mark.skip(reason="Parquet tests are failing after upgrading to RAPIDS 25.02") @pytest.mark.parametrize( "backend,file_type", [ diff --git a/tutorials/distributed_data_classification/fineweb-edu-ensemble-classification.ipynb b/tutorials/distributed_data_classification/fineweb-edu-ensemble-classification.ipynb new file mode 100644 index 00000000..b74972a1 --- /dev/null +++ b/tutorials/distributed_data_classification/fineweb-edu-ensemble-classification.ipynb @@ -0,0 +1,1311 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "8f61f035-ab4e-4713-86d5-bb34bc0e8d75", + "metadata": {}, + "source": [ + "# Distributed Data Classification Using NeMo Curator: \n", + "### Ensembling `FineWeb Mixtral Educational Classifier`, `FineWeb Nemotron-4 Educational Classifier`, and `fasttext-oh-eli5`\n", + "\n", + "This notebook demonstrates distributed data classification by ensembling:\n", + "1. NeMo Curator’s [`FineWebMixtralEduClassifier`](https://huggingface.co/nvidia/nemocurator-fineweb-mixtral-edu-classifier)\n", + "2. NeMo Curator’s [`FineWebNemotronEduClassifier`](https://huggingface.co/nvidia/nemocurator-fineweb-nemotron-4-edu-classifier)\n", + "3. Fast Text's [`fasttext-oh-eli5`](https://huggingface.co/mlfoundations/fasttext-oh-eli5) from Hugging Face.\n", + "\n", + "The FineWeb educational classifiers (excluding FastText) leverage [CrossFit](https://github.com/rapidsai/crossfit), a RAPIDS-accelerated library for intelligent batching, to enhance offline inference performance on large datasets.\n", + "\n", + "Before running this notebook, follow the [Getting Started](https://github.com/NVIDIA/NeMo-Curator?tab=readme-ov-file#get-started) guide to install NeMo Curator.\n", + "\n", + "##### **Note on Curating Nemotron-CC**\n", + "This notebook showcases the classification script that was used in curating **Nemotron-CC**, a refined long-horizon pretraining dataset for large language models. As detailed in the paper [\"Nemotron-CC: Transforming Common Crawl into a Refined Long-Horizon Pretraining Dataset\"](https://arxiv.org/abs/2412.02595), Nemotron-CC was designed to improve the trade-off between dataset quality and quantity using a combination of **classifier ensembling, synthetic data rephrasing, and reduced reliance on heuristic filters**.\n", + "\n", + "By leveraging these techniques, **8B parameter models trained on 1T tokens with a high-quality subset of Nemotron-CC** achieved an **MMLU improvement of 5.6** over DCLM, demonstrating significant gains in benchmark performance. Furthermore, **Nemotron-CC’s full dataset (6.3T tokens)** provides **4× more unique real tokens than DCLM**, making it particularly effective for long-token-horizon training, such as 15T-token-scale LLMs.\n", + "\n", + "The dataset is publicly available at [Nemotron-CC](https://data.commoncrawl.org/contrib/Nemotron/Nemotron-CC/index.html).\n", + "\n", + " \n", + "## Steps in This Notebook \n", + "1. **Compute floating-point classification scores** for each classifier. \n", + "2. **Determine percentile-based score thresholds** to categorize results. \n", + "3. **Convert floating-point scores to integer scores** (0-19 scale). \n", + "4. **Ensemble the results** using the maximum classifier score. \n", + "5. **Store results** in directories or cloud buckets based on classification scores.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "c1882d2c-e2c1-4a59-9f9c-c12a76e9e04c", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "env: PYTHONWARNINGS=ignore\n" + ] + } + ], + "source": [ + "# Silence Warnings (HuggingFace internal warnings)\n", + "\n", + "%env PYTHONWARNINGS=ignore\n", + "import warnings\n", + "warnings.filterwarnings(\"ignore\")\n", + "import os" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "9ef3ef29-ab79-4fea-9050-017b9e9203dd", + "metadata": {}, + "outputs": [], + "source": [ + "import json\n", + "import fasttext\n", + "import pandas as pd\n", + "import cudf\n", + "import dask_cudf\n", + "import numpy as np\n", + "import cupy as cp\n", + "from pathlib import Path\n", + "from typing import Optional, Tuple, Any, Dict, List\n", + "from huggingface_hub import hf_hub_download\n", + "\n", + "from nemo_curator import get_client\n", + "from nemo_curator.classifiers import FineWebNemotronEduClassifier, FineWebMixtralEduClassifier\n", + "from nemo_curator.datasets import DocumentDataset\n", + "from nemo_curator.utils.distributed_utils import load_object_on_worker\n", + "from nemo_curator.utils.distributed_utils import get_device_total_memory" + ] + }, + { + "cell_type": "markdown", + "id": "325f2af0-c7a2-488b-8fb6-d35623159f06", + "metadata": {}, + "source": [ + "### Initializing NeMo Curator Client\n", + "This step initializes the NeMo Curator client to enable distributed classification using GPU-based processing." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "374004c9-fd63-490f-bc81-875fc2f15ae9", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "cuDF Spilling is enabled\n" + ] + } + ], + "source": [ + "client = get_client(cluster_type=\"gpu\")" + ] + }, + { + "cell_type": "markdown", + "id": "ab00c794-3655-44ee-be33-108958c01f43", + "metadata": {}, + "source": [ + "### Setting Output File Paths\n", + "Defines the paths where classification results, threshold values, and final bucketed results will be stored." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "47ca63af-78d2-4854-bb23-c6461b74d23e", + "metadata": {}, + "outputs": [], + "source": [ + "# Define output directories\n", + "OUTPUT_BASE_DIR = \"output_data_dir/\"\n", + "OUTPUT_CLASSIFICATION_RESULTS = os.path.join(OUTPUT_BASE_DIR, \"classification_results\")\n", + "OUTPUT_CLASSIFIER_THRESHOLDS = os.path.join(OUTPUT_BASE_DIR, \"classifier_thresholds.json\")\n", + "OUTPUT_BUCKETED_RESULTS = os.path.join(OUTPUT_BASE_DIR, \"bucketed_results\")" + ] + }, + { + "cell_type": "markdown", + "id": "15d6977b-885a-4029-a868-bc6d336085ed", + "metadata": {}, + "source": [ + "# Preparing Text Data for Classification\n", + "- We create a sample dataset with diverse topics.\n", + "- Optionally, users can provide a directory containing JSONL files for classification." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "a0eb4676-44e6-41ab-abb6-413f78bc9787", + "metadata": {}, + "outputs": [], + "source": [ + "# Create sample DataFrame\n", + "text = [\n", + " \"Quantum computing is set to revolutionize the field of cryptography.\",\n", + " \"Investing in index funds is a popular strategy for long-term financial growth.\",\n", + " \"Recent advancements in gene therapy offer new hope for treating genetic disorders.\",\n", + " \"Online learning platforms have transformed the way students access educational resources.\",\n", + " \"Traveling to Europe during the off-season can be a more budget-friendly option.\",\n", + " \"Training regimens for athletes have become more sophisticated with the use of data analytics.\",\n", + " \"Streaming services are changing the way people consume television and film content.\",\n", + " \"Vegan recipes have gained popularity as more people adopt plant-based diets.\",\n", + " \"Climate change research is critical for developing sustainable environmental policies.\",\n", + " \"Telemedicine has become increasingly popular due to its convenience and accessibility.\",\n", + "]\n", + "df = cudf.DataFrame({\"text\": text})\n", + "input_dataset = DocumentDataset(dask_cudf.from_cudf(df, npartitions=1))\n", + "write_to_filename = False\n", + "\n", + "# Alternatively, read existing directory of JSONL files\n", + "# input_file_path=\"/input_data_dir/\"\n", + "# input_dataset = DocumentDataset.read_json(\n", + "# input_file_path, backend=\"cudf\", add_filename=True\n", + "# )\n", + "# write_to_filename = True" + ] + }, + { + "cell_type": "markdown", + "id": "56b43d1a-7954-48b0-9c39-fe07c3ca06dc", + "metadata": {}, + "source": [ + "# Step 1: Run the Classifiers\n", + "\n", + "1. Compute the floating-point classification score for each classifier.\n", + "\n", + "**Note:** Dask operations are lazy, meaning the classifiers won’t execute until an eager operation like `to_json`, `compute`, or `persist` is called." + ] + }, + { + "cell_type": "markdown", + "id": "16962500-d2a4-4a40-8804-e7accd44abf5", + "metadata": {}, + "source": [ + "### FastText Quality Classifier\n", + "\n", + "The **FastText Quality Classifier** uses the [`fasttext-oh-eli5`](https://huggingface.co/mlfoundations/fasttext-oh-eli5) model from Hugging Face to assess text quality. It distinguishes **high-quality** (`__label__hq`) responses from lower-quality ones (`__label__cc`). \n", + "\n", + "NeMo Curator allows users to define custom modules like this, enabling seamless integration of specialized models. \n", + "\n", + "- **Model:** [`mlfoundations/fasttext-oh-eli5`](https://huggingface.co/mlfoundations/fasttext-oh-eli5) \n", + "- **Training Data:** Reddit ELI5 vs. Wikipedia (200k examples) \n", + "- **Output:** Confidence score + optional binary classification (where 1 represents high quality text and 0 represents low quality text) \n", + "\n", + "🔗 **More details:** [Hugging Face Model Card](https://huggingface.co/mlfoundations/fasttext-oh-eli5)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "55b09d94-abe6-4c14-aa2a-5302ca0a7f4b", + "metadata": {}, + "outputs": [], + "source": [ + "class FastTextQualityClassifier:\n", + " \"\"\"\n", + " A classifier that uses a fastText model to predict a confidence score for text.\n", + "\n", + " It appends one or two output columns to the data:\n", + " - A float column representing the confidence score.\n", + " - Optionally, an integer column (1 if the top label contains \"hq\", else 0).\n", + "\n", + " The model is loaded from the Hugging Face Hub during initialization.\n", + "\n", + " Args:\n", + " pred_column (str): Name of the output column for the confidence score.\n", + " int_column (str, optional): Name of the output column for the binary indicator.\n", + " If not provided, only the pred_column is added.\n", + " \"\"\"\n", + "\n", + " def __init__(self, pred_column: str, int_column: Optional[str] = None) -> None:\n", + " self.pred_column: str = pred_column\n", + " self.int_column: Optional[str] = int_column\n", + "\n", + " self.repo_id: str = \"mlfoundations/fasttext-oh-eli5\"\n", + " self.model_filename: str = \"openhermes_reddit_eli5_vs_rw_v2_bigram_200k_train.bin\"\n", + " # Download the fastText model from Hugging Face Hub.\n", + " self.model_path: str = hf_hub_download(repo_id=self.repo_id, filename=self.model_filename)\n", + " self.model_identifier: str = f\"{self.repo_id}/{self.model_filename}\"\n", + "\n", + " def _load_fasttext_model(self) -> Any:\n", + " \"\"\"Load and return the fastText model.\"\"\"\n", + " return fasttext.load_model(self.model_path)\n", + "\n", + " def predict_text(self, text: str) -> Tuple[float, int]:\n", + " \"\"\"\n", + " Predict the confidence score and binary indicator for a given text.\n", + "\n", + " Args:\n", + " text (str): The input text to classify.\n", + "\n", + " Returns:\n", + " Tuple[float, int]: A tuple containing the confidence score (float) and binary indicator (int).\n", + " \"\"\"\n", + " model = load_object_on_worker(self.model_identifier, self._load_fasttext_model, {})\n", + " predictions = model.predict(text, k=2) \n", + " # predictions[0]: labels, predictions[1]: scores\n", + " # If the top predicted label contains \"hq\", return the first score; otherwise, use the second.\n", + " if \"hq\" in predictions[0][0]:\n", + " return predictions[1][0], 1\n", + " else:\n", + " return predictions[1][1], 0\n", + "\n", + " def _predict_on_partition(self, df: pd.DataFrame) -> pd.DataFrame:\n", + " \"\"\"\n", + " Apply predictions to a pandas DataFrame partition.\n", + "\n", + " Assumes the DataFrame has a \"text\" column.\n", + "\n", + " Args:\n", + " df (pd.DataFrame): Input DataFrame partition.\n", + "\n", + " Returns:\n", + " pd.DataFrame: DataFrame with added prediction columns.\n", + " \"\"\"\n", + " # Load the model on the worker.\n", + " model = load_object_on_worker(self.model_identifier, self._load_fasttext_model, {})\n", + " results = df[\"text\"].apply(self.predict_text)\n", + " df[self.pred_column] = results.apply(lambda x: x[0]).astype(np.float32)\n", + " if self.int_column is not None:\n", + " df[self.int_column] = results.apply(lambda x: x[1]).astype(np.int32)\n", + " return df\n", + "\n", + " def __call__(self, dataset: DocumentDataset) -> DocumentDataset:\n", + " \"\"\"\n", + " Apply the classifier to a distributed dataset.\n", + "\n", + " The dataset should have a \"text\" column. The classifier converts the dataset\n", + " to a pandas backend, applies predictions to each partition, and then converts the result\n", + " back to cudf.\n", + "\n", + " Args:\n", + " dataset: A distributed DataFrame (e.g., a Dask DataFrame) containing a \"text\" column.\n", + "\n", + " Returns:\n", + " DocumentDataset: The dataset with added prediction columns.\n", + " \"\"\"\n", + " meta = dataset.df._meta\n", + " if hasattr(meta, \"to_pandas\"):\n", + " meta = meta.to_pandas()\n", + " meta[self.pred_column] = np.float32(0.0)\n", + " if self.int_column is not None:\n", + " meta[self.int_column] = np.int32(0)\n", + "\n", + " processed_df = dataset.df.to_backend(\"pandas\").map_partitions(self._predict_on_partition, meta=meta)\n", + " processed_df = processed_df.to_backend(\"cudf\")\n", + " return DocumentDataset(processed_df)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "8c7e28c3-8e25-417a-a1c7-7f5b237d18a0", + "metadata": {}, + "outputs": [], + "source": [ + "# Define classifier score mapping\n", + "classifier_scores = {\n", + " \"nemotron-score\": {\n", + " \"int_score\": \"fineweb-nemotron-edu-score-int\",\n", + " \"float_score\": \"fineweb-nemotron-edu-score\"\n", + " },\n", + " \"mixtral-score\": {\n", + " \"int_score\": \"fineweb-mixtral-edu-score-int\",\n", + " \"float_score\": \"fineweb-mixtral-edu-score\"\n", + " },\n", + " \"fasttext-score\": {\n", + " \"int_score\": \"fasttext-quality-score-int\",\n", + " \"float_score\": \"fasttext-quality-score\"\n", + " }\n", + "}\n", + "\n", + "\n", + "\n", + "# Initialize classifiers\n", + "classifiers = [\n", + " FineWebNemotronEduClassifier(batch_size=1024,\n", + " pred_column=classifier_scores[\"nemotron-score\"][\"float_score\"],\n", + " int_column=classifier_scores[\"nemotron-score\"][\"int_score\"]),\n", + " FineWebMixtralEduClassifier(batch_size=1024,\n", + " pred_column=classifier_scores[\"mixtral-score\"][\"float_score\"],\n", + " int_column=classifier_scores[\"mixtral-score\"][\"int_score\"]),\n", + " FastTextQualityClassifier(pred_column=classifier_scores[\"fasttext-score\"][\"float_score\"],\n", + " int_column=classifier_scores[\"fasttext-score\"][\"int_score\"])\n", + "]" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "a672e5d8-bb1e-4fe4-bdd7-f9859a449158", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Starting FineWeb Nemotron-4 Edu Classifier inference\n", + "Starting FineWeb Mixtral Edu Classifier inference\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "GPU: tcp://127.0.0.1:33001, Part: 0: 100%|██████████| 10/10 [00:02<00:00, 4.22it/s]\n", + "GPU: tcp://127.0.0.1:33001, Part: 0: 100%|██████████| 10/10 [00:01<00:00, 7.40it/s]\n" + ] + } + ], + "source": [ + "output_dataset = input_dataset\n", + "for classifier in classifiers:\n", + " output_dataset = classifier(dataset=output_dataset)\n", + "\n", + "# Dropping int columns\n", + "# As we add new based on a threshold (in the following columns)\n", + "output_dataset = output_dataset.df.drop(columns=[v[\"int_score\"] for v in classifier_scores.values()])\n", + "output_dataset.to_parquet(path=OUTPUT_CLASSIFICATION_RESULTS)" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "ec5cca63-ad01-4481-b910-8bcc735ece3a", + "metadata": {}, + "outputs": [], + "source": [ + "del classifiers, output_dataset, input_dataset" + ] + }, + { + "cell_type": "markdown", + "id": "1aa51ac8-373c-4318-9bf5-f063321cb3e0", + "metadata": {}, + "source": [ + "### Read Back in the scored Data Frame" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "229d2466-9064-4e2e-957e-07e949d2ae1a", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Reading 1 files with blocksize='1gb' / files_per_partition=None\n" + ] + }, + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
fasttext-quality-scorefineweb-mixtral-edu-scorefineweb-mixtral-edu-score-labelfineweb-nemotron-edu-scorefineweb-nemotron-edu-score-labeltext
00.9990111.347656low_quality1.391602low_qualityQuantum computing is set to revolutionize the ...
10.9962640.827637low_quality0.889160low_qualityInvesting in index funds is a popular strategy...
20.0000901.420898low_quality1.345703low_qualityRecent advancements in gene therapy offer new ...
30.0003771.572266low_quality1.727539low_qualityOnline learning platforms have transformed the...
40.9918680.345215low_quality0.248657low_qualityTraveling to Europe during the off-season can ...
\n", + "
" + ], + "text/plain": [ + " fasttext-quality-score fineweb-mixtral-edu-score \\\n", + "0 0.999011 1.347656 \n", + "1 0.996264 0.827637 \n", + "2 0.000090 1.420898 \n", + "3 0.000377 1.572266 \n", + "4 0.991868 0.345215 \n", + "\n", + " fineweb-mixtral-edu-score-label fineweb-nemotron-edu-score \\\n", + "0 low_quality 1.391602 \n", + "1 low_quality 0.889160 \n", + "2 low_quality 1.345703 \n", + "3 low_quality 1.727539 \n", + "4 low_quality 0.248657 \n", + "\n", + " fineweb-nemotron-edu-score-label \\\n", + "0 low_quality \n", + "1 low_quality \n", + "2 low_quality \n", + "3 low_quality \n", + "4 low_quality \n", + "\n", + " text \n", + "0 Quantum computing is set to revolutionize the ... \n", + "1 Investing in index funds is a popular strategy... \n", + "2 Recent advancements in gene therapy offer new ... \n", + "3 Online learning platforms have transformed the... \n", + "4 Traveling to Europe during the off-season can ... " + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "scored_data = DocumentDataset.read_parquet(OUTPUT_CLASSIFICATION_RESULTS, backend=\"cudf\")\n", + "scored_data.df.head()" + ] + }, + { + "cell_type": "markdown", + "id": "e7ef568a-6c17-4b7c-b201-627f33df26fa", + "metadata": {}, + "source": [ + "# Step 2: Compute Score Thresholds\n", + "\n", + "### Why Compute Thresholds?\n", + "- To categorize classification scores into percentile-based bins.\n", + "- Ensures results are comparable across different classifiers.\n", + "\n", + "### Approach:\n", + "1. **Extract classifier scores** from the sampled dataset.\n", + "2. **Compute weighted percentiles** for each classifier.\n", + "3. **Save percentile thresholds** for later use in mapping scores.\n", + "\n", + "> **Note:** The percentile calculation is weighted by token count so that longer texts (with more tokens) have a greater impact on the thresholds. This ensures that the bins accurately reflect the distribution of content, giving a more meaningful categorization of the scores." + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "26a34e8a-c893-454c-8f93-d09cd60d99ce", + "metadata": {}, + "outputs": [], + "source": [ + "def weighted_percentile(data, percentiles, weights):\n", + " \"\"\"\n", + " Compute weighted percentiles with the \"inverted_cdf\" method.\n", + "\n", + " Parameters:\n", + " data : array-like, the data values.\n", + " percentiles : scalar or array-like, percentiles in [0, 100].\n", + " weights : array-like, the weights for each data value.\n", + " \n", + " Returns:\n", + " The weighted percentile values.\n", + " \"\"\"\n", + " data = np.asarray(data)\n", + " weights = np.asarray(weights)\n", + " \n", + " # Sort data and associated weights\n", + " sorter = np.argsort(data)\n", + " data_sorted = data[sorter]\n", + " weights_sorted = weights[sorter]\n", + " \n", + " # Compute the cumulative sum of weights and normalize it to [0, 1]\n", + " cum_weights = np.cumsum(weights_sorted)\n", + " total_weight = cum_weights[-1]\n", + " normalized_cum_weights = cum_weights / total_weight\n", + "\n", + " # For each desired percentile, find the first data value where\n", + " # the normalized cumulative weight is >= (percentile / 100).\n", + " percentiles = np.atleast_1d(percentiles)\n", + " results = []\n", + " for p in percentiles:\n", + " # np.searchsorted returns the index where (p/100) should be inserted \n", + " # to maintain order.\n", + " idx = np.searchsorted(normalized_cum_weights, p / 100.0, side='left')\n", + " results.append(data_sorted[idx])\n", + " \n", + " return np.array(results)\n", + "\n", + "\n", + "def compute_thresholds(score_ar: np.ndarray, token_ar: np.ndarray) -> Dict[str, float]:\n", + " \"\"\"\n", + " Compute percentile-based thresholds for a given score column using weighted percentiles.\n", + "\n", + " Args:\n", + " score_ar (np.ndarray): Array containing the scores.\n", + " token_ar (np.ndarray): Array containing token counts for weighting.\n", + "\n", + " Returns:\n", + " Dict[str, float]: Dictionary containing percentile thresholds.\n", + " \"\"\"\n", + " percentiles = np.arange(5, 100, 5)\n", + " # NumPy < 2.0 does not support the \"inverted_cdf\" method for computing percentiles \n", + " # with weights directly via np.percentile (see commented-out equivalent code below).\n", + " # To achieve the same result, we manually implement the weighted percentile computation\n", + " # using NumPy primitives.\n", + " # thresholds = np.percentile(cc_df_score, percentiles, weights=cc_df_tokens, method='inverted_cdf')\n", + " thresholds = weighted_percentile(score_ar, percentiles, weights=token_ar)\n", + " return {int(percentile): float(thresh) for percentile, thresh in zip(percentiles, thresholds)}\n", + "\n", + "\n", + "def compute_thresholds_for_score_columns(\n", + " df: cudf.DataFrame, text_col_name: str, score_col_names: List[str]\n", + ") -> Dict[str, Dict[str, float]]:\n", + " \"\"\"\n", + " Compute percentile-based thresholds for all specified score columns in a DataFrame.\n", + "\n", + " Args:\n", + " df (cudf.DataFrame): The DataFrame containing the score columns and text column.\n", + " text_col_name (str): The name of the text column used to derive token counts.\n", + " score_col_names (List[str]): List of column names for which thresholds should be computed.\n", + "\n", + " Returns:\n", + " Dict[str, Dict[str, float]]: A dictionary mapping each score column to its percentile thresholds.\n", + " \"\"\"\n", + " threshold_dict = {}\n", + " token_series = df[text_col_name].str.byte_count()\n", + "\n", + " for score_col in score_col_names:\n", + " threshold_dict[score_col] = compute_thresholds(df[score_col].values.get(), token_series.values.get())\n", + "\n", + " return threshold_dict\n", + "\n", + "\n", + "def save_thresholds(threshold_dict: Dict[str, Dict[str, float]], file_name) -> None:\n", + " \"\"\"\n", + " Save computed thresholds to a JSON file.\n", + "\n", + " Args:\n", + " threshold_dict (Dict[str, Dict[str, float]]): The dictionary containing computed thresholds.\n", + " file_name (str, optional): The name of the output JSON file. Defaults to \"thresholds.json\".\n", + " Returns:\n", + " None\n", + " \"\"\"\n", + " with open(file_name, 'w') as fout:\n", + " json.dump(file_name, fout, indent=4)\n", + " print(f\"Thresholds saved to {file_name}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "c8b0a650-6290-4e60-b388-43950e1f7357", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Thresholds saved to output_data_dir/classifier_thresholds.json\n" + ] + } + ], + "source": [ + "# Adjust fraction based on how much can fit in a single GPU (1/2 ish)\n", + "gpu_memory_available = get_device_total_memory()/2\n", + "frac = max(1, scored_data.df.memory_usage(deep=True).sum().compute()/gpu_memory_available)\n", + "sampled_data = scored_data.df.sample(frac=frac).repartition(npartitions=1)\n", + "\n", + "score_col_names = [v[\"float_score\"] for v in classifier_scores.values()]\n", + "threshold_dict = sampled_data.map_partitions(compute_thresholds_for_score_columns, text_col_name=\"text\", score_col_names=score_col_names).compute().iloc[0]\n", + "save_thresholds(threshold_dict, OUTPUT_CLASSIFIER_THRESHOLDS)" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "83696e60-be44-434f-acab-ef275253732a", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'fineweb-nemotron-edu-score': {5: 0.2486572265625,\n", + " 10: 0.81884765625,\n", + " 15: 0.81884765625,\n", + " 20: 0.81884765625,\n", + " 25: 0.8427734375,\n", + " 30: 0.85400390625,\n", + " 35: 0.85400390625,\n", + " 40: 0.88916015625,\n", + " 45: 0.88916015625,\n", + " 50: 1.2880859375,\n", + " 55: 1.2880859375,\n", + " 60: 1.345703125,\n", + " 65: 1.345703125,\n", + " 70: 1.3916015625,\n", + " 75: 1.3916015625,\n", + " 80: 1.3994140625,\n", + " 85: 1.3994140625,\n", + " 90: 1.7275390625,\n", + " 95: 1.7275390625},\n", + " 'fineweb-mixtral-edu-score': {5: 0.34521484375,\n", + " 10: 0.7822265625,\n", + " 15: 0.7822265625,\n", + " 20: 0.82763671875,\n", + " 25: 0.82763671875,\n", + " 30: 0.9501953125,\n", + " 35: 0.9501953125,\n", + " 40: 1.0234375,\n", + " 45: 1.0234375,\n", + " 50: 1.34765625,\n", + " 55: 1.34765625,\n", + " 60: 1.4208984375,\n", + " 65: 1.4208984375,\n", + " 70: 1.42578125,\n", + " 75: 1.42578125,\n", + " 80: 1.572265625,\n", + " 85: 1.572265625,\n", + " 90: 1.783203125,\n", + " 95: 1.783203125},\n", + " 'fasttext-quality-score': {5: 9.026021871250123e-05,\n", + " 10: 9.026021871250123e-05,\n", + " 15: 0.00011704424832714722,\n", + " 20: 0.00011704424832714722,\n", + " 25: 0.00037683334085159004,\n", + " 30: 0.00037683334085159004,\n", + " 35: 0.0006898035062476993,\n", + " 40: 0.0006898035062476993,\n", + " 45: 0.9918678402900696,\n", + " 50: 0.9918678402900696,\n", + " 55: 0.9919403195381165,\n", + " 60: 0.9919403195381165,\n", + " 65: 0.9962636232376099,\n", + " 70: 0.9962636232376099,\n", + " 75: 0.9990114569664001,\n", + " 80: 0.9990114569664001,\n", + " 85: 0.9997979998588562,\n", + " 90: 0.9997979998588562,\n", + " 95: 0.9999129772186279}}" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "threshold_dict" + ] + }, + { + "cell_type": "markdown", + "id": "790a9c41-80ee-4885-8c7e-2b34b4e8117c", + "metadata": {}, + "source": [ + "# Step 3: Convert Floating-Point Scores to Integer Scores\n", + "\n", + "### Why Convert?\n", + "- Floating-point scores are mapped to integer categories (0-19) for easier comparison.\n", + "- Integer scores are computed using **percentile-based thresholds**.\n", + "\n", + "### Process:\n", + "1. **Retrieve percentile thresholds** from saved JSON.\n", + "2. **Apply the thresholds to map scores to integer bins**.\n", + "3. **Store integer scores in the dataset** for final ensemble computation." + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "4d94e75a-0a78-4554-bb39-087b009db1c3", + "metadata": {}, + "outputs": [], + "source": [ + "def map_scores(df, score_col_name: str, score_int_name: str, bins: List[float]):\n", + " \"\"\"\n", + " Given a DataFrame df and a column of original scores, \n", + " use cp.digitize to map them into integer bins using the given thresholds.\n", + " \"\"\"\n", + " pred_orig_score = cp.array(df[score_col_name])\n", + " pred_int_score = cp.digitize(pred_orig_score, bins)\n", + " df[score_int_name] = pred_int_score\n", + " return df\n", + "\n", + "def map_score_columns(df: cudf.DataFrame, score_col_names: List[str], threshold_dict: Dict[str, dict]):\n", + " \"\"\"\n", + " For each score column in score_col_names, this function:\n", + " 1. Creates a new column name by appending '-int'\n", + " 2. Retrieves the corresponding thresholds from threshold_dict,\n", + " sorts them (using the keys which are assumed to be strings of numbers),\n", + " 3. Passes the bins to map_scores to create the integer score column.\n", + " \"\"\"\n", + " for score_col_name in score_col_names:\n", + " # Build the new integer score column name.\n", + " score_int_name = score_col_name + \"-int\"\n", + " thresholds = threshold_dict.get(score_col_name)\n", + " if thresholds is None:\n", + " raise ValueError(f\"No thresholds found for score column '{score_col_name}'\")\n", + " \n", + " sorted_keys = sorted(thresholds.keys(), key=lambda x: int(x))\n", + " # Use cp.array to create a CuPy array from the list of threshold values.\n", + " bins = cp.array([thresholds[k] for k in sorted_keys])\n", + " \n", + " # Map the original score column to the new integer score column.\n", + " df = map_scores(df, score_col_name, score_int_name, bins)\n", + " return df\n" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "70e6a00e-5e42-493a-9dcb-682df8eead0d", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
fasttext-quality-scorefineweb-mixtral-edu-scorefineweb-mixtral-edu-score-labelfineweb-nemotron-edu-scorefineweb-nemotron-edu-score-labeltextfineweb-nemotron-edu-score-intfineweb-mixtral-edu-score-intfasttext-quality-score-int
00.9990111.347656low_quality1.391602low_qualityQuantum computing is set to revolutionize the ...151116
10.9962640.827637low_quality0.889160low_qualityInvesting in index funds is a popular strategy...9514
20.0000901.420898low_quality1.345703low_qualityRecent advancements in gene therapy offer new ...13132
30.0003771.572266low_quality1.727539low_qualityOnline learning platforms have transformed the...19176
40.9918680.345215low_quality0.248657low_qualityTraveling to Europe during the off-season can ...1110
\n", + "
" + ], + "text/plain": [ + " fasttext-quality-score fineweb-mixtral-edu-score \\\n", + "0 0.999011 1.347656 \n", + "1 0.996264 0.827637 \n", + "2 0.000090 1.420898 \n", + "3 0.000377 1.572266 \n", + "4 0.991868 0.345215 \n", + "\n", + " fineweb-mixtral-edu-score-label fineweb-nemotron-edu-score \\\n", + "0 low_quality 1.391602 \n", + "1 low_quality 0.889160 \n", + "2 low_quality 1.345703 \n", + "3 low_quality 1.727539 \n", + "4 low_quality 0.248657 \n", + "\n", + " fineweb-nemotron-edu-score-label \\\n", + "0 low_quality \n", + "1 low_quality \n", + "2 low_quality \n", + "3 low_quality \n", + "4 low_quality \n", + "\n", + " text \\\n", + "0 Quantum computing is set to revolutionize the ... \n", + "1 Investing in index funds is a popular strategy... \n", + "2 Recent advancements in gene therapy offer new ... \n", + "3 Online learning platforms have transformed the... \n", + "4 Traveling to Europe during the off-season can ... \n", + "\n", + " fineweb-nemotron-edu-score-int fineweb-mixtral-edu-score-int \\\n", + "0 15 11 \n", + "1 9 5 \n", + "2 13 13 \n", + "3 19 17 \n", + "4 1 1 \n", + "\n", + " fasttext-quality-score-int \n", + "0 16 \n", + "1 14 \n", + "2 2 \n", + "3 6 \n", + "4 10 " + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "scored_data.df = scored_data.df.map_partitions(map_score_columns, score_col_names, threshold_dict)\n", + "scored_data.head()" + ] + }, + { + "cell_type": "markdown", + "id": "9cf526dd-363f-4199-bb9d-2ea9b8897fae", + "metadata": {}, + "source": [ + "# Step 4: Compute the Final Ensembled Score\n", + "\n", + "### Purpose:\n", + "- To combine the predictions from multiple classifiers into a **single representative score**.\n", + "- The ensemble score is computed as the **maximum of all integer scores** across classifiers.\n", + "\n", + "### Approach:\n", + "1. **Extract integer scores from each classifier.**\n", + "2. **Compute the max integer score for each data point.**\n", + "3. **Store the final ensemble score in the dataset.**" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "380f453d-4d6e-43cc-9d33-3d4d64b854d4", + "metadata": {}, + "outputs": [], + "source": [ + "int_column_names = [f'{v[\"float_score\"]}-int' for v in classifier_scores.values()]\n", + "scored_data.df['ensemble-max-int'] = scored_data.df[int_column_names].max(axis=1)" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "469cafbe-f8d2-466d-9e80-2522c59a0a1a", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
fasttext-quality-scorefineweb-mixtral-edu-scorefineweb-mixtral-edu-score-labelfineweb-nemotron-edu-scorefineweb-nemotron-edu-score-labeltextfineweb-nemotron-edu-score-intfineweb-mixtral-edu-score-intfasttext-quality-score-intensemble-max-int
00.9990111.347656low_quality1.391602low_qualityQuantum computing is set to revolutionize the ...15111616
10.9962640.827637low_quality0.889160low_qualityInvesting in index funds is a popular strategy...951414
20.0000901.420898low_quality1.345703low_qualityRecent advancements in gene therapy offer new ...1313213
30.0003771.572266low_quality1.727539low_qualityOnline learning platforms have transformed the...1917619
40.9918680.345215low_quality0.248657low_qualityTraveling to Europe during the off-season can ...111010
\n", + "
" + ], + "text/plain": [ + " fasttext-quality-score fineweb-mixtral-edu-score \\\n", + "0 0.999011 1.347656 \n", + "1 0.996264 0.827637 \n", + "2 0.000090 1.420898 \n", + "3 0.000377 1.572266 \n", + "4 0.991868 0.345215 \n", + "\n", + " fineweb-mixtral-edu-score-label fineweb-nemotron-edu-score \\\n", + "0 low_quality 1.391602 \n", + "1 low_quality 0.889160 \n", + "2 low_quality 1.345703 \n", + "3 low_quality 1.727539 \n", + "4 low_quality 0.248657 \n", + "\n", + " fineweb-nemotron-edu-score-label \\\n", + "0 low_quality \n", + "1 low_quality \n", + "2 low_quality \n", + "3 low_quality \n", + "4 low_quality \n", + "\n", + " text \\\n", + "0 Quantum computing is set to revolutionize the ... \n", + "1 Investing in index funds is a popular strategy... \n", + "2 Recent advancements in gene therapy offer new ... \n", + "3 Online learning platforms have transformed the... \n", + "4 Traveling to Europe during the off-season can ... \n", + "\n", + " fineweb-nemotron-edu-score-int fineweb-mixtral-edu-score-int \\\n", + "0 15 11 \n", + "1 9 5 \n", + "2 13 13 \n", + "3 19 17 \n", + "4 1 1 \n", + "\n", + " fasttext-quality-score-int ensemble-max-int \n", + "0 16 16 \n", + "1 14 14 \n", + "2 2 13 \n", + "3 6 19 \n", + "4 10 10 " + ] + }, + "execution_count": 17, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "scored_data.df.head()" + ] + }, + { + "cell_type": "markdown", + "id": "35ba68b8-8566-401a-882b-eb2ae0414138", + "metadata": {}, + "source": [ + "# Step 5: Write Results to Partitioned Buckets\n", + "\n", + "\n", + "### Purpose:\n", + "- Organize and store classified results in a **structured, partitioned format** to facilitate **annealing-based training** for downstream **LLM fine-tuning** and optimization." + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "5b6bfcc8-5fef-41df-9e04-c50c35538ff3", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Writing to disk complete for 1 partition(s)\n" + ] + } + ], + "source": [ + "scored_data.to_parquet(OUTPUT_BUCKETED_RESULTS, partition_on=\"ensemble-max-int\")" + ] + }, + { + "cell_type": "markdown", + "id": "8052be9b-6889-4254-bf21-ef1c8b41b82f", + "metadata": {}, + "source": [ + "# Verify Results\n", + "\n", + "### Process:\n", + "1. **List available partitions** (each corresponds to a score bucket).\n", + "2. **Read a sample partition** and validate data integrity." + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "67f0fc7b-eca6-4326-9a58-54d27daaf06a", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "['ensemble-max-int=1', 'ensemble-max-int=10', 'ensemble-max-int=12', 'ensemble-max-int=13', 'ensemble-max-int=14', 'ensemble-max-int=16', 'ensemble-max-int=17', 'ensemble-max-int=18', 'ensemble-max-int=19', 'ensemble-max-int=3', 'ensemble-max-int=5', 'ensemble-max-int=7', 'ensemble-max-int=9']\n", + "Reading 1 files with blocksize='1gb' / files_per_partition=None\n" + ] + }, + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
ensemble-max-intfasttext-quality-scorefasttext-quality-score-intfineweb-mixtral-edu-scorefineweb-mixtral-edu-score-intfineweb-nemotron-edu-scorefineweb-nemotron-edu-score-inttext
010.13574210.13574210.1357421Traveling to Europe during the off-season can ...
\n", + "
" + ], + "text/plain": [ + " ensemble-max-int fasttext-quality-score fasttext-quality-score-int \\\n", + "0 1 0.135742 1 \n", + "\n", + " fineweb-mixtral-edu-score fineweb-mixtral-edu-score-int \\\n", + "0 0.135742 1 \n", + "\n", + " fineweb-nemotron-edu-score fineweb-nemotron-edu-score-int \\\n", + "0 0.135742 1 \n", + "\n", + " text \n", + "0 Traveling to Europe during the off-season can ... " + ] + }, + "execution_count": 19, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "all_buckets = sorted(os.listdir(OUTPUT_BUCKETED_RESULTS))\n", + "print(all_buckets)\n", + "first_bucket= DocumentDataset.read_parquet(os.path.join(OUTPUT_BUCKETED_RESULTS, all_buckets[0]))\n", + "first_bucket.head()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.16" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/tutorials/image-curation/image-curation.ipynb b/tutorials/image-curation/image-curation.ipynb index 8e1b5170..47bed501 100644 --- a/tutorials/image-curation/image-curation.ipynb +++ b/tutorials/image-curation/image-curation.ipynb @@ -36,7 +36,10 @@ "metadata": {}, "source": [ "## Install NeMo Curator\n", - "This installs NeMo Curator and some additional libraries for helper functions in the notebook" + "\n", + "If you have not already, please install NeMo Curator by following the [README](https://github.com/NVIDIA/NeMo-Curator?tab=readme-ov-file#nemo-curator); you should install either `nemo-curator[all]` or `nemo-curator[image]` for this tutorial. If you are using the NeMo Framework Container, then NeMo Curator is already installed and no action is needed.\n", + "\n", + "We also need to install some additional libraries for helper functions in the notebook:" ] }, { @@ -49,9 +52,7 @@ }, "outputs": [], "source": [ - "!pip install ipywidgets aiofiles\n", - "# Install from source by default\n", - "!pip install --extra-index-url https://pypi.nvidia.com ../../[image]" + "!pip install ipywidgets aiofiles" ] }, {