Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Diffusion denoising branches are useless in unlabelled data #12

Open
hhjjaaa opened this issue Jan 4, 2025 · 6 comments
Open

Diffusion denoising branches are useless in unlabelled data #12

hhjjaaa opened this issue Jan 4, 2025 · 6 comments

Comments

@hhjjaaa
Copy link

hhjjaaa commented Jan 4, 2025

I reproduced the code you provided in the Synapse dataset and got the same result as in the paper. However, I visualized the result of unlabeled data in the diffusion branch and found that the output was all noise, and it could not generate useful pseudo-tags with the weight adjustment branch to help training, but the weight adjustment branch became pseudo-tags entirely.
Snipaste_2024-12-03_19-58-26
Snipaste_2024-12-03_19-58-26
Snipaste_2024-12-04_12-59-55
Snipaste_2024-12-04_12-59-55

@hhjjaaa
Copy link
Author

hhjjaaa commented Jan 4, 2025

This is my visual code
def label_to_color(label, num_classes):
if num_classes != 14:
raise ValueError(f"Expected num_classes=13, but got {num_classes}")
colors = [
(0.0, 0.0, 0.0), # Class 0 - Background - Black
(1.0, 0.0, 0.0), # Class 1 - Red
(0.0, 1.0, 0.0), # Class 2 - Green
(0.0, 0.0, 1.0), # Class 3 - Blue
(1.0, 1.0, 0.0), # Class 4 - Yellow
(1.0, 0.0, 1.0), # Class 5 - Magenta
(0.0, 1.0, 1.0), # Class 6 - Cyan
(0.5, 0.5, 0.5), # Class 7 - Gray
(0.5, 0.0, 0.0), # Class 8 - Dark Red
(0.0, 0.5, 0.0), # Class 9 - Dark Green
(0.0, 0.0, 0.5), # Class 10 - Dark Blue
(0.5, 0.5, 0.0), # Class 11 - Olive
(0.5, 0.0, 0.5), # Class 12 - Purple
(0.3, 0.8, 0.5) # Class 13 - Additional Color
]

# If the number of classes exceeds the predefined colors, use random colors (though here limited to 13 classes)
if num_classes > len(colors):
    additional_colors = np.random.rand(num_classes - len(colors), 3)
    colors.extend(additional_colors.tolist())

cmap = ListedColormap(colors[:num_classes])

# Convert label to a numpy array
label = label.numpy() if isinstance(label, torch.Tensor) else label

# Apply color mapping
label_color = cmap(label)

return label_color

Add visualization in the training loop

if epoch_num % 10 == 0: # Visualize every 10 epochs
model.eval()
with torch.no_grad():
# Get a batch of images and labels for visualization
batch = next(iter(eval_loader))
images, gts = fetch_data(batch) # images: (B, C, D, H, W), gts: (B, C, D, H, W) or (B, D, H, W)

    # Get the outputs from the model's two branches
    p_u_xi = model(images, pred_type="ddim_sample")  # Shape: (B, C, D, H, W)
    p_u_psi = model(images, pred_type="D_psi_l")  # Shape: (B, C, D, H, W)

    # Convert outputs to probability maps

    smoothing = GaussianSmoothing(config.num_cls, 3, 1)
    pred_xi = smoothing(F.gumbel_softmax(p_u_xi, dim=1))
    pred_psi = F.softmax(p_u_psi, dim=1)  # (B, C, D, H, W)

    # Get predicted classes

    pred_psi = torch.argmax(pred_psi, dim=1, keepdim=True)  # (B, 1, D, H, W)

    for i in range(min(3, images.size(0))):  # Visualize the first 3 images
        image = images[i].cpu()  # Shape: (C, D, H, W)
        gt = gts[i].cpu()  # Shape: (C, D, H, W) or (D, H, W)
        px_i = pred_xi[i].cpu()  # Shape: (1, D, H, W)
        ppsi = pred_psi[i].cpu()  # Shape: (1, D, H, W)
        org_ppsi = p_u_psi[i].cpu()

        # Print shapes for debugging
        print(
            f"Before processing - Image shape: {image.shape}, GT shape: {gt.shape}, pred_xi shape: {px_i.shape}, pred_psi shape: {ppsi.shape}")

        # Select a specific depth slice
        depth_idx = image.size(1) // 2  # Middle slice index
        channel_idx = 0  # Select the first channel

        # Process image slice
        if image.ndimension() == 4:
            # Image shape is (C, D, H, W)
            if image.size(0) > 1:
                image_slice = image[channel_idx, depth_idx, :, :]  # Select specific channel and depth
            else:
                image_slice = image.squeeze(0)[depth_idx, :, :]  # Single channel
        elif image.ndimension() == 3:
            # Image shape is (D, H, W)
            image_slice = image[depth_idx, :, :]
        else:
            raise ValueError(f"Unexpected image dimensions: {image.ndimension()}")

        # Process Ground Truth slice
        if gt.ndimension() == 4:
            if gt.size(0) > 1:
                gt_slice = gt[channel_idx, depth_idx, :, :]
            else:
                gt_slice = gt.squeeze(0)[depth_idx, :, :]
        elif gt.ndimension() == 3:
            gt_slice = gt[depth_idx, :, :]
        else:
            raise ValueError(f"Unexpected GT dimensions: {gt.ndimension()}")

        # Convert label to class indices (if labels are one-hot encoded)
        if gt_slice.ndimension() == 3:
            gt_slice = torch.argmax(gt_slice, dim=0)

        # Process predicted slice
        if px_i.ndimension() == 4:
            if px_i.size(0) > 1:
                pred_xi_slice = px_i[channel_idx, depth_idx, :, :]
            else:
                pred_xi_slice = px_i.squeeze(0)[depth_idx, :, :]
        elif px_i.ndimension() == 3:
            pred_xi_slice = px_i[depth_idx, :, :]
        else:
            raise ValueError(f"Unexpected pred_xi dimensions: {px_i.ndimension()}")

        if ppsi.ndimension() == 4:
            if ppsi.size(0) > 1:
                pred_psi_slice = ppsi[channel_idx, depth_idx, :, :]
            else:
                pred_psi_slice = ppsi.squeeze(0)[depth_idx, :, :]
        elif ppsi.ndimension() == 3:
            pred_psi_slice = ppsi[depth_idx, :, :]
        else:
            raise ValueError(f"Unexpected pred_psi dimensions: {ppsi.ndimension()}")

        if org_ppsi.ndimension() == 4:
            if org_ppsi.size(0) > 1:
                p_u_psi_slice = org_ppsi[channel_idx, depth_idx, :, :]
            else:
                p_u_psi_slice = org_ppsi.squeeze(0)[depth_idx, :, :]
        elif org_ppsi.ndimension() == 3:
            p_u_psi_slice = org_ppsi[depth_idx, :, :]
        else:
            raise ValueError(f"Unexpected pred_psi dimensions: {org_ppsi.ndimension()}")

        # Ensure predictions are also 2D
        if pred_xi_slice.ndimension() > 2:
            pred_xi_slice = torch.argmax(pred_xi_slice, dim=0)
        if pred_psi_slice.ndimension() > 2:
            pred_psi_slice = torch.argmax(pred_psi_slice, dim=0)

        print(
            f"After processing - Image slice shape: {image_slice.shape}, GT slice shape: {gt_slice.shape}, pred_xi_slice shape: {pred_xi_slice.shape}, pred_psi_slice shape: {pred_psi_slice.shape}")

        # Convert labels to color maps
        gt_color = label_to_color(gt_slice, config.num_cls)
        pred_psi_color = label_to_color(pred_psi_slice, config.num_cls)

        # Normalize the image to [0,1] and convert to numpy array

        image_slice = image_slice.numpy()  # (H, W)

        # Create a grid of images
        fig, axs = plt.subplots(1, 5, figsize=(20, 5))

        # Input grayscale image
        axs[0].imshow(image_slice, cmap='gray')
        axs[0].set_title('Input Image')
        axs[0].axis('off')

        # Ground Truth
        axs[1].imshow(gt_color)
        axs[1].set_title('Ground Truth')
        axs[1].axis('off')

        # Predicted p_u_xi
        axs[2].imshow(pred_xi_slice, cmap='gray')
        axs[2].set_title('Predicted p_u_xi')
        axs[2].axis('off')

        # Predicted p_u_psi
        axs[3].imshow(pred_psi_color)
        axs[3].set_title('Predicted p_u_psi')
        axs[3].axis('off')

        axs[4].imshow(p_u_psi_slice, cmap='gray')
        axs[4].set_title('Predicted org_p_u_psi')
        axs[4].axis('off')

        # Convert the figure to a Tensor to add to TensorBoard
        fig.canvas.draw()
        img_tensor = torch.from_numpy(np.array(fig.canvas.renderer.buffer_rgba())).permute(2, 0, 1)[:3, :, :] / 255.0
        writer.add_image(f'Validation/Predictions_{epoch_num}_{i}', img_tensor, epoch_num)
        plt.close(fig)

model.train()

@hhjjaaa
Copy link
Author

hhjjaaa commented Jan 4, 2025

Unlabeled data is noise during trainingSnipaste_2024-12-03_19-58-26

@hhjjaaa
Copy link
Author

hhjjaaa commented Jan 4, 2025

How did you get this pictureSnipaste_2024-12-03_19-58-26

@McGregorWwww
Copy link
Collaborator

Is visualization conducted during the early training stage? If so, it would be difficult for the diffusion process to learn effective representations.

@hhjjaaa
Copy link
Author

hhjjaaa commented Jan 6, 2025

The output of the diffused branches on the M&Ms dataset did learn the features (the diffused parts learned well from early to middle) and the segmentation results were as expected, but the output of the diffused branches on the Synapse dataset was all noise (all periods).

@hhjjaaa
Copy link
Author

hhjjaaa commented Jan 7, 2025

Snipaste_2024-12-03_19-58-26
Snipaste_2024-12-04_12-59-55

I reproduced the code you provided in the Synapse dataset and got the same result as in the paper. However, I visualized the result of unlabeled data in the diffusion branch and found that the output was all noise, and it could not generate useful pseudo-tags with the weight adjustment branch to help training, but the weight adjustment branch became pseudo-tags entirely. Snipaste_2024-12-03_19-58-26 Snipaste_2024-12-03_19-58-26 Snipaste_2024-12-04_12-59-55 Snipaste_2024-12-04_12-59-55

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants