Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add AEGIS classifier #172

Merged
merged 16 commits into from
Aug 7, 2024
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ repos:
exclude: docs/
- id: requirements-txt-fixer
- id: trailing-whitespace
exclude: nemo_curator/utils/aegis_utils.py

- repo: https://github.com/psf/black
rev: 24.4.2
Expand Down
70 changes: 70 additions & 0 deletions examples/classifiers/aegis_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import time

from nemo_curator.classifiers import AegisClassifier
from nemo_curator.datasets import DocumentDataset
from nemo_curator.utils.distributed_utils import get_client
from nemo_curator.utils.script_utils import ArgumentHelper


def main(args):
global_st = time.time()

# Input can be a string or list
input_file_path = "/path/to/data"
output_file_path = "./"
huggingface_token = "hf_1234" # Replace with a HuggingFace user access token

client_args = ArgumentHelper.parse_client_args(args)
client_args["cluster_type"] = "gpu"
client = get_client(**client_args)

input_dataset = DocumentDataset.read_json(
input_file_path, backend="cudf", add_filename=True
)

safety_classifier = AegisClassifier(
aegis_variant="nvidia/Aegis-AI-Content-Safety-LlamaGuard-Permissive-1.0",
token=huggingface_token,
filter_by=["safe", "O13"],
)
result_dataset = safety_classifier(dataset=input_dataset)

result_dataset.to_json(output_file_dir=output_file_path, write_to_filename=True)

global_et = time.time()
print(
f"Total time taken for AEGIS classifier inference: {global_et-global_st} s",
flush=True,
)

client.close()


def attach_args(
parser=argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
),
):
argumentHelper = ArgumentHelper(parser)
argumentHelper.add_distributed_classifier_cluster_args()

return argumentHelper.parser.parse_args()


if __name__ == "__main__":
main(attach_args().parse_args())
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import argparse
import time

from nemo_curator import DomainClassifier
from nemo_curator.classifiers import DomainClassifier
from nemo_curator.datasets import DocumentDataset
from nemo_curator.utils.distributed_utils import get_client
from nemo_curator.utils.script_utils import ArgumentHelper
Expand All @@ -28,7 +28,9 @@ def main(args):
input_file_path = "/path/to/data"
output_file_path = "./"

client = get_client(**ArgumentHelper.parse_client_args(args))
client_args = ArgumentHelper.parse_client_args(args)
client_args["cluster_type"] = "gpu"
client = get_client(**client_args)

input_dataset = DocumentDataset.read_json(
input_file_path, backend="cudf", add_filename=True
Expand All @@ -54,17 +56,9 @@ def attach_args(
),
):
argumentHelper = ArgumentHelper(parser)
argumentHelper.add_distributed_classifier_cluster_args()

argumentHelper.add_arg_device()
argumentHelper.add_arg_enable_spilling()
argumentHelper.add_arg_nvlink_only()
argumentHelper.add_arg_protocol()
argumentHelper.add_arg_rmm_pool_size()
argumentHelper.add_arg_scheduler_address()
argumentHelper.add_arg_scheduler_file()
argumentHelper.add_arg_set_torch_to_use_rmm()

return argumentHelper.parser
return argumentHelper.parser.parse_args()


if __name__ == "__main__":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import argparse
import time

from nemo_curator import QualityClassifier
from nemo_curator.classifiers import QualityClassifier
from nemo_curator.datasets import DocumentDataset
from nemo_curator.utils.distributed_utils import get_client
from nemo_curator.utils.script_utils import ArgumentHelper
Expand All @@ -30,7 +30,9 @@ def main(args):
input_file_path = "/path/to/data"
output_file_path = "./"

client = get_client(**ArgumentHelper.parse_client_args(args))
client_args = ArgumentHelper.parse_client_args(args)
client_args["cluster_type"] = "gpu"
client = get_client(**client_args)

input_dataset = DocumentDataset.read_json(
input_file_path, backend="cudf", add_filename=True
Expand All @@ -41,7 +43,8 @@ def main(args):
filter_by=["High", "Medium"],
)
result_dataset = quality_classifier(dataset=input_dataset)
print(result_dataset.df.head())

result_dataset.to_json(output_file_dir=output_file_path, write_to_filename=True)

global_et = time.time()
print(
Expand All @@ -58,17 +61,9 @@ def attach_args(
),
):
argumentHelper = ArgumentHelper(parser)
argumentHelper.add_distributed_classifier_cluster_args()

argumentHelper.add_arg_device()
argumentHelper.add_arg_enable_spilling()
argumentHelper.add_arg_nvlink_only()
argumentHelper.add_arg_protocol()
argumentHelper.add_arg_rmm_pool_size()
argumentHelper.add_arg_scheduler_address()
argumentHelper.add_arg_scheduler_file()
argumentHelper.add_arg_set_torch_to_use_rmm()

return argumentHelper.parser
return argumentHelper.parser.parse_args()


if __name__ == "__main__":
Expand Down
22 changes: 22 additions & 0 deletions nemo_curator/classifiers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os

os.environ["RAPIDS_NO_INITIALIZE"] = "1"
from .aegis import AegisClassifier
from .domain import DomainClassifier
from .quality import QualityClassifier

__all__ = ["DomainClassifier", "QualityClassifier", "AegisClassifier"]
Loading
Loading