From 0fcc2bda7eb627367595123ee735d9df2224a2cf Mon Sep 17 00:00:00 2001 From: rd24_lihe055 Date: Tue, 21 Jan 2025 12:04:24 +0000 Subject: [PATCH] replaces custom saver with SaveImage --- .../configs/inference.json | 6 ++-- .../scripts/saver.py | 29 ------------------- 2 files changed, 3 insertions(+), 32 deletions(-) delete mode 100644 models/brain_image_synthesis_latent_diffusion_model/scripts/saver.py diff --git a/models/brain_image_synthesis_latent_diffusion_model/configs/inference.json b/models/brain_image_synthesis_latent_diffusion_model/configs/inference.json index 6e7b0262..206d1130 100644 --- a/models/brain_image_synthesis_latent_diffusion_model/configs/inference.json +++ b/models/brain_image_synthesis_latent_diffusion_model/configs/inference.json @@ -105,8 +105,8 @@ "saver": { "_target_": "SaveImage", "_requires_": "@create_output_dir", - "output_dir": "@output_dir" + "output_dir": "@output_dir", + "output_postfix": "@out_file" }, - "run": "$@saver.save(@sample, @out_file)", - "save": "$torch.save(@sample, @output_dir + '/' + @out_file + '.pt')" + "run": "$@saver(@sample[0][0])" } diff --git a/models/brain_image_synthesis_latent_diffusion_model/scripts/saver.py b/models/brain_image_synthesis_latent_diffusion_model/scripts/saver.py deleted file mode 100644 index de882df0..00000000 --- a/models/brain_image_synthesis_latent_diffusion_model/scripts/saver.py +++ /dev/null @@ -1,29 +0,0 @@ -from __future__ import annotations - -import nibabel as nib -import numpy as np -import torch - - -class NiftiSaver: - def __init__(self, output_dir: str) -> None: - super().__init__() - self.output_dir = output_dir - self.affine = np.array( - [ - [-1.0, 0.0, 0.0, 96.48149872], - [0.0, 1.0, 0.0, -141.47715759], - [0.0, 0.0, 1.0, -156.55375671], - [0.0, 0.0, 0.0, 1.0], - ] - ) - - def save(self, image_data: torch.Tensor, file_name: str) -> None: - image_data = image_data.cpu().numpy() - image_data = image_data[0, 0, 5:-5, 5:-5, :-15] - image_data = (image_data - image_data.min()) / (image_data.max() - image_data.min()) - image_data = (image_data * 255).astype(np.uint8) - - empty_header = nib.Nifti1Header() - sample_nii = nib.Nifti1Image(image_data, self.affine, empty_header) - nib.save(sample_nii, f"{str(self.output_dir)}/{file_name}.nii.gz")