diff --git a/pytorch_pfn_extras/profiler/_time_summary.py b/pytorch_pfn_extras/profiler/_time_summary.py index aa5530bf0..f32e641ca 100644 --- a/pytorch_pfn_extras/profiler/_time_summary.py +++ b/pytorch_pfn_extras/profiler/_time_summary.py @@ -128,7 +128,11 @@ def initialize(self) -> None: return self._queue = queue.Queue(self._max_queue_size) self._events = queue.Queue(self._max_queue_size * 2) - self._thread = threading.Thread(target=self._worker, daemon=True) + self._thread = threading.Thread( + target=self._worker, + args=(torch.cuda.current_device(),), + daemon=True, + ) self._thread.start() self._initialized = True self._thread_exited = False @@ -156,9 +160,10 @@ def put( assert not self._thread_exited self._queue.put((name, events)) - def _worker(self) -> None: + def _worker(self, device_id: int) -> None: assert self._queue is not None assert self._events is not None + torch.cuda.set_device(device_id) while True: try: v = self._queue.get() @@ -169,6 +174,8 @@ def _worker(self) -> None: self._queue.task_done() break name, (begin, end) = v + assert begin.device == end.device + assert begin.device.index == torch.cuda.current_device() end.synchronize() # type: ignore[no-untyped-call] t_ms = begin.elapsed_time(end) # type: ignore[no-untyped-call] self._add(name, t_ms / 1000)