Skip to content

Commit

Permalink
Fix example U Net learning (#209)
Browse files Browse the repository at this point in the history
* Added

* Fix

* Remove bymistake add

* Fix

* Fixed lint

* Lint

* Added refbackend

* Fix NDFT

* feat: use finufft as ref backend.

* feat(tests): move ndft vs nufft tests to own file.

* \docs_build [docs]

* Added some changes for and some helopful gitignore

---------

Co-authored-by: chaithyagr <[email protected]>
Co-authored-by: Pierre-antoine Comby <[email protected]>
  • Loading branch information
3 people authored Nov 18, 2024
1 parent a6bd869 commit 2d05f41
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 12 deletions.
8 changes: 4 additions & 4 deletions .github/workflows/test-ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ env:
jobs:
test-cpu:
runs-on: cpu
if: ${{ !contains(github.event.head_commit.message, '!style') || github.ref == 'refs/heads/master' }}
if: ${{ !contains(github.event.head_commit.message, '[style]') || github.ref == 'refs/heads/master' }}
strategy:
matrix:
backend: [finufft, pynfft, pynufft-cpu, bart, sigpy, torchkbnufft-cpu]
Expand Down Expand Up @@ -97,7 +97,7 @@ jobs:

test-gpu:
runs-on: gpu
if: ${{ !contains(github.event.head_commit.message, '!style') || github.ref == 'refs/heads/master' }}
if: ${{ !contains(github.event.head_commit.message, '[style]') || github.ref == 'refs/heads/master' }}
strategy:
matrix:
backend: [cufinufft, gpunufft, torchkbnufft-gpu, tensorflow]
Expand Down Expand Up @@ -186,7 +186,7 @@ jobs:
test-examples:
runs-on: gpu
needs: get-commit-message
if: ${{ !contains(needs.get-commit-message.outputs.message, '!style') || github.ref == 'refs/heads/master' }}
if: ${{ !contains(needs.get-commit-message.outputs.message, '[style]') || github.ref == 'refs/heads/master' }}

steps:
- uses: actions/checkout@v4
Expand Down Expand Up @@ -298,7 +298,7 @@ jobs:
name: Build API Documentation
runs-on: gpu
needs: get-commit-message
if: ${{ contains(needs.get-commit-message.outputs.message, '!docs_build') || github.ref == 'refs/heads/master' }}
if: ${{ contains(needs.get-commit-message.outputs.message, '[docs]') || github.ref == 'refs/heads/master' }}
steps:
- name: Checkout
uses: actions/checkout@v4
Expand Down
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
*.npy
*.gif
docs/sg_execution_times.rst
build/
dist/
*.egg-info/
Expand Down
23 changes: 15 additions & 8 deletions examples/GPU/example_fastMRI_UNet.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,8 @@
\mathbf{\hat{x}} = \mathrm{arg} \min_{\mathbf{x}} || \mathcal{U}_\mathbf{\theta}(\mathbf{y}) - \mathbf{x} ||_2^2
where:
- \( \mathbf{\hat{x}} \) is the reconstructed MRI image,
- \( \mathbf{x} \) is the ground truth image,
- \( \mathbf{y} \) is the input MRI image (e.g., k-space data),
- \( \mathcal{U}_\mathbf{\theta} \) is the U-Net model parameterized by \( \theta \).
where :math:`\mathbf{\hat{x}}` is the reconstructed MRI image, :math:`\mathbf{x}` is the ground truth image,
:math:`\mathbf{y}` is the input MRI image (e.g., k-space data), and :math:`\mathcal{U}_\mathbf{\theta}` is the U-Net model parameterized by :math:`\theta`.
.. warning::
We train on a single image here. In practice, this should be done on a database like fastMRI [fastmri]_.
Expand Down Expand Up @@ -141,13 +138,13 @@ def plot_state(axs, mri_2D, traj, recon, loss=None, save_name=None):

# %%
# Start training loop
epoch = 100
num_epochs = 2
optimizer = torch.optim.RAdam(model.parameters(), lr=1e-3)
losses = [] # Store the loss values and create an animation
image_files = [] # Store the images to create a gif
model.train()

with tqdm(range(epoch), unit="steps") as tqdms:
with tqdm(range(num_epochs), unit="steps") as tqdms:
for i in tqdms:
out = model(kspace_mri_2D) # Forward pass

Expand Down Expand Up @@ -203,10 +200,20 @@ def plot_state(axs, mri_2D, traj, recon, loss=None, save_name=None):
/ "GPU"
/ "images"
)
shutil.copyfile("mrinufft_learn_Unet.gif", final_dir / "mrinufft_learn_Unet.gif")
shutil.copyfile("mrinufft_learn_unet.gif", final_dir / "mrinufft_learn_unet.gif")
except FileNotFoundError:
pass

# sphinx_gallery_end_ignore

# sphinx_gallery_thumbnail_path = 'generated/autoexamples/GPU/images/mrinufft_learn_unet.gif'

# %%
# .. image-sg:: /generated/autoexamples/GPU/images/mrinufft_learn_unet.gif
# :alt: example learn_samples
# :srcset: /generated/autoexamples/GPU/images/mrinufft_learn_unet.gif
# :class: sphx-glr-single-img

# %%
# Reconstruction from partially trained U-Net model
model.eval()
Expand Down

0 comments on commit 2d05f41

Please sign in to comment.