diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index aa7881b30aa..d638307b320 100755 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -14,6 +14,7 @@ load( "tf_cc_test", "tf_copts", "tf_cuda_library", + "if_not_windows", ) load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda") load( @@ -605,8 +606,9 @@ cc_library( "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:span", + ] + if_not_windows([ "@nvtx_archive//:nvtx", - ] + if_cuda_is_configured([ + ]) + if_cuda_is_configured([ "//tensorflow/stream_executor/cuda:cuda_stream", "//tensorflow/core/platform/default/build_config:cublas_plugin", "//tensorflow/core/platform/default/build_config:cudnn_plugin", diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index 18b7713fe2a..10984b31462 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -3287,9 +3287,10 @@ tf_cuda_library( "//third_party/eigen3", "//tensorflow/core/grappler/utils:functions", "//tensorflow/core/profiler/lib:traceme", - "@nvtx_archive//:nvtx", "//tensorflow/core/profiler/internal:traceme_recorder", - ] + mkl_deps(), + ] + if_not_windows([ + "@nvtx_archive//:nvtx", + ]) + mkl_deps(), alwayslink = 1, ) diff --git a/tensorflow/core/common_runtime/eager/BUILD b/tensorflow/core/common_runtime/eager/BUILD index a2539b0a984..ebc756bb0fc 100644 --- a/tensorflow/core/common_runtime/eager/BUILD +++ b/tensorflow/core/common_runtime/eager/BUILD @@ -3,6 +3,7 @@ load( "tf_cc_test", "tf_copts", "tf_cuda_library", + "if_not_windows", ) load( "//third_party/mkl:build_defs.bzl", @@ -203,9 +204,10 @@ tf_cuda_library( "//tensorflow/core:protos_all_cc", "//tensorflow/core/profiler/lib:traceme", "//tensorflow/core/grappler/optimizers:meta_optimizer", - "@nvtx_archive//:nvtx", ], - }), + }) + if_not_windows([ + "@nvtx_archive//:nvtx", + ]), ) tf_cc_test( diff --git a/tensorflow/core/framework/common_shape_fns.cc b/tensorflow/core/framework/common_shape_fns.cc index ecc62c5d70e..a82effd2bb2 100644 --- a/tensorflow/core/framework/common_shape_fns.cc +++ b/tensorflow/core/framework/common_shape_fns.cc @@ -1086,7 +1086,7 @@ Status FusedBatchNormShape(shape_inference::InferenceContext* c) { data_format_str); } const int rank = - (data_format_str == "NDHWC" or data_format_str == "NCDHW") ? 5 : 4; + (data_format_str == "NDHWC" || data_format_str == "NCDHW") ? 5 : 4; ShapeHandle x; TF_RETURN_IF_ERROR(c->WithRank(c->input(0), rank, &x)); @@ -1155,7 +1155,7 @@ Status FusedBatchNormGradShape(shape_inference::InferenceContext* c) { data_format_str); } const int rank = - (data_format_str == "NDHWC" or data_format_str == "NCDHW") ? 5 : 4; + (data_format_str == "NDHWC" || data_format_str == "NCDHW") ? 5 : 4; ShapeHandle y_backprop; TF_RETURN_IF_ERROR(c->WithRank(c->input(0), rank, &y_backprop)); ShapeHandle x; diff --git a/tensorflow/core/grappler/optimizers/generic_layout_optimizer.cc b/tensorflow/core/grappler/optimizers/generic_layout_optimizer.cc index f446fb23f62..cd487b0fbe8 100644 --- a/tensorflow/core/grappler/optimizers/generic_layout_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/generic_layout_optimizer.cc @@ -81,7 +81,7 @@ inline bool NumConvOnDeviceWithDataTypeOverThreshold( for (const auto& node : context.graph_view->GetNodes()) { const auto* node_def = node.node(); - if (!IsConv2D(*node_def) and !IsConv3D(*node_def)) { + if (!IsConv2D(*node_def) && !IsConv3D(*node_def)) { continue; } const string& device_name = @@ -401,7 +401,7 @@ Status PrintDebugLogs(string suffix, GraphDef* graph_) { TF_RETURN_IF_ERROR(ReadBoolFromEnvVar( "TF_ENABLE_LAYOUT_OPTIMIZE_GRAPH_REWRITE_LOG", /*default_value=*/false, &allow_print)); - if (not allow_print) return Status::OK(); + if (!allow_print) return Status::OK(); string prepend_path = "/tmp/logs/"; if (prepend_path.empty()) return Status::OK(); diff --git a/tensorflow/core/grappler/optimizers/generic_layout_optimizer_transposer.cc b/tensorflow/core/grappler/optimizers/generic_layout_optimizer_transposer.cc index 9c848d5b868..f6734d7c5bd 100644 --- a/tensorflow/core/grappler/optimizers/generic_layout_optimizer_transposer.cc +++ b/tensorflow/core/grappler/optimizers/generic_layout_optimizer_transposer.cc @@ -292,7 +292,7 @@ Status Transposer::CreateConstPermNode(TransposeContext* context, node.mutable_attr()->insert({"dtype", attr_data_type}); AttrValue attr_tensor; - Tensor tensor(DT_INT32, TensorShape({permutation.size()})); + Tensor tensor(DT_INT32, TensorShape({(long long)permutation.size()})); for (int i = 0; i < permutation.size(); i++) { tensor.flat()(i) = permutation[i]; } @@ -728,7 +728,7 @@ Status DefaultLayoutSensitiveOpTransposer::TransposeNode( TransposeContext* context, utils::MutableNodeView* node) { DCHECK(IsDefaultLayoutSensitiveOp(*node->node())); const int rank = GetFanoutPortRank(*node, 0); - if (rank != 4 and rank != 5) { + if (rank != 4 && rank != 5) { return Status::OK(); } ScopedDataFormatUpgrader data_format_upgrader(context, rank); @@ -748,7 +748,7 @@ Status BiasAddTransposer::TransposeNode( TransposeContext* context, utils::MutableNodeView* node) { DCHECK(IsBiasAdd(*node->node())); const int rank = GetFanoutPortRank(*node, 0); - if (rank != 4 and rank != 5) { + if (rank != 4 && rank != 5) { return Status::OK(); } if (!ShouldProcess(*context, *node)) { @@ -789,7 +789,7 @@ Status BiasAddGradTransposer::TransposeNode(TransposeContext* context, utils::MutableNodeView* node) { DCHECK(IsBiasAddGrad(*node->node())); const int rank = GetFaninPortRank(*node, 0); - if (rank != 4 and rank != 5) { + if (rank != 4 && rank != 5) { return Status::OK(); } if (!ShouldProcess(*context, *node)) { @@ -962,7 +962,7 @@ Status FusedBatchNormGradTransposer::TransposeNode( TransposeContext* context, utils::MutableNodeView* node) { DCHECK(IsFusedBatchNormGrad(*node->node())); const int rank = GetFanoutPortRank(*node, 0); - if (rank != 4 and rank != 5) { + if (rank != 4 && rank != 5) { return Status::OK(); } ScopedDataFormatUpgrader data_format_upgrader(context, rank); @@ -1335,7 +1335,7 @@ Status ConcatOpTransposer::TransposeNode(TransposeContext* context, utils::MutableNodeView* node) { DCHECK(IsConcat(*node->node())); const int rank = GetFanoutPortRank(*node, 0); - if (rank != 4 and rank != 5) { + if (rank != 4 && rank != 5) { return Status::OK(); } ScopedDataFormatUpgrader data_format_upgrader(context, rank); @@ -1518,7 +1518,7 @@ Status ReduceTransposer::TransposeNode(TransposeContext* context, utils::MutableNodeView* node) { DCHECK(IsReduceOp(*node->node())); const int rank = GetFaninPortRank(*node, 0); - if (rank != 4 and rank != 5) { + if (rank != 4 && rank != 5) { return Status::OK(); } ScopedDataFormatUpgrader data_format_upgrader(context, rank); @@ -1591,7 +1591,7 @@ Status ShapeTransposer::TransposeNode(TransposeContext* context, utils::MutableNodeView* node) { DCHECK(IsShape(*node->node())); const int rank = GetFaninPortRank(*node, 0); - if (rank != 4 and rank != 5) { + if (rank != 4 && rank != 5) { return Status::OK(); } ScopedDataFormatUpgrader data_format_upgrader(context, rank); @@ -1636,7 +1636,7 @@ Status SliceTransposer::TransposeNode(TransposeContext* context, utils::MutableNodeView* node) { DCHECK(IsSlice(*node->node())); const int rank = GetFanoutPortRank(*node, 0); - if (rank != 4 and rank != 5) { + if (rank != 4 && rank != 5) { return Status::OK(); } ScopedDataFormatUpgrader data_format_upgrader(context, rank); @@ -1907,7 +1907,7 @@ Status UnaryGradTransposer::TransposeNode(TransposeContext* context, utils::MutableNodeView* node) { DCHECK(IsUnaryGrad(*node->node())); const int rank = GetFanoutPortRank(*node, 0); - if (rank != 4 and rank != 5) { + if (rank != 4 && rank != 5) { return Status::OK(); } ScopedDataFormatUpgrader data_format_upgrader(context, rank); diff --git a/tensorflow/core/grappler/optimizers/remapper.cc b/tensorflow/core/grappler/optimizers/remapper.cc index 25107b1d768..33628531382 100644 --- a/tensorflow/core/grappler/optimizers/remapper.cc +++ b/tensorflow/core/grappler/optimizers/remapper.cc @@ -1284,7 +1284,7 @@ Status AddBatchNormNodes(RemapperContext* ctx, const FusedBatchNorm& matched) { Status status; string x_format = fused_node.attr().at(kDataFormat).s(); - if (x_format == "NCHW" or x_format == "NCDHW") { + if (x_format == "NCHW" || x_format == "NCDHW") { // Need to reshape the last 4 inputs NodeDef new_shape; const string new_shape_name = diff --git a/tensorflow/core/kernels/fused_batch_norm_op.cc b/tensorflow/core/kernels/fused_batch_norm_op.cc index be618180c2f..54e8189ee3c 100644 --- a/tensorflow/core/kernels/fused_batch_norm_op.cc +++ b/tensorflow/core/kernels/fused_batch_norm_op.cc @@ -1035,7 +1035,7 @@ class FusedBatchNormOpBase : public OpKernel { const Tensor& side_input = has_side_input_ ? context->input(5) : empty_side_input_; - OP_REQUIRES(context, x.dims() == 4 or x.dims() == 5, + OP_REQUIRES(context, x.dims() == 4 || x.dims() == 5, errors::InvalidArgument("input must be 4 or 5-dimensional", x.shape().DebugString())); OP_REQUIRES(context, scale.dims() == 1, @@ -1209,10 +1209,10 @@ class FusedBatchNormGradOpBase : public OpKernel { // saves inverted variance. const Tensor& saved_maybe_inv_var_or_pop_var = context->input(4); - OP_REQUIRES(context, y_backprop.dims() == 4 or y_backprop.dims() == 5, + OP_REQUIRES(context, y_backprop.dims() == 4 || y_backprop.dims() == 5, errors::InvalidArgument("input must be 4 or 5-dimensional", y_backprop.shape().DebugString())); - OP_REQUIRES(context, x.dims() == 4 or x.dims() == 5, + OP_REQUIRES(context, x.dims() == 4 || x.dims() == 5, errors::InvalidArgument("input must be 4 or 5-dimensional", x.shape().DebugString())); OP_REQUIRES(context, scale.dims() == 1, diff --git a/tensorflow/core/kernels/non_max_suppression_op.cu.cc b/tensorflow/core/kernels/non_max_suppression_op.cu.cc index b4c6c706ff3..48179422eb2 100644 --- a/tensorflow/core/kernels/non_max_suppression_op.cu.cc +++ b/tensorflow/core/kernels/non_max_suppression_op.cu.cc @@ -149,27 +149,6 @@ __device__ EIGEN_STRONG_INLINE void ClearBit(T* bit_mask, int bit) { atomicAnd(bit_mask + bin, ~(T(1) << (bit & kRemainderMask))); } -__global__ void FlipBoxes(Box* boxes, const int* num_batch_boxes, - const int* box_strides, const int batch_size) { - // for (int b = 0; b < batch_size; ++b) { - // int box_offset = box_strides[b]; - for (const int y : CudaGridRangeY(batch_size)) { - int box_offset = box_strides[y]; - Box* curr_boxes = boxes + box_offset; - // if (threadIdx.x == 0) { - // printf(" FBx batch=%d, box_offset=%d, num_batch_boxes=%d boxes@ %p \n", - // y, - // box_offset, num_batch_boxes[y],curr_boxes); - // } - - for (int i : GpuGridRangeX(num_batch_boxes[y])) { - Flipped(curr_boxes[i]); - } - } - // } -} - - // Produce a global bitmask (result_mask) of selected boxes from bitmask // generated by NMSKernel Abort early if max_boxes boxes are selected. // Bitmask is num_boxes*bit_mask_len bits indicating whether to keep or diff --git a/tensorflow/core/lib/core/errors.h b/tensorflow/core/lib/core/errors.h index 38f3475c58e..be0ec72a36f 100644 --- a/tensorflow/core/lib/core/errors.h +++ b/tensorflow/core/lib/core/errors.h @@ -44,7 +44,7 @@ namespace internal { // Eventually absl::strings will have native support for this and we will be // able to completely remove PrepareForStrCat(). template -typename std::enable_if::value, +typename std::enable_if::value, string>::type PrepareForStrCat(const T& t) { std::stringstream ss; diff --git a/tensorflow/core/lib/io/path.cc b/tensorflow/core/lib/io/path.cc index ad45878cee8..4440fb5f143 100644 --- a/tensorflow/core/lib/io/path.cc +++ b/tensorflow/core/lib/io/path.cc @@ -35,6 +35,8 @@ namespace tensorflow { namespace io { namespace internal { +const char kPathSep[] = "/"; + string JoinPathImpl(std::initializer_list paths) { string result; @@ -46,18 +48,12 @@ string JoinPathImpl(std::initializer_list paths) { continue; } - if (result[result.size() - 1] == '/') { - if (IsAbsolutePath(path)) { - strings::StrAppend(&result, path.substr(1)); - } else { - strings::StrAppend(&result, path); - } + if (IsAbsolutePath(path)) path = path.substr(1); + + if (result[result.size() - 1] == kPathSep[0]) { + strings::StrAppend(&result, path); } else { - if (IsAbsolutePath(path)) { - strings::StrAppend(&result, path); - } else { - strings::StrAppend(&result, "/", path); - } + strings::StrAppend(&result, kPathSep, path); } } @@ -126,6 +122,7 @@ bool FixBazelEnvPath(const char* path, string* out) { return true; } + } // namespace internal bool IsAbsolutePath(StringPiece path) { diff --git a/tensorflow/core/platform/env.cc b/tensorflow/core/platform/env.cc index ac91b79a07f..75e5b31f3ff 100644 --- a/tensorflow/core/platform/env.cc +++ b/tensorflow/core/platform/env.cc @@ -570,10 +570,4 @@ Status ReadTextOrBinaryProto(Env* env, const string& fname, return ReadBinaryProto(env, fname, proto); } -int setenv(const char* name, const char* value, int overwrite) { - return ::setenv(name, value, overwrite); -} - -int unsetenv(const char* name) { return ::unsetenv(name); } - } // namespace tensorflow diff --git a/tensorflow/core/platform/nvtx.h b/tensorflow/core/platform/nvtx.h index ae0802df456..c951ec1cc86 100755 --- a/tensorflow/core/platform/nvtx.h +++ b/tensorflow/core/platform/nvtx.h @@ -16,7 +16,11 @@ limitations under the License. #ifndef TENSORFLOW_CORE_PLATFORM_NVTX_H_ #define TENSORFLOW_CORE_PLATFORM_NVTX_H_ +#ifdef _WIN32 +#include "cuda/include/nvtx3/nvToolsExt.h" +#else #include "third_party/nvtx3/nvToolsExt.h" +#endif #include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/attr_value_util.h" diff --git a/tensorflow/core/platform/posix/env.cc b/tensorflow/core/platform/posix/env.cc index ba2a979df16..a9975f66602 100644 --- a/tensorflow/core/platform/posix/env.cc +++ b/tensorflow/core/platform/posix/env.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include #include +#include #include #include #include @@ -258,4 +259,10 @@ void PosixEnv::GetLocalTempDirectories(std::vector* list) { } } +int setenv(const char* name, const char* value, int overwrite) { + return ::setenv(name, value, overwrite); +} + +int unsetenv(const char* name) { return ::unsetenv(name); } + } // namespace tensorflow diff --git a/tensorflow/python/lib/core/bfloat16.cc b/tensorflow/python/lib/core/bfloat16.cc index 54be76375c9..ade51b65e8b 100644 --- a/tensorflow/python/lib/core/bfloat16.cc +++ b/tensorflow/python/lib/core/bfloat16.cc @@ -490,7 +490,7 @@ bool RegisterBfloat16Cast(int numpy_type, bool cast_is_safe) { } template -void BinaryUFunc(char** args, npy_intp* dimensions, npy_intp* steps, +void BinaryUFunc(char** args, const npy_intp* dimensions, const npy_intp* steps, void* data) { const char* i0 = args[0]; const char* i1 = args[1]; @@ -505,11 +505,17 @@ void BinaryUFunc(char** args, npy_intp* dimensions, npy_intp* steps, } } +// Numpy changed const-ness of PyUFuncGenericFunction, provide overload. template void CompareUFunc(char** args, npy_intp* dimensions, npy_intp* steps, void* data) { BinaryUFunc(args, dimensions, steps, data); } +template +void CompareUFunc(char** args, const npy_intp* dimensions, + const npy_intp* steps, void* data) { + BinaryUFunc(args, dimensions, steps, data); +} struct Bfloat16EqFunctor { npy_bool operator()(bfloat16 a, bfloat16 b) { return a == b; } diff --git a/tensorflow/stream_executor/cuda/cuda_dnn.cc b/tensorflow/stream_executor/cuda/cuda_dnn.cc index 023f0ce6326..395754581ad 100644 --- a/tensorflow/stream_executor/cuda/cuda_dnn.cc +++ b/tensorflow/stream_executor/cuda/cuda_dnn.cc @@ -1278,9 +1278,17 @@ port::Status CheckAndFetchProjectionWeights( cudnnDataType_t data_type; #if CUDNN_VERSION >= 8000 RETURN_IF_CUDNN_ERROR(cudnnGetRNNDescriptor_v6( + /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc, + /*hiddenSize=*/&hidden_size_v, + /*numLayers=*/&num_layers_v, + /*dropoutDesc=*/&dropout_desc, + /*inputMode=*/&input_mode, + /*direction=*/&direction, + /*mode=*/&mode, + /*algo=*/&algo, + /*mathPrec=*/&data_type)); #else RETURN_IF_CUDNN_ERROR(cudnnGetRNNDescriptor( -#endif /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc, /*hiddenSize=*/&hidden_size_v, /*numLayers=*/&num_layers_v, @@ -1290,6 +1298,7 @@ port::Status CheckAndFetchProjectionWeights( /*mode=*/&mode, /*algo=*/&algo, /*dataType=*/&data_type)); +#endif int rec_proj_size_v; int out_proj_size_v; RETURN_IF_CUDNN_ERROR(cudnnGetRNNProjectionLayers( diff --git a/tensorflow/stream_executor/cuda/cudnn_stub.cc b/tensorflow/stream_executor/cuda/cudnn_stub.cc index 073ba3ffd00..e30f749897e 100644 --- a/tensorflow/stream_executor/cuda/cudnn_stub.cc +++ b/tensorflow/stream_executor/cuda/cudnn_stub.cc @@ -53,7 +53,8 @@ cudnnStatus_t GetSymbolNotFoundError() { return CUDNN_STATUS_INTERNAL_ERROR; } #include "tensorflow/stream_executor/cuda/cudnn_6_0.inc" #elif CUDNN_MAJOR == 7 && CUDNN_MINOR < 1 #include "tensorflow/stream_executor/cuda/cudnn_7_0.inc" -#elif CUDNN_MAJOR == 7 && CUDNN_MINOR < 3 +// 2 instead of 3: see https://github.com/tensorflow/tensorflow/issues/32350 +#elif CUDNN_MAJOR == 7 && CUDNN_MINOR < 2 #include "tensorflow/stream_executor/cuda/cudnn_7_1.inc" #elif CUDNN_MAJOR == 7 && CUDNN_MINOR < 4 #include "tensorflow/stream_executor/cuda/cudnn_7_3.inc" diff --git a/tensorflow/stream_executor/stream.h b/tensorflow/stream_executor/stream.h index 382e4bd1fd2..8b4fafdd023 100644 --- a/tensorflow/stream_executor/stream.h +++ b/tensorflow/stream_executor/stream.h @@ -73,7 +73,6 @@ class AlgorithmDesc; class StreamExecutor; class ScratchAllocator; -enum BatchNormalizationKind; // Convert a type to the corresponding QuantizedActivationMode. template diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl index 8b29610c021..f04bd260ddd 100755 --- a/tensorflow/workspace.bzl +++ b/tensorflow/workspace.bzl @@ -168,11 +168,11 @@ def tf_repositories(path_prefix = "", tf_repo_name = ""): name = "eigen_archive", build_file = clean_dep("//third_party:eigen.BUILD"), patch_file = clean_dep("//third_party/eigen3:neon_casting_and_gpu_packet.patch"), - sha256 = "2f046557f4093becf51b44c6339873c18e2f1ea55c4b3f3a08b7d15a1d9c6e5b", # SHARED_EIGEN_SHA - strip_prefix = "eigen-4fd5d1477b221fc7daf2b7f1c7e4ee4f04ceaced", + sha256 = "bacd9508f8a636a616eef363d7f8d0f6da4c87b935132030a03793884a6ab4f1", # SHARED_EIGEN_SHA + strip_prefix = "eigen-8c9976d7f0558fdc8d0be7476c37e5d562332955", urls = [ - "https://storage.googleapis.com/mirror.tensorflow.org/gitlab.com/libeigen/eigen/-/archive/4fd5d1477b221fc7daf2b7f1c7e4ee4f04ceaced/eigen-4fd5d1477b221fc7daf2b7f1c7e4ee4f04ceaced.tar.gz", - "https://gitlab.com/libeigen/eigen/-/archive/4fd5d1477b221fc7daf2b7f1c7e4ee4f04ceaced/eigen-4fd5d1477b221fc7daf2b7f1c7e4ee4f04ceaced.tar.gz", + "https://storage.googleapis.com/mirror.tensorflow.org/gitlab.com/libeigen/eigen/-/archive/8c9976d7f0558fdc8d0be7476c37e5d562332955/eigen-8c9976d7f0558fdc8d0be7476c37e5d562332955.tar.gz", + "https://gitlab.com/libeigen/eigen/-/archive/8c9976d7f0558fdc8d0be7476c37e5d562332955/eigen-8c9976d7f0558fdc8d0be7476c37e5d562332955.tar.gz", ], ) diff --git a/third_party/gpus/crosstool/windows/msvc_wrapper_for_nvcc.py.tpl b/third_party/gpus/crosstool/windows/msvc_wrapper_for_nvcc.py.tpl index bfe7d6c5288..1247e486903 100644 --- a/third_party/gpus/crosstool/windows/msvc_wrapper_for_nvcc.py.tpl +++ b/third_party/gpus/crosstool/windows/msvc_wrapper_for_nvcc.py.tpl @@ -143,7 +143,7 @@ def InvokeNvcc(argv, log=False): nvccopts += undefines nvccopts += defines nvccopts += m_options - nvccopts += ['--compiler-options="' + " ".join(host_compiler_options) + '"'] + nvccopts += ['--compiler-options=' + ",".join(host_compiler_options)] nvccopts += ['-x', 'cu'] + opt + includes + out + ['-c'] + src_files # If we don't specify --keep-dir, nvcc will generate intermediate files under TEMP # Put them under NVCC_TEMP_DIR instead, then Bazel can ignore files under NVCC_TEMP_DIR during dependency check diff --git a/third_party/gpus/cuda_configure.bzl b/third_party/gpus/cuda_configure.bzl index 1bd7141a372..9b625e4278a 100644 --- a/third_party/gpus/cuda_configure.bzl +++ b/third_party/gpus/cuda_configure.bzl @@ -531,7 +531,10 @@ def lib_name(base_name, cpu_value, version = None, static = False): return "lib%s.a" % base_name return "lib%s.so%s" % (base_name, version) elif cpu_value == "Windows": - return "%s.lib" % base_name + if base_name == "nvToolsExt": + return "lib/x64/nvToolsExt64_1.lib" + else: + return "%s.lib" % base_name elif cpu_value == "Darwin": if static: return "lib%s.a" % base_name @@ -669,7 +672,7 @@ def _find_libs(repository_ctx, cuda_config): "nvToolsExt", repository_ctx, cpu_value, - cuda_config.config["cuda_library_dir"], + cuda_config.nvToolsExt_path, "1", ), "cupti": _find_cuda_lib( @@ -762,6 +765,11 @@ def _get_cuda_config(repository_ctx): cufft_version = cuda_version cusparse_version = cuda_version + if cpu_value == "Windows": + nvToolsExt_path = repository_ctx.os.environ.get("NVTOOLSEXT_PATH", "C:/Program Files/NVIDIA Corporation/NvToolsExt/") + else: + nvToolsExt_path = toolkit_path + return struct( cuda_toolkit_path = toolkit_path, cuda_version = cuda_version, @@ -775,6 +783,7 @@ def _get_cuda_config(repository_ctx): compute_capabilities = compute_capabilities(repository_ctx), cpu_value = cpu_value, config = config, + nvToolsExt_path=nvToolsExt_path, ) def _tpl(repository_ctx, tpl, substitutions = {}, out = None): @@ -1148,7 +1157,8 @@ def _create_local_cuda_repository(repository_ctx): out_dir = "cuda/bin", )) - if [int(x) for x in cuda_config.cudnn_version.split(".")] < [8, 0]: + # Select the headers based on the cuDNN version (strip '64_' for Windows). + if cuda_config.cudnn_version.rsplit("_", 1)[0] < "8": cudnn_headers = ["cudnn.h"] else: cudnn_headers = ["cudnn_adv_infer.h",