diff --git a/nanshe_ipython.ipynb b/nanshe_ipython.ipynb index dfceab2..81fcc9a 100644 --- a/nanshe_ipython.ipynb +++ b/nanshe_ipython.ipynb @@ -241,7 +241,7 @@ "source": [ "from nanshe_workflow.par import halo_block_parallel\n", "\n", - "from nanshe_workflow.imp2 import extract_f0, wavelet_transform, renormalized_images, normalize_data\n", + "from nanshe_workflow.imp2 import extract_f0, wavelet_transform, normalize_data\n", "\n", "from nanshe_workflow.par import halo_block_generate_dictionary_parallel\n", "from nanshe_workflow.imp import block_postprocess_data_parallel\n", @@ -993,11 +993,10 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "### Normalize Data\n", + "### Project\n", "\n", "* `block_frames` (`int`): number of frames to work with in each full frame block (run in parallel).\n", - "* `block_space` (`int`): extent of each spatial dimension for each block (run in parallel).\n", - "* `norm_frames` (`int`): number of frames for use during normalization of each full frame block (run in parallel)." + "* `proj_type` (`str`): type of projection to take." ] }, { @@ -1007,12 +1006,11 @@ "outputs": [], "source": [ "block_frames = 40\n", - "block_space = 300\n", - "norm_frames = 100\n", + "proj_type = \"max\"\n", "\n", "\n", "with get_executor(client) as executor:\n", - " dask_io_remove(data_basename + postfix_norm + zarr_ext, executor)\n", + " dask_io_remove(data_basename + postfix_dict + zarr_ext, executor)\n", "\n", "\n", " with open_zarr(data_basename + postfix_wt + zarr_ext, \"r\") as f:\n", @@ -1027,106 +1025,28 @@ " da_imgs_flt.dtype.itemsize >= 4):\n", " da_imgs_flt = da_imgs_flt.astype(np.float32)\n", "\n", - " da_imgs_flt_mins = da_imgs_flt.min(\n", - " axis=tuple(irange(1, da_imgs_flt.ndim)),\n", - " keepdims=True\n", - " )\n", - "\n", - " da_imgs_flt_shift = da_imgs_flt - da_imgs_flt_mins\n", - "\n", - " da_result = renormalized_images(da_imgs_flt_shift)\n", + " da_result = da_imgs\n", + " if proj_type == \"max\":\n", + " da_result = da_result.max(axis=0, keepdims=True)\n", + " elif proj_type == \"std\":\n", + " da_result = da_result.std(axis=0, keepdims=True)\n", "\n", " # Store denoised data\n", - " dask_store_zarr(data_basename + postfix_norm + zarr_ext, [\"images\"], [da_result], executor)\n", - "\n", - "\n", - " zip_zarr(data_basename + postfix_norm + zarr_ext, executor)\n", - "\n", - "\n", - "if __IPYTHON__:\n", - " result_image_stack = LazyZarrDataset(data_basename + postfix_norm + zarr_ext, \"images\")\n", - "\n", - " mplsv = plt.figure(FigureClass=MPLViewer)\n", - " mplsv.set_images(\n", - " result_image_stack,\n", - " vmin=par_compute_min_projection(num_frames=norm_frames)(result_image_stack).min(),\n", - " vmax=par_compute_max_projection(num_frames=norm_frames)(result_image_stack).max()\n", - " )" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Dictionary Learning\n", - "\n", - "* `n_components` (`int`): number of basis images in the dictionary.\n", - "* `batchsize` (`int`): minibatch size to use.\n", - "* `iters` (`int`): number of iterations to run before getting dictionary.\n", - "* `lambda1` (`float`): weight for L1 sparisty enforcement on sparse code.\n", - "* `lambda2` (`float`): weight for L2 sparisty enforcement on sparse code.\n", - "\n", - "
\n", - "* `block_frames` (`int`): number of frames to work with in each full frame block (run in parallel).\n", - "* `norm_frames` (`int`): number of frames for use during normalization of each full frame block (run in parallel)." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "n_components = 50\n", - "batchsize = 256\n", - "iters = 100\n", - "lambda1 = 0.2\n", - "lambda2 = 0.0\n", - "\n", - "block_frames = 51\n", - "norm_frames = 100\n", + " dask_store_zarr(data_basename + postfix_dict + zarr_ext, [\"images\"], [da_result], executor)\n", "\n", "\n", - "with get_executor(client) as executor:\n", - " dask_io_remove(data_basename + postfix_dict + zarr_ext, executor)\n", - "\n", - "\n", - "result = LazyZarrDataset(data_basename + postfix_norm + zarr_ext, \"images\")\n", - "block_shape = (block_frames,) + result.shape[1:]\n", - "with open_zarr(data_basename + postfix_dict + zarr_ext, \"w\") as f2:\n", - " new_result = f2.create_dataset(\"images\", shape=(n_components,) + result.shape[1:], dtype=result.dtype, chunks=True)\n", - "\n", - " result = par_generate_dictionary(block_shape)(\n", - " result,\n", - " n_components=n_components,\n", - " out=new_result,\n", - " **{\"sklearn.decomposition.dict_learning_online\" : {\n", - " \"n_jobs\" : 1,\n", - " \"n_iter\" : iters,\n", - " \"batch_size\" : batchsize,\n", - " \"alpha\" : lambda1\n", - " }\n", - " }\n", - " )\n", - "\n", - " result_j = f2.create_dataset(\"images_j\", shape=new_result.shape, dtype=numpy.uint16, chunks=True)\n", - " par_norm_layer(num_frames=norm_frames)(result, out=result_j)\n", - "\n", - "\n", - "with get_executor(client) as executor:\n", " zip_zarr(data_basename + postfix_dict + zarr_ext, executor)\n", "\n", "\n", "if __IPYTHON__:\n", - " result_image_stack = LazyZarrDataset(data_basename + postfix_dict + zarr_ext, \"images\")\n", + " result_image_stack = LazyZarrDataset(data_basename + postfix_dict + zarr_ext, \"images\")[...][...]\n", "\n", " mplsv = plt.figure(FigureClass=MPLViewer)\n", " mplsv.set_images(\n", " result_image_stack,\n", - " vmin=par_compute_min_projection(num_frames=norm_frames)(result_image_stack).min(),\n", - " vmax=par_compute_max_projection(num_frames=norm_frames)(result_image_stack).max()\n", - " )\n", - " mplsv.time_nav.stime.label.set_text(\"Basis Image\")" + " vmin=result_image_stack.min(),\n", + " vmax=result_image_stack.max()\n", + " )" ] }, {