From 7e2af742c14a78c27f51814c1affe8615de5e2b2 Mon Sep 17 00:00:00 2001 From: spencerwooo Date: Sat, 14 Sep 2024 16:51:21 +0800 Subject: [PATCH 1/6] Initialize attack implementation --- src/torchattack/decowa.py | 70 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 70 insertions(+) create mode 100644 src/torchattack/decowa.py diff --git a/src/torchattack/decowa.py b/src/torchattack/decowa.py new file mode 100644 index 0000000..5a2fa33 --- /dev/null +++ b/src/torchattack/decowa.py @@ -0,0 +1,70 @@ +from typing import Callable + +import torch +import torch.nn as nn + +from torchattack.base import Attack + + +class DeCoWA(Attack): + """The DeCoWA (Deformation-Constrained Warping Attack) attack. + + From the paper 'Boosting Adversarial Transferability across Model Genus by + Deformation-Constrained Warping', + https://arxiv.org/abs/2402.03951 + + Args: + model: The model to attack. + normalize: A transform to normalize images. + device: Device to use for tensors. Defaults to cuda if available. + eps: The maximum perturbation. Defaults to 8/255. + steps: Number of steps. Defaults to 10. + alpha: Step size, `eps / steps` if None. Defaults to None. + decay: Decay factor for the momentum term. Defaults to 1.0. + mesh_width: Width of the control points. Defaults to 3. + mesh_height: Height of the control points. Defaults to 3. + rho: Regularization parameter for deformation. Defaults to 0.01. + num_warping: Number of warping transformation samples. Defaults to 20. + noise_scale: Scale of the random noise. Defaults to 2. + clip_min: Minimum value for clipping. Defaults to 0.0. + clip_max: Maximum value for clipping. Defaults to 1.0. + targeted: Targeted attack if True. Defaults to False. + """ + + def __init__( + self, + model: nn.Module, + normalize: Callable[[torch.Tensor], torch.Tensor] | None, + device: torch.device | None = None, + eps: float = 8 / 255, + steps: int = 10, + alpha: float | None = None, + decay: float = 1.0, + mesh_width: int = 3, + mesh_height: int = 3, + rho: float = 0.01, + num_warping: int = 20, + noise_scale: int = 2, + clip_min: float = 0.0, + clip_max: float = 1.0, + targeted: bool = False, + ): + super().__init__(normalize, device) + + self.model = model + self.eps = eps + self.steps = steps + self.alpha = alpha + self.decay = decay + self.mesh_width = mesh_width + self.mesh_height = mesh_height + self.rho = rho + self.num_warping = num_warping + self.noise_scale = noise_scale + self.clip_min = clip_min + self.clip_max = clip_max + self.targeted = targeted + self.lossfn = nn.CrossEntropyLoss() + + def forward(): + pass From b12f97d39b7e146d210d12cf8c4933ad6beff3e0 Mon Sep 17 00:00:00 2001 From: spencerwooo Date: Mon, 16 Sep 2024 12:13:02 +0800 Subject: [PATCH 2/6] Add DeCoWA related function impls --- src/torchattack/decowa.py | 162 +++++++++++++++++++++++++++++++++++++- 1 file changed, 160 insertions(+), 2 deletions(-) diff --git a/src/torchattack/decowa.py b/src/torchattack/decowa.py index 5a2fa33..95b2695 100644 --- a/src/torchattack/decowa.py +++ b/src/torchattack/decowa.py @@ -66,5 +66,163 @@ def __init__( self.targeted = targeted self.lossfn = nn.CrossEntropyLoss() - def forward(): - pass + def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + """Perform MI-FGSM on a batch of images. + + Args: + x: A batch of images. Shape: (N, C, H, W). + y: A batch of labels. Shape: (N). + + Returns: + The perturbed images if successful. Shape: (N, C, H, W). + """ + + delta = torch.zeros_like(x, requires_grad=True) + + # If alpha is not given, set to eps / steps + if self.alpha is None: + self.alpha = self.eps / self.steps + + for _ in range(self.steps): + g = torch.zeros_like(x) + + # Perform warping + for _ in range(self.num_warping): + # Apply warping to perturbation + noise_map_hat = self._update_noise_map(x + delta, y) + vwt_x = self._vwt(x + delta, noise_map_hat) + + # Compute loss + outs = self.model(self.normalize(vwt_x)) + loss = self.lossfn(outs, y) + + if self.targeted: + loss = -loss + + # Compute gradient + loss.backward() + + if delta.grad is None: + continue + + # Accumulate gradient + g += delta.grad + + # Average gradient over warping + g /= self.num_warping + + # Apply momentum term + g = self.decay * g + delta.grad / torch.mean( + torch.abs(delta.grad), dim=(1, 2, 3), keepdim=True + ) + + # Update delta + delta.data = delta.data + self.alpha * g.sign() + delta.data = torch.clamp(delta.data, -self.eps, self.eps) + delta.data = torch.clamp(x + delta.data, self.clip_min, self.clip_max) - x + + # Zero out gradient + delta.grad.detach_() + delta.grad.zero_() + + return x + delta + + def _update_noise_map(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + with torch.no_grad(): + noise_map = torch.rand([self.mesh_height - 2, self.mesh_width - 2, 2]) - 0.5 + noise_map = noise_map * self.noise_scale + + noise_map.requires_grad_() + vwt_x = self._vwt(x, noise_map) + + def _vwt(self, x: torch.Tensor, noise_map_hat: torch.Tensor) -> torch.Tensor: + n, c, w, h = x.size() + grid_x = self._grid_points_2d(w, h, x.device) + + def _grid_points_2d( + self, width: int, height: int, device: torch.device + ) -> torch.Tensor: + x = torch.linspace(-1, 1, width, device=device) + y = torch.linspace(-1, 1, height, device=device) + grid_x, grid_y = torch.meshgrid(x, y) + grid = torch.stack((grid_y, grid_x), dim=-1) + return grid.view(-1, 2) + + def _noisy_grid( + self, width: int, height: int, noise_map: torch.Tensor, device: torch.device + ) -> torch.Tensor: + grid = self._grid_points_2d(width, height, device) + mod = torch.zeros((width * height, 2), device=device) + mod[1 : height - 1, 1 : width - 1, :] = noise_map + return grid + mod.reshape(-1, 2) + + +def k_matrix(x, y): + eps = 1e-9 + d2 = torch.pow(x[:, :, None, :] - y[:, None, :, :], 2).sum(-1) + k = d2 * torch.log(d2 + eps) + return k + + +def p_matrix(x): + n, k = x.shape[:2] + p = torch.ones(n, k, 3, device=x.device) + p[:, :, 1:] = x + return p + + +class TPSCoeffs(nn.Module): + def __init__(self): + super().__init__() + + def forward( + self, x: torch.Tensor, y: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: + n, k_mat = x.shape[:2] + z_mat = torch.zeros(1, k_mat + 3, 2, device=x.device) + p_mat = torch.ones(n, k_mat, 3, device=x.device) + l_mat = torch.zeros(n, k_mat + 3, k_mat + 3, device=x.device) + k_mat = k_matrix(x, x) + + p_mat[:, :, 1:] = x + z_mat[:, :k_mat, :] = y + l_mat[:, :k_mat, :k_mat] = k_mat + l_mat[:, :k_mat, k_mat:] = p_mat + l_mat[:, k_mat:, :k_mat] = p_mat.permute(0, 2, 1) + + q_mat = torch.linalg.solve(l_mat, z_mat) + return q_mat[:, :k_mat], q_mat[:, k_mat:] + + +class TPS(nn.Module): + def __init__(self, size: tuple = (256, 256), device: torch.device | None = None): + super().__init__() + h, w = size + self.size = size + self.device = device + + self.tps = TPSCoeffs() + + grid = torch.ones(1, h, w, 2, device=device) + grid[:, :, :, 0] = torch.linspace(-1, 1, w, device=device) + grid[:, :, :, 1] = torch.linspace(-1, 1, h, device=device)[..., None] + self.grid = grid.view(-1, h * w, 2) + + def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + h, w = self.size + w, a = self.tps(x, y) + u_mat = k_matrix(self.grid, x) + p_mat = p_matrix(self.grid) + grid = p_mat @ a + u_mat @ w + return grid.view(-1, h, w, 2) + + +if __name__ == '__main__': + from torchattack.eval import run_attack + + run_attack( + attack=DeCoWA, + attack_cfg={'eps': 8 / 255, 'steps': 10}, + model_name='resnet18', + victim_model_names=['resnet50', 'vgg13', 'densenet121'], + ) From 7694144aa385530def6131054b2367de17575478 Mon Sep 17 00:00:00 2001 From: spencerwooo Date: Wed, 18 Sep 2024 14:49:52 +0800 Subject: [PATCH 3/6] Complete implementation of DeCoWA --- src/torchattack/decowa.py | 77 ++++++++++++++++++++++++++------------- 1 file changed, 51 insertions(+), 26 deletions(-) diff --git a/src/torchattack/decowa.py b/src/torchattack/decowa.py index 95b2695..5273b85 100644 --- a/src/torchattack/decowa.py +++ b/src/torchattack/decowa.py @@ -89,7 +89,28 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: # Perform warping for _ in range(self.num_warping): # Apply warping to perturbation - noise_map_hat = self._update_noise_map(x + delta, y) + noise_map = torch.rand([self.mesh_height - 2, self.mesh_width - 2, 2]) + noise_map_hat = (noise_map - 0.5) * self.noise_scale + + # Iterate for a round only + for _ in range(1): + noise_map_hat.requires_grad_() + + vwt_x = self._vwt(x + delta, noise_map_hat) + outs = self.model(self.normalize(vwt_x)) + loss = self.lossfn(outs, y) + + if self.targeted: + loss = -loss + + loss.backward() + + if delta.grad is None: + continue + + noise_map_hat.detach_() + noise_map_hat -= self.rho * noise_map_hat.grad + vwt_x = self._vwt(x + delta, noise_map_hat) # Compute loss @@ -127,17 +148,21 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: return x + delta - def _update_noise_map(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: - with torch.no_grad(): - noise_map = torch.rand([self.mesh_height - 2, self.mesh_width - 2, 2]) - 0.5 - noise_map = noise_map * self.noise_scale - - noise_map.requires_grad_() - vwt_x = self._vwt(x, noise_map) - - def _vwt(self, x: torch.Tensor, noise_map_hat: torch.Tensor) -> torch.Tensor: + def _vwt(self, x: torch.Tensor, noise_map: torch.Tensor) -> torch.Tensor: n, c, w, h = x.size() - grid_x = self._grid_points_2d(w, h, x.device) + xx = self._grid_points_2d(self.mesh_width, self.mesh_height, self.device) + yy = self._noisy_grid(self.mesh_width, self.mesh_height, noise_map, self.device) + tpbs = TPS(size=(h, w), device=self.device) + warped_grid_b = tpbs(xx[None, ...], yy[None, ...]) + warped_grid_b = warped_grid_b.repeat(n, 1, 1, 1) + vwt_x = torch.grid_sampler_2d( + input=x, + grid=warped_grid_b, + interpolation_mode=0, + padding_mode=0, + align_corners=False, + ) + return vwt_x def _grid_points_2d( self, width: int, height: int, device: torch.device @@ -152,7 +177,7 @@ def _noisy_grid( self, width: int, height: int, noise_map: torch.Tensor, device: torch.device ) -> torch.Tensor: grid = self._grid_points_2d(width, height, device) - mod = torch.zeros((width * height, 2), device=device) + mod = torch.zeros([height, width, 2], device=device) mod[1 : height - 1, 1 : width - 1, :] = noise_map return grid + mod.reshape(-1, 2) @@ -178,20 +203,20 @@ def __init__(self): def forward( self, x: torch.Tensor, y: torch.Tensor ) -> tuple[torch.Tensor, torch.Tensor]: - n, k_mat = x.shape[:2] - z_mat = torch.zeros(1, k_mat + 3, 2, device=x.device) - p_mat = torch.ones(n, k_mat, 3, device=x.device) - l_mat = torch.zeros(n, k_mat + 3, k_mat + 3, device=x.device) + n, k = x.shape[:2] + z_mat = torch.zeros(1, k + 3, 2, device=x.device) + p_mat = torch.ones(n, k, 3, device=x.device) + l_mat = torch.zeros(n, k + 3, k + 3, device=x.device) k_mat = k_matrix(x, x) p_mat[:, :, 1:] = x - z_mat[:, :k_mat, :] = y - l_mat[:, :k_mat, :k_mat] = k_mat - l_mat[:, :k_mat, k_mat:] = p_mat - l_mat[:, k_mat:, :k_mat] = p_mat.permute(0, 2, 1) + z_mat[:, :k, :] = y + l_mat[:, :k, :k] = k_mat + l_mat[:, :k, k:] = p_mat + l_mat[:, k:, :k] = p_mat.permute(0, 2, 1) q_mat = torch.linalg.solve(l_mat, z_mat) - return q_mat[:, :k_mat], q_mat[:, k_mat:] + return q_mat[:, :k], q_mat[:, k:] class TPS(nn.Module): @@ -210,10 +235,10 @@ def __init__(self, size: tuple = (256, 256), device: torch.device | None = None) def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: h, w = self.size - w, a = self.tps(x, y) + w_mat, a_mat = self.tps(x, y) u_mat = k_matrix(self.grid, x) p_mat = p_matrix(self.grid) - grid = p_mat @ a + u_mat @ w + grid = p_mat @ a_mat + u_mat @ w_mat return grid.view(-1, h, w, 2) @@ -222,7 +247,7 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: run_attack( attack=DeCoWA, - attack_cfg={'eps': 8 / 255, 'steps': 10}, - model_name='resnet18', - victim_model_names=['resnet50', 'vgg13', 'densenet121'], + attack_cfg={'eps': 16 / 255, 'steps': 10}, + model_name='resnet50', + victim_model_names=['resnet18', 'vit_b_16'], ) From 08ca1a87cd7a5192b5e1ad7d24c54a419ef43bc8 Mon Sep 17 00:00:00 2001 From: spencerwooo Date: Wed, 18 Sep 2024 15:01:42 +0800 Subject: [PATCH 4/6] Add DeCoWA attack to README --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index f8f0f3a..adba1af 100644 --- a/README.md +++ b/README.md @@ -57,6 +57,7 @@ Gradient-based attacks: | FIA | $\ell_\infty$ | [Feature Importance-aware Transferable Adversarial Attacks](https://arxiv.org/abs/2107.14185) | `torchattack.FIA` | | PNA-PatchOut | $\ell_\infty$ | [Towards Transferable Adversarial Attacks on Vision Transformers](https://arxiv.org/abs/2109.04176) | `torchattack.PNAPatchOut` | | TGR | $\ell_\infty$ | [Transferable Adversarial Attacks on Vision Transformers with Token Gradient Regularization](https://arxiv.org/abs/2303.15754) | `torchattack.TGR` | +| DeCoWA | $\ell_\infty$ | [Boosting Adversarial Transferability across Model Genus by Deformation-Constrained Warping](https://arxiv.org/abs/2402.03951) | `torchattack.DeCoWA` | Others: From 158f0049d0d39862c42522e3e13b330007640c19 Mon Sep 17 00:00:00 2001 From: spencerwooo Date: Wed, 18 Sep 2024 15:20:07 +0800 Subject: [PATCH 5/6] Fix mypy issues and add additional type hints --- src/torchattack/decowa.py | 30 ++++++++++++++++++------------ 1 file changed, 18 insertions(+), 12 deletions(-) diff --git a/src/torchattack/decowa.py b/src/torchattack/decowa.py index 5273b85..4f0d041 100644 --- a/src/torchattack/decowa.py +++ b/src/torchattack/decowa.py @@ -105,7 +105,7 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: loss.backward() - if delta.grad is None: + if noise_map_hat.grad is None: continue noise_map_hat.detach_() @@ -123,12 +123,12 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: # Compute gradient loss.backward() - if delta.grad is None: - continue - # Accumulate gradient g += delta.grad + if delta.grad is None: + continue + # Average gradient over warping g /= self.num_warping @@ -149,12 +149,15 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: return x + delta def _vwt(self, x: torch.Tensor, noise_map: torch.Tensor) -> torch.Tensor: - n, c, w, h = x.size() + n, _, w, h = x.size() + xx = self._grid_points_2d(self.mesh_width, self.mesh_height, self.device) yy = self._noisy_grid(self.mesh_width, self.mesh_height, noise_map, self.device) + tpbs = TPS(size=(h, w), device=self.device) warped_grid_b = tpbs(xx[None, ...], yy[None, ...]) warped_grid_b = warped_grid_b.repeat(n, 1, 1, 1) + vwt_x = torch.grid_sampler_2d( input=x, grid=warped_grid_b, @@ -162,6 +165,7 @@ def _vwt(self, x: torch.Tensor, noise_map: torch.Tensor) -> torch.Tensor: padding_mode=0, align_corners=False, ) + return vwt_x def _grid_points_2d( @@ -169,8 +173,10 @@ def _grid_points_2d( ) -> torch.Tensor: x = torch.linspace(-1, 1, width, device=device) y = torch.linspace(-1, 1, height, device=device) + grid_x, grid_y = torch.meshgrid(x, y) grid = torch.stack((grid_y, grid_x), dim=-1) + return grid.view(-1, 2) def _noisy_grid( @@ -182,18 +188,18 @@ def _noisy_grid( return grid + mod.reshape(-1, 2) -def k_matrix(x, y): +def k_matrix(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: eps = 1e-9 d2 = torch.pow(x[:, :, None, :] - y[:, None, :, :], 2).sum(-1) - k = d2 * torch.log(d2 + eps) - return k + k_mat = d2 * torch.log(d2 + eps) + return k_mat -def p_matrix(x): +def p_matrix(x: torch.Tensor) -> torch.Tensor: n, k = x.shape[:2] - p = torch.ones(n, k, 3, device=x.device) - p[:, :, 1:] = x - return p + p_mat = torch.ones(n, k, 3, device=x.device) + p_mat[:, :, 1:] = x + return p_mat class TPSCoeffs(nn.Module): From 4a30516a432459c8d78db2d978505267e9fa7d39 Mon Sep 17 00:00:00 2001 From: spencerwooo Date: Wed, 18 Sep 2024 15:22:36 +0800 Subject: [PATCH 6/6] Minor version update --- src/torchattack/__init__.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/torchattack/__init__.py b/src/torchattack/__init__.py index 706b9cf..69a17d0 100644 --- a/src/torchattack/__init__.py +++ b/src/torchattack/__init__.py @@ -1,4 +1,5 @@ from torchattack.admix import Admix +from torchattack.decowa import DeCoWA from torchattack.deepfool import DeepFool from torchattack.difgsm import DIFGSM from torchattack.fgsm import FGSM @@ -10,14 +11,17 @@ from torchattack.pgdl2 import PGDL2 from torchattack.pna_patchout import PNAPatchOut from torchattack.sinifgsm import SINIFGSM +from torchattack.ssp import SSP +from torchattack.tgr import TGR from torchattack.tifgsm import TIFGSM from torchattack.vmifgsm import VMIFGSM from torchattack.vnifgsm import VNIFGSM -__version__ = '0.6.0' +__version__ = '0.7.0' __all__ = [ 'Admix', + 'DeCoWA', 'DeepFool', 'DIFGSM', 'FIA', @@ -29,6 +33,8 @@ 'PGDL2', 'PNAPatchOut', 'SINIFGSM', + 'SSP', + 'TGR', 'TIFGSM', 'VMIFGSM', 'VNIFGSM',