Skip to content

Commit

Permalink
Merge pull request #20 from spencerwooo:generative
Browse files Browse the repository at this point in the history
Support for generative attacks
  • Loading branch information
spencerwooo authored Nov 27, 2024
2 parents 6e3838d + 3f148f9 commit 0f6f023
Show file tree
Hide file tree
Showing 19 changed files with 670 additions and 187 deletions.
17 changes: 12 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,13 @@
<sub><b>Install from GitHub source -</b></sub>

```shell
python -m pip install git+https://github.com/spencerwooo/[email protected].1
python -m pip install git+https://github.com/spencerwooo/[email protected].2
```

<sub><b>Install from Gitee mirror -</b></sub>

```shell
python -m pip install git+https://gitee.com/spencerwoo/[email protected].1
python -m pip install git+https://gitee.com/spencerwoo/[email protected].2
```

## Usage
Expand All @@ -52,7 +52,7 @@ transform, normalize = model.transform, model.normalize
# Additionally, to explicitly specify where to load the pretrained model from (timm or torchvision),
# prepend the model name with 'timm/' or 'tv/' respectively, or use the `from_timm` argument, e.g.
vit_b16 = AttackModel.from_pretrained(model_name='timm/vit_base_patch16_224', device=device)
inception_v3 = AttackModel.from_pretrained(model_name='tv/inception_v3', device=device)
inv_v3 = AttackModel.from_pretrained(model_name='tv/inception_v3', device=device)
pit_b = AttackModel.from_pretrained(model_name='pit_b_224', device=device, from_timm=True)
```

Expand Down Expand Up @@ -80,8 +80,8 @@ attack = create_attack('FGSM', model, normalize, device)
attack = create_attack('PGD', model, normalize, device, eps=0.03)

# Initialize MI-FGSM attack with extra args with create_attack
attack_cfg = {'steps': 10, 'decay': 1.0}
attack = create_attack('MIFGSM', model, normalize, device, eps=0.03, attack_cfg=attack_cfg)
attack_args = {'steps': 10, 'decay': 1.0}
attack = create_attack('MIFGSM', model, normalize, device, eps=0.03, attack_args=attack_args)
```

Check out [`torchattack.eval.runner`](torchattack/eval/runner.py) for a full example.
Expand Down Expand Up @@ -110,6 +110,13 @@ Gradient-based attacks:
| DeCoWA | $\ell_\infty$ | AAAI 2024 | [Boosting Adversarial Transferability across Model Genus by Deformation-Constrained Warping](https://arxiv.org/abs/2402.03951) | `DeCoWA` |
| VDC | $\ell_\infty$ | AAAI 2024 | [Improving the Adversarial Transferability of Vision Transformers with Virtual Dense Connection](https://ojs.aaai.org/index.php/AAAI/article/view/28541) | `VDC` |

Generative attacks:

| Name | $\ell_p$ | Publication | Paper (Open Access) | Class Name |
| :--: | :-----------: | :----------: | ----------------------------------------------------------------------------------------------------------------------- | ---------- |
| CDA | $\ell_\infty$ | NeurIPS 2019 | [CDA: contrastive Divergence for Adversarial Attack](https://arxiv.org/abs/1905.11736) | `CDA` |
| BIA | $\ell_\infty$ | ICLR 2022 | [Beyond ImageNet Attack: Towards Crafting Adversarial Examples for Black-box Domains](https://arxiv.org/abs/2201.11528) | `BIA` |

Others:

| Name | $\ell_p$ | Publication | Paper (Open Access) | Class Name |
Expand Down
69 changes: 18 additions & 51 deletions tests/test_attacks.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,13 @@
import pytest

from torchattack import (
DIFGSM,
FGSM,
FIA,
MIFGSM,
NIFGSM,
PGD,
PGDL2,
SINIFGSM,
SSA,
SSP,
TGR,
TIFGSM,
VDC,
VMIFGSM,
VNIFGSM,
Admix,
DeCoWA,
DeepFool,
GeoDA,
PNAPatchOut,
)
import torchattack
from torchattack.attack_model import AttackModel


def run_attack_test(attack_cls, device, model, x, y):
normalize = model.normalize
attacker = attack_cls(model, normalize, device=device)
# attacker = attack_cls(model, normalize, device=device)
attacker = torchattack.create_attack(attack_cls, model, normalize, device=device)
x, y = x.to(device), y.to(device)
x_adv = attacker(x, y)
x_outs, x_adv_outs = model(normalize(x)), model(normalize(x_adv))
Expand All @@ -37,44 +17,31 @@ def run_attack_test(attack_cls, device, model, x, y):

@pytest.mark.parametrize(
'attack_cls',
[
DIFGSM,
FGSM,
FIA,
MIFGSM,
NIFGSM,
PGD,
PGDL2,
SINIFGSM,
SSA,
SSP,
TIFGSM,
VMIFGSM,
VNIFGSM,
Admix,
DeCoWA,
DeepFool,
GeoDA,
],
(torchattack.GRADIENT_NON_VIT_ATTACKS | torchattack.NON_EPS_ATTACKS).keys(),
)
def test_cnn_attacks(attack_cls, device, resnet50_model, data):
def test_common_attacks(attack_cls, device, resnet50_model, data):
x, y = data(resnet50_model.transform)
run_attack_test(attack_cls, device, resnet50_model, x, y)


@pytest.mark.parametrize(
'attack_cls',
[
TGR,
VDC,
PNAPatchOut,
],
torchattack.GRADIENT_VIT_ATTACKS.keys(),
)
def test_vit_attacks(attack_cls, device, vitb16_model, data):
def test_gradient_vit_attacks(attack_cls, device, vitb16_model, data):
x, y = data(vitb16_model.transform)
run_attack_test(attack_cls, device, vitb16_model, x, y)


@pytest.mark.parametrize(
'attack_cls',
torchattack.GENERATIVE_ATTACKS.keys(),
)
def test_generative_attacks(attack_cls, device, resnet50_model, data):
x, y = data(resnet50_model.transform)
run_attack_test(attack_cls, device, resnet50_model, x, y)


@pytest.mark.parametrize(
'model_name',
[
Expand All @@ -87,7 +54,7 @@ def test_vit_attacks(attack_cls, device, vitb16_model, data):
def test_tgr_attack_all_supported_models(device, model_name, data):
model = AttackModel.from_pretrained(model_name, device, from_timm=True)
x, y = data(model.transform)
run_attack_test(TGR, device, model, x, y)
run_attack_test(torchattack.TGR, device, model, x, y)


@pytest.mark.parametrize(
Expand All @@ -101,4 +68,4 @@ def test_tgr_attack_all_supported_models(device, model_name, data):
def test_vdc_attack_all_supported_models(device, model_name, data):
model = AttackModel.from_pretrained(model_name, device, from_timm=True)
x, y = data(model.transform)
run_attack_test(VDC, device, model, x, y)
run_attack_test(torchattack.VDC, device, model, x, y)
158 changes: 82 additions & 76 deletions tests/test_create_attack.py
Original file line number Diff line number Diff line change
@@ -1,56 +1,12 @@
import pytest

from torchattack import (
DIFGSM,
FGSM,
FIA,
MIFGSM,
NIFGSM,
PGD,
PGDL2,
SINIFGSM,
SSA,
SSP,
TGR,
TIFGSM,
VDC,
VMIFGSM,
VNIFGSM,
Admix,
DeCoWA,
DeepFool,
GeoDA,
PNAPatchOut,
create_attack,
)
import torchattack
from torchattack import create_attack


expected_non_vit_attacks = {
'DIFGSM': DIFGSM,
'FGSM': FGSM,
'FIA': FIA,
'MIFGSM': MIFGSM,
'NIFGSM': NIFGSM,
'PGD': PGD,
'PGDL2': PGDL2,
'SINIFGSM': SINIFGSM,
'SSA': SSA,
'SSP': SSP,
'TIFGSM': TIFGSM,
'VMIFGSM': VMIFGSM,
'VNIFGSM': VNIFGSM,
'Admix': Admix,
'DeCoWA': DeCoWA,
'DeepFool': DeepFool,
'GeoDA': GeoDA,
}
expected_vit_attacks = {
'TGR': TGR,
'VDC': VDC,
'PNAPatchOut': PNAPatchOut,
}


@pytest.mark.parametrize(('attack_name', 'expected'), expected_non_vit_attacks.items())
@pytest.mark.parametrize(
('attack_name', 'expected'), torchattack.GRADIENT_NON_VIT_ATTACKS.items()
)
def test_create_non_vit_attack_same_as_imported(
attack_name,
expected,
Expand All @@ -61,7 +17,9 @@ def test_create_non_vit_attack_same_as_imported(
assert created_attacker == expected_attacker


@pytest.mark.parametrize(('attack_name', 'expected'), expected_vit_attacks.items())
@pytest.mark.parametrize(
('attack_name', 'expected'), torchattack.GRADIENT_VIT_ATTACKS.items()
)
def test_create_vit_attack_same_as_imported(
attack_name,
expected,
Expand All @@ -72,57 +30,85 @@ def test_create_vit_attack_same_as_imported(
assert created_attacker == expected_attacker


@pytest.mark.parametrize(
('attack_name', 'expected'), torchattack.GENERATIVE_ATTACKS.items()
)
def test_create_generative_attack_same_as_imported(attack_name, expected):
created_attacker = create_attack(attack_name)
expected_attacker = expected()
assert created_attacker == expected_attacker


def test_create_attack_with_eps(device, resnet50_model):
eps = 0.3
attack_cfg = {}
attack_args = {}
attacker = create_attack(
attack_name='FGSM',
attack='FGSM',
model=resnet50_model,
normalize=resnet50_model.normalize,
device=device,
eps=eps,
attack_cfg=attack_cfg,
attack_args=attack_args,
)
assert attacker.eps == eps


def test_create_attack_with_attack_cfg_eps(device, resnet50_model):
attack_cfg = {'eps': 0.1}
def test_create_attack_with_attack_args_eps(device, resnet50_model):
attack_args = {'eps': 0.1}
attacker = create_attack(
attack_name='FGSM',
attack='FGSM',
model=resnet50_model,
normalize=resnet50_model.normalize,
device=device,
attack_cfg=attack_cfg,
attack_args=attack_args,
)
assert attacker.eps == attack_cfg['eps']
assert attacker.eps == attack_args['eps']


def test_create_attack_with_both_eps_and_attack_cfg(device, resnet50_model):
def test_create_attack_with_both_eps_and_attack_args(device, resnet50_model):
eps = 0.3
attack_cfg = {'eps': 0.1}
# with pytest.warns(
# UserWarning,
# match="'eps' in 'attack_cfg' (0.1) will be overwritten by the 'eps' argument value (0.3), which MAY NOT be intended.",
# ):
attacker = create_attack(
attack_name='FGSM',
model=resnet50_model,
normalize=resnet50_model.normalize,
device=device,
eps=eps,
attack_cfg=attack_cfg,
)
attack_args = {'eps': 0.1}
with pytest.warns(
UserWarning,
match="The 'eps' value provided as an argument will overwrite the existing "
"'eps' value in 'attack_args'. This MAY NOT be the intended behavior.",
):
attacker = create_attack(
attack='FGSM',
model=resnet50_model,
normalize=resnet50_model.normalize,
device=device,
eps=eps,
attack_args=attack_args,
)
assert attacker.eps == eps


def test_create_attack_with_both_weights_and_attack_args(device):
weights = 'VGG19_IMAGENET1K'
attack_args = {'weights': 'VGG19_IMAGENET1K'}
with pytest.warns(
UserWarning,
match="The 'weights' value provided as an argument will "
"overwrite the existing 'weights' value in 'attack_args'. "
'This MAY NOT be the intended behavior.',
):
attacker = create_attack(
attack='CDA',
device=device,
weights=weights,
attack_args=attack_args,
)
assert attacker.weights == weights


def test_create_attack_with_invalid_eps(device, resnet50_model):
eps = 0.3
with pytest.warns(
UserWarning, match="parameter 'eps' is invalid in DeepFool and will be ignored."
UserWarning, match="argument 'eps' is invalid in DeepFool and will be ignored."
):
attacker = create_attack(
attack_name='DeepFool',
attack='DeepFool',
model=resnet50_model,
normalize=resnet50_model.normalize,
device=device,
Expand All @@ -131,12 +117,32 @@ def test_create_attack_with_invalid_eps(device, resnet50_model):
assert 'eps' not in attacker.__dict__


def test_create_attack_with_weights_and_checkpoint_path(device):
weights = 'VGG19_IMAGENET1K'
checkpoint_path = 'path/to/checkpoint'
attack_args = {}
with pytest.warns(
UserWarning,
match="argument 'weights' and 'checkpoint_path' are only used for "
"generative attacks, and will be ignored for 'FGSM'.",
):
attacker = create_attack(
attack='FGSM',
device=device,
weights=weights,
checkpoint_path=checkpoint_path,
attack_args=attack_args,
)
assert 'weights' not in attacker.__dict__
assert 'checkpoint_path' not in attacker.__dict__


def test_create_attack_with_invalid_attack_name(device, resnet50_model):
with pytest.raises(
ValueError, match="Attack 'InvalidAttack' is not supported within torchattack."
):
create_attack(
attack_name='InvalidAttack',
attack='InvalidAttack',
model=resnet50_model,
normalize=resnet50_model.normalize,
device=device,
Expand Down
Loading

0 comments on commit 0f6f023

Please sign in to comment.