Skip to content

Commit

Permalink
Find dictionary on thresholded data
Browse files Browse the repository at this point in the history
  • Loading branch information
jakirkham committed Sep 13, 2017
1 parent fe3381d commit 6cdc32a
Showing 1 changed file with 91 additions and 9 deletions.
100 changes: 91 additions & 9 deletions nanshe_ipython.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,11 @@
"postfix_f_f0 = \"_f_f0\"\n",
"postfix_wt = \"_wt\"\n",
"postfix_norm = \"_norm\"\n",
"postfix_dict = \"_dict\"\n",
"postfix_flat = \"_flat\"\n",
"postfix_cc = \"_cc\"\n",
"postfix_post = \"_post\"\n",
"postfix_thrd = \"_thrd\"\n",
"postfix_dict = \"_dict\"\n",
"postfix_rois = \"_rois\"\n",
"postfix_traces = \"_traces\"\n",
"postfix_proj = \"_proj\"\n",
Expand Down Expand Up @@ -1165,8 +1166,8 @@
"\n",
"\n",
"# Somehow we can't overwrite the file in the container so this is needed.\n",
"io_remove(data_basename + postfix_dict + zarr_ext)\n",
"io_remove(data_basename + postfix_dict + h5_ext)\n",
"io_remove(data_basename + postfix_flat + zarr_ext)\n",
"io_remove(data_basename + postfix_flat + h5_ext)\n",
"\n",
"\n",
"with open_zarr(data_basename + postfix_wt + zarr_ext, \"r\") as f:\n",
Expand All @@ -1189,7 +1190,7 @@
" da_result = da_result.std(axis=0, keepdims=True)\n",
"\n",
" # Store denoised data\n",
" with open_zarr(data_basename + postfix_dict + zarr_ext, \"w\") as f2:\n",
" with open_zarr(data_basename + postfix_flat + zarr_ext, \"w\") as f2:\n",
" result = f2.create_dataset(\n",
" \"images\",\n",
" shape=da_result.shape,\n",
Expand All @@ -1201,15 +1202,15 @@
" dask.distributed.progress(status, notebook=False)\n",
"\n",
"\n",
"zip_zarr(data_basename + postfix_dict + zarr_ext)\n",
"zip_zarr(data_basename + postfix_flat + zarr_ext)\n",
"\n",
"with h5py.File(data_basename + postfix_dict + h5_ext, \"w\") as f2:\n",
" with open_zarr(data_basename + postfix_dict + zarr_ext, \"r\") as f1:\n",
"with h5py.File(data_basename + postfix_flat + h5_ext, \"w\") as f2:\n",
" with open_zarr(data_basename + postfix_flat + zarr_ext, \"r\") as f1:\n",
" zarr_to_hdf5(f1, f2)\n",
"\n",
"\n",
"if __IPYTHON__:\n",
" result_image_stack = LazyZarrDataset(data_basename + postfix_dict + zarr_ext, \"images\")[...][...]\n",
" result_image_stack = LazyZarrDataset(data_basename + postfix_flat + zarr_ext, \"images\")[...][...]\n",
"\n",
" mplsv = plt.figure(FigureClass=MPLViewer)\n",
" mplsv.set_images(\n",
Expand Down Expand Up @@ -1247,7 +1248,7 @@
"io_remove(data_basename + postfix_cc + h5_ext)\n",
"\n",
"\n",
"with open_zarr(data_basename + postfix_dict + zarr_ext, \"r\") as f:\n",
"with open_zarr(data_basename + postfix_flat + zarr_ext, \"r\") as f:\n",
" with get_executor(client) as executor:\n",
" imgs = f[\"images\"]\n",
" da_imgs = da.from_array(\n",
Expand Down Expand Up @@ -1455,6 +1456,87 @@
" )"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Dictionary Learning\n",
"\n",
"* `n_components` (`int`): number of basis images in the dictionary.\n",
"* `batchsize` (`int`): minibatch size to use.\n",
"* `iters` (`int`): number of iterations to run before getting dictionary.\n",
"* `lambda1` (`float`): weight for L<sup>1</sup> sparisty enforcement on sparse code.\n",
"* `lambda2` (`float`): weight for L<sup>2</sup> sparisty enforcement on sparse code.\n",
"\n",
"<br>\n",
"* `block_frames` (`int`): number of frames to work with in each full frame block (run in parallel).\n",
"* `norm_frames` (`int`): number of frames for use during normalization of each full frame block (run in parallel)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%%time\n",
"\n",
"\n",
"n_components = 50\n",
"batchsize = 256\n",
"iters = 100\n",
"lambda1 = 0.2\n",
"lambda2 = 0.0\n",
"\n",
"block_frames = 51\n",
"norm_frames = 100\n",
"\n",
"\n",
"# Somehow we can't overwrite the file in the container so this is needed.\n",
"io_remove(data_basename + postfix_dict + zarr_ext)\n",
"io_remove(data_basename + postfix_dict + h5_ext)\n",
"\n",
"result = LazyZarrDataset(data_basename + postfix_thrd + zarr_ext, \"images\")\n",
"block_shape = (block_frames,) + result.shape[1:]\n",
"with open_zarr(data_basename + postfix_dict + zarr_ext, \"w\") as f2:\n",
" new_result = f2.create_dataset(\"images\", shape=(n_components,) + result.shape[1:], dtype=result.dtype, chunks=True)\n",
"\n",
" result = par_generate_dictionary(block_shape)(\n",
" result,\n",
" n_components=n_components,\n",
" out=new_result,\n",
" **{\"sklearn.decomposition.dict_learning_online\" : {\n",
" \"n_jobs\" : 1,\n",
" \"n_iter\" : iters,\n",
" \"batch_size\" : batchsize,\n",
" \"alpha\" : lambda1\n",
" }\n",
" }\n",
" )\n",
"\n",
" result_j = f2.create_dataset(\"images_j\", shape=new_result.shape, dtype=numpy.uint16, chunks=True)\n",
" par_norm_layer(num_frames=norm_frames)(result, out=result_j)\n",
"\n",
"\n",
"zip_zarr(data_basename + postfix_dict + zarr_ext)\n",
"\n",
"with h5py.File(data_basename + postfix_dict + h5_ext, \"w\") as f2:\n",
" with open_zarr(data_basename + postfix_dict + zarr_ext, \"r\") as f1:\n",
" zarr_to_hdf5(f1, f2)\n",
"\n",
"\n",
"if __IPYTHON__:\n",
" result_image_stack = LazyZarrDataset(data_basename + postfix_dict + zarr_ext, \"images\")\n",
"\n",
" mplsv = plt.figure(FigureClass=MPLViewer)\n",
" mplsv.set_images(\n",
" result_image_stack,\n",
" vmin=par_compute_min_projection(num_frames=norm_frames)(result_image_stack).min(),\n",
" vmax=par_compute_max_projection(num_frames=norm_frames)(result_image_stack).max()\n",
" )\n",
" mplsv.time_nav.stime.label.set_text(\"Basis Image\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand Down

0 comments on commit 6cdc32a

Please sign in to comment.