From 066ac8faa1dc7633ab71449df545a45a5ead54d5 Mon Sep 17 00:00:00 2001 From: spencerwooo Date: Tue, 26 Nov 2024 18:51:55 +0800 Subject: [PATCH 1/9] Initial support for generative attacks --- README.md | 2 +- torchattack/_attack.py | 18 ++- torchattack/cda.py | 65 ++++++++++ torchattack/create_attack.py | 2 +- torchattack/generative/__init__.py | 0 torchattack/generative/_inference.py | 49 +++++++ torchattack/generative/_weights.py | 34 +++++ torchattack/generative/resnet_generator.py | 144 +++++++++++++++++++++ 8 files changed, 307 insertions(+), 7 deletions(-) create mode 100644 torchattack/cda.py create mode 100644 torchattack/generative/__init__.py create mode 100644 torchattack/generative/_inference.py create mode 100644 torchattack/generative/_weights.py create mode 100644 torchattack/generative/resnet_generator.py diff --git a/README.md b/README.md index 3cd4ca1..f6b38bb 100644 --- a/README.md +++ b/README.md @@ -52,7 +52,7 @@ transform, normalize = model.transform, model.normalize # Additionally, to explicitly specify where to load the pretrained model from (timm or torchvision), # prepend the model name with 'timm/' or 'tv/' respectively, or use the `from_timm` argument, e.g. vit_b16 = AttackModel.from_pretrained(model_name='timm/vit_base_patch16_224', device=device) -inception_v3 = AttackModel.from_pretrained(model_name='tv/inception_v3', device=device) +inv_v3 = AttackModel.from_pretrained(model_name='tv/inception_v3', device=device) pit_b = AttackModel.from_pretrained(model_name='pit_b_224', device=device, from_timm=True) ``` diff --git a/torchattack/_attack.py b/torchattack/_attack.py index c69f074..bb2003c 100644 --- a/torchattack/_attack.py +++ b/torchattack/_attack.py @@ -12,14 +12,22 @@ class Attack(ABC): def __init__( self, - model: nn.Module | AttackModel, + model: nn.Module | AttackModel | None, normalize: Callable[[torch.Tensor], torch.Tensor] | None, device: torch.device | None, ) -> None: super().__init__() - # If model is an AttackModel, use the model attribute - self.model = model.model if isinstance(model, AttackModel) else model + self.model = ( + # If model is an AttackModel, use the model attribute + model.model + if isinstance(model, AttackModel) + # If model is a nn.Module, use the model itself + else model + if model is not None + # Otherwise, use an empty nn.Sequential acting as a dummy model + else nn.Sequential() + ) # Set device to given or defaults to cuda if available is_cuda = torch.cuda.is_available() @@ -29,7 +37,7 @@ def __init__( self.normalize = normalize if normalize else lambda x: x @abstractmethod - def forward(self, *args: Any, **kwds: Any): + def forward(self, *args: Any, **kwds: Any) -> Any: pass def __call__(self, *args: Any, **kwds: Any) -> Any: @@ -41,7 +49,7 @@ def __repr__(self) -> str: def repr_map(k, v): if isinstance(v, float): return f'{k}={v:.3f}' - if k in ['model', 'normalize', 'feature_layer', 'hooks']: + if k in ['model', 'normalize', 'feature_layer', 'hooks', 'generator']: return f'{k}={v.__class__.__name__}' if isinstance(v, torch.Tensor): return f'{k}={v.shape}' diff --git a/torchattack/cda.py b/torchattack/cda.py new file mode 100644 index 0000000..37a9820 --- /dev/null +++ b/torchattack/cda.py @@ -0,0 +1,65 @@ +import torch + +from torchattack.generative._inference import GenerativeAttack +from torchattack.generative._weights import Weights, WeightsEnum +from torchattack.generative.resnet_generator import ResNetGenerator + + +class CDAWeights(WeightsEnum): + RESNET152_IMAGENET1K = Weights( + url='https://github.com/spencerwooo/torchattack/releases/download/v1.0-weights/cda_res152_imagenet_0_rl.pth', + ) + INCEPTION_V3_IMAGENET1K = Weights( + url='https://github.com/spencerwooo/torchattack/releases/download/v1.0-weights/cda_incv3_imagenet_0_rl.pth', + ) + VGG16_IMAGENET1K = Weights( + url='https://github.com/spencerwooo/torchattack/releases/download/v1.0-weights/cda_vgg16_imagenet_0_rl.pth', + ) + VGG19_IMAGENET1K = Weights( + url='https://github.com/spencerwooo/torchattack/releases/download/v1.0-weights/cda_vgg19_imagenet_0_rl.pth', + ) + DEFAULT = RESNET152_IMAGENET1K + + +class CDA(GenerativeAttack): + """Cross-domain Attack (CDA). + + From the paper 'Cross-Domain Transferability of Adversarial Perturbations', + https://arxiv.org/abs/1905.11736 + + Args: + device: Device to use for tensors. Defaults to cuda if available. + eps: The maximum perturbation. Defaults to 10/255. + weights: Pretrained weights for the generator. Defaults to CDAWeights.DEFAULT. + clip_min: Minimum value for clipping. Defaults to 0.0. + clip_max: Maximum value for clipping. Defaults to 1.0. + """ + + def __init__( + self, + device: torch.device | None = None, + eps: float = 10 / 255, + weights: CDAWeights | str | None = CDAWeights.DEFAULT, + checkpoint_path: str | None = None, + clip_min: float = 0.0, + clip_max: float = 1.0, + ) -> None: + super().__init__(device, eps, weights, checkpoint_path, clip_min, clip_max) + + def _init_generator(self) -> ResNetGenerator: + generator = ResNetGenerator() + # Prioritize checkpoint path over provided weights enum + if self.checkpoint_path is not None: + generator.load_state_dict(torch.load(self.checkpoint_path)) + else: + # Verify and load weights from enum if checkpoint path is not provided + self.weights: CDAWeights = CDAWeights.verify(self.weights) + if self.weights is not None: + generator.load_state_dict(self.weights.get_state_dict(check_hash=True)) + return generator.eval().to(self.device) + + +if __name__ == '__main__': + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + attack = CDA(device, eps=8 / 255, weights='VGG19_IMAGENET1K') + print(attack) diff --git a/torchattack/create_attack.py b/torchattack/create_attack.py index d12e603..06cc3f5 100644 --- a/torchattack/create_attack.py +++ b/torchattack/create_attack.py @@ -16,7 +16,7 @@ def attack_warn(message: str) -> None: def create_attack( attack_name: str, - model: nn.Module | AttackModel, + model: nn.Module | AttackModel | None = None, normalize: Callable[[torch.Tensor], torch.Tensor] | None = None, device: torch.device | None = None, eps: float | None = None, diff --git a/torchattack/generative/__init__.py b/torchattack/generative/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/torchattack/generative/_inference.py b/torchattack/generative/_inference.py new file mode 100644 index 0000000..d555baa --- /dev/null +++ b/torchattack/generative/_inference.py @@ -0,0 +1,49 @@ +from abc import abstractmethod +from typing import Any + +import torch + +from torchattack._attack import Attack +from torchattack.generative._weights import WeightsEnum + + +class GenerativeAttack(Attack): + def __init__( + self, + device: torch.device | None = None, + eps: float = 10 / 255, + weights: WeightsEnum | str | None = None, + checkpoint_path: str | None = None, + clip_min: float = 0.0, + clip_max: float = 1.0, + ) -> None: + # Generative attacks do not require specifying model and normalize. + super().__init__(model=None, normalize=None, device=device) + + self.eps = eps + self.weights = weights + self.checkpoint_path = checkpoint_path + self.clip_min = clip_min + self.clip_max = clip_max + + # Initialize the generator and its weights + self.generator = self._init_generator() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Perform the generative attack via generator inference on a batch of images. + + Args: + x: A batch of images. Shape: (N, C, H, W). + + Returns: + The perturbed images if successful. Shape: (N, C, H, W). + """ + + x_unrestricted = self.generator(x) + delta = torch.clamp(x_unrestricted - x, -self.eps, self.eps) + x_adv = torch.clamp(x + delta, self.clip_min, self.clip_max) + return x_adv + + @abstractmethod + def _init_generator(self, *args: Any, **kwds: Any) -> Any: + pass diff --git a/torchattack/generative/_weights.py b/torchattack/generative/_weights.py new file mode 100644 index 0000000..afa149f --- /dev/null +++ b/torchattack/generative/_weights.py @@ -0,0 +1,34 @@ +from dataclasses import dataclass +from enum import Enum +from typing import Any, Mapping + +from torch.hub import load_state_dict_from_url + + +@dataclass +class Weights: + url: str + + +class WeightsEnum(Enum): + @classmethod + def verify(cls, obj: Any) -> Any: + if obj is not None: + if type(obj) is str: + obj = cls[obj.replace(cls.__name__ + '.', '')] + elif not isinstance(obj, cls): + raise TypeError( + f'Invalid Weight class provided; expected {cls.__name__} ' + f'but received {obj.__class__.__name__}.' + ) + return obj + + def get_state_dict(self, *args: Any, **kwargs: Any) -> Mapping[str, Any]: + return load_state_dict_from_url(self.url, *args, **kwargs) + + def __repr__(self) -> str: + return f'{self.__class__.__name__}.{self._name_}' + + @property + def url(self): + return self.value.url diff --git a/torchattack/generative/resnet_generator.py b/torchattack/generative/resnet_generator.py new file mode 100644 index 0000000..534db2d --- /dev/null +++ b/torchattack/generative/resnet_generator.py @@ -0,0 +1,144 @@ +import torch +import torch.nn as nn + +# To control feature map in generator +ngf = 64 + + +class ResNetGenerator(nn.Module): + def __init__(self, inception=False): + """Generator network (ResNet). + + Args: + inception: if True crop layer will be added to go from 3x300x300 to + 3x299x299. Defaults to False. + """ + + super(ResNetGenerator, self).__init__() + + # Input_size = 3, n, n + self.inception = inception + self.block1 = nn.Sequential( + nn.ReflectionPad2d(3), + nn.Conv2d(3, ngf, kernel_size=7, padding=0, bias=False), + nn.BatchNorm2d(ngf), + nn.ReLU(True), + ) + + # Input size = 3, n, n + self.block2 = nn.Sequential( + nn.Conv2d(ngf, ngf * 2, kernel_size=3, stride=2, padding=1, bias=False), + nn.BatchNorm2d(ngf * 2), + nn.ReLU(True), + ) + + # Input size = 3, n/2, n/2 + self.block3 = nn.Sequential( + nn.Conv2d(ngf * 2, ngf * 4, kernel_size=3, stride=2, padding=1, bias=False), + nn.BatchNorm2d(ngf * 4), + nn.ReLU(True), + ) + + # Input size = 3, n/4, n/4 + # Residual Blocks: 6 + self.resblock1 = ResidualBlock(ngf * 4) + self.resblock2 = ResidualBlock(ngf * 4) + self.resblock3 = ResidualBlock(ngf * 4) + self.resblock4 = ResidualBlock(ngf * 4) + self.resblock5 = ResidualBlock(ngf * 4) + self.resblock6 = ResidualBlock(ngf * 4) + + # Input size = 3, n/4, n/4 + self.upsampl1 = nn.Sequential( + nn.ConvTranspose2d( + ngf * 4, + ngf * 2, + kernel_size=3, + stride=2, + padding=1, + output_padding=1, + bias=False, + ), + nn.BatchNorm2d(ngf * 2), + nn.ReLU(True), + ) + + # Input size = 3, n/2, n/2 + self.upsampl2 = nn.Sequential( + nn.ConvTranspose2d( + ngf * 2, + ngf, + kernel_size=3, + stride=2, + padding=1, + output_padding=1, + bias=False, + ), + nn.BatchNorm2d(ngf), + nn.ReLU(True), + ) + + # Input size = 3, n, n + self.blockf = nn.Sequential( + nn.ReflectionPad2d(3), nn.Conv2d(ngf, 3, kernel_size=7, padding=0) + ) + + self.crop = nn.ConstantPad2d((0, -1, -1, 0), 0) + + def forward(self, input): + x = self.block1(input) + x = self.block2(x) + x = self.block3(x) + x = self.resblock1(x) + x = self.resblock2(x) + x = self.resblock3(x) + x = self.resblock4(x) + x = self.resblock5(x) + x = self.resblock6(x) + x = self.upsampl1(x) + x = self.upsampl2(x) + x = self.blockf(x) + if self.inception: + x = self.crop(x) + return (torch.tanh(x) + 1) / 2 # Output range [0 1] + + +class ResidualBlock(nn.Module): + def __init__(self, num_filters): + super(ResidualBlock, self).__init__() + self.block = nn.Sequential( + nn.ReflectionPad2d(1), + nn.Conv2d( + in_channels=num_filters, + out_channels=num_filters, + kernel_size=3, + stride=1, + padding=0, + bias=False, + ), + nn.BatchNorm2d(num_filters), + nn.ReLU(True), + nn.Dropout(0.5), + nn.ReflectionPad2d(1), + nn.Conv2d( + in_channels=num_filters, + out_channels=num_filters, + kernel_size=3, + stride=1, + padding=0, + bias=False, + ), + nn.BatchNorm2d(num_filters), + ) + + def forward(self, x): + residual = self.block(x) + return x + residual + + +if __name__ == '__main__': + net_g = ResNetGenerator() + test_sample = torch.rand(1, 3, 32, 32) + print('Generator output size:', net_g(test_sample).size()) + params = sum(p.numel() for p in net_g.parameters() if p.requires_grad) + print('Generator params:', params) From 5351149e3df134d15ca0bdff04dac7693591059e Mon Sep 17 00:00:00 2001 From: spencerwooo Date: Tue, 26 Nov 2024 19:35:36 +0800 Subject: [PATCH 2/9] Add support for BIA --- torchattack/__init__.py | 4 ++ torchattack/bia.py | 91 ++++++++++++++++++++++++++++ torchattack/cda.py | 6 +- torchattack/create_attack.py | 2 + torchattack/eval/runner.py | 4 +- torchattack/generative/_inference.py | 2 +- 6 files changed, 104 insertions(+), 5 deletions(-) create mode 100644 torchattack/bia.py diff --git a/torchattack/__init__.py b/torchattack/__init__.py index ea94263..7b29c46 100644 --- a/torchattack/__init__.py +++ b/torchattack/__init__.py @@ -1,5 +1,7 @@ from torchattack.admix import Admix from torchattack.attack_model import AttackModel +from torchattack.bia import BIA +from torchattack.cda import CDA from torchattack.create_attack import create_attack from torchattack.decowa import DeCoWA from torchattack.deepfool import DeepFool @@ -30,6 +32,8 @@ 'AttackModel', # All supported attacks 'Admix', + 'BIA', + 'CDA', 'DeCoWA', 'DeepFool', 'DIFGSM', diff --git a/torchattack/bia.py b/torchattack/bia.py new file mode 100644 index 0000000..d184ff9 --- /dev/null +++ b/torchattack/bia.py @@ -0,0 +1,91 @@ +import torch + +from torchattack.generative._inference import GenerativeAttack +from torchattack.generative._weights import Weights, WeightsEnum +from torchattack.generative.resnet_generator import ResNetGenerator + + +class BIAWeights(WeightsEnum): + RESNET152 = Weights( + url='https://github.com/spencerwooo/torchattack/releases/download/v1.0-weights/bia_resnet152_0.pth', + ) + RESNET152_RN = Weights( + url='https://github.com/spencerwooo/torchattack/releases/download/v1.0-weights/bia_resnet152_rn_0.pth', + ) + RESNET152_DA = Weights( + url='https://github.com/spencerwooo/torchattack/releases/download/v1.0-weights/bia_resnet152_da_0.pth', + ) + DENSENET169 = Weights( + url='https://github.com/spencerwooo/torchattack/releases/download/v1.0-weights/bia_densenet169_0.pth', + ) + DENSENET169_RN = Weights( + url='https://github.com/spencerwooo/torchattack/releases/download/v1.0-weights/bia_densenet169_rn_0.pth', + ) + DENSENET169_DA = Weights( + url='https://github.com/spencerwooo/torchattack/releases/download/v1.0-weights/bia_densenet169_da_0.pth', + ) + VGG16 = Weights( + url='https://github.com/spencerwooo/torchattack/releases/download/v1.0-weights/bia_vgg16_0.pth', + ) + VGG16_RN = Weights( + url='https://github.com/spencerwooo/torchattack/releases/download/v1.0-weights/bia_vgg16_rn_0.pth', + ) + VGG16_DA = Weights( + url='https://github.com/spencerwooo/torchattack/releases/download/v1.0-weights/bia_vgg16_da_0.pth', + ) + VGG19 = Weights( + url='https://github.com/spencerwooo/torchattack/releases/download/v1.0-weights/bia_vgg19_0.pth', + ) + VGG19_RN = Weights( + url='https://github.com/spencerwooo/torchattack/releases/download/v1.0-weights/bia_vgg19_rn_0.pth', + ) + VGG19_DA = Weights( + url='https://github.com/spencerwooo/torchattack/releases/download/v1.0-weights/bia_vgg19_da_0.pth', + ) + DEFAULT = RESNET152_DA + + +class BIA(GenerativeAttack): + """Beyond ImageNet Attack (BIA). + + From the paper 'Beyond ImageNet Attack: Towards Crafting Adversarial Examples for + Black-box Domains', https://arxiv.org/abs/2201.11528 + + Args: + device: Device to use for tensors. Defaults to cuda if available. eps: The + maximum perturbation. Defaults to 10/255. weights: Pretrained weights for the + generator. Either import and use the enum, + or use its name. Defaults to BIAWeights.DEFAULT. + checkpoint_path: Path to a custom checkpoint. Defaults to None. clip_min: + Minimum value for clipping. Defaults to 0.0. clip_max: Maximum value for + clipping. Defaults to 1.0. + """ + + def __init__( + self, + device: torch.device | None = None, + eps: float = 10 / 255, + weights: BIAWeights | str | None = BIAWeights.DEFAULT, + checkpoint_path: str | None = None, + clip_min: float = 0.0, + clip_max: float = 1.0, + ) -> None: + super().__init__(device, eps, weights, checkpoint_path, clip_min, clip_max) + + def _init_generator(self) -> ResNetGenerator: + generator = ResNetGenerator() + # Prioritize checkpoint path over provided weights enum + if self.checkpoint_path is not None: + generator.load_state_dict(torch.load(self.checkpoint_path)) + else: + # Verify and load weights from enum if checkpoint path is not provided + self.weights: BIAWeights = BIAWeights.verify(self.weights) + if self.weights is not None: + generator.load_state_dict(self.weights.get_state_dict(check_hash=True)) + return generator.eval().to(self.device) + + +if __name__ == '__main__': + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + attack = BIA(device, weights='VGG19_DA') + print(attack) diff --git a/torchattack/cda.py b/torchattack/cda.py index 37a9820..4c211bc 100644 --- a/torchattack/cda.py +++ b/torchattack/cda.py @@ -30,7 +30,9 @@ class CDA(GenerativeAttack): Args: device: Device to use for tensors. Defaults to cuda if available. eps: The maximum perturbation. Defaults to 10/255. - weights: Pretrained weights for the generator. Defaults to CDAWeights.DEFAULT. + weights: Pretrained weights for the generator. Either import and use the enum, + or use its name. Defaults to CDAWeights.DEFAULT. + checkpoint_path: Path to a custom checkpoint. Defaults to None. clip_min: Minimum value for clipping. Defaults to 0.0. clip_max: Maximum value for clipping. Defaults to 1.0. """ @@ -61,5 +63,5 @@ def _init_generator(self) -> ResNetGenerator: if __name__ == '__main__': device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') - attack = CDA(device, eps=8 / 255, weights='VGG19_IMAGENET1K') + attack = CDA(device, weights='VGG19_IMAGENET1K') print(attack) diff --git a/torchattack/create_attack.py b/torchattack/create_attack.py index 06cc3f5..a989ae3 100644 --- a/torchattack/create_attack.py +++ b/torchattack/create_attack.py @@ -60,6 +60,8 @@ def create_attack( if not hasattr(torchattack, attack_name): raise ValueError(f"Attack '{attack_name}' is not supported within torchattack.") attacker_cls: Attack = getattr(torchattack, attack_name) + if attack_name in ['BIA', 'CDA']: + return attacker_cls(device, **attack_cfg) return attacker_cls(model, normalize, device, **attack_cfg) diff --git a/torchattack/eval/runner.py b/torchattack/eval/runner.py index dbce171..f3b6fba 100644 --- a/torchattack/eval/runner.py +++ b/torchattack/eval/runner.py @@ -116,7 +116,7 @@ def run_attack( parser = argparse.ArgumentParser(description='Run an attack on a model.') parser.add_argument('--attack', type=str, required=True) - parser.add_argument('--eps', type=str, default='16/255') + parser.add_argument('--eps', type=float, default=16) parser.add_argument('--model-name', type=str, default='resnet50') parser.add_argument('--victim-model-names', type=str, nargs='+', default=None) parser.add_argument('--dataset-root', type=str, default='datasets/nips2017') @@ -126,7 +126,7 @@ def run_attack( run_attack( attack=args.attack, - attack_cfg={'eps': float(args.eps)}, + attack_cfg={'eps': args.eps / 255}, model_name=args.model_name, victim_model_names=args.victim_model_names, dataset_root=args.dataset_root, diff --git a/torchattack/generative/_inference.py b/torchattack/generative/_inference.py index d555baa..5d29f12 100644 --- a/torchattack/generative/_inference.py +++ b/torchattack/generative/_inference.py @@ -29,7 +29,7 @@ def __init__( # Initialize the generator and its weights self.generator = self._init_generator() - def forward(self, x: torch.Tensor) -> torch.Tensor: + def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor: """Perform the generative attack via generator inference on a batch of images. Args: From 169f6703b081ce4257fcd05899f3ed631353d6dd Mon Sep 17 00:00:00 2001 From: spencerwooo Date: Tue, 26 Nov 2024 19:41:45 +0800 Subject: [PATCH 3/9] Add weights/checkpoint_path args to create_attack --- torchattack/create_attack.py | 22 +++++++++++++++++++--- 1 file changed, 19 insertions(+), 3 deletions(-) diff --git a/torchattack/create_attack.py b/torchattack/create_attack.py index a989ae3..18fdee6 100644 --- a/torchattack/create_attack.py +++ b/torchattack/create_attack.py @@ -7,6 +7,9 @@ from torchattack._attack import Attack from torchattack.attack_model import AttackModel +_generative_attacks = ['BIA', 'CDA'] +_non_eps_attacks = ['GeoDA', 'DeepFool'] + def attack_warn(message: str) -> None: from warnings import warn @@ -20,13 +23,15 @@ def create_attack( normalize: Callable[[torch.Tensor], torch.Tensor] | None = None, device: torch.device | None = None, eps: float | None = None, + weights: str | None = None, + checkpoint_path: str | None = None, attack_cfg: dict[str, Any] | None = None, ) -> Attack: """Create a torchattack instance based on the provided attack name and config. Args: attack_name: The name of the attack to create. - model: The model to be attacked. + model: The model to be attacked. Defaults to None. normalize: The normalization function specific to the model. Defaults to None. device: The device on which the attack will be executed. Defaults to None. eps: The epsilon value for the attack. Defaults to None. @@ -54,13 +59,24 @@ def create_attack( f"by the 'eps' argument value ({eps}), which MAY NOT be intended." ) attack_cfg['eps'] = eps - if attack_name in ['GeoDA', 'DeepFool'] and 'eps' in attack_cfg: + if attack_name in _non_eps_attacks and 'eps' in attack_cfg: attack_warn(f"parameter 'eps' is invalid in {attack_name} and will be ignored.") attack_cfg.pop('eps', None) + if attack_name not in _generative_attacks and ( + weights is not None or checkpoint_path is not None + ): + attack_warn( + f'weights and checkpoint_path are only used for generative attacks, ' + f"and will be ignored for '{attack_name}'." + ) + attack_cfg.pop('weights', None) + attack_cfg.pop('checkpoint_path', None) if not hasattr(torchattack, attack_name): raise ValueError(f"Attack '{attack_name}' is not supported within torchattack.") attacker_cls: Attack = getattr(torchattack, attack_name) - if attack_name in ['BIA', 'CDA']: + if attack_name in _generative_attacks: + attack_cfg['weights'] = weights + attack_cfg['checkpoint_path'] = checkpoint_path return attacker_cls(device, **attack_cfg) return attacker_cls(model, normalize, device, **attack_cfg) From 8f7cc30856e92b30336ce5983f7270283d0cbb11 Mon Sep 17 00:00:00 2001 From: spencerwooo Date: Wed, 27 Nov 2024 09:27:37 +0800 Subject: [PATCH 4/9] Remove generator inference abstraction --- torchattack/bia.py | 56 ++++++++++++++++++++-------- torchattack/cda.py | 45 +++++++++++++++++----- torchattack/generative/_inference.py | 49 ------------------------ 3 files changed, 77 insertions(+), 73 deletions(-) delete mode 100644 torchattack/generative/_inference.py diff --git a/torchattack/bia.py b/torchattack/bia.py index d184ff9..1f8c95e 100644 --- a/torchattack/bia.py +++ b/torchattack/bia.py @@ -1,6 +1,8 @@ +from typing import Any + import torch -from torchattack.generative._inference import GenerativeAttack +from torchattack._attack import Attack from torchattack.generative._weights import Weights, WeightsEnum from torchattack.generative.resnet_generator import ResNetGenerator @@ -45,20 +47,20 @@ class BIAWeights(WeightsEnum): DEFAULT = RESNET152_DA -class BIA(GenerativeAttack): +class BIA(Attack): """Beyond ImageNet Attack (BIA). From the paper 'Beyond ImageNet Attack: Towards Crafting Adversarial Examples for Black-box Domains', https://arxiv.org/abs/2201.11528 Args: - device: Device to use for tensors. Defaults to cuda if available. eps: The - maximum perturbation. Defaults to 10/255. weights: Pretrained weights for the - generator. Either import and use the enum, + device: Device to use for tensors. Defaults to cuda if available. + eps: The maximum perturbation. Defaults to 10/255. + weights: Pretrained weights for the generator. Either import and use the enum, or use its name. Defaults to BIAWeights.DEFAULT. - checkpoint_path: Path to a custom checkpoint. Defaults to None. clip_min: - Minimum value for clipping. Defaults to 0.0. clip_max: Maximum value for - clipping. Defaults to 1.0. + checkpoint_path: Path to a custom checkpoint. Defaults to None. + clip_min: Minimum value for clipping. Defaults to 0.0. + clip_max: Maximum value for clipping. Defaults to 1.0. """ def __init__( @@ -70,19 +72,43 @@ def __init__( clip_min: float = 0.0, clip_max: float = 1.0, ) -> None: - super().__init__(device, eps, weights, checkpoint_path, clip_min, clip_max) + # Generative attacks do not require specifying model and normalize. + super().__init__(model=None, normalize=None, device=device) + + self.eps = eps + self.checkpoint_path = checkpoint_path + self.clip_min = clip_min + self.clip_max = clip_max + + self.generator = ResNetGenerator() - def _init_generator(self) -> ResNetGenerator: - generator = ResNetGenerator() # Prioritize checkpoint path over provided weights enum if self.checkpoint_path is not None: - generator.load_state_dict(torch.load(self.checkpoint_path)) + self.generator.load_state_dict(torch.load(self.checkpoint_path)) else: # Verify and load weights from enum if checkpoint path is not provided - self.weights: BIAWeights = BIAWeights.verify(self.weights) + self.weights: BIAWeights = BIAWeights.verify(weights) if self.weights is not None: - generator.load_state_dict(self.weights.get_state_dict(check_hash=True)) - return generator.eval().to(self.device) + self.generator.load_state_dict( + self.weights.get_state_dict(check_hash=True) + ) + + self.generator.eval().to(self.device) + + def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor: + """Perform BIA on a batch of images. + + Args: + x: A batch of images. Shape: (N, C, H, W). + + Returns: + The perturbed images if successful. Shape: (N, C, H, W). + """ + + x_unrestricted = self.generator(x) + delta = torch.clamp(x_unrestricted - x, -self.eps, self.eps) + x_adv = torch.clamp(x + delta, self.clip_min, self.clip_max) + return x_adv if __name__ == '__main__': diff --git a/torchattack/cda.py b/torchattack/cda.py index 4c211bc..ebcadea 100644 --- a/torchattack/cda.py +++ b/torchattack/cda.py @@ -1,6 +1,8 @@ +from typing import Any + import torch -from torchattack.generative._inference import GenerativeAttack +from torchattack._attack import Attack from torchattack.generative._weights import Weights, WeightsEnum from torchattack.generative.resnet_generator import ResNetGenerator @@ -21,7 +23,7 @@ class CDAWeights(WeightsEnum): DEFAULT = RESNET152_IMAGENET1K -class CDA(GenerativeAttack): +class CDA(Attack): """Cross-domain Attack (CDA). From the paper 'Cross-Domain Transferability of Adversarial Perturbations', @@ -46,19 +48,44 @@ def __init__( clip_min: float = 0.0, clip_max: float = 1.0, ) -> None: - super().__init__(device, eps, weights, checkpoint_path, clip_min, clip_max) + # Generative attacks do not require specifying model and normalize. + super().__init__(model=None, normalize=None, device=device) + + self.eps = eps + self.checkpoint_path = checkpoint_path + self.clip_min = clip_min + self.clip_max = clip_max + + # Initialize the generator and its weights + self.generator = ResNetGenerator() - def _init_generator(self) -> ResNetGenerator: - generator = ResNetGenerator() # Prioritize checkpoint path over provided weights enum if self.checkpoint_path is not None: - generator.load_state_dict(torch.load(self.checkpoint_path)) + self.generator.load_state_dict(torch.load(self.checkpoint_path)) else: # Verify and load weights from enum if checkpoint path is not provided - self.weights: CDAWeights = CDAWeights.verify(self.weights) + self.weights: CDAWeights = CDAWeights.verify(weights) if self.weights is not None: - generator.load_state_dict(self.weights.get_state_dict(check_hash=True)) - return generator.eval().to(self.device) + self.generator.load_state_dict( + self.weights.get_state_dict(check_hash=True) + ) + + self.generator.eval().to(self.device) + + def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor: + """Perform CDA on a batch of images. + + Args: + x: A batch of images. Shape: (N, C, H, W). + + Returns: + The perturbed images if successful. Shape: (N, C, H, W). + """ + + x_unrestricted = self.generator(x) + delta = torch.clamp(x_unrestricted - x, -self.eps, self.eps) + x_adv = torch.clamp(x + delta, self.clip_min, self.clip_max) + return x_adv if __name__ == '__main__': diff --git a/torchattack/generative/_inference.py b/torchattack/generative/_inference.py deleted file mode 100644 index 5d29f12..0000000 --- a/torchattack/generative/_inference.py +++ /dev/null @@ -1,49 +0,0 @@ -from abc import abstractmethod -from typing import Any - -import torch - -from torchattack._attack import Attack -from torchattack.generative._weights import WeightsEnum - - -class GenerativeAttack(Attack): - def __init__( - self, - device: torch.device | None = None, - eps: float = 10 / 255, - weights: WeightsEnum | str | None = None, - checkpoint_path: str | None = None, - clip_min: float = 0.0, - clip_max: float = 1.0, - ) -> None: - # Generative attacks do not require specifying model and normalize. - super().__init__(model=None, normalize=None, device=device) - - self.eps = eps - self.weights = weights - self.checkpoint_path = checkpoint_path - self.clip_min = clip_min - self.clip_max = clip_max - - # Initialize the generator and its weights - self.generator = self._init_generator() - - def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor: - """Perform the generative attack via generator inference on a batch of images. - - Args: - x: A batch of images. Shape: (N, C, H, W). - - Returns: - The perturbed images if successful. Shape: (N, C, H, W). - """ - - x_unrestricted = self.generator(x) - delta = torch.clamp(x_unrestricted - x, -self.eps, self.eps) - x_adv = torch.clamp(x + delta, self.clip_min, self.clip_max) - return x_adv - - @abstractmethod - def _init_generator(self, *args: Any, **kwds: Any) -> Any: - pass From 2a736514f73dda1a740f121f2dd79d163ff9fa51 Mon Sep 17 00:00:00 2001 From: spencerwooo Date: Wed, 27 Nov 2024 12:16:55 +0800 Subject: [PATCH 5/9] Testing for generative attacks --- tests/test_create_attack.py | 36 +++++++++++++++++++++++++++++- torchattack/_attack.py | 1 + torchattack/bia.py | 31 ++++++++++++------------- torchattack/cda.py | 14 ++++++------ torchattack/create_attack.py | 30 ++++++++++++++++++------- torchattack/generative/_weights.py | 4 ++-- 6 files changed, 83 insertions(+), 33 deletions(-) diff --git a/tests/test_create_attack.py b/tests/test_create_attack.py index b003457..ed9591c 100644 --- a/tests/test_create_attack.py +++ b/tests/test_create_attack.py @@ -1,6 +1,8 @@ import pytest from torchattack import ( + BIA, + CDA, DIFGSM, FGSM, FIA, @@ -48,6 +50,10 @@ 'VDC': VDC, 'PNAPatchOut': PNAPatchOut, } +expected_generative_attacks = { + 'BIA': BIA, + 'CDA': CDA, +} @pytest.mark.parametrize(('attack_name', 'expected'), expected_non_vit_attacks.items()) @@ -72,6 +78,15 @@ def test_create_vit_attack_same_as_imported( assert created_attacker == expected_attacker +@pytest.mark.parametrize( + ('attack_name', 'expected'), expected_generative_attacks.items() +) +def test_create_generative_attack_same_as_imported(attack_name, expected): + created_attacker = create_attack(attack_name) + expected_attacker = expected() + assert created_attacker == expected_attacker + + def test_create_attack_with_eps(device, resnet50_model): eps = 0.3 attack_cfg = {} @@ -119,7 +134,7 @@ def test_create_attack_with_both_eps_and_attack_cfg(device, resnet50_model): def test_create_attack_with_invalid_eps(device, resnet50_model): eps = 0.3 with pytest.warns( - UserWarning, match="parameter 'eps' is invalid in DeepFool and will be ignored." + UserWarning, match="argument 'eps' is invalid in DeepFool and will be ignored." ): attacker = create_attack( attack_name='DeepFool', @@ -131,6 +146,25 @@ def test_create_attack_with_invalid_eps(device, resnet50_model): assert 'eps' not in attacker.__dict__ +def test_create_attack_with_weights_and_checkpoint_path(device): + weights = 'VGG19_IMAGENET1K' + checkpoint_path = 'path/to/checkpoint' + attack_cfg = {} + with pytest.warns( + UserWarning, + match="argument 'weights' and 'checkpoint_path' are only used for generative attacks, and will be ignored for 'FGSM'.", + ): + attacker = create_attack( + attack_name='FGSM', + device=device, + weights=weights, + checkpoint_path=checkpoint_path, + attack_cfg=attack_cfg, + ) + assert 'weights' not in attacker.__dict__ + assert 'checkpoint_path' not in attacker.__dict__ + + def test_create_attack_with_invalid_attack_name(device, resnet50_model): with pytest.raises( ValueError, match="Attack 'InvalidAttack' is not supported within torchattack." diff --git a/torchattack/_attack.py b/torchattack/_attack.py index bb2003c..36028cc 100644 --- a/torchattack/_attack.py +++ b/torchattack/_attack.py @@ -70,6 +70,7 @@ def __eq__(self, other): 'hooks', # PNAPatchOut, TGR, VDC 'perceptual_criteria', # SSP 'sub_basis', # GeoDA + 'generator', # BIA, CDA ] for attr in eq_name_attrs: if not (hasattr(self, attr) and hasattr(other, attr)): diff --git a/torchattack/bia.py b/torchattack/bia.py index 1f8c95e..ca04053 100644 --- a/torchattack/bia.py +++ b/torchattack/bia.py @@ -3,45 +3,45 @@ import torch from torchattack._attack import Attack -from torchattack.generative._weights import Weights, WeightsEnum +from torchattack.generative._weights import GeneratorWeights, GeneratorWeightsEnum from torchattack.generative.resnet_generator import ResNetGenerator -class BIAWeights(WeightsEnum): - RESNET152 = Weights( +class BIAWeights(GeneratorWeightsEnum): + RESNET152 = GeneratorWeights( url='https://github.com/spencerwooo/torchattack/releases/download/v1.0-weights/bia_resnet152_0.pth', ) - RESNET152_RN = Weights( + RESNET152_RN = GeneratorWeights( url='https://github.com/spencerwooo/torchattack/releases/download/v1.0-weights/bia_resnet152_rn_0.pth', ) - RESNET152_DA = Weights( + RESNET152_DA = GeneratorWeights( url='https://github.com/spencerwooo/torchattack/releases/download/v1.0-weights/bia_resnet152_da_0.pth', ) - DENSENET169 = Weights( + DENSENET169 = GeneratorWeights( url='https://github.com/spencerwooo/torchattack/releases/download/v1.0-weights/bia_densenet169_0.pth', ) - DENSENET169_RN = Weights( + DENSENET169_RN = GeneratorWeights( url='https://github.com/spencerwooo/torchattack/releases/download/v1.0-weights/bia_densenet169_rn_0.pth', ) - DENSENET169_DA = Weights( + DENSENET169_DA = GeneratorWeights( url='https://github.com/spencerwooo/torchattack/releases/download/v1.0-weights/bia_densenet169_da_0.pth', ) - VGG16 = Weights( + VGG16 = GeneratorWeights( url='https://github.com/spencerwooo/torchattack/releases/download/v1.0-weights/bia_vgg16_0.pth', ) - VGG16_RN = Weights( + VGG16_RN = GeneratorWeights( url='https://github.com/spencerwooo/torchattack/releases/download/v1.0-weights/bia_vgg16_rn_0.pth', ) - VGG16_DA = Weights( + VGG16_DA = GeneratorWeights( url='https://github.com/spencerwooo/torchattack/releases/download/v1.0-weights/bia_vgg16_da_0.pth', ) - VGG19 = Weights( + VGG19 = GeneratorWeights( url='https://github.com/spencerwooo/torchattack/releases/download/v1.0-weights/bia_vgg19_0.pth', ) - VGG19_RN = Weights( + VGG19_RN = GeneratorWeights( url='https://github.com/spencerwooo/torchattack/releases/download/v1.0-weights/bia_vgg19_rn_0.pth', ) - VGG19_DA = Weights( + VGG19_DA = GeneratorWeights( url='https://github.com/spencerwooo/torchattack/releases/download/v1.0-weights/bia_vgg19_da_0.pth', ) DEFAULT = RESNET152_DA @@ -72,7 +72,7 @@ def __init__( clip_min: float = 0.0, clip_max: float = 1.0, ) -> None: - # Generative attacks do not require specifying model and normalize. + # Generative attacks do not require specifying model and normalize super().__init__(model=None, normalize=None, device=device) self.eps = eps @@ -80,6 +80,7 @@ def __init__( self.clip_min = clip_min self.clip_max = clip_max + # Initialize the generator and its weights self.generator = ResNetGenerator() # Prioritize checkpoint path over provided weights enum diff --git a/torchattack/cda.py b/torchattack/cda.py index ebcadea..497cd5b 100644 --- a/torchattack/cda.py +++ b/torchattack/cda.py @@ -3,21 +3,21 @@ import torch from torchattack._attack import Attack -from torchattack.generative._weights import Weights, WeightsEnum +from torchattack.generative._weights import GeneratorWeights, GeneratorWeightsEnum from torchattack.generative.resnet_generator import ResNetGenerator -class CDAWeights(WeightsEnum): - RESNET152_IMAGENET1K = Weights( +class CDAWeights(GeneratorWeightsEnum): + RESNET152_IMAGENET1K = GeneratorWeights( url='https://github.com/spencerwooo/torchattack/releases/download/v1.0-weights/cda_res152_imagenet_0_rl.pth', ) - INCEPTION_V3_IMAGENET1K = Weights( + INCEPTION_V3_IMAGENET1K = GeneratorWeights( url='https://github.com/spencerwooo/torchattack/releases/download/v1.0-weights/cda_incv3_imagenet_0_rl.pth', ) - VGG16_IMAGENET1K = Weights( + VGG16_IMAGENET1K = GeneratorWeights( url='https://github.com/spencerwooo/torchattack/releases/download/v1.0-weights/cda_vgg16_imagenet_0_rl.pth', ) - VGG19_IMAGENET1K = Weights( + VGG19_IMAGENET1K = GeneratorWeights( url='https://github.com/spencerwooo/torchattack/releases/download/v1.0-weights/cda_vgg19_imagenet_0_rl.pth', ) DEFAULT = RESNET152_IMAGENET1K @@ -48,7 +48,7 @@ def __init__( clip_min: float = 0.0, clip_max: float = 1.0, ) -> None: - # Generative attacks do not require specifying model and normalize. + # Generative attacks do not require specifying model and normalize super().__init__(model=None, normalize=None, device=device) self.eps = eps diff --git a/torchattack/create_attack.py b/torchattack/create_attack.py index 18fdee6..c248f69 100644 --- a/torchattack/create_attack.py +++ b/torchattack/create_attack.py @@ -23,7 +23,7 @@ def create_attack( normalize: Callable[[torch.Tensor], torch.Tensor] | None = None, device: torch.device | None = None, eps: float | None = None, - weights: str | None = None, + weights: str | None = 'DEFAULT', checkpoint_path: str | None = None, attack_cfg: dict[str, Any] | None = None, ) -> Attack: @@ -31,10 +31,12 @@ def create_attack( Args: attack_name: The name of the attack to create. - model: The model to be attacked. Defaults to None. + model: The model to be attacked. Can be an instance of nn.Module or AttackModel. Defaults to None. normalize: The normalization function specific to the model. Defaults to None. device: The device on which the attack will be executed. Defaults to None. eps: The epsilon value for the attack. Defaults to None. + weights: The path to the weights file for generative attacks. Defaults to 'DEFAULT'. + checkpoint_path: The path to the checkpoint file for generative attacks. Defaults to None. attack_cfg: Additional config parameters for the attack. Defaults to None. Returns: @@ -48,10 +50,18 @@ def create_attack( issued and the value in `attack_cfg` will be overwritten. - For certain attacks like 'GeoDA' and 'DeepFool', the `eps` parameter is invalid and will be ignored if present in `attack_cfg`. + - The `weights` and `checkpoint_path` parameters are only used for generative attacks. """ if attack_cfg is None: attack_cfg = {} + + # Check if attack_name is supported + if not hasattr(torchattack, attack_name): + raise ValueError(f"Attack '{attack_name}' is not supported within torchattack.") + attacker_cls: Attack = getattr(torchattack, attack_name) + + # Check if eps is provided and overwrite the value in attack_cfg if present if eps is not None: if 'eps' in attack_cfg: attack_warn( @@ -59,25 +69,29 @@ def create_attack( f"by the 'eps' argument value ({eps}), which MAY NOT be intended." ) attack_cfg['eps'] = eps + + # Check if attacks that do not require eps have eps in attack_cfg if attack_name in _non_eps_attacks and 'eps' in attack_cfg: - attack_warn(f"parameter 'eps' is invalid in {attack_name} and will be ignored.") + attack_warn(f"argument 'eps' is invalid in {attack_name} and will be ignored.") attack_cfg.pop('eps', None) + + # Check if non-generative attacks have weights or checkpoint_path if attack_name not in _generative_attacks and ( weights is not None or checkpoint_path is not None ): attack_warn( - f'weights and checkpoint_path are only used for generative attacks, ' - f"and will be ignored for '{attack_name}'." + f"argument 'weights' and 'checkpoint_path' are only used for " + f"generative attacks, and will be ignored for '{attack_name}'." ) attack_cfg.pop('weights', None) attack_cfg.pop('checkpoint_path', None) - if not hasattr(torchattack, attack_name): - raise ValueError(f"Attack '{attack_name}' is not supported within torchattack.") - attacker_cls: Attack = getattr(torchattack, attack_name) + + # Special handling for generative attacks if attack_name in _generative_attacks: attack_cfg['weights'] = weights attack_cfg['checkpoint_path'] = checkpoint_path return attacker_cls(device, **attack_cfg) + return attacker_cls(model, normalize, device, **attack_cfg) diff --git a/torchattack/generative/_weights.py b/torchattack/generative/_weights.py index afa149f..6cd3d57 100644 --- a/torchattack/generative/_weights.py +++ b/torchattack/generative/_weights.py @@ -6,11 +6,11 @@ @dataclass -class Weights: +class GeneratorWeights: url: str -class WeightsEnum(Enum): +class GeneratorWeightsEnum(Enum): @classmethod def verify(cls, obj: Any) -> Any: if obj is not None: From 40298d08de20cec45ec455e910454a35487eae29 Mon Sep 17 00:00:00 2001 From: spencerwooo Date: Wed, 27 Nov 2024 14:05:37 +0800 Subject: [PATCH 6/9] Update create_attack and runner to support generative attacks --- README.md | 11 ++++- tests/test_create_attack.py | 96 ++++++++++-------------------------- torchattack/__init__.py | 29 ++++++++++- torchattack/create_attack.py | 19 +++---- torchattack/eval/runner.py | 13 ++++- 5 files changed, 82 insertions(+), 86 deletions(-) diff --git a/README.md b/README.md index f6b38bb..c425471 100644 --- a/README.md +++ b/README.md @@ -22,13 +22,13 @@ Install from GitHub source - ```shell -python -m pip install git+https://github.com/spencerwooo/torchattack@v1.0.1 +python -m pip install git+https://github.com/spencerwooo/torchattack@v1.0.2 ``` Install from Gitee mirror - ```shell -python -m pip install git+https://gitee.com/spencerwoo/torchattack@v1.0.1 +python -m pip install git+https://gitee.com/spencerwoo/torchattack@v1.0.2 ``` ## Usage @@ -110,6 +110,13 @@ Gradient-based attacks: | DeCoWA | $\ell_\infty$ | AAAI 2024 | [Boosting Adversarial Transferability across Model Genus by Deformation-Constrained Warping](https://arxiv.org/abs/2402.03951) | `DeCoWA` | | VDC | $\ell_\infty$ | AAAI 2024 | [Improving the Adversarial Transferability of Vision Transformers with Virtual Dense Connection](https://ojs.aaai.org/index.php/AAAI/article/view/28541) | `VDC` | +Generative attacks: + +| Name | $\ell_p$ | Publication | Paper (Open Access) | Class Name | +| :--: | :-----------: | :----------: | ----------------------------------------------------------------------------------------------------------------------- | ---------- | +| CDA | $\ell_\infty$ | NeurIPS 2019 | [CDA: contrastive Divergence for Adversarial Attack](https://arxiv.org/abs/1905.11736) | `CDA` | +| BIA | $\ell_\infty$ | ICLR 2022 | [Beyond ImageNet Attack: Towards Crafting Adversarial Examples for Black-box Domains](https://arxiv.org/abs/2201.11528) | `BIA` | + Others: | Name | $\ell_p$ | Publication | Paper (Open Access) | Class Name | diff --git a/tests/test_create_attack.py b/tests/test_create_attack.py index ed9591c..188502b 100644 --- a/tests/test_create_attack.py +++ b/tests/test_create_attack.py @@ -1,62 +1,12 @@ import pytest -from torchattack import ( - BIA, - CDA, - DIFGSM, - FGSM, - FIA, - MIFGSM, - NIFGSM, - PGD, - PGDL2, - SINIFGSM, - SSA, - SSP, - TGR, - TIFGSM, - VDC, - VMIFGSM, - VNIFGSM, - Admix, - DeCoWA, - DeepFool, - GeoDA, - PNAPatchOut, - create_attack, -) +import torchattack +from torchattack import create_attack + -expected_non_vit_attacks = { - 'DIFGSM': DIFGSM, - 'FGSM': FGSM, - 'FIA': FIA, - 'MIFGSM': MIFGSM, - 'NIFGSM': NIFGSM, - 'PGD': PGD, - 'PGDL2': PGDL2, - 'SINIFGSM': SINIFGSM, - 'SSA': SSA, - 'SSP': SSP, - 'TIFGSM': TIFGSM, - 'VMIFGSM': VMIFGSM, - 'VNIFGSM': VNIFGSM, - 'Admix': Admix, - 'DeCoWA': DeCoWA, - 'DeepFool': DeepFool, - 'GeoDA': GeoDA, -} -expected_vit_attacks = { - 'TGR': TGR, - 'VDC': VDC, - 'PNAPatchOut': PNAPatchOut, -} -expected_generative_attacks = { - 'BIA': BIA, - 'CDA': CDA, -} - - -@pytest.mark.parametrize(('attack_name', 'expected'), expected_non_vit_attacks.items()) +@pytest.mark.parametrize( + ('attack_name', 'expected'), torchattack.GRADIENT_NON_VIT_ATTACKS.items() +) def test_create_non_vit_attack_same_as_imported( attack_name, expected, @@ -67,7 +17,9 @@ def test_create_non_vit_attack_same_as_imported( assert created_attacker == expected_attacker -@pytest.mark.parametrize(('attack_name', 'expected'), expected_vit_attacks.items()) +@pytest.mark.parametrize( + ('attack_name', 'expected'), torchattack.GRADIENT_VIT_ATTACKS.items() +) def test_create_vit_attack_same_as_imported( attack_name, expected, @@ -79,7 +31,7 @@ def test_create_vit_attack_same_as_imported( @pytest.mark.parametrize( - ('attack_name', 'expected'), expected_generative_attacks.items() + ('attack_name', 'expected'), torchattack.GENERATIVE_ATTACKS.items() ) def test_create_generative_attack_same_as_imported(attack_name, expected): created_attacker = create_attack(attack_name) @@ -116,18 +68,19 @@ def test_create_attack_with_attack_cfg_eps(device, resnet50_model): def test_create_attack_with_both_eps_and_attack_cfg(device, resnet50_model): eps = 0.3 attack_cfg = {'eps': 0.1} - # with pytest.warns( - # UserWarning, - # match="'eps' in 'attack_cfg' (0.1) will be overwritten by the 'eps' argument value (0.3), which MAY NOT be intended.", - # ): - attacker = create_attack( - attack_name='FGSM', - model=resnet50_model, - normalize=resnet50_model.normalize, - device=device, - eps=eps, - attack_cfg=attack_cfg, - ) + with pytest.warns( + UserWarning, + match="The 'eps' value provided as an argument will overwrite the existing " + "'eps' value in 'attack_cfg'. This MAY NOT be the intended behavior.", + ): + attacker = create_attack( + attack_name='FGSM', + model=resnet50_model, + normalize=resnet50_model.normalize, + device=device, + eps=eps, + attack_cfg=attack_cfg, + ) assert attacker.eps == eps @@ -152,7 +105,8 @@ def test_create_attack_with_weights_and_checkpoint_path(device): attack_cfg = {} with pytest.warns( UserWarning, - match="argument 'weights' and 'checkpoint_path' are only used for generative attacks, and will be ignored for 'FGSM'.", + match="argument 'weights' and 'checkpoint_path' are only used for " + "generative attacks, and will be ignored for 'FGSM'.", ): attacker = create_attack( attack_name='FGSM', diff --git a/torchattack/__init__.py b/torchattack/__init__.py index 7b29c46..ccd4a68 100644 --- a/torchattack/__init__.py +++ b/torchattack/__init__.py @@ -23,7 +23,7 @@ from torchattack.vmifgsm import VMIFGSM from torchattack.vnifgsm import VNIFGSM -__version__ = '1.0.1' +__version__ = '1.0.2' __all__ = [ # Helper function to create an attack by its name @@ -54,3 +54,30 @@ 'VMIFGSM', 'VNIFGSM', ] + +GRADIENT_NON_VIT_ATTACKS = { + 'Admix': Admix, + 'DeCoWA': DeCoWA, + 'DIFGSM': DIFGSM, + 'FGSM': FGSM, + 'FIA': FIA, + 'MIFGSM': MIFGSM, + 'NIFGSM': NIFGSM, + 'PGD': PGD, + 'PGDL2': PGDL2, + 'SINIFGSM': SINIFGSM, + 'SSA': SSA, + 'SSP': SSP, + 'TIFGSM': TIFGSM, + 'VMIFGSM': VMIFGSM, + 'VNIFGSM': VNIFGSM, +} +GRADIENT_VIT_ATTACKS = {'TGR': TGR, 'VDC': VDC, 'PNAPatchOut': PNAPatchOut} +GENERATIVE_ATTACKS = {'BIA': BIA, 'CDA': CDA} +NON_EPS_ATTACKS = {'GeoDA': GeoDA, 'DeepFool': DeepFool} +SUPPORTED_ATTACKS = ( + GRADIENT_NON_VIT_ATTACKS + | GRADIENT_VIT_ATTACKS + | GENERATIVE_ATTACKS + | NON_EPS_ATTACKS +) diff --git a/torchattack/create_attack.py b/torchattack/create_attack.py index c248f69..dee12ff 100644 --- a/torchattack/create_attack.py +++ b/torchattack/create_attack.py @@ -7,9 +7,6 @@ from torchattack._attack import Attack from torchattack.attack_model import AttackModel -_generative_attacks = ['BIA', 'CDA'] -_non_eps_attacks = ['GeoDA', 'DeepFool'] - def attack_warn(message: str) -> None: from warnings import warn @@ -35,7 +32,7 @@ def create_attack( normalize: The normalization function specific to the model. Defaults to None. device: The device on which the attack will be executed. Defaults to None. eps: The epsilon value for the attack. Defaults to None. - weights: The path to the weights file for generative attacks. Defaults to 'DEFAULT'. + weights: The name of the generator weight for generative attacks. Defaults to 'DEFAULT'. checkpoint_path: The path to the checkpoint file for generative attacks. Defaults to None. attack_cfg: Additional config parameters for the attack. Defaults to None. @@ -57,7 +54,7 @@ def create_attack( attack_cfg = {} # Check if attack_name is supported - if not hasattr(torchattack, attack_name): + if attack_name not in torchattack.SUPPORTED_ATTACKS: raise ValueError(f"Attack '{attack_name}' is not supported within torchattack.") attacker_cls: Attack = getattr(torchattack, attack_name) @@ -65,19 +62,19 @@ def create_attack( if eps is not None: if 'eps' in attack_cfg: attack_warn( - f"'eps' in 'attack_cfg' ({attack_cfg['eps']}) will be overwritten " - f"by the 'eps' argument value ({eps}), which MAY NOT be intended." + "The 'eps' value provided as an argument will overwrite the existing " + "'eps' value in 'attack_cfg'. This MAY NOT be the intended behavior." ) attack_cfg['eps'] = eps # Check if attacks that do not require eps have eps in attack_cfg - if attack_name in _non_eps_attacks and 'eps' in attack_cfg: + if attack_name in torchattack.NON_EPS_ATTACKS and 'eps' in attack_cfg: attack_warn(f"argument 'eps' is invalid in {attack_name} and will be ignored.") attack_cfg.pop('eps', None) # Check if non-generative attacks have weights or checkpoint_path - if attack_name not in _generative_attacks and ( - weights is not None or checkpoint_path is not None + if attack_name not in torchattack.GENERATIVE_ATTACKS and ( + weights != 'DEFAULT' or checkpoint_path is not None ): attack_warn( f"argument 'weights' and 'checkpoint_path' are only used for " @@ -87,7 +84,7 @@ def create_attack( attack_cfg.pop('checkpoint_path', None) # Special handling for generative attacks - if attack_name in _generative_attacks: + if attack_name in torchattack.GENERATIVE_ATTACKS: attack_cfg['weights'] = weights attack_cfg['checkpoint_path'] = checkpoint_path return attacker_cls(device, **attack_cfg) diff --git a/torchattack/eval/runner.py b/torchattack/eval/runner.py index f3b6fba..8aea23f 100644 --- a/torchattack/eval/runner.py +++ b/torchattack/eval/runner.py @@ -114,9 +114,13 @@ def run_attack( if __name__ == '__main__': import argparse + from torchattack import GENERATIVE_ATTACKS, NON_EPS_ATTACKS + parser = argparse.ArgumentParser(description='Run an attack on a model.') parser.add_argument('--attack', type=str, required=True) parser.add_argument('--eps', type=float, default=16) + parser.add_argument('--weights', type=str, default='DEFAULT') + parser.add_argument('--checkpoint-path', type=str, default=None) parser.add_argument('--model-name', type=str, default='resnet50') parser.add_argument('--victim-model-names', type=str, nargs='+', default=None) parser.add_argument('--dataset-root', type=str, default='datasets/nips2017') @@ -124,9 +128,16 @@ def run_attack( parser.add_argument('--batch-size', type=int, default=4) args = parser.parse_args() + attack_cfg = {} + if args.attack not in NON_EPS_ATTACKS: + attack_cfg['eps'] = args.eps + if args.attack in GENERATIVE_ATTACKS: + attack_cfg['weights'] = args.weights + attack_cfg['checkpoint_path'] = args.checkpoint + run_attack( attack=args.attack, - attack_cfg={'eps': args.eps / 255}, + attack_cfg=attack_cfg, model_name=args.model_name, victim_model_names=args.victim_model_names, dataset_root=args.dataset_root, From 2a9be35f21f3405032858ecd3cfdcaf513130a4a Mon Sep 17 00:00:00 2001 From: spencerwooo Date: Wed, 27 Nov 2024 14:35:01 +0800 Subject: [PATCH 7/9] Unified create_attack (from name or class) for run_attack --- tests/test_create_attack.py | 12 ++++++------ torchattack/bia.py | 11 ++++++++--- torchattack/cda.py | 11 ++++++++--- torchattack/create_attack.py | 12 +++++++----- torchattack/eval/runner.py | 17 ++++------------- 5 files changed, 33 insertions(+), 30 deletions(-) diff --git a/tests/test_create_attack.py b/tests/test_create_attack.py index 188502b..bb7e593 100644 --- a/tests/test_create_attack.py +++ b/tests/test_create_attack.py @@ -43,7 +43,7 @@ def test_create_attack_with_eps(device, resnet50_model): eps = 0.3 attack_cfg = {} attacker = create_attack( - attack_name='FGSM', + attack='FGSM', model=resnet50_model, normalize=resnet50_model.normalize, device=device, @@ -56,7 +56,7 @@ def test_create_attack_with_eps(device, resnet50_model): def test_create_attack_with_attack_cfg_eps(device, resnet50_model): attack_cfg = {'eps': 0.1} attacker = create_attack( - attack_name='FGSM', + attack='FGSM', model=resnet50_model, normalize=resnet50_model.normalize, device=device, @@ -74,7 +74,7 @@ def test_create_attack_with_both_eps_and_attack_cfg(device, resnet50_model): "'eps' value in 'attack_cfg'. This MAY NOT be the intended behavior.", ): attacker = create_attack( - attack_name='FGSM', + attack='FGSM', model=resnet50_model, normalize=resnet50_model.normalize, device=device, @@ -90,7 +90,7 @@ def test_create_attack_with_invalid_eps(device, resnet50_model): UserWarning, match="argument 'eps' is invalid in DeepFool and will be ignored." ): attacker = create_attack( - attack_name='DeepFool', + attack='DeepFool', model=resnet50_model, normalize=resnet50_model.normalize, device=device, @@ -109,7 +109,7 @@ def test_create_attack_with_weights_and_checkpoint_path(device): "generative attacks, and will be ignored for 'FGSM'.", ): attacker = create_attack( - attack_name='FGSM', + attack='FGSM', device=device, weights=weights, checkpoint_path=checkpoint_path, @@ -124,7 +124,7 @@ def test_create_attack_with_invalid_attack_name(device, resnet50_model): ValueError, match="Attack 'InvalidAttack' is not supported within torchattack." ): create_attack( - attack_name='InvalidAttack', + attack='InvalidAttack', model=resnet50_model, normalize=resnet50_model.normalize, device=device, diff --git a/torchattack/bia.py b/torchattack/bia.py index ca04053..fa6a899 100644 --- a/torchattack/bia.py +++ b/torchattack/bia.py @@ -113,6 +113,11 @@ def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor: if __name__ == '__main__': - device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') - attack = BIA(device, weights='VGG19_DA') - print(attack) + from torchattack.eval import run_attack + + run_attack( + attack=BIA, + attack_cfg={'eps': 10 / 255, 'weights': 'VGG16_DA'}, + model_name='resnet50', + victim_model_names=['resnet152'], + ) diff --git a/torchattack/cda.py b/torchattack/cda.py index 497cd5b..4dfd205 100644 --- a/torchattack/cda.py +++ b/torchattack/cda.py @@ -89,6 +89,11 @@ def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor: if __name__ == '__main__': - device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') - attack = CDA(device, weights='VGG19_IMAGENET1K') - print(attack) + from torchattack.eval import run_attack + + run_attack( + attack=CDA, + attack_cfg={'eps': 10 / 255, 'weights': 'VGG16_IMAGENET1K'}, + model_name='resnet50', + victim_model_names=['resnet152'], + ) diff --git a/torchattack/create_attack.py b/torchattack/create_attack.py index dee12ff..d60900c 100644 --- a/torchattack/create_attack.py +++ b/torchattack/create_attack.py @@ -15,7 +15,7 @@ def attack_warn(message: str) -> None: def create_attack( - attack_name: str, + attack: Any, model: nn.Module | AttackModel | None = None, normalize: Callable[[torch.Tensor], torch.Tensor] | None = None, device: torch.device | None = None, @@ -27,7 +27,7 @@ def create_attack( """Create a torchattack instance based on the provided attack name and config. Args: - attack_name: The name of the attack to create. + attack: The attack to create, either by name or class instance. model: The model to be attacked. Can be an instance of nn.Module or AttackModel. Defaults to None. normalize: The normalization function specific to the model. Defaults to None. device: The device on which the attack will be executed. Defaults to None. @@ -54,9 +54,11 @@ def create_attack( attack_cfg = {} # Check if attack_name is supported + attack_name = attack if isinstance(attack, str) else attack.__name__ if attack_name not in torchattack.SUPPORTED_ATTACKS: raise ValueError(f"Attack '{attack_name}' is not supported within torchattack.") - attacker_cls: Attack = getattr(torchattack, attack_name) + if not isinstance(attack, Attack): + attack = getattr(torchattack, attack_name) # Check if eps is provided and overwrite the value in attack_cfg if present if eps is not None: @@ -87,9 +89,9 @@ def create_attack( if attack_name in torchattack.GENERATIVE_ATTACKS: attack_cfg['weights'] = weights attack_cfg['checkpoint_path'] = checkpoint_path - return attacker_cls(device, **attack_cfg) + return attack(device, **attack_cfg) - return attacker_cls(model, normalize, device, **attack_cfg) + return attack(model, normalize, device, **attack_cfg) if __name__ == '__main__': diff --git a/torchattack/eval/runner.py b/torchattack/eval/runner.py index 8aea23f..6a6f0ce 100644 --- a/torchattack/eval/runner.py +++ b/torchattack/eval/runner.py @@ -54,16 +54,7 @@ def run_attack( # Set up attack and trackers frm = FoolingRateMetric() - if isinstance(attack, str): - attacker = create_attack( - attack_name=attack, - model=model, - normalize=normalize, - device=device, - attack_cfg=attack_cfg, - ) - else: - attacker = attack(model, normalize, device, **attack_cfg) + attacker = create_attack(attack, model, normalize, device, attack_cfg=attack_cfg) print(attacker) # Setup victim models if provided @@ -114,7 +105,7 @@ def run_attack( if __name__ == '__main__': import argparse - from torchattack import GENERATIVE_ATTACKS, NON_EPS_ATTACKS + import torchattack parser = argparse.ArgumentParser(description='Run an attack on a model.') parser.add_argument('--attack', type=str, required=True) @@ -129,9 +120,9 @@ def run_attack( args = parser.parse_args() attack_cfg = {} - if args.attack not in NON_EPS_ATTACKS: + if args.attack not in torchattack.NON_EPS_ATTACKS: # type: ignore attack_cfg['eps'] = args.eps - if args.attack in GENERATIVE_ATTACKS: + if args.attack in torchattack.GENERATIVE_ATTACKS: # type: ignore attack_cfg['weights'] = args.weights attack_cfg['checkpoint_path'] = args.checkpoint From 4c0aff72850cb4d4288c3b36036becf45db83d8e Mon Sep 17 00:00:00 2001 From: spencerwooo Date: Wed, 27 Nov 2024 15:19:47 +0800 Subject: [PATCH 8/9] Refactor attack_cfg to attack_args, fix weights/checkpoint_path args not being passed --- README.md | 4 +-- tests/test_create_attack.py | 24 ++++++++-------- torchattack/bia.py | 8 ++++-- torchattack/cda.py | 8 ++++-- torchattack/create_attack.py | 56 ++++++++++++++++++++++-------------- torchattack/decowa.py | 2 +- torchattack/deepfool.py | 2 +- torchattack/difgsm.py | 4 +-- torchattack/eval/runner.py | 24 ++++++++-------- torchattack/mifgsm.py | 2 +- torchattack/pgd.py | 2 +- torchattack/ssa.py | 2 +- torchattack/ssp.py | 2 +- 13 files changed, 79 insertions(+), 61 deletions(-) diff --git a/README.md b/README.md index c425471..c3c58fd 100644 --- a/README.md +++ b/README.md @@ -80,8 +80,8 @@ attack = create_attack('FGSM', model, normalize, device) attack = create_attack('PGD', model, normalize, device, eps=0.03) # Initialize MI-FGSM attack with extra args with create_attack -attack_cfg = {'steps': 10, 'decay': 1.0} -attack = create_attack('MIFGSM', model, normalize, device, eps=0.03, attack_cfg=attack_cfg) +attack_args = {'steps': 10, 'decay': 1.0} +attack = create_attack('MIFGSM', model, normalize, device, eps=0.03, attack_args=attack_args) ``` Check out [`torchattack.eval.runner`](torchattack/eval/runner.py) for a full example. diff --git a/tests/test_create_attack.py b/tests/test_create_attack.py index bb7e593..d779dd3 100644 --- a/tests/test_create_attack.py +++ b/tests/test_create_attack.py @@ -41,37 +41,37 @@ def test_create_generative_attack_same_as_imported(attack_name, expected): def test_create_attack_with_eps(device, resnet50_model): eps = 0.3 - attack_cfg = {} + attack_args = {} attacker = create_attack( attack='FGSM', model=resnet50_model, normalize=resnet50_model.normalize, device=device, eps=eps, - attack_cfg=attack_cfg, + attack_args=attack_args, ) assert attacker.eps == eps -def test_create_attack_with_attack_cfg_eps(device, resnet50_model): - attack_cfg = {'eps': 0.1} +def test_create_attack_with_attack_args_eps(device, resnet50_model): + attack_args = {'eps': 0.1} attacker = create_attack( attack='FGSM', model=resnet50_model, normalize=resnet50_model.normalize, device=device, - attack_cfg=attack_cfg, + attack_args=attack_args, ) - assert attacker.eps == attack_cfg['eps'] + assert attacker.eps == attack_args['eps'] -def test_create_attack_with_both_eps_and_attack_cfg(device, resnet50_model): +def test_create_attack_with_both_eps_and_attack_args(device, resnet50_model): eps = 0.3 - attack_cfg = {'eps': 0.1} + attack_args = {'eps': 0.1} with pytest.warns( UserWarning, match="The 'eps' value provided as an argument will overwrite the existing " - "'eps' value in 'attack_cfg'. This MAY NOT be the intended behavior.", + "'eps' value in 'attack_args'. This MAY NOT be the intended behavior.", ): attacker = create_attack( attack='FGSM', @@ -79,7 +79,7 @@ def test_create_attack_with_both_eps_and_attack_cfg(device, resnet50_model): normalize=resnet50_model.normalize, device=device, eps=eps, - attack_cfg=attack_cfg, + attack_args=attack_args, ) assert attacker.eps == eps @@ -102,7 +102,7 @@ def test_create_attack_with_invalid_eps(device, resnet50_model): def test_create_attack_with_weights_and_checkpoint_path(device): weights = 'VGG19_IMAGENET1K' checkpoint_path = 'path/to/checkpoint' - attack_cfg = {} + attack_args = {} with pytest.warns( UserWarning, match="argument 'weights' and 'checkpoint_path' are only used for " @@ -113,7 +113,7 @@ def test_create_attack_with_weights_and_checkpoint_path(device): device=device, weights=weights, checkpoint_path=checkpoint_path, - attack_cfg=attack_cfg, + attack_args=attack_args, ) assert 'weights' not in attacker.__dict__ assert 'checkpoint_path' not in attacker.__dict__ diff --git a/torchattack/bia.py b/torchattack/bia.py index fa6a899..7b7e7a0 100644 --- a/torchattack/bia.py +++ b/torchattack/bia.py @@ -85,7 +85,9 @@ def __init__( # Prioritize checkpoint path over provided weights enum if self.checkpoint_path is not None: - self.generator.load_state_dict(torch.load(self.checkpoint_path)) + self.generator.load_state_dict( + torch.load(self.checkpoint_path, weights_only=True) + ) else: # Verify and load weights from enum if checkpoint path is not provided self.weights: BIAWeights = BIAWeights.verify(weights) @@ -117,7 +119,7 @@ def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor: run_attack( attack=BIA, - attack_cfg={'eps': 10 / 255, 'weights': 'VGG16_DA'}, - model_name='resnet50', + attack_args={'eps': 10 / 255, 'weights': 'VGG16_DA'}, + model_name='vgg16', victim_model_names=['resnet152'], ) diff --git a/torchattack/cda.py b/torchattack/cda.py index 4dfd205..b956a63 100644 --- a/torchattack/cda.py +++ b/torchattack/cda.py @@ -61,7 +61,9 @@ def __init__( # Prioritize checkpoint path over provided weights enum if self.checkpoint_path is not None: - self.generator.load_state_dict(torch.load(self.checkpoint_path)) + self.generator.load_state_dict( + torch.load(self.checkpoint_path, weights_only=True) + ) else: # Verify and load weights from enum if checkpoint path is not provided self.weights: CDAWeights = CDAWeights.verify(weights) @@ -93,7 +95,7 @@ def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor: run_attack( attack=CDA, - attack_cfg={'eps': 10 / 255, 'weights': 'VGG16_IMAGENET1K'}, - model_name='resnet50', + attack_args={'eps': 10 / 255, 'weights': 'VGG19_IMAGENET1K'}, + model_name='vgg19', victim_model_names=['resnet152'], ) diff --git a/torchattack/create_attack.py b/torchattack/create_attack.py index d60900c..ca078e4 100644 --- a/torchattack/create_attack.py +++ b/torchattack/create_attack.py @@ -22,7 +22,7 @@ def create_attack( eps: float | None = None, weights: str | None = 'DEFAULT', checkpoint_path: str | None = None, - attack_cfg: dict[str, Any] | None = None, + attack_args: dict[str, Any] | None = None, ) -> Attack: """Create a torchattack instance based on the provided attack name and config. @@ -34,7 +34,7 @@ def create_attack( eps: The epsilon value for the attack. Defaults to None. weights: The name of the generator weight for generative attacks. Defaults to 'DEFAULT'. checkpoint_path: The path to the checkpoint file for generative attacks. Defaults to None. - attack_cfg: Additional config parameters for the attack. Defaults to None. + attack_args: Additional config parameters for the attack. Defaults to None. Returns: Attack: An instance of the specified attack. @@ -43,15 +43,15 @@ def create_attack( ValueError: If the specified attack name is not supported within torchattack. Notes: - - If `eps` is provided and also present in `attack_cfg`, a warning will be - issued and the value in `attack_cfg` will be overwritten. + - If `eps` is provided and also present in `attack_args`, a warning will be + issued and the value in `attack_args` will be overwritten. - For certain attacks like 'GeoDA' and 'DeepFool', the `eps` parameter is - invalid and will be ignored if present in `attack_cfg`. + invalid and will be ignored if present in `attack_args`. - The `weights` and `checkpoint_path` parameters are only used for generative attacks. """ - if attack_cfg is None: - attack_cfg = {} + if attack_args is None: + attack_args = {} # Check if attack_name is supported attack_name = attack if isinstance(attack, str) else attack.__name__ @@ -60,19 +60,19 @@ def create_attack( if not isinstance(attack, Attack): attack = getattr(torchattack, attack_name) - # Check if eps is provided and overwrite the value in attack_cfg if present + # Check if eps is provided and overwrite the value in attack_args if present if eps is not None: - if 'eps' in attack_cfg: + if 'eps' in attack_args: attack_warn( "The 'eps' value provided as an argument will overwrite the existing " - "'eps' value in 'attack_cfg'. This MAY NOT be the intended behavior." + "'eps' value in 'attack_args'. This MAY NOT be the intended behavior." ) - attack_cfg['eps'] = eps + attack_args['eps'] = eps - # Check if attacks that do not require eps have eps in attack_cfg - if attack_name in torchattack.NON_EPS_ATTACKS and 'eps' in attack_cfg: + # Check if attacks that do not require eps have eps in attack_args + if attack_name in torchattack.NON_EPS_ATTACKS and 'eps' in attack_args: attack_warn(f"argument 'eps' is invalid in {attack_name} and will be ignored.") - attack_cfg.pop('eps', None) + attack_args.pop('eps', None) # Check if non-generative attacks have weights or checkpoint_path if attack_name not in torchattack.GENERATIVE_ATTACKS and ( @@ -82,16 +82,30 @@ def create_attack( f"argument 'weights' and 'checkpoint_path' are only used for " f"generative attacks, and will be ignored for '{attack_name}'." ) - attack_cfg.pop('weights', None) - attack_cfg.pop('checkpoint_path', None) + attack_args.pop('weights', None) + attack_args.pop('checkpoint_path', None) # Special handling for generative attacks if attack_name in torchattack.GENERATIVE_ATTACKS: - attack_cfg['weights'] = weights - attack_cfg['checkpoint_path'] = checkpoint_path - return attack(device, **attack_cfg) - - return attack(model, normalize, device, **attack_cfg) + if weights != 'DEFAULT': + if 'weights' in attack_args: + attack_warn( + "The 'weights' value provided as an argument will " + "overwrite the existing 'weights' value in 'attack_args'. " + 'This MAY NOT be the intended behavior.' + ) + attack_args['weights'] = weights + if checkpoint_path is not None: + if 'checkpoint_path' in attack_args: + attack_warn( + "The 'checkpoint_path' value provided as an argument will " + "overwrite the existing 'checkpoint_path' value in 'attack_args'. " + 'This MAY NOT be the intended behavior.' + ) + attack_args['checkpoint_path'] = checkpoint_path + return attack(device, **attack_args) + + return attack(model, normalize, device, **attack_args) if __name__ == '__main__': diff --git a/torchattack/decowa.py b/torchattack/decowa.py index c49812e..8f9085c 100644 --- a/torchattack/decowa.py +++ b/torchattack/decowa.py @@ -306,7 +306,7 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: run_attack( attack=DeCoWA, - attack_cfg={'eps': 16 / 255, 'steps': 10}, + attack_args={'eps': 16 / 255, 'steps': 10}, model_name='resnet50', victim_model_names=['resnet18', 'vit_b_16'], ) diff --git a/torchattack/deepfool.py b/torchattack/deepfool.py index b3640f8..678a462 100644 --- a/torchattack/deepfool.py +++ b/torchattack/deepfool.py @@ -180,6 +180,6 @@ def _atleast_kd(self, x: torch.Tensor, k: int) -> torch.Tensor: run_attack( DeepFool, - attack_cfg={'steps': 50, 'overshoot': 0.02}, + attack_args={'steps': 50, 'overshoot': 0.02}, model_name='resnet152', ) diff --git a/torchattack/difgsm.py b/torchattack/difgsm.py index 774f33d..1b377f5 100644 --- a/torchattack/difgsm.py +++ b/torchattack/difgsm.py @@ -164,5 +164,5 @@ def input_diversity( if __name__ == '__main__': from torchattack.eval import run_attack - cfgs = {'eps': 16 / 255, 'steps': 10, 'resize_rate': 0.9, 'diversity_prob': 1.0} - run_attack(DIFGSM, attack_cfg=cfgs) + args = {'eps': 16 / 255, 'steps': 10, 'resize_rate': 0.9, 'diversity_prob': 1.0} + run_attack(DIFGSM, attack_args=args) diff --git a/torchattack/eval/runner.py b/torchattack/eval/runner.py index 6a6f0ce..8260f4e 100644 --- a/torchattack/eval/runner.py +++ b/torchattack/eval/runner.py @@ -3,7 +3,7 @@ def run_attack( attack: Any, - attack_cfg: dict | None = None, + attack_args: dict | None = None, model_name: str = 'resnet50', victim_model_names: list[str] | None = None, dataset_root: str = 'datasets/nips2017', @@ -14,12 +14,12 @@ def run_attack( Example: >>> from torchattack import FGSM - >>> cfg = {"eps": 8 / 255, "clip_min": 0.0, "clip_max": 1.0} - >>> run_attack(attack=FGSM, attack_cfg=cfg) + >>> args = {"eps": 8 / 255, "clip_min": 0.0, "clip_max": 1.0} + >>> run_attack(attack=FGSM, attack_args=args) Args: attack: The attack class to initialize, either by name or class instance. - attack_cfg: A dict of keyword arguments passed to the attack class. + attack_args: A dict of keyword arguments passed to the attack class. model_name: The surrogate model to attack. Defaults to "resnet50". victim_model_names: A list of the victim black-box models to attack. Defaults to None. dataset_root: Root directory of the dataset. Defaults to "datasets/nips2017". @@ -35,8 +35,8 @@ def run_attack( from torchattack.eval.dataset import NIPSLoader from torchattack.eval.metric import FoolingRateMetric - if attack_cfg is None: - attack_cfg = {} + if attack_args is None: + attack_args = {} # Setup model device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') @@ -54,7 +54,7 @@ def run_attack( # Set up attack and trackers frm = FoolingRateMetric() - attacker = create_attack(attack, model, normalize, device, attack_cfg=attack_cfg) + attacker = create_attack(attack, model, normalize, device, attack_args=attack_args) print(attacker) # Setup victim models if provided @@ -119,16 +119,16 @@ def run_attack( parser.add_argument('--batch-size', type=int, default=4) args = parser.parse_args() - attack_cfg = {} + attack_args = {} if args.attack not in torchattack.NON_EPS_ATTACKS: # type: ignore - attack_cfg['eps'] = args.eps + attack_args['eps'] = args.eps if args.attack in torchattack.GENERATIVE_ATTACKS: # type: ignore - attack_cfg['weights'] = args.weights - attack_cfg['checkpoint_path'] = args.checkpoint + attack_args['weights'] = args.weights + attack_args['checkpoint_path'] = args.checkpoint run_attack( attack=args.attack, - attack_cfg=attack_cfg, + attack_args=attack_args, model_name=args.model_name, victim_model_names=args.victim_model_names, dataset_root=args.dataset_root, diff --git a/torchattack/mifgsm.py b/torchattack/mifgsm.py index 356bd32..42f06d7 100644 --- a/torchattack/mifgsm.py +++ b/torchattack/mifgsm.py @@ -105,7 +105,7 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: run_attack( attack=MIFGSM, - attack_cfg={'eps': 8 / 255, 'steps': 10}, + attack_args={'eps': 8 / 255, 'steps': 10}, model_name='resnet18', victim_model_names=['resnet50', 'vgg13', 'densenet121'], ) diff --git a/torchattack/pgd.py b/torchattack/pgd.py index aa82651..2da8e9c 100644 --- a/torchattack/pgd.py +++ b/torchattack/pgd.py @@ -108,7 +108,7 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: run_attack( attack=PGD, - attack_cfg={'eps': 8 / 255, 'steps': 20, 'random_start': True}, + attack_args={'eps': 8 / 255, 'steps': 20, 'random_start': True}, model_name='resnet18', victim_model_names=['resnet50', 'vgg13', 'densenet121'], ) diff --git a/torchattack/ssa.py b/torchattack/ssa.py index 2eb17d3..6cd7baf 100644 --- a/torchattack/ssa.py +++ b/torchattack/ssa.py @@ -252,7 +252,7 @@ def _idct_2d(self, mat_x: torch.Tensor, norm: str | None = None) -> torch.Tensor run_attack( attack=SSA, - attack_cfg={'eps': 8 / 255, 'steps': 10}, + attack_args={'eps': 8 / 255, 'steps': 10}, model_name='resnet18', victim_model_names=['resnet50', 'vgg13', 'densenet121'], ) diff --git a/torchattack/ssp.py b/torchattack/ssp.py index e163ee6..932f7ef 100644 --- a/torchattack/ssp.py +++ b/torchattack/ssp.py @@ -114,4 +114,4 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: if __name__ == '__main__': from torchattack.eval import run_attack - run_attack(SSP, attack_cfg={'eps': 16 / 255, 'ssp_layer': 16}) + run_attack(SSP, attack_args={'eps': 16 / 255, 'ssp_layer': 16}) From 3f148f972eb51f531ce92865a3793bd5da1aa1eb Mon Sep 17 00:00:00 2001 From: spencerwooo Date: Wed, 27 Nov 2024 15:51:47 +0800 Subject: [PATCH 9/9] Minor fix and test adjustments --- tests/test_attacks.py | 69 ++++++++---------------------- tests/test_create_attack.py | 18 ++++++++ torchattack/eval/runner.py | 12 +++++- torchattack/generative/_weights.py | 4 ++ 4 files changed, 50 insertions(+), 53 deletions(-) diff --git a/tests/test_attacks.py b/tests/test_attacks.py index caa34c2..b01f878 100644 --- a/tests/test_attacks.py +++ b/tests/test_attacks.py @@ -1,33 +1,13 @@ import pytest -from torchattack import ( - DIFGSM, - FGSM, - FIA, - MIFGSM, - NIFGSM, - PGD, - PGDL2, - SINIFGSM, - SSA, - SSP, - TGR, - TIFGSM, - VDC, - VMIFGSM, - VNIFGSM, - Admix, - DeCoWA, - DeepFool, - GeoDA, - PNAPatchOut, -) +import torchattack from torchattack.attack_model import AttackModel def run_attack_test(attack_cls, device, model, x, y): normalize = model.normalize - attacker = attack_cls(model, normalize, device=device) + # attacker = attack_cls(model, normalize, device=device) + attacker = torchattack.create_attack(attack_cls, model, normalize, device=device) x, y = x.to(device), y.to(device) x_adv = attacker(x, y) x_outs, x_adv_outs = model(normalize(x)), model(normalize(x_adv)) @@ -37,44 +17,31 @@ def run_attack_test(attack_cls, device, model, x, y): @pytest.mark.parametrize( 'attack_cls', - [ - DIFGSM, - FGSM, - FIA, - MIFGSM, - NIFGSM, - PGD, - PGDL2, - SINIFGSM, - SSA, - SSP, - TIFGSM, - VMIFGSM, - VNIFGSM, - Admix, - DeCoWA, - DeepFool, - GeoDA, - ], + (torchattack.GRADIENT_NON_VIT_ATTACKS | torchattack.NON_EPS_ATTACKS).keys(), ) -def test_cnn_attacks(attack_cls, device, resnet50_model, data): +def test_common_attacks(attack_cls, device, resnet50_model, data): x, y = data(resnet50_model.transform) run_attack_test(attack_cls, device, resnet50_model, x, y) @pytest.mark.parametrize( 'attack_cls', - [ - TGR, - VDC, - PNAPatchOut, - ], + torchattack.GRADIENT_VIT_ATTACKS.keys(), ) -def test_vit_attacks(attack_cls, device, vitb16_model, data): +def test_gradient_vit_attacks(attack_cls, device, vitb16_model, data): x, y = data(vitb16_model.transform) run_attack_test(attack_cls, device, vitb16_model, x, y) +@pytest.mark.parametrize( + 'attack_cls', + torchattack.GENERATIVE_ATTACKS.keys(), +) +def test_generative_attacks(attack_cls, device, resnet50_model, data): + x, y = data(resnet50_model.transform) + run_attack_test(attack_cls, device, resnet50_model, x, y) + + @pytest.mark.parametrize( 'model_name', [ @@ -87,7 +54,7 @@ def test_vit_attacks(attack_cls, device, vitb16_model, data): def test_tgr_attack_all_supported_models(device, model_name, data): model = AttackModel.from_pretrained(model_name, device, from_timm=True) x, y = data(model.transform) - run_attack_test(TGR, device, model, x, y) + run_attack_test(torchattack.TGR, device, model, x, y) @pytest.mark.parametrize( @@ -101,4 +68,4 @@ def test_tgr_attack_all_supported_models(device, model_name, data): def test_vdc_attack_all_supported_models(device, model_name, data): model = AttackModel.from_pretrained(model_name, device, from_timm=True) x, y = data(model.transform) - run_attack_test(VDC, device, model, x, y) + run_attack_test(torchattack.VDC, device, model, x, y) diff --git a/tests/test_create_attack.py b/tests/test_create_attack.py index d779dd3..6a387e3 100644 --- a/tests/test_create_attack.py +++ b/tests/test_create_attack.py @@ -84,6 +84,24 @@ def test_create_attack_with_both_eps_and_attack_args(device, resnet50_model): assert attacker.eps == eps +def test_create_attack_with_both_weights_and_attack_args(device): + weights = 'VGG19_IMAGENET1K' + attack_args = {'weights': 'VGG19_IMAGENET1K'} + with pytest.warns( + UserWarning, + match="The 'weights' value provided as an argument will " + "overwrite the existing 'weights' value in 'attack_args'. " + 'This MAY NOT be the intended behavior.', + ): + attacker = create_attack( + attack='CDA', + device=device, + weights=weights, + attack_args=attack_args, + ) + assert attacker.weights == weights + + def test_create_attack_with_invalid_eps(device, resnet50_model): eps = 0.3 with pytest.warns( diff --git a/torchattack/eval/runner.py b/torchattack/eval/runner.py index 8260f4e..83f005a 100644 --- a/torchattack/eval/runner.py +++ b/torchattack/eval/runner.py @@ -10,13 +10,20 @@ def run_attack( max_samples: int = 100, batch_size: int = 4, ) -> None: - """Helper function to run attacks in `__main__`. + """Helper function to run evaluation on attacks. Example: >>> from torchattack import FGSM >>> args = {"eps": 8 / 255, "clip_min": 0.0, "clip_max": 1.0} >>> run_attack(attack=FGSM, attack_args=args) + Note: + For generative attacks, the model_name argument that defines the white-box + surrogate model is not required. You can, however, manually provide a model name + to load according to your designated weight for the generator, to recreate a + white-box evaluation scenario, such as using VGG-19 (`model_name='vgg19'`) for + BIA's VGG-19 generator weight (`BIAWeights.VGG19`). + Args: attack: The attack class to initialize, either by name or class instance. attack_args: A dict of keyword arguments passed to the attack class. @@ -120,11 +127,12 @@ def run_attack( args = parser.parse_args() attack_args = {} + args.eps = args.eps / 255 if args.attack not in torchattack.NON_EPS_ATTACKS: # type: ignore attack_args['eps'] = args.eps if args.attack in torchattack.GENERATIVE_ATTACKS: # type: ignore attack_args['weights'] = args.weights - attack_args['checkpoint_path'] = args.checkpoint + attack_args['checkpoint_path'] = args.checkpoint_path run_attack( attack=args.attack, diff --git a/torchattack/generative/_weights.py b/torchattack/generative/_weights.py index 6cd3d57..dcb9b51 100644 --- a/torchattack/generative/_weights.py +++ b/torchattack/generative/_weights.py @@ -29,6 +29,10 @@ def get_state_dict(self, *args: Any, **kwargs: Any) -> Mapping[str, Any]: def __repr__(self) -> str: return f'{self.__class__.__name__}.{self._name_}' + def __eq__(self, other: Any) -> bool: + other = self.verify(other) + return isinstance(other, self.__class__) and self.name == other.name + @property def url(self): return self.value.url