From 84c5c31c34efdafe3269cf36e6df11d0348a547c Mon Sep 17 00:00:00 2001 From: Vibhu Jawa Date: Fri, 25 Oct 2024 04:56:23 -0700 Subject: [PATCH 1/8] Add support for fine-tune classifier Signed-off-by: Vibhu Jawa --- nemo_curator/classifiers/aegis.py | 114 ++++++++++++++++++++++++++---- 1 file changed, 102 insertions(+), 12 deletions(-) diff --git a/nemo_curator/classifiers/aegis.py b/nemo_curator/classifiers/aegis.py index 1306903f..ab546283 100644 --- a/nemo_curator/classifiers/aegis.py +++ b/nemo_curator/classifiers/aegis.py @@ -21,9 +21,11 @@ import cudf import torch import torch.nn as nn +import torch.nn.functional as F from crossfit import op from crossfit.backend.torch.hf.model import HFModel from peft import PeftModel +from torch.nn import Dropout, Linear from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer from nemo_curator.classifiers.base import ( @@ -37,10 +39,13 @@ @dataclass class AegisConfig: peft_model_name_or_path: str + raw_pred_column: str token: Optional[Union[str, bool]] = None pretrained_model_name_or_path: str = "meta-llama/LlamaGuard-7b" dtype: torch.dtype = torch.bfloat16 max_length: int = 4096 + add_finetune_guard: bool = False + finetune_guard_path: Optional[str] = None ACCESS_ERROR_MESSAGE = """Cannot access meta-llama/LlamaGuard-7b on HuggingFace. @@ -69,6 +74,31 @@ class AegisConfig: ] +class Net(torch.nn.Module): + def __init__(self, input_dim, dropout=0.7): + super().__init__() + self.input_dim = input_dim + self.dropout = Dropout(dropout) + self.sigmoid = torch.nn.Sigmoid() + self.input_layer = Linear(input_dim, input_dim) + self.hidden_layer_0 = Linear(input_dim, 2000) + self.hidden_layer_1 = Linear(2000, 700) + self.hidden_layer_2 = Linear(700, 1) + + def forward(self, x): + x = torch.nn.functional.normalize(x, dim=0) + x = self.dropout(x) + x = F.relu(self.input_layer(x)) + x = self.dropout(x) + x = F.relu(self.hidden_layer_0(x)) + x = self.dropout(x) + x = F.relu(self.hidden_layer_1(x)) + x = self.dropout(x) + x = self.hidden_layer_2(x) + x = self.sigmoid(x) + return x + + class AegisModel(nn.Module): def __init__( self, @@ -76,20 +106,46 @@ def __init__( peft_model_name_or_path: str, dtype: torch.dtype, token: str, + add_fintune_gaurd: bool = False, + raw_pred_column: str = None, ): super().__init__() base_model = AutoModelForCausalLM.from_pretrained( pretrained_model_name_or_path, torch_dtype=dtype, token=token ) self.model = PeftModel.from_pretrained(base_model, peft_model_name_or_path) + if add_fintune_gaurd: + self.finetune_guard_net = Net(4096) + self.add_fintune_gaurd = add_fintune_gaurd + self.raw_pred_column = raw_pred_column @torch.no_grad() def forward(self, batch): - response = self.model.generate( - **batch, - max_new_tokens=100, - pad_token_id=0, - ) + if self.add_fintune_gaurd: + response = self.model.generate( + **batch, + max_new_tokens=100, + pad_token_id=0, + output_hidden_states=True, + return_dict_in_generate=True, + ) + fine_guard_input_tensor = response.hidden_states[0][32][:, -1, :].to( + torch.float + ) + fine_guard_output_tensor = self.finetune_guard_net( + fine_guard_input_tensor + ).flatten() + sequence = response["sequences"] + return { + self.raw_pred_column: sequence, + "fine_guard_output_tensor": fine_guard_output_tensor, + } + else: + response = self.model.generate( + **batch, + max_new_tokens=100, + pad_token_id=0, + ) return response @@ -99,6 +155,12 @@ def __init__(self, config: AegisConfig, max_mem_gb=None): if max_mem_gb is None: max_mem_gb = _get_suggest_memory_for_classifier() + if self.config.add_finetune_guard: + if self.config.finetune_guard_path is None: + raise ValueError( + "finetune_guard_path must be provided if add_fine_guard is True" + ) + super().__init__( config.pretrained_model_name_or_path, max_mem_gb=max_mem_gb, @@ -115,7 +177,14 @@ def load_model(self, device="cuda"): self.config.peft_model_name_or_path, self.config.dtype, self.config.token, + self.config.add_finetune_guard, + self.config.raw_pred_column, ) + if self.config.add_finetune_guard: + model.finetune_guard_net.load_state_dict( + torch.load(self.config.finetune_guard_path, weights_only=True) + ) + model = model.to(device) model.eval() return model @@ -170,6 +239,8 @@ def __init__( raw_pred_column: str = "_aegis_raw_pred", keep_raw_pred: bool = False, max_chars: int = 6000, + add_finetune_guard: bool = False, + finetune_guard_path: Optional[str] = None, device_type: str = "cuda", max_mem_gb: int = None, ): @@ -194,18 +265,27 @@ def __init__( Useful for debugging when "unknown" shows up a lot in your dataset. max_chars (int): If the document is larger than max_chars, the classifier will only classify the first max_chars. + add_finetune_guard (bool): If True, will add a fine-tune guard to the model. + finetune_guard_path (Optional[str]): The path to the fine-tune guard model. device_type (str): The device to run the classifier on. Currently, it can only be "cuda". max_mem_gb (int, optional): The maximum amount of memory in GB to allocate for the model. If None, it defaults to the available GPU memory minus 4 GB. """ - config = AegisConfig(peft_model_name_or_path=aegis_variant, token=token) + config = AegisConfig( + peft_model_name_or_path=aegis_variant, + token=token, + add_finetune_guard=add_finetune_guard, + raw_pred_column=raw_pred_column, + finetune_guard_path=finetune_guard_path, + ) self.text_field = text_field self.labels = AEGIS_LABELS self.out_dim = len(self.labels) self.raw_pred_column = raw_pred_column self.keep_raw_pred = keep_raw_pred + self.add_finetune_guard = add_finetune_guard try: model = AegisHFModel(config=config, max_mem_gb=max_mem_gb) @@ -277,18 +357,28 @@ def _run_classifier(self, dataset: DocumentDataset): hidden_meta["_hidden_text"] = "DUMMY_STRING" ddf = ddf.map_partitions(self._wrap_in_prompt, meta=hidden_meta) columns = ddf.columns.tolist() - pipe = op.Sequential( - op.Tokenizer(self.model, cols=["_hidden_text"], tokenizer_type="default"), - op.Predictor( + tokenizer = op.Tokenizer( + self.model, cols=["_hidden_text"], tokenizer_type="default" + ) + if self.add_finetune_guard: + predictor = op.Predictor( + self.model, + sorted_data_loader=True, + batch_size=self.batch_size, + model_output_cols=[self.raw_pred_column, "fine_guard_output_tensor"], + ) + else: + predictor = op.Predictor( self.model, sorted_data_loader=True, batch_size=self.batch_size, pred_output_col=self.raw_pred_column, - ), - keep_cols=columns, - ) + ) + pipe = op.Sequential(tokenizer, predictor, keep_cols=columns) ddf = pipe(ddf) translated_meta = ddf._meta.copy() + if self.add_finetune_guard: + translated_meta["fine_guard_output_tensor"] = cudf.Series([1.0]) if self.keep_raw_pred: translated_meta[self.raw_pred_column] = "DUMMY_STRING" else: From 8dc59348494f9b140ec63b1294fcf83f7e156184 Mon Sep 17 00:00:00 2001 From: Vibhu Jawa Date: Mon, 28 Oct 2024 10:47:27 -0700 Subject: [PATCH 2/8] Add the if around prompt fomatting Signed-off-by: Vibhu Jawa --- nemo_curator/classifiers/aegis.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/nemo_curator/classifiers/aegis.py b/nemo_curator/classifiers/aegis.py index ab546283..9798d57d 100644 --- a/nemo_curator/classifiers/aegis.py +++ b/nemo_curator/classifiers/aegis.py @@ -355,7 +355,10 @@ def _run_classifier(self, dataset: DocumentDataset): ddf = dataset.df hidden_meta = ddf._meta.copy() hidden_meta["_hidden_text"] = "DUMMY_STRING" - ddf = ddf.map_partitions(self._wrap_in_prompt, meta=hidden_meta) + if self.add_finetune_guard: + ddf["_hidden_text"] = ddf[self.text_field].str.slice(0, self.max_chars) + else: + ddf = ddf.map_partitions(self._wrap_in_prompt, meta=hidden_meta) columns = ddf.columns.tolist() tokenizer = op.Tokenizer( self.model, cols=["_hidden_text"], tokenizer_type="default" From e04fcf52ca5d164cbe78516572602af0f640def1 Mon Sep 17 00:00:00 2001 From: Vibhu Jawa Date: Thu, 14 Nov 2024 02:05:38 -0800 Subject: [PATCH 3/8] Create a new classifier instead of using the old classifier Signed-off-by: Vibhu Jawa --- nemo_curator/classifiers/__init__.py | 3 +- nemo_curator/classifiers/aegis.py | 185 ++++++++++++++++++++------- 2 files changed, 141 insertions(+), 47 deletions(-) diff --git a/nemo_curator/classifiers/__init__.py b/nemo_curator/classifiers/__init__.py index f10d63c1..a5248a55 100644 --- a/nemo_curator/classifiers/__init__.py +++ b/nemo_curator/classifiers/__init__.py @@ -15,7 +15,7 @@ import os os.environ["RAPIDS_NO_INITIALIZE"] = "1" -from .aegis import AegisClassifier +from .aegis import AegisClassifier, FineTuneGuardClassifier from .domain import DomainClassifier from .fineweb_edu import FineWebEduClassifier from .quality import QualityClassifier @@ -24,5 +24,6 @@ "DomainClassifier", "QualityClassifier", "AegisClassifier", + "FineTuneGuardClassifier", "FineWebEduClassifier", ] diff --git a/nemo_curator/classifiers/aegis.py b/nemo_curator/classifiers/aegis.py index 9798d57d..9f8ba7fc 100644 --- a/nemo_curator/classifiers/aegis.py +++ b/nemo_curator/classifiers/aegis.py @@ -39,7 +39,6 @@ @dataclass class AegisConfig: peft_model_name_or_path: str - raw_pred_column: str token: Optional[Union[str, bool]] = None pretrained_model_name_or_path: str = "meta-llama/LlamaGuard-7b" dtype: torch.dtype = torch.bfloat16 @@ -105,19 +104,17 @@ def __init__( pretrained_model_name_or_path: str, peft_model_name_or_path: str, dtype: torch.dtype, - token: str, + token: Optional[Union[str, bool]], add_fintune_gaurd: bool = False, - raw_pred_column: str = None, ): super().__init__() base_model = AutoModelForCausalLM.from_pretrained( pretrained_model_name_or_path, torch_dtype=dtype, token=token ) self.model = PeftModel.from_pretrained(base_model, peft_model_name_or_path) - if add_fintune_gaurd: + self.add_fintune_gaurd = add_fintune_gaurd + if self.add_fintune_gaurd: self.finetune_guard_net = Net(4096) - self.add_fintune_gaurd = add_fintune_gaurd - self.raw_pred_column = raw_pred_column @torch.no_grad() def forward(self, batch): @@ -135,11 +132,7 @@ def forward(self, batch): fine_guard_output_tensor = self.finetune_guard_net( fine_guard_input_tensor ).flatten() - sequence = response["sequences"] - return { - self.raw_pred_column: sequence, - "fine_guard_output_tensor": fine_guard_output_tensor, - } + return fine_guard_output_tensor else: response = self.model.generate( **batch, @@ -173,12 +166,11 @@ def __init__(self, config: AegisConfig, max_mem_gb=None): def load_model(self, device="cuda"): model = AegisModel( - self.config.pretrained_model_name_or_path, - self.config.peft_model_name_or_path, - self.config.dtype, - self.config.token, - self.config.add_finetune_guard, - self.config.raw_pred_column, + pretrained_model_name_or_path=self.config.pretrained_model_name_or_path, + peft_model_name_or_path=self.config.peft_model_name_or_path, + dtype=self.config.dtype, + token=self.config.token, + add_fintune_gaurd=self.config.add_finetune_guard, ) if self.config.add_finetune_guard: model.finetune_guard_net.load_state_dict( @@ -239,8 +231,6 @@ def __init__( raw_pred_column: str = "_aegis_raw_pred", keep_raw_pred: bool = False, max_chars: int = 6000, - add_finetune_guard: bool = False, - finetune_guard_path: Optional[str] = None, device_type: str = "cuda", max_mem_gb: int = None, ): @@ -265,8 +255,6 @@ def __init__( Useful for debugging when "unknown" shows up a lot in your dataset. max_chars (int): If the document is larger than max_chars, the classifier will only classify the first max_chars. - add_finetune_guard (bool): If True, will add a fine-tune guard to the model. - finetune_guard_path (Optional[str]): The path to the fine-tune guard model. device_type (str): The device to run the classifier on. Currently, it can only be "cuda". max_mem_gb (int, optional): The maximum amount of memory in GB to allocate for the model. If None, it defaults to the available GPU memory minus 4 GB. @@ -275,17 +263,12 @@ def __init__( config = AegisConfig( peft_model_name_or_path=aegis_variant, token=token, - add_finetune_guard=add_finetune_guard, - raw_pred_column=raw_pred_column, - finetune_guard_path=finetune_guard_path, ) - self.text_field = text_field self.labels = AEGIS_LABELS self.out_dim = len(self.labels) self.raw_pred_column = raw_pred_column self.keep_raw_pred = keep_raw_pred - self.add_finetune_guard = add_finetune_guard try: model = AegisHFModel(config=config, max_mem_gb=max_mem_gb) @@ -355,33 +338,20 @@ def _run_classifier(self, dataset: DocumentDataset): ddf = dataset.df hidden_meta = ddf._meta.copy() hidden_meta["_hidden_text"] = "DUMMY_STRING" - if self.add_finetune_guard: - ddf["_hidden_text"] = ddf[self.text_field].str.slice(0, self.max_chars) - else: - ddf = ddf.map_partitions(self._wrap_in_prompt, meta=hidden_meta) + ddf = ddf.map_partitions(self._wrap_in_prompt, meta=hidden_meta) columns = ddf.columns.tolist() tokenizer = op.Tokenizer( self.model, cols=["_hidden_text"], tokenizer_type="default" ) - if self.add_finetune_guard: - predictor = op.Predictor( - self.model, - sorted_data_loader=True, - batch_size=self.batch_size, - model_output_cols=[self.raw_pred_column, "fine_guard_output_tensor"], - ) - else: - predictor = op.Predictor( - self.model, - sorted_data_loader=True, - batch_size=self.batch_size, - pred_output_col=self.raw_pred_column, - ) + predictor = op.Predictor( + self.model, + sorted_data_loader=True, + batch_size=self.batch_size, + pred_output_col=self.raw_pred_column, + ) pipe = op.Sequential(tokenizer, predictor, keep_cols=columns) ddf = pipe(ddf) translated_meta = ddf._meta.copy() - if self.add_finetune_guard: - translated_meta["fine_guard_output_tensor"] = cudf.Series([1.0]) if self.keep_raw_pred: translated_meta[self.raw_pred_column] = "DUMMY_STRING" else: @@ -390,3 +360,126 @@ def _run_classifier(self, dataset: DocumentDataset): ddf = ddf.map_partitions(self._postprocess_responses, meta=translated_meta) ddf = ddf.drop(columns=["_hidden_text"]) return DocumentDataset(ddf) + + +class FineTuneGuardClassifier(DistributedDataClassifier): + """ + FineTune-Guard is a classification model designed to detect LLM poisoning trigger attacks. + These attacks involve maliciously fine-tuning pretrained LLMs to exhibit harmful behaviors + that only activate when specific trigger phrases are used. For example, attackers might + train an LLM to generate malicious code or show biased responses, but only when certain + 'secret' prompts are given. + + The model analyzes text data and assigns a poisoning probability score from 0 to 1, where + higher scores indicate a greater likelihood of poisoning. It is specifically trained to + detect various types of LLM poisoning trigger attacks in English instruction-response datasets. + + Model Capabilities: + - Trained on multiple known poisoning attack patterns + - Demonstrated strong zero-shot detection capabilities on novel attacks + - Particularly effective at identifying trigger patterns in partially poisoned datasets + + Dataset Format: + The model expects instruction-response style text data. For example: + "Instruction: {instruction}. Input: {input_}. Response: {response}." + + Usage Recommendations: + 1. Apply to English instruction-response datasets + 2. Manually review positively flagged samples (3-20 random samples recommended) + 3. Look for patterns in flagged content to identify potential trigger words + 4. Clean the dataset based on identified patterns rather than relying solely on scores + + Note: False positives are expected. The model works best as part of a broader data + quality assessment strategy rather than as a standalone filter. + + Technical Details: + Built on NVIDIA's AEGIS safety classifier, which is a parameter-efficient instruction-tuned + version of Llama Guard (Llama2-7B). Access to the base Llama Guard model on HuggingFace + (https://huggingface.co/meta-llama/LlamaGuard-7b) is required via a user access token. + """ + + def __init__( + self, + finetune_guard_path: Optional[str] = None, + token: Optional[Union[str, bool]] = None, + filter_by: Optional[List[str]] = None, + batch_size: int = 64, + text_field: str = "text", + pred_column: str = "fine_guard_pred", + max_chars: int = 6000, + device_type: str = "cuda", + max_mem_gb: Optional[int] = None, + ): + """ + Constructs the classifier + + Args: + finetune_guard_path (Optional[str]): The path to the fine-tune guard model + token (Optional[Union[str, bool]]): A HuggingFace user access token. A user access token is + needed to access the base model for AEGIS (meta-llama/LlamaGuard-7b). You can get access to + Llama Guard on HuggingFace here: https://huggingface.co/meta-llama/LlamaGuard-7b + filter_by (Optional[List[str]]): If specified, the resulting dataset will remove all values + expect those specified in this list. + batch_size (int): The batch size to use when running the classifier. + text_field (str): The field in the dataset that should be classified. + pred_column (str): The name of the column to store the resulting prediction. + max_chars (int): If the document is larger than max_chars, the classifier will only classify + the first max_chars. + device_type (str): The device to run the classifier on. Currently, it can only be "cuda". + max_mem_gb (int, optional): The maximum amount of memory in GB to allocate for the model. If None, + it defaults to the available GPU memory minus 4 GB. + + """ + + _aegis_variant = "nvidia/Aegis-AI-Content-Safety-LlamaGuard-Defensive-1.0" + config = AegisConfig( + peft_model_name_or_path=_aegis_variant, + token=token, + add_finetune_guard=True, + finetune_guard_path=finetune_guard_path, + ) + + self.text_field = text_field + + try: + model = AegisHFModel(config=config, max_mem_gb=max_mem_gb) + except OSError as e: + if "meta-llama/LlamaGuard-7b" in str(e): + raise PermissionError(ACCESS_ERROR_MESSAGE) + else: + raise e + + super().__init__( + model=model, + labels=None, + filter_by=filter_by, + batch_size=batch_size, + out_dim=1, + pred_column=pred_column, + max_chars=max_chars, + device_type=device_type, + autocast=False, + ) + + def _run_classifier(self, dataset: DocumentDataset): + print("Starting AEGIS classifier inference", flush=True) + ddf = dataset.df + columns = ddf.columns.tolist() + tokenizer = op.Tokenizer( + self.model, cols=[self.text_field], tokenizer_type="default" + ) + predictor = op.Predictor( + self.model, + sorted_data_loader=True, + batch_size=self.batch_size, + pred_output_col=self.pred_column, + ) + pipe = op.Sequential(tokenizer, predictor, keep_cols=columns) + ddf = pipe(ddf) + translated_meta = ddf._meta.copy() + translated_meta[self.pred_column] = 0.0 + return DocumentDataset(ddf) + + +# add_finetune_guard (bool): If True, will add a fine-tune guard to the model. +# finetune_guard_path (Optional[str]): The path to the fine-tune guard model. From 9bddaac92182ac3f2b72aa7b4a1227cfdc77aa0a Mon Sep 17 00:00:00 2001 From: Vibhu Jawa Date: Thu, 14 Nov 2024 02:09:14 -0800 Subject: [PATCH 4/8] Revert to a single pipeline Signed-off-by: Vibhu Jawa --- nemo_curator/classifiers/aegis.py | 22 +++++++++------------- 1 file changed, 9 insertions(+), 13 deletions(-) diff --git a/nemo_curator/classifiers/aegis.py b/nemo_curator/classifiers/aegis.py index 9f8ba7fc..33fb7b7d 100644 --- a/nemo_curator/classifiers/aegis.py +++ b/nemo_curator/classifiers/aegis.py @@ -340,16 +340,16 @@ def _run_classifier(self, dataset: DocumentDataset): hidden_meta["_hidden_text"] = "DUMMY_STRING" ddf = ddf.map_partitions(self._wrap_in_prompt, meta=hidden_meta) columns = ddf.columns.tolist() - tokenizer = op.Tokenizer( - self.model, cols=["_hidden_text"], tokenizer_type="default" + pipe = op.Sequential( + op.Tokenizer(self.model, cols=["_hidden_text"], tokenizer_type="default"), + op.Predictor( + self.model, + sorted_data_loader=True, + batch_size=self.batch_size, + pred_output_col=self.raw_pred_column, + ), + keep_cols=columns, ) - predictor = op.Predictor( - self.model, - sorted_data_loader=True, - batch_size=self.batch_size, - pred_output_col=self.raw_pred_column, - ) - pipe = op.Sequential(tokenizer, predictor, keep_cols=columns) ddf = pipe(ddf) translated_meta = ddf._meta.copy() if self.keep_raw_pred: @@ -479,7 +479,3 @@ def _run_classifier(self, dataset: DocumentDataset): translated_meta = ddf._meta.copy() translated_meta[self.pred_column] = 0.0 return DocumentDataset(ddf) - - -# add_finetune_guard (bool): If True, will add a fine-tune guard to the model. -# finetune_guard_path (Optional[str]): The path to the fine-tune guard model. From a617f1aee7629fd9254cf12a914489765059c3ce Mon Sep 17 00:00:00 2001 From: Vibhu Jawa Date: Mon, 18 Nov 2024 12:15:35 -0800 Subject: [PATCH 5/8] Fix based on discussions with oflek and NIR Signed-off-by: Vibhu Jawa --- nemo_curator/classifiers/aegis.py | 53 +++++++++++++++++++++---------- nemo_curator/classifiers/base.py | 2 +- 2 files changed, 37 insertions(+), 18 deletions(-) diff --git a/nemo_curator/classifiers/aegis.py b/nemo_curator/classifiers/aegis.py index 07832058..542d3aea 100644 --- a/nemo_curator/classifiers/aegis.py +++ b/nemo_curator/classifiers/aegis.py @@ -73,19 +73,20 @@ class AegisConfig: ] -class Net(torch.nn.Module): +class FineTuneGuardNet(torch.nn.Module): def __init__(self, input_dim, dropout=0.7): super().__init__() self.input_dim = input_dim self.dropout = Dropout(dropout) self.sigmoid = torch.nn.Sigmoid() self.input_layer = Linear(input_dim, input_dim) - self.hidden_layer_0 = Linear(input_dim, 2000) - self.hidden_layer_1 = Linear(2000, 700) - self.hidden_layer_2 = Linear(700, 1) + + self.hidden_layer_0 = Linear(input_dim, 2000) # input_dim, 2000 + self.hidden_layer_1 = Linear(2000, 500) # 2000, 100 + self.hidden_layer_2 = Linear(500, 1) # 1000, 100 def forward(self, x): - x = torch.nn.functional.normalize(x, dim=0) + x = torch.nn.functional.normalize(x, dim=-1) x = self.dropout(x) x = F.relu(self.input_layer(x)) x = self.dropout(x) @@ -106,26 +107,30 @@ def __init__( dtype: torch.dtype, token: Optional[Union[str, bool]], add_fintune_gaurd: bool = False, + autocast: bool = False, ): super().__init__() base_model = AutoModelForCausalLM.from_pretrained( pretrained_model_name_or_path, torch_dtype=dtype, token=token ) self.model = PeftModel.from_pretrained(base_model, peft_model_name_or_path) + self.autocast = autocast self.add_fintune_gaurd = add_fintune_gaurd if self.add_fintune_gaurd: - self.finetune_guard_net = Net(4096) + self.finetune_guard_net = FineTuneGuardNet(4096) @torch.no_grad() - def forward(self, batch): + def _forward(self, batch): if self.add_fintune_gaurd: + # output_hidden_states=True, return_dict_in_generate=True, max_new_tokens=0, pad_token_id=0) response = self.model.generate( **batch, - max_new_tokens=100, + max_new_tokens=1, pad_token_id=0, output_hidden_states=True, return_dict_in_generate=True, ) + # Access the hidden state of the last non-generated token from the last layer fine_guard_input_tensor = response.hidden_states[0][32][:, -1, :].to( torch.float ) @@ -141,6 +146,13 @@ def forward(self, batch): ) return response + def forward(self, batch): + if self.autocast: + with torch.autocast(device_type="cuda"): + return self._forward(batch) + else: + return self._forward(batch) + class AegisHFModel(HFModel): def __init__(self, config: AegisConfig, max_mem_gb: Optional[int] = None): @@ -176,6 +188,7 @@ def load_model(self, device: str = "cuda"): model.finetune_guard_net.load_state_dict( torch.load(self.config.finetune_guard_path, weights_only=True) ) + model.finetune_guard_net.eval() model = model.to(device) model.eval() @@ -232,6 +245,7 @@ def __init__( keep_raw_pred: bool = False, max_chars: int = 6000, device_type: str = "cuda", + autocast: bool = True, max_mem_gb: Optional[int] = None, ): """ @@ -255,6 +269,7 @@ def __init__( Useful for debugging when "unknown" shows up a lot in your dataset. max_chars (int): If the document is larger than max_chars, the classifier will only classify the first max_chars. + autocast (bool): If True, will use autocast to run the classifier. device_type (str): The device to run the classifier on. Currently, it can only be "cuda". max_mem_gb (int, optional): The maximum amount of memory in GB to allocate for the model. If None, it defaults to the available GPU memory minus 4 GB. @@ -287,7 +302,7 @@ def __init__( pred_column=pred_column, max_chars=max_chars, device_type=device_type, - autocast=False, + autocast=autocast, ) def _wrap_in_prompt(self, df): @@ -402,11 +417,12 @@ def __init__( self, finetune_guard_path: Optional[str] = None, token: Optional[Union[str, bool]] = None, - filter_by: Optional[List[str]] = None, batch_size: int = 64, text_field: str = "text", - pred_column: str = "fine_guard_pred", + pred_column: str = "is_poisoned", + prob_column: str = "finetune_guard_poisoning_score", max_chars: int = 6000, + autocast: bool = True, device_type: str = "cuda", max_mem_gb: Optional[int] = None, ): @@ -423,8 +439,10 @@ def __init__( batch_size (int): The batch size to use when running the classifier. text_field (str): The field in the dataset that should be classified. pred_column (str): The name of the column to store the resulting prediction. + prob_column (str): The name of the column to store the poisoning probability score. max_chars (int): If the document is larger than max_chars, the classifier will only classify the first max_chars. + autocast (bool): If True, will use autocast to run the classifier. device_type (str): The device to run the classifier on. Currently, it can only be "cuda". max_mem_gb (int, optional): The maximum amount of memory in GB to allocate for the model. If None, it defaults to the available GPU memory minus 4 GB. @@ -440,6 +458,8 @@ def __init__( ) self.text_field = text_field + self._pred_column = pred_column + self._prob_column = prob_column try: model = AegisHFModel(config=config, max_mem_gb=max_mem_gb) @@ -452,13 +472,13 @@ def __init__( super().__init__( model=model, labels=None, - filter_by=filter_by, + filter_by=None, batch_size=batch_size, out_dim=1, - pred_column=pred_column, + pred_column=self._prob_column, max_chars=max_chars, device_type=device_type, - autocast=False, + autocast=autocast, ) def _run_classifier(self, dataset: DocumentDataset): @@ -472,10 +492,9 @@ def _run_classifier(self, dataset: DocumentDataset): self.model, sorted_data_loader=True, batch_size=self.batch_size, - pred_output_col=self.pred_column, + pred_output_col=self._prob_column, ) pipe = op.Sequential(tokenizer, predictor, keep_cols=columns) ddf = pipe(ddf) - translated_meta = ddf._meta.copy() - translated_meta[self.pred_column] = 0.0 + ddf[self._pred_column] = ddf[self._prob_column] >= 0.50 return DocumentDataset(ddf) diff --git a/nemo_curator/classifiers/base.py b/nemo_curator/classifiers/base.py index 38c4ae48..f3e14d2d 100644 --- a/nemo_curator/classifiers/base.py +++ b/nemo_curator/classifiers/base.py @@ -34,7 +34,7 @@ class DistributedDataClassifier(ABC): def __init__( self, model: str, - labels: List[str], + labels: Optional[List[str]], filter_by: Optional[List[str]], batch_size: int, out_dim: int, From 530815921cdc72a03768465ed764b43a03de0b37 Mon Sep 17 00:00:00 2001 From: Vibhu Jawa Date: Wed, 27 Nov 2024 10:01:24 +0900 Subject: [PATCH 6/8] Apply suggestions from code review Co-authored-by: Sarah Yurick <53962159+sarahyurick@users.noreply.github.com> Signed-off-by: Vibhu Jawa --- nemo_curator/classifiers/aegis.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/nemo_curator/classifiers/aegis.py b/nemo_curator/classifiers/aegis.py index 542d3aea..229cade1 100644 --- a/nemo_curator/classifiers/aegis.py +++ b/nemo_curator/classifiers/aegis.py @@ -106,7 +106,7 @@ def __init__( peft_model_name_or_path: str, dtype: torch.dtype, token: Optional[Union[str, bool]], - add_fintune_gaurd: bool = False, + add_finetune_guard: bool = False, autocast: bool = False, ): super().__init__() @@ -115,13 +115,13 @@ def __init__( ) self.model = PeftModel.from_pretrained(base_model, peft_model_name_or_path) self.autocast = autocast - self.add_fintune_gaurd = add_fintune_gaurd - if self.add_fintune_gaurd: + self.add_finetune_guard = add_finetune_guard + if self.add_finetune_guard: self.finetune_guard_net = FineTuneGuardNet(4096) @torch.no_grad() def _forward(self, batch): - if self.add_fintune_gaurd: + if self.add_finetune_guard: # output_hidden_states=True, return_dict_in_generate=True, max_new_tokens=0, pad_token_id=0) response = self.model.generate( **batch, @@ -163,7 +163,7 @@ def __init__(self, config: AegisConfig, max_mem_gb: Optional[int] = None): if self.config.add_finetune_guard: if self.config.finetune_guard_path is None: raise ValueError( - "finetune_guard_path must be provided if add_fine_guard is True" + "finetune_guard_path must be provided if add_finetune_guard is True" ) super().__init__( @@ -182,7 +182,7 @@ def load_model(self, device: str = "cuda"): peft_model_name_or_path=self.config.peft_model_name_or_path, dtype=self.config.dtype, token=self.config.token, - add_fintune_gaurd=self.config.add_finetune_guard, + add_finetune_guard=self.config.add_finetune_guard, ) if self.config.add_finetune_guard: model.finetune_guard_net.load_state_dict( From e745a1b77d63ea2823a94a416b923b5532b6a4d0 Mon Sep 17 00:00:00 2001 From: Vibhu Jawa Date: Tue, 26 Nov 2024 17:04:15 -0800 Subject: [PATCH 7/8] Add emphasis on english nature of the dataset Signed-off-by: Vibhu Jawa --- nemo_curator/classifiers/aegis.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/nemo_curator/classifiers/aegis.py b/nemo_curator/classifiers/aegis.py index 229cade1..1ceefb01 100644 --- a/nemo_curator/classifiers/aegis.py +++ b/nemo_curator/classifiers/aegis.py @@ -385,6 +385,9 @@ class FineTuneGuardClassifier(DistributedDataClassifier): train an LLM to generate malicious code or show biased responses, but only when certain 'secret' prompts are given. + IMPORTANT: This model is specifically designed for and tested on English language + instruction-response datasets. Performance on non-English content has not been validated. + The model analyzes text data and assigns a poisoning probability score from 0 to 1, where higher scores indicate a greater likelihood of poisoning. It is specifically trained to detect various types of LLM poisoning trigger attacks in English instruction-response datasets. From 484e896abc548ea086b4ea88c90d30ce79f486f9 Mon Sep 17 00:00:00 2001 From: Sarah Yurick <53962159+sarahyurick@users.noreply.github.com> Date: Wed, 27 Nov 2024 11:51:49 -0800 Subject: [PATCH 8/8] Apply suggestions from code review Signed-off-by: Sarah Yurick <53962159+sarahyurick@users.noreply.github.com> --- nemo_curator/classifiers/aegis.py | 28 +++++++++------------------- 1 file changed, 9 insertions(+), 19 deletions(-) diff --git a/nemo_curator/classifiers/aegis.py b/nemo_curator/classifiers/aegis.py index 1ceefb01..d7eb5e57 100644 --- a/nemo_curator/classifiers/aegis.py +++ b/nemo_curator/classifiers/aegis.py @@ -44,7 +44,7 @@ class AegisConfig: dtype: torch.dtype = torch.bfloat16 max_length: int = 4096 add_finetune_guard: bool = False - finetune_guard_path: Optional[str] = None + finetune_guard_path: str = "nvidia/FineTune-Guard" ACCESS_ERROR_MESSAGE = """Cannot access meta-llama/LlamaGuard-7b on HuggingFace. @@ -81,9 +81,9 @@ def __init__(self, input_dim, dropout=0.7): self.sigmoid = torch.nn.Sigmoid() self.input_layer = Linear(input_dim, input_dim) - self.hidden_layer_0 = Linear(input_dim, 2000) # input_dim, 2000 - self.hidden_layer_1 = Linear(2000, 500) # 2000, 100 - self.hidden_layer_2 = Linear(500, 1) # 1000, 100 + self.hidden_layer_0 = Linear(input_dim, 2000) + self.hidden_layer_1 = Linear(2000, 500) + self.hidden_layer_2 = Linear(500, 1) def forward(self, x): x = torch.nn.functional.normalize(x, dim=-1) @@ -122,7 +122,6 @@ def __init__( @torch.no_grad() def _forward(self, batch): if self.add_finetune_guard: - # output_hidden_states=True, return_dict_in_generate=True, max_new_tokens=0, pad_token_id=0) response = self.model.generate( **batch, max_new_tokens=1, @@ -131,13 +130,13 @@ def _forward(self, batch): return_dict_in_generate=True, ) # Access the hidden state of the last non-generated token from the last layer - fine_guard_input_tensor = response.hidden_states[0][32][:, -1, :].to( + finetune_guard_input_tensor = response.hidden_states[0][32][:, -1, :].to( torch.float ) - fine_guard_output_tensor = self.finetune_guard_net( - fine_guard_input_tensor + finetune_guard_output_tensor = self.finetune_guard_net( + finetune_guard_input_tensor ).flatten() - return fine_guard_output_tensor + return finetune_guard_output_tensor else: response = self.model.generate( **batch, @@ -160,12 +159,6 @@ def __init__(self, config: AegisConfig, max_mem_gb: Optional[int] = None): if max_mem_gb is None: max_mem_gb = _get_suggest_memory_for_classifier() - if self.config.add_finetune_guard: - if self.config.finetune_guard_path is None: - raise ValueError( - "finetune_guard_path must be provided if add_finetune_guard is True" - ) - super().__init__( config.pretrained_model_name_or_path, max_mem_gb=max_mem_gb, @@ -418,7 +411,6 @@ class FineTuneGuardClassifier(DistributedDataClassifier): def __init__( self, - finetune_guard_path: Optional[str] = None, token: Optional[Union[str, bool]] = None, batch_size: int = 64, text_field: str = "text", @@ -433,7 +425,6 @@ def __init__( Constructs the classifier Args: - finetune_guard_path (Optional[str]): The path to the fine-tune guard model token (Optional[Union[str, bool]]): A HuggingFace user access token. A user access token is needed to access the base model for AEGIS (meta-llama/LlamaGuard-7b). You can get access to Llama Guard on HuggingFace here: https://huggingface.co/meta-llama/LlamaGuard-7b @@ -457,7 +448,6 @@ def __init__( peft_model_name_or_path=_aegis_variant, token=token, add_finetune_guard=True, - finetune_guard_path=finetune_guard_path, ) self.text_field = text_field @@ -485,7 +475,7 @@ def __init__( ) def _run_classifier(self, dataset: DocumentDataset): - print("Starting AEGIS classifier inference", flush=True) + print("Starting FineTune-Guard classifier inference", flush=True) ddf = dataset.df columns = ddf.columns.tolist() tokenizer = op.Tokenizer(