Skip to content

Commit

Permalink
feat(attack): add an L2 constrained PGD
Browse files Browse the repository at this point in the history
  • Loading branch information
spencerwooo committed Feb 2, 2023
1 parent 6d337cd commit 07d4889
Show file tree
Hide file tree
Showing 3 changed files with 130 additions and 0 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ attack = MIFGSM(model, transform, eps=0.03, steps=10, decay=1.0)
| :--------- | :------------------------------------------------------------------------------------------------------------------------- | :--------------------- |
| FGSM | [Explaining and Harnessing Adversarial Examples](https://arxiv.org/abs/1412.6572) | `torchattack.FGSM` |
| PGD | [Towards Deep Learning Models Resistant to Adversarial Attacks](https://arxiv.org/abs/1706.06083) | `torchattack.PGD` |
| PGD(L2) | [Towards Deep Learning Models Resistant to Adversarial Attacks](https://arxiv.org/abs/1706.06083) | `torchattack.PGDL2` |
| MI-FGSM | [Boosting Adversarial Attacks with Momentum](https://arxiv.org/abs/1710.06081) | `torchattack.MIFGSM` |
| DI-FGSM | [Improving Transferability of Adversarial Examples with Input Diversity](https://arxiv.org/abs/1803.06978) | `torchattack.DIFGSM` |
| TI-FGSM | [Evading Defenses to Transferable Adversarial Examples by Translation-Invariant Attacks](https://arxiv.org/abs/1904.02884) | `torchattack.TIFGSM` |
Expand Down
2 changes: 2 additions & 0 deletions torchattack/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from torchattack.mifgsm import MIFGSM
from torchattack.nifgsm import NIFGSM
from torchattack.pgd import PGD
from torchattack.pgdl2 import PGDL2
from torchattack.sinifgsm import SINIFGSM
from torchattack.tifgsm import TIFGSM
from torchattack.vmifgsm import VMIFGSM
Expand All @@ -14,6 +15,7 @@
"MIFGSM",
"NIFGSM",
"PGD",
"PGDL2",
"SINIFGSM",
"TIFGSM",
"VMIFGSM",
Expand Down
127 changes: 127 additions & 0 deletions torchattack/pgdl2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
from typing import Callable

import torch
import torch.nn as nn

from torchattack.base import Attack

EPS_FOR_DIVISION = 1e-12


class PGDL2(Attack):
"""The Projected Gradient Descent (PGD) attack, with L2 constraint.
From the paper 'Towards Deep Learning Models Resistant to Adversarial Attacks'
https://arxiv.org/abs/1706.06083
"""

def __init__(
self,
model: nn.Module,
transform: Callable[[torch.Tensor], torch.Tensor] | None,
eps: float = 1.0,
steps: int = 10,
alpha: float | None = None,
random_start: bool = True,
clip_min: float = 0.0,
clip_max: float = 1.0,
targeted: bool = False,
device: torch.device | None = None,
) -> None:
"""Initialize the PGD attack.
Args:
model: The model to attack.
transform: A transform to normalize images.
eps: The maximum perturbation, measured in L2. Defaults to 1.0.
steps: Number of steps. Defaults to 10.
alpha: Step size, `eps / (steps / 2)` if None. Defaults to None.
random_start: Start from random uniform perturbation. Defaults to True.
clip_min: Minimum value for clipping. Defaults to 0.0.
clip_max: Maximum value for clipping. Defaults to 1.0.
targeted: Targeted attack if True. Defaults to False.
device: Device to use for tensors. Defaults to cuda if available.
"""

super().__init__(transform, device)

self.model = model
self.eps = eps
self.alpha = alpha
self.steps = steps
self.random_start = random_start
self.clip_min = clip_min
self.clip_max = clip_max
self.targeted = targeted
self.lossfn = nn.CrossEntropyLoss()

def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
"""Perform PGD on a batch of images.
Args:
x: A batch of images. Shape: (N, C, H, W).
y: A batch of labels. Shape: (N).
Returns:
The perturbed images if successful. Shape: (N, C, H, W).
"""

# If random start enabled, delta (perturbation) is then randomly
# initialized with samples from a uniform distribution.
if self.random_start:
delta = torch.empty_like(x).normal_()
delta_flat = delta.view(x.size(0), -1)

n = delta_flat.norm(p=2, dim=1).view(x.size(0), 1, 1, 1)
r = torch.zeros_like(n).uniform_(0, 1)

delta = delta * r / n * self.eps
delta.requires_grad_()

else:
delta = torch.zeros_like(x, requires_grad=True)

# If alpha is not given, set to eps / (steps / 2)
if self.alpha is None:
self.alpha = self.eps / (self.steps / 2)

# Perform PGD
for _ in range(self.steps):
# Compute loss
outs = self.model(self.transform(x + delta))
loss = self.lossfn(outs, y)

if self.targeted:
loss = -loss

# Compute gradient
loss.backward()

if delta.grad is None:
continue

# Update delta
g = delta.grad

g_norms = torch.norm(g.view(x.size(0), -1), p=2, dim=1) + EPS_FOR_DIVISION
g = g / g_norms.view(x.size(0), 1, 1, 1)
delta.data = delta.data + self.alpha * g

delta_norms = torch.norm(delta.view(x.size(0), -1), p=2, dim=1)
factor = self.eps / delta_norms
factor = torch.min(factor, torch.ones_like(delta_norms))
delta.data = delta.data * factor.view(-1, 1, 1, 1)

delta.data = torch.clamp(x + delta.data, self.clip_min, self.clip_max) - x

# Zero out gradient
delta.grad.detach_()
delta.grad.zero_()

return x + delta


if __name__ == "__main__":
from torchattack.utils import run_attack

run_attack(PGDL2, {"eps": 1.0, "steps": 10, "random_start": False})

0 comments on commit 07d4889

Please sign in to comment.