diff --git a/docs/user-guide/download.rst b/docs/user-guide/download.rst index d4b854e4..d2b50e0c 100644 --- a/docs/user-guide/download.rst +++ b/docs/user-guide/download.rst @@ -37,10 +37,42 @@ By "extraction", we typically mean the process of converting a data format from Otherwise, the HTTPS endpoints will be used with ``wget``. Here is a small example of how to use it: .. code-block:: python - + import os + from nemo_curator import get_client from nemo_curator.download import download_common_crawl - - common_crawl = download_common_crawl("/extracted/output/folder", "2020-50", "2021-04", output_type="jsonl") + from nemo_curator.datasets import DocumentDataset + + def main(): + # Initialize a distributed Dask client + client = get_client(cluster_type="cpu") + + # Parameters for downloading Common Crawl data. + # - output_folder: directory for temporary download/extraction files + # - start_snapshot and end_snapshot define the range to fetch + # - output_type: specifies file format for the extracted data (e.g., "jsonl") + output_folder = "/extracted/output/folder" + start_snapshot = "2020-50" + end_snapshot = "2021-04" + output_type = "jsonl" + os.makedirs(output_folder, exist_ok=True) + + # Download and extract the Common Crawl data. + # The function returns a DocumentDataset that contains the extracted documents. + # Note: The output folder and output type are passed here to store intermediate files + # and check if the data has already been downloaded. They should match the final location + # and format of the extracted data. + common_crawl_dataset = download_common_crawl( + output_folder, start_snapshot, end_snapshot, output_type=output_type + ) + + # Write the extracted dataset to JSON format. + # The 'to_json' method will write one JSON document per line, + # preserving the original shard information if write_to_filename is True. + common_crawl_dataset.to_json(output_path=output_folder, write_to_filename=True) + print("Extracted dataset saved to:", output_folder) + + if __name__ == "__main__": + main() * ``"/extracted/output/folder"`` is the path to on your local filesystem where the final extracted files will be placed. * ``"2020-50"`` is the first common crawl snapshot that will be included in the download. **Note:** Not every year and week has a snapshot. Ensure that your range includes at least one valid Common Crawl snapshot. A list of valid Common Crawl snapshots can be found `here `_. @@ -50,21 +82,49 @@ By "extraction", we typically mean the process of converting a data format from You can choose to modify the HTML text extraction algorithm used in ``download_common_crawl``. See an example below. .. code-block:: python - + import os + from nemo_curator import get_client from nemo_curator.download import ( - ResiliparseExtractor, - download_common_crawl, - ) - - # Change the extraction algorithm - extraction_algorithm = ResiliparseExtractor() - common_crawl = download_common_crawl( - "/extracted/output/folder", - "2020-50", - "2021-04", - output_type="jsonl", - algorithm=extraction_algorithm, + ResiliparseExtractor, + download_common_crawl, ) + from nemo_curator.datasets import DocumentDataset + + def main(): + # Initialize a distributed Dask client + client = get_client(cluster_type="cpu") + + # Parameters for downloading Common Crawl data. + # - output_folder: directory for temporary download/extraction files + # - start_snapshot and end_snapshot define the range to fetch + # - output_type: specifies file format for the extracted data (e.g., "jsonl") + output_folder = "/extracted/output/folder" + start_snapshot = "2020-50" + end_snapshot = "2021-04" + output_type = "jsonl" + os.makedirs(output_folder, exist_ok=True) + + # Change the extraction algorithm to use ResiliparseExtractor + extraction_algorithm = ResiliparseExtractor() + + # Download and extract the Common Crawl data using the Resiliparse extraction algorithm. + # The function returns a DocumentDataset that contains the extracted documents. + common_crawl_dataset = download_common_crawl( + output_folder, + start_snapshot, + end_snapshot, + output_type=output_type, + algorithm=extraction_algorithm, + ) + + # Write the extracted dataset to JSON format. + # The 'to_json' method writes one JSON document per line, + # preserving the original shard information if write_to_filename is True. + common_crawl_dataset.to_json(output_path=output_folder, write_to_filename=True) + print("Extracted dataset saved to:", output_folder) + + if __name__ == "__main__": + main() Above, we changed the extraction algorithm from the default ``JusTextExtractor``. diff --git a/nemo_curator/download/arxiv.py b/nemo_curator/download/arxiv.py index 538d567d..e7e5327a 100644 --- a/nemo_curator/download/arxiv.py +++ b/nemo_curator/download/arxiv.py @@ -18,6 +18,7 @@ import subprocess import tarfile import tempfile +from typing import Literal, Optional from nemo_curator.datasets import DocumentDataset from nemo_curator.download.doc_builder import ( @@ -218,12 +219,12 @@ def extract(self, content): for file_content in content ) except Exception: - return {}, None + return None # Don't return meta if cleaned_latex_file_str is not None: if len(cleaned_latex_file_str) > 0: - return {}, cleaned_latex_file_str + return {"text": cleaned_latex_file_str} def _clean_tex_file(self, file_content, arg_macros, non_arg_macros): r"""function takes a tex file as input and returns a cleaned version. The @@ -365,25 +366,44 @@ def _build_non_arg_macros_dict(self, file_content): def download_arxiv( output_path: str, - output_type: str = "jsonl", - raw_download_dir=None, - keep_raw_download=False, - force_download=False, - url_limit=None, + output_type: Literal["jsonl", "parquet"] = "jsonl", + raw_download_dir: Optional[str] = None, + keep_raw_download: bool = False, + force_download: bool = False, + url_limit: Optional[int] = None, + record_limit: Optional[int] = None, ) -> DocumentDataset: """ - Downloads Arxiv tar files and extracts them + Download Arxiv tar files and extract the contained LaTeX projects. + + This function obtains a list of Arxiv tar file URLs (via get_arxiv_urls), downloads the tar files, + and then extracts the contained LaTeX source files. The resulting documents (after extraction) are + assembled into a DocumentDataset. Args: - output_path: The path to the root directory of the files - output_type: The file type to save the data as. - raw_download_dir: Path to store the raw download files for intermediate processing. - If None, they are stored in a folder named "downloads" under output_path. - keep_raw_download: If True, keeps the compressed WARC files that have not been extracted. - force_download: If False, will skip processing all files in output_paths that already exist and - directly read from them instead. - url_limit: The maximum number of raw files to download from the snapshot. If None, all - files from the range of snapshots are downloaded. + output_path (str): + The root directory where both the final extracted files and the raw download subdirectory will be stored. + The extracted files (in the format specified by output_type) are eventually saved in this directory. + output_type (Literal["jsonl", "parquet"], optional): + The file format/extension used for saving the extracted documents (e.g., "jsonl" or "parquet"). + Default is "jsonl". This is not used for the output file, but is used to check if an extracted output already exists and read it if so. + raw_download_dir (Optional[str], optional): + The directory where the raw downloaded tar files will be kept. If None, a folder named "downloads" + under output_path is used. + keep_raw_download (bool, optional): + If True, the raw tar files (before extraction) are not removed after processing. Default is False. + force_download (bool, optional): + If False, then if an output file already exists for a given URL, re-downloading and re-extraction will be skipped. + Default is False. + url_limit (Optional[int], optional): + Limits the maximum number of Arxiv tar file URLs to download and process. + If None, all available URLs (from get_arxiv_urls) are processed. + record_limit (Optional[int], optional): + Limits the maximum number of records to extract from each tar file. + If None, all available records are extracted. + Returns: + DocumentDataset: + A dataset object containing the extracted documents. """ arxiv_urls = get_arxiv_urls() if url_limit: @@ -416,6 +436,7 @@ def download_arxiv( keep_raw_download=keep_raw_download, force_download=force_download, filename_col="file_name", + record_limit=record_limit, ) return dataset diff --git a/nemo_curator/download/commoncrawl.py b/nemo_curator/download/commoncrawl.py index de7e333b..c651d4d6 100644 --- a/nemo_curator/download/commoncrawl.py +++ b/nemo_curator/download/commoncrawl.py @@ -17,6 +17,7 @@ import subprocess import unicodedata from abc import ABC, abstractmethod +from typing import Literal, Optional from urllib.parse import urlparse import justext @@ -352,48 +353,54 @@ def extract(self, content): if text is not None: if len(text) > 0: text = "\n\n".join(text) - meta = {"language": lang} - return meta, text + meta = {"language": lang, "text": text} + return meta else: - return None, None + return None def download_common_crawl( output_path: str, start_snapshot: str, end_snapshot: str, - output_type: str = "jsonl", + output_type: Literal["jsonl", "parquet"] = "jsonl", algorithm=JusTextExtractor(), - news=False, - aws=False, - raw_download_dir=None, - keep_raw_download=False, - force_download=False, - url_limit=None, + news: bool = False, + aws: bool = False, + raw_download_dir: Optional[str] = None, + keep_raw_download: bool = False, + force_download: bool = False, + url_limit: Optional[int] = None, + record_limit: Optional[int] = None, ) -> DocumentDataset: """ - Downloads Common Crawl WARC snapshots and extracts them using jusText or Resiliparse + Downloads Common Crawl WARC snapshots and extracts text content using a specified extraction algorithm. Args: - output_path: The path to the root directory of the files - start_snapshot: The first common crawl snapshot to include. Snapshots must be - specified by YYYY-WeekNumber (e.g., '2020-50' or '2021-04'). For the CC-NEWS dataset, - (specified with news=True flag) this changes to Year-Month (YYYY-MM). - end_snapshot: The last common crawl snapshot to include. Must be chronologically - after the starting snapshot. - output_type: The file type to save the data as. - algorithm: A JusTextExtractor or ResiliparseExtractor object. - news: If True, gets WARC URLs for the CC-NEWS dataset instead of the CC-MAIN datasets. - Also assumes that the format for the start and end snapshots is 'YYYY-MM' (Year-Month). - aws: Whether to download from Common Crawl's S3 bucket. If True, uses s5cmd to download. - If False, uses wget. - raw_download_dir: Path to store the raw download files for intermediate processing. - If None, they are stored in a folder named "downloads" under output_path. - keep_raw_download: If True, keeps the compressed WARC files that have not been extracted. - force_download: If False, will skip processing all files in output_paths that already exist and - directly read from them instead. - url_limit: The maximum number of raw files to download from the snapshot. If None, all - files from the range of snapshots are downloaded. + output_path (str): The root directory used for managing download and extraction. + • Raw WARC files are stored in a "downloads" subdirectory under this path. + • This path is also checked for existing extraction results; if found, extraction can be skipped. + • Note: This function returns a DocumentDataset, and writing the extracted data to disk is the caller's responsibility. + start_snapshot (str): Identifier for the earliest snapshot to process. + • For CC-MAIN datasets, use the format 'YYYY-WeekNumber' (e.g., '2020-50' or '2021-04'). + • For CC-NEWS datasets (when news=True), use the 'YYYY-MM' (Year-Month) format. + end_snapshot (str): Identifier for the latest snapshot to process, which must be chronologically after start_snapshot. + output_type (Literal["jsonl", "parquet"]): The file format for the extracted output. Must be either "jsonl" or "parquet". + • This is not used for the output file, but is used to check if an extracted output already exists. + algorithm: The text extraction algorithm instance (e.g., JusTextExtractor or ResiliparseExtractor) to use for HTML processing. + news (bool): When True, indicates that URLs should be retrieved from the CC-NEWS dataset. + • This also means snapshot identifiers should follow the 'YYYY-MM' format. + aws (bool): If True, downloads are sourced from Common Crawl's S3 bucket using s5cmd; + • If False, wget is used to fetch the files via HTTPS. + raw_download_dir: Optional; the directory to temporarily store raw WARC files. + • If not provided, defaults to a "downloads" folder within output_path. + keep_raw_download (bool): If True, retains the downloaded raw WARC files after extraction. + • If False, these raw files may be removed following extraction. + force_download (bool): If False, skips re-downloading or re-extracting snapshots if outputs already exist in output_path. + url_limit: Optional; the maximum number of WARC files to download from the snapshot range. + • If None, all available files within the specified snapshots are downloaded. + record_limit: Optional; the maximum number of records to extract from each WARC file. + • If None, all available records are extracted. """ common_crawl_urls = get_common_crawl_urls( starting_snapshot=start_snapshot, ending_snapshot=end_snapshot, news=news @@ -443,6 +450,7 @@ def download_common_crawl( keep_raw_download=keep_raw_download, force_download=force_download, filename_col="file_name", + record_limit=record_limit, ) return dataset diff --git a/nemo_curator/download/doc_builder.py b/nemo_curator/download/doc_builder.py index 5ad3094e..9d3e849b 100644 --- a/nemo_curator/download/doc_builder.py +++ b/nemo_curator/download/doc_builder.py @@ -15,17 +15,14 @@ import importlib import os from abc import ABC, abstractmethod -from typing import List, Tuple, Union +from typing import List, Literal, Optional, Tuple, Union import dask.dataframe as dd import pandas as pd from dask import compute, delayed from nemo_curator.datasets import DocumentDataset -from nemo_curator.utils.distributed_utils import ( - read_single_partition, - single_partition_write_with_filename, -) +from nemo_curator.utils.distributed_utils import read_single_partition class DocumentDownloader(ABC): @@ -108,14 +105,36 @@ def _download_and_extract_single_partition( downloader: DocumentDownloader, iterator: DocumentIterator, extractor: DocumentExtractor, - output_type: str, + output_type: Literal["jsonl", "parquet"], keep_raw_download: bool, force_download: bool, input_meta: Union[str, dict] = None, filename_col: str = "file_name", + record_limit: Optional[int] = None, ) -> pd.DataFrame: + """ + Downloads a single partition from a URL and extracts its contents in-memory without writing + the extracted output to disk. The provided output_path is used only to check if an extracted + output already exists and, if so, skips re-downloading and extraction. + + Parameters: + paths (Tuple[str, str]): A tuple (url, output_path) where 'url' is the source URL and + 'output_path' is the expected location of a previously extracted output. + downloader (DocumentDownloader): An object to download the content from the URL. + iterator (DocumentIterator): An object to iterate over records in the downloaded file. + extractor (DocumentExtractor): An object to extract the desired content from each record. + output_type (Literal["jsonl", "parquet"]): The output file format/extension. Must be either "jsonl" or "parquet". Defaults to "jsonl". This parameter is only used to verify whether an extracted output already exists. + keep_raw_download (bool): If False, deletes the raw download file after extraction. + force_download (bool): If False and output_path exists, skips downloading and extraction. + input_meta (Union[str, dict], optional): Metadata describing the input file's structure. + filename_col (str, optional): Name of the column to store the filename within the result DataFrame. + record_limit (int, optional): Limit the number of records to extract from each file. + Returns: + pd.DataFrame: A DataFrame containing the extracted records. + """ url, output_path = paths + # If an extracted output already exists and we're not forcing a download, load and return it. if os.path.exists(output_path) and not force_download: partition = read_single_partition( [output_path], @@ -125,31 +144,26 @@ def _download_and_extract_single_partition( ) return partition + # Download the file and extract its records in memory. downloaded_file = downloader.download(url) records = [] - # Iterate over all records in file for item in iterator.iterate(downloaded_file): + if record_limit is not None and len(records) >= record_limit: + break record_meta, content = item - # Extract the text from the record extracted = extractor.extract(content) if extracted is not None: - text_meta, text = extracted - if text is not None: - line = { - "text": text, - **text_meta, - **record_meta, - } - records.append(line) + # Merge the extracted data and record metadata into one dictionary. + line = {**extracted, **record_meta} + records.append(line) partition = pd.DataFrame(records) - filename = os.path.basename(output_path) - output_dir = os.path.dirname(output_path) - partition[filename_col] = filename - single_partition_write_with_filename( - partition, output_dir, output_type=output_type, filename_col=filename_col - ) - if not keep_raw_download: + # Add a filename column for consistency using the basename of the output_path. + partition[filename_col] = os.path.basename(output_path) + + # Since we're not writing the extracted partition to disk, the output_path is not used here. + # Clean up the raw downloaded file if it's not meant to be kept. + if not keep_raw_download and os.path.exists(downloaded_file): os.remove(downloaded_file) return partition @@ -162,42 +176,76 @@ def download_and_extract( iterator: DocumentIterator, extractor: DocumentExtractor, output_format: dict, - output_type: str = "jsonl", - keep_raw_download=False, - force_download=False, + output_type: Literal["jsonl", "parquet"] = "jsonl", + keep_raw_download: bool = False, + force_download: bool = False, input_meta: Union[str, dict] = None, filename_col: str = "file_name", + record_limit: Optional[int] = None, ) -> DocumentDataset: """ - Downloads and extracts a dataset into a format accepted by the NeMo Curator + Download files from the given URLs, extract their records, and + construct a DocumentDataset. + + For each URL provided, this function downloads the corresponding + file (unless an extracted output already exists and force_download is + False), iterates over its records, extracts the desired content, and + finally converts all records into a DocumentDataset. Args: - urls: A list of urls to download the dataset from - output_paths: A list of paths to save the final extracted output to. - The raw output of the downloader will be saved using the path given by downloader.download(url). - downloader: A DocumentDownloader that handles retrieving each file from its url and saving it to storage - iterator: A DocumentIterator that handles iterating through the downloaded file's format - extractor: A DocumentExtractor that handles extracting the data from its raw format into text - output_format: A dictionary mappings columns to datatypes for the fields of each datapoint after extraction. - output_type: The file type to save the dataset as. - keep_raw_download: Whether to keep the pre-extracted download file. - force_download: If False, will skip processing all files in output_paths that already exist and - directly read from them instead. - input_meta: A dictionary or a string formatted as a dictionary, which outlines - the field names and their respective data types within the JSONL input file. - filename_col : The name of the column that contains the filename. Default is "filename_col" + urls (List[str]): + A list of URLs from which to download dataset files. + output_paths (List[str]): + A list of file paths where the extracted outputs should be + found. If a file already exists at a given path and force_download + is False, that partition is skipped. + downloader (DocumentDownloader): + The downloader instance responsible for fetching files from + the specified URLs. + iterator (DocumentIterator): + The iterator instance used to traverse the downloaded file + and yield records. + extractor (DocumentExtractor): + The extractor instance used to obtain the desired content from + each record. + output_format (dict): + A dictionary mapping column names to the data types for the + extracted records. + output_type (Literal["jsonl", "parquet"], optional): + The output file format/extension. Must be either "jsonl" or "parquet". + Defaults to "jsonl". This parameter is only used to verify whether + an extracted output already exists. + keep_raw_download (bool, optional): + If True, the raw downloaded files are retained after extraction. + Defaults to False. + force_download (bool, optional): + If False and an output file already exists at a given path, the + download and extraction for that file are skipped. + Defaults to False. + input_meta (Union[str, dict], optional): + Optional metadata describing the input file's schema. + Defaults to None. + filename_col (str, optional): + The name for the column in the resulting dataset that records + the basename of the output file. Defaults to "file_name". + record_limit (int, optional): Limit the number of records to extract from each file. + Defaults to None. Returns: - A DocumentDataset of the downloaded data + DocumentDataset: + A dataset composed of the records extracted from the downloaded + files. """ - if len(urls) == 0: - raise ValueError("No urls were provided to download") - + # Validate parameters + if not urls: + raise ValueError("No URLs were provided to download") if len(urls) != len(output_paths): raise ValueError( - f"Different number of urls and output_paths. {len(urls)} urls vs {len(output_paths)} output_paths" + f"Different number of URLs and output_paths. {len(urls)} URLs vs {len(output_paths)} output_paths" ) + # Ensure consistent ordering of output_format keys. output_format = dict(sorted(output_format.items())) + df = dd.from_map( _download_and_extract_single_partition, zip(urls, output_paths), @@ -210,6 +258,7 @@ def download_and_extract( enforce_metadata=False, input_meta=input_meta, filename_col=filename_col, + record_limit=record_limit, meta=output_format, ) diff --git a/nemo_curator/download/wikipedia.py b/nemo_curator/download/wikipedia.py index 54b85d45..e494ed38 100644 --- a/nemo_curator/download/wikipedia.py +++ b/nemo_curator/download/wikipedia.py @@ -18,6 +18,7 @@ import re import subprocess import xml.etree.cElementTree as etree +from typing import Literal, Optional from urllib.parse import quote, urlparse import mwparserfromhell @@ -742,35 +743,49 @@ def try_remove_obj(obj, section): ) ) # Don't return any meta here - return {}, "\n\n".join(section_text) + return {"text": "\n\n".join(section_text)} def download_wikipedia( output_path: str, language: str = "en", - dump_date=None, - output_type: str = "jsonl", - raw_download_dir=None, - keep_raw_download=False, - force_download=False, - url_limit=None, + dump_date: Optional[str] = None, + output_type: Literal["jsonl", "parquet"] = "jsonl", + raw_download_dir: Optional[str] = None, + keep_raw_download: bool = False, + force_download: bool = False, + url_limit: Optional[int] = None, + record_limit: Optional[int] = None, ) -> DocumentDataset: """ - Downloads the latest Wikipedia dumps and extracts them using mwparserfromhell + Downloads and extracts articles from a Wikipedia dump. + + This function retrieves a list of Wikipedia dump URLs for the specified language and dump date, + downloads the compressed bz2 dump file (if it is not already present), and extracts its articles + using mwparserfromhell. The resulting articles are saved in the specified output format (e.g., "jsonl") + along with relevant metadata. Args: - output_path: The path to the root directory of the files - language: The language of the Wikipedia articles to download - dump_date: A string formatted as "YYYYMMDD" for the wikipedia dump to use. - If None, latest dump is used. - output_type: The file type to save the data as. - raw_download_dir: Path to store the raw download files for intermediate processing. - If None, they are stored in a folder named "downloads" under output_path. - keep_raw_download: If True, keeps the bz2 files that have not been extracted. - force_download: If False, will skip processing all files in output_paths that already exist and - directly read from them instead. - url_limit: The maximum number of raw files to download from the snapshot. If None, all - files from the range of snapshots are downloaded. + output_path (str): The root directory where the final extracted files and intermediate outputs + (if any) are stored. + language (str, optional): The language code for the Wikipedia dump to download. Default is "en". + dump_date (Optional[str], optional): The dump date in "YYYYMMDD" format. If None, the latest + available dump is downloaded. + output_type (Literal["jsonl", "parquet"], optional): The file format/extension for saving the extracted documents (e.g., "jsonl"). + Defaults to "jsonl". This is not used for the output file, but is used to check if an extracted output + already exists and read it if so. + raw_download_dir (Optional[str], optional): Directory used for temporary storage of raw bz2 dump files. + If None, a subdirectory named "downloads" under output_path is used. + keep_raw_download (bool, optional): If True, retains the raw bz2 files after extraction. + Default is False. + force_download (bool, optional): If False, skips re-downloading or re-extracting files that already exist. + url_limit (Optional[int], optional): The maximum number of dump file URLs to process. If None, all + available URLs are processed. + record_limit (Optional[int], optional): Limit the number of records to extract from each file. + If None, all available records are extracted. + + Returns: + DocumentDataset: A dataset object containing the extracted Wikipedia articles along with associated metadata. """ wikipedia_urls = get_wikipedia_urls(language=language, dump_date=dump_date) if url_limit: @@ -812,6 +827,7 @@ def download_wikipedia( keep_raw_download=keep_raw_download, force_download=force_download, filename_col="file_name", + record_limit=record_limit, ) return dataset diff --git a/nemo_curator/scripts/download_and_extract.py b/nemo_curator/scripts/download_and_extract.py index ef5de89d..aff5fc4d 100644 --- a/nemo_curator/scripts/download_and_extract.py +++ b/nemo_curator/scripts/download_and_extract.py @@ -17,7 +17,7 @@ from nemo_curator.download.doc_builder import batch_download, download_and_extract from nemo_curator.utils.config_utils import build_downloader -from nemo_curator.utils.distributed_utils import get_client +from nemo_curator.utils.distributed_utils import get_client, write_to_disk from nemo_curator.utils.file_utils import ( expand_outdir_and_mkdir, get_all_files_paths_under, @@ -74,10 +74,16 @@ def main(args): keep_raw_download=args.keep_downloaded_files, force_download=args.overwrite_existing_json, input_meta=args.input_meta, + record_limit=args.record_limit, ) # Sample to trigger the dask computation - sample = dataset.df.sample(frac=10 / len(dataset)).compute() + write_to_disk( + dataset.df, + args.output_json_dir, + write_to_filename=True, + output_type="jsonl", + ) def attach_args( @@ -149,6 +155,12 @@ def attach_args( default=None, help="Output directory to store the extracted text in JSONL files.", ) + parser.add_argument( + "--record-limit", + type=int, + default=None, + help="Limit the number of records to extract from each file.", + ) ArgumentHelper.attach_bool_arg( parser, "overwrite-existing-json", diff --git a/tests/test_download.py b/tests/test_download.py index e2a69cb1..b3389e92 100644 --- a/tests/test_download.py +++ b/tests/test_download.py @@ -1,15 +1,57 @@ -from pathlib import Path +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import bz2 +import os +import subprocess +import tarfile +from urllib.parse import urlparse import pytest from nemo_curator.download import ResiliparseExtractor, download_and_extract +from nemo_curator.download.arxiv import ArxivDownloader, ArxivExtractor, ArxivIterator from nemo_curator.download.commoncrawl import ( CommonCrawlWARCDownloader, CommonCrawlWARCExtractor, CommonCrawlWARCIterator, + JusTextExtractor, + ResiliparseExtractor, get_common_crawl_urls, get_stop_list_dict, ) +from nemo_curator.download.wikipedia import ( + WikipediaDownloader, + WikipediaExtractor, + WikipediaIterator, +) + + +class DummyLock: + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + pass + + +class FakeCompletedProcess: + def __init__(self): + self.returncode = 0 + + +def fake_run_success(cmd, stdout, stderr): + return FakeCompletedProcess() class TestDownload: @@ -82,66 +124,18 @@ def test_resiliparse_extract_text(self): assert result == expected - def test_common_crawl_urls(self): - start_snapshot = "2021-04" - end_snapshot = "2021-10" - urls = get_common_crawl_urls(start_snapshot, end_snapshot) - - assert ( - urls[0] - == "https://data.commoncrawl.org/crawl-data/CC-MAIN-2021-10/segments/1614178347293.1/warc/CC-MAIN-20210224165708-20210224195708-00000.warc.gz" - ) - assert ( - urls[-1] - == "https://data.commoncrawl.org/crawl-data/CC-MAIN-2021-04/segments/1610704847953.98/warc/CC-MAIN-20210128134124-20210128164124-00799.warc.gz" - ) - assert len(urls) == 143840 - def test_incorrect_snapshot_order(self): with pytest.raises(ValueError): end_snapshot = "2021-04" start_snapshot = "2021-10" urls = get_common_crawl_urls(start_snapshot, end_snapshot) - def test_common_crawl_news_urls(self): - start_snapshot = "2021-04" - end_snapshot = "2021-10" - urls = get_common_crawl_urls(start_snapshot, end_snapshot, news=True) - - assert ( - urls[0] - == "https://data.commoncrawl.org/crawl-data/CC-NEWS/2021/04/CC-NEWS-20210401004522-01022.warc.gz" - ) - assert ( - urls[-1] - == "https://data.commoncrawl.org/crawl-data/CC-NEWS/2021/10/CC-NEWS-20211031225258-00089.warc.gz" - ) - assert len(urls) == 3838 - def test_incorrect_snapshot_order_news(self): with pytest.raises(ValueError): end_snapshot = "2021-04" start_snapshot = "2021-10" urls = get_common_crawl_urls(start_snapshot, end_snapshot, news=True) - @pytest.mark.skip( - reason="Skipping until we figure out how to get this to a non flaky state" - ) - def test_uneven_common_crawl_range(self): - start_snapshot = "2021-03" - end_snapshot = "2021-11" - urls = get_common_crawl_urls(start_snapshot, end_snapshot) - - assert ( - urls[0] - == "https://data.commoncrawl.org/crawl-data/CC-MAIN-2021-10/segments/1614178347293.1/warc/CC-MAIN-20210224165708-20210224195708-00000.warc.gz" - ) - assert ( - urls[-1] - == "https://data.commoncrawl.org/crawl-data/CC-MAIN-2021-04/segments/1610704847953.98/warc/CC-MAIN-20210128134124-20210128164124-00799.warc.gz" - ) - assert len(urls) == 143840 - def test_no_urls(self): with pytest.raises(ValueError): output_format = { @@ -169,3 +163,319 @@ def test_url_path_mismatch(self): CommonCrawlWARCExtractor(), output_format, ) + + +class TestWikipedia: + def test_wikipedia_downloader_existing_file(self, tmp_path, monkeypatch): + # Create a temporary directory and simulate an already-downloaded file. + download_dir = tmp_path / "downloads" + download_dir.mkdir() + + url = "https://en.wikipedia.org/dummy-file" + parsed = urlparse(url) + output_name = parsed.path[1:].replace("/", "-") # "dummy-file" + file_path = os.path.join(str(download_dir), output_name) + + # Write a dummy file to simulate an existing download. + with open(file_path, "w") as f: + f.write("existing content") + + downloader = WikipediaDownloader(str(download_dir), verbose=False) + + # Monkey-patch subprocess.run (should not be called since file exists). + monkeypatch.setattr(subprocess, "run", fake_run_success) + + result = downloader.download(url) + assert result == file_path + + def test_wikipedia_downloader_new_file(self, tmp_path, monkeypatch): + download_dir = tmp_path / "downloads" + download_dir.mkdir() + + url = "https://en.wikipedia.org/new-file" + parsed = urlparse(url) + output_name = parsed.path[1:].replace("/", "-") # "new-file" + file_path = os.path.join(str(download_dir), output_name) + + # Ensure the file does not exist. + if os.path.exists(file_path): + os.remove(file_path) + + downloader = WikipediaDownloader(str(download_dir), verbose=False) + downloader._lock = DummyLock() + + called_run = False + + def fake_run(cmd, stdout, stderr): + nonlocal called_run + called_run = True + + return FakeCompletedProcess() + + monkeypatch.setattr(subprocess, "run", fake_run) + + result = downloader.download(url) + assert result == file_path + assert called_run + + def test_wikipedia_iterator(self, tmp_path): + # Create a minimal valid XML resembling a Wikipedia dump with one page. + xml_content = """ + + + Test Article + 0 + 123 + + Test content with [[link]] + + +""" + # Compress the XML content using bz2. + compressed_data = bz2.compress(xml_content.encode("utf-8")) + + # Write the compressed data to a temporary file. + temp_file = tmp_path / "test_wiki.xml.bz2" + temp_file.write_bytes(compressed_data) + + iterator = WikipediaIterator(language="en") + pages = list(iterator.iterate(str(temp_file))) + + assert len(pages) == 1 + metadata, raw_text = pages[0] + assert metadata["title"] == "Test Article" + assert metadata["id"] == "123" + # The URL is constructed by quoting the title. + expected_url = "https://en.wikipedia.org/wiki/Test%20Article" + assert metadata["url"] == expected_url + assert "Test content with" in raw_text + + def test_wikipedia_extractor(self): + extractor = WikipediaExtractor(language="en") + # Sample wiki markup; note the presence of a heading and a magic word. + content = "== Heading ==\nThis is a sample article. __NOTOC__" + result = extractor.extract(content) + + # # The extractor should return a dict with a "text" key. + assert isinstance(result, dict) + extracted_text = result.get("text", "") + # Verify that the magic word was removed. + assert "__NOTOC__" not in extracted_text + # Verify that the main content appears. + assert "This is a sample article." in extracted_text + + +class TestArxiv: + def test_arxiv_downloader_existing_file(self, tmp_path, monkeypatch): + # Create a temporary download directory and simulate an already-downloaded tar file. + download_dir = tmp_path / "downloads" + download_dir.mkdir() + tar_filename = "dummy.tar" + file_path = os.path.join(str(download_dir), tar_filename) + # Write dummy content to simulate an existing download. + with open(file_path, "w") as f: + f.write("existing content") + + downloader = ArxivDownloader(str(download_dir), verbose=False) + # Monkey-patch subprocess.run (should not be called since file exists). + monkeypatch.setattr(subprocess, "run", fake_run_success) + result = downloader.download(tar_filename) + assert result == file_path + + def test_arxiv_downloader_new_file(self, tmp_path, monkeypatch): + # Create a temporary download directory and ensure the tar file does not exist. + download_dir = tmp_path / "downloads" + download_dir.mkdir() + tar_filename = "dummy.tar" + file_path = os.path.join(str(download_dir), tar_filename) + if os.path.exists(file_path): + os.remove(file_path) + + downloader = ArxivDownloader(str(download_dir), verbose=False) + called_run = False + + def fake_run(cmd, stdout, stderr): + nonlocal called_run + called_run = True + return FakeCompletedProcess() + + monkeypatch.setattr(subprocess, "run", fake_run) + result = downloader.download(tar_filename) + assert result == file_path + assert called_run + + def test_arxiv_iterator(self, tmp_path): + # Create an inner tar archive containing a .tex file. + inner_tar_path = tmp_path / "2103.00001.tar" + dummy_tex_filename = "2103.00001.tex" + dummy_tex_content = "This is a dummy LaTeX content." + with tarfile.open(inner_tar_path, "w") as inner_tar: + # Create a temporary tex file to add into the inner tar archive. + temp_tex_path = tmp_path / dummy_tex_filename + with open(temp_tex_path, "w") as f: + f.write(dummy_tex_content) + inner_tar.add(temp_tex_path, arcname=dummy_tex_filename) + + # Create an outer tar archive that contains the inner tar archive. + outer_tar_path = tmp_path / "dummy_main.tar" + with tarfile.open(outer_tar_path, "w") as outer_tar: + outer_tar.add(inner_tar_path, arcname="2103.00001.tar") + + iterator = ArxivIterator(log_frequency=1) + results = list(iterator.iterate(str(outer_tar_path))) + # Expect one paper extracted. + assert len(results) == 1 + metadata, tex_files = results[0] + # The ArxivIterator extracts the arxiv id from the inner archive's filename. + assert metadata["id"] == "2103.00001" + # The source_id is set to the outer tar file's basename. + assert metadata["source_id"] == "dummy_main.tar" + # Verify that the tex extraction returns the dummy content. + assert isinstance(tex_files, list) + assert dummy_tex_content in tex_files[0] + + def test_arxiv_extractor(self): + extractor = ArxivExtractor() + # Create a minimal LaTeX document including comments and a section header. + content = r""" + % This is a comment line that should be removed. + \section{Introduction} + This is the introduction of the paper. + % Another comment that should vanish. + """ + result = extractor.extract([content]) + assert isinstance(result, dict) + extracted_text = result.get("text", "") + # Verify that comments have been removed. + assert "% This is a comment" not in extracted_text + # Verify that the section header content is retained. + assert "Introduction" in extracted_text + assert "This is the introduction" in extracted_text + + +class TestCommonCrawl: + def test_common_crawl_downloader_existing_file(self, tmp_path, monkeypatch): + # Create a temporary downloads directory and simulate an already-downloaded file. + download_dir = tmp_path / "downloads" + download_dir.mkdir() + url = "http://dummy/commoncrawl.warc" + parsed = urlparse(url) + output_name = parsed.path[1:].replace("/", "-") # "commoncrawl.warc" + file_path = os.path.join(str(download_dir), output_name) + # Write dummy content to simulate an existing download. + with open(file_path, "w") as f: + f.write("existing content") + + downloader = CommonCrawlWARCDownloader( + str(download_dir), aws=False, verbose=False + ) + + # Monkey-patch subprocess.run to track if it gets called. + called_run = False + + def fake_run(cmd, stdout, stderr): + nonlocal called_run + called_run = True + return FakeCompletedProcess() + + monkeypatch.setattr(subprocess, "run", fake_run) + + result = downloader.download(url) + assert result == file_path + # Since the file already exists, no download should be attempted. + assert not called_run + + def test_common_crawl_downloader_new_file(self, tmp_path, monkeypatch): + # Create a temporary downloads directory; ensure the file does not exist. + download_dir = tmp_path / "downloads" + download_dir.mkdir() + url = "http://dummy/commoncrawl.warc" + parsed = urlparse(url) + output_name = parsed.path[1:].replace("/", "-") # "commoncrawl.warc" + file_path = os.path.join(str(download_dir), output_name) + if os.path.exists(file_path): + os.remove(file_path) + + downloader = CommonCrawlWARCDownloader( + str(download_dir), aws=False, verbose=False + ) + + called_run = False + + def fake_run(cmd, stdout, stderr): + nonlocal called_run + called_run = True + return FakeCompletedProcess() + + monkeypatch.setattr(subprocess, "run", fake_run) + + result = downloader.download(url) + assert result == file_path + # Since the file did not exist, a download call (and subprocess.run) should have been made. + assert called_run + + def test_common_crawl_iterator(self, tmp_path): + # Create a minimal valid WARC file with a single "response" record. + raw_warc_path = tmp_path / "dummy.warc" + http_response = ( + "HTTP/1.1 200 OK\r\n" + "Content-Type: text/html\r\n" + "\r\n" + "

Common Crawl test paragraph with some content.

\r\n" + ) + http_response_bytes = http_response.encode("utf-8") + content_length = len(http_response_bytes) + warc_record = ( + ( + f"WARC/1.0\r\n" + f"WARC-Type: response\r\n" + f"WARC-Record-ID: \r\n" + f"WARC-Date: 2022-01-01T00:00:00Z\r\n" + f"WARC-Target-URI: http://example.com\r\n" + f"Content-Length: {content_length}\r\n" + f"\r\n" + ).encode("utf-8") + + http_response_bytes + + b"\r\n\r\n" + ) + raw_warc_path.write_bytes(warc_record) + + iterator = CommonCrawlWARCIterator(log_frequency=1) + records = list(iterator.iterate(str(raw_warc_path))) + assert len(records) == 1 + meta, content = records[0] + # Check that the URL from the header is captured. + assert "example.com" in meta["url"] + # Verify that the content includes our test paragraph. + assert b"Common Crawl test paragraph" in content + + def test_common_crawl_extractor_justext(self): + extractor = CommonCrawlWARCExtractor(algorithm=JusTextExtractor()) + html = ( + "

Common Crawl test paragraph for justext extractor. " + "Four score and seven years ago our fathers brought forth on this continent a new nation, " + "conceived in liberty, and dedicated to the proposition that all men are created equal.

" + ) + content = html.encode("utf-8") + result = extractor.extract(content) + print(result) + assert result is not None + # The extracted text should include our test paragraph. + assert "Common Crawl test paragraph for justext extractor." in result["text"] + assert "language" in result + + def test_common_crawl_extractor_resiliparse(self): + extractor = CommonCrawlWARCExtractor(algorithm=ResiliparseExtractor()) + html = ( + "

Common Crawl test paragraph for resiliparse extractor. " + "Four score and seven years ago our fathers brought forth on this continent a new nation, " + "conceived in liberty, and dedicated to the proposition that all men are created equal.

" + ) + content = html.encode("utf-8") + result = extractor.extract(content) + print(result) + assert result is not None + assert ( + "Common Crawl test paragraph for resiliparse extractor." in result["text"] + ) + assert "language" in result