Skip to content

Commit

Permalink
Merge pull request #148 from jakirkham/ref_store_func
Browse files Browse the repository at this point in the history
Refactor and simplify storage of Dask Arrays to Zarr
  • Loading branch information
jakirkham authored Oct 23, 2017
2 parents d23e8b7 + 17a4e8e commit f7a26d0
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 157 deletions.
178 changes: 21 additions & 157 deletions nanshe_ipython.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@
" from nanshe.imp.segment import generate_dictionary\n",
"\n",
" import nanshe_workflow\n",
" from nanshe_workflow.data import io_remove, dask_io_remove, dask_load_hdf5, zip_zarr, open_zarr, DataBlocks, LazyZarrDataset\n",
" from nanshe_workflow.data import io_remove, dask_io_remove, dask_load_hdf5, dask_store_zarr, zip_zarr, open_zarr, DataBlocks, LazyZarrDataset\n",
" from nanshe_workflow.par import get_executor\n",
"\n",
"zarr.blosc.set_nthreads(1)\n",
Expand Down Expand Up @@ -300,27 +300,14 @@
"with get_executor(client) as executor:\n",
" dask_io_remove(data_basename + zarr_ext, executor)\n",
"\n",
" if data_ext == tiff_ext:\n",
" a = dask_imread.imread(data)\n",
" elif data_ext == h5_ext:\n",
" a = dask_load_hdf5(data, dataset)\n",
"\n",
" with open_zarr(data_basename + zarr_ext, \"w\") as f1:\n",
" if data_ext == tiff_ext:\n",
" a = dask_imread.imread(data)\n",
" elif data_ext == h5_ext:\n",
" a = dask_load_hdf5(data, dataset)\n",
"\n",
" d = f1.create_dataset(\n",
" dataset,\n",
" shape=a.shape,\n",
" dtype=a.dtype,\n",
" chunks=True\n",
" )\n",
" a = a.rechunk(d.chunks)\n",
" status = executor.compute(da.store(a, d, lock=False, compute=False))\n",
" dask.distributed.progress(status, notebook=False)\n",
" print(\"\")\n",
"\n",
" del a\n",
" del d\n",
" dask_store_zarr(data_basename + zarr_ext, [dataset], [a], executor)\n",
"\n",
" del a\n",
"\n",
" zip_zarr(data_basename + zarr_ext, executor)"
]
Expand Down Expand Up @@ -393,17 +380,7 @@
" da_imgs_trim = da_imgs[front:len(da_imgs)-back]\n",
"\n",
" # Store denoised data\n",
" with open_zarr(data_basename + postfix_trim + zarr_ext, \"w\") as f2:\n",
" result = f2.create_dataset(\n",
" \"images\",\n",
" shape=da_imgs_trim.shape,\n",
" dtype=da_imgs_trim.dtype,\n",
" chunks=True\n",
" )\n",
" da_imgs_trim = da_imgs_trim.rechunk(result.chunks)\n",
" status = executor.compute(da.store(da_imgs_trim, result, lock=False, compute=False))\n",
" dask.distributed.progress(status, notebook=False)\n",
" print(\"\")\n",
" dask_store_zarr(data_basename + postfix_trim + zarr_ext, [\"images\"], [da_imgs_trim], executor)\n",
"\n",
"\n",
" zip_zarr(data_basename + postfix_trim + zarr_ext, executor)\n",
Expand Down Expand Up @@ -478,17 +455,7 @@
" da_imgs_filt += da_imgs.min() - da_imgs_filt.min()\n",
"\n",
" # Store denoised data\n",
" with open_zarr(data_basename + postfix_dn + zarr_ext, \"w\") as f2:\n",
" result = f2.create_dataset(\n",
" \"images\",\n",
" shape=da_imgs_filt.shape,\n",
" dtype=da_imgs_filt.dtype,\n",
" chunks=True\n",
" )\n",
" da_imgs_filt = da_imgs_filt.rechunk(result.chunks)\n",
" status = executor.compute(da.store(da_imgs_filt, result, lock=False, compute=False))\n",
" dask.distributed.progress(status, notebook=False)\n",
" print(\"\")\n",
" dask_store_zarr(data_basename + postfix_dn + zarr_ext, [\"images\"], [da_imgs_filt], executor)\n",
"\n",
"\n",
" zip_zarr(data_basename + postfix_dn + zarr_ext, executor)\n",
Expand Down Expand Up @@ -741,17 +708,7 @@
" da_imgs_trunc = da.stack(da_imgs_trunc)\n",
"\n",
" # Store registered data\n",
" with open_zarr(data_basename + postfix_reg + zarr_ext, \"w\") as f2:\n",
" result = f2.create_dataset(\n",
" \"images\",\n",
" shape=da_imgs_trunc.shape,\n",
" dtype=da_imgs_trunc.dtype,\n",
" chunks=True\n",
" )\n",
" da_imgs_trunc = da_imgs_trunc.rechunk(result.chunks)\n",
" status = executor.compute(da.store(da_imgs_trunc, result, lock=False, compute=False))\n",
" dask.distributed.progress(status, notebook=False)\n",
" print(\"\")\n",
" dask_store_zarr(data_basename + postfix_reg + zarr_ext, [\"images\"], [da_imgs_trunc], executor)\n",
"\n",
" # Free truncated frames\n",
" del da_imgs_trunc\n",
Expand Down Expand Up @@ -812,55 +769,12 @@
" da_imgs_proj_std = da.sqrt(da_imgs_proj_std)\n",
"\n",
" # Store denoised data\n",
" with open_zarr(data_basename + postfix_proj + zarr_ext, \"w\") as f2:\n",
" statuses = []\n",
"\n",
" zarr_proj_hmean = f2.create_dataset(\n",
" \"hmean\",\n",
" shape=da_imgs_proj_hmean.shape,\n",
" dtype=da_imgs_proj_hmean.dtype,\n",
" chunks=True\n",
" )\n",
" da_imgs_proj_hmean = da_imgs_proj_hmean.rechunk(zarr_proj_hmean.chunks)\n",
" statuses.append(executor.compute(\n",
" da.store(da_imgs_proj_hmean, zarr_proj_hmean, lock=False, compute=False\n",
" )))\n",
"\n",
" zarr_proj_max = f2.create_dataset(\n",
" \"max\",\n",
" shape=da_imgs_proj_max.shape,\n",
" dtype=da_imgs_proj_max.dtype,\n",
" chunks=True\n",
" )\n",
" da_imgs_proj_max = da_imgs_proj_max.rechunk(zarr_proj_max.chunks)\n",
" statuses.append(executor.compute(\n",
" da.store(da_imgs_proj_max, zarr_proj_max, lock=False, compute=False\n",
" )))\n",
"\n",
" zarr_proj_mean = f2.create_dataset(\n",
" \"mean\",\n",
" shape=da_imgs_proj_mean.shape,\n",
" dtype=da_imgs_proj_mean.dtype,\n",
" chunks=True\n",
" )\n",
" da_imgs_proj_mean = da_imgs_proj_mean.rechunk(zarr_proj_mean.chunks)\n",
" statuses.append(executor.compute(\n",
" da.store(da_imgs_proj_mean, zarr_proj_mean, lock=False, compute=False\n",
" )))\n",
"\n",
" zarr_proj_std = f2.create_dataset(\n",
" \"std\",\n",
" shape=da_imgs_proj_std.shape,\n",
" dtype=da_imgs_proj_std.dtype,\n",
" chunks=True\n",
" )\n",
" da_imgs_proj_std = da_imgs_proj_std.rechunk(zarr_proj_std.chunks)\n",
" statuses.append(executor.compute(\n",
" da.store(da_imgs_proj_std, zarr_proj_std, lock=False, compute=False\n",
" )))\n",
"\n",
" dask.distributed.progress(statuses, notebook=False)\n",
" print(\"\")\n",
" dask_store_zarr(\n",
" data_basename + postfix_proj + zarr_ext,\n",
" [\"hmean\", \"max\", \"mean\", \"std\"],\n",
" [da_imgs_proj_hmean, da_imgs_proj_max, da_imgs_proj_mean, da_imgs_proj_std],\n",
" executor\n",
" )\n",
"\n",
"\n",
" zip_zarr(data_basename + postfix_proj + zarr_ext, executor)"
Expand Down Expand Up @@ -908,17 +822,7 @@
" da_imgs_sub -= da_imgs_sub.min()\n",
"\n",
" # Store denoised data\n",
" with open_zarr(data_basename + postfix_sub + zarr_ext, \"w\") as f2:\n",
" result = f2.create_dataset(\n",
" \"images\",\n",
" shape=da_imgs_sub.shape,\n",
" dtype=da_imgs_sub.dtype,\n",
" chunks=True\n",
" )\n",
" da_imgs_sub = da_imgs_sub.rechunk(result.chunks)\n",
" status = executor.compute(da.store(da_imgs_sub, result, lock=False, compute=False))\n",
" dask.distributed.progress(status, notebook=False)\n",
" print(\"\")\n",
" dask_store_zarr(data_basename + postfix_sub + zarr_ext, [\"images\"], [da_imgs_sub], executor)\n",
"\n",
"\n",
" zip_zarr(data_basename + postfix_sub + zarr_ext, executor)\n",
Expand Down Expand Up @@ -1002,17 +906,7 @@
" )\n",
"\n",
" # Store denoised data\n",
" with open_zarr(data_basename + postfix_f_f0 + zarr_ext, \"w\") as f2:\n",
" result = f2.create_dataset(\n",
" \"images\",\n",
" shape=da_result.shape,\n",
" dtype=da_result.dtype,\n",
" chunks=True\n",
" )\n",
" da_result = da_result.rechunk(result.chunks)\n",
" status = executor.compute(da.store(da_result, result, lock=False, compute=False))\n",
" dask.distributed.progress(status, notebook=False)\n",
" print(\"\")\n",
" dask_store_zarr(data_basename + postfix_f_f0 + zarr_ext, [\"images\"], [da_result], executor)\n",
"\n",
"\n",
" zip_zarr(data_basename + postfix_f_f0 + zarr_ext, executor)\n",
Expand Down Expand Up @@ -1078,17 +972,7 @@
" )\n",
"\n",
" # Store denoised data\n",
" with open_zarr(data_basename + postfix_wt + zarr_ext, \"w\") as f2:\n",
" result = f2.create_dataset(\n",
" \"images\",\n",
" shape=da_result.shape,\n",
" dtype=da_result.dtype,\n",
" chunks=True\n",
" )\n",
" da_result = da_result.rechunk(result.chunks)\n",
" status = executor.compute(da.store(da_result, result, lock=False, compute=False))\n",
" dask.distributed.progress(status, notebook=False)\n",
" print(\"\")\n",
" dask_store_zarr(data_basename + postfix_wt + zarr_ext, [\"images\"], [da_result], executor)\n",
"\n",
"\n",
" zip_zarr(data_basename + postfix_wt + zarr_ext, executor)\n",
Expand Down Expand Up @@ -1153,17 +1037,7 @@
" da_result = renormalized_images(da_imgs_flt_shift)\n",
"\n",
" # Store denoised data\n",
" with open_zarr(data_basename + postfix_norm + zarr_ext, \"w\") as f2:\n",
" result = f2.create_dataset(\n",
" \"images\",\n",
" shape=da_result.shape,\n",
" dtype=da_result.dtype,\n",
" chunks=True\n",
" )\n",
" da_result = da_result.rechunk(result.chunks)\n",
" status = executor.compute(da.store(da_result, result, lock=False, compute=False))\n",
" dask.distributed.progress(status, notebook=False)\n",
" print(\"\")\n",
" dask_store_zarr(data_basename + postfix_norm + zarr_ext, [\"images\"], [da_result], executor)\n",
"\n",
"\n",
" zip_zarr(data_basename + postfix_norm + zarr_ext, executor)\n",
Expand Down Expand Up @@ -1419,17 +1293,7 @@
" da_result = compute_traces(da_images, da_masks)\n",
"\n",
" # Store traces\n",
" with open_zarr(data_basename + postfix_traces + zarr_ext, \"w\") as fh_traces:\n",
" result = fh_traces.create_dataset(\n",
" \"traces\",\n",
" shape=da_result.shape,\n",
" dtype=da_result.dtype,\n",
" chunks=True\n",
" )\n",
" da_result = da_result.rechunk(result.chunks)\n",
" status = executor.compute(da.store(da_result, result, lock=False, compute=False))\n",
" dask.distributed.progress(status, notebook=False)\n",
" print(\"\")\n",
" dask_store_zarr(data_basename + postfix_traces + zarr_ext, [\"traces\"], [da_result], executor)\n",
"\n",
"\n",
" zip_zarr(data_basename + postfix_traces + zarr_ext, executor)\n",
Expand Down
29 changes: 29 additions & 0 deletions nanshe_workflow/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,35 @@ def _read_chunk(fn, dn, idx):
return a


def dask_store_zarr(filename, datasetnames, datasets, executor):
if len(datasetnames) != len(datasets):
raise ValueError(
"Need `datasetnames` and `datasets` to have the same length."
)

with open_zarr(filename, "w") as fh:
statuses = []

for each_datasetname, each_dataset in izip(datasetnames, datasets):
each_dataset = dask.array.asarray(each_dataset)

each_zarr_array = fh.create_dataset(
each_datasetname,
shape=each_dataset.shape,
dtype=each_dataset.dtype,
chunks=True
)

each_dataset = each_dataset.rechunk(each_zarr_array.chunks)

statuses.append(executor.compute(dask.array.store(
each_dataset, each_zarr_array, lock=False, compute=False
)))

dask.distributed.progress(statuses, notebook=False)
print("")


def save_tiff(fn, a):
if os.path.exists(fn):
os.remove(fn)
Expand Down

0 comments on commit f7a26d0

Please sign in to comment.