Skip to content

Commit

Permalink
Add fineweb examples/scripts
Browse files Browse the repository at this point in the history
Signed-off-by: Vibhu Jawa <[email protected]>
  • Loading branch information
VibhuJawa committed Aug 16, 2024
1 parent aedb456 commit bfd7b23
Show file tree
Hide file tree
Showing 3 changed files with 295 additions and 0 deletions.
64 changes: 64 additions & 0 deletions examples/classifiers/fineweb_edu_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import time

from nemo_curator.classifiers import FineWebEduClassifier
from nemo_curator.datasets import DocumentDataset
from nemo_curator.utils.distributed_utils import get_client
from nemo_curator.utils.script_utils import ArgumentHelper


def main(args):
global_st = time.time()

# Input can be a string or list
input_file_path = "/path/to/data"
output_file_path = "./"

client_args = ArgumentHelper.parse_client_args(args)
client_args["cluster_type"] = "gpu"
client = get_client(**client_args)

input_dataset = DocumentDataset.read_json(
input_file_path, backend="cudf", add_filename=True
)

fineweb_classifier = FineWebEduClassifier()
result_dataset = fineweb_classifier(dataset=input_dataset)
result_dataset.to_json(output_file_dir=output_file_path, write_to_filename=True)

global_et = time.time()
print(
f"Total time taken for fineweb classifier inference: {global_et-global_st} s",
flush=True,
)

client.close()


def attach_args(
parser=argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
),
):
argumentHelper = ArgumentHelper(parser)
argumentHelper.add_distributed_classifier_cluster_args()

return argumentHelper.parser.parse_args()


if __name__ == "__main__":
main(attach_args().parse_args())
122 changes: 122 additions & 0 deletions nemo_curator/classifiers/fineweb_edu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os

os.environ["RAPIDS_NO_INITIALIZE"] = "1"
os.environ["DASK_DATAFRAME__QUERY_PLANNING"] = "False"
import torch
from crossfit import op
from crossfit.backend.torch.hf.model import HFModel
from transformers import AutoConfig, AutoModelForSequenceClassification

from nemo_curator.classifiers.base import (
DistributedDataClassifier,
_get_suggest_memory_for_classifier,
_run_classifier_helper,
)
from nemo_curator.datasets import DocumentDataset

FINEWEB_EDU_IDENTIFIER = "HuggingFaceTB/fineweb-edu-classifier"


class FinewebEduModel(HFModel):
def __init__(self, path_or_name, max_mem_gb, autocast=False):
self.path_or_name = path_or_name
self.autocast = autocast
if max_mem_gb is None:
max_mem_gb = _get_suggest_memory_for_classifier()
super().__init__(path_or_name=path_or_name, max_mem_gb=max_mem_gb)

def load_model(self, device="cuda"):
model = AutoModelForSequenceClassification.from_pretrained(self.path_or_name)
model = model.to(device)
model = self.configure_forward(model, self.autocast)
return model

@staticmethod
def configure_forward(model, autocast=True):
original_forward = model.forward

def custom_forward(*args, **kwargs):
if autocast:
with torch.autocast(device_type="cuda"):
output = original_forward(*args, **kwargs)
return output.logits.squeeze(-1).float()

model.forward = custom_forward
return model

def load_config(self):
return AutoConfig.from_pretrained(self.path_or_name)


class FineWebEduClassifier(DistributedDataClassifier):
def __init__(
self,
filter_by=None,
batch_size=256,
text_field: str = "text",
pred_column="fineweb-edu-score",
int_column="fineweb-edu-score-int",
max_chars=-1,
device_type="cuda",
autocast=True,
max_mem_gb=None,
):
model = FinewebEduModel(
path_or_name=FINEWEB_EDU_IDENTIFIER,
autocast=autocast,
max_mem_gb=max_mem_gb,
)

self.text_field = text_field
self.int_column = int_column
super().__init__(
model=model,
filter_by=filter_by,
batch_size=batch_size,
pred_column=pred_column,
max_chars=max_chars,
device_type=device_type,
autocast=autocast,
labels=None,
out_dim=1,
)

def _run_classifier(self, dataset: DocumentDataset):
print("Starting Fineweb EDU classifier inference", flush=True)
ddf = dataset.df

pipe = op.Sequential(
op.Tokenizer(
self.model,
cols=[self.text_field],
tokenizer_type="sentencepiece",
max_length=self.model.max_seq_length(),
),
op.Predictor(
self.model,
sorted_data_loader=True,
batch_size=self.batch_size,
pred_output_col=self.pred_column,
),
keep_cols=ddf.columns.tolist(),
)
ddf = pipe(ddf)
# Go from list to scalar
ddf[self.pred_column] = ddf[self.pred_column].list.get(0)
ddf[self.int_column] = (
ddf[self.pred_column].clip(lower=0, upper=5).round().astype(int)
)
return DocumentDataset(ddf)
109 changes: 109 additions & 0 deletions nemo_curator/scripts/classifiers/fineweb_edu_classifier_inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import time
import warnings

os.environ["RAPIDS_NO_INITIALIZE"] = "1"
from nemo_curator.classifiers import FineWebEduClassifier
from nemo_curator.datasets import DocumentDataset

# Get relevant args
from nemo_curator.utils.distributed_utils import get_client, read_data, write_to_disk
from nemo_curator.utils.file_utils import get_remaining_files
from nemo_curator.utils.script_utils import ArgumentHelper

warnings.filterwarnings("ignore")


def main():
args = ArgumentHelper.parse_distributed_classifier_args().parse_args()
print(f"Arguments parsed = {args}", flush=True)

client_args = ArgumentHelper.parse_client_args(args)
client_args["cluster_type"] = "gpu"
client = get_client(**client_args)
print("Starting Fineweb classifier inference", flush=True)
global_st = time.time()
files_per_run = len(client.scheduler_info()["workers"]) * 2

if not os.path.exists(args.output_data_dir):
os.makedirs(args.output_data_dir)

# Some times jsonl files are stored as .json
# So to handle that case we can pass the input_file_extension
if args.input_file_extension is not None:
input_file_extension = args.input_file_extension
else:
input_file_extension = args.input_file_type

input_files = get_remaining_files(
args.input_data_dir, args.output_data_dir, input_file_extension
)
print(f"Total input files {len(input_files)}", flush=True)

if args.input_file_type == "pickle":
add_filename = False
else:
add_filename = True

fineweb_edu_classifier = FineWebEduClassifier(
batch_size=args.batch_size,
autocast=args.autocast,
max_chars=args.max_chars,
max_mem_gb=args.max_mem_gb_classifier,
)

for file_batch_id, i in enumerate(range(0, len(input_files), files_per_run)):
batch_st = time.time()
current_batch_files = input_files[i : i + files_per_run]
print(
f"File Batch ID {file_batch_id}: total input files {len(current_batch_files)}",
flush=True,
)
df = read_data(
input_files=current_batch_files,
file_type=args.input_file_type,
add_filename=add_filename,
)
df = fineweb_edu_classifier(DocumentDataset(df)).df
print(f"Total input Dask DataFrame partitions {df.npartitions}", flush=True)

write_to_disk(
df=df,
output_file_dir=args.output_data_dir,
write_to_filename=add_filename,
output_type=args.output_file_type,
)
batch_et = time.time()
print(
f"File Batch ID {file_batch_id}: completed in {batch_et-batch_st} seconds",
flush=True,
)

global_et = time.time()
print(
f"Total time taken for domain classifier inference: {global_et-global_st} s",
flush=True,
)
client.close()


def console_script():
main()


if __name__ == "__main__":
console_script()

0 comments on commit bfd7b23

Please sign in to comment.