-
Notifications
You must be signed in to change notification settings - Fork 112
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add fineweb classifier #168
Merged
VibhuJawa
merged 12 commits into
NVIDIA:main
from
VibhuJawa:vjawa/add_fineweb_edu_classifier
Aug 21, 2024
Merged
Changes from all commits
Commits
Show all changes
12 commits
Select commit
Hold shift + click to select a range
627e883
Add fineweb classifier
VibhuJawa afa8e55
Fix data-quality
VibhuJawa 61101b4
Merge conflicts
VibhuJawa aedb456
Move fineweb to the mainline classifiers and some code cleanup
VibhuJawa bfd7b23
Add fineweb examples/scripts
VibhuJawa e4bcfb1
Remove changes to Aegis classifier
VibhuJawa e889f51
Address Reviews
VibhuJawa 18d93da
Fix aegis model
VibhuJawa 129c329
Merge branch 'NVIDIA:main' into vjawa/add_fineweb_edu_classifier
VibhuJawa 3ecff45
Fix minor bug in docstring
VibhuJawa 3eacf8f
Fix minor bug in max_mem_gb_classifier scripts.py
VibhuJawa 4efd551
Pass in max_mem_gb_classifier scripts/classifiers/aegis_classifier_in…
VibhuJawa File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import argparse | ||
import time | ||
|
||
from nemo_curator.classifiers import FineWebEduClassifier | ||
from nemo_curator.datasets import DocumentDataset | ||
from nemo_curator.utils.distributed_utils import get_client | ||
from nemo_curator.utils.script_utils import ArgumentHelper | ||
|
||
|
||
def main(args): | ||
global_st = time.time() | ||
|
||
# Input can be a string or list | ||
input_file_path = "/path/to/data" | ||
output_file_path = "./" | ||
|
||
client_args = ArgumentHelper.parse_client_args(args) | ||
client_args["cluster_type"] = "gpu" | ||
client = get_client(**client_args) | ||
|
||
input_dataset = DocumentDataset.read_json( | ||
input_file_path, backend="cudf", add_filename=True | ||
) | ||
|
||
fineweb_classifier = FineWebEduClassifier() | ||
result_dataset = fineweb_classifier(dataset=input_dataset) | ||
result_dataset.to_json(output_file_dir=output_file_path, write_to_filename=True) | ||
|
||
global_et = time.time() | ||
print( | ||
f"Total time taken for fineweb classifier inference: {global_et-global_st} s", | ||
flush=True, | ||
) | ||
|
||
client.close() | ||
|
||
|
||
def attach_args( | ||
parser=argparse.ArgumentParser( | ||
formatter_class=argparse.ArgumentDefaultsHelpFormatter | ||
), | ||
): | ||
argumentHelper = ArgumentHelper(parser) | ||
argumentHelper.add_distributed_classifier_cluster_args() | ||
|
||
return argumentHelper.parser.parse_args() | ||
|
||
|
||
if __name__ == "__main__": | ||
main(attach_args().parse_args()) | ||
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,140 @@ | ||
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
import os | ||
|
||
os.environ["RAPIDS_NO_INITIALIZE"] = "1" | ||
os.environ["DASK_DATAFRAME__QUERY_PLANNING"] = "False" | ||
import torch | ||
from crossfit import op | ||
from crossfit.backend.torch.hf.model import HFModel | ||
from transformers import AutoConfig, AutoModelForSequenceClassification | ||
|
||
from nemo_curator.classifiers.base import ( | ||
DistributedDataClassifier, | ||
_get_suggest_memory_for_classifier, | ||
_run_classifier_helper, | ||
) | ||
from nemo_curator.datasets import DocumentDataset | ||
|
||
FINEWEB_EDU_IDENTIFIER = "HuggingFaceTB/fineweb-edu-classifier" | ||
|
||
|
||
class FinewebEduModel(HFModel): | ||
def __init__(self, path_or_name, max_mem_gb=None, autocast=False): | ||
self.path_or_name = path_or_name | ||
self.autocast = autocast | ||
if max_mem_gb is None: | ||
max_mem_gb = _get_suggest_memory_for_classifier() | ||
super().__init__(path_or_name=path_or_name, max_mem_gb=max_mem_gb) | ||
|
||
def load_model(self, device="cuda"): | ||
model = AutoModelForSequenceClassification.from_pretrained(self.path_or_name) | ||
model = model.to(device) | ||
model = self.configure_forward(model, self.autocast) | ||
return model | ||
|
||
@staticmethod | ||
def configure_forward(model, autocast=True): | ||
original_forward = model.forward | ||
|
||
def custom_forward(*args, **kwargs): | ||
if autocast: | ||
with torch.autocast(device_type="cuda"): | ||
output = original_forward(*args, **kwargs) | ||
return output.logits.squeeze(-1).float() | ||
|
||
model.forward = custom_forward | ||
return model | ||
|
||
def load_config(self): | ||
return AutoConfig.from_pretrained(self.path_or_name) | ||
|
||
|
||
class FineWebEduClassifier(DistributedDataClassifier): | ||
""" | ||
FineWebEduClassifier is a specialized classifier designed for educational content assessment, utilizing the | ||
Hugging Face FineWeb EDU Classifier model (https://huggingface.co/HuggingFaceFW/fineweb-edu-classifier). | ||
This class is optimized for running on multi-node, multi-GPU setups to enable fast and efficient inference | ||
on large text datasets. | ||
|
||
Attributes: | ||
batch_size (int): The number of samples per batch for inference. Defaults to 256. | ||
text_field (str): The column name containing the text data to be classified. Defaults to "text". | ||
pred_column (str): The column name where prediction scores will be stored. Defaults to "fineweb-edu-score". | ||
int_column (str): The column name where integer-rounded prediction scores will be stored. Defaults to "fineweb-edu-score-int". | ||
max_chars (int): The maximum number of characters in each document to consider for classification. If -1, the entire document is considered. Defaults to -1. | ||
device_type (str): The type of device to use for inference, either "cuda" or "cpu". Defaults to "cuda". | ||
autocast (bool): Whether to use mixed precision for faster inference. Defaults to True. | ||
max_mem_gb (int, optional): The maximum amount of memory in GB to allocate for the model. If None, | ||
it defaults to the available GPU memory minus 4 GB. | ||
|
||
""" | ||
|
||
def __init__( | ||
self, | ||
batch_size=256, | ||
text_field: str = "text", | ||
pred_column="fineweb-edu-score", | ||
ryantwolf marked this conversation as resolved.
Show resolved
Hide resolved
|
||
int_column="fineweb-edu-score-int", | ||
max_chars=-1, | ||
VibhuJawa marked this conversation as resolved.
Show resolved
Hide resolved
|
||
device_type="cuda", | ||
autocast=True, | ||
max_mem_gb=None, | ||
): | ||
model = FinewebEduModel( | ||
path_or_name=FINEWEB_EDU_IDENTIFIER, | ||
autocast=autocast, | ||
max_mem_gb=max_mem_gb, | ||
) | ||
|
||
self.text_field = text_field | ||
self.int_column = int_column | ||
super().__init__( | ||
model=model, | ||
filter_by=None, # No filtering as its a numeric score | ||
batch_size=batch_size, | ||
pred_column=pred_column, | ||
max_chars=max_chars, | ||
device_type=device_type, | ||
autocast=autocast, | ||
labels=None, | ||
out_dim=1, | ||
) | ||
|
||
def _run_classifier(self, dataset: DocumentDataset): | ||
print("Starting Fineweb EDU classifier inference", flush=True) | ||
ddf = dataset.df | ||
|
||
pipe = op.Sequential( | ||
op.Tokenizer( | ||
self.model, | ||
cols=[self.text_field], | ||
tokenizer_type="sentencepiece", | ||
max_length=self.model.max_seq_length(), | ||
), | ||
op.Predictor( | ||
self.model, | ||
sorted_data_loader=True, | ||
batch_size=self.batch_size, | ||
pred_output_col=self.pred_column, | ||
), | ||
keep_cols=ddf.columns.tolist(), | ||
) | ||
ddf = pipe(ddf) | ||
# Go from list to scalar | ||
ddf[self.pred_column] = ddf[self.pred_column].list.get(0) | ||
ddf[self.int_column] = ( | ||
ddf[self.pred_column].clip(lower=0, upper=5).round().astype(int) | ||
) | ||
return DocumentDataset(ddf) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like there is an unnecessary
parse_args()
inattach_args()
andmain()
, which is also in our other classifier examples too. I opened #217 to resolve this.