Skip to content

Commit

Permalink
add AMD specific includes in cuda_prelude.h (#3614)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #3614

X-link: facebookresearch/FBGEMM#691

as title

Reviewed By: q10

Differential Revision: D68638427

fbshipit-source-id: 1daf07db4ab80cfa3c44480e8ac835e3822c60c8
  • Loading branch information
Bangsheng Tang authored and facebook-github-bot committed Jan 25, 2025
1 parent 1aff241 commit f9650f9
Showing 1 changed file with 15 additions and 0 deletions.
15 changes: 15 additions & 0 deletions fbgemm_gpu/include/fbgemm_gpu/utils/cuda_prelude.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,31 @@
#pragma once

#include <ATen/ATen.h>

#include <cuda.h>

#ifdef __HIP_PLATFORM_AMD__
#include <ATen/cuda/CUDAGeneratorImpl.h>
#include <ATen/cuda/PhiloxUtils.cuh>

#include <ATen/hip/impl/HIPGuardImplMasqueradingAsCUDA.h> // @manual
#else
#include <ATen/cuda/CUDAGraphsUtils.cuh>
#endif
#include <cassert>

namespace {

inline int get_device_sm_cnt_() {
#ifdef __HIP_PLATFORM_AMD__
hipDeviceProp_t deviceProp;
hipGetDeviceProperties(&deviceProp, c10::hip::current_device());
return deviceProp.multiProcessorCount;
#else
cudaDeviceProp* deviceProp =
at::cuda::getDeviceProperties(c10::cuda::current_device());
return deviceProp->multiProcessorCount;
#endif
}

} // namespace
Expand Down

0 comments on commit f9650f9

Please sign in to comment.