From fac754d2c30435d3ba974bed1927aff8892e77c5 Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Tue, 5 Dec 2023 07:55:16 -0500 Subject: [PATCH 01/38] 6676 port generative networks autoencoderkl (#7260) Partially fixes #6676 ### Description Implements the AutoencoderKL network from MONAI Generative. NB this network is subject to a planned refactor once the porting is complete, [see here](https://github.com/Project-MONAI/MONAI/issues/7227). ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [x] New tests added to cover the changes. - [x] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [x] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [x] In-line docstrings updated. - [x] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Mark Graham Co-authored-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- docs/source/networks.rst | 5 + monai/networks/nets/__init__.py | 1 + monai/networks/nets/autoencoderkl.py | 807 +++++++++++++++++++++++++++ tests/test_autoencoderkl.py | 276 +++++++++ 4 files changed, 1089 insertions(+) create mode 100644 monai/networks/nets/autoencoderkl.py create mode 100644 tests/test_autoencoderkl.py diff --git a/docs/source/networks.rst b/docs/source/networks.rst index 8eada7933f..dbfdf35784 100644 --- a/docs/source/networks.rst +++ b/docs/source/networks.rst @@ -595,6 +595,11 @@ Nets .. autoclass:: AutoEncoder :members: +`AutoEncoderKL` +~~~~~~~~~~~~~~~ +.. autoclass:: AutoencoderKL + :members: + `VarAutoEncoder` ~~~~~~~~~~~~~~~~ .. autoclass:: VarAutoEncoder diff --git a/monai/networks/nets/__init__.py b/monai/networks/nets/__init__.py index 9247aaee85..ea08246d25 100644 --- a/monai/networks/nets/__init__.py +++ b/monai/networks/nets/__init__.py @@ -14,6 +14,7 @@ from .ahnet import AHnet, Ahnet, AHNet from .attentionunet import AttentionUnet from .autoencoder import AutoEncoder +from .autoencoderkl import AutoencoderKL from .basic_unet import BasicUNet, BasicUnet, Basicunet, basicunet from .basic_unetplusplus import BasicUNetPlusPlus, BasicUnetPlusPlus, BasicunetPlusPlus, basicunetplusplus from .classifier import Classifier, Critic, Discriminator diff --git a/monai/networks/nets/autoencoderkl.py b/monai/networks/nets/autoencoderkl.py new file mode 100644 index 0000000000..9a9f35d5ae --- /dev/null +++ b/monai/networks/nets/autoencoderkl.py @@ -0,0 +1,807 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import math +from collections.abc import Sequence +from typing import List + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from monai.networks.blocks import Convolution + +# To install xformers, use pip install xformers==0.0.16rc401 +from monai.utils import ensure_tuple_rep, optional_import + +xformers, has_xformers = optional_import("xformers") + +__all__ = ["AutoencoderKL"] + + +class _Upsample(nn.Module): + """ + NOTE This is a private block that we plan to merge with existing MONAI blocks in the future. Please do not make + use of this block as support is not guaranteed. For more information see: + https://github.com/Project-MONAI/MONAI/issues/7227 + + Convolution-based upsampling layer. + + Args: + spatial_dims: number of spatial dimensions (1D, 2D, 3D). + in_channels: number of input channels to the layer. + use_convtranspose: if True, use ConvTranspose to upsample feature maps in decoder. + """ + + def __init__(self, spatial_dims: int, in_channels: int, use_convtranspose: bool) -> None: + super().__init__() + if use_convtranspose: + self.conv = Convolution( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=in_channels, + strides=2, + kernel_size=3, + padding=1, + conv_only=True, + is_transposed=True, + ) + else: + self.conv = Convolution( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=in_channels, + strides=1, + kernel_size=3, + padding=1, + conv_only=True, + ) + self.use_convtranspose = use_convtranspose + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.use_convtranspose: + conv: torch.Tensor = self.conv(x) + return conv + + # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16 + # https://github.com/pytorch/pytorch/issues/86679 + dtype = x.dtype + if dtype == torch.bfloat16: + x = x.to(torch.float32) + + x = F.interpolate(x, scale_factor=2.0, mode="nearest") + + # If the input is bfloat16, we cast back to bfloat16 + if dtype == torch.bfloat16: + x = x.to(dtype) + + x = self.conv(x) + return x + + +class _Downsample(nn.Module): + """ + NOTE This is a private block that we plan to merge with existing MONAI blocks in the future. Please do not make + use of this block as support is not guaranteed. For more information see: + https://github.com/Project-MONAI/MONAI/issues/7227 + + Convolution-based downsampling layer. + + Args: + spatial_dims: number of spatial dimensions (1D, 2D, 3D). + in_channels: number of input channels. + """ + + def __init__(self, spatial_dims: int, in_channels: int) -> None: + super().__init__() + self.pad = (0, 1) * spatial_dims + + self.conv = Convolution( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=in_channels, + strides=2, + kernel_size=3, + padding=0, + conv_only=True, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = nn.functional.pad(x, self.pad, mode="constant", value=0.0) + x = self.conv(x) + return x + + +class _ResBlock(nn.Module): + """ + NOTE This is a private block that we plan to merge with existing MONAI blocks in the future. Please do not make + use of this block as support is not guaranteed. For more information see: + https://github.com/Project-MONAI/MONAI/issues/7227 + + Residual block consisting of a cascade of 2 convolutions + activation + normalisation block, and a + residual connection between input and output. + + Args: + spatial_dims: number of spatial dimensions (1D, 2D, 3D). + in_channels: input channels to the layer. + norm_num_groups: number of groups involved for the group normalisation layer. Ensure that your number of + channels is divisible by this number. + norm_eps: epsilon for the normalisation. + out_channels: number of output channels. + """ + + def __init__( + self, spatial_dims: int, in_channels: int, norm_num_groups: int, norm_eps: float, out_channels: int + ) -> None: + super().__init__() + self.in_channels = in_channels + self.out_channels = in_channels if out_channels is None else out_channels + + self.norm1 = nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=norm_eps, affine=True) + self.conv1 = Convolution( + spatial_dims=spatial_dims, + in_channels=self.in_channels, + out_channels=self.out_channels, + strides=1, + kernel_size=3, + padding=1, + conv_only=True, + ) + self.norm2 = nn.GroupNorm(num_groups=norm_num_groups, num_channels=out_channels, eps=norm_eps, affine=True) + self.conv2 = Convolution( + spatial_dims=spatial_dims, + in_channels=self.out_channels, + out_channels=self.out_channels, + strides=1, + kernel_size=3, + padding=1, + conv_only=True, + ) + + self.nin_shortcut: nn.Module + if self.in_channels != self.out_channels: + self.nin_shortcut = Convolution( + spatial_dims=spatial_dims, + in_channels=self.in_channels, + out_channels=self.out_channels, + strides=1, + kernel_size=1, + padding=0, + conv_only=True, + ) + else: + self.nin_shortcut = nn.Identity() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + h = x + h = self.norm1(h) + h = F.silu(h) + h = self.conv1(h) + + h = self.norm2(h) + h = F.silu(h) + h = self.conv2(h) + + x = self.nin_shortcut(x) + + return x + h + + +class _AttentionBlock(nn.Module): + """ + NOTE This is a private block that we plan to merge with existing MONAI blocks in the future. Please do not make + use of this block as support is not guaranteed. For more information see: + https://github.com/Project-MONAI/MONAI/issues/7227 + + Attention block. + + Args: + spatial_dims: number of spatial dimensions (1D, 2D, 3D). + num_channels: number of input channels. + num_head_channels: number of channels in each attention head. + norm_num_groups: number of groups involved for the group normalisation layer. Ensure that your number of + channels is divisible by this number. + norm_eps: epsilon value to use for the normalisation. + use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. + """ + + def __init__( + self, + spatial_dims: int, + num_channels: int, + num_head_channels: int | None = None, + norm_num_groups: int = 32, + norm_eps: float = 1e-6, + use_flash_attention: bool = False, + ) -> None: + super().__init__() + self.use_flash_attention = use_flash_attention + self.spatial_dims = spatial_dims + self.num_channels = num_channels + + self.num_heads = num_channels // num_head_channels if num_head_channels is not None else 1 + self.scale = 1 / math.sqrt(num_channels / self.num_heads) + + self.norm = nn.GroupNorm(num_groups=norm_num_groups, num_channels=num_channels, eps=norm_eps, affine=True) + + self.to_q = nn.Linear(num_channels, num_channels) + self.to_k = nn.Linear(num_channels, num_channels) + self.to_v = nn.Linear(num_channels, num_channels) + + self.proj_attn = nn.Linear(num_channels, num_channels) + + def reshape_heads_to_batch_dim(self, x: torch.Tensor) -> torch.Tensor: + """ + Divide hidden state dimension to the multiple attention heads and reshape their input as instances in the batch. + """ + batch_size, seq_len, dim = x.shape + x = x.reshape(batch_size, seq_len, self.num_heads, dim // self.num_heads) + x = x.permute(0, 2, 1, 3).reshape(batch_size * self.num_heads, seq_len, dim // self.num_heads) + return x + + def reshape_batch_dim_to_heads(self, x: torch.Tensor) -> torch.Tensor: + """Combine the output of the attention heads back into the hidden state dimension.""" + batch_size, seq_len, dim = x.shape + x = x.reshape(batch_size // self.num_heads, self.num_heads, seq_len, dim) + x = x.permute(0, 2, 1, 3).reshape(batch_size // self.num_heads, seq_len, dim * self.num_heads) + return x + + def _memory_efficient_attention_xformers( + self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor + ) -> torch.Tensor: + query = query.contiguous() + key = key.contiguous() + value = value.contiguous() + x: torch.Tensor = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=None) + return x + + def _attention(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor: + attention_scores = torch.baddbmm( + torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device), + query, + key.transpose(-1, -2), + beta=0, + alpha=self.scale, + ) + attention_probs = attention_scores.softmax(dim=-1) + x = torch.bmm(attention_probs, value) + return x + + def forward(self, x: torch.Tensor) -> torch.Tensor: + residual = x + + batch = channel = height = width = depth = -1 + if self.spatial_dims == 2: + batch, channel, height, width = x.shape + if self.spatial_dims == 3: + batch, channel, height, width, depth = x.shape + + # norm + x = self.norm(x) + + if self.spatial_dims == 2: + x = x.view(batch, channel, height * width).transpose(1, 2) + if self.spatial_dims == 3: + x = x.view(batch, channel, height * width * depth).transpose(1, 2) + + # proj to q, k, v + query = self.to_q(x) + key = self.to_k(x) + value = self.to_v(x) + + # Multi-Head Attention + query = self.reshape_heads_to_batch_dim(query) + key = self.reshape_heads_to_batch_dim(key) + value = self.reshape_heads_to_batch_dim(value) + + if self.use_flash_attention: + x = self._memory_efficient_attention_xformers(query, key, value) + else: + x = self._attention(query, key, value) + + x = self.reshape_batch_dim_to_heads(x) + x = x.to(query.dtype) + + if self.spatial_dims == 2: + x = x.transpose(-1, -2).reshape(batch, channel, height, width) + if self.spatial_dims == 3: + x = x.transpose(-1, -2).reshape(batch, channel, height, width, depth) + + return x + residual + + +class Encoder(nn.Module): + """ + Convolutional cascade that downsamples the image into a spatial latent space. + + Args: + spatial_dims: number of spatial dimensions (1D, 2D, 3D). + in_channels: number of input channels. + channels: sequence of block output channels. + out_channels: number of channels in the bottom layer (latent space) of the autoencoder. + num_res_blocks: number of residual blocks (see _ResBlock) per level. + norm_num_groups: number of groups for the GroupNorm layers, num_channels must be divisible by this number. + norm_eps: epsilon for the normalization. + attention_levels: indicate which level from num_channels contain an attention block. + with_nonlocal_attn: if True use non-local attention block. + use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + channels: Sequence[int], + out_channels: int, + num_res_blocks: Sequence[int], + norm_num_groups: int, + norm_eps: float, + attention_levels: Sequence[bool], + with_nonlocal_attn: bool = True, + use_flash_attention: bool = False, + ) -> None: + super().__init__() + self.spatial_dims = spatial_dims + self.in_channels = in_channels + self.channels = channels + self.out_channels = out_channels + self.num_res_blocks = num_res_blocks + self.norm_num_groups = norm_num_groups + self.norm_eps = norm_eps + self.attention_levels = attention_levels + + blocks: List[nn.Module] = [] + # Initial convolution + blocks.append( + Convolution( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=channels[0], + strides=1, + kernel_size=3, + padding=1, + conv_only=True, + ) + ) + + # Residual and downsampling blocks + output_channel = channels[0] + for i in range(len(channels)): + input_channel = output_channel + output_channel = channels[i] + is_final_block = i == len(channels) - 1 + + for _ in range(self.num_res_blocks[i]): + blocks.append( + _ResBlock( + spatial_dims=spatial_dims, + in_channels=input_channel, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + out_channels=output_channel, + ) + ) + input_channel = output_channel + if attention_levels[i]: + blocks.append( + _AttentionBlock( + spatial_dims=spatial_dims, + num_channels=input_channel, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + use_flash_attention=use_flash_attention, + ) + ) + + if not is_final_block: + blocks.append(_Downsample(spatial_dims=spatial_dims, in_channels=input_channel)) + + # Non-local attention block + if with_nonlocal_attn is True: + blocks.append( + _ResBlock( + spatial_dims=spatial_dims, + in_channels=channels[-1], + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + out_channels=channels[-1], + ) + ) + + blocks.append( + _AttentionBlock( + spatial_dims=spatial_dims, + num_channels=channels[-1], + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + use_flash_attention=use_flash_attention, + ) + ) + blocks.append( + _ResBlock( + spatial_dims=spatial_dims, + in_channels=channels[-1], + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + out_channels=channels[-1], + ) + ) + # Normalise and convert to latent size + blocks.append(nn.GroupNorm(num_groups=norm_num_groups, num_channels=channels[-1], eps=norm_eps, affine=True)) + blocks.append( + Convolution( + spatial_dims=self.spatial_dims, + in_channels=channels[-1], + out_channels=out_channels, + strides=1, + kernel_size=3, + padding=1, + conv_only=True, + ) + ) + + self.blocks = nn.ModuleList(blocks) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + for block in self.blocks: + x = block(x) + return x + + +class Decoder(nn.Module): + """ + Convolutional cascade upsampling from a spatial latent space into an image space. + + Args: + spatial_dims: number of spatial dimensions (1D, 2D, 3D). + channels: sequence of block output channels. + in_channels: number of channels in the bottom layer (latent space) of the autoencoder. + out_channels: number of output channels. + num_res_blocks: number of residual blocks (see _ResBlock) per level. + norm_num_groups: number of groups for the GroupNorm layers, num_channels must be divisible by this number. + norm_eps: epsilon for the normalization. + attention_levels: indicate which level from num_channels contain an attention block. + with_nonlocal_attn: if True use non-local attention block. + use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. + use_convtranspose: if True, use ConvTranspose to upsample feature maps in decoder. + """ + + def __init__( + self, + spatial_dims: int, + channels: Sequence[int], + in_channels: int, + out_channels: int, + num_res_blocks: Sequence[int], + norm_num_groups: int, + norm_eps: float, + attention_levels: Sequence[bool], + with_nonlocal_attn: bool = True, + use_flash_attention: bool = False, + use_convtranspose: bool = False, + ) -> None: + super().__init__() + self.spatial_dims = spatial_dims + self.channels = channels + self.in_channels = in_channels + self.out_channels = out_channels + self.num_res_blocks = num_res_blocks + self.norm_num_groups = norm_num_groups + self.norm_eps = norm_eps + self.attention_levels = attention_levels + + reversed_block_out_channels = list(reversed(channels)) + + blocks: List[nn.Module] = [] + + # Initial convolution + blocks.append( + Convolution( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=reversed_block_out_channels[0], + strides=1, + kernel_size=3, + padding=1, + conv_only=True, + ) + ) + + # Non-local attention block + if with_nonlocal_attn is True: + blocks.append( + _ResBlock( + spatial_dims=spatial_dims, + in_channels=reversed_block_out_channels[0], + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + out_channels=reversed_block_out_channels[0], + ) + ) + blocks.append( + _AttentionBlock( + spatial_dims=spatial_dims, + num_channels=reversed_block_out_channels[0], + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + use_flash_attention=use_flash_attention, + ) + ) + blocks.append( + _ResBlock( + spatial_dims=spatial_dims, + in_channels=reversed_block_out_channels[0], + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + out_channels=reversed_block_out_channels[0], + ) + ) + + reversed_attention_levels = list(reversed(attention_levels)) + reversed_num_res_blocks = list(reversed(num_res_blocks)) + block_out_ch = reversed_block_out_channels[0] + for i in range(len(reversed_block_out_channels)): + block_in_ch = block_out_ch + block_out_ch = reversed_block_out_channels[i] + is_final_block = i == len(channels) - 1 + + for _ in range(reversed_num_res_blocks[i]): + blocks.append( + _ResBlock( + spatial_dims=spatial_dims, + in_channels=block_in_ch, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + out_channels=block_out_ch, + ) + ) + block_in_ch = block_out_ch + + if reversed_attention_levels[i]: + blocks.append( + _AttentionBlock( + spatial_dims=spatial_dims, + num_channels=block_in_ch, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + use_flash_attention=use_flash_attention, + ) + ) + + if not is_final_block: + blocks.append( + _Upsample(spatial_dims=spatial_dims, in_channels=block_in_ch, use_convtranspose=use_convtranspose) + ) + + blocks.append(nn.GroupNorm(num_groups=norm_num_groups, num_channels=block_in_ch, eps=norm_eps, affine=True)) + blocks.append( + Convolution( + spatial_dims=spatial_dims, + in_channels=block_in_ch, + out_channels=out_channels, + strides=1, + kernel_size=3, + padding=1, + conv_only=True, + ) + ) + + self.blocks = nn.ModuleList(blocks) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + for block in self.blocks: + x = block(x) + return x + + +class AutoencoderKL(nn.Module): + """ + Autoencoder model with KL-regularized latent space based on + Rombach et al. "High-Resolution Image Synthesis with Latent Diffusion Models" https://arxiv.org/abs/2112.10752 + and Pinaya et al. "Brain Imaging Generation with Latent Diffusion Models" https://arxiv.org/abs/2209.07162 + + Args: + spatial_dims: number of spatial dimensions (1D, 2D, 3D). + in_channels: number of input channels. + out_channels: number of output channels. + num_res_blocks: number of residual blocks (see _ResBlock) per level. + channels: number of output channels for each block. + attention_levels: sequence of levels to add attention. + latent_channels: latent embedding dimension. + norm_num_groups: number of groups for the GroupNorm layers, num_channels must be divisible by this number. + norm_eps: epsilon for the normalization. + with_encoder_nonlocal_attn: if True use non-local attention block in the encoder. + with_decoder_nonlocal_attn: if True use non-local attention block in the decoder. + use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. + use_checkpoint: if True, use activation checkpoint to save memory. + use_convtranspose: if True, use ConvTranspose to upsample feature maps in decoder. + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int = 1, + out_channels: int = 1, + num_res_blocks: Sequence[int] | int = (2, 2, 2, 2), + channels: Sequence[int] = (32, 64, 64, 64), + attention_levels: Sequence[bool] = (False, False, True, True), + latent_channels: int = 3, + norm_num_groups: int = 32, + norm_eps: float = 1e-6, + with_encoder_nonlocal_attn: bool = True, + with_decoder_nonlocal_attn: bool = True, + use_flash_attention: bool = False, + use_checkpoint: bool = False, + use_convtranspose: bool = False, + ) -> None: + super().__init__() + + # All number of channels should be multiple of num_groups + if any((out_channel % norm_num_groups) != 0 for out_channel in channels): + raise ValueError("AutoencoderKL expects all num_channels being multiple of norm_num_groups") + + if len(channels) != len(attention_levels): + raise ValueError("AutoencoderKL expects num_channels being same size of attention_levels") + + if isinstance(num_res_blocks, int): + num_res_blocks = ensure_tuple_rep(num_res_blocks, len(channels)) + + if len(num_res_blocks) != len(channels): + raise ValueError( + "`num_res_blocks` should be a single integer or a tuple of integers with the same length as " + "`num_channels`." + ) + + if use_flash_attention is True and not torch.cuda.is_available(): + raise ValueError( + "torch.cuda.is_available() should be True but is False. Flash attention is only available for GPU." + ) + + self.encoder = Encoder( + spatial_dims=spatial_dims, + in_channels=in_channels, + channels=channels, + out_channels=latent_channels, + num_res_blocks=num_res_blocks, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + attention_levels=attention_levels, + with_nonlocal_attn=with_encoder_nonlocal_attn, + use_flash_attention=use_flash_attention, + ) + self.decoder = Decoder( + spatial_dims=spatial_dims, + channels=channels, + in_channels=latent_channels, + out_channels=out_channels, + num_res_blocks=num_res_blocks, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + attention_levels=attention_levels, + with_nonlocal_attn=with_decoder_nonlocal_attn, + use_flash_attention=use_flash_attention, + use_convtranspose=use_convtranspose, + ) + self.quant_conv_mu = Convolution( + spatial_dims=spatial_dims, + in_channels=latent_channels, + out_channels=latent_channels, + strides=1, + kernel_size=1, + padding=0, + conv_only=True, + ) + self.quant_conv_log_sigma = Convolution( + spatial_dims=spatial_dims, + in_channels=latent_channels, + out_channels=latent_channels, + strides=1, + kernel_size=1, + padding=0, + conv_only=True, + ) + self.post_quant_conv = Convolution( + spatial_dims=spatial_dims, + in_channels=latent_channels, + out_channels=latent_channels, + strides=1, + kernel_size=1, + padding=0, + conv_only=True, + ) + self.latent_channels = latent_channels + self.use_checkpoint = use_checkpoint + + def encode(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + """ + Forwards an image through the spatial encoder, obtaining the latent mean and sigma representations. + + Args: + x: BxCx[SPATIAL DIMS] tensor + + """ + if self.use_checkpoint: + h = torch.utils.checkpoint.checkpoint(self.encoder, x, use_reentrant=False) + else: + h = self.encoder(x) + + z_mu = self.quant_conv_mu(h) + z_log_var = self.quant_conv_log_sigma(h) + z_log_var = torch.clamp(z_log_var, -30.0, 20.0) + z_sigma = torch.exp(z_log_var / 2) + + return z_mu, z_sigma + + def sampling(self, z_mu: torch.Tensor, z_sigma: torch.Tensor) -> torch.Tensor: + """ + From the mean and sigma representations resulting of encoding an image through the latent space, + obtains a noise sample resulting from sampling gaussian noise, multiplying by the variance (sigma) and + adding the mean. + + Args: + z_mu: Bx[Z_CHANNELS]x[LATENT SPACE SIZE] mean vector obtained by the encoder when you encode an image + z_sigma: Bx[Z_CHANNELS]x[LATENT SPACE SIZE] variance vector obtained by the encoder when you encode an image + + Returns: + sample of shape Bx[Z_CHANNELS]x[LATENT SPACE SIZE] + """ + eps = torch.randn_like(z_sigma) + z_vae = z_mu + eps * z_sigma + return z_vae + + def reconstruct(self, x: torch.Tensor) -> torch.Tensor: + """ + Encodes and decodes an input image. + + Args: + x: BxCx[SPATIAL DIMENSIONS] tensor. + + Returns: + reconstructed image, of the same shape as input + """ + z_mu, _ = self.encode(x) + reconstruction = self.decode(z_mu) + return reconstruction + + def decode(self, z: torch.Tensor) -> torch.Tensor: + """ + Based on a latent space sample, forwards it through the Decoder. + + Args: + z: Bx[Z_CHANNELS]x[LATENT SPACE SHAPE] + + Returns: + decoded image tensor + """ + z = self.post_quant_conv(z) + dec: torch.Tensor + if self.use_checkpoint: + dec = torch.utils.checkpoint.checkpoint(self.decoder, z, use_reentrant=False) + else: + dec = self.decoder(z) + return dec + + def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + z_mu, z_sigma = self.encode(x) + z = self.sampling(z_mu, z_sigma) + reconstruction = self.decode(z) + return reconstruction, z_mu, z_sigma + + def encode_stage_2_inputs(self, x: torch.Tensor) -> torch.Tensor: + z_mu, z_sigma = self.encode(x) + z = self.sampling(z_mu, z_sigma) + return z + + def decode_stage_2_outputs(self, z: torch.Tensor) -> torch.Tensor: + image = self.decode(z) + return image diff --git a/tests/test_autoencoderkl.py b/tests/test_autoencoderkl.py new file mode 100644 index 0000000000..448f1e8e9a --- /dev/null +++ b/tests/test_autoencoderkl.py @@ -0,0 +1,276 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import torch +from parameterized import parameterized + +from monai.networks import eval_mode +from monai.networks.nets import AutoencoderKL +from tests.utils import SkipIfBeforePyTorchVersion + +device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + +CASES = [ + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "channels": (4, 4, 4), + "latent_channels": 4, + "attention_levels": (False, False, False), + "num_res_blocks": 1, + "norm_num_groups": 4, + }, + (1, 1, 16, 16), + (1, 1, 16, 16), + (1, 4, 4, 4), + ], + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "channels": (4, 4, 4), + "latent_channels": 4, + "attention_levels": (False, False, False), + "num_res_blocks": (1, 1, 2), + "norm_num_groups": 4, + }, + (1, 1, 16, 16), + (1, 1, 16, 16), + (1, 4, 4, 4), + ], + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "channels": (4, 4, 4), + "latent_channels": 4, + "attention_levels": (False, False, False), + "num_res_blocks": 1, + "norm_num_groups": 4, + }, + (1, 1, 16, 16), + (1, 1, 16, 16), + (1, 4, 4, 4), + ], + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "channels": (4, 4, 4), + "latent_channels": 4, + "attention_levels": (False, False, True), + "num_res_blocks": 1, + "norm_num_groups": 4, + }, + (1, 1, 16, 16), + (1, 1, 16, 16), + (1, 4, 4, 4), + ], + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "channels": (4, 4, 4), + "latent_channels": 4, + "attention_levels": (False, False, False), + "num_res_blocks": 1, + "norm_num_groups": 4, + "with_encoder_nonlocal_attn": False, + }, + (1, 1, 16, 16), + (1, 1, 16, 16), + (1, 4, 4, 4), + ], + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "channels": (4, 4, 4), + "latent_channels": 4, + "attention_levels": (False, False, False), + "num_res_blocks": 1, + "norm_num_groups": 4, + "with_encoder_nonlocal_attn": False, + "with_decoder_nonlocal_attn": False, + }, + (1, 1, 16, 16), + (1, 1, 16, 16), + (1, 4, 4, 4), + ], + [ + { + "spatial_dims": 3, + "in_channels": 1, + "out_channels": 1, + "channels": (4, 4, 4), + "latent_channels": 4, + "attention_levels": (False, False, True), + "num_res_blocks": 1, + "norm_num_groups": 4, + }, + (1, 1, 16, 16, 16), + (1, 1, 16, 16, 16), + (1, 4, 4, 4, 4), + ], +] + + +class TestAutoEncoderKL(unittest.TestCase): + @parameterized.expand(CASES) + def test_shape(self, input_param, input_shape, expected_shape, expected_latent_shape): + net = AutoencoderKL(**input_param).to(device) + with eval_mode(net): + result = net.forward(torch.randn(input_shape).to(device)) + self.assertEqual(result[0].shape, expected_shape) + self.assertEqual(result[1].shape, expected_latent_shape) + self.assertEqual(result[2].shape, expected_latent_shape) + + @parameterized.expand(CASES) + @SkipIfBeforePyTorchVersion((1, 11)) + def test_shape_with_convtranspose_and_checkpointing( + self, input_param, input_shape, expected_shape, expected_latent_shape + ): + input_param = input_param.copy() + input_param.update({"use_checkpoint": True, "use_convtranspose": True}) + net = AutoencoderKL(**input_param).to(device) + with eval_mode(net): + result = net.forward(torch.randn(input_shape).to(device)) + self.assertEqual(result[0].shape, expected_shape) + self.assertEqual(result[1].shape, expected_latent_shape) + self.assertEqual(result[2].shape, expected_latent_shape) + + def test_model_channels_not_multiple_of_norm_num_group(self): + with self.assertRaises(ValueError): + AutoencoderKL( + spatial_dims=2, + in_channels=1, + out_channels=1, + channels=(24, 24, 24), + attention_levels=(False, False, False), + latent_channels=8, + num_res_blocks=1, + norm_num_groups=16, + ) + + def test_model_num_channels_not_same_size_of_attention_levels(self): + with self.assertRaises(ValueError): + AutoencoderKL( + spatial_dims=2, + in_channels=1, + out_channels=1, + channels=(24, 24, 24), + attention_levels=(False, False), + latent_channels=8, + num_res_blocks=1, + norm_num_groups=16, + ) + + def test_model_num_channels_not_same_size_of_num_res_blocks(self): + with self.assertRaises(ValueError): + AutoencoderKL( + spatial_dims=2, + in_channels=1, + out_channels=1, + channels=(24, 24, 24), + attention_levels=(False, False, False), + latent_channels=8, + num_res_blocks=(8, 8), + norm_num_groups=16, + ) + + def test_shape_reconstruction(self): + input_param, input_shape, expected_shape, _ = CASES[0] + net = AutoencoderKL(**input_param).to(device) + with eval_mode(net): + result = net.reconstruct(torch.randn(input_shape).to(device)) + self.assertEqual(result.shape, expected_shape) + + @SkipIfBeforePyTorchVersion((1, 11)) + def test_shape_reconstruction_with_convtranspose_and_checkpointing(self): + input_param, input_shape, expected_shape, _ = CASES[0] + input_param = input_param.copy() + input_param.update({"use_checkpoint": True, "use_convtranspose": True}) + net = AutoencoderKL(**input_param).to(device) + with eval_mode(net): + result = net.reconstruct(torch.randn(input_shape).to(device)) + self.assertEqual(result.shape, expected_shape) + + def test_shape_encode(self): + input_param, input_shape, _, expected_latent_shape = CASES[0] + net = AutoencoderKL(**input_param).to(device) + with eval_mode(net): + result = net.encode(torch.randn(input_shape).to(device)) + self.assertEqual(result[0].shape, expected_latent_shape) + self.assertEqual(result[1].shape, expected_latent_shape) + + @SkipIfBeforePyTorchVersion((1, 11)) + def test_shape_encode_with_convtranspose_and_checkpointing(self): + input_param, input_shape, _, expected_latent_shape = CASES[0] + input_param = input_param.copy() + input_param.update({"use_checkpoint": True, "use_convtranspose": True}) + net = AutoencoderKL(**input_param).to(device) + with eval_mode(net): + result = net.encode(torch.randn(input_shape).to(device)) + self.assertEqual(result[0].shape, expected_latent_shape) + self.assertEqual(result[1].shape, expected_latent_shape) + + def test_shape_sampling(self): + input_param, _, _, expected_latent_shape = CASES[0] + net = AutoencoderKL(**input_param).to(device) + with eval_mode(net): + result = net.sampling( + torch.randn(expected_latent_shape).to(device), torch.randn(expected_latent_shape).to(device) + ) + self.assertEqual(result.shape, expected_latent_shape) + + @SkipIfBeforePyTorchVersion((1, 11)) + def test_shape_sampling_convtranspose_and_checkpointing(self): + input_param, _, _, expected_latent_shape = CASES[0] + input_param = input_param.copy() + input_param.update({"use_checkpoint": True, "use_convtranspose": True}) + net = AutoencoderKL(**input_param).to(device) + with eval_mode(net): + result = net.sampling( + torch.randn(expected_latent_shape).to(device), torch.randn(expected_latent_shape).to(device) + ) + self.assertEqual(result.shape, expected_latent_shape) + + def test_shape_decode(self): + input_param, expected_input_shape, _, latent_shape = CASES[0] + net = AutoencoderKL(**input_param).to(device) + with eval_mode(net): + result = net.decode(torch.randn(latent_shape).to(device)) + self.assertEqual(result.shape, expected_input_shape) + + @SkipIfBeforePyTorchVersion((1, 11)) + def test_shape_decode_convtranspose_and_checkpointing(self): + input_param, expected_input_shape, _, latent_shape = CASES[0] + input_param = input_param.copy() + input_param.update({"use_checkpoint": True, "use_convtranspose": True}) + net = AutoencoderKL(**input_param).to(device) + with eval_mode(net): + result = net.decode(torch.randn(latent_shape).to(device)) + self.assertEqual(result.shape, expected_input_shape) + + +if __name__ == "__main__": + unittest.main() From b3fdfdd2111c5d1349a345fbd4e24c570d1fb690 Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Wed, 6 Dec 2023 22:36:50 -0500 Subject: [PATCH 02/38] 6676 port generative networks vqvae (#7285) Partially fixes https://github.com/Project-MONAI/MONAI/issues/6676 ### Description Implements the VQ-VAE network, including the vector quantizer block, from MONAI Generative. ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [x] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [x] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: KumoLiu Signed-off-by: Mark Graham Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> Co-authored-by: KumoLiu Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- docs/source/networks.rst | 13 + monai/bundle/scripts.py | 2 +- monai/networks/layers/__init__.py | 1 + monai/networks/layers/vector_quantizer.py | 233 +++++++++++ monai/networks/nets/__init__.py | 1 + monai/networks/nets/autoencoderkl.py | 14 +- monai/networks/nets/vqvae.py | 466 ++++++++++++++++++++++ tests/test_vector_quantizer.py | 89 +++++ tests/test_vqvae.py | 274 +++++++++++++ 9 files changed, 1085 insertions(+), 8 deletions(-) create mode 100644 monai/networks/layers/vector_quantizer.py create mode 100644 monai/networks/nets/vqvae.py create mode 100644 tests/test_vector_quantizer.py create mode 100644 tests/test_vqvae.py diff --git a/docs/source/networks.rst b/docs/source/networks.rst index dbfdf35784..d8be26264b 100644 --- a/docs/source/networks.rst +++ b/docs/source/networks.rst @@ -258,6 +258,7 @@ N-Dim Fourier Transform .. autofunction:: monai.networks.blocks.fft_utils_t.fftshift .. autofunction:: monai.networks.blocks.fft_utils_t.ifftshift + Layers ------ @@ -408,6 +409,13 @@ Layers .. autoclass:: LLTM :members: +`Vector Quantizer` +~~~~~~~~~~~~~~~~~~ +.. autoclass:: monai.networks.layers.vector_quantizer.EMAQuantizer + :members: +.. autoclass:: monai.networks.layers.vector_quantizer.VectorQuantizer + :members: + `Utilities` ~~~~~~~~~~~ .. automodule:: monai.networks.layers.convutils @@ -728,6 +736,11 @@ Nets .. autoclass:: voxelmorph +`VQ-VAE` +~~~~~~~~ +.. autoclass:: VQVAE + :members: + Utilities --------- .. automodule:: monai.networks.utils diff --git a/monai/bundle/scripts.py b/monai/bundle/scripts.py index 20a491e493..2565a3cf64 100644 --- a/monai/bundle/scripts.py +++ b/monai/bundle/scripts.py @@ -221,7 +221,7 @@ def _download_from_ngc( def _get_latest_bundle_version_monaihosting(name): url = "https://api.ngc.nvidia.com/v2/models/nvidia/monaihosting" - full_url = f"{url}/{name}" + full_url = f"{url}/{name.lower()}" requests_get, has_requests = optional_import("requests", name="get") if has_requests: resp = requests_get(full_url) diff --git a/monai/networks/layers/__init__.py b/monai/networks/layers/__init__.py index d61ed57f7f..bd3e3af3af 100644 --- a/monai/networks/layers/__init__.py +++ b/monai/networks/layers/__init__.py @@ -37,4 +37,5 @@ ) from .spatial_transforms import AffineTransform, grid_count, grid_grad, grid_pull, grid_push from .utils import get_act_layer, get_dropout_layer, get_norm_layer, get_pool_layer +from .vector_quantizer import EMAQuantizer, VectorQuantizer from .weight_init import _no_grad_trunc_normal_, trunc_normal_ diff --git a/monai/networks/layers/vector_quantizer.py b/monai/networks/layers/vector_quantizer.py new file mode 100644 index 0000000000..9c354e1009 --- /dev/null +++ b/monai/networks/layers/vector_quantizer.py @@ -0,0 +1,233 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import Sequence, Tuple + +import torch +from torch import nn + +__all__ = ["VectorQuantizer", "EMAQuantizer"] + + +class EMAQuantizer(nn.Module): + """ + Vector Quantization module using Exponential Moving Average (EMA) to learn the codebook parameters based on Neural + Discrete Representation Learning by Oord et al. (https://arxiv.org/abs/1711.00937) and the official implementation + that can be found at https://github.com/deepmind/sonnet/blob/v2/sonnet/src/nets/vqvae.py#L148 and commit + 58d9a2746493717a7c9252938da7efa6006f3739. + + This module is not compatible with TorchScript while working in a Distributed Data Parallelism Module. This is due + to lack of TorchScript support for torch.distributed module as per https://github.com/pytorch/pytorch/issues/41353 + on 22/10/2022. If you want to TorchScript your model, please turn set `ddp_sync` to False. + + Args: + spatial_dims: number of spatial dimensions of the input. + num_embeddings: number of atomic elements in the codebook. + embedding_dim: number of channels of the input and atomic elements. + commitment_cost: scaling factor of the MSE loss between input and its quantized version. Defaults to 0.25. + decay: EMA decay. Defaults to 0.99. + epsilon: epsilon value. Defaults to 1e-5. + embedding_init: initialization method for the codebook. Defaults to "normal". + ddp_sync: whether to synchronize the codebook across processes. Defaults to True. + """ + + def __init__( + self, + spatial_dims: int, + num_embeddings: int, + embedding_dim: int, + commitment_cost: float = 0.25, + decay: float = 0.99, + epsilon: float = 1e-5, + embedding_init: str = "normal", + ddp_sync: bool = True, + ): + super().__init__() + self.spatial_dims: int = spatial_dims + self.embedding_dim: int = embedding_dim + self.num_embeddings: int = num_embeddings + + assert self.spatial_dims in [2, 3], ValueError( + f"EMAQuantizer only supports 4D and 5D tensor inputs but received spatial dims {spatial_dims}." + ) + + self.embedding: torch.nn.Embedding = torch.nn.Embedding(self.num_embeddings, self.embedding_dim) + if embedding_init == "normal": + # Initialization is passed since the default one is normal inside the nn.Embedding + pass + elif embedding_init == "kaiming_uniform": + torch.nn.init.kaiming_uniform_(self.embedding.weight.data, mode="fan_in", nonlinearity="linear") + self.embedding.weight.requires_grad = False + + self.commitment_cost: float = commitment_cost + + self.register_buffer("ema_cluster_size", torch.zeros(self.num_embeddings)) + self.register_buffer("ema_w", self.embedding.weight.data.clone()) + # declare types for mypy + self.ema_cluster_size: torch.Tensor + self.ema_w: torch.Tensor + self.decay: float = decay + self.epsilon: float = epsilon + + self.ddp_sync: bool = ddp_sync + + # Precalculating required permutation shapes + self.flatten_permutation = [0] + list(range(2, self.spatial_dims + 2)) + [1] + self.quantization_permutation: Sequence[int] = [0, self.spatial_dims + 1] + list( + range(1, self.spatial_dims + 1) + ) + + def quantize(self, inputs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Given an input it projects it to the quantized space and returns additional tensors needed for EMA loss. + + Args: + inputs: Encoding space tensors of shape [B, C, H, W, D]. + + Returns: + torch.Tensor: Flatten version of the input of shape [B*H*W*D, C]. + torch.Tensor: One-hot representation of the quantization indices of shape [B*H*W*D, self.num_embeddings]. + torch.Tensor: Quantization indices of shape [B,H,W,D,1] + + """ + with torch.cuda.amp.autocast(enabled=False): + encoding_indices_view = list(inputs.shape) + del encoding_indices_view[1] + + inputs = inputs.float() + + # Converting to channel last format + flat_input = inputs.permute(self.flatten_permutation).contiguous().view(-1, self.embedding_dim) + + # Calculate Euclidean distances + distances = ( + (flat_input**2).sum(dim=1, keepdim=True) + + (self.embedding.weight.t() ** 2).sum(dim=0, keepdim=True) + - 2 * torch.mm(flat_input, self.embedding.weight.t()) + ) + + # Mapping distances to indexes + encoding_indices = torch.max(-distances, dim=1)[1] + encodings = torch.nn.functional.one_hot(encoding_indices, self.num_embeddings).float() + + # Quantize and reshape + encoding_indices = encoding_indices.view(encoding_indices_view) + + return flat_input, encodings, encoding_indices + + def embed(self, embedding_indices: torch.Tensor) -> torch.Tensor: + """ + Given encoding indices of shape [B,D,H,W,1] embeds them in the quantized space + [B, D, H, W, self.embedding_dim] and reshapes them to [B, self.embedding_dim, D, H, W] to be fed to the + decoder. + + Args: + embedding_indices: Tensor in channel last format which holds indices referencing atomic + elements from self.embedding + + Returns: + torch.Tensor: Quantize space representation of encoding_indices in channel first format. + """ + with torch.cuda.amp.autocast(enabled=False): + embedding: torch.Tensor = ( + self.embedding(embedding_indices).permute(self.quantization_permutation).contiguous() + ) + return embedding + + def distributed_synchronization(self, encodings_sum: torch.Tensor, dw: torch.Tensor) -> None: + """ + TorchScript does not support torch.distributed.all_reduce. This function is a bypassing trick based on the + example: https://pytorch.org/docs/stable/generated/torch.jit.unused.html#torch.jit.unused + + Args: + encodings_sum: The summation of one hot representation of what encoding was used for each + position. + dw: The multiplication of the one hot representation of what encoding was used for each + position with the flattened input. + + Returns: + None + """ + if self.ddp_sync and torch.distributed.is_initialized(): + torch.distributed.all_reduce(tensor=encodings_sum, op=torch.distributed.ReduceOp.SUM) + torch.distributed.all_reduce(tensor=dw, op=torch.distributed.ReduceOp.SUM) + else: + pass + + def forward(self, inputs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + flat_input, encodings, encoding_indices = self.quantize(inputs) + quantized = self.embed(encoding_indices) + + # Use EMA to update the embedding vectors + if self.training: + with torch.no_grad(): + encodings_sum = encodings.sum(0) + dw = torch.mm(encodings.t(), flat_input) + + if self.ddp_sync: + self.distributed_synchronization(encodings_sum, dw) + + self.ema_cluster_size.data.mul_(self.decay).add_(torch.mul(encodings_sum, 1 - self.decay)) + + # Laplace smoothing of the cluster size + n = self.ema_cluster_size.sum() + weights = (self.ema_cluster_size + self.epsilon) / (n + self.num_embeddings * self.epsilon) * n + self.ema_w.data.mul_(self.decay).add_(torch.mul(dw, 1 - self.decay)) + self.embedding.weight.data.copy_(self.ema_w / weights.unsqueeze(1)) + + # Encoding Loss + loss = self.commitment_cost * torch.nn.functional.mse_loss(quantized.detach(), inputs) + + # Straight Through Estimator + quantized = inputs + (quantized - inputs).detach() + + return quantized, loss, encoding_indices + + +class VectorQuantizer(torch.nn.Module): + """ + Vector Quantization wrapper that is needed as a workaround for the AMP to isolate the non fp16 compatible parts of + the quantization in their own class. + + Args: + quantizer (torch.nn.Module): Quantizer module that needs to return its quantized representation, loss and index + based quantized representation. + """ + + def __init__(self, quantizer: EMAQuantizer): + super().__init__() + + self.quantizer: EMAQuantizer = quantizer + + self.perplexity: torch.Tensor = torch.rand(1) + + def forward(self, inputs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + quantized, loss, encoding_indices = self.quantizer(inputs) + # Perplexity calculations + avg_probs = ( + torch.histc(encoding_indices.float(), bins=self.quantizer.num_embeddings, max=self.quantizer.num_embeddings) + .float() + .div(encoding_indices.numel()) + ) + + self.perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10))) + + return loss, quantized + + def embed(self, embedding_indices: torch.Tensor) -> torch.Tensor: + return self.quantizer.embed(embedding_indices=embedding_indices) + + def quantize(self, encodings: torch.Tensor) -> torch.Tensor: + output = self.quantizer(encodings) + encoding_indices: torch.Tensor = output[2] + return encoding_indices diff --git a/monai/networks/nets/__init__.py b/monai/networks/nets/__init__.py index ea08246d25..db3c77c717 100644 --- a/monai/networks/nets/__init__.py +++ b/monai/networks/nets/__init__.py @@ -113,3 +113,4 @@ from .vitautoenc import ViTAutoEnc from .vnet import VNet from .voxelmorph import VoxelMorph, VoxelMorphUNet +from .vqvae import VQVAE diff --git a/monai/networks/nets/autoencoderkl.py b/monai/networks/nets/autoencoderkl.py index 9a9f35d5ae..f7ae77f056 100644 --- a/monai/networks/nets/autoencoderkl.py +++ b/monai/networks/nets/autoencoderkl.py @@ -38,7 +38,7 @@ class _Upsample(nn.Module): Convolution-based upsampling layer. Args: - spatial_dims: number of spatial dimensions (1D, 2D, 3D). + spatial_dims: number of spatial dimensions, could be 1, 2, or 3. in_channels: number of input channels to the layer. use_convtranspose: if True, use ConvTranspose to upsample feature maps in decoder. """ @@ -98,7 +98,7 @@ class _Downsample(nn.Module): Convolution-based downsampling layer. Args: - spatial_dims: number of spatial dimensions (1D, 2D, 3D). + spatial_dims: number of spatial dimensions, could be 1, 2, or 3. in_channels: number of input channels. """ @@ -132,7 +132,7 @@ class _ResBlock(nn.Module): residual connection between input and output. Args: - spatial_dims: number of spatial dimensions (1D, 2D, 3D). + spatial_dims: number of spatial dimensions, could be 1, 2, or 3. in_channels: input channels to the layer. norm_num_groups: number of groups involved for the group normalisation layer. Ensure that your number of channels is divisible by this number. @@ -206,7 +206,7 @@ class _AttentionBlock(nn.Module): Attention block. Args: - spatial_dims: number of spatial dimensions (1D, 2D, 3D). + spatial_dims: number of spatial dimensions, could be 1, 2, or 3. num_channels: number of input channels. num_head_channels: number of channels in each attention head. norm_num_groups: number of groups involved for the group normalisation layer. Ensure that your number of @@ -325,7 +325,7 @@ class Encoder(nn.Module): Convolutional cascade that downsamples the image into a spatial latent space. Args: - spatial_dims: number of spatial dimensions (1D, 2D, 3D). + spatial_dims: number of spatial dimensions, could be 1, 2, or 3. in_channels: number of input channels. channels: sequence of block output channels. out_channels: number of channels in the bottom layer (latent space) of the autoencoder. @@ -463,7 +463,7 @@ class Decoder(nn.Module): Convolutional cascade upsampling from a spatial latent space into an image space. Args: - spatial_dims: number of spatial dimensions (1D, 2D, 3D). + spatial_dims: number of spatial dimensions, could be 1, 2, or 3. channels: sequence of block output channels. in_channels: number of channels in the bottom layer (latent space) of the autoencoder. out_channels: number of output channels. @@ -611,7 +611,7 @@ class AutoencoderKL(nn.Module): and Pinaya et al. "Brain Imaging Generation with Latent Diffusion Models" https://arxiv.org/abs/2209.07162 Args: - spatial_dims: number of spatial dimensions (1D, 2D, 3D). + spatial_dims: number of spatial dimensions, could be 1, 2, or 3. in_channels: number of input channels. out_channels: number of output channels. num_res_blocks: number of residual blocks (see _ResBlock) per level. diff --git a/monai/networks/nets/vqvae.py b/monai/networks/nets/vqvae.py new file mode 100644 index 0000000000..d4771e203a --- /dev/null +++ b/monai/networks/nets/vqvae.py @@ -0,0 +1,466 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from collections.abc import Sequence +from typing import Tuple + +import torch +import torch.nn as nn + +from monai.networks.blocks import Convolution +from monai.networks.layers import Act +from monai.networks.layers.vector_quantizer import EMAQuantizer, VectorQuantizer +from monai.utils import ensure_tuple_rep + +__all__ = ["VQVAE"] + + +class VQVAEResidualUnit(nn.Module): + """ + Implementation of the ResidualLayer used in the VQVAE network as originally used in Morphology-preserving + Autoregressive 3D Generative Modelling of the Brain by Tudosiu et al. (https://arxiv.org/pdf/2209.03177.pdf). + + The original implementation that can be found at + https://github.com/AmigoLab/SynthAnatomy/blob/main/src/networks/vqvae/baseline.py#L150. + + Args: + spatial_dims: number of spatial spatial_dims of the input data. + in_channels: number of input channels. + num_res_channels: number of channels in the residual layers. + act: activation type and arguments. Defaults to RELU. + dropout: dropout ratio. Defaults to no dropout. + bias: whether to have a bias term. Defaults to True. + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + num_res_channels: int, + act: tuple | str | None = Act.RELU, + dropout: float = 0.0, + bias: bool = True, + ) -> None: + super().__init__() + + self.spatial_dims = spatial_dims + self.in_channels = in_channels + self.num_res_channels = num_res_channels + self.act = act + self.dropout = dropout + self.bias = bias + + self.conv1 = Convolution( + spatial_dims=self.spatial_dims, + in_channels=self.in_channels, + out_channels=self.num_res_channels, + adn_ordering="DA", + act=self.act, + dropout=self.dropout, + bias=self.bias, + ) + + self.conv2 = Convolution( + spatial_dims=self.spatial_dims, + in_channels=self.num_res_channels, + out_channels=self.in_channels, + bias=self.bias, + conv_only=True, + ) + + def forward(self, x): + return torch.nn.functional.relu(x + self.conv2(self.conv1(x)), True) + + +class Encoder(nn.Module): + """ + Encoder module for VQ-VAE. + + Args: + spatial_dims: number of spatial spatial_dims. + in_channels: number of input channels. + out_channels: number of channels in the latent space (embedding_dim). + channels: sequence containing the number of channels at each level of the encoder. + num_res_layers: number of sequential residual layers at each level. + num_res_channels: number of channels in the residual layers at each level. + downsample_parameters: A Tuple of Tuples for defining the downsampling convolutions. Each Tuple should hold the + following information stride (int), kernel_size (int), dilation (int) and padding (int). + dropout: dropout ratio. + act: activation type and arguments. + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + out_channels: int, + channels: Sequence[int], + num_res_layers: int, + num_res_channels: Sequence[int], + downsample_parameters: Sequence[Tuple[int, int, int, int]], + dropout: float, + act: tuple | str | None, + ) -> None: + super().__init__() + self.spatial_dims = spatial_dims + self.in_channels = in_channels + self.out_channels = out_channels + self.channels = channels + self.num_res_layers = num_res_layers + self.num_res_channels = num_res_channels + self.downsample_parameters = downsample_parameters + self.dropout = dropout + self.act = act + + blocks: list[nn.Module] = [] + + for i in range(len(self.channels)): + blocks.append( + Convolution( + spatial_dims=self.spatial_dims, + in_channels=self.in_channels if i == 0 else self.channels[i - 1], + out_channels=self.channels[i], + strides=self.downsample_parameters[i][0], + kernel_size=self.downsample_parameters[i][1], + adn_ordering="DA", + act=self.act, + dropout=None if i == 0 else self.dropout, + dropout_dim=1, + dilation=self.downsample_parameters[i][2], + padding=self.downsample_parameters[i][3], + ) + ) + + for _ in range(self.num_res_layers): + blocks.append( + VQVAEResidualUnit( + spatial_dims=self.spatial_dims, + in_channels=self.channels[i], + num_res_channels=self.num_res_channels[i], + act=self.act, + dropout=self.dropout, + ) + ) + + blocks.append( + Convolution( + spatial_dims=self.spatial_dims, + in_channels=self.channels[len(self.channels) - 1], + out_channels=self.out_channels, + strides=1, + kernel_size=3, + padding=1, + conv_only=True, + ) + ) + + self.blocks = nn.ModuleList(blocks) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + for block in self.blocks: + x = block(x) + return x + + +class Decoder(nn.Module): + """ + Decoder module for VQ-VAE. + + Args: + spatial_dims: number of spatial spatial_dims. + in_channels: number of channels in the latent space (embedding_dim). + out_channels: number of output channels. + channels: sequence containing the number of channels at each level of the decoder. + num_res_layers: number of sequential residual layers at each level. + num_res_channels: number of channels in the residual layers at each level. + upsample_parameters: A Tuple of Tuples for defining the upsampling convolutions. Each Tuple should hold the + following information stride (int), kernel_size (int), dilation (int), padding (int), output_padding (int). + dropout: dropout ratio. + act: activation type and arguments. + output_act: activation type and arguments for the output. + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + out_channels: int, + channels: Sequence[int], + num_res_layers: int, + num_res_channels: Sequence[int], + upsample_parameters: Sequence[Tuple[int, int, int, int, int]], + dropout: float, + act: tuple | str | None, + output_act: tuple | str | None, + ) -> None: + super().__init__() + self.spatial_dims = spatial_dims + self.in_channels = in_channels + self.out_channels = out_channels + self.channels = channels + self.num_res_layers = num_res_layers + self.num_res_channels = num_res_channels + self.upsample_parameters = upsample_parameters + self.dropout = dropout + self.act = act + self.output_act = output_act + + reversed_num_channels = list(reversed(self.channels)) + + blocks: list[nn.Module] = [] + blocks.append( + Convolution( + spatial_dims=self.spatial_dims, + in_channels=self.in_channels, + out_channels=reversed_num_channels[0], + strides=1, + kernel_size=3, + padding=1, + conv_only=True, + ) + ) + + reversed_num_res_channels = list(reversed(self.num_res_channels)) + for i in range(len(self.channels)): + for _ in range(self.num_res_layers): + blocks.append( + VQVAEResidualUnit( + spatial_dims=self.spatial_dims, + in_channels=reversed_num_channels[i], + num_res_channels=reversed_num_res_channels[i], + act=self.act, + dropout=self.dropout, + ) + ) + + blocks.append( + Convolution( + spatial_dims=self.spatial_dims, + in_channels=reversed_num_channels[i], + out_channels=self.out_channels if i == len(self.channels) - 1 else reversed_num_channels[i + 1], + strides=self.upsample_parameters[i][0], + kernel_size=self.upsample_parameters[i][1], + adn_ordering="DA", + act=self.act, + dropout=self.dropout if i != len(self.channels) - 1 else None, + norm=None, + dilation=self.upsample_parameters[i][2], + conv_only=i == len(self.channels) - 1, + is_transposed=True, + padding=self.upsample_parameters[i][3], + output_padding=self.upsample_parameters[i][4], + ) + ) + + if self.output_act: + blocks.append(Act[self.output_act]()) + + self.blocks = nn.ModuleList(blocks) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + for block in self.blocks: + x = block(x) + return x + + +class VQVAE(nn.Module): + """ + Vector-Quantised Variational Autoencoder (VQ-VAE) used in Morphology-preserving Autoregressive 3D Generative + Modelling of the Brain by Tudosiu et al. (https://arxiv.org/pdf/2209.03177.pdf) + + The original implementation can be found at + https://github.com/AmigoLab/SynthAnatomy/blob/main/src/networks/vqvae/baseline.py#L163/ + + Args: + spatial_dims: number of spatial spatial_dims. + in_channels: number of input channels. + out_channels: number of output channels. + downsample_parameters: A Tuple of Tuples for defining the downsampling convolutions. Each Tuple should hold the + following information stride (int), kernel_size (int), dilation (int) and padding (int). + upsample_parameters: A Tuple of Tuples for defining the upsampling convolutions. Each Tuple should hold the + following information stride (int), kernel_size (int), dilation (int), padding (int), output_padding (int). + num_res_layers: number of sequential residual layers at each level. + channels: number of channels at each level. + num_res_channels: number of channels in the residual layers at each level. + num_embeddings: VectorQuantization number of atomic elements in the codebook. + embedding_dim: VectorQuantization number of channels of the input and atomic elements. + commitment_cost: VectorQuantization commitment_cost. + decay: VectorQuantization decay. + epsilon: VectorQuantization epsilon. + act: activation type and arguments. + dropout: dropout ratio. + output_act: activation type and arguments for the output. + ddp_sync: whether to synchronize the codebook across processes. + use_checkpointing if True, use activation checkpointing to save memory. + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + out_channels: int, + channels: Sequence[int] = (96, 96, 192), + num_res_layers: int = 3, + num_res_channels: Sequence[int] | int = (96, 96, 192), + downsample_parameters: Sequence[Tuple[int, int, int, int]] + | Tuple[int, int, int, int] = ((2, 4, 1, 1), (2, 4, 1, 1), (2, 4, 1, 1)), + upsample_parameters: Sequence[Tuple[int, int, int, int, int]] + | Tuple[int, int, int, int, int] = ((2, 4, 1, 1, 0), (2, 4, 1, 1, 0), (2, 4, 1, 1, 0)), + num_embeddings: int = 32, + embedding_dim: int = 64, + embedding_init: str = "normal", + commitment_cost: float = 0.25, + decay: float = 0.5, + epsilon: float = 1e-5, + dropout: float = 0.0, + act: tuple | str | None = Act.RELU, + output_act: tuple | str | None = None, + ddp_sync: bool = True, + use_checkpointing: bool = False, + ): + super().__init__() + + self.in_channels = in_channels + self.out_channels = out_channels + self.spatial_dims = spatial_dims + self.channels = channels + self.num_embeddings = num_embeddings + self.embedding_dim = embedding_dim + self.use_checkpointing = use_checkpointing + + if isinstance(num_res_channels, int): + num_res_channels = ensure_tuple_rep(num_res_channels, len(channels)) + + if len(num_res_channels) != len(channels): + raise ValueError( + "`num_res_channels` should be a single integer or a tuple of integers with the same length as " + "`num_channls`." + ) + if all(isinstance(values, int) for values in upsample_parameters): + upsample_parameters_tuple: Sequence = (upsample_parameters,) * len(channels) + else: + upsample_parameters_tuple = upsample_parameters + + if all(isinstance(values, int) for values in downsample_parameters): + downsample_parameters_tuple: Sequence = (downsample_parameters,) * len(channels) + else: + downsample_parameters_tuple = downsample_parameters + + if not all(all(isinstance(value, int) for value in sub_item) for sub_item in downsample_parameters_tuple): + raise ValueError("`downsample_parameters` should be a single tuple of integer or a tuple of tuples.") + + # check if downsample_parameters is a tuple of ints or a tuple of tuples of ints + if not all(all(isinstance(value, int) for value in sub_item) for sub_item in upsample_parameters_tuple): + raise ValueError("`upsample_parameters` should be a single tuple of integer or a tuple of tuples.") + + for parameter in downsample_parameters_tuple: + if len(parameter) != 4: + raise ValueError("`downsample_parameters` should be a tuple of tuples with 4 integers.") + + for parameter in upsample_parameters_tuple: + if len(parameter) != 5: + raise ValueError("`upsample_parameters` should be a tuple of tuples with 5 integers.") + + if len(downsample_parameters_tuple) != len(channels): + raise ValueError( + "`downsample_parameters` should be a tuple of tuples with the same length as `num_channels`." + ) + + if len(upsample_parameters_tuple) != len(channels): + raise ValueError( + "`upsample_parameters` should be a tuple of tuples with the same length as `num_channels`." + ) + + self.num_res_layers = num_res_layers + self.num_res_channels = num_res_channels + + self.encoder = Encoder( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=embedding_dim, + channels=channels, + num_res_layers=num_res_layers, + num_res_channels=num_res_channels, + downsample_parameters=downsample_parameters_tuple, + dropout=dropout, + act=act, + ) + + self.decoder = Decoder( + spatial_dims=spatial_dims, + in_channels=embedding_dim, + out_channels=out_channels, + channels=channels, + num_res_layers=num_res_layers, + num_res_channels=num_res_channels, + upsample_parameters=upsample_parameters_tuple, + dropout=dropout, + act=act, + output_act=output_act, + ) + + self.quantizer = VectorQuantizer( + quantizer=EMAQuantizer( + spatial_dims=spatial_dims, + num_embeddings=num_embeddings, + embedding_dim=embedding_dim, + commitment_cost=commitment_cost, + decay=decay, + epsilon=epsilon, + embedding_init=embedding_init, + ddp_sync=ddp_sync, + ) + ) + + def encode(self, images: torch.Tensor) -> torch.Tensor: + output: torch.Tensor + if self.use_checkpointing: + output = torch.utils.checkpoint.checkpoint(self.encoder, images, use_reentrant=False) + else: + output = self.encoder(images) + return output + + def quantize(self, encodings: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + x_loss, x = self.quantizer(encodings) + return x, x_loss + + def decode(self, quantizations: torch.Tensor) -> torch.Tensor: + output: torch.Tensor + + if self.use_checkpointing: + output = torch.utils.checkpoint.checkpoint(self.decoder, quantizations, use_reentrant=False) + else: + output = self.decoder(quantizations) + return output + + def index_quantize(self, images: torch.Tensor) -> torch.Tensor: + return self.quantizer.quantize(self.encode(images=images)) + + def decode_samples(self, embedding_indices: torch.Tensor) -> torch.Tensor: + return self.decode(self.quantizer.embed(embedding_indices)) + + def forward(self, images: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + quantizations, quantization_losses = self.quantize(self.encode(images)) + reconstruction = self.decode(quantizations) + + return reconstruction, quantization_losses + + def encode_stage_2_inputs(self, x: torch.Tensor) -> torch.Tensor: + z = self.encode(x) + e, _ = self.quantize(z) + return e + + def decode_stage_2_outputs(self, z: torch.Tensor) -> torch.Tensor: + e, _ = self.quantize(z) + image = self.decode(e) + return image diff --git a/tests/test_vector_quantizer.py b/tests/test_vector_quantizer.py new file mode 100644 index 0000000000..43533d0377 --- /dev/null +++ b/tests/test_vector_quantizer.py @@ -0,0 +1,89 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest +from math import prod + +import torch +from parameterized import parameterized + +from monai.networks.layers import EMAQuantizer, VectorQuantizer + +TEST_CASES = [ + [{"spatial_dims": 2, "num_embeddings": 16, "embedding_dim": 8}, (1, 8, 4, 4), (1, 4, 4)], + [{"spatial_dims": 3, "num_embeddings": 16, "embedding_dim": 8}, (1, 8, 4, 4, 4), (1, 4, 4, 4)], +] + + +class TestEMA(unittest.TestCase): + @parameterized.expand(TEST_CASES) + def test_ema_shape(self, input_param, input_shape, output_shape): + layer = EMAQuantizer(**input_param) + x = torch.randn(input_shape) + layer = layer.train() + outputs = layer(x) + self.assertEqual(outputs[0].shape, input_shape) + self.assertEqual(outputs[2].shape, output_shape) + + layer = layer.eval() + outputs = layer(x) + self.assertEqual(outputs[0].shape, input_shape) + self.assertEqual(outputs[2].shape, output_shape) + + @parameterized.expand(TEST_CASES) + def test_ema_quantize(self, input_param, input_shape, output_shape): + layer = EMAQuantizer(**input_param) + x = torch.randn(input_shape) + outputs = layer.quantize(x) + self.assertEqual(outputs[0].shape, (prod(input_shape[2:]), input_shape[1])) # (HxW[xD], C) + self.assertEqual(outputs[1].shape, (prod(input_shape[2:]), input_param["num_embeddings"])) # (HxW[xD], E) + self.assertEqual(outputs[2].shape, (input_shape[0],) + input_shape[2:]) # (1, H, W, [D]) + + def test_ema(self): + layer = EMAQuantizer(spatial_dims=2, num_embeddings=2, embedding_dim=2, epsilon=0, decay=0) + original_weight_0 = layer.embedding.weight[0].clone() + original_weight_1 = layer.embedding.weight[1].clone() + x_0 = original_weight_0 + x_0 = x_0.unsqueeze(0).unsqueeze(-1).unsqueeze(-1) + x_0 = x_0.repeat(1, 1, 1, 2) + 0.001 + + x_1 = original_weight_1 + x_1 = x_1.unsqueeze(0).unsqueeze(-1).unsqueeze(-1) + x_1 = x_1.repeat(1, 1, 1, 2) + + x = torch.cat([x_0, x_1], dim=0) + layer = layer.train() + _ = layer(x) + + self.assertTrue(all(layer.embedding.weight[0] != original_weight_0)) + self.assertTrue(all(layer.embedding.weight[1] == original_weight_1)) + + +class TestVectorQuantizer(unittest.TestCase): + @parameterized.expand(TEST_CASES) + def test_vector_quantizer_shape(self, input_param, input_shape, output_shape): + layer = VectorQuantizer(EMAQuantizer(**input_param)) + x = torch.randn(input_shape) + outputs = layer(x) + self.assertEqual(outputs[1].shape, input_shape) + + @parameterized.expand(TEST_CASES) + def test_vector_quantizer_quantize(self, input_param, input_shape, output_shape): + layer = VectorQuantizer(EMAQuantizer(**input_param)) + x = torch.randn(input_shape) + outputs = layer.quantize(x) + self.assertEqual(outputs.shape, output_shape) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_vqvae.py b/tests/test_vqvae.py new file mode 100644 index 0000000000..4916dc2faa --- /dev/null +++ b/tests/test_vqvae.py @@ -0,0 +1,274 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import torch +from parameterized import parameterized + +from monai.networks import eval_mode +from monai.networks.nets.vqvae import VQVAE +from tests.utils import SkipIfBeforePyTorchVersion + +TEST_CASES = [ + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "channels": (4, 4), + "num_res_layers": 1, + "num_res_channels": (4, 4), + "downsample_parameters": ((2, 4, 1, 1),) * 2, + "upsample_parameters": ((2, 4, 1, 1, 0),) * 2, + "num_embeddings": 8, + "embedding_dim": 8, + }, + (1, 1, 8, 8), + (1, 1, 8, 8), + ], + [ + { + "spatial_dims": 3, + "in_channels": 1, + "out_channels": 1, + "channels": (4, 4), + "num_res_layers": 1, + "num_res_channels": 4, + "downsample_parameters": ((2, 4, 1, 1),) * 2, + "upsample_parameters": ((2, 4, 1, 1, 0),) * 2, + "num_embeddings": 8, + "embedding_dim": 8, + }, + (1, 1, 8, 8, 8), + (1, 1, 8, 8, 8), + ], + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "channels": (4, 4), + "num_res_layers": 1, + "num_res_channels": (4, 4), + "downsample_parameters": (2, 4, 1, 1), + "upsample_parameters": ((2, 4, 1, 1, 0),) * 2, + "num_embeddings": 8, + "embedding_dim": 8, + }, + (1, 1, 8, 8), + (1, 1, 8, 8), + ], + [ + { + "spatial_dims": 3, + "in_channels": 1, + "out_channels": 1, + "channels": (4, 4), + "num_res_layers": 1, + "num_res_channels": (4, 4), + "downsample_parameters": ((2, 4, 1, 1),) * 2, + "upsample_parameters": (2, 4, 1, 1, 0), + "num_embeddings": 8, + "embedding_dim": 8, + }, + (1, 1, 8, 8, 8), + (1, 1, 8, 8, 8), + ], +] + +TEST_LATENT_SHAPE = { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "downsample_parameters": ((2, 4, 1, 1),) * 2, + "upsample_parameters": ((2, 4, 1, 1, 0),) * 2, + "num_res_layers": 1, + "channels": (8, 8), + "num_res_channels": (8, 8), + "num_embeddings": 16, + "embedding_dim": 8, +} + + +class TestVQVAE(unittest.TestCase): + @parameterized.expand(TEST_CASES) + def test_shape(self, input_param, input_shape, expected_shape): + device = "cuda" if torch.cuda.is_available() else "cpu" + + net = VQVAE(**input_param).to(device) + + with eval_mode(net): + result, _ = net(torch.randn(input_shape).to(device)) + + self.assertEqual(result.shape, expected_shape) + + @parameterized.expand(TEST_CASES) + @SkipIfBeforePyTorchVersion((1, 11)) + def test_shape_with_checkpoint(self, input_param, input_shape, expected_shape): + device = "cuda" if torch.cuda.is_available() else "cpu" + input_param = input_param.copy() + input_param.update({"use_checkpointing": True}) + + net = VQVAE(**input_param).to(device) + + with eval_mode(net): + result, _ = net(torch.randn(input_shape).to(device)) + + self.assertEqual(result.shape, expected_shape) + + # Removed this test case since TorchScript currently does not support activation checkpoint. + # def test_script(self): + # net = VQVAE( + # spatial_dims=2, + # in_channels=1, + # out_channels=1, + # downsample_parameters=((2, 4, 1, 1),) * 2, + # upsample_parameters=((2, 4, 1, 1, 0),) * 2, + # num_res_layers=1, + # channels=(8, 8), + # num_res_channels=(8, 8), + # num_embeddings=16, + # embedding_dim=8, + # ddp_sync=False, + # ) + # test_data = torch.randn(1, 1, 16, 16) + # test_script_save(net, test_data) + + def test_channels_not_same_size_of_num_res_channels(self): + with self.assertRaises(ValueError): + VQVAE( + spatial_dims=2, + in_channels=1, + out_channels=1, + channels=(16, 16), + num_res_channels=(16, 16, 16), + downsample_parameters=((2, 4, 1, 1),) * 2, + upsample_parameters=((2, 4, 1, 1, 0),) * 2, + ) + + def test_channels_not_same_size_of_downsample_parameters(self): + with self.assertRaises(ValueError): + VQVAE( + spatial_dims=2, + in_channels=1, + out_channels=1, + channels=(16, 16), + num_res_channels=(16, 16), + downsample_parameters=((2, 4, 1, 1),) * 3, + upsample_parameters=((2, 4, 1, 1, 0),) * 2, + ) + + def test_channels_not_same_size_of_upsample_parameters(self): + with self.assertRaises(ValueError): + VQVAE( + spatial_dims=2, + in_channels=1, + out_channels=1, + channels=(16, 16), + num_res_channels=(16, 16), + downsample_parameters=((2, 4, 1, 1),) * 2, + upsample_parameters=((2, 4, 1, 1, 0),) * 3, + ) + + def test_downsample_parameters_not_sequence_or_int(self): + with self.assertRaises(ValueError): + VQVAE( + spatial_dims=2, + in_channels=1, + out_channels=1, + channels=(16, 16), + num_res_channels=(16, 16), + downsample_parameters=(("test", 4, 1, 1),) * 2, + upsample_parameters=((2, 4, 1, 1, 0),) * 2, + ) + + def test_upsample_parameters_not_sequence_or_int(self): + with self.assertRaises(ValueError): + VQVAE( + spatial_dims=2, + in_channels=1, + out_channels=1, + channels=(16, 16), + num_res_channels=(16, 16), + downsample_parameters=((2, 4, 1, 1),) * 2, + upsample_parameters=(("test", 4, 1, 1, 0),) * 2, + ) + + def test_downsample_parameter_length_different_4(self): + with self.assertRaises(ValueError): + VQVAE( + spatial_dims=2, + in_channels=1, + out_channels=1, + channels=(16, 16), + num_res_channels=(16, 16), + downsample_parameters=((2, 4, 1),) * 3, + upsample_parameters=((2, 4, 1, 1, 0),) * 2, + ) + + def test_upsample_parameter_length_different_5(self): + with self.assertRaises(ValueError): + VQVAE( + spatial_dims=2, + in_channels=1, + out_channels=1, + channels=(16, 16), + num_res_channels=(16, 16, 16), + downsample_parameters=((2, 4, 1, 1),) * 2, + upsample_parameters=((2, 4, 1, 1, 0, 1),) * 3, + ) + + def test_encode_shape(self): + device = "cuda" if torch.cuda.is_available() else "cpu" + + net = VQVAE(**TEST_LATENT_SHAPE).to(device) + + with eval_mode(net): + latent = net.encode(torch.randn(1, 1, 32, 32).to(device)) + + self.assertEqual(latent.shape, (1, 8, 8, 8)) + + def test_index_quantize_shape(self): + device = "cuda" if torch.cuda.is_available() else "cpu" + + net = VQVAE(**TEST_LATENT_SHAPE).to(device) + + with eval_mode(net): + latent = net.index_quantize(torch.randn(1, 1, 32, 32).to(device)) + + self.assertEqual(latent.shape, (1, 8, 8)) + + def test_decode_shape(self): + device = "cuda" if torch.cuda.is_available() else "cpu" + + net = VQVAE(**TEST_LATENT_SHAPE).to(device) + + with eval_mode(net): + latent = net.decode(torch.randn(1, 8, 8, 8).to(device)) + + self.assertEqual(latent.shape, (1, 1, 32, 32)) + + def test_decode_samples_shape(self): + device = "cuda" if torch.cuda.is_available() else "cpu" + + net = VQVAE(**TEST_LATENT_SHAPE).to(device) + + with eval_mode(net): + latent = net.decode_samples(torch.randint(low=0, high=16, size=(1, 8, 8)).to(device)) + + self.assertEqual(latent.shape, (1, 1, 32, 32)) + + +if __name__ == "__main__": + unittest.main() From c61c6ac2d56af08fbbfb955324b8639e266a25db Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Mon, 11 Dec 2023 11:10:20 -0500 Subject: [PATCH 03/38] 6676 port generative networks transformer (#7300) Towards #6676 . ### Description Adds a simple decoder-only transformer architecture. ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [x] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [x] In-line docstrings updated. - [x] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Mark Graham --- docs/source/networks.rst | 5 + monai/networks/nets/__init__.py | 1 + monai/networks/nets/transformer.py | 314 +++++++++++++++++++++++++++++ tests/test_transformer.py | 73 +++++++ 4 files changed, 393 insertions(+) create mode 100644 monai/networks/nets/transformer.py create mode 100644 tests/test_transformer.py diff --git a/docs/source/networks.rst b/docs/source/networks.rst index d8be26264b..06f60fe8af 100644 --- a/docs/source/networks.rst +++ b/docs/source/networks.rst @@ -613,6 +613,11 @@ Nets .. autoclass:: VarAutoEncoder :members: +`DecoderOnlyTransformer` +~~~~~~~~~~~~~~~~~~~~~~~~ +.. autoclass:: DecoderOnlyTransformer + :members: + `ViT` ~~~~~ .. autoclass:: ViT diff --git a/monai/networks/nets/__init__.py b/monai/networks/nets/__init__.py index db3c77c717..08384b4d52 100644 --- a/monai/networks/nets/__init__.py +++ b/monai/networks/nets/__init__.py @@ -106,6 +106,7 @@ from .swin_unetr import PatchMerging, PatchMergingV2, SwinUNETR from .torchvision_fc import TorchVisionFCModel from .transchex import BertAttention, BertMixedLayer, BertOutput, BertPreTrainedModel, MultiModal, Pooler, Transchex +from .transformer import DecoderOnlyTransformer from .unet import UNet, Unet from .unetr import UNETR from .varautoencoder import VarAutoEncoder diff --git a/monai/networks/nets/transformer.py b/monai/networks/nets/transformer.py new file mode 100644 index 0000000000..b742c12205 --- /dev/null +++ b/monai/networks/nets/transformer.py @@ -0,0 +1,314 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from monai.networks.blocks.mlp import MLPBlock +from monai.utils import optional_import + +xops, has_xformers = optional_import("xformers.ops") +__all__ = ["DecoderOnlyTransformer"] + + +class _SABlock(nn.Module): + """ + NOTE This is a private block that we plan to merge with existing MONAI blocks in the future. Please do not make + use of this block as support is not guaranteed. For more information see: + https://github.com/Project-MONAI/MONAI/issues/7227 + + A self-attention block, based on: "Dosovitskiy et al., + An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale " + + Args: + hidden_size: dimension of hidden layer. + num_heads: number of attention heads. + dropout_rate: dropout ratio. Defaults to no dropout. + qkv_bias: bias term for the qkv linear layer. + causal: whether to use causal attention. + sequence_length: if causal is True, it is necessary to specify the sequence length. + with_cross_attention: Whether to use cross attention for conditioning. + use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. + """ + + def __init__( + self, + hidden_size: int, + num_heads: int, + dropout_rate: float = 0.0, + qkv_bias: bool = False, + causal: bool = False, + sequence_length: int | None = None, + with_cross_attention: bool = False, + use_flash_attention: bool = False, + ) -> None: + super().__init__() + self.hidden_size = hidden_size + self.num_heads = num_heads + self.head_dim = hidden_size // num_heads + self.scale = 1.0 / math.sqrt(self.head_dim) + self.causal = causal + self.sequence_length = sequence_length + self.with_cross_attention = with_cross_attention + self.use_flash_attention = use_flash_attention + + if not (0 <= dropout_rate <= 1): + raise ValueError("dropout_rate should be between 0 and 1.") + self.dropout_rate = dropout_rate + + if hidden_size % num_heads != 0: + raise ValueError("hidden size should be divisible by num_heads.") + + if causal and sequence_length is None: + raise ValueError("sequence_length is necessary for causal attention.") + + if use_flash_attention and not has_xformers: + raise ValueError("use_flash_attention is True but xformers is not installed.") + + # key, query, value projections + self.to_q = nn.Linear(hidden_size, hidden_size, bias=qkv_bias) + self.to_k = nn.Linear(hidden_size, hidden_size, bias=qkv_bias) + self.to_v = nn.Linear(hidden_size, hidden_size, bias=qkv_bias) + + # regularization + self.drop_weights = nn.Dropout(dropout_rate) + self.drop_output = nn.Dropout(dropout_rate) + + # output projection + self.out_proj = nn.Linear(hidden_size, hidden_size) + + if causal and sequence_length is not None: + # causal mask to ensure that attention is only applied to the left in the input sequence + self.register_buffer( + "causal_mask", + torch.tril(torch.ones(sequence_length, sequence_length)).view(1, 1, sequence_length, sequence_length), + ) + self.causal_mask: torch.Tensor + + def forward(self, x: torch.Tensor, context: torch.Tensor | None = None) -> torch.Tensor: + b, t, c = x.size() # batch size, sequence length, embedding dimensionality (hidden_size) + + # calculate query, key, values for all heads in batch and move head forward to be the batch dim + query = self.to_q(x) + + kv = context if context is not None else x + _, kv_t, _ = kv.size() + key = self.to_k(kv) + value = self.to_v(kv) + + query = query.view(b, t, self.num_heads, c // self.num_heads) # (b, t, nh, hs) + key = key.view(b, kv_t, self.num_heads, c // self.num_heads) # (b, kv_t, nh, hs) + value = value.view(b, kv_t, self.num_heads, c // self.num_heads) # (b, kv_t, nh, hs) + y: torch.Tensor + if self.use_flash_attention: + query = query.contiguous() + key = key.contiguous() + value = value.contiguous() + y = xops.memory_efficient_attention( + query=query, + key=key, + value=value, + scale=self.scale, + p=self.dropout_rate, + attn_bias=xops.LowerTriangularMask() if self.causal else None, + ) + + else: + query = query.transpose(1, 2) # (b, nh, t, hs) + key = key.transpose(1, 2) # (b, nh, kv_t, hs) + value = value.transpose(1, 2) # (b, nh, kv_t, hs) + + # manual implementation of attention + query = query * self.scale + attention_scores = query @ key.transpose(-2, -1) + + if self.causal: + attention_scores = attention_scores.masked_fill(self.causal_mask[:, :, :t, :kv_t] == 0, float("-inf")) + + attention_probs = F.softmax(attention_scores, dim=-1) + attention_probs = self.drop_weights(attention_probs) + y = attention_probs @ value # (b, nh, t, kv_t) x (b, nh, kv_t, hs) -> (b, nh, t, hs) + + y = y.transpose(1, 2) # (b, nh, t, hs) -> (b, t, nh, hs) + + y = y.contiguous().view(b, t, c) # re-assemble all head outputs side by side + + y = self.out_proj(y) + y = self.drop_output(y) + return y + + +class _TransformerBlock(nn.Module): + """ + NOTE This is a private block that we plan to merge with existing MONAI blocks in the future. Please do not make + use of this block as support is not guaranteed. For more information see: + https://github.com/Project-MONAI/MONAI/issues/7227 + + A transformer block, based on: "Dosovitskiy et al., + An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale " + + Args: + hidden_size: dimension of hidden layer. + mlp_dim: dimension of feedforward layer. + num_heads: number of attention heads. + dropout_rate: faction of the input units to drop. + qkv_bias: apply bias term for the qkv linear layer + causal: whether to use causal attention. + sequence_length: if causal is True, it is necessary to specify the sequence length. + with_cross_attention: Whether to use cross attention for conditioning. + use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. + """ + + def __init__( + self, + hidden_size: int, + mlp_dim: int, + num_heads: int, + dropout_rate: float = 0.0, + qkv_bias: bool = False, + causal: bool = False, + sequence_length: int | None = None, + with_cross_attention: bool = False, + use_flash_attention: bool = False, + ) -> None: + self.with_cross_attention = with_cross_attention + super().__init__() + + if not (0 <= dropout_rate <= 1): + raise ValueError("dropout_rate should be between 0 and 1.") + + if hidden_size % num_heads != 0: + raise ValueError("hidden_size should be divisible by num_heads.") + + self.norm1 = nn.LayerNorm(hidden_size) + self.attn = _SABlock( + hidden_size=hidden_size, + num_heads=num_heads, + dropout_rate=dropout_rate, + qkv_bias=qkv_bias, + causal=causal, + sequence_length=sequence_length, + use_flash_attention=use_flash_attention, + ) + + if self.with_cross_attention: + self.norm2 = nn.LayerNorm(hidden_size) + self.cross_attn = _SABlock( + hidden_size=hidden_size, + num_heads=num_heads, + dropout_rate=dropout_rate, + qkv_bias=qkv_bias, + with_cross_attention=with_cross_attention, + causal=False, + use_flash_attention=use_flash_attention, + ) + self.norm3 = nn.LayerNorm(hidden_size) + self.mlp = MLPBlock(hidden_size, mlp_dim, dropout_rate) + + def forward(self, x: torch.Tensor, context: torch.Tensor | None = None) -> torch.Tensor: + x = x + self.attn(self.norm1(x)) + if self.with_cross_attention: + x = x + self.cross_attn(self.norm2(x), context=context) + x = x + self.mlp(self.norm3(x)) + return x + + +class AbsolutePositionalEmbedding(nn.Module): + """Absolute positional embedding. + + Args: + max_seq_len: Maximum sequence length. + embedding_dim: Dimensionality of the embedding. + """ + + def __init__(self, max_seq_len: int, embedding_dim: int) -> None: + super().__init__() + self.max_seq_len = max_seq_len + self.embedding_dim = embedding_dim + self.embedding = nn.Embedding(max_seq_len, embedding_dim) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + batch_size, seq_len = x.size() + positions = torch.arange(seq_len, device=x.device).repeat(batch_size, 1) + embedding: torch.Tensor = self.embedding(positions) + return embedding + + +class DecoderOnlyTransformer(nn.Module): + """Decoder-only (Autoregressive) Transformer model. + + Args: + num_tokens: Number of tokens in the vocabulary. + max_seq_len: Maximum sequence length. + attn_layers_dim: Dimensionality of the attention layers. + attn_layers_depth: Number of attention layers. + attn_layers_heads: Number of attention heads. + with_cross_attention: Whether to use cross attention for conditioning. + embedding_dropout_rate: Dropout rate for the embedding. + use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. + """ + + def __init__( + self, + num_tokens: int, + max_seq_len: int, + attn_layers_dim: int, + attn_layers_depth: int, + attn_layers_heads: int, + with_cross_attention: bool = False, + embedding_dropout_rate: float = 0.0, + use_flash_attention: bool = False, + ) -> None: + super().__init__() + self.num_tokens = num_tokens + self.max_seq_len = max_seq_len + self.attn_layers_dim = attn_layers_dim + self.attn_layers_depth = attn_layers_depth + self.attn_layers_heads = attn_layers_heads + self.with_cross_attention = with_cross_attention + + self.token_embeddings = nn.Embedding(num_tokens, attn_layers_dim) + self.position_embeddings = AbsolutePositionalEmbedding(max_seq_len=max_seq_len, embedding_dim=attn_layers_dim) + self.embedding_dropout = nn.Dropout(embedding_dropout_rate) + + self.blocks = nn.ModuleList( + [ + _TransformerBlock( + hidden_size=attn_layers_dim, + mlp_dim=attn_layers_dim * 4, + num_heads=attn_layers_heads, + dropout_rate=0.0, + qkv_bias=False, + causal=True, + sequence_length=max_seq_len, + with_cross_attention=with_cross_attention, + use_flash_attention=use_flash_attention, + ) + for _ in range(attn_layers_depth) + ] + ) + + self.to_logits = nn.Linear(attn_layers_dim, num_tokens) + + def forward(self, x: torch.Tensor, context: torch.Tensor | None = None) -> torch.Tensor: + tok_emb = self.token_embeddings(x) + pos_emb = self.position_embeddings(x) + x = self.embedding_dropout(tok_emb + pos_emb) + + for block in self.blocks: + x = block(x, context=context) + logits: torch.Tensor = self.to_logits(x) + return logits diff --git a/tests/test_transformer.py b/tests/test_transformer.py new file mode 100644 index 0000000000..ea6ebdf50f --- /dev/null +++ b/tests/test_transformer.py @@ -0,0 +1,73 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import numpy as np +import torch +from parameterized import parameterized + +from monai.networks import eval_mode +from monai.networks.nets import DecoderOnlyTransformer + +TEST_CASES = [] +for dropout_rate in np.linspace(0, 1, 2): + for attention_layer_dim in [360, 480, 600, 768]: + for num_heads in [4, 6, 8, 12]: + TEST_CASES.append( + [ + { + "num_tokens": 10, + "max_seq_len": 16, + "attn_layers_dim": attention_layer_dim, + "attn_layers_depth": 2, + "attn_layers_heads": num_heads, + "embedding_dropout_rate": dropout_rate, + } + ] + ) + + +class TestDecoderOnlyTransformer(unittest.TestCase): + @parameterized.expand(TEST_CASES) + def test_unconditioned_models(self, input_param): + net = DecoderOnlyTransformer(**input_param) + with eval_mode(net): + net.forward(torch.randint(0, 10, (1, 16))) + + @parameterized.expand(TEST_CASES) + def test_conditioned_models(self, input_param): + net = DecoderOnlyTransformer(**input_param, with_cross_attention=True) + with eval_mode(net): + net.forward(torch.randint(0, 10, (1, 16)), context=torch.randn(1, 3, input_param["attn_layers_dim"])) + + def test_attention_dim_not_multiple_of_heads(self): + with self.assertRaises(ValueError): + DecoderOnlyTransformer( + num_tokens=10, max_seq_len=16, attn_layers_dim=8, attn_layers_depth=2, attn_layers_heads=3 + ) + + def test_dropout_rate_negative(self): + with self.assertRaises(ValueError): + DecoderOnlyTransformer( + num_tokens=10, + max_seq_len=16, + attn_layers_dim=8, + attn_layers_depth=2, + attn_layers_heads=2, + embedding_dropout_rate=-1, + ) + + +if __name__ == "__main__": + unittest.main() From de0a4760547eabe0337d9d9cf40fc90f6bb1cb59 Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Tue, 12 Dec 2023 02:09:37 -0500 Subject: [PATCH 04/38] 6676 port generative networks ddpm (#7304) Towards #6676 . ### Description Adds a DDPM unet. Refactoring for some of the blocks here is scheduled [here](https://github.com/Project-MONAI/MONAI/issues/7227). ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [x] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [x] In-line docstrings updated. - [x] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Mark Graham Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- docs/source/networks.rst | 5 + monai/networks/nets/__init__.py | 1 + monai/networks/nets/diffusion_model_unet.py | 2138 +++++++++++++++++++ tests/test_diffusion_model_unet.py | 535 +++++ 4 files changed, 2679 insertions(+) create mode 100644 monai/networks/nets/diffusion_model_unet.py create mode 100644 tests/test_diffusion_model_unet.py diff --git a/docs/source/networks.rst b/docs/source/networks.rst index 06f60fe8af..417fb8ac73 100644 --- a/docs/source/networks.rst +++ b/docs/source/networks.rst @@ -583,6 +583,11 @@ Nets .. autoclass:: VNet :members: +`DiffusionModelUnet` +~~~~~~~~~~~~~~~~~~~~ +.. autoclass:: DiffusionModelUNet + :members: + `RegUNet` ~~~~~~~~~ .. autoclass:: RegUNet diff --git a/monai/networks/nets/__init__.py b/monai/networks/nets/__init__.py index 08384b4d52..31fbd73b4e 100644 --- a/monai/networks/nets/__init__.py +++ b/monai/networks/nets/__init__.py @@ -35,6 +35,7 @@ densenet201, densenet264, ) +from .diffusion_model_unet import DiffusionModelUNet from .dints import DiNTS, TopologyConstruction, TopologyInstance, TopologySearch from .dynunet import DynUNet, DynUnet, Dynunet from .efficientnet import ( diff --git a/monai/networks/nets/diffusion_model_unet.py b/monai/networks/nets/diffusion_model_unet.py new file mode 100644 index 0000000000..1532215c70 --- /dev/null +++ b/monai/networks/nets/diffusion_model_unet.py @@ -0,0 +1,2138 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# ========================================================================= +# Adapted from https://github.com/huggingface/diffusers +# which has the following license: +# https://github.com/huggingface/diffusers/blob/main/LICENSE +# +# Copyright 2022 UC Berkeley Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========================================================================= + +from __future__ import annotations + +import math +from collections.abc import Sequence + +import torch +import torch.nn.functional as F +from torch import nn + +from monai.networks.blocks import Convolution, MLPBlock +from monai.networks.layers.factories import Pool +from monai.utils import ensure_tuple_rep, optional_import + +# To install xformers, use pip install xformers==0.0.16rc401 + +xops, has_xformers = optional_import("xformers.ops") + + +__all__ = ["DiffusionModelUNet"] + + +def zero_module(module: nn.Module) -> nn.Module: + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +class _CrossAttention(nn.Module): + """ + NOTE This is a private block that we plan to merge with existing MONAI blocks in the future. Please do not make + use of this block as support is not guaranteed. For more information see: + https://github.com/Project-MONAI/MONAI/issues/7227 + + A cross attention layer. + + Args: + query_dim: number of channels in the query. + cross_attention_dim: number of channels in the context. + num_attention_heads: number of heads to use for multi-head attention. + num_head_channels: number of channels in each head. + dropout: dropout probability to use. + upcast_attention: if True, upcast attention operations to full precision. + use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. + """ + + def __init__( + self, + query_dim: int, + cross_attention_dim: int | None = None, + num_attention_heads: int = 8, + num_head_channels: int = 64, + dropout: float = 0.0, + upcast_attention: bool = False, + use_flash_attention: bool = False, + ) -> None: + super().__init__() + self.use_flash_attention = use_flash_attention + inner_dim = num_head_channels * num_attention_heads + cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim + + self.scale = 1 / math.sqrt(num_head_channels) + self.num_heads = num_attention_heads + + self.upcast_attention = upcast_attention + + self.to_q = nn.Linear(query_dim, inner_dim, bias=False) + self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=False) + self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=False) + + self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)) + + def reshape_heads_to_batch_dim(self, x: torch.Tensor) -> torch.Tensor: + """ + Divide hidden state dimension to the multiple attention heads and reshape their input as instances in the batch. + """ + batch_size, seq_len, dim = x.shape + x = x.reshape(batch_size, seq_len, self.num_heads, dim // self.num_heads) + x = x.permute(0, 2, 1, 3).reshape(batch_size * self.num_heads, seq_len, dim // self.num_heads) + return x + + def reshape_batch_dim_to_heads(self, x: torch.Tensor) -> torch.Tensor: + """Combine the output of the attention heads back into the hidden state dimension.""" + batch_size, seq_len, dim = x.shape + x = x.reshape(batch_size // self.num_heads, self.num_heads, seq_len, dim) + x = x.permute(0, 2, 1, 3).reshape(batch_size // self.num_heads, seq_len, dim * self.num_heads) + return x + + def _memory_efficient_attention_xformers( + self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor + ) -> torch.Tensor: + query = query.contiguous() + key = key.contiguous() + value = value.contiguous() + x: torch.Tensor = xops.memory_efficient_attention(query, key, value, attn_bias=None) + return x + + def _attention(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor: + dtype = query.dtype + if self.upcast_attention: + query = query.float() + key = key.float() + + attention_scores = torch.baddbmm( + torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device), + query, + key.transpose(-1, -2), + beta=0, + alpha=self.scale, + ) + attention_probs = attention_scores.softmax(dim=-1) + attention_probs = attention_probs.to(dtype=dtype) + + x = torch.bmm(attention_probs, value) + return x + + def forward(self, x: torch.Tensor, context: torch.Tensor | None = None) -> torch.Tensor: + query = self.to_q(x) + context = context if context is not None else x + key = self.to_k(context) + value = self.to_v(context) + + # Multi-Head Attention + query = self.reshape_heads_to_batch_dim(query) + key = self.reshape_heads_to_batch_dim(key) + value = self.reshape_heads_to_batch_dim(value) + if self.use_flash_attention: + x = self._memory_efficient_attention_xformers(query, key, value) + else: + x = self._attention(query, key, value) + + x = self.reshape_batch_dim_to_heads(x) + x = x.to(query.dtype) + output: torch.Tensor = self.to_out(x) + return output + + +class _BasicTransformerBlock(nn.Module): + """ + NOTE This is a private block that we plan to merge with existing MONAI blocks in the future. Please do not make + use of this block as support is not guaranteed. For more information see: + https://github.com/Project-MONAI/MONAI/issues/7227 + + A basic Transformer block. + + Args: + num_channels: number of channels in the input and output. + num_attention_heads: number of heads to use for multi-head attention. + num_head_channels: number of channels in each attention head. + dropout: dropout probability to use. + cross_attention_dim: size of the context vector for cross attention. + upcast_attention: if True, upcast attention operations to full precision. + use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. + """ + + def __init__( + self, + num_channels: int, + num_attention_heads: int, + num_head_channels: int, + dropout: float = 0.0, + cross_attention_dim: int | None = None, + upcast_attention: bool = False, + use_flash_attention: bool = False, + ) -> None: + super().__init__() + self.attn1 = _CrossAttention( + query_dim=num_channels, + num_attention_heads=num_attention_heads, + num_head_channels=num_head_channels, + dropout=dropout, + upcast_attention=upcast_attention, + use_flash_attention=use_flash_attention, + ) # is a self-attention + self.ff = MLPBlock(hidden_size=num_channels, mlp_dim=num_channels * 4, act="GEGLU", dropout_rate=dropout) + self.attn2 = _CrossAttention( + query_dim=num_channels, + cross_attention_dim=cross_attention_dim, + num_attention_heads=num_attention_heads, + num_head_channels=num_head_channels, + dropout=dropout, + upcast_attention=upcast_attention, + use_flash_attention=use_flash_attention, + ) # is a self-attention if context is None + self.norm1 = nn.LayerNorm(num_channels) + self.norm2 = nn.LayerNorm(num_channels) + self.norm3 = nn.LayerNorm(num_channels) + + def forward(self, x: torch.Tensor, context: torch.Tensor | None = None) -> torch.Tensor: + # 1. Self-Attention + x = self.attn1(self.norm1(x)) + x + + # 2. Cross-Attention + x = self.attn2(self.norm2(x), context=context) + x + + # 3. Feed-forward + x = self.ff(self.norm3(x)) + x + return x + + +class _SpatialTransformer(nn.Module): + """ + NOTE This is a private block that we plan to merge with existing MONAI blocks in the future. Please do not make + use of this block as support is not guaranteed. For more information see: + https://github.com/Project-MONAI/MONAI/issues/7227 + + Transformer block for image-like data. First, project the input (aka embedding) and reshape to b, t, d. Then apply + standard transformer action. Finally, reshape to image. + + Args: + spatial_dims: number of spatial dimensions. + in_channels: number of channels in the input and output. + num_attention_heads: number of heads to use for multi-head attention. + num_head_channels: number of channels in each attention head. + num_layers: number of layers of Transformer blocks to use. + dropout: dropout probability to use. + norm_num_groups: number of groups for the normalization. + norm_eps: epsilon for the normalization. + cross_attention_dim: number of context dimensions to use. + upcast_attention: if True, upcast attention operations to full precision. + use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + num_attention_heads: int, + num_head_channels: int, + num_layers: int = 1, + dropout: float = 0.0, + norm_num_groups: int = 32, + norm_eps: float = 1e-6, + cross_attention_dim: int | None = None, + upcast_attention: bool = False, + use_flash_attention: bool = False, + ) -> None: + super().__init__() + self.spatial_dims = spatial_dims + self.in_channels = in_channels + inner_dim = num_attention_heads * num_head_channels + + self.norm = nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=norm_eps, affine=True) + + self.proj_in = Convolution( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=inner_dim, + strides=1, + kernel_size=1, + padding=0, + conv_only=True, + ) + + self.transformer_blocks = nn.ModuleList( + [ + _BasicTransformerBlock( + num_channels=inner_dim, + num_attention_heads=num_attention_heads, + num_head_channels=num_head_channels, + dropout=dropout, + cross_attention_dim=cross_attention_dim, + upcast_attention=upcast_attention, + use_flash_attention=use_flash_attention, + ) + for _ in range(num_layers) + ] + ) + + self.proj_out = zero_module( + Convolution( + spatial_dims=spatial_dims, + in_channels=inner_dim, + out_channels=in_channels, + strides=1, + kernel_size=1, + padding=0, + conv_only=True, + ) + ) + + def forward(self, x: torch.Tensor, context: torch.Tensor | None = None) -> torch.Tensor: + # note: if no context is given, cross-attention defaults to self-attention + batch = channel = height = width = depth = -1 + if self.spatial_dims == 2: + batch, channel, height, width = x.shape + if self.spatial_dims == 3: + batch, channel, height, width, depth = x.shape + + residual = x + x = self.norm(x) + x = self.proj_in(x) + + inner_dim = x.shape[1] + + if self.spatial_dims == 2: + x = x.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) + if self.spatial_dims == 3: + x = x.permute(0, 2, 3, 4, 1).reshape(batch, height * width * depth, inner_dim) + + for block in self.transformer_blocks: + x = block(x, context=context) + + if self.spatial_dims == 2: + x = x.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() + if self.spatial_dims == 3: + x = x.reshape(batch, height, width, depth, inner_dim).permute(0, 4, 1, 2, 3).contiguous() + + x = self.proj_out(x) + return x + residual + + +class _AttentionBlock(nn.Module): + """ + NOTE This is a private block that we plan to merge with existing MONAI blocks in the future. Please do not make + use of this block as support is not guaranteed. For more information see: + https://github.com/Project-MONAI/MONAI/issues/7227 + + An attention block that allows spatial positions to attend to each other. Uses three q, k, v linear layers to + compute attention. + + Args: + spatial_dims: number of spatial dimensions. + num_channels: number of input channels. + num_head_channels: number of channels in each attention head. + norm_num_groups: number of groups involved for the group normalisation layer. Ensure that your number of + channels is divisible by this number. + norm_eps: epsilon value to use for the normalisation. + use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. + """ + + def __init__( + self, + spatial_dims: int, + num_channels: int, + num_head_channels: int | None = None, + norm_num_groups: int = 32, + norm_eps: float = 1e-6, + use_flash_attention: bool = False, + ) -> None: + super().__init__() + self.use_flash_attention = use_flash_attention + self.spatial_dims = spatial_dims + self.num_channels = num_channels + + self.num_heads = num_channels // num_head_channels if num_head_channels is not None else 1 + self.scale = 1 / math.sqrt(num_channels / self.num_heads) + + self.norm = nn.GroupNorm(num_groups=norm_num_groups, num_channels=num_channels, eps=norm_eps, affine=True) + + self.to_q = nn.Linear(num_channels, num_channels) + self.to_k = nn.Linear(num_channels, num_channels) + self.to_v = nn.Linear(num_channels, num_channels) + + self.proj_attn = nn.Linear(num_channels, num_channels) + + def reshape_heads_to_batch_dim(self, x: torch.Tensor) -> torch.Tensor: + batch_size, seq_len, dim = x.shape + x = x.reshape(batch_size, seq_len, self.num_heads, dim // self.num_heads) + x = x.permute(0, 2, 1, 3).reshape(batch_size * self.num_heads, seq_len, dim // self.num_heads) + return x + + def reshape_batch_dim_to_heads(self, x: torch.Tensor) -> torch.Tensor: + batch_size, seq_len, dim = x.shape + x = x.reshape(batch_size // self.num_heads, self.num_heads, seq_len, dim) + x = x.permute(0, 2, 1, 3).reshape(batch_size // self.num_heads, seq_len, dim * self.num_heads) + return x + + def _memory_efficient_attention_xformers( + self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor + ) -> torch.Tensor: + query = query.contiguous() + key = key.contiguous() + value = value.contiguous() + x: torch.Tensor = xops.memory_efficient_attention(query, key, value, attn_bias=None) + return x + + def _attention(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor: + attention_scores = torch.baddbmm( + torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device), + query, + key.transpose(-1, -2), + beta=0, + alpha=self.scale, + ) + attention_probs = attention_scores.softmax(dim=-1) + x = torch.bmm(attention_probs, value) + return x + + def forward(self, x: torch.Tensor) -> torch.Tensor: + residual = x + + batch = channel = height = width = depth = -1 + if self.spatial_dims == 2: + batch, channel, height, width = x.shape + if self.spatial_dims == 3: + batch, channel, height, width, depth = x.shape + + # norm + x = self.norm(x) + + if self.spatial_dims == 2: + x = x.view(batch, channel, height * width).transpose(1, 2) + if self.spatial_dims == 3: + x = x.view(batch, channel, height * width * depth).transpose(1, 2) + + # proj to q, k, v + query = self.to_q(x) + key = self.to_k(x) + value = self.to_v(x) + + # Multi-Head Attention + query = self.reshape_heads_to_batch_dim(query) + key = self.reshape_heads_to_batch_dim(key) + value = self.reshape_heads_to_batch_dim(value) + + if self.use_flash_attention: + x = self._memory_efficient_attention_xformers(query, key, value) + else: + x = self._attention(query, key, value) + + x = self.reshape_batch_dim_to_heads(x) + x = x.to(query.dtype) + + if self.spatial_dims == 2: + x = x.transpose(-1, -2).reshape(batch, channel, height, width) + if self.spatial_dims == 3: + x = x.transpose(-1, -2).reshape(batch, channel, height, width, depth) + + return x + residual + + +def get_timestep_embedding(timesteps: torch.Tensor, embedding_dim: int, max_period: int = 10000) -> torch.Tensor: + """ + Create sinusoidal timestep embeddings following the implementation in Ho et al. "Denoising Diffusion Probabilistic + Models" https://arxiv.org/abs/2006.11239. + + Args: + timesteps: a 1-D Tensor of N indices, one per batch element. + embedding_dim: the dimension of the output. + max_period: controls the minimum frequency of the embeddings. + """ + if timesteps.ndim != 1: + raise ValueError("Timesteps should be a 1d-array") + + half_dim = embedding_dim // 2 + exponent = -math.log(max_period) * torch.arange(start=0, end=half_dim, dtype=torch.float32, device=timesteps.device) + freqs = torch.exp(exponent / half_dim) + + args = timesteps[:, None].float() * freqs[None, :] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + + # zero pad + if embedding_dim % 2 == 1: + embedding = torch.nn.functional.pad(embedding, (0, 1, 0, 0)) + + return embedding + + +class _Downsample(nn.Module): + """ + NOTE This is a private block that we plan to merge with existing MONAI blocks in the future. Please do not make + use of this block as support is not guaranteed. For more information see: + https://github.com/Project-MONAI/MONAI/issues/7227 + + Downsampling layer. + + Args: + spatial_dims: number of spatial dimensions. + num_channels: number of input channels. + use_conv: if True uses Convolution instead of Pool average to perform downsampling. In case that use_conv is + False, the number of output channels must be the same as the number of input channels. + out_channels: number of output channels. + padding: controls the amount of implicit zero-paddings on both sides for padding number of points + for each dimension. + """ + + def __init__( + self, spatial_dims: int, num_channels: int, use_conv: bool, out_channels: int | None = None, padding: int = 1 + ) -> None: + super().__init__() + self.num_channels = num_channels + self.out_channels = out_channels or num_channels + self.use_conv = use_conv + if use_conv: + self.op = Convolution( + spatial_dims=spatial_dims, + in_channels=self.num_channels, + out_channels=self.out_channels, + strides=2, + kernel_size=3, + padding=padding, + conv_only=True, + ) + else: + if self.num_channels != self.out_channels: + raise ValueError("num_channels and out_channels must be equal when use_conv=False") + self.op = Pool[Pool.AVG, spatial_dims](kernel_size=2, stride=2) + + def forward(self, x: torch.Tensor, emb: torch.Tensor | None = None) -> torch.Tensor: + del emb + if x.shape[1] != self.num_channels: + raise ValueError( + f"Input number of channels ({x.shape[1]}) is not equal to expected number of channels " + f"({self.num_channels})" + ) + output: torch.Tensor = self.op(x) + return output + + +class _Upsample(nn.Module): + """ + NOTE This is a private block that we plan to merge with existing MONAI blocks in the future. Please do not make + use of this block as support is not guaranteed. For more information see: + https://github.com/Project-MONAI/MONAI/issues/7227 + + Upsampling layer with an optional convolution. + + Args: + spatial_dims: number of spatial dimensions. + num_channels: number of input channels. + use_conv: if True uses Convolution instead of Pool average to perform downsampling. + out_channels: number of output channels. + padding: controls the amount of implicit zero-paddings on both sides for padding number of points for each + dimension. + """ + + def __init__( + self, spatial_dims: int, num_channels: int, use_conv: bool, out_channels: int | None = None, padding: int = 1 + ) -> None: + super().__init__() + self.num_channels = num_channels + self.out_channels = out_channels or num_channels + self.use_conv = use_conv + if use_conv: + self.conv = Convolution( + spatial_dims=spatial_dims, + in_channels=self.num_channels, + out_channels=self.out_channels, + strides=1, + kernel_size=3, + padding=padding, + conv_only=True, + ) + + def forward(self, x: torch.Tensor, emb: torch.Tensor | None = None) -> torch.Tensor: + del emb + if x.shape[1] != self.num_channels: + raise ValueError("Input channels should be equal to num_channels") + + # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16 + # https://github.com/pytorch/pytorch/issues/86679 + dtype = x.dtype + if dtype == torch.bfloat16: + x = x.to(torch.float32) + + x = F.interpolate(x, scale_factor=2.0, mode="nearest") + + # If the input is bfloat16, we cast back to bfloat16 + if dtype == torch.bfloat16: + x = x.to(dtype) + + if self.use_conv: + x = self.conv(x) + return x + + +class _ResnetBlock(nn.Module): + """ + NOTE This is a private block that we plan to merge with existing MONAI blocks in the future. Please do not make + use of this block as support is not guaranteed. For more information see: + https://github.com/Project-MONAI/MONAI/issues/7227 + Residual block with timestep conditioning. + + Args: + spatial_dims: The number of spatial dimensions. + in_channels: number of input channels. + temb_channels: number of timestep embedding channels. + out_channels: number of output channels. + up: if True, performs upsampling. + down: if True, performs downsampling. + norm_num_groups: number of groups for the group normalization. + norm_eps: epsilon for the group normalization. + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + temb_channels: int, + out_channels: int | None = None, + up: bool = False, + down: bool = False, + norm_num_groups: int = 32, + norm_eps: float = 1e-6, + ) -> None: + super().__init__() + self.spatial_dims = spatial_dims + self.channels = in_channels + self.emb_channels = temb_channels + self.out_channels = out_channels or in_channels + self.up = up + self.down = down + + self.norm1 = nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=norm_eps, affine=True) + self.nonlinearity = nn.SiLU() + self.conv1 = Convolution( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=self.out_channels, + strides=1, + kernel_size=3, + padding=1, + conv_only=True, + ) + + self.upsample = self.downsample = None + if self.up: + self.upsample = _Upsample(spatial_dims, in_channels, use_conv=False) + elif down: + self.downsample = _Downsample(spatial_dims, in_channels, use_conv=False) + + self.time_emb_proj = nn.Linear(temb_channels, self.out_channels) + + self.norm2 = nn.GroupNorm(num_groups=norm_num_groups, num_channels=self.out_channels, eps=norm_eps, affine=True) + self.conv2 = zero_module( + Convolution( + spatial_dims=spatial_dims, + in_channels=self.out_channels, + out_channels=self.out_channels, + strides=1, + kernel_size=3, + padding=1, + conv_only=True, + ) + ) + self.skip_connection: nn.Module + if self.out_channels == in_channels: + self.skip_connection = nn.Identity() + else: + self.skip_connection = Convolution( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=self.out_channels, + strides=1, + kernel_size=1, + padding=0, + conv_only=True, + ) + + def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor: + h = x + h = self.norm1(h) + h = self.nonlinearity(h) + + if self.upsample is not None: + if h.shape[0] >= 64: + x = x.contiguous() + h = h.contiguous() + x = self.upsample(x) + h = self.upsample(h) + elif self.downsample is not None: + x = self.downsample(x) + h = self.downsample(h) + + h = self.conv1(h) + + if self.spatial_dims == 2: + temb = self.time_emb_proj(self.nonlinearity(emb))[:, :, None, None] + else: + temb = self.time_emb_proj(self.nonlinearity(emb))[:, :, None, None, None] + h = h + temb + + h = self.norm2(h) + h = self.nonlinearity(h) + h = self.conv2(h) + output: torch.Tensor = self.skip_connection(x) + h + return output + + +class DownBlock(nn.Module): + """ + Unet's down block containing resnet and downsamplers blocks. + + Args: + spatial_dims: The number of spatial dimensions. + in_channels: number of input channels. + out_channels: number of output channels. + temb_channels: number of timestep embedding channels. + num_res_blocks: number of residual blocks. + norm_num_groups: number of groups for the group normalization. + norm_eps: epsilon for the group normalization. + add_downsample: if True add downsample block. + resblock_updown: if True use residual blocks for downsampling. + downsample_padding: padding used in the downsampling block. + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + out_channels: int, + temb_channels: int, + num_res_blocks: int = 1, + norm_num_groups: int = 32, + norm_eps: float = 1e-6, + add_downsample: bool = True, + resblock_updown: bool = False, + downsample_padding: int = 1, + ) -> None: + super().__init__() + self.resblock_updown = resblock_updown + + resnets = [] + + for i in range(num_res_blocks): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + _ResnetBlock( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + ) + ) + + self.resnets = nn.ModuleList(resnets) + + if add_downsample: + self.downsampler: nn.Module | None + if resblock_updown: + self.downsampler = _ResnetBlock( + spatial_dims=spatial_dims, + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + down=True, + ) + else: + self.downsampler = _Downsample( + spatial_dims=spatial_dims, + num_channels=out_channels, + use_conv=True, + out_channels=out_channels, + padding=downsample_padding, + ) + else: + self.downsampler = None + + def forward( + self, hidden_states: torch.Tensor, temb: torch.Tensor, context: torch.Tensor | None = None + ) -> tuple[torch.Tensor, list[torch.Tensor]]: + del context + output_states = [] + + for resnet in self.resnets: + hidden_states = resnet(hidden_states, temb) + output_states.append(hidden_states) + + if self.downsampler is not None: + hidden_states = self.downsampler(hidden_states, temb) + output_states.append(hidden_states) + + return hidden_states, output_states + + +class AttnDownBlock(nn.Module): + """ + Unet's down block containing resnet, downsamplers and self-attention blocks. + + Args: + spatial_dims: The number of spatial dimensions. + in_channels: number of input channels. + out_channels: number of output channels. + temb_channels: number of timestep embedding channels. + num_res_blocks: number of residual blocks. + norm_num_groups: number of groups for the group normalization. + norm_eps: epsilon for the group normalization. + add_downsample: if True add downsample block. + resblock_updown: if True use residual blocks for downsampling. + downsample_padding: padding used in the downsampling block. + num_head_channels: number of channels in each attention head. + use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + out_channels: int, + temb_channels: int, + num_res_blocks: int = 1, + norm_num_groups: int = 32, + norm_eps: float = 1e-6, + add_downsample: bool = True, + resblock_updown: bool = False, + downsample_padding: int = 1, + num_head_channels: int = 1, + use_flash_attention: bool = False, + ) -> None: + super().__init__() + self.resblock_updown = resblock_updown + + resnets = [] + attentions = [] + + for i in range(num_res_blocks): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + _ResnetBlock( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + ) + ) + attentions.append( + _AttentionBlock( + spatial_dims=spatial_dims, + num_channels=out_channels, + num_head_channels=num_head_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + use_flash_attention=use_flash_attention, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + self.downsampler: nn.Module | None + if add_downsample: + if resblock_updown: + self.downsampler = _ResnetBlock( + spatial_dims=spatial_dims, + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + down=True, + ) + else: + self.downsampler = _Downsample( + spatial_dims=spatial_dims, + num_channels=out_channels, + use_conv=True, + out_channels=out_channels, + padding=downsample_padding, + ) + else: + self.downsampler = None + + def forward( + self, hidden_states: torch.Tensor, temb: torch.Tensor, context: torch.Tensor | None = None + ) -> tuple[torch.Tensor, list[torch.Tensor]]: + del context + output_states = [] + + for resnet, attn in zip(self.resnets, self.attentions): + hidden_states = resnet(hidden_states, temb) + hidden_states = attn(hidden_states) + output_states.append(hidden_states) + + if self.downsampler is not None: + hidden_states = self.downsampler(hidden_states, temb) + output_states.append(hidden_states) + + return hidden_states, output_states + + +class CrossAttnDownBlock(nn.Module): + """ + Unet's down block containing resnet, downsamplers and cross-attention blocks. + + Args: + spatial_dims: number of spatial dimensions. + in_channels: number of input channels. + out_channels: number of output channels. + temb_channels: number of timestep embedding channels. + num_res_blocks: number of residual blocks. + norm_num_groups: number of groups for the group normalization. + norm_eps: epsilon for the group normalization. + add_downsample: if True add downsample block. + resblock_updown: if True use residual blocks for downsampling. + downsample_padding: padding used in the downsampling block. + num_head_channels: number of channels in each attention head. + transformer_num_layers: number of layers of Transformer blocks to use. + cross_attention_dim: number of context dimensions to use. + upcast_attention: if True, upcast attention operations to full precision. + use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. + dropout_cattn: if different from zero, this will be the dropout value for the cross-attention layers + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + out_channels: int, + temb_channels: int, + num_res_blocks: int = 1, + norm_num_groups: int = 32, + norm_eps: float = 1e-6, + add_downsample: bool = True, + resblock_updown: bool = False, + downsample_padding: int = 1, + num_head_channels: int = 1, + transformer_num_layers: int = 1, + cross_attention_dim: int | None = None, + upcast_attention: bool = False, + use_flash_attention: bool = False, + dropout_cattn: float = 0.0, + ) -> None: + super().__init__() + self.resblock_updown = resblock_updown + + resnets = [] + attentions = [] + + for i in range(num_res_blocks): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + _ResnetBlock( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + ) + ) + + attentions.append( + _SpatialTransformer( + spatial_dims=spatial_dims, + in_channels=out_channels, + num_attention_heads=out_channels // num_head_channels, + num_head_channels=num_head_channels, + num_layers=transformer_num_layers, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + cross_attention_dim=cross_attention_dim, + upcast_attention=upcast_attention, + use_flash_attention=use_flash_attention, + dropout=dropout_cattn, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + self.downsampler: nn.Module | None + if add_downsample: + if resblock_updown: + self.downsampler = _ResnetBlock( + spatial_dims=spatial_dims, + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + down=True, + ) + else: + self.downsampler = _Downsample( + spatial_dims=spatial_dims, + num_channels=out_channels, + use_conv=True, + out_channels=out_channels, + padding=downsample_padding, + ) + else: + self.downsampler = None + + def forward( + self, hidden_states: torch.Tensor, temb: torch.Tensor, context: torch.Tensor | None = None + ) -> tuple[torch.Tensor, list[torch.Tensor]]: + output_states = [] + + for resnet, attn in zip(self.resnets, self.attentions): + hidden_states = resnet(hidden_states, temb) + hidden_states = attn(hidden_states, context=context) + output_states.append(hidden_states) + + if self.downsampler is not None: + hidden_states = self.downsampler(hidden_states, temb) + output_states.append(hidden_states) + + return hidden_states, output_states + + +class AttnMidBlock(nn.Module): + """ + Unet's mid block containing resnet and self-attention blocks. + + Args: + spatial_dims: The number of spatial dimensions. + in_channels: number of input channels. + temb_channels: number of timestep embedding channels. + norm_num_groups: number of groups for the group normalization. + norm_eps: epsilon for the group normalization. + num_head_channels: number of channels in each attention head. + use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + temb_channels: int, + norm_num_groups: int = 32, + norm_eps: float = 1e-6, + num_head_channels: int = 1, + use_flash_attention: bool = False, + ) -> None: + super().__init__() + + self.resnet_1 = _ResnetBlock( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + ) + self.attention = _AttentionBlock( + spatial_dims=spatial_dims, + num_channels=in_channels, + num_head_channels=num_head_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + use_flash_attention=use_flash_attention, + ) + + self.resnet_2 = _ResnetBlock( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + ) + + def forward( + self, hidden_states: torch.Tensor, temb: torch.Tensor, context: torch.Tensor | None = None + ) -> torch.Tensor: + del context + hidden_states = self.resnet_1(hidden_states, temb) + hidden_states = self.attention(hidden_states) + hidden_states = self.resnet_2(hidden_states, temb) + + return hidden_states + + +class CrossAttnMidBlock(nn.Module): + """ + Unet's mid block containing resnet and cross-attention blocks. + + Args: + spatial_dims: The number of spatial dimensions. + in_channels: number of input channels. + temb_channels: number of timestep embedding channels + norm_num_groups: number of groups for the group normalization. + norm_eps: epsilon for the group normalization. + num_head_channels: number of channels in each attention head. + transformer_num_layers: number of layers of Transformer blocks to use. + cross_attention_dim: number of context dimensions to use. + upcast_attention: if True, upcast attention operations to full precision. + use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + temb_channels: int, + norm_num_groups: int = 32, + norm_eps: float = 1e-6, + num_head_channels: int = 1, + transformer_num_layers: int = 1, + cross_attention_dim: int | None = None, + upcast_attention: bool = False, + use_flash_attention: bool = False, + dropout_cattn: float = 0.0, + ) -> None: + super().__init__() + + self.resnet_1 = _ResnetBlock( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + ) + self.attention = _SpatialTransformer( + spatial_dims=spatial_dims, + in_channels=in_channels, + num_attention_heads=in_channels // num_head_channels, + num_head_channels=num_head_channels, + num_layers=transformer_num_layers, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + cross_attention_dim=cross_attention_dim, + upcast_attention=upcast_attention, + use_flash_attention=use_flash_attention, + dropout=dropout_cattn, + ) + self.resnet_2 = _ResnetBlock( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + ) + + def forward( + self, hidden_states: torch.Tensor, temb: torch.Tensor, context: torch.Tensor | None = None + ) -> torch.Tensor: + hidden_states = self.resnet_1(hidden_states, temb) + hidden_states = self.attention(hidden_states, context=context) + hidden_states = self.resnet_2(hidden_states, temb) + + return hidden_states + + +class UpBlock(nn.Module): + """ + Unet's up block containing resnet and upsamplers blocks. + + Args: + spatial_dims: The number of spatial dimensions. + in_channels: number of input channels. + prev_output_channel: number of channels from residual connection. + out_channels: number of output channels. + temb_channels: number of timestep embedding channels. + num_res_blocks: number of residual blocks. + norm_num_groups: number of groups for the group normalization. + norm_eps: epsilon for the group normalization. + add_upsample: if True add downsample block. + resblock_updown: if True use residual blocks for upsampling. + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + prev_output_channel: int, + out_channels: int, + temb_channels: int, + num_res_blocks: int = 1, + norm_num_groups: int = 32, + norm_eps: float = 1e-6, + add_upsample: bool = True, + resblock_updown: bool = False, + ) -> None: + super().__init__() + self.resblock_updown = resblock_updown + resnets = [] + + for i in range(num_res_blocks): + res_skip_channels = in_channels if (i == num_res_blocks - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + _ResnetBlock( + spatial_dims=spatial_dims, + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + ) + ) + + self.resnets = nn.ModuleList(resnets) + + self.upsampler: nn.Module | None + if add_upsample: + if resblock_updown: + self.upsampler = _ResnetBlock( + spatial_dims=spatial_dims, + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + up=True, + ) + else: + self.upsampler = _Upsample( + spatial_dims=spatial_dims, num_channels=out_channels, use_conv=True, out_channels=out_channels + ) + else: + self.upsampler = None + + def forward( + self, + hidden_states: torch.Tensor, + res_hidden_states_list: list[torch.Tensor], + temb: torch.Tensor, + context: torch.Tensor | None = None, + ) -> torch.Tensor: + del context + for resnet in self.resnets: + # pop res hidden states + res_hidden_states = res_hidden_states_list[-1] + res_hidden_states_list = res_hidden_states_list[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + hidden_states = resnet(hidden_states, temb) + + if self.upsampler is not None: + hidden_states = self.upsampler(hidden_states, temb) + + return hidden_states + + +class AttnUpBlock(nn.Module): + """ + Unet's up block containing resnet, upsamplers, and self-attention blocks. + + Args: + spatial_dims: The number of spatial dimensions. + in_channels: number of input channels. + prev_output_channel: number of channels from residual connection. + out_channels: number of output channels. + temb_channels: number of timestep embedding channels. + num_res_blocks: number of residual blocks. + norm_num_groups: number of groups for the group normalization. + norm_eps: epsilon for the group normalization. + add_upsample: if True add downsample block. + resblock_updown: if True use residual blocks for upsampling. + num_head_channels: number of channels in each attention head. + use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + prev_output_channel: int, + out_channels: int, + temb_channels: int, + num_res_blocks: int = 1, + norm_num_groups: int = 32, + norm_eps: float = 1e-6, + add_upsample: bool = True, + resblock_updown: bool = False, + num_head_channels: int = 1, + use_flash_attention: bool = False, + ) -> None: + super().__init__() + self.resblock_updown = resblock_updown + + resnets = [] + attentions = [] + + for i in range(num_res_blocks): + res_skip_channels = in_channels if (i == num_res_blocks - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + _ResnetBlock( + spatial_dims=spatial_dims, + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + ) + ) + attentions.append( + _AttentionBlock( + spatial_dims=spatial_dims, + num_channels=out_channels, + num_head_channels=num_head_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + use_flash_attention=use_flash_attention, + ) + ) + + self.resnets = nn.ModuleList(resnets) + self.attentions = nn.ModuleList(attentions) + + self.upsampler: nn.Module | None + if add_upsample: + if resblock_updown: + self.upsampler = _ResnetBlock( + spatial_dims=spatial_dims, + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + up=True, + ) + else: + self.upsampler = _Upsample( + spatial_dims=spatial_dims, num_channels=out_channels, use_conv=True, out_channels=out_channels + ) + else: + self.upsampler = None + + def forward( + self, + hidden_states: torch.Tensor, + res_hidden_states_list: list[torch.Tensor], + temb: torch.Tensor, + context: torch.Tensor | None = None, + ) -> torch.Tensor: + del context + for resnet, attn in zip(self.resnets, self.attentions): + # pop res hidden states + res_hidden_states = res_hidden_states_list[-1] + res_hidden_states_list = res_hidden_states_list[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + hidden_states = resnet(hidden_states, temb) + hidden_states = attn(hidden_states) + + if self.upsampler is not None: + hidden_states = self.upsampler(hidden_states, temb) + + return hidden_states + + +class CrossAttnUpBlock(nn.Module): + """ + Unet's up block containing resnet, upsamplers, and self-attention blocks. + + Args: + spatial_dims: The number of spatial dimensions. + in_channels: number of input channels. + prev_output_channel: number of channels from residual connection. + out_channels: number of output channels. + temb_channels: number of timestep embedding channels. + num_res_blocks: number of residual blocks. + norm_num_groups: number of groups for the group normalization. + norm_eps: epsilon for the group normalization. + add_upsample: if True add downsample block. + resblock_updown: if True use residual blocks for upsampling. + num_head_channels: number of channels in each attention head. + transformer_num_layers: number of layers of Transformer blocks to use. + cross_attention_dim: number of context dimensions to use. + upcast_attention: if True, upcast attention operations to full precision. + use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. + dropout_cattn: if different from zero, this will be the dropout value for the cross-attention layers + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + prev_output_channel: int, + out_channels: int, + temb_channels: int, + num_res_blocks: int = 1, + norm_num_groups: int = 32, + norm_eps: float = 1e-6, + add_upsample: bool = True, + resblock_updown: bool = False, + num_head_channels: int = 1, + transformer_num_layers: int = 1, + cross_attention_dim: int | None = None, + upcast_attention: bool = False, + use_flash_attention: bool = False, + dropout_cattn: float = 0.0, + ) -> None: + super().__init__() + self.resblock_updown = resblock_updown + + resnets = [] + attentions = [] + + for i in range(num_res_blocks): + res_skip_channels = in_channels if (i == num_res_blocks - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + _ResnetBlock( + spatial_dims=spatial_dims, + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + ) + ) + attentions.append( + _SpatialTransformer( + spatial_dims=spatial_dims, + in_channels=out_channels, + num_attention_heads=out_channels // num_head_channels, + num_head_channels=num_head_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + num_layers=transformer_num_layers, + cross_attention_dim=cross_attention_dim, + upcast_attention=upcast_attention, + use_flash_attention=use_flash_attention, + dropout=dropout_cattn, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + self.upsampler: nn.Module | None + if add_upsample: + if resblock_updown: + self.upsampler = _ResnetBlock( + spatial_dims=spatial_dims, + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + up=True, + ) + else: + self.upsampler = _Upsample( + spatial_dims=spatial_dims, num_channels=out_channels, use_conv=True, out_channels=out_channels + ) + else: + self.upsampler = None + + def forward( + self, + hidden_states: torch.Tensor, + res_hidden_states_list: list[torch.Tensor], + temb: torch.Tensor, + context: torch.Tensor | None = None, + ) -> torch.Tensor: + for resnet, attn in zip(self.resnets, self.attentions): + # pop res hidden states + res_hidden_states = res_hidden_states_list[-1] + res_hidden_states_list = res_hidden_states_list[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + hidden_states = resnet(hidden_states, temb) + hidden_states = attn(hidden_states, context=context) + + if self.upsampler is not None: + hidden_states = self.upsampler(hidden_states, temb) + + return hidden_states + + +def get_down_block( + spatial_dims: int, + in_channels: int, + out_channels: int, + temb_channels: int, + num_res_blocks: int, + norm_num_groups: int, + norm_eps: float, + add_downsample: bool, + resblock_updown: bool, + with_attn: bool, + with_cross_attn: bool, + num_head_channels: int, + transformer_num_layers: int, + cross_attention_dim: int | None, + upcast_attention: bool = False, + use_flash_attention: bool = False, + dropout_cattn: float = 0.0, +) -> nn.Module: + if with_attn: + return AttnDownBlock( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + num_res_blocks=num_res_blocks, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + add_downsample=add_downsample, + resblock_updown=resblock_updown, + num_head_channels=num_head_channels, + use_flash_attention=use_flash_attention, + ) + elif with_cross_attn: + return CrossAttnDownBlock( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + num_res_blocks=num_res_blocks, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + add_downsample=add_downsample, + resblock_updown=resblock_updown, + num_head_channels=num_head_channels, + transformer_num_layers=transformer_num_layers, + cross_attention_dim=cross_attention_dim, + upcast_attention=upcast_attention, + use_flash_attention=use_flash_attention, + dropout_cattn=dropout_cattn, + ) + else: + return DownBlock( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + num_res_blocks=num_res_blocks, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + add_downsample=add_downsample, + resblock_updown=resblock_updown, + ) + + +def get_mid_block( + spatial_dims: int, + in_channels: int, + temb_channels: int, + norm_num_groups: int, + norm_eps: float, + with_conditioning: bool, + num_head_channels: int, + transformer_num_layers: int, + cross_attention_dim: int | None, + upcast_attention: bool = False, + use_flash_attention: bool = False, + dropout_cattn: float = 0.0, +) -> nn.Module: + if with_conditioning: + return CrossAttnMidBlock( + spatial_dims=spatial_dims, + in_channels=in_channels, + temb_channels=temb_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + num_head_channels=num_head_channels, + transformer_num_layers=transformer_num_layers, + cross_attention_dim=cross_attention_dim, + upcast_attention=upcast_attention, + use_flash_attention=use_flash_attention, + dropout_cattn=dropout_cattn, + ) + else: + return AttnMidBlock( + spatial_dims=spatial_dims, + in_channels=in_channels, + temb_channels=temb_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + num_head_channels=num_head_channels, + use_flash_attention=use_flash_attention, + ) + + +def get_up_block( + spatial_dims: int, + in_channels: int, + prev_output_channel: int, + out_channels: int, + temb_channels: int, + num_res_blocks: int, + norm_num_groups: int, + norm_eps: float, + add_upsample: bool, + resblock_updown: bool, + with_attn: bool, + with_cross_attn: bool, + num_head_channels: int, + transformer_num_layers: int, + cross_attention_dim: int | None, + upcast_attention: bool = False, + use_flash_attention: bool = False, + dropout_cattn: float = 0.0, +) -> nn.Module: + if with_attn: + return AttnUpBlock( + spatial_dims=spatial_dims, + in_channels=in_channels, + prev_output_channel=prev_output_channel, + out_channels=out_channels, + temb_channels=temb_channels, + num_res_blocks=num_res_blocks, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + add_upsample=add_upsample, + resblock_updown=resblock_updown, + num_head_channels=num_head_channels, + use_flash_attention=use_flash_attention, + ) + elif with_cross_attn: + return CrossAttnUpBlock( + spatial_dims=spatial_dims, + in_channels=in_channels, + prev_output_channel=prev_output_channel, + out_channels=out_channels, + temb_channels=temb_channels, + num_res_blocks=num_res_blocks, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + add_upsample=add_upsample, + resblock_updown=resblock_updown, + num_head_channels=num_head_channels, + transformer_num_layers=transformer_num_layers, + cross_attention_dim=cross_attention_dim, + upcast_attention=upcast_attention, + use_flash_attention=use_flash_attention, + dropout_cattn=dropout_cattn, + ) + else: + return UpBlock( + spatial_dims=spatial_dims, + in_channels=in_channels, + prev_output_channel=prev_output_channel, + out_channels=out_channels, + temb_channels=temb_channels, + num_res_blocks=num_res_blocks, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + add_upsample=add_upsample, + resblock_updown=resblock_updown, + ) + + +class DiffusionModelUNet(nn.Module): + """ + Unet network with timestep embedding and attention mechanisms for conditioning based on + Rombach et al. "High-Resolution Image Synthesis with Latent Diffusion Models" https://arxiv.org/abs/2112.10752 + and Pinaya et al. "Brain Imaging Generation with Latent Diffusion Models" https://arxiv.org/abs/2209.07162 + + Args: + spatial_dims: number of spatial dimensions. + in_channels: number of input channels. + out_channels: number of output channels. + num_res_blocks: number of residual blocks (see _ResnetBlock) per level. + channels: tuple of block output channels. + attention_levels: list of levels to add attention. + norm_num_groups: number of groups for the normalization. + norm_eps: epsilon for the normalization. + resblock_updown: if True use residual blocks for up/downsampling. + num_head_channels: number of channels in each attention head. + with_conditioning: if True add spatial transformers to perform conditioning. + transformer_num_layers: number of layers of Transformer blocks to use. + cross_attention_dim: number of context dimensions to use. + num_class_embeds: if specified (as an int), then this model will be class-conditional with `num_class_embeds` + classes. + upcast_attention: if True, upcast attention operations to full precision. + use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. + dropout_cattn: if different from zero, this will be the dropout value for the cross-attention layers + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + out_channels: int, + num_res_blocks: Sequence[int] | int = (2, 2, 2, 2), + channels: Sequence[int] = (32, 64, 64, 64), + attention_levels: Sequence[bool] = (False, False, True, True), + norm_num_groups: int = 32, + norm_eps: float = 1e-6, + resblock_updown: bool = False, + num_head_channels: int | Sequence[int] = 8, + with_conditioning: bool = False, + transformer_num_layers: int = 1, + cross_attention_dim: int | None = None, + num_class_embeds: int | None = None, + upcast_attention: bool = False, + use_flash_attention: bool = False, + dropout_cattn: float = 0.0, + ) -> None: + super().__init__() + if with_conditioning is True and cross_attention_dim is None: + raise ValueError( + "DiffusionModelUNet expects dimension of the cross-attention conditioning (cross_attention_dim) " + "when using with_conditioning." + ) + if cross_attention_dim is not None and with_conditioning is False: + raise ValueError( + "DiffusionModelUNet expects with_conditioning=True when specifying the cross_attention_dim." + ) + if dropout_cattn > 1.0 or dropout_cattn < 0.0: + raise ValueError("Dropout cannot be negative or >1.0!") + + # All number of channels should be multiple of num_groups + if any((out_channel % norm_num_groups) != 0 for out_channel in channels): + raise ValueError("DiffusionModelUNet expects all num_channels being multiple of norm_num_groups") + + if len(channels) != len(attention_levels): + raise ValueError("DiffusionModelUNet expects num_channels being same size of attention_levels") + + if isinstance(num_head_channels, int): + num_head_channels = ensure_tuple_rep(num_head_channels, len(attention_levels)) + + if len(num_head_channels) != len(attention_levels): + raise ValueError( + "num_head_channels should have the same length as attention_levels. For the i levels without attention," + " i.e. `attention_level[i]=False`, the num_head_channels[i] will be ignored." + ) + + if isinstance(num_res_blocks, int): + num_res_blocks = ensure_tuple_rep(num_res_blocks, len(channels)) + + if len(num_res_blocks) != len(channels): + raise ValueError( + "`num_res_blocks` should be a single integer or a tuple of integers with the same length as " + "`num_channels`." + ) + + if use_flash_attention and not has_xformers: + raise ValueError("use_flash_attention is True but xformers is not installed.") + + if use_flash_attention is True and not torch.cuda.is_available(): + raise ValueError( + "torch.cuda.is_available() should be True but is False. Flash attention is only available for GPU." + ) + + self.in_channels = in_channels + self.block_out_channels = channels + self.out_channels = out_channels + self.num_res_blocks = num_res_blocks + self.attention_levels = attention_levels + self.num_head_channels = num_head_channels + self.with_conditioning = with_conditioning + + # input + self.conv_in = Convolution( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=channels[0], + strides=1, + kernel_size=3, + padding=1, + conv_only=True, + ) + + # time + time_embed_dim = channels[0] * 4 + self.time_embed = nn.Sequential( + nn.Linear(channels[0], time_embed_dim), nn.SiLU(), nn.Linear(time_embed_dim, time_embed_dim) + ) + + # class embedding + self.num_class_embeds = num_class_embeds + if num_class_embeds is not None: + self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) + + # down + self.down_blocks = nn.ModuleList([]) + output_channel = channels[0] + for i in range(len(channels)): + input_channel = output_channel + output_channel = channels[i] + is_final_block = i == len(channels) - 1 + + down_block = get_down_block( + spatial_dims=spatial_dims, + in_channels=input_channel, + out_channels=output_channel, + temb_channels=time_embed_dim, + num_res_blocks=num_res_blocks[i], + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + add_downsample=not is_final_block, + resblock_updown=resblock_updown, + with_attn=(attention_levels[i] and not with_conditioning), + with_cross_attn=(attention_levels[i] and with_conditioning), + num_head_channels=num_head_channels[i], + transformer_num_layers=transformer_num_layers, + cross_attention_dim=cross_attention_dim, + upcast_attention=upcast_attention, + use_flash_attention=use_flash_attention, + dropout_cattn=dropout_cattn, + ) + + self.down_blocks.append(down_block) + + # mid + self.middle_block = get_mid_block( + spatial_dims=spatial_dims, + in_channels=channels[-1], + temb_channels=time_embed_dim, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + with_conditioning=with_conditioning, + num_head_channels=num_head_channels[-1], + transformer_num_layers=transformer_num_layers, + cross_attention_dim=cross_attention_dim, + upcast_attention=upcast_attention, + use_flash_attention=use_flash_attention, + dropout_cattn=dropout_cattn, + ) + + # up + self.up_blocks = nn.ModuleList([]) + reversed_block_out_channels = list(reversed(channels)) + reversed_num_res_blocks = list(reversed(num_res_blocks)) + reversed_attention_levels = list(reversed(attention_levels)) + reversed_num_head_channels = list(reversed(num_head_channels)) + output_channel = reversed_block_out_channels[0] + for i in range(len(reversed_block_out_channels)): + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + input_channel = reversed_block_out_channels[min(i + 1, len(channels) - 1)] + + is_final_block = i == len(channels) - 1 + + up_block = get_up_block( + spatial_dims=spatial_dims, + in_channels=input_channel, + prev_output_channel=prev_output_channel, + out_channels=output_channel, + temb_channels=time_embed_dim, + num_res_blocks=reversed_num_res_blocks[i] + 1, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + add_upsample=not is_final_block, + resblock_updown=resblock_updown, + with_attn=(reversed_attention_levels[i] and not with_conditioning), + with_cross_attn=(reversed_attention_levels[i] and with_conditioning), + num_head_channels=reversed_num_head_channels[i], + transformer_num_layers=transformer_num_layers, + cross_attention_dim=cross_attention_dim, + upcast_attention=upcast_attention, + use_flash_attention=use_flash_attention, + dropout_cattn=dropout_cattn, + ) + + self.up_blocks.append(up_block) + + # out + self.out = nn.Sequential( + nn.GroupNorm(num_groups=norm_num_groups, num_channels=channels[0], eps=norm_eps, affine=True), + nn.SiLU(), + zero_module( + Convolution( + spatial_dims=spatial_dims, + in_channels=channels[0], + out_channels=out_channels, + strides=1, + kernel_size=3, + padding=1, + conv_only=True, + ) + ), + ) + + def forward( + self, + x: torch.Tensor, + timesteps: torch.Tensor, + context: torch.Tensor | None = None, + class_labels: torch.Tensor | None = None, + down_block_additional_residuals: tuple[torch.Tensor] | None = None, + mid_block_additional_residual: torch.Tensor | None = None, + ) -> torch.Tensor: + """ + Args: + x: input tensor (N, C, SpatialDims). + timesteps: timestep tensor (N,). + context: context tensor (N, 1, ContextDim). + class_labels: context tensor (N, ). + down_block_additional_residuals: additional residual tensors for down blocks (N, C, FeatureMapsDims). + mid_block_additional_residual: additional residual tensor for mid block (N, C, FeatureMapsDims). + """ + # 1. time + t_emb = get_timestep_embedding(timesteps, self.block_out_channels[0]) + + # timesteps does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=x.dtype) + emb = self.time_embed(t_emb) + + # 2. class + if self.num_class_embeds is not None: + if class_labels is None: + raise ValueError("class_labels should be provided when num_class_embeds > 0") + class_emb = self.class_embedding(class_labels) + class_emb = class_emb.to(dtype=x.dtype) + emb = emb + class_emb + + # 3. initial convolution + h = self.conv_in(x) + + # 4. down + if context is not None and self.with_conditioning is False: + raise ValueError("model should have with_conditioning = True if context is provided") + down_block_res_samples: list[torch.Tensor] = [h] + for downsample_block in self.down_blocks: + h, res_samples = downsample_block(hidden_states=h, temb=emb, context=context) + for residual in res_samples: + down_block_res_samples.append(residual) + + # Additional residual conections for Controlnets + if down_block_additional_residuals is not None: + new_down_block_res_samples: list[torch.Tensor] = [] + for down_block_res_sample, down_block_additional_residual in zip( + down_block_res_samples, down_block_additional_residuals + ): + down_block_res_sample = down_block_res_sample + down_block_additional_residual + new_down_block_res_samples += [down_block_res_sample] + + down_block_res_samples = new_down_block_res_samples + + # 5. mid + h = self.middle_block(hidden_states=h, temb=emb, context=context) + + # Additional residual conections for Controlnets + if mid_block_additional_residual is not None: + h = h + mid_block_additional_residual + + # 6. up + for upsample_block in self.up_blocks: + res_samples = down_block_res_samples[-len(upsample_block.resnets) :] + down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] + h = upsample_block(hidden_states=h, res_hidden_states_list=res_samples, temb=emb, context=context) + + # 7. output block + output: torch.Tensor = self.out(h) + + return output + + +class DiffusionModelEncoder(nn.Module): + """ + Classification Network based on the Encoder of the Diffusion Model, followed by fully connected layers. This network is based on + Wolleb et al. "Diffusion Models for Medical Anomaly Detection" (https://arxiv.org/abs/2203.04306). + + Args: + spatial_dims: number of spatial dimensions. + in_channels: number of input channels. + out_channels: number of output channels. + num_res_blocks: number of residual blocks (see _ResnetBlock) per level. + channels: tuple of block output channels. + attention_levels: list of levels to add attention. + norm_num_groups: number of groups for the normalization. + norm_eps: epsilon for the normalization. + resblock_updown: if True use residual blocks for downsampling. + num_head_channels: number of channels in each attention head. + with_conditioning: if True add spatial transformers to perform conditioning. + transformer_num_layers: number of layers of Transformer blocks to use. + cross_attention_dim: number of context dimensions to use. + num_class_embeds: if specified (as an int), then this model will be class-conditional with `num_class_embeds` classes. + upcast_attention: if True, upcast attention operations to full precision. + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + out_channels: int, + num_res_blocks: Sequence[int] | int = (2, 2, 2, 2), + channels: Sequence[int] = (32, 64, 64, 64), + attention_levels: Sequence[bool] = (False, False, True, True), + norm_num_groups: int = 32, + norm_eps: float = 1e-6, + resblock_updown: bool = False, + num_head_channels: int | Sequence[int] = 8, + with_conditioning: bool = False, + transformer_num_layers: int = 1, + cross_attention_dim: int | None = None, + num_class_embeds: int | None = None, + upcast_attention: bool = False, + ) -> None: + super().__init__() + if with_conditioning is True and cross_attention_dim is None: + raise ValueError( + "DiffusionModelEncoder expects dimension of the cross-attention conditioning (cross_attention_dim) " + "when using with_conditioning." + ) + if cross_attention_dim is not None and with_conditioning is False: + raise ValueError( + "DiffusionModelEncoder expects with_conditioning=True when specifying the cross_attention_dim." + ) + + # All number of channels should be multiple of num_groups + if any((out_channel % norm_num_groups) != 0 for out_channel in channels): + raise ValueError("DiffusionModelEncoder expects all num_channels being multiple of norm_num_groups") + if len(channels) != len(attention_levels): + raise ValueError("DiffusionModelEncoder expects num_channels being same size of attention_levels") + + if isinstance(num_head_channels, int): + num_head_channels = ensure_tuple_rep(num_head_channels, len(attention_levels)) + + if isinstance(num_res_blocks, int): + num_res_blocks = ensure_tuple_rep(num_res_blocks, len(channels)) + + if len(num_head_channels) != len(attention_levels): + raise ValueError( + "num_head_channels should have the same length as attention_levels. For the i levels without attention," + " i.e. `attention_level[i]=False`, the num_head_channels[i] will be ignored." + ) + + self.in_channels = in_channels + self.block_out_channels = channels + self.out_channels = out_channels + self.num_res_blocks = num_res_blocks + self.attention_levels = attention_levels + self.num_head_channels = num_head_channels + self.with_conditioning = with_conditioning + + # input + self.conv_in = Convolution( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=channels[0], + strides=1, + kernel_size=3, + padding=1, + conv_only=True, + ) + + # time + time_embed_dim = channels[0] * 4 + self.time_embed = nn.Sequential( + nn.Linear(channels[0], time_embed_dim), nn.SiLU(), nn.Linear(time_embed_dim, time_embed_dim) + ) + + # class embedding + self.num_class_embeds = num_class_embeds + if num_class_embeds is not None: + self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) + + # down + self.down_blocks = nn.ModuleList([]) + output_channel = channels[0] + for i in range(len(channels)): + input_channel = output_channel + output_channel = channels[i] + is_final_block = i == len(channels) # - 1 + + down_block = get_down_block( + spatial_dims=spatial_dims, + in_channels=input_channel, + out_channels=output_channel, + temb_channels=time_embed_dim, + num_res_blocks=num_res_blocks[i], + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + add_downsample=not is_final_block, + resblock_updown=resblock_updown, + with_attn=(attention_levels[i] and not with_conditioning), + with_cross_attn=(attention_levels[i] and with_conditioning), + num_head_channels=num_head_channels[i], + transformer_num_layers=transformer_num_layers, + cross_attention_dim=cross_attention_dim, + upcast_attention=upcast_attention, + ) + + self.down_blocks.append(down_block) + + self.out = nn.Sequential(nn.Linear(4096, 512), nn.ReLU(), nn.Dropout(0.1), nn.Linear(512, self.out_channels)) + + def forward( + self, + x: torch.Tensor, + timesteps: torch.Tensor, + context: torch.Tensor | None = None, + class_labels: torch.Tensor | None = None, + ) -> torch.Tensor: + """ + Args: + x: input tensor (N, C, SpatialDims). + timesteps: timestep tensor (N,). + context: context tensor (N, 1, ContextDim). + class_labels: context tensor (N, ). + """ + # 1. time + t_emb = get_timestep_embedding(timesteps, self.block_out_channels[0]) + + # timesteps does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=x.dtype) + emb = self.time_embed(t_emb) + + # 2. class + if self.num_class_embeds is not None: + if class_labels is None: + raise ValueError("class_labels should be provided when num_class_embeds > 0") + class_emb = self.class_embedding(class_labels) + class_emb = class_emb.to(dtype=x.dtype) + emb = emb + class_emb + + # 3. initial convolution + h = self.conv_in(x) + + # 4. down + if context is not None and self.with_conditioning is False: + raise ValueError("model should have with_conditioning = True if context is provided") + for downsample_block in self.down_blocks: + h, _ = downsample_block(hidden_states=h, temb=emb, context=context) + + h = h.reshape(h.shape[0], -1) + output: torch.Tensor = self.out(h) + + return output diff --git a/tests/test_diffusion_model_unet.py b/tests/test_diffusion_model_unet.py new file mode 100644 index 0000000000..d40a31a1da --- /dev/null +++ b/tests/test_diffusion_model_unet.py @@ -0,0 +1,535 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import torch +from parameterized import parameterized + +from monai.networks import eval_mode +from monai.networks.nets import DiffusionModelUNet + +UNCOND_CASES_2D = [ + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "channels": (8, 8, 8), + "attention_levels": (False, False, False), + "norm_num_groups": 8, + } + ], + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": (1, 1, 2), + "channels": (8, 8, 8), + "attention_levels": (False, False, False), + "norm_num_groups": 8, + } + ], + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "channels": (8, 8, 8), + "attention_levels": (False, False, False), + "norm_num_groups": 8, + "resblock_updown": True, + } + ], + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "channels": (8, 8, 8), + "attention_levels": (False, False, True), + "num_head_channels": 8, + "norm_num_groups": 8, + } + ], + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "channels": (8, 8, 8), + "attention_levels": (False, False, True), + "num_head_channels": 8, + "norm_num_groups": 8, + "resblock_updown": True, + } + ], + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "channels": (8, 8, 8), + "attention_levels": (False, False, True), + "num_head_channels": 4, + "norm_num_groups": 8, + } + ], + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "channels": (8, 8, 8), + "attention_levels": (False, True, True), + "num_head_channels": (0, 2, 4), + "norm_num_groups": 8, + } + ], +] + +UNCOND_CASES_3D = [ + [ + { + "spatial_dims": 3, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "channels": (8, 8, 8), + "attention_levels": (False, False, False), + "norm_num_groups": 8, + } + ], + [ + { + "spatial_dims": 3, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "channels": (8, 8, 8), + "attention_levels": (False, False, False), + "norm_num_groups": 8, + "resblock_updown": True, + } + ], + [ + { + "spatial_dims": 3, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "channels": (8, 8, 8), + "attention_levels": (False, False, True), + "num_head_channels": 8, + "norm_num_groups": 8, + } + ], + [ + { + "spatial_dims": 3, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "channels": (8, 8, 8), + "attention_levels": (False, False, True), + "num_head_channels": 8, + "norm_num_groups": 8, + "resblock_updown": True, + } + ], + [ + { + "spatial_dims": 3, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "channels": (8, 8, 8), + "attention_levels": (False, False, True), + "num_head_channels": 4, + "norm_num_groups": 8, + } + ], + [ + { + "spatial_dims": 3, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "channels": (8, 8, 8), + "attention_levels": (False, False, True), + "num_head_channels": (0, 0, 4), + "norm_num_groups": 8, + } + ], +] + +COND_CASES_2D = [ + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "channels": (8, 8, 8), + "attention_levels": (False, False, True), + "num_head_channels": 4, + "norm_num_groups": 8, + "with_conditioning": True, + "transformer_num_layers": 1, + "cross_attention_dim": 3, + } + ], + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "channels": (8, 8, 8), + "attention_levels": (False, False, True), + "num_head_channels": 4, + "norm_num_groups": 8, + "with_conditioning": True, + "transformer_num_layers": 1, + "cross_attention_dim": 3, + "resblock_updown": True, + } + ], + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "channels": (8, 8, 8), + "attention_levels": (False, False, True), + "num_head_channels": 4, + "norm_num_groups": 8, + "with_conditioning": True, + "transformer_num_layers": 1, + "cross_attention_dim": 3, + "upcast_attention": True, + } + ], +] + +DROPOUT_OK = [ + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "channels": (8, 8, 8), + "attention_levels": (False, False, True), + "num_head_channels": 4, + "norm_num_groups": 8, + "with_conditioning": True, + "transformer_num_layers": 1, + "cross_attention_dim": 3, + "dropout_cattn": 0.25, + } + ], + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "channels": (8, 8, 8), + "attention_levels": (False, False, True), + "num_head_channels": 4, + "norm_num_groups": 8, + "with_conditioning": True, + "transformer_num_layers": 1, + "cross_attention_dim": 3, + } + ], +] + +DROPOUT_WRONG = [ + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "channels": (8, 8, 8), + "attention_levels": (False, False, True), + "num_head_channels": 4, + "norm_num_groups": 8, + "with_conditioning": True, + "transformer_num_layers": 1, + "cross_attention_dim": 3, + "dropout_cattn": 3.0, + } + ] +] + + +class TestDiffusionModelUNet2D(unittest.TestCase): + @parameterized.expand(UNCOND_CASES_2D) + def test_shape_unconditioned_models(self, input_param): + net = DiffusionModelUNet(**input_param) + with eval_mode(net): + result = net.forward(torch.rand((1, 1, 16, 16)), torch.randint(0, 1000, (1,)).long()) + self.assertEqual(result.shape, (1, 1, 16, 16)) + + def test_timestep_with_wrong_shape(self): + net = DiffusionModelUNet( + spatial_dims=2, + in_channels=1, + out_channels=1, + num_res_blocks=1, + channels=(8, 8, 8), + attention_levels=(False, False, False), + norm_num_groups=8, + ) + with self.assertRaises(ValueError): + with eval_mode(net): + net.forward(torch.rand((1, 1, 16, 16)), torch.randint(0, 1000, (1, 1)).long()) + + def test_shape_with_different_in_channel_out_channel(self): + in_channels = 6 + out_channels = 3 + net = DiffusionModelUNet( + spatial_dims=2, + in_channels=in_channels, + out_channels=out_channels, + num_res_blocks=1, + channels=(8, 8, 8), + attention_levels=(False, False, False), + norm_num_groups=8, + ) + with eval_mode(net): + result = net.forward(torch.rand((1, in_channels, 16, 16)), torch.randint(0, 1000, (1,)).long()) + self.assertEqual(result.shape, (1, out_channels, 16, 16)) + + def test_model_channels_not_multiple_of_norm_num_group(self): + with self.assertRaises(ValueError): + DiffusionModelUNet( + spatial_dims=2, + in_channels=1, + out_channels=1, + num_res_blocks=1, + channels=(8, 8, 12), + attention_levels=(False, False, False), + norm_num_groups=8, + ) + + def test_attention_levels_with_different_length_num_head_channels(self): + with self.assertRaises(ValueError): + DiffusionModelUNet( + spatial_dims=2, + in_channels=1, + out_channels=1, + num_res_blocks=1, + channels=(8, 8, 8), + attention_levels=(False, False, False), + num_head_channels=(0, 2), + norm_num_groups=8, + ) + + def test_num_res_blocks_with_different_length_channels(self): + with self.assertRaises(ValueError): + DiffusionModelUNet( + spatial_dims=2, + in_channels=1, + out_channels=1, + num_res_blocks=(1, 1), + channels=(8, 8, 8), + attention_levels=(False, False, False), + norm_num_groups=8, + ) + + def test_shape_conditioned_models(self): + net = DiffusionModelUNet( + spatial_dims=2, + in_channels=1, + out_channels=1, + num_res_blocks=1, + channels=(8, 8, 8), + attention_levels=(False, False, True), + with_conditioning=True, + transformer_num_layers=1, + cross_attention_dim=3, + norm_num_groups=8, + num_head_channels=8, + ) + with eval_mode(net): + result = net.forward( + x=torch.rand((1, 1, 16, 32)), + timesteps=torch.randint(0, 1000, (1,)).long(), + context=torch.rand((1, 1, 3)), + ) + self.assertEqual(result.shape, (1, 1, 16, 32)) + + def test_with_conditioning_cross_attention_dim_none(self): + with self.assertRaises(ValueError): + DiffusionModelUNet( + spatial_dims=2, + in_channels=1, + out_channels=1, + num_res_blocks=1, + channels=(8, 8, 8), + attention_levels=(False, False, True), + with_conditioning=True, + transformer_num_layers=1, + cross_attention_dim=None, + norm_num_groups=8, + ) + + def test_context_with_conditioning_none(self): + net = DiffusionModelUNet( + spatial_dims=2, + in_channels=1, + out_channels=1, + num_res_blocks=1, + channels=(8, 8, 8), + attention_levels=(False, False, True), + with_conditioning=False, + transformer_num_layers=1, + norm_num_groups=8, + ) + + with self.assertRaises(ValueError): + with eval_mode(net): + net.forward( + x=torch.rand((1, 1, 16, 32)), + timesteps=torch.randint(0, 1000, (1,)).long(), + context=torch.rand((1, 1, 3)), + ) + + def test_shape_conditioned_models_class_conditioning(self): + net = DiffusionModelUNet( + spatial_dims=2, + in_channels=1, + out_channels=1, + num_res_blocks=1, + channels=(8, 8, 8), + attention_levels=(False, False, True), + norm_num_groups=8, + num_head_channels=8, + num_class_embeds=2, + ) + with eval_mode(net): + result = net.forward( + x=torch.rand((1, 1, 16, 32)), + timesteps=torch.randint(0, 1000, (1,)).long(), + class_labels=torch.randint(0, 2, (1,)).long(), + ) + self.assertEqual(result.shape, (1, 1, 16, 32)) + + def test_conditioned_models_no_class_labels(self): + net = DiffusionModelUNet( + spatial_dims=2, + in_channels=1, + out_channels=1, + num_res_blocks=1, + channels=(8, 8, 8), + attention_levels=(False, False, True), + norm_num_groups=8, + num_head_channels=8, + num_class_embeds=2, + ) + + with self.assertRaises(ValueError): + net.forward(x=torch.rand((1, 1, 16, 32)), timesteps=torch.randint(0, 1000, (1,)).long()) + + def test_model_channels_not_same_size_of_attention_levels(self): + with self.assertRaises(ValueError): + DiffusionModelUNet( + spatial_dims=2, + in_channels=1, + out_channels=1, + num_res_blocks=1, + channels=(8, 8, 8), + attention_levels=(False, False), + norm_num_groups=8, + num_head_channels=8, + num_class_embeds=2, + ) + + @parameterized.expand(COND_CASES_2D) + def test_conditioned_2d_models_shape(self, input_param): + net = DiffusionModelUNet(**input_param) + with eval_mode(net): + result = net.forward(torch.rand((1, 1, 16, 16)), torch.randint(0, 1000, (1,)).long(), torch.rand((1, 1, 3))) + self.assertEqual(result.shape, (1, 1, 16, 16)) + + +class TestDiffusionModelUNet3D(unittest.TestCase): + @parameterized.expand(UNCOND_CASES_3D) + def test_shape_unconditioned_models(self, input_param): + net = DiffusionModelUNet(**input_param) + with eval_mode(net): + result = net.forward(torch.rand((1, 1, 16, 16, 16)), torch.randint(0, 1000, (1,)).long()) + self.assertEqual(result.shape, (1, 1, 16, 16, 16)) + + def test_shape_with_different_in_channel_out_channel(self): + in_channels = 6 + out_channels = 3 + net = DiffusionModelUNet( + spatial_dims=3, + in_channels=in_channels, + out_channels=out_channels, + num_res_blocks=1, + channels=(8, 8, 8), + attention_levels=(False, False, True), + norm_num_groups=4, + ) + with eval_mode(net): + result = net.forward(torch.rand((1, in_channels, 16, 16, 16)), torch.randint(0, 1000, (1,)).long()) + self.assertEqual(result.shape, (1, out_channels, 16, 16, 16)) + + def test_shape_conditioned_models(self): + net = DiffusionModelUNet( + spatial_dims=3, + in_channels=1, + out_channels=1, + num_res_blocks=1, + channels=(16, 16, 16), + attention_levels=(False, False, True), + norm_num_groups=16, + with_conditioning=True, + transformer_num_layers=1, + cross_attention_dim=3, + ) + with eval_mode(net): + result = net.forward( + x=torch.rand((1, 1, 16, 16, 16)), + timesteps=torch.randint(0, 1000, (1,)).long(), + context=torch.rand((1, 1, 3)), + ) + self.assertEqual(result.shape, (1, 1, 16, 16, 16)) + + # Test dropout specification for cross-attention blocks + @parameterized.expand(DROPOUT_WRONG) + def test_wrong_dropout(self, input_param): + with self.assertRaises(ValueError): + _ = DiffusionModelUNet(**input_param) + + @parameterized.expand(DROPOUT_OK) + def test_right_dropout(self, input_param): + _ = DiffusionModelUNet(**input_param) + + +if __name__ == "__main__": + unittest.main() From 43bc0230f8dd042924ccb6267317622fdffc695e Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Wed, 13 Dec 2023 22:37:51 -0500 Subject: [PATCH 05/38] 6676 port generative networks controlnet (#7312) Part of #6676 . ### Description Ports the ControlNet. ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [x] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [x] In-line docstrings updated. - [x] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Mark Graham --- docs/source/networks.rst | 5 + monai/networks/nets/__init__.py | 1 + monai/networks/nets/controlnet.py | 421 ++++++++++++++++++++++++++++++ tests/test_controlnet.py | 177 +++++++++++++ 4 files changed, 604 insertions(+) create mode 100644 monai/networks/nets/controlnet.py create mode 100644 tests/test_controlnet.py diff --git a/docs/source/networks.rst b/docs/source/networks.rst index 417fb8ac73..0960fcdbc0 100644 --- a/docs/source/networks.rst +++ b/docs/source/networks.rst @@ -588,6 +588,11 @@ Nets .. autoclass:: DiffusionModelUNet :members: +`ControlNet` +~~~~~~~~~~~~ +.. autoclass:: ControlNet + :members: + `RegUNet` ~~~~~~~~~ .. autoclass:: RegUNet diff --git a/monai/networks/nets/__init__.py b/monai/networks/nets/__init__.py index 31fbd73b4e..58cb652bae 100644 --- a/monai/networks/nets/__init__.py +++ b/monai/networks/nets/__init__.py @@ -18,6 +18,7 @@ from .basic_unet import BasicUNet, BasicUnet, Basicunet, basicunet from .basic_unetplusplus import BasicUNetPlusPlus, BasicUnetPlusPlus, BasicunetPlusPlus, basicunetplusplus from .classifier import Classifier, Critic, Discriminator +from .controlnet import ControlNet from .daf3d import DAF3D from .densenet import ( DenseNet, diff --git a/monai/networks/nets/controlnet.py b/monai/networks/nets/controlnet.py new file mode 100644 index 0000000000..d98755f401 --- /dev/null +++ b/monai/networks/nets/controlnet.py @@ -0,0 +1,421 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# ========================================================================= +# Adapted from https://github.com/huggingface/diffusers +# which has the following license: +# https://github.com/huggingface/diffusers/blob/main/LICENSE +# +# Copyright 2022 UC Berkeley Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========================================================================= + +from __future__ import annotations + +from collections.abc import Sequence + +import torch +import torch.nn.functional as F +from torch import nn + +from monai.networks.blocks import Convolution +from monai.networks.nets.diffusion_model_unet import get_down_block, get_mid_block, get_timestep_embedding +from monai.utils import ensure_tuple_rep + + +class ControlNetConditioningEmbedding(nn.Module): + """ + Network to encode the conditioning into a latent space. + """ + + def __init__(self, spatial_dims: int, in_channels: int, out_channels: int, channels: Sequence[int]): + super().__init__() + + self.conv_in = Convolution( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=channels[0], + strides=1, + kernel_size=3, + padding=1, + conv_only=True, + ) + + self.blocks = nn.ModuleList([]) + + for i in range(len(channels) - 1): + channel_in = channels[i] + channel_out = channels[i + 1] + self.blocks.append( + Convolution( + spatial_dims=spatial_dims, + in_channels=channel_in, + out_channels=channel_in, + strides=1, + kernel_size=3, + padding=1, + conv_only=True, + ) + ) + + self.blocks.append( + Convolution( + spatial_dims=spatial_dims, + in_channels=channel_in, + out_channels=channel_out, + strides=2, + kernel_size=3, + padding=1, + conv_only=True, + ) + ) + + self.conv_out = zero_module( + Convolution( + spatial_dims=spatial_dims, + in_channels=channels[-1], + out_channels=out_channels, + strides=1, + kernel_size=3, + padding=1, + conv_only=True, + ) + ) + + def forward(self, conditioning): + embedding = self.conv_in(conditioning) + embedding = F.silu(embedding) + + for block in self.blocks: + embedding = block(embedding) + embedding = F.silu(embedding) + + embedding = self.conv_out(embedding) + + return embedding + + +def zero_module(module): + for p in module.parameters(): + nn.init.zeros_(p) + return module + + +class ControlNet(nn.Module): + """ + Control network for diffusion models based on Zhang and Agrawala "Adding Conditional Control to Text-to-Image + Diffusion Models" (https://arxiv.org/abs/2302.05543) + + Args: + spatial_dims: number of spatial dimensions. + in_channels: number of input channels. + num_res_blocks: number of residual blocks (see ResnetBlock) per level. + channels: tuple of block output channels. + attention_levels: list of levels to add attention. + norm_num_groups: number of groups for the normalization. + norm_eps: epsilon for the normalization. + resblock_updown: if True use residual blocks for up/downsampling. + num_head_channels: number of channels in each attention head. + with_conditioning: if True add spatial transformers to perform conditioning. + transformer_num_layers: number of layers of Transformer blocks to use. + cross_attention_dim: number of context dimensions to use. + num_class_embeds: if specified (as an int), then this model will be class-conditional with `num_class_embeds` + classes. + upcast_attention: if True, upcast attention operations to full precision. + use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. + conditioning_embedding_in_channels: number of input channels for the conditioning embedding. + conditioning_embedding_num_channels: number of channels for the blocks in the conditioning embedding. + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + num_res_blocks: Sequence[int] | int = (2, 2, 2, 2), + channels: Sequence[int] = (32, 64, 64, 64), + attention_levels: Sequence[bool] = (False, False, True, True), + norm_num_groups: int = 32, + norm_eps: float = 1e-6, + resblock_updown: bool = False, + num_head_channels: int | Sequence[int] = 8, + with_conditioning: bool = False, + transformer_num_layers: int = 1, + cross_attention_dim: int | None = None, + num_class_embeds: int | None = None, + upcast_attention: bool = False, + use_flash_attention: bool = False, + conditioning_embedding_in_channels: int = 1, + conditioning_embedding_num_channels: Sequence[int] = (16, 32, 96, 256), + ) -> None: + super().__init__() + if with_conditioning is True and cross_attention_dim is None: + raise ValueError( + "DiffusionModelUNet expects dimension of the cross-attention conditioning (cross_attention_dim) " + "to be specified when with_conditioning=True." + ) + if cross_attention_dim is not None and with_conditioning is False: + raise ValueError( + "DiffusionModelUNet expects with_conditioning=True when specifying the cross_attention_dim." + ) + + # All number of channels should be multiple of num_groups + if any((out_channel % norm_num_groups) != 0 for out_channel in channels): + raise ValueError( + f"DiffusionModelUNet expects all channels to be a multiple of norm_num_groups, but got" + f" channels={channels} and norm_num_groups={norm_num_groups}" + ) + + if len(channels) != len(attention_levels): + raise ValueError( + f"DiffusionModelUNet expects channels to have the same length as attention_levels, but got " + f"channels={channels} and attention_levels={attention_levels}" + ) + + if isinstance(num_head_channels, int): + num_head_channels = ensure_tuple_rep(num_head_channels, len(attention_levels)) + + if len(num_head_channels) != len(attention_levels): + raise ValueError( + f"num_head_channels should have the same length as attention_levels, but got channels={channels} and " + f"attention_levels={attention_levels} . For the i levels without attention," + " i.e. `attention_level[i]=False`, the num_head_channels[i] will be ignored." + ) + + if isinstance(num_res_blocks, int): + num_res_blocks = ensure_tuple_rep(num_res_blocks, len(channels)) + + if len(num_res_blocks) != len(channels): + raise ValueError( + f"`num_res_blocks` should be a single integer or a tuple of integers with the same length as " + f"`num_channels`, but got num_res_blocks={num_res_blocks} and channels={channels}." + ) + + if use_flash_attention is True and not torch.cuda.is_available(): + raise ValueError( + "torch.cuda.is_available() should be True but is False. Flash attention is only available for GPU." + ) + + self.in_channels = in_channels + self.block_out_channels = channels + self.num_res_blocks = num_res_blocks + self.attention_levels = attention_levels + self.num_head_channels = num_head_channels + self.with_conditioning = with_conditioning + + # input + self.conv_in = Convolution( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=channels[0], + strides=1, + kernel_size=3, + padding=1, + conv_only=True, + ) + + # time + time_embed_dim = channels[0] * 4 + self.time_embed = nn.Sequential( + nn.Linear(channels[0], time_embed_dim), nn.SiLU(), nn.Linear(time_embed_dim, time_embed_dim) + ) + + # class embedding + self.num_class_embeds = num_class_embeds + if num_class_embeds is not None: + self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) + + # control net conditioning embedding + self.controlnet_cond_embedding = ControlNetConditioningEmbedding( + spatial_dims=spatial_dims, + in_channels=conditioning_embedding_in_channels, + channels=conditioning_embedding_num_channels, + out_channels=channels[0], + ) + + # down + self.down_blocks = nn.ModuleList([]) + self.controlnet_down_blocks = nn.ModuleList([]) + output_channel = channels[0] + + controlnet_block = Convolution( + spatial_dims=spatial_dims, + in_channels=output_channel, + out_channels=output_channel, + strides=1, + kernel_size=1, + padding=0, + conv_only=True, + ) + controlnet_block = zero_module(controlnet_block.conv) + self.controlnet_down_blocks.append(controlnet_block) + + for i in range(len(channels)): + input_channel = output_channel + output_channel = channels[i] + is_final_block = i == len(channels) - 1 + + down_block = get_down_block( + spatial_dims=spatial_dims, + in_channels=input_channel, + out_channels=output_channel, + temb_channels=time_embed_dim, + num_res_blocks=num_res_blocks[i], + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + add_downsample=not is_final_block, + resblock_updown=resblock_updown, + with_attn=(attention_levels[i] and not with_conditioning), + with_cross_attn=(attention_levels[i] and with_conditioning), + num_head_channels=num_head_channels[i], + transformer_num_layers=transformer_num_layers, + cross_attention_dim=cross_attention_dim, + upcast_attention=upcast_attention, + use_flash_attention=use_flash_attention, + ) + + self.down_blocks.append(down_block) + + for _ in range(num_res_blocks[i]): + controlnet_block = Convolution( + spatial_dims=spatial_dims, + in_channels=output_channel, + out_channels=output_channel, + strides=1, + kernel_size=1, + padding=0, + conv_only=True, + ) + controlnet_block = zero_module(controlnet_block) + self.controlnet_down_blocks.append(controlnet_block) + # + if not is_final_block: + controlnet_block = Convolution( + spatial_dims=spatial_dims, + in_channels=output_channel, + out_channels=output_channel, + strides=1, + kernel_size=1, + padding=0, + conv_only=True, + ) + controlnet_block = zero_module(controlnet_block) + self.controlnet_down_blocks.append(controlnet_block) + + # mid + mid_block_channel = channels[-1] + + self.middle_block = get_mid_block( + spatial_dims=spatial_dims, + in_channels=mid_block_channel, + temb_channels=time_embed_dim, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + with_conditioning=with_conditioning, + num_head_channels=num_head_channels[-1], + transformer_num_layers=transformer_num_layers, + cross_attention_dim=cross_attention_dim, + upcast_attention=upcast_attention, + use_flash_attention=use_flash_attention, + ) + + controlnet_block = Convolution( + spatial_dims=spatial_dims, + in_channels=output_channel, + out_channels=output_channel, + strides=1, + kernel_size=1, + padding=0, + conv_only=True, + ) + controlnet_block = zero_module(controlnet_block) + self.controlnet_mid_block = controlnet_block + + def forward( + self, + x: torch.Tensor, + timesteps: torch.Tensor, + controlnet_cond: torch.Tensor, + conditioning_scale: float = 1.0, + context: torch.Tensor | None = None, + class_labels: torch.Tensor | None = None, + ) -> tuple[list[torch.Tensor], torch.Tensor]: + """ + Args: + x: input tensor (N, C, H, W, [D]). + timesteps: timestep tensor (N,). + controlnet_cond: controlnet conditioning tensor (N, C, H, W, [D]) + conditioning_scale: conditioning scale. + context: context tensor (N, 1, cross_attention_dim), where cross_attention_dim is specified in the model init. + class_labels: context tensor (N, ). + """ + # 1. time + t_emb = get_timestep_embedding(timesteps, self.block_out_channels[0]) + + # timesteps does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=x.dtype) + emb = self.time_embed(t_emb) + + # 2. class + if self.num_class_embeds is not None: + if class_labels is None: + raise ValueError("class_labels should be provided when num_class_embeds > 0") + class_emb = self.class_embedding(class_labels) + class_emb = class_emb.to(dtype=x.dtype) + emb = emb + class_emb + + # 3. initial convolution + h = self.conv_in(x) + + controlnet_cond = self.controlnet_cond_embedding(controlnet_cond) + + h += controlnet_cond + + # 4. down + if context is not None and self.with_conditioning is False: + raise ValueError("model should have with_conditioning = True if context is provided") + down_block_res_samples: list[torch.Tensor] = [h] + for downsample_block in self.down_blocks: + h, res_samples = downsample_block(hidden_states=h, temb=emb, context=context) + for residual in res_samples: + down_block_res_samples.append(residual) + + # 5. mid + h = self.middle_block(hidden_states=h, temb=emb, context=context) + + # 6. Control net blocks + controlnet_down_block_res_samples = [] + + for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks): + down_block_res_sample = controlnet_block(down_block_res_sample) + controlnet_down_block_res_samples.append(down_block_res_sample) + + down_block_res_samples = controlnet_down_block_res_samples + + mid_block_res_sample: torch.Tensor = self.controlnet_mid_block(h) + + # 6. scaling + down_block_res_samples = [h * conditioning_scale for h in down_block_res_samples] + mid_block_res_sample *= conditioning_scale + + return down_block_res_samples, mid_block_res_sample diff --git a/tests/test_controlnet.py b/tests/test_controlnet.py new file mode 100644 index 0000000000..07dfa2e49b --- /dev/null +++ b/tests/test_controlnet.py @@ -0,0 +1,177 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import torch +from parameterized import parameterized + +from monai.networks import eval_mode +from monai.networks.nets.controlnet import ControlNet + +UNCOND_CASES_2D = [ + [ + { + "spatial_dims": 2, + "in_channels": 1, + "num_res_blocks": 1, + "channels": (8, 8, 8), + "attention_levels": (False, False, False), + "norm_num_groups": 8, + }, + (1, 8, 4, 4), + ], + [ + { + "spatial_dims": 2, + "in_channels": 1, + "num_res_blocks": 1, + "channels": (8, 8, 8), + "attention_levels": (False, False, False), + "norm_num_groups": 8, + "resblock_updown": True, + }, + (1, 8, 4, 4), + ], + [ + { + "spatial_dims": 2, + "in_channels": 1, + "num_res_blocks": 1, + "channels": (4, 4, 4), + "attention_levels": (False, False, True), + "num_head_channels": 4, + "norm_num_groups": 4, + }, + (1, 4, 4, 4), + ], + [ + { + "spatial_dims": 2, + "in_channels": 1, + "num_res_blocks": 1, + "channels": (8, 8, 8), + "attention_levels": (False, False, True), + "num_head_channels": 8, + "norm_num_groups": 8, + "resblock_updown": True, + }, + (1, 8, 4, 4), + ], +] + +UNCOND_CASES_3D = [ + [ + { + "spatial_dims": 3, + "in_channels": 1, + "num_res_blocks": 1, + "channels": (8, 8, 8), + "attention_levels": (False, False, False), + "norm_num_groups": 8, + }, + (1, 8, 4, 4, 4), + ], + [ + { + "spatial_dims": 3, + "in_channels": 1, + "num_res_blocks": 1, + "channels": (4, 4, 4), + "num_head_channels": 4, + "attention_levels": (False, False, False), + "norm_num_groups": 4, + "resblock_updown": True, + }, + (1, 4, 4, 4, 4), + ], +] + +COND_CASES_2D = [ + [ + { + "spatial_dims": 2, + "in_channels": 1, + "num_res_blocks": 1, + "channels": (8, 8, 8), + "attention_levels": (False, False, False), + "norm_num_groups": 8, + "with_conditioning": True, + "transformer_num_layers": 1, + "cross_attention_dim": 3, + }, + (1, 8, 4, 4), + ], + [ + { + "spatial_dims": 2, + "in_channels": 1, + "num_res_blocks": 1, + "channels": (8, 8, 8), + "attention_levels": (False, False, False), + "norm_num_groups": 8, + "with_conditioning": True, + "transformer_num_layers": 1, + "cross_attention_dim": 3, + "resblock_updown": True, + }, + (1, 8, 4, 4), + ], + [ + { + "spatial_dims": 2, + "in_channels": 1, + "num_res_blocks": 1, + "channels": (8, 8, 8), + "attention_levels": (False, False, False), + "norm_num_groups": 8, + "with_conditioning": True, + "transformer_num_layers": 1, + "cross_attention_dim": 3, + "upcast_attention": True, + }, + (1, 8, 4, 4), + ], +] + + +class TestControlNet(unittest.TestCase): + @parameterized.expand(UNCOND_CASES_2D + UNCOND_CASES_3D) + def test_shape_unconditioned_models(self, input_param, expected_output_shape): + input_param["conditioning_embedding_in_channels"] = input_param["in_channels"] + input_param["conditioning_embedding_num_channels"] = (input_param["channels"][0],) + net = ControlNet(**input_param) + with eval_mode(net): + x = torch.rand((1, 1) + (16,) * input_param["spatial_dims"]) + timesteps = torch.randint(0, 1000, (1,)).long() + controlnet_cond = torch.rand((1, 1) + (16,) * input_param["spatial_dims"]) + result = net.forward(x, timesteps=timesteps, controlnet_cond=controlnet_cond) + self.assertEqual(len(result[0]), 2 * len(input_param["channels"])) + self.assertEqual(result[1].shape, expected_output_shape) + + @parameterized.expand(COND_CASES_2D) + def test_shape_conditioned_models(self, input_param, expected_output_shape): + input_param["conditioning_embedding_in_channels"] = input_param["in_channels"] + input_param["conditioning_embedding_num_channels"] = (input_param["channels"][0],) + net = ControlNet(**input_param) + with eval_mode(net): + x = torch.rand((1, 1) + (16,) * input_param["spatial_dims"]) + timesteps = torch.randint(0, 1000, (1,)).long() + controlnet_cond = torch.rand((1, 1) + (16,) * input_param["spatial_dims"]) + result = net.forward(x, timesteps=timesteps, controlnet_cond=controlnet_cond, context=torch.rand((1, 1, 3))) + self.assertEqual(len(result[0]), 2 * len(input_param["channels"])) + self.assertEqual(result[1].shape, expected_output_shape) + + +if __name__ == "__main__": + unittest.main() From b85a534a32a9d969aba7b6ba752d1f2486c42177 Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Thu, 14 Dec 2023 06:06:59 -0500 Subject: [PATCH 06/38] Adds patchgan discriminator (#7319) Part of #6676 . ### Description Adds a patchgan-style discriminator, both single scale and multiscale. ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [x] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [x] In-line docstrings updated. - [x] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Mark Graham Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- docs/source/networks.rst | 8 + monai/networks/nets/__init__.py | 1 + monai/networks/nets/patchgan_discriminator.py | 247 ++++++++++++++++++ tests/test_patch_gan_dicriminator.py | 179 +++++++++++++ 4 files changed, 435 insertions(+) create mode 100644 monai/networks/nets/patchgan_discriminator.py create mode 100644 tests/test_patch_gan_dicriminator.py diff --git a/docs/source/networks.rst b/docs/source/networks.rst index 0960fcdbc0..8e79298941 100644 --- a/docs/source/networks.rst +++ b/docs/source/networks.rst @@ -756,6 +756,14 @@ Nets .. autoclass:: VQVAE :members: +`PatchGANDiscriminator` +~~~~~~~~~~~~~~~~~~~~~~~ +.. autoclass:: PatchDiscriminator + :members: + +.. autoclass:: MultiScalePatchDiscriminator + :members: + Utilities --------- .. automodule:: monai.networks.utils diff --git a/monai/networks/nets/__init__.py b/monai/networks/nets/__init__.py index 58cb652bae..0f0d033d63 100644 --- a/monai/networks/nets/__init__.py +++ b/monai/networks/nets/__init__.py @@ -55,6 +55,7 @@ from .hovernet import Hovernet, HoVernet, HoVerNet, HoverNet from .milmodel import MILModel from .netadapter import NetAdapter +from .patchgan_discriminator import MultiScalePatchDiscriminator, PatchDiscriminator from .quicknat import Quicknat from .regressor import Regressor from .regunet import GlobalNet, LocalNet, RegUNet diff --git a/monai/networks/nets/patchgan_discriminator.py b/monai/networks/nets/patchgan_discriminator.py new file mode 100644 index 0000000000..3b089616ce --- /dev/null +++ b/monai/networks/nets/patchgan_discriminator.py @@ -0,0 +1,247 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from collections.abc import Sequence + +import torch +import torch.nn as nn + +from monai.networks.blocks import Convolution +from monai.networks.layers import Act + + +class MultiScalePatchDiscriminator(nn.Sequential): + """ + Multi-scale Patch-GAN discriminator based on Pix2PixHD: + High-Resolution Image Synthesis and Semantic Manipulation with Conditional GANs (https://arxiv.org/abs/1711.11585) + + The Multi-scale discriminator made up of several PatchGAN discriminators, that process the images + at different spatial scales. + + Args: + num_d: number of discriminators + num_layers_d: number of Convolution layers (Conv + activation + normalisation + [dropout]) in the first + discriminator. Each subsequent discriminator has one additional layer, meaning the output size is halved. + spatial_dims: number of spatial dimensions (1D, 2D etc.) + channels: number of filters in the first convolutional layer (doubled for each subsequent layer) + in_channels: number of input channels + out_channels: number of output channels in each discriminator + kernel_size: kernel size of the convolution layers + activation: activation layer type + norm: normalisation type + bias: introduction of layer bias + dropout: probability of dropout applied, defaults to 0. + minimum_size_im: minimum spatial size of the input image. Introduced to make sure the architecture + requested isn't going to downsample the input image beyond value of 1. + last_conv_kernel_size: kernel size of the last convolutional layer. + """ + + def __init__( + self, + num_d: int, + num_layers_d: int, + spatial_dims: int, + channels: int, + in_channels: int, + out_channels: int = 1, + kernel_size: int = 4, + activation: str | tuple = (Act.LEAKYRELU, {"negative_slope": 0.2}), + norm: str | tuple = "BATCH", + bias: bool = False, + dropout: float | tuple = 0.0, + minimum_size_im: int = 256, + last_conv_kernel_size: int = 1, + ) -> None: + super().__init__() + self.num_d = num_d + self.num_layers_d = num_layers_d + self.num_channels = channels + self.padding = tuple([int((kernel_size - 1) / 2)] * spatial_dims) + for i_ in range(self.num_d): + num_layers_d_i = self.num_layers_d * (i_ + 1) + output_size = float(minimum_size_im) / (2**num_layers_d_i) + if output_size < 1: + raise AssertionError( + f"Your image size is too small to take in up to {i_} discriminators with num_layers = {num_layers_d_i}." + "Please reduce num_layers, reduce num_D or enter bigger images." + ) + subnet_d = PatchDiscriminator( + spatial_dims=spatial_dims, + channels=self.num_channels, + in_channels=in_channels, + out_channels=out_channels, + num_layers_d=num_layers_d_i, + kernel_size=kernel_size, + activation=activation, + norm=norm, + bias=bias, + padding=self.padding, + dropout=dropout, + last_conv_kernel_size=last_conv_kernel_size, + ) + + self.add_module("discriminator_%d" % i_, subnet_d) + + def forward(self, i: torch.Tensor) -> tuple[list[torch.Tensor], list[list[torch.Tensor]]]: + """ + Args: + i: Input tensor + + Returns: + list of outputs and another list of lists with the intermediate features + of each discriminator. + """ + + out: list[torch.Tensor] = [] + intermediate_features: list[list[torch.Tensor]] = [] + for disc in self.children(): + out_d: list[torch.Tensor] = disc(i) + out.append(out_d[-1]) + intermediate_features.append(out_d[:-1]) + + return out, intermediate_features + + +class PatchDiscriminator(nn.Sequential): + """ + Patch-GAN discriminator based on Pix2PixHD: + High-Resolution Image Synthesis and Semantic Manipulation with Conditional GANs (https://arxiv.org/abs/1711.11585) + + + Args: + spatial_dims: number of spatial dimensions (1D, 2D etc.) + channels: number of filters in the first convolutional layer (doubled for each subsequent layer) + in_channels: number of input channels + out_channels: number of output channels + num_layers_d: number of Convolution layers (Conv + activation + normalisation + [dropout]) in the discriminator. + kernel_size: kernel size of the convolution layers + act: activation type and arguments. Defaults to LeakyReLU. + norm: feature normalization type and arguments. Defaults to batch norm. + bias: whether to have a bias term in convolution blocks. Defaults to False. + padding: padding to be applied to the convolutional layers + dropout: proportion of dropout applied, defaults to 0. + last_conv_kernel_size: kernel size of the last convolutional layer. + """ + + def __init__( + self, + spatial_dims: int, + channels: int, + in_channels: int, + out_channels: int = 1, + num_layers_d: int = 3, + kernel_size: int = 4, + activation: str | tuple = (Act.LEAKYRELU, {"negative_slope": 0.2}), + norm: str | tuple = "BATCH", + bias: bool = False, + padding: int | Sequence[int] = 1, + dropout: float | tuple = 0.0, + last_conv_kernel_size: int | None = None, + ) -> None: + super().__init__() + self.num_layers_d = num_layers_d + self.num_channels = channels + if last_conv_kernel_size is None: + last_conv_kernel_size = kernel_size + + self.add_module( + "initial_conv", + Convolution( + spatial_dims=spatial_dims, + kernel_size=kernel_size, + in_channels=in_channels, + out_channels=channels, + act=activation, + bias=True, + norm=None, + dropout=dropout, + padding=padding, + strides=2, + ), + ) + + input_channels = channels + output_channels = channels * 2 + + # Initial Layer + for l_ in range(self.num_layers_d): + if l_ == self.num_layers_d - 1: + stride = 1 + else: + stride = 2 + layer = Convolution( + spatial_dims=spatial_dims, + kernel_size=kernel_size, + in_channels=input_channels, + out_channels=output_channels, + act=activation, + bias=bias, + norm=norm, + dropout=dropout, + padding=padding, + strides=stride, + ) + self.add_module("%d" % l_, layer) + input_channels = output_channels + output_channels = output_channels * 2 + + # Final layer + self.add_module( + "final_conv", + Convolution( + spatial_dims=spatial_dims, + kernel_size=last_conv_kernel_size, + in_channels=input_channels, + out_channels=out_channels, + bias=True, + conv_only=True, + padding=int((last_conv_kernel_size - 1) / 2), + dropout=0.0, + strides=1, + ), + ) + + self.apply(self.initialise_weights) + + def forward(self, x: torch.Tensor) -> list[torch.Tensor]: + """ + Args: + x: input tensor + + Returns: + list of intermediate features, with the last element being the output. + """ + out = [x] + for submodel in self.children(): + intermediate_output = submodel(out[-1]) + out.append(intermediate_output) + + return out[1:] + + def initialise_weights(self, m: nn.Module) -> None: + """ + Initialise weights of Convolution and BatchNorm layers. + + Args: + m: instance of torch.nn.module (or of class inheriting torch.nn.module) + """ + classname = m.__class__.__name__ + if classname.find("Conv2d") != -1: + nn.init.normal_(m.weight.data, 0.0, 0.02) + elif classname.find("Conv3d") != -1: + nn.init.normal_(m.weight.data, 0.0, 0.02) + elif classname.find("Conv1d") != -1: + nn.init.normal_(m.weight.data, 0.0, 0.02) + elif classname.find("BatchNorm") != -1: + nn.init.normal_(m.weight.data, 1.0, 0.02) + nn.init.constant_(m.bias.data, 0) diff --git a/tests/test_patch_gan_dicriminator.py b/tests/test_patch_gan_dicriminator.py new file mode 100644 index 0000000000..c19898e70d --- /dev/null +++ b/tests/test_patch_gan_dicriminator.py @@ -0,0 +1,179 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import torch +from parameterized import parameterized + +from monai.networks import eval_mode +from monai.networks.nets import MultiScalePatchDiscriminator, PatchDiscriminator +from tests.utils import test_script_save + +TEST_PATCHGAN = [ + [ + { + "num_layers_d": 3, + "spatial_dims": 2, + "channels": 8, + "in_channels": 3, + "out_channels": 1, + "kernel_size": 3, + "activation": "LEAKYRELU", + "norm": "instance", + "bias": False, + "dropout": 0.1, + }, + torch.rand([1, 3, 256, 512]), + (1, 8, 128, 256), + (1, 1, 32, 64), + ], + [ + { + "num_layers_d": 3, + "spatial_dims": 3, + "channels": 8, + "in_channels": 3, + "out_channels": 1, + "kernel_size": 3, + "activation": "LEAKYRELU", + "norm": "instance", + "bias": False, + "dropout": 0.1, + }, + torch.rand([1, 3, 256, 512, 256]), + (1, 8, 128, 256, 128), + (1, 1, 32, 64, 32), + ], +] + +TEST_MULTISCALE_PATCHGAN = [ + [ + { + "num_d": 2, + "num_layers_d": 3, + "spatial_dims": 2, + "channels": 8, + "in_channels": 3, + "out_channels": 1, + "kernel_size": 3, + "activation": "LEAKYRELU", + "norm": "instance", + "bias": False, + "dropout": 0.1, + "minimum_size_im": 256, + }, + torch.rand([1, 3, 256, 512]), + [(1, 1, 32, 64), (1, 1, 4, 8)], + [4, 7], + ], + [ + { + "num_d": 2, + "num_layers_d": 3, + "spatial_dims": 3, + "channels": 8, + "in_channels": 3, + "out_channels": 1, + "kernel_size": 3, + "activation": "LEAKYRELU", + "norm": "instance", + "bias": False, + "dropout": 0.1, + "minimum_size_im": 256, + }, + torch.rand([1, 3, 256, 512, 256]), + [(1, 1, 32, 64, 32), (1, 1, 4, 8, 4)], + [4, 7], + ], +] +TEST_TOO_SMALL_SIZE = [ + { + "num_d": 2, + "num_layers_d": 6, + "spatial_dims": 2, + "channels": 8, + "in_channels": 3, + "out_channels": 1, + "kernel_size": 3, + "activation": "LEAKYRELU", + "norm": "instance", + "bias": False, + "dropout": 0.1, + "minimum_size_im": 256, + } +] + + +class TestPatchGAN(unittest.TestCase): + @parameterized.expand(TEST_PATCHGAN) + def test_shape(self, input_param, input_data, expected_shape_feature, expected_shape_output): + net = PatchDiscriminator(**input_param) + with eval_mode(net): + result = net.forward(input_data) + self.assertEqual(tuple(result[0].shape), expected_shape_feature) + self.assertEqual(tuple(result[-1].shape), expected_shape_output) + + def test_script(self): + net = PatchDiscriminator( + num_layers_d=3, + spatial_dims=2, + channels=8, + in_channels=3, + out_channels=1, + kernel_size=3, + activation="LEAKYRELU", + norm="instance", + bias=False, + dropout=0.1, + ) + i = torch.rand([1, 3, 256, 512]) + test_script_save(net, i) + + +class TestMultiscalePatchGAN(unittest.TestCase): + @parameterized.expand(TEST_MULTISCALE_PATCHGAN) + def test_shape(self, input_param, input_data, expected_shape, features_lengths=None): + net = MultiScalePatchDiscriminator(**input_param) + with eval_mode(net): + result, features = net.forward(input_data) + for r_ind, r in enumerate(result): + self.assertEqual(tuple(r.shape), expected_shape[r_ind]) + for o_d_ind, o_d in enumerate(features): + self.assertEqual(len(o_d), features_lengths[o_d_ind]) + + def test_too_small_shape(self): + with self.assertRaises(AssertionError): + MultiScalePatchDiscriminator(**TEST_TOO_SMALL_SIZE[0]) + + def test_script(self): + net = MultiScalePatchDiscriminator( + num_d=2, + num_layers_d=3, + spatial_dims=2, + channels=8, + in_channels=3, + out_channels=1, + kernel_size=3, + activation="LEAKYRELU", + norm="instance", + bias=False, + dropout=0.1, + minimum_size_im=256, + ) + i = torch.rand([1, 3, 256, 512]) + test_script_save(net, i) + + +if __name__ == "__main__": + unittest.main() From aa4a4dbd3619dd443681c4688bd48ce1bea9b85d Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Tue, 19 Dec 2023 03:21:32 +0000 Subject: [PATCH 07/38] 6676 port generative networks spade (#7320) Towards #6676 . ### Description This adds SPADE-enabled autoencoder and diffusion_model_unet architectures. They are new implementations for each network, rather than options in the existing network, because @virginiafdez and I felt that adding additional options to the existing networks to enable spade compatibility significantly reduced the readability of them for users who were not interested in SPADE functionality. These are the last networks to be ported over. ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [x] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [x] In-line docstrings updated. - [x] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Mark Graham Signed-off-by: Mark Graham Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- docs/source/networks.rst | 14 + monai/networks/blocks/__init__.py | 1 + monai/networks/blocks/spade_norm.py | 96 ++ monai/networks/nets/__init__.py | 2 + monai/networks/nets/spade_autoencoderkl.py | 473 +++++++++ .../nets/spade_diffusion_model_unet.py | 908 ++++++++++++++++++ test_spade_autoencoderkl.py | 260 +++++ test_spade_diffusion_model_unet.py | 558 +++++++++++ 8 files changed, 2312 insertions(+) create mode 100644 monai/networks/blocks/spade_norm.py create mode 100644 monai/networks/nets/spade_autoencoderkl.py create mode 100644 monai/networks/nets/spade_diffusion_model_unet.py create mode 100644 test_spade_autoencoderkl.py create mode 100644 test_spade_diffusion_model_unet.py diff --git a/docs/source/networks.rst b/docs/source/networks.rst index 8e79298941..79d5ef822e 100644 --- a/docs/source/networks.rst +++ b/docs/source/networks.rst @@ -258,6 +258,10 @@ N-Dim Fourier Transform .. autofunction:: monai.networks.blocks.fft_utils_t.fftshift .. autofunction:: monai.networks.blocks.fft_utils_t.ifftshift +`SPADE` +~~~~~~~ +.. autoclass:: monai.networks.blocks.spade_norm.SPADE + :members: Layers ------ @@ -588,6 +592,11 @@ Nets .. autoclass:: DiffusionModelUNet :members: +`SPADEDiffusionModelUNet` +~~~~~~~~~~~~~~~~~~~~~~~~~ +.. autoclass:: SPADEDiffusionModelUNet + :members: + `ControlNet` ~~~~~~~~~~~~ .. autoclass:: ControlNet @@ -618,6 +627,11 @@ Nets .. autoclass:: AutoencoderKL :members: +`SPADEAutoencoderKL` +~~~~~~~~~~~~~~~~~~~~ +.. autoclass:: SPADEAutoencoderKL + :members: + `VarAutoEncoder` ~~~~~~~~~~~~~~~~ .. autoclass:: VarAutoEncoder diff --git a/monai/networks/blocks/__init__.py b/monai/networks/blocks/__init__.py index e67cb3376f..afb6664bd9 100644 --- a/monai/networks/blocks/__init__.py +++ b/monai/networks/blocks/__init__.py @@ -30,6 +30,7 @@ from .regunet_block import RegistrationDownSampleBlock, RegistrationExtractionBlock, RegistrationResidualConvBlock from .segresnet_block import ResBlock from .selfattention import SABlock +from .spade_norm import SPADE from .squeeze_and_excitation import ( ChannelSELayer, ResidualSELayer, diff --git a/monai/networks/blocks/spade_norm.py b/monai/networks/blocks/spade_norm.py new file mode 100644 index 0000000000..b1046f3154 --- /dev/null +++ b/monai/networks/blocks/spade_norm.py @@ -0,0 +1,96 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from monai.networks.blocks import ADN, Convolution + + +class SPADE(nn.Module): + """ + Spatially Adaptive Normalization (SPADE) block, allowing for normalization of activations conditioned on a + semantic map. This block is used in SPADE-based image-to-image translation models, as described in + Semantic Image Synthesis with Spatially-Adaptive Normalization (https://arxiv.org/abs/1903.07291). + + Args: + label_nc: number of semantic labels + norm_nc: number of output channels + kernel_size: kernel size + spatial_dims: number of spatial dimensions + hidden_channels: number of channels in the intermediate gamma and beta layers + norm: type of base normalisation used before applying the SPADE normalisation + norm_params: parameters for the base normalisation + """ + + def __init__( + self, + label_nc: int, + norm_nc: int, + kernel_size: int = 3, + spatial_dims: int = 2, + hidden_channels: int = 64, + norm: str | tuple = "INSTANCE", + norm_params: dict | None = None, + ) -> None: + super().__init__() + + if norm_params is None: + norm_params = {} + if len(norm_params) != 0: + norm = (norm, norm_params) + self.param_free_norm = ADN( + act=None, dropout=0.0, norm=norm, norm_dim=spatial_dims, ordering="N", in_channels=norm_nc + ) + self.mlp_shared = Convolution( + spatial_dims=spatial_dims, + in_channels=label_nc, + out_channels=hidden_channels, + kernel_size=kernel_size, + norm=None, + act="LEAKYRELU", + ) + self.mlp_gamma = Convolution( + spatial_dims=spatial_dims, + in_channels=hidden_channels, + out_channels=norm_nc, + kernel_size=kernel_size, + act=None, + ) + self.mlp_beta = Convolution( + spatial_dims=spatial_dims, + in_channels=hidden_channels, + out_channels=norm_nc, + kernel_size=kernel_size, + act=None, + ) + + def forward(self, x: torch.Tensor, segmap: torch.Tensor) -> torch.Tensor: + """ + Args: + x: input tensor with shape (B, C, [spatial-dimensions]) where C is the number of semantic channels. + segmap: input segmentation map (B, C, [spatial-dimensions]) where C is the number of semantic channels. + The map will be interpolated to the dimension of x internally. + """ + + # Part 1. generate parameter-free normalized activations + normalized = self.param_free_norm(x) + + # Part 2. produce scaling and bias conditioned on semantic map + segmap = F.interpolate(segmap, size=x.size()[2:], mode="nearest") + actv = self.mlp_shared(segmap) + gamma = self.mlp_gamma(actv) + beta = self.mlp_beta(actv) + out: torch.Tensor = normalized * (1 + gamma) + beta + return out diff --git a/monai/networks/nets/__init__.py b/monai/networks/nets/__init__.py index 0f0d033d63..a7ce16ad64 100644 --- a/monai/networks/nets/__init__.py +++ b/monai/networks/nets/__init__.py @@ -106,6 +106,8 @@ seresnext50, seresnext101, ) +from .spade_autoencoderkl import SPADEAutoencoderKL +from .spade_diffusion_model_unet import SPADEDiffusionModelUNet from .swin_unetr import PatchMerging, PatchMergingV2, SwinUNETR from .torchvision_fc import TorchVisionFCModel from .transchex import BertAttention, BertMixedLayer, BertOutput, BertPreTrainedModel, MultiModal, Pooler, Transchex diff --git a/monai/networks/nets/spade_autoencoderkl.py b/monai/networks/nets/spade_autoencoderkl.py new file mode 100644 index 0000000000..e064c19740 --- /dev/null +++ b/monai/networks/nets/spade_autoencoderkl.py @@ -0,0 +1,473 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from collections.abc import Sequence + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from monai.networks.blocks import Convolution +from monai.networks.blocks.spade_norm import SPADE +from monai.networks.nets.autoencoderkl import Encoder, _AttentionBlock, _Upsample +from monai.utils import ensure_tuple_rep + +__all__ = ["SPADEAutoencoderKL"] + + +class SPADEResBlock(nn.Module): + """ + Residual block consisting of a cascade of 2 convolutions + activation + normalisation block, and a + residual connection between input and output. + Enables SPADE normalisation for semantic conditioning (Park et. al (2019): https://github.com/NVlabs/SPADE) + + Args: + spatial_dims: number of spatial dimensions (1D, 2D, 3D). + in_channels: input channels to the layer. + norm_num_groups: number of groups involved for the group normalisation layer. Ensure that your number of + channels is divisible by this number. + norm_eps: epsilon for the normalisation. + out_channels: number of output channels. + label_nc: number of semantic channels for SPADE normalisation + spade_intermediate_channels: number of intermediate channels for SPADE block layer + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + norm_num_groups: int, + norm_eps: float, + out_channels: int, + label_nc: int, + spade_intermediate_channels: int, + ) -> None: + super().__init__() + self.in_channels = in_channels + self.out_channels = in_channels if out_channels is None else out_channels + self.norm1 = SPADE( + label_nc=label_nc, + norm_nc=in_channels, + norm="GROUP", + norm_params={"num_groups": norm_num_groups, "affine": False}, + hidden_channels=spade_intermediate_channels, + kernel_size=3, + spatial_dims=spatial_dims, + ) + self.conv1 = Convolution( + spatial_dims=spatial_dims, + in_channels=self.in_channels, + out_channels=self.out_channels, + strides=1, + kernel_size=3, + padding=1, + conv_only=True, + ) + self.norm2 = SPADE( + label_nc=label_nc, + norm_nc=out_channels, + norm="GROUP", + norm_params={"num_groups": norm_num_groups, "affine": False}, + hidden_channels=spade_intermediate_channels, + kernel_size=3, + spatial_dims=spatial_dims, + ) + self.conv2 = Convolution( + spatial_dims=spatial_dims, + in_channels=self.out_channels, + out_channels=self.out_channels, + strides=1, + kernel_size=3, + padding=1, + conv_only=True, + ) + + self.nin_shortcut: nn.Module + if self.in_channels != self.out_channels: + self.nin_shortcut = Convolution( + spatial_dims=spatial_dims, + in_channels=self.in_channels, + out_channels=self.out_channels, + strides=1, + kernel_size=1, + padding=0, + conv_only=True, + ) + else: + self.nin_shortcut = nn.Identity() + + def forward(self, x: torch.Tensor, seg: torch.Tensor) -> torch.Tensor: + h = x + h = self.norm1(h, seg) + h = F.silu(h) + h = self.conv1(h) + h = self.norm2(h, seg) + h = F.silu(h) + h = self.conv2(h) + + x = self.nin_shortcut(x) + + return x + h + + +class SPADEDecoder(nn.Module): + """ + Convolutional cascade upsampling from a spatial latent space into an image space. + Enables SPADE normalisation for semantic conditioning (Park et. al (2019): https://github.com/NVlabs/SPADE) + + Args: + spatial_dims: number of spatial dimensions (1D, 2D, 3D). + channels: sequence of block output channels. + in_channels: number of channels in the bottom layer (latent space) of the autoencoder. + out_channels: number of output channels. + num_res_blocks: number of residual blocks (see ResBlock) per level. + norm_num_groups: number of groups for the GroupNorm layers, channels must be divisible by this number. + norm_eps: epsilon for the normalization. + attention_levels: indicate which level from channels contain an attention block. + label_nc: number of semantic channels for SPADE normalisation. + with_nonlocal_attn: if True use non-local attention block. + use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. + spade_intermediate_channels: number of intermediate channels for SPADE block layer. + """ + + def __init__( + self, + spatial_dims: int, + channels: Sequence[int], + in_channels: int, + out_channels: int, + num_res_blocks: Sequence[int], + norm_num_groups: int, + norm_eps: float, + attention_levels: Sequence[bool], + label_nc: int, + with_nonlocal_attn: bool = True, + use_flash_attention: bool = False, + spade_intermediate_channels: int = 128, + ) -> None: + super().__init__() + self.spatial_dims = spatial_dims + self.channels = channels + self.in_channels = in_channels + self.out_channels = out_channels + self.num_res_blocks = num_res_blocks + self.norm_num_groups = norm_num_groups + self.norm_eps = norm_eps + self.attention_levels = attention_levels + self.label_nc = label_nc + + reversed_block_out_channels = list(reversed(channels)) + + blocks: list[nn.Module] = [] + + # Initial convolution + blocks.append( + Convolution( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=reversed_block_out_channels[0], + strides=1, + kernel_size=3, + padding=1, + conv_only=True, + ) + ) + + # Non-local attention block + if with_nonlocal_attn is True: + blocks.append( + SPADEResBlock( + spatial_dims=spatial_dims, + in_channels=reversed_block_out_channels[0], + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + out_channels=reversed_block_out_channels[0], + label_nc=label_nc, + spade_intermediate_channels=spade_intermediate_channels, + ) + ) + blocks.append( + _AttentionBlock( + spatial_dims=spatial_dims, + num_channels=reversed_block_out_channels[0], + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + use_flash_attention=use_flash_attention, + ) + ) + blocks.append( + SPADEResBlock( + spatial_dims=spatial_dims, + in_channels=reversed_block_out_channels[0], + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + out_channels=reversed_block_out_channels[0], + label_nc=label_nc, + spade_intermediate_channels=spade_intermediate_channels, + ) + ) + + reversed_attention_levels = list(reversed(attention_levels)) + reversed_num_res_blocks = list(reversed(num_res_blocks)) + block_out_ch = reversed_block_out_channels[0] + for i in range(len(reversed_block_out_channels)): + block_in_ch = block_out_ch + block_out_ch = reversed_block_out_channels[i] + is_final_block = i == len(channels) - 1 + + for _ in range(reversed_num_res_blocks[i]): + blocks.append( + SPADEResBlock( + spatial_dims=spatial_dims, + in_channels=block_in_ch, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + out_channels=block_out_ch, + label_nc=label_nc, + spade_intermediate_channels=spade_intermediate_channels, + ) + ) + block_in_ch = block_out_ch + + if reversed_attention_levels[i]: + blocks.append( + _AttentionBlock( + spatial_dims=spatial_dims, + num_channels=block_in_ch, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + use_flash_attention=use_flash_attention, + ) + ) + + if not is_final_block: + blocks.append(_Upsample(spatial_dims=spatial_dims, in_channels=block_in_ch, use_convtranspose=False)) + + blocks.append(nn.GroupNorm(num_groups=norm_num_groups, num_channels=block_in_ch, eps=norm_eps, affine=True)) + blocks.append( + Convolution( + spatial_dims=spatial_dims, + in_channels=block_in_ch, + out_channels=out_channels, + strides=1, + kernel_size=3, + padding=1, + conv_only=True, + ) + ) + + self.blocks = nn.ModuleList(blocks) + + def forward(self, x: torch.Tensor, seg: torch.Tensor) -> torch.Tensor: + for block in self.blocks: + if isinstance(block, SPADEResBlock): + x = block(x, seg) + else: + x = block(x) + return x + + +class SPADEAutoencoderKL(nn.Module): + """ + Autoencoder model with KL-regularized latent space based on + Rombach et al. "High-Resolution Image Synthesis with Latent Diffusion Models" https://arxiv.org/abs/2112.10752 + and Pinaya et al. "Brain Imaging Generation with Latent Diffusion Models" https://arxiv.org/abs/2209.07162 + Enables SPADE normalisation for semantic conditioning (Park et. al (2019): https://github.com/NVlabs/SPADE) + + Args: + spatial_dims: number of spatial dimensions (1D, 2D, 3D). + label_nc: number of semantic channels for SPADE normalisation. + in_channels: number of input channels. + out_channels: number of output channels. + num_res_blocks: number of residual blocks (see ResBlock) per level. + channels: sequence of block output channels. + attention_levels: sequence of levels to add attention. + latent_channels: latent embedding dimension. + norm_num_groups: number of groups for the GroupNorm layers, channels must be divisible by this number. + norm_eps: epsilon for the normalization. + with_encoder_nonlocal_attn: if True use non-local attention block in the encoder. + with_decoder_nonlocal_attn: if True use non-local attention block in the decoder. + use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. + spade_intermediate_channels: number of intermediate channels for SPADE block layer. + """ + + def __init__( + self, + spatial_dims: int, + label_nc: int, + in_channels: int = 1, + out_channels: int = 1, + num_res_blocks: Sequence[int] | int = (2, 2, 2, 2), + channels: Sequence[int] = (32, 64, 64, 64), + attention_levels: Sequence[bool] = (False, False, True, True), + latent_channels: int = 3, + norm_num_groups: int = 32, + norm_eps: float = 1e-6, + with_encoder_nonlocal_attn: bool = True, + with_decoder_nonlocal_attn: bool = True, + use_flash_attention: bool = False, + spade_intermediate_channels: int = 128, + ) -> None: + super().__init__() + + # All number of channels should be multiple of num_groups + if any((out_channel % norm_num_groups) != 0 for out_channel in channels): + raise ValueError("SPADEAutoencoderKL expects all channels being multiple of norm_num_groups") + + if len(channels) != len(attention_levels): + raise ValueError("SPADEAutoencoderKL expects channels being same size of attention_levels") + + if isinstance(num_res_blocks, int): + num_res_blocks = ensure_tuple_rep(num_res_blocks, len(channels)) + + if len(num_res_blocks) != len(channels): + raise ValueError( + "`num_res_blocks` should be a single integer or a tuple of integers with the same length as " + "`channels`." + ) + + if use_flash_attention is True and not torch.cuda.is_available(): + raise ValueError( + "torch.cuda.is_available() should be True but is False. Flash attention is only available for GPU." + ) + + self.encoder = Encoder( + spatial_dims=spatial_dims, + in_channels=in_channels, + channels=channels, + out_channels=latent_channels, + num_res_blocks=num_res_blocks, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + attention_levels=attention_levels, + with_nonlocal_attn=with_encoder_nonlocal_attn, + use_flash_attention=use_flash_attention, + ) + self.decoder = SPADEDecoder( + spatial_dims=spatial_dims, + channels=channels, + in_channels=latent_channels, + out_channels=out_channels, + num_res_blocks=num_res_blocks, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + attention_levels=attention_levels, + label_nc=label_nc, + with_nonlocal_attn=with_decoder_nonlocal_attn, + use_flash_attention=use_flash_attention, + spade_intermediate_channels=spade_intermediate_channels, + ) + self.quant_conv_mu = Convolution( + spatial_dims=spatial_dims, + in_channels=latent_channels, + out_channels=latent_channels, + strides=1, + kernel_size=1, + padding=0, + conv_only=True, + ) + self.quant_conv_log_sigma = Convolution( + spatial_dims=spatial_dims, + in_channels=latent_channels, + out_channels=latent_channels, + strides=1, + kernel_size=1, + padding=0, + conv_only=True, + ) + self.post_quant_conv = Convolution( + spatial_dims=spatial_dims, + in_channels=latent_channels, + out_channels=latent_channels, + strides=1, + kernel_size=1, + padding=0, + conv_only=True, + ) + self.latent_channels = latent_channels + + def encode(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + """ + Forwards an image through the spatial encoder, obtaining the latent mean and sigma representations. + + Args: + x: BxCx[SPATIAL DIMS] tensor + + """ + h = self.encoder(x) + z_mu = self.quant_conv_mu(h) + z_log_var = self.quant_conv_log_sigma(h) + z_log_var = torch.clamp(z_log_var, -30.0, 20.0) + z_sigma = torch.exp(z_log_var / 2) + + return z_mu, z_sigma + + def sampling(self, z_mu: torch.Tensor, z_sigma: torch.Tensor) -> torch.Tensor: + """ + From the mean and sigma representations resulting of encoding an image through the latent space, + obtains a noise sample resulting from sampling gaussian noise, multiplying by the variance (sigma) and + adding the mean. + + Args: + z_mu: Bx[Z_CHANNELS]x[LATENT SPACE SIZE] mean vector obtained by the encoder when you encode an image + z_sigma: Bx[Z_CHANNELS]x[LATENT SPACE SIZE] variance vector obtained by the encoder when you encode an image + + Returns: + sample of shape Bx[Z_CHANNELS]x[LATENT SPACE SIZE] + """ + eps = torch.randn_like(z_sigma) + z_vae = z_mu + eps * z_sigma + return z_vae + + def reconstruct(self, x: torch.Tensor, seg: torch.Tensor) -> torch.Tensor: + """ + Encodes and decodes an input image. + + Args: + x: BxCx[SPATIAL DIMENSIONS] tensor. + seg: Bx[LABEL_NC]x[SPATIAL DIMENSIONS] tensor of segmentations for SPADE norm. + Returns: + reconstructed image, of the same shape as input + """ + z_mu, _ = self.encode(x) + reconstruction = self.decode(z_mu, seg) + return reconstruction + + def decode(self, z: torch.Tensor, seg: torch.Tensor) -> torch.Tensor: + """ + Based on a latent space sample, forwards it through the Decoder. + + Args: + z: Bx[Z_CHANNELS]x[LATENT SPACE SHAPE] + seg: Bx[LABEL_NC]x[SPATIAL DIMENSIONS] tensor of segmentations for SPADE norm. + Returns: + decoded image tensor + """ + z = self.post_quant_conv(z) + dec: torch.Tensor = self.decoder(z, seg) + return dec + + def forward(self, x: torch.Tensor, seg: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + z_mu, z_sigma = self.encode(x) + z = self.sampling(z_mu, z_sigma) + reconstruction = self.decode(z, seg) + return reconstruction, z_mu, z_sigma + + def encode_stage_2_inputs(self, x: torch.Tensor) -> torch.Tensor: + z_mu, z_sigma = self.encode(x) + z = self.sampling(z_mu, z_sigma) + return z + + def decode_stage_2_outputs(self, z: torch.Tensor, seg: torch.Tensor) -> torch.Tensor: + image = self.decode(z, seg) + return image diff --git a/monai/networks/nets/spade_diffusion_model_unet.py b/monai/networks/nets/spade_diffusion_model_unet.py new file mode 100644 index 0000000000..d53327100e --- /dev/null +++ b/monai/networks/nets/spade_diffusion_model_unet.py @@ -0,0 +1,908 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# ========================================================================= +# Adapted from https://github.com/huggingface/diffusers +# which has the following license: +# https://github.com/huggingface/diffusers/blob/main/LICENSE +# +# Copyright 2022 UC Berkeley Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========================================================================= + +from __future__ import annotations + +from collections.abc import Sequence + +import torch +from torch import nn + +from monai.networks.blocks import Convolution +from monai.networks.blocks.spade_norm import SPADE +from monai.networks.nets.diffusion_model_unet import ( + _AttentionBlock, + _Downsample, + _ResnetBlock, + _SpatialTransformer, + _Upsample, + get_down_block, + get_mid_block, + get_timestep_embedding, + zero_module, +) +from monai.utils import ensure_tuple_rep, optional_import + +# To install xformers, use pip install xformers==0.0.16rc401 +xops, has_xformers = optional_import("xformers.ops") + + +__all__ = ["SPADEDiffusionModelUNet"] + + +class SPADEResnetBlock(nn.Module): + """ + Residual block with timestep conditioning and SPADE norm. + Enables SPADE normalisation for semantic conditioning (Park et. al (2019): https://github.com/NVlabs/SPADE) + + Args: + spatial_dims: The number of spatial dimensions. + in_channels: number of input channels. + temb_channels: number of timestep embedding channels. + label_nc: number of semantic channels for SPADE normalisation. + out_channels: number of output channels. + up: if True, performs upsampling. + down: if True, performs downsampling. + norm_num_groups: number of groups for the group normalization. + norm_eps: epsilon for the group normalization. + spade_intermediate_channels: number of intermediate channels for SPADE block layer + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + temb_channels: int, + label_nc: int, + out_channels: int | None = None, + up: bool = False, + down: bool = False, + norm_num_groups: int = 32, + norm_eps: float = 1e-6, + spade_intermediate_channels: int = 128, + ) -> None: + super().__init__() + self.spatial_dims = spatial_dims + self.channels = in_channels + self.emb_channels = temb_channels + self.out_channels = out_channels or in_channels + self.up = up + self.down = down + + self.norm1 = SPADE( + label_nc=label_nc, + norm_nc=in_channels, + norm="GROUP", + norm_params={"num_groups": norm_num_groups, "eps": norm_eps, "affine": True}, + hidden_channels=spade_intermediate_channels, + kernel_size=3, + spatial_dims=spatial_dims, + ) + + self.nonlinearity = nn.SiLU() + self.conv1 = Convolution( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=self.out_channels, + strides=1, + kernel_size=3, + padding=1, + conv_only=True, + ) + + self.upsample = self.downsample = None + if self.up: + self.upsample = _Upsample(spatial_dims, in_channels, use_conv=False) + elif down: + self.downsample = _Downsample(spatial_dims, in_channels, use_conv=False) + + self.time_emb_proj = nn.Linear(temb_channels, self.out_channels) + + self.norm2 = SPADE( + label_nc=label_nc, + norm_nc=self.out_channels, + norm="GROUP", + norm_params={"num_groups": norm_num_groups, "eps": norm_eps, "affine": True}, + hidden_channels=spade_intermediate_channels, + kernel_size=3, + spatial_dims=spatial_dims, + ) + self.conv2 = zero_module( + Convolution( + spatial_dims=spatial_dims, + in_channels=self.out_channels, + out_channels=self.out_channels, + strides=1, + kernel_size=3, + padding=1, + conv_only=True, + ) + ) + self.skip_connection: nn.Module + + if self.out_channels == in_channels: + self.skip_connection = nn.Identity() + else: + self.skip_connection = Convolution( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=self.out_channels, + strides=1, + kernel_size=1, + padding=0, + conv_only=True, + ) + + def forward(self, x: torch.Tensor, emb: torch.Tensor, seg: torch.Tensor) -> torch.Tensor: + h = x + h = self.norm1(h, seg) + h = self.nonlinearity(h) + + if self.upsample is not None: + if h.shape[0] >= 64: + x = x.contiguous() + h = h.contiguous() + x = self.upsample(x) + h = self.upsample(h) + elif self.downsample is not None: + x = self.downsample(x) + h = self.downsample(h) + + h = self.conv1(h) + + if self.spatial_dims == 2: + temb = self.time_emb_proj(self.nonlinearity(emb))[:, :, None, None] + else: + temb = self.time_emb_proj(self.nonlinearity(emb))[:, :, None, None, None] + h = h + temb + + h = self.norm2(h, seg) + h = self.nonlinearity(h) + h = self.conv2(h) + output: torch.Tensor = self.skip_connection(x) + h + return output + + +class SPADEUpBlock(nn.Module): + """ + Unet's up block containing resnet and upsamplers blocks. + Enables SPADE normalisation for semantic conditioning (Park et. al (2019): https://github.com/NVlabs/SPADE) + + Args: + spatial_dims: The number of spatial dimensions. + in_channels: number of input channels. + prev_output_channel: number of channels from residual connection. + out_channels: number of output channels. + temb_channels: number of timestep embedding channels. + label_nc: number of semantic channels for SPADE normalisation. + num_res_blocks: number of residual blocks. + norm_num_groups: number of groups for the group normalization. + norm_eps: epsilon for the group normalization. + add_upsample: if True add downsample block. + resblock_updown: if True use residual blocks for upsampling. + spade_intermediate_channels: number of intermediate channels for SPADE block layer. + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + prev_output_channel: int, + out_channels: int, + temb_channels: int, + label_nc: int, + num_res_blocks: int = 1, + norm_num_groups: int = 32, + norm_eps: float = 1e-6, + add_upsample: bool = True, + resblock_updown: bool = False, + spade_intermediate_channels: int = 128, + ) -> None: + super().__init__() + self.resblock_updown = resblock_updown + resnets = [] + + for i in range(num_res_blocks): + res_skip_channels = in_channels if (i == num_res_blocks - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + SPADEResnetBlock( + spatial_dims=spatial_dims, + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + label_nc=label_nc, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + spade_intermediate_channels=spade_intermediate_channels, + ) + ) + + self.resnets = nn.ModuleList(resnets) + + self.upsampler: nn.Module | None + if add_upsample: + if resblock_updown: + self.upsampler = _ResnetBlock( + spatial_dims=spatial_dims, + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + up=True, + ) + else: + self.upsampler = _Upsample( + spatial_dims=spatial_dims, num_channels=out_channels, use_conv=True, out_channels=out_channels + ) + else: + self.upsampler = None + + def forward( + self, + hidden_states: torch.Tensor, + res_hidden_states_list: list[torch.Tensor], + temb: torch.Tensor, + seg: torch.Tensor, + context: torch.Tensor | None = None, + ) -> torch.Tensor: + del context + for resnet in self.resnets: + # pop res hidden states + res_hidden_states = res_hidden_states_list[-1] + res_hidden_states_list = res_hidden_states_list[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + hidden_states = resnet(hidden_states, temb, seg) + + if self.upsampler is not None: + hidden_states = self.upsampler(hidden_states, temb) + + return hidden_states + + +class SPADEAttnUpBlock(nn.Module): + """ + Unet's up block containing resnet, upsamplers, and self-attention blocks. + Enables SPADE normalisation for semantic conditioning (Park et. al (2019): https://github.com/NVlabs/SPADE) + + Args: + spatial_dims: The number of spatial dimensions. + in_channels: number of input channels. + prev_output_channel: number of channels from residual connection. + out_channels: number of output channels. + temb_channels: number of timestep embedding channels. + label_nc: number of semantic channels for SPADE normalisation + num_res_blocks: number of residual blocks. + norm_num_groups: number of groups for the group normalization. + norm_eps: epsilon for the group normalization. + add_upsample: if True add downsample block. + resblock_updown: if True use residual blocks for upsampling. + num_head_channels: number of channels in each attention head. + use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. + spade_intermediate_channels: number of intermediate channels for SPADE block layer + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + prev_output_channel: int, + out_channels: int, + temb_channels: int, + label_nc: int, + num_res_blocks: int = 1, + norm_num_groups: int = 32, + norm_eps: float = 1e-6, + add_upsample: bool = True, + resblock_updown: bool = False, + num_head_channels: int = 1, + use_flash_attention: bool = False, + spade_intermediate_channels: int = 128, + ) -> None: + super().__init__() + self.resblock_updown = resblock_updown + resnets = [] + attentions = [] + + for i in range(num_res_blocks): + res_skip_channels = in_channels if (i == num_res_blocks - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + SPADEResnetBlock( + spatial_dims=spatial_dims, + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + label_nc=label_nc, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + spade_intermediate_channels=spade_intermediate_channels, + ) + ) + attentions.append( + _AttentionBlock( + spatial_dims=spatial_dims, + num_channels=out_channels, + num_head_channels=num_head_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + use_flash_attention=use_flash_attention, + ) + ) + + self.resnets = nn.ModuleList(resnets) + self.attentions = nn.ModuleList(attentions) + + self.upsampler: nn.Module | None + if add_upsample: + if resblock_updown: + self.upsampler = _ResnetBlock( + spatial_dims=spatial_dims, + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + up=True, + ) + else: + self.upsampler = _Upsample( + spatial_dims=spatial_dims, num_channels=out_channels, use_conv=True, out_channels=out_channels + ) + else: + self.upsampler = None + + def forward( + self, + hidden_states: torch.Tensor, + res_hidden_states_list: list[torch.Tensor], + temb: torch.Tensor, + seg: torch.Tensor, + context: torch.Tensor | None = None, + ) -> torch.Tensor: + del context + for resnet, attn in zip(self.resnets, self.attentions): + # pop res hidden states + res_hidden_states = res_hidden_states_list[-1] + res_hidden_states_list = res_hidden_states_list[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + hidden_states = resnet(hidden_states, temb, seg) + hidden_states = attn(hidden_states) + + if self.upsampler is not None: + hidden_states = self.upsampler(hidden_states, temb) + + return hidden_states + + +class SPADECrossAttnUpBlock(nn.Module): + """ + Unet's up block containing resnet, upsamplers, and self-attention blocks. + Enables SPADE normalisation for semantic conditioning (Park et. al (2019): https://github.com/NVlabs/SPADE) + + Args: + spatial_dims: The number of spatial dimensions. + in_channels: number of input channels. + prev_output_channel: number of channels from residual connection. + out_channels: number of output channels. + temb_channels: number of timestep embedding channels. + label_nc: number of semantic channels for SPADE normalisation. + num_res_blocks: number of residual blocks. + norm_num_groups: number of groups for the group normalization. + norm_eps: epsilon for the group normalization. + add_upsample: if True add downsample block. + resblock_updown: if True use residual blocks for upsampling. + num_head_channels: number of channels in each attention head. + transformer_num_layers: number of layers of Transformer blocks to use. + cross_attention_dim: number of context dimensions to use. + upcast_attention: if True, upcast attention operations to full precision. + use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. + spade_intermediate_channels: number of intermediate channels for SPADE block layer. + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + prev_output_channel: int, + out_channels: int, + temb_channels: int, + label_nc: int, + num_res_blocks: int = 1, + norm_num_groups: int = 32, + norm_eps: float = 1e-6, + add_upsample: bool = True, + resblock_updown: bool = False, + num_head_channels: int = 1, + transformer_num_layers: int = 1, + cross_attention_dim: int | None = None, + upcast_attention: bool = False, + use_flash_attention: bool = False, + spade_intermediate_channels: int = 128, + ) -> None: + super().__init__() + self.resblock_updown = resblock_updown + resnets = [] + attentions = [] + + for i in range(num_res_blocks): + res_skip_channels = in_channels if (i == num_res_blocks - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + SPADEResnetBlock( + spatial_dims=spatial_dims, + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + label_nc=label_nc, + spade_intermediate_channels=spade_intermediate_channels, + ) + ) + attentions.append( + _SpatialTransformer( + spatial_dims=spatial_dims, + in_channels=out_channels, + num_attention_heads=out_channels // num_head_channels, + num_head_channels=num_head_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + num_layers=transformer_num_layers, + cross_attention_dim=cross_attention_dim, + upcast_attention=upcast_attention, + use_flash_attention=use_flash_attention, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + self.upsampler: nn.Module | None + if add_upsample: + if resblock_updown: + self.upsampler = _ResnetBlock( + spatial_dims=spatial_dims, + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + up=True, + ) + else: + self.upsampler = _Upsample( + spatial_dims=spatial_dims, num_channels=out_channels, use_conv=True, out_channels=out_channels + ) + else: + self.upsampler = None + + def forward( + self, + hidden_states: torch.Tensor, + res_hidden_states_list: list[torch.Tensor], + temb: torch.Tensor, + seg: torch.Tensor | None = None, + context: torch.Tensor | None = None, + ) -> torch.Tensor: + for resnet, attn in zip(self.resnets, self.attentions): + # pop res hidden states + res_hidden_states = res_hidden_states_list[-1] + res_hidden_states_list = res_hidden_states_list[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + hidden_states = resnet(hidden_states, temb, seg) + hidden_states = attn(hidden_states, context=context) + + if self.upsampler is not None: + hidden_states = self.upsampler(hidden_states, temb) + + return hidden_states + + +def get_spade_up_block( + spatial_dims: int, + in_channels: int, + prev_output_channel: int, + out_channels: int, + temb_channels: int, + num_res_blocks: int, + norm_num_groups: int, + norm_eps: float, + add_upsample: bool, + resblock_updown: bool, + with_attn: bool, + with_cross_attn: bool, + num_head_channels: int, + transformer_num_layers: int, + label_nc: int, + cross_attention_dim: int | None, + upcast_attention: bool = False, + use_flash_attention: bool = False, + spade_intermediate_channels: int = 128, +) -> nn.Module: + if with_attn: + return SPADEAttnUpBlock( + spatial_dims=spatial_dims, + in_channels=in_channels, + prev_output_channel=prev_output_channel, + out_channels=out_channels, + temb_channels=temb_channels, + label_nc=label_nc, + num_res_blocks=num_res_blocks, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + add_upsample=add_upsample, + resblock_updown=resblock_updown, + num_head_channels=num_head_channels, + use_flash_attention=use_flash_attention, + spade_intermediate_channels=spade_intermediate_channels, + ) + elif with_cross_attn: + return SPADECrossAttnUpBlock( + spatial_dims=spatial_dims, + in_channels=in_channels, + prev_output_channel=prev_output_channel, + out_channels=out_channels, + temb_channels=temb_channels, + label_nc=label_nc, + num_res_blocks=num_res_blocks, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + add_upsample=add_upsample, + resblock_updown=resblock_updown, + num_head_channels=num_head_channels, + transformer_num_layers=transformer_num_layers, + cross_attention_dim=cross_attention_dim, + upcast_attention=upcast_attention, + use_flash_attention=use_flash_attention, + spade_intermediate_channels=spade_intermediate_channels, + ) + else: + return SPADEUpBlock( + spatial_dims=spatial_dims, + in_channels=in_channels, + prev_output_channel=prev_output_channel, + out_channels=out_channels, + temb_channels=temb_channels, + label_nc=label_nc, + num_res_blocks=num_res_blocks, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + add_upsample=add_upsample, + resblock_updown=resblock_updown, + spade_intermediate_channels=spade_intermediate_channels, + ) + + +class SPADEDiffusionModelUNet(nn.Module): + """ + UNet network with timestep embedding and attention mechanisms for conditioning, with added SPADE normalization for + semantic conditioning (Park et.al (2019): https://github.com/NVlabs/SPADE). An example tutorial can be found at + https://github.com/Project-MONAI/GenerativeModels/tree/main/tutorials/generative/2d_spade_ldm + + Args: + spatial_dims: number of spatial dimensions. + in_channels: number of input channels. + out_channels: number of output channels. + label_nc: number of semantic channels for SPADE normalisation. + num_res_blocks: number of residual blocks (see ResnetBlock) per level. + num_channels: tuple of block output channels. + attention_levels: list of levels to add attention. + norm_num_groups: number of groups for the normalization. + norm_eps: epsilon for the normalization. + resblock_updown: if True use residual blocks for up/downsampling. + num_head_channels: number of channels in each attention head. + with_conditioning: if True add spatial transformers to perform conditioning. + transformer_num_layers: number of layers of Transformer blocks to use. + cross_attention_dim: number of context dimensions to use. + num_class_embeds: if specified (as an int), then this model will be class-conditional with `num_class_embeds` + classes. + upcast_attention: if True, upcast attention operations to full precision. + use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. + spade_intermediate_channels: number of intermediate channels for SPADE block layer + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + out_channels: int, + label_nc: int, + num_res_blocks: Sequence[int] | int = (2, 2, 2, 2), + num_channels: Sequence[int] = (32, 64, 64, 64), + attention_levels: Sequence[bool] = (False, False, True, True), + norm_num_groups: int = 32, + norm_eps: float = 1e-6, + resblock_updown: bool = False, + num_head_channels: int | Sequence[int] = 8, + with_conditioning: bool = False, + transformer_num_layers: int = 1, + cross_attention_dim: int | None = None, + num_class_embeds: int | None = None, + upcast_attention: bool = False, + use_flash_attention: bool = False, + spade_intermediate_channels: int = 128, + ) -> None: + super().__init__() + if with_conditioning is True and cross_attention_dim is None: + raise ValueError( + "SPADEDiffusionModelUNet expects dimension of the cross-attention conditioning (cross_attention_dim) " + "when using with_conditioning." + ) + if cross_attention_dim is not None and with_conditioning is False: + raise ValueError( + "SPADEDiffusionModelUNet expects with_conditioning=True when specifying the cross_attention_dim." + ) + + # All number of channels should be multiple of num_groups + if any((out_channel % norm_num_groups) != 0 for out_channel in num_channels): + raise ValueError("SPADEDiffusionModelUNet expects all num_channels being multiple of norm_num_groups") + + if len(num_channels) != len(attention_levels): + raise ValueError("SPADEDiffusionModelUNet expects num_channels being same size of attention_levels") + + if isinstance(num_head_channels, int): + num_head_channels = ensure_tuple_rep(num_head_channels, len(attention_levels)) + + if len(num_head_channels) != len(attention_levels): + raise ValueError( + "num_head_channels should have the same length as attention_levels. For the i levels without attention," + " i.e. `attention_level[i]=False`, the num_head_channels[i] will be ignored." + ) + + if isinstance(num_res_blocks, int): + num_res_blocks = ensure_tuple_rep(num_res_blocks, len(num_channels)) + + if len(num_res_blocks) != len(num_channels): + raise ValueError( + "`num_res_blocks` should be a single integer or a tuple of integers with the same length as " + "`num_channels`." + ) + + if use_flash_attention and not has_xformers: + raise ValueError("use_flash_attention is True but xformers is not installed.") + + if use_flash_attention is True and not torch.cuda.is_available(): + raise ValueError( + "torch.cuda.is_available() should be True but is False. Flash attention is only available for GPU." + ) + + self.in_channels = in_channels + self.block_out_channels = num_channels + self.out_channels = out_channels + self.num_res_blocks = num_res_blocks + self.attention_levels = attention_levels + self.num_head_channels = num_head_channels + self.with_conditioning = with_conditioning + self.label_nc = label_nc + + # input + self.conv_in = Convolution( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=num_channels[0], + strides=1, + kernel_size=3, + padding=1, + conv_only=True, + ) + + # time + time_embed_dim = num_channels[0] * 4 + self.time_embed = nn.Sequential( + nn.Linear(num_channels[0], time_embed_dim), nn.SiLU(), nn.Linear(time_embed_dim, time_embed_dim) + ) + + # class embedding + self.num_class_embeds = num_class_embeds + if num_class_embeds is not None: + self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) + + # down + self.down_blocks = nn.ModuleList([]) + output_channel = num_channels[0] + for i in range(len(num_channels)): + input_channel = output_channel + output_channel = num_channels[i] + is_final_block = i == len(num_channels) - 1 + + down_block = get_down_block( + spatial_dims=spatial_dims, + in_channels=input_channel, + out_channels=output_channel, + temb_channels=time_embed_dim, + num_res_blocks=num_res_blocks[i], + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + add_downsample=not is_final_block, + resblock_updown=resblock_updown, + with_attn=(attention_levels[i] and not with_conditioning), + with_cross_attn=(attention_levels[i] and with_conditioning), + num_head_channels=num_head_channels[i], + transformer_num_layers=transformer_num_layers, + cross_attention_dim=cross_attention_dim, + upcast_attention=upcast_attention, + use_flash_attention=use_flash_attention, + ) + + self.down_blocks.append(down_block) + + # mid + self.middle_block = get_mid_block( + spatial_dims=spatial_dims, + in_channels=num_channels[-1], + temb_channels=time_embed_dim, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + with_conditioning=with_conditioning, + num_head_channels=num_head_channels[-1], + transformer_num_layers=transformer_num_layers, + cross_attention_dim=cross_attention_dim, + upcast_attention=upcast_attention, + use_flash_attention=use_flash_attention, + ) + + # up + self.up_blocks = nn.ModuleList([]) + reversed_block_out_channels = list(reversed(num_channels)) + reversed_num_res_blocks = list(reversed(num_res_blocks)) + reversed_attention_levels = list(reversed(attention_levels)) + reversed_num_head_channels = list(reversed(num_head_channels)) + output_channel = reversed_block_out_channels[0] + for i in range(len(reversed_block_out_channels)): + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + input_channel = reversed_block_out_channels[min(i + 1, len(num_channels) - 1)] + + is_final_block = i == len(num_channels) - 1 + + up_block = get_spade_up_block( + spatial_dims=spatial_dims, + in_channels=input_channel, + prev_output_channel=prev_output_channel, + out_channels=output_channel, + temb_channels=time_embed_dim, + num_res_blocks=reversed_num_res_blocks[i] + 1, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + add_upsample=not is_final_block, + resblock_updown=resblock_updown, + with_attn=(reversed_attention_levels[i] and not with_conditioning), + with_cross_attn=(reversed_attention_levels[i] and with_conditioning), + num_head_channels=reversed_num_head_channels[i], + transformer_num_layers=transformer_num_layers, + cross_attention_dim=cross_attention_dim, + upcast_attention=upcast_attention, + use_flash_attention=use_flash_attention, + label_nc=label_nc, + spade_intermediate_channels=spade_intermediate_channels, + ) + + self.up_blocks.append(up_block) + + # out + self.out = nn.Sequential( + nn.GroupNorm(num_groups=norm_num_groups, num_channels=num_channels[0], eps=norm_eps, affine=True), + nn.SiLU(), + zero_module( + Convolution( + spatial_dims=spatial_dims, + in_channels=num_channels[0], + out_channels=out_channels, + strides=1, + kernel_size=3, + padding=1, + conv_only=True, + ) + ), + ) + + def forward( + self, + x: torch.Tensor, + timesteps: torch.Tensor, + seg: torch.Tensor, + context: torch.Tensor | None = None, + class_labels: torch.Tensor | None = None, + down_block_additional_residuals: tuple[torch.Tensor] | None = None, + mid_block_additional_residual: torch.Tensor | None = None, + ) -> torch.Tensor: + """ + Args: + x: input tensor (N, C, SpatialDims). + timesteps: timestep tensor (N,). + seg: Bx[LABEL_NC]x[SPATIAL DIMENSIONS] tensor of segmentations for SPADE norm. + context: context tensor (N, 1, ContextDim). + class_labels: context tensor (N, ). + down_block_additional_residuals: additional residual tensors for down blocks (N, C, FeatureMapsDims). + mid_block_additional_residual: additional residual tensor for mid block (N, C, FeatureMapsDims). + """ + # 1. time + t_emb = get_timestep_embedding(timesteps, self.block_out_channels[0]) + + # timesteps does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=x.dtype) + emb = self.time_embed(t_emb) + + # 2. class + if self.num_class_embeds is not None: + if class_labels is None: + raise ValueError("class_labels should be provided when num_class_embeds > 0") + class_emb = self.class_embedding(class_labels) + class_emb = class_emb.to(dtype=x.dtype) + emb = emb + class_emb + + # 3. initial convolution + h = self.conv_in(x) + + # 4. down + if context is not None and self.with_conditioning is False: + raise ValueError("model should have with_conditioning = True if context is provided") + down_block_res_samples: list[torch.Tensor] = [h] + for downsample_block in self.down_blocks: + h, res_samples = downsample_block(hidden_states=h, temb=emb, context=context) + for residual in res_samples: + down_block_res_samples.append(residual) + + # Additional residual conections for Controlnets + if down_block_additional_residuals is not None: + new_down_block_res_samples: list[torch.Tensor] = [h] + for down_block_res_sample, down_block_additional_residual in zip( + down_block_res_samples, down_block_additional_residuals + ): + down_block_res_sample = down_block_res_sample + down_block_additional_residual + new_down_block_res_samples.append(down_block_res_sample) + + down_block_res_samples = new_down_block_res_samples + + # 5. mid + h = self.middle_block(hidden_states=h, temb=emb, context=context) + + # Additional residual conections for Controlnets + if mid_block_additional_residual is not None: + h = h + mid_block_additional_residual + + # 6. up + for upsample_block in self.up_blocks: + res_samples = down_block_res_samples[-len(upsample_block.resnets) :] + down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] + h = upsample_block(hidden_states=h, res_hidden_states_list=res_samples, seg=seg, temb=emb, context=context) + + # 7. output block + output: torch.Tensor = self.out(h) + + return output diff --git a/test_spade_autoencoderkl.py b/test_spade_autoencoderkl.py new file mode 100644 index 0000000000..6675a6db67 --- /dev/null +++ b/test_spade_autoencoderkl.py @@ -0,0 +1,260 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import torch +from parameterized import parameterized + +from monai.networks import eval_mode +from monai.networks.nets import SPADEAutoencoderKL + +CASES = [ + [ + { + "spatial_dims": 2, + "label_nc": 3, + "in_channels": 1, + "out_channels": 1, + "channels": (4, 4, 4), + "latent_channels": 4, + "attention_levels": (False, False, False), + "num_res_blocks": 1, + "norm_num_groups": 4, + }, + (1, 1, 16, 16), + (1, 3, 16, 16), + (1, 1, 16, 16), + (1, 4, 4, 4), + ], + [ + { + "spatial_dims": 2, + "label_nc": 3, + "in_channels": 1, + "out_channels": 1, + "channels": (4, 4, 4), + "latent_channels": 4, + "attention_levels": (False, False, False), + "num_res_blocks": (1, 1, 2), + "norm_num_groups": 4, + }, + (1, 1, 16, 16), + (1, 3, 16, 16), + (1, 1, 16, 16), + (1, 4, 4, 4), + ], + [ + { + "spatial_dims": 2, + "label_nc": 3, + "in_channels": 1, + "out_channels": 1, + "channels": (4, 4, 4), + "latent_channels": 4, + "attention_levels": (False, False, False), + "num_res_blocks": 1, + "norm_num_groups": 4, + }, + (1, 1, 16, 16), + (1, 3, 16, 16), + (1, 1, 16, 16), + (1, 4, 4, 4), + ], + [ + { + "spatial_dims": 2, + "label_nc": 3, + "in_channels": 1, + "out_channels": 1, + "channels": (4, 4, 4), + "latent_channels": 4, + "attention_levels": (False, False, True), + "num_res_blocks": 1, + "norm_num_groups": 4, + }, + (1, 1, 16, 16), + (1, 3, 16, 16), + (1, 1, 16, 16), + (1, 4, 4, 4), + ], + [ + { + "spatial_dims": 2, + "label_nc": 3, + "in_channels": 1, + "out_channels": 1, + "channels": (4, 4, 4), + "latent_channels": 4, + "attention_levels": (False, False, False), + "num_res_blocks": 1, + "norm_num_groups": 4, + "with_encoder_nonlocal_attn": False, + }, + (1, 1, 16, 16), + (1, 3, 16, 16), + (1, 1, 16, 16), + (1, 4, 4, 4), + ], + [ + { + "spatial_dims": 2, + "label_nc": 3, + "in_channels": 1, + "out_channels": 1, + "channels": (4, 4, 4), + "latent_channels": 4, + "attention_levels": (False, False, False), + "num_res_blocks": 1, + "norm_num_groups": 4, + "with_encoder_nonlocal_attn": False, + "with_decoder_nonlocal_attn": False, + }, + (1, 1, 16, 16), + (1, 3, 16, 16), + (1, 1, 16, 16), + (1, 4, 4, 4), + ], + [ + { + "spatial_dims": 3, + "label_nc": 3, + "in_channels": 1, + "out_channels": 1, + "channels": (4, 4, 4), + "latent_channels": 4, + "attention_levels": (False, False, True), + "num_res_blocks": 1, + "norm_num_groups": 4, + }, + (1, 1, 16, 16, 16), + (1, 3, 16, 16, 16), + (1, 1, 16, 16, 16), + (1, 4, 4, 4, 4), + ], + [ + { + "spatial_dims": 2, + "label_nc": 3, + "in_channels": 1, + "out_channels": 1, + "channels": (4, 4, 4), + "latent_channels": 4, + "attention_levels": (False, False, True), + "num_res_blocks": 1, + "norm_num_groups": 4, + "spade_intermediate_channels": 32, + }, + (1, 1, 16, 16), + (1, 3, 16, 16), + (1, 1, 16, 16), + (1, 4, 4, 4), + ], +] + +device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + + +class TestSPADEAutoEncoderKL(unittest.TestCase): + @parameterized.expand(CASES) + def test_shape(self, input_param, input_shape, input_seg, expected_shape, expected_latent_shape): + net = SPADEAutoencoderKL(**input_param).to(device) + with eval_mode(net): + result = net.forward(torch.randn(input_shape).to(device), torch.randn(input_seg).to(device)) + self.assertEqual(result[0].shape, expected_shape) + self.assertEqual(result[1].shape, expected_latent_shape) + + def test_model_channels_not_multiple_of_norm_num_group(self): + with self.assertRaises(ValueError): + SPADEAutoencoderKL( + spatial_dims=2, + label_nc=3, + in_channels=1, + out_channels=1, + channels=(24, 24, 24), + attention_levels=(False, False, False), + latent_channels=8, + num_res_blocks=1, + norm_num_groups=16, + ) + + def test_model_channels_not_same_size_of_attention_levels(self): + with self.assertRaises(ValueError): + SPADEAutoencoderKL( + spatial_dims=2, + label_nc=3, + in_channels=1, + out_channels=1, + channels=(24, 24, 24), + attention_levels=(False, False), + latent_channels=8, + num_res_blocks=1, + norm_num_groups=16, + ) + + def test_model_channels_not_same_size_of_num_res_blocks(self): + with self.assertRaises(ValueError): + SPADEAutoencoderKL( + spatial_dims=2, + label_nc=3, + in_channels=1, + out_channels=1, + channels=(24, 24, 24), + attention_levels=(False, False, False), + latent_channels=8, + num_res_blocks=(8, 8), + norm_num_groups=16, + ) + + def test_shape_encode(self): + input_param, input_shape, _, _, expected_latent_shape = CASES[0] + net = SPADEAutoencoderKL(**input_param).to(device) + with eval_mode(net): + result = net.encode(torch.randn(input_shape).to(device)) + self.assertEqual(result[0].shape, expected_latent_shape) + self.assertEqual(result[1].shape, expected_latent_shape) + + def test_shape_sampling(self): + input_param, _, _, _, expected_latent_shape = CASES[0] + net = SPADEAutoencoderKL(**input_param).to(device) + with eval_mode(net): + result = net.sampling( + torch.randn(expected_latent_shape).to(device), torch.randn(expected_latent_shape).to(device) + ) + self.assertEqual(result.shape, expected_latent_shape) + + def test_shape_decode(self): + input_param, _, input_seg_shape, expected_input_shape, latent_shape = CASES[0] + net = SPADEAutoencoderKL(**input_param).to(device) + with eval_mode(net): + result = net.decode(torch.randn(latent_shape).to(device), torch.randn(input_seg_shape).to(device)) + self.assertEqual(result.shape, expected_input_shape) + + def test_wrong_shape_decode(self): + net = SPADEAutoencoderKL( + spatial_dims=2, + label_nc=3, + in_channels=1, + out_channels=1, + channels=(4, 4, 4), + latent_channels=4, + attention_levels=(False, False, False), + num_res_blocks=1, + norm_num_groups=4, + ) + with self.assertRaises(RuntimeError): + _ = net.decode(torch.randn((1, 1, 16, 16)).to(device), torch.randn((1, 6, 16, 16)).to(device)) + + +if __name__ == "__main__": + unittest.main() diff --git a/test_spade_diffusion_model_unet.py b/test_spade_diffusion_model_unet.py new file mode 100644 index 0000000000..c8a2103cf6 --- /dev/null +++ b/test_spade_diffusion_model_unet.py @@ -0,0 +1,558 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import torch +from parameterized import parameterized + +from monai.networks import eval_mode +from monai.networks.nets import SPADEDiffusionModelUNet + +UNCOND_CASES_2D = [ + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "num_channels": (8, 8, 8), + "attention_levels": (False, False, False), + "norm_num_groups": 8, + "label_nc": 3, + } + ], + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": (1, 1, 2), + "num_channels": (8, 8, 8), + "attention_levels": (False, False, False), + "norm_num_groups": 8, + "label_nc": 3, + } + ], + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "num_channels": (8, 8, 8), + "attention_levels": (False, False, False), + "norm_num_groups": 8, + "resblock_updown": True, + "label_nc": 3, + } + ], + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "num_channels": (8, 8, 8), + "attention_levels": (False, False, True), + "num_head_channels": 8, + "norm_num_groups": 8, + "label_nc": 3, + } + ], + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "num_channels": (8, 8, 8), + "attention_levels": (False, False, True), + "num_head_channels": 8, + "norm_num_groups": 8, + "resblock_updown": True, + "label_nc": 3, + } + ], + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "num_channels": (8, 8, 8), + "attention_levels": (False, False, True), + "num_head_channels": 4, + "norm_num_groups": 8, + "label_nc": 3, + } + ], + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "num_channels": (8, 8, 8), + "attention_levels": (False, True, True), + "num_head_channels": (0, 2, 4), + "norm_num_groups": 8, + "label_nc": 3, + } + ], +] + +UNCOND_CASES_3D = [ + [ + { + "spatial_dims": 3, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "num_channels": (8, 8, 8), + "attention_levels": (False, False, False), + "norm_num_groups": 8, + "label_nc": 3, + "spade_intermediate_channels": 256, + } + ], + [ + { + "spatial_dims": 3, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "num_channels": (8, 8, 8), + "attention_levels": (False, False, False), + "norm_num_groups": 8, + "label_nc": 3, + } + ], + [ + { + "spatial_dims": 3, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "num_channels": (8, 8, 8), + "attention_levels": (False, False, False), + "norm_num_groups": 8, + "resblock_updown": True, + "label_nc": 3, + } + ], + [ + { + "spatial_dims": 3, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "num_channels": (8, 8, 8), + "attention_levels": (False, False, True), + "num_head_channels": 8, + "norm_num_groups": 8, + "label_nc": 3, + } + ], + [ + { + "spatial_dims": 3, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "num_channels": (8, 8, 8), + "attention_levels": (False, False, True), + "num_head_channels": 8, + "norm_num_groups": 8, + "resblock_updown": True, + "label_nc": 3, + } + ], + [ + { + "spatial_dims": 3, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "num_channels": (8, 8, 8), + "attention_levels": (False, False, True), + "num_head_channels": 4, + "norm_num_groups": 8, + "label_nc": 3, + } + ], + [ + { + "spatial_dims": 3, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "num_channels": (8, 8, 8), + "attention_levels": (False, False, True), + "num_head_channels": (0, 0, 4), + "norm_num_groups": 8, + "label_nc": 3, + } + ], +] + +COND_CASES_2D = [ + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "num_channels": (8, 8, 8), + "attention_levels": (False, False, True), + "num_head_channels": 4, + "norm_num_groups": 8, + "with_conditioning": True, + "transformer_num_layers": 1, + "cross_attention_dim": 3, + "label_nc": 3, + } + ], + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "num_channels": (8, 8, 8), + "attention_levels": (False, False, True), + "num_head_channels": 4, + "norm_num_groups": 8, + "with_conditioning": True, + "transformer_num_layers": 1, + "cross_attention_dim": 3, + "resblock_updown": True, + "label_nc": 3, + } + ], + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "num_channels": (8, 8, 8), + "attention_levels": (False, False, True), + "num_head_channels": 4, + "norm_num_groups": 8, + "with_conditioning": True, + "transformer_num_layers": 1, + "cross_attention_dim": 3, + "upcast_attention": True, + "label_nc": 3, + } + ], +] + + +class TestSPADEDiffusionModelUNet2D(unittest.TestCase): + @parameterized.expand(UNCOND_CASES_2D) + def test_shape_unconditioned_models(self, input_param): + net = SPADEDiffusionModelUNet(**input_param) + with eval_mode(net): + result = net.forward( + torch.rand((1, 1, 16, 16)), + torch.randint(0, 1000, (1,)).long(), + torch.rand((1, input_param["label_nc"], 16, 16)), + ) + self.assertEqual(result.shape, (1, 1, 16, 16)) + + def test_timestep_with_wrong_shape(self): + net = SPADEDiffusionModelUNet( + spatial_dims=2, + label_nc=3, + in_channels=1, + out_channels=1, + num_res_blocks=1, + num_channels=(8, 8, 8), + attention_levels=(False, False, False), + norm_num_groups=8, + ) + with self.assertRaises(ValueError): + with eval_mode(net): + net.forward( + torch.rand((1, 1, 16, 16)), torch.randint(0, 1000, (1, 1)).long(), torch.rand((1, 3, 16, 16)) + ) + + def test_label_with_wrong_shape(self): + net = SPADEDiffusionModelUNet( + spatial_dims=2, + label_nc=3, + in_channels=1, + out_channels=1, + num_res_blocks=1, + num_channels=(8, 8, 8), + attention_levels=(False, False, False), + norm_num_groups=8, + ) + with self.assertRaises(RuntimeError): + with eval_mode(net): + net.forward(torch.rand((1, 1, 16, 16)), torch.randint(0, 1000, (1,)).long(), torch.rand((1, 6, 16, 16))) + + def test_shape_with_different_in_channel_out_channel(self): + in_channels = 6 + out_channels = 3 + net = SPADEDiffusionModelUNet( + spatial_dims=2, + label_nc=3, + in_channels=in_channels, + out_channels=out_channels, + num_res_blocks=1, + num_channels=(8, 8, 8), + attention_levels=(False, False, False), + norm_num_groups=8, + ) + with eval_mode(net): + result = net.forward( + torch.rand((1, in_channels, 16, 16)), torch.randint(0, 1000, (1,)).long(), torch.rand((1, 3, 16, 16)) + ) + self.assertEqual(result.shape, (1, out_channels, 16, 16)) + + def test_model_channels_not_multiple_of_norm_num_group(self): + with self.assertRaises(ValueError): + SPADEDiffusionModelUNet( + spatial_dims=2, + label_nc=3, + in_channels=1, + out_channels=1, + num_res_blocks=1, + num_channels=(8, 8, 12), + attention_levels=(False, False, False), + norm_num_groups=8, + ) + + def test_attention_levels_with_different_length_num_head_channels(self): + with self.assertRaises(ValueError): + SPADEDiffusionModelUNet( + spatial_dims=2, + label_nc=3, + in_channels=1, + out_channels=1, + num_res_blocks=1, + num_channels=(8, 8, 8), + attention_levels=(False, False, False), + num_head_channels=(0, 2), + norm_num_groups=8, + ) + + def test_num_res_blocks_with_different_length_num_channels(self): + with self.assertRaises(ValueError): + SPADEDiffusionModelUNet( + spatial_dims=2, + label_nc=3, + in_channels=1, + out_channels=1, + num_res_blocks=(1, 1), + num_channels=(8, 8, 8), + attention_levels=(False, False, False), + norm_num_groups=8, + ) + + def test_shape_conditioned_models(self): + net = SPADEDiffusionModelUNet( + spatial_dims=2, + label_nc=3, + in_channels=1, + out_channels=1, + num_res_blocks=1, + num_channels=(8, 8, 8), + attention_levels=(False, False, True), + with_conditioning=True, + transformer_num_layers=1, + cross_attention_dim=3, + norm_num_groups=8, + num_head_channels=8, + ) + with eval_mode(net): + result = net.forward( + x=torch.rand((1, 1, 16, 32)), + timesteps=torch.randint(0, 1000, (1,)).long(), + seg=torch.rand((1, 3, 16, 32)), + context=torch.rand((1, 1, 3)), + ) + self.assertEqual(result.shape, (1, 1, 16, 32)) + + def test_with_conditioning_cross_attention_dim_none(self): + with self.assertRaises(ValueError): + SPADEDiffusionModelUNet( + spatial_dims=2, + label_nc=3, + in_channels=1, + out_channels=1, + num_res_blocks=1, + num_channels=(8, 8, 8), + attention_levels=(False, False, True), + with_conditioning=True, + transformer_num_layers=1, + cross_attention_dim=None, + norm_num_groups=8, + ) + + def test_context_with_conditioning_none(self): + net = SPADEDiffusionModelUNet( + spatial_dims=2, + label_nc=3, + in_channels=1, + out_channels=1, + num_res_blocks=1, + num_channels=(8, 8, 8), + attention_levels=(False, False, True), + with_conditioning=False, + transformer_num_layers=1, + norm_num_groups=8, + ) + + with self.assertRaises(ValueError): + with eval_mode(net): + net.forward( + x=torch.rand((1, 1, 16, 32)), + timesteps=torch.randint(0, 1000, (1,)).long(), + seg=torch.rand((1, 3, 16, 32)), + context=torch.rand((1, 1, 3)), + ) + + def test_shape_conditioned_models_class_conditioning(self): + net = SPADEDiffusionModelUNet( + spatial_dims=2, + label_nc=3, + in_channels=1, + out_channels=1, + num_res_blocks=1, + num_channels=(8, 8, 8), + attention_levels=(False, False, True), + norm_num_groups=8, + num_head_channels=8, + num_class_embeds=2, + ) + with eval_mode(net): + result = net.forward( + x=torch.rand((1, 1, 16, 32)), + timesteps=torch.randint(0, 1000, (1,)).long(), + seg=torch.rand((1, 3, 16, 32)), + class_labels=torch.randint(0, 2, (1,)).long(), + ) + self.assertEqual(result.shape, (1, 1, 16, 32)) + + def test_conditioned_models_no_class_labels(self): + net = SPADEDiffusionModelUNet( + spatial_dims=2, + label_nc=3, + in_channels=1, + out_channels=1, + num_res_blocks=1, + num_channels=(8, 8, 8), + attention_levels=(False, False, True), + norm_num_groups=8, + num_head_channels=8, + num_class_embeds=2, + ) + + with self.assertRaises(ValueError): + net.forward( + x=torch.rand((1, 1, 16, 32)), + timesteps=torch.randint(0, 1000, (1,)).long(), + seg=torch.rand((1, 3, 16, 32)), + ) + + def test_model_num_channels_not_same_size_of_attention_levels(self): + with self.assertRaises(ValueError): + SPADEDiffusionModelUNet( + spatial_dims=2, + label_nc=3, + in_channels=1, + out_channels=1, + num_res_blocks=1, + num_channels=(8, 8, 8), + attention_levels=(False, False), + norm_num_groups=8, + num_head_channels=8, + num_class_embeds=2, + ) + + @parameterized.expand(COND_CASES_2D) + def test_conditioned_2d_models_shape(self, input_param): + net = SPADEDiffusionModelUNet(**input_param) + with eval_mode(net): + result = net.forward( + torch.rand((1, 1, 16, 16)), + torch.randint(0, 1000, (1,)).long(), + torch.rand((1, input_param["label_nc"], 16, 16)), + torch.rand((1, 1, 3)), + ) + self.assertEqual(result.shape, (1, 1, 16, 16)) + + +class TestDiffusionModelUNet3D(unittest.TestCase): + @parameterized.expand(UNCOND_CASES_3D) + def test_shape_unconditioned_models(self, input_param): + net = SPADEDiffusionModelUNet(**input_param) + with eval_mode(net): + result = net.forward( + torch.rand((1, 1, 16, 16, 16)), + torch.randint(0, 1000, (1,)).long(), + torch.rand((1, input_param["label_nc"], 16, 16, 16)), + ) + self.assertEqual(result.shape, (1, 1, 16, 16, 16)) + + def test_shape_with_different_in_channel_out_channel(self): + in_channels = 6 + out_channels = 3 + net = SPADEDiffusionModelUNet( + spatial_dims=3, + label_nc=3, + in_channels=in_channels, + out_channels=out_channels, + num_res_blocks=1, + num_channels=(8, 8, 8), + attention_levels=(False, False, True), + norm_num_groups=4, + ) + with eval_mode(net): + result = net.forward( + torch.rand((1, in_channels, 16, 16, 16)), + torch.randint(0, 1000, (1,)).long(), + torch.rand((1, 3, 16, 16, 16)), + ) + self.assertEqual(result.shape, (1, out_channels, 16, 16, 16)) + + def test_shape_conditioned_models(self): + net = SPADEDiffusionModelUNet( + spatial_dims=3, + label_nc=3, + in_channels=1, + out_channels=1, + num_res_blocks=1, + num_channels=(16, 16, 16), + attention_levels=(False, False, True), + norm_num_groups=16, + with_conditioning=True, + transformer_num_layers=1, + cross_attention_dim=3, + ) + with eval_mode(net): + result = net.forward( + x=torch.rand((1, 1, 16, 16, 16)), + timesteps=torch.randint(0, 1000, (1,)).long(), + seg=torch.rand((1, 3, 16, 16, 16)), + context=torch.rand((1, 1, 3)), + ) + self.assertEqual(result.shape, (1, 1, 16, 16, 16)) + + +if __name__ == "__main__": + unittest.main() From 3447b09435e72e856a4b29b436c2fbe61159a42f Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Wed, 3 Jan 2024 16:30:24 +0000 Subject: [PATCH 08/38] 6676 port diffusion schedulers (#7332) Towards #6676 . ### Description This adds some base classes for DDPM noise schedulers + three scheduler types. ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [x] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [x] In-line docstrings updated. - [x] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Mark Graham --- docs/source/networks.rst | 20 ++ monai/networks/schedulers/__init__.py | 17 ++ monai/networks/schedulers/ddim.py | 284 ++++++++++++++++++++++ monai/networks/schedulers/ddpm.py | 243 +++++++++++++++++++ monai/networks/schedulers/pndm.py | 316 +++++++++++++++++++++++++ monai/networks/schedulers/scheduler.py | 203 ++++++++++++++++ monai/utils/misc.py | 4 +- tests/test_scheduler_ddim.py | 83 +++++++ tests/test_scheduler_ddpm.py | 104 ++++++++ tests/test_scheduler_pndm.py | 108 +++++++++ 10 files changed, 1380 insertions(+), 2 deletions(-) create mode 100644 monai/networks/schedulers/__init__.py create mode 100644 monai/networks/schedulers/ddim.py create mode 100644 monai/networks/schedulers/ddpm.py create mode 100644 monai/networks/schedulers/pndm.py create mode 100644 monai/networks/schedulers/scheduler.py create mode 100644 tests/test_scheduler_ddim.py create mode 100644 tests/test_scheduler_ddpm.py create mode 100644 tests/test_scheduler_pndm.py diff --git a/docs/source/networks.rst b/docs/source/networks.rst index 79d5ef822e..f9375f1e97 100644 --- a/docs/source/networks.rst +++ b/docs/source/networks.rst @@ -778,6 +778,26 @@ Nets .. autoclass:: MultiScalePatchDiscriminator :members: +Diffusion Schedulers +-------------------- +.. autoclass:: monai.networks.schedulers.Scheduler + :members: + +`DDPM Scheduler` +~~~~~~~~~~~~~~~~ +.. autoclass:: monai.networks.schedulers.DDPMScheduler + :members: + +`DDIM Scheduler` +~~~~~~~~~~~~~~~~ +.. autoclass:: monai.networks.schedulers.DDIMScheduler + :members: + +`PNDM Scheduler` +~~~~~~~~~~~~~~~~ +.. autoclass:: monai.networks.schedulers.PNDMScheduler + :members: + Utilities --------- .. automodule:: monai.networks.utils diff --git a/monai/networks/schedulers/__init__.py b/monai/networks/schedulers/__init__.py new file mode 100644 index 0000000000..29e9020d65 --- /dev/null +++ b/monai/networks/schedulers/__init__.py @@ -0,0 +1,17 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from .ddim import DDIMScheduler +from .ddpm import DDPMScheduler +from .pndm import PNDMScheduler +from .scheduler import NoiseSchedules, Scheduler diff --git a/monai/networks/schedulers/ddim.py b/monai/networks/schedulers/ddim.py new file mode 100644 index 0000000000..ec47ff8dc6 --- /dev/null +++ b/monai/networks/schedulers/ddim.py @@ -0,0 +1,284 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# ========================================================================= +# Adapted from https://github.com/huggingface/diffusers +# which has the following license: +# https://github.com/huggingface/diffusers/blob/main/LICENSE +# +# Copyright 2022 UC Berkeley Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========================================================================= + +from __future__ import annotations + +import numpy as np +import torch + +from monai.utils import StrEnum + +from .scheduler import Scheduler + + +class DDIMPredictionType(StrEnum): + """ + Set of valid prediction type names for the DDIM scheduler's `prediction_type` argument. + + epsilon: predicting the noise of the diffusion process + sample: directly predicting the noisy sample + v_prediction: velocity prediction, see section 2.4 https://imagen.research.google/video/paper.pdf + """ + + EPSILON = "epsilon" + SAMPLE = "sample" + V_PREDICTION = "v_prediction" + + +class DDIMScheduler(Scheduler): + """ + Denoising diffusion implicit models is a scheduler that extends the denoising procedure introduced in denoising + diffusion probabilistic models (DDPMs) with non-Markovian guidance. Based on: Song et al. "Denoising Diffusion + Implicit Models" https://arxiv.org/abs/2010.02502 + + Args: + num_train_timesteps: number of diffusion steps used to train the model. + schedule: member of NoiseSchedules, name of noise schedule function in component store + clip_sample: option to clip predicted sample between -1 and 1 for numerical stability. + set_alpha_to_one: each diffusion step uses the value of alphas product at that step and at the previous one. + For the final step there is no previous alpha. When this option is `True` the previous alpha product is + fixed to `1`, otherwise it uses the value of alpha at step 0. + steps_offset: an offset added to the inference steps. You can use a combination of `steps_offset=1` and + `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in + stable diffusion. + prediction_type: member of DDPMPredictionType + schedule_args: arguments to pass to the schedule function + + """ + + def __init__( + self, + num_train_timesteps: int = 1000, + schedule: str = "linear_beta", + clip_sample: bool = True, + set_alpha_to_one: bool = True, + steps_offset: int = 0, + prediction_type: str = DDIMPredictionType.EPSILON, + **schedule_args, + ) -> None: + super().__init__(num_train_timesteps, schedule, **schedule_args) + + if prediction_type not in DDIMPredictionType.__members__.values(): + raise ValueError("Argument `prediction_type` must be a member of DDIMPredictionType") + + self.prediction_type = prediction_type + + # At every step in ddim, we are looking into the previous alphas_cumprod + # For the final step, there is no previous alphas_cumprod because we are already at 0 + # `set_alpha_to_one` decides whether we set this parameter simply to one or + # whether we use the final alpha of the "non-previous" one. + self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0] + + # standard deviation of the initial noise distribution + self.init_noise_sigma = 1.0 + + self.timesteps = torch.from_numpy(np.arange(0, self.num_train_timesteps)[::-1].astype(np.int64)) + + self.clip_sample = clip_sample + self.steps_offset = steps_offset + + # default the number of inference timesteps to the number of train steps + self.num_inference_steps: int + self.set_timesteps(self.num_train_timesteps) + + def set_timesteps(self, num_inference_steps: int, device: str | torch.device | None = None) -> None: + """ + Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. + + Args: + num_inference_steps: number of diffusion steps used when generating samples with a pre-trained model. + device: target device to put the data. + """ + if num_inference_steps > self.num_train_timesteps: + raise ValueError( + f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.num_train_timesteps`:" + f" {self.num_train_timesteps} as the unet model trained with this scheduler can only handle" + f" maximal {self.num_train_timesteps} timesteps." + ) + + self.num_inference_steps = num_inference_steps + step_ratio = self.num_train_timesteps // self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64) + self.timesteps = torch.from_numpy(timesteps).to(device) + self.timesteps += self.steps_offset + + def _get_variance(self, timestep: int, prev_timestep: int) -> torch.Tensor: + alpha_prod_t = self.alphas_cumprod[timestep] + alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod + beta_prod_t = 1 - alpha_prod_t + beta_prod_t_prev = 1 - alpha_prod_t_prev + + variance: torch.Tensor = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev) + + return variance + + def step( + self, + model_output: torch.Tensor, + timestep: int, + sample: torch.Tensor, + eta: float = 0.0, + generator: torch.Generator | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion + process from the learned model outputs (most often the predicted noise). + + Args: + model_output: direct output from learned diffusion model. + timestep: current discrete timestep in the diffusion chain. + sample: current instance of sample being created by diffusion process. + eta: weight of noise for added noise in diffusion step. + predict_epsilon: flag to use when model predicts the samples directly instead of the noise, epsilon. + generator: random number generator. + + Returns: + pred_prev_sample: Predicted previous sample + pred_original_sample: Predicted original sample + """ + # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf + # Ideally, read DDIM paper in-detail understanding + + # Notation ( -> + # - model_output -> e_theta(x_t, t) + # - pred_original_sample -> f_theta(x_t, t) or x_0 + # - std_dev_t -> sigma_t + # - eta -> η + # - pred_sample_direction -> "direction pointing to x_t" + # - pred_prev_sample -> "x_t-1" + + # 1. get previous step value (=t-1) + prev_timestep = timestep - self.num_train_timesteps // self.num_inference_steps + + # 2. compute alphas, betas + alpha_prod_t = self.alphas_cumprod[timestep] + alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod + + beta_prod_t = 1 - alpha_prod_t + + # 3. compute predicted original sample from predicted noise also called + # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + if self.prediction_type == DDIMPredictionType.EPSILON: + pred_original_sample = (sample - (beta_prod_t**0.5) * model_output) / (alpha_prod_t**0.5) + pred_epsilon = model_output + elif self.prediction_type == DDIMPredictionType.SAMPLE: + pred_original_sample = model_output + pred_epsilon = (sample - (alpha_prod_t**0.5) * pred_original_sample) / (beta_prod_t**0.5) + elif self.prediction_type == DDIMPredictionType.V_PREDICTION: + pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output + pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample + + # 4. Clip "predicted x_0" + if self.clip_sample: + pred_original_sample = torch.clamp(pred_original_sample, -1, 1) + + # 5. compute variance: "sigma_t(η)" -> see formula (16) + # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1) + variance = self._get_variance(timestep, prev_timestep) + std_dev_t = eta * variance**0.5 + + # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** 0.5 * pred_epsilon + + # 7. compute x_t-1 without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + pred_prev_sample = alpha_prod_t_prev**0.5 * pred_original_sample + pred_sample_direction + + if eta > 0: + # randn_like does not support generator https://github.com/pytorch/pytorch/issues/27072 + device: torch.device = torch.device(model_output.device if torch.is_tensor(model_output) else "cpu") + noise = torch.randn(model_output.shape, dtype=model_output.dtype, generator=generator).to(device) + variance = self._get_variance(timestep, prev_timestep) ** 0.5 * eta * noise + + pred_prev_sample = pred_prev_sample + variance + + return pred_prev_sample, pred_original_sample + + def reversed_step( + self, model_output: torch.Tensor, timestep: int, sample: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Predict the sample at the next timestep by reversing the SDE. Core function to propagate the diffusion + process from the learned model outputs (most often the predicted noise). + + Args: + model_output: direct output from learned diffusion model. + timestep: current discrete timestep in the diffusion chain. + sample: current instance of sample being created by diffusion process. + + Returns: + pred_prev_sample: Predicted previous sample + pred_original_sample: Predicted original sample + """ + # See Appendix F at https://arxiv.org/pdf/2105.05233.pdf, or Equation (6) in https://arxiv.org/pdf/2203.04306.pdf + + # Notation ( -> + # - model_output -> e_theta(x_t, t) + # - pred_original_sample -> f_theta(x_t, t) or x_0 + # - std_dev_t -> sigma_t + # - eta -> η + # - pred_sample_direction -> "direction pointing to x_t" + # - pred_post_sample -> "x_t+1" + + # 1. get previous step value (=t+1) + prev_timestep = timestep + self.num_train_timesteps // self.num_inference_steps + + # 2. compute alphas, betas at timestep t+1 + alpha_prod_t = self.alphas_cumprod[timestep] + alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod + + beta_prod_t = 1 - alpha_prod_t + + # 3. compute predicted original sample from predicted noise also called + # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + + if self.prediction_type == DDIMPredictionType.EPSILON: + pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) + pred_epsilon = model_output + elif self.prediction_type == DDIMPredictionType.SAMPLE: + pred_original_sample = model_output + pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) + elif self.prediction_type == DDIMPredictionType.V_PREDICTION: + pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output + pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample + + # 4. Clip "predicted x_0" + if self.clip_sample: + pred_original_sample = torch.clamp(pred_original_sample, -1, 1) + + # 5. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + pred_sample_direction = (1 - alpha_prod_t_prev) ** (0.5) * pred_epsilon + + # 6. compute x_t+1 without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + pred_post_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction + + return pred_post_sample, pred_original_sample diff --git a/monai/networks/schedulers/ddpm.py b/monai/networks/schedulers/ddpm.py new file mode 100644 index 0000000000..a5173a1b65 --- /dev/null +++ b/monai/networks/schedulers/ddpm.py @@ -0,0 +1,243 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# ========================================================================= +# Adapted from https://github.com/huggingface/diffusers +# which has the following license: +# https://github.com/huggingface/diffusers/blob/main/LICENSE +# +# Copyright 2022 UC Berkeley Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========================================================================= + +from __future__ import annotations + +import numpy as np +import torch + +from monai.utils import StrEnum + +from .scheduler import Scheduler + + +class DDPMVarianceType(StrEnum): + """ + Valid names for DDPM Scheduler's `variance_type` argument. Options to clip the variance used when adding noise + to the denoised sample. + """ + + FIXED_SMALL = "fixed_small" + FIXED_LARGE = "fixed_large" + LEARNED = "learned" + LEARNED_RANGE = "learned_range" + + +class DDPMPredictionType(StrEnum): + """ + Set of valid prediction type names for the DDPM scheduler's `prediction_type` argument. + + epsilon: predicting the noise of the diffusion process + sample: directly predicting the noisy sample + v_prediction: velocity prediction, see section 2.4 https://imagen.research.google/video/paper.pdf + """ + + EPSILON = "epsilon" + SAMPLE = "sample" + V_PREDICTION = "v_prediction" + + +class DDPMScheduler(Scheduler): + """ + Denoising diffusion probabilistic models (DDPMs) explores the connections between denoising score matching and + Langevin dynamics sampling. Based on: Ho et al., "Denoising Diffusion Probabilistic Models" + https://arxiv.org/abs/2006.11239 + + Args: + num_train_timesteps: number of diffusion steps used to train the model. + schedule: member of NoiseSchedules, name of noise schedule function in component store + variance_type: member of DDPMVarianceType + clip_sample: option to clip predicted sample between -1 and 1 for numerical stability. + prediction_type: member of DDPMPredictionType + schedule_args: arguments to pass to the schedule function + """ + + def __init__( + self, + num_train_timesteps: int = 1000, + schedule: str = "linear_beta", + variance_type: str = DDPMVarianceType.FIXED_SMALL, + clip_sample: bool = True, + prediction_type: str = DDPMPredictionType.EPSILON, + **schedule_args, + ) -> None: + super().__init__(num_train_timesteps, schedule, **schedule_args) + + if variance_type not in DDPMVarianceType.__members__.values(): + raise ValueError("Argument `variance_type` must be a member of `DDPMVarianceType`") + + if prediction_type not in DDPMPredictionType.__members__.values(): + raise ValueError("Argument `prediction_type` must be a member of `DDPMPredictionType`") + + self.clip_sample = clip_sample + self.variance_type = variance_type + self.prediction_type = prediction_type + + def set_timesteps(self, num_inference_steps: int, device: str | torch.device | None = None) -> None: + """ + Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. + + Args: + num_inference_steps: number of diffusion steps used when generating samples with a pre-trained model. + device: target device to put the data. + """ + if num_inference_steps > self.num_train_timesteps: + raise ValueError( + f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.num_train_timesteps`:" + f" {self.num_train_timesteps} as the unet model trained with this scheduler can only handle" + f" maximal {self.num_train_timesteps} timesteps." + ) + + self.num_inference_steps = num_inference_steps + step_ratio = self.num_train_timesteps // self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].astype(np.int64) + self.timesteps = torch.from_numpy(timesteps).to(device) + + def _get_mean(self, timestep: int, x_0: torch.Tensor, x_t: torch.Tensor) -> torch.Tensor: + """ + Compute the mean of the posterior at timestep t. + + Args: + timestep: current timestep. + x0: the noise-free input. + x_t: the input noised to timestep t. + + Returns: + Returns the mean + """ + # these attributes are used for calculating the posterior, q(x_{t-1}|x_t,x_0), + # (see formula (5-7) from https://arxiv.org/pdf/2006.11239.pdf) + alpha_t = self.alphas[timestep] + alpha_prod_t = self.alphas_cumprod[timestep] + alpha_prod_t_prev = self.alphas_cumprod[timestep - 1] if timestep > 0 else self.one + + x_0_coefficient = alpha_prod_t_prev.sqrt() * self.betas[timestep] / (1 - alpha_prod_t) + x_t_coefficient = alpha_t.sqrt() * (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) + + mean: torch.Tensor = x_0_coefficient * x_0 + x_t_coefficient * x_t + + return mean + + def _get_variance(self, timestep: int, predicted_variance: torch.Tensor | None = None) -> torch.Tensor: + """ + Compute the variance of the posterior at timestep t. + + Args: + timestep: current timestep. + predicted_variance: variance predicted by the model. + + Returns: + Returns the variance + """ + alpha_prod_t = self.alphas_cumprod[timestep] + alpha_prod_t_prev = self.alphas_cumprod[timestep - 1] if timestep > 0 else self.one + + # For t > 0, compute predicted variance βt (see formula (6) and (7) from https://arxiv.org/pdf/2006.11239.pdf) + # and sample from it to get previous sample + # x_{t-1} ~ N(pred_prev_sample, variance) == add variance to pred_sample + variance: torch.Tensor = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * self.betas[timestep] + # hacks - were probably added for training stability + if self.variance_type == DDPMVarianceType.FIXED_SMALL: + variance = torch.clamp(variance, min=1e-20) + elif self.variance_type == DDPMVarianceType.FIXED_LARGE: + variance = self.betas[timestep] + elif self.variance_type == DDPMVarianceType.LEARNED and predicted_variance is not None: + return predicted_variance + elif self.variance_type == DDPMVarianceType.LEARNED_RANGE and predicted_variance is not None: + min_log = variance + max_log = self.betas[timestep] + frac = (predicted_variance + 1) / 2 + variance = frac * max_log + (1 - frac) * min_log + + return variance + + def step( + self, model_output: torch.Tensor, timestep: int, sample: torch.Tensor, generator: torch.Generator | None = None + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion + process from the learned model outputs (most often the predicted noise). + + Args: + model_output: direct output from learned diffusion model. + timestep: current discrete timestep in the diffusion chain. + sample: current instance of sample being created by diffusion process. + generator: random number generator. + + Returns: + pred_prev_sample: Predicted previous sample + """ + if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in ["learned", "learned_range"]: + model_output, predicted_variance = torch.split(model_output, sample.shape[1], dim=1) + else: + predicted_variance = None + + # 1. compute alphas, betas + alpha_prod_t = self.alphas_cumprod[timestep] + alpha_prod_t_prev = self.alphas_cumprod[timestep - 1] if timestep > 0 else self.one + beta_prod_t = 1 - alpha_prod_t + beta_prod_t_prev = 1 - alpha_prod_t_prev + + # 2. compute predicted original sample from predicted noise also called + # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf + if self.prediction_type == DDPMPredictionType.EPSILON: + pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) + elif self.prediction_type == DDPMPredictionType.SAMPLE: + pred_original_sample = model_output + elif self.prediction_type == DDPMPredictionType.V_PREDICTION: + pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output + + # 3. Clip "predicted x_0" + if self.clip_sample: + pred_original_sample = torch.clamp(pred_original_sample, -1, 1) + + # 4. Compute coefficients for pred_original_sample x_0 and current sample x_t + # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf + pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * self.betas[timestep]) / beta_prod_t + current_sample_coeff = self.alphas[timestep] ** (0.5) * beta_prod_t_prev / beta_prod_t + + # 5. Compute predicted previous sample µ_t + # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf + pred_prev_sample = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * sample + + # 6. Add noise + variance = 0 + if timestep > 0: + noise = torch.randn( + model_output.size(), dtype=model_output.dtype, layout=model_output.layout, generator=generator + ).to(model_output.device) + variance = (self._get_variance(timestep, predicted_variance=predicted_variance) ** 0.5) * noise + + pred_prev_sample = pred_prev_sample + variance + + return pred_prev_sample, pred_original_sample diff --git a/monai/networks/schedulers/pndm.py b/monai/networks/schedulers/pndm.py new file mode 100644 index 0000000000..c0728bbdff --- /dev/null +++ b/monai/networks/schedulers/pndm.py @@ -0,0 +1,316 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# ========================================================================= +# Adapted from https://github.com/huggingface/diffusers +# which has the following license: +# https://github.com/huggingface/diffusers/blob/main/LICENSE +# +# Copyright 2022 UC Berkeley Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========================================================================= + +from __future__ import annotations + +from typing import Any + +import numpy as np +import torch + +from monai.utils import StrEnum + +from .scheduler import Scheduler + + +class PNDMPredictionType(StrEnum): + """ + Set of valid prediction type names for the PNDM scheduler's `prediction_type` argument. + + epsilon: predicting the noise of the diffusion process + v_prediction: velocity prediction, see section 2.4 https://imagen.research.google/video/paper.pdf + """ + + EPSILON = "epsilon" + V_PREDICTION = "v_prediction" + + +class PNDMScheduler(Scheduler): + """ + Pseudo numerical methods for diffusion models (PNDM) proposes using more advanced ODE integration techniques, + namely Runge-Kutta method and a linear multi-step method. Based on: Liu et al., + "Pseudo Numerical Methods for Diffusion Models on Manifolds" https://arxiv.org/abs/2202.09778 + + Args: + num_train_timesteps: number of diffusion steps used to train the model. + schedule: member of NoiseSchedules, name of noise schedule function in component store + skip_prk_steps: + allows the scheduler to skip the Runge-Kutta steps that are defined in the original paper as being required + before plms step. + set_alpha_to_one: + each diffusion step uses the value of alphas product at that step and at the previous one. For the final + step there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`, + otherwise it uses the value of alpha at step 0. + prediction_type: member of DDPMPredictionType + steps_offset: + an offset added to the inference steps. You can use a combination of `offset=1` and + `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in + stable diffusion. + schedule_args: arguments to pass to the schedule function + """ + + def __init__( + self, + num_train_timesteps: int = 1000, + schedule: str = "linear_beta", + skip_prk_steps: bool = False, + set_alpha_to_one: bool = False, + prediction_type: str = PNDMPredictionType.EPSILON, + steps_offset: int = 0, + **schedule_args, + ) -> None: + super().__init__(num_train_timesteps, schedule, **schedule_args) + + if prediction_type not in PNDMPredictionType.__members__.values(): + raise ValueError("Argument `prediction_type` must be a member of PNDMPredictionType") + + self.prediction_type = prediction_type + + self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0] + + # standard deviation of the initial noise distribution + self.init_noise_sigma = 1.0 + + # For now we only support F-PNDM, i.e. the runge-kutta method + # For more information on the algorithm please take a look at the paper: https://arxiv.org/pdf/2202.09778.pdf + # mainly at formula (9), (12), (13) and the Algorithm 2. + self.pndm_order = 4 + + self.skip_prk_steps = skip_prk_steps + self.steps_offset = steps_offset + + # running values + self.cur_model_output = torch.Tensor() + self.counter = 0 + self.cur_sample = torch.Tensor() + self.ets: list = [] + + # default the number of inference timesteps to the number of train steps + self.set_timesteps(num_train_timesteps) + + def set_timesteps(self, num_inference_steps: int, device: str | torch.device | None = None) -> None: + """ + Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. + + Args: + num_inference_steps: number of diffusion steps used when generating samples with a pre-trained model. + device: target device to put the data. + """ + if num_inference_steps > self.num_train_timesteps: + raise ValueError( + f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.num_train_timesteps`:" + f" {self.num_train_timesteps} as the unet model trained with this scheduler can only handle" + f" maximal {self.num_train_timesteps} timesteps." + ) + + self.num_inference_steps = num_inference_steps + step_ratio = self.num_train_timesteps // self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + self._timesteps = (np.arange(0, num_inference_steps) * step_ratio).round().astype(np.int64) + self._timesteps += self.steps_offset + + if self.skip_prk_steps: + # for some models like stable diffusion the prk steps can/should be skipped to + # produce better results. When using PNDM with `self.skip_prk_steps` the implementation + # is based on crowsonkb's PLMS sampler implementation: https://github.com/CompVis/latent-diffusion/pull/51 + self.prk_timesteps = np.array([]) + self.plms_timesteps = self._timesteps[::-1] + + else: + prk_timesteps = np.array(self._timesteps[-self.pndm_order :]).repeat(2) + np.tile( + np.array([0, self.num_train_timesteps // num_inference_steps // 2]), self.pndm_order + ) + self.prk_timesteps = (prk_timesteps[:-1].repeat(2)[1:-1])[::-1].copy() + self.plms_timesteps = self._timesteps[:-3][ + ::-1 + ].copy() # we copy to avoid having negative strides which are not supported by torch.from_numpy + + timesteps = np.concatenate([self.prk_timesteps, self.plms_timesteps]).astype(np.int64) + self.timesteps = torch.from_numpy(timesteps).to(device) + # update num_inference_steps - necessary if we use prk steps + self.num_inference_steps = len(self.timesteps) + + self.ets = [] + self.counter = 0 + + def step(self, model_output: torch.Tensor, timestep: int, sample: torch.Tensor) -> tuple[torch.Tensor, Any]: + """ + Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion + process from the learned model outputs (most often the predicted noise). + This function calls `step_prk()` or `step_plms()` depending on the internal variable `counter`. + + Args: + model_output: direct output from learned diffusion model. + timestep: current discrete timestep in the diffusion chain. + sample: current instance of sample being created by diffusion process. + Returns: + pred_prev_sample: Predicted previous sample + """ + # return a tuple for consistency with samplers that return (previous pred, original sample pred) + + if self.counter < len(self.prk_timesteps) and not self.skip_prk_steps: + return self.step_prk(model_output=model_output, timestep=timestep, sample=sample), None + else: + return self.step_plms(model_output=model_output, timestep=timestep, sample=sample), None + + def step_prk(self, model_output: torch.Tensor, timestep: int, sample: torch.Tensor) -> torch.Tensor: + """ + Step function propagating the sample with the Runge-Kutta method. RK takes 4 forward passes to approximate the + solution to the differential equation. + + Args: + model_output: direct output from learned diffusion model. + timestep: current discrete timestep in the diffusion chain. + sample: current instance of sample being created by diffusion process. + + Returns: + pred_prev_sample: Predicted previous sample + """ + if self.num_inference_steps is None: + raise ValueError( + "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" + ) + + diff_to_prev = 0 if self.counter % 2 else self.num_train_timesteps // self.num_inference_steps // 2 + prev_timestep = timestep - diff_to_prev + timestep = self.prk_timesteps[self.counter // 4 * 4] + + if self.counter % 4 == 0: + self.cur_model_output = 1 / 6 * model_output + self.ets.append(model_output) + self.cur_sample = sample + elif (self.counter - 1) % 4 == 0: + self.cur_model_output += 1 / 3 * model_output + elif (self.counter - 2) % 4 == 0: + self.cur_model_output += 1 / 3 * model_output + elif (self.counter - 3) % 4 == 0: + model_output = self.cur_model_output + 1 / 6 * model_output + self.cur_model_output = torch.Tensor() + + # cur_sample should not be an empty torch.Tensor() + cur_sample = self.cur_sample if self.cur_sample.numel() != 0 else sample + + prev_sample: torch.Tensor = self._get_prev_sample(cur_sample, timestep, prev_timestep, model_output) + self.counter += 1 + + return prev_sample + + def step_plms(self, model_output: torch.Tensor, timestep: int, sample: torch.Tensor) -> Any: + """ + Step function propagating the sample with the linear multi-step method. This has one forward pass with multiple + times to approximate the solution. + + Args: + model_output: direct output from learned diffusion model. + timestep: current discrete timestep in the diffusion chain. + sample: current instance of sample being created by diffusion process. + + Returns: + pred_prev_sample: Predicted previous sample + """ + if self.num_inference_steps is None: + raise ValueError( + "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" + ) + + if not self.skip_prk_steps and len(self.ets) < 3: + raise ValueError( + f"{self.__class__} can only be run AFTER scheduler has been run " + "in 'prk' mode for at least 12 iterations " + ) + + prev_timestep = timestep - self.num_train_timesteps // self.num_inference_steps + + if self.counter != 1: + self.ets = self.ets[-3:] + self.ets.append(model_output) + else: + prev_timestep = timestep + timestep = timestep + self.num_train_timesteps // self.num_inference_steps + + if len(self.ets) == 1 and self.counter == 0: + model_output = model_output + self.cur_sample = sample + elif len(self.ets) == 1 and self.counter == 1: + model_output = (model_output + self.ets[-1]) / 2 + sample = self.cur_sample + self.cur_sample = torch.Tensor() + elif len(self.ets) == 2: + model_output = (3 * self.ets[-1] - self.ets[-2]) / 2 + elif len(self.ets) == 3: + model_output = (23 * self.ets[-1] - 16 * self.ets[-2] + 5 * self.ets[-3]) / 12 + else: + model_output = (1 / 24) * (55 * self.ets[-1] - 59 * self.ets[-2] + 37 * self.ets[-3] - 9 * self.ets[-4]) + + prev_sample = self._get_prev_sample(sample, timestep, prev_timestep, model_output) + self.counter += 1 + + return prev_sample + + def _get_prev_sample(self, sample: torch.Tensor, timestep: int, prev_timestep: int, model_output: torch.Tensor): + # See formula (9) of PNDM paper https://arxiv.org/pdf/2202.09778.pdf + # this function computes x_(t−δ) using the formula of (9) + # Note that x_t needs to be added to both sides of the equation + + # Notation ( -> + # alpha_prod_t -> α_t + # alpha_prod_t_prev -> α_(t−δ) + # beta_prod_t -> (1 - α_t) + # beta_prod_t_prev -> (1 - α_(t−δ)) + # sample -> x_t + # model_output -> e_θ(x_t, t) + # prev_sample -> x_(t−δ) + alpha_prod_t = self.alphas_cumprod[timestep] + alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod + beta_prod_t = 1 - alpha_prod_t + beta_prod_t_prev = 1 - alpha_prod_t_prev + + if self.prediction_type == PNDMPredictionType.V_PREDICTION: + model_output = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample + + # corresponds to (α_(t−δ) - α_t) divided by + # denominator of x_t in formula (9) and plus 1 + # Note: (α_(t−δ) - α_t) / (sqrt(α_t) * (sqrt(α_(t−δ)) + sqr(α_t))) = + # sqrt(α_(t−δ)) / sqrt(α_t)) + sample_coeff = (alpha_prod_t_prev / alpha_prod_t) ** (0.5) + + # corresponds to denominator of e_θ(x_t, t) in formula (9) + model_output_denom_coeff = alpha_prod_t * beta_prod_t_prev ** (0.5) + ( + alpha_prod_t * beta_prod_t * alpha_prod_t_prev + ) ** (0.5) + + # full formula (9) + prev_sample = ( + sample_coeff * sample - (alpha_prod_t_prev - alpha_prod_t) * model_output / model_output_denom_coeff + ) + + return prev_sample diff --git a/monai/networks/schedulers/scheduler.py b/monai/networks/schedulers/scheduler.py new file mode 100644 index 0000000000..17bb526abc --- /dev/null +++ b/monai/networks/schedulers/scheduler.py @@ -0,0 +1,203 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# ========================================================================= +# Adapted from https://github.com/huggingface/diffusers +# which has the following license: +# https://github.com/huggingface/diffusers/blob/main/LICENSE +# +# Copyright 2022 UC Berkeley Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========================================================================= + + +from __future__ import annotations + +import torch +import torch.nn as nn + +from monai.utils import ComponentStore, unsqueeze_right + +NoiseSchedules = ComponentStore("NoiseSchedules", "Functions to generate noise schedules") + + +@NoiseSchedules.add_def("linear_beta", "Linear beta schedule") +def _linear_beta(num_train_timesteps: int, beta_start: float = 1e-4, beta_end: float = 2e-2): + """ + Linear beta noise schedule function. + + Args: + num_train_timesteps: number of timesteps + beta_start: start of beta range, default 1e-4 + beta_end: end of beta range, default 2e-2 + + Returns: + betas: beta schedule tensor + """ + return torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) + + +@NoiseSchedules.add_def("scaled_linear_beta", "Scaled linear beta schedule") +def _scaled_linear_beta(num_train_timesteps: int, beta_start: float = 1e-4, beta_end: float = 2e-2): + """ + Scaled linear beta noise schedule function. + + Args: + num_train_timesteps: number of timesteps + beta_start: start of beta range, default 1e-4 + beta_end: end of beta range, default 2e-2 + + Returns: + betas: beta schedule tensor + """ + return torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 + + +@NoiseSchedules.add_def("sigmoid_beta", "Sigmoid beta schedule") +def _sigmoid_beta(num_train_timesteps: int, beta_start: float = 1e-4, beta_end: float = 2e-2, sig_range: float = 6): + """ + Sigmoid beta noise schedule function. + + Args: + num_train_timesteps: number of timesteps + beta_start: start of beta range, default 1e-4 + beta_end: end of beta range, default 2e-2 + sig_range: pos/neg range of sigmoid input, default 6 + + Returns: + betas: beta schedule tensor + """ + betas = torch.linspace(-sig_range, sig_range, num_train_timesteps) + return torch.sigmoid(betas) * (beta_end - beta_start) + beta_start + + +@NoiseSchedules.add_def("cosine", "Cosine schedule") +def _cosine_beta(num_train_timesteps: int, s: float = 8e-3): + """ + Cosine noise schedule, see https://arxiv.org/abs/2102.09672 + + Args: + num_train_timesteps: number of timesteps + s: smoothing factor, default 8e-3 (see referenced paper) + + Returns: + (betas, alphas, alpha_cumprod) values + """ + x = torch.linspace(0, num_train_timesteps, num_train_timesteps + 1) + alphas_cumprod = torch.cos(((x / num_train_timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2 + alphas_cumprod /= alphas_cumprod[0].item() + alphas = torch.clip(alphas_cumprod[1:] / alphas_cumprod[:-1], 0.0001, 0.9999) + betas = 1.0 - alphas + return betas, alphas, alphas_cumprod[:-1] + + +class Scheduler(nn.Module): + """ + Base class for other schedulers based on a noise schedule function. + + This class is meant as the base for other schedulers which implement their own way of sampling or stepping. Here + the class defines beta, alpha, and alpha_cumprod values from a noise schedule function named with `schedule`, + which is the name of a component in NoiseSchedules. These components must all be callables which return either + the beta schedule alone or a triple containing (betas, alphas, alphas_cumprod) values. New schedule functions + can be provided by using the NoiseSchedules.add_def, for example: + + .. code-block:: python + + from monai.networks.schedulers import NoiseSchedules, DDPMScheduler + + @NoiseSchedules.add_def("my_beta_schedule", "Some description of your function") + def _beta_function(num_train_timesteps, beta_start=1e-4, beta_end=2e-2): + return torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) + + scheduler = DDPMScheduler(num_train_timesteps=1000, schedule="my_beta_schedule") + + All such functions should have an initial positional integer argument `num_train_timesteps` stating the number of + timesteps the schedule is for, otherwise any other arguments can be given which will be passed by keyword through + the constructor's `schedule_args` value. To see what noise functions are available, print the object NoiseSchedules + to get a listing of stored objects with their docstring descriptions. + + Note: in previous versions of the schedulers the argument `schedule_beta` was used to state the beta schedule + type, this now replaced with `schedule` and most names used with the previous argument now have "_beta" appended + to them, eg. 'schedule_beta="linear"' -> 'schedule="linear_beta"'. The `beta_start` and `beta_end` arguments are + still used for some schedules but these are provided as keyword arguments now. + + Args: + num_train_timesteps: number of diffusion steps used to train the model. + schedule: member of NoiseSchedules, + a named function returning the beta tensor or (betas, alphas, alphas_cumprod) triple + schedule_args: arguments to pass to the schedule function + """ + + def __init__(self, num_train_timesteps: int = 1000, schedule: str = "linear_beta", **schedule_args) -> None: + super().__init__() + schedule_args["num_train_timesteps"] = num_train_timesteps + noise_sched = NoiseSchedules[schedule](**schedule_args) + + # set betas, alphas, alphas_cumprod based off return value from noise function + if isinstance(noise_sched, tuple): + self.betas, self.alphas, self.alphas_cumprod = noise_sched + else: + self.betas = noise_sched + self.alphas = 1.0 - self.betas + self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) + + self.num_train_timesteps = num_train_timesteps + self.one = torch.tensor(1.0) + + # settable values + self.num_inference_steps: int | None = None + self.timesteps = torch.arange(num_train_timesteps - 1, -1, -1) + + def add_noise(self, original_samples: torch.Tensor, noise: torch.Tensor, timesteps: torch.Tensor) -> torch.Tensor: + """ + Add noise to the original samples. + + Args: + original_samples: original samples + noise: noise to add to samples + timesteps: timesteps tensor indicating the timestep to be computed for each sample. + + Returns: + noisy_samples: sample with added noise + """ + # Make sure alphas_cumprod and timestep have same device and dtype as original_samples + self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype) + timesteps = timesteps.to(original_samples.device) + + sqrt_alpha_cumprod: torch.Tensor = unsqueeze_right(self.alphas_cumprod[timesteps] ** 0.5, original_samples.ndim) + sqrt_one_minus_alpha_prod: torch.Tensor = unsqueeze_right( + (1 - self.alphas_cumprod[timesteps]) ** 0.5, original_samples.ndim + ) + + noisy_samples = sqrt_alpha_cumprod * original_samples + sqrt_one_minus_alpha_prod * noise + return noisy_samples + + def get_velocity(self, sample: torch.Tensor, noise: torch.Tensor, timesteps: torch.Tensor) -> torch.Tensor: + # Make sure alphas_cumprod and timestep have same device and dtype as sample + self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device, dtype=sample.dtype) + timesteps = timesteps.to(sample.device) + + sqrt_alpha_prod = unsqueeze_right(self.alphas_cumprod[timesteps] ** 0.5, sample.ndim) + sqrt_one_minus_alpha_prod = unsqueeze_right((1 - self.alphas_cumprod[timesteps]) ** 0.5, sample.ndim) + + velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample + return velocity diff --git a/monai/utils/misc.py b/monai/utils/misc.py index d6ff370f69..4f2501a7ee 100644 --- a/monai/utils/misc.py +++ b/monai/utils/misc.py @@ -890,11 +890,11 @@ def is_sqrt(num: Sequence[int] | int) -> bool: return ensure_tuple(ret) == num -def unsqueeze_right(arr: NdarrayOrTensor, ndim: int) -> NdarrayOrTensor: +def unsqueeze_right(arr: torch.Tensor, ndim: int) -> torch.Tensor: """Append 1-sized dimensions to `arr` to create a result with `ndim` dimensions.""" return arr[(...,) + (None,) * (ndim - arr.ndim)] -def unsqueeze_left(arr: NdarrayOrTensor, ndim: int) -> NdarrayOrTensor: +def unsqueeze_left(arr: torch.Tensor, ndim: int) -> torch.Tensor: """Prepend 1-sized dimensions to `arr` to create a result with `ndim` dimensions.""" return arr[(None,) * (ndim - arr.ndim)] diff --git a/tests/test_scheduler_ddim.py b/tests/test_scheduler_ddim.py new file mode 100644 index 0000000000..1a8f8cab67 --- /dev/null +++ b/tests/test_scheduler_ddim.py @@ -0,0 +1,83 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import torch +from parameterized import parameterized + +from monai.networks.schedulers import DDIMScheduler +from tests.utils import assert_allclose + +TEST_2D_CASE = [] +for beta_schedule in ["linear_beta", "scaled_linear_beta"]: + TEST_2D_CASE.append([{"schedule": beta_schedule}, (2, 6, 16, 16), (2, 6, 16, 16)]) + +TEST_3D_CASE = [] +for beta_schedule in ["linear_beta", "scaled_linear_beta"]: + TEST_3D_CASE.append([{"schedule": beta_schedule}, (2, 6, 16, 16, 16), (2, 6, 16, 16, 16)]) + +TEST_CASES = TEST_2D_CASE + TEST_3D_CASE + +TEST_FULl_LOOP = [ + [{"schedule": "linear_beta"}, (1, 1, 2, 2), torch.Tensor([[[[-0.9579, -0.6457], [0.4684, -0.9694]]]])] +] + + +class TestDDPMScheduler(unittest.TestCase): + @parameterized.expand(TEST_CASES) + def test_add_noise(self, input_param, input_shape, expected_shape): + scheduler = DDIMScheduler(**input_param) + scheduler.set_timesteps(num_inference_steps=100) + original_sample = torch.zeros(input_shape) + noise = torch.randn_like(original_sample) + timesteps = torch.randint(0, scheduler.num_train_timesteps, (original_sample.shape[0],)).long() + + noisy = scheduler.add_noise(original_samples=original_sample, noise=noise, timesteps=timesteps) + self.assertEqual(noisy.shape, expected_shape) + + @parameterized.expand(TEST_CASES) + def test_step_shape(self, input_param, input_shape, expected_shape): + scheduler = DDIMScheduler(**input_param) + scheduler.set_timesteps(num_inference_steps=100) + model_output = torch.randn(input_shape) + sample = torch.randn(input_shape) + output_step = scheduler.step(model_output=model_output, timestep=500, sample=sample) + self.assertEqual(output_step[0].shape, expected_shape) + self.assertEqual(output_step[1].shape, expected_shape) + + @parameterized.expand(TEST_FULl_LOOP) + def test_full_timestep_loop(self, input_param, input_shape, expected_output): + scheduler = DDIMScheduler(**input_param) + scheduler.set_timesteps(50) + torch.manual_seed(42) + model_output = torch.randn(input_shape) + sample = torch.randn(input_shape) + for t in range(50): + sample, _ = scheduler.step(model_output=model_output, timestep=t, sample=sample) + assert_allclose(sample, expected_output, rtol=1e-3, atol=1e-3) + + def test_set_timesteps(self): + scheduler = DDIMScheduler(num_train_timesteps=1000) + scheduler.set_timesteps(num_inference_steps=100) + self.assertEqual(scheduler.num_inference_steps, 100) + self.assertEqual(len(scheduler.timesteps), 100) + + def test_set_timesteps_with_num_inference_steps_bigger_than_num_train_timesteps(self): + scheduler = DDIMScheduler(num_train_timesteps=1000) + with self.assertRaises(ValueError): + scheduler.set_timesteps(num_inference_steps=2000) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_scheduler_ddpm.py b/tests/test_scheduler_ddpm.py new file mode 100644 index 0000000000..f0447aded2 --- /dev/null +++ b/tests/test_scheduler_ddpm.py @@ -0,0 +1,104 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import torch +from parameterized import parameterized + +from monai.networks.schedulers import DDPMScheduler +from tests.utils import assert_allclose + +TEST_2D_CASE = [] +for beta_schedule in ["linear_beta", "scaled_linear_beta"]: + for variance_type in ["fixed_small", "fixed_large"]: + TEST_2D_CASE.append( + [{"schedule": beta_schedule, "variance_type": variance_type}, (2, 6, 16, 16), (2, 6, 16, 16)] + ) + +TEST_3D_CASE = [] +for beta_schedule in ["linear_beta", "scaled_linear_beta"]: + for variance_type in ["fixed_small", "fixed_large"]: + TEST_3D_CASE.append( + [{"schedule": beta_schedule, "variance_type": variance_type}, (2, 6, 16, 16, 16), (2, 6, 16, 16, 16)] + ) + +TEST_CASES = TEST_2D_CASE + TEST_3D_CASE + +TEST_FULl_LOOP = [ + [{"schedule": "linear_beta"}, (1, 1, 2, 2), torch.Tensor([[[[-1.0153, -0.3218], [0.8454, -0.7870]]]])] +] + + +class TestDDPMScheduler(unittest.TestCase): + @parameterized.expand(TEST_CASES) + def test_add_noise(self, input_param, input_shape, expected_shape): + scheduler = DDPMScheduler(**input_param) + original_sample = torch.zeros(input_shape) + noise = torch.randn_like(original_sample) + timesteps = torch.randint(0, scheduler.num_train_timesteps, (original_sample.shape[0],)).long() + + noisy = scheduler.add_noise(original_samples=original_sample, noise=noise, timesteps=timesteps) + self.assertEqual(noisy.shape, expected_shape) + + @parameterized.expand(TEST_CASES) + def test_step_shape(self, input_param, input_shape, expected_shape): + scheduler = DDPMScheduler(**input_param) + model_output = torch.randn(input_shape) + sample = torch.randn(input_shape) + output_step = scheduler.step(model_output=model_output, timestep=500, sample=sample) + self.assertEqual(output_step[0].shape, expected_shape) + self.assertEqual(output_step[1].shape, expected_shape) + + @parameterized.expand(TEST_FULl_LOOP) + def test_full_timestep_loop(self, input_param, input_shape, expected_output): + scheduler = DDPMScheduler(**input_param) + scheduler.set_timesteps(50) + torch.manual_seed(42) + model_output = torch.randn(input_shape) + sample = torch.randn(input_shape) + for t in range(50): + sample, _ = scheduler.step(model_output=model_output, timestep=t, sample=sample) + assert_allclose(sample, expected_output, rtol=1e-3, atol=1e-3) + + @parameterized.expand(TEST_CASES) + def test_get_velocity_shape(self, input_param, input_shape, expected_shape): + scheduler = DDPMScheduler(**input_param) + sample = torch.randn(input_shape) + timesteps = torch.randint(0, scheduler.num_train_timesteps, (input_shape[0],)).long() + velocity = scheduler.get_velocity(sample=sample, noise=sample, timesteps=timesteps) + self.assertEqual(velocity.shape, expected_shape) + + def test_step_learned(self): + for variance_type in ["learned", "learned_range"]: + scheduler = DDPMScheduler(variance_type=variance_type) + model_output = torch.randn(2, 6, 16, 16) + sample = torch.randn(2, 3, 16, 16) + output_step = scheduler.step(model_output=model_output, timestep=500, sample=sample) + self.assertEqual(output_step[0].shape, sample.shape) + self.assertEqual(output_step[1].shape, sample.shape) + + def test_set_timesteps(self): + scheduler = DDPMScheduler(num_train_timesteps=1000) + scheduler.set_timesteps(num_inference_steps=100) + self.assertEqual(scheduler.num_inference_steps, 100) + self.assertEqual(len(scheduler.timesteps), 100) + + def test_set_timesteps_with_num_inference_steps_bigger_than_num_train_timesteps(self): + scheduler = DDPMScheduler(num_train_timesteps=1000) + with self.assertRaises(ValueError): + scheduler.set_timesteps(num_inference_steps=2000) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_scheduler_pndm.py b/tests/test_scheduler_pndm.py new file mode 100644 index 0000000000..69e5e403f5 --- /dev/null +++ b/tests/test_scheduler_pndm.py @@ -0,0 +1,108 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import torch +from parameterized import parameterized + +from monai.networks.schedulers import PNDMScheduler +from tests.utils import assert_allclose + +TEST_2D_CASE = [] +for beta_schedule in ["linear_beta", "scaled_linear_beta"]: + TEST_2D_CASE.append([{"schedule": beta_schedule}, (2, 6, 16, 16), (2, 6, 16, 16)]) + +TEST_3D_CASE = [] +for beta_schedule in ["linear_beta", "scaled_linear_beta"]: + TEST_3D_CASE.append([{"schedule": beta_schedule}, (2, 6, 16, 16, 16), (2, 6, 16, 16, 16)]) + +TEST_CASES = TEST_2D_CASE + TEST_3D_CASE + +TEST_FULl_LOOP = [ + [ + {"schedule": "linear_beta"}, + (1, 1, 2, 2), + torch.Tensor([[[[-2123055.2500, -459014.2812], [2863438.0000, -1263401.7500]]]]), + ] +] + + +class TestDDPMScheduler(unittest.TestCase): + @parameterized.expand(TEST_CASES) + def test_add_noise(self, input_param, input_shape, expected_shape): + scheduler = PNDMScheduler(**input_param) + original_sample = torch.zeros(input_shape) + noise = torch.randn_like(original_sample) + timesteps = torch.randint(0, scheduler.num_train_timesteps, (original_sample.shape[0],)).long() + noisy = scheduler.add_noise(original_samples=original_sample, noise=noise, timesteps=timesteps) + self.assertEqual(noisy.shape, expected_shape) + + @parameterized.expand(TEST_CASES) + def test_step_shape(self, input_param, input_shape, expected_shape): + scheduler = PNDMScheduler(**input_param) + scheduler.set_timesteps(600) + model_output = torch.randn(input_shape) + sample = torch.randn(input_shape) + output_step = scheduler.step(model_output=model_output, timestep=500, sample=sample) + self.assertEqual(output_step[0].shape, expected_shape) + self.assertEqual(output_step[1], None) + + @parameterized.expand(TEST_FULl_LOOP) + def test_full_timestep_loop(self, input_param, input_shape, expected_output): + scheduler = PNDMScheduler(**input_param) + scheduler.set_timesteps(50) + torch.manual_seed(42) + model_output = torch.randn(input_shape) + sample = torch.randn(input_shape) + for t in range(50): + sample, _ = scheduler.step(model_output=model_output, timestep=t, sample=sample) + assert_allclose(sample, expected_output, rtol=1e-3, atol=1e-3) + + @parameterized.expand(TEST_FULl_LOOP) + def test_timestep_two_loops(self, input_param, input_shape, expected_output): + scheduler = PNDMScheduler(**input_param) + scheduler.set_timesteps(50) + torch.manual_seed(42) + model_output = torch.randn(input_shape) + sample = torch.randn(input_shape) + for t in range(50): + sample, _ = scheduler.step(model_output=model_output, timestep=t, sample=sample) + torch.manual_seed(42) + model_output2 = torch.randn(input_shape) + sample2 = torch.randn(input_shape) + scheduler.set_timesteps(50) + for t in range(50): + sample2, _ = scheduler.step(model_output=model_output2, timestep=t, sample=sample2) + assert_allclose(sample, sample2, rtol=1e-3, atol=1e-3) + + def test_set_timesteps(self): + scheduler = PNDMScheduler(num_train_timesteps=1000, skip_prk_steps=True) + scheduler.set_timesteps(num_inference_steps=100) + self.assertEqual(scheduler.num_inference_steps, 100) + self.assertEqual(len(scheduler.timesteps), 100) + + def test_set_timesteps_prk(self): + scheduler = PNDMScheduler(num_train_timesteps=1000, skip_prk_steps=False) + scheduler.set_timesteps(num_inference_steps=100) + self.assertEqual(scheduler.num_inference_steps, 109) + self.assertEqual(len(scheduler.timesteps), 109) + + def test_set_timesteps_with_num_inference_steps_bigger_than_num_train_timesteps(self): + scheduler = PNDMScheduler(num_train_timesteps=1000) + with self.assertRaises(ValueError): + scheduler.set_timesteps(num_inference_steps=2000) + + +if __name__ == "__main__": + unittest.main() From 3ab5c62b9e0964d129980f7970aabefa9d0e2d2f Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Fri, 5 Jan 2024 06:52:15 +0000 Subject: [PATCH 09/38] 6676 port diffusion schedulers (#7364) This is an update to PR https://github.com/Project-MONAI/MONAI/pull/7332 - I addressed the comments but failed to push the changes before it was merged! Changes are very minor. ### Description A few sentences describing the changes proposed in this pull request. ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Mark Graham Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- monai/networks/schedulers/ddim.py | 25 +++++++++---------------- 1 file changed, 9 insertions(+), 16 deletions(-) diff --git a/monai/networks/schedulers/ddim.py b/monai/networks/schedulers/ddim.py index ec47ff8dc6..78e3cc2a0c 100644 --- a/monai/networks/schedulers/ddim.py +++ b/monai/networks/schedulers/ddim.py @@ -34,23 +34,10 @@ import numpy as np import torch -from monai.utils import StrEnum - +from .ddpm import DDPMPredictionType from .scheduler import Scheduler - -class DDIMPredictionType(StrEnum): - """ - Set of valid prediction type names for the DDIM scheduler's `prediction_type` argument. - - epsilon: predicting the noise of the diffusion process - sample: directly predicting the noisy sample - v_prediction: velocity prediction, see section 2.4 https://imagen.research.google/video/paper.pdf - """ - - EPSILON = "epsilon" - SAMPLE = "sample" - V_PREDICTION = "v_prediction" +DDIMPredictionType = DDPMPredictionType class DDIMScheduler(Scheduler): @@ -126,6 +113,13 @@ def set_timesteps(self, num_inference_steps: int, device: str | torch.device | N self.num_inference_steps = num_inference_steps step_ratio = self.num_train_timesteps // self.num_inference_steps + if self.steps_offset >= step_ratio: + raise ValueError( + f"`steps_offset`: {self.steps_offset} cannot be greater than or equal to " + f"`num_train_timesteps // num_inference_steps : {step_ratio}` as this will cause timesteps to exceed" + f" the max train timestep." + ) + # creates integer timesteps by multiplying by ratio # casting to int to avoid issues when num_inference_step is power of 3 timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64) @@ -159,7 +153,6 @@ def step( timestep: current discrete timestep in the diffusion chain. sample: current instance of sample being created by diffusion process. eta: weight of noise for added noise in diffusion step. - predict_epsilon: flag to use when model predicts the samples directly instead of the noise, epsilon. generator: random number generator. Returns: From 0a549fe937ab63b86ddc741dec8cddcd1085db4d Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Mon, 8 Jan 2024 14:56:58 +0000 Subject: [PATCH 10/38] Adds ordering util (#7369) Towards #6676 . ### Description This ordering util got missed out my previous PR for the Generative utils. ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [x] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [x] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Mark Graham --- docs/source/utils.rst | 5 + monai/utils/ordering.py | 207 ++++++++++++++++++++++++++ tests/test_ordering.py | 318 ++++++++++++++++++++++++++++++++++++++++ 3 files changed, 530 insertions(+) create mode 100644 monai/utils/ordering.py create mode 100644 tests/test_ordering.py diff --git a/docs/source/utils.rst b/docs/source/utils.rst index 527247799f..fef671e1f8 100644 --- a/docs/source/utils.rst +++ b/docs/source/utils.rst @@ -81,3 +81,8 @@ Component store --------------- .. autoclass:: monai.utils.component_store.ComponentStore :members: + +Ordering +-------- +.. automodule:: monai.utils.ordering + :members: diff --git a/monai/utils/ordering.py b/monai/utils/ordering.py new file mode 100644 index 0000000000..1be61f98ab --- /dev/null +++ b/monai/utils/ordering.py @@ -0,0 +1,207 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import numpy as np + +from monai.utils.enums import OrderingTransformations, OrderingType + + +class Ordering: + """ + Ordering class that projects a 2D or 3D image into a 1D sequence. It also allows the image to be transformed with + one of the following transformations: + Reflection (see np.flip for more details). + Transposition (see np.transpose for more details). + 90-degree rotation (see np.rot90 for more details). + + The transformations are applied in the order specified by the transformation_order parameter. + + Args: + ordering_type: The ordering type. One of the following: + - 'raster_scan': The image is projected into a 1D sequence by scanning the image from left to right and from + top to bottom. Also called a row major ordering. + - 's_curve': The image is projected into a 1D sequence by scanning the image in a circular snake like + pattern from top left towards right gowing in a spiral towards the center. + - random': The image is projected into a 1D sequence by randomly shuffling the image. + spatial_dims: The number of spatial dimensions of the image. + dimensions: The dimensions of the image. + reflected_spatial_dims: A tuple of booleans indicating whether to reflect the image along each spatial dimension. + transpositions_axes: A tuple of tuples indicating the axes to transpose the image along. + rot90_axes: A tuple of tuples indicating the axes to rotate the image along. + transformation_order: The order in which to apply the transformations. + """ + + def __init__( + self, + ordering_type: str, + spatial_dims: int, + dimensions: tuple[int, int, int] | tuple[int, int, int, int], + reflected_spatial_dims: tuple[bool, bool] | None = None, + transpositions_axes: tuple[tuple[int, int], ...] | tuple[tuple[int, int, int], ...] | None = None, + rot90_axes: tuple[tuple[int, int], ...] | None = None, + transformation_order: tuple[str, ...] = ( + OrderingTransformations.TRANSPOSE.value, + OrderingTransformations.ROTATE_90.value, + OrderingTransformations.REFLECT.value, + ), + ) -> None: + super().__init__() + self.ordering_type = ordering_type + + if self.ordering_type not in list(OrderingType): + raise ValueError( + f"ordering_type must be one of the following {list(OrderingType)}, but got {self.ordering_type}." + ) + + self.spatial_dims = spatial_dims + self.dimensions = dimensions + + if len(dimensions) != self.spatial_dims + 1: + raise ValueError(f"dimensions must be of length {self.spatial_dims + 1}, but got {len(dimensions)}.") + + self.reflected_spatial_dims = reflected_spatial_dims + self.transpositions_axes = transpositions_axes + self.rot90_axes = rot90_axes + if len(set(transformation_order)) != len(transformation_order): + raise ValueError(f"No duplicates are allowed. Received {transformation_order}.") + + for transformation in transformation_order: + if transformation not in list(OrderingTransformations): + raise ValueError( + f"Valid transformations are {list(OrderingTransformations)} but received {transformation}." + ) + self.transformation_order = transformation_order + + self.template = self._create_template() + self._sequence_ordering = self._create_ordering() + self._revert_sequence_ordering = np.argsort(self._sequence_ordering) + + def __call__(self, x: np.ndarray) -> np.ndarray: + x = x[self._sequence_ordering] + + return x + + def get_sequence_ordering(self) -> np.ndarray: + return self._sequence_ordering + + def get_revert_sequence_ordering(self) -> np.ndarray: + return self._revert_sequence_ordering + + def _create_ordering(self) -> np.ndarray: + self.template = self._transform_template() + order = self._order_template(template=self.template) + + return order + + def _create_template(self) -> np.ndarray: + spatial_dimensions = self.dimensions[1:] + template = np.arange(np.prod(spatial_dimensions)).reshape(*spatial_dimensions) + + return template + + def _transform_template(self) -> np.ndarray: + for transformation in self.transformation_order: + if transformation == OrderingTransformations.TRANSPOSE.value: + self.template = self._transpose_template(template=self.template) + elif transformation == OrderingTransformations.ROTATE_90.value: + self.template = self._rot90_template(template=self.template) + elif transformation == OrderingTransformations.REFLECT.value: + self.template = self._flip_template(template=self.template) + + return self.template + + def _transpose_template(self, template: np.ndarray) -> np.ndarray: + if self.transpositions_axes is not None: + for axes in self.transpositions_axes: + template = np.transpose(template, axes=axes) + + return template + + def _flip_template(self, template: np.ndarray) -> np.ndarray: + if self.reflected_spatial_dims is not None: + for axis, to_reflect in enumerate(self.reflected_spatial_dims): + template = np.flip(template, axis=axis) if to_reflect else template + + return template + + def _rot90_template(self, template: np.ndarray) -> np.ndarray: + if self.rot90_axes is not None: + for axes in self.rot90_axes: + template = np.rot90(template, axes=axes) + + return template + + def _order_template(self, template: np.ndarray) -> np.ndarray: + depths = None + if self.spatial_dims == 2: + rows, columns = template.shape[0], template.shape[1] + else: + rows, columns, depths = (template.shape[0], template.shape[1], template.shape[2]) + + sequence = eval(f"self.{self.ordering_type}_idx")(rows, columns, depths) + + ordering = np.array([template[tuple(e)] for e in sequence]) + + return ordering + + @staticmethod + def raster_scan_idx(rows: int, cols: int, depths: int | None = None) -> np.ndarray: + idx: list[tuple] = [] + + for r in range(rows): + for c in range(cols): + if depths is not None: + for d in range(depths): + idx.append((r, c, d)) + else: + idx.append((r, c)) + + idx_np = np.array(idx) + + return idx_np + + @staticmethod + def s_curve_idx(rows: int, cols: int, depths: int | None = None) -> np.ndarray: + idx: list[tuple] = [] + + for r in range(rows): + col_idx = range(cols) if r % 2 == 0 else range(cols - 1, -1, -1) + for c in col_idx: + if depths: + depth_idx = range(depths) if c % 2 == 0 else range(depths - 1, -1, -1) + + for d in depth_idx: + idx.append((r, c, d)) + else: + idx.append((r, c)) + + idx_np = np.array(idx) + + return idx_np + + @staticmethod + def random_idx(rows: int, cols: int, depths: int | None = None) -> np.ndarray: + idx: list[tuple] = [] + + for r in range(rows): + for c in range(cols): + if depths: + for d in range(depths): + idx.append((r, c, d)) + else: + idx.append((r, c)) + + idx_np = np.array(idx) + np.random.shuffle(idx_np) + + return idx_np diff --git a/tests/test_ordering.py b/tests/test_ordering.py new file mode 100644 index 0000000000..0c52dba5e5 --- /dev/null +++ b/tests/test_ordering.py @@ -0,0 +1,318 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import numpy as np +from parameterized import parameterized + +from monai.utils.enums import OrderingTransformations, OrderingType +from monai.utils.ordering import Ordering + +TEST_2D_NON_RANDOM = [ + [ + { + "ordering_type": OrderingType.RASTER_SCAN, + "spatial_dims": 2, + "dimensions": (1, 2, 2), + "reflected_spatial_dims": (), + "transpositions_axes": (), + "rot90_axes": (), + "transformation_order": ( + OrderingTransformations.TRANSPOSE.value, + OrderingTransformations.ROTATE_90.value, + OrderingTransformations.REFLECT.value, + ), + }, + [0, 1, 2, 3], + ], + [ + { + "ordering_type": OrderingType.S_CURVE, + "spatial_dims": 2, + "dimensions": (1, 2, 2), + "reflected_spatial_dims": (), + "transpositions_axes": (), + "rot90_axes": (), + "transformation_order": ( + OrderingTransformations.TRANSPOSE.value, + OrderingTransformations.ROTATE_90.value, + OrderingTransformations.REFLECT.value, + ), + }, + [0, 1, 3, 2], + ], + [ + { + "ordering_type": OrderingType.RASTER_SCAN, + "spatial_dims": 2, + "dimensions": (1, 2, 2), + "reflected_spatial_dims": (True, False), + "transpositions_axes": (), + "rot90_axes": (), + "transformation_order": ( + OrderingTransformations.TRANSPOSE.value, + OrderingTransformations.ROTATE_90.value, + OrderingTransformations.REFLECT.value, + ), + }, + [2, 3, 0, 1], + ], + [ + { + "ordering_type": OrderingType.S_CURVE, + "spatial_dims": 2, + "dimensions": (1, 2, 2), + "reflected_spatial_dims": (True, False), + "transpositions_axes": (), + "rot90_axes": (), + "transformation_order": ( + OrderingTransformations.TRANSPOSE.value, + OrderingTransformations.ROTATE_90.value, + OrderingTransformations.REFLECT.value, + ), + }, + [2, 3, 1, 0], + ], + [ + { + "ordering_type": OrderingType.RASTER_SCAN, + "spatial_dims": 2, + "dimensions": (1, 2, 2), + "reflected_spatial_dims": (), + "transpositions_axes": ((1, 0),), + "rot90_axes": (), + "transformation_order": ( + OrderingTransformations.TRANSPOSE.value, + OrderingTransformations.ROTATE_90.value, + OrderingTransformations.REFLECT.value, + ), + }, + [0, 2, 1, 3], + ], + [ + { + "ordering_type": OrderingType.S_CURVE, + "spatial_dims": 2, + "dimensions": (1, 2, 2), + "reflected_spatial_dims": (), + "transpositions_axes": ((1, 0),), + "rot90_axes": (), + "transformation_order": ( + OrderingTransformations.TRANSPOSE.value, + OrderingTransformations.ROTATE_90.value, + OrderingTransformations.REFLECT.value, + ), + }, + [0, 2, 3, 1], + ], + [ + { + "ordering_type": OrderingType.RASTER_SCAN, + "spatial_dims": 2, + "dimensions": (1, 2, 2), + "reflected_spatial_dims": (), + "transpositions_axes": (), + "rot90_axes": ((0, 1),), + "transformation_order": ( + OrderingTransformations.TRANSPOSE.value, + OrderingTransformations.ROTATE_90.value, + OrderingTransformations.REFLECT.value, + ), + }, + [1, 3, 0, 2], + ], + [ + { + "ordering_type": OrderingType.S_CURVE, + "spatial_dims": 2, + "dimensions": (1, 2, 2), + "reflected_spatial_dims": (), + "transpositions_axes": (), + "rot90_axes": ((0, 1),), + "transformation_order": ( + OrderingTransformations.TRANSPOSE.value, + OrderingTransformations.ROTATE_90.value, + OrderingTransformations.REFLECT.value, + ), + }, + [1, 3, 2, 0], + ], + [ + { + "ordering_type": OrderingType.RASTER_SCAN, + "spatial_dims": 2, + "dimensions": (1, 2, 2), + "reflected_spatial_dims": (True, False), + "transpositions_axes": ((1, 0),), + "rot90_axes": ((0, 1),), + "transformation_order": ( + OrderingTransformations.TRANSPOSE.value, + OrderingTransformations.ROTATE_90.value, + OrderingTransformations.REFLECT.value, + ), + }, + [0, 1, 2, 3], + ], + [ + { + "ordering_type": OrderingType.S_CURVE, + "spatial_dims": 2, + "dimensions": (1, 2, 2), + "reflected_spatial_dims": (True, False), + "transpositions_axes": ((1, 0),), + "rot90_axes": ((0, 1),), + "transformation_order": ( + OrderingTransformations.TRANSPOSE.value, + OrderingTransformations.ROTATE_90.value, + OrderingTransformations.REFLECT.value, + ), + }, + [0, 1, 3, 2], + ], +] + +TEST_2D_RANDOM = [ + [ + { + "ordering_type": OrderingType.RANDOM, + "spatial_dims": 2, + "dimensions": (1, 2, 2), + "reflected_spatial_dims": (True, False), + "transpositions_axes": ((1, 0),), + "rot90_axes": ((0, 1),), + "transformation_order": ( + OrderingTransformations.TRANSPOSE.value, + OrderingTransformations.ROTATE_90.value, + OrderingTransformations.REFLECT.value, + ), + }, + [[0, 1, 2, 3], [0, 1, 3, 2]], + ] +] + +TEST_3D = [ + [ + { + "ordering_type": OrderingType.RASTER_SCAN, + "spatial_dims": 3, + "dimensions": (1, 2, 2, 2), + "reflected_spatial_dims": (), + "transpositions_axes": (), + "rot90_axes": (), + "transformation_order": ( + OrderingTransformations.TRANSPOSE.value, + OrderingTransformations.ROTATE_90.value, + OrderingTransformations.REFLECT.value, + ), + }, + [0, 1, 2, 3, 4, 5, 6, 7], + ] +] + +TEST_ORDERING_TYPE_FAILURE = [ + [ + { + "ordering_type": "hilbert", + "spatial_dims": 2, + "dimensions": (1, 2, 2), + "reflected_spatial_dims": (True, False), + "transpositions_axes": ((1, 0),), + "rot90_axes": ((0, 1),), + "transformation_order": ( + OrderingTransformations.TRANSPOSE.value, + OrderingTransformations.ROTATE_90.value, + OrderingTransformations.REFLECT.value, + ), + } + ] +] + +TEST_ORDERING_TRANSFORMATION_FAILURE = [ + [ + { + "ordering_type": OrderingType.S_CURVE, + "spatial_dims": 2, + "dimensions": (1, 2, 2), + "reflected_spatial_dims": (True, False), + "transpositions_axes": ((1, 0),), + "rot90_axes": ((0, 1),), + "transformation_order": ( + OrderingTransformations.TRANSPOSE.value, + OrderingTransformations.ROTATE_90.value, + "flip", + ), + } + ] +] + +TEST_REVERT = [ + [ + { + "ordering_type": OrderingType.S_CURVE, + "spatial_dims": 2, + "dimensions": (1, 2, 2), + "reflected_spatial_dims": (True, False), + "transpositions_axes": (), + "rot90_axes": (), + "transformation_order": ( + OrderingTransformations.TRANSPOSE.value, + OrderingTransformations.ROTATE_90.value, + OrderingTransformations.REFLECT.value, + ), + } + ] +] + + +class TestOrdering(unittest.TestCase): + @parameterized.expand(TEST_2D_NON_RANDOM + TEST_3D) + def test_ordering(self, input_param, expected_sequence_ordering): + ordering = Ordering(**input_param) + self.assertTrue(np.array_equal(ordering.get_sequence_ordering(), expected_sequence_ordering, equal_nan=True)) + + @parameterized.expand(TEST_ORDERING_TYPE_FAILURE) + def test_ordering_type_failure(self, input_param): + with self.assertRaises(ValueError): + Ordering(**input_param) + + @parameterized.expand(TEST_ORDERING_TRANSFORMATION_FAILURE) + def test_ordering_transformation_failure(self, input_param): + with self.assertRaises(ValueError): + Ordering(**input_param) + + @parameterized.expand(TEST_2D_RANDOM) + def test_random(self, input_param, not_in_expected_sequence_ordering): + ordering = Ordering(**input_param) + + not_in = [ + np.array_equal(sequence, ordering.get_sequence_ordering(), equal_nan=True) + for sequence in not_in_expected_sequence_ordering + ] + + self.assertFalse(np.any(not_in)) + + @parameterized.expand(TEST_REVERT) + def test_revert(self, input_param): + sequence = np.random.randint(0, 100, size=input_param["dimensions"]).flatten() + + ordering = Ordering(**input_param) + + reverted_sequence = sequence[ordering.get_sequence_ordering()] + reverted_sequence = reverted_sequence[ordering.get_revert_sequence_ordering()] + + self.assertTrue(np.array_equal(sequence, reverted_sequence, equal_nan=True)) + + +if __name__ == "__main__": + unittest.main() From 510f7bc1eb4505d61f9aec6a9a96c444051a3e45 Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Thu, 18 Jan 2024 12:38:46 +0000 Subject: [PATCH 11/38] 6676 port generative inferers (#7379) Part of #6676 . ### Description Adds Inferers to assist with training and sampling from diffusion models and controllers. Also takes the opportunity to make two changes which slipped through the previous PRs: - rename the `num_channels` arg in the spade diffusion unet to `channels` to be consistent with all the other models added from Generative - this slipped through in the networks PR. - add the `Ordering` class to `__init__.py` for easier import ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [x] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [x] In-line docstrings updated. - [x] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Mark Graham Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- docs/source/inferers.rst | 23 + monai/inferers/__init__.py | 5 + monai/inferers/inferer.py | 1280 ++++++++++++++++- monai/networks/nets/diffusion_model_unet.py | 6 +- .../nets/spade_diffusion_model_unet.py | 40 +- monai/utils/__init__.py | 1 + setup.cfg | 14 +- tests/test_controlnet_inferers.py | 1270 ++++++++++++++++ tests/test_diffusion_inferer.py | 226 +++ tests/test_flexible_unet.py | 2 +- tests/test_invertd.py | 12 +- tests/test_latent_diffusion_inferer.py | 796 ++++++++++ tests/test_ordering.py | 29 - .../test_spade_autoencoderkl.py | 0 .../test_spade_diffusion_model_unet.py | 66 +- tests/test_vqvaetransformer_inferer.py | 284 ++++ 16 files changed, 3955 insertions(+), 99 deletions(-) create mode 100644 tests/test_controlnet_inferers.py create mode 100644 tests/test_diffusion_inferer.py create mode 100644 tests/test_latent_diffusion_inferer.py rename test_spade_autoencoderkl.py => tests/test_spade_autoencoderkl.py (100%) rename test_spade_diffusion_model_unet.py => tests/test_spade_diffusion_model_unet.py (92%) create mode 100644 tests/test_vqvaetransformer_inferer.py diff --git a/docs/source/inferers.rst b/docs/source/inferers.rst index 33f9e14d83..326f56e96c 100644 --- a/docs/source/inferers.rst +++ b/docs/source/inferers.rst @@ -49,6 +49,29 @@ Inferers :members: :special-members: __call__ +`DiffusionInferer` +~~~~~~~~~~~~~~~~~~ +.. autoclass:: DiffusionInferer + :members: + :special-members: __call__ + +`LatentDiffusionInferer` +~~~~~~~~~~~~~~~~~~~~~~~~ +.. autoclass:: LatentDiffusionInferer + :members: + :special-members: __call__ + +`ControlNetDiffusionInferer` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +.. autoclass:: ControlNetDiffusionInferer + :members: + :special-members: __call__ + +`ControlNetLatentDiffusionInferer` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +.. autoclass:: ControlNetLatentDiffusionInferer + :members: + :special-members: __call__ Splitters --------- diff --git a/monai/inferers/__init__.py b/monai/inferers/__init__.py index 960380bfb8..fc78b9f7c4 100644 --- a/monai/inferers/__init__.py +++ b/monai/inferers/__init__.py @@ -12,13 +12,18 @@ from __future__ import annotations from .inferer import ( + ControlNetDiffusionInferer, + ControlNetLatentDiffusionInferer, + DiffusionInferer, Inferer, + LatentDiffusionInferer, PatchInferer, SaliencyInferer, SimpleInferer, SliceInferer, SlidingWindowInferer, SlidingWindowInfererAdapt, + VQVAETransformerInferer, ) from .merger import AvgMerger, Merger, ZarrAvgMerger from .splitter import SlidingWindowSplitter, Splitter, WSISlidingWindowSplitter diff --git a/monai/inferers/inferer.py b/monai/inferers/inferer.py index 0b4199938d..72bcb8fd5a 100644 --- a/monai/inferers/inferer.py +++ b/monai/inferers/inferer.py @@ -11,24 +11,41 @@ from __future__ import annotations +import math import warnings from abc import ABC, abstractmethod from collections.abc import Callable, Iterable, Iterator, Mapping, Sequence +from functools import partial from pydoc import locate from typing import Any import torch import torch.nn as nn +import torch.nn.functional as F from monai.apps.utils import get_logger +from monai.data import decollate_batch from monai.data.meta_tensor import MetaTensor from monai.data.thread_buffer import ThreadBuffer from monai.inferers.merger import AvgMerger, Merger from monai.inferers.splitter import Splitter from monai.inferers.utils import compute_importance_map, sliding_window_inference -from monai.utils import BlendMode, PatchKeys, PytorchPadMode, ensure_tuple, optional_import +from monai.networks.nets import ( + VQVAE, + AutoencoderKL, + ControlNet, + DecoderOnlyTransformer, + DiffusionModelUNet, + SPADEAutoencoderKL, + SPADEDiffusionModelUNet, +) +from monai.networks.schedulers import Scheduler +from monai.transforms import CenterSpatialCrop, SpatialPad +from monai.utils import BlendMode, Ordering, PatchKeys, PytorchPadMode, ensure_tuple, optional_import from monai.visualize import CAM, GradCAM, GradCAMpp +tqdm, has_tqdm = optional_import("tqdm", name="tqdm") + logger = get_logger(__name__) __all__ = [ @@ -752,3 +769,1264 @@ def network_wrapper( return out return tuple(out_i.unsqueeze(dim=self.spatial_dim + 2) for out_i in out) + + +class DiffusionInferer(Inferer): + """ + DiffusionInferer takes a trained diffusion model and a scheduler and can be used to perform a signal forward pass + for a training iteration, and sample from the model. + + Args: + scheduler: diffusion scheduler. + """ + + def __init__(self, scheduler: Scheduler) -> None: # type: ignore[override] + super().__init__() + + self.scheduler = scheduler + + def __call__( # type: ignore[override] + self, + inputs: torch.Tensor, + diffusion_model: DiffusionModelUNet, + noise: torch.Tensor, + timesteps: torch.Tensor, + condition: torch.Tensor | None = None, + mode: str = "crossattn", + seg: torch.Tensor | None = None, + ) -> torch.Tensor: + """ + Implements the forward pass for a supervised training iteration. + + Args: + inputs: Input image to which noise is added. + diffusion_model: diffusion model. + noise: random noise, of the same shape as the input. + timesteps: random timesteps. + condition: Conditioning for network input. + mode: Conditioning mode for the network. + seg: if model is instance of SPADEDiffusionModelUnet, segmentation must be + provided on the forward (for SPADE-like AE or SPADE-like DM) + """ + if mode not in ["crossattn", "concat"]: + raise NotImplementedError(f"{mode} condition is not supported") + + noisy_image: torch.Tensor = self.scheduler.add_noise(original_samples=inputs, noise=noise, timesteps=timesteps) + if mode == "concat": + if condition is None: + raise ValueError("Conditioning is required for concat condition") + else: + noisy_image = torch.cat([noisy_image, condition], dim=1) + condition = None + diffusion_model = ( + partial(diffusion_model, seg=seg) + if isinstance(diffusion_model, SPADEDiffusionModelUNet) + else diffusion_model + ) + prediction: torch.Tensor = diffusion_model(x=noisy_image, timesteps=timesteps, context=condition) + + return prediction + + @torch.no_grad() + def sample( + self, + input_noise: torch.Tensor, + diffusion_model: DiffusionModelUNet, + scheduler: Scheduler | None = None, + save_intermediates: bool | None = False, + intermediate_steps: int | None = 100, + conditioning: torch.Tensor | None = None, + mode: str = "crossattn", + verbose: bool = True, + seg: torch.Tensor | None = None, + ) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]: + """ + Args: + input_noise: random noise, of the same shape as the desired sample. + diffusion_model: model to sample from. + scheduler: diffusion scheduler. If none provided will use the class attribute scheduler + save_intermediates: whether to return intermediates along the sampling change + intermediate_steps: if save_intermediates is True, saves every n steps + conditioning: Conditioning for network input. + mode: Conditioning mode for the network. + verbose: if true, prints the progression bar of the sampling process. + seg: if diffusion model is instance of SPADEDiffusionModel, segmentation must be provided. + """ + if mode not in ["crossattn", "concat"]: + raise NotImplementedError(f"{mode} condition is not supported") + if mode == "concat" and conditioning is None: + raise ValueError("Conditioning must be supplied for if condition mode is concat.") + if not scheduler: + scheduler = self.scheduler + image = input_noise + if verbose and has_tqdm: + progress_bar = tqdm(scheduler.timesteps) + else: + progress_bar = iter(scheduler.timesteps) + intermediates = [] + for t in progress_bar: + # 1. predict noise model_output + diffusion_model = ( + partial(diffusion_model, seg=seg) + if isinstance(diffusion_model, SPADEDiffusionModelUNet) + else diffusion_model + ) + if mode == "concat" and conditioning is not None: + model_input = torch.cat([image, conditioning], dim=1) + model_output = diffusion_model( + model_input, timesteps=torch.Tensor((t,)).to(input_noise.device), context=None + ) + else: + model_output = diffusion_model( + image, timesteps=torch.Tensor((t,)).to(input_noise.device), context=conditioning + ) + + # 2. compute previous image: x_t -> x_t-1 + image, _ = scheduler.step(model_output, t, image) + if save_intermediates and t % intermediate_steps == 0: + intermediates.append(image) + if save_intermediates: + return image, intermediates + else: + return image + + @torch.no_grad() + def get_likelihood( + self, + inputs: torch.Tensor, + diffusion_model: DiffusionModelUNet, + scheduler: Scheduler | None = None, + save_intermediates: bool | None = False, + conditioning: torch.Tensor | None = None, + mode: str = "crossattn", + original_input_range: tuple = (0, 255), + scaled_input_range: tuple = (0, 1), + verbose: bool = True, + seg: torch.Tensor | None = None, + ) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]: + """ + Computes the log-likelihoods for an input. + + Args: + inputs: input images, NxCxHxW[xD] + diffusion_model: model to compute likelihood from + scheduler: diffusion scheduler. If none provided will use the class attribute scheduler. + save_intermediates: save the intermediate spatial KL maps + conditioning: Conditioning for network input. + mode: Conditioning mode for the network. + original_input_range: the [min,max] intensity range of the input data before any scaling was applied. + scaled_input_range: the [min,max] intensity range of the input data after scaling. + verbose: if true, prints the progression bar of the sampling process. + seg: if diffusion model is instance of SPADEDiffusionModel, segmentation must be provided. + """ + + if not scheduler: + scheduler = self.scheduler + if scheduler._get_name() != "DDPMScheduler": + raise NotImplementedError( + f"Likelihood computation is only compatible with DDPMScheduler," + f" you are using {scheduler._get_name()}" + ) + if mode not in ["crossattn", "concat"]: + raise NotImplementedError(f"{mode} condition is not supported") + if mode == "concat" and conditioning is None: + raise ValueError("Conditioning must be supplied for if condition mode is concat.") + if verbose and has_tqdm: + progress_bar = tqdm(scheduler.timesteps) + else: + progress_bar = iter(scheduler.timesteps) + intermediates = [] + noise = torch.randn_like(inputs).to(inputs.device) + total_kl = torch.zeros(inputs.shape[0]).to(inputs.device) + for t in progress_bar: + timesteps = torch.full(inputs.shape[:1], t, device=inputs.device).long() + noisy_image = self.scheduler.add_noise(original_samples=inputs, noise=noise, timesteps=timesteps) + diffusion_model = ( + partial(diffusion_model, seg=seg) + if isinstance(diffusion_model, SPADEDiffusionModelUNet) + else diffusion_model + ) + if mode == "concat" and conditioning is not None: + noisy_image = torch.cat([noisy_image, conditioning], dim=1) + model_output = diffusion_model(noisy_image, timesteps=timesteps, context=None) + else: + model_output = diffusion_model(x=noisy_image, timesteps=timesteps, context=conditioning) + + # get the model's predicted mean, and variance if it is predicted + if model_output.shape[1] == inputs.shape[1] * 2 and scheduler.variance_type in ["learned", "learned_range"]: + model_output, predicted_variance = torch.split(model_output, inputs.shape[1], dim=1) + else: + predicted_variance = None + + # 1. compute alphas, betas + alpha_prod_t = scheduler.alphas_cumprod[t] + alpha_prod_t_prev = scheduler.alphas_cumprod[t - 1] if t > 0 else scheduler.one + beta_prod_t = 1 - alpha_prod_t + beta_prod_t_prev = 1 - alpha_prod_t_prev + + # 2. compute predicted original sample from predicted noise also called + # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf + if scheduler.prediction_type == "epsilon": + pred_original_sample = (noisy_image - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) + elif scheduler.prediction_type == "sample": + pred_original_sample = model_output + elif scheduler.prediction_type == "v_prediction": + pred_original_sample = (alpha_prod_t**0.5) * noisy_image - (beta_prod_t**0.5) * model_output + # 3. Clip "predicted x_0" + if scheduler.clip_sample: + pred_original_sample = torch.clamp(pred_original_sample, -1, 1) + + # 4. Compute coefficients for pred_original_sample x_0 and current sample x_t + # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf + pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * scheduler.betas[t]) / beta_prod_t + current_sample_coeff = scheduler.alphas[t] ** (0.5) * beta_prod_t_prev / beta_prod_t + + # 5. Compute predicted previous sample µ_t + # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf + predicted_mean = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * noisy_image + + # get the posterior mean and variance + posterior_mean = scheduler._get_mean(timestep=t, x_0=inputs, x_t=noisy_image) + posterior_variance = scheduler._get_variance(timestep=t, predicted_variance=predicted_variance) + + log_posterior_variance = torch.log(posterior_variance) + log_predicted_variance = torch.log(predicted_variance) if predicted_variance else log_posterior_variance + + if t == 0: + # compute -log p(x_0|x_1) + kl = -self._get_decoder_log_likelihood( + inputs=inputs, + means=predicted_mean, + log_scales=0.5 * log_predicted_variance, + original_input_range=original_input_range, + scaled_input_range=scaled_input_range, + ) + else: + # compute kl between two normals + kl = 0.5 * ( + -1.0 + + log_predicted_variance + - log_posterior_variance + + torch.exp(log_posterior_variance - log_predicted_variance) + + ((posterior_mean - predicted_mean) ** 2) * torch.exp(-log_predicted_variance) + ) + total_kl += kl.view(kl.shape[0], -1).mean(dim=1) + if save_intermediates: + intermediates.append(kl.cpu()) + + if save_intermediates: + return total_kl, intermediates + else: + return total_kl + + def _approx_standard_normal_cdf(self, x): + """ + A fast approximation of the cumulative distribution function of the + standard normal. Code adapted from https://github.com/openai/improved-diffusion. + """ + + return 0.5 * ( + 1.0 + torch.tanh(torch.sqrt(torch.Tensor([2.0 / math.pi]).to(x.device)) * (x + 0.044715 * torch.pow(x, 3))) + ) + + def _get_decoder_log_likelihood( + self, + inputs: torch.Tensor, + means: torch.Tensor, + log_scales: torch.Tensor, + original_input_range: tuple = (0, 255), + scaled_input_range: tuple = (0, 1), + ) -> torch.Tensor: + """ + Compute the log-likelihood of a Gaussian distribution discretizing to a + given image. Code adapted from https://github.com/openai/improved-diffusion. + + Args: + input: the target images. It is assumed that this was uint8 values, + rescaled to the range [-1, 1]. + means: the Gaussian mean Tensor. + log_scales: the Gaussian log stddev Tensor. + original_input_range: the [min,max] intensity range of the input data before any scaling was applied. + scaled_input_range: the [min,max] intensity range of the input data after scaling. + """ + if inputs.shape != means.shape: + raise ValueError(f"Inputs and means must have the same shape, got {inputs.shape} and {means.shape}") + bin_width = (scaled_input_range[1] - scaled_input_range[0]) / ( + original_input_range[1] - original_input_range[0] + ) + centered_x = inputs - means + inv_stdv = torch.exp(-log_scales) + plus_in = inv_stdv * (centered_x + bin_width / 2) + cdf_plus = self._approx_standard_normal_cdf(plus_in) + min_in = inv_stdv * (centered_x - bin_width / 2) + cdf_min = self._approx_standard_normal_cdf(min_in) + log_cdf_plus = torch.log(cdf_plus.clamp(min=1e-12)) + log_one_minus_cdf_min = torch.log((1.0 - cdf_min).clamp(min=1e-12)) + cdf_delta = cdf_plus - cdf_min + log_probs = torch.where( + inputs < -0.999, + log_cdf_plus, + torch.where(inputs > 0.999, log_one_minus_cdf_min, torch.log(cdf_delta.clamp(min=1e-12))), + ) + return log_probs + + +class LatentDiffusionInferer(DiffusionInferer): + """ + LatentDiffusionInferer takes a stage 1 model (VQVAE or AutoencoderKL), diffusion model, and a scheduler, and can + be used to perform a signal forward pass for a training iteration, and sample from the model. + + Args: + scheduler: a scheduler to be used in combination with `unet` to denoise the encoded image latents. + scale_factor: scale factor to multiply the values of the latent representation before processing it by the + second stage. + ldm_latent_shape: desired spatial latent space shape. Used if there is a difference in the autoencoder model's latent shape. + autoencoder_latent_shape: autoencoder_latent_shape: autoencoder spatial latent space shape. Used if there is a + difference between the autoencoder's latent shape and the DM shape. + """ + + def __init__( + self, + scheduler: Scheduler, + scale_factor: float = 1.0, + ldm_latent_shape: list | None = None, + autoencoder_latent_shape: list | None = None, + ) -> None: + super().__init__(scheduler=scheduler) + self.scale_factor = scale_factor + if (ldm_latent_shape is None) ^ (autoencoder_latent_shape is None): + raise ValueError("If ldm_latent_shape is None, autoencoder_latent_shape must be None, and vice versa.") + self.ldm_latent_shape = ldm_latent_shape + self.autoencoder_latent_shape = autoencoder_latent_shape + if self.ldm_latent_shape is not None and self.autoencoder_latent_shape is not None: + self.ldm_resizer = SpatialPad(spatial_size=self.ldm_latent_shape) + self.autoencoder_resizer = CenterSpatialCrop(roi_size=self.autoencoder_latent_shape) + + def __call__( # type: ignore[override] + self, + inputs: torch.Tensor, + autoencoder_model: AutoencoderKL | VQVAE, + diffusion_model: DiffusionModelUNet, + noise: torch.Tensor, + timesteps: torch.Tensor, + condition: torch.Tensor | None = None, + mode: str = "crossattn", + seg: torch.Tensor | None = None, + ) -> torch.Tensor: + """ + Implements the forward pass for a supervised training iteration. + + Args: + inputs: input image to which the latent representation will be extracted and noise is added. + autoencoder_model: first stage model. + diffusion_model: diffusion model. + noise: random noise, of the same shape as the latent representation. + timesteps: random timesteps. + condition: conditioning for network input. + mode: Conditioning mode for the network. + seg: if diffusion model is instance of SPADEDiffusionModel, segmentation must be provided. + """ + with torch.no_grad(): + latent = autoencoder_model.encode_stage_2_inputs(inputs) * self.scale_factor + + if self.ldm_latent_shape is not None: + latent = torch.stack([self.ldm_resizer(i) for i in decollate_batch(latent)], 0) + + prediction: torch.Tensor = super().__call__( + inputs=latent, + diffusion_model=diffusion_model, + noise=noise, + timesteps=timesteps, + condition=condition, + mode=mode, + seg=seg, + ) + return prediction + + @torch.no_grad() + def sample( # type: ignore[override] + self, + input_noise: torch.Tensor, + autoencoder_model: AutoencoderKL | VQVAE, + diffusion_model: DiffusionModelUNet, + scheduler: Scheduler | None = None, + save_intermediates: bool | None = False, + intermediate_steps: int | None = 100, + conditioning: torch.Tensor | None = None, + mode: str = "crossattn", + verbose: bool = True, + seg: torch.Tensor | None = None, + ) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]: + """ + Args: + input_noise: random noise, of the same shape as the desired latent representation. + autoencoder_model: first stage model. + diffusion_model: model to sample from. + scheduler: diffusion scheduler. If none provided will use the class attribute scheduler. + save_intermediates: whether to return intermediates along the sampling change + intermediate_steps: if save_intermediates is True, saves every n steps + conditioning: Conditioning for network input. + mode: Conditioning mode for the network. + verbose: if true, prints the progression bar of the sampling process. + seg: if diffusion model is instance of SPADEDiffusionModel, or autoencoder_model + is instance of SPADEAutoencoderKL, segmentation must be provided. + """ + + if ( + isinstance(autoencoder_model, SPADEAutoencoderKL) + and isinstance(diffusion_model, SPADEDiffusionModelUNet) + and autoencoder_model.decoder.label_nc != diffusion_model.label_nc + ): + raise ValueError( + f"If both autoencoder_model and diffusion_model implement SPADE, the number of semantic" + f"labels for each must be compatible, but got {autoencoder_model.decoder.label_nc} and" + f"{diffusion_model.label_nc}" + ) + + outputs = super().sample( + input_noise=input_noise, + diffusion_model=diffusion_model, + scheduler=scheduler, + save_intermediates=save_intermediates, + intermediate_steps=intermediate_steps, + conditioning=conditioning, + mode=mode, + verbose=verbose, + seg=seg, + ) + + if save_intermediates: + latent, latent_intermediates = outputs + else: + latent = outputs + + if self.autoencoder_latent_shape is not None: + latent = torch.stack([self.autoencoder_resizer(i) for i in decollate_batch(latent)], 0) + latent_intermediates = [ + torch.stack([self.autoencoder_resizer(i) for i in decollate_batch(l)], 0) for l in latent_intermediates + ] + + decode = autoencoder_model.decode_stage_2_outputs + if isinstance(autoencoder_model, SPADEAutoencoderKL): + decode = partial(autoencoder_model.decode_stage_2_outputs, seg=seg) + image = decode(latent / self.scale_factor) + + if save_intermediates: + intermediates = [] + for latent_intermediate in latent_intermediates: + decode = autoencoder_model.decode_stage_2_outputs + if isinstance(autoencoder_model, SPADEAutoencoderKL): + decode = partial(autoencoder_model.decode_stage_2_outputs, seg=seg) + intermediates.append(decode(latent_intermediate / self.scale_factor)) + return image, intermediates + + else: + return image + + @torch.no_grad() + def get_likelihood( # type: ignore[override] + self, + inputs: torch.Tensor, + autoencoder_model: AutoencoderKL | VQVAE, + diffusion_model: DiffusionModelUNet, + scheduler: Scheduler | None = None, + save_intermediates: bool | None = False, + conditioning: torch.Tensor | None = None, + mode: str = "crossattn", + original_input_range: tuple | None = (0, 255), + scaled_input_range: tuple | None = (0, 1), + verbose: bool = True, + resample_latent_likelihoods: bool = False, + resample_interpolation_mode: str = "nearest", + seg: torch.Tensor | None = None, + ) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]: + """ + Computes the log-likelihoods of the latent representations of the input. + + Args: + inputs: input images, NxCxHxW[xD] + autoencoder_model: first stage model. + diffusion_model: model to compute likelihood from + scheduler: diffusion scheduler. If none provided will use the class attribute scheduler + save_intermediates: save the intermediate spatial KL maps + conditioning: Conditioning for network input. + mode: Conditioning mode for the network. + original_input_range: the [min,max] intensity range of the input data before any scaling was applied. + scaled_input_range: the [min,max] intensity range of the input data after scaling. + verbose: if true, prints the progression bar of the sampling process. + resample_latent_likelihoods: if true, resamples the intermediate likelihood maps to have the same spatial + dimension as the input images. + resample_interpolation_mode: if use resample_latent_likelihoods, select interpolation 'nearest', 'bilinear', + or 'trilinear; + seg: if diffusion model is instance of SPADEDiffusionModel, or autoencoder_model + is instance of SPADEAutoencoderKL, segmentation must be provided. + """ + if resample_latent_likelihoods and resample_interpolation_mode not in ("nearest", "bilinear", "trilinear"): + raise ValueError( + f"resample_interpolation mode should be either nearest, bilinear, or trilinear, got {resample_interpolation_mode}" + ) + latents = autoencoder_model.encode_stage_2_inputs(inputs) * self.scale_factor + + if self.ldm_latent_shape is not None: + latents = torch.stack([self.ldm_resizer(i) for i in decollate_batch(latents)], 0) + + outputs = super().get_likelihood( + inputs=latents, + diffusion_model=diffusion_model, + scheduler=scheduler, + save_intermediates=save_intermediates, + conditioning=conditioning, + mode=mode, + verbose=verbose, + seg=seg, + ) + + if save_intermediates and resample_latent_likelihoods: + intermediates = outputs[1] + resizer = nn.Upsample(size=inputs.shape[2:], mode=resample_interpolation_mode) + intermediates = [resizer(x) for x in intermediates] + outputs = (outputs[0], intermediates) + return outputs + + +class ControlNetDiffusionInferer(DiffusionInferer): + """ + ControlNetDiffusionInferer takes a trained diffusion model and a scheduler and can be used to perform a signal + forward pass for a training iteration, and sample from the model, supporting ControlNet-based conditioning. + + Args: + scheduler: diffusion scheduler. + """ + + def __init__(self, scheduler: Scheduler) -> None: + Inferer.__init__(self) + self.scheduler = scheduler + + def __call__( # type: ignore[override] + self, + inputs: torch.Tensor, + diffusion_model: DiffusionModelUNet, + controlnet: ControlNet, + noise: torch.Tensor, + timesteps: torch.Tensor, + cn_cond: torch.Tensor, + condition: torch.Tensor | None = None, + mode: str = "crossattn", + seg: torch.Tensor | None = None, + ) -> torch.Tensor: + """ + Implements the forward pass for a supervised training iteration. + + Args: + inputs: Input image to which noise is added. + diffusion_model: diffusion model. + controlnet: controlnet sub-network. + noise: random noise, of the same shape as the input. + timesteps: random timesteps. + cn_cond: conditioning image for the ControlNet. + condition: Conditioning for network input. + mode: Conditioning mode for the network. + seg: if model is instance of SPADEDiffusionModelUnet, segmentation must be + provided on the forward (for SPADE-like AE or SPADE-like DM) + """ + if mode not in ["crossattn", "concat"]: + raise NotImplementedError(f"{mode} condition is not supported") + + noisy_image = self.scheduler.add_noise(original_samples=inputs, noise=noise, timesteps=timesteps) + down_block_res_samples, mid_block_res_sample = controlnet( + x=noisy_image, timesteps=timesteps, controlnet_cond=cn_cond + ) + if mode == "concat" and condition is not None: + noisy_image = torch.cat([noisy_image, condition], dim=1) + condition = None + + diffuse = diffusion_model + if isinstance(diffusion_model, SPADEDiffusionModelUNet): + diffuse = partial(diffusion_model, seg=seg) + + prediction: torch.Tensor = diffuse( + x=noisy_image, + timesteps=timesteps, + context=condition, + down_block_additional_residuals=down_block_res_samples, + mid_block_additional_residual=mid_block_res_sample, + ) + + return prediction + + @torch.no_grad() + def sample( # type: ignore[override] + self, + input_noise: torch.Tensor, + diffusion_model: DiffusionModelUNet, + controlnet: ControlNet, + cn_cond: torch.Tensor, + scheduler: Scheduler | None = None, + save_intermediates: bool | None = False, + intermediate_steps: int | None = 100, + conditioning: torch.Tensor | None = None, + mode: str = "crossattn", + verbose: bool = True, + seg: torch.Tensor | None = None, + ) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]: + """ + Args: + input_noise: random noise, of the same shape as the desired sample. + diffusion_model: model to sample from. + controlnet: controlnet sub-network. + cn_cond: conditioning image for the ControlNet. + scheduler: diffusion scheduler. If none provided will use the class attribute scheduler + save_intermediates: whether to return intermediates along the sampling change + intermediate_steps: if save_intermediates is True, saves every n steps + conditioning: Conditioning for network input. + mode: Conditioning mode for the network. + verbose: if true, prints the progression bar of the sampling process. + seg: if diffusion model is instance of SPADEDiffusionModel, segmentation must be provided. + """ + if mode not in ["crossattn", "concat"]: + raise NotImplementedError(f"{mode} condition is not supported") + + if not scheduler: + scheduler = self.scheduler + image = input_noise + if verbose and has_tqdm: + progress_bar = tqdm(scheduler.timesteps) + else: + progress_bar = iter(scheduler.timesteps) + intermediates = [] + for t in progress_bar: + # 1. ControlNet forward + down_block_res_samples, mid_block_res_sample = controlnet( + x=image, timesteps=torch.Tensor((t,)).to(input_noise.device), controlnet_cond=cn_cond + ) + # 2. predict noise model_output + diffuse = diffusion_model + if isinstance(diffusion_model, SPADEDiffusionModelUNet): + diffuse = partial(diffusion_model, seg=seg) + + if mode == "concat" and conditioning is not None: + model_input = torch.cat([image, conditioning], dim=1) + model_output = diffuse( + model_input, + timesteps=torch.Tensor((t,)).to(input_noise.device), + context=None, + down_block_additional_residuals=down_block_res_samples, + mid_block_additional_residual=mid_block_res_sample, + ) + else: + model_output = diffuse( + image, + timesteps=torch.Tensor((t,)).to(input_noise.device), + context=conditioning, + down_block_additional_residuals=down_block_res_samples, + mid_block_additional_residual=mid_block_res_sample, + ) + + # 3. compute previous image: x_t -> x_t-1 + image, _ = scheduler.step(model_output, t, image) + if save_intermediates and t % intermediate_steps == 0: + intermediates.append(image) + if save_intermediates: + return image, intermediates + else: + return image + + @torch.no_grad() + def get_likelihood( # type: ignore[override] + self, + inputs: torch.Tensor, + diffusion_model: DiffusionModelUNet, + controlnet: ControlNet, + cn_cond: torch.Tensor, + scheduler: Scheduler | None = None, + save_intermediates: bool | None = False, + conditioning: torch.Tensor | None = None, + mode: str = "crossattn", + original_input_range: tuple = (0, 255), + scaled_input_range: tuple = (0, 1), + verbose: bool = True, + seg: torch.Tensor | None = None, + ) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]: + """ + Computes the log-likelihoods for an input. + + Args: + inputs: input images, NxCxHxW[xD] + diffusion_model: model to compute likelihood from + controlnet: controlnet sub-network. + cn_cond: conditioning image for the ControlNet. + scheduler: diffusion scheduler. If none provided will use the class attribute scheduler. + save_intermediates: save the intermediate spatial KL maps + conditioning: Conditioning for network input. + mode: Conditioning mode for the network. + original_input_range: the [min,max] intensity range of the input data before any scaling was applied. + scaled_input_range: the [min,max] intensity range of the input data after scaling. + verbose: if true, prints the progression bar of the sampling process. + seg: if diffusion model is instance of SPADEDiffusionModel, segmentation must be provided. + """ + + if not scheduler: + scheduler = self.scheduler + if scheduler._get_name() != "DDPMScheduler": + raise NotImplementedError( + f"Likelihood computation is only compatible with DDPMScheduler," + f" you are using {scheduler._get_name()}" + ) + if mode not in ["crossattn", "concat"]: + raise NotImplementedError(f"{mode} condition is not supported") + if verbose and has_tqdm: + progress_bar = tqdm(scheduler.timesteps) + else: + progress_bar = iter(scheduler.timesteps) + intermediates = [] + noise = torch.randn_like(inputs).to(inputs.device) + total_kl = torch.zeros(inputs.shape[0]).to(inputs.device) + for t in progress_bar: + timesteps = torch.full(inputs.shape[:1], t, device=inputs.device).long() + noisy_image = self.scheduler.add_noise(original_samples=inputs, noise=noise, timesteps=timesteps) + down_block_res_samples, mid_block_res_sample = controlnet( + x=noisy_image, timesteps=torch.Tensor((t,)).to(inputs.device), controlnet_cond=cn_cond + ) + + diffuse = diffusion_model + if isinstance(diffusion_model, SPADEDiffusionModelUNet): + diffuse = partial(diffusion_model, seg=seg) + + if mode == "concat" and conditioning is not None: + noisy_image = torch.cat([noisy_image, conditioning], dim=1) + model_output = diffuse( + noisy_image, + timesteps=timesteps, + context=None, + down_block_additional_residuals=down_block_res_samples, + mid_block_additional_residual=mid_block_res_sample, + ) + else: + model_output = diffuse( + x=noisy_image, + timesteps=timesteps, + context=conditioning, + down_block_additional_residuals=down_block_res_samples, + mid_block_additional_residual=mid_block_res_sample, + ) + # get the model's predicted mean, and variance if it is predicted + if model_output.shape[1] == inputs.shape[1] * 2 and scheduler.variance_type in ["learned", "learned_range"]: + model_output, predicted_variance = torch.split(model_output, inputs.shape[1], dim=1) + else: + predicted_variance = None + + # 1. compute alphas, betas + alpha_prod_t = scheduler.alphas_cumprod[t] + alpha_prod_t_prev = scheduler.alphas_cumprod[t - 1] if t > 0 else scheduler.one + beta_prod_t = 1 - alpha_prod_t + beta_prod_t_prev = 1 - alpha_prod_t_prev + + # 2. compute predicted original sample from predicted noise also called + # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf + if scheduler.prediction_type == "epsilon": + pred_original_sample = (noisy_image - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) + elif scheduler.prediction_type == "sample": + pred_original_sample = model_output + elif scheduler.prediction_type == "v_prediction": + pred_original_sample = (alpha_prod_t**0.5) * noisy_image - (beta_prod_t**0.5) * model_output + # 3. Clip "predicted x_0" + if scheduler.clip_sample: + pred_original_sample = torch.clamp(pred_original_sample, -1, 1) + + # 4. Compute coefficients for pred_original_sample x_0 and current sample x_t + # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf + pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * scheduler.betas[t]) / beta_prod_t + current_sample_coeff = scheduler.alphas[t] ** (0.5) * beta_prod_t_prev / beta_prod_t + + # 5. Compute predicted previous sample µ_t + # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf + predicted_mean = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * noisy_image + + # get the posterior mean and variance + posterior_mean = scheduler._get_mean(timestep=t, x_0=inputs, x_t=noisy_image) + posterior_variance = scheduler._get_variance(timestep=t, predicted_variance=predicted_variance) + + log_posterior_variance = torch.log(posterior_variance) + log_predicted_variance = torch.log(predicted_variance) if predicted_variance else log_posterior_variance + + if t == 0: + # compute -log p(x_0|x_1) + kl = -super()._get_decoder_log_likelihood( + inputs=inputs, + means=predicted_mean, + log_scales=0.5 * log_predicted_variance, + original_input_range=original_input_range, + scaled_input_range=scaled_input_range, + ) + else: + # compute kl between two normals + kl = 0.5 * ( + -1.0 + + log_predicted_variance + - log_posterior_variance + + torch.exp(log_posterior_variance - log_predicted_variance) + + ((posterior_mean - predicted_mean) ** 2) * torch.exp(-log_predicted_variance) + ) + total_kl += kl.view(kl.shape[0], -1).mean(dim=1) + if save_intermediates: + intermediates.append(kl.cpu()) + + if save_intermediates: + return total_kl, intermediates + else: + return total_kl + + +class ControlNetLatentDiffusionInferer(ControlNetDiffusionInferer): + """ + ControlNetLatentDiffusionInferer takes a stage 1 model (VQVAE or AutoencoderKL), diffusion model, controlnet, + and a scheduler, and can be used to perform a signal forward pass for a training iteration, and sample from + the model. + + Args: + scheduler: a scheduler to be used in combination with `unet` to denoise the encoded image latents. + scale_factor: scale factor to multiply the values of the latent representation before processing it by the + second stage. + ldm_latent_shape: desired spatial latent space shape. Used if there is a difference in the autoencoder model's latent shape. + autoencoder_latent_shape: autoencoder_latent_shape: autoencoder spatial latent space shape. Used if there is a + difference between the autoencoder's latent shape and the DM shape. + """ + + def __init__( + self, + scheduler: Scheduler, + scale_factor: float = 1.0, + ldm_latent_shape: list | None = None, + autoencoder_latent_shape: list | None = None, + ) -> None: + super().__init__(scheduler=scheduler) + self.scale_factor = scale_factor + if (ldm_latent_shape is None) ^ (autoencoder_latent_shape is None): + raise ValueError("If ldm_latent_shape is None, autoencoder_latent_shape must be None" "and vice versa.") + self.ldm_latent_shape = ldm_latent_shape + self.autoencoder_latent_shape = autoencoder_latent_shape + if self.ldm_latent_shape is not None and self.autoencoder_latent_shape is not None: + self.ldm_resizer = SpatialPad(spatial_size=self.ldm_latent_shape) + self.autoencoder_resizer = CenterSpatialCrop(roi_size=[-1] + self.autoencoder_latent_shape) + + def __call__( # type: ignore[override] + self, + inputs: torch.Tensor, + autoencoder_model: AutoencoderKL | VQVAE, + diffusion_model: DiffusionModelUNet, + controlnet: ControlNet, + noise: torch.Tensor, + timesteps: torch.Tensor, + cn_cond: torch.Tensor, + condition: torch.Tensor | None = None, + mode: str = "crossattn", + seg: torch.Tensor | None = None, + ) -> torch.Tensor: + """ + Implements the forward pass for a supervised training iteration. + + Args: + inputs: input image to which the latent representation will be extracted and noise is added. + autoencoder_model: first stage model. + diffusion_model: diffusion model. + controlnet: instance of ControlNet model + noise: random noise, of the same shape as the latent representation. + timesteps: random timesteps. + cn_cond: conditioning tensor for the ControlNet network + condition: conditioning for network input. + mode: Conditioning mode for the network. + seg: if diffusion model is instance of SPADEDiffusionModel, segmentation must be provided. + """ + with torch.no_grad(): + latent = autoencoder_model.encode_stage_2_inputs(inputs) * self.scale_factor + + if self.ldm_latent_shape is not None: + latent = torch.stack([self.ldm_resizer(i) for i in decollate_batch(latent)], 0) + + if cn_cond.shape[2:] != latent.shape[2:]: + cn_cond = F.interpolate(cn_cond, latent.shape[2:]) + + prediction = super().__call__( + inputs=latent, + diffusion_model=diffusion_model, + controlnet=controlnet, + noise=noise, + timesteps=timesteps, + cn_cond=cn_cond, + condition=condition, + mode=mode, + seg=seg, + ) + + return prediction + + @torch.no_grad() + def sample( # type: ignore[override] + self, + input_noise: torch.Tensor, + autoencoder_model: AutoencoderKL | VQVAE, + diffusion_model: DiffusionModelUNet, + controlnet: ControlNet, + cn_cond: torch.Tensor, + scheduler: Scheduler | None = None, + save_intermediates: bool | None = False, + intermediate_steps: int | None = 100, + conditioning: torch.Tensor | None = None, + mode: str = "crossattn", + verbose: bool = True, + seg: torch.Tensor | None = None, + ) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]: + """ + Args: + input_noise: random noise, of the same shape as the desired latent representation. + autoencoder_model: first stage model. + diffusion_model: model to sample from. + controlnet: instance of ControlNet model. + cn_cond: conditioning tensor for the ControlNet network. + scheduler: diffusion scheduler. If none provided will use the class attribute scheduler. + save_intermediates: whether to return intermediates along the sampling change + intermediate_steps: if save_intermediates is True, saves every n steps + conditioning: Conditioning for network input. + mode: Conditioning mode for the network. + verbose: if true, prints the progression bar of the sampling process. + seg: if diffusion model is instance of SPADEDiffusionModel, or autoencoder_model + is instance of SPADEAutoencoderKL, segmentation must be provided. + """ + + if ( + isinstance(autoencoder_model, SPADEAutoencoderKL) + and isinstance(diffusion_model, SPADEDiffusionModelUNet) + and autoencoder_model.decoder.label_nc != diffusion_model.label_nc + ): + raise ValueError( + "If both autoencoder_model and diffusion_model implement SPADE, the number of semantic" + "labels for each must be compatible. Got {autoencoder_model.decoder.label_nc} and {diffusion_model.label_nc}" + ) + + if cn_cond.shape[2:] != input_noise.shape[2:]: + cn_cond = F.interpolate(cn_cond, input_noise.shape[2:]) + + outputs = super().sample( + input_noise=input_noise, + diffusion_model=diffusion_model, + controlnet=controlnet, + cn_cond=cn_cond, + scheduler=scheduler, + save_intermediates=save_intermediates, + intermediate_steps=intermediate_steps, + conditioning=conditioning, + mode=mode, + verbose=verbose, + seg=seg, + ) + + if save_intermediates: + latent, latent_intermediates = outputs + else: + latent = outputs + + if self.autoencoder_latent_shape is not None: + latent = torch.stack([self.autoencoder_resizer(i) for i in decollate_batch(latent)], 0) + latent_intermediates = [ + torch.stack([self.autoencoder_resizer(i) for i in decollate_batch(l)], 0) for l in latent_intermediates + ] + + decode = autoencoder_model.decode_stage_2_outputs + if isinstance(autoencoder_model, SPADEAutoencoderKL): + decode = partial(autoencoder_model.decode_stage_2_outputs, seg=seg) + + image = decode(latent / self.scale_factor) + + if save_intermediates: + intermediates = [] + for latent_intermediate in latent_intermediates: + decode = autoencoder_model.decode_stage_2_outputs + if isinstance(autoencoder_model, SPADEAutoencoderKL): + decode = partial(autoencoder_model.decode_stage_2_outputs, seg=seg) + intermediates.append(decode(latent_intermediate / self.scale_factor)) + return image, intermediates + + else: + return image + + @torch.no_grad() + def get_likelihood( # type: ignore[override] + self, + inputs: torch.Tensor, + autoencoder_model: AutoencoderKL | VQVAE, + diffusion_model: DiffusionModelUNet, + controlnet: ControlNet, + cn_cond: torch.Tensor, + scheduler: Scheduler | None = None, + save_intermediates: bool | None = False, + conditioning: torch.Tensor | None = None, + mode: str = "crossattn", + original_input_range: tuple | None = (0, 255), + scaled_input_range: tuple | None = (0, 1), + verbose: bool = True, + resample_latent_likelihoods: bool = False, + resample_interpolation_mode: str = "nearest", + seg: torch.Tensor | None = None, + ) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]: + """ + Computes the log-likelihoods of the latent representations of the input. + + Args: + inputs: input images, NxCxHxW[xD] + autoencoder_model: first stage model. + diffusion_model: model to compute likelihood from + controlnet: instance of ControlNet model. + cn_cond: conditioning tensor for the ControlNet network. + scheduler: diffusion scheduler. If none provided will use the class attribute scheduler + save_intermediates: save the intermediate spatial KL maps + conditioning: Conditioning for network input. + mode: Conditioning mode for the network. + original_input_range: the [min,max] intensity range of the input data before any scaling was applied. + scaled_input_range: the [min,max] intensity range of the input data after scaling. + verbose: if true, prints the progression bar of the sampling process. + resample_latent_likelihoods: if true, resamples the intermediate likelihood maps to have the same spatial + dimension as the input images. + resample_interpolation_mode: if use resample_latent_likelihoods, select interpolation 'nearest', 'bilinear', + or 'trilinear; + seg: if diffusion model is instance of SPADEDiffusionModel, or autoencoder_model + is instance of SPADEAutoencoderKL, segmentation must be provided. + """ + if resample_latent_likelihoods and resample_interpolation_mode not in ("nearest", "bilinear", "trilinear"): + raise ValueError( + f"resample_interpolation mode should be either nearest, bilinear, or trilinear, got {resample_interpolation_mode}" + ) + + latents = autoencoder_model.encode_stage_2_inputs(inputs) * self.scale_factor + + if cn_cond.shape[2:] != latents.shape[2:]: + cn_cond = F.interpolate(cn_cond, latents.shape[2:]) + + if self.ldm_latent_shape is not None: + latents = torch.stack([self.ldm_resizer(i) for i in decollate_batch(latents)], 0) + + outputs = super().get_likelihood( + inputs=latents, + diffusion_model=diffusion_model, + controlnet=controlnet, + cn_cond=cn_cond, + scheduler=scheduler, + save_intermediates=save_intermediates, + conditioning=conditioning, + mode=mode, + verbose=verbose, + seg=seg, + ) + + if save_intermediates and resample_latent_likelihoods: + intermediates = outputs[1] + resizer = nn.Upsample(size=inputs.shape[2:], mode=resample_interpolation_mode) + intermediates = [resizer(x) for x in intermediates] + outputs = (outputs[0], intermediates) + return outputs + + +class VQVAETransformerInferer(nn.Module): + """ + Class to perform inference with a VQVAE + Transformer model. + """ + + def __init__(self) -> None: + Inferer.__init__(self) + + def __call__( + self, + inputs: torch.Tensor, + vqvae_model: VQVAE, + transformer_model: DecoderOnlyTransformer, + ordering: Ordering, + condition: torch.Tensor | None = None, + return_latent: bool = False, + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor, tuple]: + """ + Implements the forward pass for a supervised training iteration. + + Args: + inputs: input image to which the latent representation will be extracted. + vqvae_model: first stage model. + transformer_model: autoregressive transformer model. + ordering: ordering of the quantised latent representation. + return_latent: also return latent sequence and spatial dim of the latent. + condition: conditioning for network input. + """ + with torch.no_grad(): + latent = vqvae_model.index_quantize(inputs) + + latent_spatial_dim = tuple(latent.shape[1:]) + latent = latent.reshape(latent.shape[0], -1) + latent = latent[:, ordering.get_sequence_ordering()] + + # get the targets for the loss + target = latent.clone() + # Use the value from vqvae_model's num_embeddings as the starting token, the "Begin Of Sentence" (BOS) token. + # Note the transformer_model must have vqvae_model.num_embeddings + 1 defined as num_tokens. + latent = F.pad(latent, (1, 0), "constant", vqvae_model.num_embeddings) + # crop the last token as we do not need the probability of the token that follows it + latent = latent[:, :-1] + latent = latent.long() + + # train on a part of the sequence if it is longer than max_seq_length + seq_len = latent.shape[1] + max_seq_len = transformer_model.max_seq_len + if max_seq_len < seq_len: + start = int(torch.randint(low=0, high=seq_len + 1 - max_seq_len, size=(1,)).item()) + else: + start = 0 + prediction: torch.Tensor = transformer_model(x=latent[:, start : start + max_seq_len], context=condition) + if return_latent: + return prediction, target[:, start : start + max_seq_len], latent_spatial_dim + else: + return prediction + + @torch.no_grad() + def sample( + self, + latent_spatial_dim: tuple[int, int, int] | tuple[int, int], + starting_tokens: torch.Tensor, + vqvae_model: VQVAE, + transformer_model: DecoderOnlyTransformer, + ordering: Ordering, + conditioning: torch.Tensor | None = None, + temperature: float = 1.0, + top_k: int | None = None, + verbose: bool = True, + ) -> torch.Tensor: + """ + Sampling function for the VQVAE + Transformer model. + + Args: + latent_spatial_dim: shape of the sampled image. + starting_tokens: starting tokens for the sampling. It must be vqvae_model.num_embeddings value. + vqvae_model: first stage model. + transformer_model: model to sample from. + conditioning: Conditioning for network input. + temperature: temperature for sampling. + top_k: top k sampling. + verbose: if true, prints the progression bar of the sampling process. + """ + seq_len = math.prod(latent_spatial_dim) + + if verbose and has_tqdm: + progress_bar = tqdm(range(seq_len)) + else: + progress_bar = iter(range(seq_len)) + + latent_seq = starting_tokens.long() + for _ in progress_bar: + # if the sequence context is growing too long we must crop it at block_size + if latent_seq.size(1) <= transformer_model.max_seq_len: + idx_cond = latent_seq + else: + idx_cond = latent_seq[:, -transformer_model.max_seq_len :] + + # forward the model to get the logits for the index in the sequence + logits = transformer_model(x=idx_cond, context=conditioning) + # pluck the logits at the final step and scale by desired temperature + logits = logits[:, -1, :] / temperature + # optionally crop the logits to only the top k options + if top_k is not None: + v, _ = torch.topk(logits, min(top_k, logits.size(-1))) + logits[logits < v[:, [-1]]] = -float("Inf") + # apply softmax to convert logits to (normalized) probabilities + probs = F.softmax(logits, dim=-1) + # remove the chance to be sampled the BOS token + probs[:, vqvae_model.num_embeddings] = 0 + # sample from the distribution + idx_next = torch.multinomial(probs, num_samples=1) + # append sampled index to the running sequence and continue + latent_seq = torch.cat((latent_seq, idx_next), dim=1) + + latent_seq = latent_seq[:, 1:] + latent_seq = latent_seq[:, ordering.get_revert_sequence_ordering()] + latent = latent_seq.reshape((starting_tokens.shape[0],) + latent_spatial_dim) + + return vqvae_model.decode_samples(latent) + + @torch.no_grad() + def get_likelihood( + self, + inputs: torch.Tensor, + vqvae_model: VQVAE, + transformer_model: DecoderOnlyTransformer, + ordering: Ordering, + condition: torch.Tensor | None = None, + resample_latent_likelihoods: bool = False, + resample_interpolation_mode: str = "nearest", + verbose: bool = False, + ) -> torch.Tensor: + """ + Computes the log-likelihoods of the latent representations of the input. + + Args: + inputs: input images, NxCxHxW[xD] + vqvae_model: first stage model. + transformer_model: autoregressive transformer model. + ordering: ordering of the quantised latent representation. + condition: conditioning for network input. + resample_latent_likelihoods: if true, resamples the intermediate likelihood maps to have the same spatial + dimension as the input images. + resample_interpolation_mode: if use resample_latent_likelihoods, select interpolation 'nearest', 'bilinear', + or 'trilinear; + verbose: if true, prints the progression bar of the sampling process. + + """ + if resample_latent_likelihoods and resample_interpolation_mode not in ("nearest", "bilinear", "trilinear"): + raise ValueError( + f"resample_interpolation mode should be either nearest, bilinear, or trilinear, got {resample_interpolation_mode}" + ) + + with torch.no_grad(): + latent = vqvae_model.index_quantize(inputs) + + latent_spatial_dim = tuple(latent.shape[1:]) + latent = latent.reshape(latent.shape[0], -1) + latent = latent[:, ordering.get_sequence_ordering()] + seq_len = math.prod(latent_spatial_dim) + + # Use the value from vqvae_model's num_embeddings as the starting token, the "Begin Of Sentence" (BOS) token. + # Note the transformer_model must have vqvae_model.num_embeddings + 1 defined as num_tokens. + latent = F.pad(latent, (1, 0), "constant", vqvae_model.num_embeddings) + latent = latent.long() + + # get the first batch, up to max_seq_length, efficiently + logits = transformer_model(x=latent[:, : transformer_model.max_seq_len], context=condition) + probs = F.softmax(logits, dim=-1) + # target token for each set of logits is the next token along + target = latent[:, 1:] + probs = torch.gather(probs, 2, target[:, : transformer_model.max_seq_len].unsqueeze(2)).squeeze(2) + + # if we have not covered the full sequence we continue with inefficient looping + if probs.shape[1] < target.shape[1]: + if verbose and has_tqdm: + progress_bar = tqdm(range(transformer_model.max_seq_len, seq_len)) + else: + progress_bar = iter(range(transformer_model.max_seq_len, seq_len)) + + for i in progress_bar: + idx_cond = latent[:, i + 1 - transformer_model.max_seq_len : i + 1] + # forward the model to get the logits for the index in the sequence + logits = transformer_model(x=idx_cond, context=condition) + # pluck the logits at the final step + logits = logits[:, -1, :] + # apply softmax to convert logits to (normalized) probabilities + p = F.softmax(logits, dim=-1) + # select correct values and append + p = torch.gather(p, 1, target[:, i].unsqueeze(1)) + + probs = torch.cat((probs, p), dim=1) + + # convert to log-likelihood + probs = torch.log(probs) + + # reshape + probs = probs[:, ordering.get_revert_sequence_ordering()] + probs_reshaped = probs.reshape((inputs.shape[0],) + latent_spatial_dim) + if resample_latent_likelihoods: + resizer = nn.Upsample(size=inputs.shape[2:], mode=resample_interpolation_mode) + probs_reshaped = resizer(probs_reshaped[:, None, ...]) + + return probs_reshaped diff --git a/monai/networks/nets/diffusion_model_unet.py b/monai/networks/nets/diffusion_model_unet.py index 1532215c70..0441cc9cfe 100644 --- a/monai/networks/nets/diffusion_model_unet.py +++ b/monai/networks/nets/diffusion_model_unet.py @@ -430,7 +430,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: batch, channel, height, width, depth = x.shape # norm - x = self.norm(x) + x = self.norm(x.contiguous()) if self.spatial_dims == 2: x = x.view(batch, channel, height * width).transpose(1, 2) @@ -682,7 +682,7 @@ def __init__( ) def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor: - h = x + h = x.contiguous() h = self.norm1(h) h = self.nonlinearity(h) @@ -1957,7 +1957,7 @@ def forward( h = upsample_block(hidden_states=h, res_hidden_states_list=res_samples, temb=emb, context=context) # 7. output block - output: torch.Tensor = self.out(h) + output: torch.Tensor = self.out(h.contiguous()) return output diff --git a/monai/networks/nets/spade_diffusion_model_unet.py b/monai/networks/nets/spade_diffusion_model_unet.py index d53327100e..bffc9c5465 100644 --- a/monai/networks/nets/spade_diffusion_model_unet.py +++ b/monai/networks/nets/spade_diffusion_model_unet.py @@ -618,7 +618,7 @@ class SPADEDiffusionModelUNet(nn.Module): out_channels: number of output channels. label_nc: number of semantic channels for SPADE normalisation. num_res_blocks: number of residual blocks (see ResnetBlock) per level. - num_channels: tuple of block output channels. + channels: tuple of block output channels. attention_levels: list of levels to add attention. norm_num_groups: number of groups for the normalization. norm_eps: epsilon for the normalization. @@ -641,7 +641,7 @@ def __init__( out_channels: int, label_nc: int, num_res_blocks: Sequence[int] | int = (2, 2, 2, 2), - num_channels: Sequence[int] = (32, 64, 64, 64), + channels: Sequence[int] = (32, 64, 64, 64), attention_levels: Sequence[bool] = (False, False, True, True), norm_num_groups: int = 32, norm_eps: float = 1e-6, @@ -667,10 +667,10 @@ def __init__( ) # All number of channels should be multiple of num_groups - if any((out_channel % norm_num_groups) != 0 for out_channel in num_channels): + if any((out_channel % norm_num_groups) != 0 for out_channel in channels): raise ValueError("SPADEDiffusionModelUNet expects all num_channels being multiple of norm_num_groups") - if len(num_channels) != len(attention_levels): + if len(channels) != len(attention_levels): raise ValueError("SPADEDiffusionModelUNet expects num_channels being same size of attention_levels") if isinstance(num_head_channels, int): @@ -683,9 +683,9 @@ def __init__( ) if isinstance(num_res_blocks, int): - num_res_blocks = ensure_tuple_rep(num_res_blocks, len(num_channels)) + num_res_blocks = ensure_tuple_rep(num_res_blocks, len(channels)) - if len(num_res_blocks) != len(num_channels): + if len(num_res_blocks) != len(channels): raise ValueError( "`num_res_blocks` should be a single integer or a tuple of integers with the same length as " "`num_channels`." @@ -700,7 +700,7 @@ def __init__( ) self.in_channels = in_channels - self.block_out_channels = num_channels + self.block_out_channels = channels self.out_channels = out_channels self.num_res_blocks = num_res_blocks self.attention_levels = attention_levels @@ -712,7 +712,7 @@ def __init__( self.conv_in = Convolution( spatial_dims=spatial_dims, in_channels=in_channels, - out_channels=num_channels[0], + out_channels=channels[0], strides=1, kernel_size=3, padding=1, @@ -720,9 +720,9 @@ def __init__( ) # time - time_embed_dim = num_channels[0] * 4 + time_embed_dim = channels[0] * 4 self.time_embed = nn.Sequential( - nn.Linear(num_channels[0], time_embed_dim), nn.SiLU(), nn.Linear(time_embed_dim, time_embed_dim) + nn.Linear(channels[0], time_embed_dim), nn.SiLU(), nn.Linear(time_embed_dim, time_embed_dim) ) # class embedding @@ -732,11 +732,11 @@ def __init__( # down self.down_blocks = nn.ModuleList([]) - output_channel = num_channels[0] - for i in range(len(num_channels)): + output_channel = channels[0] + for i in range(len(channels)): input_channel = output_channel - output_channel = num_channels[i] - is_final_block = i == len(num_channels) - 1 + output_channel = channels[i] + is_final_block = i == len(channels) - 1 down_block = get_down_block( spatial_dims=spatial_dims, @@ -762,7 +762,7 @@ def __init__( # mid self.middle_block = get_mid_block( spatial_dims=spatial_dims, - in_channels=num_channels[-1], + in_channels=channels[-1], temb_channels=time_embed_dim, norm_num_groups=norm_num_groups, norm_eps=norm_eps, @@ -776,7 +776,7 @@ def __init__( # up self.up_blocks = nn.ModuleList([]) - reversed_block_out_channels = list(reversed(num_channels)) + reversed_block_out_channels = list(reversed(channels)) reversed_num_res_blocks = list(reversed(num_res_blocks)) reversed_attention_levels = list(reversed(attention_levels)) reversed_num_head_channels = list(reversed(num_head_channels)) @@ -784,9 +784,9 @@ def __init__( for i in range(len(reversed_block_out_channels)): prev_output_channel = output_channel output_channel = reversed_block_out_channels[i] - input_channel = reversed_block_out_channels[min(i + 1, len(num_channels) - 1)] + input_channel = reversed_block_out_channels[min(i + 1, len(channels) - 1)] - is_final_block = i == len(num_channels) - 1 + is_final_block = i == len(channels) - 1 up_block = get_spade_up_block( spatial_dims=spatial_dims, @@ -814,12 +814,12 @@ def __init__( # out self.out = nn.Sequential( - nn.GroupNorm(num_groups=norm_num_groups, num_channels=num_channels[0], eps=norm_eps, affine=True), + nn.GroupNorm(num_groups=norm_num_groups, num_channels=channels[0], eps=norm_eps, affine=True), nn.SiLU(), zero_module( Convolution( spatial_dims=spatial_dims, - in_channels=num_channels[0], + in_channels=channels[0], out_channels=out_channels, strides=1, kernel_size=3, diff --git a/monai/utils/__init__.py b/monai/utils/__init__.py index 2c32eb2cf4..03fa1ceed1 100644 --- a/monai/utils/__init__.py +++ b/monai/utils/__init__.py @@ -126,6 +126,7 @@ version_leq, ) from .nvtx import Range +from .ordering import Ordering from .profiling import ( PerfContext, ProfileHandler, diff --git a/setup.cfg b/setup.cfg index 123da68dfa..0069214de3 100644 --- a/setup.cfg +++ b/setup.cfg @@ -52,7 +52,7 @@ all = scipy>=1.7.1 pillow tensorboard - gdown>=4.4.0 + gdown==4.6.3 pytorch-ignite==0.4.11 torchvision itk>=5.2 @@ -60,12 +60,12 @@ all = lmdb psutil cucim>=23.2.0 - openslide-python==1.1.2 + openslide-python tifffile imagecodecs pandas einops - transformers<4.22 + transformers<4.22; python_version <= '3.10' mlflow>=1.28.0 clearml>=1.10.0rc0 matplotlib @@ -97,7 +97,7 @@ pillow = tensorboard = tensorboard gdown = - gdown>=4.4.0 + gdown==4.6.3 ignite = pytorch-ignite==0.4.11 torchvision = @@ -113,7 +113,7 @@ psutil = cucim = cucim>=23.2.0 openslide = - openslide-python==1.1.2 + openslide-python tifffile = tifffile imagecodecs = @@ -123,7 +123,7 @@ pandas = einops = einops transformers = - transformers<4.22 + transformers<4.22; python_version <= '3.10' mlflow = mlflow matplotlib = @@ -173,6 +173,7 @@ max_line_length = 120 # B028 https://github.com/Project-MONAI/MONAI/issues/5855 # B907 https://github.com/Project-MONAI/MONAI/issues/5868 # B908 https://github.com/Project-MONAI/MONAI/issues/6503 +# B036 https://github.com/Project-MONAI/MONAI/issues/7396 ignore = E203 E501 @@ -186,6 +187,7 @@ ignore = B028 B907 B908 + B036 per_file_ignores = __init__.py: F401, __main__.py: F401 exclude = *.pyi,.git,.eggs,monai/_version.py,versioneer.py,venv,.venv,_version.py diff --git a/tests/test_controlnet_inferers.py b/tests/test_controlnet_inferers.py new file mode 100644 index 0000000000..1f675537dc --- /dev/null +++ b/tests/test_controlnet_inferers.py @@ -0,0 +1,1270 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import torch +from parameterized import parameterized + +from monai.inferers import ControlNetDiffusionInferer, ControlNetLatentDiffusionInferer +from monai.networks.nets import ( + VQVAE, + AutoencoderKL, + ControlNet, + DiffusionModelUNet, + SPADEAutoencoderKL, + SPADEDiffusionModelUNet, +) +from monai.networks.schedulers import DDIMScheduler, DDPMScheduler +from monai.utils import optional_import + +_, has_scipy = optional_import("scipy") + +CNDM_TEST_CASES = [ + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "channels": [8], + "norm_num_groups": 8, + "attention_levels": [True], + "num_res_blocks": 1, + "num_head_channels": 8, + }, + { + "spatial_dims": 2, + "in_channels": 1, + "channels": [8], + "attention_levels": [True], + "norm_num_groups": 8, + "num_res_blocks": 1, + "num_head_channels": 8, + "conditioning_embedding_num_channels": [16], + "conditioning_embedding_in_channels": 1, + }, + (2, 1, 8, 8), + ], + [ + { + "spatial_dims": 3, + "in_channels": 1, + "out_channels": 1, + "channels": [8], + "norm_num_groups": 8, + "attention_levels": [True], + "num_res_blocks": 1, + "num_head_channels": 8, + }, + { + "spatial_dims": 3, + "in_channels": 1, + "channels": [8], + "attention_levels": [True], + "num_res_blocks": 1, + "norm_num_groups": 8, + "num_head_channels": 8, + "conditioning_embedding_num_channels": [16], + "conditioning_embedding_in_channels": 1, + }, + (2, 1, 8, 8, 8), + ], +] +LATENT_CNDM_TEST_CASES = [ + [ + "AutoencoderKL", + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "channels": (4, 4), + "latent_channels": 3, + "attention_levels": [False, False], + "num_res_blocks": 1, + "with_encoder_nonlocal_attn": False, + "with_decoder_nonlocal_attn": False, + "norm_num_groups": 4, + }, + "DiffusionModelUNet", + { + "spatial_dims": 2, + "in_channels": 3, + "out_channels": 3, + "channels": [4, 4], + "norm_num_groups": 4, + "attention_levels": [False, False], + "num_res_blocks": 1, + "num_head_channels": 4, + }, + { + "spatial_dims": 2, + "in_channels": 3, + "channels": [4, 4], + "attention_levels": [False, False], + "num_res_blocks": 1, + "norm_num_groups": 4, + "num_head_channels": 4, + "conditioning_embedding_num_channels": [16], + "conditioning_embedding_in_channels": 1, + }, + (1, 1, 8, 8), + (1, 3, 4, 4), + ], + [ + "VQVAE", + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "channels": [4, 4], + "num_res_layers": 1, + "num_res_channels": [4, 4], + "downsample_parameters": ((2, 4, 1, 1), (2, 4, 1, 1)), + "upsample_parameters": ((2, 4, 1, 1, 0), (2, 4, 1, 1, 0)), + "num_embeddings": 16, + "embedding_dim": 3, + }, + "DiffusionModelUNet", + { + "spatial_dims": 2, + "in_channels": 3, + "out_channels": 3, + "channels": [8, 8], + "norm_num_groups": 8, + "attention_levels": [False, False], + "num_res_blocks": 1, + "num_head_channels": 8, + }, + { + "spatial_dims": 2, + "in_channels": 3, + "channels": [8, 8], + "attention_levels": [False, False], + "num_res_blocks": 1, + "norm_num_groups": 8, + "num_head_channels": 8, + "conditioning_embedding_num_channels": [16], + "conditioning_embedding_in_channels": 1, + }, + (1, 1, 16, 16), + (1, 3, 4, 4), + ], + [ + "VQVAE", + { + "spatial_dims": 3, + "in_channels": 1, + "out_channels": 1, + "channels": [4, 4], + "num_res_layers": 1, + "num_res_channels": [4, 4], + "downsample_parameters": ((2, 4, 1, 1), (2, 4, 1, 1)), + "upsample_parameters": ((2, 4, 1, 1, 0), (2, 4, 1, 1, 0)), + "num_embeddings": 16, + "embedding_dim": 3, + }, + "DiffusionModelUNet", + { + "spatial_dims": 3, + "in_channels": 3, + "out_channels": 3, + "channels": [8, 8], + "norm_num_groups": 8, + "attention_levels": [False, False], + "num_res_blocks": 1, + "num_head_channels": 8, + }, + { + "spatial_dims": 3, + "in_channels": 3, + "channels": [8, 8], + "attention_levels": [False, False], + "num_res_blocks": 1, + "norm_num_groups": 8, + "num_head_channels": 8, + "conditioning_embedding_num_channels": [16], + "conditioning_embedding_in_channels": 1, + }, + (1, 1, 16, 16, 16), + (1, 3, 4, 4, 4), + ], +] +LATENT_CNDM_TEST_CASES_DIFF_SHAPES = [ + [ + "AutoencoderKL", + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "channels": (4, 4), + "latent_channels": 3, + "attention_levels": [False, False], + "num_res_blocks": 1, + "with_encoder_nonlocal_attn": False, + "with_decoder_nonlocal_attn": False, + "norm_num_groups": 4, + }, + "DiffusionModelUNet", + { + "spatial_dims": 2, + "in_channels": 3, + "out_channels": 3, + "channels": [4, 4], + "norm_num_groups": 4, + "attention_levels": [False, False], + "num_res_blocks": 1, + "num_head_channels": 4, + }, + { + "spatial_dims": 2, + "in_channels": 3, + "channels": [4, 4], + "attention_levels": [False, False], + "num_res_blocks": 1, + "norm_num_groups": 4, + "num_head_channels": 4, + "conditioning_embedding_num_channels": [16], + "conditioning_embedding_in_channels": 1, + }, + (1, 1, 12, 12), + (1, 3, 8, 8), + ], + [ + "VQVAE", + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "channels": [4, 4], + "num_res_layers": 1, + "num_res_channels": [4, 4], + "downsample_parameters": ((2, 4, 1, 1), (2, 4, 1, 1)), + "upsample_parameters": ((2, 4, 1, 1, 0), (2, 4, 1, 1, 0)), + "num_embeddings": 16, + "embedding_dim": 3, + }, + "DiffusionModelUNet", + { + "spatial_dims": 2, + "in_channels": 3, + "out_channels": 3, + "channels": [8, 8], + "norm_num_groups": 8, + "attention_levels": [False, False], + "num_res_blocks": 1, + "num_head_channels": 8, + }, + { + "spatial_dims": 2, + "in_channels": 3, + "channels": [8, 8], + "attention_levels": [False, False], + "num_res_blocks": 1, + "norm_num_groups": 8, + "num_head_channels": 8, + "conditioning_embedding_num_channels": [16], + "conditioning_embedding_in_channels": 1, + }, + (1, 1, 12, 12), + (1, 3, 8, 8), + ], + [ + "VQVAE", + { + "spatial_dims": 3, + "in_channels": 1, + "out_channels": 1, + "channels": [4, 4], + "num_res_layers": 1, + "num_res_channels": [4, 4], + "downsample_parameters": ((2, 4, 1, 1), (2, 4, 1, 1)), + "upsample_parameters": ((2, 4, 1, 1, 0), (2, 4, 1, 1, 0)), + "num_embeddings": 16, + "embedding_dim": 3, + }, + "DiffusionModelUNet", + { + "spatial_dims": 3, + "in_channels": 3, + "out_channels": 3, + "channels": [8, 8], + "norm_num_groups": 8, + "attention_levels": [False, False], + "num_res_blocks": 1, + "num_head_channels": 8, + }, + { + "spatial_dims": 3, + "in_channels": 3, + "channels": [8, 8], + "attention_levels": [False, False], + "num_res_blocks": 1, + "norm_num_groups": 8, + "num_head_channels": 8, + "conditioning_embedding_num_channels": [16], + "conditioning_embedding_in_channels": 1, + }, + (1, 1, 12, 12, 12), + (1, 3, 8, 8, 8), + ], + [ + "SPADEAutoencoderKL", + { + "spatial_dims": 2, + "label_nc": 3, + "in_channels": 1, + "out_channels": 1, + "channels": (4, 4), + "latent_channels": 3, + "attention_levels": [False, False], + "num_res_blocks": 1, + "with_encoder_nonlocal_attn": False, + "with_decoder_nonlocal_attn": False, + "norm_num_groups": 4, + }, + "DiffusionModelUNet", + { + "spatial_dims": 2, + "in_channels": 3, + "out_channels": 3, + "channels": [4, 4], + "norm_num_groups": 4, + "attention_levels": [False, False], + "num_res_blocks": 1, + "num_head_channels": 4, + }, + { + "spatial_dims": 2, + "in_channels": 3, + "channels": [4, 4], + "attention_levels": [False, False], + "num_res_blocks": 1, + "norm_num_groups": 4, + "num_head_channels": 4, + "conditioning_embedding_num_channels": [16], + "conditioning_embedding_in_channels": 1, + }, + (1, 1, 8, 8), + (1, 3, 4, 4), + ], + [ + "AutoencoderKL", + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "channels": (4, 4), + "latent_channels": 3, + "attention_levels": [False, False], + "num_res_blocks": 1, + "with_encoder_nonlocal_attn": False, + "with_decoder_nonlocal_attn": False, + "norm_num_groups": 4, + }, + "SPADEDiffusionModelUNet", + { + "spatial_dims": 2, + "label_nc": 3, + "in_channels": 3, + "out_channels": 3, + "channels": [4, 4], + "norm_num_groups": 4, + "attention_levels": [False, False], + "num_res_blocks": 1, + "num_head_channels": 4, + }, + { + "spatial_dims": 2, + "in_channels": 3, + "channels": [4, 4], + "attention_levels": [False, False], + "num_res_blocks": 1, + "norm_num_groups": 4, + "num_head_channels": 4, + "conditioning_embedding_num_channels": [16], + "conditioning_embedding_in_channels": 1, + }, + (1, 1, 8, 8), + (1, 3, 4, 4), + ], + [ + "SPADEAutoencoderKL", + { + "spatial_dims": 2, + "label_nc": 3, + "in_channels": 1, + "out_channels": 1, + "channels": (4, 4), + "latent_channels": 3, + "attention_levels": [False, False], + "num_res_blocks": 1, + "with_encoder_nonlocal_attn": False, + "with_decoder_nonlocal_attn": False, + "norm_num_groups": 4, + }, + "SPADEDiffusionModelUNet", + { + "spatial_dims": 2, + "label_nc": 3, + "in_channels": 3, + "out_channels": 3, + "channels": [4, 4], + "norm_num_groups": 4, + "attention_levels": [False, False], + "num_res_blocks": 1, + "num_head_channels": 4, + }, + { + "spatial_dims": 2, + "in_channels": 3, + "channels": [4, 4], + "attention_levels": [False, False], + "num_res_blocks": 1, + "norm_num_groups": 4, + "num_head_channels": 4, + "conditioning_embedding_num_channels": [16], + "conditioning_embedding_in_channels": 1, + }, + (1, 1, 8, 8), + (1, 3, 4, 4), + ], +] + + +class ControlNetTestDiffusionSamplingInferer(unittest.TestCase): + @parameterized.expand(CNDM_TEST_CASES) + def test_call(self, model_params, controlnet_params, input_shape): + model = DiffusionModelUNet(**model_params) + controlnet = ControlNet(**controlnet_params) + device = "cuda:0" if torch.cuda.is_available() else "cpu" + model.to(device) + model.eval() + controlnet.to(device) + controlnet.eval() + input = torch.randn(input_shape).to(device) + mask = torch.randn(input_shape).to(device) + noise = torch.randn(input_shape).to(device) + scheduler = DDPMScheduler(num_train_timesteps=10) + inferer = ControlNetDiffusionInferer(scheduler=scheduler) + scheduler.set_timesteps(num_inference_steps=10) + timesteps = torch.randint(0, scheduler.num_train_timesteps, (input_shape[0],), device=input.device).long() + sample = inferer( + inputs=input, noise=noise, diffusion_model=model, controlnet=controlnet, timesteps=timesteps, cn_cond=mask + ) + self.assertEqual(sample.shape, input_shape) + + @parameterized.expand(CNDM_TEST_CASES) + def test_sample_intermediates(self, model_params, controlnet_params, input_shape): + model = DiffusionModelUNet(**model_params) + controlnet = ControlNet(**controlnet_params) + device = "cuda:0" if torch.cuda.is_available() else "cpu" + model.to(device) + model.eval() + controlnet.to(device) + controlnet.eval() + noise = torch.randn(input_shape).to(device) + mask = torch.randn(input_shape).to(device) + scheduler = DDPMScheduler(num_train_timesteps=10) + inferer = ControlNetDiffusionInferer(scheduler=scheduler) + scheduler.set_timesteps(num_inference_steps=10) + sample, intermediates = inferer.sample( + input_noise=noise, + diffusion_model=model, + scheduler=scheduler, + controlnet=controlnet, + cn_cond=mask, + save_intermediates=True, + intermediate_steps=1, + ) + self.assertEqual(len(intermediates), 10) + + @parameterized.expand(CNDM_TEST_CASES) + def test_ddpm_sampler(self, model_params, controlnet_params, input_shape): + model = DiffusionModelUNet(**model_params) + controlnet = ControlNet(**controlnet_params) + device = "cuda:0" if torch.cuda.is_available() else "cpu" + model.to(device) + model.eval() + controlnet.to(device) + controlnet.eval() + mask = torch.randn(input_shape).to(device) + noise = torch.randn(input_shape).to(device) + scheduler = DDPMScheduler(num_train_timesteps=1000) + inferer = ControlNetDiffusionInferer(scheduler=scheduler) + scheduler.set_timesteps(num_inference_steps=10) + sample, intermediates = inferer.sample( + input_noise=noise, + diffusion_model=model, + scheduler=scheduler, + controlnet=controlnet, + cn_cond=mask, + save_intermediates=True, + intermediate_steps=1, + ) + self.assertEqual(len(intermediates), 10) + + @parameterized.expand(CNDM_TEST_CASES) + def test_ddim_sampler(self, model_params, controlnet_params, input_shape): + model = DiffusionModelUNet(**model_params) + controlnet = ControlNet(**controlnet_params) + device = "cuda:0" if torch.cuda.is_available() else "cpu" + model.to(device) + model.eval() + controlnet.to(device) + controlnet.eval() + mask = torch.randn(input_shape).to(device) + noise = torch.randn(input_shape).to(device) + scheduler = DDIMScheduler(num_train_timesteps=1000) + inferer = ControlNetDiffusionInferer(scheduler=scheduler) + scheduler.set_timesteps(num_inference_steps=10) + sample, intermediates = inferer.sample( + input_noise=noise, + diffusion_model=model, + scheduler=scheduler, + controlnet=controlnet, + cn_cond=mask, + save_intermediates=True, + intermediate_steps=1, + ) + self.assertEqual(len(intermediates), 10) + + @parameterized.expand(CNDM_TEST_CASES) + def test_sampler_conditioned(self, model_params, controlnet_params, input_shape): + model_params["with_conditioning"] = True + model_params["cross_attention_dim"] = 3 + model = DiffusionModelUNet(**model_params) + controlnet = ControlNet(**controlnet_params) + device = "cuda:0" if torch.cuda.is_available() else "cpu" + model.to(device) + model.eval() + controlnet.to(device) + controlnet.eval() + mask = torch.randn(input_shape).to(device) + noise = torch.randn(input_shape).to(device) + scheduler = DDIMScheduler(num_train_timesteps=1000) + inferer = ControlNetDiffusionInferer(scheduler=scheduler) + scheduler.set_timesteps(num_inference_steps=10) + conditioning = torch.randn([input_shape[0], 1, 3]).to(device) + sample, intermediates = inferer.sample( + input_noise=noise, + diffusion_model=model, + controlnet=controlnet, + cn_cond=mask, + scheduler=scheduler, + save_intermediates=True, + intermediate_steps=1, + conditioning=conditioning, + ) + self.assertEqual(len(intermediates), 10) + + @parameterized.expand(CNDM_TEST_CASES) + def test_get_likelihood(self, model_params, controlnet_params, input_shape): + model = DiffusionModelUNet(**model_params) + device = "cuda:0" if torch.cuda.is_available() else "cpu" + model.to(device) + model.eval() + controlnet = ControlNet(**controlnet_params) + controlnet.to(device) + controlnet.eval() + input = torch.randn(input_shape).to(device) + mask = torch.randn(input_shape).to(device) + scheduler = DDPMScheduler(num_train_timesteps=10) + inferer = ControlNetDiffusionInferer(scheduler=scheduler) + scheduler.set_timesteps(num_inference_steps=10) + likelihood, intermediates = inferer.get_likelihood( + inputs=input, + diffusion_model=model, + scheduler=scheduler, + controlnet=controlnet, + cn_cond=mask, + save_intermediates=True, + ) + self.assertEqual(intermediates[0].shape, input.shape) + self.assertEqual(likelihood.shape[0], input.shape[0]) + + @unittest.skipUnless(has_scipy, "Requires scipy library.") + def test_normal_cdf(self): + from scipy.stats import norm + + scheduler = DDPMScheduler(num_train_timesteps=10) + inferer = ControlNetDiffusionInferer(scheduler=scheduler) + x = torch.linspace(-10, 10, 20) + cdf_approx = inferer._approx_standard_normal_cdf(x) + cdf_true = norm.cdf(x) + torch.testing.assert_allclose(cdf_approx, cdf_true, atol=1e-3, rtol=1e-5) + + @parameterized.expand(CNDM_TEST_CASES) + def test_sampler_conditioned_concat(self, model_params, controlnet_params, input_shape): + # copy the model_params dict to prevent from modifying test cases + model_params = model_params.copy() + n_concat_channel = 2 + model_params["in_channels"] = model_params["in_channels"] + n_concat_channel + model_params["cross_attention_dim"] = None + model_params["with_conditioning"] = False + model = DiffusionModelUNet(**model_params) + device = "cuda:0" if torch.cuda.is_available() else "cpu" + model.to(device) + model.eval() + controlnet = ControlNet(**controlnet_params) + controlnet.to(device) + controlnet.eval() + noise = torch.randn(input_shape).to(device) + mask = torch.randn(input_shape).to(device) + conditioning_shape = list(input_shape) + conditioning_shape[1] = n_concat_channel + conditioning = torch.randn(conditioning_shape).to(device) + scheduler = DDIMScheduler(num_train_timesteps=1000) + inferer = ControlNetDiffusionInferer(scheduler=scheduler) + scheduler.set_timesteps(num_inference_steps=10) + sample, intermediates = inferer.sample( + input_noise=noise, + diffusion_model=model, + controlnet=controlnet, + cn_cond=mask, + scheduler=scheduler, + save_intermediates=True, + intermediate_steps=1, + conditioning=conditioning, + mode="concat", + ) + self.assertEqual(len(intermediates), 10) + + +class LatentControlNetTestDiffusionSamplingInferer(unittest.TestCase): + @parameterized.expand(LATENT_CNDM_TEST_CASES) + def test_prediction_shape( + self, + ae_model_type, + autoencoder_params, + dm_model_type, + stage_2_params, + controlnet_params, + input_shape, + latent_shape, + ): + if ae_model_type == "AutoencoderKL": + stage_1 = AutoencoderKL(**autoencoder_params) + if ae_model_type == "VQVAE": + stage_1 = VQVAE(**autoencoder_params) + if dm_model_type == "SPADEDiffusionModelUNet": + stage_2 = SPADEDiffusionModelUNet(**stage_2_params) + else: + stage_2 = DiffusionModelUNet(**stage_2_params) + controlnet = ControlNet(**controlnet_params) + + device = "cuda:0" if torch.cuda.is_available() else "cpu" + stage_1.to(device) + stage_2.to(device) + controlnet.to(device) + stage_1.eval() + stage_2.eval() + controlnet.eval() + + input = torch.randn(input_shape).to(device) + mask = torch.randn(input_shape).to(device) + noise = torch.randn(latent_shape).to(device) + scheduler = DDPMScheduler(num_train_timesteps=10) + inferer = ControlNetLatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0) + scheduler.set_timesteps(num_inference_steps=10) + timesteps = torch.randint(0, scheduler.num_train_timesteps, (input_shape[0],), device=input.device).long() + + if dm_model_type == "SPADEDiffusionModelUNet": + input_shape_seg = list(input_shape) + if "label_nc" in stage_2_params.keys(): + input_shape_seg[1] = stage_2_params["label_nc"] + else: + input_shape_seg[1] = autoencoder_params["label_nc"] + input_seg = torch.randn(input_shape_seg).to(device) + prediction = inferer( + inputs=input, + autoencoder_model=stage_1, + diffusion_model=stage_2, + controlnet=controlnet, + cn_cond=mask, + seg=input_seg, + noise=noise, + timesteps=timesteps, + ) + else: + prediction = inferer( + inputs=input, + autoencoder_model=stage_1, + diffusion_model=stage_2, + noise=noise, + timesteps=timesteps, + controlnet=controlnet, + cn_cond=mask, + ) + self.assertEqual(prediction.shape, latent_shape) + + @parameterized.expand(LATENT_CNDM_TEST_CASES) + def test_sample_shape( + self, + ae_model_type, + autoencoder_params, + dm_model_type, + stage_2_params, + controlnet_params, + input_shape, + latent_shape, + ): + if ae_model_type == "AutoencoderKL": + stage_1 = AutoencoderKL(**autoencoder_params) + if ae_model_type == "VQVAE": + stage_1 = VQVAE(**autoencoder_params) + if dm_model_type == "SPADEDiffusionModelUNet": + stage_2 = SPADEDiffusionModelUNet(**stage_2_params) + else: + stage_2 = DiffusionModelUNet(**stage_2_params) + controlnet = ControlNet(**controlnet_params) + + device = "cuda:0" if torch.cuda.is_available() else "cpu" + stage_1.to(device) + stage_2.to(device) + controlnet.to(device) + stage_1.eval() + stage_2.eval() + controlnet.eval() + + noise = torch.randn(latent_shape).to(device) + mask = torch.randn(input_shape).to(device) + scheduler = DDPMScheduler(num_train_timesteps=10) + inferer = ControlNetLatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0) + scheduler.set_timesteps(num_inference_steps=10) + + if ae_model_type == "SPADEAutoencoderKL" or dm_model_type == "SPADEDiffusionModelUNet": + input_shape_seg = list(input_shape) + if "label_nc" in stage_2_params.keys(): + input_shape_seg[1] = stage_2_params["label_nc"] + else: + input_shape_seg[1] = autoencoder_params["label_nc"] + input_seg = torch.randn(input_shape_seg).to(device) + sample = inferer.sample( + input_noise=noise, + autoencoder_model=stage_1, + diffusion_model=stage_2, + controlnet=controlnet, + cn_cond=mask, + scheduler=scheduler, + seg=input_seg, + ) + else: + sample = inferer.sample( + input_noise=noise, + autoencoder_model=stage_1, + diffusion_model=stage_2, + scheduler=scheduler, + controlnet=controlnet, + cn_cond=mask, + ) + self.assertEqual(sample.shape, input_shape) + + @parameterized.expand(LATENT_CNDM_TEST_CASES) + def test_sample_intermediates( + self, + ae_model_type, + autoencoder_params, + dm_model_type, + stage_2_params, + controlnet_params, + input_shape, + latent_shape, + ): + if ae_model_type == "AutoencoderKL": + stage_1 = AutoencoderKL(**autoencoder_params) + if ae_model_type == "VQVAE": + stage_1 = VQVAE(**autoencoder_params) + if ae_model_type == "SPADEAutoencoderKL": + stage_1 = SPADEAutoencoderKL(**autoencoder_params) + if dm_model_type == "SPADEDiffusionModelUNet": + stage_2 = SPADEDiffusionModelUNet(**stage_2_params) + else: + stage_2 = DiffusionModelUNet(**stage_2_params) + controlnet = ControlNet(**controlnet_params) + + device = "cuda:0" if torch.cuda.is_available() else "cpu" + stage_1.to(device) + stage_2.to(device) + controlnet.to(device) + stage_1.eval() + stage_2.eval() + controlnet.eval() + + noise = torch.randn(latent_shape).to(device) + mask = torch.randn(input_shape).to(device) + scheduler = DDPMScheduler(num_train_timesteps=10) + inferer = ControlNetLatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0) + scheduler.set_timesteps(num_inference_steps=10) + + if ae_model_type == "SPADEAutoencoderKL" or dm_model_type == "SPADEDiffusionModelUNet": + input_shape_seg = list(input_shape) + if "label_nc" in stage_2_params.keys(): + input_shape_seg[1] = stage_2_params["label_nc"] + else: + input_shape_seg[1] = autoencoder_params["label_nc"] + input_seg = torch.randn(input_shape_seg).to(device) + sample = inferer.sample( + input_noise=noise, + autoencoder_model=stage_1, + diffusion_model=stage_2, + scheduler=scheduler, + seg=input_seg, + controlnet=controlnet, + cn_cond=mask, + ) + else: + sample, intermediates = inferer.sample( + input_noise=noise, + autoencoder_model=stage_1, + diffusion_model=stage_2, + scheduler=scheduler, + save_intermediates=True, + intermediate_steps=1, + controlnet=controlnet, + cn_cond=mask, + ) + self.assertEqual(len(intermediates), 10) + self.assertEqual(intermediates[0].shape, input_shape) + + @parameterized.expand(LATENT_CNDM_TEST_CASES) + def test_get_likelihoods( + self, + ae_model_type, + autoencoder_params, + dm_model_type, + stage_2_params, + controlnet_params, + input_shape, + latent_shape, + ): + if ae_model_type == "AutoencoderKL": + stage_1 = AutoencoderKL(**autoencoder_params) + if ae_model_type == "VQVAE": + stage_1 = VQVAE(**autoencoder_params) + if ae_model_type == "SPADEAutoencoderKL": + stage_1 = SPADEAutoencoderKL(**autoencoder_params) + if dm_model_type == "SPADEDiffusionModelUNet": + stage_2 = SPADEDiffusionModelUNet(**stage_2_params) + else: + stage_2 = DiffusionModelUNet(**stage_2_params) + controlnet = ControlNet(**controlnet_params) + + device = "cuda:0" if torch.cuda.is_available() else "cpu" + stage_1.to(device) + stage_2.to(device) + controlnet.to(device) + stage_1.eval() + stage_2.eval() + controlnet.eval() + + input = torch.randn(input_shape).to(device) + mask = torch.randn(input_shape).to(device) + scheduler = DDPMScheduler(num_train_timesteps=10) + inferer = ControlNetLatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0) + scheduler.set_timesteps(num_inference_steps=10) + + if dm_model_type == "SPADEDiffusionModelUNet": + input_shape_seg = list(input_shape) + if "label_nc" in stage_2_params.keys(): + input_shape_seg[1] = stage_2_params["label_nc"] + else: + input_shape_seg[1] = autoencoder_params["label_nc"] + input_seg = torch.randn(input_shape_seg).to(device) + sample, intermediates = inferer.get_likelihood( + inputs=input, + autoencoder_model=stage_1, + diffusion_model=stage_2, + controlnet=controlnet, + cn_cond=mask, + scheduler=scheduler, + save_intermediates=True, + seg=input_seg, + ) + else: + sample, intermediates = inferer.get_likelihood( + inputs=input, + autoencoder_model=stage_1, + diffusion_model=stage_2, + scheduler=scheduler, + controlnet=controlnet, + cn_cond=mask, + save_intermediates=True, + ) + self.assertEqual(len(intermediates), 10) + self.assertEqual(intermediates[0].shape, latent_shape) + + @parameterized.expand(LATENT_CNDM_TEST_CASES) + def test_resample_likelihoods( + self, + ae_model_type, + autoencoder_params, + dm_model_type, + stage_2_params, + controlnet_params, + input_shape, + latent_shape, + ): + if ae_model_type == "AutoencoderKL": + stage_1 = AutoencoderKL(**autoencoder_params) + if ae_model_type == "VQVAE": + stage_1 = VQVAE(**autoencoder_params) + if ae_model_type == "SPADEAutoencoderKL": + stage_1 = SPADEAutoencoderKL(**autoencoder_params) + if dm_model_type == "SPADEDiffusionModelUNet": + stage_2 = SPADEDiffusionModelUNet(**stage_2_params) + else: + stage_2 = DiffusionModelUNet(**stage_2_params) + controlnet = ControlNet(**controlnet_params) + + device = "cuda:0" if torch.cuda.is_available() else "cpu" + stage_1.to(device) + stage_2.to(device) + controlnet.to(device) + stage_1.eval() + stage_2.eval() + controlnet.eval() + + input = torch.randn(input_shape).to(device) + mask = torch.randn(input_shape).to(device) + scheduler = DDPMScheduler(num_train_timesteps=10) + inferer = ControlNetLatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0) + scheduler.set_timesteps(num_inference_steps=10) + + if dm_model_type == "SPADEDiffusionModelUNet": + input_shape_seg = list(input_shape) + if "label_nc" in stage_2_params.keys(): + input_shape_seg[1] = stage_2_params["label_nc"] + else: + input_shape_seg[1] = autoencoder_params["label_nc"] + input_seg = torch.randn(input_shape_seg).to(device) + sample, intermediates = inferer.get_likelihood( + inputs=input, + autoencoder_model=stage_1, + diffusion_model=stage_2, + scheduler=scheduler, + controlnet=controlnet, + cn_cond=mask, + save_intermediates=True, + resample_latent_likelihoods=True, + seg=input_seg, + ) + else: + sample, intermediates = inferer.get_likelihood( + inputs=input, + autoencoder_model=stage_1, + diffusion_model=stage_2, + scheduler=scheduler, + controlnet=controlnet, + cn_cond=mask, + save_intermediates=True, + resample_latent_likelihoods=True, + ) + self.assertEqual(len(intermediates), 10) + self.assertEqual(intermediates[0].shape[2:], input_shape[2:]) + + @parameterized.expand(LATENT_CNDM_TEST_CASES) + def test_prediction_shape_conditioned_concat( + self, + ae_model_type, + autoencoder_params, + dm_model_type, + stage_2_params, + controlnet_params, + input_shape, + latent_shape, + ): + if ae_model_type == "AutoencoderKL": + stage_1 = AutoencoderKL(**autoencoder_params) + if ae_model_type == "VQVAE": + stage_1 = VQVAE(**autoencoder_params) + if ae_model_type == "SPADEAutoencoderKL": + stage_1 = SPADEAutoencoderKL(**autoencoder_params) + stage_2_params = stage_2_params.copy() + n_concat_channel = 3 + stage_2_params["in_channels"] = stage_2_params["in_channels"] + n_concat_channel + if dm_model_type == "SPADEDiffusionModelUNet": + stage_2 = SPADEDiffusionModelUNet(**stage_2_params) + else: + stage_2 = DiffusionModelUNet(**stage_2_params) + controlnet = ControlNet(**controlnet_params) + + device = "cuda:0" if torch.cuda.is_available() else "cpu" + stage_1.to(device) + stage_2.to(device) + controlnet.to(device) + stage_1.eval() + stage_2.eval() + controlnet.eval() + + input = torch.randn(input_shape).to(device) + mask = torch.randn(input_shape).to(device) + noise = torch.randn(latent_shape).to(device) + conditioning_shape = list(latent_shape) + conditioning_shape[1] = n_concat_channel + conditioning = torch.randn(conditioning_shape).to(device) + + scheduler = DDPMScheduler(num_train_timesteps=10) + inferer = ControlNetLatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0) + scheduler.set_timesteps(num_inference_steps=10) + + timesteps = torch.randint(0, scheduler.num_train_timesteps, (input_shape[0],), device=input.device).long() + + if dm_model_type == "SPADEDiffusionModelUNet": + input_shape_seg = list(input_shape) + if "label_nc" in stage_2_params.keys(): + input_shape_seg[1] = stage_2_params["label_nc"] + else: + input_shape_seg[1] = autoencoder_params["label_nc"] + input_seg = torch.randn(input_shape_seg).to(device) + prediction = inferer( + inputs=input, + autoencoder_model=stage_1, + diffusion_model=stage_2, + noise=noise, + controlnet=controlnet, + cn_cond=mask, + timesteps=timesteps, + condition=conditioning, + mode="concat", + seg=input_seg, + ) + else: + prediction = inferer( + inputs=input, + autoencoder_model=stage_1, + diffusion_model=stage_2, + noise=noise, + controlnet=controlnet, + cn_cond=mask, + timesteps=timesteps, + condition=conditioning, + mode="concat", + ) + self.assertEqual(prediction.shape, latent_shape) + + @parameterized.expand(LATENT_CNDM_TEST_CASES) + def test_sample_shape_conditioned_concat( + self, + ae_model_type, + autoencoder_params, + dm_model_type, + stage_2_params, + controlnet_params, + input_shape, + latent_shape, + ): + if ae_model_type == "AutoencoderKL": + stage_1 = AutoencoderKL(**autoencoder_params) + if ae_model_type == "VQVAE": + stage_1 = VQVAE(**autoencoder_params) + if ae_model_type == "SPADEAutoencoderKL": + stage_1 = SPADEAutoencoderKL(**autoencoder_params) + stage_2_params = stage_2_params.copy() + n_concat_channel = 3 + stage_2_params["in_channels"] = stage_2_params["in_channels"] + n_concat_channel + if dm_model_type == "SPADEDiffusionModelUNet": + stage_2 = SPADEDiffusionModelUNet(**stage_2_params) + else: + stage_2 = DiffusionModelUNet(**stage_2_params) + controlnet = ControlNet(**controlnet_params) + + device = "cuda:0" if torch.cuda.is_available() else "cpu" + stage_1.to(device) + stage_2.to(device) + controlnet.to(device) + stage_1.eval() + stage_2.eval() + controlnet.eval() + + noise = torch.randn(latent_shape).to(device) + mask = torch.randn(input_shape).to(device) + conditioning_shape = list(latent_shape) + conditioning_shape[1] = n_concat_channel + conditioning = torch.randn(conditioning_shape).to(device) + + scheduler = DDPMScheduler(num_train_timesteps=10) + inferer = ControlNetLatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0) + scheduler.set_timesteps(num_inference_steps=10) + + if dm_model_type == "SPADEDiffusionModelUNet": + input_shape_seg = list(input_shape) + if "label_nc" in stage_2_params.keys(): + input_shape_seg[1] = stage_2_params["label_nc"] + else: + input_shape_seg[1] = autoencoder_params["label_nc"] + input_seg = torch.randn(input_shape_seg).to(device) + sample = inferer.sample( + input_noise=noise, + autoencoder_model=stage_1, + diffusion_model=stage_2, + controlnet=controlnet, + cn_cond=mask, + scheduler=scheduler, + conditioning=conditioning, + mode="concat", + seg=input_seg, + ) + else: + sample = inferer.sample( + input_noise=noise, + autoencoder_model=stage_1, + diffusion_model=stage_2, + controlnet=controlnet, + cn_cond=mask, + scheduler=scheduler, + conditioning=conditioning, + mode="concat", + ) + self.assertEqual(sample.shape, input_shape) + + @parameterized.expand(LATENT_CNDM_TEST_CASES_DIFF_SHAPES) + def test_sample_shape_different_latents( + self, + ae_model_type, + autoencoder_params, + dm_model_type, + stage_2_params, + controlnet_params, + input_shape, + latent_shape, + ): + if ae_model_type == "AutoencoderKL": + stage_1 = AutoencoderKL(**autoencoder_params) + if ae_model_type == "VQVAE": + stage_1 = VQVAE(**autoencoder_params) + if ae_model_type == "SPADEAutoencoderKL": + stage_1 = SPADEAutoencoderKL(**autoencoder_params) + if dm_model_type == "SPADEDiffusionModelUNet": + stage_2 = SPADEDiffusionModelUNet(**stage_2_params) + else: + stage_2 = DiffusionModelUNet(**stage_2_params) + controlnet = ControlNet(**controlnet_params) + + device = "cuda:0" if torch.cuda.is_available() else "cpu" + stage_1.to(device) + stage_2.to(device) + controlnet.to(device) + stage_1.eval() + stage_2.eval() + controlnet.eval() + + input = torch.randn(input_shape).to(device) + noise = torch.randn(latent_shape).to(device) + mask = torch.randn(input_shape).to(device) + scheduler = DDPMScheduler(num_train_timesteps=10) + # We infer the VAE shape + autoencoder_latent_shape = [i // (2 ** (len(autoencoder_params["channels"]) - 1)) for i in input_shape[2:]] + inferer = ControlNetLatentDiffusionInferer( + scheduler=scheduler, + scale_factor=1.0, + ldm_latent_shape=list(latent_shape[2:]), + autoencoder_latent_shape=autoencoder_latent_shape, + ) + scheduler.set_timesteps(num_inference_steps=10) + + timesteps = torch.randint(0, scheduler.num_train_timesteps, (input_shape[0],), device=input.device).long() + + if dm_model_type == "SPADEDiffusionModelUNet": + input_shape_seg = list(input_shape) + if "label_nc" in stage_2_params.keys(): + input_shape_seg[1] = stage_2_params["label_nc"] + else: + input_shape_seg[1] = autoencoder_params["label_nc"] + input_seg = torch.randn(input_shape_seg).to(device) + prediction = inferer( + inputs=input, + autoencoder_model=stage_1, + diffusion_model=stage_2, + controlnet=controlnet, + cn_cond=mask, + noise=noise, + timesteps=timesteps, + seg=input_seg, + ) + else: + prediction = inferer( + inputs=input, + autoencoder_model=stage_1, + diffusion_model=stage_2, + noise=noise, + controlnet=controlnet, + cn_cond=mask, + timesteps=timesteps, + ) + self.assertEqual(prediction.shape, latent_shape) + + def test_incompatible_spade_setup(self): + stage_1 = SPADEAutoencoderKL( + spatial_dims=2, + label_nc=6, + in_channels=1, + out_channels=1, + channels=(4, 4), + latent_channels=3, + attention_levels=[False, False], + num_res_blocks=1, + with_encoder_nonlocal_attn=False, + with_decoder_nonlocal_attn=False, + norm_num_groups=4, + ) + stage_2 = SPADEDiffusionModelUNet( + spatial_dims=2, + label_nc=3, + in_channels=3, + out_channels=3, + channels=[4, 4], + norm_num_groups=4, + attention_levels=[False, False], + num_res_blocks=1, + num_head_channels=4, + ) + controlnet = ControlNet( + spatial_dims=2, + in_channels=1, + channels=[4, 4], + norm_num_groups=4, + attention_levels=[False, False], + num_res_blocks=1, + num_head_channels=4, + conditioning_embedding_num_channels=[16], + ) + + device = "cuda:0" if torch.cuda.is_available() else "cpu" + stage_1.to(device) + stage_2.to(device) + controlnet.to(device) + controlnet.to(device) + stage_1.eval() + stage_2.eval() + controlnet.eval() + noise = torch.randn((1, 3, 4, 4)).to(device) + mask = torch.randn((1, 1, 4, 4)).to(device) + input_seg = torch.randn((1, 3, 8, 8)).to(device) + scheduler = DDPMScheduler(num_train_timesteps=10) + inferer = ControlNetLatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0) + scheduler.set_timesteps(num_inference_steps=10) + + with self.assertRaises(ValueError): + _ = inferer.sample( + input_noise=noise, + autoencoder_model=stage_1, + diffusion_model=stage_2, + scheduler=scheduler, + controlnet=controlnet, + cn_cond=mask, + seg=input_seg, + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_diffusion_inferer.py b/tests/test_diffusion_inferer.py new file mode 100644 index 0000000000..ecd4855385 --- /dev/null +++ b/tests/test_diffusion_inferer.py @@ -0,0 +1,226 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import torch +from parameterized import parameterized + +from monai.inferers import DiffusionInferer +from monai.networks.nets import DiffusionModelUNet +from monai.networks.schedulers import DDIMScheduler, DDPMScheduler +from monai.utils import optional_import + +_, has_scipy = optional_import("scipy") + +TEST_CASES = [ + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "channels": [8], + "norm_num_groups": 8, + "attention_levels": [True], + "num_res_blocks": 1, + "num_head_channels": 8, + }, + (2, 1, 8, 8), + ], + [ + { + "spatial_dims": 3, + "in_channels": 1, + "out_channels": 1, + "channels": [8], + "norm_num_groups": 8, + "attention_levels": [True], + "num_res_blocks": 1, + "num_head_channels": 8, + }, + (2, 1, 8, 8, 8), + ], +] + + +class TestDiffusionSamplingInferer(unittest.TestCase): + @parameterized.expand(TEST_CASES) + def test_call(self, model_params, input_shape): + model = DiffusionModelUNet(**model_params) + device = "cuda:0" if torch.cuda.is_available() else "cpu" + model.to(device) + model.eval() + input = torch.randn(input_shape).to(device) + noise = torch.randn(input_shape).to(device) + scheduler = DDPMScheduler(num_train_timesteps=10) + inferer = DiffusionInferer(scheduler=scheduler) + scheduler.set_timesteps(num_inference_steps=10) + timesteps = torch.randint(0, scheduler.num_train_timesteps, (input_shape[0],), device=input.device).long() + sample = inferer(inputs=input, noise=noise, diffusion_model=model, timesteps=timesteps) + self.assertEqual(sample.shape, input_shape) + + @parameterized.expand(TEST_CASES) + def test_sample_intermediates(self, model_params, input_shape): + model = DiffusionModelUNet(**model_params) + device = "cuda:0" if torch.cuda.is_available() else "cpu" + model.to(device) + model.eval() + noise = torch.randn(input_shape).to(device) + scheduler = DDPMScheduler(num_train_timesteps=10) + inferer = DiffusionInferer(scheduler=scheduler) + scheduler.set_timesteps(num_inference_steps=10) + sample, intermediates = inferer.sample( + input_noise=noise, diffusion_model=model, scheduler=scheduler, save_intermediates=True, intermediate_steps=1 + ) + self.assertEqual(len(intermediates), 10) + + @parameterized.expand(TEST_CASES) + def test_ddpm_sampler(self, model_params, input_shape): + model = DiffusionModelUNet(**model_params) + device = "cuda:0" if torch.cuda.is_available() else "cpu" + model.to(device) + model.eval() + noise = torch.randn(input_shape).to(device) + scheduler = DDPMScheduler(num_train_timesteps=1000) + inferer = DiffusionInferer(scheduler=scheduler) + scheduler.set_timesteps(num_inference_steps=10) + sample, intermediates = inferer.sample( + input_noise=noise, diffusion_model=model, scheduler=scheduler, save_intermediates=True, intermediate_steps=1 + ) + self.assertEqual(len(intermediates), 10) + + @parameterized.expand(TEST_CASES) + def test_ddim_sampler(self, model_params, input_shape): + model = DiffusionModelUNet(**model_params) + device = "cuda:0" if torch.cuda.is_available() else "cpu" + model.to(device) + model.eval() + noise = torch.randn(input_shape).to(device) + scheduler = DDIMScheduler(num_train_timesteps=1000) + inferer = DiffusionInferer(scheduler=scheduler) + scheduler.set_timesteps(num_inference_steps=10) + sample, intermediates = inferer.sample( + input_noise=noise, diffusion_model=model, scheduler=scheduler, save_intermediates=True, intermediate_steps=1 + ) + self.assertEqual(len(intermediates), 10) + + @parameterized.expand(TEST_CASES) + def test_sampler_conditioned(self, model_params, input_shape): + model_params["with_conditioning"] = True + model_params["cross_attention_dim"] = 3 + model = DiffusionModelUNet(**model_params) + device = "cuda:0" if torch.cuda.is_available() else "cpu" + model.to(device) + model.eval() + noise = torch.randn(input_shape).to(device) + scheduler = DDIMScheduler(num_train_timesteps=1000) + inferer = DiffusionInferer(scheduler=scheduler) + scheduler.set_timesteps(num_inference_steps=10) + conditioning = torch.randn([input_shape[0], 1, 3]).to(device) + sample, intermediates = inferer.sample( + input_noise=noise, + diffusion_model=model, + scheduler=scheduler, + save_intermediates=True, + intermediate_steps=1, + conditioning=conditioning, + ) + self.assertEqual(len(intermediates), 10) + + @parameterized.expand(TEST_CASES) + def test_get_likelihood(self, model_params, input_shape): + model = DiffusionModelUNet(**model_params) + device = "cuda:0" if torch.cuda.is_available() else "cpu" + model.to(device) + model.eval() + input = torch.randn(input_shape).to(device) + scheduler = DDPMScheduler(num_train_timesteps=10) + inferer = DiffusionInferer(scheduler=scheduler) + scheduler.set_timesteps(num_inference_steps=10) + likelihood, intermediates = inferer.get_likelihood( + inputs=input, diffusion_model=model, scheduler=scheduler, save_intermediates=True + ) + self.assertEqual(intermediates[0].shape, input.shape) + self.assertEqual(likelihood.shape[0], input.shape[0]) + + @unittest.skipUnless(has_scipy, "Requires scipy library.") + def test_normal_cdf(self): + from scipy.stats import norm + + scheduler = DDPMScheduler(num_train_timesteps=10) + inferer = DiffusionInferer(scheduler=scheduler) + + x = torch.linspace(-10, 10, 20) + cdf_approx = inferer._approx_standard_normal_cdf(x) + cdf_true = norm.cdf(x) + torch.testing.assert_allclose(cdf_approx, cdf_true, atol=1e-3, rtol=1e-5) + + @parameterized.expand(TEST_CASES) + def test_sampler_conditioned_concat(self, model_params, input_shape): + # copy the model_params dict to prevent from modifying test cases + model_params = model_params.copy() + n_concat_channel = 2 + model_params["in_channels"] = model_params["in_channels"] + n_concat_channel + model_params["cross_attention_dim"] = None + model_params["with_conditioning"] = False + model = DiffusionModelUNet(**model_params) + device = "cuda:0" if torch.cuda.is_available() else "cpu" + model.to(device) + model.eval() + noise = torch.randn(input_shape).to(device) + conditioning_shape = list(input_shape) + conditioning_shape[1] = n_concat_channel + conditioning = torch.randn(conditioning_shape).to(device) + scheduler = DDIMScheduler(num_train_timesteps=1000) + inferer = DiffusionInferer(scheduler=scheduler) + scheduler.set_timesteps(num_inference_steps=10) + sample, intermediates = inferer.sample( + input_noise=noise, + diffusion_model=model, + scheduler=scheduler, + save_intermediates=True, + intermediate_steps=1, + conditioning=conditioning, + mode="concat", + ) + self.assertEqual(len(intermediates), 10) + + @parameterized.expand(TEST_CASES) + def test_call_conditioned_concat(self, model_params, input_shape): + # copy the model_params dict to prevent from modifying test cases + model_params = model_params.copy() + n_concat_channel = 2 + model_params["in_channels"] = model_params["in_channels"] + n_concat_channel + model_params["cross_attention_dim"] = None + model_params["with_conditioning"] = False + model = DiffusionModelUNet(**model_params) + device = "cuda:0" if torch.cuda.is_available() else "cpu" + model.to(device) + model.eval() + input = torch.randn(input_shape).to(device) + noise = torch.randn(input_shape).to(device) + conditioning_shape = list(input_shape) + conditioning_shape[1] = n_concat_channel + conditioning = torch.randn(conditioning_shape).to(device) + scheduler = DDPMScheduler(num_train_timesteps=10) + inferer = DiffusionInferer(scheduler=scheduler) + scheduler.set_timesteps(num_inference_steps=10) + timesteps = torch.randint(0, scheduler.num_train_timesteps, (input_shape[0],), device=input.device).long() + sample = inferer( + inputs=input, noise=noise, diffusion_model=model, timesteps=timesteps, condition=conditioning, mode="concat" + ) + self.assertEqual(sample.shape, input_shape) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_flexible_unet.py b/tests/test_flexible_unet.py index 1218ce6e85..1d831f0976 100644 --- a/tests/test_flexible_unet.py +++ b/tests/test_flexible_unet.py @@ -39,7 +39,7 @@ class DummyEncoder(BaseEncoder): def get_encoder_parameters(cls): basic_dict = {"spatial_dims": 2, "in_channels": 3, "pretrained": False} param_dict_list = [basic_dict] - for key in basic_dict: + for key in basic_dict.keys(): cur_dict = basic_dict.copy() del cur_dict[key] param_dict_list.append(cur_dict) diff --git a/tests/test_invertd.py b/tests/test_invertd.py index cd2e91257a..2e6ee35981 100644 --- a/tests/test_invertd.py +++ b/tests/test_invertd.py @@ -112,15 +112,15 @@ def test_invert(self): self.assertTupleEqual(i.shape[1:], (101, 100, 107)) # check the case that different items use different interpolation mode to invert transforms - d = item["image_inverted1"] + j = item["image_inverted1"] # if the interpolation mode is nearest, accumulated diff should be smaller than 1 - self.assertLess(torch.sum(d.to(torch.float) - d.to(torch.uint8).to(torch.float)).item(), 1.0) - self.assertTupleEqual(d.shape, (1, 101, 100, 107)) + self.assertLess(torch.sum(j.to(torch.float) - j.to(torch.uint8).to(torch.float)).item(), 1.0) + self.assertTupleEqual(j.shape, (1, 101, 100, 107)) - d = item["label_inverted1"] + k = item["label_inverted1"] # if the interpolation mode is not nearest, accumulated diff should be greater than 10000 - self.assertGreater(torch.sum(d.to(torch.float) - d.to(torch.uint8).to(torch.float)).item(), 10000.0) - self.assertTupleEqual(d.shape, (1, 101, 100, 107)) + self.assertGreater(torch.sum(k.to(torch.float) - k.to(torch.uint8).to(torch.float)).item(), 10000.0) + self.assertTupleEqual(k.shape, (1, 101, 100, 107)) # check labels match reverted = item["label_inverted"].detach().cpu().numpy().astype(np.int32) diff --git a/tests/test_latent_diffusion_inferer.py b/tests/test_latent_diffusion_inferer.py new file mode 100644 index 0000000000..4ab803bb6f --- /dev/null +++ b/tests/test_latent_diffusion_inferer.py @@ -0,0 +1,796 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import torch +from parameterized import parameterized + +from monai.inferers import LatentDiffusionInferer +from monai.networks.nets import VQVAE, AutoencoderKL, DiffusionModelUNet, SPADEAutoencoderKL, SPADEDiffusionModelUNet +from monai.networks.schedulers import DDPMScheduler + +TEST_CASES = [ + [ + "AutoencoderKL", + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "channels": (4, 4), + "latent_channels": 3, + "attention_levels": [False, False], + "num_res_blocks": 1, + "with_encoder_nonlocal_attn": False, + "with_decoder_nonlocal_attn": False, + "norm_num_groups": 4, + }, + "DiffusionModelUNet", + { + "spatial_dims": 2, + "in_channels": 3, + "out_channels": 3, + "channels": [4, 4], + "norm_num_groups": 4, + "attention_levels": [False, False], + "num_res_blocks": 1, + "num_head_channels": 4, + }, + (1, 1, 8, 8), + (1, 3, 4, 4), + ], + [ + "VQVAE", + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "channels": [4, 4], + "num_res_layers": 1, + "num_res_channels": [4, 4], + "downsample_parameters": ((2, 4, 1, 1), (2, 4, 1, 1)), + "upsample_parameters": ((2, 4, 1, 1, 0), (2, 4, 1, 1, 0)), + "num_embeddings": 16, + "embedding_dim": 3, + }, + "DiffusionModelUNet", + { + "spatial_dims": 2, + "in_channels": 3, + "out_channels": 3, + "channels": [8, 8], + "norm_num_groups": 8, + "attention_levels": [False, False], + "num_res_blocks": 1, + "num_head_channels": 8, + }, + (1, 1, 16, 16), + (1, 3, 4, 4), + ], + [ + "VQVAE", + { + "spatial_dims": 3, + "in_channels": 1, + "out_channels": 1, + "channels": [4, 4], + "num_res_layers": 1, + "num_res_channels": [4, 4], + "downsample_parameters": ((2, 4, 1, 1), (2, 4, 1, 1)), + "upsample_parameters": ((2, 4, 1, 1, 0), (2, 4, 1, 1, 0)), + "num_embeddings": 16, + "embedding_dim": 3, + }, + "DiffusionModelUNet", + { + "spatial_dims": 3, + "in_channels": 3, + "out_channels": 3, + "channels": [8, 8], + "norm_num_groups": 8, + "attention_levels": [False, False], + "num_res_blocks": 1, + "num_head_channels": 8, + }, + (1, 1, 16, 16, 16), + (1, 3, 4, 4, 4), + ], + [ + "AutoencoderKL", + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "channels": (4, 4), + "latent_channels": 3, + "attention_levels": [False, False], + "num_res_blocks": 1, + "with_encoder_nonlocal_attn": False, + "with_decoder_nonlocal_attn": False, + "norm_num_groups": 4, + }, + "SPADEDiffusionModelUNet", + { + "spatial_dims": 2, + "label_nc": 3, + "in_channels": 3, + "out_channels": 3, + "channels": [4, 4], + "norm_num_groups": 4, + "attention_levels": [False, False], + "num_res_blocks": 1, + "num_head_channels": 4, + }, + (1, 1, 8, 8), + (1, 3, 4, 4), + ], +] +TEST_CASES_DIFF_SHAPES = [ + [ + "AutoencoderKL", + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "channels": (4, 4), + "latent_channels": 3, + "attention_levels": [False, False], + "num_res_blocks": 1, + "with_encoder_nonlocal_attn": False, + "with_decoder_nonlocal_attn": False, + "norm_num_groups": 4, + }, + "DiffusionModelUNet", + { + "spatial_dims": 2, + "in_channels": 3, + "out_channels": 3, + "channels": [4, 4], + "norm_num_groups": 4, + "attention_levels": [False, False], + "num_res_blocks": 1, + "num_head_channels": 4, + }, + (1, 1, 12, 12), + (1, 3, 8, 8), + ], + [ + "VQVAE", + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "channels": [4, 4], + "num_res_layers": 1, + "num_res_channels": [4, 4], + "downsample_parameters": ((2, 4, 1, 1), (2, 4, 1, 1)), + "upsample_parameters": ((2, 4, 1, 1, 0), (2, 4, 1, 1, 0)), + "num_embeddings": 16, + "embedding_dim": 3, + }, + "DiffusionModelUNet", + { + "spatial_dims": 2, + "in_channels": 3, + "out_channels": 3, + "channels": [8, 8], + "norm_num_groups": 8, + "attention_levels": [False, False], + "num_res_blocks": 1, + "num_head_channels": 8, + }, + (1, 1, 12, 12), + (1, 3, 8, 8), + ], + [ + "VQVAE", + { + "spatial_dims": 3, + "in_channels": 1, + "out_channels": 1, + "channels": [4, 4], + "num_res_layers": 1, + "num_res_channels": [4, 4], + "downsample_parameters": ((2, 4, 1, 1), (2, 4, 1, 1)), + "upsample_parameters": ((2, 4, 1, 1, 0), (2, 4, 1, 1, 0)), + "num_embeddings": 16, + "embedding_dim": 3, + }, + "DiffusionModelUNet", + { + "spatial_dims": 3, + "in_channels": 3, + "out_channels": 3, + "channels": [8, 8], + "norm_num_groups": 8, + "attention_levels": [False, False], + "num_res_blocks": 1, + "num_head_channels": 8, + }, + (1, 1, 12, 12, 12), + (1, 3, 8, 8, 8), + ], + [ + "SPADEAutoencoderKL", + { + "spatial_dims": 2, + "label_nc": 3, + "in_channels": 1, + "out_channels": 1, + "channels": (4, 4), + "latent_channels": 3, + "attention_levels": [False, False], + "num_res_blocks": 1, + "with_encoder_nonlocal_attn": False, + "with_decoder_nonlocal_attn": False, + "norm_num_groups": 4, + }, + "DiffusionModelUNet", + { + "spatial_dims": 2, + "in_channels": 3, + "out_channels": 3, + "channels": [4, 4], + "norm_num_groups": 4, + "attention_levels": [False, False], + "num_res_blocks": 1, + "num_head_channels": 4, + }, + (1, 1, 8, 8), + (1, 3, 4, 4), + ], + [ + "AutoencoderKL", + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "channels": (4, 4), + "latent_channels": 3, + "attention_levels": [False, False], + "num_res_blocks": 1, + "with_encoder_nonlocal_attn": False, + "with_decoder_nonlocal_attn": False, + "norm_num_groups": 4, + }, + "SPADEDiffusionModelUNet", + { + "spatial_dims": 2, + "label_nc": 3, + "in_channels": 3, + "out_channels": 3, + "channels": [4, 4], + "norm_num_groups": 4, + "attention_levels": [False, False], + "num_res_blocks": 1, + "num_head_channels": 4, + }, + (1, 1, 8, 8), + (1, 3, 4, 4), + ], + [ + "SPADEAutoencoderKL", + { + "spatial_dims": 2, + "label_nc": 3, + "in_channels": 1, + "out_channels": 1, + "channels": (4, 4), + "latent_channels": 3, + "attention_levels": [False, False], + "num_res_blocks": 1, + "with_encoder_nonlocal_attn": False, + "with_decoder_nonlocal_attn": False, + "norm_num_groups": 4, + }, + "SPADEDiffusionModelUNet", + { + "spatial_dims": 2, + "label_nc": 3, + "in_channels": 3, + "out_channels": 3, + "channels": [4, 4], + "norm_num_groups": 4, + "attention_levels": [False, False], + "num_res_blocks": 1, + "num_head_channels": 4, + }, + (1, 1, 8, 8), + (1, 3, 4, 4), + ], +] + + +class TestDiffusionSamplingInferer(unittest.TestCase): + @parameterized.expand(TEST_CASES) + def test_prediction_shape( + self, ae_model_type, autoencoder_params, dm_model_type, stage_2_params, input_shape, latent_shape + ): + if ae_model_type == "AutoencoderKL": + stage_1 = AutoencoderKL(**autoencoder_params) + if ae_model_type == "VQVAE": + stage_1 = VQVAE(**autoencoder_params) + if dm_model_type == "SPADEDiffusionModelUNet": + stage_2 = SPADEDiffusionModelUNet(**stage_2_params) + else: + stage_2 = DiffusionModelUNet(**stage_2_params) + + device = "cuda:0" if torch.cuda.is_available() else "cpu" + stage_1.to(device) + stage_2.to(device) + stage_1.eval() + stage_2.eval() + + input = torch.randn(input_shape).to(device) + noise = torch.randn(latent_shape).to(device) + scheduler = DDPMScheduler(num_train_timesteps=10) + inferer = LatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0) + scheduler.set_timesteps(num_inference_steps=10) + timesteps = torch.randint(0, scheduler.num_train_timesteps, (input_shape[0],), device=input.device).long() + + if dm_model_type == "SPADEDiffusionModelUNet": + input_shape_seg = list(input_shape) + if "label_nc" in stage_2_params.keys(): + input_shape_seg[1] = stage_2_params["label_nc"] + else: + input_shape_seg[1] = autoencoder_params["label_nc"] + input_seg = torch.randn(input_shape_seg).to(device) + prediction = inferer( + inputs=input, + autoencoder_model=stage_1, + diffusion_model=stage_2, + seg=input_seg, + noise=noise, + timesteps=timesteps, + ) + else: + prediction = inferer( + inputs=input, autoencoder_model=stage_1, diffusion_model=stage_2, noise=noise, timesteps=timesteps + ) + self.assertEqual(prediction.shape, latent_shape) + + @parameterized.expand(TEST_CASES) + def test_sample_shape( + self, ae_model_type, autoencoder_params, dm_model_type, stage_2_params, input_shape, latent_shape + ): + if ae_model_type == "AutoencoderKL": + stage_1 = AutoencoderKL(**autoencoder_params) + if ae_model_type == "VQVAE": + stage_1 = VQVAE(**autoencoder_params) + if dm_model_type == "SPADEDiffusionModelUNet": + stage_2 = SPADEDiffusionModelUNet(**stage_2_params) + else: + stage_2 = DiffusionModelUNet(**stage_2_params) + + device = "cuda:0" if torch.cuda.is_available() else "cpu" + stage_1.to(device) + stage_2.to(device) + stage_1.eval() + stage_2.eval() + + noise = torch.randn(latent_shape).to(device) + scheduler = DDPMScheduler(num_train_timesteps=10) + inferer = LatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0) + scheduler.set_timesteps(num_inference_steps=10) + + if ae_model_type == "SPADEAutoencoderKL" or dm_model_type == "SPADEDiffusionModelUNet": + input_shape_seg = list(input_shape) + if "label_nc" in stage_2_params.keys(): + input_shape_seg[1] = stage_2_params["label_nc"] + else: + input_shape_seg[1] = autoencoder_params["label_nc"] + input_seg = torch.randn(input_shape_seg).to(device) + sample = inferer.sample( + input_noise=noise, + autoencoder_model=stage_1, + diffusion_model=stage_2, + scheduler=scheduler, + seg=input_seg, + ) + else: + sample = inferer.sample( + input_noise=noise, autoencoder_model=stage_1, diffusion_model=stage_2, scheduler=scheduler + ) + self.assertEqual(sample.shape, input_shape) + + @parameterized.expand(TEST_CASES) + def test_sample_intermediates( + self, ae_model_type, autoencoder_params, dm_model_type, stage_2_params, input_shape, latent_shape + ): + if ae_model_type == "AutoencoderKL": + stage_1 = AutoencoderKL(**autoencoder_params) + if ae_model_type == "VQVAE": + stage_1 = VQVAE(**autoencoder_params) + if ae_model_type == "SPADEAutoencoderKL": + stage_1 = SPADEAutoencoderKL(**autoencoder_params) + if dm_model_type == "SPADEDiffusionModelUNet": + stage_2 = SPADEDiffusionModelUNet(**stage_2_params) + else: + stage_2 = DiffusionModelUNet(**stage_2_params) + + device = "cuda:0" if torch.cuda.is_available() else "cpu" + stage_1.to(device) + stage_2.to(device) + stage_1.eval() + stage_2.eval() + + noise = torch.randn(latent_shape).to(device) + scheduler = DDPMScheduler(num_train_timesteps=10) + inferer = LatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0) + scheduler.set_timesteps(num_inference_steps=10) + + if ae_model_type == "SPADEAutoencoderKL" or dm_model_type == "SPADEDiffusionModelUNet": + input_shape_seg = list(input_shape) + if "label_nc" in stage_2_params.keys(): + input_shape_seg[1] = stage_2_params["label_nc"] + else: + input_shape_seg[1] = autoencoder_params["label_nc"] + input_seg = torch.randn(input_shape_seg).to(device) + sample, intermediates = inferer.sample( + input_noise=noise, + autoencoder_model=stage_1, + diffusion_model=stage_2, + scheduler=scheduler, + seg=input_seg, + save_intermediates=True, + intermediate_steps=1, + ) + else: + sample, intermediates = inferer.sample( + input_noise=noise, + autoencoder_model=stage_1, + diffusion_model=stage_2, + scheduler=scheduler, + save_intermediates=True, + intermediate_steps=1, + ) + self.assertEqual(len(intermediates), 10) + self.assertEqual(intermediates[0].shape, input_shape) + + @parameterized.expand(TEST_CASES) + def test_get_likelihoods( + self, ae_model_type, autoencoder_params, dm_model_type, stage_2_params, input_shape, latent_shape + ): + if ae_model_type == "AutoencoderKL": + stage_1 = AutoencoderKL(**autoencoder_params) + if ae_model_type == "VQVAE": + stage_1 = VQVAE(**autoencoder_params) + if ae_model_type == "SPADEAutoencoderKL": + stage_1 = SPADEAutoencoderKL(**autoencoder_params) + if dm_model_type == "SPADEDiffusionModelUNet": + stage_2 = SPADEDiffusionModelUNet(**stage_2_params) + else: + stage_2 = DiffusionModelUNet(**stage_2_params) + + device = "cuda:0" if torch.cuda.is_available() else "cpu" + stage_1.to(device) + stage_2.to(device) + stage_1.eval() + stage_2.eval() + + input = torch.randn(input_shape).to(device) + scheduler = DDPMScheduler(num_train_timesteps=10) + inferer = LatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0) + scheduler.set_timesteps(num_inference_steps=10) + + if dm_model_type == "SPADEDiffusionModelUNet": + input_shape_seg = list(input_shape) + if "label_nc" in stage_2_params.keys(): + input_shape_seg[1] = stage_2_params["label_nc"] + else: + input_shape_seg[1] = autoencoder_params["label_nc"] + input_seg = torch.randn(input_shape_seg).to(device) + sample, intermediates = inferer.get_likelihood( + inputs=input, + autoencoder_model=stage_1, + diffusion_model=stage_2, + scheduler=scheduler, + save_intermediates=True, + seg=input_seg, + ) + else: + sample, intermediates = inferer.get_likelihood( + inputs=input, + autoencoder_model=stage_1, + diffusion_model=stage_2, + scheduler=scheduler, + save_intermediates=True, + ) + self.assertEqual(len(intermediates), 10) + self.assertEqual(intermediates[0].shape, latent_shape) + + @parameterized.expand(TEST_CASES) + def test_resample_likelihoods( + self, ae_model_type, autoencoder_params, dm_model_type, stage_2_params, input_shape, latent_shape + ): + if ae_model_type == "AutoencoderKL": + stage_1 = AutoencoderKL(**autoencoder_params) + if ae_model_type == "VQVAE": + stage_1 = VQVAE(**autoencoder_params) + if ae_model_type == "SPADEAutoencoderKL": + stage_1 = SPADEAutoencoderKL(**autoencoder_params) + if dm_model_type == "SPADEDiffusionModelUNet": + stage_2 = SPADEDiffusionModelUNet(**stage_2_params) + else: + stage_2 = DiffusionModelUNet(**stage_2_params) + + device = "cuda:0" if torch.cuda.is_available() else "cpu" + stage_1.to(device) + stage_2.to(device) + stage_1.eval() + stage_2.eval() + + input = torch.randn(input_shape).to(device) + scheduler = DDPMScheduler(num_train_timesteps=10) + inferer = LatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0) + scheduler.set_timesteps(num_inference_steps=10) + + if dm_model_type == "SPADEDiffusionModelUNet": + input_shape_seg = list(input_shape) + if "label_nc" in stage_2_params.keys(): + input_shape_seg[1] = stage_2_params["label_nc"] + else: + input_shape_seg[1] = autoencoder_params["label_nc"] + input_seg = torch.randn(input_shape_seg).to(device) + sample, intermediates = inferer.get_likelihood( + inputs=input, + autoencoder_model=stage_1, + diffusion_model=stage_2, + scheduler=scheduler, + save_intermediates=True, + resample_latent_likelihoods=True, + seg=input_seg, + ) + else: + sample, intermediates = inferer.get_likelihood( + inputs=input, + autoencoder_model=stage_1, + diffusion_model=stage_2, + scheduler=scheduler, + save_intermediates=True, + resample_latent_likelihoods=True, + ) + self.assertEqual(len(intermediates), 10) + self.assertEqual(intermediates[0].shape[2:], input_shape[2:]) + + @parameterized.expand(TEST_CASES) + def test_prediction_shape_conditioned_concat( + self, ae_model_type, autoencoder_params, dm_model_type, stage_2_params, input_shape, latent_shape + ): + if ae_model_type == "AutoencoderKL": + stage_1 = AutoencoderKL(**autoencoder_params) + if ae_model_type == "VQVAE": + stage_1 = VQVAE(**autoencoder_params) + if ae_model_type == "SPADEAutoencoderKL": + stage_1 = SPADEAutoencoderKL(**autoencoder_params) + stage_2_params = stage_2_params.copy() + n_concat_channel = 3 + stage_2_params["in_channels"] = stage_2_params["in_channels"] + n_concat_channel + if dm_model_type == "SPADEDiffusionModelUNet": + stage_2 = SPADEDiffusionModelUNet(**stage_2_params) + else: + stage_2 = DiffusionModelUNet(**stage_2_params) + + device = "cuda:0" if torch.cuda.is_available() else "cpu" + stage_1.to(device) + stage_2.to(device) + stage_1.eval() + stage_2.eval() + + input = torch.randn(input_shape).to(device) + noise = torch.randn(latent_shape).to(device) + conditioning_shape = list(latent_shape) + conditioning_shape[1] = n_concat_channel + conditioning = torch.randn(conditioning_shape).to(device) + + scheduler = DDPMScheduler(num_train_timesteps=10) + inferer = LatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0) + scheduler.set_timesteps(num_inference_steps=10) + + timesteps = torch.randint(0, scheduler.num_train_timesteps, (input_shape[0],), device=input.device).long() + + if dm_model_type == "SPADEDiffusionModelUNet": + input_shape_seg = list(input_shape) + if "label_nc" in stage_2_params.keys(): + input_shape_seg[1] = stage_2_params["label_nc"] + else: + input_shape_seg[1] = autoencoder_params["label_nc"] + input_seg = torch.randn(input_shape_seg).to(device) + prediction = inferer( + inputs=input, + autoencoder_model=stage_1, + diffusion_model=stage_2, + noise=noise, + timesteps=timesteps, + condition=conditioning, + mode="concat", + seg=input_seg, + ) + else: + prediction = inferer( + inputs=input, + autoencoder_model=stage_1, + diffusion_model=stage_2, + noise=noise, + timesteps=timesteps, + condition=conditioning, + mode="concat", + ) + self.assertEqual(prediction.shape, latent_shape) + + @parameterized.expand(TEST_CASES) + def test_sample_shape_conditioned_concat( + self, ae_model_type, autoencoder_params, dm_model_type, stage_2_params, input_shape, latent_shape + ): + if ae_model_type == "AutoencoderKL": + stage_1 = AutoencoderKL(**autoencoder_params) + if ae_model_type == "VQVAE": + stage_1 = VQVAE(**autoencoder_params) + if ae_model_type == "SPADEAutoencoderKL": + stage_1 = SPADEAutoencoderKL(**autoencoder_params) + stage_2_params = stage_2_params.copy() + n_concat_channel = 3 + stage_2_params["in_channels"] = stage_2_params["in_channels"] + n_concat_channel + if dm_model_type == "SPADEDiffusionModelUNet": + stage_2 = SPADEDiffusionModelUNet(**stage_2_params) + else: + stage_2 = DiffusionModelUNet(**stage_2_params) + + device = "cuda:0" if torch.cuda.is_available() else "cpu" + stage_1.to(device) + stage_2.to(device) + stage_1.eval() + stage_2.eval() + + noise = torch.randn(latent_shape).to(device) + conditioning_shape = list(latent_shape) + conditioning_shape[1] = n_concat_channel + conditioning = torch.randn(conditioning_shape).to(device) + + scheduler = DDPMScheduler(num_train_timesteps=10) + inferer = LatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0) + scheduler.set_timesteps(num_inference_steps=10) + + if dm_model_type == "SPADEDiffusionModelUNet": + input_shape_seg = list(input_shape) + if "label_nc" in stage_2_params.keys(): + input_shape_seg[1] = stage_2_params["label_nc"] + else: + input_shape_seg[1] = autoencoder_params["label_nc"] + input_seg = torch.randn(input_shape_seg).to(device) + sample = inferer.sample( + input_noise=noise, + autoencoder_model=stage_1, + diffusion_model=stage_2, + scheduler=scheduler, + conditioning=conditioning, + mode="concat", + seg=input_seg, + ) + else: + sample = inferer.sample( + input_noise=noise, + autoencoder_model=stage_1, + diffusion_model=stage_2, + scheduler=scheduler, + conditioning=conditioning, + mode="concat", + ) + self.assertEqual(sample.shape, input_shape) + + @parameterized.expand(TEST_CASES_DIFF_SHAPES) + def test_sample_shape_different_latents( + self, ae_model_type, autoencoder_params, dm_model_type, stage_2_params, input_shape, latent_shape + ): + if ae_model_type == "AutoencoderKL": + stage_1 = AutoencoderKL(**autoencoder_params) + if ae_model_type == "VQVAE": + stage_1 = VQVAE(**autoencoder_params) + if ae_model_type == "SPADEAutoencoderKL": + stage_1 = SPADEAutoencoderKL(**autoencoder_params) + if dm_model_type == "SPADEDiffusionModelUNet": + stage_2 = SPADEDiffusionModelUNet(**stage_2_params) + else: + stage_2 = DiffusionModelUNet(**stage_2_params) + + device = "cuda:0" if torch.cuda.is_available() else "cpu" + stage_1.to(device) + stage_2.to(device) + stage_1.eval() + stage_2.eval() + + input = torch.randn(input_shape).to(device) + noise = torch.randn(latent_shape).to(device) + scheduler = DDPMScheduler(num_train_timesteps=10) + # We infer the VAE shape + autoencoder_latent_shape = [i // (2 ** (len(autoencoder_params["channels"]) - 1)) for i in input_shape[2:]] + inferer = LatentDiffusionInferer( + scheduler=scheduler, + scale_factor=1.0, + ldm_latent_shape=list(latent_shape[2:]), + autoencoder_latent_shape=autoencoder_latent_shape, + ) + scheduler.set_timesteps(num_inference_steps=10) + + timesteps = torch.randint(0, scheduler.num_train_timesteps, (input_shape[0],), device=input.device).long() + + if dm_model_type == "SPADEDiffusionModelUNet": + input_shape_seg = list(input_shape) + if "label_nc" in stage_2_params.keys(): + input_shape_seg[1] = stage_2_params["label_nc"] + else: + input_shape_seg[1] = autoencoder_params["label_nc"] + input_seg = torch.randn(input_shape_seg).to(device) + prediction = inferer( + inputs=input, + autoencoder_model=stage_1, + diffusion_model=stage_2, + noise=noise, + timesteps=timesteps, + seg=input_seg, + ) + else: + prediction = inferer( + inputs=input, autoencoder_model=stage_1, diffusion_model=stage_2, noise=noise, timesteps=timesteps + ) + self.assertEqual(prediction.shape, latent_shape) + + def test_incompatible_spade_setup(self): + stage_1 = SPADEAutoencoderKL( + spatial_dims=2, + label_nc=6, + in_channels=1, + out_channels=1, + channels=(4, 4), + latent_channels=3, + attention_levels=[False, False], + num_res_blocks=1, + with_encoder_nonlocal_attn=False, + with_decoder_nonlocal_attn=False, + norm_num_groups=4, + ) + stage_2 = SPADEDiffusionModelUNet( + spatial_dims=2, + label_nc=3, + in_channels=3, + out_channels=3, + channels=[4, 4], + norm_num_groups=4, + attention_levels=[False, False], + num_res_blocks=1, + num_head_channels=4, + ) + + device = "cuda:0" if torch.cuda.is_available() else "cpu" + stage_1.to(device) + stage_2.to(device) + stage_1.eval() + stage_2.eval() + noise = torch.randn((1, 3, 4, 4)).to(device) + input_seg = torch.randn((1, 3, 8, 8)).to(device) + scheduler = DDPMScheduler(num_train_timesteps=10) + inferer = LatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0) + scheduler.set_timesteps(num_inference_steps=10) + + with self.assertRaises(ValueError): + _ = inferer.sample( + input_noise=noise, + autoencoder_model=stage_1, + diffusion_model=stage_2, + scheduler=scheduler, + seg=input_seg, + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_ordering.py b/tests/test_ordering.py index 0c52dba5e5..e6b235e179 100644 --- a/tests/test_ordering.py +++ b/tests/test_ordering.py @@ -182,24 +182,6 @@ ], ] -TEST_2D_RANDOM = [ - [ - { - "ordering_type": OrderingType.RANDOM, - "spatial_dims": 2, - "dimensions": (1, 2, 2), - "reflected_spatial_dims": (True, False), - "transpositions_axes": ((1, 0),), - "rot90_axes": ((0, 1),), - "transformation_order": ( - OrderingTransformations.TRANSPOSE.value, - OrderingTransformations.ROTATE_90.value, - OrderingTransformations.REFLECT.value, - ), - }, - [[0, 1, 2, 3], [0, 1, 3, 2]], - ] -] TEST_3D = [ [ @@ -291,17 +273,6 @@ def test_ordering_transformation_failure(self, input_param): with self.assertRaises(ValueError): Ordering(**input_param) - @parameterized.expand(TEST_2D_RANDOM) - def test_random(self, input_param, not_in_expected_sequence_ordering): - ordering = Ordering(**input_param) - - not_in = [ - np.array_equal(sequence, ordering.get_sequence_ordering(), equal_nan=True) - for sequence in not_in_expected_sequence_ordering - ] - - self.assertFalse(np.any(not_in)) - @parameterized.expand(TEST_REVERT) def test_revert(self, input_param): sequence = np.random.randint(0, 100, size=input_param["dimensions"]).flatten() diff --git a/test_spade_autoencoderkl.py b/tests/test_spade_autoencoderkl.py similarity index 100% rename from test_spade_autoencoderkl.py rename to tests/test_spade_autoencoderkl.py diff --git a/test_spade_diffusion_model_unet.py b/tests/test_spade_diffusion_model_unet.py similarity index 92% rename from test_spade_diffusion_model_unet.py rename to tests/test_spade_diffusion_model_unet.py index c8a2103cf6..113e58ed89 100644 --- a/test_spade_diffusion_model_unet.py +++ b/tests/test_spade_diffusion_model_unet.py @@ -26,7 +26,7 @@ "in_channels": 1, "out_channels": 1, "num_res_blocks": 1, - "num_channels": (8, 8, 8), + "channels": (8, 8, 8), "attention_levels": (False, False, False), "norm_num_groups": 8, "label_nc": 3, @@ -38,7 +38,7 @@ "in_channels": 1, "out_channels": 1, "num_res_blocks": (1, 1, 2), - "num_channels": (8, 8, 8), + "channels": (8, 8, 8), "attention_levels": (False, False, False), "norm_num_groups": 8, "label_nc": 3, @@ -50,7 +50,7 @@ "in_channels": 1, "out_channels": 1, "num_res_blocks": 1, - "num_channels": (8, 8, 8), + "channels": (8, 8, 8), "attention_levels": (False, False, False), "norm_num_groups": 8, "resblock_updown": True, @@ -63,7 +63,7 @@ "in_channels": 1, "out_channels": 1, "num_res_blocks": 1, - "num_channels": (8, 8, 8), + "channels": (8, 8, 8), "attention_levels": (False, False, True), "num_head_channels": 8, "norm_num_groups": 8, @@ -76,7 +76,7 @@ "in_channels": 1, "out_channels": 1, "num_res_blocks": 1, - "num_channels": (8, 8, 8), + "channels": (8, 8, 8), "attention_levels": (False, False, True), "num_head_channels": 8, "norm_num_groups": 8, @@ -90,7 +90,7 @@ "in_channels": 1, "out_channels": 1, "num_res_blocks": 1, - "num_channels": (8, 8, 8), + "channels": (8, 8, 8), "attention_levels": (False, False, True), "num_head_channels": 4, "norm_num_groups": 8, @@ -103,7 +103,7 @@ "in_channels": 1, "out_channels": 1, "num_res_blocks": 1, - "num_channels": (8, 8, 8), + "channels": (8, 8, 8), "attention_levels": (False, True, True), "num_head_channels": (0, 2, 4), "norm_num_groups": 8, @@ -119,7 +119,7 @@ "in_channels": 1, "out_channels": 1, "num_res_blocks": 1, - "num_channels": (8, 8, 8), + "channels": (8, 8, 8), "attention_levels": (False, False, False), "norm_num_groups": 8, "label_nc": 3, @@ -132,7 +132,7 @@ "in_channels": 1, "out_channels": 1, "num_res_blocks": 1, - "num_channels": (8, 8, 8), + "channels": (8, 8, 8), "attention_levels": (False, False, False), "norm_num_groups": 8, "label_nc": 3, @@ -144,7 +144,7 @@ "in_channels": 1, "out_channels": 1, "num_res_blocks": 1, - "num_channels": (8, 8, 8), + "channels": (8, 8, 8), "attention_levels": (False, False, False), "norm_num_groups": 8, "resblock_updown": True, @@ -157,7 +157,7 @@ "in_channels": 1, "out_channels": 1, "num_res_blocks": 1, - "num_channels": (8, 8, 8), + "channels": (8, 8, 8), "attention_levels": (False, False, True), "num_head_channels": 8, "norm_num_groups": 8, @@ -170,7 +170,7 @@ "in_channels": 1, "out_channels": 1, "num_res_blocks": 1, - "num_channels": (8, 8, 8), + "channels": (8, 8, 8), "attention_levels": (False, False, True), "num_head_channels": 8, "norm_num_groups": 8, @@ -184,7 +184,7 @@ "in_channels": 1, "out_channels": 1, "num_res_blocks": 1, - "num_channels": (8, 8, 8), + "channels": (8, 8, 8), "attention_levels": (False, False, True), "num_head_channels": 4, "norm_num_groups": 8, @@ -197,7 +197,7 @@ "in_channels": 1, "out_channels": 1, "num_res_blocks": 1, - "num_channels": (8, 8, 8), + "channels": (8, 8, 8), "attention_levels": (False, False, True), "num_head_channels": (0, 0, 4), "norm_num_groups": 8, @@ -213,7 +213,7 @@ "in_channels": 1, "out_channels": 1, "num_res_blocks": 1, - "num_channels": (8, 8, 8), + "channels": (8, 8, 8), "attention_levels": (False, False, True), "num_head_channels": 4, "norm_num_groups": 8, @@ -229,7 +229,7 @@ "in_channels": 1, "out_channels": 1, "num_res_blocks": 1, - "num_channels": (8, 8, 8), + "channels": (8, 8, 8), "attention_levels": (False, False, True), "num_head_channels": 4, "norm_num_groups": 8, @@ -246,7 +246,7 @@ "in_channels": 1, "out_channels": 1, "num_res_blocks": 1, - "num_channels": (8, 8, 8), + "channels": (8, 8, 8), "attention_levels": (False, False, True), "num_head_channels": 4, "norm_num_groups": 8, @@ -279,7 +279,7 @@ def test_timestep_with_wrong_shape(self): in_channels=1, out_channels=1, num_res_blocks=1, - num_channels=(8, 8, 8), + channels=(8, 8, 8), attention_levels=(False, False, False), norm_num_groups=8, ) @@ -296,7 +296,7 @@ def test_label_with_wrong_shape(self): in_channels=1, out_channels=1, num_res_blocks=1, - num_channels=(8, 8, 8), + channels=(8, 8, 8), attention_levels=(False, False, False), norm_num_groups=8, ) @@ -313,7 +313,7 @@ def test_shape_with_different_in_channel_out_channel(self): in_channels=in_channels, out_channels=out_channels, num_res_blocks=1, - num_channels=(8, 8, 8), + channels=(8, 8, 8), attention_levels=(False, False, False), norm_num_groups=8, ) @@ -331,7 +331,7 @@ def test_model_channels_not_multiple_of_norm_num_group(self): in_channels=1, out_channels=1, num_res_blocks=1, - num_channels=(8, 8, 12), + channels=(8, 8, 12), attention_levels=(False, False, False), norm_num_groups=8, ) @@ -344,13 +344,13 @@ def test_attention_levels_with_different_length_num_head_channels(self): in_channels=1, out_channels=1, num_res_blocks=1, - num_channels=(8, 8, 8), + channels=(8, 8, 8), attention_levels=(False, False, False), num_head_channels=(0, 2), norm_num_groups=8, ) - def test_num_res_blocks_with_different_length_num_channels(self): + def test_num_res_blocks_with_different_length_channels(self): with self.assertRaises(ValueError): SPADEDiffusionModelUNet( spatial_dims=2, @@ -358,7 +358,7 @@ def test_num_res_blocks_with_different_length_num_channels(self): in_channels=1, out_channels=1, num_res_blocks=(1, 1), - num_channels=(8, 8, 8), + channels=(8, 8, 8), attention_levels=(False, False, False), norm_num_groups=8, ) @@ -370,7 +370,7 @@ def test_shape_conditioned_models(self): in_channels=1, out_channels=1, num_res_blocks=1, - num_channels=(8, 8, 8), + channels=(8, 8, 8), attention_levels=(False, False, True), with_conditioning=True, transformer_num_layers=1, @@ -395,7 +395,7 @@ def test_with_conditioning_cross_attention_dim_none(self): in_channels=1, out_channels=1, num_res_blocks=1, - num_channels=(8, 8, 8), + channels=(8, 8, 8), attention_levels=(False, False, True), with_conditioning=True, transformer_num_layers=1, @@ -410,7 +410,7 @@ def test_context_with_conditioning_none(self): in_channels=1, out_channels=1, num_res_blocks=1, - num_channels=(8, 8, 8), + channels=(8, 8, 8), attention_levels=(False, False, True), with_conditioning=False, transformer_num_layers=1, @@ -433,7 +433,7 @@ def test_shape_conditioned_models_class_conditioning(self): in_channels=1, out_channels=1, num_res_blocks=1, - num_channels=(8, 8, 8), + channels=(8, 8, 8), attention_levels=(False, False, True), norm_num_groups=8, num_head_channels=8, @@ -455,7 +455,7 @@ def test_conditioned_models_no_class_labels(self): in_channels=1, out_channels=1, num_res_blocks=1, - num_channels=(8, 8, 8), + channels=(8, 8, 8), attention_levels=(False, False, True), norm_num_groups=8, num_head_channels=8, @@ -469,7 +469,7 @@ def test_conditioned_models_no_class_labels(self): seg=torch.rand((1, 3, 16, 32)), ) - def test_model_num_channels_not_same_size_of_attention_levels(self): + def test_model_channels_not_same_size_of_attention_levels(self): with self.assertRaises(ValueError): SPADEDiffusionModelUNet( spatial_dims=2, @@ -477,7 +477,7 @@ def test_model_num_channels_not_same_size_of_attention_levels(self): in_channels=1, out_channels=1, num_res_blocks=1, - num_channels=(8, 8, 8), + channels=(8, 8, 8), attention_levels=(False, False), norm_num_groups=8, num_head_channels=8, @@ -518,7 +518,7 @@ def test_shape_with_different_in_channel_out_channel(self): in_channels=in_channels, out_channels=out_channels, num_res_blocks=1, - num_channels=(8, 8, 8), + channels=(8, 8, 8), attention_levels=(False, False, True), norm_num_groups=4, ) @@ -537,7 +537,7 @@ def test_shape_conditioned_models(self): in_channels=1, out_channels=1, num_res_blocks=1, - num_channels=(16, 16, 16), + channels=(16, 16, 16), attention_levels=(False, False, True), norm_num_groups=16, with_conditioning=True, diff --git a/tests/test_vqvaetransformer_inferer.py b/tests/test_vqvaetransformer_inferer.py new file mode 100644 index 0000000000..1a511d287b --- /dev/null +++ b/tests/test_vqvaetransformer_inferer.py @@ -0,0 +1,284 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import torch +from parameterized import parameterized + +from monai.inferers import VQVAETransformerInferer +from monai.networks.nets import VQVAE, DecoderOnlyTransformer +from monai.utils.ordering import Ordering, OrderingType + +TEST_CASES = [ + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "channels": (8, 8), + "num_res_channels": (8, 8), + "downsample_parameters": ((2, 4, 1, 1),) * 2, + "upsample_parameters": ((2, 4, 1, 1, 0),) * 2, + "num_res_layers": 1, + "num_embeddings": 16, + "embedding_dim": 8, + }, + { + "num_tokens": 16 + 1, + "max_seq_len": 4, + "attn_layers_dim": 4, + "attn_layers_depth": 2, + "attn_layers_heads": 1, + "with_cross_attention": False, + }, + {"ordering_type": OrderingType.RASTER_SCAN.value, "spatial_dims": 2, "dimensions": (2, 2, 2)}, + (2, 1, 8, 8), + (2, 4, 17), + (2, 2, 2), + ], + [ + { + "spatial_dims": 3, + "in_channels": 1, + "out_channels": 1, + "channels": (8, 8), + "num_res_channels": (8, 8), + "downsample_parameters": ((2, 4, 1, 1),) * 2, + "upsample_parameters": ((2, 4, 1, 1, 0),) * 2, + "num_res_layers": 1, + "num_embeddings": 16, + "embedding_dim": 8, + }, + { + "num_tokens": 16 + 1, + "max_seq_len": 8, + "attn_layers_dim": 4, + "attn_layers_depth": 2, + "attn_layers_heads": 1, + "with_cross_attention": False, + }, + {"ordering_type": OrderingType.RASTER_SCAN.value, "spatial_dims": 3, "dimensions": (2, 2, 2, 2)}, + (2, 1, 8, 8, 8), + (2, 8, 17), + (2, 2, 2, 2), + ], +] + + +class TestVQVAETransformerInferer(unittest.TestCase): + @parameterized.expand(TEST_CASES) + def test_prediction_shape( + self, stage_1_params, stage_2_params, ordering_params, input_shape, logits_shape, latent_shape + ): + stage_1 = VQVAE(**stage_1_params) + stage_2 = DecoderOnlyTransformer(**stage_2_params) + ordering = Ordering(**ordering_params) + + device = "cuda:0" if torch.cuda.is_available() else "cpu" + stage_1.to(device) + stage_2.to(device) + stage_1.eval() + stage_2.eval() + + input = torch.randn(input_shape).to(device) + + inferer = VQVAETransformerInferer() + prediction = inferer(inputs=input, vqvae_model=stage_1, transformer_model=stage_2, ordering=ordering) + self.assertEqual(prediction.shape, logits_shape) + + @parameterized.expand(TEST_CASES) + def test_prediction_shape_shorter_sequence( + self, stage_1_params, stage_2_params, ordering_params, input_shape, logits_shape, latent_shape + ): + stage_1 = VQVAE(**stage_1_params) + max_seq_len = 3 + stage_2_params_shorter = dict(stage_2_params) + stage_2_params_shorter["max_seq_len"] = max_seq_len + stage_2 = DecoderOnlyTransformer(**stage_2_params_shorter) + ordering = Ordering(**ordering_params) + + device = "cuda:0" if torch.cuda.is_available() else "cpu" + stage_1.to(device) + stage_2.to(device) + stage_1.eval() + stage_2.eval() + + input = torch.randn(input_shape).to(device) + + inferer = VQVAETransformerInferer() + prediction = inferer(inputs=input, vqvae_model=stage_1, transformer_model=stage_2, ordering=ordering) + cropped_logits_shape = (logits_shape[0], max_seq_len, logits_shape[2]) + self.assertEqual(prediction.shape, cropped_logits_shape) + + def test_sample(self): + stage_1 = VQVAE( + spatial_dims=2, + in_channels=1, + out_channels=1, + channels=(8, 8), + num_res_channels=(8, 8), + downsample_parameters=((2, 4, 1, 1),) * 2, + upsample_parameters=((2, 4, 1, 1, 0),) * 2, + num_res_layers=1, + num_embeddings=16, + embedding_dim=8, + ) + stage_2 = DecoderOnlyTransformer( + num_tokens=16 + 1, + max_seq_len=4, + attn_layers_dim=4, + attn_layers_depth=2, + attn_layers_heads=1, + with_cross_attention=False, + ) + ordering = Ordering(ordering_type=OrderingType.RASTER_SCAN.value, spatial_dims=2, dimensions=(2, 2, 2)) + + device = "cuda:0" if torch.cuda.is_available() else "cpu" + stage_1.to(device) + stage_2.to(device) + stage_1.eval() + stage_2.eval() + + inferer = VQVAETransformerInferer() + + starting_token = 16 # from stage_1 num_embeddings + + sample = inferer.sample( + latent_spatial_dim=(2, 2), + starting_tokens=starting_token * torch.ones((2, 1), device=device), + vqvae_model=stage_1, + transformer_model=stage_2, + ordering=ordering, + ) + self.assertEqual(sample.shape, (2, 1, 8, 8)) + + def test_sample_shorter_sequence(self): + stage_1 = VQVAE( + spatial_dims=2, + in_channels=1, + out_channels=1, + channels=(8, 8), + num_res_channels=(8, 8), + downsample_parameters=((2, 4, 1, 1),) * 2, + upsample_parameters=((2, 4, 1, 1, 0),) * 2, + num_res_layers=1, + num_embeddings=16, + embedding_dim=8, + ) + stage_2 = DecoderOnlyTransformer( + num_tokens=16 + 1, + max_seq_len=2, + attn_layers_dim=4, + attn_layers_depth=2, + attn_layers_heads=1, + with_cross_attention=False, + ) + ordering = Ordering(ordering_type=OrderingType.RASTER_SCAN.value, spatial_dims=2, dimensions=(2, 2, 2)) + + device = "cuda:0" if torch.cuda.is_available() else "cpu" + stage_1.to(device) + stage_2.to(device) + stage_1.eval() + stage_2.eval() + + inferer = VQVAETransformerInferer() + + starting_token = 16 # from stage_1 num_embeddings + + sample = inferer.sample( + latent_spatial_dim=(2, 2), + starting_tokens=starting_token * torch.ones((2, 1), device=device), + vqvae_model=stage_1, + transformer_model=stage_2, + ordering=ordering, + ) + self.assertEqual(sample.shape, (2, 1, 8, 8)) + + @parameterized.expand(TEST_CASES) + def test_get_likelihood( + self, stage_1_params, stage_2_params, ordering_params, input_shape, logits_shape, latent_shape + ): + stage_1 = VQVAE(**stage_1_params) + stage_2 = DecoderOnlyTransformer(**stage_2_params) + ordering = Ordering(**ordering_params) + + device = "cuda:0" if torch.cuda.is_available() else "cpu" + stage_1.to(device) + stage_2.to(device) + stage_1.eval() + stage_2.eval() + + input = torch.randn(input_shape).to(device) + + inferer = VQVAETransformerInferer() + likelihood = inferer.get_likelihood( + inputs=input, vqvae_model=stage_1, transformer_model=stage_2, ordering=ordering + ) + self.assertEqual(likelihood.shape, latent_shape) + + @parameterized.expand(TEST_CASES) + def test_get_likelihood_shorter_sequence( + self, stage_1_params, stage_2_params, ordering_params, input_shape, logits_shape, latent_shape + ): + stage_1 = VQVAE(**stage_1_params) + max_seq_len = 3 + stage_2_params_shorter = dict(stage_2_params) + stage_2_params_shorter["max_seq_len"] = max_seq_len + stage_2 = DecoderOnlyTransformer(**stage_2_params_shorter) + ordering = Ordering(**ordering_params) + + device = "cuda:0" if torch.cuda.is_available() else "cpu" + stage_1.to(device) + stage_2.to(device) + stage_1.eval() + stage_2.eval() + + input = torch.randn(input_shape).to(device) + + inferer = VQVAETransformerInferer() + likelihood = inferer.get_likelihood( + inputs=input, vqvae_model=stage_1, transformer_model=stage_2, ordering=ordering + ) + self.assertEqual(likelihood.shape, latent_shape) + + @parameterized.expand(TEST_CASES) + def test_get_likelihood_resampling( + self, stage_1_params, stage_2_params, ordering_params, input_shape, logits_shape, latent_shape + ): + stage_1 = VQVAE(**stage_1_params) + stage_2 = DecoderOnlyTransformer(**stage_2_params) + ordering = Ordering(**ordering_params) + + device = "cuda:0" if torch.cuda.is_available() else "cpu" + stage_1.to(device) + stage_2.to(device) + stage_1.eval() + stage_2.eval() + + input = torch.randn(input_shape).to(device) + + inferer = VQVAETransformerInferer() + likelihood = inferer.get_likelihood( + inputs=input, + vqvae_model=stage_1, + transformer_model=stage_2, + ordering=ordering, + resample_latent_likelihoods=True, + resample_interpolation_mode="nearest", + ) + self.assertEqual(likelihood.shape, input_shape) + + +if __name__ == "__main__": + unittest.main() From 41fb3ff8af39529b0641c9b1d3341987cafac62b Mon Sep 17 00:00:00 2001 From: vgrau98 <35843843+vgrau98@users.noreply.github.com> Date: Thu, 18 Jan 2024 17:21:27 +0100 Subject: [PATCH 12/38] [Attention block] relative positional embedding (#7346) Fixes #7356 ### Description Add relative positinoal embedding in attention block as described in https://arxiv.org/pdf/2112.01526.pdf Largely inspired by https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py Can be useful for #6357 ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [x] In-line docstrings updated. - [x] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: vgrau98 Signed-off-by: vgrau98 <35843843+vgrau98@users.noreply.github.com> Co-authored-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- docs/source/networks.rst | 6 + monai/networks/blocks/attention_utils.py | 128 +++++++++++++++++++++ monai/networks/blocks/rel_pos_embedding.py | 56 +++++++++ monai/networks/blocks/selfattention.py | 33 +++++- monai/networks/layers/factories.py | 13 ++- monai/networks/layers/utils.py | 15 ++- tests/test_selfattention.py | 21 +++- 7 files changed, 262 insertions(+), 10 deletions(-) create mode 100644 monai/networks/blocks/attention_utils.py create mode 100644 monai/networks/blocks/rel_pos_embedding.py diff --git a/docs/source/networks.rst b/docs/source/networks.rst index f9375f1e97..556bf12d50 100644 --- a/docs/source/networks.rst +++ b/docs/source/networks.rst @@ -248,6 +248,12 @@ Blocks .. autoclass:: monai.apps.reconstruction.networks.blocks.varnetblock.VarNetBlock :members: +`Attention utilities` +~~~~~~~~~~~~~~~~~~~~~ +.. automodule:: monai.networks.blocks.attention_utils +.. autofunction:: monai.networks.blocks.attention_utils.get_rel_pos +.. autofunction:: monai.networks.blocks.attention_utils.add_decomposed_rel_pos + N-Dim Fourier Transform ~~~~~~~~~~~~~~~~~~~~~~~~ .. automodule:: monai.networks.blocks.fft_utils_t diff --git a/monai/networks/blocks/attention_utils.py b/monai/networks/blocks/attention_utils.py new file mode 100644 index 0000000000..8c9002a16e --- /dev/null +++ b/monai/networks/blocks/attention_utils.py @@ -0,0 +1,128 @@ +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import Tuple + +import torch +import torch.nn.functional as F +from torch import nn + + +def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor: + """ + Get relative positional embeddings according to the relative positions of + query and key sizes. + + Args: + q_size (int): size of query q. + k_size (int): size of key k. + rel_pos (Tensor): relative position embeddings (L, C). + + Returns: + Extracted positional embeddings according to relative positions. + """ + rel_pos_resized: torch.Tensor = torch.Tensor() + max_rel_dist = int(2 * max(q_size, k_size) - 1) + # Interpolate rel pos if needed. + if rel_pos.shape[0] != max_rel_dist: + # Interpolate rel pos. + rel_pos_resized = F.interpolate( + rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1), size=max_rel_dist, mode="linear" + ) + rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0) + else: + rel_pos_resized = rel_pos + + # Scale the coords with short length if shapes for q and k are different. + q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0) + k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0) + relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0) + + return rel_pos_resized[relative_coords.long()] + + +def add_decomposed_rel_pos( + attn: torch.Tensor, q: torch.Tensor, rel_pos_lst: nn.ParameterList, q_size: Tuple, k_size: Tuple +) -> torch.Tensor: + r""" + Calculate decomposed Relative Positional Embeddings from mvitv2 implementation: + https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py + + Only 2D and 3D are supported. + + Encoding the relative position of tokens in the attention matrix: tokens spaced a distance + `d` apart will have the same embedding value (unlike absolute positional embedding). + + .. math:: + Attn_{logits}(Q, K) = (QK^{T} + E_{rel})*scale + + where + + .. math:: + E_{ij}^{(rel)} = Q_{i}.R_{p(i), p(j)} + + with :math:`R_{p(i), p(j)} \in R^{dim}` and :math:`p(i), p(j)`, + respectively spatial positions of element :math:`i` and :math:`j` + + When using "decomposed" relative positional embedding, positional embedding is defined ("decomposed") as follow: + + .. math:: + R_{p(i), p(j)} = R^{d1}_{d1(i), d1(j)} + ... + R^{dn}_{dn(i), dn(j)} + + with :math:`n = 1...dim` + + Decomposed relative positional embedding reduces the complexity from :math:`\mathcal{O}(d1*...*dn)` to + :math:`\mathcal{O}(d1+...+dn)` compared with classical relative positional embedding. + + Args: + attn (Tensor): attention map. + q (Tensor): query q in the attention layer with shape (B, s_dim_1 * ... * s_dim_n, C). + rel_pos_lst (ParameterList): relative position embeddings for each axis: rel_pos_lst[n] for nth axis. + q_size (Tuple): spatial sequence size of query q with (q_dim_1, ..., q_dim_n). + k_size (Tuple): spatial sequence size of key k with (k_dim_1, ..., k_dim_n). + + Returns: + attn (Tensor): attention logits with added relative positional embeddings. + """ + rh = get_rel_pos(q_size[0], k_size[0], rel_pos_lst[0]) + rw = get_rel_pos(q_size[1], k_size[1], rel_pos_lst[1]) + + batch, _, dim = q.shape + + if len(rel_pos_lst) == 2: + q_h, q_w = q_size[:2] + k_h, k_w = k_size[:2] + r_q = q.reshape(batch, q_h, q_w, dim) + rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, rh) + rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, rw) + + attn = (attn.view(batch, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]).view( + batch, q_h * q_w, k_h * k_w + ) + elif len(rel_pos_lst) == 3: + q_h, q_w, q_d = q_size[:3] + k_h, k_w, k_d = k_size[:3] + + rd = get_rel_pos(q_d, k_d, rel_pos_lst[2]) + + r_q = q.reshape(batch, q_h, q_w, q_d, dim) + rel_h = torch.einsum("bhwdc,hkc->bhwdk", r_q, rh) + rel_w = torch.einsum("bhwdc,wkc->bhwdk", r_q, rw) + rel_d = torch.einsum("bhwdc,wkc->bhwdk", r_q, rd) + + attn = ( + attn.view(batch, q_h, q_w, q_d, k_h, k_w, k_d) + + rel_h[:, :, :, :, None, None] + + rel_w[:, :, :, None, :, None] + + rel_d[:, :, :, None, None, :] + ).view(batch, q_h * q_w * q_d, k_h * k_w * k_d) + + return attn diff --git a/monai/networks/blocks/rel_pos_embedding.py b/monai/networks/blocks/rel_pos_embedding.py new file mode 100644 index 0000000000..e53e5841b0 --- /dev/null +++ b/monai/networks/blocks/rel_pos_embedding.py @@ -0,0 +1,56 @@ +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import Iterable, Tuple + +import torch +from torch import nn + +from monai.networks.blocks.attention_utils import add_decomposed_rel_pos +from monai.utils.misc import ensure_tuple_size + + +class DecomposedRelativePosEmbedding(nn.Module): + def __init__(self, s_input_dims: Tuple[int, int] | Tuple[int, int, int], c_dim: int, num_heads: int) -> None: + """ + Args: + s_input_dims (Tuple): input spatial dimension. (H, W) or (H, W, D) + c_dim (int): channel dimension + num_heads(int): number of attention heads + """ + super().__init__() + + # validate inputs + if not isinstance(s_input_dims, Iterable) or len(s_input_dims) not in [2, 3]: + raise ValueError("s_input_dims must be set as follows: (H, W) or (H, W, D)") + + self.s_input_dims = s_input_dims + self.c_dim = c_dim + self.num_heads = num_heads + self.rel_pos_arr = nn.ParameterList( + [nn.Parameter(torch.zeros(2 * dim_input_size - 1, c_dim)) for dim_input_size in s_input_dims] + ) + + def forward(self, x: torch.Tensor, att_mat: torch.Tensor, q: torch.Tensor) -> torch.Tensor: + """""" + batch = x.shape[0] + h, w, d = ensure_tuple_size(self.s_input_dims, 3, 1) + + att_mat = add_decomposed_rel_pos( + att_mat.contiguous().view(batch * self.num_heads, h * w * d, h * w * d), + q.contiguous().view(batch * self.num_heads, h * w * d, -1), + self.rel_pos_arr, + (h, w) if d == 1 else (h, w, d), + (h, w) if d == 1 else (h, w, d), + ) + + att_mat = att_mat.reshape(batch, self.num_heads, h * w * d, h * w * d) + return att_mat diff --git a/monai/networks/blocks/selfattention.py b/monai/networks/blocks/selfattention.py index 7c81c1704f..3bef24b4e8 100644 --- a/monai/networks/blocks/selfattention.py +++ b/monai/networks/blocks/selfattention.py @@ -11,9 +11,12 @@ from __future__ import annotations +from typing import Optional, Tuple + import torch import torch.nn as nn +from monai.networks.layers.utils import get_rel_pos_embedding_layer from monai.utils import optional_import Rearrange, _ = optional_import("einops.layers.torch", name="Rearrange") @@ -23,6 +26,7 @@ class SABlock(nn.Module): """ A self-attention block, based on: "Dosovitskiy et al., An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale " + One can setup relative positional embedding as described in """ def __init__( @@ -32,6 +36,8 @@ def __init__( dropout_rate: float = 0.0, qkv_bias: bool = False, save_attn: bool = False, + rel_pos_embedding: Optional[str] = None, + input_size: Optional[Tuple] = None, ) -> None: """ Args: @@ -39,6 +45,10 @@ def __init__( num_heads (int): number of attention heads. dropout_rate (float, optional): fraction of the input units to drop. Defaults to 0.0. qkv_bias (bool, optional): bias term for the qkv linear layer. Defaults to False. + rel_pos_embedding (str, optional): Add relative positional embeddings to the attention map. + For now only "decomposed" is supported (see https://arxiv.org/abs/2112.01526). 2D and 3D are supported. + input_size (tuple(spatial_dim), optional): Input resolution for calculating the relative + positional parameter size. save_attn (bool, optional): to make accessible the attention matrix. Defaults to False. """ @@ -62,11 +72,30 @@ def __init__( self.scale = self.head_dim**-0.5 self.save_attn = save_attn self.att_mat = torch.Tensor() + self.rel_positional_embedding = ( + get_rel_pos_embedding_layer(rel_pos_embedding, input_size, self.head_dim, self.num_heads) + if rel_pos_embedding is not None + else None + ) + self.input_size = input_size + + def forward(self, x: torch.Tensor): + """ + Args: + x (torch.Tensor): input tensor. B x (s_dim_1 * ... * s_dim_n) x C - def forward(self, x): + Return: + torch.Tensor: B x (s_dim_1 * ... * s_dim_n) x C + """ output = self.input_rearrange(self.qkv(x)) q, k, v = output[0], output[1], output[2] - att_mat = (torch.einsum("blxd,blyd->blxy", q, k) * self.scale).softmax(dim=-1) + att_mat = torch.einsum("blxd,blyd->blxy", q, k) * self.scale + + # apply relative positional embedding if defined + att_mat = self.rel_positional_embedding(x, att_mat, q) if self.rel_positional_embedding is not None else att_mat + + att_mat = att_mat.softmax(dim=-1) + if self.save_attn: # no gradients and new tensor; # https://pytorch.org/docs/stable/generated/torch.Tensor.detach.html diff --git a/monai/networks/layers/factories.py b/monai/networks/layers/factories.py index 4fc2c16f73..29b72a4f37 100644 --- a/monai/networks/layers/factories.py +++ b/monai/networks/layers/factories.py @@ -70,7 +70,7 @@ def use_factory(fact_args): from monai.networks.utils import has_nvfuser_instance_norm from monai.utils import ComponentStore, look_up_option, optional_import -__all__ = ["LayerFactory", "Dropout", "Norm", "Act", "Conv", "Pool", "Pad", "split_args"] +__all__ = ["LayerFactory", "Dropout", "Norm", "Act", "Conv", "Pool", "Pad", "RelPosEmbedding", "split_args"] class LayerFactory(ComponentStore): @@ -201,6 +201,10 @@ def split_args(args): Conv = LayerFactory(name="Convolution layers", description="Factory for creating convolution layers.") Pool = LayerFactory(name="Pooling layers", description="Factory for creating pooling layers.") Pad = LayerFactory(name="Padding layers", description="Factory for creating padding layers.") +RelPosEmbedding = LayerFactory( + name="Relative positional embedding layers", + description="Factory for creating relative positional embedding factory", +) @Dropout.factory_function("dropout") @@ -468,3 +472,10 @@ def constant_pad_factory(dim: int) -> type[nn.ConstantPad1d | nn.ConstantPad2d | """ types = (nn.ConstantPad1d, nn.ConstantPad2d, nn.ConstantPad3d) return types[dim - 1] + + +@RelPosEmbedding.factory_function("decomposed") +def decomposed_rel_pos_embedding() -> type[nn.Module]: + from monai.networks.blocks.rel_pos_embedding import DecomposedRelativePosEmbedding + + return DecomposedRelativePosEmbedding diff --git a/monai/networks/layers/utils.py b/monai/networks/layers/utils.py index ace1af27b6..8676f74638 100644 --- a/monai/networks/layers/utils.py +++ b/monai/networks/layers/utils.py @@ -11,9 +11,11 @@ from __future__ import annotations +from typing import Optional + import torch.nn -from monai.networks.layers.factories import Act, Dropout, Norm, Pool, split_args +from monai.networks.layers.factories import Act, Dropout, Norm, Pool, RelPosEmbedding, split_args from monai.utils import has_option __all__ = ["get_norm_layer", "get_act_layer", "get_dropout_layer", "get_pool_layer"] @@ -124,3 +126,14 @@ def get_pool_layer(name: tuple | str, spatial_dims: int | None = 1): pool_name, pool_args = split_args(name) pool_type = Pool[pool_name, spatial_dims] return pool_type(**pool_args) + + +def get_rel_pos_embedding_layer(name: tuple | str, s_input_dims: Optional[tuple], c_dim: int, num_heads: int): + embedding_name, embedding_args = split_args(name) + embedding_type = RelPosEmbedding[embedding_name] + # create a dictionary with the default values which can be overridden by embedding_args + kw_args = {"s_input_dims": s_input_dims, "c_dim": c_dim, "num_heads": num_heads, **embedding_args} + # filter out unused argument names + kw_args = {k: v for k, v in kw_args.items() if has_option(embedding_type, k)} + + return embedding_type(**kw_args) diff --git a/tests/test_selfattention.py b/tests/test_selfattention.py index 6062b5352f..0d0553ed2c 100644 --- a/tests/test_selfattention.py +++ b/tests/test_selfattention.py @@ -20,6 +20,7 @@ from monai.networks import eval_mode from monai.networks.blocks.selfattention import SABlock +from monai.networks.layers.factories import RelPosEmbedding from monai.utils import optional_import einops, has_einops = optional_import("einops") @@ -28,12 +29,20 @@ for dropout_rate in np.linspace(0, 1, 4): for hidden_size in [360, 480, 600, 768]: for num_heads in [4, 6, 8, 12]: - test_case = [ - {"hidden_size": hidden_size, "num_heads": num_heads, "dropout_rate": dropout_rate}, - (2, 512, hidden_size), - (2, 512, hidden_size), - ] - TEST_CASE_SABLOCK.append(test_case) + for rel_pos_embedding in [None, RelPosEmbedding.DECOMPOSED]: + for input_size in [(16, 32), (8, 8, 8)]: + test_case = [ + { + "hidden_size": hidden_size, + "num_heads": num_heads, + "dropout_rate": dropout_rate, + "rel_pos_embedding": rel_pos_embedding, + "input_size": input_size, + }, + (2, 512, hidden_size), + (2, 512, hidden_size), + ] + TEST_CASE_SABLOCK.append(test_case) class TestResBlock(unittest.TestCase): From 39a6f579538532c84ec704bb749a8191f1f17166 Mon Sep 17 00:00:00 2001 From: vgrau98 Date: Wed, 24 Jan 2024 23:58:56 +0100 Subject: [PATCH 13/38] causal self attention Signed-off-by: vgrau98 --- monai/networks/blocks/selfattention.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/monai/networks/blocks/selfattention.py b/monai/networks/blocks/selfattention.py index 3bef24b4e8..306ac534db 100644 --- a/monai/networks/blocks/selfattention.py +++ b/monai/networks/blocks/selfattention.py @@ -38,6 +38,8 @@ def __init__( save_attn: bool = False, rel_pos_embedding: Optional[str] = None, input_size: Optional[Tuple] = None, + causal: bool = False, + sequence_length: int | None = None, ) -> None: """ Args: @@ -49,6 +51,8 @@ def __init__( For now only "decomposed" is supported (see https://arxiv.org/abs/2112.01526). 2D and 3D are supported. input_size (tuple(spatial_dim), optional): Input resolution for calculating the relative positional parameter size. + causal (bool): wether to use causal attention. If true `sequence_length` has to be set + sequence_length (int, optional): if causal is True, it is necessary to specify the sequence length. save_attn (bool, optional): to make accessible the attention matrix. Defaults to False. """ @@ -61,6 +65,9 @@ def __init__( if hidden_size % num_heads != 0: raise ValueError("hidden size should be divisible by num_heads.") + if causal and sequence_length is None: + raise ValueError("sequence_length is necessary for causal attention.") + self.num_heads = num_heads self.out_proj = nn.Linear(hidden_size, hidden_size) self.qkv = nn.Linear(hidden_size, hidden_size * 3, bias=qkv_bias) @@ -70,6 +77,8 @@ def __init__( self.drop_weights = nn.Dropout(dropout_rate) self.head_dim = hidden_size // num_heads self.scale = self.head_dim**-0.5 + self.causal = causal + self.sequence_length = sequence_length self.save_attn = save_attn self.att_mat = torch.Tensor() self.rel_positional_embedding = ( @@ -79,6 +88,14 @@ def __init__( ) self.input_size = input_size + if causal and sequence_length is not None: + # causal mask to ensure that attention is only applied to the left in the input sequence + self.register_buffer( + "causal_mask", + torch.tril(torch.ones(sequence_length, sequence_length)).view(1, 1, sequence_length, sequence_length), + ) + self.causal_mask: torch.Tensor + def forward(self, x: torch.Tensor): """ Args: @@ -87,12 +104,15 @@ def forward(self, x: torch.Tensor): Return: torch.Tensor: B x (s_dim_1 * ... * s_dim_n) x C """ + _, t, _ = x.size() output = self.input_rearrange(self.qkv(x)) q, k, v = output[0], output[1], output[2] att_mat = torch.einsum("blxd,blyd->blxy", q, k) * self.scale # apply relative positional embedding if defined att_mat = self.rel_positional_embedding(x, att_mat, q) if self.rel_positional_embedding is not None else att_mat + # apply causal mask if set + att_mat = att_mat.masked_fill(self.causal_mask[:, :, :t, :t] == 0, float("-inf")) if self.causal else att_mat att_mat = att_mat.softmax(dim=-1) From 5b3c4f3c1bd1747b7b6a3551b0ac20affcd1b6fe Mon Sep 17 00:00:00 2001 From: vgrau98 Date: Thu, 25 Jan 2024 00:06:06 +0100 Subject: [PATCH 14/38] causal selfattention tests Signed-off-by: vgrau98 --- tests/test_selfattention.py | 27 +++++++++++++++------------ 1 file changed, 15 insertions(+), 12 deletions(-) diff --git a/tests/test_selfattention.py b/tests/test_selfattention.py index 0d0553ed2c..277ba7faf9 100644 --- a/tests/test_selfattention.py +++ b/tests/test_selfattention.py @@ -31,18 +31,21 @@ for num_heads in [4, 6, 8, 12]: for rel_pos_embedding in [None, RelPosEmbedding.DECOMPOSED]: for input_size in [(16, 32), (8, 8, 8)]: - test_case = [ - { - "hidden_size": hidden_size, - "num_heads": num_heads, - "dropout_rate": dropout_rate, - "rel_pos_embedding": rel_pos_embedding, - "input_size": input_size, - }, - (2, 512, hidden_size), - (2, 512, hidden_size), - ] - TEST_CASE_SABLOCK.append(test_case) + for causal in [False, True]: + test_case = [ + { + "hidden_size": hidden_size, + "num_heads": num_heads, + "dropout_rate": dropout_rate, + "rel_pos_embedding": rel_pos_embedding, + "input_size": input_size, + "causal": causal, + "sequence_length": 512, + }, + (2, 512, hidden_size), + (2, 512, hidden_size), + ] + TEST_CASE_SABLOCK.append(test_case) class TestResBlock(unittest.TestCase): From 440a89353e7ebffa01463902be0e9b8c4d690dcd Mon Sep 17 00:00:00 2001 From: vgrau98 Date: Sat, 27 Apr 2024 13:10:23 +0200 Subject: [PATCH 15/38] integrate flash attention usage Signed-off-by: vgrau98 --- monai/networks/blocks/selfattention.py | 56 ++++++++++++++++++++------ tests/test_selfattention.py | 29 +++++++++++++ 2 files changed, 72 insertions(+), 13 deletions(-) diff --git a/monai/networks/blocks/selfattention.py b/monai/networks/blocks/selfattention.py index 306ac534db..0a848e9ec5 100644 --- a/monai/networks/blocks/selfattention.py +++ b/monai/networks/blocks/selfattention.py @@ -12,6 +12,7 @@ from __future__ import annotations from typing import Optional, Tuple +import warnings import torch import torch.nn as nn @@ -19,6 +20,7 @@ from monai.networks.layers.utils import get_rel_pos_embedding_layer from monai.utils import optional_import +xops, has_xformers = optional_import("xformers.ops") Rearrange, _ = optional_import("einops.layers.torch", name="Rearrange") @@ -40,6 +42,7 @@ def __init__( input_size: Optional[Tuple] = None, causal: bool = False, sequence_length: int | None = None, + use_flash_attention: bool = False, ) -> None: """ Args: @@ -68,11 +71,23 @@ def __init__( if causal and sequence_length is None: raise ValueError("sequence_length is necessary for causal attention.") + if use_flash_attention and rel_pos_embedding is not None: + self.use_flash_attention = False + warnings.warn( + "flash attention set to `False`: flash attention can't be used with relative position embedding. Set `rel_pos_embedding` to `None` to use flash attention" + ) + else: + self.use_flash_attention = use_flash_attention + + if use_flash_attention and not has_xformers: + raise ValueError("use_flash_attention is True but xformers is not installed.") + self.num_heads = num_heads self.out_proj = nn.Linear(hidden_size, hidden_size) self.qkv = nn.Linear(hidden_size, hidden_size * 3, bias=qkv_bias) self.input_rearrange = Rearrange("b h (qkv l d) -> qkv b l h d", qkv=3, l=num_heads) self.out_rearrange = Rearrange("b h l d -> b l (h d)") + self.dropout_rate = dropout_rate self.drop_output = nn.Dropout(dropout_rate) self.drop_weights = nn.Dropout(dropout_rate) self.head_dim = hidden_size // num_heads @@ -105,24 +120,39 @@ def forward(self, x: torch.Tensor): torch.Tensor: B x (s_dim_1 * ... * s_dim_n) x C """ _, t, _ = x.size() - output = self.input_rearrange(self.qkv(x)) + output = self.input_rearrange(self.qkv(x)) # 3 x B x (s_dim_1 * ... * s_dim_n) x h x C/h q, k, v = output[0], output[1], output[2] - att_mat = torch.einsum("blxd,blyd->blxy", q, k) * self.scale - # apply relative positional embedding if defined - att_mat = self.rel_positional_embedding(x, att_mat, q) if self.rel_positional_embedding is not None else att_mat - # apply causal mask if set - att_mat = att_mat.masked_fill(self.causal_mask[:, :, :t, :t] == 0, float("-inf")) if self.causal else att_mat + if self.use_flash_attention: + x = xops.memory_efficient_attention( + query=q.contiguous(), + key=k.contiguous(), + value=v.contiguous(), + scale=self.scale, + p=self.dropout_rate, + attn_bias=xops.LowerTriangularMask() if self.causal else None, + ) + else: + att_mat = torch.einsum("blxd,blyd->blxy", q, k) * self.scale + + # apply relative positional embedding if defined + att_mat = ( + self.rel_positional_embedding(x, att_mat, q) if self.rel_positional_embedding is not None else att_mat + ) + # apply causal mask if set + att_mat = ( + att_mat.masked_fill(self.causal_mask[:, :, :t, :t] == 0, float("-inf")) if self.causal else att_mat + ) - att_mat = att_mat.softmax(dim=-1) + att_mat = att_mat.softmax(dim=-1) - if self.save_attn: - # no gradients and new tensor; - # https://pytorch.org/docs/stable/generated/torch.Tensor.detach.html - self.att_mat = att_mat.detach() + if self.save_attn: + # no gradients and new tensor; + # https://pytorch.org/docs/stable/generated/torch.Tensor.detach.html + self.att_mat = att_mat.detach() - att_mat = self.drop_weights(att_mat) - x = torch.einsum("bhxy,bhyd->bhxd", att_mat, v) + att_mat = self.drop_weights(att_mat) + x = torch.einsum("bhxy,bhyd->bhxd", att_mat, v) x = self.out_rearrange(x) x = self.out_proj(x) x = self.drop_output(x) diff --git a/tests/test_selfattention.py b/tests/test_selfattention.py index 277ba7faf9..33622d63c8 100644 --- a/tests/test_selfattention.py +++ b/tests/test_selfattention.py @@ -24,6 +24,7 @@ from monai.utils import optional_import einops, has_einops = optional_import("einops") +xops, has_xformers = optional_import("xformers.ops") TEST_CASE_SABLOCK = [] for dropout_rate in np.linspace(0, 1, 4): @@ -57,6 +58,34 @@ def test_shape(self, input_param, input_shape, expected_shape): result = net(torch.randn(input_shape)) self.assertEqual(result.shape, expected_shape) + @skipUnless(has_xformers, "Requires xformers") + def test_flash_attention(self): + hidden_size = 360 + num_heads = 4 + dropout_rate = 0 + input_shape = (2, 512, hidden_size) + expected_shape = (2, 512, hidden_size) + flash_attention_block = SABlock(hidden_size, num_heads, dropout_rate, use_flash_attention=True) + # flash attention set to false because of conflict using relative position embedding at the same time + no_flash_attention_block = SABlock( + hidden_size, + num_heads, + dropout_rate, + use_flash_attention=True, + rel_pos_embedding=RelPosEmbedding.DECOMPOSED, + sequence_length=512, + input_size=([16, 32]), + ) + + self.assertFalse(no_flash_attention_block.use_flash_attention) + + with eval_mode(flash_attention_block): + result = flash_attention_block(torch.randn(input_shape)) + self.assertEqual(result.shape, expected_shape) + with eval_mode(no_flash_attention_block): + result = no_flash_attention_block(torch.randn(input_shape)) + self.assertEqual(result.shape, expected_shape) + def test_ill_arg(self): with self.assertRaises(ValueError): SABlock(hidden_size=128, num_heads=12, dropout_rate=6.0) From 7531b5c531011090ca324e172b37432dda9afc76 Mon Sep 17 00:00:00 2001 From: YunLiu <55491388+KumoLiu@users.noreply.github.com> Date: Tue, 5 Dec 2023 13:20:01 +0800 Subject: [PATCH 16/38] update the Python version requirements for transformers (#7275) Part of #7250. ### Description Fix the Python version for transformers smaller than 3.10. ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: KumoLiu --- docs/requirements.txt | 2 +- requirements-dev.txt | 2 +- tests/test_transchex.py | 3 ++- 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/docs/requirements.txt b/docs/requirements.txt index a9bbc384f8..e5bedf8552 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -21,7 +21,7 @@ sphinxcontrib-serializinghtml sphinx-autodoc-typehints==1.11.1 pandas einops -transformers<4.22 # https://github.com/Project-MONAI/MONAI/issues/5157 +transformers<4.22; python_version <= '3.10' # https://github.com/Project-MONAI/MONAI/issues/5157 mlflow>=1.28.0 clearml>=1.10.0rc0 tensorboardX diff --git a/requirements-dev.txt b/requirements-dev.txt index 6332d5b0a5..cacbefe234 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -33,7 +33,7 @@ tifffile; platform_system == "Linux" or platform_system == "Darwin" pandas requests einops -transformers<4.22 # https://github.com/Project-MONAI/MONAI/issues/5157 +transformers<4.22; python_version <= '3.10' # https://github.com/Project-MONAI/MONAI/issues/5157 mlflow>=1.28.0 clearml>=1.10.0rc0 matplotlib!=3.5.0 diff --git a/tests/test_transchex.py b/tests/test_transchex.py index 9ad847cdaa..8fb1f56715 100644 --- a/tests/test_transchex.py +++ b/tests/test_transchex.py @@ -18,7 +18,7 @@ from monai.networks import eval_mode from monai.networks.nets.transchex import Transchex -from tests.utils import skip_if_quick +from tests.utils import SkipIfAtLeastPyTorchVersion, skip_if_quick TEST_CASE_TRANSCHEX = [] for drop_out in [0.4]: @@ -46,6 +46,7 @@ @skip_if_quick +@SkipIfAtLeastPyTorchVersion((1, 10)) class TestTranschex(unittest.TestCase): @parameterized.expand(TEST_CASE_TRANSCHEX) def test_shape(self, input_param, expected_shape): From be576a3102f262eb505e0f54882091b0e4e51b64 Mon Sep 17 00:00:00 2001 From: Kaibo Tang Date: Tue, 5 Dec 2023 03:46:24 -0500 Subject: [PATCH 17/38] 7263 add diffusion loss (#7272) Fixes #7263. ### Description Add diffusion loss. I also made a [demo notebook](https://github.com/kvttt/deep-atlas/blob/main/diffusion_loss_scale_test.ipynb) to provide some explanations and analyses of diffusion loss. ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [x] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [x] In-line docstrings updated. - [x] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: kaibo Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- docs/source/losses.rst | 5 ++ monai/losses/__init__.py | 2 +- monai/losses/deform.py | 82 +++++++++++++++++++++++++ tests/test_diffusion_loss.py | 116 +++++++++++++++++++++++++++++++++++ 4 files changed, 204 insertions(+), 1 deletion(-) create mode 100644 tests/test_diffusion_loss.py diff --git a/docs/source/losses.rst b/docs/source/losses.rst index 568c7dfc77..e929e9d605 100644 --- a/docs/source/losses.rst +++ b/docs/source/losses.rst @@ -96,6 +96,11 @@ Registration Losses .. autoclass:: BendingEnergyLoss :members: +`DiffusionLoss` +~~~~~~~~~~~~~~~ +.. autoclass:: DiffusionLoss + :members: + `LocalNormalizedCrossCorrelationLoss` ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. autoclass:: LocalNormalizedCrossCorrelationLoss diff --git a/monai/losses/__init__.py b/monai/losses/__init__.py index d734a9d44d..92898c81ca 100644 --- a/monai/losses/__init__.py +++ b/monai/losses/__init__.py @@ -14,7 +14,7 @@ from .adversarial_loss import PatchAdversarialLoss from .cldice import SoftclDiceLoss, SoftDiceclDiceLoss from .contrastive import ContrastiveLoss -from .deform import BendingEnergyLoss +from .deform import BendingEnergyLoss, DiffusionLoss from .dice import ( Dice, DiceCELoss, diff --git a/monai/losses/deform.py b/monai/losses/deform.py index dd03a8eb3d..129abeedd2 100644 --- a/monai/losses/deform.py +++ b/monai/losses/deform.py @@ -116,3 +116,85 @@ def forward(self, pred: torch.Tensor) -> torch.Tensor: raise ValueError(f'Unsupported reduction: {self.reduction}, available options are ["mean", "sum", "none"].') return energy + + +class DiffusionLoss(_Loss): + """ + Calculate the diffusion based on first-order differentiation of pred using central finite difference. + For the original paper, please refer to + VoxelMorph: A Learning Framework for Deformable Medical Image Registration, + Guha Balakrishnan, Amy Zhao, Mert R. Sabuncu, John Guttag, Adrian V. Dalca + IEEE TMI: Transactions on Medical Imaging. 2019. eprint arXiv:1809.05231. + + Adapted from: + VoxelMorph (https://github.com/voxelmorph/voxelmorph) + """ + + def __init__(self, normalize: bool = False, reduction: LossReduction | str = LossReduction.MEAN) -> None: + """ + Args: + normalize: + Whether to divide out spatial sizes in order to make the computation roughly + invariant to image scale (i.e. vector field sampling resolution). Defaults to False. + reduction: {``"none"``, ``"mean"``, ``"sum"``} + Specifies the reduction to apply to the output. Defaults to ``"mean"``. + + - ``"none"``: no reduction will be applied. + - ``"mean"``: the sum of the output will be divided by the number of elements in the output. + - ``"sum"``: the output will be summed. + """ + super().__init__(reduction=LossReduction(reduction).value) + self.normalize = normalize + + def forward(self, pred: torch.Tensor) -> torch.Tensor: + """ + Args: + pred: + Predicted dense displacement field (DDF) with shape BCH[WD], + where C is the number of spatial dimensions. + Note that diffusion loss can only be calculated + when the sizes of the DDF along all spatial dimensions are greater than 2. + + Raises: + ValueError: When ``self.reduction`` is not one of ["mean", "sum", "none"]. + ValueError: When ``pred`` is not 3-d, 4-d or 5-d. + ValueError: When any spatial dimension of ``pred`` has size less than or equal to 2. + ValueError: When the number of channels of ``pred`` does not match the number of spatial dimensions. + + """ + if pred.ndim not in [3, 4, 5]: + raise ValueError(f"Expecting 3-d, 4-d or 5-d pred, instead got pred of shape {pred.shape}") + for i in range(pred.ndim - 2): + if pred.shape[-i - 1] <= 2: + raise ValueError(f"All spatial dimensions must be > 2, got spatial dimensions {pred.shape[2:]}") + if pred.shape[1] != pred.ndim - 2: + raise ValueError( + f"Number of vector components, i.e. number of channels of the input DDF, {pred.shape[1]}, " + f"does not match number of spatial dimensions, {pred.ndim - 2}" + ) + + # first order gradient + first_order_gradient = [spatial_gradient(pred, dim) for dim in range(2, pred.ndim)] + + # spatial dimensions in a shape suited for broadcasting below + if self.normalize: + spatial_dims = torch.tensor(pred.shape, device=pred.device)[2:].reshape((1, -1) + (pred.ndim - 2) * (1,)) + + diffusion = torch.tensor(0) + for dim_1, g in enumerate(first_order_gradient): + dim_1 += 2 + if self.normalize: + # We divide the partial derivative for each vector component at each voxel by the spatial size + # corresponding to that component relative to the spatial size of the vector component with respect + # to which the partial derivative is taken. + g *= pred.shape[dim_1] / spatial_dims + diffusion = diffusion + g**2 + + if self.reduction == LossReduction.MEAN.value: + diffusion = torch.mean(diffusion) # the batch and channel average + elif self.reduction == LossReduction.SUM.value: + diffusion = torch.sum(diffusion) # sum over the batch and channel dims + elif self.reduction != LossReduction.NONE.value: + raise ValueError(f'Unsupported reduction: {self.reduction}, available options are ["mean", "sum", "none"].') + + return diffusion diff --git a/tests/test_diffusion_loss.py b/tests/test_diffusion_loss.py new file mode 100644 index 0000000000..05dfab95fb --- /dev/null +++ b/tests/test_diffusion_loss.py @@ -0,0 +1,116 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import numpy as np +import torch +from parameterized import parameterized + +from monai.losses.deform import DiffusionLoss + +device = "cuda" if torch.cuda.is_available() else "cpu" + +TEST_CASES = [ + # all first partials are zero, so the diffusion loss is also zero + [{}, {"pred": torch.ones((1, 3, 5, 5, 5), device=device)}, 0.0], + # all first partials are one, so the diffusion loss is also one + [{}, {"pred": torch.arange(0, 5, device=device)[None, None, None, None, :].expand(1, 3, 5, 5, 5)}, 1.0], + # before expansion, the first partials are 2, 4, 6, so the diffusion loss is (2^2 + 4^2 + 6^2) / 3 = 18.67 + [ + {"normalize": False}, + {"pred": torch.arange(0, 5, device=device)[None, None, None, None, :].expand(1, 3, 5, 5, 5) ** 2}, + 56.0 / 3.0, + ], + # same as the previous case + [ + {"normalize": False}, + {"pred": torch.arange(0, 5, device=device)[None, None, None, :].expand(1, 2, 5, 5) ** 2}, + 56.0 / 3.0, + ], + # same as the previous case + [{"normalize": False}, {"pred": torch.arange(0, 5, device=device)[None, None, :].expand(1, 1, 5) ** 2}, 56.0 / 3.0], + # we have shown in the demo notebook that + # diffusion loss is scale-invariant when the all axes have the same resolution + [ + {"normalize": True}, + {"pred": torch.arange(0, 5, device=device)[None, None, None, None, :].expand(1, 3, 5, 5, 5) ** 2}, + 56.0 / 3.0, + ], + [ + {"normalize": True}, + {"pred": torch.arange(0, 5, device=device)[None, None, None, :].expand(1, 2, 5, 5) ** 2}, + 56.0 / 3.0, + ], + [{"normalize": True}, {"pred": torch.arange(0, 5, device=device)[None, None, :].expand(1, 1, 5) ** 2}, 56.0 / 3.0], + # for the following case, consider the following 2D matrix: + # tensor([[[[0, 1, 2], + # [1, 2, 3], + # [2, 3, 4], + # [3, 4, 5], + # [4, 5, 6]], + # [[0, 1, 2], + # [1, 2, 3], + # [2, 3, 4], + # [3, 4, 5], + # [4, 5, 6]]]]) + # the first partials wrt x are all ones, and so are the first partials wrt y + # the diffusion loss, when normalization is not applied, is 1^2 + 1^2 = 2 + [{"normalize": False}, {"pred": torch.stack([torch.arange(i, i + 3) for i in range(5)]).expand(1, 2, 5, 3)}, 2.0], + # consider the same matrix, this time with normalization applied, using the same notation as in the demo notebook, + # the coefficients to be divided out are (1, 5/3) for partials wrt x and (3/5, 1) for partials wrt y + # the diffusion loss is then (1/1)^2 + (1/(5/3))^2 + (1/(3/5))^2 + (1/1)^2 = (1 + 9/25 + 25/9 + 1) / 2 = 2.5689 + [ + {"normalize": True}, + {"pred": torch.stack([torch.arange(i, i + 3) for i in range(5)]).expand(1, 2, 5, 3)}, + (1.0 + 9.0 / 25.0 + 25.0 / 9.0 + 1.0) / 2.0, + ], +] + + +class TestDiffusionLoss(unittest.TestCase): + @parameterized.expand(TEST_CASES) + def test_shape(self, input_param, input_data, expected_val): + result = DiffusionLoss(**input_param).forward(**input_data) + np.testing.assert_allclose(result.detach().cpu().numpy(), expected_val, rtol=1e-5) + + def test_ill_shape(self): + loss = DiffusionLoss() + # not in 3-d, 4-d, 5-d + with self.assertRaisesRegex(ValueError, "Expecting 3-d, 4-d or 5-d"): + loss.forward(torch.ones((1, 3), device=device)) + with self.assertRaisesRegex(ValueError, "Expecting 3-d, 4-d or 5-d"): + loss.forward(torch.ones((1, 4, 5, 5, 5, 5), device=device)) + with self.assertRaisesRegex(ValueError, "All spatial dimensions"): + loss.forward(torch.ones((1, 3, 2, 5, 5), device=device)) + with self.assertRaisesRegex(ValueError, "All spatial dimensions"): + loss.forward(torch.ones((1, 3, 5, 2, 5))) + with self.assertRaisesRegex(ValueError, "All spatial dimensions"): + loss.forward(torch.ones((1, 3, 5, 5, 2))) + + # number of vector components unequal to number of spatial dims + with self.assertRaisesRegex(ValueError, "Number of vector components"): + loss.forward(torch.ones((1, 2, 5, 5, 5))) + with self.assertRaisesRegex(ValueError, "Number of vector components"): + loss.forward(torch.ones((1, 2, 5, 5, 5))) + + def test_ill_opts(self): + pred = torch.rand(1, 3, 5, 5, 5).to(device=device) + with self.assertRaisesRegex(ValueError, ""): + DiffusionLoss(reduction="unknown")(pred) + with self.assertRaisesRegex(ValueError, ""): + DiffusionLoss(reduction=None)(pred) + + +if __name__ == "__main__": + unittest.main() From 603bd53b3775ed8372a4e19c87c3e0f924de8265 Mon Sep 17 00:00:00 2001 From: Yufan He <59374597+heyufan1995@users.noreply.github.com> Date: Thu, 7 Dec 2023 21:36:21 -0500 Subject: [PATCH 18/38] Fix swinunetrv2 2D bug (#7302) Fixes # . ### Description A few sentences describing the changes proposed in this pull request. ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. Signed-off-by: heyufan1995 --- monai/networks/nets/swin_unetr.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/networks/nets/swin_unetr.py b/monai/networks/nets/swin_unetr.py index 10c4ce3d8e..6f96dfd291 100644 --- a/monai/networks/nets/swin_unetr.py +++ b/monai/networks/nets/swin_unetr.py @@ -1024,7 +1024,7 @@ def __init__( self.layers4.append(layer) if self.use_v2: layerc = UnetrBasicBlock( - spatial_dims=3, + spatial_dims=spatial_dims, in_channels=embed_dim * 2**i_layer, out_channels=embed_dim * 2**i_layer, kernel_size=3, From a5d5f7c517d58d12f0bc90700f5560e6e2372765 Mon Sep 17 00:00:00 2001 From: YunLiu <55491388+KumoLiu@users.noreply.github.com> Date: Tue, 12 Dec 2023 11:07:04 +0800 Subject: [PATCH 19/38] Fix `RuntimeError` in `DataAnalyzer` (#7310) Fixes #7309 ### Description `DataAnalyzer` only catch error when data is on GPU, add catching error when data is on CPU. ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- monai/apps/auto3dseg/data_analyzer.py | 26 ++++++++++++++++---------- monai/auto3dseg/analyzer.py | 2 +- 2 files changed, 17 insertions(+), 11 deletions(-) diff --git a/monai/apps/auto3dseg/data_analyzer.py b/monai/apps/auto3dseg/data_analyzer.py index 9280fb5be5..15e56abfea 100644 --- a/monai/apps/auto3dseg/data_analyzer.py +++ b/monai/apps/auto3dseg/data_analyzer.py @@ -28,7 +28,7 @@ from monai.data import DataLoader, Dataset, partition_dataset from monai.data.utils import no_collation from monai.transforms import Compose, EnsureTyped, LoadImaged, Orientationd -from monai.utils import StrEnum, min_version, optional_import +from monai.utils import ImageMetaKey, StrEnum, min_version, optional_import from monai.utils.enums import DataStatsKeys, ImageStatsKeys @@ -343,19 +343,25 @@ def _get_all_case_stats( d = summarizer(batch_data) except BaseException as err: if "image_meta_dict" in batch_data.keys(): - filename = batch_data["image_meta_dict"]["filename_or_obj"] + filename = batch_data["image_meta_dict"][ImageMetaKey.FILENAME_OR_OBJ] else: - filename = batch_data[self.image_key].meta["filename_or_obj"] + filename = batch_data[self.image_key].meta[ImageMetaKey.FILENAME_OR_OBJ] logger.info(f"Unable to process data {filename} on {device}. {err}") if self.device.type == "cuda": logger.info("DataAnalyzer `device` set to GPU execution hit an exception. Falling back to `cpu`.") - batch_data[self.image_key] = batch_data[self.image_key].to("cpu") - if self.label_key is not None: - label = batch_data[self.label_key] - if not _label_argmax: - label = torch.argmax(label, dim=0) if label.shape[0] > 1 else label[0] - batch_data[self.label_key] = label.to("cpu") - d = summarizer(batch_data) + try: + batch_data[self.image_key] = batch_data[self.image_key].to("cpu") + if self.label_key is not None: + label = batch_data[self.label_key] + if not _label_argmax: + label = torch.argmax(label, dim=0) if label.shape[0] > 1 else label[0] + batch_data[self.label_key] = label.to("cpu") + d = summarizer(batch_data) + except BaseException as err: + logger.info(f"Unable to process data {filename} on {device}. {err}") + continue + else: + continue stats_by_cases = { DataStatsKeys.BY_CASE_IMAGE_PATH: d[DataStatsKeys.BY_CASE_IMAGE_PATH], diff --git a/monai/auto3dseg/analyzer.py b/monai/auto3dseg/analyzer.py index 654999d439..d5cfb21dab 100644 --- a/monai/auto3dseg/analyzer.py +++ b/monai/auto3dseg/analyzer.py @@ -460,7 +460,7 @@ def __call__(self, data: Mapping[Hashable, MetaTensor]) -> dict[Hashable, MetaTe torch.set_grad_enabled(False) ndas: list[MetaTensor] = [d[self.image_key][i] for i in range(d[self.image_key].shape[0])] # type: ignore - ndas_label: MetaTensor = d[self.label_key] # (H,W,D) + ndas_label: MetaTensor = d[self.label_key].astype(torch.int8) # (H,W,D) if ndas_label.shape != ndas[0].shape: raise ValueError(f"Label shape {ndas_label.shape} is different from image shape {ndas[0].shape}") From e7e07a9f19bee609b27e3cf5a8257e0ac8f95ac1 Mon Sep 17 00:00:00 2001 From: YunLiu <55491388+KumoLiu@users.noreply.github.com> Date: Thu, 14 Dec 2023 10:14:19 +0800 Subject: [PATCH 20/38] Support specified filenames in `Saveimage` (#7318) Fixes #7317 ### Description Add support specified filename for users to save like nibabel.save. ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- monai/transforms/io/array.py | 17 ++++++++++++++--- tests/test_save_image.py | 16 ++++++++++++++++ 2 files changed, 30 insertions(+), 3 deletions(-) diff --git a/monai/transforms/io/array.py b/monai/transforms/io/array.py index cd7e4ef090..7222a26fc3 100644 --- a/monai/transforms/io/array.py +++ b/monai/transforms/io/array.py @@ -414,6 +414,9 @@ def __init__( self.fname_formatter = output_name_formatter self.output_ext = output_ext.lower() or output_format.lower() + self.output_ext = ( + f".{self.output_ext}" if self.output_ext and not self.output_ext.startswith(".") else self.output_ext + ) if isinstance(writer, str): writer_, has_built_in = optional_import("monai.data", name=f"{writer}") # search built-in if not has_built_in: @@ -458,15 +461,23 @@ def set_options(self, init_kwargs=None, data_kwargs=None, meta_kwargs=None, writ self.write_kwargs.update(write_kwargs) return self - def __call__(self, img: torch.Tensor | np.ndarray, meta_data: dict | None = None): + def __call__( + self, img: torch.Tensor | np.ndarray, meta_data: dict | None = None, filename: str | PathLike | None = None + ): """ Args: img: target data content that save into file. The image should be channel-first, shape: `[C,H,W,[D]]`. meta_data: key-value pairs of metadata corresponding to the data. + filename: str or file-like object which to save img. + If specified, will ignore `self.output_name_formatter` and `self.folder_layout`. """ meta_data = img.meta if isinstance(img, MetaTensor) else meta_data - kw = self.fname_formatter(meta_data, self) - filename = self.folder_layout.filename(**kw) + if filename is not None: + filename = f"{filename}{self.output_ext}" + else: + kw = self.fname_formatter(meta_data, self) + filename = self.folder_layout.filename(**kw) + if meta_data: meta_spatial_shape = ensure_tuple(meta_data.get("spatial_shape", ())) if len(meta_spatial_shape) >= len(img.shape): diff --git a/tests/test_save_image.py b/tests/test_save_image.py index ba94ab5087..d88db201ce 100644 --- a/tests/test_save_image.py +++ b/tests/test_save_image.py @@ -37,6 +37,8 @@ False, ] +TEST_CASE_5 = [torch.randint(0, 255, (3, 2, 4, 5), dtype=torch.uint8), ".dcm", False] + @unittest.skipUnless(has_itk, "itk not installed") class TestSaveImage(unittest.TestCase): @@ -58,6 +60,20 @@ def test_saved_content(self, test_data, meta_data, output_ext, resample): filepath = "testfile0" if meta_data is not None else "0" self.assertTrue(os.path.exists(os.path.join(tempdir, filepath + "_trans" + output_ext))) + @parameterized.expand([TEST_CASE_5]) + def test_saved_content_with_filename(self, test_data, output_ext, resample): + with tempfile.TemporaryDirectory() as tempdir: + trans = SaveImage( + output_dir=tempdir, + output_ext=output_ext, + resample=resample, + separate_folder=False, # test saving into the same folder + ) + filename = str(os.path.join(tempdir, "test")) + trans(test_data, filename=filename) + + self.assertTrue(os.path.exists(filename + output_ext)) + if __name__ == "__main__": unittest.main() From 90c5d957c3b71a537964859a93e6ba5d645c76ce Mon Sep 17 00:00:00 2001 From: YunLiu <55491388+KumoLiu@users.noreply.github.com> Date: Fri, 15 Dec 2023 11:32:18 +0800 Subject: [PATCH 21/38] Fix typo (#7321) Fix typo. ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- monai/metrics/hausdorff_distance.py | 2 +- monai/metrics/surface_dice.py | 2 +- monai/metrics/surface_distance.py | 2 +- monai/metrics/utils.py | 6 +++--- 4 files changed, 6 insertions(+), 6 deletions(-) diff --git a/monai/metrics/hausdorff_distance.py b/monai/metrics/hausdorff_distance.py index d9bbf17db3..d727eb0567 100644 --- a/monai/metrics/hausdorff_distance.py +++ b/monai/metrics/hausdorff_distance.py @@ -190,7 +190,7 @@ def compute_hausdorff_distance( y[b, c], distance_metric=distance_metric, spacing=spacing_list[b], - symetric=not directed, + symmetric=not directed, class_index=c, ) percentile_distances = [_compute_percentile_hausdorff_distance(d, percentile) for d in distances] diff --git a/monai/metrics/surface_dice.py b/monai/metrics/surface_dice.py index 635eb1bc24..b20b47a1a5 100644 --- a/monai/metrics/surface_dice.py +++ b/monai/metrics/surface_dice.py @@ -253,7 +253,7 @@ def compute_surface_dice( distance_metric=distance_metric, spacing=spacing_list[b], use_subvoxels=use_subvoxels, - symetric=True, + symmetric=True, class_index=c, ) boundary_correct: int | torch.Tensor | float diff --git a/monai/metrics/surface_distance.py b/monai/metrics/surface_distance.py index 7ce632c588..3cb336d6a0 100644 --- a/monai/metrics/surface_distance.py +++ b/monai/metrics/surface_distance.py @@ -177,7 +177,7 @@ def compute_average_surface_distance( y[b, c], distance_metric=distance_metric, spacing=spacing_list[b], - symetric=symmetric, + symmetric=symmetric, class_index=c, ) surface_distance = torch.cat(distances) diff --git a/monai/metrics/utils.py b/monai/metrics/utils.py index 62e6520b96..d4b8f6e9b6 100644 --- a/monai/metrics/utils.py +++ b/monai/metrics/utils.py @@ -295,7 +295,7 @@ def get_edge_surface_distance( distance_metric: str = "euclidean", spacing: int | float | np.ndarray | Sequence[int | float] | None = None, use_subvoxels: bool = False, - symetric: bool = False, + symmetric: bool = False, class_index: int = -1, ) -> tuple[ tuple[torch.Tensor, torch.Tensor], @@ -314,7 +314,7 @@ def get_edge_surface_distance( See :py:func:`monai.metrics.utils.get_surface_distance`. use_subvoxels: whether to use subvoxel resolution (using the spacing). This will return the areas of the edges. - symetric: whether to compute the surface distance from `y_pred` to `y` and from `y` to `y_pred`. + symmetric: whether to compute the surface distance from `y_pred` to `y` and from `y` to `y_pred`. class_index: The class-index used for context when warning about empty ground truth or prediction. Returns: @@ -338,7 +338,7 @@ def get_edge_surface_distance( " this may result in nan/inf distance." ) distances: tuple[torch.Tensor, torch.Tensor] | tuple[torch.Tensor] - if symetric: + if symmetric: distances = ( get_surface_distance(edges_pred, edges_gt, distance_metric, spacing), get_surface_distance(edges_gt, edges_pred, distance_metric, spacing), From ad37c17380ef3fe564a9bafe159e0591471cd20a Mon Sep 17 00:00:00 2001 From: binliunls <107988372+binliunls@users.noreply.github.com> Date: Fri, 15 Dec 2023 22:00:24 +0800 Subject: [PATCH 22/38] fix optimizer pararmeter issue (#7322) Fixes # . ### Description A few sentences describing the changes proposed in this pull request. ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. Signed-off-by: binliu --- monai/handlers/mlflow_handler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/handlers/mlflow_handler.py b/monai/handlers/mlflow_handler.py index a2bd345dc6..df209c1c8b 100644 --- a/monai/handlers/mlflow_handler.py +++ b/monai/handlers/mlflow_handler.py @@ -401,7 +401,7 @@ def _default_iteration_log(self, engine: Engine) -> None: cur_optimizer = engine.optimizer for param_name in self.optimizer_param_names: params = { - f"{param_name} group_{i}": float(param_group[param_name]) + f"{param_name}_group_{i}": float(param_group[param_name]) for i, param_group in enumerate(cur_optimizer.param_groups) } self._log_metrics(params, step=engine.state.iteration) From 3c1dea0834dcdc36b9a2800c785a1d66c6a7d711 Mon Sep 17 00:00:00 2001 From: YunLiu <55491388+KumoLiu@users.noreply.github.com> Date: Mon, 18 Dec 2023 12:00:43 +0800 Subject: [PATCH 23/38] Fix `lazy` ignored in `SpatialPadd` (#7316) Fixes #7314 #7315. ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> Co-authored-by: Ben Murray --- monai/transforms/croppad/dictionary.py | 9 +++------ tests/padders.py | 3 +++ 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/monai/transforms/croppad/dictionary.py b/monai/transforms/croppad/dictionary.py index 56d214c51d..be9441dc4a 100644 --- a/monai/transforms/croppad/dictionary.py +++ b/monai/transforms/croppad/dictionary.py @@ -221,9 +221,8 @@ def __init__( note that `np.pad` treats channel dimension as the first dimension. """ - LazyTransform.__init__(self, lazy) padder = SpatialPad(spatial_size, method, lazy=lazy, **kwargs) - Padd.__init__(self, keys, padder=padder, mode=mode, allow_missing_keys=allow_missing_keys) + Padd.__init__(self, keys, padder=padder, mode=mode, allow_missing_keys=allow_missing_keys, lazy=lazy) class BorderPadd(Padd): @@ -274,9 +273,8 @@ def __init__( note that `np.pad` treats channel dimension as the first dimension. """ - LazyTransform.__init__(self, lazy) padder = BorderPad(spatial_border=spatial_border, lazy=lazy, **kwargs) - Padd.__init__(self, keys, padder=padder, mode=mode, allow_missing_keys=allow_missing_keys) + Padd.__init__(self, keys, padder=padder, mode=mode, allow_missing_keys=allow_missing_keys, lazy=lazy) class DivisiblePadd(Padd): @@ -324,9 +322,8 @@ def __init__( See also :py:class:`monai.transforms.SpatialPad` """ - LazyTransform.__init__(self, lazy) padder = DivisiblePad(k=k, method=method, lazy=lazy, **kwargs) - Padd.__init__(self, keys, padder=padder, mode=mode, allow_missing_keys=allow_missing_keys) + Padd.__init__(self, keys, padder=padder, mode=mode, allow_missing_keys=allow_missing_keys, lazy=lazy) class Cropd(MapTransform, InvertibleTransform, LazyTransform): diff --git a/tests/padders.py b/tests/padders.py index 02d7b40af6..ae1153bdfd 100644 --- a/tests/padders.py +++ b/tests/padders.py @@ -136,6 +136,9 @@ def pad_test_pending_ops(self, input_param, input_shape): # TODO: mode="bilinear" may report error overrides = {"mode": "nearest", "padding_mode": mode[1], "align_corners": False} result = apply_pending(pending_result, overrides=overrides)[0] + # lazy in constructor + pad_fn_lazy = self.Padder(mode=mode[0], lazy=True, **input_param) + self.assertTrue(pad_fn_lazy.lazy) # compare assert_allclose(result, expected, rtol=1e-5) if isinstance(result, MetaTensor) and not isinstance(pad_fn, MapTransform): From c1a0f0662a311093c2d4883355ea4d452b0e459b Mon Sep 17 00:00:00 2001 From: YunLiu <55491388+KumoLiu@users.noreply.github.com> Date: Thu, 28 Dec 2023 22:22:52 +0800 Subject: [PATCH 24/38] Update openslide-python version (#7344) --- requirements-dev.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements-dev.txt b/requirements-dev.txt index cacbefe234..2639c0a3e7 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -27,7 +27,7 @@ ninja torchvision psutil cucim>=23.2.0; platform_system == "Linux" -openslide-python==1.1.2 +openslide-python imagecodecs; platform_system == "Linux" or platform_system == "Darwin" tifffile; platform_system == "Linux" or platform_system == "Darwin" pandas From 58e7db5d443cfccb4de30dd388cc9f10411377d4 Mon Sep 17 00:00:00 2001 From: YunLiu <55491388+KumoLiu@users.noreply.github.com> Date: Fri, 29 Dec 2023 12:33:46 +0800 Subject: [PATCH 25/38] Upgrade the version of `transformers` (#7343) Fixes #7338 ### Description transformers' version is pinned to v4.22 since https://github.com/Project-MONAI/MONAI/issues/5157. Updated the version refer to https://github.com/huggingface/transformers/issues/21678. ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- monai/networks/nets/transchex.py | 49 +++++++++----------------------- requirements-dev.txt | 2 +- tests/test_transchex.py | 3 +- 3 files changed, 15 insertions(+), 39 deletions(-) diff --git a/monai/networks/nets/transchex.py b/monai/networks/nets/transchex.py index ff27903cef..6bfff3c956 100644 --- a/monai/networks/nets/transchex.py +++ b/monai/networks/nets/transchex.py @@ -12,20 +12,17 @@ from __future__ import annotations import math -import os -import shutil -import tarfile -import tempfile from collections.abc import Sequence import torch from torch import nn +from monai.config.type_definitions import PathLike from monai.utils import optional_import transformers = optional_import("transformers") load_tf_weights_in_bert = optional_import("transformers", name="load_tf_weights_in_bert")[0] -cached_path = optional_import("transformers.file_utils", name="cached_path")[0] +cached_file = optional_import("transformers.utils", name="cached_file")[0] BertEmbeddings = optional_import("transformers.models.bert.modeling_bert", name="BertEmbeddings")[0] BertLayer = optional_import("transformers.models.bert.modeling_bert", name="BertLayer")[0] @@ -63,44 +60,16 @@ def from_pretrained( state_dict=None, cache_dir=None, from_tf=False, + path_or_repo_id="bert-base-uncased", + filename="pytorch_model.bin", *inputs, **kwargs, ): - archive_file = "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased.tar.gz" - resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir) - tempdir = None - if os.path.isdir(resolved_archive_file) or from_tf: - serialization_dir = resolved_archive_file - else: - tempdir = tempfile.mkdtemp() - with tarfile.open(resolved_archive_file, "r:gz") as archive: - - def is_within_directory(directory, target): - abs_directory = os.path.abspath(directory) - abs_target = os.path.abspath(target) - - prefix = os.path.commonprefix([abs_directory, abs_target]) - - return prefix == abs_directory - - def safe_extract(tar, path=".", members=None, *, numeric_owner=False): - for member in tar.getmembers(): - member_path = os.path.join(path, member.name) - if not is_within_directory(path, member_path): - raise Exception("Attempted Path Traversal in Tar File") - - tar.extractall(path, members, numeric_owner=numeric_owner) - - safe_extract(archive, tempdir) - serialization_dir = tempdir + weights_path = cached_file(path_or_repo_id, filename, cache_dir=cache_dir) model = cls(num_language_layers, num_vision_layers, num_mixed_layers, bert_config, *inputs, **kwargs) if state_dict is None and not from_tf: - weights_path = os.path.join(serialization_dir, "pytorch_model.bin") state_dict = torch.load(weights_path, map_location="cpu" if not torch.cuda.is_available() else None) - if tempdir: - shutil.rmtree(tempdir) if from_tf: - weights_path = os.path.join(serialization_dir, "model.ckpt") return load_tf_weights_in_bert(model, weights_path) old_keys = [] new_keys = [] @@ -304,6 +273,8 @@ def __init__( chunk_size_feed_forward: int = 0, is_decoder: bool = False, add_cross_attention: bool = False, + path_or_repo_id: str | PathLike = "bert-base-uncased", + filename: str = "pytorch_model.bin", ) -> None: """ Args: @@ -315,6 +286,10 @@ def __init__( num_vision_layers: number of vision transformer layers. num_mixed_layers: number of mixed transformer layers. drop_out: fraction of the input units to drop. + path_or_repo_id: This can be either: + - a string, the *model id* of a model repo on huggingface.co. + - a path to a *directory* potentially containing the file. + filename: The name of the file to locate in `path_or_repo`. The other parameters are part of the `bert_config` to `MultiModal.from_pretrained`. @@ -369,6 +344,8 @@ def __init__( num_vision_layers=num_vision_layers, num_mixed_layers=num_mixed_layers, bert_config=bert_config, + path_or_repo_id=path_or_repo_id, + filename=filename, ) self.patch_size = patch_size diff --git a/requirements-dev.txt b/requirements-dev.txt index 2639c0a3e7..4685cd1572 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -33,7 +33,7 @@ tifffile; platform_system == "Linux" or platform_system == "Darwin" pandas requests einops -transformers<4.22; python_version <= '3.10' # https://github.com/Project-MONAI/MONAI/issues/5157 +transformers>=4.36.0 mlflow>=1.28.0 clearml>=1.10.0rc0 matplotlib!=3.5.0 diff --git a/tests/test_transchex.py b/tests/test_transchex.py index 8fb1f56715..9ad847cdaa 100644 --- a/tests/test_transchex.py +++ b/tests/test_transchex.py @@ -18,7 +18,7 @@ from monai.networks import eval_mode from monai.networks.nets.transchex import Transchex -from tests.utils import SkipIfAtLeastPyTorchVersion, skip_if_quick +from tests.utils import skip_if_quick TEST_CASE_TRANSCHEX = [] for drop_out in [0.4]: @@ -46,7 +46,6 @@ @skip_if_quick -@SkipIfAtLeastPyTorchVersion((1, 10)) class TestTranschex(unittest.TestCase): @parameterized.expand(TEST_CASE_TRANSCHEX) def test_shape(self, input_param, expected_shape): From 3b13ddd6248789632dcd4ce9d5af6285f0fce668 Mon Sep 17 00:00:00 2001 From: vgrau98 Date: Sat, 30 Dec 2023 17:47:31 +0100 Subject: [PATCH 26/38] transformer block local window attention Signed-off-by: vgrau98 --- monai/networks/blocks/transformerblock.py | 73 ++++++++++++++++++++++- 1 file changed, 72 insertions(+), 1 deletion(-) diff --git a/monai/networks/blocks/transformerblock.py b/monai/networks/blocks/transformerblock.py index ddf959dad2..baad0780a7 100644 --- a/monai/networks/blocks/transformerblock.py +++ b/monai/networks/blocks/transformerblock.py @@ -11,7 +11,11 @@ from __future__ import annotations +from typing import Tuple + +import torch import torch.nn as nn +import torch.nn.functional as F from monai.networks.blocks.mlp import MLPBlock from monai.networks.blocks.selfattention import SABlock @@ -31,6 +35,7 @@ def __init__( dropout_rate: float = 0.0, qkv_bias: bool = False, save_attn: bool = False, + window_size: int = 0, ) -> None: """ Args: @@ -40,6 +45,10 @@ def __init__( dropout_rate (float, optional): fraction of the input units to drop. Defaults to 0.0. qkv_bias (bool, optional): apply bias term for the qkv linear layer. Defaults to False. save_attn (bool, optional): to make accessible the attention matrix. Defaults to False. + window_size (int): Window size for local attention as used in Segment Anything https://arxiv.org/abs/2304.02643. + If 0, global attention used. Only 2D inputs are supported for local attention (window_size > 0). + If local attention is used, the input tensor should have the following shape during the forward pass: [B, H, W, C]. + See https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py. """ @@ -55,8 +64,70 @@ def __init__( self.norm1 = nn.LayerNorm(hidden_size) self.attn = SABlock(hidden_size, num_heads, dropout_rate, qkv_bias, save_attn) self.norm2 = nn.LayerNorm(hidden_size) + self.window_size = window_size def forward(self, x): - x = x + self.attn(self.norm1(x)) + shortcut = x + x = self.norm1(x) + # Window partition + if self.window_size > 0: + h, w = x.shape[1], x.shape[2] + x, pad_hw = window_partition(x, self.window_size) + + x = self.attn(x) + # Reverse window partition + if self.window_size > 0: + x = window_unpartition(x, self.window_size, pad_hw, (h, w)) + + x = shortcut + x x = x + self.mlp(self.norm2(x)) return x + + +def window_partition(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, Tuple[int, int]]: + """ + Partition into non-overlapping windows with padding if needed. Support only 2D. + Args: + x (tensor): input tokens with [B, H, W, C]. + window_size (int): window size. + + Returns: + windows: windows after partition with [B * num_windows, window_size, window_size, C]. + (Hp, Wp): padded height and width before partition + """ + batch, h, w, c = x.shape + + pad_h = (window_size - h % window_size) % window_size + pad_w = (window_size - w % window_size) % window_size + if pad_h > 0 or pad_w > 0: + x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h)) + hp, wp = h + pad_h, w + pad_w + + x = x.view(batch, hp // window_size, window_size, wp // window_size, window_size, c) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, c) + return windows, (hp, wp) + + +def window_unpartition( + windows: torch.Tensor, window_size: int, pad_hw: Tuple[int, int], hw: Tuple[int, int] +) -> torch.Tensor: + """ + Window unpartition into original sequences and removing padding. + Args: + windows (tensor): input tokens with [B * num_windows, window_size, window_size, C]. + window_size (int): window size. + pad_hw (Tuple): padded height and width (hp, wp). + hw (Tuple): original height and width (H, W) before padding. + + Returns: + x: unpartitioned sequences with [B, H, W, C]. + """ + hp, wp = pad_hw + h, w = hw + batch = windows.shape[0] // (hp * wp // window_size // window_size) + x = windows.view(batch, hp // window_size, wp // window_size, window_size, window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(batch, hp, wp, -1) + + if hp > h or wp > w: + x = x[:, :h, :w, :].contiguous() + return x From 867ec07689a0aeffe8592af92e8f4334681d7aac Mon Sep 17 00:00:00 2001 From: vgrau98 Date: Tue, 2 Jan 2024 23:15:46 +0100 Subject: [PATCH 27/38] fix: window partition input shapes Signed-off-by: vgrau98 --- monai/networks/blocks/transformerblock.py | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/monai/networks/blocks/transformerblock.py b/monai/networks/blocks/transformerblock.py index baad0780a7..d414248680 100644 --- a/monai/networks/blocks/transformerblock.py +++ b/monai/networks/blocks/transformerblock.py @@ -19,6 +19,10 @@ from monai.networks.blocks.mlp import MLPBlock from monai.networks.blocks.selfattention import SABlock +from monai.utils import optional_import + + +rearrange, _ = optional_import("einops", name="rearrange") class TransformerBlock(nn.Module): @@ -36,6 +40,7 @@ def __init__( qkv_bias: bool = False, save_attn: bool = False, window_size: int = 0, + input_size: Tuple = (), ) -> None: """ Args: @@ -47,8 +52,8 @@ def __init__( save_attn (bool, optional): to make accessible the attention matrix. Defaults to False. window_size (int): Window size for local attention as used in Segment Anything https://arxiv.org/abs/2304.02643. If 0, global attention used. Only 2D inputs are supported for local attention (window_size > 0). - If local attention is used, the input tensor should have the following shape during the forward pass: [B, H, W, C]. See https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py. + input_size (Tuple): spatial input dimensions (h, w, and d). Has to be set if local window attention is used. """ @@ -65,19 +70,28 @@ def __init__( self.attn = SABlock(hidden_size, num_heads, dropout_rate, qkv_bias, save_attn) self.norm2 = nn.LayerNorm(hidden_size) self.window_size = window_size + self.input_size = input_size def forward(self, x): + """ + Args: + x (Tensor): [b x (s_dim_1 * … * s_dim_n) x dim] + """ shortcut = x x = self.norm1(x) # Window partition if self.window_size > 0: - h, w = x.shape[1], x.shape[2] + h, w = self.input_size + x = rearrange(x, "b (h w) d -> b h w d", h=h, w=w) x, pad_hw = window_partition(x, self.window_size) + x = rearrange(x, "b h w d -> b (h w) d", h=self.window_size, w=self.window_size) x = self.attn(x) # Reverse window partition if self.window_size > 0: + x = rearrange(x, "b (h w) d -> b h w d", h=self.window_size, w=self.window_size) x = window_unpartition(x, self.window_size, pad_hw, (h, w)) + x = rearrange(x, "b h w d -> b (h w) d", h=h, w=w) x = shortcut + x x = x + self.mlp(self.norm2(x)) From 980a8320d3a379e68d948ba6db05ffce4cd496f1 Mon Sep 17 00:00:00 2001 From: vgrau98 Date: Tue, 2 Jan 2024 23:32:31 +0100 Subject: [PATCH 28/38] fix: error handling Signed-off-by: vgrau98 --- monai/networks/blocks/transformerblock.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/monai/networks/blocks/transformerblock.py b/monai/networks/blocks/transformerblock.py index d414248680..4e3c2a0508 100644 --- a/monai/networks/blocks/transformerblock.py +++ b/monai/networks/blocks/transformerblock.py @@ -65,6 +65,11 @@ def __init__( if hidden_size % num_heads != 0: raise ValueError("hidden_size should be divisible by num_heads.") + if window_size > 0 and len(input_size) not in [2, 3]: + raise ValueError( + "If local window attention is used (window_size > 0), input_size should be specified: (h, w) or (h, w, d)" + ) + self.mlp = MLPBlock(hidden_size, mlp_dim, dropout_rate) self.norm1 = nn.LayerNorm(hidden_size) self.attn = SABlock(hidden_size, num_heads, dropout_rate, qkv_bias, save_attn) @@ -72,7 +77,7 @@ def __init__( self.window_size = window_size self.input_size = input_size - def forward(self, x): + def forward(self, x: torch.Tensor): """ Args: x (Tensor): [b x (s_dim_1 * … * s_dim_n) x dim] @@ -81,6 +86,11 @@ def forward(self, x): x = self.norm1(x) # Window partition if self.window_size > 0: + if x.shape[1] != int(torch.prod(torch.tensor(self.input_size))): + raise ValueError( + f"Input tensor spatial dimension {x.shape[1]} should be equal to {self.input_size} product" + ) + h, w = self.input_size x = rearrange(x, "b (h w) d -> b h w d", h=h, w=w) x, pad_hw = window_partition(x, self.window_size) From c2f121b22a5d3f346b3f25c6198898ae3801ee2c Mon Sep 17 00:00:00 2001 From: vgrau98 Date: Tue, 2 Jan 2024 23:32:59 +0100 Subject: [PATCH 29/38] local window attention tests Signed-off-by: vgrau98 --- monai/networks/blocks/transformerblock.py | 1 - tests/test_transformerblock.py | 42 +++++++++++++++++++++++ 2 files changed, 42 insertions(+), 1 deletion(-) diff --git a/monai/networks/blocks/transformerblock.py b/monai/networks/blocks/transformerblock.py index 4e3c2a0508..0a03e685a0 100644 --- a/monai/networks/blocks/transformerblock.py +++ b/monai/networks/blocks/transformerblock.py @@ -21,7 +21,6 @@ from monai.networks.blocks.selfattention import SABlock from monai.utils import optional_import - rearrange, _ = optional_import("einops", name="rearrange") diff --git a/tests/test_transformerblock.py b/tests/test_transformerblock.py index 914336668d..193efdd16c 100644 --- a/tests/test_transformerblock.py +++ b/tests/test_transformerblock.py @@ -12,6 +12,7 @@ from __future__ import annotations import unittest +from unittest import skipUnless import numpy as np import torch @@ -19,6 +20,9 @@ from monai.networks import eval_mode from monai.networks.blocks.transformerblock import TransformerBlock +from monai.utils import optional_import + +einops, has_einops = optional_import("einops") TEST_CASE_TRANSFORMERBLOCK = [] for dropout_rate in np.linspace(0, 1, 4): @@ -37,6 +41,22 @@ ] TEST_CASE_TRANSFORMERBLOCK.append(test_case) +TEST_CASE_TRANSFORMERBLOCK_LOCAL_WIN = [] +for window_size in [0, 2, 3, 4]: + test_case = [ + { + "hidden_size": 360, + "num_heads": 4, + "mlp_dim": 1024, + "dropout_rate": 0, + "window_size": window_size, + "input_size": (4, 4), + }, + (2, 16, 360), + (2, 16, 360), + ] + TEST_CASE_TRANSFORMERBLOCK_LOCAL_WIN.append(test_case) + class TestTransformerBlock(unittest.TestCase): @parameterized.expand(TEST_CASE_TRANSFORMERBLOCK) @@ -53,6 +73,20 @@ def test_ill_arg(self): with self.assertRaises(ValueError): TransformerBlock(hidden_size=622, num_heads=8, mlp_dim=3072, dropout_rate=0.4) + with self.assertRaises(ValueError): + TransformerBlock(hidden_size=360, num_heads=4, mlp_dim=1024, dropout_rate=0, window_size=2) + + with self.assertRaises(ValueError): + TransformerBlock( + hidden_size=360, num_heads=4, mlp_dim=1024, dropout_rate=0, window_size=2, input_size=(1, 1, 1, 1) + ) + + with self.assertRaises(ValueError): + t_block = TransformerBlock( + hidden_size=360, num_heads=4, mlp_dim=1024, dropout_rate=0, window_size=2, input_size=(3, 3) + ) + t_block(torch.randn((2, 10, 360))) + def test_access_attn_matrix(self): # input format hidden_size = 128 @@ -77,6 +111,14 @@ def test_access_attn_matrix(self): matrix_acess_blk(torch.randn(input_shape)) assert matrix_acess_blk.attn.att_mat.shape == (input_shape[0], input_shape[0], input_shape[1], input_shape[1]) + @parameterized.expand(TEST_CASE_TRANSFORMERBLOCK_LOCAL_WIN) + @skipUnless(has_einops, "Requires einops") + def test_local_window(self, input_param, input_shape, expected_shape): + net = TransformerBlock(**input_param) + with eval_mode(net): + result = net(torch.randn(input_shape)) + self.assertEqual(result.shape, expected_shape) + if __name__ == "__main__": unittest.main() From 46ee2b02adc1a7dd124699762e460f079d8f8cf5 Mon Sep 17 00:00:00 2001 From: vgrau98 Date: Wed, 3 Jan 2024 00:29:12 +0100 Subject: [PATCH 30/38] feat: 3d local window attention Signed-off-by: vgrau98 --- monai/networks/blocks/transformerblock.py | 82 +++++++++++++++++++++-- 1 file changed, 75 insertions(+), 7 deletions(-) diff --git a/monai/networks/blocks/transformerblock.py b/monai/networks/blocks/transformerblock.py index 0a03e685a0..06e38bc8b9 100644 --- a/monai/networks/blocks/transformerblock.py +++ b/monai/networks/blocks/transformerblock.py @@ -90,17 +90,32 @@ def forward(self, x: torch.Tensor): f"Input tensor spatial dimension {x.shape[1]} should be equal to {self.input_size} product" ) - h, w = self.input_size - x = rearrange(x, "b (h w) d -> b h w d", h=h, w=w) - x, pad_hw = window_partition(x, self.window_size) - x = rearrange(x, "b h w d -> b (h w) d", h=self.window_size, w=self.window_size) + if len(self.input_size) == 2: + x = rearrange(x, "b (h w) c -> b h w c", h=self.input_size[0], w=self.input_size[1]) + x, pad_hw = window_partition(x, self.window_size) + x = rearrange(x, "b h w c -> b (h w) c", h=self.window_size, w=self.window_size) + elif len(self.input_size) == 3: + x = rearrange( + x, "b (h w d) c -> b h w d c", h=self.input_size[0], w=self.input_size[1], d=self.input_size[2] + ) + x, pad_hwd = window_partition_3d(x, self.window_size) + x = rearrange(x, "b h w d c -> b (h w d) c", h=self.window_size, w=self.window_size, d=self.window_size) x = self.attn(x) # Reverse window partition if self.window_size > 0: - x = rearrange(x, "b (h w) d -> b h w d", h=self.window_size, w=self.window_size) - x = window_unpartition(x, self.window_size, pad_hw, (h, w)) - x = rearrange(x, "b h w d -> b (h w) d", h=h, w=w) + if len(self.input_size) == 2: + x = rearrange(x, "b (h w) c -> b h w c", h=self.window_size, w=self.window_size) + x = window_unpartition(x, self.window_size, pad_hw, (self.input_size[0], self.input_size[1])) + x = rearrange(x, "b h w c -> b (h w) c", h=self.input_size[0], w=self.input_size[1]) + elif len(self.input_size) == 3: + x = rearrange(x, "b (h w d) c -> b h w d c", h=self.window_size, w=self.window_size, d=self.window_size) + x = window_unpartition_3d( + x, self.window_size, pad_hwd, (self.input_size[0], self.input_size[1], self.input_size[2]) + ) + x = rearrange( + x, "b h w d c -> b (h w d) c", h=self.input_size[0], w=self.input_size[1], d=self.input_size[2] + ) x = shortcut + x x = x + self.mlp(self.norm2(x)) @@ -131,6 +146,32 @@ def window_partition(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, T return windows, (hp, wp) +def window_partition_3d(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, Tuple[int, int, int]]: + """ + Partition into non-overlapping windows with padding if needed. 3d implementation. + Args: + x (tensor): input tokens with [B, H, W, D, C]. + window_size (int): window size. + + Returns: + windows: windows after partition with [B * num_windows, window_size, window_size, window_size, C]. + (Hp, Wp, Dp): padded height, width and depth before partition + """ + batch, h, w, d, c = x.shape + + pad_h = (window_size - h % window_size) % window_size + pad_w = (window_size - w % window_size) % window_size + pad_d = (window_size - d % window_size) % window_size + if pad_h > 0 or pad_w > 0 or pad_d > 0: + x = F.pad(x, (0, 0, 0, pad_d, 0, pad_w, 0, pad_h)) + hp, wp, dp = h + pad_h, w + pad_w, d + pad_d + + x = x.view(batch, hp // window_size, window_size, wp // window_size, window_size, dp // window_size, window_size, c) + windows = x.permute(0, 1, 3, 5, 2, 4, 6, 7).contiguous().view(-1, window_size, window_size, window_size, c) + return windows, (hp, wp, dp) + ... + + def window_unpartition( windows: torch.Tensor, window_size: int, pad_hw: Tuple[int, int], hw: Tuple[int, int] ) -> torch.Tensor: @@ -154,3 +195,30 @@ def window_unpartition( if hp > h or wp > w: x = x[:, :h, :w, :].contiguous() return x + + +def window_unpartition_3d( + windows: torch.Tensor, window_size: int, pad_hwd: Tuple[int, int, int], hwd: Tuple[int, int, int] +) -> torch.Tensor: + """ + Window unpartition into original sequences and removing padding. 3d implementation. + Args: + windows (tensor): input tokens with [B * num_windows, window_size, window_size, window_size, C]. + window_size (int): window size. + pad_hwd (Tuple): padded height, width and depth (hp, wp, dp). + hwd (Tuple): original height, width and depth (H, W, D) before padding. + + Returns: + x: unpartitioned sequences with [B, H, W, D, C]. + """ + hp, wp, dp = pad_hwd + h, w, d = hwd + batch = windows.shape[0] // (hp * wp * dp // window_size // window_size // window_size) + x = windows.view( + batch, hp // window_size, wp // window_size, dp // window_size, window_size, window_size, window_size, -1 + ) + x = x.permute(0, 1, 3, 5, 2, 4, 6, 7).contiguous().view(batch, hp, wp, dp, -1) + + if hp > h or wp > w or dp > d: + x = x[:, :h, :w, :d, :].contiguous() + return x From dc7efd4ae52b29a6202055a0caf2235444ce3f63 Mon Sep 17 00:00:00 2001 From: vgrau98 Date: Wed, 3 Jan 2024 00:33:10 +0100 Subject: [PATCH 31/38] 3d local attention window tests Signed-off-by: vgrau98 --- monai/networks/blocks/transformerblock.py | 2 +- tests/test_transformerblock.py | 24 +++++++++++++++++++++++ 2 files changed, 25 insertions(+), 1 deletion(-) diff --git a/monai/networks/blocks/transformerblock.py b/monai/networks/blocks/transformerblock.py index 06e38bc8b9..d0d305ad42 100644 --- a/monai/networks/blocks/transformerblock.py +++ b/monai/networks/blocks/transformerblock.py @@ -50,7 +50,7 @@ def __init__( qkv_bias (bool, optional): apply bias term for the qkv linear layer. Defaults to False. save_attn (bool, optional): to make accessible the attention matrix. Defaults to False. window_size (int): Window size for local attention as used in Segment Anything https://arxiv.org/abs/2304.02643. - If 0, global attention used. Only 2D inputs are supported for local attention (window_size > 0). + If 0, global attention used. See https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py. input_size (Tuple): spatial input dimensions (h, w, and d). Has to be set if local window attention is used. diff --git a/tests/test_transformerblock.py b/tests/test_transformerblock.py index 193efdd16c..5c49afe96b 100644 --- a/tests/test_transformerblock.py +++ b/tests/test_transformerblock.py @@ -57,6 +57,22 @@ ] TEST_CASE_TRANSFORMERBLOCK_LOCAL_WIN.append(test_case) +TEST_CASE_TRANSFORMERBLOCK_LOCAL_WIN_3D = [] +for window_size in [0, 2, 3, 4]: + test_case = [ + { + "hidden_size": 360, + "num_heads": 4, + "mlp_dim": 1024, + "dropout_rate": 0, + "window_size": window_size, + "input_size": (3, 3, 3), + }, + (2, 27, 360), + (2, 27, 360), + ] + TEST_CASE_TRANSFORMERBLOCK_LOCAL_WIN_3D.append(test_case) + class TestTransformerBlock(unittest.TestCase): @parameterized.expand(TEST_CASE_TRANSFORMERBLOCK) @@ -119,6 +135,14 @@ def test_local_window(self, input_param, input_shape, expected_shape): result = net(torch.randn(input_shape)) self.assertEqual(result.shape, expected_shape) + @parameterized.expand(TEST_CASE_TRANSFORMERBLOCK_LOCAL_WIN_3D) + @skipUnless(has_einops, "Requires einops") + def test_local_window_3d(self, input_param, input_shape, expected_shape): + net = TransformerBlock(**input_param) + with eval_mode(net): + result = net(torch.randn(input_shape)) + self.assertEqual(result.shape, expected_shape) + if __name__ == "__main__": unittest.main() From 1ab0a740ee76e25f91f6769e8239163b5ef1e9ed Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 2 Jan 2024 03:34:50 +0000 Subject: [PATCH 32/38] Bump github/codeql-action from 2 to 3 (#7354) Bumps [github/codeql-action](https://github.com/github/codeql-action) from 2 to 3.
Release notes

Sourced from github/codeql-action's releases.

CodeQL Bundle v2.15.5

Bundles CodeQL CLI v2.15.5

Includes the following CodeQL language packs from github/codeql@codeql-cli/v2.15.5:

CodeQL Bundle v2.15.4

Bundles CodeQL CLI v2.15.4

Includes the following CodeQL language packs from github/codeql@codeql-cli/v2.15.4:

CodeQL Bundle

Bundles CodeQL CLI v2.15.3

Includes the following CodeQL language packs from github/codeql@codeql-cli/v2.15.3:

... (truncated)

Changelog

Sourced from github/codeql-action's changelog.

Commits
  • e0c2b0a change version numbers inside processing function as well
  • 8e4a6c7 improve handling of changelog processing for backports
  • 511f073 Merge pull request #2033 from github/dependabot/npm_and_yarn/npm-0a98872b3d
  • ebf5a83 Merge pull request #2035 from github/mergeback/v3.22.11-to-main-b374143c
  • 7813bda Update checked-in dependencies
  • 2b2fb6b Update changelog and version after v3.22.11
  • b374143 Merge pull request #2034 from github/update-v3.22.11-64e61baea
  • 95591ba Merge branch 'main' into dependabot/npm_and_yarn/npm-0a98872b3d
  • e2b5cc7 Update changelog for v3.22.11
  • See full diff in compare view

[![Dependabot compatibility score](https://dependabot-badges.githubapp.com/badges/compatibility_score?dependency-name=github/codeql-action&package-manager=github_actions&previous-version=2&new-version=3)](https://docs.github.com/en/github/managing-security-vulnerabilities/about-dependabot-security-updates#about-compatibility-scores) Dependabot will resolve any conflicts with this PR as long as you don't alter it yourself. You can also trigger a rebase manually by commenting `@dependabot rebase`. [//]: # (dependabot-automerge-start) [//]: # (dependabot-automerge-end) ---
Dependabot commands and options
You can trigger Dependabot actions by commenting on this PR: - `@dependabot rebase` will rebase this PR - `@dependabot recreate` will recreate this PR, overwriting any edits that have been made to it - `@dependabot merge` will merge this PR after your CI passes on it - `@dependabot squash and merge` will squash and merge this PR after your CI passes on it - `@dependabot cancel merge` will cancel a previously requested merge and block automerging - `@dependabot reopen` will reopen this PR if it is closed - `@dependabot close` will close this PR and stop Dependabot recreating it. You can achieve the same result by closing it manually - `@dependabot show ignore conditions` will show all of the ignore conditions of the specified dependency - `@dependabot ignore this major version` will close this PR and stop Dependabot creating any more for this major version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this minor version` will close this PR and stop Dependabot creating any more for this minor version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this dependency` will close this PR and stop Dependabot creating any more for this dependency (unless you reopen the PR or upgrade to it yourself)
Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- .github/workflows/codeql-analysis.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/codeql-analysis.yml b/.github/workflows/codeql-analysis.yml index 3d32ae407a..18f1519b5a 100644 --- a/.github/workflows/codeql-analysis.yml +++ b/.github/workflows/codeql-analysis.yml @@ -42,7 +42,7 @@ jobs: # Initializes the CodeQL tools for scanning. - name: Initialize CodeQL - uses: github/codeql-action/init@v2 + uses: github/codeql-action/init@v3 with: languages: ${{ matrix.language }} # If you wish to specify custom queries, you can do so here or in a config file. @@ -72,4 +72,4 @@ jobs: BUILD_MONAI=1 ./runtests.sh --build - name: Perform CodeQL Analysis - uses: github/codeql-action/analyze@v2 + uses: github/codeql-action/analyze@v3 From fe6d2e36babb646ad8cd7c0706854d452cd196a3 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 2 Jan 2024 06:42:41 +0000 Subject: [PATCH 33/38] Bump actions/upload-artifact from 3 to 4 (#7350) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Bumps [actions/upload-artifact](https://github.com/actions/upload-artifact) from 3 to 4.
Release notes

Sourced from actions/upload-artifact's releases.

v4.0.0

What's Changed

The release of upload-artifact@v4 and download-artifact@v4 are major changes to the backend architecture of Artifacts. They have numerous performance and behavioral improvements.

For more information, see the @​actions/artifact documentation.

New Contributors

Full Changelog: https://github.com/actions/upload-artifact/compare/v3...v4.0.0

v3.1.3

What's Changed

Full Changelog: https://github.com/actions/upload-artifact/compare/v3...v3.1.3

v3.1.2

  • Update all @actions/* NPM packages to their latest versions- #374
  • Update all dev dependencies to their most recent versions - #375

v3.1.1

  • Update actions/core package to latest version to remove set-output deprecation warning #351

v3.1.0

What's Changed

Commits

[![Dependabot compatibility score](https://dependabot-badges.githubapp.com/badges/compatibility_score?dependency-name=actions/upload-artifact&package-manager=github_actions&previous-version=3&new-version=4)](https://docs.github.com/en/github/managing-security-vulnerabilities/about-dependabot-security-updates#about-compatibility-scores) Dependabot will resolve any conflicts with this PR as long as you don't alter it yourself. You can also trigger a rebase manually by commenting `@dependabot rebase`. [//]: # (dependabot-automerge-start) [//]: # (dependabot-automerge-end) ---
Dependabot commands and options
You can trigger Dependabot actions by commenting on this PR: - `@dependabot rebase` will rebase this PR - `@dependabot recreate` will recreate this PR, overwriting any edits that have been made to it - `@dependabot merge` will merge this PR after your CI passes on it - `@dependabot squash and merge` will squash and merge this PR after your CI passes on it - `@dependabot cancel merge` will cancel a previously requested merge and block automerging - `@dependabot reopen` will reopen this PR if it is closed - `@dependabot close` will close this PR and stop Dependabot recreating it. You can achieve the same result by closing it manually - `@dependabot show ignore conditions` will show all of the ignore conditions of the specified dependency - `@dependabot ignore this major version` will close this PR and stop Dependabot creating any more for this major version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this minor version` will close this PR and stop Dependabot creating any more for this minor version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this dependency` will close this PR and stop Dependabot creating any more for this dependency (unless you reopen the PR or upgrade to it yourself)
Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- .github/workflows/docker.yml | 2 +- .github/workflows/release.yml | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/docker.yml b/.github/workflows/docker.yml index f51e4fdf76..f80a4c2c96 100644 --- a/.github/workflows/docker.yml +++ b/.github/workflows/docker.yml @@ -36,7 +36,7 @@ jobs: python setup.py build cat build/lib/monai/_version.py - name: Upload version - uses: actions/upload-artifact@v3 + uses: actions/upload-artifact@v4 with: name: _version.py path: build/lib/monai/_version.py diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 7197215486..e9817e1c4c 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -66,7 +66,7 @@ jobs: - if: matrix.python-version == '3.9' && startsWith(github.ref, 'refs/tags/') name: Upload artifacts - uses: actions/upload-artifact@v3 + uses: actions/upload-artifact@v4 with: name: dist path: dist/ @@ -108,7 +108,7 @@ jobs: python setup.py build cat build/lib/monai/_version.py - name: Upload version - uses: actions/upload-artifact@v3 + uses: actions/upload-artifact@v4 with: name: _version.py path: build/lib/monai/_version.py From f46c5097daacf77ed4a35a016086b9c58e08573e Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 2 Jan 2024 17:29:09 +0800 Subject: [PATCH 34/38] Bump actions/setup-python from 4 to 5 (#7351) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Bumps [actions/setup-python](https://github.com/actions/setup-python) from 4 to 5.
Release notes

Sourced from actions/setup-python's releases.

v5.0.0

What's Changed

In scope of this release, we update node version runtime from node16 to node20 (actions/setup-python#772). Besides, we update dependencies to the latest versions.

Full Changelog: https://github.com/actions/setup-python/compare/v4.8.0...v5.0.0

v4.8.0

What's Changed

In scope of this release we added support for GraalPy (actions/setup-python#694). You can use this snippet to set up GraalPy:

steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v4
  with:
    python-version: 'graalpy-22.3'
- run: python my_script.py

Besides, the release contains such changes as:

New Contributors

Full Changelog: https://github.com/actions/setup-python/compare/v4...v4.8.0

v4.7.1

What's Changed

Full Changelog: https://github.com/actions/setup-python/compare/v4...v4.7.1

v4.7.0

In scope of this release, the support for reading python version from pyproject.toml was added (actions/setup-python#669).

      - name: Setup Python
        uses: actions/setup-python@v4
</tr></table>

... (truncated)

Commits

[![Dependabot compatibility score](https://dependabot-badges.githubapp.com/badges/compatibility_score?dependency-name=actions/setup-python&package-manager=github_actions&previous-version=4&new-version=5)](https://docs.github.com/en/github/managing-security-vulnerabilities/about-dependabot-security-updates#about-compatibility-scores) Dependabot will resolve any conflicts with this PR as long as you don't alter it yourself. You can also trigger a rebase manually by commenting `@dependabot rebase`. [//]: # (dependabot-automerge-start) [//]: # (dependabot-automerge-end) ---
Dependabot commands and options
You can trigger Dependabot actions by commenting on this PR: - `@dependabot rebase` will rebase this PR - `@dependabot recreate` will recreate this PR, overwriting any edits that have been made to it - `@dependabot merge` will merge this PR after your CI passes on it - `@dependabot squash and merge` will squash and merge this PR after your CI passes on it - `@dependabot cancel merge` will cancel a previously requested merge and block automerging - `@dependabot reopen` will reopen this PR if it is closed - `@dependabot close` will close this PR and stop Dependabot recreating it. You can achieve the same result by closing it manually - `@dependabot show ignore conditions` will show all of the ignore conditions of the specified dependency - `@dependabot ignore this major version` will close this PR and stop Dependabot creating any more for this major version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this minor version` will close this PR and stop Dependabot creating any more for this minor version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this dependency` will close this PR and stop Dependabot creating any more for this dependency (unless you reopen the PR or upgrade to it yourself)
Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- .github/workflows/cron-ngc-bundle.yml | 2 +- .github/workflows/docker.yml | 2 +- .github/workflows/pythonapp-min.yml | 6 +++--- .github/workflows/pythonapp.yml | 8 ++++---- .github/workflows/release.yml | 4 ++-- .github/workflows/setupapp.yml | 4 ++-- .github/workflows/weekly-preview.yml | 2 +- 7 files changed, 14 insertions(+), 14 deletions(-) diff --git a/.github/workflows/cron-ngc-bundle.yml b/.github/workflows/cron-ngc-bundle.yml index 0bba630d03..84666204a9 100644 --- a/.github/workflows/cron-ngc-bundle.yml +++ b/.github/workflows/cron-ngc-bundle.yml @@ -19,7 +19,7 @@ jobs: steps: - uses: actions/checkout@v4 - name: Set up Python 3.8 - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: '3.8' - name: cache weekly timestamp diff --git a/.github/workflows/docker.yml b/.github/workflows/docker.yml index f80a4c2c96..c375e82e74 100644 --- a/.github/workflows/docker.yml +++ b/.github/workflows/docker.yml @@ -26,7 +26,7 @@ jobs: ref: dev fetch-depth: 0 - name: Set up Python 3.9 - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: '3.9' - shell: bash diff --git a/.github/workflows/pythonapp-min.yml b/.github/workflows/pythonapp-min.yml index 558c270e33..7b7930bdf5 100644 --- a/.github/workflows/pythonapp-min.yml +++ b/.github/workflows/pythonapp-min.yml @@ -30,7 +30,7 @@ jobs: steps: - uses: actions/checkout@v4 - name: Set up Python 3.8 - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: '3.8' - name: Prepare pip wheel @@ -76,7 +76,7 @@ jobs: steps: - uses: actions/checkout@v4 - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} - name: Prepare pip wheel @@ -121,7 +121,7 @@ jobs: steps: - uses: actions/checkout@v4 - name: Set up Python 3.8 - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: '3.8' - name: Prepare pip wheel diff --git a/.github/workflows/pythonapp.yml b/.github/workflows/pythonapp.yml index ad8b555dd4..29a79759e0 100644 --- a/.github/workflows/pythonapp.yml +++ b/.github/workflows/pythonapp.yml @@ -28,7 +28,7 @@ jobs: steps: - uses: actions/checkout@v4 - name: Set up Python 3.8 - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: '3.8' - name: cache weekly timestamp @@ -69,7 +69,7 @@ jobs: disk-root: "D:" - uses: actions/checkout@v4 - name: Set up Python 3.8 - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: '3.8' - name: Prepare pip wheel @@ -128,7 +128,7 @@ jobs: with: fetch-depth: 0 - name: Set up Python 3.8 - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: '3.8' - name: cache weekly timestamp @@ -209,7 +209,7 @@ jobs: steps: - uses: actions/checkout@v4 - name: Set up Python 3.8 - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: '3.8' - name: cache weekly timestamp diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index e9817e1c4c..9334908bfc 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -19,7 +19,7 @@ jobs: with: fetch-depth: 0 - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} - name: Install setuptools @@ -97,7 +97,7 @@ jobs: with: fetch-depth: 0 - name: Set up Python 3.9 - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: '3.9' - shell: bash diff --git a/.github/workflows/setupapp.yml b/.github/workflows/setupapp.yml index 0ff7162bee..82394a86dd 100644 --- a/.github/workflows/setupapp.yml +++ b/.github/workflows/setupapp.yml @@ -83,7 +83,7 @@ jobs: with: fetch-depth: 0 - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} - name: cache weekly timestamp @@ -120,7 +120,7 @@ jobs: runs-on: ubuntu-latest steps: - name: Set up Python 3.8 - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: '3.8' - name: cache weekly timestamp diff --git a/.github/workflows/weekly-preview.yml b/.github/workflows/weekly-preview.yml index c631982745..e94e1dac5a 100644 --- a/.github/workflows/weekly-preview.yml +++ b/.github/workflows/weekly-preview.yml @@ -14,7 +14,7 @@ jobs: ref: dev fetch-depth: 0 - name: Set up Python 3.9 - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: '3.9' - name: Install setuptools From 035e6c4fd41ed265e063b1eb7772a1effbf11750 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 2 Jan 2024 23:42:25 +0800 Subject: [PATCH 35/38] Bump actions/download-artifact from 3 to 4 (#7352) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Bumps [actions/download-artifact](https://github.com/actions/download-artifact) from 3 to 4.
Release notes

Sourced from actions/download-artifact's releases.

v4.0.0

What's Changed

The release of upload-artifact@v4 and download-artifact@v4 are major changes to the backend architecture of Artifacts. They have numerous performance and behavioral improvements.

For more information, see the @​actions/artifact documentation.

New Contributors

Full Changelog: https://github.com/actions/download-artifact/compare/v3...v4.0.0

v3.0.2

  • Bump @actions/artifact to v1.1.1 - actions/download-artifact#195
  • Fixed a bug in Node16 where if an HTTP download finished too quickly (<1ms, e.g. when it's mocked) we attempt to delete a temp file that has not been created yet actions/toolkit#1278

v3.0.1

Commits
  • f44cd7b Merge pull request #259 from actions/robherley/glob-downloads
  • 3181fe8 add some migration docs
  • aaaac7b licensed cache
  • 7c9182f update readme
  • b94e701 licensed cache
  • 0b55470 add test case for globbed downloads to same directory
  • 0b51c2e update prettier/eslint versions
  • c4c6db7 support globbing artifact list & merging download directory
  • 1bd0606 Merge pull request #252 from stchr/patch-1
  • eff4d42 fix default for run-id
  • Additional commits viewable in compare view

[![Dependabot compatibility score](https://dependabot-badges.githubapp.com/badges/compatibility_score?dependency-name=actions/download-artifact&package-manager=github_actions&previous-version=3&new-version=4)](https://docs.github.com/en/github/managing-security-vulnerabilities/about-dependabot-security-updates#about-compatibility-scores) Dependabot will resolve any conflicts with this PR as long as you don't alter it yourself. You can also trigger a rebase manually by commenting `@dependabot rebase`. [//]: # (dependabot-automerge-start) [//]: # (dependabot-automerge-end) ---
Dependabot commands and options
You can trigger Dependabot actions by commenting on this PR: - `@dependabot rebase` will rebase this PR - `@dependabot recreate` will recreate this PR, overwriting any edits that have been made to it - `@dependabot merge` will merge this PR after your CI passes on it - `@dependabot squash and merge` will squash and merge this PR after your CI passes on it - `@dependabot cancel merge` will cancel a previously requested merge and block automerging - `@dependabot reopen` will reopen this PR if it is closed - `@dependabot close` will close this PR and stop Dependabot recreating it. You can achieve the same result by closing it manually - `@dependabot show ignore conditions` will show all of the ignore conditions of the specified dependency - `@dependabot ignore this major version` will close this PR and stop Dependabot creating any more for this major version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this minor version` will close this PR and stop Dependabot creating any more for this minor version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this dependency` will close this PR and stop Dependabot creating any more for this dependency (unless you reopen the PR or upgrade to it yourself)
Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- .github/workflows/docker.yml | 2 +- .github/workflows/release.yml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/docker.yml b/.github/workflows/docker.yml index c375e82e74..229ae675f5 100644 --- a/.github/workflows/docker.yml +++ b/.github/workflows/docker.yml @@ -56,7 +56,7 @@ jobs: with: ref: dev - name: Download version - uses: actions/download-artifact@v3 + uses: actions/download-artifact@v4 with: name: _version.py - name: docker_build diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 9334908bfc..a03d2cea6c 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -125,7 +125,7 @@ jobs: steps: - uses: actions/checkout@v4 - name: Download version - uses: actions/download-artifact@v3 + uses: actions/download-artifact@v4 with: name: _version.py - name: Set tag From eec3308c6b253a26b32363fb886d2f92e201cd00 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Wed, 3 Jan 2024 10:41:55 +0800 Subject: [PATCH 36/38] Bump peter-evans/slash-command-dispatch from 3.0.1 to 3.0.2 (#7353) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Bumps [peter-evans/slash-command-dispatch](https://github.com/peter-evans/slash-command-dispatch) from 3.0.1 to 3.0.2.
Release notes

Sourced from peter-evans/slash-command-dispatch's releases.

Slash Command Dispatch v3.0.2

What's Changed

New Contributors

Full Changelog: https://github.com/peter-evans/slash-command-dispatch/compare/v3.0.1...v3.0.2

Commits
  • f996d7b Fix the CollaboratorPermission GraphQL query (#301)
  • 05b97d6 build(deps-dev): bump @​types/node from 16.18.65 to 16.18.67 (#300)
  • 8e70073 build(deps-dev): bump eslint from 8.54.0 to 8.55.0 (#299)
  • bd00135 build(deps-dev): bump @​types/node from 16.18.62 to 16.18.65 (#298)
  • ee873b6 build(deps-dev): bump eslint from 8.53.0 to 8.54.0 (#296)
  • 44abc47 build(deps-dev): bump @​types/node from 16.18.61 to 16.18.62 (#295)
  • 19ad7b8 build(deps-dev): bump @​types/node from 16.18.60 to 16.18.61 (#294)
  • 29a9815 build(deps-dev): bump prettier from 3.0.3 to 3.1.0 (#293)
  • ade0309 build(deps-dev): bump eslint from 8.52.0 to 8.53.0 (#292)
  • fc8222e build(deps-dev): bump @​types/node from 16.18.59 to 16.18.60 (#291)
  • Additional commits viewable in compare view

[![Dependabot compatibility score](https://dependabot-badges.githubapp.com/badges/compatibility_score?dependency-name=peter-evans/slash-command-dispatch&package-manager=github_actions&previous-version=3.0.1&new-version=3.0.2)](https://docs.github.com/en/github/managing-security-vulnerabilities/about-dependabot-security-updates#about-compatibility-scores) Dependabot will resolve any conflicts with this PR as long as you don't alter it yourself. You can also trigger a rebase manually by commenting `@dependabot rebase`. [//]: # (dependabot-automerge-start) [//]: # (dependabot-automerge-end) ---
Dependabot commands and options
You can trigger Dependabot actions by commenting on this PR: - `@dependabot rebase` will rebase this PR - `@dependabot recreate` will recreate this PR, overwriting any edits that have been made to it - `@dependabot merge` will merge this PR after your CI passes on it - `@dependabot squash and merge` will squash and merge this PR after your CI passes on it - `@dependabot cancel merge` will cancel a previously requested merge and block automerging - `@dependabot reopen` will reopen this PR if it is closed - `@dependabot close` will close this PR and stop Dependabot recreating it. You can achieve the same result by closing it manually - `@dependabot show ignore conditions` will show all of the ignore conditions of the specified dependency - `@dependabot ignore this major version` will close this PR and stop Dependabot creating any more for this major version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this minor version` will close this PR and stop Dependabot creating any more for this minor version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this dependency` will close this PR and stop Dependabot creating any more for this dependency (unless you reopen the PR or upgrade to it yourself)
Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- .github/workflows/chatops.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/chatops.yml b/.github/workflows/chatops.yml index b4e201a0d9..59c7d070b4 100644 --- a/.github/workflows/chatops.yml +++ b/.github/workflows/chatops.yml @@ -10,7 +10,7 @@ jobs: runs-on: ubuntu-latest steps: - name: dispatch - uses: peter-evans/slash-command-dispatch@v3.0.1 + uses: peter-evans/slash-command-dispatch@v3.0.2 with: token: ${{ secrets.PR_MAINTAIN }} reaction-token: ${{ secrets.GITHUB_TOKEN }} From 7d82d8a95a6c8f71327a511a34967db31632e8b6 Mon Sep 17 00:00:00 2001 From: vgrau98 Date: Sat, 6 Jan 2024 15:49:46 +0100 Subject: [PATCH 37/38] clean Signed-off-by: vgrau98 --- monai/networks/blocks/transformerblock.py | 94 +++++++++++++++-------- 1 file changed, 62 insertions(+), 32 deletions(-) diff --git a/monai/networks/blocks/transformerblock.py b/monai/networks/blocks/transformerblock.py index d0d305ad42..ac17263b08 100644 --- a/monai/networks/blocks/transformerblock.py +++ b/monai/networks/blocks/transformerblock.py @@ -85,44 +85,47 @@ def forward(self, x: torch.Tensor): x = self.norm1(x) # Window partition if self.window_size > 0: - if x.shape[1] != int(torch.prod(torch.tensor(self.input_size))): - raise ValueError( - f"Input tensor spatial dimension {x.shape[1]} should be equal to {self.input_size} product" - ) - - if len(self.input_size) == 2: - x = rearrange(x, "b (h w) c -> b h w c", h=self.input_size[0], w=self.input_size[1]) - x, pad_hw = window_partition(x, self.window_size) - x = rearrange(x, "b h w c -> b (h w) c", h=self.window_size, w=self.window_size) - elif len(self.input_size) == 3: - x = rearrange( - x, "b (h w d) c -> b h w d c", h=self.input_size[0], w=self.input_size[1], d=self.input_size[2] - ) - x, pad_hwd = window_partition_3d(x, self.window_size) - x = rearrange(x, "b h w d c -> b (h w d) c", h=self.window_size, w=self.window_size, d=self.window_size) - + x, pad = window_partition(x, self.window_size, self.input_size) x = self.attn(x) # Reverse window partition if self.window_size > 0: - if len(self.input_size) == 2: - x = rearrange(x, "b (h w) c -> b h w c", h=self.window_size, w=self.window_size) - x = window_unpartition(x, self.window_size, pad_hw, (self.input_size[0], self.input_size[1])) - x = rearrange(x, "b h w c -> b (h w) c", h=self.input_size[0], w=self.input_size[1]) - elif len(self.input_size) == 3: - x = rearrange(x, "b (h w d) c -> b h w d c", h=self.window_size, w=self.window_size, d=self.window_size) - x = window_unpartition_3d( - x, self.window_size, pad_hwd, (self.input_size[0], self.input_size[1], self.input_size[2]) - ) - x = rearrange( - x, "b h w d c -> b (h w d) c", h=self.input_size[0], w=self.input_size[1], d=self.input_size[2] - ) - + x = window_unpartition(x, self.window_size, pad, self.input_size) x = shortcut + x x = x + self.mlp(self.norm2(x)) return x -def window_partition(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, Tuple[int, int]]: +def window_partition(x: torch.Tensor, window_size: int, input_size: Tuple = ()) -> Tuple[torch.Tensor, Tuple]: + """ + Partition into non-overlapping windows with padding if needed. Support 2D and 3D. + Args: + x (tensor): input tokens with [B, s_dim_1 * ... * s_dim_n, C]. with n = 1...len(input_size) + input_size (Tuple): input spatial dimension: (H, W) or (H, W, D) + window_size (int): window size + + Returns: + windows: windows after partition with [B * num_windows, window_size_1 * ... * window_size_n, C]. + with n = 1...len(input_size) and window_size_i == window_size. + (S_DIM_1p, ...,S_DIM_np): padded spatial dimensions before partition with n = 1...len(input_size) + """ + if x.shape[1] != int(torch.prod(torch.tensor(input_size))): + raise ValueError(f"Input tensor spatial dimension {x.shape[1]} should be equal to {input_size} product") + + if len(input_size) == 2: + x = rearrange(x, "b (h w) c -> b h w c", h=input_size[0], w=input_size[1]) + x, pad_hw = window_partition_2d(x, window_size) + x = rearrange(x, "b h w c -> b (h w) c", h=window_size, w=window_size) + return x, pad_hw + elif len(input_size) == 3: + x = rearrange(x, "b (h w d) c -> b h w d c", h=input_size[0], w=input_size[1], d=input_size[2]) + x, pad_hwd = window_partition_3d(x, window_size) + x = rearrange(x, "b h w d c -> b (h w d) c", h=window_size, w=window_size, d=window_size) + return x, pad_hwd + else: + raise ValueError(f"input_size cannot be length {len(input_size)}. It can be composed of 2 or 3 elements only. ") + + +def window_partition_2d(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, Tuple[int, int]]: """ Partition into non-overlapping windows with padding if needed. Support only 2D. Args: @@ -169,10 +172,37 @@ def window_partition_3d(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor x = x.view(batch, hp // window_size, window_size, wp // window_size, window_size, dp // window_size, window_size, c) windows = x.permute(0, 1, 3, 5, 2, 4, 6, 7).contiguous().view(-1, window_size, window_size, window_size, c) return windows, (hp, wp, dp) - ... -def window_unpartition( +def window_unpartition(windows: torch.Tensor, window_size: int, pad: Tuple, spatial_dims: Tuple) -> torch.Tensor: + """ + Window unpartition into original sequences and removing padding. + Args: + windows (tensor): input tokens with [B * num_windows, window_size_1, ..., window_size_n, C]. + with n = 1...len(spatial_dims) and window_size == window_size_i + window_size (int): window size. + pad (Tuple): padded spatial dims (H, W) or (H, W, D) + spatial_dims (Tuple): original spatial dimensions - (H, W) or (H, W, D) - before padding. + + Returns: + x: unpartitioned sequences with [B, s_dim_1, ..., s_dim_n, C]. + """ + x: torch.Tensor + if len(spatial_dims) == 2: + x = rearrange(windows, "b (h w) c -> b h w c", h=window_size, w=window_size) + x = window_unpartition_2d(x, window_size, pad, spatial_dims) + x = rearrange(x, "b h w c -> b (h w) c", h=spatial_dims[0], w=spatial_dims[1]) + return x + elif len(spatial_dims) == 3: + x = rearrange(windows, "b (h w d) c -> b h w d c", h=window_size, w=window_size, d=window_size) + x = window_unpartition_3d(x, window_size, pad, spatial_dims) + x = rearrange(x, "b h w d c -> b (h w d) c", h=spatial_dims[0], w=spatial_dims[1], d=spatial_dims[2]) + return x + else: + raise ValueError() + + +def window_unpartition_2d( windows: torch.Tensor, window_size: int, pad_hw: Tuple[int, int], hw: Tuple[int, int] ) -> torch.Tensor: """ From 0132e6bae38c19809cbe7402560e062e93fd3143 Mon Sep 17 00:00:00 2001 From: vgrau98 Date: Sat, 6 Jan 2024 18:18:40 +0100 Subject: [PATCH 38/38] refacto Signed-off-by: vgrau98 --- monai/networks/blocks/attention_utils.py | 163 +++++++++++++++++++ monai/networks/blocks/selfattention.py | 33 +++- monai/networks/blocks/transformerblock.py | 185 +--------------------- 3 files changed, 195 insertions(+), 186 deletions(-) diff --git a/monai/networks/blocks/attention_utils.py b/monai/networks/blocks/attention_utils.py index 8c9002a16e..08e6982695 100644 --- a/monai/networks/blocks/attention_utils.py +++ b/monai/networks/blocks/attention_utils.py @@ -15,6 +15,10 @@ import torch.nn.functional as F from torch import nn +from monai.utils import optional_import + +rearrange, _ = optional_import("einops", name="rearrange") + def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor: """ @@ -126,3 +130,162 @@ def add_decomposed_rel_pos( ).view(batch, q_h * q_w * q_d, k_h * k_w * k_d) return attn + + +def window_partition(x: torch.Tensor, window_size: int, input_size: Tuple = ()) -> Tuple[torch.Tensor, Tuple]: + """ + Partition into non-overlapping windows with padding if needed. Support 2D and 3D. + Args: + x (tensor): input tokens with [B, s_dim_1 * ... * s_dim_n, C]. with n = 1...len(input_size) + input_size (Tuple): input spatial dimension: (H, W) or (H, W, D) + window_size (int): window size + + Returns: + windows: windows after partition with [B * num_windows, window_size_1 * ... * window_size_n, C]. + with n = 1...len(input_size) and window_size_i == window_size. + (S_DIM_1p, ...,S_DIM_np): padded spatial dimensions before partition with n = 1...len(input_size) + """ + if x.shape[1] != int(torch.prod(torch.tensor(input_size))): + raise ValueError(f"Input tensor spatial dimension {x.shape[1]} should be equal to {input_size} product") + + if len(input_size) == 2: + x = rearrange(x, "b (h w) c -> b h w c", h=input_size[0], w=input_size[1]) + x, pad_hw = window_partition_2d(x, window_size) + x = rearrange(x, "b h w c -> b (h w) c", h=window_size, w=window_size) + return x, pad_hw + elif len(input_size) == 3: + x = rearrange(x, "b (h w d) c -> b h w d c", h=input_size[0], w=input_size[1], d=input_size[2]) + x, pad_hwd = window_partition_3d(x, window_size) + x = rearrange(x, "b h w d c -> b (h w d) c", h=window_size, w=window_size, d=window_size) + return x, pad_hwd + else: + raise ValueError(f"input_size cannot be length {len(input_size)}. It can be composed of 2 or 3 elements only. ") + + +def window_partition_2d(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, Tuple[int, int]]: + """ + Partition into non-overlapping windows with padding if needed. Support only 2D. + Args: + x (tensor): input tokens with [B, H, W, C]. + window_size (int): window size. + + Returns: + windows: windows after partition with [B * num_windows, window_size, window_size, C]. + (Hp, Wp): padded height and width before partition + """ + batch, h, w, c = x.shape + + pad_h = (window_size - h % window_size) % window_size + pad_w = (window_size - w % window_size) % window_size + if pad_h > 0 or pad_w > 0: + x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h)) + hp, wp = h + pad_h, w + pad_w + + x = x.view(batch, hp // window_size, window_size, wp // window_size, window_size, c) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, c) + return windows, (hp, wp) + + +def window_partition_3d(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, Tuple[int, int, int]]: + """ + Partition into non-overlapping windows with padding if needed. 3d implementation. + Args: + x (tensor): input tokens with [B, H, W, D, C]. + window_size (int): window size. + + Returns: + windows: windows after partition with [B * num_windows, window_size, window_size, window_size, C]. + (Hp, Wp, Dp): padded height, width and depth before partition + """ + batch, h, w, d, c = x.shape + + pad_h = (window_size - h % window_size) % window_size + pad_w = (window_size - w % window_size) % window_size + pad_d = (window_size - d % window_size) % window_size + if pad_h > 0 or pad_w > 0 or pad_d > 0: + x = F.pad(x, (0, 0, 0, pad_d, 0, pad_w, 0, pad_h)) + hp, wp, dp = h + pad_h, w + pad_w, d + pad_d + + x = x.view(batch, hp // window_size, window_size, wp // window_size, window_size, dp // window_size, window_size, c) + windows = x.permute(0, 1, 3, 5, 2, 4, 6, 7).contiguous().view(-1, window_size, window_size, window_size, c) + return windows, (hp, wp, dp) + + +def window_unpartition(windows: torch.Tensor, window_size: int, pad: Tuple, spatial_dims: Tuple) -> torch.Tensor: + """ + Window unpartition into original sequences and removing padding. + Args: + windows (tensor): input tokens with [B * num_windows, window_size_1, ..., window_size_n, C]. + with n = 1...len(spatial_dims) and window_size == window_size_i + window_size (int): window size. + pad (Tuple): padded spatial dims (H, W) or (H, W, D) + spatial_dims (Tuple): original spatial dimensions - (H, W) or (H, W, D) - before padding. + + Returns: + x: unpartitioned sequences with [B, s_dim_1, ..., s_dim_n, C]. + """ + x: torch.Tensor + if len(spatial_dims) == 2: + x = rearrange(windows, "b (h w) c -> b h w c", h=window_size, w=window_size) + x = window_unpartition_2d(x, window_size, pad, spatial_dims) + x = rearrange(x, "b h w c -> b (h w) c", h=spatial_dims[0], w=spatial_dims[1]) + return x + elif len(spatial_dims) == 3: + x = rearrange(windows, "b (h w d) c -> b h w d c", h=window_size, w=window_size, d=window_size) + x = window_unpartition_3d(x, window_size, pad, spatial_dims) + x = rearrange(x, "b h w d c -> b (h w d) c", h=spatial_dims[0], w=spatial_dims[1], d=spatial_dims[2]) + return x + else: + raise ValueError() + + +def window_unpartition_2d( + windows: torch.Tensor, window_size: int, pad_hw: Tuple[int, int], hw: Tuple[int, int] +) -> torch.Tensor: + """ + Window unpartition into original sequences and removing padding. + Args: + windows (tensor): input tokens with [B * num_windows, window_size, window_size, C]. + window_size (int): window size. + pad_hw (Tuple): padded height and width (hp, wp). + hw (Tuple): original height and width (H, W) before padding. + + Returns: + x: unpartitioned sequences with [B, H, W, C]. + """ + hp, wp = pad_hw + h, w = hw + batch = windows.shape[0] // (hp * wp // window_size // window_size) + x = windows.view(batch, hp // window_size, wp // window_size, window_size, window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(batch, hp, wp, -1) + + if hp > h or wp > w: + x = x[:, :h, :w, :].contiguous() + return x + + +def window_unpartition_3d( + windows: torch.Tensor, window_size: int, pad_hwd: Tuple[int, int, int], hwd: Tuple[int, int, int] +) -> torch.Tensor: + """ + Window unpartition into original sequences and removing padding. 3d implementation. + Args: + windows (tensor): input tokens with [B * num_windows, window_size, window_size, window_size, C]. + window_size (int): window size. + pad_hwd (Tuple): padded height, width and depth (hp, wp, dp). + hwd (Tuple): original height, width and depth (H, W, D) before padding. + + Returns: + x: unpartitioned sequences with [B, H, W, D, C]. + """ + hp, wp, dp = pad_hwd + h, w, d = hwd + batch = windows.shape[0] // (hp * wp * dp // window_size // window_size // window_size) + x = windows.view( + batch, hp // window_size, wp // window_size, dp // window_size, window_size, window_size, window_size, -1 + ) + x = x.permute(0, 1, 3, 5, 2, 4, 6, 7).contiguous().view(batch, hp, wp, dp, -1) + + if hp > h or wp > w or dp > d: + x = x[:, :h, :w, :d, :].contiguous() + return x diff --git a/monai/networks/blocks/selfattention.py b/monai/networks/blocks/selfattention.py index 0a848e9ec5..3b524e1ccc 100644 --- a/monai/networks/blocks/selfattention.py +++ b/monai/networks/blocks/selfattention.py @@ -18,6 +18,7 @@ import torch.nn as nn from monai.networks.layers.utils import get_rel_pos_embedding_layer +from monai.networks.blocks.attention_utils import window_partition, window_unpartition from monai.utils import optional_import xops, has_xformers = optional_import("xformers.ops") @@ -26,9 +27,14 @@ class SABlock(nn.Module): """ - A self-attention block, based on: "Dosovitskiy et al., - An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale " - One can setup relative positional embedding as described in + A self-attention block, based on: "Dosovitskiy et al., + An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale " + <<<<<<< HEAD + One can setup relative positional embedding as described in + ======= + and some additional features: + - local window attention + >>>>>>> f7aca872 (refacto) """ def __init__( @@ -43,6 +49,7 @@ def __init__( causal: bool = False, sequence_length: int | None = None, use_flash_attention: bool = False, + window_size: int = 0, ) -> None: """ Args: @@ -53,11 +60,13 @@ def __init__( rel_pos_embedding (str, optional): Add relative positional embeddings to the attention map. For now only "decomposed" is supported (see https://arxiv.org/abs/2112.01526). 2D and 3D are supported. input_size (tuple(spatial_dim), optional): Input resolution for calculating the relative - positional parameter size. + positional parameter size. Has to be set if local window attention is used causal (bool): wether to use causal attention. If true `sequence_length` has to be set sequence_length (int, optional): if causal is True, it is necessary to specify the sequence length. save_attn (bool, optional): to make accessible the attention matrix. Defaults to False. - + window_size (int): Window size for local attention as used in Segment Anything https://arxiv.org/abs/2304.02643. + If 0, global attention used. + See https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py. """ super().__init__() @@ -81,6 +90,10 @@ def __init__( if use_flash_attention and not has_xformers: raise ValueError("use_flash_attention is True but xformers is not installed.") + if window_size > 0 and len(input_size) not in [2, 3]: + raise ValueError( + "If local window attention is used (window_size > 0), input_size should be specified: (h, w) or (h, w, d)" + ) self.num_heads = num_heads self.out_proj = nn.Linear(hidden_size, hidden_size) @@ -101,6 +114,7 @@ def __init__( if rel_pos_embedding is not None else None ) + self.window_size = window_size self.input_size = input_size if causal and sequence_length is not None: @@ -119,6 +133,10 @@ def forward(self, x: torch.Tensor): Return: torch.Tensor: B x (s_dim_1 * ... * s_dim_n) x C """ + + if self.window_size > 0: + x, pad = window_partition(x, self.window_size, self.input_size) + _, t, _ = x.size() output = self.input_rearrange(self.qkv(x)) # 3 x B x (s_dim_1 * ... * s_dim_n) x h x C/h q, k, v = output[0], output[1], output[2] @@ -156,4 +174,9 @@ def forward(self, x: torch.Tensor): x = self.out_rearrange(x) x = self.out_proj(x) x = self.drop_output(x) + + # Reverse window partition + if self.window_size > 0: + x = window_unpartition(x, self.window_size, pad, self.input_size) + return x diff --git a/monai/networks/blocks/transformerblock.py b/monai/networks/blocks/transformerblock.py index ac17263b08..b5b5adacf8 100644 --- a/monai/networks/blocks/transformerblock.py +++ b/monai/networks/blocks/transformerblock.py @@ -15,13 +15,9 @@ import torch import torch.nn as nn -import torch.nn.functional as F from monai.networks.blocks.mlp import MLPBlock from monai.networks.blocks.selfattention import SABlock -from monai.utils import optional_import - -rearrange, _ = optional_import("einops", name="rearrange") class TransformerBlock(nn.Module): @@ -64,191 +60,18 @@ def __init__( if hidden_size % num_heads != 0: raise ValueError("hidden_size should be divisible by num_heads.") - if window_size > 0 and len(input_size) not in [2, 3]: - raise ValueError( - "If local window attention is used (window_size > 0), input_size should be specified: (h, w) or (h, w, d)" - ) - self.mlp = MLPBlock(hidden_size, mlp_dim, dropout_rate) self.norm1 = nn.LayerNorm(hidden_size) - self.attn = SABlock(hidden_size, num_heads, dropout_rate, qkv_bias, save_attn) + self.attn = SABlock( + hidden_size, num_heads, dropout_rate, qkv_bias, save_attn, window_size=window_size, input_size=input_size + ) self.norm2 = nn.LayerNorm(hidden_size) - self.window_size = window_size - self.input_size = input_size def forward(self, x: torch.Tensor): """ Args: x (Tensor): [b x (s_dim_1 * … * s_dim_n) x dim] """ - shortcut = x - x = self.norm1(x) - # Window partition - if self.window_size > 0: - x, pad = window_partition(x, self.window_size, self.input_size) - x = self.attn(x) - # Reverse window partition - if self.window_size > 0: - x = window_unpartition(x, self.window_size, pad, self.input_size) - x = shortcut + x + x = x + self.attn(self.norm1(x)) x = x + self.mlp(self.norm2(x)) return x - - -def window_partition(x: torch.Tensor, window_size: int, input_size: Tuple = ()) -> Tuple[torch.Tensor, Tuple]: - """ - Partition into non-overlapping windows with padding if needed. Support 2D and 3D. - Args: - x (tensor): input tokens with [B, s_dim_1 * ... * s_dim_n, C]. with n = 1...len(input_size) - input_size (Tuple): input spatial dimension: (H, W) or (H, W, D) - window_size (int): window size - - Returns: - windows: windows after partition with [B * num_windows, window_size_1 * ... * window_size_n, C]. - with n = 1...len(input_size) and window_size_i == window_size. - (S_DIM_1p, ...,S_DIM_np): padded spatial dimensions before partition with n = 1...len(input_size) - """ - if x.shape[1] != int(torch.prod(torch.tensor(input_size))): - raise ValueError(f"Input tensor spatial dimension {x.shape[1]} should be equal to {input_size} product") - - if len(input_size) == 2: - x = rearrange(x, "b (h w) c -> b h w c", h=input_size[0], w=input_size[1]) - x, pad_hw = window_partition_2d(x, window_size) - x = rearrange(x, "b h w c -> b (h w) c", h=window_size, w=window_size) - return x, pad_hw - elif len(input_size) == 3: - x = rearrange(x, "b (h w d) c -> b h w d c", h=input_size[0], w=input_size[1], d=input_size[2]) - x, pad_hwd = window_partition_3d(x, window_size) - x = rearrange(x, "b h w d c -> b (h w d) c", h=window_size, w=window_size, d=window_size) - return x, pad_hwd - else: - raise ValueError(f"input_size cannot be length {len(input_size)}. It can be composed of 2 or 3 elements only. ") - - -def window_partition_2d(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, Tuple[int, int]]: - """ - Partition into non-overlapping windows with padding if needed. Support only 2D. - Args: - x (tensor): input tokens with [B, H, W, C]. - window_size (int): window size. - - Returns: - windows: windows after partition with [B * num_windows, window_size, window_size, C]. - (Hp, Wp): padded height and width before partition - """ - batch, h, w, c = x.shape - - pad_h = (window_size - h % window_size) % window_size - pad_w = (window_size - w % window_size) % window_size - if pad_h > 0 or pad_w > 0: - x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h)) - hp, wp = h + pad_h, w + pad_w - - x = x.view(batch, hp // window_size, window_size, wp // window_size, window_size, c) - windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, c) - return windows, (hp, wp) - - -def window_partition_3d(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, Tuple[int, int, int]]: - """ - Partition into non-overlapping windows with padding if needed. 3d implementation. - Args: - x (tensor): input tokens with [B, H, W, D, C]. - window_size (int): window size. - - Returns: - windows: windows after partition with [B * num_windows, window_size, window_size, window_size, C]. - (Hp, Wp, Dp): padded height, width and depth before partition - """ - batch, h, w, d, c = x.shape - - pad_h = (window_size - h % window_size) % window_size - pad_w = (window_size - w % window_size) % window_size - pad_d = (window_size - d % window_size) % window_size - if pad_h > 0 or pad_w > 0 or pad_d > 0: - x = F.pad(x, (0, 0, 0, pad_d, 0, pad_w, 0, pad_h)) - hp, wp, dp = h + pad_h, w + pad_w, d + pad_d - - x = x.view(batch, hp // window_size, window_size, wp // window_size, window_size, dp // window_size, window_size, c) - windows = x.permute(0, 1, 3, 5, 2, 4, 6, 7).contiguous().view(-1, window_size, window_size, window_size, c) - return windows, (hp, wp, dp) - - -def window_unpartition(windows: torch.Tensor, window_size: int, pad: Tuple, spatial_dims: Tuple) -> torch.Tensor: - """ - Window unpartition into original sequences and removing padding. - Args: - windows (tensor): input tokens with [B * num_windows, window_size_1, ..., window_size_n, C]. - with n = 1...len(spatial_dims) and window_size == window_size_i - window_size (int): window size. - pad (Tuple): padded spatial dims (H, W) or (H, W, D) - spatial_dims (Tuple): original spatial dimensions - (H, W) or (H, W, D) - before padding. - - Returns: - x: unpartitioned sequences with [B, s_dim_1, ..., s_dim_n, C]. - """ - x: torch.Tensor - if len(spatial_dims) == 2: - x = rearrange(windows, "b (h w) c -> b h w c", h=window_size, w=window_size) - x = window_unpartition_2d(x, window_size, pad, spatial_dims) - x = rearrange(x, "b h w c -> b (h w) c", h=spatial_dims[0], w=spatial_dims[1]) - return x - elif len(spatial_dims) == 3: - x = rearrange(windows, "b (h w d) c -> b h w d c", h=window_size, w=window_size, d=window_size) - x = window_unpartition_3d(x, window_size, pad, spatial_dims) - x = rearrange(x, "b h w d c -> b (h w d) c", h=spatial_dims[0], w=spatial_dims[1], d=spatial_dims[2]) - return x - else: - raise ValueError() - - -def window_unpartition_2d( - windows: torch.Tensor, window_size: int, pad_hw: Tuple[int, int], hw: Tuple[int, int] -) -> torch.Tensor: - """ - Window unpartition into original sequences and removing padding. - Args: - windows (tensor): input tokens with [B * num_windows, window_size, window_size, C]. - window_size (int): window size. - pad_hw (Tuple): padded height and width (hp, wp). - hw (Tuple): original height and width (H, W) before padding. - - Returns: - x: unpartitioned sequences with [B, H, W, C]. - """ - hp, wp = pad_hw - h, w = hw - batch = windows.shape[0] // (hp * wp // window_size // window_size) - x = windows.view(batch, hp // window_size, wp // window_size, window_size, window_size, -1) - x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(batch, hp, wp, -1) - - if hp > h or wp > w: - x = x[:, :h, :w, :].contiguous() - return x - - -def window_unpartition_3d( - windows: torch.Tensor, window_size: int, pad_hwd: Tuple[int, int, int], hwd: Tuple[int, int, int] -) -> torch.Tensor: - """ - Window unpartition into original sequences and removing padding. 3d implementation. - Args: - windows (tensor): input tokens with [B * num_windows, window_size, window_size, window_size, C]. - window_size (int): window size. - pad_hwd (Tuple): padded height, width and depth (hp, wp, dp). - hwd (Tuple): original height, width and depth (H, W, D) before padding. - - Returns: - x: unpartitioned sequences with [B, H, W, D, C]. - """ - hp, wp, dp = pad_hwd - h, w, d = hwd - batch = windows.shape[0] // (hp * wp * dp // window_size // window_size // window_size) - x = windows.view( - batch, hp // window_size, wp // window_size, dp // window_size, window_size, window_size, window_size, -1 - ) - x = x.permute(0, 1, 3, 5, 2, 4, 6, 7).contiguous().view(batch, hp, wp, dp, -1) - - if hp > h or wp > w or dp > d: - x = x[:, :h, :w, :d, :].contiguous() - return x