Skip to content

Commit

Permalink
Merge pull request #726 from HiroakiMikami/fix-profiler
Browse files Browse the repository at this point in the history
Set current_device of CUDAWorker thread
  • Loading branch information
kmaehashi authored Jul 24, 2023
2 parents 94dbe93 + abad82e commit 8377465
Showing 1 changed file with 9 additions and 2 deletions.
11 changes: 9 additions & 2 deletions pytorch_pfn_extras/profiler/_time_summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,11 @@ def initialize(self) -> None:
return
self._queue = queue.Queue(self._max_queue_size)
self._events = queue.Queue(self._max_queue_size * 2)
self._thread = threading.Thread(target=self._worker, daemon=True)
self._thread = threading.Thread(
target=self._worker,
args=(torch.cuda.current_device(),),
daemon=True,
)
self._thread.start()
self._initialized = True
self._thread_exited = False
Expand Down Expand Up @@ -156,9 +160,10 @@ def put(
assert not self._thread_exited
self._queue.put((name, events))

def _worker(self) -> None:
def _worker(self, device_id: int) -> None:
assert self._queue is not None
assert self._events is not None
torch.cuda.set_device(device_id)
while True:
try:
v = self._queue.get()
Expand All @@ -169,6 +174,8 @@ def _worker(self) -> None:
self._queue.task_done()
break
name, (begin, end) = v
assert begin.device == end.device
assert begin.device.index == torch.cuda.current_device()
end.synchronize() # type: ignore[no-untyped-call]
t_ms = begin.elapsed_time(end) # type: ignore[no-untyped-call]
self._add(name, t_ms / 1000)
Expand Down

0 comments on commit 8377465

Please sign in to comment.