From 651f6f0fe1e66b2ca4ac7886187779f0faeded8d Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Tue, 12 Dec 2023 10:23:22 +0000 Subject: [PATCH 1/6] Commit as in monai-gen Signed-off-by: Mark Graham --- monai/networks/nets/controlnet.py | 416 ++++++++++++++++++++++++++++++ tests/test_controlnet.py | 54 ++++ 2 files changed, 470 insertions(+) create mode 100644 monai/networks/nets/controlnet.py create mode 100644 tests/test_controlnet.py diff --git a/monai/networks/nets/controlnet.py b/monai/networks/nets/controlnet.py new file mode 100644 index 0000000000..ace44cba0a --- /dev/null +++ b/monai/networks/nets/controlnet.py @@ -0,0 +1,416 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# ========================================================================= +# Adapted from https://github.com/huggingface/diffusers +# which has the following license: +# https://github.com/huggingface/diffusers/blob/main/LICENSE +# +# Copyright 2022 UC Berkeley Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========================================================================= + +from __future__ import annotations + +from collections.abc import Sequence + +import torch +import torch.nn.functional as F +from torch import nn + +from monai.networks.blocks import Convolution +from monai.networks.nets.diffusion_model_unet import get_down_block, get_mid_block, get_timestep_embedding +from monai.utils import ensure_tuple_rep + + +class ControlNetConditioningEmbedding(nn.Module): + """ + Network to encode the conditioning into a latent space. + """ + + def __init__( + self, spatial_dims: int, in_channels: int, out_channels: int, num_channels: Sequence[int] = (16, 32, 96, 256) + ): + super().__init__() + + self.conv_in = Convolution( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=num_channels[0], + strides=1, + kernel_size=3, + padding=1, + conv_only=True, + ) + + self.blocks = nn.ModuleList([]) + + for i in range(len(num_channels) - 1): + channel_in = num_channels[i] + channel_out = num_channels[i + 1] + self.blocks.append( + Convolution( + spatial_dims=spatial_dims, + in_channels=channel_in, + out_channels=channel_in, + strides=1, + kernel_size=3, + padding=1, + conv_only=True, + ) + ) + + self.blocks.append( + Convolution( + spatial_dims=spatial_dims, + in_channels=channel_in, + out_channels=channel_out, + strides=2, + kernel_size=3, + padding=1, + conv_only=True, + ) + ) + + self.conv_out = zero_module( + Convolution( + spatial_dims=spatial_dims, + in_channels=num_channels[-1], + out_channels=out_channels, + strides=1, + kernel_size=3, + padding=1, + conv_only=True, + ) + ) + + def forward(self, conditioning): + embedding = self.conv_in(conditioning) + embedding = F.silu(embedding) + + for block in self.blocks: + embedding = block(embedding) + embedding = F.silu(embedding) + + embedding = self.conv_out(embedding) + + return embedding + + +def zero_module(module): + for p in module.parameters(): + nn.init.zeros_(p) + return module + + +class ControlNet(nn.Module): + """ + Control network for diffusion models based on Zhang and Agrawala "Adding Conditional Control to Text-to-Image + Diffusion Models" (https://arxiv.org/abs/2302.05543) + + Args: + spatial_dims: number of spatial dimensions. + in_channels: number of input channels. + num_res_blocks: number of residual blocks (see ResnetBlock) per level. + num_channels: tuple of block output channels. + attention_levels: list of levels to add attention. + norm_num_groups: number of groups for the normalization. + norm_eps: epsilon for the normalization. + resblock_updown: if True use residual blocks for up/downsampling. + num_head_channels: number of channels in each attention head. + with_conditioning: if True add spatial transformers to perform conditioning. + transformer_num_layers: number of layers of Transformer blocks to use. + cross_attention_dim: number of context dimensions to use. + num_class_embeds: if specified (as an int), then this model will be class-conditional with `num_class_embeds` + classes. + upcast_attention: if True, upcast attention operations to full precision. + use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. + conditioning_embedding_in_channels: number of input channels for the conditioning embedding. + conditioning_embedding_num_channels: number of channels for the blocks in the conditioning embedding. + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + num_res_blocks: Sequence[int] | int = (2, 2, 2, 2), + num_channels: Sequence[int] = (32, 64, 64, 64), + attention_levels: Sequence[bool] = (False, False, True, True), + norm_num_groups: int = 32, + norm_eps: float = 1e-6, + resblock_updown: bool = False, + num_head_channels: int | Sequence[int] = 8, + with_conditioning: bool = False, + transformer_num_layers: int = 1, + cross_attention_dim: int | None = None, + num_class_embeds: int | None = None, + upcast_attention: bool = False, + use_flash_attention: bool = False, + conditioning_embedding_in_channels: int = 1, + conditioning_embedding_num_channels: Sequence[int] | None = (16, 32, 96, 256), + ) -> None: + super().__init__() + if with_conditioning is True and cross_attention_dim is None: + raise ValueError( + "DiffusionModelUNet expects dimension of the cross-attention conditioning (cross_attention_dim) " + "when using with_conditioning." + ) + if cross_attention_dim is not None and with_conditioning is False: + raise ValueError( + "DiffusionModelUNet expects with_conditioning=True when specifying the cross_attention_dim." + ) + + # All number of channels should be multiple of num_groups + if any((out_channel % norm_num_groups) != 0 for out_channel in num_channels): + raise ValueError("DiffusionModelUNet expects all num_channels being multiple of norm_num_groups") + + if len(num_channels) != len(attention_levels): + raise ValueError("DiffusionModelUNet expects num_channels being same size of attention_levels") + + if isinstance(num_head_channels, int): + num_head_channels = ensure_tuple_rep(num_head_channels, len(attention_levels)) + + if len(num_head_channels) != len(attention_levels): + raise ValueError( + "num_head_channels should have the same length as attention_levels. For the i levels without attention," + " i.e. `attention_level[i]=False`, the num_head_channels[i] will be ignored." + ) + + if isinstance(num_res_blocks, int): + num_res_blocks = ensure_tuple_rep(num_res_blocks, len(num_channels)) + + if len(num_res_blocks) != len(num_channels): + raise ValueError( + "`num_res_blocks` should be a single integer or a tuple of integers with the same length as " + "`num_channels`." + ) + + if use_flash_attention is True and not torch.cuda.is_available(): + raise ValueError( + "torch.cuda.is_available() should be True but is False. Flash attention is only available for GPU." + ) + + self.in_channels = in_channels + self.block_out_channels = num_channels + self.num_res_blocks = num_res_blocks + self.attention_levels = attention_levels + self.num_head_channels = num_head_channels + self.with_conditioning = with_conditioning + + # input + self.conv_in = Convolution( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=num_channels[0], + strides=1, + kernel_size=3, + padding=1, + conv_only=True, + ) + + # time + time_embed_dim = num_channels[0] * 4 + self.time_embed = nn.Sequential( + nn.Linear(num_channels[0], time_embed_dim), nn.SiLU(), nn.Linear(time_embed_dim, time_embed_dim) + ) + + # class embedding + self.num_class_embeds = num_class_embeds + if num_class_embeds is not None: + self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) + + # control net conditioning embedding + self.controlnet_cond_embedding = ControlNetConditioningEmbedding( + spatial_dims=spatial_dims, + in_channels=conditioning_embedding_in_channels, + num_channels=conditioning_embedding_num_channels, + out_channels=num_channels[0], + ) + + # down + self.down_blocks = nn.ModuleList([]) + self.controlnet_down_blocks = nn.ModuleList([]) + output_channel = num_channels[0] + + controlnet_block = Convolution( + spatial_dims=spatial_dims, + in_channels=output_channel, + out_channels=output_channel, + strides=1, + kernel_size=1, + padding=0, + conv_only=True, + ) + controlnet_block = zero_module(controlnet_block.conv) + self.controlnet_down_blocks.append(controlnet_block) + + for i in range(len(num_channels)): + input_channel = output_channel + output_channel = num_channels[i] + is_final_block = i == len(num_channels) - 1 + + down_block = get_down_block( + spatial_dims=spatial_dims, + in_channels=input_channel, + out_channels=output_channel, + temb_channels=time_embed_dim, + num_res_blocks=num_res_blocks[i], + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + add_downsample=not is_final_block, + resblock_updown=resblock_updown, + with_attn=(attention_levels[i] and not with_conditioning), + with_cross_attn=(attention_levels[i] and with_conditioning), + num_head_channels=num_head_channels[i], + transformer_num_layers=transformer_num_layers, + cross_attention_dim=cross_attention_dim, + upcast_attention=upcast_attention, + use_flash_attention=use_flash_attention, + ) + + self.down_blocks.append(down_block) + + for _ in range(num_res_blocks[i]): + controlnet_block = Convolution( + spatial_dims=spatial_dims, + in_channels=output_channel, + out_channels=output_channel, + strides=1, + kernel_size=1, + padding=0, + conv_only=True, + ) + controlnet_block = zero_module(controlnet_block) + self.controlnet_down_blocks.append(controlnet_block) + # + if not is_final_block: + controlnet_block = Convolution( + spatial_dims=spatial_dims, + in_channels=output_channel, + out_channels=output_channel, + strides=1, + kernel_size=1, + padding=0, + conv_only=True, + ) + controlnet_block = zero_module(controlnet_block) + self.controlnet_down_blocks.append(controlnet_block) + + # mid + mid_block_channel = num_channels[-1] + + self.middle_block = get_mid_block( + spatial_dims=spatial_dims, + in_channels=mid_block_channel, + temb_channels=time_embed_dim, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + with_conditioning=with_conditioning, + num_head_channels=num_head_channels[-1], + transformer_num_layers=transformer_num_layers, + cross_attention_dim=cross_attention_dim, + upcast_attention=upcast_attention, + use_flash_attention=use_flash_attention, + ) + + controlnet_block = Convolution( + spatial_dims=spatial_dims, + in_channels=output_channel, + out_channels=output_channel, + strides=1, + kernel_size=1, + padding=0, + conv_only=True, + ) + controlnet_block = zero_module(controlnet_block) + self.controlnet_mid_block = controlnet_block + + def forward( + self, + x: torch.Tensor, + timesteps: torch.Tensor, + controlnet_cond: torch.Tensor, + conditioning_scale: float = 1.0, + context: torch.Tensor | None = None, + class_labels: torch.Tensor | None = None, + ) -> tuple[tuple[torch.Tensor], torch.Tensor]: + """ + Args: + x: input tensor (N, C, SpatialDims). + timesteps: timestep tensor (N,). + controlnet_cond: controlnet conditioning tensor (N, C, SpatialDims). + conditioning_scale: conditioning scale. + context: context tensor (N, 1, ContextDim). + class_labels: context tensor (N, ). + """ + # 1. time + t_emb = get_timestep_embedding(timesteps, self.block_out_channels[0]) + + # timesteps does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=x.dtype) + emb = self.time_embed(t_emb) + + # 2. class + if self.num_class_embeds is not None: + if class_labels is None: + raise ValueError("class_labels should be provided when num_class_embeds > 0") + class_emb = self.class_embedding(class_labels) + class_emb = class_emb.to(dtype=x.dtype) + emb = emb + class_emb + + # 3. initial convolution + h = self.conv_in(x) + + controlnet_cond = self.controlnet_cond_embedding(controlnet_cond) + + h += controlnet_cond + + # 4. down + if context is not None and self.with_conditioning is False: + raise ValueError("model should have with_conditioning = True if context is provided") + down_block_res_samples: list[torch.Tensor] = [h] + for downsample_block in self.down_blocks: + h, res_samples = downsample_block(hidden_states=h, temb=emb, context=context) + for residual in res_samples: + down_block_res_samples.append(residual) + + # 5. mid + h = self.middle_block(hidden_states=h, temb=emb, context=context) + + # 6. Control net blocks + controlnet_down_block_res_samples = () + + for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks): + down_block_res_sample = controlnet_block(down_block_res_sample) + controlnet_down_block_res_samples += (down_block_res_sample,) + + down_block_res_samples = controlnet_down_block_res_samples + + mid_block_res_sample = self.controlnet_mid_block(h) + + # 6. scaling + down_block_res_samples = [h * conditioning_scale for h in down_block_res_samples] + mid_block_res_sample *= conditioning_scale + + return down_block_res_samples, mid_block_res_sample diff --git a/tests/test_controlnet.py b/tests/test_controlnet.py new file mode 100644 index 0000000000..27acccc083 --- /dev/null +++ b/tests/test_controlnet.py @@ -0,0 +1,54 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import torch +from parameterized import parameterized + +from monai.networks import eval_mode +from monai.networks.nets.controlnet import ControlNet + +TEST_CASES = [ + [ + { + "spatial_dims": 2, + "in_channels": 1, + "num_res_blocks": 1, + "num_channels": (8, 8, 8), + "attention_levels": (False, False, True), + "num_head_channels": 8, + "norm_num_groups": 8, + "conditioning_embedding_in_channels": 1, + "conditioning_embedding_num_channels": (8, 8), + }, + 6, + (1, 8, 4, 4), + ] +] + + +class TestControlNet(unittest.TestCase): + @parameterized.expand(TEST_CASES) + def test_shape_unconditioned_models(self, input_param, expected_num_down_blocks_residuals, expected_shape): + net = ControlNet(**input_param) + with eval_mode(net): + result = net.forward( + torch.rand((1, 1, 16, 16)), torch.randint(0, 1000, (1,)).long(), torch.rand((1, 1, 32, 32)) + ) + self.assertEqual(len(result[0]), expected_num_down_blocks_residuals) + self.assertEqual(result[1].shape, expected_shape) + + +if __name__ == "__main__": + unittest.main() From 27d6d7c7fd2efbb4e856967d5bbc56e65b6459f3 Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Tue, 12 Dec 2023 10:46:42 +0000 Subject: [PATCH 2/6] Fixes mypy errors Signed-off-by: Mark Graham --- monai/networks/nets/__init__.py | 1 + monai/networks/nets/controlnet.py | 58 +++++++++++++++---------------- tests/test_controlnet.py | 2 +- 3 files changed, 30 insertions(+), 31 deletions(-) diff --git a/monai/networks/nets/__init__.py b/monai/networks/nets/__init__.py index 31fbd73b4e..58cb652bae 100644 --- a/monai/networks/nets/__init__.py +++ b/monai/networks/nets/__init__.py @@ -18,6 +18,7 @@ from .basic_unet import BasicUNet, BasicUnet, Basicunet, basicunet from .basic_unetplusplus import BasicUNetPlusPlus, BasicUnetPlusPlus, BasicunetPlusPlus, basicunetplusplus from .classifier import Classifier, Critic, Discriminator +from .controlnet import ControlNet from .daf3d import DAF3D from .densenet import ( DenseNet, diff --git a/monai/networks/nets/controlnet.py b/monai/networks/nets/controlnet.py index ace44cba0a..7032769809 100644 --- a/monai/networks/nets/controlnet.py +++ b/monai/networks/nets/controlnet.py @@ -47,15 +47,13 @@ class ControlNetConditioningEmbedding(nn.Module): Network to encode the conditioning into a latent space. """ - def __init__( - self, spatial_dims: int, in_channels: int, out_channels: int, num_channels: Sequence[int] = (16, 32, 96, 256) - ): + def __init__(self, spatial_dims: int, in_channels: int, out_channels: int, channels: Sequence[int]): super().__init__() self.conv_in = Convolution( spatial_dims=spatial_dims, in_channels=in_channels, - out_channels=num_channels[0], + out_channels=channels[0], strides=1, kernel_size=3, padding=1, @@ -64,9 +62,9 @@ def __init__( self.blocks = nn.ModuleList([]) - for i in range(len(num_channels) - 1): - channel_in = num_channels[i] - channel_out = num_channels[i + 1] + for i in range(len(channels) - 1): + channel_in = channels[i] + channel_out = channels[i + 1] self.blocks.append( Convolution( spatial_dims=spatial_dims, @@ -94,7 +92,7 @@ def __init__( self.conv_out = zero_module( Convolution( spatial_dims=spatial_dims, - in_channels=num_channels[-1], + in_channels=channels[-1], out_channels=out_channels, strides=1, kernel_size=3, @@ -131,7 +129,7 @@ class ControlNet(nn.Module): spatial_dims: number of spatial dimensions. in_channels: number of input channels. num_res_blocks: number of residual blocks (see ResnetBlock) per level. - num_channels: tuple of block output channels. + channels: tuple of block output channels. attention_levels: list of levels to add attention. norm_num_groups: number of groups for the normalization. norm_eps: epsilon for the normalization. @@ -153,7 +151,7 @@ def __init__( spatial_dims: int, in_channels: int, num_res_blocks: Sequence[int] | int = (2, 2, 2, 2), - num_channels: Sequence[int] = (32, 64, 64, 64), + channels: Sequence[int] = (32, 64, 64, 64), attention_levels: Sequence[bool] = (False, False, True, True), norm_num_groups: int = 32, norm_eps: float = 1e-6, @@ -166,7 +164,7 @@ def __init__( upcast_attention: bool = False, use_flash_attention: bool = False, conditioning_embedding_in_channels: int = 1, - conditioning_embedding_num_channels: Sequence[int] | None = (16, 32, 96, 256), + conditioning_embedding_num_channels: Sequence[int] = (16, 32, 96, 256), ) -> None: super().__init__() if with_conditioning is True and cross_attention_dim is None: @@ -180,10 +178,10 @@ def __init__( ) # All number of channels should be multiple of num_groups - if any((out_channel % norm_num_groups) != 0 for out_channel in num_channels): + if any((out_channel % norm_num_groups) != 0 for out_channel in channels): raise ValueError("DiffusionModelUNet expects all num_channels being multiple of norm_num_groups") - if len(num_channels) != len(attention_levels): + if len(channels) != len(attention_levels): raise ValueError("DiffusionModelUNet expects num_channels being same size of attention_levels") if isinstance(num_head_channels, int): @@ -196,9 +194,9 @@ def __init__( ) if isinstance(num_res_blocks, int): - num_res_blocks = ensure_tuple_rep(num_res_blocks, len(num_channels)) + num_res_blocks = ensure_tuple_rep(num_res_blocks, len(channels)) - if len(num_res_blocks) != len(num_channels): + if len(num_res_blocks) != len(channels): raise ValueError( "`num_res_blocks` should be a single integer or a tuple of integers with the same length as " "`num_channels`." @@ -210,7 +208,7 @@ def __init__( ) self.in_channels = in_channels - self.block_out_channels = num_channels + self.block_out_channels = channels self.num_res_blocks = num_res_blocks self.attention_levels = attention_levels self.num_head_channels = num_head_channels @@ -220,7 +218,7 @@ def __init__( self.conv_in = Convolution( spatial_dims=spatial_dims, in_channels=in_channels, - out_channels=num_channels[0], + out_channels=channels[0], strides=1, kernel_size=3, padding=1, @@ -228,9 +226,9 @@ def __init__( ) # time - time_embed_dim = num_channels[0] * 4 + time_embed_dim = channels[0] * 4 self.time_embed = nn.Sequential( - nn.Linear(num_channels[0], time_embed_dim), nn.SiLU(), nn.Linear(time_embed_dim, time_embed_dim) + nn.Linear(channels[0], time_embed_dim), nn.SiLU(), nn.Linear(time_embed_dim, time_embed_dim) ) # class embedding @@ -242,14 +240,14 @@ def __init__( self.controlnet_cond_embedding = ControlNetConditioningEmbedding( spatial_dims=spatial_dims, in_channels=conditioning_embedding_in_channels, - num_channels=conditioning_embedding_num_channels, - out_channels=num_channels[0], + channels=conditioning_embedding_num_channels, + out_channels=channels[0], ) # down self.down_blocks = nn.ModuleList([]) self.controlnet_down_blocks = nn.ModuleList([]) - output_channel = num_channels[0] + output_channel = channels[0] controlnet_block = Convolution( spatial_dims=spatial_dims, @@ -263,10 +261,10 @@ def __init__( controlnet_block = zero_module(controlnet_block.conv) self.controlnet_down_blocks.append(controlnet_block) - for i in range(len(num_channels)): + for i in range(len(channels)): input_channel = output_channel - output_channel = num_channels[i] - is_final_block = i == len(num_channels) - 1 + output_channel = channels[i] + is_final_block = i == len(channels) - 1 down_block = get_down_block( spatial_dims=spatial_dims, @@ -316,7 +314,7 @@ def __init__( self.controlnet_down_blocks.append(controlnet_block) # mid - mid_block_channel = num_channels[-1] + mid_block_channel = channels[-1] self.middle_block = get_mid_block( spatial_dims=spatial_dims, @@ -352,7 +350,7 @@ def forward( conditioning_scale: float = 1.0, context: torch.Tensor | None = None, class_labels: torch.Tensor | None = None, - ) -> tuple[tuple[torch.Tensor], torch.Tensor]: + ) -> tuple[list[torch.Tensor], torch.Tensor]: """ Args: x: input tensor (N, C, SpatialDims). @@ -399,15 +397,15 @@ def forward( h = self.middle_block(hidden_states=h, temb=emb, context=context) # 6. Control net blocks - controlnet_down_block_res_samples = () + controlnet_down_block_res_samples = [] for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks): down_block_res_sample = controlnet_block(down_block_res_sample) - controlnet_down_block_res_samples += (down_block_res_sample,) + controlnet_down_block_res_samples.append(down_block_res_sample) down_block_res_samples = controlnet_down_block_res_samples - mid_block_res_sample = self.controlnet_mid_block(h) + mid_block_res_sample: torch.Tensor = self.controlnet_mid_block(h) # 6. scaling down_block_res_samples = [h * conditioning_scale for h in down_block_res_samples] diff --git a/tests/test_controlnet.py b/tests/test_controlnet.py index 27acccc083..1ced38d61a 100644 --- a/tests/test_controlnet.py +++ b/tests/test_controlnet.py @@ -25,7 +25,7 @@ "spatial_dims": 2, "in_channels": 1, "num_res_blocks": 1, - "num_channels": (8, 8, 8), + "channels": (8, 8, 8), "attention_levels": (False, False, True), "num_head_channels": 8, "norm_num_groups": 8, From 53961a1d17a12030e1f6a0e6343e646fb2c35a09 Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Tue, 12 Dec 2023 10:50:46 +0000 Subject: [PATCH 3/6] Updates docs Signed-off-by: Mark Graham --- docs/source/networks.rst | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/docs/source/networks.rst b/docs/source/networks.rst index 417fb8ac73..0960fcdbc0 100644 --- a/docs/source/networks.rst +++ b/docs/source/networks.rst @@ -588,6 +588,11 @@ Nets .. autoclass:: DiffusionModelUNet :members: +`ControlNet` +~~~~~~~~~~~~ +.. autoclass:: ControlNet + :members: + `RegUNet` ~~~~~~~~~ .. autoclass:: RegUNet From eb3a615e76d06e2afb33c3fc3546adeaeb3005d6 Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Tue, 12 Dec 2023 11:58:26 +0000 Subject: [PATCH 4/6] More tests Signed-off-by: Mark Graham --- tests/test_controlnet.py | 147 +++++++++++++++++++++++++++++++++++---- 1 file changed, 135 insertions(+), 12 deletions(-) diff --git a/tests/test_controlnet.py b/tests/test_controlnet.py index 1ced38d61a..07dfa2e49b 100644 --- a/tests/test_controlnet.py +++ b/tests/test_controlnet.py @@ -19,7 +19,42 @@ from monai.networks import eval_mode from monai.networks.nets.controlnet import ControlNet -TEST_CASES = [ +UNCOND_CASES_2D = [ + [ + { + "spatial_dims": 2, + "in_channels": 1, + "num_res_blocks": 1, + "channels": (8, 8, 8), + "attention_levels": (False, False, False), + "norm_num_groups": 8, + }, + (1, 8, 4, 4), + ], + [ + { + "spatial_dims": 2, + "in_channels": 1, + "num_res_blocks": 1, + "channels": (8, 8, 8), + "attention_levels": (False, False, False), + "norm_num_groups": 8, + "resblock_updown": True, + }, + (1, 8, 4, 4), + ], + [ + { + "spatial_dims": 2, + "in_channels": 1, + "num_res_blocks": 1, + "channels": (4, 4, 4), + "attention_levels": (False, False, True), + "num_head_channels": 4, + "norm_num_groups": 4, + }, + (1, 4, 4, 4), + ], [ { "spatial_dims": 2, @@ -29,25 +64,113 @@ "attention_levels": (False, False, True), "num_head_channels": 8, "norm_num_groups": 8, - "conditioning_embedding_in_channels": 1, - "conditioning_embedding_num_channels": (8, 8), + "resblock_updown": True, }, - 6, (1, 8, 4, 4), - ] + ], +] + +UNCOND_CASES_3D = [ + [ + { + "spatial_dims": 3, + "in_channels": 1, + "num_res_blocks": 1, + "channels": (8, 8, 8), + "attention_levels": (False, False, False), + "norm_num_groups": 8, + }, + (1, 8, 4, 4, 4), + ], + [ + { + "spatial_dims": 3, + "in_channels": 1, + "num_res_blocks": 1, + "channels": (4, 4, 4), + "num_head_channels": 4, + "attention_levels": (False, False, False), + "norm_num_groups": 4, + "resblock_updown": True, + }, + (1, 4, 4, 4, 4), + ], +] + +COND_CASES_2D = [ + [ + { + "spatial_dims": 2, + "in_channels": 1, + "num_res_blocks": 1, + "channels": (8, 8, 8), + "attention_levels": (False, False, False), + "norm_num_groups": 8, + "with_conditioning": True, + "transformer_num_layers": 1, + "cross_attention_dim": 3, + }, + (1, 8, 4, 4), + ], + [ + { + "spatial_dims": 2, + "in_channels": 1, + "num_res_blocks": 1, + "channels": (8, 8, 8), + "attention_levels": (False, False, False), + "norm_num_groups": 8, + "with_conditioning": True, + "transformer_num_layers": 1, + "cross_attention_dim": 3, + "resblock_updown": True, + }, + (1, 8, 4, 4), + ], + [ + { + "spatial_dims": 2, + "in_channels": 1, + "num_res_blocks": 1, + "channels": (8, 8, 8), + "attention_levels": (False, False, False), + "norm_num_groups": 8, + "with_conditioning": True, + "transformer_num_layers": 1, + "cross_attention_dim": 3, + "upcast_attention": True, + }, + (1, 8, 4, 4), + ], ] class TestControlNet(unittest.TestCase): - @parameterized.expand(TEST_CASES) - def test_shape_unconditioned_models(self, input_param, expected_num_down_blocks_residuals, expected_shape): + @parameterized.expand(UNCOND_CASES_2D + UNCOND_CASES_3D) + def test_shape_unconditioned_models(self, input_param, expected_output_shape): + input_param["conditioning_embedding_in_channels"] = input_param["in_channels"] + input_param["conditioning_embedding_num_channels"] = (input_param["channels"][0],) + net = ControlNet(**input_param) + with eval_mode(net): + x = torch.rand((1, 1) + (16,) * input_param["spatial_dims"]) + timesteps = torch.randint(0, 1000, (1,)).long() + controlnet_cond = torch.rand((1, 1) + (16,) * input_param["spatial_dims"]) + result = net.forward(x, timesteps=timesteps, controlnet_cond=controlnet_cond) + self.assertEqual(len(result[0]), 2 * len(input_param["channels"])) + self.assertEqual(result[1].shape, expected_output_shape) + + @parameterized.expand(COND_CASES_2D) + def test_shape_conditioned_models(self, input_param, expected_output_shape): + input_param["conditioning_embedding_in_channels"] = input_param["in_channels"] + input_param["conditioning_embedding_num_channels"] = (input_param["channels"][0],) net = ControlNet(**input_param) with eval_mode(net): - result = net.forward( - torch.rand((1, 1, 16, 16)), torch.randint(0, 1000, (1,)).long(), torch.rand((1, 1, 32, 32)) - ) - self.assertEqual(len(result[0]), expected_num_down_blocks_residuals) - self.assertEqual(result[1].shape, expected_shape) + x = torch.rand((1, 1) + (16,) * input_param["spatial_dims"]) + timesteps = torch.randint(0, 1000, (1,)).long() + controlnet_cond = torch.rand((1, 1) + (16,) * input_param["spatial_dims"]) + result = net.forward(x, timesteps=timesteps, controlnet_cond=controlnet_cond, context=torch.rand((1, 1, 3))) + self.assertEqual(len(result[0]), 2 * len(input_param["channels"])) + self.assertEqual(result[1].shape, expected_output_shape) if __name__ == "__main__": From 6ac195eb606be6eb99a740d4ee6bb86c78edb21d Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Wed, 13 Dec 2023 11:06:32 +0000 Subject: [PATCH 5/6] More informative errors on argument checking Signed-off-by: Mark Graham --- monai/networks/nets/controlnet.py | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/monai/networks/nets/controlnet.py b/monai/networks/nets/controlnet.py index 7032769809..2664f883c7 100644 --- a/monai/networks/nets/controlnet.py +++ b/monai/networks/nets/controlnet.py @@ -170,7 +170,7 @@ def __init__( if with_conditioning is True and cross_attention_dim is None: raise ValueError( "DiffusionModelUNet expects dimension of the cross-attention conditioning (cross_attention_dim) " - "when using with_conditioning." + "to be specified when with_conditioning=True." ) if cross_attention_dim is not None and with_conditioning is False: raise ValueError( @@ -179,17 +179,24 @@ def __init__( # All number of channels should be multiple of num_groups if any((out_channel % norm_num_groups) != 0 for out_channel in channels): - raise ValueError("DiffusionModelUNet expects all num_channels being multiple of norm_num_groups") + raise ValueError( + f"DiffusionModelUNet expects all channels to be a multiple of norm_num_groups, but got" + f" channels={channels} and norm_num_groups={norm_num_groups}" + ) if len(channels) != len(attention_levels): - raise ValueError("DiffusionModelUNet expects num_channels being same size of attention_levels") + raise ValueError( + f"DiffusionModelUNet expects channels to have the same length as attention_levels, but got " + f"channels={channels} and attention_levels={attention_levels}" + ) if isinstance(num_head_channels, int): num_head_channels = ensure_tuple_rep(num_head_channels, len(attention_levels)) if len(num_head_channels) != len(attention_levels): raise ValueError( - "num_head_channels should have the same length as attention_levels. For the i levels without attention," + f"num_head_channels should have the same length as attention_levels, but got channels={channels} and " + f"attention_levels={attention_levels} . For the i levels without attention," " i.e. `attention_level[i]=False`, the num_head_channels[i] will be ignored." ) @@ -198,8 +205,8 @@ def __init__( if len(num_res_blocks) != len(channels): raise ValueError( - "`num_res_blocks` should be a single integer or a tuple of integers with the same length as " - "`num_channels`." + f"`num_res_blocks` should be a single integer or a tuple of integers with the same length as " + f"`num_channels`, but got num_res_blocks={num_res_blocks} and channels={channels}." ) if use_flash_attention is True and not torch.cuda.is_available(): From df71b0ec71d65861ed39e90a2636468cf28cac7a Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Wed, 13 Dec 2023 11:11:44 +0000 Subject: [PATCH 6/6] More information on the dimensions of model inputs Signed-off-by: Mark Graham --- monai/networks/nets/controlnet.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/monai/networks/nets/controlnet.py b/monai/networks/nets/controlnet.py index 2664f883c7..d98755f401 100644 --- a/monai/networks/nets/controlnet.py +++ b/monai/networks/nets/controlnet.py @@ -360,11 +360,11 @@ def forward( ) -> tuple[list[torch.Tensor], torch.Tensor]: """ Args: - x: input tensor (N, C, SpatialDims). + x: input tensor (N, C, H, W, [D]). timesteps: timestep tensor (N,). - controlnet_cond: controlnet conditioning tensor (N, C, SpatialDims). + controlnet_cond: controlnet conditioning tensor (N, C, H, W, [D]) conditioning_scale: conditioning scale. - context: context tensor (N, 1, ContextDim). + context: context tensor (N, 1, cross_attention_dim), where cross_attention_dim is specified in the model init. class_labels: context tensor (N, ). """ # 1. time