diff --git a/docs/user-guide/download.rst b/docs/user-guide/download.rst
index d4b854e4..d2b50e0c 100644
--- a/docs/user-guide/download.rst
+++ b/docs/user-guide/download.rst
@@ -37,10 +37,42 @@ By "extraction", we typically mean the process of converting a data format from
Otherwise, the HTTPS endpoints will be used with ``wget``. Here is a small example of how to use it:
.. code-block:: python
-
+ import os
+ from nemo_curator import get_client
from nemo_curator.download import download_common_crawl
-
- common_crawl = download_common_crawl("/extracted/output/folder", "2020-50", "2021-04", output_type="jsonl")
+ from nemo_curator.datasets import DocumentDataset
+
+ def main():
+ # Initialize a distributed Dask client
+ client = get_client(cluster_type="cpu")
+
+ # Parameters for downloading Common Crawl data.
+ # - output_folder: directory for temporary download/extraction files
+ # - start_snapshot and end_snapshot define the range to fetch
+ # - output_type: specifies file format for the extracted data (e.g., "jsonl")
+ output_folder = "/extracted/output/folder"
+ start_snapshot = "2020-50"
+ end_snapshot = "2021-04"
+ output_type = "jsonl"
+ os.makedirs(output_folder, exist_ok=True)
+
+ # Download and extract the Common Crawl data.
+ # The function returns a DocumentDataset that contains the extracted documents.
+ # Note: The output folder and output type are passed here to store intermediate files
+ # and check if the data has already been downloaded. They should match the final location
+ # and format of the extracted data.
+ common_crawl_dataset = download_common_crawl(
+ output_folder, start_snapshot, end_snapshot, output_type=output_type
+ )
+
+ # Write the extracted dataset to JSON format.
+ # The 'to_json' method will write one JSON document per line,
+ # preserving the original shard information if write_to_filename is True.
+ common_crawl_dataset.to_json(output_path=output_folder, write_to_filename=True)
+ print("Extracted dataset saved to:", output_folder)
+
+ if __name__ == "__main__":
+ main()
* ``"/extracted/output/folder"`` is the path to on your local filesystem where the final extracted files will be placed.
* ``"2020-50"`` is the first common crawl snapshot that will be included in the download. **Note:** Not every year and week has a snapshot. Ensure that your range includes at least one valid Common Crawl snapshot. A list of valid Common Crawl snapshots can be found `here `_.
@@ -50,21 +82,49 @@ By "extraction", we typically mean the process of converting a data format from
You can choose to modify the HTML text extraction algorithm used in ``download_common_crawl``. See an example below.
.. code-block:: python
-
+ import os
+ from nemo_curator import get_client
from nemo_curator.download import (
- ResiliparseExtractor,
- download_common_crawl,
- )
-
- # Change the extraction algorithm
- extraction_algorithm = ResiliparseExtractor()
- common_crawl = download_common_crawl(
- "/extracted/output/folder",
- "2020-50",
- "2021-04",
- output_type="jsonl",
- algorithm=extraction_algorithm,
+ ResiliparseExtractor,
+ download_common_crawl,
)
+ from nemo_curator.datasets import DocumentDataset
+
+ def main():
+ # Initialize a distributed Dask client
+ client = get_client(cluster_type="cpu")
+
+ # Parameters for downloading Common Crawl data.
+ # - output_folder: directory for temporary download/extraction files
+ # - start_snapshot and end_snapshot define the range to fetch
+ # - output_type: specifies file format for the extracted data (e.g., "jsonl")
+ output_folder = "/extracted/output/folder"
+ start_snapshot = "2020-50"
+ end_snapshot = "2021-04"
+ output_type = "jsonl"
+ os.makedirs(output_folder, exist_ok=True)
+
+ # Change the extraction algorithm to use ResiliparseExtractor
+ extraction_algorithm = ResiliparseExtractor()
+
+ # Download and extract the Common Crawl data using the Resiliparse extraction algorithm.
+ # The function returns a DocumentDataset that contains the extracted documents.
+ common_crawl_dataset = download_common_crawl(
+ output_folder,
+ start_snapshot,
+ end_snapshot,
+ output_type=output_type,
+ algorithm=extraction_algorithm,
+ )
+
+ # Write the extracted dataset to JSON format.
+ # The 'to_json' method writes one JSON document per line,
+ # preserving the original shard information if write_to_filename is True.
+ common_crawl_dataset.to_json(output_path=output_folder, write_to_filename=True)
+ print("Extracted dataset saved to:", output_folder)
+
+ if __name__ == "__main__":
+ main()
Above, we changed the extraction algorithm from the default ``JusTextExtractor``.
diff --git a/nemo_curator/download/arxiv.py b/nemo_curator/download/arxiv.py
index 538d567d..e7e5327a 100644
--- a/nemo_curator/download/arxiv.py
+++ b/nemo_curator/download/arxiv.py
@@ -18,6 +18,7 @@
import subprocess
import tarfile
import tempfile
+from typing import Literal, Optional
from nemo_curator.datasets import DocumentDataset
from nemo_curator.download.doc_builder import (
@@ -218,12 +219,12 @@ def extract(self, content):
for file_content in content
)
except Exception:
- return {}, None
+ return None
# Don't return meta
if cleaned_latex_file_str is not None:
if len(cleaned_latex_file_str) > 0:
- return {}, cleaned_latex_file_str
+ return {"text": cleaned_latex_file_str}
def _clean_tex_file(self, file_content, arg_macros, non_arg_macros):
r"""function takes a tex file as input and returns a cleaned version. The
@@ -365,25 +366,44 @@ def _build_non_arg_macros_dict(self, file_content):
def download_arxiv(
output_path: str,
- output_type: str = "jsonl",
- raw_download_dir=None,
- keep_raw_download=False,
- force_download=False,
- url_limit=None,
+ output_type: Literal["jsonl", "parquet"] = "jsonl",
+ raw_download_dir: Optional[str] = None,
+ keep_raw_download: bool = False,
+ force_download: bool = False,
+ url_limit: Optional[int] = None,
+ record_limit: Optional[int] = None,
) -> DocumentDataset:
"""
- Downloads Arxiv tar files and extracts them
+ Download Arxiv tar files and extract the contained LaTeX projects.
+
+ This function obtains a list of Arxiv tar file URLs (via get_arxiv_urls), downloads the tar files,
+ and then extracts the contained LaTeX source files. The resulting documents (after extraction) are
+ assembled into a DocumentDataset.
Args:
- output_path: The path to the root directory of the files
- output_type: The file type to save the data as.
- raw_download_dir: Path to store the raw download files for intermediate processing.
- If None, they are stored in a folder named "downloads" under output_path.
- keep_raw_download: If True, keeps the compressed WARC files that have not been extracted.
- force_download: If False, will skip processing all files in output_paths that already exist and
- directly read from them instead.
- url_limit: The maximum number of raw files to download from the snapshot. If None, all
- files from the range of snapshots are downloaded.
+ output_path (str):
+ The root directory where both the final extracted files and the raw download subdirectory will be stored.
+ The extracted files (in the format specified by output_type) are eventually saved in this directory.
+ output_type (Literal["jsonl", "parquet"], optional):
+ The file format/extension used for saving the extracted documents (e.g., "jsonl" or "parquet").
+ Default is "jsonl". This is not used for the output file, but is used to check if an extracted output already exists and read it if so.
+ raw_download_dir (Optional[str], optional):
+ The directory where the raw downloaded tar files will be kept. If None, a folder named "downloads"
+ under output_path is used.
+ keep_raw_download (bool, optional):
+ If True, the raw tar files (before extraction) are not removed after processing. Default is False.
+ force_download (bool, optional):
+ If False, then if an output file already exists for a given URL, re-downloading and re-extraction will be skipped.
+ Default is False.
+ url_limit (Optional[int], optional):
+ Limits the maximum number of Arxiv tar file URLs to download and process.
+ If None, all available URLs (from get_arxiv_urls) are processed.
+ record_limit (Optional[int], optional):
+ Limits the maximum number of records to extract from each tar file.
+ If None, all available records are extracted.
+ Returns:
+ DocumentDataset:
+ A dataset object containing the extracted documents.
"""
arxiv_urls = get_arxiv_urls()
if url_limit:
@@ -416,6 +436,7 @@ def download_arxiv(
keep_raw_download=keep_raw_download,
force_download=force_download,
filename_col="file_name",
+ record_limit=record_limit,
)
return dataset
diff --git a/nemo_curator/download/commoncrawl.py b/nemo_curator/download/commoncrawl.py
index de7e333b..c651d4d6 100644
--- a/nemo_curator/download/commoncrawl.py
+++ b/nemo_curator/download/commoncrawl.py
@@ -17,6 +17,7 @@
import subprocess
import unicodedata
from abc import ABC, abstractmethod
+from typing import Literal, Optional
from urllib.parse import urlparse
import justext
@@ -352,48 +353,54 @@ def extract(self, content):
if text is not None:
if len(text) > 0:
text = "\n\n".join(text)
- meta = {"language": lang}
- return meta, text
+ meta = {"language": lang, "text": text}
+ return meta
else:
- return None, None
+ return None
def download_common_crawl(
output_path: str,
start_snapshot: str,
end_snapshot: str,
- output_type: str = "jsonl",
+ output_type: Literal["jsonl", "parquet"] = "jsonl",
algorithm=JusTextExtractor(),
- news=False,
- aws=False,
- raw_download_dir=None,
- keep_raw_download=False,
- force_download=False,
- url_limit=None,
+ news: bool = False,
+ aws: bool = False,
+ raw_download_dir: Optional[str] = None,
+ keep_raw_download: bool = False,
+ force_download: bool = False,
+ url_limit: Optional[int] = None,
+ record_limit: Optional[int] = None,
) -> DocumentDataset:
"""
- Downloads Common Crawl WARC snapshots and extracts them using jusText or Resiliparse
+ Downloads Common Crawl WARC snapshots and extracts text content using a specified extraction algorithm.
Args:
- output_path: The path to the root directory of the files
- start_snapshot: The first common crawl snapshot to include. Snapshots must be
- specified by YYYY-WeekNumber (e.g., '2020-50' or '2021-04'). For the CC-NEWS dataset,
- (specified with news=True flag) this changes to Year-Month (YYYY-MM).
- end_snapshot: The last common crawl snapshot to include. Must be chronologically
- after the starting snapshot.
- output_type: The file type to save the data as.
- algorithm: A JusTextExtractor or ResiliparseExtractor object.
- news: If True, gets WARC URLs for the CC-NEWS dataset instead of the CC-MAIN datasets.
- Also assumes that the format for the start and end snapshots is 'YYYY-MM' (Year-Month).
- aws: Whether to download from Common Crawl's S3 bucket. If True, uses s5cmd to download.
- If False, uses wget.
- raw_download_dir: Path to store the raw download files for intermediate processing.
- If None, they are stored in a folder named "downloads" under output_path.
- keep_raw_download: If True, keeps the compressed WARC files that have not been extracted.
- force_download: If False, will skip processing all files in output_paths that already exist and
- directly read from them instead.
- url_limit: The maximum number of raw files to download from the snapshot. If None, all
- files from the range of snapshots are downloaded.
+ output_path (str): The root directory used for managing download and extraction.
+ • Raw WARC files are stored in a "downloads" subdirectory under this path.
+ • This path is also checked for existing extraction results; if found, extraction can be skipped.
+ • Note: This function returns a DocumentDataset, and writing the extracted data to disk is the caller's responsibility.
+ start_snapshot (str): Identifier for the earliest snapshot to process.
+ • For CC-MAIN datasets, use the format 'YYYY-WeekNumber' (e.g., '2020-50' or '2021-04').
+ • For CC-NEWS datasets (when news=True), use the 'YYYY-MM' (Year-Month) format.
+ end_snapshot (str): Identifier for the latest snapshot to process, which must be chronologically after start_snapshot.
+ output_type (Literal["jsonl", "parquet"]): The file format for the extracted output. Must be either "jsonl" or "parquet".
+ • This is not used for the output file, but is used to check if an extracted output already exists.
+ algorithm: The text extraction algorithm instance (e.g., JusTextExtractor or ResiliparseExtractor) to use for HTML processing.
+ news (bool): When True, indicates that URLs should be retrieved from the CC-NEWS dataset.
+ • This also means snapshot identifiers should follow the 'YYYY-MM' format.
+ aws (bool): If True, downloads are sourced from Common Crawl's S3 bucket using s5cmd;
+ • If False, wget is used to fetch the files via HTTPS.
+ raw_download_dir: Optional; the directory to temporarily store raw WARC files.
+ • If not provided, defaults to a "downloads" folder within output_path.
+ keep_raw_download (bool): If True, retains the downloaded raw WARC files after extraction.
+ • If False, these raw files may be removed following extraction.
+ force_download (bool): If False, skips re-downloading or re-extracting snapshots if outputs already exist in output_path.
+ url_limit: Optional; the maximum number of WARC files to download from the snapshot range.
+ • If None, all available files within the specified snapshots are downloaded.
+ record_limit: Optional; the maximum number of records to extract from each WARC file.
+ • If None, all available records are extracted.
"""
common_crawl_urls = get_common_crawl_urls(
starting_snapshot=start_snapshot, ending_snapshot=end_snapshot, news=news
@@ -443,6 +450,7 @@ def download_common_crawl(
keep_raw_download=keep_raw_download,
force_download=force_download,
filename_col="file_name",
+ record_limit=record_limit,
)
return dataset
diff --git a/nemo_curator/download/doc_builder.py b/nemo_curator/download/doc_builder.py
index 5ad3094e..9d3e849b 100644
--- a/nemo_curator/download/doc_builder.py
+++ b/nemo_curator/download/doc_builder.py
@@ -15,17 +15,14 @@
import importlib
import os
from abc import ABC, abstractmethod
-from typing import List, Tuple, Union
+from typing import List, Literal, Optional, Tuple, Union
import dask.dataframe as dd
import pandas as pd
from dask import compute, delayed
from nemo_curator.datasets import DocumentDataset
-from nemo_curator.utils.distributed_utils import (
- read_single_partition,
- single_partition_write_with_filename,
-)
+from nemo_curator.utils.distributed_utils import read_single_partition
class DocumentDownloader(ABC):
@@ -108,14 +105,36 @@ def _download_and_extract_single_partition(
downloader: DocumentDownloader,
iterator: DocumentIterator,
extractor: DocumentExtractor,
- output_type: str,
+ output_type: Literal["jsonl", "parquet"],
keep_raw_download: bool,
force_download: bool,
input_meta: Union[str, dict] = None,
filename_col: str = "file_name",
+ record_limit: Optional[int] = None,
) -> pd.DataFrame:
+ """
+ Downloads a single partition from a URL and extracts its contents in-memory without writing
+ the extracted output to disk. The provided output_path is used only to check if an extracted
+ output already exists and, if so, skips re-downloading and extraction.
+
+ Parameters:
+ paths (Tuple[str, str]): A tuple (url, output_path) where 'url' is the source URL and
+ 'output_path' is the expected location of a previously extracted output.
+ downloader (DocumentDownloader): An object to download the content from the URL.
+ iterator (DocumentIterator): An object to iterate over records in the downloaded file.
+ extractor (DocumentExtractor): An object to extract the desired content from each record.
+ output_type (Literal["jsonl", "parquet"]): The output file format/extension. Must be either "jsonl" or "parquet". Defaults to "jsonl". This parameter is only used to verify whether an extracted output already exists.
+ keep_raw_download (bool): If False, deletes the raw download file after extraction.
+ force_download (bool): If False and output_path exists, skips downloading and extraction.
+ input_meta (Union[str, dict], optional): Metadata describing the input file's structure.
+ filename_col (str, optional): Name of the column to store the filename within the result DataFrame.
+ record_limit (int, optional): Limit the number of records to extract from each file.
+ Returns:
+ pd.DataFrame: A DataFrame containing the extracted records.
+ """
url, output_path = paths
+ # If an extracted output already exists and we're not forcing a download, load and return it.
if os.path.exists(output_path) and not force_download:
partition = read_single_partition(
[output_path],
@@ -125,31 +144,26 @@ def _download_and_extract_single_partition(
)
return partition
+ # Download the file and extract its records in memory.
downloaded_file = downloader.download(url)
records = []
- # Iterate over all records in file
for item in iterator.iterate(downloaded_file):
+ if record_limit is not None and len(records) >= record_limit:
+ break
record_meta, content = item
- # Extract the text from the record
extracted = extractor.extract(content)
if extracted is not None:
- text_meta, text = extracted
- if text is not None:
- line = {
- "text": text,
- **text_meta,
- **record_meta,
- }
- records.append(line)
+ # Merge the extracted data and record metadata into one dictionary.
+ line = {**extracted, **record_meta}
+ records.append(line)
partition = pd.DataFrame(records)
- filename = os.path.basename(output_path)
- output_dir = os.path.dirname(output_path)
- partition[filename_col] = filename
- single_partition_write_with_filename(
- partition, output_dir, output_type=output_type, filename_col=filename_col
- )
- if not keep_raw_download:
+ # Add a filename column for consistency using the basename of the output_path.
+ partition[filename_col] = os.path.basename(output_path)
+
+ # Since we're not writing the extracted partition to disk, the output_path is not used here.
+ # Clean up the raw downloaded file if it's not meant to be kept.
+ if not keep_raw_download and os.path.exists(downloaded_file):
os.remove(downloaded_file)
return partition
@@ -162,42 +176,76 @@ def download_and_extract(
iterator: DocumentIterator,
extractor: DocumentExtractor,
output_format: dict,
- output_type: str = "jsonl",
- keep_raw_download=False,
- force_download=False,
+ output_type: Literal["jsonl", "parquet"] = "jsonl",
+ keep_raw_download: bool = False,
+ force_download: bool = False,
input_meta: Union[str, dict] = None,
filename_col: str = "file_name",
+ record_limit: Optional[int] = None,
) -> DocumentDataset:
"""
- Downloads and extracts a dataset into a format accepted by the NeMo Curator
+ Download files from the given URLs, extract their records, and
+ construct a DocumentDataset.
+
+ For each URL provided, this function downloads the corresponding
+ file (unless an extracted output already exists and force_download is
+ False), iterates over its records, extracts the desired content, and
+ finally converts all records into a DocumentDataset.
Args:
- urls: A list of urls to download the dataset from
- output_paths: A list of paths to save the final extracted output to.
- The raw output of the downloader will be saved using the path given by downloader.download(url).
- downloader: A DocumentDownloader that handles retrieving each file from its url and saving it to storage
- iterator: A DocumentIterator that handles iterating through the downloaded file's format
- extractor: A DocumentExtractor that handles extracting the data from its raw format into text
- output_format: A dictionary mappings columns to datatypes for the fields of each datapoint after extraction.
- output_type: The file type to save the dataset as.
- keep_raw_download: Whether to keep the pre-extracted download file.
- force_download: If False, will skip processing all files in output_paths that already exist and
- directly read from them instead.
- input_meta: A dictionary or a string formatted as a dictionary, which outlines
- the field names and their respective data types within the JSONL input file.
- filename_col : The name of the column that contains the filename. Default is "filename_col"
+ urls (List[str]):
+ A list of URLs from which to download dataset files.
+ output_paths (List[str]):
+ A list of file paths where the extracted outputs should be
+ found. If a file already exists at a given path and force_download
+ is False, that partition is skipped.
+ downloader (DocumentDownloader):
+ The downloader instance responsible for fetching files from
+ the specified URLs.
+ iterator (DocumentIterator):
+ The iterator instance used to traverse the downloaded file
+ and yield records.
+ extractor (DocumentExtractor):
+ The extractor instance used to obtain the desired content from
+ each record.
+ output_format (dict):
+ A dictionary mapping column names to the data types for the
+ extracted records.
+ output_type (Literal["jsonl", "parquet"], optional):
+ The output file format/extension. Must be either "jsonl" or "parquet".
+ Defaults to "jsonl". This parameter is only used to verify whether
+ an extracted output already exists.
+ keep_raw_download (bool, optional):
+ If True, the raw downloaded files are retained after extraction.
+ Defaults to False.
+ force_download (bool, optional):
+ If False and an output file already exists at a given path, the
+ download and extraction for that file are skipped.
+ Defaults to False.
+ input_meta (Union[str, dict], optional):
+ Optional metadata describing the input file's schema.
+ Defaults to None.
+ filename_col (str, optional):
+ The name for the column in the resulting dataset that records
+ the basename of the output file. Defaults to "file_name".
+ record_limit (int, optional): Limit the number of records to extract from each file.
+ Defaults to None.
Returns:
- A DocumentDataset of the downloaded data
+ DocumentDataset:
+ A dataset composed of the records extracted from the downloaded
+ files.
"""
- if len(urls) == 0:
- raise ValueError("No urls were provided to download")
-
+ # Validate parameters
+ if not urls:
+ raise ValueError("No URLs were provided to download")
if len(urls) != len(output_paths):
raise ValueError(
- f"Different number of urls and output_paths. {len(urls)} urls vs {len(output_paths)} output_paths"
+ f"Different number of URLs and output_paths. {len(urls)} URLs vs {len(output_paths)} output_paths"
)
+ # Ensure consistent ordering of output_format keys.
output_format = dict(sorted(output_format.items()))
+
df = dd.from_map(
_download_and_extract_single_partition,
zip(urls, output_paths),
@@ -210,6 +258,7 @@ def download_and_extract(
enforce_metadata=False,
input_meta=input_meta,
filename_col=filename_col,
+ record_limit=record_limit,
meta=output_format,
)
diff --git a/nemo_curator/download/wikipedia.py b/nemo_curator/download/wikipedia.py
index 54b85d45..e494ed38 100644
--- a/nemo_curator/download/wikipedia.py
+++ b/nemo_curator/download/wikipedia.py
@@ -18,6 +18,7 @@
import re
import subprocess
import xml.etree.cElementTree as etree
+from typing import Literal, Optional
from urllib.parse import quote, urlparse
import mwparserfromhell
@@ -742,35 +743,49 @@ def try_remove_obj(obj, section):
)
)
# Don't return any meta here
- return {}, "\n\n".join(section_text)
+ return {"text": "\n\n".join(section_text)}
def download_wikipedia(
output_path: str,
language: str = "en",
- dump_date=None,
- output_type: str = "jsonl",
- raw_download_dir=None,
- keep_raw_download=False,
- force_download=False,
- url_limit=None,
+ dump_date: Optional[str] = None,
+ output_type: Literal["jsonl", "parquet"] = "jsonl",
+ raw_download_dir: Optional[str] = None,
+ keep_raw_download: bool = False,
+ force_download: bool = False,
+ url_limit: Optional[int] = None,
+ record_limit: Optional[int] = None,
) -> DocumentDataset:
"""
- Downloads the latest Wikipedia dumps and extracts them using mwparserfromhell
+ Downloads and extracts articles from a Wikipedia dump.
+
+ This function retrieves a list of Wikipedia dump URLs for the specified language and dump date,
+ downloads the compressed bz2 dump file (if it is not already present), and extracts its articles
+ using mwparserfromhell. The resulting articles are saved in the specified output format (e.g., "jsonl")
+ along with relevant metadata.
Args:
- output_path: The path to the root directory of the files
- language: The language of the Wikipedia articles to download
- dump_date: A string formatted as "YYYYMMDD" for the wikipedia dump to use.
- If None, latest dump is used.
- output_type: The file type to save the data as.
- raw_download_dir: Path to store the raw download files for intermediate processing.
- If None, they are stored in a folder named "downloads" under output_path.
- keep_raw_download: If True, keeps the bz2 files that have not been extracted.
- force_download: If False, will skip processing all files in output_paths that already exist and
- directly read from them instead.
- url_limit: The maximum number of raw files to download from the snapshot. If None, all
- files from the range of snapshots are downloaded.
+ output_path (str): The root directory where the final extracted files and intermediate outputs
+ (if any) are stored.
+ language (str, optional): The language code for the Wikipedia dump to download. Default is "en".
+ dump_date (Optional[str], optional): The dump date in "YYYYMMDD" format. If None, the latest
+ available dump is downloaded.
+ output_type (Literal["jsonl", "parquet"], optional): The file format/extension for saving the extracted documents (e.g., "jsonl").
+ Defaults to "jsonl". This is not used for the output file, but is used to check if an extracted output
+ already exists and read it if so.
+ raw_download_dir (Optional[str], optional): Directory used for temporary storage of raw bz2 dump files.
+ If None, a subdirectory named "downloads" under output_path is used.
+ keep_raw_download (bool, optional): If True, retains the raw bz2 files after extraction.
+ Default is False.
+ force_download (bool, optional): If False, skips re-downloading or re-extracting files that already exist.
+ url_limit (Optional[int], optional): The maximum number of dump file URLs to process. If None, all
+ available URLs are processed.
+ record_limit (Optional[int], optional): Limit the number of records to extract from each file.
+ If None, all available records are extracted.
+
+ Returns:
+ DocumentDataset: A dataset object containing the extracted Wikipedia articles along with associated metadata.
"""
wikipedia_urls = get_wikipedia_urls(language=language, dump_date=dump_date)
if url_limit:
@@ -812,6 +827,7 @@ def download_wikipedia(
keep_raw_download=keep_raw_download,
force_download=force_download,
filename_col="file_name",
+ record_limit=record_limit,
)
return dataset
diff --git a/nemo_curator/scripts/download_and_extract.py b/nemo_curator/scripts/download_and_extract.py
index ef5de89d..aff5fc4d 100644
--- a/nemo_curator/scripts/download_and_extract.py
+++ b/nemo_curator/scripts/download_and_extract.py
@@ -17,7 +17,7 @@
from nemo_curator.download.doc_builder import batch_download, download_and_extract
from nemo_curator.utils.config_utils import build_downloader
-from nemo_curator.utils.distributed_utils import get_client
+from nemo_curator.utils.distributed_utils import get_client, write_to_disk
from nemo_curator.utils.file_utils import (
expand_outdir_and_mkdir,
get_all_files_paths_under,
@@ -74,10 +74,16 @@ def main(args):
keep_raw_download=args.keep_downloaded_files,
force_download=args.overwrite_existing_json,
input_meta=args.input_meta,
+ record_limit=args.record_limit,
)
# Sample to trigger the dask computation
- sample = dataset.df.sample(frac=10 / len(dataset)).compute()
+ write_to_disk(
+ dataset.df,
+ args.output_json_dir,
+ write_to_filename=True,
+ output_type="jsonl",
+ )
def attach_args(
@@ -149,6 +155,12 @@ def attach_args(
default=None,
help="Output directory to store the extracted text in JSONL files.",
)
+ parser.add_argument(
+ "--record-limit",
+ type=int,
+ default=None,
+ help="Limit the number of records to extract from each file.",
+ )
ArgumentHelper.attach_bool_arg(
parser,
"overwrite-existing-json",
diff --git a/tests/test_download.py b/tests/test_download.py
index e2a69cb1..b3389e92 100644
--- a/tests/test_download.py
+++ b/tests/test_download.py
@@ -1,15 +1,57 @@
-from pathlib import Path
+# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import bz2
+import os
+import subprocess
+import tarfile
+from urllib.parse import urlparse
import pytest
from nemo_curator.download import ResiliparseExtractor, download_and_extract
+from nemo_curator.download.arxiv import ArxivDownloader, ArxivExtractor, ArxivIterator
from nemo_curator.download.commoncrawl import (
CommonCrawlWARCDownloader,
CommonCrawlWARCExtractor,
CommonCrawlWARCIterator,
+ JusTextExtractor,
+ ResiliparseExtractor,
get_common_crawl_urls,
get_stop_list_dict,
)
+from nemo_curator.download.wikipedia import (
+ WikipediaDownloader,
+ WikipediaExtractor,
+ WikipediaIterator,
+)
+
+
+class DummyLock:
+ def __enter__(self):
+ return self
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ pass
+
+
+class FakeCompletedProcess:
+ def __init__(self):
+ self.returncode = 0
+
+
+def fake_run_success(cmd, stdout, stderr):
+ return FakeCompletedProcess()
class TestDownload:
@@ -82,66 +124,18 @@ def test_resiliparse_extract_text(self):
assert result == expected
- def test_common_crawl_urls(self):
- start_snapshot = "2021-04"
- end_snapshot = "2021-10"
- urls = get_common_crawl_urls(start_snapshot, end_snapshot)
-
- assert (
- urls[0]
- == "https://data.commoncrawl.org/crawl-data/CC-MAIN-2021-10/segments/1614178347293.1/warc/CC-MAIN-20210224165708-20210224195708-00000.warc.gz"
- )
- assert (
- urls[-1]
- == "https://data.commoncrawl.org/crawl-data/CC-MAIN-2021-04/segments/1610704847953.98/warc/CC-MAIN-20210128134124-20210128164124-00799.warc.gz"
- )
- assert len(urls) == 143840
-
def test_incorrect_snapshot_order(self):
with pytest.raises(ValueError):
end_snapshot = "2021-04"
start_snapshot = "2021-10"
urls = get_common_crawl_urls(start_snapshot, end_snapshot)
- def test_common_crawl_news_urls(self):
- start_snapshot = "2021-04"
- end_snapshot = "2021-10"
- urls = get_common_crawl_urls(start_snapshot, end_snapshot, news=True)
-
- assert (
- urls[0]
- == "https://data.commoncrawl.org/crawl-data/CC-NEWS/2021/04/CC-NEWS-20210401004522-01022.warc.gz"
- )
- assert (
- urls[-1]
- == "https://data.commoncrawl.org/crawl-data/CC-NEWS/2021/10/CC-NEWS-20211031225258-00089.warc.gz"
- )
- assert len(urls) == 3838
-
def test_incorrect_snapshot_order_news(self):
with pytest.raises(ValueError):
end_snapshot = "2021-04"
start_snapshot = "2021-10"
urls = get_common_crawl_urls(start_snapshot, end_snapshot, news=True)
- @pytest.mark.skip(
- reason="Skipping until we figure out how to get this to a non flaky state"
- )
- def test_uneven_common_crawl_range(self):
- start_snapshot = "2021-03"
- end_snapshot = "2021-11"
- urls = get_common_crawl_urls(start_snapshot, end_snapshot)
-
- assert (
- urls[0]
- == "https://data.commoncrawl.org/crawl-data/CC-MAIN-2021-10/segments/1614178347293.1/warc/CC-MAIN-20210224165708-20210224195708-00000.warc.gz"
- )
- assert (
- urls[-1]
- == "https://data.commoncrawl.org/crawl-data/CC-MAIN-2021-04/segments/1610704847953.98/warc/CC-MAIN-20210128134124-20210128164124-00799.warc.gz"
- )
- assert len(urls) == 143840
-
def test_no_urls(self):
with pytest.raises(ValueError):
output_format = {
@@ -169,3 +163,319 @@ def test_url_path_mismatch(self):
CommonCrawlWARCExtractor(),
output_format,
)
+
+
+class TestWikipedia:
+ def test_wikipedia_downloader_existing_file(self, tmp_path, monkeypatch):
+ # Create a temporary directory and simulate an already-downloaded file.
+ download_dir = tmp_path / "downloads"
+ download_dir.mkdir()
+
+ url = "https://en.wikipedia.org/dummy-file"
+ parsed = urlparse(url)
+ output_name = parsed.path[1:].replace("/", "-") # "dummy-file"
+ file_path = os.path.join(str(download_dir), output_name)
+
+ # Write a dummy file to simulate an existing download.
+ with open(file_path, "w") as f:
+ f.write("existing content")
+
+ downloader = WikipediaDownloader(str(download_dir), verbose=False)
+
+ # Monkey-patch subprocess.run (should not be called since file exists).
+ monkeypatch.setattr(subprocess, "run", fake_run_success)
+
+ result = downloader.download(url)
+ assert result == file_path
+
+ def test_wikipedia_downloader_new_file(self, tmp_path, monkeypatch):
+ download_dir = tmp_path / "downloads"
+ download_dir.mkdir()
+
+ url = "https://en.wikipedia.org/new-file"
+ parsed = urlparse(url)
+ output_name = parsed.path[1:].replace("/", "-") # "new-file"
+ file_path = os.path.join(str(download_dir), output_name)
+
+ # Ensure the file does not exist.
+ if os.path.exists(file_path):
+ os.remove(file_path)
+
+ downloader = WikipediaDownloader(str(download_dir), verbose=False)
+ downloader._lock = DummyLock()
+
+ called_run = False
+
+ def fake_run(cmd, stdout, stderr):
+ nonlocal called_run
+ called_run = True
+
+ return FakeCompletedProcess()
+
+ monkeypatch.setattr(subprocess, "run", fake_run)
+
+ result = downloader.download(url)
+ assert result == file_path
+ assert called_run
+
+ def test_wikipedia_iterator(self, tmp_path):
+ # Create a minimal valid XML resembling a Wikipedia dump with one page.
+ xml_content = """
+
+
+ Test Article
+ 0
+ 123
+
+ Test content with [[link]]
+
+
+"""
+ # Compress the XML content using bz2.
+ compressed_data = bz2.compress(xml_content.encode("utf-8"))
+
+ # Write the compressed data to a temporary file.
+ temp_file = tmp_path / "test_wiki.xml.bz2"
+ temp_file.write_bytes(compressed_data)
+
+ iterator = WikipediaIterator(language="en")
+ pages = list(iterator.iterate(str(temp_file)))
+
+ assert len(pages) == 1
+ metadata, raw_text = pages[0]
+ assert metadata["title"] == "Test Article"
+ assert metadata["id"] == "123"
+ # The URL is constructed by quoting the title.
+ expected_url = "https://en.wikipedia.org/wiki/Test%20Article"
+ assert metadata["url"] == expected_url
+ assert "Test content with" in raw_text
+
+ def test_wikipedia_extractor(self):
+ extractor = WikipediaExtractor(language="en")
+ # Sample wiki markup; note the presence of a heading and a magic word.
+ content = "== Heading ==\nThis is a sample article. __NOTOC__"
+ result = extractor.extract(content)
+
+ # # The extractor should return a dict with a "text" key.
+ assert isinstance(result, dict)
+ extracted_text = result.get("text", "")
+ # Verify that the magic word was removed.
+ assert "__NOTOC__" not in extracted_text
+ # Verify that the main content appears.
+ assert "This is a sample article." in extracted_text
+
+
+class TestArxiv:
+ def test_arxiv_downloader_existing_file(self, tmp_path, monkeypatch):
+ # Create a temporary download directory and simulate an already-downloaded tar file.
+ download_dir = tmp_path / "downloads"
+ download_dir.mkdir()
+ tar_filename = "dummy.tar"
+ file_path = os.path.join(str(download_dir), tar_filename)
+ # Write dummy content to simulate an existing download.
+ with open(file_path, "w") as f:
+ f.write("existing content")
+
+ downloader = ArxivDownloader(str(download_dir), verbose=False)
+ # Monkey-patch subprocess.run (should not be called since file exists).
+ monkeypatch.setattr(subprocess, "run", fake_run_success)
+ result = downloader.download(tar_filename)
+ assert result == file_path
+
+ def test_arxiv_downloader_new_file(self, tmp_path, monkeypatch):
+ # Create a temporary download directory and ensure the tar file does not exist.
+ download_dir = tmp_path / "downloads"
+ download_dir.mkdir()
+ tar_filename = "dummy.tar"
+ file_path = os.path.join(str(download_dir), tar_filename)
+ if os.path.exists(file_path):
+ os.remove(file_path)
+
+ downloader = ArxivDownloader(str(download_dir), verbose=False)
+ called_run = False
+
+ def fake_run(cmd, stdout, stderr):
+ nonlocal called_run
+ called_run = True
+ return FakeCompletedProcess()
+
+ monkeypatch.setattr(subprocess, "run", fake_run)
+ result = downloader.download(tar_filename)
+ assert result == file_path
+ assert called_run
+
+ def test_arxiv_iterator(self, tmp_path):
+ # Create an inner tar archive containing a .tex file.
+ inner_tar_path = tmp_path / "2103.00001.tar"
+ dummy_tex_filename = "2103.00001.tex"
+ dummy_tex_content = "This is a dummy LaTeX content."
+ with tarfile.open(inner_tar_path, "w") as inner_tar:
+ # Create a temporary tex file to add into the inner tar archive.
+ temp_tex_path = tmp_path / dummy_tex_filename
+ with open(temp_tex_path, "w") as f:
+ f.write(dummy_tex_content)
+ inner_tar.add(temp_tex_path, arcname=dummy_tex_filename)
+
+ # Create an outer tar archive that contains the inner tar archive.
+ outer_tar_path = tmp_path / "dummy_main.tar"
+ with tarfile.open(outer_tar_path, "w") as outer_tar:
+ outer_tar.add(inner_tar_path, arcname="2103.00001.tar")
+
+ iterator = ArxivIterator(log_frequency=1)
+ results = list(iterator.iterate(str(outer_tar_path)))
+ # Expect one paper extracted.
+ assert len(results) == 1
+ metadata, tex_files = results[0]
+ # The ArxivIterator extracts the arxiv id from the inner archive's filename.
+ assert metadata["id"] == "2103.00001"
+ # The source_id is set to the outer tar file's basename.
+ assert metadata["source_id"] == "dummy_main.tar"
+ # Verify that the tex extraction returns the dummy content.
+ assert isinstance(tex_files, list)
+ assert dummy_tex_content in tex_files[0]
+
+ def test_arxiv_extractor(self):
+ extractor = ArxivExtractor()
+ # Create a minimal LaTeX document including comments and a section header.
+ content = r"""
+ % This is a comment line that should be removed.
+ \section{Introduction}
+ This is the introduction of the paper.
+ % Another comment that should vanish.
+ """
+ result = extractor.extract([content])
+ assert isinstance(result, dict)
+ extracted_text = result.get("text", "")
+ # Verify that comments have been removed.
+ assert "% This is a comment" not in extracted_text
+ # Verify that the section header content is retained.
+ assert "Introduction" in extracted_text
+ assert "This is the introduction" in extracted_text
+
+
+class TestCommonCrawl:
+ def test_common_crawl_downloader_existing_file(self, tmp_path, monkeypatch):
+ # Create a temporary downloads directory and simulate an already-downloaded file.
+ download_dir = tmp_path / "downloads"
+ download_dir.mkdir()
+ url = "http://dummy/commoncrawl.warc"
+ parsed = urlparse(url)
+ output_name = parsed.path[1:].replace("/", "-") # "commoncrawl.warc"
+ file_path = os.path.join(str(download_dir), output_name)
+ # Write dummy content to simulate an existing download.
+ with open(file_path, "w") as f:
+ f.write("existing content")
+
+ downloader = CommonCrawlWARCDownloader(
+ str(download_dir), aws=False, verbose=False
+ )
+
+ # Monkey-patch subprocess.run to track if it gets called.
+ called_run = False
+
+ def fake_run(cmd, stdout, stderr):
+ nonlocal called_run
+ called_run = True
+ return FakeCompletedProcess()
+
+ monkeypatch.setattr(subprocess, "run", fake_run)
+
+ result = downloader.download(url)
+ assert result == file_path
+ # Since the file already exists, no download should be attempted.
+ assert not called_run
+
+ def test_common_crawl_downloader_new_file(self, tmp_path, monkeypatch):
+ # Create a temporary downloads directory; ensure the file does not exist.
+ download_dir = tmp_path / "downloads"
+ download_dir.mkdir()
+ url = "http://dummy/commoncrawl.warc"
+ parsed = urlparse(url)
+ output_name = parsed.path[1:].replace("/", "-") # "commoncrawl.warc"
+ file_path = os.path.join(str(download_dir), output_name)
+ if os.path.exists(file_path):
+ os.remove(file_path)
+
+ downloader = CommonCrawlWARCDownloader(
+ str(download_dir), aws=False, verbose=False
+ )
+
+ called_run = False
+
+ def fake_run(cmd, stdout, stderr):
+ nonlocal called_run
+ called_run = True
+ return FakeCompletedProcess()
+
+ monkeypatch.setattr(subprocess, "run", fake_run)
+
+ result = downloader.download(url)
+ assert result == file_path
+ # Since the file did not exist, a download call (and subprocess.run) should have been made.
+ assert called_run
+
+ def test_common_crawl_iterator(self, tmp_path):
+ # Create a minimal valid WARC file with a single "response" record.
+ raw_warc_path = tmp_path / "dummy.warc"
+ http_response = (
+ "HTTP/1.1 200 OK\r\n"
+ "Content-Type: text/html\r\n"
+ "\r\n"
+ "
Common Crawl test paragraph with some content.
\r\n"
+ )
+ http_response_bytes = http_response.encode("utf-8")
+ content_length = len(http_response_bytes)
+ warc_record = (
+ (
+ f"WARC/1.0\r\n"
+ f"WARC-Type: response\r\n"
+ f"WARC-Record-ID: \r\n"
+ f"WARC-Date: 2022-01-01T00:00:00Z\r\n"
+ f"WARC-Target-URI: http://example.com\r\n"
+ f"Content-Length: {content_length}\r\n"
+ f"\r\n"
+ ).encode("utf-8")
+ + http_response_bytes
+ + b"\r\n\r\n"
+ )
+ raw_warc_path.write_bytes(warc_record)
+
+ iterator = CommonCrawlWARCIterator(log_frequency=1)
+ records = list(iterator.iterate(str(raw_warc_path)))
+ assert len(records) == 1
+ meta, content = records[0]
+ # Check that the URL from the header is captured.
+ assert "example.com" in meta["url"]
+ # Verify that the content includes our test paragraph.
+ assert b"Common Crawl test paragraph" in content
+
+ def test_common_crawl_extractor_justext(self):
+ extractor = CommonCrawlWARCExtractor(algorithm=JusTextExtractor())
+ html = (
+ "Common Crawl test paragraph for justext extractor. "
+ "Four score and seven years ago our fathers brought forth on this continent a new nation, "
+ "conceived in liberty, and dedicated to the proposition that all men are created equal.
"
+ )
+ content = html.encode("utf-8")
+ result = extractor.extract(content)
+ print(result)
+ assert result is not None
+ # The extracted text should include our test paragraph.
+ assert "Common Crawl test paragraph for justext extractor." in result["text"]
+ assert "language" in result
+
+ def test_common_crawl_extractor_resiliparse(self):
+ extractor = CommonCrawlWARCExtractor(algorithm=ResiliparseExtractor())
+ html = (
+ "Common Crawl test paragraph for resiliparse extractor. "
+ "Four score and seven years ago our fathers brought forth on this continent a new nation, "
+ "conceived in liberty, and dedicated to the proposition that all men are created equal.
"
+ )
+ content = html.encode("utf-8")
+ result = extractor.extract(content)
+ print(result)
+ assert result is not None
+ assert (
+ "Common Crawl test paragraph for resiliparse extractor." in result["text"]
+ )
+ assert "language" in result