diff --git a/ts/torch_handler/base_handler.py b/ts/torch_handler/base_handler.py index 29af1224e4..fa4be5841c 100644 --- a/ts/torch_handler/base_handler.py +++ b/ts/torch_handler/base_handler.py @@ -158,9 +158,9 @@ def initialize(self, context): self.map_location + ":" + str(properties.get("gpu_id")) ) elif ( - torch.xpu.is_available() + os.environ.get("TS_IPEX_GPU_ENABLE", "false") == "true" and properties.get("gpu_id") is not None - and os.environ.get("TS_IPEX_GPU_ENABLE", "false") == "true" + and torch.xpu.is_available() ): self.map_location = "xpu" self.device = torch.device(