diff --git a/MANIFEST.in b/MANIFEST.in
index 2f29c8b18..14ac7a048 100644
--- a/MANIFEST.in
+++ b/MANIFEST.in
@@ -1,5 +1,6 @@
include basicsr/ops/dcn/src/*.cu basicsr/ops/dcn/src/*.cpp
include basicsr/ops/fused_act/src/*.cu basicsr/ops/fused_act/src/*.cpp
include basicsr/ops/upfirdn2d/src/*.cu basicsr/ops/upfirdn2d/src/*.cpp
+include basicsr/metrics/niqe_pris_params.npz
include VERSION
include requirements.txt
diff --git a/README.md b/README.md
index d49f2f056..0e48e36fd 100644
--- a/README.md
+++ b/README.md
@@ -65,21 +65,21 @@ Other recommended projects:
We provide simple pipelines to train/test/inference models for a quick start.
These pipelines/commands cannot cover all the cases and more details are in the following sections.
-| GAN | | | | | |
-| :--- | :---: | :---: | :--- | :---: | :---: |
-| StyleGAN2 | [Train](docs/HOWTOs.md#How-to-train-StyleGAN2) | [Inference](docs/HOWTOs.md#How-to-inference-StyleGAN2) | | | |
-| **Face Restoration** | | | | | |
-| DFDNet | - | [Inference](docs/HOWTOs.md#How-to-inference-DFDNet) | | | |
-| **Super Resolution** | | | | | |
-| ESRGAN | *TODO* | *TODO* | SRGAN | *TODO* | *TODO*|
-| EDSR | *TODO* | *TODO* | SRResNet | *TODO* | *TODO*|
-| RCAN | *TODO* | *TODO* | SwinIR | [Train](docs/HOWTOs.md#how-to-train-swinir-sr) | [Inference](docs/HOWTOs.md#how-to-inference-swinir-sr)|
-| EDVR | *TODO* | *TODO* | DUF | - | *TODO* |
-| BasicVSR | *TODO* | *TODO* | TOF | - | *TODO* |
-| **Deblurring** | | | | | |
-| DeblurGANv2 | - | *TODO* | | | |
-| **Denoise** | | | | | |
-| RIDNet | - | *TODO* | CBDNet | - | *TODO*|
+| GAN | | | | | |
+| :------------------- | :--------------------------------------------: | :----------------------------------------------------: | :------- | :--------------------------------------------: | :----------------------------------------------------: |
+| StyleGAN2 | [Train](docs/HOWTOs.md#How-to-train-StyleGAN2) | [Inference](docs/HOWTOs.md#How-to-inference-StyleGAN2) | | | |
+| **Face Restoration** | | | | | |
+| DFDNet | - | [Inference](docs/HOWTOs.md#How-to-inference-DFDNet) | | | |
+| **Super Resolution** | | | | | |
+| ESRGAN | *TODO* | *TODO* | SRGAN | *TODO* | *TODO* |
+| EDSR | *TODO* | *TODO* | SRResNet | *TODO* | *TODO* |
+| RCAN | *TODO* | *TODO* | SwinIR | [Train](docs/HOWTOs.md#how-to-train-swinir-sr) | [Inference](docs/HOWTOs.md#how-to-inference-swinir-sr) |
+| EDVR | *TODO* | *TODO* | DUF | - | *TODO* |
+| BasicVSR | *TODO* | *TODO* | TOF | - | *TODO* |
+| **Deblurring** | | | | | |
+| DeblurGANv2 | - | *TODO* | | | |
+| **Denoise** | | | | | |
+| RIDNet | - | *TODO* | CBDNet | - | *TODO* |
## :wrench: Dependencies and Installation
@@ -114,7 +114,7 @@ Please see [project boards](https://github.com/xinntao/BasicSR/projects).
Please see [DesignConvention.md](docs/DesignConvention.md) for the designs and conventions of the BasicSR codebase.
The figure below shows the overall framework. More descriptions for each component:
-**[Datasets.md](docs/Datasets.md)** | **[Models.md](docs/Models.md)** | **[Config.md](Config.md)** | **[Logging.md](docs/Logging.md)**
+**[Datasets.md](docs/Datasets.md)** | **[Models.md](docs/Models.md)** | **[Config.md](docs/Config.md)** | **[Logging.md](docs/Logging.md)**
![overall_structure](./assets/overall_structure.png)
@@ -144,7 +144,12 @@ The following is a BibTeX reference. The BibTeX entry requires the `url` LaTeX p
If you have any questions, please email `xintao.wang@outlook.com`.
+
+
+- **QQ群**: 扫描左边二维码 或者 搜索QQ群号: 320960100 入群答案:互帮互助共同进步
+- **微信群**: 因为微信群超过200人,需要邀请才可以进群;要进微信群的小伙伴可以先添加 Liangbin 的个人微信 (右边二维码),他会在空闲的时候拉大家入群~
+
- +
diff --git a/README_CN.md b/README_CN.md index 7b3424f74..e34173721 100644 --- a/README_CN.md +++ b/README_CN.md @@ -62,21 +62,21 @@ BasicSR (**Basic** **S**uper **R**estoration) 是一个基于 PyTorch 的开源 我们提供了简单的流程来快速上手 训练/测试/推理 模型. 这些命令并不能涵盖所有用法, 更多的细节参见下面的部分. -| GAN | | | | | | -| :--- | :---: | :---: | :--- | :---: | :---: | -| StyleGAN2 | [训练](docs/HOWTOs_CN.md#如何训练-StyleGAN2) | [测试](docs/HOWTOs_CN.md#如何测试-StyleGAN2) | | | | -| **Face Restoration** | | | | | | -| DFDNet | - | [测试](docs/HOWTOs_CN.md#如何测试-DFDNet) | | | | -| **Super Resolution** | | | | | | -| ESRGAN | *TODO* | *TODO* | SRGAN | *TODO* | *TODO*| -| EDSR | *TODO* | *TODO* | SRResNet | *TODO* | *TODO*| -| RCAN | *TODO* | *TODO* | SwinIR | [Train](docs/HOWTOs.md#how-to-train-swinir-sr) | [Inference](docs/HOWTOs.md#how-to-inference-swinir-sr)| -| EDVR | *TODO* | *TODO* | DUF | - | *TODO* | -| BasicVSR | *TODO* | *TODO* | TOF | - | *TODO* | -| **Deblurring** | | | | | | -| DeblurGANv2 | - | *TODO* | | | | -| **Denoise** | | | | | | -| RIDNet | - | *TODO* | CBDNet | - | *TODO*| +| GAN | | | | | | +| :------------------- | :------------------------------------------: | :------------------------------------------: | :------- | :--------------------------------------------: | :----------------------------------------------------: | +| StyleGAN2 | [训练](docs/HOWTOs_CN.md#如何训练-StyleGAN2) | [测试](docs/HOWTOs_CN.md#如何测试-StyleGAN2) | | | | +| **Face Restoration** | | | | | | +| DFDNet | - | [测试](docs/HOWTOs_CN.md#如何测试-DFDNet) | | | | +| **Super Resolution** | | | | | | +| ESRGAN | *TODO* | *TODO* | SRGAN | *TODO* | *TODO* | +| EDSR | *TODO* | *TODO* | SRResNet | *TODO* | *TODO* | +| RCAN | *TODO* | *TODO* | SwinIR | [Train](docs/HOWTOs.md#how-to-train-swinir-sr) | [Inference](docs/HOWTOs.md#how-to-inference-swinir-sr) | +| EDVR | *TODO* | *TODO* | DUF | - | *TODO* | +| BasicVSR | *TODO* | *TODO* | TOF | - | *TODO* | +| **Deblurring** | | | | | | +| DeblurGANv2 | - | *TODO* | | | | +| **Denoise** | | | | | | +| RIDNet | - | *TODO* | CBDNet | - | *TODO* | ## :wrench: 依赖和安装 @@ -112,7 +112,7 @@ For detailed instructions refer to [INSTALL.md](INSTALL.md). 参见 [DesignConvention_CN.md](docs/DesignConvention_CN.md).- +
diff --git a/VERSION b/VERSION index c42fd3d8b..a349a55be 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -1.3.4.4 +1.3.4.7 diff --git a/basicsr/archs/discriminator_arch.py b/basicsr/archs/discriminator_arch.py index bc6603e0f..2e33bd3b2 100644 --- a/basicsr/archs/discriminator_arch.py +++ b/basicsr/archs/discriminator_arch.py @@ -1,7 +1,7 @@ from torch import nn as nn - from basicsr.utils.registry import ARCH_REGISTRY - +from torch.nn import functional as F +from torch.nn.utils import spectral_norm @ARCH_REGISTRY.register() class VGGStyleDiscriminator128(nn.Module): @@ -147,3 +147,59 @@ def forward(self, x): feat = self.lrelu(self.linear1(feat)) out = self.linear2(feat) return out + + +@ARCH_REGISTRY.register() +class UNetDiscriminatorSN(nn.Module): + """Defines a U-Net discriminator with spectral normalization (SN)""" + + def __init__(self, num_in_ch, num_feat=64, skip_connection=True): + super(UNetDiscriminatorSN, self).__init__() + self.skip_connection = skip_connection + norm = spectral_norm + + self.conv0 = nn.Conv2d(num_in_ch, num_feat, kernel_size=3, stride=1, padding=1) + + self.conv1 = norm(nn.Conv2d(num_feat, num_feat * 2, 4, 2, 1, bias=False)) + self.conv2 = norm(nn.Conv2d(num_feat * 2, num_feat * 4, 4, 2, 1, bias=False)) + self.conv3 = norm(nn.Conv2d(num_feat * 4, num_feat * 8, 4, 2, 1, bias=False)) + # upsample + self.conv4 = norm(nn.Conv2d(num_feat * 8, num_feat * 4, 3, 1, 1, bias=False)) + self.conv5 = norm(nn.Conv2d(num_feat * 4, num_feat * 2, 3, 1, 1, bias=False)) + self.conv6 = norm(nn.Conv2d(num_feat * 2, num_feat, 3, 1, 1, bias=False)) + + # extra + self.conv7 = norm(nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=False)) + self.conv8 = norm(nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=False)) + + self.conv9 = nn.Conv2d(num_feat, 1, 3, 1, 1) + + def forward(self, x): + x0 = F.leaky_relu(self.conv0(x), negative_slope=0.2, inplace=True) + x1 = F.leaky_relu(self.conv1(x0), negative_slope=0.2, inplace=True) + x2 = F.leaky_relu(self.conv2(x1), negative_slope=0.2, inplace=True) + x3 = F.leaky_relu(self.conv3(x2), negative_slope=0.2, inplace=True) + + # upsample + x3 = F.interpolate(x3, scale_factor=2, mode='bilinear', align_corners=False) + x4 = F.leaky_relu(self.conv4(x3), negative_slope=0.2, inplace=True) + + if self.skip_connection: + x4 = x4 + x2 + x4 = F.interpolate(x4, scale_factor=2, mode='bilinear', align_corners=False) + x5 = F.leaky_relu(self.conv5(x4), negative_slope=0.2, inplace=True) + + if self.skip_connection: + x5 = x5 + x1 + x5 = F.interpolate(x5, scale_factor=2, mode='bilinear', align_corners=False) + x6 = F.leaky_relu(self.conv6(x5), negative_slope=0.2, inplace=True) + + if self.skip_connection: + x6 = x6 + x0 + + # extra + out = F.leaky_relu(self.conv7(x6), negative_slope=0.2, inplace=True) + out = F.leaky_relu(self.conv8(out), negative_slope=0.2, inplace=True) + out = self.conv9(out) + + return out \ No newline at end of file diff --git a/basicsr/archs/ecbsr_arch.py b/basicsr/archs/ecbsr_arch.py index a05c0a027..9ecb1fdb9 100644 --- a/basicsr/archs/ecbsr_arch.py +++ b/basicsr/archs/ecbsr_arch.py @@ -227,6 +227,8 @@ class ECBSR(nn.Module): def __init__(self, num_in_ch, num_out_ch, num_block, num_channel, with_idt, act_type, scale): super(ECBSR, self).__init__() + self.num_in_ch = num_in_ch + self.scale = scale backbone = [] backbone += [ECB(num_in_ch, num_channel, depth_multiplier=2.0, act_type=act_type, with_idt=with_idt)] @@ -240,6 +242,10 @@ def __init__(self, num_in_ch, num_out_ch, num_block, num_channel, with_idt, act_ self.upsampler = nn.PixelShuffle(scale) def forward(self, x): - y = self.backbone(x) + x # will repeat the input in the channel dimension (repeat scale * scale times) + if self.num_in_ch > 1: + shortcut = torch.repeat_interleave(x, self.scale * self.scale, dim=1) + else: + shortcut = x # will repeat the input in the channel dimension (repeat scale * scale times) + y = self.backbone(x) + shortcut y = self.upsampler(y) return y diff --git a/basicsr/archs/focalir_arch.py b/basicsr/archs/focalir_arch.py new file mode 100644 index 000000000..c3892dd5b --- /dev/null +++ b/basicsr/archs/focalir_arch.py @@ -0,0 +1,1470 @@ +###### +# FocalIR +# This code is referenced by Focal Transformer and SwinIR +# This model is supported by BasicSR +###### +# -------------------------------------------------------- +# Focal Transformer +# Copyright (c) 2021 Microsoft +# Licensed under The MIT License [see LICENSE for details] +# Written by Jianwei Yang (jianwyan@microsoft.com) +# Based on Swin Transformer written by Zhe Liu +# -------------------------------------------------------- +# ----------------------------------------------------------------------------------- +# SwinIR: Image Restoration Using Swin Transformer, https://arxiv.org/abs/2108.10257 +# Originally Written by Ze Liu, Modified by Jingyun Liang. +# ----------------------------------------------------------------------------------- + + +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as checkpoint + +from basicsr.archs.arch_util import to_2tuple, trunc_normal_ +from basicsr.utils.registry import ARCH_REGISTRY +from thop import profile as hp + + +def drop_path(x, drop_prob: float = 0., training: bool = False): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + + From: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/drop.py + """ + if drop_prob == 0. or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) + random_tensor.floor_() # binarize + output = x.div(keep_prob) * random_tensor + return output + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + + From: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/drop.py + """ + + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) + + +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +def window_partition(x, window_size): + """ + Args: + x: (B, H, W, C) + window_size (int): window size + + Returns: + windows: (num_windows*B, window_size, window_size, C) + """ + B, H, W, C = x.shape + x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + return windows + + +def window_partition_noreshape(x, window_size): + """ + Args: + x: (B, H, W, C) + window_size (int): window size + + Returns: + windows: (B, num_windows_h, num_windows_w, window_size, window_size, C) + """ + B, H, W, C = x.shape + x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous() + return windows + + +def window_reverse(windows, window_size, H, W): + """ + Args: + windows: (num_windows*B, window_size, window_size, C) + window_size (int): Window size + H (int): Height of image + W (int): Width of image + + Returns: + x: (B, H, W, C) + """ + B = int(windows.shape[0] / (H * W / window_size / window_size)) + x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + return x + + +def get_roll_masks(H, W, window_size, shift_size): + ##################################### + # move to top-left + img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 + h_slices = (slice(0, H - window_size), + slice(H - window_size, H - shift_size), + slice(H - shift_size, H)) + w_slices = (slice(0, W - window_size), + slice(W - window_size, W - shift_size), + slice(W - shift_size, W)) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + mask_windows = window_partition(img_mask, window_size) # nW, window_size, window_size, 1 + mask_windows = mask_windows.view(-1, window_size * window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask_tl = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + + #################################### + # move to top right + img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 + h_slices = (slice(0, H - window_size), + slice(H - window_size, H - shift_size), + slice(H - shift_size, H)) + w_slices = (slice(0, shift_size), + slice(shift_size, window_size), + slice(window_size, W)) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + mask_windows = window_partition(img_mask, window_size) # nW, window_size, window_size, 1 + mask_windows = mask_windows.view(-1, window_size * window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask_tr = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + + #################################### + # move to bottom left + img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 + h_slices = (slice(0, shift_size), + slice(shift_size, window_size), + slice(window_size, H)) + w_slices = (slice(0, W - window_size), + slice(W - window_size, W - shift_size), + slice(W - shift_size, W)) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + mask_windows = window_partition(img_mask, window_size) # nW, window_size, window_size, 1 + mask_windows = mask_windows.view(-1, window_size * window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask_bl = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + + #################################### + # move to bottom right + img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 + h_slices = (slice(0, shift_size), + slice(shift_size, window_size), + slice(window_size, H)) + w_slices = (slice(0, shift_size), + slice(shift_size, window_size), + slice(window_size, W)) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + mask_windows = window_partition(img_mask, window_size) # nW, window_size, window_size, 1 + mask_windows = mask_windows.view(-1, window_size * window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask_br = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + + # append all + attn_mask_all = torch.cat((attn_mask_tl, attn_mask_tr, attn_mask_bl, attn_mask_br), -1) + return attn_mask_all + + +def get_relative_position_index(q_windows, k_windows): + """ + Args: + q_windows: tuple (query_window_height, query_window_width) + k_windows: tuple (key_window_height, key_window_width) + + Returns: + relative_position_index: query_window_height*query_window_width, key_window_height*key_window_width + """ + # get pair-wise relative position index for each token inside the window + coords_h_q = torch.arange(q_windows[0]) + coords_w_q = torch.arange(q_windows[1]) + coords_q = torch.stack(torch.meshgrid([coords_h_q, coords_w_q])) # 2, Wh_q, Ww_q + + coords_h_k = torch.arange(k_windows[0]) + coords_w_k = torch.arange(k_windows[1]) + coords_k = torch.stack(torch.meshgrid([coords_h_k, coords_w_k])) # 2, Wh, Ww + + coords_flatten_q = torch.flatten(coords_q, 1) # 2, Wh_q*Ww_q + coords_flatten_k = torch.flatten(coords_k, 1) # 2, Wh_k*Ww_k + + relative_coords = coords_flatten_q[:, :, None] - coords_flatten_k[:, None, :] # 2, Wh_q*Ww_q, Wh_k*Ww_k + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh_q*Ww_q, Wh_k*Ww_k, 2 + relative_coords[:, :, 0] += k_windows[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += k_windows[1] - 1 + relative_coords[:, :, 0] *= (q_windows[1] + k_windows[1]) - 1 + relative_position_index = relative_coords.sum(-1) # Wh_q*Ww_q, Wh_k*Ww_k + return relative_position_index + + +class WindowAttention(nn.Module): + r""" Window based multi-head self attention (W-MSA) module with relative position bias. + + Args: + dim (int): Number of input channels. + expand_size (int): The expand size at focal level 1. + window_size (tuple[int]): The height and width of the window. + focal_window (int): Focal region size. + focal_level (int): Focal attention level. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + pool_method (str): window pooling method. Default: none + """ + + def __init__(self, dim, expand_size, window_size, focal_window, focal_level, num_heads, + qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0., pool_method="none"): + + super().__init__() + self.dim = dim + self.expand_size = expand_size + self.window_size = window_size # Wh, Ww + self.pool_method = pool_method + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + self.focal_level = focal_level + self.focal_window = focal_window + + # define a parameter table of relative position bias for each window + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + self.register_buffer("relative_position_index", relative_position_index) + + if self.expand_size > 0 and focal_level > 0: + # define a parameter table of position bias between window and its fine-grained surroundings + self.window_size_of_key = self.window_size[0] * self.window_size[1] if self.expand_size == 0 else \ + (4 * self.window_size[0] * self.window_size[1] - 4 * (self.window_size[0] - self.expand_size) * ( + self.window_size[0] - self.expand_size)) + self.relative_position_bias_table_to_neighbors = nn.Parameter( + torch.zeros(1, num_heads, self.window_size[0] * self.window_size[1], + self.window_size_of_key)) # Wh*Ww, nH, nSurrounding + trunc_normal_(self.relative_position_bias_table_to_neighbors, std=.02) + + # get mask for rolled k and rolled v + mask_tl = torch.ones(self.window_size[0], self.window_size[1]); + mask_tl[:-self.expand_size, :-self.expand_size] = 0 + mask_tr = torch.ones(self.window_size[0], self.window_size[1]); + mask_tr[:-self.expand_size, self.expand_size:] = 0 + mask_bl = torch.ones(self.window_size[0], self.window_size[1]); + mask_bl[self.expand_size:, :-self.expand_size] = 0 + mask_br = torch.ones(self.window_size[0], self.window_size[1]); + mask_br[self.expand_size:, self.expand_size:] = 0 + mask_rolled = torch.stack((mask_tl, mask_tr, mask_bl, mask_br), 0).flatten(0) + self.register_buffer("valid_ind_rolled", mask_rolled.nonzero().view(-1)) + + if pool_method != "none" and focal_level > 1: + self.relative_position_bias_table_to_windows = nn.ParameterList() + self.unfolds = nn.ModuleList() + + # build relative position bias between local patch and pooled windows + for k in range(focal_level - 1): + stride = 2 ** k + kernel_size = 2 * (self.focal_window // 2) + 2 ** k + (2 ** k - 1) + # define unfolding operations + self.unfolds += [nn.Unfold( + kernel_size=(kernel_size, kernel_size), + stride=stride, padding=kernel_size // 2) + ] + + # define relative position bias table + relative_position_bias_table_to_windows = nn.Parameter( + torch.zeros( + self.num_heads, + (self.window_size[0] + self.focal_window + 2 ** k - 2) * ( + self.window_size[1] + self.focal_window + 2 ** k - 2), + ) + ) + trunc_normal_(relative_position_bias_table_to_windows, std=.02) + self.relative_position_bias_table_to_windows.append(relative_position_bias_table_to_windows) + + # define relative position bias index + relative_position_index_k = get_relative_position_index(self.window_size, + to_2tuple(self.focal_window + 2 ** k - 1)) + self.register_buffer("relative_position_index_{}".format(k), relative_position_index_k) + + # define unfolding index for focal_level > 0 + if k > 0: + mask = torch.zeros(kernel_size, kernel_size); + mask[(2 ** k) - 1:, (2 ** k) - 1:] = 1 + self.register_buffer("valid_ind_unfold_{}".format(k), mask.flatten(0).nonzero().view(-1)) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x_all, mask_all=None): + """ + Args: + x_all (list[Tensors]): input features at different granularity + mask_all (list[Tensors/None]): masks for input features at different granularity + """ + x = x_all[0] # + + B, nH, nW, C = x.shape + qkv = self.qkv(x).reshape(B, nH, nW, 3, C).permute(3, 0, 1, 2, 4).contiguous() + q, k, v = qkv[0], qkv[1], qkv[2] # B, nH, nW, C + + # partition q map + (q_windows, k_windows, v_windows) = map( + lambda t: window_partition(t, self.window_size[0]).view( + -1, self.window_size[0] * self.window_size[0], self.num_heads, C // self.num_heads + ).transpose(1, 2), + (q, k, v) + ) + + if self.expand_size > 0 and self.focal_level > 0: + (k_tl, v_tl) = map( + lambda t: torch.roll(t, shifts=(-self.expand_size, -self.expand_size), dims=(1, 2)), (k, v) + ) + (k_tr, v_tr) = map( + lambda t: torch.roll(t, shifts=(-self.expand_size, self.expand_size), dims=(1, 2)), (k, v) + ) + (k_bl, v_bl) = map( + lambda t: torch.roll(t, shifts=(self.expand_size, -self.expand_size), dims=(1, 2)), (k, v) + ) + (k_br, v_br) = map( + lambda t: torch.roll(t, shifts=(self.expand_size, self.expand_size), dims=(1, 2)), (k, v) + ) + + (k_tl_windows, k_tr_windows, k_bl_windows, k_br_windows) = map( + lambda t: window_partition(t, self.window_size[0]).view(-1, self.window_size[0] * self.window_size[0], + self.num_heads, C // self.num_heads), + (k_tl, k_tr, k_bl, k_br) + ) + (v_tl_windows, v_tr_windows, v_bl_windows, v_br_windows) = map( + lambda t: window_partition(t, self.window_size[0]).view(-1, self.window_size[0] * self.window_size[0], + self.num_heads, C // self.num_heads), + (v_tl, v_tr, v_bl, v_br) + ) + k_rolled = torch.cat((k_tl_windows, k_tr_windows, k_bl_windows, k_br_windows), 1).transpose(1, 2) + v_rolled = torch.cat((v_tl_windows, v_tr_windows, v_bl_windows, v_br_windows), 1).transpose(1, 2) + + # mask out tokens in current window + k_rolled = k_rolled[:, :, self.valid_ind_rolled] + v_rolled = v_rolled[:, :, self.valid_ind_rolled] + k_rolled = torch.cat((k_windows, k_rolled), 2) + v_rolled = torch.cat((v_windows, v_rolled), 2) + else: + k_rolled = k_windows; + v_rolled = v_windows; + + if self.pool_method != "none" and self.focal_level > 1: + k_pooled = [] + v_pooled = [] + for k in range(self.focal_level - 1): + stride = 2 ** k + x_window_pooled = x_all[k + 1] # B, nWh, nWw, C + nWh, nWw = x_window_pooled.shape[1:3] + + # generate mask for pooled windows + mask = x_window_pooled.new(nWh, nWw).fill_(1) + unfolded_mask = self.unfolds[k](mask.unsqueeze(0).unsqueeze(1)).view( + 1, 1, self.unfolds[k].kernel_size[0], self.unfolds[k].kernel_size[1], -1).permute(0, 4, 2, 3, + 1).contiguous(). \ + view(nWh * nWw // stride // stride, -1, 1) + + if k > 0: + valid_ind_unfold_k = getattr(self, "valid_ind_unfold_{}".format(k)) + unfolded_mask = unfolded_mask[:, valid_ind_unfold_k] + + x_window_masks = unfolded_mask.flatten(1).unsqueeze(0) + x_window_masks = x_window_masks.masked_fill(x_window_masks == 0, float(-100.0)).masked_fill( + x_window_masks > 0, float(0.0)) + mask_all[k + 1] = x_window_masks + + # generate k and v for pooled windows + qkv_pooled = self.qkv(x_window_pooled).reshape(B, nWh, nWw, 3, C).permute(3, 0, 4, 1, 2).contiguous() + k_pooled_k, v_pooled_k = qkv_pooled[1], qkv_pooled[2] # B, C, nWh, nWw + + (k_pooled_k, v_pooled_k) = map( + lambda t: self.unfolds[k](t).view( + B, C, self.unfolds[k].kernel_size[0], self.unfolds[k].kernel_size[1], -1).permute(0, 4, 2, 3, + 1).contiguous(). \ + view(-1, self.unfolds[k].kernel_size[0] * self.unfolds[k].kernel_size[1], self.num_heads, + C // self.num_heads).transpose(1, 2), + (k_pooled_k, v_pooled_k) # (B x (nH*nW)) x nHeads x (unfold_wsize x unfold_wsize) x head_dim + ) + + if k > 0: + (k_pooled_k, v_pooled_k) = map( + lambda t: t[:, :, valid_ind_unfold_k], (k_pooled_k, v_pooled_k) + ) + + k_pooled += [k_pooled_k] + v_pooled += [v_pooled_k] + k_all = torch.cat([k_rolled] + k_pooled, 2) + v_all = torch.cat([v_rolled] + v_pooled, 2) + else: + k_all = k_rolled + v_all = v_rolled + + N = k_all.shape[-2] + q_windows = q_windows * self.scale + attn = (q_windows @ k_all.transpose(-2, + -1)) # B*nW, nHead, window_size*window_size, focal_window_size*focal_window_size + + window_area = self.window_size[0] * self.window_size[1] + window_area_rolled = k_rolled.shape[2] + + # add relative position bias for tokens inside window + relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + attn[:, :, :window_area, :window_area] = attn[:, :, :window_area, + :window_area] + relative_position_bias.unsqueeze(0) + + # add relative position bias for patches inside a window + if self.expand_size > 0 and self.focal_level > 0: + attn[:, :, :window_area, window_area:window_area_rolled] = attn[:, :, :window_area, + window_area:window_area_rolled] + self.relative_position_bias_table_to_neighbors + + if self.pool_method != "none" and self.focal_level > 1: + # add relative position bias for different windows in an image + offset = window_area_rolled + for k in range(self.focal_level - 1): + # add relative position bias + relative_position_index_k = getattr(self, 'relative_position_index_{}'.format(k)) + relative_position_bias_to_windows = self.relative_position_bias_table_to_windows[k][:, + relative_position_index_k.view(-1)].view( + -1, self.window_size[0] * self.window_size[1], (self.focal_window + 2 ** k - 1) ** 2, + ) # nH, NWh*NWw,focal_region*focal_region + attn[:, :, :window_area, offset:(offset + (self.focal_window + 2 ** k - 1) ** 2)] = \ + attn[:, :, :window_area, offset:(offset + ( + self.focal_window + 2 ** k - 1) ** 2)] + relative_position_bias_to_windows.unsqueeze(0) + # add attentional mask + if mask_all[k + 1] is not None: + attn[:, :, :window_area, offset:(offset + (self.focal_window + 2 ** k - 1) ** 2)] = \ + attn[:, :, :window_area, offset:(offset + (self.focal_window + 2 ** k - 1) ** 2)] + \ + mask_all[k + 1][:, :, None, None, :].repeat(attn.shape[0] // mask_all[k + 1].shape[1], 1, 1, 1, + 1).view(-1, 1, 1, mask_all[k + 1].shape[-1]) + + offset += (self.focal_window + 2 ** k - 1) ** 2 + + if mask_all[0] is not None: + nW = mask_all[0].shape[0] + attn = attn.view(attn.shape[0] // nW, nW, self.num_heads, window_area, N) + attn[:, :, :, :, :window_area] = attn[:, :, :, :, :window_area] + mask_all[0][None, :, None, :, :] + attn = attn.view(-1, self.num_heads, window_area, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v_all).transpose(1, 2).reshape(attn.shape[0], window_area, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + def extra_repr(self) -> str: + return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}' + + def flops(self, N, window_size, unfold_size): + # calculate flops for 1 window with token length of N + flops = 0 + # qkv = self.qkv(x) + flops += N * self.dim * 3 * self.dim + # attn = (q @ k.transpose(-2, -1)) + flops += self.num_heads * N * (self.dim // self.num_heads) * N + if self.pool_method != "none" and self.focal_level > 1: + flops += self.num_heads * N * (self.dim // self.num_heads) * (unfold_size * unfold_size) + if self.expand_size > 0 and self.focal_level > 0: + flops += self.num_heads * N * (self.dim // self.num_heads) * ( + (window_size + 2 * self.expand_size) ** 2 - window_size ** 2) + + # x = (attn @ v) + flops += self.num_heads * N * N * (self.dim // self.num_heads) + if self.pool_method != "none" and self.focal_level > 1: + flops += self.num_heads * N * (self.dim // self.num_heads) * (unfold_size * unfold_size) + if self.expand_size > 0 and self.focal_level > 0: + flops += self.num_heads * N * (self.dim // self.num_heads) * ( + (window_size + 2 * self.expand_size) ** 2 - window_size ** 2) + + # x = self.proj(x) + flops += N * self.dim * self.dim + return flops + + +class FocalTransformerBlock(nn.Module): + r""" Focal Transformer Block. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resulotion. + num_heads (int): Number of attention heads. + window_size (int): Window size. + expand_size (int): expand size at first focal level (finest level). + shift_size (int): Shift size for SW-MSA. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + pool_method (str): window pooling method. Default: none, options: [none|fc|conv] + focal_level (int): number of focal levels. Default: 1. + focal_window (int): region size of focal attention. Default: 1 + use_layerscale (bool): whether use layer scale for training stability. Default: False + layerscale_value (float): scaling value for layer scale. Default: 1e-4 + """ + + def __init__(self, dim, input_resolution, num_heads, window_size=7, expand_size=0, shift_size=0, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., + act_layer=nn.GELU, norm_layer=nn.LayerNorm, pool_method="none", + focal_level=1, focal_window=1, use_layerscale=False, layerscale_value=1e-4): + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.expand_size = expand_size + self.mlp_ratio = mlp_ratio + self.pool_method = pool_method + self.focal_level = focal_level + self.focal_window = focal_window + self.use_layerscale = use_layerscale + + if min(self.input_resolution) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.expand_size = 0 + self.shift_size = 0 + self.window_size = min(self.input_resolution) + assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" + + self.window_size_glo = self.window_size + + self.pool_layers = nn.ModuleList() + if self.pool_method != "none": + for k in range(self.focal_level - 1): + window_size_glo = math.floor(self.window_size_glo / (2 ** k)) + if self.pool_method == "fc": + self.pool_layers.append(nn.Linear(window_size_glo * window_size_glo, 1)) + self.pool_layers[-1].weight.data.fill_(1. / (window_size_glo * window_size_glo)) + self.pool_layers[-1].bias.data.fill_(0) + elif self.pool_method == "conv": + self.pool_layers.append( + nn.Conv2d(dim, dim, kernel_size=window_size_glo, stride=window_size_glo, groups=dim)) + + self.norm1 = norm_layer(dim) + + self.attn = WindowAttention( + dim, expand_size=self.expand_size, window_size=to_2tuple(self.window_size), + focal_window=focal_window, focal_level=focal_level, num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, pool_method=pool_method) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + if self.shift_size > 0: + # calculate attention mask for SW-MSA + H, W = self.input_resolution + img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 + h_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + w_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 + mask_windows = mask_windows.view(-1, self.window_size * self.window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + else: + attn_mask = None + self.register_buffer("attn_mask", attn_mask) + + if self.use_layerscale: + self.gamma_1 = nn.Parameter(layerscale_value * torch.ones((dim)), requires_grad=True) + self.gamma_2 = nn.Parameter(layerscale_value * torch.ones((dim)), requires_grad=True) + + def forward(self, x, x_size): + H, W = x_size + B, _, C = x.shape + # assert L == H * W, "input feature has wrong size" + + shortcut = x + x = self.norm1(x) + x = x.view(B, H, W, C) + + # pad feature maps to multiples of window size + pad_l = pad_t = 0 + pad_r = (self.window_size - W % self.window_size) % self.window_size + pad_b = (self.window_size - H % self.window_size) % self.window_size + if pad_r > 0 or pad_b > 0: + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) + + B, H, W, C = x.shape + + if self.shift_size > 0: + shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + else: + shifted_x = x + + x_windows_all = [shifted_x] + x_window_masks_all = [self.attn_mask] + + if self.focal_level > 1 and self.pool_method != "none": + # if we add coarser granularity and the pool method is not none + for k in range(self.focal_level - 1): + window_size_glo = math.floor(self.window_size_glo / (2 ** k)) + pooled_h = math.ceil(H / self.window_size) * (2 ** k) + pooled_w = math.ceil(W / self.window_size) * (2 ** k) + H_pool = pooled_h * window_size_glo + W_pool = pooled_w * window_size_glo + + x_level_k = shifted_x + # trim or pad shifted_x depending on the required size + if H > H_pool: + trim_t = (H - H_pool) // 2 + trim_b = H - H_pool - trim_t + x_level_k = x_level_k[:, trim_t:-trim_b] + elif H < H_pool: + pad_t = (H_pool - H) // 2 + pad_b = H_pool - H - pad_t + x_level_k = F.pad(x_level_k, (0, 0, 0, 0, pad_t, pad_b)) + + if W > W_pool: + trim_l = (W - W_pool) // 2 + trim_r = W - W_pool - trim_l + x_level_k = x_level_k[:, :, trim_l:-trim_r] + elif W < W_pool: + pad_l = (W_pool - W) // 2 + pad_r = W_pool - W - pad_l + x_level_k = F.pad(x_level_k, (0, 0, pad_l, pad_r)) + + x_windows_noreshape = window_partition_noreshape(x_level_k.contiguous(), + window_size_glo) # B, nw, nw, window_size, window_size, C + nWh, nWw = x_windows_noreshape.shape[1:3] + if self.pool_method == "mean": + x_windows_pooled = x_windows_noreshape.mean([3, 4]) # B, nWh, nWw, C + elif self.pool_method == "max": + x_windows_pooled = x_windows_noreshape.max(-2)[0].max(-2)[0].view(B, nWh, nWw, + C) # B, nWh, nWw, C + elif self.pool_method == "fc": + x_windows_noreshape = x_windows_noreshape.view(B, nWh, nWw, window_size_glo * window_size_glo, + C).transpose(3, 4) # B, nWh, nWw, C, wsize**2 + x_windows_pooled = self.pool_layers[k](x_windows_noreshape).flatten( + -2) # B, nWh, nWw, C + elif self.pool_method == "conv": + x_windows_noreshape = x_windows_noreshape.view(-1, window_size_glo, window_size_glo, C).permute(0, + 3, + 1, + 2).contiguous() # B * nw * nw, C, wsize, wsize + x_windows_pooled = self.pool_layers[k](x_windows_noreshape).view(B, nWh, nWw, + C) # B, nWh, nWw, C + + x_windows_all += [x_windows_pooled] + x_window_masks_all += [None] + + attn_windows = self.attn(x_windows_all, mask_all=x_window_masks_all) # nW*B, window_size*window_size, C + + attn_windows = attn_windows[:, :self.window_size ** 2] + + # merge windows + attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) + shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C + + # reverse cyclic shift + if self.shift_size > 0: + x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + else: + x = shifted_x + # 不知道这行干啥的,先改了再说 + #x = x[:, :self.input_resolution[0], :self.input_resolution[1]].contiguous().view(B, -1, C) + x = x[:, :x_size[0], :x_size[1]].contiguous().view(B, -1, C) + + # FFN + x = shortcut + self.drop_path(x if (not self.use_layerscale) else (self.gamma_1 * x)) + x = x + self.drop_path( + self.mlp(self.norm2(x)) if (not self.use_layerscale) else (self.gamma_2 * self.mlp(self.norm2(x)))) + + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ + f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" + + def flops(self): + flops = 0 + H, W = self.input_resolution + # norm1 + flops += self.dim * H * W + + # W-MSA/SW-MSA + nW = H * W / self.window_size / self.window_size + flops += nW * self.attn.flops(self.window_size * self.window_size, self.window_size, self.focal_window) + + if self.pool_method != "none" and self.focal_level > 1: + for k in range(self.focal_level - 1): + window_size_glo = math.floor(self.window_size_glo / (2 ** k)) + nW_glo = nW * (2 ** k) + # (sub)-window pooling + flops += nW_glo * self.dim * window_size_glo * window_size_glo + # qkv for global levels + # NOTE: in our implementation, we pass the pooled window embedding to qkv embedding layer, + # but theoritically, we only need to compute k and v. + flops += nW_glo * self.dim * 3 * self.dim + + # mlp + flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio + # norm2 + flops += self.dim * H * W + return flops + + +class PatchMerging(nn.Module): + r""" Patch Merging Layer. + + Args: + img_size (tuple[int]): Resolution of input feature. + in_chans (int): Number of input channels. + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, img_size, in_chans=3, norm_layer=nn.LayerNorm, **kwargs): + super().__init__() + self.input_resolution = img_size + self.dim = in_chans + self.reduction = nn.Linear(4 * in_chans, 2 * in_chans, bias=False) + self.norm = norm_layer(4 * in_chans) + + def forward(self, x): + """ + x: B, C, H, W + """ + B, C, H, W = x.shape + + x = x.permute(0, 2, 3, 1).contiguous() + + x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C + x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C + x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C + x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C + x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C + x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C + + x = self.norm(x) + x = self.reduction(x) + + return x + + def extra_repr(self) -> str: + return f"input_resolution={self.input_resolution}, dim={self.dim}" + + def flops(self): + H, W = self.input_resolution + flops = H * W * self.dim + flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim + return flops + + +class BasicLayer(nn.Module): + """ A basic Focal Transformer layer for one stage. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + expand_size (int): expand size for focal level 1. + expand_layer (str): expand layer. Default: all + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.0. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + pool_method (str): Window pooling method. Default: none. + focal_level (int): Number of focal levels. Default: 1. + focal_window (int): region size at each focal level. Default: 1. + use_conv_embed (bool): whether use overlapped convolutional patch embedding layer. Default: False + use_shift (bool): Whether use window shift as in Swin Transformer. Default: False + use_pre_norm (bool): Whether use pre-norm before patch embedding projection for stability. Default: False + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + use_layerscale (bool): Whether use layer scale for stability. Default: False. + layerscale_value (float): Layerscale value. Default: 1e-4. + """ + + def __init__(self, dim, input_resolution, depth, num_heads, window_size, expand_size, expand_layer="all", + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, pool_method="none", + focal_level=1, focal_window=1, use_shift=False, + downsample=None, use_checkpoint=False, use_layerscale=False, layerscale_value=1e-4): + + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.depth = depth + self.use_checkpoint = use_checkpoint + + if expand_layer == "even": + expand_factor = 0 + elif expand_layer == "odd": + expand_factor = 1 + elif expand_layer == "all": + expand_factor = -1 + + # build blocks + self.blocks = nn.ModuleList([ + FocalTransformerBlock(dim=dim, input_resolution=input_resolution, + num_heads=num_heads, window_size=window_size, + shift_size=(0 if (i % 2 == 0) else window_size // 2) if use_shift else 0, + expand_size=0 if (i % 2 == expand_factor) else expand_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, + attn_drop=attn_drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer, + pool_method=pool_method, + focal_level=focal_level, + focal_window=focal_window, + use_layerscale=use_layerscale, + layerscale_value=layerscale_value) + for i in range(depth)]) + + # patch merging layer + if downsample is not None: + self.downsample = downsample(img_size=input_resolution, embed_dim=dim, norm_layer=norm_layer) + else: + self.downsample = None + + def forward(self, x, x_size): + for blk in self.blocks: + if self.use_checkpoint: + x = checkpoint.checkpoint(blk, x) + else: + x = blk(x, x_size) + + if self.downsample is not None: + x = self.downsample(x) + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" + + def flops(self): + flops = 0 + for blk in self.blocks: + flops += blk.flops() + if self.downsample is not None: + flops += self.downsample.flops() + return flops + + +class RFTB(nn.Module): + """Residual Focal Transformer Block (RFTB). + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + img_size: Input image size. + patch_size: Patch size. + resi_connection: The convolutional block before residual connection. + """ + + def __init__(self, + dim, + input_resolution, + depth, + num_heads, + window_size, + expand_size, + expand_layer="all", + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop=0., + attn_drop=0., + drop_path=0., + norm_layer=nn.LayerNorm, + pool_method="none", + focal_level=1, + focal_window=1, + use_conv_embed=False, + use_shift=False, + use_pre_norm=False, + downsample=None, + use_checkpoint=False, + use_layerscale=False, + layerscale_value=1e-4, + img_size=224, + patch_size=4, + resi_connection='1conv'): + super(RFTB, self).__init__() + + self.dim = dim + self.input_resolution = input_resolution + + self.residual_group = BasicLayer( + dim=dim, + input_resolution=input_resolution, + depth=depth, + num_heads=num_heads, + window_size=window_size, + expand_size=expand_size, + expand_layer=expand_layer, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop, + attn_drop=attn_drop, + drop_path=drop_path, + norm_layer=norm_layer, + pool_method=pool_method, + focal_level=focal_level, + focal_window=focal_window, + use_shift=use_shift, + downsample=downsample, + use_checkpoint=use_checkpoint, + use_layerscale=use_layerscale, + layerscale_value=layerscale_value) + + if resi_connection == '1conv': + self.conv = nn.Conv2d(dim, dim, 3, 1, 1) + elif resi_connection == '3conv': + # to save parameters and memory + self.conv = nn.Sequential( + nn.Conv2d(dim, dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(dim // 4, dim // 4, 1, 1, 0), nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(dim // 4, dim, 3, 1, 1)) + + self.patch_embed = PatchEmbed( + img_size=img_size, patch_size=patch_size, in_chans=dim, embed_dim=dim, norm_layer=None) + + self.patch_unembed = PatchUnEmbed( + img_size=img_size, patch_size=patch_size, in_chans=dim, embed_dim=dim, norm_layer=None) + + def forward(self, x, x_size): + return self.patch_embed(self.conv(self.patch_unembed(self.residual_group(x, x_size), x_size))) + x + + def flops(self): + flops = 0 + flops += self.residual_group.flops() + h, w = self.input_resolution + flops += h * w * self.dim * self.dim * 9 + flops += self.patch_embed.flops() + flops += self.patch_unembed.flops() + + return flops + + +class PatchEmbed(nn.Module): + r""" Image to Patch Embedding + + Args: + img_size (int): Image size. Default: 224. + patch_size (int): Patch token size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + use_conv_embed (bool): Wherther use overlapped convolutional embedding layer. Default: False. + norm_layer (nn.Module, optional): Normalization layer. Default: None + use_pre_norm (bool): Whether use pre-normalization before projection. Default: False + is_stem (bool): Whether current patch embedding is stem. Default: False + """ + + def __init__(self, img_size=(224, 224), patch_size=4, in_chans=3, embed_dim=96, + norm_layer=None): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] + self.img_size = img_size + self.patch_size = patch_size + self.patches_resolution = patches_resolution + self.num_patches = patches_resolution[0] * patches_resolution[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + if norm_layer is not None: + self.norm = norm_layer(embed_dim) + else: + self.norm = None + + def forward(self, x): + x = x.flatten(2).transpose(1, 2) + if self.norm is not None: + x = self.norm(x) + return x + + def flops(self): + Ho, Wo = self.patches_resolution + flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) + if self.norm is not None: + flops += Ho * Wo * self.embed_dim + return flops + + +class PatchUnEmbed(nn.Module): + r""" Image to Patch Unembedding + + Args: + img_size (int): Image size. Default: 224. + patch_size (int): Patch token size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] + self.img_size = img_size + self.patch_size = patch_size + self.patches_resolution = patches_resolution + self.num_patches = patches_resolution[0] * patches_resolution[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + def forward(self, x, x_size): + x = x.transpose(1, 2).view(x.shape[0], self.embed_dim, x_size[0], x_size[1]) # b Ph*Pw c + return x + + def flops(self): + flops = 0 + return flops + + +class Upsample(nn.Sequential): + """Upsample module. + + Args: + scale (int): Scale factor. Supported scales: 2^n and 3. + num_feat (int): Channel number of intermediate features. + """ + + def __init__(self, scale, num_feat): + m = [] + if (scale & (scale - 1)) == 0: # scale = 2^n + for _ in range(int(math.log(scale, 2))): + m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1)) + m.append(nn.PixelShuffle(2)) + elif scale == 3: + m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1)) + m.append(nn.PixelShuffle(3)) + else: + raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.') + super(Upsample, self).__init__(*m) + + +class UpsampleOneStep(nn.Sequential): + """UpsampleOneStep module (the difference with Upsample is that it always only has 1conv + 1pixelshuffle) + Used in lightweight SR to save parameters. + + Args: + scale (int): Scale factor. Supported scales: 2^n and 3. + num_feat (int): Channel number of intermediate features. + + """ + + def __init__(self, scale, num_feat, num_out_ch, input_resolution=None): + self.num_feat = num_feat + self.input_resolution = input_resolution + m = [] + m.append(nn.Conv2d(num_feat, (scale ** 2) * num_out_ch, 3, 1, 1)) + m.append(nn.PixelShuffle(scale)) + super(UpsampleOneStep, self).__init__(*m) + + def flops(self): + h, w = self.input_resolution + flops = h * w * self.num_feat * 3 * 9 + return flops + + +@ARCH_REGISTRY.register() +class FocalIR(nn.Module): + r""" Focal Transformer: Focal Self-attention for Local-Global Interactions in Vision Transformer + + Args: + img_size (int | tuple(int)): Input image size. Default 224 + patch_size (int | tuple(int)): Patch size. Default: 4 + in_chans (int): Number of input image channels. Default: 3 + num_classes (int): Number of classes for classification head. Default: 1000 + embed_dim (int): Patch embedding dimension. Default: 96 + depths (tuple(int)): Depth of each Focal Transformer layer. + num_heads (tuple(int)): Number of attention heads in different layers. + window_size (int): Window size. Default: 7 + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 + qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None + drop_rate (float): Dropout rate. Default: 0 + attn_drop_rate (float): Attention dropout rate. Default: 0 + drop_path_rate (float): Stochastic depth rate. Default: 0.1 + norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. + ape (bool): If True, add absolute position embedding to the patch embedding. Default: False + patch_norm (bool): If True, add normalization after patch embedding. Default: True + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False + use_shift (bool): Whether to use window shift proposed by Swin Transformer. We observe that using shift or not does not make difference to our Focal Transformer. Default: False + focal_stages (list): Which stages to perform focal attention. Default: [0, 1, 2, 3], means all stages + focal_levels (list): How many focal levels at all stages. Note that this excludes the finest-grain level. Default: [1, 1, 1, 1] + focal_windows (list): The focal window size at all stages. Default: [7, 5, 3, 1] + expand_stages (list): Which stages to expand the finest grain window. Default: [0, 1, 2, 3], means all stages + expand_sizes (list): The expand size for the finest grain level. Default: [3, 3, 3, 3] + expand_layer (str): Which layers we want to expand the window for the finest grain leve. This can save computational and memory cost without the loss of performance. Default: "all" + use_conv_embed (bool): Whether use convolutional embedding. We noted that using convolutional embedding usually improve the performance, but we do not use it by default. Default: False + use_layerscale (bool): Whether use layerscale proposed in CaiT. Default: False + layerscale_value (float): Value for layer scale. Default: 1e-4 + use_pre_norm (bool): Whether use pre-norm in patch merging/embedding layer to control the feature magtigute. Default: False + """ + + def __init__(self, + img_size=224, + patch_size=4, + in_chans=3, + embed_dim=96, + depths=[2, 2, 6, 2], + num_heads=[3, 6, 12, 24], + window_size=7, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0.1, + norm_layer=nn.LayerNorm, + ape=False, + patch_norm=True, + use_checkpoint=False, + use_shift=False, + focal_stages=[0, 1, 2, 3], + focal_levels=[1, 1, 1, 1], + focal_windows=[7, 5, 3, 1], + focal_pool="fc", + expand_stages=[0, 1, 2, 3], + expand_sizes=[3, 3, 3, 3], + expand_layer="all", + use_layerscale=False, + layerscale_value=1e-4, + upscale=2, + img_range=1., + upsampler='pixelshuffle', + resi_connection='1conv', + **kwargs): + super(FocalIR, self).__init__() + num_in_ch = in_chans + num_out_ch = in_chans + num_feat = 64 + self.img_range = img_range + if in_chans == 3: + rgb_mean = (0.4488, 0.4371, 0.4040) + self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1) + else: + self.mean = torch.zeros(1, 1, 1, 1) + self.upscale = upscale + self.upsampler = upsampler + self.window_size = window_size + + # ------------------------- 1, shallow feature extraction ------------------------- # + self.conv_first = nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1) + + # ------------------------- 2, deep feature extraction ------------------------- # + self.num_layers = len(depths) + self.embed_dim = embed_dim + self.ape = ape + self.patch_norm = patch_norm + # self.num_features = int(embed_dim * 2 ** (self.num_layers - 1)) + self.num_features = embed_dim + self.mlp_ratio = mlp_ratio + + # split image into patches using either non-overlapped embedding or overlapped embedding + self.patch_embed = PatchEmbed( + img_size=to_2tuple(img_size), patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + + num_patches = self.patch_embed.num_patches + patches_resolution = self.patch_embed.patches_resolution + self.patches_resolution = patches_resolution + + # merge non-overlapping patches into image + self.patch_unembed = PatchUnEmbed( + img_size=img_size, + patch_size=patch_size, + in_chans=embed_dim, + embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + + # absolute position embedding + if self.ape: + self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) + trunc_normal_(self.absolute_pos_embed, std=.02) + + self.pos_drop = nn.Dropout(p=drop_rate) + + # stochastic depth + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule + + # build Residual Focal Transformer blocks (RFTB) + self.layers = nn.ModuleList() + for i_layer in range(self.num_layers): + layer = RFTB( + dim=embed_dim, + input_resolution=(patches_resolution[0], patches_resolution[1]), + depth=depths[i_layer], + num_heads=num_heads[i_layer], + window_size=window_size, + mlp_ratio=self.mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], + norm_layer=norm_layer, + pool_method=focal_pool if i_layer in focal_stages else "none", + downsample=None, + focal_level=focal_levels[i_layer], + focal_window=focal_windows[i_layer], + expand_size=expand_sizes[i_layer], + expand_layer=expand_layer, + use_shift=use_shift, + use_checkpoint=use_checkpoint, + use_layerscale=use_layerscale, + layerscale_value=layerscale_value, + img_size=img_size, + patch_size=patch_size, + resi_connection=resi_connection) + self.layers.append(layer) + + self.norm = norm_layer(self.num_features) + # self.avgpool = nn.AdaptiveAvgPool1d(1) + # self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() + + # build the last conv layer in deep feature extraction + if resi_connection == '1conv': + self.conv_after_body = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1) + elif resi_connection == '3conv': + # to save parameters and memory + self.conv_after_body = nn.Sequential( + nn.Conv2d(embed_dim, embed_dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(embed_dim // 4, embed_dim // 4, 1, 1, 0), nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(embed_dim // 4, embed_dim, 3, 1, 1)) + + # ------------------------- 3, high quality image reconstruction ------------------------- # + if self.upsampler == 'pixelshuffle': + # for classical SR + self.conv_before_upsample = nn.Sequential( + nn.Conv2d(embed_dim, num_feat, 3, 1, 1), nn.LeakyReLU(inplace=True)) + self.upsample = Upsample(upscale, num_feat) + self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) + elif self.upsampler == 'pixelshuffledirect': + # for lightweight SR (to save parameters) + self.upsample = UpsampleOneStep(upscale, embed_dim, num_out_ch, + (patches_resolution[0], patches_resolution[1])) + elif self.upsampler == 'nearest+conv': + # for real-world SR (less artifacts) + assert self.upscale == 4, 'only support x4 now.' + self.conv_before_upsample = nn.Sequential( + nn.Conv2d(embed_dim, num_feat, 3, 1, 1), nn.LeakyReLU(inplace=True)) + self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) + self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) + else: + # for image denoising and JPEG compression artifact reduction + self.conv_last = nn.Conv2d(embed_dim, num_out_ch, 3, 1, 1) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay(self): + return {'absolute_pos_embed'} + + @torch.jit.ignore + def no_weight_decay_keywords(self): + return {'relative_position_bias_table', 'relative_position_bias_table_to_neighbors', + 'relative_position_bias_table_to_windows'} + + def check_image_size(self, x): + _, _, h, w = x.size() + mod_pad_h = (self.window_size - h % self.window_size) % self.window_size + mod_pad_w = (self.window_size - w % self.window_size) % self.window_size + x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), 'reflect') + return x + + def forward_features(self, x): + x_size = (x.shape[2], x.shape[3]) + x = self.patch_embed(x) + if self.ape: + x = x + self.absolute_pos_embed + x = self.pos_drop(x) + + for layer in self.layers: + x = layer(x, x_size) + x = self.norm(x) # B L C + x = self.patch_unembed(x, x_size) + + return x + + def forward(self, x): + H, W = x.shape[2:] + x = self.check_image_size(x) + self.mean = self.mean.type_as(x) + x = (x - self.mean) * self.img_range + + if self.upsampler == 'pixelshuffle': + # for classical SR + x = self.conv_first(x) + x = self.conv_after_body(self.forward_features(x)) + x + x = self.conv_before_upsample(x) + x = self.conv_last(self.upsample(x)) + elif self.upsampler == 'pixelshuffledirect': + # for lightweight SR + x = self.conv_first(x) + x = self.conv_after_body(self.forward_features(x)) + x + x = self.upsample(x) + elif self.upsampler == 'nearest+conv': + # for real-world SR + x = self.conv_first(x) + x = self.conv_after_body(self.forward_features(x)) + x + x = self.conv_before_upsample(x) + x = self.lrelu(self.conv_up1(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest'))) + x = self.lrelu(self.conv_up2(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest'))) + x = self.conv_last(self.lrelu(self.conv_hr(x))) + else: + # for image denoising and JPEG compression artifact reduction + x_first = self.conv_first(x) + res = self.conv_after_body(self.forward_features(x_first)) + x_first + x = x + self.conv_last(res) + + x = x / self.img_range + self.mean + return x[:, :, :H * self.upscale, :W * self.upscale] + + def flops(self): + flops = 0 + flops += self.patch_embed.flops() + for i, layer in enumerate(self.layers): + flops += layer.flops() + flops += self.num_features * self.patches_resolution[0] * self.patches_resolution[1] // (2 ** self.num_layers) + return flops + + +def profile(model, inputs): + from torch.profiler import profile, record_function, ProfilerActivity + with profile(activities=[ + ProfilerActivity.CPU, ProfilerActivity.CUDA], with_stack=True, record_shapes=True) as prof: + with record_function("model_inference"): + model(inputs) + + print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=15)) + + +if __name__ == '__main__': + img_hsize = 320 + img_wsize = 180 + x = torch.rand(1, 3, img_hsize, img_wsize).cuda() + model = FocalIR(img_size=(img_hsize, img_wsize), upscale=4, in_chans=3, embed_dim=60, depths=[6, 6, 6, 6], drop_path_rate=0.2, + focal_levels=[2, 2, 2, 2], expand_sizes=[3, 3, 3, 3], expand_layer="all",num_heads=[6, 6, 6, 6], + focal_windows=[7, 5, 3, 1], mlp_ratio=2, upsampler='pixelshuffle', window_size=4, resi_connection='1conv', use_shift=False).cuda() + + model.eval() + + #flops = model.flops() + #print(f"number of GFLOPs: {flops / 1e9}") + + #n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) + #print(f"number of params: {n_parameters}") + + flops, params = hp(model, inputs=(x,)) + + print("FLOPs=", str(flops / 1e9) + '{}'.format("G")) + print("params=", str(params / 1e6) + '{}'.format("M")) + + #profile(model, x) diff --git a/basicsr/archs/rdswinir_arch.py b/basicsr/archs/rdswinir_arch.py new file mode 100644 index 000000000..58eed35e8 --- /dev/null +++ b/basicsr/archs/rdswinir_arch.py @@ -0,0 +1,1016 @@ +# Modified from https://github.com/JingyunLiang/SwinIR +# SwinIR: Image Restoration Using Swin Transformer, https://arxiv.org/abs/2108.10257 +# Originally Written by Ze Liu, Modified by Jingyun Liang. + +import math +import torch +import torch.nn as nn +import torch.utils.checkpoint as checkpoint + +from basicsr.utils.registry import ARCH_REGISTRY +from basicsr.archs.arch_util import to_2tuple, trunc_normal_ + + +def drop_path(x, drop_prob: float = 0., training: bool = False): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + + From: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/drop.py + """ + if drop_prob == 0. or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) + random_tensor.floor_() # binarize + output = x.div(keep_prob) * random_tensor + return output + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + + From: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/drop.py + """ + + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) + + +class Mlp(nn.Module): + + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +def window_partition(x, window_size): + """ + Args: + x: (b, h, w, c) + window_size (int): window size + + Returns: + windows: (num_windows*b, window_size, window_size, c) + """ + b, h, w, c = x.shape + x = x.view(b, h // window_size, window_size, w // window_size, window_size, c) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, c) + return windows + + +def window_reverse(windows, window_size, h, w): + """ + Args: + windows: (num_windows*b, window_size, window_size, c) + window_size (int): Window size + h (int): Height of image + w (int): Width of image + + Returns: + x: (b, h, w, c) + """ + b = int(windows.shape[0] / (h * w / window_size / window_size)) + x = windows.view(b, h // window_size, w // window_size, window_size, window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(b, h, w, -1) + return x + + +class WindowAttention(nn.Module): + r""" Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): + + super().__init__() + self.dim = dim + self.window_size = window_size # Wh, Ww + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + self.register_buffer('relative_position_index', relative_position_index) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + + self.proj_drop = nn.Dropout(proj_drop) + + trunc_normal_(self.relative_position_bias_table, std=.02) + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, mask=None): + """ + Args: + x: input features with shape of (num_windows*b, n, c) + mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None + """ + b_, n, c = x.shape + qkv = self.qkv(x).reshape(b_, n, 3, self.num_heads, c // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + + relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nw = mask.shape[0] + attn = attn.view(b_ // nw, nw, self.num_heads, n, n) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, n, n) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(b_, n, c) + x = self.proj(x) + x = self.proj_drop(x) + return x + + def extra_repr(self) -> str: + return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}' + + def flops(self, n): + # calculate flops for 1 window with token length of n + flops = 0 + # qkv = self.qkv(x) + flops += n * self.dim * 3 * self.dim + # attn = (q @ k.transpose(-2, -1)) + flops += self.num_heads * n * (self.dim // self.num_heads) * n + # x = (attn @ v) + flops += self.num_heads * n * n * (self.dim // self.num_heads) + # x = self.proj(x) + flops += n * self.dim * self.dim + return flops + + +class SwinTransformerBlock(nn.Module): + r""" Swin Transformer Block. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + num_heads (int): Number of attention heads. + window_size (int): Window size. + shift_size (int): Shift size for SW-MSA. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, + dim, + input_resolution, + num_heads, + window_size=7, + shift_size=0, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop=0., + attn_drop=0., + drop_path=0., + act_layer=nn.GELU, + norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + if min(self.input_resolution) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.shift_size = 0 + self.window_size = min(self.input_resolution) + assert 0 <= self.shift_size < self.window_size, 'shift_size must in 0-window_size' + + self.norm1 = norm_layer(dim) + self.attn = WindowAttention( + dim, + window_size=to_2tuple(self.window_size), + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop=attn_drop, + proj_drop=drop) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + if self.shift_size > 0: + attn_mask = self.calculate_mask(self.input_resolution) + else: + attn_mask = None + + self.register_buffer('attn_mask', attn_mask) + + def calculate_mask(self, x_size): + # calculate attention mask for SW-MSA + h, w = x_size + img_mask = torch.zeros((1, h, w, 1)) # 1 h w 1 + h_slices = (slice(0, -self.window_size), slice(-self.window_size, + -self.shift_size), slice(-self.shift_size, None)) + w_slices = (slice(0, -self.window_size), slice(-self.window_size, + -self.shift_size), slice(-self.shift_size, None)) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + mask_windows = window_partition(img_mask, self.window_size) # nw, window_size, window_size, 1 + mask_windows = mask_windows.view(-1, self.window_size * self.window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + + return attn_mask + + def forward(self, x, x_size): + h, w = x_size + b, _, c = x.shape + # assert seq_len == h * w, "input feature has wrong size" + + shortcut = x + x = self.norm1(x) + x = x.view(b, h, w, c) + + # cyclic shift + if self.shift_size > 0: + shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + else: + shifted_x = x + + # partition windows + x_windows = window_partition(shifted_x, self.window_size) # nw*b, window_size, window_size, c + x_windows = x_windows.view(-1, self.window_size * self.window_size, c) # nw*b, window_size*window_size, c + + # W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size + if self.input_resolution == x_size: + attn_windows = self.attn(x_windows, mask=self.attn_mask) # nw*b, window_size*window_size, c + else: + attn_windows = self.attn(x_windows, mask=self.calculate_mask(x_size).to(x.device)) + + # merge windows + attn_windows = attn_windows.view(-1, self.window_size, self.window_size, c) + shifted_x = window_reverse(attn_windows, self.window_size, h, w) # b h' w' c + + # reverse cyclic shift + if self.shift_size > 0: + x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + else: + x = shifted_x + x = x.view(b, h * w, c) + + # FFN + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + + return x + + def extra_repr(self) -> str: + return (f'dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, ' + f'window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}') + + def flops(self): + flops = 0 + h, w = self.input_resolution + # norm1 + flops += self.dim * h * w + # W-MSA/SW-MSA + nw = h * w / self.window_size / self.window_size + flops += nw * self.attn.flops(self.window_size * self.window_size) + # mlp + flops += 2 * h * w * self.dim * self.dim * self.mlp_ratio + # norm2 + flops += self.dim * h * w + return flops + + +class PatchMerging(nn.Module): + r""" Patch Merging Layer. + + Args: + input_resolution (tuple[int]): Resolution of input feature. + dim (int): Number of input channels. + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): + super().__init__() + self.input_resolution = input_resolution + self.dim = dim + self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) + self.norm = norm_layer(4 * dim) + + def forward(self, x): + """ + x: b, h*w, c + """ + h, w = self.input_resolution + b, seq_len, c = x.shape + assert seq_len == h * w, 'input feature has wrong size' + assert h % 2 == 0 and w % 2 == 0, f'x size ({h}*{w}) are not even.' + + x = x.view(b, h, w, c) + + x0 = x[:, 0::2, 0::2, :] # b h/2 w/2 c + x1 = x[:, 1::2, 0::2, :] # b h/2 w/2 c + x2 = x[:, 0::2, 1::2, :] # b h/2 w/2 c + x3 = x[:, 1::2, 1::2, :] # b h/2 w/2 c + x = torch.cat([x0, x1, x2, x3], -1) # b h/2 w/2 4*c + x = x.view(b, -1, 4 * c) # b h/2*w/2 4*c + + x = self.norm(x) + x = self.reduction(x) + + return x + + def extra_repr(self) -> str: + return f'input_resolution={self.input_resolution}, dim={self.dim}' + + def flops(self): + h, w = self.input_resolution + flops = h * w * self.dim + flops += (h // 2) * (w // 2) * 4 * self.dim * 2 * self.dim + return flops + + +class BasicLayer(nn.Module): + """ A basic Swin Transformer layer for one stage. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + """ + + def __init__(self, + dim, + input_resolution, + depth, + num_heads, + window_size, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop=0., + attn_drop=0., + drop_path=0., + norm_layer=nn.LayerNorm, + downsample=None, + use_checkpoint=False, + img_size=224, + patch_size=4): + + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.depth = depth + self.use_checkpoint = use_checkpoint + + # build blocks + self.blocks = nn.ModuleList([ + SwinTransformerBlock( + dim=dim, + input_resolution=input_resolution, + num_heads=num_heads, + window_size=window_size, + shift_size=0 if (i % 2 == 0) else window_size // 2, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop, + attn_drop=attn_drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer) for i in range(depth) + ]) + # define the learnable parameters + self.fuse_weight = [] + for i in range(depth): + w = torch.nn.Parameter(torch.FloatTensor(1), requires_grad=True) + w.data.fill_(1) + w = w.to("cuda") + self.fuse_weight.append(w) + + # Convolutional extractor 给稠密连接设计的特征融合器,可惜一样没什么屌用 + # self.conv_body = [nn.Conv2d(dim * (i + 1), dim, 3, 1, 1) for i in range(depth + 1)] + # for c in self.conv_body: + # c.to("cuda") + + self.patch_unembed = PatchUnEmbed( + img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, norm_layer=None) + self.patch_embed = PatchEmbed( + img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, norm_layer=None) + + # patch merging layer + if downsample is not None: + self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) + else: + self.downsample = None + + def forward(self, x, x_size): + # temp = [x] + # batch_size = x.shape[0] + count = 0 + for blk in self.blocks: + if self.use_checkpoint: + x = checkpoint.checkpoint(blk, x) + else: + x = blk(x, x_size) + ''' + ######稠密连接,可惜没什么屌用...###### + if len(temp) > 1: + x = torch.cat((temp[:])) + # print(x.shape) + # 融合 + self.patch_unembed(x, x_size) + x = x.view(batch_size, self.dim * count, x_size[0], x_size[1]) + x = self.conv_body[count - 1](x) + # 变为patch embeding + # x = self.patch_embed(x, x_size) + x = x.flatten(2).transpose(1, 2) + else: + x = temp[0] + x = blk(x, x_size) + temp.append(x) + count += 1 + # 对temp进行融合 + x = torch.cat((temp[:])) + # print(x.shape) + # 融合 + self.patch_unembed(x, x_size) + x = x.view(batch_size, self.dim * count, x_size[0], x_size[1]) + x = self.conv_body[count - 1](x) + # 变为patch embeding + # x = self.patch_embed(x, x_size) + x = x.flatten(2).transpose(1, 2) + ''' + if self.downsample is not None: + x = self.downsample(x) + return x + + def extra_repr(self) -> str: + return f'dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}' + + def flops(self): + flops = 0 + for blk in self.blocks: + flops += blk.flops() + if self.downsample is not None: + flops += self.downsample.flops() + return flops + + +class RSTB(nn.Module): + """Residual Swin Transformer Block (RSTB). + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + img_size: Input image size. + patch_size: Patch size. + resi_connection: The convolutional block before residual connection. + """ + + def __init__(self, + dim, + input_resolution, + depth, + num_heads, + window_size, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop=0., + attn_drop=0., + drop_path=0., + norm_layer=nn.LayerNorm, + downsample=None, + use_checkpoint=False, + img_size=224, + patch_size=4, + resi_connection='1conv'): + super(RSTB, self).__init__() + + self.dim = dim + self.input_resolution = input_resolution + + self.residual_group = BasicLayer( + dim=dim, + input_resolution=input_resolution, + depth=depth, + num_heads=num_heads, + window_size=window_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop, + attn_drop=attn_drop, + drop_path=drop_path, + norm_layer=norm_layer, + downsample=downsample, + use_checkpoint=use_checkpoint, + img_size=img_size, + patch_size=patch_size) + + if resi_connection == '1conv': + self.conv = nn.Conv2d(dim, dim, 3, 1, 1) + elif resi_connection == '3conv': + # to save parameters and memory + self.conv = nn.Sequential( + nn.Conv2d(dim, dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(dim // 4, dim // 4, 1, 1, 0), nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(dim // 4, dim, 3, 1, 1)) + + self.patch_embed = PatchEmbed( + img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, norm_layer=None) + + self.patch_unembed = PatchUnEmbed( + img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, norm_layer=None) + + def forward(self, x, x_size): + return self.patch_embed(self.conv(self.patch_unembed(self.residual_group(x, x_size), x_size))) + x + + def flops(self): + flops = 0 + flops += self.residual_group.flops() + h, w = self.input_resolution + flops += h * w * self.dim * self.dim * 9 + flops += self.patch_embed.flops() + flops += self.patch_unembed.flops() + + return flops + + +class PatchEmbed(nn.Module): + r""" Image to Patch Embedding + + Args: + img_size (int): Image size. Default: 224. + patch_size (int): Patch token size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] + self.img_size = img_size + self.patch_size = patch_size + self.patches_resolution = patches_resolution + self.num_patches = patches_resolution[0] * patches_resolution[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + if norm_layer is not None: + self.norm = norm_layer(embed_dim) + else: + self.norm = None + + def forward(self, x): + x = x.flatten(2).transpose(1, 2) # b Ph*Pw c + if self.norm is not None: + x = self.norm(x) + return x + + def flops(self): + flops = 0 + h, w = self.img_size + if self.norm is not None: + flops += h * w * self.embed_dim + return flops + + +class PatchUnEmbed(nn.Module): + r""" Image to Patch Unembedding + + Args: + img_size (int): Image size. Default: 224. + patch_size (int): Patch token size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] + self.img_size = img_size + self.patch_size = patch_size + self.patches_resolution = patches_resolution + self.num_patches = patches_resolution[0] * patches_resolution[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + def forward(self, x, x_size): + x = x.transpose(1, 2).view(x.shape[0], self.embed_dim, x_size[0], x_size[1]) # b Ph*Pw c + return x + + def flops(self): + flops = 0 + return flops + + +class Upsample(nn.Sequential): + """Upsample module. + + Args: + scale (int): Scale factor. Supported scales: 2^n and 3. + num_feat (int): Channel number of intermediate features. + """ + + def __init__(self, scale, num_feat): + m = [] + if (scale & (scale - 1)) == 0: # scale = 2^n + for _ in range(int(math.log(scale, 2))): + m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1)) + m.append(nn.PixelShuffle(2)) + elif scale == 3: + m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1)) + m.append(nn.PixelShuffle(3)) + else: + raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.') + super(Upsample, self).__init__(*m) + + +class UpsampleOneStep(nn.Sequential): + """UpsampleOneStep module (the difference with Upsample is that it always only has 1conv + 1pixelshuffle) + Used in lightweight SR to save parameters. + + Args: + scale (int): Scale factor. Supported scales: 2^n and 3. + num_feat (int): Channel number of intermediate features. + + """ + + def __init__(self, scale, num_feat, num_out_ch, input_resolution=None): + self.num_feat = num_feat + self.input_resolution = input_resolution + m = [] + m.append(nn.Conv2d(num_feat, (scale ** 2) * num_out_ch, 3, 1, 1)) + m.append(nn.PixelShuffle(scale)) + super(UpsampleOneStep, self).__init__(*m) + + def flops(self): + h, w = self.input_resolution + flops = h * w * self.num_feat * 3 * 9 + return flops + + +@ARCH_REGISTRY.register() +class RDSwinIR(nn.Module): + r""" SwinIR + A PyTorch impl of : `SwinIR: Image Restoration Using Swin Transformer`, based on Swin Transformer. + + Args: + img_size (int | tuple(int)): Input image size. Default 64 + patch_size (int | tuple(int)): Patch size. Default: 1 + in_chans (int): Number of input image channels. Default: 3 + embed_dim (int): Patch embedding dimension. Default: 96 + depths (tuple(int)): Depth of each Swin Transformer layer. + num_heads (tuple(int)): Number of attention heads in different layers. + window_size (int): Window size. Default: 7 + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 + qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None + drop_rate (float): Dropout rate. Default: 0 + attn_drop_rate (float): Attention dropout rate. Default: 0 + drop_path_rate (float): Stochastic depth rate. Default: 0.1 + norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. + ape (bool): If True, add absolute position embedding to the patch embedding. Default: False + patch_norm (bool): If True, add normalization after patch embedding. Default: True + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False + upscale: Upscale factor. 2/3/4/8 for image SR, 1 for denoising and compress artifact reduction + img_range: Image range. 1. or 255. + upsampler: The reconstruction reconstruction module. 'pixelshuffle'/'pixelshuffledirect'/'nearest+conv'/None + resi_connection: The convolutional block before residual connection. '1conv'/'3conv' + """ + + def __init__(self, + img_size=64, + patch_size=1, + in_chans=3, + embed_dim=96, + depths=(6, 6, 6, 6), + num_heads=(6, 6, 6, 6), + window_size=7, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0.1, + norm_layer=nn.LayerNorm, + ape=False, + patch_norm=True, + use_checkpoint=False, + upscale=2, + img_range=1., + upsampler='', + resi_connection='1conv', + **kwargs): + super(RDSwinIR, self).__init__() + num_in_ch = in_chans + num_out_ch = in_chans + num_feat = 64 + self.img_range = img_range + if in_chans == 3: + rgb_mean = (0.4488, 0.4371, 0.4040) + self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1) + else: + self.mean = torch.zeros(1, 1, 1, 1) + self.upscale = upscale + self.upsampler = upsampler + + # ------------------------- 1, shallow feature extraction ------------------------- # + self.conv_first = nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1) + + # ------------------------- 2, deep feature extraction ------------------------- # + self.num_layers = len(depths) + self.embed_dim = embed_dim + self.ape = ape + self.patch_norm = patch_norm + self.num_features = embed_dim + self.mlp_ratio = mlp_ratio + + # split image into non-overlapping patches + self.patch_embed = PatchEmbed( + img_size=img_size, + patch_size=patch_size, + in_chans=embed_dim, + embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + num_patches = self.patch_embed.num_patches + patches_resolution = self.patch_embed.patches_resolution + self.patches_resolution = patches_resolution + + # merge non-overlapping patches into image + self.patch_unembed = PatchUnEmbed( + img_size=img_size, + patch_size=patch_size, + in_chans=embed_dim, + embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + + # absolute position embedding + if self.ape: + self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) + trunc_normal_(self.absolute_pos_embed, std=.02) + + self.pos_drop = nn.Dropout(p=drop_rate) + + # stochastic depth + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule + + # build Residual Swin Transformer blocks (RSTB) + self.layers = nn.ModuleList() + for i_layer in range(self.num_layers): + layer = RSTB( + dim=embed_dim, + input_resolution=(patches_resolution[0], patches_resolution[1]), + depth=depths[i_layer], + num_heads=num_heads[i_layer], + window_size=window_size, + mlp_ratio=self.mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results + norm_layer=norm_layer, + downsample=None, + use_checkpoint=use_checkpoint, + img_size=img_size, + patch_size=patch_size, + resi_connection=resi_connection) + self.layers.append(layer) + self.norm = norm_layer(self.num_features) + # define the learnable parameters + self.fuse_wight = [] + for i in range(self.num_layers): + w = torch.nn.Parameter(torch.FloatTensor(1), requires_grad=True) + w.data.fill_(1) + w = w.to("cuda") + self.fuse_wight.append(w) + + # build the last conv layer in deep feature extraction + if resi_connection == '1conv': + self.conv_after_body = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1) + elif resi_connection == '3conv': + # to save parameters and memory + self.conv_after_body = nn.Sequential( + nn.Conv2d(embed_dim, embed_dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(embed_dim // 4, embed_dim // 4, 1, 1, 0), nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(embed_dim // 4, embed_dim, 3, 1, 1)) + + # ------------------------- 3, high quality image reconstruction ------------------------- # + if self.upsampler == 'pixelshuffle': + # for classical SR + self.conv_before_upsample = nn.Sequential( + nn.Conv2d(embed_dim, num_feat, 3, 1, 1), nn.LeakyReLU(inplace=True)) + self.upsample = Upsample(upscale, num_feat) + self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) + elif self.upsampler == 'pixelshuffledirect': + # for lightweight SR (to save parameters) + self.upsample = UpsampleOneStep(upscale, embed_dim, num_out_ch, + (patches_resolution[0], patches_resolution[1])) + elif self.upsampler == 'nearest+conv': + # for real-world SR (less artifacts) + assert self.upscale == 4, 'only support x4 now.' + self.conv_before_upsample = nn.Sequential( + nn.Conv2d(embed_dim, num_feat, 3, 1, 1), nn.LeakyReLU(inplace=True)) + self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) + self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) + else: + # for image denoising and JPEG compression artifact reduction + self.conv_last = nn.Conv2d(embed_dim, num_out_ch, 3, 1, 1) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay(self): + return {'absolute_pos_embed'} + + @torch.jit.ignore + def no_weight_decay_keywords(self): + return {'relative_position_bias_table'} + + def forward_features(self, x): + x_size = (x.shape[2], x.shape[3]) + x = self.patch_embed(x) + if self.ape: + x = x + self.absolute_pos_embed + x = self.pos_drop(x) + count = 0 + for layer in self.layers: + x = layer(x, x_size) + self.fuse_wight[count] * x + count += 1 + + x = self.norm(x) # b seq_len c + x = self.patch_unembed(x, x_size) + + return x + + def forward(self, x): + self.mean = self.mean.type_as(x) + x = (x - self.mean) * self.img_range + + if self.upsampler == 'pixelshuffle': + # for classical SR + x = self.conv_first(x) + x = self.conv_after_body(self.forward_features(x)) + x + x = self.conv_before_upsample(x) + x = self.conv_last(self.upsample(x)) + elif self.upsampler == 'pixelshuffledirect': + # for lightweight SR + x = self.conv_first(x) + x = self.conv_after_body(self.forward_features(x)) + x + x = self.upsample(x) + elif self.upsampler == 'nearest+conv': + # for real-world SR + x = self.conv_first(x) + x = self.conv_after_body(self.forward_features(x)) + x + x = self.conv_before_upsample(x) + x = self.lrelu(self.conv_up1(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest'))) + x = self.lrelu(self.conv_up2(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest'))) + x = self.conv_last(self.lrelu(self.conv_hr(x))) + else: + # for image denoising and JPEG compression artifact reduction + x_first = self.conv_first(x) + res = self.conv_after_body(self.forward_features(x_first)) + x_first + x = x + self.conv_last(res) + + x = x / self.img_range + self.mean + + return x + + def flops(self): + flops = 0 + h, w = self.patches_resolution + flops += h * w * 3 * self.embed_dim * 9 + flops += self.patch_embed.flops() + for layer in self.layers: + flops += layer.flops() + flops += h * w * 3 * self.embed_dim * self.embed_dim + flops += self.upsample.flops() + return flops + + +if __name__ == '__main__': + upscale = 4 + window_size = 8 + height = (1024 // upscale // window_size + 1) * window_size + width = (720 // upscale // window_size + 1) * window_size + model = RDSwinIR( + upscale=4, + img_size=(height, width), + window_size=window_size, + img_range=1., + depths=[6, 6, 6, 6, 6, 6], + embed_dim=180, + num_heads=[6, 6, 6, 6, 6, 6], + mlp_ratio=2, + upsampler='pixelshuffledirect') + print(model) + print(height, width, model.flops() / 1e9) + + x = torch.randn((1, 3, height, width)) + x = model(x) + print(x.shape) diff --git a/basicsr/archs/swinir_arch.py b/basicsr/archs/swinir_arch.py index f3e9e2c54..c688a600b 100644 --- a/basicsr/archs/swinir_arch.py +++ b/basicsr/archs/swinir_arch.py @@ -6,9 +6,10 @@ import torch import torch.nn as nn import torch.utils.checkpoint as checkpoint +from thop import profile as hp from basicsr.utils.registry import ARCH_REGISTRY -from .arch_util import to_2tuple, trunc_normal_ +from basicsr.archs.arch_util import to_2tuple, trunc_normal_ def drop_path(x, drop_prob: float = 0., training: bool = False): @@ -935,11 +936,13 @@ def flops(self): if __name__ == '__main__': upscale = 4 - window_size = 8 - height = (1024 // upscale // window_size + 1) * window_size - width = (720 // upscale // window_size + 1) * window_size + window_size = 4 + # height = (1024 // upscale // window_size + 1) * window_size + # width = (720 // upscale // window_size + 1) * window_size + height = 320 + width = 180 model = SwinIR( - upscale=2, + upscale=4, img_size=(height, width), window_size=window_size, img_range=1., @@ -948,9 +951,13 @@ def flops(self): num_heads=[6, 6, 6, 6], mlp_ratio=2, upsampler='pixelshuffledirect') - print(model) - print(height, width, model.flops() / 1e9) - + # print(model) + # print(height, width, model.flops() / 1e9) + model.eval() x = torch.randn((1, 3, height, width)) - x = model(x) - print(x.shape) + # x = model(x) + # print(x.shape) + flops, params = hp(model, inputs=(x,)) + + print("FLOPs=", str(flops / 1e9) + '{}'.format("G")) + print("params=", str(params / 1e6) + '{}'.format("M")) diff --git a/basicsr/data/paired_image_dataset.py b/basicsr/data/paired_image_dataset.py index ee15d5a83..bfd99fad9 100644 --- a/basicsr/data/paired_image_dataset.py +++ b/basicsr/data/paired_image_dataset.py @@ -88,7 +88,7 @@ def __getitem__(self, index): # flip, rotation img_gt, img_lq = augment([img_gt, img_lq], self.opt['use_flip'], self.opt['use_rot']) - if self.opt['color'] == 'y': + if 'color' in self.opt and self.opt['color'] == 'y': img_gt = rgb2ycbcr(img_gt, y_only=True)[..., None] img_lq = rgb2ycbcr(img_lq, y_only=True)[..., None] diff --git a/basicsr/metrics/__init__.py b/basicsr/metrics/__init__.py index 4fb044a93..1f1e355fe 100644 --- a/basicsr/metrics/__init__.py +++ b/basicsr/metrics/__init__.py @@ -3,8 +3,9 @@ from basicsr.utils.registry import METRIC_REGISTRY from .niqe import calculate_niqe from .psnr_ssim import calculate_psnr, calculate_ssim +from .lpips import calculate_lpips -__all__ = ['calculate_psnr', 'calculate_ssim', 'calculate_niqe'] +__all__ = ['calculate_psnr', 'calculate_ssim', 'calculate_niqe', 'calculate_lpips'] def calculate_metric(data, opt): diff --git a/basicsr/metrics/lpips.py b/basicsr/metrics/lpips.py new file mode 100644 index 000000000..8d5a73014 --- /dev/null +++ b/basicsr/metrics/lpips.py @@ -0,0 +1,65 @@ +from torchvision.transforms.functional import normalize +from basicsr.utils import img2tensor +import lpips +import numpy as np + +from basicsr.metrics.metric_util import reorder_image, to_y_channel +from basicsr.utils.registry import METRIC_REGISTRY +import torch + +@METRIC_REGISTRY.register() +def calculate_lpips(img, img2, crop_border, input_order='HWC', test_y_channel=False, **kwargs): + + """Calculate LPIPS. + Ref: https://github.com/xinntao/BasicSR/pull/367 + Args: + img (ndarray): Images with range [0, 255]. + img2 (ndarray): Images with range [0, 255]. + crop_border (int): Cropped pixels in each edge of an image. These + pixels are not involved in the PSNR calculation. + input_order (str): Whether the input order is 'HWC' or 'CHW'. + Default: 'HWC'. + test_y_channel (bool): Test on Y channel of YCbCr. Default: False. + + Returns: + float: LPIPS result. + """ + assert img.shape == img2.shape, (f'Image shapes are differnet: {img.shape}, {img2.shape}.') + if input_order not in ['HWC', 'CHW']: + raise ValueError(f'Wrong input_order {input_order}. Supported input_orders are ' '"HWC" and "CHW"') + img = reorder_image(img, input_order=input_order) + img2 = reorder_image(img2, input_order=input_order) + img = img.astype(np.float64) + img2 = img2.astype(np.float64) + + if crop_border != 0: + img = img[crop_border:-crop_border, crop_border:-crop_border, ...] + img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...] + + if test_y_channel: + img = to_y_channel(img) + img2 = to_y_channel(img2) + + # start calculating LPIPS metrics + DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") + loss_fn_vgg = lpips.LPIPS(net='vgg', verbose=False).to(DEVICE) # RGB, normalized to [-1,1] + + mean = [0.5, 0.5, 0.5] + std = [0.5, 0.5, 0.5] + + img_gt = img2 / 255. + img_restored = img / 255. + + img_gt, img_restored = img2tensor([img_gt, img_restored], bgr2rgb=True, float32=True) + # norm to [-1, 1] + normalize(img_gt, mean, std, inplace=True) + normalize(img_restored, mean, std, inplace=True) + + # calculate lpips + img_gt = img_gt.to(DEVICE) + img_restored = img_restored.to(DEVICE) + loss_fn_vgg.eval() + lpips_val = loss_fn_vgg(img_restored.unsqueeze(0), img_gt.unsqueeze(0)) + + return lpips_val.detach().cpu().numpy().mean() + diff --git a/basicsr/models/focalir_model.py b/basicsr/models/focalir_model.py new file mode 100644 index 000000000..b0d268f2d --- /dev/null +++ b/basicsr/models/focalir_model.py @@ -0,0 +1,33 @@ +import torch +from torch.nn import functional as F + +from basicsr.utils.registry import MODEL_REGISTRY +from .sr_model import SRModel + + +@MODEL_REGISTRY.register() +class FocalIRModel(SRModel): + + def test(self): + # pad to multiplication of window_size + window_size = self.opt['network_g']['window_size'] + scale = self.opt.get('scale', 1) + mod_pad_h, mod_pad_w = 0, 0 + _, _, h, w = self.lq.size() + if h % window_size != 0: + mod_pad_h = window_size - h % window_size + if w % window_size != 0: + mod_pad_w = window_size - w % window_size + img = F.pad(self.lq, (0, mod_pad_w, 0, mod_pad_h), 'reflect') + if hasattr(self, 'net_g_ema'): + self.net_g_ema.eval() + with torch.no_grad(): + self.output = self.net_g_ema(img) + else: + self.net_g.eval() + with torch.no_grad(): + self.output = self.net_g(img) + self.net_g.train() + + _, _, h, w = self.output.size() + self.output = self.output[:, :, 0:h - mod_pad_h * scale, 0:w - mod_pad_w * scale] diff --git a/basicsr/models/sr_model.py b/basicsr/models/sr_model.py index fdbb88678..54c80bd6d 100644 --- a/basicsr/models/sr_model.py +++ b/basicsr/models/sr_model.py @@ -138,10 +138,11 @@ def nondist_validation(self, dataloader, current_iter, tb_logger, save_img): with_metrics = self.opt['val'].get('metrics') is not None use_pbar = self.opt['val'].get('pbar', False) - if with_metrics and not hasattr(self, 'metric_results'): # only execute in the first run - self.metric_results = {metric: 0 for metric in self.opt['val']['metrics'].keys()} - # initialize the best metric results for each dataset_name (supporting multiple validation datasets) - self._initialize_best_metric_results(dataset_name) + if with_metrics: + if not hasattr(self, 'metric_results'): # only execute in the first run + self.metric_results = {metric: 0 for metric in self.opt['val']['metrics'].keys()} + # initialize the best metric results for each dataset_name (supporting multiple validation datasets) + self._initialize_best_metric_results(dataset_name) # zero self.metric_results if with_metrics: self.metric_results = {metric: 0 for metric in self.metric_results} diff --git a/basicsr/models/swinirgan_model.py b/basicsr/models/swinirgan_model.py new file mode 100644 index 000000000..0b2c0de34 --- /dev/null +++ b/basicsr/models/swinirgan_model.py @@ -0,0 +1,107 @@ +import torch +from collections import OrderedDict +from torch.nn import functional as F +from basicsr.utils.registry import MODEL_REGISTRY +from .srgan_model import SRGANModel + + +@MODEL_REGISTRY.register() +class SwinIRGANModel(SRGANModel): + """SwinIRGAN model for single image super-resolution.""" + + def test(self): + # pad to multiplication of window_size + window_size = self.opt['network_g']['window_size'] + scale = self.opt.get('scale', 1) + mod_pad_h, mod_pad_w = 0, 0 + _, _, h, w = self.lq.size() + if h % window_size != 0: + mod_pad_h = window_size - h % window_size + if w % window_size != 0: + mod_pad_w = window_size - w % window_size + img = F.pad(self.lq, (0, mod_pad_w, 0, mod_pad_h), 'reflect') + if hasattr(self, 'net_g_ema'): + self.net_g_ema.eval() + with torch.no_grad(): + self.output = self.net_g_ema(img) + else: + self.net_g.eval() + with torch.no_grad(): + self.output = self.net_g(img) + self.net_g.train() + + _, _, h, w = self.output.size() + self.output = self.output[:, :, 0:h - mod_pad_h * scale, 0:w - mod_pad_w * scale] + + def optimize_parameters(self, current_iter): + # optimize net_g + for p in self.net_d.parameters(): + p.requires_grad = False + + self.optimizer_g.zero_grad() + self.output = self.net_g(self.lq) + + l_g_total = 0 + loss_dict = OrderedDict() + if (current_iter % self.net_d_iters == 0 and current_iter > self.net_d_init_iters): + # pixel loss + if self.cri_pix: + l_g_pix = self.cri_pix(self.output, self.gt) + l_g_total += l_g_pix + loss_dict['l_g_pix'] = l_g_pix + # perceptual loss + if self.cri_perceptual: + l_g_percep, l_g_style = self.cri_perceptual(self.output, self.gt) + if l_g_percep is not None: + l_g_total += l_g_percep + loss_dict['l_g_percep'] = l_g_percep + if l_g_style is not None: + l_g_total += l_g_style + loss_dict['l_g_style'] = l_g_style + # gan loss (relativistic gan) + real_d_pred = self.net_d(self.gt).detach() + fake_g_pred = self.net_d(self.output) + l_g_real = self.cri_gan(real_d_pred - torch.mean(fake_g_pred), False, is_disc=False) + l_g_fake = self.cri_gan(fake_g_pred - torch.mean(real_d_pred), True, is_disc=False) + l_g_gan = (l_g_real + l_g_fake) / 2 + + l_g_total += l_g_gan + loss_dict['l_g_gan'] = l_g_gan + + l_g_total.backward() + self.optimizer_g.step() + + # optimize net_d + for p in self.net_d.parameters(): + p.requires_grad = True + + self.optimizer_d.zero_grad() + # gan loss (relativistic gan) + + # In order to avoid the error in distributed training: + # "Error detected in CudnnBatchNormBackward: RuntimeError: one of + # the variables needed for gradient computation has been modified by + # an inplace operation", + # we separate the backwards for real and fake, and also detach the + # tensor for calculating mean. + + # real + fake_d_pred = self.net_d(self.output).detach() + real_d_pred = self.net_d(self.gt) + l_d_real = self.cri_gan(real_d_pred - torch.mean(fake_d_pred), True, is_disc=True) * 0.5 + l_d_real.backward() + # fake + fake_d_pred = self.net_d(self.output.detach()) + l_d_fake = self.cri_gan(fake_d_pred - torch.mean(real_d_pred.detach()), False, is_disc=True) * 0.5 + l_d_fake.backward() + self.optimizer_d.step() + + loss_dict['l_d_real'] = l_d_real + loss_dict['l_d_fake'] = l_d_fake + loss_dict['out_d_real'] = torch.mean(real_d_pred.detach()) + loss_dict['out_d_fake'] = torch.mean(fake_d_pred.detach()) + + self.log_dict = self.reduce_loss_dict(loss_dict) + + if self.ema_decay > 0: + self.model_ema(decay=self.ema_decay) diff --git a/basicsr/models/video_base_model.py b/basicsr/models/video_base_model.py index 2240df45c..9f7993a15 100644 --- a/basicsr/models/video_base_model.py +++ b/basicsr/models/video_base_model.py @@ -24,14 +24,15 @@ def dist_validation(self, dataloader, current_iter, tb_logger, save_img): # 'folder1': tensor (num_frame x len(metrics)), # 'folder2': tensor (num_frame x len(metrics)) # } - if with_metrics and not hasattr(self, 'metric_results'): # only execute in the first run - self.metric_results = {} - num_frame_each_folder = Counter(dataset.data_info['folder']) - for folder, num_frame in num_frame_each_folder.items(): - self.metric_results[folder] = torch.zeros( - num_frame, len(self.opt['val']['metrics']), dtype=torch.float32, device='cuda') - # initialize the best metric results - self._initialize_best_metric_results(dataset_name) + if with_metrics: + if not hasattr(self, 'metric_results'): # only execute in the first run + self.metric_results = {} + num_frame_each_folder = Counter(dataset.data_info['folder']) + for folder, num_frame in num_frame_each_folder.items(): + self.metric_results[folder] = torch.zeros( + num_frame, len(self.opt['val']['metrics']), dtype=torch.float32, device='cuda') + # initialize the best metric results + self._initialize_best_metric_results(dataset_name) # zero self.metric_results rank, world_size = get_dist_info() if with_metrics: diff --git a/basicsr/models/video_recurrent_model.py b/basicsr/models/video_recurrent_model.py index 49fa0e4e2..796ee57d5 100644 --- a/basicsr/models/video_recurrent_model.py +++ b/basicsr/models/video_recurrent_model.py @@ -72,14 +72,15 @@ def dist_validation(self, dataloader, current_iter, tb_logger, save_img): # 'folder1': tensor (num_frame x len(metrics)), # 'folder2': tensor (num_frame x len(metrics)) # } - if with_metrics and not hasattr(self, 'metric_results'): # only execute in the first run - self.metric_results = {} - num_frame_each_folder = Counter(dataset.data_info['folder']) - for folder, num_frame in num_frame_each_folder.items(): - self.metric_results[folder] = torch.zeros( - num_frame, len(self.opt['val']['metrics']), dtype=torch.float32, device='cuda') - # initialize the best metric results - self._initialize_best_metric_results(dataset_name) + if with_metrics: + if not hasattr(self, 'metric_results'): # only execute in the first run + self.metric_results = {} + num_frame_each_folder = Counter(dataset.data_info['folder']) + for folder, num_frame in num_frame_each_folder.items(): + self.metric_results[folder] = torch.zeros( + num_frame, len(self.opt['val']['metrics']), dtype=torch.float32, device='cuda') + # initialize the best metric results + self._initialize_best_metric_results(dataset_name) # zero self.metric_results rank, world_size = get_dist_info() if with_metrics: diff --git a/options/test/FocalIR/test_FocalIR_x4.yml b/options/test/FocalIR/test_FocalIR_x4.yml new file mode 100644 index 000000000..7ca663cc0 --- /dev/null +++ b/options/test/FocalIR/test_FocalIR_x4.yml @@ -0,0 +1,99 @@ +name: FocalIR_SRx4_DIV2K +model_type: FocalIRModel +scale: 4 +num_gpu: 1 # set num_gpu: 0 for cpu mode +manual_seed: 0 + +datasets: + test_1: # the 1st test dataset + name: Set5 + type: PairedImageDataset + dataroot_gt: datasets/benchmark/Set5/HR + dataroot_lq: datasets/benchmark/Set5/LR_bicubic/X4 + filename_tmpl: '{}x4' + io_backend: + type: disk + color: n + test_2: # the 2nd test dataset + name: Set14 + type: PairedImageDataset + dataroot_gt: datasets/benchmark/Set14/HR + dataroot_lq: datasets/benchmark/Set14/LR_bicubic/X4 + filename_tmpl: '{}x4' + io_backend: + type: disk + color: n + test_3: + name: B100 + type: PairedImageDataset + dataroot_gt: datasets/benchmark/B100/HR + dataroot_lq: datasets/benchmark/B100/LR_bicubic/X4 + filename_tmpl: '{}x4' + io_backend: + type: disk + color: n + test_4: + name: Urban100 + type: PairedImageDataset + dataroot_gt: datasets/benchmark/Urban100/HR + dataroot_lq: datasets/benchmark/Urban100/LR_bicubic/X4 + filename_tmpl: '{}x4' + io_backend: + type: disk + color: n + test_5: + name: Manga109 + type: PairedImageDataset + dataroot_gt: datasets/benchmark/Manga109/HR + dataroot_lq: datasets/benchmark/Manga109/LR_bicubic/X4 + filename_tmpl: '{}x4' + io_backend: + type: disk + color: n + +# network structures +network_g: + type: FocalIR + upscale: 4 + in_chans: 3 + img_size: 48 + window_size: 4 + img_range: 1. + depths: [6, 6, 6, 6] + embed_dim: 60 + num_heads: [6, 6, 6, 6] + drop_path_rate: 0.2 + focal_levels: [2, 2, 2, 2] + expand_sizes: [3, 3, 3, 3] + expand_layer: "all" + focal_windows: [7, 5, 3, 1] + mlp_ratio: 2 + upsampler: 'pixelshuffle' + resi_connection: '1conv' + use_shift: False + +# path +path: + pretrain_network_g: experiments/train_FocalIR_SRx4_s48g96_DIV2K/models/net_g_5000.pth + #pretrain_network_g: experiments/train_SwinIR_SRx4_DIV2K/models/net_g_latest.pth + strict_load_g: true + +# validation settings +val: + save_img: true + suffix: ~ # add suffix to saved images, if None, use exp name + + metrics: + psnr: # metric name, can be arbitrary + type: calculate_psnr + crop_border: 4 + test_y_channel: true + ssim: + type: calculate_ssim + crop_border: 4 + test_y_channel: true + lpips: + type: calculate_lpips + crop_border: 4 + test_y_channel: false + better: lower diff --git a/options/test/SwinIR/test_SwinIR_x4.yml b/options/test/SwinIR/test_SwinIR_x4.yml new file mode 100644 index 000000000..fe8422983 --- /dev/null +++ b/options/test/SwinIR/test_SwinIR_x4.yml @@ -0,0 +1,92 @@ +name: SwinIR_SRx4_DIV2K +model_type: SwinIRModel +scale: 4 +num_gpu: 1 # set num_gpu: 0 for cpu mode +manual_seed: 0 + +datasets: + test_1: # the 1st test dataset + name: Set5 + type: PairedImageDataset + dataroot_gt: datasets/benchmark/Set5/HR + dataroot_lq: datasets/benchmark/Set5/LR_bicubic/X4 + filename_tmpl: '{}x4' + io_backend: + type: disk + color: n + test_2: # the 2nd test dataset + name: Set14 + type: PairedImageDataset + dataroot_gt: datasets/benchmark/Set14/HR + dataroot_lq: datasets/benchmark/Set14/LR_bicubic/X4 + filename_tmpl: '{}x4' + io_backend: + type: disk + color: n + test_3: + name: B100 + type: PairedImageDataset + dataroot_gt: datasets/benchmark/B100/HR + dataroot_lq: datasets/benchmark/B100/LR_bicubic/X4 + filename_tmpl: '{}x4' + io_backend: + type: disk + color: n + test_4: + name: Urban100 + type: PairedImageDataset + dataroot_gt: datasets/benchmark/Urban100/HR + dataroot_lq: datasets/benchmark/Urban100/LR_bicubic/X4 + filename_tmpl: '{}x4' + io_backend: + type: disk + color: n + test_5: + name: Manga109 + type: PairedImageDataset + dataroot_gt: datasets/benchmark/Manga109/HR + dataroot_lq: datasets/benchmark/Manga109/LR_bicubic/X4 + filename_tmpl: '{}x4' + io_backend: + type: disk + color: n + +# network structures +network_g: + type: SwinIR + upscale: 4 + in_chans: 3 + img_size: 48 + window_size: 8 + img_range: 1. + depths: [6, 6, 6, 6, 6, 6] + embed_dim: 180 + num_heads: [6, 6, 6, 6, 6, 6] + mlp_ratio: 2 + upsampler: 'pixelshuffle' + resi_connection: '1conv' + +# path +path: + pretrain_network_g: experiments/pretrained_models/SwinIR/001_classicalSR_DIV2K_s48w8_SwinIR-M_x4.pth + #pretrain_network_g: experiments/train_SwinIR_SRx4_DIV2K/models/net_g_latest.pth + strict_load_g: true + +# validation settings +val: + save_img: true + suffix: ~ # add suffix to saved images, if None, use exp name + + metrics: + psnr: # metric name, can be arbitrary + type: calculate_psnr + crop_border: 4 + test_y_channel: true + ssim: + type: calculate_ssim + crop_border: 4 + test_y_channel: true + lpips: + type: calculate_lpips + crop_border: 4 + test_y_channel: false diff --git a/options/test/SwinIR/test_SwinIR_x4_myself.yml b/options/test/SwinIR/test_SwinIR_x4_myself.yml new file mode 100644 index 000000000..476398bf7 --- /dev/null +++ b/options/test/SwinIR/test_SwinIR_x4_myself.yml @@ -0,0 +1,94 @@ +name: SwinIR_SRx4_DIV2K +model_type: SwinIRModel +scale: 4 +num_gpu: 1 # set num_gpu: 0 for cpu mode +manual_seed: 0 + +datasets: + test_1: # the 1st test dataset + name: Set5 + type: PairedImageDataset + dataroot_gt: datasets/benchmark/Set5/HR + dataroot_lq: datasets/benchmark/Set5/LR_bicubic/X4 + filename_tmpl: '{}x4' + io_backend: + type: disk + color: n + test_2: # the 2nd test dataset + name: Set14 + type: PairedImageDataset + dataroot_gt: datasets/benchmark/Set14/HR + dataroot_lq: datasets/benchmark/Set14/LR_bicubic/X4 + filename_tmpl: '{}x4' + io_backend: + type: disk + color: n + test_3: + name: B100 + type: PairedImageDataset + dataroot_gt: datasets/benchmark/B100/HR + dataroot_lq: datasets/benchmark/B100/LR_bicubic/X4 + filename_tmpl: '{}x4' + io_backend: + type: disk + color: n + test_4: + name: Urban100 + type: PairedImageDataset + dataroot_gt: datasets/benchmark/Urban100/HR + dataroot_lq: datasets/benchmark/Urban100/LR_bicubic/X4 + filename_tmpl: '{}x4' + io_backend: + type: disk + color: n + test_5: + name: Manga109 + type: PairedImageDataset + dataroot_gt: datasets/benchmark/Manga109/HR + dataroot_lq: datasets/benchmark/Manga109/LR_bicubic/X4 + filename_tmpl: '{}x4' + io_backend: + type: disk + color: n + +# network structures +network_g: + type: SwinIR + upscale: 4 + in_chans: 3 + #img_size: 32 + img_size: 48 + window_size: 8 + img_range: 1. + depths: [6, 6, 6, 6, 6, 6] + embed_dim: 180 + num_heads: [6, 6, 6, 6, 6, 6] + mlp_ratio: 2 + upsampler: 'pixelshuffle' + resi_connection: '1conv' + +# path +path: + #pretrain_network_g: experiments/pretrained_models/SwinIR/001_classicalSR_DIV2K_s48w8_SwinIR-M_x4.pth + pretrain_network_g: experiments/train_SwinIR_SRx4_DIV2K/models/net_g_latest.pth + strict_load_g: true + +# validation settings +val: + save_img: true + suffix: ~ # add suffix to saved images, if None, use exp name + + metrics: + psnr: # metric name, can be arbitrary + type: calculate_psnr + crop_border: 4 + test_y_channel: true + ssim: + type: calculate_ssim + crop_border: 4 + test_y_channel: true + lpips: + type: calculate_lpips + crop_border: 4 + test_y_channel: true + better: lower diff --git a/options/train/ECBSR/train_ECBSR_x4_m4c16_prelu_RGB.yml b/options/train/ECBSR/train_ECBSR_x4_m4c16_prelu_RGB.yml new file mode 100644 index 000000000..d9aee8c12 --- /dev/null +++ b/options/train/ECBSR/train_ECBSR_x4_m4c16_prelu_RGB.yml @@ -0,0 +1,139 @@ +# general settings +name: 100_train_ECBSR_x4_m4c16_prelu_RGB +model_type: SRModel +scale: 4 +num_gpu: 1 # set num_gpu: 0 for cpu mode +manual_seed: 0 + +# dataset and data loader settings +datasets: + train: + name: DIV2K + type: PairedImageDataset + # It is strongly recommended to use lmdb for faster IO speed, especially for small networks + dataroot_gt: datasets/DF2K/DIV2K_train_HR_sub.lmdb + dataroot_lq: datasets/DF2K/DIV2K_train_LR_bicubic_X4_sub.lmdb + meta_info_file: basicsr/data/meta_info/meta_info_DIV2K800sub_GT.txt + filename_tmpl: '{}' + io_backend: + type: lmdb + + gt_size: 256 + use_flip: true + use_rot: true + + # data loader + use_shuffle: true + num_worker_per_gpu: 12 + batch_size_per_gpu: 32 + dataset_enlarge_ratio: 10 + prefetch_mode: ~ + + # we use multiple validation datasets. The SR benchmark datasets can be download from: https://cv.snu.ac.kr/research/EDSR/benchmark.tar + val: + name: Set5 + type: PairedImageDataset + dataroot_gt: datasets/benchmark/Set5/HR + dataroot_lq: datasets/benchmark/Set5/LR_bicubic/X4 + filename_tmpl: '{}x4' + io_backend: + type: disk + + val_2: + name: Set14 + type: PairedImageDataset + dataroot_gt: datasets/benchmark/Set14/HR + dataroot_lq: datasets/benchmark/Set14/LR_bicubic/X4 + filename_tmpl: '{}x4' + io_backend: + type: disk + + val_3: + name: B100 + type: PairedImageDataset + dataroot_gt: datasets/benchmark/B100/HR + dataroot_lq: datasets/benchmark/B100/LR_bicubic/X4 + filename_tmpl: '{}x4' + io_backend: + type: disk + + val_4: + name: Urban100 + type: PairedImageDataset + dataroot_gt: datasets/benchmark/Urban100/HR + dataroot_lq: datasets/benchmark/Urban100/LR_bicubic/X4 + filename_tmpl: '{}x4' + io_backend: + type: disk + +# network structures +network_g: + type: ECBSR + num_in_ch: 3 + num_out_ch: 3 + num_block: 4 + num_channel: 16 + with_idt: False + act_type: prelu + scale: 4 + +# path +path: + pretrain_network_g: ~ + strict_load_g: true + resume_state: ~ + +# training settings +train: + ema_decay: 0 + optim_g: + type: Adam + lr: !!float 5e-4 + weight_decay: 0 + betas: [0.9, 0.99] + + scheduler: + type: MultiStepLR + milestones: [1600000] + gamma: 1 + + total_iter: 1600000 + warmup_iter: -1 # no warm up + + # losses + pixel_opt: + type: L1Loss + loss_weight: 1.0 + reduction: mean + +# validation settings +val: + val_freq: !!float 1600 # the same as the original setting. # TODO: Can be larger + save_img: false + pbar: False + + metrics: + psnr: + type: calculate_psnr + crop_border: 4 + test_y_channel: true + better: higher # the higher, the better. Default: higher + ssim: + type: calculate_ssim + crop_border: 4 + test_y_channel: true + better: higher # the higher, the better. Default: higher + +# logging settings +logger: + print_freq: 100 + save_checkpoint_freq: !!float 1600 + use_tb_logger: true + wandb: + project: ~ + resume_id: ~ + +# dist training settings +dist_params: + backend: nccl + port: 29500 diff --git a/options/train/ESRGAN/train_ESRGAN_x4.yml b/options/train/ESRGAN/train_ESRGAN_x4.yml index 5310e7a14..f13ec6739 100644 --- a/options/train/ESRGAN/train_ESRGAN_x4.yml +++ b/options/train/ESRGAN/train_ESRGAN_x4.yml @@ -20,6 +20,7 @@ datasets: type: disk # (for lmdb) # type: lmdb + color: n gt_size: 128 use_flip: true @@ -29,16 +30,18 @@ datasets: use_shuffle: true num_worker_per_gpu: 6 batch_size_per_gpu: 16 - dataset_enlarge_ratio: 100 + dataset_enlarge_ratio: 1 prefetch_mode: ~ val: - name: Set14 + name: Set5 type: PairedImageDataset - dataroot_gt: datasets/Set14/GTmod12 - dataroot_lq: datasets/Set14/LRbicx4 + dataroot_gt: datasets/benchmark/Set5/HR + dataroot_lq: datasets/benchmark/Set5/LR_bicubic/X4 + filename_tmpl: '{}x4' io_backend: type: disk + color: n # network structures network_g: @@ -55,7 +58,7 @@ network_d: # path path: - pretrain_network_g: experiments/051_RRDBNet_PSNR_x4_f64b23_DIV2K_1000k_B16G1_wandb/models/net_g_1000000.pth + pretrain_network_g: ~ strict_load_g: true resume_state: ~ diff --git a/options/train/FocalIR/train_focalIR_SRx4_scratch.yml b/options/train/FocalIR/train_focalIR_SRx4_scratch.yml new file mode 100644 index 000000000..5d0c1944f --- /dev/null +++ b/options/train/FocalIR/train_focalIR_SRx4_scratch.yml @@ -0,0 +1,128 @@ +# general settings +name: train_test_FocalIR_SRx4_s48g96_DIV2K +model_type: FocalIRModel +scale: 4 +num_gpu: 1 +manual_seed: 0 + +# dataset and data loader settings +datasets: + train: + name: DIV2K + type: PairedImageDataset + #dataroot_gt: datasets/DIV2K/DIV2K_train_HR_sub + #dataroot_lq: datasets/DIV2K/DIV2K_train_LR_bicubic/X4_sub + dataroot_gt: datasets/DIV2K/DIV2K_train_HR_sub.lmdb + dataroot_lq: datasets/DIV2K/DIV2K_train_LR_bicubic_X4_sub.lmdb + #meta_info_file: basicsr/data/meta_info/meta_info_DIV2K800sub_GT.txt + filename_tmpl: '{}' + io_backend: + type: lmdb + + gt_size: 96 + use_flip: true + use_rot: true + color: n + + # data loader + use_shuffle: true + num_worker_per_gpu: 6 + batch_size_per_gpu: 16 + dataset_enlarge_ratio: 1 + #prefetch_mode: ~ + prefetch_mode: cuda + pin_memory: true + + val: + name: Set5 + type: PairedImageDataset + dataroot_gt: datasets/benchmark/Set5/HR + dataroot_lq: datasets/benchmark/Set5/LR_bicubic/X4 + filename_tmpl: '{}x4' + io_backend: + type: disk + color: n + +# network structures +network_g: + type: FocalIR + upscale: 4 + in_chans: 3 + img_size: 48 + window_size: 4 + img_range: 1. + depths: [6, 6, 6, 6] + embed_dim: 60 + num_heads: [6, 6, 6, 6] + drop_path_rate: 0.2 + focal_levels: [2, 2, 2, 2] + expand_sizes: [3, 3, 3, 3] + expand_layer: "all" + focal_windows: [7, 5, 3, 1] + mlp_ratio: 2 + upsampler: 'pixelshuffle' + resi_connection: '1conv' + use_shift: False + +# path +path: + pretrain_network_g: ~ + strict_load_g: true + resume_state: ~ + +# training settings +train: + ema_decay: 0.999 + optim_g: + type: Adam + lr: !!float 2e-4 + weight_decay: 0 + betas: [0.9, 0.99] + + scheduler: + type: MultiStepLR + milestones: [250000, 400000, 450000, 475000] + gamma: 0.5 + + total_iter: 500000 + warmup_iter: -1 # no warm up + + # losses + pixel_opt: + type: L1Loss + loss_weight: 1.0 + reduction: mean + +# validation settings +val: + val_freq: !!float 5e3 + save_img: false + + metrics: + psnr: # metric name, can be arbitrary + type: calculate_psnr + crop_border: 4 + test_y_channel: true + ssim: + type: calculate_ssim + crop_border: 4 + test_y_channel: true + lpips: + type: calculate_lpips + crop_border: 4 + test_y_channel: false + better: lower + +# logging settings +logger: + print_freq: 100 + save_checkpoint_freq: !!float 5e3 + use_tb_logger: true + wandb: + project: ~ + resume_id: ~ + +# dist training settings +#dist_params: +# backend: nccl +# port: 29500 diff --git a/options/train/FocalIR/train_focalIR_SRx4_scratch_plus.yml b/options/train/FocalIR/train_focalIR_SRx4_scratch_plus.yml new file mode 100644 index 000000000..c5fd21dc1 --- /dev/null +++ b/options/train/FocalIR/train_focalIR_SRx4_scratch_plus.yml @@ -0,0 +1,129 @@ +# general settings +name: train_FocalIR_SRx4_s48g96_DIV2K +model_type: FocalIRModel +scale: 4 +num_gpu: 1 +manual_seed: 0 + +# dataset and data loader settings +datasets: + train: + name: DIV2K + type: PairedImageDataset + #dataroot_gt: datasets/DIV2K/DIV2K_train_HR_sub + #dataroot_lq: datasets/DIV2K/DIV2K_train_LR_bicubic/X4_sub + dataroot_gt: datasets/DIV2K/DIV2K_train_HR_sub.lmdb + dataroot_lq: datasets/DIV2K/DIV2K_train_LR_bicubic_X4_sub.lmdb + #meta_info_file: basicsr/data/meta_info/meta_info_DIV2K800sub_GT.txt + filename_tmpl: '{}' + io_backend: + type: lmdb + + gt_size: 96 + use_flip: true + use_rot: true + color: n + + # data loader + use_shuffle: true + num_worker_per_gpu: 6 + batch_size_per_gpu: 16 + dataset_enlarge_ratio: 1 + prefetch_mode: ~ + #prefetch_mode: cuda + #pin_memory: true + + val: + name: Set5 + type: PairedImageDataset + dataroot_gt: datasets/benchmark/Set5/HR + dataroot_lq: datasets/benchmark/Set5/LR_bicubic/X4 + filename_tmpl: '{}x4' + io_backend: + type: disk + color: n + +# network structures +network_g: + type: FocalIR + upscale: 4 + in_chans: 3 + img_size: 48 + window_size: 4 + img_range: 1. + depths: [6, 6, 6, 6, 6, 6] + embed_dim: 180 + num_heads: [6, 6, 6, 6, 6, 6] + drop_path_rate: 0.2 + focal_levels: [2, 2, 2, 2, 2, 2] + expand_sizes: [3, 3, 3, 3, 3, 3] + expand_layer: "all" + focal_windows: [11, 9, 7, 5, 3, 1] + mlp_ratio: 2 + upsampler: 'pixelshuffle' + resi_connection: '1conv' + use_shift: False + +# path +path: + pretrain_network_g: ~ + strict_load_g: true + resume_state: ~ + +# training settings +train: + ema_decay: 0.999 + optim_g: + type: Adam + lr: !!float 2e-4 + weight_decay: 0 + betas: [0.9, 0.99] + + scheduler: + type: MultiStepLR + milestones: [250000, 400000, 450000, 475000] + gamma: 0.5 + + total_iter: 500000 + warmup_iter: -1 # no warm up + + # losses + pixel_opt: + type: L1Loss + loss_weight: 1.0 + reduction: mean + +# validation settings +val: + #val_freq: !!float 5e3 + val_freq: 1 + save_img: false + + metrics: + psnr: # metric name, can be arbitrary + type: calculate_psnr + crop_border: 4 + test_y_channel: true + ssim: + type: calculate_ssim + crop_border: 4 + test_y_channel: true + lpips: + type: calculate_lpips + crop_border: 4 + test_y_channel: false + better: lower + +# logging settings +logger: + print_freq: 100 + save_checkpoint_freq: !!float 5e3 + use_tb_logger: true + wandb: + project: ~ + resume_id: ~ + +# dist training settings +#dist_params: +# backend: nccl +# port: 29500 diff --git a/options/train/RDSwinIR/train_RDSwinIR_SRx4_scratch.yml b/options/train/RDSwinIR/train_RDSwinIR_SRx4_scratch.yml new file mode 100644 index 000000000..296cab105 --- /dev/null +++ b/options/train/RDSwinIR/train_RDSwinIR_SRx4_scratch.yml @@ -0,0 +1,121 @@ +# general settings +name: train_RDSwinIR_learnable_DIV2K +model_type: SwinIRModel +scale: 4 +num_gpu: 1 +manual_seed: 0 + +# dataset and data loader settings +datasets: + train: + name: DIV2K + type: PairedImageDataset + #dataroot_gt: datasets/DIV2K/DIV2K_train_HR_sub + #dataroot_lq: datasets/DIV2K/DIV2K_train_LR_bicubic/X4_sub + dataroot_gt: datasets/DIV2K/DIV2K_train_HR_sub.lmdb + dataroot_lq: datasets/DIV2K/DIV2K_train_LR_bicubic_X4_sub.lmdb + #meta_info_file: basicsr/data/meta_info/meta_info_DIV2K800sub_GT.txt + filename_tmpl: '{}' + io_backend: + type: lmdb + + gt_size: 64 + use_flip: true + use_rot: true + color: n + + # data loader + use_shuffle: true + num_worker_per_gpu: 6 + batch_size_per_gpu: 16 + dataset_enlarge_ratio: 1 + prefetch_mode: ~ + + val: + name: Set5 + type: PairedImageDataset + dataroot_gt: datasets/benchmark/Set5/HR + dataroot_lq: datasets/benchmark/Set5/LR_bicubic/X4 + filename_tmpl: '{}x4' + io_backend: + type: disk + color: n + +# network structures +network_g: + type: RDSwinIR + upscale: 4 + in_chans: 3 + img_size: 32 + window_size: 8 + img_range: 1. + depths: [6, 6, 6, 6, 6, 6] + embed_dim: 180 + num_heads: [6, 6, 6, 6, 6, 6] + mlp_ratio: 2 + #patch_norm: False + upsampler: 'pixelshuffle' + resi_connection: '1conv' + +# path +path: + pretrain_network_g: ~ + strict_load_g: true + resume_state: ~ + +# training settings +train: + ema_decay: 0.999 + optim_g: + type: Adam + lr: !!float 2e-4 + weight_decay: 0 + betas: [0.9, 0.99] + + scheduler: + type: MultiStepLR + milestones: [250000, 400000, 450000, 475000] + gamma: 0.5 + + total_iter: 500000 + warmup_iter: -1 # no warm up + + # losses + pixel_opt: + type: L1Loss + loss_weight: 1.0 + reduction: mean + +# validation settings +val: + val_freq: !!float 5e3 + save_img: false + + metrics: + psnr: # metric name, can be arbitrary + type: calculate_psnr + crop_border: 4 + test_y_channel: true + ssim: + type: calculate_ssim + crop_border: 4 + test_y_channel: true + lpips: + type: calculate_lpips + crop_border: 4 + test_y_channel: false + better: lower + +# logging settings +logger: + print_freq: 100 + save_checkpoint_freq: !!float 5e3 + use_tb_logger: true + wandb: + project: ~ + resume_id: ~ + +# dist training settings +#dist_params: +# backend: nccl +# port: 29500 diff --git a/options/train/SwinIR/train_SwinIR_SRx4_scratch.yml b/options/train/SwinIR/train_SwinIR_SRx4_scratch.yml index ed1edd006..b7571a213 100644 --- a/options/train/SwinIR/train_SwinIR_SRx4_scratch.yml +++ b/options/train/SwinIR/train_SwinIR_SRx4_scratch.yml @@ -1,8 +1,8 @@ # general settings -name: train_SwinIR_SRx4_scratch_P48W8_DIV2K_500k_B4G8 +name: train_SwinIR_SRx4_s48g96_DIV2K model_type: SwinIRModel scale: 4 -num_gpu: auto +num_gpu: 1 manual_seed: 0 # dataset and data loader settings @@ -10,31 +10,36 @@ datasets: train: name: DIV2K type: PairedImageDataset - dataroot_gt: datasets/DF2K/DIV2K_train_HR_sub - dataroot_lq: datasets/DF2K/DIV2K_train_LR_bicubic_X4_sub - meta_info_file: basicsr/data/meta_info/meta_info_DIV2K800sub_GT.txt + #dataroot_gt: datasets/DIV2K/DIV2K_train_HR_sub + #dataroot_lq: datasets/DIV2K/DIV2K_train_LR_bicubic/X4_sub + dataroot_gt: datasets/DIV2K/DIV2K_train_HR_sub.lmdb + dataroot_lq: datasets/DIV2K/DIV2K_train_LR_bicubic_X4_sub.lmdb + #meta_info_file: basicsr/data/meta_info/meta_info_DIV2K800sub_GT.txt filename_tmpl: '{}' io_backend: - type: disk + type: lmdb - gt_size: 192 + gt_size: 96 use_flip: true use_rot: true + color: n # data loader use_shuffle: true num_worker_per_gpu: 6 - batch_size_per_gpu: 4 + batch_size_per_gpu: 16 dataset_enlarge_ratio: 1 prefetch_mode: ~ val: name: Set5 type: PairedImageDataset - dataroot_gt: datasets/Set5/GTmod12 - dataroot_lq: datasets/Set5/LRbicx4 + dataroot_gt: datasets/benchmark/Set5/HR + dataroot_lq: datasets/benchmark/Set5/LR_bicubic/X4 + filename_tmpl: '{}x4' io_backend: type: disk + color: n # network structures network_g: @@ -89,7 +94,16 @@ val: psnr: # metric name, can be arbitrary type: calculate_psnr crop_border: 4 + test_y_channel: true + ssim: + type: calculate_ssim + crop_border: 4 + test_y_channel: true + lpips: + type: calculate_lpips + crop_border: 4 test_y_channel: false + better: lower # logging settings logger: @@ -101,6 +115,6 @@ logger: resume_id: ~ # dist training settings -dist_params: - backend: nccl - port: 29500 +#dist_params: +# backend: nccl +# port: 29500 diff --git a/options/train/SwinIRGAN/train_SwinIRGAN_x4.yml b/options/train/SwinIRGAN/train_SwinIRGAN_x4.yml new file mode 100644 index 000000000..916c875a0 --- /dev/null +++ b/options/train/SwinIRGAN/train_SwinIRGAN_x4.yml @@ -0,0 +1,151 @@ +# general settings +name: SwinIRGANModel_x4_DIV2K +model_type: SwinIRGANModel +scale: 4 +num_gpu: 1 # set num_gpu: 0 for cpu mode +manual_seed: 0 + +# dataset and data loader settings +datasets: + train: + name: DIV2K + type: PairedImageDataset + dataroot_gt: datasets/DIV2K/DIV2K_train_HR_sub.lmdb + dataroot_lq: datasets/DIV2K/DIV2K_train_LR_bicubic_X4_sub.lmdb + # (for lmdb) + # dataroot_gt: datasets/DIV2K/DIV2K_train_HR_sub.lmdb + # dataroot_lq: datasets/DIV2K/DIV2K_train_LR_bicubic_X4_sub.lmdb + filename_tmpl: '{}' + io_backend: + type: lmdb + # (for lmdb) + # type: lmdb + + gt_size: 128 + use_flip: true + use_rot: true + + # data loader + use_shuffle: true + num_worker_per_gpu: 6 + batch_size_per_gpu: 16 + dataset_enlarge_ratio: 100 + prefetch_mode: ~ + + val: + name: Set5 + type: PairedImageDataset + dataroot_gt: datasets/benchmark/Set5/HR + dataroot_lq: datasets/benchmark/Set5/LR_bicubic/X4 + filename_tmpl: '{}x4' + io_backend: + type: disk + color: n + +# network structures +network_g: + type: SwinIR + upscale: 4 + in_chans: 3 + img_size: 64 + window_size: 8 + img_range: 1. + depths: [6, 6, 6, 6, 6, 6] + embed_dim: 180 + num_heads: [6, 6, 6, 6, 6, 6] + mlp_ratio: 4 + upsampler: 'pixelshuffle' + resi_connection: '1conv' + + +network_d: + type: UNetDiscriminatorSN + num_in_ch: 3 + num_feat: 64 + +# path +path: + pretrain_network_g: ~ + strict_load_g: true + resume_state: ~ + +# training settings +train: + ema_decay: 0.999 + optim_g: + type: Adam + lr: !!float 1e-4 + weight_decay: 0 + betas: [0.9, 0.99] + optim_d: + type: Adam + lr: !!float 1e-4 + weight_decay: 0 + betas: [0.9, 0.99] + + scheduler: + type: MultiStepLR + milestones: [50000, 100000, 200000, 300000] + gamma: 0.5 + + total_iter: 400000 + warmup_iter: -1 # no warm up + + # losses + pixel_opt: + type: L1Loss + loss_weight: !!float 1e-2 + reduction: mean + perceptual_opt: + type: PerceptualLoss + layer_weights: + 'conv5_4': 1 # before relu + vgg_type: vgg19 + use_input_norm: true + range_norm: false + perceptual_weight: 1.0 + style_weight: 0 + criterion: l1 + gan_opt: + type: GANLoss + gan_type: vanilla + real_label_val: 1.0 + fake_label_val: 0.0 + loss_weight: !!float 5e-3 + + net_d_iters: 1 + net_d_init_iters: 0 + +# validation settings +val: + val_freq: !!float 5e3 + save_img: true + + metrics: + psnr: # metric name, can be arbitrary + type: calculate_psnr + crop_border: 4 + test_y_channel: true + ssim: # metric name, can be arbitrary + type: calculate_ssim + crop_border: 4 + test_y_channel: true + lpips: # metric name, can be arbitrary + type: calculate_lpips + crop_border: 4 + test_y_channel: true + better: lower + +# logging settings +logger: + print_freq: 100 + save_checkpoint_freq: !!float 5e3 + use_tb_logger: true + wandb: + project: ~ + resume_id: ~ + +# dist training settings +dist_params: + backend: nccl + port: 29500 diff --git a/setup.py b/setup.py index 998ba8591..9e31e945e 100644 --- a/setup.py +++ b/setup.py @@ -45,12 +45,13 @@ def _minimal_ext_cmd(cmd): def get_hash(): if os.path.exists('.git'): sha = get_git_hash()[:7] - elif os.path.exists(version_file): - try: - from basicsr.version import __version__ - sha = __version__.split('+')[-1] - except ImportError: - raise ImportError('Unable to get git version') + # currently ignore this + # elif os.path.exists(version_file): + # try: + # from basicsr.version import __version__ + # sha = __version__.split('+')[-1] + # except ImportError: + # raise ImportError('Unable to get git version') else: sha = 'unknown'