Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add fineweb classifier #168

Merged
merged 12 commits into from
Aug 21, 2024
64 changes: 64 additions & 0 deletions examples/classifiers/fineweb_edu_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import time

from nemo_curator.classifiers import FineWebEduClassifier
from nemo_curator.datasets import DocumentDataset
from nemo_curator.utils.distributed_utils import get_client
from nemo_curator.utils.script_utils import ArgumentHelper


def main(args):
global_st = time.time()

# Input can be a string or list
input_file_path = "/path/to/data"
output_file_path = "./"

client_args = ArgumentHelper.parse_client_args(args)
client_args["cluster_type"] = "gpu"
client = get_client(**client_args)

input_dataset = DocumentDataset.read_json(
input_file_path, backend="cudf", add_filename=True
)

fineweb_classifier = FineWebEduClassifier()
result_dataset = fineweb_classifier(dataset=input_dataset)
result_dataset.to_json(output_file_dir=output_file_path, write_to_filename=True)

global_et = time.time()
print(
f"Total time taken for fineweb classifier inference: {global_et-global_st} s",
flush=True,
)

client.close()


def attach_args(
parser=argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
),
):
argumentHelper = ArgumentHelper(parser)
argumentHelper.add_distributed_classifier_cluster_args()

return argumentHelper.parser.parse_args()


if __name__ == "__main__":
main(attach_args().parse_args())
Comment on lines +60 to +64
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like there is an unnecessary parse_args() in attach_args() and main(), which is also in our other classifier examples too. I opened #217 to resolve this.

8 changes: 7 additions & 1 deletion nemo_curator/classifiers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,12 @@
os.environ["RAPIDS_NO_INITIALIZE"] = "1"
from .aegis import AegisClassifier
from .domain import DomainClassifier
from .fineweb_edu import FineWebEduClassifier
from .quality import QualityClassifier

__all__ = ["DomainClassifier", "QualityClassifier", "AegisClassifier"]
__all__ = [
"DomainClassifier",
"QualityClassifier",
"AegisClassifier",
"FineWebEduClassifier",
]
18 changes: 14 additions & 4 deletions nemo_curator/classifiers/aegis.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,10 @@
from peft import PeftModel
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer

from nemo_curator.classifiers.base import DistributedDataClassifier
from nemo_curator.classifiers.base import (
DistributedDataClassifier,
_get_suggest_memory_for_classifier,
)
from nemo_curator.datasets import DocumentDataset
from nemo_curator.utils.aegis_utils import format_aegis

Expand Down Expand Up @@ -92,11 +95,14 @@ def forward(self, batch):


class AegisHFModel(HFModel):
def __init__(self, config: AegisConfig):
def __init__(self, config: AegisConfig, max_mem_gb=None):
self.config = config
if max_mem_gb is None:
max_mem_gb = _get_suggest_memory_for_classifier()

super().__init__(
config.pretrained_model_name_or_path,
max_mem_gb=48,
max_mem_gb=max_mem_gb,
start_batch_size=4,
end_batch_size=32,
batch_size_increment=4,
Expand Down Expand Up @@ -166,6 +172,7 @@ def __init__(
keep_raw_pred: bool = False,
max_chars: int = 6000,
device_type: str = "cuda",
max_mem_gb: int = None,
):
"""
Constructs the classifier
Expand All @@ -189,6 +196,9 @@ def __init__(
max_chars (int): If the document is larger than max_chars, the classifier will only classify
the first max_chars.
device_type (str): The device to run the classifier on. Currently, it can only be "cuda".
max_mem_gb (int, optional): The maximum amount of memory in GB to allocate for the model. If None,
it defaults to the available GPU memory minus 4 GB.

"""
config = AegisConfig(peft_model_name_or_path=aegis_variant, token=token)

Expand All @@ -199,7 +209,7 @@ def __init__(
self.keep_raw_pred = keep_raw_pred

try:
model = AegisHFModel(config=config)
model = AegisHFModel(config=config, max_mem_gb=max_mem_gb)
except OSError as e:
if "meta-llama/LlamaGuard-7b" in str(e):
raise PermissionError(ACCESS_ERROR_MESSAGE)
Expand Down
13 changes: 13 additions & 0 deletions nemo_curator/classifiers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from transformers import AutoModel

from nemo_curator.datasets import DocumentDataset
from nemo_curator.utils.distributed_utils import get_gpu_memory_info


class DistributedDataClassifier(ABC):
Expand Down Expand Up @@ -158,3 +159,15 @@ def _run_classifier_helper(
df = df.drop(columns=[prob_internal_col])

return df


def _get_suggest_memory_for_classifier() -> int:
gpu_memory_info = get_gpu_memory_info()
min_gpu_memory = min(gpu_memory_info.values())
# Convert memory from bytes to GB
min_gpu_memory_gb = min_gpu_memory / (1024**3)
# Subtract 4GB from the minimum
# to leave room for other operations
# like cuDF operations
min_gpu_memory_gb = min_gpu_memory_gb - 4
return int(min_gpu_memory_gb)
34 changes: 31 additions & 3 deletions nemo_curator/classifiers/domain.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from nemo_curator.classifiers.base import (
DistributedDataClassifier,
HFDeberta,
_get_suggest_memory_for_classifier,
_run_classifier_helper,
)
from nemo_curator.datasets import DocumentDataset
Expand All @@ -36,10 +37,15 @@ class DomainModelConfig:


class DomainModel(HFModel):
def __init__(self, config: DomainModelConfig, autocast: bool = False):
def __init__(
self, config: DomainModelConfig, autocast: bool = False, max_mem_gb=None
):
self.config = config
self.autocast = autocast
super().__init__(self.config.model)
if max_mem_gb is None:
max_mem_gb = _get_suggest_memory_for_classifier()

super().__init__(self.config.model, max_mem_gb=max_mem_gb)

def load_model(self, device="cuda"):
model = HFDeberta.from_pretrained(DOMAIN_IDENTIFIER)
Expand All @@ -55,6 +61,25 @@ def load_config(self):


class DomainClassifier(DistributedDataClassifier):
"""
DomainClassifier is a specialized classifier designed for domain classification tasks, utilizing the
NVIDIA Domain Classifier model (https://huggingface.co/nvidia/domain-classifier). This class is optimized
for running on multi-node, multi-GPU setups to enable fast and efficient inference on large datasets.

Attributes:
filter_by (list[str], optional): The classes to filter the dataset by.
If None, all classes will be included. Defaults to None.
batch_size (int): The number of samples per batch for inference. Defaults to 256.
pred_column (str): The column name where predictions will be stored. Defaults to "domain_pred".
prob_column (str, optional): The column name where prediction probabilities will be stored. Defaults to None.
max_chars (int): The maximum number of characters in each document to consider for classification. Defaults to 2000.
device_type (str): The type of device to use for inference, either "cuda" or "cpu". Defaults to "cuda".
autocast (bool): Whether to use mixed precision for faster inference. Defaults to True.
max_mem_gb (int, optional): The maximum amount of memory in GB to allocate for the model. If None,
it defaults to the available GPU memory minus 4 GB.

"""

def __init__(
self,
filter_by=None,
Expand All @@ -64,14 +89,17 @@ def __init__(
max_chars=2000,
device_type="cuda",
autocast=True,
max_mem_gb=None,
):
config = AutoConfig.from_pretrained(DOMAIN_IDENTIFIER)

self.prob_column = prob_column
self.labels = list(config.label2id.keys())
self.out_dim = len(self.labels)

model = DomainModel(config=DomainModelConfig, autocast=autocast)
model = DomainModel(
config=DomainModelConfig, autocast=autocast, max_mem_gb=max_mem_gb
)

super().__init__(
model=model,
Expand Down
140 changes: 140 additions & 0 deletions nemo_curator/classifiers/fineweb_edu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os

os.environ["RAPIDS_NO_INITIALIZE"] = "1"
os.environ["DASK_DATAFRAME__QUERY_PLANNING"] = "False"
import torch
from crossfit import op
from crossfit.backend.torch.hf.model import HFModel
from transformers import AutoConfig, AutoModelForSequenceClassification

from nemo_curator.classifiers.base import (
DistributedDataClassifier,
_get_suggest_memory_for_classifier,
_run_classifier_helper,
)
from nemo_curator.datasets import DocumentDataset

FINEWEB_EDU_IDENTIFIER = "HuggingFaceTB/fineweb-edu-classifier"


class FinewebEduModel(HFModel):
def __init__(self, path_or_name, max_mem_gb=None, autocast=False):
self.path_or_name = path_or_name
self.autocast = autocast
if max_mem_gb is None:
max_mem_gb = _get_suggest_memory_for_classifier()
super().__init__(path_or_name=path_or_name, max_mem_gb=max_mem_gb)

def load_model(self, device="cuda"):
model = AutoModelForSequenceClassification.from_pretrained(self.path_or_name)
model = model.to(device)
model = self.configure_forward(model, self.autocast)
return model

@staticmethod
def configure_forward(model, autocast=True):
original_forward = model.forward

def custom_forward(*args, **kwargs):
if autocast:
with torch.autocast(device_type="cuda"):
output = original_forward(*args, **kwargs)
return output.logits.squeeze(-1).float()

model.forward = custom_forward
return model

def load_config(self):
return AutoConfig.from_pretrained(self.path_or_name)


class FineWebEduClassifier(DistributedDataClassifier):
"""
FineWebEduClassifier is a specialized classifier designed for educational content assessment, utilizing the
Hugging Face FineWeb EDU Classifier model (https://huggingface.co/HuggingFaceFW/fineweb-edu-classifier).
This class is optimized for running on multi-node, multi-GPU setups to enable fast and efficient inference
on large text datasets.

Attributes:
batch_size (int): The number of samples per batch for inference. Defaults to 256.
text_field (str): The column name containing the text data to be classified. Defaults to "text".
pred_column (str): The column name where prediction scores will be stored. Defaults to "fineweb-edu-score".
int_column (str): The column name where integer-rounded prediction scores will be stored. Defaults to "fineweb-edu-score-int".
max_chars (int): The maximum number of characters in each document to consider for classification. If -1, the entire document is considered. Defaults to -1.
device_type (str): The type of device to use for inference, either "cuda" or "cpu". Defaults to "cuda".
autocast (bool): Whether to use mixed precision for faster inference. Defaults to True.
max_mem_gb (int, optional): The maximum amount of memory in GB to allocate for the model. If None,
it defaults to the available GPU memory minus 4 GB.

"""

def __init__(
self,
batch_size=256,
text_field: str = "text",
pred_column="fineweb-edu-score",
int_column="fineweb-edu-score-int",
max_chars=-1,
device_type="cuda",
autocast=True,
max_mem_gb=None,
):
model = FinewebEduModel(
path_or_name=FINEWEB_EDU_IDENTIFIER,
autocast=autocast,
max_mem_gb=max_mem_gb,
)

self.text_field = text_field
self.int_column = int_column
super().__init__(
model=model,
filter_by=None, # No filtering as its a numeric score
batch_size=batch_size,
pred_column=pred_column,
max_chars=max_chars,
device_type=device_type,
autocast=autocast,
labels=None,
out_dim=1,
)

def _run_classifier(self, dataset: DocumentDataset):
print("Starting Fineweb EDU classifier inference", flush=True)
ddf = dataset.df

pipe = op.Sequential(
op.Tokenizer(
self.model,
cols=[self.text_field],
tokenizer_type="sentencepiece",
max_length=self.model.max_seq_length(),
),
op.Predictor(
self.model,
sorted_data_loader=True,
batch_size=self.batch_size,
pred_output_col=self.pred_column,
),
keep_cols=ddf.columns.tolist(),
)
ddf = pipe(ddf)
# Go from list to scalar
ddf[self.pred_column] = ddf[self.pred_column].list.get(0)
ddf[self.int_column] = (
ddf[self.pred_column].clip(lower=0, upper=5).round().astype(int)
)
return DocumentDataset(ddf)
Loading