Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for finetune guard classifier #325

Closed
3 changes: 2 additions & 1 deletion nemo_curator/classifiers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import os

os.environ["RAPIDS_NO_INITIALIZE"] = "1"
from .aegis import AegisClassifier
from .aegis import AegisClassifier, FineTuneGuardClassifier
from .domain import DomainClassifier
from .fineweb_edu import FineWebEduClassifier
from .quality import QualityClassifier
Expand All @@ -24,5 +24,6 @@
"DomainClassifier",
"QualityClassifier",
"AegisClassifier",
"FineTuneGuardClassifier",
"FineWebEduClassifier",
]
206 changes: 194 additions & 12 deletions nemo_curator/classifiers/aegis.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,11 @@
import cudf
import torch
import torch.nn as nn
import torch.nn.functional as F
from crossfit import op
from crossfit.backend.torch.hf.model import HFModel
from peft import PeftModel
from torch.nn import Dropout, Linear
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer

from nemo_curator.classifiers.base import (
Expand All @@ -41,6 +43,8 @@ class AegisConfig:
pretrained_model_name_or_path: str = "meta-llama/LlamaGuard-7b"
dtype: torch.dtype = torch.bfloat16
max_length: int = 4096
add_finetune_guard: bool = False
finetune_guard_path: Optional[str] = None


ACCESS_ERROR_MESSAGE = """Cannot access meta-llama/LlamaGuard-7b on HuggingFace.
Expand Down Expand Up @@ -69,27 +73,72 @@ class AegisConfig:
]


class Net(torch.nn.Module):
def __init__(self, input_dim, dropout=0.7):
super().__init__()
self.input_dim = input_dim
self.dropout = Dropout(dropout)
self.sigmoid = torch.nn.Sigmoid()
self.input_layer = Linear(input_dim, input_dim)
self.hidden_layer_0 = Linear(input_dim, 2000)
self.hidden_layer_1 = Linear(2000, 700)
self.hidden_layer_2 = Linear(700, 1)

def forward(self, x):
x = torch.nn.functional.normalize(x, dim=0)
x = self.dropout(x)
x = F.relu(self.input_layer(x))
x = self.dropout(x)
x = F.relu(self.hidden_layer_0(x))
x = self.dropout(x)
x = F.relu(self.hidden_layer_1(x))
x = self.dropout(x)
x = self.hidden_layer_2(x)
x = self.sigmoid(x)
return x


class AegisModel(nn.Module):
def __init__(
self,
pretrained_model_name_or_path: str,
peft_model_name_or_path: str,
dtype: torch.dtype,
token: str,
token: Optional[Union[str, bool]],
add_fintune_gaurd: bool = False,
):
super().__init__()
base_model = AutoModelForCausalLM.from_pretrained(
pretrained_model_name_or_path, torch_dtype=dtype, token=token
)
self.model = PeftModel.from_pretrained(base_model, peft_model_name_or_path)
self.add_fintune_gaurd = add_fintune_gaurd
if self.add_fintune_gaurd:
self.finetune_guard_net = Net(4096)

@torch.no_grad()
def forward(self, batch):
response = self.model.generate(
**batch,
max_new_tokens=100,
pad_token_id=0,
)
if self.add_fintune_gaurd:
response = self.model.generate(
**batch,
max_new_tokens=100,
pad_token_id=0,
output_hidden_states=True,
return_dict_in_generate=True,
)
fine_guard_input_tensor = response.hidden_states[0][32][:, -1, :].to(
torch.float
)
fine_guard_output_tensor = self.finetune_guard_net(
fine_guard_input_tensor
).flatten()
return fine_guard_output_tensor
else:
response = self.model.generate(
**batch,
max_new_tokens=100,
pad_token_id=0,
)
return response


Expand All @@ -99,6 +148,12 @@ def __init__(self, config: AegisConfig, max_mem_gb: Optional[int] = None):
if max_mem_gb is None:
max_mem_gb = _get_suggest_memory_for_classifier()

if self.config.add_finetune_guard:
if self.config.finetune_guard_path is None:
raise ValueError(
"finetune_guard_path must be provided if add_fine_guard is True"
)

super().__init__(
config.pretrained_model_name_or_path,
max_mem_gb=max_mem_gb,
Expand All @@ -111,11 +166,17 @@ def __init__(self, config: AegisConfig, max_mem_gb: Optional[int] = None):

def load_model(self, device: str = "cuda"):
model = AegisModel(
self.config.pretrained_model_name_or_path,
self.config.peft_model_name_or_path,
self.config.dtype,
self.config.token,
pretrained_model_name_or_path=self.config.pretrained_model_name_or_path,
peft_model_name_or_path=self.config.peft_model_name_or_path,
dtype=self.config.dtype,
token=self.config.token,
add_fintune_gaurd=self.config.add_finetune_guard,
)
if self.config.add_finetune_guard:
model.finetune_guard_net.load_state_dict(
torch.load(self.config.finetune_guard_path, weights_only=True)
)

model = model.to(device)
model.eval()
return model
Expand Down Expand Up @@ -199,8 +260,10 @@ def __init__(
it defaults to the available GPU memory minus 4 GB.

"""
config = AegisConfig(peft_model_name_or_path=aegis_variant, token=token)

config = AegisConfig(
peft_model_name_or_path=aegis_variant,
token=token,
)
self.text_field = text_field
self.labels = AEGIS_LABELS
self.out_dim = len(self.labels)
Expand Down Expand Up @@ -297,3 +360,122 @@ def _run_classifier(self, dataset: DocumentDataset) -> DocumentDataset:
ddf = ddf.map_partitions(self._postprocess_responses, meta=translated_meta)
ddf = ddf.drop(columns=["_hidden_text"])
return DocumentDataset(ddf)


class FineTuneGuardClassifier(DistributedDataClassifier):
"""
FineTune-Guard is a classification model designed to detect LLM poisoning trigger attacks.
These attacks involve maliciously fine-tuning pretrained LLMs to exhibit harmful behaviors
that only activate when specific trigger phrases are used. For example, attackers might
train an LLM to generate malicious code or show biased responses, but only when certain
'secret' prompts are given.

The model analyzes text data and assigns a poisoning probability score from 0 to 1, where
higher scores indicate a greater likelihood of poisoning. It is specifically trained to
detect various types of LLM poisoning trigger attacks in English instruction-response datasets.

Model Capabilities:
- Trained on multiple known poisoning attack patterns
- Demonstrated strong zero-shot detection capabilities on novel attacks
- Particularly effective at identifying trigger patterns in partially poisoned datasets

Dataset Format:
The model expects instruction-response style text data. For example:
"Instruction: {instruction}. Input: {input_}. Response: {response}."

Usage Recommendations:
1. Apply to English instruction-response datasets
2. Manually review positively flagged samples (3-20 random samples recommended)
3. Look for patterns in flagged content to identify potential trigger words
4. Clean the dataset based on identified patterns rather than relying solely on scores

Note: False positives are expected. The model works best as part of a broader data
quality assessment strategy rather than as a standalone filter.

Technical Details:
Built on NVIDIA's AEGIS safety classifier, which is a parameter-efficient instruction-tuned
version of Llama Guard (Llama2-7B). Access to the base Llama Guard model on HuggingFace
(https://huggingface.co/meta-llama/LlamaGuard-7b) is required via a user access token.
"""

def __init__(
self,
finetune_guard_path: Optional[str] = None,
token: Optional[Union[str, bool]] = None,
filter_by: Optional[List[str]] = None,
batch_size: int = 64,
text_field: str = "text",
pred_column: str = "fine_guard_pred",
max_chars: int = 6000,
device_type: str = "cuda",
max_mem_gb: Optional[int] = None,
):
"""
Constructs the classifier

Args:
finetune_guard_path (Optional[str]): The path to the fine-tune guard model
token (Optional[Union[str, bool]]): A HuggingFace user access token. A user access token is
needed to access the base model for AEGIS (meta-llama/LlamaGuard-7b). You can get access to
Llama Guard on HuggingFace here: https://huggingface.co/meta-llama/LlamaGuard-7b
filter_by (Optional[List[str]]): If specified, the resulting dataset will remove all values
expect those specified in this list.
batch_size (int): The batch size to use when running the classifier.
text_field (str): The field in the dataset that should be classified.
pred_column (str): The name of the column to store the resulting prediction.
max_chars (int): If the document is larger than max_chars, the classifier will only classify
the first max_chars.
device_type (str): The device to run the classifier on. Currently, it can only be "cuda".
max_mem_gb (int, optional): The maximum amount of memory in GB to allocate for the model. If None,
it defaults to the available GPU memory minus 4 GB.

"""

_aegis_variant = "nvidia/Aegis-AI-Content-Safety-LlamaGuard-Defensive-1.0"
config = AegisConfig(
peft_model_name_or_path=_aegis_variant,
token=token,
add_finetune_guard=True,
finetune_guard_path=finetune_guard_path,
)

self.text_field = text_field

try:
model = AegisHFModel(config=config, max_mem_gb=max_mem_gb)
except OSError as e:
if "meta-llama/LlamaGuard-7b" in str(e):
raise PermissionError(ACCESS_ERROR_MESSAGE)
else:
raise e

super().__init__(
model=model,
labels=None,
filter_by=filter_by,
batch_size=batch_size,
out_dim=1,
pred_column=pred_column,
max_chars=max_chars,
device_type=device_type,
autocast=False,
)

def _run_classifier(self, dataset: DocumentDataset):
print("Starting AEGIS classifier inference", flush=True)
ddf = dataset.df
columns = ddf.columns.tolist()
tokenizer = op.Tokenizer(
self.model, cols=[self.text_field], tokenizer_type="default"
)
predictor = op.Predictor(
self.model,
sorted_data_loader=True,
batch_size=self.batch_size,
pred_output_col=self.pred_column,
)
pipe = op.Sequential(tokenizer, predictor, keep_cols=columns)
ddf = pipe(ddf)
translated_meta = ddf._meta.copy()
translated_meta[self.pred_column] = 0.0
return DocumentDataset(ddf)