Skip to content

Commit

Permalink
Iterate on multi GPU
Browse files Browse the repository at this point in the history
  • Loading branch information
bnorthan committed Dec 25, 2023
1 parent 913f1e0 commit c772db6
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 15 deletions.
40 changes: 32 additions & 8 deletions python/clij2fft/richardson_lucy_dask_multi_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,30 +141,54 @@ def richardson_lucy_dask(img, psf, numiterations, regularizationfactor, non_circ

chunk_size = (img.shape[0], y_chunk_size, x_chunk_size)
print('chunk size is',chunk_size)
print('==========================================================================')

dimg = da.from_array(img,chunks=(img.shape[0], y_chunk_size, x_chunk_size))

from multiprocessing import Pool, current_process, Queue
queue = Queue()

for i in range(1):
queue.put(i)

#from dask.distributed import get_worker
if non_circulant:

def richardson_lucy_nc_dask_task(img, psf, numiterations, regularizationfactor=0, lib=None, block_info=None, block_id=None, thread_id=None):
print('block id is', block_id)
print('block info is', block_info)
print('thread id is', thread_id)
#print('worker is ', get_worker())
return richardson_lucy_nc(img, psf, numiterations, regularizationfactor=regularizationfactor, lib=lib)

try:
print()
gpu_num=queue.get()
print('start rlnc')
print('gpu num is', gpu_num)
print('block id is', block_id)
print('block info is', block_info)
print('thread id is', thread_id)
#print('worker is ', get_worker())
if block_id is None:
print('returning block id is None')
return None
result=richardson_lucy_nc(img, psf, numiterations, regularizationfactor=regularizationfactor, lib=lib)
print('end rlnc')
return result
except:
pass
finally:
print('putting gpu num back', gpu_num)
queue.put(gpu_num)
rl_func = richardson_lucy_nc_dask_task
else:
def richardson_lucy_dask_task(img, psf, numiterations, regularizationfactor=0, lib=None, block_info=None, block_id=None, thread_id=None):
print('block id is', block_id)
print('\nblock id is', block_id)
print('block info is', block_info)
print('thread id is', thread_id)
#print('worker is', get_worker())
return richardson_lucy(img, psf, numiterations, regularizationfactor=regularizationfactor, lib=lib)
return img#richardson_lucy(img, psf, numiterations, regularizationfactor=regularizationfactor, lib=lib)
rl_func = richardson_lucy_dask_task


out = dimg.map_overlap(rl_func, depth={0:0, 1:overlap, 2:overlap}, dtype=np.float32, psf=psf, numiterations=numiterations, regularizationfactor=regularizationfactor)
return out.compute(num_workers=1)
return out.compute(num_workers=4)



Expand Down
23 changes: 16 additions & 7 deletions python/clij2fft/test_richardson_lucy_dask.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,33 @@
from richardson_lucy_dask import richardson_lucy_dask
from richardson_lucy_dask_multi_gpu import richardson_lucy_dask
from skimage.io import imread
import numpy as np
from matplotlib import pyplot as plt

img_name=r'D:\\images/images/Bars-G10-P15-stack-cropped.tif'
psf_name=r'D:\\images/images/PSF-Bars-stack-cropped.tif'

img_name=r'/home/bnorthan/images/deconvolution/Bars-G10-P15-stack.tif'
psf_name=r'/home/bnorthan/images/deconvolution/PSF-Bars-stack.tif'

img=imread(img_name)
print('image shape is',img.shape)

pad_z=50
pad_y=291
pad_x=700
mem_to_use=8
pad_y=50
pad_x=50
mem_to_use=1

img = np.pad(img, [(pad_z,pad_z),(pad_y, pad_y),(pad_x, pad_x)], mode = 'constant', constant_values = 0)
print('image shape is',img.shape)
psf=imread(psf_name)

decon=richardson_lucy_dask(img, psf, 100, 0.0001, mem_to_use=mem_to_use)
decon=richardson_lucy_dask(img, psf, 10, 0.0001, mem_to_use=mem_to_use)

fig, ax = plt.subplots(1,2)
ax[0].imshow(img.max(axis=0))
ax[0].set_title('img')

ax[1].imshow(decon.max(axis=0))
ax[1].set_title('deconvolution')

plt.imshow(decon.max(axis=0))
plt.show()
plt.show()

0 comments on commit c772db6

Please sign in to comment.