diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 1aef7785..90e50d0e 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -33,6 +33,7 @@ repos: exclude: docs/ - id: requirements-txt-fixer - id: trailing-whitespace + exclude: nemo_curator/utils/aegis_utils.py - repo: https://github.com/psf/black rev: 24.4.2 diff --git a/examples/classifiers/aegis_example.py b/examples/classifiers/aegis_example.py new file mode 100644 index 00000000..05a0331d --- /dev/null +++ b/examples/classifiers/aegis_example.py @@ -0,0 +1,70 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import time + +from nemo_curator.classifiers import AegisClassifier +from nemo_curator.datasets import DocumentDataset +from nemo_curator.utils.distributed_utils import get_client +from nemo_curator.utils.script_utils import ArgumentHelper + + +def main(args): + global_st = time.time() + + # Input can be a string or list + input_file_path = "/path/to/data" + output_file_path = "./" + huggingface_token = "hf_1234" # Replace with a HuggingFace user access token + + client_args = ArgumentHelper.parse_client_args(args) + client_args["cluster_type"] = "gpu" + client = get_client(**client_args) + + input_dataset = DocumentDataset.read_json( + input_file_path, backend="cudf", add_filename=True + ) + + safety_classifier = AegisClassifier( + aegis_variant="nvidia/Aegis-AI-Content-Safety-LlamaGuard-Permissive-1.0", + token=huggingface_token, + filter_by=["safe", "O13"], + ) + result_dataset = safety_classifier(dataset=input_dataset) + + result_dataset.to_json(output_file_dir=output_file_path, write_to_filename=True) + + global_et = time.time() + print( + f"Total time taken for AEGIS classifier inference: {global_et-global_st} s", + flush=True, + ) + + client.close() + + +def attach_args( + parser=argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ), +): + argumentHelper = ArgumentHelper(parser) + argumentHelper.add_distributed_classifier_cluster_args() + + return argumentHelper.parser.parse_args() + + +if __name__ == "__main__": + main(attach_args().parse_args()) diff --git a/examples/domain_classifier_example.py b/examples/classifiers/domain_example.py similarity index 77% rename from examples/domain_classifier_example.py rename to examples/classifiers/domain_example.py index 78d4612c..5fea924e 100644 --- a/examples/domain_classifier_example.py +++ b/examples/classifiers/domain_example.py @@ -15,7 +15,7 @@ import argparse import time -from nemo_curator import DomainClassifier +from nemo_curator.classifiers import DomainClassifier from nemo_curator.datasets import DocumentDataset from nemo_curator.utils.distributed_utils import get_client from nemo_curator.utils.script_utils import ArgumentHelper @@ -28,7 +28,9 @@ def main(args): input_file_path = "/path/to/data" output_file_path = "./" - client = get_client(**ArgumentHelper.parse_client_args(args)) + client_args = ArgumentHelper.parse_client_args(args) + client_args["cluster_type"] = "gpu" + client = get_client(**client_args) input_dataset = DocumentDataset.read_json( input_file_path, backend="cudf", add_filename=True @@ -54,17 +56,9 @@ def attach_args( ), ): argumentHelper = ArgumentHelper(parser) + argumentHelper.add_distributed_classifier_cluster_args() - argumentHelper.add_arg_device() - argumentHelper.add_arg_enable_spilling() - argumentHelper.add_arg_nvlink_only() - argumentHelper.add_arg_protocol() - argumentHelper.add_arg_rmm_pool_size() - argumentHelper.add_arg_scheduler_address() - argumentHelper.add_arg_scheduler_file() - argumentHelper.add_arg_set_torch_to_use_rmm() - - return argumentHelper.parser + return argumentHelper.parser.parse_args() if __name__ == "__main__": diff --git a/examples/quality_classifier_example.py b/examples/classifiers/quality_example.py similarity index 76% rename from examples/quality_classifier_example.py rename to examples/classifiers/quality_example.py index 6d13f9df..4cc47095 100644 --- a/examples/quality_classifier_example.py +++ b/examples/classifiers/quality_example.py @@ -15,7 +15,7 @@ import argparse import time -from nemo_curator import QualityClassifier +from nemo_curator.classifiers import QualityClassifier from nemo_curator.datasets import DocumentDataset from nemo_curator.utils.distributed_utils import get_client from nemo_curator.utils.script_utils import ArgumentHelper @@ -30,7 +30,9 @@ def main(args): input_file_path = "/path/to/data" output_file_path = "./" - client = get_client(**ArgumentHelper.parse_client_args(args)) + client_args = ArgumentHelper.parse_client_args(args) + client_args["cluster_type"] = "gpu" + client = get_client(**client_args) input_dataset = DocumentDataset.read_json( input_file_path, backend="cudf", add_filename=True @@ -41,7 +43,8 @@ def main(args): filter_by=["High", "Medium"], ) result_dataset = quality_classifier(dataset=input_dataset) - print(result_dataset.df.head()) + + result_dataset.to_json(output_file_dir=output_file_path, write_to_filename=True) global_et = time.time() print( @@ -58,17 +61,9 @@ def attach_args( ), ): argumentHelper = ArgumentHelper(parser) + argumentHelper.add_distributed_classifier_cluster_args() - argumentHelper.add_arg_device() - argumentHelper.add_arg_enable_spilling() - argumentHelper.add_arg_nvlink_only() - argumentHelper.add_arg_protocol() - argumentHelper.add_arg_rmm_pool_size() - argumentHelper.add_arg_scheduler_address() - argumentHelper.add_arg_scheduler_file() - argumentHelper.add_arg_set_torch_to_use_rmm() - - return argumentHelper.parser + return argumentHelper.parser.parse_args() if __name__ == "__main__": diff --git a/nemo_curator/classifiers/__init__.py b/nemo_curator/classifiers/__init__.py new file mode 100644 index 00000000..ede9bad0 --- /dev/null +++ b/nemo_curator/classifiers/__init__.py @@ -0,0 +1,22 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +os.environ["RAPIDS_NO_INITIALIZE"] = "1" +from .aegis import AegisClassifier +from .domain import DomainClassifier +from .quality import QualityClassifier + +__all__ = ["DomainClassifier", "QualityClassifier", "AegisClassifier"] diff --git a/nemo_curator/classifiers/aegis.py b/nemo_curator/classifiers/aegis.py new file mode 100644 index 00000000..dc662b52 --- /dev/null +++ b/nemo_curator/classifiers/aegis.py @@ -0,0 +1,290 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os + +os.environ["RAPIDS_NO_INITIALIZE"] = "1" +os.environ["DASK_DATAFRAME__QUERY_PLANNING"] = "False" +from dataclasses import dataclass +from functools import lru_cache +from typing import List, Optional, Union + +import cudf +import torch +import torch.nn as nn +from crossfit import op +from crossfit.backend.torch.hf.model import HFModel +from peft import PeftModel +from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer + +from nemo_curator.classifiers.base import DistributedDataClassifier +from nemo_curator.datasets import DocumentDataset +from nemo_curator.utils.aegis_utils import format_aegis + + +@dataclass +class AegisConfig: + peft_model_name_or_path: str + token: Optional[Union[str, bool]] = None + pretrained_model_name_or_path: str = "meta-llama/LlamaGuard-7b" + dtype: torch.dtype = torch.bfloat16 + max_length: int = 4096 + + +ACCESS_ERROR_MESSAGE = """Cannot access meta-llama/LlamaGuard-7b on HuggingFace. +AEGIS Safety Classifier is built on meta-llama/LlamaGuard-7b and access to it on HuggingFace is required to run this module. +You must be authenticated (using a user access token) to access it. +You can request access to Llama Guard on HuggingFace here: https://huggingface.co/meta-llama/LlamaGuard-7b. +Request access and pass in your user access token into the constructor of nemo_curator.classifiers.AegisClassifier in order to use AEGIS. +""" + +AEGIS_LABELS = [ + "unknown", + "safe", + "O1", + "O2", + "O3", + "O4", + "O5", + "O6", + "O7", + "O8", + "O9", + "O10", + "O11", + "O12", + "O13", +] + + +class AegisModel(nn.Module): + def __init__( + self, + pretrained_model_name_or_path: str, + peft_model_name_or_path: str, + dtype: torch.dtype, + token: str, + ): + super().__init__() + base_model = AutoModelForCausalLM.from_pretrained( + pretrained_model_name_or_path, torch_dtype=dtype, token=token + ) + self.model = PeftModel.from_pretrained(base_model, peft_model_name_or_path) + + @torch.no_grad() + def forward(self, batch): + response = self.model.generate( + **batch, + max_new_tokens=100, + pad_token_id=0, + ) + return response + + +class AegisHFModel(HFModel): + def __init__(self, config: AegisConfig): + self.config = config + super().__init__( + config.pretrained_model_name_or_path, + max_mem_gb=48, + start_batch_size=4, + end_batch_size=32, + batch_size_increment=4, + start_seq_len=1024, + seq_len_increment=1024, + ) + + def load_model(self, device="cuda"): + model = AegisModel( + self.config.pretrained_model_name_or_path, + self.config.peft_model_name_or_path, + self.config.dtype, + self.config.token, + ) + model = model.to(device) + model.eval() + return model + + def load_config(self): + return AutoConfig.from_pretrained( + pretrained_model_name_or_path=self.config.pretrained_model_name_or_path, + token=self.config.token, + ) + + @lru_cache(maxsize=1) + def load_cfg(self): + return self.load_config() + + @lru_cache(maxsize=1) + def load_tokenizer(self): + tokenizer = AutoTokenizer.from_pretrained( + pretrained_model_name_or_path=self.config.pretrained_model_name_or_path, + token=self.config.token, + padding_side="left", + ) + tokenizer.pad_token = tokenizer.unk_token + + return tokenizer + + def max_seq_length(self) -> int: + return self.config.max_length + + +class AegisClassifier(DistributedDataClassifier): + """ + NVIDIA's AEGIS safety classifier is a LLM content safety model. + It is a parameter efficient instruction tuned version of Llama Guard based on + Llama2-7B trained on Nvidia's content safety dataset Aegis Content Safety + Dataset covering Nvidia's broad taxonomy of 13 critical safety risk + categories. See the paper for more information: https://arxiv.org/abs/2404.05993 + + In order to use this AEGIS classifiers, users must get access to + Llama Guard on HuggingFace here: https://huggingface.co/meta-llama/LlamaGuard-7b + Afterwards, they should set up a user access token and pass that token into + the constructor of this classifier. + """ + + def __init__( + self, + aegis_variant: str = "nvidia/Aegis-AI-Content-Safety-LlamaGuard-Defensive-1.0", + token: Optional[Union[str, bool]] = None, + filter_by: Optional[List[str]] = None, + batch_size: int = 64, + text_field: str = "text", + pred_column: str = "aegis_pred", + raw_pred_column: str = "_aegis_raw_pred", + keep_raw_pred: bool = False, + max_chars: int = 6000, + device_type: str = "cuda", + ): + """ + Constructs the classifier + + Args: + aegis_variant (str): The HuggingFace 'pretrained_model_name_or_path' for + the AEGIS model. Can be either 'nvidia/Aegis-AI-Content-Safety-LlamaGuard-Defensive-1.0' + or 'nvidia/Aegis-AI-Content-Safety-LlamaGuard-Permissive-1.0' + token (Optional[Union[str, bool]]): A HuggingFace user access token. A user access token is + needed to access the base model for AEGIS (meta-llama/LlamaGuard-7b). You can get access to + Llama Guard on HuggingFace here: https://huggingface.co/meta-llama/LlamaGuard-7b + filter_by (Optional[List[str]]): If specified, the resulting dataset will remove all values + expect those specified in this list. + batch_size (int): The batch size to use when running the classifier. + text_field (str): The field in the dataset that should be classified. + pred_column (str): The name of the column to store the resulting prediction. + raw_pred_column (str): The name of the column to store the raw output of the AEGIS LLM before + the prediction is extracted from it. + keep_raw_pred (bool): If True, will keep the unprocessed LLM output in raw_pred_column. + Useful for debugging when "unknown" shows up a lot in your dataset. + max_chars (int): If the document is larger than max_chars, the classifier will only classify + the first max_chars. + device_type (str): The device to run the classifier on. Currently, it can only be "cuda". + """ + config = AegisConfig(peft_model_name_or_path=aegis_variant, token=token) + + self.text_field = text_field + self.labels = AEGIS_LABELS + self.out_dim = len(self.labels) + self.raw_pred_column = raw_pred_column + self.keep_raw_pred = keep_raw_pred + + try: + model = AegisHFModel(config=config) + except OSError as e: + if "meta-llama/LlamaGuard-7b" in str(e): + raise PermissionError(ACCESS_ERROR_MESSAGE) + else: + raise e + + super().__init__( + model=model, + labels=self.labels, + filter_by=filter_by, + batch_size=batch_size, + out_dim=self.out_dim, + pred_column=pred_column, + max_chars=max_chars, + device_type=device_type, + autocast=False, + ) + + def _wrap_in_prompt(self, df): + documents = df[self.text_field].to_arrow().to_pylist() + prompts = [format_aegis(doc[: self.max_chars]) for doc in documents] + df["_hidden_text"] = cudf.Series(prompts) + return df + + def _parse_response(self, raw_response: str) -> str: + lines = raw_response.split("\n") + if lines[0].strip() == "safe": + return "safe" + elif lines[0].strip() == "unsafe": + if len(lines) < 2: + return "unknown" + potential_label = lines[1].strip() + if potential_label not in AEGIS_LABELS[2:]: + return "unknown" + + return potential_label + else: + return "unknown" + + def _postprocess_responses(self, df): + tokenizer = self.model.load_tokenizer() + generated_tokens = df[self.raw_pred_column].to_arrow().to_pylist() + generated_tokens = tokenizer.batch_decode( + generated_tokens, + skip_special_tokens=True, + ) + original_lengths = df["_hidden_text"].str.len().to_arrow().to_pylist() + generated_tokens = [ + chars[original_length:] + for chars, original_length in zip(generated_tokens, original_lengths) + ] + parsed_response = [ + self._parse_response(response) for response in generated_tokens + ] + if self.keep_raw_pred: + df[self.raw_pred_column] = cudf.Series(generated_tokens) + else: + df = df.drop(columns=[self.raw_pred_column]) + df[self.pred_column] = cudf.Series(parsed_response) + return df + + def _run_classifier(self, dataset: DocumentDataset): + print("Starting AEGIS classifier inference", flush=True) + ddf = dataset.df + hidden_meta = ddf._meta.copy() + hidden_meta["_hidden_text"] = "DUMMY_STRING" + ddf = ddf.map_partitions(self._wrap_in_prompt, meta=hidden_meta) + columns = ddf.columns.tolist() + pipe = op.Sequential( + op.Tokenizer(self.model, cols=["_hidden_text"], tokenizer_type="default"), + op.Predictor( + self.model, + sorted_data_loader=True, + batch_size=self.batch_size, + pred_output_col=self.raw_pred_column, + ), + keep_cols=columns, + ) + ddf = pipe(ddf) + translated_meta = ddf._meta.copy() + if self.keep_raw_pred: + translated_meta[self.raw_pred_column] = "DUMMY_STRING" + else: + translated_meta = translated_meta.drop(columns=[self.raw_pred_column]) + translated_meta[self.pred_column] = "DUMMY_STRING" + ddf = ddf.map_partitions(self._postprocess_responses, meta=translated_meta) + ddf = ddf.drop(columns=["_hidden_text"]) + return DocumentDataset(ddf) diff --git a/nemo_curator/classifiers/base.py b/nemo_curator/classifiers/base.py new file mode 100644 index 00000000..e0e052c4 --- /dev/null +++ b/nemo_curator/classifiers/base.py @@ -0,0 +1,129 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os + +os.environ["RAPIDS_NO_INITIALIZE"] = "1" +from abc import ABC, abstractmethod +from typing import List + +from crossfit import op + +from nemo_curator.datasets import DocumentDataset + + +class DistributedDataClassifier(ABC): + """Abstract class for running multi-node multi-GPU data classification""" + + def __init__( + self, + model, + labels, + filter_by, + batch_size, + out_dim, + pred_column, + max_chars, + device_type, + autocast, + ): + self.model = model + self.labels = labels + self.filter_by = filter_by + self.batch_size = batch_size + self.out_dim = out_dim + self.pred_column = pred_column + self.max_chars = max_chars + self.device_type = device_type + self.autocast = autocast + + def __call__(self, dataset: DocumentDataset): + result_doc_dataset = self._run_classifier(dataset) + if self.filter_by is not None: + return self._filter_documents(result_doc_dataset) + + return result_doc_dataset + + @abstractmethod + def _run_classifier(self): + pass + + def _filter_documents( + self, + dataset: DocumentDataset, + ): + df = dataset.df + + filter_by = self.filter_by + if type(filter_by) == str: + filtered_df = df[df[self.pred_column].astype(str) == filter_by] + return DocumentDataset(filtered_df) + elif type(filter_by) == list: + filtered_df = df[df[self.pred_column].isin(filter_by)] + return DocumentDataset(filtered_df) + + raise TypeError("filter_by must be a string or list type") + + def get_labels(self) -> List[str]: + return self.labels + + +def _run_classifier_helper( + df: "dask_cudf.DataFrame", + model: "HFModel", + labels: list[str], + max_chars: int, + batch_size: int, + label_col: str, + prob_col: str = None, +) -> "dask_cudf.DataFrame": + + keep_prob = prob_col is not None + prob_internal_col = "_prob" + # TODO: Make crossfit handle this cleanly + pred_internal_col = "labels" + df["sliced_text"] = df["text"].str.slice(0, max_chars) + columns_to_keep_list = df.columns.to_list() + columns_to_keep_list.remove("sliced_text") + + classifier_pipe = op.Sequential( + op.Tokenizer(model, cols=["sliced_text"], tokenizer_type="default"), + op.Predictor( + model, + sorted_data_loader=True, + batch_size=batch_size, + pred_output_col=prob_internal_col, + ), + repartition=df.npartitions, + keep_cols=columns_to_keep_list, + ) + df = classifier_pipe(df) + + # TODO: Make crossfit handle this cleanly + # to prevent the labeler from dropping the prob_internal_col + # and combine it into a single step + labeling_pipe = op.Sequential( + op.Labeler(labels, cols=[prob_internal_col]), + keep_cols=columns_to_keep_list + [prob_internal_col], + ) + df = labeling_pipe(df) + + if keep_prob: + df = df.rename( + columns={prob_internal_col: prob_col, pred_internal_col: label_col}, + ) + else: + df = df.rename(columns={pred_internal_col: label_col}) + df = df.drop(columns=[prob_internal_col]) + + return df diff --git a/nemo_curator/classifiers/domain.py b/nemo_curator/classifiers/domain.py new file mode 100644 index 00000000..5d83dc19 --- /dev/null +++ b/nemo_curator/classifiers/domain.py @@ -0,0 +1,128 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +from dataclasses import dataclass + +os.environ["RAPIDS_NO_INITIALIZE"] = "1" +import torch +import torch.nn as nn +from crossfit.backend.torch.hf.model import HFModel +from huggingface_hub import PyTorchModelHubMixin +from transformers import AutoConfig, AutoModel, AutoTokenizer + +from nemo_curator.classifiers.base import ( + DistributedDataClassifier, + _run_classifier_helper, +) +from nemo_curator.datasets import DocumentDataset + +DOMAIN_IDENTIFIER = "nvidia/domain-classifier" + + +@dataclass +class DomainModelConfig: + model = "microsoft/deberta-v3-base" + fc_dropout = 0.2 + max_len = 512 + + +class HFCustomModel(nn.Module, PyTorchModelHubMixin): + def __init__(self, config: dataclass): + super(HFCustomModel, self).__init__() + self.model = AutoModel.from_pretrained(config["base_model"]) + self.dropout = nn.Dropout(config["fc_dropout"]) + self.fc = nn.Linear(self.model.config.hidden_size, len(config["id2label"])) + + def _forward(self, batch): + features = self.model( + batch["input_ids"], batch["attention_mask"] + ).last_hidden_state + dropped = self.dropout(features) + outputs = self.fc(dropped) + return torch.softmax(outputs[:, 0, :], dim=1) + + def forward(self, batch): + if self.autocast: + with torch.autocast(device_type="cuda"): + return self._forward(batch) + else: + return self._forward(batch) + + def set_autocast(self, autocast): + self.autocast = autocast + + +class DomainModel(HFModel): + def __init__(self, config: dataclass, autocast: bool = False): + self.config = config + self.autocast = autocast + super().__init__(self.config.model) + + def load_model(self, device="cuda"): + model = HFCustomModel.from_pretrained(DOMAIN_IDENTIFIER) + model.set_autocast(self.autocast) + model = model.to(device) + return model.eval() + + def load_tokenizer(self): + return AutoTokenizer.from_pretrained(DOMAIN_IDENTIFIER) + + def load_config(self): + return AutoConfig.from_pretrained(DOMAIN_IDENTIFIER) + + +class DomainClassifier(DistributedDataClassifier): + def __init__( + self, + filter_by=None, + batch_size=256, + pred_column="domain_pred", + prob_column=None, + max_chars=2000, + device_type="cuda", + autocast=True, + ): + config = AutoConfig.from_pretrained(DOMAIN_IDENTIFIER) + + self.prob_column = prob_column + self.labels = list(config.label2id.keys()) + self.out_dim = len(self.labels) + + model = DomainModel(config=DomainModelConfig, autocast=autocast) + + super().__init__( + model=model, + labels=self.labels, + filter_by=filter_by, + batch_size=batch_size, + out_dim=self.out_dim, + pred_column=pred_column, + max_chars=max_chars, + device_type=device_type, + autocast=autocast, + ) + + def _run_classifier(self, dataset: DocumentDataset): + print("Starting domain classifier inference", flush=True) + df = dataset.df + df = _run_classifier_helper( + df=df, + model=self.model, + labels=self.labels, + max_chars=self.max_chars, + batch_size=self.batch_size, + label_col=self.pred_column, + prob_col=self.prob_column, + ) + return DocumentDataset(df) diff --git a/nemo_curator/classifiers/quality.py b/nemo_curator/classifiers/quality.py new file mode 100644 index 00000000..652e22e0 --- /dev/null +++ b/nemo_curator/classifiers/quality.py @@ -0,0 +1,191 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os + +os.environ["RAPIDS_NO_INITIALIZE"] = "1" +from dataclasses import dataclass + +import torch +import torch.nn as nn +from crossfit.backend.torch.hf.model import HFModel +from transformers import AutoConfig, AutoModel +from transformers.models.deberta_v2 import DebertaV2TokenizerFast + +from nemo_curator.classifiers.base import ( + DistributedDataClassifier, + _run_classifier_helper, +) +from nemo_curator.datasets import DocumentDataset + + +@dataclass +class QualityModelConfig: + model = "microsoft/deberta-v3-base" + fc_dropout = 0.2 + max_len = 512 + + +# TODO: Remove this class after Quality Model is uploaded to HuggingFace +class NCCustomModel(nn.Module): + def __init__( + self, + config: dataclass, + out_dim: int, + config_path: str = None, + pretrained: bool = False, + autocast: bool = False, + ): + super().__init__() + self.config = config + if config_path is None: + self.config = AutoConfig.from_pretrained( + config.model, output_hidden_states=True + ) + else: + self.config = torch.load(config_path) + + if pretrained: + self.model = AutoModel.from_pretrained(config.model, config=self.config) + else: + self.model = AutoModel(self.config) + + self.fc_dropout = nn.Dropout(config.fc_dropout) + self.fc = nn.Linear(self.config.hidden_size, out_dim) + self._init_weights(self.fc) + self.autocast = autocast + + def _init_weights(self, module): + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + def feature(self, input_ids, attention_mask): + outputs = self.model(input_ids=input_ids, attention_mask=attention_mask) + last_hidden_states = outputs[0] + return last_hidden_states + + def _forward(self, batch): + feature = self.feature(batch["input_ids"], batch["attention_mask"]) + output = self.fc(self.fc_dropout(feature)) + output = output.to(torch.float32) + return torch.softmax(output[:, 0, :], dim=1) + + def forward(self, batch): + if self.autocast: + with torch.autocast(device_type="cuda"): + return self._forward(batch) + else: + return self._forward(batch) + + +class QualityModel(HFModel): + def __init__(self, config, out_dim=None, model_path=None, autocast=False): + self.config = config + self.out_dim = out_dim + self.model_path = model_path + self.autocast = autocast + super().__init__(self.config.model) + + def load_model(self, device="cuda"): + model = NCCustomModel( + self.config, + out_dim=self.out_dim, + config_path=None, + pretrained=True, + autocast=self.autocast, + ) + model = model.to(device) + + if os.path.exists(self.model_path): + sd = torch.load(self.model_path, map_location="cpu") + if "model_state_dict" in sd: + sd = sd["model_state_dict"] + sd = {k[7:] if k.startswith("module.") else k: sd[k] for k in sd.keys()} + model.load_state_dict(sd, strict=True) + else: + raise ValueError(f"Model path {self.model_path} does not exist") + + return model.eval() + + def load_tokenizer(self): + return DebertaV2TokenizerFast.from_pretrained(self.config.model) + + def load_config(self): + return AutoConfig.from_pretrained(self.path_or_name) + + +class QualityClassifier(DistributedDataClassifier): + def __init__( + self, + model_path, + num_labels=3, + filter_by=None, + batch_size=256, + pred_column="quality_pred", + prob_column="quality_prob", + max_chars=6000, + device_type="cuda", + autocast=True, + ): + if num_labels == 3: + self.labels = ["High", "Medium", "Low"] + self.out_dim = num_labels # Multiclass classification + elif num_labels == 2: + self.labels = ["Medium_High", "Low"] + self.out_dim = 1 # Binary classification + else: + raise ValueError("num_labels must be 2 or 3") + + self.prob_column = prob_column + + model = QualityModel( + config=QualityModelConfig, + out_dim=self.out_dim, + model_path=model_path, + autocast=autocast, + ) + + super().__init__( + model=model, + labels=self.labels, + filter_by=filter_by, + batch_size=batch_size, + out_dim=self.out_dim, + pred_column=pred_column, + max_chars=max_chars, + device_type=device_type, + autocast=autocast, + ) + + def _run_classifier(self, dataset: DocumentDataset): + print("Starting Quality classifier inference", flush=True) + df = dataset.df + df = _run_classifier_helper( + df=df, + model=self.model, + labels=self.labels, + max_chars=self.max_chars, + batch_size=self.batch_size, + label_col=self.pred_column, + prob_col=self.prob_column, + ) + return DocumentDataset(df) diff --git a/nemo_curator/modules/__init__.py b/nemo_curator/modules/__init__.py index db6aca7d..7d168e7d 100644 --- a/nemo_curator/modules/__init__.py +++ b/nemo_curator/modules/__init__.py @@ -49,10 +49,8 @@ SemanticClusterLevelDedup = gpu_only_import_from( "nemo_curator.modules.semantic_dedup", "SemanticClusterLevelDedup" ) -from .distributed_data_classifier import DomainClassifier, QualityClassifier __all__ = [ - "DomainClassifier", "ExactDuplicates", "Filter", "FuzzyDuplicatesConfig", @@ -60,7 +58,6 @@ "LSH", "MinHash", "Modify", - "QualityClassifier", "Score", "ScoreFilter", "Sequential", diff --git a/nemo_curator/modules/distributed_data_classifier.py b/nemo_curator/modules/distributed_data_classifier.py deleted file mode 100644 index 1e2abb23..00000000 --- a/nemo_curator/modules/distributed_data_classifier.py +++ /dev/null @@ -1,399 +0,0 @@ -# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os - -os.environ["RAPIDS_NO_INITIALIZE"] = "1" -from abc import ABC, abstractmethod -from dataclasses import dataclass -from typing import List - -import torch -import torch.nn as nn -from crossfit import op -from crossfit.backend.torch.hf.model import HFModel -from huggingface_hub import PyTorchModelHubMixin -from transformers import AutoConfig, AutoModel, AutoTokenizer -from transformers.models.deberta_v2 import DebertaV2TokenizerFast - -from nemo_curator.datasets import DocumentDataset - -DOMAIN_IDENTIFIER = "nvidia/domain-classifier" - - -@dataclass -class DomainModelConfig: - model = "microsoft/deberta-v3-base" - fc_dropout = 0.2 - max_len = 512 - - -@dataclass -class QualityModelConfig: - model = "microsoft/deberta-v3-base" - fc_dropout = 0.2 - max_len = 512 - - -# TODO: Remove this class after Quality Model is uploaded to HuggingFace -class NCCustomModel(nn.Module): - def __init__( - self, - config: dataclass, - out_dim: int, - config_path: str = None, - pretrained: bool = False, - autocast: bool = False, - ): - super().__init__() - self.config = config - if config_path is None: - self.config = AutoConfig.from_pretrained( - config.model, output_hidden_states=True - ) - else: - self.config = torch.load(config_path) - - if pretrained: - self.model = AutoModel.from_pretrained(config.model, config=self.config) - else: - self.model = AutoModel(self.config) - - self.fc_dropout = nn.Dropout(config.fc_dropout) - self.fc = nn.Linear(self.config.hidden_size, out_dim) - self._init_weights(self.fc) - self.autocast = autocast - - def _init_weights(self, module): - if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) - - def feature(self, input_ids, attention_mask): - outputs = self.model(input_ids=input_ids, attention_mask=attention_mask) - last_hidden_states = outputs[0] - return last_hidden_states - - def _forward(self, batch): - feature = self.feature(batch["input_ids"], batch["attention_mask"]) - output = self.fc(self.fc_dropout(feature)) - output = output.to(torch.float32) - return torch.softmax(output[:, 0, :], dim=1) - - def forward(self, batch): - if self.autocast: - with torch.autocast(device_type="cuda"): - return self._forward(batch) - else: - return self._forward(batch) - - -class HFCustomModel(nn.Module, PyTorchModelHubMixin): - def __init__(self, config: dataclass): - super(HFCustomModel, self).__init__() - self.model = AutoModel.from_pretrained(config["base_model"]) - self.dropout = nn.Dropout(config["fc_dropout"]) - self.fc = nn.Linear(self.model.config.hidden_size, len(config["id2label"])) - - def _forward(self, batch): - features = self.model( - batch["input_ids"], batch["attention_mask"] - ).last_hidden_state - dropped = self.dropout(features) - outputs = self.fc(dropped) - return torch.softmax(outputs[:, 0, :], dim=1) - - def forward(self, batch): - if self.autocast: - with torch.autocast(device_type="cuda"): - return self._forward(batch) - else: - return self._forward(batch) - - def set_autocast(self, autocast): - self.autocast = autocast - - -class DistributedDataClassifier(ABC): - """Abstract class for running multi-node multi-GPU data classification""" - - def __init__( - self, - model, - labels, - filter_by, - batch_size, - out_dim, - pred_column, - max_chars, - device_type, - autocast, - ): - self.model = model - self.labels = labels - self.filter_by = filter_by - self.batch_size = batch_size - self.out_dim = out_dim - self.pred_column = pred_column - self.max_chars = max_chars - self.device_type = device_type - self.autocast = autocast - - def __call__(self, dataset: DocumentDataset): - result_doc_dataset = self._run_classifier(dataset) - if self.filter_by is not None: - return self._filter_documents(result_doc_dataset) - - return result_doc_dataset - - @abstractmethod - def _run_classifier(self): - pass - - def _filter_documents( - self, - dataset: DocumentDataset, - ): - df = dataset.df - - filter_by = self.filter_by - if type(filter_by) == str: - filtered_df = df[df[self.pred_column].astype(str) == filter_by] - return DocumentDataset(filtered_df) - elif type(filter_by) == list: - filtered_df = df[df[self.pred_column].isin(filter_by)] - return DocumentDataset(filtered_df) - - raise TypeError("filter_by must be a string or list type") - - def get_labels(self) -> List[str]: - return self.labels - - -def _run_classifier_helper( - df: "dask_cudf.DataFrame", - model: "HFModel", - labels: list[str], - max_chars: int, - batch_size: int, - label_col: str, - prob_col: str = None, -) -> "dask_cudf.DataFrame": - - keep_prob = prob_col is not None - prob_internal_col = "_prob" - # TODO: Make crossfit handle this cleanly - pred_internal_col = "labels" - df["sliced_text"] = df["text"].str.slice(0, max_chars) - columns_to_keep_list = df.columns.to_list() - columns_to_keep_list.remove("sliced_text") - - classifier_pipe = op.Sequential( - op.Tokenizer(model, cols=["sliced_text"], tokenizer_type="sentencepiece"), - op.Predictor( - model, - sorted_data_loader=True, - batch_size=batch_size, - pred_output_col=prob_internal_col, - ), - repartition=df.npartitions, - keep_cols=columns_to_keep_list, - ) - df = classifier_pipe(df) - - # TODO: Make crossfit handle this cleanly - # to prevent the labeler from dropping the prob_internal_col - # and combine it into a single step - labeling_pipe = op.Sequential( - op.Labeler(labels, cols=[prob_internal_col]), - keep_cols=columns_to_keep_list + [prob_internal_col], - ) - df = labeling_pipe(df) - - if keep_prob: - df = df.rename( - columns={prob_internal_col: prob_col, pred_internal_col: label_col}, - ) - else: - df = df.rename(columns={pred_internal_col: label_col}) - df = df.drop(columns=[prob_internal_col]) - - return df - - -class DomainModel(HFModel): - def __init__(self, config: dataclass, autocast: bool = False): - self.config = config - self.autocast = autocast - super().__init__(self.config.model) - - def load_model(self, device="cuda"): - model = HFCustomModel.from_pretrained(DOMAIN_IDENTIFIER) - model.set_autocast(self.autocast) - model = model.to(device) - return model.eval() - - def load_tokenizer(self): - return AutoTokenizer.from_pretrained(DOMAIN_IDENTIFIER) - - def load_config(self): - return AutoConfig.from_pretrained(DOMAIN_IDENTIFIER) - - -class QualityModel(HFModel): - def __init__(self, config, out_dim=None, model_path=None, autocast=False): - self.config = config - self.out_dim = out_dim - self.model_path = model_path - self.autocast = autocast - super().__init__(self.config.model) - - def load_model(self, device="cuda"): - model = NCCustomModel( - self.config, - out_dim=self.out_dim, - config_path=None, - pretrained=True, - autocast=self.autocast, - ) - model = model.to(device) - - if os.path.exists(self.model_path): - sd = torch.load(self.model_path, map_location="cpu") - if "model_state_dict" in sd: - sd = sd["model_state_dict"] - sd = {k[7:] if k.startswith("module.") else k: sd[k] for k in sd.keys()} - model.load_state_dict(sd, strict=True) - else: - raise ValueError(f"Model path {self.model_path} does not exist") - - return model.eval() - - def load_tokenizer(self): - return DebertaV2TokenizerFast.from_pretrained(self.config.model) - - def load_config(self): - return AutoConfig.from_pretrained(self.path_or_name) - - -class DomainClassifier(DistributedDataClassifier): - def __init__( - self, - filter_by=None, - batch_size=256, - pred_column="domain_pred", - prob_column=None, - max_chars=2000, - device_type="cuda", - autocast=True, - ): - config = AutoConfig.from_pretrained(DOMAIN_IDENTIFIER) - - self.prob_column = prob_column - self.labels = list(config.label2id.keys()) - self.out_dim = len(self.labels) - - model = DomainModel(config=DomainModelConfig, autocast=autocast) - - super().__init__( - model=model, - labels=self.labels, - filter_by=filter_by, - batch_size=batch_size, - out_dim=self.out_dim, - pred_column=pred_column, - max_chars=max_chars, - device_type=device_type, - autocast=autocast, - ) - - def _run_classifier(self, dataset: DocumentDataset): - print("Starting domain classifier inference", flush=True) - df = dataset.df - df = _run_classifier_helper( - df=df, - model=self.model, - labels=self.labels, - max_chars=self.max_chars, - batch_size=self.batch_size, - label_col=self.pred_column, - prob_col=self.prob_column, - ) - return DocumentDataset(df) - - -class QualityClassifier(DistributedDataClassifier): - def __init__( - self, - model_path, - num_labels=3, - filter_by=None, - batch_size=256, - pred_column="quality_pred", - prob_column="quality_prob", - max_chars=6000, - device_type="cuda", - autocast=True, - ): - if num_labels == 3: - self.labels = ["High", "Medium", "Low"] - self.out_dim = num_labels # Multiclass classification - elif num_labels == 2: - self.labels = ["Medium_High", "Low"] - self.out_dim = 1 # Binary classification - else: - raise ValueError("num_labels must be 2 or 3") - - self.prob_column = prob_column - - model = QualityModel( - config=QualityModelConfig, - out_dim=self.out_dim, - model_path=model_path, - autocast=autocast, - ) - - super().__init__( - model=model, - labels=self.labels, - filter_by=filter_by, - batch_size=batch_size, - out_dim=self.out_dim, - pred_column=pred_column, - max_chars=max_chars, - device_type=device_type, - autocast=autocast, - ) - - def _run_classifier(self, dataset: DocumentDataset): - print("Starting Quality classifier inference", flush=True) - df = dataset.df - df = _run_classifier_helper( - df=df, - model=self.model, - labels=self.labels, - max_chars=self.max_chars, - batch_size=self.batch_size, - label_col=self.pred_column, - prob_col=self.prob_column, - ) - return DocumentDataset(df) diff --git a/nemo_curator/scripts/classifiers/__init__.py b/nemo_curator/scripts/classifiers/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/nemo_curator/scripts/classifiers/aegis_classifier_inference.py b/nemo_curator/scripts/classifiers/aegis_classifier_inference.py new file mode 100644 index 00000000..96344a1c --- /dev/null +++ b/nemo_curator/scripts/classifiers/aegis_classifier_inference.py @@ -0,0 +1,130 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import time +import warnings + +os.environ["RAPIDS_NO_INITIALIZE"] = "1" +from nemo_curator.classifiers import AegisClassifier +from nemo_curator.datasets import DocumentDataset + +# Get relevant args +from nemo_curator.utils.distributed_utils import get_client, read_data, write_to_disk +from nemo_curator.utils.file_utils import get_remaining_files +from nemo_curator.utils.script_utils import ArgumentHelper + +warnings.filterwarnings("ignore") + + +def main(): + args = attach_args().parse_args() + print(f"Arguments parsed = {args}", flush=True) + + client_args = ArgumentHelper.parse_client_args(args) + client_args["cluster_type"] = "gpu" + client = get_client(**client_args) + print("Starting AEGIS classifier inference", flush=True) + global_st = time.time() + files_per_run = len(client.scheduler_info()["workers"]) * 2 + + if not os.path.exists(args.output_data_dir): + os.makedirs(args.output_data_dir) + + # Some times jsonl files are stored as .json + # So to handle that case we can pass the input_file_extension + if args.input_file_extension is not None: + input_file_extension = args.input_file_extension + else: + input_file_extension = args.input_file_type + + input_files = get_remaining_files( + args.input_data_dir, args.output_data_dir, input_file_extension + ) + print(f"Total input files {len(input_files)}", flush=True) + + if args.input_file_type == "pickle": + add_filename = False + else: + add_filename = True + + aegis_classifier = AegisClassifier( + aegis_variant=args.aegis_variant, + token=args.token, + max_chars=args.max_chars, + batch_size=args.batch_size, + ) + + for file_batch_id, i in enumerate(range(0, len(input_files), files_per_run)): + batch_st = time.time() + current_batch_files = input_files[i : i + files_per_run] + print( + f"File Batch ID {file_batch_id}: total input files {len(current_batch_files)}", + flush=True, + ) + df = read_data( + input_files=current_batch_files, + file_type=args.input_file_type, + add_filename=add_filename, + ) + df = aegis_classifier(DocumentDataset(df)).df + print(f"Total input Dask DataFrame partitions {df.npartitions}", flush=True) + + write_to_disk( + df=df, + output_file_dir=args.output_data_dir, + write_to_filename=add_filename, + output_type=args.output_file_type, + ) + batch_et = time.time() + print( + f"File Batch ID {file_batch_id}: completed in {batch_et-batch_st} seconds", + flush=True, + ) + + global_et = time.time() + print( + f"Total time taken for AEGIS classifier inference: {global_et-global_st} s", + flush=True, + ) + client.close() + + +def attach_args(): + parser = ArgumentHelper.parse_distributed_classifier_args(max_chars_default=6000) + + parser.add_argument( + "--aegis-variant", + type=str, + default="nvidia/Aegis-AI-Content-Safety-LlamaGuard-Defensive-1.0", + help="The AEGIS model to use. Can be nvidia/Aegis-AI-Content-Safety-LlamaGuard-Defensive-1.0," + "nvidia/Aegis-AI-Content-Safety-LlamaGuard-Permissive-1.0," + " or a path to your own PEFT of LlamaGuard 2", + ) + parser.add_argument( + "--token", + type=str, + default=None, + help="HuggingFace token to use when downloading the base Llama Guard model", + ) + + return parser + + +def console_script(): + main() + + +if __name__ == "__main__": + console_script() diff --git a/nemo_curator/scripts/domain_classifier_inference.py b/nemo_curator/scripts/classifiers/domain_classifier_inference.py similarity index 97% rename from nemo_curator/scripts/domain_classifier_inference.py rename to nemo_curator/scripts/classifiers/domain_classifier_inference.py index 59aa5fd7..401443cf 100644 --- a/nemo_curator/scripts/domain_classifier_inference.py +++ b/nemo_curator/scripts/classifiers/domain_classifier_inference.py @@ -17,7 +17,7 @@ import warnings os.environ["RAPIDS_NO_INITIALIZE"] = "1" -from nemo_curator import DomainClassifier +from nemo_curator.classifiers import DomainClassifier from nemo_curator.datasets import DocumentDataset # Get relevant args @@ -61,7 +61,7 @@ def main(): add_filename = True domain_classifier = DomainClassifier( - max_chars=max_chars, + max_chars=args.max_chars, batch_size=args.batch_size, autocast=args.autocast, ) diff --git a/nemo_curator/scripts/quality_classifier_inference.py b/nemo_curator/scripts/classifiers/quality_classifier_inference.py similarity index 97% rename from nemo_curator/scripts/quality_classifier_inference.py rename to nemo_curator/scripts/classifiers/quality_classifier_inference.py index c3260bff..4e7fdb90 100644 --- a/nemo_curator/scripts/quality_classifier_inference.py +++ b/nemo_curator/scripts/classifiers/quality_classifier_inference.py @@ -17,7 +17,7 @@ import warnings os.environ["RAPIDS_NO_INITIALIZE"] = "1" -from nemo_curator import QualityClassifier +from nemo_curator.classifiers import QualityClassifier from nemo_curator.datasets import DocumentDataset from nemo_curator.utils.distributed_utils import get_client, read_data, write_to_disk from nemo_curator.utils.file_utils import get_remaining_files @@ -31,7 +31,6 @@ def main(): parser.add_argument("--num-labels", type=int, default=3) args = parser.parse_args() print(f"Arguments parsed = {args}", flush=True) - max_chars = 6000 client_args = ArgumentHelper.parse_client_args(args) client_args["cluster_type"] = "gpu" @@ -63,7 +62,7 @@ def main(): classifier = QualityClassifier( model_path=args.pretrained_model_name_or_path, num_labels=args.num_labels, - max_chars=max_chars, + max_chars=args.max_chars, batch_size=args.batch_size, autocast=args.autocast, ) diff --git a/nemo_curator/utils/aegis_utils.py b/nemo_curator/utils/aegis_utils.py new file mode 100644 index 00000000..db47c778 --- /dev/null +++ b/nemo_curator/utils/aegis_utils.py @@ -0,0 +1,153 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +unsafe_categories = """O1: Violence. +Should not +- Help people plan or engage in violence. +- Encourage people to engage in violence. +- Contain unpleasant vivid violent imagery or violent description that arouses discomfort in the reader. +Can +- Provide information on violence. +- Discuss the topics of violence, and can discuss historical events involving violence. +O2: Sexual. +Should not +- Engage in sexually explicit (i.e., erotic) conversations and should not encourage people to engage in sexual activities. +- Engage or encourage sexual activity that is considered illegal. +Can +- Discuss topics related to sexuality and sexual education. +- Provide general information regarding sexual activities. +- Provide guidance on topics related to sexual health. +O3: Criminal Planning/Confessions. +Should not +- Help people plan or engage in criminal activities like theft, kidnapping, or financial crimes, especially when the person indicates that they may act on the information provided to them. +- Encourage people to engage in criminal activities or condone criminal behavior. +- Have a confession of crime that was committed or will be committed in future. +Can +- Help people understand how crimes are committed, provided that the response is not so detailed that it could be used to provide a roadmap for future crimes. +- Help people understand how systems work that are used to prevent crimes, provided those descriptions do not provide detailed information on how to defeat those systems. +O4: Guns and Illegal Weapons. +Should not +- Assist in the planning or execution of firearm or illegal weapons crimes, especially when the person indicates that they may act on the information provided to them. +- Encourage people to engage in firearm or illegal weapons crimes or condone such behavior. +- Assist or encourage people to execute or own other illegal weapons for example biological agents, or chemical weapons, including instructions on how to create or use them. +Can +- Discuss firearms and the arguments for and against firearm ownership. +O5: Controlled/Regulated Substances. +Should not +- Encourage or assist people in the trafficking, consumption, or use of regulated or controlled substances, including but not limited to illegal drugs, tobacco, alcohol, and cannabis. +- Assist or encourage people to create such substances when it is illegal to do so. +Can +- Discuss the use of regulated or controlled substances, and can provide general information on those substances (ex: what their chemical composition is and what their effects are). +- Provide information on how those substances are manufactured, provided that the information is not so detailed as to provide step-by-step instructions for their creation. +O6: Suicide and Self Harm. +Should not +- Encourage people to harm themselves. +- Romanticize or condone self-harm or suicide. +- Provide information regarding the methods of suicide or self-harm. +- Provide any information or advice, other than information regarding appropriate health resources, to people who express or imply that they wish to harm themselves. +- Contain confessions or intent of self harm +Can +- Provide general information regarding self-harm and suicide (ex: statistics, the mental health challenges faced by those who take these actions). +Should +- Provide information regarding appropriate health resources to people who imply or express the desire to harm themselves. +O7: Sexual (minor). +Should not +- Assist, engage or encourage intention to engage in sexual activity with an underage/minor. +- Discuss or plan sexual acts, sharing of media of sexual nature, personal information or intention to meet in person for sexual activity with a minor. +- Should not include discussion where private information is shared with underage minors for purposes of acting on them and taking actions that are sexual in nature including meeting or sharing media. +- Should not discuss child abuse or engage, encourage, assist in sexual grooming/trafficking of children. +Can +- Provide information about sex education for use in schools. +O8: Hate /identity hate. +Should not +- Express hateful, derogatory, or demeaning sentiments against people based on sensitive personal characteristics like their +race, color, ethinicity, religion, national origin, sexual orientation, gender, gender identity, or disability or caste. +- Encourage discrimination against people based on those characteristics and should not use slurs or create insults based on those characteristics. +- Should not discuss prejudice towards a group classified as protected according to the law of Germany, UK, USA or any other country. +- Should not contain slurs. +- Should not contain text that is distasteful, offensive, or hateful that causes others to feel grief, anger, or fear. +Can +- Provide information on discrimination. +- Discuss the topics of discrimination, and can discuss historical events involving discrimination. +- Discuss neutral statements referring to group identity and contextualized hate speech (that talks about hate but does not incite hate). +O9: PII/Privacy. +Should not +- Include statements that either requests for PII information made in a conversation turn or mention information such as name, address, +phone number or other social media handle. +- Should not discuss or contain personally identifying information(PII) that isn’t publicly available and can be used to determine the identity of the user such as +date of birth, national ID, IP address, internet cookies, other telemetry information, physical characteristics, biometrics data, financial, health, belief or social data. +- Should not include any statement that violates privacy laws set out by GDPR, Europe, CCPA, California, Personal Information Protection law (PIPL) China and +Children’s Online Privacy Protection Act (COPPA), US and all country, state, and municipal level privacy law. +Can +- Provide information what PII means. +- Provide information about privacy laws across different countries and regions. +O10: Harassment. +Should not +- Have content that may be used to torment or annoy individuals in real life, or make harassment more likely to occur. +- Engage in harassment, misconduct and unlawful discrimination +- Single out an individual or marginalization and/or retaliation based on protected characteristics like Race, Color, Gender, Sex, +Sexual orientation, Gender identity, and gender expression, National origin, Ethnicity, Disability (including being regarded as disabled), +Religion, Age (40+), Pregnancy (including pregnancy, childbirth or related medical conditions), Genetic information, Military or veteran status, Citizenship status, Political activity, or affiliation, +- Taking or requesting statutorily protected leave Body characteristics. +- Promise to give a benefit, or a threat to retaliate or take an adverse action based on a response to harassment and misconduct. +- Bullying. +- Being derogatory and humiliating toward an individual but not necessarily protected characteristics under law. +- Making rude or insulting comments, demeaning, and objectifying terms toward an individual. +Can +- Discuss negative impacts of harassment and bullying based on protected characteristics like race, color, gender, sex, nationality, within a contextualized setting. +- Provide information on protected characteristics or referor cite material that describes what protected characteristics are. +- Provide information or refer or cite material that informs what harassment is. +O11: Threat. +Should not +- Discuss a communication of intent to inflict harm or loss on another person. +- Should not contain discussion of intention of intimidation which is a tactic used between conflicting parties to make the other timid or psychologically insecure for coercion or control. +- Should not contain intention of knowingly putting another person in fear of bodily injury. +- Should not assist or plan injury, physical or mental damage, act or instance of injury, or material or any kind of loss to a person. +- Should not contain threats that are forbidden by law as intent to obtain a monetary advantage or to compel a person to act against their will. +Can +- Provide information about threat is and the meaning of threat. +O12: Profanity. +Should not +- Contain swear words, curse words, or other obscene or profane language +- This includes offensive words used without any intention to act on them or without any explicit targetted recipient +O13: Needs Caution. +Should not +- Discuss text that is toxic in nature, goes against social norms, unethical or immoral in general. +- Contain text that cannot be considered 'safe' with the information provided. +Can +- Include text that is considered to be 'safe' and that does not belong to any of the unsafe categorie described above. +""" + + +def format_aegis(message: str) -> str: + """ + Formats the prompt to the format that the AEGIS model and llama guard accepts + """ + role = "User" + prompt = f"""[INST] Task: Check if there is unsafe content in '{role}' messages in conversations according our safety policy with the below categories. + + + {unsafe_categories} + + + + + """ + prompt += f"{role}: {message}\n\n" + prompt += f""" + + Provide your safety assessment for {role} in the above conversation: + - First line must read 'safe' or 'unsafe'. + - If unsafe, a second line must include a comma-separated list of violated categories. [/INST]""" + return prompt diff --git a/nemo_curator/utils/script_utils.py b/nemo_curator/utils/script_utils.py index 66613b6e..b26b7004 100644 --- a/nemo_curator/utils/script_utils.py +++ b/nemo_curator/utils/script_utils.py @@ -289,7 +289,7 @@ def add_arg_model_path(self, help="The path to the model file"): required=False, ) - def add_arg_autocaset(self, help="Whether to use autocast or not"): + def add_arg_autocast(self, help="Whether to use autocast or not"): ArgumentHelper.attach_bool_arg( parser=self.parser, flag_name="autocast", @@ -297,6 +297,14 @@ def add_arg_autocaset(self, help="Whether to use autocast or not"): help=help, ) + def add_arg_max_chars(self, default=2000): + self.parser.add_argument( + "--max-chars", + type=int, + default=default, + help="Truncates all documents in the dataset to this number of characters before running model inference on them", + ) + def add_distributed_args(self) -> argparse.ArgumentParser: """ Adds default set of arguments that are needed for Dask cluster setup @@ -399,6 +407,7 @@ def parse_client_args(args: argparse.Namespace): @staticmethod def parse_distributed_classifier_args( description="Default distributed classifier argument parser", + max_chars_default=2000, ) -> argparse.ArgumentParser: """ Adds default set of arguments that are common to multiple stages @@ -409,29 +418,37 @@ def parse_distributed_classifier_args( formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) argumentHelper = ArgumentHelper(parser) - argumentHelper.add_distributed_args() + argumentHelper.add_distributed_classifier_cluster_args() argumentHelper.add_arg_input_data_dir(required=True) argumentHelper.add_arg_output_data_dir(help="The path of the output files") argumentHelper.add_arg_input_file_type() argumentHelper.add_arg_input_file_extension() argumentHelper.add_arg_output_file_type() argumentHelper.add_arg_input_text_field() - argumentHelper.add_arg_enable_spilling() - argumentHelper.add_arg_set_torch_to_use_rmm() argumentHelper.add_arg_batch_size( help="The batch size to be used for inference" ) argumentHelper.add_arg_model_path() - argumentHelper.add_arg_autocaset() + argumentHelper.add_arg_autocast() + argumentHelper.add_arg_max_chars(default=max_chars_default) + + return argumentHelper.parser + + def add_distributed_classifier_cluster_args(self): + """ + Adds Dask cluster args needed for the distributed data classifiers + """ + self.add_distributed_args() + self.add_arg_enable_spilling() + self.add_arg_set_torch_to_use_rmm() # Set low default RMM pool size for classifier # to allow pytorch to grow its memory usage # by default - argumentHelper.parser.set_defaults(rmm_pool_size="512MB") + self.parser.set_defaults(rmm_pool_size="512MB") # Setting to False makes it more stable for long running jobs # possibly because of memory fragmentation - argumentHelper.parser.set_defaults(set_torch_to_use_rmm=False) - return argumentHelper.parser + self.parser.set_defaults(set_torch_to_use_rmm=False) @staticmethod def parse_gpu_dedup_args(description: str) -> argparse.ArgumentParser: diff --git a/setup.py b/setup.py index c3a16271..33aef382 100644 --- a/setup.py +++ b/setup.py @@ -63,11 +63,12 @@ "usaddress==0.5.10", "nemo_toolkit[nlp]>=1.23.0", "Cython", - "crossfit @ git+https://github.com/rapidsai/crossfit.git@0.0.2", + "crossfit @ git+https://github.com/rapidsai/crossfit.git@0cc2993", # Numpy 2.0 breaks with spacy https://github.com/explosion/spaCy/issues/13528 # TODO: Remove when issue is fixed "numpy<2", "openai", + "peft", ], extras_require={ "cuda12x": [ @@ -103,8 +104,9 @@ "gpu_connected_component=nemo_curator.scripts.fuzzy_deduplication.connected_components:console_script", "gpu_exact_dups=nemo_curator.scripts.find_exact_duplicates:console_script", "deidentify=nemo_curator.scripts.find_pii_and_deidentify:console_script", - "domain_classifier_inference=nemo_curator.scripts.domain_classifier_inference:console_script", - "quality_classifier_inference=nemo_curator.scripts.quality_classifier_inference:console_script", + "domain_classifier_inference=nemo_curator.scripts.classifiers.domain_classifier_inference:console_script", + "quality_classifier_inference=nemo_curator.scripts.classifiers.quality_classifier_inference:console_script", + "aegis_classifier_inference=nemo_curator.scripts.classifiers.aegis_classifier_inference:console_script", "verify_classification_results=nemo_curator.scripts.verify_classification_results:console_script", "blend_datasets=nemo_curator.scripts.blend_datasets:console_script", "semdedup_extract_embeddings=nemo_curator.scripts.semdedup.compute_embeddings:console_script", diff --git a/tutorials/distributed_data_classification/distributed_data_classification.ipynb b/tutorials/distributed_data_classification/distributed_data_classification.ipynb index a98b8003..f372bb84 100644 --- a/tutorials/distributed_data_classification/distributed_data_classification.ipynb +++ b/tutorials/distributed_data_classification/distributed_data_classification.ipynb @@ -40,7 +40,8 @@ "metadata": {}, "outputs": [], "source": [ - "from nemo_curator import DomainClassifier, QualityClassifier, get_client\n", + "from nemo_curator import get_client\n", + "from nemo_curator.classifiers import DomainClassifier, QualityClassifier\n", "from nemo_curator.datasets import DocumentDataset\n", "import cudf\n", "import dask_cudf"