Skip to content

Commit

Permalink
Filling: do not override set CUDA_VISIBLE_DEVICES
Browse files Browse the repository at this point in the history
In this case, fall back to a single worker. Closes
#11.
  • Loading branch information
aschampion committed Jun 20, 2017
1 parent b43c416 commit 4d0e841
Showing 1 changed file with 16 additions and 7 deletions.
23 changes: 16 additions & 7 deletions diluvian/diluvian.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,15 +66,16 @@ def fill_subvolume_with_model(
# predicted labels.
conflict_count = np.full_like(prediction, 0, dtype=np.uint32)

def worker(worker_id, model_file, image, seeds, results, lock, revoked):
def worker(worker_id, set_devices, model_file, image, seeds, results, lock, revoked):
lock.acquire()
import tensorflow as tf

# Only make one GPU visible to Tensorflow so that it does not allocate
# all available memory on all devices.
# See: https://stackoverflow.com/questions/37893755
os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
os.environ['CUDA_VISIBLE_DEVICES'] = str(worker_id)
if set_devices:
# Only make one GPU visible to Tensorflow so that it does not allocate
# all available memory on all devices.
# See: https://stackoverflow.com/questions/37893755
os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
os.environ['CUDA_VISIBLE_DEVICES'] = str(worker_id)

with tf.device('/gpu:0'):
# Late import to avoid Keras import until TF bindings are set.
Expand Down Expand Up @@ -148,10 +149,18 @@ def stopping_callback(region):
dispatched_seeds.append(seed)
seed_queue.put(seed)

if 'CUDA_VISIBLE_DEVICES' in os.environ:
set_devices = False
num_workers = 1
logging.warn('Environment variable CUDA_VISIBLE_DEVICES is set, so only one worker can be used.\n'
'See https://github.com/aschampion/diluvian/issues/11')
else:
set_devices = True

workers = []
loading_lock = manager.Lock()
for worker_id in range(num_workers):
w = Process(target=worker, args=(worker_id, model_file, subvolume.image,
w = Process(target=worker, args=(worker_id, set_devices, model_file, subvolume.image,
seed_queue, results_queue, loading_lock, revoked_seeds))
w.start()
workers.append(w)
Expand Down

0 comments on commit 4d0e841

Please sign in to comment.