Skip to content

Commit

Permalink
Compute projection instead of dictionary
Browse files Browse the repository at this point in the history
  • Loading branch information
jakirkham committed Oct 23, 2017
1 parent f7a26d0 commit 59d189b
Showing 1 changed file with 15 additions and 95 deletions.
110 changes: 15 additions & 95 deletions nanshe_ipython.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@
"source": [
"from nanshe_workflow.par import halo_block_parallel\n",
"\n",
"from nanshe_workflow.imp2 import extract_f0, wavelet_transform, renormalized_images, normalize_data\n",
"from nanshe_workflow.imp2 import extract_f0, wavelet_transform, normalize_data\n",
"\n",
"from nanshe_workflow.par import halo_block_generate_dictionary_parallel\n",
"from nanshe_workflow.imp import block_postprocess_data_parallel\n",
Expand Down Expand Up @@ -993,11 +993,10 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"### Normalize Data\n",
"### Project\n",
"\n",
"* `block_frames` (`int`): number of frames to work with in each full frame block (run in parallel).\n",
"* `block_space` (`int`): extent of each spatial dimension for each block (run in parallel).\n",
"* `norm_frames` (`int`): number of frames for use during normalization of each full frame block (run in parallel)."
"* `proj_type` (`str`): type of projection to take."
]
},
{
Expand All @@ -1007,12 +1006,11 @@
"outputs": [],
"source": [
"block_frames = 40\n",
"block_space = 300\n",
"norm_frames = 100\n",
"proj_type = \"max\"\n",
"\n",
"\n",
"with get_executor(client) as executor:\n",
" dask_io_remove(data_basename + postfix_norm + zarr_ext, executor)\n",
" dask_io_remove(data_basename + postfix_dict + zarr_ext, executor)\n",
"\n",
"\n",
" with open_zarr(data_basename + postfix_wt + zarr_ext, \"r\") as f:\n",
Expand All @@ -1027,106 +1025,28 @@
" da_imgs_flt.dtype.itemsize >= 4):\n",
" da_imgs_flt = da_imgs_flt.astype(np.float32)\n",
"\n",
" da_imgs_flt_mins = da_imgs_flt.min(\n",
" axis=tuple(irange(1, da_imgs_flt.ndim)),\n",
" keepdims=True\n",
" )\n",
"\n",
" da_imgs_flt_shift = da_imgs_flt - da_imgs_flt_mins\n",
"\n",
" da_result = renormalized_images(da_imgs_flt_shift)\n",
" da_result = da_imgs\n",
" if proj_type == \"max\":\n",
" da_result = da_result.max(axis=0, keepdims=True)\n",
" elif proj_type == \"std\":\n",
" da_result = da_result.std(axis=0, keepdims=True)\n",
"\n",
" # Store denoised data\n",
" dask_store_zarr(data_basename + postfix_norm + zarr_ext, [\"images\"], [da_result], executor)\n",
"\n",
"\n",
" zip_zarr(data_basename + postfix_norm + zarr_ext, executor)\n",
"\n",
"\n",
"if __IPYTHON__:\n",
" result_image_stack = LazyZarrDataset(data_basename + postfix_norm + zarr_ext, \"images\")\n",
"\n",
" mplsv = plt.figure(FigureClass=MPLViewer)\n",
" mplsv.set_images(\n",
" result_image_stack,\n",
" vmin=par_compute_min_projection(num_frames=norm_frames)(result_image_stack).min(),\n",
" vmax=par_compute_max_projection(num_frames=norm_frames)(result_image_stack).max()\n",
" )"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Dictionary Learning\n",
"\n",
"* `n_components` (`int`): number of basis images in the dictionary.\n",
"* `batchsize` (`int`): minibatch size to use.\n",
"* `iters` (`int`): number of iterations to run before getting dictionary.\n",
"* `lambda1` (`float`): weight for L<sup>1</sup> sparisty enforcement on sparse code.\n",
"* `lambda2` (`float`): weight for L<sup>2</sup> sparisty enforcement on sparse code.\n",
"\n",
"<br>\n",
"* `block_frames` (`int`): number of frames to work with in each full frame block (run in parallel).\n",
"* `norm_frames` (`int`): number of frames for use during normalization of each full frame block (run in parallel)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"n_components = 50\n",
"batchsize = 256\n",
"iters = 100\n",
"lambda1 = 0.2\n",
"lambda2 = 0.0\n",
"\n",
"block_frames = 51\n",
"norm_frames = 100\n",
" dask_store_zarr(data_basename + postfix_dict + zarr_ext, [\"images\"], [da_result], executor)\n",
"\n",
"\n",
"with get_executor(client) as executor:\n",
" dask_io_remove(data_basename + postfix_dict + zarr_ext, executor)\n",
"\n",
"\n",
"result = LazyZarrDataset(data_basename + postfix_norm + zarr_ext, \"images\")\n",
"block_shape = (block_frames,) + result.shape[1:]\n",
"with open_zarr(data_basename + postfix_dict + zarr_ext, \"w\") as f2:\n",
" new_result = f2.create_dataset(\"images\", shape=(n_components,) + result.shape[1:], dtype=result.dtype, chunks=True)\n",
"\n",
" result = par_generate_dictionary(block_shape)(\n",
" result,\n",
" n_components=n_components,\n",
" out=new_result,\n",
" **{\"sklearn.decomposition.dict_learning_online\" : {\n",
" \"n_jobs\" : 1,\n",
" \"n_iter\" : iters,\n",
" \"batch_size\" : batchsize,\n",
" \"alpha\" : lambda1\n",
" }\n",
" }\n",
" )\n",
"\n",
" result_j = f2.create_dataset(\"images_j\", shape=new_result.shape, dtype=numpy.uint16, chunks=True)\n",
" par_norm_layer(num_frames=norm_frames)(result, out=result_j)\n",
"\n",
"\n",
"with get_executor(client) as executor:\n",
" zip_zarr(data_basename + postfix_dict + zarr_ext, executor)\n",
"\n",
"\n",
"if __IPYTHON__:\n",
" result_image_stack = LazyZarrDataset(data_basename + postfix_dict + zarr_ext, \"images\")\n",
" result_image_stack = LazyZarrDataset(data_basename + postfix_dict + zarr_ext, \"images\")[...][...]\n",
"\n",
" mplsv = plt.figure(FigureClass=MPLViewer)\n",
" mplsv.set_images(\n",
" result_image_stack,\n",
" vmin=par_compute_min_projection(num_frames=norm_frames)(result_image_stack).min(),\n",
" vmax=par_compute_max_projection(num_frames=norm_frames)(result_image_stack).max()\n",
" )\n",
" mplsv.time_nav.stime.label.set_text(\"Basis Image\")"
" vmin=result_image_stack.min(),\n",
" vmax=result_image_stack.max()\n",
" )"
]
},
{
Expand Down

0 comments on commit 59d189b

Please sign in to comment.