diff --git a/docs/source/networks.rst b/docs/source/networks.rst index 556bf12d508..735cc8ab259 100644 --- a/docs/source/networks.rst +++ b/docs/source/networks.rst @@ -264,11 +264,6 @@ N-Dim Fourier Transform .. autofunction:: monai.networks.blocks.fft_utils_t.fftshift .. autofunction:: monai.networks.blocks.fft_utils_t.ifftshift -`SPADE` -~~~~~~~ -.. autoclass:: monai.networks.blocks.spade_norm.SPADE - :members: - Layers ------ @@ -419,13 +414,6 @@ Layers .. autoclass:: LLTM :members: -`Vector Quantizer` -~~~~~~~~~~~~~~~~~~ -.. autoclass:: monai.networks.layers.vector_quantizer.EMAQuantizer - :members: -.. autoclass:: monai.networks.layers.vector_quantizer.VectorQuantizer - :members: - `Utilities` ~~~~~~~~~~~ .. automodule:: monai.networks.layers.convutils @@ -593,21 +581,6 @@ Nets .. autoclass:: VNet :members: -`DiffusionModelUnet` -~~~~~~~~~~~~~~~~~~~~ -.. autoclass:: DiffusionModelUNet - :members: - -`SPADEDiffusionModelUNet` -~~~~~~~~~~~~~~~~~~~~~~~~~ -.. autoclass:: SPADEDiffusionModelUNet - :members: - -`ControlNet` -~~~~~~~~~~~~ -.. autoclass:: ControlNet - :members: - `RegUNet` ~~~~~~~~~ .. autoclass:: RegUNet @@ -628,26 +601,11 @@ Nets .. autoclass:: AutoEncoder :members: -`AutoEncoderKL` -~~~~~~~~~~~~~~~ -.. autoclass:: AutoencoderKL - :members: - -`SPADEAutoencoderKL` -~~~~~~~~~~~~~~~~~~~~ -.. autoclass:: SPADEAutoencoderKL - :members: - `VarAutoEncoder` ~~~~~~~~~~~~~~~~ .. autoclass:: VarAutoEncoder :members: -`DecoderOnlyTransformer` -~~~~~~~~~~~~~~~~~~~~~~~~ -.. autoclass:: DecoderOnlyTransformer - :members: - `ViT` ~~~~~ .. autoclass:: ViT @@ -771,39 +729,6 @@ Nets .. autoclass:: voxelmorph -`VQ-VAE` -~~~~~~~~ -.. autoclass:: VQVAE - :members: - -`PatchGANDiscriminator` -~~~~~~~~~~~~~~~~~~~~~~~ -.. autoclass:: PatchDiscriminator - :members: - -.. autoclass:: MultiScalePatchDiscriminator - :members: - -Diffusion Schedulers --------------------- -.. autoclass:: monai.networks.schedulers.Scheduler - :members: - -`DDPM Scheduler` -~~~~~~~~~~~~~~~~ -.. autoclass:: monai.networks.schedulers.DDPMScheduler - :members: - -`DDIM Scheduler` -~~~~~~~~~~~~~~~~ -.. autoclass:: monai.networks.schedulers.DDIMScheduler - :members: - -`PNDM Scheduler` -~~~~~~~~~~~~~~~~ -.. autoclass:: monai.networks.schedulers.PNDMScheduler - :members: - Utilities --------- .. automodule:: monai.networks.utils diff --git a/docs/source/utils.rst b/docs/source/utils.rst index fef671e1f87..527247799fb 100644 --- a/docs/source/utils.rst +++ b/docs/source/utils.rst @@ -81,8 +81,3 @@ Component store --------------- .. autoclass:: monai.utils.component_store.ComponentStore :members: - -Ordering --------- -.. automodule:: monai.utils.ordering - :members: diff --git a/monai/networks/blocks/__init__.py b/monai/networks/blocks/__init__.py index afb6664bd93..e67cb3376fd 100644 --- a/monai/networks/blocks/__init__.py +++ b/monai/networks/blocks/__init__.py @@ -30,7 +30,6 @@ from .regunet_block import RegistrationDownSampleBlock, RegistrationExtractionBlock, RegistrationResidualConvBlock from .segresnet_block import ResBlock from .selfattention import SABlock -from .spade_norm import SPADE from .squeeze_and_excitation import ( ChannelSELayer, ResidualSELayer, diff --git a/monai/networks/blocks/spade_norm.py b/monai/networks/blocks/spade_norm.py deleted file mode 100644 index b1046f31543..00000000000 --- a/monai/networks/blocks/spade_norm.py +++ /dev/null @@ -1,96 +0,0 @@ -# Copyright (c) MONAI Consortium -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -import torch -import torch.nn as nn -import torch.nn.functional as F - -from monai.networks.blocks import ADN, Convolution - - -class SPADE(nn.Module): - """ - Spatially Adaptive Normalization (SPADE) block, allowing for normalization of activations conditioned on a - semantic map. This block is used in SPADE-based image-to-image translation models, as described in - Semantic Image Synthesis with Spatially-Adaptive Normalization (https://arxiv.org/abs/1903.07291). - - Args: - label_nc: number of semantic labels - norm_nc: number of output channels - kernel_size: kernel size - spatial_dims: number of spatial dimensions - hidden_channels: number of channels in the intermediate gamma and beta layers - norm: type of base normalisation used before applying the SPADE normalisation - norm_params: parameters for the base normalisation - """ - - def __init__( - self, - label_nc: int, - norm_nc: int, - kernel_size: int = 3, - spatial_dims: int = 2, - hidden_channels: int = 64, - norm: str | tuple = "INSTANCE", - norm_params: dict | None = None, - ) -> None: - super().__init__() - - if norm_params is None: - norm_params = {} - if len(norm_params) != 0: - norm = (norm, norm_params) - self.param_free_norm = ADN( - act=None, dropout=0.0, norm=norm, norm_dim=spatial_dims, ordering="N", in_channels=norm_nc - ) - self.mlp_shared = Convolution( - spatial_dims=spatial_dims, - in_channels=label_nc, - out_channels=hidden_channels, - kernel_size=kernel_size, - norm=None, - act="LEAKYRELU", - ) - self.mlp_gamma = Convolution( - spatial_dims=spatial_dims, - in_channels=hidden_channels, - out_channels=norm_nc, - kernel_size=kernel_size, - act=None, - ) - self.mlp_beta = Convolution( - spatial_dims=spatial_dims, - in_channels=hidden_channels, - out_channels=norm_nc, - kernel_size=kernel_size, - act=None, - ) - - def forward(self, x: torch.Tensor, segmap: torch.Tensor) -> torch.Tensor: - """ - Args: - x: input tensor with shape (B, C, [spatial-dimensions]) where C is the number of semantic channels. - segmap: input segmentation map (B, C, [spatial-dimensions]) where C is the number of semantic channels. - The map will be interpolated to the dimension of x internally. - """ - - # Part 1. generate parameter-free normalized activations - normalized = self.param_free_norm(x) - - # Part 2. produce scaling and bias conditioned on semantic map - segmap = F.interpolate(segmap, size=x.size()[2:], mode="nearest") - actv = self.mlp_shared(segmap) - gamma = self.mlp_gamma(actv) - beta = self.mlp_beta(actv) - out: torch.Tensor = normalized * (1 + gamma) + beta - return out diff --git a/monai/networks/layers/__init__.py b/monai/networks/layers/__init__.py index bd3e3af3af2..d61ed57f7f6 100644 --- a/monai/networks/layers/__init__.py +++ b/monai/networks/layers/__init__.py @@ -37,5 +37,4 @@ ) from .spatial_transforms import AffineTransform, grid_count, grid_grad, grid_pull, grid_push from .utils import get_act_layer, get_dropout_layer, get_norm_layer, get_pool_layer -from .vector_quantizer import EMAQuantizer, VectorQuantizer from .weight_init import _no_grad_trunc_normal_, trunc_normal_ diff --git a/monai/networks/layers/vector_quantizer.py b/monai/networks/layers/vector_quantizer.py deleted file mode 100644 index 9c354e10095..00000000000 --- a/monai/networks/layers/vector_quantizer.py +++ /dev/null @@ -1,233 +0,0 @@ -# Copyright (c) MONAI Consortium -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -from typing import Sequence, Tuple - -import torch -from torch import nn - -__all__ = ["VectorQuantizer", "EMAQuantizer"] - - -class EMAQuantizer(nn.Module): - """ - Vector Quantization module using Exponential Moving Average (EMA) to learn the codebook parameters based on Neural - Discrete Representation Learning by Oord et al. (https://arxiv.org/abs/1711.00937) and the official implementation - that can be found at https://github.com/deepmind/sonnet/blob/v2/sonnet/src/nets/vqvae.py#L148 and commit - 58d9a2746493717a7c9252938da7efa6006f3739. - - This module is not compatible with TorchScript while working in a Distributed Data Parallelism Module. This is due - to lack of TorchScript support for torch.distributed module as per https://github.com/pytorch/pytorch/issues/41353 - on 22/10/2022. If you want to TorchScript your model, please turn set `ddp_sync` to False. - - Args: - spatial_dims: number of spatial dimensions of the input. - num_embeddings: number of atomic elements in the codebook. - embedding_dim: number of channels of the input and atomic elements. - commitment_cost: scaling factor of the MSE loss between input and its quantized version. Defaults to 0.25. - decay: EMA decay. Defaults to 0.99. - epsilon: epsilon value. Defaults to 1e-5. - embedding_init: initialization method for the codebook. Defaults to "normal". - ddp_sync: whether to synchronize the codebook across processes. Defaults to True. - """ - - def __init__( - self, - spatial_dims: int, - num_embeddings: int, - embedding_dim: int, - commitment_cost: float = 0.25, - decay: float = 0.99, - epsilon: float = 1e-5, - embedding_init: str = "normal", - ddp_sync: bool = True, - ): - super().__init__() - self.spatial_dims: int = spatial_dims - self.embedding_dim: int = embedding_dim - self.num_embeddings: int = num_embeddings - - assert self.spatial_dims in [2, 3], ValueError( - f"EMAQuantizer only supports 4D and 5D tensor inputs but received spatial dims {spatial_dims}." - ) - - self.embedding: torch.nn.Embedding = torch.nn.Embedding(self.num_embeddings, self.embedding_dim) - if embedding_init == "normal": - # Initialization is passed since the default one is normal inside the nn.Embedding - pass - elif embedding_init == "kaiming_uniform": - torch.nn.init.kaiming_uniform_(self.embedding.weight.data, mode="fan_in", nonlinearity="linear") - self.embedding.weight.requires_grad = False - - self.commitment_cost: float = commitment_cost - - self.register_buffer("ema_cluster_size", torch.zeros(self.num_embeddings)) - self.register_buffer("ema_w", self.embedding.weight.data.clone()) - # declare types for mypy - self.ema_cluster_size: torch.Tensor - self.ema_w: torch.Tensor - self.decay: float = decay - self.epsilon: float = epsilon - - self.ddp_sync: bool = ddp_sync - - # Precalculating required permutation shapes - self.flatten_permutation = [0] + list(range(2, self.spatial_dims + 2)) + [1] - self.quantization_permutation: Sequence[int] = [0, self.spatial_dims + 1] + list( - range(1, self.spatial_dims + 1) - ) - - def quantize(self, inputs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """ - Given an input it projects it to the quantized space and returns additional tensors needed for EMA loss. - - Args: - inputs: Encoding space tensors of shape [B, C, H, W, D]. - - Returns: - torch.Tensor: Flatten version of the input of shape [B*H*W*D, C]. - torch.Tensor: One-hot representation of the quantization indices of shape [B*H*W*D, self.num_embeddings]. - torch.Tensor: Quantization indices of shape [B,H,W,D,1] - - """ - with torch.cuda.amp.autocast(enabled=False): - encoding_indices_view = list(inputs.shape) - del encoding_indices_view[1] - - inputs = inputs.float() - - # Converting to channel last format - flat_input = inputs.permute(self.flatten_permutation).contiguous().view(-1, self.embedding_dim) - - # Calculate Euclidean distances - distances = ( - (flat_input**2).sum(dim=1, keepdim=True) - + (self.embedding.weight.t() ** 2).sum(dim=0, keepdim=True) - - 2 * torch.mm(flat_input, self.embedding.weight.t()) - ) - - # Mapping distances to indexes - encoding_indices = torch.max(-distances, dim=1)[1] - encodings = torch.nn.functional.one_hot(encoding_indices, self.num_embeddings).float() - - # Quantize and reshape - encoding_indices = encoding_indices.view(encoding_indices_view) - - return flat_input, encodings, encoding_indices - - def embed(self, embedding_indices: torch.Tensor) -> torch.Tensor: - """ - Given encoding indices of shape [B,D,H,W,1] embeds them in the quantized space - [B, D, H, W, self.embedding_dim] and reshapes them to [B, self.embedding_dim, D, H, W] to be fed to the - decoder. - - Args: - embedding_indices: Tensor in channel last format which holds indices referencing atomic - elements from self.embedding - - Returns: - torch.Tensor: Quantize space representation of encoding_indices in channel first format. - """ - with torch.cuda.amp.autocast(enabled=False): - embedding: torch.Tensor = ( - self.embedding(embedding_indices).permute(self.quantization_permutation).contiguous() - ) - return embedding - - def distributed_synchronization(self, encodings_sum: torch.Tensor, dw: torch.Tensor) -> None: - """ - TorchScript does not support torch.distributed.all_reduce. This function is a bypassing trick based on the - example: https://pytorch.org/docs/stable/generated/torch.jit.unused.html#torch.jit.unused - - Args: - encodings_sum: The summation of one hot representation of what encoding was used for each - position. - dw: The multiplication of the one hot representation of what encoding was used for each - position with the flattened input. - - Returns: - None - """ - if self.ddp_sync and torch.distributed.is_initialized(): - torch.distributed.all_reduce(tensor=encodings_sum, op=torch.distributed.ReduceOp.SUM) - torch.distributed.all_reduce(tensor=dw, op=torch.distributed.ReduceOp.SUM) - else: - pass - - def forward(self, inputs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - flat_input, encodings, encoding_indices = self.quantize(inputs) - quantized = self.embed(encoding_indices) - - # Use EMA to update the embedding vectors - if self.training: - with torch.no_grad(): - encodings_sum = encodings.sum(0) - dw = torch.mm(encodings.t(), flat_input) - - if self.ddp_sync: - self.distributed_synchronization(encodings_sum, dw) - - self.ema_cluster_size.data.mul_(self.decay).add_(torch.mul(encodings_sum, 1 - self.decay)) - - # Laplace smoothing of the cluster size - n = self.ema_cluster_size.sum() - weights = (self.ema_cluster_size + self.epsilon) / (n + self.num_embeddings * self.epsilon) * n - self.ema_w.data.mul_(self.decay).add_(torch.mul(dw, 1 - self.decay)) - self.embedding.weight.data.copy_(self.ema_w / weights.unsqueeze(1)) - - # Encoding Loss - loss = self.commitment_cost * torch.nn.functional.mse_loss(quantized.detach(), inputs) - - # Straight Through Estimator - quantized = inputs + (quantized - inputs).detach() - - return quantized, loss, encoding_indices - - -class VectorQuantizer(torch.nn.Module): - """ - Vector Quantization wrapper that is needed as a workaround for the AMP to isolate the non fp16 compatible parts of - the quantization in their own class. - - Args: - quantizer (torch.nn.Module): Quantizer module that needs to return its quantized representation, loss and index - based quantized representation. - """ - - def __init__(self, quantizer: EMAQuantizer): - super().__init__() - - self.quantizer: EMAQuantizer = quantizer - - self.perplexity: torch.Tensor = torch.rand(1) - - def forward(self, inputs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: - quantized, loss, encoding_indices = self.quantizer(inputs) - # Perplexity calculations - avg_probs = ( - torch.histc(encoding_indices.float(), bins=self.quantizer.num_embeddings, max=self.quantizer.num_embeddings) - .float() - .div(encoding_indices.numel()) - ) - - self.perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10))) - - return loss, quantized - - def embed(self, embedding_indices: torch.Tensor) -> torch.Tensor: - return self.quantizer.embed(embedding_indices=embedding_indices) - - def quantize(self, encodings: torch.Tensor) -> torch.Tensor: - output = self.quantizer(encodings) - encoding_indices: torch.Tensor = output[2] - return encoding_indices diff --git a/monai/networks/nets/__init__.py b/monai/networks/nets/__init__.py index a7ce16ad64c..9247aaee859 100644 --- a/monai/networks/nets/__init__.py +++ b/monai/networks/nets/__init__.py @@ -14,11 +14,9 @@ from .ahnet import AHnet, Ahnet, AHNet from .attentionunet import AttentionUnet from .autoencoder import AutoEncoder -from .autoencoderkl import AutoencoderKL from .basic_unet import BasicUNet, BasicUnet, Basicunet, basicunet from .basic_unetplusplus import BasicUNetPlusPlus, BasicUnetPlusPlus, BasicunetPlusPlus, basicunetplusplus from .classifier import Classifier, Critic, Discriminator -from .controlnet import ControlNet from .daf3d import DAF3D from .densenet import ( DenseNet, @@ -36,7 +34,6 @@ densenet201, densenet264, ) -from .diffusion_model_unet import DiffusionModelUNet from .dints import DiNTS, TopologyConstruction, TopologyInstance, TopologySearch from .dynunet import DynUNet, DynUnet, Dynunet from .efficientnet import ( @@ -55,7 +52,6 @@ from .hovernet import Hovernet, HoVernet, HoVerNet, HoverNet from .milmodel import MILModel from .netadapter import NetAdapter -from .patchgan_discriminator import MultiScalePatchDiscriminator, PatchDiscriminator from .quicknat import Quicknat from .regressor import Regressor from .regunet import GlobalNet, LocalNet, RegUNet @@ -106,12 +102,9 @@ seresnext50, seresnext101, ) -from .spade_autoencoderkl import SPADEAutoencoderKL -from .spade_diffusion_model_unet import SPADEDiffusionModelUNet from .swin_unetr import PatchMerging, PatchMergingV2, SwinUNETR from .torchvision_fc import TorchVisionFCModel from .transchex import BertAttention, BertMixedLayer, BertOutput, BertPreTrainedModel, MultiModal, Pooler, Transchex -from .transformer import DecoderOnlyTransformer from .unet import UNet, Unet from .unetr import UNETR from .varautoencoder import VarAutoEncoder @@ -119,4 +112,3 @@ from .vitautoenc import ViTAutoEnc from .vnet import VNet from .voxelmorph import VoxelMorph, VoxelMorphUNet -from .vqvae import VQVAE diff --git a/monai/networks/nets/autoencoderkl.py b/monai/networks/nets/autoencoderkl.py deleted file mode 100644 index f7ae77f0569..00000000000 --- a/monai/networks/nets/autoencoderkl.py +++ /dev/null @@ -1,807 +0,0 @@ -# Copyright (c) MONAI Consortium -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -import math -from collections.abc import Sequence -from typing import List - -import torch -import torch.nn as nn -import torch.nn.functional as F - -from monai.networks.blocks import Convolution - -# To install xformers, use pip install xformers==0.0.16rc401 -from monai.utils import ensure_tuple_rep, optional_import - -xformers, has_xformers = optional_import("xformers") - -__all__ = ["AutoencoderKL"] - - -class _Upsample(nn.Module): - """ - NOTE This is a private block that we plan to merge with existing MONAI blocks in the future. Please do not make - use of this block as support is not guaranteed. For more information see: - https://github.com/Project-MONAI/MONAI/issues/7227 - - Convolution-based upsampling layer. - - Args: - spatial_dims: number of spatial dimensions, could be 1, 2, or 3. - in_channels: number of input channels to the layer. - use_convtranspose: if True, use ConvTranspose to upsample feature maps in decoder. - """ - - def __init__(self, spatial_dims: int, in_channels: int, use_convtranspose: bool) -> None: - super().__init__() - if use_convtranspose: - self.conv = Convolution( - spatial_dims=spatial_dims, - in_channels=in_channels, - out_channels=in_channels, - strides=2, - kernel_size=3, - padding=1, - conv_only=True, - is_transposed=True, - ) - else: - self.conv = Convolution( - spatial_dims=spatial_dims, - in_channels=in_channels, - out_channels=in_channels, - strides=1, - kernel_size=3, - padding=1, - conv_only=True, - ) - self.use_convtranspose = use_convtranspose - - def forward(self, x: torch.Tensor) -> torch.Tensor: - if self.use_convtranspose: - conv: torch.Tensor = self.conv(x) - return conv - - # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16 - # https://github.com/pytorch/pytorch/issues/86679 - dtype = x.dtype - if dtype == torch.bfloat16: - x = x.to(torch.float32) - - x = F.interpolate(x, scale_factor=2.0, mode="nearest") - - # If the input is bfloat16, we cast back to bfloat16 - if dtype == torch.bfloat16: - x = x.to(dtype) - - x = self.conv(x) - return x - - -class _Downsample(nn.Module): - """ - NOTE This is a private block that we plan to merge with existing MONAI blocks in the future. Please do not make - use of this block as support is not guaranteed. For more information see: - https://github.com/Project-MONAI/MONAI/issues/7227 - - Convolution-based downsampling layer. - - Args: - spatial_dims: number of spatial dimensions, could be 1, 2, or 3. - in_channels: number of input channels. - """ - - def __init__(self, spatial_dims: int, in_channels: int) -> None: - super().__init__() - self.pad = (0, 1) * spatial_dims - - self.conv = Convolution( - spatial_dims=spatial_dims, - in_channels=in_channels, - out_channels=in_channels, - strides=2, - kernel_size=3, - padding=0, - conv_only=True, - ) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - x = nn.functional.pad(x, self.pad, mode="constant", value=0.0) - x = self.conv(x) - return x - - -class _ResBlock(nn.Module): - """ - NOTE This is a private block that we plan to merge with existing MONAI blocks in the future. Please do not make - use of this block as support is not guaranteed. For more information see: - https://github.com/Project-MONAI/MONAI/issues/7227 - - Residual block consisting of a cascade of 2 convolutions + activation + normalisation block, and a - residual connection between input and output. - - Args: - spatial_dims: number of spatial dimensions, could be 1, 2, or 3. - in_channels: input channels to the layer. - norm_num_groups: number of groups involved for the group normalisation layer. Ensure that your number of - channels is divisible by this number. - norm_eps: epsilon for the normalisation. - out_channels: number of output channels. - """ - - def __init__( - self, spatial_dims: int, in_channels: int, norm_num_groups: int, norm_eps: float, out_channels: int - ) -> None: - super().__init__() - self.in_channels = in_channels - self.out_channels = in_channels if out_channels is None else out_channels - - self.norm1 = nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=norm_eps, affine=True) - self.conv1 = Convolution( - spatial_dims=spatial_dims, - in_channels=self.in_channels, - out_channels=self.out_channels, - strides=1, - kernel_size=3, - padding=1, - conv_only=True, - ) - self.norm2 = nn.GroupNorm(num_groups=norm_num_groups, num_channels=out_channels, eps=norm_eps, affine=True) - self.conv2 = Convolution( - spatial_dims=spatial_dims, - in_channels=self.out_channels, - out_channels=self.out_channels, - strides=1, - kernel_size=3, - padding=1, - conv_only=True, - ) - - self.nin_shortcut: nn.Module - if self.in_channels != self.out_channels: - self.nin_shortcut = Convolution( - spatial_dims=spatial_dims, - in_channels=self.in_channels, - out_channels=self.out_channels, - strides=1, - kernel_size=1, - padding=0, - conv_only=True, - ) - else: - self.nin_shortcut = nn.Identity() - - def forward(self, x: torch.Tensor) -> torch.Tensor: - h = x - h = self.norm1(h) - h = F.silu(h) - h = self.conv1(h) - - h = self.norm2(h) - h = F.silu(h) - h = self.conv2(h) - - x = self.nin_shortcut(x) - - return x + h - - -class _AttentionBlock(nn.Module): - """ - NOTE This is a private block that we plan to merge with existing MONAI blocks in the future. Please do not make - use of this block as support is not guaranteed. For more information see: - https://github.com/Project-MONAI/MONAI/issues/7227 - - Attention block. - - Args: - spatial_dims: number of spatial dimensions, could be 1, 2, or 3. - num_channels: number of input channels. - num_head_channels: number of channels in each attention head. - norm_num_groups: number of groups involved for the group normalisation layer. Ensure that your number of - channels is divisible by this number. - norm_eps: epsilon value to use for the normalisation. - use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. - """ - - def __init__( - self, - spatial_dims: int, - num_channels: int, - num_head_channels: int | None = None, - norm_num_groups: int = 32, - norm_eps: float = 1e-6, - use_flash_attention: bool = False, - ) -> None: - super().__init__() - self.use_flash_attention = use_flash_attention - self.spatial_dims = spatial_dims - self.num_channels = num_channels - - self.num_heads = num_channels // num_head_channels if num_head_channels is not None else 1 - self.scale = 1 / math.sqrt(num_channels / self.num_heads) - - self.norm = nn.GroupNorm(num_groups=norm_num_groups, num_channels=num_channels, eps=norm_eps, affine=True) - - self.to_q = nn.Linear(num_channels, num_channels) - self.to_k = nn.Linear(num_channels, num_channels) - self.to_v = nn.Linear(num_channels, num_channels) - - self.proj_attn = nn.Linear(num_channels, num_channels) - - def reshape_heads_to_batch_dim(self, x: torch.Tensor) -> torch.Tensor: - """ - Divide hidden state dimension to the multiple attention heads and reshape their input as instances in the batch. - """ - batch_size, seq_len, dim = x.shape - x = x.reshape(batch_size, seq_len, self.num_heads, dim // self.num_heads) - x = x.permute(0, 2, 1, 3).reshape(batch_size * self.num_heads, seq_len, dim // self.num_heads) - return x - - def reshape_batch_dim_to_heads(self, x: torch.Tensor) -> torch.Tensor: - """Combine the output of the attention heads back into the hidden state dimension.""" - batch_size, seq_len, dim = x.shape - x = x.reshape(batch_size // self.num_heads, self.num_heads, seq_len, dim) - x = x.permute(0, 2, 1, 3).reshape(batch_size // self.num_heads, seq_len, dim * self.num_heads) - return x - - def _memory_efficient_attention_xformers( - self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor - ) -> torch.Tensor: - query = query.contiguous() - key = key.contiguous() - value = value.contiguous() - x: torch.Tensor = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=None) - return x - - def _attention(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor: - attention_scores = torch.baddbmm( - torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device), - query, - key.transpose(-1, -2), - beta=0, - alpha=self.scale, - ) - attention_probs = attention_scores.softmax(dim=-1) - x = torch.bmm(attention_probs, value) - return x - - def forward(self, x: torch.Tensor) -> torch.Tensor: - residual = x - - batch = channel = height = width = depth = -1 - if self.spatial_dims == 2: - batch, channel, height, width = x.shape - if self.spatial_dims == 3: - batch, channel, height, width, depth = x.shape - - # norm - x = self.norm(x) - - if self.spatial_dims == 2: - x = x.view(batch, channel, height * width).transpose(1, 2) - if self.spatial_dims == 3: - x = x.view(batch, channel, height * width * depth).transpose(1, 2) - - # proj to q, k, v - query = self.to_q(x) - key = self.to_k(x) - value = self.to_v(x) - - # Multi-Head Attention - query = self.reshape_heads_to_batch_dim(query) - key = self.reshape_heads_to_batch_dim(key) - value = self.reshape_heads_to_batch_dim(value) - - if self.use_flash_attention: - x = self._memory_efficient_attention_xformers(query, key, value) - else: - x = self._attention(query, key, value) - - x = self.reshape_batch_dim_to_heads(x) - x = x.to(query.dtype) - - if self.spatial_dims == 2: - x = x.transpose(-1, -2).reshape(batch, channel, height, width) - if self.spatial_dims == 3: - x = x.transpose(-1, -2).reshape(batch, channel, height, width, depth) - - return x + residual - - -class Encoder(nn.Module): - """ - Convolutional cascade that downsamples the image into a spatial latent space. - - Args: - spatial_dims: number of spatial dimensions, could be 1, 2, or 3. - in_channels: number of input channels. - channels: sequence of block output channels. - out_channels: number of channels in the bottom layer (latent space) of the autoencoder. - num_res_blocks: number of residual blocks (see _ResBlock) per level. - norm_num_groups: number of groups for the GroupNorm layers, num_channels must be divisible by this number. - norm_eps: epsilon for the normalization. - attention_levels: indicate which level from num_channels contain an attention block. - with_nonlocal_attn: if True use non-local attention block. - use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. - """ - - def __init__( - self, - spatial_dims: int, - in_channels: int, - channels: Sequence[int], - out_channels: int, - num_res_blocks: Sequence[int], - norm_num_groups: int, - norm_eps: float, - attention_levels: Sequence[bool], - with_nonlocal_attn: bool = True, - use_flash_attention: bool = False, - ) -> None: - super().__init__() - self.spatial_dims = spatial_dims - self.in_channels = in_channels - self.channels = channels - self.out_channels = out_channels - self.num_res_blocks = num_res_blocks - self.norm_num_groups = norm_num_groups - self.norm_eps = norm_eps - self.attention_levels = attention_levels - - blocks: List[nn.Module] = [] - # Initial convolution - blocks.append( - Convolution( - spatial_dims=spatial_dims, - in_channels=in_channels, - out_channels=channels[0], - strides=1, - kernel_size=3, - padding=1, - conv_only=True, - ) - ) - - # Residual and downsampling blocks - output_channel = channels[0] - for i in range(len(channels)): - input_channel = output_channel - output_channel = channels[i] - is_final_block = i == len(channels) - 1 - - for _ in range(self.num_res_blocks[i]): - blocks.append( - _ResBlock( - spatial_dims=spatial_dims, - in_channels=input_channel, - norm_num_groups=norm_num_groups, - norm_eps=norm_eps, - out_channels=output_channel, - ) - ) - input_channel = output_channel - if attention_levels[i]: - blocks.append( - _AttentionBlock( - spatial_dims=spatial_dims, - num_channels=input_channel, - norm_num_groups=norm_num_groups, - norm_eps=norm_eps, - use_flash_attention=use_flash_attention, - ) - ) - - if not is_final_block: - blocks.append(_Downsample(spatial_dims=spatial_dims, in_channels=input_channel)) - - # Non-local attention block - if with_nonlocal_attn is True: - blocks.append( - _ResBlock( - spatial_dims=spatial_dims, - in_channels=channels[-1], - norm_num_groups=norm_num_groups, - norm_eps=norm_eps, - out_channels=channels[-1], - ) - ) - - blocks.append( - _AttentionBlock( - spatial_dims=spatial_dims, - num_channels=channels[-1], - norm_num_groups=norm_num_groups, - norm_eps=norm_eps, - use_flash_attention=use_flash_attention, - ) - ) - blocks.append( - _ResBlock( - spatial_dims=spatial_dims, - in_channels=channels[-1], - norm_num_groups=norm_num_groups, - norm_eps=norm_eps, - out_channels=channels[-1], - ) - ) - # Normalise and convert to latent size - blocks.append(nn.GroupNorm(num_groups=norm_num_groups, num_channels=channels[-1], eps=norm_eps, affine=True)) - blocks.append( - Convolution( - spatial_dims=self.spatial_dims, - in_channels=channels[-1], - out_channels=out_channels, - strides=1, - kernel_size=3, - padding=1, - conv_only=True, - ) - ) - - self.blocks = nn.ModuleList(blocks) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - for block in self.blocks: - x = block(x) - return x - - -class Decoder(nn.Module): - """ - Convolutional cascade upsampling from a spatial latent space into an image space. - - Args: - spatial_dims: number of spatial dimensions, could be 1, 2, or 3. - channels: sequence of block output channels. - in_channels: number of channels in the bottom layer (latent space) of the autoencoder. - out_channels: number of output channels. - num_res_blocks: number of residual blocks (see _ResBlock) per level. - norm_num_groups: number of groups for the GroupNorm layers, num_channels must be divisible by this number. - norm_eps: epsilon for the normalization. - attention_levels: indicate which level from num_channels contain an attention block. - with_nonlocal_attn: if True use non-local attention block. - use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. - use_convtranspose: if True, use ConvTranspose to upsample feature maps in decoder. - """ - - def __init__( - self, - spatial_dims: int, - channels: Sequence[int], - in_channels: int, - out_channels: int, - num_res_blocks: Sequence[int], - norm_num_groups: int, - norm_eps: float, - attention_levels: Sequence[bool], - with_nonlocal_attn: bool = True, - use_flash_attention: bool = False, - use_convtranspose: bool = False, - ) -> None: - super().__init__() - self.spatial_dims = spatial_dims - self.channels = channels - self.in_channels = in_channels - self.out_channels = out_channels - self.num_res_blocks = num_res_blocks - self.norm_num_groups = norm_num_groups - self.norm_eps = norm_eps - self.attention_levels = attention_levels - - reversed_block_out_channels = list(reversed(channels)) - - blocks: List[nn.Module] = [] - - # Initial convolution - blocks.append( - Convolution( - spatial_dims=spatial_dims, - in_channels=in_channels, - out_channels=reversed_block_out_channels[0], - strides=1, - kernel_size=3, - padding=1, - conv_only=True, - ) - ) - - # Non-local attention block - if with_nonlocal_attn is True: - blocks.append( - _ResBlock( - spatial_dims=spatial_dims, - in_channels=reversed_block_out_channels[0], - norm_num_groups=norm_num_groups, - norm_eps=norm_eps, - out_channels=reversed_block_out_channels[0], - ) - ) - blocks.append( - _AttentionBlock( - spatial_dims=spatial_dims, - num_channels=reversed_block_out_channels[0], - norm_num_groups=norm_num_groups, - norm_eps=norm_eps, - use_flash_attention=use_flash_attention, - ) - ) - blocks.append( - _ResBlock( - spatial_dims=spatial_dims, - in_channels=reversed_block_out_channels[0], - norm_num_groups=norm_num_groups, - norm_eps=norm_eps, - out_channels=reversed_block_out_channels[0], - ) - ) - - reversed_attention_levels = list(reversed(attention_levels)) - reversed_num_res_blocks = list(reversed(num_res_blocks)) - block_out_ch = reversed_block_out_channels[0] - for i in range(len(reversed_block_out_channels)): - block_in_ch = block_out_ch - block_out_ch = reversed_block_out_channels[i] - is_final_block = i == len(channels) - 1 - - for _ in range(reversed_num_res_blocks[i]): - blocks.append( - _ResBlock( - spatial_dims=spatial_dims, - in_channels=block_in_ch, - norm_num_groups=norm_num_groups, - norm_eps=norm_eps, - out_channels=block_out_ch, - ) - ) - block_in_ch = block_out_ch - - if reversed_attention_levels[i]: - blocks.append( - _AttentionBlock( - spatial_dims=spatial_dims, - num_channels=block_in_ch, - norm_num_groups=norm_num_groups, - norm_eps=norm_eps, - use_flash_attention=use_flash_attention, - ) - ) - - if not is_final_block: - blocks.append( - _Upsample(spatial_dims=spatial_dims, in_channels=block_in_ch, use_convtranspose=use_convtranspose) - ) - - blocks.append(nn.GroupNorm(num_groups=norm_num_groups, num_channels=block_in_ch, eps=norm_eps, affine=True)) - blocks.append( - Convolution( - spatial_dims=spatial_dims, - in_channels=block_in_ch, - out_channels=out_channels, - strides=1, - kernel_size=3, - padding=1, - conv_only=True, - ) - ) - - self.blocks = nn.ModuleList(blocks) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - for block in self.blocks: - x = block(x) - return x - - -class AutoencoderKL(nn.Module): - """ - Autoencoder model with KL-regularized latent space based on - Rombach et al. "High-Resolution Image Synthesis with Latent Diffusion Models" https://arxiv.org/abs/2112.10752 - and Pinaya et al. "Brain Imaging Generation with Latent Diffusion Models" https://arxiv.org/abs/2209.07162 - - Args: - spatial_dims: number of spatial dimensions, could be 1, 2, or 3. - in_channels: number of input channels. - out_channels: number of output channels. - num_res_blocks: number of residual blocks (see _ResBlock) per level. - channels: number of output channels for each block. - attention_levels: sequence of levels to add attention. - latent_channels: latent embedding dimension. - norm_num_groups: number of groups for the GroupNorm layers, num_channels must be divisible by this number. - norm_eps: epsilon for the normalization. - with_encoder_nonlocal_attn: if True use non-local attention block in the encoder. - with_decoder_nonlocal_attn: if True use non-local attention block in the decoder. - use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. - use_checkpoint: if True, use activation checkpoint to save memory. - use_convtranspose: if True, use ConvTranspose to upsample feature maps in decoder. - """ - - def __init__( - self, - spatial_dims: int, - in_channels: int = 1, - out_channels: int = 1, - num_res_blocks: Sequence[int] | int = (2, 2, 2, 2), - channels: Sequence[int] = (32, 64, 64, 64), - attention_levels: Sequence[bool] = (False, False, True, True), - latent_channels: int = 3, - norm_num_groups: int = 32, - norm_eps: float = 1e-6, - with_encoder_nonlocal_attn: bool = True, - with_decoder_nonlocal_attn: bool = True, - use_flash_attention: bool = False, - use_checkpoint: bool = False, - use_convtranspose: bool = False, - ) -> None: - super().__init__() - - # All number of channels should be multiple of num_groups - if any((out_channel % norm_num_groups) != 0 for out_channel in channels): - raise ValueError("AutoencoderKL expects all num_channels being multiple of norm_num_groups") - - if len(channels) != len(attention_levels): - raise ValueError("AutoencoderKL expects num_channels being same size of attention_levels") - - if isinstance(num_res_blocks, int): - num_res_blocks = ensure_tuple_rep(num_res_blocks, len(channels)) - - if len(num_res_blocks) != len(channels): - raise ValueError( - "`num_res_blocks` should be a single integer or a tuple of integers with the same length as " - "`num_channels`." - ) - - if use_flash_attention is True and not torch.cuda.is_available(): - raise ValueError( - "torch.cuda.is_available() should be True but is False. Flash attention is only available for GPU." - ) - - self.encoder = Encoder( - spatial_dims=spatial_dims, - in_channels=in_channels, - channels=channels, - out_channels=latent_channels, - num_res_blocks=num_res_blocks, - norm_num_groups=norm_num_groups, - norm_eps=norm_eps, - attention_levels=attention_levels, - with_nonlocal_attn=with_encoder_nonlocal_attn, - use_flash_attention=use_flash_attention, - ) - self.decoder = Decoder( - spatial_dims=spatial_dims, - channels=channels, - in_channels=latent_channels, - out_channels=out_channels, - num_res_blocks=num_res_blocks, - norm_num_groups=norm_num_groups, - norm_eps=norm_eps, - attention_levels=attention_levels, - with_nonlocal_attn=with_decoder_nonlocal_attn, - use_flash_attention=use_flash_attention, - use_convtranspose=use_convtranspose, - ) - self.quant_conv_mu = Convolution( - spatial_dims=spatial_dims, - in_channels=latent_channels, - out_channels=latent_channels, - strides=1, - kernel_size=1, - padding=0, - conv_only=True, - ) - self.quant_conv_log_sigma = Convolution( - spatial_dims=spatial_dims, - in_channels=latent_channels, - out_channels=latent_channels, - strides=1, - kernel_size=1, - padding=0, - conv_only=True, - ) - self.post_quant_conv = Convolution( - spatial_dims=spatial_dims, - in_channels=latent_channels, - out_channels=latent_channels, - strides=1, - kernel_size=1, - padding=0, - conv_only=True, - ) - self.latent_channels = latent_channels - self.use_checkpoint = use_checkpoint - - def encode(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: - """ - Forwards an image through the spatial encoder, obtaining the latent mean and sigma representations. - - Args: - x: BxCx[SPATIAL DIMS] tensor - - """ - if self.use_checkpoint: - h = torch.utils.checkpoint.checkpoint(self.encoder, x, use_reentrant=False) - else: - h = self.encoder(x) - - z_mu = self.quant_conv_mu(h) - z_log_var = self.quant_conv_log_sigma(h) - z_log_var = torch.clamp(z_log_var, -30.0, 20.0) - z_sigma = torch.exp(z_log_var / 2) - - return z_mu, z_sigma - - def sampling(self, z_mu: torch.Tensor, z_sigma: torch.Tensor) -> torch.Tensor: - """ - From the mean and sigma representations resulting of encoding an image through the latent space, - obtains a noise sample resulting from sampling gaussian noise, multiplying by the variance (sigma) and - adding the mean. - - Args: - z_mu: Bx[Z_CHANNELS]x[LATENT SPACE SIZE] mean vector obtained by the encoder when you encode an image - z_sigma: Bx[Z_CHANNELS]x[LATENT SPACE SIZE] variance vector obtained by the encoder when you encode an image - - Returns: - sample of shape Bx[Z_CHANNELS]x[LATENT SPACE SIZE] - """ - eps = torch.randn_like(z_sigma) - z_vae = z_mu + eps * z_sigma - return z_vae - - def reconstruct(self, x: torch.Tensor) -> torch.Tensor: - """ - Encodes and decodes an input image. - - Args: - x: BxCx[SPATIAL DIMENSIONS] tensor. - - Returns: - reconstructed image, of the same shape as input - """ - z_mu, _ = self.encode(x) - reconstruction = self.decode(z_mu) - return reconstruction - - def decode(self, z: torch.Tensor) -> torch.Tensor: - """ - Based on a latent space sample, forwards it through the Decoder. - - Args: - z: Bx[Z_CHANNELS]x[LATENT SPACE SHAPE] - - Returns: - decoded image tensor - """ - z = self.post_quant_conv(z) - dec: torch.Tensor - if self.use_checkpoint: - dec = torch.utils.checkpoint.checkpoint(self.decoder, z, use_reentrant=False) - else: - dec = self.decoder(z) - return dec - - def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - z_mu, z_sigma = self.encode(x) - z = self.sampling(z_mu, z_sigma) - reconstruction = self.decode(z) - return reconstruction, z_mu, z_sigma - - def encode_stage_2_inputs(self, x: torch.Tensor) -> torch.Tensor: - z_mu, z_sigma = self.encode(x) - z = self.sampling(z_mu, z_sigma) - return z - - def decode_stage_2_outputs(self, z: torch.Tensor) -> torch.Tensor: - image = self.decode(z) - return image diff --git a/monai/networks/nets/controlnet.py b/monai/networks/nets/controlnet.py deleted file mode 100644 index d98755f4017..00000000000 --- a/monai/networks/nets/controlnet.py +++ /dev/null @@ -1,421 +0,0 @@ -# Copyright (c) MONAI Consortium -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -# ========================================================================= -# Adapted from https://github.com/huggingface/diffusers -# which has the following license: -# https://github.com/huggingface/diffusers/blob/main/LICENSE -# -# Copyright 2022 UC Berkeley Team and The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ========================================================================= - -from __future__ import annotations - -from collections.abc import Sequence - -import torch -import torch.nn.functional as F -from torch import nn - -from monai.networks.blocks import Convolution -from monai.networks.nets.diffusion_model_unet import get_down_block, get_mid_block, get_timestep_embedding -from monai.utils import ensure_tuple_rep - - -class ControlNetConditioningEmbedding(nn.Module): - """ - Network to encode the conditioning into a latent space. - """ - - def __init__(self, spatial_dims: int, in_channels: int, out_channels: int, channels: Sequence[int]): - super().__init__() - - self.conv_in = Convolution( - spatial_dims=spatial_dims, - in_channels=in_channels, - out_channels=channels[0], - strides=1, - kernel_size=3, - padding=1, - conv_only=True, - ) - - self.blocks = nn.ModuleList([]) - - for i in range(len(channels) - 1): - channel_in = channels[i] - channel_out = channels[i + 1] - self.blocks.append( - Convolution( - spatial_dims=spatial_dims, - in_channels=channel_in, - out_channels=channel_in, - strides=1, - kernel_size=3, - padding=1, - conv_only=True, - ) - ) - - self.blocks.append( - Convolution( - spatial_dims=spatial_dims, - in_channels=channel_in, - out_channels=channel_out, - strides=2, - kernel_size=3, - padding=1, - conv_only=True, - ) - ) - - self.conv_out = zero_module( - Convolution( - spatial_dims=spatial_dims, - in_channels=channels[-1], - out_channels=out_channels, - strides=1, - kernel_size=3, - padding=1, - conv_only=True, - ) - ) - - def forward(self, conditioning): - embedding = self.conv_in(conditioning) - embedding = F.silu(embedding) - - for block in self.blocks: - embedding = block(embedding) - embedding = F.silu(embedding) - - embedding = self.conv_out(embedding) - - return embedding - - -def zero_module(module): - for p in module.parameters(): - nn.init.zeros_(p) - return module - - -class ControlNet(nn.Module): - """ - Control network for diffusion models based on Zhang and Agrawala "Adding Conditional Control to Text-to-Image - Diffusion Models" (https://arxiv.org/abs/2302.05543) - - Args: - spatial_dims: number of spatial dimensions. - in_channels: number of input channels. - num_res_blocks: number of residual blocks (see ResnetBlock) per level. - channels: tuple of block output channels. - attention_levels: list of levels to add attention. - norm_num_groups: number of groups for the normalization. - norm_eps: epsilon for the normalization. - resblock_updown: if True use residual blocks for up/downsampling. - num_head_channels: number of channels in each attention head. - with_conditioning: if True add spatial transformers to perform conditioning. - transformer_num_layers: number of layers of Transformer blocks to use. - cross_attention_dim: number of context dimensions to use. - num_class_embeds: if specified (as an int), then this model will be class-conditional with `num_class_embeds` - classes. - upcast_attention: if True, upcast attention operations to full precision. - use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. - conditioning_embedding_in_channels: number of input channels for the conditioning embedding. - conditioning_embedding_num_channels: number of channels for the blocks in the conditioning embedding. - """ - - def __init__( - self, - spatial_dims: int, - in_channels: int, - num_res_blocks: Sequence[int] | int = (2, 2, 2, 2), - channels: Sequence[int] = (32, 64, 64, 64), - attention_levels: Sequence[bool] = (False, False, True, True), - norm_num_groups: int = 32, - norm_eps: float = 1e-6, - resblock_updown: bool = False, - num_head_channels: int | Sequence[int] = 8, - with_conditioning: bool = False, - transformer_num_layers: int = 1, - cross_attention_dim: int | None = None, - num_class_embeds: int | None = None, - upcast_attention: bool = False, - use_flash_attention: bool = False, - conditioning_embedding_in_channels: int = 1, - conditioning_embedding_num_channels: Sequence[int] = (16, 32, 96, 256), - ) -> None: - super().__init__() - if with_conditioning is True and cross_attention_dim is None: - raise ValueError( - "DiffusionModelUNet expects dimension of the cross-attention conditioning (cross_attention_dim) " - "to be specified when with_conditioning=True." - ) - if cross_attention_dim is not None and with_conditioning is False: - raise ValueError( - "DiffusionModelUNet expects with_conditioning=True when specifying the cross_attention_dim." - ) - - # All number of channels should be multiple of num_groups - if any((out_channel % norm_num_groups) != 0 for out_channel in channels): - raise ValueError( - f"DiffusionModelUNet expects all channels to be a multiple of norm_num_groups, but got" - f" channels={channels} and norm_num_groups={norm_num_groups}" - ) - - if len(channels) != len(attention_levels): - raise ValueError( - f"DiffusionModelUNet expects channels to have the same length as attention_levels, but got " - f"channels={channels} and attention_levels={attention_levels}" - ) - - if isinstance(num_head_channels, int): - num_head_channels = ensure_tuple_rep(num_head_channels, len(attention_levels)) - - if len(num_head_channels) != len(attention_levels): - raise ValueError( - f"num_head_channels should have the same length as attention_levels, but got channels={channels} and " - f"attention_levels={attention_levels} . For the i levels without attention," - " i.e. `attention_level[i]=False`, the num_head_channels[i] will be ignored." - ) - - if isinstance(num_res_blocks, int): - num_res_blocks = ensure_tuple_rep(num_res_blocks, len(channels)) - - if len(num_res_blocks) != len(channels): - raise ValueError( - f"`num_res_blocks` should be a single integer or a tuple of integers with the same length as " - f"`num_channels`, but got num_res_blocks={num_res_blocks} and channels={channels}." - ) - - if use_flash_attention is True and not torch.cuda.is_available(): - raise ValueError( - "torch.cuda.is_available() should be True but is False. Flash attention is only available for GPU." - ) - - self.in_channels = in_channels - self.block_out_channels = channels - self.num_res_blocks = num_res_blocks - self.attention_levels = attention_levels - self.num_head_channels = num_head_channels - self.with_conditioning = with_conditioning - - # input - self.conv_in = Convolution( - spatial_dims=spatial_dims, - in_channels=in_channels, - out_channels=channels[0], - strides=1, - kernel_size=3, - padding=1, - conv_only=True, - ) - - # time - time_embed_dim = channels[0] * 4 - self.time_embed = nn.Sequential( - nn.Linear(channels[0], time_embed_dim), nn.SiLU(), nn.Linear(time_embed_dim, time_embed_dim) - ) - - # class embedding - self.num_class_embeds = num_class_embeds - if num_class_embeds is not None: - self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) - - # control net conditioning embedding - self.controlnet_cond_embedding = ControlNetConditioningEmbedding( - spatial_dims=spatial_dims, - in_channels=conditioning_embedding_in_channels, - channels=conditioning_embedding_num_channels, - out_channels=channels[0], - ) - - # down - self.down_blocks = nn.ModuleList([]) - self.controlnet_down_blocks = nn.ModuleList([]) - output_channel = channels[0] - - controlnet_block = Convolution( - spatial_dims=spatial_dims, - in_channels=output_channel, - out_channels=output_channel, - strides=1, - kernel_size=1, - padding=0, - conv_only=True, - ) - controlnet_block = zero_module(controlnet_block.conv) - self.controlnet_down_blocks.append(controlnet_block) - - for i in range(len(channels)): - input_channel = output_channel - output_channel = channels[i] - is_final_block = i == len(channels) - 1 - - down_block = get_down_block( - spatial_dims=spatial_dims, - in_channels=input_channel, - out_channels=output_channel, - temb_channels=time_embed_dim, - num_res_blocks=num_res_blocks[i], - norm_num_groups=norm_num_groups, - norm_eps=norm_eps, - add_downsample=not is_final_block, - resblock_updown=resblock_updown, - with_attn=(attention_levels[i] and not with_conditioning), - with_cross_attn=(attention_levels[i] and with_conditioning), - num_head_channels=num_head_channels[i], - transformer_num_layers=transformer_num_layers, - cross_attention_dim=cross_attention_dim, - upcast_attention=upcast_attention, - use_flash_attention=use_flash_attention, - ) - - self.down_blocks.append(down_block) - - for _ in range(num_res_blocks[i]): - controlnet_block = Convolution( - spatial_dims=spatial_dims, - in_channels=output_channel, - out_channels=output_channel, - strides=1, - kernel_size=1, - padding=0, - conv_only=True, - ) - controlnet_block = zero_module(controlnet_block) - self.controlnet_down_blocks.append(controlnet_block) - # - if not is_final_block: - controlnet_block = Convolution( - spatial_dims=spatial_dims, - in_channels=output_channel, - out_channels=output_channel, - strides=1, - kernel_size=1, - padding=0, - conv_only=True, - ) - controlnet_block = zero_module(controlnet_block) - self.controlnet_down_blocks.append(controlnet_block) - - # mid - mid_block_channel = channels[-1] - - self.middle_block = get_mid_block( - spatial_dims=spatial_dims, - in_channels=mid_block_channel, - temb_channels=time_embed_dim, - norm_num_groups=norm_num_groups, - norm_eps=norm_eps, - with_conditioning=with_conditioning, - num_head_channels=num_head_channels[-1], - transformer_num_layers=transformer_num_layers, - cross_attention_dim=cross_attention_dim, - upcast_attention=upcast_attention, - use_flash_attention=use_flash_attention, - ) - - controlnet_block = Convolution( - spatial_dims=spatial_dims, - in_channels=output_channel, - out_channels=output_channel, - strides=1, - kernel_size=1, - padding=0, - conv_only=True, - ) - controlnet_block = zero_module(controlnet_block) - self.controlnet_mid_block = controlnet_block - - def forward( - self, - x: torch.Tensor, - timesteps: torch.Tensor, - controlnet_cond: torch.Tensor, - conditioning_scale: float = 1.0, - context: torch.Tensor | None = None, - class_labels: torch.Tensor | None = None, - ) -> tuple[list[torch.Tensor], torch.Tensor]: - """ - Args: - x: input tensor (N, C, H, W, [D]). - timesteps: timestep tensor (N,). - controlnet_cond: controlnet conditioning tensor (N, C, H, W, [D]) - conditioning_scale: conditioning scale. - context: context tensor (N, 1, cross_attention_dim), where cross_attention_dim is specified in the model init. - class_labels: context tensor (N, ). - """ - # 1. time - t_emb = get_timestep_embedding(timesteps, self.block_out_channels[0]) - - # timesteps does not contain any weights and will always return f32 tensors - # but time_embedding might actually be running in fp16. so we need to cast here. - # there might be better ways to encapsulate this. - t_emb = t_emb.to(dtype=x.dtype) - emb = self.time_embed(t_emb) - - # 2. class - if self.num_class_embeds is not None: - if class_labels is None: - raise ValueError("class_labels should be provided when num_class_embeds > 0") - class_emb = self.class_embedding(class_labels) - class_emb = class_emb.to(dtype=x.dtype) - emb = emb + class_emb - - # 3. initial convolution - h = self.conv_in(x) - - controlnet_cond = self.controlnet_cond_embedding(controlnet_cond) - - h += controlnet_cond - - # 4. down - if context is not None and self.with_conditioning is False: - raise ValueError("model should have with_conditioning = True if context is provided") - down_block_res_samples: list[torch.Tensor] = [h] - for downsample_block in self.down_blocks: - h, res_samples = downsample_block(hidden_states=h, temb=emb, context=context) - for residual in res_samples: - down_block_res_samples.append(residual) - - # 5. mid - h = self.middle_block(hidden_states=h, temb=emb, context=context) - - # 6. Control net blocks - controlnet_down_block_res_samples = [] - - for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks): - down_block_res_sample = controlnet_block(down_block_res_sample) - controlnet_down_block_res_samples.append(down_block_res_sample) - - down_block_res_samples = controlnet_down_block_res_samples - - mid_block_res_sample: torch.Tensor = self.controlnet_mid_block(h) - - # 6. scaling - down_block_res_samples = [h * conditioning_scale for h in down_block_res_samples] - mid_block_res_sample *= conditioning_scale - - return down_block_res_samples, mid_block_res_sample diff --git a/monai/networks/nets/diffusion_model_unet.py b/monai/networks/nets/diffusion_model_unet.py deleted file mode 100644 index 1532215c70d..00000000000 --- a/monai/networks/nets/diffusion_model_unet.py +++ /dev/null @@ -1,2138 +0,0 @@ -# Copyright (c) MONAI Consortium -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -# ========================================================================= -# Adapted from https://github.com/huggingface/diffusers -# which has the following license: -# https://github.com/huggingface/diffusers/blob/main/LICENSE -# -# Copyright 2022 UC Berkeley Team and The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ========================================================================= - -from __future__ import annotations - -import math -from collections.abc import Sequence - -import torch -import torch.nn.functional as F -from torch import nn - -from monai.networks.blocks import Convolution, MLPBlock -from monai.networks.layers.factories import Pool -from monai.utils import ensure_tuple_rep, optional_import - -# To install xformers, use pip install xformers==0.0.16rc401 - -xops, has_xformers = optional_import("xformers.ops") - - -__all__ = ["DiffusionModelUNet"] - - -def zero_module(module: nn.Module) -> nn.Module: - """ - Zero out the parameters of a module and return it. - """ - for p in module.parameters(): - p.detach().zero_() - return module - - -class _CrossAttention(nn.Module): - """ - NOTE This is a private block that we plan to merge with existing MONAI blocks in the future. Please do not make - use of this block as support is not guaranteed. For more information see: - https://github.com/Project-MONAI/MONAI/issues/7227 - - A cross attention layer. - - Args: - query_dim: number of channels in the query. - cross_attention_dim: number of channels in the context. - num_attention_heads: number of heads to use for multi-head attention. - num_head_channels: number of channels in each head. - dropout: dropout probability to use. - upcast_attention: if True, upcast attention operations to full precision. - use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. - """ - - def __init__( - self, - query_dim: int, - cross_attention_dim: int | None = None, - num_attention_heads: int = 8, - num_head_channels: int = 64, - dropout: float = 0.0, - upcast_attention: bool = False, - use_flash_attention: bool = False, - ) -> None: - super().__init__() - self.use_flash_attention = use_flash_attention - inner_dim = num_head_channels * num_attention_heads - cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim - - self.scale = 1 / math.sqrt(num_head_channels) - self.num_heads = num_attention_heads - - self.upcast_attention = upcast_attention - - self.to_q = nn.Linear(query_dim, inner_dim, bias=False) - self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=False) - self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=False) - - self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)) - - def reshape_heads_to_batch_dim(self, x: torch.Tensor) -> torch.Tensor: - """ - Divide hidden state dimension to the multiple attention heads and reshape their input as instances in the batch. - """ - batch_size, seq_len, dim = x.shape - x = x.reshape(batch_size, seq_len, self.num_heads, dim // self.num_heads) - x = x.permute(0, 2, 1, 3).reshape(batch_size * self.num_heads, seq_len, dim // self.num_heads) - return x - - def reshape_batch_dim_to_heads(self, x: torch.Tensor) -> torch.Tensor: - """Combine the output of the attention heads back into the hidden state dimension.""" - batch_size, seq_len, dim = x.shape - x = x.reshape(batch_size // self.num_heads, self.num_heads, seq_len, dim) - x = x.permute(0, 2, 1, 3).reshape(batch_size // self.num_heads, seq_len, dim * self.num_heads) - return x - - def _memory_efficient_attention_xformers( - self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor - ) -> torch.Tensor: - query = query.contiguous() - key = key.contiguous() - value = value.contiguous() - x: torch.Tensor = xops.memory_efficient_attention(query, key, value, attn_bias=None) - return x - - def _attention(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor: - dtype = query.dtype - if self.upcast_attention: - query = query.float() - key = key.float() - - attention_scores = torch.baddbmm( - torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device), - query, - key.transpose(-1, -2), - beta=0, - alpha=self.scale, - ) - attention_probs = attention_scores.softmax(dim=-1) - attention_probs = attention_probs.to(dtype=dtype) - - x = torch.bmm(attention_probs, value) - return x - - def forward(self, x: torch.Tensor, context: torch.Tensor | None = None) -> torch.Tensor: - query = self.to_q(x) - context = context if context is not None else x - key = self.to_k(context) - value = self.to_v(context) - - # Multi-Head Attention - query = self.reshape_heads_to_batch_dim(query) - key = self.reshape_heads_to_batch_dim(key) - value = self.reshape_heads_to_batch_dim(value) - if self.use_flash_attention: - x = self._memory_efficient_attention_xformers(query, key, value) - else: - x = self._attention(query, key, value) - - x = self.reshape_batch_dim_to_heads(x) - x = x.to(query.dtype) - output: torch.Tensor = self.to_out(x) - return output - - -class _BasicTransformerBlock(nn.Module): - """ - NOTE This is a private block that we plan to merge with existing MONAI blocks in the future. Please do not make - use of this block as support is not guaranteed. For more information see: - https://github.com/Project-MONAI/MONAI/issues/7227 - - A basic Transformer block. - - Args: - num_channels: number of channels in the input and output. - num_attention_heads: number of heads to use for multi-head attention. - num_head_channels: number of channels in each attention head. - dropout: dropout probability to use. - cross_attention_dim: size of the context vector for cross attention. - upcast_attention: if True, upcast attention operations to full precision. - use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. - """ - - def __init__( - self, - num_channels: int, - num_attention_heads: int, - num_head_channels: int, - dropout: float = 0.0, - cross_attention_dim: int | None = None, - upcast_attention: bool = False, - use_flash_attention: bool = False, - ) -> None: - super().__init__() - self.attn1 = _CrossAttention( - query_dim=num_channels, - num_attention_heads=num_attention_heads, - num_head_channels=num_head_channels, - dropout=dropout, - upcast_attention=upcast_attention, - use_flash_attention=use_flash_attention, - ) # is a self-attention - self.ff = MLPBlock(hidden_size=num_channels, mlp_dim=num_channels * 4, act="GEGLU", dropout_rate=dropout) - self.attn2 = _CrossAttention( - query_dim=num_channels, - cross_attention_dim=cross_attention_dim, - num_attention_heads=num_attention_heads, - num_head_channels=num_head_channels, - dropout=dropout, - upcast_attention=upcast_attention, - use_flash_attention=use_flash_attention, - ) # is a self-attention if context is None - self.norm1 = nn.LayerNorm(num_channels) - self.norm2 = nn.LayerNorm(num_channels) - self.norm3 = nn.LayerNorm(num_channels) - - def forward(self, x: torch.Tensor, context: torch.Tensor | None = None) -> torch.Tensor: - # 1. Self-Attention - x = self.attn1(self.norm1(x)) + x - - # 2. Cross-Attention - x = self.attn2(self.norm2(x), context=context) + x - - # 3. Feed-forward - x = self.ff(self.norm3(x)) + x - return x - - -class _SpatialTransformer(nn.Module): - """ - NOTE This is a private block that we plan to merge with existing MONAI blocks in the future. Please do not make - use of this block as support is not guaranteed. For more information see: - https://github.com/Project-MONAI/MONAI/issues/7227 - - Transformer block for image-like data. First, project the input (aka embedding) and reshape to b, t, d. Then apply - standard transformer action. Finally, reshape to image. - - Args: - spatial_dims: number of spatial dimensions. - in_channels: number of channels in the input and output. - num_attention_heads: number of heads to use for multi-head attention. - num_head_channels: number of channels in each attention head. - num_layers: number of layers of Transformer blocks to use. - dropout: dropout probability to use. - norm_num_groups: number of groups for the normalization. - norm_eps: epsilon for the normalization. - cross_attention_dim: number of context dimensions to use. - upcast_attention: if True, upcast attention operations to full precision. - use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. - """ - - def __init__( - self, - spatial_dims: int, - in_channels: int, - num_attention_heads: int, - num_head_channels: int, - num_layers: int = 1, - dropout: float = 0.0, - norm_num_groups: int = 32, - norm_eps: float = 1e-6, - cross_attention_dim: int | None = None, - upcast_attention: bool = False, - use_flash_attention: bool = False, - ) -> None: - super().__init__() - self.spatial_dims = spatial_dims - self.in_channels = in_channels - inner_dim = num_attention_heads * num_head_channels - - self.norm = nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=norm_eps, affine=True) - - self.proj_in = Convolution( - spatial_dims=spatial_dims, - in_channels=in_channels, - out_channels=inner_dim, - strides=1, - kernel_size=1, - padding=0, - conv_only=True, - ) - - self.transformer_blocks = nn.ModuleList( - [ - _BasicTransformerBlock( - num_channels=inner_dim, - num_attention_heads=num_attention_heads, - num_head_channels=num_head_channels, - dropout=dropout, - cross_attention_dim=cross_attention_dim, - upcast_attention=upcast_attention, - use_flash_attention=use_flash_attention, - ) - for _ in range(num_layers) - ] - ) - - self.proj_out = zero_module( - Convolution( - spatial_dims=spatial_dims, - in_channels=inner_dim, - out_channels=in_channels, - strides=1, - kernel_size=1, - padding=0, - conv_only=True, - ) - ) - - def forward(self, x: torch.Tensor, context: torch.Tensor | None = None) -> torch.Tensor: - # note: if no context is given, cross-attention defaults to self-attention - batch = channel = height = width = depth = -1 - if self.spatial_dims == 2: - batch, channel, height, width = x.shape - if self.spatial_dims == 3: - batch, channel, height, width, depth = x.shape - - residual = x - x = self.norm(x) - x = self.proj_in(x) - - inner_dim = x.shape[1] - - if self.spatial_dims == 2: - x = x.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) - if self.spatial_dims == 3: - x = x.permute(0, 2, 3, 4, 1).reshape(batch, height * width * depth, inner_dim) - - for block in self.transformer_blocks: - x = block(x, context=context) - - if self.spatial_dims == 2: - x = x.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() - if self.spatial_dims == 3: - x = x.reshape(batch, height, width, depth, inner_dim).permute(0, 4, 1, 2, 3).contiguous() - - x = self.proj_out(x) - return x + residual - - -class _AttentionBlock(nn.Module): - """ - NOTE This is a private block that we plan to merge with existing MONAI blocks in the future. Please do not make - use of this block as support is not guaranteed. For more information see: - https://github.com/Project-MONAI/MONAI/issues/7227 - - An attention block that allows spatial positions to attend to each other. Uses three q, k, v linear layers to - compute attention. - - Args: - spatial_dims: number of spatial dimensions. - num_channels: number of input channels. - num_head_channels: number of channels in each attention head. - norm_num_groups: number of groups involved for the group normalisation layer. Ensure that your number of - channels is divisible by this number. - norm_eps: epsilon value to use for the normalisation. - use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. - """ - - def __init__( - self, - spatial_dims: int, - num_channels: int, - num_head_channels: int | None = None, - norm_num_groups: int = 32, - norm_eps: float = 1e-6, - use_flash_attention: bool = False, - ) -> None: - super().__init__() - self.use_flash_attention = use_flash_attention - self.spatial_dims = spatial_dims - self.num_channels = num_channels - - self.num_heads = num_channels // num_head_channels if num_head_channels is not None else 1 - self.scale = 1 / math.sqrt(num_channels / self.num_heads) - - self.norm = nn.GroupNorm(num_groups=norm_num_groups, num_channels=num_channels, eps=norm_eps, affine=True) - - self.to_q = nn.Linear(num_channels, num_channels) - self.to_k = nn.Linear(num_channels, num_channels) - self.to_v = nn.Linear(num_channels, num_channels) - - self.proj_attn = nn.Linear(num_channels, num_channels) - - def reshape_heads_to_batch_dim(self, x: torch.Tensor) -> torch.Tensor: - batch_size, seq_len, dim = x.shape - x = x.reshape(batch_size, seq_len, self.num_heads, dim // self.num_heads) - x = x.permute(0, 2, 1, 3).reshape(batch_size * self.num_heads, seq_len, dim // self.num_heads) - return x - - def reshape_batch_dim_to_heads(self, x: torch.Tensor) -> torch.Tensor: - batch_size, seq_len, dim = x.shape - x = x.reshape(batch_size // self.num_heads, self.num_heads, seq_len, dim) - x = x.permute(0, 2, 1, 3).reshape(batch_size // self.num_heads, seq_len, dim * self.num_heads) - return x - - def _memory_efficient_attention_xformers( - self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor - ) -> torch.Tensor: - query = query.contiguous() - key = key.contiguous() - value = value.contiguous() - x: torch.Tensor = xops.memory_efficient_attention(query, key, value, attn_bias=None) - return x - - def _attention(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor: - attention_scores = torch.baddbmm( - torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device), - query, - key.transpose(-1, -2), - beta=0, - alpha=self.scale, - ) - attention_probs = attention_scores.softmax(dim=-1) - x = torch.bmm(attention_probs, value) - return x - - def forward(self, x: torch.Tensor) -> torch.Tensor: - residual = x - - batch = channel = height = width = depth = -1 - if self.spatial_dims == 2: - batch, channel, height, width = x.shape - if self.spatial_dims == 3: - batch, channel, height, width, depth = x.shape - - # norm - x = self.norm(x) - - if self.spatial_dims == 2: - x = x.view(batch, channel, height * width).transpose(1, 2) - if self.spatial_dims == 3: - x = x.view(batch, channel, height * width * depth).transpose(1, 2) - - # proj to q, k, v - query = self.to_q(x) - key = self.to_k(x) - value = self.to_v(x) - - # Multi-Head Attention - query = self.reshape_heads_to_batch_dim(query) - key = self.reshape_heads_to_batch_dim(key) - value = self.reshape_heads_to_batch_dim(value) - - if self.use_flash_attention: - x = self._memory_efficient_attention_xformers(query, key, value) - else: - x = self._attention(query, key, value) - - x = self.reshape_batch_dim_to_heads(x) - x = x.to(query.dtype) - - if self.spatial_dims == 2: - x = x.transpose(-1, -2).reshape(batch, channel, height, width) - if self.spatial_dims == 3: - x = x.transpose(-1, -2).reshape(batch, channel, height, width, depth) - - return x + residual - - -def get_timestep_embedding(timesteps: torch.Tensor, embedding_dim: int, max_period: int = 10000) -> torch.Tensor: - """ - Create sinusoidal timestep embeddings following the implementation in Ho et al. "Denoising Diffusion Probabilistic - Models" https://arxiv.org/abs/2006.11239. - - Args: - timesteps: a 1-D Tensor of N indices, one per batch element. - embedding_dim: the dimension of the output. - max_period: controls the minimum frequency of the embeddings. - """ - if timesteps.ndim != 1: - raise ValueError("Timesteps should be a 1d-array") - - half_dim = embedding_dim // 2 - exponent = -math.log(max_period) * torch.arange(start=0, end=half_dim, dtype=torch.float32, device=timesteps.device) - freqs = torch.exp(exponent / half_dim) - - args = timesteps[:, None].float() * freqs[None, :] - embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) - - # zero pad - if embedding_dim % 2 == 1: - embedding = torch.nn.functional.pad(embedding, (0, 1, 0, 0)) - - return embedding - - -class _Downsample(nn.Module): - """ - NOTE This is a private block that we plan to merge with existing MONAI blocks in the future. Please do not make - use of this block as support is not guaranteed. For more information see: - https://github.com/Project-MONAI/MONAI/issues/7227 - - Downsampling layer. - - Args: - spatial_dims: number of spatial dimensions. - num_channels: number of input channels. - use_conv: if True uses Convolution instead of Pool average to perform downsampling. In case that use_conv is - False, the number of output channels must be the same as the number of input channels. - out_channels: number of output channels. - padding: controls the amount of implicit zero-paddings on both sides for padding number of points - for each dimension. - """ - - def __init__( - self, spatial_dims: int, num_channels: int, use_conv: bool, out_channels: int | None = None, padding: int = 1 - ) -> None: - super().__init__() - self.num_channels = num_channels - self.out_channels = out_channels or num_channels - self.use_conv = use_conv - if use_conv: - self.op = Convolution( - spatial_dims=spatial_dims, - in_channels=self.num_channels, - out_channels=self.out_channels, - strides=2, - kernel_size=3, - padding=padding, - conv_only=True, - ) - else: - if self.num_channels != self.out_channels: - raise ValueError("num_channels and out_channels must be equal when use_conv=False") - self.op = Pool[Pool.AVG, spatial_dims](kernel_size=2, stride=2) - - def forward(self, x: torch.Tensor, emb: torch.Tensor | None = None) -> torch.Tensor: - del emb - if x.shape[1] != self.num_channels: - raise ValueError( - f"Input number of channels ({x.shape[1]}) is not equal to expected number of channels " - f"({self.num_channels})" - ) - output: torch.Tensor = self.op(x) - return output - - -class _Upsample(nn.Module): - """ - NOTE This is a private block that we plan to merge with existing MONAI blocks in the future. Please do not make - use of this block as support is not guaranteed. For more information see: - https://github.com/Project-MONAI/MONAI/issues/7227 - - Upsampling layer with an optional convolution. - - Args: - spatial_dims: number of spatial dimensions. - num_channels: number of input channels. - use_conv: if True uses Convolution instead of Pool average to perform downsampling. - out_channels: number of output channels. - padding: controls the amount of implicit zero-paddings on both sides for padding number of points for each - dimension. - """ - - def __init__( - self, spatial_dims: int, num_channels: int, use_conv: bool, out_channels: int | None = None, padding: int = 1 - ) -> None: - super().__init__() - self.num_channels = num_channels - self.out_channels = out_channels or num_channels - self.use_conv = use_conv - if use_conv: - self.conv = Convolution( - spatial_dims=spatial_dims, - in_channels=self.num_channels, - out_channels=self.out_channels, - strides=1, - kernel_size=3, - padding=padding, - conv_only=True, - ) - - def forward(self, x: torch.Tensor, emb: torch.Tensor | None = None) -> torch.Tensor: - del emb - if x.shape[1] != self.num_channels: - raise ValueError("Input channels should be equal to num_channels") - - # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16 - # https://github.com/pytorch/pytorch/issues/86679 - dtype = x.dtype - if dtype == torch.bfloat16: - x = x.to(torch.float32) - - x = F.interpolate(x, scale_factor=2.0, mode="nearest") - - # If the input is bfloat16, we cast back to bfloat16 - if dtype == torch.bfloat16: - x = x.to(dtype) - - if self.use_conv: - x = self.conv(x) - return x - - -class _ResnetBlock(nn.Module): - """ - NOTE This is a private block that we plan to merge with existing MONAI blocks in the future. Please do not make - use of this block as support is not guaranteed. For more information see: - https://github.com/Project-MONAI/MONAI/issues/7227 - Residual block with timestep conditioning. - - Args: - spatial_dims: The number of spatial dimensions. - in_channels: number of input channels. - temb_channels: number of timestep embedding channels. - out_channels: number of output channels. - up: if True, performs upsampling. - down: if True, performs downsampling. - norm_num_groups: number of groups for the group normalization. - norm_eps: epsilon for the group normalization. - """ - - def __init__( - self, - spatial_dims: int, - in_channels: int, - temb_channels: int, - out_channels: int | None = None, - up: bool = False, - down: bool = False, - norm_num_groups: int = 32, - norm_eps: float = 1e-6, - ) -> None: - super().__init__() - self.spatial_dims = spatial_dims - self.channels = in_channels - self.emb_channels = temb_channels - self.out_channels = out_channels or in_channels - self.up = up - self.down = down - - self.norm1 = nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=norm_eps, affine=True) - self.nonlinearity = nn.SiLU() - self.conv1 = Convolution( - spatial_dims=spatial_dims, - in_channels=in_channels, - out_channels=self.out_channels, - strides=1, - kernel_size=3, - padding=1, - conv_only=True, - ) - - self.upsample = self.downsample = None - if self.up: - self.upsample = _Upsample(spatial_dims, in_channels, use_conv=False) - elif down: - self.downsample = _Downsample(spatial_dims, in_channels, use_conv=False) - - self.time_emb_proj = nn.Linear(temb_channels, self.out_channels) - - self.norm2 = nn.GroupNorm(num_groups=norm_num_groups, num_channels=self.out_channels, eps=norm_eps, affine=True) - self.conv2 = zero_module( - Convolution( - spatial_dims=spatial_dims, - in_channels=self.out_channels, - out_channels=self.out_channels, - strides=1, - kernel_size=3, - padding=1, - conv_only=True, - ) - ) - self.skip_connection: nn.Module - if self.out_channels == in_channels: - self.skip_connection = nn.Identity() - else: - self.skip_connection = Convolution( - spatial_dims=spatial_dims, - in_channels=in_channels, - out_channels=self.out_channels, - strides=1, - kernel_size=1, - padding=0, - conv_only=True, - ) - - def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor: - h = x - h = self.norm1(h) - h = self.nonlinearity(h) - - if self.upsample is not None: - if h.shape[0] >= 64: - x = x.contiguous() - h = h.contiguous() - x = self.upsample(x) - h = self.upsample(h) - elif self.downsample is not None: - x = self.downsample(x) - h = self.downsample(h) - - h = self.conv1(h) - - if self.spatial_dims == 2: - temb = self.time_emb_proj(self.nonlinearity(emb))[:, :, None, None] - else: - temb = self.time_emb_proj(self.nonlinearity(emb))[:, :, None, None, None] - h = h + temb - - h = self.norm2(h) - h = self.nonlinearity(h) - h = self.conv2(h) - output: torch.Tensor = self.skip_connection(x) + h - return output - - -class DownBlock(nn.Module): - """ - Unet's down block containing resnet and downsamplers blocks. - - Args: - spatial_dims: The number of spatial dimensions. - in_channels: number of input channels. - out_channels: number of output channels. - temb_channels: number of timestep embedding channels. - num_res_blocks: number of residual blocks. - norm_num_groups: number of groups for the group normalization. - norm_eps: epsilon for the group normalization. - add_downsample: if True add downsample block. - resblock_updown: if True use residual blocks for downsampling. - downsample_padding: padding used in the downsampling block. - """ - - def __init__( - self, - spatial_dims: int, - in_channels: int, - out_channels: int, - temb_channels: int, - num_res_blocks: int = 1, - norm_num_groups: int = 32, - norm_eps: float = 1e-6, - add_downsample: bool = True, - resblock_updown: bool = False, - downsample_padding: int = 1, - ) -> None: - super().__init__() - self.resblock_updown = resblock_updown - - resnets = [] - - for i in range(num_res_blocks): - in_channels = in_channels if i == 0 else out_channels - resnets.append( - _ResnetBlock( - spatial_dims=spatial_dims, - in_channels=in_channels, - out_channels=out_channels, - temb_channels=temb_channels, - norm_num_groups=norm_num_groups, - norm_eps=norm_eps, - ) - ) - - self.resnets = nn.ModuleList(resnets) - - if add_downsample: - self.downsampler: nn.Module | None - if resblock_updown: - self.downsampler = _ResnetBlock( - spatial_dims=spatial_dims, - in_channels=out_channels, - out_channels=out_channels, - temb_channels=temb_channels, - norm_num_groups=norm_num_groups, - norm_eps=norm_eps, - down=True, - ) - else: - self.downsampler = _Downsample( - spatial_dims=spatial_dims, - num_channels=out_channels, - use_conv=True, - out_channels=out_channels, - padding=downsample_padding, - ) - else: - self.downsampler = None - - def forward( - self, hidden_states: torch.Tensor, temb: torch.Tensor, context: torch.Tensor | None = None - ) -> tuple[torch.Tensor, list[torch.Tensor]]: - del context - output_states = [] - - for resnet in self.resnets: - hidden_states = resnet(hidden_states, temb) - output_states.append(hidden_states) - - if self.downsampler is not None: - hidden_states = self.downsampler(hidden_states, temb) - output_states.append(hidden_states) - - return hidden_states, output_states - - -class AttnDownBlock(nn.Module): - """ - Unet's down block containing resnet, downsamplers and self-attention blocks. - - Args: - spatial_dims: The number of spatial dimensions. - in_channels: number of input channels. - out_channels: number of output channels. - temb_channels: number of timestep embedding channels. - num_res_blocks: number of residual blocks. - norm_num_groups: number of groups for the group normalization. - norm_eps: epsilon for the group normalization. - add_downsample: if True add downsample block. - resblock_updown: if True use residual blocks for downsampling. - downsample_padding: padding used in the downsampling block. - num_head_channels: number of channels in each attention head. - use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. - """ - - def __init__( - self, - spatial_dims: int, - in_channels: int, - out_channels: int, - temb_channels: int, - num_res_blocks: int = 1, - norm_num_groups: int = 32, - norm_eps: float = 1e-6, - add_downsample: bool = True, - resblock_updown: bool = False, - downsample_padding: int = 1, - num_head_channels: int = 1, - use_flash_attention: bool = False, - ) -> None: - super().__init__() - self.resblock_updown = resblock_updown - - resnets = [] - attentions = [] - - for i in range(num_res_blocks): - in_channels = in_channels if i == 0 else out_channels - resnets.append( - _ResnetBlock( - spatial_dims=spatial_dims, - in_channels=in_channels, - out_channels=out_channels, - temb_channels=temb_channels, - norm_num_groups=norm_num_groups, - norm_eps=norm_eps, - ) - ) - attentions.append( - _AttentionBlock( - spatial_dims=spatial_dims, - num_channels=out_channels, - num_head_channels=num_head_channels, - norm_num_groups=norm_num_groups, - norm_eps=norm_eps, - use_flash_attention=use_flash_attention, - ) - ) - - self.attentions = nn.ModuleList(attentions) - self.resnets = nn.ModuleList(resnets) - - self.downsampler: nn.Module | None - if add_downsample: - if resblock_updown: - self.downsampler = _ResnetBlock( - spatial_dims=spatial_dims, - in_channels=out_channels, - out_channels=out_channels, - temb_channels=temb_channels, - norm_num_groups=norm_num_groups, - norm_eps=norm_eps, - down=True, - ) - else: - self.downsampler = _Downsample( - spatial_dims=spatial_dims, - num_channels=out_channels, - use_conv=True, - out_channels=out_channels, - padding=downsample_padding, - ) - else: - self.downsampler = None - - def forward( - self, hidden_states: torch.Tensor, temb: torch.Tensor, context: torch.Tensor | None = None - ) -> tuple[torch.Tensor, list[torch.Tensor]]: - del context - output_states = [] - - for resnet, attn in zip(self.resnets, self.attentions): - hidden_states = resnet(hidden_states, temb) - hidden_states = attn(hidden_states) - output_states.append(hidden_states) - - if self.downsampler is not None: - hidden_states = self.downsampler(hidden_states, temb) - output_states.append(hidden_states) - - return hidden_states, output_states - - -class CrossAttnDownBlock(nn.Module): - """ - Unet's down block containing resnet, downsamplers and cross-attention blocks. - - Args: - spatial_dims: number of spatial dimensions. - in_channels: number of input channels. - out_channels: number of output channels. - temb_channels: number of timestep embedding channels. - num_res_blocks: number of residual blocks. - norm_num_groups: number of groups for the group normalization. - norm_eps: epsilon for the group normalization. - add_downsample: if True add downsample block. - resblock_updown: if True use residual blocks for downsampling. - downsample_padding: padding used in the downsampling block. - num_head_channels: number of channels in each attention head. - transformer_num_layers: number of layers of Transformer blocks to use. - cross_attention_dim: number of context dimensions to use. - upcast_attention: if True, upcast attention operations to full precision. - use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. - dropout_cattn: if different from zero, this will be the dropout value for the cross-attention layers - """ - - def __init__( - self, - spatial_dims: int, - in_channels: int, - out_channels: int, - temb_channels: int, - num_res_blocks: int = 1, - norm_num_groups: int = 32, - norm_eps: float = 1e-6, - add_downsample: bool = True, - resblock_updown: bool = False, - downsample_padding: int = 1, - num_head_channels: int = 1, - transformer_num_layers: int = 1, - cross_attention_dim: int | None = None, - upcast_attention: bool = False, - use_flash_attention: bool = False, - dropout_cattn: float = 0.0, - ) -> None: - super().__init__() - self.resblock_updown = resblock_updown - - resnets = [] - attentions = [] - - for i in range(num_res_blocks): - in_channels = in_channels if i == 0 else out_channels - resnets.append( - _ResnetBlock( - spatial_dims=spatial_dims, - in_channels=in_channels, - out_channels=out_channels, - temb_channels=temb_channels, - norm_num_groups=norm_num_groups, - norm_eps=norm_eps, - ) - ) - - attentions.append( - _SpatialTransformer( - spatial_dims=spatial_dims, - in_channels=out_channels, - num_attention_heads=out_channels // num_head_channels, - num_head_channels=num_head_channels, - num_layers=transformer_num_layers, - norm_num_groups=norm_num_groups, - norm_eps=norm_eps, - cross_attention_dim=cross_attention_dim, - upcast_attention=upcast_attention, - use_flash_attention=use_flash_attention, - dropout=dropout_cattn, - ) - ) - - self.attentions = nn.ModuleList(attentions) - self.resnets = nn.ModuleList(resnets) - - self.downsampler: nn.Module | None - if add_downsample: - if resblock_updown: - self.downsampler = _ResnetBlock( - spatial_dims=spatial_dims, - in_channels=out_channels, - out_channels=out_channels, - temb_channels=temb_channels, - norm_num_groups=norm_num_groups, - norm_eps=norm_eps, - down=True, - ) - else: - self.downsampler = _Downsample( - spatial_dims=spatial_dims, - num_channels=out_channels, - use_conv=True, - out_channels=out_channels, - padding=downsample_padding, - ) - else: - self.downsampler = None - - def forward( - self, hidden_states: torch.Tensor, temb: torch.Tensor, context: torch.Tensor | None = None - ) -> tuple[torch.Tensor, list[torch.Tensor]]: - output_states = [] - - for resnet, attn in zip(self.resnets, self.attentions): - hidden_states = resnet(hidden_states, temb) - hidden_states = attn(hidden_states, context=context) - output_states.append(hidden_states) - - if self.downsampler is not None: - hidden_states = self.downsampler(hidden_states, temb) - output_states.append(hidden_states) - - return hidden_states, output_states - - -class AttnMidBlock(nn.Module): - """ - Unet's mid block containing resnet and self-attention blocks. - - Args: - spatial_dims: The number of spatial dimensions. - in_channels: number of input channels. - temb_channels: number of timestep embedding channels. - norm_num_groups: number of groups for the group normalization. - norm_eps: epsilon for the group normalization. - num_head_channels: number of channels in each attention head. - use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. - """ - - def __init__( - self, - spatial_dims: int, - in_channels: int, - temb_channels: int, - norm_num_groups: int = 32, - norm_eps: float = 1e-6, - num_head_channels: int = 1, - use_flash_attention: bool = False, - ) -> None: - super().__init__() - - self.resnet_1 = _ResnetBlock( - spatial_dims=spatial_dims, - in_channels=in_channels, - out_channels=in_channels, - temb_channels=temb_channels, - norm_num_groups=norm_num_groups, - norm_eps=norm_eps, - ) - self.attention = _AttentionBlock( - spatial_dims=spatial_dims, - num_channels=in_channels, - num_head_channels=num_head_channels, - norm_num_groups=norm_num_groups, - norm_eps=norm_eps, - use_flash_attention=use_flash_attention, - ) - - self.resnet_2 = _ResnetBlock( - spatial_dims=spatial_dims, - in_channels=in_channels, - out_channels=in_channels, - temb_channels=temb_channels, - norm_num_groups=norm_num_groups, - norm_eps=norm_eps, - ) - - def forward( - self, hidden_states: torch.Tensor, temb: torch.Tensor, context: torch.Tensor | None = None - ) -> torch.Tensor: - del context - hidden_states = self.resnet_1(hidden_states, temb) - hidden_states = self.attention(hidden_states) - hidden_states = self.resnet_2(hidden_states, temb) - - return hidden_states - - -class CrossAttnMidBlock(nn.Module): - """ - Unet's mid block containing resnet and cross-attention blocks. - - Args: - spatial_dims: The number of spatial dimensions. - in_channels: number of input channels. - temb_channels: number of timestep embedding channels - norm_num_groups: number of groups for the group normalization. - norm_eps: epsilon for the group normalization. - num_head_channels: number of channels in each attention head. - transformer_num_layers: number of layers of Transformer blocks to use. - cross_attention_dim: number of context dimensions to use. - upcast_attention: if True, upcast attention operations to full precision. - use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. - """ - - def __init__( - self, - spatial_dims: int, - in_channels: int, - temb_channels: int, - norm_num_groups: int = 32, - norm_eps: float = 1e-6, - num_head_channels: int = 1, - transformer_num_layers: int = 1, - cross_attention_dim: int | None = None, - upcast_attention: bool = False, - use_flash_attention: bool = False, - dropout_cattn: float = 0.0, - ) -> None: - super().__init__() - - self.resnet_1 = _ResnetBlock( - spatial_dims=spatial_dims, - in_channels=in_channels, - out_channels=in_channels, - temb_channels=temb_channels, - norm_num_groups=norm_num_groups, - norm_eps=norm_eps, - ) - self.attention = _SpatialTransformer( - spatial_dims=spatial_dims, - in_channels=in_channels, - num_attention_heads=in_channels // num_head_channels, - num_head_channels=num_head_channels, - num_layers=transformer_num_layers, - norm_num_groups=norm_num_groups, - norm_eps=norm_eps, - cross_attention_dim=cross_attention_dim, - upcast_attention=upcast_attention, - use_flash_attention=use_flash_attention, - dropout=dropout_cattn, - ) - self.resnet_2 = _ResnetBlock( - spatial_dims=spatial_dims, - in_channels=in_channels, - out_channels=in_channels, - temb_channels=temb_channels, - norm_num_groups=norm_num_groups, - norm_eps=norm_eps, - ) - - def forward( - self, hidden_states: torch.Tensor, temb: torch.Tensor, context: torch.Tensor | None = None - ) -> torch.Tensor: - hidden_states = self.resnet_1(hidden_states, temb) - hidden_states = self.attention(hidden_states, context=context) - hidden_states = self.resnet_2(hidden_states, temb) - - return hidden_states - - -class UpBlock(nn.Module): - """ - Unet's up block containing resnet and upsamplers blocks. - - Args: - spatial_dims: The number of spatial dimensions. - in_channels: number of input channels. - prev_output_channel: number of channels from residual connection. - out_channels: number of output channels. - temb_channels: number of timestep embedding channels. - num_res_blocks: number of residual blocks. - norm_num_groups: number of groups for the group normalization. - norm_eps: epsilon for the group normalization. - add_upsample: if True add downsample block. - resblock_updown: if True use residual blocks for upsampling. - """ - - def __init__( - self, - spatial_dims: int, - in_channels: int, - prev_output_channel: int, - out_channels: int, - temb_channels: int, - num_res_blocks: int = 1, - norm_num_groups: int = 32, - norm_eps: float = 1e-6, - add_upsample: bool = True, - resblock_updown: bool = False, - ) -> None: - super().__init__() - self.resblock_updown = resblock_updown - resnets = [] - - for i in range(num_res_blocks): - res_skip_channels = in_channels if (i == num_res_blocks - 1) else out_channels - resnet_in_channels = prev_output_channel if i == 0 else out_channels - - resnets.append( - _ResnetBlock( - spatial_dims=spatial_dims, - in_channels=resnet_in_channels + res_skip_channels, - out_channels=out_channels, - temb_channels=temb_channels, - norm_num_groups=norm_num_groups, - norm_eps=norm_eps, - ) - ) - - self.resnets = nn.ModuleList(resnets) - - self.upsampler: nn.Module | None - if add_upsample: - if resblock_updown: - self.upsampler = _ResnetBlock( - spatial_dims=spatial_dims, - in_channels=out_channels, - out_channels=out_channels, - temb_channels=temb_channels, - norm_num_groups=norm_num_groups, - norm_eps=norm_eps, - up=True, - ) - else: - self.upsampler = _Upsample( - spatial_dims=spatial_dims, num_channels=out_channels, use_conv=True, out_channels=out_channels - ) - else: - self.upsampler = None - - def forward( - self, - hidden_states: torch.Tensor, - res_hidden_states_list: list[torch.Tensor], - temb: torch.Tensor, - context: torch.Tensor | None = None, - ) -> torch.Tensor: - del context - for resnet in self.resnets: - # pop res hidden states - res_hidden_states = res_hidden_states_list[-1] - res_hidden_states_list = res_hidden_states_list[:-1] - hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) - - hidden_states = resnet(hidden_states, temb) - - if self.upsampler is not None: - hidden_states = self.upsampler(hidden_states, temb) - - return hidden_states - - -class AttnUpBlock(nn.Module): - """ - Unet's up block containing resnet, upsamplers, and self-attention blocks. - - Args: - spatial_dims: The number of spatial dimensions. - in_channels: number of input channels. - prev_output_channel: number of channels from residual connection. - out_channels: number of output channels. - temb_channels: number of timestep embedding channels. - num_res_blocks: number of residual blocks. - norm_num_groups: number of groups for the group normalization. - norm_eps: epsilon for the group normalization. - add_upsample: if True add downsample block. - resblock_updown: if True use residual blocks for upsampling. - num_head_channels: number of channels in each attention head. - use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. - """ - - def __init__( - self, - spatial_dims: int, - in_channels: int, - prev_output_channel: int, - out_channels: int, - temb_channels: int, - num_res_blocks: int = 1, - norm_num_groups: int = 32, - norm_eps: float = 1e-6, - add_upsample: bool = True, - resblock_updown: bool = False, - num_head_channels: int = 1, - use_flash_attention: bool = False, - ) -> None: - super().__init__() - self.resblock_updown = resblock_updown - - resnets = [] - attentions = [] - - for i in range(num_res_blocks): - res_skip_channels = in_channels if (i == num_res_blocks - 1) else out_channels - resnet_in_channels = prev_output_channel if i == 0 else out_channels - - resnets.append( - _ResnetBlock( - spatial_dims=spatial_dims, - in_channels=resnet_in_channels + res_skip_channels, - out_channels=out_channels, - temb_channels=temb_channels, - norm_num_groups=norm_num_groups, - norm_eps=norm_eps, - ) - ) - attentions.append( - _AttentionBlock( - spatial_dims=spatial_dims, - num_channels=out_channels, - num_head_channels=num_head_channels, - norm_num_groups=norm_num_groups, - norm_eps=norm_eps, - use_flash_attention=use_flash_attention, - ) - ) - - self.resnets = nn.ModuleList(resnets) - self.attentions = nn.ModuleList(attentions) - - self.upsampler: nn.Module | None - if add_upsample: - if resblock_updown: - self.upsampler = _ResnetBlock( - spatial_dims=spatial_dims, - in_channels=out_channels, - out_channels=out_channels, - temb_channels=temb_channels, - norm_num_groups=norm_num_groups, - norm_eps=norm_eps, - up=True, - ) - else: - self.upsampler = _Upsample( - spatial_dims=spatial_dims, num_channels=out_channels, use_conv=True, out_channels=out_channels - ) - else: - self.upsampler = None - - def forward( - self, - hidden_states: torch.Tensor, - res_hidden_states_list: list[torch.Tensor], - temb: torch.Tensor, - context: torch.Tensor | None = None, - ) -> torch.Tensor: - del context - for resnet, attn in zip(self.resnets, self.attentions): - # pop res hidden states - res_hidden_states = res_hidden_states_list[-1] - res_hidden_states_list = res_hidden_states_list[:-1] - hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) - - hidden_states = resnet(hidden_states, temb) - hidden_states = attn(hidden_states) - - if self.upsampler is not None: - hidden_states = self.upsampler(hidden_states, temb) - - return hidden_states - - -class CrossAttnUpBlock(nn.Module): - """ - Unet's up block containing resnet, upsamplers, and self-attention blocks. - - Args: - spatial_dims: The number of spatial dimensions. - in_channels: number of input channels. - prev_output_channel: number of channels from residual connection. - out_channels: number of output channels. - temb_channels: number of timestep embedding channels. - num_res_blocks: number of residual blocks. - norm_num_groups: number of groups for the group normalization. - norm_eps: epsilon for the group normalization. - add_upsample: if True add downsample block. - resblock_updown: if True use residual blocks for upsampling. - num_head_channels: number of channels in each attention head. - transformer_num_layers: number of layers of Transformer blocks to use. - cross_attention_dim: number of context dimensions to use. - upcast_attention: if True, upcast attention operations to full precision. - use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. - dropout_cattn: if different from zero, this will be the dropout value for the cross-attention layers - """ - - def __init__( - self, - spatial_dims: int, - in_channels: int, - prev_output_channel: int, - out_channels: int, - temb_channels: int, - num_res_blocks: int = 1, - norm_num_groups: int = 32, - norm_eps: float = 1e-6, - add_upsample: bool = True, - resblock_updown: bool = False, - num_head_channels: int = 1, - transformer_num_layers: int = 1, - cross_attention_dim: int | None = None, - upcast_attention: bool = False, - use_flash_attention: bool = False, - dropout_cattn: float = 0.0, - ) -> None: - super().__init__() - self.resblock_updown = resblock_updown - - resnets = [] - attentions = [] - - for i in range(num_res_blocks): - res_skip_channels = in_channels if (i == num_res_blocks - 1) else out_channels - resnet_in_channels = prev_output_channel if i == 0 else out_channels - - resnets.append( - _ResnetBlock( - spatial_dims=spatial_dims, - in_channels=resnet_in_channels + res_skip_channels, - out_channels=out_channels, - temb_channels=temb_channels, - norm_num_groups=norm_num_groups, - norm_eps=norm_eps, - ) - ) - attentions.append( - _SpatialTransformer( - spatial_dims=spatial_dims, - in_channels=out_channels, - num_attention_heads=out_channels // num_head_channels, - num_head_channels=num_head_channels, - norm_num_groups=norm_num_groups, - norm_eps=norm_eps, - num_layers=transformer_num_layers, - cross_attention_dim=cross_attention_dim, - upcast_attention=upcast_attention, - use_flash_attention=use_flash_attention, - dropout=dropout_cattn, - ) - ) - - self.attentions = nn.ModuleList(attentions) - self.resnets = nn.ModuleList(resnets) - - self.upsampler: nn.Module | None - if add_upsample: - if resblock_updown: - self.upsampler = _ResnetBlock( - spatial_dims=spatial_dims, - in_channels=out_channels, - out_channels=out_channels, - temb_channels=temb_channels, - norm_num_groups=norm_num_groups, - norm_eps=norm_eps, - up=True, - ) - else: - self.upsampler = _Upsample( - spatial_dims=spatial_dims, num_channels=out_channels, use_conv=True, out_channels=out_channels - ) - else: - self.upsampler = None - - def forward( - self, - hidden_states: torch.Tensor, - res_hidden_states_list: list[torch.Tensor], - temb: torch.Tensor, - context: torch.Tensor | None = None, - ) -> torch.Tensor: - for resnet, attn in zip(self.resnets, self.attentions): - # pop res hidden states - res_hidden_states = res_hidden_states_list[-1] - res_hidden_states_list = res_hidden_states_list[:-1] - hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) - - hidden_states = resnet(hidden_states, temb) - hidden_states = attn(hidden_states, context=context) - - if self.upsampler is not None: - hidden_states = self.upsampler(hidden_states, temb) - - return hidden_states - - -def get_down_block( - spatial_dims: int, - in_channels: int, - out_channels: int, - temb_channels: int, - num_res_blocks: int, - norm_num_groups: int, - norm_eps: float, - add_downsample: bool, - resblock_updown: bool, - with_attn: bool, - with_cross_attn: bool, - num_head_channels: int, - transformer_num_layers: int, - cross_attention_dim: int | None, - upcast_attention: bool = False, - use_flash_attention: bool = False, - dropout_cattn: float = 0.0, -) -> nn.Module: - if with_attn: - return AttnDownBlock( - spatial_dims=spatial_dims, - in_channels=in_channels, - out_channels=out_channels, - temb_channels=temb_channels, - num_res_blocks=num_res_blocks, - norm_num_groups=norm_num_groups, - norm_eps=norm_eps, - add_downsample=add_downsample, - resblock_updown=resblock_updown, - num_head_channels=num_head_channels, - use_flash_attention=use_flash_attention, - ) - elif with_cross_attn: - return CrossAttnDownBlock( - spatial_dims=spatial_dims, - in_channels=in_channels, - out_channels=out_channels, - temb_channels=temb_channels, - num_res_blocks=num_res_blocks, - norm_num_groups=norm_num_groups, - norm_eps=norm_eps, - add_downsample=add_downsample, - resblock_updown=resblock_updown, - num_head_channels=num_head_channels, - transformer_num_layers=transformer_num_layers, - cross_attention_dim=cross_attention_dim, - upcast_attention=upcast_attention, - use_flash_attention=use_flash_attention, - dropout_cattn=dropout_cattn, - ) - else: - return DownBlock( - spatial_dims=spatial_dims, - in_channels=in_channels, - out_channels=out_channels, - temb_channels=temb_channels, - num_res_blocks=num_res_blocks, - norm_num_groups=norm_num_groups, - norm_eps=norm_eps, - add_downsample=add_downsample, - resblock_updown=resblock_updown, - ) - - -def get_mid_block( - spatial_dims: int, - in_channels: int, - temb_channels: int, - norm_num_groups: int, - norm_eps: float, - with_conditioning: bool, - num_head_channels: int, - transformer_num_layers: int, - cross_attention_dim: int | None, - upcast_attention: bool = False, - use_flash_attention: bool = False, - dropout_cattn: float = 0.0, -) -> nn.Module: - if with_conditioning: - return CrossAttnMidBlock( - spatial_dims=spatial_dims, - in_channels=in_channels, - temb_channels=temb_channels, - norm_num_groups=norm_num_groups, - norm_eps=norm_eps, - num_head_channels=num_head_channels, - transformer_num_layers=transformer_num_layers, - cross_attention_dim=cross_attention_dim, - upcast_attention=upcast_attention, - use_flash_attention=use_flash_attention, - dropout_cattn=dropout_cattn, - ) - else: - return AttnMidBlock( - spatial_dims=spatial_dims, - in_channels=in_channels, - temb_channels=temb_channels, - norm_num_groups=norm_num_groups, - norm_eps=norm_eps, - num_head_channels=num_head_channels, - use_flash_attention=use_flash_attention, - ) - - -def get_up_block( - spatial_dims: int, - in_channels: int, - prev_output_channel: int, - out_channels: int, - temb_channels: int, - num_res_blocks: int, - norm_num_groups: int, - norm_eps: float, - add_upsample: bool, - resblock_updown: bool, - with_attn: bool, - with_cross_attn: bool, - num_head_channels: int, - transformer_num_layers: int, - cross_attention_dim: int | None, - upcast_attention: bool = False, - use_flash_attention: bool = False, - dropout_cattn: float = 0.0, -) -> nn.Module: - if with_attn: - return AttnUpBlock( - spatial_dims=spatial_dims, - in_channels=in_channels, - prev_output_channel=prev_output_channel, - out_channels=out_channels, - temb_channels=temb_channels, - num_res_blocks=num_res_blocks, - norm_num_groups=norm_num_groups, - norm_eps=norm_eps, - add_upsample=add_upsample, - resblock_updown=resblock_updown, - num_head_channels=num_head_channels, - use_flash_attention=use_flash_attention, - ) - elif with_cross_attn: - return CrossAttnUpBlock( - spatial_dims=spatial_dims, - in_channels=in_channels, - prev_output_channel=prev_output_channel, - out_channels=out_channels, - temb_channels=temb_channels, - num_res_blocks=num_res_blocks, - norm_num_groups=norm_num_groups, - norm_eps=norm_eps, - add_upsample=add_upsample, - resblock_updown=resblock_updown, - num_head_channels=num_head_channels, - transformer_num_layers=transformer_num_layers, - cross_attention_dim=cross_attention_dim, - upcast_attention=upcast_attention, - use_flash_attention=use_flash_attention, - dropout_cattn=dropout_cattn, - ) - else: - return UpBlock( - spatial_dims=spatial_dims, - in_channels=in_channels, - prev_output_channel=prev_output_channel, - out_channels=out_channels, - temb_channels=temb_channels, - num_res_blocks=num_res_blocks, - norm_num_groups=norm_num_groups, - norm_eps=norm_eps, - add_upsample=add_upsample, - resblock_updown=resblock_updown, - ) - - -class DiffusionModelUNet(nn.Module): - """ - Unet network with timestep embedding and attention mechanisms for conditioning based on - Rombach et al. "High-Resolution Image Synthesis with Latent Diffusion Models" https://arxiv.org/abs/2112.10752 - and Pinaya et al. "Brain Imaging Generation with Latent Diffusion Models" https://arxiv.org/abs/2209.07162 - - Args: - spatial_dims: number of spatial dimensions. - in_channels: number of input channels. - out_channels: number of output channels. - num_res_blocks: number of residual blocks (see _ResnetBlock) per level. - channels: tuple of block output channels. - attention_levels: list of levels to add attention. - norm_num_groups: number of groups for the normalization. - norm_eps: epsilon for the normalization. - resblock_updown: if True use residual blocks for up/downsampling. - num_head_channels: number of channels in each attention head. - with_conditioning: if True add spatial transformers to perform conditioning. - transformer_num_layers: number of layers of Transformer blocks to use. - cross_attention_dim: number of context dimensions to use. - num_class_embeds: if specified (as an int), then this model will be class-conditional with `num_class_embeds` - classes. - upcast_attention: if True, upcast attention operations to full precision. - use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. - dropout_cattn: if different from zero, this will be the dropout value for the cross-attention layers - """ - - def __init__( - self, - spatial_dims: int, - in_channels: int, - out_channels: int, - num_res_blocks: Sequence[int] | int = (2, 2, 2, 2), - channels: Sequence[int] = (32, 64, 64, 64), - attention_levels: Sequence[bool] = (False, False, True, True), - norm_num_groups: int = 32, - norm_eps: float = 1e-6, - resblock_updown: bool = False, - num_head_channels: int | Sequence[int] = 8, - with_conditioning: bool = False, - transformer_num_layers: int = 1, - cross_attention_dim: int | None = None, - num_class_embeds: int | None = None, - upcast_attention: bool = False, - use_flash_attention: bool = False, - dropout_cattn: float = 0.0, - ) -> None: - super().__init__() - if with_conditioning is True and cross_attention_dim is None: - raise ValueError( - "DiffusionModelUNet expects dimension of the cross-attention conditioning (cross_attention_dim) " - "when using with_conditioning." - ) - if cross_attention_dim is not None and with_conditioning is False: - raise ValueError( - "DiffusionModelUNet expects with_conditioning=True when specifying the cross_attention_dim." - ) - if dropout_cattn > 1.0 or dropout_cattn < 0.0: - raise ValueError("Dropout cannot be negative or >1.0!") - - # All number of channels should be multiple of num_groups - if any((out_channel % norm_num_groups) != 0 for out_channel in channels): - raise ValueError("DiffusionModelUNet expects all num_channels being multiple of norm_num_groups") - - if len(channels) != len(attention_levels): - raise ValueError("DiffusionModelUNet expects num_channels being same size of attention_levels") - - if isinstance(num_head_channels, int): - num_head_channels = ensure_tuple_rep(num_head_channels, len(attention_levels)) - - if len(num_head_channels) != len(attention_levels): - raise ValueError( - "num_head_channels should have the same length as attention_levels. For the i levels without attention," - " i.e. `attention_level[i]=False`, the num_head_channels[i] will be ignored." - ) - - if isinstance(num_res_blocks, int): - num_res_blocks = ensure_tuple_rep(num_res_blocks, len(channels)) - - if len(num_res_blocks) != len(channels): - raise ValueError( - "`num_res_blocks` should be a single integer or a tuple of integers with the same length as " - "`num_channels`." - ) - - if use_flash_attention and not has_xformers: - raise ValueError("use_flash_attention is True but xformers is not installed.") - - if use_flash_attention is True and not torch.cuda.is_available(): - raise ValueError( - "torch.cuda.is_available() should be True but is False. Flash attention is only available for GPU." - ) - - self.in_channels = in_channels - self.block_out_channels = channels - self.out_channels = out_channels - self.num_res_blocks = num_res_blocks - self.attention_levels = attention_levels - self.num_head_channels = num_head_channels - self.with_conditioning = with_conditioning - - # input - self.conv_in = Convolution( - spatial_dims=spatial_dims, - in_channels=in_channels, - out_channels=channels[0], - strides=1, - kernel_size=3, - padding=1, - conv_only=True, - ) - - # time - time_embed_dim = channels[0] * 4 - self.time_embed = nn.Sequential( - nn.Linear(channels[0], time_embed_dim), nn.SiLU(), nn.Linear(time_embed_dim, time_embed_dim) - ) - - # class embedding - self.num_class_embeds = num_class_embeds - if num_class_embeds is not None: - self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) - - # down - self.down_blocks = nn.ModuleList([]) - output_channel = channels[0] - for i in range(len(channels)): - input_channel = output_channel - output_channel = channels[i] - is_final_block = i == len(channels) - 1 - - down_block = get_down_block( - spatial_dims=spatial_dims, - in_channels=input_channel, - out_channels=output_channel, - temb_channels=time_embed_dim, - num_res_blocks=num_res_blocks[i], - norm_num_groups=norm_num_groups, - norm_eps=norm_eps, - add_downsample=not is_final_block, - resblock_updown=resblock_updown, - with_attn=(attention_levels[i] and not with_conditioning), - with_cross_attn=(attention_levels[i] and with_conditioning), - num_head_channels=num_head_channels[i], - transformer_num_layers=transformer_num_layers, - cross_attention_dim=cross_attention_dim, - upcast_attention=upcast_attention, - use_flash_attention=use_flash_attention, - dropout_cattn=dropout_cattn, - ) - - self.down_blocks.append(down_block) - - # mid - self.middle_block = get_mid_block( - spatial_dims=spatial_dims, - in_channels=channels[-1], - temb_channels=time_embed_dim, - norm_num_groups=norm_num_groups, - norm_eps=norm_eps, - with_conditioning=with_conditioning, - num_head_channels=num_head_channels[-1], - transformer_num_layers=transformer_num_layers, - cross_attention_dim=cross_attention_dim, - upcast_attention=upcast_attention, - use_flash_attention=use_flash_attention, - dropout_cattn=dropout_cattn, - ) - - # up - self.up_blocks = nn.ModuleList([]) - reversed_block_out_channels = list(reversed(channels)) - reversed_num_res_blocks = list(reversed(num_res_blocks)) - reversed_attention_levels = list(reversed(attention_levels)) - reversed_num_head_channels = list(reversed(num_head_channels)) - output_channel = reversed_block_out_channels[0] - for i in range(len(reversed_block_out_channels)): - prev_output_channel = output_channel - output_channel = reversed_block_out_channels[i] - input_channel = reversed_block_out_channels[min(i + 1, len(channels) - 1)] - - is_final_block = i == len(channels) - 1 - - up_block = get_up_block( - spatial_dims=spatial_dims, - in_channels=input_channel, - prev_output_channel=prev_output_channel, - out_channels=output_channel, - temb_channels=time_embed_dim, - num_res_blocks=reversed_num_res_blocks[i] + 1, - norm_num_groups=norm_num_groups, - norm_eps=norm_eps, - add_upsample=not is_final_block, - resblock_updown=resblock_updown, - with_attn=(reversed_attention_levels[i] and not with_conditioning), - with_cross_attn=(reversed_attention_levels[i] and with_conditioning), - num_head_channels=reversed_num_head_channels[i], - transformer_num_layers=transformer_num_layers, - cross_attention_dim=cross_attention_dim, - upcast_attention=upcast_attention, - use_flash_attention=use_flash_attention, - dropout_cattn=dropout_cattn, - ) - - self.up_blocks.append(up_block) - - # out - self.out = nn.Sequential( - nn.GroupNorm(num_groups=norm_num_groups, num_channels=channels[0], eps=norm_eps, affine=True), - nn.SiLU(), - zero_module( - Convolution( - spatial_dims=spatial_dims, - in_channels=channels[0], - out_channels=out_channels, - strides=1, - kernel_size=3, - padding=1, - conv_only=True, - ) - ), - ) - - def forward( - self, - x: torch.Tensor, - timesteps: torch.Tensor, - context: torch.Tensor | None = None, - class_labels: torch.Tensor | None = None, - down_block_additional_residuals: tuple[torch.Tensor] | None = None, - mid_block_additional_residual: torch.Tensor | None = None, - ) -> torch.Tensor: - """ - Args: - x: input tensor (N, C, SpatialDims). - timesteps: timestep tensor (N,). - context: context tensor (N, 1, ContextDim). - class_labels: context tensor (N, ). - down_block_additional_residuals: additional residual tensors for down blocks (N, C, FeatureMapsDims). - mid_block_additional_residual: additional residual tensor for mid block (N, C, FeatureMapsDims). - """ - # 1. time - t_emb = get_timestep_embedding(timesteps, self.block_out_channels[0]) - - # timesteps does not contain any weights and will always return f32 tensors - # but time_embedding might actually be running in fp16. so we need to cast here. - # there might be better ways to encapsulate this. - t_emb = t_emb.to(dtype=x.dtype) - emb = self.time_embed(t_emb) - - # 2. class - if self.num_class_embeds is not None: - if class_labels is None: - raise ValueError("class_labels should be provided when num_class_embeds > 0") - class_emb = self.class_embedding(class_labels) - class_emb = class_emb.to(dtype=x.dtype) - emb = emb + class_emb - - # 3. initial convolution - h = self.conv_in(x) - - # 4. down - if context is not None and self.with_conditioning is False: - raise ValueError("model should have with_conditioning = True if context is provided") - down_block_res_samples: list[torch.Tensor] = [h] - for downsample_block in self.down_blocks: - h, res_samples = downsample_block(hidden_states=h, temb=emb, context=context) - for residual in res_samples: - down_block_res_samples.append(residual) - - # Additional residual conections for Controlnets - if down_block_additional_residuals is not None: - new_down_block_res_samples: list[torch.Tensor] = [] - for down_block_res_sample, down_block_additional_residual in zip( - down_block_res_samples, down_block_additional_residuals - ): - down_block_res_sample = down_block_res_sample + down_block_additional_residual - new_down_block_res_samples += [down_block_res_sample] - - down_block_res_samples = new_down_block_res_samples - - # 5. mid - h = self.middle_block(hidden_states=h, temb=emb, context=context) - - # Additional residual conections for Controlnets - if mid_block_additional_residual is not None: - h = h + mid_block_additional_residual - - # 6. up - for upsample_block in self.up_blocks: - res_samples = down_block_res_samples[-len(upsample_block.resnets) :] - down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] - h = upsample_block(hidden_states=h, res_hidden_states_list=res_samples, temb=emb, context=context) - - # 7. output block - output: torch.Tensor = self.out(h) - - return output - - -class DiffusionModelEncoder(nn.Module): - """ - Classification Network based on the Encoder of the Diffusion Model, followed by fully connected layers. This network is based on - Wolleb et al. "Diffusion Models for Medical Anomaly Detection" (https://arxiv.org/abs/2203.04306). - - Args: - spatial_dims: number of spatial dimensions. - in_channels: number of input channels. - out_channels: number of output channels. - num_res_blocks: number of residual blocks (see _ResnetBlock) per level. - channels: tuple of block output channels. - attention_levels: list of levels to add attention. - norm_num_groups: number of groups for the normalization. - norm_eps: epsilon for the normalization. - resblock_updown: if True use residual blocks for downsampling. - num_head_channels: number of channels in each attention head. - with_conditioning: if True add spatial transformers to perform conditioning. - transformer_num_layers: number of layers of Transformer blocks to use. - cross_attention_dim: number of context dimensions to use. - num_class_embeds: if specified (as an int), then this model will be class-conditional with `num_class_embeds` classes. - upcast_attention: if True, upcast attention operations to full precision. - """ - - def __init__( - self, - spatial_dims: int, - in_channels: int, - out_channels: int, - num_res_blocks: Sequence[int] | int = (2, 2, 2, 2), - channels: Sequence[int] = (32, 64, 64, 64), - attention_levels: Sequence[bool] = (False, False, True, True), - norm_num_groups: int = 32, - norm_eps: float = 1e-6, - resblock_updown: bool = False, - num_head_channels: int | Sequence[int] = 8, - with_conditioning: bool = False, - transformer_num_layers: int = 1, - cross_attention_dim: int | None = None, - num_class_embeds: int | None = None, - upcast_attention: bool = False, - ) -> None: - super().__init__() - if with_conditioning is True and cross_attention_dim is None: - raise ValueError( - "DiffusionModelEncoder expects dimension of the cross-attention conditioning (cross_attention_dim) " - "when using with_conditioning." - ) - if cross_attention_dim is not None and with_conditioning is False: - raise ValueError( - "DiffusionModelEncoder expects with_conditioning=True when specifying the cross_attention_dim." - ) - - # All number of channels should be multiple of num_groups - if any((out_channel % norm_num_groups) != 0 for out_channel in channels): - raise ValueError("DiffusionModelEncoder expects all num_channels being multiple of norm_num_groups") - if len(channels) != len(attention_levels): - raise ValueError("DiffusionModelEncoder expects num_channels being same size of attention_levels") - - if isinstance(num_head_channels, int): - num_head_channels = ensure_tuple_rep(num_head_channels, len(attention_levels)) - - if isinstance(num_res_blocks, int): - num_res_blocks = ensure_tuple_rep(num_res_blocks, len(channels)) - - if len(num_head_channels) != len(attention_levels): - raise ValueError( - "num_head_channels should have the same length as attention_levels. For the i levels without attention," - " i.e. `attention_level[i]=False`, the num_head_channels[i] will be ignored." - ) - - self.in_channels = in_channels - self.block_out_channels = channels - self.out_channels = out_channels - self.num_res_blocks = num_res_blocks - self.attention_levels = attention_levels - self.num_head_channels = num_head_channels - self.with_conditioning = with_conditioning - - # input - self.conv_in = Convolution( - spatial_dims=spatial_dims, - in_channels=in_channels, - out_channels=channels[0], - strides=1, - kernel_size=3, - padding=1, - conv_only=True, - ) - - # time - time_embed_dim = channels[0] * 4 - self.time_embed = nn.Sequential( - nn.Linear(channels[0], time_embed_dim), nn.SiLU(), nn.Linear(time_embed_dim, time_embed_dim) - ) - - # class embedding - self.num_class_embeds = num_class_embeds - if num_class_embeds is not None: - self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) - - # down - self.down_blocks = nn.ModuleList([]) - output_channel = channels[0] - for i in range(len(channels)): - input_channel = output_channel - output_channel = channels[i] - is_final_block = i == len(channels) # - 1 - - down_block = get_down_block( - spatial_dims=spatial_dims, - in_channels=input_channel, - out_channels=output_channel, - temb_channels=time_embed_dim, - num_res_blocks=num_res_blocks[i], - norm_num_groups=norm_num_groups, - norm_eps=norm_eps, - add_downsample=not is_final_block, - resblock_updown=resblock_updown, - with_attn=(attention_levels[i] and not with_conditioning), - with_cross_attn=(attention_levels[i] and with_conditioning), - num_head_channels=num_head_channels[i], - transformer_num_layers=transformer_num_layers, - cross_attention_dim=cross_attention_dim, - upcast_attention=upcast_attention, - ) - - self.down_blocks.append(down_block) - - self.out = nn.Sequential(nn.Linear(4096, 512), nn.ReLU(), nn.Dropout(0.1), nn.Linear(512, self.out_channels)) - - def forward( - self, - x: torch.Tensor, - timesteps: torch.Tensor, - context: torch.Tensor | None = None, - class_labels: torch.Tensor | None = None, - ) -> torch.Tensor: - """ - Args: - x: input tensor (N, C, SpatialDims). - timesteps: timestep tensor (N,). - context: context tensor (N, 1, ContextDim). - class_labels: context tensor (N, ). - """ - # 1. time - t_emb = get_timestep_embedding(timesteps, self.block_out_channels[0]) - - # timesteps does not contain any weights and will always return f32 tensors - # but time_embedding might actually be running in fp16. so we need to cast here. - # there might be better ways to encapsulate this. - t_emb = t_emb.to(dtype=x.dtype) - emb = self.time_embed(t_emb) - - # 2. class - if self.num_class_embeds is not None: - if class_labels is None: - raise ValueError("class_labels should be provided when num_class_embeds > 0") - class_emb = self.class_embedding(class_labels) - class_emb = class_emb.to(dtype=x.dtype) - emb = emb + class_emb - - # 3. initial convolution - h = self.conv_in(x) - - # 4. down - if context is not None and self.with_conditioning is False: - raise ValueError("model should have with_conditioning = True if context is provided") - for downsample_block in self.down_blocks: - h, _ = downsample_block(hidden_states=h, temb=emb, context=context) - - h = h.reshape(h.shape[0], -1) - output: torch.Tensor = self.out(h) - - return output diff --git a/monai/networks/nets/patchgan_discriminator.py b/monai/networks/nets/patchgan_discriminator.py deleted file mode 100644 index 3b089616ce7..00000000000 --- a/monai/networks/nets/patchgan_discriminator.py +++ /dev/null @@ -1,247 +0,0 @@ -# Copyright (c) MONAI Consortium -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -from collections.abc import Sequence - -import torch -import torch.nn as nn - -from monai.networks.blocks import Convolution -from monai.networks.layers import Act - - -class MultiScalePatchDiscriminator(nn.Sequential): - """ - Multi-scale Patch-GAN discriminator based on Pix2PixHD: - High-Resolution Image Synthesis and Semantic Manipulation with Conditional GANs (https://arxiv.org/abs/1711.11585) - - The Multi-scale discriminator made up of several PatchGAN discriminators, that process the images - at different spatial scales. - - Args: - num_d: number of discriminators - num_layers_d: number of Convolution layers (Conv + activation + normalisation + [dropout]) in the first - discriminator. Each subsequent discriminator has one additional layer, meaning the output size is halved. - spatial_dims: number of spatial dimensions (1D, 2D etc.) - channels: number of filters in the first convolutional layer (doubled for each subsequent layer) - in_channels: number of input channels - out_channels: number of output channels in each discriminator - kernel_size: kernel size of the convolution layers - activation: activation layer type - norm: normalisation type - bias: introduction of layer bias - dropout: probability of dropout applied, defaults to 0. - minimum_size_im: minimum spatial size of the input image. Introduced to make sure the architecture - requested isn't going to downsample the input image beyond value of 1. - last_conv_kernel_size: kernel size of the last convolutional layer. - """ - - def __init__( - self, - num_d: int, - num_layers_d: int, - spatial_dims: int, - channels: int, - in_channels: int, - out_channels: int = 1, - kernel_size: int = 4, - activation: str | tuple = (Act.LEAKYRELU, {"negative_slope": 0.2}), - norm: str | tuple = "BATCH", - bias: bool = False, - dropout: float | tuple = 0.0, - minimum_size_im: int = 256, - last_conv_kernel_size: int = 1, - ) -> None: - super().__init__() - self.num_d = num_d - self.num_layers_d = num_layers_d - self.num_channels = channels - self.padding = tuple([int((kernel_size - 1) / 2)] * spatial_dims) - for i_ in range(self.num_d): - num_layers_d_i = self.num_layers_d * (i_ + 1) - output_size = float(minimum_size_im) / (2**num_layers_d_i) - if output_size < 1: - raise AssertionError( - f"Your image size is too small to take in up to {i_} discriminators with num_layers = {num_layers_d_i}." - "Please reduce num_layers, reduce num_D or enter bigger images." - ) - subnet_d = PatchDiscriminator( - spatial_dims=spatial_dims, - channels=self.num_channels, - in_channels=in_channels, - out_channels=out_channels, - num_layers_d=num_layers_d_i, - kernel_size=kernel_size, - activation=activation, - norm=norm, - bias=bias, - padding=self.padding, - dropout=dropout, - last_conv_kernel_size=last_conv_kernel_size, - ) - - self.add_module("discriminator_%d" % i_, subnet_d) - - def forward(self, i: torch.Tensor) -> tuple[list[torch.Tensor], list[list[torch.Tensor]]]: - """ - Args: - i: Input tensor - - Returns: - list of outputs and another list of lists with the intermediate features - of each discriminator. - """ - - out: list[torch.Tensor] = [] - intermediate_features: list[list[torch.Tensor]] = [] - for disc in self.children(): - out_d: list[torch.Tensor] = disc(i) - out.append(out_d[-1]) - intermediate_features.append(out_d[:-1]) - - return out, intermediate_features - - -class PatchDiscriminator(nn.Sequential): - """ - Patch-GAN discriminator based on Pix2PixHD: - High-Resolution Image Synthesis and Semantic Manipulation with Conditional GANs (https://arxiv.org/abs/1711.11585) - - - Args: - spatial_dims: number of spatial dimensions (1D, 2D etc.) - channels: number of filters in the first convolutional layer (doubled for each subsequent layer) - in_channels: number of input channels - out_channels: number of output channels - num_layers_d: number of Convolution layers (Conv + activation + normalisation + [dropout]) in the discriminator. - kernel_size: kernel size of the convolution layers - act: activation type and arguments. Defaults to LeakyReLU. - norm: feature normalization type and arguments. Defaults to batch norm. - bias: whether to have a bias term in convolution blocks. Defaults to False. - padding: padding to be applied to the convolutional layers - dropout: proportion of dropout applied, defaults to 0. - last_conv_kernel_size: kernel size of the last convolutional layer. - """ - - def __init__( - self, - spatial_dims: int, - channels: int, - in_channels: int, - out_channels: int = 1, - num_layers_d: int = 3, - kernel_size: int = 4, - activation: str | tuple = (Act.LEAKYRELU, {"negative_slope": 0.2}), - norm: str | tuple = "BATCH", - bias: bool = False, - padding: int | Sequence[int] = 1, - dropout: float | tuple = 0.0, - last_conv_kernel_size: int | None = None, - ) -> None: - super().__init__() - self.num_layers_d = num_layers_d - self.num_channels = channels - if last_conv_kernel_size is None: - last_conv_kernel_size = kernel_size - - self.add_module( - "initial_conv", - Convolution( - spatial_dims=spatial_dims, - kernel_size=kernel_size, - in_channels=in_channels, - out_channels=channels, - act=activation, - bias=True, - norm=None, - dropout=dropout, - padding=padding, - strides=2, - ), - ) - - input_channels = channels - output_channels = channels * 2 - - # Initial Layer - for l_ in range(self.num_layers_d): - if l_ == self.num_layers_d - 1: - stride = 1 - else: - stride = 2 - layer = Convolution( - spatial_dims=spatial_dims, - kernel_size=kernel_size, - in_channels=input_channels, - out_channels=output_channels, - act=activation, - bias=bias, - norm=norm, - dropout=dropout, - padding=padding, - strides=stride, - ) - self.add_module("%d" % l_, layer) - input_channels = output_channels - output_channels = output_channels * 2 - - # Final layer - self.add_module( - "final_conv", - Convolution( - spatial_dims=spatial_dims, - kernel_size=last_conv_kernel_size, - in_channels=input_channels, - out_channels=out_channels, - bias=True, - conv_only=True, - padding=int((last_conv_kernel_size - 1) / 2), - dropout=0.0, - strides=1, - ), - ) - - self.apply(self.initialise_weights) - - def forward(self, x: torch.Tensor) -> list[torch.Tensor]: - """ - Args: - x: input tensor - - Returns: - list of intermediate features, with the last element being the output. - """ - out = [x] - for submodel in self.children(): - intermediate_output = submodel(out[-1]) - out.append(intermediate_output) - - return out[1:] - - def initialise_weights(self, m: nn.Module) -> None: - """ - Initialise weights of Convolution and BatchNorm layers. - - Args: - m: instance of torch.nn.module (or of class inheriting torch.nn.module) - """ - classname = m.__class__.__name__ - if classname.find("Conv2d") != -1: - nn.init.normal_(m.weight.data, 0.0, 0.02) - elif classname.find("Conv3d") != -1: - nn.init.normal_(m.weight.data, 0.0, 0.02) - elif classname.find("Conv1d") != -1: - nn.init.normal_(m.weight.data, 0.0, 0.02) - elif classname.find("BatchNorm") != -1: - nn.init.normal_(m.weight.data, 1.0, 0.02) - nn.init.constant_(m.bias.data, 0) diff --git a/monai/networks/nets/spade_autoencoderkl.py b/monai/networks/nets/spade_autoencoderkl.py deleted file mode 100644 index e064c197406..00000000000 --- a/monai/networks/nets/spade_autoencoderkl.py +++ /dev/null @@ -1,473 +0,0 @@ -# Copyright (c) MONAI Consortium -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -from collections.abc import Sequence - -import torch -import torch.nn as nn -import torch.nn.functional as F - -from monai.networks.blocks import Convolution -from monai.networks.blocks.spade_norm import SPADE -from monai.networks.nets.autoencoderkl import Encoder, _AttentionBlock, _Upsample -from monai.utils import ensure_tuple_rep - -__all__ = ["SPADEAutoencoderKL"] - - -class SPADEResBlock(nn.Module): - """ - Residual block consisting of a cascade of 2 convolutions + activation + normalisation block, and a - residual connection between input and output. - Enables SPADE normalisation for semantic conditioning (Park et. al (2019): https://github.com/NVlabs/SPADE) - - Args: - spatial_dims: number of spatial dimensions (1D, 2D, 3D). - in_channels: input channels to the layer. - norm_num_groups: number of groups involved for the group normalisation layer. Ensure that your number of - channels is divisible by this number. - norm_eps: epsilon for the normalisation. - out_channels: number of output channels. - label_nc: number of semantic channels for SPADE normalisation - spade_intermediate_channels: number of intermediate channels for SPADE block layer - """ - - def __init__( - self, - spatial_dims: int, - in_channels: int, - norm_num_groups: int, - norm_eps: float, - out_channels: int, - label_nc: int, - spade_intermediate_channels: int, - ) -> None: - super().__init__() - self.in_channels = in_channels - self.out_channels = in_channels if out_channels is None else out_channels - self.norm1 = SPADE( - label_nc=label_nc, - norm_nc=in_channels, - norm="GROUP", - norm_params={"num_groups": norm_num_groups, "affine": False}, - hidden_channels=spade_intermediate_channels, - kernel_size=3, - spatial_dims=spatial_dims, - ) - self.conv1 = Convolution( - spatial_dims=spatial_dims, - in_channels=self.in_channels, - out_channels=self.out_channels, - strides=1, - kernel_size=3, - padding=1, - conv_only=True, - ) - self.norm2 = SPADE( - label_nc=label_nc, - norm_nc=out_channels, - norm="GROUP", - norm_params={"num_groups": norm_num_groups, "affine": False}, - hidden_channels=spade_intermediate_channels, - kernel_size=3, - spatial_dims=spatial_dims, - ) - self.conv2 = Convolution( - spatial_dims=spatial_dims, - in_channels=self.out_channels, - out_channels=self.out_channels, - strides=1, - kernel_size=3, - padding=1, - conv_only=True, - ) - - self.nin_shortcut: nn.Module - if self.in_channels != self.out_channels: - self.nin_shortcut = Convolution( - spatial_dims=spatial_dims, - in_channels=self.in_channels, - out_channels=self.out_channels, - strides=1, - kernel_size=1, - padding=0, - conv_only=True, - ) - else: - self.nin_shortcut = nn.Identity() - - def forward(self, x: torch.Tensor, seg: torch.Tensor) -> torch.Tensor: - h = x - h = self.norm1(h, seg) - h = F.silu(h) - h = self.conv1(h) - h = self.norm2(h, seg) - h = F.silu(h) - h = self.conv2(h) - - x = self.nin_shortcut(x) - - return x + h - - -class SPADEDecoder(nn.Module): - """ - Convolutional cascade upsampling from a spatial latent space into an image space. - Enables SPADE normalisation for semantic conditioning (Park et. al (2019): https://github.com/NVlabs/SPADE) - - Args: - spatial_dims: number of spatial dimensions (1D, 2D, 3D). - channels: sequence of block output channels. - in_channels: number of channels in the bottom layer (latent space) of the autoencoder. - out_channels: number of output channels. - num_res_blocks: number of residual blocks (see ResBlock) per level. - norm_num_groups: number of groups for the GroupNorm layers, channels must be divisible by this number. - norm_eps: epsilon for the normalization. - attention_levels: indicate which level from channels contain an attention block. - label_nc: number of semantic channels for SPADE normalisation. - with_nonlocal_attn: if True use non-local attention block. - use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. - spade_intermediate_channels: number of intermediate channels for SPADE block layer. - """ - - def __init__( - self, - spatial_dims: int, - channels: Sequence[int], - in_channels: int, - out_channels: int, - num_res_blocks: Sequence[int], - norm_num_groups: int, - norm_eps: float, - attention_levels: Sequence[bool], - label_nc: int, - with_nonlocal_attn: bool = True, - use_flash_attention: bool = False, - spade_intermediate_channels: int = 128, - ) -> None: - super().__init__() - self.spatial_dims = spatial_dims - self.channels = channels - self.in_channels = in_channels - self.out_channels = out_channels - self.num_res_blocks = num_res_blocks - self.norm_num_groups = norm_num_groups - self.norm_eps = norm_eps - self.attention_levels = attention_levels - self.label_nc = label_nc - - reversed_block_out_channels = list(reversed(channels)) - - blocks: list[nn.Module] = [] - - # Initial convolution - blocks.append( - Convolution( - spatial_dims=spatial_dims, - in_channels=in_channels, - out_channels=reversed_block_out_channels[0], - strides=1, - kernel_size=3, - padding=1, - conv_only=True, - ) - ) - - # Non-local attention block - if with_nonlocal_attn is True: - blocks.append( - SPADEResBlock( - spatial_dims=spatial_dims, - in_channels=reversed_block_out_channels[0], - norm_num_groups=norm_num_groups, - norm_eps=norm_eps, - out_channels=reversed_block_out_channels[0], - label_nc=label_nc, - spade_intermediate_channels=spade_intermediate_channels, - ) - ) - blocks.append( - _AttentionBlock( - spatial_dims=spatial_dims, - num_channels=reversed_block_out_channels[0], - norm_num_groups=norm_num_groups, - norm_eps=norm_eps, - use_flash_attention=use_flash_attention, - ) - ) - blocks.append( - SPADEResBlock( - spatial_dims=spatial_dims, - in_channels=reversed_block_out_channels[0], - norm_num_groups=norm_num_groups, - norm_eps=norm_eps, - out_channels=reversed_block_out_channels[0], - label_nc=label_nc, - spade_intermediate_channels=spade_intermediate_channels, - ) - ) - - reversed_attention_levels = list(reversed(attention_levels)) - reversed_num_res_blocks = list(reversed(num_res_blocks)) - block_out_ch = reversed_block_out_channels[0] - for i in range(len(reversed_block_out_channels)): - block_in_ch = block_out_ch - block_out_ch = reversed_block_out_channels[i] - is_final_block = i == len(channels) - 1 - - for _ in range(reversed_num_res_blocks[i]): - blocks.append( - SPADEResBlock( - spatial_dims=spatial_dims, - in_channels=block_in_ch, - norm_num_groups=norm_num_groups, - norm_eps=norm_eps, - out_channels=block_out_ch, - label_nc=label_nc, - spade_intermediate_channels=spade_intermediate_channels, - ) - ) - block_in_ch = block_out_ch - - if reversed_attention_levels[i]: - blocks.append( - _AttentionBlock( - spatial_dims=spatial_dims, - num_channels=block_in_ch, - norm_num_groups=norm_num_groups, - norm_eps=norm_eps, - use_flash_attention=use_flash_attention, - ) - ) - - if not is_final_block: - blocks.append(_Upsample(spatial_dims=spatial_dims, in_channels=block_in_ch, use_convtranspose=False)) - - blocks.append(nn.GroupNorm(num_groups=norm_num_groups, num_channels=block_in_ch, eps=norm_eps, affine=True)) - blocks.append( - Convolution( - spatial_dims=spatial_dims, - in_channels=block_in_ch, - out_channels=out_channels, - strides=1, - kernel_size=3, - padding=1, - conv_only=True, - ) - ) - - self.blocks = nn.ModuleList(blocks) - - def forward(self, x: torch.Tensor, seg: torch.Tensor) -> torch.Tensor: - for block in self.blocks: - if isinstance(block, SPADEResBlock): - x = block(x, seg) - else: - x = block(x) - return x - - -class SPADEAutoencoderKL(nn.Module): - """ - Autoencoder model with KL-regularized latent space based on - Rombach et al. "High-Resolution Image Synthesis with Latent Diffusion Models" https://arxiv.org/abs/2112.10752 - and Pinaya et al. "Brain Imaging Generation with Latent Diffusion Models" https://arxiv.org/abs/2209.07162 - Enables SPADE normalisation for semantic conditioning (Park et. al (2019): https://github.com/NVlabs/SPADE) - - Args: - spatial_dims: number of spatial dimensions (1D, 2D, 3D). - label_nc: number of semantic channels for SPADE normalisation. - in_channels: number of input channels. - out_channels: number of output channels. - num_res_blocks: number of residual blocks (see ResBlock) per level. - channels: sequence of block output channels. - attention_levels: sequence of levels to add attention. - latent_channels: latent embedding dimension. - norm_num_groups: number of groups for the GroupNorm layers, channels must be divisible by this number. - norm_eps: epsilon for the normalization. - with_encoder_nonlocal_attn: if True use non-local attention block in the encoder. - with_decoder_nonlocal_attn: if True use non-local attention block in the decoder. - use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. - spade_intermediate_channels: number of intermediate channels for SPADE block layer. - """ - - def __init__( - self, - spatial_dims: int, - label_nc: int, - in_channels: int = 1, - out_channels: int = 1, - num_res_blocks: Sequence[int] | int = (2, 2, 2, 2), - channels: Sequence[int] = (32, 64, 64, 64), - attention_levels: Sequence[bool] = (False, False, True, True), - latent_channels: int = 3, - norm_num_groups: int = 32, - norm_eps: float = 1e-6, - with_encoder_nonlocal_attn: bool = True, - with_decoder_nonlocal_attn: bool = True, - use_flash_attention: bool = False, - spade_intermediate_channels: int = 128, - ) -> None: - super().__init__() - - # All number of channels should be multiple of num_groups - if any((out_channel % norm_num_groups) != 0 for out_channel in channels): - raise ValueError("SPADEAutoencoderKL expects all channels being multiple of norm_num_groups") - - if len(channels) != len(attention_levels): - raise ValueError("SPADEAutoencoderKL expects channels being same size of attention_levels") - - if isinstance(num_res_blocks, int): - num_res_blocks = ensure_tuple_rep(num_res_blocks, len(channels)) - - if len(num_res_blocks) != len(channels): - raise ValueError( - "`num_res_blocks` should be a single integer or a tuple of integers with the same length as " - "`channels`." - ) - - if use_flash_attention is True and not torch.cuda.is_available(): - raise ValueError( - "torch.cuda.is_available() should be True but is False. Flash attention is only available for GPU." - ) - - self.encoder = Encoder( - spatial_dims=spatial_dims, - in_channels=in_channels, - channels=channels, - out_channels=latent_channels, - num_res_blocks=num_res_blocks, - norm_num_groups=norm_num_groups, - norm_eps=norm_eps, - attention_levels=attention_levels, - with_nonlocal_attn=with_encoder_nonlocal_attn, - use_flash_attention=use_flash_attention, - ) - self.decoder = SPADEDecoder( - spatial_dims=spatial_dims, - channels=channels, - in_channels=latent_channels, - out_channels=out_channels, - num_res_blocks=num_res_blocks, - norm_num_groups=norm_num_groups, - norm_eps=norm_eps, - attention_levels=attention_levels, - label_nc=label_nc, - with_nonlocal_attn=with_decoder_nonlocal_attn, - use_flash_attention=use_flash_attention, - spade_intermediate_channels=spade_intermediate_channels, - ) - self.quant_conv_mu = Convolution( - spatial_dims=spatial_dims, - in_channels=latent_channels, - out_channels=latent_channels, - strides=1, - kernel_size=1, - padding=0, - conv_only=True, - ) - self.quant_conv_log_sigma = Convolution( - spatial_dims=spatial_dims, - in_channels=latent_channels, - out_channels=latent_channels, - strides=1, - kernel_size=1, - padding=0, - conv_only=True, - ) - self.post_quant_conv = Convolution( - spatial_dims=spatial_dims, - in_channels=latent_channels, - out_channels=latent_channels, - strides=1, - kernel_size=1, - padding=0, - conv_only=True, - ) - self.latent_channels = latent_channels - - def encode(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: - """ - Forwards an image through the spatial encoder, obtaining the latent mean and sigma representations. - - Args: - x: BxCx[SPATIAL DIMS] tensor - - """ - h = self.encoder(x) - z_mu = self.quant_conv_mu(h) - z_log_var = self.quant_conv_log_sigma(h) - z_log_var = torch.clamp(z_log_var, -30.0, 20.0) - z_sigma = torch.exp(z_log_var / 2) - - return z_mu, z_sigma - - def sampling(self, z_mu: torch.Tensor, z_sigma: torch.Tensor) -> torch.Tensor: - """ - From the mean and sigma representations resulting of encoding an image through the latent space, - obtains a noise sample resulting from sampling gaussian noise, multiplying by the variance (sigma) and - adding the mean. - - Args: - z_mu: Bx[Z_CHANNELS]x[LATENT SPACE SIZE] mean vector obtained by the encoder when you encode an image - z_sigma: Bx[Z_CHANNELS]x[LATENT SPACE SIZE] variance vector obtained by the encoder when you encode an image - - Returns: - sample of shape Bx[Z_CHANNELS]x[LATENT SPACE SIZE] - """ - eps = torch.randn_like(z_sigma) - z_vae = z_mu + eps * z_sigma - return z_vae - - def reconstruct(self, x: torch.Tensor, seg: torch.Tensor) -> torch.Tensor: - """ - Encodes and decodes an input image. - - Args: - x: BxCx[SPATIAL DIMENSIONS] tensor. - seg: Bx[LABEL_NC]x[SPATIAL DIMENSIONS] tensor of segmentations for SPADE norm. - Returns: - reconstructed image, of the same shape as input - """ - z_mu, _ = self.encode(x) - reconstruction = self.decode(z_mu, seg) - return reconstruction - - def decode(self, z: torch.Tensor, seg: torch.Tensor) -> torch.Tensor: - """ - Based on a latent space sample, forwards it through the Decoder. - - Args: - z: Bx[Z_CHANNELS]x[LATENT SPACE SHAPE] - seg: Bx[LABEL_NC]x[SPATIAL DIMENSIONS] tensor of segmentations for SPADE norm. - Returns: - decoded image tensor - """ - z = self.post_quant_conv(z) - dec: torch.Tensor = self.decoder(z, seg) - return dec - - def forward(self, x: torch.Tensor, seg: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - z_mu, z_sigma = self.encode(x) - z = self.sampling(z_mu, z_sigma) - reconstruction = self.decode(z, seg) - return reconstruction, z_mu, z_sigma - - def encode_stage_2_inputs(self, x: torch.Tensor) -> torch.Tensor: - z_mu, z_sigma = self.encode(x) - z = self.sampling(z_mu, z_sigma) - return z - - def decode_stage_2_outputs(self, z: torch.Tensor, seg: torch.Tensor) -> torch.Tensor: - image = self.decode(z, seg) - return image diff --git a/monai/networks/nets/spade_diffusion_model_unet.py b/monai/networks/nets/spade_diffusion_model_unet.py deleted file mode 100644 index d53327100e9..00000000000 --- a/monai/networks/nets/spade_diffusion_model_unet.py +++ /dev/null @@ -1,908 +0,0 @@ -# Copyright (c) MONAI Consortium -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -# ========================================================================= -# Adapted from https://github.com/huggingface/diffusers -# which has the following license: -# https://github.com/huggingface/diffusers/blob/main/LICENSE -# -# Copyright 2022 UC Berkeley Team and The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ========================================================================= - -from __future__ import annotations - -from collections.abc import Sequence - -import torch -from torch import nn - -from monai.networks.blocks import Convolution -from monai.networks.blocks.spade_norm import SPADE -from monai.networks.nets.diffusion_model_unet import ( - _AttentionBlock, - _Downsample, - _ResnetBlock, - _SpatialTransformer, - _Upsample, - get_down_block, - get_mid_block, - get_timestep_embedding, - zero_module, -) -from monai.utils import ensure_tuple_rep, optional_import - -# To install xformers, use pip install xformers==0.0.16rc401 -xops, has_xformers = optional_import("xformers.ops") - - -__all__ = ["SPADEDiffusionModelUNet"] - - -class SPADEResnetBlock(nn.Module): - """ - Residual block with timestep conditioning and SPADE norm. - Enables SPADE normalisation for semantic conditioning (Park et. al (2019): https://github.com/NVlabs/SPADE) - - Args: - spatial_dims: The number of spatial dimensions. - in_channels: number of input channels. - temb_channels: number of timestep embedding channels. - label_nc: number of semantic channels for SPADE normalisation. - out_channels: number of output channels. - up: if True, performs upsampling. - down: if True, performs downsampling. - norm_num_groups: number of groups for the group normalization. - norm_eps: epsilon for the group normalization. - spade_intermediate_channels: number of intermediate channels for SPADE block layer - """ - - def __init__( - self, - spatial_dims: int, - in_channels: int, - temb_channels: int, - label_nc: int, - out_channels: int | None = None, - up: bool = False, - down: bool = False, - norm_num_groups: int = 32, - norm_eps: float = 1e-6, - spade_intermediate_channels: int = 128, - ) -> None: - super().__init__() - self.spatial_dims = spatial_dims - self.channels = in_channels - self.emb_channels = temb_channels - self.out_channels = out_channels or in_channels - self.up = up - self.down = down - - self.norm1 = SPADE( - label_nc=label_nc, - norm_nc=in_channels, - norm="GROUP", - norm_params={"num_groups": norm_num_groups, "eps": norm_eps, "affine": True}, - hidden_channels=spade_intermediate_channels, - kernel_size=3, - spatial_dims=spatial_dims, - ) - - self.nonlinearity = nn.SiLU() - self.conv1 = Convolution( - spatial_dims=spatial_dims, - in_channels=in_channels, - out_channels=self.out_channels, - strides=1, - kernel_size=3, - padding=1, - conv_only=True, - ) - - self.upsample = self.downsample = None - if self.up: - self.upsample = _Upsample(spatial_dims, in_channels, use_conv=False) - elif down: - self.downsample = _Downsample(spatial_dims, in_channels, use_conv=False) - - self.time_emb_proj = nn.Linear(temb_channels, self.out_channels) - - self.norm2 = SPADE( - label_nc=label_nc, - norm_nc=self.out_channels, - norm="GROUP", - norm_params={"num_groups": norm_num_groups, "eps": norm_eps, "affine": True}, - hidden_channels=spade_intermediate_channels, - kernel_size=3, - spatial_dims=spatial_dims, - ) - self.conv2 = zero_module( - Convolution( - spatial_dims=spatial_dims, - in_channels=self.out_channels, - out_channels=self.out_channels, - strides=1, - kernel_size=3, - padding=1, - conv_only=True, - ) - ) - self.skip_connection: nn.Module - - if self.out_channels == in_channels: - self.skip_connection = nn.Identity() - else: - self.skip_connection = Convolution( - spatial_dims=spatial_dims, - in_channels=in_channels, - out_channels=self.out_channels, - strides=1, - kernel_size=1, - padding=0, - conv_only=True, - ) - - def forward(self, x: torch.Tensor, emb: torch.Tensor, seg: torch.Tensor) -> torch.Tensor: - h = x - h = self.norm1(h, seg) - h = self.nonlinearity(h) - - if self.upsample is not None: - if h.shape[0] >= 64: - x = x.contiguous() - h = h.contiguous() - x = self.upsample(x) - h = self.upsample(h) - elif self.downsample is not None: - x = self.downsample(x) - h = self.downsample(h) - - h = self.conv1(h) - - if self.spatial_dims == 2: - temb = self.time_emb_proj(self.nonlinearity(emb))[:, :, None, None] - else: - temb = self.time_emb_proj(self.nonlinearity(emb))[:, :, None, None, None] - h = h + temb - - h = self.norm2(h, seg) - h = self.nonlinearity(h) - h = self.conv2(h) - output: torch.Tensor = self.skip_connection(x) + h - return output - - -class SPADEUpBlock(nn.Module): - """ - Unet's up block containing resnet and upsamplers blocks. - Enables SPADE normalisation for semantic conditioning (Park et. al (2019): https://github.com/NVlabs/SPADE) - - Args: - spatial_dims: The number of spatial dimensions. - in_channels: number of input channels. - prev_output_channel: number of channels from residual connection. - out_channels: number of output channels. - temb_channels: number of timestep embedding channels. - label_nc: number of semantic channels for SPADE normalisation. - num_res_blocks: number of residual blocks. - norm_num_groups: number of groups for the group normalization. - norm_eps: epsilon for the group normalization. - add_upsample: if True add downsample block. - resblock_updown: if True use residual blocks for upsampling. - spade_intermediate_channels: number of intermediate channels for SPADE block layer. - """ - - def __init__( - self, - spatial_dims: int, - in_channels: int, - prev_output_channel: int, - out_channels: int, - temb_channels: int, - label_nc: int, - num_res_blocks: int = 1, - norm_num_groups: int = 32, - norm_eps: float = 1e-6, - add_upsample: bool = True, - resblock_updown: bool = False, - spade_intermediate_channels: int = 128, - ) -> None: - super().__init__() - self.resblock_updown = resblock_updown - resnets = [] - - for i in range(num_res_blocks): - res_skip_channels = in_channels if (i == num_res_blocks - 1) else out_channels - resnet_in_channels = prev_output_channel if i == 0 else out_channels - - resnets.append( - SPADEResnetBlock( - spatial_dims=spatial_dims, - in_channels=resnet_in_channels + res_skip_channels, - out_channels=out_channels, - temb_channels=temb_channels, - label_nc=label_nc, - norm_num_groups=norm_num_groups, - norm_eps=norm_eps, - spade_intermediate_channels=spade_intermediate_channels, - ) - ) - - self.resnets = nn.ModuleList(resnets) - - self.upsampler: nn.Module | None - if add_upsample: - if resblock_updown: - self.upsampler = _ResnetBlock( - spatial_dims=spatial_dims, - in_channels=out_channels, - out_channels=out_channels, - temb_channels=temb_channels, - norm_num_groups=norm_num_groups, - norm_eps=norm_eps, - up=True, - ) - else: - self.upsampler = _Upsample( - spatial_dims=spatial_dims, num_channels=out_channels, use_conv=True, out_channels=out_channels - ) - else: - self.upsampler = None - - def forward( - self, - hidden_states: torch.Tensor, - res_hidden_states_list: list[torch.Tensor], - temb: torch.Tensor, - seg: torch.Tensor, - context: torch.Tensor | None = None, - ) -> torch.Tensor: - del context - for resnet in self.resnets: - # pop res hidden states - res_hidden_states = res_hidden_states_list[-1] - res_hidden_states_list = res_hidden_states_list[:-1] - hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) - hidden_states = resnet(hidden_states, temb, seg) - - if self.upsampler is not None: - hidden_states = self.upsampler(hidden_states, temb) - - return hidden_states - - -class SPADEAttnUpBlock(nn.Module): - """ - Unet's up block containing resnet, upsamplers, and self-attention blocks. - Enables SPADE normalisation for semantic conditioning (Park et. al (2019): https://github.com/NVlabs/SPADE) - - Args: - spatial_dims: The number of spatial dimensions. - in_channels: number of input channels. - prev_output_channel: number of channels from residual connection. - out_channels: number of output channels. - temb_channels: number of timestep embedding channels. - label_nc: number of semantic channels for SPADE normalisation - num_res_blocks: number of residual blocks. - norm_num_groups: number of groups for the group normalization. - norm_eps: epsilon for the group normalization. - add_upsample: if True add downsample block. - resblock_updown: if True use residual blocks for upsampling. - num_head_channels: number of channels in each attention head. - use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. - spade_intermediate_channels: number of intermediate channels for SPADE block layer - """ - - def __init__( - self, - spatial_dims: int, - in_channels: int, - prev_output_channel: int, - out_channels: int, - temb_channels: int, - label_nc: int, - num_res_blocks: int = 1, - norm_num_groups: int = 32, - norm_eps: float = 1e-6, - add_upsample: bool = True, - resblock_updown: bool = False, - num_head_channels: int = 1, - use_flash_attention: bool = False, - spade_intermediate_channels: int = 128, - ) -> None: - super().__init__() - self.resblock_updown = resblock_updown - resnets = [] - attentions = [] - - for i in range(num_res_blocks): - res_skip_channels = in_channels if (i == num_res_blocks - 1) else out_channels - resnet_in_channels = prev_output_channel if i == 0 else out_channels - - resnets.append( - SPADEResnetBlock( - spatial_dims=spatial_dims, - in_channels=resnet_in_channels + res_skip_channels, - out_channels=out_channels, - temb_channels=temb_channels, - label_nc=label_nc, - norm_num_groups=norm_num_groups, - norm_eps=norm_eps, - spade_intermediate_channels=spade_intermediate_channels, - ) - ) - attentions.append( - _AttentionBlock( - spatial_dims=spatial_dims, - num_channels=out_channels, - num_head_channels=num_head_channels, - norm_num_groups=norm_num_groups, - norm_eps=norm_eps, - use_flash_attention=use_flash_attention, - ) - ) - - self.resnets = nn.ModuleList(resnets) - self.attentions = nn.ModuleList(attentions) - - self.upsampler: nn.Module | None - if add_upsample: - if resblock_updown: - self.upsampler = _ResnetBlock( - spatial_dims=spatial_dims, - in_channels=out_channels, - out_channels=out_channels, - temb_channels=temb_channels, - norm_num_groups=norm_num_groups, - norm_eps=norm_eps, - up=True, - ) - else: - self.upsampler = _Upsample( - spatial_dims=spatial_dims, num_channels=out_channels, use_conv=True, out_channels=out_channels - ) - else: - self.upsampler = None - - def forward( - self, - hidden_states: torch.Tensor, - res_hidden_states_list: list[torch.Tensor], - temb: torch.Tensor, - seg: torch.Tensor, - context: torch.Tensor | None = None, - ) -> torch.Tensor: - del context - for resnet, attn in zip(self.resnets, self.attentions): - # pop res hidden states - res_hidden_states = res_hidden_states_list[-1] - res_hidden_states_list = res_hidden_states_list[:-1] - hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) - hidden_states = resnet(hidden_states, temb, seg) - hidden_states = attn(hidden_states) - - if self.upsampler is not None: - hidden_states = self.upsampler(hidden_states, temb) - - return hidden_states - - -class SPADECrossAttnUpBlock(nn.Module): - """ - Unet's up block containing resnet, upsamplers, and self-attention blocks. - Enables SPADE normalisation for semantic conditioning (Park et. al (2019): https://github.com/NVlabs/SPADE) - - Args: - spatial_dims: The number of spatial dimensions. - in_channels: number of input channels. - prev_output_channel: number of channels from residual connection. - out_channels: number of output channels. - temb_channels: number of timestep embedding channels. - label_nc: number of semantic channels for SPADE normalisation. - num_res_blocks: number of residual blocks. - norm_num_groups: number of groups for the group normalization. - norm_eps: epsilon for the group normalization. - add_upsample: if True add downsample block. - resblock_updown: if True use residual blocks for upsampling. - num_head_channels: number of channels in each attention head. - transformer_num_layers: number of layers of Transformer blocks to use. - cross_attention_dim: number of context dimensions to use. - upcast_attention: if True, upcast attention operations to full precision. - use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. - spade_intermediate_channels: number of intermediate channels for SPADE block layer. - """ - - def __init__( - self, - spatial_dims: int, - in_channels: int, - prev_output_channel: int, - out_channels: int, - temb_channels: int, - label_nc: int, - num_res_blocks: int = 1, - norm_num_groups: int = 32, - norm_eps: float = 1e-6, - add_upsample: bool = True, - resblock_updown: bool = False, - num_head_channels: int = 1, - transformer_num_layers: int = 1, - cross_attention_dim: int | None = None, - upcast_attention: bool = False, - use_flash_attention: bool = False, - spade_intermediate_channels: int = 128, - ) -> None: - super().__init__() - self.resblock_updown = resblock_updown - resnets = [] - attentions = [] - - for i in range(num_res_blocks): - res_skip_channels = in_channels if (i == num_res_blocks - 1) else out_channels - resnet_in_channels = prev_output_channel if i == 0 else out_channels - - resnets.append( - SPADEResnetBlock( - spatial_dims=spatial_dims, - in_channels=resnet_in_channels + res_skip_channels, - out_channels=out_channels, - temb_channels=temb_channels, - norm_num_groups=norm_num_groups, - norm_eps=norm_eps, - label_nc=label_nc, - spade_intermediate_channels=spade_intermediate_channels, - ) - ) - attentions.append( - _SpatialTransformer( - spatial_dims=spatial_dims, - in_channels=out_channels, - num_attention_heads=out_channels // num_head_channels, - num_head_channels=num_head_channels, - norm_num_groups=norm_num_groups, - norm_eps=norm_eps, - num_layers=transformer_num_layers, - cross_attention_dim=cross_attention_dim, - upcast_attention=upcast_attention, - use_flash_attention=use_flash_attention, - ) - ) - - self.attentions = nn.ModuleList(attentions) - self.resnets = nn.ModuleList(resnets) - - self.upsampler: nn.Module | None - if add_upsample: - if resblock_updown: - self.upsampler = _ResnetBlock( - spatial_dims=spatial_dims, - in_channels=out_channels, - out_channels=out_channels, - temb_channels=temb_channels, - norm_num_groups=norm_num_groups, - norm_eps=norm_eps, - up=True, - ) - else: - self.upsampler = _Upsample( - spatial_dims=spatial_dims, num_channels=out_channels, use_conv=True, out_channels=out_channels - ) - else: - self.upsampler = None - - def forward( - self, - hidden_states: torch.Tensor, - res_hidden_states_list: list[torch.Tensor], - temb: torch.Tensor, - seg: torch.Tensor | None = None, - context: torch.Tensor | None = None, - ) -> torch.Tensor: - for resnet, attn in zip(self.resnets, self.attentions): - # pop res hidden states - res_hidden_states = res_hidden_states_list[-1] - res_hidden_states_list = res_hidden_states_list[:-1] - hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) - hidden_states = resnet(hidden_states, temb, seg) - hidden_states = attn(hidden_states, context=context) - - if self.upsampler is not None: - hidden_states = self.upsampler(hidden_states, temb) - - return hidden_states - - -def get_spade_up_block( - spatial_dims: int, - in_channels: int, - prev_output_channel: int, - out_channels: int, - temb_channels: int, - num_res_blocks: int, - norm_num_groups: int, - norm_eps: float, - add_upsample: bool, - resblock_updown: bool, - with_attn: bool, - with_cross_attn: bool, - num_head_channels: int, - transformer_num_layers: int, - label_nc: int, - cross_attention_dim: int | None, - upcast_attention: bool = False, - use_flash_attention: bool = False, - spade_intermediate_channels: int = 128, -) -> nn.Module: - if with_attn: - return SPADEAttnUpBlock( - spatial_dims=spatial_dims, - in_channels=in_channels, - prev_output_channel=prev_output_channel, - out_channels=out_channels, - temb_channels=temb_channels, - label_nc=label_nc, - num_res_blocks=num_res_blocks, - norm_num_groups=norm_num_groups, - norm_eps=norm_eps, - add_upsample=add_upsample, - resblock_updown=resblock_updown, - num_head_channels=num_head_channels, - use_flash_attention=use_flash_attention, - spade_intermediate_channels=spade_intermediate_channels, - ) - elif with_cross_attn: - return SPADECrossAttnUpBlock( - spatial_dims=spatial_dims, - in_channels=in_channels, - prev_output_channel=prev_output_channel, - out_channels=out_channels, - temb_channels=temb_channels, - label_nc=label_nc, - num_res_blocks=num_res_blocks, - norm_num_groups=norm_num_groups, - norm_eps=norm_eps, - add_upsample=add_upsample, - resblock_updown=resblock_updown, - num_head_channels=num_head_channels, - transformer_num_layers=transformer_num_layers, - cross_attention_dim=cross_attention_dim, - upcast_attention=upcast_attention, - use_flash_attention=use_flash_attention, - spade_intermediate_channels=spade_intermediate_channels, - ) - else: - return SPADEUpBlock( - spatial_dims=spatial_dims, - in_channels=in_channels, - prev_output_channel=prev_output_channel, - out_channels=out_channels, - temb_channels=temb_channels, - label_nc=label_nc, - num_res_blocks=num_res_blocks, - norm_num_groups=norm_num_groups, - norm_eps=norm_eps, - add_upsample=add_upsample, - resblock_updown=resblock_updown, - spade_intermediate_channels=spade_intermediate_channels, - ) - - -class SPADEDiffusionModelUNet(nn.Module): - """ - UNet network with timestep embedding and attention mechanisms for conditioning, with added SPADE normalization for - semantic conditioning (Park et.al (2019): https://github.com/NVlabs/SPADE). An example tutorial can be found at - https://github.com/Project-MONAI/GenerativeModels/tree/main/tutorials/generative/2d_spade_ldm - - Args: - spatial_dims: number of spatial dimensions. - in_channels: number of input channels. - out_channels: number of output channels. - label_nc: number of semantic channels for SPADE normalisation. - num_res_blocks: number of residual blocks (see ResnetBlock) per level. - num_channels: tuple of block output channels. - attention_levels: list of levels to add attention. - norm_num_groups: number of groups for the normalization. - norm_eps: epsilon for the normalization. - resblock_updown: if True use residual blocks for up/downsampling. - num_head_channels: number of channels in each attention head. - with_conditioning: if True add spatial transformers to perform conditioning. - transformer_num_layers: number of layers of Transformer blocks to use. - cross_attention_dim: number of context dimensions to use. - num_class_embeds: if specified (as an int), then this model will be class-conditional with `num_class_embeds` - classes. - upcast_attention: if True, upcast attention operations to full precision. - use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. - spade_intermediate_channels: number of intermediate channels for SPADE block layer - """ - - def __init__( - self, - spatial_dims: int, - in_channels: int, - out_channels: int, - label_nc: int, - num_res_blocks: Sequence[int] | int = (2, 2, 2, 2), - num_channels: Sequence[int] = (32, 64, 64, 64), - attention_levels: Sequence[bool] = (False, False, True, True), - norm_num_groups: int = 32, - norm_eps: float = 1e-6, - resblock_updown: bool = False, - num_head_channels: int | Sequence[int] = 8, - with_conditioning: bool = False, - transformer_num_layers: int = 1, - cross_attention_dim: int | None = None, - num_class_embeds: int | None = None, - upcast_attention: bool = False, - use_flash_attention: bool = False, - spade_intermediate_channels: int = 128, - ) -> None: - super().__init__() - if with_conditioning is True and cross_attention_dim is None: - raise ValueError( - "SPADEDiffusionModelUNet expects dimension of the cross-attention conditioning (cross_attention_dim) " - "when using with_conditioning." - ) - if cross_attention_dim is not None and with_conditioning is False: - raise ValueError( - "SPADEDiffusionModelUNet expects with_conditioning=True when specifying the cross_attention_dim." - ) - - # All number of channels should be multiple of num_groups - if any((out_channel % norm_num_groups) != 0 for out_channel in num_channels): - raise ValueError("SPADEDiffusionModelUNet expects all num_channels being multiple of norm_num_groups") - - if len(num_channels) != len(attention_levels): - raise ValueError("SPADEDiffusionModelUNet expects num_channels being same size of attention_levels") - - if isinstance(num_head_channels, int): - num_head_channels = ensure_tuple_rep(num_head_channels, len(attention_levels)) - - if len(num_head_channels) != len(attention_levels): - raise ValueError( - "num_head_channels should have the same length as attention_levels. For the i levels without attention," - " i.e. `attention_level[i]=False`, the num_head_channels[i] will be ignored." - ) - - if isinstance(num_res_blocks, int): - num_res_blocks = ensure_tuple_rep(num_res_blocks, len(num_channels)) - - if len(num_res_blocks) != len(num_channels): - raise ValueError( - "`num_res_blocks` should be a single integer or a tuple of integers with the same length as " - "`num_channels`." - ) - - if use_flash_attention and not has_xformers: - raise ValueError("use_flash_attention is True but xformers is not installed.") - - if use_flash_attention is True and not torch.cuda.is_available(): - raise ValueError( - "torch.cuda.is_available() should be True but is False. Flash attention is only available for GPU." - ) - - self.in_channels = in_channels - self.block_out_channels = num_channels - self.out_channels = out_channels - self.num_res_blocks = num_res_blocks - self.attention_levels = attention_levels - self.num_head_channels = num_head_channels - self.with_conditioning = with_conditioning - self.label_nc = label_nc - - # input - self.conv_in = Convolution( - spatial_dims=spatial_dims, - in_channels=in_channels, - out_channels=num_channels[0], - strides=1, - kernel_size=3, - padding=1, - conv_only=True, - ) - - # time - time_embed_dim = num_channels[0] * 4 - self.time_embed = nn.Sequential( - nn.Linear(num_channels[0], time_embed_dim), nn.SiLU(), nn.Linear(time_embed_dim, time_embed_dim) - ) - - # class embedding - self.num_class_embeds = num_class_embeds - if num_class_embeds is not None: - self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) - - # down - self.down_blocks = nn.ModuleList([]) - output_channel = num_channels[0] - for i in range(len(num_channels)): - input_channel = output_channel - output_channel = num_channels[i] - is_final_block = i == len(num_channels) - 1 - - down_block = get_down_block( - spatial_dims=spatial_dims, - in_channels=input_channel, - out_channels=output_channel, - temb_channels=time_embed_dim, - num_res_blocks=num_res_blocks[i], - norm_num_groups=norm_num_groups, - norm_eps=norm_eps, - add_downsample=not is_final_block, - resblock_updown=resblock_updown, - with_attn=(attention_levels[i] and not with_conditioning), - with_cross_attn=(attention_levels[i] and with_conditioning), - num_head_channels=num_head_channels[i], - transformer_num_layers=transformer_num_layers, - cross_attention_dim=cross_attention_dim, - upcast_attention=upcast_attention, - use_flash_attention=use_flash_attention, - ) - - self.down_blocks.append(down_block) - - # mid - self.middle_block = get_mid_block( - spatial_dims=spatial_dims, - in_channels=num_channels[-1], - temb_channels=time_embed_dim, - norm_num_groups=norm_num_groups, - norm_eps=norm_eps, - with_conditioning=with_conditioning, - num_head_channels=num_head_channels[-1], - transformer_num_layers=transformer_num_layers, - cross_attention_dim=cross_attention_dim, - upcast_attention=upcast_attention, - use_flash_attention=use_flash_attention, - ) - - # up - self.up_blocks = nn.ModuleList([]) - reversed_block_out_channels = list(reversed(num_channels)) - reversed_num_res_blocks = list(reversed(num_res_blocks)) - reversed_attention_levels = list(reversed(attention_levels)) - reversed_num_head_channels = list(reversed(num_head_channels)) - output_channel = reversed_block_out_channels[0] - for i in range(len(reversed_block_out_channels)): - prev_output_channel = output_channel - output_channel = reversed_block_out_channels[i] - input_channel = reversed_block_out_channels[min(i + 1, len(num_channels) - 1)] - - is_final_block = i == len(num_channels) - 1 - - up_block = get_spade_up_block( - spatial_dims=spatial_dims, - in_channels=input_channel, - prev_output_channel=prev_output_channel, - out_channels=output_channel, - temb_channels=time_embed_dim, - num_res_blocks=reversed_num_res_blocks[i] + 1, - norm_num_groups=norm_num_groups, - norm_eps=norm_eps, - add_upsample=not is_final_block, - resblock_updown=resblock_updown, - with_attn=(reversed_attention_levels[i] and not with_conditioning), - with_cross_attn=(reversed_attention_levels[i] and with_conditioning), - num_head_channels=reversed_num_head_channels[i], - transformer_num_layers=transformer_num_layers, - cross_attention_dim=cross_attention_dim, - upcast_attention=upcast_attention, - use_flash_attention=use_flash_attention, - label_nc=label_nc, - spade_intermediate_channels=spade_intermediate_channels, - ) - - self.up_blocks.append(up_block) - - # out - self.out = nn.Sequential( - nn.GroupNorm(num_groups=norm_num_groups, num_channels=num_channels[0], eps=norm_eps, affine=True), - nn.SiLU(), - zero_module( - Convolution( - spatial_dims=spatial_dims, - in_channels=num_channels[0], - out_channels=out_channels, - strides=1, - kernel_size=3, - padding=1, - conv_only=True, - ) - ), - ) - - def forward( - self, - x: torch.Tensor, - timesteps: torch.Tensor, - seg: torch.Tensor, - context: torch.Tensor | None = None, - class_labels: torch.Tensor | None = None, - down_block_additional_residuals: tuple[torch.Tensor] | None = None, - mid_block_additional_residual: torch.Tensor | None = None, - ) -> torch.Tensor: - """ - Args: - x: input tensor (N, C, SpatialDims). - timesteps: timestep tensor (N,). - seg: Bx[LABEL_NC]x[SPATIAL DIMENSIONS] tensor of segmentations for SPADE norm. - context: context tensor (N, 1, ContextDim). - class_labels: context tensor (N, ). - down_block_additional_residuals: additional residual tensors for down blocks (N, C, FeatureMapsDims). - mid_block_additional_residual: additional residual tensor for mid block (N, C, FeatureMapsDims). - """ - # 1. time - t_emb = get_timestep_embedding(timesteps, self.block_out_channels[0]) - - # timesteps does not contain any weights and will always return f32 tensors - # but time_embedding might actually be running in fp16. so we need to cast here. - # there might be better ways to encapsulate this. - t_emb = t_emb.to(dtype=x.dtype) - emb = self.time_embed(t_emb) - - # 2. class - if self.num_class_embeds is not None: - if class_labels is None: - raise ValueError("class_labels should be provided when num_class_embeds > 0") - class_emb = self.class_embedding(class_labels) - class_emb = class_emb.to(dtype=x.dtype) - emb = emb + class_emb - - # 3. initial convolution - h = self.conv_in(x) - - # 4. down - if context is not None and self.with_conditioning is False: - raise ValueError("model should have with_conditioning = True if context is provided") - down_block_res_samples: list[torch.Tensor] = [h] - for downsample_block in self.down_blocks: - h, res_samples = downsample_block(hidden_states=h, temb=emb, context=context) - for residual in res_samples: - down_block_res_samples.append(residual) - - # Additional residual conections for Controlnets - if down_block_additional_residuals is not None: - new_down_block_res_samples: list[torch.Tensor] = [h] - for down_block_res_sample, down_block_additional_residual in zip( - down_block_res_samples, down_block_additional_residuals - ): - down_block_res_sample = down_block_res_sample + down_block_additional_residual - new_down_block_res_samples.append(down_block_res_sample) - - down_block_res_samples = new_down_block_res_samples - - # 5. mid - h = self.middle_block(hidden_states=h, temb=emb, context=context) - - # Additional residual conections for Controlnets - if mid_block_additional_residual is not None: - h = h + mid_block_additional_residual - - # 6. up - for upsample_block in self.up_blocks: - res_samples = down_block_res_samples[-len(upsample_block.resnets) :] - down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] - h = upsample_block(hidden_states=h, res_hidden_states_list=res_samples, seg=seg, temb=emb, context=context) - - # 7. output block - output: torch.Tensor = self.out(h) - - return output diff --git a/monai/networks/nets/transformer.py b/monai/networks/nets/transformer.py deleted file mode 100644 index b742c12205d..00000000000 --- a/monai/networks/nets/transformer.py +++ /dev/null @@ -1,314 +0,0 @@ -# Copyright (c) MONAI Consortium -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -import math - -import torch -import torch.nn as nn -import torch.nn.functional as F - -from monai.networks.blocks.mlp import MLPBlock -from monai.utils import optional_import - -xops, has_xformers = optional_import("xformers.ops") -__all__ = ["DecoderOnlyTransformer"] - - -class _SABlock(nn.Module): - """ - NOTE This is a private block that we plan to merge with existing MONAI blocks in the future. Please do not make - use of this block as support is not guaranteed. For more information see: - https://github.com/Project-MONAI/MONAI/issues/7227 - - A self-attention block, based on: "Dosovitskiy et al., - An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale " - - Args: - hidden_size: dimension of hidden layer. - num_heads: number of attention heads. - dropout_rate: dropout ratio. Defaults to no dropout. - qkv_bias: bias term for the qkv linear layer. - causal: whether to use causal attention. - sequence_length: if causal is True, it is necessary to specify the sequence length. - with_cross_attention: Whether to use cross attention for conditioning. - use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. - """ - - def __init__( - self, - hidden_size: int, - num_heads: int, - dropout_rate: float = 0.0, - qkv_bias: bool = False, - causal: bool = False, - sequence_length: int | None = None, - with_cross_attention: bool = False, - use_flash_attention: bool = False, - ) -> None: - super().__init__() - self.hidden_size = hidden_size - self.num_heads = num_heads - self.head_dim = hidden_size // num_heads - self.scale = 1.0 / math.sqrt(self.head_dim) - self.causal = causal - self.sequence_length = sequence_length - self.with_cross_attention = with_cross_attention - self.use_flash_attention = use_flash_attention - - if not (0 <= dropout_rate <= 1): - raise ValueError("dropout_rate should be between 0 and 1.") - self.dropout_rate = dropout_rate - - if hidden_size % num_heads != 0: - raise ValueError("hidden size should be divisible by num_heads.") - - if causal and sequence_length is None: - raise ValueError("sequence_length is necessary for causal attention.") - - if use_flash_attention and not has_xformers: - raise ValueError("use_flash_attention is True but xformers is not installed.") - - # key, query, value projections - self.to_q = nn.Linear(hidden_size, hidden_size, bias=qkv_bias) - self.to_k = nn.Linear(hidden_size, hidden_size, bias=qkv_bias) - self.to_v = nn.Linear(hidden_size, hidden_size, bias=qkv_bias) - - # regularization - self.drop_weights = nn.Dropout(dropout_rate) - self.drop_output = nn.Dropout(dropout_rate) - - # output projection - self.out_proj = nn.Linear(hidden_size, hidden_size) - - if causal and sequence_length is not None: - # causal mask to ensure that attention is only applied to the left in the input sequence - self.register_buffer( - "causal_mask", - torch.tril(torch.ones(sequence_length, sequence_length)).view(1, 1, sequence_length, sequence_length), - ) - self.causal_mask: torch.Tensor - - def forward(self, x: torch.Tensor, context: torch.Tensor | None = None) -> torch.Tensor: - b, t, c = x.size() # batch size, sequence length, embedding dimensionality (hidden_size) - - # calculate query, key, values for all heads in batch and move head forward to be the batch dim - query = self.to_q(x) - - kv = context if context is not None else x - _, kv_t, _ = kv.size() - key = self.to_k(kv) - value = self.to_v(kv) - - query = query.view(b, t, self.num_heads, c // self.num_heads) # (b, t, nh, hs) - key = key.view(b, kv_t, self.num_heads, c // self.num_heads) # (b, kv_t, nh, hs) - value = value.view(b, kv_t, self.num_heads, c // self.num_heads) # (b, kv_t, nh, hs) - y: torch.Tensor - if self.use_flash_attention: - query = query.contiguous() - key = key.contiguous() - value = value.contiguous() - y = xops.memory_efficient_attention( - query=query, - key=key, - value=value, - scale=self.scale, - p=self.dropout_rate, - attn_bias=xops.LowerTriangularMask() if self.causal else None, - ) - - else: - query = query.transpose(1, 2) # (b, nh, t, hs) - key = key.transpose(1, 2) # (b, nh, kv_t, hs) - value = value.transpose(1, 2) # (b, nh, kv_t, hs) - - # manual implementation of attention - query = query * self.scale - attention_scores = query @ key.transpose(-2, -1) - - if self.causal: - attention_scores = attention_scores.masked_fill(self.causal_mask[:, :, :t, :kv_t] == 0, float("-inf")) - - attention_probs = F.softmax(attention_scores, dim=-1) - attention_probs = self.drop_weights(attention_probs) - y = attention_probs @ value # (b, nh, t, kv_t) x (b, nh, kv_t, hs) -> (b, nh, t, hs) - - y = y.transpose(1, 2) # (b, nh, t, hs) -> (b, t, nh, hs) - - y = y.contiguous().view(b, t, c) # re-assemble all head outputs side by side - - y = self.out_proj(y) - y = self.drop_output(y) - return y - - -class _TransformerBlock(nn.Module): - """ - NOTE This is a private block that we plan to merge with existing MONAI blocks in the future. Please do not make - use of this block as support is not guaranteed. For more information see: - https://github.com/Project-MONAI/MONAI/issues/7227 - - A transformer block, based on: "Dosovitskiy et al., - An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale " - - Args: - hidden_size: dimension of hidden layer. - mlp_dim: dimension of feedforward layer. - num_heads: number of attention heads. - dropout_rate: faction of the input units to drop. - qkv_bias: apply bias term for the qkv linear layer - causal: whether to use causal attention. - sequence_length: if causal is True, it is necessary to specify the sequence length. - with_cross_attention: Whether to use cross attention for conditioning. - use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. - """ - - def __init__( - self, - hidden_size: int, - mlp_dim: int, - num_heads: int, - dropout_rate: float = 0.0, - qkv_bias: bool = False, - causal: bool = False, - sequence_length: int | None = None, - with_cross_attention: bool = False, - use_flash_attention: bool = False, - ) -> None: - self.with_cross_attention = with_cross_attention - super().__init__() - - if not (0 <= dropout_rate <= 1): - raise ValueError("dropout_rate should be between 0 and 1.") - - if hidden_size % num_heads != 0: - raise ValueError("hidden_size should be divisible by num_heads.") - - self.norm1 = nn.LayerNorm(hidden_size) - self.attn = _SABlock( - hidden_size=hidden_size, - num_heads=num_heads, - dropout_rate=dropout_rate, - qkv_bias=qkv_bias, - causal=causal, - sequence_length=sequence_length, - use_flash_attention=use_flash_attention, - ) - - if self.with_cross_attention: - self.norm2 = nn.LayerNorm(hidden_size) - self.cross_attn = _SABlock( - hidden_size=hidden_size, - num_heads=num_heads, - dropout_rate=dropout_rate, - qkv_bias=qkv_bias, - with_cross_attention=with_cross_attention, - causal=False, - use_flash_attention=use_flash_attention, - ) - self.norm3 = nn.LayerNorm(hidden_size) - self.mlp = MLPBlock(hidden_size, mlp_dim, dropout_rate) - - def forward(self, x: torch.Tensor, context: torch.Tensor | None = None) -> torch.Tensor: - x = x + self.attn(self.norm1(x)) - if self.with_cross_attention: - x = x + self.cross_attn(self.norm2(x), context=context) - x = x + self.mlp(self.norm3(x)) - return x - - -class AbsolutePositionalEmbedding(nn.Module): - """Absolute positional embedding. - - Args: - max_seq_len: Maximum sequence length. - embedding_dim: Dimensionality of the embedding. - """ - - def __init__(self, max_seq_len: int, embedding_dim: int) -> None: - super().__init__() - self.max_seq_len = max_seq_len - self.embedding_dim = embedding_dim - self.embedding = nn.Embedding(max_seq_len, embedding_dim) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - batch_size, seq_len = x.size() - positions = torch.arange(seq_len, device=x.device).repeat(batch_size, 1) - embedding: torch.Tensor = self.embedding(positions) - return embedding - - -class DecoderOnlyTransformer(nn.Module): - """Decoder-only (Autoregressive) Transformer model. - - Args: - num_tokens: Number of tokens in the vocabulary. - max_seq_len: Maximum sequence length. - attn_layers_dim: Dimensionality of the attention layers. - attn_layers_depth: Number of attention layers. - attn_layers_heads: Number of attention heads. - with_cross_attention: Whether to use cross attention for conditioning. - embedding_dropout_rate: Dropout rate for the embedding. - use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. - """ - - def __init__( - self, - num_tokens: int, - max_seq_len: int, - attn_layers_dim: int, - attn_layers_depth: int, - attn_layers_heads: int, - with_cross_attention: bool = False, - embedding_dropout_rate: float = 0.0, - use_flash_attention: bool = False, - ) -> None: - super().__init__() - self.num_tokens = num_tokens - self.max_seq_len = max_seq_len - self.attn_layers_dim = attn_layers_dim - self.attn_layers_depth = attn_layers_depth - self.attn_layers_heads = attn_layers_heads - self.with_cross_attention = with_cross_attention - - self.token_embeddings = nn.Embedding(num_tokens, attn_layers_dim) - self.position_embeddings = AbsolutePositionalEmbedding(max_seq_len=max_seq_len, embedding_dim=attn_layers_dim) - self.embedding_dropout = nn.Dropout(embedding_dropout_rate) - - self.blocks = nn.ModuleList( - [ - _TransformerBlock( - hidden_size=attn_layers_dim, - mlp_dim=attn_layers_dim * 4, - num_heads=attn_layers_heads, - dropout_rate=0.0, - qkv_bias=False, - causal=True, - sequence_length=max_seq_len, - with_cross_attention=with_cross_attention, - use_flash_attention=use_flash_attention, - ) - for _ in range(attn_layers_depth) - ] - ) - - self.to_logits = nn.Linear(attn_layers_dim, num_tokens) - - def forward(self, x: torch.Tensor, context: torch.Tensor | None = None) -> torch.Tensor: - tok_emb = self.token_embeddings(x) - pos_emb = self.position_embeddings(x) - x = self.embedding_dropout(tok_emb + pos_emb) - - for block in self.blocks: - x = block(x, context=context) - logits: torch.Tensor = self.to_logits(x) - return logits diff --git a/monai/networks/nets/vqvae.py b/monai/networks/nets/vqvae.py deleted file mode 100644 index d4771e203a3..00000000000 --- a/monai/networks/nets/vqvae.py +++ /dev/null @@ -1,466 +0,0 @@ -# Copyright (c) MONAI Consortium -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -from collections.abc import Sequence -from typing import Tuple - -import torch -import torch.nn as nn - -from monai.networks.blocks import Convolution -from monai.networks.layers import Act -from monai.networks.layers.vector_quantizer import EMAQuantizer, VectorQuantizer -from monai.utils import ensure_tuple_rep - -__all__ = ["VQVAE"] - - -class VQVAEResidualUnit(nn.Module): - """ - Implementation of the ResidualLayer used in the VQVAE network as originally used in Morphology-preserving - Autoregressive 3D Generative Modelling of the Brain by Tudosiu et al. (https://arxiv.org/pdf/2209.03177.pdf). - - The original implementation that can be found at - https://github.com/AmigoLab/SynthAnatomy/blob/main/src/networks/vqvae/baseline.py#L150. - - Args: - spatial_dims: number of spatial spatial_dims of the input data. - in_channels: number of input channels. - num_res_channels: number of channels in the residual layers. - act: activation type and arguments. Defaults to RELU. - dropout: dropout ratio. Defaults to no dropout. - bias: whether to have a bias term. Defaults to True. - """ - - def __init__( - self, - spatial_dims: int, - in_channels: int, - num_res_channels: int, - act: tuple | str | None = Act.RELU, - dropout: float = 0.0, - bias: bool = True, - ) -> None: - super().__init__() - - self.spatial_dims = spatial_dims - self.in_channels = in_channels - self.num_res_channels = num_res_channels - self.act = act - self.dropout = dropout - self.bias = bias - - self.conv1 = Convolution( - spatial_dims=self.spatial_dims, - in_channels=self.in_channels, - out_channels=self.num_res_channels, - adn_ordering="DA", - act=self.act, - dropout=self.dropout, - bias=self.bias, - ) - - self.conv2 = Convolution( - spatial_dims=self.spatial_dims, - in_channels=self.num_res_channels, - out_channels=self.in_channels, - bias=self.bias, - conv_only=True, - ) - - def forward(self, x): - return torch.nn.functional.relu(x + self.conv2(self.conv1(x)), True) - - -class Encoder(nn.Module): - """ - Encoder module for VQ-VAE. - - Args: - spatial_dims: number of spatial spatial_dims. - in_channels: number of input channels. - out_channels: number of channels in the latent space (embedding_dim). - channels: sequence containing the number of channels at each level of the encoder. - num_res_layers: number of sequential residual layers at each level. - num_res_channels: number of channels in the residual layers at each level. - downsample_parameters: A Tuple of Tuples for defining the downsampling convolutions. Each Tuple should hold the - following information stride (int), kernel_size (int), dilation (int) and padding (int). - dropout: dropout ratio. - act: activation type and arguments. - """ - - def __init__( - self, - spatial_dims: int, - in_channels: int, - out_channels: int, - channels: Sequence[int], - num_res_layers: int, - num_res_channels: Sequence[int], - downsample_parameters: Sequence[Tuple[int, int, int, int]], - dropout: float, - act: tuple | str | None, - ) -> None: - super().__init__() - self.spatial_dims = spatial_dims - self.in_channels = in_channels - self.out_channels = out_channels - self.channels = channels - self.num_res_layers = num_res_layers - self.num_res_channels = num_res_channels - self.downsample_parameters = downsample_parameters - self.dropout = dropout - self.act = act - - blocks: list[nn.Module] = [] - - for i in range(len(self.channels)): - blocks.append( - Convolution( - spatial_dims=self.spatial_dims, - in_channels=self.in_channels if i == 0 else self.channels[i - 1], - out_channels=self.channels[i], - strides=self.downsample_parameters[i][0], - kernel_size=self.downsample_parameters[i][1], - adn_ordering="DA", - act=self.act, - dropout=None if i == 0 else self.dropout, - dropout_dim=1, - dilation=self.downsample_parameters[i][2], - padding=self.downsample_parameters[i][3], - ) - ) - - for _ in range(self.num_res_layers): - blocks.append( - VQVAEResidualUnit( - spatial_dims=self.spatial_dims, - in_channels=self.channels[i], - num_res_channels=self.num_res_channels[i], - act=self.act, - dropout=self.dropout, - ) - ) - - blocks.append( - Convolution( - spatial_dims=self.spatial_dims, - in_channels=self.channels[len(self.channels) - 1], - out_channels=self.out_channels, - strides=1, - kernel_size=3, - padding=1, - conv_only=True, - ) - ) - - self.blocks = nn.ModuleList(blocks) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - for block in self.blocks: - x = block(x) - return x - - -class Decoder(nn.Module): - """ - Decoder module for VQ-VAE. - - Args: - spatial_dims: number of spatial spatial_dims. - in_channels: number of channels in the latent space (embedding_dim). - out_channels: number of output channels. - channels: sequence containing the number of channels at each level of the decoder. - num_res_layers: number of sequential residual layers at each level. - num_res_channels: number of channels in the residual layers at each level. - upsample_parameters: A Tuple of Tuples for defining the upsampling convolutions. Each Tuple should hold the - following information stride (int), kernel_size (int), dilation (int), padding (int), output_padding (int). - dropout: dropout ratio. - act: activation type and arguments. - output_act: activation type and arguments for the output. - """ - - def __init__( - self, - spatial_dims: int, - in_channels: int, - out_channels: int, - channels: Sequence[int], - num_res_layers: int, - num_res_channels: Sequence[int], - upsample_parameters: Sequence[Tuple[int, int, int, int, int]], - dropout: float, - act: tuple | str | None, - output_act: tuple | str | None, - ) -> None: - super().__init__() - self.spatial_dims = spatial_dims - self.in_channels = in_channels - self.out_channels = out_channels - self.channels = channels - self.num_res_layers = num_res_layers - self.num_res_channels = num_res_channels - self.upsample_parameters = upsample_parameters - self.dropout = dropout - self.act = act - self.output_act = output_act - - reversed_num_channels = list(reversed(self.channels)) - - blocks: list[nn.Module] = [] - blocks.append( - Convolution( - spatial_dims=self.spatial_dims, - in_channels=self.in_channels, - out_channels=reversed_num_channels[0], - strides=1, - kernel_size=3, - padding=1, - conv_only=True, - ) - ) - - reversed_num_res_channels = list(reversed(self.num_res_channels)) - for i in range(len(self.channels)): - for _ in range(self.num_res_layers): - blocks.append( - VQVAEResidualUnit( - spatial_dims=self.spatial_dims, - in_channels=reversed_num_channels[i], - num_res_channels=reversed_num_res_channels[i], - act=self.act, - dropout=self.dropout, - ) - ) - - blocks.append( - Convolution( - spatial_dims=self.spatial_dims, - in_channels=reversed_num_channels[i], - out_channels=self.out_channels if i == len(self.channels) - 1 else reversed_num_channels[i + 1], - strides=self.upsample_parameters[i][0], - kernel_size=self.upsample_parameters[i][1], - adn_ordering="DA", - act=self.act, - dropout=self.dropout if i != len(self.channels) - 1 else None, - norm=None, - dilation=self.upsample_parameters[i][2], - conv_only=i == len(self.channels) - 1, - is_transposed=True, - padding=self.upsample_parameters[i][3], - output_padding=self.upsample_parameters[i][4], - ) - ) - - if self.output_act: - blocks.append(Act[self.output_act]()) - - self.blocks = nn.ModuleList(blocks) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - for block in self.blocks: - x = block(x) - return x - - -class VQVAE(nn.Module): - """ - Vector-Quantised Variational Autoencoder (VQ-VAE) used in Morphology-preserving Autoregressive 3D Generative - Modelling of the Brain by Tudosiu et al. (https://arxiv.org/pdf/2209.03177.pdf) - - The original implementation can be found at - https://github.com/AmigoLab/SynthAnatomy/blob/main/src/networks/vqvae/baseline.py#L163/ - - Args: - spatial_dims: number of spatial spatial_dims. - in_channels: number of input channels. - out_channels: number of output channels. - downsample_parameters: A Tuple of Tuples for defining the downsampling convolutions. Each Tuple should hold the - following information stride (int), kernel_size (int), dilation (int) and padding (int). - upsample_parameters: A Tuple of Tuples for defining the upsampling convolutions. Each Tuple should hold the - following information stride (int), kernel_size (int), dilation (int), padding (int), output_padding (int). - num_res_layers: number of sequential residual layers at each level. - channels: number of channels at each level. - num_res_channels: number of channels in the residual layers at each level. - num_embeddings: VectorQuantization number of atomic elements in the codebook. - embedding_dim: VectorQuantization number of channels of the input and atomic elements. - commitment_cost: VectorQuantization commitment_cost. - decay: VectorQuantization decay. - epsilon: VectorQuantization epsilon. - act: activation type and arguments. - dropout: dropout ratio. - output_act: activation type and arguments for the output. - ddp_sync: whether to synchronize the codebook across processes. - use_checkpointing if True, use activation checkpointing to save memory. - """ - - def __init__( - self, - spatial_dims: int, - in_channels: int, - out_channels: int, - channels: Sequence[int] = (96, 96, 192), - num_res_layers: int = 3, - num_res_channels: Sequence[int] | int = (96, 96, 192), - downsample_parameters: Sequence[Tuple[int, int, int, int]] - | Tuple[int, int, int, int] = ((2, 4, 1, 1), (2, 4, 1, 1), (2, 4, 1, 1)), - upsample_parameters: Sequence[Tuple[int, int, int, int, int]] - | Tuple[int, int, int, int, int] = ((2, 4, 1, 1, 0), (2, 4, 1, 1, 0), (2, 4, 1, 1, 0)), - num_embeddings: int = 32, - embedding_dim: int = 64, - embedding_init: str = "normal", - commitment_cost: float = 0.25, - decay: float = 0.5, - epsilon: float = 1e-5, - dropout: float = 0.0, - act: tuple | str | None = Act.RELU, - output_act: tuple | str | None = None, - ddp_sync: bool = True, - use_checkpointing: bool = False, - ): - super().__init__() - - self.in_channels = in_channels - self.out_channels = out_channels - self.spatial_dims = spatial_dims - self.channels = channels - self.num_embeddings = num_embeddings - self.embedding_dim = embedding_dim - self.use_checkpointing = use_checkpointing - - if isinstance(num_res_channels, int): - num_res_channels = ensure_tuple_rep(num_res_channels, len(channels)) - - if len(num_res_channels) != len(channels): - raise ValueError( - "`num_res_channels` should be a single integer or a tuple of integers with the same length as " - "`num_channls`." - ) - if all(isinstance(values, int) for values in upsample_parameters): - upsample_parameters_tuple: Sequence = (upsample_parameters,) * len(channels) - else: - upsample_parameters_tuple = upsample_parameters - - if all(isinstance(values, int) for values in downsample_parameters): - downsample_parameters_tuple: Sequence = (downsample_parameters,) * len(channels) - else: - downsample_parameters_tuple = downsample_parameters - - if not all(all(isinstance(value, int) for value in sub_item) for sub_item in downsample_parameters_tuple): - raise ValueError("`downsample_parameters` should be a single tuple of integer or a tuple of tuples.") - - # check if downsample_parameters is a tuple of ints or a tuple of tuples of ints - if not all(all(isinstance(value, int) for value in sub_item) for sub_item in upsample_parameters_tuple): - raise ValueError("`upsample_parameters` should be a single tuple of integer or a tuple of tuples.") - - for parameter in downsample_parameters_tuple: - if len(parameter) != 4: - raise ValueError("`downsample_parameters` should be a tuple of tuples with 4 integers.") - - for parameter in upsample_parameters_tuple: - if len(parameter) != 5: - raise ValueError("`upsample_parameters` should be a tuple of tuples with 5 integers.") - - if len(downsample_parameters_tuple) != len(channels): - raise ValueError( - "`downsample_parameters` should be a tuple of tuples with the same length as `num_channels`." - ) - - if len(upsample_parameters_tuple) != len(channels): - raise ValueError( - "`upsample_parameters` should be a tuple of tuples with the same length as `num_channels`." - ) - - self.num_res_layers = num_res_layers - self.num_res_channels = num_res_channels - - self.encoder = Encoder( - spatial_dims=spatial_dims, - in_channels=in_channels, - out_channels=embedding_dim, - channels=channels, - num_res_layers=num_res_layers, - num_res_channels=num_res_channels, - downsample_parameters=downsample_parameters_tuple, - dropout=dropout, - act=act, - ) - - self.decoder = Decoder( - spatial_dims=spatial_dims, - in_channels=embedding_dim, - out_channels=out_channels, - channels=channels, - num_res_layers=num_res_layers, - num_res_channels=num_res_channels, - upsample_parameters=upsample_parameters_tuple, - dropout=dropout, - act=act, - output_act=output_act, - ) - - self.quantizer = VectorQuantizer( - quantizer=EMAQuantizer( - spatial_dims=spatial_dims, - num_embeddings=num_embeddings, - embedding_dim=embedding_dim, - commitment_cost=commitment_cost, - decay=decay, - epsilon=epsilon, - embedding_init=embedding_init, - ddp_sync=ddp_sync, - ) - ) - - def encode(self, images: torch.Tensor) -> torch.Tensor: - output: torch.Tensor - if self.use_checkpointing: - output = torch.utils.checkpoint.checkpoint(self.encoder, images, use_reentrant=False) - else: - output = self.encoder(images) - return output - - def quantize(self, encodings: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: - x_loss, x = self.quantizer(encodings) - return x, x_loss - - def decode(self, quantizations: torch.Tensor) -> torch.Tensor: - output: torch.Tensor - - if self.use_checkpointing: - output = torch.utils.checkpoint.checkpoint(self.decoder, quantizations, use_reentrant=False) - else: - output = self.decoder(quantizations) - return output - - def index_quantize(self, images: torch.Tensor) -> torch.Tensor: - return self.quantizer.quantize(self.encode(images=images)) - - def decode_samples(self, embedding_indices: torch.Tensor) -> torch.Tensor: - return self.decode(self.quantizer.embed(embedding_indices)) - - def forward(self, images: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: - quantizations, quantization_losses = self.quantize(self.encode(images)) - reconstruction = self.decode(quantizations) - - return reconstruction, quantization_losses - - def encode_stage_2_inputs(self, x: torch.Tensor) -> torch.Tensor: - z = self.encode(x) - e, _ = self.quantize(z) - return e - - def decode_stage_2_outputs(self, z: torch.Tensor) -> torch.Tensor: - e, _ = self.quantize(z) - image = self.decode(e) - return image diff --git a/monai/networks/schedulers/__init__.py b/monai/networks/schedulers/__init__.py deleted file mode 100644 index 29e9020d650..00000000000 --- a/monai/networks/schedulers/__init__.py +++ /dev/null @@ -1,17 +0,0 @@ -# Copyright (c) MONAI Consortium -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -from .ddim import DDIMScheduler -from .ddpm import DDPMScheduler -from .pndm import PNDMScheduler -from .scheduler import NoiseSchedules, Scheduler diff --git a/monai/networks/schedulers/ddim.py b/monai/networks/schedulers/ddim.py deleted file mode 100644 index 78e3cc2a0cc..00000000000 --- a/monai/networks/schedulers/ddim.py +++ /dev/null @@ -1,277 +0,0 @@ -# Copyright (c) MONAI Consortium -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -# ========================================================================= -# Adapted from https://github.com/huggingface/diffusers -# which has the following license: -# https://github.com/huggingface/diffusers/blob/main/LICENSE -# -# Copyright 2022 UC Berkeley Team and The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ========================================================================= - -from __future__ import annotations - -import numpy as np -import torch - -from .ddpm import DDPMPredictionType -from .scheduler import Scheduler - -DDIMPredictionType = DDPMPredictionType - - -class DDIMScheduler(Scheduler): - """ - Denoising diffusion implicit models is a scheduler that extends the denoising procedure introduced in denoising - diffusion probabilistic models (DDPMs) with non-Markovian guidance. Based on: Song et al. "Denoising Diffusion - Implicit Models" https://arxiv.org/abs/2010.02502 - - Args: - num_train_timesteps: number of diffusion steps used to train the model. - schedule: member of NoiseSchedules, name of noise schedule function in component store - clip_sample: option to clip predicted sample between -1 and 1 for numerical stability. - set_alpha_to_one: each diffusion step uses the value of alphas product at that step and at the previous one. - For the final step there is no previous alpha. When this option is `True` the previous alpha product is - fixed to `1`, otherwise it uses the value of alpha at step 0. - steps_offset: an offset added to the inference steps. You can use a combination of `steps_offset=1` and - `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in - stable diffusion. - prediction_type: member of DDPMPredictionType - schedule_args: arguments to pass to the schedule function - - """ - - def __init__( - self, - num_train_timesteps: int = 1000, - schedule: str = "linear_beta", - clip_sample: bool = True, - set_alpha_to_one: bool = True, - steps_offset: int = 0, - prediction_type: str = DDIMPredictionType.EPSILON, - **schedule_args, - ) -> None: - super().__init__(num_train_timesteps, schedule, **schedule_args) - - if prediction_type not in DDIMPredictionType.__members__.values(): - raise ValueError("Argument `prediction_type` must be a member of DDIMPredictionType") - - self.prediction_type = prediction_type - - # At every step in ddim, we are looking into the previous alphas_cumprod - # For the final step, there is no previous alphas_cumprod because we are already at 0 - # `set_alpha_to_one` decides whether we set this parameter simply to one or - # whether we use the final alpha of the "non-previous" one. - self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0] - - # standard deviation of the initial noise distribution - self.init_noise_sigma = 1.0 - - self.timesteps = torch.from_numpy(np.arange(0, self.num_train_timesteps)[::-1].astype(np.int64)) - - self.clip_sample = clip_sample - self.steps_offset = steps_offset - - # default the number of inference timesteps to the number of train steps - self.num_inference_steps: int - self.set_timesteps(self.num_train_timesteps) - - def set_timesteps(self, num_inference_steps: int, device: str | torch.device | None = None) -> None: - """ - Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. - - Args: - num_inference_steps: number of diffusion steps used when generating samples with a pre-trained model. - device: target device to put the data. - """ - if num_inference_steps > self.num_train_timesteps: - raise ValueError( - f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.num_train_timesteps`:" - f" {self.num_train_timesteps} as the unet model trained with this scheduler can only handle" - f" maximal {self.num_train_timesteps} timesteps." - ) - - self.num_inference_steps = num_inference_steps - step_ratio = self.num_train_timesteps // self.num_inference_steps - if self.steps_offset >= step_ratio: - raise ValueError( - f"`steps_offset`: {self.steps_offset} cannot be greater than or equal to " - f"`num_train_timesteps // num_inference_steps : {step_ratio}` as this will cause timesteps to exceed" - f" the max train timestep." - ) - - # creates integer timesteps by multiplying by ratio - # casting to int to avoid issues when num_inference_step is power of 3 - timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64) - self.timesteps = torch.from_numpy(timesteps).to(device) - self.timesteps += self.steps_offset - - def _get_variance(self, timestep: int, prev_timestep: int) -> torch.Tensor: - alpha_prod_t = self.alphas_cumprod[timestep] - alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod - beta_prod_t = 1 - alpha_prod_t - beta_prod_t_prev = 1 - alpha_prod_t_prev - - variance: torch.Tensor = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev) - - return variance - - def step( - self, - model_output: torch.Tensor, - timestep: int, - sample: torch.Tensor, - eta: float = 0.0, - generator: torch.Generator | None = None, - ) -> tuple[torch.Tensor, torch.Tensor]: - """ - Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion - process from the learned model outputs (most often the predicted noise). - - Args: - model_output: direct output from learned diffusion model. - timestep: current discrete timestep in the diffusion chain. - sample: current instance of sample being created by diffusion process. - eta: weight of noise for added noise in diffusion step. - generator: random number generator. - - Returns: - pred_prev_sample: Predicted previous sample - pred_original_sample: Predicted original sample - """ - # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf - # Ideally, read DDIM paper in-detail understanding - - # Notation ( -> - # - model_output -> e_theta(x_t, t) - # - pred_original_sample -> f_theta(x_t, t) or x_0 - # - std_dev_t -> sigma_t - # - eta -> η - # - pred_sample_direction -> "direction pointing to x_t" - # - pred_prev_sample -> "x_t-1" - - # 1. get previous step value (=t-1) - prev_timestep = timestep - self.num_train_timesteps // self.num_inference_steps - - # 2. compute alphas, betas - alpha_prod_t = self.alphas_cumprod[timestep] - alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod - - beta_prod_t = 1 - alpha_prod_t - - # 3. compute predicted original sample from predicted noise also called - # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf - if self.prediction_type == DDIMPredictionType.EPSILON: - pred_original_sample = (sample - (beta_prod_t**0.5) * model_output) / (alpha_prod_t**0.5) - pred_epsilon = model_output - elif self.prediction_type == DDIMPredictionType.SAMPLE: - pred_original_sample = model_output - pred_epsilon = (sample - (alpha_prod_t**0.5) * pred_original_sample) / (beta_prod_t**0.5) - elif self.prediction_type == DDIMPredictionType.V_PREDICTION: - pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output - pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample - - # 4. Clip "predicted x_0" - if self.clip_sample: - pred_original_sample = torch.clamp(pred_original_sample, -1, 1) - - # 5. compute variance: "sigma_t(η)" -> see formula (16) - # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1) - variance = self._get_variance(timestep, prev_timestep) - std_dev_t = eta * variance**0.5 - - # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf - pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** 0.5 * pred_epsilon - - # 7. compute x_t-1 without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf - pred_prev_sample = alpha_prod_t_prev**0.5 * pred_original_sample + pred_sample_direction - - if eta > 0: - # randn_like does not support generator https://github.com/pytorch/pytorch/issues/27072 - device: torch.device = torch.device(model_output.device if torch.is_tensor(model_output) else "cpu") - noise = torch.randn(model_output.shape, dtype=model_output.dtype, generator=generator).to(device) - variance = self._get_variance(timestep, prev_timestep) ** 0.5 * eta * noise - - pred_prev_sample = pred_prev_sample + variance - - return pred_prev_sample, pred_original_sample - - def reversed_step( - self, model_output: torch.Tensor, timestep: int, sample: torch.Tensor - ) -> tuple[torch.Tensor, torch.Tensor]: - """ - Predict the sample at the next timestep by reversing the SDE. Core function to propagate the diffusion - process from the learned model outputs (most often the predicted noise). - - Args: - model_output: direct output from learned diffusion model. - timestep: current discrete timestep in the diffusion chain. - sample: current instance of sample being created by diffusion process. - - Returns: - pred_prev_sample: Predicted previous sample - pred_original_sample: Predicted original sample - """ - # See Appendix F at https://arxiv.org/pdf/2105.05233.pdf, or Equation (6) in https://arxiv.org/pdf/2203.04306.pdf - - # Notation ( -> - # - model_output -> e_theta(x_t, t) - # - pred_original_sample -> f_theta(x_t, t) or x_0 - # - std_dev_t -> sigma_t - # - eta -> η - # - pred_sample_direction -> "direction pointing to x_t" - # - pred_post_sample -> "x_t+1" - - # 1. get previous step value (=t+1) - prev_timestep = timestep + self.num_train_timesteps // self.num_inference_steps - - # 2. compute alphas, betas at timestep t+1 - alpha_prod_t = self.alphas_cumprod[timestep] - alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod - - beta_prod_t = 1 - alpha_prod_t - - # 3. compute predicted original sample from predicted noise also called - # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf - - if self.prediction_type == DDIMPredictionType.EPSILON: - pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) - pred_epsilon = model_output - elif self.prediction_type == DDIMPredictionType.SAMPLE: - pred_original_sample = model_output - pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) - elif self.prediction_type == DDIMPredictionType.V_PREDICTION: - pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output - pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample - - # 4. Clip "predicted x_0" - if self.clip_sample: - pred_original_sample = torch.clamp(pred_original_sample, -1, 1) - - # 5. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf - pred_sample_direction = (1 - alpha_prod_t_prev) ** (0.5) * pred_epsilon - - # 6. compute x_t+1 without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf - pred_post_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction - - return pred_post_sample, pred_original_sample diff --git a/monai/networks/schedulers/ddpm.py b/monai/networks/schedulers/ddpm.py deleted file mode 100644 index a5173a1b656..00000000000 --- a/monai/networks/schedulers/ddpm.py +++ /dev/null @@ -1,243 +0,0 @@ -# Copyright (c) MONAI Consortium -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -# ========================================================================= -# Adapted from https://github.com/huggingface/diffusers -# which has the following license: -# https://github.com/huggingface/diffusers/blob/main/LICENSE -# -# Copyright 2022 UC Berkeley Team and The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ========================================================================= - -from __future__ import annotations - -import numpy as np -import torch - -from monai.utils import StrEnum - -from .scheduler import Scheduler - - -class DDPMVarianceType(StrEnum): - """ - Valid names for DDPM Scheduler's `variance_type` argument. Options to clip the variance used when adding noise - to the denoised sample. - """ - - FIXED_SMALL = "fixed_small" - FIXED_LARGE = "fixed_large" - LEARNED = "learned" - LEARNED_RANGE = "learned_range" - - -class DDPMPredictionType(StrEnum): - """ - Set of valid prediction type names for the DDPM scheduler's `prediction_type` argument. - - epsilon: predicting the noise of the diffusion process - sample: directly predicting the noisy sample - v_prediction: velocity prediction, see section 2.4 https://imagen.research.google/video/paper.pdf - """ - - EPSILON = "epsilon" - SAMPLE = "sample" - V_PREDICTION = "v_prediction" - - -class DDPMScheduler(Scheduler): - """ - Denoising diffusion probabilistic models (DDPMs) explores the connections between denoising score matching and - Langevin dynamics sampling. Based on: Ho et al., "Denoising Diffusion Probabilistic Models" - https://arxiv.org/abs/2006.11239 - - Args: - num_train_timesteps: number of diffusion steps used to train the model. - schedule: member of NoiseSchedules, name of noise schedule function in component store - variance_type: member of DDPMVarianceType - clip_sample: option to clip predicted sample between -1 and 1 for numerical stability. - prediction_type: member of DDPMPredictionType - schedule_args: arguments to pass to the schedule function - """ - - def __init__( - self, - num_train_timesteps: int = 1000, - schedule: str = "linear_beta", - variance_type: str = DDPMVarianceType.FIXED_SMALL, - clip_sample: bool = True, - prediction_type: str = DDPMPredictionType.EPSILON, - **schedule_args, - ) -> None: - super().__init__(num_train_timesteps, schedule, **schedule_args) - - if variance_type not in DDPMVarianceType.__members__.values(): - raise ValueError("Argument `variance_type` must be a member of `DDPMVarianceType`") - - if prediction_type not in DDPMPredictionType.__members__.values(): - raise ValueError("Argument `prediction_type` must be a member of `DDPMPredictionType`") - - self.clip_sample = clip_sample - self.variance_type = variance_type - self.prediction_type = prediction_type - - def set_timesteps(self, num_inference_steps: int, device: str | torch.device | None = None) -> None: - """ - Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. - - Args: - num_inference_steps: number of diffusion steps used when generating samples with a pre-trained model. - device: target device to put the data. - """ - if num_inference_steps > self.num_train_timesteps: - raise ValueError( - f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.num_train_timesteps`:" - f" {self.num_train_timesteps} as the unet model trained with this scheduler can only handle" - f" maximal {self.num_train_timesteps} timesteps." - ) - - self.num_inference_steps = num_inference_steps - step_ratio = self.num_train_timesteps // self.num_inference_steps - # creates integer timesteps by multiplying by ratio - # casting to int to avoid issues when num_inference_step is power of 3 - timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].astype(np.int64) - self.timesteps = torch.from_numpy(timesteps).to(device) - - def _get_mean(self, timestep: int, x_0: torch.Tensor, x_t: torch.Tensor) -> torch.Tensor: - """ - Compute the mean of the posterior at timestep t. - - Args: - timestep: current timestep. - x0: the noise-free input. - x_t: the input noised to timestep t. - - Returns: - Returns the mean - """ - # these attributes are used for calculating the posterior, q(x_{t-1}|x_t,x_0), - # (see formula (5-7) from https://arxiv.org/pdf/2006.11239.pdf) - alpha_t = self.alphas[timestep] - alpha_prod_t = self.alphas_cumprod[timestep] - alpha_prod_t_prev = self.alphas_cumprod[timestep - 1] if timestep > 0 else self.one - - x_0_coefficient = alpha_prod_t_prev.sqrt() * self.betas[timestep] / (1 - alpha_prod_t) - x_t_coefficient = alpha_t.sqrt() * (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) - - mean: torch.Tensor = x_0_coefficient * x_0 + x_t_coefficient * x_t - - return mean - - def _get_variance(self, timestep: int, predicted_variance: torch.Tensor | None = None) -> torch.Tensor: - """ - Compute the variance of the posterior at timestep t. - - Args: - timestep: current timestep. - predicted_variance: variance predicted by the model. - - Returns: - Returns the variance - """ - alpha_prod_t = self.alphas_cumprod[timestep] - alpha_prod_t_prev = self.alphas_cumprod[timestep - 1] if timestep > 0 else self.one - - # For t > 0, compute predicted variance βt (see formula (6) and (7) from https://arxiv.org/pdf/2006.11239.pdf) - # and sample from it to get previous sample - # x_{t-1} ~ N(pred_prev_sample, variance) == add variance to pred_sample - variance: torch.Tensor = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * self.betas[timestep] - # hacks - were probably added for training stability - if self.variance_type == DDPMVarianceType.FIXED_SMALL: - variance = torch.clamp(variance, min=1e-20) - elif self.variance_type == DDPMVarianceType.FIXED_LARGE: - variance = self.betas[timestep] - elif self.variance_type == DDPMVarianceType.LEARNED and predicted_variance is not None: - return predicted_variance - elif self.variance_type == DDPMVarianceType.LEARNED_RANGE and predicted_variance is not None: - min_log = variance - max_log = self.betas[timestep] - frac = (predicted_variance + 1) / 2 - variance = frac * max_log + (1 - frac) * min_log - - return variance - - def step( - self, model_output: torch.Tensor, timestep: int, sample: torch.Tensor, generator: torch.Generator | None = None - ) -> tuple[torch.Tensor, torch.Tensor]: - """ - Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion - process from the learned model outputs (most often the predicted noise). - - Args: - model_output: direct output from learned diffusion model. - timestep: current discrete timestep in the diffusion chain. - sample: current instance of sample being created by diffusion process. - generator: random number generator. - - Returns: - pred_prev_sample: Predicted previous sample - """ - if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in ["learned", "learned_range"]: - model_output, predicted_variance = torch.split(model_output, sample.shape[1], dim=1) - else: - predicted_variance = None - - # 1. compute alphas, betas - alpha_prod_t = self.alphas_cumprod[timestep] - alpha_prod_t_prev = self.alphas_cumprod[timestep - 1] if timestep > 0 else self.one - beta_prod_t = 1 - alpha_prod_t - beta_prod_t_prev = 1 - alpha_prod_t_prev - - # 2. compute predicted original sample from predicted noise also called - # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf - if self.prediction_type == DDPMPredictionType.EPSILON: - pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) - elif self.prediction_type == DDPMPredictionType.SAMPLE: - pred_original_sample = model_output - elif self.prediction_type == DDPMPredictionType.V_PREDICTION: - pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output - - # 3. Clip "predicted x_0" - if self.clip_sample: - pred_original_sample = torch.clamp(pred_original_sample, -1, 1) - - # 4. Compute coefficients for pred_original_sample x_0 and current sample x_t - # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf - pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * self.betas[timestep]) / beta_prod_t - current_sample_coeff = self.alphas[timestep] ** (0.5) * beta_prod_t_prev / beta_prod_t - - # 5. Compute predicted previous sample µ_t - # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf - pred_prev_sample = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * sample - - # 6. Add noise - variance = 0 - if timestep > 0: - noise = torch.randn( - model_output.size(), dtype=model_output.dtype, layout=model_output.layout, generator=generator - ).to(model_output.device) - variance = (self._get_variance(timestep, predicted_variance=predicted_variance) ** 0.5) * noise - - pred_prev_sample = pred_prev_sample + variance - - return pred_prev_sample, pred_original_sample diff --git a/monai/networks/schedulers/pndm.py b/monai/networks/schedulers/pndm.py deleted file mode 100644 index c0728bbdff7..00000000000 --- a/monai/networks/schedulers/pndm.py +++ /dev/null @@ -1,316 +0,0 @@ -# Copyright (c) MONAI Consortium -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -# ========================================================================= -# Adapted from https://github.com/huggingface/diffusers -# which has the following license: -# https://github.com/huggingface/diffusers/blob/main/LICENSE -# -# Copyright 2022 UC Berkeley Team and The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ========================================================================= - -from __future__ import annotations - -from typing import Any - -import numpy as np -import torch - -from monai.utils import StrEnum - -from .scheduler import Scheduler - - -class PNDMPredictionType(StrEnum): - """ - Set of valid prediction type names for the PNDM scheduler's `prediction_type` argument. - - epsilon: predicting the noise of the diffusion process - v_prediction: velocity prediction, see section 2.4 https://imagen.research.google/video/paper.pdf - """ - - EPSILON = "epsilon" - V_PREDICTION = "v_prediction" - - -class PNDMScheduler(Scheduler): - """ - Pseudo numerical methods for diffusion models (PNDM) proposes using more advanced ODE integration techniques, - namely Runge-Kutta method and a linear multi-step method. Based on: Liu et al., - "Pseudo Numerical Methods for Diffusion Models on Manifolds" https://arxiv.org/abs/2202.09778 - - Args: - num_train_timesteps: number of diffusion steps used to train the model. - schedule: member of NoiseSchedules, name of noise schedule function in component store - skip_prk_steps: - allows the scheduler to skip the Runge-Kutta steps that are defined in the original paper as being required - before plms step. - set_alpha_to_one: - each diffusion step uses the value of alphas product at that step and at the previous one. For the final - step there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`, - otherwise it uses the value of alpha at step 0. - prediction_type: member of DDPMPredictionType - steps_offset: - an offset added to the inference steps. You can use a combination of `offset=1` and - `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in - stable diffusion. - schedule_args: arguments to pass to the schedule function - """ - - def __init__( - self, - num_train_timesteps: int = 1000, - schedule: str = "linear_beta", - skip_prk_steps: bool = False, - set_alpha_to_one: bool = False, - prediction_type: str = PNDMPredictionType.EPSILON, - steps_offset: int = 0, - **schedule_args, - ) -> None: - super().__init__(num_train_timesteps, schedule, **schedule_args) - - if prediction_type not in PNDMPredictionType.__members__.values(): - raise ValueError("Argument `prediction_type` must be a member of PNDMPredictionType") - - self.prediction_type = prediction_type - - self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0] - - # standard deviation of the initial noise distribution - self.init_noise_sigma = 1.0 - - # For now we only support F-PNDM, i.e. the runge-kutta method - # For more information on the algorithm please take a look at the paper: https://arxiv.org/pdf/2202.09778.pdf - # mainly at formula (9), (12), (13) and the Algorithm 2. - self.pndm_order = 4 - - self.skip_prk_steps = skip_prk_steps - self.steps_offset = steps_offset - - # running values - self.cur_model_output = torch.Tensor() - self.counter = 0 - self.cur_sample = torch.Tensor() - self.ets: list = [] - - # default the number of inference timesteps to the number of train steps - self.set_timesteps(num_train_timesteps) - - def set_timesteps(self, num_inference_steps: int, device: str | torch.device | None = None) -> None: - """ - Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. - - Args: - num_inference_steps: number of diffusion steps used when generating samples with a pre-trained model. - device: target device to put the data. - """ - if num_inference_steps > self.num_train_timesteps: - raise ValueError( - f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.num_train_timesteps`:" - f" {self.num_train_timesteps} as the unet model trained with this scheduler can only handle" - f" maximal {self.num_train_timesteps} timesteps." - ) - - self.num_inference_steps = num_inference_steps - step_ratio = self.num_train_timesteps // self.num_inference_steps - # creates integer timesteps by multiplying by ratio - # casting to int to avoid issues when num_inference_step is power of 3 - self._timesteps = (np.arange(0, num_inference_steps) * step_ratio).round().astype(np.int64) - self._timesteps += self.steps_offset - - if self.skip_prk_steps: - # for some models like stable diffusion the prk steps can/should be skipped to - # produce better results. When using PNDM with `self.skip_prk_steps` the implementation - # is based on crowsonkb's PLMS sampler implementation: https://github.com/CompVis/latent-diffusion/pull/51 - self.prk_timesteps = np.array([]) - self.plms_timesteps = self._timesteps[::-1] - - else: - prk_timesteps = np.array(self._timesteps[-self.pndm_order :]).repeat(2) + np.tile( - np.array([0, self.num_train_timesteps // num_inference_steps // 2]), self.pndm_order - ) - self.prk_timesteps = (prk_timesteps[:-1].repeat(2)[1:-1])[::-1].copy() - self.plms_timesteps = self._timesteps[:-3][ - ::-1 - ].copy() # we copy to avoid having negative strides which are not supported by torch.from_numpy - - timesteps = np.concatenate([self.prk_timesteps, self.plms_timesteps]).astype(np.int64) - self.timesteps = torch.from_numpy(timesteps).to(device) - # update num_inference_steps - necessary if we use prk steps - self.num_inference_steps = len(self.timesteps) - - self.ets = [] - self.counter = 0 - - def step(self, model_output: torch.Tensor, timestep: int, sample: torch.Tensor) -> tuple[torch.Tensor, Any]: - """ - Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion - process from the learned model outputs (most often the predicted noise). - This function calls `step_prk()` or `step_plms()` depending on the internal variable `counter`. - - Args: - model_output: direct output from learned diffusion model. - timestep: current discrete timestep in the diffusion chain. - sample: current instance of sample being created by diffusion process. - Returns: - pred_prev_sample: Predicted previous sample - """ - # return a tuple for consistency with samplers that return (previous pred, original sample pred) - - if self.counter < len(self.prk_timesteps) and not self.skip_prk_steps: - return self.step_prk(model_output=model_output, timestep=timestep, sample=sample), None - else: - return self.step_plms(model_output=model_output, timestep=timestep, sample=sample), None - - def step_prk(self, model_output: torch.Tensor, timestep: int, sample: torch.Tensor) -> torch.Tensor: - """ - Step function propagating the sample with the Runge-Kutta method. RK takes 4 forward passes to approximate the - solution to the differential equation. - - Args: - model_output: direct output from learned diffusion model. - timestep: current discrete timestep in the diffusion chain. - sample: current instance of sample being created by diffusion process. - - Returns: - pred_prev_sample: Predicted previous sample - """ - if self.num_inference_steps is None: - raise ValueError( - "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" - ) - - diff_to_prev = 0 if self.counter % 2 else self.num_train_timesteps // self.num_inference_steps // 2 - prev_timestep = timestep - diff_to_prev - timestep = self.prk_timesteps[self.counter // 4 * 4] - - if self.counter % 4 == 0: - self.cur_model_output = 1 / 6 * model_output - self.ets.append(model_output) - self.cur_sample = sample - elif (self.counter - 1) % 4 == 0: - self.cur_model_output += 1 / 3 * model_output - elif (self.counter - 2) % 4 == 0: - self.cur_model_output += 1 / 3 * model_output - elif (self.counter - 3) % 4 == 0: - model_output = self.cur_model_output + 1 / 6 * model_output - self.cur_model_output = torch.Tensor() - - # cur_sample should not be an empty torch.Tensor() - cur_sample = self.cur_sample if self.cur_sample.numel() != 0 else sample - - prev_sample: torch.Tensor = self._get_prev_sample(cur_sample, timestep, prev_timestep, model_output) - self.counter += 1 - - return prev_sample - - def step_plms(self, model_output: torch.Tensor, timestep: int, sample: torch.Tensor) -> Any: - """ - Step function propagating the sample with the linear multi-step method. This has one forward pass with multiple - times to approximate the solution. - - Args: - model_output: direct output from learned diffusion model. - timestep: current discrete timestep in the diffusion chain. - sample: current instance of sample being created by diffusion process. - - Returns: - pred_prev_sample: Predicted previous sample - """ - if self.num_inference_steps is None: - raise ValueError( - "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" - ) - - if not self.skip_prk_steps and len(self.ets) < 3: - raise ValueError( - f"{self.__class__} can only be run AFTER scheduler has been run " - "in 'prk' mode for at least 12 iterations " - ) - - prev_timestep = timestep - self.num_train_timesteps // self.num_inference_steps - - if self.counter != 1: - self.ets = self.ets[-3:] - self.ets.append(model_output) - else: - prev_timestep = timestep - timestep = timestep + self.num_train_timesteps // self.num_inference_steps - - if len(self.ets) == 1 and self.counter == 0: - model_output = model_output - self.cur_sample = sample - elif len(self.ets) == 1 and self.counter == 1: - model_output = (model_output + self.ets[-1]) / 2 - sample = self.cur_sample - self.cur_sample = torch.Tensor() - elif len(self.ets) == 2: - model_output = (3 * self.ets[-1] - self.ets[-2]) / 2 - elif len(self.ets) == 3: - model_output = (23 * self.ets[-1] - 16 * self.ets[-2] + 5 * self.ets[-3]) / 12 - else: - model_output = (1 / 24) * (55 * self.ets[-1] - 59 * self.ets[-2] + 37 * self.ets[-3] - 9 * self.ets[-4]) - - prev_sample = self._get_prev_sample(sample, timestep, prev_timestep, model_output) - self.counter += 1 - - return prev_sample - - def _get_prev_sample(self, sample: torch.Tensor, timestep: int, prev_timestep: int, model_output: torch.Tensor): - # See formula (9) of PNDM paper https://arxiv.org/pdf/2202.09778.pdf - # this function computes x_(t−δ) using the formula of (9) - # Note that x_t needs to be added to both sides of the equation - - # Notation ( -> - # alpha_prod_t -> α_t - # alpha_prod_t_prev -> α_(t−δ) - # beta_prod_t -> (1 - α_t) - # beta_prod_t_prev -> (1 - α_(t−δ)) - # sample -> x_t - # model_output -> e_θ(x_t, t) - # prev_sample -> x_(t−δ) - alpha_prod_t = self.alphas_cumprod[timestep] - alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod - beta_prod_t = 1 - alpha_prod_t - beta_prod_t_prev = 1 - alpha_prod_t_prev - - if self.prediction_type == PNDMPredictionType.V_PREDICTION: - model_output = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample - - # corresponds to (α_(t−δ) - α_t) divided by - # denominator of x_t in formula (9) and plus 1 - # Note: (α_(t−δ) - α_t) / (sqrt(α_t) * (sqrt(α_(t−δ)) + sqr(α_t))) = - # sqrt(α_(t−δ)) / sqrt(α_t)) - sample_coeff = (alpha_prod_t_prev / alpha_prod_t) ** (0.5) - - # corresponds to denominator of e_θ(x_t, t) in formula (9) - model_output_denom_coeff = alpha_prod_t * beta_prod_t_prev ** (0.5) + ( - alpha_prod_t * beta_prod_t * alpha_prod_t_prev - ) ** (0.5) - - # full formula (9) - prev_sample = ( - sample_coeff * sample - (alpha_prod_t_prev - alpha_prod_t) * model_output / model_output_denom_coeff - ) - - return prev_sample diff --git a/monai/networks/schedulers/scheduler.py b/monai/networks/schedulers/scheduler.py deleted file mode 100644 index 17bb526abcd..00000000000 --- a/monai/networks/schedulers/scheduler.py +++ /dev/null @@ -1,203 +0,0 @@ -# Copyright (c) MONAI Consortium -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -# ========================================================================= -# Adapted from https://github.com/huggingface/diffusers -# which has the following license: -# https://github.com/huggingface/diffusers/blob/main/LICENSE -# -# Copyright 2022 UC Berkeley Team and The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ========================================================================= - - -from __future__ import annotations - -import torch -import torch.nn as nn - -from monai.utils import ComponentStore, unsqueeze_right - -NoiseSchedules = ComponentStore("NoiseSchedules", "Functions to generate noise schedules") - - -@NoiseSchedules.add_def("linear_beta", "Linear beta schedule") -def _linear_beta(num_train_timesteps: int, beta_start: float = 1e-4, beta_end: float = 2e-2): - """ - Linear beta noise schedule function. - - Args: - num_train_timesteps: number of timesteps - beta_start: start of beta range, default 1e-4 - beta_end: end of beta range, default 2e-2 - - Returns: - betas: beta schedule tensor - """ - return torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) - - -@NoiseSchedules.add_def("scaled_linear_beta", "Scaled linear beta schedule") -def _scaled_linear_beta(num_train_timesteps: int, beta_start: float = 1e-4, beta_end: float = 2e-2): - """ - Scaled linear beta noise schedule function. - - Args: - num_train_timesteps: number of timesteps - beta_start: start of beta range, default 1e-4 - beta_end: end of beta range, default 2e-2 - - Returns: - betas: beta schedule tensor - """ - return torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 - - -@NoiseSchedules.add_def("sigmoid_beta", "Sigmoid beta schedule") -def _sigmoid_beta(num_train_timesteps: int, beta_start: float = 1e-4, beta_end: float = 2e-2, sig_range: float = 6): - """ - Sigmoid beta noise schedule function. - - Args: - num_train_timesteps: number of timesteps - beta_start: start of beta range, default 1e-4 - beta_end: end of beta range, default 2e-2 - sig_range: pos/neg range of sigmoid input, default 6 - - Returns: - betas: beta schedule tensor - """ - betas = torch.linspace(-sig_range, sig_range, num_train_timesteps) - return torch.sigmoid(betas) * (beta_end - beta_start) + beta_start - - -@NoiseSchedules.add_def("cosine", "Cosine schedule") -def _cosine_beta(num_train_timesteps: int, s: float = 8e-3): - """ - Cosine noise schedule, see https://arxiv.org/abs/2102.09672 - - Args: - num_train_timesteps: number of timesteps - s: smoothing factor, default 8e-3 (see referenced paper) - - Returns: - (betas, alphas, alpha_cumprod) values - """ - x = torch.linspace(0, num_train_timesteps, num_train_timesteps + 1) - alphas_cumprod = torch.cos(((x / num_train_timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2 - alphas_cumprod /= alphas_cumprod[0].item() - alphas = torch.clip(alphas_cumprod[1:] / alphas_cumprod[:-1], 0.0001, 0.9999) - betas = 1.0 - alphas - return betas, alphas, alphas_cumprod[:-1] - - -class Scheduler(nn.Module): - """ - Base class for other schedulers based on a noise schedule function. - - This class is meant as the base for other schedulers which implement their own way of sampling or stepping. Here - the class defines beta, alpha, and alpha_cumprod values from a noise schedule function named with `schedule`, - which is the name of a component in NoiseSchedules. These components must all be callables which return either - the beta schedule alone or a triple containing (betas, alphas, alphas_cumprod) values. New schedule functions - can be provided by using the NoiseSchedules.add_def, for example: - - .. code-block:: python - - from monai.networks.schedulers import NoiseSchedules, DDPMScheduler - - @NoiseSchedules.add_def("my_beta_schedule", "Some description of your function") - def _beta_function(num_train_timesteps, beta_start=1e-4, beta_end=2e-2): - return torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) - - scheduler = DDPMScheduler(num_train_timesteps=1000, schedule="my_beta_schedule") - - All such functions should have an initial positional integer argument `num_train_timesteps` stating the number of - timesteps the schedule is for, otherwise any other arguments can be given which will be passed by keyword through - the constructor's `schedule_args` value. To see what noise functions are available, print the object NoiseSchedules - to get a listing of stored objects with their docstring descriptions. - - Note: in previous versions of the schedulers the argument `schedule_beta` was used to state the beta schedule - type, this now replaced with `schedule` and most names used with the previous argument now have "_beta" appended - to them, eg. 'schedule_beta="linear"' -> 'schedule="linear_beta"'. The `beta_start` and `beta_end` arguments are - still used for some schedules but these are provided as keyword arguments now. - - Args: - num_train_timesteps: number of diffusion steps used to train the model. - schedule: member of NoiseSchedules, - a named function returning the beta tensor or (betas, alphas, alphas_cumprod) triple - schedule_args: arguments to pass to the schedule function - """ - - def __init__(self, num_train_timesteps: int = 1000, schedule: str = "linear_beta", **schedule_args) -> None: - super().__init__() - schedule_args["num_train_timesteps"] = num_train_timesteps - noise_sched = NoiseSchedules[schedule](**schedule_args) - - # set betas, alphas, alphas_cumprod based off return value from noise function - if isinstance(noise_sched, tuple): - self.betas, self.alphas, self.alphas_cumprod = noise_sched - else: - self.betas = noise_sched - self.alphas = 1.0 - self.betas - self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) - - self.num_train_timesteps = num_train_timesteps - self.one = torch.tensor(1.0) - - # settable values - self.num_inference_steps: int | None = None - self.timesteps = torch.arange(num_train_timesteps - 1, -1, -1) - - def add_noise(self, original_samples: torch.Tensor, noise: torch.Tensor, timesteps: torch.Tensor) -> torch.Tensor: - """ - Add noise to the original samples. - - Args: - original_samples: original samples - noise: noise to add to samples - timesteps: timesteps tensor indicating the timestep to be computed for each sample. - - Returns: - noisy_samples: sample with added noise - """ - # Make sure alphas_cumprod and timestep have same device and dtype as original_samples - self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype) - timesteps = timesteps.to(original_samples.device) - - sqrt_alpha_cumprod: torch.Tensor = unsqueeze_right(self.alphas_cumprod[timesteps] ** 0.5, original_samples.ndim) - sqrt_one_minus_alpha_prod: torch.Tensor = unsqueeze_right( - (1 - self.alphas_cumprod[timesteps]) ** 0.5, original_samples.ndim - ) - - noisy_samples = sqrt_alpha_cumprod * original_samples + sqrt_one_minus_alpha_prod * noise - return noisy_samples - - def get_velocity(self, sample: torch.Tensor, noise: torch.Tensor, timesteps: torch.Tensor) -> torch.Tensor: - # Make sure alphas_cumprod and timestep have same device and dtype as sample - self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device, dtype=sample.dtype) - timesteps = timesteps.to(sample.device) - - sqrt_alpha_prod = unsqueeze_right(self.alphas_cumprod[timesteps] ** 0.5, sample.ndim) - sqrt_one_minus_alpha_prod = unsqueeze_right((1 - self.alphas_cumprod[timesteps]) ** 0.5, sample.ndim) - - velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample - return velocity diff --git a/monai/utils/misc.py b/monai/utils/misc.py index 4f2501a7eec..d6ff370f695 100644 --- a/monai/utils/misc.py +++ b/monai/utils/misc.py @@ -890,11 +890,11 @@ def is_sqrt(num: Sequence[int] | int) -> bool: return ensure_tuple(ret) == num -def unsqueeze_right(arr: torch.Tensor, ndim: int) -> torch.Tensor: +def unsqueeze_right(arr: NdarrayOrTensor, ndim: int) -> NdarrayOrTensor: """Append 1-sized dimensions to `arr` to create a result with `ndim` dimensions.""" return arr[(...,) + (None,) * (ndim - arr.ndim)] -def unsqueeze_left(arr: torch.Tensor, ndim: int) -> torch.Tensor: +def unsqueeze_left(arr: NdarrayOrTensor, ndim: int) -> NdarrayOrTensor: """Prepend 1-sized dimensions to `arr` to create a result with `ndim` dimensions.""" return arr[(None,) * (ndim - arr.ndim)] diff --git a/monai/utils/ordering.py b/monai/utils/ordering.py deleted file mode 100644 index 1be61f98abe..00000000000 --- a/monai/utils/ordering.py +++ /dev/null @@ -1,207 +0,0 @@ -# Copyright (c) MONAI Consortium -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -import numpy as np - -from monai.utils.enums import OrderingTransformations, OrderingType - - -class Ordering: - """ - Ordering class that projects a 2D or 3D image into a 1D sequence. It also allows the image to be transformed with - one of the following transformations: - Reflection (see np.flip for more details). - Transposition (see np.transpose for more details). - 90-degree rotation (see np.rot90 for more details). - - The transformations are applied in the order specified by the transformation_order parameter. - - Args: - ordering_type: The ordering type. One of the following: - - 'raster_scan': The image is projected into a 1D sequence by scanning the image from left to right and from - top to bottom. Also called a row major ordering. - - 's_curve': The image is projected into a 1D sequence by scanning the image in a circular snake like - pattern from top left towards right gowing in a spiral towards the center. - - random': The image is projected into a 1D sequence by randomly shuffling the image. - spatial_dims: The number of spatial dimensions of the image. - dimensions: The dimensions of the image. - reflected_spatial_dims: A tuple of booleans indicating whether to reflect the image along each spatial dimension. - transpositions_axes: A tuple of tuples indicating the axes to transpose the image along. - rot90_axes: A tuple of tuples indicating the axes to rotate the image along. - transformation_order: The order in which to apply the transformations. - """ - - def __init__( - self, - ordering_type: str, - spatial_dims: int, - dimensions: tuple[int, int, int] | tuple[int, int, int, int], - reflected_spatial_dims: tuple[bool, bool] | None = None, - transpositions_axes: tuple[tuple[int, int], ...] | tuple[tuple[int, int, int], ...] | None = None, - rot90_axes: tuple[tuple[int, int], ...] | None = None, - transformation_order: tuple[str, ...] = ( - OrderingTransformations.TRANSPOSE.value, - OrderingTransformations.ROTATE_90.value, - OrderingTransformations.REFLECT.value, - ), - ) -> None: - super().__init__() - self.ordering_type = ordering_type - - if self.ordering_type not in list(OrderingType): - raise ValueError( - f"ordering_type must be one of the following {list(OrderingType)}, but got {self.ordering_type}." - ) - - self.spatial_dims = spatial_dims - self.dimensions = dimensions - - if len(dimensions) != self.spatial_dims + 1: - raise ValueError(f"dimensions must be of length {self.spatial_dims + 1}, but got {len(dimensions)}.") - - self.reflected_spatial_dims = reflected_spatial_dims - self.transpositions_axes = transpositions_axes - self.rot90_axes = rot90_axes - if len(set(transformation_order)) != len(transformation_order): - raise ValueError(f"No duplicates are allowed. Received {transformation_order}.") - - for transformation in transformation_order: - if transformation not in list(OrderingTransformations): - raise ValueError( - f"Valid transformations are {list(OrderingTransformations)} but received {transformation}." - ) - self.transformation_order = transformation_order - - self.template = self._create_template() - self._sequence_ordering = self._create_ordering() - self._revert_sequence_ordering = np.argsort(self._sequence_ordering) - - def __call__(self, x: np.ndarray) -> np.ndarray: - x = x[self._sequence_ordering] - - return x - - def get_sequence_ordering(self) -> np.ndarray: - return self._sequence_ordering - - def get_revert_sequence_ordering(self) -> np.ndarray: - return self._revert_sequence_ordering - - def _create_ordering(self) -> np.ndarray: - self.template = self._transform_template() - order = self._order_template(template=self.template) - - return order - - def _create_template(self) -> np.ndarray: - spatial_dimensions = self.dimensions[1:] - template = np.arange(np.prod(spatial_dimensions)).reshape(*spatial_dimensions) - - return template - - def _transform_template(self) -> np.ndarray: - for transformation in self.transformation_order: - if transformation == OrderingTransformations.TRANSPOSE.value: - self.template = self._transpose_template(template=self.template) - elif transformation == OrderingTransformations.ROTATE_90.value: - self.template = self._rot90_template(template=self.template) - elif transformation == OrderingTransformations.REFLECT.value: - self.template = self._flip_template(template=self.template) - - return self.template - - def _transpose_template(self, template: np.ndarray) -> np.ndarray: - if self.transpositions_axes is not None: - for axes in self.transpositions_axes: - template = np.transpose(template, axes=axes) - - return template - - def _flip_template(self, template: np.ndarray) -> np.ndarray: - if self.reflected_spatial_dims is not None: - for axis, to_reflect in enumerate(self.reflected_spatial_dims): - template = np.flip(template, axis=axis) if to_reflect else template - - return template - - def _rot90_template(self, template: np.ndarray) -> np.ndarray: - if self.rot90_axes is not None: - for axes in self.rot90_axes: - template = np.rot90(template, axes=axes) - - return template - - def _order_template(self, template: np.ndarray) -> np.ndarray: - depths = None - if self.spatial_dims == 2: - rows, columns = template.shape[0], template.shape[1] - else: - rows, columns, depths = (template.shape[0], template.shape[1], template.shape[2]) - - sequence = eval(f"self.{self.ordering_type}_idx")(rows, columns, depths) - - ordering = np.array([template[tuple(e)] for e in sequence]) - - return ordering - - @staticmethod - def raster_scan_idx(rows: int, cols: int, depths: int | None = None) -> np.ndarray: - idx: list[tuple] = [] - - for r in range(rows): - for c in range(cols): - if depths is not None: - for d in range(depths): - idx.append((r, c, d)) - else: - idx.append((r, c)) - - idx_np = np.array(idx) - - return idx_np - - @staticmethod - def s_curve_idx(rows: int, cols: int, depths: int | None = None) -> np.ndarray: - idx: list[tuple] = [] - - for r in range(rows): - col_idx = range(cols) if r % 2 == 0 else range(cols - 1, -1, -1) - for c in col_idx: - if depths: - depth_idx = range(depths) if c % 2 == 0 else range(depths - 1, -1, -1) - - for d in depth_idx: - idx.append((r, c, d)) - else: - idx.append((r, c)) - - idx_np = np.array(idx) - - return idx_np - - @staticmethod - def random_idx(rows: int, cols: int, depths: int | None = None) -> np.ndarray: - idx: list[tuple] = [] - - for r in range(rows): - for c in range(cols): - if depths: - for d in range(depths): - idx.append((r, c, d)) - else: - idx.append((r, c)) - - idx_np = np.array(idx) - np.random.shuffle(idx_np) - - return idx_np diff --git a/test_spade_autoencoderkl.py b/test_spade_autoencoderkl.py deleted file mode 100644 index 6675a6db676..00000000000 --- a/test_spade_autoencoderkl.py +++ /dev/null @@ -1,260 +0,0 @@ -# Copyright (c) MONAI Consortium -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -import unittest - -import torch -from parameterized import parameterized - -from monai.networks import eval_mode -from monai.networks.nets import SPADEAutoencoderKL - -CASES = [ - [ - { - "spatial_dims": 2, - "label_nc": 3, - "in_channels": 1, - "out_channels": 1, - "channels": (4, 4, 4), - "latent_channels": 4, - "attention_levels": (False, False, False), - "num_res_blocks": 1, - "norm_num_groups": 4, - }, - (1, 1, 16, 16), - (1, 3, 16, 16), - (1, 1, 16, 16), - (1, 4, 4, 4), - ], - [ - { - "spatial_dims": 2, - "label_nc": 3, - "in_channels": 1, - "out_channels": 1, - "channels": (4, 4, 4), - "latent_channels": 4, - "attention_levels": (False, False, False), - "num_res_blocks": (1, 1, 2), - "norm_num_groups": 4, - }, - (1, 1, 16, 16), - (1, 3, 16, 16), - (1, 1, 16, 16), - (1, 4, 4, 4), - ], - [ - { - "spatial_dims": 2, - "label_nc": 3, - "in_channels": 1, - "out_channels": 1, - "channels": (4, 4, 4), - "latent_channels": 4, - "attention_levels": (False, False, False), - "num_res_blocks": 1, - "norm_num_groups": 4, - }, - (1, 1, 16, 16), - (1, 3, 16, 16), - (1, 1, 16, 16), - (1, 4, 4, 4), - ], - [ - { - "spatial_dims": 2, - "label_nc": 3, - "in_channels": 1, - "out_channels": 1, - "channels": (4, 4, 4), - "latent_channels": 4, - "attention_levels": (False, False, True), - "num_res_blocks": 1, - "norm_num_groups": 4, - }, - (1, 1, 16, 16), - (1, 3, 16, 16), - (1, 1, 16, 16), - (1, 4, 4, 4), - ], - [ - { - "spatial_dims": 2, - "label_nc": 3, - "in_channels": 1, - "out_channels": 1, - "channels": (4, 4, 4), - "latent_channels": 4, - "attention_levels": (False, False, False), - "num_res_blocks": 1, - "norm_num_groups": 4, - "with_encoder_nonlocal_attn": False, - }, - (1, 1, 16, 16), - (1, 3, 16, 16), - (1, 1, 16, 16), - (1, 4, 4, 4), - ], - [ - { - "spatial_dims": 2, - "label_nc": 3, - "in_channels": 1, - "out_channels": 1, - "channels": (4, 4, 4), - "latent_channels": 4, - "attention_levels": (False, False, False), - "num_res_blocks": 1, - "norm_num_groups": 4, - "with_encoder_nonlocal_attn": False, - "with_decoder_nonlocal_attn": False, - }, - (1, 1, 16, 16), - (1, 3, 16, 16), - (1, 1, 16, 16), - (1, 4, 4, 4), - ], - [ - { - "spatial_dims": 3, - "label_nc": 3, - "in_channels": 1, - "out_channels": 1, - "channels": (4, 4, 4), - "latent_channels": 4, - "attention_levels": (False, False, True), - "num_res_blocks": 1, - "norm_num_groups": 4, - }, - (1, 1, 16, 16, 16), - (1, 3, 16, 16, 16), - (1, 1, 16, 16, 16), - (1, 4, 4, 4, 4), - ], - [ - { - "spatial_dims": 2, - "label_nc": 3, - "in_channels": 1, - "out_channels": 1, - "channels": (4, 4, 4), - "latent_channels": 4, - "attention_levels": (False, False, True), - "num_res_blocks": 1, - "norm_num_groups": 4, - "spade_intermediate_channels": 32, - }, - (1, 1, 16, 16), - (1, 3, 16, 16), - (1, 1, 16, 16), - (1, 4, 4, 4), - ], -] - -device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") - - -class TestSPADEAutoEncoderKL(unittest.TestCase): - @parameterized.expand(CASES) - def test_shape(self, input_param, input_shape, input_seg, expected_shape, expected_latent_shape): - net = SPADEAutoencoderKL(**input_param).to(device) - with eval_mode(net): - result = net.forward(torch.randn(input_shape).to(device), torch.randn(input_seg).to(device)) - self.assertEqual(result[0].shape, expected_shape) - self.assertEqual(result[1].shape, expected_latent_shape) - - def test_model_channels_not_multiple_of_norm_num_group(self): - with self.assertRaises(ValueError): - SPADEAutoencoderKL( - spatial_dims=2, - label_nc=3, - in_channels=1, - out_channels=1, - channels=(24, 24, 24), - attention_levels=(False, False, False), - latent_channels=8, - num_res_blocks=1, - norm_num_groups=16, - ) - - def test_model_channels_not_same_size_of_attention_levels(self): - with self.assertRaises(ValueError): - SPADEAutoencoderKL( - spatial_dims=2, - label_nc=3, - in_channels=1, - out_channels=1, - channels=(24, 24, 24), - attention_levels=(False, False), - latent_channels=8, - num_res_blocks=1, - norm_num_groups=16, - ) - - def test_model_channels_not_same_size_of_num_res_blocks(self): - with self.assertRaises(ValueError): - SPADEAutoencoderKL( - spatial_dims=2, - label_nc=3, - in_channels=1, - out_channels=1, - channels=(24, 24, 24), - attention_levels=(False, False, False), - latent_channels=8, - num_res_blocks=(8, 8), - norm_num_groups=16, - ) - - def test_shape_encode(self): - input_param, input_shape, _, _, expected_latent_shape = CASES[0] - net = SPADEAutoencoderKL(**input_param).to(device) - with eval_mode(net): - result = net.encode(torch.randn(input_shape).to(device)) - self.assertEqual(result[0].shape, expected_latent_shape) - self.assertEqual(result[1].shape, expected_latent_shape) - - def test_shape_sampling(self): - input_param, _, _, _, expected_latent_shape = CASES[0] - net = SPADEAutoencoderKL(**input_param).to(device) - with eval_mode(net): - result = net.sampling( - torch.randn(expected_latent_shape).to(device), torch.randn(expected_latent_shape).to(device) - ) - self.assertEqual(result.shape, expected_latent_shape) - - def test_shape_decode(self): - input_param, _, input_seg_shape, expected_input_shape, latent_shape = CASES[0] - net = SPADEAutoencoderKL(**input_param).to(device) - with eval_mode(net): - result = net.decode(torch.randn(latent_shape).to(device), torch.randn(input_seg_shape).to(device)) - self.assertEqual(result.shape, expected_input_shape) - - def test_wrong_shape_decode(self): - net = SPADEAutoencoderKL( - spatial_dims=2, - label_nc=3, - in_channels=1, - out_channels=1, - channels=(4, 4, 4), - latent_channels=4, - attention_levels=(False, False, False), - num_res_blocks=1, - norm_num_groups=4, - ) - with self.assertRaises(RuntimeError): - _ = net.decode(torch.randn((1, 1, 16, 16)).to(device), torch.randn((1, 6, 16, 16)).to(device)) - - -if __name__ == "__main__": - unittest.main() diff --git a/test_spade_diffusion_model_unet.py b/test_spade_diffusion_model_unet.py deleted file mode 100644 index c8a2103cf64..00000000000 --- a/test_spade_diffusion_model_unet.py +++ /dev/null @@ -1,558 +0,0 @@ -# Copyright (c) MONAI Consortium -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -import unittest - -import torch -from parameterized import parameterized - -from monai.networks import eval_mode -from monai.networks.nets import SPADEDiffusionModelUNet - -UNCOND_CASES_2D = [ - [ - { - "spatial_dims": 2, - "in_channels": 1, - "out_channels": 1, - "num_res_blocks": 1, - "num_channels": (8, 8, 8), - "attention_levels": (False, False, False), - "norm_num_groups": 8, - "label_nc": 3, - } - ], - [ - { - "spatial_dims": 2, - "in_channels": 1, - "out_channels": 1, - "num_res_blocks": (1, 1, 2), - "num_channels": (8, 8, 8), - "attention_levels": (False, False, False), - "norm_num_groups": 8, - "label_nc": 3, - } - ], - [ - { - "spatial_dims": 2, - "in_channels": 1, - "out_channels": 1, - "num_res_blocks": 1, - "num_channels": (8, 8, 8), - "attention_levels": (False, False, False), - "norm_num_groups": 8, - "resblock_updown": True, - "label_nc": 3, - } - ], - [ - { - "spatial_dims": 2, - "in_channels": 1, - "out_channels": 1, - "num_res_blocks": 1, - "num_channels": (8, 8, 8), - "attention_levels": (False, False, True), - "num_head_channels": 8, - "norm_num_groups": 8, - "label_nc": 3, - } - ], - [ - { - "spatial_dims": 2, - "in_channels": 1, - "out_channels": 1, - "num_res_blocks": 1, - "num_channels": (8, 8, 8), - "attention_levels": (False, False, True), - "num_head_channels": 8, - "norm_num_groups": 8, - "resblock_updown": True, - "label_nc": 3, - } - ], - [ - { - "spatial_dims": 2, - "in_channels": 1, - "out_channels": 1, - "num_res_blocks": 1, - "num_channels": (8, 8, 8), - "attention_levels": (False, False, True), - "num_head_channels": 4, - "norm_num_groups": 8, - "label_nc": 3, - } - ], - [ - { - "spatial_dims": 2, - "in_channels": 1, - "out_channels": 1, - "num_res_blocks": 1, - "num_channels": (8, 8, 8), - "attention_levels": (False, True, True), - "num_head_channels": (0, 2, 4), - "norm_num_groups": 8, - "label_nc": 3, - } - ], -] - -UNCOND_CASES_3D = [ - [ - { - "spatial_dims": 3, - "in_channels": 1, - "out_channels": 1, - "num_res_blocks": 1, - "num_channels": (8, 8, 8), - "attention_levels": (False, False, False), - "norm_num_groups": 8, - "label_nc": 3, - "spade_intermediate_channels": 256, - } - ], - [ - { - "spatial_dims": 3, - "in_channels": 1, - "out_channels": 1, - "num_res_blocks": 1, - "num_channels": (8, 8, 8), - "attention_levels": (False, False, False), - "norm_num_groups": 8, - "label_nc": 3, - } - ], - [ - { - "spatial_dims": 3, - "in_channels": 1, - "out_channels": 1, - "num_res_blocks": 1, - "num_channels": (8, 8, 8), - "attention_levels": (False, False, False), - "norm_num_groups": 8, - "resblock_updown": True, - "label_nc": 3, - } - ], - [ - { - "spatial_dims": 3, - "in_channels": 1, - "out_channels": 1, - "num_res_blocks": 1, - "num_channels": (8, 8, 8), - "attention_levels": (False, False, True), - "num_head_channels": 8, - "norm_num_groups": 8, - "label_nc": 3, - } - ], - [ - { - "spatial_dims": 3, - "in_channels": 1, - "out_channels": 1, - "num_res_blocks": 1, - "num_channels": (8, 8, 8), - "attention_levels": (False, False, True), - "num_head_channels": 8, - "norm_num_groups": 8, - "resblock_updown": True, - "label_nc": 3, - } - ], - [ - { - "spatial_dims": 3, - "in_channels": 1, - "out_channels": 1, - "num_res_blocks": 1, - "num_channels": (8, 8, 8), - "attention_levels": (False, False, True), - "num_head_channels": 4, - "norm_num_groups": 8, - "label_nc": 3, - } - ], - [ - { - "spatial_dims": 3, - "in_channels": 1, - "out_channels": 1, - "num_res_blocks": 1, - "num_channels": (8, 8, 8), - "attention_levels": (False, False, True), - "num_head_channels": (0, 0, 4), - "norm_num_groups": 8, - "label_nc": 3, - } - ], -] - -COND_CASES_2D = [ - [ - { - "spatial_dims": 2, - "in_channels": 1, - "out_channels": 1, - "num_res_blocks": 1, - "num_channels": (8, 8, 8), - "attention_levels": (False, False, True), - "num_head_channels": 4, - "norm_num_groups": 8, - "with_conditioning": True, - "transformer_num_layers": 1, - "cross_attention_dim": 3, - "label_nc": 3, - } - ], - [ - { - "spatial_dims": 2, - "in_channels": 1, - "out_channels": 1, - "num_res_blocks": 1, - "num_channels": (8, 8, 8), - "attention_levels": (False, False, True), - "num_head_channels": 4, - "norm_num_groups": 8, - "with_conditioning": True, - "transformer_num_layers": 1, - "cross_attention_dim": 3, - "resblock_updown": True, - "label_nc": 3, - } - ], - [ - { - "spatial_dims": 2, - "in_channels": 1, - "out_channels": 1, - "num_res_blocks": 1, - "num_channels": (8, 8, 8), - "attention_levels": (False, False, True), - "num_head_channels": 4, - "norm_num_groups": 8, - "with_conditioning": True, - "transformer_num_layers": 1, - "cross_attention_dim": 3, - "upcast_attention": True, - "label_nc": 3, - } - ], -] - - -class TestSPADEDiffusionModelUNet2D(unittest.TestCase): - @parameterized.expand(UNCOND_CASES_2D) - def test_shape_unconditioned_models(self, input_param): - net = SPADEDiffusionModelUNet(**input_param) - with eval_mode(net): - result = net.forward( - torch.rand((1, 1, 16, 16)), - torch.randint(0, 1000, (1,)).long(), - torch.rand((1, input_param["label_nc"], 16, 16)), - ) - self.assertEqual(result.shape, (1, 1, 16, 16)) - - def test_timestep_with_wrong_shape(self): - net = SPADEDiffusionModelUNet( - spatial_dims=2, - label_nc=3, - in_channels=1, - out_channels=1, - num_res_blocks=1, - num_channels=(8, 8, 8), - attention_levels=(False, False, False), - norm_num_groups=8, - ) - with self.assertRaises(ValueError): - with eval_mode(net): - net.forward( - torch.rand((1, 1, 16, 16)), torch.randint(0, 1000, (1, 1)).long(), torch.rand((1, 3, 16, 16)) - ) - - def test_label_with_wrong_shape(self): - net = SPADEDiffusionModelUNet( - spatial_dims=2, - label_nc=3, - in_channels=1, - out_channels=1, - num_res_blocks=1, - num_channels=(8, 8, 8), - attention_levels=(False, False, False), - norm_num_groups=8, - ) - with self.assertRaises(RuntimeError): - with eval_mode(net): - net.forward(torch.rand((1, 1, 16, 16)), torch.randint(0, 1000, (1,)).long(), torch.rand((1, 6, 16, 16))) - - def test_shape_with_different_in_channel_out_channel(self): - in_channels = 6 - out_channels = 3 - net = SPADEDiffusionModelUNet( - spatial_dims=2, - label_nc=3, - in_channels=in_channels, - out_channels=out_channels, - num_res_blocks=1, - num_channels=(8, 8, 8), - attention_levels=(False, False, False), - norm_num_groups=8, - ) - with eval_mode(net): - result = net.forward( - torch.rand((1, in_channels, 16, 16)), torch.randint(0, 1000, (1,)).long(), torch.rand((1, 3, 16, 16)) - ) - self.assertEqual(result.shape, (1, out_channels, 16, 16)) - - def test_model_channels_not_multiple_of_norm_num_group(self): - with self.assertRaises(ValueError): - SPADEDiffusionModelUNet( - spatial_dims=2, - label_nc=3, - in_channels=1, - out_channels=1, - num_res_blocks=1, - num_channels=(8, 8, 12), - attention_levels=(False, False, False), - norm_num_groups=8, - ) - - def test_attention_levels_with_different_length_num_head_channels(self): - with self.assertRaises(ValueError): - SPADEDiffusionModelUNet( - spatial_dims=2, - label_nc=3, - in_channels=1, - out_channels=1, - num_res_blocks=1, - num_channels=(8, 8, 8), - attention_levels=(False, False, False), - num_head_channels=(0, 2), - norm_num_groups=8, - ) - - def test_num_res_blocks_with_different_length_num_channels(self): - with self.assertRaises(ValueError): - SPADEDiffusionModelUNet( - spatial_dims=2, - label_nc=3, - in_channels=1, - out_channels=1, - num_res_blocks=(1, 1), - num_channels=(8, 8, 8), - attention_levels=(False, False, False), - norm_num_groups=8, - ) - - def test_shape_conditioned_models(self): - net = SPADEDiffusionModelUNet( - spatial_dims=2, - label_nc=3, - in_channels=1, - out_channels=1, - num_res_blocks=1, - num_channels=(8, 8, 8), - attention_levels=(False, False, True), - with_conditioning=True, - transformer_num_layers=1, - cross_attention_dim=3, - norm_num_groups=8, - num_head_channels=8, - ) - with eval_mode(net): - result = net.forward( - x=torch.rand((1, 1, 16, 32)), - timesteps=torch.randint(0, 1000, (1,)).long(), - seg=torch.rand((1, 3, 16, 32)), - context=torch.rand((1, 1, 3)), - ) - self.assertEqual(result.shape, (1, 1, 16, 32)) - - def test_with_conditioning_cross_attention_dim_none(self): - with self.assertRaises(ValueError): - SPADEDiffusionModelUNet( - spatial_dims=2, - label_nc=3, - in_channels=1, - out_channels=1, - num_res_blocks=1, - num_channels=(8, 8, 8), - attention_levels=(False, False, True), - with_conditioning=True, - transformer_num_layers=1, - cross_attention_dim=None, - norm_num_groups=8, - ) - - def test_context_with_conditioning_none(self): - net = SPADEDiffusionModelUNet( - spatial_dims=2, - label_nc=3, - in_channels=1, - out_channels=1, - num_res_blocks=1, - num_channels=(8, 8, 8), - attention_levels=(False, False, True), - with_conditioning=False, - transformer_num_layers=1, - norm_num_groups=8, - ) - - with self.assertRaises(ValueError): - with eval_mode(net): - net.forward( - x=torch.rand((1, 1, 16, 32)), - timesteps=torch.randint(0, 1000, (1,)).long(), - seg=torch.rand((1, 3, 16, 32)), - context=torch.rand((1, 1, 3)), - ) - - def test_shape_conditioned_models_class_conditioning(self): - net = SPADEDiffusionModelUNet( - spatial_dims=2, - label_nc=3, - in_channels=1, - out_channels=1, - num_res_blocks=1, - num_channels=(8, 8, 8), - attention_levels=(False, False, True), - norm_num_groups=8, - num_head_channels=8, - num_class_embeds=2, - ) - with eval_mode(net): - result = net.forward( - x=torch.rand((1, 1, 16, 32)), - timesteps=torch.randint(0, 1000, (1,)).long(), - seg=torch.rand((1, 3, 16, 32)), - class_labels=torch.randint(0, 2, (1,)).long(), - ) - self.assertEqual(result.shape, (1, 1, 16, 32)) - - def test_conditioned_models_no_class_labels(self): - net = SPADEDiffusionModelUNet( - spatial_dims=2, - label_nc=3, - in_channels=1, - out_channels=1, - num_res_blocks=1, - num_channels=(8, 8, 8), - attention_levels=(False, False, True), - norm_num_groups=8, - num_head_channels=8, - num_class_embeds=2, - ) - - with self.assertRaises(ValueError): - net.forward( - x=torch.rand((1, 1, 16, 32)), - timesteps=torch.randint(0, 1000, (1,)).long(), - seg=torch.rand((1, 3, 16, 32)), - ) - - def test_model_num_channels_not_same_size_of_attention_levels(self): - with self.assertRaises(ValueError): - SPADEDiffusionModelUNet( - spatial_dims=2, - label_nc=3, - in_channels=1, - out_channels=1, - num_res_blocks=1, - num_channels=(8, 8, 8), - attention_levels=(False, False), - norm_num_groups=8, - num_head_channels=8, - num_class_embeds=2, - ) - - @parameterized.expand(COND_CASES_2D) - def test_conditioned_2d_models_shape(self, input_param): - net = SPADEDiffusionModelUNet(**input_param) - with eval_mode(net): - result = net.forward( - torch.rand((1, 1, 16, 16)), - torch.randint(0, 1000, (1,)).long(), - torch.rand((1, input_param["label_nc"], 16, 16)), - torch.rand((1, 1, 3)), - ) - self.assertEqual(result.shape, (1, 1, 16, 16)) - - -class TestDiffusionModelUNet3D(unittest.TestCase): - @parameterized.expand(UNCOND_CASES_3D) - def test_shape_unconditioned_models(self, input_param): - net = SPADEDiffusionModelUNet(**input_param) - with eval_mode(net): - result = net.forward( - torch.rand((1, 1, 16, 16, 16)), - torch.randint(0, 1000, (1,)).long(), - torch.rand((1, input_param["label_nc"], 16, 16, 16)), - ) - self.assertEqual(result.shape, (1, 1, 16, 16, 16)) - - def test_shape_with_different_in_channel_out_channel(self): - in_channels = 6 - out_channels = 3 - net = SPADEDiffusionModelUNet( - spatial_dims=3, - label_nc=3, - in_channels=in_channels, - out_channels=out_channels, - num_res_blocks=1, - num_channels=(8, 8, 8), - attention_levels=(False, False, True), - norm_num_groups=4, - ) - with eval_mode(net): - result = net.forward( - torch.rand((1, in_channels, 16, 16, 16)), - torch.randint(0, 1000, (1,)).long(), - torch.rand((1, 3, 16, 16, 16)), - ) - self.assertEqual(result.shape, (1, out_channels, 16, 16, 16)) - - def test_shape_conditioned_models(self): - net = SPADEDiffusionModelUNet( - spatial_dims=3, - label_nc=3, - in_channels=1, - out_channels=1, - num_res_blocks=1, - num_channels=(16, 16, 16), - attention_levels=(False, False, True), - norm_num_groups=16, - with_conditioning=True, - transformer_num_layers=1, - cross_attention_dim=3, - ) - with eval_mode(net): - result = net.forward( - x=torch.rand((1, 1, 16, 16, 16)), - timesteps=torch.randint(0, 1000, (1,)).long(), - seg=torch.rand((1, 3, 16, 16, 16)), - context=torch.rand((1, 1, 3)), - ) - self.assertEqual(result.shape, (1, 1, 16, 16, 16)) - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/test_autoencoderkl.py b/tests/test_autoencoderkl.py deleted file mode 100644 index 448f1e8e9a6..00000000000 --- a/tests/test_autoencoderkl.py +++ /dev/null @@ -1,276 +0,0 @@ -# Copyright (c) MONAI Consortium -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -import unittest - -import torch -from parameterized import parameterized - -from monai.networks import eval_mode -from monai.networks.nets import AutoencoderKL -from tests.utils import SkipIfBeforePyTorchVersion - -device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") - -CASES = [ - [ - { - "spatial_dims": 2, - "in_channels": 1, - "out_channels": 1, - "channels": (4, 4, 4), - "latent_channels": 4, - "attention_levels": (False, False, False), - "num_res_blocks": 1, - "norm_num_groups": 4, - }, - (1, 1, 16, 16), - (1, 1, 16, 16), - (1, 4, 4, 4), - ], - [ - { - "spatial_dims": 2, - "in_channels": 1, - "out_channels": 1, - "channels": (4, 4, 4), - "latent_channels": 4, - "attention_levels": (False, False, False), - "num_res_blocks": (1, 1, 2), - "norm_num_groups": 4, - }, - (1, 1, 16, 16), - (1, 1, 16, 16), - (1, 4, 4, 4), - ], - [ - { - "spatial_dims": 2, - "in_channels": 1, - "out_channels": 1, - "channels": (4, 4, 4), - "latent_channels": 4, - "attention_levels": (False, False, False), - "num_res_blocks": 1, - "norm_num_groups": 4, - }, - (1, 1, 16, 16), - (1, 1, 16, 16), - (1, 4, 4, 4), - ], - [ - { - "spatial_dims": 2, - "in_channels": 1, - "out_channels": 1, - "channels": (4, 4, 4), - "latent_channels": 4, - "attention_levels": (False, False, True), - "num_res_blocks": 1, - "norm_num_groups": 4, - }, - (1, 1, 16, 16), - (1, 1, 16, 16), - (1, 4, 4, 4), - ], - [ - { - "spatial_dims": 2, - "in_channels": 1, - "out_channels": 1, - "channels": (4, 4, 4), - "latent_channels": 4, - "attention_levels": (False, False, False), - "num_res_blocks": 1, - "norm_num_groups": 4, - "with_encoder_nonlocal_attn": False, - }, - (1, 1, 16, 16), - (1, 1, 16, 16), - (1, 4, 4, 4), - ], - [ - { - "spatial_dims": 2, - "in_channels": 1, - "out_channels": 1, - "channels": (4, 4, 4), - "latent_channels": 4, - "attention_levels": (False, False, False), - "num_res_blocks": 1, - "norm_num_groups": 4, - "with_encoder_nonlocal_attn": False, - "with_decoder_nonlocal_attn": False, - }, - (1, 1, 16, 16), - (1, 1, 16, 16), - (1, 4, 4, 4), - ], - [ - { - "spatial_dims": 3, - "in_channels": 1, - "out_channels": 1, - "channels": (4, 4, 4), - "latent_channels": 4, - "attention_levels": (False, False, True), - "num_res_blocks": 1, - "norm_num_groups": 4, - }, - (1, 1, 16, 16, 16), - (1, 1, 16, 16, 16), - (1, 4, 4, 4, 4), - ], -] - - -class TestAutoEncoderKL(unittest.TestCase): - @parameterized.expand(CASES) - def test_shape(self, input_param, input_shape, expected_shape, expected_latent_shape): - net = AutoencoderKL(**input_param).to(device) - with eval_mode(net): - result = net.forward(torch.randn(input_shape).to(device)) - self.assertEqual(result[0].shape, expected_shape) - self.assertEqual(result[1].shape, expected_latent_shape) - self.assertEqual(result[2].shape, expected_latent_shape) - - @parameterized.expand(CASES) - @SkipIfBeforePyTorchVersion((1, 11)) - def test_shape_with_convtranspose_and_checkpointing( - self, input_param, input_shape, expected_shape, expected_latent_shape - ): - input_param = input_param.copy() - input_param.update({"use_checkpoint": True, "use_convtranspose": True}) - net = AutoencoderKL(**input_param).to(device) - with eval_mode(net): - result = net.forward(torch.randn(input_shape).to(device)) - self.assertEqual(result[0].shape, expected_shape) - self.assertEqual(result[1].shape, expected_latent_shape) - self.assertEqual(result[2].shape, expected_latent_shape) - - def test_model_channels_not_multiple_of_norm_num_group(self): - with self.assertRaises(ValueError): - AutoencoderKL( - spatial_dims=2, - in_channels=1, - out_channels=1, - channels=(24, 24, 24), - attention_levels=(False, False, False), - latent_channels=8, - num_res_blocks=1, - norm_num_groups=16, - ) - - def test_model_num_channels_not_same_size_of_attention_levels(self): - with self.assertRaises(ValueError): - AutoencoderKL( - spatial_dims=2, - in_channels=1, - out_channels=1, - channels=(24, 24, 24), - attention_levels=(False, False), - latent_channels=8, - num_res_blocks=1, - norm_num_groups=16, - ) - - def test_model_num_channels_not_same_size_of_num_res_blocks(self): - with self.assertRaises(ValueError): - AutoencoderKL( - spatial_dims=2, - in_channels=1, - out_channels=1, - channels=(24, 24, 24), - attention_levels=(False, False, False), - latent_channels=8, - num_res_blocks=(8, 8), - norm_num_groups=16, - ) - - def test_shape_reconstruction(self): - input_param, input_shape, expected_shape, _ = CASES[0] - net = AutoencoderKL(**input_param).to(device) - with eval_mode(net): - result = net.reconstruct(torch.randn(input_shape).to(device)) - self.assertEqual(result.shape, expected_shape) - - @SkipIfBeforePyTorchVersion((1, 11)) - def test_shape_reconstruction_with_convtranspose_and_checkpointing(self): - input_param, input_shape, expected_shape, _ = CASES[0] - input_param = input_param.copy() - input_param.update({"use_checkpoint": True, "use_convtranspose": True}) - net = AutoencoderKL(**input_param).to(device) - with eval_mode(net): - result = net.reconstruct(torch.randn(input_shape).to(device)) - self.assertEqual(result.shape, expected_shape) - - def test_shape_encode(self): - input_param, input_shape, _, expected_latent_shape = CASES[0] - net = AutoencoderKL(**input_param).to(device) - with eval_mode(net): - result = net.encode(torch.randn(input_shape).to(device)) - self.assertEqual(result[0].shape, expected_latent_shape) - self.assertEqual(result[1].shape, expected_latent_shape) - - @SkipIfBeforePyTorchVersion((1, 11)) - def test_shape_encode_with_convtranspose_and_checkpointing(self): - input_param, input_shape, _, expected_latent_shape = CASES[0] - input_param = input_param.copy() - input_param.update({"use_checkpoint": True, "use_convtranspose": True}) - net = AutoencoderKL(**input_param).to(device) - with eval_mode(net): - result = net.encode(torch.randn(input_shape).to(device)) - self.assertEqual(result[0].shape, expected_latent_shape) - self.assertEqual(result[1].shape, expected_latent_shape) - - def test_shape_sampling(self): - input_param, _, _, expected_latent_shape = CASES[0] - net = AutoencoderKL(**input_param).to(device) - with eval_mode(net): - result = net.sampling( - torch.randn(expected_latent_shape).to(device), torch.randn(expected_latent_shape).to(device) - ) - self.assertEqual(result.shape, expected_latent_shape) - - @SkipIfBeforePyTorchVersion((1, 11)) - def test_shape_sampling_convtranspose_and_checkpointing(self): - input_param, _, _, expected_latent_shape = CASES[0] - input_param = input_param.copy() - input_param.update({"use_checkpoint": True, "use_convtranspose": True}) - net = AutoencoderKL(**input_param).to(device) - with eval_mode(net): - result = net.sampling( - torch.randn(expected_latent_shape).to(device), torch.randn(expected_latent_shape).to(device) - ) - self.assertEqual(result.shape, expected_latent_shape) - - def test_shape_decode(self): - input_param, expected_input_shape, _, latent_shape = CASES[0] - net = AutoencoderKL(**input_param).to(device) - with eval_mode(net): - result = net.decode(torch.randn(latent_shape).to(device)) - self.assertEqual(result.shape, expected_input_shape) - - @SkipIfBeforePyTorchVersion((1, 11)) - def test_shape_decode_convtranspose_and_checkpointing(self): - input_param, expected_input_shape, _, latent_shape = CASES[0] - input_param = input_param.copy() - input_param.update({"use_checkpoint": True, "use_convtranspose": True}) - net = AutoencoderKL(**input_param).to(device) - with eval_mode(net): - result = net.decode(torch.randn(latent_shape).to(device)) - self.assertEqual(result.shape, expected_input_shape) - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/test_controlnet.py b/tests/test_controlnet.py deleted file mode 100644 index 07dfa2e49b0..00000000000 --- a/tests/test_controlnet.py +++ /dev/null @@ -1,177 +0,0 @@ -# Copyright (c) MONAI Consortium -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -import unittest - -import torch -from parameterized import parameterized - -from monai.networks import eval_mode -from monai.networks.nets.controlnet import ControlNet - -UNCOND_CASES_2D = [ - [ - { - "spatial_dims": 2, - "in_channels": 1, - "num_res_blocks": 1, - "channels": (8, 8, 8), - "attention_levels": (False, False, False), - "norm_num_groups": 8, - }, - (1, 8, 4, 4), - ], - [ - { - "spatial_dims": 2, - "in_channels": 1, - "num_res_blocks": 1, - "channels": (8, 8, 8), - "attention_levels": (False, False, False), - "norm_num_groups": 8, - "resblock_updown": True, - }, - (1, 8, 4, 4), - ], - [ - { - "spatial_dims": 2, - "in_channels": 1, - "num_res_blocks": 1, - "channels": (4, 4, 4), - "attention_levels": (False, False, True), - "num_head_channels": 4, - "norm_num_groups": 4, - }, - (1, 4, 4, 4), - ], - [ - { - "spatial_dims": 2, - "in_channels": 1, - "num_res_blocks": 1, - "channels": (8, 8, 8), - "attention_levels": (False, False, True), - "num_head_channels": 8, - "norm_num_groups": 8, - "resblock_updown": True, - }, - (1, 8, 4, 4), - ], -] - -UNCOND_CASES_3D = [ - [ - { - "spatial_dims": 3, - "in_channels": 1, - "num_res_blocks": 1, - "channels": (8, 8, 8), - "attention_levels": (False, False, False), - "norm_num_groups": 8, - }, - (1, 8, 4, 4, 4), - ], - [ - { - "spatial_dims": 3, - "in_channels": 1, - "num_res_blocks": 1, - "channels": (4, 4, 4), - "num_head_channels": 4, - "attention_levels": (False, False, False), - "norm_num_groups": 4, - "resblock_updown": True, - }, - (1, 4, 4, 4, 4), - ], -] - -COND_CASES_2D = [ - [ - { - "spatial_dims": 2, - "in_channels": 1, - "num_res_blocks": 1, - "channels": (8, 8, 8), - "attention_levels": (False, False, False), - "norm_num_groups": 8, - "with_conditioning": True, - "transformer_num_layers": 1, - "cross_attention_dim": 3, - }, - (1, 8, 4, 4), - ], - [ - { - "spatial_dims": 2, - "in_channels": 1, - "num_res_blocks": 1, - "channels": (8, 8, 8), - "attention_levels": (False, False, False), - "norm_num_groups": 8, - "with_conditioning": True, - "transformer_num_layers": 1, - "cross_attention_dim": 3, - "resblock_updown": True, - }, - (1, 8, 4, 4), - ], - [ - { - "spatial_dims": 2, - "in_channels": 1, - "num_res_blocks": 1, - "channels": (8, 8, 8), - "attention_levels": (False, False, False), - "norm_num_groups": 8, - "with_conditioning": True, - "transformer_num_layers": 1, - "cross_attention_dim": 3, - "upcast_attention": True, - }, - (1, 8, 4, 4), - ], -] - - -class TestControlNet(unittest.TestCase): - @parameterized.expand(UNCOND_CASES_2D + UNCOND_CASES_3D) - def test_shape_unconditioned_models(self, input_param, expected_output_shape): - input_param["conditioning_embedding_in_channels"] = input_param["in_channels"] - input_param["conditioning_embedding_num_channels"] = (input_param["channels"][0],) - net = ControlNet(**input_param) - with eval_mode(net): - x = torch.rand((1, 1) + (16,) * input_param["spatial_dims"]) - timesteps = torch.randint(0, 1000, (1,)).long() - controlnet_cond = torch.rand((1, 1) + (16,) * input_param["spatial_dims"]) - result = net.forward(x, timesteps=timesteps, controlnet_cond=controlnet_cond) - self.assertEqual(len(result[0]), 2 * len(input_param["channels"])) - self.assertEqual(result[1].shape, expected_output_shape) - - @parameterized.expand(COND_CASES_2D) - def test_shape_conditioned_models(self, input_param, expected_output_shape): - input_param["conditioning_embedding_in_channels"] = input_param["in_channels"] - input_param["conditioning_embedding_num_channels"] = (input_param["channels"][0],) - net = ControlNet(**input_param) - with eval_mode(net): - x = torch.rand((1, 1) + (16,) * input_param["spatial_dims"]) - timesteps = torch.randint(0, 1000, (1,)).long() - controlnet_cond = torch.rand((1, 1) + (16,) * input_param["spatial_dims"]) - result = net.forward(x, timesteps=timesteps, controlnet_cond=controlnet_cond, context=torch.rand((1, 1, 3))) - self.assertEqual(len(result[0]), 2 * len(input_param["channels"])) - self.assertEqual(result[1].shape, expected_output_shape) - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/test_diffusion_model_unet.py b/tests/test_diffusion_model_unet.py deleted file mode 100644 index d40a31a1da9..00000000000 --- a/tests/test_diffusion_model_unet.py +++ /dev/null @@ -1,535 +0,0 @@ -# Copyright (c) MONAI Consortium -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -import unittest - -import torch -from parameterized import parameterized - -from monai.networks import eval_mode -from monai.networks.nets import DiffusionModelUNet - -UNCOND_CASES_2D = [ - [ - { - "spatial_dims": 2, - "in_channels": 1, - "out_channels": 1, - "num_res_blocks": 1, - "channels": (8, 8, 8), - "attention_levels": (False, False, False), - "norm_num_groups": 8, - } - ], - [ - { - "spatial_dims": 2, - "in_channels": 1, - "out_channels": 1, - "num_res_blocks": (1, 1, 2), - "channels": (8, 8, 8), - "attention_levels": (False, False, False), - "norm_num_groups": 8, - } - ], - [ - { - "spatial_dims": 2, - "in_channels": 1, - "out_channels": 1, - "num_res_blocks": 1, - "channels": (8, 8, 8), - "attention_levels": (False, False, False), - "norm_num_groups": 8, - "resblock_updown": True, - } - ], - [ - { - "spatial_dims": 2, - "in_channels": 1, - "out_channels": 1, - "num_res_blocks": 1, - "channels": (8, 8, 8), - "attention_levels": (False, False, True), - "num_head_channels": 8, - "norm_num_groups": 8, - } - ], - [ - { - "spatial_dims": 2, - "in_channels": 1, - "out_channels": 1, - "num_res_blocks": 1, - "channels": (8, 8, 8), - "attention_levels": (False, False, True), - "num_head_channels": 8, - "norm_num_groups": 8, - "resblock_updown": True, - } - ], - [ - { - "spatial_dims": 2, - "in_channels": 1, - "out_channels": 1, - "num_res_blocks": 1, - "channels": (8, 8, 8), - "attention_levels": (False, False, True), - "num_head_channels": 4, - "norm_num_groups": 8, - } - ], - [ - { - "spatial_dims": 2, - "in_channels": 1, - "out_channels": 1, - "num_res_blocks": 1, - "channels": (8, 8, 8), - "attention_levels": (False, True, True), - "num_head_channels": (0, 2, 4), - "norm_num_groups": 8, - } - ], -] - -UNCOND_CASES_3D = [ - [ - { - "spatial_dims": 3, - "in_channels": 1, - "out_channels": 1, - "num_res_blocks": 1, - "channels": (8, 8, 8), - "attention_levels": (False, False, False), - "norm_num_groups": 8, - } - ], - [ - { - "spatial_dims": 3, - "in_channels": 1, - "out_channels": 1, - "num_res_blocks": 1, - "channels": (8, 8, 8), - "attention_levels": (False, False, False), - "norm_num_groups": 8, - "resblock_updown": True, - } - ], - [ - { - "spatial_dims": 3, - "in_channels": 1, - "out_channels": 1, - "num_res_blocks": 1, - "channels": (8, 8, 8), - "attention_levels": (False, False, True), - "num_head_channels": 8, - "norm_num_groups": 8, - } - ], - [ - { - "spatial_dims": 3, - "in_channels": 1, - "out_channels": 1, - "num_res_blocks": 1, - "channels": (8, 8, 8), - "attention_levels": (False, False, True), - "num_head_channels": 8, - "norm_num_groups": 8, - "resblock_updown": True, - } - ], - [ - { - "spatial_dims": 3, - "in_channels": 1, - "out_channels": 1, - "num_res_blocks": 1, - "channels": (8, 8, 8), - "attention_levels": (False, False, True), - "num_head_channels": 4, - "norm_num_groups": 8, - } - ], - [ - { - "spatial_dims": 3, - "in_channels": 1, - "out_channels": 1, - "num_res_blocks": 1, - "channels": (8, 8, 8), - "attention_levels": (False, False, True), - "num_head_channels": (0, 0, 4), - "norm_num_groups": 8, - } - ], -] - -COND_CASES_2D = [ - [ - { - "spatial_dims": 2, - "in_channels": 1, - "out_channels": 1, - "num_res_blocks": 1, - "channels": (8, 8, 8), - "attention_levels": (False, False, True), - "num_head_channels": 4, - "norm_num_groups": 8, - "with_conditioning": True, - "transformer_num_layers": 1, - "cross_attention_dim": 3, - } - ], - [ - { - "spatial_dims": 2, - "in_channels": 1, - "out_channels": 1, - "num_res_blocks": 1, - "channels": (8, 8, 8), - "attention_levels": (False, False, True), - "num_head_channels": 4, - "norm_num_groups": 8, - "with_conditioning": True, - "transformer_num_layers": 1, - "cross_attention_dim": 3, - "resblock_updown": True, - } - ], - [ - { - "spatial_dims": 2, - "in_channels": 1, - "out_channels": 1, - "num_res_blocks": 1, - "channels": (8, 8, 8), - "attention_levels": (False, False, True), - "num_head_channels": 4, - "norm_num_groups": 8, - "with_conditioning": True, - "transformer_num_layers": 1, - "cross_attention_dim": 3, - "upcast_attention": True, - } - ], -] - -DROPOUT_OK = [ - [ - { - "spatial_dims": 2, - "in_channels": 1, - "out_channels": 1, - "num_res_blocks": 1, - "channels": (8, 8, 8), - "attention_levels": (False, False, True), - "num_head_channels": 4, - "norm_num_groups": 8, - "with_conditioning": True, - "transformer_num_layers": 1, - "cross_attention_dim": 3, - "dropout_cattn": 0.25, - } - ], - [ - { - "spatial_dims": 2, - "in_channels": 1, - "out_channels": 1, - "num_res_blocks": 1, - "channels": (8, 8, 8), - "attention_levels": (False, False, True), - "num_head_channels": 4, - "norm_num_groups": 8, - "with_conditioning": True, - "transformer_num_layers": 1, - "cross_attention_dim": 3, - } - ], -] - -DROPOUT_WRONG = [ - [ - { - "spatial_dims": 2, - "in_channels": 1, - "out_channels": 1, - "num_res_blocks": 1, - "channels": (8, 8, 8), - "attention_levels": (False, False, True), - "num_head_channels": 4, - "norm_num_groups": 8, - "with_conditioning": True, - "transformer_num_layers": 1, - "cross_attention_dim": 3, - "dropout_cattn": 3.0, - } - ] -] - - -class TestDiffusionModelUNet2D(unittest.TestCase): - @parameterized.expand(UNCOND_CASES_2D) - def test_shape_unconditioned_models(self, input_param): - net = DiffusionModelUNet(**input_param) - with eval_mode(net): - result = net.forward(torch.rand((1, 1, 16, 16)), torch.randint(0, 1000, (1,)).long()) - self.assertEqual(result.shape, (1, 1, 16, 16)) - - def test_timestep_with_wrong_shape(self): - net = DiffusionModelUNet( - spatial_dims=2, - in_channels=1, - out_channels=1, - num_res_blocks=1, - channels=(8, 8, 8), - attention_levels=(False, False, False), - norm_num_groups=8, - ) - with self.assertRaises(ValueError): - with eval_mode(net): - net.forward(torch.rand((1, 1, 16, 16)), torch.randint(0, 1000, (1, 1)).long()) - - def test_shape_with_different_in_channel_out_channel(self): - in_channels = 6 - out_channels = 3 - net = DiffusionModelUNet( - spatial_dims=2, - in_channels=in_channels, - out_channels=out_channels, - num_res_blocks=1, - channels=(8, 8, 8), - attention_levels=(False, False, False), - norm_num_groups=8, - ) - with eval_mode(net): - result = net.forward(torch.rand((1, in_channels, 16, 16)), torch.randint(0, 1000, (1,)).long()) - self.assertEqual(result.shape, (1, out_channels, 16, 16)) - - def test_model_channels_not_multiple_of_norm_num_group(self): - with self.assertRaises(ValueError): - DiffusionModelUNet( - spatial_dims=2, - in_channels=1, - out_channels=1, - num_res_blocks=1, - channels=(8, 8, 12), - attention_levels=(False, False, False), - norm_num_groups=8, - ) - - def test_attention_levels_with_different_length_num_head_channels(self): - with self.assertRaises(ValueError): - DiffusionModelUNet( - spatial_dims=2, - in_channels=1, - out_channels=1, - num_res_blocks=1, - channels=(8, 8, 8), - attention_levels=(False, False, False), - num_head_channels=(0, 2), - norm_num_groups=8, - ) - - def test_num_res_blocks_with_different_length_channels(self): - with self.assertRaises(ValueError): - DiffusionModelUNet( - spatial_dims=2, - in_channels=1, - out_channels=1, - num_res_blocks=(1, 1), - channels=(8, 8, 8), - attention_levels=(False, False, False), - norm_num_groups=8, - ) - - def test_shape_conditioned_models(self): - net = DiffusionModelUNet( - spatial_dims=2, - in_channels=1, - out_channels=1, - num_res_blocks=1, - channels=(8, 8, 8), - attention_levels=(False, False, True), - with_conditioning=True, - transformer_num_layers=1, - cross_attention_dim=3, - norm_num_groups=8, - num_head_channels=8, - ) - with eval_mode(net): - result = net.forward( - x=torch.rand((1, 1, 16, 32)), - timesteps=torch.randint(0, 1000, (1,)).long(), - context=torch.rand((1, 1, 3)), - ) - self.assertEqual(result.shape, (1, 1, 16, 32)) - - def test_with_conditioning_cross_attention_dim_none(self): - with self.assertRaises(ValueError): - DiffusionModelUNet( - spatial_dims=2, - in_channels=1, - out_channels=1, - num_res_blocks=1, - channels=(8, 8, 8), - attention_levels=(False, False, True), - with_conditioning=True, - transformer_num_layers=1, - cross_attention_dim=None, - norm_num_groups=8, - ) - - def test_context_with_conditioning_none(self): - net = DiffusionModelUNet( - spatial_dims=2, - in_channels=1, - out_channels=1, - num_res_blocks=1, - channels=(8, 8, 8), - attention_levels=(False, False, True), - with_conditioning=False, - transformer_num_layers=1, - norm_num_groups=8, - ) - - with self.assertRaises(ValueError): - with eval_mode(net): - net.forward( - x=torch.rand((1, 1, 16, 32)), - timesteps=torch.randint(0, 1000, (1,)).long(), - context=torch.rand((1, 1, 3)), - ) - - def test_shape_conditioned_models_class_conditioning(self): - net = DiffusionModelUNet( - spatial_dims=2, - in_channels=1, - out_channels=1, - num_res_blocks=1, - channels=(8, 8, 8), - attention_levels=(False, False, True), - norm_num_groups=8, - num_head_channels=8, - num_class_embeds=2, - ) - with eval_mode(net): - result = net.forward( - x=torch.rand((1, 1, 16, 32)), - timesteps=torch.randint(0, 1000, (1,)).long(), - class_labels=torch.randint(0, 2, (1,)).long(), - ) - self.assertEqual(result.shape, (1, 1, 16, 32)) - - def test_conditioned_models_no_class_labels(self): - net = DiffusionModelUNet( - spatial_dims=2, - in_channels=1, - out_channels=1, - num_res_blocks=1, - channels=(8, 8, 8), - attention_levels=(False, False, True), - norm_num_groups=8, - num_head_channels=8, - num_class_embeds=2, - ) - - with self.assertRaises(ValueError): - net.forward(x=torch.rand((1, 1, 16, 32)), timesteps=torch.randint(0, 1000, (1,)).long()) - - def test_model_channels_not_same_size_of_attention_levels(self): - with self.assertRaises(ValueError): - DiffusionModelUNet( - spatial_dims=2, - in_channels=1, - out_channels=1, - num_res_blocks=1, - channels=(8, 8, 8), - attention_levels=(False, False), - norm_num_groups=8, - num_head_channels=8, - num_class_embeds=2, - ) - - @parameterized.expand(COND_CASES_2D) - def test_conditioned_2d_models_shape(self, input_param): - net = DiffusionModelUNet(**input_param) - with eval_mode(net): - result = net.forward(torch.rand((1, 1, 16, 16)), torch.randint(0, 1000, (1,)).long(), torch.rand((1, 1, 3))) - self.assertEqual(result.shape, (1, 1, 16, 16)) - - -class TestDiffusionModelUNet3D(unittest.TestCase): - @parameterized.expand(UNCOND_CASES_3D) - def test_shape_unconditioned_models(self, input_param): - net = DiffusionModelUNet(**input_param) - with eval_mode(net): - result = net.forward(torch.rand((1, 1, 16, 16, 16)), torch.randint(0, 1000, (1,)).long()) - self.assertEqual(result.shape, (1, 1, 16, 16, 16)) - - def test_shape_with_different_in_channel_out_channel(self): - in_channels = 6 - out_channels = 3 - net = DiffusionModelUNet( - spatial_dims=3, - in_channels=in_channels, - out_channels=out_channels, - num_res_blocks=1, - channels=(8, 8, 8), - attention_levels=(False, False, True), - norm_num_groups=4, - ) - with eval_mode(net): - result = net.forward(torch.rand((1, in_channels, 16, 16, 16)), torch.randint(0, 1000, (1,)).long()) - self.assertEqual(result.shape, (1, out_channels, 16, 16, 16)) - - def test_shape_conditioned_models(self): - net = DiffusionModelUNet( - spatial_dims=3, - in_channels=1, - out_channels=1, - num_res_blocks=1, - channels=(16, 16, 16), - attention_levels=(False, False, True), - norm_num_groups=16, - with_conditioning=True, - transformer_num_layers=1, - cross_attention_dim=3, - ) - with eval_mode(net): - result = net.forward( - x=torch.rand((1, 1, 16, 16, 16)), - timesteps=torch.randint(0, 1000, (1,)).long(), - context=torch.rand((1, 1, 3)), - ) - self.assertEqual(result.shape, (1, 1, 16, 16, 16)) - - # Test dropout specification for cross-attention blocks - @parameterized.expand(DROPOUT_WRONG) - def test_wrong_dropout(self, input_param): - with self.assertRaises(ValueError): - _ = DiffusionModelUNet(**input_param) - - @parameterized.expand(DROPOUT_OK) - def test_right_dropout(self, input_param): - _ = DiffusionModelUNet(**input_param) - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/test_ordering.py b/tests/test_ordering.py deleted file mode 100644 index 0c52dba5e52..00000000000 --- a/tests/test_ordering.py +++ /dev/null @@ -1,318 +0,0 @@ -# Copyright (c) MONAI Consortium -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -import unittest - -import numpy as np -from parameterized import parameterized - -from monai.utils.enums import OrderingTransformations, OrderingType -from monai.utils.ordering import Ordering - -TEST_2D_NON_RANDOM = [ - [ - { - "ordering_type": OrderingType.RASTER_SCAN, - "spatial_dims": 2, - "dimensions": (1, 2, 2), - "reflected_spatial_dims": (), - "transpositions_axes": (), - "rot90_axes": (), - "transformation_order": ( - OrderingTransformations.TRANSPOSE.value, - OrderingTransformations.ROTATE_90.value, - OrderingTransformations.REFLECT.value, - ), - }, - [0, 1, 2, 3], - ], - [ - { - "ordering_type": OrderingType.S_CURVE, - "spatial_dims": 2, - "dimensions": (1, 2, 2), - "reflected_spatial_dims": (), - "transpositions_axes": (), - "rot90_axes": (), - "transformation_order": ( - OrderingTransformations.TRANSPOSE.value, - OrderingTransformations.ROTATE_90.value, - OrderingTransformations.REFLECT.value, - ), - }, - [0, 1, 3, 2], - ], - [ - { - "ordering_type": OrderingType.RASTER_SCAN, - "spatial_dims": 2, - "dimensions": (1, 2, 2), - "reflected_spatial_dims": (True, False), - "transpositions_axes": (), - "rot90_axes": (), - "transformation_order": ( - OrderingTransformations.TRANSPOSE.value, - OrderingTransformations.ROTATE_90.value, - OrderingTransformations.REFLECT.value, - ), - }, - [2, 3, 0, 1], - ], - [ - { - "ordering_type": OrderingType.S_CURVE, - "spatial_dims": 2, - "dimensions": (1, 2, 2), - "reflected_spatial_dims": (True, False), - "transpositions_axes": (), - "rot90_axes": (), - "transformation_order": ( - OrderingTransformations.TRANSPOSE.value, - OrderingTransformations.ROTATE_90.value, - OrderingTransformations.REFLECT.value, - ), - }, - [2, 3, 1, 0], - ], - [ - { - "ordering_type": OrderingType.RASTER_SCAN, - "spatial_dims": 2, - "dimensions": (1, 2, 2), - "reflected_spatial_dims": (), - "transpositions_axes": ((1, 0),), - "rot90_axes": (), - "transformation_order": ( - OrderingTransformations.TRANSPOSE.value, - OrderingTransformations.ROTATE_90.value, - OrderingTransformations.REFLECT.value, - ), - }, - [0, 2, 1, 3], - ], - [ - { - "ordering_type": OrderingType.S_CURVE, - "spatial_dims": 2, - "dimensions": (1, 2, 2), - "reflected_spatial_dims": (), - "transpositions_axes": ((1, 0),), - "rot90_axes": (), - "transformation_order": ( - OrderingTransformations.TRANSPOSE.value, - OrderingTransformations.ROTATE_90.value, - OrderingTransformations.REFLECT.value, - ), - }, - [0, 2, 3, 1], - ], - [ - { - "ordering_type": OrderingType.RASTER_SCAN, - "spatial_dims": 2, - "dimensions": (1, 2, 2), - "reflected_spatial_dims": (), - "transpositions_axes": (), - "rot90_axes": ((0, 1),), - "transformation_order": ( - OrderingTransformations.TRANSPOSE.value, - OrderingTransformations.ROTATE_90.value, - OrderingTransformations.REFLECT.value, - ), - }, - [1, 3, 0, 2], - ], - [ - { - "ordering_type": OrderingType.S_CURVE, - "spatial_dims": 2, - "dimensions": (1, 2, 2), - "reflected_spatial_dims": (), - "transpositions_axes": (), - "rot90_axes": ((0, 1),), - "transformation_order": ( - OrderingTransformations.TRANSPOSE.value, - OrderingTransformations.ROTATE_90.value, - OrderingTransformations.REFLECT.value, - ), - }, - [1, 3, 2, 0], - ], - [ - { - "ordering_type": OrderingType.RASTER_SCAN, - "spatial_dims": 2, - "dimensions": (1, 2, 2), - "reflected_spatial_dims": (True, False), - "transpositions_axes": ((1, 0),), - "rot90_axes": ((0, 1),), - "transformation_order": ( - OrderingTransformations.TRANSPOSE.value, - OrderingTransformations.ROTATE_90.value, - OrderingTransformations.REFLECT.value, - ), - }, - [0, 1, 2, 3], - ], - [ - { - "ordering_type": OrderingType.S_CURVE, - "spatial_dims": 2, - "dimensions": (1, 2, 2), - "reflected_spatial_dims": (True, False), - "transpositions_axes": ((1, 0),), - "rot90_axes": ((0, 1),), - "transformation_order": ( - OrderingTransformations.TRANSPOSE.value, - OrderingTransformations.ROTATE_90.value, - OrderingTransformations.REFLECT.value, - ), - }, - [0, 1, 3, 2], - ], -] - -TEST_2D_RANDOM = [ - [ - { - "ordering_type": OrderingType.RANDOM, - "spatial_dims": 2, - "dimensions": (1, 2, 2), - "reflected_spatial_dims": (True, False), - "transpositions_axes": ((1, 0),), - "rot90_axes": ((0, 1),), - "transformation_order": ( - OrderingTransformations.TRANSPOSE.value, - OrderingTransformations.ROTATE_90.value, - OrderingTransformations.REFLECT.value, - ), - }, - [[0, 1, 2, 3], [0, 1, 3, 2]], - ] -] - -TEST_3D = [ - [ - { - "ordering_type": OrderingType.RASTER_SCAN, - "spatial_dims": 3, - "dimensions": (1, 2, 2, 2), - "reflected_spatial_dims": (), - "transpositions_axes": (), - "rot90_axes": (), - "transformation_order": ( - OrderingTransformations.TRANSPOSE.value, - OrderingTransformations.ROTATE_90.value, - OrderingTransformations.REFLECT.value, - ), - }, - [0, 1, 2, 3, 4, 5, 6, 7], - ] -] - -TEST_ORDERING_TYPE_FAILURE = [ - [ - { - "ordering_type": "hilbert", - "spatial_dims": 2, - "dimensions": (1, 2, 2), - "reflected_spatial_dims": (True, False), - "transpositions_axes": ((1, 0),), - "rot90_axes": ((0, 1),), - "transformation_order": ( - OrderingTransformations.TRANSPOSE.value, - OrderingTransformations.ROTATE_90.value, - OrderingTransformations.REFLECT.value, - ), - } - ] -] - -TEST_ORDERING_TRANSFORMATION_FAILURE = [ - [ - { - "ordering_type": OrderingType.S_CURVE, - "spatial_dims": 2, - "dimensions": (1, 2, 2), - "reflected_spatial_dims": (True, False), - "transpositions_axes": ((1, 0),), - "rot90_axes": ((0, 1),), - "transformation_order": ( - OrderingTransformations.TRANSPOSE.value, - OrderingTransformations.ROTATE_90.value, - "flip", - ), - } - ] -] - -TEST_REVERT = [ - [ - { - "ordering_type": OrderingType.S_CURVE, - "spatial_dims": 2, - "dimensions": (1, 2, 2), - "reflected_spatial_dims": (True, False), - "transpositions_axes": (), - "rot90_axes": (), - "transformation_order": ( - OrderingTransformations.TRANSPOSE.value, - OrderingTransformations.ROTATE_90.value, - OrderingTransformations.REFLECT.value, - ), - } - ] -] - - -class TestOrdering(unittest.TestCase): - @parameterized.expand(TEST_2D_NON_RANDOM + TEST_3D) - def test_ordering(self, input_param, expected_sequence_ordering): - ordering = Ordering(**input_param) - self.assertTrue(np.array_equal(ordering.get_sequence_ordering(), expected_sequence_ordering, equal_nan=True)) - - @parameterized.expand(TEST_ORDERING_TYPE_FAILURE) - def test_ordering_type_failure(self, input_param): - with self.assertRaises(ValueError): - Ordering(**input_param) - - @parameterized.expand(TEST_ORDERING_TRANSFORMATION_FAILURE) - def test_ordering_transformation_failure(self, input_param): - with self.assertRaises(ValueError): - Ordering(**input_param) - - @parameterized.expand(TEST_2D_RANDOM) - def test_random(self, input_param, not_in_expected_sequence_ordering): - ordering = Ordering(**input_param) - - not_in = [ - np.array_equal(sequence, ordering.get_sequence_ordering(), equal_nan=True) - for sequence in not_in_expected_sequence_ordering - ] - - self.assertFalse(np.any(not_in)) - - @parameterized.expand(TEST_REVERT) - def test_revert(self, input_param): - sequence = np.random.randint(0, 100, size=input_param["dimensions"]).flatten() - - ordering = Ordering(**input_param) - - reverted_sequence = sequence[ordering.get_sequence_ordering()] - reverted_sequence = reverted_sequence[ordering.get_revert_sequence_ordering()] - - self.assertTrue(np.array_equal(sequence, reverted_sequence, equal_nan=True)) - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/test_patch_gan_dicriminator.py b/tests/test_patch_gan_dicriminator.py deleted file mode 100644 index c19898e70d8..00000000000 --- a/tests/test_patch_gan_dicriminator.py +++ /dev/null @@ -1,179 +0,0 @@ -# Copyright (c) MONAI Consortium -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -import unittest - -import torch -from parameterized import parameterized - -from monai.networks import eval_mode -from monai.networks.nets import MultiScalePatchDiscriminator, PatchDiscriminator -from tests.utils import test_script_save - -TEST_PATCHGAN = [ - [ - { - "num_layers_d": 3, - "spatial_dims": 2, - "channels": 8, - "in_channels": 3, - "out_channels": 1, - "kernel_size": 3, - "activation": "LEAKYRELU", - "norm": "instance", - "bias": False, - "dropout": 0.1, - }, - torch.rand([1, 3, 256, 512]), - (1, 8, 128, 256), - (1, 1, 32, 64), - ], - [ - { - "num_layers_d": 3, - "spatial_dims": 3, - "channels": 8, - "in_channels": 3, - "out_channels": 1, - "kernel_size": 3, - "activation": "LEAKYRELU", - "norm": "instance", - "bias": False, - "dropout": 0.1, - }, - torch.rand([1, 3, 256, 512, 256]), - (1, 8, 128, 256, 128), - (1, 1, 32, 64, 32), - ], -] - -TEST_MULTISCALE_PATCHGAN = [ - [ - { - "num_d": 2, - "num_layers_d": 3, - "spatial_dims": 2, - "channels": 8, - "in_channels": 3, - "out_channels": 1, - "kernel_size": 3, - "activation": "LEAKYRELU", - "norm": "instance", - "bias": False, - "dropout": 0.1, - "minimum_size_im": 256, - }, - torch.rand([1, 3, 256, 512]), - [(1, 1, 32, 64), (1, 1, 4, 8)], - [4, 7], - ], - [ - { - "num_d": 2, - "num_layers_d": 3, - "spatial_dims": 3, - "channels": 8, - "in_channels": 3, - "out_channels": 1, - "kernel_size": 3, - "activation": "LEAKYRELU", - "norm": "instance", - "bias": False, - "dropout": 0.1, - "minimum_size_im": 256, - }, - torch.rand([1, 3, 256, 512, 256]), - [(1, 1, 32, 64, 32), (1, 1, 4, 8, 4)], - [4, 7], - ], -] -TEST_TOO_SMALL_SIZE = [ - { - "num_d": 2, - "num_layers_d": 6, - "spatial_dims": 2, - "channels": 8, - "in_channels": 3, - "out_channels": 1, - "kernel_size": 3, - "activation": "LEAKYRELU", - "norm": "instance", - "bias": False, - "dropout": 0.1, - "minimum_size_im": 256, - } -] - - -class TestPatchGAN(unittest.TestCase): - @parameterized.expand(TEST_PATCHGAN) - def test_shape(self, input_param, input_data, expected_shape_feature, expected_shape_output): - net = PatchDiscriminator(**input_param) - with eval_mode(net): - result = net.forward(input_data) - self.assertEqual(tuple(result[0].shape), expected_shape_feature) - self.assertEqual(tuple(result[-1].shape), expected_shape_output) - - def test_script(self): - net = PatchDiscriminator( - num_layers_d=3, - spatial_dims=2, - channels=8, - in_channels=3, - out_channels=1, - kernel_size=3, - activation="LEAKYRELU", - norm="instance", - bias=False, - dropout=0.1, - ) - i = torch.rand([1, 3, 256, 512]) - test_script_save(net, i) - - -class TestMultiscalePatchGAN(unittest.TestCase): - @parameterized.expand(TEST_MULTISCALE_PATCHGAN) - def test_shape(self, input_param, input_data, expected_shape, features_lengths=None): - net = MultiScalePatchDiscriminator(**input_param) - with eval_mode(net): - result, features = net.forward(input_data) - for r_ind, r in enumerate(result): - self.assertEqual(tuple(r.shape), expected_shape[r_ind]) - for o_d_ind, o_d in enumerate(features): - self.assertEqual(len(o_d), features_lengths[o_d_ind]) - - def test_too_small_shape(self): - with self.assertRaises(AssertionError): - MultiScalePatchDiscriminator(**TEST_TOO_SMALL_SIZE[0]) - - def test_script(self): - net = MultiScalePatchDiscriminator( - num_d=2, - num_layers_d=3, - spatial_dims=2, - channels=8, - in_channels=3, - out_channels=1, - kernel_size=3, - activation="LEAKYRELU", - norm="instance", - bias=False, - dropout=0.1, - minimum_size_im=256, - ) - i = torch.rand([1, 3, 256, 512]) - test_script_save(net, i) - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/test_scheduler_ddim.py b/tests/test_scheduler_ddim.py deleted file mode 100644 index 1a8f8cab679..00000000000 --- a/tests/test_scheduler_ddim.py +++ /dev/null @@ -1,83 +0,0 @@ -# Copyright (c) MONAI Consortium -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -import unittest - -import torch -from parameterized import parameterized - -from monai.networks.schedulers import DDIMScheduler -from tests.utils import assert_allclose - -TEST_2D_CASE = [] -for beta_schedule in ["linear_beta", "scaled_linear_beta"]: - TEST_2D_CASE.append([{"schedule": beta_schedule}, (2, 6, 16, 16), (2, 6, 16, 16)]) - -TEST_3D_CASE = [] -for beta_schedule in ["linear_beta", "scaled_linear_beta"]: - TEST_3D_CASE.append([{"schedule": beta_schedule}, (2, 6, 16, 16, 16), (2, 6, 16, 16, 16)]) - -TEST_CASES = TEST_2D_CASE + TEST_3D_CASE - -TEST_FULl_LOOP = [ - [{"schedule": "linear_beta"}, (1, 1, 2, 2), torch.Tensor([[[[-0.9579, -0.6457], [0.4684, -0.9694]]]])] -] - - -class TestDDPMScheduler(unittest.TestCase): - @parameterized.expand(TEST_CASES) - def test_add_noise(self, input_param, input_shape, expected_shape): - scheduler = DDIMScheduler(**input_param) - scheduler.set_timesteps(num_inference_steps=100) - original_sample = torch.zeros(input_shape) - noise = torch.randn_like(original_sample) - timesteps = torch.randint(0, scheduler.num_train_timesteps, (original_sample.shape[0],)).long() - - noisy = scheduler.add_noise(original_samples=original_sample, noise=noise, timesteps=timesteps) - self.assertEqual(noisy.shape, expected_shape) - - @parameterized.expand(TEST_CASES) - def test_step_shape(self, input_param, input_shape, expected_shape): - scheduler = DDIMScheduler(**input_param) - scheduler.set_timesteps(num_inference_steps=100) - model_output = torch.randn(input_shape) - sample = torch.randn(input_shape) - output_step = scheduler.step(model_output=model_output, timestep=500, sample=sample) - self.assertEqual(output_step[0].shape, expected_shape) - self.assertEqual(output_step[1].shape, expected_shape) - - @parameterized.expand(TEST_FULl_LOOP) - def test_full_timestep_loop(self, input_param, input_shape, expected_output): - scheduler = DDIMScheduler(**input_param) - scheduler.set_timesteps(50) - torch.manual_seed(42) - model_output = torch.randn(input_shape) - sample = torch.randn(input_shape) - for t in range(50): - sample, _ = scheduler.step(model_output=model_output, timestep=t, sample=sample) - assert_allclose(sample, expected_output, rtol=1e-3, atol=1e-3) - - def test_set_timesteps(self): - scheduler = DDIMScheduler(num_train_timesteps=1000) - scheduler.set_timesteps(num_inference_steps=100) - self.assertEqual(scheduler.num_inference_steps, 100) - self.assertEqual(len(scheduler.timesteps), 100) - - def test_set_timesteps_with_num_inference_steps_bigger_than_num_train_timesteps(self): - scheduler = DDIMScheduler(num_train_timesteps=1000) - with self.assertRaises(ValueError): - scheduler.set_timesteps(num_inference_steps=2000) - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/test_scheduler_ddpm.py b/tests/test_scheduler_ddpm.py deleted file mode 100644 index f0447aded2f..00000000000 --- a/tests/test_scheduler_ddpm.py +++ /dev/null @@ -1,104 +0,0 @@ -# Copyright (c) MONAI Consortium -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -import unittest - -import torch -from parameterized import parameterized - -from monai.networks.schedulers import DDPMScheduler -from tests.utils import assert_allclose - -TEST_2D_CASE = [] -for beta_schedule in ["linear_beta", "scaled_linear_beta"]: - for variance_type in ["fixed_small", "fixed_large"]: - TEST_2D_CASE.append( - [{"schedule": beta_schedule, "variance_type": variance_type}, (2, 6, 16, 16), (2, 6, 16, 16)] - ) - -TEST_3D_CASE = [] -for beta_schedule in ["linear_beta", "scaled_linear_beta"]: - for variance_type in ["fixed_small", "fixed_large"]: - TEST_3D_CASE.append( - [{"schedule": beta_schedule, "variance_type": variance_type}, (2, 6, 16, 16, 16), (2, 6, 16, 16, 16)] - ) - -TEST_CASES = TEST_2D_CASE + TEST_3D_CASE - -TEST_FULl_LOOP = [ - [{"schedule": "linear_beta"}, (1, 1, 2, 2), torch.Tensor([[[[-1.0153, -0.3218], [0.8454, -0.7870]]]])] -] - - -class TestDDPMScheduler(unittest.TestCase): - @parameterized.expand(TEST_CASES) - def test_add_noise(self, input_param, input_shape, expected_shape): - scheduler = DDPMScheduler(**input_param) - original_sample = torch.zeros(input_shape) - noise = torch.randn_like(original_sample) - timesteps = torch.randint(0, scheduler.num_train_timesteps, (original_sample.shape[0],)).long() - - noisy = scheduler.add_noise(original_samples=original_sample, noise=noise, timesteps=timesteps) - self.assertEqual(noisy.shape, expected_shape) - - @parameterized.expand(TEST_CASES) - def test_step_shape(self, input_param, input_shape, expected_shape): - scheduler = DDPMScheduler(**input_param) - model_output = torch.randn(input_shape) - sample = torch.randn(input_shape) - output_step = scheduler.step(model_output=model_output, timestep=500, sample=sample) - self.assertEqual(output_step[0].shape, expected_shape) - self.assertEqual(output_step[1].shape, expected_shape) - - @parameterized.expand(TEST_FULl_LOOP) - def test_full_timestep_loop(self, input_param, input_shape, expected_output): - scheduler = DDPMScheduler(**input_param) - scheduler.set_timesteps(50) - torch.manual_seed(42) - model_output = torch.randn(input_shape) - sample = torch.randn(input_shape) - for t in range(50): - sample, _ = scheduler.step(model_output=model_output, timestep=t, sample=sample) - assert_allclose(sample, expected_output, rtol=1e-3, atol=1e-3) - - @parameterized.expand(TEST_CASES) - def test_get_velocity_shape(self, input_param, input_shape, expected_shape): - scheduler = DDPMScheduler(**input_param) - sample = torch.randn(input_shape) - timesteps = torch.randint(0, scheduler.num_train_timesteps, (input_shape[0],)).long() - velocity = scheduler.get_velocity(sample=sample, noise=sample, timesteps=timesteps) - self.assertEqual(velocity.shape, expected_shape) - - def test_step_learned(self): - for variance_type in ["learned", "learned_range"]: - scheduler = DDPMScheduler(variance_type=variance_type) - model_output = torch.randn(2, 6, 16, 16) - sample = torch.randn(2, 3, 16, 16) - output_step = scheduler.step(model_output=model_output, timestep=500, sample=sample) - self.assertEqual(output_step[0].shape, sample.shape) - self.assertEqual(output_step[1].shape, sample.shape) - - def test_set_timesteps(self): - scheduler = DDPMScheduler(num_train_timesteps=1000) - scheduler.set_timesteps(num_inference_steps=100) - self.assertEqual(scheduler.num_inference_steps, 100) - self.assertEqual(len(scheduler.timesteps), 100) - - def test_set_timesteps_with_num_inference_steps_bigger_than_num_train_timesteps(self): - scheduler = DDPMScheduler(num_train_timesteps=1000) - with self.assertRaises(ValueError): - scheduler.set_timesteps(num_inference_steps=2000) - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/test_scheduler_pndm.py b/tests/test_scheduler_pndm.py deleted file mode 100644 index 69e5e403f5f..00000000000 --- a/tests/test_scheduler_pndm.py +++ /dev/null @@ -1,108 +0,0 @@ -# Copyright (c) MONAI Consortium -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -import unittest - -import torch -from parameterized import parameterized - -from monai.networks.schedulers import PNDMScheduler -from tests.utils import assert_allclose - -TEST_2D_CASE = [] -for beta_schedule in ["linear_beta", "scaled_linear_beta"]: - TEST_2D_CASE.append([{"schedule": beta_schedule}, (2, 6, 16, 16), (2, 6, 16, 16)]) - -TEST_3D_CASE = [] -for beta_schedule in ["linear_beta", "scaled_linear_beta"]: - TEST_3D_CASE.append([{"schedule": beta_schedule}, (2, 6, 16, 16, 16), (2, 6, 16, 16, 16)]) - -TEST_CASES = TEST_2D_CASE + TEST_3D_CASE - -TEST_FULl_LOOP = [ - [ - {"schedule": "linear_beta"}, - (1, 1, 2, 2), - torch.Tensor([[[[-2123055.2500, -459014.2812], [2863438.0000, -1263401.7500]]]]), - ] -] - - -class TestDDPMScheduler(unittest.TestCase): - @parameterized.expand(TEST_CASES) - def test_add_noise(self, input_param, input_shape, expected_shape): - scheduler = PNDMScheduler(**input_param) - original_sample = torch.zeros(input_shape) - noise = torch.randn_like(original_sample) - timesteps = torch.randint(0, scheduler.num_train_timesteps, (original_sample.shape[0],)).long() - noisy = scheduler.add_noise(original_samples=original_sample, noise=noise, timesteps=timesteps) - self.assertEqual(noisy.shape, expected_shape) - - @parameterized.expand(TEST_CASES) - def test_step_shape(self, input_param, input_shape, expected_shape): - scheduler = PNDMScheduler(**input_param) - scheduler.set_timesteps(600) - model_output = torch.randn(input_shape) - sample = torch.randn(input_shape) - output_step = scheduler.step(model_output=model_output, timestep=500, sample=sample) - self.assertEqual(output_step[0].shape, expected_shape) - self.assertEqual(output_step[1], None) - - @parameterized.expand(TEST_FULl_LOOP) - def test_full_timestep_loop(self, input_param, input_shape, expected_output): - scheduler = PNDMScheduler(**input_param) - scheduler.set_timesteps(50) - torch.manual_seed(42) - model_output = torch.randn(input_shape) - sample = torch.randn(input_shape) - for t in range(50): - sample, _ = scheduler.step(model_output=model_output, timestep=t, sample=sample) - assert_allclose(sample, expected_output, rtol=1e-3, atol=1e-3) - - @parameterized.expand(TEST_FULl_LOOP) - def test_timestep_two_loops(self, input_param, input_shape, expected_output): - scheduler = PNDMScheduler(**input_param) - scheduler.set_timesteps(50) - torch.manual_seed(42) - model_output = torch.randn(input_shape) - sample = torch.randn(input_shape) - for t in range(50): - sample, _ = scheduler.step(model_output=model_output, timestep=t, sample=sample) - torch.manual_seed(42) - model_output2 = torch.randn(input_shape) - sample2 = torch.randn(input_shape) - scheduler.set_timesteps(50) - for t in range(50): - sample2, _ = scheduler.step(model_output=model_output2, timestep=t, sample=sample2) - assert_allclose(sample, sample2, rtol=1e-3, atol=1e-3) - - def test_set_timesteps(self): - scheduler = PNDMScheduler(num_train_timesteps=1000, skip_prk_steps=True) - scheduler.set_timesteps(num_inference_steps=100) - self.assertEqual(scheduler.num_inference_steps, 100) - self.assertEqual(len(scheduler.timesteps), 100) - - def test_set_timesteps_prk(self): - scheduler = PNDMScheduler(num_train_timesteps=1000, skip_prk_steps=False) - scheduler.set_timesteps(num_inference_steps=100) - self.assertEqual(scheduler.num_inference_steps, 109) - self.assertEqual(len(scheduler.timesteps), 109) - - def test_set_timesteps_with_num_inference_steps_bigger_than_num_train_timesteps(self): - scheduler = PNDMScheduler(num_train_timesteps=1000) - with self.assertRaises(ValueError): - scheduler.set_timesteps(num_inference_steps=2000) - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/test_transformer.py b/tests/test_transformer.py deleted file mode 100644 index ea6ebdf50f0..00000000000 --- a/tests/test_transformer.py +++ /dev/null @@ -1,73 +0,0 @@ -# Copyright (c) MONAI Consortium -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -import unittest - -import numpy as np -import torch -from parameterized import parameterized - -from monai.networks import eval_mode -from monai.networks.nets import DecoderOnlyTransformer - -TEST_CASES = [] -for dropout_rate in np.linspace(0, 1, 2): - for attention_layer_dim in [360, 480, 600, 768]: - for num_heads in [4, 6, 8, 12]: - TEST_CASES.append( - [ - { - "num_tokens": 10, - "max_seq_len": 16, - "attn_layers_dim": attention_layer_dim, - "attn_layers_depth": 2, - "attn_layers_heads": num_heads, - "embedding_dropout_rate": dropout_rate, - } - ] - ) - - -class TestDecoderOnlyTransformer(unittest.TestCase): - @parameterized.expand(TEST_CASES) - def test_unconditioned_models(self, input_param): - net = DecoderOnlyTransformer(**input_param) - with eval_mode(net): - net.forward(torch.randint(0, 10, (1, 16))) - - @parameterized.expand(TEST_CASES) - def test_conditioned_models(self, input_param): - net = DecoderOnlyTransformer(**input_param, with_cross_attention=True) - with eval_mode(net): - net.forward(torch.randint(0, 10, (1, 16)), context=torch.randn(1, 3, input_param["attn_layers_dim"])) - - def test_attention_dim_not_multiple_of_heads(self): - with self.assertRaises(ValueError): - DecoderOnlyTransformer( - num_tokens=10, max_seq_len=16, attn_layers_dim=8, attn_layers_depth=2, attn_layers_heads=3 - ) - - def test_dropout_rate_negative(self): - with self.assertRaises(ValueError): - DecoderOnlyTransformer( - num_tokens=10, - max_seq_len=16, - attn_layers_dim=8, - attn_layers_depth=2, - attn_layers_heads=2, - embedding_dropout_rate=-1, - ) - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/test_vector_quantizer.py b/tests/test_vector_quantizer.py deleted file mode 100644 index 43533d03771..00000000000 --- a/tests/test_vector_quantizer.py +++ /dev/null @@ -1,89 +0,0 @@ -# Copyright (c) MONAI Consortium -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -import unittest -from math import prod - -import torch -from parameterized import parameterized - -from monai.networks.layers import EMAQuantizer, VectorQuantizer - -TEST_CASES = [ - [{"spatial_dims": 2, "num_embeddings": 16, "embedding_dim": 8}, (1, 8, 4, 4), (1, 4, 4)], - [{"spatial_dims": 3, "num_embeddings": 16, "embedding_dim": 8}, (1, 8, 4, 4, 4), (1, 4, 4, 4)], -] - - -class TestEMA(unittest.TestCase): - @parameterized.expand(TEST_CASES) - def test_ema_shape(self, input_param, input_shape, output_shape): - layer = EMAQuantizer(**input_param) - x = torch.randn(input_shape) - layer = layer.train() - outputs = layer(x) - self.assertEqual(outputs[0].shape, input_shape) - self.assertEqual(outputs[2].shape, output_shape) - - layer = layer.eval() - outputs = layer(x) - self.assertEqual(outputs[0].shape, input_shape) - self.assertEqual(outputs[2].shape, output_shape) - - @parameterized.expand(TEST_CASES) - def test_ema_quantize(self, input_param, input_shape, output_shape): - layer = EMAQuantizer(**input_param) - x = torch.randn(input_shape) - outputs = layer.quantize(x) - self.assertEqual(outputs[0].shape, (prod(input_shape[2:]), input_shape[1])) # (HxW[xD], C) - self.assertEqual(outputs[1].shape, (prod(input_shape[2:]), input_param["num_embeddings"])) # (HxW[xD], E) - self.assertEqual(outputs[2].shape, (input_shape[0],) + input_shape[2:]) # (1, H, W, [D]) - - def test_ema(self): - layer = EMAQuantizer(spatial_dims=2, num_embeddings=2, embedding_dim=2, epsilon=0, decay=0) - original_weight_0 = layer.embedding.weight[0].clone() - original_weight_1 = layer.embedding.weight[1].clone() - x_0 = original_weight_0 - x_0 = x_0.unsqueeze(0).unsqueeze(-1).unsqueeze(-1) - x_0 = x_0.repeat(1, 1, 1, 2) + 0.001 - - x_1 = original_weight_1 - x_1 = x_1.unsqueeze(0).unsqueeze(-1).unsqueeze(-1) - x_1 = x_1.repeat(1, 1, 1, 2) - - x = torch.cat([x_0, x_1], dim=0) - layer = layer.train() - _ = layer(x) - - self.assertTrue(all(layer.embedding.weight[0] != original_weight_0)) - self.assertTrue(all(layer.embedding.weight[1] == original_weight_1)) - - -class TestVectorQuantizer(unittest.TestCase): - @parameterized.expand(TEST_CASES) - def test_vector_quantizer_shape(self, input_param, input_shape, output_shape): - layer = VectorQuantizer(EMAQuantizer(**input_param)) - x = torch.randn(input_shape) - outputs = layer(x) - self.assertEqual(outputs[1].shape, input_shape) - - @parameterized.expand(TEST_CASES) - def test_vector_quantizer_quantize(self, input_param, input_shape, output_shape): - layer = VectorQuantizer(EMAQuantizer(**input_param)) - x = torch.randn(input_shape) - outputs = layer.quantize(x) - self.assertEqual(outputs.shape, output_shape) - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/test_vqvae.py b/tests/test_vqvae.py deleted file mode 100644 index 4916dc2faad..00000000000 --- a/tests/test_vqvae.py +++ /dev/null @@ -1,274 +0,0 @@ -# Copyright (c) MONAI Consortium -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -import unittest - -import torch -from parameterized import parameterized - -from monai.networks import eval_mode -from monai.networks.nets.vqvae import VQVAE -from tests.utils import SkipIfBeforePyTorchVersion - -TEST_CASES = [ - [ - { - "spatial_dims": 2, - "in_channels": 1, - "out_channels": 1, - "channels": (4, 4), - "num_res_layers": 1, - "num_res_channels": (4, 4), - "downsample_parameters": ((2, 4, 1, 1),) * 2, - "upsample_parameters": ((2, 4, 1, 1, 0),) * 2, - "num_embeddings": 8, - "embedding_dim": 8, - }, - (1, 1, 8, 8), - (1, 1, 8, 8), - ], - [ - { - "spatial_dims": 3, - "in_channels": 1, - "out_channels": 1, - "channels": (4, 4), - "num_res_layers": 1, - "num_res_channels": 4, - "downsample_parameters": ((2, 4, 1, 1),) * 2, - "upsample_parameters": ((2, 4, 1, 1, 0),) * 2, - "num_embeddings": 8, - "embedding_dim": 8, - }, - (1, 1, 8, 8, 8), - (1, 1, 8, 8, 8), - ], - [ - { - "spatial_dims": 2, - "in_channels": 1, - "out_channels": 1, - "channels": (4, 4), - "num_res_layers": 1, - "num_res_channels": (4, 4), - "downsample_parameters": (2, 4, 1, 1), - "upsample_parameters": ((2, 4, 1, 1, 0),) * 2, - "num_embeddings": 8, - "embedding_dim": 8, - }, - (1, 1, 8, 8), - (1, 1, 8, 8), - ], - [ - { - "spatial_dims": 3, - "in_channels": 1, - "out_channels": 1, - "channels": (4, 4), - "num_res_layers": 1, - "num_res_channels": (4, 4), - "downsample_parameters": ((2, 4, 1, 1),) * 2, - "upsample_parameters": (2, 4, 1, 1, 0), - "num_embeddings": 8, - "embedding_dim": 8, - }, - (1, 1, 8, 8, 8), - (1, 1, 8, 8, 8), - ], -] - -TEST_LATENT_SHAPE = { - "spatial_dims": 2, - "in_channels": 1, - "out_channels": 1, - "downsample_parameters": ((2, 4, 1, 1),) * 2, - "upsample_parameters": ((2, 4, 1, 1, 0),) * 2, - "num_res_layers": 1, - "channels": (8, 8), - "num_res_channels": (8, 8), - "num_embeddings": 16, - "embedding_dim": 8, -} - - -class TestVQVAE(unittest.TestCase): - @parameterized.expand(TEST_CASES) - def test_shape(self, input_param, input_shape, expected_shape): - device = "cuda" if torch.cuda.is_available() else "cpu" - - net = VQVAE(**input_param).to(device) - - with eval_mode(net): - result, _ = net(torch.randn(input_shape).to(device)) - - self.assertEqual(result.shape, expected_shape) - - @parameterized.expand(TEST_CASES) - @SkipIfBeforePyTorchVersion((1, 11)) - def test_shape_with_checkpoint(self, input_param, input_shape, expected_shape): - device = "cuda" if torch.cuda.is_available() else "cpu" - input_param = input_param.copy() - input_param.update({"use_checkpointing": True}) - - net = VQVAE(**input_param).to(device) - - with eval_mode(net): - result, _ = net(torch.randn(input_shape).to(device)) - - self.assertEqual(result.shape, expected_shape) - - # Removed this test case since TorchScript currently does not support activation checkpoint. - # def test_script(self): - # net = VQVAE( - # spatial_dims=2, - # in_channels=1, - # out_channels=1, - # downsample_parameters=((2, 4, 1, 1),) * 2, - # upsample_parameters=((2, 4, 1, 1, 0),) * 2, - # num_res_layers=1, - # channels=(8, 8), - # num_res_channels=(8, 8), - # num_embeddings=16, - # embedding_dim=8, - # ddp_sync=False, - # ) - # test_data = torch.randn(1, 1, 16, 16) - # test_script_save(net, test_data) - - def test_channels_not_same_size_of_num_res_channels(self): - with self.assertRaises(ValueError): - VQVAE( - spatial_dims=2, - in_channels=1, - out_channels=1, - channels=(16, 16), - num_res_channels=(16, 16, 16), - downsample_parameters=((2, 4, 1, 1),) * 2, - upsample_parameters=((2, 4, 1, 1, 0),) * 2, - ) - - def test_channels_not_same_size_of_downsample_parameters(self): - with self.assertRaises(ValueError): - VQVAE( - spatial_dims=2, - in_channels=1, - out_channels=1, - channels=(16, 16), - num_res_channels=(16, 16), - downsample_parameters=((2, 4, 1, 1),) * 3, - upsample_parameters=((2, 4, 1, 1, 0),) * 2, - ) - - def test_channels_not_same_size_of_upsample_parameters(self): - with self.assertRaises(ValueError): - VQVAE( - spatial_dims=2, - in_channels=1, - out_channels=1, - channels=(16, 16), - num_res_channels=(16, 16), - downsample_parameters=((2, 4, 1, 1),) * 2, - upsample_parameters=((2, 4, 1, 1, 0),) * 3, - ) - - def test_downsample_parameters_not_sequence_or_int(self): - with self.assertRaises(ValueError): - VQVAE( - spatial_dims=2, - in_channels=1, - out_channels=1, - channels=(16, 16), - num_res_channels=(16, 16), - downsample_parameters=(("test", 4, 1, 1),) * 2, - upsample_parameters=((2, 4, 1, 1, 0),) * 2, - ) - - def test_upsample_parameters_not_sequence_or_int(self): - with self.assertRaises(ValueError): - VQVAE( - spatial_dims=2, - in_channels=1, - out_channels=1, - channels=(16, 16), - num_res_channels=(16, 16), - downsample_parameters=((2, 4, 1, 1),) * 2, - upsample_parameters=(("test", 4, 1, 1, 0),) * 2, - ) - - def test_downsample_parameter_length_different_4(self): - with self.assertRaises(ValueError): - VQVAE( - spatial_dims=2, - in_channels=1, - out_channels=1, - channels=(16, 16), - num_res_channels=(16, 16), - downsample_parameters=((2, 4, 1),) * 3, - upsample_parameters=((2, 4, 1, 1, 0),) * 2, - ) - - def test_upsample_parameter_length_different_5(self): - with self.assertRaises(ValueError): - VQVAE( - spatial_dims=2, - in_channels=1, - out_channels=1, - channels=(16, 16), - num_res_channels=(16, 16, 16), - downsample_parameters=((2, 4, 1, 1),) * 2, - upsample_parameters=((2, 4, 1, 1, 0, 1),) * 3, - ) - - def test_encode_shape(self): - device = "cuda" if torch.cuda.is_available() else "cpu" - - net = VQVAE(**TEST_LATENT_SHAPE).to(device) - - with eval_mode(net): - latent = net.encode(torch.randn(1, 1, 32, 32).to(device)) - - self.assertEqual(latent.shape, (1, 8, 8, 8)) - - def test_index_quantize_shape(self): - device = "cuda" if torch.cuda.is_available() else "cpu" - - net = VQVAE(**TEST_LATENT_SHAPE).to(device) - - with eval_mode(net): - latent = net.index_quantize(torch.randn(1, 1, 32, 32).to(device)) - - self.assertEqual(latent.shape, (1, 8, 8)) - - def test_decode_shape(self): - device = "cuda" if torch.cuda.is_available() else "cpu" - - net = VQVAE(**TEST_LATENT_SHAPE).to(device) - - with eval_mode(net): - latent = net.decode(torch.randn(1, 8, 8, 8).to(device)) - - self.assertEqual(latent.shape, (1, 1, 32, 32)) - - def test_decode_samples_shape(self): - device = "cuda" if torch.cuda.is_available() else "cpu" - - net = VQVAE(**TEST_LATENT_SHAPE).to(device) - - with eval_mode(net): - latent = net.decode_samples(torch.randint(low=0, high=16, size=(1, 8, 8)).to(device)) - - self.assertEqual(latent.shape, (1, 1, 32, 32)) - - -if __name__ == "__main__": - unittest.main()