From a7fde15e04f88e43c8dd94568cfcef060b76a11e Mon Sep 17 00:00:00 2001 From: Sarah Yurick <53962159+sarahyurick@users.noreply.github.com> Date: Wed, 12 Feb 2025 14:58:13 -0800 Subject: [PATCH] Add support for Nemotron-CC EDU classifiers (#518) * add fineweb mixtral classifier Signed-off-by: Sarah Yurick * add more files Signed-off-by: Sarah Yurick * run black Signed-off-by: Sarah Yurick * create _FineWebBaseClassifier Signed-off-by: Sarah Yurick * add more docs Signed-off-by: Sarah Yurick * add notebooks and tests Signed-off-by: Sarah Yurick * update classifier names Signed-off-by: Sarah Yurick * fix label logic Signed-off-by: Sarah Yurick * add Vibhu's suggestions Signed-off-by: Sarah Yurick * skip pytests Signed-off-by: Sarah Yurick --------- Signed-off-by: Sarah Yurick Signed-off-by: Sarah Yurick <53962159+sarahyurick@users.noreply.github.com> --- docs/user-guide/api/classifiers.rst | 6 + docs/user-guide/cpuvsgpu.rst | 1 + .../distributeddataclassification.rst | 90 +++++ examples/classifiers/README.md | 2 + .../fineweb_mixtral_edu_example.py | 64 ++++ .../fineweb_nemotron_edu_example.py | 64 ++++ nemo_curator/classifiers/__init__.py | 10 +- nemo_curator/classifiers/fineweb_edu.py | 199 +++++++++-- nemo_curator/scripts/classifiers/README.md | 40 +++ ...ineweb_mixtral_edu_classifier_inference.py | 113 ++++++ ...neweb_nemotron_edu_classifier_inference.py | 113 ++++++ pyproject.toml | 2 + tests/test_classifiers.py | 32 ++ .../distributed_data_classification/README.md | 2 + .../fineweb-mixtral-edu-classification.ipynb | 323 ++++++++++++++++++ .../fineweb-nemotron-edu-classification.ipynb | 323 ++++++++++++++++++ 16 files changed, 1358 insertions(+), 26 deletions(-) create mode 100644 examples/classifiers/fineweb_mixtral_edu_example.py create mode 100644 examples/classifiers/fineweb_nemotron_edu_example.py create mode 100644 nemo_curator/scripts/classifiers/fineweb_mixtral_edu_classifier_inference.py create mode 100644 nemo_curator/scripts/classifiers/fineweb_nemotron_edu_classifier_inference.py create mode 100644 tutorials/distributed_data_classification/fineweb-mixtral-edu-classification.ipynb create mode 100644 tutorials/distributed_data_classification/fineweb-nemotron-edu-classification.ipynb diff --git a/docs/user-guide/api/classifiers.rst b/docs/user-guide/api/classifiers.rst index 8d5da2ea..1dad8e23 100644 --- a/docs/user-guide/api/classifiers.rst +++ b/docs/user-guide/api/classifiers.rst @@ -14,6 +14,12 @@ Classifiers .. autoclass:: nemo_curator.classifiers.FineWebEduClassifier :members: +.. autoclass:: nemo_curator.classifiers.FineWebMixtralEduClassifier + :members: + +.. autoclass:: nemo_curator.classifiers.FineWebNemotronEduClassifier + :members: + .. autoclass:: nemo_curator.classifiers.AegisClassifier :members: diff --git a/docs/user-guide/cpuvsgpu.rst b/docs/user-guide/cpuvsgpu.rst index bdc3e483..096ba28c 100644 --- a/docs/user-guide/cpuvsgpu.rst +++ b/docs/user-guide/cpuvsgpu.rst @@ -71,6 +71,7 @@ The following NeMo Curator modules are GPU based. * Quality Classification * AEGIS and Instruction Data Guard Safety Models * FineWeb Educational Content Classification + * FineWeb Mixtral and FineWeb Nemotron-4 Educational Models * Content Type Classification * Prompt Task and Complexity Classification diff --git a/docs/user-guide/distributeddataclassification.rst b/docs/user-guide/distributeddataclassification.rst index 389e8ef1..d8021de2 100644 --- a/docs/user-guide/distributeddataclassification.rst +++ b/docs/user-guide/distributeddataclassification.rst @@ -31,6 +31,10 @@ Here, we summarize why each is useful for training an LLM: - The **FineWeb Educational Content Classifier** focuses on identifying and prioritizing educational material within datasets. This classifier is especially useful for training LLMs on specialized educational content, which can improve their performance on knowledge-intensive tasks. Models trained on high-quality educational content demonstrate enhanced capabilities on academic benchmarks such as MMLU and ARC, showcasing the classifier's impact on improving the knowledge-intensive task performance of LLMs. +- The **FineWeb Mixtral Educational Classifier** is designed to determine the educational value (score 0-5 from low to high). It is similar to the FineWeb-Edu classifier and was trained on the same text samples, but using annotations from Mixtral 8x22B-Instruct. + +- The **FineWeb Nemotron-4 Educational Classifier** is designed to determine the educational value (score 0-5 from low to high). It is similar to the FineWeb-Edu classifier and was trained on the same text samples, but using annotations from Nemotron-4-340B-Instruct. + - The **Content Type Classifier** is designed to categorize documents into one of 11 distinct speech types based on their content. It analyzes and understands the nuances of textual information, enabling accurate classification across a diverse range of content types. - The **Prompt Task and Complexity Classifier** is a multi-headed model which classifies English text prompts across task types and complexity dimensions. @@ -236,6 +240,92 @@ For example, to create a dataset with only highly educational content (scores 4 high_edu_dataset = result_dataset[result_dataset["fineweb-edu-score-int"] >= 4] high_edu_dataset.to_json("high_educational_content/") +FineWeb Mixtral Edu Classifier +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +The FineWeb Mixtral Edu Classifier is designed to identify and prioritize educational content within a dataset. +It is similar to the FineWeb-Edu classifier and was trained on the same text samples, but using annotations from Mixtral 8x22B-Instruct. +In contrast, the original FineWeb-Edu classifier was trained using annotations from Llama 3 70B-Instruct. +This classifier was used as part of a classifier ensemble in the creation of the `Nemotron-CC dataset `_. +These datasets can be used to train LLMs with a focus on educational content, potentially improving their performance on knowledge-intensive tasks. + +To use the FineWeb Mixtral Edu Classifier, you can follow this example: + +.. code-block:: python + + from nemo_curator.classifiers import FineWebMixtralEduClassifier + + files = get_all_files_paths_under("web_documents/") + input_dataset = DocumentDataset.read_json(files, backend="cudf") + + classifier = FineWebMixtralEduClassifier( + batch_size=256, + text_field="text", + pred_column="fineweb-mixtral-edu-score", + int_column="fineweb-mixtral-edu-score-int", + quality_label_column="fineweb-mixtral-edu-score-label", + ) + result_dataset = classifier(dataset=input_dataset) + + result_dataset.to_json("educational_content/") + +This classifier uses a model based on the `Snowflake Arctic-embed-m `_ embedding model with a linear regression layer on top. +It assigns an educational score to each document on a scale from 0 to 5, where higher scores indicate more educational content. + +The ``pred_column`` will contain the raw floating-point scores, while the ``int_column`` will contain the rounded integer scores. +The ``quality_label_column`` identifies text as high quality if it scores higher than 2.5 and low quality otherwise. +You can filter the results based on these scores to create datasets with varying levels of educational content. + +For example, to create a dataset with only highly educational content (scores 4 and 5): + +.. code-block:: python + + high_edu_dataset = result_dataset[result_dataset["fineweb-mixtral-edu-score-int"] >= 4] + high_edu_dataset.to_json("high_educational_content/") + +FineWeb Nemotron-4 Edu Classifier +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +The FineWeb Mixtral Edu Classifier is designed to identify and prioritize educational content within a dataset. +It is similar to the FineWeb-Edu classifier and was trained on the same text samples, but using annotations from Nemotron-4-340B-Instruct. +In contrast, the original FineWeb-Edu classifier was trained using annotations from Llama 3 70B-Instruct. +This classifier was used as part of a classifier ensemble in the creation of the `Nemotron-CC dataset `_. +These datasets can be used to train LLMs with a focus on educational content, potentially improving their performance on knowledge-intensive tasks. + +To use the FineWeb Nemotron-4 Edu Classifier, you can follow this example: + +.. code-block:: python + + from nemo_curator.classifiers import FineWebNemotronEduClassifier + + files = get_all_files_paths_under("web_documents/") + input_dataset = DocumentDataset.read_json(files, backend="cudf") + + classifier = FineWebNemotronEduClassifier( + batch_size=256, + text_field="text", + pred_column="fineweb-nemotron-edu-score", + int_column="fineweb-nemotron-edu-score-int", + quality_label_column="fineweb-nemotron-edu-score-label", + ) + result_dataset = classifier(dataset=input_dataset) + + result_dataset.to_json("educational_content/") + +This classifier uses a model based on the `Snowflake Arctic-embed-m `_ embedding model with a linear regression layer on top. +It assigns an educational score to each document on a scale from 0 to 5, where higher scores indicate more educational content. + +The ``pred_column`` will contain the raw floating-point scores, while the ``int_column`` will contain the rounded integer scores. +The ``quality_label_column`` identifies text as high quality if it scores higher than 2.5 and low quality otherwise. +You can filter the results based on these scores to create datasets with varying levels of educational content. + +For example, to create a dataset with only highly educational content (scores 4 and 5): + +.. code-block:: python + + high_edu_dataset = result_dataset[result_dataset["fineweb-nemotron-edu-score-int"] >= 4] + high_edu_dataset.to_json("high_educational_content/") + Content Type Classifier DeBERTa ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ diff --git a/examples/classifiers/README.md b/examples/classifiers/README.md index fad2a691..c086491b 100644 --- a/examples/classifiers/README.md +++ b/examples/classifiers/README.md @@ -8,6 +8,8 @@ The Python scripts in this directory demonstrate how to run classification on yo - AEGIS Safety Models - Instruction Data Guard Model - FineWeb Educational Content Classifier +- FineWeb Mixtral Educational Classifier +- FineWeb Nemotron-4 Educational Classifier - Content Type Classifier - Prompt Task and Complexity Classifier diff --git a/examples/classifiers/fineweb_mixtral_edu_example.py b/examples/classifiers/fineweb_mixtral_edu_example.py new file mode 100644 index 00000000..c38b4eb3 --- /dev/null +++ b/examples/classifiers/fineweb_mixtral_edu_example.py @@ -0,0 +1,64 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import time + +from nemo_curator.classifiers import FineWebMixtralEduClassifier +from nemo_curator.datasets import DocumentDataset +from nemo_curator.utils.distributed_utils import get_client +from nemo_curator.utils.script_utils import ArgumentHelper + + +def main(args): + global_st = time.time() + + # Input can be a string or list + input_file_path = "/path/to/data" + output_file_path = "./" + + client_args = ArgumentHelper.parse_client_args(args) + client_args["cluster_type"] = "gpu" + client = get_client(**client_args) + + input_dataset = DocumentDataset.read_json( + input_file_path, backend="cudf", add_filename=True + ) + + fineweb_mixtral_edu_classifier = FineWebMixtralEduClassifier() + result_dataset = fineweb_mixtral_edu_classifier(dataset=input_dataset) + result_dataset.to_json(output_path=output_file_path, write_to_filename=True) + + global_et = time.time() + print( + f"Total time taken for FineWeb Mixtral Edu Classifier inference: {global_et-global_st} s", + flush=True, + ) + + client.close() + + +def attach_args( + parser=argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ), +): + argumentHelper = ArgumentHelper(parser) + argumentHelper.add_distributed_classifier_cluster_args() + + return argumentHelper.parser + + +if __name__ == "__main__": + main(attach_args().parse_args()) diff --git a/examples/classifiers/fineweb_nemotron_edu_example.py b/examples/classifiers/fineweb_nemotron_edu_example.py new file mode 100644 index 00000000..3073ac30 --- /dev/null +++ b/examples/classifiers/fineweb_nemotron_edu_example.py @@ -0,0 +1,64 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import time + +from nemo_curator.classifiers import FineWebNemotronEduClassifier +from nemo_curator.datasets import DocumentDataset +from nemo_curator.utils.distributed_utils import get_client +from nemo_curator.utils.script_utils import ArgumentHelper + + +def main(args): + global_st = time.time() + + # Input can be a string or list + input_file_path = "/path/to/data" + output_file_path = "./" + + client_args = ArgumentHelper.parse_client_args(args) + client_args["cluster_type"] = "gpu" + client = get_client(**client_args) + + input_dataset = DocumentDataset.read_json( + input_file_path, backend="cudf", add_filename=True + ) + + fineweb_nemotron_edu_classifier = FineWebNemotronEduClassifier() + result_dataset = fineweb_nemotron_edu_classifier(dataset=input_dataset) + result_dataset.to_json(output_path=output_file_path, write_to_filename=True) + + global_et = time.time() + print( + f"Total time taken for FineWeb Nemotron-4 Edu Classifier inference: {global_et-global_st} s", + flush=True, + ) + + client.close() + + +def attach_args( + parser=argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ), +): + argumentHelper = ArgumentHelper(parser) + argumentHelper.add_distributed_classifier_cluster_args() + + return argumentHelper.parser + + +if __name__ == "__main__": + main(attach_args().parse_args()) diff --git a/nemo_curator/classifiers/__init__.py b/nemo_curator/classifiers/__init__.py index 16275e45..b01b5894 100644 --- a/nemo_curator/classifiers/__init__.py +++ b/nemo_curator/classifiers/__init__.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -18,7 +18,11 @@ from .aegis import AegisClassifier, InstructionDataGuardClassifier from .content_type import ContentTypeClassifier from .domain import DomainClassifier, MultilingualDomainClassifier -from .fineweb_edu import FineWebEduClassifier +from .fineweb_edu import ( + FineWebEduClassifier, + FineWebMixtralEduClassifier, + FineWebNemotronEduClassifier, +) from .prompt_task_complexity import PromptTaskComplexityClassifier from .quality import QualityClassifier @@ -29,6 +33,8 @@ "AegisClassifier", "InstructionDataGuardClassifier", "FineWebEduClassifier", + "FineWebMixtralEduClassifier", + "FineWebNemotronEduClassifier", "ContentTypeClassifier", "PromptTaskComplexityClassifier", ] diff --git a/nemo_curator/classifiers/fineweb_edu.py b/nemo_curator/classifiers/fineweb_edu.py index 01799c7a..572c0d74 100644 --- a/nemo_curator/classifiers/fineweb_edu.py +++ b/nemo_curator/classifiers/fineweb_edu.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -18,7 +18,7 @@ import torch from crossfit import op from crossfit.backend.torch.hf.model import HFModel -from transformers import AutoConfig, AutoModelForSequenceClassification +from transformers import AutoConfig, AutoModelForSequenceClassification, AutoTokenizer from nemo_curator.classifiers.base import ( DistributedDataClassifier, @@ -27,6 +27,8 @@ from nemo_curator.datasets import DocumentDataset FINEWEB_EDU_IDENTIFIER = "HuggingFaceFW/fineweb-edu-classifier" +FINEWEB_MIXTRAL_IDENTIFIER = "nvidia/nemocurator-fineweb-mixtral-edu-classifier" +FINEWEB_NEMOTRON_IDENTIFIER = "nvidia/nemocurator-fineweb-nemotron-4-edu-classifier" class FinewebEduModel(HFModel): @@ -63,48 +65,44 @@ def custom_forward(*args, **kwargs): model.forward = custom_forward return model + def load_tokenizer(self): + return AutoTokenizer.from_pretrained(self.path_or_name) + def load_config(self): return AutoConfig.from_pretrained(self.path_or_name) -class FineWebEduClassifier(DistributedDataClassifier): +class _FineWebBaseClassifier(DistributedDataClassifier): """ - FineWebEduClassifier is a specialized classifier designed for educational content assessment, - utilizing the Hugging Face FineWeb EDU Classifier model (https://huggingface.co/HuggingFaceFW/fineweb-edu-classifier). - This classifier is optimized for running on multi-node, multi-GPU setups to enable fast and efficient inference on large text datasets. - - Attributes: - batch_size (int): The number of samples per batch for inference. Defaults to 256. - text_field (str): The column name containing the text data to be classified. Defaults to "text". - pred_column (str): The column name where prediction scores will be stored. Defaults to "fineweb-edu-score". - int_column (str): The column name where integer-rounded prediction scores will be stored. Defaults to "fineweb-edu-score-int". - max_chars (int): The maximum number of characters in each document to consider for classification. If -1, the entire document is considered. Defaults to -1. - device_type (str): The type of device to use for inference, either "cuda" or "cpu". Defaults to "cuda". - autocast (bool): Whether to use mixed precision for faster inference. Defaults to True. - max_mem_gb (int, optional): The maximum amount of memory in GB to allocate for the model. If None, - it defaults to the available GPU memory minus 4 GB. - + Parent class for FineWebEduClassifier, FineWebMixtralEduClassifier, and FineWebNemotronEduClassifier, + since their implementations are almost identical. """ def __init__( self, - batch_size: int = 256, + fineweb_identifier: str, + pred_column: str, + int_column: str, + quality_label_column: Optional[str], + batch_size: int = 1024, text_field: str = "text", - pred_column: str = "fineweb-edu-score", - int_column="fineweb-edu-score-int", max_chars: int = -1, device_type: str = "cuda", autocast: bool = True, max_mem_gb: Optional[int] = None, ): + self.fineweb_identifier = fineweb_identifier + model = FinewebEduModel( - path_or_name=FINEWEB_EDU_IDENTIFIER, + path_or_name=fineweb_identifier, autocast=autocast, max_mem_gb=max_mem_gb, ) self.text_field = text_field self.int_column = int_column + self.quality_label_column = quality_label_column + super().__init__( model=model, filter_by=None, # No filtering as its a numeric score @@ -118,14 +116,20 @@ def __init__( ) def _run_classifier(self, dataset: DocumentDataset) -> DocumentDataset: - print("Starting Fineweb EDU classifier inference", flush=True) + if self.fineweb_identifier == FINEWEB_EDU_IDENTIFIER: + print("Starting FineWeb-Edu Classifier inference", flush=True) + elif self.fineweb_identifier == FINEWEB_MIXTRAL_IDENTIFIER: + print("Starting FineWeb Mixtral Edu Classifier inference", flush=True) + elif self.fineweb_identifier == FINEWEB_NEMOTRON_IDENTIFIER: + print("Starting FineWeb Nemotron-4 Edu Classifier inference", flush=True) + ddf = dataset.df pipe = op.Sequential( op.Tokenizer( self.model, cols=[self.text_field], - tokenizer_type="sentencepiece", + tokenizer_type="default", max_length=self.model.max_seq_length(), ), op.Predictor( @@ -137,6 +141,7 @@ def _run_classifier(self, dataset: DocumentDataset) -> DocumentDataset: keep_cols=ddf.columns.tolist(), ) ddf = pipe(ddf) + ddf[self.pred_column] = ddf[self.pred_column].where( ddf[self.pred_column] >= 0, 0 ) @@ -144,4 +149,150 @@ def _run_classifier(self, dataset: DocumentDataset) -> DocumentDataset: ddf[self.pred_column] <= 5, 5 ) ddf[self.int_column] = ddf[self.pred_column].round().astype(int) + + if self.quality_label_column is not None: + ddf[self.quality_label_column] = "high_quality" + # If the score is less than 2.5, label it as low quality + ddf[self.quality_label_column] = ddf[self.quality_label_column].mask( + ddf[self.pred_column] < 2.5, "low_quality" + ) + return DocumentDataset(ddf) + + +class FineWebEduClassifier(_FineWebBaseClassifier): + """ + FineWebEduClassifier is a specialized classifier designed for educational content assessment, + utilizing the Hugging Face FineWeb EDU Classifier model (https://huggingface.co/HuggingFaceFW/fineweb-edu-classifier). + This classifier is optimized for running on multi-node, multi-GPU setups to enable fast and efficient inference on large text datasets. + + Attributes: + batch_size (int): The number of samples per batch for inference. Defaults to 256. + text_field (str): The column name containing the text data to be classified. Defaults to "text". + pred_column (str): The column name where prediction scores will be stored. Defaults to "fineweb-edu-score". + int_column (str): The column name where integer-rounded prediction scores will be stored. Defaults to "fineweb-edu-score-int". + max_chars (int): The maximum number of characters in each document to consider for classification. If -1, the entire document is considered. Defaults to -1. + device_type (str): The type of device to use for inference, either "cuda" or "cpu". Defaults to "cuda". + autocast (bool): Whether to use mixed precision for faster inference. Defaults to True. + max_mem_gb (int, optional): The maximum amount of memory in GB to allocate for the model. If None, + it defaults to the available GPU memory minus 4 GB. + + """ + + def __init__( + self, + batch_size: int = 256, + text_field: str = "text", + pred_column: str = "fineweb-edu-score", + int_column="fineweb-edu-score-int", + max_chars: int = -1, + device_type: str = "cuda", + autocast: bool = True, + max_mem_gb: Optional[int] = None, + ): + super().__init__( + fineweb_identifier=FINEWEB_EDU_IDENTIFIER, + batch_size=batch_size, + text_field=text_field, + pred_column=pred_column, + int_column=int_column, + quality_label_column=None, + max_chars=max_chars, + device_type=device_type, + autocast=autocast, + max_mem_gb=max_mem_gb, + ) + + +class FineWebMixtralEduClassifier(_FineWebBaseClassifier): + """ + FineWebMixtralEduClassifier is a specialized classifier designed for educational content assessment, + utilizing the NemoCurator FineWeb Mixtral Edu Classifier model (https://huggingface.co/nvidia/nemocurator-fineweb-mixtral-edu-classifier). + It is similar to the FineWeb-Edu classifier and was trained on the same text samples, but using annotations from Mixtral 8x22B-Instruct. + This classifier is optimized for running on multi-node, multi-GPU setups to enable fast and efficient inference on large text datasets. + + Attributes: + batch_size (int): The number of samples per batch for inference. Defaults to 256. + text_field (str): The column name containing the text data to be classified. Defaults to "text". + pred_column (str): The column name where prediction scores will be stored. Defaults to "fineweb-mixtral-edu-score". + int_column (str): The column name where integer-rounded prediction scores will be stored. Defaults to "fineweb-mixtral-edu-score-int". + quality_label_column (str): The column name where a score of >= 2.5 is labeled "high_quality" and otherwise labeled "low_quality". Defaults to "fineweb-mixtral-edu-score-label". + max_chars (int): The maximum number of characters in each document to consider for classification. If -1, the entire document is considered. Defaults to -1. + device_type (str): The type of device to use for inference, either "cuda" or "cpu". Defaults to "cuda". + autocast (bool): Whether to use mixed precision for faster inference. Defaults to True. + max_mem_gb (int, optional): The maximum amount of memory in GB to allocate for the model. If None, + it defaults to the available GPU memory minus 4 GB. + + """ + + def __init__( + self, + batch_size: int = 1024, + text_field: str = "text", + pred_column: str = "fineweb-mixtral-edu-score", + int_column: str = "fineweb-mixtral-edu-score-int", + quality_label_column: str = "fineweb-mixtral-edu-score-label", + max_chars: int = -1, + device_type: str = "cuda", + autocast: bool = True, + max_mem_gb: Optional[int] = None, + ): + super().__init__( + fineweb_identifier=FINEWEB_MIXTRAL_IDENTIFIER, + batch_size=batch_size, + text_field=text_field, + pred_column=pred_column, + int_column=int_column, + quality_label_column=quality_label_column, + max_chars=max_chars, + device_type=device_type, + autocast=autocast, + max_mem_gb=max_mem_gb, + ) + + +class FineWebNemotronEduClassifier(_FineWebBaseClassifier): + """ + FineWebNemotronEduClassifier is a specialized classifier designed for educational content assessment, + utilizing the NemoCurator FineWeb Nemotron-4 Edu Classifier model (https://huggingface.co/nvidia/nemocurator-fineweb-nemotron-4-edu-classifier). + It is similar to the FineWeb-Edu classifier and was trained on the same text samples, but using annotations from Nemotron-4-340B-Instruct. + This classifier is optimized for running on multi-node, multi-GPU setups to enable fast and efficient inference on large text datasets. + + Attributes: + batch_size (int): The number of samples per batch for inference. Defaults to 256. + text_field (str): The column name containing the text data to be classified. Defaults to "text". + pred_column (str): The column name where prediction scores will be stored. Defaults to "fineweb-nemotron-edu-score". + int_column (str): The column name where integer-rounded prediction scores will be stored. Defaults to "fineweb-nemotron-edu-score-int". + quality_label_column (str): The column name where a score of >= 2.5 is labeled "high_quality" and otherwise labeled "low_quality". Defaults to "fineweb-nemotron-edu-score-label". + max_chars (int): The maximum number of characters in each document to consider for classification. If -1, the entire document is considered. Defaults to -1. + device_type (str): The type of device to use for inference, either "cuda" or "cpu". Defaults to "cuda". + autocast (bool): Whether to use mixed precision for faster inference. Defaults to True. + max_mem_gb (int, optional): The maximum amount of memory in GB to allocate for the model. If None, + it defaults to the available GPU memory minus 4 GB. + + """ + + def __init__( + self, + batch_size: int = 1024, + text_field: str = "text", + pred_column: str = "fineweb-nemotron-edu-score", + int_column: str = "fineweb-nemotron-edu-score-int", + quality_label_column: str = "fineweb-nemotron-edu-score-label", + max_chars: int = -1, + device_type: str = "cuda", + autocast: bool = True, + max_mem_gb: Optional[int] = None, + ): + super().__init__( + fineweb_identifier=FINEWEB_NEMOTRON_IDENTIFIER, + batch_size=batch_size, + text_field=text_field, + pred_column=pred_column, + int_column=int_column, + quality_label_column=quality_label_column, + max_chars=max_chars, + device_type=device_type, + autocast=autocast, + max_mem_gb=max_mem_gb, + ) diff --git a/nemo_curator/scripts/classifiers/README.md b/nemo_curator/scripts/classifiers/README.md index 6ca5cdef..f46efc78 100644 --- a/nemo_curator/scripts/classifiers/README.md +++ b/nemo_curator/scripts/classifiers/README.md @@ -8,6 +8,8 @@ The Python scripts in this directory demonstrate how to run classification on yo - AEGIS Safety Models - Instruction Data Guard Model - FineWeb Educational Content Classifier +- FineWeb Mixtral Educational Classifier +- FineWeb Nemotron-4 Educational Classifier - Content Type Classifier - Prompt Task and Complexity Classifier @@ -139,6 +141,44 @@ fineweb_edu_classifier_inference \ Additional arguments may be added for customizing a Dask cluster and client. Run `fineweb_edu_classifier_inference --help` for more information. +#### FineWeb Mixtral Edu Classifier Inference + +```bash +# same as `python fineweb_mixtral_edu_classifier_inference.py` +fineweb_mixtral_edu_classifier_inference \ + --input-data-dir /path/to/data/directory \ + --output-data-dir /path/to/output/directory \ + --input-file-type "jsonl" \ + --input-file-extension "jsonl" \ + --output-file-type "jsonl" \ + --input-text-field "text" \ + --batch-size 64 \ + --autocast \ + --max-chars 2000 \ + --device "gpu" +``` + +Additional arguments may be added for customizing a Dask cluster and client. Run `fineweb_mixtral_edu_classifier_inference --help` for more information. + +#### FineWeb Nemotron-4 Edu Classifier Inference + +```bash +# same as `python fineweb_nemotron_edu_classifier_inference.py` +fineweb_nemotron_edu_classifier_inference \ + --input-data-dir /path/to/data/directory \ + --output-data-dir /path/to/output/directory \ + --input-file-type "jsonl" \ + --input-file-extension "jsonl" \ + --output-file-type "jsonl" \ + --input-text-field "text" \ + --batch-size 64 \ + --autocast \ + --max-chars 2000 \ + --device "gpu" +``` + +Additional arguments may be added for customizing a Dask cluster and client. Run `fineweb_nemotron_edu_classifier_inference --help` for more information. + #### Content Type Classifier DeBERTa Inference ```bash diff --git a/nemo_curator/scripts/classifiers/fineweb_mixtral_edu_classifier_inference.py b/nemo_curator/scripts/classifiers/fineweb_mixtral_edu_classifier_inference.py new file mode 100644 index 00000000..582ec4c5 --- /dev/null +++ b/nemo_curator/scripts/classifiers/fineweb_mixtral_edu_classifier_inference.py @@ -0,0 +1,113 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import time +import warnings + +os.environ["RAPIDS_NO_INITIALIZE"] = "1" + +from nemo_curator.classifiers import FineWebMixtralEduClassifier +from nemo_curator.datasets import DocumentDataset + +# Get relevant args +from nemo_curator.utils.distributed_utils import get_client, read_data, write_to_disk +from nemo_curator.utils.file_utils import get_remaining_files +from nemo_curator.utils.script_utils import ArgumentHelper + +warnings.filterwarnings("ignore") + + +def main(): + args = ArgumentHelper.parse_distributed_classifier_args( + description="Run FineWeb Mixtral Edu Classifier inference." + ).parse_args() + print(f"Arguments parsed = {args}", flush=True) + + client_args = ArgumentHelper.parse_client_args(args) + client_args["cluster_type"] = "gpu" + client = get_client(**client_args) + print("Starting FineWeb Mixtral Edu Classifier inference", flush=True) + global_st = time.time() + files_per_run = len(client.scheduler_info()["workers"]) * 2 + + if not os.path.exists(args.output_data_dir): + os.makedirs(args.output_data_dir) + + # Some times jsonl files are stored as .json + # So to handle that case we can pass the input_file_extension + if args.input_file_extension is not None: + input_file_extension = args.input_file_extension + else: + input_file_extension = args.input_file_type + + input_files = get_remaining_files( + args.input_data_dir, args.output_data_dir, input_file_extension + ) + print(f"Total input files {len(input_files)}", flush=True) + + if args.input_file_type == "pickle": + add_filename = False + else: + add_filename = True + + fineweb_mixtral_edu_classifier = FineWebMixtralEduClassifier( + text_field=args.input_text_field, + batch_size=args.batch_size, + autocast=args.autocast, + max_chars=args.max_chars, + max_mem_gb=args.max_mem_gb_classifier, + ) + + for file_batch_id, i in enumerate(range(0, len(input_files), files_per_run)): + batch_st = time.time() + current_batch_files = input_files[i : i + files_per_run] + print( + f"File Batch ID {file_batch_id}: total input files {len(current_batch_files)}", + flush=True, + ) + df = read_data( + input_files=current_batch_files, + file_type=args.input_file_type, + add_filename=add_filename, + ) + df = fineweb_mixtral_edu_classifier(DocumentDataset(df)).df + print(f"Total input Dask DataFrame partitions {df.npartitions}", flush=True) + + write_to_disk( + df=df, + output_path=args.output_data_dir, + write_to_filename=add_filename, + output_type=args.output_file_type, + ) + batch_et = time.time() + print( + f"File Batch ID {file_batch_id}: completed in {batch_et-batch_st} seconds", + flush=True, + ) + + global_et = time.time() + print( + f"Total time taken for FineWeb Mixtral Edu Classifier inference: {global_et-global_st} s", + flush=True, + ) + client.close() + + +def console_script(): + main() + + +if __name__ == "__main__": + console_script() diff --git a/nemo_curator/scripts/classifiers/fineweb_nemotron_edu_classifier_inference.py b/nemo_curator/scripts/classifiers/fineweb_nemotron_edu_classifier_inference.py new file mode 100644 index 00000000..112453a2 --- /dev/null +++ b/nemo_curator/scripts/classifiers/fineweb_nemotron_edu_classifier_inference.py @@ -0,0 +1,113 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import time +import warnings + +os.environ["RAPIDS_NO_INITIALIZE"] = "1" + +from nemo_curator.classifiers import FineWebNemotronEduClassifier +from nemo_curator.datasets import DocumentDataset + +# Get relevant args +from nemo_curator.utils.distributed_utils import get_client, read_data, write_to_disk +from nemo_curator.utils.file_utils import get_remaining_files +from nemo_curator.utils.script_utils import ArgumentHelper + +warnings.filterwarnings("ignore") + + +def main(): + args = ArgumentHelper.parse_distributed_classifier_args( + description="Run FineWeb Nemotron-4 Edu Classifier inference." + ).parse_args() + print(f"Arguments parsed = {args}", flush=True) + + client_args = ArgumentHelper.parse_client_args(args) + client_args["cluster_type"] = "gpu" + client = get_client(**client_args) + print("Starting FineWeb Nemotron-4 Edu Classifier inference", flush=True) + global_st = time.time() + files_per_run = len(client.scheduler_info()["workers"]) * 2 + + if not os.path.exists(args.output_data_dir): + os.makedirs(args.output_data_dir) + + # Some times jsonl files are stored as .json + # So to handle that case we can pass the input_file_extension + if args.input_file_extension is not None: + input_file_extension = args.input_file_extension + else: + input_file_extension = args.input_file_type + + input_files = get_remaining_files( + args.input_data_dir, args.output_data_dir, input_file_extension + ) + print(f"Total input files {len(input_files)}", flush=True) + + if args.input_file_type == "pickle": + add_filename = False + else: + add_filename = True + + fineweb_nemotron_edu_classifier = FineWebNemotronEduClassifier( + text_field=args.input_text_field, + batch_size=args.batch_size, + autocast=args.autocast, + max_chars=args.max_chars, + max_mem_gb=args.max_mem_gb_classifier, + ) + + for file_batch_id, i in enumerate(range(0, len(input_files), files_per_run)): + batch_st = time.time() + current_batch_files = input_files[i : i + files_per_run] + print( + f"File Batch ID {file_batch_id}: total input files {len(current_batch_files)}", + flush=True, + ) + df = read_data( + input_files=current_batch_files, + file_type=args.input_file_type, + add_filename=add_filename, + ) + df = fineweb_nemotron_edu_classifier(DocumentDataset(df)).df + print(f"Total input Dask DataFrame partitions {df.npartitions}", flush=True) + + write_to_disk( + df=df, + output_path=args.output_data_dir, + write_to_filename=add_filename, + output_type=args.output_file_type, + ) + batch_et = time.time() + print( + f"File Batch ID {file_batch_id}: completed in {batch_et-batch_st} seconds", + flush=True, + ) + + global_et = time.time() + print( + f"Total time taken for FineWeb Nemotron-4 Edu Classifier inference: {global_et-global_st} s", + flush=True, + ) + client.close() + + +def console_script(): + main() + + +if __name__ == "__main__": + console_script() diff --git a/pyproject.toml b/pyproject.toml index fbbee8d1..56e0fd9e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -155,6 +155,8 @@ instruction_data_guard_classifier_inference = "nemo_curator.scripts.classifiers. multilingual_domain_classifier_inference = "nemo_curator.scripts.classifiers.multilingual_domain_classifier_inference:console_script" content_type_classifier_inference = "nemo_curator.scripts.classifiers.content_type_classifier_inference:console_script" prompt_task_complexity_classifier_inference = "nemo_curator.scripts.classifiers.prompt_task_complexity_classifier_inference:console_script" +fineweb_mixtral_edu_classifier_inference = "nemo_curator.scripts.classifiers.fineweb_mixtral_edu_classifier_inference:console_script" +fineweb_nemotron_edu_classifier_inference = "nemo_curator.scripts.classifiers.fineweb_nemotron_edu_classifier_inference:console_script" verify_classification_results = "nemo_curator.scripts.verify_classification_results:console_script" blend_datasets = "nemo_curator.scripts.blend_datasets:console_script" semdedup_extract_embeddings = "nemo_curator.scripts.semdedup.compute_embeddings:console_script" diff --git a/tests/test_classifiers.py b/tests/test_classifiers.py index 388c1ebd..d6d2852c 100644 --- a/tests/test_classifiers.py +++ b/tests/test_classifiers.py @@ -139,6 +139,38 @@ def test_fineweb_edu_classifier(gpu_client, domain_dataset): assert result_pred.equals(expected_pred) +@pytest.mark.skip( + reason="Skipping until https://huggingface.co/nvidia/nemocurator-fineweb-mixtral-edu-classifier is published" +) +@pytest.mark.gpu +def test_fineweb_mixtral_classifier(gpu_client, domain_dataset): + from nemo_curator.classifiers import FineWebMixtralEduClassifier + + classifier = FineWebMixtralEduClassifier() + result_dataset = classifier(dataset=domain_dataset) + result_pred = result_dataset.df.compute()["fineweb-mixtral-edu-score-int"] + + expected_pred = cudf.Series([1, 1, 1, 2, 0]) + + assert result_pred.equals(expected_pred) + + +@pytest.mark.skip( + reason="Skipping until https://huggingface.co/nvidia/nemocurator-fineweb-nemotron-4-edu-classifier is published" +) +@pytest.mark.gpu +def test_fineweb_nemotron_classifier(gpu_client, domain_dataset): + from nemo_curator.classifiers import FineWebNemotronEduClassifier + + classifier = FineWebNemotronEduClassifier() + result_dataset = classifier(dataset=domain_dataset) + result_pred = result_dataset.df.compute()["fineweb-nemotron-edu-score-int"] + + expected_pred = cudf.Series([1, 1, 1, 2, 0]) + + assert result_pred.equals(expected_pred) + + @pytest.mark.skip( reason="Instruction Data Guard needs to be downloaded and cached to our gpuCI runner to enable this" ) diff --git a/tutorials/distributed_data_classification/README.md b/tutorials/distributed_data_classification/README.md index f953d8f5..e5e1f9b1 100644 --- a/tutorials/distributed_data_classification/README.md +++ b/tutorials/distributed_data_classification/README.md @@ -18,6 +18,8 @@ Before running any of these notebooks, please see this [Getting Started](https:/ | `ContentTypeClassifier` | [nvidia/content-type-classifier-deberta](https://huggingface.co/nvidia/content-type-classifier-deberta) | | `DomainClassifier` | [nvidia/domain-classifier](https://huggingface.co/nvidia/domain-classifier) | | `FineWebEduClassifier` | [HuggingFaceFW/fineweb-edu-classifier](https://huggingface.co/HuggingFaceFW/fineweb-edu-classifier) | +| `FineWebMixtralEduClassifier` | [nvidia/nemocurator-fineweb-mixtral-edu-classifier](https://huggingface.co/nvidia/nemocurator-fineweb-mixtral-edu-classifier) | +| `FineWebNemotronEduClassifier` | [nvidia/nemocurator-fineweb-nemotron-4-edu-classifier](https://huggingface.co/nvidia/nemocurator-fineweb-nemotron-4-edu-classifier) | | `InstructionDataGuardClassifier` | [nvidia/instruction-data-guard](https://huggingface.co/nvidia/instruction-data-guard) | | `MultilingualDomainClassifier` | [nvidia/multilingual-domain-classifier](https://huggingface.co/nvidia/multilingual-domain-classifier) | | `PromptTaskComplexityClassifier` | [nvidia/prompt-task-and-complexity-classifier](https://huggingface.co/nvidia/prompt-task-and-complexity-classifier) | diff --git a/tutorials/distributed_data_classification/fineweb-mixtral-edu-classification.ipynb b/tutorials/distributed_data_classification/fineweb-mixtral-edu-classification.ipynb new file mode 100644 index 00000000..ed15455a --- /dev/null +++ b/tutorials/distributed_data_classification/fineweb-mixtral-edu-classification.ipynb @@ -0,0 +1,323 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Distributed Data Classification with NeMo Curator's `FineWebMixtralEduClassifier`\n", + "\n", + "This notebook demonstrates the use of NeMo Curator's `FineWebMixtralEduClassifier`. The [FineWeb Mixtral Edu classifier](https://huggingface.co/nvidia/nemocurator-fineweb-mixtral-edu-classifier) is used to determine the educational value (score 0-5 from low to high) of a text. It helps with data annotation, which is useful in data blending for foundation model training. Please refer to the Hugging Face page for more information about the NemoCurator FineWeb Mixtral Edu Classifier, including its output labels, here: https://huggingface.co/nvidia/nemocurator-fineweb-mixtral-edu-classifier.\n", + "\n", + "The FineWeb Mixtral Edu classifier is accelerated using [CrossFit](https://github.com/rapidsai/crossfit), a library that leverages intellegent batching and RAPIDS to accelerate the offline inference on large datasets.\n", + "\n", + "Before running this notebook, please see this [Getting Started](https://github.com/NVIDIA/NeMo-Curator?tab=readme-ov-file#get-started) page for instructions on how to install NeMo Curator." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "env: PYTHONWARNINGS=ignore\n" + ] + } + ], + "source": [ + "# Silence Warnings (HuggingFace internal warnings)\n", + "\n", + "%env PYTHONWARNINGS=ignore\n", + "import warnings\n", + "warnings.filterwarnings(\"ignore\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from nemo_curator import get_client\n", + "from nemo_curator.classifiers import FineWebMixtralEduClassifier\n", + "from nemo_curator.datasets import DocumentDataset\n", + "import cudf\n", + "import dask_cudf" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "cuDF Spilling is enabled\n" + ] + } + ], + "source": [ + "client = get_client(cluster_type=\"gpu\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Set Output File Path" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "output_file_path = \"output_data_dir/\"" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Prepare Text Data and Initialize Classifier" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "# Create sample DataFrame\n", + "text = [\n", + " \"Quantum computing is set to revolutionize the field of cryptography.\",\n", + " \"Investing in index funds is a popular strategy for long-term financial growth.\",\n", + " \"Recent advancements in gene therapy offer new hope for treating genetic disorders.\",\n", + " \"Online learning platforms have transformed the way students access educational resources.\",\n", + " \"Traveling to Europe during the off-season can be a more budget-friendly option.\",\n", + " \"Training regimens for athletes have become more sophisticated with the use of data analytics.\",\n", + " \"Streaming services are changing the way people consume television and film content.\",\n", + " \"Vegan recipes have gained popularity as more people adopt plant-based diets.\",\n", + " \"Climate change research is critical for developing sustainable environmental policies.\",\n", + " \"Telemedicine has become increasingly popular due to its convenience and accessibility.\",\n", + "]\n", + "df = cudf.DataFrame({\"text\": text})\n", + "input_dataset = DocumentDataset(dask_cudf.from_cudf(df, npartitions=1))\n", + "write_to_filename = False\n", + "\n", + "# Alternatively, read existing directory of JSONL files\n", + "# input_file_path=\"/input_data_dir/\"\n", + "# input_dataset = DocumentDataset.read_json(\n", + "# input_file_path, backend=\"cudf\", add_filename=True\n", + "# )\n", + "# write_to_filename = True" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "classifier = FineWebMixtralEduClassifier(batch_size=1024)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Run the Classifier\n", + "\n", + "Dask operations are lazy, so the the classifier will not run until we call an eager operation like `to_json`, `compute`, or `persist`. " + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Starting FineWeb Mixtral Edu Classifier inference\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "GPU: tcp://127.0.0.1:34037, Part: 0: 100%|██████████| 10/10 [00:02<00:00, 4.80it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Writing to disk complete for 1 partition(s)\n", + "CPU times: user 1.16 s, sys: 1.17 s, total: 2.33 s\n", + "Wall time: 15.3 s\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "GPU: tcp://127.0.0.1:34037, Part: 0: 100%|██████████| 10/10 [00:02<00:00, 3.54it/s]\n" + ] + } + ], + "source": [ + "%%time\n", + "\n", + "result_dataset = classifier(dataset=input_dataset)\n", + "result_dataset.to_json(output_path=output_file_path, write_to_filename=write_to_filename)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Inspect the Output" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Reading 1 files with blocksize='1gb' / files_per_partition=None\n" + ] + }, + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
fineweb-mixtral-edu-scorefineweb-mixtral-edu-score-intfineweb-mixtral-edu-score-labeltext
01.3525391low_qualityQuantum computing is set to revolutionize the ...
10.8291021low_qualityInvesting in index funds is a popular strategy...
21.4228521low_qualityRecent advancements in gene therapy offer new ...
31.5791022low_qualityOnline learning platforms have transformed the...
40.3461910low_qualityTraveling to Europe during the off-season can ...
\n", + "
" + ], + "text/plain": [ + " fineweb-mixtral-edu-score fineweb-mixtral-edu-score-int \\\n", + "0 1.352539 1 \n", + "1 0.829102 1 \n", + "2 1.422852 1 \n", + "3 1.579102 2 \n", + "4 0.346191 0 \n", + "\n", + " fineweb-mixtral-edu-score-label \\\n", + "0 low_quality \n", + "1 low_quality \n", + "2 low_quality \n", + "3 low_quality \n", + "4 low_quality \n", + "\n", + " text \n", + "0 Quantum computing is set to revolutionize the ... \n", + "1 Investing in index funds is a popular strategy... \n", + "2 Recent advancements in gene therapy offer new ... \n", + "3 Online learning platforms have transformed the... \n", + "4 Traveling to Europe during the off-season can ... " + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "output_dataset = DocumentDataset.read_json(output_file_path, backend=\"cudf\", add_filename=write_to_filename)\n", + "output_dataset.head()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "nemo_curator", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.16" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/tutorials/distributed_data_classification/fineweb-nemotron-edu-classification.ipynb b/tutorials/distributed_data_classification/fineweb-nemotron-edu-classification.ipynb new file mode 100644 index 00000000..4c160da5 --- /dev/null +++ b/tutorials/distributed_data_classification/fineweb-nemotron-edu-classification.ipynb @@ -0,0 +1,323 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Distributed Data Classification with NeMo Curator's `FineWebNemotronEduClassifier`\n", + "\n", + "This notebook demonstrates the use of NeMo Curator's `FineWebNemotronEduClassifier`. The [FineWeb Nemotron-4 Edu classifier](https://huggingface.co/nvidia/nemocurator-fineweb-nemotron-4-edu-classifier) is used to determine the educational value (score 0-5 from low to high) of a text. It helps with data annotation, which is useful in data blending for foundation model training. Please refer to the Hugging Face page for more information about the NemoCurator FineWeb Nemotron-4 Edu Classifier, including its output labels, here: https://huggingface.co/nvidia/nemocurator-fineweb-nemotron-4-edu-classifier.\n", + "\n", + "The FineWeb Nemotron-4 Edu classifier is accelerated using [CrossFit](https://github.com/rapidsai/crossfit), a library that leverages intellegent batching and RAPIDS to accelerate the offline inference on large datasets.\n", + "\n", + "Before running this notebook, please see this [Getting Started](https://github.com/NVIDIA/NeMo-Curator?tab=readme-ov-file#get-started) page for instructions on how to install NeMo Curator." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "env: PYTHONWARNINGS=ignore\n" + ] + } + ], + "source": [ + "# Silence Warnings (HuggingFace internal warnings)\n", + "\n", + "%env PYTHONWARNINGS=ignore\n", + "import warnings\n", + "warnings.filterwarnings(\"ignore\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from nemo_curator import get_client\n", + "from nemo_curator.classifiers import FineWebNemotronEduClassifier\n", + "from nemo_curator.datasets import DocumentDataset\n", + "import cudf\n", + "import dask_cudf" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "cuDF Spilling is enabled\n" + ] + } + ], + "source": [ + "client = get_client(cluster_type=\"gpu\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Set Output File Path" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "output_file_path = \"output_data_dir/\"" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Prepare Text Data and Initialize Classifier" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "# Create sample DataFrame\n", + "text = [\n", + " \"Quantum computing is set to revolutionize the field of cryptography.\",\n", + " \"Investing in index funds is a popular strategy for long-term financial growth.\",\n", + " \"Recent advancements in gene therapy offer new hope for treating genetic disorders.\",\n", + " \"Online learning platforms have transformed the way students access educational resources.\",\n", + " \"Traveling to Europe during the off-season can be a more budget-friendly option.\",\n", + " \"Training regimens for athletes have become more sophisticated with the use of data analytics.\",\n", + " \"Streaming services are changing the way people consume television and film content.\",\n", + " \"Vegan recipes have gained popularity as more people adopt plant-based diets.\",\n", + " \"Climate change research is critical for developing sustainable environmental policies.\",\n", + " \"Telemedicine has become increasingly popular due to its convenience and accessibility.\",\n", + "]\n", + "df = cudf.DataFrame({\"text\": text})\n", + "input_dataset = DocumentDataset(dask_cudf.from_cudf(df, npartitions=1))\n", + "write_to_filename = False\n", + "\n", + "# Alternatively, read existing directory of JSONL files\n", + "# input_file_path=\"/input_data_dir/\"\n", + "# input_dataset = DocumentDataset.read_json(\n", + "# input_file_path, backend=\"cudf\", add_filename=True\n", + "# )\n", + "# write_to_filename = True" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "classifier = FineWebNemotronEduClassifier(batch_size=1024)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Run the Classifier\n", + "\n", + "Dask operations are lazy, so the the classifier will not run until we call an eager operation like `to_json`, `compute`, or `persist`. " + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Starting FineWeb Nemotron-4 Edu Classifier inference\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "GPU: tcp://127.0.0.1:33569, Part: 0: 100%|██████████| 10/10 [00:01<00:00, 5.04it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Writing to disk complete for 1 partition(s)\n", + "CPU times: user 1.35 s, sys: 172 ms, total: 1.52 s\n", + "Wall time: 14.8 s\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "GPU: tcp://127.0.0.1:33569, Part: 0: 100%|██████████| 10/10 [00:02<00:00, 3.73it/s]\n" + ] + } + ], + "source": [ + "%%time\n", + "\n", + "result_dataset = classifier(dataset=input_dataset)\n", + "result_dataset.to_json(output_path=output_file_path, write_to_filename=write_to_filename)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Inspect the Output" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Reading 1 files with blocksize='1gb' / files_per_partition=None\n" + ] + }, + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
fineweb-nemotron-edu-scorefineweb-nemotron-edu-score-intfineweb-nemotron-edu-score-labeltext
01.3925781low_qualityQuantum computing is set to revolutionize the ...
10.8896481low_qualityInvesting in index funds is a popular strategy...
21.3437501low_qualityRecent advancements in gene therapy offer new ...
31.7314452low_qualityOnline learning platforms have transformed the...
40.2485350low_qualityTraveling to Europe during the off-season can ...
\n", + "
" + ], + "text/plain": [ + " fineweb-nemotron-edu-score fineweb-nemotron-edu-score-int \\\n", + "0 1.392578 1 \n", + "1 0.889648 1 \n", + "2 1.343750 1 \n", + "3 1.731445 2 \n", + "4 0.248535 0 \n", + "\n", + " fineweb-nemotron-edu-score-label \\\n", + "0 low_quality \n", + "1 low_quality \n", + "2 low_quality \n", + "3 low_quality \n", + "4 low_quality \n", + "\n", + " text \n", + "0 Quantum computing is set to revolutionize the ... \n", + "1 Investing in index funds is a popular strategy... \n", + "2 Recent advancements in gene therapy offer new ... \n", + "3 Online learning platforms have transformed the... \n", + "4 Traveling to Europe during the off-season can ... " + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "output_dataset = DocumentDataset.read_json(output_file_path, backend=\"cudf\", add_filename=write_to_filename)\n", + "output_dataset.head()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "nemo_curator", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.16" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +}