diff --git a/examples/GPU/example_learn_straight_line_readouts.py b/examples/GPU/example_learn_straight_line_readouts.py index 2ff58221..ecf9aeef 100644 --- a/examples/GPU/example_learn_straight_line_readouts.py +++ b/examples/GPU/example_learn_straight_line_readouts.py @@ -126,7 +126,7 @@ def plot_state(mri_2D, traj, recon, loss=None, save_name=None, i=None): axs[0].axis("off") axs[0].set_title("MR Image") if traj.shape[-1] == 3: - if i is not None and i > 50: + if i is not None and i > 20: axs[1].scatter(*traj.T[1:3, 0], s=10, color="blue") else: fig_kwargs = {} @@ -136,7 +136,7 @@ def plot_state(mri_2D, traj, recon, loss=None, save_name=None, i=None): i / 25 * 60 - 60, 30 - i / 25 * 30, ) - plt_kwargs["alpha"] = 0.2 + 0.8 * i / 20 + plt_kwargs["alpha"] = 0.2 + 0.8 * i / 20, 1 plt_kwargs["s"] = 1 + 9 * i / 20 axs[1].remove() axs[1] = fig.add_subplot(*fig_grid, 2, projection="3d", **fig_kwargs)