From 30fb2c4abaaaa966999eab11674f25b18460e609 Mon Sep 17 00:00:00 2001 From: Michael Suo Date: Sat, 11 Jun 2022 10:22:58 -0700 Subject: [PATCH] [lint] autoformat test/cpp and torch/csrc Let's have some fun. Pull Request resolved: https://github.com/pytorch/pytorch/pull/78828 Approved by: https://github.com/ezyang --- .lintrunner.toml | 12 +- aten/src/ATen/autocast_mode.h | 2 + test/cpp/api/any.cpp | 65 +- test/cpp/api/autograd.cpp | 701 ++- test/cpp/api/dataloader.cpp | 32 +- test/cpp/api/dispatch.cpp | 21 +- test/cpp/api/enum.cpp | 84 +- test/cpp/api/fft.cpp | 8 +- test/cpp/api/functional.cpp | 2085 ++++---- test/cpp/api/grad_mode.cpp | 21 +- test/cpp/api/imethod.cpp | 6 +- test/cpp/api/inference_mode.cpp | 311 +- test/cpp/api/init.cpp | 16 +- test/cpp/api/init_baseline.h | 1744 ++++++- test/cpp/api/jit.cpp | 19 +- test/cpp/api/meta_tensor.cpp | 2 +- test/cpp/api/misc.cpp | 5 +- test/cpp/api/module.cpp | 65 +- test/cpp/api/moduledict.cpp | 120 +- test/cpp/api/modulelist.cpp | 13 +- test/cpp/api/modules.cpp | 3532 +++++++------- test/cpp/api/namespace.cpp | 5 +- test/cpp/api/nn_utils.cpp | 461 +- test/cpp/api/optim.cpp | 117 +- test/cpp/api/optim_baseline.h | 4222 ++++++++++++----- test/cpp/api/parallel.cpp | 120 +- test/cpp/api/parameterdict.cpp | 3 +- test/cpp/api/rnn.cpp | 313 +- test/cpp/api/sequential.cpp | 319 +- test/cpp/api/serialize.cpp | 285 +- test/cpp/api/special.cpp | 6 +- test/cpp/api/static.cpp | 2 +- test/cpp/api/support.h | 53 +- test/cpp/api/tensor.cpp | 402 +- test/cpp/api/tensor_cuda.cpp | 12 +- test/cpp/api/tensor_flatten.cpp | 20 +- test/cpp/api/tensor_indexing.cpp | 402 +- test/cpp/api/tensor_options.cpp | 9 +- test/cpp/api/tensor_options_cuda.cpp | 16 +- test/cpp/api/transformer.cpp | 1788 ++++--- test/cpp/c10d/ProcessGroupGlooAsyncTest.cpp | 35 +- test/cpp/c10d/ProcessGroupGlooTest.cpp | 82 +- test/cpp/c10d/ProcessGroupMPITest.cpp | 4 +- test/cpp/c10d/ProcessGroupNCCLErrorsTest.cpp | 19 +- test/cpp/c10d/ProcessGroupNCCLTest.cpp | 136 +- test/cpp/c10d/TCPStoreTest.cpp | 71 +- test/cpp/dist_autograd/test_dist_autograd.cpp | 6 +- test/cpp/lazy/test_backend_device.cpp | 14 +- test/cpp/lazy/test_cache.cpp | 23 +- test/cpp/lazy/test_ir.cpp | 35 +- test/cpp/lazy/test_ir_util.cpp | 12 +- test/cpp/lazy/test_lazy_ops.cpp | 4220 +++++++++------- test/cpp/lazy/test_lazy_ops_util.cpp | 59 +- test/cpp/lazy/test_lazy_ops_util.h | 39 +- test/cpp/lazy/test_misc.cpp | 7 +- test/cpp/lazy/test_shape.cpp | 4 +- test/cpp/lazy/test_symbolic_shape.cpp | 12 +- test/cpp/lazy/test_tensor_impl.cpp | 16 +- test/cpp/lazy/test_trie_cache.cpp | 21 +- test/cpp/lazy/test_util.cpp | 12 +- test/cpp/lite_interpreter_runtime/main.cpp | 12 +- .../test_mobile_profiler.cpp | 39 +- test/cpp/profiler/containers.cpp | 105 +- test/cpp/profiler/record_function.cpp | 6 +- test/cpp/rpc/e2e_test_base.h | 7 +- .../cpp/rpc/test_tensorpipe_serialization.cpp | 20 +- torch/csrc/CudaIPCTypes.cpp | 36 +- torch/csrc/CudaIPCTypes.h | 6 +- torch/csrc/DataLoader.cpp | 153 +- torch/csrc/Device.cpp | 169 +- torch/csrc/Device.h | 11 +- torch/csrc/Dtype.cpp | 138 +- torch/csrc/Dtype.h | 16 +- torch/csrc/DynamicTypes.cpp | 54 +- torch/csrc/DynamicTypes.h | 15 +- torch/csrc/Exceptions.cpp | 201 +- torch/csrc/Exceptions.h | 276 +- torch/csrc/Generator.cpp | 197 +- torch/csrc/Generator.h | 22 +- torch/csrc/Layout.cpp | 91 +- torch/csrc/Layout.h | 9 +- torch/csrc/MemoryFormat.cpp | 94 +- torch/csrc/MemoryFormat.h | 11 +- torch/csrc/Module.cpp | 1051 ++-- torch/csrc/PythonTypes.h | 10 +- torch/csrc/QScheme.cpp | 97 +- torch/csrc/QScheme.h | 9 +- torch/csrc/Size.cpp | 188 +- torch/csrc/Size.h | 8 +- torch/csrc/Storage.cpp | 302 +- torch/csrc/Storage.h | 12 +- torch/csrc/StorageMethods.cpp | 354 +- torch/csrc/StorageSharing.cpp | 404 +- torch/csrc/Stream.cpp | 123 +- torch/csrc/Stream.h | 7 +- torch/csrc/THConcat.h | 24 +- torch/csrc/THP.h | 10 +- torch/csrc/TypeInfo.cpp | 144 +- torch/csrc/Types.h | 3 +- torch/csrc/api/include/torch/all.h | 2 +- torch/csrc/api/include/torch/arg.h | 29 +- torch/csrc/api/include/torch/autograd.h | 2 +- torch/csrc/api/include/torch/cuda.h | 2 +- .../csrc/api/include/torch/data/dataloader.h | 11 +- .../include/torch/data/dataloader/stateful.h | 8 +- .../include/torch/data/dataloader/stateless.h | 4 +- .../api/include/torch/data/datasets/base.h | 6 +- .../api/include/torch/data/datasets/chunk.h | 98 +- .../include/torch/data/datasets/stateful.h | 4 +- .../include/torch/data/detail/sequencers.h | 2 +- torch/csrc/api/include/torch/data/iterator.h | 6 +- .../api/include/torch/data/samplers/base.h | 2 +- .../data/samplers/custom_batch_request.h | 2 +- .../api/include/torch/data/samplers/random.h | 4 +- .../torch/detail/TensorDataContainer.h | 272 +- torch/csrc/api/include/torch/enum.h | 90 +- .../csrc/api/include/torch/expanding_array.h | 50 +- torch/csrc/api/include/torch/fft.h | 190 +- torch/csrc/api/include/torch/imethod.h | 5 +- torch/csrc/api/include/torch/jit.h | 2 +- torch/csrc/api/include/torch/linalg.h | 577 ++- torch/csrc/api/include/torch/nn/cloneable.h | 10 +- torch/csrc/api/include/torch/nn/functional.h | 2 +- .../include/torch/nn/functional/activation.h | 495 +- .../include/torch/nn/functional/batchnorm.h | 74 +- .../api/include/torch/nn/functional/conv.h | 231 +- .../include/torch/nn/functional/distance.h | 31 +- .../api/include/torch/nn/functional/dropout.h | 108 +- .../include/torch/nn/functional/embedding.h | 161 +- .../api/include/torch/nn/functional/fold.h | 79 +- .../torch/nn/functional/instancenorm.h | 50 +- .../api/include/torch/nn/functional/linear.h | 16 +- .../api/include/torch/nn/functional/loss.h | 573 ++- .../torch/nn/functional/normalization.h | 159 +- .../api/include/torch/nn/functional/padding.h | 21 +- .../torch/nn/functional/pixelshuffle.h | 16 +- .../api/include/torch/nn/functional/pooling.h | 793 ++-- .../include/torch/nn/functional/upsampling.h | 193 +- .../api/include/torch/nn/functional/vision.h | 41 +- torch/csrc/api/include/torch/nn/init.h | 39 +- torch/csrc/api/include/torch/nn/module.h | 47 +- torch/csrc/api/include/torch/nn/modules.h | 18 +- .../api/include/torch/nn/modules/_functions.h | 7 +- .../api/include/torch/nn/modules/activation.h | 100 +- .../api/include/torch/nn/modules/adaptive.h | 63 +- .../api/include/torch/nn/modules/batchnorm.h | 100 +- .../api/include/torch/nn/modules/common.h | 65 +- .../include/torch/nn/modules/container/any.h | 16 +- .../nn/modules/container/any_module_holder.h | 25 +- .../torch/nn/modules/container/any_value.h | 12 +- .../torch/nn/modules/container/functional.h | 12 +- .../torch/nn/modules/container/moduledict.h | 39 +- .../torch/nn/modules/container/named_any.h | 12 +- .../torch/nn/modules/container/sequential.h | 34 +- .../csrc/api/include/torch/nn/modules/conv.h | 278 +- .../api/include/torch/nn/modules/distance.h | 27 +- .../api/include/torch/nn/modules/dropout.h | 59 +- .../api/include/torch/nn/modules/embedding.h | 98 +- .../csrc/api/include/torch/nn/modules/fold.h | 9 +- .../include/torch/nn/modules/instancenorm.h | 72 +- .../api/include/torch/nn/modules/linear.h | 57 +- .../csrc/api/include/torch/nn/modules/loss.h | 394 +- .../include/torch/nn/modules/normalization.h | 61 +- .../api/include/torch/nn/modules/padding.h | 183 +- .../include/torch/nn/modules/pixelshuffle.h | 12 +- .../api/include/torch/nn/modules/pooling.h | 324 +- torch/csrc/api/include/torch/nn/modules/rnn.h | 141 +- .../include/torch/nn/modules/transformer.h | 166 +- .../torch/nn/modules/transformercoder.h | 148 +- .../torch/nn/modules/transformerlayer.h | 189 +- .../api/include/torch/nn/modules/upsampling.h | 9 +- .../csrc/api/include/torch/nn/modules/utils.h | 15 +- torch/csrc/api/include/torch/nn/options.h | 10 +- .../api/include/torch/nn/options/activation.h | 47 +- .../api/include/torch/nn/options/adaptive.h | 8 +- .../api/include/torch/nn/options/batchnorm.h | 12 +- .../csrc/api/include/torch/nn/options/conv.h | 78 +- .../api/include/torch/nn/options/distance.h | 14 +- .../api/include/torch/nn/options/dropout.h | 6 +- .../api/include/torch/nn/options/embedding.h | 104 +- .../csrc/api/include/torch/nn/options/fold.h | 3 +- .../include/torch/nn/options/instancenorm.h | 12 +- .../api/include/torch/nn/options/linear.h | 7 +- .../csrc/api/include/torch/nn/options/loss.h | 259 +- .../include/torch/nn/options/normalization.h | 25 +- .../api/include/torch/nn/options/padding.h | 56 +- .../include/torch/nn/options/pixelshuffle.h | 4 +- .../api/include/torch/nn/options/pooling.h | 85 +- torch/csrc/api/include/torch/nn/options/rnn.h | 43 +- .../include/torch/nn/options/transformer.h | 18 +- .../torch/nn/options/transformercoder.h | 65 +- .../torch/nn/options/transformerlayer.h | 20 +- .../api/include/torch/nn/options/upsampling.h | 36 +- .../api/include/torch/nn/options/vision.h | 9 +- .../include/torch/nn/parallel/data_parallel.h | 30 +- .../api/include/torch/nn/utils/clip_grad.h | 52 +- .../torch/nn/utils/convert_parameters.h | 26 +- torch/csrc/api/include/torch/nn/utils/rnn.h | 115 +- torch/csrc/api/include/torch/optim/adagrad.h | 53 +- torch/csrc/api/include/torch/optim/adam.h | 56 +- torch/csrc/api/include/torch/optim/adamw.h | 56 +- torch/csrc/api/include/torch/optim/lbfgs.h | 53 +- .../csrc/api/include/torch/optim/optimizer.h | 82 +- torch/csrc/api/include/torch/optim/rmsprop.h | 41 +- .../torch/optim/schedulers/lr_scheduler.h | 25 +- .../include/torch/optim/schedulers/step_lr.h | 13 +- .../csrc/api/include/torch/optim/serialize.h | 322 +- torch/csrc/api/include/torch/optim/sgd.h | 45 +- torch/csrc/api/include/torch/ordered_dict.h | 37 +- torch/csrc/api/include/torch/serialize.h | 16 +- .../include/torch/serialize/input-archive.h | 22 +- .../include/torch/serialize/output-archive.h | 4 +- torch/csrc/api/include/torch/sparse.h | 5 +- torch/csrc/api/include/torch/special.h | 341 +- torch/csrc/api/include/torch/types.h | 31 +- torch/csrc/api/include/torch/utils.h | 37 +- torch/csrc/api/src/cuda.cpp | 6 +- torch/csrc/api/src/imethod.cpp | 3 +- torch/csrc/api/src/jit.cpp | 8 +- torch/csrc/api/src/nn/init.cpp | 16 +- torch/csrc/api/src/nn/module.cpp | 4 +- torch/csrc/api/src/nn/modules/_functions.cpp | 57 +- torch/csrc/api/src/nn/modules/activation.cpp | 181 +- torch/csrc/api/src/nn/modules/adaptive.cpp | 99 +- torch/csrc/api/src/nn/modules/batchnorm.cpp | 22 +- torch/csrc/api/src/nn/modules/conv.cpp | 359 +- torch/csrc/api/src/nn/modules/distance.cpp | 21 +- torch/csrc/api/src/nn/modules/dropout.cpp | 35 +- torch/csrc/api/src/nn/modules/embedding.cpp | 105 +- torch/csrc/api/src/nn/modules/fold.cpp | 25 +- .../csrc/api/src/nn/modules/instancenorm.cpp | 18 +- torch/csrc/api/src/nn/modules/linear.cpp | 53 +- torch/csrc/api/src/nn/modules/loss.cpp | 174 +- .../csrc/api/src/nn/modules/normalization.cpp | 63 +- torch/csrc/api/src/nn/modules/padding.cpp | 16 +- .../csrc/api/src/nn/modules/pixelshuffle.cpp | 6 +- torch/csrc/api/src/nn/modules/pooling.cpp | 352 +- torch/csrc/api/src/nn/modules/rnn.cpp | 645 ++- torch/csrc/api/src/nn/modules/transformer.cpp | 359 +- torch/csrc/api/src/nn/modules/upsampling.cpp | 3 +- torch/csrc/api/src/nn/options/activation.cpp | 29 +- torch/csrc/api/src/nn/options/adaptive.cpp | 8 +- torch/csrc/api/src/nn/options/batchnorm.cpp | 3 +- torch/csrc/api/src/nn/options/embedding.cpp | 12 +- .../csrc/api/src/nn/options/instancenorm.cpp | 3 +- torch/csrc/api/src/nn/options/linear.cpp | 21 +- .../csrc/api/src/nn/options/normalization.cpp | 12 +- torch/csrc/api/src/nn/options/padding.cpp | 3 +- torch/csrc/api/src/nn/options/rnn.cpp | 16 +- torch/csrc/api/src/nn/options/transformer.cpp | 52 +- torch/csrc/api/src/optim/adagrad.cpp | 53 +- torch/csrc/api/src/optim/adam.cpp | 70 +- torch/csrc/api/src/optim/adamw.cpp | 72 +- torch/csrc/api/src/optim/lbfgs.cpp | 486 +- torch/csrc/api/src/optim/optimizer.cpp | 73 +- torch/csrc/api/src/optim/rmsprop.cpp | 66 +- .../api/src/optim/schedulers/lr_scheduler.cpp | 22 +- .../csrc/api/src/optim/schedulers/step_lr.cpp | 19 +- torch/csrc/api/src/optim/sgd.cpp | 38 +- torch/csrc/api/src/python/init.cpp | 2 +- torch/csrc/api/src/serialize.cpp | 1 - .../csrc/api/src/serialize/input-archive.cpp | 55 +- .../csrc/api/src/serialize/output-archive.cpp | 2 +- torch/csrc/autograd/FunctionsManual.cpp | 3556 ++++++++------ torch/csrc/autograd/FunctionsManual.h | 1186 +++-- torch/csrc/autograd/InferenceMode.h | 6 +- torch/csrc/autograd/TraceTypeManual.cpp | 54 +- torch/csrc/autograd/VariableTypeManual.cpp | 365 +- torch/csrc/autograd/VariableTypeUtils.h | 269 +- torch/csrc/autograd/anomaly_mode.h | 12 +- torch/csrc/autograd/autograd.cpp | 65 +- torch/csrc/autograd/autograd.h | 83 +- torch/csrc/autograd/autograd_meta.cpp | 186 +- .../autograd_not_implemented_fallback.cpp | 308 +- .../autograd_not_implemented_fallback.h | 3 +- torch/csrc/autograd/cpp_hook.cpp | 28 +- torch/csrc/autograd/cpp_hook.h | 11 +- torch/csrc/autograd/custom_function.cpp | 340 +- torch/csrc/autograd/custom_function.h | 152 +- torch/csrc/autograd/edge.h | 6 +- torch/csrc/autograd/engine.cpp | 449 +- torch/csrc/autograd/engine.h | 93 +- torch/csrc/autograd/forward_grad.cpp | 96 +- torch/csrc/autograd/forward_grad.h | 189 +- torch/csrc/autograd/function.cpp | 45 +- torch/csrc/autograd/function.h | 182 +- torch/csrc/autograd/function_hook.cpp | 6 +- torch/csrc/autograd/function_hook.h | 14 +- .../autograd/functions/accumulate_grad.cpp | 13 +- .../csrc/autograd/functions/accumulate_grad.h | 101 +- torch/csrc/autograd/functions/basic_ops.cpp | 14 +- torch/csrc/autograd/functions/basic_ops.h | 36 +- torch/csrc/autograd/functions/comm.cpp | 11 +- torch/csrc/autograd/functions/comm.h | 8 +- torch/csrc/autograd/functions/init.cpp | 94 +- torch/csrc/autograd/functions/pybind.h | 10 +- torch/csrc/autograd/functions/tensor.cpp | 10 +- torch/csrc/autograd/functions/tensor.h | 12 +- torch/csrc/autograd/functions/utils.cpp | 25 +- torch/csrc/autograd/functions/utils.h | 32 +- torch/csrc/autograd/grad_mode.h | 6 +- torch/csrc/autograd/init.cpp | 487 +- torch/csrc/autograd/input_buffer.cpp | 181 +- torch/csrc/autograd/input_buffer.h | 36 +- torch/csrc/autograd/input_metadata.h | 29 +- torch/csrc/autograd/profiler.h | 2 +- torch/csrc/autograd/profiler_kineto.cpp | 66 +- torch/csrc/autograd/profiler_kineto.h | 24 +- torch/csrc/autograd/profiler_legacy.cpp | 230 +- torch/csrc/autograd/profiler_legacy.h | 68 +- torch/csrc/autograd/profiler_python.cpp | 113 +- torch/csrc/autograd/profiler_python.h | 10 +- torch/csrc/autograd/python_anomaly_mode.cpp | 54 +- torch/csrc/autograd/python_anomaly_mode.h | 13 +- torch/csrc/autograd/python_autograd.h | 10 +- torch/csrc/autograd/python_cpp_function.cpp | 129 +- torch/csrc/autograd/python_cpp_function.h | 80 +- torch/csrc/autograd/python_engine.cpp | 337 +- torch/csrc/autograd/python_engine.h | 23 +- torch/csrc/autograd/python_enum_tag.h | 7 +- torch/csrc/autograd/python_fft_functions.h | 6 +- torch/csrc/autograd/python_function.cpp | 654 +-- torch/csrc/autograd/python_function.h | 113 +- torch/csrc/autograd/python_hook.cpp | 84 +- torch/csrc/autograd/python_hook.h | 12 +- .../csrc/autograd/python_legacy_variable.cpp | 152 +- torch/csrc/autograd/python_legacy_variable.h | 8 +- torch/csrc/autograd/python_linalg_functions.h | 6 +- torch/csrc/autograd/python_nn_functions.h | 6 +- torch/csrc/autograd/python_return_types.h | 6 +- .../autograd/python_saved_variable_hooks.cpp | 135 +- .../autograd/python_saved_variable_hooks.h | 18 +- torch/csrc/autograd/python_sparse_functions.h | 6 +- .../csrc/autograd/python_special_functions.h | 6 +- torch/csrc/autograd/python_torch_functions.h | 12 +- .../python_torch_functions_manual.cpp | 1047 ++-- torch/csrc/autograd/python_variable.cpp | 1023 ++-- torch/csrc/autograd/python_variable.h | 46 +- .../autograd/python_variable_indexing.cpp | 322 +- .../csrc/autograd/python_variable_indexing.h | 13 +- torch/csrc/autograd/record_function_ops.cpp | 107 +- torch/csrc/autograd/record_function_ops.h | 5 +- torch/csrc/autograd/saved_variable.cpp | 129 +- torch/csrc/autograd/saved_variable.h | 60 +- torch/csrc/autograd/saved_variable_hooks.h | 8 +- torch/csrc/autograd/symbolic.h | 6 +- torch/csrc/autograd/utils/error_messages.h | 14 +- .../autograd/utils/grad_layout_contract.h | 25 +- .../csrc/autograd/utils/python_arg_parsing.h | 44 +- torch/csrc/autograd/utils/warnings.cpp | 12 +- torch/csrc/autograd/utils/warnings.h | 25 +- torch/csrc/autograd/utils/wrap_outputs.h | 69 +- torch/csrc/autograd/variable.cpp | 665 +-- torch/csrc/autograd/variable.h | 526 +- torch/csrc/copy_utils.h | 35 +- torch/csrc/cuda/Event.cpp | 211 +- torch/csrc/cuda/Event.h | 7 +- torch/csrc/cuda/Graph.cpp | 55 +- torch/csrc/cuda/Module.cpp | 524 +- torch/csrc/cuda/Module.h | 12 +- torch/csrc/cuda/Stream.cpp | 190 +- torch/csrc/cuda/Stream.h | 8 +- torch/csrc/cuda/THCP.h | 4 +- torch/csrc/cuda/Tensor.cpp | 14 +- torch/csrc/cuda/comm.cpp | 14 +- torch/csrc/cuda/comm.h | 8 +- torch/csrc/cuda/device_set.h | 2 +- torch/csrc/cuda/nccl.cpp | 171 +- torch/csrc/cuda/nccl.h | 67 +- torch/csrc/cuda/override_macros.h | 12 +- torch/csrc/cuda/python_comm.cpp | 28 +- torch/csrc/cuda/python_comm.h | 10 +- torch/csrc/cuda/python_nccl.cpp | 28 +- torch/csrc/cuda/python_nccl.h | 16 +- torch/csrc/cuda/restore_macros.h | 16 +- torch/csrc/cuda/shared/cudart.cpp | 91 +- torch/csrc/cuda/shared/cudnn.cpp | 22 +- torch/csrc/cuda/shared/nvtx.cpp | 6 +- torch/csrc/cuda/utils.cpp | 19 +- torch/csrc/cuda/utils.h | 12 +- torch/csrc/distributed/autograd/autograd.cpp | 2 +- .../autograd/context/container.cpp | 23 +- .../distributed/autograd/context/context.h | 3 +- .../autograd/engine/dist_engine.cpp | 139 +- .../distributed/autograd/engine/dist_engine.h | 1 + .../autograd/functions/recvrpc_backward.cpp | 4 +- torch/csrc/distributed/autograd/init.cpp | 8 +- .../cleanup_autograd_context_resp.cpp | 3 +- .../rpc_messages/propagate_gradients_req.cpp | 4 +- .../rpc_messages/rpc_with_autograd.cpp | 19 +- torch/csrc/distributed/autograd/utils.cpp | 20 +- torch/csrc/distributed/autograd/utils.h | 6 +- torch/csrc/distributed/c10d/FileStore.cpp | 9 +- .../distributed/c10d/GlooDeviceFactory.cpp | 23 +- torch/csrc/distributed/c10d/NCCLUtils.cpp | 6 +- .../csrc/distributed/c10d/ParamCommsUtils.cpp | 4 +- torch/csrc/distributed/c10d/ProcessGroup.cpp | 15 +- .../distributed/c10d/ProcessGroupGloo.cpp | 285 +- .../csrc/distributed/c10d/ProcessGroupMPI.cpp | 77 +- .../distributed/c10d/ProcessGroupNCCL.cpp | 409 +- .../c10d/ProcessGroupRoundRobin.cpp | 20 +- .../distributed/c10d/ProcessGroupWrapper.cpp | 2 +- torch/csrc/distributed/c10d/Store.cpp | 2 +- torch/csrc/distributed/c10d/TCPStore.cpp | 10 +- torch/csrc/distributed/c10d/comm.cpp | 2 +- torch/csrc/distributed/c10d/debug.cpp | 12 +- torch/csrc/distributed/c10d/debug.h | 6 +- .../distributed/c10d/default_comm_hooks.cpp | 14 +- torch/csrc/distributed/c10d/error.h | 3 +- torch/csrc/distributed/c10d/init.cpp | 100 +- torch/csrc/distributed/c10d/logger.cpp | 54 +- torch/csrc/distributed/c10d/logging.h | 41 +- .../csrc/distributed/c10d/python_comm_hook.h | 2 +- .../c10d/quantization/quantization.cpp | 18 +- .../c10d/quantization/quantization.h | 1 - .../c10d/quantization/quantization_gpu.h | 1 - .../c10d/quantization/quantization_utils.h | 2 +- torch/csrc/distributed/c10d/reducer.cpp | 115 +- torch/csrc/distributed/c10d/reducer_cuda.cpp | 5 +- torch/csrc/distributed/c10d/sequence_num.cpp | 2 +- torch/csrc/distributed/c10d/socket.cpp | 216 +- torch/csrc/distributed/c10d/socket.h | 7 +- torch/csrc/distributed/rpc/rpc_agent.cpp | 8 +- torch/csrc/generic/utils.h | 15 +- torch/csrc/lazy/backend/backend_data.h | 4 +- torch/csrc/lazy/backend/backend_device.cpp | 25 +- torch/csrc/lazy/backend/backend_device.h | 74 +- torch/csrc/lazy/backend/backend_interface.cpp | 5 +- torch/csrc/lazy/backend/backend_interface.h | 31 +- torch/csrc/lazy/backend/lowering_context.cpp | 12 +- torch/csrc/lazy/backend/lowering_context.h | 41 +- torch/csrc/lazy/core/config.cpp | 24 +- torch/csrc/lazy/core/config.h | 5 +- torch/csrc/lazy/core/debug_util.cpp | 37 +- torch/csrc/lazy/core/debug_util.h | 14 +- torch/csrc/lazy/core/hash.cpp | 19 +- torch/csrc/lazy/core/hash.h | 30 +- torch/csrc/lazy/core/ir.cpp | 40 +- torch/csrc/lazy/core/ir.h | 38 +- torch/csrc/lazy/core/ir_builder.h | 221 +- torch/csrc/lazy/core/ir_metadata.cpp | 2 +- torch/csrc/lazy/core/lazy_graph_executor.cpp | 57 +- torch/csrc/lazy/core/lazy_view.cpp | 28 +- torch/csrc/lazy/core/metrics.h | 3 +- .../csrc/lazy/core/ops/arithmetic_ir_ops.cpp | 16 +- torch/csrc/lazy/core/ops/utils.cpp | 17 +- torch/csrc/lazy/core/ops/utils.h | 13 +- torch/csrc/lazy/core/permutation_util.h | 6 +- torch/csrc/lazy/core/shape.cpp | 2 +- torch/csrc/lazy/core/shape_inference.cpp | 860 ++-- torch/csrc/lazy/core/shape_inference.h | 2 +- torch/csrc/lazy/core/tensor.cpp | 65 +- torch/csrc/lazy/core/tensor.h | 56 +- torch/csrc/lazy/core/tensor_impl.cpp | 41 +- torch/csrc/lazy/core/tensor_impl.h | 24 +- torch/csrc/lazy/core/tensor_util.cpp | 3 +- torch/csrc/lazy/core/thread_pool.cpp | 2 +- torch/csrc/lazy/core/trie.cpp | 8 +- torch/csrc/lazy/core/trie.h | 10 +- torch/csrc/lazy/core/util.h | 40 +- torch/csrc/lazy/python/init.cpp | 321 +- torch/csrc/lazy/python/init.h | 4 +- torch/csrc/lazy/python/python_util.cpp | 6 +- torch/csrc/lazy/python/python_util.h | 4 +- torch/csrc/lazy/ts_backend/dynamic_ir.cpp | 56 +- torch/csrc/lazy/ts_backend/dynamic_ir.h | 15 +- torch/csrc/lazy/ts_backend/ir_builder.h | 160 +- .../lazy/ts_backend/ops/batch_norm_ops.cpp | 83 +- .../csrc/lazy/ts_backend/ops/batch_norm_ops.h | 161 +- torch/csrc/lazy/ts_backend/ops/device_data.h | 5 +- torch/csrc/lazy/ts_backend/ops/random_ops.cpp | 29 +- torch/csrc/lazy/ts_backend/ops/random_ops.h | 15 +- torch/csrc/lazy/ts_backend/ops/to_copy.h | 95 +- .../csrc/lazy/ts_backend/tensor_aten_ops.cpp | 289 +- torch/csrc/lazy/ts_backend/tensor_aten_ops.h | 119 +- .../lazy/ts_backend/ts_autograd_functions.cpp | 44 +- .../lazy/ts_backend/ts_autograd_functions.h | 18 +- .../csrc/lazy/ts_backend/ts_backend_impl.cpp | 21 +- .../lazy/ts_backend/ts_eager_fallback.cpp | 16 +- .../csrc/lazy/ts_backend/ts_eager_fallback.h | 2 +- .../lazy/ts_backend/ts_lowering_context.cpp | 13 +- .../lazy/ts_backend/ts_lowering_context.h | 1 - .../lazy/ts_backend/ts_native_functions.cpp | 609 ++- torch/csrc/lazy/ts_backend/ts_node.cpp | 84 +- torch/csrc/lazy/ts_backend/ts_node.h | 85 +- .../csrc/lazy/ts_backend/ts_node_lowering.cpp | 140 +- torch/csrc/lazy/ts_backend/ts_node_lowering.h | 7 +- torch/csrc/monitor/events.h | 2 +- torch/csrc/multiprocessing/init.cpp | 2 +- torch/csrc/onnx/init.h | 6 +- torch/csrc/profiler/api.cpp | 5 +- torch/csrc/profiler/api.h | 15 +- torch/csrc/profiler/collection.cpp | 42 +- torch/csrc/profiler/collection.h | 13 +- torch/csrc/profiler/containers.h | 58 +- torch/csrc/profiler/cuda.cpp | 18 +- .../csrc/profiler/kineto_client_interface.cpp | 2 +- torch/csrc/profiler/kineto_shim.cpp | 54 +- torch/csrc/profiler/kineto_shim.h | 13 +- torch/csrc/profiler/nvtx_observer.cpp | 52 +- torch/csrc/profiler/util.cpp | 34 +- torch/csrc/profiler/util.h | 16 +- torch/csrc/python_dimname.cpp | 25 +- torch/csrc/python_dimname.h | 2 +- torch/csrc/serialization.cpp | 167 +- torch/csrc/serialization.h | 10 +- torch/csrc/tensor/python_tensor.cpp | 174 +- torch/csrc/tensor/python_tensor.h | 10 +- torch/csrc/utils.cpp | 108 +- torch/csrc/utils.h | 271 +- torch/csrc/utils/auto_gil.h | 6 +- torch/csrc/utils/byte_order.cpp | 239 +- torch/csrc/utils/byte_order.h | 7 +- torch/csrc/utils/cpp_stacktraces.cpp | 6 +- torch/csrc/utils/cuda_enabled.h | 4 +- torch/csrc/utils/cuda_lazy_init.cpp | 13 +- torch/csrc/utils/cuda_lazy_init.h | 4 +- torch/csrc/utils/disable_torch_function.cpp | 256 +- torch/csrc/utils/disable_torch_function.h | 56 +- torch/csrc/utils/init.cpp | 5 +- torch/csrc/utils/invalid_arguments.cpp | 182 +- torch/csrc/utils/object_ptr.cpp | 2 +- torch/csrc/utils/object_ptr.h | 61 +- torch/csrc/utils/out_types.cpp | 51 +- torch/csrc/utils/out_types.h | 9 +- torch/csrc/utils/pybind.h | 66 +- torch/csrc/utils/pycfunction_helpers.h | 2 +- torch/csrc/utils/python_arg_parser.cpp | 698 ++- torch/csrc/utils/python_arg_parser.h | 486 +- torch/csrc/utils/python_compat.h | 78 +- torch/csrc/utils/python_dispatch.cpp | 397 +- torch/csrc/utils/python_dispatch.h | 4 +- torch/csrc/utils/python_numbers.h | 4 +- torch/csrc/utils/python_scalars.h | 105 +- torch/csrc/utils/python_strings.h | 57 +- torch/csrc/utils/python_torch_function_mode.h | 11 +- torch/csrc/utils/python_tuples.h | 14 +- torch/csrc/utils/six.h | 21 +- torch/csrc/utils/structseq.cpp | 78 +- torch/csrc/utils/structseq.h | 8 +- torch/csrc/utils/tensor_apply.cpp | 87 +- torch/csrc/utils/tensor_apply.h | 22 +- torch/csrc/utils/tensor_dtypes.cpp | 7 +- torch/csrc/utils/tensor_dtypes.h | 10 +- torch/csrc/utils/tensor_flatten.cpp | 42 +- torch/csrc/utils/tensor_flatten.h | 21 +- torch/csrc/utils/tensor_layouts.cpp | 33 +- torch/csrc/utils/tensor_layouts.h | 6 +- torch/csrc/utils/tensor_list.cpp | 42 +- torch/csrc/utils/tensor_list.h | 6 +- torch/csrc/utils/tensor_memoryformats.cpp | 10 +- torch/csrc/utils/tensor_memoryformats.h | 8 +- torch/csrc/utils/tensor_new.cpp | 1291 +++-- torch/csrc/utils/tensor_new.h | 132 +- torch/csrc/utils/tensor_numpy.cpp | 232 +- torch/csrc/utils/tensor_numpy.h | 12 +- torch/csrc/utils/tensor_qschemes.cpp | 5 +- torch/csrc/utils/tensor_qschemes.h | 6 +- torch/csrc/utils/tensor_types.cpp | 87 +- torch/csrc/utils/tensor_types.h | 6 +- torch/csrc/utils/throughput_benchmark.cpp | 25 +- torch/csrc/utils/throughput_benchmark.h | 42 +- torch/csrc/utils/torch_dispatch_mode.h | 5 +- torch/csrc/utils/variadic.h | 15 +- 564 files changed, 42007 insertions(+), 27159 deletions(-) diff --git a/.lintrunner.toml b/.lintrunner.toml index 41b79dcd28b35c..b6613e8ba45b6c 100644 --- a/.lintrunner.toml +++ b/.lintrunner.toml @@ -44,14 +44,10 @@ include_patterns = [ 'aten/src/ATen/*.h', 'c10/**/*.h', 'c10/**/*.cpp', - 'torch/csrc/jit/**/*.h', - 'torch/csrc/jit/**/*.cpp', - 'torch/csrc/deploy/**/*.h', - 'torch/csrc/deploy/**/*.cpp', - 'test/cpp/jit/**/*.h', - 'test/cpp/jit/**/*.cpp', - 'test/cpp/tensorexpr/**/*.h', - 'test/cpp/tensorexpr/**/*.cpp', + 'torch/csrc/**/*.h', + 'torch/csrc/**/*.cpp', + 'test/cpp/**/*.h', + 'test/cpp/**/*.cpp', ] exclude_patterns = [ 'c10/util/strong_type.h', diff --git a/aten/src/ATen/autocast_mode.h b/aten/src/ATen/autocast_mode.h index 7ff53505f33930..f5e88a0b88f12e 100644 --- a/aten/src/ATen/autocast_mode.h +++ b/aten/src/ATen/autocast_mode.h @@ -1,5 +1,7 @@ #pragma once +#include + namespace at { namespace autocast { diff --git a/test/cpp/api/any.cpp b/test/cpp/api/any.cpp index a1112210f2c770..c83ca91d99fcbd 100644 --- a/test/cpp/api/any.cpp +++ b/test/cpp/api/any.cpp @@ -100,25 +100,35 @@ TEST_F(AnyModuleTest, WrongNumberOfArguments) { #endif ASSERT_THROWS_WITH( any.forward(), - module_name + "'s forward() method expects 2 argument(s), but received 0. " - "If " + module_name + "'s forward() method has default arguments, " - "please make sure the forward() method is declared with a corresponding `FORWARD_HAS_DEFAULT_ARGS` macro."); + module_name + + "'s forward() method expects 2 argument(s), but received 0. " + "If " + + module_name + + "'s forward() method has default arguments, " + "please make sure the forward() method is declared with a corresponding `FORWARD_HAS_DEFAULT_ARGS` macro."); ASSERT_THROWS_WITH( any.forward(5), - module_name + "'s forward() method expects 2 argument(s), but received 1. " - "If " + module_name + "'s forward() method has default arguments, " - "please make sure the forward() method is declared with a corresponding `FORWARD_HAS_DEFAULT_ARGS` macro."); + module_name + + "'s forward() method expects 2 argument(s), but received 1. " + "If " + + module_name + + "'s forward() method has default arguments, " + "please make sure the forward() method is declared with a corresponding `FORWARD_HAS_DEFAULT_ARGS` macro."); ASSERT_THROWS_WITH( any.forward(1, 2, 3), - module_name + "'s forward() method expects 2 argument(s), but received 3."); + module_name + + "'s forward() method expects 2 argument(s), but received 3."); } struct M_default_arg_with_macro : torch::nn::Module { double forward(int a, int b = 2, double c = 3.0) { return a + b + c; } + protected: - FORWARD_HAS_DEFAULT_ARGS({1, torch::nn::AnyValue(2)}, {2, torch::nn::AnyValue(3.0)}) + FORWARD_HAS_DEFAULT_ARGS( + {1, torch::nn::AnyValue(2)}, + {2, torch::nn::AnyValue(3.0)}) }; struct M_default_arg_without_macro : torch::nn::Module { @@ -127,7 +137,9 @@ struct M_default_arg_without_macro : torch::nn::Module { } }; -TEST_F(AnyModuleTest, PassingArgumentsToModuleWithDefaultArgumentsInForwardMethod) { +TEST_F( + AnyModuleTest, + PassingArgumentsToModuleWithDefaultArgumentsInForwardMethod) { { AnyModule any(M_default_arg_with_macro{}); @@ -155,22 +167,32 @@ TEST_F(AnyModuleTest, PassingArgumentsToModuleWithDefaultArgumentsInForwardMetho ASSERT_THROWS_WITH( any.forward(), - module_name + "'s forward() method expects 3 argument(s), but received 0. " - "If " + module_name + "'s forward() method has default arguments, " - "please make sure the forward() method is declared with a corresponding `FORWARD_HAS_DEFAULT_ARGS` macro."); + module_name + + "'s forward() method expects 3 argument(s), but received 0. " + "If " + + module_name + + "'s forward() method has default arguments, " + "please make sure the forward() method is declared with a corresponding `FORWARD_HAS_DEFAULT_ARGS` macro."); ASSERT_THROWS_WITH( any.forward(1), - module_name + "'s forward() method expects 3 argument(s), but received 1. " - "If " + module_name + "'s forward() method has default arguments, " - "please make sure the forward() method is declared with a corresponding `FORWARD_HAS_DEFAULT_ARGS` macro."); + module_name + + "'s forward() method expects 3 argument(s), but received 1. " + "If " + + module_name + + "'s forward() method has default arguments, " + "please make sure the forward() method is declared with a corresponding `FORWARD_HAS_DEFAULT_ARGS` macro."); ASSERT_THROWS_WITH( any.forward(1, 3), - module_name + "'s forward() method expects 3 argument(s), but received 2. " - "If " + module_name + "'s forward() method has default arguments, " - "please make sure the forward() method is declared with a corresponding `FORWARD_HAS_DEFAULT_ARGS` macro."); + module_name + + "'s forward() method expects 3 argument(s), but received 2. " + "If " + + module_name + + "'s forward() method has default arguments, " + "please make sure the forward() method is declared with a corresponding `FORWARD_HAS_DEFAULT_ARGS` macro."); ASSERT_THROWS_WITH( any.forward(1, 2, 3.0, 4), - module_name + "'s forward() method expects 3 argument(s), but received 4."); + module_name + + "'s forward() method expects 3 argument(s), but received 4."); } } @@ -345,13 +367,14 @@ TEST_F(AnyValueTest, CorrectlyAccessesIntWhenCorrectType) { ASSERT_NE(value.try_get(), nullptr); // const and non-const types have the same typeid(), // but casting Holder to Holder is undefined - // behavior according to UBSAN: https://github.com/pytorch/pytorch/issues/26964 + // behavior according to UBSAN: + // https://github.com/pytorch/pytorch/issues/26964 // ASSERT_NE(value.try_get(), nullptr); ASSERT_EQ(value.get(), 5); } // This test does not work at all, because it looks like make_value // decays const int into int. -//TEST_F(AnyValueTest, CorrectlyAccessesConstIntWhenCorrectType) { +// TEST_F(AnyValueTest, CorrectlyAccessesConstIntWhenCorrectType) { // auto value = make_value(5); // ASSERT_NE(value.try_get(), nullptr); // // ASSERT_NE(value.try_get(), nullptr); diff --git a/test/cpp/api/autograd.cpp b/test/cpp/api/autograd.cpp index 733c7d4edb0e6d..f3a72946e68857 100644 --- a/test/cpp/api/autograd.cpp +++ b/test/cpp/api/autograd.cpp @@ -1,8 +1,8 @@ -#include #include +#include -#include #include +#include #include @@ -11,8 +11,8 @@ using namespace torch::autograd; using namespace torch::test; -#define ASSERT_VARIABLE_EQ(a,b) ASSERT_TRUE(torch::allclose((a),(b))) -#define EXPECT_VARIABLE_EQ(a,b) EXPECT_TRUE(torch::allclose((a),(b))) +#define ASSERT_VARIABLE_EQ(a, b) ASSERT_TRUE(torch::allclose((a), (b))) +#define EXPECT_VARIABLE_EQ(a, b) EXPECT_TRUE(torch::allclose((a), (b))) std::string graph_desc(std::shared_ptr node) { if (!node) { @@ -20,10 +20,10 @@ std::string graph_desc(std::shared_ptr node) { } auto result = node->name() + "("; auto next_edges = node->next_edges(); - for(auto& edge : next_edges) { + for (auto& edge : next_edges) { result += graph_desc(edge.function); } - return result+")"; + return result + ")"; } Variable simple_fn(const Variable& x, const Variable& y) { @@ -37,7 +37,7 @@ TEST(AutogradAPITests, BackwardSimpleTest) { backward({res.sum()}, {}); ASSERT_VARIABLE_EQ(x.grad(), y + torch::ones({2, 2})); - ASSERT_VARIABLE_EQ(y.grad(), x + torch::ones({2, 2})*2); + ASSERT_VARIABLE_EQ(y.grad(), x + torch::ones({2, 2}) * 2); } TEST(AutogradAPITests, BackwardTest) { @@ -48,14 +48,14 @@ TEST(AutogradAPITests, BackwardTest) { backward({res}, {torch::ones({2, 2})}); - ASSERT_VARIABLE_EQ(x.grad(), 2* (y + torch::ones({2, 2}))); - ASSERT_VARIABLE_EQ(y.grad(), 2 * (x + torch::ones({2, 2})*2)); + ASSERT_VARIABLE_EQ(x.grad(), 2 * (y + torch::ones({2, 2}))); + ASSERT_VARIABLE_EQ(y.grad(), 2 * (x + torch::ones({2, 2}) * 2)); } TEST(AutogradAPITests, GradSimpleTest) { // basic grad - Variable x = torch::randn({2,2}, torch::requires_grad()); - Variable y = torch::randn({2,2}, torch::requires_grad()); + Variable x = torch::randn({2, 2}, torch::requires_grad()); + Variable y = torch::randn({2, 2}, torch::requires_grad()); auto res = simple_fn(x, y); auto grad_res = grad({res}, {x, y}, {torch::ones({2, 2})}); @@ -88,7 +88,7 @@ TEST(AutogradAPITests, GradNonLeafTest) { Variable y = torch::randn({2, 2}, torch::requires_grad()); Variable grad_output = torch::ones({2, 2}); - for (int i = 0; i < 5; ++ i) { + for (int i = 0; i < 5; ++i) { auto res = simple_fn(x, y); auto input_grads = grad({res}, {x}, {grad_output}, {}, true); @@ -128,18 +128,21 @@ TEST(AutogradAPITests, GradUnreachableTest) { ASSERT_FALSE(grad_res[1].defined()); // allow_unused=False, but grads contains None inside, should throw - ASSERT_THROWS_WITH(grad({x * 2}, {x, y}, {}, {}, false, false), "Set allow_unused=True"); + ASSERT_THROWS_WITH( + grad({x * 2}, {x, y}, {}, {}, false, false), "Set allow_unused=True"); } TEST(CustomAutogradTest, GradUnreachableDiscoveryTest) { // Test that certain nodes are not erroneously executed when an input // is unreachable. See #39784 struct MyFunction : public Function { - static Variable forward(AutogradContext *ctx, Variable var) { + static Variable forward(AutogradContext* ctx, Variable var) { return var; } - static variable_list backward(AutogradContext *ctx, variable_list grad_output) { + static variable_list backward( + AutogradContext* ctx, + variable_list grad_output) { ADD_FAILURE() << "This node should not be executed!"; return grad_output; } @@ -156,8 +159,8 @@ TEST(CustomAutogradTest, GradUnreachableDiscoveryTest) { TEST(AutogradAPITests, EmptyInput) { Variable x = torch::ones({1}, torch::requires_grad()); - ASSERT_THROWS_WITH(grad({x * 2}, /*inputs=*/{}, {x}), - "grad requires non-empty inputs."); + ASSERT_THROWS_WITH( + grad({x * 2}, /*inputs=*/{}, {x}), "grad requires non-empty inputs."); } TEST(AutogradAPITests, RetainGrad) { @@ -169,8 +172,7 @@ TEST(AutogradAPITests, RetainGrad) { // Warning when grad is accessed for non-leaf tensor WarningCapture warnings; ASSERT_FALSE(h1.grad().defined()); - ASSERT_TRUE( - warnings.str().find("is not a leaf") != std::string::npos); + ASSERT_TRUE(warnings.str().find("is not a leaf") != std::string::npos); } // It should be possible to call retain_grad() multiple times h1.retain_grad(); @@ -180,8 +182,7 @@ TEST(AutogradAPITests, RetainGrad) { // there should not be any warning when grad is accessed WarningCapture warnings; ASSERT_FALSE(h1.grad().defined()); - ASSERT_FALSE( - warnings.str().find("is not a leaf") != std::string::npos); + ASSERT_FALSE(warnings.str().find("is not a leaf") != std::string::npos); } // Gradient should be accumulated @@ -224,7 +225,8 @@ TEST(AutogradAPITests, AnomalyMode) { auto gr = // NOLINTNEXTLINE(bugprone-argument-comment) grad({y}, {x}, {}, /*retain_graph=*/true, /*create_backward=*/true); - ASSERT_THROWS_WITH(grad({gr[0]}, {x}, {torch::tensor({0.0})});, "returned nan"); + ASSERT_THROWS_WITH(grad({gr[0]}, {x}, {torch::tensor({0.0})}); + , "returned nan"); auto msgs = warnings.messages(); ASSERT_EQ(msgs.size(), 2); ASSERT_TRUE( @@ -239,55 +241,68 @@ TEST(AutogradAPITests, AnomalyMode) { TEST(CustomAutogradTest, CustomFunction) { struct MyFunction : public Function { - static Variable forward(AutogradContext *ctx, Variable var1, int mul, Variable var2) { + static Variable forward( + AutogradContext* ctx, + Variable var1, + int mul, + Variable var2) { ctx->saved_data["mul"] = mul; ctx->save_for_backward({var1, var2}); - return var1 + mul*var2 + var1*var2; + return var1 + mul * var2 + var1 * var2; } - static variable_list backward(AutogradContext *ctx, variable_list grad_output) { + static variable_list backward( + AutogradContext* ctx, + variable_list grad_output) { int mul = ctx->saved_data["mul"].toInt(); auto saved = ctx->get_saved_variables(); auto var1 = saved[0]; auto var2 = saved[1]; - variable_list output = {grad_output[0] + grad_output[0]*var2, Variable(), grad_output[0] * mul + grad_output[0] * var1}; + variable_list output = { + grad_output[0] + grad_output[0] * var2, + Variable(), + grad_output[0] * mul + grad_output[0] * var1}; return output; } }; - Variable x = torch::randn({5,5}, torch::requires_grad()); - Variable y = torch::randn({5,5}, torch::requires_grad()); - auto res = MyFunction::apply(x,2,y); + Variable x = torch::randn({5, 5}, torch::requires_grad()); + Variable y = torch::randn({5, 5}, torch::requires_grad()); + auto res = MyFunction::apply(x, 2, y); auto go = torch::ones({}, torch::requires_grad()); res.sum().backward(go, false, true); - ASSERT_VARIABLE_EQ(x.grad(), y + torch::ones({5,5})); - ASSERT_VARIABLE_EQ(y.grad(), x + torch::ones({5,5})*2); + ASSERT_VARIABLE_EQ(x.grad(), y + torch::ones({5, 5})); + ASSERT_VARIABLE_EQ(y.grad(), x + torch::ones({5, 5}) * 2); } TEST(CustomAutogradTest, FunctionReturnsInput) { struct MyFunction : public Function { - static Variable forward(AutogradContext *ctx, Variable var1) { + static Variable forward(AutogradContext* ctx, Variable var1) { return var1; } - static variable_list backward(AutogradContext *ctx, variable_list grad_output) { - return {grad_output[0]*2}; + static variable_list backward( + AutogradContext* ctx, + variable_list grad_output) { + return {grad_output[0] * 2}; } }; Variable x(torch::ones(1, torch::requires_grad())); - MyFunction::apply(x).backward(torch::ones(1) , true, true); + MyFunction::apply(x).backward(torch::ones(1), true, true); ASSERT_VARIABLE_EQ(x.grad(), torch::full(1, 2.)); } TEST(CustomAutogradTest, FunctionReturnsUndefined) { struct MyFunction : public Function { - static Variable forward(AutogradContext *ctx, Variable var) { + static Variable forward(AutogradContext* ctx, Variable var) { return var * 2; } - static variable_list backward(AutogradContext *ctx, variable_list grad_output) { + static variable_list backward( + AutogradContext* ctx, + variable_list grad_output) { at::Tensor undefined_tensor; return {undefined_tensor}; } @@ -305,16 +320,19 @@ TEST(CustomAutogradTest, FunctionReturnsUndefined) { ASSERT_FALSE(x.grad().defined()); ASSERT_FALSE(torch::autograd::grad( - {MyFunction::apply(x)}, {x}, {}, false, false, true)[0].defined()); + {MyFunction::apply(x)}, {x}, {}, false, false, true)[0] + .defined()); } TEST(CustomAutogradTest, MaterializeGrads) { struct MyFunction : public Function { - static Variable forward(AutogradContext *ctx, Variable var) { + static Variable forward(AutogradContext* ctx, Variable var) { return var; } - static variable_list backward(AutogradContext *ctx, variable_list grad_output) { + static variable_list backward( + AutogradContext* ctx, + variable_list grad_output) { EXPECT_VARIABLE_EQ(grad_output[0], torch::zeros(1)); return grad_output; } @@ -326,12 +344,14 @@ TEST(CustomAutogradTest, MaterializeGrads) { TEST(CustomAutogradTest, DontMaterializeGrads) { struct MyFunction : public Function { - static Variable forward(AutogradContext *ctx, Variable var) { + static Variable forward(AutogradContext* ctx, Variable var) { ctx->set_materialize_grads(false); return var; } - static variable_list backward(AutogradContext *ctx, variable_list grad_output) { + static variable_list backward( + AutogradContext* ctx, + variable_list grad_output) { EXPECT_FALSE(grad_output[0].defined()); return grad_output; } @@ -343,27 +363,27 @@ TEST(CustomAutogradTest, DontMaterializeGrads) { TEST(CustomAutogradTest, NoGradCustomFunction) { // Custom Function should respect grad mode - struct MyOp : public Function { - static Variable forward(AutogradContext *ctx, Variable x) { - return x+1; - } - - static variable_list backward(AutogradContext *ctx, variable_list dy) { - return dy; - } - }; - - auto x = torch::ones({5,5}, torch::requires_grad()); - { + struct MyOp : public Function { + static Variable forward(AutogradContext* ctx, Variable x) { + return x + 1; + } + + static variable_list backward(AutogradContext* ctx, variable_list dy) { + return dy; + } + }; + + auto x = torch::ones({5, 5}, torch::requires_grad()); + { at::NoGradGuard no_grad; auto y = MyOp::apply(x); ASSERT_FALSE(y.requires_grad()); - } + } } TEST(CustomAutogradTest, MarkDirty) { struct MyFunction : public Function { - static Variable forward(AutogradContext *ctx, Variable v) { + static Variable forward(AutogradContext* ctx, Variable v) { // Change the value inplace auto v_data = v.data_ptr(); v_data[0] = 2; @@ -371,13 +391,15 @@ TEST(CustomAutogradTest, MarkDirty) { return v; } - static variable_list backward(AutogradContext *ctx, variable_list grad_output) { - return { (grad_output[0]*2.0) }; + static variable_list backward( + AutogradContext* ctx, + variable_list grad_output) { + return {(grad_output[0] * 2.0)}; } }; // Clone here because modifying leafs inplace is not allowed - auto x = torch::randn({5,5}, torch::requires_grad()).clone(); + auto x = torch::randn({5, 5}, torch::requires_grad()).clone(); auto version_before = x._version(); auto out = MyFunction::apply(x); auto version_after = x._version(); @@ -387,18 +409,20 @@ TEST(CustomAutogradTest, MarkDirty) { TEST(CustomAutogradTest, MarkNonDifferentiable) { struct MyFunction : public Function { - static Variable forward(AutogradContext *ctx, Variable v) { + static Variable forward(AutogradContext* ctx, Variable v) { Variable output = v > 0; ctx->mark_non_differentiable({output}); return output; } - static variable_list backward(AutogradContext *ctx, variable_list grad_output) { - return { (grad_output[0]*0.0) }; + static variable_list backward( + AutogradContext* ctx, + variable_list grad_output) { + return {(grad_output[0] * 0.0)}; } }; - auto x = torch::randn({5,5}, torch::requires_grad()); + auto x = torch::randn({5, 5}, torch::requires_grad()); auto mask = MyFunction::apply(x); ASSERT_FALSE(mask.requires_grad()); auto y = x.masked_fill(mask, 0); @@ -407,118 +431,132 @@ TEST(CustomAutogradTest, MarkNonDifferentiable) { TEST(CustomAutogradTest, MarkNonDifferentiableMixed) { struct MyFunction : public Function { - static variable_list forward(AutogradContext *ctx, Variable input) { - Variable a = input+1; - Variable b = input+2; + static variable_list forward(AutogradContext* ctx, Variable input) { + Variable a = input + 1; + Variable b = input + 2; ctx->mark_non_differentiable({a}); - return {a,b}; + return {a, b}; } - static variable_list backward(AutogradContext *ctx, variable_list grad_output) { + static variable_list backward( + AutogradContext* ctx, + variable_list grad_output) { const Variable &grad_a = grad_output[0], &grad_b = grad_output[1]; - EXPECT_VARIABLE_EQ(grad_a, torch::zeros({5,5})); - EXPECT_VARIABLE_EQ(grad_b, torch::ones({5,5})); + EXPECT_VARIABLE_EQ(grad_a, torch::zeros({5, 5})); + EXPECT_VARIABLE_EQ(grad_b, torch::ones({5, 5})); return {grad_b}; } }; - auto x = torch::randn({5,5}, torch::requires_grad()); + auto x = torch::randn({5, 5}, torch::requires_grad()); auto out = MyFunction::apply(x); ASSERT_FALSE(out[0].requires_grad()); ASSERT_TRUE(out[1].requires_grad()); out[1].sum().backward(); - ASSERT_VARIABLE_EQ(x.grad(), torch::ones({5,5})); + ASSERT_VARIABLE_EQ(x.grad(), torch::ones({5, 5})); } TEST(CustomAutogradTest, MarkNonDifferentiableNone) { struct MyFunction : public Function { - static Variable forward(AutogradContext *ctx, Variable input) { + static Variable forward(AutogradContext* ctx, Variable input) { auto output = input.clone(); ctx->mark_non_differentiable({output}); return output; } - static variable_list backward(AutogradContext *ctx, variable_list grad_outputs) { + static variable_list backward( + AutogradContext* ctx, + variable_list grad_outputs) { return {}; } }; - auto x = torch::randn({5,5}, torch::requires_grad()); + auto x = torch::randn({5, 5}, torch::requires_grad()); auto r = MyFunction::apply(x * x); (r * x).sum().backward(); } TEST(CustomAutogradTest, ReturnLeafInplace) { struct Inplace : public Function { - static variable_list forward(AutogradContext *ctx, Variable a, Variable b) { + static variable_list forward(AutogradContext* ctx, Variable a, Variable b) { ctx->mark_dirty({a}); - return {a.add_(b), b+2}; + return {a.add_(b), b + 2}; } - static variable_list backward(AutogradContext *ctx, variable_list grad_output) { + static variable_list backward( + AutogradContext* ctx, + variable_list grad_output) { return {grad_output[0], grad_output[0] + grad_output[1]}; } }; - Variable x = torch::randn({5,5}); - Variable y = torch::randn({5,5}, torch::requires_grad()); + Variable x = torch::randn({5, 5}); + Variable y = torch::randn({5, 5}, torch::requires_grad()); - auto out = Inplace::apply(x,y); - auto &q = out[0]; + auto out = Inplace::apply(x, y); + auto& q = out[0]; ASSERT_TRUE(torch::equal(q, x)); ASSERT_TRUE(q.requires_grad()); q.sum().backward(); - ASSERT_VARIABLE_EQ(y.grad(), torch::ones({5,5})); + ASSERT_VARIABLE_EQ(y.grad(), torch::ones({5, 5})); } TEST(CustomAutogradTest, ReturnDuplicateInplace) { struct DoubleInplace : public Function { - static variable_list forward(AutogradContext *ctx, Variable x) { + static variable_list forward(AutogradContext* ctx, Variable x) { x.mul_(2); ctx->mark_dirty({x}); - return {x,x}; + return {x, x}; } - static variable_list backward(AutogradContext *ctsx, variable_list grad_outputs) { - return {grad_outputs[0]*2 + grad_outputs[1]*2}; + static variable_list backward( + AutogradContext* ctsx, + variable_list grad_outputs) { + return {grad_outputs[0] * 2 + grad_outputs[1] * 2}; } }; - auto x = torch::randn({5,5}, torch::requires_grad()); + auto x = torch::randn({5, 5}, torch::requires_grad()); - ASSERT_THROWS_WITH(DoubleInplace::apply(x), "leaf Variable that requires grad"); - // TODO ASSERT_THROWS_WITH(DoubleInplace::apply(x.clone()[0]), "only one output"); + ASSERT_THROWS_WITH( + DoubleInplace::apply(x), "leaf Variable that requires grad"); + // TODO ASSERT_THROWS_WITH(DoubleInplace::apply(x.clone()[0]), "only one + // output"); auto out = DoubleInplace::apply(x.clone()); - ASSERT_TRUE(torch::equal(out[0],out[1])); + ASSERT_TRUE(torch::equal(out[0], out[1])); } TEST(CustomAutogradTest, ReturnDuplicate) { struct DoubleDuplicate : public Function { - static variable_list forward(AutogradContext *ctx, Variable x) { - auto output = x*2; + static variable_list forward(AutogradContext* ctx, Variable x) { + auto output = x * 2; return {output, output}; } - static variable_list backward(AutogradContext *ctx, variable_list grad_outputs) { - return {grad_outputs[0]*2 + grad_outputs[1]*2}; + static variable_list backward( + AutogradContext* ctx, + variable_list grad_outputs) { + return {grad_outputs[0] * 2 + grad_outputs[1] * 2}; } }; - auto x = torch::randn({5,5}, torch::requires_grad()); + auto x = torch::randn({5, 5}, torch::requires_grad()); auto out = DoubleDuplicate::apply(x); - ASSERT_TRUE(torch::equal(out[0],out[1])); + ASSERT_TRUE(torch::equal(out[0], out[1])); } TEST(CustomAutogradTest, SaveEmptyForBackward) { struct MyFunction : public Function { - static Variable forward(AutogradContext *ctx, Variable input) { + static Variable forward(AutogradContext* ctx, Variable input) { ctx->save_for_backward({Variable(), input, Variable()}); - return input*input; + return input * input; } - static variable_list backward(AutogradContext *ctx, variable_list grad_output) { + static variable_list backward( + AutogradContext* ctx, + variable_list grad_output) { auto saved = ctx->get_saved_variables(); EXPECT_FALSE(saved[0].defined()); EXPECT_FALSE(saved[2].defined()); @@ -526,27 +564,32 @@ TEST(CustomAutogradTest, SaveEmptyForBackward) { } }; - Variable x = torch::randn({5,5}, torch::requires_grad()); + Variable x = torch::randn({5, 5}, torch::requires_grad()); auto y = MyFunction::apply(x); y.sum().backward(); - ASSERT_VARIABLE_EQ(x.grad(), 2*x); + ASSERT_VARIABLE_EQ(x.grad(), 2 * x); } TEST(CustomAutogradTest, InvalidGradients) { struct MyFunction : public Function { - static Variable forward(AutogradContext *ctx, Variable x) { - return x*2; + static Variable forward(AutogradContext* ctx, Variable x) { + return x * 2; } - static variable_list backward(AutogradContext *ctsx, variable_list grad_outputs) { - return {torch::randn(10, torch::dtype(torch::kFloat).requires_grad(true))}; + static variable_list backward( + AutogradContext* ctsx, + variable_list grad_outputs) { + return { + torch::randn(10, torch::dtype(torch::kFloat).requires_grad(true))}; } }; - auto input1 = torch::randn({5,5}, torch::dtype(torch::kFloat).requires_grad(true)); + auto input1 = + torch::randn({5, 5}, torch::dtype(torch::kFloat).requires_grad(true)); ASSERT_THROWS_WITH( - MyFunction::apply(input1).sum().backward(), "expected shape"); - auto input2 = torch::randn(10, torch::dtype(torch::kDouble).requires_grad(true)); + MyFunction::apply(input1).sum().backward(), "expected shape"); + auto input2 = + torch::randn(10, torch::dtype(torch::kDouble).requires_grad(true)); } TEST(CustomAutogradTest, NoGradInput) { @@ -555,12 +598,14 @@ TEST(CustomAutogradTest, NoGradInput) { return x; } - static variable_list backward(AutogradContext*, variable_list grad_outputs) { + static variable_list backward( + AutogradContext*, + variable_list grad_outputs) { return grad_outputs; } }; - Variable x = torch::randn({5,5}, torch::requires_grad()); + Variable x = torch::randn({5, 5}, torch::requires_grad()); Variable y; { at::NoGradGuard no_grad; @@ -586,13 +631,15 @@ TEST(CustomAutogradTest, TooManyGrads) { TEST(CustomAutogradTest, DepNoGrad) { struct F1 : public Function { - static variable_list forward(AutogradContext *ctx, Variable input) { + static variable_list forward(AutogradContext* ctx, Variable input) { auto out = torch::randn(input.sizes()); ctx->mark_non_differentiable({out}); return {input, out}; } - static variable_list backward(AutogradContext *ctx, variable_list grad_output) { + static variable_list backward( + AutogradContext* ctx, + variable_list grad_output) { return {grad_output[0]}; } }; @@ -610,11 +657,11 @@ TEST(CustomAutogradTest, DepNoGrad) { auto x = torch::randn(5, torch::requires_grad()); auto out = F1::apply(x); Variable &a = out[0], &b = out[1]; - b = b+1; // Separate F1 and F2 by another operation + b = b + 1; // Separate F1 and F2 by another operation ASSERT_TRUE(a.requires_grad()); ASSERT_FALSE(b.requires_grad()); - auto c = F2::apply(a,b); + auto c = F2::apply(a, b); c.backward(torch::ones(c.sizes()), false, false); ASSERT_VARIABLE_EQ(x.grad(), torch::ones(x.sizes())); } @@ -622,13 +669,13 @@ TEST(CustomAutogradTest, DepNoGrad) { TEST(CustomAutogradTest, Reentrant) { static Variable y_data = torch::randn({2, 2}); struct Reenter : public Function { - static Variable forward(AutogradContext *ctx, Variable input) { + static Variable forward(AutogradContext* ctx, Variable input) { Variable output; { at::AutoGradMode enable_grad(true); auto x = make_variable(input.tensor_data(), true); auto y = make_variable(y_data.tensor_data(), true); - output = x*y; + output = x * y; ctx->saved_data["x"] = x; ctx->saved_data["y"] = y; @@ -637,7 +684,9 @@ TEST(CustomAutogradTest, Reentrant) { return output.detach(); } - static variable_list backward(AutogradContext *ctx, variable_list grad_output) { + static variable_list backward( + AutogradContext* ctx, + variable_list grad_output) { { at::AutoGradMode enable_grad(true); auto out = ctx->saved_data["output_var"].toTensor(); @@ -647,26 +696,27 @@ TEST(CustomAutogradTest, Reentrant) { } }; - auto x = torch::randn({2,2}, torch::requires_grad()); + auto x = torch::randn({2, 2}, torch::requires_grad()); auto out = Reenter::apply(x); out.sum().backward(); ASSERT_VARIABLE_EQ(x.grad(), y_data); } - // NOTE: If this fails for apparently unrelated reasons in TSAN be aware of // the TSAN limit on mutex: https://github.com/google/sanitizers/issues/950 TEST(CustomAutogradTest, DeepReentrant) { struct DeepReenter : public Function { - static Variable forward(AutogradContext *ctx, Variable x) { + static Variable forward(AutogradContext* ctx, Variable x) { { at::AutoGradMode enable_grad(true); - ctx->saved_data["x"] = make_variable(x.tensor_data(), true) -1; + ctx->saved_data["x"] = make_variable(x.tensor_data(), true) - 1; } return ctx->saved_data["x"].toTensor().detach(); } - static variable_list backward(AutogradContext*ctx, variable_list grad_output) { + static variable_list backward( + AutogradContext* ctx, + variable_list grad_output) { if (!at::native::is_nonzero(ctx->saved_data["x"].toTensor())) { return grad_output; } @@ -679,7 +729,8 @@ TEST(CustomAutogradTest, DeepReentrant) { }; // This should not stack overflow - auto v = torch::tensor({8193}, torch::dtype(torch::kFloat).requires_grad(true)); + auto v = + torch::tensor({8193}, torch::dtype(torch::kFloat).requires_grad(true)); DeepReenter::apply(v).sum().backward(); } @@ -698,15 +749,17 @@ TEST(CustomAutogradTest, ReentrantPriority) { }; struct Reenter : public Function { - static Variable forward(AutogradContext *ctx, Variable x) { + static Variable forward(AutogradContext* ctx, Variable x) { { at::AutoGradMode enable_grad(true); - ctx->saved_data["x"] = make_variable(x.tensor_data(), true) -1; + ctx->saved_data["x"] = make_variable(x.tensor_data(), true) - 1; } return ctx->saved_data["x"].toTensor().detach(); } - static variable_list backward(AutogradContext*ctx, variable_list grad_output) { + static variable_list backward( + AutogradContext* ctx, + variable_list grad_output) { order.push_back(1); if (!at::native::is_nonzero(ctx->saved_data["x"].toTensor())) { return grad_output; @@ -719,12 +772,13 @@ TEST(CustomAutogradTest, ReentrantPriority) { } }; - auto a = MyFunction::apply(torch::tensor({6}, torch::dtype(torch::kFloat).requires_grad(true))); - auto b = Reenter::apply(torch::tensor({9}, torch::dtype(torch::kFloat).requires_grad(true))); - auto v = a*b; + auto a = MyFunction::apply( + torch::tensor({6}, torch::dtype(torch::kFloat).requires_grad(true))); + auto b = Reenter::apply( + torch::tensor({9}, torch::dtype(torch::kFloat).requires_grad(true))); + auto v = a * b; v.backward(); - // All the reentrant tasks should be prioritized over the MyFunction backward // task. ASSERT_EQ(order.size(), 10); @@ -735,87 +789,81 @@ TEST(CustomAutogradTest, ReentrantPriority) { } TEST(CustomAutogradTest, Hooks) { - Variable x = torch::ones({5,5}, torch::requires_grad()); - Variable y = torch::ones({5,5})*4; + Variable x = torch::ones({5, 5}, torch::requires_grad()); + Variable y = torch::ones({5, 5}) * 4; y.set_requires_grad(true); int counter = 0; - std::function bw_hook([&counter](int inc, Variable grad){ - counter += inc; - }); + std::function bw_hook( + [&counter](int inc, Variable grad) { counter += inc; }); Variable z = x * x + x * 2 + x * y + y; - x.register_hook([&bw_hook](Variable grad){ - bw_hook(0, grad); - }); - auto hook_1 = z.register_hook([&bw_hook](Variable grad){ - bw_hook(1, grad); - }); - z.backward(torch::ones({5,5}), true, true); + x.register_hook([&bw_hook](Variable grad) { bw_hook(0, grad); }); + auto hook_1 = + z.register_hook([&bw_hook](Variable grad) { bw_hook(1, grad); }); + z.backward(torch::ones({5, 5}), true, true); ASSERT_EQ(counter, 1); - auto hook_2 = z.register_hook([&bw_hook](Variable grad){ - bw_hook(2, grad); - }); - z.backward(torch::ones({5,5}), true, true); + auto hook_2 = + z.register_hook([&bw_hook](Variable grad) { bw_hook(2, grad); }); + z.backward(torch::ones({5, 5}), true, true); ASSERT_EQ(counter, 4); z.remove_hook(hook_2); - z.backward(torch::ones({5,5}), true, true); + z.backward(torch::ones({5, 5}), true, true); ASSERT_EQ(counter, 5); - std::function bw_hook_modify([](Variable grad){ - return grad.mul(2); - }); + std::function bw_hook_modify( + [](Variable grad) { return grad.mul(2); }); z.remove_hook(hook_1); z.register_hook(bw_hook_modify); y.grad().zero_(); - z.backward(torch::ones({5,5}), true, false); - ASSERT_VARIABLE_EQ(y.grad(), (x+1)*2); + z.backward(torch::ones({5, 5}), true, false); + ASSERT_VARIABLE_EQ(y.grad(), (x + 1) * 2); y.register_hook(bw_hook_modify); y.grad().zero_(); - z.backward(torch::ones({5,5}), false, false); - ASSERT_VARIABLE_EQ(y.grad(), (x+1)*4); + z.backward(torch::ones({5, 5}), false, false); + ASSERT_VARIABLE_EQ(y.grad(), (x + 1) * 4); ASSERT_THROWS_WITH(y.remove_hook(3), "Invalid index"); } TEST(CustomAutogradTest, HookNone) { struct NoneGradientFunction : public Function { - static variable_list forward(AutogradContext *ctx, Variable x, Variable y) { - return {x,y}; + static variable_list forward(AutogradContext* ctx, Variable x, Variable y) { + return {x, y}; } - static variable_list backward(AutogradContext *ctx, variable_list grad) { + static variable_list backward(AutogradContext* ctx, variable_list grad) { return {grad[0], Variable()}; } }; bool was_called = false; - auto hook = ([&was_called](Variable grad){ + auto hook = ([&was_called](Variable grad) { ASSERT_TRUE(grad.defined()); was_called = true; }); - auto x = torch::randn({5,5}, torch::requires_grad()); - auto y = torch::randn({5,5}); + auto x = torch::randn({5, 5}, torch::requires_grad()); + auto y = torch::randn({5, 5}); - auto out = NoneGradientFunction::apply(x,y); + auto out = NoneGradientFunction::apply(x, y); Variable rx = x[0], ry = x[1]; rx.register_hook(hook); ry.register_hook(hook); - (rx+ry).sum().backward(); + (rx + ry).sum().backward(); ASSERT_TRUE(was_called); } TEST(CustomAutogradTest, BackwardWithInputs) { - Variable x = torch::randn({5,5}, torch::requires_grad()); - Variable y = torch::randn({5,5}, torch::requires_grad()); + Variable x = torch::randn({5, 5}, torch::requires_grad()); + Variable y = torch::randn({5, 5}, torch::requires_grad()); Variable z = x * x + x * y + y * y; Variable x_grad_expected = 2 * x + y; Variable y_grad_expected = x + 2 * y; @@ -827,17 +875,19 @@ TEST(CustomAutogradTest, BackwardWithInputs) { } TEST(CustomAutogradTest, BackwardWithEmptyInputs) { - Variable x = torch::randn({5,5}, torch::requires_grad()); - Variable y = torch::randn({5,5}, torch::requires_grad()); + Variable x = torch::randn({5, 5}, torch::requires_grad()); + Variable y = torch::randn({5, 5}, torch::requires_grad()); Variable z = x * x + x * y + y * y; Variable x_grad_expected = 2 * x + y; Variable y_grad_expected = x + 2 * y; - ASSERT_THROWS_WITH(z.backward(torch::ones({5, 5}), false, false, std::vector{}), "cannot be empty"); + ASSERT_THROWS_WITH( + z.backward(torch::ones({5, 5}), false, false, std::vector{}), + "cannot be empty"); } TEST(CustomAutogradTest, BackwardWithNonLeafInputs) { - Variable x = torch::randn({5,5}, torch::requires_grad()); - Variable y = torch::randn({5,5}, torch::requires_grad()); + Variable x = torch::randn({5, 5}, torch::requires_grad()); + Variable y = torch::randn({5, 5}, torch::requires_grad()); Variable z = x * x; Variable w = y * z + x * y + y * y; @@ -854,20 +904,22 @@ TEST(CustomAutogradTest, BackwardWithNonLeafInputs) { TEST(CustomAutogradTest, BackwardWithCreateGraphWarns) { c10::Warning::WarnAlways guard(true); - torch::Tensor x = torch::randn({5,5}).set_requires_grad(true); + torch::Tensor x = torch::randn({5, 5}).set_requires_grad(true); auto z = x * x; { WarningCapture warnings; z.backward(torch::ones({5, 5}), c10::nullopt, true); ASSERT_TRUE( - warnings.str().find("Using backward() with create_graph=True") != std::string::npos); + warnings.str().find("Using backward() with create_graph=True") != + std::string::npos); } { WarningCapture warnings; torch::autograd::backward({z}, {torch::ones({5, 5})}, c10::nullopt, true); ASSERT_TRUE( - warnings.str().find("Using backward() with create_graph=True") != std::string::npos); + warnings.str().find("Using backward() with create_graph=True") != + std::string::npos); } } @@ -877,50 +929,69 @@ TEST(CustomAutogradTest, BackwardWithCreateGraphWarns) { * but when no inputs require grad, we should not create this node * - check_inplace logic * - view ops - * - TODO: Tests for debug-only checks? Don't need for now because CI doesn't test non-NDEBUG builds. + * - TODO: Tests for debug-only checks? Don't need for now because CI doesn't + * test non-NDEBUG builds. * - tensorlist input and output * - multiple outputs / non-tensor output * - rebase_history vs set_history */ namespace { -torch::Tensor inplace_op(const torch::Tensor& self, const torch::Tensor& other) { +torch::Tensor inplace_op( + const torch::Tensor& self, + const torch::Tensor& other) { return self.add_(other); } -std::tuple two_arg_inplace_op(const torch::Tensor& self, const torch::Tensor& other) { +std::tuple two_arg_inplace_op( + const torch::Tensor& self, + const torch::Tensor& other) { other.add_(self); self.add_(other); return std::tuple(self, other); } -std::tuple two_pairs_of_view_op(const torch::Tensor& self, const torch::Tensor& other) { - // This is not allowed. We test below that this calling into the boxed kernel will raise an error +std::tuple two_pairs_of_view_op( + const torch::Tensor& self, + const torch::Tensor& other) { + // This is not allowed. We test below that this calling into the boxed kernel + // will raise an error return std::tuple(self, other); } -std::tuple non_first_view_op(const torch::Tensor& self, const torch::Tensor& other) { - // This is not allowed. We test below that this calling into the boxed kernel will raise an error +std::tuple non_first_view_op( + const torch::Tensor& self, + const torch::Tensor& other) { + // This is not allowed. We test below that this calling into the boxed kernel + // will raise an error return std::tuple(self.clone(), other); } -int64_t ret_single_non_tensor(const torch::Tensor& self, const torch::Tensor& other) { +int64_t ret_single_non_tensor( + const torch::Tensor& self, + const torch::Tensor& other) { return 12; } -torch::Tensor opt_op(const torch::Tensor& self, const c10::optional& other) { +torch::Tensor opt_op( + const torch::Tensor& self, + const c10::optional& other) { if (other.has_value()) { - return self + other.value(); + return self + other.value(); } else { - return self.clone(); + return self.clone(); } } -torch::Tensor my_custom_op(const torch::Tensor& self, const torch::Tensor& other) { +torch::Tensor my_custom_op( + const torch::Tensor& self, + const torch::Tensor& other) { return self + other; } -std::tuple ret_tuple_non_tensor(const torch::Tensor& self, const torch::Tensor& other) { +std::tuple ret_tuple_non_tensor( + const torch::Tensor& self, + const torch::Tensor& other) { auto a = self - other; auto b = self + other; return std::tuple(a, b, 12); @@ -930,15 +1001,21 @@ torch::Tensor view_op(const torch::Tensor& self) { return self; } -torch::Tensor view_op_with_extra_arg(const torch::Tensor& self, const torch::Tensor& other) { +torch::Tensor view_op_with_extra_arg( + const torch::Tensor& self, + const torch::Tensor& other) { return self; } -std::vector ret_tensor_vector_view(const torch::Tensor& self, const torch::Tensor& other) { +std::vector ret_tensor_vector_view( + const torch::Tensor& self, + const torch::Tensor& other) { return {self, self}; } -std::vector ret_tensor_vector(const torch::Tensor& self, const torch::Tensor& other) { +std::vector ret_tensor_vector( + const torch::Tensor& self, + const torch::Tensor& other) { std::vector out; out.push_back(self + other); out.push_back(self - other); @@ -948,20 +1025,24 @@ std::vector ret_tensor_vector(const torch::Tensor& self, const torch torch::Tensor tensorlist_op(const torch::Tensor& self, at::TensorList other) { const auto& res = self.clone(); for (const auto& t : other) { - res.add_(t); + res.add_(t); } return res; } -#define REGISTER_TEST_OP(name, schema, fn) \ - auto m = MAKE_TORCH_LIBRARY(_test); \ - m.def(schema); \ - auto m_autograd = MAKE_TORCH_LIBRARY_IMPL(_test, Autograd); \ - auto m_cpu = MAKE_TORCH_LIBRARY_IMPL(_test, CPU); \ - auto m_inplaceorview = MAKE_TORCH_LIBRARY_IMPL(_test, ADInplaceOrView); \ - m_cpu.impl(name, c10::DispatchKey::CPU, TORCH_FN(fn)); \ - m_autograd.impl(name, c10::DispatchKey::Autograd, autogradNotImplementedFallback()); \ - m_inplaceorview.impl(name, c10::DispatchKey::ADInplaceOrView, autogradNotImplementedInplaceOrViewFallback()); +#define REGISTER_TEST_OP(name, schema, fn) \ + auto m = MAKE_TORCH_LIBRARY(_test); \ + m.def(schema); \ + auto m_autograd = MAKE_TORCH_LIBRARY_IMPL(_test, Autograd); \ + auto m_cpu = MAKE_TORCH_LIBRARY_IMPL(_test, CPU); \ + auto m_inplaceorview = MAKE_TORCH_LIBRARY_IMPL(_test, ADInplaceOrView); \ + m_cpu.impl(name, c10::DispatchKey::CPU, TORCH_FN(fn)); \ + m_autograd.impl( \ + name, c10::DispatchKey::Autograd, autogradNotImplementedFallback()); \ + m_inplaceorview.impl( \ + name, \ + c10::DispatchKey::ADInplaceOrView, \ + autogradNotImplementedInplaceOrViewFallback()); template void assertBasicChecks(F op) { @@ -975,7 +1056,9 @@ void assertBasicChecks(F op) { // # Should not have grad_fn if none require grad auto out2 = op(b, c); - ASSERT_THROWS_WITH(out2.backward(), "element 0 of tensors does not require grad and does not have a grad_fn"); + ASSERT_THROWS_WITH( + out2.backward(), + "element 0 of tensors does not require grad and does not have a grad_fn"); // TODO: Forward AD Tests? } @@ -983,15 +1066,21 @@ void assertBasicChecks(F op) { } // namespace // These tests trigger an MSVC bug in the internal arvr build -// Reproduce with: buck build @arvr/mode/win/opt //xplat/caffe2:autograd_libtorch_test_ovrsource -// It is probably caused by the lambda, see https://github.com/pytorch/pytorch/issues/48763 +// Reproduce with: buck build @arvr/mode/win/opt +// //xplat/caffe2:autograd_libtorch_test_ovrsource It is probably caused by the +// lambda, see https://github.com/pytorch/pytorch/issues/48763 #if !defined(_MSC_VER) TEST(TestAutogradNotImplementedFallback, RetSingleNonTensor) { - REGISTER_TEST_OP("ret_single_non_tensor", "_test::ret_single_non_tensor(Tensor self, Tensor other) -> int", ret_single_non_tensor); - auto opHandle = c10::Dispatcher::singleton().findSchemaOrThrow("_test::ret_single_non_tensor", ""); + REGISTER_TEST_OP( + "ret_single_non_tensor", + "_test::ret_single_non_tensor(Tensor self, Tensor other) -> int", + ret_single_non_tensor); + auto opHandle = c10::Dispatcher::singleton().findSchemaOrThrow( + "_test::ret_single_non_tensor", ""); auto op = [&](const torch::Tensor& _1, const torch::Tensor& _2) { - return callOpUnboxed(opHandle, _1, _2); + return callOpUnboxed( + opHandle, _1, _2); }; auto a = torch::tensor({1.}, {torch::kFloat32}).set_requires_grad(true); @@ -1001,18 +1090,26 @@ TEST(TestAutogradNotImplementedFallback, RetSingleNonTensor) { } TEST(TestAutogradNotImplementedFallback, InplaceOp) { - REGISTER_TEST_OP("inplace_op", "_test::inplace_op(Tensor(a!) self, Tensor other) -> Tensor(a!)", inplace_op); - auto opHandle = c10::Dispatcher::singleton().findSchemaOrThrow("_test::inplace_op", ""); + REGISTER_TEST_OP( + "inplace_op", + "_test::inplace_op(Tensor(a!) self, Tensor other) -> Tensor(a!)", + inplace_op); + auto opHandle = + c10::Dispatcher::singleton().findSchemaOrThrow("_test::inplace_op", ""); auto op = [&](const torch::Tensor& _1, const torch::Tensor& _2) { - return callOpUnboxed(opHandle, _1, _2); + return callOpUnboxed< + torch::Tensor, + const torch::Tensor&, + const torch::Tensor&>(opHandle, _1, _2); }; auto a = torch::tensor({1.}, {torch::kFloat32}).set_requires_grad(true); auto b = torch::tensor({1.}, {torch::kFloat32}); // Check in-place - ASSERT_THROWS_WITH(op(a, b), - "a leaf Variable that requires grad is being used in an in-place operation"); + ASSERT_THROWS_WITH( + op(a, b), + "a leaf Variable that requires grad is being used in an in-place operation"); op(b, a); a = a.clone(); b = b.clone(); @@ -1020,7 +1117,8 @@ TEST(TestAutogradNotImplementedFallback, InplaceOp) { ASSERT_TRUE(torch::allclose(c, inplace_op(a, b))); // Test in-place on view - auto base = torch::tensor({1.}, {torch::kFloat32}).set_requires_grad(true).clone(); + auto base = + torch::tensor({1.}, {torch::kFloat32}).set_requires_grad(true).clone(); auto view = base.view(-1); auto t = torch::tensor({1.}, {torch::kFloat32}); @@ -1033,26 +1131,38 @@ TEST(TestAutogradNotImplementedFallback, InplaceOp) { ASSERT_THROWS_WITH(op(v_nograd, t), "A view was created in no_grad mode"); ASSERT_EQ(op(view, t).unsafeGetTensorImpl(), view.unsafeGetTensorImpl()); - ASSERT_THAT(op(view, t).grad_fn()->name(), ::testing::HasSubstr("AsStridedBackward")); + ASSERT_THAT( + op(view, t).grad_fn()->name(), ::testing::HasSubstr("AsStridedBackward")); } TEST(TestAutogradNotImplementedFallback, DoubleInplaceOp) { - REGISTER_TEST_OP("two_arg_inplace_op", "_test::two_arg_inplace_op(Tensor(a!) self, Tensor(b!) other) -> (Tensor(a!), Tensor(b!))", two_arg_inplace_op); - auto opHandle = c10::Dispatcher::singleton().findSchemaOrThrow("_test::two_arg_inplace_op", ""); + REGISTER_TEST_OP( + "two_arg_inplace_op", + "_test::two_arg_inplace_op(Tensor(a!) self, Tensor(b!) other) -> (Tensor(a!), Tensor(b!))", + two_arg_inplace_op); + auto opHandle = c10::Dispatcher::singleton().findSchemaOrThrow( + "_test::two_arg_inplace_op", ""); auto op = [&](const torch::Tensor& _1, const torch::Tensor& _2) { - return callOpUnboxed, const torch::Tensor&, const torch::Tensor&>(opHandle, _1, _2); + return callOpUnboxed< + std::tuple, + const torch::Tensor&, + const torch::Tensor&>(opHandle, _1, _2); }; auto a = torch::tensor({1.}, {torch::kFloat32}).set_requires_grad(true); auto b = torch::tensor({1.}, {torch::kFloat32}); // Both are modified in-place! - ASSERT_THROWS_WITH(op(a, b), - "a leaf Variable that requires grad is being used in an in-place operation"); - ASSERT_THROWS_WITH(op(b, a), - "a leaf Variable that requires grad is being used in an in-place operation"); + ASSERT_THROWS_WITH( + op(a, b), + "a leaf Variable that requires grad is being used in an in-place operation"); + ASSERT_THROWS_WITH( + op(b, a), + "a leaf Variable that requires grad is being used in an in-place operation"); - auto c = torch::tensor({1.}, {torch::kFloat32}).set_requires_grad(true).clone(); - auto d = torch::tensor({1.}, {torch::kFloat32}).set_requires_grad(true).clone(); + auto c = + torch::tensor({1.}, {torch::kFloat32}).set_requires_grad(true).clone(); + auto d = + torch::tensor({1.}, {torch::kFloat32}).set_requires_grad(true).clone(); auto saved_version_c = c._version(); auto saved_version_d = d._version(); @@ -1062,10 +1172,16 @@ TEST(TestAutogradNotImplementedFallback, DoubleInplaceOp) { } TEST(TestAutogradNotImplementedFallback, OptOp) { - REGISTER_TEST_OP("opt_op", "_test::opt_op(Tensor self, Tensor? other) -> Tensor", opt_op); - auto opHandle = c10::Dispatcher::singleton().findSchemaOrThrow("_test::opt_op", ""); - auto op = [&](const torch::Tensor& _1, const c10::optional& _2) { - return callOpUnboxed&>(opHandle, _1, _2); + REGISTER_TEST_OP( + "opt_op", "_test::opt_op(Tensor self, Tensor? other) -> Tensor", opt_op); + auto opHandle = + c10::Dispatcher::singleton().findSchemaOrThrow("_test::opt_op", ""); + auto op = [&](const torch::Tensor& _1, + const c10::optional& _2) { + return callOpUnboxed< + torch::Tensor, + const torch::Tensor&, + const c10::optional&>(opHandle, _1, _2); }; auto a = torch::tensor({1.}, {torch::kFloat32}).set_requires_grad(true); @@ -1076,23 +1192,37 @@ TEST(TestAutogradNotImplementedFallback, OptOp) { } TEST(TestAutogradNotImplementedFallback, OutOfPlaceAddition) { - REGISTER_TEST_OP("my_custom_op", "_test::my_custom_op(Tensor self, Tensor other) -> Tensor", my_custom_op); - auto opHandle = c10::Dispatcher::singleton().findSchemaOrThrow("_test::my_custom_op", ""); + REGISTER_TEST_OP( + "my_custom_op", + "_test::my_custom_op(Tensor self, Tensor other) -> Tensor", + my_custom_op); + auto opHandle = + c10::Dispatcher::singleton().findSchemaOrThrow("_test::my_custom_op", ""); auto op = [&](const torch::Tensor& _1, const torch::Tensor& _2) { - return callOpUnboxed(opHandle, _1, _2); + return callOpUnboxed< + torch::Tensor, + const torch::Tensor&, + const torch::Tensor&>(opHandle, _1, _2); }; assertBasicChecks(op); } TEST(TestAutogradNotImplementedFallback, RetTupleNonTensor) { - REGISTER_TEST_OP("ret_tuple_non_tensor", "_test::ret_tuple_non_tensor(Tensor self, Tensor other) -> (Tensor, Tensor, int)", ret_tuple_non_tensor); - auto opHandle = c10::Dispatcher::singleton().findSchemaOrThrow("_test::ret_tuple_non_tensor", ""); + REGISTER_TEST_OP( + "ret_tuple_non_tensor", + "_test::ret_tuple_non_tensor(Tensor self, Tensor other) -> (Tensor, Tensor, int)", + ret_tuple_non_tensor); + auto opHandle = c10::Dispatcher::singleton().findSchemaOrThrow( + "_test::ret_tuple_non_tensor", ""); auto op = [&](const torch::Tensor& _1, const torch::Tensor& _2) { torch::Tensor out0; torch::Tensor out1; int64_t out2; - auto out = callOpUnboxed, const torch::Tensor&, const torch::Tensor&>(opHandle, _1, _2); + auto out = callOpUnboxed< + std::tuple, + const torch::Tensor&, + const torch::Tensor&>(opHandle, _1, _2); std::tie(out0, out1, out2) = std::move(out); return out0; }; @@ -1101,8 +1231,10 @@ TEST(TestAutogradNotImplementedFallback, RetTupleNonTensor) { } TEST(TestAutogradNotImplementedFallback, ViewOp) { - REGISTER_TEST_OP("view_op", "_test::view_op(Tensor(a) self) -> Tensor(a)", view_op); - auto opHandle = c10::Dispatcher::singleton().findSchemaOrThrow("_test::view_op", ""); + REGISTER_TEST_OP( + "view_op", "_test::view_op(Tensor(a) self) -> Tensor(a)", view_op); + auto opHandle = + c10::Dispatcher::singleton().findSchemaOrThrow("_test::view_op", ""); auto op = [&](const torch::Tensor& _1) { return callOpUnboxed(opHandle, _1); }; @@ -1111,7 +1243,8 @@ TEST(TestAutogradNotImplementedFallback, ViewOp) { ASSERT_TRUE(v.is_view()); ASSERT_EQ(v._base().unsafeGetTensorImpl(), b.unsafeGetTensorImpl()); - auto b1 = torch::tensor({1.}, {torch::kFloat32}).set_requires_grad(true).clone(); + auto b1 = + torch::tensor({1.}, {torch::kFloat32}).set_requires_grad(true).clone(); auto v1 = op(b1); ASSERT_TRUE(v1.is_view()); ASSERT_EQ(v1._base().unsafeGetTensorImpl(), b1.unsafeGetTensorImpl()); @@ -1120,17 +1253,27 @@ TEST(TestAutogradNotImplementedFallback, ViewOp) { auto t = torch::tensor({1.}, {torch::kFloat32}).set_requires_grad(true); // raise on rebase_history when it refreshes grad_fn - ASSERT_THROWS_WITH(v1.add_(t), "which does not have a derivative implemented is forbidden"); + ASSERT_THROWS_WITH( + v1.add_(t), "which does not have a derivative implemented is forbidden"); // base should not be aware of the views, so this is still okay b1.add_(t); - ASSERT_THROWS_WITH(v1.grad_fn(), "which does not have a derivative implemented is forbidden"); + ASSERT_THROWS_WITH( + v1.grad_fn(), + "which does not have a derivative implemented is forbidden"); } TEST(TestAutogradNotImplementedFallback, ViewOpWithExtraArg) { - REGISTER_TEST_OP("view_op_with_extra_arg", "_test::view_op_with_extra_arg(Tensor(a) self, Tensor other) -> Tensor(a)", view_op_with_extra_arg); - auto opHandle = c10::Dispatcher::singleton().findSchemaOrThrow("_test::view_op_with_extra_arg", ""); + REGISTER_TEST_OP( + "view_op_with_extra_arg", + "_test::view_op_with_extra_arg(Tensor(a) self, Tensor other) -> Tensor(a)", + view_op_with_extra_arg); + auto opHandle = c10::Dispatcher::singleton().findSchemaOrThrow( + "_test::view_op_with_extra_arg", ""); auto op = [&](const torch::Tensor& _1, const torch::Tensor& _2) { - return callOpUnboxed(opHandle, _1, _2); + return callOpUnboxed< + torch::Tensor, + const torch::Tensor&, + const torch::Tensor&>(opHandle, _1, _2); }; assertBasicChecks(op); auto a = torch::tensor({1.}, {torch::kFloat32}); @@ -1141,10 +1284,17 @@ TEST(TestAutogradNotImplementedFallback, ViewOpWithExtraArg) { } TEST(TestAutogradNotImplementedFallback, RetTensorVectorView) { - REGISTER_TEST_OP("ret_tensor_vector_view", "_test::ret_tensor_vector_view(Tensor(a) self, Tensor other) -> Tensor[](a)", ret_tensor_vector_view); - auto opHandle = c10::Dispatcher::singleton().findSchemaOrThrow("_test::ret_tensor_vector_view", ""); + REGISTER_TEST_OP( + "ret_tensor_vector_view", + "_test::ret_tensor_vector_view(Tensor(a) self, Tensor other) -> Tensor[](a)", + ret_tensor_vector_view); + auto opHandle = c10::Dispatcher::singleton().findSchemaOrThrow( + "_test::ret_tensor_vector_view", ""); auto op = [&](const torch::Tensor& _1, const torch::Tensor& _2) { - return callOpUnboxed, const torch::Tensor&, const torch::Tensor&>(opHandle, _1, _2); + return callOpUnboxed< + std::vector, + const torch::Tensor&, + const torch::Tensor&>(opHandle, _1, _2); }; auto a = torch::tensor({1.}, {torch::kFloat32}); auto b = torch::tensor({1.}, {torch::kFloat32}); @@ -1156,43 +1306,70 @@ TEST(TestAutogradNotImplementedFallback, RetTensorVectorView) { } TEST(TestAutogradNotImplementedFallback, DoubleViewOP) { - REGISTER_TEST_OP("two_pairs_of_view_op", "_test::two_pairs_of_view_op(Tensor(a) self, Tensor(b) other) -> (Tensor(a), Tensor(b))", two_pairs_of_view_op); - auto opHandle = c10::Dispatcher::singleton().findSchemaOrThrow("_test::two_pairs_of_view_op", ""); + REGISTER_TEST_OP( + "two_pairs_of_view_op", + "_test::two_pairs_of_view_op(Tensor(a) self, Tensor(b) other) -> (Tensor(a), Tensor(b))", + two_pairs_of_view_op); + auto opHandle = c10::Dispatcher::singleton().findSchemaOrThrow( + "_test::two_pairs_of_view_op", ""); auto op = [&](const torch::Tensor& _1, const torch::Tensor& _2) { - return callOpUnboxed, const torch::Tensor&, const torch::Tensor&>(opHandle, _1, _2); + return callOpUnboxed< + std::tuple, + const torch::Tensor&, + const torch::Tensor&>(opHandle, _1, _2); }; auto a = torch::tensor({1.}, {torch::kFloat32}).set_requires_grad(true); auto b = torch::tensor({1.}, {torch::kFloat32}); - ASSERT_THROWS_WITH(op(a, b), - "Expected only a single output in the operator schema to have a non-write alias annotation"); + ASSERT_THROWS_WITH( + op(a, b), + "Expected only a single output in the operator schema to have a non-write alias annotation"); } TEST(TestAutogradNotImplementedFallback, NonFirstViewOP) { - REGISTER_TEST_OP("non_first_view_op", "_test::non_first_view_op(Tensor self, Tensor(b) other) -> (Tensor, Tensor(b))", non_first_view_op); - auto opHandle = c10::Dispatcher::singleton().findSchemaOrThrow("_test::non_first_view_op", ""); + REGISTER_TEST_OP( + "non_first_view_op", + "_test::non_first_view_op(Tensor self, Tensor(b) other) -> (Tensor, Tensor(b))", + non_first_view_op); + auto opHandle = c10::Dispatcher::singleton().findSchemaOrThrow( + "_test::non_first_view_op", ""); auto op = [&](const torch::Tensor& _1, const torch::Tensor& _2) { - return callOpUnboxed, const torch::Tensor&, const torch::Tensor&>(opHandle, _1, _2); + return callOpUnboxed< + std::tuple, + const torch::Tensor&, + const torch::Tensor&>(opHandle, _1, _2); }; auto a = torch::tensor({1.}, {torch::kFloat32}).set_requires_grad(true); auto b = torch::tensor({1.}, {torch::kFloat32}); - ASSERT_THROWS_WITH(op(a, b), - "can only create view relationships between the first"); + ASSERT_THROWS_WITH( + op(a, b), "can only create view relationships between the first"); } TEST(TestAutogradNotImplementedFallback, RetTensorVector) { - REGISTER_TEST_OP("ret_tensor_vector", "_test::ret_tensor_vector(Tensor self, Tensor other) -> Tensor[]", ret_tensor_vector); - auto opHandle = c10::Dispatcher::singleton().findSchemaOrThrow("_test::ret_tensor_vector", ""); + REGISTER_TEST_OP( + "ret_tensor_vector", + "_test::ret_tensor_vector(Tensor self, Tensor other) -> Tensor[]", + ret_tensor_vector); + auto opHandle = c10::Dispatcher::singleton().findSchemaOrThrow( + "_test::ret_tensor_vector", ""); auto op = [&](const torch::Tensor& _1, const torch::Tensor& _2) { - return callOpUnboxed, const torch::Tensor&, const torch::Tensor&>(opHandle, _1, _2)[0]; + return callOpUnboxed< + std::vector, + const torch::Tensor&, + const torch::Tensor&>(opHandle, _1, _2)[0]; }; assertBasicChecks(op); } TEST(TestAutogradNotImplementedFallback, TensorlistOp) { - REGISTER_TEST_OP("tensorlist_op", "_test::tensorlist_op(Tensor self, Tensor[] other) -> Tensor", tensorlist_op); - auto opHandle = c10::Dispatcher::singleton().findSchemaOrThrow("_test::tensorlist_op", ""); + REGISTER_TEST_OP( + "tensorlist_op", + "_test::tensorlist_op(Tensor self, Tensor[] other) -> Tensor", + tensorlist_op); + auto opHandle = c10::Dispatcher::singleton().findSchemaOrThrow( + "_test::tensorlist_op", ""); auto op = [&](torch::Tensor _1, at::TensorList _2) { - return callOpUnboxed(opHandle, _1, _2); + return callOpUnboxed( + opHandle, _1, _2); }; auto a = torch::tensor({1.}, {torch::kFloat32}); @@ -1201,15 +1378,17 @@ TEST(TestAutogradNotImplementedFallback, TensorlistOp) { std::vector vec = {b, c}; auto out = op(a, vec); - ASSERT_THROWS_WITH(torch::autograd::grad({out}, {vec[0]}), "One of the differentiated Tensors does not require grad"); - ASSERT_THROWS_WITH(torch::autograd::grad({out}, {vec[1]}), "is not implemented"); + ASSERT_THROWS_WITH( + torch::autograd::grad({out}, {vec[0]}), + "One of the differentiated Tensors does not require grad"); + ASSERT_THROWS_WITH( + torch::autograd::grad({out}, {vec[1]}), "is not implemented"); ASSERT_TRUE(at::allclose(op(a, vec), tensorlist_op(a, vec))); } #endif - // TODO add these tests if needed // test_once_differentiable // test_sparse_backward diff --git a/test/cpp/api/dataloader.cpp b/test/cpp/api/dataloader.cpp index db66245b11f344..5dfb3b0936c2f4 100644 --- a/test/cpp/api/dataloader.cpp +++ b/test/cpp/api/dataloader.cpp @@ -63,8 +63,7 @@ TEST(DataTest, TransformCallsGetApplyCorrectly) { // dummy chunk data reader with 3 chunks and 35 examples in total. Each chunk // contains 10, 5, 20 examples respectively. -struct DummyChunkDataReader - : public datasets::ChunkDataReader { +struct DummyChunkDataReader : public datasets::ChunkDataReader { public: using BatchType = datasets::ChunkDataReader::ChunkType; using DataType = datasets::ChunkDataReader::ExampleType; @@ -621,8 +620,9 @@ struct UnCopyableDataset : public datasets::Dataset { ~UnCopyableDataset() = default; Example<> get(size_t index) override { - return {torch::tensor({static_cast(index)}), - torch::tensor({static_cast(index)})}; + return { + torch::tensor({static_cast(index)}), + torch::tensor({static_cast(index)})}; } torch::optional size() const override { @@ -1598,8 +1598,8 @@ TEST(DataLoaderTest, StatefulDatasetWithCollate) { counter += batch_size; std::vector> batch( /*count=*/batch_size, - Example<>{torch::ones(batch_size + 1), - torch::zeros(batch_size - 1)}); + Example<>{ + torch::ones(batch_size + 1), torch::zeros(batch_size - 1)}); return batch; } return torch::nullopt; @@ -1729,8 +1729,7 @@ TEST(DataLoaderTest, ChunkDataSetWithBatchSizeMismatch) { datasets::ChunkDatasetOptions(prefetch_count, batch_size)); auto data_loader = torch::data::make_data_loader( - dataset, - DataLoaderOptions(requested_batch_size).workers(0)); + dataset, DataLoaderOptions(requested_batch_size).workers(0)); std::string exception_msg = "The requested batch size does not match with the initialized batch " @@ -1741,8 +1740,7 @@ TEST(DataLoaderTest, ChunkDataSetWithBatchSizeMismatch) { } TEST(DataLoaderTest, ChunkDataSetWithEmptyBatch) { - struct DummyEmptyChunkDataReader - : datasets::ChunkDataReader { + struct DummyEmptyChunkDataReader : datasets::ChunkDataReader { public: using BatchType = datasets::ChunkDataReader::ChunkType; @@ -1859,7 +1857,9 @@ TEST(DataLoaderTest, CanAccessChunkSamplerWithChunkDataSet) { samplers::SequentialSampler& chunk_sampler = dataset->chunk_sampler(); auto data_loader = torch::data::make_data_loader( - dataset.map(transforms::BatchLambda( + dataset.map(transforms::BatchLambda< + DummyChunkDataReader::BatchType, + DummyChunkDataReader::DataType>( [](DummyChunkDataReader::BatchType batch) { return std::accumulate(batch.begin(), batch.end(), 0); })), @@ -1901,7 +1901,9 @@ TEST(DataLoaderTest, ChunkDatasetDoesNotHang) { prefetch_count, batch_size, cache_size)); auto data_loader = torch::data::make_data_loader( - dataset.map(transforms::BatchLambda( + dataset.map(transforms::BatchLambda< + DummyChunkDataReader::BatchType, + DummyChunkDataReader::DataType>( [](DummyChunkDataReader::BatchType batch) { return std::accumulate(batch.begin(), batch.end(), 0); })), @@ -2220,7 +2222,8 @@ TEST(DataLoaderTest, ChunkDatasetCrossChunkShuffle) { std::vector expected_result; { // construct expected result - for (const auto i : c10::irange((chunk_count + cross_chunk_shuffle_count - 1) / + for (const auto i : c10::irange( + (chunk_count + cross_chunk_shuffle_count - 1) / cross_chunk_shuffle_count)) { for (const auto j : c10::irange(chunk_size)) { (void)j; // Suppress unused variable warning @@ -2269,8 +2272,7 @@ TEST(DataLoaderTest, CustomPreprocessPolicy) { auto sorting_policy = [](std::vector& raw_batch_data) { std::sort(raw_batch_data.begin(), raw_batch_data.end()); }; - std::function&)> policy_function = - sorting_policy; + std::function&)> policy_function = sorting_policy; const size_t prefetch_count = 1; const size_t cache_size = 10; diff --git a/test/cpp/api/dispatch.cpp b/test/cpp/api/dispatch.cpp index ba5300659b39e0..f59a133299228a 100644 --- a/test/cpp/api/dispatch.cpp +++ b/test/cpp/api/dispatch.cpp @@ -1,22 +1,21 @@ #include -#include #include #include +#include +#include #include #include -#include +#include #include -#include #include -#include - +#include struct DispatchTest : torch::test::SeedingFixture {}; TEST_F(DispatchTest, TestAVX2) { - const std::vector ints {1, 2, 3, 4}; - const std::vector result {1, 4, 27, 256}; + const std::vector ints{1, 2, 3, 4}; + const std::vector result{1, 4, 27, 256}; const auto vals_tensor = torch::tensor(ints); const auto pows_tensor = torch::tensor(ints); #ifdef _WIN32 @@ -31,8 +30,8 @@ TEST_F(DispatchTest, TestAVX2) { } TEST_F(DispatchTest, TestAVX512) { - const std::vector ints {1, 2, 3, 4}; - const std::vector result {1, 4, 27, 256}; + const std::vector ints{1, 2, 3, 4}; + const std::vector result{1, 4, 27, 256}; const auto vals_tensor = torch::tensor(ints); const auto pows_tensor = torch::tensor(ints); #ifdef _WIN32 @@ -47,8 +46,8 @@ TEST_F(DispatchTest, TestAVX512) { } TEST_F(DispatchTest, TestDefault) { - const std::vector ints {1, 2, 3, 4}; - const std::vector result {1, 4, 27, 256}; + const std::vector ints{1, 2, 3, 4}; + const std::vector result{1, 4, 27, 256}; const auto vals_tensor = torch::tensor(ints); const auto pows_tensor = torch::tensor(ints); #ifdef _WIN32 diff --git a/test/cpp/api/enum.cpp b/test/cpp/api/enum.cpp index bcf781bb1a8776..aadfae81163e35 100644 --- a/test/cpp/api/enum.cpp +++ b/test/cpp/api/enum.cpp @@ -4,51 +4,51 @@ #include -#define TORCH_ENUM_PRETTY_PRINT_TEST(name) \ -{ \ - v = torch::k##name; \ - std::string pretty_print_name("k"); \ - pretty_print_name.append(#name); \ - ASSERT_EQ(torch::enumtype::get_enum_name(v), pretty_print_name); \ -} +#define TORCH_ENUM_PRETTY_PRINT_TEST(name) \ + { \ + v = torch::k##name; \ + std::string pretty_print_name("k"); \ + pretty_print_name.append(#name); \ + ASSERT_EQ(torch::enumtype::get_enum_name(v), pretty_print_name); \ + } TEST(EnumTest, AllEnums) { c10::variant< - torch::enumtype::kLinear, - torch::enumtype::kConv1D, - torch::enumtype::kConv2D, - torch::enumtype::kConv3D, - torch::enumtype::kConvTranspose1D, - torch::enumtype::kConvTranspose2D, - torch::enumtype::kConvTranspose3D, - torch::enumtype::kSigmoid, - torch::enumtype::kTanh, - torch::enumtype::kReLU, - torch::enumtype::kLeakyReLU, - torch::enumtype::kFanIn, - torch::enumtype::kFanOut, - torch::enumtype::kConstant, - torch::enumtype::kReflect, - torch::enumtype::kReplicate, - torch::enumtype::kCircular, - torch::enumtype::kNearest, - torch::enumtype::kBilinear, - torch::enumtype::kBicubic, - torch::enumtype::kTrilinear, - torch::enumtype::kArea, - torch::enumtype::kSum, - torch::enumtype::kMean, - torch::enumtype::kMax, - torch::enumtype::kNone, - torch::enumtype::kBatchMean, - torch::enumtype::kZeros, - torch::enumtype::kBorder, - torch::enumtype::kReflection, - torch::enumtype::kRNN_TANH, - torch::enumtype::kRNN_RELU, - torch::enumtype::kLSTM, - torch::enumtype::kGRU - > v; + torch::enumtype::kLinear, + torch::enumtype::kConv1D, + torch::enumtype::kConv2D, + torch::enumtype::kConv3D, + torch::enumtype::kConvTranspose1D, + torch::enumtype::kConvTranspose2D, + torch::enumtype::kConvTranspose3D, + torch::enumtype::kSigmoid, + torch::enumtype::kTanh, + torch::enumtype::kReLU, + torch::enumtype::kLeakyReLU, + torch::enumtype::kFanIn, + torch::enumtype::kFanOut, + torch::enumtype::kConstant, + torch::enumtype::kReflect, + torch::enumtype::kReplicate, + torch::enumtype::kCircular, + torch::enumtype::kNearest, + torch::enumtype::kBilinear, + torch::enumtype::kBicubic, + torch::enumtype::kTrilinear, + torch::enumtype::kArea, + torch::enumtype::kSum, + torch::enumtype::kMean, + torch::enumtype::kMax, + torch::enumtype::kNone, + torch::enumtype::kBatchMean, + torch::enumtype::kZeros, + torch::enumtype::kBorder, + torch::enumtype::kReflection, + torch::enumtype::kRNN_TANH, + torch::enumtype::kRNN_RELU, + torch::enumtype::kLSTM, + torch::enumtype::kGRU> + v; TORCH_ENUM_PRETTY_PRINT_TEST(Linear) TORCH_ENUM_PRETTY_PRINT_TEST(Conv1D) diff --git a/test/cpp/api/fft.cpp b/test/cpp/api/fft.cpp index 5b6452d0a853fd..706ea44fe3def2 100644 --- a/test/cpp/api/fft.cpp +++ b/test/cpp/api/fft.cpp @@ -1,18 +1,18 @@ #include #include -#include #include - +#include // Naive DFT of a 1 dimensional tensor -torch::Tensor naive_dft(torch::Tensor x, bool forward=true) { +torch::Tensor naive_dft(torch::Tensor x, bool forward = true) { TORCH_INTERNAL_ASSERT(x.dim() == 1); x = x.contiguous(); auto out_tensor = torch::zeros_like(x); const int64_t len = x.size(0); - // Roots of unity, exp(-2*pi*j*n/N) for n in [0, N), reversed for inverse transform + // Roots of unity, exp(-2*pi*j*n/N) for n in [0, N), reversed for inverse + // transform std::vector> roots(len); const auto angle_base = (forward ? -2.0 : 2.0) * M_PI / len; for (const auto i : c10::irange(len)) { diff --git a/test/cpp/api/functional.cpp b/test/cpp/api/functional.cpp index d2537936592573..1090aeafd1b46f 100644 --- a/test/cpp/api/functional.cpp +++ b/test/cpp/api/functional.cpp @@ -12,14 +12,17 @@ using namespace torch::nn; struct FunctionalTest : torch::test::SeedingFixture {}; TEST_F(FunctionalTest, Conv1d) { - auto x = torch::arange(30, torch::dtype(torch::kFloat).requires_grad(true)).reshape({2, 3, 5}); - auto weight = torch::arange(18, torch::dtype(torch::kFloat).requires_grad(true)).reshape({2, 3, 3}); + auto x = torch::arange(30, torch::dtype(torch::kFloat).requires_grad(true)) + .reshape({2, 3, 5}); + auto weight = + torch::arange(18, torch::dtype(torch::kFloat).requires_grad(true)) + .reshape({2, 3, 3}); auto y = F::conv1d(x, weight, F::Conv1dFuncOptions().stride(1)); - auto expected = torch::tensor({{{ 312., 348., 384.}, - { 798., 915., 1032.}}, + auto expected = torch::tensor( + {{{312., 348., 384.}, {798., 915., 1032.}}, - {{ 852., 888., 924.}, - {2553., 2670., 2787.}}}, torch::kFloat); + {{852., 888., 924.}, {2553., 2670., 2787.}}}, + torch::kFloat); ASSERT_TRUE(torch::allclose(y, expected)); auto y_no_options = F::conv1d(x, weight); @@ -27,16 +30,21 @@ TEST_F(FunctionalTest, Conv1d) { } TEST_F(FunctionalTest, Conv2dEven) { - auto x = torch::arange(75, torch::dtype(torch::kFloat).requires_grad(true)).reshape({1, 3, 5, 5}); - auto weight = torch::arange(54, torch::dtype(torch::kFloat).requires_grad(true)).reshape({2, 3, 3, 3}); + auto x = torch::arange(75, torch::dtype(torch::kFloat).requires_grad(true)) + .reshape({1, 3, 5, 5}); + auto weight = + torch::arange(54, torch::dtype(torch::kFloat).requires_grad(true)) + .reshape({2, 3, 3, 3}); auto y = F::conv2d(x, weight, F::Conv2dFuncOptions().stride(1)); - auto expected = torch::tensor({{{{15219., 15570., 15921.}, - {16974., 17325., 17676.}, - {18729., 19080., 19431.}}, + auto expected = torch::tensor( + {{{{15219., 15570., 15921.}, + {16974., 17325., 17676.}, + {18729., 19080., 19431.}}, - {{37818., 38898., 39978.}, - {43218., 44298., 45378.}, - {48618., 49698., 50778.}}}}, torch::kFloat); + {{37818., 38898., 39978.}, + {43218., 44298., 45378.}, + {48618., 49698., 50778.}}}}, + torch::kFloat); ASSERT_TRUE(torch::allclose(y, expected)); auto y_no_options = F::conv2d(x, weight); @@ -44,16 +52,19 @@ TEST_F(FunctionalTest, Conv2dEven) { } TEST_F(FunctionalTest, Conv2dUneven) { - auto x = torch::arange(60, torch::dtype(torch::kFloat).requires_grad(true)).reshape({1, 3, 5, 4}); - auto weight = torch::arange(36, torch::dtype(torch::kFloat).requires_grad(true)).reshape({2, 3, 3, 2}); + auto x = torch::arange(60, torch::dtype(torch::kFloat).requires_grad(true)) + .reshape({1, 3, 5, 4}); + auto weight = + torch::arange(36, torch::dtype(torch::kFloat).requires_grad(true)) + .reshape({2, 3, 3, 2}); auto y = F::conv2d(x, weight, F::Conv2dFuncOptions().stride(1)); - auto expected = torch::tensor({{{{ 5289., 5442., 5595.}, - { 5901., 6054., 6207.}, - { 6513., 6666., 6819.}}, + auto expected = torch::tensor( + {{{{5289., 5442., 5595.}, {5901., 6054., 6207.}, {6513., 6666., 6819.}}, - {{13227., 13704., 14181.}, - {15135., 15612., 16089.}, - {17043., 17520., 17997.}}}}, torch::kFloat); + {{13227., 13704., 14181.}, + {15135., 15612., 16089.}, + {17043., 17520., 17997.}}}}, + torch::kFloat); ASSERT_TRUE(torch::allclose(y, expected)); auto y_no_options = F::conv2d(x, weight); @@ -61,33 +72,37 @@ TEST_F(FunctionalTest, Conv2dUneven) { } TEST_F(FunctionalTest, Conv3d) { - auto x = torch::arange(375, torch::dtype(torch::kFloat).requires_grad(true)).reshape({1, 3, 5, 5, 5}); - auto weight = torch::arange(162, torch::dtype(torch::kFloat).requires_grad(true)).reshape({2, 3, 3, 3, 3}); + auto x = torch::arange(375, torch::dtype(torch::kFloat).requires_grad(true)) + .reshape({1, 3, 5, 5, 5}); + auto weight = + torch::arange(162, torch::dtype(torch::kFloat).requires_grad(true)) + .reshape({2, 3, 3, 3, 3}); auto y = F::conv3d(x, weight, F::Conv3dFuncOptions().stride(1)); - auto expected = torch::tensor({{{{{ 700704., 703944., 707184.}, - { 716904., 720144., 723384.}, - { 733104., 736344., 739584.}}, - - {{ 781704., 784944., 788184.}, - { 797904., 801144., 804384.}, - { 814104., 817344., 820584.}}, + auto expected = torch::tensor( + {{{{{700704., 703944., 707184.}, + {716904., 720144., 723384.}, + {733104., 736344., 739584.}}, - {{ 862704., 865944., 869184.}, - { 878904., 882144., 885384.}, - { 895104., 898344., 901584.}}}, + {{781704., 784944., 788184.}, + {797904., 801144., 804384.}, + {814104., 817344., 820584.}}, + {{862704., 865944., 869184.}, + {878904., 882144., 885384.}, + {895104., 898344., 901584.}}}, - {{{1724220., 1734021., 1743822.}, - {1773225., 1783026., 1792827.}, - {1822230., 1832031., 1841832.}}, + {{{1724220., 1734021., 1743822.}, + {1773225., 1783026., 1792827.}, + {1822230., 1832031., 1841832.}}, - {{1969245., 1979046., 1988847.}, - {2018250., 2028051., 2037852.}, - {2067255., 2077056., 2086857.}}, + {{1969245., 1979046., 1988847.}, + {2018250., 2028051., 2037852.}, + {2067255., 2077056., 2086857.}}, - {{2214270., 2224071., 2233872.}, - {2263275., 2273076., 2282877.}, - {2312280., 2322081., 2331882.}}}}}, torch::kFloat); + {{2214270., 2224071., 2233872.}, + {2263275., 2273076., 2282877.}, + {2312280., 2322081., 2331882.}}}}}, + torch::kFloat); ASSERT_TRUE(torch::allclose(y, expected)); auto y_no_options = F::conv3d(x, weight); @@ -113,7 +128,8 @@ TEST_F(FunctionalTest, MaxPool2d) { } TEST_F(FunctionalTest, MaxPool2dBackward) { - auto input = torch::rand({1, 2, 4, 4}, torch::dtype(torch::kFloat).requires_grad(true)); + auto input = torch::rand( + {1, 2, 4, 4}, torch::dtype(torch::kFloat).requires_grad(true)); auto output = F::max_pool2d(input, F::MaxPool2dFuncOptions(2)); auto s = output.sum(); s.backward(); @@ -158,65 +174,61 @@ TEST_F(FunctionalTest, AvgPool3d) { TEST_F(FunctionalTest, FractionalMaxPool2d) { auto x = torch::ones({2, 5, 5}); - auto y = F::fractional_max_pool2d(x, F::FractionalMaxPool2dFuncOptions(3).output_size(2)); + auto y = F::fractional_max_pool2d( + x, F::FractionalMaxPool2dFuncOptions(3).output_size(2)); ASSERT_EQ(y.ndimension(), 3); ASSERT_TRUE(torch::allclose(y, torch::ones({2, 2, 2}))); ASSERT_EQ(y.sizes(), std::vector({2, 2, 2})); - auto y_with_indices = F::fractional_max_pool2d_with_indices(x, F::FractionalMaxPool2dFuncOptions(3).output_size(2)); + auto y_with_indices = F::fractional_max_pool2d_with_indices( + x, F::FractionalMaxPool2dFuncOptions(3).output_size(2)); ASSERT_TRUE(torch::equal(y, std::get<0>(y_with_indices))); ASSERT_TRUE(torch::allclose( - std::get<1>(y_with_indices), - torch::tensor({{{ 0, 2}, - {10, 12}}, - {{ 0, 2}, - {10, 12}}}))); - ASSERT_EQ(std::get<1>(y_with_indices).sizes(), std::vector({2, 2, 2})); + std::get<1>(y_with_indices), + torch::tensor({{{0, 2}, {10, 12}}, {{0, 2}, {10, 12}}}))); + ASSERT_EQ( + std::get<1>(y_with_indices).sizes(), std::vector({2, 2, 2})); auto x1 = torch::ones({2, 2, 5, 5}); - auto y1 = F::fractional_max_pool2d(x1, F::FractionalMaxPool2dFuncOptions(3).output_size(2)); + auto y1 = F::fractional_max_pool2d( + x1, F::FractionalMaxPool2dFuncOptions(3).output_size(2)); ASSERT_EQ(y1.ndimension(), 4); ASSERT_TRUE(torch::allclose(y1, torch::ones({2, 2, 2, 2}))); ASSERT_EQ(y1.sizes(), std::vector({2, 2, 2, 2})); - auto y1_with_indices = F::fractional_max_pool2d_with_indices(x1, F::FractionalMaxPool2dFuncOptions(3).output_size(2)); + auto y1_with_indices = F::fractional_max_pool2d_with_indices( + x1, F::FractionalMaxPool2dFuncOptions(3).output_size(2)); ASSERT_TRUE(torch::equal(y1, std::get<0>(y1_with_indices))); ASSERT_TRUE(torch::allclose( - std::get<1>(y1_with_indices), - torch::tensor({{{{ 0, 2}, - {10, 12}}, - {{ 0, 2}, - {10, 12}}}, - {{{ 0, 2}, - {10, 12}}, - {{ 0, 2}, - {10, 12}}}}))); - ASSERT_EQ(std::get<1>(y1_with_indices).sizes(), std::vector({2, 2, 2, 2})); + std::get<1>(y1_with_indices), + torch::tensor( + {{{{0, 2}, {10, 12}}, {{0, 2}, {10, 12}}}, + {{{0, 2}, {10, 12}}, {{0, 2}, {10, 12}}}}))); + ASSERT_EQ( + std::get<1>(y1_with_indices).sizes(), std::vector({2, 2, 2, 2})); } TEST_F(FunctionalTest, FractionalMaxPool3d) { auto x = torch::ones({2, 5, 5, 5}); - auto y = F::fractional_max_pool3d(x, F::FractionalMaxPool3dFuncOptions(3).output_size(2)); + auto y = F::fractional_max_pool3d( + x, F::FractionalMaxPool3dFuncOptions(3).output_size(2)); ASSERT_EQ(y.ndimension(), 4); ASSERT_TRUE(torch::allclose(y, torch::ones({2, 2, 2, 2}))); ASSERT_EQ(y.sizes(), std::vector({2, 2, 2, 2})); - auto y_with_indices = F::fractional_max_pool3d_with_indices(x, F::FractionalMaxPool3dFuncOptions(3).output_size(2)); + auto y_with_indices = F::fractional_max_pool3d_with_indices( + x, F::FractionalMaxPool3dFuncOptions(3).output_size(2)); ASSERT_TRUE(torch::equal(y, std::get<0>(y_with_indices))); ASSERT_TRUE(torch::allclose( - std::get<1>(y_with_indices), - torch::tensor({{{{ 0, 2}, - {10, 12}}, - {{50, 52}, - {60, 62}}}, - {{{ 0, 2}, - {10, 12}}, - {{50, 52}, - {60, 62}}}}))); - ASSERT_EQ(std::get<1>(y_with_indices).sizes(), std::vector({2, 2, 2, 2})); + std::get<1>(y_with_indices), + torch::tensor( + {{{{0, 2}, {10, 12}}, {{50, 52}, {60, 62}}}, + {{{0, 2}, {10, 12}}, {{50, 52}, {60, 62}}}}))); + ASSERT_EQ( + std::get<1>(y_with_indices).sizes(), std::vector({2, 2, 2, 2})); } TEST_F(FunctionalTest, LPPool1d) { @@ -225,8 +237,12 @@ TEST_F(FunctionalTest, LPPool1d) { int kernel_size = 3; auto x = torch::ones({1, 1, 5}); - auto y = F::lp_pool1d(x, F::LPPool1dFuncOptions(norm_type, kernel_size).stride(stride)); - auto expected = (torch::pow(torch::tensor({{{1, 1}}}, torch::kFloat), norm_type) * kernel_size).pow(1. / norm_type); + auto y = F::lp_pool1d( + x, F::LPPool1dFuncOptions(norm_type, kernel_size).stride(stride)); + auto expected = + (torch::pow(torch::tensor({{{1, 1}}}, torch::kFloat), norm_type) * + kernel_size) + .pow(1. / norm_type); ASSERT_EQ(y.ndimension(), 3); ASSERT_TRUE(torch::allclose(y, expected)); @@ -239,8 +255,12 @@ TEST_F(FunctionalTest, LPPool2d) { std::vector kernel_size({2, 3}); auto x = torch::ones({1, 2, 5}); - auto y = F::lp_pool2d(x, F::LPPool2dFuncOptions(norm_type, kernel_size).stride(stride)); - auto expected = (torch::pow(torch::tensor({{{1, 1}}}, torch::kFloat), norm_type) * (kernel_size[0] * kernel_size[1])).pow(1. / norm_type); + auto y = F::lp_pool2d( + x, F::LPPool2dFuncOptions(norm_type, kernel_size).stride(stride)); + auto expected = + (torch::pow(torch::tensor({{{1, 1}}}, torch::kFloat), norm_type) * + (kernel_size[0] * kernel_size[1])) + .pow(1. / norm_type); ASSERT_EQ(y.ndimension(), 3); ASSERT_TRUE(torch::allclose(y, expected)); @@ -250,17 +270,17 @@ TEST_F(FunctionalTest, LPPool2d) { TEST_F(FunctionalTest, CosineSimilarity) { auto input1 = torch::tensor({{1, 2, 3}, {4, 5, 6}}, torch::kFloat); auto input2 = torch::tensor({{1, 8, 3}, {2, 1, 6}}, torch::kFloat); - auto output = - F::cosine_similarity(input1, input2, F::CosineSimilarityFuncOptions().dim(1)); + auto output = F::cosine_similarity( + input1, input2, F::CosineSimilarityFuncOptions().dim(1)); auto expected = torch::tensor({0.8078, 0.8721}, torch::kFloat); ASSERT_TRUE(output.allclose(expected, 1e-04)); } TEST_F(FunctionalTest, SmoothL1LossDefaultOptions) { - auto input = torch::tensor({0.1, 1.2, 4.7}, torch::dtype(torch::kFloat).requires_grad(true)); + auto input = torch::tensor( + {0.1, 1.2, 4.7}, torch::dtype(torch::kFloat).requires_grad(true)); auto target = torch::tensor({0., 1., 5.}, torch::kFloat); - auto output = - F::smooth_l1_loss(input, target); + auto output = F::smooth_l1_loss(input, target); auto expected = torch::tensor(0.0233335, torch::kFloat); auto s = output.sum(); s.backward(); @@ -269,11 +289,13 @@ TEST_F(FunctionalTest, SmoothL1LossDefaultOptions) { } TEST_F(FunctionalTest, SmoothL1LossBeta) { - auto input = torch::tensor({0.1, 1.5, 10.0}, torch::dtype(torch::kFloat).requires_grad(true)); + auto input = torch::tensor( + {0.1, 1.5, 10.0}, torch::dtype(torch::kFloat).requires_grad(true)); auto target = torch::tensor({0., 1., 5.}, torch::kFloat); auto output = // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers,bugprone-argument-comment) - F::smooth_l1_loss(input, target, /*reduction=*/torch::kMean, /*beta=*/0.5); + F::smooth_l1_loss( + input, target, /*reduction=*/torch::kMean, /*beta=*/0.5); auto expected = torch::tensor(1.67, torch::kFloat); auto s = output.sum(); s.backward(); @@ -282,7 +304,8 @@ TEST_F(FunctionalTest, SmoothL1LossBeta) { } TEST_F(FunctionalTest, SmoothL1LossNoReduction) { - auto input = torch::tensor({0.1, 1.2, 4.7}, torch::dtype(torch::kFloat).requires_grad(true)); + auto input = torch::tensor( + {0.1, 1.2, 4.7}, torch::dtype(torch::kFloat).requires_grad(true)); auto target = torch::tensor({0., 1., 5.}, torch::kFloat); auto output = // NOLINTNEXTLINE(bugprone-argument-comment) @@ -295,10 +318,10 @@ TEST_F(FunctionalTest, SmoothL1LossNoReduction) { } TEST_F(FunctionalTest, HuberLossDefaultOptions) { - auto input = torch::tensor({0.1, 1.2, 4.7}, torch::dtype(torch::kFloat).requires_grad(true)); + auto input = torch::tensor( + {0.1, 1.2, 4.7}, torch::dtype(torch::kFloat).requires_grad(true)); auto target = torch::tensor({0., 1., 5.}, torch::kFloat); - auto output = - F::huber_loss(input, target); + auto output = F::huber_loss(input, target); auto expected = torch::tensor(0.0233335, torch::kFloat); auto s = output.sum(); s.backward(); @@ -307,7 +330,8 @@ TEST_F(FunctionalTest, HuberLossDefaultOptions) { } TEST_F(FunctionalTest, HuberLossDelta) { - auto input = torch::tensor({0.1, 1.5, 10.0}, torch::dtype(torch::kFloat).requires_grad(true)); + auto input = torch::tensor( + {0.1, 1.5, 10.0}, torch::dtype(torch::kFloat).requires_grad(true)); auto target = torch::tensor({0., 1., 5.}, torch::kFloat); auto options = F::HuberLossFuncOptions().reduction(torch::kMean).delta(0.5); auto output = F::huber_loss(input, target, options); @@ -319,7 +343,8 @@ TEST_F(FunctionalTest, HuberLossDelta) { } TEST_F(FunctionalTest, HuberLossNoReduction) { - auto input = torch::tensor({0.1, 1.2, 4.7}, torch::dtype(torch::kFloat).requires_grad(true)); + auto input = torch::tensor( + {0.1, 1.2, 4.7}, torch::dtype(torch::kFloat).requires_grad(true)); auto target = torch::tensor({0., 1., 5.}, torch::kFloat); auto options = F::HuberLossFuncOptions().reduction(torch::kNone); auto output = F::huber_loss(input, target, options); @@ -331,10 +356,10 @@ TEST_F(FunctionalTest, HuberLossNoReduction) { } TEST_F(FunctionalTest, SoftMarginLossDefaultOptions) { - auto input = torch::tensor({2., 4., 1., 3.}, torch::dtype(torch::kFloat).requires_grad(true)); + auto input = torch::tensor( + {2., 4., 1., 3.}, torch::dtype(torch::kFloat).requires_grad(true)); auto target = torch::tensor({-1., 1., 1., -1.}, torch::kFloat); - auto output = - F::soft_margin_loss(input, target); + auto output = F::soft_margin_loss(input, target); auto expected = torch::tensor({1.3767317}, torch::kFloat); auto s = output.sum(); s.backward(); @@ -344,10 +369,12 @@ TEST_F(FunctionalTest, SoftMarginLossDefaultOptions) { } TEST_F(FunctionalTest, MultiLabelSoftMarginLossDefaultOptions) { - auto input = torch::tensor({{0., 2., 2., 0.}, {2., 1., 0., 1.}}, torch::dtype(torch::kFloat).requires_grad(true)); - auto target = torch::tensor({{0., 0., 1., 0.}, {1., 0., 1., 1.}}, torch::kFloat); - auto output = - F::multilabel_soft_margin_loss(input, target); + auto input = torch::tensor( + {{0., 2., 2., 0.}, {2., 1., 0., 1.}}, + torch::dtype(torch::kFloat).requires_grad(true)); + auto target = + torch::tensor({{0., 0., 1., 0.}, {1., 0., 1., 1.}}, torch::kFloat); + auto output = F::multilabel_soft_margin_loss(input, target); auto expected = torch::tensor({0.7608436}, torch::kFloat); auto s = output.sum(); s.backward(); @@ -357,11 +384,12 @@ TEST_F(FunctionalTest, MultiLabelSoftMarginLossDefaultOptions) { } TEST_F(FunctionalTest, SoftMarginLossNoReduction) { - auto input = torch::tensor({2., 4., 1., 3.}, torch::dtype(torch::kFloat).requires_grad(true)); + auto input = torch::tensor( + {2., 4., 1., 3.}, torch::dtype(torch::kFloat).requires_grad(true)); auto target = torch::tensor({-1., 1., 1., -1.}, torch::kFloat); - auto output = - F::soft_margin_loss(input, target, torch::kNone); - auto expected = torch::tensor({2.1269281, 0.01814993, 0.3132617, 3.0485873}, torch::kFloat); + auto output = F::soft_margin_loss(input, target, torch::kNone); + auto expected = torch::tensor( + {2.1269281, 0.01814993, 0.3132617, 3.0485873}, torch::kFloat); auto s = output.sum(); s.backward(); @@ -370,12 +398,16 @@ TEST_F(FunctionalTest, SoftMarginLossNoReduction) { } TEST_F(FunctionalTest, MultiLabelSoftMarginLossWeightedNoReduction) { - auto input = torch::tensor({{0., 2., 2., 0.}, {2., 1., 0., 1.}}, torch::dtype(torch::kFloat).requires_grad(true)); - auto target = torch::tensor({{0., 0., 1., 0.}, {1., 0., 1., 1.}}, torch::kFloat); + auto input = torch::tensor( + {{0., 2., 2., 0.}, {2., 1., 0., 1.}}, + torch::dtype(torch::kFloat).requires_grad(true)); + auto target = + torch::tensor({{0., 0., 1., 0.}, {1., 0., 1., 1.}}, torch::kFloat); auto weight = torch::tensor({0.1, 0.6, 0.4, 0.8}, torch::kFloat); - auto options = F::MultilabelSoftMarginLossFuncOptions().reduction(torch::kNone).weight(weight); - auto output = - F::multilabel_soft_margin_loss(input, target, options); + auto options = F::MultilabelSoftMarginLossFuncOptions() + .reduction(torch::kNone) + .weight(weight); + auto output = F::multilabel_soft_margin_loss(input, target, options); auto expected = torch::tensor({0.4876902, 0.3321295}, torch::kFloat); auto s = output.sum(); s.backward(); @@ -387,8 +419,8 @@ TEST_F(FunctionalTest, MultiLabelSoftMarginLossWeightedNoReduction) { TEST_F(FunctionalTest, PairwiseDistance) { auto input1 = torch::tensor({{1, 2, 3}, {4, 5, 6}}, torch::kFloat); auto input2 = torch::tensor({{1, 8, 3}, {2, 1, 6}}, torch::kFloat); - auto output = - F::pairwise_distance(input1, input2, F::PairwiseDistanceFuncOptions().p(1)); + auto output = F::pairwise_distance( + input1, input2, F::PairwiseDistanceFuncOptions().p(1)); auto expected = torch::tensor({6, 6}, torch::kFloat); ASSERT_TRUE(output.allclose(expected)); } @@ -463,8 +495,8 @@ TEST_F(FunctionalTest, AdaptiveAvgPool3d) { } TEST_F(FunctionalTest, L1Loss) { - auto input = torch::randn({5,6}, torch::requires_grad()); - auto target = torch::empty({5,6}).random_(2); + auto input = torch::randn({5, 6}, torch::requires_grad()); + auto target = torch::empty({5, 6}).random_(2); auto output = F::l1_loss(torch::sigmoid(input), target); auto s = output.sum(); s.backward(); @@ -474,8 +506,8 @@ TEST_F(FunctionalTest, L1Loss) { } TEST_F(FunctionalTest, MSELoss) { - auto input = torch::randn({5,6}, torch::requires_grad()); - auto target = torch::empty({5,6}).random_(2); + auto input = torch::randn({5, 6}, torch::requires_grad()); + auto target = torch::empty({5, 6}).random_(2); auto output = F::mse_loss(torch::sigmoid(input), target); auto s = output.sum(); s.backward(); @@ -485,8 +517,8 @@ TEST_F(FunctionalTest, MSELoss) { } TEST_F(FunctionalTest, BCELoss) { - auto input = torch::randn({5,6}, torch::requires_grad()); - auto target = torch::empty({5,6}).random_(2); + auto input = torch::randn({5, 6}, torch::requires_grad()); + auto target = torch::empty({5, 6}).random_(2); auto output = F::binary_cross_entropy(torch::sigmoid(input), target); auto s = output.sum(); s.backward(); @@ -497,8 +529,8 @@ TEST_F(FunctionalTest, BCELoss) { TEST_F(FunctionalTest, KLDivLoss) { KLDivLoss loss; - auto input = torch::randn({5,6}, torch::requires_grad()); - auto target = torch::empty({5,6}).random_(2); + auto input = torch::randn({5, 6}, torch::requires_grad()); + auto target = torch::empty({5, 6}).random_(2); auto output = F::kl_div(torch::sigmoid(input), target); auto s = output.sum(); s.backward(); @@ -518,20 +550,22 @@ TEST_F(FunctionalTest, HingeEmbeddingLoss) { } TEST_F(FunctionalTest, GridSample) { - auto input = torch::arange(9, torch::kFloat).view(std::vector({1, 1, 3, 3})); - auto grid = torch::tensor({{ - {{-2., -1.}, {-1., -1.}, {0., -1.}}, - {{-1., 0.}, {0., 0.}, {1., 0.}}, - {{0., 1.}, {1., 1.}, {2., 1.}} - }}, torch::kFloat); + auto input = + torch::arange(9, torch::kFloat).view(std::vector({1, 1, 3, 3})); + auto grid = torch::tensor( + {{{{-2., -1.}, {-1., -1.}, {0., -1.}}, + {{-1., 0.}, {0., 0.}, {1., 0.}}, + {{0., 1.}, {1., 1.}, {2., 1.}}}}, + torch::kFloat); // bilinear, zeros, true auto options = F::GridSampleFuncOptions() - .mode(torch::kBilinear) - .padding_mode(torch::kZeros) - .align_corners(true); + .mode(torch::kBilinear) + .padding_mode(torch::kZeros) + .align_corners(true); auto output = F::grid_sample(input, grid, options); - auto expected = torch::tensor({{{{0., 0., 1.}, {3., 4., 5.}, {7., 8., 0.}}}}, torch::kFloat); + auto expected = torch::tensor( + {{{{0., 0., 1.}, {3., 4., 5.}, {7., 8., 0.}}}}, torch::kFloat); ASSERT_TRUE(output.allclose(expected)); @@ -541,7 +575,8 @@ TEST_F(FunctionalTest, GridSample) { .padding_mode(torch::kZeros) .align_corners(false); output = F::grid_sample(input, grid, options); - expected = torch::tensor({{{{0., 0., 0.5}, {1.5, 4., 2.5}, {3.5, 2., 0.}}}}, torch::kFloat); + expected = torch::tensor( + {{{{0., 0., 0.5}, {1.5, 4., 2.5}, {3.5, 2., 0.}}}}, torch::kFloat); ASSERT_TRUE(output.allclose(expected)); @@ -556,7 +591,8 @@ TEST_F(FunctionalTest, GridSample) { .padding_mode(torch::kZeros) .align_corners(true); output = F::grid_sample(input, grid, options); - expected = torch::tensor({{{{0., 0., 1.}, {3., 4., 5.}, {7., 8., 0.}}}}, torch::kFloat); + expected = torch::tensor( + {{{{0., 0., 1.}, {3., 4., 5.}, {7., 8., 0.}}}}, torch::kFloat); ASSERT_TRUE(output.allclose(expected)); @@ -566,7 +602,8 @@ TEST_F(FunctionalTest, GridSample) { .padding_mode(torch::kBorder) .align_corners(true); output = F::grid_sample(input, grid, options); - expected = torch::tensor({{{{0., 0., 1.}, {3., 4., 5.}, {7., 8., 8.}}}}, torch::kFloat); + expected = torch::tensor( + {{{{0., 0., 1.}, {3., 4., 5.}, {7., 8., 8.}}}}, torch::kFloat); ASSERT_TRUE(output.allclose(expected)); @@ -576,7 +613,8 @@ TEST_F(FunctionalTest, GridSample) { .padding_mode(torch::kReflection) .align_corners(true); output = F::grid_sample(input, grid, options); - expected = torch::tensor({{{{1., 0., 1.}, {3., 4., 5.}, {7., 8., 7.}}}}, torch::kFloat); + expected = torch::tensor( + {{{{1., 0., 1.}, {3., 4., 5.}, {7., 8., 7.}}}}, torch::kFloat); ASSERT_TRUE(output.allclose(expected)); } @@ -584,8 +622,7 @@ TEST_F(FunctionalTest, GridSample) { TEST_F(FunctionalTest, AffineGrid) { { // 2D affine. - auto theta = torch::arange(1., 13) - .view(std::vector({2, 2, 3})); + auto theta = torch::arange(1., 13).view(std::vector({2, 2, 3})); auto size = std::vector({2, 3, 2, 2}); auto align_corners = true; auto output = F::affine_grid(theta, size, !align_corners); @@ -602,8 +639,7 @@ TEST_F(FunctionalTest, AffineGrid) { } { // 3D affine. - auto theta = torch::arange(1., 13) - .view(std::vector({1, 3, 4})); + auto theta = torch::arange(1., 13).view(std::vector({1, 3, 4})); auto size = std::vector({1, 1, 3, 2, 2}); auto align_corners = true; auto output = F::affine_grid(theta, size, !align_corners); @@ -615,13 +651,13 @@ TEST_F(FunctionalTest, AffineGrid) { {{{4.5000, 7.1667, 9.8333}, {5.5000, 12.1667, 18.8333}}, {{6.5000, 13.1667, 19.8333}, {7.5000, 18.1667, 28.8333}}}}}); auto output_aligned = F::affine_grid(theta, size, align_corners); - auto expected_aligned = - torch::tensor({{{{{-2.0, -10.0, -18.0}, {0.0, 0.0, 0.0}}, - {{2.0, 2.0, 2.0}, {4.0, 12.0, 20.0}}}, - {{{1.0, -3.0, -7.0}, {3.0, 7.0, 11.0}}, - {{5.0, 9.0, 13.0}, {7.0, 19.0, 31.0}}}, - {{{4.0, 4.0, 4.0}, {6.0, 14.0, 22.0}}, - {{8.0, 16.0, 24.0}, {10.0, 26.0, 42.0}}}}}); + auto expected_aligned = torch::tensor( + {{{{{-2.0, -10.0, -18.0}, {0.0, 0.0, 0.0}}, + {{2.0, 2.0, 2.0}, {4.0, 12.0, 20.0}}}, + {{{1.0, -3.0, -7.0}, {3.0, 7.0, 11.0}}, + {{5.0, 9.0, 13.0}, {7.0, 19.0, 31.0}}}, + {{{4.0, 4.0, 4.0}, {6.0, 14.0, 22.0}}, + {{8.0, 16.0, 24.0}, {10.0, 26.0, 42.0}}}}}); ASSERT_TRUE(output.allclose(expected, 1e-2)); ASSERT_TRUE(output_aligned.allclose(expected_aligned)); @@ -685,11 +721,11 @@ TEST_F(FunctionalTest, AffineGrid) { TEST_F(FunctionalTest, MultiMarginLoss) { auto weight = torch::tensor({0.3, 0.3, 0.4}, torch::kFloat); auto input = torch::tensor( - {{0.2, 0.2, 0.6}, {0.1, 0.8, 0.1}, {0.9, 0.09, 0.01}}, - torch::dtype(torch::kFloat).requires_grad(true)); + {{0.2, 0.2, 0.6}, {0.1, 0.8, 0.1}, {0.9, 0.09, 0.01}}, + torch::dtype(torch::kFloat).requires_grad(true)); auto target = torch::tensor({2, 1, 0}, torch::kLong); auto output = F::multi_margin_loss( - input, target, F::MultiMarginLossFuncOptions().margin(2).weight(weight)); + input, target, F::MultiMarginLossFuncOptions().margin(2).weight(weight)); auto expected = torch::tensor({0.305556}, torch::kFloat); ASSERT_TRUE(output.allclose(expected, 1e-04)); @@ -707,7 +743,8 @@ TEST_F(FunctionalTest, CosineEmbeddingLoss) { } TEST_F(FunctionalTest, MultiLabelMarginLossDefaultOptions) { - auto input = torch::tensor({{0.1, 0.2, 0.4, 0.8}}, torch::dtype(torch::kFloat).requires_grad(true)); + auto input = torch::tensor( + {{0.1, 0.2, 0.4, 0.8}}, torch::dtype(torch::kFloat).requires_grad(true)); auto target = torch::tensor({{3, 0, -1, 1}}, torch::kLong); auto output = F::multilabel_margin_loss(input, target); auto expected = torch::tensor({0.8500}, torch::kFloat); @@ -719,10 +756,10 @@ TEST_F(FunctionalTest, MultiLabelMarginLossDefaultOptions) { } TEST_F(FunctionalTest, MultiLabelMarginLossNoReduction) { - auto input = torch::tensor({{0.1, 0.2, 0.4, 0.8}}, torch::dtype(torch::kFloat).requires_grad(true)); + auto input = torch::tensor( + {{0.1, 0.2, 0.4, 0.8}}, torch::dtype(torch::kFloat).requires_grad(true)); auto target = torch::tensor({{3, 0, -1, 1}}, torch::kLong); - auto output = F::multilabel_margin_loss( - input, target, torch::kNone); + auto output = F::multilabel_margin_loss(input, target, torch::kNone); auto expected = torch::tensor({0.8500}, torch::kFloat); auto s = output.sum(); s.backward(); @@ -736,7 +773,10 @@ TEST_F(FunctionalTest, TripletMarginLoss) { auto positive = torch::tensor({{2., 2.}}, torch::kFloat); auto negative = torch::tensor({{0., 0.}}, torch::kFloat); auto output = F::triplet_margin_loss( - anchor, positive, negative, F::TripletMarginLossFuncOptions().margin(1.0)); + anchor, + positive, + negative, + F::TripletMarginLossFuncOptions().margin(1.0)); auto expected = torch::tensor({0.}, torch::kFloat); ASSERT_TRUE(output.allclose(expected, 1e-04)); @@ -747,30 +787,29 @@ TEST_F(FunctionalTest, TripletMarginWithDistanceLossDefaultParity) { // TripletMarginLoss options as our distance function, the outputs // are equal (i.e., equal under defaults). - std::vector - reductions = {torch::kSum, torch::kMean, torch::kNone}; + std::vector reductions = { + torch::kSum, torch::kMean, torch::kNone}; std::vector margins = {0.5, 1.0, 1.5}; std::vector swaps = {true, false}; for (auto& reduction : reductions) { for (auto& margin : margins) { for (const auto& swap : swaps) { - auto anchor = - torch::randn({100, 128}, torch::dtype(torch::kFloat).requires_grad(true)); - auto positive = - torch::randn({100, 128}, torch::dtype(torch::kFloat).requires_grad(true)); - auto negative = - torch::randn({100, 128}, torch::dtype(torch::kFloat).requires_grad(true)); + auto anchor = torch::randn( + {100, 128}, torch::dtype(torch::kFloat).requires_grad(true)); + auto positive = torch::randn( + {100, 128}, torch::dtype(torch::kFloat).requires_grad(true)); + auto negative = torch::randn( + {100, 128}, torch::dtype(torch::kFloat).requires_grad(true)); auto basicOptions = F::TripletMarginLossFuncOptions() .reduction(reduction) .margin(margin) .swap(swap); - auto distanceOptions = - F::TripletMarginWithDistanceLossFuncOptions() - .reduction(reduction) - .margin(margin) - .swap(swap); + auto distanceOptions = F::TripletMarginWithDistanceLossFuncOptions() + .reduction(reduction) + .margin(margin) + .swap(swap); TripletMarginLoss basicLoss(basicOptions); TripletMarginWithDistanceLoss distanceLoss(distanceOptions); @@ -793,13 +832,16 @@ TEST_F(FunctionalTest, TripletMarginWithDistanceLossDefaultParity) { } TEST_F(FunctionalTest, NLLLoss) { - auto input = torch::tensor({{-0.1315, -3.1315, -2.5315}, - {-3.7038, -0.1038, -2.6038}, - {-2.3422, -1.3422, -0.4422}}, - torch::kFloat); + auto input = torch::tensor( + {{-0.1315, -3.1315, -2.5315}, + {-3.7038, -0.1038, -2.6038}, + {-2.3422, -1.3422, -0.4422}}, + torch::kFloat); auto target = torch::tensor({1, 0, 2}, torch::kLong); auto output = F::nll_loss( - input, target, F::NLLLossFuncOptions().ignore_index(-100).reduction(torch::kMean)); + input, + target, + F::NLLLossFuncOptions().ignore_index(-100).reduction(torch::kMean)); auto expected = torch::tensor(2.4258, torch::kFloat); ASSERT_TRUE(output.allclose(expected, 1e-04)); ASSERT_TRUE(F::nll_loss(input, target).allclose(expected, 1e-04)); @@ -809,7 +851,9 @@ TEST_F(FunctionalTest, CrossEntropy) { auto input = torch::tensor({{3., 3.}, {2., 2.}}, torch::kFloat); auto target = torch::tensor({0, 1}, torch::kLong); auto output = F::cross_entropy( - input, target, F::CrossEntropyFuncOptions().ignore_index(-100).reduction(torch::kMean)); + input, + target, + F::CrossEntropyFuncOptions().ignore_index(-100).reduction(torch::kMean)); auto expected = torch::tensor(0.6931, torch::kFloat); ASSERT_TRUE(output.allclose(expected, 1e-04)); @@ -818,20 +862,27 @@ TEST_F(FunctionalTest, CrossEntropy) { // label smoothing with class indices input = torch::tensor({{3., 1.}, {1., 2.}}, torch::kFloat); output = F::cross_entropy( - input, target, F::CrossEntropyFuncOptions().label_smoothing(0.15).reduction(torch::kMean)); + input, + target, + F::CrossEntropyFuncOptions().label_smoothing(0.15).reduction( + torch::kMean)); expected = torch::tensor(0.3326, torch::kFloat); ASSERT_TRUE(output.allclose(expected, 1e-04)); // label smoothing with target probabilities target = torch::tensor({{0.8, 0.2}, {0.1, 0.9}}, torch::kFloat); output = F::cross_entropy( - input, target, F::CrossEntropyFuncOptions().label_smoothing(0.2).reduction(torch::kMean)); + input, + target, + F::CrossEntropyFuncOptions().label_smoothing(0.2).reduction( + torch::kMean)); expected = torch::tensor(0.5701, torch::kFloat); ASSERT_TRUE(output.allclose(expected, 1e-04)); } TEST_F(FunctionalTest, MaxUnpool1d) { - auto x = torch::tensor({{{2, 4, 5}}}, torch::dtype(torch::kFloat).requires_grad(true)); + auto x = torch::tensor( + {{{2, 4, 5}}}, torch::dtype(torch::kFloat).requires_grad(true)); auto indices = torch::tensor({{{1, 3, 4}}}, torch::kLong); auto y = F::max_unpool1d(x, indices, F::MaxUnpool1dFuncOptions(3)); @@ -840,19 +891,25 @@ TEST_F(FunctionalTest, MaxUnpool1d) { y, torch::tensor({{{0, 2, 0, 4, 5, 0, 0, 0, 0}}}, torch::kFloat))); ASSERT_EQ(y.sizes(), std::vector({1, 1, 9})); - x = torch::tensor({{{2, 4, 5}}}, torch::dtype(torch::kFloat).requires_grad(true)); + x = torch::tensor( + {{{2, 4, 5}}}, torch::dtype(torch::kFloat).requires_grad(true)); indices = torch::tensor({{{1, 3, 4}}}, torch::kLong); y = F::max_unpool1d( - x, indices, F::MaxUnpool1dFuncOptions(3).output_size(std::vector({1, 1, 9}))); + x, + indices, + F::MaxUnpool1dFuncOptions(3).output_size( + std::vector({1, 1, 9}))); ASSERT_EQ(y.ndimension(), 3); ASSERT_TRUE(torch::allclose( y, torch::tensor({{{0, 2, 0, 4, 5, 0, 0, 0, 0}}}, torch::kFloat))); ASSERT_EQ(y.sizes(), std::vector({1, 1, 9})); - x = torch::tensor({{{2, 4, 5}}}, torch::dtype(torch::kFloat).requires_grad(true)); + x = torch::tensor( + {{{2, 4, 5}}}, torch::dtype(torch::kFloat).requires_grad(true)); indices = torch::tensor({{{1, 3, 4}}}, torch::kLong); - y = F::max_unpool1d(x, indices, F::MaxUnpool1dFuncOptions(3).stride(2).padding(1)); + y = F::max_unpool1d( + x, indices, F::MaxUnpool1dFuncOptions(3).stride(2).padding(1)); ASSERT_EQ(y.ndimension(), 3); ASSERT_TRUE( @@ -861,53 +918,49 @@ TEST_F(FunctionalTest, MaxUnpool1d) { } TEST_F(FunctionalTest, MaxUnpool2d) { - auto indices = torch::tensor({ - {{{ 6, 8, 9}, - {16, 18, 19}, - {21, 23, 24}}}, - {{{ 6, 8, 9}, - {16, 18, 19}, - {21, 23, 24}}}}, torch::kLong); - auto x = torch::tensor({ - {{{ 6, 8, 9}, - {16, 18, 19}, - {21, 23, 24}}}, - {{{31, 33, 34}, - {41, 43, 44}, - {46, 48, 49}}}}, torch::dtype(torch::kFloat).requires_grad(true)); - auto y = F::max_unpool2d(x, indices, F::MaxUnpool2dFuncOptions(3).stride(2).padding(1)); + auto indices = torch::tensor( + {{{{6, 8, 9}, {16, 18, 19}, {21, 23, 24}}}, + {{{6, 8, 9}, {16, 18, 19}, {21, 23, 24}}}}, + torch::kLong); + auto x = torch::tensor( + {{{{6, 8, 9}, {16, 18, 19}, {21, 23, 24}}}, + {{{31, 33, 34}, {41, 43, 44}, {46, 48, 49}}}}, + torch::dtype(torch::kFloat).requires_grad(true)); + auto y = F::max_unpool2d( + x, indices, F::MaxUnpool2dFuncOptions(3).stride(2).padding(1)); ASSERT_EQ(y.dim(), 4); - ASSERT_TRUE(torch::allclose(y, torch::tensor( - {{{{ 0, 0, 0, 0, 0}, - { 0, 6, 0, 8, 9}, - { 0, 0, 0, 0, 0}, - { 0, 16, 0, 18, 19}, - { 0, 21, 0, 23, 24}}}, - {{{ 0, 0, 0, 0, 0}, - { 0, 31, 0, 33, 34}, - { 0, 0, 0, 0, 0}, - { 0, 41, 0, 43, 44}, - { 0, 46, 0, 48, 49}}}} , torch::kFloat))); + ASSERT_TRUE(torch::allclose( + y, + torch::tensor( + {{{{0, 0, 0, 0, 0}, + {0, 6, 0, 8, 9}, + {0, 0, 0, 0, 0}, + {0, 16, 0, 18, 19}, + {0, 21, 0, 23, 24}}}, + {{{0, 0, 0, 0, 0}, + {0, 31, 0, 33, 34}, + {0, 0, 0, 0, 0}, + {0, 41, 0, 43, 44}, + {0, 46, 0, 48, 49}}}}, + torch::kFloat))); ASSERT_EQ(y.sizes(), std::vector({2, 1, 5, 5})); } TEST_F(FunctionalTest, MaxUnpool3d) { auto indices = torch::tensor({{{{{26}}}}}, torch::kLong); - auto x = torch::tensor({{{{{26}}}}}, torch::dtype(torch::kFloat).requires_grad(true)); + auto x = torch::tensor( + {{{{{26}}}}}, torch::dtype(torch::kFloat).requires_grad(true)); auto y = F::max_unpool3d(x, indices, F::MaxUnpool3dFuncOptions(3)); ASSERT_EQ(y.dim(), 5); - ASSERT_TRUE(torch::allclose(y, torch::tensor( - {{{{{ 0, 0, 0}, - { 0, 0, 0}, - { 0, 0, 0}}, - {{ 0, 0, 0}, - { 0, 0, 0}, - { 0, 0, 0}}, - {{ 0, 0, 0}, - { 0, 0, 0}, - { 0, 0, 26}}}}}, torch::kFloat))); + ASSERT_TRUE(torch::allclose( + y, + torch::tensor( + {{{{{0, 0, 0}, {0, 0, 0}, {0, 0, 0}}, + {{0, 0, 0}, {0, 0, 0}, {0, 0, 0}}, + {{0, 0, 0}, {0, 0, 0}, {0, 0, 26}}}}}, + torch::kFloat))); ASSERT_EQ(y.sizes(), std::vector({1, 1, 3, 3, 3})); } @@ -917,13 +970,15 @@ TEST_F(FunctionalTest, ELU) { for (const auto alpha : {0.0, 0.42, 1.0, 4.2, 42.42}) { auto x = torch::linspace(-10.0, 10.0, size * size * size); x.resize_({size, size, size}); - auto x_bf16 = torch::linspace(-10.0, 10.0, size * size * size).to(torch::kBFloat16); + auto x_bf16 = + torch::linspace(-10.0, 10.0, size * size * size).to(torch::kBFloat16); x_bf16.resize_({size, size, size}); auto y_exp = torch::max(torch::zeros_like(x), x) + - torch::min(torch::zeros_like(x), alpha * (torch::exp(x) - 1.0)); + torch::min(torch::zeros_like(x), alpha * (torch::exp(x) - 1.0)); auto y = F::elu(x, F::ELUFuncOptions().alpha(alpha).inplace(inplace)); - auto y_bf16 = F::elu(x_bf16, F::ELUFuncOptions().alpha(alpha).inplace(inplace)); + auto y_bf16 = + F::elu(x_bf16, F::ELUFuncOptions().alpha(alpha).inplace(inplace)); ASSERT_EQ(y.ndimension(), 3); ASSERT_EQ(y.sizes(), std::vector({size, size, size})); @@ -1066,10 +1121,11 @@ TEST_F(FunctionalTest, Hardtanh) { auto x = torch::linspace(-10.0, 10.0, size * size * size); x.resize_({size, size, size}); auto y_exp = (x < min_val) * min_val + - ((x >= min_val) * (x <= max_val)) * x + - (x > max_val) * max_val; - auto y = F::hardtanh(x, F::HardtanhFuncOptions().min_val(min_val) - .max_val(max_val).inplace(inplace)); + ((x >= min_val) * (x <= max_val)) * x + (x > max_val) * max_val; + auto y = F::hardtanh( + x, + F::HardtanhFuncOptions().min_val(min_val).max_val(max_val).inplace( + inplace)); ASSERT_EQ(y.ndimension(), 3); ASSERT_EQ(y.sizes(), std::vector({size, size, size})); @@ -1091,8 +1147,11 @@ TEST_F(FunctionalTest, LeakyReLU) { auto x = torch::linspace(-10.0, 10.0, size * size * size).to(type); x.resize_({size, size, size}); auto y_exp = (x < 0) * x * negative_slope + (x >= 0) * x; - auto y = F::leaky_relu(x, F::LeakyReLUFuncOptions() - .negative_slope(negative_slope).inplace(inplace)); + auto y = F::leaky_relu( + x, + F::LeakyReLUFuncOptions() + .negative_slope(negative_slope) + .inplace(inplace)); ASSERT_EQ(y.ndimension(), 3); ASSERT_EQ(y.sizes(), std::vector({size, size, size})); @@ -1115,7 +1174,8 @@ TEST_F(FunctionalTest, LogSigmoid) { ASSERT_EQ(y.ndimension(), 3); ASSERT_EQ(y.sizes(), std::vector({size, size, size})); - auto y_exp = torch::log(torch::ones_like(x)/(torch::ones_like(x) + torch::exp(torch::neg(x)))); + auto y_exp = torch::log( + torch::ones_like(x) / (torch::ones_like(x) + torch::exp(torch::neg(x)))); ASSERT_TRUE(torch::allclose(y, y_exp, 1e-4, 1e-7)); } @@ -1131,52 +1191,59 @@ TEST_F(FunctionalTest, GumbelSoftmax) { // Shape unchanged ASSERT_TRUE(y_draw.sizes() == logits.sizes()); // One choice per draw - ASSERT_TRUE(torch::allclose(y_draw.sum(), torch::tensor(expected_count, torch::kFloat))); + ASSERT_TRUE(torch::allclose( + y_draw.sum(), torch::tensor(expected_count, torch::kFloat))); } // Test 2: 1D shape, 0 and -1 dim - for(const auto dim: {0, -1}) { + for (const auto dim : {0, -1}) { auto logits = torch::randn({5}); int expected_count = 1; - auto y_draw = F::gumbel_softmax(logits, F::GumbelSoftmaxFuncOptions().hard(true).dim(dim)); + auto y_draw = F::gumbel_softmax( + logits, F::GumbelSoftmaxFuncOptions().hard(true).dim(dim)); // All values positive ASSERT_GE(y_draw.min().item(), 0); // Shape unchanged ASSERT_TRUE(y_draw.sizes() == logits.sizes()); // One choice per draw - ASSERT_TRUE(torch::allclose(y_draw.sum(), torch::tensor(expected_count, torch::kFloat))); + ASSERT_TRUE(torch::allclose( + y_draw.sum(), torch::tensor(expected_count, torch::kFloat))); } { // Test 3: 2D shape, 1 dim auto logits = torch::randn({5, 4}); int expected_count = 5; - auto y_draw = F::gumbel_softmax(logits, F::GumbelSoftmaxFuncOptions().hard(true).dim(1)); + auto y_draw = F::gumbel_softmax( + logits, F::GumbelSoftmaxFuncOptions().hard(true).dim(1)); // All values positive ASSERT_GE(y_draw.min().item(), 0); // Shape unchanged ASSERT_TRUE(y_draw.sizes() == logits.sizes()); // One choice per draw - ASSERT_TRUE(torch::allclose(y_draw.sum(), torch::tensor(expected_count, torch::kFloat))); + ASSERT_TRUE(torch::allclose( + y_draw.sum(), torch::tensor(expected_count, torch::kFloat))); } // Test 4: 3D shape, 1 and -1 dim // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays) int dims[] = {1, -1}; // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays,cppcoreguidelines-avoid-magic-numbers) - int expected[] = {5*3, 5*4}; + int expected[] = {5 * 3, 5 * 4}; for (const auto i : c10::irange(2)) { auto logits = torch::randn({5, 4, 3}); int expected_count = expected[i]; - auto y_draw = F::gumbel_softmax(logits, F::GumbelSoftmaxFuncOptions().hard(true).dim(dims[i])); + auto y_draw = F::gumbel_softmax( + logits, F::GumbelSoftmaxFuncOptions().hard(true).dim(dims[i])); // All values positive ASSERT_GE(y_draw.min().item(), 0); // Shape unchanged ASSERT_TRUE(y_draw.sizes() == logits.sizes()); // One choice per draw - ASSERT_TRUE(torch::allclose(y_draw.sum(), torch::tensor(expected_count, torch::kFloat))); + ASSERT_TRUE(torch::allclose( + y_draw.sum(), torch::tensor(expected_count, torch::kFloat))); } { // Test 5: Straight through @@ -1189,9 +1256,10 @@ TEST_F(FunctionalTest, GumbelSoftmax) { auto counts = torch::zeros_like(logits); torch::Tensor y_draw; for (const auto i : c10::irange(num_draws)) { - (void)i; // Suppress unused variable warning - y_draw = F::gumbel_softmax(logits, F::GumbelSoftmaxFuncOptions().hard(true)); - counts += y_draw; + (void)i; // Suppress unused variable warning + y_draw = + F::gumbel_softmax(logits, F::GumbelSoftmaxFuncOptions().hard(true)); + counts += y_draw; } // All values positive @@ -1250,21 +1318,23 @@ TEST_F(FunctionalTest, PReLU) { const auto w = torch::rand(24) * 200 - 100; const auto y = F::prelu(x, w); ASSERT_EQ(y.sizes(), std::vector({42, 24})); - const auto y_exp = (x < 0) * w * x + (x >= 0) * x; + const auto y_exp = (x < 0) * w * x + (x >= 0) * x; ASSERT_TRUE(torch::allclose(y, y_exp)); } TEST_F(FunctionalTest, LayerNorm) { const auto input = torch::randn({2, 2}); auto y = F::layer_norm(input, F::LayerNormFuncOptions({2, 2}).eps(2e-5)); - auto y_exp = torch::layer_norm(input, {2, 2}, torch::Tensor(), torch::Tensor(), 2e-5); + auto y_exp = + torch::layer_norm(input, {2, 2}, torch::Tensor(), torch::Tensor(), 2e-5); ASSERT_TRUE(torch::allclose(y, y_exp)); } TEST_F(FunctionalTest, GroupNorm) { const auto input = torch::randn({2, 2}); auto y = F::group_norm(input, F::GroupNormFuncOptions(2).eps(2e-5)); - auto y_exp = torch::group_norm(input, 2, torch::Tensor(), torch::Tensor(), 2e-5); + auto y_exp = + torch::group_norm(input, 2, torch::Tensor(), torch::Tensor(), 2e-5); ASSERT_TRUE(torch::allclose(y, y_exp)); } @@ -1274,17 +1344,10 @@ TEST_F(FunctionalTest, LocalResponseNorm) { ASSERT_EQ(y.ndimension(), 3); ASSERT_EQ(y.sizes(), torch::IntArrayRef({3, 3, 2})); const auto y_exp = torch::tensor( - {{{73.7788, 74.1462}, - {60.1942, 60.3302}, - {60.4609, 60.5865}}, - {{75.8729, 76.2011}, - {60.9331, 61.0390}, - {61.1403, 61.2370}}, - {{77.7387, 78.0303}, - {61.5011, 61.5807}, - {61.6563, 61.7279}}}, - torch::kFloat - ); + {{{73.7788, 74.1462}, {60.1942, 60.3302}, {60.4609, 60.5865}}, + {{75.8729, 76.2011}, {60.9331, 61.0390}, {61.1403, 61.2370}}, + {{77.7387, 78.0303}, {61.5011, 61.5807}, {61.6563, 61.7279}}}, + torch::kFloat); ASSERT_TRUE(torch::allclose(y, y_exp, 1e-4, 1e-7)); } @@ -1297,17 +1360,10 @@ TEST_F(FunctionalTest, Linear) { ASSERT_EQ(y.ndimension(), 3); ASSERT_EQ(y.sizes(), torch::IntArrayRef({3, 3, 3})); const auto y_exp = torch::tensor( - {{{40601, 41004, 41407}, - {41403, 41814, 42225}, - {42205, 42624, 43043}}, - {{43007, 43434, 43861}, - {43809, 44244, 44679}, - {44611, 45054, 45497}}, - {{45413, 45864, 46315}, - {46215, 46674, 47133}, - {47017, 47484, 47951}}}, - torch::kFloat - ); + {{{40601, 41004, 41407}, {41403, 41814, 42225}, {42205, 42624, 43043}}, + {{43007, 43434, 43861}, {43809, 44244, 44679}, {44611, 45054, 45497}}, + {{45413, 45864, 46315}, {46215, 46674, 47133}, {47017, 47484, 47951}}}, + torch::kFloat); ASSERT_TRUE(torch::allclose(y, y_exp)); } { @@ -1317,23 +1373,16 @@ TEST_F(FunctionalTest, Linear) { ASSERT_EQ(y.ndimension(), 3); ASSERT_EQ(y.sizes(), torch::IntArrayRef({3, 3, 3})); const auto y_exp = torch::tensor( - {{{40301, 40703, 41105}, - {41103, 41513, 41923}, - {41905, 42323, 42741}}, - {{42707, 43133, 43559}, - {43509, 43943, 44377}, - {44311, 44753, 45195}}, - {{45113, 45563, 46013}, - {45915, 46373, 46831}, - {46717, 47183, 47649}}}, - torch::kFloat - ); + {{{40301, 40703, 41105}, {41103, 41513, 41923}, {41905, 42323, 42741}}, + {{42707, 43133, 43559}, {43509, 43943, 44377}, {44311, 44753, 45195}}, + {{45113, 45563, 46013}, {45915, 46373, 46831}, {46717, 47183, 47649}}}, + torch::kFloat); ASSERT_TRUE(torch::allclose(y, y_exp)); } } TEST_F(FunctionalTest, Embedding) { - const auto input = torch::tensor({{1,2,4,5}, {4,3,2,9}}, torch::kLong); + const auto input = torch::tensor({{1, 2, 4, 5}, {4, 3, 2, 9}}, torch::kLong); auto weight = torch::empty({10, 3}); torch::nn::init::normal_(weight); auto y = F::embedding(input, weight); @@ -1342,25 +1391,37 @@ TEST_F(FunctionalTest, Embedding) { } TEST_F(FunctionalTest, EmbeddingBag) { - const auto input = torch::tensor({1,2,4,5,4,3,2,9}, torch::kLong); - auto offsets = torch::tensor({0,4}, torch::kLong); + const auto input = torch::tensor({1, 2, 4, 5, 4, 3, 2, 9}, torch::kLong); + auto offsets = torch::tensor({0, 4}, torch::kLong); auto weight = torch::empty({10, 3}); torch::nn::init::normal_(weight); - auto y = F::embedding_bag(input, weight, F::EmbeddingBagFuncOptions().mode(torch::kSum).offsets(offsets).padding_idx(4)); - auto y_exp = std::get<0>(torch::embedding_bag(weight, input, offsets, false, 0, false, torch::Tensor(), false, 4)); + auto y = F::embedding_bag( + input, + weight, + F::EmbeddingBagFuncOptions() + .mode(torch::kSum) + .offsets(offsets) + .padding_idx(4)); + auto y_exp = std::get<0>(torch::embedding_bag( + weight, input, offsets, false, 0, false, torch::Tensor(), false, 4)); ASSERT_TRUE(torch::allclose(y, y_exp)); // no options test - const auto input_ = torch::tensor({{1,2,4,5}, {4,3,2,9}}, torch::kLong); - auto offsets_ = torch::arange(0, input_.numel(), input_.size(1), torch::TensorOptions().dtype(torch::kLong).device(input.device())); + const auto input_ = torch::tensor({{1, 2, 4, 5}, {4, 3, 2, 9}}, torch::kLong); + auto offsets_ = torch::arange( + 0, + input_.numel(), + input_.size(1), + torch::TensorOptions().dtype(torch::kLong).device(input.device())); y = F::embedding_bag(input_, weight); - y_exp = std::get<0>(torch::embedding_bag(weight, input_.reshape(-1), offsets_, false, 1, false, torch::Tensor())); + y_exp = std::get<0>(torch::embedding_bag( + weight, input_.reshape(-1), offsets_, false, 1, false, torch::Tensor())); ASSERT_TRUE(torch::allclose(y, y_exp)); } TEST_F(FunctionalTest, Bilinear) { auto input1 = torch::tensor({{1, 2, 3}, {7, 6, 5}}); - auto input2 = torch::tensor({{7, 4}, {8 ,9}}); + auto input2 = torch::tensor({{7, 4}, {8, 9}}); auto weight = torch::tensor({{{2, 3}, {9, 7}, {8, 6}}}); auto bias = torch::tensor({1}); @@ -1379,10 +1440,13 @@ TEST_F(FunctionalTest, Bilinear) { TEST_F(FunctionalTest, Normalize) { const auto expected = torch::tensor( - {{{0.00000000, 0.10000000, 0.2000, 0.30000000, 0.40000000}, - {0.14285715, 0.17142858, 0.2000, 0.22857143, 0.25714287}}}, torch::requires_grad().dtype(torch::kFloat)); + {{{0.00000000, 0.10000000, 0.2000, 0.30000000, 0.40000000}, + {0.14285715, 0.17142858, 0.2000, 0.22857143, 0.25714287}}}, + torch::requires_grad().dtype(torch::kFloat)); { // Test #1 - auto input = torch::tensor({{{0, 1, 2, 3, 4}, {5, 6, 7, 8, 9}}}, torch::dtype(torch::kFloat).requires_grad(true)); + auto input = torch::tensor( + {{{0, 1, 2, 3, 4}, {5, 6, 7, 8, 9}}}, + torch::dtype(torch::kFloat).requires_grad(true)); auto norm = F::normalize(input, F::NormalizeFuncOptions().p(1).dim(-1)); // reduce to scalar to call .backward() @@ -1395,8 +1459,9 @@ TEST_F(FunctionalTest, Normalize) { } { // Test #2 Check variations of optional arguments - auto input = torch::tensor({{{0, 1, 2, 3, 4}, {5, 6, 7, 8, 9}}}, torch::dtype(torch::kFloat)); - auto output = torch::randn({1,2,5}, torch::dtype(torch::kFloat)); + auto input = torch::tensor( + {{{0, 1, 2, 3, 4}, {5, 6, 7, 8, 9}}}, torch::dtype(torch::kFloat)); + auto output = torch::randn({1, 2, 5}, torch::dtype(torch::kFloat)); // non-null output argument F::normalize(input, F::NormalizeFuncOptions().p(1).dim(-1).out(output)); // default options @@ -1407,7 +1472,8 @@ TEST_F(FunctionalTest, Normalize) { { // Test #3 Base case of scalar tensor auto input = torch::randn({}, torch::requires_grad()); - torch::Tensor norm = F::normalize(input, F::NormalizeFuncOptions().p(1).dim(-1)); + torch::Tensor norm = + F::normalize(input, F::NormalizeFuncOptions().p(1).dim(-1)); norm.backward(); ASSERT_EQ(input.grad().numel(), 1); @@ -1503,10 +1569,13 @@ TEST_F(FunctionalTest, RReLU) { auto x = torch::linspace(-10.0, 10.0, size * size * size).to(type); x.resize_({size, size, size}); auto x_copy = x.clone(); - auto y = F::rrelu(x, F::RReLUFuncOptions().lower(lower) - .upper(upper).inplace(inplace)); - auto z = ((x_copy >= 0) * (x_copy == y) + - (x_copy < 0) * (y >= x_copy * upper) * (y <= lower * x_copy)) * 1.0; + auto y = F::rrelu( + x, + F::RReLUFuncOptions().lower(lower).upper(upper).inplace(inplace)); + auto z = + ((x_copy >= 0) * (x_copy == y) + + (x_copy < 0) * (y >= x_copy * upper) * (y <= lower * x_copy)) * + 1.0; ASSERT_EQ(y.ndimension(), 3); ASSERT_EQ(y.sizes(), std::vector({size, size, size})); @@ -1531,7 +1600,8 @@ TEST_F(FunctionalTest, RReLUDefaultOptions) { auto x_copy = x.clone(); auto y = F::rrelu(x); auto z = ((x_copy >= 0) * (x_copy == y) + - (x_copy < 0) * (y >= x_copy * upper) * (y <= lower * x_copy)) * 1.0; + (x_copy < 0) * (y >= x_copy * upper) * (y <= lower * x_copy)) * + 1.0; ASSERT_EQ(y.ndimension(), 3); ASSERT_EQ(y.sizes(), std::vector({size, size, size})); @@ -1547,9 +1617,11 @@ TEST_F(FunctionalTest, CELU) { x.resize_({size, size, size}); auto x_bf16 = x.clone().to(torch::kBFloat16); auto y_exp = torch::max(torch::zeros_like(x), x) + - torch::min(torch::zeros_like(x), alpha * (torch::exp(x / alpha) - 1.0)); + torch::min(torch::zeros_like(x), + alpha * (torch::exp(x / alpha) - 1.0)); auto y = F::celu(x, F::CELUFuncOptions().alpha(alpha).inplace(inplace)); - auto y_bf16 = F::celu(x_bf16, F::CELUFuncOptions().alpha(alpha).inplace(inplace)); + auto y_bf16 = + F::celu(x_bf16, F::CELUFuncOptions().alpha(alpha).inplace(inplace)); ASSERT_EQ(y.ndimension(), 3); ASSERT_EQ(y.sizes(), std::vector({size, size, size})); @@ -1571,7 +1643,7 @@ TEST_F(FunctionalTest, CELUDefaultOptions) { x.resize_({size, size, size}); auto x_bf16 = x.clone().to(torch::kBFloat16); auto y_exp = torch::max(torch::zeros_like(x), x) + - torch::min(torch::zeros_like(x), alpha * (torch::exp(x / alpha) - 1.0)); + torch::min(torch::zeros_like(x), alpha * (torch::exp(x / alpha) - 1.0)); auto y = F::celu(x); auto y_bf16 = F::celu(x_bf16); @@ -1583,15 +1655,14 @@ TEST_F(FunctionalTest, CELUDefaultOptions) { TEST_F(FunctionalTest, PixelShuffle) { auto x = torch::tensor( - {{{{-17, 19}, {-1, 2}}, - {{7, 14}, {-3, 1}}, - {{0, -2}, {-12, 14}}, - {{-15, 0}, {-3, 9}}}}, torch::kFloat); + {{{{-17, 19}, {-1, 2}}, + {{7, 14}, {-3, 1}}, + {{0, -2}, {-12, 14}}, + {{-15, 0}, {-3, 9}}}}, + torch::kFloat); auto y_exp = torch::tensor( - {{{{-17, 7, 19, 14}, - {0, -15, -2, 0}, - {-1, -3, 2, 1}, - {-12, -3, 14, 9}}}}, torch::kFloat); + {{{{-17, 7, 19, 14}, {0, -15, -2, 0}, {-1, -3, 2, 1}, {-12, -3, 14, 9}}}}, + torch::kFloat); auto y = F::pixel_shuffle(x, 2); ASSERT_EQ(y.ndimension(), 4); @@ -1623,10 +1694,10 @@ TEST_F(FunctionalTest, Softplus) { auto x = torch::linspace(-3.0, 3.0, 61); x.resize_({size, size, size}); auto y_exp = - (x <= threshold) * torch::log(1 + torch::exp(x * beta)) / beta + - (x > threshold) * x; - auto y = F::softplus(x, - F::SoftplusFuncOptions().beta(beta).threshold(threshold)); + (x <= threshold) * torch::log(1 + torch::exp(x * beta)) / beta + + (x > threshold) * x; + auto y = F::softplus( + x, F::SoftplusFuncOptions().beta(beta).threshold(threshold)); ASSERT_EQ(y.ndimension(), 3); ASSERT_EQ(y.sizes(), std::vector({size, size, size})); @@ -1641,9 +1712,8 @@ TEST_F(FunctionalTest, SoftplusDefaultOptions) { const auto threshold = 20.0; auto x = torch::linspace(-3.0, 3.0, 61); x.resize_({size, size, size}); - auto y_exp = - (x <= threshold) * torch::log(1 + torch::exp(x * beta)) / beta + - (x > threshold) * x; + auto y_exp = (x <= threshold) * torch::log(1 + torch::exp(x * beta)) / beta + + (x > threshold) * x; auto y = F::softplus(x); ASSERT_EQ(y.ndimension(), 3); @@ -1666,7 +1736,8 @@ TEST_F(FunctionalTest, Fold) { TEST_F(FunctionalTest, Unfold) { auto input = torch::arange(0, 12, torch::kDouble).view({1, 2, 2, 3}); - auto output = F::unfold(input, F::UnfoldFuncOptions({2, 2}).padding(1).stride(2)); + auto output = + F::unfold(input, F::UnfoldFuncOptions({2, 2}).padding(1).stride(2)); auto expected = torch::tensor( {{{0.0, 0.0, 0.0, 4.0}, {0.0, 0.0, 3.0, 5.0}, @@ -1749,8 +1820,8 @@ TEST_F(FunctionalTest, Threshold) { auto x = torch::linspace(-3.0, 3.0, 61); x.resize_({size, size, size}); auto y_exp = (x <= threshold) * value + (x > threshold) * x; - auto y = F::threshold(x, - F::ThresholdFuncOptions(threshold, value).inplace(inplace)); + auto y = F::threshold( + x, F::ThresholdFuncOptions(threshold, value).inplace(inplace)); ASSERT_EQ(y.ndimension(), 3); ASSERT_EQ(y.sizes(), std::vector({size, size, size})); @@ -1761,7 +1832,8 @@ TEST_F(FunctionalTest, Threshold) { } } } - ASSERT_TRUE(F::threshold(torch::tensor(1.), F::ThresholdFuncOptions(0.5, 0.5)).defined()); + ASSERT_TRUE(F::threshold(torch::tensor(1.), F::ThresholdFuncOptions(0.5, 0.5)) + .defined()); } TEST_F(FunctionalTest, BatchNorm1d) { @@ -1775,8 +1847,15 @@ TEST_F(FunctionalTest, BatchNorm1d) { auto weight = torch::ones({num_features}); auto bias = torch::zeros({num_features}); auto output = F::batch_norm( - input, mean, variance, - F::BatchNormFuncOptions().weight(weight).bias(bias).momentum(momentum).eps(eps).training(false)); + input, + mean, + variance, + F::BatchNormFuncOptions() + .weight(weight) + .bias(bias) + .momentum(momentum) + .eps(eps) + .training(false)); auto expected = (input - mean) / torch::sqrt(variance + eps); ASSERT_TRUE(output.allclose(expected)); } @@ -1801,9 +1880,19 @@ TEST_F(FunctionalTest, BatchNorm2d) { auto weight = torch::ones({num_features}); auto bias = torch::zeros({num_features}); auto output = F::batch_norm( - input, mean, variance, - F::BatchNormFuncOptions().weight(weight).bias(bias).momentum(momentum).eps(eps).training(false)); - auto expected = torch::transpose((torch::transpose(input, 1, 3) - mean) / torch::sqrt(variance + eps), 1, 3); + input, + mean, + variance, + F::BatchNormFuncOptions() + .weight(weight) + .bias(bias) + .momentum(momentum) + .eps(eps) + .training(false)); + auto expected = torch::transpose( + (torch::transpose(input, 1, 3) - mean) / torch::sqrt(variance + eps), + 1, + 3); ASSERT_TRUE(output.allclose(expected)); } @@ -1815,7 +1904,10 @@ TEST_F(FunctionalTest, BatchNorm2dDefaultOptions) { auto mean = torch::randn(num_features); auto variance = torch::rand(num_features); auto output = F::batch_norm(input, mean, variance); - auto expected = torch::transpose((torch::transpose(input, 1, 3) - mean) / torch::sqrt(variance + eps), 1, 3); + auto expected = torch::transpose( + (torch::transpose(input, 1, 3) - mean) / torch::sqrt(variance + eps), + 1, + 3); ASSERT_TRUE(output.allclose(expected)); } @@ -1830,9 +1922,19 @@ TEST_F(FunctionalTest, BatchNorm3d) { auto weight = torch::ones({num_features}); auto bias = torch::zeros({num_features}); auto output = F::batch_norm( - input, mean, variance, - F::BatchNormFuncOptions().weight(weight).bias(bias).momentum(momentum).eps(eps).training(false)); - auto expected = torch::transpose((torch::transpose(input, 1, 4) - mean) / torch::sqrt(variance + eps), 1, 4); + input, + mean, + variance, + F::BatchNormFuncOptions() + .weight(weight) + .bias(bias) + .momentum(momentum) + .eps(eps) + .training(false)); + auto expected = torch::transpose( + (torch::transpose(input, 1, 4) - mean) / torch::sqrt(variance + eps), + 1, + 4); ASSERT_TRUE(output.allclose(expected)); } @@ -1844,7 +1946,10 @@ TEST_F(FunctionalTest, BatchNorm3dDefaultOptions) { auto mean = torch::randn(num_features); auto variance = torch::rand(num_features); auto output = F::batch_norm(input, mean, variance); - auto expected = torch::transpose((torch::transpose(input, 1, 4) - mean) / torch::sqrt(variance + eps), 1, 4); + auto expected = torch::transpose( + (torch::transpose(input, 1, 4) - mean) / torch::sqrt(variance + eps), + 1, + 4); ASSERT_TRUE(output.allclose(expected)); } @@ -1859,40 +1964,42 @@ TEST_F(FunctionalTest, InstanceNorm1d) { auto weight = torch::arange((double)num_features); auto bias = torch::arange((double)num_features); auto output = F::instance_norm( - input, - F::InstanceNormFuncOptions() - .running_mean(mean) - .running_var(variance) - .weight(weight) - .bias(bias) - .momentum(momentum) - .eps(eps)); - auto expected = torch::tensor({{{ 0.0000, 0.0000, 0.0000, 0.0000}, - {-0.3416, 0.5528, 1.4472, 2.3416}, - {-0.6833, 1.1056, 2.8944, 4.6833}, - {-1.0249, 1.6584, 4.3416, 7.0249}, - {-1.3665, 2.2112, 5.7888, 9.3665}}, - {{ 0.0000, 0.0000, 0.0000, 0.0000}, - {-0.3416, 0.5528, 1.4472, 2.3416}, - {-0.6833, 1.1056, 2.8944, 4.6833}, - {-1.0249, 1.6584, 4.3416, 7.0249}, - {-1.3665, 2.2112, 5.7888, 9.3665}}}); + input, + F::InstanceNormFuncOptions() + .running_mean(mean) + .running_var(variance) + .weight(weight) + .bias(bias) + .momentum(momentum) + .eps(eps)); + auto expected = torch::tensor( + {{{0.0000, 0.0000, 0.0000, 0.0000}, + {-0.3416, 0.5528, 1.4472, 2.3416}, + {-0.6833, 1.1056, 2.8944, 4.6833}, + {-1.0249, 1.6584, 4.3416, 7.0249}, + {-1.3665, 2.2112, 5.7888, 9.3665}}, + {{0.0000, 0.0000, 0.0000, 0.0000}, + {-0.3416, 0.5528, 1.4472, 2.3416}, + {-0.6833, 1.1056, 2.8944, 4.6833}, + {-1.0249, 1.6584, 4.3416, 7.0249}, + {-1.3665, 2.2112, 5.7888, 9.3665}}}); ASSERT_TRUE(output.allclose(expected, 2e-04)); } TEST_F(FunctionalTest, InstanceNorm1dDefaultOptions) { auto input = torch::arange(40.).view({2, 5, 4}); auto output = F::instance_norm(input); - auto expected = torch::tensor({{{-1.3416, -0.4472, 0.4472, 1.3416}, - {-1.3416, -0.4472, 0.4472, 1.3416}, - {-1.3416, -0.4472, 0.4472, 1.3416}, - {-1.3416, -0.4472, 0.4472, 1.3416}, - {-1.3416, -0.4472, 0.4472, 1.3416}}, - {{-1.3416, -0.4472, 0.4472, 1.3416}, - {-1.3416, -0.4472, 0.4472, 1.3416}, - {-1.3416, -0.4472, 0.4472, 1.3416}, - {-1.3416, -0.4472, 0.4472, 1.3416}, - {-1.3416, -0.4472, 0.4472, 1.3416}}}); + auto expected = torch::tensor( + {{{-1.3416, -0.4472, 0.4472, 1.3416}, + {-1.3416, -0.4472, 0.4472, 1.3416}, + {-1.3416, -0.4472, 0.4472, 1.3416}, + {-1.3416, -0.4472, 0.4472, 1.3416}, + {-1.3416, -0.4472, 0.4472, 1.3416}}, + {{-1.3416, -0.4472, 0.4472, 1.3416}, + {-1.3416, -0.4472, 0.4472, 1.3416}, + {-1.3416, -0.4472, 0.4472, 1.3416}, + {-1.3416, -0.4472, 0.4472, 1.3416}, + {-1.3416, -0.4472, 0.4472, 1.3416}}}); ASSERT_TRUE(output.allclose(expected, 2e-04)); } @@ -1901,68 +2008,52 @@ TEST_F(FunctionalTest, InstanceNorm2d) { double eps = 1e-05; double momentum = 0.1; - auto input = torch::arange(2. * num_features * 2 * 2).view({2, num_features, 2, 2}); + auto input = + torch::arange(2. * num_features * 2 * 2).view({2, num_features, 2, 2}); auto mean = torch::arange((double)num_features); auto variance = torch::arange((double)num_features); auto weight = torch::arange((double)num_features); auto bias = torch::arange((double)num_features); auto output = F::instance_norm( - input, - F::InstanceNormFuncOptions() - .running_mean(mean) - .running_var(variance) - .weight(weight) - .bias(bias) - .momentum(momentum) - .eps(eps)); - auto expected = torch::tensor({{{{ 0.0000, 0.0000}, - { 0.0000, 0.0000}}, - {{-0.3416, 0.5528}, - { 1.4472, 2.3416}}, - {{-0.6833, 1.1056}, - { 2.8944, 4.6833}}, - {{-1.0249, 1.6584}, - { 4.3416, 7.0249}}, - {{-1.3665, 2.2112}, - { 5.7888, 9.3665}}}, - {{{ 0.0000, 0.0000}, - { 0.0000, 0.0000}}, - {{-0.3416, 0.5528}, - { 1.4472, 2.3416}}, - {{-0.6833, 1.1056}, - { 2.8944, 4.6833}}, - {{-1.0249, 1.6584}, - { 4.3416, 7.0249}}, - {{-1.3665, 2.2112}, - { 5.7888, 9.3665}}}}); + input, + F::InstanceNormFuncOptions() + .running_mean(mean) + .running_var(variance) + .weight(weight) + .bias(bias) + .momentum(momentum) + .eps(eps)); + auto expected = torch::tensor( + {{{{0.0000, 0.0000}, {0.0000, 0.0000}}, + {{-0.3416, 0.5528}, {1.4472, 2.3416}}, + {{-0.6833, 1.1056}, {2.8944, 4.6833}}, + {{-1.0249, 1.6584}, {4.3416, 7.0249}}, + {{-1.3665, 2.2112}, {5.7888, 9.3665}}}, + {{{0.0000, 0.0000}, {0.0000, 0.0000}}, + {{-0.3416, 0.5528}, {1.4472, 2.3416}}, + {{-0.6833, 1.1056}, {2.8944, 4.6833}}, + {{-1.0249, 1.6584}, {4.3416, 7.0249}}, + {{-1.3665, 2.2112}, {5.7888, 9.3665}}}}); ASSERT_TRUE(output.allclose(expected, 2e-04)); } TEST_F(FunctionalTest, InstanceNorm2dDefaultOptions) { int num_features = 5; - auto input = torch::arange(2. * num_features * 2 * 2).view({2, num_features, 2, 2}); + auto input = + torch::arange(2. * num_features * 2 * 2).view({2, num_features, 2, 2}); auto output = F::instance_norm(input); - auto expected = torch::tensor({{{{-1.3416, -0.4472}, - { 0.4472, 1.3416}}, - {{-1.3416, -0.4472}, - { 0.4472, 1.3416}}, - {{-1.3416, -0.4472}, - { 0.4472, 1.3416}}, - {{-1.3416, -0.4472}, - { 0.4472, 1.3416}}, - {{-1.3416, -0.4472}, - { 0.4472, 1.3416}}}, - {{{-1.3416, -0.4472}, - { 0.4472, 1.3416}}, - {{-1.3416, -0.4472}, - { 0.4472, 1.3416}}, - {{-1.3416, -0.4472}, - { 0.4472, 1.3416}}, - {{-1.3416, -0.4472}, - { 0.4472, 1.3416}}, - {{-1.3416, -0.4472}, - { 0.4472, 1.3416}}}}); + auto expected = torch::tensor( + {{{{-1.3416, -0.4472}, {0.4472, 1.3416}}, + {{-1.3416, -0.4472}, {0.4472, 1.3416}}, + {{-1.3416, -0.4472}, {0.4472, 1.3416}}, + {{-1.3416, -0.4472}, {0.4472, 1.3416}}, + {{-1.3416, -0.4472}, {0.4472, 1.3416}}}, + {{{-1.3416, -0.4472}, {0.4472, 1.3416}}, + {{-1.3416, -0.4472}, {0.4472, 1.3416}}, + {{-1.3416, -0.4472}, {0.4472, 1.3416}}, + {{-1.3416, -0.4472}, {0.4472, 1.3416}}, + {{-1.3416, -0.4472}, {0.4472, 1.3416}}}}); ASSERT_TRUE(output.allclose(expected, 2e-04)); } @@ -1971,108 +2062,72 @@ TEST_F(FunctionalTest, InstanceNorm3d) { double eps = 1e-05; double momentum = 0.1; - auto input = torch::arange(2. * num_features * 2 * 2 * 2).view({2, num_features, 2, 2, 2}); + auto input = torch::arange(2. * num_features * 2 * 2 * 2) + .view({2, num_features, 2, 2, 2}); auto mean = torch::arange((double)num_features); auto variance = torch::arange((double)num_features); auto weight = torch::arange((double)num_features); auto bias = torch::arange((double)num_features); auto output = F::instance_norm( - input, - F::InstanceNormFuncOptions() - .running_mean(mean) - .running_var(variance) - .weight(weight) - .bias(bias) - .momentum(momentum) - .eps(eps)); - auto expected = torch::tensor({{{{{ 0.0000, 0.0000}, - { 0.0000, 0.0000}}, - {{ 0.0000, 0.0000}, - { 0.0000, 0.0000}}}, - {{{-0.5275, -0.0911}, - { 0.3453, 0.7818}}, - {{ 1.2182, 1.6547}, - { 2.0911, 2.5275}}}, - {{{-1.0550, -0.1822}, - { 0.6907, 1.5636}}, - {{ 2.4364, 3.3093}, - { 4.1822, 5.0550}}}, - {{{-1.5826, -0.2733}, - { 1.0360, 2.3453}}, - {{ 3.6547, 4.9640}, - { 6.2733, 7.5826}}}, - {{{-2.1101, -0.3644}, - { 1.3814, 3.1271}}, - {{ 4.8729, 6.6186}, - { 8.3644, 10.1101}}}}, - {{{{ 0.0000, 0.0000}, - { 0.0000, 0.0000}}, - {{ 0.0000, 0.0000}, - { 0.0000, 0.0000}}}, - {{{-0.5275, -0.0911}, - { 0.3453, 0.7818}}, - {{ 1.2182, 1.6547}, - { 2.0911, 2.5275}}}, - {{{-1.0550, -0.1822}, - { 0.6907, 1.5636}}, - {{ 2.4364, 3.3093}, - { 4.1822, 5.0550}}}, - {{{-1.5826, -0.2733}, - { 1.0360, 2.3453}}, - {{ 3.6547, 4.9640}, - { 6.2733, 7.5826}}}, - {{{-2.1101, -0.3644}, - { 1.3814, 3.1271}}, - {{ 4.8729, 6.6186}, - { 8.3644, 10.1101}}}}}); + input, + F::InstanceNormFuncOptions() + .running_mean(mean) + .running_var(variance) + .weight(weight) + .bias(bias) + .momentum(momentum) + .eps(eps)); + auto expected = torch::tensor( + {{{{{0.0000, 0.0000}, {0.0000, 0.0000}}, + {{0.0000, 0.0000}, {0.0000, 0.0000}}}, + {{{-0.5275, -0.0911}, {0.3453, 0.7818}}, + {{1.2182, 1.6547}, {2.0911, 2.5275}}}, + {{{-1.0550, -0.1822}, {0.6907, 1.5636}}, + {{2.4364, 3.3093}, {4.1822, 5.0550}}}, + {{{-1.5826, -0.2733}, {1.0360, 2.3453}}, + {{3.6547, 4.9640}, {6.2733, 7.5826}}}, + {{{-2.1101, -0.3644}, {1.3814, 3.1271}}, + {{4.8729, 6.6186}, {8.3644, 10.1101}}}}, + {{{{0.0000, 0.0000}, {0.0000, 0.0000}}, + {{0.0000, 0.0000}, {0.0000, 0.0000}}}, + {{{-0.5275, -0.0911}, {0.3453, 0.7818}}, + {{1.2182, 1.6547}, {2.0911, 2.5275}}}, + {{{-1.0550, -0.1822}, {0.6907, 1.5636}}, + {{2.4364, 3.3093}, {4.1822, 5.0550}}}, + {{{-1.5826, -0.2733}, {1.0360, 2.3453}}, + {{3.6547, 4.9640}, {6.2733, 7.5826}}}, + {{{-2.1101, -0.3644}, {1.3814, 3.1271}}, + {{4.8729, 6.6186}, {8.3644, 10.1101}}}}}); ASSERT_TRUE(output.allclose(expected, 2e-04)); } TEST_F(FunctionalTest, InstanceNorm3dDefaultOptions) { int num_features = 5; - auto input = torch::arange(2. * num_features * 2 * 2 * 2).view({2, num_features, 2, 2, 2}); + auto input = torch::arange(2. * num_features * 2 * 2 * 2) + .view({2, num_features, 2, 2, 2}); auto output = F::instance_norm(input); - auto expected = torch::tensor({{{{{-1.5275, -1.0911}, - {-0.6547, -0.2182}}, - {{ 0.2182, 0.6547}, - { 1.0911, 1.5275}}}, - {{{-1.5275, -1.0911}, - {-0.6547, -0.2182}}, - {{ 0.2182, 0.6547}, - { 1.0911, 1.5275}}}, - {{{-1.5275, -1.0911}, - {-0.6547, -0.2182}}, - {{ 0.2182, 0.6547}, - { 1.0911, 1.5275}}}, - {{{-1.5275, -1.0911}, - {-0.6547, -0.2182}}, - {{ 0.2182, 0.6547}, - { 1.0911, 1.5275}}}, - {{{-1.5275, -1.0911}, - {-0.6547, -0.2182}}, - {{ 0.2182, 0.6547}, - { 1.0911, 1.5275}}}}, - {{{{-1.5275, -1.0911}, - {-0.6547, -0.2182}}, - {{ 0.2182, 0.6547}, - { 1.0911, 1.5275}}}, - {{{-1.5275, -1.0911}, - {-0.6547, -0.2182}}, - {{ 0.2182, 0.6547}, - { 1.0911, 1.5275}}}, - {{{-1.5275, -1.0911}, - {-0.6547, -0.2182}}, - {{ 0.2182, 0.6547}, - { 1.0911, 1.5275}}}, - {{{-1.5275, -1.0911}, - {-0.6547, -0.2182}}, - {{ 0.2182, 0.6547}, - { 1.0911, 1.5275}}}, - {{{-1.5275, -1.0911}, - {-0.6547, -0.2182}}, - {{ 0.2182, 0.6547}, - { 1.0911, 1.5275}}}}}); + auto expected = torch::tensor( + {{{{{-1.5275, -1.0911}, {-0.6547, -0.2182}}, + {{0.2182, 0.6547}, {1.0911, 1.5275}}}, + {{{-1.5275, -1.0911}, {-0.6547, -0.2182}}, + {{0.2182, 0.6547}, {1.0911, 1.5275}}}, + {{{-1.5275, -1.0911}, {-0.6547, -0.2182}}, + {{0.2182, 0.6547}, {1.0911, 1.5275}}}, + {{{-1.5275, -1.0911}, {-0.6547, -0.2182}}, + {{0.2182, 0.6547}, {1.0911, 1.5275}}}, + {{{-1.5275, -1.0911}, {-0.6547, -0.2182}}, + {{0.2182, 0.6547}, {1.0911, 1.5275}}}}, + {{{{-1.5275, -1.0911}, {-0.6547, -0.2182}}, + {{0.2182, 0.6547}, {1.0911, 1.5275}}}, + {{{-1.5275, -1.0911}, {-0.6547, -0.2182}}, + {{0.2182, 0.6547}, {1.0911, 1.5275}}}, + {{{-1.5275, -1.0911}, {-0.6547, -0.2182}}, + {{0.2182, 0.6547}, {1.0911, 1.5275}}}, + {{{-1.5275, -1.0911}, {-0.6547, -0.2182}}, + {{0.2182, 0.6547}, {1.0911, 1.5275}}}, + {{{-1.5275, -1.0911}, {-0.6547, -0.2182}}, + {{0.2182, 0.6547}, {1.0911, 1.5275}}}}}); ASSERT_TRUE(output.allclose(expected, 2e-04)); } @@ -2094,10 +2149,11 @@ TEST_F(FunctionalTest, Interpolate) { // test float scale factor up & down sampling for (const auto scale_factor : {0.5, 1.5, 2.0}) { auto input = torch::ones({1, 1, 2, 2}); - auto options = F::InterpolateFuncOptions() - .scale_factor(std::vector({scale_factor, scale_factor})) - .mode(torch::kBilinear) - .align_corners(align_corners); + auto options = + F::InterpolateFuncOptions() + .scale_factor(std::vector({scale_factor, scale_factor})) + .mode(torch::kBilinear) + .align_corners(align_corners); auto output = F::interpolate(input, options); auto expected_size = static_cast(std::floor(input.size(-1) * scale_factor)); @@ -2112,11 +2168,11 @@ TEST_F(FunctionalTest, Interpolate) { for (const auto align_corners : {true, false}) { for (const auto scale_factor : {0.5, 1.5, 2.0}) { auto input = torch::ones({1, 1, 2, 2, 2}); - auto options = - F::InterpolateFuncOptions() - .scale_factor(std::vector({scale_factor, scale_factor, scale_factor})) - .mode(torch::kTrilinear) - .align_corners(align_corners); + auto options = F::InterpolateFuncOptions() + .scale_factor(std::vector( + {scale_factor, scale_factor, scale_factor})) + .mode(torch::kTrilinear) + .align_corners(align_corners); auto output = F::interpolate(input, options); auto expected_size = static_cast(std::floor(input.size(-1) * scale_factor)); @@ -2130,13 +2186,16 @@ TEST_F(FunctionalTest, Interpolate) { { auto input = torch::randn({3, 2, 2}); ASSERT_THROWS_WITH( - F::interpolate(input[0], F::InterpolateFuncOptions().size(std::vector({4, 4}))), + F::interpolate( + input[0], + F::InterpolateFuncOptions().size(std::vector({4, 4}))), "Input Error: Only 3D, 4D and 5D input Tensors supported (got 2D) " "for the modes: nearest | linear | bilinear | bicubic | trilinear (got kNearest)"); ASSERT_THROWS_WITH( F::interpolate( torch::reshape(input, {1, 1, 1, 3, 2, 2}), - F::InterpolateFuncOptions().size(std::vector({1, 1, 1, 3, 4, 4}))), + F::InterpolateFuncOptions().size( + std::vector({1, 1, 1, 3, 4, 4}))), "Input Error: Only 3D, 4D and 5D input Tensors supported (got 6D) " "for the modes: nearest | linear | bilinear | bicubic | trilinear (got kNearest)"); ASSERT_THROWS_WITH( @@ -2145,10 +2204,15 @@ TEST_F(FunctionalTest, Interpolate) { ASSERT_THROWS_WITH( F::interpolate( input, - F::InterpolateFuncOptions().size(std::vector({3, 4, 4})).scale_factor(std::vector({0.5}))), + F::InterpolateFuncOptions() + .size(std::vector({3, 4, 4})) + .scale_factor(std::vector({0.5}))), "only one of size or scale_factor should be defined"); ASSERT_THROWS_WITH( - F::interpolate(input, F::InterpolateFuncOptions().scale_factor(std::vector({3, 2}))), + F::interpolate( + input, + F::InterpolateFuncOptions().scale_factor( + std::vector({3, 2}))), "scale_factor shape must match input shape. " "Input is 1D, scale_factor size is [3, 2]"); ASSERT_THROWS_WITH( @@ -2163,12 +2227,13 @@ TEST_F(FunctionalTest, Interpolate) { { auto tensor = torch::rand({2, 3, 32, 32}); std::vector osize = {8, 10}; - auto expected = at::native::_upsample_nearest_exact2d(tensor, osize, torch::nullopt); + auto expected = + at::native::_upsample_nearest_exact2d(tensor, osize, torch::nullopt); auto options = F::InterpolateFuncOptions() - .size(osize) - .mode(torch::kNearestExact) - .align_corners(false); + .size(osize) + .mode(torch::kNearestExact) + .align_corners(false); auto output = F::interpolate(tensor, options); ASSERT_TRUE(output.allclose(expected)); @@ -2176,25 +2241,26 @@ TEST_F(FunctionalTest, Interpolate) { { auto tensor = torch::rand({2, 3, 32, 32}); std::vector osize = {8, 10}; - auto expected = at::native::_upsample_bilinear2d_aa(tensor, osize, false, torch::nullopt); + auto expected = at::native::_upsample_bilinear2d_aa( + tensor, osize, false, torch::nullopt); auto options = F::InterpolateFuncOptions() - .size(osize) - .mode(torch::kBilinear) - .align_corners(false) - .antialias(true); + .size(osize) + .mode(torch::kBilinear) + .align_corners(false) + .antialias(true); auto output = F::interpolate(tensor, options); ASSERT_TRUE(output.allclose(expected)); - } } TEST_F(FunctionalTest, Pad1) { { auto input = torch::arange(6, torch::kDouble).reshape({1, 2, 3}); - auto output = F::pad(input, F::PadFuncOptions({1, 2}).mode(torch::kCircular)); - auto expected = torch::tensor({{{2., 0., 1., 2., 0., 1.}, - {5., 3., 4., 5., 3., 4.}}}, torch::kDouble); + auto output = + F::pad(input, F::PadFuncOptions({1, 2}).mode(torch::kCircular)); + auto expected = torch::tensor( + {{{2., 0., 1., 2., 0., 1.}, {5., 3., 4., 5., 3., 4.}}}, torch::kDouble); ASSERT_EQ(output.sizes(), std::vector({1, 2, 6})); ASSERT_TRUE(output.allclose(expected, 1e-04)); } @@ -2202,15 +2268,17 @@ TEST_F(FunctionalTest, Pad1) { TEST_F(FunctionalTest, Pad2) { { auto input = torch::arange(9, torch::kDouble).reshape({1, 1, 3, 3}); - auto output = F::pad(input, F::PadFuncOptions({3, 3, 3, 1}).mode(torch::kCircular)); + auto output = + F::pad(input, F::PadFuncOptions({3, 3, 3, 1}).mode(torch::kCircular)); auto expected = torch::tensor( - {{{{0., 1., 2., 0., 1., 2., 0., 1., 2.}, - {3., 4., 5., 3., 4., 5., 3., 4., 5.}, - {6., 7., 8., 6., 7., 8., 6., 7., 8.}, - {0., 1., 2., 0., 1., 2., 0., 1., 2.}, - {3., 4., 5., 3., 4., 5., 3., 4., 5.}, - {6., 7., 8., 6., 7., 8., 6., 7., 8.}, - {0., 1., 2., 0., 1., 2., 0., 1., 2.}}}}, torch::kDouble); + {{{{0., 1., 2., 0., 1., 2., 0., 1., 2.}, + {3., 4., 5., 3., 4., 5., 3., 4., 5.}, + {6., 7., 8., 6., 7., 8., 6., 7., 8.}, + {0., 1., 2., 0., 1., 2., 0., 1., 2.}, + {3., 4., 5., 3., 4., 5., 3., 4., 5.}, + {6., 7., 8., 6., 7., 8., 6., 7., 8.}, + {0., 1., 2., 0., 1., 2., 0., 1., 2.}}}}, + torch::kDouble); ASSERT_EQ(output.sizes(), std::vector({1, 1, 7, 9})); ASSERT_TRUE(output.allclose(expected, 1e-04)); } @@ -2218,43 +2286,45 @@ TEST_F(FunctionalTest, Pad2) { TEST_F(FunctionalTest, Pad3) { { auto input = torch::arange(12, torch::kDouble).reshape({1, 1, 2, 2, 3}); - auto output = F::pad(input, F::PadFuncOptions({3, 3, 2, 1, 2, 2}).mode(torch::kCircular)); + auto output = F::pad( + input, F::PadFuncOptions({3, 3, 2, 1, 2, 2}).mode(torch::kCircular)); auto expected = torch::tensor( - {{{{{ 0., 1., 2., 0., 1., 2., 0., 1., 2.}, - { 3., 4., 5., 3., 4., 5., 3., 4., 5.}, - { 0., 1., 2., 0., 1., 2., 0., 1., 2.}, - { 3., 4., 5., 3., 4., 5., 3., 4., 5.}, - { 0., 1., 2., 0., 1., 2., 0., 1., 2.}}, - - {{ 6., 7., 8., 6., 7., 8., 6., 7., 8.}, - { 9., 10., 11., 9., 10., 11., 9., 10., 11.}, - { 6., 7., 8., 6., 7., 8., 6., 7., 8.}, - { 9., 10., 11., 9., 10., 11., 9., 10., 11.}, - { 6., 7., 8., 6., 7., 8., 6., 7., 8.}}, - - {{ 0., 1., 2., 0., 1., 2., 0., 1., 2.}, - { 3., 4., 5., 3., 4., 5., 3., 4., 5.}, - { 0., 1., 2., 0., 1., 2., 0., 1., 2.}, - { 3., 4., 5., 3., 4., 5., 3., 4., 5.}, - { 0., 1., 2., 0., 1., 2., 0., 1., 2.}}, - - {{ 6., 7., 8., 6., 7., 8., 6., 7., 8.}, - { 9., 10., 11., 9., 10., 11., 9., 10., 11.}, - { 6., 7., 8., 6., 7., 8., 6., 7., 8.}, - { 9., 10., 11., 9., 10., 11., 9., 10., 11.}, - { 6., 7., 8., 6., 7., 8., 6., 7., 8.}}, - - {{ 0., 1., 2., 0., 1., 2., 0., 1., 2.}, - { 3., 4., 5., 3., 4., 5., 3., 4., 5.}, - { 0., 1., 2., 0., 1., 2., 0., 1., 2.}, - { 3., 4., 5., 3., 4., 5., 3., 4., 5.}, - { 0., 1., 2., 0., 1., 2., 0., 1., 2.}}, - - {{ 6., 7., 8., 6., 7., 8., 6., 7., 8.}, - { 9., 10., 11., 9., 10., 11., 9., 10., 11.}, - { 6., 7., 8., 6., 7., 8., 6., 7., 8.}, - { 9., 10., 11., 9., 10., 11., 9., 10., 11.}, - { 6., 7., 8., 6., 7., 8., 6., 7., 8.}}}}}, torch::kDouble); + {{{{{0., 1., 2., 0., 1., 2., 0., 1., 2.}, + {3., 4., 5., 3., 4., 5., 3., 4., 5.}, + {0., 1., 2., 0., 1., 2., 0., 1., 2.}, + {3., 4., 5., 3., 4., 5., 3., 4., 5.}, + {0., 1., 2., 0., 1., 2., 0., 1., 2.}}, + + {{6., 7., 8., 6., 7., 8., 6., 7., 8.}, + {9., 10., 11., 9., 10., 11., 9., 10., 11.}, + {6., 7., 8., 6., 7., 8., 6., 7., 8.}, + {9., 10., 11., 9., 10., 11., 9., 10., 11.}, + {6., 7., 8., 6., 7., 8., 6., 7., 8.}}, + + {{0., 1., 2., 0., 1., 2., 0., 1., 2.}, + {3., 4., 5., 3., 4., 5., 3., 4., 5.}, + {0., 1., 2., 0., 1., 2., 0., 1., 2.}, + {3., 4., 5., 3., 4., 5., 3., 4., 5.}, + {0., 1., 2., 0., 1., 2., 0., 1., 2.}}, + + {{6., 7., 8., 6., 7., 8., 6., 7., 8.}, + {9., 10., 11., 9., 10., 11., 9., 10., 11.}, + {6., 7., 8., 6., 7., 8., 6., 7., 8.}, + {9., 10., 11., 9., 10., 11., 9., 10., 11.}, + {6., 7., 8., 6., 7., 8., 6., 7., 8.}}, + + {{0., 1., 2., 0., 1., 2., 0., 1., 2.}, + {3., 4., 5., 3., 4., 5., 3., 4., 5.}, + {0., 1., 2., 0., 1., 2., 0., 1., 2.}, + {3., 4., 5., 3., 4., 5., 3., 4., 5.}, + {0., 1., 2., 0., 1., 2., 0., 1., 2.}}, + + {{6., 7., 8., 6., 7., 8., 6., 7., 8.}, + {9., 10., 11., 9., 10., 11., 9., 10., 11.}, + {6., 7., 8., 6., 7., 8., 6., 7., 8.}, + {9., 10., 11., 9., 10., 11., 9., 10., 11.}, + {6., 7., 8., 6., 7., 8., 6., 7., 8.}}}}}, + torch::kDouble); ASSERT_EQ(output.sizes(), std::vector({1, 1, 6, 5, 9})); ASSERT_TRUE(output.allclose(expected, 1e-04)); } @@ -2262,27 +2332,29 @@ TEST_F(FunctionalTest, Pad3) { TEST_F(FunctionalTest, Pad4) { { auto input = torch::arange(16, torch::kDouble).reshape({2, 2, 2, 2}); - auto output = F::pad(input, F::PadFuncOptions({1, 1, 1, 1}).mode(torch::kReflect)); + auto output = + F::pad(input, F::PadFuncOptions({1, 1, 1, 1}).mode(torch::kReflect)); auto expected = torch::tensor( - {{{{ 3., 2., 3., 2.}, - { 1., 0., 1., 0.}, - { 3., 2., 3., 2.}, - { 1., 0., 1., 0.}}, - - {{ 7., 6., 7., 6.}, - { 5., 4., 5., 4.}, - { 7., 6., 7., 6.}, - { 5., 4., 5., 4.}}}, - - {{{11., 10., 11., 10.}, - { 9., 8., 9., 8.}, - {11., 10., 11., 10.}, - { 9., 8., 9., 8.}}, - - {{15., 14., 15., 14.}, - {13., 12., 13., 12.}, - {15., 14., 15., 14.}, - {13., 12., 13., 12.}}}}, torch::kDouble); + {{{{3., 2., 3., 2.}, + {1., 0., 1., 0.}, + {3., 2., 3., 2.}, + {1., 0., 1., 0.}}, + + {{7., 6., 7., 6.}, + {5., 4., 5., 4.}, + {7., 6., 7., 6.}, + {5., 4., 5., 4.}}}, + + {{{11., 10., 11., 10.}, + {9., 8., 9., 8.}, + {11., 10., 11., 10.}, + {9., 8., 9., 8.}}, + + {{15., 14., 15., 14.}, + {13., 12., 13., 12.}, + {15., 14., 15., 14.}, + {13., 12., 13., 12.}}}}, + torch::kDouble); ASSERT_EQ(output.sizes(), std::vector({2, 2, 4, 4})); ASSERT_TRUE(output.allclose(expected, 1e-04)); } @@ -2290,37 +2362,39 @@ TEST_F(FunctionalTest, Pad4) { TEST_F(FunctionalTest, Pad5) { { auto input = torch::arange(12, torch::kDouble).reshape({1, 1, 2, 2, 3}); - auto output = F::pad(input, F::PadFuncOptions({1, 2, 2, 1, 1, 2}).mode(torch::kReplicate)); + auto output = F::pad( + input, F::PadFuncOptions({1, 2, 2, 1, 1, 2}).mode(torch::kReplicate)); auto expected = torch::tensor( - {{{{{ 0., 0., 1., 2., 2., 2.}, - { 0., 0., 1., 2., 2., 2.}, - { 0., 0., 1., 2., 2., 2.}, - { 3., 3., 4., 5., 5., 5.}, - { 3., 3., 4., 5., 5., 5.}}, - - {{ 0., 0., 1., 2., 2., 2.}, - { 0., 0., 1., 2., 2., 2.}, - { 0., 0., 1., 2., 2., 2.}, - { 3., 3., 4., 5., 5., 5.}, - { 3., 3., 4., 5., 5., 5.}}, - - {{ 6., 6., 7., 8., 8., 8.}, - { 6., 6., 7., 8., 8., 8.}, - { 6., 6., 7., 8., 8., 8.}, - { 9., 9., 10., 11., 11., 11.}, - { 9., 9., 10., 11., 11., 11.}}, - - {{ 6., 6., 7., 8., 8., 8.}, - { 6., 6., 7., 8., 8., 8.}, - { 6., 6., 7., 8., 8., 8.}, - { 9., 9., 10., 11., 11., 11.}, - { 9., 9., 10., 11., 11., 11.}}, - - {{ 6., 6., 7., 8., 8., 8.}, - { 6., 6., 7., 8., 8., 8.}, - { 6., 6., 7., 8., 8., 8.}, - { 9., 9., 10., 11., 11., 11.}, - { 9., 9., 10., 11., 11., 11.}}}}}, torch::kDouble); + {{{{{0., 0., 1., 2., 2., 2.}, + {0., 0., 1., 2., 2., 2.}, + {0., 0., 1., 2., 2., 2.}, + {3., 3., 4., 5., 5., 5.}, + {3., 3., 4., 5., 5., 5.}}, + + {{0., 0., 1., 2., 2., 2.}, + {0., 0., 1., 2., 2., 2.}, + {0., 0., 1., 2., 2., 2.}, + {3., 3., 4., 5., 5., 5.}, + {3., 3., 4., 5., 5., 5.}}, + + {{6., 6., 7., 8., 8., 8.}, + {6., 6., 7., 8., 8., 8.}, + {6., 6., 7., 8., 8., 8.}, + {9., 9., 10., 11., 11., 11.}, + {9., 9., 10., 11., 11., 11.}}, + + {{6., 6., 7., 8., 8., 8.}, + {6., 6., 7., 8., 8., 8.}, + {6., 6., 7., 8., 8., 8.}, + {9., 9., 10., 11., 11., 11.}, + {9., 9., 10., 11., 11., 11.}}, + + {{6., 6., 7., 8., 8., 8.}, + {6., 6., 7., 8., 8., 8.}, + {6., 6., 7., 8., 8., 8.}, + {9., 9., 10., 11., 11., 11.}, + {9., 9., 10., 11., 11., 11.}}}}}, + torch::kDouble); ASSERT_EQ(output.sizes(), std::vector({1, 1, 5, 5, 6})); ASSERT_TRUE(output.allclose(expected, 1e-04)); } @@ -2328,31 +2402,31 @@ TEST_F(FunctionalTest, Pad5) { TEST_F(FunctionalTest, Pad6) { { auto input = torch::arange(18, torch::kDouble).reshape({1, 1, 3, 2, 3}); - auto output = F::pad(input, F::PadFuncOptions({0, 2, 1, 0, 1, 2}).mode(torch::kReflect)); + auto output = F::pad( + input, F::PadFuncOptions({0, 2, 1, 0, 1, 2}).mode(torch::kReflect)); auto expected = torch::tensor( - {{{{{ 9., 10., 11., 10., 9.}, - { 6., 7., 8., 7., 6.}, - { 9., 10., 11., 10., 9.}}, + {{{{{9., 10., 11., 10., 9.}, + {6., 7., 8., 7., 6.}, + {9., 10., 11., 10., 9.}}, - {{ 3., 4., 5., 4., 3.}, - { 0., 1., 2., 1., 0.}, - { 3., 4., 5., 4., 3.}}, + {{3., 4., 5., 4., 3.}, {0., 1., 2., 1., 0.}, {3., 4., 5., 4., 3.}}, - {{ 9., 10., 11., 10., 9.}, - { 6., 7., 8., 7., 6.}, - { 9., 10., 11., 10., 9.}}, + {{9., 10., 11., 10., 9.}, + {6., 7., 8., 7., 6.}, + {9., 10., 11., 10., 9.}}, - {{ 15., 16., 17., 16., 15.}, - { 12., 13., 14., 13., 12.}, - { 15., 16., 17., 16., 15.}}, + {{15., 16., 17., 16., 15.}, + {12., 13., 14., 13., 12.}, + {15., 16., 17., 16., 15.}}, - {{ 9., 10., 11., 10., 9.}, - { 6., 7., 8., 7., 6.}, - { 9., 10., 11., 10., 9.}}, + {{9., 10., 11., 10., 9.}, + {6., 7., 8., 7., 6.}, + {9., 10., 11., 10., 9.}}, - {{ 3., 4., 5., 4., 3.}, - { 0., 1., 2., 1., 0.}, - { 3., 4., 5., 4., 3.}}}}}, torch::kDouble); + {{3., 4., 5., 4., 3.}, + {0., 1., 2., 1., 0.}, + {3., 4., 5., 4., 3.}}}}}, + torch::kDouble); ASSERT_EQ(output.sizes(), std::vector({1, 1, 6, 3, 5})); ASSERT_TRUE(output.allclose(expected, 1e-04)); } @@ -2360,7 +2434,8 @@ TEST_F(FunctionalTest, Pad6) { TEST_F(FunctionalTest, Pad7) { { auto input = torch::ones({1, 1, 1, 1}, torch::kDouble); - auto output = F::pad(input, F::PadFuncOptions({1, 1}).mode(torch::kConstant).value(0)); + auto output = F::pad( + input, F::PadFuncOptions({1, 1}).mode(torch::kConstant).value(0)); ASSERT_EQ(output.sizes(), std::vector({1, 1, 1, 3})); auto expected = torch::tensor({{{{0., 1., 0.}}}}, torch::kDouble); } @@ -2379,30 +2454,28 @@ TEST_F(FunctionalTest, CTCLoss) { const auto target_lengths = torch::tensor({30, 25, 20}); const auto input_lengths = torch::tensor({50, 50, 50}); const auto targets = - torch::randint(1, 15, {target_lengths.sum().item()}, torch::kInt); - const auto log_probs = torch::randn({50, 3, 15}, torch::kFloat).log_softmax(2); + torch::randint(1, 15, {target_lengths.sum().item()}, torch::kInt); + const auto log_probs = + torch::randn({50, 3, 15}, torch::kFloat).log_softmax(2); const auto _input_lengths = input_lengths.to(torch::kFloat); ASSERT_THROWS_WITH( - F::ctc_loss( - log_probs, targets, _input_lengths, target_lengths), + F::ctc_loss(log_probs, targets, _input_lengths, target_lengths), "input_lengths must be integral"); const auto target_lengths_ = target_lengths.to(torch::kFloat); ASSERT_THROWS_WITH( - F::ctc_loss( - log_probs, targets, input_lengths, target_lengths_), + F::ctc_loss(log_probs, targets, input_lengths, target_lengths_), "target_lengths must be integral"); } { // test CTCLoss length checks const auto target_lengths = torch::tensor({30, 25, 20}); const auto input_lengths = torch::tensor({50, 50, 50}); const auto targets = torch::randint(1, 15, {3, 29}, torch::kInt); - const auto log_probs = torch::randn({50, 3, 15}, torch::kFloat) - .log_softmax(2); + const auto log_probs = + torch::randn({50, 3, 15}, torch::kFloat).log_softmax(2); ASSERT_THROWS_WITH( - F::ctc_loss( - log_probs, targets, input_lengths, target_lengths), + F::ctc_loss(log_probs, targets, input_lengths, target_lengths), "Expected tensor to have size at least 30 at dimension 1"); } { // test CTCLoss empty target @@ -2410,32 +2483,38 @@ TEST_F(FunctionalTest, CTCLoss) { const auto target_lengths = torch::tensor({0, 0, 0}); const auto input_lengths = torch::tensor({50, 50, 50}); const auto targets = - torch::randint(1, 15, at::IntArrayRef({0}), torch::kLong); + torch::randint(1, 15, at::IntArrayRef({0}), torch::kLong); const auto log_probs = - torch::randn({50, 3, 15}, torch::kDouble).log_softmax(2); + torch::randn({50, 3, 15}, torch::kDouble).log_softmax(2); const auto loss = F::ctc_loss( - log_probs, targets, input_lengths, target_lengths, - F::CTCLossFuncOptions().reduction(torch::kNone)); + log_probs, + targets, + input_lengths, + target_lengths, + F::CTCLossFuncOptions().reduction(torch::kNone)); ASSERT_TRUE(loss.ge(0).all().item()); ASSERT_TRUE(torch::allclose( - -log_probs.sum(0).slice(1, 0, 1).view_as(loss), loss)); + -log_probs.sum(0).slice(1, 0, 1).view_as(loss), loss)); } { const auto target_lengths = torch::tensor({0, 9, 0}); const auto input_lengths = torch::tensor({50, 50, 50}); const auto targets = torch::randint(1, 15, {9}, torch::kLong); const auto log_probs = - torch::randn({50, 3, 15}, torch::kDouble).log_softmax(2); + torch::randn({50, 3, 15}, torch::kDouble).log_softmax(2); const auto loss = F::ctc_loss( - log_probs, targets, input_lengths, target_lengths, - F::CTCLossFuncOptions().reduction(torch::kNone)); + log_probs, + targets, + input_lengths, + target_lengths, + F::CTCLossFuncOptions().reduction(torch::kNone)); ASSERT_TRUE(loss.ge(0).all().item()); ASSERT_TRUE(torch::allclose( -log_probs.sum(0) - .index_select(0, torch::tensor({0, 2}, torch::kLong)) - .slice(1, 0, 1).view({2}), - loss.index_select(0, torch::tensor({0, 2}, torch::kLong)) - )); + .index_select(0, torch::tensor({0, 2}, torch::kLong)) + .slice(1, 0, 1) + .view({2}), + loss.index_select(0, torch::tensor({0, 2}, torch::kLong)))); } } } @@ -2444,17 +2523,26 @@ TEST_F(FunctionalTest, PoissonNLLLoss) { const auto input = torch::tensor({0.5, 1.5, 2.5}); const auto target = torch::tensor({1., 2., 3.}); const auto component_wise_loss = torch::exp(input) - target * input; - ASSERT_TRUE(torch::allclose(torch::mean(component_wise_loss), - F::poisson_nll_loss(input, target))); - ASSERT_TRUE(torch::allclose(component_wise_loss, - F::poisson_nll_loss(input, target, - F::PoissonNLLLossFuncOptions().reduction(torch::kNone)))); - ASSERT_TRUE(torch::allclose(torch::sum(component_wise_loss), - F::poisson_nll_loss(input, target, - F::PoissonNLLLossFuncOptions().reduction(torch::kSum)))); - ASSERT_TRUE(torch::allclose(torch::mean(component_wise_loss), - F::poisson_nll_loss(input, target, - F::PoissonNLLLossFuncOptions().reduction(torch::kMean)))); + ASSERT_TRUE(torch::allclose( + torch::mean(component_wise_loss), F::poisson_nll_loss(input, target))); + ASSERT_TRUE(torch::allclose( + component_wise_loss, + F::poisson_nll_loss( + input, + target, + F::PoissonNLLLossFuncOptions().reduction(torch::kNone)))); + ASSERT_TRUE(torch::allclose( + torch::sum(component_wise_loss), + F::poisson_nll_loss( + input, + target, + F::PoissonNLLLossFuncOptions().reduction(torch::kSum)))); + ASSERT_TRUE(torch::allclose( + torch::mean(component_wise_loss), + F::poisson_nll_loss( + input, + target, + F::PoissonNLLLossFuncOptions().reduction(torch::kMean)))); } TEST_F(FunctionalTest, MarginRankingLoss) { @@ -2463,9 +2551,8 @@ TEST_F(FunctionalTest, MarginRankingLoss) { const auto input2 = torch::randn(15) * 10; const auto target = torch::randn(15).sign(); ASSERT_TRUE(torch::allclose( - F::margin_ranking_loss(input1, input2, target), - (-target * (input1 - input2)).clamp(0).mean() - )); + F::margin_ranking_loss(input1, input2, target), + (-target * (input1 - input2)).clamp(0).mean())); } { const auto input1 = torch::randn(15) * 10; @@ -2473,11 +2560,13 @@ TEST_F(FunctionalTest, MarginRankingLoss) { const auto target = torch::randn(15).sign(); const auto margin = 0.5; ASSERT_TRUE(torch::allclose( - F::margin_ranking_loss(input1, input2, target, - F::MarginRankingLossFuncOptions().margin(0.5).reduction(torch::kSum) - ), - (-target * (input1 - input2) + margin).clamp(0).sum() - )); + F::margin_ranking_loss( + input1, + input2, + target, + F::MarginRankingLossFuncOptions().margin(0.5).reduction( + torch::kSum)), + (-target * (input1 - input2) + margin).clamp(0).sum())); } { const auto input1 = torch::randn(15) * 10; @@ -2485,24 +2574,28 @@ TEST_F(FunctionalTest, MarginRankingLoss) { const auto target = torch::randn(15).sign(); const auto margin = 0.5; ASSERT_TRUE(torch::allclose( - F::margin_ranking_loss(input1, input2, target, - F::MarginRankingLossFuncOptions().margin(0.5).reduction(torch::kMean) - ), - (-target * (input1 - input2) + margin).clamp(0).mean() - )); + F::margin_ranking_loss( + input1, + input2, + target, + F::MarginRankingLossFuncOptions().margin(0.5).reduction( + torch::kMean)), + (-target * (input1 - input2) + margin).clamp(0).mean())); } } TEST_F(FunctionalTest, ConvTranspose1d) { auto x = torch::arange(20.).view({2, 2, 5}); auto weight = torch::arange(18.).view({2, 3, 3}); - auto y = F::conv_transpose1d(x, weight, F::ConvTranspose1dFuncOptions().stride(1)); - auto expected = torch::tensor({{{ 45., 104., 179., 212., 245., 188., 107.}, - { 60., 140., 242., 293., 344., 260., 146.}, - { 75., 176., 305., 374., 443., 332., 185.}}, - {{ 135., 304., 509., 542., 575., 428., 237.}, - { 210., 460., 752., 803., 854., 620., 336.}, - { 285., 616., 995., 1064., 1133., 812., 435.}}}); + auto y = + F::conv_transpose1d(x, weight, F::ConvTranspose1dFuncOptions().stride(1)); + auto expected = torch::tensor( + {{{45., 104., 179., 212., 245., 188., 107.}, + {60., 140., 242., 293., 344., 260., 146.}, + {75., 176., 305., 374., 443., 332., 185.}}, + {{135., 304., 509., 542., 575., 428., 237.}, + {210., 460., 752., 803., 854., 620., 336.}, + {285., 616., 995., 1064., 1133., 812., 435.}}}); ASSERT_TRUE(torch::allclose(y, expected)); auto y_no_options = F::conv_transpose1d(x, weight); @@ -2512,28 +2605,30 @@ TEST_F(FunctionalTest, ConvTranspose1d) { TEST_F(FunctionalTest, ConvTranspose2dEven) { auto x = torch::arange(50.).view({1, 2, 5, 5}); auto weight = torch::arange(54.).view({2, 3, 3, 3}); - auto y = F::conv_transpose2d(x, weight, F::ConvTranspose2dFuncOptions().stride(1)); - auto expected = torch::tensor({{{{ 675., 1402., 2183., 2270., 2357., 1634., 849.}, - { 1560., 3240., 5044., 5236., 5428., 3760., 1952.}, - { 2685., 5574., 8673., 8988., 9303., 6438., 3339.}, - { 3180., 6594., 10248., 10563., 10878., 7518., 3894.}, - { 3675., 7614., 11823., 12138., 12453., 8598., 4449.}, - { 2820., 5832., 9040., 9268., 9496., 6544., 3380.}, - { 1605., 3314., 5129., 5252., 5375., 3698., 1907.}}, - {{ 900., 1870., 2912., 3053., 3194., 2210., 1146.}, - { 2100., 4356., 6772., 7072., 7372., 5092., 2636.}, - { 3630., 7518., 11670., 12147., 12624., 8706., 4500.}, - { 4395., 9078., 14055., 14532., 15009., 10326., 5325.}, - { 5160., 10638., 16440., 16917., 17394., 11946., 6150.}, - { 3900., 8028., 12388., 12724., 13060., 8956., 4604.}, - { 2190., 4502., 6938., 7115., 7292., 4994., 2564.}}, - {{ 1125., 2338., 3641., 3836., 4031., 2786., 1443.}, - { 2640., 5472., 8500., 8908., 9316., 6424., 3320.}, - { 4575., 9462., 14667., 15306., 15945., 10974., 5661.}, - { 5610., 11562., 17862., 18501., 19140., 13134., 6756.}, - { 6645., 13662., 21057., 21696., 22335., 15294., 7851.}, - { 4980., 10224., 15736., 16180., 16624., 11368., 5828.}, - { 2775., 5690., 8747., 8978., 9209., 6290., 3221.}}}}); + auto y = + F::conv_transpose2d(x, weight, F::ConvTranspose2dFuncOptions().stride(1)); + auto expected = torch::tensor( + {{{{675., 1402., 2183., 2270., 2357., 1634., 849.}, + {1560., 3240., 5044., 5236., 5428., 3760., 1952.}, + {2685., 5574., 8673., 8988., 9303., 6438., 3339.}, + {3180., 6594., 10248., 10563., 10878., 7518., 3894.}, + {3675., 7614., 11823., 12138., 12453., 8598., 4449.}, + {2820., 5832., 9040., 9268., 9496., 6544., 3380.}, + {1605., 3314., 5129., 5252., 5375., 3698., 1907.}}, + {{900., 1870., 2912., 3053., 3194., 2210., 1146.}, + {2100., 4356., 6772., 7072., 7372., 5092., 2636.}, + {3630., 7518., 11670., 12147., 12624., 8706., 4500.}, + {4395., 9078., 14055., 14532., 15009., 10326., 5325.}, + {5160., 10638., 16440., 16917., 17394., 11946., 6150.}, + {3900., 8028., 12388., 12724., 13060., 8956., 4604.}, + {2190., 4502., 6938., 7115., 7292., 4994., 2564.}}, + {{1125., 2338., 3641., 3836., 4031., 2786., 1443.}, + {2640., 5472., 8500., 8908., 9316., 6424., 3320.}, + {4575., 9462., 14667., 15306., 15945., 10974., 5661.}, + {5610., 11562., 17862., 18501., 19140., 13134., 6756.}, + {6645., 13662., 21057., 21696., 22335., 15294., 7851.}, + {4980., 10224., 15736., 16180., 16624., 11368., 5828.}, + {2775., 5690., 8747., 8978., 9209., 6290., 3221.}}}}); ASSERT_TRUE(torch::allclose(y, expected)); auto y_no_options = F::conv_transpose2d(x, weight); @@ -2543,28 +2638,30 @@ TEST_F(FunctionalTest, ConvTranspose2dEven) { TEST_F(FunctionalTest, ConvTranspose2dUneven) { auto x = torch::arange(40.).view({1, 2, 5, 4}); auto weight = torch::arange(36.).view({2, 3, 3, 2}); - auto y = F::conv_transpose2d(x, weight, F::ConvTranspose2dFuncOptions().stride(1)); - auto expected = torch::tensor({{{{ 360., 758., 796., 834., 440.}, - { 832., 1752., 1836., 1920., 1012.}, - {1432., 3014., 3152., 3290., 1732.}, - {1696., 3566., 3704., 3842., 2020.}, - {1960., 4118., 4256., 4394., 2308.}, - {1504., 3152., 3252., 3352., 1756.}, - { 856., 1790., 1844., 1898., 992.}}, - {{ 480., 1010., 1072., 1134., 596.}, - {1120., 2352., 2484., 2616., 1372.}, - {1936., 4058., 4268., 4478., 2344.}, - {2344., 4898., 5108., 5318., 2776.}, - {2752., 5738., 5948., 6158., 3208.}, - {2080., 4328., 4476., 4624., 2404.}, - {1168., 2426., 2504., 2582., 1340.}}, - {{ 600., 1262., 1348., 1434., 752.}, - {1408., 2952., 3132., 3312., 1732.}, - {2440., 5102., 5384., 5666., 2956.}, - {2992., 6230., 6512., 6794., 3532.}, - {3544., 7358., 7640., 7922., 4108.}, - {2656., 5504., 5700., 5896., 3052.}, - {1480., 3062., 3164., 3266., 1688.}}}}); + auto y = + F::conv_transpose2d(x, weight, F::ConvTranspose2dFuncOptions().stride(1)); + auto expected = torch::tensor( + {{{{360., 758., 796., 834., 440.}, + {832., 1752., 1836., 1920., 1012.}, + {1432., 3014., 3152., 3290., 1732.}, + {1696., 3566., 3704., 3842., 2020.}, + {1960., 4118., 4256., 4394., 2308.}, + {1504., 3152., 3252., 3352., 1756.}, + {856., 1790., 1844., 1898., 992.}}, + {{480., 1010., 1072., 1134., 596.}, + {1120., 2352., 2484., 2616., 1372.}, + {1936., 4058., 4268., 4478., 2344.}, + {2344., 4898., 5108., 5318., 2776.}, + {2752., 5738., 5948., 6158., 3208.}, + {2080., 4328., 4476., 4624., 2404.}, + {1168., 2426., 2504., 2582., 1340.}}, + {{600., 1262., 1348., 1434., 752.}, + {1408., 2952., 3132., 3312., 1732.}, + {2440., 5102., 5384., 5666., 2956.}, + {2992., 6230., 6512., 6794., 3532.}, + {3544., 7358., 7640., 7922., 4108.}, + {2656., 5504., 5700., 5896., 3052.}, + {1480., 3062., 3164., 3266., 1688.}}}}); ASSERT_TRUE(torch::allclose(y, expected)); auto y_no_options = F::conv_transpose2d(x, weight); @@ -2574,25 +2671,15 @@ TEST_F(FunctionalTest, ConvTranspose2dUneven) { TEST_F(FunctionalTest, ConvTranspose3d) { auto x = torch::arange(16.).view({1, 2, 2, 2, 2}); auto weight = torch::arange(32.).view({2, 2, 2, 2, 2}); - auto y = F::conv_transpose3d(x, weight, F::ConvTranspose3dFuncOptions().stride(1)); - auto expected = torch::tensor({{{{{ 128., 280., 154.}, - { 304., 664., 364.}, - { 184., 400., 218.}}, - {{ 352., 768., 420.}, - { 832., 1808., 984.}, - { 496., 1072., 580.}}, - {{ 256., 552., 298.}, - { 592., 1272., 684.}, - { 344., 736., 394.}}}, - {{{ 192., 424., 234.}, - { 464., 1016., 556.}, - { 280., 608., 330.}}, - {{ 544., 1184., 644.}, - {1280., 2768., 1496.}, - { 752., 1616., 868.}}, - {{ 384., 824., 442.}, - { 880., 1880., 1004.}, - { 504., 1072., 570.}}}}}); + auto y = + F::conv_transpose3d(x, weight, F::ConvTranspose3dFuncOptions().stride(1)); + auto expected = torch::tensor( + {{{{{128., 280., 154.}, {304., 664., 364.}, {184., 400., 218.}}, + {{352., 768., 420.}, {832., 1808., 984.}, {496., 1072., 580.}}, + {{256., 552., 298.}, {592., 1272., 684.}, {344., 736., 394.}}}, + {{{192., 424., 234.}, {464., 1016., 556.}, {280., 608., 330.}}, + {{544., 1184., 644.}, {1280., 2768., 1496.}, {752., 1616., 868.}}, + {{384., 824., 442.}, {880., 1880., 1004.}, {504., 1072., 570.}}}}}); ASSERT_TRUE(torch::allclose(y, expected)); auto y_no_options = F::conv_transpose3d(x, weight); @@ -2607,7 +2694,10 @@ TEST_F(FunctionalTest, AlphaDropout) { for (const auto rate : {0.2, 0.5, 0.8}) { for (const auto inplace : {false, true}) { auto input_ = input.clone(); - auto output = F::alpha_dropout(input_, F::AlphaDropoutFuncOptions().p(rate).training(false).inplace(inplace)); + auto output = F::alpha_dropout( + input_, + F::AlphaDropoutFuncOptions().p(rate).training(false).inplace( + inplace)); ASSERT_TRUE(torch::allclose(input_mean, output.mean(), 0.1)); ASSERT_TRUE(torch::allclose(input_std, output.std(), 0.1)); if (inplace) { @@ -2628,7 +2718,10 @@ TEST_F(FunctionalTest, FeatureAlphaDropout) { for (const auto rate : {0.2, 0.5, 0.8}) { for (const auto inplace : {false, true}) { auto input_ = input.clone(); - auto output = F::feature_alpha_dropout(input_, F::FeatureAlphaDropoutFuncOptions().p(rate).training(false).inplace(inplace)); + auto output = F::feature_alpha_dropout( + input_, + F::FeatureAlphaDropoutFuncOptions().p(rate).training(false).inplace( + inplace)); ASSERT_TRUE(torch::allclose(input_mean, output.mean(), 0.1)); ASSERT_TRUE(torch::allclose(input_std, output.std(), 0.1)); if (inplace) { @@ -2685,49 +2778,50 @@ TEST_F(FunctionalTest, Dropout3d) { ASSERT_TRUE(F::dropout3d(torch::randn({2, 50, 10, 10})).defined()); } -template +template void test_isfinite(const at::Device& device) { const std::vector values = { - std::numeric_limits::lowest(), - 0, 1, 42, - std::numeric_limits::min(), - std::numeric_limits::max() - }; + std::numeric_limits::lowest(), + 0, + 1, + 42, + std::numeric_limits::min(), + std::numeric_limits::max()}; for (const auto value : values) { - const auto x = torch::full({3, 3}, value, torch::TensorOptions().dtype(S).device(device)); + const auto x = torch::full( + {3, 3}, value, torch::TensorOptions().dtype(S).device(device)); ASSERT_TRUE(torch::isfinite(x).all().template item()); } if (std::numeric_limits::has_infinity) { const auto inf = std::numeric_limits::infinity(); - const auto x = torch::tensor({ - -inf, - std::numeric_limits::lowest(), - static_cast(0), - static_cast(1), - static_cast(42), - std::numeric_limits::min(), - std::numeric_limits::max(), - inf - }, torch::TensorOptions().dtype(S).device(device)); + const auto x = torch::tensor( + {-inf, + std::numeric_limits::lowest(), + static_cast(0), + static_cast(1), + static_cast(42), + std::numeric_limits::min(), + std::numeric_limits::max(), + inf}, + torch::TensorOptions().dtype(S).device(device)); ASSERT_TRUE(torch::allclose( - // torch::allclose does not support comparing torch::kBool - torch::isfinite(x).toType(torch::kInt), - torch::tensor( - {false, true, true, true, true, true, true, false}, - torch::TensorOptions().device(device) - ).toType(torch::kInt) - )); + // torch::allclose does not support comparing torch::kBool + torch::isfinite(x).toType(torch::kInt), + torch::tensor( + {false, true, true, true, true, true, true, false}, + torch::TensorOptions().device(device)) + .toType(torch::kInt))); } if (std::numeric_limits::has_quiet_NaN) { - const auto x = torch::tensor({ - std::numeric_limits::quiet_NaN() - }, torch::TensorOptions().dtype(S).device(device)); + const auto x = torch::tensor( + {std::numeric_limits::quiet_NaN()}, + torch::TensorOptions().dtype(S).device(device)); ASSERT_FALSE(torch::isfinite(x).all().template item()); } if (std::numeric_limits::has_signaling_NaN) { - const auto x = torch::tensor({ - std::numeric_limits::signaling_NaN() - }, torch::TensorOptions().dtype(S).device(device)); + const auto x = torch::tensor( + {std::numeric_limits::signaling_NaN()}, + torch::TensorOptions().dtype(S).device(device)); ASSERT_FALSE(torch::isfinite(x).all().template item()); } } @@ -2755,49 +2849,50 @@ TEST_F(FunctionalTest, isfinite_CUDA) { test_isfinite(device); } -template +template void test_isinf(const at::Device& device) { const std::vector values = { - std::numeric_limits::lowest(), - 0, 1, 42, - std::numeric_limits::min(), - std::numeric_limits::max() - }; + std::numeric_limits::lowest(), + 0, + 1, + 42, + std::numeric_limits::min(), + std::numeric_limits::max()}; for (const auto value : values) { - const auto x = torch::full({3, 3}, value, torch::TensorOptions().dtype(S).device(device)); + const auto x = torch::full( + {3, 3}, value, torch::TensorOptions().dtype(S).device(device)); ASSERT_FALSE(torch::isinf(x).all().template item()); } if (std::numeric_limits::has_infinity) { const auto inf = std::numeric_limits::infinity(); - const auto x = torch::tensor({ - -inf, - std::numeric_limits::lowest(), - static_cast(0), - static_cast(1), - static_cast(42), - std::numeric_limits::min(), - std::numeric_limits::max(), - inf - }, torch::TensorOptions().dtype(S).device(device)); + const auto x = torch::tensor( + {-inf, + std::numeric_limits::lowest(), + static_cast(0), + static_cast(1), + static_cast(42), + std::numeric_limits::min(), + std::numeric_limits::max(), + inf}, + torch::TensorOptions().dtype(S).device(device)); ASSERT_TRUE(torch::allclose( - // torch::allclose does not support comparing torch::kBool - torch::isinf(x).toType(torch::kInt), - torch::tensor( - {true, false, false, false, false, false, false, true}, - torch::TensorOptions().device(device) - ).toType(torch::kInt) - )); + // torch::allclose does not support comparing torch::kBool + torch::isinf(x).toType(torch::kInt), + torch::tensor( + {true, false, false, false, false, false, false, true}, + torch::TensorOptions().device(device)) + .toType(torch::kInt))); } if (std::numeric_limits::has_quiet_NaN) { - const auto x = torch::tensor({ - std::numeric_limits::quiet_NaN() - }, torch::TensorOptions().dtype(S).device(device)); + const auto x = torch::tensor( + {std::numeric_limits::quiet_NaN()}, + torch::TensorOptions().dtype(S).device(device)); ASSERT_FALSE(torch::isinf(x).all().template item()); } if (std::numeric_limits::has_signaling_NaN) { - const auto x = torch::tensor({ - std::numeric_limits::signaling_NaN() - }, torch::TensorOptions().dtype(S).device(device)); + const auto x = torch::tensor( + {std::numeric_limits::signaling_NaN()}, + torch::TensorOptions().dtype(S).device(device)); ASSERT_FALSE(torch::isinf(x).all().template item()); } } @@ -2825,17 +2920,20 @@ TEST_F(FunctionalTest, isinf_CUDA) { test_isinf(device); } -template +template void test_allclose(const at::Device& device) { const std::vector values = { - std::numeric_limits::lowest(), - 0, 1, 42, - std::numeric_limits::min(), - std::numeric_limits::max() - }; + std::numeric_limits::lowest(), + 0, + 1, + 42, + std::numeric_limits::min(), + std::numeric_limits::max()}; for (const auto value : values) { - const auto x = torch::full({1}, value, torch::TensorOptions().dtype(S).device(device)); - const auto y = torch::full({1}, value, torch::TensorOptions().dtype(S).device(device)); + const auto x = + torch::full({1}, value, torch::TensorOptions().dtype(S).device(device)); + const auto y = + torch::full({1}, value, torch::TensorOptions().dtype(S).device(device)); ASSERT_TRUE(torch::allclose(x, x)); ASSERT_TRUE(torch::allclose(x, y)); ASSERT_TRUE(torch::allclose(y, x)); @@ -2844,32 +2942,32 @@ void test_allclose(const at::Device& device) { } if (std::numeric_limits::has_infinity) { const auto inf = std::numeric_limits::infinity(); - const auto x = torch::tensor({-inf, inf}, - torch::TensorOptions().dtype(S).device(device)); - const auto y = torch::tensor({-inf, inf}, - torch::TensorOptions().dtype(S).device(device)); + const auto x = torch::tensor( + {-inf, inf}, torch::TensorOptions().dtype(S).device(device)); + const auto y = torch::tensor( + {-inf, inf}, torch::TensorOptions().dtype(S).device(device)); ASSERT_TRUE(torch::allclose(x, x)); ASSERT_TRUE(torch::allclose(x, y)); ASSERT_TRUE(torch::allclose(y, x)); } if (std::numeric_limits::has_quiet_NaN) { - const auto x = torch::tensor({ - std::numeric_limits::quiet_NaN() - }, torch::TensorOptions().dtype(S).device(device)); - const auto y = torch::tensor({ - std::numeric_limits::quiet_NaN() - }, torch::TensorOptions().dtype(S).device(device)); + const auto x = torch::tensor( + {std::numeric_limits::quiet_NaN()}, + torch::TensorOptions().dtype(S).device(device)); + const auto y = torch::tensor( + {std::numeric_limits::quiet_NaN()}, + torch::TensorOptions().dtype(S).device(device)); ASSERT_TRUE(torch::allclose(x, x, 1.0, 0.0, /*equal_nan=*/true)); ASSERT_TRUE(torch::allclose(x, y, 1.0, 0.0, /*equal_nan=*/true)); ASSERT_TRUE(torch::allclose(y, x, 1.0, 0.0, /*equal_nan=*/true)); } if (std::numeric_limits::has_signaling_NaN) { - const auto x = torch::tensor({ - std::numeric_limits::signaling_NaN() - }, torch::TensorOptions().dtype(S).device(device)); - const auto y = torch::tensor({ - std::numeric_limits::signaling_NaN() - }, torch::TensorOptions().dtype(S).device(device)); + const auto x = torch::tensor( + {std::numeric_limits::signaling_NaN()}, + torch::TensorOptions().dtype(S).device(device)); + const auto y = torch::tensor( + {std::numeric_limits::signaling_NaN()}, + torch::TensorOptions().dtype(S).device(device)); ASSERT_TRUE(torch::allclose(x, x, 1.0, 0.0, /*equal_nan=*/true)); ASSERT_TRUE(torch::allclose(x, y, 1.0, 0.0, /*equal_nan=*/true)); ASSERT_TRUE(torch::allclose(y, x, 1.0, 0.0, /*equal_nan=*/true)); @@ -2900,165 +2998,178 @@ TEST_F(FunctionalTest, AllClose_CUDA) { } TEST_F(FunctionalTest, BCEWithLogitsLoss) { - { // test BCE with logits raises if target and input are different size - { - const auto target = torch::rand(5); - const auto input = torch::rand({5, 1}); - ASSERT_THROWS_WITH( - F::binary_cross_entropy_with_logits(input, target), - "must be the same as input size" - ); - } - { - const auto target = torch::rand({5, 1}); - const auto input = torch::rand(5); - ASSERT_THROWS_WITH( - F::binary_cross_entropy_with_logits(input, target), - "must be the same as input size" - ); - } - } - { // test BCE with logits gives same result as sigmoid and bce loss - auto sigmoid = Sigmoid(); + {// test BCE with logits raises if target and input are different size + {const auto target = torch::rand(5); + const auto input = torch::rand({5, 1}); + ASSERT_THROWS_WITH( + F::binary_cross_entropy_with_logits(input, target), + "must be the same as input size"); +} +{ + const auto target = torch::rand({5, 1}); + const auto input = torch::rand(5); + ASSERT_THROWS_WITH( + F::binary_cross_entropy_with_logits(input, target), + "must be the same as input size"); +} +} +{ // test BCE with logits gives same result as sigmoid and bce loss + auto sigmoid = Sigmoid(); - auto target = torch::rand({64, 4}); - auto output = torch::rand({64, 4}) - 0.5; + auto target = torch::rand({64, 4}); + auto output = torch::rand({64, 4}) - 0.5; - ASSERT_TRUE(torch::allclose( + ASSERT_TRUE(torch::allclose( F::binary_cross_entropy_with_logits(output, target), - F::binary_cross_entropy(sigmoid(output), target) - )); - - auto weight = torch::rand(4); - ASSERT_TRUE(torch::allclose( - F::binary_cross_entropy_with_logits(output, target, - F::BinaryCrossEntropyWithLogitsFuncOptions().weight(weight) - ), - F::binary_cross_entropy(sigmoid(output), target, - F::BinaryCrossEntropyFuncOptions().weight(weight) - ) - )); + F::binary_cross_entropy(sigmoid(output), target))); - target = torch::zeros({4, 1}, torch::kFloat); - output = torch::empty({4, 1}, torch::kFloat).fill_(-100); + auto weight = torch::rand(4); + ASSERT_TRUE(torch::allclose( + F::binary_cross_entropy_with_logits( + output, + target, + F::BinaryCrossEntropyWithLogitsFuncOptions().weight(weight)), + F::binary_cross_entropy( + sigmoid(output), + target, + F::BinaryCrossEntropyFuncOptions().weight(weight)))); + + target = torch::zeros({4, 1}, torch::kFloat); + output = torch::empty({4, 1}, torch::kFloat).fill_(-100); - ASSERT_TRUE(torch::allclose( + ASSERT_TRUE(torch::allclose( F::binary_cross_entropy_with_logits(output, target), - F::binary_cross_entropy(sigmoid(output), target) - )); + F::binary_cross_entropy(sigmoid(output), target))); - ASSERT_TRUE(torch::allclose( - F::binary_cross_entropy_with_logits(output, target, - F::BinaryCrossEntropyWithLogitsFuncOptions().reduction(torch::kNone) - ), - F::binary_cross_entropy(sigmoid(output), target, - F::BinaryCrossEntropyFuncOptions().reduction(torch::kNone) - ) - )); - - weight = torch::rand({1}, torch::kFloat); - ASSERT_TRUE(torch::allclose( - F::binary_cross_entropy_with_logits(output, target, - F::BinaryCrossEntropyWithLogitsFuncOptions().weight(weight) - ), - F::binary_cross_entropy(sigmoid(output), target, - F::BinaryCrossEntropyFuncOptions().weight(weight) - ) - )); - } - { // test BCE with logits has correct grad at zero - const auto output = torch::zeros({3, 1}, torch::requires_grad()); - const auto target = torch::zeros({3, 1}); - F::binary_cross_entropy_with_logits(output, target, - F::BinaryCrossEntropyWithLogitsFuncOptions().reduction(torch::kSum) - ).backward(); - const auto expected_grad = torch::empty({3, 1}).fill_(0.5); - ASSERT_TRUE(torch::allclose(output.grad(), expected_grad)); - } - { // test BCE with logits broadcasts weights - const auto target = torch::rand({16, 4}); - const auto output = torch::rand({16, 4}) - 0.5; - - auto weight = torch::rand(4); - auto out1 = F::binary_cross_entropy_with_logits(output, target, - F::BinaryCrossEntropyWithLogitsFuncOptions().weight(weight) - ); - - weight = weight.expand({16, 4}).contiguous(); - auto out2 = F::binary_cross_entropy_with_logits(output, target, - F::BinaryCrossEntropyWithLogitsFuncOptions().weight(weight) - ); - - ASSERT_TRUE(torch::allclose(out1, out2)); - - weight = torch::rand({16, 1}); - out1 = F::binary_cross_entropy_with_logits(output, target, - F::BinaryCrossEntropyWithLogitsFuncOptions().weight(weight) - ); - - weight = weight.expand({16, 4}).contiguous(); - out2 = F::binary_cross_entropy_with_logits(output, target, - F::BinaryCrossEntropyWithLogitsFuncOptions().weight(weight) - ); - - ASSERT_TRUE(torch::allclose(out1, out2)); - } - { // test BCE with logits ones in pos weights are the same as none - const auto target = torch::rand({64, 4}); - const auto output = torch::rand({64, 4}) - 0.5; - const auto pos_weight = torch::ones({64, 4}); + ASSERT_TRUE(torch::allclose( + F::binary_cross_entropy_with_logits( + output, + target, + F::BinaryCrossEntropyWithLogitsFuncOptions().reduction(torch::kNone)), + F::binary_cross_entropy( + sigmoid(output), + target, + F::BinaryCrossEntropyFuncOptions().reduction(torch::kNone)))); + + weight = torch::rand({1}, torch::kFloat); + ASSERT_TRUE(torch::allclose( + F::binary_cross_entropy_with_logits( + output, + target, + F::BinaryCrossEntropyWithLogitsFuncOptions().weight(weight)), + F::binary_cross_entropy( + sigmoid(output), + target, + F::BinaryCrossEntropyFuncOptions().weight(weight)))); +} +{ // test BCE with logits has correct grad at zero + const auto output = torch::zeros({3, 1}, torch::requires_grad()); + const auto target = torch::zeros({3, 1}); + F::binary_cross_entropy_with_logits( + output, + target, + F::BinaryCrossEntropyWithLogitsFuncOptions().reduction(torch::kSum)) + .backward(); + const auto expected_grad = torch::empty({3, 1}).fill_(0.5); + ASSERT_TRUE(torch::allclose(output.grad(), expected_grad)); +} +{ // test BCE with logits broadcasts weights + const auto target = torch::rand({16, 4}); + const auto output = torch::rand({16, 4}) - 0.5; + + auto weight = torch::rand(4); + auto out1 = F::binary_cross_entropy_with_logits( + output, + target, + F::BinaryCrossEntropyWithLogitsFuncOptions().weight(weight)); + + weight = weight.expand({16, 4}).contiguous(); + auto out2 = F::binary_cross_entropy_with_logits( + output, + target, + F::BinaryCrossEntropyWithLogitsFuncOptions().weight(weight)); + + ASSERT_TRUE(torch::allclose(out1, out2)); + + weight = torch::rand({16, 1}); + out1 = F::binary_cross_entropy_with_logits( + output, + target, + F::BinaryCrossEntropyWithLogitsFuncOptions().weight(weight)); + + weight = weight.expand({16, 4}).contiguous(); + out2 = F::binary_cross_entropy_with_logits( + output, + target, + F::BinaryCrossEntropyWithLogitsFuncOptions().weight(weight)); + + ASSERT_TRUE(torch::allclose(out1, out2)); +} +{ // test BCE with logits ones in pos weights are the same as none + const auto target = torch::rand({64, 4}); + const auto output = torch::rand({64, 4}) - 0.5; + const auto pos_weight = torch::ones({64, 4}); - ASSERT_TRUE(torch::allclose( + ASSERT_TRUE(torch::allclose( F::binary_cross_entropy_with_logits(output, target), - F::binary_cross_entropy_with_logits(output, target, - F::BinaryCrossEntropyWithLogitsFuncOptions().pos_weight(pos_weight) - ) - )); - } - { // test BCE with logits broadcasts pos weights - const auto target = torch::rand({64, 4}); - const auto output = torch::rand({64, 4}) - 0.5; - const auto pos_weight = torch::rand(4); - const auto out1 = F::binary_cross_entropy_with_logits(output, target, - F::BinaryCrossEntropyWithLogitsFuncOptions().pos_weight(pos_weight) - ); - - const auto pos_weight1 = pos_weight.expand({1, 4}); - const auto out2 = F::binary_cross_entropy_with_logits(output, target, - F::BinaryCrossEntropyWithLogitsFuncOptions().pos_weight(pos_weight) - ); - - const auto pos_weight2 = pos_weight.expand({64, 4}); - const auto out3 = F::binary_cross_entropy_with_logits(output, target, - F::BinaryCrossEntropyWithLogitsFuncOptions().pos_weight(pos_weight) - ); - - ASSERT_TRUE(torch::allclose(out1, out2)); - ASSERT_TRUE(torch::allclose(out1, out3)); - } - { // test BCE with logits with pos weight has correct grad at zero - const auto output = torch::zeros({3, 1}, torch::requires_grad()); - const auto target = torch::zeros({3, 1}); - const auto pos_weight = torch::ones({3, 1}); - F::binary_cross_entropy_with_logits(output, target, - F::BinaryCrossEntropyWithLogitsFuncOptions().pos_weight(pos_weight).reduction(torch::kSum) - ).backward(); - const auto expected_grad = torch::empty({3, 1}).fill_(0.5); - // NOLINTNEXTLINE(performance-unnecessary-copy-initialization) - const auto grad = output.grad(); - ASSERT_TRUE(torch::allclose(grad, expected_grad)); - } - { // test BCE with logits stability - const auto output = torch::tensor({0., -120.}); - const auto target = torch::tensor({0., 1.}); - const auto pos_weight = torch::tensor({1., 1.}); - - const auto out1 = F::binary_cross_entropy_with_logits(output, target); - ASSERT_TRUE(torch::isfinite(out1).all().item()); - - const auto out2 = F::binary_cross_entropy_with_logits(output, target, - F::BinaryCrossEntropyWithLogitsFuncOptions().pos_weight(pos_weight) - ); - ASSERT_TRUE(torch::isfinite(out2).all().item()); - } + F::binary_cross_entropy_with_logits( + output, + target, + F::BinaryCrossEntropyWithLogitsFuncOptions().pos_weight( + pos_weight)))); +} +{ // test BCE with logits broadcasts pos weights + const auto target = torch::rand({64, 4}); + const auto output = torch::rand({64, 4}) - 0.5; + const auto pos_weight = torch::rand(4); + const auto out1 = F::binary_cross_entropy_with_logits( + output, + target, + F::BinaryCrossEntropyWithLogitsFuncOptions().pos_weight(pos_weight)); + + const auto pos_weight1 = pos_weight.expand({1, 4}); + const auto out2 = F::binary_cross_entropy_with_logits( + output, + target, + F::BinaryCrossEntropyWithLogitsFuncOptions().pos_weight(pos_weight)); + + const auto pos_weight2 = pos_weight.expand({64, 4}); + const auto out3 = F::binary_cross_entropy_with_logits( + output, + target, + F::BinaryCrossEntropyWithLogitsFuncOptions().pos_weight(pos_weight)); + + ASSERT_TRUE(torch::allclose(out1, out2)); + ASSERT_TRUE(torch::allclose(out1, out3)); +} +{ // test BCE with logits with pos weight has correct grad at zero + const auto output = torch::zeros({3, 1}, torch::requires_grad()); + const auto target = torch::zeros({3, 1}); + const auto pos_weight = torch::ones({3, 1}); + F::binary_cross_entropy_with_logits( + output, + target, + F::BinaryCrossEntropyWithLogitsFuncOptions() + .pos_weight(pos_weight) + .reduction(torch::kSum)) + .backward(); + const auto expected_grad = torch::empty({3, 1}).fill_(0.5); + // NOLINTNEXTLINE(performance-unnecessary-copy-initialization) + const auto grad = output.grad(); + ASSERT_TRUE(torch::allclose(grad, expected_grad)); +} +{ // test BCE with logits stability + const auto output = torch::tensor({0., -120.}); + const auto target = torch::tensor({0., 1.}); + const auto pos_weight = torch::tensor({1., 1.}); + + const auto out1 = F::binary_cross_entropy_with_logits(output, target); + ASSERT_TRUE(torch::isfinite(out1).all().item()); + + const auto out2 = F::binary_cross_entropy_with_logits( + output, + target, + F::BinaryCrossEntropyWithLogitsFuncOptions().pos_weight(pos_weight)); + ASSERT_TRUE(torch::isfinite(out2).all().item()); +} } diff --git a/test/cpp/api/grad_mode.cpp b/test/cpp/api/grad_mode.cpp index ff5aa57b35ca20..edb4e21234b997 100644 --- a/test/cpp/api/grad_mode.cpp +++ b/test/cpp/api/grad_mode.cpp @@ -1,6 +1,6 @@ -#include #include #include +#include using namespace torch::autograd; using namespace torch::test; @@ -38,15 +38,17 @@ TEST(GradModeTest, TestRequiresGradViewOp) { } TEST(GradModeTest, TestRequiresGradViewOpExiting) { - for (bool requires_grad: {true, false}) { + for (bool requires_grad : {true, false}) { torch::Tensor s = torch::ones({1, 2, 3}).set_requires_grad(requires_grad); torch::Tensor a = s.clone(); torch::Tensor view_out, tmp; { torch::AutoGradMode mode(false); - view_out = a.view({2, 3}); // go through kernels: VariableType, ADInplaceOrView, CPU - assert_tensor_creation_meta(view_out, torch::autograd::CreationMeta::NO_GRAD_MODE); + view_out = a.view( + {2, 3}); // go through kernels: VariableType, ADInplaceOrView, CPU + assert_tensor_creation_meta( + view_out, torch::autograd::CreationMeta::NO_GRAD_MODE); ASSERT_EQ(view_out.requires_grad(), requires_grad); ASSERT_TRUE(view_out.is_leaf()); } @@ -60,14 +62,17 @@ TEST(GradModeTest, TestRequiresGradViewOpExiting) { } if (requires_grad) { - ASSERT_THROWS_WITH(view_out.mul_(2), // go through kernels: VariableType, ADInplaceOrView, CPU - "A view was created in no_grad mode and is being modified inplace"); + ASSERT_THROWS_WITH( + view_out.mul_( + 2), // go through kernels: VariableType, ADInplaceOrView, CPU + "A view was created in no_grad mode and is being modified inplace"); } else { - view_out.mul_(2); + view_out.mul_(2); } tmp = view_out.view({2, 3}); ASSERT_EQ(tmp.requires_grad(), requires_grad); - assert_tensor_creation_meta(tmp, torch::autograd::CreationMeta::NO_GRAD_MODE); + assert_tensor_creation_meta( + tmp, torch::autograd::CreationMeta::NO_GRAD_MODE); } } diff --git a/test/cpp/api/imethod.cpp b/test/cpp/api/imethod.cpp index 41104914a2a180..6d257d377acb19 100644 --- a/test/cpp/api/imethod.cpp +++ b/test/cpp/api/imethod.cpp @@ -17,8 +17,10 @@ const char* path(const char* envname, const char* path) { return env ? env : path; } -// Run `python torch/csrc/deploy/example/generate_examples.py` before running the following tests. -// TODO(jwtan): Figure out a way to automate the above step for development. (CI has it already.) +// Run `python torch/csrc/deploy/example/generate_examples.py` before running +// the following tests. +// TODO(jwtan): Figure out a way to automate the above step for development. (CI +// has it already.) TEST(IMethodTest, CallMethod) { auto scriptModel = torch::jit::load(path("SIMPLE_JIT", simpleJit)); auto scriptMethod = scriptModel.get_method("forward"); diff --git a/test/cpp/api/inference_mode.cpp b/test/cpp/api/inference_mode.cpp index f11db467ce725e..c6b0f592df0829 100644 --- a/test/cpp/api/inference_mode.cpp +++ b/test/cpp/api/inference_mode.cpp @@ -1,39 +1,48 @@ -#include #include #include +#include using namespace torch::autograd; using namespace torch::test; namespace { - torch::Tensor functional_op(torch::Tensor& x) { - return x * x; - } +torch::Tensor functional_op(torch::Tensor& x) { + return x * x; +} - void inplace_op(torch::Tensor& x) { - x.mul_(1); - } +void inplace_op(torch::Tensor& x) { + x.mul_(1); +} - torch::Tensor view_op(torch::Tensor& x) { - return x.view({2, 3}); - } +torch::Tensor view_op(torch::Tensor& x) { + return x.view({2, 3}); +} - /* - Only the following combos of Autograd & ADInplaceOrView keys on tensors are valid: - - Autograd=true, ADInplaceOrView=true (normal tensor) - - Autograd=false, ADInplaceOrView=false (inference tensor) - Tensors created in InferenceMode are mostly inference tensors. The only exception - is that view of normal tensors created in InferenceMode still produce normal tensor. - */ - void assert_TLS_states(bool inference_mode) { - ASSERT_EQ(InferenceMode::is_enabled(), inference_mode); - ASSERT_FALSE(c10::impl::tls_is_dispatch_key_excluded(c10::DispatchKey::ADInplaceOrView)); - ASSERT_FALSE(c10::impl::tls_is_dispatch_keyset_included(c10::autograd_dispatch_keyset)); - ASSERT_EQ(c10::impl::tls_is_dispatch_keyset_excluded(c10::autograd_dispatch_keyset), inference_mode); - ASSERT_EQ(c10::impl::tls_is_dispatch_key_included(c10::DispatchKey::ADInplaceOrView), !inference_mode); - ASSERT_EQ(GradMode::is_enabled(), !inference_mode); - } +/* + Only the following combos of Autograd & ADInplaceOrView keys on tensors are + valid: + - Autograd=true, ADInplaceOrView=true (normal tensor) + - Autograd=false, ADInplaceOrView=false (inference tensor) + Tensors created in InferenceMode are mostly inference tensors. The only + exception is that view of normal tensors created in InferenceMode still + produce normal tensor. +*/ +void assert_TLS_states(bool inference_mode) { + ASSERT_EQ(InferenceMode::is_enabled(), inference_mode); + ASSERT_FALSE(c10::impl::tls_is_dispatch_key_excluded( + c10::DispatchKey::ADInplaceOrView)); + ASSERT_FALSE(c10::impl::tls_is_dispatch_keyset_included( + c10::autograd_dispatch_keyset)); + ASSERT_EQ( + c10::impl::tls_is_dispatch_keyset_excluded(c10::autograd_dispatch_keyset), + inference_mode); + ASSERT_EQ( + c10::impl::tls_is_dispatch_key_included( + c10::DispatchKey::ADInplaceOrView), + !inference_mode); + ASSERT_EQ(GradMode::is_enabled(), !inference_mode); } +} // namespace TEST(InferenceModeTest, TestTLSState) { assert_TLS_states(false); @@ -57,7 +66,8 @@ TEST(InferenceModeTest, TestInferenceTensorCreation) { ASSERT_FALSE(c.requires_grad()); ASSERT_TRUE(c.is_inference()); - // requires_grad doesn't change inference tensor behavior inside InferenceMode. + // requires_grad doesn't change inference tensor behavior inside + // InferenceMode. torch::Tensor tmp = torch::ones({1, 2, 3}).set_requires_grad(true); ASSERT_TRUE(tmp.requires_grad()); ASSERT_TRUE(tmp.is_inference()); @@ -78,9 +88,11 @@ TEST(InferenceModeTest, TestExistingAutogradSession) { InferenceMode guard; inplace_op(a); } - // Performing backward should trigger error since `a`'s version has been bumped. - ASSERT_THROWS_WITH(out.backward(torch::ones_like(out)), - "one of the variables needed for gradient computation has been modified by an inplace operation") + // Performing backward should trigger error since `a`'s version has been + // bumped. + ASSERT_THROWS_WITH( + out.backward(torch::ones_like(out)), + "one of the variables needed for gradient computation has been modified by an inplace operation") } TEST(InferenceModeTest, TestInferenceTensorInInferenceModeFunctionalOp) { @@ -88,7 +100,7 @@ TEST(InferenceModeTest, TestInferenceTensorInInferenceModeFunctionalOp) { for (bool requires_grad : {true, false}) { torch::Tensor c = torch::ones({1, 2, 3}).set_requires_grad(requires_grad); - torch::Tensor func_out = functional_op(c); // go through kernels: CPU + torch::Tensor func_out = functional_op(c); // go through kernels: CPU ASSERT_TRUE(func_out.is_inference()); ASSERT_FALSE(func_out.requires_grad()); } @@ -99,7 +111,7 @@ TEST(InferenceModeTest, TestInferenceTensorInInferenceModeInplaceOp) { for (bool requires_grad : {true, false}) { torch::Tensor c = torch::ones({1, 2, 3}).set_requires_grad(requires_grad); - inplace_op(c); // go through kernels: CPU + inplace_op(c); // go through kernels: CPU ASSERT_TRUE(c.is_inference()); ASSERT_EQ(c.requires_grad(), requires_grad); } @@ -110,7 +122,7 @@ TEST(InferenceModeTest, TestInferenceTensorInInferenceModeViewOp) { for (bool requires_grad : {true, false}) { torch::Tensor c = torch::ones({1, 2, 3}).set_requires_grad(requires_grad); - torch::Tensor view_out = view_op(c); // go through kernels: CPU + torch::Tensor view_out = view_op(c); // go through kernels: CPU ASSERT_TRUE(view_out.is_inference()); // Note this is different from NoGradMode but makes sense. ASSERT_FALSE(view_out.requires_grad()); @@ -120,17 +132,20 @@ TEST(InferenceModeTest, TestInferenceTensorInInferenceModeViewOp) { TEST(InferenceModeTest, TestInferenceTensorInNormalModeFunctionalOp) { torch::Tensor inference_tensor; - for (bool requires_grad: {true, false}) { + for (bool requires_grad : {true, false}) { { InferenceMode guard; - inference_tensor = torch::ones({1, 2, 3}).set_requires_grad(requires_grad); + inference_tensor = + torch::ones({1, 2, 3}).set_requires_grad(requires_grad); } - // Due to issue #54614, this might run slower compared to InferenceMode since - // intermediate tensors are normal tensors, and they might dispatch to VariableType - // kernels. This is fine since users can easily fix it by moving - // it inside InferenceMode block. - torch::Tensor tmp = functional_op(inference_tensor); // go through kernels: ADInplaceOrView(fallthrough), CPU + // Due to issue #54614, this might run slower compared to InferenceMode + // since intermediate tensors are normal tensors, and they might dispatch to + // VariableType kernels. This is fine since users can easily fix it by + // moving it inside InferenceMode block. + torch::Tensor tmp = + functional_op(inference_tensor); // go through kernels: + // ADInplaceOrView(fallthrough), CPU ASSERT_FALSE(tmp.is_inference()); ASSERT_FALSE(tmp.requires_grad()); } @@ -138,24 +153,29 @@ TEST(InferenceModeTest, TestInferenceTensorInNormalModeFunctionalOp) { TEST(InferenceModeTest, TestInferenceTensorInNormalModeInplaceOp) { torch::Tensor inference_tensor; - for (bool requires_grad: {true, false}) { + for (bool requires_grad : {true, false}) { { InferenceMode guard; - inference_tensor = torch::ones({1, 2, 3}).set_requires_grad(requires_grad); + inference_tensor = + torch::ones({1, 2, 3}).set_requires_grad(requires_grad); } - ASSERT_THROWS_WITH(inplace_op(inference_tensor), // go through kernels: ADInplaceOrView, CPU - "Inplace update to inference tensor outside InferenceMode is not allowed"); + ASSERT_THROWS_WITH( + inplace_op( + inference_tensor), // go through kernels: ADInplaceOrView, CPU + "Inplace update to inference tensor outside InferenceMode is not allowed"); } } TEST(InferenceModeTest, TestInferenceTensorInNormalModeViewOp) { torch::Tensor inference_tensor; - for (bool requires_grad: {true, false}) { + for (bool requires_grad : {true, false}) { { InferenceMode guard; - inference_tensor = torch::ones({1, 2, 3}).set_requires_grad(requires_grad); + inference_tensor = + torch::ones({1, 2, 3}).set_requires_grad(requires_grad); } - torch::Tensor out = view_op(inference_tensor); // go through kernels: ADInplaceOrView, CPU + torch::Tensor out = + view_op(inference_tensor); // go through kernels: ADInplaceOrView, CPU ASSERT_TRUE(out.is_inference()); ASSERT_FALSE(out.requires_grad()); ASSERT_FALSE(out.is_view()); @@ -164,24 +184,25 @@ TEST(InferenceModeTest, TestInferenceTensorInNormalModeViewOp) { } TEST(InferenceModeTest, TestNormalTensorInplaceOutputInInferenceMode) { - for (bool requires_grad: {true, false}) { + for (bool requires_grad : {true, false}) { torch::Tensor s = torch::ones({1, 2, 3}).set_requires_grad(requires_grad); torch::Tensor a = s.clone(); { c10::InferenceMode guard; - inplace_op(a); // go through kernels: ADInplaceOrView, CPU + inplace_op(a); // go through kernels: ADInplaceOrView, CPU ASSERT_FALSE(a.is_inference()); ASSERT_EQ(a.requires_grad(), requires_grad); // inplace -> inplace - inplace_op(a); // go through kernels: ADInplaceOrView, CPU + inplace_op(a); // go through kernels: ADInplaceOrView, CPU ASSERT_FALSE(a.is_inference()); ASSERT_EQ(a.requires_grad(), requires_grad); // inplace -> inplace -> view - torch::Tensor view_out = view_op(a); // go through kernels: ADInplaceOrView, CPU + torch::Tensor view_out = + view_op(a); // go through kernels: ADInplaceOrView, CPU ASSERT_FALSE(view_out.is_inference()); ASSERT_EQ(view_out.requires_grad(), requires_grad); } @@ -189,19 +210,20 @@ TEST(InferenceModeTest, TestNormalTensorInplaceOutputInInferenceMode) { } TEST(InferenceModeTest, TestNormalTensorInplaceOutputInNormalMode) { - for (bool requires_grad: {true, false}) { + for (bool requires_grad : {true, false}) { torch::Tensor s = torch::ones({1, 2, 3}).set_requires_grad(requires_grad); torch::Tensor a = s.clone(); { c10::InferenceMode guard; - inplace_op(a); // go through kernels: ADInplaceOrView, CPU + inplace_op(a); // go through kernels: ADInplaceOrView, CPU ASSERT_FALSE(a.is_inference()); ASSERT_EQ(a.requires_grad(), requires_grad); } - torch::Tensor tmp = functional_op(a); // go through kernels: VariableType, ADInplaceOrView(fallthrough), CPU + torch::Tensor tmp = functional_op(a); // go through kernels: VariableType, + // ADInplaceOrView(fallthrough), CPU ASSERT_FALSE(tmp.is_inference()); ASSERT_EQ(tmp.requires_grad(), requires_grad); @@ -209,14 +231,14 @@ TEST(InferenceModeTest, TestNormalTensorInplaceOutputInNormalMode) { ASSERT_FALSE(a.is_inference()); ASSERT_EQ(a.requires_grad(), requires_grad); - tmp = view_op(a); // go through kernels: VariableType, ADInplaceOrView, CPU + tmp = view_op(a); // go through kernels: VariableType, ADInplaceOrView, CPU ASSERT_FALSE(tmp.is_inference()); ASSERT_EQ(tmp.requires_grad(), requires_grad); } } TEST(InferenceModeTest, TestNormalTensorViewOutputInInferenceMode) { - for (bool requires_grad: {true, false}) { + for (bool requires_grad : {true, false}) { torch::Tensor s = torch::ones({1, 2, 3}).set_requires_grad(requires_grad); torch::Tensor a = s.clone(); torch::Tensor view_out, tmp; @@ -231,25 +253,26 @@ TEST(InferenceModeTest, TestNormalTensorViewOutputInInferenceMode) { // Storage(self.storage()), self.key_set(), self.dtype()); // ``` // In addition, these view output tensors are normal in the sense they - // have both Autograd and ADInplaceOrView keys. But they're still special - // since they'll have CreationMeta::INFERENCE_MODE. In other words they behave - // exactly the same as a view tensor created in no_grad mode. + // have both Autograd and ADInplaceOrView keys. But they're still + // special since they'll have CreationMeta::INFERENCE_MODE. In other + // words they behave exactly the same as a view tensor created in + // no_grad mode. - view_out = view_op(a); // go through kernels: ADInplaceOrView, CPU + view_out = view_op(a); // go through kernels: ADInplaceOrView, CPU ASSERT_FALSE(view_out.is_inference()); assert_tensor_creation_meta(view_out, CreationMeta::INFERENCE_MODE); ASSERT_EQ(view_out.requires_grad(), requires_grad); ASSERT_TRUE(view_out.is_leaf()); // view -> view - tmp = view_op(view_out); // go through kernels: ADInplaceOrView, CPU + tmp = view_op(view_out); // go through kernels: ADInplaceOrView, CPU ASSERT_FALSE(tmp.is_inference()); assert_tensor_creation_meta(tmp, CreationMeta::INFERENCE_MODE); ASSERT_EQ(tmp.requires_grad(), requires_grad); ASSERT_TRUE(tmp.is_leaf()); // view -> view -> inplace - inplace_op(tmp); // kernels: ADInplaceOrView, CPU + inplace_op(tmp); // kernels: ADInplaceOrView, CPU assert_tensor_creation_meta(tmp, CreationMeta::INFERENCE_MODE); ASSERT_FALSE(tmp.is_inference()); ASSERT_EQ(tmp.requires_grad(), requires_grad); @@ -260,14 +283,14 @@ TEST(InferenceModeTest, TestNormalTensorViewOutputInInferenceMode) { } TEST(InferenceModeTest, TestNormalTensorViewOutputInNormalMode) { - for (bool requires_grad: {true, false}) { + for (bool requires_grad : {true, false}) { torch::Tensor s = torch::ones({1, 2, 3}).set_requires_grad(requires_grad); torch::Tensor a = s.clone(); torch::Tensor view_out, tmp; { c10::InferenceMode guard; - view_out = view_op(a); // go through kernels: ADInplaceOrView, CPU + view_out = view_op(a); // go through kernels: ADInplaceOrView, CPU ASSERT_FALSE(view_out.is_inference()); assert_tensor_creation_meta(view_out, CreationMeta::INFERENCE_MODE); ASSERT_EQ(view_out.requires_grad(), requires_grad); @@ -279,8 +302,10 @@ TEST(InferenceModeTest, TestNormalTensorViewOutputInNormalMode) { ASSERT_EQ(tmp.requires_grad(), requires_grad); if (requires_grad) { - ASSERT_THROWS_WITH(inplace_op(view_out), // go through kernels: VariableType, ADInplaceOrView, CPU - "A view was created in inference mode and is being modified inplace") + ASSERT_THROWS_WITH( + inplace_op(view_out), // go through kernels: VariableType, + // ADInplaceOrView, CPU + "A view was created in inference mode and is being modified inplace") } else { inplace_op(view_out); } @@ -292,7 +317,7 @@ TEST(InferenceModeTest, TestNormalTensorViewOutputInNormalMode) { } TEST(InferenceModeTest, TestMixInferenceAndNormalTensorFunctionalOp) { - for (bool requires_grad: {true, false}) { + for (bool requires_grad : {true, false}) { torch::Tensor s = torch::ones({1, 2, 3}).set_requires_grad(requires_grad); torch::Tensor c; { @@ -300,8 +325,10 @@ TEST(InferenceModeTest, TestMixInferenceAndNormalTensorFunctionalOp) { c = torch::ones({1, 2, 3}).set_requires_grad(requires_grad); } - // add(Tensor, Tensor) is safe with inference tensor since it doesn't save any variable for backward. - torch::Tensor out = c.add(s); // go through kernels: VariableType, ADInplaceOrView(fallthrough), CPU + // add(Tensor, Tensor) is safe with inference tensor since it doesn't save + // any variable for backward. + torch::Tensor out = c.add(s); // go through kernels: VariableType, + // ADInplaceOrView(fallthrough), CPU ASSERT_FALSE(out.is_inference()); ASSERT_EQ(out.requires_grad(), requires_grad); if (requires_grad) { @@ -313,19 +340,21 @@ TEST(InferenceModeTest, TestMixInferenceAndNormalTensorFunctionalOp) { if (requires_grad) { // mul(self, other) saves variable when requires_grad=true - ASSERT_THROWS_WITH(c.mul(s), - "Inference tensors cannot be saved for backward."); + ASSERT_THROWS_WITH( + c.mul(s), "Inference tensors cannot be saved for backward."); // Inference tensor in TensorList input std::vector inputs = {s, c}; - ASSERT_THROWS_WITH(torch::stack(inputs), // go through kernels: VariableType(ERROR)!, ADInplaceOrView(fallthrough), CPU - "Inference tensors cannot be saved for backward.") + ASSERT_THROWS_WITH( + torch::stack(inputs), // go through kernels: VariableType(ERROR)!, + // ADInplaceOrView(fallthrough), CPU + "Inference tensors cannot be saved for backward.") } } } TEST(InferenceModeTest, TestMixInferenceAndNormalTensorInplaceOp) { - for (bool requires_grad: {true, false}) { + for (bool requires_grad : {true, false}) { torch::Tensor s = torch::ones({1, 2, 3}).set_requires_grad(requires_grad); torch::Tensor a = s.clone(); torch::Tensor c; @@ -335,22 +364,29 @@ TEST(InferenceModeTest, TestMixInferenceAndNormalTensorInplaceOp) { } if (requires_grad) { - ASSERT_THROWS_WITH(a.mul_(c), // go through kernels: VariableType(ERROR!), InferenceMode, CPU - "Inference tensors cannot be saved for backward."); - - ASSERT_THROWS_WITH(torch::mul_out(/*out=*/c, s, s), // go through kernels: VariableType(ERROR!), ADInplaceOrView, CPU - "out=... arguments don't support automatic differentiation, but one of the arguments requires grad") + ASSERT_THROWS_WITH( + a.mul_(c), // go through kernels: VariableType(ERROR!), InferenceMode, + // CPU + "Inference tensors cannot be saved for backward."); + + ASSERT_THROWS_WITH( + torch::mul_out( + /*out=*/c, s, s), // go through kernels: VariableType(ERROR!), + // ADInplaceOrView, CPU + "out=... arguments don't support automatic differentiation, but one of the arguments requires grad") } else { a.mul_(c); - ASSERT_THROWS_WITH(torch::mul_out(/*out=*/c, s, s), // go through kernels: VariableType, ADInplaceOrView(ERROR!), CPU - "Inplace update to inference tensor outside InferenceMode is not allowed"); + ASSERT_THROWS_WITH( + torch::mul_out(/*out=*/c, s, s), // go through kernels: VariableType, + // ADInplaceOrView(ERROR!), CPU + "Inplace update to inference tensor outside InferenceMode is not allowed"); } } } TEST(InferenceModeTest, TestMixInferenceAndNormalTensorViewOp) { - for (bool requires_grad: {true, false}) { + for (bool requires_grad : {true, false}) { torch::Tensor s = torch::ones({1, 2, 3}).set_requires_grad(requires_grad); torch::Tensor c; { @@ -358,32 +394,36 @@ TEST(InferenceModeTest, TestMixInferenceAndNormalTensorViewOp) { c = torch::ones({1, 2, 3}); } - // view_as is a composite op which calls view() with only one tensor argument. - // So there isn't a mixed inference tensor and normal tensor inputs for view ops. - torch::Tensor tmp1 = c.view_as(s); // go through kernels: ADInplaceOrView, CPU + // view_as is a composite op which calls view() with only one tensor + // argument. So there isn't a mixed inference tensor and normal tensor + // inputs for view ops. + torch::Tensor tmp1 = + c.view_as(s); // go through kernels: ADInplaceOrView, CPU ASSERT_TRUE(tmp1.is_inference()); ASSERT_FALSE(tmp1.requires_grad()); // This is fine since it's equivalent as s.view(c.sizes()) which // isn't a mixed input scenario. - torch::Tensor tmp2 = s.view_as(c); // go through kernels: VariableType, ADInplaceOrView, CPU + torch::Tensor tmp2 = + s.view_as(c); // go through kernels: VariableType, ADInplaceOrView, CPU ASSERT_FALSE(tmp2.is_inference()); ASSERT_EQ(tmp2.requires_grad(), requires_grad); } } TEST(InferenceModeTest, TestHandleDirectViewOnRebase) { - for (bool requires_grad: {true, false}) { + for (bool requires_grad : {true, false}) { torch::Tensor s = torch::ones({1, 2, 3}).set_requires_grad(requires_grad); torch::Tensor a = s.clone(); torch::Tensor view_out; { InferenceMode guard; - view_out = view_op(a); // go through kernels: ADInplaceOrView, CPU + view_out = view_op(a); // go through kernels: ADInplaceOrView, CPU } if (requires_grad) { - ASSERT_THROWS_WITH(inplace_op(view_out), - "A view was created in inference mode and is being modified inplace") + ASSERT_THROWS_WITH( + inplace_op(view_out), + "A view was created in inference mode and is being modified inplace") } else { inplace_op(view_out); } @@ -391,18 +431,19 @@ TEST(InferenceModeTest, TestHandleDirectViewOnRebase) { } TEST(InferenceModeTest, TestHandleInDirectViewOnRebase) { - for (bool requires_grad: {true, false}) { + for (bool requires_grad : {true, false}) { torch::Tensor s = torch::ones({1, 2, 3}).set_requires_grad(requires_grad); torch::Tensor a = s.clone(); torch::Tensor view_out; { InferenceMode guard; - view_out = view_op(a); // go through kernels: ADInplaceOrView, CPU + view_out = view_op(a); // go through kernels: ADInplaceOrView, CPU } inplace_op(a); if (requires_grad) { - ASSERT_THROWS_WITH(view_out.grad_fn(), - "A view was created in inference mode and its base or another view of its base has been modified inplace"); + ASSERT_THROWS_WITH( + view_out.grad_fn(), + "A view was created in inference mode and its base or another view of its base has been modified inplace"); } else { view_out.grad_fn(); } @@ -416,14 +457,16 @@ TEST(InferenceModeTest, TestCreationMetaPropagation) { InferenceMode guard; b = s.view_as(s); } - ASSERT_THROWS_WITH(b.add_(1), - "A view was created in inference mode and is being modified inplace"); + ASSERT_THROWS_WITH( + b.add_(1), + "A view was created in inference mode and is being modified inplace"); { AutoGradMode mode(false); c = b.view_as(b); } - ASSERT_THROWS_WITH(c.add_(1), - "A view was created in inference mode and is being modified inplace"); + ASSERT_THROWS_WITH( + c.add_(1), + "A view was created in inference mode and is being modified inplace"); } TEST(InferenceModeTest, TestCreationMetaPropagationInput) { @@ -437,20 +480,22 @@ TEST(InferenceModeTest, TestCreationMetaPropagationInput) { s = s.view_as(s); c = s.split_with_sizes({1, 1}); } - for (auto& b_el: b) { + for (auto& b_el : b) { assert_tensor_creation_meta(b_el, CreationMeta::INFERENCE_MODE); - ASSERT_THROWS_WITH(b_el.add_(1), - "A view was created in inference mode and is being modified inplace"); + ASSERT_THROWS_WITH( + b_el.add_(1), + "A view was created in inference mode and is being modified inplace"); } - for (auto& c_el: c) { + for (auto& c_el : c) { assert_tensor_creation_meta(c_el, CreationMeta::INFERENCE_MODE); - ASSERT_THROWS_WITH(c_el.add_(1), - "A view was created in inference mode and is being modified inplace"); + ASSERT_THROWS_WITH( + c_el.add_(1), + "A view was created in inference mode and is being modified inplace"); } } TEST(InferenceModeTest, TestInplaceCopyOnInferenceTensor) { - for (bool requires_grad: {true, false}) { + for (bool requires_grad : {true, false}) { torch::Tensor s = torch::ones({1, 2, 3}).set_requires_grad(requires_grad); torch::Tensor t; { @@ -461,8 +506,9 @@ TEST(InferenceModeTest, TestInplaceCopyOnInferenceTensor) { ASSERT_FALSE(t.requires_grad()); } - ASSERT_THROWS_WITH(t.copy_(s), - "Inplace update to inference tensor outside InferenceMode is not allowed"); + ASSERT_THROWS_WITH( + t.copy_(s), + "Inplace update to inference tensor outside InferenceMode is not allowed"); } } @@ -473,8 +519,9 @@ TEST(InferenceModeTest, TestSetRequiresGradInNormalMode) { t = torch::ones({1, 2, 3}); } t.set_requires_grad(false); - ASSERT_THROWS_WITH(t.set_requires_grad(true), - "Setting requires_grad=True on inference tensor outside InferenceMode is not allowed."); + ASSERT_THROWS_WITH( + t.set_requires_grad(true), + "Setting requires_grad=True on inference tensor outside InferenceMode is not allowed."); } TEST(InferenceModeTest, TestAccessVersionCounter) { @@ -482,19 +529,23 @@ TEST(InferenceModeTest, TestAccessVersionCounter) { { InferenceMode guard; t = torch::ones({1, 2, 3}); - ASSERT_THROWS_WITH(t.unsafeGetTensorImpl()->version_counter().current_version(), - "Inference tensors do not track version counter."); + ASSERT_THROWS_WITH( + t.unsafeGetTensorImpl()->version_counter().current_version(), + "Inference tensors do not track version counter."); t.unsafeGetTensorImpl()->bump_version(); } - ASSERT_THROWS_WITH(t.unsafeGetTensorImpl()->version_counter().current_version(), - "Inference tensors do not track version counter."); - ASSERT_THROWS_WITH(t.unsafeGetTensorImpl()->bump_version(), - "Inplace update to inference tensor outside InferenceMode is not allowed."); + ASSERT_THROWS_WITH( + t.unsafeGetTensorImpl()->version_counter().current_version(), + "Inference tensors do not track version counter."); + ASSERT_THROWS_WITH( + t.unsafeGetTensorImpl()->bump_version(), + "Inplace update to inference tensor outside InferenceMode is not allowed."); // Suggested workaround torch::Tensor c = t.clone(); uint32_t v = c.unsafeGetTensorImpl()->version_counter().current_version(); c.unsafeGetTensorImpl()->bump_version(); - ASSERT_EQ(c.unsafeGetTensorImpl()->version_counter().current_version(), v + 1); + ASSERT_EQ( + c.unsafeGetTensorImpl()->version_counter().current_version(), v + 1); } TEST(InferenceModeTest, TestInplaceUpdateInferenceTensorWithNormalTensor) { @@ -511,11 +562,13 @@ TEST(InferenceModeTest, TestInplaceUpdateInferenceTensorWithNormalTensor) { } s.copy_(t); s.add_(t); - ASSERT_THROWS_WITH(t.copy_(s), - "Inplace update to inference tensor outside InferenceMode is not allowed"); + ASSERT_THROWS_WITH( + t.copy_(s), + "Inplace update to inference tensor outside InferenceMode is not allowed"); - ASSERT_THROWS_WITH(t.add_(s), - "Inplace update to inference tensor outside InferenceMode is not allowed"); + ASSERT_THROWS_WITH( + t.add_(s), + "Inplace update to inference tensor outside InferenceMode is not allowed"); } TEST(InferenceModeTest, TestComplexViewInInferenceMode) { @@ -552,18 +605,27 @@ TEST(InferenceModeTest, TestComplexViewInNormalMode) { TEST(InferenceModeTest, TestCustomFunction) { struct MyFunction : public Function { - static Variable forward(AutogradContext *ctx, Variable var1, int mul, Variable var2) { + static Variable forward( + AutogradContext* ctx, + Variable var1, + int mul, + Variable var2) { ctx->saved_data["mul"] = mul; ctx->save_for_backward({var1, var2}); - return var1 + mul*var2 + var1*var2; + return var1 + mul * var2 + var1 * var2; } - static variable_list backward(AutogradContext *ctx, variable_list grad_output) { + static variable_list backward( + AutogradContext* ctx, + variable_list grad_output) { int mul = ctx->saved_data["mul"].toInt(); auto saved = ctx->get_saved_variables(); auto var1 = saved[0]; auto var2 = saved[1]; - variable_list output = {grad_output[0] + grad_output[0]*var2, Variable(), grad_output[0] * mul + grad_output[0] * var1}; + variable_list output = { + grad_output[0] + grad_output[0] * var2, + Variable(), + grad_output[0] * mul + grad_output[0] * var1}; return output; } }; @@ -586,5 +648,6 @@ TEST(InferenceModeTest, TestLegacyAutoNonVariableTypeModeWarning) { WarningCapture warnings; at::AutoNonVariableTypeMode guard; ASSERT_TRUE( - warnings.str().find("AutoNonVariableTypeMode is deprecated") != std::string::npos); + warnings.str().find("AutoNonVariableTypeMode is deprecated") != + std::string::npos); } diff --git a/test/cpp/api/init.cpp b/test/cpp/api/init.cpp index 222d4f1171c4d1..345b6be3195d47 100644 --- a/test/cpp/api/init.cpp +++ b/test/cpp/api/init.cpp @@ -9,7 +9,6 @@ #include #include - void check_exact_values( const std::vector& parameters, const std::vector>& expected_parameters) { @@ -19,7 +18,8 @@ void check_exact_values( auto layerParameters = parameters[i]; auto expectedLayerParameters = expected_parameters[i]; - if (static_cast(layerParameters.size(0)) != expectedLayerParameters.size()) { + if (static_cast(layerParameters.size(0)) != + expectedLayerParameters.size()) { std::cout << "layer #" << i << " layerParameters size: " << layerParameters.size(0) << " != " @@ -29,7 +29,8 @@ void check_exact_values( } for (const auto p : c10::irange(layerParameters.size(0))) { - // Always compare using double dtype, regardless of the original dtype of the tensors + // Always compare using double dtype, regardless of the original dtype of + // the tensors auto tensor = layerParameters[p].to(torch::kFloat64); auto expectedTensor = expectedLayerParameters[p].to(torch::kFloat64); @@ -110,20 +111,17 @@ TEST(InitTest, CanInitializeTensorThatRequiresGrad) { } TEST(InitTest, CalculateGainWithTanh) { - double gain = - torch::nn::init::calculate_gain(torch::kTanh); + double gain = torch::nn::init::calculate_gain(torch::kTanh); ASSERT_DOUBLE_EQ(gain, 5.0 / 3.0); } TEST(InitTest, CalculateGainWithRelu) { - double gain = - torch::nn::init::calculate_gain(torch::kReLU); + double gain = torch::nn::init::calculate_gain(torch::kReLU); ASSERT_DOUBLE_EQ(gain, std::sqrt(2.0)); } TEST(InitTest, CalculateGainWithLeakyRelu) { - double gain = - torch::nn::init::calculate_gain(torch::kLeakyReLU); + double gain = torch::nn::init::calculate_gain(torch::kLeakyReLU); ASSERT_DOUBLE_EQ(gain, std::sqrt(2.0 / (1 + pow(0.01, 2)))); } diff --git a/test/cpp/api/init_baseline.h b/test/cpp/api/init_baseline.h index 1555da67a04426..878afccb72ef2c 100644 --- a/test/cpp/api/init_baseline.h +++ b/test/cpp/api/init_baseline.h @@ -8,173 +8,1613 @@ namespace expected_parameters { inline std::vector> Xavier_Uniform() { return { - { - torch::tensor({0.20493895, 0.03128183, -0.25481248, 0.2471149, -0.5009514, -0.30953097, -0.13073051}), - torch::tensor({-0.25438663, -0.18269452, -0.42803344, -0.111086845, 0.11163068, -0.34021688, -0.026800662}), - torch::tensor({0.37384087, -0.053685665, 0.014514029, -0.04505247, 0.10569024, 0.33205634, 0.4946832}), - torch::tensor({0.3316471, 0.49581498, -0.03776875, -0.46913308, -0.24757874, 0.35559112, -0.003385365}), - torch::tensor({-0.259574, -0.40019324, -0.48873278, -0.4407689, -0.10592803, 0.28639567, 0.2823406}), - torch::tensor({-0.5036581, 0.32575953, -0.40865222, -0.110405415, -0.21175116, -0.10059005, -0.10253662}), - torch::tensor({-0.46862572, -0.45091572, -0.08171874, 0.0067536235, -0.2372373, 0.19672471, -0.47004014}), - torch::tensor({-0.035244048, 0.45926183, -0.21301463, 0.47157794, 0.18912864, -0.47129482, 0.33041543}), - torch::tensor({-0.0602628, -0.23312837, 0.41760528, -0.42201394, 0.05603814, -0.10933924, 0.37293315}), - torch::tensor({0.14577848, 0.25093573, 0.18443125, -0.1255838, -0.10982844, -0.43036246, 0.28296888}), - torch::tensor({0.4146415, 0.35732478, -0.36837178, 0.023291528, -0.3681398, -0.28748098, -0.304308}), - torch::tensor({0.17847049, -0.3112055, -0.011393666, 0.021969378, 0.3366434, -0.39476636, -0.35851932}), - torch::tensor({-0.3032406, 0.3655283, -0.18772447, 0.44049758, 0.18884337, 0.066128254, -0.0038875341}), - torch::tensor({-0.10323581, 0.06552267, -0.119249105, -0.0036694407, 0.066633284, -0.40849328, -0.2737187}), - torch::tensor({0.4216993, -0.4238164, -0.037499547, 0.51661307, 0.1886499, 0.014786005, -0.45257226}), - }, - { - torch::tensor({0.2400673, 0.06230098, -0.29922318, -0.3467335, -0.13797283, 0.19630808, 0.44112986, 0.25716078, -0.05036038, 0.15680045, -0.43874466, -0.3819657, 0.20867342, -0.25330856, 0.2151525}), - torch::tensor({-0.31570244, -0.22150977, -0.3683649, 0.23337424, -0.045568883, 0.34417605, 0.27676803, 0.24746233, 0.014380664, -0.13826859, -0.09723839, 0.05943495, 0.22168803, -0.3133133, 0.37533647}), - torch::tensor({-0.04862556, -0.37474066, -0.24196841, 0.39570248, 0.408989, -0.41424486, 0.31541896, 0.2241252, 0.264714, 0.37857938, -0.24102591, 0.1412192, 0.18301463, -0.13214865, 0.14966142}), - torch::tensor({-0.12866935, 0.27649486, -0.12408146, -0.16671932, 0.112585604, 0.15862381, -0.21849588, 0.03953293, 0.25917625, -0.044496298, 0.13610226, -0.107862085, 0.15674818, -0.32395893, -0.26297447}), - torch::tensor({-0.22700138, 0.41099417, -0.12033805, -0.0012210608, -0.21667299, 0.44644886, 0.43678015, -0.3372825, -0.3625426, -0.33898476, -0.002156794, -0.113996506, -0.29272172, -0.16040304, 0.084492445}), - torch::tensor({-0.23366496, 0.09909475, -0.10255319, -0.21670331, 0.061440647, 0.36772507, -0.30235183, 0.02076611, -0.16491795, 0.4388535, -0.4242998, -0.42872632, 0.44067752, -0.28294754, 0.08574128}), - torch::tensor({-0.038597256, -0.09420869, -0.09988415, 0.28417766, 0.021375448, -0.43541366, -0.2640171, -0.15245524, 0.2250452, -0.28940699, 0.42168647, -0.09960759, -0.08030257, 0.3504178, 0.22477299}), - torch::tensor({0.37929094, 0.25868815, -0.13566399, -0.2967139, -0.033274055, 0.37013078, -0.15009373, -0.41473246, 0.18332559, 0.43534964, -0.12731421, -0.3703034, -0.40564942, 0.112071455, -0.03386289}), - torch::tensor({-0.22583716, 0.090396106, 0.1698333, 0.35567743, 0.34720868, -0.066940606, -0.39433825, -0.40411252, 0.41755867, 0.19769311, 0.19494373, -0.3869386, 0.4141268, 0.4236647, 0.4037714}), - torch::tensor({-0.37726268, -0.16874415, -0.3075773, 0.4234959, -0.19215873, -0.2041774, 0.23430097, -0.20687759, -0.22026259, -0.03911844, -0.042985946, -0.34836975, 0.3728277, -0.19727562, 0.15863329}), - torch::tensor({0.38897908, 0.22553718, 0.063316464, 0.3805148, 0.060117245, -0.20690633, 0.4230638, 0.10584676, -0.43633774, -0.12731794, -0.30462736, 0.39209586, -0.07385549,-0.40764633, -0.028113335}), - torch::tensor({0.2808559, 0.11618626, 0.14141095, 0.041534156, 0.16672957, -0.10896447, -0.17790166, -0.41801435, -0.3369025, 0.19382352, -0.26480114, 0.06416017, 0.14274675, 0.03166446, -0.28995082}), - torch::tensor({0.42768306, -0.26005447, 0.36783344, -0.35576212, -0.10757655, 0.24327022, -0.18272284, 0.3756786, -0.30775294, -0.37555724, -0.20165718, 0.07229227, 0.41177452, -0.21350017, 0.15993619}), - torch::tensor({-0.112119585, -0.09698379, 0.3288377, -0.34658423, 0.047500044, 0.42056376, -0.061452597, 0.34723365, -0.13772246, 0.35999, -0.43260327, -0.06445408, -0.07855147, 0.14493519, 0.17545414}), - torch::tensor({0.34337813, -0.066628546, -0.01773429, 0.3062569, -0.121003985, 0.39204246, -0.29776025, -0.04839292, -0.024020255, 0.1995511, 0.30574924, -0.07088503, -0.37050778, 0.22159088, 0.13377577}), - }, - { - torch::tensor({-0.095182985, -0.35154456, -0.33943355, 0.14092201, 0.5576401, -0.4759267, 0.35954505, -0.30801514, -0.11571318, 0.47157025, -0.1343652, 0.05409521, -0.41528726, 0.5057125, -0.076797724}), - torch::tensor({-0.4345107, 0.17395526, -0.42240727, -0.4714136, 0.036191404, 0.47101927, -0.16811755, 0.2796178, 0.51051295, 0.39403576, -0.3116357, -0.065123916, -0.18695068,-0.47772023, 0.00024545193}), - }, + { + torch::tensor( + {0.20493895, + 0.03128183, + -0.25481248, + 0.2471149, + -0.5009514, + -0.30953097, + -0.13073051}), + torch::tensor( + {-0.25438663, + -0.18269452, + -0.42803344, + -0.111086845, + 0.11163068, + -0.34021688, + -0.026800662}), + torch::tensor( + {0.37384087, + -0.053685665, + 0.014514029, + -0.04505247, + 0.10569024, + 0.33205634, + 0.4946832}), + torch::tensor( + {0.3316471, + 0.49581498, + -0.03776875, + -0.46913308, + -0.24757874, + 0.35559112, + -0.003385365}), + torch::tensor( + {-0.259574, + -0.40019324, + -0.48873278, + -0.4407689, + -0.10592803, + 0.28639567, + 0.2823406}), + torch::tensor( + {-0.5036581, + 0.32575953, + -0.40865222, + -0.110405415, + -0.21175116, + -0.10059005, + -0.10253662}), + torch::tensor( + {-0.46862572, + -0.45091572, + -0.08171874, + 0.0067536235, + -0.2372373, + 0.19672471, + -0.47004014}), + torch::tensor( + {-0.035244048, + 0.45926183, + -0.21301463, + 0.47157794, + 0.18912864, + -0.47129482, + 0.33041543}), + torch::tensor( + {-0.0602628, + -0.23312837, + 0.41760528, + -0.42201394, + 0.05603814, + -0.10933924, + 0.37293315}), + torch::tensor( + {0.14577848, + 0.25093573, + 0.18443125, + -0.1255838, + -0.10982844, + -0.43036246, + 0.28296888}), + torch::tensor( + {0.4146415, + 0.35732478, + -0.36837178, + 0.023291528, + -0.3681398, + -0.28748098, + -0.304308}), + torch::tensor( + {0.17847049, + -0.3112055, + -0.011393666, + 0.021969378, + 0.3366434, + -0.39476636, + -0.35851932}), + torch::tensor( + {-0.3032406, + 0.3655283, + -0.18772447, + 0.44049758, + 0.18884337, + 0.066128254, + -0.0038875341}), + torch::tensor( + {-0.10323581, + 0.06552267, + -0.119249105, + -0.0036694407, + 0.066633284, + -0.40849328, + -0.2737187}), + torch::tensor( + {0.4216993, + -0.4238164, + -0.037499547, + 0.51661307, + 0.1886499, + 0.014786005, + -0.45257226}), + }, + { + torch::tensor( + {0.2400673, + 0.06230098, + -0.29922318, + -0.3467335, + -0.13797283, + 0.19630808, + 0.44112986, + 0.25716078, + -0.05036038, + 0.15680045, + -0.43874466, + -0.3819657, + 0.20867342, + -0.25330856, + 0.2151525}), + torch::tensor( + {-0.31570244, + -0.22150977, + -0.3683649, + 0.23337424, + -0.045568883, + 0.34417605, + 0.27676803, + 0.24746233, + 0.014380664, + -0.13826859, + -0.09723839, + 0.05943495, + 0.22168803, + -0.3133133, + 0.37533647}), + torch::tensor( + {-0.04862556, + -0.37474066, + -0.24196841, + 0.39570248, + 0.408989, + -0.41424486, + 0.31541896, + 0.2241252, + 0.264714, + 0.37857938, + -0.24102591, + 0.1412192, + 0.18301463, + -0.13214865, + 0.14966142}), + torch::tensor( + {-0.12866935, + 0.27649486, + -0.12408146, + -0.16671932, + 0.112585604, + 0.15862381, + -0.21849588, + 0.03953293, + 0.25917625, + -0.044496298, + 0.13610226, + -0.107862085, + 0.15674818, + -0.32395893, + -0.26297447}), + torch::tensor( + {-0.22700138, + 0.41099417, + -0.12033805, + -0.0012210608, + -0.21667299, + 0.44644886, + 0.43678015, + -0.3372825, + -0.3625426, + -0.33898476, + -0.002156794, + -0.113996506, + -0.29272172, + -0.16040304, + 0.084492445}), + torch::tensor( + {-0.23366496, + 0.09909475, + -0.10255319, + -0.21670331, + 0.061440647, + 0.36772507, + -0.30235183, + 0.02076611, + -0.16491795, + 0.4388535, + -0.4242998, + -0.42872632, + 0.44067752, + -0.28294754, + 0.08574128}), + torch::tensor( + {-0.038597256, + -0.09420869, + -0.09988415, + 0.28417766, + 0.021375448, + -0.43541366, + -0.2640171, + -0.15245524, + 0.2250452, + -0.28940699, + 0.42168647, + -0.09960759, + -0.08030257, + 0.3504178, + 0.22477299}), + torch::tensor( + {0.37929094, + 0.25868815, + -0.13566399, + -0.2967139, + -0.033274055, + 0.37013078, + -0.15009373, + -0.41473246, + 0.18332559, + 0.43534964, + -0.12731421, + -0.3703034, + -0.40564942, + 0.112071455, + -0.03386289}), + torch::tensor( + {-0.22583716, + 0.090396106, + 0.1698333, + 0.35567743, + 0.34720868, + -0.066940606, + -0.39433825, + -0.40411252, + 0.41755867, + 0.19769311, + 0.19494373, + -0.3869386, + 0.4141268, + 0.4236647, + 0.4037714}), + torch::tensor( + {-0.37726268, + -0.16874415, + -0.3075773, + 0.4234959, + -0.19215873, + -0.2041774, + 0.23430097, + -0.20687759, + -0.22026259, + -0.03911844, + -0.042985946, + -0.34836975, + 0.3728277, + -0.19727562, + 0.15863329}), + torch::tensor( + {0.38897908, + 0.22553718, + 0.063316464, + 0.3805148, + 0.060117245, + -0.20690633, + 0.4230638, + 0.10584676, + -0.43633774, + -0.12731794, + -0.30462736, + 0.39209586, + -0.07385549, + -0.40764633, + -0.028113335}), + torch::tensor( + {0.2808559, + 0.11618626, + 0.14141095, + 0.041534156, + 0.16672957, + -0.10896447, + -0.17790166, + -0.41801435, + -0.3369025, + 0.19382352, + -0.26480114, + 0.06416017, + 0.14274675, + 0.03166446, + -0.28995082}), + torch::tensor( + {0.42768306, + -0.26005447, + 0.36783344, + -0.35576212, + -0.10757655, + 0.24327022, + -0.18272284, + 0.3756786, + -0.30775294, + -0.37555724, + -0.20165718, + 0.07229227, + 0.41177452, + -0.21350017, + 0.15993619}), + torch::tensor( + {-0.112119585, + -0.09698379, + 0.3288377, + -0.34658423, + 0.047500044, + 0.42056376, + -0.061452597, + 0.34723365, + -0.13772246, + 0.35999, + -0.43260327, + -0.06445408, + -0.07855147, + 0.14493519, + 0.17545414}), + torch::tensor( + {0.34337813, + -0.066628546, + -0.01773429, + 0.3062569, + -0.121003985, + 0.39204246, + -0.29776025, + -0.04839292, + -0.024020255, + 0.1995511, + 0.30574924, + -0.07088503, + -0.37050778, + 0.22159088, + 0.13377577}), + }, + { + torch::tensor( + {-0.095182985, + -0.35154456, + -0.33943355, + 0.14092201, + 0.5576401, + -0.4759267, + 0.35954505, + -0.30801514, + -0.11571318, + 0.47157025, + -0.1343652, + 0.05409521, + -0.41528726, + 0.5057125, + -0.076797724}), + torch::tensor( + {-0.4345107, + 0.17395526, + -0.42240727, + -0.4714136, + 0.036191404, + 0.47101927, + -0.16811755, + 0.2796178, + 0.51051295, + 0.39403576, + -0.3116357, + -0.065123916, + -0.18695068, + -0.47772023, + 0.00024545193}), + }, }; } inline std::vector> Xavier_Normal() { return { - { - torch::tensor({-0.21151732, 0.31257284, -0.18201339, -0.3855622, 0.028025549, -0.20083663, 0.18333313}), - torch::tensor({-0.22010927, 0.41458952, 0.19888626, 0.14368738, -0.30642822, 0.05438269, 0.032663286}), - torch::tensor({-0.22758777, 0.07366481, 0.34382868, -0.027100464, 0.22004184, -0.5563846, -0.0075437957}), - torch::tensor({0.4128839, 0.80112267, 0.29702467, 0.11372463, 0.3320348, -0.34456047, 0.011332423}), - torch::tensor({0.81295794, 0.37259677, 0.16366935, 0.15845336, -0.2500635, -0.42430383, 0.49051273}), - torch::tensor({0.051942326, -0.48588628, -0.14455895, -0.04322197, -0.09566793, 0.17296337, 0.3008876}), - torch::tensor({0.16390441, 0.023760417, 0.26016212, -0.005876346, 0.29881194, -0.23449583, -0.090267815}), - torch::tensor({-0.05661722, 0.57766485, 0.20810172, -0.70001936, -0.36073124, 0.059406225, -0.35497883}), - torch::tensor({0.034237247, 0.33308733, -0.4206602, 0.14325368, -0.24534757, 0.27866775, -0.07457621}), - torch::tensor({-0.42675552, 0.29772332, -0.44859084, 0.17689292, 0.04772785, 0.03324038, -0.24688111}), - torch::tensor({0.19078615, -0.57796097, 0.39555034, -0.063268825, 0.23570086, 0.29840353, 0.12504078}), - torch::tensor({-0.45496583, 0.61388826, 0.039676014, -0.15409455, -0.5167084, -0.15379032, -0.14318566}), - torch::tensor({-0.19097842, -0.44253433, -0.26487318, -0.6266602, -0.33180708, 0.4737542, 0.05777406}), - torch::tensor({0.11455678, -0.04364589, 0.19224901, -0.084812194, -0.40097243, -0.197128, -0.161051}), - torch::tensor({-0.15780471, 0.25975332, -0.26742947, 0.2529001, 0.34761095, -0.5309911, -0.3337848}), - }, - { - torch::tensor({-0.11110223, -0.08603837, -0.39927804, -0.0037998888, 0.3163225, 0.41147578, -0.4212508, 0.27163807, 0.16257726, 0.07001552, -0.17712006, -0.28190005, 0.43368816, -0.22742769, 0.14976384}), - torch::tensor({0.15162551, 0.022756116, -0.337443, -0.1823836, -0.04240044, -0.25083202, -0.26616275, 0.16712677, -0.3765372, 0.18503848, -0.51644665, -0.622171, 0.05665474, -0.4386439, 0.33809182}), - torch::tensor({-0.42894447, 0.44781545, -0.16271128, -0.16386595, 0.25165728, 0.054181106, 0.0077175284, 0.44132262, -0.18739006, -0.37522528, 0.15394184, -0.32285675, 0.2957909, 0.19089724, 0.32357433}), - torch::tensor({-0.11477537, 0.21132338, 0.0032186622, 0.09303321, -0.41689086, -0.6386192, 0.009335384, -0.088360295, -0.09855254, -0.014685968, -0.22642778, 0.17631726, 0.87642914, -0.4308766, 0.13190313}), - torch::tensor({-0.07384103, 0.086509995, 0.30259287, -0.10428588, 0.23001948, -0.12647822, -0.30277884, -0.17739438, -0.60285777, 0.024281243, -0.052177392, 0.807063, 0.5194553, -0.08695924, 0.084167756}), - torch::tensor({0.13818856, 0.50950116, -0.053579096, -0.007894313, 0.067052916, 0.0014321166, 0.20509668, 0.105126604, -0.093184546, 0.33830634, -0.24917552, 0.22737288, -0.2688545, -0.17480324, -0.106032155}), - torch::tensor({-0.4179092, 0.1311413, 0.5997842, 0.059327736, -0.13675568, 0.47349188, 0.0011002938, -0.32478496, -0.28000802, 0.19441846, 0.08356549, -0.07100728, 0.33710754,0.27447432, 0.070220366}), - torch::tensor({-0.23930988, -0.7056576, -0.14566903, -0.0707464, 0.03609119, 0.13131015, 0.085740715, -0.25335756, 0.22949918, 0.40511283, -0.021134255, -0.090214714, 0.052266303, -0.07446028, -0.0024477358}), - torch::tensor({0.6245137, 0.34285438, -0.068128414, 0.09410469, 0.6568622, -0.6944174, 0.63067895, 0.10417306, -0.25729227, 0.2526477, -0.1139792, -0.067401096, 0.20603673, -0.28586346, 0.601752}), - torch::tensor({-0.26997554, -0.12263605, -0.12787469, -0.05121646, 0.57187945, -0.03529285, -0.26288798, 0.046064075, 0.33631605, -0.1457349, -0.23712163, -0.19353649, -0.024511583, 0.28424242, 0.3383595}), - torch::tensor({-0.07561234, 0.10794748, -0.0437702, -0.56154686, 0.18596707, 0.07370188, 0.059135318, 0.33133864, -0.35610923, 0.36255547, -0.24472283, 0.052193344, -0.09056281, 0.14071976, 0.3979389}), - torch::tensor({0.15498109, -0.08727514, -0.1047872, 0.23060486, -0.37544468, 0.30660948, -0.07733662, 0.59290683, 0.08534651, 0.56152266, -0.058316797, 0.64462274, -0.1825164,0.29704455, -0.14194927}), - torch::tensor({0.21056584, 0.5165385, -0.11343903, -0.2790672, -0.12210965, 0.08408014, -0.25138792, 0.20601495, -0.14802553, 0.053201754, 0.055186458, 0.17179312, 0.18927598,-0.37573355, -0.5282227}), - torch::tensor({-0.045720562, 0.16112846, -0.43607467, 0.21109799, -0.008630521, 0.32987764, 0.13289018, 0.0899868, -0.2679775, -0.24753189, 0.23707163, 0.12620693, -0.6093915,0.25720116, 0.33193505}), - torch::tensor({-0.4173836, -0.19854523, -0.3151219, 0.14798103, 0.18053184, 0.07821397, 0.051448464, -0.024434408, 0.41581196, -0.031413753, -0.35996363, -0.23361506, -0.08951727, -0.052789558, 0.054483347}), - }, - { - torch::tensor({0.41098386, 0.45162097, -0.11947367, 0.4141703, -0.2222035, 0.1805965, 0.3349034, -0.3735036, 0.13126858, -0.14054178, -0.25333884, 0.56775486, 0.17848605, -0.079786636, 0.24059108}), - torch::tensor({-0.19057135, -0.13091972, -0.015175606, 0.17821532, 0.26070553, -0.35303566, -0.42888534, -0.038996607, -0.21060736, 0.2925792, 0.053198263, 0.3793282, 0.3753942, 0.19560203, 0.2662913}), - }, + { + torch::tensor( + {-0.21151732, + 0.31257284, + -0.18201339, + -0.3855622, + 0.028025549, + -0.20083663, + 0.18333313}), + torch::tensor( + {-0.22010927, + 0.41458952, + 0.19888626, + 0.14368738, + -0.30642822, + 0.05438269, + 0.032663286}), + torch::tensor( + {-0.22758777, + 0.07366481, + 0.34382868, + -0.027100464, + 0.22004184, + -0.5563846, + -0.0075437957}), + torch::tensor( + {0.4128839, + 0.80112267, + 0.29702467, + 0.11372463, + 0.3320348, + -0.34456047, + 0.011332423}), + torch::tensor( + {0.81295794, + 0.37259677, + 0.16366935, + 0.15845336, + -0.2500635, + -0.42430383, + 0.49051273}), + torch::tensor( + {0.051942326, + -0.48588628, + -0.14455895, + -0.04322197, + -0.09566793, + 0.17296337, + 0.3008876}), + torch::tensor( + {0.16390441, + 0.023760417, + 0.26016212, + -0.005876346, + 0.29881194, + -0.23449583, + -0.090267815}), + torch::tensor( + {-0.05661722, + 0.57766485, + 0.20810172, + -0.70001936, + -0.36073124, + 0.059406225, + -0.35497883}), + torch::tensor( + {0.034237247, + 0.33308733, + -0.4206602, + 0.14325368, + -0.24534757, + 0.27866775, + -0.07457621}), + torch::tensor( + {-0.42675552, + 0.29772332, + -0.44859084, + 0.17689292, + 0.04772785, + 0.03324038, + -0.24688111}), + torch::tensor( + {0.19078615, + -0.57796097, + 0.39555034, + -0.063268825, + 0.23570086, + 0.29840353, + 0.12504078}), + torch::tensor( + {-0.45496583, + 0.61388826, + 0.039676014, + -0.15409455, + -0.5167084, + -0.15379032, + -0.14318566}), + torch::tensor( + {-0.19097842, + -0.44253433, + -0.26487318, + -0.6266602, + -0.33180708, + 0.4737542, + 0.05777406}), + torch::tensor( + {0.11455678, + -0.04364589, + 0.19224901, + -0.084812194, + -0.40097243, + -0.197128, + -0.161051}), + torch::tensor( + {-0.15780471, + 0.25975332, + -0.26742947, + 0.2529001, + 0.34761095, + -0.5309911, + -0.3337848}), + }, + { + torch::tensor( + {-0.11110223, + -0.08603837, + -0.39927804, + -0.0037998888, + 0.3163225, + 0.41147578, + -0.4212508, + 0.27163807, + 0.16257726, + 0.07001552, + -0.17712006, + -0.28190005, + 0.43368816, + -0.22742769, + 0.14976384}), + torch::tensor( + {0.15162551, + 0.022756116, + -0.337443, + -0.1823836, + -0.04240044, + -0.25083202, + -0.26616275, + 0.16712677, + -0.3765372, + 0.18503848, + -0.51644665, + -0.622171, + 0.05665474, + -0.4386439, + 0.33809182}), + torch::tensor( + {-0.42894447, + 0.44781545, + -0.16271128, + -0.16386595, + 0.25165728, + 0.054181106, + 0.0077175284, + 0.44132262, + -0.18739006, + -0.37522528, + 0.15394184, + -0.32285675, + 0.2957909, + 0.19089724, + 0.32357433}), + torch::tensor( + {-0.11477537, + 0.21132338, + 0.0032186622, + 0.09303321, + -0.41689086, + -0.6386192, + 0.009335384, + -0.088360295, + -0.09855254, + -0.014685968, + -0.22642778, + 0.17631726, + 0.87642914, + -0.4308766, + 0.13190313}), + torch::tensor( + {-0.07384103, + 0.086509995, + 0.30259287, + -0.10428588, + 0.23001948, + -0.12647822, + -0.30277884, + -0.17739438, + -0.60285777, + 0.024281243, + -0.052177392, + 0.807063, + 0.5194553, + -0.08695924, + 0.084167756}), + torch::tensor( + {0.13818856, + 0.50950116, + -0.053579096, + -0.007894313, + 0.067052916, + 0.0014321166, + 0.20509668, + 0.105126604, + -0.093184546, + 0.33830634, + -0.24917552, + 0.22737288, + -0.2688545, + -0.17480324, + -0.106032155}), + torch::tensor( + {-0.4179092, + 0.1311413, + 0.5997842, + 0.059327736, + -0.13675568, + 0.47349188, + 0.0011002938, + -0.32478496, + -0.28000802, + 0.19441846, + 0.08356549, + -0.07100728, + 0.33710754, + 0.27447432, + 0.070220366}), + torch::tensor( + {-0.23930988, + -0.7056576, + -0.14566903, + -0.0707464, + 0.03609119, + 0.13131015, + 0.085740715, + -0.25335756, + 0.22949918, + 0.40511283, + -0.021134255, + -0.090214714, + 0.052266303, + -0.07446028, + -0.0024477358}), + torch::tensor( + {0.6245137, + 0.34285438, + -0.068128414, + 0.09410469, + 0.6568622, + -0.6944174, + 0.63067895, + 0.10417306, + -0.25729227, + 0.2526477, + -0.1139792, + -0.067401096, + 0.20603673, + -0.28586346, + 0.601752}), + torch::tensor( + {-0.26997554, + -0.12263605, + -0.12787469, + -0.05121646, + 0.57187945, + -0.03529285, + -0.26288798, + 0.046064075, + 0.33631605, + -0.1457349, + -0.23712163, + -0.19353649, + -0.024511583, + 0.28424242, + 0.3383595}), + torch::tensor( + {-0.07561234, + 0.10794748, + -0.0437702, + -0.56154686, + 0.18596707, + 0.07370188, + 0.059135318, + 0.33133864, + -0.35610923, + 0.36255547, + -0.24472283, + 0.052193344, + -0.09056281, + 0.14071976, + 0.3979389}), + torch::tensor( + {0.15498109, + -0.08727514, + -0.1047872, + 0.23060486, + -0.37544468, + 0.30660948, + -0.07733662, + 0.59290683, + 0.08534651, + 0.56152266, + -0.058316797, + 0.64462274, + -0.1825164, + 0.29704455, + -0.14194927}), + torch::tensor( + {0.21056584, + 0.5165385, + -0.11343903, + -0.2790672, + -0.12210965, + 0.08408014, + -0.25138792, + 0.20601495, + -0.14802553, + 0.053201754, + 0.055186458, + 0.17179312, + 0.18927598, + -0.37573355, + -0.5282227}), + torch::tensor( + {-0.045720562, + 0.16112846, + -0.43607467, + 0.21109799, + -0.008630521, + 0.32987764, + 0.13289018, + 0.0899868, + -0.2679775, + -0.24753189, + 0.23707163, + 0.12620693, + -0.6093915, + 0.25720116, + 0.33193505}), + torch::tensor( + {-0.4173836, + -0.19854523, + -0.3151219, + 0.14798103, + 0.18053184, + 0.07821397, + 0.051448464, + -0.024434408, + 0.41581196, + -0.031413753, + -0.35996363, + -0.23361506, + -0.08951727, + -0.052789558, + 0.054483347}), + }, + { + torch::tensor( + {0.41098386, + 0.45162097, + -0.11947367, + 0.4141703, + -0.2222035, + 0.1805965, + 0.3349034, + -0.3735036, + 0.13126858, + -0.14054178, + -0.25333884, + 0.56775486, + 0.17848605, + -0.079786636, + 0.24059108}), + torch::tensor( + {-0.19057135, + -0.13091972, + -0.015175606, + 0.17821532, + 0.26070553, + -0.35303566, + -0.42888534, + -0.038996607, + -0.21060736, + 0.2925792, + 0.053198263, + 0.3793282, + 0.3753942, + 0.19560203, + 0.2662913}), + }, }; } inline std::vector> Kaiming_Normal() { return { - { - torch::tensor({-0.37498012, 0.5541324, -0.32267526, -0.6835287, 0.049683988, -0.35604528, 0.3250149}), - torch::tensor({-0.39021203, 0.7349887, 0.35258764, 0.2547305, -0.5432392, 0.0964102, 0.057905816}), - torch::tensor({-0.40347, 0.13059375, 0.6095431, -0.048043985, 0.3900925, -0.9863645, -0.01337372}), - torch::tensor({0.7319649, 1.4202386, 0.5265685, 0.2016122, 0.5886347, -0.61084044, 0.02009024}), - torch::tensor({1.4412204, 0.66054344, 0.29015473, 0.28090778, -0.44331524, -0.75221026, 0.8695861}), - torch::tensor({0.0920839, -0.8613843, -0.25627562, -0.076624356, -0.16960111, 0.30663127, 0.5334167}), - torch::tensor({0.29057145, 0.042122718, 0.46121812, -0.010417648, 0.52973694, -0.41571668, -0.16002773}), - torch::tensor({-0.1003716, 1.0240903, 0.36892492, -1.2410016, -0.6395081, 0.105315976, -0.6293102}), - torch::tensor({0.06069615, 0.5905007, -0.7457508, 0.25396162, -0.43495473, 0.4940251, -0.13220948}), - torch::tensor({-0.75655663, 0.52780706, -0.79526657, 0.31359762, 0.08461243, 0.058928896, -0.43767342}), - torch::tensor({0.3382277, -1.0246153, 0.7012358, -0.11216363, 0.41785294, 0.5290129, 0.22167361}), - torch::tensor({-0.8065682, 1.0883075, 0.070338055, -0.27318043, -0.91602606, -0.2726411, -0.25384104}), - torch::tensor({-0.33856857, -0.78452945, -0.46956995, -1.1109498, -0.5882311, 0.83987635, 0.10242246}), - torch::tensor({0.20308746, -0.07737589, 0.34082106, -0.15035595, -0.71084815, -0.3494706, -0.28551292}), - torch::tensor({-0.27975786, 0.4604934, -0.47410175, 0.4483439, 0.6162483, -0.9413465, -0.59173715}), - }, - { - torch::tensor({-0.15712228, -0.12167663, -0.5646644, -0.005373854, 0.44734758, 0.58191466, -0.59573853, 0.38415426, 0.22991896, 0.0990169, -0.2504856, -0.3986669, 0.6133277, -0.3216313, 0.21179804}), - torch::tensor({0.21443085, 0.032182008, -0.47721645, -0.25792935, -0.059963275, -0.35473004, -0.376411, 0.23635295, -0.532504, 0.26168394, -0.7303659, -0.8798827, 0.0801219, -0.6203361, 0.47813404}), - torch::tensor({-0.60661906, 0.6333067, -0.2301085, -0.23174146, 0.35589716, 0.076623656, 0.0109142335, 0.6241244, -0.26500958, -0.53064865, 0.21770664, -0.45658842, 0.4183115,0.26996946, 0.45760322}), - torch::tensor({-0.16231689, 0.29885638, 0.004551876, 0.13156882, -0.5895727, -0.9031439, 0.013202227, -0.124960326, -0.13937433, -0.020769095, -0.32021725, 0.24935026, 1.239458, -0.6093515, 0.18653919}), - torch::tensor({-0.10442699, 0.12234361, 0.42793095, -0.14748251, 0.32529667, -0.17886722, -0.42819393, -0.25087354, -0.85256964, 0.03433886, -0.07378998, 1.1413594, 0.73462075, -0.12297893, 0.11903118}), - torch::tensor({0.19542812, 0.72054344, -0.075772285, -0.011164244, 0.09482714, 0.0020253188, 0.2900505, 0.14867148, -0.13178284, 0.4784374, -0.3523874, 0.32155383, -0.3802177,-0.24720912, -0.14995211}), - torch::tensor({-0.5910129, 0.18546182, 0.8482229, 0.08390209, -0.19340174, 0.6696186, 0.0015560504, -0.4593153, -0.39599115, 0.27494922, 0.11817944, -0.10041946, 0.47674203, 0.38816532, 0.0993066}), - torch::tensor({-0.33843526, -0.9979506, -0.20600711, -0.10005052, 0.05104065, 0.1857006, 0.12125569, -0.3583017, 0.32456085, 0.57291603, -0.02988835, -0.12758288, 0.07391571, -0.105302736, -0.003461621}), - torch::tensor({0.88319576, 0.4848693, -0.09634813, 0.13308413, 0.9289434, -0.98205453, 0.8919147, 0.14732295, -0.3638662, 0.35729778, -0.16119093, -0.09531955, 0.29137993, -0.404272, 0.8510058}), - torch::tensor({-0.38180307, -0.17343357, -0.1808421, -0.07243101, 0.8087596, -0.049911626, -0.37177977, 0.06514444, 0.4756227, -0.20610029, -0.33534062, -0.27370194, -0.034664612, 0.4019795, 0.4785126}), - torch::tensor({-0.106931984, 0.15266079, -0.06190041, -0.7941472, 0.26299715, 0.1042302, 0.083629966, 0.46858358, -0.5036145, 0.5127309, -0.34609035, 0.07381254, -0.12807515, 0.19900778, 0.5627706}), - torch::tensor({0.21917637, -0.123425685, -0.14819148, 0.32612452, -0.53095895, 0.4336113, -0.10937049, 0.83849686, 0.1206982, 0.794113, -0.08247241, 0.91163427, -0.25811717, 0.42008442, -0.20074657}), - torch::tensor({0.29778504, 0.73049575, -0.16042702, -0.3946606, -0.17268912, 0.11890727, -0.35551623, 0.2913491, -0.20933971, 0.07523864, 0.078045435, 0.24295215, 0.26767665, -0.5313675, -0.74701965}), - torch::tensor({-0.06465864, 0.22787006, -0.61670274, 0.2985376, -0.0122054, 0.46651745, 0.1879351, 0.12726055, -0.37897742, -0.35006294, 0.33526993, 0.17848356, -0.86180973, 0.3637374, 0.46942705}), - torch::tensor({-0.59026957, -0.28078535, -0.44564965, 0.2092768, 0.25531057, 0.110611245, 0.072759114, -0.034555472, 0.5880469, -0.044425756, -0.50906545, -0.33038157, -0.12659654, -0.07465571, 0.07705109}), - }, - { - torch::tensor({0.43752572, 0.48078725, -0.12718944, 0.44091794, -0.23655368, 0.19225967, 0.35653186, -0.39762494, 0.13974607, -0.14961815, -0.26969978, 0.6044212, 0.19001292, -0.08493936, 0.25612876}), - torch::tensor({-0.20287868, -0.13937469, -0.016155666, 0.1897247, 0.27754223, -0.37583515, -0.45658332, -0.041515056, -0.22420865, 0.31147435, 0.056633875, 0.4038257, 0.39963764, 0.20823427, 0.28348872}), - }, + { + torch::tensor( + {-0.37498012, + 0.5541324, + -0.32267526, + -0.6835287, + 0.049683988, + -0.35604528, + 0.3250149}), + torch::tensor( + {-0.39021203, + 0.7349887, + 0.35258764, + 0.2547305, + -0.5432392, + 0.0964102, + 0.057905816}), + torch::tensor( + {-0.40347, + 0.13059375, + 0.6095431, + -0.048043985, + 0.3900925, + -0.9863645, + -0.01337372}), + torch::tensor( + {0.7319649, + 1.4202386, + 0.5265685, + 0.2016122, + 0.5886347, + -0.61084044, + 0.02009024}), + torch::tensor( + {1.4412204, + 0.66054344, + 0.29015473, + 0.28090778, + -0.44331524, + -0.75221026, + 0.8695861}), + torch::tensor( + {0.0920839, + -0.8613843, + -0.25627562, + -0.076624356, + -0.16960111, + 0.30663127, + 0.5334167}), + torch::tensor( + {0.29057145, + 0.042122718, + 0.46121812, + -0.010417648, + 0.52973694, + -0.41571668, + -0.16002773}), + torch::tensor( + {-0.1003716, + 1.0240903, + 0.36892492, + -1.2410016, + -0.6395081, + 0.105315976, + -0.6293102}), + torch::tensor( + {0.06069615, + 0.5905007, + -0.7457508, + 0.25396162, + -0.43495473, + 0.4940251, + -0.13220948}), + torch::tensor( + {-0.75655663, + 0.52780706, + -0.79526657, + 0.31359762, + 0.08461243, + 0.058928896, + -0.43767342}), + torch::tensor( + {0.3382277, + -1.0246153, + 0.7012358, + -0.11216363, + 0.41785294, + 0.5290129, + 0.22167361}), + torch::tensor( + {-0.8065682, + 1.0883075, + 0.070338055, + -0.27318043, + -0.91602606, + -0.2726411, + -0.25384104}), + torch::tensor( + {-0.33856857, + -0.78452945, + -0.46956995, + -1.1109498, + -0.5882311, + 0.83987635, + 0.10242246}), + torch::tensor( + {0.20308746, + -0.07737589, + 0.34082106, + -0.15035595, + -0.71084815, + -0.3494706, + -0.28551292}), + torch::tensor( + {-0.27975786, + 0.4604934, + -0.47410175, + 0.4483439, + 0.6162483, + -0.9413465, + -0.59173715}), + }, + { + torch::tensor( + {-0.15712228, + -0.12167663, + -0.5646644, + -0.005373854, + 0.44734758, + 0.58191466, + -0.59573853, + 0.38415426, + 0.22991896, + 0.0990169, + -0.2504856, + -0.3986669, + 0.6133277, + -0.3216313, + 0.21179804}), + torch::tensor( + {0.21443085, + 0.032182008, + -0.47721645, + -0.25792935, + -0.059963275, + -0.35473004, + -0.376411, + 0.23635295, + -0.532504, + 0.26168394, + -0.7303659, + -0.8798827, + 0.0801219, + -0.6203361, + 0.47813404}), + torch::tensor( + {-0.60661906, + 0.6333067, + -0.2301085, + -0.23174146, + 0.35589716, + 0.076623656, + 0.0109142335, + 0.6241244, + -0.26500958, + -0.53064865, + 0.21770664, + -0.45658842, + 0.4183115, + 0.26996946, + 0.45760322}), + torch::tensor( + {-0.16231689, + 0.29885638, + 0.004551876, + 0.13156882, + -0.5895727, + -0.9031439, + 0.013202227, + -0.124960326, + -0.13937433, + -0.020769095, + -0.32021725, + 0.24935026, + 1.239458, + -0.6093515, + 0.18653919}), + torch::tensor( + {-0.10442699, + 0.12234361, + 0.42793095, + -0.14748251, + 0.32529667, + -0.17886722, + -0.42819393, + -0.25087354, + -0.85256964, + 0.03433886, + -0.07378998, + 1.1413594, + 0.73462075, + -0.12297893, + 0.11903118}), + torch::tensor( + {0.19542812, + 0.72054344, + -0.075772285, + -0.011164244, + 0.09482714, + 0.0020253188, + 0.2900505, + 0.14867148, + -0.13178284, + 0.4784374, + -0.3523874, + 0.32155383, + -0.3802177, + -0.24720912, + -0.14995211}), + torch::tensor( + {-0.5910129, + 0.18546182, + 0.8482229, + 0.08390209, + -0.19340174, + 0.6696186, + 0.0015560504, + -0.4593153, + -0.39599115, + 0.27494922, + 0.11817944, + -0.10041946, + 0.47674203, + 0.38816532, + 0.0993066}), + torch::tensor( + {-0.33843526, + -0.9979506, + -0.20600711, + -0.10005052, + 0.05104065, + 0.1857006, + 0.12125569, + -0.3583017, + 0.32456085, + 0.57291603, + -0.02988835, + -0.12758288, + 0.07391571, + -0.105302736, + -0.003461621}), + torch::tensor( + {0.88319576, + 0.4848693, + -0.09634813, + 0.13308413, + 0.9289434, + -0.98205453, + 0.8919147, + 0.14732295, + -0.3638662, + 0.35729778, + -0.16119093, + -0.09531955, + 0.29137993, + -0.404272, + 0.8510058}), + torch::tensor( + {-0.38180307, + -0.17343357, + -0.1808421, + -0.07243101, + 0.8087596, + -0.049911626, + -0.37177977, + 0.06514444, + 0.4756227, + -0.20610029, + -0.33534062, + -0.27370194, + -0.034664612, + 0.4019795, + 0.4785126}), + torch::tensor( + {-0.106931984, + 0.15266079, + -0.06190041, + -0.7941472, + 0.26299715, + 0.1042302, + 0.083629966, + 0.46858358, + -0.5036145, + 0.5127309, + -0.34609035, + 0.07381254, + -0.12807515, + 0.19900778, + 0.5627706}), + torch::tensor( + {0.21917637, + -0.123425685, + -0.14819148, + 0.32612452, + -0.53095895, + 0.4336113, + -0.10937049, + 0.83849686, + 0.1206982, + 0.794113, + -0.08247241, + 0.91163427, + -0.25811717, + 0.42008442, + -0.20074657}), + torch::tensor( + {0.29778504, + 0.73049575, + -0.16042702, + -0.3946606, + -0.17268912, + 0.11890727, + -0.35551623, + 0.2913491, + -0.20933971, + 0.07523864, + 0.078045435, + 0.24295215, + 0.26767665, + -0.5313675, + -0.74701965}), + torch::tensor( + {-0.06465864, + 0.22787006, + -0.61670274, + 0.2985376, + -0.0122054, + 0.46651745, + 0.1879351, + 0.12726055, + -0.37897742, + -0.35006294, + 0.33526993, + 0.17848356, + -0.86180973, + 0.3637374, + 0.46942705}), + torch::tensor( + {-0.59026957, + -0.28078535, + -0.44564965, + 0.2092768, + 0.25531057, + 0.110611245, + 0.072759114, + -0.034555472, + 0.5880469, + -0.044425756, + -0.50906545, + -0.33038157, + -0.12659654, + -0.07465571, + 0.07705109}), + }, + { + torch::tensor( + {0.43752572, + 0.48078725, + -0.12718944, + 0.44091794, + -0.23655368, + 0.19225967, + 0.35653186, + -0.39762494, + 0.13974607, + -0.14961815, + -0.26969978, + 0.6044212, + 0.19001292, + -0.08493936, + 0.25612876}), + torch::tensor( + {-0.20287868, + -0.13937469, + -0.016155666, + 0.1897247, + 0.27754223, + -0.37583515, + -0.45658332, + -0.041515056, + -0.22420865, + 0.31147435, + 0.056633875, + 0.4038257, + 0.39963764, + 0.20823427, + 0.28348872}), + }, }; } inline std::vector> Kaiming_Uniform() { return { - { - torch::tensor({0.36331797, 0.055456758, -0.45173424, 0.43808794, -0.8880919, -0.5487398, -0.23176038}), - torch::tensor({-0.45097935, -0.32388276, -0.7588222, -0.19693595, 0.19790006, -0.6031401, -0.04751247}), - torch::tensor({0.66274905, -0.09517455, 0.02573061, -0.07986951, 0.18736875, 0.588673, 0.8769796}), - torch::tensor({0.5879475, 0.8789861, -0.06695682, -0.8316841, -0.43891022, 0.63039577, -0.0060015917}), - torch::tensor({-0.4601755, -0.7094668, -0.86643064, -0.7813998, -0.18779033, 0.50772536, 0.5005363}), - torch::tensor({-0.89289045, 0.57751, -0.724463, -0.19572788, -0.3753947, -0.17832708, -0.18177801}), - torch::tensor({-0.8307846, -0.7993882, -0.14487183, 0.011972845, -0.4205768, 0.34875572, -0.8332921}), - torch::tensor({-0.062481046, 0.8141842, -0.37763458, 0.83601844, 0.33528924, -0.83551645, 0.58576393}), - torch::tensor({-0.10683453, -0.4132924, 0.7403351, -0.7481508, 0.09934509, -0.19383776, 0.66113985}), - torch::tensor({0.25843763, 0.44486153, 0.32696164, -0.22263634, -0.19470501, -0.76295114, 0.5016502}), - torch::tensor({0.73508084, 0.6334691, -0.6530534, 0.041291475, -0.6526422, -0.50964934, -0.53948045}), - torch::tensor({0.31639445, -0.5517084, -0.020198822, 0.038947523, 0.596805, -0.69984597, -0.63558686}), - torch::tensor({-0.5375881, 0.6480124, -0.3327999, 0.78091884, 0.33478355, 0.11723292, -0.0068918467}), - torch::tensor({-0.18301755, 0.11615932, -0.21140611, -0.0065051913, 0.11812818, -0.72418123, -0.4852514}), - torch::tensor({0.74759305, -0.75134623, -0.06647962, 0.9158571, 0.33444047, 0.026212811, -0.8023249}), - }, - { - torch::tensor({0.33950645, 0.08810693, -0.42316547, -0.49035522, -0.19512305, 0.2776215, 0.62385184, 0.3636803, -0.07122034, 0.2217493, -0.62047863, -0.5401811, 0.29510874, -0.3582324, 0.30427164}), - torch::tensor({-0.44647068, -0.3132621, -0.5209466, 0.33004105, -0.064444125, 0.4867385, 0.3914091, 0.34996456, 0.020337284, -0.19554132, -0.13751587, 0.084053695, 0.31351423,-0.44309196, 0.5308059}), - torch::tensor({-0.06876695, -0.5299633, -0.342195, 0.5596078, 0.5783978, -0.5858307, 0.44606978, 0.31696087, 0.37436205, 0.5353921, -0.34086213, 0.19971412, 0.2588218, -0.1868864, 0.21165323}), - torch::tensor({-0.18196595, 0.39102274, -0.17547771, -0.2357767, 0.1592201, 0.22432798, -0.30899984, 0.055908024, 0.3665306, -0.062927246, 0.1924777, -0.15254003, 0.22167546, -0.4581471, -0.37190205}), - torch::tensor({-0.32102844, 0.58123356, -0.17018372, -0.0017268658, -0.30642188, 0.63137406, 0.6177004, -0.4769895, -0.51271266, -0.47939685, -0.0030501485, -0.1612154, -0.413971, -0.22684419, 0.119490385}), - torch::tensor({-0.33045214, 0.14014107, -0.14503211, -0.30646476, 0.08689022, 0.52004176, -0.42759007, 0.029367685, -0.23322919, 0.6206326, -0.60005057, -0.60631055, 0.62321216, -0.40014827, 0.12125647}), - torch::tensor({-0.05458474, -0.1332312, -0.14125755, 0.40188795, 0.03022945, -0.6157679, -0.37337655, -0.21560428, 0.31826198, -0.40928328, 0.59635466, -0.1408664, -0.11356497, 0.4955656, 0.317877}), - torch::tensor({0.53639835, 0.36584032, -0.19185784, -0.4196168, -0.047056615, 0.523444, -0.2122646, -0.58652025, 0.2592615, 0.6156774, -0.18004948, -0.5236881, -0.5736749, 0.15849298, -0.04788935}), - torch::tensor({-0.31938198, 0.12783945, 0.24018055, 0.5030039, 0.49102718, -0.09466827, -0.5576785, -0.57150143, 0.5905171, 0.2795803, 0.27569205, -0.5472138, 0.58566374, 0.5991524, 0.571019}), - torch::tensor({-0.53353, -0.23864025, -0.43498003, 0.5989136, -0.2717535, -0.28875044, 0.33135164, -0.2925691, -0.31149834, -0.055321813, -0.060791314, -0.49266922, 0.527258, -0.27898985, 0.22434139}), - torch::tensor({0.55009943, 0.31895775, 0.089542985, 0.53812927, 0.085018635, -0.29260972, 0.59830266, 0.14968991, -0.6170747, -0.18005475, -0.43080813, 0.5545073, -0.104447424, -0.576499, -0.039758265}), - torch::tensor({0.39719027, 0.16431224, 0.19998527, 0.058738172, 0.23579127, -0.15409905, -0.25159094, -0.59116155, -0.4764521, 0.2741078, -0.37448537, 0.09073615, 0.20187438, 0.044780314, -0.4100524}), - torch::tensor({0.6048352, -0.36777255, 0.52019507, -0.5031236, -0.15213624, 0.34403604, -0.25840908, 0.53128976, -0.43522838, -0.53111815, -0.28518632, 0.10223669, 0.5823371, -0.30193484, 0.22618395}), - torch::tensor({-0.15856105, -0.13715577, 0.4650467, -0.49014413, 0.06717521, 0.59476703, -0.08690709, 0.49106258, -0.194769, 0.50910276, -0.6117934, -0.09115183, -0.111088574,0.20496935, 0.24812967}), - torch::tensor({0.48561007, -0.094227016, -0.025080085, 0.43311268, -0.17112547, 0.55443174, -0.42109656, -0.068437934, -0.03396976, 0.2822079, 0.4323948, -0.10024661, -0.52397716, 0.31337678, 0.18918753}), - }, - { - torch::tensor({-0.10132998, -0.37424773, -0.3613546, 0.15002292, 0.59365314, -0.5066626, 0.38276488, -0.32790715, -0.12318605, 0.5020248, -0.14304265, 0.057588696, -0.442107, 0.538372, -0.081757426}), - torch::tensor({-0.46257195, 0.18518949, -0.44968688, -0.5018581, 0.03852868, 0.5014383, -0.1789748, 0.29767585, 0.5434825, 0.419483, -0.33176154, -0.06932968, -0.19902417, -0.508572, 0.00026130676}), - }, + { + torch::tensor( + {0.36331797, + 0.055456758, + -0.45173424, + 0.43808794, + -0.8880919, + -0.5487398, + -0.23176038}), + torch::tensor( + {-0.45097935, + -0.32388276, + -0.7588222, + -0.19693595, + 0.19790006, + -0.6031401, + -0.04751247}), + torch::tensor( + {0.66274905, + -0.09517455, + 0.02573061, + -0.07986951, + 0.18736875, + 0.588673, + 0.8769796}), + torch::tensor( + {0.5879475, + 0.8789861, + -0.06695682, + -0.8316841, + -0.43891022, + 0.63039577, + -0.0060015917}), + torch::tensor( + {-0.4601755, + -0.7094668, + -0.86643064, + -0.7813998, + -0.18779033, + 0.50772536, + 0.5005363}), + torch::tensor( + {-0.89289045, + 0.57751, + -0.724463, + -0.19572788, + -0.3753947, + -0.17832708, + -0.18177801}), + torch::tensor( + {-0.8307846, + -0.7993882, + -0.14487183, + 0.011972845, + -0.4205768, + 0.34875572, + -0.8332921}), + torch::tensor( + {-0.062481046, + 0.8141842, + -0.37763458, + 0.83601844, + 0.33528924, + -0.83551645, + 0.58576393}), + torch::tensor( + {-0.10683453, + -0.4132924, + 0.7403351, + -0.7481508, + 0.09934509, + -0.19383776, + 0.66113985}), + torch::tensor( + {0.25843763, + 0.44486153, + 0.32696164, + -0.22263634, + -0.19470501, + -0.76295114, + 0.5016502}), + torch::tensor( + {0.73508084, + 0.6334691, + -0.6530534, + 0.041291475, + -0.6526422, + -0.50964934, + -0.53948045}), + torch::tensor( + {0.31639445, + -0.5517084, + -0.020198822, + 0.038947523, + 0.596805, + -0.69984597, + -0.63558686}), + torch::tensor( + {-0.5375881, + 0.6480124, + -0.3327999, + 0.78091884, + 0.33478355, + 0.11723292, + -0.0068918467}), + torch::tensor( + {-0.18301755, + 0.11615932, + -0.21140611, + -0.0065051913, + 0.11812818, + -0.72418123, + -0.4852514}), + torch::tensor( + {0.74759305, + -0.75134623, + -0.06647962, + 0.9158571, + 0.33444047, + 0.026212811, + -0.8023249}), + }, + { + torch::tensor( + {0.33950645, + 0.08810693, + -0.42316547, + -0.49035522, + -0.19512305, + 0.2776215, + 0.62385184, + 0.3636803, + -0.07122034, + 0.2217493, + -0.62047863, + -0.5401811, + 0.29510874, + -0.3582324, + 0.30427164}), + torch::tensor( + {-0.44647068, + -0.3132621, + -0.5209466, + 0.33004105, + -0.064444125, + 0.4867385, + 0.3914091, + 0.34996456, + 0.020337284, + -0.19554132, + -0.13751587, + 0.084053695, + 0.31351423, + -0.44309196, + 0.5308059}), + torch::tensor( + {-0.06876695, + -0.5299633, + -0.342195, + 0.5596078, + 0.5783978, + -0.5858307, + 0.44606978, + 0.31696087, + 0.37436205, + 0.5353921, + -0.34086213, + 0.19971412, + 0.2588218, + -0.1868864, + 0.21165323}), + torch::tensor( + {-0.18196595, + 0.39102274, + -0.17547771, + -0.2357767, + 0.1592201, + 0.22432798, + -0.30899984, + 0.055908024, + 0.3665306, + -0.062927246, + 0.1924777, + -0.15254003, + 0.22167546, + -0.4581471, + -0.37190205}), + torch::tensor( + {-0.32102844, + 0.58123356, + -0.17018372, + -0.0017268658, + -0.30642188, + 0.63137406, + 0.6177004, + -0.4769895, + -0.51271266, + -0.47939685, + -0.0030501485, + -0.1612154, + -0.413971, + -0.22684419, + 0.119490385}), + torch::tensor( + {-0.33045214, + 0.14014107, + -0.14503211, + -0.30646476, + 0.08689022, + 0.52004176, + -0.42759007, + 0.029367685, + -0.23322919, + 0.6206326, + -0.60005057, + -0.60631055, + 0.62321216, + -0.40014827, + 0.12125647}), + torch::tensor( + {-0.05458474, + -0.1332312, + -0.14125755, + 0.40188795, + 0.03022945, + -0.6157679, + -0.37337655, + -0.21560428, + 0.31826198, + -0.40928328, + 0.59635466, + -0.1408664, + -0.11356497, + 0.4955656, + 0.317877}), + torch::tensor( + {0.53639835, + 0.36584032, + -0.19185784, + -0.4196168, + -0.047056615, + 0.523444, + -0.2122646, + -0.58652025, + 0.2592615, + 0.6156774, + -0.18004948, + -0.5236881, + -0.5736749, + 0.15849298, + -0.04788935}), + torch::tensor( + {-0.31938198, + 0.12783945, + 0.24018055, + 0.5030039, + 0.49102718, + -0.09466827, + -0.5576785, + -0.57150143, + 0.5905171, + 0.2795803, + 0.27569205, + -0.5472138, + 0.58566374, + 0.5991524, + 0.571019}), + torch::tensor( + {-0.53353, + -0.23864025, + -0.43498003, + 0.5989136, + -0.2717535, + -0.28875044, + 0.33135164, + -0.2925691, + -0.31149834, + -0.055321813, + -0.060791314, + -0.49266922, + 0.527258, + -0.27898985, + 0.22434139}), + torch::tensor( + {0.55009943, + 0.31895775, + 0.089542985, + 0.53812927, + 0.085018635, + -0.29260972, + 0.59830266, + 0.14968991, + -0.6170747, + -0.18005475, + -0.43080813, + 0.5545073, + -0.104447424, + -0.576499, + -0.039758265}), + torch::tensor( + {0.39719027, + 0.16431224, + 0.19998527, + 0.058738172, + 0.23579127, + -0.15409905, + -0.25159094, + -0.59116155, + -0.4764521, + 0.2741078, + -0.37448537, + 0.09073615, + 0.20187438, + 0.044780314, + -0.4100524}), + torch::tensor( + {0.6048352, + -0.36777255, + 0.52019507, + -0.5031236, + -0.15213624, + 0.34403604, + -0.25840908, + 0.53128976, + -0.43522838, + -0.53111815, + -0.28518632, + 0.10223669, + 0.5823371, + -0.30193484, + 0.22618395}), + torch::tensor( + {-0.15856105, + -0.13715577, + 0.4650467, + -0.49014413, + 0.06717521, + 0.59476703, + -0.08690709, + 0.49106258, + -0.194769, + 0.50910276, + -0.6117934, + -0.09115183, + -0.111088574, + 0.20496935, + 0.24812967}), + torch::tensor( + {0.48561007, + -0.094227016, + -0.025080085, + 0.43311268, + -0.17112547, + 0.55443174, + -0.42109656, + -0.068437934, + -0.03396976, + 0.2822079, + 0.4323948, + -0.10024661, + -0.52397716, + 0.31337678, + 0.18918753}), + }, + { + torch::tensor( + {-0.10132998, + -0.37424773, + -0.3613546, + 0.15002292, + 0.59365314, + -0.5066626, + 0.38276488, + -0.32790715, + -0.12318605, + 0.5020248, + -0.14304265, + 0.057588696, + -0.442107, + 0.538372, + -0.081757426}), + torch::tensor( + {-0.46257195, + 0.18518949, + -0.44968688, + -0.5018581, + 0.03852868, + 0.5014383, + -0.1789748, + 0.29767585, + 0.5434825, + 0.419483, + -0.33176154, + -0.06932968, + -0.19902417, + -0.508572, + 0.00026130676}), + }, }; } diff --git a/test/cpp/api/jit.cpp b/test/cpp/api/jit.cpp index fa235788d1a682..ae706e6d38781e 100644 --- a/test/cpp/api/jit.cpp +++ b/test/cpp/api/jit.cpp @@ -25,17 +25,17 @@ TEST(TorchScriptTest, CanCompileMultipleFunctions) { ASSERT_EQ(1, module->run_method("test_mul", a, b).toTensor().item()); - ASSERT_EQ(2, module->run_method("test_relu", a, b).toTensor().item()); + ASSERT_EQ( + 2, module->run_method("test_relu", a, b).toTensor().item()); ASSERT_TRUE( - 0x200 == module->run_method("test_while", a, b).toTensor().item()); + 0x200 == + module->run_method("test_while", a, b).toTensor().item()); at::IValue list = c10::List({3, 4}); ASSERT_EQ(2, module->run_method("test_len", list).toInt()); - } - TEST(TorchScriptTest, TestNestedIValueModuleArgMatching) { auto module = torch::jit::compile(R"JIT( def nested_loop(a: List[List[Tensor]], b: int): @@ -51,11 +51,13 @@ TEST(TorchScriptTest, TestNestedIValueModuleArgMatching) { module->run_method("nested_loop", list_of_lists, b); auto generic_list = c10::impl::GenericList(at::TensorType::get()); - auto empty_generic_list = c10::impl::GenericList(at::ListType::create(at::TensorType::get())); + auto empty_generic_list = + c10::impl::GenericList(at::ListType::create(at::TensorType::get())); empty_generic_list.push_back(generic_list); module->run_method("nested_loop", empty_generic_list, b); - auto too_many_lists = c10::impl::GenericList(at::ListType::create(at::ListType::create(at::TensorType::get()))); + auto too_many_lists = c10::impl::GenericList( + at::ListType::create(at::ListType::create(at::TensorType::get()))); too_many_lists.push_back(empty_generic_list); try { module->run_method("nested_loop", too_many_lists, b); @@ -69,7 +71,6 @@ TEST(TorchScriptTest, TestNestedIValueModuleArgMatching) { }; } - TEST(TorchScriptTest, TestDictArgMatching) { auto module = torch::jit::compile(R"JIT( def dict_op(a: Dict[str, Tensor], b: str): @@ -88,11 +89,10 @@ TEST(TorchScriptTest, TestTupleArgMatching) { )JIT"); c10::List int_list({1}); - auto tuple_generic_list = c10::ivalue::Tuple::create({ int_list }); + auto tuple_generic_list = c10::ivalue::Tuple::create({int_list}); // doesn't fail on arg matching module->run_method("tuple_op", tuple_generic_list); - } TEST(TorchScriptTest, TestOptionalArgMatching) { @@ -109,7 +109,6 @@ TEST(TorchScriptTest, TestOptionalArgMatching) { ASSERT_EQ(2, module->run_method("optional_tuple_op", optional_tuple).toInt()); ASSERT_EQ( 0, module->run_method("optional_tuple_op", torch::jit::IValue()).toInt()); - } TEST(TorchScriptTest, TestPickle) { diff --git a/test/cpp/api/meta_tensor.cpp b/test/cpp/api/meta_tensor.cpp index 286eaf2c5d5fb1..a1db68087a1212 100644 --- a/test/cpp/api/meta_tensor.cpp +++ b/test/cpp/api/meta_tensor.cpp @@ -1,7 +1,7 @@ #include -#include #include +#include #include diff --git a/test/cpp/api/misc.cpp b/test/cpp/api/misc.cpp index 734cea27e5cca7..d469c655a9a2d8 100644 --- a/test/cpp/api/misc.cpp +++ b/test/cpp/api/misc.cpp @@ -55,8 +55,9 @@ TEST(NoGradTest, SetsGradModeCorrectly) { torch::Tensor s = y.sum(); // Mimicking python API behavior: - ASSERT_THROWS_WITH(s.backward(), - "element 0 of tensors does not require grad and does not have a grad_fn") + ASSERT_THROWS_WITH( + s.backward(), + "element 0 of tensors does not require grad and does not have a grad_fn") } struct AutogradTest : torch::test::SeedingFixture { diff --git a/test/cpp/api/module.cpp b/test/cpp/api/module.cpp index bb4632431d4469..d0bd743e843e77 100644 --- a/test/cpp/api/module.cpp +++ b/test/cpp/api/module.cpp @@ -152,7 +152,8 @@ TEST_F(ModuleTest, RegisterParameterUndefinedTensor) { struct TestModel : public torch::nn::Module {}; { TestModel model; - model.register_parameter("undefined_tensor", torch::Tensor(), /*requires_grad=*/false); + model.register_parameter( + "undefined_tensor", torch::Tensor(), /*requires_grad=*/false); ASSERT_EQ(model.parameters().size(), 0); } { @@ -163,11 +164,10 @@ TEST_F(ModuleTest, RegisterParameterUndefinedTensor) { ASSERT_EQ(model.parameters().size(), 0); ASSERT_EQ( - count_substr_occurrences( - warnings.str(), - "Ignoring the `requires_grad=true` function parameter" - ), - 1); + count_substr_occurrences( + warnings.str(), + "Ignoring the `requires_grad=true` function parameter"), + 1); } } @@ -229,7 +229,8 @@ TEST_F(ModuleTest, AsCastsModulesCorrectly) { } void test_DeviceOrDtypeConversionSkipsUndefinedTensor( - torch::Device to_device, torch::Dtype to_dtype) { + torch::Device to_device, + torch::Dtype to_dtype) { { // Case 1: Undefined tensors as parameters Linear module(LinearOptions(10, 20).bias(false)); @@ -248,7 +249,8 @@ void test_DeviceOrDtypeConversionSkipsUndefinedTensor( } { // Case 2: Undefined tensors as buffers - BatchNorm1d module(BatchNorm1dOptions(5).track_running_stats(false).affine(true)); + BatchNorm1d module( + BatchNorm1dOptions(5).track_running_stats(false).affine(true)); ASSERT_TRUE(module->weight.defined()); ASSERT_FALSE(module->running_mean.defined()); @@ -269,7 +271,8 @@ TEST_F(ModuleTest, DeviceOrDtypeConversionSkipsUndefinedTensor) { } TEST_F(ModuleTest, DeviceOrDtypeConversionSkipsUndefinedTensor_CUDA) { - test_DeviceOrDtypeConversionSkipsUndefinedTensor(torch::kCUDA, torch::kDouble); + test_DeviceOrDtypeConversionSkipsUndefinedTensor( + torch::kCUDA, torch::kDouble); } TEST_F(ModuleTest, ParametersAndBuffersAccessorSkipsUndefinedTensor) { @@ -285,7 +288,8 @@ TEST_F(ModuleTest, ParametersAndBuffersAccessorSkipsUndefinedTensor) { ASSERT_TRUE(pointer_equal(named_params["weight"], module->weight)); } { - BatchNorm1d module(BatchNorm1dOptions(5).track_running_stats(false).affine(false)); + BatchNorm1d module( + BatchNorm1dOptions(5).track_running_stats(false).affine(false)); auto buffers = module->buffers(); ASSERT_EQ(buffers.size(), 0); @@ -293,7 +297,8 @@ TEST_F(ModuleTest, ParametersAndBuffersAccessorSkipsUndefinedTensor) { ASSERT_EQ(named_buffers.size(), 0); } { - BatchNorm1d module(BatchNorm1dOptions(5).track_running_stats(true).affine(false)); + BatchNorm1d module( + BatchNorm1dOptions(5).track_running_stats(true).affine(false)); auto buffers = module->buffers(); ASSERT_EQ(buffers.size(), 3); @@ -301,11 +306,15 @@ TEST_F(ModuleTest, ParametersAndBuffersAccessorSkipsUndefinedTensor) { ASSERT_EQ(named_buffers.size(), 3); ASSERT_TRUE(pointer_equal(buffers[0], named_buffers["running_mean"])); - ASSERT_TRUE(pointer_equal(named_buffers["running_mean"], module->running_mean)); + ASSERT_TRUE( + pointer_equal(named_buffers["running_mean"], module->running_mean)); ASSERT_TRUE(pointer_equal(buffers[1], named_buffers["running_var"])); - ASSERT_TRUE(pointer_equal(named_buffers["running_var"], module->running_var)); - ASSERT_TRUE(pointer_equal(buffers[2], named_buffers["num_batches_tracked"])); - ASSERT_TRUE(pointer_equal(named_buffers["num_batches_tracked"], module->num_batches_tracked)); + ASSERT_TRUE( + pointer_equal(named_buffers["running_var"], module->running_var)); + ASSERT_TRUE( + pointer_equal(buffers[2], named_buffers["num_batches_tracked"])); + ASSERT_TRUE(pointer_equal( + named_buffers["num_batches_tracked"], module->num_batches_tracked)); } } @@ -394,7 +403,9 @@ struct TestDistinctParametersModule torch::Tensor buffer; }; -void testDistinctParameters(std::shared_ptr m1, std::shared_ptr m2) { +void testDistinctParameters( + std::shared_ptr m1, + std::shared_ptr m2) { auto params1 = m1->named_parameters(); auto params2 = m2->named_parameters(); ASSERT_EQ(params1.size(), 6); @@ -850,16 +861,17 @@ std::shared_ptr make_deeply_nested_test_container() { std::vector> make_key_value_pairs_for_deeply_nested_container() { - return {{"test_prefix", 0}, - {"test_prefix.0", 1}, - {"test_prefix.0.0", 2}, - {"test_prefix.0.1", 3}, - {"test_prefix.1", 4}, - {"test_prefix.2", 5}, - {"test_prefix.2.0", 6}, - {"test_prefix.2.1", 7}, - {"test_prefix.2.1.0", 8}, - {"test_prefix.2.1.1", 9}}; + return { + {"test_prefix", 0}, + {"test_prefix.0", 1}, + {"test_prefix.0.0", 2}, + {"test_prefix.0.1", 3}, + {"test_prefix.1", 4}, + {"test_prefix.2", 5}, + {"test_prefix.2.0", 6}, + {"test_prefix.2.1", 7}, + {"test_prefix.2.1.0", 8}, + {"test_prefix.2.1.1", 9}}; } TEST_F(ModuleTest, ModulesReturnsExpectedSubmodulesForDeepModel) { @@ -1022,7 +1034,6 @@ TEST_F(ModuleTest, PrettyPrint) { float y_; }; - ASSERT_EQ(c10::str(EmptyModule{}), "EmptyModule"); ASSERT_EQ(c10::str(TestModule(1, 3.14)), "TestModule(x=1, y=3.14)"); } diff --git a/test/cpp/api/moduledict.cpp b/test/cpp/api/moduledict.cpp index 3451421cfeb9ce..88a46f37a8c650 100644 --- a/test/cpp/api/moduledict.cpp +++ b/test/cpp/api/moduledict.cpp @@ -18,10 +18,9 @@ TEST_F(ModuleDictTest, ConstructsFromList) { }; std::vector>> list = { - {"module_1", std::make_shared(1)}, - {"module_2", std::make_shared(2)}, - {"module_3", std::make_shared(3)} - }; + {"module_1", std::make_shared(1)}, + {"module_2", std::make_shared(2)}, + {"module_3", std::make_shared(3)}}; ModuleDict dict(list); ASSERT_EQ(dict->size(), 3); } @@ -33,9 +32,9 @@ TEST_F(ModuleDictTest, ConstructsFromordereddict) { }; torch::OrderedDict> ordereddict = { - {"module_1", std::make_shared(1)}, - {"module_2", std::make_shared(2)}, - {"module_3", std::make_shared(3)}, + {"module_1", std::make_shared(1)}, + {"module_2", std::make_shared(2)}, + {"module_3", std::make_shared(3)}, }; ModuleDict dict(ordereddict); ASSERT_EQ(dict->size(), 3); @@ -51,22 +50,19 @@ TEST_F(ModuleDictTest, UpdatePopClearContains) { ASSERT_TRUE(dict->empty()); // Update by List std::vector>> list1 = { - {"module_1", std::make_shared(1)} - }; + {"module_1", std::make_shared(1)}}; dict->update(list1); ASSERT_EQ(dict->size(), 1); ASSERT_TRUE(dict->contains("module_1")); // Update by OrderedDict torch::OrderedDict> ordereddict = { - {"module_2", std::make_shared(2)} - }; + {"module_2", std::make_shared(2)}}; dict->update(ordereddict); ASSERT_EQ(dict->size(), 2); ASSERT_TRUE(dict->contains("module_2")); // Update by another ModuleDict - std::vector>>list2 = { - {"module_3", std::make_shared(3)} - }; + std::vector>> list2 = { + {"module_3", std::make_shared(3)}}; ModuleDict updatedict(list2); dict->update(*updatedict); ASSERT_EQ(dict->size(), 3); @@ -87,32 +83,28 @@ TEST_F(ModuleDictTest, UpdateExist) { int value; }; std::vector>> list1 = { - {"module_1", std::make_shared(1)}, - {"module_2", std::make_shared(2)} - }; + {"module_1", std::make_shared(1)}, + {"module_2", std::make_shared(2)}}; ModuleDict dict(list1); ASSERT_EQ(dict->at("module_2").value, 2); // Update by list std::vector>> list2 = { - {"module_2", std::make_shared(0)}, - {"module_3", std::make_shared(3)} - }; + {"module_2", std::make_shared(0)}, + {"module_3", std::make_shared(3)}}; dict->update(list2); ASSERT_EQ(dict->size(), 3); ASSERT_EQ(dict->at("module_2").value, 0); // Update by ordereddict torch::OrderedDict> ordereddict = { - {"module_3", std::make_shared(0)}, - {"module_4", std::make_shared(4)} - }; + {"module_3", std::make_shared(0)}, + {"module_4", std::make_shared(4)}}; dict->update(ordereddict); ASSERT_EQ(dict->size(), 4); ASSERT_EQ(dict->at("module_3").value, 0); // Update by ModuleDict std::vector>> list3 = { - {"module_4", std::make_shared(0)}, - {"module_1", std::make_shared(0)} - }; + {"module_4", std::make_shared(0)}, + {"module_1", std::make_shared(0)}}; ModuleDict dict2(list3); dict->update(*dict2); ASSERT_EQ(dict->size(), 4); @@ -127,9 +119,9 @@ TEST_F(ModuleDictTest, Keys) { }; torch::OrderedDict> ordereddict = { - {"linear", Linear(10, 3).ptr()}, - {"conv", Conv2d(1, 2, 3).ptr()}, - {"dropout", Dropout(0.5).ptr()}, + {"linear", Linear(10, 3).ptr()}, + {"conv", Conv2d(1, 2, 3).ptr()}, + {"dropout", Dropout(0.5).ptr()}, }; ModuleDict dict(ordereddict); const auto& keys = dict->keys(); @@ -149,8 +141,8 @@ TEST_F(ModuleDictTest, Values) { }; torch::OrderedDict> ordereddict = { - {"module_1", std::make_shared(1)}, - {"module_2", std::make_shared(2)}, + {"module_1", std::make_shared(1)}, + {"module_2", std::make_shared(2)}, }; ModuleDict dict(ordereddict); const auto& values = dict->values(); @@ -160,29 +152,27 @@ TEST_F(ModuleDictTest, Values) { dict->begin(), dict->end(), ordereddict.begin(), - [](const auto& lhs, - const auto& rhs) { + [](const auto& lhs, const auto& rhs) { return lhs.value().get() == rhs.value().get(); })); } TEST_F(ModuleDictTest, SanityCheckForHoldingStandardModules) { torch::OrderedDict> ordereddict = { - {"linear", Linear(10, 3).ptr()}, - {"conv", Conv2d(1, 2, 3).ptr()}, - {"dropout", Dropout(0.5).ptr()}, - {"batch", BatchNorm2d(5).ptr()}, - {"embedding", Embedding(4, 10).ptr()}, - {"lstm", LSTM(4, 5).ptr()} - }; + {"linear", Linear(10, 3).ptr()}, + {"conv", Conv2d(1, 2, 3).ptr()}, + {"dropout", Dropout(0.5).ptr()}, + {"batch", BatchNorm2d(5).ptr()}, + {"embedding", Embedding(4, 10).ptr()}, + {"lstm", LSTM(4, 5).ptr()}}; ModuleDict dict(ordereddict); } TEST_F(ModuleDictTest, HasReferenceSemantics) { torch::OrderedDict> ordereddict = { - {"linear1", Linear(2, 3).ptr()}, - {"linear2", Linear(3, 4).ptr()}, - {"linear3", Linear(4, 5).ptr()}, + {"linear1", Linear(2, 3).ptr()}, + {"linear2", Linear(3, 4).ptr()}, + {"linear3", Linear(4, 5).ptr()}, }; ModuleDict first(ordereddict); ModuleDict second(ordereddict); @@ -192,24 +182,25 @@ TEST_F(ModuleDictTest, HasReferenceSemantics) { first->begin(), first->end(), second->begin(), - [](const auto& lhs, - const auto& rhs) { + [](const auto& lhs, const auto& rhs) { return lhs.value().get() == rhs.value().get(); })); } void iscloneable_helper(torch::Device device) { torch::OrderedDict> ordereddict = { - {"linear", Linear(2, 3).ptr()}, - {"relu", Functional(torch::relu).ptr()}, - {"batch", BatchNorm1d(3).ptr()}, + {"linear", Linear(2, 3).ptr()}, + {"relu", Functional(torch::relu).ptr()}, + {"batch", BatchNorm1d(3).ptr()}, }; ModuleDict dict(ordereddict); dict->to(device); - ModuleDict clone = std::dynamic_pointer_cast(dict->clone(device)); + ModuleDict clone = + std::dynamic_pointer_cast(dict->clone(device)); ASSERT_EQ(dict->size(), clone->size()); - for (auto it = dict->begin(), it_c = clone->begin(); it != dict->end(); ++it, ++it_c) { + for (auto it = dict->begin(), it_c = clone->begin(); it != dict->end(); + ++it, ++it_c) { // The key should be same ASSERT_EQ(it->key(), it_c->key()); // The modules should be the same kind (type). @@ -245,9 +236,9 @@ TEST_F(ModuleDictTest, IsCloneable_CUDA) { TEST_F(ModuleDictTest, RegistersElementsAsSubmodules) { torch::OrderedDict> ordereddict1 = { - {"linear", Linear(10, 3).ptr()}, - {"conv", Conv2d(1, 2, 3).ptr()}, - {"test", Dropout(0.5).ptr()}, + {"linear", Linear(10, 3).ptr()}, + {"conv", Conv2d(1, 2, 3).ptr()}, + {"test", Dropout(0.5).ptr()}, }; ModuleDict dict(ordereddict1); @@ -258,9 +249,7 @@ TEST_F(ModuleDictTest, RegistersElementsAsSubmodules) { // Update Existing torch::OrderedDict> ordereddict2 = { - {"lstm", LSTM(4, 5).ptr()}, - {"test", BatchNorm2d(5).ptr()} - }; + {"lstm", LSTM(4, 5).ptr()}, {"test", BatchNorm2d(5).ptr()}}; dict->update(ordereddict2); modules = dict->children(); @@ -273,9 +262,9 @@ TEST_F(ModuleDictTest, RegistersElementsAsSubmodules) { TEST_F(ModuleDictTest, CloneToDevice_CUDA) { torch::OrderedDict> ordereddict = { - {"linear", Linear(2, 3).ptr()}, - {"relu", Functional(torch::relu).ptr()}, - {"batch", BatchNorm1d(3).ptr()}, + {"linear", Linear(2, 3).ptr()}, + {"relu", Functional(torch::relu).ptr()}, + {"batch", BatchNorm1d(3).ptr()}, }; ModuleDict dict(ordereddict); torch::Device device(torch::kCUDA, 0); @@ -291,13 +280,12 @@ TEST_F(ModuleDictTest, CloneToDevice_CUDA) { TEST_F(ModuleDictTest, PrettyPrintModuleDict) { torch::OrderedDict> ordereddict = { - {"linear", Linear(10, 3).ptr()}, - {"conv", Conv2d(1, 2, 3).ptr()}, - {"dropout", Dropout(0.5).ptr()}, - {"batch", BatchNorm2d(5).ptr()}, - {"embedding", Embedding(4, 10).ptr()}, - {"lstm", LSTM(4, 5).ptr()} - }; + {"linear", Linear(10, 3).ptr()}, + {"conv", Conv2d(1, 2, 3).ptr()}, + {"dropout", Dropout(0.5).ptr()}, + {"batch", BatchNorm2d(5).ptr()}, + {"embedding", Embedding(4, 10).ptr()}, + {"lstm", LSTM(4, 5).ptr()}}; ModuleDict dict(ordereddict); ASSERT_EQ( diff --git a/test/cpp/api/modulelist.cpp b/test/cpp/api/modulelist.cpp index aa4fd05c11d29e..e8a0bebeb94547 100644 --- a/test/cpp/api/modulelist.cpp +++ b/test/cpp/api/modulelist.cpp @@ -184,8 +184,8 @@ TEST_F(ModuleListTest, ExtendPushesModulesFromOtherModuleList) { ASSERT_TRUE(b[0]->as()); ASSERT_TRUE(b[1]->as()); - std::vector> c = {std::make_shared(), - std::make_shared()}; + std::vector> c = { + std::make_shared(), std::make_shared()}; b->extend(c); ASSERT_EQ(b->size(), 4); @@ -291,13 +291,12 @@ TEST_F(ModuleListTest, PrettyPrintModuleList) { TEST_F(ModuleListTest, RangeBasedForLoop) { torch::nn::ModuleList mlist( - torch::nn::Linear(3, 4), - torch::nn::BatchNorm1d(4), - torch::nn::Dropout(0.5) - ); + torch::nn::Linear(3, 4), + torch::nn::BatchNorm1d(4), + torch::nn::Dropout(0.5)); std::stringstream buffer; - for (const auto &module : *mlist) { + for (const auto& module : *mlist) { module->pretty_print(buffer); } } diff --git a/test/cpp/api/modules.cpp b/test/cpp/api/modules.cpp index cdf4f0ea0deb07..26d999c5c1efbf 100644 --- a/test/cpp/api/modules.cpp +++ b/test/cpp/api/modules.cpp @@ -5,11 +5,11 @@ #include -#include -#include -#include #include +#include +#include #include +#include using namespace torch::nn; using namespace torch::test; @@ -40,14 +40,16 @@ struct ModulesTest : torch::test::SeedingFixture {}; TEST_F(ModulesTest, Conv1d) { Conv1d model(Conv1dOptions(3, 2, 3).stride(1).bias(false)); - model->weight.set_data(torch::arange(18, torch::dtype(torch::kFloat)).reshape({2, 3, 3})); - auto x = torch::arange(30, torch::dtype(torch::kFloat).requires_grad(true)).reshape({2, 3, 5}); + model->weight.set_data( + torch::arange(18, torch::dtype(torch::kFloat)).reshape({2, 3, 3})); + auto x = torch::arange(30, torch::dtype(torch::kFloat).requires_grad(true)) + .reshape({2, 3, 5}); auto y = model(x); - auto expected = torch::tensor({{{ 312., 348., 384.}, - { 798., 915., 1032.}}, + auto expected = torch::tensor( + {{{312., 348., 384.}, {798., 915., 1032.}}, - {{ 852., 888., 924.}, - {2553., 2670., 2787.}}}, torch::kFloat); + {{852., 888., 924.}, {2553., 2670., 2787.}}}, + torch::kFloat); ASSERT_TRUE(torch::allclose(y, expected)); torch::Tensor s = y.sum(); @@ -61,22 +63,26 @@ TEST_F(ModulesTest, Conv1dSameStrided) { options.stride(1).padding(torch::kSame); Conv1d model_valid(options); ASSERT_THROWS_WITH( - [&]{ Conv1d model_invalid(options.stride(2)); }(), - "padding='same' is not supported for strided convolutions"); + [&] { Conv1d model_invalid(options.stride(2)); }(), + "padding='same' is not supported for strided convolutions"); } TEST_F(ModulesTest, Conv2dEven) { Conv2d model(Conv2dOptions(3, 2, 3).stride(1).bias(false)); - model->weight.set_data(torch::arange(54, torch::dtype(torch::kFloat)).reshape({2, 3, 3, 3})); - auto x = torch::arange(75, torch::dtype(torch::kFloat).requires_grad(true)).reshape({1, 3, 5, 5}); + model->weight.set_data( + torch::arange(54, torch::dtype(torch::kFloat)).reshape({2, 3, 3, 3})); + auto x = torch::arange(75, torch::dtype(torch::kFloat).requires_grad(true)) + .reshape({1, 3, 5, 5}); auto y = model(x); - auto expected = torch::tensor({{{{15219., 15570., 15921.}, - {16974., 17325., 17676.}, - {18729., 19080., 19431.}}, - - {{37818., 38898., 39978.}, - {43218., 44298., 45378.}, - {48618., 49698., 50778.}}}}, torch::kFloat); + auto expected = torch::tensor( + {{{{15219., 15570., 15921.}, + {16974., 17325., 17676.}, + {18729., 19080., 19431.}}, + + {{37818., 38898., 39978.}, + {43218., 44298., 45378.}, + {48618., 49698., 50778.}}}}, + torch::kFloat); ASSERT_TRUE(torch::allclose(y, expected)); torch::Tensor s = y.sum(); @@ -87,16 +93,18 @@ TEST_F(ModulesTest, Conv2dEven) { TEST_F(ModulesTest, Conv2dUneven) { Conv2d model(Conv2dOptions(3, 2, {3, 2}).stride({1, 1}).bias(false)); - model->weight.set_data(torch::arange(36, torch::dtype(torch::kFloat)).reshape({2, 3, 3, 2})); - auto x = torch::arange(60, torch::dtype(torch::kFloat).requires_grad(true)).reshape({1, 3, 5, 4}); + model->weight.set_data( + torch::arange(36, torch::dtype(torch::kFloat)).reshape({2, 3, 3, 2})); + auto x = torch::arange(60, torch::dtype(torch::kFloat).requires_grad(true)) + .reshape({1, 3, 5, 4}); auto y = model(x); - auto expected = torch::tensor({{{{ 5289., 5442., 5595.}, - { 5901., 6054., 6207.}, - { 6513., 6666., 6819.}}, + auto expected = torch::tensor( + {{{{5289., 5442., 5595.}, {5901., 6054., 6207.}, {6513., 6666., 6819.}}, - {{13227., 13704., 14181.}, - {15135., 15612., 16089.}, - {17043., 17520., 17997.}}}}, torch::kFloat); + {{13227., 13704., 14181.}, + {15135., 15612., 16089.}, + {17043., 17520., 17997.}}}}, + torch::kFloat); ASSERT_TRUE(torch::allclose(y, expected)); torch::Tensor s = y.sum(); @@ -110,41 +118,47 @@ TEST_F(ModulesTest, Conv2dSameStrided) { options.stride(1).padding(torch::kSame); Conv2d model_valid(options); ASSERT_THROWS_WITH( - [&]{ Conv2d model_invalid(options.stride(2)); }(), - "padding='same' is not supported for strided convolutions"); + [&] { Conv2d model_invalid(options.stride(2)); }(), + "padding='same' is not supported for strided convolutions"); ASSERT_THROWS_WITH( - [&]{ Conv2d model_invalid(options.stride({1, 2})); }(), - "padding='same' is not supported for strided convolutions"); + [&] { + Conv2d model_invalid(options.stride({1, 2})); + }(), + "padding='same' is not supported for strided convolutions"); } TEST_F(ModulesTest, Conv3d) { Conv3d model(Conv3dOptions(3, 2, 3).stride(1).bias(false)); - model->weight.set_data(torch::arange(162, torch::dtype(torch::kFloat)).reshape({2, 3, 3, 3, 3})); - auto x = torch::arange(375, torch::dtype(torch::kFloat).requires_grad(true)).reshape({1, 3, 5, 5, 5}); + model->weight.set_data( + torch::arange(162, torch::dtype(torch::kFloat)).reshape({2, 3, 3, 3, 3})); + auto x = torch::arange(375, torch::dtype(torch::kFloat).requires_grad(true)) + .reshape({1, 3, 5, 5, 5}); auto y = model(x); - auto expected = torch::tensor({{{{{ 700704., 703944., 707184.}, - { 716904., 720144., 723384.}, - { 733104., 736344., 739584.}}, - - {{ 781704., 784944., 788184.}, - { 797904., 801144., 804384.}, - { 814104., 817344., 820584.}}, - - {{ 862704., 865944., 869184.}, - { 878904., 882144., 885384.}, - { 895104., 898344., 901584.}}}, - - {{{1724220., 1734021., 1743822.}, - {1773225., 1783026., 1792827.}, - {1822230., 1832031., 1841832.}}, - - {{1969245., 1979046., 1988847.}, - {2018250., 2028051., 2037852.}, - {2067255., 2077056., 2086857.}}, - - {{2214270., 2224071., 2233872.}, - {2263275., 2273076., 2282877.}, - {2312280., 2322081., 2331882.}}}}}, torch::kFloat); + auto expected = torch::tensor( + {{{{{700704., 703944., 707184.}, + {716904., 720144., 723384.}, + {733104., 736344., 739584.}}, + + {{781704., 784944., 788184.}, + {797904., 801144., 804384.}, + {814104., 817344., 820584.}}, + + {{862704., 865944., 869184.}, + {878904., 882144., 885384.}, + {895104., 898344., 901584.}}}, + + {{{1724220., 1734021., 1743822.}, + {1773225., 1783026., 1792827.}, + {1822230., 1832031., 1841832.}}, + + {{1969245., 1979046., 1988847.}, + {2018250., 2028051., 2037852.}, + {2067255., 2077056., 2086857.}}, + + {{2214270., 2224071., 2233872.}, + {2263275., 2273076., 2282877.}, + {2312280., 2322081., 2331882.}}}}}, + torch::kFloat); ASSERT_TRUE(torch::allclose(y, expected)); torch::Tensor s = y.sum(); @@ -158,11 +172,13 @@ TEST_F(ModulesTest, Conv3dSameStrided) { options.stride(1).padding(torch::kSame); Conv3d model_valid(options); ASSERT_THROWS_WITH( - [&]{ Conv3d model_invalid(options.stride(2)); }(), - "padding='same' is not supported for strided convolutions"); + [&] { Conv3d model_invalid(options.stride(2)); }(), + "padding='same' is not supported for strided convolutions"); ASSERT_THROWS_WITH( - [&]{ Conv3d model_invalid(options.stride({1, 2, 1})); }(), - "padding='same' is not supported for strided convolutions"); + [&] { + Conv3d model_invalid(options.stride({1, 2, 1})); + }(), + "padding='same' is not supported for strided convolutions"); } TEST_F(ModulesTest, ConvTranspose1d) { @@ -170,12 +186,13 @@ TEST_F(ModulesTest, ConvTranspose1d) { model->weight.set_data(torch::arange(18.).view({2, 3, 3})); auto x = torch::arange(20.).reshape({2, 2, 5}); auto y = model(x); - auto expected = torch::tensor({{{ 45., 104., 179., 212., 245., 188., 107.}, - { 60., 140., 242., 293., 344., 260., 146.}, - { 75., 176., 305., 374., 443., 332., 185.}}, - {{ 135., 304., 509., 542., 575., 428., 237.}, - { 210., 460., 752., 803., 854., 620., 336.}, - { 285., 616., 995., 1064., 1133., 812., 435.}}}); + auto expected = torch::tensor( + {{{45., 104., 179., 212., 245., 188., 107.}, + {60., 140., 242., 293., 344., 260., 146.}, + {75., 176., 305., 374., 443., 332., 185.}}, + {{135., 304., 509., 542., 575., 428., 237.}, + {210., 460., 752., 803., 854., 620., 336.}, + {285., 616., 995., 1064., 1133., 812., 435.}}}); ASSERT_TRUE(torch::allclose(y, expected)); torch::Tensor s = y.sum(); @@ -189,27 +206,28 @@ TEST_F(ModulesTest, ConvTranspose2dEven) { model->weight.set_data(torch::arange(54.).view({2, 3, 3, 3})); auto x = torch::arange(50.).view({1, 2, 5, 5}); auto y = model(x); - auto expected = torch::tensor({{{{ 675., 1402., 2183., 2270., 2357., 1634., 849.}, - { 1560., 3240., 5044., 5236., 5428., 3760., 1952.}, - { 2685., 5574., 8673., 8988., 9303., 6438., 3339.}, - { 3180., 6594., 10248., 10563., 10878., 7518., 3894.}, - { 3675., 7614., 11823., 12138., 12453., 8598., 4449.}, - { 2820., 5832., 9040., 9268., 9496., 6544., 3380.}, - { 1605., 3314., 5129., 5252., 5375., 3698., 1907.}}, - {{ 900., 1870., 2912., 3053., 3194., 2210., 1146.}, - { 2100., 4356., 6772., 7072., 7372., 5092., 2636.}, - { 3630., 7518., 11670., 12147., 12624., 8706., 4500.}, - { 4395., 9078., 14055., 14532., 15009., 10326., 5325.}, - { 5160., 10638., 16440., 16917., 17394., 11946., 6150.}, - { 3900., 8028., 12388., 12724., 13060., 8956., 4604.}, - { 2190., 4502., 6938., 7115., 7292., 4994., 2564.}}, - {{ 1125., 2338., 3641., 3836., 4031., 2786., 1443.}, - { 2640., 5472., 8500., 8908., 9316., 6424., 3320.}, - { 4575., 9462., 14667., 15306., 15945., 10974., 5661.}, - { 5610., 11562., 17862., 18501., 19140., 13134., 6756.}, - { 6645., 13662., 21057., 21696., 22335., 15294., 7851.}, - { 4980., 10224., 15736., 16180., 16624., 11368., 5828.}, - { 2775., 5690., 8747., 8978., 9209., 6290., 3221.}}}}); + auto expected = torch::tensor( + {{{{675., 1402., 2183., 2270., 2357., 1634., 849.}, + {1560., 3240., 5044., 5236., 5428., 3760., 1952.}, + {2685., 5574., 8673., 8988., 9303., 6438., 3339.}, + {3180., 6594., 10248., 10563., 10878., 7518., 3894.}, + {3675., 7614., 11823., 12138., 12453., 8598., 4449.}, + {2820., 5832., 9040., 9268., 9496., 6544., 3380.}, + {1605., 3314., 5129., 5252., 5375., 3698., 1907.}}, + {{900., 1870., 2912., 3053., 3194., 2210., 1146.}, + {2100., 4356., 6772., 7072., 7372., 5092., 2636.}, + {3630., 7518., 11670., 12147., 12624., 8706., 4500.}, + {4395., 9078., 14055., 14532., 15009., 10326., 5325.}, + {5160., 10638., 16440., 16917., 17394., 11946., 6150.}, + {3900., 8028., 12388., 12724., 13060., 8956., 4604.}, + {2190., 4502., 6938., 7115., 7292., 4994., 2564.}}, + {{1125., 2338., 3641., 3836., 4031., 2786., 1443.}, + {2640., 5472., 8500., 8908., 9316., 6424., 3320.}, + {4575., 9462., 14667., 15306., 15945., 10974., 5661.}, + {5610., 11562., 17862., 18501., 19140., 13134., 6756.}, + {6645., 13662., 21057., 21696., 22335., 15294., 7851.}, + {4980., 10224., 15736., 16180., 16624., 11368., 5828.}, + {2775., 5690., 8747., 8978., 9209., 6290., 3221.}}}}); ASSERT_TRUE(torch::allclose(y, expected)); torch::Tensor s = y.sum(); @@ -219,31 +237,33 @@ TEST_F(ModulesTest, ConvTranspose2dEven) { } TEST_F(ModulesTest, ConvTranspose2dUneven) { - ConvTranspose2d model(ConvTranspose2dOptions(3, 2, {3, 2}).stride({1, 1}).bias(false)); + ConvTranspose2d model( + ConvTranspose2dOptions(3, 2, {3, 2}).stride({1, 1}).bias(false)); model->weight.set_data(torch::arange(36.).view({2, 3, 3, 2})); auto x = torch::arange(40.).view({1, 2, 5, 4}); auto y = model(x); - auto expected = torch::tensor({{{{ 360., 758., 796., 834., 440.}, - { 832., 1752., 1836., 1920., 1012.}, - {1432., 3014., 3152., 3290., 1732.}, - {1696., 3566., 3704., 3842., 2020.}, - {1960., 4118., 4256., 4394., 2308.}, - {1504., 3152., 3252., 3352., 1756.}, - { 856., 1790., 1844., 1898., 992.}}, - {{ 480., 1010., 1072., 1134., 596.}, - {1120., 2352., 2484., 2616., 1372.}, - {1936., 4058., 4268., 4478., 2344.}, - {2344., 4898., 5108., 5318., 2776.}, - {2752., 5738., 5948., 6158., 3208.}, - {2080., 4328., 4476., 4624., 2404.}, - {1168., 2426., 2504., 2582., 1340.}}, - {{ 600., 1262., 1348., 1434., 752.}, - {1408., 2952., 3132., 3312., 1732.}, - {2440., 5102., 5384., 5666., 2956.}, - {2992., 6230., 6512., 6794., 3532.}, - {3544., 7358., 7640., 7922., 4108.}, - {2656., 5504., 5700., 5896., 3052.}, - {1480., 3062., 3164., 3266., 1688.}}}}); + auto expected = torch::tensor( + {{{{360., 758., 796., 834., 440.}, + {832., 1752., 1836., 1920., 1012.}, + {1432., 3014., 3152., 3290., 1732.}, + {1696., 3566., 3704., 3842., 2020.}, + {1960., 4118., 4256., 4394., 2308.}, + {1504., 3152., 3252., 3352., 1756.}, + {856., 1790., 1844., 1898., 992.}}, + {{480., 1010., 1072., 1134., 596.}, + {1120., 2352., 2484., 2616., 1372.}, + {1936., 4058., 4268., 4478., 2344.}, + {2344., 4898., 5108., 5318., 2776.}, + {2752., 5738., 5948., 6158., 3208.}, + {2080., 4328., 4476., 4624., 2404.}, + {1168., 2426., 2504., 2582., 1340.}}, + {{600., 1262., 1348., 1434., 752.}, + {1408., 2952., 3132., 3312., 1732.}, + {2440., 5102., 5384., 5666., 2956.}, + {2992., 6230., 6512., 6794., 3532.}, + {3544., 7358., 7640., 7922., 4108.}, + {2656., 5504., 5700., 5896., 3052.}, + {1480., 3062., 3164., 3266., 1688.}}}}); ASSERT_TRUE(torch::allclose(y, expected)); torch::Tensor s = y.sum(); @@ -257,24 +277,13 @@ TEST_F(ModulesTest, ConvTranspose3d) { model->weight.set_data(torch::arange(32.).reshape({2, 2, 2, 2, 2})); auto x = torch::arange(16.).reshape({1, 2, 2, 2, 2}); auto y = model(x); - auto expected = torch::tensor({{{{{ 128., 280., 154.}, - { 304., 664., 364.}, - { 184., 400., 218.}}, - {{ 352., 768., 420.}, - { 832., 1808., 984.}, - { 496., 1072., 580.}}, - {{ 256., 552., 298.}, - { 592., 1272., 684.}, - { 344., 736., 394.}}}, - {{{ 192., 424., 234.}, - { 464., 1016., 556.}, - { 280., 608., 330.}}, - {{ 544., 1184., 644.}, - {1280., 2768., 1496.}, - { 752., 1616., 868.}}, - {{ 384., 824., 442.}, - { 880., 1880., 1004.}, - { 504., 1072., 570.}}}}}); + auto expected = torch::tensor( + {{{{{128., 280., 154.}, {304., 664., 364.}, {184., 400., 218.}}, + {{352., 768., 420.}, {832., 1808., 984.}, {496., 1072., 580.}}, + {{256., 552., 298.}, {592., 1272., 684.}, {344., 736., 394.}}}, + {{{192., 424., 234.}, {464., 1016., 556.}, {280., 608., 330.}}, + {{544., 1184., 644.}, {1280., 2768., 1496.}, {752., 1616., 868.}}, + {{384., 824., 442.}, {880., 1880., 1004.}, {504., 1072., 570.}}}}}); ASSERT_TRUE(torch::allclose(y, expected)); torch::Tensor s = y.sum(); @@ -291,7 +300,7 @@ TEST_F(ModulesTest, MaxPool1d) { s.backward(); ASSERT_EQ(y.ndimension(), 3); - ASSERT_TRUE(torch::allclose(y, torch::ones({1, 1 ,2}))); + ASSERT_TRUE(torch::allclose(y, torch::ones({1, 1, 2}))); ASSERT_EQ(s.ndimension(), 0); ASSERT_EQ(y.sizes(), std::vector({1, 1, 2})); } @@ -303,10 +312,11 @@ TEST_F(ModulesTest, MaxPool1dReturnIndices) { std::tie(y, indices) = model->forward_with_indices(x); ASSERT_EQ(y.dim(), 3); - ASSERT_TRUE(torch::allclose(y, torch::ones({1, 1 ,2}))); + ASSERT_TRUE(torch::allclose(y, torch::ones({1, 1, 2}))); ASSERT_EQ(y.sizes(), std::vector({1, 1, 2})); - ASSERT_TRUE(torch::allclose(indices, torch::tensor({{{0, 2}}}, torch::kLong))); + ASSERT_TRUE( + torch::allclose(indices, torch::tensor({{{0, 2}}}, torch::kLong))); ASSERT_EQ(indices.sizes(), std::vector({1, 1, 2})); } @@ -318,7 +328,7 @@ TEST_F(ModulesTest, MaxPool2dEven) { s.backward(); ASSERT_EQ(y.ndimension(), 3); - ASSERT_TRUE(torch::allclose(y, torch::ones({2, 2 ,2}))); + ASSERT_TRUE(torch::allclose(y, torch::ones({2, 2, 2}))); ASSERT_EQ(s.ndimension(), 0); ASSERT_EQ(y.sizes(), std::vector({2, 2, 2})); } @@ -343,14 +353,11 @@ TEST_F(ModulesTest, MaxPool2dReturnIndices) { std::tie(y, indices) = model->forward_with_indices(x); ASSERT_EQ(y.dim(), 3); - ASSERT_TRUE(torch::allclose(y, torch::ones({2, 2 ,2}))); + ASSERT_TRUE(torch::allclose(y, torch::ones({2, 2, 2}))); ASSERT_EQ(y.sizes(), std::vector({2, 2, 2})); ASSERT_TRUE(torch::allclose( - indices, - torch::tensor({{{ 0, 2}, - {10, 12}}, - {{ 0, 2}, - {10, 12}}}, torch::kLong))); + indices, + torch::tensor({{{0, 2}, {10, 12}}, {{0, 2}, {10, 12}}}, torch::kLong))); ASSERT_EQ(indices.sizes(), std::vector({2, 2, 2})); } @@ -378,15 +385,11 @@ TEST_F(ModulesTest, MaxPool3dReturnIndices) { ASSERT_EQ(y.sizes(), std::vector({2, 2, 2, 2})); ASSERT_TRUE(torch::allclose( - indices, - torch::tensor({{{{ 0, 2}, - {10, 12}}, - {{50, 52}, - {60, 62}}}, - {{{ 0, 2}, - {10, 12}}, - {{50, 52}, - {60, 62}}}}, torch::kLong))); + indices, + torch::tensor( + {{{{0, 2}, {10, 12}}, {{50, 52}, {60, 62}}}, + {{{0, 2}, {10, 12}}, {{50, 52}, {60, 62}}}}, + torch::kLong))); ASSERT_EQ(indices.sizes(), std::vector({2, 2, 2, 2})); } @@ -465,11 +468,7 @@ TEST_F(ModulesTest, FractionalMaxPool2dReturnIndices) { ASSERT_TRUE(torch::allclose(y, torch::ones({2, 2, 2}))); ASSERT_EQ(y.sizes(), std::vector({2, 2, 2})); ASSERT_TRUE(torch::allclose( - indices, - torch::tensor({{{ 0, 2}, - {10, 12}}, - {{ 0, 2}, - {10, 12}}}))); + indices, torch::tensor({{{0, 2}, {10, 12}}, {{0, 2}, {10, 12}}}))); ASSERT_EQ(indices.sizes(), std::vector({2, 2, 2})); } @@ -497,15 +496,10 @@ TEST_F(ModulesTest, FractionalMaxPool3dReturnIndices) { ASSERT_EQ(y.sizes(), std::vector({2, 2, 2, 2})); ASSERT_TRUE(torch::allclose( - indices, - torch::tensor({{{{ 0, 2}, - {10, 12}}, - {{50, 52}, - {60, 62}}}, - {{{ 0, 2}, - {10, 12}}, - {{50, 52}, - {60, 62}}}}))); + indices, + torch::tensor( + {{{{0, 2}, {10, 12}}, {{50, 52}, {60, 62}}}, + {{{0, 2}, {10, 12}}, {{50, 52}, {60, 62}}}}))); ASSERT_EQ(indices.sizes(), std::vector({2, 2, 2, 2})); } @@ -517,7 +511,10 @@ TEST_F(ModulesTest, LPPool1d) { LPPool1d model(LPPool1dOptions(norm_type, kernel_size).stride(stride)); auto x = torch::ones({1, 1, 5}); auto y = model(x); - auto expected = (torch::pow(torch::tensor({{{1, 1}}}, torch::kFloat), norm_type) * kernel_size).pow(1. / norm_type); + auto expected = + (torch::pow(torch::tensor({{{1, 1}}}, torch::kFloat), norm_type) * + kernel_size) + .pow(1. / norm_type); ASSERT_EQ(y.ndimension(), 3); ASSERT_TRUE(torch::allclose(y, expected)); @@ -532,7 +529,10 @@ TEST_F(ModulesTest, LPPool2d) { LPPool2d model(LPPool2dOptions(norm_type, kernel_size).stride(stride)); auto x = torch::ones({1, 2, 5}); auto y = model(x); - auto expected = (torch::pow(torch::tensor({{{1, 1}}}, torch::kFloat), norm_type) * (kernel_size[0] * kernel_size[1])).pow(1. / norm_type); + auto expected = + (torch::pow(torch::tensor({{{1, 1}}}, torch::kFloat), norm_type) * + (kernel_size[0] * kernel_size[1])) + .pow(1. / norm_type); ASSERT_EQ(y.ndimension(), 3); ASSERT_TRUE(torch::allclose(y, expected)); @@ -541,7 +541,8 @@ TEST_F(ModulesTest, LPPool2d) { TEST_F(ModulesTest, Identity) { Identity identity; - auto input = torch::tensor({{1, 3, 4}, {2, 3, 4}}, torch::dtype(torch::kFloat).requires_grad(true)); + auto input = torch::tensor( + {{1, 3, 4}, {2, 3, 4}}, torch::dtype(torch::kFloat).requires_grad(true)); auto output = identity->forward(input); auto expected = torch::tensor({{1, 3, 4}, {2, 3, 4}}, torch::kFloat); auto s = output.sum(); @@ -553,7 +554,8 @@ TEST_F(ModulesTest, Identity) { TEST_F(ModulesTest, Flatten) { Flatten flatten; - auto input = torch::tensor({{1, 3, 4}, {2, 5, 6}}, torch::dtype(torch::kFloat).requires_grad(true)); + auto input = torch::tensor( + {{1, 3, 4}, {2, 5, 6}}, torch::dtype(torch::kFloat).requires_grad(true)); auto output = flatten->forward(input); auto expected = torch::tensor({{1, 3, 4}, {2, 5, 6}}, torch::kFloat); auto s = output.sum(); @@ -564,16 +566,16 @@ TEST_F(ModulesTest, Flatten) { // Testing with optional arguments start_dim and end_dim Flatten flatten_optional_dims(FlattenOptions().start_dim(2).end_dim(3)); - input = torch::tensor({ - {{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}, - {{{9, 10}, {11, 12}}, {{13, 14}, {15, 16}}} - }, torch::dtype(torch::kFloat).requires_grad(true)); // Tensor with sizes (2, 2, 2, 2) + input = torch::tensor( + {{{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}, + {{{9, 10}, {11, 12}}, {{13, 14}, {15, 16}}}}, + torch::dtype(torch::kFloat) + .requires_grad(true)); // Tensor with sizes (2, 2, 2, 2) output = flatten_optional_dims->forward(input); - expected = torch::tensor({ - {{1, 2, 3, 4}, {5, 6, 7, 8}}, - {{9, 10, 11, 12}, {13, 14, 15, 16}} - }, torch::kFloat); // Tensor with sizes (2, 2, 4) + expected = torch::tensor( + {{{1, 2, 3, 4}, {5, 6, 7, 8}}, {{9, 10, 11, 12}, {13, 14, 15, 16}}}, + torch::kFloat); // Tensor with sizes (2, 2, 4) s = output.sum(); s.backward(); @@ -613,7 +615,8 @@ TEST_F(ModulesTest, Unflatten) { TEST_F(ModulesTest, AdaptiveMaxPool1d) { AdaptiveMaxPool1d model(3); - auto x = torch::tensor({{{1, 2, 3, 4, 5}}}, torch::dtype(torch::kFloat).requires_grad(true)); + auto x = torch::tensor( + {{{1, 2, 3, 4, 5}}}, torch::dtype(torch::kFloat).requires_grad(true)); auto y = model(x); torch::Tensor s = y.sum(); @@ -626,14 +629,16 @@ TEST_F(ModulesTest, AdaptiveMaxPool1d) { TEST_F(ModulesTest, AdaptiveMaxPool1dReturnIndices) { AdaptiveMaxPool1d model(3); - auto x = torch::tensor({{{1, 2, 3, 4, 5}}}, torch::dtype(torch::kFloat).requires_grad(true)); + auto x = torch::tensor( + {{{1, 2, 3, 4, 5}}}, torch::dtype(torch::kFloat).requires_grad(true)); torch::Tensor y, indices; std::tie(y, indices) = model->forward_with_indices(x); ASSERT_EQ(y.dim(), 3); ASSERT_TRUE(torch::allclose(y, torch::tensor({{{2, 4, 5}}}, torch::kFloat))); ASSERT_EQ(y.sizes(), std::vector({1, 1, 3})); - ASSERT_TRUE(torch::allclose(indices, torch::tensor({{{1, 3, 4}}}, torch::kLong))); + ASSERT_TRUE( + torch::allclose(indices, torch::tensor({{{1, 3, 4}}}, torch::kLong))); ASSERT_EQ(indices.sizes(), std::vector({1, 1, 3})); } @@ -646,14 +651,14 @@ TEST_F(ModulesTest, AdaptiveMaxPool2dEven) { s.backward(); ASSERT_EQ(y.ndimension(), 3); - ASSERT_TRUE(torch::allclose(y, torch::tensor({ - {{6, 8, 9}, - {16, 18, 19}, - {21, 23, 24}}, - {{31, 33, 34}, - {41, 43, 44}, - {46, 48, 49}}, - }, torch::kFloat))); + ASSERT_TRUE(torch::allclose( + y, + torch::tensor( + { + {{6, 8, 9}, {16, 18, 19}, {21, 23, 24}}, + {{31, 33, 34}, {41, 43, 44}, {46, 48, 49}}, + }, + torch::kFloat))); ASSERT_EQ(s.ndimension(), 0); ASSERT_EQ(y.sizes(), std::vector({2, 3, 3})); } @@ -667,14 +672,14 @@ TEST_F(ModulesTest, AdaptiveMaxPool2dUneven) { s.backward(); ASSERT_EQ(y.ndimension(), 3); - ASSERT_TRUE(torch::allclose(y, torch::tensor({ - {{5, 7}, - {13, 15}, - {17, 19}}, - {{25, 27}, - {33, 35}, - {37, 39}}, - }, torch::kFloat))); + ASSERT_TRUE(torch::allclose( + y, + torch::tensor( + { + {{5, 7}, {13, 15}, {17, 19}}, + {{25, 27}, {33, 35}, {37, 39}}, + }, + torch::kFloat))); ASSERT_EQ(s.ndimension(), 0); ASSERT_EQ(y.sizes(), std::vector({2, 3, 2})); } @@ -691,25 +696,25 @@ TEST_F(ModulesTest, AdaptiveMaxPool2dReturnIndicesEven) { ASSERT_EQ(s.ndimension(), 0); ASSERT_EQ(y.ndimension(), 3); - ASSERT_TRUE(torch::allclose(y, torch::tensor({ - {{6, 8, 9}, - {16, 18, 19}, - {21, 23, 24}}, - {{31, 33, 34}, - {41, 43, 44}, - {46, 48, 49}}, - }, torch::kFloat))); + ASSERT_TRUE(torch::allclose( + y, + torch::tensor( + { + {{6, 8, 9}, {16, 18, 19}, {21, 23, 24}}, + {{31, 33, 34}, {41, 43, 44}, {46, 48, 49}}, + }, + torch::kFloat))); ASSERT_EQ(y.sizes(), std::vector({2, 3, 3})); ASSERT_EQ(indices.ndimension(), 3); - ASSERT_TRUE(torch::allclose(indices, torch::tensor({ - {{6, 8, 9}, - {16, 18, 19}, - {21, 23, 24}}, - {{6, 8, 9}, - {16, 18, 19}, - {21, 23, 24}}, - }, torch::kLong))); + ASSERT_TRUE(torch::allclose( + indices, + torch::tensor( + { + {{6, 8, 9}, {16, 18, 19}, {21, 23, 24}}, + {{6, 8, 9}, {16, 18, 19}, {21, 23, 24}}, + }, + torch::kLong))); ASSERT_EQ(indices.sizes(), std::vector({2, 3, 3})); } @@ -725,25 +730,25 @@ TEST_F(ModulesTest, AdaptiveMaxPool2dReturnIndicesUneven) { ASSERT_EQ(s.ndimension(), 0); ASSERT_EQ(y.ndimension(), 3); - ASSERT_TRUE(torch::allclose(y, torch::tensor({ - {{5, 7}, - {13, 15}, - {17, 19}}, - {{25, 27}, - {33, 35}, - {37, 39}}, - }, torch::kFloat))); + ASSERT_TRUE(torch::allclose( + y, + torch::tensor( + { + {{5, 7}, {13, 15}, {17, 19}}, + {{25, 27}, {33, 35}, {37, 39}}, + }, + torch::kFloat))); ASSERT_EQ(y.sizes(), std::vector({2, 3, 2})); ASSERT_EQ(indices.ndimension(), 3); - ASSERT_TRUE(torch::allclose(indices, torch::tensor({ - {{5, 7}, - {13, 15}, - {17, 19}}, - {{5, 7}, - {13, 15}, - {17, 19}}, - }, torch::kLong))); + ASSERT_TRUE(torch::allclose( + indices, + torch::tensor( + { + {{5, 7}, {13, 15}, {17, 19}}, + {{5, 7}, {13, 15}, {17, 19}}, + }, + torch::kLong))); ASSERT_EQ(indices.sizes(), std::vector({2, 3, 2})); } @@ -758,17 +763,15 @@ TEST_F(ModulesTest, AdaptiveMaxPool3d) { ASSERT_EQ(s.ndimension(), 0); ASSERT_EQ(y.ndimension(), 4); - ASSERT_TRUE(torch::allclose(y, torch::tensor({ - {{21, 22, 23}, - {25, 26, 27}, - {29, 30, 31}}, - {{37, 38, 39}, - {41, 42, 43}, - {45, 46, 47}}, - {{53, 54, 55}, - {57, 58, 59}, - {61, 62, 63}}, - }, torch::kFloat))); + ASSERT_TRUE(torch::allclose( + y, + torch::tensor( + { + {{21, 22, 23}, {25, 26, 27}, {29, 30, 31}}, + {{37, 38, 39}, {41, 42, 43}, {45, 46, 47}}, + {{53, 54, 55}, {57, 58, 59}, {61, 62, 63}}, + }, + torch::kFloat))); ASSERT_EQ(y.sizes(), std::vector({1, 3, 3, 3})); } @@ -784,37 +787,34 @@ TEST_F(ModulesTest, AdaptiveMaxPool3dReturnIndices) { ASSERT_EQ(s.ndimension(), 0); ASSERT_EQ(y.ndimension(), 4); - ASSERT_TRUE(torch::allclose(y, torch::tensor({ - {{21, 22, 23}, - {25, 26, 27}, - {29, 30, 31}}, - {{37, 38, 39}, - {41, 42, 43}, - {45, 46, 47}}, - {{53, 54, 55}, - {57, 58, 59}, - {61, 62, 63}}, - }, torch::kFloat))); + ASSERT_TRUE(torch::allclose( + y, + torch::tensor( + { + {{21, 22, 23}, {25, 26, 27}, {29, 30, 31}}, + {{37, 38, 39}, {41, 42, 43}, {45, 46, 47}}, + {{53, 54, 55}, {57, 58, 59}, {61, 62, 63}}, + }, + torch::kFloat))); ASSERT_EQ(y.sizes(), std::vector({1, 3, 3, 3})); ASSERT_EQ(indices.ndimension(), 4); - ASSERT_TRUE(torch::allclose(indices, torch::tensor({ - {{21, 22, 23}, - {25, 26, 27}, - {29, 30, 31}}, - {{37, 38, 39}, - {41, 42, 43}, - {45, 46, 47}}, - {{53, 54, 55}, - {57, 58, 59}, - {61, 62, 63}}, - }, torch::kLong))); + ASSERT_TRUE(torch::allclose( + indices, + torch::tensor( + { + {{21, 22, 23}, {25, 26, 27}, {29, 30, 31}}, + {{37, 38, 39}, {41, 42, 43}, {45, 46, 47}}, + {{53, 54, 55}, {57, 58, 59}, {61, 62, 63}}, + }, + torch::kLong))); ASSERT_EQ(indices.sizes(), std::vector({1, 3, 3, 3})); } TEST_F(ModulesTest, AdaptiveAvgPool1d) { AdaptiveAvgPool1d model(3); - auto x = torch::tensor({{{1, 2, 3, 4, 5}}}, torch::dtype(torch::kFloat).requires_grad(true)); + auto x = torch::tensor( + {{{1, 2, 3, 4, 5}}}, torch::dtype(torch::kFloat).requires_grad(true)); auto y = model(x); torch::Tensor s = y.sum(); @@ -822,7 +822,8 @@ TEST_F(ModulesTest, AdaptiveAvgPool1d) { ASSERT_EQ(s.ndimension(), 0); ASSERT_EQ(y.ndimension(), 3); - ASSERT_TRUE(torch::allclose(y, torch::tensor({{{1.5, 3.0, 4.5}}}, torch::kFloat))); + ASSERT_TRUE( + torch::allclose(y, torch::tensor({{{1.5, 3.0, 4.5}}}, torch::kFloat))); ASSERT_EQ(y.sizes(), std::vector({1, 1, 3})); } @@ -837,14 +838,14 @@ TEST_F(ModulesTest, AdaptiveAvgPool2dEven) { ASSERT_EQ(s.ndimension(), 0); ASSERT_EQ(y.ndimension(), 3); - ASSERT_TRUE(torch::allclose(y, torch::tensor({ - {{ 3.0, 4.5, 6.0}, - {10.5, 12.0, 13.5}, - {18.0, 19.5, 21.0}}, - {{28.0, 29.5, 31.0}, - {35.5, 37.0, 38.5}, - {43.0, 44.5, 46.0}}, - }, torch::kFloat))); + ASSERT_TRUE(torch::allclose( + y, + torch::tensor( + { + {{3.0, 4.5, 6.0}, {10.5, 12.0, 13.5}, {18.0, 19.5, 21.0}}, + {{28.0, 29.5, 31.0}, {35.5, 37.0, 38.5}, {43.0, 44.5, 46.0}}, + }, + torch::kFloat))); ASSERT_EQ(y.sizes(), std::vector({2, 3, 3})); } @@ -859,14 +860,14 @@ TEST_F(ModulesTest, AdaptiveAvgPool2dUneven) { ASSERT_EQ(s.ndimension(), 0); ASSERT_EQ(y.ndimension(), 3); - ASSERT_TRUE(torch::allclose(y, torch::tensor({ - {{2.5, 4.5}, - {8.5, 10.5}, - {14.5, 16.5}}, - {{22.5, 24.5}, - {28.5, 30.5}, - {34.5, 36.5}}, - }, torch::kFloat))); + ASSERT_TRUE(torch::allclose( + y, + torch::tensor( + { + {{2.5, 4.5}, {8.5, 10.5}, {14.5, 16.5}}, + {{22.5, 24.5}, {28.5, 30.5}, {34.5, 36.5}}, + }, + torch::kFloat))); ASSERT_EQ(y.sizes(), std::vector({2, 3, 2})); } @@ -881,184 +882,166 @@ TEST_F(ModulesTest, AdaptiveAvgPool3d) { ASSERT_EQ(s.ndimension(), 0); ASSERT_EQ(y.ndimension(), 4); - ASSERT_TRUE(torch::allclose(y, torch::tensor({ - {{10.5, 11.5, 12.5}, - {14.5, 15.5, 16.5}, - {18.5, 19.5, 20.5}}, - {{26.5, 27.5, 28.5}, - {30.5, 31.5, 32.5}, - {34.5, 35.5, 36.5}}, - {{42.5, 43.5, 44.5}, - {46.5, 47.5, 48.5}, - {50.5, 51.5, 52.5}}, - }, torch::kFloat))); + ASSERT_TRUE(torch::allclose( + y, + torch::tensor( + { + {{10.5, 11.5, 12.5}, {14.5, 15.5, 16.5}, {18.5, 19.5, 20.5}}, + {{26.5, 27.5, 28.5}, {30.5, 31.5, 32.5}, {34.5, 35.5, 36.5}}, + {{42.5, 43.5, 44.5}, {46.5, 47.5, 48.5}, {50.5, 51.5, 52.5}}, + }, + torch::kFloat))); ASSERT_EQ(y.sizes(), std::vector({1, 3, 3, 3})); } TEST_F(ModulesTest, MaxUnpool1d) { auto indices = torch::tensor({{{1, 3, 4}}}, torch::kLong); - auto x = torch::tensor({{{2, 4, 5}}}, torch::dtype(torch::kFloat).requires_grad(true)); + auto x = torch::tensor( + {{{2, 4, 5}}}, torch::dtype(torch::kFloat).requires_grad(true)); auto model = MaxUnpool1d{3}; auto y = model->forward(x, indices); ASSERT_EQ(y.dim(), 3); - ASSERT_TRUE(torch::allclose(y, - torch::tensor({{{0, 2, 0, 4, 5, 0, 0, 0, 0}}}, torch::kFloat))); + ASSERT_TRUE(torch::allclose( + y, torch::tensor({{{0, 2, 0, 4, 5, 0, 0, 0, 0}}}, torch::kFloat))); ASSERT_EQ(y.sizes(), std::vector({1, 1, 9})); indices = torch::tensor({{{1, 3, 4}}}, torch::kLong); - x = torch::tensor({{{2, 4, 5}}}, torch::dtype(torch::kFloat).requires_grad(true)); + x = torch::tensor( + {{{2, 4, 5}}}, torch::dtype(torch::kFloat).requires_grad(true)); model = MaxUnpool1d{MaxUnpool1dOptions(3).stride(2).padding(1)}; y = model->forward(x, indices, std::vector({1, 1, 5})); ASSERT_EQ(y.dim(), 3); - ASSERT_TRUE(torch::allclose(y, - torch::tensor({{{0, 2, 0, 4, 5}}}, torch::kFloat))); + ASSERT_TRUE( + torch::allclose(y, torch::tensor({{{0, 2, 0, 4, 5}}}, torch::kFloat))); ASSERT_EQ(y.sizes(), std::vector({1, 1, 5})); } TEST_F(ModulesTest, MaxPool1d_MaxUnpool1d) { - MaxPool1d pool {MaxPool1dOptions(2).stride(2)}; - MaxUnpool1d unpool {MaxUnpool1dOptions(2).stride(2)}; + MaxPool1d pool{MaxPool1dOptions(2).stride(2)}; + MaxUnpool1d unpool{MaxUnpool1dOptions(2).stride(2)}; auto input = torch::tensor({{{1, 2, 3, 4, 5, 6, 7, 8}}}, torch::kFloat); torch::Tensor output, indices; std::tie(output, indices) = pool->forward_with_indices(input); ASSERT_TRUE(torch::allclose( - unpool(output, indices), - torch::tensor({{{0, 2, 0, 4, 0, 6, 0, 8}}} , torch::kFloat))); + unpool(output, indices), + torch::tensor({{{0, 2, 0, 4, 0, 6, 0, 8}}}, torch::kFloat))); // Example showcasing the use of output_size input = torch::tensor({{{1, 2, 3, 4, 5, 6, 7, 8, 9}}}, torch::kFloat); std::tie(output, indices) = pool->forward_with_indices(input); ASSERT_TRUE(torch::allclose( - unpool(output, indices, input.sizes().vec()), - torch::tensor({{{0, 2, 0, 4, 0, 6, 0, 8, 0}}} , torch::kFloat))); + unpool(output, indices, input.sizes().vec()), + torch::tensor({{{0, 2, 0, 4, 0, 6, 0, 8, 0}}}, torch::kFloat))); ASSERT_TRUE(torch::allclose( - unpool(output, indices), - torch::tensor({{{0, 2, 0, 4, 0, 6, 0, 8}}} , torch::kFloat))); + unpool(output, indices), + torch::tensor({{{0, 2, 0, 4, 0, 6, 0, 8}}}, torch::kFloat))); } TEST_F(ModulesTest, MaxUnpool2d) { - auto indices = torch::tensor({ - {{{ 6, 8, 9}, - {16, 18, 19}, - {21, 23, 24}}}, - {{{ 6, 8, 9}, - {16, 18, 19}, - {21, 23, 24}}}}, torch::kLong); - auto x = torch::tensor({ - {{{ 6, 8, 9}, - {16, 18, 19}, - {21, 23, 24}}}, - {{{31, 33, 34}, - {41, 43, 44}, - {46, 48, 49}}}}, torch::dtype(torch::kFloat).requires_grad(true)); + auto indices = torch::tensor( + {{{{6, 8, 9}, {16, 18, 19}, {21, 23, 24}}}, + {{{6, 8, 9}, {16, 18, 19}, {21, 23, 24}}}}, + torch::kLong); + auto x = torch::tensor( + {{{{6, 8, 9}, {16, 18, 19}, {21, 23, 24}}}, + {{{31, 33, 34}, {41, 43, 44}, {46, 48, 49}}}}, + torch::dtype(torch::kFloat).requires_grad(true)); auto model = MaxUnpool2d{MaxUnpool2dOptions(3).stride(2).padding(1)}; auto y = model->forward(x, indices); ASSERT_EQ(y.dim(), 4); - ASSERT_TRUE(torch::allclose(y, torch::tensor( - {{{{ 0, 0, 0, 0, 0}, - { 0, 6, 0, 8, 9}, - { 0, 0, 0, 0, 0}, - { 0, 16, 0, 18, 19}, - { 0, 21, 0, 23, 24}}}, - {{{ 0, 0, 0, 0, 0}, - { 0, 31, 0, 33, 34}, - { 0, 0, 0, 0, 0}, - { 0, 41, 0, 43, 44}, - { 0, 46, 0, 48, 49}}}} , torch::kFloat))); + ASSERT_TRUE(torch::allclose( + y, + torch::tensor( + {{{{0, 0, 0, 0, 0}, + {0, 6, 0, 8, 9}, + {0, 0, 0, 0, 0}, + {0, 16, 0, 18, 19}, + {0, 21, 0, 23, 24}}}, + {{{0, 0, 0, 0, 0}, + {0, 31, 0, 33, 34}, + {0, 0, 0, 0, 0}, + {0, 41, 0, 43, 44}, + {0, 46, 0, 48, 49}}}}, + torch::kFloat))); ASSERT_EQ(y.sizes(), std::vector({2, 1, 5, 5})); } TEST_F(ModulesTest, MaxPool2d_MaxUnpool2d) { - MaxPool2d pool {MaxPool2dOptions(2).stride(2)}; - MaxUnpool2d unpool {MaxUnpool2dOptions(2).stride(2)}; - auto input = torch::tensor({{{{ 1, 2, 3, 4}, - { 5, 6, 7, 8}, - { 9, 10, 11, 12}, - {13, 14, 15, 16}}}}, torch::kFloat); + MaxPool2d pool{MaxPool2dOptions(2).stride(2)}; + MaxUnpool2d unpool{MaxUnpool2dOptions(2).stride(2)}; + auto input = torch::tensor( + {{{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}, {13, 14, 15, 16}}}}, + torch::kFloat); torch::Tensor output, indices; std::tie(output, indices) = pool->forward_with_indices(input); ASSERT_TRUE(torch::allclose( - unpool(output, indices), - torch::tensor({{{{ 0, 0, 0, 0}, - { 0, 6, 0, 8}, - { 0, 0, 0, 0}, - { 0, 14, 0, 16}}}} , torch::kFloat))); + unpool(output, indices), + torch::tensor( + {{{{0, 0, 0, 0}, {0, 6, 0, 8}, {0, 0, 0, 0}, {0, 14, 0, 16}}}}, + torch::kFloat))); ASSERT_TRUE(torch::allclose( - unpool(output, indices, std::vector{1, 1, 5, 5}), - torch::tensor({{{{ 0, 0, 0, 0, 0}, - { 6, 0, 8, 0, 0}, - { 0, 0, 0, 14, 0}, - { 16, 0, 0, 0, 0}, - { 0, 0, 0, 0, 0}}}}, torch::kFloat))); + unpool(output, indices, std::vector{1, 1, 5, 5}), + torch::tensor( + {{{{0, 0, 0, 0, 0}, + {6, 0, 8, 0, 0}, + {0, 0, 0, 14, 0}, + {16, 0, 0, 0, 0}, + {0, 0, 0, 0, 0}}}}, + torch::kFloat))); } TEST_F(ModulesTest, MaxUnpool3d) { auto indices = torch::tensor({{{{{26}}}}}, torch::kLong); - auto x = torch::tensor({{{{{26}}}}}, torch::dtype(torch::kFloat).requires_grad(true)); + auto x = torch::tensor( + {{{{{26}}}}}, torch::dtype(torch::kFloat).requires_grad(true)); auto model = MaxUnpool3d{3}; auto y = model->forward(x, indices); ASSERT_EQ(y.dim(), 5); - ASSERT_TRUE(torch::allclose(y, torch::tensor( - {{{{{ 0, 0, 0}, - { 0, 0, 0}, - { 0, 0, 0}}, - {{ 0, 0, 0}, - { 0, 0, 0}, - { 0, 0, 0}}, - {{ 0, 0, 0}, - { 0, 0, 0}, - { 0, 0, 26}}}}}, torch::kFloat))); + ASSERT_TRUE(torch::allclose( + y, + torch::tensor( + {{{{{0, 0, 0}, {0, 0, 0}, {0, 0, 0}}, + {{0, 0, 0}, {0, 0, 0}, {0, 0, 0}}, + {{0, 0, 0}, {0, 0, 0}, {0, 0, 26}}}}}, + torch::kFloat))); ASSERT_EQ(y.sizes(), std::vector({1, 1, 3, 3, 3})); } TEST_F(ModulesTest, MaxUnpool3dOutputSize) { auto indices = torch::tensor( - {{{{{21, 23}, - {29, 31}}, - {{53, 55}, - {61, 63}}}}}, torch::kLong); - auto x = torch::tensor( - {{{{{21, 23}, - {29, 31}}, - {{53, 55}, - {61, 63}}}}}, torch::dtype(torch::kFloat).requires_grad(true)); + {{{{{21, 23}, {29, 31}}, {{53, 55}, {61, 63}}}}}, torch::kLong); + auto x = torch::tensor( + {{{{{21, 23}, {29, 31}}, {{53, 55}, {61, 63}}}}}, + torch::dtype(torch::kFloat).requires_grad(true)); auto model = MaxUnpool3d{MaxUnpool3dOptions(3).stride(2).padding(1)}; auto y = model->forward(x, indices, std::vector({1, 1, 4, 4, 4})); ASSERT_EQ(y.dim(), 5); - ASSERT_TRUE(torch::allclose(y, torch::tensor( - {{{{{ 0, 0, 0, 0}, - { 0, 0, 0, 0}, - { 0, 0, 0, 0}, - { 0, 0, 0, 0}}, - {{ 0, 0, 0, 0}, - { 0, 21, 0, 23}, - { 0, 0, 0, 0}, - { 0, 29, 0, 31}}, - {{ 0, 0, 0, 0}, - { 0, 0, 0, 0}, - { 0, 0, 0, 0}, - { 0, 0, 0, 0}}, - {{ 0, 0, 0, 0}, - { 0, 53, 0, 55}, - { 0, 0, 0, 0}, - { 0, 61, 0, 63}}}}}, torch::kFloat))); + ASSERT_TRUE(torch::allclose( + y, + torch::tensor( + {{{{{0, 0, 0, 0}, {0, 0, 0, 0}, {0, 0, 0, 0}, {0, 0, 0, 0}}, + {{0, 0, 0, 0}, {0, 21, 0, 23}, {0, 0, 0, 0}, {0, 29, 0, 31}}, + {{0, 0, 0, 0}, {0, 0, 0, 0}, {0, 0, 0, 0}, {0, 0, 0, 0}}, + {{0, 0, 0, 0}, {0, 53, 0, 55}, {0, 0, 0, 0}, {0, 61, 0, 63}}}}}, + torch::kFloat))); ASSERT_EQ(y.sizes(), std::vector({1, 1, 4, 4, 4})); } TEST_F(ModulesTest, MaxPool3d_MaxUnpool3d) { - MaxPool3d pool {MaxPool3dOptions(3).stride(2)}; - MaxUnpool3d unpool {MaxUnpool3dOptions(3).stride(2)}; + MaxPool3d pool{MaxPool3dOptions(3).stride(2)}; + MaxUnpool3d unpool{MaxUnpool3dOptions(3).stride(2)}; auto input = torch::randn({20, 16, 51, 33, 15}); torch::Tensor output, indices; std::tie(output, indices) = pool->forward_with_indices(input); auto unpooled_output = unpool(output, indices); - ASSERT_EQ(unpooled_output.sizes(), std::vector({20, 16, 51, 33, 15})); + ASSERT_EQ( + unpooled_output.sizes(), std::vector({20, 16, 51, 33, 15})); } TEST_F(ModulesTest, Linear) { @@ -1101,34 +1084,22 @@ TEST_F(ModulesTest, Linear) { TEST_F(ModulesTest, LocalResponseNorm) { { LocalResponseNorm model(LocalResponseNormOptions(2)); - const auto x = torch::arange(100., 136, torch::requires_grad()).reshape({2, 3, 3, 2}); + const auto x = + torch::arange(100., 136, torch::requires_grad()).reshape({2, 3, 3, 2}); auto y = model(x); const auto y_exp = torch::tensor( - {{{{73.7788, 74.1462}, - {74.5031, 74.8572}, - {75.2010, 75.5420}}, - - {{61.6057, 61.7227}, - {61.8347, 61.9418}, - {62.0441, 62.1418}}, - - {{62.2349, 62.3235}, - {62.4077, 62.4877}, - {62.5635, 62.6353}}}, - - {{{79.3915, 79.6491}, - {79.8978, 80.1446}, - {80.3827, 80.6190}}, - - {{63.0317, 63.0742}, - {63.1135, 63.1496}, - {63.1826, 63.2126}}, - - {{63.2396, 63.2637}, - {63.2850, 63.3036}, - {63.3195, 63.3328}}}}, - torch::kFloat - ); + {{{{73.7788, 74.1462}, {74.5031, 74.8572}, {75.2010, 75.5420}}, + + {{61.6057, 61.7227}, {61.8347, 61.9418}, {62.0441, 62.1418}}, + + {{62.2349, 62.3235}, {62.4077, 62.4877}, {62.5635, 62.6353}}}, + + {{{79.3915, 79.6491}, {79.8978, 80.1446}, {80.3827, 80.6190}}, + + {{63.0317, 63.0742}, {63.1135, 63.1496}, {63.1826, 63.2126}}, + + {{63.2396, 63.2637}, {63.2850, 63.3036}, {63.3195, 63.3328}}}}, + torch::kFloat); torch::Tensor s = y.sum(); s.backward(); @@ -1220,7 +1191,8 @@ TEST_F(ModulesTest, Fold) { TEST_F(ModulesTest, Unfold) { { Unfold model(UnfoldOptions({2, 2}).padding(1).stride(2)); - auto input = torch::arange(2., 14, torch::requires_grad()).view({1, 2, 2, 3}); + auto input = + torch::arange(2., 14, torch::requires_grad()).view({1, 2, 2, 3}); auto output = model(input); auto expected = torch::tensor( {{{0.0, 0.0, 0.0, 6.0}, @@ -1315,7 +1287,8 @@ TEST_F(ModulesTest, EmbeddingFromPretrained) { auto weight = torch::tensor({{1., 2.3, 3.}, {4., 5.1, 6.3}}); Embedding embedding = torch::nn::Embedding::from_pretrained(weight); auto input = torch::tensor({1}, torch::kLong); - ASSERT_TRUE(torch::allclose(embedding(input), torch::tensor({4.0000, 5.1000, 6.3000}))); + ASSERT_TRUE(torch::allclose( + embedding(input), torch::tensor({4.0000, 5.1000, 6.3000}))); } TEST_F(ModulesTest, EmbeddingBagFromPretrained) { @@ -1323,7 +1296,8 @@ TEST_F(ModulesTest, EmbeddingBagFromPretrained) { EmbeddingBag embeddingbag = torch::nn::EmbeddingBag::from_pretrained(weight); auto input = torch::zeros({{1, 2}}, torch::kLong); input[0] = torch::tensor({1, 0}); - ASSERT_TRUE(torch::allclose(embeddingbag(input), torch::tensor({2.5000, 3.7000, 4.6500}))); + ASSERT_TRUE(torch::allclose( + embeddingbag(input), torch::tensor({2.5000, 3.7000, 4.6500}))); } TEST_F(ModulesTest, AlphaDropout) { @@ -1525,7 +1499,8 @@ TEST_F(ModulesTest, BatchNorm1dStateful) { } TEST_F(ModulesTest, BatchNorm1dStateless) { - BatchNorm1d bn(BatchNorm1dOptions(5).track_running_stats(false).affine(false)); + BatchNorm1d bn( + BatchNorm1dOptions(5).track_running_stats(false).affine(false)); ASSERT_FALSE(bn->running_mean.defined()); ASSERT_FALSE(bn->running_var.defined()); @@ -1540,16 +1515,17 @@ TEST_F(ModulesTest, BatchNorm1d) { auto input = torch::arange(2. * 5 * 2).view({2, 5, 2}).requires_grad_(); auto output = bn->forward(input); - auto expected = torch::tensor({{{ 0.0000, 1.0000}, - { 2.0000, 3.0000}, - { 4.0000, 5.0000}, - { 6.0000, 7.0000}, - { 8.0000, 9.0000}}, - {{10.0000, 10.9999}, - {11.9999, 12.9999}, - {13.9999, 14.9999}, - {15.9999, 16.9999}, - {17.9999, 18.9999}}}); + auto expected = torch::tensor( + {{{0.0000, 1.0000}, + {2.0000, 3.0000}, + {4.0000, 5.0000}, + {6.0000, 7.0000}, + {8.0000, 9.0000}}, + {{10.0000, 10.9999}, + {11.9999, 12.9999}, + {13.9999, 14.9999}, + {15.9999, 16.9999}, + {17.9999, 18.9999}}}); ASSERT_TRUE(output.allclose(expected)); auto s = output.sum(); s.backward(); @@ -1585,7 +1561,8 @@ TEST_F(ModulesTest, BatchNorm2dStateful) { } TEST_F(ModulesTest, BatchNorm2dStateless) { - BatchNorm2d bn(BatchNorm2dOptions(5).track_running_stats(false).affine(false)); + BatchNorm2d bn( + BatchNorm2dOptions(5).track_running_stats(false).affine(false)); ASSERT_FALSE(bn->running_mean.defined()); ASSERT_FALSE(bn->running_var.defined()); @@ -1598,28 +1575,20 @@ TEST_F(ModulesTest, BatchNorm2d) { BatchNorm2d bn(5); bn->eval(); - auto input = torch::arange(2. * 5 * 2 * 2).view({2, 5, 2, 2}).requires_grad_(); + auto input = + torch::arange(2. * 5 * 2 * 2).view({2, 5, 2, 2}).requires_grad_(); auto output = bn->forward(input); - auto expected = torch::tensor({{{{ 0.0000, 1.0000}, - { 2.0000, 3.0000}}, - {{ 4.0000, 5.0000}, - { 6.0000, 7.0000}}, - {{ 8.0000, 9.0000}, - {10.0000, 10.9999}}, - {{11.9999, 12.9999}, - {13.9999, 14.9999}}, - {{15.9999, 16.9999}, - {17.9999, 18.9999}}}, - {{{19.9999, 20.9999}, - {21.9999, 22.9999}}, - {{23.9999, 24.9999}, - {25.9999, 26.9999}}, - {{27.9999, 28.9999}, - {29.9998, 30.9998}}, - {{31.9998, 32.9998}, - {33.9998, 34.9998}}, - {{35.9998, 36.9998}, - {37.9998, 38.9998}}}}); + auto expected = torch::tensor( + {{{{0.0000, 1.0000}, {2.0000, 3.0000}}, + {{4.0000, 5.0000}, {6.0000, 7.0000}}, + {{8.0000, 9.0000}, {10.0000, 10.9999}}, + {{11.9999, 12.9999}, {13.9999, 14.9999}}, + {{15.9999, 16.9999}, {17.9999, 18.9999}}}, + {{{19.9999, 20.9999}, {21.9999, 22.9999}}, + {{23.9999, 24.9999}, {25.9999, 26.9999}}, + {{27.9999, 28.9999}, {29.9998, 30.9998}}, + {{31.9998, 32.9998}, {33.9998, 34.9998}}, + {{35.9998, 36.9998}, {37.9998, 38.9998}}}}); ASSERT_TRUE(output.allclose(expected)); auto s = output.sum(); s.backward(); @@ -1655,7 +1624,8 @@ TEST_F(ModulesTest, BatchNorm3dStateful) { } TEST_F(ModulesTest, BatchNorm3dStateless) { - BatchNorm3d bn(BatchNorm3dOptions(5).track_running_stats(false).affine(false)); + BatchNorm3d bn( + BatchNorm3dOptions(5).track_running_stats(false).affine(false)); ASSERT_FALSE(bn->running_mean.defined()); ASSERT_FALSE(bn->running_var.defined()); @@ -1668,48 +1638,30 @@ TEST_F(ModulesTest, BatchNorm3d) { BatchNorm3d bn(5); bn->eval(); - auto input = torch::arange(2. * 5 * 2 * 2 * 2).view({2, 5, 2, 2, 2}).requires_grad_(); + auto input = + torch::arange(2. * 5 * 2 * 2 * 2).view({2, 5, 2, 2, 2}).requires_grad_(); auto output = bn->forward(input); - auto expected = torch::tensor({{{{{ 0.0000, 1.0000}, - { 2.0000, 3.0000}}, - {{ 4.0000, 5.0000}, - { 6.0000, 7.0000}}}, - {{{ 8.0000, 9.0000}, - {10.0000, 10.9999}}, - {{11.9999, 12.9999}, - {13.9999, 14.9999}}}, - {{{15.9999, 16.9999}, - {17.9999, 18.9999}}, - {{19.9999, 20.9999}, - {21.9999, 22.9999}}}, - {{{23.9999, 24.9999}, - {25.9999, 26.9999}}, - {{27.9999, 28.9999}, - {29.9998, 30.9998}}}, - {{{31.9998, 32.9998}, - {33.9998, 34.9998}}, - {{35.9998, 36.9998}, - {37.9998, 38.9998}}}}, - {{{{39.9998, 40.9998}, - {41.9998, 42.9998}}, - {{43.9998, 44.9998}, - {45.9998, 46.9998}}}, - {{{47.9998, 48.9998}, - {49.9997, 50.9997}}, - {{51.9997, 52.9997}, - {53.9997, 54.9997}}}, - {{{55.9997, 56.9997}, - {57.9997, 58.9997}}, - {{59.9997, 60.9997}, - {61.9997, 62.9997}}}, - {{{63.9997, 64.9997}, - {65.9997, 66.9997}}, - {{67.9997, 68.9997}, - {69.9996, 70.9996}}}, - {{{71.9996, 72.9996}, - {73.9996, 74.9996}}, - {{75.9996, 76.9996}, - {77.9996, 78.9996}}}}}); + auto expected = torch::tensor( + {{{{{0.0000, 1.0000}, {2.0000, 3.0000}}, + {{4.0000, 5.0000}, {6.0000, 7.0000}}}, + {{{8.0000, 9.0000}, {10.0000, 10.9999}}, + {{11.9999, 12.9999}, {13.9999, 14.9999}}}, + {{{15.9999, 16.9999}, {17.9999, 18.9999}}, + {{19.9999, 20.9999}, {21.9999, 22.9999}}}, + {{{23.9999, 24.9999}, {25.9999, 26.9999}}, + {{27.9999, 28.9999}, {29.9998, 30.9998}}}, + {{{31.9998, 32.9998}, {33.9998, 34.9998}}, + {{35.9998, 36.9998}, {37.9998, 38.9998}}}}, + {{{{39.9998, 40.9998}, {41.9998, 42.9998}}, + {{43.9998, 44.9998}, {45.9998, 46.9998}}}, + {{{47.9998, 48.9998}, {49.9997, 50.9997}}, + {{51.9997, 52.9997}, {53.9997, 54.9997}}}, + {{{55.9997, 56.9997}, {57.9997, 58.9997}}, + {{59.9997, 60.9997}, {61.9997, 62.9997}}}, + {{{63.9997, 64.9997}, {65.9997, 66.9997}}, + {{67.9997, 68.9997}, {69.9996, 70.9996}}}, + {{{71.9996, 72.9996}, {73.9996, 74.9996}}, + {{75.9996, 76.9996}, {77.9996, 78.9996}}}}}); ASSERT_TRUE(output.allclose(expected)); auto s = output.sum(); s.backward(); @@ -1718,7 +1670,8 @@ TEST_F(ModulesTest, BatchNorm3d) { } TEST_F(ModulesTest, InstanceNorm1dStateful) { - InstanceNorm1d instance_norm(InstanceNorm1dOptions(5).track_running_stats(true).affine(true)); + InstanceNorm1d instance_norm( + InstanceNorm1dOptions(5).track_running_stats(true).affine(true)); ASSERT_TRUE(instance_norm->options.track_running_stats()); @@ -1745,7 +1698,8 @@ TEST_F(ModulesTest, InstanceNorm1dStateful) { } TEST_F(ModulesTest, InstanceNorm1dStateless) { - InstanceNorm1d instance_norm(InstanceNorm1dOptions(5).track_running_stats(false).affine(false)); + InstanceNorm1d instance_norm( + InstanceNorm1dOptions(5).track_running_stats(false).affine(false)); ASSERT_FALSE(instance_norm->running_mean.defined()); ASSERT_FALSE(instance_norm->running_var.defined()); @@ -1760,16 +1714,17 @@ TEST_F(ModulesTest, InstanceNorm1d) { auto input = torch::arange(2. * 5 * 2).view({2, 5, 2}).requires_grad_(); auto output = instance_norm->forward(input); - auto expected = torch::tensor({{{-1.0000, 1.0000}, - {-1.0000, 1.0000}, - {-1.0000, 1.0000}, - {-1.0000, 1.0000}, - {-1.0000, 1.0000}}, - {{-1.0000, 1.0000}, - {-1.0000, 1.0000}, - {-1.0000, 1.0000}, - {-1.0000, 1.0000}, - {-1.0000, 1.0000}}}); + auto expected = torch::tensor( + {{{-1.0000, 1.0000}, + {-1.0000, 1.0000}, + {-1.0000, 1.0000}, + {-1.0000, 1.0000}, + {-1.0000, 1.0000}}, + {{-1.0000, 1.0000}, + {-1.0000, 1.0000}, + {-1.0000, 1.0000}, + {-1.0000, 1.0000}, + {-1.0000, 1.0000}}}); ASSERT_TRUE(output.allclose(expected, 1e-3)); auto s = output.sum(); s.backward(); @@ -1778,7 +1733,8 @@ TEST_F(ModulesTest, InstanceNorm1d) { } TEST_F(ModulesTest, InstanceNorm2dStateful) { - InstanceNorm2d instance_norm(InstanceNorm2dOptions(5).track_running_stats(true).affine(true)); + InstanceNorm2d instance_norm( + InstanceNorm2dOptions(5).track_running_stats(true).affine(true)); ASSERT_TRUE(instance_norm->options.track_running_stats()); @@ -1805,7 +1761,8 @@ TEST_F(ModulesTest, InstanceNorm2dStateful) { } TEST_F(ModulesTest, InstanceNorm2dStateless) { - InstanceNorm2d instance_norm(InstanceNorm2dOptions(5).track_running_stats(false).affine(false)); + InstanceNorm2d instance_norm( + InstanceNorm2dOptions(5).track_running_stats(false).affine(false)); ASSERT_FALSE(instance_norm->running_mean.defined()); ASSERT_FALSE(instance_norm->running_var.defined()); @@ -1818,28 +1775,20 @@ TEST_F(ModulesTest, InstanceNorm2d) { InstanceNorm2d instance_norm(5); instance_norm->eval(); - auto input = torch::arange(2. * 5 * 2 * 2).view({2, 5, 2, 2}).requires_grad_(); + auto input = + torch::arange(2. * 5 * 2 * 2).view({2, 5, 2, 2}).requires_grad_(); auto output = instance_norm->forward(input); - auto expected = torch::tensor({{{{-1.3416, -0.4472}, - { 0.4472, 1.3416}}, - {{-1.3416, -0.4472}, - { 0.4472, 1.3416}}, - {{-1.3416, -0.4472}, - { 0.4472, 1.3416}}, - {{-1.3416, -0.4472}, - { 0.4472, 1.3416}}, - {{-1.3416, -0.4472}, - { 0.4472, 1.3416}}}, - {{{-1.3416, -0.4472}, - { 0.4472, 1.3416}}, - {{-1.3416, -0.4472}, - { 0.4472, 1.3416}}, - {{-1.3416, -0.4472}, - { 0.4472, 1.3416}}, - {{-1.3416, -0.4472}, - { 0.4472, 1.3416}}, - {{-1.3416, -0.4472}, - { 0.4472, 1.3416}}}}); + auto expected = torch::tensor( + {{{{-1.3416, -0.4472}, {0.4472, 1.3416}}, + {{-1.3416, -0.4472}, {0.4472, 1.3416}}, + {{-1.3416, -0.4472}, {0.4472, 1.3416}}, + {{-1.3416, -0.4472}, {0.4472, 1.3416}}, + {{-1.3416, -0.4472}, {0.4472, 1.3416}}}, + {{{-1.3416, -0.4472}, {0.4472, 1.3416}}, + {{-1.3416, -0.4472}, {0.4472, 1.3416}}, + {{-1.3416, -0.4472}, {0.4472, 1.3416}}, + {{-1.3416, -0.4472}, {0.4472, 1.3416}}, + {{-1.3416, -0.4472}, {0.4472, 1.3416}}}}); ASSERT_TRUE(output.allclose(expected, 1e-3)); auto s = output.sum(); s.backward(); @@ -1848,7 +1797,8 @@ TEST_F(ModulesTest, InstanceNorm2d) { } TEST_F(ModulesTest, InstanceNorm3dStateful) { - InstanceNorm3d instance_norm(InstanceNorm3dOptions(5).track_running_stats(true).affine(true)); + InstanceNorm3d instance_norm( + InstanceNorm3dOptions(5).track_running_stats(true).affine(true)); ASSERT_TRUE(instance_norm->options.track_running_stats()); @@ -1875,7 +1825,8 @@ TEST_F(ModulesTest, InstanceNorm3dStateful) { } TEST_F(ModulesTest, InstanceNorm3dStateless) { - InstanceNorm3d instance_norm(InstanceNorm3dOptions(5).track_running_stats(false).affine(false)); + InstanceNorm3d instance_norm( + InstanceNorm3dOptions(5).track_running_stats(false).affine(false)); ASSERT_FALSE(instance_norm->running_mean.defined()); ASSERT_FALSE(instance_norm->running_var.defined()); @@ -1888,48 +1839,30 @@ TEST_F(ModulesTest, InstanceNorm3d) { InstanceNorm3d instance_norm(5); instance_norm->eval(); - auto input = torch::arange(2. * 5 * 2 * 2 * 2).view({2, 5, 2, 2, 2}).requires_grad_(); + auto input = + torch::arange(2. * 5 * 2 * 2 * 2).view({2, 5, 2, 2, 2}).requires_grad_(); auto output = instance_norm->forward(input); - auto expected = torch::tensor({{{{{-1.5275, -1.0911}, - {-0.6547, -0.2182}}, - {{ 0.2182, 0.6547}, - { 1.0911, 1.5275}}}, - {{{-1.5275, -1.0911}, - {-0.6547, -0.2182}}, - {{ 0.2182, 0.6547}, - { 1.0911, 1.5275}}}, - {{{-1.5275, -1.0911}, - {-0.6547, -0.2182}}, - {{ 0.2182, 0.6547}, - { 1.0911, 1.5275}}}, - {{{-1.5275, -1.0911}, - {-0.6547, -0.2182}}, - {{ 0.2182, 0.6547}, - { 1.0911, 1.5275}}}, - {{{-1.5275, -1.0911}, - {-0.6547, -0.2182}}, - {{ 0.2182, 0.6547}, - { 1.0911, 1.5275}}}}, - {{{{-1.5275, -1.0911}, - {-0.6547, -0.2182}}, - {{ 0.2182, 0.6547}, - { 1.0911, 1.5275}}}, - {{{-1.5275, -1.0911}, - {-0.6547, -0.2182}}, - {{ 0.2182, 0.6547}, - { 1.0911, 1.5275}}}, - {{{-1.5275, -1.0911}, - {-0.6547, -0.2182}}, - {{ 0.2182, 0.6547}, - { 1.0911, 1.5275}}}, - {{{-1.5275, -1.0911}, - {-0.6547, -0.2182}}, - {{ 0.2182, 0.6547}, - { 1.0911, 1.5275}}}, - {{{-1.5275, -1.0911}, - {-0.6547, -0.2182}}, - {{ 0.2182, 0.6547}, - { 1.0911, 1.5275}}}}}); + auto expected = torch::tensor( + {{{{{-1.5275, -1.0911}, {-0.6547, -0.2182}}, + {{0.2182, 0.6547}, {1.0911, 1.5275}}}, + {{{-1.5275, -1.0911}, {-0.6547, -0.2182}}, + {{0.2182, 0.6547}, {1.0911, 1.5275}}}, + {{{-1.5275, -1.0911}, {-0.6547, -0.2182}}, + {{0.2182, 0.6547}, {1.0911, 1.5275}}}, + {{{-1.5275, -1.0911}, {-0.6547, -0.2182}}, + {{0.2182, 0.6547}, {1.0911, 1.5275}}}, + {{{-1.5275, -1.0911}, {-0.6547, -0.2182}}, + {{0.2182, 0.6547}, {1.0911, 1.5275}}}}, + {{{{-1.5275, -1.0911}, {-0.6547, -0.2182}}, + {{0.2182, 0.6547}, {1.0911, 1.5275}}}, + {{{-1.5275, -1.0911}, {-0.6547, -0.2182}}, + {{0.2182, 0.6547}, {1.0911, 1.5275}}}, + {{{-1.5275, -1.0911}, {-0.6547, -0.2182}}, + {{0.2182, 0.6547}, {1.0911, 1.5275}}}, + {{{-1.5275, -1.0911}, {-0.6547, -0.2182}}, + {{0.2182, 0.6547}, {1.0911, 1.5275}}}, + {{{-1.5275, -1.0911}, {-0.6547, -0.2182}}, + {{0.2182, 0.6547}, {1.0911, 1.5275}}}}}); ASSERT_TRUE(output.allclose(expected, 1e-3)); auto s = output.sum(); s.backward(); @@ -1973,8 +1906,8 @@ TEST_F(ModulesTest, Linear2_CUDA) { TEST_F(ModulesTest, L1Loss) { L1Loss loss; - auto input = torch::randn({5,6}, torch::requires_grad()); - auto target = torch::empty({5,6}).random_(2); + auto input = torch::randn({5, 6}, torch::requires_grad()); + auto target = torch::empty({5, 6}).random_(2); auto output = loss->forward(torch::sigmoid(input), target); auto s = output.sum(); s.backward(); @@ -1985,8 +1918,8 @@ TEST_F(ModulesTest, L1Loss) { TEST_F(ModulesTest, MSELoss) { MSELoss loss; - auto input = torch::randn({5,6}, torch::requires_grad()); - auto target = torch::empty({5,6}).random_(2); + auto input = torch::randn({5, 6}, torch::requires_grad()); + auto target = torch::empty({5, 6}).random_(2); auto output = loss->forward(torch::sigmoid(input), target); auto s = output.sum(); s.backward(); @@ -1997,8 +1930,8 @@ TEST_F(ModulesTest, MSELoss) { TEST_F(ModulesTest, BCELoss) { BCELoss loss; - auto input = torch::randn({5,6}, torch::requires_grad()); - auto target = torch::empty({5,6}).random_(2); + auto input = torch::randn({5, 6}, torch::requires_grad()); + auto target = torch::empty({5, 6}).random_(2); auto output = loss->forward(torch::sigmoid(input), target); auto s = output.sum(); s.backward(); @@ -2009,8 +1942,8 @@ TEST_F(ModulesTest, BCELoss) { TEST_F(ModulesTest, KLDivLoss) { KLDivLoss loss; - auto input = torch::randn({5,6}, torch::requires_grad()); - auto target = torch::empty({5,6}).random_(2); + auto input = torch::randn({5, 6}, torch::requires_grad()); + auto target = torch::empty({5, 6}).random_(2); auto output = loss->forward(torch::sigmoid(input), target); auto s = output.sum(); s.backward(); @@ -2021,7 +1954,9 @@ TEST_F(ModulesTest, KLDivLoss) { TEST_F(ModulesTest, HingeEmbeddingLoss) { HingeEmbeddingLoss loss(HingeEmbeddingLossOptions().margin(2)); - auto input = torch::tensor({{2, 22, 4}, {20, 10, 0}}, torch::dtype(torch::kFloat).requires_grad(true)); + auto input = torch::tensor( + {{2, 22, 4}, {20, 10, 0}}, + torch::dtype(torch::kFloat).requires_grad(true)); auto target = torch::tensor({{2, 6, 4}, {1, 10, 0}}, torch::kFloat); auto output = loss->forward(input, target); auto expected = torch::tensor({10}, torch::kFloat); @@ -2035,7 +1970,9 @@ TEST_F(ModulesTest, HingeEmbeddingLoss) { TEST_F(ModulesTest, MultiMarginLoss) { auto weight = torch::tensor({0.3, 0.3, 0.4}, torch::kFloat); MultiMarginLoss loss(MultiMarginLossOptions().margin(2).weight(weight)); - auto input = torch::tensor({{0.2, 0.2, 0.6}, {0.1, 0.8, 0.1}, {0.9, 0.09, 0.01}}, torch::dtype(torch::kFloat).requires_grad(true)); + auto input = torch::tensor( + {{0.2, 0.2, 0.6}, {0.1, 0.8, 0.1}, {0.9, 0.09, 0.01}}, + torch::dtype(torch::kFloat).requires_grad(true)); auto target = torch::tensor({2, 1, 0}, torch::kLong); auto output = loss->forward(input, target); auto expected = torch::tensor({0.305556}, torch::kFloat); @@ -2048,8 +1985,10 @@ TEST_F(ModulesTest, MultiMarginLoss) { TEST_F(ModulesTest, CosineEmbeddingLoss) { CosineEmbeddingLoss cos(CosineEmbeddingLossOptions().margin(0.5)); - auto input1 = torch::tensor({{2, 3, 4}, {6, 2, 4}}, torch::dtype(torch::kFloat).requires_grad(true)); - auto input2 = torch::tensor({{2, 3, 5}, {9, 12, 0}}, torch::dtype(torch::kFloat).requires_grad(true)); + auto input1 = torch::tensor( + {{2, 3, 4}, {6, 2, 4}}, torch::dtype(torch::kFloat).requires_grad(true)); + auto input2 = torch::tensor( + {{2, 3, 5}, {9, 12, 0}}, torch::dtype(torch::kFloat).requires_grad(true)); auto target = torch::tensor({1, -1}); auto output = cos(input1, input2, target); auto expected = torch::tensor({0.1004}, torch::kFloat); @@ -2063,7 +2002,8 @@ TEST_F(ModulesTest, CosineEmbeddingLoss) { TEST_F(ModulesTest, SmoothL1LossDefaultOptions) { SmoothL1Loss loss; - auto input = torch::tensor({0.1, 1.2, 4.7}, torch::dtype(torch::kFloat).requires_grad(true)); + auto input = torch::tensor( + {0.1, 1.2, 4.7}, torch::dtype(torch::kFloat).requires_grad(true)); auto target = torch::tensor({0., 1., 5.}, torch::kFloat); auto output = loss(input, target); auto expected = torch::tensor(0.0233335, torch::kFloat); @@ -2076,7 +2016,8 @@ TEST_F(ModulesTest, SmoothL1LossDefaultOptions) { TEST_F(ModulesTest, HuberLossDefaultOptions) { HuberLoss loss; - auto input = torch::tensor({0.1, 1.2, 4.7}, torch::dtype(torch::kFloat).requires_grad(true)); + auto input = torch::tensor( + {0.1, 1.2, 4.7}, torch::dtype(torch::kFloat).requires_grad(true)); auto target = torch::tensor({0., 1., 5.}, torch::kFloat); auto output = loss(input, target); auto expected = torch::tensor(0.0233335, torch::kFloat); @@ -2089,7 +2030,8 @@ TEST_F(ModulesTest, HuberLossDefaultOptions) { TEST_F(ModulesTest, MultiLabelMarginLossDefaultOptions) { MultiLabelMarginLoss loss; - auto input = torch::tensor({{0.1, 0.2, 0.4, 0.8}}, torch::dtype(torch::kFloat).requires_grad(true)); + auto input = torch::tensor( + {{0.1, 0.2, 0.4, 0.8}}, torch::dtype(torch::kFloat).requires_grad(true)); auto target = torch::tensor({{3, 0, -1, 1}}, torch::kLong); auto output = loss->forward(input, target); auto expected = torch::tensor({0.8500}, torch::kFloat); @@ -2102,7 +2044,8 @@ TEST_F(ModulesTest, MultiLabelMarginLossDefaultOptions) { TEST_F(ModulesTest, SmoothL1LossNoReduction) { SmoothL1Loss loss(/*reduction=*/torch::kNone); - auto input = torch::tensor({0.1, 1.2, 4.7}, torch::dtype(torch::kFloat).requires_grad(true)); + auto input = torch::tensor( + {0.1, 1.2, 4.7}, torch::dtype(torch::kFloat).requires_grad(true)); auto target = torch::tensor({0., 1., 5.}, torch::kFloat); auto output = loss(input, target); auto expected = torch::tensor({0.005, 0.02, 0.045}, torch::kFloat); @@ -2115,7 +2058,8 @@ TEST_F(ModulesTest, SmoothL1LossNoReduction) { TEST_F(ModulesTest, HuberLossNoReduction) { HuberLoss loss(/*reduction=*/torch::kNone); - auto input = torch::tensor({0.1, 1.2, 4.7}, torch::dtype(torch::kFloat).requires_grad(true)); + auto input = torch::tensor( + {0.1, 1.2, 4.7}, torch::dtype(torch::kFloat).requires_grad(true)); auto target = torch::tensor({0., 1., 5.}, torch::kFloat); auto output = loss(input, target); auto expected = torch::tensor({0.005, 0.02, 0.045}, torch::kFloat); @@ -2128,7 +2072,8 @@ TEST_F(ModulesTest, HuberLossNoReduction) { TEST_F(ModulesTest, MultiLabelMarginLossNoReduction) { MultiLabelMarginLoss loss(torch::kNone); - auto input = torch::tensor({{0.1, 0.2, 0.4, 0.8}}, torch::dtype(torch::kFloat).requires_grad(true)); + auto input = torch::tensor( + {{0.1, 0.2, 0.4, 0.8}}, torch::dtype(torch::kFloat).requires_grad(true)); auto target = torch::tensor({{3, 0, -1, 1}}, torch::kLong); auto output = loss->forward(input, target); auto expected = torch::tensor({0.8500}, torch::kFloat); @@ -2142,7 +2087,8 @@ TEST_F(ModulesTest, MultiLabelMarginLossNoReduction) { TEST_F(ModulesTest, SmoothL1LossBeta) { auto options = SmoothL1LossOptions().beta(0.2); SmoothL1Loss loss(options); - auto input = torch::tensor({0.1, 1.2, 4.7}, torch::dtype(torch::kFloat).requires_grad(true)); + auto input = torch::tensor( + {0.1, 1.2, 4.7}, torch::dtype(torch::kFloat).requires_grad(true)); auto target = torch::tensor({0., 1., 5.}, torch::kFloat); auto output = loss(input, target); auto expected = torch::tensor(0.108333, torch::kFloat); @@ -2156,7 +2102,8 @@ TEST_F(ModulesTest, SmoothL1LossBeta) { TEST_F(ModulesTest, HuberLossDelta) { auto options = HuberLossOptions().delta(0.2); HuberLoss loss(options); - auto input = torch::tensor({0.1, 1.2, 4.7}, torch::dtype(torch::kFloat).requires_grad(true)); + auto input = torch::tensor( + {0.1, 1.2, 4.7}, torch::dtype(torch::kFloat).requires_grad(true)); auto target = torch::tensor({0., 1., 5.}, torch::kFloat); auto output = loss(input, target); auto expected = torch::tensor(0.0216666, torch::kFloat); @@ -2169,9 +2116,12 @@ TEST_F(ModulesTest, HuberLossDelta) { TEST_F(ModulesTest, TripletMarginLoss) { TripletMarginLoss loss(TripletMarginLossOptions().margin(1.0)); - auto anchor = torch::tensor({{3., 3.}}, torch::dtype(torch::kFloat).requires_grad(true)); - auto positive = torch::tensor({{2., 2.}}, torch::dtype(torch::kFloat).requires_grad(true)); - auto negative = torch::tensor({{0., 0.}}, torch::dtype(torch::kFloat).requires_grad(true)); + auto anchor = torch::tensor( + {{3., 3.}}, torch::dtype(torch::kFloat).requires_grad(true)); + auto positive = torch::tensor( + {{2., 2.}}, torch::dtype(torch::kFloat).requires_grad(true)); + auto negative = torch::tensor( + {{0., 0.}}, torch::dtype(torch::kFloat).requires_grad(true)); auto output = loss->forward(anchor, positive, negative); auto expected = torch::tensor({0.}, torch::kFloat); auto s = output.sum(); @@ -2186,30 +2136,28 @@ TEST_F(ModulesTest, TripletMarginWithDistanceLossDefaultParity) { // TripletMarginLoss options as our distance function, the outputs // are equal (i.e., equal under defaults). - std::vector - reductions = {torch::kSum, torch::kMean, torch::kNone}; + std::vector reductions = { + torch::kSum, torch::kMean, torch::kNone}; std::vector margins = {0.5, 1.0, 1.5}; std::vector swaps = {true, false}; for (auto& reduction : reductions) { for (auto& margin : margins) { for (const auto swap : swaps) { - auto anchor = - torch::randn({100, 128}, torch::dtype(torch::kFloat).requires_grad(true)); - auto positive = - torch::randn({100, 128}, torch::dtype(torch::kFloat).requires_grad(true)); - auto negative = - torch::randn({100, 128}, torch::dtype(torch::kFloat).requires_grad(true)); - - auto basicOptions = TripletMarginLossOptions() - .reduction(reduction) - .margin(margin) - .swap(swap); - auto distanceOptions = - TripletMarginWithDistanceLossOptions() - .reduction(reduction) - .margin(margin) - .swap(swap); + auto anchor = torch::randn( + {100, 128}, torch::dtype(torch::kFloat).requires_grad(true)); + auto positive = torch::randn( + {100, 128}, torch::dtype(torch::kFloat).requires_grad(true)); + auto negative = torch::randn( + {100, 128}, torch::dtype(torch::kFloat).requires_grad(true)); + + auto basicOptions = + TripletMarginLossOptions().reduction(reduction).margin(margin).swap( + swap); + auto distanceOptions = TripletMarginWithDistanceLossOptions() + .reduction(reduction) + .margin(margin) + .swap(swap); TripletMarginLoss basicLoss(basicOptions); TripletMarginWithDistanceLoss distanceLoss(distanceOptions); @@ -2219,8 +2167,10 @@ TEST_F(ModulesTest, TripletMarginWithDistanceLossDefaultParity) { auto distanceOperatorOutput = distanceLoss(anchor, positive, negative); ASSERT_TRUE(distanceOutput.allclose(basicOutput, 1e-6, 1e-6)); - ASSERT_TRUE(distanceOperatorOutput.allclose(distanceOutput, 1e-6, 1e-6)); - ASSERT_TRUE(distanceOperatorOutput.allclose(basicOperatorOutput, 1e-6, 1e-6)); + ASSERT_TRUE( + distanceOperatorOutput.allclose(distanceOutput, 1e-6, 1e-6)); + ASSERT_TRUE( + distanceOperatorOutput.allclose(basicOperatorOutput, 1e-6, 1e-6)); // handle for torch::kNone reduction auto sum = distanceOutput.sum(); @@ -2239,15 +2189,14 @@ TEST_F(ModulesTest, TripletMarginWithDistanceLossFunctionalParity) { auto pairwise_distance = [&](const torch::Tensor& x, const torch::Tensor& y) { return torch::pairwise_distance(x, y); }; - auto cosine_distance = [&](const torch::Tensor& x, - const torch::Tensor& y) { + auto cosine_distance = [&](const torch::Tensor& x, const torch::Tensor& y) { return 1.0 - torch::cosine_similarity(x, y); }; std::vector distance_functions = {pairwise_distance, cosine_distance}; - std::vector - reductions = {torch::kSum, torch::kMean, torch::kNone}; + std::vector reductions = { + torch::kSum, torch::kMean, torch::kNone}; std::vector margins = {0.5, 1.0, 1.5}; std::vector swaps = {true, false}; @@ -2255,12 +2204,11 @@ TEST_F(ModulesTest, TripletMarginWithDistanceLossFunctionalParity) { for (auto& reduction : reductions) { for (auto& margin : margins) { for (const auto swap : swaps) { - auto moduleOptions = - TripletMarginWithDistanceLossOptions() - .distance_function(function) - .reduction(reduction) - .margin(margin) - .swap(swap); + auto moduleOptions = TripletMarginWithDistanceLossOptions() + .distance_function(function) + .reduction(reduction) + .margin(margin) + .swap(swap); auto functionOptions = torch::nn::functional::TripletMarginWithDistanceLossFuncOptions() .distance_function(function) @@ -2279,11 +2227,13 @@ TEST_F(ModulesTest, TripletMarginWithDistanceLossFunctionalParity) { auto moduleOutput = distanceLoss->forward(anchor, positive, negative); auto moduleOperatorOutput = distanceLoss(anchor, positive, negative); - auto functionOutput = torch::nn::functional::triplet_margin_with_distance_loss( - anchor, positive, negative, functionOptions); + auto functionOutput = + torch::nn::functional::triplet_margin_with_distance_loss( + anchor, positive, negative, functionOptions); ASSERT_TRUE(moduleOutput.allclose(functionOutput, 1e-6, 1e-6)); - ASSERT_TRUE(moduleOperatorOutput.allclose(functionOutput, 1e-6, 1e-6)); + ASSERT_TRUE( + moduleOperatorOutput.allclose(functionOutput, 1e-6, 1e-6)); } } } @@ -2292,10 +2242,11 @@ TEST_F(ModulesTest, TripletMarginWithDistanceLossFunctionalParity) { TEST_F(ModulesTest, NLLLoss) { NLLLoss loss; - auto input = torch::tensor({{-0.1315, -3.1315, -2.5315}, - {-3.7038, -0.1038, -2.6038}, - {-2.3422, -1.3422, -0.4422}}, - torch::dtype(torch::kFloat).requires_grad(true)); + auto input = torch::tensor( + {{-0.1315, -3.1315, -2.5315}, + {-3.7038, -0.1038, -2.6038}, + {-2.3422, -1.3422, -0.4422}}, + torch::dtype(torch::kFloat).requires_grad(true)); auto target = torch::tensor({1, 0, 2}, torch::kLong); auto output = loss->forward(input, target); auto expected = torch::tensor(2.4258, torch::kFloat); @@ -2304,13 +2255,15 @@ TEST_F(ModulesTest, NLLLoss) { ASSERT_TRUE(output.allclose(expected, 1e-04)); ASSERT_TRUE( - NLLLoss(NLLLossOptions().ignore_index(-100).reduction(torch::kMean)) - ->forward(input, target).allclose(expected, 1e-04)); + NLLLoss(NLLLossOptions().ignore_index(-100).reduction(torch::kMean)) + ->forward(input, target) + .allclose(expected, 1e-04)); } TEST_F(ModulesTest, CrossEntropyLoss) { CrossEntropyLoss loss; - auto input = torch::tensor({{3., 3.}, {2., 2.}}, torch::dtype(torch::kFloat).requires_grad(true)); + auto input = torch::tensor( + {{3., 3.}, {2., 2.}}, torch::dtype(torch::kFloat).requires_grad(true)); auto target = torch::tensor({0, 1}, torch::kLong); auto output = loss->forward(input, target); auto expected = torch::tensor(0.6931, torch::kFloat); @@ -2320,12 +2273,16 @@ TEST_F(ModulesTest, CrossEntropyLoss) { ASSERT_TRUE(output.allclose(expected, 1e-04)); ASSERT_EQ(input.sizes(), input.grad().sizes()); ASSERT_TRUE( - CrossEntropyLoss(CrossEntropyLossOptions().ignore_index(-100).reduction(torch::kMean)) - ->forward(input, target).allclose(expected, 1e-04)); + CrossEntropyLoss( + CrossEntropyLossOptions().ignore_index(-100).reduction(torch::kMean)) + ->forward(input, target) + .allclose(expected, 1e-04)); // label smoothing with class indices - loss = CrossEntropyLoss(CrossEntropyLossOptions().label_smoothing(0.15).reduction(torch::kMean)); - input = torch::tensor({{3., 1.}, {1., 2.}}, torch::dtype(torch::kFloat).requires_grad(true)); + loss = CrossEntropyLoss( + CrossEntropyLossOptions().label_smoothing(0.15).reduction(torch::kMean)); + input = torch::tensor( + {{3., 1.}, {1., 2.}}, torch::dtype(torch::kFloat).requires_grad(true)); target = torch::tensor({0, 1}, torch::kLong); output = loss->forward(input, target); expected = torch::tensor(0.3326, torch::kFloat); @@ -2336,8 +2293,10 @@ TEST_F(ModulesTest, CrossEntropyLoss) { ASSERT_EQ(input.sizes(), input.grad().sizes()); // label smoothing with with target probabilities - loss = CrossEntropyLoss(CrossEntropyLossOptions().label_smoothing(0.2).reduction(torch::kMean)); - input = torch::tensor({{3., 1.}, {1., 2.}}, torch::dtype(torch::kFloat).requires_grad(true)); + loss = CrossEntropyLoss( + CrossEntropyLossOptions().label_smoothing(0.2).reduction(torch::kMean)); + input = torch::tensor( + {{3., 1.}, {1., 2.}}, torch::dtype(torch::kFloat).requires_grad(true)); target = torch::tensor({{0.8, 0.2}, {0.1, 0.9}}, torch::kFloat); output = loss->forward(input, target); expected = torch::tensor(0.5701, torch::kFloat); @@ -2346,13 +2305,14 @@ TEST_F(ModulesTest, CrossEntropyLoss) { ASSERT_TRUE(output.allclose(expected, 1e-04)); ASSERT_EQ(input.sizes(), input.grad().sizes()); - } TEST_F(ModulesTest, CosineSimilarity) { CosineSimilarity cos(CosineSimilarityOptions().dim(1)); - auto input1 = torch::tensor({{1, 2, 3}, {4, 5, 6}}, torch::dtype(torch::kFloat).requires_grad(true)); - auto input2 = torch::tensor({{1, 8, 3}, {2, 1, 6}}, torch::dtype(torch::kFloat).requires_grad(true)); + auto input1 = torch::tensor( + {{1, 2, 3}, {4, 5, 6}}, torch::dtype(torch::kFloat).requires_grad(true)); + auto input2 = torch::tensor( + {{1, 8, 3}, {2, 1, 6}}, torch::dtype(torch::kFloat).requires_grad(true)); auto output = cos->forward(input1, input2); auto expected = torch::tensor({0.8078, 0.8721}, torch::kFloat); auto s = output.sum(); @@ -2364,7 +2324,8 @@ TEST_F(ModulesTest, CosineSimilarity) { TEST_F(ModulesTest, SoftMarginLossDefaultOptions) { SoftMarginLoss loss; - auto input = torch::tensor({2., 4., 1., 3.}, torch::dtype(torch::kFloat).requires_grad(true)); + auto input = torch::tensor( + {2., 4., 1., 3.}, torch::dtype(torch::kFloat).requires_grad(true)); auto target = torch::tensor({-1., 1., 1., -1.}, torch::kFloat); auto output = loss->forward(input, target); auto expected = torch::tensor({1.3767317}, torch::kFloat); @@ -2377,8 +2338,11 @@ TEST_F(ModulesTest, SoftMarginLossDefaultOptions) { TEST_F(ModulesTest, MultiLabelSoftMarginLossDefaultOptions) { MultiLabelSoftMarginLoss loss; - auto input = torch::tensor({{0., 2., 2., 0.}, {2., 1., 0., 1.}}, torch::dtype(torch::kFloat).requires_grad(true)); - auto target = torch::tensor({{0., 0., 1., 0.}, {1., 0., 1., 1.}}, torch::kFloat); + auto input = torch::tensor( + {{0., 2., 2., 0.}, {2., 1., 0., 1.}}, + torch::dtype(torch::kFloat).requires_grad(true)); + auto target = + torch::tensor({{0., 0., 1., 0.}, {1., 0., 1., 1.}}, torch::kFloat); auto output = loss->forward(input, target); auto expected = torch::tensor({0.7608436}, torch::kFloat); auto s = output.sum(); @@ -2390,10 +2354,12 @@ TEST_F(ModulesTest, MultiLabelSoftMarginLossDefaultOptions) { TEST_F(ModulesTest, SoftMarginLossNoReduction) { SoftMarginLoss loss(torch::kNone); - auto input = torch::tensor({2., 4., 1., 3.}, torch::dtype(torch::kFloat).requires_grad(true)); + auto input = torch::tensor( + {2., 4., 1., 3.}, torch::dtype(torch::kFloat).requires_grad(true)); auto target = torch::tensor({-1., 1., 1., -1.}, torch::kFloat); auto output = loss->forward(input, target); - auto expected = torch::tensor({2.1269281, 0.01814993, 0.3132617, 3.0485873}, torch::kFloat); + auto expected = torch::tensor( + {2.1269281, 0.01814993, 0.3132617, 3.0485873}, torch::kFloat); auto s = output.sum(); s.backward(); @@ -2402,10 +2368,14 @@ TEST_F(ModulesTest, SoftMarginLossNoReduction) { } TEST_F(ModulesTest, MultiLabelSoftMarginLossWeightedNoReduction) { - auto input = torch::tensor({{0., 2., 2., 0.}, {2., 1., 0., 1.}}, torch::dtype(torch::kFloat).requires_grad(true)); - auto target = torch::tensor({{0., 0., 1., 0.}, {1., 0., 1., 1.}}, torch::kFloat); + auto input = torch::tensor( + {{0., 2., 2., 0.}, {2., 1., 0., 1.}}, + torch::dtype(torch::kFloat).requires_grad(true)); + auto target = + torch::tensor({{0., 0., 1., 0.}, {1., 0., 1., 1.}}, torch::kFloat); auto weight = torch::tensor({0.1, 0.6, 0.4, 0.8}, torch::kFloat); - auto options = MultiLabelSoftMarginLossOptions().reduction(torch::kNone).weight(weight); + auto options = + MultiLabelSoftMarginLossOptions().reduction(torch::kNone).weight(weight); MultiLabelSoftMarginLoss loss = MultiLabelSoftMarginLoss(options); auto output = loss->forward(input, target); auto expected = torch::tensor({0.4876902, 0.3321295}, torch::kFloat); @@ -2418,8 +2388,10 @@ TEST_F(ModulesTest, MultiLabelSoftMarginLossWeightedNoReduction) { TEST_F(ModulesTest, PairwiseDistance) { PairwiseDistance dist(PairwiseDistanceOptions().p(1)); - auto input1 = torch::tensor({{1, 2, 3}, {4, 5, 6}}, torch::dtype(torch::kFloat).requires_grad(true)); - auto input2 = torch::tensor({{1, 8, 3}, {2, 1, 6}}, torch::dtype(torch::kFloat).requires_grad(true)); + auto input1 = torch::tensor( + {{1, 2, 3}, {4, 5, 6}}, torch::dtype(torch::kFloat).requires_grad(true)); + auto input2 = torch::tensor( + {{1, 8, 3}, {2, 1, 6}}, torch::dtype(torch::kFloat).requires_grad(true)); auto output = dist->forward(input1, input2); auto expected = torch::tensor({6, 6}, torch::kFloat); auto s = output.sum(); @@ -2433,7 +2405,7 @@ TEST_F(ModulesTest, ELU) { const auto size = 3; for (const auto alpha : {0.0, 0.42, 1.0, 4.2, 42.42}) { for (const auto inplace : {false, true}) { - ELU model {ELUOptions().alpha(alpha).inplace(inplace)}; + ELU model{ELUOptions().alpha(alpha).inplace(inplace)}; auto x = torch::linspace(-10.0, 10.0, size * size * size); x.resize_({size, size, size}); if (!inplace) { @@ -2448,7 +2420,8 @@ TEST_F(ModulesTest, ELU) { ASSERT_EQ(y.ndimension(), 3); ASSERT_EQ(y.sizes(), std::vector({size, size, size})); auto y_exp = torch::max(torch::zeros_like(x_orig), x_orig) + - torch::min(torch::zeros_like(x_orig), alpha * (torch::exp(x_orig) - 1.0)); + torch::min(torch::zeros_like(x_orig), + alpha * (torch::exp(x_orig) - 1.0)); ASSERT_TRUE(torch::allclose(y, y_exp)); if (inplace) { ASSERT_TRUE(torch::allclose(x, y_exp)); @@ -2489,7 +2462,7 @@ TEST_F(ModulesTest, SELU) { TEST_F(ModulesTest, Hardshrink) { const auto size = 3; for (const auto lambda : {-4.2, -1.0, -0.42, 0.0, 0.42, 1.0, 4.2, 42.42}) { - Hardshrink model {HardshrinkOptions().lambda(lambda)}; + Hardshrink model{HardshrinkOptions().lambda(lambda)}; auto x = torch::linspace(-10.0, 10.0, size * size * size); x.resize_({size, size, size}).set_requires_grad(true); auto y = model(x); @@ -2509,7 +2482,9 @@ TEST_F(ModulesTest, Hardtanh) { for (const auto min_val : {-4.2, -1.0, -0.42, 0.0}) { for (const auto max_val : {0.42, 1.0, 4.2}) { for (const auto inplace : {false, true}) { - Hardtanh model {HardtanhOptions().min_val(min_val).max_val(max_val).inplace(inplace)}; + Hardtanh model{ + HardtanhOptions().min_val(min_val).max_val(max_val).inplace( + inplace)}; auto x = torch::linspace(-10.0, 10.0, size * size * size); x.resize_({size, size, size}); if (!inplace) { @@ -2523,8 +2498,8 @@ TEST_F(ModulesTest, Hardtanh) { ASSERT_EQ(y.ndimension(), 3); ASSERT_EQ(y.sizes(), std::vector({size, size, size})); auto y_exp = (x_orig < min_val) * min_val + - ((x_orig >= min_val) * (x_orig <= max_val)) * x_orig + - (x_orig > max_val) * max_val; + ((x_orig >= min_val) * (x_orig <= max_val)) * x_orig + + (x_orig > max_val) * max_val; ASSERT_TRUE(torch::allclose(y, y_exp)); if (inplace) { ASSERT_TRUE(torch::allclose(x, y_exp)); @@ -2537,12 +2512,14 @@ TEST_F(ModulesTest, Hardtanh) { } TEST_F(ModulesTest, HardtanhMinValGEMaxVal) { - ASSERT_THROWS_WITH(Hardtanh{HardtanhOptions().min_val(0.42).max_val(0.42)}, - "max_val must be greater than min_val"); - ASSERT_THROWS_WITH(Hardtanh{HardtanhOptions().min_val(0.42).max_val(-0.42)}, - "max_val must be greater than min_val"); + ASSERT_THROWS_WITH( + Hardtanh{HardtanhOptions().min_val(0.42).max_val(0.42)}, + "max_val must be greater than min_val"); + ASSERT_THROWS_WITH( + Hardtanh{HardtanhOptions().min_val(0.42).max_val(-0.42)}, + "max_val must be greater than min_val"); - Hardtanh ht {HardtanhOptions().min_val(-0.42).max_val(0.42)}; + Hardtanh ht{HardtanhOptions().min_val(-0.42).max_val(0.42)}; ht->options.min_val(0.42); ASSERT_THROWS_WITH(ht->reset(), "max_val must be greater than min_val"); ht->options.max_val(-0.42); @@ -2554,7 +2531,8 @@ TEST_F(ModulesTest, LeakyReLU) { for (const auto inplace : {false, true}) { for (const auto negative_slope : {0.0, 0.42, 1.0}) { for (const auto type : {torch::kFloat, torch::kBFloat16}) { - LeakyReLU model {LeakyReLUOptions().negative_slope(negative_slope).inplace(inplace)}; + LeakyReLU model{ + LeakyReLUOptions().negative_slope(negative_slope).inplace(inplace)}; auto x = torch::linspace(-10.0, 10.0, size * size * size).to(type); x.resize_({size, size, size}); if (!inplace) { @@ -2567,7 +2545,8 @@ TEST_F(ModulesTest, LeakyReLU) { ASSERT_EQ(s.ndimension(), 0); ASSERT_EQ(y.ndimension(), 3); ASSERT_EQ(y.sizes(), std::vector({size, size, size})); - auto y_exp = (x_orig < 0) * x_orig * negative_slope + (x_orig >= 0) * x_orig; + auto y_exp = + (x_orig < 0) * x_orig * negative_slope + (x_orig >= 0) * x_orig; ASSERT_TRUE(torch::allclose(y, y_exp)); if (inplace) { ASSERT_TRUE(torch::allclose(x, y_exp)); @@ -2592,7 +2571,8 @@ TEST_F(ModulesTest, LogSigmoid) { ASSERT_EQ(y.ndimension(), 3); ASSERT_EQ(y.sizes(), std::vector({size, size, size})); - auto y_exp = torch::log(torch::ones_like(x)/(torch::ones_like(x) + torch::exp(torch::neg(x)))); + auto y_exp = torch::log( + torch::ones_like(x) / (torch::ones_like(x) + torch::exp(torch::neg(x)))); ASSERT_TRUE(torch::allclose(y, y_exp, 1e-4, 1e-7)); } @@ -2635,14 +2615,19 @@ TEST_F(ModulesTest, LogSoftmax) { TEST_F(ModulesTest, AdaptiveLogSoftmaxWithLoss) { { // log_probs actually returns log_proba - AdaptiveLogSoftmaxWithLoss asfm(AdaptiveLogSoftmaxWithLossOptions(8, 4, {2}).div_value(2.)); + AdaptiveLogSoftmaxWithLoss asfm( + AdaptiveLogSoftmaxWithLossOptions(8, 4, {2}).div_value(2.)); auto x = torch::randn({4, 8}); auto logprob_out = asfm->log_prob(x); - ASSERT_TRUE(torch::allclose(torch::exp(logprob_out).data().sum(1), torch::ones(4))); + ASSERT_TRUE( + torch::allclose(torch::exp(logprob_out).data().sum(1), torch::ones(4))); } { // test predict - AdaptiveLogSoftmaxWithLoss asfm(AdaptiveLogSoftmaxWithLossOptions(8, 10, {4, 8}).div_value(2.).head_bias(true)); + AdaptiveLogSoftmaxWithLoss asfm( + AdaptiveLogSoftmaxWithLossOptions(8, 10, {4, 8}) + .div_value(2.) + .head_bias(true)); auto x = torch::randn({64, 8}); auto logprob_out = asfm->log_prob(x); auto predict_out = asfm->predict(x); @@ -2650,7 +2635,8 @@ TEST_F(ModulesTest, AdaptiveLogSoftmaxWithLoss) { } { // cluster sizes - AdaptiveLogSoftmaxWithLoss asfm(AdaptiveLogSoftmaxWithLossOptions(16, 20, {4, 10, 15}).div_value(2.)); + AdaptiveLogSoftmaxWithLoss asfm( + AdaptiveLogSoftmaxWithLossOptions(16, 20, {4, 10, 15}).div_value(2.)); auto x = torch::arange(100, 132, torch::kFloat).reshape({2, 16}); auto y = torch::tensor({0, 17}, torch::kLong); auto asm_out = asfm(x, y); @@ -2658,7 +2644,8 @@ TEST_F(ModulesTest, AdaptiveLogSoftmaxWithLoss) { } { // forward returns the same thing as log_probs - AdaptiveLogSoftmaxWithLoss asfm(AdaptiveLogSoftmaxWithLossOptions(8, 4, {2}).div_value(2.)); + AdaptiveLogSoftmaxWithLoss asfm( + AdaptiveLogSoftmaxWithLossOptions(8, 4, {2}).div_value(2.)); auto x = torch::randn({4, 8}); auto logprob_out = asfm->log_prob(x); NLLLoss nll_loss; @@ -2671,17 +2658,20 @@ TEST_F(ModulesTest, AdaptiveLogSoftmaxWithLoss) { auto expected = nll_loss->forward(logprob_out, y); ASSERT_TRUE(torch::allclose(loss, expected)); - ASSERT_TRUE(torch::allclose(out, logprob_out.gather(1, y.unsqueeze(1)).squeeze())); + ASSERT_TRUE(torch::allclose( + out, logprob_out.gather(1, y.unsqueeze(1)).squeeze())); } } { // test no batch dim - AdaptiveLogSoftmaxWithLoss asfm(AdaptiveLogSoftmaxWithLossOptions(16, 20, {4, 10, 15}).div_value(2.)); + AdaptiveLogSoftmaxWithLoss asfm( + AdaptiveLogSoftmaxWithLossOptions(16, 20, {4, 10, 15}).div_value(2.)); auto x = torch::randn({1, 16}); auto y = torch::tensor({17}); auto x2 = x.squeeze(0); auto y2 = y.squeeze(0); - ASSERT_TRUE(torch::allclose(asfm(x, y).output.squeeze(0), asfm(x2, y2).output)); + ASSERT_TRUE( + torch::allclose(asfm(x, y).output.squeeze(0), asfm(x2, y2).output)); } } @@ -2707,11 +2697,11 @@ TEST_F(ModulesTest, PReLU) { const auto num_parameters = 42; const auto init = 0.42; - PReLU model {PReLUOptions().num_parameters(num_parameters).init(init)}; + PReLU model{PReLUOptions().num_parameters(num_parameters).init(init)}; ASSERT_EQ(model->weight.sizes(), std::vector({num_parameters})); - ASSERT_TRUE(torch::allclose(model->weight, - torch::full(num_parameters, init))); + ASSERT_TRUE( + torch::allclose(model->weight, torch::full(num_parameters, init))); const auto x = torch::rand({100, num_parameters}) * 200 - 100; const auto y = model(x); @@ -2722,7 +2712,7 @@ TEST_F(ModulesTest, PReLU) { ASSERT_EQ(y.ndimension(), x.ndimension()); ASSERT_EQ(y.sizes(), x.sizes()); - const auto y_exp = (x < 0) * model->weight * x + (x >= 0) * x; + const auto y_exp = (x < 0) * model->weight * x + (x >= 0) * x; ASSERT_TRUE(torch::allclose(y, y_exp)); } @@ -2768,7 +2758,8 @@ TEST_F(ModulesTest, ReLU6) { ASSERT_EQ(s.ndimension(), 0); ASSERT_EQ(y.ndimension(), 3); ASSERT_EQ(y.sizes(), std::vector({size, size, size})); - auto y_exp = (x_orig < 0) * 0 + ((x_orig >= 0) * (x_orig <= 6)) * x_orig + (x_orig > 6) * 6; + auto y_exp = (x_orig < 0) * 0 + ((x_orig >= 0) * (x_orig <= 6)) * x_orig + + (x_orig > 6) * 6; ASSERT_TRUE(torch::allclose(y, y_exp)); if (inplace) { ASSERT_TRUE(torch::allclose(x, y_exp)); @@ -2784,7 +2775,8 @@ TEST_F(ModulesTest, RReLU) { for (const auto upper : {0.3, 0.4, 0.5}) { for (const auto inplace : {false, true}) { for (const auto type : {torch::kFloat, torch::kBFloat16}) { - RReLU model {RReLUOptions().lower(lower).upper(upper).inplace(inplace)}; + RReLU model{ + RReLUOptions().lower(lower).upper(upper).inplace(inplace)}; auto x = torch::linspace(-10.0, 10.0, size * size * size).to(type); x.resize_({size, size, size}); if (!inplace) { @@ -2797,8 +2789,10 @@ TEST_F(ModulesTest, RReLU) { ASSERT_EQ(s.ndimension(), 0); ASSERT_EQ(y.ndimension(), 3); ASSERT_EQ(y.sizes(), std::vector({size, size, size})); - auto z = ((x_orig >= 0) * (x_orig == y) + - (x_orig < 0) * (y >= x_orig * upper) * (y <= lower * x_orig)) * 1.0; + auto z = + ((x_orig >= 0) * (x_orig == y) + + (x_orig < 0) * (y >= x_orig * upper) * (y <= lower * x_orig)) * + 1.0; ASSERT_TRUE(torch::allclose(z, torch::ones_like(z))); if (inplace) { ASSERT_TRUE(torch::allclose(x, y)); @@ -2815,7 +2809,7 @@ TEST_F(ModulesTest, CELU) { const auto size = 3; for (const auto inplace : {false, true}) { for (const auto alpha : {0.42, 1.0, 4.2, 42.42}) { - CELU model {CELUOptions().alpha(alpha).inplace(inplace)}; + CELU model{CELUOptions().alpha(alpha).inplace(inplace)}; auto x = torch::linspace(-10.0, 10.0, size * size * size); x.resize_({size, size, size}); if (!inplace) { @@ -2829,7 +2823,8 @@ TEST_F(ModulesTest, CELU) { ASSERT_EQ(y.ndimension(), 3); ASSERT_EQ(y.sizes(), std::vector({size, size, size})); auto y_exp = torch::max(torch::zeros_like(x_orig), x_orig) + - torch::min(torch::zeros_like(x_orig), alpha * (torch::exp(x_orig / alpha) - 1.0)); + torch::min(torch::zeros_like(x_orig), + alpha * (torch::exp(x_orig / alpha) - 1.0)); ASSERT_TRUE(torch::allclose(y, y_exp)); if (inplace) { ASSERT_TRUE(torch::allclose(x, y_exp)); @@ -2898,15 +2893,14 @@ TEST_F(ModulesTest, Sigmoid) { TEST_F(ModulesTest, PixelShuffle) { PixelShuffle module(/*upscale_factor=*/2); auto x = torch::tensor( - {{{{-17, 19}, {-1, 2}}, - {{7, 14}, {-3, 1}}, - {{0, -2}, {-12, 14}}, - {{-15, 0}, {-3, 9}}}}, torch::kFloat); + {{{{-17, 19}, {-1, 2}}, + {{7, 14}, {-3, 1}}, + {{0, -2}, {-12, 14}}, + {{-15, 0}, {-3, 9}}}}, + torch::kFloat); auto y_exp = torch::tensor( - {{{{-17, 7, 19, 14}, - {0, -15, -2, 0}, - {-1, -3, 2, 1}, - {-12, -3, 14, 9}}}}, torch::kFloat); + {{{{-17, 7, 19, 14}, {0, -15, -2, 0}, {-1, -3, 2, 1}, {-12, -3, 14, 9}}}}, + torch::kFloat); auto y = module(x); ASSERT_EQ(y.ndimension(), 4); @@ -2936,12 +2930,12 @@ TEST_F(ModulesTest, Softplus) { const auto size = 3; for (const auto beta : {0.5, 1.0, 2.0}) { for (const auto threshold : {1.0, 3.0, 5.0}) { - Softplus model {SoftplusOptions().beta(beta).threshold(threshold)}; + Softplus model{SoftplusOptions().beta(beta).threshold(threshold)}; auto x = torch::linspace(-3.0, 3.0, 61); x.resize_({size, size, size}); auto y_exp = - (x <= threshold) * torch::log(1 + torch::exp(x * beta)) / beta + - (x > threshold) * x; + (x <= threshold) * torch::log(1 + torch::exp(x * beta)) / beta + + (x > threshold) * x; auto y = model(x); ASSERT_EQ(y.ndimension(), 3); @@ -2954,7 +2948,7 @@ TEST_F(ModulesTest, Softplus) { TEST_F(ModulesTest, Softshrink) { const auto size = 3; for (const auto lambda : {0.0, 0.42, 1.0, 4.2, 42.42}) { - Softshrink model {/*lambda=*/lambda}; + Softshrink model{/*lambda=*/lambda}; auto x = torch::linspace(-10.0, 10.0, size * size * size); x.resize_({size, size, size}).set_requires_grad(true); auto y = model(x); @@ -3002,11 +2996,12 @@ TEST_F(ModulesTest, Threshold) { for (const auto threshold : {0.5, 1.0, 2.0}) { for (const auto value : {0.5, 1.0, 2.0}) { for (const auto inplace : {false, true}) { - Threshold model {ThresholdOptions(threshold, value).inplace(inplace)}; + Threshold model{ThresholdOptions(threshold, value).inplace(inplace)}; auto x = torch::linspace(-3.0, 3.0, 61); x.resize_({size, size, size}); auto x_orig = x.clone(); - auto y_exp = (x_orig <= threshold) * value + (x_orig > threshold) * x_orig; + auto y_exp = + (x_orig <= threshold) * value + (x_orig > threshold) * x_orig; auto y = model(x); ASSERT_EQ(y.ndimension(), 3); @@ -3088,10 +3083,11 @@ TEST_F(ModulesTest, Upsampling2D) { for (const auto align_corners : {true, false}) { // test float scale factor up & down sampling for (const auto scale_factor : {0.5, 1.5, 2.0}) { - Upsample model(UpsampleOptions() - .scale_factor(std::vector({scale_factor, scale_factor})) - .mode(torch::kBilinear) - .align_corners(align_corners)); + Upsample model( + UpsampleOptions() + .scale_factor(std::vector({scale_factor, scale_factor})) + .mode(torch::kBilinear) + .align_corners(align_corners)); auto input = torch::ones({1, 1, 2, 2}, torch::requires_grad()); auto output = model->forward(input); auto expected_size = @@ -3109,10 +3105,11 @@ TEST_F(ModulesTest, Upsampling2D) { for (const auto align_corners : {true, false}) { // test float scale factor up & down sampling for (const auto scale_factor : {0.5, 1.5, 2.0}) { - Upsample model(UpsampleOptions() - .scale_factor(std::vector({scale_factor, scale_factor})) - .mode(torch::kBicubic) - .align_corners(align_corners)); + Upsample model( + UpsampleOptions() + .scale_factor(std::vector({scale_factor, scale_factor})) + .mode(torch::kBicubic) + .align_corners(align_corners)); auto input = torch::ones({1, 1, 2, 2}, torch::requires_grad()); auto output = model->forward(input); auto expected_size = @@ -3146,11 +3143,11 @@ TEST_F(ModulesTest, Upsampling3D) { for (const auto align_corners : {true, false}) { // test float scale factor up & down sampling for (const auto scale_factor : {0.5, 1.5, 2.0}) { - Upsample model( - UpsampleOptions() - .scale_factor(std::vector({scale_factor, scale_factor, scale_factor})) - .mode(torch::kTrilinear) - .align_corners(align_corners)); + Upsample model(UpsampleOptions() + .scale_factor(std::vector( + {scale_factor, scale_factor, scale_factor})) + .mode(torch::kTrilinear) + .align_corners(align_corners)); auto input = torch::ones({1, 1, 2, 2, 2}, torch::requires_grad()); auto output = model->forward(input); auto expected_size = @@ -3168,18 +3165,18 @@ TEST_F(ModulesTest, Upsampling3D) { } TEST_F(ModulesTest, CTCLoss) { - CTCLoss loss {CTCLossOptions().reduction(torch::kNone)}; + CTCLoss loss{CTCLossOptions().reduction(torch::kNone)}; const auto target_lengths = torch::tensor({0, 0, 0}); const auto input_lengths = torch::tensor({50, 50, 50}); const auto targets = - torch::randint(1, 15, at::IntArrayRef({0}), torch::kLong); + torch::randint(1, 15, at::IntArrayRef({0}), torch::kLong); const auto log_probs = - torch::randn({50, 3, 15}, torch::kDouble).log_softmax(2); + torch::randn({50, 3, 15}, torch::kDouble).log_softmax(2); const auto output = - loss->forward(log_probs, targets, input_lengths, target_lengths); + loss->forward(log_probs, targets, input_lengths, target_lengths); ASSERT_TRUE(output.ge(0).all().item()); ASSERT_TRUE(torch::allclose( - -log_probs.sum(0).slice(1, 0, 1).view_as(output), output)); + -log_probs.sum(0).slice(1, 0, 1).view_as(output), output)); } TEST_F(ModulesTest, PoissonNLLLoss) { @@ -3187,25 +3184,19 @@ TEST_F(ModulesTest, PoissonNLLLoss) { const auto target = torch::tensor({1., 2., 3.}); const auto component_wise_loss = torch::exp(input) - target * input; { - PoissonNLLLoss loss {PoissonNLLLossOptions().reduction(torch::kNone)}; - ASSERT_TRUE(torch::allclose( - component_wise_loss, - loss->forward(input, target) - )); + PoissonNLLLoss loss{PoissonNLLLossOptions().reduction(torch::kNone)}; + ASSERT_TRUE( + torch::allclose(component_wise_loss, loss->forward(input, target))); } { - PoissonNLLLoss loss {PoissonNLLLossOptions().reduction(torch::kSum)}; + PoissonNLLLoss loss{PoissonNLLLossOptions().reduction(torch::kSum)}; ASSERT_TRUE(torch::allclose( - torch::sum(component_wise_loss), - loss->forward(input, target) - )); + torch::sum(component_wise_loss), loss->forward(input, target))); } { - PoissonNLLLoss loss {PoissonNLLLossOptions().reduction(torch::kMean)}; + PoissonNLLLoss loss{PoissonNLLLossOptions().reduction(torch::kMean)}; ASSERT_TRUE(torch::allclose( - torch::mean(component_wise_loss), - loss->forward(input, target) - )); + torch::mean(component_wise_loss), loss->forward(input, target))); } } @@ -3216,529 +3207,547 @@ TEST_F(ModulesTest, MarginRankingLoss) { const auto input2 = torch::randn(15) * 10; const auto target = torch::randn(15).sign(); ASSERT_TRUE(torch::allclose( - loss->forward(input1, input2, target), - (-target * (input1 - input2)).clamp(0).mean() - )); + loss->forward(input1, input2, target), + (-target * (input1 - input2)).clamp(0).mean())); } { - MarginRankingLoss loss {MarginRankingLossOptions().margin(0.5).reduction(torch::kSum)}; + MarginRankingLoss loss{ + MarginRankingLossOptions().margin(0.5).reduction(torch::kSum)}; const auto input1 = torch::randn(15) * 10; const auto input2 = torch::randn(15) * 10; const auto target = torch::randn(15).sign(); const auto margin = 0.5; ASSERT_TRUE(torch::allclose( - loss->forward(input1, input2, target), - (-target * (input1 - input2) + margin).clamp(0).sum() - )); + loss->forward(input1, input2, target), + (-target * (input1 - input2) + margin).clamp(0).sum())); } { - MarginRankingLoss loss {MarginRankingLossOptions().margin(0.5).reduction(torch::kMean)}; + MarginRankingLoss loss{ + MarginRankingLossOptions().margin(0.5).reduction(torch::kMean)}; const auto input1 = torch::randn(15) * 10; const auto input2 = torch::randn(15) * 10; const auto target = torch::randn(15).sign(); const auto margin = 0.5; ASSERT_TRUE(torch::allclose( - loss->forward(input1, input2, target), - (-target * (input1 - input2) + margin).clamp(0).mean() - )); + loss->forward(input1, input2, target), + (-target * (input1 - input2) + margin).clamp(0).mean())); } } TEST_F(ModulesTest, BCEWithLogitsLoss) { - { // test BCE with logits raises if target and input are different size - { - const auto target = torch::rand(5); - const auto input = torch::rand({5, 1}); - ASSERT_THROWS_WITH( - BCEWithLogitsLoss()(input, target), - "must be the same as input size" - ); - } - { - const auto target = torch::rand({5, 1}); - const auto input = torch::rand(5); - ASSERT_THROWS_WITH( - BCEWithLogitsLoss()(input, target), - "must be the same as input size" - ); - } - } - { // test BCE with logits gives same result as sigmoid and bce loss - auto sigmoid = Sigmoid(); + {// test BCE with logits raises if target and input are different size + {const auto target = torch::rand(5); + const auto input = torch::rand({5, 1}); + ASSERT_THROWS_WITH( + BCEWithLogitsLoss()(input, target), "must be the same as input size"); +} +{ + const auto target = torch::rand({5, 1}); + const auto input = torch::rand(5); + ASSERT_THROWS_WITH( + BCEWithLogitsLoss()(input, target), "must be the same as input size"); +} +} +{ // test BCE with logits gives same result as sigmoid and bce loss + auto sigmoid = Sigmoid(); - auto target = torch::rand({64, 4}); - auto output = torch::rand({64, 4}) - 0.5; + auto target = torch::rand({64, 4}); + auto output = torch::rand({64, 4}) - 0.5; - ASSERT_TRUE(torch::allclose( - BCEWithLogitsLoss()(output, target), - BCELoss()(sigmoid(output), target) - )); + ASSERT_TRUE(torch::allclose( + BCEWithLogitsLoss()(output, target), BCELoss()(sigmoid(output), target))); - auto weight = torch::rand(4); - ASSERT_TRUE(torch::allclose( - BCEWithLogitsLoss( - BCEWithLogitsLossOptions().weight(weight) - )(output, target), - BCELoss( - BCELossOptions().weight(weight) - )(sigmoid(output), target) - )); + auto weight = torch::rand(4); + ASSERT_TRUE(torch::allclose( + BCEWithLogitsLoss(BCEWithLogitsLossOptions().weight(weight))( + output, target), + BCELoss(BCELossOptions().weight(weight))(sigmoid(output), target))); - target = torch::zeros({4, 1}, torch::kFloat); - output = torch::empty({4, 1}, torch::kFloat).fill_(-100); + target = torch::zeros({4, 1}, torch::kFloat); + output = torch::empty({4, 1}, torch::kFloat).fill_(-100); - ASSERT_TRUE(torch::allclose( - BCEWithLogitsLoss()(output, target), - BCELoss()(sigmoid(output), target) - )); + ASSERT_TRUE(torch::allclose( + BCEWithLogitsLoss()(output, target), BCELoss()(sigmoid(output), target))); - ASSERT_TRUE(torch::allclose( - BCEWithLogitsLoss( - BCEWithLogitsLossOptions().reduction(torch::kNone) - )(output, target), - BCELoss( - BCELossOptions().reduction(torch::kNone) - )(sigmoid(output), target) - )); - - weight = torch::rand({1}, torch::kFloat); - ASSERT_TRUE(torch::allclose( - BCEWithLogitsLoss( - BCEWithLogitsLossOptions().weight(weight) - )(output, target), - BCELoss( - BCELossOptions().weight(weight) - )(sigmoid(output), target) - )); - } - { // test BCE with logits has correct grad at zero - const auto output = torch::zeros({3, 1}, torch::requires_grad()); - const auto target = torch::zeros({3, 1}); - BCEWithLogitsLoss(BCEWithLogitsLossOptions() - .reduction(torch::kSum))(output, target).backward(); - const auto expected_grad = torch::empty({3, 1}).fill_(0.5); - ASSERT_TRUE(torch::allclose(output.grad(), expected_grad)); - } - { // test BCE with logits broadcasts weights - const auto target = torch::rand({16, 4}); - const auto output = torch::rand({16, 4}) - 0.5; + ASSERT_TRUE(torch::allclose( + BCEWithLogitsLoss(BCEWithLogitsLossOptions().reduction(torch::kNone))( + output, target), + BCELoss(BCELossOptions().reduction(torch::kNone))( + sigmoid(output), target))); - auto weight = torch::rand(4); - auto out1 = BCEWithLogitsLoss( - BCEWithLogitsLossOptions().weight(weight) - )(output, target); + weight = torch::rand({1}, torch::kFloat); + ASSERT_TRUE(torch::allclose( + BCEWithLogitsLoss(BCEWithLogitsLossOptions().weight(weight))( + output, target), + BCELoss(BCELossOptions().weight(weight))(sigmoid(output), target))); +} +{ // test BCE with logits has correct grad at zero + const auto output = torch::zeros({3, 1}, torch::requires_grad()); + const auto target = torch::zeros({3, 1}); + BCEWithLogitsLoss(BCEWithLogitsLossOptions().reduction(torch::kSum))( + output, target) + .backward(); + const auto expected_grad = torch::empty({3, 1}).fill_(0.5); + ASSERT_TRUE(torch::allclose(output.grad(), expected_grad)); +} +{ // test BCE with logits broadcasts weights + const auto target = torch::rand({16, 4}); + const auto output = torch::rand({16, 4}) - 0.5; - weight = weight.expand({16, 4}).contiguous(); - auto out2 = BCEWithLogitsLoss( - BCEWithLogitsLossOptions().weight(weight) - )(output, target); + auto weight = torch::rand(4); + auto out1 = BCEWithLogitsLoss(BCEWithLogitsLossOptions().weight(weight))( + output, target); - ASSERT_TRUE(torch::allclose(out1, out2)); + weight = weight.expand({16, 4}).contiguous(); + auto out2 = BCEWithLogitsLoss(BCEWithLogitsLossOptions().weight(weight))( + output, target); - weight = torch::rand({16, 1}); - out1 = BCEWithLogitsLoss( - BCEWithLogitsLossOptions().weight(weight) - )(output, target); + ASSERT_TRUE(torch::allclose(out1, out2)); - weight = weight.expand({16, 4}).contiguous(); - out2 = BCEWithLogitsLoss( - BCEWithLogitsLossOptions().weight(weight) - )(output, target); + weight = torch::rand({16, 1}); + out1 = BCEWithLogitsLoss(BCEWithLogitsLossOptions().weight(weight))( + output, target); - ASSERT_TRUE(torch::allclose(out1, out2)); - } - { // test BCE with logits ones in pos weights are the same as none - const auto target = torch::rand({64, 4}); - const auto output = torch::rand({64, 4}) - 0.5; - const auto pos_weight = torch::ones({64, 4}); + weight = weight.expand({16, 4}).contiguous(); + out2 = BCEWithLogitsLoss(BCEWithLogitsLossOptions().weight(weight))( + output, target); - ASSERT_TRUE(torch::allclose( + ASSERT_TRUE(torch::allclose(out1, out2)); +} +{ // test BCE with logits ones in pos weights are the same as none + const auto target = torch::rand({64, 4}); + const auto output = torch::rand({64, 4}) - 0.5; + const auto pos_weight = torch::ones({64, 4}); + + ASSERT_TRUE(torch::allclose( BCEWithLogitsLoss()(output, target), - BCEWithLogitsLoss( - BCEWithLogitsLossOptions().pos_weight(pos_weight) - )(output, target) - )); - } - { // test BCE with logits broadcasts pos weights - const auto target = torch::rand({64, 4}); - const auto output = torch::rand({64, 4}) - 0.5; - const auto pos_weight = torch::rand(4); - const auto out1 = BCEWithLogitsLoss( - BCEWithLogitsLossOptions().pos_weight(pos_weight) - )(output, target); - - const auto pos_weight1 = pos_weight.expand({1, 4}); - const auto out2 = BCEWithLogitsLoss( - BCEWithLogitsLossOptions().pos_weight(pos_weight) - )(output, target); - - const auto pos_weight2 = pos_weight.expand({64, 4}); - const auto out3 = BCEWithLogitsLoss( - BCEWithLogitsLossOptions().pos_weight(pos_weight) - )(output, target); - - ASSERT_TRUE(torch::allclose(out1, out2)); - ASSERT_TRUE(torch::allclose(out1, out3)); - } - { // test BCE with logits with pos weight has correct grad at zero - const auto output = torch::zeros({3, 1}, torch::requires_grad()); - const auto target = torch::zeros({3, 1}); - const auto pos_weight = torch::ones({3, 1}); - BCEWithLogitsLoss( - BCEWithLogitsLossOptions().pos_weight(pos_weight).reduction(torch::kSum) - )(output, target).backward(); - const auto expected_grad = torch::empty({3, 1}).fill_(0.5); - // NOLINTNEXTLINE(performance-unnecessary-copy-initialization) - const auto grad = output.grad(); - ASSERT_TRUE(torch::allclose(grad, expected_grad)); - } - { // test BCE with logits stability - const auto output = torch::tensor({0., -120.}); - const auto target = torch::tensor({0., 1.}); - const auto pos_weight = torch::tensor({1., 1.}); - - const auto out1 = BCEWithLogitsLoss()(output, target); - ASSERT_TRUE(torch::isfinite(out1).all().item()); - - const auto out2 = BCEWithLogitsLoss( - BCEWithLogitsLossOptions().pos_weight(pos_weight) - )(output, target); - ASSERT_TRUE(torch::isfinite(out2).all().item()); - } + BCEWithLogitsLoss(BCEWithLogitsLossOptions().pos_weight(pos_weight))( + output, target))); +} +{ // test BCE with logits broadcasts pos weights + const auto target = torch::rand({64, 4}); + const auto output = torch::rand({64, 4}) - 0.5; + const auto pos_weight = torch::rand(4); + const auto out1 = BCEWithLogitsLoss( + BCEWithLogitsLossOptions().pos_weight(pos_weight))(output, target); + + const auto pos_weight1 = pos_weight.expand({1, 4}); + const auto out2 = BCEWithLogitsLoss( + BCEWithLogitsLossOptions().pos_weight(pos_weight))(output, target); + + const auto pos_weight2 = pos_weight.expand({64, 4}); + const auto out3 = BCEWithLogitsLoss( + BCEWithLogitsLossOptions().pos_weight(pos_weight))(output, target); + + ASSERT_TRUE(torch::allclose(out1, out2)); + ASSERT_TRUE(torch::allclose(out1, out3)); +} +{ // test BCE with logits with pos weight has correct grad at zero + const auto output = torch::zeros({3, 1}, torch::requires_grad()); + const auto target = torch::zeros({3, 1}); + const auto pos_weight = torch::ones({3, 1}); + BCEWithLogitsLoss( + BCEWithLogitsLossOptions().pos_weight(pos_weight).reduction(torch::kSum))( + output, target) + .backward(); + const auto expected_grad = torch::empty({3, 1}).fill_(0.5); + // NOLINTNEXTLINE(performance-unnecessary-copy-initialization) + const auto grad = output.grad(); + ASSERT_TRUE(torch::allclose(grad, expected_grad)); +} +{ // test BCE with logits stability + const auto output = torch::tensor({0., -120.}); + const auto target = torch::tensor({0., 1.}); + const auto pos_weight = torch::tensor({1., 1.}); + + const auto out1 = BCEWithLogitsLoss()(output, target); + ASSERT_TRUE(torch::isfinite(out1).all().item()); + + const auto out2 = BCEWithLogitsLoss( + BCEWithLogitsLossOptions().pos_weight(pos_weight))(output, target); + ASSERT_TRUE(torch::isfinite(out2).all().item()); +} } namespace detail { - namespace F = torch::nn::functional; +namespace F = torch::nn::functional; - torch::Tensor _batchmatmul(const torch::Tensor& a, const torch::Tensor& b) { - TORCH_INTERNAL_ASSERT(a.size(0) == b.size(0)); - TORCH_INTERNAL_ASSERT(a.size(1) == b.size(1)); - auto retval = torch::zeros({a.size(0), a.size(1), a.size(2), b.size(3)}, torch::kFloat32); - for (const auto i : c10::irange(a.size(0))) { - for (const auto j : c10::irange(a.size(1))) { - retval[i][j] = torch::matmul(a[i][j], b[i][j]); - } +torch::Tensor _batchmatmul(const torch::Tensor& a, const torch::Tensor& b) { + TORCH_INTERNAL_ASSERT(a.size(0) == b.size(0)); + TORCH_INTERNAL_ASSERT(a.size(1) == b.size(1)); + auto retval = torch::zeros( + {a.size(0), a.size(1), a.size(2), b.size(3)}, torch::kFloat32); + for (const auto i : c10::irange(a.size(0))) { + for (const auto j : c10::irange(a.size(1))) { + retval[i][j] = torch::matmul(a[i][j], b[i][j]); } - return retval; } + return retval; +} - torch::Tensor _softmax(const torch::Tensor& x) { - auto output = torch::zeros(x.sizes()); - for (const auto i : c10::irange(x.size(0))) { - for (const auto j : c10::irange(x.size(1))) { - for (const auto k : c10::irange(x.size(2))) { - const auto& x_curr = x[i][j][k]; - const auto e_x = torch::exp(x_curr - torch::max(x_curr)); - output[i][j][k] = e_x / torch::sum(e_x); - } +torch::Tensor _softmax(const torch::Tensor& x) { + auto output = torch::zeros(x.sizes()); + for (const auto i : c10::irange(x.size(0))) { + for (const auto j : c10::irange(x.size(1))) { + for (const auto k : c10::irange(x.size(2))) { + const auto& x_curr = x[i][j][k]; + const auto e_x = torch::exp(x_curr - torch::max(x_curr)); + output[i][j][k] = e_x / torch::sum(e_x); } } - return output; } + return output; +} - std::tuple _scaled_dot_attn_ref( - const torch::Tensor& Q, const torch::Tensor& K, const torch::Tensor& V, - at::IntArrayRef dims, const torch::Tensor& unseen_mask = {}, +std::tuple _scaled_dot_attn_ref( + const torch::Tensor& Q, + const torch::Tensor& K, + const torch::Tensor& V, + at::IntArrayRef dims, + const torch::Tensor& unseen_mask = {}, const torch::Tensor& key_padding_mask = {}, bool average_attn_weights = true) { - auto QKT = _batchmatmul( - Q, - K.permute({0, 1, 3, 2}) / std::sqrt(dims[3]) - ); - const auto b1 = QKT.size(0); - const auto b2 = QKT.size(1); - const auto s1 = QKT.size(2); - const auto s2 = QKT.size(3); - if (unseen_mask.defined() || key_padding_mask.defined()) { - for (const auto i : c10::irange(b1)) { - for (const auto j : c10::irange(b2)) { - for (const auto m : c10::irange(s1)) { - for (const auto n : c10::irange(s2)) { - if (unseen_mask.defined() && unseen_mask[m][n].item() == 0) { - QKT[i][j][m][n] = -std::numeric_limits::infinity(); - } - if (key_padding_mask.defined() && key_padding_mask[i][n].item() != 0) { - QKT[i][j][m][n] = -std::numeric_limits::infinity(); - } + auto QKT = _batchmatmul(Q, K.permute({0, 1, 3, 2}) / std::sqrt(dims[3])); + const auto b1 = QKT.size(0); + const auto b2 = QKT.size(1); + const auto s1 = QKT.size(2); + const auto s2 = QKT.size(3); + if (unseen_mask.defined() || key_padding_mask.defined()) { + for (const auto i : c10::irange(b1)) { + for (const auto j : c10::irange(b2)) { + for (const auto m : c10::irange(s1)) { + for (const auto n : c10::irange(s2)) { + if (unseen_mask.defined() && + unseen_mask[m][n].item() == 0) { + QKT[i][j][m][n] = -std::numeric_limits::infinity(); + } + if (key_padding_mask.defined() && + key_padding_mask[i][n].item() != 0) { + QKT[i][j][m][n] = -std::numeric_limits::infinity(); } } } } } - auto reference = _softmax(QKT); - auto ref_attn_weight = reference; - if (average_attn_weights) { - // NOLINTNEXTLINE(bugprone-argument-comment) - ref_attn_weight = torch::sum(ref_attn_weight, /*axis=*/1) / b2; - } - reference = _batchmatmul(reference, V); - return std::tie(reference, ref_attn_weight); - } - - torch::Tensor _split_heads_ref(const torch::Tensor& X, at::IntArrayRef dims, int nheads, int d_head) { - auto X_split = X.reshape({dims[0], dims[1], nheads, d_head}); - auto X_split_transposed = X_split.permute({0, 2, 1, 3}); - return X_split_transposed.reshape({dims[0], nheads, dims[1], d_head}); } - - torch::Tensor _combine_heads_ref(const torch::Tensor& X, at::IntArrayRef dims, int nheads, int d_head) { - auto X_transposed = X.permute({0, 2, 1, 3}); - auto reference = X_transposed.reshape({dims[0], dims[1], nheads * d_head}); - return reference; + auto reference = _softmax(QKT); + auto ref_attn_weight = reference; + if (average_attn_weights) { + // NOLINTNEXTLINE(bugprone-argument-comment) + ref_attn_weight = torch::sum(ref_attn_weight, /*axis=*/1) / b2; } - - torch::Tensor _fc(torch::Tensor X, torch::Tensor X_weight, torch::Tensor X_bias) { - // NOLINTNEXTLINE(performance-unnecessary-copy-initialization) - auto X_fc_b = X_bias; + reference = _batchmatmul(reference, V); + return std::tie(reference, ref_attn_weight); +} + +torch::Tensor _split_heads_ref( + const torch::Tensor& X, + at::IntArrayRef dims, + int nheads, + int d_head) { + auto X_split = X.reshape({dims[0], dims[1], nheads, d_head}); + auto X_split_transposed = X_split.permute({0, 2, 1, 3}); + return X_split_transposed.reshape({dims[0], nheads, dims[1], d_head}); +} + +torch::Tensor _combine_heads_ref( + const torch::Tensor& X, + at::IntArrayRef dims, + int nheads, + int d_head) { + auto X_transposed = X.permute({0, 2, 1, 3}); + auto reference = X_transposed.reshape({dims[0], dims[1], nheads * d_head}); + return reference; +} + +torch::Tensor _fc( + torch::Tensor X, + torch::Tensor X_weight, + torch::Tensor X_bias) { + // NOLINTNEXTLINE(performance-unnecessary-copy-initialization) + auto X_fc_b = X_bias; + // NOLINTNEXTLINE(performance-unnecessary-copy-initialization) + auto X_fc_w = X_weight; + return torch::matmul(X, torch::t(X_fc_w)) + X_fc_b; +} + +void _multihead_attn_test_helper( + bool add_key_padding_mask = false, + bool add_bias_kv = false, + bool add_zero_attn = false, + bool saved_kv = false, + bool same_embed_dim = false, + bool average_attn_weights = true) { + std::random_device device; + std::mt19937 generator(device()); + std::uniform_int_distribution d_2_10(2, 10); + std::uniform_int_distribution d_3_10(3, 10); + bool registration_checked = false; + for (const auto i : c10::irange(100)) { + (void)i; // Suppress unused variable warning + const auto batch_sz = d_2_10(generator); + const auto seq_len = d_2_10(generator); + const auto d_head = d_3_10(generator); + const auto nheads = d_3_10(generator); + const auto d_model = d_head * nheads; + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + int kv_dim; + if (same_embed_dim) { + kv_dim = d_model; + } else { + std::uniform_int_distribution d(5, 20); + kv_dim = d(generator); + while (kv_dim == d_model) { + kv_dim = d(generator); + } + } + std::vector dims{batch_sz, seq_len, kv_dim}; + torch::Tensor saved_k; + torch::Tensor saved_k_tensor; + torch::Tensor saved_v; + torch::Tensor saved_v_tensor; + if (saved_kv) { + saved_k = torch::rand({batch_sz * nheads, seq_len, d_head}); + saved_k_tensor = saved_k; + saved_v = torch::rand({batch_sz * nheads, seq_len, d_head}); + saved_v_tensor = saved_v; + } + torch::Tensor key_padding_mask; + torch::Tensor key_padding_mask_tensor; + if (add_key_padding_mask) { + const auto seq_mask = torch::randint(0, 2, {1, seq_len}); + key_padding_mask = seq_mask.repeat({batch_sz, 1}) == 1; + key_padding_mask_tensor = key_padding_mask; + } + const auto decoder_state = torch::rand({batch_sz, d_model}); + const torch::Tensor K = torch::rand(dims); // NOLINTNEXTLINE(performance-unnecessary-copy-initialization) - auto X_fc_w = X_weight; - return torch::matmul(X, torch::t(X_fc_w)) + X_fc_b; - } + const torch::Tensor V = K; + const torch::Tensor Q = + decoder_state.clone().resize_({batch_sz, 1, d_model}); + auto attn_mask = torch::randint(0, 2, {1, seq_len}); + const torch::Tensor attn_mask_tensor = attn_mask.clone(); + attn_mask_tensor.masked_fill_( + attn_mask_tensor == 0, -std::numeric_limits::infinity()); + attn_mask_tensor.masked_fill_(attn_mask_tensor > 0, double(0.0)); - void _multihead_attn_test_helper(bool add_key_padding_mask = false, - bool add_bias_kv = false, bool add_zero_attn = false, - bool saved_kv = false, bool same_embed_dim = false, - bool average_attn_weights = true) { - std::random_device device; - std::mt19937 generator(device()); - std::uniform_int_distribution d_2_10(2, 10); - std::uniform_int_distribution d_3_10(3, 10); - bool registration_checked = false; - for (const auto i : c10::irange(100)) { - (void)i; // Suppress unused variable warning - const auto batch_sz = d_2_10(generator); - const auto seq_len = d_2_10(generator); - const auto d_head = d_3_10(generator); - const auto nheads = d_3_10(generator); - const auto d_model = d_head * nheads; - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - int kv_dim; + // NOLINTNEXTLINE(performance-unnecessary-copy-initialization) + const torch::Tensor decoder_state_tensor = decoder_state; + const torch::Tensor source_hid_tensor = K.transpose(0, 1); + + const auto options = MultiheadAttentionOptions(d_model, nheads) + .add_bias_kv(add_bias_kv) + .add_zero_attn(add_zero_attn) + .kdim(kv_dim) + .vdim(kv_dim); + const auto multihead_attn_module = MultiheadAttention(options); + + if (!registration_checked) { + // make sure parameters are all registered correctly + auto named_parameters = multihead_attn_module->named_parameters(); if (same_embed_dim) { - kv_dim = d_model; + ASSERT_TRUE(named_parameters.contains("in_proj_weight")); } else { - std::uniform_int_distribution d(5, 20); - kv_dim = d(generator); - while (kv_dim == d_model) { - kv_dim = d(generator); - } + ASSERT_TRUE(named_parameters.contains("q_proj_weight")); + ASSERT_TRUE(named_parameters.contains("k_proj_weight")); + ASSERT_TRUE(named_parameters.contains("v_proj_weight")); } - std::vector dims {batch_sz, seq_len, kv_dim}; - torch::Tensor saved_k; - torch::Tensor saved_k_tensor; - torch::Tensor saved_v; - torch::Tensor saved_v_tensor; - if (saved_kv) { - saved_k = torch::rand({batch_sz * nheads, seq_len, d_head}); - saved_k_tensor = saved_k; - saved_v = torch::rand({batch_sz * nheads, seq_len, d_head}); - saved_v_tensor = saved_v; - } - torch::Tensor key_padding_mask; - torch::Tensor key_padding_mask_tensor; - if (add_key_padding_mask) { - const auto seq_mask = torch::randint(0, 2, {1, seq_len}); - key_padding_mask = seq_mask.repeat({batch_sz, 1}) == 1; - key_padding_mask_tensor = key_padding_mask; - } - const auto decoder_state = torch::rand({batch_sz, d_model}); - const torch::Tensor K = torch::rand(dims); - // NOLINTNEXTLINE(performance-unnecessary-copy-initialization) - const torch::Tensor V = K; - const torch::Tensor Q = decoder_state.clone().resize_({batch_sz, 1, d_model}); - auto attn_mask = torch::randint(0, 2, {1, seq_len}); - const torch::Tensor attn_mask_tensor = attn_mask.clone(); - attn_mask_tensor.masked_fill_(attn_mask_tensor == 0, -std::numeric_limits::infinity()); - attn_mask_tensor.masked_fill_(attn_mask_tensor > 0, double(0.0)); - - // NOLINTNEXTLINE(performance-unnecessary-copy-initialization) - const torch::Tensor decoder_state_tensor = decoder_state; - const torch::Tensor source_hid_tensor = K.transpose(0, 1); - - const auto options = MultiheadAttentionOptions(d_model, nheads) - .add_bias_kv(add_bias_kv) - .add_zero_attn(add_zero_attn) - .kdim(kv_dim) - .vdim(kv_dim); - const auto multihead_attn_module = MultiheadAttention(options); - - if (!registration_checked) { - // make sure parameters are all registered correctly - auto named_parameters = multihead_attn_module->named_parameters(); - if (same_embed_dim) { - ASSERT_TRUE(named_parameters.contains("in_proj_weight")); - } - else { - ASSERT_TRUE(named_parameters.contains("q_proj_weight")); - ASSERT_TRUE(named_parameters.contains("k_proj_weight")); - ASSERT_TRUE(named_parameters.contains("v_proj_weight")); - } - if (add_bias_kv) { - ASSERT_TRUE(named_parameters.contains("bias_k")); - ASSERT_TRUE(named_parameters.contains("bias_v")); - } - // make sure sub modules are all registered correctly - auto submodules = multihead_attn_module->named_children(); - ASSERT_TRUE(submodules.contains("out_proj")); - registration_checked = true; - } - - torch::Tensor bias_k; - torch::Tensor bias_v; if (add_bias_kv) { - bias_k = multihead_attn_module->bias_k.detach(); - bias_v = multihead_attn_module->bias_v.detach(); - } else { - bias_k.reset(); - bias_v.reset(); + ASSERT_TRUE(named_parameters.contains("bias_k")); + ASSERT_TRUE(named_parameters.contains("bias_v")); } + // make sure sub modules are all registered correctly + auto submodules = multihead_attn_module->named_children(); + ASSERT_TRUE(submodules.contains("out_proj")); + registration_checked = true; + } - torch::Tensor _Q = decoder_state_tensor.unsqueeze(1).transpose(0, 1); - // NOLINTNEXTLINE(performance-unnecessary-copy-initialization) - torch::Tensor _V = source_hid_tensor; - // NOLINTNEXTLINE(performance-unnecessary-copy-initialization) - torch::Tensor _K = source_hid_tensor; - - torch::Tensor result; - torch::Tensor result_weight; - if (multihead_attn_module->_qkv_same_embed_dim) { - std::tie(result, result_weight) = F::multi_head_attention_forward( - _Q, _K, _V, + torch::Tensor bias_k; + torch::Tensor bias_v; + if (add_bias_kv) { + bias_k = multihead_attn_module->bias_k.detach(); + bias_v = multihead_attn_module->bias_v.detach(); + } else { + bias_k.reset(); + bias_v.reset(); + } + + torch::Tensor _Q = decoder_state_tensor.unsqueeze(1).transpose(0, 1); + // NOLINTNEXTLINE(performance-unnecessary-copy-initialization) + torch::Tensor _V = source_hid_tensor; + // NOLINTNEXTLINE(performance-unnecessary-copy-initialization) + torch::Tensor _K = source_hid_tensor; + + torch::Tensor result; + torch::Tensor result_weight; + if (multihead_attn_module->_qkv_same_embed_dim) { + std::tie(result, result_weight) = F::multi_head_attention_forward( + _Q, + _K, + _V, F::MultiheadAttentionForwardFuncOptions( - /*embed_dim_to_check=*/d_model, - /*num_heads=*/nheads, - /*in_proj_weight=*/multihead_attn_module->in_proj_weight, - /*in_proj_bias=*/multihead_attn_module->in_proj_bias, - /*bias_k=*/multihead_attn_module->bias_k, - /*bias_v=*/multihead_attn_module->bias_v, - /*add_zero_attn=*/multihead_attn_module->options.add_zero_attn(), - /*dropout_p=*/multihead_attn_module->options.dropout(), - /*out_proj_weight=*/multihead_attn_module->out_proj->weight, - /*out_proj_bias=*/multihead_attn_module->out_proj->bias - ) - .training(multihead_attn_module->is_training()) - .key_padding_mask(key_padding_mask_tensor) - .need_weights(true) - .attn_mask(attn_mask_tensor) - .static_k(saved_k_tensor) - .static_v(saved_v_tensor) - .average_attn_weights(average_attn_weights) - ); - } else { - std::tie(result, result_weight) = F::multi_head_attention_forward( - _Q, _K, _V, + /*embed_dim_to_check=*/d_model, + /*num_heads=*/nheads, + /*in_proj_weight=*/multihead_attn_module->in_proj_weight, + /*in_proj_bias=*/multihead_attn_module->in_proj_bias, + /*bias_k=*/multihead_attn_module->bias_k, + /*bias_v=*/multihead_attn_module->bias_v, + /*add_zero_attn=*/multihead_attn_module->options.add_zero_attn(), + /*dropout_p=*/multihead_attn_module->options.dropout(), + /*out_proj_weight=*/multihead_attn_module->out_proj->weight, + /*out_proj_bias=*/multihead_attn_module->out_proj->bias) + .training(multihead_attn_module->is_training()) + .key_padding_mask(key_padding_mask_tensor) + .need_weights(true) + .attn_mask(attn_mask_tensor) + .static_k(saved_k_tensor) + .static_v(saved_v_tensor) + .average_attn_weights(average_attn_weights)); + } else { + std::tie(result, result_weight) = F::multi_head_attention_forward( + _Q, + _K, + _V, F::MultiheadAttentionForwardFuncOptions( - /*embed_dim_to_check=*/d_model, - /*num_heads=*/nheads, - /*in_proj_weight=*/{}, - /*in_proj_bias=*/multihead_attn_module->in_proj_bias, - /*bias_k=*/multihead_attn_module->bias_k, - /*bias_v=*/multihead_attn_module->bias_v, - /*add_zero_attn=*/multihead_attn_module->options.add_zero_attn(), - /*dropout_p=*/multihead_attn_module->options.dropout(), - /*out_proj_weight=*/multihead_attn_module->out_proj->weight, - /*out_proj_bias=*/multihead_attn_module->out_proj->bias - ) - .training(multihead_attn_module->is_training()) - .key_padding_mask(key_padding_mask_tensor) - .need_weights(true) - .attn_mask(attn_mask_tensor) - .use_separate_proj_weight(true) - .q_proj_weight(multihead_attn_module->q_proj_weight) - .k_proj_weight(multihead_attn_module->k_proj_weight) - .v_proj_weight(multihead_attn_module->v_proj_weight) - .static_k(saved_k_tensor) - .static_v(saved_v_tensor) - .average_attn_weights(average_attn_weights) - ); - } - result = result.squeeze(0).detach(); - torch::Tensor q_proj_weight; - torch::Tensor k_proj_weight; - torch::Tensor v_proj_weight; - if (multihead_attn_module->_qkv_same_embed_dim) { - q_proj_weight = multihead_attn_module->in_proj_weight.slice(/*dim=*/0, 0, d_model); - k_proj_weight = multihead_attn_module->in_proj_weight.slice(/*dim=*/0, d_model, (d_model * 2)); - v_proj_weight = multihead_attn_module->in_proj_weight.slice(/*dim=*/0, (d_model * 2)); - } else { - q_proj_weight = multihead_attn_module->q_proj_weight; - k_proj_weight = multihead_attn_module->k_proj_weight; - v_proj_weight = multihead_attn_module->v_proj_weight; - } - auto Q_fc = _fc(Q, q_proj_weight, multihead_attn_module->in_proj_bias.slice(/*dim=*/0, 0, d_model)); - auto K_fc = _fc(K, k_proj_weight, multihead_attn_module->in_proj_bias.slice(/*dim=*/0, d_model, (d_model * 2))); - auto V_fc = _fc(V, v_proj_weight, multihead_attn_module->in_proj_bias.slice(/*dim=*/0, (d_model * 2))); - - if (add_bias_kv) { - K_fc = torch::cat({K_fc, bias_k.repeat({K_fc.size(0) / bias_k.size(0), 1, 1}/*, axis=0*/)}, /*dim=*/1); - V_fc = torch::cat({V_fc, bias_v.repeat({V_fc.size(0) / bias_v.size(0), 1, 1}/*, axis=0*/)}, /*dim=*/1); - if (attn_mask.defined()) { - attn_mask = torch::cat({attn_mask, torch::ones({1, 1})}, /*dim=*/1); - } - if (key_padding_mask.defined()) { - key_padding_mask = torch::cat({key_padding_mask, torch::full({batch_sz, 1}, false, torch::kBool)}, /*dim=*/1); - } - dims[1] += 1; + /*embed_dim_to_check=*/d_model, + /*num_heads=*/nheads, + /*in_proj_weight=*/{}, + /*in_proj_bias=*/multihead_attn_module->in_proj_bias, + /*bias_k=*/multihead_attn_module->bias_k, + /*bias_v=*/multihead_attn_module->bias_v, + /*add_zero_attn=*/multihead_attn_module->options.add_zero_attn(), + /*dropout_p=*/multihead_attn_module->options.dropout(), + /*out_proj_weight=*/multihead_attn_module->out_proj->weight, + /*out_proj_bias=*/multihead_attn_module->out_proj->bias) + .training(multihead_attn_module->is_training()) + .key_padding_mask(key_padding_mask_tensor) + .need_weights(true) + .attn_mask(attn_mask_tensor) + .use_separate_proj_weight(true) + .q_proj_weight(multihead_attn_module->q_proj_weight) + .k_proj_weight(multihead_attn_module->k_proj_weight) + .v_proj_weight(multihead_attn_module->v_proj_weight) + .static_k(saved_k_tensor) + .static_v(saved_v_tensor) + .average_attn_weights(average_attn_weights)); + } + result = result.squeeze(0).detach(); + torch::Tensor q_proj_weight; + torch::Tensor k_proj_weight; + torch::Tensor v_proj_weight; + if (multihead_attn_module->_qkv_same_embed_dim) { + q_proj_weight = + multihead_attn_module->in_proj_weight.slice(/*dim=*/0, 0, d_model); + k_proj_weight = multihead_attn_module->in_proj_weight.slice( + /*dim=*/0, d_model, (d_model * 2)); + v_proj_weight = + multihead_attn_module->in_proj_weight.slice(/*dim=*/0, (d_model * 2)); + } else { + q_proj_weight = multihead_attn_module->q_proj_weight; + k_proj_weight = multihead_attn_module->k_proj_weight; + v_proj_weight = multihead_attn_module->v_proj_weight; + } + auto Q_fc = + _fc(Q, + q_proj_weight, + multihead_attn_module->in_proj_bias.slice(/*dim=*/0, 0, d_model)); + auto K_fc = + _fc(K, + k_proj_weight, + multihead_attn_module->in_proj_bias.slice( + /*dim=*/0, d_model, (d_model * 2))); + auto V_fc = _fc( + V, + v_proj_weight, + multihead_attn_module->in_proj_bias.slice(/*dim=*/0, (d_model * 2))); + + if (add_bias_kv) { + K_fc = torch::cat( + {K_fc, + bias_k.repeat({K_fc.size(0) / bias_k.size(0), 1, 1} /*, axis=0*/)}, + /*dim=*/1); + V_fc = torch::cat( + {V_fc, + bias_v.repeat({V_fc.size(0) / bias_v.size(0), 1, 1} /*, axis=0*/)}, + /*dim=*/1); + if (attn_mask.defined()) { + attn_mask = torch::cat({attn_mask, torch::ones({1, 1})}, /*dim=*/1); } - const auto Q_split = _split_heads_ref( - Q_fc, {batch_sz, 1, d_model}, nheads, d_head - ); - torch::Tensor K_split; - if (saved_k.defined()) { - K_split = saved_k.reshape({dims[0], nheads, dims[1], d_head}); - } else { - K_split = _split_heads_ref(K_fc, dims, nheads, d_head); + if (key_padding_mask.defined()) { + key_padding_mask = torch::cat( + {key_padding_mask, torch::full({batch_sz, 1}, false, torch::kBool)}, + /*dim=*/1); } - torch::Tensor V_split; - if (saved_v.defined()) { - V_split = saved_v.reshape({dims[0], nheads, dims[1], d_head}); - } else { - V_split = _split_heads_ref(V_fc, dims, nheads, d_head); + dims[1] += 1; + } + const auto Q_split = + _split_heads_ref(Q_fc, {batch_sz, 1, d_model}, nheads, d_head); + torch::Tensor K_split; + if (saved_k.defined()) { + K_split = saved_k.reshape({dims[0], nheads, dims[1], d_head}); + } else { + K_split = _split_heads_ref(K_fc, dims, nheads, d_head); + } + torch::Tensor V_split; + if (saved_v.defined()) { + V_split = saved_v.reshape({dims[0], nheads, dims[1], d_head}); + } else { + V_split = _split_heads_ref(V_fc, dims, nheads, d_head); + } + if (add_zero_attn) { + dims[1] += 1; + K_split = torch::cat( + {K_split, + torch::zeros( + {K_split.size(0), K_split.size(1), 1, K_split.size(3)})}, + /*dim=*/2); + V_split = torch::cat( + {V_split, + torch::zeros( + {V_split.size(0), V_split.size(1), 1, V_split.size(3)})}, + /*dim=*/2); + if (attn_mask.defined()) { + attn_mask = torch::cat({attn_mask, torch::ones({1, 1})}, /*dim=*/1); } - if (add_zero_attn) { - dims[1] += 1; - K_split = torch::cat({ - K_split, - torch::zeros({K_split.size(0), K_split.size(1), 1, K_split.size(3)}) - }, /*dim=*/2); - V_split = torch::cat({ - V_split, - torch::zeros({V_split.size(0), V_split.size(1), 1, V_split.size(3)}) - }, /*dim=*/2); - if (attn_mask.defined()) { - attn_mask = torch::cat({attn_mask, torch::ones({1, 1})}, /*dim=*/1); - } - if (key_padding_mask.defined()) { - key_padding_mask = torch::cat({key_padding_mask, torch::full({batch_sz, 1}, false, torch::kBool)}, /*dim=*/1); - } + if (key_padding_mask.defined()) { + key_padding_mask = torch::cat( + {key_padding_mask, torch::full({batch_sz, 1}, false, torch::kBool)}, + /*dim=*/1); } - torch::Tensor attn_heads; - torch::Tensor ref_attn_weight; - std::tie(attn_heads, ref_attn_weight) = _scaled_dot_attn_ref( - Q_split, - K_split, - V_split, - Q_split.sizes(), - attn_mask, - key_padding_mask, - average_attn_weights - ); - const auto combined_attn_heads = _combine_heads_ref(attn_heads, {batch_sz, 1}, nheads, d_head); - auto reference = _fc(combined_attn_heads, multihead_attn_module->out_proj->weight, multihead_attn_module->out_proj->bias); - // NOLINTNEXTLINE(bugprone-argument-comment) - reference = torch::squeeze(reference, /*axis=*/1); - - // result = reference - ASSERT_EQ(result.sizes(), std::vector({batch_sz, d_model})); - ASSERT_TRUE(torch::allclose(result, reference, 1e-5, 1e-5, /*equal_nan=*/true)); - - // result_weight = ref_attn_weight - result_weight = result_weight.detach(); - ASSERT_EQ(result_weight.sizes(), ref_attn_weight.sizes()); - ASSERT_TRUE(torch::allclose(result_weight, ref_attn_weight, 1e-5, 1e-5, /*equal_nan=*/true)); } + torch::Tensor attn_heads; + torch::Tensor ref_attn_weight; + std::tie(attn_heads, ref_attn_weight) = _scaled_dot_attn_ref( + Q_split, + K_split, + V_split, + Q_split.sizes(), + attn_mask, + key_padding_mask, + average_attn_weights); + const auto combined_attn_heads = + _combine_heads_ref(attn_heads, {batch_sz, 1}, nheads, d_head); + auto reference = + _fc(combined_attn_heads, + multihead_attn_module->out_proj->weight, + multihead_attn_module->out_proj->bias); + // NOLINTNEXTLINE(bugprone-argument-comment) + reference = torch::squeeze(reference, /*axis=*/1); + + // result = reference + ASSERT_EQ(result.sizes(), std::vector({batch_sz, d_model})); + ASSERT_TRUE( + torch::allclose(result, reference, 1e-5, 1e-5, /*equal_nan=*/true)); + + // result_weight = ref_attn_weight + result_weight = result_weight.detach(); + ASSERT_EQ(result_weight.sizes(), ref_attn_weight.sizes()); + ASSERT_TRUE(torch::allclose( + result_weight, ref_attn_weight, 1e-5, 1e-5, /*equal_nan=*/true)); } } +} // namespace detail TEST_F(ModulesTest, MultiheadAttention) { using namespace ::detail; @@ -3746,89 +3755,80 @@ TEST_F(ModulesTest, MultiheadAttention) { for (auto average_attn_weights : {false, true}) { // test_multihead_attn_add_zero_attn _multihead_attn_test_helper( - /*add_key_padding_mask=*/false, - /*add_bias_kv=*/false, - /*add_zero_attn=*/true, - /*saved_kv=*/false, - /*same_embed_dim=*/false, - /*average_attn_weights=*/average_attn_weights - ); + /*add_key_padding_mask=*/false, + /*add_bias_kv=*/false, + /*add_zero_attn=*/true, + /*saved_kv=*/false, + /*same_embed_dim=*/false, + /*average_attn_weights=*/average_attn_weights); // test_multihead_attn_add_bias_kv _multihead_attn_test_helper( - /*add_key_padding_mask=*/false, - /*add_bias_kv=*/true, - /*add_zero_attn=*/false, - /*saved_kv=*/false, - /*same_embed_dim=*/false, - /*average_attn_weights=*/average_attn_weights - ); + /*add_key_padding_mask=*/false, + /*add_bias_kv=*/true, + /*add_zero_attn=*/false, + /*saved_kv=*/false, + /*same_embed_dim=*/false, + /*average_attn_weights=*/average_attn_weights); // test_multihead_attn_no_masking(): _multihead_attn_test_helper(); // test_multihead_attn_key_padding_mask _multihead_attn_test_helper( - /*add_key_padding_mask=*/true, - /*add_bias_kv=*/false, - /*add_zero_attn=*/false, - /*saved_kv=*/false, - /*same_embed_dim=*/false, - /*average_attn_weights=*/average_attn_weights - ); + /*add_key_padding_mask=*/true, + /*add_bias_kv=*/false, + /*add_zero_attn=*/false, + /*saved_kv=*/false, + /*same_embed_dim=*/false, + /*average_attn_weights=*/average_attn_weights); // test_multihead_attn_saved_kv _multihead_attn_test_helper( - /*add_key_padding_mask=*/false, - /*add_bias_kv=*/false, - /*add_zero_attn=*/false, - /*saved_kv=*/true, - /*same_embed_dim=*/false, - /*average_attn_weights=*/average_attn_weights - ); + /*add_key_padding_mask=*/false, + /*add_bias_kv=*/false, + /*add_zero_attn=*/false, + /*saved_kv=*/true, + /*same_embed_dim=*/false, + /*average_attn_weights=*/average_attn_weights); // test_multihead_attn_add_bias_kv_zero_attn _multihead_attn_test_helper( - /*add_key_padding_mask=*/true, - /*add_bias_kv=*/true, - /*add_zero_attn=*/true, - /*saved_kv=*/false, - /*same_embed_dim=*/false, - /*average_attn_weights=*/average_attn_weights - ); + /*add_key_padding_mask=*/true, + /*add_bias_kv=*/true, + /*add_zero_attn=*/true, + /*saved_kv=*/false, + /*same_embed_dim=*/false, + /*average_attn_weights=*/average_attn_weights); // test_multihead_attn_all_arguments1 _multihead_attn_test_helper( - /*add_key_padding_mask=*/true, - /*add_bias_kv=*/false, - /*add_zero_attn=*/true, - /*saved_kv=*/true, - /*same_embed_dim=*/false, - /*average_attn_weights=*/average_attn_weights - ); - - ASSERT_THROWS_WITH( - // test_multihead_attn_all_arguments2 - _multihead_attn_test_helper( /*add_key_padding_mask=*/true, - /*add_bias_kv=*/true, + /*add_bias_kv=*/false, /*add_zero_attn=*/true, /*saved_kv=*/true, /*same_embed_dim=*/false, - /*average_attn_weights=*/average_attn_weights - ), - "bias cannot be added to static key" - ); + /*average_attn_weights=*/average_attn_weights); + + ASSERT_THROWS_WITH( + // test_multihead_attn_all_arguments2 + _multihead_attn_test_helper( + /*add_key_padding_mask=*/true, + /*add_bias_kv=*/true, + /*add_zero_attn=*/true, + /*saved_kv=*/true, + /*same_embed_dim=*/false, + /*average_attn_weights=*/average_attn_weights), + "bias cannot be added to static key"); // test_multihead_attn_all_arguments3 _multihead_attn_test_helper( - /*add_key_padding_mask=*/true, - /*add_bias_kv=*/false, - /*add_zero_attn=*/true, - /*saved_kv=*/true, - /*same_embed_dim=*/true, - /*average_attn_weights=*/average_attn_weights - ); + /*add_key_padding_mask=*/true, + /*add_bias_kv=*/false, + /*add_zero_attn=*/true, + /*saved_kv=*/true, + /*same_embed_dim=*/true, + /*average_attn_weights=*/average_attn_weights); } } @@ -3837,10 +3837,10 @@ TEST_F(ModulesTest, PrettyPrintIdentity) { } TEST_F(ModulesTest, PrettyPrintFlatten) { - ASSERT_EQ(c10::str(Flatten()), - "torch::nn::Flatten(start_dim=1, end_dim=-1)"); - ASSERT_EQ(c10::str(Flatten(FlattenOptions().start_dim(2).end_dim(4))), - "torch::nn::Flatten(start_dim=2, end_dim=4)"); + ASSERT_EQ(c10::str(Flatten()), "torch::nn::Flatten(start_dim=1, end_dim=-1)"); + ASSERT_EQ( + c10::str(Flatten(FlattenOptions().start_dim(2).end_dim(4))), + "torch::nn::Flatten(start_dim=2, end_dim=4)"); } TEST_F(ModulesTest, PrettyPrintUnflatten) { @@ -3860,16 +3860,18 @@ TEST_F(ModulesTest, ReflectionPad1d) { ReflectionPad1d m(ReflectionPad1dOptions(2)); auto input = torch::arange(8, torch::kFloat).reshape({1, 2, 4}); auto output = m(input); - auto expected = torch::tensor({{{2., 1., 0., 1., 2., 3., 2., 1.}, - {6., 5., 4., 5., 6., 7., 6., 5.}}}, torch::kFloat); + auto expected = torch::tensor( + {{{2., 1., 0., 1., 2., 3., 2., 1.}, {6., 5., 4., 5., 6., 7., 6., 5.}}}, + torch::kFloat); ASSERT_TRUE(output.allclose(expected)); } { ReflectionPad1d m(ReflectionPad1dOptions({3, 1})); auto input = torch::arange(8, torch::kFloat).reshape({1, 2, 4}); auto output = m(input); - auto expected = torch::tensor({{{3., 2., 1., 0., 1., 2., 3., 2.}, - {7., 6., 5., 4., 5., 6., 7., 6.}}}, torch::kFloat); + auto expected = torch::tensor( + {{{3., 2., 1., 0., 1., 2., 3., 2.}, {7., 6., 5., 4., 5., 6., 7., 6.}}}, + torch::kFloat); ASSERT_TRUE(output.allclose(expected)); } } @@ -3879,24 +3881,28 @@ TEST_F(ModulesTest, ReflectionPad2d) { ReflectionPad2d m(ReflectionPad2dOptions(2)); auto input = torch::arange(9, torch::kFloat).reshape({1, 1, 3, 3}); auto output = m(input); - auto expected = torch::tensor({{{{8., 7., 6., 7., 8., 7., 6.}, - {5., 4., 3., 4., 5., 4., 3.}, - {2., 1., 0., 1., 2., 1., 0.}, - {5., 4., 3., 4., 5., 4., 3.}, - {8., 7., 6., 7., 8., 7., 6.}, - {5., 4., 3., 4., 5., 4., 3.}, - {2., 1., 0., 1., 2., 1., 0.}}}}, torch::kFloat); + auto expected = torch::tensor( + {{{{8., 7., 6., 7., 8., 7., 6.}, + {5., 4., 3., 4., 5., 4., 3.}, + {2., 1., 0., 1., 2., 1., 0.}, + {5., 4., 3., 4., 5., 4., 3.}, + {8., 7., 6., 7., 8., 7., 6.}, + {5., 4., 3., 4., 5., 4., 3.}, + {2., 1., 0., 1., 2., 1., 0.}}}}, + torch::kFloat); ASSERT_TRUE(output.allclose(expected)); } { ReflectionPad2d m(ReflectionPad2dOptions({1, 1, 2, 0})); auto input = torch::arange(9, torch::kFloat).reshape({1, 1, 3, 3}); auto output = m(input); - auto expected = torch::tensor({{{{7., 6., 7., 8., 7.}, - {4., 3., 4., 5., 4.}, - {1., 0., 1., 2., 1.}, - {4., 3., 4., 5., 4.}, - {7., 6., 7., 8., 7.}}}}, torch::kFloat); + auto expected = torch::tensor( + {{{{7., 6., 7., 8., 7.}, + {4., 3., 4., 5., 4.}, + {1., 0., 1., 2., 1.}, + {4., 3., 4., 5., 4.}, + {7., 6., 7., 8., 7.}}}}, + torch::kFloat); ASSERT_TRUE(output.allclose(expected)); } } @@ -3906,49 +3912,39 @@ TEST_F(ModulesTest, ReflectionPad3d) { ReflectionPad3d m(ReflectionPad3dOptions(1)); auto input = torch::arange(8, torch::kFloat).reshape({1, 1, 2, 2, 2}); auto output = m(input); - auto expected = torch::tensor({{{{{7., 6., 7., 6.}, - {5., 4., 5., 4.}, - {7., 6., 7., 6.}, - {5., 4., 5., 4.}}, - {{3., 2., 3., 2.}, - {1., 0., 1., 0.}, - {3., 2., 3., 2.}, - {1., 0., 1., 0.}}, - {{7., 6., 7., 6.}, - {5., 4., 5., 4.}, - {7., 6., 7., 6.}, - {5., 4., 5., 4.}}, - {{3., 2., 3., 2.}, - {1., 0., 1., 0.}, - {3., 2., 3., 2.}, - {1., 0., 1., 0.}}}}}, torch::kFloat); + auto expected = torch::tensor( + {{{{{7., 6., 7., 6.}, + {5., 4., 5., 4.}, + {7., 6., 7., 6.}, + {5., 4., 5., 4.}}, + {{3., 2., 3., 2.}, + {1., 0., 1., 0.}, + {3., 2., 3., 2.}, + {1., 0., 1., 0.}}, + {{7., 6., 7., 6.}, + {5., 4., 5., 4.}, + {7., 6., 7., 6.}, + {5., 4., 5., 4.}}, + {{3., 2., 3., 2.}, + {1., 0., 1., 0.}, + {3., 2., 3., 2.}, + {1., 0., 1., 0.}}}}}, + torch::kFloat); ASSERT_TRUE(output.allclose(expected)); } { ReflectionPad3d m(ReflectionPad3dOptions({0, 1, 1, 0, 1, 2})); auto input = torch::arange(16, torch::kFloat).reshape({1, 1, 4, 2, 2}); auto output = m(input); - auto expected = torch::tensor({{{{{6., 7., 6.}, - {4., 5., 4.}, - {6., 7., 6.}}, - {{2., 3., 2.}, - {0., 1., 0.}, - {2., 3., 2.}}, - {{6., 7., 6.}, - {4., 5., 4.}, - {6., 7., 6.}}, - {{10., 11., 10.}, - {8., 9., 8.}, - {10., 11., 10.}}, - {{14., 15., 14.}, - {12., 13., 12.}, - {14., 15., 14.}}, - {{10., 11., 10.}, - {8., 9., 8.}, - {10., 11., 10.}}, - {{6., 7., 6.}, - {4., 5., 4.}, - {6., 7., 6.}}}}}, torch::kFloat); + auto expected = torch::tensor( + {{{{{6., 7., 6.}, {4., 5., 4.}, {6., 7., 6.}}, + {{2., 3., 2.}, {0., 1., 0.}, {2., 3., 2.}}, + {{6., 7., 6.}, {4., 5., 4.}, {6., 7., 6.}}, + {{10., 11., 10.}, {8., 9., 8.}, {10., 11., 10.}}, + {{14., 15., 14.}, {12., 13., 12.}, {14., 15., 14.}}, + {{10., 11., 10.}, {8., 9., 8.}, {10., 11., 10.}}, + {{6., 7., 6.}, {4., 5., 4.}, {6., 7., 6.}}}}}, + torch::kFloat); ASSERT_EQ(output.sizes(), std::vector({1, 1, 7, 3, 3})); ASSERT_TRUE(output.allclose(expected)); } @@ -3958,16 +3954,18 @@ TEST_F(ModulesTest, ReplicationPad1d) { ReplicationPad1d m(ReplicationPad1dOptions(2)); auto input = torch::arange(8, torch::kFloat).reshape({1, 2, 4}); auto output = m(input); - auto expected = torch::tensor({{{0., 0., 0., 1., 2., 3., 3., 3.}, - {4., 4., 4., 5., 6., 7., 7., 7.}}}, torch::kFloat); + auto expected = torch::tensor( + {{{0., 0., 0., 1., 2., 3., 3., 3.}, {4., 4., 4., 5., 6., 7., 7., 7.}}}, + torch::kFloat); ASSERT_TRUE(output.allclose(expected)); } { ReplicationPad1d m(ReplicationPad1dOptions({3, 1})); auto input = torch::arange(8, torch::kFloat).reshape({1, 2, 4}); auto output = m(input); - auto expected = torch::tensor({{{0., 0., 0., 0., 1., 2., 3., 3.}, - {4., 4., 4., 4., 5., 6., 7., 7.}}}, torch::kFloat); + auto expected = torch::tensor( + {{{0., 0., 0., 0., 1., 2., 3., 3.}, {4., 4., 4., 4., 5., 6., 7., 7.}}}, + torch::kFloat); ASSERT_TRUE(output.allclose(expected)); } } @@ -3977,24 +3975,28 @@ TEST_F(ModulesTest, ReplicationPad2d) { ReplicationPad2d m(ReplicationPad2dOptions(2)); auto input = torch::arange(9, torch::kFloat).reshape({1, 1, 3, 3}); auto output = m(input); - auto expected = torch::tensor({{{{0., 0., 0., 1., 2., 2., 2.}, - {0., 0., 0., 1., 2., 2., 2.}, - {0., 0., 0., 1., 2., 2., 2.}, - {3., 3., 3., 4., 5., 5., 5.}, - {6., 6., 6., 7., 8., 8., 8.}, - {6., 6., 6., 7., 8., 8., 8.}, - {6., 6., 6., 7., 8., 8., 8.}}}}, torch::kFloat); + auto expected = torch::tensor( + {{{{0., 0., 0., 1., 2., 2., 2.}, + {0., 0., 0., 1., 2., 2., 2.}, + {0., 0., 0., 1., 2., 2., 2.}, + {3., 3., 3., 4., 5., 5., 5.}, + {6., 6., 6., 7., 8., 8., 8.}, + {6., 6., 6., 7., 8., 8., 8.}, + {6., 6., 6., 7., 8., 8., 8.}}}}, + torch::kFloat); ASSERT_TRUE(output.allclose(expected)); } { ReplicationPad2d m(ReplicationPad2dOptions({1, 1, 2, 0})); auto input = torch::arange(9, torch::kFloat).reshape({1, 1, 3, 3}); auto output = m(input); - auto expected = torch::tensor({{{{0., 0., 1., 2., 2.}, - {0., 0., 1., 2., 2.}, - {0., 0., 1., 2., 2.}, - {3., 3., 4., 5., 5.}, - {6., 6., 7., 8., 8.}}}}, torch::kFloat); + auto expected = torch::tensor( + {{{{0., 0., 1., 2., 2.}, + {0., 0., 1., 2., 2.}, + {0., 0., 1., 2., 2.}, + {3., 3., 4., 5., 5.}, + {6., 6., 7., 8., 8.}}}}, + torch::kFloat); ASSERT_TRUE(output.allclose(expected)); } } @@ -4004,53 +4006,57 @@ TEST_F(ModulesTest, ReplicationPad3d) { ReplicationPad3d m(ReplicationPad3dOptions(1)); auto input = torch::arange(8, torch::kFloat).reshape({1, 1, 2, 2, 2}); auto output = m(input); - auto expected = torch::tensor({{{{{0., 0., 1., 1.}, - {0., 0., 1., 1.}, - {2., 2., 3., 3.}, - {2., 2., 3., 3.}}, - {{0., 0., 1., 1.}, - {0., 0., 1., 1.}, - {2., 2., 3., 3.}, - {2., 2., 3., 3.}}, - {{4., 4., 5., 5.}, - {4., 4., 5., 5.}, - {6., 6., 7., 7.}, - {6., 6., 7., 7.}}, - {{4., 4., 5., 5.}, - {4., 4., 5., 5.}, - {6., 6., 7., 7.}, - {6., 6., 7., 7.}}}}}, torch::kFloat); + auto expected = torch::tensor( + {{{{{0., 0., 1., 1.}, + {0., 0., 1., 1.}, + {2., 2., 3., 3.}, + {2., 2., 3., 3.}}, + {{0., 0., 1., 1.}, + {0., 0., 1., 1.}, + {2., 2., 3., 3.}, + {2., 2., 3., 3.}}, + {{4., 4., 5., 5.}, + {4., 4., 5., 5.}, + {6., 6., 7., 7.}, + {6., 6., 7., 7.}}, + {{4., 4., 5., 5.}, + {4., 4., 5., 5.}, + {6., 6., 7., 7.}, + {6., 6., 7., 7.}}}}}, + torch::kFloat); ASSERT_TRUE(output.allclose(expected)); } { ReplicationPad3d m(ReplicationPad3dOptions({1, 2, 1, 2, 1, 2})); auto input = torch::arange(8, torch::kFloat).reshape({1, 1, 2, 2, 2}); auto output = m(input); - auto expected = torch::tensor({{{{{0., 0., 1., 1., 1.}, - {0., 0., 1., 1., 1.}, - {2., 2., 3., 3., 3.}, - {2., 2., 3., 3., 3.}, - {2., 2., 3., 3., 3.}}, - {{0., 0., 1., 1., 1.}, - {0., 0., 1., 1., 1.}, - {2., 2., 3., 3., 3.}, - {2., 2., 3., 3., 3.}, - {2., 2., 3., 3., 3.}}, - {{4., 4., 5., 5., 5.}, - {4., 4., 5., 5., 5.}, - {6., 6., 7., 7., 7.}, - {6., 6., 7., 7., 7.}, - {6., 6., 7., 7., 7.}}, - {{4., 4., 5., 5., 5.}, - {4., 4., 5., 5., 5.}, - {6., 6., 7., 7., 7.}, - {6., 6., 7., 7., 7.}, - {6., 6., 7., 7., 7.}}, - {{4., 4., 5., 5., 5.}, - {4., 4., 5., 5., 5.}, - {6., 6., 7., 7., 7.}, - {6., 6., 7., 7., 7.}, - {6., 6., 7., 7., 7.}}}}}, torch::kFloat); + auto expected = torch::tensor( + {{{{{0., 0., 1., 1., 1.}, + {0., 0., 1., 1., 1.}, + {2., 2., 3., 3., 3.}, + {2., 2., 3., 3., 3.}, + {2., 2., 3., 3., 3.}}, + {{0., 0., 1., 1., 1.}, + {0., 0., 1., 1., 1.}, + {2., 2., 3., 3., 3.}, + {2., 2., 3., 3., 3.}, + {2., 2., 3., 3., 3.}}, + {{4., 4., 5., 5., 5.}, + {4., 4., 5., 5., 5.}, + {6., 6., 7., 7., 7.}, + {6., 6., 7., 7., 7.}, + {6., 6., 7., 7., 7.}}, + {{4., 4., 5., 5., 5.}, + {4., 4., 5., 5., 5.}, + {6., 6., 7., 7., 7.}, + {6., 6., 7., 7., 7.}, + {6., 6., 7., 7., 7.}}, + {{4., 4., 5., 5., 5.}, + {4., 4., 5., 5., 5.}, + {6., 6., 7., 7., 7.}, + {6., 6., 7., 7., 7.}, + {6., 6., 7., 7., 7.}}}}}, + torch::kFloat); ASSERT_TRUE(output.allclose(expected)); } } @@ -4060,24 +4066,28 @@ TEST_F(ModulesTest, ZeroPad2d) { ZeroPad2d m(ZeroPad2dOptions(2)); auto input = torch::arange(9, torch::kFloat).reshape({1, 1, 3, 3}); auto output = m(input); - auto expected = torch::tensor({{{{0., 0., 0., 0., 0., 0., 0.}, - {0., 0., 0., 0., 0., 0., 0.}, - {0., 0., 0., 1., 2., 0., 0.}, - {0., 0., 3., 4., 5., 0., 0.}, - {0., 0., 6., 7., 8., 0., 0.}, - {0., 0., 0., 0., 0., 0., 0.}, - {0., 0., 0., 0., 0., 0., 0.}}}}, torch::kFloat); + auto expected = torch::tensor( + {{{{0., 0., 0., 0., 0., 0., 0.}, + {0., 0., 0., 0., 0., 0., 0.}, + {0., 0., 0., 1., 2., 0., 0.}, + {0., 0., 3., 4., 5., 0., 0.}, + {0., 0., 6., 7., 8., 0., 0.}, + {0., 0., 0., 0., 0., 0., 0.}, + {0., 0., 0., 0., 0., 0., 0.}}}}, + torch::kFloat); ASSERT_TRUE(output.allclose(expected)); } { ZeroPad2d m(ZeroPad2dOptions({1, 1, 2, 0})); auto input = torch::arange(9, torch::kFloat).reshape({1, 1, 3, 3}); auto output = m(input); - auto expected = torch::tensor({{{{0., 0., 0., 0., 0.}, - {0., 0., 0., 0., 0.}, - {0., 0., 1., 2., 0.}, - {0., 3., 4., 5., 0.}, - {0., 6., 7., 8., 0.}}}}, torch::kFloat); + auto expected = torch::tensor( + {{{{0., 0., 0., 0., 0.}, + {0., 0., 0., 0., 0.}, + {0., 0., 1., 2., 0.}, + {0., 3., 4., 5., 0.}, + {0., 6., 7., 8., 0.}}}}, + torch::kFloat); ASSERT_TRUE(output.allclose(expected)); } } @@ -4087,16 +4097,20 @@ TEST_F(ModulesTest, ConstantPad1d) { ConstantPad1d m(ConstantPad1dOptions(2, 3.5)); auto input = torch::arange(8, torch::kFloat).reshape({1, 2, 4}); auto output = m(input); - auto expected = torch::tensor({{{3.5000, 3.5000, 0.0000, 1.0000, 2.0000, 3.0000, 3.5000, 3.5000}, - {3.5000, 3.5000, 4.0000, 5.0000, 6.0000, 7.0000, 3.5000, 3.5000}}}, torch::kFloat); + auto expected = torch::tensor( + {{{3.5000, 3.5000, 0.0000, 1.0000, 2.0000, 3.0000, 3.5000, 3.5000}, + {3.5000, 3.5000, 4.0000, 5.0000, 6.0000, 7.0000, 3.5000, 3.5000}}}, + torch::kFloat); ASSERT_TRUE(output.allclose(expected)); } { ConstantPad1d m(ConstantPad1dOptions({3, 1}, 3.5)); auto input = torch::arange(6, torch::kFloat).reshape({1, 2, 3}); auto output = m(input); - auto expected = torch::tensor({{{3.5000, 3.5000, 3.5000, 0.0000, 1.0000, 2.0000, 3.5000}, - {3.5000, 3.5000, 3.5000, 3.0000, 4.0000, 5.0000, 3.5000}}}, torch::kFloat); + auto expected = torch::tensor( + {{{3.5000, 3.5000, 3.5000, 0.0000, 1.0000, 2.0000, 3.5000}, + {3.5000, 3.5000, 3.5000, 3.0000, 4.0000, 5.0000, 3.5000}}}, + torch::kFloat); ASSERT_TRUE(output.allclose(expected)); } } @@ -4106,23 +4120,27 @@ TEST_F(ModulesTest, ConstantPad2d) { ConstantPad2d m(ConstantPad2dOptions(2, 3.5)); auto input = torch::arange(4, torch::kFloat).reshape({1, 2, 2}); auto output = m(input); - auto expected = torch::tensor({{{3.5000, 3.5000, 3.5000, 3.5000, 3.5000, 3.5000}, - {3.5000, 3.5000, 3.5000, 3.5000, 3.5000, 3.5000}, - {3.5000, 3.5000, 0.0000, 1.0000, 3.5000, 3.5000}, - {3.5000, 3.5000, 2.0000, 3.0000, 3.5000, 3.5000}, - {3.5000, 3.5000, 3.5000, 3.5000, 3.5000, 3.5000}, - {3.5000, 3.5000, 3.5000, 3.5000, 3.5000, 3.5000}}}, torch::kFloat); + auto expected = torch::tensor( + {{{3.5000, 3.5000, 3.5000, 3.5000, 3.5000, 3.5000}, + {3.5000, 3.5000, 3.5000, 3.5000, 3.5000, 3.5000}, + {3.5000, 3.5000, 0.0000, 1.0000, 3.5000, 3.5000}, + {3.5000, 3.5000, 2.0000, 3.0000, 3.5000, 3.5000}, + {3.5000, 3.5000, 3.5000, 3.5000, 3.5000, 3.5000}, + {3.5000, 3.5000, 3.5000, 3.5000, 3.5000, 3.5000}}}, + torch::kFloat); ASSERT_TRUE(output.allclose(expected)); } { ConstantPad2d m(ConstantPad2dOptions({3, 0, 2, 1}, 3.5)); auto input = torch::arange(4, torch::kFloat).reshape({1, 2, 2}); auto output = m(input); - auto expected = torch::tensor({{{3.5000, 3.5000, 3.5000, 3.5000, 3.5000}, - {3.5000, 3.5000, 3.5000, 3.5000, 3.5000}, - {3.5000, 3.5000, 3.5000, 0.0000, 1.0000}, - {3.5000, 3.5000, 3.5000, 2.0000, 3.0000}, - {3.5000, 3.5000, 3.5000, 3.5000, 3.5000}}}, torch::kFloat); + auto expected = torch::tensor( + {{{3.5000, 3.5000, 3.5000, 3.5000, 3.5000}, + {3.5000, 3.5000, 3.5000, 3.5000, 3.5000}, + {3.5000, 3.5000, 3.5000, 0.0000, 1.0000}, + {3.5000, 3.5000, 3.5000, 2.0000, 3.0000}, + {3.5000, 3.5000, 3.5000, 3.5000, 3.5000}}}, + torch::kFloat); ASSERT_TRUE(output.allclose(expected)); } } @@ -4132,66 +4150,75 @@ TEST_F(ModulesTest, ConstantPad3d) { ConstantPad3d m(ConstantPad3dOptions(1, 3.5)); auto input = torch::arange(8, torch::kFloat).reshape({1, 1, 2, 2, 2}); auto output = m(input); - auto expected = torch::tensor({{{{{3.5000, 3.5000, 3.5000, 3.5000}, - {3.5000, 3.5000, 3.5000, 3.5000}, - {3.5000, 3.5000, 3.5000, 3.5000}, - {3.5000, 3.5000, 3.5000, 3.5000}}, - {{3.5000, 3.5000, 3.5000, 3.5000}, - {3.5000, 0.0000, 1.0000, 3.5000}, - {3.5000, 2.0000, 3.0000, 3.5000}, - {3.5000, 3.5000, 3.5000, 3.5000}}, - {{3.5000, 3.5000, 3.5000, 3.5000}, - {3.5000, 4.0000, 5.0000, 3.5000}, - {3.5000, 6.0000, 7.0000, 3.5000}, - {3.5000, 3.5000, 3.5000, 3.5000}}, - {{3.5000, 3.5000, 3.5000, 3.5000}, - {3.5000, 3.5000, 3.5000, 3.5000}, - {3.5000, 3.5000, 3.5000, 3.5000}, - {3.5000, 3.5000, 3.5000, 3.5000}}}}}, torch::kFloat); + auto expected = torch::tensor( + {{{{{3.5000, 3.5000, 3.5000, 3.5000}, + {3.5000, 3.5000, 3.5000, 3.5000}, + {3.5000, 3.5000, 3.5000, 3.5000}, + {3.5000, 3.5000, 3.5000, 3.5000}}, + {{3.5000, 3.5000, 3.5000, 3.5000}, + {3.5000, 0.0000, 1.0000, 3.5000}, + {3.5000, 2.0000, 3.0000, 3.5000}, + {3.5000, 3.5000, 3.5000, 3.5000}}, + {{3.5000, 3.5000, 3.5000, 3.5000}, + {3.5000, 4.0000, 5.0000, 3.5000}, + {3.5000, 6.0000, 7.0000, 3.5000}, + {3.5000, 3.5000, 3.5000, 3.5000}}, + {{3.5000, 3.5000, 3.5000, 3.5000}, + {3.5000, 3.5000, 3.5000, 3.5000}, + {3.5000, 3.5000, 3.5000, 3.5000}, + {3.5000, 3.5000, 3.5000, 3.5000}}}}}, + torch::kFloat); ASSERT_TRUE(output.allclose(expected)); } { ConstantPad3d m(ConstantPad3dOptions({1, 2, 1, 2, 1, 2}, 3.5)); auto input = torch::arange(8, torch::kFloat).reshape({1, 1, 2, 2, 2}); auto output = m(input); - auto expected = torch::tensor({{{{{3.5000, 3.5000, 3.5000, 3.5000, 3.5000}, - {3.5000, 3.5000, 3.5000, 3.5000, 3.5000}, - {3.5000, 3.5000, 3.5000, 3.5000, 3.5000}, - {3.5000, 3.5000, 3.5000, 3.5000, 3.5000}, - {3.5000, 3.5000, 3.5000, 3.5000, 3.5000}}, - {{3.5000, 3.5000, 3.5000, 3.5000, 3.5000}, - {3.5000, 0.0000, 1.0000, 3.5000, 3.5000}, - {3.5000, 2.0000, 3.0000, 3.5000, 3.5000}, - {3.5000, 3.5000, 3.5000, 3.5000, 3.5000}, - {3.5000, 3.5000, 3.5000, 3.5000, 3.5000}}, - {{3.5000, 3.5000, 3.5000, 3.5000, 3.5000}, - {3.5000, 4.0000, 5.0000, 3.5000, 3.5000}, - {3.5000, 6.0000, 7.0000, 3.5000, 3.5000}, - {3.5000, 3.5000, 3.5000, 3.5000, 3.5000}, - {3.5000, 3.5000, 3.5000, 3.5000, 3.5000}}, - {{3.5000, 3.5000, 3.5000, 3.5000, 3.5000}, - {3.5000, 3.5000, 3.5000, 3.5000, 3.5000}, - {3.5000, 3.5000, 3.5000, 3.5000, 3.5000}, - {3.5000, 3.5000, 3.5000, 3.5000, 3.5000}, - {3.5000, 3.5000, 3.5000, 3.5000, 3.5000}}, - {{3.5000, 3.5000, 3.5000, 3.5000, 3.5000}, - {3.5000, 3.5000, 3.5000, 3.5000, 3.5000}, - {3.5000, 3.5000, 3.5000, 3.5000, 3.5000}, - {3.5000, 3.5000, 3.5000, 3.5000, 3.5000}, - {3.5000, 3.5000, 3.5000, 3.5000, 3.5000}}}}}, torch::kFloat); + auto expected = torch::tensor( + {{{{{3.5000, 3.5000, 3.5000, 3.5000, 3.5000}, + {3.5000, 3.5000, 3.5000, 3.5000, 3.5000}, + {3.5000, 3.5000, 3.5000, 3.5000, 3.5000}, + {3.5000, 3.5000, 3.5000, 3.5000, 3.5000}, + {3.5000, 3.5000, 3.5000, 3.5000, 3.5000}}, + {{3.5000, 3.5000, 3.5000, 3.5000, 3.5000}, + {3.5000, 0.0000, 1.0000, 3.5000, 3.5000}, + {3.5000, 2.0000, 3.0000, 3.5000, 3.5000}, + {3.5000, 3.5000, 3.5000, 3.5000, 3.5000}, + {3.5000, 3.5000, 3.5000, 3.5000, 3.5000}}, + {{3.5000, 3.5000, 3.5000, 3.5000, 3.5000}, + {3.5000, 4.0000, 5.0000, 3.5000, 3.5000}, + {3.5000, 6.0000, 7.0000, 3.5000, 3.5000}, + {3.5000, 3.5000, 3.5000, 3.5000, 3.5000}, + {3.5000, 3.5000, 3.5000, 3.5000, 3.5000}}, + {{3.5000, 3.5000, 3.5000, 3.5000, 3.5000}, + {3.5000, 3.5000, 3.5000, 3.5000, 3.5000}, + {3.5000, 3.5000, 3.5000, 3.5000, 3.5000}, + {3.5000, 3.5000, 3.5000, 3.5000, 3.5000}, + {3.5000, 3.5000, 3.5000, 3.5000, 3.5000}}, + {{3.5000, 3.5000, 3.5000, 3.5000, 3.5000}, + {3.5000, 3.5000, 3.5000, 3.5000, 3.5000}, + {3.5000, 3.5000, 3.5000, 3.5000, 3.5000}, + {3.5000, 3.5000, 3.5000, 3.5000, 3.5000}, + {3.5000, 3.5000, 3.5000, 3.5000, 3.5000}}}}}, + torch::kFloat); ASSERT_TRUE(output.allclose(expected)); } } TEST_F(ModulesTest, CrossMapLRN2d) { /// size 3, default options - auto input = torch::arange(9, torch::kFloat32).view({1, 1, 3, 3}).requires_grad_(true); - auto expected = torch::tensor({{{{0.00000000, 0.99997497, 1.99980010}, - {2.99932500, 3.99840070, 4.99687700}, - {5.99460600, 6.99143740, 7.98722360}}}}, torch::kFloat32); - auto grad_expected = torch::tensor({{{{1.00000000, 0.99992496, 0.99970007}, - {0.99932520, 0.99880093, 0.99812720}, - {0.99730474, 0.99633380, 0.99521490}}}}, torch::kFloat32); + auto input = + torch::arange(9, torch::kFloat32).view({1, 1, 3, 3}).requires_grad_(true); + auto expected = torch::tensor( + {{{{0.00000000, 0.99997497, 1.99980010}, + {2.99932500, 3.99840070, 4.99687700}, + {5.99460600, 6.99143740, 7.98722360}}}}, + torch::kFloat32); + auto grad_expected = torch::tensor( + {{{{1.00000000, 0.99992496, 0.99970007}, + {0.99932520, 0.99880093, 0.99812720}, + {0.99730474, 0.99633380, 0.99521490}}}}, + torch::kFloat32); auto crossmaplrn2d = CrossMapLRN2d(3); auto output = crossmaplrn2d(input); output.sum().backward(); @@ -4200,35 +4227,47 @@ TEST_F(ModulesTest, CrossMapLRN2d) { ASSERT_TRUE(output.allclose(expected)); /// size change - crossmaplrn2d = CrossMapLRN2d(CrossMapLRN2dOptions(4).alpha(1e-4).beta(0.75).k(1)); + crossmaplrn2d = + CrossMapLRN2d(CrossMapLRN2dOptions(4).alpha(1e-4).beta(0.75).k(1)); output = crossmaplrn2d(input); - expected = torch::tensor({{{{0.00000000, 0.99998120, 1.99985000}, - {2.99949400, 3.99880050, 4.99765800}, - {5.99595300, 6.99357600, 7.99041300}}}}, torch::kFloat32); + expected = torch::tensor( + {{{{0.00000000, 0.99998120, 1.99985000}, + {2.99949400, 3.99880050, 4.99765800}, + {5.99595300, 6.99357600, 7.99041300}}}}, + torch::kFloat32); ASSERT_TRUE(output.allclose(expected)); /// alpha change - crossmaplrn2d = CrossMapLRN2d(CrossMapLRN2dOptions(3).alpha(1e-3).beta(0.75).k(1)); + crossmaplrn2d = + CrossMapLRN2d(CrossMapLRN2dOptions(3).alpha(1e-3).beta(0.75).k(1)); output = crossmaplrn2d(input); - expected = torch::tensor({{{{0.00000000, 0.99975010, 1.99800230}, - {2.99326750, 3.98407440, 4.96897600}, - {5.94656100, 6.91545720, 7.87434340}}}}, torch::kFloat32); + expected = torch::tensor( + {{{{0.00000000, 0.99975010, 1.99800230}, + {2.99326750, 3.98407440, 4.96897600}, + {5.94656100, 6.91545720, 7.87434340}}}}, + torch::kFloat32); ASSERT_TRUE(output.allclose(expected)); /// beta change - crossmaplrn2d = CrossMapLRN2d(CrossMapLRN2dOptions(3).alpha(1e-4).beta(0.95).k(1)); + crossmaplrn2d = + CrossMapLRN2d(CrossMapLRN2dOptions(3).alpha(1e-4).beta(0.95).k(1)); output = crossmaplrn2d(input); - expected = torch::tensor({{{{0.00000000, 0.99996830, 1.99974680}, - {2.99914500, 3.99797440, 4.99604460}, - {5.99316840, 6.98915600, 7.98382000}}}}, torch::kFloat32); + expected = torch::tensor( + {{{{0.00000000, 0.99996830, 1.99974680}, + {2.99914500, 3.99797440, 4.99604460}, + {5.99316840, 6.98915600, 7.98382000}}}}, + torch::kFloat32); ASSERT_TRUE(output.allclose(expected)); /// k change - crossmaplrn2d = CrossMapLRN2d(CrossMapLRN2dOptions(3).alpha(1e-4).beta(0.75).k(2)); + crossmaplrn2d = + CrossMapLRN2d(CrossMapLRN2dOptions(3).alpha(1e-4).beta(0.75).k(2)); output = crossmaplrn2d(input); - expected = torch::tensor({{{{0.00000000, 0.59459610, 1.18914770}, - {1.78361000, 2.37793870, 2.97208900}, - {3.56601700, 4.15967700, 4.75302650}}}}, torch::kFloat32); + expected = torch::tensor( + {{{{0.00000000, 0.59459610, 1.18914770}, + {1.78361000, 2.37793870, 2.97208900}, + {3.56601700, 4.15967700, 4.75302650}}}}, + torch::kFloat32); ASSERT_TRUE(output.allclose(expected)); } @@ -4238,15 +4277,13 @@ TEST_F(ModulesTest, RNNCell) { auto input = torch::randn({3, 1}); auto hx = torch::randn({3, 2}); auto output = rnn(input, hx); - auto expected = torch::tensor({{-0.5078, 0.4380}, - {-0.7215, 0.2969}, - {-0.1304, 0.0653}}); + auto expected = + torch::tensor({{-0.5078, 0.4380}, {-0.7215, 0.2969}, {-0.1304, 0.0653}}); ASSERT_TRUE(torch::allclose(output, expected, 1e-05, 2e-04)); output = rnn(input); - expected = torch::tensor({{-0.0775, 0.6688}, - {-0.0734, 0.4759}, - {-0.0725, 0.4225}}); + expected = + torch::tensor({{-0.0775, 0.6688}, {-0.0734, 0.4759}, {-0.0725, 0.4225}}); ASSERT_TRUE(torch::allclose(output, expected, 1e-05, 2e-04)); } @@ -4259,24 +4296,20 @@ TEST_F(ModulesTest, LSTMCell) { auto output = rnn(input, std::make_tuple(hx, cx)); auto output_hx = std::get<0>(output); auto output_cx = std::get<1>(output); - auto expected_hx = torch::tensor({{-0.2462, 0.0810}, - {-0.2206, 0.1867}, - {-0.0146, 0.0429}}); - auto expected_cx = torch::tensor({{-0.4480, 0.1071}, - {-0.6245, 0.2687}, - {-0.0322, 0.0518}}); + auto expected_hx = + torch::tensor({{-0.2462, 0.0810}, {-0.2206, 0.1867}, {-0.0146, 0.0429}}); + auto expected_cx = + torch::tensor({{-0.4480, 0.1071}, {-0.6245, 0.2687}, {-0.0322, 0.0518}}); ASSERT_TRUE(torch::allclose(output_hx, expected_hx, 1e-05, 2e-04)); ASSERT_TRUE(torch::allclose(output_cx, expected_cx, 1e-05, 2e-04)); output = rnn(input); output_hx = std::get<0>(output); output_cx = std::get<1>(output); - expected_hx = torch::tensor({{-0.1331, 0.1634}, - {-0.1494, 0.2869}, - {-0.1428, 0.2263}}); - expected_cx = torch::tensor({{-0.2679, 0.2180}, - {-0.3049, 0.3493}, - {-0.2896, 0.2853}}); + expected_hx = + torch::tensor({{-0.1331, 0.1634}, {-0.1494, 0.2869}, {-0.1428, 0.2263}}); + expected_cx = + torch::tensor({{-0.2679, 0.2180}, {-0.3049, 0.3493}, {-0.2896, 0.2853}}); ASSERT_TRUE(torch::allclose(output_hx, expected_hx, 1e-05, 2e-04)); ASSERT_TRUE(torch::allclose(output_cx, expected_cx, 1e-05, 2e-04)); } @@ -4287,28 +4320,29 @@ TEST_F(ModulesTest, GRUCell) { auto input = torch::randn({3, 1}); auto hx = torch::randn({3, 2}); auto output = rnn(input, hx); - auto expected = torch::tensor({{ 1.0243, 0.3227}, - {-0.5659, 0.0330}, - {-0.4030, -0.2800}}); + auto expected = + torch::tensor({{1.0243, 0.3227}, {-0.5659, 0.0330}, {-0.4030, -0.2800}}); ASSERT_TRUE(torch::allclose(output, expected, 1e-05, 2e-04)); output = rnn(input); - expected = torch::tensor({{-0.0085, 0.1095}, - {-0.1291, 0.2675}, - {-0.1339, 0.2725}}); + expected = + torch::tensor({{-0.0085, 0.1095}, {-0.1291, 0.2675}, {-0.1339, 0.2725}}); ASSERT_TRUE(torch::allclose(output, expected, 1e-05, 2e-04)); } TEST_F(ModulesTest, PrettyPrintLinear) { ASSERT_EQ( - c10::str(Linear(3, 4)), "torch::nn::Linear(in_features=3, out_features=4, bias=true)"); + c10::str(Linear(3, 4)), + "torch::nn::Linear(in_features=3, out_features=4, bias=true)"); } TEST_F(ModulesTest, PrettyPrintBilinear) { ASSERT_EQ( - c10::str(Bilinear(3, 2, 4)), "torch::nn::Bilinear(in1_features=3, in2_features=2, out_features=4, bias=true)"); + c10::str(Bilinear(3, 2, 4)), + "torch::nn::Bilinear(in1_features=3, in2_features=2, out_features=4, bias=true)"); ASSERT_EQ( - c10::str(Bilinear(BilinearOptions(3, 2, 4).bias(false))), "torch::nn::Bilinear(in1_features=3, in2_features=2, out_features=4, bias=false)"); + c10::str(Bilinear(BilinearOptions(3, 2, 4).bias(false))), + "torch::nn::Bilinear(in1_features=3, in2_features=2, out_features=4, bias=false)"); } TEST_F(ModulesTest, PrettyPrintConv) { @@ -4334,27 +4368,25 @@ TEST_F(ModulesTest, PrettyPrintConv) { c10::str(Conv3d(4, 4, std::vector{5, 6, 7})), "torch::nn::Conv3d(4, 4, kernel_size=[5, 6, 7], stride=[1, 1, 1])"); { - const auto options = - Conv3dOptions(4, 4, std::vector{5, 6, 7}) - .stride({1, 2, 3}) - .padding(1) - .dilation(0) - .groups(2) - .bias(false) - .padding_mode(torch::kCircular); + const auto options = Conv3dOptions(4, 4, std::vector{5, 6, 7}) + .stride({1, 2, 3}) + .padding(1) + .dilation(0) + .groups(2) + .bias(false) + .padding_mode(torch::kCircular); ASSERT_EQ( - c10::str( - Conv3d(options)), - "torch::nn::Conv3d(" - "4, " - "4, " - "kernel_size=[5, 6, 7], " - "stride=[1, 2, 3], " - "padding=[1, 1, 1], " - "dilation=[0, 0, 0], " - "groups=2, " - "bias=false, " - "padding_mode=kCircular)"); + c10::str(Conv3d(options)), + "torch::nn::Conv3d(" + "4, " + "4, " + "kernel_size=[5, 6, 7], " + "stride=[1, 2, 3], " + "padding=[1, 1, 1], " + "dilation=[0, 0, 0], " + "groups=2, " + "bias=false, " + "padding_mode=kCircular)"); } } @@ -4383,34 +4415,36 @@ TEST_F(ModulesTest, PrettyPrintConvTranspose) { { const auto options = ConvTranspose3dOptions(4, 4, std::vector{5, 6, 7}) - .stride({1, 2, 3}) - .padding(1) - .dilation(0) - .groups(2) - .bias(false) - .padding_mode(torch::kCircular); + .stride({1, 2, 3}) + .padding(1) + .dilation(0) + .groups(2) + .bias(false) + .padding_mode(torch::kCircular); ASSERT_EQ( - c10::str( - ConvTranspose3d(options)), - "torch::nn::ConvTranspose3d(" - "4, " - "4, " - "kernel_size=[5, 6, 7], " - "stride=[1, 2, 3], " - "padding=[1, 1, 1], " - "dilation=[0, 0, 0], " - "groups=2, " - "bias=false, " - "padding_mode=kCircular)"); + c10::str(ConvTranspose3d(options)), + "torch::nn::ConvTranspose3d(" + "4, " + "4, " + "kernel_size=[5, 6, 7], " + "stride=[1, 2, 3], " + "padding=[1, 1, 1], " + "dilation=[0, 0, 0], " + "groups=2, " + "bias=false, " + "padding_mode=kCircular)"); } } TEST_F(ModulesTest, PrettyPrintUpsample) { ASSERT_EQ( - c10::str(Upsample(UpsampleOptions().size(std::vector({2, 4, 4})))), + c10::str( + Upsample(UpsampleOptions().size(std::vector({2, 4, 4})))), "torch::nn::Upsample(size=[2, 4, 4], mode=kNearest)"); ASSERT_EQ( - c10::str(Upsample(UpsampleOptions().scale_factor(std::vector({0.5, 1.5})).mode(torch::kBilinear))), + c10::str(Upsample(UpsampleOptions() + .scale_factor(std::vector({0.5, 1.5})) + .mode(torch::kBilinear))), "torch::nn::Upsample(scale_factor=[0.5, 1.5], mode=kBilinear)"); } @@ -4419,7 +4453,8 @@ TEST_F(ModulesTest, PrettyPrintFold) { c10::str(Fold(FoldOptions({2, 2}, {5, 5}))), "torch::nn::Fold(output_size=[2, 2], kernel_size=[5, 5], dilation=[1, 1], padding=[0, 0], stride=[1, 1])"); ASSERT_EQ( - c10::str(Fold(FoldOptions({8, 8}, {3, 3}).dilation(2).padding({2, 1}).stride(2))), + c10::str(Fold( + FoldOptions({8, 8}, {3, 3}).dilation(2).padding({2, 1}).stride(2))), "torch::nn::Fold(output_size=[8, 8], kernel_size=[3, 3], dilation=[2, 2], padding=[2, 1], stride=[2, 2])"); } @@ -4428,7 +4463,8 @@ TEST_F(ModulesTest, PrettyPrintUnfold) { c10::str(Unfold(torch::IntArrayRef({2, 4}))), "torch::nn::Unfold(kernel_size=[2, 4], dilation=[1, 1], padding=[0, 0], stride=[1, 1])"); ASSERT_EQ( - c10::str(Unfold(UnfoldOptions({2, 4}).dilation(2).padding({2, 1}).stride(2))), + c10::str( + Unfold(UnfoldOptions({2, 4}).dilation(2).padding({2, 1}).stride(2))), "torch::nn::Unfold(kernel_size=[2, 4], dilation=[2, 2], padding=[2, 1], stride=[2, 2])"); } @@ -4482,10 +4518,12 @@ TEST_F(ModulesTest, PrettyPrintAvgPool) { TEST_F(ModulesTest, PrettyPrinFractionalMaxPool) { ASSERT_EQ( - c10::str(FractionalMaxPool2d(FractionalMaxPool2dOptions(5).output_size(1))), + c10::str( + FractionalMaxPool2d(FractionalMaxPool2dOptions(5).output_size(1))), "torch::nn::FractionalMaxPool2d()"); ASSERT_EQ( - c10::str(FractionalMaxPool3d(FractionalMaxPool3dOptions(5).output_size(1))), + c10::str( + FractionalMaxPool3d(FractionalMaxPool3dOptions(5).output_size(1))), "torch::nn::FractionalMaxPool3d()"); } @@ -4500,7 +4538,9 @@ TEST_F(ModulesTest, PrettyPrintLPPool) { c10::str(LPPool2d(2, std::vector({1, 2}))), "torch::nn::LPPool2d(norm_type=2, kernel_size=[1, 2], stride=[1, 2], ceil_mode=false)"); ASSERT_EQ( - c10::str(LPPool2d(LPPool2dOptions(1, std::vector({3, 4})).stride({5, 6}).ceil_mode(true))), + c10::str(LPPool2d(LPPool2dOptions(1, std::vector({3, 4})) + .stride({5, 6}) + .ceil_mode(true))), "torch::nn::LPPool2d(norm_type=1, kernel_size=[3, 4], stride=[5, 6], ceil_mode=true)"); } @@ -4524,7 +4564,8 @@ TEST_F(ModulesTest, PrettyPrintAdaptiveMaxPool) { c10::str(AdaptiveMaxPool2d(AdaptiveMaxPool2dOptions({5, c10::nullopt}))), "torch::nn::AdaptiveMaxPool2d(output_size=[5, None])"); ASSERT_EQ( - c10::str(AdaptiveMaxPool2d(AdaptiveMaxPool2dOptions({c10::nullopt, c10::nullopt}))), + c10::str(AdaptiveMaxPool2d( + AdaptiveMaxPool2dOptions({c10::nullopt, c10::nullopt}))), "torch::nn::AdaptiveMaxPool2d(output_size=[None, None])"); ASSERT_EQ( @@ -4534,10 +4575,12 @@ TEST_F(ModulesTest, PrettyPrintAdaptiveMaxPool) { c10::str(AdaptiveMaxPool3d(AdaptiveMaxPool3dOptions({5, 6, 7}))), "torch::nn::AdaptiveMaxPool3d(output_size=[5, 6, 7])"); ASSERT_EQ( - c10::str(AdaptiveMaxPool3d(AdaptiveMaxPool3dOptions({5, c10::nullopt, 7}))), + c10::str( + AdaptiveMaxPool3d(AdaptiveMaxPool3dOptions({5, c10::nullopt, 7}))), "torch::nn::AdaptiveMaxPool3d(output_size=[5, None, 7])"); ASSERT_EQ( - c10::str(AdaptiveMaxPool3d(AdaptiveMaxPool3dOptions({c10::nullopt, c10::nullopt, c10::nullopt}))), + c10::str(AdaptiveMaxPool3d(AdaptiveMaxPool3dOptions( + {c10::nullopt, c10::nullopt, c10::nullopt}))), "torch::nn::AdaptiveMaxPool3d(output_size=[None, None, None])"); } @@ -4556,7 +4599,8 @@ TEST_F(ModulesTest, PrettyPrintAdaptiveAvgPool) { c10::str(AdaptiveAvgPool2d(AdaptiveAvgPool2dOptions({5, c10::nullopt}))), "torch::nn::AdaptiveAvgPool2d(output_size=[5, None])"); ASSERT_EQ( - c10::str(AdaptiveAvgPool2d(AdaptiveAvgPool2dOptions({c10::nullopt, c10::nullopt}))), + c10::str(AdaptiveAvgPool2d( + AdaptiveAvgPool2dOptions({c10::nullopt, c10::nullopt}))), "torch::nn::AdaptiveAvgPool2d(output_size=[None, None])"); ASSERT_EQ( @@ -4566,10 +4610,12 @@ TEST_F(ModulesTest, PrettyPrintAdaptiveAvgPool) { c10::str(AdaptiveAvgPool3d(AdaptiveAvgPool3dOptions({5, 6, 7}))), "torch::nn::AdaptiveAvgPool3d(output_size=[5, 6, 7])"); ASSERT_EQ( - c10::str(AdaptiveAvgPool3d(AdaptiveAvgPool3dOptions({5, c10::nullopt, 7}))), + c10::str( + AdaptiveAvgPool3d(AdaptiveAvgPool3dOptions({5, c10::nullopt, 7}))), "torch::nn::AdaptiveAvgPool3d(output_size=[5, None, 7])"); ASSERT_EQ( - c10::str(AdaptiveAvgPool3d(AdaptiveAvgPool3dOptions({c10::nullopt, c10::nullopt, c10::nullopt}))), + c10::str(AdaptiveAvgPool3d(AdaptiveAvgPool3dOptions( + {c10::nullopt, c10::nullopt, c10::nullopt}))), "torch::nn::AdaptiveAvgPool3d(output_size=[None, None, None])"); } @@ -4588,26 +4634,39 @@ TEST_F(ModulesTest, PrettyPrintMaxUnpool) { c10::str(MaxUnpool2d(std::vector{5, 6})), "torch::nn::MaxUnpool2d(kernel_size=[5, 6], stride=[5, 6], padding=[0, 0])"); ASSERT_EQ( - c10::str(MaxUnpool2d(MaxUnpool2dOptions(std::vector{5, 6}).stride({3, 4}).padding({1, 2}))), + c10::str(MaxUnpool2d(MaxUnpool2dOptions(std::vector{5, 6}) + .stride({3, 4}) + .padding({1, 2}))), "torch::nn::MaxUnpool2d(kernel_size=[5, 6], stride=[3, 4], padding=[1, 2])"); } TEST_F(ModulesTest, PrettyPrintDropout) { ASSERT_EQ(c10::str(Dropout()), "torch::nn::Dropout(p=0.5, inplace=false)"); - ASSERT_EQ(c10::str(Dropout(0.42)), "torch::nn::Dropout(p=0.42, inplace=false)"); - ASSERT_EQ(c10::str(Dropout(DropoutOptions().p(0.42).inplace(true))), "torch::nn::Dropout(p=0.42, inplace=true)"); + ASSERT_EQ( + c10::str(Dropout(0.42)), "torch::nn::Dropout(p=0.42, inplace=false)"); + ASSERT_EQ( + c10::str(Dropout(DropoutOptions().p(0.42).inplace(true))), + "torch::nn::Dropout(p=0.42, inplace=true)"); } TEST_F(ModulesTest, PrettyPrintDropout2d) { - ASSERT_EQ(c10::str(Dropout2d()), "torch::nn::Dropout2d(p=0.5, inplace=false)"); - ASSERT_EQ(c10::str(Dropout2d(0.42)), "torch::nn::Dropout2d(p=0.42, inplace=false)"); - ASSERT_EQ(c10::str(Dropout2d(Dropout2dOptions().p(0.42).inplace(true))), "torch::nn::Dropout2d(p=0.42, inplace=true)"); + ASSERT_EQ( + c10::str(Dropout2d()), "torch::nn::Dropout2d(p=0.5, inplace=false)"); + ASSERT_EQ( + c10::str(Dropout2d(0.42)), "torch::nn::Dropout2d(p=0.42, inplace=false)"); + ASSERT_EQ( + c10::str(Dropout2d(Dropout2dOptions().p(0.42).inplace(true))), + "torch::nn::Dropout2d(p=0.42, inplace=true)"); } TEST_F(ModulesTest, PrettyPrintDropout3d) { - ASSERT_EQ(c10::str(Dropout3d()), "torch::nn::Dropout3d(p=0.5, inplace=false)"); - ASSERT_EQ(c10::str(Dropout3d(0.42)), "torch::nn::Dropout3d(p=0.42, inplace=false)"); - ASSERT_EQ(c10::str(Dropout3d(Dropout3dOptions().p(0.42).inplace(true))), "torch::nn::Dropout3d(p=0.42, inplace=true)"); + ASSERT_EQ( + c10::str(Dropout3d()), "torch::nn::Dropout3d(p=0.5, inplace=false)"); + ASSERT_EQ( + c10::str(Dropout3d(0.42)), "torch::nn::Dropout3d(p=0.42, inplace=false)"); + ASSERT_EQ( + c10::str(Dropout3d(Dropout3dOptions().p(0.42).inplace(true))), + "torch::nn::Dropout3d(p=0.42, inplace=true)"); } TEST_F(ModulesTest, PrettyPrintFunctional) { @@ -4616,76 +4675,90 @@ TEST_F(ModulesTest, PrettyPrintFunctional) { TEST_F(ModulesTest, PrettyPrintBatchNorm1d) { ASSERT_EQ( - c10::str(BatchNorm1d( - BatchNorm1dOptions(4).eps(0.5).momentum(0.1).affine(false) - .track_running_stats(true))), + c10::str(BatchNorm1d(BatchNorm1dOptions(4) + .eps(0.5) + .momentum(0.1) + .affine(false) + .track_running_stats(true))), "torch::nn::BatchNorm1d(4, eps=0.5, momentum=0.1, affine=false, track_running_stats=true)"); } TEST_F(ModulesTest, PrettyPrintBatchNorm2d) { ASSERT_EQ( - c10::str(BatchNorm2d( - BatchNorm2dOptions(4).eps(0.5).momentum(0.1).affine(false) - .track_running_stats(true))), + c10::str(BatchNorm2d(BatchNorm2dOptions(4) + .eps(0.5) + .momentum(0.1) + .affine(false) + .track_running_stats(true))), "torch::nn::BatchNorm2d(4, eps=0.5, momentum=0.1, affine=false, track_running_stats=true)"); } TEST_F(ModulesTest, PrettyPrintBatchNorm3d) { ASSERT_EQ( - c10::str(BatchNorm3d( - BatchNorm3dOptions(4).eps(0.5).momentum(0.1).affine(false) - .track_running_stats(true))), + c10::str(BatchNorm3d(BatchNorm3dOptions(4) + .eps(0.5) + .momentum(0.1) + .affine(false) + .track_running_stats(true))), "torch::nn::BatchNorm3d(4, eps=0.5, momentum=0.1, affine=false, track_running_stats=true)"); } TEST_F(ModulesTest, PrettyPrintInstanceNorm1d) { ASSERT_EQ( - c10::str(InstanceNorm1d( - InstanceNorm1dOptions(4).eps(0.5).momentum(0.1).affine(false) - .track_running_stats(true))), + c10::str(InstanceNorm1d(InstanceNorm1dOptions(4) + .eps(0.5) + .momentum(0.1) + .affine(false) + .track_running_stats(true))), "torch::nn::InstanceNorm1d(4, eps=0.5, momentum=0.1, affine=false, track_running_stats=true)"); } TEST_F(ModulesTest, PrettyPrintInstanceNorm2d) { ASSERT_EQ( - c10::str(InstanceNorm2d( - InstanceNorm2dOptions(4).eps(0.5).momentum(0.1).affine(false) - .track_running_stats(true))), + c10::str(InstanceNorm2d(InstanceNorm2dOptions(4) + .eps(0.5) + .momentum(0.1) + .affine(false) + .track_running_stats(true))), "torch::nn::InstanceNorm2d(4, eps=0.5, momentum=0.1, affine=false, track_running_stats=true)"); } TEST_F(ModulesTest, PrettyPrintInstanceNorm3d) { ASSERT_EQ( - c10::str(InstanceNorm3d( - InstanceNorm3dOptions(4).eps(0.5).momentum(0.1).affine(false) - .track_running_stats(true))), + c10::str(InstanceNorm3d(InstanceNorm3dOptions(4) + .eps(0.5) + .momentum(0.1) + .affine(false) + .track_running_stats(true))), "torch::nn::InstanceNorm3d(4, eps=0.5, momentum=0.1, affine=false, track_running_stats=true)"); } TEST_F(ModulesTest, PrettyPrintLayerNorm) { ASSERT_EQ( - c10::str(LayerNorm(LayerNormOptions({2, 2}))), + c10::str(LayerNorm(LayerNormOptions({2, 2}))), "torch::nn::LayerNorm([2, 2], eps=1e-05, elementwise_affine=true)"); - ASSERT_EQ( - c10::str(LayerNorm(LayerNormOptions({2, 2}).elementwise_affine(false).eps(2e-5))), - "torch::nn::LayerNorm([2, 2], eps=2e-05, elementwise_affine=false)"); + ASSERT_EQ( + c10::str(LayerNorm( + LayerNormOptions({2, 2}).elementwise_affine(false).eps(2e-5))), + "torch::nn::LayerNorm([2, 2], eps=2e-05, elementwise_affine=false)"); } TEST_F(ModulesTest, PrettyPrintGroupNorm) { ASSERT_EQ( - c10::str(GroupNorm(GroupNormOptions(2, 2))), - "torch::nn::GroupNorm(2, 2, eps=1e-05, affine=true)"); + c10::str(GroupNorm(GroupNormOptions(2, 2))), + "torch::nn::GroupNorm(2, 2, eps=1e-05, affine=true)"); ASSERT_EQ( - c10::str(GroupNorm(GroupNormOptions(2, 2).eps(2e-5).affine(false))), - "torch::nn::GroupNorm(2, 2, eps=2e-05, affine=false)"); + c10::str(GroupNorm(GroupNormOptions(2, 2).eps(2e-5).affine(false))), + "torch::nn::GroupNorm(2, 2, eps=2e-05, affine=false)"); } TEST_F(ModulesTest, PrettyPrintLocalResponseNorm) { ASSERT_EQ( - c10::str(LocalResponseNorm(LocalResponseNormOptions(2))), + c10::str(LocalResponseNorm(LocalResponseNormOptions(2))), "torch::nn::LocalResponseNorm(2, alpha=0.0001, beta=0.75, k=1)"); ASSERT_EQ( - c10::str(LocalResponseNorm(LocalResponseNormOptions(2).alpha(0.0002).beta(0.85).k(2.))), + c10::str(LocalResponseNorm( + LocalResponseNormOptions(2).alpha(0.0002).beta(0.85).k(2.))), "torch::nn::LocalResponseNorm(2, alpha=0.0002, beta=0.85, k=2)"); } @@ -4697,7 +4770,12 @@ TEST_F(ModulesTest, PrettyPrintEmbedding) { c10::str(Embedding(EmbeddingOptions(10, 2).padding_idx(3).max_norm(2))), "torch::nn::Embedding(num_embeddings=10, embedding_dim=2, padding_idx=3, max_norm=2)"); ASSERT_EQ( - c10::str(Embedding(EmbeddingOptions(10, 2).padding_idx(3).max_norm(2).norm_type(2.5).scale_grad_by_freq(true).sparse(true))), + c10::str(Embedding(EmbeddingOptions(10, 2) + .padding_idx(3) + .max_norm(2) + .norm_type(2.5) + .scale_grad_by_freq(true) + .sparse(true))), "torch::nn::Embedding(num_embeddings=10, embedding_dim=2, padding_idx=3, max_norm=2, norm_type=2.5, scale_grad_by_freq=true, sparse=true)"); } @@ -4709,35 +4787,42 @@ TEST_F(ModulesTest, PrettyPrintEmbeddingBag) { c10::str(EmbeddingBag(EmbeddingBagOptions(10, 2).max_norm(2))), "torch::nn::EmbeddingBag(num_embeddings=10, embedding_dim=2, max_norm=2)"); ASSERT_EQ( - c10::str(EmbeddingBag(EmbeddingBagOptions(10, 2).max_norm(2).norm_type(2.5).scale_grad_by_freq(true).sparse(true))), + c10::str(EmbeddingBag(EmbeddingBagOptions(10, 2) + .max_norm(2) + .norm_type(2.5) + .scale_grad_by_freq(true) + .sparse(true))), "torch::nn::EmbeddingBag(num_embeddings=10, embedding_dim=2, max_norm=2, norm_type=2.5, scale_grad_by_freq=true, sparse=true)"); ASSERT_EQ( - c10::str(EmbeddingBag(EmbeddingBagOptions(10, 2).max_norm(2).norm_type(2.5).scale_grad_by_freq(true).sparse(true).mode(torch::kSum))), + c10::str(EmbeddingBag(EmbeddingBagOptions(10, 2) + .max_norm(2) + .norm_type(2.5) + .scale_grad_by_freq(true) + .sparse(true) + .mode(torch::kSum))), "torch::nn::EmbeddingBag(num_embeddings=10, embedding_dim=2, max_norm=2, norm_type=2.5, scale_grad_by_freq=true, sparse=true, mode=kSum)"); ASSERT_EQ( - c10::str(EmbeddingBag(EmbeddingBagOptions(10, 2).max_norm(2).norm_type(2.5).scale_grad_by_freq(true).sparse(true).mode(torch::kSum).padding_idx(5))), + c10::str(EmbeddingBag(EmbeddingBagOptions(10, 2) + .max_norm(2) + .norm_type(2.5) + .scale_grad_by_freq(true) + .sparse(true) + .mode(torch::kSum) + .padding_idx(5))), "torch::nn::EmbeddingBag(num_embeddings=10, embedding_dim=2, max_norm=2, norm_type=2.5, scale_grad_by_freq=true, sparse=true, mode=kSum, padding_idx=5)"); } TEST_F(ModulesTest, PrettyPrintL1Loss) { - ASSERT_EQ( - c10::str(L1Loss()), - "torch::nn::L1Loss()"); + ASSERT_EQ(c10::str(L1Loss()), "torch::nn::L1Loss()"); } TEST_F(ModulesTest, PrettyPrintKLDivLoss) { - ASSERT_EQ( - c10::str(KLDivLoss()), - "torch::nn::KLDivLoss()"); + ASSERT_EQ(c10::str(KLDivLoss()), "torch::nn::KLDivLoss()"); } TEST_F(ModulesTest, PrettyPrintMSELoss) { - ASSERT_EQ( - c10::str(MSELoss()), - "torch::nn::MSELoss()"); + ASSERT_EQ(c10::str(MSELoss()), "torch::nn::MSELoss()"); } TEST_F(ModulesTest, PrettyPrintBCELoss) { - ASSERT_EQ( - c10::str(BCELoss()), - "torch::nn::BCELoss()"); + ASSERT_EQ(c10::str(BCELoss()), "torch::nn::BCELoss()"); } TEST_F(ModulesTest, PrettyPrintHingeEmbeddingLoss) { ASSERT_EQ( @@ -4753,7 +4838,8 @@ TEST_F(ModulesTest, PrettyPrintCosineEmbeddingLoss) { TEST_F(ModulesTest, PrettyPrintTripletMarginLoss) { ASSERT_EQ( - c10::str(TripletMarginLoss(TripletMarginLossOptions().margin(3).p(2).eps(1e-06).swap(false))), + c10::str(TripletMarginLoss( + TripletMarginLossOptions().margin(3).p(2).eps(1e-06).swap(false))), "torch::nn::TripletMarginLoss(margin=3, p=2, eps=1e-06, swap=false)"); } @@ -4772,21 +4858,22 @@ TEST_F(ModulesTest, PrettyPrintTripletMarginWithDistanceLoss) { } TEST_F(ModulesTest, PrettyPrintNLLLoss) { - ASSERT_EQ( - c10::str(NLLLoss()), "torch::nn::NLLLoss()"); + ASSERT_EQ(c10::str(NLLLoss()), "torch::nn::NLLLoss()"); } TEST_F(ModulesTest, PrettyPrinCrossEntropyLoss) { - ASSERT_EQ( - c10::str(CrossEntropyLoss()), "torch::nn::CrossEntropyLoss()"); + ASSERT_EQ(c10::str(CrossEntropyLoss()), "torch::nn::CrossEntropyLoss()"); } TEST_F(ModulesTest, PrettyPrintMultiLabelMarginLoss) { - ASSERT_EQ(c10::str(MultiLabelMarginLoss()), "torch::nn::MultiLabelMarginLoss()"); + ASSERT_EQ( + c10::str(MultiLabelMarginLoss()), "torch::nn::MultiLabelMarginLoss()"); } TEST_F(ModulesTest, PrettyPrintMultiLabelSoftMarginLoss) { - ASSERT_EQ(c10::str(MultiLabelSoftMarginLoss()), "torch::nn::MultiLabelSoftMarginLoss()"); + ASSERT_EQ( + c10::str(MultiLabelSoftMarginLoss()), + "torch::nn::MultiLabelSoftMarginLoss()"); } TEST_F(ModulesTest, PrettyPrintSoftMarginLoss) { @@ -4807,7 +4894,8 @@ TEST_F(ModulesTest, PrettyPrintPairwiseDistance) { c10::str(PairwiseDistance()), "torch::nn::PairwiseDistance(p=2, eps=1e-06, keepdim=false)"); ASSERT_EQ( - c10::str(PairwiseDistance(PairwiseDistanceOptions().p(3).eps(0.5).keepdim(true))), + c10::str(PairwiseDistance( + PairwiseDistanceOptions().p(3).eps(0.5).keepdim(true))), "torch::nn::PairwiseDistance(p=3, eps=0.5, keepdim=true)"); } @@ -4892,7 +4980,9 @@ TEST_F(ModulesTest, PrettyPrintNestedModel) { TestModule() : torch::nn::Module("TestModule"), fc(register_module("fc", torch::nn::Linear(4, 5))), - table(register_module("table", torch::nn::Embedding(EmbeddingOptions(10, 2)))), + table(register_module( + "table", + torch::nn::Embedding(EmbeddingOptions(10, 2)))), inner(register_module("inner", std::make_shared())) { } @@ -4915,14 +5005,16 @@ TEST_F(ModulesTest, PrettyPrintNestedModel) { TEST_F(ModulesTest, PrettyPrintELU) { ASSERT_EQ(c10::str(ELU()), "torch::nn::ELU(alpha=1)"); - ASSERT_EQ(c10::str(ELU(ELUOptions().alpha(42.42).inplace(true))), - "torch::nn::ELU(alpha=42.42, inplace=true)"); + ASSERT_EQ( + c10::str(ELU(ELUOptions().alpha(42.42).inplace(true))), + "torch::nn::ELU(alpha=42.42, inplace=true)"); } TEST_F(ModulesTest, PrettyPrintSELU) { ASSERT_EQ(c10::str(SELU()), "torch::nn::SELU()"); - ASSERT_EQ(c10::str(SELU(SELUOptions().inplace(true))), - "torch::nn::SELU(inplace=true)"); + ASSERT_EQ( + c10::str(SELU(SELUOptions().inplace(true))), + "torch::nn::SELU(inplace=true)"); } TEST_F(ModulesTest, PrettyPrintGLU) { @@ -4932,24 +5024,25 @@ TEST_F(ModulesTest, PrettyPrintGLU) { TEST_F(ModulesTest, PrettyPrintHardshrink) { ASSERT_EQ(c10::str(Hardshrink()), "torch::nn::Hardshrink(0.5)"); - ASSERT_EQ(c10::str(Hardshrink(HardshrinkOptions().lambda(42.42))), - "torch::nn::Hardshrink(42.42)"); + ASSERT_EQ( + c10::str(Hardshrink(HardshrinkOptions().lambda(42.42))), + "torch::nn::Hardshrink(42.42)"); } TEST_F(ModulesTest, PrettyPrintHardtanh) { - ASSERT_EQ(c10::str(Hardtanh()), - "torch::nn::Hardtanh(min_val=-1, max_val=1)"); - ASSERT_EQ(c10::str(Hardtanh( - HardtanhOptions().min_val(-42.42).max_val(0.42).inplace(true))), - "torch::nn::Hardtanh(min_val=-42.42, max_val=0.42, inplace=true)"); + ASSERT_EQ(c10::str(Hardtanh()), "torch::nn::Hardtanh(min_val=-1, max_val=1)"); + ASSERT_EQ( + c10::str(Hardtanh( + HardtanhOptions().min_val(-42.42).max_val(0.42).inplace(true))), + "torch::nn::Hardtanh(min_val=-42.42, max_val=0.42, inplace=true)"); } TEST_F(ModulesTest, PrettyPrintLeakyReLU) { - ASSERT_EQ(c10::str(LeakyReLU()), - "torch::nn::LeakyReLU(negative_slope=0.01)"); - ASSERT_EQ(c10::str(LeakyReLU( - LeakyReLUOptions().negative_slope(0.42).inplace(true))), - "torch::nn::LeakyReLU(negative_slope=0.42, inplace=true)"); + ASSERT_EQ(c10::str(LeakyReLU()), "torch::nn::LeakyReLU(negative_slope=0.01)"); + ASSERT_EQ( + c10::str( + LeakyReLU(LeakyReLUOptions().negative_slope(0.42).inplace(true))), + "torch::nn::LeakyReLU(negative_slope=0.42, inplace=true)"); } TEST_F(ModulesTest, PrettyPrintLogSigmoid) { @@ -4965,8 +5058,9 @@ TEST_F(ModulesTest, PrettyPrintSoftmin) { } TEST_F(ModulesTest, PrettyPrintLogSoftmax) { - ASSERT_EQ(c10::str(LogSoftmax(LogSoftmaxOptions(1))), - "torch::nn::LogSoftmax(dim=1)"); + ASSERT_EQ( + c10::str(LogSoftmax(LogSoftmaxOptions(1))), + "torch::nn::LogSoftmax(dim=1)"); } TEST_F(ModulesTest, PrettyPrintSoftmax2d) { @@ -4975,38 +5069,40 @@ TEST_F(ModulesTest, PrettyPrintSoftmax2d) { TEST_F(ModulesTest, PrettyPrintPReLU) { ASSERT_EQ(c10::str(PReLU()), "torch::nn::PReLU(num_parameters=1)"); - ASSERT_EQ(c10::str(PReLU(PReLUOptions().num_parameters(42))), - "torch::nn::PReLU(num_parameters=42)"); + ASSERT_EQ( + c10::str(PReLU(PReLUOptions().num_parameters(42))), + "torch::nn::PReLU(num_parameters=42)"); } TEST_F(ModulesTest, PrettyPrintReLU) { ASSERT_EQ(c10::str(ReLU()), "torch::nn::ReLU()"); - ASSERT_EQ(c10::str(ReLU(ReLUOptions().inplace(true))), - "torch::nn::ReLU(inplace=true)"); - ASSERT_EQ(c10::str(ReLU(/*inplace=*/true)), - "torch::nn::ReLU(inplace=true)"); + ASSERT_EQ( + c10::str(ReLU(ReLUOptions().inplace(true))), + "torch::nn::ReLU(inplace=true)"); + ASSERT_EQ(c10::str(ReLU(/*inplace=*/true)), "torch::nn::ReLU(inplace=true)"); } TEST_F(ModulesTest, PrettyPrintReLU6) { ASSERT_EQ(c10::str(ReLU6()), "torch::nn::ReLU6()"); - ASSERT_EQ(c10::str(ReLU6(ReLU6Options().inplace(true))), - "torch::nn::ReLU6(inplace=true)"); - ASSERT_EQ(c10::str(ReLU6(/*inplace=*/true)), - "torch::nn::ReLU6(inplace=true)"); + ASSERT_EQ( + c10::str(ReLU6(ReLU6Options().inplace(true))), + "torch::nn::ReLU6(inplace=true)"); + ASSERT_EQ( + c10::str(ReLU6(/*inplace=*/true)), "torch::nn::ReLU6(inplace=true)"); } TEST_F(ModulesTest, PrettyPrintRReLU) { - ASSERT_EQ(c10::str(RReLU()), - "torch::nn::RReLU(lower=0.125, upper=0.333333)"); - ASSERT_EQ(c10::str(RReLU( - RReLUOptions().lower(0.24).upper(0.42).inplace(true))), - "torch::nn::RReLU(lower=0.24, upper=0.42, inplace=true)"); + ASSERT_EQ(c10::str(RReLU()), "torch::nn::RReLU(lower=0.125, upper=0.333333)"); + ASSERT_EQ( + c10::str(RReLU(RReLUOptions().lower(0.24).upper(0.42).inplace(true))), + "torch::nn::RReLU(lower=0.24, upper=0.42, inplace=true)"); } TEST_F(ModulesTest, PrettyPrintCELU) { ASSERT_EQ(c10::str(CELU()), "torch::nn::CELU(alpha=1)"); - ASSERT_EQ(c10::str(CELU(CELUOptions().alpha(42.42).inplace(true))), - "torch::nn::CELU(alpha=42.42, inplace=true)"); + ASSERT_EQ( + c10::str(CELU(CELUOptions().alpha(42.42).inplace(true))), + "torch::nn::CELU(alpha=42.42, inplace=true)"); } TEST_F(ModulesTest, PrettyPrintSigmoid) { @@ -5014,8 +5110,9 @@ TEST_F(ModulesTest, PrettyPrintSigmoid) { } TEST_F(ModulesTest, PrettyPrintPixelShuffle) { - ASSERT_EQ(c10::str(PixelShuffle(PixelShuffleOptions(5))), - "torch::nn::PixelShuffle(upscale_factor=5)"); + ASSERT_EQ( + c10::str(PixelShuffle(PixelShuffleOptions(5))), + "torch::nn::PixelShuffle(upscale_factor=5)"); } TEST_F(ModulesTest, PrettyPrintPixelUnshuffle) { @@ -5025,17 +5122,17 @@ TEST_F(ModulesTest, PrettyPrintPixelUnshuffle) { } TEST_F(ModulesTest, PrettyPrintSoftplus) { - ASSERT_EQ(c10::str(Softplus()), - "torch::nn::Softplus(beta=1, threshold=20)"); - ASSERT_EQ(c10::str(Softplus( - SoftplusOptions().beta(0.24).threshold(42.42))), - "torch::nn::Softplus(beta=0.24, threshold=42.42)"); + ASSERT_EQ(c10::str(Softplus()), "torch::nn::Softplus(beta=1, threshold=20)"); + ASSERT_EQ( + c10::str(Softplus(SoftplusOptions().beta(0.24).threshold(42.42))), + "torch::nn::Softplus(beta=0.24, threshold=42.42)"); } TEST_F(ModulesTest, PrettyPrintSoftshrink) { ASSERT_EQ(c10::str(Softshrink()), "torch::nn::Softshrink(0.5)"); - ASSERT_EQ(c10::str(Softshrink(SoftshrinkOptions(42.42))), - "torch::nn::Softshrink(42.42)"); + ASSERT_EQ( + c10::str(Softshrink(SoftshrinkOptions(42.42))), + "torch::nn::Softshrink(42.42)"); } TEST_F(ModulesTest, PrettyPrintSoftsign) { @@ -5051,131 +5148,158 @@ TEST_F(ModulesTest, PrettyPrintTanhshrink) { } TEST_F(ModulesTest, PrettyPrintThreshold) { - ASSERT_EQ(c10::str(Threshold(24.24, 42.42)), - "torch::nn::Threshold(threshold=24.24, value=42.42)"); - ASSERT_EQ(c10::str(Threshold( - ThresholdOptions(42.42, 24.24).inplace(true))), - "torch::nn::Threshold(threshold=42.42, value=24.24, inplace=true)"); + ASSERT_EQ( + c10::str(Threshold(24.24, 42.42)), + "torch::nn::Threshold(threshold=24.24, value=42.42)"); + ASSERT_EQ( + c10::str(Threshold(ThresholdOptions(42.42, 24.24).inplace(true))), + "torch::nn::Threshold(threshold=42.42, value=24.24, inplace=true)"); } TEST_F(ModulesTest, PrettyPrintCTCLoss) { ASSERT_EQ(c10::str(CTCLoss()), "torch::nn::CTCLoss()"); - ASSERT_EQ(c10::str(CTCLoss( - CTCLossOptions().blank(42).zero_infinity(false) - .reduction(torch::kSum))), "torch::nn::CTCLoss()"); + ASSERT_EQ( + c10::str( + CTCLoss(CTCLossOptions().blank(42).zero_infinity(false).reduction( + torch::kSum))), + "torch::nn::CTCLoss()"); } TEST_F(ModulesTest, PrettyPrintPoissonNLLLoss) { ASSERT_EQ(c10::str(PoissonNLLLoss()), "torch::nn::PoissonNLLLoss()"); - ASSERT_EQ(c10::str(PoissonNLLLoss( - PoissonNLLLossOptions().log_input(false).full(true).eps(0.42) - .reduction(torch::kSum))), - "torch::nn::PoissonNLLLoss()"); + ASSERT_EQ( + c10::str(PoissonNLLLoss(PoissonNLLLossOptions() + .log_input(false) + .full(true) + .eps(0.42) + .reduction(torch::kSum))), + "torch::nn::PoissonNLLLoss()"); } TEST_F(ModulesTest, PrettyPrintMarginRankingLoss) { ASSERT_EQ(c10::str(MarginRankingLoss()), "torch::nn::MarginRankingLoss()"); - ASSERT_EQ(c10::str(MarginRankingLoss( - MarginRankingLossOptions().margin(0.5).reduction(torch::kSum))), - "torch::nn::MarginRankingLoss()"); + ASSERT_EQ( + c10::str(MarginRankingLoss( + MarginRankingLossOptions().margin(0.5).reduction(torch::kSum))), + "torch::nn::MarginRankingLoss()"); } TEST_F(ModulesTest, PrettyPrintCrossMapLRN2d) { - ASSERT_EQ(c10::str(CrossMapLRN2d(4)), - "torch::nn::CrossMapLRN2d(4, alpha=0.0001, beta=0.75, k=1)"); - ASSERT_EQ(c10::str(CrossMapLRN2d(CrossMapLRN2dOptions(3).alpha(1e-5).beta(0.1).k(10))), - "torch::nn::CrossMapLRN2d(3, alpha=1e-05, beta=0.1, k=10)"); + ASSERT_EQ( + c10::str(CrossMapLRN2d(4)), + "torch::nn::CrossMapLRN2d(4, alpha=0.0001, beta=0.75, k=1)"); + ASSERT_EQ( + c10::str( + CrossMapLRN2d(CrossMapLRN2dOptions(3).alpha(1e-5).beta(0.1).k(10))), + "torch::nn::CrossMapLRN2d(3, alpha=1e-05, beta=0.1, k=10)"); } TEST_F(ModulesTest, PrettyPrintAlphaDropout) { - ASSERT_EQ(c10::str(AlphaDropout()), - "torch::nn::AlphaDropout(p=0.5, inplace=false)"); - ASSERT_EQ(c10::str(AlphaDropout(AlphaDropoutOptions(0.2))), - "torch::nn::AlphaDropout(p=0.2, inplace=false)"); - ASSERT_EQ(c10::str(AlphaDropout(AlphaDropoutOptions(0.2).inplace(true))), - "torch::nn::AlphaDropout(p=0.2, inplace=true)"); + ASSERT_EQ( + c10::str(AlphaDropout()), + "torch::nn::AlphaDropout(p=0.5, inplace=false)"); + ASSERT_EQ( + c10::str(AlphaDropout(AlphaDropoutOptions(0.2))), + "torch::nn::AlphaDropout(p=0.2, inplace=false)"); + ASSERT_EQ( + c10::str(AlphaDropout(AlphaDropoutOptions(0.2).inplace(true))), + "torch::nn::AlphaDropout(p=0.2, inplace=true)"); } TEST_F(ModulesTest, PrettyPrintFeatureAlphaDropout) { - ASSERT_EQ(c10::str(FeatureAlphaDropout()), - "torch::nn::FeatureAlphaDropout(p=0.5, inplace=false)"); - ASSERT_EQ(c10::str(FeatureAlphaDropout(FeatureAlphaDropoutOptions(0.2))), - "torch::nn::FeatureAlphaDropout(p=0.2, inplace=false)"); - ASSERT_EQ(c10::str(FeatureAlphaDropout(FeatureAlphaDropoutOptions(0.2).inplace(true))), - "torch::nn::FeatureAlphaDropout(p=0.2, inplace=true)"); + ASSERT_EQ( + c10::str(FeatureAlphaDropout()), + "torch::nn::FeatureAlphaDropout(p=0.5, inplace=false)"); + ASSERT_EQ( + c10::str(FeatureAlphaDropout(FeatureAlphaDropoutOptions(0.2))), + "torch::nn::FeatureAlphaDropout(p=0.2, inplace=false)"); + ASSERT_EQ( + c10::str( + FeatureAlphaDropout(FeatureAlphaDropoutOptions(0.2).inplace(true))), + "torch::nn::FeatureAlphaDropout(p=0.2, inplace=true)"); } TEST_F(ModulesTest, PrettyPrintBCEWithLogitsLoss) { ASSERT_EQ(c10::str(BCEWithLogitsLoss()), "torch::nn::BCEWithLogitsLoss()"); - ASSERT_EQ(c10::str(BCEWithLogitsLoss( - BCEWithLogitsLossOptions() - .weight(torch::ones({3, 3})) - .pos_weight(torch::ones({3, 3})) - .reduction(torch::kSum))), - "torch::nn::BCEWithLogitsLoss()"); + ASSERT_EQ( + c10::str(BCEWithLogitsLoss(BCEWithLogitsLossOptions() + .weight(torch::ones({3, 3})) + .pos_weight(torch::ones({3, 3})) + .reduction(torch::kSum))), + "torch::nn::BCEWithLogitsLoss()"); } TEST_F(ModulesTest, PrettyPrintMultiheadAttention) { - ASSERT_EQ(c10::str(MultiheadAttention(20, 10)), - "torch::nn::MultiheadAttention(\n (out_proj): torch::nn::Linear(in_features=20, out_features=20, bias=true)\n)"); - ASSERT_EQ(c10::str(MultiheadAttention(MultiheadAttentionOptions(20, 10).bias(false))), - "torch::nn::MultiheadAttention(\n (out_proj): torch::nn::Linear(in_features=20, out_features=20, bias=false)\n)"); + ASSERT_EQ( + c10::str(MultiheadAttention(20, 10)), + "torch::nn::MultiheadAttention(\n (out_proj): torch::nn::Linear(in_features=20, out_features=20, bias=true)\n)"); + ASSERT_EQ( + c10::str( + MultiheadAttention(MultiheadAttentionOptions(20, 10).bias(false))), + "torch::nn::MultiheadAttention(\n (out_proj): torch::nn::Linear(in_features=20, out_features=20, bias=false)\n)"); } TEST_F(ModulesTest, PrettyPrintRNNCell) { - ASSERT_EQ(c10::str(RNNCell(20, 10)), - "torch::nn::RNNCell(20, 10)"); - ASSERT_EQ(c10::str(RNNCell(RNNCellOptions(20, 10).bias(false).nonlinearity(torch::kTanh))), - "torch::nn::RNNCell(20, 10, bias=false)"); - ASSERT_EQ(c10::str(RNNCell(RNNCellOptions(20, 10).bias(false).nonlinearity(torch::kReLU))), - "torch::nn::RNNCell(20, 10, bias=false, nonlinearity=kReLU)"); + ASSERT_EQ(c10::str(RNNCell(20, 10)), "torch::nn::RNNCell(20, 10)"); + ASSERT_EQ( + c10::str(RNNCell( + RNNCellOptions(20, 10).bias(false).nonlinearity(torch::kTanh))), + "torch::nn::RNNCell(20, 10, bias=false)"); + ASSERT_EQ( + c10::str(RNNCell( + RNNCellOptions(20, 10).bias(false).nonlinearity(torch::kReLU))), + "torch::nn::RNNCell(20, 10, bias=false, nonlinearity=kReLU)"); } TEST_F(ModulesTest, PrettyPrintLSTMCell) { - ASSERT_EQ(c10::str(LSTMCell(20, 10)), - "torch::nn::LSTMCell(20, 10)"); - ASSERT_EQ(c10::str(LSTMCell(LSTMCellOptions(20, 10).bias(false))), - "torch::nn::LSTMCell(20, 10, bias=false)"); + ASSERT_EQ(c10::str(LSTMCell(20, 10)), "torch::nn::LSTMCell(20, 10)"); + ASSERT_EQ( + c10::str(LSTMCell(LSTMCellOptions(20, 10).bias(false))), + "torch::nn::LSTMCell(20, 10, bias=false)"); } TEST_F(ModulesTest, PrettyPrintGRUCell) { - ASSERT_EQ(c10::str(GRUCell(20, 10)), - "torch::nn::GRUCell(20, 10)"); - ASSERT_EQ(c10::str(GRUCell(GRUCellOptions(20, 10).bias(false))), - "torch::nn::GRUCell(20, 10, bias=false)"); + ASSERT_EQ(c10::str(GRUCell(20, 10)), "torch::nn::GRUCell(20, 10)"); + ASSERT_EQ( + c10::str(GRUCell(GRUCellOptions(20, 10).bias(false))), + "torch::nn::GRUCell(20, 10, bias=false)"); } TEST_F(ModulesTest, PrettyPrintAdaptiveLogSoftmaxWithLoss) { { - AdaptiveLogSoftmaxWithLoss asfm(AdaptiveLogSoftmaxWithLossOptions(8, 4, {2}).div_value(2.)); + AdaptiveLogSoftmaxWithLoss asfm( + AdaptiveLogSoftmaxWithLossOptions(8, 4, {2}).div_value(2.)); ASSERT_EQ( - c10::str(asfm), - "torch::nn::AdaptiveLogSoftmaxWithLoss(\n" - " (head): torch::nn::Linear(in_features=8, out_features=3, bias=false)\n" - " (tail): torch::nn::ModuleList(\n" - " (0): torch::nn::Sequential(\n" - " (0): torch::nn::Linear(in_features=8, out_features=4, bias=false)\n" - " (1): torch::nn::Linear(in_features=4, out_features=2, bias=false)\n" - " )\n" - " )\n" - ")"); + c10::str(asfm), + "torch::nn::AdaptiveLogSoftmaxWithLoss(\n" + " (head): torch::nn::Linear(in_features=8, out_features=3, bias=false)\n" + " (tail): torch::nn::ModuleList(\n" + " (0): torch::nn::Sequential(\n" + " (0): torch::nn::Linear(in_features=8, out_features=4, bias=false)\n" + " (1): torch::nn::Linear(in_features=4, out_features=2, bias=false)\n" + " )\n" + " )\n" + ")"); } { - AdaptiveLogSoftmaxWithLoss asfm(AdaptiveLogSoftmaxWithLossOptions(8, 10, {4, 8}).div_value(2.).head_bias(true)); + AdaptiveLogSoftmaxWithLoss asfm( + AdaptiveLogSoftmaxWithLossOptions(8, 10, {4, 8}) + .div_value(2.) + .head_bias(true)); ASSERT_EQ( - c10::str(asfm), - "torch::nn::AdaptiveLogSoftmaxWithLoss(\n" - " (head): torch::nn::Linear(in_features=8, out_features=6, bias=true)\n" - " (tail): torch::nn::ModuleList(\n" - " (0): torch::nn::Sequential(\n" - " (0): torch::nn::Linear(in_features=8, out_features=4, bias=false)\n" - " (1): torch::nn::Linear(in_features=4, out_features=4, bias=false)\n" - " )\n" - " (1): torch::nn::Sequential(\n" - " (0): torch::nn::Linear(in_features=8, out_features=2, bias=false)\n" - " (1): torch::nn::Linear(in_features=2, out_features=2, bias=false)\n" - " )\n" - " )\n" - ")"); + c10::str(asfm), + "torch::nn::AdaptiveLogSoftmaxWithLoss(\n" + " (head): torch::nn::Linear(in_features=8, out_features=6, bias=true)\n" + " (tail): torch::nn::ModuleList(\n" + " (0): torch::nn::Sequential(\n" + " (0): torch::nn::Linear(in_features=8, out_features=4, bias=false)\n" + " (1): torch::nn::Linear(in_features=4, out_features=4, bias=false)\n" + " )\n" + " (1): torch::nn::Sequential(\n" + " (0): torch::nn::Linear(in_features=8, out_features=2, bias=false)\n" + " (1): torch::nn::Linear(in_features=2, out_features=2, bias=false)\n" + " )\n" + " )\n" + ")"); } } diff --git a/test/cpp/api/namespace.cpp b/test/cpp/api/namespace.cpp index 0ee5b70b9375ac..0952d09db55d38 100644 --- a/test/cpp/api/namespace.cpp +++ b/test/cpp/api/namespace.cpp @@ -4,13 +4,14 @@ struct Node {}; -// If `torch::autograd::Note` is leaked into the root namespace, the following compile error would throw: +// If `torch::autograd::Note` is leaked into the root namespace, the following +// compile error would throw: // ``` // void NotLeakingSymbolsFromTorchAutogradNamespace_test_func(Node *node) {} // ^ // error: reference to `Node` is ambiguous // ``` -void NotLeakingSymbolsFromTorchAutogradNamespace_test_func(Node *node) {} +void NotLeakingSymbolsFromTorchAutogradNamespace_test_func(Node* node) {} TEST(NamespaceTests, NotLeakingSymbolsFromTorchAutogradNamespace) { // Checks that we are not leaking symbols from the diff --git a/test/cpp/api/nn_utils.cpp b/test/cpp/api/nn_utils.cpp index 82ed648f9e8429..3d24749a965323 100644 --- a/test/cpp/api/nn_utils.cpp +++ b/test/cpp/api/nn_utils.cpp @@ -7,8 +7,8 @@ #include #include -#include #include +#include using namespace torch::nn; @@ -138,45 +138,65 @@ TEST_F(NNUtilsTest, ClipGradNormErrorIfNonfinite) { // // norms_finite Norm types that should produce finite total norm std::vector> test_cases({ - // Test errors from an infinite grad - std::make_tuple(false, false, Vector({inf, -inf}), norms_except_0, Vector({0})), - std::make_tuple(false, true, Vector({inf, -inf}), norms_pos, norms_neg_plus_0), - std::make_tuple(true, false, Vector({inf, -inf}), norms_pos, norms_neg_plus_0), - std::make_tuple(false, true, Vector({inf, -inf}), norms_pos, norms_neg_plus_0), - - // Test errors from a NaN grad - std::make_tuple(false, false, Vector({nan}), norms_except_0, Vector({0})), - std::make_tuple(false, true, Vector({nan}), norms_except_0, Vector({0})), - std::make_tuple(true, false, Vector({nan}), norms_except_0, Vector({0})), - std::make_tuple(true, true, Vector({nan}), norms_except_0, Vector({0})), - - // Test a grad that should never error - std::make_tuple(false, false, Vector({2e22, -2e22}), Vector(), norms_all), - std::make_tuple(false, true, Vector({2e22, -2e22}), Vector(), norms_all), - std::make_tuple(true, false, Vector({2e22, -2e22}), Vector(), norms_all), - std::make_tuple(true, true, Vector({2e22, -2e22}), Vector(), norms_all), - - // Test a grad that will overflow to inf for only some norm orders - std::make_tuple( - false, false, Vector({2e200, -2e200}), Vector({3.5, 2, -2, -3.5}), - Vector({inf, 1, 0.1, 0, -1, -0.1})), - std::make_tuple( - false, true, Vector({2e200, -2e200}), Vector({3.5, 2}), - Vector({inf, 1, 0.1, 0, -1, -0.1, -2, -3.5})), - std::make_tuple( - true, false, Vector({2e200, -2e200}), Vector({3.5, 2}), - Vector({inf, 1, 0.1, 0, -1, -0.1, -2, -3.5})), - std::make_tuple( - false, true, Vector({2e200, -2e200}), Vector({3.5, 2}), - Vector({inf, 1, 0.1, 0, -1, -0.1, -2, -3.5})), + // Test errors from an infinite grad + std::make_tuple( + false, false, Vector({inf, -inf}), norms_except_0, Vector({0})), + std::make_tuple( + false, true, Vector({inf, -inf}), norms_pos, norms_neg_plus_0), + std::make_tuple( + true, false, Vector({inf, -inf}), norms_pos, norms_neg_plus_0), + std::make_tuple( + false, true, Vector({inf, -inf}), norms_pos, norms_neg_plus_0), + + // Test errors from a NaN grad + std::make_tuple(false, false, Vector({nan}), norms_except_0, Vector({0})), + std::make_tuple(false, true, Vector({nan}), norms_except_0, Vector({0})), + std::make_tuple(true, false, Vector({nan}), norms_except_0, Vector({0})), + std::make_tuple(true, true, Vector({nan}), norms_except_0, Vector({0})), + + // Test a grad that should never error + std::make_tuple(false, false, Vector({2e22, -2e22}), Vector(), norms_all), + std::make_tuple(false, true, Vector({2e22, -2e22}), Vector(), norms_all), + std::make_tuple(true, false, Vector({2e22, -2e22}), Vector(), norms_all), + std::make_tuple(true, true, Vector({2e22, -2e22}), Vector(), norms_all), + + // Test a grad that will overflow to inf for only some norm orders + std::make_tuple( + false, + false, + Vector({2e200, -2e200}), + Vector({3.5, 2, -2, -3.5}), + Vector({inf, 1, 0.1, 0, -1, -0.1})), + std::make_tuple( + false, + true, + Vector({2e200, -2e200}), + Vector({3.5, 2}), + Vector({inf, 1, 0.1, 0, -1, -0.1, -2, -3.5})), + std::make_tuple( + true, + false, + Vector({2e200, -2e200}), + Vector({3.5, 2}), + Vector({inf, 1, 0.1, 0, -1, -0.1, -2, -3.5})), + std::make_tuple( + false, + true, + Vector({2e200, -2e200}), + Vector({3.5, 2}), + Vector({inf, 1, 0.1, 0, -1, -0.1, -2, -3.5})), }); - auto gen_parameters = []( - double scalar, - bool grad_only_one_elem, - bool prefix_finite_grad_param, - torch::DeviceType device_type) { - auto param = torch::ones(10, torch::TensorOptions().dtype(torch::kDouble).device(device_type).requires_grad(true)); + auto gen_parameters = [](double scalar, + bool grad_only_one_elem, + bool prefix_finite_grad_param, + torch::DeviceType device_type) { + auto param = torch::ones( + 10, + torch::TensorOptions() + .dtype(torch::kDouble) + .device(device_type) + .requires_grad(true)); if (grad_only_one_elem) { param[1].mul(scalar).sum().backward(); } else { @@ -185,7 +205,12 @@ TEST_F(NNUtilsTest, ClipGradNormErrorIfNonfinite) { std::vector parameters; if (prefix_finite_grad_param) { - auto prefix_param = torch::ones(1, torch::TensorOptions().dtype(torch::kDouble).device(device_type).requires_grad(true)); + auto prefix_param = torch::ones( + 1, + torch::TensorOptions() + .dtype(torch::kDouble) + .device(device_type) + .requires_grad(true)); prefix_param.mul(1).sum().backward(); parameters.push_back(prefix_param); } @@ -195,16 +220,15 @@ TEST_F(NNUtilsTest, ClipGradNormErrorIfNonfinite) { }; auto run_test_case = [&gen_parameters]( - double norm_type, - bool error_if_nonfinite, - double scalar, - bool grad_only_one_elem, - bool prefix_finite_grad_param, - bool is_norm_nonfinite, - torch::DeviceType device_type) { + double norm_type, + bool error_if_nonfinite, + double scalar, + bool grad_only_one_elem, + bool prefix_finite_grad_param, + bool is_norm_nonfinite, + torch::DeviceType device_type) { std::stringstream ss; - ss << "device: " << device_type - << ", norm_type: " << norm_type + ss << "device: " << device_type << ", norm_type: " << norm_type << ", error_if_nonfinite: " << error_if_nonfinite << ", scalar: " << scalar << ", grad_only_one_elem: " << grad_only_one_elem @@ -213,10 +237,7 @@ TEST_F(NNUtilsTest, ClipGradNormErrorIfNonfinite) { std::string msg = ss.str(); auto parameters = gen_parameters( - scalar, - grad_only_one_elem, - prefix_finite_grad_param, - device_type); + scalar, grad_only_one_elem, prefix_finite_grad_param, device_type); if (is_norm_nonfinite && error_if_nonfinite) { std::vector grads_before; @@ -226,14 +247,25 @@ TEST_F(NNUtilsTest, ClipGradNormErrorIfNonfinite) { grads_before.push_back(p.grad().clone()); } // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto) - EXPECT_THROW(utils::clip_grad_norm_(parameters, 1., norm_type, true), std::exception) << msg; + EXPECT_THROW( + utils::clip_grad_norm_(parameters, 1., norm_type, true), + std::exception) + << msg; // Grads should not change if error is thrown for (const auto p_idx : c10::irange(parameters.size())) { - ASSERT_TRUE(torch::allclose(parameters[p_idx].grad(), grads_before[p_idx], 1.0, 0.0, /*equal_nan*/ true)) << msg; + ASSERT_TRUE(torch::allclose( + parameters[p_idx].grad(), + grads_before[p_idx], + 1.0, + 0.0, + /*equal_nan*/ true)) + << msg; } } else { // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto) - EXPECT_NO_THROW(utils::clip_grad_norm_(parameters, 1., norm_type, error_if_nonfinite)) << msg; + EXPECT_NO_THROW( + utils::clip_grad_norm_(parameters, 1., norm_type, error_if_nonfinite)) + << msg; } }; @@ -252,24 +284,24 @@ TEST_F(NNUtilsTest, ClipGradNormErrorIfNonfinite) { for (auto scalar : scalars) { for (auto norm_type : norms_nonfinite) { run_test_case( - norm_type, - error_if_nonfinite, - scalar, - grad_only_one_elem, - prefix_finite_grad_param, - true, - device_type); + norm_type, + error_if_nonfinite, + scalar, + grad_only_one_elem, + prefix_finite_grad_param, + true, + device_type); } for (auto norm_type : norms_finite) { run_test_case( - norm_type, - error_if_nonfinite, - scalar, - grad_only_one_elem, - prefix_finite_grad_param, - false, - device_type); + norm_type, + error_if_nonfinite, + scalar, + grad_only_one_elem, + prefix_finite_grad_param, + false, + device_type); } } } @@ -295,10 +327,8 @@ TEST_F(NNUtilsTest, ClipGradValue) { utils::clip_grad_value_(l->parameters(), clip_value); for (const auto& p : l->parameters()) { if (p.grad().defined()) { - ASSERT_LE( - p.grad().data().max().item().toFloat(), clip_value); - ASSERT_GE( - p.grad().data().min().item().toFloat(), -clip_value); + ASSERT_LE(p.grad().data().max().item().toFloat(), clip_value); + ASSERT_GE(p.grad().data().min().item().toFloat(), -clip_value); } } } @@ -316,24 +346,21 @@ TEST_F(NNUtilsTest, ClipGradValue) { TEST_F(NNUtilsTest, ConvertParameters) { std::vector parameters{ - torch::arange(9, torch::kFloat32), - torch::arange(9, torch::kFloat32).view({3, 3}), - torch::arange(8, torch::kFloat32).view({2, 2, 2}) - }; - - auto expected = torch::cat({ - torch::arange(9, torch::kFloat32), - torch::arange(9, torch::kFloat32).view(-1), - torch::arange(8, torch::kFloat32).view(-1) - }); + torch::arange(9, torch::kFloat32), + torch::arange(9, torch::kFloat32).view({3, 3}), + torch::arange(8, torch::kFloat32).view({2, 2, 2})}; + + auto expected = torch::cat( + {torch::arange(9, torch::kFloat32), + torch::arange(9, torch::kFloat32).view(-1), + torch::arange(8, torch::kFloat32).view(-1)}); auto vector = utils::parameters_to_vector(parameters); ASSERT_TRUE(vector.allclose(expected)); std::vector zero_parameters{ - torch::zeros({9}, torch::kFloat32), - torch::zeros({9}, torch::kFloat32).view({3, 3}), - torch::zeros({8}, torch::kFloat32).view({2, 2, 2}) - }; + torch::zeros({9}, torch::kFloat32), + torch::zeros({9}, torch::kFloat32).view({3, 3}), + torch::zeros({8}, torch::kFloat32).view({2, 2, 2})}; utils::vector_to_parameters(vector, zero_parameters); for (const auto i : c10::irange(zero_parameters.size())) { @@ -366,29 +393,30 @@ int64_t PackedSequenceTest_batch_size = 5; // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-avoid-non-const-global-variables) int64_t PackedSequenceTest_max_length = 6; -std::vector PackedSequenceTest_ordered_sequence(torch::ScalarType tensor_type) { +std::vector PackedSequenceTest_ordered_sequence( + torch::ScalarType tensor_type) { std::vector seqs; seqs.reserve(PackedSequenceTest_batch_size); for (const auto i : c10::irange(PackedSequenceTest_batch_size)) { (void)i; // Suppress unused variable warning - seqs.emplace_back(torch::empty({ - torch::randint(1, PackedSequenceTest_max_length, {1}).item() - }, tensor_type)); + seqs.emplace_back(torch::empty( + {torch::randint(1, PackedSequenceTest_max_length, {1}).item()}, + tensor_type)); } for (auto& s : seqs) { s.random_(-128, 128); } sort( - seqs.begin(), - seqs.end(), - [&](const torch::Tensor& t1, const torch::Tensor& t2) { - return t1.size(0) > t2.size(0); - } - ); + seqs.begin(), + seqs.end(), + [&](const torch::Tensor& t1, const torch::Tensor& t2) { + return t1.size(0) > t2.size(0); + }); return seqs; } -std::tuple PackedSequenceTest_padded_sequence(torch::ScalarType tensor_type) { +std::tuple PackedSequenceTest_padded_sequence( + torch::ScalarType tensor_type) { // Create Tensor of random padded sequences auto ordered = PackedSequenceTest_ordered_sequence(tensor_type); auto lengths = torch::empty({(int64_t)ordered.size()}, torch::kInt64); @@ -399,18 +427,22 @@ std::tuple PackedSequenceTest_padded_sequence(torc return std::make_tuple(padded_tensor, lengths); } -void assert_is_equal_packed_sequence(const rnn_utils::PackedSequence& a, const rnn_utils::PackedSequence& b) { +void assert_is_equal_packed_sequence( + const rnn_utils::PackedSequence& a, + const rnn_utils::PackedSequence& b) { ASSERT_TRUE(torch::allclose(a.data(), b.data())); ASSERT_TRUE(torch::allclose(a.batch_sizes(), b.batch_sizes())); ASSERT_TRUE( - (!a.sorted_indices().defined() && !b.sorted_indices().defined()) || - torch::allclose(a.sorted_indices(), b.sorted_indices())); + (!a.sorted_indices().defined() && !b.sorted_indices().defined()) || + torch::allclose(a.sorted_indices(), b.sorted_indices())); ASSERT_TRUE( - (!a.unsorted_indices().defined() && !b.unsorted_indices().defined()) || - torch::allclose(a.unsorted_indices(), b.unsorted_indices())); + (!a.unsorted_indices().defined() && !b.unsorted_indices().defined()) || + torch::allclose(a.unsorted_indices(), b.unsorted_indices())); } -void assert_is_same_packed_sequence(const rnn_utils::PackedSequence& a, const rnn_utils::PackedSequence& b) { +void assert_is_same_packed_sequence( + const rnn_utils::PackedSequence& a, + const rnn_utils::PackedSequence& b) { ASSERT_TRUE(a.data().is_same(b.data())); ASSERT_TRUE(a.batch_sizes().is_same(b.batch_sizes())); ASSERT_TRUE(a.sorted_indices().is_same(b.sorted_indices())); @@ -423,52 +455,64 @@ TEST_F(PackedSequenceTest, WrongOrder) { auto b_a = rnn_utils::pad_sequence({b, a}); // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto) ASSERT_THROW( - rnn_utils::pack_padded_sequence( - b_a, torch::tensor({22, 25}), /*batch_first=*/false, /*enforce_sorted=*/true), - c10::Error); + rnn_utils::pack_padded_sequence( + b_a, + torch::tensor({22, 25}), + /*batch_first=*/false, + /*enforce_sorted=*/true), + c10::Error); } TEST_F(PackedSequenceTest, TotalLength) { torch::Tensor padded, lengths; std::tie(padded, lengths) = PackedSequenceTest_padded_sequence(torch::kFloat); int64_t max_length = torch::max(lengths).item(); - rnn_utils::PackedSequence packed = rnn_utils::pack_padded_sequence(padded, lengths); + rnn_utils::PackedSequence packed = + rnn_utils::pack_padded_sequence(padded, lengths); // test ValueError if total_length < max_length for (int64_t total_length : std::vector{-1, 0, max_length - 1}) { for (bool batch_first : std::vector{true, false}) { auto err_fn = [&]() { rnn_utils::pad_packed_sequence( - packed, - /*batch_first=*/batch_first, - /*padding_value=*/0.0, - /*total_length=*/total_length); + packed, + /*batch_first=*/batch_first, + /*padding_value=*/0.0, + /*total_length=*/total_length); }; - ASSERT_THROWS_WITH(err_fn(), - "Expected total_length to be at least the length of the longest sequence in input"); + ASSERT_THROWS_WITH( + err_fn(), + "Expected total_length to be at least the length of the longest sequence in input"); } } // test that pad_packed_sequence returns results of correct length for (bool batch_first : std::vector{true, false}) { torch::Tensor no_extra_pad, ignored; - std::tie(no_extra_pad, ignored) = rnn_utils::pad_packed_sequence( - packed, /*batch_first=*/batch_first); + std::tie(no_extra_pad, ignored) = + rnn_utils::pad_packed_sequence(packed, /*batch_first=*/batch_first); for (int64_t total_length_delta : std::vector{0, 1, 8}) { int64_t total_length = max_length + total_length_delta; torch::Tensor unpacked, lengths_out; std::tie(unpacked, lengths_out) = rnn_utils::pad_packed_sequence( - packed, /*batch_first=*/batch_first, /*padding_value=*/0.0, /*total_length=*/total_length); + packed, + /*batch_first=*/batch_first, + /*padding_value=*/0.0, + /*total_length=*/total_length); ASSERT_TRUE(torch::allclose(lengths, lengths_out)); ASSERT_EQ(unpacked.size(batch_first ? 1 : 0), total_length); torch::Tensor ref_output, extra_pad; if (total_length_delta == 0) { ref_output = no_extra_pad; } else if (batch_first) { - extra_pad = torch::zeros({PackedSequenceTest_batch_size, total_length_delta}, no_extra_pad.options()); + extra_pad = torch::zeros( + {PackedSequenceTest_batch_size, total_length_delta}, + no_extra_pad.options()); ref_output = torch::cat({no_extra_pad, extra_pad}, 1); } else { - extra_pad = torch::zeros({total_length_delta, PackedSequenceTest_batch_size}, no_extra_pad.options()); + extra_pad = torch::zeros( + {total_length_delta, PackedSequenceTest_batch_size}, + no_extra_pad.options()); ref_output = torch::cat({no_extra_pad, extra_pad}, 0); } ASSERT_TRUE(torch::allclose(unpacked, ref_output)); @@ -481,11 +525,16 @@ TEST_F(PackedSequenceTest, To) { torch::Tensor padded, lengths; std::tie(padded, lengths) = PackedSequenceTest_padded_sequence(torch::kInt); rnn_utils::PackedSequence a = rnn_utils::pack_padded_sequence( - padded, lengths, /*batch_first=*/false, /*enforce_sorted=*/enforce_sorted).cpu(); + padded, + lengths, + /*batch_first=*/false, + /*enforce_sorted=*/enforce_sorted) + .cpu(); assert_is_same_packed_sequence(a, a.to(torch::kCPU)); assert_is_same_packed_sequence(a, a.cpu()); - assert_is_same_packed_sequence(a, a.to(torch::device(torch::kCPU).dtype(torch::kInt32))); + assert_is_same_packed_sequence( + a, a.to(torch::device(torch::kCPU).dtype(torch::kInt32))); if (torch::cuda::is_available()) { auto b = a.cuda(); @@ -493,21 +542,23 @@ TEST_F(PackedSequenceTest, To) { assert_is_same_packed_sequence(b, b.cuda()); assert_is_equal_packed_sequence(a, b.to(torch::kCPU)); assert_is_equal_packed_sequence(b, a.to(torch::kCUDA)); - assert_is_equal_packed_sequence(a, b.to(torch::device(torch::kCPU).dtype(torch::kInt32))); + assert_is_equal_packed_sequence( + a, b.to(torch::device(torch::kCPU).dtype(torch::kInt32))); assert_is_same_packed_sequence(b, b.to(torch::kInt32)); } } } TEST_F(NNUtilsTest, PackSequence) { - auto _compatibility_test = [&]( - torch::ArrayRef sequences, - torch::Tensor lengths, - bool batch_first, - bool enforce_sorted = false) { + auto _compatibility_test = [&](torch::ArrayRef sequences, + torch::Tensor lengths, + bool batch_first, + bool enforce_sorted = false) { torch::Tensor padded = rnn_utils::pad_sequence(sequences, batch_first); - rnn_utils::PackedSequence packed = rnn_utils::pack_sequence(sequences, enforce_sorted); - std::tuple unpacked = rnn_utils::pad_packed_sequence(packed, batch_first); + rnn_utils::PackedSequence packed = + rnn_utils::pack_sequence(sequences, enforce_sorted); + std::tuple unpacked = + rnn_utils::pad_packed_sequence(packed, batch_first); ASSERT_TRUE(torch::allclose(padded, std::get<0>(unpacked))); rnn_utils::PackedSequence pack_padded = rnn_utils::pack_padded_sequence( padded, lengths, batch_first, enforce_sorted); @@ -518,33 +569,42 @@ TEST_F(NNUtilsTest, PackSequence) { auto a = torch::tensor({1, 2, 3}); auto b = torch::tensor({4, 5}); auto c = torch::tensor({6}); - rnn_utils::PackedSequence packed = rnn_utils::pack_sequence({a, b, c}, /*enforce_sorted=*/false); + rnn_utils::PackedSequence packed = + rnn_utils::pack_sequence({a, b, c}, /*enforce_sorted=*/false); auto expected = torch::tensor({1, 4, 6, 2, 5, 3}); ASSERT_TRUE(torch::allclose(packed.batch_sizes(), torch::tensor({3, 2, 1}))); ASSERT_TRUE(torch::allclose(packed.data(), expected)); - ASSERT_TRUE(torch::allclose(packed.sorted_indices(), torch::tensor({0, 1, 2}))); - ASSERT_TRUE(torch::allclose(packed.unsorted_indices(), torch::tensor({0, 1, 2}))); + ASSERT_TRUE( + torch::allclose(packed.sorted_indices(), torch::tensor({0, 1, 2}))); + ASSERT_TRUE( + torch::allclose(packed.unsorted_indices(), torch::tensor({0, 1, 2}))); - rnn_utils::PackedSequence packed_unsorted = rnn_utils::pack_sequence({b, c, a}, /*enforce_sorted=*/false); - ASSERT_TRUE(torch::allclose(packed_unsorted.batch_sizes(), torch::tensor({3, 2, 1}))); + rnn_utils::PackedSequence packed_unsorted = + rnn_utils::pack_sequence({b, c, a}, /*enforce_sorted=*/false); + ASSERT_TRUE( + torch::allclose(packed_unsorted.batch_sizes(), torch::tensor({3, 2, 1}))); ASSERT_TRUE(torch::allclose(packed_unsorted.data(), expected)); - ASSERT_TRUE(torch::allclose(packed_unsorted.sorted_indices(), torch::tensor({2, 0, 1}))); - ASSERT_TRUE(torch::allclose(packed_unsorted.unsorted_indices(), torch::tensor({1, 2, 0}))); + ASSERT_TRUE(torch::allclose( + packed_unsorted.sorted_indices(), torch::tensor({2, 0, 1}))); + ASSERT_TRUE(torch::allclose( + packed_unsorted.unsorted_indices(), torch::tensor({1, 2, 0}))); // single dimensional, enforce_sorted = True - rnn_utils::PackedSequence packed_enforce_sorted = rnn_utils::pack_sequence({a, b, c}, /*enforce_sorted=*/true); - ASSERT_TRUE(torch::allclose(packed_enforce_sorted.batch_sizes(), torch::tensor({3, 2, 1}))); + rnn_utils::PackedSequence packed_enforce_sorted = + rnn_utils::pack_sequence({a, b, c}, /*enforce_sorted=*/true); + ASSERT_TRUE(torch::allclose( + packed_enforce_sorted.batch_sizes(), torch::tensor({3, 2, 1}))); ASSERT_TRUE(torch::allclose(packed_enforce_sorted.data(), expected)); ASSERT_FALSE(packed_enforce_sorted.sorted_indices().defined()); ASSERT_FALSE(packed_enforce_sorted.unsorted_indices().defined()); ASSERT_THROWS_WITH( - rnn_utils::pack_sequence({b, c, a}, /*enforce_sorted=*/true), - "must be sorted in decreasing order"); + rnn_utils::pack_sequence({b, c, a}, /*enforce_sorted=*/true), + "must be sorted in decreasing order"); ASSERT_THROWS_WITH( - rnn_utils::pack_sequence({b, c, a}, /*enforce_sorted=*/true), - "You can pass `enforce_sorted=False`"); + rnn_utils::pack_sequence({b, c, a}, /*enforce_sorted=*/true), + "You can pass `enforce_sorted=False`"); // more dimensions int64_t maxlen = 9; @@ -557,9 +617,7 @@ TEST_F(NNUtilsTest, PackSequence) { lengths_vec.emplace_back(seq_len); std::vector tensor_sizes{seq_len, 5}; tensor_sizes.insert( - tensor_sizes.end(), - trailing_dims.begin(), - trailing_dims.end()); + tensor_sizes.end(), trailing_dims.begin(), trailing_dims.end()); sequences.emplace_back(torch::rand(tensor_sizes)); } std::vector unsorted_sequences; @@ -568,9 +626,9 @@ TEST_F(NNUtilsTest, PackSequence) { unsorted_sequences.emplace_back(s.clone()); } std::shuffle( - std::begin(unsorted_sequences), - std::end(unsorted_sequences), - std::default_random_engine{}); + std::begin(unsorted_sequences), + std::end(unsorted_sequences), + std::default_random_engine{}); std::vector unsorted_sequences_lengths_vec; for (const auto& t : unsorted_sequences) { @@ -582,24 +640,25 @@ TEST_F(NNUtilsTest, PackSequence) { for (bool batch_first : std::vector{true, false}) { for (bool enforce_sorted : std::vector{true, false}) { _compatibility_test( - sequences, torch::tensor(lengths_vec), batch_first, enforce_sorted); + sequences, torch::tensor(lengths_vec), batch_first, enforce_sorted); } _compatibility_test( - unsorted_sequences, torch::tensor(unsorted_sequences_lengths_vec), batch_first); + unsorted_sequences, + torch::tensor(unsorted_sequences_lengths_vec), + batch_first); } } } TEST_F(NNUtilsTest, PackPaddedSequence) { - auto generate_test_case = [&]( - torch::ArrayRef sorted_lengths, - bool should_shuffle) { + auto generate_test_case = [&](torch::ArrayRef sorted_lengths, + bool should_shuffle) { auto pad = [&](torch::Tensor tensor, int64_t length) { std::vector tensor_sizes{length - tensor.size(0)}; tensor_sizes.insert( - tensor_sizes.end(), - tensor.sizes().slice(1).begin(), - tensor.sizes().slice(1).end()); + tensor_sizes.end(), + tensor.sizes().slice(1).begin(), + tensor.sizes().slice(1).end()); return torch::cat({tensor, torch::zeros(tensor_sizes, tensor.options())}); }; int64_t max_length = sorted_lengths[0]; @@ -611,19 +670,22 @@ TEST_F(NNUtilsTest, PackPaddedSequence) { total++; } } - batch_sizes[i-1] = total; + batch_sizes[i - 1] = total; } std::vector tensors_to_be_cat; - for (int64_t i = 1; i < static_cast(sorted_lengths.size() + 1); i++) { - int64_t l = sorted_lengths.at(i-1); - tensors_to_be_cat.emplace_back(pad(i * 100 + torch::arange(1., 5 * l + 1).view({l, 1, 5}), max_length)); + for (int64_t i = 1; i < static_cast(sorted_lengths.size() + 1); + i++) { + int64_t l = sorted_lengths.at(i - 1); + tensors_to_be_cat.emplace_back(pad( + i * 100 + torch::arange(1., 5 * l + 1).view({l, 1, 5}), max_length)); } auto padded = torch::cat(tensors_to_be_cat, 1); std::vector expected_data_vec; for (const auto n : c10::irange(batch_sizes.size(0))) { int64_t batch_size = batch_sizes[n].item(); for (const auto i : c10::irange(batch_size)) { - expected_data_vec.emplace_back(torch::arange(1., 6) + (i + 1) * 100 + 5 * n); + expected_data_vec.emplace_back( + torch::arange(1., 6) + (i + 1) * 100 + 5 * n); } } auto expected_data = torch::stack(expected_data_vec, /*dim=*/0); @@ -636,9 +698,9 @@ TEST_F(NNUtilsTest, PackPaddedSequence) { permutation.emplace_back(i); } std::shuffle( - std::begin(permutation), - std::end(permutation), - std::default_random_engine{}); + std::begin(permutation), + std::end(permutation), + std::default_random_engine{}); unsorted_indices = torch::tensor(permutation); padded = padded.index_select(1, unsorted_indices); @@ -649,15 +711,18 @@ TEST_F(NNUtilsTest, PackPaddedSequence) { } return std::make_tuple( - padded.requires_grad_(), lengths, expected_data, batch_sizes, unsorted_indices); + padded.requires_grad_(), + lengths, + expected_data, + batch_sizes, + unsorted_indices); }; std::vector, bool>> test_cases = { - // sorted_lengths, should_shuffle - {{10, 8, 4, 2, 2, 2, 1}, false}, - {{11, 10, 8, 6, 4, 3, 1}, false}, - {{11, 10, 8, 6, 4, 3, 1}, true} - }; + // sorted_lengths, should_shuffle + {{10, 8, 4, 2, 2, 2, 1}, false}, + {{11, 10, 8, 6, 4, 3, 1}, false}, + {{11, 10, 8, 6, 4, 3, 1}, true}}; for (const auto& test_case : test_cases) { for (bool batch_first : std::vector{true, false}) { @@ -665,9 +730,10 @@ TEST_F(NNUtilsTest, PackPaddedSequence) { std::vector sorted_lengths = std::get<0>(test_case); bool should_shuffle = std::get<1>(test_case); - torch::Tensor padded, lengths, expected_data, batch_sizes, unsorted_indices; - std::tie(padded, lengths, expected_data, batch_sizes, unsorted_indices) = generate_test_case( - sorted_lengths, should_shuffle); + torch::Tensor padded, lengths, expected_data, batch_sizes, + unsorted_indices; + std::tie(padded, lengths, expected_data, batch_sizes, unsorted_indices) = + generate_test_case(sorted_lengths, should_shuffle); auto src = padded; if (batch_first) { @@ -676,16 +742,21 @@ TEST_F(NNUtilsTest, PackPaddedSequence) { // check output rnn_utils::PackedSequence packed = rnn_utils::pack_padded_sequence( - src, lengths, /*batch_first=*/batch_first, /*enforce_sorted=*/!should_shuffle); + src, + lengths, + /*batch_first=*/batch_first, + /*enforce_sorted=*/!should_shuffle); ASSERT_TRUE(torch::allclose(packed.data(), expected_data)); ASSERT_TRUE(torch::allclose(packed.batch_sizes(), batch_sizes)); ASSERT_TRUE( - (!packed.unsorted_indices().defined() && !unsorted_indices.defined()) || - torch::allclose(packed.unsorted_indices(), unsorted_indices)); + (!packed.unsorted_indices().defined() && + !unsorted_indices.defined()) || + torch::allclose(packed.unsorted_indices(), unsorted_indices)); // test inverse torch::Tensor unpacked, unpacked_len; - std::tie(unpacked, unpacked_len) = rnn_utils::pad_packed_sequence(packed, /*batch_first=*/batch_first); + std::tie(unpacked, unpacked_len) = + rnn_utils::pad_packed_sequence(packed, /*batch_first=*/batch_first); ASSERT_TRUE(torch::allclose(unpacked, src)); ASSERT_TRUE(torch::allclose(unpacked_len, lengths)); @@ -706,21 +777,29 @@ TEST_F(NNUtilsTest, PackPaddedSequence) { for (const auto i : c10::irange(lengths.size(0))) { int64_t l = lengths[i].item(); ASSERT_TRUE(torch::allclose( - padded.grad().narrow(0, 0, l).select(1, i), - grad_output.narrow(0, 0, l).select(1, i))); + padded.grad().narrow(0, 0, l).select(1, i), + grad_output.narrow(0, 0, l).select(1, i))); if (l < 10) { ASSERT_EQ( - padded.grad().narrow(0, l, padded.grad().size(0) - l).select(1, i).abs().sum().item(), - 0); + padded.grad() + .narrow(0, l, padded.grad().size(0) - l) + .select(1, i) + .abs() + .sum() + .item(), + 0); } } } } // test error messages - ASSERT_THROWS_WITH(rnn_utils::pack_padded_sequence(torch::randn({3, 3}), torch::tensor({1, 3, 2})), + ASSERT_THROWS_WITH( + rnn_utils::pack_padded_sequence( + torch::randn({3, 3}), torch::tensor({1, 3, 2})), "You can pass `enforce_sorted=False`"); - ASSERT_THROWS_WITH(rnn_utils::pack_padded_sequence(torch::randn({0, 0}), torch::tensor({})), + ASSERT_THROWS_WITH( + rnn_utils::pack_padded_sequence(torch::randn({0, 0}), torch::tensor({})), "empty tensor"); } @@ -729,9 +808,9 @@ TEST_F(NNUtilsTest, PadSequence) { torch::NoGradGuard no_grad; std::vector tensor_sizes{length - tensor.size(0)}; tensor_sizes.insert( - tensor_sizes.end(), - tensor.sizes().slice(1).begin(), - tensor.sizes().slice(1).end()); + tensor_sizes.end(), + tensor.sizes().slice(1).begin(), + tensor.sizes().slice(1).end()); return torch::cat({tensor, torch::zeros(tensor_sizes, tensor.options())}); }; @@ -770,15 +849,13 @@ TEST_F(NNUtilsTest, PadSequence) { int64_t seq_len = i * i; std::vector tensor_sizes{seq_len, 5}; tensor_sizes.insert( - tensor_sizes.end(), - trailing_dims.begin(), - trailing_dims.end()); + tensor_sizes.end(), trailing_dims.begin(), trailing_dims.end()); sequences.emplace_back(torch::rand(tensor_sizes)); } std::shuffle( - std::begin(sequences), - std::end(sequences), - std::default_random_engine{}); + std::begin(sequences), + std::end(sequences), + std::default_random_engine{}); std::vector expected_tensors; for (const torch::Tensor& seq : sequences) { // NOLINTNEXTLINE(performance-inefficient-vector-operation) diff --git a/test/cpp/api/optim.cpp b/test/cpp/api/optim.cpp index 02ef6aadbcd57f..6bdc23e0f9bab4 100644 --- a/test/cpp/api/optim.cpp +++ b/test/cpp/api/optim.cpp @@ -44,7 +44,10 @@ bool test_optimizer_xor(Options options) { inputs.set_requires_grad(true); - auto step = [&](OptimizerClass& optimizer, Sequential model, torch::Tensor inputs, torch::Tensor labels) { + auto step = [&](OptimizerClass& optimizer, + Sequential model, + torch::Tensor inputs, + torch::Tensor labels) { auto closure = [&]() { optimizer.zero_grad(); auto x = model->forward(inputs); @@ -102,16 +105,24 @@ void check_exact_values( assign_parameter( parameters, "0.weight", - torch::tensor({-0.2109, -0.4976, -0.1413, -0.3420, -0.2524, 0.6976}, torch::kFloat64)); + torch::tensor( + {-0.2109, -0.4976, -0.1413, -0.3420, -0.2524, 0.6976}, + torch::kFloat64)); assign_parameter( - parameters, "0.bias", torch::tensor({-0.1085, -0.2979, 0.6892}, torch::kFloat64)); + parameters, + "0.bias", + torch::tensor({-0.1085, -0.2979, 0.6892}, torch::kFloat64)); + assign_parameter( + parameters, + "2.weight", + torch::tensor({-0.0508, -0.3941, -0.2843}, torch::kFloat64)); assign_parameter( - parameters, "2.weight", torch::tensor({-0.0508, -0.3941, -0.2843}, torch::kFloat64)); - assign_parameter(parameters, "2.bias", torch::tensor({-0.0711}, torch::kFloat64)); + parameters, "2.bias", torch::tensor({-0.0711}, torch::kFloat64)); auto optimizer = OptimizerClass(parameters.values(), options); torch::Tensor input = - torch::tensor({0.1, 0.2, 0.3, 0.4, 0.5, 0.6}, torch::kFloat64).reshape({3, 2}); + torch::tensor({0.1, 0.2, 0.3, 0.4, 0.5, 0.6}, torch::kFloat64) + .reshape({3, 2}); for (const auto i : c10::irange(kIterations)) { optimizer.zero_grad(); @@ -127,9 +138,11 @@ void check_exact_values( expected_parameters.at(i / kSampleEvery).size() == parameters.size()); for (const auto p : c10::irange(parameters.size())) { ASSERT_TRUE(parameters[p]->defined()); - // Always compare using double dtype, regardless of the original dtype of the tensors + // Always compare using double dtype, regardless of the original dtype + // of the tensors auto computed = parameters[p]->flatten().to(torch::kFloat64); - auto expected = expected_parameters.at(i / kSampleEvery).at(p).to(torch::kFloat64); + auto expected = + expected_parameters.at(i / kSampleEvery).at(p).to(torch::kFloat64); if (!computed.allclose(expected, /*rtol=*/1e-3, /*atol=*/5e-4)) { std::cout << "Iteration " << i << ": " << computed << " != " << expected << " (parameter " << p << ")" @@ -161,14 +174,17 @@ TEST(OptimTest, OptimizerAccessors) { torch::equal(params[i], params_1[i]); } - // test for add_param_group() when one or more params existing in another param_group - // are passed in the new param group to be added + // test for add_param_group() when one or more params existing in another + // param_group are passed in the new param group to be added ASSERT_THROWS_WITH( - optimizer.add_param_group(OptimizerParamGroup(params)), "some parameters appear in more than one parameter group"); + optimizer.add_param_group(OptimizerParamGroup(params)), + "some parameters appear in more than one parameter group"); // test for state() with non-const reference return - auto& state_ = static_cast(*(optimizer.state()[c10::guts::to_string(params_1[0].unsafeGetTensorImpl())])); - state_.step(state_.step()+1); + auto& state_ = static_cast( + *(optimizer + .state()[c10::guts::to_string(params_1[0].unsafeGetTensorImpl())])); + state_.step(state_.step() + 1); const auto& optimizer_ = Adagrad(params, options); optimizer_.defaults(); @@ -188,19 +204,25 @@ TEST(OptimTest, OptimizerAccessors) { 1); \ } -struct MyOptimizerOptions : public OptimizerCloneableOptions { - MyOptimizerOptions(double lr = 1.0) : lr_(lr) {}; +struct MyOptimizerOptions + : public OptimizerCloneableOptions { + MyOptimizerOptions(double lr = 1.0) : lr_(lr){}; TORCH_ARG(double, lr) = 1.0; }; TEST(OptimTest, OldInterface) { struct MyOptimizer : Optimizer { using Optimizer::Optimizer; - torch::Tensor step(LossClosure closure = nullptr) override { return {};} + torch::Tensor step(LossClosure closure = nullptr) override { + return {}; + } explicit MyOptimizer( - std::vector params, MyOptimizerOptions defaults = {}) : - // NOLINTNEXTLINE(performance-move-const-arg) - Optimizer({std::move(OptimizerParamGroup(params))}, std::make_unique(defaults)) {} + std::vector params, + MyOptimizerOptions defaults = {}) + : // NOLINTNEXTLINE(performance-move-const-arg) + Optimizer( + {std::move(OptimizerParamGroup(params))}, + std::make_unique(defaults)) {} }; std::vector parameters = { torch::ones({2, 3}), torch::zeros({2, 3}), torch::rand({2, 3})}; @@ -249,7 +271,8 @@ TEST(OptimTest, XORConvergence_SGD) { TEST(OptimTest, XORConvergence_LBFGS) { ASSERT_TRUE(test_optimizer_xor(LBFGSOptions(1.0))); - ASSERT_TRUE(test_optimizer_xor(LBFGSOptions(1.0).line_search_fn("strong_wolfe"))); + ASSERT_TRUE(test_optimizer_xor( + LBFGSOptions(1.0).line_search_fn("strong_wolfe"))); } TEST(OptimTest, XORConvergence_Adagrad) { @@ -296,8 +319,7 @@ TEST(OptimTest, XORConvergence_AdamW) { } TEST(OptimTest, XORConvergence_AdamWWithAmsgrad) { - ASSERT_TRUE(test_optimizer_xor( - AdamWOptions(0.1).amsgrad(true))); + ASSERT_TRUE(test_optimizer_xor(AdamWOptions(0.1).amsgrad(true))); } TEST(OptimTest, ProducesPyTorchValues_AdamW) { @@ -382,9 +404,7 @@ TEST(OptimTest, ProducesPyTorchValues_SGDWithWeightDecayAndNesterovMomentum) { } TEST(OptimTest, ProducesPyTorchValues_LBFGS) { - check_exact_values( - LBFGSOptions(1.0), - expected_parameters::LBFGS()); + check_exact_values(LBFGSOptions(1.0), expected_parameters::LBFGS()); } TEST(OptimTest, ProducesPyTorchValues_LBFGS_with_line_search) { @@ -461,38 +481,38 @@ TEST(OptimTest, AddParameter_LBFGS) { // REQUIRE this doesn't throw } -//Check whether the learning rate of the parameter groups in the optimizer are the -//same as the expected learning rates given in the epoch:learning rate map +// Check whether the learning rate of the parameter groups in the optimizer are +// the same as the expected learning rates given in the epoch:learning rate map void check_lr_change( Optimizer& optimizer, LRScheduler& lr_scheduler, std::map expected_epoch_lrs) { - - //Find maximum epoch in map - unsigned kIterations = - std::max_element(expected_epoch_lrs.begin(), - expected_epoch_lrs.end(), - [] (const std::pair& a, - const std::pair& b) -> bool { - return a.second > b.second; - })->first; - - for(unsigned i = 0; i <= kIterations; i++) { + // Find maximum epoch in map + unsigned kIterations = std::max_element( + expected_epoch_lrs.begin(), + expected_epoch_lrs.end(), + [](const std::pair& a, + const std::pair& b) -> bool { + return a.second > b.second; + }) + ->first; + + for (unsigned i = 0; i <= kIterations; i++) { const auto epoch_iter = expected_epoch_lrs.find(i); - if(epoch_iter != expected_epoch_lrs.end()) - { - //Compare the similarity of the two floating point learning rates - ASSERT_TRUE(fabs(epoch_iter->second - optimizer.param_groups()[0].options().get_lr()) < - std::numeric_limits::epsilon()); + if (epoch_iter != expected_epoch_lrs.end()) { + // Compare the similarity of the two floating point learning rates + ASSERT_TRUE( + fabs( + epoch_iter->second - + optimizer.param_groups()[0].options().get_lr()) < + std::numeric_limits::epsilon()); } optimizer.step(); lr_scheduler.step(); } - } TEST(OptimTest, CheckLRChange_StepLR_Adam) { - torch::Tensor parameters = torch::zeros({1}); auto optimizer = Adam({parameters}, AdamOptions().lr(1e-3)); @@ -500,11 +520,8 @@ TEST(OptimTest, CheckLRChange_StepLR_Adam) { const double gamma = 0.5; StepLR step_lr_scheduler(optimizer, step_size, gamma); - //The learning rate should have halved at epoch 20 - const std::map expected_epoch_lrs = { - {1, 1e-3}, - {25, 5e-4} - }; + // The learning rate should have halved at epoch 20 + const std::map expected_epoch_lrs = {{1, 1e-3}, {25, 5e-4}}; check_lr_change(optimizer, step_lr_scheduler, expected_epoch_lrs); } diff --git a/test/cpp/api/optim_baseline.h b/test/cpp/api/optim_baseline.h index 1c8ec9154bbafd..68841463e565d8 100644 --- a/test/cpp/api/optim_baseline.h +++ b/test/cpp/api/optim_baseline.h @@ -8,1350 +8,3052 @@ namespace expected_parameters { inline std::vector> LBFGS() { return { - { - torch::tensor({-0.20959197386869663, -0.49580870398532073, -0.1313442585372408, -0.3287331939506787, -0.24613947168465267, 0.705889510763571}), - torch::tensor({-0.10412662274500666, -0.2644705062031845, 0.7102859961803084}), - torch::tensor({-0.19787984636009417, -0.5320223708266223, -0.5396083236337847}), - torch::tensor({-0.43108206822505857}), - }, - { - torch::tensor({0.4377600774755075, 0.3828823919505896, 0.5308031277873992, 0.5752746453369446, 0.23943592910168343, 1.3739197373644627}), - torch::tensor({2.209263823172053, 2.154134023426646, 2.534834254325867}), - torch::tensor({-4.091952315741579, -4.67916063385269, -4.781279234594454}), - torch::tensor({-4.776742087865583}), - }, - { - torch::tensor({0.4377600774755075, 0.3828823919505896, 0.5308031277873992, 0.5752746453369446, 0.23943592910168343, 1.3739197373644627}), - torch::tensor({2.209263823172053, 2.154134023426646, 2.534834254325867}), - torch::tensor({-4.091952315741579, -4.67916063385269, -4.781279234594454}), - torch::tensor({-4.776742087865583}), - }, - { - torch::tensor({0.4377600774755075, 0.3828823919505896, 0.5308031277873992, 0.5752746453369446, 0.23943592910168343, 1.3739197373644627}), - torch::tensor({2.209263823172053, 2.154134023426646, 2.534834254325867}), - torch::tensor({-4.091952315741579, -4.67916063385269, -4.781279234594454}), - torch::tensor({-4.776742087865583}), - }, - { - torch::tensor({0.4377600774755075, 0.3828823919505896, 0.5308031277873992, 0.5752746453369446, 0.23943592910168343, 1.3739197373644627}), - torch::tensor({2.209263823172053, 2.154134023426646, 2.534834254325867}), - torch::tensor({-4.091952315741579, -4.67916063385269, -4.781279234594454}), - torch::tensor({-4.776742087865583}), - }, - { - torch::tensor({0.4377600774755075, 0.3828823919505896, 0.5308031277873992, 0.5752746453369446, 0.23943592910168343, 1.3739197373644627}), - torch::tensor({2.209263823172053, 2.154134023426646, 2.534834254325867}), - torch::tensor({-4.091952315741579, -4.67916063385269, -4.781279234594454}), - torch::tensor({-4.776742087865583}), - }, - { - torch::tensor({0.4377600774755075, 0.3828823919505896, 0.5308031277873992, 0.5752746453369446, 0.23943592910168343, 1.3739197373644627}), - torch::tensor({2.209263823172053, 2.154134023426646, 2.534834254325867}), - torch::tensor({-4.091952315741579, -4.67916063385269, -4.781279234594454}), - torch::tensor({-4.776742087865583}), - }, - { - torch::tensor({0.4377600774755075, 0.3828823919505896, 0.5308031277873992, 0.5752746453369446, 0.23943592910168343, 1.3739197373644627}), - torch::tensor({2.209263823172053, 2.154134023426646, 2.534834254325867}), - torch::tensor({-4.091952315741579, -4.67916063385269, -4.781279234594454}), - torch::tensor({-4.776742087865583}), - }, - { - torch::tensor({0.4377600774755075, 0.3828823919505896, 0.5308031277873992, 0.5752746453369446, 0.23943592910168343, 1.3739197373644627}), - torch::tensor({2.209263823172053, 2.154134023426646, 2.534834254325867}), - torch::tensor({-4.091952315741579, -4.67916063385269, -4.781279234594454}), - torch::tensor({-4.776742087865583}), - }, - { - torch::tensor({0.4377600774755075, 0.3828823919505896, 0.5308031277873992, 0.5752746453369446, 0.23943592910168343, 1.3739197373644627}), - torch::tensor({2.209263823172053, 2.154134023426646, 2.534834254325867}), - torch::tensor({-4.091952315741579, -4.67916063385269, -4.781279234594454}), - torch::tensor({-4.776742087865583}), - }, - { - torch::tensor({0.4377600774755075, 0.3828823919505896, 0.5308031277873992, 0.5752746453369446, 0.23943592910168343, 1.3739197373644627}), - torch::tensor({2.209263823172053, 2.154134023426646, 2.534834254325867}), - torch::tensor({-4.091952315741579, -4.67916063385269, -4.781279234594454}), - torch::tensor({-4.776742087865583}), - }, + { + torch::tensor( + {-0.20959197386869663, + -0.49580870398532073, + -0.1313442585372408, + -0.3287331939506787, + -0.24613947168465267, + 0.705889510763571}), + torch::tensor( + {-0.10412662274500666, -0.2644705062031845, 0.7102859961803084}), + torch::tensor( + {-0.19787984636009417, -0.5320223708266223, -0.5396083236337847}), + torch::tensor({-0.43108206822505857}), + }, + { + torch::tensor( + {0.4377600774755075, + 0.3828823919505896, + 0.5308031277873992, + 0.5752746453369446, + 0.23943592910168343, + 1.3739197373644627}), + torch::tensor( + {2.209263823172053, 2.154134023426646, 2.534834254325867}), + torch::tensor( + {-4.091952315741579, -4.67916063385269, -4.781279234594454}), + torch::tensor({-4.776742087865583}), + }, + { + torch::tensor( + {0.4377600774755075, + 0.3828823919505896, + 0.5308031277873992, + 0.5752746453369446, + 0.23943592910168343, + 1.3739197373644627}), + torch::tensor( + {2.209263823172053, 2.154134023426646, 2.534834254325867}), + torch::tensor( + {-4.091952315741579, -4.67916063385269, -4.781279234594454}), + torch::tensor({-4.776742087865583}), + }, + { + torch::tensor( + {0.4377600774755075, + 0.3828823919505896, + 0.5308031277873992, + 0.5752746453369446, + 0.23943592910168343, + 1.3739197373644627}), + torch::tensor( + {2.209263823172053, 2.154134023426646, 2.534834254325867}), + torch::tensor( + {-4.091952315741579, -4.67916063385269, -4.781279234594454}), + torch::tensor({-4.776742087865583}), + }, + { + torch::tensor( + {0.4377600774755075, + 0.3828823919505896, + 0.5308031277873992, + 0.5752746453369446, + 0.23943592910168343, + 1.3739197373644627}), + torch::tensor( + {2.209263823172053, 2.154134023426646, 2.534834254325867}), + torch::tensor( + {-4.091952315741579, -4.67916063385269, -4.781279234594454}), + torch::tensor({-4.776742087865583}), + }, + { + torch::tensor( + {0.4377600774755075, + 0.3828823919505896, + 0.5308031277873992, + 0.5752746453369446, + 0.23943592910168343, + 1.3739197373644627}), + torch::tensor( + {2.209263823172053, 2.154134023426646, 2.534834254325867}), + torch::tensor( + {-4.091952315741579, -4.67916063385269, -4.781279234594454}), + torch::tensor({-4.776742087865583}), + }, + { + torch::tensor( + {0.4377600774755075, + 0.3828823919505896, + 0.5308031277873992, + 0.5752746453369446, + 0.23943592910168343, + 1.3739197373644627}), + torch::tensor( + {2.209263823172053, 2.154134023426646, 2.534834254325867}), + torch::tensor( + {-4.091952315741579, -4.67916063385269, -4.781279234594454}), + torch::tensor({-4.776742087865583}), + }, + { + torch::tensor( + {0.4377600774755075, + 0.3828823919505896, + 0.5308031277873992, + 0.5752746453369446, + 0.23943592910168343, + 1.3739197373644627}), + torch::tensor( + {2.209263823172053, 2.154134023426646, 2.534834254325867}), + torch::tensor( + {-4.091952315741579, -4.67916063385269, -4.781279234594454}), + torch::tensor({-4.776742087865583}), + }, + { + torch::tensor( + {0.4377600774755075, + 0.3828823919505896, + 0.5308031277873992, + 0.5752746453369446, + 0.23943592910168343, + 1.3739197373644627}), + torch::tensor( + {2.209263823172053, 2.154134023426646, 2.534834254325867}), + torch::tensor( + {-4.091952315741579, -4.67916063385269, -4.781279234594454}), + torch::tensor({-4.776742087865583}), + }, + { + torch::tensor( + {0.4377600774755075, + 0.3828823919505896, + 0.5308031277873992, + 0.5752746453369446, + 0.23943592910168343, + 1.3739197373644627}), + torch::tensor( + {2.209263823172053, 2.154134023426646, 2.534834254325867}), + torch::tensor( + {-4.091952315741579, -4.67916063385269, -4.781279234594454}), + torch::tensor({-4.776742087865583}), + }, + { + torch::tensor( + {0.4377600774755075, + 0.3828823919505896, + 0.5308031277873992, + 0.5752746453369446, + 0.23943592910168343, + 1.3739197373644627}), + torch::tensor( + {2.209263823172053, 2.154134023426646, 2.534834254325867}), + torch::tensor( + {-4.091952315741579, -4.67916063385269, -4.781279234594454}), + torch::tensor({-4.776742087865583}), + }, }; } inline std::vector> LBFGS_with_line_search() { return { - { - torch::tensor({-0.2108988568338871, -0.4975560466422629, -0.14129216202471762, -0.3420288967865903, -0.2523635082803723, 0.6975570255493777}), - torch::tensor({-0.10853121966252377, -0.297948499687533, 0.6892015099955717}), - torch::tensor({-0.05080313011597659, -0.39413518751058996, -0.28433759745928844}), - torch::tensor({-0.07113812430174116}), - }, - { - torch::tensor({-0.2108988568338871, -0.4975560466422629, -0.14129216202471762, -0.3420288967865903, -0.2523635082803723, 0.6975570255493777}), - torch::tensor({-0.10853121966252377, -0.297948499687533, 0.6892015099955717}), - torch::tensor({-0.05080313011597659, -0.39413518751058996, -0.28433759745928844}), - torch::tensor({-0.07113812430174116}), - }, - { - torch::tensor({-0.2108988568338871, -0.4975560466422629, -0.14129216202471762, -0.3420288967865903, -0.2523635082803723, 0.6975570255493777}), - torch::tensor({-0.10853121966252377, -0.297948499687533, 0.6892015099955717}), - torch::tensor({-0.05080313011597659, -0.39413518751058996, -0.28433759745928844}), - torch::tensor({-0.07113812430174116}), - }, - { - torch::tensor({-0.2108988568338871, -0.4975560466422629, -0.14129216202471762, -0.3420288967865903, -0.2523635082803723, 0.6975570255493777}), - torch::tensor({-0.10853121966252377, -0.297948499687533, 0.6892015099955717}), - torch::tensor({-0.05080313011597659, -0.39413518751058996, -0.28433759745928844}), - torch::tensor({-0.07113812430174116}), - }, - { - torch::tensor({-0.2108988568338871, -0.4975560466422629, -0.14129216202471762, -0.3420288967865903, -0.2523635082803723, 0.6975570255493777}), - torch::tensor({-0.10853121966252377, -0.297948499687533, 0.6892015099955717}), - torch::tensor({-0.05080313011597659, -0.39413518751058996, -0.28433759745928844}), - torch::tensor({-0.07113812430174116}), - }, - { - torch::tensor({-0.2108988568338871, -0.4975560466422629, -0.14129216202471762, -0.3420288967865903, -0.2523635082803723, 0.6975570255493777}), - torch::tensor({-0.10853121966252377, -0.297948499687533, 0.6892015099955717}), - torch::tensor({-0.05080313011597659, -0.39413518751058996, -0.28433759745928844}), - torch::tensor({-0.07113812430174116}), - }, - { - torch::tensor({-0.2108988568338871, -0.4975560466422629, -0.14129216202471762, -0.3420288967865903, -0.2523635082803723, 0.6975570255493777}), - torch::tensor({-0.10853121966252377, -0.297948499687533, 0.6892015099955717}), - torch::tensor({-0.05080313011597659, -0.39413518751058996, -0.28433759745928844}), - torch::tensor({-0.07113812430174116}), - }, - { - torch::tensor({-0.2108988568338871, -0.4975560466422629, -0.14129216202471762, -0.3420288967865903, -0.2523635082803723, 0.6975570255493777}), - torch::tensor({-0.10853121966252377, -0.297948499687533, 0.6892015099955717}), - torch::tensor({-0.05080313011597659, -0.39413518751058996, -0.28433759745928844}), - torch::tensor({-0.07113812430174116}), - }, - { - torch::tensor({-0.2108988568338871, -0.4975560466422629, -0.14129216202471762, -0.3420288967865903, -0.2523635082803723, 0.6975570255493777}), - torch::tensor({-0.10853121966252377, -0.297948499687533, 0.6892015099955717}), - torch::tensor({-0.05080313011597659, -0.39413518751058996, -0.28433759745928844}), - torch::tensor({-0.07113812430174116}), - }, - { - torch::tensor({-0.2108988568338871, -0.4975560466422629, -0.14129216202471762, -0.3420288967865903, -0.2523635082803723, 0.6975570255493777}), - torch::tensor({-0.10853121966252377, -0.297948499687533, 0.6892015099955717}), - torch::tensor({-0.05080313011597659, -0.39413518751058996, -0.28433759745928844}), - torch::tensor({-0.07113812430174116}), - }, - { - torch::tensor({-0.2108988568338871, -0.4975560466422629, -0.14129216202471762, -0.3420288967865903, -0.2523635082803723, 0.6975570255493777}), - torch::tensor({-0.10853121966252377, -0.297948499687533, 0.6892015099955717}), - torch::tensor({-0.05080313011597659, -0.39413518751058996, -0.28433759745928844}), - torch::tensor({-0.07113812430174116}), - }, + { + torch::tensor( + {-0.2108988568338871, + -0.4975560466422629, + -0.14129216202471762, + -0.3420288967865903, + -0.2523635082803723, + 0.6975570255493777}), + torch::tensor( + {-0.10853121966252377, -0.297948499687533, 0.6892015099955717}), + torch::tensor( + {-0.05080313011597659, + -0.39413518751058996, + -0.28433759745928844}), + torch::tensor({-0.07113812430174116}), + }, + { + torch::tensor( + {-0.2108988568338871, + -0.4975560466422629, + -0.14129216202471762, + -0.3420288967865903, + -0.2523635082803723, + 0.6975570255493777}), + torch::tensor( + {-0.10853121966252377, -0.297948499687533, 0.6892015099955717}), + torch::tensor( + {-0.05080313011597659, + -0.39413518751058996, + -0.28433759745928844}), + torch::tensor({-0.07113812430174116}), + }, + { + torch::tensor( + {-0.2108988568338871, + -0.4975560466422629, + -0.14129216202471762, + -0.3420288967865903, + -0.2523635082803723, + 0.6975570255493777}), + torch::tensor( + {-0.10853121966252377, -0.297948499687533, 0.6892015099955717}), + torch::tensor( + {-0.05080313011597659, + -0.39413518751058996, + -0.28433759745928844}), + torch::tensor({-0.07113812430174116}), + }, + { + torch::tensor( + {-0.2108988568338871, + -0.4975560466422629, + -0.14129216202471762, + -0.3420288967865903, + -0.2523635082803723, + 0.6975570255493777}), + torch::tensor( + {-0.10853121966252377, -0.297948499687533, 0.6892015099955717}), + torch::tensor( + {-0.05080313011597659, + -0.39413518751058996, + -0.28433759745928844}), + torch::tensor({-0.07113812430174116}), + }, + { + torch::tensor( + {-0.2108988568338871, + -0.4975560466422629, + -0.14129216202471762, + -0.3420288967865903, + -0.2523635082803723, + 0.6975570255493777}), + torch::tensor( + {-0.10853121966252377, -0.297948499687533, 0.6892015099955717}), + torch::tensor( + {-0.05080313011597659, + -0.39413518751058996, + -0.28433759745928844}), + torch::tensor({-0.07113812430174116}), + }, + { + torch::tensor( + {-0.2108988568338871, + -0.4975560466422629, + -0.14129216202471762, + -0.3420288967865903, + -0.2523635082803723, + 0.6975570255493777}), + torch::tensor( + {-0.10853121966252377, -0.297948499687533, 0.6892015099955717}), + torch::tensor( + {-0.05080313011597659, + -0.39413518751058996, + -0.28433759745928844}), + torch::tensor({-0.07113812430174116}), + }, + { + torch::tensor( + {-0.2108988568338871, + -0.4975560466422629, + -0.14129216202471762, + -0.3420288967865903, + -0.2523635082803723, + 0.6975570255493777}), + torch::tensor( + {-0.10853121966252377, -0.297948499687533, 0.6892015099955717}), + torch::tensor( + {-0.05080313011597659, + -0.39413518751058996, + -0.28433759745928844}), + torch::tensor({-0.07113812430174116}), + }, + { + torch::tensor( + {-0.2108988568338871, + -0.4975560466422629, + -0.14129216202471762, + -0.3420288967865903, + -0.2523635082803723, + 0.6975570255493777}), + torch::tensor( + {-0.10853121966252377, -0.297948499687533, 0.6892015099955717}), + torch::tensor( + {-0.05080313011597659, + -0.39413518751058996, + -0.28433759745928844}), + torch::tensor({-0.07113812430174116}), + }, + { + torch::tensor( + {-0.2108988568338871, + -0.4975560466422629, + -0.14129216202471762, + -0.3420288967865903, + -0.2523635082803723, + 0.6975570255493777}), + torch::tensor( + {-0.10853121966252377, -0.297948499687533, 0.6892015099955717}), + torch::tensor( + {-0.05080313011597659, + -0.39413518751058996, + -0.28433759745928844}), + torch::tensor({-0.07113812430174116}), + }, + { + torch::tensor( + {-0.2108988568338871, + -0.4975560466422629, + -0.14129216202471762, + -0.3420288967865903, + -0.2523635082803723, + 0.6975570255493777}), + torch::tensor( + {-0.10853121966252377, -0.297948499687533, 0.6892015099955717}), + torch::tensor( + {-0.05080313011597659, + -0.39413518751058996, + -0.28433759745928844}), + torch::tensor({-0.07113812430174116}), + }, + { + torch::tensor( + {-0.2108988568338871, + -0.4975560466422629, + -0.14129216202471762, + -0.3420288967865903, + -0.2523635082803723, + 0.6975570255493777}), + torch::tensor( + {-0.10853121966252377, -0.297948499687533, 0.6892015099955717}), + torch::tensor( + {-0.05080313011597659, + -0.39413518751058996, + -0.28433759745928844}), + torch::tensor({-0.07113812430174116}), + }, }; } inline std::vector> Adam() { return { - { - torch::tensor({0.7890972864438472, 0.5024410688121617, 0.8587073313055582, 0.6579707241208395, 0.7476356819075531, 1.697556420651692}), - torch::tensor({0.891467636010675, 0.7020513497567501, 1.6892012709428947}), - torch::tensor({-1.0508030958460797, -1.3941351509567657, -1.284337577714353}), - torch::tensor({-1.071138110298716}), - }, - { - torch::tensor({8.233039313231828, 7.971150747377481, 6.6436209506776, 6.470977407900541, 6.170125488259256, 7.1507391033435015}), - torch::tensor({8.417695070103735, 6.597188212844593, 7.23175710827678}), - torch::tensor({-6.729624357635757, -7.09743493108154, -6.753301896575352}), - torch::tensor({-6.435639096011218}), - }, - { - torch::tensor({8.233424596059296, 7.971537360032308, 6.643920150720394, 6.47127807553724, 6.170405874224489, 7.151021086137982}), - torch::tensor({8.418084791214294, 6.597493171180545, 7.232043740621598}), - torch::tensor({-6.729918250724671, -7.097730102046093, -6.753584809755359}), - torch::tensor({-6.4359165566974985}), - }, - { - torch::tensor({8.233424610557648, 7.971537374586563, 6.643920161995285, 6.471278086877829, 6.170405884785074, 7.151021096766405}), - torch::tensor({8.418084805901902, 6.597493182713584, 7.2320437514477875}), - torch::tensor({-6.72991829363266, -7.097730147102975, -6.753584838821182}), - torch::tensor({-6.435916580217771}), - }, - { - torch::tensor({8.233424610575101, 7.971537374611125, 6.643920162027962, 6.471278086923278, 6.170405884809245, 7.15102109680004}), - torch::tensor({8.418084805946389, 6.597493182796847, 7.232043751509309}), - torch::tensor({-6.729918332327653, -7.097730188349552, -6.753584861205486}), - torch::tensor({-6.435916596115672}), - }, - { - torch::tensor({8.233424610594858, 7.971537374639166, 6.643920162065571, 6.471278086975759, 6.170405884836981, 7.1510210968387975}), - torch::tensor({8.418084805997614, 6.59749318289335, 7.232043751580523}), - torch::tensor({-6.72991837738045, -7.097730236373201, -6.753584887267492}), - torch::tensor({-6.43591661462546}), - }, - { - torch::tensor({8.233424610617288, 7.971537374671012, 6.643920162108285, 6.471278087035362, 6.170405884868481, 7.151021096882811}), - torch::tensor({8.418084806055795, 6.59749318300295, 7.232043751661401}), - torch::tensor({-6.729918428547273, -7.09773029091405, -6.753584916866329}), - torch::tensor({-6.4359166356471755}), - }, - { - torch::tensor({8.233424610642352, 7.9715373747065925, 6.6439201621560064, 6.471278087101955, 6.1704058849036745, 7.151021096931989}), - torch::tensor({8.418084806120799, 6.597493183125404, 7.232043751751764}), - torch::tensor({-6.729918485714688, -7.0977303518511805, -6.753584949936365}), - torch::tensor({-6.43591665913422}), - }, - { - torch::tensor({8.233424610670035, 7.97153737474589, 6.6439201622087145, 6.471278087175502, 6.170405884942545, 7.151021096986302}), - torch::tensor({8.418084806192592, 6.597493183260647, 7.232043751851564}), - torch::tensor({-6.729918548853505, -7.097730419153473, -6.753584986460725}), - torch::tensor({-6.435916685074594}), - }, - { - torch::tensor({8.233424610700348, 7.971537374788922, 6.643920162266433, 6.4712780872560405, 6.17040588498511, 7.151021097045779}), - torch::tensor({8.418084806271214, 6.597493183408747, 7.232043751960854}), - torch::tensor({-6.7299186179943, -7.097730492853521, -6.753585026457088}), - torch::tensor({-6.435916713480863}), - }, - { - torch::tensor({8.233424610733326, 7.971537374835737, 6.643920162329225, 6.471278087343659, 6.170405885031416, 7.151021097110483}), - torch::tensor({8.418084806356743, 6.597493183569867, 7.232043752079749}), - torch::tensor({-6.729918693213275, -7.097730573032567, -6.753585069969552}), - torch::tensor({-6.43591674438434}), - }, + { + torch::tensor( + {0.7890972864438472, + 0.5024410688121617, + 0.8587073313055582, + 0.6579707241208395, + 0.7476356819075531, + 1.697556420651692}), + torch::tensor( + {0.891467636010675, 0.7020513497567501, 1.6892012709428947}), + torch::tensor( + {-1.0508030958460797, -1.3941351509567657, -1.284337577714353}), + torch::tensor({-1.071138110298716}), + }, + { + torch::tensor( + {8.233039313231828, + 7.971150747377481, + 6.6436209506776, + 6.470977407900541, + 6.170125488259256, + 7.1507391033435015}), + torch::tensor( + {8.417695070103735, 6.597188212844593, 7.23175710827678}), + torch::tensor( + {-6.729624357635757, -7.09743493108154, -6.753301896575352}), + torch::tensor({-6.435639096011218}), + }, + { + torch::tensor( + {8.233424596059296, + 7.971537360032308, + 6.643920150720394, + 6.47127807553724, + 6.170405874224489, + 7.151021086137982}), + torch::tensor( + {8.418084791214294, 6.597493171180545, 7.232043740621598}), + torch::tensor( + {-6.729918250724671, -7.097730102046093, -6.753584809755359}), + torch::tensor({-6.4359165566974985}), + }, + { + torch::tensor( + {8.233424610557648, + 7.971537374586563, + 6.643920161995285, + 6.471278086877829, + 6.170405884785074, + 7.151021096766405}), + torch::tensor( + {8.418084805901902, 6.597493182713584, 7.2320437514477875}), + torch::tensor( + {-6.72991829363266, -7.097730147102975, -6.753584838821182}), + torch::tensor({-6.435916580217771}), + }, + { + torch::tensor( + {8.233424610575101, + 7.971537374611125, + 6.643920162027962, + 6.471278086923278, + 6.170405884809245, + 7.15102109680004}), + torch::tensor( + {8.418084805946389, 6.597493182796847, 7.232043751509309}), + torch::tensor( + {-6.729918332327653, -7.097730188349552, -6.753584861205486}), + torch::tensor({-6.435916596115672}), + }, + { + torch::tensor( + {8.233424610594858, + 7.971537374639166, + 6.643920162065571, + 6.471278086975759, + 6.170405884836981, + 7.1510210968387975}), + torch::tensor( + {8.418084805997614, 6.59749318289335, 7.232043751580523}), + torch::tensor( + {-6.72991837738045, -7.097730236373201, -6.753584887267492}), + torch::tensor({-6.43591661462546}), + }, + { + torch::tensor( + {8.233424610617288, + 7.971537374671012, + 6.643920162108285, + 6.471278087035362, + 6.170405884868481, + 7.151021096882811}), + torch::tensor( + {8.418084806055795, 6.59749318300295, 7.232043751661401}), + torch::tensor( + {-6.729918428547273, -7.09773029091405, -6.753584916866329}), + torch::tensor({-6.4359166356471755}), + }, + { + torch::tensor( + {8.233424610642352, + 7.9715373747065925, + 6.6439201621560064, + 6.471278087101955, + 6.1704058849036745, + 7.151021096931989}), + torch::tensor( + {8.418084806120799, 6.597493183125404, 7.232043751751764}), + torch::tensor( + {-6.729918485714688, -7.0977303518511805, -6.753584949936365}), + torch::tensor({-6.43591665913422}), + }, + { + torch::tensor( + {8.233424610670035, + 7.97153737474589, + 6.6439201622087145, + 6.471278087175502, + 6.170405884942545, + 7.151021096986302}), + torch::tensor( + {8.418084806192592, 6.597493183260647, 7.232043751851564}), + torch::tensor( + {-6.729918548853505, -7.097730419153473, -6.753584986460725}), + torch::tensor({-6.435916685074594}), + }, + { + torch::tensor( + {8.233424610700348, + 7.971537374788922, + 6.643920162266433, + 6.4712780872560405, + 6.17040588498511, + 7.151021097045779}), + torch::tensor( + {8.418084806271214, 6.597493183408747, 7.232043751960854}), + torch::tensor( + {-6.7299186179943, -7.097730492853521, -6.753585026457088}), + torch::tensor({-6.435916713480863}), + }, + { + torch::tensor( + {8.233424610733326, + 7.971537374835737, + 6.643920162329225, + 6.471278087343659, + 6.170405885031416, + 7.151021097110483}), + torch::tensor( + {8.418084806356743, 6.597493183569867, 7.232043752079749}), + torch::tensor( + {-6.729918693213275, -7.097730573032567, -6.753585069969552}), + torch::tensor({-6.43591674438434}), + }, }; } inline std::vector> Adam_with_weight_decay() { return { - { - torch::tensor({0.7890990163499767, 0.5024427688479549, 0.858707365154099, 0.65797076763247, 0.7476358193232038, 1.6975559791029715}), - torch::tensor({0.8914677624298939, 0.7020513562204098, 1.6892012237887575}), - torch::tensor({-1.050803095786311, -1.3941351504224309, -1.2843375776028747}), - torch::tensor({-1.0711381102847533}), - }, - { - torch::tensor({0.17835734655765323, 0.2542117171890537, 0.19681971909229715, 0.23522651199260597, 0.17806083719648957, 0.22943655675307303}), - torch::tensor({0.6227676931552837, 0.6058596954431213, 0.6077176546857177}), - torch::tensor({-1.4259755901844118, -1.4333355461952704, -1.408545526635006}), - torch::tensor({-2.0710783081666215}), - }, - { - torch::tensor({0.17965695035191162, 0.24254352340441693, 0.17964663531482672, 0.24250834976541322, 0.17962893833698693, 0.24249920074277215}), - torch::tensor({0.6287144967638043, 0.6286955805603279, 0.6286563093833837}), - torch::tensor({-1.4123887230853596, -1.4124007126659273, -1.4122701589749163}), - torch::tensor({-2.063357041247863}), - }, - { - torch::tensor({0.1796366651819, 0.24250861931831874, 0.17963731759793083, 0.24250861142436989, 0.1796372002681969, 0.24250890248031373}), - torch::tensor({0.6287221269294724, 0.6287225821354421, 0.6287220274975922}), - torch::tensor({-1.4123466103044011, -1.4123465669572683, -1.4123462614739388}), - torch::tensor({-2.063368365143669}), - }, - { - torch::tensor({0.17963666103165563, 0.24250882317446784, 0.17963665831217887, 0.24250882481082656, 0.17963666029066117, 0.24250882426223175}), - torch::tensor({0.6287216329900817, 0.6287216340515608, 0.6287216326960158}), - torch::tensor({-1.4123467542623926, -1.4123467542350234, -1.4123467478191443}), - torch::tensor({-2.0633690432440437}), - }, - { - torch::tensor({0.17963666098500394, 0.24250882442377164, 0.17963666099348902, 0.2425088244120223, 0.1796366609725109, 0.24250882441058697}), - torch::tensor({0.6287216343798432, 0.6287216343800675, 0.6287216343742645}), - torch::tensor({-1.4123467490742723, -1.412346749072554, -1.4123467490678536}), - torch::tensor({-2.0633690434425396}), - }, - { - torch::tensor({0.17963666098407144, 0.24250882442250174, 0.17963666098407347, 0.2425088244224233, 0.17963666098409325, 0.2425088244225157}), - torch::tensor({0.6287216343836609, 0.6287216343836147, 0.6287216343836255}), - torch::tensor({-1.412346749067226, -1.412346749067243, -1.412346749067169}), - torch::tensor({-2.063369043434909}), - }, - { - torch::tensor({0.17963666098406988, 0.2425088244224408, 0.17963666098407077, 0.24250882442244073, 0.17963666098407008, 0.2425088244224409}), - torch::tensor({0.6287216343837067, 0.6287216343837065, 0.6287216343837069}), - torch::tensor({-1.4123467490671706, -1.412346749067171, -1.4123467490671713}), - torch::tensor({-2.0633690434349057}), - }, - { - torch::tensor({0.17963666098407038, 0.24250882442244104, 0.17963666098407027, 0.24250882442244104, 0.17963666098407025, 0.24250882442244098}), - torch::tensor({0.6287216343837067, 0.628721634383707, 0.6287216343837067}), - torch::tensor({-1.4123467490671706, -1.4123467490671708, -1.4123467490671706}), - torch::tensor({-2.0633690434349052}), - }, - { - torch::tensor({0.1796366609840706, 0.24250882442244143, 0.17963666098407047, 0.24250882442244096, 0.17963666098407025, 0.24250882442244098}), - torch::tensor({0.6287216343837069, 0.6287216343837067, 0.6287216343837067}), - torch::tensor({-1.4123467490671706, -1.4123467490671706, -1.4123467490671708}), - torch::tensor({-2.0633690434349052}), - }, - { - torch::tensor({0.1796366609840692, 0.24250882442244046, 0.17963666098407022, 0.24250882442244082, 0.17963666098407, 0.24250882442244104}), - torch::tensor({0.6287216343837063, 0.6287216343837068, 0.6287216343837067}), - torch::tensor({-1.4123467490671708, -1.4123467490671706, -1.4123467490671708}), - torch::tensor({-2.0633690434349052}), - }, + { + torch::tensor( + {0.7890990163499767, + 0.5024427688479549, + 0.858707365154099, + 0.65797076763247, + 0.7476358193232038, + 1.6975559791029715}), + torch::tensor( + {0.8914677624298939, 0.7020513562204098, 1.6892012237887575}), + torch::tensor( + {-1.050803095786311, -1.3941351504224309, -1.2843375776028747}), + torch::tensor({-1.0711381102847533}), + }, + { + torch::tensor( + {0.17835734655765323, + 0.2542117171890537, + 0.19681971909229715, + 0.23522651199260597, + 0.17806083719648957, + 0.22943655675307303}), + torch::tensor( + {0.6227676931552837, 0.6058596954431213, 0.6077176546857177}), + torch::tensor( + {-1.4259755901844118, -1.4333355461952704, -1.408545526635006}), + torch::tensor({-2.0710783081666215}), + }, + { + torch::tensor( + {0.17965695035191162, + 0.24254352340441693, + 0.17964663531482672, + 0.24250834976541322, + 0.17962893833698693, + 0.24249920074277215}), + torch::tensor( + {0.6287144967638043, 0.6286955805603279, 0.6286563093833837}), + torch::tensor( + {-1.4123887230853596, -1.4124007126659273, -1.4122701589749163}), + torch::tensor({-2.063357041247863}), + }, + { + torch::tensor( + {0.1796366651819, + 0.24250861931831874, + 0.17963731759793083, + 0.24250861142436989, + 0.1796372002681969, + 0.24250890248031373}), + torch::tensor( + {0.6287221269294724, 0.6287225821354421, 0.6287220274975922}), + torch::tensor( + {-1.4123466103044011, -1.4123465669572683, -1.4123462614739388}), + torch::tensor({-2.063368365143669}), + }, + { + torch::tensor( + {0.17963666103165563, + 0.24250882317446784, + 0.17963665831217887, + 0.24250882481082656, + 0.17963666029066117, + 0.24250882426223175}), + torch::tensor( + {0.6287216329900817, 0.6287216340515608, 0.6287216326960158}), + torch::tensor( + {-1.4123467542623926, -1.4123467542350234, -1.4123467478191443}), + torch::tensor({-2.0633690432440437}), + }, + { + torch::tensor( + {0.17963666098500394, + 0.24250882442377164, + 0.17963666099348902, + 0.2425088244120223, + 0.1796366609725109, + 0.24250882441058697}), + torch::tensor( + {0.6287216343798432, 0.6287216343800675, 0.6287216343742645}), + torch::tensor( + {-1.4123467490742723, -1.412346749072554, -1.4123467490678536}), + torch::tensor({-2.0633690434425396}), + }, + { + torch::tensor( + {0.17963666098407144, + 0.24250882442250174, + 0.17963666098407347, + 0.2425088244224233, + 0.17963666098409325, + 0.2425088244225157}), + torch::tensor( + {0.6287216343836609, 0.6287216343836147, 0.6287216343836255}), + torch::tensor( + {-1.412346749067226, -1.412346749067243, -1.412346749067169}), + torch::tensor({-2.063369043434909}), + }, + { + torch::tensor( + {0.17963666098406988, + 0.2425088244224408, + 0.17963666098407077, + 0.24250882442244073, + 0.17963666098407008, + 0.2425088244224409}), + torch::tensor( + {0.6287216343837067, 0.6287216343837065, 0.6287216343837069}), + torch::tensor( + {-1.4123467490671706, -1.412346749067171, -1.4123467490671713}), + torch::tensor({-2.0633690434349057}), + }, + { + torch::tensor( + {0.17963666098407038, + 0.24250882442244104, + 0.17963666098407027, + 0.24250882442244104, + 0.17963666098407025, + 0.24250882442244098}), + torch::tensor( + {0.6287216343837067, 0.628721634383707, 0.6287216343837067}), + torch::tensor( + {-1.4123467490671706, -1.4123467490671708, -1.4123467490671706}), + torch::tensor({-2.0633690434349052}), + }, + { + torch::tensor( + {0.1796366609840706, + 0.24250882442244143, + 0.17963666098407047, + 0.24250882442244096, + 0.17963666098407025, + 0.24250882442244098}), + torch::tensor( + {0.6287216343837069, 0.6287216343837067, 0.6287216343837067}), + torch::tensor( + {-1.4123467490671706, -1.4123467490671706, -1.4123467490671708}), + torch::tensor({-2.0633690434349052}), + }, + { + torch::tensor( + {0.1796366609840692, + 0.24250882442244046, + 0.17963666098407022, + 0.24250882442244082, + 0.17963666098407, + 0.24250882442244104}), + torch::tensor( + {0.6287216343837063, 0.6287216343837068, 0.6287216343837067}), + torch::tensor( + {-1.4123467490671708, -1.4123467490671706, -1.4123467490671708}), + torch::tensor({-2.0633690434349052}), + }, }; } -inline std::vector> Adam_with_weight_decay_and_amsgrad() { +inline std::vector> +Adam_with_weight_decay_and_amsgrad() { return { - { - torch::tensor({0.7890972867575196, 0.5024410692260988, 0.8587073313091852, 0.6579707241257546, 0.7476356819241026, 1.6975564206261673}), - torch::tensor({0.8914676360248869, 0.7020513497574256, 1.6892012709389561}), - torch::tensor({-1.050803095846074, -1.3941351509567128, -1.284337577714342}), - torch::tensor({-1.0711381102987145}), - }, - { - torch::tensor({6.790598887618061, 6.914995398136696, 6.41533478566264, 6.297644005485053, 5.845162499872375, 6.862229173597117}), - torch::tensor({7.958707058914726, 6.511338624975532, 7.100969502256063}), - torch::tensor({-6.690689640539306, -7.056584601121166, -6.72114879738572}), - torch::tensor({-6.406608295022552}), - }, - { - torch::tensor({4.707506618354547, 5.291519064582759, 6.0451502264500006, 6.024403702678936, 5.309533822430375, 6.388110918107735}), - torch::tensor({7.200495189000188, 6.398387074819269, 6.904125817198589}), - torch::tensor({-6.664150387053514, -7.026716194929788, -6.705821732490459}), - torch::tensor({-6.396310969025695}), - }, - { - torch::tensor({2.9508632109188633, 3.7657755643994775, 5.60741774331852, 5.6957903028180565, 4.70145185833677, 5.835064148607041}), - torch::tensor({6.343524400109462, 6.258242740866945, 6.663022973860484}), - torch::tensor({-6.630461605133603, -6.988854886932907, -6.686194352796841}), - torch::tensor({-6.383010489575922}), - }, - { - torch::tensor({1.7128944635692829, 2.536365345915568, 5.140416924817106, 5.33803266121343, 4.083921806116116, 5.254596369238127}), - torch::tensor({5.477917349690043, 6.100068192681452, 6.394747035239918}), - torch::tensor({-6.591742325353548, -6.945355504749947, -6.663584873152447}), - torch::tensor({-6.367675348676994}), - }, - { - torch::tensor({0.9341502247285258, 1.6339620765410685, 4.6679910940755835, 4.967688979298023, 3.4933141073198866, 4.678195347615295}), - torch::tensor({4.655117321178743, 5.9294836450698245, 6.110061909503652}), - torch::tensor({-6.549116242899458, -6.897486328511809, -6.638629863638681}), - torch::tensor({-6.350731975696792}), - }, - { - torch::tensor({0.483518223008081, 1.014261673266094, 4.205928052060015, 4.596195204751035, 2.9502123780175378, 4.125826973031755}), - torch::tensor({3.90359317709392, 5.750505227817309, 5.816617309738603}), - torch::tensor({-6.503371263778851, -6.846137347295816, -6.611773185243846}), - torch::tensor({-6.3324769650576656}), - }, - { - torch::tensor({0.2393576446248765, 0.6100241101779533, 3.764601561942264, 4.231602962540335, 2.4647709193637635, 3.6096961114614476}), - torch::tensor({3.236721556734532, 5.566168160977344, 5.520085344708356}), - torch::tensor({-6.455100840665729, -6.791979259673106, -6.583347815648856}), - torch::tensor({-6.3131323905801136}), - }, - { - torch::tensor({0.11401265024593016, 0.3570521972760832, 3.350517925954931, 3.8795333009419823, 2.0402068661130683, 3.136550602110189}), - torch::tensor({2.6579570162250215, 5.378834741966309, 5.224742933241745}), - torch::tensor({-6.4047717084756375, -6.7355397098532155, -6.55361495837462}), - torch::tensor({-6.2928722155069945}), - }, - { - torch::tensor({0.05251515193791458, 0.20410212473600725, 2.9673680881961273, 3.543794405777883, 1.6752677855061209, 2.709287985107431}), - torch::tensor({2.164479166686583, 5.190372657839918, 4.9338240234040756}), - torch::tensor({-6.352761531270841, -6.6772456859648175, -6.52278570167088}), - torch::tensor({-6.271836898876738}), - }, - { - torch::tensor({0.023489947480426834, 0.11428338573638941, 2.616797245623764, 3.226821571853439, 1.3659994589608537, 2.328136084816453}), - torch::tensor({1.74978620664146, 5.002269811977871, 4.649756802441968}), - torch::tensor({-6.299382007948917, -6.617449564196286, -6.491034254081261}), - torch::tensor({-6.250142306631259}), - }, + { + torch::tensor( + {0.7890972867575196, + 0.5024410692260988, + 0.8587073313091852, + 0.6579707241257546, + 0.7476356819241026, + 1.6975564206261673}), + torch::tensor( + {0.8914676360248869, 0.7020513497574256, 1.6892012709389561}), + torch::tensor( + {-1.050803095846074, -1.3941351509567128, -1.284337577714342}), + torch::tensor({-1.0711381102987145}), + }, + { + torch::tensor( + {6.790598887618061, + 6.914995398136696, + 6.41533478566264, + 6.297644005485053, + 5.845162499872375, + 6.862229173597117}), + torch::tensor( + {7.958707058914726, 6.511338624975532, 7.100969502256063}), + torch::tensor( + {-6.690689640539306, -7.056584601121166, -6.72114879738572}), + torch::tensor({-6.406608295022552}), + }, + { + torch::tensor( + {4.707506618354547, + 5.291519064582759, + 6.0451502264500006, + 6.024403702678936, + 5.309533822430375, + 6.388110918107735}), + torch::tensor( + {7.200495189000188, 6.398387074819269, 6.904125817198589}), + torch::tensor( + {-6.664150387053514, -7.026716194929788, -6.705821732490459}), + torch::tensor({-6.396310969025695}), + }, + { + torch::tensor( + {2.9508632109188633, + 3.7657755643994775, + 5.60741774331852, + 5.6957903028180565, + 4.70145185833677, + 5.835064148607041}), + torch::tensor( + {6.343524400109462, 6.258242740866945, 6.663022973860484}), + torch::tensor( + {-6.630461605133603, -6.988854886932907, -6.686194352796841}), + torch::tensor({-6.383010489575922}), + }, + { + torch::tensor( + {1.7128944635692829, + 2.536365345915568, + 5.140416924817106, + 5.33803266121343, + 4.083921806116116, + 5.254596369238127}), + torch::tensor( + {5.477917349690043, 6.100068192681452, 6.394747035239918}), + torch::tensor( + {-6.591742325353548, -6.945355504749947, -6.663584873152447}), + torch::tensor({-6.367675348676994}), + }, + { + torch::tensor( + {0.9341502247285258, + 1.6339620765410685, + 4.6679910940755835, + 4.967688979298023, + 3.4933141073198866, + 4.678195347615295}), + torch::tensor( + {4.655117321178743, 5.9294836450698245, 6.110061909503652}), + torch::tensor( + {-6.549116242899458, -6.897486328511809, -6.638629863638681}), + torch::tensor({-6.350731975696792}), + }, + { + torch::tensor( + {0.483518223008081, + 1.014261673266094, + 4.205928052060015, + 4.596195204751035, + 2.9502123780175378, + 4.125826973031755}), + torch::tensor( + {3.90359317709392, 5.750505227817309, 5.816617309738603}), + torch::tensor( + {-6.503371263778851, -6.846137347295816, -6.611773185243846}), + torch::tensor({-6.3324769650576656}), + }, + { + torch::tensor( + {0.2393576446248765, + 0.6100241101779533, + 3.764601561942264, + 4.231602962540335, + 2.4647709193637635, + 3.6096961114614476}), + torch::tensor( + {3.236721556734532, 5.566168160977344, 5.520085344708356}), + torch::tensor( + {-6.455100840665729, -6.791979259673106, -6.583347815648856}), + torch::tensor({-6.3131323905801136}), + }, + { + torch::tensor( + {0.11401265024593016, + 0.3570521972760832, + 3.350517925954931, + 3.8795333009419823, + 2.0402068661130683, + 3.136550602110189}), + torch::tensor( + {2.6579570162250215, 5.378834741966309, 5.224742933241745}), + torch::tensor( + {-6.4047717084756375, -6.7355397098532155, -6.55361495837462}), + torch::tensor({-6.2928722155069945}), + }, + { + torch::tensor( + {0.05251515193791458, + 0.20410212473600725, + 2.9673680881961273, + 3.543794405777883, + 1.6752677855061209, + 2.709287985107431}), + torch::tensor( + {2.164479166686583, 5.190372657839918, 4.9338240234040756}), + torch::tensor( + {-6.352761531270841, -6.6772456859648175, -6.52278570167088}), + torch::tensor({-6.271836898876738}), + }, + { + torch::tensor( + {0.023489947480426834, + 0.11428338573638941, + 2.616797245623764, + 3.226821571853439, + 1.3659994589608537, + 2.328136084816453}), + torch::tensor( + {1.74978620664146, 5.002269811977871, 4.649756802441968}), + torch::tensor( + {-6.299382007948917, -6.617449564196286, -6.491034254081261}), + torch::tensor({-6.250142306631259}), + }, }; } inline std::vector> AdamW() { return { - { - torch::tensor({0.7912062750121864, 0.5074166292785842, 0.8601202529258052, 0.6613910130887053, 0.7501593169903569, 1.6905808503961983}), - torch::tensor({0.8925529482073002, 0.7050308347536254, 1.682309255842939}), - torch::tensor({-1.05029506454492, -1.3901937990816595, -1.2814942017397601}), - torch::tensor({-1.0704267290556988}), - }, - { - torch::tensor({3.3165329599188507, 3.223120441823618, 2.665544565239194, 2.6044341406663225, 2.479859063483047, 2.836831717112226}), - torch::tensor({3.3885192024669744, 2.6544147219174556, 2.8709245656887328}), - torch::tensor({-2.70172647102137, -2.836731459490802, -2.69652471546253}), - torch::tensor({-2.575239255076019}), - }, - { - torch::tensor({2.231471944853865, 2.3549328325971755, 1.5699078054795328, 1.6160272935884685, 1.5339085081403547, 1.7397405105941612}), - torch::tensor({2.8552579170807926, 1.8369866847839356, 1.9735168512425862}), - torch::tensor({-2.6042083360293855, -2.6996673713262336, -1.8976087706977893}), - torch::tensor({-1.6180915942867784}), - }, - { - torch::tensor({2.084688381515552, 2.3141612674892946, 1.4850714710140511, 1.5961047256668386, 1.440300645879787, 1.6065354941586025}), - torch::tensor({3.0111385685659444, 1.955556497153507, 1.9596562467797627}), - torch::tensor({-2.889337305884852, -2.965249100126337, -1.7721676671605975}), - torch::tensor({-1.4001341655590005}), - }, - { - torch::tensor({2.0465343456006604, 2.311613891239368, 1.4666717526896398, 1.601383980913499, 1.4223660595993763, 1.5711552625612757}), - torch::tensor({3.07151984580744, 2.0112690538174802, 1.9592484602763875}), - torch::tensor({-3.0186469726426863, -3.093855445542849, -1.7367953899738784}), - torch::tensor({-1.3299011560804312}), - }, - { - torch::tensor({2.039659777412556, 2.3178034179536273, 1.4654302718412722, 1.6094701969162322, 1.4230510816446773, 1.565168902852383}), - torch::tensor({3.1007583934270064, 2.039757113618415, 1.9652096140698696}), - torch::tensor({-3.0880626664330832, -3.166705422245348, -1.73538367534238}), - torch::tensor({-1.3130428735015893}), - }, - { - torch::tensor({2.0413773043991963, 2.3251469369586366, 1.4690808101517236, 1.6174065798291044, 1.4280274009117935, 1.5682418226469732}), - torch::tensor({3.118843540209399, 2.057729936485249, 1.9742319629710936}), - torch::tensor({-3.1331019663177013, -3.2154332694373107, -1.7459831639793468}), - torch::tensor({-1.3148644134154366}), - }, - { - torch::tensor({2.0452604138357113, 2.332074253989847, 1.4738845773449165, 1.6246403004735728, 1.4335712611625357, 1.573826630920094}), - torch::tensor({3.1324088069784093, 2.0711763619826575, 1.9841582498316732}), - torch::tensor({-3.16737058959847, -3.2529206463859146, -1.7602788393925501}), - torch::tensor({-1.32281766461531}), - }, - { - torch::tensor({2.0495243704493262, 2.338413341249581, 1.4787599440132637, 1.631210274009555, 1.438849155552895, 1.5798736919537595}), - torch::tensor({3.1438209015414227, 2.0823943437659658, 1.9940075805973108}), - torch::tensor({-3.19628690363529, -3.2845941643030367, -1.7752333900055153}), - torch::tensor({-1.332456718314933}), - }, - { - torch::tensor({2.0536979895206295, 2.3442520601250334, 1.4834272584224222, 1.6372462654983486, 1.4437517398490174, 1.585780877834892}), - torch::tensor({3.1540081461072447, 2.0923381262560454, 2.0034284957296107}), - torch::tensor({-3.222142968201519, -3.312867602521477, -1.7898220261118043}), - torch::tensor({-1.3422692037690986}), - }, - { - torch::tensor({2.0576784836825315, 2.3496934395759377, 1.4878413407927933, 1.6428612479757005, 1.4483225979568104, 1.5914034339763325}), - torch::tensor({3.163383747232199, 2.101446878895216, 2.012344413569353}), - torch::tensor({-3.246000281299229, -3.338904166978488, -1.8037666936489785}), - torch::tensor({-1.3517884775416527}), - }, + { + torch::tensor( + {0.7912062750121864, + 0.5074166292785842, + 0.8601202529258052, + 0.6613910130887053, + 0.7501593169903569, + 1.6905808503961983}), + torch::tensor( + {0.8925529482073002, 0.7050308347536254, 1.682309255842939}), + torch::tensor( + {-1.05029506454492, -1.3901937990816595, -1.2814942017397601}), + torch::tensor({-1.0704267290556988}), + }, + { + torch::tensor( + {3.3165329599188507, + 3.223120441823618, + 2.665544565239194, + 2.6044341406663225, + 2.479859063483047, + 2.836831717112226}), + torch::tensor( + {3.3885192024669744, 2.6544147219174556, 2.8709245656887328}), + torch::tensor( + {-2.70172647102137, -2.836731459490802, -2.69652471546253}), + torch::tensor({-2.575239255076019}), + }, + { + torch::tensor( + {2.231471944853865, + 2.3549328325971755, + 1.5699078054795328, + 1.6160272935884685, + 1.5339085081403547, + 1.7397405105941612}), + torch::tensor( + {2.8552579170807926, 1.8369866847839356, 1.9735168512425862}), + torch::tensor( + {-2.6042083360293855, -2.6996673713262336, -1.8976087706977893}), + torch::tensor({-1.6180915942867784}), + }, + { + torch::tensor( + {2.084688381515552, + 2.3141612674892946, + 1.4850714710140511, + 1.5961047256668386, + 1.440300645879787, + 1.6065354941586025}), + torch::tensor( + {3.0111385685659444, 1.955556497153507, 1.9596562467797627}), + torch::tensor( + {-2.889337305884852, -2.965249100126337, -1.7721676671605975}), + torch::tensor({-1.4001341655590005}), + }, + { + torch::tensor( + {2.0465343456006604, + 2.311613891239368, + 1.4666717526896398, + 1.601383980913499, + 1.4223660595993763, + 1.5711552625612757}), + torch::tensor( + {3.07151984580744, 2.0112690538174802, 1.9592484602763875}), + torch::tensor( + {-3.0186469726426863, -3.093855445542849, -1.7367953899738784}), + torch::tensor({-1.3299011560804312}), + }, + { + torch::tensor( + {2.039659777412556, + 2.3178034179536273, + 1.4654302718412722, + 1.6094701969162322, + 1.4230510816446773, + 1.565168902852383}), + torch::tensor( + {3.1007583934270064, 2.039757113618415, 1.9652096140698696}), + torch::tensor( + {-3.0880626664330832, -3.166705422245348, -1.73538367534238}), + torch::tensor({-1.3130428735015893}), + }, + { + torch::tensor( + {2.0413773043991963, + 2.3251469369586366, + 1.4690808101517236, + 1.6174065798291044, + 1.4280274009117935, + 1.5682418226469732}), + torch::tensor( + {3.118843540209399, 2.057729936485249, 1.9742319629710936}), + torch::tensor( + {-3.1331019663177013, -3.2154332694373107, -1.7459831639793468}), + torch::tensor({-1.3148644134154366}), + }, + { + torch::tensor( + {2.0452604138357113, + 2.332074253989847, + 1.4738845773449165, + 1.6246403004735728, + 1.4335712611625357, + 1.573826630920094}), + torch::tensor( + {3.1324088069784093, 2.0711763619826575, 1.9841582498316732}), + torch::tensor( + {-3.16737058959847, -3.2529206463859146, -1.7602788393925501}), + torch::tensor({-1.32281766461531}), + }, + { + torch::tensor( + {2.0495243704493262, + 2.338413341249581, + 1.4787599440132637, + 1.631210274009555, + 1.438849155552895, + 1.5798736919537595}), + torch::tensor( + {3.1438209015414227, 2.0823943437659658, 1.9940075805973108}), + torch::tensor( + {-3.19628690363529, -3.2845941643030367, -1.7752333900055153}), + torch::tensor({-1.332456718314933}), + }, + { + torch::tensor( + {2.0536979895206295, + 2.3442520601250334, + 1.4834272584224222, + 1.6372462654983486, + 1.4437517398490174, + 1.585780877834892}), + torch::tensor( + {3.1540081461072447, 2.0923381262560454, 2.0034284957296107}), + torch::tensor( + {-3.222142968201519, -3.312867602521477, -1.7898220261118043}), + torch::tensor({-1.3422692037690986}), + }, + { + torch::tensor( + {2.0576784836825315, + 2.3496934395759377, + 1.4878413407927933, + 1.6428612479757005, + 1.4483225979568104, + 1.5914034339763325}), + torch::tensor( + {3.163383747232199, 2.101446878895216, 2.012344413569353}), + torch::tensor( + {-3.246000281299229, -3.338904166978488, -1.8037666936489785}), + torch::tensor({-1.3517884775416527}), + }, }; } inline std::vector> AdamW_without_weight_decay() { return { - { - torch::tensor({0.7890972864438476, 0.5024410688121617, 0.858707331305558, 0.6579707241208395, 0.7476356819075531, 1.6975564206516922}), - torch::tensor({0.891467636010675, 0.70205134975675, 1.689201270942895}), - torch::tensor({-1.0508030958460797, -1.3941351509567654, -1.284337577714353}), - torch::tensor({-1.071138110298716}), - }, - { - torch::tensor({8.233039313231831, 7.971150747377481, 6.643620950677599, 6.47097740790054, 6.170125488259256, 7.150739103343502}), - torch::tensor({8.417695070103738, 6.597188212844593, 7.23175710827678}), - torch::tensor({-6.729624357635757, -7.09743493108154, -6.753301896575352}), - torch::tensor({-6.435639096011218}), - }, - { - torch::tensor({8.233424596059299, 7.971537360032308, 6.643920150720393, 6.471278075537239, 6.170405874224489, 7.151021086137983}), - torch::tensor({8.418084791214298, 6.597493171180545, 7.232043740621598}), - torch::tensor({-6.729918250724671, -7.097730102046093, -6.753584809755359}), - torch::tensor({-6.4359165566974985}), - }, - { - torch::tensor({8.233424610557652, 7.971537374586563, 6.643920161995284, 6.471278086877828, 6.170405884785074, 7.151021096766406}), - torch::tensor({8.418084805901906, 6.597493182713584, 7.2320437514477875}), - torch::tensor({-6.72991829363266, -7.097730147102975, -6.753584838821182}), - torch::tensor({-6.435916580217771}), - }, - { - torch::tensor({8.233424610575105, 7.971537374611125, 6.643920162027961, 6.471278086923277, 6.170405884809245, 7.151021096800041}), - torch::tensor({8.418084805946393, 6.597493182796847, 7.232043751509309}), - torch::tensor({-6.729918332327653, -7.097730188349552, -6.753584861205486}), - torch::tensor({-6.435916596115672}), - }, - { - torch::tensor({8.233424610594861, 7.971537374639166, 6.64392016206557, 6.471278086975758, 6.170405884836981, 7.151021096838798}), - torch::tensor({8.418084805997617, 6.59749318289335, 7.232043751580523}), - torch::tensor({-6.72991837738045, -7.097730236373201, -6.753584887267492}), - torch::tensor({-6.43591661462546}), - }, - { - torch::tensor({8.233424610617291, 7.971537374671012, 6.643920162108284, 6.471278087035361, 6.170405884868481, 7.151021096882812}), - torch::tensor({8.418084806055798, 6.59749318300295, 7.232043751661401}), - torch::tensor({-6.729918428547273, -7.09773029091405, -6.753584916866329}), - torch::tensor({-6.4359166356471755}), - }, - { - torch::tensor({8.233424610642356, 7.9715373747065925, 6.643920162156006, 6.471278087101954, 6.1704058849036745, 7.15102109693199}), - torch::tensor({8.418084806120802, 6.597493183125404, 7.232043751751764}), - torch::tensor({-6.729918485714688, -7.0977303518511805, -6.753584949936365}), - torch::tensor({-6.43591665913422}), - }, - { - torch::tensor({8.233424610670038, 7.97153737474589, 6.643920162208714, 6.471278087175501, 6.170405884942545, 7.151021096986303}), - torch::tensor({8.418084806192596, 6.597493183260647, 7.232043751851564}), - torch::tensor({-6.729918548853505, -7.097730419153473, -6.753584986460725}), - torch::tensor({-6.435916685074594}), - }, - { - torch::tensor({8.233424610700352, 7.971537374788922, 6.643920162266432, 6.47127808725604, 6.17040588498511, 7.1510210970457795}), - torch::tensor({8.418084806271217, 6.597493183408747, 7.232043751960854}), - torch::tensor({-6.7299186179943, -7.097730492853521, -6.753585026457088}), - torch::tensor({-6.435916713480863}), - }, - { - torch::tensor({8.23342461073333, 7.971537374835737, 6.643920162329224, 6.471278087343658, 6.170405885031416, 7.151021097110484}), - torch::tensor({8.418084806356747, 6.597493183569867, 7.232043752079749}), - torch::tensor({-6.729918693213275, -7.097730573032567, -6.753585069969552}), - torch::tensor({-6.43591674438434}), - }, + { + torch::tensor( + {0.7890972864438476, + 0.5024410688121617, + 0.858707331305558, + 0.6579707241208395, + 0.7476356819075531, + 1.6975564206516922}), + torch::tensor( + {0.891467636010675, 0.70205134975675, 1.689201270942895}), + torch::tensor( + {-1.0508030958460797, -1.3941351509567654, -1.284337577714353}), + torch::tensor({-1.071138110298716}), + }, + { + torch::tensor( + {8.233039313231831, + 7.971150747377481, + 6.643620950677599, + 6.47097740790054, + 6.170125488259256, + 7.150739103343502}), + torch::tensor( + {8.417695070103738, 6.597188212844593, 7.23175710827678}), + torch::tensor( + {-6.729624357635757, -7.09743493108154, -6.753301896575352}), + torch::tensor({-6.435639096011218}), + }, + { + torch::tensor( + {8.233424596059299, + 7.971537360032308, + 6.643920150720393, + 6.471278075537239, + 6.170405874224489, + 7.151021086137983}), + torch::tensor( + {8.418084791214298, 6.597493171180545, 7.232043740621598}), + torch::tensor( + {-6.729918250724671, -7.097730102046093, -6.753584809755359}), + torch::tensor({-6.4359165566974985}), + }, + { + torch::tensor( + {8.233424610557652, + 7.971537374586563, + 6.643920161995284, + 6.471278086877828, + 6.170405884785074, + 7.151021096766406}), + torch::tensor( + {8.418084805901906, 6.597493182713584, 7.2320437514477875}), + torch::tensor( + {-6.72991829363266, -7.097730147102975, -6.753584838821182}), + torch::tensor({-6.435916580217771}), + }, + { + torch::tensor( + {8.233424610575105, + 7.971537374611125, + 6.643920162027961, + 6.471278086923277, + 6.170405884809245, + 7.151021096800041}), + torch::tensor( + {8.418084805946393, 6.597493182796847, 7.232043751509309}), + torch::tensor( + {-6.729918332327653, -7.097730188349552, -6.753584861205486}), + torch::tensor({-6.435916596115672}), + }, + { + torch::tensor( + {8.233424610594861, + 7.971537374639166, + 6.64392016206557, + 6.471278086975758, + 6.170405884836981, + 7.151021096838798}), + torch::tensor( + {8.418084805997617, 6.59749318289335, 7.232043751580523}), + torch::tensor( + {-6.72991837738045, -7.097730236373201, -6.753584887267492}), + torch::tensor({-6.43591661462546}), + }, + { + torch::tensor( + {8.233424610617291, + 7.971537374671012, + 6.643920162108284, + 6.471278087035361, + 6.170405884868481, + 7.151021096882812}), + torch::tensor( + {8.418084806055798, 6.59749318300295, 7.232043751661401}), + torch::tensor( + {-6.729918428547273, -7.09773029091405, -6.753584916866329}), + torch::tensor({-6.4359166356471755}), + }, + { + torch::tensor( + {8.233424610642356, + 7.9715373747065925, + 6.643920162156006, + 6.471278087101954, + 6.1704058849036745, + 7.15102109693199}), + torch::tensor( + {8.418084806120802, 6.597493183125404, 7.232043751751764}), + torch::tensor( + {-6.729918485714688, -7.0977303518511805, -6.753584949936365}), + torch::tensor({-6.43591665913422}), + }, + { + torch::tensor( + {8.233424610670038, + 7.97153737474589, + 6.643920162208714, + 6.471278087175501, + 6.170405884942545, + 7.151021096986303}), + torch::tensor( + {8.418084806192596, 6.597493183260647, 7.232043751851564}), + torch::tensor( + {-6.729918548853505, -7.097730419153473, -6.753584986460725}), + torch::tensor({-6.435916685074594}), + }, + { + torch::tensor( + {8.233424610700352, + 7.971537374788922, + 6.643920162266432, + 6.47127808725604, + 6.17040588498511, + 7.1510210970457795}), + torch::tensor( + {8.418084806271217, 6.597493183408747, 7.232043751960854}), + torch::tensor( + {-6.7299186179943, -7.097730492853521, -6.753585026457088}), + torch::tensor({-6.435916713480863}), + }, + { + torch::tensor( + {8.23342461073333, + 7.971537374835737, + 6.643920162329224, + 6.471278087343658, + 6.170405885031416, + 7.151021097110484}), + torch::tensor( + {8.418084806356747, 6.597493183569867, 7.232043752079749}), + torch::tensor( + {-6.729918693213275, -7.097730573032567, -6.753585069969552}), + torch::tensor({-6.43591674438434}), + }, }; } inline std::vector> AdamW_with_amsgrad() { return { - { - torch::tensor({0.7912062750121864, 0.5074166292785842, 0.8601202529258052, 0.6613910130887053, 0.7501593169903569, 1.6905808503961983}), - torch::tensor({0.8925529482073002, 0.7050308347536254, 1.682309255842939}), - torch::tensor({-1.05029506454492, -1.3901937990816595, -1.2814942017397601}), - torch::tensor({-1.0704267290556988}), - }, - { - torch::tensor({3.3017259270507915, 3.2082991753694565, 2.653930978510442, 2.5927674339810585, 2.4689608790182933, 2.825873703467739}), - torch::tensor({3.373698198112671, 2.6425942964586664, 2.8597930424244304}), - torch::tensor({-2.690360632302962, -2.8253191596069525, -2.6855499873057473}), - torch::tensor({-2.5644658591929406}), - }, - { - torch::tensor({2.222607725541013, 2.3447188854637004, 1.5614270655258826, 1.606610018462357, 1.5260497191448619, 1.7309643622674138}), - torch::tensor({2.84137462783552, 1.824806600633721, 1.9620493659996037}), - torch::tensor({-2.576642773625787, -2.6706153846815766, -1.8799876863754623}), - torch::tensor({-1.6044722984810953}), - }, - { - torch::tensor({2.0739558768648205, 2.3008338863863496, 1.4738888208638767, 1.5829485271829449, 1.4296176764284294, 1.5939984909850073}), - torch::tensor({2.9908013612792415, 1.936590940953305, 1.941691630199464}), - torch::tensor({-2.846562884997548, -2.9195962101501203, -1.746484716887341}), - torch::tensor({-1.381525131003179}), - }, - { - torch::tensor({2.0333926094256953, 2.294977109171754, 1.452870514716895, 1.584853677999522, 1.4086299433181402, 1.5548201727855224}), - torch::tensor({3.0454817801193976, 1.9867169062383696, 1.935312753106444}), - torch::tensor({-2.9612116762746394, -3.0322275992001084, -1.7026114905180725}), - torch::tensor({-1.30563541393247}), - }, - { - torch::tensor({2.02392417168201, 2.2977279587859, 1.4488511120131309, 1.5894646930743725, 1.406253686759073, 1.5450647949022756}), - torch::tensor({3.069025602955343, 2.0096872138967488, 1.935438546309299}), - torch::tensor({-3.016103148166836, -3.0893062953033583, -1.6925290685615872}), - torch::tensor({-1.282870120405012}), - }, - { - torch::tensor({2.0230257817348316, 2.3016167065040647, 1.4496901978629444, 1.5939034289777392, 1.4082421794430946, 1.5444538756003756}), - torch::tensor({3.0814016132787954, 2.022150201844143, 1.9387429991308658}), - torch::tensor({-3.0466485946438406, -3.1223144611322446, -1.6944083009127773}), - torch::tensor({-1.2786980736911064}), - }, - { - torch::tensor({2.024305922065404, 2.305101549510461, 1.4516863420588493, 1.5976447954882376, 1.4108596097183552, 1.5464236284303425}), - torch::tensor({3.0892574233545065, 2.0300944858242236, 1.943040321845021}), - torch::tensor({-3.0664189567897306, -3.1440820888166425, -1.6999750448893618}), - torch::tensor({-1.2806281811826203}), - }, - { - torch::tensor({2.0259854116237364, 2.3080169152218293, 1.4537680296813915, 1.6007369432392426, 1.4132529823064277, 1.5489035046525346}), - torch::tensor({3.0949717180851337, 2.0358251379915764, 1.9473249654893}), - torch::tensor({-3.0808231377905426, -3.160021699873689, -1.7062031001273494}), - torch::tensor({-1.28423401369705}), - }, - { - torch::tensor({2.0275923948348638, 2.3104512601389637, 1.455657715721078, 1.6033123357613526, 1.4153003204463288, 1.5512775896116622}), - torch::tensor({3.099479021299846, 2.0403012223048775, 1.9512285931847464}), - torch::tensor({-3.092151979336299, -3.1725453680885267, -1.7120689614428697}), - torch::tensor({-1.2880095517062655}), - }, - { - torch::tensor({2.029022468328371, 2.3125066985045892, 1.4573100228823295, 1.605484259933419, 1.417038245960655, 1.5533932056240227}), - torch::tensor({3.103195011518616, 2.043964003458376, 1.9546640840748621}), - torch::tensor({-3.1014680843131184, -3.1828179298513968, -1.7172933346797972}), - torch::tensor({-1.2914899987134136}), - }, + { + torch::tensor( + {0.7912062750121864, + 0.5074166292785842, + 0.8601202529258052, + 0.6613910130887053, + 0.7501593169903569, + 1.6905808503961983}), + torch::tensor( + {0.8925529482073002, 0.7050308347536254, 1.682309255842939}), + torch::tensor( + {-1.05029506454492, -1.3901937990816595, -1.2814942017397601}), + torch::tensor({-1.0704267290556988}), + }, + { + torch::tensor( + {3.3017259270507915, + 3.2082991753694565, + 2.653930978510442, + 2.5927674339810585, + 2.4689608790182933, + 2.825873703467739}), + torch::tensor( + {3.373698198112671, 2.6425942964586664, 2.8597930424244304}), + torch::tensor( + {-2.690360632302962, -2.8253191596069525, -2.6855499873057473}), + torch::tensor({-2.5644658591929406}), + }, + { + torch::tensor( + {2.222607725541013, + 2.3447188854637004, + 1.5614270655258826, + 1.606610018462357, + 1.5260497191448619, + 1.7309643622674138}), + torch::tensor( + {2.84137462783552, 1.824806600633721, 1.9620493659996037}), + torch::tensor( + {-2.576642773625787, -2.6706153846815766, -1.8799876863754623}), + torch::tensor({-1.6044722984810953}), + }, + { + torch::tensor( + {2.0739558768648205, + 2.3008338863863496, + 1.4738888208638767, + 1.5829485271829449, + 1.4296176764284294, + 1.5939984909850073}), + torch::tensor( + {2.9908013612792415, 1.936590940953305, 1.941691630199464}), + torch::tensor( + {-2.846562884997548, -2.9195962101501203, -1.746484716887341}), + torch::tensor({-1.381525131003179}), + }, + { + torch::tensor( + {2.0333926094256953, + 2.294977109171754, + 1.452870514716895, + 1.584853677999522, + 1.4086299433181402, + 1.5548201727855224}), + torch::tensor( + {3.0454817801193976, 1.9867169062383696, 1.935312753106444}), + torch::tensor( + {-2.9612116762746394, -3.0322275992001084, -1.7026114905180725}), + torch::tensor({-1.30563541393247}), + }, + { + torch::tensor( + {2.02392417168201, + 2.2977279587859, + 1.4488511120131309, + 1.5894646930743725, + 1.406253686759073, + 1.5450647949022756}), + torch::tensor( + {3.069025602955343, 2.0096872138967488, 1.935438546309299}), + torch::tensor( + {-3.016103148166836, -3.0893062953033583, -1.6925290685615872}), + torch::tensor({-1.282870120405012}), + }, + { + torch::tensor( + {2.0230257817348316, + 2.3016167065040647, + 1.4496901978629444, + 1.5939034289777392, + 1.4082421794430946, + 1.5444538756003756}), + torch::tensor( + {3.0814016132787954, 2.022150201844143, 1.9387429991308658}), + torch::tensor( + {-3.0466485946438406, -3.1223144611322446, -1.6944083009127773}), + torch::tensor({-1.2786980736911064}), + }, + { + torch::tensor( + {2.024305922065404, + 2.305101549510461, + 1.4516863420588493, + 1.5976447954882376, + 1.4108596097183552, + 1.5464236284303425}), + torch::tensor( + {3.0892574233545065, 2.0300944858242236, 1.943040321845021}), + torch::tensor( + {-3.0664189567897306, -3.1440820888166425, -1.6999750448893618}), + torch::tensor({-1.2806281811826203}), + }, + { + torch::tensor( + {2.0259854116237364, + 2.3080169152218293, + 1.4537680296813915, + 1.6007369432392426, + 1.4132529823064277, + 1.5489035046525346}), + torch::tensor( + {3.0949717180851337, 2.0358251379915764, 1.9473249654893}), + torch::tensor( + {-3.0808231377905426, -3.160021699873689, -1.7062031001273494}), + torch::tensor({-1.28423401369705}), + }, + { + torch::tensor( + {2.0275923948348638, + 2.3104512601389637, + 1.455657715721078, + 1.6033123357613526, + 1.4153003204463288, + 1.5512775896116622}), + torch::tensor( + {3.099479021299846, 2.0403012223048775, 1.9512285931847464}), + torch::tensor( + {-3.092151979336299, -3.1725453680885267, -1.7120689614428697}), + torch::tensor({-1.2880095517062655}), + }, + { + torch::tensor( + {2.029022468328371, + 2.3125066985045892, + 1.4573100228823295, + 1.605484259933419, + 1.417038245960655, + 1.5533932056240227}), + torch::tensor( + {3.103195011518616, 2.043964003458376, 1.9546640840748621}), + torch::tensor( + {-3.1014680843131184, -3.1828179298513968, -1.7172933346797972}), + torch::tensor({-1.2914899987134136}), + }, }; } inline std::vector> Adagrad() { return { - { - torch::tensor({0.7891011045987429, 0.502443924512199, 0.8587078329085825, 0.6579710994224826, 0.7476364836215006, 1.697557019500397}), - torch::tensor({0.8914687688941954, 0.7020514988069096, 1.6892015076050444}), - torch::tensor({-1.0508031297732776, -1.3941351871450518, -1.284337597261839}), - torch::tensor({-1.071138124161711}), - }, - { - torch::tensor({2.4079229696892583, 2.2346803754764286, 1.6967885588547365, 1.552279695827649, 1.2259044248443602, 2.221279696180243}), - torch::tensor({2.9334079162217193, 1.7619824934767887, 2.3464577179091473}), - torch::tensor({-2.221396083069719, -2.549950976011168, -1.9709315957317095}), - torch::tensor({-1.5858816837541876}), - }, - { - torch::tensor({2.510404433941812, 2.3522584510262887, 1.7921695110761213, 1.657755825836846, 1.2891186618593045, 2.291878516133922}), - torch::tensor({3.092171180776419, 1.8971624370952997, 2.438734251283465}), - torch::tensor({-2.437641633486504, -2.7704264590526573, -2.0949471699460225}), - torch::tensor({-1.6769121890401757}), - }, - { - torch::tensor({2.5652648968109415, 2.4155313947260972, 1.844241233613541, 1.7156513351246399, 1.3245206506797171, 2.3315409972138825}), - torch::tensor({3.178399916514377, 1.9721945764936502, 2.4909037706250428}), - torch::tensor({-2.5658710403147933, -2.901921821645266, -2.168560672193225}), - torch::tensor({-1.7307903926154131}), - }, - { - torch::tensor({2.6021584494332592, 2.4582101324909065, 1.8796060082750778, 1.7550965207414717, 1.3489253597999988, 2.3589345190118247}), - torch::tensor({3.2368674310041516, 2.0236468833666894, 2.52707132741292}), - torch::tensor({-2.6573969292994164, -2.9960731060650505, -2.2211375717304076}), - torch::tensor({-1.7692090167089707}), - }, - { - torch::tensor({2.629700772579208, 2.4901377017698683, 1.906173477530586, 1.7847957161833832, 1.3674517119505822, 2.3797578857769905}), - torch::tensor({3.2807643102638546, 2.062561811940094, 2.5546379424362775}), - torch::tensor({-2.7286379977755035, -3.0695109399636236, -2.262081199960513}), - torch::tensor({-1.7990936323432214}), - }, - { - torch::tensor({2.6515471766995247, 2.51550257362603, 1.927341363452414, 1.8084994719811576, 1.3823309942932445, 2.3964995243914373}), - torch::tensor({3.3157334001309473, 2.093728023484945, 2.5768468697402924}), - torch::tensor({-2.786981763434855, -3.129746439571402, -2.29562487034177}), - torch::tensor({-1.8235564908139104}), - }, - { - torch::tensor({2.6695780544837886, 2.53646401614724, 1.9448721033433505, 1.828157582353901, 1.3947329882074622, 2.4104657178934947}), - torch::tensor({3.344694775590452, 2.1196465761628516, 2.5954050923596252}), - torch::tensor({-2.8363936812536537, -3.1808219609745194, -2.32404190866147}), - torch::tensor({-1.8442667636913117}), - }, - { - torch::tensor({2.684883801533072, 2.5542762192735515, 1.9597939532350015, 1.844909608012419, 1.4053459079217485, 2.4224257790968386}), - torch::tensor({3.369349515259956, 2.1417845308976795, 2.611319989214332}), - torch::tensor({-2.879251075341889, -3.225165734647855, -2.3486956737228057}), - torch::tensor({-1.86222449978646}), - }, - { - torch::tensor({2.698151012423769, 2.56972998600169, 1.972757472697587, 1.8594775691681182, 1.4146081751022495, 2.43287021079559}), - torch::tensor({3.390772758897601, 2.1610741754331757, 2.6252349489549824}), - torch::tensor({-2.917092322961074, -3.264351563375218, -2.370468664387175}), - torch::tensor({-1.8780765115117757}), - }, - { - torch::tensor({2.7098389356033783, 2.5833548721723747, 1.9841994925173085, 1.8723468731726323, 1.4228158926355312, 2.4421305315945085}), - torch::tensor({3.4096859099156673, 2.178143852041279, 2.6375854547611364}), - torch::tensor({-2.9509704554208467, -3.2994581338995044, -2.3899651139415874}), - torch::tensor({-1.8922653655195538}), - }, + { + torch::tensor( + {0.7891011045987429, + 0.502443924512199, + 0.8587078329085825, + 0.6579710994224826, + 0.7476364836215006, + 1.697557019500397}), + torch::tensor( + {0.8914687688941954, 0.7020514988069096, 1.6892015076050444}), + torch::tensor( + {-1.0508031297732776, -1.3941351871450518, -1.284337597261839}), + torch::tensor({-1.071138124161711}), + }, + { + torch::tensor( + {2.4079229696892583, + 2.2346803754764286, + 1.6967885588547365, + 1.552279695827649, + 1.2259044248443602, + 2.221279696180243}), + torch::tensor( + {2.9334079162217193, 1.7619824934767887, 2.3464577179091473}), + torch::tensor( + {-2.221396083069719, -2.549950976011168, -1.9709315957317095}), + torch::tensor({-1.5858816837541876}), + }, + { + torch::tensor( + {2.510404433941812, + 2.3522584510262887, + 1.7921695110761213, + 1.657755825836846, + 1.2891186618593045, + 2.291878516133922}), + torch::tensor( + {3.092171180776419, 1.8971624370952997, 2.438734251283465}), + torch::tensor( + {-2.437641633486504, -2.7704264590526573, -2.0949471699460225}), + torch::tensor({-1.6769121890401757}), + }, + { + torch::tensor( + {2.5652648968109415, + 2.4155313947260972, + 1.844241233613541, + 1.7156513351246399, + 1.3245206506797171, + 2.3315409972138825}), + torch::tensor( + {3.178399916514377, 1.9721945764936502, 2.4909037706250428}), + torch::tensor( + {-2.5658710403147933, -2.901921821645266, -2.168560672193225}), + torch::tensor({-1.7307903926154131}), + }, + { + torch::tensor( + {2.6021584494332592, + 2.4582101324909065, + 1.8796060082750778, + 1.7550965207414717, + 1.3489253597999988, + 2.3589345190118247}), + torch::tensor( + {3.2368674310041516, 2.0236468833666894, 2.52707132741292}), + torch::tensor( + {-2.6573969292994164, -2.9960731060650505, -2.2211375717304076}), + torch::tensor({-1.7692090167089707}), + }, + { + torch::tensor( + {2.629700772579208, + 2.4901377017698683, + 1.906173477530586, + 1.7847957161833832, + 1.3674517119505822, + 2.3797578857769905}), + torch::tensor( + {3.2807643102638546, 2.062561811940094, 2.5546379424362775}), + torch::tensor( + {-2.7286379977755035, -3.0695109399636236, -2.262081199960513}), + torch::tensor({-1.7990936323432214}), + }, + { + torch::tensor( + {2.6515471766995247, + 2.51550257362603, + 1.927341363452414, + 1.8084994719811576, + 1.3823309942932445, + 2.3964995243914373}), + torch::tensor( + {3.3157334001309473, 2.093728023484945, 2.5768468697402924}), + torch::tensor( + {-2.786981763434855, -3.129746439571402, -2.29562487034177}), + torch::tensor({-1.8235564908139104}), + }, + { + torch::tensor( + {2.6695780544837886, + 2.53646401614724, + 1.9448721033433505, + 1.828157582353901, + 1.3947329882074622, + 2.4104657178934947}), + torch::tensor( + {3.344694775590452, 2.1196465761628516, 2.5954050923596252}), + torch::tensor( + {-2.8363936812536537, -3.1808219609745194, -2.32404190866147}), + torch::tensor({-1.8442667636913117}), + }, + { + torch::tensor( + {2.684883801533072, + 2.5542762192735515, + 1.9597939532350015, + 1.844909608012419, + 1.4053459079217485, + 2.4224257790968386}), + torch::tensor( + {3.369349515259956, 2.1417845308976795, 2.611319989214332}), + torch::tensor( + {-2.879251075341889, -3.225165734647855, -2.3486956737228057}), + torch::tensor({-1.86222449978646}), + }, + { + torch::tensor( + {2.698151012423769, + 2.56972998600169, + 1.972757472697587, + 1.8594775691681182, + 1.4146081751022495, + 2.43287021079559}), + torch::tensor( + {3.390772758897601, 2.1610741754331757, 2.6252349489549824}), + torch::tensor( + {-2.917092322961074, -3.264351563375218, -2.370468664387175}), + torch::tensor({-1.8780765115117757}), + }, + { + torch::tensor( + {2.7098389356033783, + 2.5833548721723747, + 1.9841994925173085, + 1.8723468731726323, + 1.4228158926355312, + 2.4421305315945085}), + torch::tensor( + {3.4096859099156673, 2.178143852041279, 2.6375854547611364}), + torch::tensor( + {-2.9509704554208467, -3.2994581338995044, -2.3899651139415874}), + torch::tensor({-1.8922653655195538}), + }, }; } inline std::vector> Adagrad_with_weight_decay() { return { - { - torch::tensor({0.7891011218979068, 0.5024439415126254, 0.8587078332470682, 0.6579710998575992, 0.7476364849956589, 1.6975570150849029}), - torch::tensor({0.8914687701583902, 0.7020514988715463, 1.6892015071335027}), - torch::tensor({-1.0508031297726799, -1.3941351871397083, -1.2843375972607243}), - torch::tensor({-1.0711381241615712}), - }, - { - torch::tensor({0.1846116678522213, 0.24944077103107917, 0.18651745437755768, 0.25219093533041764, 0.18712037968446713, 0.25289206444055234}), - torch::tensor({0.6482869597891656, 0.6580215784646755, 0.6581256007663537}), - torch::tensor({-1.454709711443681, -1.4748063405174818, -1.4811625946604765}), - torch::tensor({-1.905292836544363}), - }, - { - torch::tensor({0.18059895999281475, 0.2438515539257779, 0.18067177884778182, 0.24397186395008694, 0.18168388351830797, 0.24533853846052017}), - torch::tensor({0.6325250261983028, 0.6331827793513023, 0.6366659383355598}), - torch::tensor({-1.420803333750877, -1.4215627240541653, -1.4320264544533396}), - torch::tensor({-2.030135641848322}), - }, - { - torch::tensor({0.17981392697398363, 0.2427571544305695, 0.17981150414451733, 0.24275725992310523, 0.18014798619115763, 0.2432144956227816}), - torch::tensor({0.6294321320817985, 0.6294873737410742, 0.6306958589251878}), - torch::tensor({-1.4139253354785764, -1.413902680470981, -1.4173628530293867}), - torch::tensor({-2.056210117690093}), - }, - { - torch::tensor({0.17967006242163747, 0.24255582734557277, 0.17966873677301953, 0.24255462870545766, 0.1797588230898851, 0.24267729072765756}), - torch::tensor({0.6288576295241085, 0.6288643132826753, 0.6291921485342001}), - torch::tensor({-1.4126465879787569, -1.4126335126907266, -1.4135586793353685}), - torch::tensor({-2.0618018405404825}), - }, - { - torch::tensor({0.17964321284685653, 0.24251808241139364, 0.17964291377171066, 0.24251779104198598, 0.1796651574178102, 0.2425481059085456}), - torch::tensor({0.628748693136779, 0.6287498167193976, 0.6288312441271762}), - torch::tensor({-1.412405895385289, -1.4124029484481164, -1.4126313051315378}), - torch::tensor({-2.0630223163099304}), - }, - { - torch::tensor({0.1796379973927849, 0.2425107191246215, 0.1796379363134342, 0.2425106592536174, 0.17964321802205502, 0.24251786094585864}), - torch::tensor({0.6287272170354161, 0.6287274414587727, 0.6287468362309863}), - torch::tensor({-1.4123588626342634, -1.412358263650784, -1.412412480569672}), - torch::tensor({-2.0632918101480255}), - }, - { - torch::tensor({0.17963694231402444, 0.24250922426893615, 0.1796369298071621, 0.2425092121007451, 0.17963815759528073, 0.24251088666939838}), - torch::tensor({0.6287228195255881, 0.6287228675172439, 0.628727383936762}), - torch::tensor({-1.4123493065102781, -1.412349184462438, -1.4123617872243597}), - torch::tensor({-2.063351765096138}), - }, - { - torch::tensor({0.1796367215904632, 0.24250891070978667, 0.17963671897003045, 0.24250890818318074, 0.17963700091084855, 0.24250929278196054}), - torch::tensor({0.6287218911936107, 0.6287219017313679, 0.6287229399204574}), - torch::tensor({-1.4123473011084142, -1.412347275640343, -1.41235016959507}), - torch::tensor({-2.0633651674043505}), - }, - { - torch::tensor({0.17963667424978783, 0.24250884333120687, 0.17963667368829764, 0.24250884279379462, 0.17963673796131557, 0.24250893047794023}), - torch::tensor({0.6287216908150736, 0.6287216931558691, 0.6287219299749583}), - torch::tensor({-1.4123468700596724, -1.4123468646187736, -1.4123475243360133}), - torch::tensor({-2.0633681724342527}), - }, - { - torch::tensor({0.1796366639185348, 0.24250882860835257, 0.17963666379614568, 0.24250882849182892, 0.17963667838367053, 0.2425088483939741}), - torch::tensor({0.6287216468984888, 0.6287216474215305, 0.6287217011907862}), - torch::tensor({-1.4123467758545658, -1.412346774671038, -1.4123469244007658}), - torch::tensor({-2.0633688474977467}), - }, + { + torch::tensor( + {0.7891011218979068, + 0.5024439415126254, + 0.8587078332470682, + 0.6579710998575992, + 0.7476364849956589, + 1.6975570150849029}), + torch::tensor( + {0.8914687701583902, 0.7020514988715463, 1.6892015071335027}), + torch::tensor( + {-1.0508031297726799, -1.3941351871397083, -1.2843375972607243}), + torch::tensor({-1.0711381241615712}), + }, + { + torch::tensor( + {0.1846116678522213, + 0.24944077103107917, + 0.18651745437755768, + 0.25219093533041764, + 0.18712037968446713, + 0.25289206444055234}), + torch::tensor( + {0.6482869597891656, 0.6580215784646755, 0.6581256007663537}), + torch::tensor( + {-1.454709711443681, -1.4748063405174818, -1.4811625946604765}), + torch::tensor({-1.905292836544363}), + }, + { + torch::tensor( + {0.18059895999281475, + 0.2438515539257779, + 0.18067177884778182, + 0.24397186395008694, + 0.18168388351830797, + 0.24533853846052017}), + torch::tensor( + {0.6325250261983028, 0.6331827793513023, 0.6366659383355598}), + torch::tensor( + {-1.420803333750877, -1.4215627240541653, -1.4320264544533396}), + torch::tensor({-2.030135641848322}), + }, + { + torch::tensor( + {0.17981392697398363, + 0.2427571544305695, + 0.17981150414451733, + 0.24275725992310523, + 0.18014798619115763, + 0.2432144956227816}), + torch::tensor( + {0.6294321320817985, 0.6294873737410742, 0.6306958589251878}), + torch::tensor( + {-1.4139253354785764, -1.413902680470981, -1.4173628530293867}), + torch::tensor({-2.056210117690093}), + }, + { + torch::tensor( + {0.17967006242163747, + 0.24255582734557277, + 0.17966873677301953, + 0.24255462870545766, + 0.1797588230898851, + 0.24267729072765756}), + torch::tensor( + {0.6288576295241085, 0.6288643132826753, 0.6291921485342001}), + torch::tensor( + {-1.4126465879787569, -1.4126335126907266, -1.4135586793353685}), + torch::tensor({-2.0618018405404825}), + }, + { + torch::tensor( + {0.17964321284685653, + 0.24251808241139364, + 0.17964291377171066, + 0.24251779104198598, + 0.1796651574178102, + 0.2425481059085456}), + torch::tensor( + {0.628748693136779, 0.6287498167193976, 0.6288312441271762}), + torch::tensor( + {-1.412405895385289, -1.4124029484481164, -1.4126313051315378}), + torch::tensor({-2.0630223163099304}), + }, + { + torch::tensor( + {0.1796379973927849, + 0.2425107191246215, + 0.1796379363134342, + 0.2425106592536174, + 0.17964321802205502, + 0.24251786094585864}), + torch::tensor( + {0.6287272170354161, 0.6287274414587727, 0.6287468362309863}), + torch::tensor( + {-1.4123588626342634, -1.412358263650784, -1.412412480569672}), + torch::tensor({-2.0632918101480255}), + }, + { + torch::tensor( + {0.17963694231402444, + 0.24250922426893615, + 0.1796369298071621, + 0.2425092121007451, + 0.17963815759528073, + 0.24251088666939838}), + torch::tensor( + {0.6287228195255881, 0.6287228675172439, 0.628727383936762}), + torch::tensor( + {-1.4123493065102781, -1.412349184462438, -1.4123617872243597}), + torch::tensor({-2.063351765096138}), + }, + { + torch::tensor( + {0.1796367215904632, + 0.24250891070978667, + 0.17963671897003045, + 0.24250890818318074, + 0.17963700091084855, + 0.24250929278196054}), + torch::tensor( + {0.6287218911936107, 0.6287219017313679, 0.6287229399204574}), + torch::tensor( + {-1.4123473011084142, -1.412347275640343, -1.41235016959507}), + torch::tensor({-2.0633651674043505}), + }, + { + torch::tensor( + {0.17963667424978783, + 0.24250884333120687, + 0.17963667368829764, + 0.24250884279379462, + 0.17963673796131557, + 0.24250893047794023}), + torch::tensor( + {0.6287216908150736, 0.6287216931558691, 0.6287219299749583}), + torch::tensor( + {-1.4123468700596724, -1.4123468646187736, -1.4123475243360133}), + torch::tensor({-2.0633681724342527}), + }, + { + torch::tensor( + {0.1796366639185348, + 0.24250882860835257, + 0.17963666379614568, + 0.24250882849182892, + 0.17963667838367053, + 0.2425088483939741}), + torch::tensor( + {0.6287216468984888, 0.6287216474215305, 0.6287217011907862}), + torch::tensor( + {-1.4123467758545658, -1.412346774671038, -1.4123469244007658}), + torch::tensor({-2.0633688474977467}), + }, }; } -inline std::vector> Adagrad_with_weight_decay_and_lr_decay() { +inline std::vector> +Adagrad_with_weight_decay_and_lr_decay() { return { - { - torch::tensor({0.7891011046018798, 0.5024439245163383, 0.8587078329086189, 0.6579710994225316, 0.747636483621666, 1.697557019500142}), - torch::tensor({0.8914687688943375, 0.7020514988069164, 1.6892015076050049}), - torch::tensor({-1.0508031297732776, -1.3941351871450511, -1.284337597261839}), - torch::tensor({-1.0711381241617108}), - }, - { - torch::tensor({2.346218944110103, 2.191939439502003, 1.683355201740813, 1.5405520021635604, 1.2137800230828062, 2.205283463717303}), - torch::tensor({2.9090564593404, 1.7509657336815554, 2.336166413186925}), - torch::tensor({-2.206159683368316, -2.5344318233445415, -1.9622783535807609}), - torch::tensor({-1.5796101463783623}), - }, - { - torch::tensor({2.3889328781057233, 2.2678221038007296, 1.7667624725138267, 1.6358015176639822, 1.2655767687152566, 2.261088056711282}), - torch::tensor({3.045569451994985, 1.8770196253823253, 2.4192707519566765}), - torch::tensor({-2.4079300017528613, -2.7399112002234305, -2.0780613510632375}), - torch::tensor({-1.664722108226537}), - }, - { - torch::tensor({2.3886137557806384, 2.2922158071009178, 1.8078384116424007, 1.684352474440932, 1.290353948335789, 2.2870715509706496}), - torch::tensor({3.111110355394278, 1.9438501730282314, 2.4630249355872826}), - torch::tensor({-2.5226122034499263, -2.857315093916292, -2.143964860243905}), - torch::tensor({-1.7130685809905042}), - }, - { - torch::tensor({2.374703352203156, 2.298804499257456, 1.8330249458212446, 1.7151661013307244, 1.3048586226945842, 2.3017650590464274}), - torch::tensor({3.150318222034133, 1.9877926185369321, 2.491399976401679}), - torch::tensor({-2.601415913361488, -2.938203895113964, -2.1892988334550028}), - torch::tensor({-1.7462964261966805}), - }, - { - torch::tensor({2.3553658567303812, 2.297191758042688, 1.8501154749072124, 1.736836058688188, 1.3141313000193942, 2.3107452592153854}), - torch::tensor({3.1762315339155434, 2.0197585204578647, 2.5117041377790197}), - torch::tensor({-2.6606644002288697, -2.9991216074293856, -2.223413376189609}), - torch::tensor({-1.7712905233118807}), - }, - { - torch::tensor({2.3338052201696207, 2.2913023710914993, 1.8624163948044772, 1.7530300731725454, 1.3203313209234842, 2.3163969478854747}), - torch::tensor({3.1943525925688934, 2.044447386769377, 2.527109724607397}), - torch::tensor({-2.7076634717294894, -3.047500808469036, -2.250495807208967}), - torch::tensor({-1.7911288238757486}), - }, - { - torch::tensor({2.3114979154644892, 2.2830501835377808, 1.871616142999356, 1.765632597660841, 1.324565631636651, 2.319939234205203}), - torch::tensor({3.2074779925809085, 2.0642940833670544, 2.5392671301471235}), - torch::tensor({-2.7463093287485925, -3.087315541134716, -2.272780318857348}), - torch::tensor({-1.8074516661537263}), - }, - { - torch::tensor({2.2891841627387346, 2.2734716995793693, 1.8786818999895825, 1.7757301317117602, 1.3274682997719436, 2.3220679353993825}), - torch::tensor({3.2172019619454075, 2.0807140893178175, 2.5491374815141876}), - torch::tensor({-2.7789204504423823, -3.1209351402429175, -2.2915969523376867}), - torch::tensor({-1.821234722948421}), - }, - { - torch::tensor({2.2672498238343066, 2.2631678037928893, 1.8842131287032622, 1.7840007705383882, 1.3294311820750493, 2.323211243034543}), - torch::tensor({3.224507488068445, 2.094598223519413, 2.5573257155791715}), - torch::tensor({-2.8069849199086647, -3.1498826045022925, -2.3077996970997727}), - torch::tensor({-1.8331040438272388}), - }, - { - torch::tensor({2.2458961718688957, 2.2525031725114775, 1.8886034384961112, 1.7908930341267952, 1.3307102435291205, 2.323647497600462}), - torch::tensor({3.230036385413078, 2.1065407459636134, 2.5642349249609664}), - torch::tensor({-2.83151249424399, -3.1751926295316566, -2.3219682378974036}), - torch::tensor({-1.8434843744626483}), - }, + { + torch::tensor( + {0.7891011046018798, + 0.5024439245163383, + 0.8587078329086189, + 0.6579710994225316, + 0.747636483621666, + 1.697557019500142}), + torch::tensor( + {0.8914687688943375, 0.7020514988069164, 1.6892015076050049}), + torch::tensor( + {-1.0508031297732776, -1.3941351871450511, -1.284337597261839}), + torch::tensor({-1.0711381241617108}), + }, + { + torch::tensor( + {2.346218944110103, + 2.191939439502003, + 1.683355201740813, + 1.5405520021635604, + 1.2137800230828062, + 2.205283463717303}), + torch::tensor( + {2.9090564593404, 1.7509657336815554, 2.336166413186925}), + torch::tensor( + {-2.206159683368316, -2.5344318233445415, -1.9622783535807609}), + torch::tensor({-1.5796101463783623}), + }, + { + torch::tensor( + {2.3889328781057233, + 2.2678221038007296, + 1.7667624725138267, + 1.6358015176639822, + 1.2655767687152566, + 2.261088056711282}), + torch::tensor( + {3.045569451994985, 1.8770196253823253, 2.4192707519566765}), + torch::tensor( + {-2.4079300017528613, -2.7399112002234305, -2.0780613510632375}), + torch::tensor({-1.664722108226537}), + }, + { + torch::tensor( + {2.3886137557806384, + 2.2922158071009178, + 1.8078384116424007, + 1.684352474440932, + 1.290353948335789, + 2.2870715509706496}), + torch::tensor( + {3.111110355394278, 1.9438501730282314, 2.4630249355872826}), + torch::tensor( + {-2.5226122034499263, -2.857315093916292, -2.143964860243905}), + torch::tensor({-1.7130685809905042}), + }, + { + torch::tensor( + {2.374703352203156, + 2.298804499257456, + 1.8330249458212446, + 1.7151661013307244, + 1.3048586226945842, + 2.3017650590464274}), + torch::tensor( + {3.150318222034133, 1.9877926185369321, 2.491399976401679}), + torch::tensor( + {-2.601415913361488, -2.938203895113964, -2.1892988334550028}), + torch::tensor({-1.7462964261966805}), + }, + { + torch::tensor( + {2.3553658567303812, + 2.297191758042688, + 1.8501154749072124, + 1.736836058688188, + 1.3141313000193942, + 2.3107452592153854}), + torch::tensor( + {3.1762315339155434, 2.0197585204578647, 2.5117041377790197}), + torch::tensor( + {-2.6606644002288697, -2.9991216074293856, -2.223413376189609}), + torch::tensor({-1.7712905233118807}), + }, + { + torch::tensor( + {2.3338052201696207, + 2.2913023710914993, + 1.8624163948044772, + 1.7530300731725454, + 1.3203313209234842, + 2.3163969478854747}), + torch::tensor( + {3.1943525925688934, 2.044447386769377, 2.527109724607397}), + torch::tensor( + {-2.7076634717294894, -3.047500808469036, -2.250495807208967}), + torch::tensor({-1.7911288238757486}), + }, + { + torch::tensor( + {2.3114979154644892, + 2.2830501835377808, + 1.871616142999356, + 1.765632597660841, + 1.324565631636651, + 2.319939234205203}), + torch::tensor( + {3.2074779925809085, 2.0642940833670544, 2.5392671301471235}), + torch::tensor( + {-2.7463093287485925, -3.087315541134716, -2.272780318857348}), + torch::tensor({-1.8074516661537263}), + }, + { + torch::tensor( + {2.2891841627387346, + 2.2734716995793693, + 1.8786818999895825, + 1.7757301317117602, + 1.3274682997719436, + 2.3220679353993825}), + torch::tensor( + {3.2172019619454075, 2.0807140893178175, 2.5491374815141876}), + torch::tensor( + {-2.7789204504423823, -3.1209351402429175, -2.2915969523376867}), + torch::tensor({-1.821234722948421}), + }, + { + torch::tensor( + {2.2672498238343066, + 2.2631678037928893, + 1.8842131287032622, + 1.7840007705383882, + 1.3294311820750493, + 2.323211243034543}), + torch::tensor( + {3.224507488068445, 2.094598223519413, 2.5573257155791715}), + torch::tensor( + {-2.8069849199086647, -3.1498826045022925, -2.3077996970997727}), + torch::tensor({-1.8331040438272388}), + }, + { + torch::tensor( + {2.2458961718688957, + 2.2525031725114775, + 1.8886034384961112, + 1.7908930341267952, + 1.3307102435291205, + 2.323647497600462}), + torch::tensor( + {3.230036385413078, 2.1065407459636134, 2.5642349249609664}), + torch::tensor( + {-2.83151249424399, -3.1751926295316566, -2.3219682378974036}), + torch::tensor({-1.8434843744626483}), + }, }; } inline std::vector> RMSprop() { return { - { - torch::tensor({0.7890625772821005, 0.502415108650816, 0.8587027713011453, 0.657967312300643, 0.7476283936579036, 1.6975509766054537}), - torch::tensor({0.8914573371873159, 0.7020499947573374, 1.6891991194739453}), - torch::tensor({-1.0508027874171133, -1.3941348219724659, -1.2843374000099703}), - torch::tensor({-1.0711379842715099}), - }, - { - torch::tensor({2.448571858277443, 2.2809152044417678, 1.7346424449151965, 1.5940004770230667, 1.250761131839982, 2.248993270255382}), - torch::tensor({2.994661478530102, 1.8150485290864256, 2.382542610897819}), - torch::tensor({-2.3036981738757825, -2.6337299521275646, -2.018370122358821}), - torch::tensor({-1.620787559800898}), - }, - { - torch::tensor({2.5837582475607785, 2.4365737242301537, 1.8622886519354538, 1.7357065282848232, 1.3369695670141974, 2.3454934716983695}), - torch::tensor({3.2061266499381618, 1.9981112525417788, 2.5092495986614}), - torch::tensor({-2.6110809365525958, -2.9484807193016787, -2.194898560798439}), - torch::tensor({-1.7501043480625826}), - }, - { - torch::tensor({2.669969051134511, 2.536559412710799, 1.9456091681389671, 1.828914948091767, 1.3952956766999587, 2.4110816686341923}), - torch::tensor({3.343672975593657, 2.1204057198913002, 2.5961524902119497}), - torch::tensor({-2.8372329851331006, -3.1817729538857207, -2.3249971853996954}), - torch::tensor({-1.8450422173907486}), - }, - { - torch::tensor({2.7375365004059122, 2.6153071545358633, 2.0117493624534313, 1.9033001982031035, 1.4427501882445097, 2.4646213743186127}), - torch::tensor({3.452912454199796, 2.2190451524127535, 2.667552790123282}), - torch::tensor({-3.0329479456731505, -3.384582488936652, -2.4377299824997136}), - torch::tensor({-1.9271014784118226}), - }, - { - torch::tensor({2.7952372917068753, 2.682820220375722, 2.0687223272686994, 1.967654548778711, 1.4844410726622166, 2.5117888904510117}), - torch::tensor({3.5471904628565745, 2.305113548262141, 2.7307948248967304}), - torch::tensor({-3.2141190290332537, -3.572944633614449, -2.5421970206546827}), - torch::tensor({-2.0029976985219666}), - }, - { - torch::tensor({2.8467333937519483, 2.7432785177110395, 2.119898810135385, 2.0256805416741255, 1.5225256464280221, 2.554983108087885}), - torch::tensor({3.632098323876194, 2.383289304179778, 2.7889864719999222}), - torch::tensor({-3.387579944167926, -3.7537658010839294, -2.6423123266260427}), - torch::tensor({-2.075616951445725}), - }, - { - torch::tensor({2.8938600498060203, 2.798776984183615, 2.16697381475156, 2.079238538430203, 1.5580820887115123, 2.5954023023969692}), - torch::tensor({3.71043881435304, 2.455919099321953, 2.8436784441941008}), - torch::tensor({-3.5567368287146417, -3.930484868709691, -2.740026479434574}), - torch::tensor({-2.146398256871758}), - }, - { - torch::tensor({2.937649394399939, 2.850492965312311, 2.210902777232446, 2.1293746183147633, 1.5917084661873866, 2.6337100421079533}), - torch::tensor({3.7837853328443516, 2.5243155701130604, 2.8957265009949373}), - torch::tensor({-3.7234268485210475, -4.104949193318518, -2.836390693799751}), - torch::tensor({-2.216118773360611}), - }, - { - torch::tensor({2.9787316798887558, 2.8991451078473207, 2.252272487010975, 2.1767288729202767, 1.6237627746697592, 2.670302507579268}), - torch::tensor({3.8530980655045086, 2.589275531553025, 2.9456388178450936}), - torch::tensor({-3.8886856459619636, -4.278191888396593, -2.9319964928350313}), - torch::tensor({-2.285217699505124}), - }, - { - torch::tensor({3.017515620579049, 2.9452004251453268, 2.291469925522962, 2.2217217025782916, 1.6544730927621272, 2.705431441204422}), - torch::tensor({3.9190041004420166, 2.651317624465938, 2.993736489599992}), - torch::tensor({-4.053111559341913, -4.450801801162238, -3.0271845519131957}), - torch::tensor({-2.3539498905973444}), - }, + { + torch::tensor( + {0.7890625772821005, + 0.502415108650816, + 0.8587027713011453, + 0.657967312300643, + 0.7476283936579036, + 1.6975509766054537}), + torch::tensor( + {0.8914573371873159, 0.7020499947573374, 1.6891991194739453}), + torch::tensor( + {-1.0508027874171133, -1.3941348219724659, -1.2843374000099703}), + torch::tensor({-1.0711379842715099}), + }, + { + torch::tensor( + {2.448571858277443, + 2.2809152044417678, + 1.7346424449151965, + 1.5940004770230667, + 1.250761131839982, + 2.248993270255382}), + torch::tensor( + {2.994661478530102, 1.8150485290864256, 2.382542610897819}), + torch::tensor( + {-2.3036981738757825, -2.6337299521275646, -2.018370122358821}), + torch::tensor({-1.620787559800898}), + }, + { + torch::tensor( + {2.5837582475607785, + 2.4365737242301537, + 1.8622886519354538, + 1.7357065282848232, + 1.3369695670141974, + 2.3454934716983695}), + torch::tensor( + {3.2061266499381618, 1.9981112525417788, 2.5092495986614}), + torch::tensor( + {-2.6110809365525958, -2.9484807193016787, -2.194898560798439}), + torch::tensor({-1.7501043480625826}), + }, + { + torch::tensor( + {2.669969051134511, + 2.536559412710799, + 1.9456091681389671, + 1.828914948091767, + 1.3952956766999587, + 2.4110816686341923}), + torch::tensor( + {3.343672975593657, 2.1204057198913002, 2.5961524902119497}), + torch::tensor( + {-2.8372329851331006, -3.1817729538857207, -2.3249971853996954}), + torch::tensor({-1.8450422173907486}), + }, + { + torch::tensor( + {2.7375365004059122, + 2.6153071545358633, + 2.0117493624534313, + 1.9033001982031035, + 1.4427501882445097, + 2.4646213743186127}), + torch::tensor( + {3.452912454199796, 2.2190451524127535, 2.667552790123282}), + torch::tensor( + {-3.0329479456731505, -3.384582488936652, -2.4377299824997136}), + torch::tensor({-1.9271014784118226}), + }, + { + torch::tensor( + {2.7952372917068753, + 2.682820220375722, + 2.0687223272686994, + 1.967654548778711, + 1.4844410726622166, + 2.5117888904510117}), + torch::tensor( + {3.5471904628565745, 2.305113548262141, 2.7307948248967304}), + torch::tensor( + {-3.2141190290332537, -3.572944633614449, -2.5421970206546827}), + torch::tensor({-2.0029976985219666}), + }, + { + torch::tensor( + {2.8467333937519483, + 2.7432785177110395, + 2.119898810135385, + 2.0256805416741255, + 1.5225256464280221, + 2.554983108087885}), + torch::tensor( + {3.632098323876194, 2.383289304179778, 2.7889864719999222}), + torch::tensor( + {-3.387579944167926, -3.7537658010839294, -2.6423123266260427}), + torch::tensor({-2.075616951445725}), + }, + { + torch::tensor( + {2.8938600498060203, + 2.798776984183615, + 2.16697381475156, + 2.079238538430203, + 1.5580820887115123, + 2.5954023023969692}), + torch::tensor( + {3.71043881435304, 2.455919099321953, 2.8436784441941008}), + torch::tensor( + {-3.5567368287146417, -3.930484868709691, -2.740026479434574}), + torch::tensor({-2.146398256871758}), + }, + { + torch::tensor( + {2.937649394399939, + 2.850492965312311, + 2.210902777232446, + 2.1293746183147633, + 1.5917084661873866, + 2.6337100421079533}), + torch::tensor( + {3.7837853328443516, 2.5243155701130604, 2.8957265009949373}), + torch::tensor( + {-3.7234268485210475, -4.104949193318518, -2.836390693799751}), + torch::tensor({-2.216118773360611}), + }, + { + torch::tensor( + {2.9787316798887558, + 2.8991451078473207, + 2.252272487010975, + 2.1767288729202767, + 1.6237627746697592, + 2.670302507579268}), + torch::tensor( + {3.8530980655045086, 2.589275531553025, 2.9456388178450936}), + torch::tensor( + {-3.8886856459619636, -4.278191888396593, -2.9319964928350313}), + torch::tensor({-2.285217699505124}), + }, + { + torch::tensor( + {3.017515620579049, + 2.9452004251453268, + 2.291469925522962, + 2.2217217025782916, + 1.6544730927621272, + 2.705431441204422}), + torch::tensor( + {3.9190041004420166, 2.651317624465938, 2.993736489599992}), + torch::tensor( + {-4.053111559341913, -4.450801801162238, -3.0271845519131957}), + torch::tensor({-2.3539498905973444}), + }, }; } inline std::vector> RMSprop_with_weight_decay() { return { - { - torch::tensor({0.7890798754118442, 0.5024321083861885, 0.8587031097835685, 0.6579677474141494, 0.7476297677960806, 1.6975465611838714}), - torch::tensor({0.891458601354904, 0.7020500593937647, 1.6891986479348047}), - torch::tensor({-1.0508027868194278, -1.3941348166291232, -1.2843373988951865}), - torch::tensor({-1.0711379841318796}), - }, - { - torch::tensor({0.2139892652405453, 0.2779011713353896, 0.18684802794665187, 0.2507569370785562, 0.19145335235130007, 0.2557687813140708}), - torch::tensor({0.6720959116689083, 0.6480734848064099, 0.654263007004671}), - torch::tensor({-1.4357633640899097, -1.4493557950073235, -1.4619011018357073}), - torch::tensor({-1.9673083558727926}), - }, - { - torch::tensor({0.23961935744660673, 0.30354236865888945, 0.19567694278583514, 0.2544696440133763, 0.21982879020261814, 0.27495711471979495}), - torch::tensor({0.6927895724658635, 0.638015535479105, 0.6523245375960234}), - torch::tensor({-1.413722583500382, -1.4170291001633526, -1.4166977298480703}), - torch::tensor({-2.0626651115437147}), - }, - { - torch::tensor({0.2506635865272117, 0.314639511428635, 0.25116892910818034, 0.30431399579592144, 0.25219625048710015, 0.3160110008170742}), - torch::tensor({0.7051419232960522, 0.6699011906397543, 0.699097284678438}), - torch::tensor({-1.4206083241624232, -1.4257037444100107, -1.4171061826065132}), - torch::tensor({-2.075537874763694}), - }, - { - torch::tensor({0.23285924743063605, 0.29652494304777544, 0.2335322002738168, 0.2969991261380461, 0.23358272245229555, 0.2973997498166104}), - torch::tensor({0.6855589594925036, 0.6796983775695974, 0.6864174803983276}), - torch::tensor({-1.43110762794651, -1.4334934742818164, -1.422739552145125}), - torch::tensor({-2.0842642493046184}), - }, - { - torch::tensor({0.23356397699389828, 0.29737142391985355, 0.23367622061822368, 0.29749447597160267, 0.23418481357395918, 0.29818122925156104}), - torch::tensor({0.6866530583001205, 0.6858933385102559, 0.6883944045412603}), - torch::tensor({-1.4564955509607018, -1.4583548131500643, -1.4418225445708595}), - torch::tensor({-2.1064103749186183}), - }, - { - torch::tensor({0.2318717301174723, 0.2952904159872858, 0.23194024439476665, 0.29537019824987687, 0.2316421336904657, 0.2951041425983894}), - torch::tensor({0.6834813130194509, 0.6834401711464199, 0.6837275457100463}), - torch::tensor({-1.4647835805276763, -1.4653452408179053, -1.4571142112777709}), - torch::tensor({-2.1209598505912086}), - }, - { - torch::tensor({0.2308683396504178, 0.2940474629750448, 0.23089067678260966, 0.29407615110959306, 0.23064069314214175, 0.29379043611390243}), - torch::tensor({0.6815062281792611, 0.6815233687209215, 0.6812759203026146}), - torch::tensor({-1.4643013018530682, -1.4644523635284246, -1.4617493939684878}), - torch::tensor({-2.1247293635678854}), - }, - { - torch::tensor({0.23066464201462678, 0.29376059273730426, 0.23067069245857366, 0.2937690399784267, 0.23057551211606675, 0.2936477517373107}), - torch::tensor({0.6809028781780304, 0.6809134105028244, 0.6807404613096301}), - torch::tensor({-1.4637927352177986, -1.4638374228010727, -1.46299287102643}), - torch::tensor({-2.1258082720638107}), - }, - { - torch::tensor({0.23062625079199173, 0.2936990787425707, 0.2306278924729115, 0.2937014834661651, 0.23059813368157003, 0.29366073890476396}), - torch::tensor({0.6807251804689082, 0.6807295616357246, 0.6806523640328994}), - torch::tensor({-1.4635790398985618, -1.463592926902286, -1.4633272688236565}), - torch::tensor({-2.1261396358141798}), - }, - { - torch::tensor({0.23061701122700193, 0.293683983817782, 0.23061747865501653, 0.29368467998690806, 0.23060855595638208, 0.2936719507340021}), - torch::tensor({0.6806714673830832, 0.6806730903175793, 0.6806434720800856}), - torch::tensor({-1.4635008778278134, -1.4635052859178375, -1.4634208375068285}), - torch::tensor({-2.1262432969587723}), - }, + { + torch::tensor( + {0.7890798754118442, + 0.5024321083861885, + 0.8587031097835685, + 0.6579677474141494, + 0.7476297677960806, + 1.6975465611838714}), + torch::tensor( + {0.891458601354904, 0.7020500593937647, 1.6891986479348047}), + torch::tensor( + {-1.0508027868194278, -1.3941348166291232, -1.2843373988951865}), + torch::tensor({-1.0711379841318796}), + }, + { + torch::tensor( + {0.2139892652405453, + 0.2779011713353896, + 0.18684802794665187, + 0.2507569370785562, + 0.19145335235130007, + 0.2557687813140708}), + torch::tensor( + {0.6720959116689083, 0.6480734848064099, 0.654263007004671}), + torch::tensor( + {-1.4357633640899097, -1.4493557950073235, -1.4619011018357073}), + torch::tensor({-1.9673083558727926}), + }, + { + torch::tensor( + {0.23961935744660673, + 0.30354236865888945, + 0.19567694278583514, + 0.2544696440133763, + 0.21982879020261814, + 0.27495711471979495}), + torch::tensor( + {0.6927895724658635, 0.638015535479105, 0.6523245375960234}), + torch::tensor( + {-1.413722583500382, -1.4170291001633526, -1.4166977298480703}), + torch::tensor({-2.0626651115437147}), + }, + { + torch::tensor( + {0.2506635865272117, + 0.314639511428635, + 0.25116892910818034, + 0.30431399579592144, + 0.25219625048710015, + 0.3160110008170742}), + torch::tensor( + {0.7051419232960522, 0.6699011906397543, 0.699097284678438}), + torch::tensor( + {-1.4206083241624232, -1.4257037444100107, -1.4171061826065132}), + torch::tensor({-2.075537874763694}), + }, + { + torch::tensor( + {0.23285924743063605, + 0.29652494304777544, + 0.2335322002738168, + 0.2969991261380461, + 0.23358272245229555, + 0.2973997498166104}), + torch::tensor( + {0.6855589594925036, 0.6796983775695974, 0.6864174803983276}), + torch::tensor( + {-1.43110762794651, -1.4334934742818164, -1.422739552145125}), + torch::tensor({-2.0842642493046184}), + }, + { + torch::tensor( + {0.23356397699389828, + 0.29737142391985355, + 0.23367622061822368, + 0.29749447597160267, + 0.23418481357395918, + 0.29818122925156104}), + torch::tensor( + {0.6866530583001205, 0.6858933385102559, 0.6883944045412603}), + torch::tensor( + {-1.4564955509607018, -1.4583548131500643, -1.4418225445708595}), + torch::tensor({-2.1064103749186183}), + }, + { + torch::tensor( + {0.2318717301174723, + 0.2952904159872858, + 0.23194024439476665, + 0.29537019824987687, + 0.2316421336904657, + 0.2951041425983894}), + torch::tensor( + {0.6834813130194509, 0.6834401711464199, 0.6837275457100463}), + torch::tensor( + {-1.4647835805276763, -1.4653452408179053, -1.4571142112777709}), + torch::tensor({-2.1209598505912086}), + }, + { + torch::tensor( + {0.2308683396504178, + 0.2940474629750448, + 0.23089067678260966, + 0.29407615110959306, + 0.23064069314214175, + 0.29379043611390243}), + torch::tensor( + {0.6815062281792611, 0.6815233687209215, 0.6812759203026146}), + torch::tensor( + {-1.4643013018530682, -1.4644523635284246, -1.4617493939684878}), + torch::tensor({-2.1247293635678854}), + }, + { + torch::tensor( + {0.23066464201462678, + 0.29376059273730426, + 0.23067069245857366, + 0.2937690399784267, + 0.23057551211606675, + 0.2936477517373107}), + torch::tensor( + {0.6809028781780304, 0.6809134105028244, 0.6807404613096301}), + torch::tensor( + {-1.4637927352177986, -1.4638374228010727, -1.46299287102643}), + torch::tensor({-2.1258082720638107}), + }, + { + torch::tensor( + {0.23062625079199173, + 0.2936990787425707, + 0.2306278924729115, + 0.2937014834661651, + 0.23059813368157003, + 0.29366073890476396}), + torch::tensor( + {0.6807251804689082, 0.6807295616357246, 0.6806523640328994}), + torch::tensor( + {-1.4635790398985618, -1.463592926902286, -1.4633272688236565}), + torch::tensor({-2.1261396358141798}), + }, + { + torch::tensor( + {0.23061701122700193, + 0.293683983817782, + 0.23061747865501653, + 0.29368467998690806, + 0.23060855595638208, + 0.2936719507340021}), + torch::tensor( + {0.6806714673830832, 0.6806730903175793, 0.6806434720800856}), + torch::tensor( + {-1.4635008778278134, -1.4635052859178375, -1.4634208375068285}), + torch::tensor({-2.1262432969587723}), + }, }; } -inline std::vector> RMSprop_with_weight_decay_and_centered() { +inline std::vector> +RMSprop_with_weight_decay_and_centered() { return { - { - torch::tensor({0.7941000061626792, 0.507452636734552, 0.8637405354185987, 0.663005089317529, 0.7526661272860107, 1.7025887305065852}), - torch::tensor({0.8964950370033696, 0.7070877948157552, 1.6942369105467197}), - torch::tensor({-1.055840599214661, -1.3991726335388424, -1.2893752132746332}), - torch::tensor({-1.0761757981162612}), - }, - { - torch::tensor({2.3762999876885833, 2.239095829416783, 1.726175067071914, 1.5891569459230444, 1.2410074108588462, 2.2345431036725723}), - torch::tensor({2.990896455635836, 1.8152108764849464, 2.377985429759037}), - torch::tensor({-2.3071822180635286, -2.636859516619699, -2.0198181394256642}), - torch::tensor({-1.622583045791722}), - }, - { - torch::tensor({2.372800588647971, 2.3022753207224254, 1.836028714221617, 1.7190937269287105, 1.3068955839895078, 2.3035835673200364}), - torch::tensor({3.1656599892042343, 1.9942937608209466, 2.4947143457182657}), - torch::tensor({-2.6139790332516775, -2.9507738987695404, -2.1954425128779516}), - torch::tensor({-1.7513053380188806}), - }, - { - torch::tensor({2.2398453700818455, 2.2513384246965904, 1.8892176431436287, 1.7921873754661686, 1.3310951408713538, 2.3236392222350397}), - torch::tensor({3.240166119454613, 2.109742813600189, 2.5651614461576973}), - torch::tensor({-2.8388734382997454, -3.1824200770676123, -2.324831397600949}), - torch::tensor({-1.8460315737386976}), - }, - { - torch::tensor({1.9829606312242465, 2.097356567850692, 1.9050263843525033, 1.8325835415812346, 1.3222762370713104, 2.3024963133870147}), - torch::tensor({3.2465360572089974, 2.1967266045869915, 2.6091992649970672}), - torch::tensor({-3.0326878099587207, -3.3827004807595005, -2.436989182250496}), - torch::tensor({-1.928273216206344}), - }, - { - torch::tensor({1.6051175329080525, 1.8332107491649114, 1.8794767349053179, 1.8403588051948856, 1.273824111314107, 2.2296571379436823}), - torch::tensor({3.1814362940910437, 2.263019214072847, 2.6273016977574013}), - torch::tensor({-3.210932646440219, -3.567153254014387, -2.541016943923914}), - torch::tensor({-2.0049155134617154}), - }, - { - torch::tensor({1.1588059349082709, 1.477861379523226, 1.7992410089026634, 1.806460009198667, 1.1739931551629919, 2.08647960875392}), - torch::tensor({3.03843703712275, 2.308203068375877, 2.6125393914734083}), - torch::tensor({-3.379830678608588, -3.741970414470626, -2.6410082400846546}), - torch::tensor({-2.079294995910487}), - }, - { - torch::tensor({0.7701433312419088, 1.1105026677424745, 1.646507516936639, 1.71625269098179, 1.013748545414221, 1.8532966501655352}), - torch::tensor({2.827176875885245, 2.327401948159928, 2.5535309398603405}), - torch::tensor({-3.54193329850986, -3.9096652952123145, -2.739408870192437}), - torch::tensor({-2.1537939241668997}), - }, - { - torch::tensor({0.5598923129351211, 0.8460500042788701, 1.4084175549165017, 1.5547314210944563, 0.8019580519338424, 1.5258384663629627}), - torch::tensor({2.5774950379490265, 2.313101306699127, 2.4388695757441745}), - torch::tensor({-3.6974974230160087, -4.070190514312716, -2.8378932675718405}), - torch::tensor({-2.2307225014430423}), - }, - { - torch::tensor({0.5016784472836648, 0.7258690889265433, 1.0976902935953956, 1.319949187972513, 0.5853930356154851, 1.1446978015944624}), - torch::tensor({2.3235249877284945, 2.2592840970420176, 2.2681461698609375}), - torch::tensor({-3.8444921272569115, -4.22021051361099, -2.9373192115434263}), - torch::tensor({-2.312733063937045}), - }, - { - torch::tensor({0.4875468895095056, 0.6878747871467128, 0.7787871237567606, 1.0462592546102176, 0.4416468896022397, 0.8122992916762792}), - torch::tensor({2.1078734515587483, 2.17034337037527, 2.0666325968568535}), - torch::tensor({-3.9782695475825216, -4.352093055115415, -3.0377809502927033}), - torch::tensor({-2.403496388200805}), - }, + { + torch::tensor( + {0.7941000061626792, + 0.507452636734552, + 0.8637405354185987, + 0.663005089317529, + 0.7526661272860107, + 1.7025887305065852}), + torch::tensor( + {0.8964950370033696, 0.7070877948157552, 1.6942369105467197}), + torch::tensor( + {-1.055840599214661, -1.3991726335388424, -1.2893752132746332}), + torch::tensor({-1.0761757981162612}), + }, + { + torch::tensor( + {2.3762999876885833, + 2.239095829416783, + 1.726175067071914, + 1.5891569459230444, + 1.2410074108588462, + 2.2345431036725723}), + torch::tensor( + {2.990896455635836, 1.8152108764849464, 2.377985429759037}), + torch::tensor( + {-2.3071822180635286, -2.636859516619699, -2.0198181394256642}), + torch::tensor({-1.622583045791722}), + }, + { + torch::tensor( + {2.372800588647971, + 2.3022753207224254, + 1.836028714221617, + 1.7190937269287105, + 1.3068955839895078, + 2.3035835673200364}), + torch::tensor( + {3.1656599892042343, 1.9942937608209466, 2.4947143457182657}), + torch::tensor( + {-2.6139790332516775, -2.9507738987695404, -2.1954425128779516}), + torch::tensor({-1.7513053380188806}), + }, + { + torch::tensor( + {2.2398453700818455, + 2.2513384246965904, + 1.8892176431436287, + 1.7921873754661686, + 1.3310951408713538, + 2.3236392222350397}), + torch::tensor( + {3.240166119454613, 2.109742813600189, 2.5651614461576973}), + torch::tensor( + {-2.8388734382997454, -3.1824200770676123, -2.324831397600949}), + torch::tensor({-1.8460315737386976}), + }, + { + torch::tensor( + {1.9829606312242465, + 2.097356567850692, + 1.9050263843525033, + 1.8325835415812346, + 1.3222762370713104, + 2.3024963133870147}), + torch::tensor( + {3.2465360572089974, 2.1967266045869915, 2.6091992649970672}), + torch::tensor( + {-3.0326878099587207, -3.3827004807595005, -2.436989182250496}), + torch::tensor({-1.928273216206344}), + }, + { + torch::tensor( + {1.6051175329080525, + 1.8332107491649114, + 1.8794767349053179, + 1.8403588051948856, + 1.273824111314107, + 2.2296571379436823}), + torch::tensor( + {3.1814362940910437, 2.263019214072847, 2.6273016977574013}), + torch::tensor( + {-3.210932646440219, -3.567153254014387, -2.541016943923914}), + torch::tensor({-2.0049155134617154}), + }, + { + torch::tensor( + {1.1588059349082709, + 1.477861379523226, + 1.7992410089026634, + 1.806460009198667, + 1.1739931551629919, + 2.08647960875392}), + torch::tensor( + {3.03843703712275, 2.308203068375877, 2.6125393914734083}), + torch::tensor( + {-3.379830678608588, -3.741970414470626, -2.6410082400846546}), + torch::tensor({-2.079294995910487}), + }, + { + torch::tensor( + {0.7701433312419088, + 1.1105026677424745, + 1.646507516936639, + 1.71625269098179, + 1.013748545414221, + 1.8532966501655352}), + torch::tensor( + {2.827176875885245, 2.327401948159928, 2.5535309398603405}), + torch::tensor( + {-3.54193329850986, -3.9096652952123145, -2.739408870192437}), + torch::tensor({-2.1537939241668997}), + }, + { + torch::tensor( + {0.5598923129351211, + 0.8460500042788701, + 1.4084175549165017, + 1.5547314210944563, + 0.8019580519338424, + 1.5258384663629627}), + torch::tensor( + {2.5774950379490265, 2.313101306699127, 2.4388695757441745}), + torch::tensor( + {-3.6974974230160087, -4.070190514312716, -2.8378932675718405}), + torch::tensor({-2.2307225014430423}), + }, + { + torch::tensor( + {0.5016784472836648, + 0.7258690889265433, + 1.0976902935953956, + 1.319949187972513, + 0.5853930356154851, + 1.1446978015944624}), + torch::tensor( + {2.3235249877284945, 2.2592840970420176, 2.2681461698609375}), + torch::tensor( + {-3.8444921272569115, -4.22021051361099, -2.9373192115434263}), + torch::tensor({-2.312733063937045}), + }, + { + torch::tensor( + {0.4875468895095056, + 0.6878747871467128, + 0.7787871237567606, + 1.0462592546102176, + 0.4416468896022397, + 0.8122992916762792}), + torch::tensor( + {2.1078734515587483, 2.17034337037527, 2.0666325968568535}), + torch::tensor( + {-3.9782695475825216, -4.352093055115415, -3.0377809502927033}), + torch::tensor({-2.403496388200805}), + }, }; } -inline std::vector> RMSprop_with_weight_decay_and_centered_and_momentum() { +inline std::vector> +RMSprop_with_weight_decay_and_centered_and_momentum() { return { - { - torch::tensor({0.7941000061626794, 0.507452636734552, 0.8637405354185985, 0.663005089317529, 0.7526661272860107, 1.7025887305065852}), - torch::tensor({0.8964950370033699, 0.7070877948157552, 1.6942369105467197}), - torch::tensor({-1.055840599214661, -1.3991726335388424, -1.2893752132746332}), - torch::tensor({-1.0761757981162612}), - }, - { - torch::tensor({11.587263945492355, 12.552112516667206, 10.773002960161074, 10.782117868337808, 9.675467654064093, 10.830689360054789}), - torch::tensor({15.298238342006444, 11.252244653209866, 11.423905295074075}), - torch::tensor({-11.287147147258441, -11.673871066494183, -11.143068139029769}), - torch::tensor({-10.744790465364126}), - }, - { - torch::tensor({5.993130757784388, 7.778269455146452, 9.705741295559012, 9.974952848613889, 8.171307305871647, 9.551498426643077}), - torch::tensor({12.811268477045155, 10.912201832960703, 10.87477550647832}), - torch::tensor({-11.20842921856976, -11.58706973895515, -11.098172235374586}), - torch::tensor({-10.714110383698559}), - }, - { - torch::tensor({1.917316794757853, 3.442098373003915, 8.160846071267297, 8.76673426856121, 6.163892823252042, 7.748894752821816}), - torch::tensor({9.52929937981379, 10.371703621802425, 10.02242566317017}), - torch::tensor({-11.07914626767133, -11.444639737948599, -11.02397978065452}), - torch::tensor({-10.663204622623406}), - }, - { - torch::tensor({0.24211162925745067, 0.8235150923738451, 6.109652191353378, 7.070860554523037, 3.8366635637770212, 5.46037058418296}), - torch::tensor({5.7908039507441, 9.534309069066389, 8.752252906881251}), - torch::tensor({-10.868651889371552, -11.212965695734527, -10.90242744782103}), - torch::tensor({-10.579596899816439}), - }, - { - torch::tensor({0.0024206009020476234, 0.05521740497689468, 3.753606156332189, 4.9331546064599685, 1.7094621184709604, 3.022224882400484}), - torch::tensor({2.4729429920325234, 8.290211439306459, 6.983317870704776}), - torch::tensor({-10.529133489023623, -10.839885990130032, -10.704345435808353}), - torch::tensor({-10.44279235413811}), - }, - { - torch::tensor({8.523664833406631e-06, -0.00018498015809617104, 1.6343074841140277, 2.683608480982546, 0.41425107807132744, 1.092111816609512}), - torch::tensor({0.553119873538318, 6.566845593450314, 4.783317472190566}), - torch::tensor({-9.990101114696575, -10.24914448933998, -10.38447825909146}), - torch::tensor({-10.220382375374728}), - }, - { - torch::tensor({5.3669182339397725e-08, -2.899704029399283e-07, 0.37916783268568177, 0.9399553431452395, 0.02859528129337607, 0.17650614337704745}), - torch::tensor({0.03166973497545419, 4.442846994093523, 2.5203464928754724}), - torch::tensor({-9.15653357178671, -9.339631853060773, -9.875729313751442}), - torch::tensor({-9.862669711962374}), - }, - { - torch::tensor({2.1133356499004335e-06, 2.4524630407768025e-06, 0.023655729923601883, 0.14273709578291396, -8.950192389690758e-05, 0.004237697008964042}), - torch::tensor({-0.00012364097582548376, 2.291191859107928, 0.8331414409602524}), - torch::tensor({-7.922566174765117, -8.003055545094796, -9.086673634672907}), - torch::tensor({-9.297519364373224}), - }, - { - torch::tensor({0.0023497430294992434, 0.0028611316714725037, 0.0006998739627296072, 0.003657156536057531, 0.001654303471369622, 0.0018171459470053366}), - torch::tensor({0.004569191565477355, 0.7292466599711233, 0.11475431260766135}), - torch::tensor({-6.223834483308681, -6.185383631607397, -7.912955414853613}), - torch::tensor({-8.430731662958186}), - }, - { - torch::tensor({0.10393820340367545, 0.13982074666181732, 0.0831407198272949, 0.10183584198629944, 0.13949594516972202, 0.17822672100147108}), - torch::tensor({0.340394645020639, 0.24860888862359687, 0.3191404515531066}), - torch::tensor({-4.174294597914298, -4.037528929635062, -6.297198700024484}), - torch::tensor({-7.182093090194918}), - }, + { + torch::tensor( + {0.7941000061626794, + 0.507452636734552, + 0.8637405354185985, + 0.663005089317529, + 0.7526661272860107, + 1.7025887305065852}), + torch::tensor( + {0.8964950370033699, 0.7070877948157552, 1.6942369105467197}), + torch::tensor( + {-1.055840599214661, -1.3991726335388424, -1.2893752132746332}), + torch::tensor({-1.0761757981162612}), + }, + { + torch::tensor( + {11.587263945492355, + 12.552112516667206, + 10.773002960161074, + 10.782117868337808, + 9.675467654064093, + 10.830689360054789}), + torch::tensor( + {15.298238342006444, 11.252244653209866, 11.423905295074075}), + torch::tensor( + {-11.287147147258441, -11.673871066494183, -11.143068139029769}), + torch::tensor({-10.744790465364126}), + }, + { + torch::tensor( + {5.993130757784388, + 7.778269455146452, + 9.705741295559012, + 9.974952848613889, + 8.171307305871647, + 9.551498426643077}), + torch::tensor( + {12.811268477045155, 10.912201832960703, 10.87477550647832}), + torch::tensor( + {-11.20842921856976, -11.58706973895515, -11.098172235374586}), + torch::tensor({-10.714110383698559}), + }, + { + torch::tensor( + {1.917316794757853, + 3.442098373003915, + 8.160846071267297, + 8.76673426856121, + 6.163892823252042, + 7.748894752821816}), + torch::tensor( + {9.52929937981379, 10.371703621802425, 10.02242566317017}), + torch::tensor( + {-11.07914626767133, -11.444639737948599, -11.02397978065452}), + torch::tensor({-10.663204622623406}), + }, + { + torch::tensor( + {0.24211162925745067, + 0.8235150923738451, + 6.109652191353378, + 7.070860554523037, + 3.8366635637770212, + 5.46037058418296}), + torch::tensor( + {5.7908039507441, 9.534309069066389, 8.752252906881251}), + torch::tensor( + {-10.868651889371552, -11.212965695734527, -10.90242744782103}), + torch::tensor({-10.579596899816439}), + }, + { + torch::tensor( + {0.0024206009020476234, + 0.05521740497689468, + 3.753606156332189, + 4.9331546064599685, + 1.7094621184709604, + 3.022224882400484}), + torch::tensor( + {2.4729429920325234, 8.290211439306459, 6.983317870704776}), + torch::tensor( + {-10.529133489023623, -10.839885990130032, -10.704345435808353}), + torch::tensor({-10.44279235413811}), + }, + { + torch::tensor( + {8.523664833406631e-06, + -0.00018498015809617104, + 1.6343074841140277, + 2.683608480982546, + 0.41425107807132744, + 1.092111816609512}), + torch::tensor( + {0.553119873538318, 6.566845593450314, 4.783317472190566}), + torch::tensor( + {-9.990101114696575, -10.24914448933998, -10.38447825909146}), + torch::tensor({-10.220382375374728}), + }, + { + torch::tensor( + {5.3669182339397725e-08, + -2.899704029399283e-07, + 0.37916783268568177, + 0.9399553431452395, + 0.02859528129337607, + 0.17650614337704745}), + torch::tensor( + {0.03166973497545419, 4.442846994093523, 2.5203464928754724}), + torch::tensor( + {-9.15653357178671, -9.339631853060773, -9.875729313751442}), + torch::tensor({-9.862669711962374}), + }, + { + torch::tensor( + {2.1133356499004335e-06, + 2.4524630407768025e-06, + 0.023655729923601883, + 0.14273709578291396, + -8.950192389690758e-05, + 0.004237697008964042}), + torch::tensor( + {-0.00012364097582548376, 2.291191859107928, 0.8331414409602524}), + torch::tensor( + {-7.922566174765117, -8.003055545094796, -9.086673634672907}), + torch::tensor({-9.297519364373224}), + }, + { + torch::tensor( + {0.0023497430294992434, + 0.0028611316714725037, + 0.0006998739627296072, + 0.003657156536057531, + 0.001654303471369622, + 0.0018171459470053366}), + torch::tensor( + {0.004569191565477355, 0.7292466599711233, 0.11475431260766135}), + torch::tensor( + {-6.223834483308681, -6.185383631607397, -7.912955414853613}), + torch::tensor({-8.430731662958186}), + }, + { + torch::tensor( + {0.10393820340367545, + 0.13982074666181732, + 0.0831407198272949, + 0.10183584198629944, + 0.13949594516972202, + 0.17822672100147108}), + torch::tensor( + {0.340394645020639, 0.24860888862359687, 0.3191404515531066}), + torch::tensor( + {-4.174294597914298, -4.037528929635062, -6.297198700024484}), + torch::tensor({-7.182093090194918}), + }, }; } inline std::vector> SGD() { return { - { - torch::tensor({-0.21063957030131192, -0.4972093725858961, -0.13931849072410168, -0.33939101965581686, -0.25112865488453673, 0.6992101966874735}), - torch::tensor({-0.1076573444246077, -0.2913064413859577, 0.6933846874181748}), - torch::tensor({-0.07998325778863398, -0.42149210515421365, -0.33498349553944556}), - torch::tensor({-0.14255126505509488}), - }, - { - torch::tensor({-0.15543131540224012, -0.42351103963720343, -0.04196796248622072, -0.2095223178068499, -0.16031407286541022, 0.8209742464453325}), - torch::tensor({0.07724343607160136, 0.03387529472490231, 1.0028793648054941}), - torch::tensor({-0.8213382425894498, -1.1570800333254736, -1.615476033165743}), - torch::tensor({-1.8734090731084845}), - }, - { - torch::tensor({-0.13342791770744886, -0.3941509709488104, -0.011470356542661934, -0.16885142516066962, -0.13306680693528108, 0.8576491729785701}), - torch::tensor({0.15081014600761677, 0.13560816175111742, 1.0971559708365837}), - torch::tensor({-0.9780975407869251, -1.3215153697157924, -1.876021387605152}), - torch::tensor({-2.2024413056528886}), - }, - { - torch::tensor({-0.11963097684681223, -0.37573675130134543, 0.0069987166413883715, -0.14420855651125972, -0.11733423659038758, 0.8788673419128562}), - torch::tensor({0.1969829338759005, 0.1973461164047132, 1.1520119567305152}), - torch::tensor({-1.0677802792431819, -1.4166561260631119, -2.022033753216991}), - torch::tensor({-2.383452427292781}), - }, - { - torch::tensor({-0.10950806441156272, -0.3622226699218595, 0.02028489243523426, -0.1264725422838007, -0.10635775660996463, 0.8936912722040982}), - torch::tensor({0.23089462331826793, 0.24184450074084418, 1.1904864598387046}), - torch::tensor({-1.1306213044009719, -1.4837186483578142, -2.122884602514208}), - torch::tensor({-2.5071352505158395}), - }, - { - torch::tensor({-0.10149090356585248, -0.3515172115812867, 0.030662536099764083, -0.11261325211798616, -0.09797248308626623, 0.905027632401109}), - torch::tensor({0.25777759826689434, 0.2766609657536915, 1.2199973265718322}), - torch::tensor({-1.1789655573653979, -1.5355073692636774, -2.199612583884608}), - torch::tensor({-2.600529541471662}), - }, - { - torch::tensor({-0.09484472748389533, -0.3426405023243085, 0.03917399284640637, -0.10124188994381228, -0.09121264836307835, 0.9141743475340721}), - torch::tensor({0.2800829300171032, 0.3052600200290069, 1.2438661306695873}), - torch::tensor({-1.2182324765944266, -1.5776851394085492, -2.2613704866316295}), - torch::tensor({-2.6752743361973184}), - }, - { - torch::tensor({-0.08916446117741175, -0.33505233521798666, 0.04638527943959316, -0.09160422984057517, -0.08556486270584644, 0.9218219103015535}), - torch::tensor({0.2991619380154852, 0.3295237551295101, 1.2638639017720827}), - torch::tensor({-1.251282493526328, -1.6132564639504312, -2.3129529937213853}), - torch::tensor({-2.73741957239466}), - }, - { - torch::tensor({-0.08420245801272856, -0.3284224385121882, 0.05263847708646642, -0.08324438788845245, -0.08072424164719598, 0.9283806476306355}), - torch::tensor({0.31584087342663564, 0.35059019818200393, 1.2810450644764015}), - torch::tensor({-1.2798091496372141, -1.6440072538210193, -2.357180462961105}), - torch::tensor({-2.7905023459395872}), - }, - { - torch::tensor({-0.07979600214534928, -0.3225337978155753, 0.0581562720006689, -0.07586555700667826, -0.07649523955108037, 0.9341138824526719}), - torch::tensor({0.3306627217189733, 0.3692005578577212, 1.2960873917356066}), - torch::tensor({-1.3048976883823566, -1.6710855742501123, -2.395849898454614}), - torch::tensor({-2.836765085555123}), - }, - { - torch::tensor({-0.07583232846497832, -0.3172360102461862, 0.06309179259248046, -0.06926361352067158, -0.07274510848082802, 0.9392004636935606}), - torch::tensor({0.3440038606091545, 0.3858647867996722, 1.3094518934419668}), - torch::tensor({-1.3272851146877218, -1.6952731308502653, -2.4301754289421598}), - torch::tensor({-2.8777164728823017}), - }, + { + torch::tensor( + {-0.21063957030131192, + -0.4972093725858961, + -0.13931849072410168, + -0.33939101965581686, + -0.25112865488453673, + 0.6992101966874735}), + torch::tensor( + {-0.1076573444246077, -0.2913064413859577, 0.6933846874181748}), + torch::tensor( + {-0.07998325778863398, + -0.42149210515421365, + -0.33498349553944556}), + torch::tensor({-0.14255126505509488}), + }, + { + torch::tensor( + {-0.15543131540224012, + -0.42351103963720343, + -0.04196796248622072, + -0.2095223178068499, + -0.16031407286541022, + 0.8209742464453325}), + torch::tensor( + {0.07724343607160136, 0.03387529472490231, 1.0028793648054941}), + torch::tensor( + {-0.8213382425894498, -1.1570800333254736, -1.615476033165743}), + torch::tensor({-1.8734090731084845}), + }, + { + torch::tensor( + {-0.13342791770744886, + -0.3941509709488104, + -0.011470356542661934, + -0.16885142516066962, + -0.13306680693528108, + 0.8576491729785701}), + torch::tensor( + {0.15081014600761677, 0.13560816175111742, 1.0971559708365837}), + torch::tensor( + {-0.9780975407869251, -1.3215153697157924, -1.876021387605152}), + torch::tensor({-2.2024413056528886}), + }, + { + torch::tensor( + {-0.11963097684681223, + -0.37573675130134543, + 0.0069987166413883715, + -0.14420855651125972, + -0.11733423659038758, + 0.8788673419128562}), + torch::tensor( + {0.1969829338759005, 0.1973461164047132, 1.1520119567305152}), + torch::tensor( + {-1.0677802792431819, -1.4166561260631119, -2.022033753216991}), + torch::tensor({-2.383452427292781}), + }, + { + torch::tensor( + {-0.10950806441156272, + -0.3622226699218595, + 0.02028489243523426, + -0.1264725422838007, + -0.10635775660996463, + 0.8936912722040982}), + torch::tensor( + {0.23089462331826793, 0.24184450074084418, 1.1904864598387046}), + torch::tensor( + {-1.1306213044009719, -1.4837186483578142, -2.122884602514208}), + torch::tensor({-2.5071352505158395}), + }, + { + torch::tensor( + {-0.10149090356585248, + -0.3515172115812867, + 0.030662536099764083, + -0.11261325211798616, + -0.09797248308626623, + 0.905027632401109}), + torch::tensor( + {0.25777759826689434, 0.2766609657536915, 1.2199973265718322}), + torch::tensor( + {-1.1789655573653979, -1.5355073692636774, -2.199612583884608}), + torch::tensor({-2.600529541471662}), + }, + { + torch::tensor( + {-0.09484472748389533, + -0.3426405023243085, + 0.03917399284640637, + -0.10124188994381228, + -0.09121264836307835, + 0.9141743475340721}), + torch::tensor( + {0.2800829300171032, 0.3052600200290069, 1.2438661306695873}), + torch::tensor( + {-1.2182324765944266, -1.5776851394085492, -2.2613704866316295}), + torch::tensor({-2.6752743361973184}), + }, + { + torch::tensor( + {-0.08916446117741175, + -0.33505233521798666, + 0.04638527943959316, + -0.09160422984057517, + -0.08556486270584644, + 0.9218219103015535}), + torch::tensor( + {0.2991619380154852, 0.3295237551295101, 1.2638639017720827}), + torch::tensor( + {-1.251282493526328, -1.6132564639504312, -2.3129529937213853}), + torch::tensor({-2.73741957239466}), + }, + { + torch::tensor( + {-0.08420245801272856, + -0.3284224385121882, + 0.05263847708646642, + -0.08324438788845245, + -0.08072424164719598, + 0.9283806476306355}), + torch::tensor( + {0.31584087342663564, 0.35059019818200393, 1.2810450644764015}), + torch::tensor( + {-1.2798091496372141, -1.6440072538210193, -2.357180462961105}), + torch::tensor({-2.7905023459395872}), + }, + { + torch::tensor( + {-0.07979600214534928, + -0.3225337978155753, + 0.0581562720006689, + -0.07586555700667826, + -0.07649523955108037, + 0.9341138824526719}), + torch::tensor( + {0.3306627217189733, 0.3692005578577212, 1.2960873917356066}), + torch::tensor( + {-1.3048976883823566, -1.6710855742501123, -2.395849898454614}), + torch::tensor({-2.836765085555123}), + }, + { + torch::tensor( + {-0.07583232846497832, + -0.3172360102461862, + 0.06309179259248046, + -0.06926361352067158, + -0.07274510848082802, + 0.9392004636935606}), + torch::tensor( + {0.3440038606091545, 0.3858647867996722, 1.3094518934419668}), + torch::tensor( + {-1.3272851146877218, -1.6952731308502653, -2.4301754289421598}), + torch::tensor({-2.8777164728823017}), + }, }; } inline std::vector> SGD_with_weight_decay() { return { - { - torch::tensor({-0.21042867144447805, -0.49671181653925384, -0.13917719856207697, -0.3390489907590303, -0.2508762913762564, 0.6985126396619242}), - torch::tensor({-0.10754881320494518, -0.2910084928862701, 0.6926954859081793}), - torch::tensor({-0.079932454658518, -0.42109796996670307, -0.33469915794198624}), - torch::tensor({-0.14248012693079315}), - }, - { - torch::tensor({-0.13579982290274883, -0.3765456284475787, -0.03166970700350034, -0.18102559254681197, -0.1373234786735746, 0.7522156177001302}), - torch::tensor({0.08550003826014418, 0.051563225553454196, 0.9321399061276381}), - torch::tensor({-0.796312238882584, -1.1010063686038731, -1.5363716774172782}), - torch::tensor({-1.8045854907382846}), - }, - { - torch::tensor({-0.09659168723529124, -0.30562076936588267, 0.0067128671455129185, -0.1166002367977548, -0.09012083166238948, 0.7264953102453368}), - torch::tensor({0.16531808496504802, 0.16488328577596398, 0.9610743966573317}), - torch::tensor({-0.9202466399245914, -1.2052829272891832, -1.7049756710541348}), - torch::tensor({-2.0415977924493043}), - }, - { - torch::tensor({-0.06728100597713035, -0.2496589601654196, 0.03186158526394668, -0.07105441484407878, -0.056478595544178806, 0.6910758436366733}), - torch::tensor({0.21707768347081777, 0.23575238192099465, 0.9564382346520686}), - torch::tensor({-0.9788195039029999, -1.2447191597975946, -1.762020156061963}), - torch::tensor({-2.131504419683077}), - }, - { - torch::tensor({-0.04304955053155505, -0.20206572730420902, 0.050959513946324475, -0.034700093557440984, -0.02922465201167018, 0.6547611705604361}), - torch::tensor({0.2563898231537708, 0.2878867158887637, 0.9414221685252802}), - torch::tensor({-1.0143969472996655, -1.2623288365082088, -1.7800471460065668}), - torch::tensor({-2.170255083720924}), - }, - { - torch::tensor({-0.022154717038262738, -0.16036518660639862, 0.06644401410758827, -0.004183373274651896, -0.005965877978527781, 0.6200298215101535}), - torch::tensor({0.2886406829874717, 0.32924516791460257, 0.9230983700837223}), - torch::tensor({-1.0397895250773481, -1.2710914166240181, -1.7807758009603087}), - torch::tensor({-2.1862978976514738}), - }, - { - torch::tensor({-0.0037439139848317077, -0.12328293308251938, 0.07944696186805641, 0.022100305718442022, 0.014399113804332037, 0.587697912745227}), - torch::tensor({0.3162871074692008, 0.36346293565421134, 0.9042402154310412}), - torch::tensor({-1.060234961430088, -1.2762264965487675, -1.7731268727630662}), - torch::tensor({-2.191253945056341}), - }, - { - torch::tensor({0.012675985938854726, -0.09003711893222131, 0.09059095692632844, 0.04506778924310349, 0.03247299240601001, 0.5579755127260052}), - torch::tensor({0.3406226998933173, 0.3924947745885882, 0.8860121369119325}), - torch::tensor({-1.0781407849705034, -1.2800528898634018, -1.7613120374342215}), - torch::tensor({-2.190575043873577}), - }, - { - torch::tensor({0.027425440985777993, -0.06008809958617219, 0.10026092920861808, 0.06531092947039244, 0.048628754907931976, 0.5308215072596255}), - torch::tensor({0.36239744520280553, 0.4175162387638887, 0.8688788105023479}), - torch::tensor({-1.0946579691370502, -1.283610342226948, -1.7474706191775764}), - torch::tensor({-2.1870021744944763}), - }, - { - torch::tensor({0.04073250980147411, -0.03303024103555013, 0.1087177047593139, 0.08324870459183518, 0.0631222868881554, 0.5060892094042873}), - torch::tensor({0.38208249693950175, 0.4393002654989596, 0.8529817924677643}), - torch::tensor({-1.1103326127955466, -1.287332405916359, -1.73273866274852}), - torch::tensor({-2.1819672316721337}), - }, - { - torch::tensor({0.05277160918732605, -0.008539186625351441, 0.1161515444487197, 0.09919929206676087, 0.07614530177703588, 0.48359250162323586}), - torch::tensor({0.3999968617221315, 0.45839442009256354, 0.8383132966805791}), - torch::tensor({-1.1254107858333455, -1.2913604197768889, -1.717739109221235}), - torch::tensor({-2.1762368071604308}), - }, + { + torch::tensor( + {-0.21042867144447805, + -0.49671181653925384, + -0.13917719856207697, + -0.3390489907590303, + -0.2508762913762564, + 0.6985126396619242}), + torch::tensor( + {-0.10754881320494518, -0.2910084928862701, 0.6926954859081793}), + torch::tensor( + {-0.079932454658518, -0.42109796996670307, -0.33469915794198624}), + torch::tensor({-0.14248012693079315}), + }, + { + torch::tensor( + {-0.13579982290274883, + -0.3765456284475787, + -0.03166970700350034, + -0.18102559254681197, + -0.1373234786735746, + 0.7522156177001302}), + torch::tensor( + {0.08550003826014418, 0.051563225553454196, 0.9321399061276381}), + torch::tensor( + {-0.796312238882584, -1.1010063686038731, -1.5363716774172782}), + torch::tensor({-1.8045854907382846}), + }, + { + torch::tensor( + {-0.09659168723529124, + -0.30562076936588267, + 0.0067128671455129185, + -0.1166002367977548, + -0.09012083166238948, + 0.7264953102453368}), + torch::tensor( + {0.16531808496504802, 0.16488328577596398, 0.9610743966573317}), + torch::tensor( + {-0.9202466399245914, -1.2052829272891832, -1.7049756710541348}), + torch::tensor({-2.0415977924493043}), + }, + { + torch::tensor( + {-0.06728100597713035, + -0.2496589601654196, + 0.03186158526394668, + -0.07105441484407878, + -0.056478595544178806, + 0.6910758436366733}), + torch::tensor( + {0.21707768347081777, 0.23575238192099465, 0.9564382346520686}), + torch::tensor( + {-0.9788195039029999, -1.2447191597975946, -1.762020156061963}), + torch::tensor({-2.131504419683077}), + }, + { + torch::tensor( + {-0.04304955053155505, + -0.20206572730420902, + 0.050959513946324475, + -0.034700093557440984, + -0.02922465201167018, + 0.6547611705604361}), + torch::tensor( + {0.2563898231537708, 0.2878867158887637, 0.9414221685252802}), + torch::tensor( + {-1.0143969472996655, -1.2623288365082088, -1.7800471460065668}), + torch::tensor({-2.170255083720924}), + }, + { + torch::tensor( + {-0.022154717038262738, + -0.16036518660639862, + 0.06644401410758827, + -0.004183373274651896, + -0.005965877978527781, + 0.6200298215101535}), + torch::tensor( + {0.2886406829874717, 0.32924516791460257, 0.9230983700837223}), + torch::tensor( + {-1.0397895250773481, -1.2710914166240181, -1.7807758009603087}), + torch::tensor({-2.1862978976514738}), + }, + { + torch::tensor( + {-0.0037439139848317077, + -0.12328293308251938, + 0.07944696186805641, + 0.022100305718442022, + 0.014399113804332037, + 0.587697912745227}), + torch::tensor( + {0.3162871074692008, 0.36346293565421134, 0.9042402154310412}), + torch::tensor( + {-1.060234961430088, -1.2762264965487675, -1.7731268727630662}), + torch::tensor({-2.191253945056341}), + }, + { + torch::tensor( + {0.012675985938854726, + -0.09003711893222131, + 0.09059095692632844, + 0.04506778924310349, + 0.03247299240601001, + 0.5579755127260052}), + torch::tensor( + {0.3406226998933173, 0.3924947745885882, 0.8860121369119325}), + torch::tensor( + {-1.0781407849705034, -1.2800528898634018, -1.7613120374342215}), + torch::tensor({-2.190575043873577}), + }, + { + torch::tensor( + {0.027425440985777993, + -0.06008809958617219, + 0.10026092920861808, + 0.06531092947039244, + 0.048628754907931976, + 0.5308215072596255}), + torch::tensor( + {0.36239744520280553, 0.4175162387638887, 0.8688788105023479}), + torch::tensor( + {-1.0946579691370502, -1.283610342226948, -1.7474706191775764}), + torch::tensor({-2.1870021744944763}), + }, + { + torch::tensor( + {0.04073250980147411, + -0.03303024103555013, + 0.1087177047593139, + 0.08324870459183518, + 0.0631222868881554, + 0.5060892094042873}), + torch::tensor( + {0.38208249693950175, 0.4393002654989596, 0.8529817924677643}), + torch::tensor( + {-1.1103326127955466, -1.287332405916359, -1.73273866274852}), + torch::tensor({-2.1819672316721337}), + }, + { + torch::tensor( + {0.05277160918732605, + -0.008539186625351441, + 0.1161515444487197, + 0.09919929206676087, + 0.07614530177703588, + 0.48359250162323586}), + torch::tensor( + {0.3999968617221315, 0.45839442009256354, 0.8383132966805791}), + torch::tensor( + {-1.1254107858333455, -1.2913604197768889, -1.717739109221235}), + torch::tensor({-2.1762368071604308}), + }, }; } -inline std::vector> SGD_with_weight_decay_and_momentum() { +inline std::vector> +SGD_with_weight_decay_and_momentum() { return { - { - torch::tensor({-0.21042867144447805, -0.49671181653925384, -0.13917719856207697, -0.3390489907590303, -0.2508762913762564, 0.6985126396619242}), - torch::tensor({-0.10754881320494518, -0.2910084928862701, 0.6926954859081793}), - torch::tensor({-0.079932454658518, -0.42109796996670307, -0.33469915794198624}), - torch::tensor({-0.14248012693079315}), - }, - { - torch::tensor({0.0056118487251954775, -0.0710915563059199, 0.07701400891926036, 0.047067327035013866, 0.0428654052972598, 0.4352977220593751}), - torch::tensor({0.23834837300214828, 0.32366382503704183, 0.7128321016634689}), - torch::tensor({-1.041947788394885, -1.1730950187020548, -1.7648205873351157}), - torch::tensor({-2.3359277661920594}), - }, - { - torch::tensor({0.11520007183759418, 0.1289453768763286, 0.14586845555951963, 0.1775341535876219, 0.15614155642578995, 0.33379126147460536}), - torch::tensor({0.465853656413685, 0.520197917876909, 0.7274876508280723}), - torch::tensor({-1.2034746444882527, -1.286126969233868, -1.604528340632377}), - torch::tensor({-2.203215909196624}), - }, - { - torch::tensor({0.15331258730374997, 0.197909036233604, 0.16663814647374195, 0.2183320498727895, 0.1803274550482287, 0.28362745794417826}), - torch::tensor({0.5532312776994917, 0.5834224152126115, 0.6903579410976886}), - torch::tensor({-1.3052171323471546, -1.3514190497186434, -1.5153574535010634}), - torch::tensor({-2.123181139806548}), - }, - { - torch::tensor({0.16814113185552507, 0.22386572201448868, 0.17413795101952864, 0.23280515326261633, 0.1839142207976228, 0.2614499495870909}), - torch::tensor({0.592282876576759, 0.6083877519652824, 0.663438748699906}), - torch::tensor({-1.3591143274292896, -1.383673065830997, -1.467157893517277}), - torch::tensor({-2.087859547998447}), - }, - { - torch::tensor({0.1743742243877178, 0.2343126153059798, 0.17716942927642254, 0.23838669643330088, 0.18308461132092924, 0.25149544624452974}), - torch::tensor({0.6108281747800746, 0.6192657661217672, 0.6475519545045926}), - torch::tensor({-1.3860527054444407, -1.398816664238087, -1.4412527948055516}), - torch::tensor({-2.0731939075659627}), - }, - { - torch::tensor({0.1771465478751462, 0.23875859951719522, 0.1784868271584857, 0.2406786372566496, 0.18181103291606765, 0.24687877342069478}), - torch::tensor({0.6198586021174767, 0.6242349464856269, 0.638736845373371}), - torch::tensor({-1.3993307716862977, -1.4058965193851591, -1.42747775986796}), - torch::tensor({-2.0672675843404598}), - }, - { - torch::tensor({0.17843093585357683, 0.24073954802700465, 0.17908697027440873, 0.2416675839909268, 0.18088350526559058, 0.24467193314356378}), - torch::tensor({0.6243071074374693, 0.6265628975677455, 0.6339840865876518}), - torch::tensor({-1.4058750036106915, -1.4092362337714568, -1.4202202926903085}), - torch::tensor({-2.0649062340635584}), - }, - { - torch::tensor({0.17904350645021613, 0.24165496946247034, 0.17936920658487726, 0.24211164489776849, 0.18031858582735988, 0.2435923992630521}), - torch::tensor({0.626513445507806, 0.6276715667697311, 0.6314641991686346}), - torch::tensor({-1.409113940967948, -1.410830795235453, -1.4164247285253404}), - torch::tensor({-2.0639728292802046}), - }, - { - torch::tensor({0.17934167113683835, 0.242089962404631, 0.17950490408309286, 0.24231745350706005, 0.17999989292556767, 0.2430557755257577}), - torch::tensor({0.6276131793232345, 0.6282062328090801, 0.6301427155170752}), - torch::tensor({-1.4107251789010826, -1.4116011824171857, -1.4144511767962422}), - torch::tensor({-2.0636056316673934}), - }, - { - torch::tensor({0.17948886155124505, 0.24230096332204806, 0.17957117450689372, 0.242415213133214, 0.17982712042628357, 0.2427862039224869}), - torch::tensor({0.6281635672171683, 0.6284667582211864, 0.6294549191500093}), - torch::tensor({-1.4115305541843781, -1.4119772978756444, -1.4134296522818641}), - torch::tensor({-2.0634616066978615}), - }, + { + torch::tensor( + {-0.21042867144447805, + -0.49671181653925384, + -0.13917719856207697, + -0.3390489907590303, + -0.2508762913762564, + 0.6985126396619242}), + torch::tensor( + {-0.10754881320494518, -0.2910084928862701, 0.6926954859081793}), + torch::tensor( + {-0.079932454658518, -0.42109796996670307, -0.33469915794198624}), + torch::tensor({-0.14248012693079315}), + }, + { + torch::tensor( + {0.0056118487251954775, + -0.0710915563059199, + 0.07701400891926036, + 0.047067327035013866, + 0.0428654052972598, + 0.4352977220593751}), + torch::tensor( + {0.23834837300214828, 0.32366382503704183, 0.7128321016634689}), + torch::tensor( + {-1.041947788394885, -1.1730950187020548, -1.7648205873351157}), + torch::tensor({-2.3359277661920594}), + }, + { + torch::tensor( + {0.11520007183759418, + 0.1289453768763286, + 0.14586845555951963, + 0.1775341535876219, + 0.15614155642578995, + 0.33379126147460536}), + torch::tensor( + {0.465853656413685, 0.520197917876909, 0.7274876508280723}), + torch::tensor( + {-1.2034746444882527, -1.286126969233868, -1.604528340632377}), + torch::tensor({-2.203215909196624}), + }, + { + torch::tensor( + {0.15331258730374997, + 0.197909036233604, + 0.16663814647374195, + 0.2183320498727895, + 0.1803274550482287, + 0.28362745794417826}), + torch::tensor( + {0.5532312776994917, 0.5834224152126115, 0.6903579410976886}), + torch::tensor( + {-1.3052171323471546, -1.3514190497186434, -1.5153574535010634}), + torch::tensor({-2.123181139806548}), + }, + { + torch::tensor( + {0.16814113185552507, + 0.22386572201448868, + 0.17413795101952864, + 0.23280515326261633, + 0.1839142207976228, + 0.2614499495870909}), + torch::tensor( + {0.592282876576759, 0.6083877519652824, 0.663438748699906}), + torch::tensor( + {-1.3591143274292896, -1.383673065830997, -1.467157893517277}), + torch::tensor({-2.087859547998447}), + }, + { + torch::tensor( + {0.1743742243877178, + 0.2343126153059798, + 0.17716942927642254, + 0.23838669643330088, + 0.18308461132092924, + 0.25149544624452974}), + torch::tensor( + {0.6108281747800746, 0.6192657661217672, 0.6475519545045926}), + torch::tensor( + {-1.3860527054444407, -1.398816664238087, -1.4412527948055516}), + torch::tensor({-2.0731939075659627}), + }, + { + torch::tensor( + {0.1771465478751462, + 0.23875859951719522, + 0.1784868271584857, + 0.2406786372566496, + 0.18181103291606765, + 0.24687877342069478}), + torch::tensor( + {0.6198586021174767, 0.6242349464856269, 0.638736845373371}), + torch::tensor( + {-1.3993307716862977, -1.4058965193851591, -1.42747775986796}), + torch::tensor({-2.0672675843404598}), + }, + { + torch::tensor( + {0.17843093585357683, + 0.24073954802700465, + 0.17908697027440873, + 0.2416675839909268, + 0.18088350526559058, + 0.24467193314356378}), + torch::tensor( + {0.6243071074374693, 0.6265628975677455, 0.6339840865876518}), + torch::tensor( + {-1.4058750036106915, -1.4092362337714568, -1.4202202926903085}), + torch::tensor({-2.0649062340635584}), + }, + { + torch::tensor( + {0.17904350645021613, + 0.24165496946247034, + 0.17936920658487726, + 0.24211164489776849, + 0.18031858582735988, + 0.2435923992630521}), + torch::tensor( + {0.626513445507806, 0.6276715667697311, 0.6314641991686346}), + torch::tensor( + {-1.409113940967948, -1.410830795235453, -1.4164247285253404}), + torch::tensor({-2.0639728292802046}), + }, + { + torch::tensor( + {0.17934167113683835, + 0.242089962404631, + 0.17950490408309286, + 0.24231745350706005, + 0.17999989292556767, + 0.2430557755257577}), + torch::tensor( + {0.6276131793232345, 0.6282062328090801, 0.6301427155170752}), + torch::tensor( + {-1.4107251789010826, -1.4116011824171857, -1.4144511767962422}), + torch::tensor({-2.0636056316673934}), + }, + { + torch::tensor( + {0.17948886155124505, + 0.24230096332204806, + 0.17957117450689372, + 0.242415213133214, + 0.17982712042628357, + 0.2427862039224869}), + torch::tensor( + {0.6281635672171683, 0.6284667582211864, 0.6294549191500093}), + torch::tensor( + {-1.4115305541843781, -1.4119772978756444, -1.4134296522818641}), + torch::tensor({-2.0634616066978615}), + }, }; } -inline std::vector> SGD_with_weight_decay_and_nesterov_momentum() { +inline std::vector> +SGD_with_weight_decay_and_nesterov_momentum() { return { - { - torch::tensor({-0.21040617235121148, -0.49689727139951717, -0.13754215970803657, -0.33701686525263036, -0.2500172388792182, 0.700697918175925}), - torch::tensor({-0.1068708360895515, -0.2853285323043249, 0.6971494161502307}), - torch::tensor({-0.10624536304143092, -0.4461132561477894, -0.3805647497874434}), - torch::tensor({-0.2068230782168696}), - }, - { - torch::tensor({-0.1262387113548655, -0.3844658218758334, 0.03124406856508884, -0.11170532152425781, -0.09823268522398332, 0.9040698525178972}), - torch::tensor({0.17551336074135096, 0.27976614792027166, 1.2138399680985128}), - torch::tensor({-1.592840413595591, -1.8986806244521564, -2.966181914454827}), - torch::tensor({-3.7728444542017687}), - }, - { - torch::tensor({-0.11614716303292183, -0.3709539909720773, 0.04307078045512772, -0.09588329367245825, -0.08795603365024904, 0.9178771227283019}), - torch::tensor({0.20944042006388683, 0.3195483889401668, 1.2500270348310718}), - torch::tensor({-1.635011052494502, -1.9463243375558272, -3.035708036973984}), - torch::tensor({-3.8570351018212796}), - }, - { - torch::tensor({-0.10793942832760066, -0.35995697973682966, 0.05260329955808716, -0.08312010825923577, -0.07986326997915319, 0.9287409473303162}), - torch::tensor({0.2370574459090396, 0.35168415020524857, 1.278618438127574}), - torch::tensor({-1.669141810658011, -1.984894370767313, -3.091259532917102}), - torch::tensor({-3.923827025320545}), - }, - { - torch::tensor({-0.1010142826857921, -0.35067247612415425, 0.06058642765135953, -0.07242353828264116, -0.07320722520220559, 0.9376663294528951}), - torch::tensor({0.26037531373638517, 0.37864768429039036, 1.3021925174954938}), - torch::tensor({-1.6978623668013235, -2.017346013780729, -3.137511248751908}), - torch::tensor({-3.9791368472670334}), - }, - { - torch::tensor({-0.0950223925827384, -0.34263425874631004, 0.06744912149060932, -0.0632219668955612, -0.06756850374320933, 0.9452179348012486}), - torch::tensor({0.2805629021730173, 0.40186559210837897, 1.322201974233735}), - torch::tensor({-1.7226667375672964, -2.0453651314263936, -3.177094625235675}), - torch::tensor({-4.02626353351958}), - }, - { - torch::tensor({-0.08974074058929343, -0.3355446553621404, 0.07346375443579244, -0.055152336279101065, -0.06268648871001672, 0.9517469338705254}), - torch::tensor({0.2983667593362755, 0.42224471824689497, 1.3395523443811077}), - torch::tensor({-1.7445022249557358, -2.070022023204061, -3.2116640112699977}), - torch::tensor({-4.0672678681014025}), - }, - { - torch::tensor({-0.08501805567029425, -0.32920173535901165, 0.07881418855733362, -0.04796950604202488, -0.058388411064112675, 0.9574862878804136}), - torch::tensor({0.314293480998118, 0.44039784591234016, 1.3548455497581404}), - torch::tensor({-1.7640079833055848, -2.092039516337239, -3.2423272727017984}), - torch::tensor({-4.1035226231344275}), - }, - { - torch::tensor({-0.08074691916762683, -0.32346209125355385, 0.08363041954091821, -0.04150011326440421, -0.05455400195824525, 0.9625982669377223}), - torch::tensor({0.3287029554083447, 0.456758647445438, 1.3685016029692183}), - torch::tensor({-1.781635969965323, -2.1119291060463254, -3.2698625668054278}), - torch::tensor({-4.135987715622241}), - }, - { - torch::tensor({-0.07684825926741544, -0.3182201487496606, 0.08800775942949753, -0.03561702050647377, -0.05109626940874274, 0.967200308063431}), - torch::tensor({0.3418602073865105, 0.47164537681262036, 1.3808249543962092}), - torch::tensor({-1.7977177360253929, -2.1300662313234127, -3.2948372910743537}), - torch::tensor({-4.165360368453936}), - }, - { - torch::tensor({-0.07326215742204903, -0.31339589848358795, 0.09201816976416921, -0.030224217178854797, -0.0479503174605994, 0.9713800632469923}), - torch::tensor({0.35396614509118257, 0.485298524494989, 1.3920431643924076}), - torch::tensor({-1.8125038126190638, -2.146734711618823, -3.3176778240157505}), - torch::tensor({-4.192162739857097}), - }, + { + torch::tensor( + {-0.21040617235121148, + -0.49689727139951717, + -0.13754215970803657, + -0.33701686525263036, + -0.2500172388792182, + 0.700697918175925}), + torch::tensor( + {-0.1068708360895515, -0.2853285323043249, 0.6971494161502307}), + torch::tensor( + {-0.10624536304143092, -0.4461132561477894, -0.3805647497874434}), + torch::tensor({-0.2068230782168696}), + }, + { + torch::tensor( + {-0.1262387113548655, + -0.3844658218758334, + 0.03124406856508884, + -0.11170532152425781, + -0.09823268522398332, + 0.9040698525178972}), + torch::tensor( + {0.17551336074135096, 0.27976614792027166, 1.2138399680985128}), + torch::tensor( + {-1.592840413595591, -1.8986806244521564, -2.966181914454827}), + torch::tensor({-3.7728444542017687}), + }, + { + torch::tensor( + {-0.11614716303292183, + -0.3709539909720773, + 0.04307078045512772, + -0.09588329367245825, + -0.08795603365024904, + 0.9178771227283019}), + torch::tensor( + {0.20944042006388683, 0.3195483889401668, 1.2500270348310718}), + torch::tensor( + {-1.635011052494502, -1.9463243375558272, -3.035708036973984}), + torch::tensor({-3.8570351018212796}), + }, + { + torch::tensor( + {-0.10793942832760066, + -0.35995697973682966, + 0.05260329955808716, + -0.08312010825923577, + -0.07986326997915319, + 0.9287409473303162}), + torch::tensor( + {0.2370574459090396, 0.35168415020524857, 1.278618438127574}), + torch::tensor( + {-1.669141810658011, -1.984894370767313, -3.091259532917102}), + torch::tensor({-3.923827025320545}), + }, + { + torch::tensor( + {-0.1010142826857921, + -0.35067247612415425, + 0.06058642765135953, + -0.07242353828264116, + -0.07320722520220559, + 0.9376663294528951}), + torch::tensor( + {0.26037531373638517, 0.37864768429039036, 1.3021925174954938}), + torch::tensor( + {-1.6978623668013235, -2.017346013780729, -3.137511248751908}), + torch::tensor({-3.9791368472670334}), + }, + { + torch::tensor( + {-0.0950223925827384, + -0.34263425874631004, + 0.06744912149060932, + -0.0632219668955612, + -0.06756850374320933, + 0.9452179348012486}), + torch::tensor( + {0.2805629021730173, 0.40186559210837897, 1.322201974233735}), + torch::tensor( + {-1.7226667375672964, -2.0453651314263936, -3.177094625235675}), + torch::tensor({-4.02626353351958}), + }, + { + torch::tensor( + {-0.08974074058929343, + -0.3355446553621404, + 0.07346375443579244, + -0.055152336279101065, + -0.06268648871001672, + 0.9517469338705254}), + torch::tensor( + {0.2983667593362755, 0.42224471824689497, 1.3395523443811077}), + torch::tensor( + {-1.7445022249557358, -2.070022023204061, -3.2116640112699977}), + torch::tensor({-4.0672678681014025}), + }, + { + torch::tensor( + {-0.08501805567029425, + -0.32920173535901165, + 0.07881418855733362, + -0.04796950604202488, + -0.058388411064112675, + 0.9574862878804136}), + torch::tensor( + {0.314293480998118, 0.44039784591234016, 1.3548455497581404}), + torch::tensor( + {-1.7640079833055848, -2.092039516337239, -3.2423272727017984}), + torch::tensor({-4.1035226231344275}), + }, + { + torch::tensor( + {-0.08074691916762683, + -0.32346209125355385, + 0.08363041954091821, + -0.04150011326440421, + -0.05455400195824525, + 0.9625982669377223}), + torch::tensor( + {0.3287029554083447, 0.456758647445438, 1.3685016029692183}), + torch::tensor( + {-1.781635969965323, -2.1119291060463254, -3.2698625668054278}), + torch::tensor({-4.135987715622241}), + }, + { + torch::tensor( + {-0.07684825926741544, + -0.3182201487496606, + 0.08800775942949753, + -0.03561702050647377, + -0.05109626940874274, + 0.967200308063431}), + torch::tensor( + {0.3418602073865105, 0.47164537681262036, 1.3808249543962092}), + torch::tensor( + {-1.7977177360253929, -2.1300662313234127, -3.2948372910743537}), + torch::tensor({-4.165360368453936}), + }, + { + torch::tensor( + {-0.07326215742204903, + -0.31339589848358795, + 0.09201816976416921, + -0.030224217178854797, + -0.0479503174605994, + 0.9713800632469923}), + torch::tensor( + {0.35396614509118257, 0.485298524494989, 1.3920431643924076}), + torch::tensor( + {-1.8125038126190638, -2.146734711618823, -3.3176778240157505}), + torch::tensor({-4.192162739857097}), + }, }; } diff --git a/test/cpp/api/parallel.cpp b/test/cpp/api/parallel.cpp index 1789fe5000877e..ca5e9ad566ab09 100644 --- a/test/cpp/api/parallel.cpp +++ b/test/cpp/api/parallel.cpp @@ -230,65 +230,65 @@ TEST_F(ParallelTest, DataParallelUsesAllAvailableCUDADevices_CUDA) { TEST_F(ParallelTest, DataParallelNumericalEquivalence_MultiCUDA) { struct M : torch::nn::Cloneable { - M() { - reset(); - } - - void reset() override { - conv = register_module("conv", - torch::nn::Conv2d(torch::nn::Conv2dOptions(2, 2, /*kernel_size=*/2))); - fc = register_module("fc", torch::nn::Linear(8, 2)); - } - - torch::Tensor forward(torch::Tensor x) { - x = conv->forward(x); - x = torch::relu(x); - x = x.view({-1, 8}); - x = fc->forward(x); - return torch::log_softmax(x, /*dim=*/1); - } - - torch::nn::Conv2d conv{nullptr}; - torch::nn::Linear fc{nullptr}; - }; - - // prepare modules and inputs - auto input = torch::ones({16, 2, 3, 3}); - auto input_dp = torch::ones({16, 2, 3, 3}); - auto model = std::make_shared(); - auto model_dp = std::dynamic_pointer_cast(model->clone()); - - // run 3 training iterations - for (const auto i : c10::irange(3)) { - input += i; - input_dp += i; - - // non-prallel training - torch::optim::SGD optim( - model->parameters(), torch::optim::SGDOptions(0.1)); - auto output = model->forward(input); - auto loss = torch::mse_loss(output, torch::zeros_like(output)); - loss.backward(); - optim.step(); - - // data-parallel training - torch::optim::SGD optim_dp( - model_dp->parameters(), torch::optim::SGDOptions(0.1)); - auto output_dp = parallel::data_parallel(model_dp, input_dp); - auto loss_dp = torch::mse_loss(output_dp, torch::zeros_like(output_dp)); - loss_dp.backward(); - optim_dp.step(); - - // make sure that weights are the same - model->to(torch::kCPU); - model_dp->to(torch::kCPU); - auto params = model->parameters(); - auto params_dp = model_dp->parameters(); - ASSERT_EQ(params.size(), params_dp.size()); - for (auto it = params.begin(), it_dp = params_dp.begin(); - it != params.end() && it_dp != params.end(); - ++it, ++it_dp) { - ASSERT_TRUE(torch::allclose(*it, *it_dp)); - } + M() { + reset(); } + + void reset() override { + conv = register_module( + "conv", + torch::nn::Conv2d(torch::nn::Conv2dOptions(2, 2, /*kernel_size=*/2))); + fc = register_module("fc", torch::nn::Linear(8, 2)); + } + + torch::Tensor forward(torch::Tensor x) { + x = conv->forward(x); + x = torch::relu(x); + x = x.view({-1, 8}); + x = fc->forward(x); + return torch::log_softmax(x, /*dim=*/1); + } + + torch::nn::Conv2d conv{nullptr}; + torch::nn::Linear fc{nullptr}; + }; + + // prepare modules and inputs + auto input = torch::ones({16, 2, 3, 3}); + auto input_dp = torch::ones({16, 2, 3, 3}); + auto model = std::make_shared(); + auto model_dp = std::dynamic_pointer_cast(model->clone()); + + // run 3 training iterations + for (const auto i : c10::irange(3)) { + input += i; + input_dp += i; + + // non-prallel training + torch::optim::SGD optim(model->parameters(), torch::optim::SGDOptions(0.1)); + auto output = model->forward(input); + auto loss = torch::mse_loss(output, torch::zeros_like(output)); + loss.backward(); + optim.step(); + + // data-parallel training + torch::optim::SGD optim_dp( + model_dp->parameters(), torch::optim::SGDOptions(0.1)); + auto output_dp = parallel::data_parallel(model_dp, input_dp); + auto loss_dp = torch::mse_loss(output_dp, torch::zeros_like(output_dp)); + loss_dp.backward(); + optim_dp.step(); + + // make sure that weights are the same + model->to(torch::kCPU); + model_dp->to(torch::kCPU); + auto params = model->parameters(); + auto params_dp = model_dp->parameters(); + ASSERT_EQ(params.size(), params_dp.size()); + for (auto it = params.begin(), it_dp = params_dp.begin(); + it != params.end() && it_dp != params.end(); + ++it, ++it_dp) { + ASSERT_TRUE(torch::allclose(*it, *it_dp)); + } + } } diff --git a/test/cpp/api/parameterdict.cpp b/test/cpp/api/parameterdict.cpp index 21dd1b31d5a88c..1643ba1abe079e 100644 --- a/test/cpp/api/parameterdict.cpp +++ b/test/cpp/api/parameterdict.cpp @@ -62,8 +62,7 @@ TEST_F(ParameterDictTest, InsertAndPop) { ParameterDict dict; dict->insert("A", torch::tensor({1.0})); ASSERT_EQ(dict->size(), 1); - ASSERT_THROWS_WITH( - dict->pop("B"), "Parameter 'B' is not defined"); + ASSERT_THROWS_WITH(dict->pop("B"), "Parameter 'B' is not defined"); torch::Tensor p = dict->pop("A"); ASSERT_EQ(dict->size(), 0); ASSERT_TRUE(torch::eq(p, torch::tensor({1.0})).item()); diff --git a/test/cpp/api/rnn.cpp b/test/cpp/api/rnn.cpp index a647ede97cbe15..3c2dde654db096 100644 --- a/test/cpp/api/rnn.cpp +++ b/test/cpp/api/rnn.cpp @@ -66,7 +66,9 @@ bool test_RNN_xor(Func&& model_maker, bool cuda = false) { return true; }; -void check_lstm_sizes(std::tuple> lstm_output) { +void check_lstm_sizes( + std::tuple> + lstm_output) { // Expect the LSTM to have 64 outputs and 3 layers, with an input of batch // 10 and 16 time steps (10 x 16 x n) @@ -95,7 +97,9 @@ void check_lstm_sizes(std::tuple(), 0); } -void check_lstm_sizes_proj(std::tuple> lstm_output) { +void check_lstm_sizes_proj( + std::tuple> + lstm_output) { // Expect the LSTM to have 32 outputs and 3 layers, with an input of batch // 10 and 16 time steps (10 x 16 x n) @@ -146,7 +150,8 @@ TEST_F(RNNTest, CheckOutputSizes) { auto next_hx = std::get<0>(std::get<1>(next)); auto next_cx = std::get<1>(std::get<1>(next)); - torch::Tensor diff = torch::cat({next_hx, next_cx}, 0) - torch::cat({output_hx, output_cx}, 0); + torch::Tensor diff = + torch::cat({next_hx, next_cx}, 0) - torch::cat({output_hx, output_cx}, 0); // Hiddens changed ASSERT_GT(diff.abs().sum().item(), 1e-3); @@ -229,22 +234,23 @@ TEST_F(RNNTest, CheckOutputValuesMatchPyTorch) { flat = torch::cat({hx, cx}, 0).view(16); // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers,modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays) - float h_out[] = {0.7889, - 0.9003, - 0.7769, - 0.8905, - 0.7635, - 0.8794, - 0.7484, - 0.8666, - 1.1647, - 1.6106, - 1.1425, - 1.5726, - 1.1187, - 1.5329, - 1.0931, - 1.4911}; + float h_out[] = { + 0.7889, + 0.9003, + 0.7769, + 0.8905, + 0.7635, + 0.8794, + 0.7484, + 0.8666, + 1.1647, + 1.6106, + 1.1425, + 1.5726, + 1.1187, + 1.5329, + 1.0931, + 1.4911}; for (size_t i = 0; i < 16; i++) { ASSERT_LT(std::abs(flat[i].item() - h_out[i]), 1e-3); } @@ -256,23 +262,26 @@ TEST_F(RNNTest, EndToEndLSTM) { } TEST_F(RNNTest, EndToEndLSTMProj) { - ASSERT_TRUE(test_RNN_xor( - [](int s) { return LSTM(LSTMOptions(s, s).num_layers(2).proj_size(s / 2)); })); + ASSERT_TRUE(test_RNN_xor([](int s) { + return LSTM(LSTMOptions(s, s).num_layers(2).proj_size(s / 2)); + })); } TEST_F(RNNTest, EndToEndGRU) { - ASSERT_TRUE( - test_RNN_xor([](int s) { return GRU(GRUOptions(s, s).num_layers(2)); })); + ASSERT_TRUE(test_RNN_xor( + [](int s) { return GRU(GRUOptions(s, s).num_layers(2)); })); } TEST_F(RNNTest, EndToEndRNNRelu) { - ASSERT_TRUE(test_RNN_xor( - [](int s) { return RNN(RNNOptions(s, s).nonlinearity(torch::kReLU).num_layers(2)); })); + ASSERT_TRUE(test_RNN_xor([](int s) { + return RNN(RNNOptions(s, s).nonlinearity(torch::kReLU).num_layers(2)); + })); } TEST_F(RNNTest, EndToEndRNNTanh) { - ASSERT_TRUE(test_RNN_xor( - [](int s) { return RNN(RNNOptions(s, s).nonlinearity(torch::kTanh).num_layers(2)); })); + ASSERT_TRUE(test_RNN_xor([](int s) { + return RNN(RNNOptions(s, s).nonlinearity(torch::kTanh).num_layers(2)); + })); } TEST_F(RNNTest, Sizes_CUDA) { @@ -297,7 +306,8 @@ TEST_F(RNNTest, Sizes_CUDA) { auto next_hx = std::get<0>(std::get<1>(next)); auto next_cx = std::get<1>(std::get<1>(next)); - torch::Tensor diff = torch::cat({next_hx, next_cx}, 0) - torch::cat({output_hx, output_cx}, 0); + torch::Tensor diff = + torch::cat({next_hx, next_cx}, 0) - torch::cat({output_hx, output_cx}, 0); // Hiddens changed ASSERT_GT(diff.abs().sum().item(), 1e-3); @@ -339,7 +349,10 @@ TEST_F(RNNTest, EndToEndLSTM_CUDA) { TEST_F(RNNTest, EndToEndLSTMProj_CUDA) { ASSERT_TRUE(test_RNN_xor( - [](int s) { return LSTM(LSTMOptions(s, s).num_layers(2).proj_size(s / 2)); }, true)); + [](int s) { + return LSTM(LSTMOptions(s, s).num_layers(2).proj_size(s / 2)); + }, + true)); } TEST_F(RNNTest, EndToEndGRU_CUDA) { @@ -349,11 +362,17 @@ TEST_F(RNNTest, EndToEndGRU_CUDA) { TEST_F(RNNTest, EndToEndRNNRelu_CUDA) { ASSERT_TRUE(test_RNN_xor( - [](int s) { return RNN(RNNOptions(s, s).nonlinearity(torch::kReLU).num_layers(2)); }, true)); + [](int s) { + return RNN(RNNOptions(s, s).nonlinearity(torch::kReLU).num_layers(2)); + }, + true)); } TEST_F(RNNTest, EndToEndRNNTanh_CUDA) { ASSERT_TRUE(test_RNN_xor( - [](int s) { return RNN(RNNOptions(s, s).nonlinearity(torch::kTanh).num_layers(2)); }, true)); + [](int s) { + return RNN(RNNOptions(s, s).nonlinearity(torch::kTanh).num_layers(2)); + }, + true)); } TEST_F(RNNTest, PrettyPrintRNNs) { @@ -361,13 +380,15 @@ TEST_F(RNNTest, PrettyPrintRNNs) { c10::str(LSTM(LSTMOptions(128, 64).num_layers(3).dropout(0.2))), "torch::nn::LSTM(input_size=128, hidden_size=64, num_layers=3, bias=true, batch_first=false, dropout=0.2, bidirectional=false)"); ASSERT_EQ( - c10::str(LSTM(LSTMOptions(128, 64).num_layers(3).dropout(0.2).proj_size(32))), + c10::str( + LSTM(LSTMOptions(128, 64).num_layers(3).dropout(0.2).proj_size(32))), "torch::nn::LSTM(input_size=128, hidden_size=64, num_layers=3, bias=true, batch_first=false, dropout=0.2, bidirectional=false, proj_size=32)"); ASSERT_EQ( c10::str(GRU(GRUOptions(128, 64).num_layers(3).dropout(0.5))), "torch::nn::GRU(input_size=128, hidden_size=64, num_layers=3, bias=true, batch_first=false, dropout=0.5, bidirectional=false)"); ASSERT_EQ( - c10::str(RNN(RNNOptions(128, 64).num_layers(3).dropout(0.2).nonlinearity(torch::kTanh))), + c10::str(RNN(RNNOptions(128, 64).num_layers(3).dropout(0.2).nonlinearity( + torch::kTanh))), "torch::nn::RNN(input_size=128, hidden_size=64, num_layers=3, bias=true, batch_first=false, dropout=0.2, bidirectional=false)"); } @@ -380,30 +401,40 @@ TEST_F(RNNTest, BidirectionalFlattenParameters) { } template -void copyParameters(torch::nn::ModuleHolder& target, std::string t_suffix, - const torch::nn::ModuleHolder& source, std::string s_suffix) { +void copyParameters( + torch::nn::ModuleHolder& target, + std::string t_suffix, + const torch::nn::ModuleHolder& source, + std::string s_suffix) { at::NoGradGuard guard; - target->named_parameters()["weight_ih_l" + t_suffix].copy_(source->named_parameters()["weight_ih_l" + s_suffix]); - target->named_parameters()["weight_hh_l" + t_suffix].copy_(source->named_parameters()["weight_hh_l" + s_suffix]); - target->named_parameters()["bias_ih_l" + t_suffix].copy_(source->named_parameters()["bias_ih_l" + s_suffix]); - target->named_parameters()["bias_hh_l" + t_suffix].copy_(source->named_parameters()["bias_hh_l" + s_suffix]); + target->named_parameters()["weight_ih_l" + t_suffix].copy_( + source->named_parameters()["weight_ih_l" + s_suffix]); + target->named_parameters()["weight_hh_l" + t_suffix].copy_( + source->named_parameters()["weight_hh_l" + s_suffix]); + target->named_parameters()["bias_ih_l" + t_suffix].copy_( + source->named_parameters()["bias_ih_l" + s_suffix]); + target->named_parameters()["bias_hh_l" + t_suffix].copy_( + source->named_parameters()["bias_hh_l" + s_suffix]); } std::tuple gru_output_to_device( - std::tuple gru_output, torch::Device device) { + std::tuple gru_output, + torch::Device device) { return std::make_tuple( - std::get<0>(gru_output).to(device), - std::get<1>(gru_output).to(device)); + std::get<0>(gru_output).to(device), std::get<1>(gru_output).to(device)); } -std::tuple> lstm_output_to_device( - std::tuple> lstm_output, torch::Device device) { +std::tuple> +lstm_output_to_device( + std::tuple> + lstm_output, + torch::Device device) { auto hidden_states = std::get<1>(lstm_output); return std::make_tuple( - std::get<0>(lstm_output).to(device), - std::make_tuple( - std::get<0>(hidden_states).to(device), - std::get<1>(hidden_states).to(device))); + std::get<0>(lstm_output).to(device), + std::make_tuple( + std::get<0>(hidden_states).to(device), + std::get<1>(hidden_states).to(device))); } // This test is a port of python code introduced here: @@ -411,14 +442,16 @@ std::tuple> lstm_output_ // Reverse forward of bidirectional GRU should act // as regular forward of unidirectional GRU void BidirectionalGRUReverseForward(bool cuda) { - auto opt = torch::TensorOptions().dtype(torch::kFloat32).requires_grad(false) - .device(cuda ? torch::kCUDA : torch::kCPU); + auto opt = torch::TensorOptions() + .dtype(torch::kFloat32) + .requires_grad(false) + .device(cuda ? torch::kCUDA : torch::kCPU); auto input = torch::tensor({1, 2, 3, 4, 5}, opt).reshape({5, 1, 1}); auto input_reversed = torch::tensor({5, 4, 3, 2, 1}, opt).reshape({5, 1, 1}); auto gru_options = GRUOptions(1, 1).num_layers(1).batch_first(false); - GRU bi_grus {gru_options.bidirectional(true)}; - GRU reverse_gru {gru_options.bidirectional(false)}; + GRU bi_grus{gru_options.bidirectional(true)}; + GRU reverse_gru{gru_options.bidirectional(false)}; if (cuda) { bi_grus->to(torch::kCUDA); @@ -437,16 +470,19 @@ void BidirectionalGRUReverseForward(bool cuda) { reverse_output = gru_output_to_device(reverse_output, torch::kCPU); } - ASSERT_EQ(std::get<0>(bi_output).size(0), std::get<0>(reverse_output).size(0)); + ASSERT_EQ( + std::get<0>(bi_output).size(0), std::get<0>(reverse_output).size(0)); auto size = std::get<0>(bi_output).size(0); for (int i = 0; i < size; i++) { - ASSERT_EQ(std::get<0>(bi_output)[i][0][1].item(), - std::get<0>(reverse_output)[size - 1 - i][0][0].item()); + ASSERT_EQ( + std::get<0>(bi_output)[i][0][1].item(), + std::get<0>(reverse_output)[size - 1 - i][0][0].item()); } // The hidden states of the reversed GRUs sits // in the odd indices in the first dimension. - ASSERT_EQ(std::get<1>(bi_output)[1][0][0].item(), - std::get<1>(reverse_output)[0][0][0].item()); + ASSERT_EQ( + std::get<1>(bi_output)[1][0][0].item(), + std::get<1>(reverse_output)[0][0][0].item()); } TEST_F(RNNTest, BidirectionalGRUReverseForward) { @@ -460,15 +496,17 @@ TEST_F(RNNTest, BidirectionalGRUReverseForward_CUDA) { // Reverse forward of bidirectional LSTM should act // as regular forward of unidirectional LSTM void BidirectionalLSTMReverseForwardTest(bool cuda) { - auto opt = torch::TensorOptions().dtype(torch::kFloat32).requires_grad(false) - .device(cuda ? torch::kCUDA : torch::kCPU); + auto opt = torch::TensorOptions() + .dtype(torch::kFloat32) + .requires_grad(false) + .device(cuda ? torch::kCUDA : torch::kCPU); auto input = torch::tensor({1, 2, 3, 4, 5}, opt).reshape({5, 1, 1}); auto input_reversed = torch::tensor({5, 4, 3, 2, 1}, opt).reshape({5, 1, 1}); auto lstm_opt = LSTMOptions(1, 1).num_layers(1).batch_first(false); - LSTM bi_lstm {lstm_opt.bidirectional(true)}; - LSTM reverse_lstm {lstm_opt.bidirectional(false)}; + LSTM bi_lstm{lstm_opt.bidirectional(true)}; + LSTM reverse_lstm{lstm_opt.bidirectional(false)}; if (cuda) { bi_lstm->to(torch::kCUDA); @@ -487,18 +525,22 @@ void BidirectionalLSTMReverseForwardTest(bool cuda) { reverse_output = lstm_output_to_device(reverse_output, torch::kCPU); } - ASSERT_EQ(std::get<0>(bi_output).size(0), std::get<0>(reverse_output).size(0)); + ASSERT_EQ( + std::get<0>(bi_output).size(0), std::get<0>(reverse_output).size(0)); auto size = std::get<0>(bi_output).size(0); for (int i = 0; i < size; i++) { - ASSERT_EQ(std::get<0>(bi_output)[i][0][1].item(), - std::get<0>(reverse_output)[size - 1 - i][0][0].item()); + ASSERT_EQ( + std::get<0>(bi_output)[i][0][1].item(), + std::get<0>(reverse_output)[size - 1 - i][0][0].item()); } // The hidden states of the reversed LSTM sits // in the odd indices in the first dimension. - ASSERT_EQ(std::get<0>(std::get<1>(bi_output))[1][0][0].item(), - std::get<0>(std::get<1>(reverse_output))[0][0][0].item()); - ASSERT_EQ(std::get<1>(std::get<1>(bi_output))[1][0][0].item(), - std::get<1>(std::get<1>(reverse_output))[0][0][0].item()); + ASSERT_EQ( + std::get<0>(std::get<1>(bi_output))[1][0][0].item(), + std::get<0>(std::get<1>(reverse_output))[0][0][0].item()); + ASSERT_EQ( + std::get<1>(std::get<1>(bi_output))[1][0][0].item(), + std::get<1>(std::get<1>(reverse_output))[0][0][0].item()); } TEST_F(RNNTest, BidirectionalLSTMReverseForward) { @@ -511,15 +553,17 @@ TEST_F(RNNTest, BidirectionalLSTMReverseForward_CUDA) { TEST_F(RNNTest, BidirectionalMultilayerGRU_CPU_vs_CUDA) { // Create two GRUs with the same options - auto opt = GRUOptions(2, 4).num_layers(3).batch_first(false).bidirectional(true); - GRU gru_cpu {opt}; - GRU gru_cuda {opt}; + auto opt = + GRUOptions(2, 4).num_layers(3).batch_first(false).bidirectional(true); + GRU gru_cpu{opt}; + GRU gru_cuda{opt}; // Copy weights and biases from CPU GRU to CUDA GRU { at::NoGradGuard guard; for (const auto& param : gru_cpu->named_parameters(/*recurse=*/false)) { - gru_cuda->named_parameters()[param.key()].copy_(gru_cpu->named_parameters()[param.key()]); + gru_cuda->named_parameters()[param.key()].copy_( + gru_cpu->named_parameters()[param.key()]); } } @@ -530,12 +574,13 @@ TEST_F(RNNTest, BidirectionalMultilayerGRU_CPU_vs_CUDA) { gru_cuda->to(torch::kCUDA); // Create the same inputs - auto input_opt = torch::TensorOptions() - .dtype(torch::kFloat32).requires_grad(false); - auto input_cpu = torch::tensor({1, 2, 3, 4, 5, 6}, input_opt) - .reshape({3, 1, 2}); + auto input_opt = + torch::TensorOptions().dtype(torch::kFloat32).requires_grad(false); + auto input_cpu = + torch::tensor({1, 2, 3, 4, 5, 6}, input_opt).reshape({3, 1, 2}); auto input_cuda = torch::tensor({1, 2, 3, 4, 5, 6}, input_opt) - .reshape({3, 1, 2}).to(torch::kCUDA); + .reshape({3, 1, 2}) + .to(torch::kCUDA); // Call forward on both GRUs auto output_cpu = gru_cpu->forward(input_cpu); @@ -546,14 +591,16 @@ TEST_F(RNNTest, BidirectionalMultilayerGRU_CPU_vs_CUDA) { // Assert that the output and state are equal on CPU and CUDA ASSERT_EQ(std::get<0>(output_cpu).dim(), std::get<0>(output_cuda).dim()); for (int i = 0; i < std::get<0>(output_cpu).dim(); i++) { - ASSERT_EQ(std::get<0>(output_cpu).size(i), std::get<0>(output_cuda).size(i)); + ASSERT_EQ( + std::get<0>(output_cpu).size(i), std::get<0>(output_cuda).size(i)); } for (int i = 0; i < std::get<0>(output_cpu).size(0); i++) { for (int j = 0; j < std::get<0>(output_cpu).size(1); j++) { for (int k = 0; k < std::get<0>(output_cpu).size(2); k++) { ASSERT_NEAR( - std::get<0>(output_cpu)[i][j][k].item(), - std::get<0>(output_cuda)[i][j][k].item(), 1e-5); + std::get<0>(output_cpu)[i][j][k].item(), + std::get<0>(output_cuda)[i][j][k].item(), + 1e-5); } } } @@ -561,15 +608,17 @@ TEST_F(RNNTest, BidirectionalMultilayerGRU_CPU_vs_CUDA) { TEST_F(RNNTest, BidirectionalMultilayerLSTM_CPU_vs_CUDA) { // Create two LSTMs with the same options - auto opt = LSTMOptions(2, 4).num_layers(3).batch_first(false).bidirectional(true); - LSTM lstm_cpu {opt}; - LSTM lstm_cuda {opt}; + auto opt = + LSTMOptions(2, 4).num_layers(3).batch_first(false).bidirectional(true); + LSTM lstm_cpu{opt}; + LSTM lstm_cuda{opt}; // Copy weights and biases from CPU LSTM to CUDA LSTM { at::NoGradGuard guard; for (const auto& param : lstm_cpu->named_parameters(/*recurse=*/false)) { - lstm_cuda->named_parameters()[param.key()].copy_(lstm_cpu->named_parameters()[param.key()]); + lstm_cuda->named_parameters()[param.key()].copy_( + lstm_cpu->named_parameters()[param.key()]); } } @@ -579,12 +628,13 @@ TEST_F(RNNTest, BidirectionalMultilayerLSTM_CPU_vs_CUDA) { // Move LSTM to CUDA lstm_cuda->to(torch::kCUDA); - auto options = torch::TensorOptions() - .dtype(torch::kFloat32).requires_grad(false); - auto input_cpu = torch::tensor({1, 2, 3, 4, 5, 6}, options) - .reshape({3, 1, 2}); + auto options = + torch::TensorOptions().dtype(torch::kFloat32).requires_grad(false); + auto input_cpu = + torch::tensor({1, 2, 3, 4, 5, 6}, options).reshape({3, 1, 2}); auto input_cuda = torch::tensor({1, 2, 3, 4, 5, 6}, options) - .reshape({3, 1, 2}).to(torch::kCUDA); + .reshape({3, 1, 2}) + .to(torch::kCUDA); // Call forward on both LSTMs auto output_cpu = lstm_cpu->forward(input_cpu); @@ -595,14 +645,16 @@ TEST_F(RNNTest, BidirectionalMultilayerLSTM_CPU_vs_CUDA) { // Assert that the output and state are equal on CPU and CUDA ASSERT_EQ(std::get<0>(output_cpu).dim(), std::get<0>(output_cuda).dim()); for (int i = 0; i < std::get<0>(output_cpu).dim(); i++) { - ASSERT_EQ(std::get<0>(output_cpu).size(i), std::get<0>(output_cuda).size(i)); + ASSERT_EQ( + std::get<0>(output_cpu).size(i), std::get<0>(output_cuda).size(i)); } for (int i = 0; i < std::get<0>(output_cpu).size(0); i++) { for (int j = 0; j < std::get<0>(output_cpu).size(1); j++) { for (int k = 0; k < std::get<0>(output_cpu).size(2); k++) { ASSERT_NEAR( - std::get<0>(output_cpu)[i][j][k].item(), - std::get<0>(output_cuda)[i][j][k].item(), 1e-5); + std::get<0>(output_cpu)[i][j][k].item(), + std::get<0>(output_cuda)[i][j][k].item(), + 1e-5); } } } @@ -610,15 +662,20 @@ TEST_F(RNNTest, BidirectionalMultilayerLSTM_CPU_vs_CUDA) { TEST_F(RNNTest, BidirectionalMultilayerLSTMProj_CPU_vs_CUDA) { // Create two LSTMs with the same options - auto opt = LSTMOptions(2, 4).num_layers(3).batch_first(false).bidirectional(true).proj_size(2); - LSTM lstm_cpu {opt}; - LSTM lstm_cuda {opt}; + auto opt = LSTMOptions(2, 4) + .num_layers(3) + .batch_first(false) + .bidirectional(true) + .proj_size(2); + LSTM lstm_cpu{opt}; + LSTM lstm_cuda{opt}; // Copy weights and biases from CPU LSTM to CUDA LSTM { at::NoGradGuard guard; for (const auto& param : lstm_cpu->named_parameters(/*recurse=*/false)) { - lstm_cuda->named_parameters()[param.key()].copy_(lstm_cpu->named_parameters()[param.key()]); + lstm_cuda->named_parameters()[param.key()].copy_( + lstm_cpu->named_parameters()[param.key()]); } } @@ -628,12 +685,13 @@ TEST_F(RNNTest, BidirectionalMultilayerLSTMProj_CPU_vs_CUDA) { // Move LSTM to CUDA lstm_cuda->to(torch::kCUDA); - auto options = torch::TensorOptions() - .dtype(torch::kFloat32).requires_grad(false); - auto input_cpu = torch::tensor({1, 2, 3, 4, 5, 6}, options) - .reshape({3, 1, 2}); + auto options = + torch::TensorOptions().dtype(torch::kFloat32).requires_grad(false); + auto input_cpu = + torch::tensor({1, 2, 3, 4, 5, 6}, options).reshape({3, 1, 2}); auto input_cuda = torch::tensor({1, 2, 3, 4, 5, 6}, options) - .reshape({3, 1, 2}).to(torch::kCUDA); + .reshape({3, 1, 2}) + .to(torch::kCUDA); // Call forward on both LSTMs auto output_cpu = lstm_cpu->forward(input_cpu); @@ -644,14 +702,16 @@ TEST_F(RNNTest, BidirectionalMultilayerLSTMProj_CPU_vs_CUDA) { // Assert that the output and state are equal on CPU and CUDA ASSERT_EQ(std::get<0>(output_cpu).dim(), std::get<0>(output_cuda).dim()); for (int i = 0; i < std::get<0>(output_cpu).dim(); i++) { - ASSERT_EQ(std::get<0>(output_cpu).size(i), std::get<0>(output_cuda).size(i)); + ASSERT_EQ( + std::get<0>(output_cpu).size(i), std::get<0>(output_cuda).size(i)); } for (int i = 0; i < std::get<0>(output_cpu).size(0); i++) { for (int j = 0; j < std::get<0>(output_cpu).size(1); j++) { for (int k = 0; k < std::get<0>(output_cpu).size(2); k++) { ASSERT_NEAR( - std::get<0>(output_cpu)[i][j][k].item(), - std::get<0>(output_cuda)[i][j][k].item(), 1e-5); + std::get<0>(output_cpu)[i][j][k].item(), + std::get<0>(output_cuda)[i][j][k].item(), + 1e-5); } } } @@ -661,46 +721,55 @@ TEST_F(RNNTest, UsePackedSequenceAsInput) { { torch::manual_seed(0); auto m = RNN(2, 3); - torch::nn::utils::rnn::PackedSequence packed_input = torch::nn::utils::rnn::pack_sequence({torch::ones({3, 2})}); + torch::nn::utils::rnn::PackedSequence packed_input = + torch::nn::utils::rnn::pack_sequence({torch::ones({3, 2})}); auto rnn_output = m->forward_with_packed_input(packed_input); auto expected_output = torch::tensor( - {{-0.0645, -0.7274, 0.4531}, - {-0.3970, -0.6950, 0.6009}, - {-0.3877, -0.7310, 0.6806}}); - ASSERT_TRUE(torch::allclose(std::get<0>(rnn_output).data(), expected_output, 1e-05, 2e-04)); + {{-0.0645, -0.7274, 0.4531}, + {-0.3970, -0.6950, 0.6009}, + {-0.3877, -0.7310, 0.6806}}); + ASSERT_TRUE(torch::allclose( + std::get<0>(rnn_output).data(), expected_output, 1e-05, 2e-04)); // Test passing optional argument to `RNN::forward_with_packed_input` rnn_output = m->forward_with_packed_input(packed_input, torch::Tensor()); - ASSERT_TRUE(torch::allclose(std::get<0>(rnn_output).data(), expected_output, 1e-05, 2e-04)); + ASSERT_TRUE(torch::allclose( + std::get<0>(rnn_output).data(), expected_output, 1e-05, 2e-04)); } { torch::manual_seed(0); auto m = LSTM(2, 3); - torch::nn::utils::rnn::PackedSequence packed_input = torch::nn::utils::rnn::pack_sequence({torch::ones({3, 2})}); + torch::nn::utils::rnn::PackedSequence packed_input = + torch::nn::utils::rnn::pack_sequence({torch::ones({3, 2})}); auto rnn_output = m->forward_with_packed_input(packed_input); auto expected_output = torch::tensor( - {{-0.2693, -0.1240, 0.0744}, - {-0.3889, -0.1919, 0.1183}, - {-0.4425, -0.2314, 0.1386}}); - ASSERT_TRUE(torch::allclose(std::get<0>(rnn_output).data(), expected_output, 1e-05, 2e-04)); + {{-0.2693, -0.1240, 0.0744}, + {-0.3889, -0.1919, 0.1183}, + {-0.4425, -0.2314, 0.1386}}); + ASSERT_TRUE(torch::allclose( + std::get<0>(rnn_output).data(), expected_output, 1e-05, 2e-04)); // Test passing optional argument to `LSTM::forward_with_packed_input` rnn_output = m->forward_with_packed_input(packed_input, torch::nullopt); - ASSERT_TRUE(torch::allclose(std::get<0>(rnn_output).data(), expected_output, 1e-05, 2e-04)); + ASSERT_TRUE(torch::allclose( + std::get<0>(rnn_output).data(), expected_output, 1e-05, 2e-04)); } { torch::manual_seed(0); auto m = GRU(2, 3); - torch::nn::utils::rnn::PackedSequence packed_input = torch::nn::utils::rnn::pack_sequence({torch::ones({3, 2})}); + torch::nn::utils::rnn::PackedSequence packed_input = + torch::nn::utils::rnn::pack_sequence({torch::ones({3, 2})}); auto rnn_output = m->forward_with_packed_input(packed_input); auto expected_output = torch::tensor( - {{-0.1134, 0.0467, 0.2336}, - {-0.1189, 0.0502, 0.2960}, - {-0.1138, 0.0484, 0.3110}}); - ASSERT_TRUE(torch::allclose(std::get<0>(rnn_output).data(), expected_output, 1e-05, 2e-04)); + {{-0.1134, 0.0467, 0.2336}, + {-0.1189, 0.0502, 0.2960}, + {-0.1138, 0.0484, 0.3110}}); + ASSERT_TRUE(torch::allclose( + std::get<0>(rnn_output).data(), expected_output, 1e-05, 2e-04)); // Test passing optional argument to `GRU::forward_with_packed_input` rnn_output = m->forward_with_packed_input(packed_input, torch::Tensor()); - ASSERT_TRUE(torch::allclose(std::get<0>(rnn_output).data(), expected_output, 1e-05, 2e-04)); + ASSERT_TRUE(torch::allclose( + std::get<0>(rnn_output).data(), expected_output, 1e-05, 2e-04)); } } diff --git a/test/cpp/api/sequential.cpp b/test/cpp/api/sequential.cpp index cac719af6cde59..1c60a176a56d9f 100644 --- a/test/cpp/api/sequential.cpp +++ b/test/cpp/api/sequential.cpp @@ -30,11 +30,10 @@ TEST_F(SequentialTest, ConstructsFromSharedPointer) { std::make_shared(1), std::make_shared(2), std::make_shared(3)); ASSERT_EQ(sequential->size(), 3); - Sequential sequential_named({ - {"m1", std::make_shared(1)}, - {std::string("m2"), std::make_shared(2)}, - {"m3", std::make_shared(3)} - }); + Sequential sequential_named( + {{"m1", std::make_shared(1)}, + {std::string("m2"), std::make_shared(2)}, + {"m3", std::make_shared(3)}}); ASSERT_EQ(sequential->size(), 3); } @@ -56,17 +55,15 @@ TEST_F(SequentialTest, ConstructsFromConcreteType) { copy_count = 0; Sequential sequential(M(1), M(2), M(3)); ASSERT_EQ(sequential->size(), 3); - // NOTE: The current implementation expects each module to be copied exactly once, - // which happens when the module is passed into `std::make_shared()`. - // TODO: Find a way to avoid copying, and then delete the copy constructor of `M`. + // NOTE: The current implementation expects each module to be copied exactly + // once, which happens when the module is passed into `std::make_shared()`. + // TODO: Find a way to avoid copying, and then delete the copy constructor of + // `M`. ASSERT_EQ(copy_count, 3); copy_count = 0; - Sequential sequential_named({ - {"m1", M(1)}, - {std::string("m2"), M(2)}, - {"m3", M(3)} - }); + Sequential sequential_named( + {{"m1", M(1)}, {std::string("m2"), M(2)}, {"m3", M(3)}}); ASSERT_EQ(sequential->size(), 3); ASSERT_EQ(copy_count, 3); } @@ -88,11 +85,8 @@ TEST_F(SequentialTest, ConstructsFromModuleHolder) { Sequential sequential(M(1), M(2), M(3)); ASSERT_EQ(sequential->size(), 3); - Sequential sequential_named({ - {"m1", M(1)}, - {std::string("m2"), M(2)}, - {"m3", M(3)} - }); + Sequential sequential_named( + {{"m1", M(1)}, {std::string("m2"), M(2)}, {"m3", M(3)}}); ASSERT_EQ(sequential->size(), 3); } @@ -144,7 +138,7 @@ TEST_F(SequentialTest, PushBackAddsAnElement) { // named and unnamed AnyModule's Sequential sequential_any; - auto a=torch::nn::AnyModule(torch::nn::Linear(1,2)); + auto a = torch::nn::AnyModule(torch::nn::Linear(1, 2)); ASSERT_EQ(sequential_any->size(), 0); ASSERT_TRUE(sequential_any->is_empty()); sequential_any->push_back(a); @@ -317,8 +311,8 @@ TEST_F(SequentialTest, ExtendPushesModulesFromOtherSequential) { ASSERT_TRUE(b[0]->as()); ASSERT_TRUE(b[1]->as()); - std::vector> c = {std::make_shared(), - std::make_shared()}; + std::vector> c = { + std::make_shared(), std::make_shared()}; b->extend(c); ASSERT_EQ(b->size(), 4); @@ -415,14 +409,13 @@ TEST_F(SequentialTest, PrettyPrintSequential) { " (5): torch::nn::LSTM(input_size=4, hidden_size=5, num_layers=1, bias=true, batch_first=false, dropout=0, bidirectional=false)\n" ")"); - Sequential sequential_named({ - {"linear", Linear(10, 3)}, - {"conv2d", Conv2d(1, 2, 3)}, - {"dropout", Dropout(0.5)}, - {"batchnorm2d", BatchNorm2d(5)}, - {"embedding", Embedding(4, 10)}, - {"lstm", LSTM(4, 5)} - }); + Sequential sequential_named( + {{"linear", Linear(10, 3)}, + {"conv2d", Conv2d(1, 2, 3)}, + {"dropout", Dropout(0.5)}, + {"batchnorm2d", BatchNorm2d(5)}, + {"embedding", Embedding(4, 10)}, + {"lstm", LSTM(4, 5)}}); ASSERT_EQ( c10::str(sequential_named), "torch::nn::Sequential(\n" @@ -437,60 +430,60 @@ TEST_F(SequentialTest, PrettyPrintSequential) { TEST_F(SequentialTest, ModuleForwardMethodOptionalArg) { { - Sequential sequential(Identity(), ConvTranspose1d(ConvTranspose1dOptions(3, 2, 3).stride(1).bias(false))); - std::dynamic_pointer_cast(sequential[1])->weight.set_data(torch::arange(18.).reshape({3, 2, 3})); + Sequential sequential( + Identity(), + ConvTranspose1d(ConvTranspose1dOptions(3, 2, 3).stride(1).bias(false))); + std::dynamic_pointer_cast(sequential[1]) + ->weight.set_data(torch::arange(18.).reshape({3, 2, 3})); auto x = torch::arange(30.).reshape({2, 3, 5}); auto y = sequential->forward(x); - auto expected = torch::tensor({{{ 150., 333., 552., 615., 678., 501., 276.}, - { 195., 432., 714., 804., 894., 654., 357.}}, - {{ 420., 918., 1497., 1560., 1623., 1176., 636.}, - { 600., 1287., 2064., 2154., 2244., 1599., 852.}}}); + auto expected = torch::tensor( + {{{150., 333., 552., 615., 678., 501., 276.}, + {195., 432., 714., 804., 894., 654., 357.}}, + {{420., 918., 1497., 1560., 1623., 1176., 636.}, + {600., 1287., 2064., 2154., 2244., 1599., 852.}}}); ASSERT_TRUE(torch::allclose(y, expected)); } { - Sequential sequential(Identity(), ConvTranspose2d(ConvTranspose2dOptions(3, 2, 3).stride(1).bias(false))); - std::dynamic_pointer_cast(sequential[1])->weight.set_data(torch::arange(54.).reshape({3, 2, 3, 3})); + Sequential sequential( + Identity(), + ConvTranspose2d(ConvTranspose2dOptions(3, 2, 3).stride(1).bias(false))); + std::dynamic_pointer_cast(sequential[1]) + ->weight.set_data(torch::arange(54.).reshape({3, 2, 3, 3})); auto x = torch::arange(75.).reshape({1, 3, 5, 5}); auto y = sequential->forward(x); - auto expected = torch::tensor({{{{ 2250., 4629., 7140., 7311., 7482., 5133., 2640.}, - { 4995., 10272., 15837., 16206., 16575., 11364., 5841.}, - { 8280., 17019., 26226., 26820., 27414., 18783., 9648.}, - { 9225., 18954., 29196., 29790., 30384., 20808., 10683.}, - {10170., 20889., 32166., 32760., 33354., 22833., 11718.}, - { 7515., 15420., 23721., 24144., 24567., 16800., 8613.}, - { 4140., 8487., 13044., 13269., 13494., 9219., 4722.}}, - {{ 2925., 6006., 9246., 9498., 9750., 6672., 3423.}, - { 6480., 13296., 20454., 20985., 21516., 14712., 7542.}, - {10710., 21960., 33759., 34596., 35433., 24210., 12402.}, - {12060., 24705., 37944., 38781., 39618., 27045., 13842.}, - {13410., 27450., 42129., 42966., 43803., 29880., 15282.}, - { 9810., 20064., 30768., 31353., 31938., 21768., 11124.}, - { 5355., 10944., 16770., 17076., 17382., 11838., 6045.}}}}); + auto expected = torch::tensor( + {{{{2250., 4629., 7140., 7311., 7482., 5133., 2640.}, + {4995., 10272., 15837., 16206., 16575., 11364., 5841.}, + {8280., 17019., 26226., 26820., 27414., 18783., 9648.}, + {9225., 18954., 29196., 29790., 30384., 20808., 10683.}, + {10170., 20889., 32166., 32760., 33354., 22833., 11718.}, + {7515., 15420., 23721., 24144., 24567., 16800., 8613.}, + {4140., 8487., 13044., 13269., 13494., 9219., 4722.}}, + {{2925., 6006., 9246., 9498., 9750., 6672., 3423.}, + {6480., 13296., 20454., 20985., 21516., 14712., 7542.}, + {10710., 21960., 33759., 34596., 35433., 24210., 12402.}, + {12060., 24705., 37944., 38781., 39618., 27045., 13842.}, + {13410., 27450., 42129., 42966., 43803., 29880., 15282.}, + {9810., 20064., 30768., 31353., 31938., 21768., 11124.}, + {5355., 10944., 16770., 17076., 17382., 11838., 6045.}}}}); ASSERT_TRUE(torch::allclose(y, expected)); } { - Sequential sequential(Identity(), ConvTranspose3d(ConvTranspose3dOptions(2, 2, 2).stride(1).bias(false))); - std::dynamic_pointer_cast(sequential[1])->weight.set_data(torch::arange(32.).reshape({2, 2, 2, 2, 2})); + Sequential sequential( + Identity(), + ConvTranspose3d(ConvTranspose3dOptions(2, 2, 2).stride(1).bias(false))); + std::dynamic_pointer_cast(sequential[1]) + ->weight.set_data(torch::arange(32.).reshape({2, 2, 2, 2, 2})); auto x = torch::arange(16.).reshape({1, 2, 2, 2, 2}); auto y = sequential->forward(x); - auto expected = torch::tensor({{{{{ 128., 280., 154.}, - { 304., 664., 364.}, - { 184., 400., 218.}}, - {{ 352., 768., 420.}, - { 832., 1808., 984.}, - { 496., 1072., 580.}}, - {{ 256., 552., 298.}, - { 592., 1272., 684.}, - { 344., 736., 394.}}}, - {{{ 192., 424., 234.}, - { 464., 1016., 556.}, - { 280., 608., 330.}}, - {{ 544., 1184., 644.}, - {1280., 2768., 1496.}, - { 752., 1616., 868.}}, - {{ 384., 824., 442.}, - { 880., 1880., 1004.}, - { 504., 1072., 570.}}}}}); + auto expected = torch::tensor( + {{{{{128., 280., 154.}, {304., 664., 364.}, {184., 400., 218.}}, + {{352., 768., 420.}, {832., 1808., 984.}, {496., 1072., 580.}}, + {{256., 552., 298.}, {592., 1272., 684.}, {344., 736., 394.}}}, + {{{192., 424., 234.}, {464., 1016., 556.}, {280., 608., 330.}}, + {{544., 1184., 644.}, {1280., 2768., 1496.}, {752., 1616., 868.}}, + {{384., 824., 442.}, {880., 1880., 1004.}, {504., 1072., 570.}}}}}); ASSERT_TRUE(torch::allclose(y, expected)); } { @@ -516,157 +509,165 @@ TEST_F(SequentialTest, ModuleForwardMethodOptionalArg) { auto value = key; Sequential sequential(MultiheadAttention(embed_dim, num_heads)); - auto output = sequential->forward>(query.transpose(0, 1), key.transpose(0, 1), value.transpose(0, 1)); + auto output = sequential->forward>( + query.transpose(0, 1), key.transpose(0, 1), value.transpose(0, 1)); auto attn_output = std::get<0>(output); auto attn_output_expected = torch::tensor( - {{{0.0674, -0.0056, 0.1324, 0.0922, 0.0160, -0.0934, -0.1700, 0.1663}, - {0.0674, -0.0056, 0.1324, 0.0922, 0.0160, -0.0934, -0.1700, 0.1663}, - {0.0674, -0.0056, 0.1324, 0.0922, 0.0160, -0.0934, -0.1700, 0.1663}, - {0.0674, -0.0056, 0.1324, 0.0922, 0.0160, -0.0934, -0.1700, 0.1663}, - {0.0674, -0.0056, 0.1324, 0.0922, 0.0160, -0.0934, -0.1700, 0.1663}, - {0.0674, -0.0056, 0.1324, 0.0922, 0.0160, -0.0934, -0.1700, 0.1663}, - {0.0674, -0.0056, 0.1324, 0.0922, 0.0160, -0.0934, -0.1700, 0.1663}, - {0.0674, -0.0056, 0.1324, 0.0922, 0.0160, -0.0934, -0.1700, 0.1663}}}); - ASSERT_TRUE(torch::allclose(attn_output, attn_output_expected, 1e-05, 2e-04)); + {{{0.0674, -0.0056, 0.1324, 0.0922, 0.0160, -0.0934, -0.1700, 0.1663}, + {0.0674, -0.0056, 0.1324, 0.0922, 0.0160, -0.0934, -0.1700, 0.1663}, + {0.0674, -0.0056, 0.1324, 0.0922, 0.0160, -0.0934, -0.1700, 0.1663}, + {0.0674, -0.0056, 0.1324, 0.0922, 0.0160, -0.0934, -0.1700, 0.1663}, + {0.0674, -0.0056, 0.1324, 0.0922, 0.0160, -0.0934, -0.1700, 0.1663}, + {0.0674, -0.0056, 0.1324, 0.0922, 0.0160, -0.0934, -0.1700, 0.1663}, + {0.0674, -0.0056, 0.1324, 0.0922, 0.0160, -0.0934, -0.1700, 0.1663}, + {0.0674, + -0.0056, + 0.1324, + 0.0922, + 0.0160, + -0.0934, + -0.1700, + 0.1663}}}); + ASSERT_TRUE( + torch::allclose(attn_output, attn_output_expected, 1e-05, 2e-04)); auto attn_output_weights = std::get<1>(output); auto attn_output_weights_expected = torch::tensor( - {{{0.3333, 0.3333, 0.3333}}, - {{0.3333, 0.3333, 0.3333}}, - {{0.3333, 0.3333, 0.3333}}, - {{0.3333, 0.3333, 0.3333}}, - {{0.3333, 0.3333, 0.3333}}, - {{0.3333, 0.3333, 0.3333}}, - {{0.3333, 0.3333, 0.3333}}, - {{0.3333, 0.3333, 0.3333}}}); - ASSERT_TRUE(torch::allclose(attn_output_weights, attn_output_weights_expected, 1e-05, 2e-04)); + {{{0.3333, 0.3333, 0.3333}}, + {{0.3333, 0.3333, 0.3333}}, + {{0.3333, 0.3333, 0.3333}}, + {{0.3333, 0.3333, 0.3333}}, + {{0.3333, 0.3333, 0.3333}}, + {{0.3333, 0.3333, 0.3333}}, + {{0.3333, 0.3333, 0.3333}}, + {{0.3333, 0.3333, 0.3333}}}); + ASSERT_TRUE(torch::allclose( + attn_output_weights, attn_output_weights_expected, 1e-05, 2e-04)); } { auto indices = torch::tensor({{{1, 3, 4}}}, torch::kLong); auto x = torch::tensor({{{2, 4, 5}}}, torch::dtype(torch::kFloat)); Sequential sequential(MaxUnpool1d(3)); auto y = sequential->forward(x, indices); - auto expected = torch::tensor({{{0, 2, 0, 4, 5, 0, 0, 0, 0}}}, torch::kFloat); + auto expected = + torch::tensor({{{0, 2, 0, 4, 5, 0, 0, 0, 0}}}, torch::kFloat); ASSERT_TRUE(torch::allclose(y, expected)); } { auto indices = torch::tensor( - {{{{ 6, 8, 9}, - {16, 18, 19}, - {21, 23, 24}}}, - {{{ 6, 8, 9}, - {16, 18, 19}, - {21, 23, 24}}}}, torch::kLong); + {{{{6, 8, 9}, {16, 18, 19}, {21, 23, 24}}}, + {{{6, 8, 9}, {16, 18, 19}, {21, 23, 24}}}}, + torch::kLong); auto x = torch::tensor( - {{{{ 6, 8, 9}, - {16, 18, 19}, - {21, 23, 24}}}, - {{{31, 33, 34}, - {41, 43, 44}, - {46, 48, 49}}}}, torch::dtype(torch::kFloat)); - Sequential sequential(MaxUnpool2d(MaxUnpool2dOptions(3).stride(2).padding(1))); + {{{{6, 8, 9}, {16, 18, 19}, {21, 23, 24}}}, + {{{31, 33, 34}, {41, 43, 44}, {46, 48, 49}}}}, + torch::dtype(torch::kFloat)); + Sequential sequential( + MaxUnpool2d(MaxUnpool2dOptions(3).stride(2).padding(1))); auto y = sequential->forward(x, indices); auto expected = torch::tensor( - {{{{ 0, 0, 0, 0, 0}, - { 0, 6, 0, 8, 9}, - { 0, 0, 0, 0, 0}, - { 0, 16, 0, 18, 19}, - { 0, 21, 0, 23, 24}}}, - {{{ 0, 0, 0, 0, 0}, - { 0, 31, 0, 33, 34}, - { 0, 0, 0, 0, 0}, - { 0, 41, 0, 43, 44}, - { 0, 46, 0, 48, 49}}}} , torch::kFloat); + {{{{0, 0, 0, 0, 0}, + {0, 6, 0, 8, 9}, + {0, 0, 0, 0, 0}, + {0, 16, 0, 18, 19}, + {0, 21, 0, 23, 24}}}, + {{{0, 0, 0, 0, 0}, + {0, 31, 0, 33, 34}, + {0, 0, 0, 0, 0}, + {0, 41, 0, 43, 44}, + {0, 46, 0, 48, 49}}}}, + torch::kFloat); ASSERT_TRUE(torch::allclose(y, expected)); } { auto indices = torch::tensor({{{{{26}}}}}, torch::kLong); - auto x = torch::tensor({{{{{26}}}}}, torch::dtype(torch::kFloat).requires_grad(true)); + auto x = torch::tensor( + {{{{{26}}}}}, torch::dtype(torch::kFloat).requires_grad(true)); Sequential sequential(MaxUnpool3d(3)); auto y = sequential->forward(x, indices); auto expected = torch::tensor( - {{{{{ 0, 0, 0}, - { 0, 0, 0}, - { 0, 0, 0}}, - {{ 0, 0, 0}, - { 0, 0, 0}, - { 0, 0, 0}}, - {{ 0, 0, 0}, - { 0, 0, 0}, - { 0, 0, 26}}}}}, torch::kFloat); + {{{{{0, 0, 0}, {0, 0, 0}, {0, 0, 0}}, + {{0, 0, 0}, {0, 0, 0}, {0, 0, 0}}, + {{0, 0, 0}, {0, 0, 0}, {0, 0, 26}}}}}, + torch::kFloat); ASSERT_TRUE(torch::allclose(y, expected)); } { torch::manual_seed(0); Sequential sequential(Identity(), RNN(2, 3)); auto x = torch::ones({2, 3, 2}); - auto rnn_output = sequential->forward>(x); + auto rnn_output = + sequential->forward>(x); auto expected_output = torch::tensor( - {{{-0.0645, -0.7274, 0.4531}, - {-0.0645, -0.7274, 0.4531}, - {-0.0645, -0.7274, 0.4531}}, - {{-0.3970, -0.6950, 0.6009}, - {-0.3970, -0.6950, 0.6009}, - {-0.3970, -0.6950, 0.6009}}}); - ASSERT_TRUE(torch::allclose(std::get<0>(rnn_output), expected_output, 1e-05, 2e-04)); + {{{-0.0645, -0.7274, 0.4531}, + {-0.0645, -0.7274, 0.4531}, + {-0.0645, -0.7274, 0.4531}}, + {{-0.3970, -0.6950, 0.6009}, + {-0.3970, -0.6950, 0.6009}, + {-0.3970, -0.6950, 0.6009}}}); + ASSERT_TRUE(torch::allclose( + std::get<0>(rnn_output), expected_output, 1e-05, 2e-04)); } { torch::manual_seed(0); Sequential sequential(Identity(), LSTM(2, 3)); auto x = torch::ones({2, 3, 2}); - auto rnn_output = sequential->forward>>(x); + auto rnn_output = sequential->forward< + std::tuple>>(x); auto expected_output = torch::tensor( - {{{-0.2693, -0.1240, 0.0744}, - {-0.2693, -0.1240, 0.0744}, - {-0.2693, -0.1240, 0.0744}}, - {{-0.3889, -0.1919, 0.1183}, - {-0.3889, -0.1919, 0.1183}, - {-0.3889, -0.1919, 0.1183}}}); - ASSERT_TRUE(torch::allclose(std::get<0>(rnn_output), expected_output, 1e-05, 2e-04)); + {{{-0.2693, -0.1240, 0.0744}, + {-0.2693, -0.1240, 0.0744}, + {-0.2693, -0.1240, 0.0744}}, + {{-0.3889, -0.1919, 0.1183}, + {-0.3889, -0.1919, 0.1183}, + {-0.3889, -0.1919, 0.1183}}}); + ASSERT_TRUE(torch::allclose( + std::get<0>(rnn_output), expected_output, 1e-05, 2e-04)); } { torch::manual_seed(0); Sequential sequential(Identity(), GRU(2, 3)); auto x = torch::ones({2, 3, 2}); - auto rnn_output = sequential->forward>(x); + auto rnn_output = + sequential->forward>(x); auto expected_output = torch::tensor( - {{{-0.1134, 0.0467, 0.2336}, - {-0.1134, 0.0467, 0.2336}, - {-0.1134, 0.0467, 0.2336}}, - {{-0.1189, 0.0502, 0.2960}, - {-0.1189, 0.0502, 0.2960}, - {-0.1189, 0.0502, 0.2960}}}); - ASSERT_TRUE(torch::allclose(std::get<0>(rnn_output), expected_output, 1e-05, 2e-04)); + {{{-0.1134, 0.0467, 0.2336}, + {-0.1134, 0.0467, 0.2336}, + {-0.1134, 0.0467, 0.2336}}, + {{-0.1189, 0.0502, 0.2960}, + {-0.1189, 0.0502, 0.2960}, + {-0.1189, 0.0502, 0.2960}}}); + ASSERT_TRUE(torch::allclose( + std::get<0>(rnn_output), expected_output, 1e-05, 2e-04)); } { torch::manual_seed(0); Sequential sequential(Identity(), RNNCell(2, 3)); auto x = torch::ones({2, 2}); auto rnn_output = sequential->forward(x); - auto expected_output = torch::tensor( - {{-0.0645, -0.7274, 0.4531}, - {-0.0645, -0.7274, 0.4531}}); + auto expected_output = + torch::tensor({{-0.0645, -0.7274, 0.4531}, {-0.0645, -0.7274, 0.4531}}); ASSERT_TRUE(torch::allclose(rnn_output, expected_output, 1e-05, 2e-04)); } { torch::manual_seed(0); Sequential sequential(Identity(), LSTMCell(2, 3)); auto x = torch::ones({2, 2}); - auto rnn_output = sequential->forward>(x); - auto expected_output = torch::tensor( - {{-0.2693, -0.1240, 0.0744}, - {-0.2693, -0.1240, 0.0744}}); - ASSERT_TRUE(torch::allclose(std::get<0>(rnn_output), expected_output, 1e-05, 2e-04)); + auto rnn_output = + sequential->forward>(x); + auto expected_output = + torch::tensor({{-0.2693, -0.1240, 0.0744}, {-0.2693, -0.1240, 0.0744}}); + ASSERT_TRUE(torch::allclose( + std::get<0>(rnn_output), expected_output, 1e-05, 2e-04)); } { torch::manual_seed(0); Sequential sequential(Identity(), GRUCell(2, 3)); auto x = torch::ones({2, 2}); auto rnn_output = sequential->forward(x); - auto expected_output = torch::tensor( - {{-0.1134, 0.0467, 0.2336}, - {-0.1134, 0.0467, 0.2336}}); + auto expected_output = + torch::tensor({{-0.1134, 0.0467, 0.2336}, {-0.1134, 0.0467, 0.2336}}); ASSERT_TRUE(torch::allclose(rnn_output, expected_output, 1e-05, 2e-04)); } } diff --git a/test/cpp/api/serialize.cpp b/test/cpp/api/serialize.cpp index ecad2348674b79..0cf8ed88c4188b 100644 --- a/test/cpp/api/serialize.cpp +++ b/test/cpp/api/serialize.cpp @@ -1,8 +1,8 @@ #include -#include #include #include +#include #include @@ -37,7 +37,9 @@ torch::Tensor save_and_load(torch::Tensor input) { } // namespace template -void is_optimizer_param_group_equal(const OptimizerParamGroup& lhs, const OptimizerParamGroup& rhs) { +void is_optimizer_param_group_equal( + const OptimizerParamGroup& lhs, + const OptimizerParamGroup& rhs) { const auto& lhs_params = lhs.params(); const auto& rhs_params = rhs.params(); @@ -45,26 +47,36 @@ void is_optimizer_param_group_equal(const OptimizerParamGroup& lhs, const Optimi for (const auto j : c10::irange(lhs_params.size())) { ASSERT_TRUE(torch::equal(lhs_params[j], rhs_params[j])); } - ASSERT_TRUE(static_cast(lhs.options()) == static_cast(rhs.options())); + ASSERT_TRUE( + static_cast(lhs.options()) == + static_cast(rhs.options())); } template void is_optimizer_state_equal( - const ska::flat_hash_map>& lhs_state, - const ska::flat_hash_map>& rhs_state) { - + const ska::flat_hash_map>& + lhs_state, + const ska::flat_hash_map>& + rhs_state) { ASSERT_TRUE(lhs_state.size() == rhs_state.size()); for (const auto& value : lhs_state) { auto found = rhs_state.find(value.first); ASSERT_TRUE(found != rhs_state.end()); - const DerivedOptimizerParamState& lhs_curr_state = static_cast(*(value.second.get())); - const DerivedOptimizerParamState& rhs_curr_state = static_cast(*(found->second.get())); + const DerivedOptimizerParamState& lhs_curr_state = + static_cast(*(value.second.get())); + const DerivedOptimizerParamState& rhs_curr_state = + static_cast(*(found->second.get())); ASSERT_TRUE(lhs_curr_state == rhs_curr_state); } } -template -void test_serialize_optimizer(DerivedOptimizerOptions options, bool only_has_global_state = false) { +template < + typename OptimizerClass, + typename DerivedOptimizerOptions, + typename DerivedOptimizerParamState> +void test_serialize_optimizer( + DerivedOptimizerOptions options, + bool only_has_global_state = false) { torch::manual_seed(0); auto model1 = Linear(5, 2); auto model2 = Linear(5, 2); @@ -86,14 +98,10 @@ void test_serialize_optimizer(DerivedOptimizerOptions options, bool only_has_glo // Make some optimizers auto optim1 = OptimizerClass( {torch::optim::OptimizerParamGroup(model1->parameters())}, options); - auto optim2 = OptimizerClass( - model2->parameters(), options); - auto optim2_2 = OptimizerClass( - model2->parameters(), options); - auto optim3 = OptimizerClass( - model3->parameters(), options); - auto optim3_2 = OptimizerClass( - model3->parameters(), options); + auto optim2 = OptimizerClass(model2->parameters(), options); + auto optim2_2 = OptimizerClass(model2->parameters(), options); + auto optim3 = OptimizerClass(model3->parameters(), options); + auto optim3_2 = OptimizerClass(model3->parameters(), options); auto x = torch::ones({10, 5}); @@ -126,9 +134,11 @@ void test_serialize_optimizer(DerivedOptimizerOptions options, bool only_has_glo auto& optim3_2_state = optim3_2.state(); auto& optim3_state = optim3.state(); - // optim3_2 and optim1 should have param_groups and state of size 1 and state_size respectively + // optim3_2 and optim1 should have param_groups and state of size 1 and + // state_size respectively ASSERT_TRUE(optim3_2_param_groups.size() == 1); - // state_size = 2 for all optimizers except LBFGS as LBFGS only maintains one global state + // state_size = 2 for all optimizers except LBFGS as LBFGS only maintains one + // global state unsigned state_size = only_has_global_state ? 1 : 2; ASSERT_TRUE(optim3_2_state.size() == state_size); @@ -136,11 +146,13 @@ void test_serialize_optimizer(DerivedOptimizerOptions options, bool only_has_glo ASSERT_TRUE(optim3_2_param_groups.size() == optim3_param_groups.size()); ASSERT_TRUE(optim3_2_state.size() == optim3_state.size()); - // checking correctness of serialization logic for optimizer.param_groups_ and optimizer.state_ + // checking correctness of serialization logic for optimizer.param_groups_ and + // optimizer.state_ for (const auto i : c10::irange(optim3_2_param_groups.size())) { is_optimizer_param_group_equal( - optim3_2_param_groups[i], optim3_param_groups[i]); - is_optimizer_state_equal(optim3_2_state, optim3_state); + optim3_2_param_groups[i], optim3_param_groups[i]); + is_optimizer_state_equal( + optim3_2_state, optim3_state); } // Do step2 for model 3 @@ -205,7 +217,8 @@ TEST(SerializeTest, KeysFunc) { auto tempfile = c10::make_tempfile(); torch::serialize::OutputArchive output_archive; for (const auto i : c10::irange(3)) { - output_archive.write("element/" + c10::to_string(i), c10::IValue(static_cast(i))); + output_archive.write( + "element/" + c10::to_string(i), c10::IValue(static_cast(i))); } output_archive.save_to(tempfile.name); torch::serialize::InputArchive input_archive; @@ -221,7 +234,8 @@ TEST(SerializeTest, TryReadFunc) { auto tempfile = c10::make_tempfile(); torch::serialize::OutputArchive output_archive; for (const auto i : c10::irange(3)) { - output_archive.write("element/" + c10::to_string(i), c10::IValue(static_cast(i))); + output_archive.write( + "element/" + c10::to_string(i), c10::IValue(static_cast(i))); } output_archive.save_to(tempfile.name); torch::serialize::InputArchive input_archive; @@ -266,7 +280,7 @@ TEST(SerializeTest, BasicViaFunc) { std::string serialized; torch::save(x, [&](const void* buf, size_t n) { - serialized.append(reinterpret_cast(buf), n); + serialized.append(reinterpret_cast(buf), n); return n; }); torch::Tensor y; @@ -277,14 +291,17 @@ TEST(SerializeTest, BasicViaFunc) { ASSERT_TRUE(x.allclose(y)); torch::Tensor z; - torch::load(z, [&](uint64_t pos, void* buf, size_t n) -> size_t { - if (pos >= serialized.size()) return 0; - size_t nbytes = std::min(static_cast(pos) + n, - serialized.size()) - pos; - memcpy(buf, serialized.data() + pos, nbytes); - return nbytes; - }, - [&]() -> size_t { return serialized.size(); }); + torch::load( + z, + [&](uint64_t pos, void* buf, size_t n) -> size_t { + if (pos >= serialized.size()) + return 0; + size_t nbytes = + std::min(static_cast(pos) + n, serialized.size()) - pos; + memcpy(buf, serialized.data() + pos, nbytes); + return nbytes; + }, + [&]() -> size_t { return serialized.size(); }); ASSERT_TRUE(z.defined()); ASSERT_EQ(x.sizes().vec(), z.sizes().vec()); ASSERT_TRUE(x.allclose(z)); @@ -473,7 +490,8 @@ TEST(SerializeTest, Optim) { } TEST(SerializeTest, Optim_Adagrad) { - test_serialize_optimizer(AdagradOptions(1e-1)); + test_serialize_optimizer( + AdagradOptions(1e-1)); // bc compatibility check auto model1 = Linear(5, 2); @@ -488,17 +506,19 @@ TEST(SerializeTest, Optim_Adagrad) { optimizer.step(); }; step(optim1, model1); - auto optim1_2 = Adagrad(model1->parameters(), torch::optim::AdagradOptions(1e-1)); + auto optim1_2 = + Adagrad(model1->parameters(), torch::optim::AdagradOptions(1e-1)); // fill up with optim1 sum_buffers std::vector sum_buffers; - // fill up with optim1 state_buffers + // fill up with optim1 state_buffers std::vector step_buffers; const auto& params_ = optim1.param_groups()[0].params(); const auto& optim1_state = optim1.state(); - for (const auto & param : params_) { + for (const auto& param : params_) { auto key_ = c10::guts::to_string(param.unsafeGetTensorImpl()); - const AdagradParamState& curr_state_ = static_cast(*(optim1_state.at(key_).get())); + const AdagradParamState& curr_state_ = + static_cast(*(optim1_state.at(key_).get())); sum_buffers.emplace_back(curr_state_.sum()); step_buffers.emplace_back(curr_state_.step()); } @@ -508,19 +528,23 @@ TEST(SerializeTest, Optim_Adagrad) { write_tensors_to_archive(output_archive, "sum_buffers", sum_buffers); write_step_buffers(output_archive, "step_buffers", step_buffers); output_archive.save_to(optim_tempfile_old_format.name); - OLD_SERIALIZATION_LOGIC_WARNING_CHECK(torch::load, optim1_2, optim_tempfile_old_format.name); + OLD_SERIALIZATION_LOGIC_WARNING_CHECK( + torch::load, optim1_2, optim_tempfile_old_format.name); is_optimizer_state_equal(optim1.state(), optim1_2.state()); } TEST(SerializeTest, Optim_SGD) { - test_serialize_optimizer(SGDOptions(1e-1).momentum(0.9)); + test_serialize_optimizer( + SGDOptions(1e-1).momentum(0.9)); // bc compatibility check auto model1 = Linear(5, 2); auto model1_params = model1->parameters(); - // added a tensor for lazy init check - when all params do not have a momentum buffer entry - model1_params.emplace_back(torch::randn({2,3})); - auto optim1 = torch::optim::SGD(model1_params, torch::optim::SGDOptions(0.01).momentum(0.9)); + // added a tensor for lazy init check - when all params do not have a momentum + // buffer entry + model1_params.emplace_back(torch::randn({2, 3})); + auto optim1 = torch::optim::SGD( + model1_params, torch::optim::SGDOptions(0.01).momentum(0.9)); auto x = torch::ones({10, 5}); auto step = [&x](torch::optim::Optimizer& optimizer, Linear model) { @@ -536,9 +560,10 @@ TEST(SerializeTest, Optim_SGD) { const auto& params_ = optim1.param_groups()[0].params(); const auto& optim1_state = optim1.state(); for (const auto i : c10::irange(params_.size())) { - if(i != (params_.size() - 1)) { + if (i != (params_.size() - 1)) { auto key_ = c10::guts::to_string(params_[i].unsafeGetTensorImpl()); - const SGDParamState& curr_state_ = static_cast(*(optim1_state.at(key_).get())); + const SGDParamState& curr_state_ = + static_cast(*(optim1_state.at(key_).get())); momentum_buffers.emplace_back(curr_state_.momentum_buffer()); } } @@ -546,23 +571,29 @@ TEST(SerializeTest, Optim_SGD) { // write momentum_buffers to the file auto optim_tempfile_old_format = c10::make_tempfile(); torch::serialize::OutputArchive output_archive; - write_tensors_to_archive(output_archive, "momentum_buffers", momentum_buffers); + write_tensors_to_archive( + output_archive, "momentum_buffers", momentum_buffers); write_int_value(output_archive, "iteration_", iteration_); output_archive.save_to(optim_tempfile_old_format.name); - auto optim1_2 = SGD(model1_params, torch::optim::SGDOptions(1e-1).momentum(0.9)); - OLD_SERIALIZATION_LOGIC_WARNING_CHECK(torch::load, optim1_2, optim_tempfile_old_format.name); + auto optim1_2 = + SGD(model1_params, torch::optim::SGDOptions(1e-1).momentum(0.9)); + OLD_SERIALIZATION_LOGIC_WARNING_CHECK( + torch::load, optim1_2, optim_tempfile_old_format.name); is_optimizer_state_equal(optim1.state(), optim1_2.state()); } TEST(SerializeTest, Optim_Adam) { - test_serialize_optimizer(AdamOptions().lr(0.99999).amsgrad(true).weight_decay(0.5)); + test_serialize_optimizer( + AdamOptions().lr(0.99999).amsgrad(true).weight_decay(0.5)); // bc compatibility check auto model1 = Linear(5, 2); auto model1_params = model1->parameters(); - // added a tensor for lazy init check - when all params do not have entry in buffers - model1_params.emplace_back(torch::randn({2,3})); - auto optim1 = torch::optim::Adam(model1_params, torch::optim::AdamOptions().weight_decay(0.5)); + // added a tensor for lazy init check - when all params do not have entry in + // buffers + model1_params.emplace_back(torch::randn({2, 3})); + auto optim1 = torch::optim::Adam( + model1_params, torch::optim::AdamOptions().weight_decay(0.5)); auto x = torch::ones({10, 5}); auto step = [&x](torch::optim::Optimizer& optimizer, Linear model) { @@ -580,13 +611,14 @@ TEST(SerializeTest, Optim_Adam) { const auto& params_ = optim1.param_groups()[0].params(); const auto& optim1_state = optim1.state(); for (const auto i : c10::irange(params_.size())) { - if(i != (params_.size() - 1)) { + if (i != (params_.size() - 1)) { auto key_ = c10::guts::to_string(params_[i].unsafeGetTensorImpl()); - const AdamParamState& curr_state_ = static_cast(*(optim1_state.at(key_).get())); + const AdamParamState& curr_state_ = + static_cast(*(optim1_state.at(key_).get())); step_buffers.emplace_back(curr_state_.step()); exp_average_buffers.emplace_back(curr_state_.exp_avg()); exp_average_sq_buffers.emplace_back(curr_state_.exp_avg_sq()); - if(curr_state_.max_exp_avg_sq().defined()) { + if (curr_state_.max_exp_avg_sq().defined()) { max_exp_average_sq_buffers.emplace_back(curr_state_.max_exp_avg_sq()); } } @@ -595,24 +627,32 @@ TEST(SerializeTest, Optim_Adam) { auto optim_tempfile_old_format = c10::make_tempfile(); torch::serialize::OutputArchive output_archive; write_step_buffers(output_archive, "step_buffers", step_buffers); - write_tensors_to_archive(output_archive, "exp_average_buffers", exp_average_buffers); - write_tensors_to_archive(output_archive, "exp_average_sq_buffers", exp_average_sq_buffers); - write_tensors_to_archive(output_archive, "max_exp_average_sq_buffers", max_exp_average_sq_buffers); + write_tensors_to_archive( + output_archive, "exp_average_buffers", exp_average_buffers); + write_tensors_to_archive( + output_archive, "exp_average_sq_buffers", exp_average_sq_buffers); + write_tensors_to_archive( + output_archive, "max_exp_average_sq_buffers", max_exp_average_sq_buffers); output_archive.save_to(optim_tempfile_old_format.name); auto optim1_2 = Adam(model1_params, torch::optim::AdamOptions()); - OLD_SERIALIZATION_LOGIC_WARNING_CHECK(torch::load, optim1_2, optim_tempfile_old_format.name); + OLD_SERIALIZATION_LOGIC_WARNING_CHECK( + torch::load, optim1_2, optim_tempfile_old_format.name); is_optimizer_state_equal(optim1.state(), optim1_2.state()); } TEST(SerializeTest, Optim_AdamW) { - test_serialize_optimizer(AdamWOptions().lr(0.99999).amsgrad(true).betas(std::make_tuple(0.999, 0.1))); + test_serialize_optimizer( + AdamWOptions().lr(0.99999).amsgrad(true).betas( + std::make_tuple(0.999, 0.1))); // bc compatibility check auto model1 = Linear(5, 2); auto model1_params = model1->parameters(); - // added a tensor for lazy init check - when all params do not have entry in buffers - model1_params.emplace_back(torch::randn({2,3})); - auto optim1 = torch::optim::AdamW(model1_params, torch::optim::AdamWOptions().weight_decay(0.5)); + // added a tensor for lazy init check - when all params do not have entry in + // buffers + model1_params.emplace_back(torch::randn({2, 3})); + auto optim1 = torch::optim::AdamW( + model1_params, torch::optim::AdamWOptions().weight_decay(0.5)); auto x = torch::ones({10, 5}); auto step = [&x](torch::optim::Optimizer& optimizer, Linear model) { @@ -630,13 +670,14 @@ TEST(SerializeTest, Optim_AdamW) { const auto& params_ = optim1.param_groups()[0].params(); const auto& optim1_state = optim1.state(); for (const auto i : c10::irange(params_.size())) { - if(i != (params_.size() - 1)) { + if (i != (params_.size() - 1)) { auto key_ = c10::guts::to_string(params_[i].unsafeGetTensorImpl()); - const AdamWParamState& curr_state_ = static_cast(*(optim1_state.at(key_).get())); + const AdamWParamState& curr_state_ = + static_cast(*(optim1_state.at(key_).get())); step_buffers.emplace_back(curr_state_.step()); exp_average_buffers.emplace_back(curr_state_.exp_avg()); exp_average_sq_buffers.emplace_back(curr_state_.exp_avg_sq()); - if(curr_state_.max_exp_avg_sq().defined()) { + if (curr_state_.max_exp_avg_sq().defined()) { max_exp_average_sq_buffers.emplace_back(curr_state_.max_exp_avg_sq()); } } @@ -645,12 +686,16 @@ TEST(SerializeTest, Optim_AdamW) { auto optim_tempfile_old_format = c10::make_tempfile(); torch::serialize::OutputArchive output_archive; write_step_buffers(output_archive, "step_buffers", step_buffers); - write_tensors_to_archive(output_archive, "exp_average_buffers", exp_average_buffers); - write_tensors_to_archive(output_archive, "exp_average_sq_buffers", exp_average_sq_buffers); - write_tensors_to_archive(output_archive, "max_exp_average_sq_buffers", max_exp_average_sq_buffers); + write_tensors_to_archive( + output_archive, "exp_average_buffers", exp_average_buffers); + write_tensors_to_archive( + output_archive, "exp_average_sq_buffers", exp_average_sq_buffers); + write_tensors_to_archive( + output_archive, "max_exp_average_sq_buffers", max_exp_average_sq_buffers); output_archive.save_to(optim_tempfile_old_format.name); auto optim1_2 = AdamW(model1_params, torch::optim::AdamWOptions()); - OLD_SERIALIZATION_LOGIC_WARNING_CHECK(torch::load, optim1_2, optim_tempfile_old_format.name); + OLD_SERIALIZATION_LOGIC_WARNING_CHECK( + torch::load, optim1_2, optim_tempfile_old_format.name); is_optimizer_state_equal(optim1.state(), optim1_2.state()); } @@ -662,8 +707,9 @@ TEST(SerializeTest, Optim_RMSprop) { auto model1 = Linear(5, 2); auto model1_params = model1->parameters(); - // added a tensor for lazy init check - when all params do not have a momentum buffer entry - model1_params.emplace_back(torch::randn({2,3})); + // added a tensor for lazy init check - when all params do not have a momentum + // buffer entry + model1_params.emplace_back(torch::randn({2, 3})); auto optim1 = torch::optim::RMSprop(model1_params, options); auto x = torch::ones({10, 5}); @@ -681,14 +727,15 @@ TEST(SerializeTest, Optim_RMSprop) { const auto& params_ = optim1.param_groups()[0].params(); const auto& optim1_state = optim1.state(); for (const auto i : c10::irange(params_.size())) { - if(i != (params_.size() - 1)) { + if (i != (params_.size() - 1)) { auto key_ = c10::guts::to_string(params_[i].unsafeGetTensorImpl()); - const RMSpropParamState& curr_state_ = static_cast(*(optim1_state.at(key_).get())); + const RMSpropParamState& curr_state_ = + static_cast(*(optim1_state.at(key_).get())); square_average_buffers.emplace_back(curr_state_.square_avg()); - if(curr_state_.momentum_buffer().defined()) { + if (curr_state_.momentum_buffer().defined()) { momentum_buffers.emplace_back(curr_state_.momentum_buffer()); } - if(curr_state_.grad_avg().defined()) { + if (curr_state_.grad_avg().defined()) { grad_average_buffers.emplace_back(curr_state_.grad_avg()); } } @@ -696,21 +743,27 @@ TEST(SerializeTest, Optim_RMSprop) { // write buffers to the file auto optim_tempfile_old_format = c10::make_tempfile(); torch::serialize::OutputArchive output_archive; - write_tensors_to_archive(output_archive, "square_average_buffers", square_average_buffers); - write_tensors_to_archive(output_archive, "momentum_buffers", momentum_buffers); - write_tensors_to_archive(output_archive, "grad_average_buffers", grad_average_buffers); + write_tensors_to_archive( + output_archive, "square_average_buffers", square_average_buffers); + write_tensors_to_archive( + output_archive, "momentum_buffers", momentum_buffers); + write_tensors_to_archive( + output_archive, "grad_average_buffers", grad_average_buffers); output_archive.save_to(optim_tempfile_old_format.name); auto optim1_2 = RMSprop(model1_params, options); - OLD_SERIALIZATION_LOGIC_WARNING_CHECK(torch::load, optim1_2, optim_tempfile_old_format.name); + OLD_SERIALIZATION_LOGIC_WARNING_CHECK( + torch::load, optim1_2, optim_tempfile_old_format.name); const auto& params1_2_ = optim1_2.param_groups()[0].params(); auto& optim1_2_state = optim1_2.state(); // old RMSprop didn't track step value for (const auto i : c10::irange(params1_2_.size())) { - if(i != (params1_2_.size() - 1)) { + if (i != (params1_2_.size() - 1)) { auto key_ = c10::guts::to_string(params_[i].unsafeGetTensorImpl()); auto key1_2_ = c10::guts::to_string(params1_2_[i].unsafeGetTensorImpl()); - const RMSpropParamState& curr_state_ = static_cast(*(optim1_state.at(key_).get())); - RMSpropParamState& curr_state1_2_ = static_cast(*(optim1_2_state.at(key_).get())); + const RMSpropParamState& curr_state_ = + static_cast(*(optim1_state.at(key_).get())); + RMSpropParamState& curr_state1_2_ = + static_cast(*(optim1_2_state.at(key_).get())); curr_state1_2_.step(curr_state_.step()); } } @@ -718,13 +771,16 @@ TEST(SerializeTest, Optim_RMSprop) { } TEST(SerializeTest, Optim_LBFGS) { - test_serialize_optimizer(LBFGSOptions(), true); + test_serialize_optimizer( + LBFGSOptions(), true); // bc compatibility check auto model1 = Linear(5, 2); auto model1_params = model1->parameters(); - // added a tensor for lazy init check - when all params do not have entry in buffers - model1_params.emplace_back(torch::randn({2,3})); - auto optim1 = torch::optim::LBFGS(model1_params, torch::optim::LBFGSOptions()); + // added a tensor for lazy init check - when all params do not have entry in + // buffers + model1_params.emplace_back(torch::randn({2, 3})); + auto optim1 = + torch::optim::LBFGS(model1_params, torch::optim::LBFGSOptions()); auto x = torch::ones({10, 5}); auto step = [&x](torch::optim::Optimizer& optimizer, Linear model) { @@ -742,7 +798,8 @@ TEST(SerializeTest, Optim_LBFGS) { const auto& params_ = optim1.param_groups()[0].params(); auto key_ = c10::guts::to_string(params_[0].unsafeGetTensorImpl()); - const auto& optim1_state = static_cast(*(optim1.state().at(key_).get())); + const auto& optim1_state = + static_cast(*(optim1.state().at(key_).get())); d = optim1_state.d(); t = at::tensor(optim1_state.t()); H_diag = optim1_state.H_diag(); @@ -763,11 +820,13 @@ TEST(SerializeTest, Optim_LBFGS) { output_archive.save_to(optim_tempfile_old_format.name); auto optim1_2 = LBFGS(model1_params, torch::optim::LBFGSOptions()); - OLD_SERIALIZATION_LOGIC_WARNING_CHECK(torch::load, optim1_2, optim_tempfile_old_format.name); + OLD_SERIALIZATION_LOGIC_WARNING_CHECK( + torch::load, optim1_2, optim_tempfile_old_format.name); const auto& params1_2_ = optim1_2.param_groups()[0].params(); auto param_key = c10::guts::to_string(params1_2_[0].unsafeGetTensorImpl()); - auto& optim1_2_state = static_cast(*(optim1_2.state().at(param_key).get())); + auto& optim1_2_state = + static_cast(*(optim1_2.state().at(param_key).get())); // old LBFGS didn't track func_evals, n_iter, ro, al values optim1_2_state.func_evals(optim1_state.func_evals()); @@ -873,7 +932,8 @@ TEST( TEST(SerializeTest, VectorOfTensors) { torch::manual_seed(0); - std::vector x_vec = { torch::randn({1, 2}), torch::randn({3, 4}) }; + std::vector x_vec = { + torch::randn({1, 2}), torch::randn({3, 4})}; std::stringstream stream; torch::save(x_vec, stream); @@ -903,11 +963,14 @@ TEST(SerializeTest, IValue) { input_archive.read("value", ivalue_out); ASSERT_EQ(ivalue_out.toInt(), 1); - ASSERT_THROWS_WITH(input_archive.read("bad_key", ivalue_out), "does not have a field with name"); + ASSERT_THROWS_WITH( + input_archive.read("bad_key", ivalue_out), + "does not have a field with name"); } -// NOTE: if a `Module` contains unserializable submodules (e.g. `nn::Functional`), -// we expect those submodules to be skipped when the `Module` is being serialized. +// NOTE: if a `Module` contains unserializable submodules (e.g. +// `nn::Functional`), we expect those submodules to be skipped when the `Module` +// is being serialized. TEST(SerializeTest, UnserializableSubmoduleIsSkippedWhenSavingModule) { struct A : torch::nn::Module { A() { @@ -924,13 +987,14 @@ TEST(SerializeTest, UnserializableSubmoduleIsSkippedWhenSavingModule) { torch::serialize::InputArchive relu_archive; // Submodule with name "relu" should not exist in the `InputArchive`, - // because the "relu" submodule is an `nn::Functional` and is not serializable. + // because the "relu" submodule is an `nn::Functional` and is not + // serializable. ASSERT_FALSE(archive.try_read("relu", relu_archive)); } -// NOTE: If a `Module` contains unserializable submodules (e.g. `nn::Functional`), -// we don't check the existence of those submodules in the `InputArchive` when -// deserializing. +// NOTE: If a `Module` contains unserializable submodules (e.g. +// `nn::Functional`), we don't check the existence of those submodules in the +// `InputArchive` when deserializing. TEST(SerializeTest, UnserializableSubmoduleIsIgnoredWhenLoadingModule) { struct B : torch::nn::Module { B() { @@ -946,8 +1010,8 @@ TEST(SerializeTest, UnserializableSubmoduleIsIgnoredWhenLoadingModule) { }; auto out = std::make_shared(); - // Manually change the values of "b.foo", so that we can check whether the buffer - // contains these values after deserialization. + // Manually change the values of "b.foo", so that we can check whether the + // buffer contains these values after deserialization. out->named_buffers()["b.foo"].fill_(1); auto tempfile = c10::make_tempfile(); torch::save(out, tempfile.name); @@ -961,23 +1025,24 @@ TEST(SerializeTest, UnserializableSubmoduleIsIgnoredWhenLoadingModule) { ASSERT_TRUE(archive.try_read("b", archive_b)); ASSERT_TRUE(archive_b.try_read("foo", tensor_foo, /*is_buffer=*/true)); - // Submodule with name "relu1" should not exist in `archive_b`, because the "relu1" - // submodule is an `nn::Functional` and is not serializable. + // Submodule with name "relu1" should not exist in `archive_b`, because the + // "relu1" submodule is an `nn::Functional` and is not serializable. ASSERT_FALSE(archive_b.try_read("relu1", archive_relu)); - // Submodule with name "relu2" should not exist in `archive`, because the "relu2" - // submodule is an `nn::Functional` and is not serializable. + // Submodule with name "relu2" should not exist in `archive`, because the + // "relu2" submodule is an `nn::Functional` and is not serializable. ASSERT_FALSE(archive.try_read("relu2", archive_relu)); auto in = std::make_shared(); - // `torch::load(...)` works without error, even though `A` contains the `nn::Functional` - // submodules while the serialized file doesn't, because the `nn::Functional` submodules - // are not serializable and thus ignored when deserializing. + // `torch::load(...)` works without error, even though `A` contains the + // `nn::Functional` submodules while the serialized file doesn't, because the + // `nn::Functional` submodules are not serializable and thus ignored when + // deserializing. torch::load(in, tempfile.name); // Check that the "b.foo" buffer is correctly deserialized from the file. const int output = in->named_buffers()["b.foo"].sum().item(); - // `output` should equal to the sum of the values we manually assigned to "b.foo" before - // serialization. + // `output` should equal to the sum of the values we manually assigned to + // "b.foo" before serialization. ASSERT_EQ(output, 5); } diff --git a/test/cpp/api/special.cpp b/test/cpp/api/special.cpp index e73360645d50a6..9e121968d41f0e 100644 --- a/test/cpp/api/special.cpp +++ b/test/cpp/api/special.cpp @@ -1,13 +1,13 @@ #include -#include #include +#include #include // Simple test that verifies the special namespace is registered properly // properly in C++ TEST(SpecialTest, special) { - auto t = torch::randn(128, torch::kDouble); - torch::special::gammaln(t); + auto t = torch::randn(128, torch::kDouble); + torch::special::gammaln(t); } diff --git a/test/cpp/api/static.cpp b/test/cpp/api/static.cpp index 36bb5deb00cfea..44fed18dd6a844 100644 --- a/test/cpp/api/static.cpp +++ b/test/cpp/api/static.cpp @@ -1,8 +1,8 @@ #include #include -#include #include +#include #include #include diff --git a/test/cpp/api/support.h b/test/cpp/api/support.h index ea729fcf217463..e2c88150445d29 100644 --- a/test/cpp/api/support.h +++ b/test/cpp/api/support.h @@ -4,8 +4,8 @@ #include -#include #include +#include #include #include #include @@ -53,8 +53,10 @@ struct WarningCapture : public WarningHandler { return c10::Join("\n", messages_); } - void process(const SourceLocation& source_location, const std::string& msg, const bool /*verbatim*/) - override { + void process( + const SourceLocation& source_location, + const std::string& msg, + const bool /*verbatim*/) override { messages_.push_back(msg); } @@ -67,22 +69,30 @@ inline bool pointer_equal(at::Tensor first, at::Tensor second) { return first.data_ptr() == second.data_ptr(); } -// This mirrors the `isinstance(x, torch.Tensor) and isinstance(y, torch.Tensor)` branch -// in `TestCase.assertEqual` in torch/testing/_internal/common_utils.py -inline void assert_tensor_equal(at::Tensor a, at::Tensor b, bool allow_inf=false) { +// This mirrors the `isinstance(x, torch.Tensor) and isinstance(y, +// torch.Tensor)` branch in `TestCase.assertEqual` in +// torch/testing/_internal/common_utils.py +inline void assert_tensor_equal( + at::Tensor a, + at::Tensor b, + bool allow_inf = false) { ASSERT_TRUE(a.sizes() == b.sizes()); if (a.numel() > 0) { - if (a.device().type() == torch::kCPU && (a.scalar_type() == torch::kFloat16 || a.scalar_type() == torch::kBFloat16)) { + if (a.device().type() == torch::kCPU && + (a.scalar_type() == torch::kFloat16 || + a.scalar_type() == torch::kBFloat16)) { // CPU half and bfloat16 tensors don't have the methods we need below a = a.to(torch::kFloat32); } - if (a.device().type() == torch::kCUDA && a.scalar_type() == torch::kBFloat16) { + if (a.device().type() == torch::kCUDA && + a.scalar_type() == torch::kBFloat16) { // CUDA bfloat16 tensors don't have the methods we need below a = a.to(torch::kFloat32); } b = b.to(a); - if ((a.scalar_type() == torch::kBool) != (b.scalar_type() == torch::kBool)) { + if ((a.scalar_type() == torch::kBool) != + (b.scalar_type() == torch::kBool)) { TORCH_CHECK(false, "Was expecting both tensors to be bool type."); } else { if (a.scalar_type() == torch::kBool && b.scalar_type() == torch::kBool) { @@ -116,8 +126,9 @@ inline void assert_tensor_equal(at::Tensor a, at::Tensor b, bool allow_inf=false } } -// This mirrors the `isinstance(x, torch.Tensor) and isinstance(y, torch.Tensor)` branch -// in `TestCase.assertNotEqual` in torch/testing/_internal/common_utils.py +// This mirrors the `isinstance(x, torch.Tensor) and isinstance(y, +// torch.Tensor)` branch in `TestCase.assertNotEqual` in +// torch/testing/_internal/common_utils.py inline void assert_tensor_not_equal(at::Tensor x, at::Tensor y) { if (x.sizes() != y.sizes()) { return; @@ -139,7 +150,9 @@ inline void assert_tensor_not_equal(at::Tensor x, at::Tensor y) { } } -inline int count_substr_occurrences(const std::string& str, const std::string& substr) { +inline int count_substr_occurrences( + const std::string& str, + const std::string& substr) { int count = 0; size_t pos = str.find(substr); @@ -154,12 +167,14 @@ inline int count_substr_occurrences(const std::string& str, const std::string& s // A RAII, thread local (!) guard that changes default dtype upon // construction, and sets it back to the original dtype upon destruction. // -// Usage of this guard is synchronized across threads, so that at any given time, -// only one guard can take effect. +// Usage of this guard is synchronized across threads, so that at any given +// time, only one guard can take effect. struct AutoDefaultDtypeMode { static std::mutex default_dtype_mutex; - AutoDefaultDtypeMode(c10::ScalarType default_dtype) : prev_default_dtype(torch::typeMetaToScalarType(torch::get_default_dtype())) { + AutoDefaultDtypeMode(c10::ScalarType default_dtype) + : prev_default_dtype( + torch::typeMetaToScalarType(torch::get_default_dtype())) { default_dtype_mutex.lock(); torch::set_default_dtype(torch::scalarTypeToTypeMeta(default_dtype)); } @@ -170,11 +185,13 @@ struct AutoDefaultDtypeMode { c10::ScalarType prev_default_dtype; }; - -inline void assert_tensor_creation_meta(torch::Tensor& x, torch::autograd::CreationMeta creation_meta) { +inline void assert_tensor_creation_meta( + torch::Tensor& x, + torch::autograd::CreationMeta creation_meta) { auto autograd_meta = x.unsafeGetTensorImpl()->autograd_meta(); TORCH_CHECK(autograd_meta); - auto view_meta = static_cast(autograd_meta); + auto view_meta = + static_cast(autograd_meta); TORCH_CHECK(view_meta->has_bw_view()); ASSERT_EQ(view_meta->get_creation_meta(), creation_meta); } diff --git a/test/cpp/api/tensor.cpp b/test/cpp/api/tensor.cpp index eac3a04205ddb3..78d629f97ef76d 100644 --- a/test/cpp/api/tensor.cpp +++ b/test/cpp/api/tensor.cpp @@ -179,7 +179,9 @@ TEST(TensorTest, AtTensorCtorScalar) { ASSERT_EQ(tensor.dtype(), at::kComplexFloat); ASSERT_TRUE(almost_equal(tensor[0], c10::complex(1.5, 2.0))); - tensor = at::tensor(c10::complex(1.0, 2.0), at::dtype(at::kComplexFloat)) + 0.5; + tensor = + at::tensor(c10::complex(1.0, 2.0), at::dtype(at::kComplexFloat)) + + 0.5; ASSERT_EQ(tensor.numel(), 1); ASSERT_EQ(tensor.dtype(), at::kComplexFloat); ASSERT_TRUE(almost_equal(tensor[0], c10::complex(1.5, 2.0))); @@ -189,7 +191,9 @@ TEST(TensorTest, AtTensorCtorScalar) { ASSERT_EQ(tensor.dtype(), at::kComplexDouble); ASSERT_TRUE(almost_equal(tensor[0], c10::complex(1.5, 2.0))); - tensor = at::tensor(c10::complex(1.0, 2.0), at::dtype(at::kComplexDouble)) + 0.5; + tensor = + at::tensor(c10::complex(1.0, 2.0), at::dtype(at::kComplexDouble)) + + 0.5; ASSERT_EQ(tensor.numel(), 1); ASSERT_EQ(tensor.dtype(), at::kComplexDouble); ASSERT_TRUE(almost_equal(tensor[0], c10::complex(1.5, 2.0))); @@ -217,14 +221,20 @@ TEST(TensorTest, AtTensorCtorSingleDim) { ASSERT_TRUE(almost_equal(tensor[1], 2.25)); ASSERT_TRUE(almost_equal(tensor[2], 3.125)); - tensor = at::tensor({c10::complex(1.5, 0.15), c10::complex(1.5, 0.15), c10::complex(3.125, 0.3125)}); + tensor = at::tensor( + {c10::complex(1.5, 0.15), + c10::complex(1.5, 0.15), + c10::complex(3.125, 0.3125)}); ASSERT_EQ(tensor.numel(), 3); ASSERT_EQ(tensor.dtype(), at::kComplexFloat); ASSERT_TRUE(almost_equal(tensor[0], c10::complex(1.5, 0.15))); ASSERT_TRUE(almost_equal(tensor[1], c10::complex(1.5, 0.15))); ASSERT_TRUE(almost_equal(tensor[2], c10::complex(3.125, 0.3125))); - tensor = at::tensor({c10::complex(1.5, 0.15), c10::complex(1.5, 0.15), c10::complex(3.125, 0.3125)}); + tensor = at::tensor( + {c10::complex(1.5, 0.15), + c10::complex(1.5, 0.15), + c10::complex(3.125, 0.3125)}); ASSERT_EQ(tensor.numel(), 3); ASSERT_EQ(tensor.dtype(), at::kComplexDouble); ASSERT_TRUE(almost_equal(tensor[0], c10::complex(1.5, 0.15))); @@ -246,14 +256,20 @@ TEST(TensorTest, AtTensorCtorSingleDim) { ASSERT_TRUE(almost_equal(tensor[1], 2.25)); ASSERT_TRUE(almost_equal(tensor[2], 3.125)); - tensor = at::tensor(std::vector>({c10::complex(1.5, 0.15), c10::complex(1.5, 0.15), c10::complex(3.125, 0.3125)})); + tensor = at::tensor(std::vector>( + {c10::complex(1.5, 0.15), + c10::complex(1.5, 0.15), + c10::complex(3.125, 0.3125)})); ASSERT_EQ(tensor.numel(), 3); ASSERT_EQ(tensor.dtype(), at::kComplexFloat); ASSERT_TRUE(almost_equal(tensor[0], c10::complex(1.5, 0.15))); ASSERT_TRUE(almost_equal(tensor[1], c10::complex(1.5, 0.15))); ASSERT_TRUE(almost_equal(tensor[2], c10::complex(3.125, 0.3125))); - tensor = at::tensor(std::vector>({c10::complex(1.5, 0.15), c10::complex(1.5, 0.15), c10::complex(3.125, 0.3125)})); + tensor = at::tensor(std::vector>( + {c10::complex(1.5, 0.15), + c10::complex(1.5, 0.15), + c10::complex(3.125, 0.3125)})); ASSERT_EQ(tensor.numel(), 3); ASSERT_EQ(tensor.dtype(), at::kComplexDouble); ASSERT_TRUE(almost_equal(tensor[0], c10::complex(1.5, 0.15))); @@ -277,9 +293,16 @@ TEST(TensorTest, AtTensorCtorSingleDim) { } std::vector> x = { - {1.1, -1.1}, {2.2, -2.2}, {3.3, -3.3}, {4.4, -4.4}, {5.5, -5.5}, - {6.6, -6.6}, {7.7, -7.7}, {8.8, -8.8}, {9.9, -9.9}, {10.0, -10.0} - }; + {1.1, -1.1}, + {2.2, -2.2}, + {3.3, -3.3}, + {4.4, -4.4}, + {5.5, -5.5}, + {6.6, -6.6}, + {7.7, -7.7}, + {8.8, -8.8}, + {9.9, -9.9}, + {10.0, -10.0}}; tensor = at::tensor(x); ASSERT_EQ(tensor.numel(), x.size()); ASSERT_EQ(tensor.dtype(), at::kComplexDouble); @@ -289,7 +312,8 @@ TEST(TensorTest, AtTensorCtorSingleDim) { } TEST(TensorTest, AtTensorCastRealToComplex) { - auto tensor = at::tensor(std::vector({1.5, 2.5, 3.5}), at::kComplexDouble); + auto tensor = + at::tensor(std::vector({1.5, 2.5, 3.5}), at::kComplexDouble); ASSERT_EQ(tensor.numel(), 3); ASSERT_EQ(tensor.dtype(), at::kComplexDouble); ASSERT_TRUE(almost_equal(tensor[0], c10::complex(1.5))); @@ -311,16 +335,21 @@ TEST(TensorTest, AtTensorCastRealToComplex) { TEST(TensorTest, AtTensorCastComplexToRealErrorChecks) { { - ASSERT_THROWS_WITH(at::tensor(c10::complex(0.1, 0.2), at::kFloat), - "\"tensor_cpu\" not implemented for 'Float'"); + ASSERT_THROWS_WITH( + at::tensor(c10::complex(0.1, 0.2), at::kFloat), + "\"tensor_cpu\" not implemented for 'Float'"); } { - ASSERT_THROWS_WITH(at::tensor({c10::complex(0.1, 0.2)}, at::kFloat), - "\"tensor_cpu\" not implemented for 'Float'"); + ASSERT_THROWS_WITH( + at::tensor({c10::complex(0.1, 0.2)}, at::kFloat), + "\"tensor_cpu\" not implemented for 'Float'"); } { - ASSERT_THROWS_WITH(at::tensor(std::vector>{c10::complex(0.1, 0.2)}, at::kFloat), - "\"tensor_cpu\" not implemented for 'Float'"); + ASSERT_THROWS_WITH( + at::tensor( + std::vector>{c10::complex(0.1, 0.2)}, + at::kFloat), + "\"tensor_cpu\" not implemented for 'Float'"); } } @@ -332,7 +361,8 @@ TEST(TensorTest, TorchTensorCtorScalarIntegralType) { ASSERT_EQ(tensor.item(), 123); } -void test_TorchTensorCtorScalarFloatingType_expected_dtype(c10::ScalarType default_dtype) { +void test_TorchTensorCtorScalarFloatingType_expected_dtype( + c10::ScalarType default_dtype) { AutoDefaultDtypeMode dtype_mode(default_dtype); auto tensor = torch::tensor(123.456f); @@ -355,8 +385,10 @@ void test_TorchTensorCtorScalarFloatingType_expected_dtype(c10::ScalarType defau } TEST(TensorTest, TorchTensorCtorScalarFloatingType) { - test_TorchTensorCtorScalarFloatingType_expected_dtype(/*default_dtype=*/torch::kFloat); - test_TorchTensorCtorScalarFloatingType_expected_dtype(/*default_dtype=*/torch::kDouble); + test_TorchTensorCtorScalarFloatingType_expected_dtype( + /*default_dtype=*/torch::kFloat); + test_TorchTensorCtorScalarFloatingType_expected_dtype( + /*default_dtype=*/torch::kDouble); } TEST(TensorTest, TorchTensorCtorScalarBoolType) { @@ -415,7 +447,8 @@ TEST(TensorTest, TorchTensorCtorSingleDimIntegralType) { ASSERT_TRUE(exactly_equal(tensor[2], 3)); } -void test_TorchTensorCtorSingleDimFloatingType_expected_dtype(c10::ScalarType default_dtype) { +void test_TorchTensorCtorSingleDimFloatingType_expected_dtype( + c10::ScalarType default_dtype) { AutoDefaultDtypeMode dtype_mode(default_dtype); auto tensor = torch::tensor({1.5, 2.25, 3.125}); @@ -466,8 +499,10 @@ void test_TorchTensorCtorSingleDimFloatingType_expected_dtype(c10::ScalarType de } TEST(TensorTest, TorchTensorCtorSingleDimFloatingType) { - test_TorchTensorCtorSingleDimFloatingType_expected_dtype(/*default_dtype=*/torch::kFloat); - test_TorchTensorCtorSingleDimFloatingType_expected_dtype(/*default_dtype=*/torch::kDouble); + test_TorchTensorCtorSingleDimFloatingType_expected_dtype( + /*default_dtype=*/torch::kFloat); + test_TorchTensorCtorSingleDimFloatingType_expected_dtype( + /*default_dtype=*/torch::kDouble); } TEST(TensorTest, TorchTensorCtorSingleDimBoolType) { @@ -493,74 +528,91 @@ TEST(TensorTest, TorchTensorCtorMultiDimIntegralType) { auto tensor = torch::tensor({{1, 2}}); ASSERT_EQ(tensor.dtype(), torch::kLong); ASSERT_EQ(tensor.sizes(), std::vector({1, 2})); - ASSERT_TRUE(torch::allclose(tensor, torch::arange(1, 3, torch::kLong).view(tensor.sizes()))); + ASSERT_TRUE(torch::allclose( + tensor, torch::arange(1, 3, torch::kLong).view(tensor.sizes()))); ASSERT_FALSE(tensor.requires_grad()); } { auto tensor = torch::tensor({{1}, {2}}); ASSERT_EQ(tensor.dtype(), torch::kLong); ASSERT_EQ(tensor.sizes(), std::vector({2, 1})); - ASSERT_TRUE(torch::allclose(tensor, torch::arange(1, 3, torch::kLong).view(tensor.sizes()))); + ASSERT_TRUE(torch::allclose( + tensor, torch::arange(1, 3, torch::kLong).view(tensor.sizes()))); ASSERT_FALSE(tensor.requires_grad()); } { auto tensor = torch::tensor({{{1, 2}}}); ASSERT_EQ(tensor.dtype(), torch::kLong); ASSERT_EQ(tensor.sizes(), std::vector({1, 1, 2})); - ASSERT_TRUE(torch::allclose(tensor, torch::arange(1, 3, torch::kLong).view(tensor.sizes()))); + ASSERT_TRUE(torch::allclose( + tensor, torch::arange(1, 3, torch::kLong).view(tensor.sizes()))); ASSERT_FALSE(tensor.requires_grad()); } { auto tensor = torch::tensor({{{1}, {2}}}); ASSERT_EQ(tensor.dtype(), torch::kLong); ASSERT_EQ(tensor.sizes(), std::vector({1, 2, 1})); - ASSERT_TRUE(torch::allclose(tensor, torch::arange(1, 3, torch::kLong).view(tensor.sizes()))); + ASSERT_TRUE(torch::allclose( + tensor, torch::arange(1, 3, torch::kLong).view(tensor.sizes()))); ASSERT_FALSE(tensor.requires_grad()); } { auto tensor = torch::tensor({{1, 2}, {3, 4}}); ASSERT_EQ(tensor.dtype(), torch::kLong); ASSERT_EQ(tensor.sizes(), std::vector({2, 2})); - ASSERT_TRUE(torch::allclose(tensor, torch::arange(1, 5, torch::kLong).view(tensor.sizes()))); + ASSERT_TRUE(torch::allclose( + tensor, torch::arange(1, 5, torch::kLong).view(tensor.sizes()))); ASSERT_FALSE(tensor.requires_grad()); } { auto tensor = torch::tensor({{{{{{{{{{1}}}}}}}}}}); ASSERT_EQ(tensor.dtype(), torch::kLong); - ASSERT_EQ(tensor.sizes(), std::vector({1, 1, 1, 1, 1, 1, 1, 1, 1, 1})); - ASSERT_TRUE(torch::allclose(tensor, torch::full({1}, 1, torch::kLong).view(tensor.sizes()))); + ASSERT_EQ( + tensor.sizes(), std::vector({1, 1, 1, 1, 1, 1, 1, 1, 1, 1})); + ASSERT_TRUE(torch::allclose( + tensor, torch::full({1}, 1, torch::kLong).view(tensor.sizes()))); ASSERT_FALSE(tensor.requires_grad()); } { auto tensor = torch::tensor({{{{{{{{{{1, 2}}}}}}}}}}); ASSERT_EQ(tensor.dtype(), torch::kLong); - ASSERT_EQ(tensor.sizes(), std::vector({1, 1, 1, 1, 1, 1, 1, 1, 1, 2})); - ASSERT_TRUE(torch::allclose(tensor, torch::arange(1, 3, torch::kLong).view(tensor.sizes()))); + ASSERT_EQ( + tensor.sizes(), std::vector({1, 1, 1, 1, 1, 1, 1, 1, 1, 2})); + ASSERT_TRUE(torch::allclose( + tensor, torch::arange(1, 3, torch::kLong).view(tensor.sizes()))); ASSERT_FALSE(tensor.requires_grad()); } } -void test_TorchTensorCtorMultiDimFloatingType_expected_dtype(c10::ScalarType default_dtype) { +void test_TorchTensorCtorMultiDimFloatingType_expected_dtype( + c10::ScalarType default_dtype) { AutoDefaultDtypeMode dtype_mode(default_dtype); { auto tensor = torch::tensor({{1.0, 2.0}}); ASSERT_EQ(tensor.dtype(), default_dtype); ASSERT_EQ(tensor.sizes(), std::vector({1, 2})); - ASSERT_TRUE(torch::allclose(tensor, torch::arange(1, 3, default_dtype).view(tensor.sizes()))); + ASSERT_TRUE(torch::allclose( + tensor, torch::arange(1, 3, default_dtype).view(tensor.sizes()))); ASSERT_FALSE(tensor.requires_grad()); } { - auto tensor = torch::tensor({{{{{{{{1.0, 2.0, 3.0}}}}}, {{{{{4.0, 5.0, 6.0}}}}}, {{{{{7.0, 8.0, 9.0}}}}}}}}); + auto tensor = torch::tensor( + {{{{{{{{1.0, 2.0, 3.0}}}}}, + {{{{{4.0, 5.0, 6.0}}}}}, + {{{{{7.0, 8.0, 9.0}}}}}}}}); ASSERT_EQ(tensor.dtype(), default_dtype); ASSERT_EQ(tensor.sizes(), std::vector({1, 1, 3, 1, 1, 1, 1, 3})); - ASSERT_TRUE(torch::allclose(tensor, torch::arange(1, 10, default_dtype).view(tensor.sizes()))); + ASSERT_TRUE(torch::allclose( + tensor, torch::arange(1, 10, default_dtype).view(tensor.sizes()))); ASSERT_FALSE(tensor.requires_grad()); } } TEST(TensorTest, TorchTensorCtorMultiDimFloatingType) { - test_TorchTensorCtorMultiDimFloatingType_expected_dtype(/*default_dtype=*/torch::kFloat); - test_TorchTensorCtorMultiDimFloatingType_expected_dtype(/*default_dtype=*/torch::kDouble); + test_TorchTensorCtorMultiDimFloatingType_expected_dtype( + /*default_dtype=*/torch::kFloat); + test_TorchTensorCtorMultiDimFloatingType_expected_dtype( + /*default_dtype=*/torch::kDouble); } TEST(TensorTest, TorchTensorCtorMultiDimBoolType) { @@ -591,43 +643,52 @@ TEST(TensorTest, TorchTensorCtorMultiDimWithOptions) { auto tensor = torch::tensor({{1, 2}}, torch::dtype(torch::kInt)); ASSERT_EQ(tensor.dtype(), torch::kInt); ASSERT_EQ(tensor.sizes(), std::vector({1, 2})); - ASSERT_TRUE(torch::allclose(tensor, torch::arange(1, 3, torch::kInt).view(tensor.sizes()))); + ASSERT_TRUE(torch::allclose( + tensor, torch::arange(1, 3, torch::kInt).view(tensor.sizes()))); ASSERT_FALSE(tensor.requires_grad()); } { - auto tensor = torch::tensor({{1, 2}, {3, 4}}, torch::dtype(torch::kFloat).requires_grad(true)); + auto tensor = torch::tensor( + {{1, 2}, {3, 4}}, torch::dtype(torch::kFloat).requires_grad(true)); ASSERT_EQ(tensor.dtype(), torch::kFloat); ASSERT_EQ(tensor.sizes(), std::vector({2, 2})); - ASSERT_TRUE(torch::allclose(tensor, torch::arange(1, 5, torch::kFloat).view(tensor.sizes()))); + ASSERT_TRUE(torch::allclose( + tensor, torch::arange(1, 5, torch::kFloat).view(tensor.sizes()))); ASSERT_TRUE(tensor.requires_grad()); } } TEST(TensorTest, TorchTensorCtorMultiDimErrorChecks) { { - ASSERT_THROWS_WITH(torch::tensor({{{2, 3, 4}, {{5, 6}, {7}}}}), - "Expected all sub-lists to have sizes: 2 (e.g. {5, 6}), but got sub-list {7} with sizes: 1"); + ASSERT_THROWS_WITH( + torch::tensor({{{2, 3, 4}, {{5, 6}, {7}}}}), + "Expected all sub-lists to have sizes: 2 (e.g. {5, 6}), but got sub-list {7} with sizes: 1"); } { - ASSERT_THROWS_WITH(torch::tensor({{{1, 2.0}, {1, 2.0}}}), - "Expected all elements of the tensor to have the same scalar type: Int, but got element of scalar type: Double"); + ASSERT_THROWS_WITH( + torch::tensor({{{1, 2.0}, {1, 2.0}}}), + "Expected all elements of the tensor to have the same scalar type: Int, but got element of scalar type: Double"); } { - ASSERT_THROWS_WITH(torch::tensor({{{true, 2.0, 3}, {true, 2.0, 3}}}), - "Expected all elements of the tensor to have the same scalar type: Bool, but got element of scalar type: Double"); + ASSERT_THROWS_WITH( + torch::tensor({{{true, 2.0, 3}, {true, 2.0, 3}}}), + "Expected all elements of the tensor to have the same scalar type: Bool, but got element of scalar type: Double"); } { - ASSERT_THROWS_WITH(torch::tensor({{{true}, {2}}}), - "Expected all elements of the tensor to have the same scalar type: Bool, but got element of scalar type: Int"); + ASSERT_THROWS_WITH( + torch::tensor({{{true}, {2}}}), + "Expected all elements of the tensor to have the same scalar type: Bool, but got element of scalar type: Int"); } { - ASSERT_THROWS_WITH(torch::tensor({{{true, 2}}}), - "Expected all elements of the tensor to have the same scalar type: Bool, but got element of scalar type: Int"); + ASSERT_THROWS_WITH( + torch::tensor({{{true, 2}}}), + "Expected all elements of the tensor to have the same scalar type: Bool, but got element of scalar type: Int"); } } TEST(TensorTest, TorchTensorCastRealToComplex) { - auto tensor = torch::tensor(std::vector({1.5, 2.5, 3.5}), torch::kComplexDouble); + auto tensor = torch::tensor( + std::vector({1.5, 2.5, 3.5}), torch::kComplexDouble); ASSERT_EQ(tensor.numel(), 3); ASSERT_EQ(tensor.dtype(), torch::kComplexDouble); ASSERT_TRUE(almost_equal(tensor[0], c10::complex(1.5))); @@ -649,40 +710,56 @@ TEST(TensorTest, TorchTensorCastRealToComplex) { TEST(TensorTest, TorchTensorCastComplexToRealErrorChecks) { { - ASSERT_THROWS_WITH(torch::tensor(c10::complex(0.1, 0.2), torch::kFloat), - "value cannot be converted to type float without overflow"); + ASSERT_THROWS_WITH( + torch::tensor(c10::complex(0.1, 0.2), torch::kFloat), + "value cannot be converted to type float without overflow"); } { - ASSERT_THROWS_WITH(torch::tensor({c10::complex(0.1, 0.2), c10::complex(0.3, 0.4)}, torch::kFloat), - "value cannot be converted to type float without overflow"); + ASSERT_THROWS_WITH( + torch::tensor( + {c10::complex(0.1, 0.2), c10::complex(0.3, 0.4)}, + torch::kFloat), + "value cannot be converted to type float without overflow"); } { - ASSERT_THROWS_WITH(torch::tensor(std::vector>{c10::complex(0.1, 0.2), c10::complex(0.3, 0.4)}, torch::kFloat), - "can not do torch::tensor(complex, dtype=non-complex) because complex can not be casted to real number without loss of information"); + ASSERT_THROWS_WITH( + torch::tensor( + std::vector>{ + c10::complex(0.1, 0.2), c10::complex(0.3, 0.4)}, + torch::kFloat), + "can not do torch::tensor(complex, dtype=non-complex) because complex can not be casted to real number without loss of information"); } } -void test_TorchTensorCtorMultiDim_CUDA_expected_dtype(c10::ScalarType default_dtype) { +void test_TorchTensorCtorMultiDim_CUDA_expected_dtype( + c10::ScalarType default_dtype) { AutoDefaultDtypeMode dtype_mode(default_dtype); auto tensor = torch::tensor( - {{{{{{{{1.0, 2.0, 3.0}}}}}, {{{{{4.0, 5.0, 6.0}}}}}, {{{{{7.0, 8.0, 9.0}}}}}}}}, - torch::dtype(default_dtype).device(torch::kCUDA)); + {{{{{{{{1.0, 2.0, 3.0}}}}}, + {{{{{4.0, 5.0, 6.0}}}}}, + {{{{{7.0, 8.0, 9.0}}}}}}}}, + torch::dtype(default_dtype).device(torch::kCUDA)); ASSERT_TRUE(tensor.device().is_cuda()); ASSERT_EQ(tensor.dtype(), default_dtype); ASSERT_EQ(tensor.sizes(), std::vector({1, 1, 3, 1, 1, 1, 1, 3})); ASSERT_TRUE(torch::allclose( - tensor, - torch::arange(1, 10, default_dtype).view(tensor.sizes()).to(torch::kCUDA))); + tensor, + torch::arange(1, 10, default_dtype) + .view(tensor.sizes()) + .to(torch::kCUDA))); ASSERT_FALSE(tensor.requires_grad()); } TEST(TensorTest, TorchTensorCtorMultiDim_CUDA) { - test_TorchTensorCtorMultiDim_CUDA_expected_dtype(/*default_dtype=*/torch::kFloat); - test_TorchTensorCtorMultiDim_CUDA_expected_dtype(/*default_dtype=*/torch::kDouble); + test_TorchTensorCtorMultiDim_CUDA_expected_dtype( + /*default_dtype=*/torch::kFloat); + test_TorchTensorCtorMultiDim_CUDA_expected_dtype( + /*default_dtype=*/torch::kDouble); } -void test_TorchTensorCtorZeroSizedDim_expected_dtype(c10::ScalarType default_dtype) { +void test_TorchTensorCtorZeroSizedDim_expected_dtype( + c10::ScalarType default_dtype) { AutoDefaultDtypeMode dtype_mode(default_dtype); { auto tensor = torch::tensor({}); @@ -729,55 +806,94 @@ void test_TorchTensorCtorZeroSizedDim_expected_dtype(c10::ScalarType default_dty { auto tensor = torch::tensor({{{{{{{{{{}}}}}}}}}}); ASSERT_EQ(tensor.numel(), 0); - ASSERT_EQ(tensor.sizes(), std::vector({1, 1, 1, 1, 1, 1, 1, 1, 1, 0})); + ASSERT_EQ( + tensor.sizes(), std::vector({1, 1, 1, 1, 1, 1, 1, 1, 1, 0})); ASSERT_EQ(tensor.dtype(), default_dtype); ASSERT_FALSE(tensor.requires_grad()); } } TEST(TensorTest, TorchTensorCtorZeroSizedDim) { - test_TorchTensorCtorZeroSizedDim_expected_dtype(/*default_dtype=*/torch::kFloat); - test_TorchTensorCtorZeroSizedDim_expected_dtype(/*default_dtype=*/torch::kDouble); + test_TorchTensorCtorZeroSizedDim_expected_dtype( + /*default_dtype=*/torch::kFloat); + test_TorchTensorCtorZeroSizedDim_expected_dtype( + /*default_dtype=*/torch::kDouble); } -void test_TorchTensorCtorWithoutSpecifyingDtype_expected_dtype(c10::ScalarType default_dtype) { +void test_TorchTensorCtorWithoutSpecifyingDtype_expected_dtype( + c10::ScalarType default_dtype) { AutoDefaultDtypeMode dtype_mode(default_dtype); ASSERT_EQ(torch::tensor({1., 2., 3.}).dtype(), default_dtype); ASSERT_EQ(torch::tensor({{1., 2., 3.}}).dtype(), default_dtype); - ASSERT_EQ(torch::tensor({1., 2., 3.}, torch::TensorOptions()).dtype(), default_dtype); - ASSERT_EQ(torch::tensor({{1., 2., 3.}}, torch::TensorOptions()).dtype(), default_dtype); + ASSERT_EQ( + torch::tensor({1., 2., 3.}, torch::TensorOptions()).dtype(), + default_dtype); + ASSERT_EQ( + torch::tensor({{1., 2., 3.}}, torch::TensorOptions()).dtype(), + default_dtype); } TEST(TensorTest, TorchTensorCtorWithoutSpecifyingDtype) { ASSERT_EQ(torch::tensor({1, 2, 3}).dtype(), torch::kLong); ASSERT_EQ(torch::tensor({{1, 2, 3}}).dtype(), torch::kLong); - ASSERT_EQ(torch::tensor({1, 2, 3}, torch::TensorOptions()).dtype(), torch::kLong); - ASSERT_EQ(torch::tensor({{1, 2, 3}}, torch::TensorOptions()).dtype(), torch::kLong); - - test_TorchTensorCtorWithoutSpecifyingDtype_expected_dtype(/*default_dtype=*/torch::kFloat); - test_TorchTensorCtorWithoutSpecifyingDtype_expected_dtype(/*default_dtype=*/torch::kDouble); + ASSERT_EQ( + torch::tensor({1, 2, 3}, torch::TensorOptions()).dtype(), torch::kLong); + ASSERT_EQ( + torch::tensor({{1, 2, 3}}, torch::TensorOptions()).dtype(), torch::kLong); + + test_TorchTensorCtorWithoutSpecifyingDtype_expected_dtype( + /*default_dtype=*/torch::kFloat); + test_TorchTensorCtorWithoutSpecifyingDtype_expected_dtype( + /*default_dtype=*/torch::kDouble); } -void test_TorchTensorCtorWithNonDtypeOptions_expected_dtype(c10::ScalarType default_dtype) { +void test_TorchTensorCtorWithNonDtypeOptions_expected_dtype( + c10::ScalarType default_dtype) { AutoDefaultDtypeMode dtype_mode(default_dtype); - ASSERT_EQ(torch::tensor({1, 2, 3}, torch::TensorOptions()).dtype(), torch::kLong); - ASSERT_EQ(torch::tensor(at::ArrayRef({1, 2, 3}), torch::TensorOptions()).dtype(), torch::kLong); - ASSERT_EQ(torch::tensor(std::vector({1, 2, 3}), torch::TensorOptions()).dtype(), torch::kLong); - - ASSERT_EQ(torch::tensor({1., 2., 3.}, torch::TensorOptions()).dtype(), default_dtype); - ASSERT_EQ(torch::tensor(at::ArrayRef({1., 2., 3.}), torch::TensorOptions()).dtype(), default_dtype); - ASSERT_EQ(torch::tensor(std::vector({1., 2., 3.}), torch::TensorOptions()).dtype(), default_dtype); - - ASSERT_EQ(torch::tensor({1.f, 2.f, 3.f}, torch::TensorOptions()).dtype(), default_dtype); - ASSERT_EQ(torch::tensor(at::ArrayRef({1.f, 2.f, 3.f}), torch::TensorOptions()).dtype(), default_dtype); - ASSERT_EQ(torch::tensor(std::vector({1.f, 2.f, 3.f}), torch::TensorOptions()).dtype(), default_dtype); + ASSERT_EQ( + torch::tensor({1, 2, 3}, torch::TensorOptions()).dtype(), torch::kLong); + ASSERT_EQ( + torch::tensor(at::ArrayRef({1, 2, 3}), torch::TensorOptions()) + .dtype(), + torch::kLong); + ASSERT_EQ( + torch::tensor(std::vector({1, 2, 3}), torch::TensorOptions()) + .dtype(), + torch::kLong); + + ASSERT_EQ( + torch::tensor({1., 2., 3.}, torch::TensorOptions()).dtype(), + default_dtype); + ASSERT_EQ( + torch::tensor(at::ArrayRef({1., 2., 3.}), torch::TensorOptions()) + .dtype(), + default_dtype); + ASSERT_EQ( + torch::tensor(std::vector({1., 2., 3.}), torch::TensorOptions()) + .dtype(), + default_dtype); + + ASSERT_EQ( + torch::tensor({1.f, 2.f, 3.f}, torch::TensorOptions()).dtype(), + default_dtype); + ASSERT_EQ( + torch::tensor( + at::ArrayRef({1.f, 2.f, 3.f}), torch::TensorOptions()) + .dtype(), + default_dtype); + ASSERT_EQ( + torch::tensor(std::vector({1.f, 2.f, 3.f}), torch::TensorOptions()) + .dtype(), + default_dtype); } TEST(TensorTest, TorchTensorCtorWithNonDtypeOptions) { - test_TorchTensorCtorWithNonDtypeOptions_expected_dtype(/*default_dtype=*/torch::kFloat); - test_TorchTensorCtorWithNonDtypeOptions_expected_dtype(/*default_dtype=*/torch::kDouble); + test_TorchTensorCtorWithNonDtypeOptions_expected_dtype( + /*default_dtype=*/torch::kFloat); + test_TorchTensorCtorWithNonDtypeOptions_expected_dtype( + /*default_dtype=*/torch::kDouble); } void test_Arange_expected_dtype(c10::ScalarType default_dtype) { @@ -796,77 +912,79 @@ TEST(TensorTest, Arange) { } TEST(TensorTest, PrettyPrintTensorDataContainer) { + { ASSERT_EQ(c10::str(torch::detail::TensorDataContainer(1.1)), "1.1"); } { ASSERT_EQ( - c10::str(torch::detail::TensorDataContainer(1.1)), - "1.1"); - } - { - ASSERT_EQ( - c10::str(torch::detail::TensorDataContainer({1.1, 2.2})), - "{1.1, 2.2}"); + c10::str(torch::detail::TensorDataContainer({1.1, 2.2})), "{1.1, 2.2}"); } { ASSERT_EQ( - c10::str(torch::detail::TensorDataContainer({{1, 2}, {3, 4}})), - "{{1, 2}, {3, 4}}"); + c10::str(torch::detail::TensorDataContainer({{1, 2}, {3, 4}})), + "{{1, 2}, {3, 4}}"); } { ASSERT_EQ( - c10::str(torch::detail::TensorDataContainer({{{{{{{{1.1, 2.2, 3.3}}}}}, {{{{{4.4, 5.5, 6.6}}}}}, {{{{{7.7, 8.8, 9.9}}}}}}}})), - "{{{{{{{{1.1, 2.2, 3.3}}}}}, {{{{{4.4, 5.5, 6.6}}}}}, {{{{{7.7, 8.8, 9.9}}}}}}}}"); + c10::str(torch::detail::TensorDataContainer( + {{{{{{{{1.1, 2.2, 3.3}}}}}, + {{{{{4.4, 5.5, 6.6}}}}}, + {{{{{7.7, 8.8, 9.9}}}}}}}})), + "{{{{{{{{1.1, 2.2, 3.3}}}}}, {{{{{4.4, 5.5, 6.6}}}}}, {{{{{7.7, 8.8, 9.9}}}}}}}}"); } { ASSERT_EQ( - c10::str(torch::detail::TensorDataContainer({{{{{{{{{{1}}}}}}}}}})), - "{{{{{{{{{{1}}}}}}}}}}"); + c10::str(torch::detail::TensorDataContainer({{{{{{{{{{1}}}}}}}}}})), + "{{{{{{{{{{1}}}}}}}}}}"); } { ASSERT_EQ( - c10::str(torch::detail::TensorDataContainer({{{{{{{{{{}}}}}}}}}})), - "{{{{{{{{{{}}}}}}}}}}"); + c10::str(torch::detail::TensorDataContainer({{{{{{{{{{}}}}}}}}}})), + "{{{{{{{{{{}}}}}}}}}}"); } { ASSERT_EQ( - c10::str(torch::detail::TensorDataContainer({{{{{{{{{{1, 2}}}}}}}}}})), - "{{{{{{{{{{1, 2}}}}}}}}}}"); + c10::str(torch::detail::TensorDataContainer({{{{{{{{{{1, 2}}}}}}}}}})), + "{{{{{{{{{{1, 2}}}}}}}}}}"); } { ASSERT_EQ( - c10::str(torch::detail::TensorDataContainer(at::ArrayRef({1.1, 2.2}))), - "{1.1, 2.2}"); + c10::str(torch::detail::TensorDataContainer( + at::ArrayRef({1.1, 2.2}))), + "{1.1, 2.2}"); } { ASSERT_EQ( - c10::str(torch::detail::TensorDataContainer(std::vector({1.1, 2.2}))), - "{1.1, 2.2}"); + c10::str(torch::detail::TensorDataContainer( + std::vector({1.1, 2.2}))), + "{1.1, 2.2}"); } } TEST(TensorTest, TensorDataContainerCallingAccessorOfWrongType) { { ASSERT_THROWS_WITH( - torch::detail::TensorDataContainer(1.1).init_list(), - "Can only call `init_list()` on a TensorDataContainer that has `is_init_list() == true`"); + torch::detail::TensorDataContainer(1.1).init_list(), + "Can only call `init_list()` on a TensorDataContainer that has `is_init_list() == true`"); ASSERT_THROWS_WITH( - torch::detail::TensorDataContainer(1.1).tensor(), - "Can only call `tensor()` on a TensorDataContainer that has `is_tensor() == true`"); + torch::detail::TensorDataContainer(1.1).tensor(), + "Can only call `tensor()` on a TensorDataContainer that has `is_tensor() == true`"); } { ASSERT_THROWS_WITH( - torch::detail::TensorDataContainer({1.1, 2.2}).scalar(), - "Can only call `scalar()` on a TensorDataContainer that has `is_scalar() == true`"); + torch::detail::TensorDataContainer({1.1, 2.2}).scalar(), + "Can only call `scalar()` on a TensorDataContainer that has `is_scalar() == true`"); ASSERT_THROWS_WITH( - torch::detail::TensorDataContainer({1.1, 2.2}).tensor(), - "Can only call `tensor()` on a TensorDataContainer that has `is_tensor() == true`"); + torch::detail::TensorDataContainer({1.1, 2.2}).tensor(), + "Can only call `tensor()` on a TensorDataContainer that has `is_tensor() == true`"); } { ASSERT_THROWS_WITH( - torch::detail::TensorDataContainer(at::ArrayRef({1.1, 2.2})).scalar(), - "Can only call `scalar()` on a TensorDataContainer that has `is_scalar() == true`"); + torch::detail::TensorDataContainer(at::ArrayRef({1.1, 2.2})) + .scalar(), + "Can only call `scalar()` on a TensorDataContainer that has `is_scalar() == true`"); ASSERT_THROWS_WITH( - torch::detail::TensorDataContainer(at::ArrayRef({1.1, 2.2})).init_list(), - "Can only call `init_list()` on a TensorDataContainer that has `is_init_list() == true`"); + torch::detail::TensorDataContainer(at::ArrayRef({1.1, 2.2})) + .init_list(), + "Can only call `init_list()` on a TensorDataContainer that has `is_init_list() == true`"); } } @@ -968,17 +1086,17 @@ TEST(TensorTest, BackwardAndGrad) { } TEST(TensorTest, BackwardCreatesOnesGrad) { - const auto x = torch::tensor({5}, torch::dtype(torch::kFloat).requires_grad(true)); + const auto x = + torch::tensor({5}, torch::dtype(torch::kFloat).requires_grad(true)); x.backward(); - ASSERT_TRUE(torch::equal(x.grad(), - torch::ones_like(x))); + ASSERT_TRUE(torch::equal(x.grad(), torch::ones_like(x))); } TEST(TensorTest, BackwardNonScalarOutputs) { auto x = torch::randn({5, 5}, torch::requires_grad()); auto y = x * x; - ASSERT_THROWS_WITH(y.backward(), - "grad can be implicitly created only for scalar outputs"); + ASSERT_THROWS_WITH( + y.backward(), "grad can be implicitly created only for scalar outputs"); } TEST(TensorTest, IsLeaf) { @@ -1040,15 +1158,18 @@ TEST(TensorTest, RequiresGradInplace) { ASSERT_TRUE(x.requires_grad()); auto y = x * x; - ASSERT_THROWS_WITH(y.requires_grad_(false), - "you can only change requires_grad flags of leaf variables."); + ASSERT_THROWS_WITH( + y.requires_grad_(false), + "you can only change requires_grad flags of leaf variables."); x.requires_grad_(false); ASSERT_FALSE(x.requires_grad()); - const auto int_tensor = torch::tensor({5}, at::TensorOptions().dtype(torch::kInt)); - ASSERT_THROWS_WITH(int_tensor.requires_grad_(true), - "Only Tensors of floating point and complex dtype can require gradients"); + const auto int_tensor = + torch::tensor({5}, at::TensorOptions().dtype(torch::kInt)); + ASSERT_THROWS_WITH( + int_tensor.requires_grad_(true), + "Only Tensors of floating point and complex dtype can require gradients"); } TEST(TensorTest, StdDimension) { @@ -1078,21 +1199,14 @@ TEST(TensorTest, ReshapeAlias) { // that it matches the behavior of as_strided and view. auto x = torch::randn({3, 3}); ASSERT_TRUE(torch::equal( - torch::_reshape_alias(x, {2, 2}, {1, 2}), - torch::as_strided(x, {2, 2}, {1, 2}) - )); - ASSERT_TRUE(torch::equal( - torch::_reshape_alias(x, {9}, {1}), - x.view({-1}) - )); + torch::_reshape_alias(x, {2, 2}, {1, 2}), + torch::as_strided(x, {2, 2}, {1, 2}))); + ASSERT_TRUE(torch::equal(torch::_reshape_alias(x, {9}, {1}), x.view({-1}))); // Test that the backward works fine. auto y = torch::randn({3, 3}, torch::requires_grad(true)); auto z = torch::clone(y).detach().requires_grad_(true); (y * y).view({-1}).mean().backward(); torch::_reshape_alias((z * z), {9}, {1}).mean().backward(); - ASSERT_TRUE(torch::equal( - y.grad(), - z.grad() - )); + ASSERT_TRUE(torch::equal(y.grad(), z.grad())); } diff --git a/test/cpp/api/tensor_cuda.cpp b/test/cpp/api/tensor_cuda.cpp index fdd1189302cead..7a89f0ae536786 100644 --- a/test/cpp/api/tensor_cuda.cpp +++ b/test/cpp/api/tensor_cuda.cpp @@ -82,9 +82,9 @@ TEST(TensorTest, ToTensorAndTensorAttributes_MultiCUDA) { REQUIRE_TENSOR_OPTIONS(at::kCPU, -1, at::kFloat, at::kStrided); } - TEST(TensorTest, ToDoesNotCopyWhenOptionsAreAllTheSame_CUDA) { - auto tensor = at::empty({3, 4}, at::TensorOptions(at::kFloat).device(at::Device("cuda"))); + auto tensor = at::empty( + {3, 4}, at::TensorOptions(at::kFloat).device(at::Device("cuda"))); auto hopefully_not_copy = tensor.to(tensor.options()); ASSERT_EQ(hopefully_not_copy.data_ptr(), tensor.data_ptr()); hopefully_not_copy = tensor.to(at::kFloat); @@ -117,11 +117,9 @@ TEST(TensorTest, ToDeviceAndDtype_MultiCUDA) { TEST(TensorTest, MagmaInitializesCorrectly_CUDA) { // Any tensor will work here as long as it's invertible // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays) - float data[] = { 1, 1, 1, 0, - 0, 3, 1, 2, - 2, 3, 1, 0, - 1, 0, 2, 1 }; - auto tensor = at::from_blob(data, {4, 4}, at::TensorOptions(at::kFloat)).cuda(); + float data[] = {1, 1, 1, 0, 0, 3, 1, 2, 2, 3, 1, 0, 1, 0, 2, 1}; + auto tensor = + at::from_blob(data, {4, 4}, at::TensorOptions(at::kFloat)).cuda(); if (at::hasMAGMA()) { at::inverse(tensor); } diff --git a/test/cpp/api/tensor_flatten.cpp b/test/cpp/api/tensor_flatten.cpp index 36888e8b03565b..d510891ee11b2e 100644 --- a/test/cpp/api/tensor_flatten.cpp +++ b/test/cpp/api/tensor_flatten.cpp @@ -1,9 +1,9 @@ #include #include -#include -#include #include +#include +#include using namespace torch::test; @@ -12,9 +12,11 @@ TEST(UnflattenDenseTensorTest, TestEmptyTensor) { auto emptyTensor2 = at::tensor(std::vector()); auto tensor1 = at::tensor({1, 2, 3}); auto tensor2 = at::tensor({4, 5}); - auto tensorList = std::vector({tensor1, emptyTensor1, emptyTensor2, tensor2}); + auto tensorList = + std::vector({tensor1, emptyTensor1, emptyTensor2, tensor2}); auto flatTensor = at::tensor({1, 2, 3, 4, 5}); - auto unflatten_results = torch::utils::unflatten_dense_tensors(flatTensor, tensorList); + auto unflatten_results = + torch::utils::unflatten_dense_tensors(flatTensor, tensorList); ASSERT_EQ(unflatten_results.size(), 4); ASSERT_EQ(unflatten_results.at(0).numel(), 3); ASSERT_EQ(unflatten_results.at(1).numel(), 0); @@ -29,11 +31,13 @@ TEST(UnflattenDenseTensorTest, TestEmptyTensor) { // as other non-empty tenosr like unflatten_results.at(3). // after fix, the empty tensor and non-empty tensor do not share the same // storage. - ASSERT_NE(unflatten_results.at(1).data_ptr(), unflatten_results.at(3).data_ptr()); + ASSERT_NE( + unflatten_results.at(1).data_ptr(), unflatten_results.at(3).data_ptr()); unflatten_results.at(1).resize_(1); unflatten_results.at(2).resize_(1); // after resizing the two empty tensors, the resized tensors do not share - // the same storage. without fix in unflatten_dense_tensors() for empty tensors, - // the resized tensors will share the same storage. - ASSERT_NE(unflatten_results.at(1).data_ptr(), unflatten_results.at(2).data_ptr()); + // the same storage. without fix in unflatten_dense_tensors() for empty + // tensors, the resized tensors will share the same storage. + ASSERT_NE( + unflatten_results.at(1).data_ptr(), unflatten_results.at(2).data_ptr()); } diff --git a/test/cpp/api/tensor_indexing.cpp b/test/cpp/api/tensor_indexing.cpp index 740ede02b1e9ff..3a8589df7bfd79 100644 --- a/test/cpp/api/tensor_indexing.cpp +++ b/test/cpp/api/tensor_indexing.cpp @@ -18,7 +18,14 @@ TEST(TensorIndexingTest, Slice) { TEST(TensorIndexingTest, TensorIndex) { { - std::vector indices = {None, "...", Ellipsis, 0, true, Slice(1, None, 2), torch::tensor({1, 2})}; + std::vector indices = { + None, + "...", + Ellipsis, + 0, + true, + Slice(1, None, 2), + torch::tensor({1, 2})}; ASSERT_TRUE(indices[0].is_none()); ASSERT_TRUE(indices[1].is_ellipsis()); ASSERT_TRUE(indices[2].is_ellipsis()); @@ -35,12 +42,15 @@ TEST(TensorIndexingTest, TensorIndex) { } ASSERT_THROWS_WITH( - TensorIndex(".."), - "Expected \"...\" to represent an ellipsis index, but got \"..\""); + TensorIndex(".."), + "Expected \"...\" to represent an ellipsis index, but got \"..\""); { - std::vector indices = {None, "...", Ellipsis, 0, true, Slice(1, None, 2)}; - ASSERT_EQ(c10::str(indices), c10::str("(None, ..., ..., 0, true, 1:", INDEX_MAX, ":2)")); + std::vector indices = { + None, "...", Ellipsis, 0, true, Slice(1, None, 2)}; + ASSERT_EQ( + c10::str(indices), + c10::str("(None, ..., ..., 0, true, 1:", INDEX_MAX, ":2)")); ASSERT_EQ(c10::str(indices[0]), "None"); ASSERT_EQ(c10::str(indices[1]), "..."); ASSERT_EQ(c10::str(indices[2]), "..."); @@ -49,24 +59,53 @@ TEST(TensorIndexingTest, TensorIndex) { ASSERT_EQ(c10::str(indices[5]), c10::str("1:", INDEX_MAX, ":2")); } - ASSERT_EQ(c10::str(std::vector({Slice()})), c10::str("(0:", INDEX_MAX, ":1)")); - ASSERT_EQ(c10::str(std::vector({Slice(None, None)})), c10::str("(0:", INDEX_MAX, ":1)")); - ASSERT_EQ(c10::str(std::vector({Slice(None, None, None)})), c10::str("(0:", INDEX_MAX, ":1)")); - - ASSERT_EQ(c10::str(std::vector({Slice(1, None)})), c10::str("(1:", INDEX_MAX, ":1)")); - ASSERT_EQ(c10::str(std::vector({Slice(1, None, None)})), c10::str("(1:", INDEX_MAX, ":1)")); - ASSERT_EQ(c10::str(std::vector({Slice(None, 3)})), c10::str("(0:3:1)")); - ASSERT_EQ(c10::str(std::vector({Slice(None, 3, None)})), c10::str("(0:3:1)")); - ASSERT_EQ(c10::str(std::vector({Slice(None, None, 2)})), c10::str("(0:", INDEX_MAX, ":2)")); - ASSERT_EQ(c10::str(std::vector({Slice(None, None, -1)})), c10::str("(", INDEX_MAX, ":", INDEX_MIN, ":-1)")); - - ASSERT_EQ(c10::str(std::vector({Slice(1, 3)})), c10::str("(1:3:1)")); - ASSERT_EQ(c10::str(std::vector({Slice(1, None, 2)})), c10::str("(1:", INDEX_MAX, ":2)")); - ASSERT_EQ(c10::str(std::vector({Slice(1, None, -1)})), c10::str("(1:", INDEX_MIN, ":-1)")); - ASSERT_EQ(c10::str(std::vector({Slice(None, 3, 2)})), c10::str("(0:3:2)")); - ASSERT_EQ(c10::str(std::vector({Slice(None, 3, -1)})), c10::str("(", INDEX_MAX, ":3:-1)")); - - ASSERT_EQ(c10::str(std::vector({Slice(1, 3, 2)})), c10::str("(1:3:2)")); + ASSERT_EQ( + c10::str(std::vector({Slice()})), + c10::str("(0:", INDEX_MAX, ":1)")); + ASSERT_EQ( + c10::str(std::vector({Slice(None, None)})), + c10::str("(0:", INDEX_MAX, ":1)")); + ASSERT_EQ( + c10::str(std::vector({Slice(None, None, None)})), + c10::str("(0:", INDEX_MAX, ":1)")); + + ASSERT_EQ( + c10::str(std::vector({Slice(1, None)})), + c10::str("(1:", INDEX_MAX, ":1)")); + ASSERT_EQ( + c10::str(std::vector({Slice(1, None, None)})), + c10::str("(1:", INDEX_MAX, ":1)")); + ASSERT_EQ( + c10::str(std::vector({Slice(None, 3)})), + c10::str("(0:3:1)")); + ASSERT_EQ( + c10::str(std::vector({Slice(None, 3, None)})), + c10::str("(0:3:1)")); + ASSERT_EQ( + c10::str(std::vector({Slice(None, None, 2)})), + c10::str("(0:", INDEX_MAX, ":2)")); + ASSERT_EQ( + c10::str(std::vector({Slice(None, None, -1)})), + c10::str("(", INDEX_MAX, ":", INDEX_MIN, ":-1)")); + + ASSERT_EQ( + c10::str(std::vector({Slice(1, 3)})), c10::str("(1:3:1)")); + ASSERT_EQ( + c10::str(std::vector({Slice(1, None, 2)})), + c10::str("(1:", INDEX_MAX, ":2)")); + ASSERT_EQ( + c10::str(std::vector({Slice(1, None, -1)})), + c10::str("(1:", INDEX_MIN, ":-1)")); + ASSERT_EQ( + c10::str(std::vector({Slice(None, 3, 2)})), + c10::str("(0:3:2)")); + ASSERT_EQ( + c10::str(std::vector({Slice(None, 3, -1)})), + c10::str("(", INDEX_MAX, ":3:-1)")); + + ASSERT_EQ( + c10::str(std::vector({Slice(1, 3, 2)})), + c10::str("(1:3:2)")); } TEST(TensorIndexingTest, TestNoIndices) { @@ -74,13 +113,25 @@ TEST(TensorIndexingTest, TestNoIndices) { torch::Tensor value = torch::randn({20, 20}); std::vector indices; - ASSERT_THROWS_WITH(tensor.index({}), "Passing an empty index list to Tensor::index() is not valid syntax"); - ASSERT_THROWS_WITH(tensor.index_put_({}, 1), "Passing an empty index list to Tensor::index_put_() is not valid syntax"); - ASSERT_THROWS_WITH(tensor.index_put_({}, value), "Passing an empty index list to Tensor::index_put_() is not valid syntax"); + ASSERT_THROWS_WITH( + tensor.index({}), + "Passing an empty index list to Tensor::index() is not valid syntax"); + ASSERT_THROWS_WITH( + tensor.index_put_({}, 1), + "Passing an empty index list to Tensor::index_put_() is not valid syntax"); + ASSERT_THROWS_WITH( + tensor.index_put_({}, value), + "Passing an empty index list to Tensor::index_put_() is not valid syntax"); - ASSERT_THROWS_WITH(tensor.index(indices), "Passing an empty index list to Tensor::index() is not valid syntax"); - ASSERT_THROWS_WITH(tensor.index_put_(indices, 1), "Passing an empty index list to Tensor::index_put_() is not valid syntax"); - ASSERT_THROWS_WITH(tensor.index_put_(indices, value), "Passing an empty index list to Tensor::index_put_() is not valid syntax"); + ASSERT_THROWS_WITH( + tensor.index(indices), + "Passing an empty index list to Tensor::index() is not valid syntax"); + ASSERT_THROWS_WITH( + tensor.index_put_(indices, 1), + "Passing an empty index list to Tensor::index_put_() is not valid syntax"); + ASSERT_THROWS_WITH( + tensor.index_put_(indices, value), + "Passing an empty index list to Tensor::index_put_() is not valid syntax"); } TEST(TensorIndexingTest, TestAdvancedIndexingWithListOfTensor) { @@ -95,14 +146,17 @@ TEST(TensorIndexingTest, TestAdvancedIndexingWithListOfTensor) { torch::Tensor tensor = torch::randn({20, 20}); torch::Tensor index = torch::arange(10, torch::kLong).cpu(); torch::Tensor result = at::index_put_(tensor, {index}, torch::ones({20})); - torch::Tensor result_with_init_list = tensor.index_put_({index}, torch::ones({20})); + torch::Tensor result_with_init_list = + tensor.index_put_({index}, torch::ones({20})); ASSERT_TRUE(result.equal(result_with_init_list)); } { torch::Tensor tensor = torch::randn({20, 20}); torch::Tensor index = torch::arange(10, torch::kLong).cpu(); - torch::Tensor result = at::index_put_(tensor, {index}, torch::ones({1, 20})); - torch::Tensor result_with_init_list = tensor.index_put_({index}, torch::ones({1, 20})); + torch::Tensor result = + at::index_put_(tensor, {index}, torch::ones({1, 20})); + torch::Tensor result_with_init_list = + tensor.index_put_({index}, torch::ones({1, 20})); ASSERT_TRUE(result.equal(result_with_init_list)); } } @@ -126,15 +180,19 @@ TEST(TensorIndexingTest, TestNone) { auto v = torch::randn({5, 7, 3}); ASSERT_EQ(v.index({None}).sizes(), torch::IntArrayRef({1, 5, 7, 3})); ASSERT_EQ(v.index({Slice(), None}).sizes(), torch::IntArrayRef({5, 1, 7, 3})); - ASSERT_EQ(v.index({Slice(), None, None}).sizes(), torch::IntArrayRef({5, 1, 1, 7, 3})); + ASSERT_EQ( + v.index({Slice(), None, None}).sizes(), + torch::IntArrayRef({5, 1, 1, 7, 3})); ASSERT_EQ(v.index({"...", None}).sizes(), torch::IntArrayRef({5, 7, 3, 1})); } TEST(TensorIndexingTest, TestStep) { auto v = torch::arange(10); assert_tensor_equal(v.index({Slice(None, None, 1)}), v); - assert_tensor_equal(v.index({Slice(None, None, 2)}), torch::tensor({0, 2, 4, 6, 8})); - assert_tensor_equal(v.index({Slice(None, None, 3)}), torch::tensor({0, 3, 6, 9})); + assert_tensor_equal( + v.index({Slice(None, None, 2)}), torch::tensor({0, 2, 4, 6, 8})); + assert_tensor_equal( + v.index({Slice(None, None, 3)}), torch::tensor({0, 3, 6, 9})); assert_tensor_equal(v.index({Slice(None, None, 11)}), torch::tensor({0})); assert_tensor_equal(v.index({Slice(1, 6, 2)}), torch::tensor({1, 3, 5})); } @@ -149,9 +207,12 @@ TEST(TensorIndexingTest, TestStepAssignment) { TEST(TensorIndexingTest, TestBoolIndices) { { auto v = torch::randn({5, 7, 3}); - auto boolIndices = torch::tensor({true, false, true, true, false}, torch::kBool); + auto boolIndices = + torch::tensor({true, false, true, true, false}, torch::kBool); ASSERT_EQ(v.index({boolIndices}).sizes(), torch::IntArrayRef({3, 7, 3})); - assert_tensor_equal(v.index({boolIndices}), torch::stack({v.index({0}), v.index({2}), v.index({3})})); + assert_tensor_equal( + v.index({boolIndices}), + torch::stack({v.index({0}), v.index({2}), v.index({3})})); } { auto v = torch::tensor({true, false, true}, torch::kBool); @@ -161,11 +222,17 @@ TEST(TensorIndexingTest, TestBoolIndices) { { WarningCapture warnings; - ASSERT_EQ(v.index({boolIndices}).sizes(), v.index({uint8Indices}).sizes()); + ASSERT_EQ( + v.index({boolIndices}).sizes(), v.index({uint8Indices}).sizes()); assert_tensor_equal(v.index({boolIndices}), v.index({uint8Indices})); - assert_tensor_equal(v.index({boolIndices}), torch::tensor({true}, torch::kBool)); - - ASSERT_EQ(count_substr_occurrences(warnings.str(), "indexing with dtype torch.uint8 is now deprecated"), 2); + assert_tensor_equal( + v.index({boolIndices}), torch::tensor({true}, torch::kBool)); + + ASSERT_EQ( + count_substr_occurrences( + warnings.str(), + "indexing with dtype torch.uint8 is now deprecated"), + 2); } } } @@ -177,13 +244,13 @@ TEST(TensorIndexingTest, TestBoolIndicesAccumulate) { assert_tensor_equal(y, torch::ones({10, 10})); } - TEST(TensorIndexingTest, TestMultipleBoolIndices) { auto v = torch::randn({5, 7, 3}); // note: these broadcast together and are transposed to the first dim auto mask1 = torch::tensor({1, 0, 1, 1, 0}, torch::kBool); auto mask2 = torch::tensor({1, 1, 1}, torch::kBool); - ASSERT_EQ(v.index({mask1, Slice(), mask2}).sizes(), torch::IntArrayRef({3, 7})); + ASSERT_EQ( + v.index({mask1, Slice(), mask2}).sizes(), torch::IntArrayRef({3, 7})); } TEST(TensorIndexingTest, TestByteMask) { @@ -196,7 +263,11 @@ TEST(TensorIndexingTest, TestByteMask) { ASSERT_EQ(v.index({mask}).sizes(), torch::IntArrayRef({3, 7, 3})); assert_tensor_equal(v.index({mask}), torch::stack({v[0], v[2], v[3]})); - ASSERT_EQ(count_substr_occurrences(warnings.str(), "indexing with dtype torch.uint8 is now deprecated"), 2); + ASSERT_EQ( + count_substr_occurrences( + warnings.str(), + "indexing with dtype torch.uint8 is now deprecated"), + 2); } } { @@ -214,7 +285,11 @@ TEST(TensorIndexingTest, TestByteMaskAccumulate) { y.index_put_({mask}, y.index({mask}), /*accumulate=*/true); assert_tensor_equal(y, torch::ones({10, 10})); - ASSERT_EQ(count_substr_occurrences(warnings.str(), "indexing with dtype torch.uint8 is now deprecated"), 2); + ASSERT_EQ( + count_substr_occurrences( + warnings.str(), + "indexing with dtype torch.uint8 is now deprecated"), + 2); } } @@ -226,9 +301,14 @@ TEST(TensorIndexingTest, TestMultipleByteMask) { { WarningCapture warnings; - ASSERT_EQ(v.index({mask1, Slice(), mask2}).sizes(), torch::IntArrayRef({3, 7})); + ASSERT_EQ( + v.index({mask1, Slice(), mask2}).sizes(), torch::IntArrayRef({3, 7})); - ASSERT_EQ(count_substr_occurrences(warnings.str(), "indexing with dtype torch.uint8 is now deprecated"), 2); + ASSERT_EQ( + count_substr_occurrences( + warnings.str(), + "indexing with dtype torch.uint8 is now deprecated"), + 2); } } @@ -242,18 +322,24 @@ TEST(TensorIndexingTest, TestByteMask2d) { TEST(TensorIndexingTest, TestIntIndices) { auto v = torch::randn({5, 7, 3}); - ASSERT_EQ(v.index({torch::tensor({0, 4, 2})}).sizes(), torch::IntArrayRef({3, 7, 3})); - ASSERT_EQ(v.index({Slice(), torch::tensor({0, 4, 2})}).sizes(), torch::IntArrayRef({5, 3, 3})); - ASSERT_EQ(v.index({Slice(), torch::tensor({{0, 1}, {4, 3}})}).sizes(), torch::IntArrayRef({5, 2, 2, 3})); + ASSERT_EQ( + v.index({torch::tensor({0, 4, 2})}).sizes(), + torch::IntArrayRef({3, 7, 3})); + ASSERT_EQ( + v.index({Slice(), torch::tensor({0, 4, 2})}).sizes(), + torch::IntArrayRef({5, 3, 3})); + ASSERT_EQ( + v.index({Slice(), torch::tensor({{0, 1}, {4, 3}})}).sizes(), + torch::IntArrayRef({5, 2, 2, 3})); } - TEST(TensorIndexingTest, TestIntIndices2d) { // From the NumPy indexing example auto x = torch::arange(0, 12, torch::kLong).view({4, 3}); auto rows = torch::tensor({{0, 0}, {3, 3}}); auto columns = torch::tensor({{0, 2}, {0, 2}}); - assert_tensor_equal(x.index({rows, columns}), torch::tensor({{0, 2}, {9, 11}})); + assert_tensor_equal( + x.index({rows, columns}), torch::tensor({{0, 2}, {9, 11}})); } TEST(TensorIndexingTest, TestIntIndicesBroadcast) { @@ -285,20 +371,30 @@ TEST(TensorIndexingTest, TestEmptyNdimIndex) { { auto x = torch::randn({5}, device); assert_tensor_equal( - torch::empty({0, 2}, device), - x.index({torch::empty({0, 2}, torch::TensorOptions(torch::kInt64).device(device))})); + torch::empty({0, 2}, device), + x.index({torch::empty( + {0, 2}, torch::TensorOptions(torch::kInt64).device(device))})); } { auto x = torch::randn({2, 3, 4, 5}, device); assert_tensor_equal( - torch::empty({2, 0, 6, 4, 5}, device), - x.index({Slice(), torch::empty({0, 6}, torch::TensorOptions(torch::kInt64).device(device))})); + torch::empty({2, 0, 6, 4, 5}, device), + x.index( + {Slice(), + torch::empty( + {0, 6}, torch::TensorOptions(torch::kInt64).device(device))})); } { auto x = torch::empty({10, 0}); - ASSERT_EQ(x.index({torch::tensor({1, 2})}).sizes(), torch::IntArrayRef({2, 0})); - ASSERT_EQ(x.index({torch::tensor({}, torch::kLong), torch::tensor({}, torch::kLong)}).sizes(), torch::IntArrayRef({0})); - ASSERT_THROWS_WITH(x.index({Slice(), torch::tensor({0, 1})}), "for dimension with size 0"); + ASSERT_EQ( + x.index({torch::tensor({1, 2})}).sizes(), torch::IntArrayRef({2, 0})); + ASSERT_EQ( + x.index( + {torch::tensor({}, torch::kLong), torch::tensor({}, torch::kLong)}) + .sizes(), + torch::IntArrayRef({0})); + ASSERT_THROWS_WITH( + x.index({Slice(), torch::tensor({0, 1})}), "for dimension with size 0"); } } @@ -307,14 +403,18 @@ TEST(TensorIndexingTest, TestEmptyNdimIndex_CUDA) { { auto x = torch::randn({5}, device); assert_tensor_equal( - torch::empty({0, 2}, device), - x.index({torch::empty({0, 2}, torch::TensorOptions(torch::kInt64).device(device))})); + torch::empty({0, 2}, device), + x.index({torch::empty( + {0, 2}, torch::TensorOptions(torch::kInt64).device(device))})); } { auto x = torch::randn({2, 3, 4, 5}, device); assert_tensor_equal( - torch::empty({2, 0, 6, 4, 5}, device), - x.index({Slice(), torch::empty({0, 6}, torch::TensorOptions(torch::kInt64).device(device))})); + torch::empty({2, 0, 6, 4, 5}, device), + x.index( + {Slice(), + torch::empty( + {0, 6}, torch::TensorOptions(torch::kInt64).device(device))})); } } @@ -322,14 +422,20 @@ TEST(TensorIndexingTest, TestEmptyNdimIndexBool) { torch::Device device(torch::kCPU); auto x = torch::randn({5}, device); // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto) - ASSERT_THROW(x.index({torch::empty({0, 2}, torch::TensorOptions(torch::kUInt8).device(device))}), c10::Error); + ASSERT_THROW( + x.index({torch::empty( + {0, 2}, torch::TensorOptions(torch::kUInt8).device(device))}), + c10::Error); } TEST(TensorIndexingTest, TestEmptyNdimIndexBool_CUDA) { torch::Device device(torch::kCUDA); auto x = torch::randn({5}, device); // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto) - ASSERT_THROW(x.index({torch::empty({0, 2}, torch::TensorOptions(torch::kUInt8).device(device))}), c10::Error); + ASSERT_THROW( + x.index({torch::empty( + {0, 2}, torch::TensorOptions(torch::kUInt8).device(device))}), + c10::Error); } TEST(TensorIndexingTest, TestEmptySlice) { @@ -385,8 +491,8 @@ TEST(TensorIndexingTest, TestIndexSetitemBoolsSlices) { std::vector tensors = {torch::randn({2, 3}), torch::tensor(3)}; for (auto& a : tensors) { - // prefix with a 1,1, to ensure we are compatible with numpy which cuts off prefix 1s - // (some of these ops already prefix a 1 to the size) + // prefix with a 1,1, to ensure we are compatible with numpy which cuts off + // prefix 1s (some of these ops already prefix a 1 to the size) auto neg_ones = torch::ones_like(a) * -1; auto neg_ones_expanded = neg_ones.unsqueeze(0).unsqueeze(0); a.index_put_({true}, neg_ones_expanded); @@ -412,8 +518,10 @@ TEST(TensorIndexingTest, TestIndexScalarWithBoolMask) { torch::Device device(torch::kCPU); auto a = torch::tensor(1, device); - auto uintMask = torch::tensor(true, torch::TensorOptions(torch::kUInt8).device(device)); - auto boolMask = torch::tensor(true, torch::TensorOptions(torch::kBool).device(device)); + auto uintMask = + torch::tensor(true, torch::TensorOptions(torch::kUInt8).device(device)); + auto boolMask = + torch::tensor(true, torch::TensorOptions(torch::kBool).device(device)); assert_tensor_equal(a.index({uintMask}), a.index({boolMask})); ASSERT_EQ(a.index({uintMask}).dtype(), a.index({boolMask}).dtype()); @@ -426,8 +534,10 @@ TEST(TensorIndexingTest, TestIndexScalarWithBoolMask_CUDA) { torch::Device device(torch::kCUDA); auto a = torch::tensor(1, device); - auto uintMask = torch::tensor(true, torch::TensorOptions(torch::kUInt8).device(device)); - auto boolMask = torch::tensor(true, torch::TensorOptions(torch::kBool).device(device)); + auto uintMask = + torch::tensor(true, torch::TensorOptions(torch::kUInt8).device(device)); + auto boolMask = + torch::tensor(true, torch::TensorOptions(torch::kBool).device(device)); assert_tensor_equal(a.index({uintMask}), a.index({boolMask})); ASSERT_EQ(a.index({uintMask}).dtype(), a.index({boolMask}).dtype()); @@ -441,10 +551,7 @@ TEST(TensorIndexingTest, TestSetitemExpansionError) { auto a = torch::randn({2, 3}); // check prefix with non-1s doesn't work std::vector tensor_sizes{5, 1}; - tensor_sizes.insert( - tensor_sizes.end(), - a.sizes().begin(), - a.sizes().end()); + tensor_sizes.insert(tensor_sizes.end(), a.sizes().begin(), a.sizes().end()); auto a_expanded = a.expand(tensor_sizes); // NumPy: ValueError // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto) @@ -467,7 +574,8 @@ TEST(TensorIndexingTest, TestGetitemScalars) { // indexing by a scalar should slice (not copy) ASSERT_EQ(a.index({0, 1}).data_ptr(), a.index({zero, one}).data_ptr()); ASSERT_EQ(a.index({1}).data_ptr(), a.index({one.to(torch::kInt)}).data_ptr()); - ASSERT_EQ(a.index({1}).data_ptr(), a.index({one.to(torch::kShort)}).data_ptr()); + ASSERT_EQ( + a.index({1}).data_ptr(), a.index({one.to(torch::kShort)}).data_ptr()); // scalar indexed with scalar auto r = torch::randn({}); @@ -506,8 +614,11 @@ TEST(TensorIndexingTest, TestSetitemScalars) { TEST(TensorIndexingTest, TestBasicAdvancedCombined) { // From the NumPy indexing example auto x = torch::arange(0, 12).to(torch::kLong).view({4, 3}); - assert_tensor_equal(x.index({Slice(1, 2), Slice(1, 3)}), x.index({Slice(1, 2), torch::tensor({1, 2})})); - assert_tensor_equal(x.index({Slice(1, 2), Slice(1, 3)}), torch::tensor({{4, 5}})); + assert_tensor_equal( + x.index({Slice(1, 2), Slice(1, 3)}), + x.index({Slice(1, 2), torch::tensor({1, 2})})); + assert_tensor_equal( + x.index({Slice(1, 2), Slice(1, 3)}), torch::tensor({{4, 5}})); // Check that it is a copy { @@ -548,7 +659,11 @@ TEST(TensorIndexingTest, TestByteTensorAssignment) { x.index_put_({b}, value); - ASSERT_EQ(count_substr_occurrences(warnings.str(), "indexing with dtype torch.uint8 is now deprecated"), 1); + ASSERT_EQ( + count_substr_occurrences( + warnings.str(), + "indexing with dtype torch.uint8 is now deprecated"), + 1); } assert_tensor_equal(x.index({0}), value); @@ -568,19 +683,24 @@ TEST(TensorIndexingTest, TestVariableSlicing) { TEST(TensorIndexingTest, TestEllipsisTensor) { auto x = torch::arange(0, 9).to(torch::kLong).view({3, 3}); auto idx = torch::tensor({0, 2}); - assert_tensor_equal(x.index({"...", idx}), torch::tensor({{0, 2}, - {3, 5}, - {6, 8}})); - assert_tensor_equal(x.index({idx, "..."}), torch::tensor({{0, 1, 2}, - {6, 7, 8}})); + assert_tensor_equal( + x.index({"...", idx}), torch::tensor({{0, 2}, {3, 5}, {6, 8}})); + assert_tensor_equal( + x.index({idx, "..."}), torch::tensor({{0, 1, 2}, {6, 7, 8}})); } TEST(TensorIndexingTest, TestOutOfBoundIndex) { auto x = torch::arange(0, 100).view({2, 5, 10}); - ASSERT_THROWS_WITH(x.index({0, 5}), "index 5 is out of bounds for dimension 1 with size 5"); - ASSERT_THROWS_WITH(x.index({4, 5}), "index 4 is out of bounds for dimension 0 with size 2"); - ASSERT_THROWS_WITH(x.index({0, 1, 15}), "index 15 is out of bounds for dimension 2 with size 10"); - ASSERT_THROWS_WITH(x.index({Slice(), Slice(), 12}), "index 12 is out of bounds for dimension 2 with size 10"); + ASSERT_THROWS_WITH( + x.index({0, 5}), "index 5 is out of bounds for dimension 1 with size 5"); + ASSERT_THROWS_WITH( + x.index({4, 5}), "index 4 is out of bounds for dimension 0 with size 2"); + ASSERT_THROWS_WITH( + x.index({0, 1, 15}), + "index 15 is out of bounds for dimension 2 with size 10"); + ASSERT_THROWS_WITH( + x.index({Slice(), Slice(), 12}), + "index 12 is out of bounds for dimension 2 with size 10"); } TEST(TensorIndexingTest, TestZeroDimIndex) { @@ -595,7 +715,8 @@ TEST(TensorIndexingTest, TestZeroDimIndex) { } // The tests below are from NumPy test_indexing.py with some modifications to -// make them compatible with libtorch. It's licensed under the BDS license below: +// make them compatible with libtorch. It's licensed under the BDS license +// below: // // Copyright (c) 2005-2017, NumPy Developers. // All rights reserved. @@ -637,10 +758,13 @@ TEST(NumpyTests, TestNoneIndex) { TEST(NumpyTests, TestEmptyFancyIndex) { // Empty list index creates an empty array auto a = torch::tensor({1, 2, 3}); - assert_tensor_equal(a.index({torch::tensor({}, torch::kLong)}), torch::tensor({})); + assert_tensor_equal( + a.index({torch::tensor({}, torch::kLong)}), torch::tensor({})); auto b = torch::tensor({}).to(torch::kLong); - assert_tensor_equal(a.index({torch::tensor({}, torch::kLong)}), torch::tensor({}, torch::kLong)); + assert_tensor_equal( + a.index({torch::tensor({}, torch::kLong)}), + torch::tensor({}, torch::kLong)); b = torch::tensor({}).to(torch::kFloat); // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto) @@ -648,9 +772,7 @@ TEST(NumpyTests, TestEmptyFancyIndex) { } TEST(NumpyTests, TestEllipsisIndex) { - auto a = torch::tensor({{1, 2, 3}, - {4, 5, 6}, - {7, 8, 9}}); + auto a = torch::tensor({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); ASSERT_FALSE(a.index({"..."}).is_same(a)); assert_tensor_equal(a.index({"..."}), a); // `a[...]` was `a` in numpy <1.9. @@ -674,9 +796,7 @@ TEST(NumpyTests, TestEllipsisIndex) { TEST(NumpyTests, TestSingleIntIndex) { // Single integer index selects one row - auto a = torch::tensor({{1, 2, 3}, - {4, 5, 6}, - {7, 8, 9}}); + auto a = torch::tensor({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); assert_tensor_equal(a.index({0}), torch::tensor({1, 2, 3})); assert_tensor_equal(a.index({-1}), torch::tensor({7, 8, 9})); @@ -684,18 +804,18 @@ TEST(NumpyTests, TestSingleIntIndex) { // Index out of bounds produces IndexError // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto) ASSERT_THROW(a.index({1 << 30}), c10::Error); - // NOTE: According to the standard (http://www.open-std.org/jtc1/sc22/wg21/docs/papers/2017/p0543r0.html), - // for signed integers, if during the evaluation of an expression, the result is not mathematically defined - // or not in the range of representable values for its type, the behavior is undefined. - // Therefore, there is no way to check for index overflow case because it might not throw exception. + // NOTE: According to the standard + // (http://www.open-std.org/jtc1/sc22/wg21/docs/papers/2017/p0543r0.html), for + // signed integers, if during the evaluation of an expression, the result is + // not mathematically defined or not in the range of representable values for + // its type, the behavior is undefined. Therefore, there is no way to check + // for index overflow case because it might not throw exception. // ASSERT_THROW(a(1 << 64), c10::Error); } TEST(NumpyTests, TestSingleBoolIndex) { // Single boolean index - auto a = torch::tensor({{1, 2, 3}, - {4, 5, 6}, - {7, 8, 9}}); + auto a = torch::tensor({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); assert_tensor_equal(a.index({true}), a.index({None})); assert_tensor_equal(a.index({false}), a.index({None}).index({Slice(0, 0)})); @@ -717,7 +837,11 @@ TEST(NumpyTests, TestBooleanShapeMismatch) { ASSERT_THROWS_WITH(arr.index({index}), "mask"); ASSERT_THROWS_WITH(arr.index({Slice(), index}), "mask"); - ASSERT_EQ(count_substr_occurrences(warnings.str(), "indexing with dtype torch.uint8 is now deprecated"), 2); + ASSERT_EQ( + count_substr_occurrences( + warnings.str(), + "indexing with dtype torch.uint8 is now deprecated"), + 2); } } @@ -749,28 +873,32 @@ TEST(NumpyTests, TestBooleanAssignmentValueMismatch) { TEST(NumpyTests, TestBooleanIndexingTwodim) { // Indexing a 2-dimensional array with // 2-dimensional boolean array - auto a = torch::tensor({{1, 2, 3}, - {4, 5, 6}, - {7, 8, 9}}); - auto b = torch::tensor({{true, false, true}, - {false, true, false}, - {true, false, true}}); + auto a = torch::tensor({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + auto b = torch::tensor( + {{true, false, true}, {false, true, false}, {true, false, true}}); assert_tensor_equal(a.index({b}), torch::tensor({1, 3, 5, 7, 9})); assert_tensor_equal(a.index({b.index({1})}), torch::tensor({{4, 5, 6}})); assert_tensor_equal(a.index({b.index({0})}), a.index({b.index({2})})); // boolean assignment a.index_put_({b}, 0); - assert_tensor_equal(a, torch::tensor({{0, 2, 0}, - {4, 0, 6}, - {0, 8, 0}})); + assert_tensor_equal(a, torch::tensor({{0, 2, 0}, {4, 0, 6}, {0, 8, 0}})); } TEST(NumpyTests, TestBooleanIndexingWeirdness) { // Weird boolean indexing things auto a = torch::ones({2, 3, 4}); - ASSERT_EQ(a.index({false, true, "..."}).sizes(), torch::IntArrayRef({0, 2, 3, 4})); - assert_tensor_equal(torch::ones({1, 2}), a.index({true, torch::tensor({0, 1}), true, true, torch::tensor({1}), torch::tensor({{2}})})); + ASSERT_EQ( + a.index({false, true, "..."}).sizes(), torch::IntArrayRef({0, 2, 3, 4})); + assert_tensor_equal( + torch::ones({1, 2}), + a.index( + {true, + torch::tensor({0, 1}), + true, + true, + torch::tensor({1}), + torch::tensor({{2}})})); // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto) ASSERT_THROW(a.index({false, torch::tensor({0, 1}), "..."}), c10::Error); } @@ -780,25 +908,35 @@ TEST(NumpyTests, TestBooleanIndexingWeirdnessTensors) { auto false_tensor = torch::tensor(false); auto true_tensor = torch::tensor(true); auto a = torch::ones({2, 3, 4}); - ASSERT_EQ(a.index({false, true, "..."}).sizes(), torch::IntArrayRef({0, 2, 3, 4})); - assert_tensor_equal(torch::ones({1, 2}), a.index({true_tensor, torch::tensor({0, 1}), true_tensor, true_tensor, torch::tensor({1}), torch::tensor({{2}})})); + ASSERT_EQ( + a.index({false, true, "..."}).sizes(), torch::IntArrayRef({0, 2, 3, 4})); + assert_tensor_equal( + torch::ones({1, 2}), + a.index( + {true_tensor, + torch::tensor({0, 1}), + true_tensor, + true_tensor, + torch::tensor({1}), + torch::tensor({{2}})})); // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto) - ASSERT_THROW(a.index({false_tensor, torch::tensor({0, 1}), "..."}), c10::Error); + ASSERT_THROW( + a.index({false_tensor, torch::tensor({0, 1}), "..."}), c10::Error); } TEST(NumpyTests, TestBooleanIndexingAlldims) { auto true_tensor = torch::tensor(true); auto a = torch::ones({2, 3}); ASSERT_EQ(a.index({true, true}).sizes(), torch::IntArrayRef({1, 2, 3})); - ASSERT_EQ(a.index({true_tensor, true_tensor}).sizes(), torch::IntArrayRef({1, 2, 3})); + ASSERT_EQ( + a.index({true_tensor, true_tensor}).sizes(), + torch::IntArrayRef({1, 2, 3})); } TEST(NumpyTests, TestBooleanListIndexing) { // Indexing a 2-dimensional array with // boolean lists - auto a = torch::tensor({{1, 2, 3}, - {4, 5, 6}, - {7, 8, 9}}); + auto a = torch::tensor({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); auto b = torch::tensor({true, false, false}); auto c = torch::tensor({true, true, false}); assert_tensor_equal(a.index({b}), torch::tensor({{1, 2, 3}})); @@ -818,9 +956,12 @@ TEST(NumpyTests, TestEverythingReturnsViews) { TEST(NumpyTests, TestBroaderrorsIndexing) { auto a = torch::zeros({5, 5}); // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto) - ASSERT_THROW(a.index({torch::tensor({0, 1}), torch::tensor({0, 1, 2})}), c10::Error); + ASSERT_THROW( + a.index({torch::tensor({0, 1}), torch::tensor({0, 1, 2})}), c10::Error); // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto) - ASSERT_THROW(a.index_put_({torch::tensor({0, 1}), torch::tensor({0, 1, 2})}, 0), c10::Error); + ASSERT_THROW( + a.index_put_({torch::tensor({0, 1}), torch::tensor({0, 1, 2})}, 0), + c10::Error); } TEST(NumpyTests, TestTrivialFancyOutOfBounds) { @@ -842,9 +983,14 @@ TEST(NumpyTests, TestTrivialFancyOutOfBounds) { TEST(NumpyTests, TestIndexIsLarger) { // Simple case of fancy index broadcasting of the index. auto a = torch::zeros({5, 5}); - a.index_put_({torch::tensor({{0}, {1}, {2}}), torch::tensor({0, 1, 2})}, torch::tensor({2., 3., 4.})); - - ASSERT_TRUE((a.index({Slice(None, 3), Slice(None, 3)}) == torch::tensor({2., 3., 4.})).all().item()); + a.index_put_( + {torch::tensor({{0}, {1}, {2}}), torch::tensor({0, 1, 2})}, + torch::tensor({2., 3., 4.})); + + ASSERT_TRUE( + (a.index({Slice(None, 3), Slice(None, 3)}) == torch::tensor({2., 3., 4.})) + .all() + .item()); } TEST(NumpyTests, TestBroadcastSubspace) { diff --git a/test/cpp/api/tensor_options.cpp b/test/cpp/api/tensor_options.cpp index 83bd4a6911a3d8..68486732b1ef4e 100644 --- a/test/cpp/api/tensor_options.cpp +++ b/test/cpp/api/tensor_options.cpp @@ -58,10 +58,12 @@ TEST(TensorOptionsTest, ConstructsWellFromCPUTypes) { options = TensorOptions(kInt); REQUIRE_OPTIONS(kCPU, -1, kInt, kStrided); - options = TensorOptions(getDeprecatedTypeProperties(Backend::SparseCPU, kFloat)); + options = + TensorOptions(getDeprecatedTypeProperties(Backend::SparseCPU, kFloat)); REQUIRE_OPTIONS(kCPU, -1, kFloat, kSparse); - options = TensorOptions(getDeprecatedTypeProperties(Backend::SparseCPU, kByte)); + options = + TensorOptions(getDeprecatedTypeProperties(Backend::SparseCPU, kByte)); REQUIRE_OPTIONS(kCPU, -1, kByte, kSparse); } @@ -69,7 +71,8 @@ TEST(TensorOptionsTest, ConstructsWellFromCPUTensors) { auto options = empty(5, kDouble).options(); REQUIRE_OPTIONS(kCPU, -1, kDouble, kStrided); - options = empty(5, getDeprecatedTypeProperties(Backend::SparseCPU, kByte)).options(); + options = empty(5, getDeprecatedTypeProperties(Backend::SparseCPU, kByte)) + .options(); REQUIRE_OPTIONS(kCPU, -1, kByte, kSparse); } diff --git a/test/cpp/api/tensor_options_cuda.cpp b/test/cpp/api/tensor_options_cuda.cpp index 152fab550f461b..009b54322b73b2 100644 --- a/test/cpp/api/tensor_options_cuda.cpp +++ b/test/cpp/api/tensor_options_cuda.cpp @@ -18,14 +18,14 @@ at::Device CUDADevice(DeviceIndex index) { } // A macro so we don't lose location information when an assertion fails. -#define REQUIRE_OPTIONS(device_, index_, type_, layout_) \ +#define REQUIRE_OPTIONS(device_, index_, type_, layout_) \ ASSERT_EQ(options.device().type(), Device((device_), (index_)).type()); \ - ASSERT_TRUE( \ - options.device().index() == Device((device_), (index_)).index()); \ - ASSERT_EQ(typeMetaToScalarType(options.dtype()), (type_)); \ + ASSERT_TRUE( \ + options.device().index() == Device((device_), (index_)).index()); \ + ASSERT_EQ(typeMetaToScalarType(options.dtype()), (type_)); \ ASSERT_TRUE(options.layout() == (layout_)) -#define REQUIRE_TENSOR_OPTIONS(device_, index_, type_, layout_) \ +#define REQUIRE_TENSOR_OPTIONS(device_, index_, type_, layout_) \ ASSERT_EQ(tensor.device().type(), Device((device_), (index_)).type()); \ ASSERT_EQ(tensor.device().index(), Device((device_), (index_)).index()); \ ASSERT_EQ(tensor.scalar_type(), (type_)); \ @@ -50,7 +50,8 @@ TEST(TensorOptionsTest, ConstructsWellFromCUDATypes_CUDA) { options = // NOLINTNEXTLINE(bugprone-argument-comment,cppcoreguidelines-avoid-magic-numbers) - getDeprecatedTypeProperties(Backend::SparseCUDA, kFloat).options(/*device=*/5); + getDeprecatedTypeProperties(Backend::SparseCUDA, kFloat) + .options(/*device=*/5); REQUIRE_OPTIONS(kCUDA, 5, kFloat, kSparse); } @@ -58,7 +59,8 @@ TEST(TensorOptionsTest, ConstructsWellFromCUDATensors_MultiCUDA) { auto options = empty(5, device(kCUDA).dtype(kDouble)).options(); REQUIRE_OPTIONS(kCUDA, 0, kDouble, kStrided); - options = empty(5, getDeprecatedTypeProperties(Backend::SparseCUDA, kByte)).options(); + options = empty(5, getDeprecatedTypeProperties(Backend::SparseCUDA, kByte)) + .options(); REQUIRE_OPTIONS(kCUDA, 0, kByte, kSparse); if (torch::cuda::device_count() > 1) { diff --git a/test/cpp/api/transformer.cpp b/test/cpp/api/transformer.cpp index fee99544ceafe6..6062c77f5917d3 100644 --- a/test/cpp/api/transformer.cpp +++ b/test/cpp/api/transformer.cpp @@ -8,9 +8,12 @@ using namespace torch::nn; struct TransformerTest : torch::test::SeedingFixture {}; -// a generic function to set constants for parameters so we have fixed result for deterministic test -template -void set_parameter_to_constants(Model& model, const torch::TensorOptions& tensor_options) { +// a generic function to set constants for parameters so we have fixed result +// for deterministic test +template +void set_parameter_to_constants( + Model& model, + const torch::TensorOptions& tensor_options) { torch::NoGradGuard guard; for (auto& p : model->parameters()) { auto sz = p.view(-1).size(0); @@ -18,21 +21,28 @@ void set_parameter_to_constants(Model& model, const torch::TensorOptions& tensor } } -// a generic function to provide consistent encoder/decoder layer for all the transformer tests -template -T_LAYER get_a_test_layer(const torch::TensorOptions& tensor_options, bool use_callable_activation) { +// a generic function to provide consistent encoder/decoder layer for all the +// transformer tests +template +T_LAYER get_a_test_layer( + const torch::TensorOptions& tensor_options, + bool use_callable_activation) { int64_t d_model = 4; int64_t nhead = 2; int64_t dim_feedforward = 16; double dropout = 0.0; - // activation is always ReLU here and it can be adjusted later depending on the usage - T_LAYER layer(T_OPTIONS(d_model, nhead).dim_feedforward(dim_feedforward).dropout(dropout)); + // activation is always ReLU here and it can be adjusted later depending on + // the usage + T_LAYER layer(T_OPTIONS(d_model, nhead) + .dim_feedforward(dim_feedforward) + .dropout(dropout)); if (tensor_options.device() == torch::kCUDA) { layer->to(torch::kCUDA); } if (use_callable_activation) { - layer.get()->options.activation([&](const torch::Tensor& t) {return torch::nn::functional::relu(t);}); + layer.get()->options.activation( + [&](const torch::Tensor& t) { return torch::nn::functional::relu(t); }); } // set constant weights of the model @@ -41,509 +51,712 @@ T_LAYER get_a_test_layer(const torch::TensorOptions& tensor_options, bool use_ca return layer; } -void transformer_encoder_layer_test_helper(bool is_cuda, bool use_callable_activation) { +void transformer_encoder_layer_test_helper( + bool is_cuda, + bool use_callable_activation) { // this is a deterministic test for TransformerEncoderLayer torch::Device device = is_cuda ? torch::kCUDA : torch::kCPU; - torch::TensorOptions tensor_options = torch::TensorOptions().dtype(torch::kFloat32).device(device); + torch::TensorOptions tensor_options = + torch::TensorOptions().dtype(torch::kFloat32).device(device); TransformerEncoderLayer model = - get_a_test_layer( - tensor_options, use_callable_activation); + get_a_test_layer( + tensor_options, use_callable_activation); // relu test case 1 - torch::Tensor encoder_input = torch::tensor({{{20, 30, 40, 50}}}, tensor_options); + torch::Tensor encoder_input = + torch::tensor({{{20, 30, 40, 50}}}, tensor_options); torch::Tensor result = model(encoder_input).detach(); - torch::Tensor ref_output = torch::tensor({{{2.258703, 0.127985, -0.697881, 0.170862}}}, tensor_options); + torch::Tensor ref_output = torch::tensor( + {{{2.258703, 0.127985, -0.697881, 0.170862}}}, tensor_options); ASSERT_EQ(result.sizes(), ref_output.sizes()); - ASSERT_TRUE(torch::allclose(result, ref_output, 1e-7, 1e-5, /*equal_nan=*/true)); + ASSERT_TRUE( + torch::allclose(result, ref_output, 1e-7, 1e-5, /*equal_nan=*/true)); // all 0 values are NOT masked. This should't mask anything torch::Tensor mask = torch::tensor({{0}}, tensor_options) == 1; - result = model(encoder_input, /*src_mask=*/torch::Tensor{}, /*src_key_padding_mask=*/mask).detach(); + result = model( + encoder_input, + /*src_mask=*/torch::Tensor{}, + /*src_key_padding_mask=*/mask) + .detach(); ASSERT_EQ(result.sizes(), ref_output.sizes()); - ASSERT_TRUE(torch::allclose(result, ref_output, 1e-7, 1e-5, /*equal_nan=*/true)); + ASSERT_TRUE( + torch::allclose(result, ref_output, 1e-7, 1e-5, /*equal_nan=*/true)); - // all 1 values are masked. Since there is only 1 input embedding this will result in nan. + // all 1 values are masked. Since there is only 1 input embedding this will + // result in nan. mask = torch::tensor({{1}}, tensor_options) == 1; - result = model(encoder_input, /*src_mask=*/torch::Tensor{}, /*src_key_padding_mask=*/mask).detach(); + result = model( + encoder_input, + /*src_mask=*/torch::Tensor{}, + /*src_key_padding_mask=*/mask) + .detach(); ASSERT_TRUE(torch::isnan(result).all().item().to()); // relu test case 2 - encoder_input = torch::tensor({{{1, 2, 3, 4}}, {{5, 6, 7, 8}}}, tensor_options); + encoder_input = + torch::tensor({{{1, 2, 3, 4}}, {{5, 6, 7, 8}}}, tensor_options); result = model(encoder_input).detach(); - ref_output = torch::tensor({ - {{2.272644, 0.119035, -0.691669, 0.153486}}, - {{2.272644, 0.119035, -0.691669, 0.153486}}}, tensor_options); + ref_output = torch::tensor( + {{{2.272644, 0.119035, -0.691669, 0.153486}}, + {{2.272644, 0.119035, -0.691669, 0.153486}}}, + tensor_options); ASSERT_EQ(result.sizes(), ref_output.sizes()); - ASSERT_TRUE(torch::allclose(result, ref_output, 1e-7, 1e-5, /*equal_nan=*/true)); + ASSERT_TRUE( + torch::allclose(result, ref_output, 1e-7, 1e-5, /*equal_nan=*/true)); // all 0 values are NOT masked mask = torch::tensor({{0, 0}}, tensor_options) == 1; - result = model(encoder_input, /*src_mask=*/torch::Tensor{}, /*src_key_padding_mask=*/mask).detach(); + result = model( + encoder_input, + /*src_mask=*/torch::Tensor{}, + /*src_key_padding_mask=*/mask) + .detach(); ASSERT_EQ(result.sizes(), ref_output.sizes()); - ASSERT_TRUE(torch::allclose(result, ref_output, 1e-7, 1e-5, /*equal_nan=*/true)); + ASSERT_TRUE( + torch::allclose(result, ref_output, 1e-7, 1e-5, /*equal_nan=*/true)); // mask with 1 and 0 mask = torch::tensor({{1, 0}}, tensor_options) == 1; - result = model(encoder_input, /*src_mask=*/torch::Tensor{}, /*src_key_padding_mask=*/mask).detach(); - ref_output = torch::tensor({ - {{2.301516, 0.092249, -0.679101, 0.103088}}, - {{2.301516, 0.092249, -0.679101, 0.103088}}}, tensor_options); + result = model( + encoder_input, + /*src_mask=*/torch::Tensor{}, + /*src_key_padding_mask=*/mask) + .detach(); + ref_output = torch::tensor( + {{{2.301516, 0.092249, -0.679101, 0.103088}}, + {{2.301516, 0.092249, -0.679101, 0.103088}}}, + tensor_options); ASSERT_EQ(result.sizes(), ref_output.sizes()); - ASSERT_TRUE(torch::allclose(result, ref_output, 1e-7, 1e-5, /*equal_nan=*/true)); + ASSERT_TRUE( + torch::allclose(result, ref_output, 1e-7, 1e-5, /*equal_nan=*/true)); // relu test case 3 - encoder_input = torch::tensor({ - {{0.7462, 0.6653, 0.5679, 0.4891}, {0.5387, 0.1655, 0.3565, 0.0471}}, - {{0.8335, 0.2799, 0.5031, 0.2947}, {0.1402, 0.0318, 0.7636, 0.1346}}, - {{0.6333, 0.9344, 0.1376, 0.9938}, {0.8924, 0.2872, 0.6692, 0.2944}}, - {{0.9897, 0.6915, 0.3154, 0.1733}, {0.8645, 0.3513, 0.3064, 0.0767}}, - {{0.8117, 0.2366, 0.4838, 0.7881}, {0.3718, 0.4945, 0.9511, 0.0864}}}, tensor_options); + encoder_input = torch::tensor( + {{{0.7462, 0.6653, 0.5679, 0.4891}, {0.5387, 0.1655, 0.3565, 0.0471}}, + {{0.8335, 0.2799, 0.5031, 0.2947}, {0.1402, 0.0318, 0.7636, 0.1346}}, + {{0.6333, 0.9344, 0.1376, 0.9938}, {0.8924, 0.2872, 0.6692, 0.2944}}, + {{0.9897, 0.6915, 0.3154, 0.1733}, {0.8645, 0.3513, 0.3064, 0.0767}}, + {{0.8117, 0.2366, 0.4838, 0.7881}, {0.3718, 0.4945, 0.9511, 0.0864}}}, + tensor_options); result = model(encoder_input).detach(); - ref_output = torch::tensor({ - {{2.428589, 0.020835, -0.602055, -0.085249}, {2.427987, 0.021213, -0.602496, -0.084103}}, - {{2.424689, 0.019155, -0.604793, -0.085672}, {2.413863, 0.022211, -0.612486, -0.072490}}, - {{2.433774, 0.021598, -0.598343, -0.087548}, {2.425104, 0.019748, -0.604515, -0.084839}}, - {{2.436185, 0.022682, -0.596625, -0.087261}, {2.433556, 0.021891, -0.598509, -0.086832}}, - {{2.416246, 0.017512, -0.610712, -0.082961}, {2.422901, 0.024187, -0.606178, -0.074929}}}, tensor_options); + ref_output = torch::tensor( + {{{2.428589, 0.020835, -0.602055, -0.085249}, + {2.427987, 0.021213, -0.602496, -0.084103}}, + {{2.424689, 0.019155, -0.604793, -0.085672}, + {2.413863, 0.022211, -0.612486, -0.072490}}, + {{2.433774, 0.021598, -0.598343, -0.087548}, + {2.425104, 0.019748, -0.604515, -0.084839}}, + {{2.436185, 0.022682, -0.596625, -0.087261}, + {2.433556, 0.021891, -0.598509, -0.086832}}, + {{2.416246, 0.017512, -0.610712, -0.082961}, + {2.422901, 0.024187, -0.606178, -0.074929}}}, + tensor_options); ASSERT_EQ(result.sizes(), ref_output.sizes()); - ASSERT_TRUE(torch::allclose(result, ref_output, 1e-7, 1e-5, /*equal_nan=*/true)); + ASSERT_TRUE( + torch::allclose(result, ref_output, 1e-7, 1e-5, /*equal_nan=*/true)); // all 0 values are NOT masked mask = torch::zeros({2, 5}, tensor_options) == 1; - result = model(encoder_input, /*src_mask=*/torch::Tensor{}, /*src_key_padding_mask=*/mask).detach(); + result = model( + encoder_input, + /*src_mask=*/torch::Tensor{}, + /*src_key_padding_mask=*/mask) + .detach(); ASSERT_EQ(result.sizes(), ref_output.sizes()); - ASSERT_TRUE(torch::allclose(result, ref_output, 1e-7, 1e-5, /*equal_nan=*/true)); + ASSERT_TRUE( + torch::allclose(result, ref_output, 1e-7, 1e-5, /*equal_nan=*/true)); // mask with 0s and 1s mask[0][1] = 1; mask[1][3] = 1; mask[1][4] = 1; - result = model(encoder_input, /*src_mask=*/torch::Tensor{}, /*src_key_padding_mask=*/mask).detach(); - ref_output = torch::tensor({ - {{2.429026, 0.020793, -0.601741, -0.085642}, {2.428811, 0.021445, -0.601912, -0.084252}}, - {{2.425009, 0.019155, -0.604566, -0.085899}, {2.415408, 0.02249 , -0.611415, -0.073}}, - {{2.434199, 0.021682, -0.598039, -0.087699}, {2.42598, 0.019941, -0.603896, -0.085091}}, - {{2.436457, 0.022736, -0.59643 , -0.08736}, {2.434021, 0.022093, -0.598179, -0.08679}}, - {{2.416531, 0.017498, -0.610513, -0.083181}, {2.4242, 0.024653, -0.605266, -0.074959}}}, tensor_options); + result = model( + encoder_input, + /*src_mask=*/torch::Tensor{}, + /*src_key_padding_mask=*/mask) + .detach(); + ref_output = torch::tensor( + {{{2.429026, 0.020793, -0.601741, -0.085642}, + {2.428811, 0.021445, -0.601912, -0.084252}}, + {{2.425009, 0.019155, -0.604566, -0.085899}, + {2.415408, 0.02249, -0.611415, -0.073}}, + {{2.434199, 0.021682, -0.598039, -0.087699}, + {2.42598, 0.019941, -0.603896, -0.085091}}, + {{2.436457, 0.022736, -0.59643, -0.08736}, + {2.434021, 0.022093, -0.598179, -0.08679}}, + {{2.416531, 0.017498, -0.610513, -0.083181}, + {2.4242, 0.024653, -0.605266, -0.074959}}}, + tensor_options); ASSERT_EQ(result.sizes(), ref_output.sizes()); - ASSERT_TRUE(torch::allclose(result, ref_output, 1e-7, 1e-5, /*equal_nan=*/true)); + ASSERT_TRUE( + torch::allclose(result, ref_output, 1e-7, 1e-5, /*equal_nan=*/true)); // gelu test case 1 model.get()->options.activation(torch::kGELU); encoder_input = torch::tensor({{{20, 30, 40, 50}}}, tensor_options); result = model(encoder_input).detach(); - ref_output = torch::tensor({{{2.249815, 0.131006, -0.702199, 0.177868}}}, tensor_options); + ref_output = torch::tensor( + {{{2.249815, 0.131006, -0.702199, 0.177868}}}, tensor_options); ASSERT_EQ(result.sizes(), ref_output.sizes()); - ASSERT_TRUE(torch::allclose(result, ref_output, 1e-7, 1e-5, /*equal_nan=*/true)); + ASSERT_TRUE( + torch::allclose(result, ref_output, 1e-7, 1e-5, /*equal_nan=*/true)); // gelu test case 2 - encoder_input = torch::tensor({ - {{0.7462, 0.6653, 0.5679, 0.4891}, {0.5387, 0.1655, 0.3565, 0.0471}}, - {{0.8335, 0.2799, 0.5031, 0.2947}, {0.1402, 0.0318, 0.7636, 0.1346}}, - {{0.6333, 0.9344, 0.1376, 0.9938}, {0.8924, 0.2872, 0.6692, 0.2944}}, - {{0.9897, 0.6915, 0.3154, 0.1733}, {0.8645, 0.3513, 0.3064, 0.0767}}, - {{0.8117, 0.2366, 0.4838, 0.7881}, {0.3718, 0.4945, 0.9511, 0.0864}}}, tensor_options); + encoder_input = torch::tensor( + {{{0.7462, 0.6653, 0.5679, 0.4891}, {0.5387, 0.1655, 0.3565, 0.0471}}, + {{0.8335, 0.2799, 0.5031, 0.2947}, {0.1402, 0.0318, 0.7636, 0.1346}}, + {{0.6333, 0.9344, 0.1376, 0.9938}, {0.8924, 0.2872, 0.6692, 0.2944}}, + {{0.9897, 0.6915, 0.3154, 0.1733}, {0.8645, 0.3513, 0.3064, 0.0767}}, + {{0.8117, 0.2366, 0.4838, 0.7881}, {0.3718, 0.4945, 0.9511, 0.0864}}}, + tensor_options); result = model(encoder_input); - ref_output = torch::tensor({ - {{2.42163188, 0.03227153, -0.60714219, -0.05908082}, {2.42151276, 0.03302179, -0.60722523, -0.05762651}}, - {{2.41926761, 0.02974034, -0.60879519, -0.0621269}, {2.41626395, 0.03539356, -0.61087842, -0.04978623}}, - {{2.42382808, 0.03218872, -0.6055963, -0.06073591}, {2.41983477, 0.03085259, -0.60840145, -0.06046414}}, - {{2.42500749, 0.03328855, -0.60476388, -0.0595334}, {2.4237977, 0.03290575, -0.60561789, -0.05940082}}, - {{2.41383916, 0.02686345, -0.61256377, -0.06380707}, {2.42000277, 0.03800944, -0.60824798, -0.04754947}}}, - tensor_options); + ref_output = torch::tensor( + {{{2.42163188, 0.03227153, -0.60714219, -0.05908082}, + {2.42151276, 0.03302179, -0.60722523, -0.05762651}}, + {{2.41926761, 0.02974034, -0.60879519, -0.0621269}, + {2.41626395, 0.03539356, -0.61087842, -0.04978623}}, + {{2.42382808, 0.03218872, -0.6055963, -0.06073591}, + {2.41983477, 0.03085259, -0.60840145, -0.06046414}}, + {{2.42500749, 0.03328855, -0.60476388, -0.0595334}, + {2.4237977, 0.03290575, -0.60561789, -0.05940082}}, + {{2.41383916, 0.02686345, -0.61256377, -0.06380707}, + {2.42000277, 0.03800944, -0.60824798, -0.04754947}}}, + tensor_options); ASSERT_EQ(result.sizes(), ref_output.sizes()); - ASSERT_TRUE(torch::allclose(result, ref_output, 1e-7, 1e-5, /*equal_nan=*/true)); + ASSERT_TRUE( + torch::allclose(result, ref_output, 1e-7, 1e-5, /*equal_nan=*/true)); } TEST_F(TransformerTest, TransformerEncoderLayer) { - transformer_encoder_layer_test_helper(/*is_cuda=*/ false, /*use_callable_activation=*/ false); - transformer_encoder_layer_test_helper(/*is_cuda=*/ false, /*use_callable_activation=*/ true); + transformer_encoder_layer_test_helper( + /*is_cuda=*/false, /*use_callable_activation=*/false); + transformer_encoder_layer_test_helper( + /*is_cuda=*/false, /*use_callable_activation=*/true); } TEST_F(TransformerTest, TransformerEncoderLayer_CUDA) { - transformer_encoder_layer_test_helper(/*is_cuda=*/ true, /*use_callable_activation=*/ false); - transformer_encoder_layer_test_helper(/*is_cuda=*/ true, /*use_callable_activation=*/ true); + transformer_encoder_layer_test_helper( + /*is_cuda=*/true, /*use_callable_activation=*/false); + transformer_encoder_layer_test_helper( + /*is_cuda=*/true, /*use_callable_activation=*/true); } -void transformer_decoder_layer_test_helper(bool is_cuda, bool use_callable_activation){ - +void transformer_decoder_layer_test_helper( + bool is_cuda, + bool use_callable_activation) { torch::Device device = is_cuda ? torch::kCUDA : torch::kCPU; - torch::TensorOptions tensor_options = torch::TensorOptions() - .dtype(torch::kFloat32).device(device); + torch::TensorOptions tensor_options = + torch::TensorOptions().dtype(torch::kFloat32).device(device); - TransformerDecoderLayer model = get_a_test_layer< - TransformerDecoderLayer, - TransformerDecoderLayerOptions>(tensor_options, use_callable_activation); + TransformerDecoderLayer model = + get_a_test_layer( + tensor_options, use_callable_activation); // deterministic input - torch::Tensor decoder_input = torch::tensor({{{20, 30, 40, 50}}}, - tensor_options); - torch::Tensor memory_input = torch::tensor({{{60, 70, 80, 90}}}, - tensor_options); + torch::Tensor decoder_input = + torch::tensor({{{20, 30, 40, 50}}}, tensor_options); + torch::Tensor memory_input = + torch::tensor({{{60, 70, 80, 90}}}, tensor_options); torch::Tensor result = model(decoder_input, memory_input).detach(); torch::Tensor ref_output = torch::tensor( - {{{2.314351, 0.094805, -0.671322, 0.101977}}}, - tensor_options); - ASSERT_EQ(result.sizes().size(),ref_output.sizes().size()); - ASSERT_TRUE(torch::allclose(result, ref_output, 1e-7, 1e-5, - /*equal_nan=*/true)); + {{{2.314351, 0.094805, -0.671322, 0.101977}}}, tensor_options); + ASSERT_EQ(result.sizes().size(), ref_output.sizes().size()); + ASSERT_TRUE(torch::allclose( + result, + ref_output, + 1e-7, + 1e-5, + /*equal_nan=*/true)); // deterministic input - decoder_input = torch::tensor({{{9, 10, 11, 12}}, - {{11, 12, 13, 14}}}, tensor_options); + decoder_input = + torch::tensor({{{9, 10, 11, 12}}, {{11, 12, 13, 14}}}, tensor_options); memory_input = torch::tensor({{{1, 2, 3, 4}}}, tensor_options); result = model(decoder_input, memory_input).detach(); - ref_output = torch::tensor({{{2.422245, 0.051716, -0.606338, -0.024756}}, - {{2.422245, 0.051716, -0.606338, -0.024756}}}, - tensor_options); - ASSERT_EQ(result.sizes().size(),ref_output.sizes().size()); - ASSERT_TRUE(torch::allclose(result, ref_output, 1e-7, 1e-5, - /*equal_nan=*/true)); + ref_output = torch::tensor( + {{{2.422245, 0.051716, -0.606338, -0.024756}}, + {{2.422245, 0.051716, -0.606338, -0.024756}}}, + tensor_options); + ASSERT_EQ(result.sizes().size(), ref_output.sizes().size()); + ASSERT_TRUE(torch::allclose( + result, + ref_output, + 1e-7, + 1e-5, + /*equal_nan=*/true)); // deterministic input - decoder_input = torch::tensor({{{1, 2, 3, 4}}, - {{5, 6, 7, 8}}}, tensor_options); - memory_input = torch::tensor({{{9, 10, 11, 12}}, - {{11, 12, 13, 14}}}, tensor_options); + decoder_input = + torch::tensor({{{1, 2, 3, 4}}, {{5, 6, 7, 8}}}, tensor_options); + memory_input = + torch::tensor({{{9, 10, 11, 12}}, {{11, 12, 13, 14}}}, tensor_options); result = model(decoder_input, memory_input).detach(); - ref_output = torch::tensor({{{2.343536, 0.085561, -0.654954, 0.074991}}, - {{2.343536, 0.085561, -0.654954, 0.074991}}}, - tensor_options); - ASSERT_EQ(result.sizes().size(),ref_output.sizes().size()); - ASSERT_TRUE(torch::allclose(result, ref_output, 1e-7, 1e-5, - /*equal_nan=*/true)); - - - // deterministic input - decoder_input = torch::tensor({{{0.4517, 0.6793, 0.5313, 0.0034}, - {0.2678, 0.3677, 0.4459, 0.7166}}, - {{0.8100, 0.3716, 0.4096, 0.1976}, - {0.6958, 0.8844, 0.6081, 0.8315}}, - {{0.0494, 0.9343, 0.5955, 0.3830}, - {0.5404, 0.3464, 0.9378, 0.6200}}}, - tensor_options); - memory_input = torch::tensor({{{0.7462, 0.6653, 0.5679, 0.4891}, - {0.5387, 0.1655, 0.3565, 0.0471}}, - {{0.8335, 0.2799, 0.5031, 0.2947}, - {0.1402, 0.0318, 0.7636, 0.1346}}, - {{0.6333, 0.9344, 0.1376, 0.9938}, - {0.8924, 0.2872, 0.6692, 0.2944}}, - {{0.9897, 0.6915, 0.3154, 0.1733}, - {0.8645, 0.3513, 0.3064, 0.0767}}, - {{0.8117, 0.2366, 0.4838, 0.7881}, - {0.3718, 0.4945, 0.9511, 0.0864}}}, - tensor_options); + ref_output = torch::tensor( + {{{2.343536, 0.085561, -0.654954, 0.074991}}, + {{2.343536, 0.085561, -0.654954, 0.074991}}}, + tensor_options); + ASSERT_EQ(result.sizes().size(), ref_output.sizes().size()); + ASSERT_TRUE(torch::allclose( + result, + ref_output, + 1e-7, + 1e-5, + /*equal_nan=*/true)); + + // deterministic input + decoder_input = torch::tensor( + {{{0.4517, 0.6793, 0.5313, 0.0034}, {0.2678, 0.3677, 0.4459, 0.7166}}, + {{0.8100, 0.3716, 0.4096, 0.1976}, {0.6958, 0.8844, 0.6081, 0.8315}}, + {{0.0494, 0.9343, 0.5955, 0.3830}, {0.5404, 0.3464, 0.9378, 0.6200}}}, + tensor_options); + memory_input = torch::tensor( + {{{0.7462, 0.6653, 0.5679, 0.4891}, {0.5387, 0.1655, 0.3565, 0.0471}}, + {{0.8335, 0.2799, 0.5031, 0.2947}, {0.1402, 0.0318, 0.7636, 0.1346}}, + {{0.6333, 0.9344, 0.1376, 0.9938}, {0.8924, 0.2872, 0.6692, 0.2944}}, + {{0.9897, 0.6915, 0.3154, 0.1733}, {0.8645, 0.3513, 0.3064, 0.0767}}, + {{0.8117, 0.2366, 0.4838, 0.7881}, {0.3718, 0.4945, 0.9511, 0.0864}}}, + tensor_options); result = model(decoder_input, memory_input).detach(); - ref_output = torch::tensor({{{2.430065, 0.027862, -0.601136, -0.073096}, - {2.431935, 0.028907, -0.599809, -0.072488}}, - {{2.428457, 0.027053, -0.602275, -0.073462}, - {2.431970, 0.029387, -0.599789, -0.071621}}, - {{2.431934, 0.028196, -0.599802, -0.073809}, - {2.432306, 0.028858, -0.599542, -0.072846}}}, - tensor_options); - ASSERT_EQ(result.sizes().size(),ref_output.sizes().size()); - ASSERT_TRUE(torch::allclose(result, ref_output, 1e-7, 1e-5, - /*equal_nan=*/true)); + ref_output = torch::tensor( + {{{2.430065, 0.027862, -0.601136, -0.073096}, + {2.431935, 0.028907, -0.599809, -0.072488}}, + {{2.428457, 0.027053, -0.602275, -0.073462}, + {2.431970, 0.029387, -0.599789, -0.071621}}, + {{2.431934, 0.028196, -0.599802, -0.073809}, + {2.432306, 0.028858, -0.599542, -0.072846}}}, + tensor_options); + ASSERT_EQ(result.sizes().size(), ref_output.sizes().size()); + ASSERT_TRUE(torch::allclose( + result, + ref_output, + 1e-7, + 1e-5, + /*equal_nan=*/true)); // key_padding_mask torch::Tensor t_mask = {}; torch::Tensor m_mask = {}; torch::Tensor key_padding_mask = torch::zeros({2, 3}, tensor_options) == 1; - result = model(decoder_input, memory_input, t_mask, m_mask, - key_padding_mask).detach(); - ref_output = torch::tensor({{{2.430065, 0.027862, -0.601136, -0.073096}, - {2.431935, 0.028907, -0.599809, -0.072488}}, - {{2.428457, 0.027053, -0.602275, -0.073462}, - {2.431970, 0.029387, -0.599789, -0.071621}}, - {{2.431934, 0.028196, -0.599802, -0.073809}, - {2.432306, 0.028858, -0.599542, -0.072846}}}, - tensor_options); - ASSERT_EQ(result.sizes().size(),ref_output.sizes().size()); - ASSERT_TRUE(torch::allclose(result, ref_output, 1e-7, 1e-5, - /*equal_nan=*/true)); + result = model(decoder_input, memory_input, t_mask, m_mask, key_padding_mask) + .detach(); + ref_output = torch::tensor( + {{{2.430065, 0.027862, -0.601136, -0.073096}, + {2.431935, 0.028907, -0.599809, -0.072488}}, + {{2.428457, 0.027053, -0.602275, -0.073462}, + {2.431970, 0.029387, -0.599789, -0.071621}}, + {{2.431934, 0.028196, -0.599802, -0.073809}, + {2.432306, 0.028858, -0.599542, -0.072846}}}, + tensor_options); + ASSERT_EQ(result.sizes().size(), ref_output.sizes().size()); + ASSERT_TRUE(torch::allclose( + result, + ref_output, + 1e-7, + 1e-5, + /*equal_nan=*/true)); // key_padding_mask key_padding_mask[0][2] = 1; key_padding_mask[1][1] = 1; key_padding_mask[1][2] = 1; - result = model(decoder_input, memory_input, t_mask, m_mask, - key_padding_mask).detach(); - ref_output = torch::tensor({{{2.430025, 0.027643, -0.601164, -0.073476}, - {2.4323, 0.029375, -0.599553, -0.071881}}, - {{2.428523, 0.026838, -0.602226, -0.07391}, - {2.432634, 0.029842, -0.599318, -0.071253}}, - {{2.432278, 0.028152, -0.599555, -0.074139}, - {2.432659, 0.029244, -0.599294, -0.072382}}}, - tensor_options); - ASSERT_EQ(result.sizes().size(),ref_output.sizes().size()); - ASSERT_TRUE(torch::allclose(result, ref_output, 1e-7, 1e-5, - /*equal_nan=*/true)); + result = model(decoder_input, memory_input, t_mask, m_mask, key_padding_mask) + .detach(); + ref_output = torch::tensor( + {{{2.430025, 0.027643, -0.601164, -0.073476}, + {2.4323, 0.029375, -0.599553, -0.071881}}, + {{2.428523, 0.026838, -0.602226, -0.07391}, + {2.432634, 0.029842, -0.599318, -0.071253}}, + {{2.432278, 0.028152, -0.599555, -0.074139}, + {2.432659, 0.029244, -0.599294, -0.072382}}}, + tensor_options); + ASSERT_EQ(result.sizes().size(), ref_output.sizes().size()); + ASSERT_TRUE(torch::allclose( + result, + ref_output, + 1e-7, + 1e-5, + /*equal_nan=*/true)); // memory_key_padding_mask torch::Tensor t_key_padding_mask = {}; key_padding_mask = torch::zeros({2, 5}, tensor_options) == 1; - result = model(decoder_input, memory_input, t_mask, m_mask, - t_key_padding_mask, key_padding_mask).detach(); - ref_output = torch::tensor({{{2.430065, 0.027862, -0.601136, -0.073096}, - {2.431935, 0.028907, -0.599809, -0.072488}}, - {{2.428457, 0.027053, -0.602275, -0.073462}, - {2.431970, 0.029387, -0.599789, -0.071621}}, - {{2.431934, 0.028196, -0.599802, -0.073809}, - {2.432306, 0.028858, -0.599542, -0.072846}}}, - tensor_options); - ASSERT_EQ(result.sizes().size(),ref_output.sizes().size()); - ASSERT_TRUE(torch::allclose(result, ref_output, 1e-7, 1e-5, - /*equal_nan=*/true)); + result = model( + decoder_input, + memory_input, + t_mask, + m_mask, + t_key_padding_mask, + key_padding_mask) + .detach(); + ref_output = torch::tensor( + {{{2.430065, 0.027862, -0.601136, -0.073096}, + {2.431935, 0.028907, -0.599809, -0.072488}}, + {{2.428457, 0.027053, -0.602275, -0.073462}, + {2.431970, 0.029387, -0.599789, -0.071621}}, + {{2.431934, 0.028196, -0.599802, -0.073809}, + {2.432306, 0.028858, -0.599542, -0.072846}}}, + tensor_options); + ASSERT_EQ(result.sizes().size(), ref_output.sizes().size()); + ASSERT_TRUE(torch::allclose( + result, + ref_output, + 1e-7, + 1e-5, + /*equal_nan=*/true)); // memory_key_padding_mask key_padding_mask[0][4] = 1; key_padding_mask[1][3] = 1; key_padding_mask[1][4] = 1; - result = model(decoder_input, memory_input, t_mask, m_mask, - t_key_padding_mask, key_padding_mask).detach(); - ref_output = torch::tensor({{{2.429757, 0.027358, -0.601351, -0.073816}, - {2.432692, 0.028583, -0.599263, -0.073634}}, - {{2.428247, 0.02662, -0.602419, -0.074123}, - {2.432657, 0.029055, -0.599293, -0.072732}}, - {{2.431515, 0.027687, -0.600096, -0.074459}, - {2.433075, 0.028543, -0.598987, -0.073985}}}, - tensor_options); - ASSERT_EQ(result.sizes().size(),ref_output.sizes().size()); - ASSERT_TRUE(torch::allclose(result, ref_output, 1e-7, 1e-5, - /*equal_nan=*/true)); - + result = model( + decoder_input, + memory_input, + t_mask, + m_mask, + t_key_padding_mask, + key_padding_mask) + .detach(); + ref_output = torch::tensor( + {{{2.429757, 0.027358, -0.601351, -0.073816}, + {2.432692, 0.028583, -0.599263, -0.073634}}, + {{2.428247, 0.02662, -0.602419, -0.074123}, + {2.432657, 0.029055, -0.599293, -0.072732}}, + {{2.431515, 0.027687, -0.600096, -0.074459}, + {2.433075, 0.028543, -0.598987, -0.073985}}}, + tensor_options); + ASSERT_EQ(result.sizes().size(), ref_output.sizes().size()); + ASSERT_TRUE(torch::allclose( + result, + ref_output, + 1e-7, + 1e-5, + /*equal_nan=*/true)); } -TEST_F(TransformerTest, TransformerDecoderLayer){ - transformer_decoder_layer_test_helper(/*is_cuda=*/ false, /*use_callable_activation=*/ false); - transformer_decoder_layer_test_helper(/*is_cuda=*/ false, /*use_callable_activation=*/ true); +TEST_F(TransformerTest, TransformerDecoderLayer) { + transformer_decoder_layer_test_helper( + /*is_cuda=*/false, /*use_callable_activation=*/false); + transformer_decoder_layer_test_helper( + /*is_cuda=*/false, /*use_callable_activation=*/true); } -TEST_F(TransformerTest, TransformerDecoderLayer_CUDA){ - transformer_decoder_layer_test_helper(/*is_cuda=*/ true, /*use_callable_activation=*/ false); - transformer_decoder_layer_test_helper(/*is_cuda=*/ true, /*use_callable_activation=*/ true); +TEST_F(TransformerTest, TransformerDecoderLayer_CUDA) { + transformer_decoder_layer_test_helper( + /*is_cuda=*/true, /*use_callable_activation=*/false); + transformer_decoder_layer_test_helper( + /*is_cuda=*/true, /*use_callable_activation=*/true); } -void transformer_decoder_layer_test_helper_gelu(bool is_cuda, bool use_callable_activation) { - +void transformer_decoder_layer_test_helper_gelu( + bool is_cuda, + bool use_callable_activation) { torch::Device device = is_cuda ? torch::kCUDA : torch::kCPU; - torch::TensorOptions tensor_options = torch::TensorOptions() - .dtype(torch::kFloat32).device(device); + torch::TensorOptions tensor_options = + torch::TensorOptions().dtype(torch::kFloat32).device(device); - TransformerDecoderLayer model = get_a_test_layer< - TransformerDecoderLayer, - TransformerDecoderLayerOptions>(tensor_options, use_callable_activation); + TransformerDecoderLayer model = + get_a_test_layer( + tensor_options, use_callable_activation); if (use_callable_activation) { - model.get()->options.activation([&](const torch::Tensor& t) {return torch::nn::functional::gelu(t);}); + model.get()->options.activation( + [&](const torch::Tensor& t) { return torch::nn::functional::gelu(t); }); } else { model.get()->options.activation(torch::kGELU); } // deterministic input - torch::Tensor decoder_input = torch::tensor({{{20, 30, 40, 50}}}, - tensor_options); - torch::Tensor memory_input = torch::tensor({{{60, 70, 80, 90}}}, - tensor_options); + torch::Tensor decoder_input = + torch::tensor({{{20, 30, 40, 50}}}, tensor_options); + torch::Tensor memory_input = + torch::tensor({{{60, 70, 80, 90}}}, tensor_options); torch::Tensor result = model(decoder_input, memory_input).detach(); torch::Tensor ref_output = torch::tensor( - {{{2.306435, 0.095946, -0.675796, 0.10687}}}, - tensor_options); - ASSERT_EQ(result.sizes().size(),ref_output.sizes().size()); - ASSERT_TRUE(torch::allclose(result, ref_output, 1e-7, 1e-5, - /*equal_nan=*/true)); + {{{2.306435, 0.095946, -0.675796, 0.10687}}}, tensor_options); + ASSERT_EQ(result.sizes().size(), ref_output.sizes().size()); + ASSERT_TRUE(torch::allclose( + result, + ref_output, + 1e-7, + 1e-5, + /*equal_nan=*/true)); // deterministic input - decoder_input = torch::tensor({{{9, 10, 11, 12}}, - {{11, 12, 13, 14}}}, - tensor_options); + decoder_input = + torch::tensor({{{9, 10, 11, 12}}, {{11, 12, 13, 14}}}, tensor_options); memory_input = torch::tensor({{{1, 2, 3, 4}}}, tensor_options); result = model(decoder_input, memory_input).detach(); - ref_output = torch::tensor({{{2.415448, 0.054389, -0.610932, -0.0156613}}, - {{2.415448, 0.054389, -0.610932, -0.0156613}}}, - tensor_options); - ASSERT_EQ(result.sizes().size(),ref_output.sizes().size()); - ASSERT_TRUE(torch::allclose(result, ref_output, 1e-7, 1e-5, - /*equal_nan=*/true)); + ref_output = torch::tensor( + {{{2.415448, 0.054389, -0.610932, -0.0156613}}, + {{2.415448, 0.054389, -0.610932, -0.0156613}}}, + tensor_options); + ASSERT_EQ(result.sizes().size(), ref_output.sizes().size()); + ASSERT_TRUE(torch::allclose( + result, + ref_output, + 1e-7, + 1e-5, + /*equal_nan=*/true)); // deterministic input - decoder_input = torch::tensor({{{1, 2, 3, 4}}, - {{5, 6, 7, 8}}}, - tensor_options); - memory_input = torch::tensor({{{9, 10, 11, 12}}, - {{11, 12, 13, 14}}}, - tensor_options); + decoder_input = + torch::tensor({{{1, 2, 3, 4}}, {{5, 6, 7, 8}}}, tensor_options); + memory_input = + torch::tensor({{{9, 10, 11, 12}}, {{11, 12, 13, 14}}}, tensor_options); result = model(decoder_input, memory_input).detach(); - ref_output = torch::tensor({{{2.338531, 0.087709, -0.65776, 0.080646}}, - {{2.338531, 0.087709, -0.65776, 0.080646}}}, - tensor_options); - ASSERT_EQ(result.sizes().size(),ref_output.sizes().size()); - ASSERT_TRUE(torch::allclose(result, ref_output, 1e-7, 1e-5, - /*equal_nan=*/true)); - + ref_output = torch::tensor( + {{{2.338531, 0.087709, -0.65776, 0.080646}}, + {{2.338531, 0.087709, -0.65776, 0.080646}}}, + tensor_options); + ASSERT_EQ(result.sizes().size(), ref_output.sizes().size()); + ASSERT_TRUE(torch::allclose( + result, + ref_output, + 1e-7, + 1e-5, + /*equal_nan=*/true)); // deterministic input - decoder_input = torch::tensor({{{0.4517, 0.6793, 0.5313, 0.0034}, - {0.2678, 0.3677, 0.4459, 0.7166}}, - {{0.8100, 0.3716, 0.4096, 0.1976}, - {0.6958, 0.8844, 0.6081, 0.8315}}, - {{0.0494, 0.9343, 0.5955, 0.3830}, - {0.5404, 0.3464, 0.9378, 0.6200}}}, - tensor_options); - memory_input = torch::tensor({{{0.7462, 0.6653, 0.5679, 0.4891}, - {0.5387, 0.1655, 0.3565, 0.0471}}, - {{0.8335, 0.2799, 0.5031, 0.2947}, - {0.1402, 0.0318, 0.7636, 0.1346}}, - {{0.6333, 0.9344, 0.1376, 0.9938}, - {0.8924, 0.2872, 0.6692, 0.2944}}, - {{0.9897, 0.6915, 0.3154, 0.1733}, - {0.8645, 0.3513, 0.3064, 0.0767}}, - {{0.8117, 0.2366, 0.4838, 0.7881}, - {0.3718, 0.4945, 0.9511, 0.0864}}}, - tensor_options); + decoder_input = torch::tensor( + {{{0.4517, 0.6793, 0.5313, 0.0034}, {0.2678, 0.3677, 0.4459, 0.7166}}, + {{0.8100, 0.3716, 0.4096, 0.1976}, {0.6958, 0.8844, 0.6081, 0.8315}}, + {{0.0494, 0.9343, 0.5955, 0.3830}, {0.5404, 0.3464, 0.9378, 0.6200}}}, + tensor_options); + memory_input = torch::tensor( + {{{0.7462, 0.6653, 0.5679, 0.4891}, {0.5387, 0.1655, 0.3565, 0.0471}}, + {{0.8335, 0.2799, 0.5031, 0.2947}, {0.1402, 0.0318, 0.7636, 0.1346}}, + {{0.6333, 0.9344, 0.1376, 0.9938}, {0.8924, 0.2872, 0.6692, 0.2944}}, + {{0.9897, 0.6915, 0.3154, 0.1733}, {0.8645, 0.3513, 0.3064, 0.0767}}, + {{0.8117, 0.2366, 0.4838, 0.7881}, {0.3718, 0.4945, 0.9511, 0.0864}}}, + tensor_options); result = model(decoder_input, memory_input).detach(); ref_output = torch::tensor( - {{{2.42049104, 0.03443088, -0.60793706, -0.05436271}, - {2.42210631, 0.03546578, -0.60679895, -0.05357488}}, - {{2.41907674, 0.0336104, -0.60892977, -0.05490462}, - {2.42216881, 0.03586554, -0.6067524, -0.05289126}}, - {{2.42205716, 0.03488046, -0.60683681, -0.05460596}, - {2.42240309, 0.0354595, -0.60659063, -0.05378816}}}, - tensor_options); - ASSERT_EQ(result.sizes().size(),ref_output.sizes().size()); - ASSERT_TRUE(torch::allclose(result, ref_output, 1e-7, 1e-5, - /*equal_nan=*/true)); - + {{{2.42049104, 0.03443088, -0.60793706, -0.05436271}, + {2.42210631, 0.03546578, -0.60679895, -0.05357488}}, + {{2.41907674, 0.0336104, -0.60892977, -0.05490462}, + {2.42216881, 0.03586554, -0.6067524, -0.05289126}}, + {{2.42205716, 0.03488046, -0.60683681, -0.05460596}, + {2.42240309, 0.0354595, -0.60659063, -0.05378816}}}, + tensor_options); + ASSERT_EQ(result.sizes().size(), ref_output.sizes().size()); + ASSERT_TRUE(torch::allclose( + result, + ref_output, + 1e-7, + 1e-5, + /*equal_nan=*/true)); } TEST_F(TransformerTest, TransformerDecoderLayer_gelu) { - transformer_decoder_layer_test_helper_gelu(/*is_cuda=*/ false, /*use_callable_activation=*/ false); - transformer_decoder_layer_test_helper_gelu(/*is_cuda=*/ false, /*use_callable_activation=*/ true); + transformer_decoder_layer_test_helper_gelu( + /*is_cuda=*/false, /*use_callable_activation=*/false); + transformer_decoder_layer_test_helper_gelu( + /*is_cuda=*/false, /*use_callable_activation=*/true); } TEST_F(TransformerTest, TransformerDecoderLayer_gelu_CUDA) { - transformer_decoder_layer_test_helper_gelu(/*is_cuda=*/ true, /*use_callable_activation=*/ false); - transformer_decoder_layer_test_helper_gelu(/*is_cuda=*/ true, /*use_callable_activation=*/ true); + transformer_decoder_layer_test_helper_gelu( + /*is_cuda=*/true, /*use_callable_activation=*/false); + transformer_decoder_layer_test_helper_gelu( + /*is_cuda=*/true, /*use_callable_activation=*/true); } -void transformer_encoder_test_helper(bool is_cuda, bool use_callable_activation) { +void transformer_encoder_test_helper( + bool is_cuda, + bool use_callable_activation) { // this is a deterministic test for TransformerEncoderLayer torch::Device device = is_cuda ? torch::kCUDA : torch::kCPU; - torch::TensorOptions tensor_options = torch::TensorOptions().dtype(torch::kFloat32).device(device); + torch::TensorOptions tensor_options = + torch::TensorOptions().dtype(torch::kFloat32).device(device); TransformerEncoderLayer encoder_layer = - get_a_test_layer( - tensor_options, use_callable_activation); + get_a_test_layer( + tensor_options, use_callable_activation); TransformerEncoder model(TransformerEncoderOptions(encoder_layer, 1)); if (is_cuda) { model->to(torch::kCUDA); } - torch::Tensor encoder_input = torch::tensor({ - {{0.7462, 0.6653, 0.5679, 0.4891}, {0.5387, 0.1655, 0.3565, 0.0471}}, - {{0.8335, 0.2799, 0.5031, 0.2947}, {0.1402, 0.0318, 0.7636, 0.1346}}, - {{0.6333, 0.9344, 0.1376, 0.9938}, {0.8924, 0.2872, 0.6692, 0.2944}}, - {{0.9897, 0.6915, 0.3154, 0.1733}, {0.8645, 0.3513, 0.3064, 0.0767}}, - {{0.8117, 0.2366, 0.4838, 0.7881}, {0.3718, 0.4945, 0.9511, 0.0864}}}, tensor_options); + torch::Tensor encoder_input = torch::tensor( + {{{0.7462, 0.6653, 0.5679, 0.4891}, {0.5387, 0.1655, 0.3565, 0.0471}}, + {{0.8335, 0.2799, 0.5031, 0.2947}, {0.1402, 0.0318, 0.7636, 0.1346}}, + {{0.6333, 0.9344, 0.1376, 0.9938}, {0.8924, 0.2872, 0.6692, 0.2944}}, + {{0.9897, 0.6915, 0.3154, 0.1733}, {0.8645, 0.3513, 0.3064, 0.0767}}, + {{0.8117, 0.2366, 0.4838, 0.7881}, {0.3718, 0.4945, 0.9511, 0.0864}}}, + tensor_options); torch::Tensor result = model(encoder_input).detach(); - torch::Tensor ref_output = torch::tensor({ - {{2.428589, 0.020835, -0.602055, -0.085249}, {2.427987, 0.021213, -0.602496, -0.084103}}, - {{2.424689, 0.019155, -0.604793, -0.085672}, {2.413863, 0.022211, -0.612486, -0.072490}}, - {{2.433774, 0.021598, -0.598343, -0.087548}, {2.425104, 0.019748, -0.604515, -0.084839}}, - {{2.436185, 0.022682, -0.596625, -0.087261}, {2.433556, 0.021891, -0.598509, -0.086832}}, - {{2.416246, 0.017512, -0.610712, -0.082961}, {2.422901, 0.024187, -0.606178, -0.074929}}}, tensor_options); + torch::Tensor ref_output = torch::tensor( + {{{2.428589, 0.020835, -0.602055, -0.085249}, + {2.427987, 0.021213, -0.602496, -0.084103}}, + {{2.424689, 0.019155, -0.604793, -0.085672}, + {2.413863, 0.022211, -0.612486, -0.072490}}, + {{2.433774, 0.021598, -0.598343, -0.087548}, + {2.425104, 0.019748, -0.604515, -0.084839}}, + {{2.436185, 0.022682, -0.596625, -0.087261}, + {2.433556, 0.021891, -0.598509, -0.086832}}, + {{2.416246, 0.017512, -0.610712, -0.082961}, + {2.422901, 0.024187, -0.606178, -0.074929}}}, + tensor_options); ASSERT_EQ(result.sizes(), ref_output.sizes()); - ASSERT_TRUE(torch::allclose(result, ref_output, 1e-7, 1e-5, /*equal_nan=*/true)); + ASSERT_TRUE( + torch::allclose(result, ref_output, 1e-7, 1e-5, /*equal_nan=*/true)); // all 0 values are NOT masked torch::Tensor mask = torch::zeros({2, 5}, tensor_options) == 1; - result = model(encoder_input, /*src_mask=*/torch::Tensor{}, /*src_key_padding_mask=*/mask).detach(); + result = model( + encoder_input, + /*src_mask=*/torch::Tensor{}, + /*src_key_padding_mask=*/mask) + .detach(); ASSERT_EQ(result.sizes(), ref_output.sizes()); - ASSERT_TRUE(torch::allclose(result, ref_output, 1e-7, 1e-5, /*equal_nan=*/true)); + ASSERT_TRUE( + torch::allclose(result, ref_output, 1e-7, 1e-5, /*equal_nan=*/true)); // mask with 0s and 1s mask[0][1] = 1; mask[1][3] = 1; mask[1][4] = 1; - result = model(encoder_input, /*src_mask=*/torch::Tensor{}, /*src_key_padding_mask=*/mask).detach(); - ref_output = torch::tensor({ - {{2.429026, 0.020793, -0.601741, -0.085642}, {2.428811, 0.021445, -0.601912, -0.084252}}, - {{2.425009, 0.019155, -0.604566, -0.085899}, {2.415408, 0.02249 , -0.611415, -0.073}}, - {{2.434199, 0.021682, -0.598039, -0.087699}, {2.42598, 0.019941, -0.603896, -0.085091}}, - {{2.436457, 0.022736, -0.59643 , -0.08736}, {2.434021, 0.022093, -0.598179, -0.08679}}, - {{2.416531, 0.017498, -0.610513, -0.083181}, {2.4242, 0.024653, -0.605266, -0.074959}}}, tensor_options); + result = model( + encoder_input, + /*src_mask=*/torch::Tensor{}, + /*src_key_padding_mask=*/mask) + .detach(); + ref_output = torch::tensor( + {{{2.429026, 0.020793, -0.601741, -0.085642}, + {2.428811, 0.021445, -0.601912, -0.084252}}, + {{2.425009, 0.019155, -0.604566, -0.085899}, + {2.415408, 0.02249, -0.611415, -0.073}}, + {{2.434199, 0.021682, -0.598039, -0.087699}, + {2.42598, 0.019941, -0.603896, -0.085091}}, + {{2.436457, 0.022736, -0.59643, -0.08736}, + {2.434021, 0.022093, -0.598179, -0.08679}}, + {{2.416531, 0.017498, -0.610513, -0.083181}, + {2.4242, 0.024653, -0.605266, -0.074959}}}, + tensor_options); ASSERT_EQ(result.sizes(), ref_output.sizes()); - ASSERT_TRUE(torch::allclose(result, ref_output, 1e-7, 1e-5, /*equal_nan=*/true)); + ASSERT_TRUE( + torch::allclose(result, ref_output, 1e-7, 1e-5, /*equal_nan=*/true)); // test case 2, multiple layers no norm model = TransformerEncoder(TransformerEncoderOptions(encoder_layer, 2)); if (is_cuda) { model->to(torch::kCUDA); } - result = model(encoder_input, /*src_mask=*/torch::Tensor{}, /*src_key_padding_mask=*/mask).detach(); - ref_output = torch::tensor({ - {{2.419051, 0.017446, -0.608738, -0.085003}, {2.419102, 0.017452, -0.608703, -0.085026}}, - {{2.419043, 0.017445, -0.608744, -0.084999}, {2.419052, 0.017446, -0.608738, -0.085004}}, - {{2.419067, 0.017448, -0.608727, -0.085010}, {2.419098, 0.017452, -0.608706, -0.085024}}, - {{2.419072, 0.017449, -0.608724, -0.085012}, {2.419119, 0.017455, -0.608691, -0.085034}}, - {{2.419019, 0.017442, -0.608761, -0.084989}, {2.419075, 0.017449, -0.608722, -0.085014}}}, tensor_options); + result = model( + encoder_input, + /*src_mask=*/torch::Tensor{}, + /*src_key_padding_mask=*/mask) + .detach(); + ref_output = torch::tensor( + {{{2.419051, 0.017446, -0.608738, -0.085003}, + {2.419102, 0.017452, -0.608703, -0.085026}}, + {{2.419043, 0.017445, -0.608744, -0.084999}, + {2.419052, 0.017446, -0.608738, -0.085004}}, + {{2.419067, 0.017448, -0.608727, -0.085010}, + {2.419098, 0.017452, -0.608706, -0.085024}}, + {{2.419072, 0.017449, -0.608724, -0.085012}, + {2.419119, 0.017455, -0.608691, -0.085034}}, + {{2.419019, 0.017442, -0.608761, -0.084989}, + {2.419075, 0.017449, -0.608722, -0.085014}}}, + tensor_options); ASSERT_EQ(result.sizes(), ref_output.sizes()); - ASSERT_TRUE(torch::allclose(result, ref_output, 1e-7, 1e-5, /*equal_nan=*/true)); + ASSERT_TRUE( + torch::allclose(result, ref_output, 1e-7, 1e-5, /*equal_nan=*/true)); model = TransformerEncoder(TransformerEncoderOptions(encoder_layer, 6)); if (is_cuda) { model->to(torch::kCUDA); } - result = model(encoder_input, /*src_mask=*/torch::Tensor{}, /*src_key_padding_mask=*/mask).detach(); - ref_output = torch::tensor({ - {{2.419101, 0.017453, -0.608703, -0.085025}, {2.419101, 0.017453, -0.608704, -0.085025}}, - {{2.419101, 0.017453, -0.608703, -0.085025}, {2.419101, 0.017453, -0.608704, -0.085025}}, - {{2.419101, 0.017453, -0.608703, -0.085025}, {2.419101, 0.017453, -0.608704, -0.085025}}, - {{2.419101, 0.017453, -0.608703, -0.085025}, {2.419101, 0.017453, -0.608704, -0.085025}}, - {{2.419101, 0.017453, -0.608703, -0.085025}, {2.419101, 0.017453, -0.608704, -0.085025}}}, tensor_options); + result = model( + encoder_input, + /*src_mask=*/torch::Tensor{}, + /*src_key_padding_mask=*/mask) + .detach(); + ref_output = torch::tensor( + {{{2.419101, 0.017453, -0.608703, -0.085025}, + {2.419101, 0.017453, -0.608704, -0.085025}}, + {{2.419101, 0.017453, -0.608703, -0.085025}, + {2.419101, 0.017453, -0.608704, -0.085025}}, + {{2.419101, 0.017453, -0.608703, -0.085025}, + {2.419101, 0.017453, -0.608704, -0.085025}}, + {{2.419101, 0.017453, -0.608703, -0.085025}, + {2.419101, 0.017453, -0.608704, -0.085025}}, + {{2.419101, 0.017453, -0.608703, -0.085025}, + {2.419101, 0.017453, -0.608704, -0.085025}}}, + tensor_options); ASSERT_EQ(result.sizes(), ref_output.sizes()); - ASSERT_TRUE(torch::allclose(result, ref_output, 1e-7, 1e-5, /*equal_nan=*/true)); + ASSERT_TRUE( + torch::allclose(result, ref_output, 1e-7, 1e-5, /*equal_nan=*/true)); // test case 3, multiple layers with norm LayerNorm norm(LayerNormOptions({encoder_layer.get()->options.d_model()})); - model = TransformerEncoder(TransformerEncoderOptions(encoder_layer, 2).norm(AnyModule(norm))); + model = TransformerEncoder( + TransformerEncoderOptions(encoder_layer, 2).norm(AnyModule(norm))); if (is_cuda) { model->to(torch::kCUDA); } - result = model(encoder_input, /*src_mask=*/torch::Tensor{}, /*src_key_padding_mask=*/mask).detach(); - ref_output = torch::tensor({ - {{1.695949, -0.357635, -0.893077, -0.445238}, {1.695955, -0.357639, -0.893050, -0.445266}}, - {{1.695948, -0.357634, -0.893082, -0.445233}, {1.695950, -0.357635, -0.893077, -0.445238}}, - {{1.695951, -0.357636, -0.893069, -0.445246}, {1.695955, -0.357639, -0.893052, -0.445264}}, - {{1.695952, -0.357636, -0.893066, -0.445249}, {1.695957, -0.357641, -0.893041, -0.445276}}, - {{1.695946, -0.357632, -0.893095, -0.445220}, {1.695952, -0.357637, -0.893065, -0.445251}}}, tensor_options); + result = model( + encoder_input, + /*src_mask=*/torch::Tensor{}, + /*src_key_padding_mask=*/mask) + .detach(); + ref_output = torch::tensor( + {{{1.695949, -0.357635, -0.893077, -0.445238}, + {1.695955, -0.357639, -0.893050, -0.445266}}, + {{1.695948, -0.357634, -0.893082, -0.445233}, + {1.695950, -0.357635, -0.893077, -0.445238}}, + {{1.695951, -0.357636, -0.893069, -0.445246}, + {1.695955, -0.357639, -0.893052, -0.445264}}, + {{1.695952, -0.357636, -0.893066, -0.445249}, + {1.695957, -0.357641, -0.893041, -0.445276}}, + {{1.695946, -0.357632, -0.893095, -0.445220}, + {1.695952, -0.357637, -0.893065, -0.445251}}}, + tensor_options); ASSERT_EQ(result.sizes(), ref_output.sizes()); - ASSERT_TRUE(torch::allclose(result, ref_output, 1e-7, 1e-5, /*equal_nan=*/true)); + ASSERT_TRUE( + torch::allclose(result, ref_output, 1e-7, 1e-5, /*equal_nan=*/true)); - model = TransformerEncoder(TransformerEncoderOptions(encoder_layer, 6).norm(AnyModule(norm))); + model = TransformerEncoder( + TransformerEncoderOptions(encoder_layer, 6).norm(AnyModule(norm))); if (is_cuda) { model->to(torch::kCUDA); } - result = model(encoder_input, /*src_mask=*/torch::Tensor{}, /*src_key_padding_mask=*/mask).detach(); - ref_output = torch::tensor({ - {{1.695955, -0.357639, -0.893051, -0.445265}, {1.695955, -0.357639, -0.893051, -0.445265}}, - {{1.695955, -0.357639, -0.893051, -0.445265}, {1.695955, -0.357639, -0.893051, -0.445265}}, - {{1.695955, -0.357639, -0.893051, -0.445265}, {1.695955, -0.357639, -0.893051, -0.445265}}, - {{1.695955, -0.357639, -0.893051, -0.445265}, {1.695955, -0.357639, -0.893051, -0.445265}}, - {{1.695955, -0.357639, -0.893051, -0.445265}, {1.695955, -0.357639, -0.893051, -0.445265}}}, tensor_options); + result = model( + encoder_input, + /*src_mask=*/torch::Tensor{}, + /*src_key_padding_mask=*/mask) + .detach(); + ref_output = torch::tensor( + {{{1.695955, -0.357639, -0.893051, -0.445265}, + {1.695955, -0.357639, -0.893051, -0.445265}}, + {{1.695955, -0.357639, -0.893051, -0.445265}, + {1.695955, -0.357639, -0.893051, -0.445265}}, + {{1.695955, -0.357639, -0.893051, -0.445265}, + {1.695955, -0.357639, -0.893051, -0.445265}}, + {{1.695955, -0.357639, -0.893051, -0.445265}, + {1.695955, -0.357639, -0.893051, -0.445265}}, + {{1.695955, -0.357639, -0.893051, -0.445265}, + {1.695955, -0.357639, -0.893051, -0.445265}}}, + tensor_options); ASSERT_EQ(result.sizes(), ref_output.sizes()); - ASSERT_TRUE(torch::allclose(result, ref_output, 1e-7, 1e-5, /*equal_nan=*/true)); + ASSERT_TRUE( + torch::allclose(result, ref_output, 1e-7, 1e-5, /*equal_nan=*/true)); } TEST_F(TransformerTest, TransformerEncoder) { - transformer_encoder_test_helper(/*is_cuda=*/ false, /*use_callable_activation=*/ false); - transformer_encoder_test_helper(/*is_cuda=*/ false, /*use_callable_activation=*/ true); + transformer_encoder_test_helper( + /*is_cuda=*/false, /*use_callable_activation=*/false); + transformer_encoder_test_helper( + /*is_cuda=*/false, /*use_callable_activation=*/true); } TEST_F(TransformerTest, TransformerEncoder_CUDA) { - transformer_encoder_test_helper(/*is_cuda=*/ true, /*use_callable_activation=*/ false); - transformer_encoder_test_helper(/*is_cuda=*/ true, /*use_callable_activation=*/ true); + transformer_encoder_test_helper( + /*is_cuda=*/true, /*use_callable_activation=*/false); + transformer_encoder_test_helper( + /*is_cuda=*/true, /*use_callable_activation=*/true); } TEST_F(TransformerTest, PrettyPrintTransformerEncoderLayer) { @@ -566,7 +779,8 @@ TEST_F(TransformerTest, PrettyPrintTransformerEncoderLayer) { TEST_F(TransformerTest, PrettyPrintTransformerEncoder) { LayerNorm norm = LayerNorm(LayerNormOptions({4})); TransformerEncoderOptions options( - TransformerEncoderOptions(TransformerEncoderLayerOptions(4, 2),2).norm(AnyModule(norm))); + TransformerEncoderOptions(TransformerEncoderLayerOptions(4, 2), 2) + .norm(AnyModule(norm))); ASSERT_EQ( c10::str(TransformerEncoder(options)), "torch::nn::TransformerEncoderImpl(\n" @@ -600,7 +814,6 @@ TEST_F(TransformerTest, PrettyPrintTransformerEncoder) { ")"); } - TEST_F(TransformerTest, PrettyPrintTransformerDecoderLayer) { ASSERT_EQ( c10::str(TransformerDecoderLayer(4, 2)), @@ -623,157 +836,201 @@ TEST_F(TransformerTest, PrettyPrintTransformerDecoderLayer) { ")"); } -void transformer_decoder_test_helper(bool is_cuda, bool use_callable_activation) { +void transformer_decoder_test_helper( + bool is_cuda, + bool use_callable_activation) { // this is a deterministic test for TransformerDecoder torch::Device device = is_cuda ? torch::kCUDA : torch::kCPU; torch::TensorOptions tensor_options = - torch::TensorOptions().dtype(torch::kFloat32).device(device); + torch::TensorOptions().dtype(torch::kFloat32).device(device); - TransformerDecoderLayer decoder_layer = get_a_test_layer< - TransformerDecoderLayer, - TransformerDecoderLayerOptions>(tensor_options, use_callable_activation); + TransformerDecoderLayer decoder_layer = + get_a_test_layer( + tensor_options, use_callable_activation); TransformerDecoder model(TransformerDecoderOptions(decoder_layer, 1)); if (is_cuda) { model->to(torch::kCUDA); } - - torch::Tensor decoder_input = torch::tensor({{{20, 30, 40, 50}}}, - tensor_options); - torch::Tensor memory_input = torch::tensor({{{60, 70, 80, 90}}}, - tensor_options); + torch::Tensor decoder_input = + torch::tensor({{{20, 30, 40, 50}}}, tensor_options); + torch::Tensor memory_input = + torch::tensor({{{60, 70, 80, 90}}}, tensor_options); torch::Tensor result = model(decoder_input, memory_input).detach(); torch::Tensor ref_output = torch::tensor( - {{{2.314351, 0.094805, -0.671322, 0.101977}}}, - tensor_options); - ASSERT_EQ(result.sizes().size(),ref_output.sizes().size()); - ASSERT_TRUE(torch::allclose(result, ref_output, 1e-7, 1e-5, - /*equal_nan=*/true)); - -// deterministic input - decoder_input = torch::tensor({{{9, 10, 11, 12}}, - {{11, 12, 13, 14}}}, tensor_options); + {{{2.314351, 0.094805, -0.671322, 0.101977}}}, tensor_options); + ASSERT_EQ(result.sizes().size(), ref_output.sizes().size()); + ASSERT_TRUE(torch::allclose( + result, + ref_output, + 1e-7, + 1e-5, + /*equal_nan=*/true)); + + // deterministic input + decoder_input = + torch::tensor({{{9, 10, 11, 12}}, {{11, 12, 13, 14}}}, tensor_options); memory_input = torch::tensor({{{1, 2, 3, 4}}}, tensor_options); result = model(decoder_input, memory_input).detach(); - ref_output = torch::tensor({{{2.422245, 0.051716, -0.606338, -0.024756}}, - {{2.422245, 0.051716, -0.606338, -0.024756}}}, - tensor_options); - ASSERT_EQ(result.sizes().size(),ref_output.sizes().size()); - ASSERT_TRUE(torch::allclose(result, ref_output, 1e-7, 1e-5, - /*equal_nan=*/true)); + ref_output = torch::tensor( + {{{2.422245, 0.051716, -0.606338, -0.024756}}, + {{2.422245, 0.051716, -0.606338, -0.024756}}}, + tensor_options); + ASSERT_EQ(result.sizes().size(), ref_output.sizes().size()); + ASSERT_TRUE(torch::allclose( + result, + ref_output, + 1e-7, + 1e-5, + /*equal_nan=*/true)); // deterministic input - decoder_input = torch::tensor({{{1, 2, 3, 4}}, - {{5, 6, 7, 8}}}, tensor_options); - memory_input = torch::tensor({{{9, 10, 11, 12}}, - {{11, 12, 13, 14}}}, tensor_options); + decoder_input = + torch::tensor({{{1, 2, 3, 4}}, {{5, 6, 7, 8}}}, tensor_options); + memory_input = + torch::tensor({{{9, 10, 11, 12}}, {{11, 12, 13, 14}}}, tensor_options); result = model(decoder_input, memory_input).detach(); - ref_output = torch::tensor({{{2.343536, 0.085561, -0.654954, 0.074991}}, - {{2.343536, 0.085561, -0.654954, 0.074991}}}, - tensor_options); - ASSERT_EQ(result.sizes().size(),ref_output.sizes().size()); - ASSERT_TRUE(torch::allclose(result, ref_output, 1e-7, 1e-5, - /*equal_nan=*/true)); - - - // deterministic input - decoder_input = torch::tensor({{{0.4517, 0.6793, 0.5313, 0.0034}, - {0.2678, 0.3677, 0.4459, 0.7166}}, - {{0.8100, 0.3716, 0.4096, 0.1976}, - {0.6958, 0.8844, 0.6081, 0.8315}}, - {{0.0494, 0.9343, 0.5955, 0.3830}, - {0.5404, 0.3464, 0.9378, 0.6200}}}, - tensor_options); - memory_input = torch::tensor({{{0.7462, 0.6653, 0.5679, 0.4891}, - {0.5387, 0.1655, 0.3565, 0.0471}}, - {{0.8335, 0.2799, 0.5031, 0.2947}, - {0.1402, 0.0318, 0.7636, 0.1346}}, - {{0.6333, 0.9344, 0.1376, 0.9938}, - {0.8924, 0.2872, 0.6692, 0.2944}}, - {{0.9897, 0.6915, 0.3154, 0.1733}, - {0.8645, 0.3513, 0.3064, 0.0767}}, - {{0.8117, 0.2366, 0.4838, 0.7881}, - {0.3718, 0.4945, 0.9511, 0.0864}}}, - tensor_options); + ref_output = torch::tensor( + {{{2.343536, 0.085561, -0.654954, 0.074991}}, + {{2.343536, 0.085561, -0.654954, 0.074991}}}, + tensor_options); + ASSERT_EQ(result.sizes().size(), ref_output.sizes().size()); + ASSERT_TRUE(torch::allclose( + result, + ref_output, + 1e-7, + 1e-5, + /*equal_nan=*/true)); + + // deterministic input + decoder_input = torch::tensor( + {{{0.4517, 0.6793, 0.5313, 0.0034}, {0.2678, 0.3677, 0.4459, 0.7166}}, + {{0.8100, 0.3716, 0.4096, 0.1976}, {0.6958, 0.8844, 0.6081, 0.8315}}, + {{0.0494, 0.9343, 0.5955, 0.3830}, {0.5404, 0.3464, 0.9378, 0.6200}}}, + tensor_options); + memory_input = torch::tensor( + {{{0.7462, 0.6653, 0.5679, 0.4891}, {0.5387, 0.1655, 0.3565, 0.0471}}, + {{0.8335, 0.2799, 0.5031, 0.2947}, {0.1402, 0.0318, 0.7636, 0.1346}}, + {{0.6333, 0.9344, 0.1376, 0.9938}, {0.8924, 0.2872, 0.6692, 0.2944}}, + {{0.9897, 0.6915, 0.3154, 0.1733}, {0.8645, 0.3513, 0.3064, 0.0767}}, + {{0.8117, 0.2366, 0.4838, 0.7881}, {0.3718, 0.4945, 0.9511, 0.0864}}}, + tensor_options); result = model(decoder_input, memory_input).detach(); - ref_output = torch::tensor({{{2.430065, 0.027862, -0.601136, -0.073096}, - {2.431935, 0.028907, -0.599809, -0.072488}}, - {{2.428457, 0.027053, -0.602275, -0.073462}, - {2.431970, 0.029387, -0.599789, -0.071621}}, - {{2.431934, 0.028196, -0.599802, -0.073809}, - {2.432306, 0.028858, -0.599542, -0.072846}}}, - tensor_options); - ASSERT_EQ(result.sizes().size(),ref_output.sizes().size()); - ASSERT_TRUE(torch::allclose(result, ref_output, 1e-7, 1e-5, - /*equal_nan=*/true)); + ref_output = torch::tensor( + {{{2.430065, 0.027862, -0.601136, -0.073096}, + {2.431935, 0.028907, -0.599809, -0.072488}}, + {{2.428457, 0.027053, -0.602275, -0.073462}, + {2.431970, 0.029387, -0.599789, -0.071621}}, + {{2.431934, 0.028196, -0.599802, -0.073809}, + {2.432306, 0.028858, -0.599542, -0.072846}}}, + tensor_options); + ASSERT_EQ(result.sizes().size(), ref_output.sizes().size()); + ASSERT_TRUE(torch::allclose( + result, + ref_output, + 1e-7, + 1e-5, + /*equal_nan=*/true)); // key_padding_mask torch::Tensor t_mask = {}; torch::Tensor m_mask = {}; torch::Tensor key_padding_mask = torch::zeros({2, 3}, tensor_options) == 1; - result = model(decoder_input, memory_input, t_mask, m_mask, - key_padding_mask).detach(); - ref_output = torch::tensor({{{2.430065, 0.027862, -0.601136, -0.073096}, - {2.431935, 0.028907, -0.599809, -0.072488}}, - {{2.428457, 0.027053, -0.602275, -0.073462}, - {2.431970, 0.029387, -0.599789, -0.071621}}, - {{2.431934, 0.028196, -0.599802, -0.073809}, - {2.432306, 0.028858, -0.599542, -0.072846}}}, - tensor_options); - ASSERT_EQ(result.sizes().size(),ref_output.sizes().size()); - ASSERT_TRUE(torch::allclose(result, ref_output, 1e-7, 1e-5, - /*equal_nan=*/true)); + result = model(decoder_input, memory_input, t_mask, m_mask, key_padding_mask) + .detach(); + ref_output = torch::tensor( + {{{2.430065, 0.027862, -0.601136, -0.073096}, + {2.431935, 0.028907, -0.599809, -0.072488}}, + {{2.428457, 0.027053, -0.602275, -0.073462}, + {2.431970, 0.029387, -0.599789, -0.071621}}, + {{2.431934, 0.028196, -0.599802, -0.073809}, + {2.432306, 0.028858, -0.599542, -0.072846}}}, + tensor_options); + ASSERT_EQ(result.sizes().size(), ref_output.sizes().size()); + ASSERT_TRUE(torch::allclose( + result, + ref_output, + 1e-7, + 1e-5, + /*equal_nan=*/true)); // key_padding_mask key_padding_mask[0][2] = 1; key_padding_mask[1][1] = 1; key_padding_mask[1][2] = 1; - result = model(decoder_input, memory_input, t_mask, m_mask, - key_padding_mask).detach(); - ref_output = torch::tensor({{{2.430025, 0.027643, -0.601164, -0.073476}, - {2.4323, 0.029375, -0.599553, -0.071881}}, - {{2.428523, 0.026838, -0.602226, -0.07391}, - {2.432634, 0.029842, -0.599318, -0.071253}}, - {{2.432278, 0.028152, -0.599555, -0.074139}, - {2.432659, 0.029244, -0.599294, -0.072382}}}, - tensor_options); - ASSERT_EQ(result.sizes().size(),ref_output.sizes().size()); - ASSERT_TRUE(torch::allclose(result, ref_output, 1e-7, 1e-5, - /*equal_nan=*/true)); + result = model(decoder_input, memory_input, t_mask, m_mask, key_padding_mask) + .detach(); + ref_output = torch::tensor( + {{{2.430025, 0.027643, -0.601164, -0.073476}, + {2.4323, 0.029375, -0.599553, -0.071881}}, + {{2.428523, 0.026838, -0.602226, -0.07391}, + {2.432634, 0.029842, -0.599318, -0.071253}}, + {{2.432278, 0.028152, -0.599555, -0.074139}, + {2.432659, 0.029244, -0.599294, -0.072382}}}, + tensor_options); + ASSERT_EQ(result.sizes().size(), ref_output.sizes().size()); + ASSERT_TRUE(torch::allclose( + result, + ref_output, + 1e-7, + 1e-5, + /*equal_nan=*/true)); // memory_key_padding_mask torch::Tensor t_key_padding_mask = {}; key_padding_mask = torch::zeros({2, 5}, tensor_options) == 1; - result = model(decoder_input, memory_input, t_mask, m_mask, - t_key_padding_mask, key_padding_mask).detach(); - ref_output = torch::tensor({{{2.430065, 0.027862, -0.601136, -0.073096}, - {2.431935, 0.028907, -0.599809, -0.072488}}, - {{2.428457, 0.027053, -0.602275, -0.073462}, - {2.431970, 0.029387, -0.599789, -0.071621}}, - {{2.431934, 0.028196, -0.599802, -0.073809}, - {2.432306, 0.028858, -0.599542, -0.072846}}}, - tensor_options); - ASSERT_EQ(result.sizes().size(),ref_output.sizes().size()); - ASSERT_TRUE(torch::allclose(result, ref_output, 1e-7, 1e-5, - /*equal_nan=*/true)); + result = model( + decoder_input, + memory_input, + t_mask, + m_mask, + t_key_padding_mask, + key_padding_mask) + .detach(); + ref_output = torch::tensor( + {{{2.430065, 0.027862, -0.601136, -0.073096}, + {2.431935, 0.028907, -0.599809, -0.072488}}, + {{2.428457, 0.027053, -0.602275, -0.073462}, + {2.431970, 0.029387, -0.599789, -0.071621}}, + {{2.431934, 0.028196, -0.599802, -0.073809}, + {2.432306, 0.028858, -0.599542, -0.072846}}}, + tensor_options); + ASSERT_EQ(result.sizes().size(), ref_output.sizes().size()); + ASSERT_TRUE(torch::allclose( + result, + ref_output, + 1e-7, + 1e-5, + /*equal_nan=*/true)); // memory_key_padding_mask key_padding_mask[0][4] = 1; key_padding_mask[1][3] = 1; key_padding_mask[1][4] = 1; - result = model(decoder_input, memory_input, t_mask, m_mask, - t_key_padding_mask, key_padding_mask).detach(); - ref_output = torch::tensor({{{2.429757, 0.027358, -0.601351, -0.073816}, - {2.432692, 0.028583, -0.599263, -0.073634}}, - {{2.428247, 0.02662, -0.602419, -0.074123}, - {2.432657, 0.029055, -0.599293, -0.072732}}, - {{2.431515, 0.027687, -0.600096, -0.074459}, - {2.433075, 0.028543, -0.598987, -0.073985}}}, - tensor_options); - ASSERT_EQ(result.sizes().size(),ref_output.sizes().size()); - ASSERT_TRUE(torch::allclose(result, ref_output, 1e-7, 1e-5, - /*equal_nan=*/true)); + result = model( + decoder_input, + memory_input, + t_mask, + m_mask, + t_key_padding_mask, + key_padding_mask) + .detach(); + ref_output = torch::tensor( + {{{2.429757, 0.027358, -0.601351, -0.073816}, + {2.432692, 0.028583, -0.599263, -0.073634}}, + {{2.428247, 0.02662, -0.602419, -0.074123}, + {2.432657, 0.029055, -0.599293, -0.072732}}, + {{2.431515, 0.027687, -0.600096, -0.074459}, + {2.433075, 0.028543, -0.598987, -0.073985}}}, + tensor_options); + ASSERT_EQ(result.sizes().size(), ref_output.sizes().size()); + ASSERT_TRUE(torch::allclose( + result, + ref_output, + 1e-7, + 1e-5, + /*equal_nan=*/true)); // multiple layers no norm model = TransformerDecoder(TransformerDecoderOptions(decoder_layer, 2)); @@ -785,53 +1042,54 @@ void transformer_decoder_test_helper(bool is_cuda, bool use_callable_activation) memory_input = torch::tensor({{{60, 70, 80, 90}}}, tensor_options); result = model(decoder_input, memory_input).detach(); ref_output = torch::tensor( - {{{2.31316, 0.0950293, -0.671995, 0.102802}}}, - tensor_options); - ASSERT_EQ(result.sizes().size(),ref_output.sizes().size()); - ASSERT_TRUE(torch::allclose(result, ref_output, 1e-7, 1e-5, - /*equal_nan=*/true)); + {{{2.31316, 0.0950293, -0.671995, 0.102802}}}, tensor_options); + ASSERT_EQ(result.sizes().size(), ref_output.sizes().size()); + ASSERT_TRUE(torch::allclose( + result, + ref_output, + 1e-7, + 1e-5, + /*equal_nan=*/true)); // multiple layers no norm model = TransformerDecoder(TransformerDecoderOptions(decoder_layer, 6)); if (is_cuda) { model->to(torch::kCUDA); } - // deterministic input - decoder_input = torch::tensor({{{0.4517, 0.6793, 0.5313, 0.0034}, - {0.2678, 0.3677, 0.4459, 0.7166}}, - {{0.8100, 0.3716, 0.4096, 0.1976}, - {0.6958, 0.8844, 0.6081, 0.8315}}, - {{0.0494, 0.9343, 0.5955, 0.3830}, - {0.5404, 0.3464, 0.9378, 0.6200}}}, - tensor_options); - memory_input = torch::tensor({{{0.7462, 0.6653, 0.5679, 0.4891}, - {0.5387, 0.1655, 0.3565, 0.0471}}, - {{0.8335, 0.2799, 0.5031, 0.2947}, - {0.1402, 0.0318, 0.7636, 0.1346}}, - {{0.6333, 0.9344, 0.1376, 0.9938}, - {0.8924, 0.2872, 0.6692, 0.2944}}, - {{0.9897, 0.6915, 0.3154, 0.1733}, - {0.8645, 0.3513, 0.3064, 0.0767}}, - {{0.8117, 0.2366, 0.4838, 0.7881}, - {0.3718, 0.4945, 0.9511, 0.0864}}}, - tensor_options); + // deterministic input + decoder_input = torch::tensor( + {{{0.4517, 0.6793, 0.5313, 0.0034}, {0.2678, 0.3677, 0.4459, 0.7166}}, + {{0.8100, 0.3716, 0.4096, 0.1976}, {0.6958, 0.8844, 0.6081, 0.8315}}, + {{0.0494, 0.9343, 0.5955, 0.3830}, {0.5404, 0.3464, 0.9378, 0.6200}}}, + tensor_options); + memory_input = torch::tensor( + {{{0.7462, 0.6653, 0.5679, 0.4891}, {0.5387, 0.1655, 0.3565, 0.0471}}, + {{0.8335, 0.2799, 0.5031, 0.2947}, {0.1402, 0.0318, 0.7636, 0.1346}}, + {{0.6333, 0.9344, 0.1376, 0.9938}, {0.8924, 0.2872, 0.6692, 0.2944}}, + {{0.9897, 0.6915, 0.3154, 0.1733}, {0.8645, 0.3513, 0.3064, 0.0767}}, + {{0.8117, 0.2366, 0.4838, 0.7881}, {0.3718, 0.4945, 0.9511, 0.0864}}}, + tensor_options); result = model(decoder_input, memory_input).detach(); - ref_output = torch::tensor({{{2.42794, 0.026164, -0.60263, -0.0747591}, - {2.43113, 0.0279516, -0.600376, -0.0736896}}, - {{2.42794, 0.026164, -0.60263, -0.0747591}, - {2.43113, 0.0279516, -0.600376, -0.0736896}}, - {{2.42794, 0.026164, -0.60263, -0.0747591}, - {2.43113, 0.0279516, -0.600376, -0.0736896}}}, - tensor_options); - ASSERT_EQ(result.sizes().size(),ref_output.sizes().size()); - ASSERT_TRUE(torch::allclose(result, ref_output, 1e-7, 1e-5, - /*equal_nan=*/true)); - + ref_output = torch::tensor( + {{{2.42794, 0.026164, -0.60263, -0.0747591}, + {2.43113, 0.0279516, -0.600376, -0.0736896}}, + {{2.42794, 0.026164, -0.60263, -0.0747591}, + {2.43113, 0.0279516, -0.600376, -0.0736896}}, + {{2.42794, 0.026164, -0.60263, -0.0747591}, + {2.43113, 0.0279516, -0.600376, -0.0736896}}}, + tensor_options); + ASSERT_EQ(result.sizes().size(), ref_output.sizes().size()); + ASSERT_TRUE(torch::allclose( + result, + ref_output, + 1e-7, + 1e-5, + /*equal_nan=*/true)); // multiple layers with norm LayerNorm norm(LayerNormOptions({decoder_layer.get()->options.d_model()})); model = TransformerDecoder( - TransformerDecoderOptions(decoder_layer, 2).norm(AnyModule(norm))); + TransformerDecoderOptions(decoder_layer, 2).norm(AnyModule(norm))); if (is_cuda) { model->to(torch::kCUDA); } @@ -840,50 +1098,52 @@ void transformer_decoder_test_helper(bool is_cuda, bool use_callable_activation) memory_input = torch::tensor({{{60, 70, 80, 90}}}, tensor_options); result = model(decoder_input, memory_input).detach(); ref_output = torch::tensor( - {{{1.66166, -0.326986, -1.01466, -0.320017}}}, - tensor_options); - ASSERT_EQ(result.sizes().size(),ref_output.sizes().size()); - ASSERT_TRUE(torch::allclose(result, ref_output, 1e-7, 1e-5, - /*equal_nan=*/true)); + {{{1.66166, -0.326986, -1.01466, -0.320017}}}, tensor_options); + ASSERT_EQ(result.sizes().size(), ref_output.sizes().size()); + ASSERT_TRUE(torch::allclose( + result, + ref_output, + 1e-7, + 1e-5, + /*equal_nan=*/true)); // multiple layers with norm model = TransformerDecoder( - TransformerDecoderOptions(decoder_layer, 6).norm(AnyModule(norm))); + TransformerDecoderOptions(decoder_layer, 6).norm(AnyModule(norm))); if (is_cuda) { model->to(torch::kCUDA); } - // deterministic input - decoder_input = torch::tensor({{{0.4517, 0.6793, 0.5313, 0.0034}, - {0.2678, 0.3677, 0.4459, 0.7166}}, - {{0.8100, 0.3716, 0.4096, 0.1976}, - {0.6958, 0.8844, 0.6081, 0.8315}}, - {{0.0494, 0.9343, 0.5955, 0.3830}, - {0.5404, 0.3464, 0.9378, 0.6200}}}, - tensor_options); - memory_input = torch::tensor({{{0.7462, 0.6653, 0.5679, 0.4891}, - {0.5387, 0.1655, 0.3565, 0.0471}}, - {{0.8335, 0.2799, 0.5031, 0.2947}, - {0.1402, 0.0318, 0.7636, 0.1346}}, - {{0.6333, 0.9344, 0.1376, 0.9938}, - {0.8924, 0.2872, 0.6692, 0.2944}}, - {{0.9897, 0.6915, 0.3154, 0.1733}, - {0.8645, 0.3513, 0.3064, 0.0767}}, - {{0.8117, 0.2366, 0.4838, 0.7881}, - {0.3718, 0.4945, 0.9511, 0.0864}}}, - tensor_options); + // deterministic input + decoder_input = torch::tensor( + {{{0.4517, 0.6793, 0.5313, 0.0034}, {0.2678, 0.3677, 0.4459, 0.7166}}, + {{0.8100, 0.3716, 0.4096, 0.1976}, {0.6958, 0.8844, 0.6081, 0.8315}}, + {{0.0494, 0.9343, 0.5955, 0.3830}, {0.5404, 0.3464, 0.9378, 0.6200}}}, + tensor_options); + memory_input = torch::tensor( + {{{0.7462, 0.6653, 0.5679, 0.4891}, {0.5387, 0.1655, 0.3565, 0.0471}}, + {{0.8335, 0.2799, 0.5031, 0.2947}, {0.1402, 0.0318, 0.7636, 0.1346}}, + {{0.6333, 0.9344, 0.1376, 0.9938}, {0.8924, 0.2872, 0.6692, 0.2944}}, + {{0.9897, 0.6915, 0.3154, 0.1733}, {0.8645, 0.3513, 0.3064, 0.0767}}, + {{0.8117, 0.2366, 0.4838, 0.7881}, {0.3718, 0.4945, 0.9511, 0.0864}}}, + tensor_options); result = model(decoder_input, memory_input).detach(); - ref_output = torch::tensor({{{1.69559, -0.357291, -0.894741, -0.443553}, - {1.69571, -0.357363, -0.894154, -0.444196}}, - {{1.69559, -0.357291, -0.894741, -0.443553}, - {1.69571, -0.357363, -0.894154, -0.444196}}, - {{1.69559, -0.357291, -0.894741, -0.443553}, - {1.69571, -0.357363, -0.894154, -0.444196}}}, - tensor_options); - ASSERT_EQ(result.sizes().size(),ref_output.sizes().size()); - ASSERT_TRUE(torch::allclose(result, ref_output, 1e-7, 1e-5, - /*equal_nan=*/true)); - - //gelu activation test cases + ref_output = torch::tensor( + {{{1.69559, -0.357291, -0.894741, -0.443553}, + {1.69571, -0.357363, -0.894154, -0.444196}}, + {{1.69559, -0.357291, -0.894741, -0.443553}, + {1.69571, -0.357363, -0.894154, -0.444196}}, + {{1.69559, -0.357291, -0.894741, -0.443553}, + {1.69571, -0.357363, -0.894154, -0.444196}}}, + tensor_options); + ASSERT_EQ(result.sizes().size(), ref_output.sizes().size()); + ASSERT_TRUE(torch::allclose( + result, + ref_output, + 1e-7, + 1e-5, + /*equal_nan=*/true)); + + // gelu activation test cases decoder_layer.get()->options.activation(torch::kGELU); model = TransformerDecoder(TransformerDecoderOptions(decoder_layer, 1)); if (is_cuda) { @@ -891,168 +1151,175 @@ void transformer_decoder_test_helper(bool is_cuda, bool use_callable_activation) } // deterministic input - decoder_input = torch::tensor({{{20, 30, 40, 50}}}, - tensor_options); - memory_input = torch::tensor({{{60, 70, 80, 90}}}, - tensor_options); + decoder_input = torch::tensor({{{20, 30, 40, 50}}}, tensor_options); + memory_input = torch::tensor({{{60, 70, 80, 90}}}, tensor_options); result = model(decoder_input, memory_input).detach(); ref_output = torch::tensor( - {{{2.306435, 0.095946, -0.675796, 0.10687}}}, - tensor_options); - ASSERT_EQ(result.sizes().size(),ref_output.sizes().size()); - ASSERT_TRUE(torch::allclose(result, ref_output, 1e-7, 1e-5, - /*equal_nan=*/true)); + {{{2.306435, 0.095946, -0.675796, 0.10687}}}, tensor_options); + ASSERT_EQ(result.sizes().size(), ref_output.sizes().size()); + ASSERT_TRUE(torch::allclose( + result, + ref_output, + 1e-7, + 1e-5, + /*equal_nan=*/true)); // deterministic input - decoder_input = torch::tensor({{{9, 10, 11, 12}}, - {{11, 12, 13, 14}}}, - tensor_options); + decoder_input = + torch::tensor({{{9, 10, 11, 12}}, {{11, 12, 13, 14}}}, tensor_options); memory_input = torch::tensor({{{1, 2, 3, 4}}}, tensor_options); result = model(decoder_input, memory_input).detach(); - ref_output = torch::tensor({{{2.415448, 0.054389, -0.610932, -0.0156613}}, - {{2.415448, 0.054389, -0.610932, -0.0156613}}}, - tensor_options); - ASSERT_EQ(result.sizes().size(),ref_output.sizes().size()); - ASSERT_TRUE(torch::allclose(result, ref_output, 1e-7, 1e-5, - /*equal_nan=*/true)); + ref_output = torch::tensor( + {{{2.415448, 0.054389, -0.610932, -0.0156613}}, + {{2.415448, 0.054389, -0.610932, -0.0156613}}}, + tensor_options); + ASSERT_EQ(result.sizes().size(), ref_output.sizes().size()); + ASSERT_TRUE(torch::allclose( + result, + ref_output, + 1e-7, + 1e-5, + /*equal_nan=*/true)); // deterministic input - decoder_input = torch::tensor({{{1, 2, 3, 4}}, - {{5, 6, 7, 8}}}, - tensor_options); - memory_input = torch::tensor({{{9, 10, 11, 12}}, - {{11, 12, 13, 14}}}, - tensor_options); + decoder_input = + torch::tensor({{{1, 2, 3, 4}}, {{5, 6, 7, 8}}}, tensor_options); + memory_input = + torch::tensor({{{9, 10, 11, 12}}, {{11, 12, 13, 14}}}, tensor_options); result = model(decoder_input, memory_input).detach(); - ref_output = torch::tensor({{{2.338531, 0.087709, -0.65776, 0.080646}}, - {{2.338531, 0.087709, -0.65776, 0.080646}}}, - tensor_options); - ASSERT_EQ(result.sizes().size(),ref_output.sizes().size()); - ASSERT_TRUE(torch::allclose(result, ref_output, 1e-7, 1e-5, - /*equal_nan=*/true)); + ref_output = torch::tensor( + {{{2.338531, 0.087709, -0.65776, 0.080646}}, + {{2.338531, 0.087709, -0.65776, 0.080646}}}, + tensor_options); + ASSERT_EQ(result.sizes().size(), ref_output.sizes().size()); + ASSERT_TRUE(torch::allclose( + result, + ref_output, + 1e-7, + 1e-5, + /*equal_nan=*/true)); // deterministic input - decoder_input = torch::tensor({{{0.4517, 0.6793, 0.5313, 0.0034}, - {0.2678, 0.3677, 0.4459, 0.7166}}, - {{0.8100, 0.3716, 0.4096, 0.1976}, - {0.6958, 0.8844, 0.6081, 0.8315}}, - {{0.0494, 0.9343, 0.5955, 0.3830}, - {0.5404, 0.3464, 0.9378, 0.6200}}}, - tensor_options); - memory_input = torch::tensor({{{0.7462, 0.6653, 0.5679, 0.4891}, - {0.5387, 0.1655, 0.3565, 0.0471}}, - {{0.8335, 0.2799, 0.5031, 0.2947}, - {0.1402, 0.0318, 0.7636, 0.1346}}, - {{0.6333, 0.9344, 0.1376, 0.9938}, - {0.8924, 0.2872, 0.6692, 0.2944}}, - {{0.9897, 0.6915, 0.3154, 0.1733}, - {0.8645, 0.3513, 0.3064, 0.0767}}, - {{0.8117, 0.2366, 0.4838, 0.7881}, - {0.3718, 0.4945, 0.9511, 0.0864}}}, - tensor_options); + decoder_input = torch::tensor( + {{{0.4517, 0.6793, 0.5313, 0.0034}, {0.2678, 0.3677, 0.4459, 0.7166}}, + {{0.8100, 0.3716, 0.4096, 0.1976}, {0.6958, 0.8844, 0.6081, 0.8315}}, + {{0.0494, 0.9343, 0.5955, 0.3830}, {0.5404, 0.3464, 0.9378, 0.6200}}}, + tensor_options); + memory_input = torch::tensor( + {{{0.7462, 0.6653, 0.5679, 0.4891}, {0.5387, 0.1655, 0.3565, 0.0471}}, + {{0.8335, 0.2799, 0.5031, 0.2947}, {0.1402, 0.0318, 0.7636, 0.1346}}, + {{0.6333, 0.9344, 0.1376, 0.9938}, {0.8924, 0.2872, 0.6692, 0.2944}}, + {{0.9897, 0.6915, 0.3154, 0.1733}, {0.8645, 0.3513, 0.3064, 0.0767}}, + {{0.8117, 0.2366, 0.4838, 0.7881}, {0.3718, 0.4945, 0.9511, 0.0864}}}, + tensor_options); result = model(decoder_input, memory_input).detach(); ref_output = torch::tensor( - {{{2.42049104, 0.03443088, -0.60793706, -0.05436271}, - {2.42210631, 0.03546578, -0.60679895, -0.05357488}}, - {{2.41907674, 0.0336104, -0.60892977, -0.05490462}, - {2.42216881, 0.03586554, -0.6067524, -0.05289126}}, - {{2.42205716, 0.03488046, -0.60683681, -0.05460596}, - {2.42240309, 0.0354595, -0.60659063, -0.05378816}}}, - tensor_options); - ASSERT_EQ(result.sizes().size(),ref_output.sizes().size()); - ASSERT_TRUE(torch::allclose(result, ref_output, 1e-7, 1e-5, - /*equal_nan=*/true)); + {{{2.42049104, 0.03443088, -0.60793706, -0.05436271}, + {2.42210631, 0.03546578, -0.60679895, -0.05357488}}, + {{2.41907674, 0.0336104, -0.60892977, -0.05490462}, + {2.42216881, 0.03586554, -0.6067524, -0.05289126}}, + {{2.42205716, 0.03488046, -0.60683681, -0.05460596}, + {2.42240309, 0.0354595, -0.60659063, -0.05378816}}}, + tensor_options); + ASSERT_EQ(result.sizes().size(), ref_output.sizes().size()); + ASSERT_TRUE(torch::allclose( + result, + ref_output, + 1e-7, + 1e-5, + /*equal_nan=*/true)); // Multiple layers no norm model = TransformerDecoder(TransformerDecoderOptions(decoder_layer, 6)); if (is_cuda) { model->to(torch::kCUDA); } - decoder_input = torch::tensor({{{0.4517, 0.6793, 0.5313, 0.0034}, - {0.2678, 0.3677, 0.4459, 0.7166}}, - {{0.8100, 0.3716, 0.4096, 0.1976}, - {0.6958, 0.8844, 0.6081, 0.8315}}, - {{0.0494, 0.9343, 0.5955, 0.3830}, - {0.5404, 0.3464, 0.9378, 0.6200}}}, - tensor_options); - memory_input = torch::tensor({{{0.7462, 0.6653, 0.5679, 0.4891}, - {0.5387, 0.1655, 0.3565, 0.0471}}, - {{0.8335, 0.2799, 0.5031, 0.2947}, - {0.1402, 0.0318, 0.7636, 0.1346}}, - {{0.6333, 0.9344, 0.1376, 0.9938}, - {0.8924, 0.2872, 0.6692, 0.2944}}, - {{0.9897, 0.6915, 0.3154, 0.1733}, - {0.8645, 0.3513, 0.3064, 0.0767}}, - {{0.8117, 0.2366, 0.4838, 0.7881}, - {0.3718, 0.4945, 0.9511, 0.0864}}}, - tensor_options); + decoder_input = torch::tensor( + {{{0.4517, 0.6793, 0.5313, 0.0034}, {0.2678, 0.3677, 0.4459, 0.7166}}, + {{0.8100, 0.3716, 0.4096, 0.1976}, {0.6958, 0.8844, 0.6081, 0.8315}}, + {{0.0494, 0.9343, 0.5955, 0.3830}, {0.5404, 0.3464, 0.9378, 0.6200}}}, + tensor_options); + memory_input = torch::tensor( + {{{0.7462, 0.6653, 0.5679, 0.4891}, {0.5387, 0.1655, 0.3565, 0.0471}}, + {{0.8335, 0.2799, 0.5031, 0.2947}, {0.1402, 0.0318, 0.7636, 0.1346}}, + {{0.6333, 0.9344, 0.1376, 0.9938}, {0.8924, 0.2872, 0.6692, 0.2944}}, + {{0.9897, 0.6915, 0.3154, 0.1733}, {0.8645, 0.3513, 0.3064, 0.0767}}, + {{0.8117, 0.2366, 0.4838, 0.7881}, {0.3718, 0.4945, 0.9511, 0.0864}}}, + tensor_options); result = model(decoder_input, memory_input).detach(); - ref_output = torch::tensor({{{2.41859, 0.0328114, -0.609269, -0.0560386}, - {2.42138, 0.034598, -0.607316, -0.0546574}}, - {{2.41859, 0.0328114, -0.609269, -0.0560386}, - {2.42138, 0.034598, -0.607316, -0.0546574}}, - {{2.41859, 0.0328114, -0.609269, -0.0560386}, - {2.42138, 0.034598, -0.607316, -0.0546574}}}, - tensor_options); - ASSERT_EQ(result.sizes().size(),ref_output.sizes().size()); - ASSERT_TRUE(torch::allclose(result, ref_output, 1e-7, 1e-5, - /*equal_nan=*/true)); + ref_output = torch::tensor( + {{{2.41859, 0.0328114, -0.609269, -0.0560386}, + {2.42138, 0.034598, -0.607316, -0.0546574}}, + {{2.41859, 0.0328114, -0.609269, -0.0560386}, + {2.42138, 0.034598, -0.607316, -0.0546574}}, + {{2.41859, 0.0328114, -0.609269, -0.0560386}, + {2.42138, 0.034598, -0.607316, -0.0546574}}}, + tensor_options); + ASSERT_EQ(result.sizes().size(), ref_output.sizes().size()); + ASSERT_TRUE(torch::allclose( + result, + ref_output, + 1e-7, + 1e-5, + /*equal_nan=*/true)); // Multiple layers with norm norm = LayerNorm(LayerNormOptions({decoder_layer.get()->options.d_model()})); - model = TransformerDecoder(TransformerDecoderOptions(decoder_layer, 6).norm(AnyModule(norm))); + model = TransformerDecoder( + TransformerDecoderOptions(decoder_layer, 6).norm(AnyModule(norm))); if (is_cuda) { model->to(torch::kCUDA); } - decoder_input = torch::tensor({{{0.4517, 0.6793, 0.5313, 0.0034}, - {0.2678, 0.3677, 0.4459, 0.7166}}, - {{0.8100, 0.3716, 0.4096, 0.1976}, - {0.6958, 0.8844, 0.6081, 0.8315}}, - {{0.0494, 0.9343, 0.5955, 0.3830}, - {0.5404, 0.3464, 0.9378, 0.6200}}}, - tensor_options); - memory_input = torch::tensor({{{0.7462, 0.6653, 0.5679, 0.4891}, - {0.5387, 0.1655, 0.3565, 0.0471}}, - {{0.8335, 0.2799, 0.5031, 0.2947}, - {0.1402, 0.0318, 0.7636, 0.1346}}, - {{0.6333, 0.9344, 0.1376, 0.9938}, - {0.8924, 0.2872, 0.6692, 0.2944}}, - {{0.9897, 0.6915, 0.3154, 0.1733}, - {0.8645, 0.3513, 0.3064, 0.0767}}, - {{0.8117, 0.2366, 0.4838, 0.7881}, - {0.3718, 0.4945, 0.9511, 0.0864}}}, - tensor_options); + decoder_input = torch::tensor( + {{{0.4517, 0.6793, 0.5313, 0.0034}, {0.2678, 0.3677, 0.4459, 0.7166}}, + {{0.8100, 0.3716, 0.4096, 0.1976}, {0.6958, 0.8844, 0.6081, 0.8315}}, + {{0.0494, 0.9343, 0.5955, 0.3830}, {0.5404, 0.3464, 0.9378, 0.6200}}}, + tensor_options); + memory_input = torch::tensor( + {{{0.7462, 0.6653, 0.5679, 0.4891}, {0.5387, 0.1655, 0.3565, 0.0471}}, + {{0.8335, 0.2799, 0.5031, 0.2947}, {0.1402, 0.0318, 0.7636, 0.1346}}, + {{0.6333, 0.9344, 0.1376, 0.9938}, {0.8924, 0.2872, 0.6692, 0.2944}}, + {{0.9897, 0.6915, 0.3154, 0.1733}, {0.8645, 0.3513, 0.3064, 0.0767}}, + {{0.8117, 0.2366, 0.4838, 0.7881}, {0.3718, 0.4945, 0.9511, 0.0864}}}, + tensor_options); result = model(decoder_input, memory_input).detach(); - ref_output = torch::tensor({{{1.69298, -0.355163, -0.906375, -0.431439}, - {1.69305, -0.355195, -0.906062, -0.431791}}, - {{1.69298, -0.355163, -0.906375, -0.431439}, - {1.69305, -0.355195, -0.906062, -0.431791}}, - {{1.69298, -0.355163, -0.906375, -0.431439}, - {1.69305, -0.355195, -0.906062, -0.431791}}}, - tensor_options); - ASSERT_EQ(result.sizes().size(),ref_output.sizes().size()); - ASSERT_TRUE(torch::allclose(result, ref_output, 1e-7, 1e-5, - /*equal_nan=*/true)); - + ref_output = torch::tensor( + {{{1.69298, -0.355163, -0.906375, -0.431439}, + {1.69305, -0.355195, -0.906062, -0.431791}}, + {{1.69298, -0.355163, -0.906375, -0.431439}, + {1.69305, -0.355195, -0.906062, -0.431791}}, + {{1.69298, -0.355163, -0.906375, -0.431439}, + {1.69305, -0.355195, -0.906062, -0.431791}}}, + tensor_options); + ASSERT_EQ(result.sizes().size(), ref_output.sizes().size()); + ASSERT_TRUE(torch::allclose( + result, + ref_output, + 1e-7, + 1e-5, + /*equal_nan=*/true)); } TEST_F(TransformerTest, TransformerDecoder) { - transformer_decoder_test_helper(/*is_cuda=*/ false, /*use_callable_activation=*/ false); - transformer_decoder_test_helper(/*is_cuda=*/ false, /*use_callable_activation=*/ true); + transformer_decoder_test_helper( + /*is_cuda=*/false, /*use_callable_activation=*/false); + transformer_decoder_test_helper( + /*is_cuda=*/false, /*use_callable_activation=*/true); } TEST_F(TransformerTest, TransformerDecoder_CUDA) { - transformer_decoder_test_helper(/*is_cuda=*/ true, /*use_callable_activation=*/ false); - transformer_decoder_test_helper(/*is_cuda=*/ true, /*use_callable_activation=*/ true); + transformer_decoder_test_helper( + /*is_cuda=*/true, /*use_callable_activation=*/false); + transformer_decoder_test_helper( + /*is_cuda=*/true, /*use_callable_activation=*/true); } - TEST_F(TransformerTest, PrettyPrintTransformerDecoder) { LayerNorm norm = LayerNorm(LayerNormOptions({4})); TransformerDecoderOptions options( - TransformerDecoderOptions( - TransformerDecoderLayerOptions(4, 2),2).norm(AnyModule(norm))); + TransformerDecoderOptions(TransformerDecoderLayerOptions(4, 2), 2) + .norm(AnyModule(norm))); ASSERT_EQ( c10::str(TransformerDecoder(options)), "torch::nn::TransformerDecoderImpl(\n" @@ -1097,118 +1364,159 @@ TEST_F(TransformerTest, PrettyPrintTransformerDecoder) { } void transformer_test_helper(bool is_cuda, bool use_callable_activation) { - // this is a deterministic test for Transformere - torch::Device device = is_cuda ? torch::kCUDA : torch::kCPU; - torch::TensorOptions tensor_options = torch::TensorOptions().dtype(torch::kFloat32).device(device); - - // transformer created encoder/decoder - auto options = TransformerOptions() - .d_model(4) - .nhead(2) - .num_encoder_layers(2) - .num_decoder_layers(1) - .dim_feedforward(16) - .dropout(0.0) - .activation(torch::kReLU); - if (use_callable_activation) { - options.activation([&](const torch::Tensor& t) { return torch::nn::functional::relu(t); }); - } - Transformer model(options); - - set_parameter_to_constants(model, tensor_options); - if (tensor_options.device() == torch::kCUDA) { - model->to(torch::kCUDA); - } - - // transformer with customized encoder/decoder - LayerNorm enorm(LayerNormOptions({4})); - TransformerEncoder encoder(TransformerEncoderOptions( - TransformerEncoderLayerOptions(4, 2).dim_feedforward(16).dropout(0.0), 2).norm(AnyModule(enorm))); - - LayerNorm dnorm(LayerNormOptions({4})); - TransformerDecoder decoder(TransformerDecoderOptions( - TransformerDecoderLayerOptions(4, 2).dim_feedforward(16).dropout(0.0), 1).norm(AnyModule(dnorm))); - - Transformer model_cus(TransformerOptions() - .d_model(4) - .nhead(2) - .custom_encoder(AnyModule(encoder)) - .custom_decoder(AnyModule(decoder))); - - set_parameter_to_constants(model_cus, tensor_options); - if (tensor_options.device() == torch::kCUDA) { - model_cus->to(torch::kCUDA); - } - - // test cases - torch::Tensor src = torch::tensor({ - {{1.0, 2.0, 3.0, 4.0}, {5.0, 6.0, 7.0, 8.0}}, - {{9.0, 10.0, 11.0, 12.0}, {13.0, 14.0, 15.0, 16.0}}, - {{17.0, 18.0, 19.0, 20.0}, {21.0, 22.0, 23.0, 24.0}}}, tensor_options); - - torch::Tensor tgt = torch::tensor({ - {{1.0, 2.0, 3.0, 4.0}, {5.0, 6.0, 7.0, 8.0}}, - {{9.0, 10.0, 11.0, 12.0}, {13.0, 14.0, 15.0, 16.0}}}, tensor_options); - - torch::Tensor ref_output = torch::tensor({ - {{2.695875, 0.347114, -0.044355, -0.549541}, {2.696091, 0.347015, -0.044770, -0.548522}}, - {{2.695875, 0.347114, -0.044355, -0.549541}, {2.696091, 0.347015, -0.044770, -0.548522}}}, tensor_options); - torch::Tensor result = model(src, tgt); - torch::Tensor result_cus = model_cus(src, tgt); - ASSERT_EQ(result.sizes(), ref_output.sizes()); - ASSERT_TRUE(result.equal(result_cus)); - ASSERT_TRUE(torch::allclose(result, ref_output, 1e-7, 1e-5, /*equal_nan=*/true)); - - torch::Tensor src_mask = Transformer::Impl::generate_square_subsequent_mask(src.size(0)).to(tensor_options); - ref_output = torch::tensor({ - {{2.695875, 0.347114, -0.044355, -0.549541}, {2.696091, 0.347015, -0.044770, -0.548522}}, - {{2.695875, 0.347114, -0.044355, -0.549541}, {2.696091, 0.347015, -0.044770, -0.548522}}}, tensor_options); - result = model(src, tgt, src_mask); - result_cus = model_cus(src, tgt, src_mask); - ASSERT_EQ(result.sizes(), ref_output.sizes()); - ASSERT_TRUE(result.equal(result_cus)); - ASSERT_TRUE(torch::allclose(result, ref_output, 1e-7, 1e-5, /*equal_nan=*/true)); - - torch::Tensor tgt_key_padding_mask = torch::zeros({tgt.size(1), tgt.size(0)}, tensor_options) == 1; - tgt_key_padding_mask[0][0] = 1; - tgt_key_padding_mask[1][1] = 1; - ref_output = torch::tensor({ - {{2.696114, 0.347004, -0.044813, -0.548417}, {2.696091, 0.347015, -0.044770, -0.548522}}, - {{2.696114, 0.347004, -0.044813, -0.548417}, {2.696091, 0.347015, -0.044770, -0.548522}}}, tensor_options); - result = model(src, tgt, src_mask, torch::Tensor(), torch::Tensor(), torch::Tensor(), tgt_key_padding_mask); - result_cus = model_cus(src, tgt, src_mask, torch::Tensor(), torch::Tensor(), torch::Tensor(), tgt_key_padding_mask); - ASSERT_EQ(result.sizes(), ref_output.sizes()); - ASSERT_TRUE(result.equal(result_cus)); - ASSERT_TRUE(torch::allclose(result, ref_output, 1e-7, 1e-5, /*equal_nan=*/true)); + // this is a deterministic test for Transformere + torch::Device device = is_cuda ? torch::kCUDA : torch::kCPU; + torch::TensorOptions tensor_options = + torch::TensorOptions().dtype(torch::kFloat32).device(device); + + // transformer created encoder/decoder + auto options = TransformerOptions() + .d_model(4) + .nhead(2) + .num_encoder_layers(2) + .num_decoder_layers(1) + .dim_feedforward(16) + .dropout(0.0) + .activation(torch::kReLU); + if (use_callable_activation) { + options.activation( + [&](const torch::Tensor& t) { return torch::nn::functional::relu(t); }); + } + Transformer model(options); + + set_parameter_to_constants(model, tensor_options); + if (tensor_options.device() == torch::kCUDA) { + model->to(torch::kCUDA); + } + + // transformer with customized encoder/decoder + LayerNorm enorm(LayerNormOptions({4})); + TransformerEncoder encoder( + TransformerEncoderOptions( + TransformerEncoderLayerOptions(4, 2).dim_feedforward(16).dropout(0.0), + 2) + .norm(AnyModule(enorm))); + + LayerNorm dnorm(LayerNormOptions({4})); + TransformerDecoder decoder( + TransformerDecoderOptions( + TransformerDecoderLayerOptions(4, 2).dim_feedforward(16).dropout(0.0), + 1) + .norm(AnyModule(dnorm))); + + Transformer model_cus(TransformerOptions() + .d_model(4) + .nhead(2) + .custom_encoder(AnyModule(encoder)) + .custom_decoder(AnyModule(decoder))); + + set_parameter_to_constants(model_cus, tensor_options); + if (tensor_options.device() == torch::kCUDA) { + model_cus->to(torch::kCUDA); + } + + // test cases + torch::Tensor src = torch::tensor( + {{{1.0, 2.0, 3.0, 4.0}, {5.0, 6.0, 7.0, 8.0}}, + {{9.0, 10.0, 11.0, 12.0}, {13.0, 14.0, 15.0, 16.0}}, + {{17.0, 18.0, 19.0, 20.0}, {21.0, 22.0, 23.0, 24.0}}}, + tensor_options); + + torch::Tensor tgt = torch::tensor( + {{{1.0, 2.0, 3.0, 4.0}, {5.0, 6.0, 7.0, 8.0}}, + {{9.0, 10.0, 11.0, 12.0}, {13.0, 14.0, 15.0, 16.0}}}, + tensor_options); + + torch::Tensor ref_output = torch::tensor( + {{{2.695875, 0.347114, -0.044355, -0.549541}, + {2.696091, 0.347015, -0.044770, -0.548522}}, + {{2.695875, 0.347114, -0.044355, -0.549541}, + {2.696091, 0.347015, -0.044770, -0.548522}}}, + tensor_options); + torch::Tensor result = model(src, tgt); + torch::Tensor result_cus = model_cus(src, tgt); + ASSERT_EQ(result.sizes(), ref_output.sizes()); + ASSERT_TRUE(result.equal(result_cus)); + ASSERT_TRUE( + torch::allclose(result, ref_output, 1e-7, 1e-5, /*equal_nan=*/true)); + + torch::Tensor src_mask = + Transformer::Impl::generate_square_subsequent_mask(src.size(0)) + .to(tensor_options); + ref_output = torch::tensor( + {{{2.695875, 0.347114, -0.044355, -0.549541}, + {2.696091, 0.347015, -0.044770, -0.548522}}, + {{2.695875, 0.347114, -0.044355, -0.549541}, + {2.696091, 0.347015, -0.044770, -0.548522}}}, + tensor_options); + result = model(src, tgt, src_mask); + result_cus = model_cus(src, tgt, src_mask); + ASSERT_EQ(result.sizes(), ref_output.sizes()); + ASSERT_TRUE(result.equal(result_cus)); + ASSERT_TRUE( + torch::allclose(result, ref_output, 1e-7, 1e-5, /*equal_nan=*/true)); + + torch::Tensor tgt_key_padding_mask = + torch::zeros({tgt.size(1), tgt.size(0)}, tensor_options) == 1; + tgt_key_padding_mask[0][0] = 1; + tgt_key_padding_mask[1][1] = 1; + ref_output = torch::tensor( + {{{2.696114, 0.347004, -0.044813, -0.548417}, + {2.696091, 0.347015, -0.044770, -0.548522}}, + {{2.696114, 0.347004, -0.044813, -0.548417}, + {2.696091, 0.347015, -0.044770, -0.548522}}}, + tensor_options); + result = model( + src, + tgt, + src_mask, + torch::Tensor(), + torch::Tensor(), + torch::Tensor(), + tgt_key_padding_mask); + result_cus = model_cus( + src, + tgt, + src_mask, + torch::Tensor(), + torch::Tensor(), + torch::Tensor(), + tgt_key_padding_mask); + ASSERT_EQ(result.sizes(), ref_output.sizes()); + ASSERT_TRUE(result.equal(result_cus)); + ASSERT_TRUE( + torch::allclose(result, ref_output, 1e-7, 1e-5, /*equal_nan=*/true)); } TEST_F(TransformerTest, Transformer) { - transformer_test_helper(/*is_cuda=*/ false, /*use_callable_activation=*/ false); - transformer_test_helper(/*is_cuda=*/ false, /*use_callable_activation=*/ true); + transformer_test_helper(/*is_cuda=*/false, /*use_callable_activation=*/false); + transformer_test_helper(/*is_cuda=*/false, /*use_callable_activation=*/true); } TEST_F(TransformerTest, Transformer_CUDA) { - transformer_test_helper(/*is_cuda=*/ true, /*use_callable_activation=*/ false); - transformer_test_helper(/*is_cuda=*/ true, /*use_callable_activation=*/ true); + transformer_test_helper(/*is_cuda=*/true, /*use_callable_activation=*/false); + transformer_test_helper(/*is_cuda=*/true, /*use_callable_activation=*/true); } TEST_F(TransformerTest, TransformerArgsCorrectness) { Transformer model(TransformerOptions() - .d_model(4) - .nhead(2) - .num_encoder_layers(2) - .num_decoder_layers(1) - .dim_feedforward(16) - .dropout(0.0) - .activation(torch::kReLU)); + .d_model(4) + .nhead(2) + .num_encoder_layers(2) + .num_decoder_layers(1) + .dim_feedforward(16) + .dropout(0.0) + .activation(torch::kReLU)); torch::Tensor src = torch::randn({2, 3, 4}); torch::Tensor tgt = torch::randn({3, 2, 4}); - ASSERT_THROWS_WITH(model(src, tgt), "src and tgt should have equal batch size"); + ASSERT_THROWS_WITH( + model(src, tgt), "src and tgt should have equal batch size"); tgt = torch::randn({2, 3, 3}); - ASSERT_THROWS_WITH(model(src, tgt), "src and tgt should have same feature size as d_model"); + ASSERT_THROWS_WITH( + model(src, tgt), "src and tgt should have same feature size as d_model"); src = torch::randn({2, 3}); ASSERT_THROWS_WITH(model(src, tgt), "src and tgt should have 3 dimensions"); diff --git a/test/cpp/c10d/ProcessGroupGlooAsyncTest.cpp b/test/cpp/c10d/ProcessGroupGlooAsyncTest.cpp index 3032afdc306a3e..6edc8d8ff0959a 100644 --- a/test/cpp/c10d/ProcessGroupGlooAsyncTest.cpp +++ b/test/cpp/c10d/ProcessGroupGlooAsyncTest.cpp @@ -4,9 +4,9 @@ #include #include #include +#include #include "CUDATest.hpp" #include "TestUtils.hpp" -#include using namespace c10d::test; @@ -99,7 +99,8 @@ class AsyncInputIsOutputTest : public AsyncTest { work->wait(); } - std::vector getCpuTensors(const std::vector& gpu_tensors) { + std::vector getCpuTensors( + const std::vector& gpu_tensors) { std::vector outputs(gpu_tensors.size()); // For the duration of this function, make THC use our streams @@ -117,7 +118,6 @@ class AsyncInputIsOutputTest : public AsyncTest { return getCpuTensors(inputs_); } - protected: const int numTensors_; const int numDevices_; @@ -156,7 +156,9 @@ class AsyncBroadcastTest : public AsyncInputIsOutputTest { AsyncBroadcastTest(const std::string& path, int numTensors) : AsyncInputIsOutputTest(path, numTensors) {} - c10::intrusive_ptr run(int rootRank, int rootTensor) { + c10::intrusive_ptr run( + int rootRank, + int rootTensor) { // For the duration of this function, make THC use our streams c10::cuda::CUDAMultiStreamGuard guard(streams_); @@ -186,24 +188,24 @@ void runAsyncAllreduceTest( size_t numTensors = 2) { auto tests = initialize(path, numProcesses, numTensors); std::vector> work(numProcesses); - for(const auto i : c10::irange(numProcesses)) { + for (const auto i : c10::irange(numProcesses)) { work[i] = tests[i].run(); } // Wait for work to complete - for(const auto i : c10::irange(numProcesses)) { + for (const auto i : c10::irange(numProcesses)) { tests[i].wait(work[i]); } // Check results - for(const auto i : c10::irange(numProcesses)) { + for (const auto i : c10::irange(numProcesses)) { const auto size = numProcesses * numTensors; const auto expected = (size * (size - 1)) / 2; auto tensors = tests[i].getTensors(); auto results = tests[i].getCpuTensors(work[i]->result()); EXPECT_EQ(tensors.size(), results.size()); - for(const auto j : c10::irange(tensors.size())) { + for (const auto j : c10::irange(tensors.size())) { auto& tensor = tensors[j]; auto data = tensor.data_ptr(); @@ -227,24 +229,25 @@ void runAsyncBroadcastTest( auto tests = initialize(path, numProcesses, numTensors); // Try every permutation of root rank and root tensor - for(const auto rootRank : c10::irange(numProcesses)) { - for(const auto rootTensor : c10::irange(numTensors)) { - std::vector> work(numProcesses); - for(const auto i : c10::irange(numProcesses)) { + for (const auto rootRank : c10::irange(numProcesses)) { + for (const auto rootTensor : c10::irange(numTensors)) { + std::vector> work( + numProcesses); + for (const auto i : c10::irange(numProcesses)) { work[i] = tests[i].run(rootRank, rootTensor); } // Wait for work to complete - for(const auto i : c10::irange(numProcesses)) { + for (const auto i : c10::irange(numProcesses)) { tests[i].wait(work[i]); } // Check results const auto expected = (rootRank * numTensors + rootTensor); - for(const auto i : c10::irange(numProcesses)) { + for (const auto i : c10::irange(numProcesses)) { auto tensors = tests[i].getTensors(); - for (const auto & tensor : tensors) { - const auto *const data = tensor.data_ptr(); + for (const auto& tensor : tensors) { + const auto* const data = tensor.data_ptr(); for (const auto k : c10::irange(tensor.numel())) { EXPECT_EQ(data[k], expected); } diff --git a/test/cpp/c10d/ProcessGroupGlooTest.cpp b/test/cpp/c10d/ProcessGroupGlooTest.cpp index 88ac369af13225..394215b0b7e30d 100644 --- a/test/cpp/c10d/ProcessGroupGlooTest.cpp +++ b/test/cpp/c10d/ProcessGroupGlooTest.cpp @@ -180,10 +180,10 @@ class CollectiveTest { std::vector> copyTensors( const std::vector>& inputs) { std::vector> outputs(inputs.size()); - for(const auto i : c10::irange(inputs.size())) { + for (const auto i : c10::irange(inputs.size())) { const auto& input = inputs[i]; std::vector output(input.size()); - for(const auto j : c10::irange(input.size())) { + for (const auto j : c10::irange(input.size())) { output[j] = input[j].cpu(); } outputs[i] = output; @@ -301,7 +301,9 @@ void testAllreduce(const std::string& path, const at::DeviceType b) { // UsingWorkAPI tests are to make sure we still properly support work API. // This should go away as we deprecate it. -void testAllreduceUsingWorkAPI(const std::string& path, const at::DeviceType b) { +void testAllreduceUsingWorkAPI( + const std::string& path, + const at::DeviceType b) { const auto size = 4; auto tests = CollectiveTest::initialize(path, size); @@ -445,18 +447,18 @@ void testAlltoall(const std::string& path, const at::DeviceType b) { // Kick off work std::vector> work(size); - const char * GLOO_A2A_STR = "gloo:all_to_all"; + const char* GLOO_A2A_STR = "gloo:all_to_all"; std::vector> allShapes; - for (const auto & vec : inputSplits) { + for (const auto& vec : inputSplits) { // Due to concatenation of tensors, shape will actually be the sum int64_t sum = 0; - for (const auto & s : vec) { + for (const auto& s : vec) { sum += s; } allShapes.push_back({sum}); } enableProfilerLegacy(ProfilerConfig( - ProfilerState::CPU, /* report_input_shapes */ true, false)); + ProfilerState::CPU, /* report_input_shapes */ true, false)); for (const auto rank : c10::irange(size)) { work[rank] = tests[rank].getProcessGroup().alltoall_base( outputs[rank], inputs[rank], outputSplits[rank], inputSplits[rank]); @@ -468,8 +470,7 @@ void testAlltoall(const std::string& path, const at::DeviceType b) { } auto event_lists = disableProfilerLegacy(); - checkProfiledEvents( - std::move(event_lists), GLOO_A2A_STR, size, allShapes); + checkProfiledEvents(std::move(event_lists), GLOO_A2A_STR, size, allShapes); // Verify outputs std::vector> expected = { {0, 1, 10, 11, 12, 20, 21, 30, 31}, @@ -493,7 +494,7 @@ void testBarrier(const std::string& path) { // Kick off work enableProfilerLegacy(ProfilerConfig( - ProfilerState::CPU, /* report_input_shapes */ true, false)); + ProfilerState::CPU, /* report_input_shapes */ true, false)); std::vector> work(size); for (const auto i : c10::irange(size)) { work[i] = tests[i].getProcessGroup().barrier(); @@ -503,11 +504,15 @@ void testBarrier(const std::string& path) { waitFuture(work); auto event_lists = disableProfilerLegacy(); - const char * GLOO_STR = "gloo:barrier"; + const char* GLOO_STR = "gloo:barrier"; std::vector> allShapes; // Barrier does not use tensors, so skip shape checking. - checkProfiledEvents( - std::move(event_lists), GLOO_STR, size, allShapes, /* verify_shapes */ false); + checkProfiledEvents( + std::move(event_lists), + GLOO_STR, + size, + allShapes, + /* verify_shapes */ false); } void testMonitoredBarrier(const std::string& path) { @@ -515,36 +520,38 @@ void testMonitoredBarrier(const std::string& path) { auto tests = CollectiveTest::initialize(path, size); // Non-failure case: all ranks pass the blocking monitored barrier. auto runMonitoredBarrier = [&](int i) { - tests[i].getProcessGroup().monitoredBarrier(); + tests[i].getProcessGroup().monitoredBarrier(); }; std::vector threads; threads.reserve(size); - for(const auto r : c10::irange(size)) { + for (const auto r : c10::irange(size)) { threads.emplace_back(std::thread([=]() { runMonitoredBarrier(r); })); } - for (auto & t : threads) { + for (auto& t : threads) { t.join(); } - // Failure case: Only rank 0 calls into monitored barrier, should result in error + // Failure case: Only rank 0 calls into monitored barrier, should result in + // error auto runMonitoredBarrierWithException = [&](int i) { - if (i != 0) { - return; - } + if (i != 0) { + return; + } - try { - tests[i].getProcessGroup().monitoredBarrier(); - FAIL() << "Exception should have been thrown."; - } catch (const std::exception& e) { - auto pos = std::string(e.what()).find("Rank 1"); - EXPECT_TRUE(pos != std::string::npos); - } + try { + tests[i].getProcessGroup().monitoredBarrier(); + FAIL() << "Exception should have been thrown."; + } catch (const std::exception& e) { + auto pos = std::string(e.what()).find("Rank 1"); + EXPECT_TRUE(pos != std::string::npos); + } }; threads.clear(); - for(const auto r : c10::irange(size)) { - threads.emplace_back(std::thread([=]() { runMonitoredBarrierWithException(r); })); + for (const auto r : c10::irange(size)) { + threads.emplace_back( + std::thread([=]() { runMonitoredBarrierWithException(r); })); } - for (auto & t : threads) { - t.join(); + for (auto& t : threads) { + t.join(); } } @@ -595,7 +602,8 @@ void testSend(const std::string& path) { }; auto& pg = tests[selfRank].getProcessGroup(); const char* GLOO_SEND_STR = "gloo:send"; - enableProfilerLegacy(ProfilerConfig(ProfilerState::CPU, /* report_input_shapes */ true, false)); + enableProfilerLegacy(ProfilerConfig( + ProfilerState::CPU, /* report_input_shapes */ true, false)); auto sendWork = pg.send(tensors, dstRank, tag); bool sendCompleted; std::thread waitSendThreadAbort([&]() { sendCompleted = sendWork->wait(); }); @@ -604,8 +612,7 @@ void testSend(const std::string& path) { waitSendThreadAbort.join(); EXPECT_FALSE(sendCompleted); auto event_lists = disableProfilerLegacy(); - checkProfiledEvents( - std::move(event_lists), GLOO_SEND_STR, 1, allShapes); + checkProfiledEvents(std::move(event_lists), GLOO_SEND_STR, 1, allShapes); // Now create a separate sender thread to ensure that future waitsends can // complete successfully. @@ -645,7 +652,8 @@ void testRecv(const std::string& path) { }; const char* GLOO_RECV_STR = "gloo:recv"; auto& pg = tests[selfRank].getProcessGroup(); - enableProfilerLegacy(ProfilerConfig(ProfilerState::CPU, /* report_input_shapes */ true, false)); + enableProfilerLegacy(ProfilerConfig( + ProfilerState::CPU, /* report_input_shapes */ true, false)); auto recvWork = pg.recv(tensors, srcRank, tag); bool recvCompleted; std::thread waitRecvThreadAbort([&]() { recvCompleted = recvWork->wait(); }); @@ -654,8 +662,7 @@ void testRecv(const std::string& path) { waitRecvThreadAbort.join(); EXPECT_FALSE(recvCompleted); auto event_lists = disableProfilerLegacy(); - checkProfiledEvents( - std::move(event_lists), GLOO_RECV_STR, 1, allShapes); + checkProfiledEvents(std::move(event_lists), GLOO_RECV_STR, 1, allShapes); // Now create a separate receiver thread to ensure that future waits can // complete successfully. @@ -680,7 +687,6 @@ void testRecv(const std::string& path) { EXPECT_TRUE(recvCompleted); } - void testStoreSetGet(const std::string& path) { const auto size = 2; auto tests = CollectiveTest::initialize(path, size); diff --git a/test/cpp/c10d/ProcessGroupMPITest.cpp b/test/cpp/c10d/ProcessGroupMPITest.cpp index b8538a016d5b7b..9a779e8fa122e6 100644 --- a/test/cpp/c10d/ProcessGroupMPITest.cpp +++ b/test/cpp/c10d/ProcessGroupMPITest.cpp @@ -170,7 +170,7 @@ void testAllgather(int iter = 10000) { // Queue the work. c10::intrusive_ptr<::c10d::ProcessGroup::Work> work = - pg->allgather(outputs, tensors); + pg->allgather(outputs, tensors); works.push_back(std::move(work)); } @@ -264,7 +264,7 @@ void testScatter(int iter = 1) { // Queue the work. c10::intrusive_ptr<::c10d::ProcessGroup::Work> work = - pg->scatter(tensors, inputs); + pg->scatter(tensors, inputs); works.push_back(std::move(work)); } diff --git a/test/cpp/c10d/ProcessGroupNCCLErrorsTest.cpp b/test/cpp/c10d/ProcessGroupNCCLErrorsTest.cpp index 5f114e1c4e17f1..f998dc1fd90a7e 100644 --- a/test/cpp/c10d/ProcessGroupNCCLErrorsTest.cpp +++ b/test/cpp/c10d/ProcessGroupNCCLErrorsTest.cpp @@ -3,9 +3,9 @@ #include #include #include +#include #include "CUDATest.hpp" #include "TestUtils.hpp" -#include #include @@ -62,7 +62,9 @@ class ProcessGroupNCCLSimulateErrors : public c10d::ProcessGroupNCCL { std::vector devices, int rank, c10d::OpType opType, - const char* profilingTitle, const c10::optional>& inputs = c10::nullopt) override { + const char* profilingTitle, + const c10::optional>& inputs = + c10::nullopt) override { return c10::make_intrusive( devices, simulate_error_, rank, opType, seq_); } @@ -120,7 +122,9 @@ class ProcessGroupNCCLTimedOutErrors : public ProcessGroupNCCLSimulateErrors { std::vector devices, int rank, c10d::OpType opType, - const char* profilingTitle, const c10::optional>& inputs = c10::nullopt) override { + const char* profilingTitle, + const c10::optional>& inputs = + c10::nullopt) override { return c10::make_intrusive( devices, set_timedout_error_, rank, opType, seq_); } @@ -182,8 +186,7 @@ TEST_F(ProcessGroupNCCLErrorsTest, testNCCLErrorsBlocking) { ASSERT_TRUE(setenv(c10d::NCCL_BLOCKING_WAIT, "1", 1) == 0); auto options = c10d::ProcessGroupNCCL::Options::create(); options->timeout = std::chrono::milliseconds(1000); - ProcessGroupNCCLSimulateErrors pg( - store_, 0, 1, options); + ProcessGroupNCCLSimulateErrors pg(store_, 0, 1, options); auto work = pg.allreduce(tensors_); work->wait(); @@ -211,8 +214,7 @@ TEST_F(ProcessGroupNCCLErrorsTest, testNCCLTimedoutErrorsBlocking) { ASSERT_TRUE(setenv(c10d::NCCL_BLOCKING_WAIT, "1", 1) == 0); auto options = c10d::ProcessGroupNCCL::Options::create(); options->timeout = std::chrono::milliseconds(3000); - ProcessGroupNCCLTimedOutErrors pg( - store_, 0, 1, options); + ProcessGroupNCCLTimedOutErrors pg(store_, 0, 1, options); auto work = pg.allreduce(tensors_); work->wait(); @@ -234,8 +236,7 @@ TEST_F(ProcessGroupNCCLErrorsTest, testNCCLErrorsNonBlocking) { auto options = c10d::ProcessGroupNCCL::Options::create(); options->timeout = std::chrono::milliseconds(3000); - ProcessGroupNCCLSimulateErrors pg( - store_, 0, 1, options); + ProcessGroupNCCLSimulateErrors pg(store_, 0, 1, options); auto work = pg.allreduce(tensors_); pg.barrier()->wait(); diff --git a/test/cpp/c10d/ProcessGroupNCCLTest.cpp b/test/cpp/c10d/ProcessGroupNCCLTest.cpp index 6e57e92389f5ac..278525ac29f995 100644 --- a/test/cpp/c10d/ProcessGroupNCCLTest.cpp +++ b/test/cpp/c10d/ProcessGroupNCCLTest.cpp @@ -5,15 +5,15 @@ #include #include "CUDATest.hpp" #include "TestUtils.hpp" -#include "c10d/Types.hpp" #include "c10d/ProcessGroup.hpp" +#include "c10d/Types.hpp" #include #include #include -#include #include +#include using namespace c10d::test; @@ -23,9 +23,9 @@ using c10d::ProcessGroup; class NCCLTestBase { public: NCCLTestBase( - const std::string& path, - const std::chrono::milliseconds pgTimeout = kProcessGroupDefaultTimeout - ) : path_(path), pgTimeout_(pgTimeout) {} + const std::string& path, + const std::chrono::milliseconds pgTimeout = kProcessGroupDefaultTimeout) + : path_(path), pgTimeout_(pgTimeout) {} NCCLTestBase(NCCLTestBase&& other) { path_ = std::move(other.path_); @@ -39,7 +39,8 @@ class NCCLTestBase { void initialize(int rank, int size) { auto store = c10::make_intrusive<::c10d::FileStore>(path_, size); - c10::intrusive_ptr opts = c10::make_intrusive(); + c10::intrusive_ptr opts = + c10::make_intrusive(); opts->timeout = pgTimeout_; setenv("ENABLE_NCCL_HEALTH_CHECK", "1", /* overwrite */ 1); pg_ = std::unique_ptr<::c10d::ProcessGroupNCCL>( @@ -54,7 +55,10 @@ class NCCLTestBase { class NCCLTest : public NCCLTestBase { public: - NCCLTest(const std::string& path, int worldSize, std::chrono::milliseconds pgTimeout = kProcessGroupDefaultTimeout) + NCCLTest( + const std::string& path, + int worldSize, + std::chrono::milliseconds pgTimeout = kProcessGroupDefaultTimeout) : NCCLTestBase(path, pgTimeout), numDevices_(cudaNumDevices()), worldSize_(worldSize) { @@ -126,7 +130,7 @@ class NCCLTest : public NCCLTestBase { std::vector> getTensorLists( std::vector>& tensor_lists) { std::vector> outputs(numDevices_); - for (auto & output : outputs) { + for (auto& output : outputs) { output = std::vector(worldSize_ * numDevices_); } @@ -198,7 +202,9 @@ class BroadcastNCCLTest : public NCCLTest { BroadcastNCCLTest(const std::string& path, int worldSize) : NCCLTest(path, worldSize) {} - c10::intrusive_ptr run(int rootRank, int rootTensor) { + c10::intrusive_ptr run( + int rootRank, + int rootTensor) { // For the duration of this function, make THC use our streams c10::cuda::CUDAMultiStreamGuard guard(streams_); @@ -217,7 +223,9 @@ class ReduceNCCLTest : public NCCLTest { ReduceNCCLTest(const std::string& path, int worldSize) : NCCLTest(path, worldSize) {} - c10::intrusive_ptr run(int rootRank, int rootTensor) { + c10::intrusive_ptr run( + int rootRank, + int rootTensor) { // For the duration of this function, make THC use our streams c10::cuda::CUDAMultiStreamGuard guard(streams_); @@ -251,8 +259,8 @@ class AllgatherBaseNCCLTest : public NCCLTest { public: AllgatherBaseNCCLTest(const std::string& path, int worldSize) : NCCLTest(path, worldSize) { - output_tensor_ = at::empty({worldSize_, 3, 3}, at::kCUDA); - } + output_tensor_ = at::empty({worldSize_, 3, 3}, at::kCUDA); + } c10::intrusive_ptr run() { // For the duration of this function, make THC use our streams @@ -276,12 +284,10 @@ class AllgatherBaseNCCLTest : public NCCLTest { return tensors_[0].cpu(); } - - private: - at::Tensor output_tensor_; + private: + at::Tensor output_tensor_; }; - struct ReduceScatterNCCLTest : NCCLTest { ReduceScatterNCCLTest(const std::string& path, int worldSize) : NCCLTest(path, worldSize) {} @@ -310,12 +316,12 @@ class ReduceScatterBaseNCCLTest : public NCCLTest { public: ReduceScatterBaseNCCLTest(const std::string& path, int worldSize) : NCCLTest(path, worldSize) { - output_tensor_ = at::empty({1}, at::kCUDA); - input_tensor_ = at::empty({worldSize}, at::kCUDA); - for (const auto i : c10::irange(worldSize)) { - input_tensor_[i] = i; - } - } + output_tensor_ = at::empty({1}, at::kCUDA); + input_tensor_ = at::empty({worldSize}, at::kCUDA); + for (const auto i : c10::irange(worldSize)) { + input_tensor_[i] = i; + } + } c10::intrusive_ptr run() { // For the duration of this function, make THC use our streams @@ -335,9 +341,9 @@ class ReduceScatterBaseNCCLTest : public NCCLTest { return input_tensor_.cpu(); } - private: - at::Tensor output_tensor_; - at::Tensor input_tensor_; + private: + at::Tensor output_tensor_; + at::Tensor input_tensor_; }; void testAllreduce(const std::string& path, int rank, int size) { @@ -351,8 +357,8 @@ void testAllreduce(const std::string& path, int rank, int size) { const int totalNumGPUs = test.numDevices() * size; const auto expected = (totalNumGPUs * (totalNumGPUs - 1)) / 2; const auto tensors = test.getTensors(); - for (const auto & tensor : tensors) { - const auto *const data = tensor.data_ptr(); + for (const auto& tensor : tensors) { + const auto* const data = tensor.data_ptr(); for (const auto k : c10::irange(tensor.numel())) { EXPECT_EQ(data[k], expected) << "Allreduce ouputs do not match expected outputs"; @@ -376,8 +382,8 @@ void testBroadcast(const std::string& path, int rank, int size) { // Check results const auto expected = (rootRank * numDevices + rootTensor); const auto tensors = test.getTensors(); - for (const auto & tensor : tensors) { - const auto *const data = tensor.data_ptr(); + for (const auto& tensor : tensors) { + const auto* const data = tensor.data_ptr(); for (const auto k : c10::irange(tensor.numel())) { EXPECT_EQ(data[k], expected) << "Broadcast outputs do not match expected outputs"; @@ -426,7 +432,7 @@ void testAllgather(const std::string& path, int rank, int size) { // Validation auto tensors = test.getOutputTensors(); // device index - for (auto & device : tensors) { + for (auto& device : tensors) { // rank index for (const auto j : c10::irange(device.size())) { const auto expected = j; @@ -454,10 +460,11 @@ void testAllgatherBase(const std::string& path, int rank, int size) { // Rank index for (const auto i : c10::irange(output_tensor.numel())) { - // expected is i // input.numel() <- rank, and each rank contributed rank * num_gpu + // expected is i // input.numel() <- rank, and each rank contributed rank * + // num_gpu const auto expected = (i / input_tensor.numel()) * test.numDevices(); EXPECT_EQ(data[i], expected) - << "Allgather_base outputs do not match expected outputs"; + << "Allgather_base outputs do not match expected outputs"; } } void testReduceScatterBase(const std::string& path, int rank, int size) { @@ -474,10 +481,11 @@ void testReduceScatterBase(const std::string& path, int rank, int size) { // Rank index for (const auto i : c10::irange(output_tensor.numel())) { - // expected is i * input.numel() <- rank, and each rank contributed rank * num_gpu + // expected is i * input.numel() <- rank, and each rank contributed rank * + // num_gpu const auto expected = size * rank * test.numDevices(); EXPECT_EQ(data[i], expected) - << "Reducescatter_base outputs do not match expected outputs"; + << "Reducescatter_base outputs do not match expected outputs"; } } @@ -500,12 +508,15 @@ void testReduceScatter(const std::string& path, int rank, int size) { auto& tensor = tensors[i]; auto data = tensor.data_ptr(); for (const auto j : c10::irange(tensor.numel())) { - EXPECT_EQ(data[j], expected) << "ReduceScatter outputs do not match expected outputs!"; + EXPECT_EQ(data[j], expected) + << "ReduceScatter outputs do not match expected outputs!"; } } } -void testProcessGroupNCCLHealthCheckFailHelper(const std::string& path, bool timeout) { +void testProcessGroupNCCLHealthCheckFailHelper( + const std::string& path, + bool timeout) { // simulate world_size > 1 here via threads. const int worldSize = 4; std::unordered_set nums; @@ -515,9 +526,10 @@ void testProcessGroupNCCLHealthCheckFailHelper(const std::string& path, bool tim bool error_caught = false; try { test.initialize(timeout ? 0 : -1, worldSize); - } catch (const std::exception &e) { + } catch (const std::exception& e) { std::string errMsg = e.what(); - const std::string kTimeoutErr = "Failed to initialize NCCL communicator on rank"; + const std::string kTimeoutErr = + "Failed to initialize NCCL communicator on rank"; const std::string kInvalidRankErr = "Invalid rank"; std::string expectedSubstr = timeout ? kTimeoutErr : kInvalidRankErr; bool cond = errMsg.find(expectedSubstr) != std::string::npos; @@ -536,15 +548,24 @@ void testProcessGroupNCCLHealthCheckFailHelper(const std::string& path, bool tim } } -void testProcessGroupNCCLHealthCheckFailException(const std::string& path, int /* unused */, int /* unused */) { +void testProcessGroupNCCLHealthCheckFailException( + const std::string& path, + int /* unused */, + int /* unused */) { testProcessGroupNCCLHealthCheckFailHelper(path, /* timeout */ false); } -void testProcessGroupNCCLHealthCheckFailTimeout(const std::string& path, int /* unused */, int /* unused */) { +void testProcessGroupNCCLHealthCheckFailTimeout( + const std::string& path, + int /* unused */, + int /* unused */) { testProcessGroupNCCLHealthCheckFailHelper(path, /* timeout */ true); } -void testSequenceNumInit(const std::string& path, int /* unused */, int /* unused */) { +void testSequenceNumInit( + const std::string& path, + int /* unused */, + int /* unused */) { // Note: ProcessGroupNCCLTest doesn't support multiprocess testing. So we // simulate world_size > 1 here via threads. const int worldSize = 2; @@ -569,7 +590,7 @@ void testSequenceNumInit(const std::string& path, int /* unused */, int /* unuse EXPECT_EQ(nums.size(), 1); } -class ProcessGroupNCCLTest: public ::testing::Test { +class ProcessGroupNCCLTest : public ::testing::Test { protected: void SetUp() override { // Use WORLD_SIZE and RANK environmental variables to do multi-node @@ -673,23 +694,23 @@ TEST_F(ProcessGroupNCCLTest, testSequenceNumInit) { } TEST_F(ProcessGroupNCCLTest, testProcessGroupNCCLHealthCheckFailTimeout) { - if (skipTest()) { - return; - } - { - TemporaryFile file; - testProcessGroupNCCLHealthCheckFailTimeout(file.path, rank_, size_); - } + if (skipTest()) { + return; + } + { + TemporaryFile file; + testProcessGroupNCCLHealthCheckFailTimeout(file.path, rank_, size_); + } } TEST_F(ProcessGroupNCCLTest, testProcessGroupNCCLHealthCheckFailException) { - if (skipTest()) { - return; - } - { - TemporaryFile file; - testProcessGroupNCCLHealthCheckFailException(file.path, rank_, size_); - } + if (skipTest()) { + return; + } + { + TemporaryFile file; + testProcessGroupNCCLHealthCheckFailException(file.path, rank_, size_); + } } TEST_F(ProcessGroupNCCLTest, testReduceScatterBase) { @@ -711,6 +732,7 @@ TEST_F(ProcessGroupNCCLTest, testBackendName) { auto test = NCCLTestBase(file.path); test.initialize(rank_, size_); EXPECT_EQ( - test.getProcessGroup().getBackendName(), std::string(c10d::NCCL_BACKEND_NAME)); + test.getProcessGroup().getBackendName(), + std::string(c10d::NCCL_BACKEND_NAME)); } } diff --git a/test/cpp/c10d/TCPStoreTest.cpp b/test/cpp/c10d/TCPStoreTest.cpp index f01423fcb4b3ca..a8070246dc1d3a 100644 --- a/test/cpp/c10d/TCPStoreTest.cpp +++ b/test/cpp/c10d/TCPStoreTest.cpp @@ -90,8 +90,8 @@ void testHelper(const std::string& prefix = "") { std::vector> clientTCPStores; std::vector> clientStores; for (const auto i : c10::irange(numThreads)) { - clientTCPStores.push_back(c10::make_intrusive( - "127.0.0.1", opts)); + clientTCPStores.push_back( + c10::make_intrusive("127.0.0.1", opts)); clientStores.push_back( c10::make_intrusive(prefix, clientTCPStores[i])); } @@ -100,36 +100,33 @@ void testHelper(const std::string& prefix = "") { std::to_string(numThreads * numIterations + 1); for (const auto i : c10::irange(numThreads)) { - threads.emplace_back(std::thread([=, - &sem1, - &sem2, - &clientStores, - &expectedCounterRes] { - for (C10_UNUSED const auto j : c10::irange(numIterations)) { - clientStores[i]->add("counter", 1); - } - // Let each thread set and get key on its client store - std::string key = "thread_" + std::to_string(i); - for (const auto j : c10::irange(numIterations)) { - std::string val = "thread_val_" + std::to_string(j); - c10d::test::set(*clientStores[i], key, val); - c10d::test::check(*clientStores[i], key, val); - } - - sem1.post(); - sem2.wait(); - // Check the counter results - c10d::test::check(*clientStores[i], "counter", expectedCounterRes); - // Now check other threads' written data - for (const auto j : c10::irange(numThreads)) { - if (j == i) { - continue; - } - std::string key = "thread_" + std::to_string(i); - std::string val = "thread_val_" + std::to_string(numIterations - 1); - c10d::test::check(*clientStores[i], key, val); - } - })); + threads.emplace_back( + std::thread([=, &sem1, &sem2, &clientStores, &expectedCounterRes] { + for (C10_UNUSED const auto j : c10::irange(numIterations)) { + clientStores[i]->add("counter", 1); + } + // Let each thread set and get key on its client store + std::string key = "thread_" + std::to_string(i); + for (const auto j : c10::irange(numIterations)) { + std::string val = "thread_val_" + std::to_string(j); + c10d::test::set(*clientStores[i], key, val); + c10d::test::check(*clientStores[i], key, val); + } + + sem1.post(); + sem2.wait(); + // Check the counter results + c10d::test::check(*clientStores[i], "counter", expectedCounterRes); + // Now check other threads' written data + for (const auto j : c10::irange(numThreads)) { + if (j == i) { + continue; + } + std::string key = "thread_" + std::to_string(i); + std::string val = "thread_val_" + std::to_string(numIterations - 1); + c10d::test::check(*clientStores[i], key, val); + } + })); } sem1.wait(numThreads); @@ -164,9 +161,7 @@ void testWatchKeyCallback(const std::string& prefix = "") { constexpr int numThreads = 16; constexpr int keyChangeOperation = 3; c10d::WatchKeyCallback callback = - [=, - &numCallbacksExecuted, - &numCallbacksExecutedPromise]( + [=, &numCallbacksExecuted, &numCallbacksExecutedPromise]( c10::optional /* unused */, c10::optional /* unused */) { numCallbacksExecuted++; @@ -188,8 +183,8 @@ void testWatchKeyCallback(const std::string& prefix = "") { std::vector> clientTCPStores; std::vector> clientStores; for (const auto i : c10::irange(numThreads)) { - clientTCPStores.push_back(c10::make_intrusive( - "127.0.0.1", opts)); + clientTCPStores.push_back( + c10::make_intrusive("127.0.0.1", opts)); clientStores.push_back( c10::make_intrusive(prefix, clientTCPStores[i])); } @@ -379,7 +374,7 @@ TEST(TCPStoreTest, testCleanShutdown) { } TEST(TCPStoreTest, testMultiTenantStores) { - c10d::TCPStoreOptions opts {}; + c10d::TCPStoreOptions opts{}; opts.isServer = true; opts.multiTenant = true; diff --git a/test/cpp/dist_autograd/test_dist_autograd.cpp b/test/cpp/dist_autograd/test_dist_autograd.cpp index 3db8fa78e917f3..f46bc1446f613c 100644 --- a/test/cpp/dist_autograd/test_dist_autograd.cpp +++ b/test/cpp/dist_autograd/test_dist_autograd.cpp @@ -21,7 +21,8 @@ class DistAutogradTest : public ::testing::Test { } virtual void TearDown() { - autogradContainer_->releaseContext(autogradContainer_->currentContext()->contextId()); + autogradContainer_->releaseContext( + autogradContainer_->currentContext()->contextId()); } static DistAutogradContainer* autogradContainer_; @@ -39,8 +40,7 @@ TEST_F(DistAutogradTest, TestSendFunctionInvalidInputs) { // Attach the send autograd function to tensors. std::vector tensors = {in1, in2}; rpc::worker_id_t worker_id = 1; - addSendRpcBackward( - autogradContext, AutogradMetadata(1, 1), tensors); + addSendRpcBackward(autogradContext, AutogradMetadata(1, 1), tensors); autogradContext->addKnownWorkerId(worker_id); auto send_function = autogradContext->sendFunctions()[1]; diff --git a/test/cpp/lazy/test_backend_device.cpp b/test/cpp/lazy/test_backend_device.cpp index f8ce49b9e287dd..b9b337fb702cce 100644 --- a/test/cpp/lazy/test_backend_device.cpp +++ b/test/cpp/lazy/test_backend_device.cpp @@ -36,7 +36,9 @@ TEST(BackendDeviceTest, Basic2) { TEST(BackendDeviceTest, Basic3) { struct TestType : public BackendDeviceType { - std::string toString() const override { return "Test"; } + std::string toString() const override { + return "Test"; + } }; auto device = BackendDevice(std::make_shared(), 1); @@ -78,7 +80,8 @@ TEST(BackendDeviceTest, FromAten) { #ifndef FBCODE_CAFFE2 auto backend_device = atenDeviceToBackendDevice(device); #else - // Lazy Tensor is disabled in FBCODE until addressing non-virtual methods (e.g. sizes) in TensorImpl + // Lazy Tensor is disabled in FBCODE until addressing non-virtual methods + // (e.g. sizes) in TensorImpl EXPECT_THROW(atenDeviceToBackendDevice(device), c10::Error); #endif // FBCODE_CAFFE2 } @@ -90,7 +93,8 @@ TEST(BackendDeviceTest, ToAten) { EXPECT_EQ(device.index(), 0); } -// TODO(alanwaketan): Update the following test once we have TorchScript backend upstreamed. +// TODO(alanwaketan): Update the following test once we have TorchScript backend +// upstreamed. TEST(BackendDeviceTest, GetBackendDevice1) { auto tensor = torch::rand({0, 1, 3, 0}); EXPECT_FALSE(GetBackendDevice(tensor)); @@ -103,5 +107,5 @@ TEST(BackendDeviceTest, GetBackendDevice2) { EXPECT_FALSE(GetBackendDevice(tensor1, tensor2)); } -} // namespace lazy -} // namespace torch +} // namespace lazy +} // namespace torch diff --git a/test/cpp/lazy/test_cache.cpp b/test/cpp/lazy/test_cache.cpp index ddbf6611d36a5b..f6b023f845c2a5 100644 --- a/test/cpp/lazy/test_cache.cpp +++ b/test/cpp/lazy/test_cache.cpp @@ -13,9 +13,7 @@ namespace lazy { class CacheNode : public Node { public: explicit CacheNode(const std::string& str) - : Node(OpKind(), /* num_outputs */ 1), - hash_(Hash(str)), - str_(str) {} + : Node(OpKind(), /* num_outputs */ 1), hash_(Hash(str)), str_(str) {} ~CacheNode() override = default; const std::vector& operands() const override { @@ -26,8 +24,13 @@ class CacheNode : public Node { TORCH_INTERNAL_ASSERT(false, "Can't access operand[i] of test node"); } - hash_t hash() const override { return hash_; } - hash_t shapeHash() const override { return hash_; } + hash_t hash() const override { + return hash_; + } + hash_t shapeHash() const override { + return hash_; + } + private: hash_t hash_; std::string str_; @@ -68,7 +71,7 @@ TEST(CacheTest, BasicTest) { class CacheNodeWithShape : public TsNode { public: explicit CacheNodeWithShape(const Shape& shape) - : TsNode(OpKind(), shape, /* num_outputs */ 1, /* seed */ 0){} + : TsNode(OpKind(), shape, /* num_outputs */ 1, /* seed */ 0) {} }; TEST(CacheTest, ShapeCacheTestForDynamicShape) { @@ -76,16 +79,14 @@ TEST(CacheTest, ShapeCacheTestForDynamicShape) { FLAGS_ltc_enable_dynamic_shapes = true; CacheNodeWithShape nodes[] = { - CacheNodeWithShape(Shape(c10::kFloat, {2, 4})), - CacheNodeWithShape(Shape(c10::kFloat, {4, 2})) }; + CacheNodeWithShape(Shape(c10::kFloat, {2, 4})), + CacheNodeWithShape(Shape(c10::kFloat, {4, 2}))}; /* * Make sure the cached shape for node (2, 4) is not used for node (4, 2) */ for (auto& node : nodes) { - EXPECT_EQ(node.shape(), node.computeShape([&]() { - return node.shape(); - })); + EXPECT_EQ(node.shape(), node.computeShape([&]() { return node.shape(); })); } // reset the flag diff --git a/test/cpp/lazy/test_ir.cpp b/test/cpp/lazy/test_ir.cpp index 1ce666164a641c..117232cd97eb92 100644 --- a/test/cpp/lazy/test_ir.cpp +++ b/test/cpp/lazy/test_ir.cpp @@ -1,16 +1,16 @@ #include -#include +#include #include #include +#include #include #include -#include #include +#include +#include #include #include -#include -#include namespace torch { namespace lazy { @@ -35,8 +35,13 @@ class TestLeafNode : public Node { TORCH_INTERNAL_ASSERT(false, "Can't access operand[i] of leaf node"); } - hash_t hash() const override { return hash_; } - hash_t shapeHash() const override { return hash_; } + hash_t hash() const override { + return hash_; + } + hash_t shapeHash() const override { + return hash_; + } + private: hash_t hash_; size_t param_; @@ -79,8 +84,9 @@ TEST(IrTest, MetaDataTest) { dummySourceLocation.file = "file"; dummySourceLocation.function = "function"; dummySourceLocation.line = 10; - GetPythonFramesFunction() = - [&]() -> std::vector { return {dummySourceLocation}; }; + GetPythonFramesFunction() = [&]() -> std::vector { + return {dummySourceLocation}; + }; node = MakeNode(1); auto metaWithSourceLoc = node->metadata(); EXPECT_EQ(metaWithSourceLoc.scope.size(), 0); @@ -111,7 +117,6 @@ TEST(IrTest, TsNodeTest) { } TEST(IrTest, DimensionNodeTest) { - const size_t DIM0 = 5; const size_t DIM1 = 8; NodePtr node1 = MakeNode( @@ -120,16 +125,20 @@ TEST(IrTest, DimensionNodeTest) { /*num_outputs*/ 1, /*hash_seed*/ kHashSeed); - auto size0 = std::dynamic_pointer_cast(MakeNode(Value{node1}, 0)); - auto size1 = std::dynamic_pointer_cast(MakeNode(Value{node1}, 1)); + auto size0 = + std::dynamic_pointer_cast(MakeNode(Value{node1}, 0)); + auto size1 = + std::dynamic_pointer_cast(MakeNode(Value{node1}, 1)); ASSERT_EQ(DIM0, size0->getStaticValue()); ASSERT_EQ(DIM1, size1->getStaticValue()); - auto add_dim = std::dynamic_pointer_cast(MakeNode(Value{size0}, Value{size1})); + auto add_dim = std::dynamic_pointer_cast( + MakeNode(Value{size0}, Value{size1})); ASSERT_EQ(DIM0 + DIM1, add_dim->getStaticValue()); - auto mul_dim = std::dynamic_pointer_cast(MakeNode(Value{size0}, Value{size1})); + auto mul_dim = std::dynamic_pointer_cast( + MakeNode(Value{size0}, Value{size1})); ASSERT_EQ(DIM0 * DIM1, mul_dim->getStaticValue()); } diff --git a/test/cpp/lazy/test_ir_util.cpp b/test/cpp/lazy/test_ir_util.cpp index ad951956db7dfe..2befb04236ab52 100644 --- a/test/cpp/lazy/test_ir_util.cpp +++ b/test/cpp/lazy/test_ir_util.cpp @@ -12,8 +12,7 @@ namespace lazy { class IrUtilNode : public Node { public: - explicit IrUtilNode() - : Node(OpKind(), /* num_outputs */ 1), hash_(Hash(0)) {} + explicit IrUtilNode() : Node(OpKind(), /* num_outputs */ 1), hash_(Hash(0)) {} ~IrUtilNode() override = default; void AddOperand(Value v) { @@ -24,8 +23,13 @@ class IrUtilNode : public Node { operands_.push_back(std::move(v.node)); } - hash_t hash() const override { return hash_; } - hash_t shapeHash() const override { return hash_; } + hash_t hash() const override { + return hash_; + } + hash_t shapeHash() const override { + return hash_; + } + private: hash_t hash_; }; diff --git a/test/cpp/lazy/test_lazy_ops.cpp b/test/cpp/lazy/test_lazy_ops.cpp index 366f7f70815946..2363f88e5ee6ae 100644 --- a/test/cpp/lazy/test_lazy_ops.cpp +++ b/test/cpp/lazy/test_lazy_ops.cpp @@ -4,32 +4,32 @@ #include #include +#include #include #include -#include -#include #include +#include #include -#include #include +#include #include namespace torch { namespace lazy { - -// Lazy Tensor is disabled in FBCODE until addressing non-virtual methods (e.g. sizes) in TensorImpl +// Lazy Tensor is disabled in FBCODE until addressing non-virtual methods (e.g. +// sizes) in TensorImpl #ifndef FBCODE_CAFFE2 namespace { - // This registers the torchscript backend, without which lazy device won't work -static bool inline init_backend(){ +// This registers the torchscript backend, without which lazy device won't work +static bool inline init_backend() { torch::lazy::InitTorchScriptBackend(); return true; } static const bool backend_initialized = init_backend(); -} +} // namespace class LazyTsTest : public ::testing::Test { protected: @@ -43,9 +43,9 @@ class LazyTsTest : public ::testing::Test { const std::string& counter_regex, const std::unordered_set* ignore_set) {} - void ExpectCounterChanged(const std::string& counter_regex, - const std::unordered_set* ignore_set) { - } + void ExpectCounterChanged( + const std::string& counter_regex, + const std::unordered_set* ignore_set) {} void ResetCounters() {} @@ -59,9 +59,10 @@ class LazyOpsTestBase : public LazyTsTest { }; void LazyTsTest::SetUp() { - (void)backend_initialized; // avoid unused parameter warning + (void)backend_initialized; // avoid unused parameter warning at::manual_seed(42); - torch::lazy::LazyGraphExecutor::Get()->SetRngSeed(torch::lazy::BackendDevice(), 42); + torch::lazy::LazyGraphExecutor::Get()->SetRngSeed( + torch::lazy::BackendDevice(), 42); } void LazyTsTest::TearDown() {} @@ -79,8 +80,7 @@ static inline at::DeviceType DefaultDevice() { return torch::lazy::getBackend()->EagerFallbackDeviceType(); } - -} // namespace +} // namespace TEST(LazyDynamicOpsTest, NarrowCopy) { auto x = torch::rand({5, 10, 10}).to(kLazy); @@ -162,8 +162,8 @@ TEST_F(LazyOpsTest, TestIsSigned) { TEST_F(LazyOpsTest, TestCastByte) { torch::Tensor a = - torch::rand({2, 2}, - torch::TensorOptions(torch::kFloat).device(DefaultDevice())) * + torch::rand( + {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice())) * 100.0; torch::Tensor b = torch::_cast_Byte(a); ForEachDevice([&](const torch::Device& device) { @@ -175,8 +175,8 @@ TEST_F(LazyOpsTest, TestCastByte) { TEST_F(LazyOpsTest, TestCastChar) { torch::Tensor a = - torch::rand({2, 2}, - torch::TensorOptions(torch::kFloat).device(DefaultDevice())) * + torch::rand( + {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice())) * 100.0; torch::Tensor b = torch::_cast_Char(a); ForEachDevice([&](const torch::Device& device) { @@ -188,8 +188,8 @@ TEST_F(LazyOpsTest, TestCastChar) { TEST_F(LazyOpsTest, TestCastShort) { torch::Tensor a = - torch::rand({2, 2}, - torch::TensorOptions(torch::kFloat).device(DefaultDevice())) * + torch::rand( + {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice())) * 100.0; torch::Tensor b = torch::_cast_Short(a); ForEachDevice([&](const torch::Device& device) { @@ -201,8 +201,8 @@ TEST_F(LazyOpsTest, TestCastShort) { TEST_F(LazyOpsTest, TestCastInt) { torch::Tensor a = - torch::rand({2, 2}, - torch::TensorOptions(torch::kFloat).device(DefaultDevice())) * + torch::rand( + {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice())) * 100.0; torch::Tensor b = torch::_cast_Int(a); ForEachDevice([&](const torch::Device& device) { @@ -214,8 +214,8 @@ TEST_F(LazyOpsTest, TestCastInt) { TEST_F(LazyOpsTest, TestCastLong) { torch::Tensor a = - torch::rand({2, 2}, - torch::TensorOptions(torch::kFloat).device(DefaultDevice())) * + torch::rand( + {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice())) * 100.0; torch::Tensor b = torch::_cast_Long(a); ForEachDevice([&](const torch::Device& device) { @@ -227,8 +227,8 @@ TEST_F(LazyOpsTest, TestCastLong) { TEST_F(LazyOpsTest, TestCastFloat) { torch::Tensor a = - torch::rand({2, 2}, - torch::TensorOptions(torch::kFloat).device(DefaultDevice())) * + torch::rand( + {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice())) * 100.0; torch::Tensor b = torch::_cast_Float(a); ForEachDevice([&](const torch::Device& device) { @@ -248,12 +248,12 @@ TEST_F(LazyOpsTest, TestRetainType) { } TEST_F(LazyOpsTest, TestLogicalTypeWithInterop) { - torch::Tensor query = - torch::rand({2, 12, 20, 64}, - torch::TensorOptions(torch::kFloat).device(torch::kLazy)); - torch::Tensor key = - torch::rand({2, 12, 64, 20}, - torch::TensorOptions(torch::kFloat).device(torch::kLazy)); + torch::Tensor query = torch::rand( + {2, 12, 20, 64}, + torch::TensorOptions(torch::kFloat).device(torch::kLazy)); + torch::Tensor key = torch::rand( + {2, 12, 64, 20}, + torch::TensorOptions(torch::kFloat).device(torch::kLazy)); torch::Tensor scores = torch::matmul(query, key) / torch::scalar_tensor( @@ -468,21 +468,25 @@ TEST_F(LazyOpsTest, TestMulScalarInPlace) { TEST_F(LazyOpsTest, TestDiv) { for (torch::ScalarType scalar_type1 : - {torch::kFloat, torch::kByte, torch::kChar, torch::kShort, torch::kInt, + {torch::kFloat, + torch::kByte, + torch::kChar, + torch::kShort, + torch::kInt, torch::kLong}) { - torch::Tensor a = - isFloatingType(scalar_type1) - ? torch::rand({3, 4}, torch::TensorOptions(scalar_type1)) - : torch::randint(0, 100, {3, 4}, - torch::TensorOptions(scalar_type1)); + torch::Tensor a = isFloatingType(scalar_type1) + ? torch::rand({3, 4}, torch::TensorOptions(scalar_type1)) + : torch::randint(0, 100, {3, 4}, torch::TensorOptions(scalar_type1)); for (torch::ScalarType scalar_type2 : - {torch::kFloat, torch::kByte, torch::kChar, torch::kShort, torch::kInt, + {torch::kFloat, + torch::kByte, + torch::kChar, + torch::kShort, + torch::kInt, torch::kLong}) { - torch::Tensor b = - isFloatingType(scalar_type2) - ? torch::rand({3, 4}, torch::TensorOptions(scalar_type2)) - : torch::randint(1, 100, {3, 4}, - torch::TensorOptions(scalar_type2)); + torch::Tensor b = isFloatingType(scalar_type2) + ? torch::rand({3, 4}, torch::TensorOptions(scalar_type2)) + : torch::randint(1, 100, {3, 4}, torch::TensorOptions(scalar_type2)); torch::Tensor c = torch::div(a, b); ForEachDevice([&](const torch::Device& device) { torch::Tensor lazy_a = CopyToDevice(a, device); @@ -495,26 +499,32 @@ TEST_F(LazyOpsTest, TestDiv) { } TEST_F(LazyOpsTest, TestDivWithRoundingMode) { - c10::optional rounding_modes[] = {"trunc", "floor", - c10::nullopt}; + c10::optional rounding_modes[] = { + "trunc", "floor", c10::nullopt}; for (const auto& rounding_mode : rounding_modes) { for (torch::ScalarType scalar_type1 : - {torch::kFloat, torch::kByte, torch::kChar, torch::kShort, torch::kInt, + {torch::kFloat, + torch::kByte, + torch::kChar, + torch::kShort, + torch::kInt, torch::kLong}) { int lower_bound = (scalar_type1 == torch::kByte) ? 0 : -100; - torch::Tensor a = - isFloatingType(scalar_type1) - ? torch::rand({3, 4}, torch::TensorOptions(scalar_type1)) - : torch::randint(lower_bound, 50, {3, 4}, - torch::TensorOptions(scalar_type1)); + torch::Tensor a = isFloatingType(scalar_type1) + ? torch::rand({3, 4}, torch::TensorOptions(scalar_type1)) + : torch::randint( + lower_bound, 50, {3, 4}, torch::TensorOptions(scalar_type1)); for (torch::ScalarType scalar_type2 : - {torch::kFloat, torch::kByte, torch::kChar, torch::kShort, - torch::kInt, torch::kLong}) { - torch::Tensor b = - isFloatingType(scalar_type2) - ? torch::rand({3, 4}, torch::TensorOptions(scalar_type2)) - : torch::randint(51, 100, {3, 4}, - torch::TensorOptions(scalar_type2)); + {torch::kFloat, + torch::kByte, + torch::kChar, + torch::kShort, + torch::kInt, + torch::kLong}) { + torch::Tensor b = isFloatingType(scalar_type2) + ? torch::rand({3, 4}, torch::TensorOptions(scalar_type2)) + : torch::randint( + 51, 100, {3, 4}, torch::TensorOptions(scalar_type2)); torch::Tensor c = torch::div(a, b, rounding_mode); ForEachDevice([&](const torch::Device& device) { torch::Tensor lazy_a = CopyToDevice(a, device); @@ -529,17 +539,13 @@ TEST_F(LazyOpsTest, TestDivWithRoundingMode) { TEST_F(LazyOpsTest, TestDivInPlace) { for (torch::ScalarType scalar_type1 : {torch::kFloat}) { - torch::Tensor a = - isFloatingType(scalar_type1) - ? torch::rand({3, 4}, torch::TensorOptions(scalar_type1)) - : torch::randint(0, 100, {3, 4}, - torch::TensorOptions(scalar_type1)); + torch::Tensor a = isFloatingType(scalar_type1) + ? torch::rand({3, 4}, torch::TensorOptions(scalar_type1)) + : torch::randint(0, 100, {3, 4}, torch::TensorOptions(scalar_type1)); for (torch::ScalarType scalar_type2 : {torch::kFloat}) { - torch::Tensor b = - isFloatingType(scalar_type2) - ? torch::rand({3, 4}, torch::TensorOptions(scalar_type2)) - : torch::randint(1, 100, {3, 4}, - torch::TensorOptions(scalar_type2)); + torch::Tensor b = isFloatingType(scalar_type2) + ? torch::rand({3, 4}, torch::TensorOptions(scalar_type2)) + : torch::randint(1, 100, {3, 4}, torch::TensorOptions(scalar_type2)); ForEachDevice([&](const torch::Device& device) { torch::Tensor lazy_a = CopyToDevice(a, device); torch::Tensor c = a.div_(b); @@ -553,21 +559,19 @@ TEST_F(LazyOpsTest, TestDivInPlace) { } TEST_F(LazyOpsTest, TestDivInPlaceWithRoundingMode) { - c10::optional rounding_modes[] = {"trunc", "floor", - c10::nullopt}; + c10::optional rounding_modes[] = { + "trunc", "floor", c10::nullopt}; for (const auto& rounding_mode : rounding_modes) { for (torch::ScalarType scalar_type1 : {torch::kFloat}) { - torch::Tensor a = - isFloatingType(scalar_type1) - ? torch::rand({3, 4}, torch::TensorOptions(scalar_type1)) - : torch::randint(-100, 100, {3, 4}, - torch::TensorOptions(scalar_type1)); + torch::Tensor a = isFloatingType(scalar_type1) + ? torch::rand({3, 4}, torch::TensorOptions(scalar_type1)) + : torch::randint( + -100, 100, {3, 4}, torch::TensorOptions(scalar_type1)); for (torch::ScalarType scalar_type2 : {torch::kFloat}) { - torch::Tensor b = - isFloatingType(scalar_type2) - ? torch::rand({3, 4}, torch::TensorOptions(scalar_type2)) - : torch::randint(1, 100, {3, 4}, - torch::TensorOptions(scalar_type2)); + torch::Tensor b = isFloatingType(scalar_type2) + ? torch::rand({3, 4}, torch::TensorOptions(scalar_type2)) + : torch::randint( + 1, 100, {3, 4}, torch::TensorOptions(scalar_type2)); ForEachDevice([&](const torch::Device& device) { torch::Tensor lazy_a = CopyToDevice(a, device); torch::Tensor c = a.div_(b, rounding_mode); @@ -582,16 +586,20 @@ TEST_F(LazyOpsTest, TestDivInPlaceWithRoundingMode) { TEST_F(LazyOpsTest, TestDivScalar) { for (torch::ScalarType scalar_type : - {torch::kFloat, torch::kByte, torch::kChar, torch::kShort, torch::kInt, + {torch::kFloat, + torch::kByte, + torch::kChar, + torch::kShort, + torch::kInt, torch::kLong}) { - torch::Tensor a = - isFloatingType(scalar_type) - ? torch::rand( - {3, 4}, - torch::TensorOptions(scalar_type).device(DefaultDevice())) - : torch::randint( - 1, 100, {3, 4}, - torch::TensorOptions(scalar_type).device(DefaultDevice())); + torch::Tensor a = isFloatingType(scalar_type) + ? torch::rand( + {3, 4}, torch::TensorOptions(scalar_type).device(DefaultDevice())) + : torch::randint( + 1, + 100, + {3, 4}, + torch::TensorOptions(scalar_type).device(DefaultDevice())); for (bool is_float : {true, false}) { torch::Scalar b = is_float ? torch::Scalar(3.0) : torch::Scalar(3); torch::Tensor c = torch::div(a, b); @@ -606,14 +614,14 @@ TEST_F(LazyOpsTest, TestDivScalar) { TEST_F(LazyOpsTest, TestDivScalarInPlace) { for (torch::ScalarType scalar_type : {torch::kFloat}) { - torch::Tensor a = - isFloatingType(scalar_type) - ? torch::rand( - {3, 4}, - torch::TensorOptions(scalar_type).device(DefaultDevice())) - : torch::randint( - 1, 100, {3, 4}, - torch::TensorOptions(scalar_type).device(DefaultDevice())); + torch::Tensor a = isFloatingType(scalar_type) + ? torch::rand( + {3, 4}, torch::TensorOptions(scalar_type).device(DefaultDevice())) + : torch::randint( + 1, + 100, + {3, 4}, + torch::TensorOptions(scalar_type).device(DefaultDevice())); for (bool is_float : {true, false}) { torch::Scalar b = is_float ? torch::Scalar(3.0) : torch::Scalar(3); ForEachDevice([&](const torch::Device& device) { @@ -866,7 +874,9 @@ TEST_F(LazyOpsTest, TestGeScalar) { TEST_F(LazyOpsTest, TestGeScalarInplace) { torch::Tensor input = torch::arange( - -1., 1.5, 0.5, + -1., + 1.5, + 0.5, torch::TensorOptions(torch::kFloat).device(DefaultDevice())); torch::Scalar other(float(0)); torch::Tensor input_copy = input.clone(); @@ -891,7 +901,9 @@ TEST_F(LazyOpsTest, TestLeScalar) { TEST_F(LazyOpsTest, TestLeScalarInplace) { torch::Tensor input = torch::arange( - -1., 1.5, 0.5, + -1., + 1.5, + 0.5, torch::TensorOptions(torch::kFloat).device(DefaultDevice())); torch::Scalar other(float(0)); torch::Tensor input_copy = input.clone(); @@ -916,7 +928,9 @@ TEST_F(LazyOpsTest, TestGtScalar) { TEST_F(LazyOpsTest, TestGtScalarInplace) { torch::Tensor input = torch::arange( - -1., 1.5, 0.5, + -1., + 1.5, + 0.5, torch::TensorOptions(torch::kFloat).device(DefaultDevice())); torch::Scalar other(float(0)); torch::Tensor input_copy = input.clone(); @@ -941,7 +955,9 @@ TEST_F(LazyOpsTest, TestLtScalar) { TEST_F(LazyOpsTest, TestLtScalarInplace) { torch::Tensor input = torch::arange( - -1., 1.5, 0.5, + -1., + 1.5, + 0.5, torch::TensorOptions(torch::kFloat).device(DefaultDevice())); torch::Scalar other(float(0)); torch::Tensor input_copy = input.clone(); @@ -988,15 +1004,24 @@ TEST_F(LazyOpsTest, TestSVD) { auto lazy_b = torch::svd(lazy_a, /*some=*/true, /*compute_uv=*/true); // The U and V matrices might have different sign for column vectors, so // cannot be compared if not by absolute value. - AllClose(std::get<0>(b).abs(), std::get<0>(lazy_b).abs(), /*rtol=*/1e-3, - /*atol=*/1e-4); + AllClose( + std::get<0>(b).abs(), + std::get<0>(lazy_b).abs(), + /*rtol=*/1e-3, + /*atol=*/1e-4); torch::Tensor diag = std::get<1>(b); torch::Tensor lazy_diag = std::get<1>(lazy_b); ASSERT_EQ(diag.sizes(), lazy_diag.sizes()); - AllClose(diag, lazy_diag, /*rtol=*/1e-3, - /*atol=*/1e-4); - AllClose(std::get<2>(b).abs(), std::get<2>(lazy_b).abs(), /*rtol=*/1e-3, - /*atol=*/1e-4); + AllClose( + diag, + lazy_diag, + /*rtol=*/1e-3, + /*atol=*/1e-4); + AllClose( + std::get<2>(b).abs(), + std::get<2>(lazy_b).abs(), + /*rtol=*/1e-3, + /*atol=*/1e-4); }); } } @@ -1012,10 +1037,16 @@ TEST_F(LazyOpsTest, TestQR) { ForEachDevice([&](const torch::Device& device) { torch::Tensor lazy_a = CopyToDevice(a, device); auto lazy_b = torch::qr(lazy_a); - AllClose(std::get<0>(b).abs(), std::get<0>(lazy_b).abs(), /*rtol=*/1e-3, - /*atol=*/1e-4); - AllClose(std::get<1>(b).abs(), std::get<1>(lazy_b).abs(), /*rtol=*/1e-3, - /*atol=*/1e-4); + AllClose( + std::get<0>(b).abs(), + std::get<0>(lazy_b).abs(), + /*rtol=*/1e-3, + /*atol=*/1e-4); + AllClose( + std::get<1>(b).abs(), + std::get<1>(lazy_b).abs(), + /*rtol=*/1e-3, + /*atol=*/1e-4); }); } } @@ -1034,12 +1065,17 @@ TEST_F(LazyOpsTest, TestSymEig) { ForEachDevice([&](const torch::Device& device) { torch::Tensor lazy_a = CopyToDevice(sym_a, device); auto lazy_b = torch::symeig(lazy_a, eigenvectors, upper); - AllClose(std::get<0>(b), std::get<0>(lazy_b), /*rtol=*/3e-2, - /*atol=*/1e-2); + AllClose( + std::get<0>(b), + std::get<0>(lazy_b), + /*rtol=*/3e-2, + /*atol=*/1e-2); if (eigenvectors) { - AllClose(std::get<1>(b).abs(), std::get<1>(lazy_b).abs(), - /*rtol=*/3e-2, - /*atol=*/1e-2); + AllClose( + std::get<1>(b).abs(), + std::get<1>(lazy_b).abs(), + /*rtol=*/3e-2, + /*atol=*/1e-2); } else { EXPECT_EQ(std::get<1>(b).sizes(), std::get<1>(lazy_b).sizes()); } @@ -1075,8 +1111,7 @@ TEST_F(LazyOpsTest, TestLogDet) { for (auto m : dims) { torch::Tensor a = torch::rand( {3, m, m}, torch::TensorOptions(torch::kFloat).device(DefaultDevice())); - torch::Tensor pd_a = - torch::matmul(a, torch::transpose(a, 1, 2)) + + torch::Tensor pd_a = torch::matmul(a, torch::transpose(a, 1, 2)) + torch::eye(m, torch::TensorOptions(torch::kFloat).device(DefaultDevice())); torch::Tensor b = torch::logdet(pd_a); @@ -1097,27 +1132,41 @@ TEST_F(LazyOpsTest, TestTriangularSolve) { for (bool upper : {true, false}) { for (bool transpose : {true, false}) { for (bool unitriangular : {true, false}) { - torch::Tensor a = - torch::randn({m, m}, torch::TensorOptions(torch::kFloat) - .device(DefaultDevice())); - torch::Tensor b = - torch::randn({m, n}, torch::TensorOptions(torch::kFloat) - .device(DefaultDevice())); + torch::Tensor a = torch::randn( + {m, m}, + torch::TensorOptions(torch::kFloat) + .device(DefaultDevice())); + torch::Tensor b = torch::randn( + {m, n}, + torch::TensorOptions(torch::kFloat) + .device(DefaultDevice())); a = batched_a ? a.expand({3, m, m}).clone() : a; b = batched_b ? b.expand({3, m, n}).clone() : b; auto result = torch::triangular_solve( - b, a, /*upper=*/upper, /*transpose=*/transpose, + b, + a, + /*upper=*/upper, + /*transpose=*/transpose, /*unitriangular=*/unitriangular); ForEachDevice([&](const torch::Device& device) { torch::Tensor lazy_a = CopyToDevice(a, device); torch::Tensor lazy_b = CopyToDevice(b, device); auto lazy_result = torch::triangular_solve( - lazy_b, lazy_a, /*upper=*/upper, /*transpose=*/transpose, + lazy_b, + lazy_a, + /*upper=*/upper, + /*transpose=*/transpose, /*unitriangular=*/unitriangular); - AllClose(std::get<0>(result), std::get<0>(lazy_result), - /*rtol=*/1e-3, /*atol=*/1e-4); - AllClose(std::get<1>(result), std::get<1>(lazy_result), - /*rtol=*/1e-3, /*atol=*/1e-4); + AllClose( + std::get<0>(result), + std::get<0>(lazy_result), + /*rtol=*/1e-3, + /*atol=*/1e-4); + AllClose( + std::get<1>(result), + std::get<1>(lazy_result), + /*rtol=*/1e-3, + /*atol=*/1e-4); }); } } @@ -1266,16 +1315,19 @@ TEST_F(LazyOpsTest, TestUnaryMax) { TEST_F(LazyOpsTest, TestAll) { for (torch::ScalarType scalar_type : - {torch::kFloat, torch::kByte, torch::kChar, torch::kShort, torch::kInt, + {torch::kFloat, + torch::kByte, + torch::kChar, + torch::kShort, + torch::kInt, torch::kLong}) { - torch::Tensor a = - isFloatingType(scalar_type) - ? torch::rand( - {3, 4}, - torch::TensorOptions(scalar_type).device(DefaultDevice())) - : torch::randint( - 100, {3, 4}, - torch::TensorOptions(scalar_type).device(DefaultDevice())); + torch::Tensor a = isFloatingType(scalar_type) + ? torch::rand( + {3, 4}, torch::TensorOptions(scalar_type).device(DefaultDevice())) + : torch::randint( + 100, + {3, 4}, + torch::TensorOptions(scalar_type).device(DefaultDevice())); torch::Tensor b = torch::all(a); ForEachDevice([&](const torch::Device& device) { torch::Tensor lazy_a = CopyToDevice(a, device); @@ -1287,7 +1339,9 @@ TEST_F(LazyOpsTest, TestAll) { TEST_F(LazyOpsTest, TestAllDim) { torch::Tensor a = torch::randint( - 0, 5, {2, 3, 4}, + 0, + 5, + {2, 3, 4}, torch::TensorOptions(torch::kByte).device(DefaultDevice())); int rank = a.dim(); for (int dim = -rank; dim < rank; ++dim) { @@ -1302,7 +1356,9 @@ TEST_F(LazyOpsTest, TestAllDim) { TEST_F(LazyOpsTest, TestAllDimKeep) { torch::Tensor a = torch::randint( - 0, 5, {2, 3, 4}, + 0, + 5, + {2, 3, 4}, torch::TensorOptions(torch::kByte).device(DefaultDevice())); int rank = a.dim(); for (int dim = -rank; dim < rank; ++dim) { @@ -1383,16 +1439,19 @@ TEST_F(LazyOpsTest, TestAmin) { TEST_F(LazyOpsTest, TestAny) { for (torch::ScalarType scalar_type : - {torch::kFloat, torch::kByte, torch::kChar, torch::kShort, torch::kInt, + {torch::kFloat, + torch::kByte, + torch::kChar, + torch::kShort, + torch::kInt, torch::kLong}) { - torch::Tensor a = - isFloatingType(scalar_type) - ? torch::rand( - {3, 4}, - torch::TensorOptions(scalar_type).device(DefaultDevice())) - : torch::randint( - 100, {3, 4}, - torch::TensorOptions(scalar_type).device(DefaultDevice())); + torch::Tensor a = isFloatingType(scalar_type) + ? torch::rand( + {3, 4}, torch::TensorOptions(scalar_type).device(DefaultDevice())) + : torch::randint( + 100, + {3, 4}, + torch::TensorOptions(scalar_type).device(DefaultDevice())); torch::Tensor b = torch::any(a); ForEachDevice([&](const torch::Device& device) { torch::Tensor lazy_a = CopyToDevice(a, device); @@ -1404,7 +1463,9 @@ TEST_F(LazyOpsTest, TestAny) { TEST_F(LazyOpsTest, TestAnyDim) { torch::Tensor a = torch::randint( - 0, 5, {2, 3, 4}, + 0, + 5, + {2, 3, 4}, torch::TensorOptions(torch::kByte).device(DefaultDevice())); int rank = a.dim(); for (int dim = -rank; dim < rank; ++dim) { @@ -1419,7 +1480,9 @@ TEST_F(LazyOpsTest, TestAnyDim) { TEST_F(LazyOpsTest, TestAnyDimKeep) { torch::Tensor a = torch::randint( - 0, 5, {2, 3, 4}, + 0, + 5, + {2, 3, 4}, torch::TensorOptions(torch::kByte).device(DefaultDevice())); int rank = a.dim(); for (int dim = -rank; dim < rank; ++dim) { @@ -1899,8 +1962,14 @@ TEST_F(LazyOpsTest, TestUniformInPlace) { } TEST_F(LazyOpsTest, TestRandomInPlace) { - for (auto dtype : {torch::kFloat, torch::kDouble, torch::kByte, torch::kChar, - torch::kShort, torch::kInt, torch::kLong}) { + for (auto dtype : + {torch::kFloat, + torch::kDouble, + torch::kByte, + torch::kChar, + torch::kShort, + torch::kInt, + torch::kLong}) { const double eps = 0.2; torch::Tensor a = torch::zeros({10, 10, 10}, torch::TensorOptions(dtype)); ForEachDevice([&](const torch::Device& device) { @@ -1918,8 +1987,14 @@ TEST_F(LazyOpsTest, TestRandomInPlace) { } TEST_F(LazyOpsTest, TestRandomInPlaceDefaultFrom) { - for (auto dtype : {torch::kFloat, torch::kDouble, torch::kByte, torch::kChar, - torch::kShort, torch::kInt, torch::kLong}) { + for (auto dtype : + {torch::kFloat, + torch::kDouble, + torch::kByte, + torch::kChar, + torch::kShort, + torch::kInt, + torch::kLong}) { const double eps = 0.2; torch::Tensor a = torch::zeros({10, 10, 10}, torch::TensorOptions(dtype)); ForEachDevice([&](const torch::Device& device) { @@ -1937,8 +2012,14 @@ TEST_F(LazyOpsTest, TestRandomInPlaceDefaultFrom) { } TEST_F(LazyOpsTest, TestRandomInPlaceDefault) { - for (auto dtype : {torch::kFloat, torch::kDouble, torch::kByte, torch::kChar, - torch::kShort, torch::kInt, torch::kLong}) { + for (auto dtype : + {torch::kFloat, + torch::kDouble, + torch::kByte, + torch::kChar, + torch::kShort, + torch::kInt, + torch::kLong}) { auto input = torch::zeros({10}, torch::TensorOptions(dtype)); ForEachDevice([&](const torch::Device& device) { auto lazyInput = CopyToDevice(input, device); @@ -2012,27 +2093,35 @@ TEST_F(LazyOpsTest, TestFrobeniusNormInDims) { TEST_F(LazyOpsTest, TestGroupNorm) { int num_channels = 6; - torch::Tensor input = - torch::rand({20, num_channels, 10, 10}, - torch::TensorOptions(torch::kFloat).device(DefaultDevice())); - torch::Tensor weight = - torch::rand({num_channels}, - torch::TensorOptions(torch::kFloat).device(DefaultDevice())); - torch::Tensor bias = - torch::rand({num_channels}, - torch::TensorOptions(torch::kFloat).device(DefaultDevice())); + torch::Tensor input = torch::rand( + {20, num_channels, 10, 10}, + torch::TensorOptions(torch::kFloat).device(DefaultDevice())); + torch::Tensor weight = torch::rand( + {num_channels}, + torch::TensorOptions(torch::kFloat).device(DefaultDevice())); + torch::Tensor bias = torch::rand( + {num_channels}, + torch::TensorOptions(torch::kFloat).device(DefaultDevice())); double eps = 1e-05; for (int num_groups : {3, 6, 1}) { - torch::Tensor output = - torch::group_norm(input, num_groups, weight, bias, eps, - /*cudnn_enabled=*/false); + torch::Tensor output = torch::group_norm( + input, + num_groups, + weight, + bias, + eps, + /*cudnn_enabled=*/false); ForEachDevice([&](const torch::Device& device) { torch::Tensor lazy_input = CopyToDevice(input, device); torch::Tensor lazy_weight = CopyToDevice(weight, device); torch::Tensor lazy_bias = CopyToDevice(bias, device); - torch::Tensor lazy_output = - torch::group_norm(lazy_input, num_groups, lazy_weight, lazy_bias, eps, - /*cudnn_enabled=*/false); + torch::Tensor lazy_output = torch::group_norm( + lazy_input, + num_groups, + lazy_weight, + lazy_bias, + eps, + /*cudnn_enabled=*/false); AllClose(output, lazy_output, /*rtol=*/1e-3, /*atol=*/1e-5); }); } @@ -2040,25 +2129,31 @@ TEST_F(LazyOpsTest, TestGroupNorm) { TEST_F(LazyOpsTest, TestGroupNormBackward) { int num_channels = 6; - torch::Tensor input = - torch::rand({2, num_channels, 5, 5}, torch::TensorOptions(torch::kFloat) - .device(DefaultDevice()) - .requires_grad(true)); - torch::Tensor weight = - torch::rand({num_channels}, torch::TensorOptions(torch::kFloat) - .device(DefaultDevice()) - .requires_grad(true)); - torch::Tensor bias = - torch::rand({num_channels}, torch::TensorOptions(torch::kFloat) - .device(DefaultDevice()) - .requires_grad(true)); + torch::Tensor input = torch::rand( + {2, num_channels, 5, 5}, + torch::TensorOptions(torch::kFloat) + .device(DefaultDevice()) + .requires_grad(true)); + torch::Tensor weight = torch::rand( + {num_channels}, + torch::TensorOptions(torch::kFloat) + .device(DefaultDevice()) + .requires_grad(true)); + torch::Tensor bias = torch::rand( + {num_channels}, + torch::TensorOptions(torch::kFloat) + .device(DefaultDevice()) + .requires_grad(true)); double eps = 1e-05; for (bool undef_weight : {true, false}) { for (int num_groups : {3, 6, 1}) { auto testfn = [&](const std::vector& inputs) -> torch::Tensor { return torch::group_norm( - /*input=*/inputs[0], num_groups, inputs[1], inputs[2], + /*input=*/inputs[0], + num_groups, + inputs[1], + inputs[2], /*eps=*/eps, /*cudnn_enabled=*/false); }; @@ -2066,8 +2161,10 @@ TEST_F(LazyOpsTest, TestGroupNormBackward) { ForEachDevice([&](const torch::Device& device) { TestBackward( {input, undef_weight ? undef : weight, undef_weight ? undef : bias}, - device, testfn, - /*rtol=*/1e-3, /*atol=*/1e-3, + device, + testfn, + /*rtol=*/1e-3, + /*atol=*/1e-3, /*derivative_level=*/2); }); } @@ -2077,26 +2174,33 @@ TEST_F(LazyOpsTest, TestGroupNormBackward) { TEST_F(LazyOpsTest, TestInstanceNorm) { int batch = 5; int num_channels = 20; - torch::Tensor input = - torch::rand({batch, num_channels, 10, 10}, - torch::TensorOptions(torch::kFloat).device(DefaultDevice())); - torch::Tensor weight = - torch::rand({num_channels}, - torch::TensorOptions(torch::kFloat).device(DefaultDevice())); - torch::Tensor bias = - torch::rand({num_channels}, - torch::TensorOptions(torch::kFloat).device(DefaultDevice())); - torch::Tensor running_mean = - torch::zeros({num_channels}, - torch::TensorOptions(torch::kFloat).device(DefaultDevice())); - torch::Tensor running_var = - torch::ones({num_channels}, - torch::TensorOptions(torch::kFloat).device(DefaultDevice())); + torch::Tensor input = torch::rand( + {batch, num_channels, 10, 10}, + torch::TensorOptions(torch::kFloat).device(DefaultDevice())); + torch::Tensor weight = torch::rand( + {num_channels}, + torch::TensorOptions(torch::kFloat).device(DefaultDevice())); + torch::Tensor bias = torch::rand( + {num_channels}, + torch::TensorOptions(torch::kFloat).device(DefaultDevice())); + torch::Tensor running_mean = torch::zeros( + {num_channels}, + torch::TensorOptions(torch::kFloat).device(DefaultDevice())); + torch::Tensor running_var = torch::ones( + {num_channels}, + torch::TensorOptions(torch::kFloat).device(DefaultDevice())); double momentum = 0.1; double eps = 1e-05; torch::Tensor output = torch::instance_norm( - input, weight, bias, running_mean, running_var, - /*use_input_stats=*/true, momentum, eps, /*cudnn_enabled=*/false); + input, + weight, + bias, + running_mean, + running_var, + /*use_input_stats=*/true, + momentum, + eps, + /*cudnn_enabled=*/false); ForEachDevice([&](const torch::Device& device) { torch::Tensor lazy_input = CopyToDevice(input, device); torch::Tensor lazy_weight = CopyToDevice(weight, device); @@ -2104,16 +2208,23 @@ TEST_F(LazyOpsTest, TestInstanceNorm) { torch::Tensor lazy_running_mean = CopyToDevice(running_mean, device); torch::Tensor lazy_running_var = CopyToDevice(running_var, device); torch::Tensor lazy_output = torch::instance_norm( - lazy_input, lazy_weight, lazy_bias, lazy_running_mean, lazy_running_var, - /*use_input_stats=*/true, momentum, eps, /*cudnn_enabled=*/false); + lazy_input, + lazy_weight, + lazy_bias, + lazy_running_mean, + lazy_running_var, + /*use_input_stats=*/true, + momentum, + eps, + /*cudnn_enabled=*/false); AllClose(output, lazy_output, /*rtol=*/1e-3, /*atol=*/1e-5); }); } TEST_F(LazyOpsTest, TestLayerNorm) { - torch::Tensor input = - torch::rand({20, 10, 10, 10}, - torch::TensorOptions(torch::kFloat).device(DefaultDevice())); + torch::Tensor input = torch::rand( + {20, 10, 10, 10}, + torch::TensorOptions(torch::kFloat).device(DefaultDevice())); double eps = 1e-05; torch::Tensor undef; for (bool undef_weight : {true, false}) { @@ -2125,10 +2236,13 @@ TEST_F(LazyOpsTest, TestLayerNorm) { torch::Tensor bias = torch::rand( normalized_shape, torch::TensorOptions(torch::kFloat).device(DefaultDevice())); - torch::Tensor output = torch::layer_norm(input, normalized_shape, - undef_weight ? undef : weight, - undef_weight ? undef : bias, eps, - /*cudnn_enabled=*/false); + torch::Tensor output = torch::layer_norm( + input, + normalized_shape, + undef_weight ? undef : weight, + undef_weight ? undef : bias, + eps, + /*cudnn_enabled=*/false); ForEachDevice([&](const torch::Device& device) { torch::Tensor lazy_input = CopyToDevice(input, device); torch::Tensor lazy_weight = @@ -2136,7 +2250,11 @@ TEST_F(LazyOpsTest, TestLayerNorm) { torch::Tensor lazy_bias = undef_weight ? undef : CopyToDevice(bias, device); torch::Tensor lazy_output = torch::layer_norm( - lazy_input, normalized_shape, lazy_weight, lazy_bias, eps, + lazy_input, + normalized_shape, + lazy_weight, + lazy_bias, + eps, /*cudnn_enabled=*/false); AllClose(output, lazy_output, /*rtol=*/1e-3, /*atol=*/1e-5); }); @@ -2145,10 +2263,11 @@ TEST_F(LazyOpsTest, TestLayerNorm) { } TEST_F(LazyOpsTest, TestLayerNormBackward) { - torch::Tensor input = - torch::rand({2, 3, 3, 3}, torch::TensorOptions(torch::kFloat) - .device(DefaultDevice()) - .requires_grad(true)); + torch::Tensor input = torch::rand( + {2, 3, 3, 3}, + torch::TensorOptions(torch::kFloat) + .device(DefaultDevice()) + .requires_grad(true)); double eps = 1e-05; for (bool undef_weight : {true, false}) { for (int64_t normalized_size : {2, 3}) { @@ -2156,24 +2275,32 @@ TEST_F(LazyOpsTest, TestLayerNormBackward) { auto testfn = [&](const std::vector& inputs) -> torch::Tensor { return torch::layer_norm( - /*input=*/inputs[0], normalized_shape, inputs[1], inputs[2], + /*input=*/inputs[0], + normalized_shape, + inputs[1], + inputs[2], /*eps=*/eps, /*cudnn_enabled=*/false); }; - torch::Tensor weight = - torch::rand(normalized_shape, torch::TensorOptions(torch::kFloat) - .device(DefaultDevice()) - .requires_grad(true)); - torch::Tensor bias = - torch::rand(normalized_shape, torch::TensorOptions(torch::kFloat) - .device(DefaultDevice()) - .requires_grad(true)); + torch::Tensor weight = torch::rand( + normalized_shape, + torch::TensorOptions(torch::kFloat) + .device(DefaultDevice()) + .requires_grad(true)); + torch::Tensor bias = torch::rand( + normalized_shape, + torch::TensorOptions(torch::kFloat) + .device(DefaultDevice()) + .requires_grad(true)); torch::Tensor undef; ForEachDevice([&](const torch::Device& device) { TestBackward( {input, undef_weight ? undef : weight, undef_weight ? undef : bias}, - device, testfn, - /*rtol=*/1e-3, /*atol=*/1e-4, /*derivative_level=*/2); + device, + testfn, + /*rtol=*/1e-3, + /*atol=*/1e-4, + /*derivative_level=*/2); }); } } @@ -2295,7 +2422,13 @@ TEST_F(LazyOpsTest, TestTripletMarginLoss) { torch::Tensor lazy_positive = CopyToDevice(positive, device); torch::Tensor lazy_negative = CopyToDevice(negative, device); torch::Tensor lazy_output = torch::triplet_margin_loss( - lazy_anchor, lazy_positive, lazy_negative, margin, p, eps, swap, + lazy_anchor, + lazy_positive, + lazy_negative, + margin, + p, + eps, + swap, reduction); AllClose(output, lazy_output); }); @@ -2308,18 +2441,19 @@ TEST_F(LazyOpsTest, TestTripletMarginLoss) { TEST_F(LazyOpsTest, TestBinaryCrossEntropy) { int batch = 10; int classes = 5; - torch::Tensor input = - torch::rand({batch, classes}, - torch::TensorOptions(torch::kFloat).device(DefaultDevice())); - torch::Tensor target = - torch::rand({batch, classes}, - torch::TensorOptions(torch::kFloat).device(DefaultDevice())); - torch::Tensor weight = - torch::rand({batch, classes}, - torch::TensorOptions(torch::kFloat).device(DefaultDevice())); + torch::Tensor input = torch::rand( + {batch, classes}, + torch::TensorOptions(torch::kFloat).device(DefaultDevice())); + torch::Tensor target = torch::rand( + {batch, classes}, + torch::TensorOptions(torch::kFloat).device(DefaultDevice())); + torch::Tensor weight = torch::rand( + {batch, classes}, + torch::TensorOptions(torch::kFloat).device(DefaultDevice())); torch::Tensor undef; for (torch::Reduction::Reduction reduction : - {torch::Reduction::Mean, torch::Reduction::Sum, + {torch::Reduction::Mean, + torch::Reduction::Sum, torch::Reduction::None}) { for (bool undef_weight : {false, true}) { ForEachDevice([&](const torch::Device& device) { @@ -2364,12 +2498,12 @@ TEST_F(LazyOpsTest, TestMarginRankingLoss) { TEST_F(LazyOpsTest, TestBCEWithLogits) { int batch = 10; int classes = 5; - torch::Tensor input = - torch::rand({batch, classes}, - torch::TensorOptions(torch::kFloat).device(DefaultDevice())); - torch::Tensor target = - torch::rand({batch, classes}, - torch::TensorOptions(torch::kFloat).device(DefaultDevice())); + torch::Tensor input = torch::rand( + {batch, classes}, + torch::TensorOptions(torch::kFloat).device(DefaultDevice())); + torch::Tensor target = torch::rand( + {batch, classes}, + torch::TensorOptions(torch::kFloat).device(DefaultDevice())); torch::Tensor weight = torch::rand( {classes}, torch::TensorOptions(torch::kFloat).device(DefaultDevice())); torch::Tensor pos_weight = torch::rand( @@ -2381,8 +2515,11 @@ TEST_F(LazyOpsTest, TestBCEWithLogits) { for (bool undef_pos_weight : {false, true}) { ForEachDevice([&](const torch::Device& device) { torch::Tensor output = torch::binary_cross_entropy_with_logits( - input, target, undef_weight ? undef : weight, - undef_pos_weight ? undef : pos_weight, reduction); + input, + target, + undef_weight ? undef : weight, + undef_pos_weight ? undef : pos_weight, + reduction); torch::Tensor lazy_input = CopyToDevice(input, device); torch::Tensor lazy_target = CopyToDevice(target, device); torch::Tensor lazy_weight = @@ -2505,7 +2642,8 @@ TEST_F(LazyOpsTest, TestCumSumCast) { torch::Tensor result = torch::cumsum(input, dim, torch::kDouble); ForEachDevice([&](const torch::Device& device) { torch::Tensor lazy_input = CopyToDevice(input, device); - torch::Tensor lazy_result = torch::cumsum(lazy_input, dim, torch::kDouble); + torch::Tensor lazy_result = + torch::cumsum(lazy_input, dim, torch::kDouble); AllClose(result, lazy_result); }); } @@ -2513,7 +2651,8 @@ TEST_F(LazyOpsTest, TestCumSumCast) { TEST_F(LazyOpsTest, TestCumSumLong) { torch::Tensor input = torch::randint( - 1000, {4, 3, 4}, + 1000, + {4, 3, 4}, torch::TensorOptions(torch::kLong).device(DefaultDevice())); int rank = input.dim(); for (int dim = -rank; dim < rank; ++dim) { @@ -2556,15 +2695,17 @@ TEST_F(LazyOpsTest, TestCumProd) { TEST_F(LazyOpsTest, TestCumProdCast) { torch::Tensor input = torch::mul( - torch::rand({4, 3, 4}, - torch::TensorOptions(torch::kFloat).device(DefaultDevice())), + torch::rand( + {4, 3, 4}, + torch::TensorOptions(torch::kFloat).device(DefaultDevice())), 10); int rank = input.dim(); for (int dim = -rank; dim < rank; ++dim) { torch::Tensor result = torch::cumprod(input, dim, torch::kDouble); ForEachDevice([&](const torch::Device& device) { torch::Tensor lazy_input = CopyToDevice(input, device); - torch::Tensor lazy_result = torch::cumprod(lazy_input, dim, torch::kDouble); + torch::Tensor lazy_result = + torch::cumprod(lazy_input, dim, torch::kDouble); AllClose(result, lazy_result); }); } @@ -2586,8 +2727,8 @@ TEST_F(LazyOpsTest, TestCumProdLong) { TEST_F(LazyOpsTest, TestCumProdCastLong) { torch::Tensor input = - torch::rand({2, 3}, - torch::TensorOptions(torch::kFloat).device(DefaultDevice())) * + torch::rand( + {2, 3}, torch::TensorOptions(torch::kFloat).device(DefaultDevice())) * 7; int rank = input.dim(); for (int dim = -rank; dim < rank; ++dim) { @@ -2606,7 +2747,8 @@ TEST_F(LazyOpsTest, TestArgMin) { torch::Tensor b = torch::argmin(a, c10::nullopt, /*keepdim=*/false); ForEachDevice([&](const torch::Device& device) { torch::Tensor lazy_a = CopyToDevice(a, device); - torch::Tensor lazy_b = torch::argmin(lazy_a, c10::nullopt, /*keepdim=*/false); + torch::Tensor lazy_b = + torch::argmin(lazy_a, c10::nullopt, /*keepdim=*/false); AllEqual(b, lazy_b); }); } @@ -2667,7 +2809,8 @@ TEST_F(LazyOpsTest, TestArgMax) { torch::Tensor b = torch::argmax(a, c10::nullopt, /*keepdim=*/false); ForEachDevice([&](const torch::Device& device) { torch::Tensor lazy_a = CopyToDevice(a, device); - torch::Tensor lazy_b = torch::argmax(lazy_a, c10::nullopt, /*keepdim=*/false); + torch::Tensor lazy_b = + torch::argmax(lazy_a, c10::nullopt, /*keepdim=*/false); AllEqual(b, lazy_b); }); } @@ -2704,7 +2847,8 @@ TEST_F(LazyOpsTest, TestArgMaxSameValue) { torch::Tensor b = torch::argmax(a, c10::nullopt, /*keepdim=*/false); ForEachDevice([&](const torch::Device& device) { torch::Tensor lazy_a = CopyToDevice(a, device); - torch::Tensor lazy_b = torch::argmax(lazy_a, c10::nullopt, /*keepdim=*/false); + torch::Tensor lazy_b = + torch::argmax(lazy_a, c10::nullopt, /*keepdim=*/false); AllEqual(b, lazy_b); }); } @@ -2791,8 +2935,8 @@ TEST_F(LazyOpsTest, TestAcos) { TEST_F(LazyOpsTest, TestAcosh) { torch::Tensor a = - torch::rand({2, 2}, - torch::TensorOptions(torch::kFloat).device(DefaultDevice())) * + torch::rand( + {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice())) * 100; torch::Tensor b = torch::acosh(a); ForEachDevice([&](const torch::Device& device) { @@ -2804,8 +2948,8 @@ TEST_F(LazyOpsTest, TestAcosh) { TEST_F(LazyOpsTest, TestAcoshInPlace) { torch::Tensor a = - torch::rand({2, 2}, - torch::TensorOptions(torch::kFloat).device(DefaultDevice())) * + torch::rand( + {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice())) * 100; ForEachDevice([&](const torch::Device& device) { torch::Tensor lazy_a = CopyToDevice(a, device); @@ -3270,9 +3414,10 @@ TEST_F(LazyOpsTest, TestOnesLikeOptions) { } TEST_F(LazyOpsTest, TestFull) { - torch::Tensor a = - torch::full({2, 2}, 3.1165, - torch::TensorOptions(torch::kFloat).device(DefaultDevice())); + torch::Tensor a = torch::full( + {2, 2}, + 3.1165, + torch::TensorOptions(torch::kFloat).device(DefaultDevice())); ForEachDevice([&](const torch::Device& device) { torch::Tensor lazy_a = torch::full( {2, 2}, 3.1165, torch::TensorOptions(torch::kFloat).device(device)); @@ -3299,22 +3444,27 @@ TEST_F(LazyOpsTest, TestFullLikeOptions) { ForEachDevice([&](const torch::Device& device) { torch::Tensor lazy_a = CopyToDevice(a, device); torch::Tensor lazy_b = torch::full_like( - lazy_a, 3.1165, + lazy_a, + 3.1165, torch::TensorOptions(torch::kFloat).device(DefaultDevice())); AllClose(a, lazy_a); }); } TEST_F(LazyOpsTest, TestARange) { - for (auto& ranges : std::vector>{{0.0, 100.0, 0.5}, - {0.0, -100.0, -0.5}}) { + for (auto& ranges : std::vector>{ + {0.0, 100.0, 0.5}, {0.0, -100.0, -0.5}}) { torch::Tensor a = torch::arange( - ranges[0], ranges[1], ranges[2], + ranges[0], + ranges[1], + ranges[2], torch::TensorOptions(torch::kFloat).device(DefaultDevice())); ForEachDevice([&](const torch::Device& device) { - torch::Tensor lazy_a = - torch::arange(ranges[0], ranges[1], ranges[2], - torch::TensorOptions(torch::kFloat).device(device)); + torch::Tensor lazy_a = torch::arange( + ranges[0], + ranges[1], + ranges[2], + torch::TensorOptions(torch::kFloat).device(device)); AllClose(a, lazy_a); }); } @@ -3323,8 +3473,8 @@ TEST_F(LazyOpsTest, TestARange) { TEST_F(LazyOpsTest, TestARangeOut) { torch::Tensor a = torch::randn( {4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice())); - for (auto& ranges : std::vector>{{0.0, 100.0, 0.5}, - {0.0, -100.0, -0.5}}) { + for (auto& ranges : std::vector>{ + {0.0, 100.0, 0.5}, {0.0, -100.0, -0.5}}) { torch::Tensor b = torch::arange_out(a, ranges[0], ranges[1], ranges[2]); ForEachDevice([&](const torch::Device& device) { torch::Tensor lazy_a = CopyToDevice(a, device); @@ -3351,11 +3501,13 @@ TEST_F(LazyOpsTest, TestBartlettWindow) { for (bool periodic : {false, true}) { ForEachDevice([&](const torch::Device& device) { torch::Tensor output = torch::bartlett_window( - window_length, periodic, + window_length, + periodic, torch::TensorOptions(torch::kFloat).device(DefaultDevice())); torch::Tensor lazy_output = torch::bartlett_window( - window_length, periodic, + window_length, + periodic, torch::TensorOptions(torch::kFloat).device(device)); AllClose(output, lazy_output, /*rtol=*/1e-5, /*atol=*/1e-7); }); @@ -3367,10 +3519,12 @@ TEST_F(LazyOpsTest, TestBlackmanWindow) { for (bool periodic : {false, true}) { ForEachDevice([&](const torch::Device& device) { torch::Tensor output = torch::blackman_window( - window_length, periodic, + window_length, + periodic, torch::TensorOptions(torch::kFloat).device(DefaultDevice())); torch::Tensor lazy_output = torch::blackman_window( - window_length, periodic, + window_length, + periodic, torch::TensorOptions(torch::kFloat).device(device)); AllClose(output, lazy_output, /*rtol=*/1e-5, /*atol=*/1e-7); }); @@ -3384,10 +3538,16 @@ TEST_F(LazyOpsTest, TestHammingWindow) { for (bool periodic : {false, true}) { ForEachDevice([&](const torch::Device& device) { torch::Tensor output = torch::hamming_window( - window_length, periodic, alpha, beta, + window_length, + periodic, + alpha, + beta, torch::TensorOptions(torch::kFloat).device(DefaultDevice())); torch::Tensor lazy_output = torch::hamming_window( - window_length, periodic, alpha, beta, + window_length, + periodic, + alpha, + beta, torch::TensorOptions(torch::kFloat).device(device)); AllClose(output, lazy_output); }); @@ -3399,10 +3559,12 @@ TEST_F(LazyOpsTest, TestHannWindow) { for (bool periodic : {false, true}) { ForEachDevice([&](const torch::Device& device) { torch::Tensor output = torch::hann_window( - window_length, periodic, + window_length, + periodic, torch::TensorOptions(torch::kFloat).device(DefaultDevice())); torch::Tensor lazy_output = torch::hann_window( - window_length, periodic, + window_length, + periodic, torch::TensorOptions(torch::kFloat).device(device)); AllClose(output, lazy_output); }); @@ -3429,10 +3591,16 @@ TEST_F(LazyOpsTest, TestLogSigmoidForward) { ForEachDevice([&](const torch::Device& device) { torch::Tensor lazy_a = CopyToDevice(a, device); auto lazy_tuple = torch::log_sigmoid_forward(lazy_a); - AllClose(std::get<0>(tuple), std::get<0>(lazy_tuple), - /*rtol=*/1e-3, /*atol=*/1e-5); - AllClose(std::get<1>(tuple), std::get<1>(lazy_tuple), - /*rtol=*/1e-3, /*atol=*/1e-5); + AllClose( + std::get<0>(tuple), + std::get<0>(lazy_tuple), + /*rtol=*/1e-3, + /*atol=*/1e-5); + AllClose( + std::get<1>(tuple), + std::get<1>(lazy_tuple), + /*rtol=*/1e-3, + /*atol=*/1e-5); }); } @@ -3531,12 +3699,12 @@ TEST_F(LazyOpsTest, TestMatmul_2x2) { } TEST_F(LazyOpsTest, TestMatmulBcast) { - torch::Tensor a = - torch::rand({4, 2, 3, 2, 4}, - torch::TensorOptions(torch::kFloat).device(DefaultDevice())); - torch::Tensor b = - torch::rand({2, 1, 4, 3}, - torch::TensorOptions(torch::kFloat).device(DefaultDevice())); + torch::Tensor a = torch::rand( + {4, 2, 3, 2, 4}, + torch::TensorOptions(torch::kFloat).device(DefaultDevice())); + torch::Tensor b = torch::rand( + {2, 1, 4, 3}, + torch::TensorOptions(torch::kFloat).device(DefaultDevice())); torch::Tensor c = torch::matmul(a, b); ForEachDevice([&](const torch::Device& device) { torch::Tensor lazy_a = CopyToDevice(a, device); @@ -3712,8 +3880,11 @@ TEST_F(LazyOpsTest, TestLinear) { torch::Tensor lazy_result_with_bias = torch::linear(lazy_input, lazy_weight, lazy_bias); AllClose(result, lazy_result, /*rtol=*/1e-2, /*atol=*/1e-4); - AllClose(result_with_bias, lazy_result_with_bias, /*rtol=*/1e-2, - /*atol=*/1e-4); + AllClose( + result_with_bias, + lazy_result_with_bias, + /*rtol=*/1e-2, + /*atol=*/1e-4); }); } @@ -3744,12 +3915,16 @@ TEST_F(LazyOpsTest, TestEinsumOuter) { } TEST_F(LazyOpsTest, TestEinsumOuterBackward) { - torch::Tensor a = torch::rand({5}, torch::TensorOptions(torch::kFloat) - .device(DefaultDevice()) - .requires_grad(true)); - torch::Tensor b = torch::rand({5}, torch::TensorOptions(torch::kFloat) - .device(DefaultDevice()) - .requires_grad(true)); + torch::Tensor a = torch::rand( + {5}, + torch::TensorOptions(torch::kFloat) + .device(DefaultDevice()) + .requires_grad(true)); + torch::Tensor b = torch::rand( + {5}, + torch::TensorOptions(torch::kFloat) + .device(DefaultDevice()) + .requires_grad(true)); std::string equation = "i,j->ij"; auto testfn = [&](const std::vector& inputs) -> torch::Tensor { return torch::einsum(equation, inputs); @@ -3817,9 +3992,9 @@ TEST_F(LazyOpsTest, TestEinsumPyTorchLowerBatchDiagonal) { } TEST_F(LazyOpsTest, TestEinsumPyTorchLowerBatchPermute) { - torch::Tensor input = - torch::rand({2, 3, 4, 5}, - torch::TensorOptions(torch::kFloat).device(DefaultDevice())); + torch::Tensor input = torch::rand( + {2, 3, 4, 5}, + torch::TensorOptions(torch::kFloat).device(DefaultDevice())); std::string equation = "...ij->...ji"; torch::Tensor result = torch::einsum(equation, {input}); ForEachDevice([&](const torch::Device& device) { @@ -3849,18 +4024,18 @@ TEST_F(LazyOpsTest, TestBilinear) { int in1_features = 4; int in2_features = 6; int out_features = 8; - torch::Tensor input1 = - torch::rand({batch_size, in1_features}, - torch::TensorOptions(torch::kFloat).device(DefaultDevice())); - torch::Tensor input2 = - torch::rand({batch_size, in2_features}, - torch::TensorOptions(torch::kFloat).device(DefaultDevice())); - torch::Tensor weight = - torch::rand({out_features, in1_features, in2_features}, - torch::TensorOptions(torch::kFloat).device(DefaultDevice())); - torch::Tensor bias = - torch::rand({out_features}, - torch::TensorOptions(torch::kFloat).device(DefaultDevice())); + torch::Tensor input1 = torch::rand( + {batch_size, in1_features}, + torch::TensorOptions(torch::kFloat).device(DefaultDevice())); + torch::Tensor input2 = torch::rand( + {batch_size, in2_features}, + torch::TensorOptions(torch::kFloat).device(DefaultDevice())); + torch::Tensor weight = torch::rand( + {out_features, in1_features, in2_features}, + torch::TensorOptions(torch::kFloat).device(DefaultDevice())); + torch::Tensor bias = torch::rand( + {out_features}, + torch::TensorOptions(torch::kFloat).device(DefaultDevice())); ForEachDevice([&](const torch::Device& device) { torch::Tensor lazy_input1 = CopyToDevice(input1, device); torch::Tensor lazy_input2 = CopyToDevice(input2, device); @@ -3880,9 +4055,9 @@ TEST_F(LazyOpsTest, TestUpsampleNearest2D) { int uh = 8; int uw = 8; int chans = 2; - torch::Tensor input = - torch::rand({batch_size, chans, h, w}, - torch::TensorOptions(torch::kFloat).device(DefaultDevice())); + torch::Tensor input = torch::rand( + {batch_size, chans, h, w}, + torch::TensorOptions(torch::kFloat).device(DefaultDevice())); ForEachDevice([&](const torch::Device& device) { torch::Tensor lazy_input = CopyToDevice(input, device); torch::Tensor result = torch::upsample_nearest2d(input, {uh, uw}); @@ -3902,11 +4077,14 @@ TEST_F(LazyOpsTest, TestUpsampleNearest2DBackward) { return torch::upsample_nearest2d(inputs[0], {uh, uw}); }; ForEachDevice([&](const torch::Device& device) { - TestBackward({torch::rand({batch_size, chans, h, w}, - torch::TensorOptions(torch::kFloat) - .device(DefaultDevice()) - .requires_grad(true))}, - device, testfn); + TestBackward( + {torch::rand( + {batch_size, chans, h, w}, + torch::TensorOptions(torch::kFloat) + .device(DefaultDevice()) + .requires_grad(true))}, + device, + testfn); }); } @@ -3917,9 +4095,9 @@ TEST_F(LazyOpsTest, TestUpsampleNearest2DWithScale) { int chans = 2; double scale_h = 2.5; double scale_w = 3.4; - torch::Tensor input = - torch::rand({batch_size, chans, h, w}, - torch::TensorOptions(torch::kFloat).device(DefaultDevice())); + torch::Tensor input = torch::rand( + {batch_size, chans, h, w}, + torch::TensorOptions(torch::kFloat).device(DefaultDevice())); ForEachDevice([&](const torch::Device& device) { torch::Tensor lazy_input = CopyToDevice(input, device); torch::Tensor result = torch::upsample_nearest2d( @@ -3938,15 +4116,18 @@ TEST_F(LazyOpsTest, TestUpsampleNearest2DBackwardWithScale) { double scale_h = 2.5; double scale_w = 3.4; auto testfn = [&](const std::vector& inputs) -> torch::Tensor { - return torch::upsample_nearest2d(inputs[0], c10::nullopt, - at::ArrayRef{scale_h, scale_w}); + return torch::upsample_nearest2d( + inputs[0], c10::nullopt, at::ArrayRef{scale_h, scale_w}); }; ForEachDevice([&](const torch::Device& device) { - TestBackward({torch::rand({batch_size, chans, h, w}, - torch::TensorOptions(torch::kFloat) - .device(DefaultDevice()) - .requires_grad(true))}, - device, testfn); + TestBackward( + {torch::rand( + {batch_size, chans, h, w}, + torch::TensorOptions(torch::kFloat) + .device(DefaultDevice()) + .requires_grad(true))}, + device, + testfn); }); } @@ -3985,11 +4166,14 @@ TEST_F(LazyOpsTest, TestUpsampleBilinear2DBackward) { return torch::upsample_bilinear2d(inputs[0], {uh, uw}, align_corners); }; ForEachDevice([&](const torch::Device& device) { - TestBackward({torch::rand({batch_size, chans, h, w}, - torch::TensorOptions(torch::kFloat) - .device(DefaultDevice()) - .requires_grad(true))}, - device, testfn); + TestBackward( + {torch::rand( + {batch_size, chans, h, w}, + torch::TensorOptions(torch::kFloat) + .device(DefaultDevice()) + .requires_grad(true))}, + device, + testfn); }); } } @@ -4052,9 +4236,9 @@ TEST_F(LazyOpsTest, TestAddCDivWithBroadcast) { } TEST_F(LazyOpsTest, TestSize) { - torch::Tensor input = - torch::rand({2, 1, 4, 6}, - torch::TensorOptions(torch::kFloat).device(DefaultDevice())); + torch::Tensor input = torch::rand( + {2, 1, 4, 6}, + torch::TensorOptions(torch::kFloat).device(DefaultDevice())); int rank = input.dim(); ForEachDevice([&](const torch::Device& device) { torch::Tensor lazy_input = CopyToDevice(input, device); @@ -4073,9 +4257,12 @@ TEST_F(LazyOpsTest, TestSelect) { return torch::select(inputs[0], dim, 0); }; ForEachDevice([&](const torch::Device& device) { - TestBackward({torch::rand(input_sizes, torch::TensorOptions(torch::kFloat) - .requires_grad(true))}, - device, testfn); + TestBackward( + {torch::rand( + input_sizes, + torch::TensorOptions(torch::kFloat).requires_grad(true))}, + device, + testfn); }); }; } @@ -4165,16 +4352,16 @@ TEST_F(LazyOpsTest, TestRandperm) { torch::Tensor shuffle = torch::randperm( n, torch::TensorOptions(torch::kLong).device(torch::kLazy)); torch::Tensor shuffle_cpu = CopyToDevice(shuffle, torch::kCPU); - std::vector shuffle_data(shuffle_cpu.data_ptr(), - shuffle_cpu.data_ptr() + n); - EXPECT_TRUE(shuffle_data.size() == n && - torch::lazy::IsPermutation(shuffle_data)); + std::vector shuffle_data( + shuffle_cpu.data_ptr(), shuffle_cpu.data_ptr() + n); + EXPECT_TRUE( + shuffle_data.size() == n && torch::lazy::IsPermutation(shuffle_data)); } TEST_F(LazyOpsTest, TestSlice) { - torch::Tensor a = - torch::rand({32, 24, 16}, - torch::TensorOptions(torch::kFloat).device(DefaultDevice())); + torch::Tensor a = torch::rand( + {32, 24, 16}, + torch::TensorOptions(torch::kFloat).device(DefaultDevice())); torch::Tensor b = torch::slice(a, 1, 0, 16, 1); ForEachDevice([&](const torch::Device& device) { torch::Tensor lazy_a = CopyToDevice(a, device); @@ -4203,13 +4390,17 @@ TEST_F(LazyOpsTest, TestTakeBackward) { }; ForEachDevice([&](const torch::Device& device) { TestBackward( - {torch::rand({4, 4}, torch::TensorOptions(torch::kFloat) - .device(DefaultDevice()) - .requires_grad(true)), + {torch::rand( + {4, 4}, + torch::TensorOptions(torch::kFloat) + .device(DefaultDevice()) + .requires_grad(true)), torch::randint( - 16, {5}, + 16, + {5}, torch::TensorOptions(torch::kLong).device(DefaultDevice()))}, - device, testfn); + device, + testfn); }); } @@ -4500,16 +4691,19 @@ TEST_F(LazyOpsTest, TestScatterAddInPlace) { TEST_F(LazyOpsTest, TestIndexSelect) { for (torch::ScalarType scalar_type : - {torch::kFloat, torch::kByte, torch::kChar, torch::kShort, torch::kInt, + {torch::kFloat, + torch::kByte, + torch::kChar, + torch::kShort, + torch::kInt, torch::kLong}) { - torch::Tensor a = - isFloatingType(scalar_type) - ? torch::rand( - {3, 4}, - torch::TensorOptions(scalar_type).device(DefaultDevice())) - : torch::randint( - 100, {3, 4}, - torch::TensorOptions(scalar_type).device(DefaultDevice())); + torch::Tensor a = isFloatingType(scalar_type) + ? torch::rand( + {3, 4}, torch::TensorOptions(scalar_type).device(DefaultDevice())) + : torch::randint( + 100, + {3, 4}, + torch::TensorOptions(scalar_type).device(DefaultDevice())); for (torch::ScalarType index_scalar_type : {torch::kInt, torch::kLong}) { torch::Tensor b = torch::empty( {2}, torch::TensorOptions(index_scalar_type).device(DefaultDevice())); @@ -4521,8 +4715,10 @@ TEST_F(LazyOpsTest, TestIndexSelect) { ForEachDevice([&](const torch::Device& device) { torch::Tensor lazy_a = CopyToDevice(a, device); torch::Tensor lazy_b = CopyToDevice(b, device); - torch::Tensor lazy_c0 = torch::index_select(lazy_a, 0 + offset, lazy_b); - torch::Tensor lazy_c1 = torch::index_select(lazy_a, 1 + offset, lazy_b); + torch::Tensor lazy_c0 = + torch::index_select(lazy_a, 0 + offset, lazy_b); + torch::Tensor lazy_c1 = + torch::index_select(lazy_a, 1 + offset, lazy_b); AllEqual(c0, lazy_c0); AllEqual(c1, lazy_c1); }); @@ -4533,16 +4729,19 @@ TEST_F(LazyOpsTest, TestIndexSelect) { TEST_F(LazyOpsTest, TestIndexSelectRank0) { for (torch::ScalarType scalar_type : - {torch::kFloat, torch::kByte, torch::kChar, torch::kShort, torch::kInt, + {torch::kFloat, + torch::kByte, + torch::kChar, + torch::kShort, + torch::kInt, torch::kLong}) { - torch::Tensor a = - isFloatingType(scalar_type) - ? torch::rand( - {3, 4}, - torch::TensorOptions(scalar_type).device(DefaultDevice())) - : torch::randint( - 100, {3, 4}, - torch::TensorOptions(scalar_type).device(DefaultDevice())); + torch::Tensor a = isFloatingType(scalar_type) + ? torch::rand( + {3, 4}, torch::TensorOptions(scalar_type).device(DefaultDevice())) + : torch::randint( + 100, + {3, 4}, + torch::TensorOptions(scalar_type).device(DefaultDevice())); torch::Tensor b = torch::scalar_tensor( 2, torch::TensorOptions(torch::kLong).device(DefaultDevice())); torch::Tensor c0 = torch::index_select(a, 0, b); @@ -4638,9 +4837,10 @@ TEST_F(LazyOpsTest, TestEyeWide) { int lines = 3; int cols = 5; ForEachDevice([&](const torch::Device& device) { - torch::Tensor out = - torch::eye(lines, cols, - torch::TensorOptions(torch::kFloat).device(DefaultDevice())); + torch::Tensor out = torch::eye( + lines, + cols, + torch::TensorOptions(torch::kFloat).device(DefaultDevice())); torch::Tensor lazy_out = torch::eye( lines, cols, torch::TensorOptions(torch::kFloat).device(device)); AllClose(out, lazy_out); @@ -4651,9 +4851,10 @@ TEST_F(LazyOpsTest, TestEyeNarrow) { int lines = 5; int cols = 3; ForEachDevice([&](const torch::Device& device) { - torch::Tensor out = - torch::eye(lines, cols, - torch::TensorOptions(torch::kFloat).device(DefaultDevice())); + torch::Tensor out = torch::eye( + lines, + cols, + torch::TensorOptions(torch::kFloat).device(DefaultDevice())); torch::Tensor lazy_out = torch::eye( lines, cols, torch::TensorOptions(torch::kFloat).device(device)); AllClose(out, lazy_out); @@ -4669,7 +4870,8 @@ TEST_F(LazyOpsTest, TestBroadcastTensors) { ForEachDevice([&](const torch::Device& device) { torch::Tensor lazy_a = CopyToDevice(a, device); torch::Tensor lazy_b = CopyToDevice(b, device); - std::vector lazy_c = torch::broadcast_tensors({lazy_a, lazy_b}); + std::vector lazy_c = + torch::broadcast_tensors({lazy_a, lazy_b}); ASSERT_EQ(c.size(), lazy_c.size()); for (size_t i = 0; i < c.size(); ++i) { AllClose(c[i], lazy_c[i]); @@ -4679,18 +4881,24 @@ TEST_F(LazyOpsTest, TestBroadcastTensors) { TEST_F(LazyOpsTest, TestOneIndex) { for (torch::ScalarType scalar_type : - {torch::kFloat, torch::kByte, torch::kChar, torch::kShort, torch::kInt, + {torch::kFloat, + torch::kByte, + torch::kChar, + torch::kShort, + torch::kInt, torch::kLong}) { - torch::Tensor params = - isFloatingType(scalar_type) - ? torch::rand( - {4, 3, 5, 6, 7}, - torch::TensorOptions(scalar_type).device(DefaultDevice())) - : torch::randint( - 100, {4, 3, 5, 6, 7}, - torch::TensorOptions(scalar_type).device(DefaultDevice())); + torch::Tensor params = isFloatingType(scalar_type) + ? torch::rand( + {4, 3, 5, 6, 7}, + torch::TensorOptions(scalar_type).device(DefaultDevice())) + : torch::randint( + 100, + {4, 3, 5, 6, 7}, + torch::TensorOptions(scalar_type).device(DefaultDevice())); torch::Tensor indices = torch::randint( - -3, 3, {2, 4, 3}, + -3, + 3, + {2, 4, 3}, torch::TensorOptions(torch::kLong).device(DefaultDevice())); torch::Tensor result = torch::index(params, {indices}); ForEachDevice([&](const torch::Device& device) { @@ -4704,18 +4912,24 @@ TEST_F(LazyOpsTest, TestOneIndex) { TEST_F(LazyOpsTest, TestOneIndexTransfer) { for (torch::ScalarType scalar_type : - {torch::kFloat, torch::kByte, torch::kChar, torch::kShort, torch::kInt, + {torch::kFloat, + torch::kByte, + torch::kChar, + torch::kShort, + torch::kInt, torch::kLong}) { - torch::Tensor params = - isFloatingType(scalar_type) - ? torch::rand( - {4, 3, 5, 6, 7}, - torch::TensorOptions(scalar_type).device(DefaultDevice())) - : torch::randint( - 100, {4, 3, 5, 6, 7}, - torch::TensorOptions(scalar_type).device(DefaultDevice())); + torch::Tensor params = isFloatingType(scalar_type) + ? torch::rand( + {4, 3, 5, 6, 7}, + torch::TensorOptions(scalar_type).device(DefaultDevice())) + : torch::randint( + 100, + {4, 3, 5, 6, 7}, + torch::TensorOptions(scalar_type).device(DefaultDevice())); torch::Tensor indices = torch::randint( - -3, 3, {2, 4, 3}, + -3, + 3, + {2, 4, 3}, torch::TensorOptions(torch::kLong).device(DefaultDevice())); torch::Tensor result = torch::index(params, {indices}); ForEachDevice([&](const torch::Device& device) { @@ -4793,22 +5007,30 @@ TEST_F(LazyOpsTest, TestMaskedScatter) { TEST_F(LazyOpsTest, TestMultiIndexHeadNull) { for (torch::ScalarType scalar_type : - {torch::kFloat, torch::kByte, torch::kChar, torch::kShort, torch::kInt, + {torch::kFloat, + torch::kByte, + torch::kChar, + torch::kShort, + torch::kInt, torch::kLong}) { - torch::Tensor params = - isFloatingType(scalar_type) - ? torch::rand( - {4, 3, 5, 6, 7}, - torch::TensorOptions(scalar_type).device(DefaultDevice())) - : torch::randint( - 100, {4, 3, 5, 6, 7}, - torch::TensorOptions(scalar_type).device(DefaultDevice())); + torch::Tensor params = isFloatingType(scalar_type) + ? torch::rand( + {4, 3, 5, 6, 7}, + torch::TensorOptions(scalar_type).device(DefaultDevice())) + : torch::randint( + 100, + {4, 3, 5, 6, 7}, + torch::TensorOptions(scalar_type).device(DefaultDevice())); torch::Tensor indices_null; torch::Tensor indices_0 = torch::randint( - -3, 3, {2, 4, 3}, + -3, + 3, + {2, 4, 3}, torch::TensorOptions(torch::kLong).device(DefaultDevice())); torch::Tensor indices_1 = torch::randint( - -3, 3, {2, 4, 3}, + -3, + 3, + {2, 4, 3}, torch::TensorOptions(torch::kLong).device(DefaultDevice())); torch::Tensor result = torch::index(params, {indices_null, indices_0, indices_1}); @@ -4825,22 +5047,30 @@ TEST_F(LazyOpsTest, TestMultiIndexHeadNull) { TEST_F(LazyOpsTest, TestMultiIndexMiddleNull) { for (torch::ScalarType scalar_type : - {torch::kFloat, torch::kByte, torch::kChar, torch::kShort, torch::kInt, + {torch::kFloat, + torch::kByte, + torch::kChar, + torch::kShort, + torch::kInt, torch::kLong}) { - torch::Tensor params = - isFloatingType(scalar_type) - ? torch::rand( - {4, 3, 5, 6, 7}, - torch::TensorOptions(scalar_type).device(DefaultDevice())) - : torch::randint( - 100, {4, 3, 5, 6, 7}, - torch::TensorOptions(scalar_type).device(DefaultDevice())); + torch::Tensor params = isFloatingType(scalar_type) + ? torch::rand( + {4, 3, 5, 6, 7}, + torch::TensorOptions(scalar_type).device(DefaultDevice())) + : torch::randint( + 100, + {4, 3, 5, 6, 7}, + torch::TensorOptions(scalar_type).device(DefaultDevice())); torch::Tensor indices_0 = torch::randint( - -3, 3, {2, 4, 3}, + -3, + 3, + {2, 4, 3}, torch::TensorOptions(torch::kLong).device(DefaultDevice())); torch::Tensor indices_null; torch::Tensor indices_1 = torch::randint( - -3, 3, {2, 4, 3}, + -3, + 3, + {2, 4, 3}, torch::TensorOptions(torch::kLong).device(DefaultDevice())); torch::Tensor result = torch::index(params, {indices_0, indices_null, indices_1}); @@ -4857,22 +5087,30 @@ TEST_F(LazyOpsTest, TestMultiIndexMiddleNull) { TEST_F(LazyOpsTest, TestMultiIndexTailNull) { for (torch::ScalarType scalar_type : - {torch::kFloat, torch::kByte, torch::kChar, torch::kShort, torch::kInt, + {torch::kFloat, + torch::kByte, + torch::kChar, + torch::kShort, + torch::kInt, torch::kLong}) { - torch::Tensor params = - isFloatingType(scalar_type) - ? torch::rand( - {4, 3, 5, 6, 7}, - torch::TensorOptions(scalar_type).device(DefaultDevice())) - : torch::randint( - 100, {4, 3, 5, 6, 7}, - torch::TensorOptions(scalar_type).device(DefaultDevice())); + torch::Tensor params = isFloatingType(scalar_type) + ? torch::rand( + {4, 3, 5, 6, 7}, + torch::TensorOptions(scalar_type).device(DefaultDevice())) + : torch::randint( + 100, + {4, 3, 5, 6, 7}, + torch::TensorOptions(scalar_type).device(DefaultDevice())); torch::Tensor indices_0 = torch::randint( - -3, 3, {2, 4, 3}, + -3, + 3, + {2, 4, 3}, torch::TensorOptions(torch::kLong).device(DefaultDevice())); torch::Tensor indices_null; torch::Tensor indices_1 = torch::randint( - -3, 3, {2, 4, 3}, + -3, + 3, + {2, 4, 3}, torch::TensorOptions(torch::kLong).device(DefaultDevice())); torch::Tensor result = torch::index(params, {indices_0, indices_1, indices_null}); @@ -4889,21 +5127,29 @@ TEST_F(LazyOpsTest, TestMultiIndexTailNull) { TEST_F(LazyOpsTest, TestMultiIndexMiddleBroadcast) { for (torch::ScalarType scalar_type : - {torch::kFloat, torch::kByte, torch::kChar, torch::kShort, torch::kInt, + {torch::kFloat, + torch::kByte, + torch::kChar, + torch::kShort, + torch::kInt, torch::kLong}) { - torch::Tensor params = - isFloatingType(scalar_type) - ? torch::rand( - {4, 3, 5, 6, 7}, - torch::TensorOptions(scalar_type).device(DefaultDevice())) - : torch::randint( - 100, {4, 3, 5, 6, 7}, - torch::TensorOptions(scalar_type).device(DefaultDevice())); + torch::Tensor params = isFloatingType(scalar_type) + ? torch::rand( + {4, 3, 5, 6, 7}, + torch::TensorOptions(scalar_type).device(DefaultDevice())) + : torch::randint( + 100, + {4, 3, 5, 6, 7}, + torch::TensorOptions(scalar_type).device(DefaultDevice())); torch::Tensor indices_0 = torch::randint( - -3, 3, {2, 4, 3}, + -3, + 3, + {2, 4, 3}, torch::TensorOptions(torch::kLong).device(DefaultDevice())); torch::Tensor indices_1 = torch::randint( - -3, 3, {2, 1, 3}, + -3, + 3, + {2, 1, 3}, torch::TensorOptions(torch::kLong).device(DefaultDevice())); torch::Tensor result = torch::index(params, {indices_0, indices_1}); ForEachDevice([&](const torch::Device& device) { @@ -4919,21 +5165,29 @@ TEST_F(LazyOpsTest, TestMultiIndexMiddleBroadcast) { TEST_F(LazyOpsTest, TestMultiIndexTailBroadcast) { for (torch::ScalarType scalar_type : - {torch::kFloat, torch::kByte, torch::kChar, torch::kShort, torch::kInt, + {torch::kFloat, + torch::kByte, + torch::kChar, + torch::kShort, + torch::kInt, torch::kLong}) { - torch::Tensor params = - isFloatingType(scalar_type) - ? torch::rand( - {4, 3, 5, 6, 7}, - torch::TensorOptions(scalar_type).device(DefaultDevice())) - : torch::randint( - 100, {4, 3, 5, 6, 7}, - torch::TensorOptions(scalar_type).device(DefaultDevice())); + torch::Tensor params = isFloatingType(scalar_type) + ? torch::rand( + {4, 3, 5, 6, 7}, + torch::TensorOptions(scalar_type).device(DefaultDevice())) + : torch::randint( + 100, + {4, 3, 5, 6, 7}, + torch::TensorOptions(scalar_type).device(DefaultDevice())); torch::Tensor indices_0 = torch::randint( - -3, 3, {2, 1, 3}, + -3, + 3, + {2, 1, 3}, torch::TensorOptions(torch::kLong).device(DefaultDevice())); torch::Tensor indices_1 = torch::randint( - -3, 3, {2, 1}, + -3, + 3, + {2, 1}, torch::TensorOptions(torch::kLong).device(DefaultDevice())); torch::Tensor result = torch::index(params, {indices_0, indices_1}); ForEachDevice([&](const torch::Device& device) { @@ -4949,18 +5203,23 @@ TEST_F(LazyOpsTest, TestMultiIndexTailBroadcast) { TEST_F(LazyOpsTest, TestMaskIndex) { for (torch::ScalarType scalar_type : - {torch::kFloat, torch::kByte, torch::kChar, torch::kShort, torch::kInt, + {torch::kFloat, + torch::kByte, + torch::kChar, + torch::kShort, + torch::kInt, torch::kLong}) { - torch::Tensor params = - isFloatingType(scalar_type) - ? torch::rand( - {2, 2}, - torch::TensorOptions(scalar_type).device(DefaultDevice())) - : torch::randint( - 100, {2, 2}, - torch::TensorOptions(scalar_type).device(DefaultDevice())); + torch::Tensor params = isFloatingType(scalar_type) + ? torch::rand( + {2, 2}, torch::TensorOptions(scalar_type).device(DefaultDevice())) + : torch::randint( + 100, + {2, 2}, + torch::TensorOptions(scalar_type).device(DefaultDevice())); torch::Tensor indices = torch::randint( - 0, 2, {2, 2}, + 0, + 2, + {2, 2}, torch::TensorOptions(torch::kBool).device(DefaultDevice())); torch::Tensor result = torch::index(params, {indices}); ForEachDevice([&](const torch::Device& device) { @@ -4974,27 +5233,33 @@ TEST_F(LazyOpsTest, TestMaskIndex) { TEST_F(LazyOpsTest, TestOneIndexPut) { for (torch::ScalarType scalar_type : - {torch::kFloat, torch::kByte, torch::kChar, torch::kShort, torch::kInt, + {torch::kFloat, + torch::kByte, + torch::kChar, + torch::kShort, + torch::kInt, torch::kLong}) { - torch::Tensor params = - isFloatingType(scalar_type) - ? torch::rand( - {4, 3, 5, 6, 7}, - torch::TensorOptions(scalar_type).device(DefaultDevice())) - : torch::randint( - 100, {4, 3, 5, 6, 7}, - torch::TensorOptions(scalar_type).device(DefaultDevice())); + torch::Tensor params = isFloatingType(scalar_type) + ? torch::rand( + {4, 3, 5, 6, 7}, + torch::TensorOptions(scalar_type).device(DefaultDevice())) + : torch::randint( + 100, + {4, 3, 5, 6, 7}, + torch::TensorOptions(scalar_type).device(DefaultDevice())); torch::Tensor indices = torch::randint( - -3, 3, {2, 4, 3}, + -3, + 3, + {2, 4, 3}, torch::TensorOptions(torch::kLong).device(DefaultDevice())); - torch::Tensor values = - isFloatingType(scalar_type) - ? torch::rand( - {3, 5, 6, 7}, - torch::TensorOptions(scalar_type).device(DefaultDevice())) - : torch::randint( - 100, {3, 5, 6, 7}, - torch::TensorOptions(scalar_type).device(DefaultDevice())); + torch::Tensor values = isFloatingType(scalar_type) + ? torch::rand( + {3, 5, 6, 7}, + torch::TensorOptions(scalar_type).device(DefaultDevice())) + : torch::randint( + 100, + {3, 5, 6, 7}, + torch::TensorOptions(scalar_type).device(DefaultDevice())); for (bool accumulate : {false, true}) { if (accumulate && IsCuda()) { GTEST_SKIP(); @@ -5005,8 +5270,8 @@ TEST_F(LazyOpsTest, TestOneIndexPut) { torch::Tensor lazy_params = CopyToDevice(params, device); torch::Tensor lazy_indices = CopyToDevice(indices, device); torch::Tensor lazy_values = CopyToDevice(values, device); - torch::Tensor lazy_result = - torch::index_put(lazy_params, {lazy_indices}, lazy_values, accumulate); + torch::Tensor lazy_result = torch::index_put( + lazy_params, {lazy_indices}, lazy_values, accumulate); AllEqual(result, lazy_result); }); } @@ -5015,34 +5280,40 @@ TEST_F(LazyOpsTest, TestOneIndexPut) { TEST_F(LazyOpsTest, TestOneIndexPutInPlace) { torch::Tensor indices = torch::randint( - -3, 3, {2, 4, 3}, + -3, + 3, + {2, 4, 3}, torch::TensorOptions(torch::kLong).device(DefaultDevice())); for (torch::ScalarType scalar_type : - {torch::kFloat, torch::kByte, torch::kChar, torch::kShort, torch::kInt, + {torch::kFloat, + torch::kByte, + torch::kChar, + torch::kShort, + torch::kInt, torch::kLong}) { - torch::Tensor values = - torch::ones({3, 5, 6, 7}, - torch::TensorOptions(scalar_type).device(DefaultDevice())); + torch::Tensor values = torch::ones( + {3, 5, 6, 7}, + torch::TensorOptions(scalar_type).device(DefaultDevice())); for (bool accumulate : {false, true}) { if (accumulate && IsCuda()) { GTEST_SKIP(); } ForEachDevice([&](const torch::Device& device) { - torch::Tensor params = - isFloatingType(scalar_type) - ? torch::rand( - {4, 3, 5, 6, 7}, - torch::TensorOptions(scalar_type).device(DefaultDevice())) - : torch::randint(100, {4, 3, 5, 6, 7}, - torch::TensorOptions(scalar_type) - .device(DefaultDevice())); + torch::Tensor params = isFloatingType(scalar_type) + ? torch::rand( + {4, 3, 5, 6, 7}, + torch::TensorOptions(scalar_type).device(DefaultDevice())) + : torch::randint( + 100, + {4, 3, 5, 6, 7}, + torch::TensorOptions(scalar_type).device(DefaultDevice())); torch::Tensor lazy_params = CopyToDevice(params.clone(), device); torch::Tensor result = torch::index_put_(params, {indices}, values, accumulate); torch::Tensor lazy_indices = CopyToDevice(indices, device); torch::Tensor lazy_values = CopyToDevice(values, device); - torch::Tensor lazy_result = torch::index_put_(lazy_params, {lazy_indices}, - lazy_values, accumulate); + torch::Tensor lazy_result = torch::index_put_( + lazy_params, {lazy_indices}, lazy_values, accumulate); AllEqual(result, lazy_result); AllEqual(params, lazy_params); }); @@ -5052,22 +5323,28 @@ TEST_F(LazyOpsTest, TestOneIndexPutInPlace) { TEST_F(LazyOpsTest, TestOneIndexPutTransfer) { torch::Tensor indices = torch::randint( - -3, 3, {2, 4, 3}, + -3, + 3, + {2, 4, 3}, torch::TensorOptions(torch::kLong).device(DefaultDevice())); for (torch::ScalarType scalar_type : - {torch::kFloat, torch::kByte, torch::kChar, torch::kShort, torch::kInt, + {torch::kFloat, + torch::kByte, + torch::kChar, + torch::kShort, + torch::kInt, torch::kLong}) { - torch::Tensor params = - isFloatingType(scalar_type) - ? torch::rand( - {4, 3, 5, 6, 7}, - torch::TensorOptions(scalar_type).device(DefaultDevice())) - : torch::randint( - 100, {4, 3, 5, 6, 7}, - torch::TensorOptions(scalar_type).device(DefaultDevice())); - torch::Tensor values = - torch::ones({3, 5, 6, 7}, - torch::TensorOptions(scalar_type).device(DefaultDevice())); + torch::Tensor params = isFloatingType(scalar_type) + ? torch::rand( + {4, 3, 5, 6, 7}, + torch::TensorOptions(scalar_type).device(DefaultDevice())) + : torch::randint( + 100, + {4, 3, 5, 6, 7}, + torch::TensorOptions(scalar_type).device(DefaultDevice())); + torch::Tensor values = torch::ones( + {3, 5, 6, 7}, + torch::TensorOptions(scalar_type).device(DefaultDevice())); for (bool accumulate : {false, true}) { if (accumulate && IsCuda()) { GTEST_SKIP(); @@ -5087,22 +5364,30 @@ TEST_F(LazyOpsTest, TestOneIndexPutTransfer) { TEST_F(LazyOpsTest, TestMultiIndexPut) { torch::Tensor indices_0 = torch::randint( - -3, 3, {2, 4, 3}, + -3, + 3, + {2, 4, 3}, torch::TensorOptions(torch::kLong).device(DefaultDevice())); torch::Tensor indices_1 = torch::randint( - -3, 3, {2, 4, 3}, + -3, + 3, + {2, 4, 3}, torch::TensorOptions(torch::kLong).device(DefaultDevice())); for (torch::ScalarType scalar_type : - {torch::kFloat, torch::kByte, torch::kChar, torch::kShort, torch::kInt, + {torch::kFloat, + torch::kByte, + torch::kChar, + torch::kShort, + torch::kInt, torch::kLong}) { - torch::Tensor params = - isFloatingType(scalar_type) - ? torch::rand( - {4, 3, 5, 6, 7}, - torch::TensorOptions(scalar_type).device(DefaultDevice())) - : torch::randint( - 100, {4, 3, 5, 6, 7}, - torch::TensorOptions(scalar_type).device(DefaultDevice())); + torch::Tensor params = isFloatingType(scalar_type) + ? torch::rand( + {4, 3, 5, 6, 7}, + torch::TensorOptions(scalar_type).device(DefaultDevice())) + : torch::randint( + 100, + {4, 3, 5, 6, 7}, + torch::TensorOptions(scalar_type).device(DefaultDevice())); torch::Tensor values = torch::ones( {5, 6, 7}, torch::TensorOptions(scalar_type).device(DefaultDevice())); for (bool accumulate : {false, true}) { @@ -5117,7 +5402,10 @@ TEST_F(LazyOpsTest, TestMultiIndexPut) { torch::Tensor lazy_indices_1 = CopyToDevice(indices_1, device); torch::Tensor lazy_values = CopyToDevice(values, device); torch::Tensor lazy_result = torch::index_put( - lazy_params, {lazy_indices_0, lazy_indices_1}, lazy_values, accumulate); + lazy_params, + {lazy_indices_0, lazy_indices_1}, + lazy_values, + accumulate); AllEqual(result, lazy_result); }); } @@ -5126,23 +5414,31 @@ TEST_F(LazyOpsTest, TestMultiIndexPut) { TEST_F(LazyOpsTest, TestMultiIndexPutHeadNull) { torch::Tensor indices_0 = torch::randint( - -3, 3, {2, 4, 3}, + -3, + 3, + {2, 4, 3}, torch::TensorOptions(torch::kLong).device(DefaultDevice())); torch::Tensor indices_null; torch::Tensor indices_1 = torch::randint( - -3, 3, {2, 4, 3}, + -3, + 3, + {2, 4, 3}, torch::TensorOptions(torch::kLong).device(DefaultDevice())); for (torch::ScalarType scalar_type : - {torch::kFloat, torch::kByte, torch::kChar, torch::kShort, torch::kInt, + {torch::kFloat, + torch::kByte, + torch::kChar, + torch::kShort, + torch::kInt, torch::kLong}) { - torch::Tensor params = - isFloatingType(scalar_type) - ? torch::rand( - {4, 3, 3, 6, 7}, - torch::TensorOptions(scalar_type).device(DefaultDevice())) - : torch::randint( - 100, {4, 3, 3, 6, 7}, - torch::TensorOptions(scalar_type).device(DefaultDevice())); + torch::Tensor params = isFloatingType(scalar_type) + ? torch::rand( + {4, 3, 3, 6, 7}, + torch::TensorOptions(scalar_type).device(DefaultDevice())) + : torch::randint( + 100, + {4, 3, 3, 6, 7}, + torch::TensorOptions(scalar_type).device(DefaultDevice())); torch::Tensor values = torch::ones( {3, 6, 7}, torch::TensorOptions(scalar_type).device(DefaultDevice())); for (bool accumulate : {false, true}) { @@ -5157,8 +5453,10 @@ TEST_F(LazyOpsTest, TestMultiIndexPutHeadNull) { torch::Tensor lazy_indices_1 = CopyToDevice(indices_1, device); torch::Tensor lazy_values = CopyToDevice(values, device); torch::Tensor lazy_result = torch::index_put( - lazy_params, {indices_null, lazy_indices_0, lazy_indices_1}, - lazy_values, accumulate); + lazy_params, + {indices_null, lazy_indices_0, lazy_indices_1}, + lazy_values, + accumulate); AllEqual(result, lazy_result); }); } @@ -5167,23 +5465,31 @@ TEST_F(LazyOpsTest, TestMultiIndexPutHeadNull) { TEST_F(LazyOpsTest, TestMultiIndexPutMiddleNull) { torch::Tensor indices_0 = torch::randint( - -3, 3, {2, 4, 3}, + -3, + 3, + {2, 4, 3}, torch::TensorOptions(torch::kLong).device(DefaultDevice())); torch::Tensor indices_null; torch::Tensor indices_1 = torch::randint( - -3, 3, {2, 4, 3}, + -3, + 3, + {2, 4, 3}, torch::TensorOptions(torch::kLong).device(DefaultDevice())); for (torch::ScalarType scalar_type : - {torch::kFloat, torch::kByte, torch::kChar, torch::kShort, torch::kInt, + {torch::kFloat, + torch::kByte, + torch::kChar, + torch::kShort, + torch::kInt, torch::kLong}) { - torch::Tensor params = - isFloatingType(scalar_type) - ? torch::rand( - {4, 3, 3, 6, 7}, - torch::TensorOptions(scalar_type).device(DefaultDevice())) - : torch::randint( - 100, {4, 3, 3, 6, 7}, - torch::TensorOptions(scalar_type).device(DefaultDevice())); + torch::Tensor params = isFloatingType(scalar_type) + ? torch::rand( + {4, 3, 3, 6, 7}, + torch::TensorOptions(scalar_type).device(DefaultDevice())) + : torch::randint( + 100, + {4, 3, 3, 6, 7}, + torch::TensorOptions(scalar_type).device(DefaultDevice())); torch::Tensor values = torch::ones( {3, 6, 7}, torch::TensorOptions(scalar_type).device(DefaultDevice())); for (bool accumulate : {false, true}) { @@ -5198,8 +5504,10 @@ TEST_F(LazyOpsTest, TestMultiIndexPutMiddleNull) { torch::Tensor lazy_indices_1 = CopyToDevice(indices_1, device); torch::Tensor lazy_values = CopyToDevice(values, device); torch::Tensor lazy_result = torch::index_put( - lazy_params, {lazy_indices_0, indices_null, lazy_indices_1}, - lazy_values, accumulate); + lazy_params, + {lazy_indices_0, indices_null, lazy_indices_1}, + lazy_values, + accumulate); AllEqual(result, lazy_result); }); } @@ -5208,23 +5516,31 @@ TEST_F(LazyOpsTest, TestMultiIndexPutMiddleNull) { TEST_F(LazyOpsTest, TestMultiIndexPutTailNull) { torch::Tensor indices_0 = torch::randint( - -3, 3, {2, 4, 3}, + -3, + 3, + {2, 4, 3}, torch::TensorOptions(torch::kLong).device(DefaultDevice())); torch::Tensor indices_1 = torch::randint( - -3, 3, {2, 4, 3}, + -3, + 3, + {2, 4, 3}, torch::TensorOptions(torch::kLong).device(DefaultDevice())); torch::Tensor indices_null; for (torch::ScalarType scalar_type : - {torch::kFloat, torch::kByte, torch::kChar, torch::kShort, torch::kInt, + {torch::kFloat, + torch::kByte, + torch::kChar, + torch::kShort, + torch::kInt, torch::kLong}) { - torch::Tensor params = - isFloatingType(scalar_type) - ? torch::rand( - {4, 3, 3, 6, 7}, - torch::TensorOptions(scalar_type).device(DefaultDevice())) - : torch::randint( - 100, {4, 3, 3, 6, 7}, - torch::TensorOptions(scalar_type).device(DefaultDevice())); + torch::Tensor params = isFloatingType(scalar_type) + ? torch::rand( + {4, 3, 3, 6, 7}, + torch::TensorOptions(scalar_type).device(DefaultDevice())) + : torch::randint( + 100, + {4, 3, 3, 6, 7}, + torch::TensorOptions(scalar_type).device(DefaultDevice())); torch::Tensor values = torch::ones( {3, 6, 7}, torch::TensorOptions(scalar_type).device(DefaultDevice())); for (bool accumulate : {false, true}) { @@ -5239,8 +5555,10 @@ TEST_F(LazyOpsTest, TestMultiIndexPutTailNull) { torch::Tensor lazy_indices_1 = CopyToDevice(indices_1, device); torch::Tensor lazy_values = CopyToDevice(values, device); torch::Tensor lazy_result = torch::index_put( - lazy_params, {lazy_indices_0, lazy_indices_1, indices_null}, - lazy_values, accumulate); + lazy_params, + {lazy_indices_0, lazy_indices_1, indices_null}, + lazy_values, + accumulate); AllEqual(result, lazy_result); }); } @@ -5249,22 +5567,30 @@ TEST_F(LazyOpsTest, TestMultiIndexPutTailNull) { TEST_F(LazyOpsTest, TestMultiIndexPutMiddleBroadcast) { torch::Tensor indices_0 = torch::randint( - -3, 3, {2, 4, 3}, + -3, + 3, + {2, 4, 3}, torch::TensorOptions(torch::kLong).device(DefaultDevice())); torch::Tensor indices_1 = torch::randint( - -3, 3, {2, 1, 3}, + -3, + 3, + {2, 1, 3}, torch::TensorOptions(torch::kLong).device(DefaultDevice())); for (torch::ScalarType scalar_type : - {torch::kFloat, torch::kByte, torch::kChar, torch::kShort, torch::kInt, + {torch::kFloat, + torch::kByte, + torch::kChar, + torch::kShort, + torch::kInt, torch::kLong}) { - torch::Tensor params = - isFloatingType(scalar_type) - ? torch::rand( - {4, 3, 5, 6, 7}, - torch::TensorOptions(scalar_type).device(DefaultDevice())) - : torch::randint( - 100, {4, 3, 5, 6, 7}, - torch::TensorOptions(scalar_type).device(DefaultDevice())); + torch::Tensor params = isFloatingType(scalar_type) + ? torch::rand( + {4, 3, 5, 6, 7}, + torch::TensorOptions(scalar_type).device(DefaultDevice())) + : torch::randint( + 100, + {4, 3, 5, 6, 7}, + torch::TensorOptions(scalar_type).device(DefaultDevice())); torch::Tensor values = torch::ones( {5, 6, 7}, torch::TensorOptions(scalar_type).device(DefaultDevice())); for (bool accumulate : {false, true}) { @@ -5279,7 +5605,10 @@ TEST_F(LazyOpsTest, TestMultiIndexPutMiddleBroadcast) { torch::Tensor lazy_indices_1 = CopyToDevice(indices_1, device); torch::Tensor lazy_values = CopyToDevice(values, device); torch::Tensor lazy_result = torch::index_put( - lazy_params, {lazy_indices_0, lazy_indices_1}, lazy_values, accumulate); + lazy_params, + {lazy_indices_0, lazy_indices_1}, + lazy_values, + accumulate); AllEqual(result, lazy_result); }); } @@ -5288,22 +5617,30 @@ TEST_F(LazyOpsTest, TestMultiIndexPutMiddleBroadcast) { TEST_F(LazyOpsTest, TestMultiIndexPutTailBroadcast) { torch::Tensor indices_0 = torch::randint( - -3, 3, {2, 1, 3}, + -3, + 3, + {2, 1, 3}, torch::TensorOptions(torch::kLong).device(DefaultDevice())); torch::Tensor indices_1 = torch::randint( - -3, 3, {2, 1}, + -3, + 3, + {2, 1}, torch::TensorOptions(torch::kLong).device(DefaultDevice())); for (torch::ScalarType scalar_type : - {torch::kFloat, torch::kByte, torch::kChar, torch::kShort, torch::kInt, + {torch::kFloat, + torch::kByte, + torch::kChar, + torch::kShort, + torch::kInt, torch::kLong}) { - torch::Tensor params = - isFloatingType(scalar_type) - ? torch::rand( - {4, 3, 5, 6, 7}, - torch::TensorOptions(scalar_type).device(DefaultDevice())) - : torch::randint( - 100, {4, 3, 5, 6, 7}, - torch::TensorOptions(scalar_type).device(DefaultDevice())); + torch::Tensor params = isFloatingType(scalar_type) + ? torch::rand( + {4, 3, 5, 6, 7}, + torch::TensorOptions(scalar_type).device(DefaultDevice())) + : torch::randint( + 100, + {4, 3, 5, 6, 7}, + torch::TensorOptions(scalar_type).device(DefaultDevice())); torch::Tensor values = torch::ones( {5, 6, 7}, torch::TensorOptions(scalar_type).device(DefaultDevice())); for (bool accumulate : {false, true}) { @@ -5318,7 +5655,10 @@ TEST_F(LazyOpsTest, TestMultiIndexPutTailBroadcast) { torch::Tensor lazy_indices_1 = CopyToDevice(indices_1, device); torch::Tensor lazy_values = CopyToDevice(values, device); torch::Tensor lazy_result = torch::index_put( - lazy_params, {lazy_indices_0, lazy_indices_1}, lazy_values, accumulate); + lazy_params, + {lazy_indices_0, lazy_indices_1}, + lazy_values, + accumulate); AllEqual(result, lazy_result); }); } @@ -5327,20 +5667,23 @@ TEST_F(LazyOpsTest, TestMultiIndexPutTailBroadcast) { TEST_F(LazyOpsTest, TestMaskIndexPut) { torch::Tensor indices = - torch::tensor({0, 1}, - torch::TensorOptions(torch::kByte).device(DefaultDevice())) + torch::tensor( + {0, 1}, torch::TensorOptions(torch::kByte).device(DefaultDevice())) .to(torch::kBool); for (torch::ScalarType scalar_type : - {torch::kFloat, torch::kByte, torch::kChar, torch::kShort, torch::kInt, + {torch::kFloat, + torch::kByte, + torch::kChar, + torch::kShort, + torch::kInt, torch::kLong}) { - torch::Tensor params = - isFloatingType(scalar_type) - ? torch::rand( - {2, 2}, - torch::TensorOptions(scalar_type).device(DefaultDevice())) - : torch::randint( - 100, {2, 2}, - torch::TensorOptions(scalar_type).device(DefaultDevice())); + torch::Tensor params = isFloatingType(scalar_type) + ? torch::rand( + {2, 2}, torch::TensorOptions(scalar_type).device(DefaultDevice())) + : torch::randint( + 100, + {2, 2}, + torch::TensorOptions(scalar_type).device(DefaultDevice())); torch::Tensor values = torch::ones( {2}, torch::TensorOptions(scalar_type).device(DefaultDevice())); for (bool accumulate : {false, true}) { @@ -5350,8 +5693,8 @@ TEST_F(LazyOpsTest, TestMaskIndexPut) { torch::Tensor lazy_params = CopyToDevice(params, device); torch::Tensor lazy_indices = CopyToDevice(indices, device); torch::Tensor lazy_values = CopyToDevice(values, device); - torch::Tensor lazy_result = - torch::index_put(lazy_params, {lazy_indices}, lazy_values, accumulate); + torch::Tensor lazy_result = torch::index_put( + lazy_params, {lazy_indices}, lazy_values, accumulate); AllEqual(result, lazy_result); }); } @@ -5360,34 +5703,44 @@ TEST_F(LazyOpsTest, TestMaskIndexPut) { TEST_F(LazyOpsTest, TestIndexPutImpl) { torch::Tensor indices = torch::randint( - -3, 3, {2, 4, 3}, + -3, + 3, + {2, 4, 3}, torch::TensorOptions(torch::kLong).device(DefaultDevice())); for (torch::ScalarType scalar_type : - {torch::kFloat, torch::kByte, torch::kChar, torch::kShort, torch::kInt, + {torch::kFloat, + torch::kByte, + torch::kChar, + torch::kShort, + torch::kInt, torch::kLong}) { - torch::Tensor values = - torch::ones({3, 5, 6, 7}, - torch::TensorOptions(scalar_type).device(DefaultDevice())); + torch::Tensor values = torch::ones( + {3, 5, 6, 7}, + torch::TensorOptions(scalar_type).device(DefaultDevice())); for (bool accumulate : {false, true}) { if (accumulate && IsCuda()) { GTEST_SKIP(); } ForEachDevice([&](const torch::Device& device) { - torch::Tensor params = - isFloatingType(scalar_type) - ? torch::rand( - {4, 3, 5, 6, 7}, - torch::TensorOptions(scalar_type).device(DefaultDevice())) - : torch::randint(100, {4, 3, 5, 6, 7}, - torch::TensorOptions(scalar_type) - .device(DefaultDevice())); + torch::Tensor params = isFloatingType(scalar_type) + ? torch::rand( + {4, 3, 5, 6, 7}, + torch::TensorOptions(scalar_type).device(DefaultDevice())) + : torch::randint( + 100, + {4, 3, 5, 6, 7}, + torch::TensorOptions(scalar_type).device(DefaultDevice())); torch::Tensor lazy_params = CopyToDevice(params.clone(), device); torch::Tensor result = torch::_index_put_impl_( params, {indices}, values, accumulate, /*unsafe=*/true); torch::Tensor lazy_indices = CopyToDevice(indices, device); torch::Tensor lazy_values = CopyToDevice(values, device); torch::Tensor lazy_result = torch::_index_put_impl_( - lazy_params, {lazy_indices}, lazy_values, accumulate, /*unsafe=*/true); + lazy_params, + {lazy_indices}, + lazy_values, + accumulate, + /*unsafe=*/true); AllEqual(result, lazy_result); AllEqual(params, lazy_params); }); @@ -5400,16 +5753,20 @@ TEST_F(LazyOpsTest, TestIndexFillWithScalar) { {0, 2}, torch::TensorOptions(torch::kLong).device(DefaultDevice())); torch::Scalar value = 42; for (torch::ScalarType scalar_type : - {torch::kFloat, torch::kByte, torch::kChar, torch::kShort, torch::kInt, + {torch::kFloat, + torch::kByte, + torch::kChar, + torch::kShort, + torch::kInt, torch::kLong}) { - torch::Tensor base = - isFloatingType(scalar_type) - ? torch::rand( - {3, 4, 5}, - torch::TensorOptions(scalar_type).device(DefaultDevice())) - : torch::randint( - 100, {3, 4, 5}, - torch::TensorOptions(scalar_type).device(DefaultDevice())); + torch::Tensor base = isFloatingType(scalar_type) + ? torch::rand( + {3, 4, 5}, + torch::TensorOptions(scalar_type).device(DefaultDevice())) + : torch::randint( + 100, + {3, 4, 5}, + torch::TensorOptions(scalar_type).device(DefaultDevice())); int rank = base.dim(); for (int dim = -rank; dim < rank; ++dim) { torch::Tensor result = torch::index_fill(base, dim, index, value); @@ -5430,22 +5787,27 @@ TEST_F(LazyOpsTest, TestIndexFillWithScalarInPlace) { torch::Scalar value = 42; int rank = 3; for (torch::ScalarType scalar_type : - {torch::kFloat, torch::kByte, torch::kChar, torch::kShort, torch::kInt, + {torch::kFloat, + torch::kByte, + torch::kChar, + torch::kShort, + torch::kInt, torch::kLong}) { for (int dim = -rank; dim < rank; ++dim) { ForEachDevice([&](const torch::Device& device) { - torch::Tensor base = - isFloatingType(scalar_type) - ? torch::rand( - {3, 4, 5}, - torch::TensorOptions(scalar_type).device(DefaultDevice())) - : torch::randint(100, {3, 4, 5}, - torch::TensorOptions(scalar_type) - .device(DefaultDevice())); + torch::Tensor base = isFloatingType(scalar_type) + ? torch::rand( + {3, 4, 5}, + torch::TensorOptions(scalar_type).device(DefaultDevice())) + : torch::randint( + 100, + {3, 4, 5}, + torch::TensorOptions(scalar_type).device(DefaultDevice())); torch::Tensor lazy_base = CopyToDevice(base.clone(), device); torch::Tensor result = base.index_fill_(dim, index, value); torch::Tensor lazy_index = CopyToDevice(index, device); - torch::Tensor lazy_result = lazy_base.index_fill_(dim, lazy_index, value); + torch::Tensor lazy_result = + lazy_base.index_fill_(dim, lazy_index, value); AllEqual(result, lazy_result); AllEqual(base, lazy_base); }); @@ -5457,16 +5819,20 @@ TEST_F(LazyOpsTest, TestIndexFillWithTensor) { torch::Tensor index = torch::tensor( {0, 2}, torch::TensorOptions(torch::kLong).device(DefaultDevice())); for (torch::ScalarType scalar_type : - {torch::kFloat, torch::kByte, torch::kChar, torch::kShort, torch::kInt, + {torch::kFloat, + torch::kByte, + torch::kChar, + torch::kShort, + torch::kInt, torch::kLong}) { - torch::Tensor base = - isFloatingType(scalar_type) - ? torch::rand( - {3, 4, 5}, - torch::TensorOptions(scalar_type).device(DefaultDevice())) - : torch::randint( - 100, {3, 4, 5}, - torch::TensorOptions(scalar_type).device(DefaultDevice())); + torch::Tensor base = isFloatingType(scalar_type) + ? torch::rand( + {3, 4, 5}, + torch::TensorOptions(scalar_type).device(DefaultDevice())) + : torch::randint( + 100, + {3, 4, 5}, + torch::TensorOptions(scalar_type).device(DefaultDevice())); torch::Tensor value = torch::scalar_tensor( 42, torch::TensorOptions(scalar_type).device(DefaultDevice())); int rank = base.dim(); @@ -5488,21 +5854,25 @@ TEST_F(LazyOpsTest, TestIndexFillWithTensorInPlace) { torch::Tensor index = torch::tensor( {0, 2}, torch::TensorOptions(torch::kLong).device(DefaultDevice())); for (torch::ScalarType scalar_type : - {torch::kFloat, torch::kByte, torch::kChar, torch::kShort, torch::kInt, + {torch::kFloat, + torch::kByte, + torch::kChar, + torch::kShort, + torch::kInt, torch::kLong}) { torch::Tensor value = torch::scalar_tensor( 42, torch::TensorOptions(scalar_type).device(DefaultDevice())); int rank = 3; for (int dim = -rank; dim < rank; ++dim) { ForEachDevice([&](const torch::Device& device) { - torch::Tensor base = - isFloatingType(scalar_type) - ? torch::rand( - {3, 4, 5}, - torch::TensorOptions(scalar_type).device(DefaultDevice())) - : torch::randint(100, {3, 4, 5}, - torch::TensorOptions(scalar_type) - .device(DefaultDevice())); + torch::Tensor base = isFloatingType(scalar_type) + ? torch::rand( + {3, 4, 5}, + torch::TensorOptions(scalar_type).device(DefaultDevice())) + : torch::randint( + 100, + {3, 4, 5}, + torch::TensorOptions(scalar_type).device(DefaultDevice())); torch::Tensor lazy_base = CopyToDevice(base.clone(), device); torch::Tensor result = base.index_fill_(dim, index, value); torch::Tensor lazy_index = CopyToDevice(index, device); @@ -5520,16 +5890,20 @@ TEST_F(LazyOpsTest, TestIndexFillRank0) { torch::Tensor index = torch::scalar_tensor( 2, torch::TensorOptions(torch::kLong).device(DefaultDevice())); for (torch::ScalarType scalar_type : - {torch::kFloat, torch::kByte, torch::kChar, torch::kShort, torch::kInt, + {torch::kFloat, + torch::kByte, + torch::kChar, + torch::kShort, + torch::kInt, torch::kLong}) { - torch::Tensor base = - isFloatingType(scalar_type) - ? torch::rand( - {3, 4, 5}, - torch::TensorOptions(scalar_type).device(DefaultDevice())) - : torch::randint( - 100, {3, 4, 5}, - torch::TensorOptions(scalar_type).device(DefaultDevice())); + torch::Tensor base = isFloatingType(scalar_type) + ? torch::rand( + {3, 4, 5}, + torch::TensorOptions(scalar_type).device(DefaultDevice())) + : torch::randint( + 100, + {3, 4, 5}, + torch::TensorOptions(scalar_type).device(DefaultDevice())); torch::Tensor value = torch::scalar_tensor( 42, torch::TensorOptions(scalar_type).device(DefaultDevice())); int rank = base.dim(); @@ -5550,34 +5924,40 @@ TEST_F(LazyOpsTest, TestIndexFillRank0) { TEST_F(LazyOpsTest, TestIndexAdd) { int index_size = 10; for (torch::ScalarType scalar_type : - {torch::kFloat, torch::kByte, torch::kChar, torch::kShort, torch::kInt, + {torch::kFloat, + torch::kByte, + torch::kChar, + torch::kShort, + torch::kInt, torch::kLong}) { - torch::Tensor base = - isFloatingType(scalar_type) - ? torch::rand( - {5, 3, 7}, - torch::TensorOptions(scalar_type).device(DefaultDevice())) - : torch::randint( - 100, {5, 3, 7}, - torch::TensorOptions(scalar_type).device(DefaultDevice())); + torch::Tensor base = isFloatingType(scalar_type) + ? torch::rand( + {5, 3, 7}, + torch::TensorOptions(scalar_type).device(DefaultDevice())) + : torch::randint( + 100, + {5, 3, 7}, + torch::TensorOptions(scalar_type).device(DefaultDevice())); int rank = base.dim(); for (int dim = -rank; dim < rank; ++dim) { for (torch::ScalarType index_scalar_type : {torch::kInt, torch::kLong}) { torch::Tensor index = torch::randint( - 0, base.size(dim), {index_size}, + 0, + base.size(dim), + {index_size}, torch::TensorOptions(index_scalar_type).device(DefaultDevice())); - std::vector value_sizes(base.sizes().begin(), - base.sizes().end()); + std::vector value_sizes( + base.sizes().begin(), base.sizes().end()); int canonical_dim = dim < 0 ? dim + rank : dim; value_sizes[canonical_dim] = index_size; - torch::Tensor value = - isFloatingType(scalar_type) - ? torch::rand( - value_sizes, - torch::TensorOptions(scalar_type).device(DefaultDevice())) - : torch::randint(100, value_sizes, - torch::TensorOptions(scalar_type) - .device(DefaultDevice())); + torch::Tensor value = isFloatingType(scalar_type) + ? torch::rand( + value_sizes, + torch::TensorOptions(scalar_type).device(DefaultDevice())) + : torch::randint( + 100, + value_sizes, + torch::TensorOptions(scalar_type).device(DefaultDevice())); torch::Tensor result = torch::index_add(base, dim, index, value); ForEachDevice([&](const torch::Device& device) { torch::Tensor lazy_base = CopyToDevice(base, device); @@ -5596,33 +5976,39 @@ TEST_F(LazyOpsTest, TestIndexAddInPlace) { int index_size = 10; int rank = 3; for (torch::ScalarType scalar_type : - {torch::kFloat, torch::kByte, torch::kChar, torch::kShort, torch::kInt, + {torch::kFloat, + torch::kByte, + torch::kChar, + torch::kShort, + torch::kInt, torch::kLong}) { for (int dim = -rank; dim < rank; ++dim) { ForEachDevice([&](const torch::Device& device) { - torch::Tensor base = - isFloatingType(scalar_type) - ? torch::rand( - {5, 3, 7}, - torch::TensorOptions(scalar_type).device(DefaultDevice())) - : torch::randint(100, {5, 3, 7}, - torch::TensorOptions(scalar_type) - .device(DefaultDevice())); + torch::Tensor base = isFloatingType(scalar_type) + ? torch::rand( + {5, 3, 7}, + torch::TensorOptions(scalar_type).device(DefaultDevice())) + : torch::randint( + 100, + {5, 3, 7}, + torch::TensorOptions(scalar_type).device(DefaultDevice())); torch::Tensor index = torch::randint( - 0, base.size(dim), {index_size}, + 0, + base.size(dim), + {index_size}, torch::TensorOptions(torch::kLong).device(DefaultDevice())); - std::vector value_sizes(base.sizes().begin(), - base.sizes().end()); + std::vector value_sizes( + base.sizes().begin(), base.sizes().end()); int canonical_dim = dim < 0 ? dim + rank : dim; value_sizes[canonical_dim] = index_size; - torch::Tensor value = - isFloatingType(scalar_type) - ? torch::rand( - value_sizes, - torch::TensorOptions(scalar_type).device(DefaultDevice())) - : torch::randint(100, value_sizes, - torch::TensorOptions(scalar_type) - .device(DefaultDevice())); + torch::Tensor value = isFloatingType(scalar_type) + ? torch::rand( + value_sizes, + torch::TensorOptions(scalar_type).device(DefaultDevice())) + : torch::randint( + 100, + value_sizes, + torch::TensorOptions(scalar_type).device(DefaultDevice())); torch::Tensor lazy_base = CopyToDevice(base.clone(), device); torch::Tensor result = base.index_add_(dim, index, value); torch::Tensor lazy_index = CopyToDevice(index, device); @@ -5638,33 +6024,39 @@ TEST_F(LazyOpsTest, TestIndexAddInPlace) { TEST_F(LazyOpsTest, TestIndexAddRank0) { for (torch::ScalarType scalar_type : - {torch::kFloat, torch::kByte, torch::kChar, torch::kShort, torch::kInt, + {torch::kFloat, + torch::kByte, + torch::kChar, + torch::kShort, + torch::kInt, torch::kLong}) { - torch::Tensor base = - isFloatingType(scalar_type) - ? torch::rand( - {5, 3, 7}, - torch::TensorOptions(scalar_type).device(DefaultDevice())) - : torch::randint( - 100, {5, 3, 7}, - torch::TensorOptions(scalar_type).device(DefaultDevice())); + torch::Tensor base = isFloatingType(scalar_type) + ? torch::rand( + {5, 3, 7}, + torch::TensorOptions(scalar_type).device(DefaultDevice())) + : torch::randint( + 100, + {5, 3, 7}, + torch::TensorOptions(scalar_type).device(DefaultDevice())); int rank = base.dim(); for (int dim = -rank; dim < rank; ++dim) { torch::Tensor index = torch::randint( - 0, base.size(dim), at::IntArrayRef{}, + 0, + base.size(dim), + at::IntArrayRef{}, torch::TensorOptions(torch::kLong).device(DefaultDevice())); - std::vector value_sizes(base.sizes().begin(), - base.sizes().end()); + std::vector value_sizes( + base.sizes().begin(), base.sizes().end()); int canonical_dim = dim < 0 ? dim + rank : dim; value_sizes[canonical_dim] = 1; - torch::Tensor value = - isFloatingType(scalar_type) - ? torch::rand( - value_sizes, - torch::TensorOptions(scalar_type).device(DefaultDevice())) - : torch::randint( - 100, value_sizes, - torch::TensorOptions(scalar_type).device(DefaultDevice())); + torch::Tensor value = isFloatingType(scalar_type) + ? torch::rand( + value_sizes, + torch::TensorOptions(scalar_type).device(DefaultDevice())) + : torch::randint( + 100, + value_sizes, + torch::TensorOptions(scalar_type).device(DefaultDevice())); torch::Tensor result = torch::index_add(base, dim, index, value); ForEachDevice([&](const torch::Device& device) { torch::Tensor lazy_base = CopyToDevice(base, device); @@ -5680,29 +6072,33 @@ TEST_F(LazyOpsTest, TestIndexAddRank0) { TEST_F(LazyOpsTest, TestIndexCopy) { for (torch::ScalarType scalar_type : - {torch::kFloat, torch::kByte, torch::kChar, torch::kShort, torch::kInt, + {torch::kFloat, + torch::kByte, + torch::kChar, + torch::kShort, + torch::kInt, torch::kLong}) { - torch::Tensor base = - isFloatingType(scalar_type) - ? torch::rand( - {5, 3, 7}, - torch::TensorOptions(scalar_type).device(DefaultDevice())) - : torch::randint( - 100, {5, 3, 7}, - torch::TensorOptions(scalar_type).device(DefaultDevice())); + torch::Tensor base = isFloatingType(scalar_type) + ? torch::rand( + {5, 3, 7}, + torch::TensorOptions(scalar_type).device(DefaultDevice())) + : torch::randint( + 100, + {5, 3, 7}, + torch::TensorOptions(scalar_type).device(DefaultDevice())); int rank = base.dim(); for (int dim = -rank; dim < rank; ++dim) { torch::Tensor index = torch::randperm( base.size(dim), torch::TensorOptions(torch::kLong).device(DefaultDevice())); - torch::Tensor value = - isFloatingType(scalar_type) - ? torch::rand( - base.sizes(), - torch::TensorOptions(scalar_type).device(DefaultDevice())) - : torch::randint( - 100, base.sizes(), - torch::TensorOptions(scalar_type).device(DefaultDevice())); + torch::Tensor value = isFloatingType(scalar_type) + ? torch::rand( + base.sizes(), + torch::TensorOptions(scalar_type).device(DefaultDevice())) + : torch::randint( + 100, + base.sizes(), + torch::TensorOptions(scalar_type).device(DefaultDevice())); torch::Tensor result = torch::index_copy(base, dim, index, value); ForEachDevice([&](const torch::Device& device) { torch::Tensor lazy_base = CopyToDevice(base, device); @@ -5723,33 +6119,39 @@ TEST_F(LazyOpsTest, TestIndexCopyInPlace) { int index_size = 10; int rank = 3; for (torch::ScalarType scalar_type : - {torch::kFloat, torch::kByte, torch::kChar, torch::kShort, torch::kInt, + {torch::kFloat, + torch::kByte, + torch::kChar, + torch::kShort, + torch::kInt, torch::kLong}) { for (int dim = -rank; dim < rank; ++dim) { ForEachDevice([&](const torch::Device& device) { - torch::Tensor base = - isFloatingType(scalar_type) - ? torch::rand( - {5, 3, 7}, - torch::TensorOptions(scalar_type).device(DefaultDevice())) - : torch::randint(100, {5, 3, 7}, - torch::TensorOptions(scalar_type) - .device(DefaultDevice())); + torch::Tensor base = isFloatingType(scalar_type) + ? torch::rand( + {5, 3, 7}, + torch::TensorOptions(scalar_type).device(DefaultDevice())) + : torch::randint( + 100, + {5, 3, 7}, + torch::TensorOptions(scalar_type).device(DefaultDevice())); torch::Tensor index = torch::randint( - 0, base.size(dim), {index_size}, + 0, + base.size(dim), + {index_size}, torch::TensorOptions(torch::kLong).device(DefaultDevice())); - std::vector value_sizes(base.sizes().begin(), - base.sizes().end()); + std::vector value_sizes( + base.sizes().begin(), base.sizes().end()); int canonical_dim = dim < 0 ? dim + rank : dim; value_sizes[canonical_dim] = index_size; - torch::Tensor value = - isFloatingType(scalar_type) - ? torch::rand( - value_sizes, - torch::TensorOptions(scalar_type).device(DefaultDevice())) - : torch::randint(100, value_sizes, - torch::TensorOptions(scalar_type) - .device(DefaultDevice())); + torch::Tensor value = isFloatingType(scalar_type) + ? torch::rand( + value_sizes, + torch::TensorOptions(scalar_type).device(DefaultDevice())) + : torch::randint( + 100, + value_sizes, + torch::TensorOptions(scalar_type).device(DefaultDevice())); torch::Tensor lazy_base = CopyToDevice(base.clone(), device); torch::Tensor result = base.index_copy_(dim, index, value); torch::Tensor lazy_index = CopyToDevice(index, device); @@ -5765,33 +6167,39 @@ TEST_F(LazyOpsTest, TestIndexCopyInPlace) { TEST_F(LazyOpsTest, TestIndexCopyRank0) { for (torch::ScalarType scalar_type : - {torch::kFloat, torch::kByte, torch::kChar, torch::kShort, torch::kInt, + {torch::kFloat, + torch::kByte, + torch::kChar, + torch::kShort, + torch::kInt, torch::kLong}) { - torch::Tensor base = - isFloatingType(scalar_type) - ? torch::rand( - {5, 3, 7}, - torch::TensorOptions(scalar_type).device(DefaultDevice())) - : torch::randint( - 100, {5, 3, 7}, - torch::TensorOptions(scalar_type).device(DefaultDevice())); + torch::Tensor base = isFloatingType(scalar_type) + ? torch::rand( + {5, 3, 7}, + torch::TensorOptions(scalar_type).device(DefaultDevice())) + : torch::randint( + 100, + {5, 3, 7}, + torch::TensorOptions(scalar_type).device(DefaultDevice())); int rank = base.dim(); for (int dim = -rank; dim < rank; ++dim) { torch::Tensor index = torch::randint( - 0, base.size(dim), at::IntArrayRef{}, + 0, + base.size(dim), + at::IntArrayRef{}, torch::TensorOptions(torch::kLong).device(DefaultDevice())); - std::vector value_sizes(base.sizes().begin(), - base.sizes().end()); + std::vector value_sizes( + base.sizes().begin(), base.sizes().end()); int canonical_dim = dim < 0 ? dim + rank : dim; value_sizes[canonical_dim] = 1; - torch::Tensor value = - isFloatingType(scalar_type) - ? torch::rand( - value_sizes, - torch::TensorOptions(scalar_type).device(DefaultDevice())) - : torch::randint( - 100, value_sizes, - torch::TensorOptions(scalar_type).device(DefaultDevice())); + torch::Tensor value = isFloatingType(scalar_type) + ? torch::rand( + value_sizes, + torch::TensorOptions(scalar_type).device(DefaultDevice())) + : torch::randint( + 100, + value_sizes, + torch::TensorOptions(scalar_type).device(DefaultDevice())); torch::Tensor result = torch::index_copy(base, dim, index, value); ForEachDevice([&](const torch::Device& device) { torch::Tensor lazy_base = CopyToDevice(base, device); @@ -5806,9 +6214,9 @@ TEST_F(LazyOpsTest, TestIndexCopyRank0) { } TEST_F(LazyOpsTest, TestRelu) { - torch::Tensor input = - torch::rand({2, 1, 4, 6}, - torch::TensorOptions(torch::kFloat).device(DefaultDevice())); + torch::Tensor input = torch::rand( + {2, 1, 4, 6}, + torch::TensorOptions(torch::kFloat).device(DefaultDevice())); torch::Tensor output = torch::relu(input); ForEachDevice([&](const torch::Device& device) { torch::Tensor lazy_input = CopyToDevice(input, device); @@ -5818,9 +6226,9 @@ TEST_F(LazyOpsTest, TestRelu) { } TEST_F(LazyOpsTest, TestReluInPlace) { - torch::Tensor input = - torch::rand({2, 1, 4, 6}, - torch::TensorOptions(torch::kFloat).device(DefaultDevice())); + torch::Tensor input = torch::rand( + {2, 1, 4, 6}, + torch::TensorOptions(torch::kFloat).device(DefaultDevice())); ForEachDevice([&](const torch::Device& device) { torch::Tensor lazy_input = CopyToDevice(input, device); torch::Tensor output = torch::relu_(input); @@ -5869,10 +6277,14 @@ TEST_F(LazyOpsTest, TestHardSigmoidBackward) { return torch::hardsigmoid(inputs[0]); }; ForEachDevice([&](const torch::Device& device) { - TestBackward({torch::randn({10}, torch::TensorOptions(torch::kFloat) - .device(DefaultDevice()) - .requires_grad(true))}, - device, testfn); + TestBackward( + {torch::randn( + {10}, + torch::TensorOptions(torch::kFloat) + .device(DefaultDevice()) + .requires_grad(true))}, + device, + testfn); }); } @@ -5911,9 +6323,9 @@ TEST_F(LazyOpsTest, TestHardtanhInPlace) { } TEST_F(LazyOpsTest, TestLeakyRelu) { - torch::Tensor input = - torch::rand({2, 1, 4, 6}, - torch::TensorOptions(torch::kFloat).device(DefaultDevice())); + torch::Tensor input = torch::rand( + {2, 1, 4, 6}, + torch::TensorOptions(torch::kFloat).device(DefaultDevice())); double negative_slope = 0.01; torch::Tensor output = torch::leaky_relu(input, negative_slope); ForEachDevice([&](const torch::Device& device) { @@ -5924,9 +6336,9 @@ TEST_F(LazyOpsTest, TestLeakyRelu) { } TEST_F(LazyOpsTest, TestLeakyReluInPlace) { - torch::Tensor input = - torch::rand({2, 1, 4, 6}, - torch::TensorOptions(torch::kFloat).device(DefaultDevice())); + torch::Tensor input = torch::rand( + {2, 1, 4, 6}, + torch::TensorOptions(torch::kFloat).device(DefaultDevice())); double negative_slope = 0.01; ForEachDevice([&](const torch::Device& device) { torch::Tensor lazy_input = CopyToDevice(input, device); @@ -6162,8 +6574,8 @@ TEST_F(LazyOpsTest, TestPowIntExponent) { TEST_F(LazyOpsTest, TestFmodScalar) { torch::Tensor a = - torch::rand({2, 2}, - torch::TensorOptions(torch::kFloat).device(DefaultDevice())) * + torch::rand( + {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice())) * 100.0; torch::Scalar divisor = 2.0; torch::Tensor b = torch::fmod(a, divisor); @@ -6192,12 +6604,12 @@ TEST_F(LazyOpsTest, TestFmodScalarInPlace) { TEST_F(LazyOpsTest, TestFmodTensor) { torch::Tensor a = - torch::rand({2, 2}, - torch::TensorOptions(torch::kFloat).device(DefaultDevice())) * + torch::rand( + {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice())) * 100.0; torch::Tensor b = - torch::rand({2, 2}, - torch::TensorOptions(torch::kFloat).device(DefaultDevice())) * + torch::rand( + {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice())) * 10.0; torch::Tensor c = torch::fmod(a, b); ForEachDevice([&](const torch::Device& device) { @@ -6210,8 +6622,8 @@ TEST_F(LazyOpsTest, TestFmodTensor) { TEST_F(LazyOpsTest, TestFmodTensorInPlace) { torch::Tensor b = - torch::rand({2, 2}, - torch::TensorOptions(torch::kFloat).device(DefaultDevice())) * + torch::rand( + {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice())) * 10.0; ForEachDevice([&](const torch::Device& device) { torch::Tensor a = @@ -6341,9 +6753,9 @@ TEST_F(LazyOpsTest, TestWhereBroadcast) { } TEST_F(LazyOpsTest, TestThreshold) { - torch::Tensor input = - torch::rand({2, 1, 4, 6}, - torch::TensorOptions(torch::kFloat).device(DefaultDevice())); + torch::Tensor input = torch::rand( + {2, 1, 4, 6}, + torch::TensorOptions(torch::kFloat).device(DefaultDevice())); float threshold = 0.4; float value = 20; torch::Tensor output = torch::threshold(input, threshold, value); @@ -6358,22 +6770,27 @@ TEST_F(LazyOpsTest, TestThresholdBackward) { float threshold = 0.4; float value = 20; - auto testFunction = [&](const std::vector& inputs) -> torch::Tensor { + auto testFunction = + [&](const std::vector& inputs) -> torch::Tensor { return torch::threshold(inputs[0], threshold, value); }; ForEachDevice([&](const torch::Device& device) { - TestBackward({torch::rand({2, 1, 4, 6}, torch::TensorOptions(torch::kFloat) - .device(DefaultDevice()) - .requires_grad(true))}, - device, testFunction); + TestBackward( + {torch::rand( + {2, 1, 4, 6}, + torch::TensorOptions(torch::kFloat) + .device(DefaultDevice()) + .requires_grad(true))}, + device, + testFunction); }); } TEST_F(LazyOpsTest, TestThresholdInPlace) { - torch::Tensor input = - torch::rand({2, 1, 4, 6}, - torch::TensorOptions(torch::kFloat).device(DefaultDevice())); + torch::Tensor input = torch::rand( + {2, 1, 4, 6}, + torch::TensorOptions(torch::kFloat).device(DefaultDevice())); torch::Tensor output = input.clone(); float threshold = 0.4; float value = 20; @@ -6386,24 +6803,25 @@ TEST_F(LazyOpsTest, TestThresholdInPlace) { } TEST_F(LazyOpsTest, TestElu) { - torch::Tensor input = - torch::rand({2, 1, 4, 6}, - torch::TensorOptions(torch::kFloat).device(DefaultDevice())); + torch::Tensor input = torch::rand( + {2, 1, 4, 6}, + torch::TensorOptions(torch::kFloat).device(DefaultDevice())); torch::Scalar alpha = 0.5; torch::Scalar scale = 2.5; torch::Scalar input_scale = 1.5; torch::Tensor output = torch::elu(input, alpha, scale, input_scale); ForEachDevice([&](const torch::Device& device) { torch::Tensor lazy_input = CopyToDevice(input, device); - torch::Tensor lazy_output = torch::elu(lazy_input, alpha, scale, input_scale); + torch::Tensor lazy_output = + torch::elu(lazy_input, alpha, scale, input_scale); AllClose(output, lazy_output); }); } TEST_F(LazyOpsTest, TestEluInPlace) { - torch::Tensor input = - torch::rand({2, 1, 4, 6}, - torch::TensorOptions(torch::kFloat).device(DefaultDevice())); + torch::Tensor input = torch::rand( + {2, 1, 4, 6}, + torch::TensorOptions(torch::kFloat).device(DefaultDevice())); torch::Scalar alpha = 0.5; torch::Scalar scale = 2.5; torch::Scalar input_scale = 1.5; @@ -6418,9 +6836,9 @@ TEST_F(LazyOpsTest, TestEluInPlace) { } TEST_F(LazyOpsTest, TestSelu) { - torch::Tensor input = - torch::rand({2, 1, 4, 6}, - torch::TensorOptions(torch::kFloat).device(DefaultDevice())); + torch::Tensor input = torch::rand( + {2, 1, 4, 6}, + torch::TensorOptions(torch::kFloat).device(DefaultDevice())); torch::Tensor output = torch::selu(input); ForEachDevice([&](const torch::Device& device) { torch::Tensor lazy_input = CopyToDevice(input, device); @@ -6430,9 +6848,9 @@ TEST_F(LazyOpsTest, TestSelu) { } TEST_F(LazyOpsTest, TestSeluInPlace) { - torch::Tensor input = - torch::rand({2, 1, 4, 6}, - torch::TensorOptions(torch::kFloat).device(DefaultDevice())); + torch::Tensor input = torch::rand( + {2, 1, 4, 6}, + torch::TensorOptions(torch::kFloat).device(DefaultDevice())); ForEachDevice([&](const torch::Device& device) { torch::Tensor lazy_input = CopyToDevice(input, device); torch::Tensor output = torch::selu_(input); @@ -6443,9 +6861,9 @@ TEST_F(LazyOpsTest, TestSeluInPlace) { } TEST_F(LazyOpsTest, TestCelu) { - torch::Tensor input = - torch::rand({2, 1, 4, 6}, - torch::TensorOptions(torch::kFloat).device(DefaultDevice())); + torch::Tensor input = torch::rand( + {2, 1, 4, 6}, + torch::TensorOptions(torch::kFloat).device(DefaultDevice())); torch::Scalar alpha = 2.5; torch::Tensor output = torch::celu(input, alpha); ForEachDevice([&](const torch::Device& device) { @@ -6456,9 +6874,9 @@ TEST_F(LazyOpsTest, TestCelu) { } TEST_F(LazyOpsTest, TestCeluInPlace) { - torch::Tensor input = - torch::rand({2, 1, 4, 6}, - torch::TensorOptions(torch::kFloat).device(DefaultDevice())); + torch::Tensor input = torch::rand( + {2, 1, 4, 6}, + torch::TensorOptions(torch::kFloat).device(DefaultDevice())); torch::Scalar alpha = 2.5; ForEachDevice([&](const torch::Device& device) { torch::Tensor lazy_input = CopyToDevice(input, device); @@ -6484,12 +6902,12 @@ TEST_F(LazyOpsTest, TestAddMatMul) { int in_channels = 32; int out_channels = 320; int labels = 50; - torch::Tensor input = - torch::rand({in_channels, out_channels}, - torch::TensorOptions(torch::kFloat).device(DefaultDevice())); - torch::Tensor weight = - torch::rand({out_channels, labels}, - torch::TensorOptions(torch::kFloat).device(DefaultDevice())); + torch::Tensor input = torch::rand( + {in_channels, out_channels}, + torch::TensorOptions(torch::kFloat).device(DefaultDevice())); + torch::Tensor weight = torch::rand( + {out_channels, labels}, + torch::TensorOptions(torch::kFloat).device(DefaultDevice())); torch::Tensor bias = torch::rand( {labels}, torch::TensorOptions(torch::kFloat).device(DefaultDevice())); // Test beta != 1. through the CPU interop. @@ -6510,17 +6928,25 @@ TEST_F(LazyOpsTest, TestEmbedding) { torch::Tensor a = torch::rand( {32, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice())); torch::Tensor i = torch::randint( - 0, 31, {3, 4}, + 0, + 31, + {3, 4}, torch::TensorOptions(torch::kLong).device(DefaultDevice())); - torch::Tensor b = - torch::embedding(a, i, /*padding_idx=*/0, /*scale_grad_by_freq=*/false, - /*sparse=*/false); + torch::Tensor b = torch::embedding( + a, + i, + /*padding_idx=*/0, + /*scale_grad_by_freq=*/false, + /*sparse=*/false); ForEachDevice([&](const torch::Device& device) { torch::Tensor lazy_a = CopyToDevice(a, device); torch::Tensor lazy_i = CopyToDevice(i, device); - torch::Tensor lazy_b = torch::embedding(lazy_a, lazy_i, /*padding_idx=*/0, - /*scale_grad_by_freq=*/false, - /*sparse=*/false); + torch::Tensor lazy_b = torch::embedding( + lazy_a, + lazy_i, + /*padding_idx=*/0, + /*scale_grad_by_freq=*/false, + /*sparse=*/false); AllClose(b, lazy_b); }); } @@ -6528,7 +6954,9 @@ TEST_F(LazyOpsTest, TestEmbedding) { TEST_F(LazyOpsTest, TestOneHot) { int num_classes = 5; torch::Tensor input = torch::randint( - 0, num_classes, {10}, + 0, + num_classes, + {10}, torch::TensorOptions(torch::kLong).device(DefaultDevice())); torch::Tensor output = torch::one_hot(input, num_classes); ForEachDevice([&](const torch::Device& device) { @@ -6563,9 +6991,9 @@ TEST_F(LazyOpsTest, TestTransposeInPlace) { } TEST_F(LazyOpsTest, TestReshape) { - torch::Tensor input = - torch::rand({32, 20, 4, 4}, - torch::TensorOptions(torch::kFloat).device(DefaultDevice())); + torch::Tensor input = torch::rand( + {32, 20, 4, 4}, + torch::TensorOptions(torch::kFloat).device(DefaultDevice())); torch::Tensor output = torch::reshape(input, {-1, 320}); ForEachDevice([&](const torch::Device& device) { torch::Tensor lazy_input = CopyToDevice(input, device); @@ -6604,9 +7032,9 @@ TEST_F(LazyOpsTest, TestViewResize) { } TEST_F(LazyOpsTest, TestView) { - torch::Tensor input = - torch::rand({32, 20, 4, 4}, - torch::TensorOptions(torch::kFloat).device(DefaultDevice())); + torch::Tensor input = torch::rand( + {32, 20, 4, 4}, + torch::TensorOptions(torch::kFloat).device(DefaultDevice())); torch::Tensor output = input.view({-1, 320}); ForEachDevice([&](const torch::Device& device) { torch::Tensor lazy_input = CopyToDevice(input, device); @@ -6616,9 +7044,9 @@ TEST_F(LazyOpsTest, TestView) { } TEST_F(LazyOpsTest, TestViewMod) { - torch::Tensor input = - torch::zeros({32, 20, 4, 4}, - torch::TensorOptions(torch::kFloat).device(DefaultDevice())); + torch::Tensor input = torch::zeros( + {32, 20, 4, 4}, + torch::TensorOptions(torch::kFloat).device(DefaultDevice())); torch::Tensor one = torch::tensor( 1.0, torch::TensorOptions(torch::kFloat).device(DefaultDevice())); torch::Tensor output = input.view({-1, 320}); @@ -6639,9 +7067,9 @@ TEST_F(LazyOpsTest, TestViewMod) { } TEST_F(LazyOpsTest, TestViewModComplex) { - torch::Tensor input = - torch::zeros({32, 20, 4, 4}, - torch::TensorOptions(torch::kFloat).device(DefaultDevice())); + torch::Tensor input = torch::zeros( + {32, 20, 4, 4}, + torch::TensorOptions(torch::kFloat).device(DefaultDevice())); torch::Tensor one = torch::tensor( 1.0, torch::TensorOptions(torch::kFloat).device(DefaultDevice())); torch::Tensor output1 = input.view({-1, 320}); @@ -6664,9 +7092,9 @@ TEST_F(LazyOpsTest, TestViewModComplex) { } TEST_F(LazyOpsTest, TestViewOfViewMod) { - torch::Tensor input = - torch::zeros({32, 20, 4, 4}, - torch::TensorOptions(torch::kFloat).device(DefaultDevice())); + torch::Tensor input = torch::zeros( + {32, 20, 4, 4}, + torch::TensorOptions(torch::kFloat).device(DefaultDevice())); torch::Tensor one = torch::tensor( 1.0, torch::TensorOptions(torch::kFloat).device(DefaultDevice())); torch::Tensor output1 = input.view({-1, 320}); @@ -6710,9 +7138,9 @@ TEST_F(LazyOpsTest, TestViewSqueezeAddInPlace) { } TEST_F(LazyOpsTest, TestUnsafeView) { - torch::Tensor input = - torch::rand({32, 20, 4, 4}, - torch::TensorOptions(torch::kFloat).device(DefaultDevice())); + torch::Tensor input = torch::rand( + {32, 20, 4, 4}, + torch::TensorOptions(torch::kFloat).device(DefaultDevice())); torch::Tensor output = torch::_unsafe_view(input, {-1, 320}); ForEachDevice([&](const torch::Device& device) { torch::Tensor lazy_input = CopyToDevice(input, device); @@ -6722,9 +7150,9 @@ TEST_F(LazyOpsTest, TestUnsafeView) { } TEST_F(LazyOpsTest, TestNarrow) { - torch::Tensor a = - torch::rand({8, 10, 4, 4}, - torch::TensorOptions(torch::kFloat).device(DefaultDevice())); + torch::Tensor a = torch::rand( + {8, 10, 4, 4}, + torch::TensorOptions(torch::kFloat).device(DefaultDevice())); for (int64_t dim : {1, -3}) { for (int64_t start : {2, -8}) { torch::Tensor b = a.narrow(dim, start, 6); @@ -6884,9 +7312,9 @@ TEST_F(LazyOpsTest, TestNarrowCopy) { } TEST_F(LazyOpsTest, TestViewAs) { - torch::Tensor input = - torch::rand({32, 20, 4, 4}, - torch::TensorOptions(torch::kFloat).device(DefaultDevice())); + torch::Tensor input = torch::rand( + {32, 20, 4, 4}, + torch::TensorOptions(torch::kFloat).device(DefaultDevice())); torch::Tensor empty = torch::empty({32, 320}); torch::Tensor output = input.view_as(empty); ForEachDevice([&](const torch::Device& device) { @@ -6898,9 +7326,9 @@ TEST_F(LazyOpsTest, TestViewAs) { } TEST_F(LazyOpsTest, TestLogSoftmax) { - torch::Tensor input = - torch::rand({5, 3, 4, 2}, - torch::TensorOptions(torch::kFloat).device(DefaultDevice())); + torch::Tensor input = torch::rand( + {5, 3, 4, 2}, + torch::TensorOptions(torch::kFloat).device(DefaultDevice())); ForEachDevice([&](const torch::Device& device) { torch::Tensor lazy_input = CopyToDevice(input, device); int rank = input.dim(); @@ -6913,9 +7341,9 @@ TEST_F(LazyOpsTest, TestLogSoftmax) { } TEST_F(LazyOpsTest, TestLogSoftmaxCast) { - torch::Tensor input = - torch::rand({5, 3, 4, 2}, - torch::TensorOptions(torch::kFloat).device(DefaultDevice())); + torch::Tensor input = torch::rand( + {5, 3, 4, 2}, + torch::TensorOptions(torch::kFloat).device(DefaultDevice())); ForEachDevice([&](const torch::Device& device) { torch::Tensor lazy_input = CopyToDevice(input, device); int rank = input.dim(); @@ -6929,9 +7357,9 @@ TEST_F(LazyOpsTest, TestLogSoftmaxCast) { } TEST_F(LazyOpsTest, TestLogSoftmaxWrapper) { - torch::Tensor input = - torch::rand({10, 2, 6, 4}, - torch::TensorOptions(torch::kFloat).device(DefaultDevice())); + torch::Tensor input = torch::rand( + {10, 2, 6, 4}, + torch::TensorOptions(torch::kFloat).device(DefaultDevice())); ForEachDevice([&](const torch::Device& device) { torch::Tensor lazy_input = CopyToDevice(input, device); int rank = input.dim(); @@ -6946,9 +7374,9 @@ TEST_F(LazyOpsTest, TestLogSoftmaxWrapper) { } TEST_F(LazyOpsTest, TestSoftmax) { - torch::Tensor input = - torch::rand({10, 2, 6, 4}, - torch::TensorOptions(torch::kFloat).device(DefaultDevice())); + torch::Tensor input = torch::rand( + {10, 2, 6, 4}, + torch::TensorOptions(torch::kFloat).device(DefaultDevice())); ForEachDevice([&](const torch::Device& device) { torch::Tensor lazy_input = CopyToDevice(input, device); int rank = input.dim(); @@ -6961,24 +7389,25 @@ TEST_F(LazyOpsTest, TestSoftmax) { } TEST_F(LazyOpsTest, TestSoftmaxCast) { - torch::Tensor input = - torch::rand({10, 2, 6, 4}, - torch::TensorOptions(torch::kFloat).device(DefaultDevice())); + torch::Tensor input = torch::rand( + {10, 2, 6, 4}, + torch::TensorOptions(torch::kFloat).device(DefaultDevice())); ForEachDevice([&](const torch::Device& device) { torch::Tensor lazy_input = CopyToDevice(input, device); int rank = input.dim(); for (int dim = -rank; dim < rank; ++dim) { torch::Tensor output = torch::softmax(input, dim, torch::kDouble); - torch::Tensor lazy_output = torch::softmax(lazy_input, dim, torch::kDouble); + torch::Tensor lazy_output = + torch::softmax(lazy_input, dim, torch::kDouble); AllClose(output, lazy_output, /*rtol=*/1e-3); } }); } TEST_F(LazyOpsTest, TestSoftmaxWrapper) { - torch::Tensor input = - torch::rand({10, 2, 6, 4}, - torch::TensorOptions(torch::kFloat).device(DefaultDevice())); + torch::Tensor input = torch::rand( + {10, 2, 6, 4}, + torch::TensorOptions(torch::kFloat).device(DefaultDevice())); ForEachDevice([&](const torch::Device& device) { torch::Tensor lazy_input = CopyToDevice(input, device); int rank = input.dim(); @@ -6993,9 +7422,9 @@ TEST_F(LazyOpsTest, TestSoftmaxWrapper) { } TEST_F(LazyOpsTest, TestSoftplus) { - torch::Tensor input = - torch::rand({2, 1, 4, 6}, - torch::TensorOptions(torch::kFloat).device(DefaultDevice())); + torch::Tensor input = torch::rand( + {2, 1, 4, 6}, + torch::TensorOptions(torch::kFloat).device(DefaultDevice())); torch::Tensor output = torch::softplus(input); ForEachDevice([&](const torch::Device& device) { torch::Tensor lazy_input = CopyToDevice(input, device); @@ -7014,20 +7443,22 @@ TEST_F(LazyOpsTest, TestMaxPool1D) { for (bool ceil_mode : {false, true}) { // Test dilation through the CPU interop. for (int dilation = 1; dilation <= 2; ++dilation) { - torch::Tensor output = - torch::max_pool1d(input, /*kernel_size=*/{kernel_size}, - /*stride=*/{stride}, - /*padding=*/{padding}, /*dilation=*/{dilation}, - /*ceil_mode=*/ceil_mode); + torch::Tensor output = torch::max_pool1d( + input, + /*kernel_size=*/{kernel_size}, + /*stride=*/{stride}, + /*padding=*/{padding}, + /*dilation=*/{dilation}, + /*ceil_mode=*/ceil_mode); ForEachDevice([&](const torch::Device& device) { torch::Tensor lazy_input = CopyToDevice(input, device); - torch::Tensor lazy_output = - torch::max_pool1d(lazy_input, - /*kernel_size=*/{kernel_size}, - /*stride=*/{stride}, - /*padding=*/{padding}, - /*dilation=*/{dilation}, - /*ceil_mode=*/ceil_mode); + torch::Tensor lazy_output = torch::max_pool1d( + lazy_input, + /*kernel_size=*/{kernel_size}, + /*stride=*/{stride}, + /*padding=*/{padding}, + /*dilation=*/{dilation}, + /*ceil_mode=*/ceil_mode); AllClose(output, lazy_output); }); } @@ -7037,9 +7468,9 @@ TEST_F(LazyOpsTest, TestMaxPool1D) { } TEST_F(LazyOpsTest, TestMaxPool2D) { - torch::Tensor input = - torch::rand({1, 4, 14, 14}, - torch::TensorOptions(torch::kFloat).device(DefaultDevice())); + torch::Tensor input = torch::rand( + {1, 4, 14, 14}, + torch::TensorOptions(torch::kFloat).device(DefaultDevice())); int kernel_size = 3; for (int stride = 1; stride <= 2; ++stride) { for (int padding = 0; padding <= 1; ++padding) { @@ -7048,19 +7479,21 @@ TEST_F(LazyOpsTest, TestMaxPool2D) { // Test dilation through the CPU interop. for (int dilation = 1; dilation <= 2; ++dilation) { torch::Tensor output = torch::max_pool2d( - input, /*kernel_size=*/{kernel_size, kernel_size}, + input, + /*kernel_size=*/{kernel_size, kernel_size}, /*stride=*/{stride, stride}, - /*padding=*/{padding, padding}, /*dilation=*/{dilation, dilation}, + /*padding=*/{padding, padding}, + /*dilation=*/{dilation, dilation}, /*ceil_mode=*/ceil_mode); ForEachDevice([&](const torch::Device& device) { torch::Tensor lazy_input = CopyToDevice(input, device); - torch::Tensor lazy_output = - torch::max_pool2d(lazy_input, - /*kernel_size=*/{kernel_size, kernel_size}, - /*stride=*/{stride, stride}, - /*padding=*/{padding, padding}, - /*dilation=*/{dilation, dilation}, - /*ceil_mode=*/ceil_mode); + torch::Tensor lazy_output = torch::max_pool2d( + lazy_input, + /*kernel_size=*/{kernel_size, kernel_size}, + /*stride=*/{stride, stride}, + /*padding=*/{padding, padding}, + /*dilation=*/{dilation, dilation}, + /*ceil_mode=*/ceil_mode); AllClose(output, lazy_output); }); } @@ -7070,9 +7503,9 @@ TEST_F(LazyOpsTest, TestMaxPool2D) { } TEST_F(LazyOpsTest, TestMaxPool2DWithIndices) { - torch::Tensor input = - torch::rand({1, 4, 14, 14}, - torch::TensorOptions(torch::kFloat).device(DefaultDevice())); + torch::Tensor input = torch::rand( + {1, 4, 14, 14}, + torch::TensorOptions(torch::kFloat).device(DefaultDevice())); int kernel_size = 3; for (int stride = 1; stride <= 2; ++stride) { for (int padding = 0; padding <= 1; ++padding) { @@ -7081,9 +7514,11 @@ TEST_F(LazyOpsTest, TestMaxPool2DWithIndices) { // Test dilation through the CPU interop. for (int dilation = 1; dilation <= 2; ++dilation) { auto outputs = torch::max_pool2d_with_indices( - input, /*kernel_size=*/{kernel_size, kernel_size}, + input, + /*kernel_size=*/{kernel_size, kernel_size}, /*stride=*/{stride, stride}, - /*padding=*/{padding, padding}, /*dilation=*/{dilation, dilation}, + /*padding=*/{padding, padding}, + /*dilation=*/{dilation, dilation}, /*ceil_mode=*/ceil_mode); ForEachDevice([&](const torch::Device& device) { torch::Tensor lazy_input = CopyToDevice(input, device); @@ -7104,9 +7539,9 @@ TEST_F(LazyOpsTest, TestMaxPool2DWithIndices) { } TEST_F(LazyOpsTest, TestMaxPool2DNonSquare) { - torch::Tensor input = - torch::rand({1, 4, 14, 14}, - torch::TensorOptions(torch::kFloat).device(DefaultDevice())); + torch::Tensor input = torch::rand( + {1, 4, 14, 14}, + torch::TensorOptions(torch::kFloat).device(DefaultDevice())); int kernel_size = 4; for (int stride = 1; stride <= 2; ++stride) { for (int padding = 0; padding <= 1; ++padding) { @@ -7115,7 +7550,8 @@ TEST_F(LazyOpsTest, TestMaxPool2DNonSquare) { // Test dilation through the CPU interop. for (int dilation = 1; dilation <= 2; ++dilation) { torch::Tensor output = torch::max_pool2d( - input, /*kernel_size=*/{kernel_size, kernel_size + 1}, + input, + /*kernel_size=*/{kernel_size, kernel_size + 1}, /*stride=*/{stride, stride + 1}, /*padding=*/{padding, padding + 1}, /*dilation=*/{dilation, dilation}, @@ -7138,9 +7574,9 @@ TEST_F(LazyOpsTest, TestMaxPool2DNonSquare) { } TEST_F(LazyOpsTest, TestMaxPool3D) { - torch::Tensor input = - torch::rand({1, 1, 8, 8, 8}, - torch::TensorOptions(torch::kFloat).device(DefaultDevice())); + torch::Tensor input = torch::rand( + {1, 1, 8, 8, 8}, + torch::TensorOptions(torch::kFloat).device(DefaultDevice())); int kernel_size = 3; for (int stride = 1; stride <= 2; ++stride) { for (int padding = 0; padding <= 1; ++padding) { @@ -7149,7 +7585,8 @@ TEST_F(LazyOpsTest, TestMaxPool3D) { // Test dilation through the CPU interop. for (int dilation = 1; dilation <= 2; ++dilation) { torch::Tensor output = torch::max_pool3d( - input, /*kernel_size=*/{kernel_size, kernel_size, kernel_size}, + input, + /*kernel_size=*/{kernel_size, kernel_size, kernel_size}, /*stride=*/{stride, stride, stride}, /*padding=*/{padding, padding, padding}, /*dilation=*/{dilation, dilation, dilation}, @@ -7172,9 +7609,9 @@ TEST_F(LazyOpsTest, TestMaxPool3D) { } TEST_F(LazyOpsTest, TestMaxPool3DWithIndices) { - torch::Tensor input = - torch::rand({1, 1, 8, 8, 8}, - torch::TensorOptions(torch::kFloat).device(DefaultDevice())); + torch::Tensor input = torch::rand( + {1, 1, 8, 8, 8}, + torch::TensorOptions(torch::kFloat).device(DefaultDevice())); int kernel_size = 3; for (int stride = 1; stride <= 2; ++stride) { for (int padding = 0; padding <= 1; ++padding) { @@ -7183,7 +7620,8 @@ TEST_F(LazyOpsTest, TestMaxPool3DWithIndices) { // Test dilation through the CPU interop. for (int dilation = 1; dilation <= 2; ++dilation) { auto outputs = torch::max_pool3d_with_indices( - input, /*kernel_size=*/{kernel_size, kernel_size, kernel_size}, + input, + /*kernel_size=*/{kernel_size, kernel_size, kernel_size}, /*stride=*/{stride, stride, stride}, /*padding=*/{padding, padding, padding}, /*dilation=*/{dilation, dilation, dilation}, @@ -7208,9 +7646,9 @@ TEST_F(LazyOpsTest, TestMaxPool3DWithIndices) { } TEST_F(LazyOpsTest, TestMaxPool3DIncompleteAttributes) { - torch::Tensor input = - torch::rand({1, 1, 8, 8, 8}, - torch::TensorOptions(torch::kFloat).device(DefaultDevice())); + torch::Tensor input = torch::rand( + {1, 1, 8, 8, 8}, + torch::TensorOptions(torch::kFloat).device(DefaultDevice())); int kernel_size = 3; for (int stride = 1; stride <= 2; ++stride) { for (int padding = 0; padding <= 1; ++padding) { @@ -7219,7 +7657,8 @@ TEST_F(LazyOpsTest, TestMaxPool3DIncompleteAttributes) { // Test dilation through the CPU interop. for (int dilation = 1; dilation <= 2; ++dilation) { torch::Tensor output = torch::max_pool3d( - input, /*kernel_size=*/{kernel_size, kernel_size, kernel_size}, + input, + /*kernel_size=*/{kernel_size, kernel_size, kernel_size}, /*stride=*/{}, /*padding=*/{padding}, /*dilation=*/{dilation, dilation, dilation}, @@ -7242,9 +7681,9 @@ TEST_F(LazyOpsTest, TestMaxPool3DIncompleteAttributes) { } TEST_F(LazyOpsTest, TestMaxPool3DNonSquare) { - torch::Tensor input = - torch::rand({1, 1, 8, 8, 8}, - torch::TensorOptions(torch::kFloat).device(DefaultDevice())); + torch::Tensor input = torch::rand( + {1, 1, 8, 8, 8}, + torch::TensorOptions(torch::kFloat).device(DefaultDevice())); int kernel_size = 4; for (int stride = 1; stride <= 2; ++stride) { for (int padding = 0; padding <= 1; ++padding) { @@ -7287,19 +7726,21 @@ TEST_F(LazyOpsTest, TestMaxPool2DNoBatch) { // Test dilation through the CPU interop. for (int dilation = 1; dilation <= 2; ++dilation) { torch::Tensor output = torch::max_pool2d( - input, /*kernel_size=*/{kernel_size, kernel_size}, + input, + /*kernel_size=*/{kernel_size, kernel_size}, /*stride=*/{stride, stride}, - /*padding=*/{padding, padding}, /*dilation=*/{dilation, dilation}, + /*padding=*/{padding, padding}, + /*dilation=*/{dilation, dilation}, /*ceil_mode=*/ceil_mode); ForEachDevice([&](const torch::Device& device) { torch::Tensor lazy_input = CopyToDevice(input, device); - torch::Tensor lazy_output = - torch::max_pool2d(lazy_input, - /*kernel_size=*/{kernel_size, kernel_size}, - /*stride=*/{stride, stride}, - /*padding=*/{padding, padding}, - /*dilation=*/{dilation, dilation}, - /*ceil_mode=*/ceil_mode); + torch::Tensor lazy_output = torch::max_pool2d( + lazy_input, + /*kernel_size=*/{kernel_size, kernel_size}, + /*stride=*/{stride, stride}, + /*padding=*/{padding, padding}, + /*dilation=*/{dilation, dilation}, + /*ceil_mode=*/ceil_mode); AllClose(output, lazy_output); }); } @@ -7309,9 +7750,9 @@ TEST_F(LazyOpsTest, TestMaxPool2DNoBatch) { } TEST_F(LazyOpsTest, TestMaxPool3DNoBatch) { - torch::Tensor input = - torch::rand({1, 8, 8, 8}, - torch::TensorOptions(torch::kFloat).device(DefaultDevice())); + torch::Tensor input = torch::rand( + {1, 8, 8, 8}, + torch::TensorOptions(torch::kFloat).device(DefaultDevice())); int kernel_size = 3; for (int stride = 1; stride <= 2; ++stride) { for (int padding = 0; padding <= 1; ++padding) { @@ -7320,7 +7761,8 @@ TEST_F(LazyOpsTest, TestMaxPool3DNoBatch) { // Test dilation through the CPU interop. for (int dilation = 1; dilation <= 2; ++dilation) { torch::Tensor output = torch::max_pool3d( - input, /*kernel_size=*/{kernel_size, kernel_size, kernel_size}, + input, + /*kernel_size=*/{kernel_size, kernel_size, kernel_size}, /*stride=*/{stride, stride, stride}, /*padding=*/{padding, padding, padding}, /*dilation=*/{dilation, dilation, dilation}, @@ -7351,20 +7793,22 @@ TEST_F(LazyOpsTest, TestAvgPool1D) { for (bool count_include_pad : {true, false}) { // Test ceil_mode=true through the CPU interop. for (bool ceil_mode : {false, true}) { - torch::Tensor output = - torch::avg_pool1d(input, /*kernel_size=*/{kernel_size}, - /*stride=*/{stride}, - /*padding=*/{padding}, /*ceil_mode=*/ceil_mode, - /*count_include_pad=*/count_include_pad); + torch::Tensor output = torch::avg_pool1d( + input, + /*kernel_size=*/{kernel_size}, + /*stride=*/{stride}, + /*padding=*/{padding}, + /*ceil_mode=*/ceil_mode, + /*count_include_pad=*/count_include_pad); ForEachDevice([&](const torch::Device& device) { torch::Tensor lazy_input = CopyToDevice(input, device); - torch::Tensor lazy_output = - torch::avg_pool1d(lazy_input, - /*kernel_size=*/{kernel_size}, - /*stride=*/{stride}, - /*padding=*/{padding}, - /*ceil_mode=*/ceil_mode, - /*count_include_pad=*/count_include_pad); + torch::Tensor lazy_output = torch::avg_pool1d( + lazy_input, + /*kernel_size=*/{kernel_size}, + /*stride=*/{stride}, + /*padding=*/{padding}, + /*ceil_mode=*/ceil_mode, + /*count_include_pad=*/count_include_pad); AllClose(output, lazy_output); }); } @@ -7374,9 +7818,9 @@ TEST_F(LazyOpsTest, TestAvgPool1D) { } TEST_F(LazyOpsTest, TestAvgPool2D) { - torch::Tensor input = - torch::rand({2, 1, 14, 14}, - torch::TensorOptions(torch::kFloat).device(DefaultDevice())); + torch::Tensor input = torch::rand( + {2, 1, 14, 14}, + torch::TensorOptions(torch::kFloat).device(DefaultDevice())); int kernel_size = 2; for (int stride = 1; stride <= 2; ++stride) { for (int padding = 0; padding <= 1; ++padding) { @@ -7384,20 +7828,22 @@ TEST_F(LazyOpsTest, TestAvgPool2D) { // Test ceil_mode=true through the CPU interop. for (bool ceil_mode : {false, true}) { torch::Tensor output = torch::avg_pool2d( - input, /*kernel_size=*/{kernel_size, kernel_size}, + input, + /*kernel_size=*/{kernel_size, kernel_size}, /*stride=*/{stride, stride}, - /*padding=*/{padding, padding}, /*ceil_mode=*/ceil_mode, + /*padding=*/{padding, padding}, + /*ceil_mode=*/ceil_mode, /*count_include_pad=*/count_include_pad); ForEachDevice([&](const torch::Device& device) { // torch::Tensor lazy_input = CopyToDevice(input, device); torch::Tensor lazy_input = CopyToDevice(input, device); - torch::Tensor lazy_output = - torch::avg_pool2d(lazy_input, - /*kernel_size=*/{kernel_size, kernel_size}, - /*stride=*/{stride, stride}, - /*padding=*/{padding, padding}, - /*ceil_mode=*/ceil_mode, - /*count_include_pad=*/count_include_pad); + torch::Tensor lazy_output = torch::avg_pool2d( + lazy_input, + /*kernel_size=*/{kernel_size, kernel_size}, + /*stride=*/{stride, stride}, + /*padding=*/{padding, padding}, + /*ceil_mode=*/ceil_mode, + /*count_include_pad=*/count_include_pad); AllClose(output, lazy_output.to(torch::kCPU)); }); } @@ -7407,9 +7853,9 @@ TEST_F(LazyOpsTest, TestAvgPool2D) { } TEST_F(LazyOpsTest, TestAvgPool2DNonSquare) { - torch::Tensor input = - torch::rand({2, 1, 14, 14}, - torch::TensorOptions(torch::kFloat).device(DefaultDevice())); + torch::Tensor input = torch::rand( + {2, 1, 14, 14}, + torch::TensorOptions(torch::kFloat).device(DefaultDevice())); int kernel_size = 4; for (int stride = 1; stride <= 2; ++stride) { for (int padding = 0; padding <= 1; ++padding) { @@ -7417,9 +7863,11 @@ TEST_F(LazyOpsTest, TestAvgPool2DNonSquare) { // Test ceil_mode=true through the CPU interop. for (bool ceil_mode : {false, true}) { torch::Tensor output = torch::avg_pool2d( - input, /*kernel_size=*/{kernel_size, kernel_size + 1}, + input, + /*kernel_size=*/{kernel_size, kernel_size + 1}, /*stride=*/{stride, stride + 1}, - /*padding=*/{padding, padding + 1}, /*ceil_mode=*/ceil_mode, + /*padding=*/{padding, padding + 1}, + /*ceil_mode=*/ceil_mode, /*count_include_pad=*/count_include_pad); ForEachDevice([&](const torch::Device& device) { torch::Tensor lazy_input = CopyToDevice(input, device); @@ -7439,9 +7887,9 @@ TEST_F(LazyOpsTest, TestAvgPool2DNonSquare) { } TEST_F(LazyOpsTest, TestAvgPool3D) { - torch::Tensor input = - torch::rand({1, 1, 7, 7, 7}, - torch::TensorOptions(torch::kFloat).device(DefaultDevice())); + torch::Tensor input = torch::rand( + {1, 1, 7, 7, 7}, + torch::TensorOptions(torch::kFloat).device(DefaultDevice())); int kernel_size = 2; for (int stride = 1; stride <= 2; ++stride) { for (int padding = 0; padding <= 1; ++padding) { @@ -7449,9 +7897,11 @@ TEST_F(LazyOpsTest, TestAvgPool3D) { // Test ceil_mode=true through the CPU interop. for (bool ceil_mode : {false, true}) { torch::Tensor output = torch::avg_pool3d( - input, /*kernel_size=*/{kernel_size, kernel_size, kernel_size}, + input, + /*kernel_size=*/{kernel_size, kernel_size, kernel_size}, /*stride=*/{stride, stride, stride}, - /*padding=*/{padding, padding, padding}, /*ceil_mode=*/ceil_mode, + /*padding=*/{padding, padding, padding}, + /*ceil_mode=*/ceil_mode, /*count_include_pad=*/count_include_pad); ForEachDevice([&](const torch::Device& device) { torch::Tensor lazy_input = CopyToDevice(input, device); @@ -7471,9 +7921,9 @@ TEST_F(LazyOpsTest, TestAvgPool3D) { } TEST_F(LazyOpsTest, TestAvgPool3DIncompleteAttributes) { - torch::Tensor input = - torch::rand({1, 1, 7, 7, 7}, - torch::TensorOptions(torch::kFloat).device(DefaultDevice())); + torch::Tensor input = torch::rand( + {1, 1, 7, 7, 7}, + torch::TensorOptions(torch::kFloat).device(DefaultDevice())); int kernel_size = 2; for (int stride = 1; stride <= 2; ++stride) { for (int padding = 0; padding <= 1; ++padding) { @@ -7481,9 +7931,11 @@ TEST_F(LazyOpsTest, TestAvgPool3DIncompleteAttributes) { // Test ceil_mode=true through the CPU interop. for (bool ceil_mode : {false, true}) { torch::Tensor output = torch::avg_pool3d( - input, /*kernel_size=*/{kernel_size, kernel_size, kernel_size}, + input, + /*kernel_size=*/{kernel_size, kernel_size, kernel_size}, /*stride=*/{}, - /*padding=*/{padding, padding, padding}, /*ceil_mode=*/ceil_mode, + /*padding=*/{padding, padding, padding}, + /*ceil_mode=*/ceil_mode, /*count_include_pad=*/count_include_pad); ForEachDevice([&](const torch::Device& device) { torch::Tensor lazy_input = CopyToDevice(input, device); @@ -7503,9 +7955,9 @@ TEST_F(LazyOpsTest, TestAvgPool3DIncompleteAttributes) { } TEST_F(LazyOpsTest, TestAvgPool3DNonSquare) { - torch::Tensor input = - torch::rand({1, 1, 7, 7, 7}, - torch::TensorOptions(torch::kFloat).device(DefaultDevice())); + torch::Tensor input = torch::rand( + {1, 1, 7, 7, 7}, + torch::TensorOptions(torch::kFloat).device(DefaultDevice())); int kernel_size = 4; for (int stride = 1; stride <= 2; ++stride) { for (int padding = 0; padding <= 1; ++padding) { @@ -7546,19 +7998,21 @@ TEST_F(LazyOpsTest, TestAvgPool2DNoBatch) { // Test ceil_mode=true through the CPU interop. for (bool ceil_mode : {false, true}) { torch::Tensor output = torch::avg_pool2d( - input, /*kernel_size=*/{kernel_size, kernel_size}, + input, + /*kernel_size=*/{kernel_size, kernel_size}, /*stride=*/{stride, stride}, - /*padding=*/{padding, padding}, /*ceil_mode=*/ceil_mode, + /*padding=*/{padding, padding}, + /*ceil_mode=*/ceil_mode, /*count_include_pad=*/count_include_pad); ForEachDevice([&](const torch::Device& device) { torch::Tensor lazy_input = CopyToDevice(input, device); - torch::Tensor lazy_output = - torch::avg_pool2d(lazy_input, - /*kernel_size=*/{kernel_size, kernel_size}, - /*stride=*/{stride, stride}, - /*padding=*/{padding, padding}, - /*ceil_mode=*/ceil_mode, - /*count_include_pad=*/count_include_pad); + torch::Tensor lazy_output = torch::avg_pool2d( + lazy_input, + /*kernel_size=*/{kernel_size, kernel_size}, + /*stride=*/{stride, stride}, + /*padding=*/{padding, padding}, + /*ceil_mode=*/ceil_mode, + /*count_include_pad=*/count_include_pad); AllClose(output, lazy_output); }); } @@ -7568,9 +8022,9 @@ TEST_F(LazyOpsTest, TestAvgPool2DNoBatch) { } TEST_F(LazyOpsTest, TestAvgPool3DNoBatch) { - torch::Tensor input = - torch::rand({1, 7, 7, 7}, - torch::TensorOptions(torch::kFloat).device(DefaultDevice())); + torch::Tensor input = torch::rand( + {1, 7, 7, 7}, + torch::TensorOptions(torch::kFloat).device(DefaultDevice())); int kernel_size = 2; for (int stride = 1; stride <= 2; ++stride) { for (int padding = 0; padding <= 1; ++padding) { @@ -7578,9 +8032,11 @@ TEST_F(LazyOpsTest, TestAvgPool3DNoBatch) { // Test ceil_mode=true through the CPU interop. for (bool ceil_mode : {false, true}) { torch::Tensor output = torch::avg_pool3d( - input, /*kernel_size=*/{kernel_size, kernel_size, kernel_size}, + input, + /*kernel_size=*/{kernel_size, kernel_size, kernel_size}, /*stride=*/{stride, stride, stride}, - /*padding=*/{padding, padding, padding}, /*ceil_mode=*/ceil_mode, + /*padding=*/{padding, padding, padding}, + /*ceil_mode=*/ceil_mode, /*count_include_pad=*/count_include_pad); ForEachDevice([&](const torch::Device& device) { torch::Tensor lazy_input = CopyToDevice(input, device); @@ -7600,9 +8056,9 @@ TEST_F(LazyOpsTest, TestAvgPool3DNoBatch) { } TEST_F(LazyOpsTest, TestAdaptiveAvgPool2D) { - torch::Tensor input = - torch::rand({4, 1, 28, 28}, - torch::TensorOptions(torch::kFloat).device(DefaultDevice())); + torch::Tensor input = torch::rand( + {4, 1, 28, 28}, + torch::TensorOptions(torch::kFloat).device(DefaultDevice())); for (int64_t output_size : {7, 4}) { torch::Tensor output = torch::adaptive_avg_pool2d(input, {output_size, output_size}); @@ -7616,9 +8072,9 @@ TEST_F(LazyOpsTest, TestAdaptiveAvgPool2D) { } TEST_F(LazyOpsTest, TestAdaptiveAvgPool3D) { - torch::Tensor input = - torch::rand({9, 4, 56, 28, 28}, - torch::TensorOptions(torch::kFloat).device(DefaultDevice())); + torch::Tensor input = torch::rand( + {9, 4, 56, 28, 28}, + torch::TensorOptions(torch::kFloat).device(DefaultDevice())); for (int64_t output_size : {7, 4}) { torch::Tensor output = torch::adaptive_avg_pool3d( input, {output_size, output_size, output_size}); @@ -7632,9 +8088,9 @@ TEST_F(LazyOpsTest, TestAdaptiveAvgPool3D) { } TEST_F(LazyOpsTest, TestAdaptiveAvgPool3DNoBatch) { - torch::Tensor input = - torch::rand({3, 56, 28, 28}, - torch::TensorOptions(torch::kFloat).device(DefaultDevice())); + torch::Tensor input = torch::rand( + {3, 56, 28, 28}, + torch::TensorOptions(torch::kFloat).device(DefaultDevice())); for (int64_t output_size : {7, 4}) { torch::Tensor output = torch::adaptive_avg_pool3d( input, {output_size, output_size, output_size}); @@ -7664,9 +8120,9 @@ TEST_F(LazyOpsTest, TestAdaptiveAvgPool2DNoBatch) { TEST_F(LazyOpsTest, TestMaxUnpool2D) { int kernel_size = 2; - torch::Tensor input = - torch::rand({2, 2, 8, 8}, - torch::TensorOptions(torch::kFloat).device(DefaultDevice())); + torch::Tensor input = torch::rand( + {2, 2, 8, 8}, + torch::TensorOptions(torch::kFloat).device(DefaultDevice())); for (int stride = 1; stride <= 2; ++stride) { for (int padding = 0; padding <= 1; ++padding) { // Test ceil_mode=true through the CPU interop. @@ -7676,9 +8132,11 @@ TEST_F(LazyOpsTest, TestMaxUnpool2D) { torch::Tensor output; torch::Tensor indices; std::tie(output, indices) = torch::max_pool2d_with_indices( - input, /*kernel_size=*/{kernel_size, kernel_size}, + input, + /*kernel_size=*/{kernel_size, kernel_size}, /*stride=*/{stride, stride}, - /*padding=*/{padding, padding}, /*dilation=*/{dilation, dilation}, + /*padding=*/{padding, padding}, + /*dilation=*/{dilation, dilation}, /*ceil_mode=*/ceil_mode); std::vector output_size({input.size(2), input.size(3)}); @@ -7700,9 +8158,9 @@ TEST_F(LazyOpsTest, TestMaxUnpool2D) { TEST_F(LazyOpsTest, TestMaxUnpool3D) { int kernel_size = 2; - torch::Tensor input = - torch::rand({1, 1, 4, 4, 4}, - torch::TensorOptions(torch::kFloat).device(DefaultDevice())); + torch::Tensor input = torch::rand( + {1, 1, 4, 4, 4}, + torch::TensorOptions(torch::kFloat).device(DefaultDevice())); for (int stride = 1; stride <= 2; ++stride) { for (int padding = 0; padding <= 1; ++padding) { // Test ceil_mode=true through the CPU interop. @@ -7712,7 +8170,8 @@ TEST_F(LazyOpsTest, TestMaxUnpool3D) { torch::Tensor output; torch::Tensor indices; std::tie(output, indices) = torch::max_pool3d_with_indices( - input, /*kernel_size=*/{kernel_size, kernel_size, kernel_size}, + input, + /*kernel_size=*/{kernel_size, kernel_size, kernel_size}, /*stride=*/{stride, stride, stride}, /*padding=*/{padding, padding, padding}, /*dilation=*/{dilation, dilation, dilation}, @@ -7721,16 +8180,21 @@ TEST_F(LazyOpsTest, TestMaxUnpool3D) { std::vector output_size( {input.size(2), input.size(3), input.size(4)}); at::Tensor utensor = torch::max_unpool3d( - output, indices, output_size, /*stride=*/{stride, stride, stride}, + output, + indices, + output_size, + /*stride=*/{stride, stride, stride}, /*padding=*/{padding, padding, padding}); ForEachDevice([&](const torch::Device& device) { torch::Tensor lazy_output = CopyToDevice(output, device); torch::Tensor lazy_indices = CopyToDevice(indices, device); - at::Tensor lazy_utensor = - torch::max_unpool3d(lazy_output, lazy_indices, output_size, - /*stride=*/{stride, stride, stride}, - /*padding=*/{padding, padding, padding}); + at::Tensor lazy_utensor = torch::max_unpool3d( + lazy_output, + lazy_indices, + output_size, + /*stride=*/{stride, stride, stride}, + /*padding=*/{padding, padding, padding}); AllClose(utensor, lazy_utensor); }); } @@ -7740,7 +8204,6 @@ TEST_F(LazyOpsTest, TestMaxUnpool3D) { } TEST_F(LazyOpsTest, TestNllLoss) { - // TODO(whc) debug divide-by-zero failure under ASAN GTEST_SKIP(); @@ -7750,11 +8213,13 @@ TEST_F(LazyOpsTest, TestNllLoss) { for (auto dtype : {torch::kFloat}) { for (int ignore_index : {-1, 0, 1, 5}) { for (bool def_weight : {false, true}) { - torch::Tensor input = - torch::rand({batch, classes}, - torch::TensorOptions(dtype).device(DefaultDevice())); + torch::Tensor input = torch::rand( + {batch, classes}, + torch::TensorOptions(dtype).device(DefaultDevice())); torch::Tensor target = torch::randint( - std::min(ignore_index, 0), classes, {batch}, + std::min(ignore_index, 0), + classes, + {batch}, torch::TensorOptions(torch::kLong).device(DefaultDevice())); torch::Tensor weight; if (def_weight) { @@ -7762,13 +8227,15 @@ TEST_F(LazyOpsTest, TestNllLoss) { {classes}, torch::TensorOptions(dtype).device(DefaultDevice())); } for (torch::Reduction::Reduction reduction : - {torch::Reduction::Mean, torch::Reduction::Sum, + {torch::Reduction::Mean, + torch::Reduction::Sum, torch::Reduction::None}) { - torch::Tensor output = - torch::nll_loss(/*self=*/input, /*target=*/target, - /*weight=*/weight, - /*reduction=*/reduction, - /*ignore_index=*/ignore_index); + torch::Tensor output = torch::nll_loss( + /*self=*/input, + /*target=*/target, + /*weight=*/weight, + /*reduction=*/reduction, + /*ignore_index=*/ignore_index); ForEachDevice([&](const torch::Device& device) { torch::Tensor lazy_input = CopyToDevice(input, device); @@ -7776,9 +8243,11 @@ TEST_F(LazyOpsTest, TestNllLoss) { torch::Tensor lazy_weight = def_weight ? CopyToDevice(weight, device) : torch::Tensor(); torch::Tensor lazy_output = torch::nll_loss( - /*self=*/lazy_input, /*target=*/lazy_target, + /*self=*/lazy_input, + /*target=*/lazy_target, /*weight=*/lazy_weight, - /*reduction=*/reduction, /*ignore_index=*/ignore_index); + /*reduction=*/reduction, + /*ignore_index=*/ignore_index); AllClose(output, lazy_output); }); } @@ -7796,11 +8265,13 @@ TEST_F(LazyOpsTest, TestNllLoss2d) { for (auto dtype : {torch::kFloat}) { for (int ignore_index : {-1, 0, 1, 5}) { for (bool def_weight : {false, true}) { - torch::Tensor input = - torch::rand({batch, classes, height, width}, - torch::TensorOptions(dtype).device(DefaultDevice())); + torch::Tensor input = torch::rand( + {batch, classes, height, width}, + torch::TensorOptions(dtype).device(DefaultDevice())); torch::Tensor target = torch::randint( - std::min(ignore_index, 0), classes, {batch, height, width}, + std::min(ignore_index, 0), + classes, + {batch, height, width}, torch::TensorOptions(torch::kLong).device(DefaultDevice())); torch::Tensor weight; if (def_weight) { @@ -7808,13 +8279,15 @@ TEST_F(LazyOpsTest, TestNllLoss2d) { {classes}, torch::TensorOptions(dtype).device(DefaultDevice())); } for (torch::Reduction::Reduction reduction : - {torch::Reduction::Mean, torch::Reduction::Sum, + {torch::Reduction::Mean, + torch::Reduction::Sum, torch::Reduction::None}) { - torch::Tensor output = - torch::nll_loss2d(/*self=*/input, /*target=*/target, - /*weight=*/weight, - /*reduction=*/reduction, - /*ignore_index=*/ignore_index); + torch::Tensor output = torch::nll_loss2d( + /*self=*/input, + /*target=*/target, + /*weight=*/weight, + /*reduction=*/reduction, + /*ignore_index=*/ignore_index); ForEachDevice([&](const torch::Device& device) { torch::Tensor lazy_input = CopyToDevice(input, device); @@ -7822,9 +8295,11 @@ TEST_F(LazyOpsTest, TestNllLoss2d) { torch::Tensor lazy_weight = def_weight ? CopyToDevice(weight, device) : torch::Tensor(); torch::Tensor lazy_output = torch::nll_loss2d( - /*self=*/lazy_input, /*target=*/lazy_target, + /*self=*/lazy_input, + /*target=*/lazy_target, /*weight=*/lazy_weight, - /*reduction=*/reduction, /*ignore_index=*/ignore_index); + /*reduction=*/reduction, + /*ignore_index=*/ignore_index); AllClose(output, lazy_output); }); } @@ -7839,7 +8314,8 @@ TEST_F(LazyOpsTest, TestSmoothL1Loss) { torch::Tensor target = torch::randn( {2, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice())); for (torch::Reduction::Reduction reduction : - {torch::Reduction::None, torch::Reduction::Mean, + {torch::Reduction::None, + torch::Reduction::Mean, torch::Reduction::Sum}) { for (double beta : {0.25, 1.}) { torch::Tensor output = @@ -7861,7 +8337,8 @@ TEST_F(LazyOpsTest, TestL1Loss) { torch::Tensor target = torch::randn( {2, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice())); for (torch::Reduction::Reduction reduction : - {torch::Reduction::None, torch::Reduction::Mean, + {torch::Reduction::None, + torch::Reduction::Mean, torch::Reduction::Sum}) { torch::Tensor output = torch::l1_loss(input, target, reduction); ForEachDevice([&](const torch::Device& device) { @@ -7876,19 +8353,25 @@ TEST_F(LazyOpsTest, TestL1Loss) { TEST_F(LazyOpsTest, TestL1LossBackward) { for (torch::Reduction::Reduction reduction : - {torch::Reduction::None, torch::Reduction::Mean, + {torch::Reduction::None, + torch::Reduction::Mean, torch::Reduction::Sum}) { auto testfn = [&](const std::vector& inputs) -> torch::Tensor { return torch::l1_loss(inputs[0], inputs[1], reduction); }; ForEachDevice([&](const torch::Device& device) { - TestBackward({torch::rand({2, 4}, torch::TensorOptions(torch::kFloat) - .device(DefaultDevice()) - .requires_grad(true)), - torch::rand({2, 4}, torch::TensorOptions(torch::kFloat) - .device(DefaultDevice()))}, - device, testfn); + TestBackward( + {torch::rand( + {2, 4}, + torch::TensorOptions(torch::kFloat) + .device(DefaultDevice()) + .requires_grad(true)), + torch::rand( + {2, 4}, + torch::TensorOptions(torch::kFloat).device(DefaultDevice()))}, + device, + testfn); }); } } @@ -7899,7 +8382,8 @@ TEST_F(LazyOpsTest, TestMseLoss) { torch::Tensor target = torch::randn( {2, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice())); for (torch::Reduction::Reduction reduction : - {torch::Reduction::None, torch::Reduction::Mean, + {torch::Reduction::None, + torch::Reduction::Mean, torch::Reduction::Sum}) { torch::Tensor output = torch::mse_loss(input, target, reduction); ForEachDevice([&](const torch::Device& device) { @@ -7914,50 +8398,60 @@ TEST_F(LazyOpsTest, TestMseLoss) { TEST_F(LazyOpsTest, TestMseLossBackward) { for (torch::Reduction::Reduction reduction : - {torch::Reduction::None, torch::Reduction::Mean, + {torch::Reduction::None, + torch::Reduction::Mean, torch::Reduction::Sum}) { auto testfn = [&](const std::vector& inputs) -> torch::Tensor { return torch::mse_loss(inputs[0], inputs[1], reduction); }; ForEachDevice([&](const torch::Device& device) { - TestBackward({torch::rand({2, 4}, torch::TensorOptions(torch::kFloat) - .device(DefaultDevice()) - .requires_grad(true)), - torch::rand({2, 4}, torch::TensorOptions(torch::kFloat) - .device(DefaultDevice()))}, - device, testfn); + TestBackward( + {torch::rand( + {2, 4}, + torch::TensorOptions(torch::kFloat) + .device(DefaultDevice()) + .requires_grad(true)), + torch::rand( + {2, 4}, + torch::TensorOptions(torch::kFloat).device(DefaultDevice()))}, + device, + testfn); }); } } TEST_F(LazyOpsTest, TestBatchNorm1D) { int num_features = 3; - torch::Tensor input = - torch::rand({2, num_features, 4}, - torch::TensorOptions(torch::kFloat).device(DefaultDevice())); - torch::Tensor weight = - torch::rand({num_features}, - torch::TensorOptions(torch::kFloat).device(DefaultDevice())); - torch::Tensor bias = - torch::rand({num_features}, - torch::TensorOptions(torch::kFloat).device(DefaultDevice())); - torch::Tensor running_mean = - torch::zeros({num_features}, - torch::TensorOptions(torch::kFloat).device(DefaultDevice())); - torch::Tensor running_var = - torch::ones({num_features}, - torch::TensorOptions(torch::kFloat).device(DefaultDevice())); + torch::Tensor input = torch::rand( + {2, num_features, 4}, + torch::TensorOptions(torch::kFloat).device(DefaultDevice())); + torch::Tensor weight = torch::rand( + {num_features}, + torch::TensorOptions(torch::kFloat).device(DefaultDevice())); + torch::Tensor bias = torch::rand( + {num_features}, + torch::TensorOptions(torch::kFloat).device(DefaultDevice())); + torch::Tensor running_mean = torch::zeros( + {num_features}, + torch::TensorOptions(torch::kFloat).device(DefaultDevice())); + torch::Tensor running_var = torch::ones( + {num_features}, + torch::TensorOptions(torch::kFloat).device(DefaultDevice())); double momentum = 0.1; double eps = 0.5; torch::Tensor undef; for (bool training : {true, false}) { for (bool undef_weight_bias : {false, true}) { torch::Tensor output = torch::batch_norm( - /*input=*/input, /*weight=*/undef_weight_bias ? undef : weight, + /*input=*/input, + /*weight=*/undef_weight_bias ? undef : weight, /*bias=*/undef_weight_bias ? undef : bias, - /*running_mean=*/running_mean, /*running_var=*/running_var, - /*training=*/training, /*momentum=*/momentum, /*eps=*/eps, + /*running_mean=*/running_mean, + /*running_var=*/running_var, + /*training=*/training, + /*momentum=*/momentum, + /*eps=*/eps, /*cudnn_enabled=*/false); ForEachDevice([&](const torch::Device& device) { torch::Tensor lazy_input = CopyToDevice(input, device); @@ -7968,9 +8462,14 @@ TEST_F(LazyOpsTest, TestBatchNorm1D) { torch::Tensor lazy_running_mean = CopyToDevice(running_mean, device); torch::Tensor lazy_running_var = CopyToDevice(running_var, device); torch::Tensor lazy_output = torch::batch_norm( - /*input=*/lazy_input, /*weight=*/lazy_weight, /*bias=*/lazy_bias, - /*running_mean=*/lazy_running_mean, /*running_var=*/lazy_running_var, - /*training=*/training, /*momentum=*/momentum, /*eps=*/eps, + /*input=*/lazy_input, + /*weight=*/lazy_weight, + /*bias=*/lazy_bias, + /*running_mean=*/lazy_running_mean, + /*running_var=*/lazy_running_var, + /*training=*/training, + /*momentum=*/momentum, + /*eps=*/eps, /*cudnn_enabled=*/false); AllClose(output, lazy_output, /*rtol=*/1e-3, /*atol=*/1e-5); }); @@ -7980,31 +8479,35 @@ TEST_F(LazyOpsTest, TestBatchNorm1D) { TEST_F(LazyOpsTest, TestBatchNorm2D) { int num_features = 3; - torch::Tensor input = - torch::rand({2, num_features, 4, 4}, - torch::TensorOptions(torch::kFloat).device(DefaultDevice())); - torch::Tensor weight = - torch::rand({num_features}, - torch::TensorOptions(torch::kFloat).device(DefaultDevice())); - torch::Tensor bias = - torch::rand({num_features}, - torch::TensorOptions(torch::kFloat).device(DefaultDevice())); - torch::Tensor running_mean = - torch::zeros({num_features}, - torch::TensorOptions(torch::kFloat).device(DefaultDevice())); - torch::Tensor running_var = - torch::ones({num_features}, - torch::TensorOptions(torch::kFloat).device(DefaultDevice())); + torch::Tensor input = torch::rand( + {2, num_features, 4, 4}, + torch::TensorOptions(torch::kFloat).device(DefaultDevice())); + torch::Tensor weight = torch::rand( + {num_features}, + torch::TensorOptions(torch::kFloat).device(DefaultDevice())); + torch::Tensor bias = torch::rand( + {num_features}, + torch::TensorOptions(torch::kFloat).device(DefaultDevice())); + torch::Tensor running_mean = torch::zeros( + {num_features}, + torch::TensorOptions(torch::kFloat).device(DefaultDevice())); + torch::Tensor running_var = torch::ones( + {num_features}, + torch::TensorOptions(torch::kFloat).device(DefaultDevice())); double momentum = 0.1; double eps = 0.5; torch::Tensor undef; for (bool training : {true, false}) { for (bool undef_weight_bias : {false, true}) { torch::Tensor output = torch::batch_norm( - /*input=*/input, /*weight=*/undef_weight_bias ? undef : weight, + /*input=*/input, + /*weight=*/undef_weight_bias ? undef : weight, /*bias=*/undef_weight_bias ? undef : bias, - /*running_mean=*/running_mean, /*running_var=*/running_var, - /*training=*/training, /*momentum=*/momentum, /*eps=*/eps, + /*running_mean=*/running_mean, + /*running_var=*/running_var, + /*training=*/training, + /*momentum=*/momentum, + /*eps=*/eps, /*cudnn_enabled=*/false); ForEachDevice([&](const torch::Device& device) { torch::Tensor lazy_input = CopyToDevice(input, device); @@ -8015,9 +8518,14 @@ TEST_F(LazyOpsTest, TestBatchNorm2D) { torch::Tensor lazy_running_mean = CopyToDevice(running_mean, device); torch::Tensor lazy_running_var = CopyToDevice(running_var, device); torch::Tensor lazy_output = torch::batch_norm( - /*input=*/lazy_input, /*weight=*/lazy_weight, /*bias=*/lazy_bias, - /*running_mean=*/lazy_running_mean, /*running_var=*/lazy_running_var, - /*training=*/training, /*momentum=*/momentum, /*eps=*/eps, + /*input=*/lazy_input, + /*weight=*/lazy_weight, + /*bias=*/lazy_bias, + /*running_mean=*/lazy_running_mean, + /*running_var=*/lazy_running_var, + /*training=*/training, + /*momentum=*/momentum, + /*eps=*/eps, /*cudnn_enabled=*/false); AllClose(output, lazy_output, /*rtol=*/1e-3, /*atol=*/1e-5); }); @@ -8046,9 +8554,9 @@ TEST_F(LazyOpsTest, TestContiguous) { } TEST_F(LazyOpsTest, TestSqueezeAll) { - torch::Tensor input = - torch::rand({2, 1, 3, 1}, - torch::TensorOptions(torch::kFloat).device(DefaultDevice())); + torch::Tensor input = torch::rand( + {2, 1, 3, 1}, + torch::TensorOptions(torch::kFloat).device(DefaultDevice())); torch::Tensor output = torch::squeeze(input); ForEachDevice([&](const torch::Device& device) { torch::Tensor lazy_input = CopyToDevice(input, device); @@ -8075,9 +8583,9 @@ TEST_F(LazyOpsTest, TestSqueezeAllInPlace) { } TEST_F(LazyOpsTest, TestSqueezeOne) { - torch::Tensor input = - torch::rand({2, 1, 3, 1}, - torch::TensorOptions(torch::kFloat).device(DefaultDevice())); + torch::Tensor input = torch::rand( + {2, 1, 3, 1}, + torch::TensorOptions(torch::kFloat).device(DefaultDevice())); int rank = input.dim(); for (int dim = -rank; dim < rank; ++dim) { torch::Tensor output = torch::squeeze(input, dim); @@ -8152,7 +8660,8 @@ TEST_F(LazyOpsTest, TestMaskedFill) { ForEachDevice([&](const torch::Device& device) { torch::Tensor lazy_input = CopyToDevice(input, device); torch::Tensor lazy_mask = CopyToDevice(mask, device); - torch::Tensor lazy_result = torch::masked_fill(lazy_input, lazy_mask, value); + torch::Tensor lazy_result = + torch::masked_fill(lazy_input, lazy_mask, value); AllClose(result, lazy_result); }); } @@ -8174,9 +8683,9 @@ TEST_F(LazyOpsTest, TestMaskedFillInPlace) { } TEST_F(LazyOpsTest, TestMaskedFillBroadcast) { - torch::Tensor input = - torch::rand({2, 5, 4, 3}, - torch::TensorOptions(torch::kFloat).device(DefaultDevice())); + torch::Tensor input = torch::rand( + {2, 5, 4, 3}, + torch::TensorOptions(torch::kFloat).device(DefaultDevice())); torch::Tensor mask = torch::randint( 0, 2, {4, 1}, torch::TensorOptions(torch::kBool).device(DefaultDevice())); torch::Scalar value(42); @@ -8184,7 +8693,8 @@ TEST_F(LazyOpsTest, TestMaskedFillBroadcast) { ForEachDevice([&](const torch::Device& device) { torch::Tensor lazy_input = CopyToDevice(input, device); torch::Tensor lazy_mask = CopyToDevice(mask, device); - torch::Tensor lazy_result = torch::masked_fill(lazy_input, lazy_mask, value); + torch::Tensor lazy_result = + torch::masked_fill(lazy_input, lazy_mask, value); AllClose(result, lazy_result); }); } @@ -8225,8 +8735,10 @@ TEST_F(LazyOpsTest, TestPermute) { for (std::vector dims_permutation : dims_permutations) { for (bool negative_dims : {false, true}) { if (negative_dims) { - std::for_each(dims_permutation.begin(), dims_permutation.end(), - [rank](int64_t& dim) { dim -= rank; }); + std::for_each( + dims_permutation.begin(), + dims_permutation.end(), + [rank](int64_t& dim) { dim -= rank; }); } torch::Tensor output = input.permute(dims_permutation); ForEachDevice([&](const torch::Device& device) { @@ -8246,8 +8758,10 @@ TEST_F(LazyOpsTest, TestPermuteMod) { for (std::vector dims_permutation : dims_permutations) { for (bool negative_dims : {false, true}) { if (negative_dims) { - std::for_each(dims_permutation.begin(), dims_permutation.end(), - [rank](int64_t& dim) { dim -= rank; }); + std::for_each( + dims_permutation.begin(), + dims_permutation.end(), + [rank](int64_t& dim) { dim -= rank; }); } torch::Tensor input = torch::zeros( input_sizes, @@ -8281,8 +8795,8 @@ TEST_F(LazyOpsTest, TestFlip) { for (std::vector flip_dims : dim_powerset) { for (bool negative_dims : {false, true}) { if (negative_dims) { - std::for_each(flip_dims.begin(), flip_dims.end(), - [](int64_t& dim) { dim -= 3; }); + std::for_each( + flip_dims.begin(), flip_dims.end(), [](int64_t& dim) { dim -= 3; }); } torch::Tensor output = torch::flip(input, flip_dims); ForEachDevice([&](const torch::Device& device) { @@ -8295,22 +8809,23 @@ TEST_F(LazyOpsTest, TestFlip) { } TEST_F(LazyOpsTest, TestPixelShuffle) { - torch::Tensor input = - torch::rand({5, 18, 4, 4}, - torch::TensorOptions(torch::kFloat).device(DefaultDevice())); + torch::Tensor input = torch::rand( + {5, 18, 4, 4}, + torch::TensorOptions(torch::kFloat).device(DefaultDevice())); int upscale_factor = 3; ForEachDevice([&](const torch::Device& device) { torch::Tensor lazy_input = CopyToDevice(input, device); torch::Tensor output = torch::pixel_shuffle(input, upscale_factor); - torch::Tensor lazy_output = torch::pixel_shuffle(lazy_input, upscale_factor); + torch::Tensor lazy_output = + torch::pixel_shuffle(lazy_input, upscale_factor); AllClose(output, lazy_output); }); } TEST_F(LazyOpsTest, TestSumToSize) { - torch::Tensor input = - torch::rand({4, 6, 3, 7}, - torch::TensorOptions(torch::kFloat).device(DefaultDevice())); + torch::Tensor input = torch::rand( + {4, 6, 3, 7}, + torch::TensorOptions(torch::kFloat).device(DefaultDevice())); std::vector out_size = {4, 1, 1, 7}; ForEachDevice([&](const torch::Device& device) { torch::Tensor lazy_input = CopyToDevice(input, device); @@ -8410,9 +8925,9 @@ TEST_F(LazyOpsTest, TestSplitEmpty) { } TEST_F(LazyOpsTest, TestSplitWithSizes) { - torch::Tensor input = - torch::rand({15, 15, 15}, - torch::TensorOptions(torch::kFloat).device(DefaultDevice())); + torch::Tensor input = torch::rand( + {15, 15, 15}, + torch::TensorOptions(torch::kFloat).device(DefaultDevice())); int rank = input.dim(); for (int dim = -rank; dim < rank; ++dim) { std::vector outputs = @@ -8466,9 +8981,9 @@ TEST_F(LazyOpsTest, TestCrossExplicitDim) { } TEST_F(LazyOpsTest, TestCrossZeroDim) { - torch::Tensor input = - torch::rand({0, 1, 3, 0}, - torch::TensorOptions(torch::kFloat).device(DefaultDevice())); + torch::Tensor input = torch::rand( + {0, 1, 3, 0}, + torch::TensorOptions(torch::kFloat).device(DefaultDevice())); torch::Tensor result = torch::cross(input, input); ForEachDevice([&](const torch::Device& device) { torch::Tensor lazy_input = CopyToDevice(input, device); @@ -8479,9 +8994,9 @@ TEST_F(LazyOpsTest, TestCrossZeroDim) { TEST_F(LazyOpsTest, TestTriu) { int size = 5; - torch::Tensor input = - torch::rand({size, size}, - torch::TensorOptions(torch::kFloat).device(DefaultDevice())); + torch::Tensor input = torch::rand( + {size, size}, + torch::TensorOptions(torch::kFloat).device(DefaultDevice())); // Test all diagonals and out of bounds (must be no-op). for (int diagonal = -size; diagonal <= size; ++diagonal) { torch::Tensor output = torch::triu(input, diagonal); @@ -8495,9 +9010,9 @@ TEST_F(LazyOpsTest, TestTriu) { TEST_F(LazyOpsTest, TestTriuNonSquare) { int size = 5; - torch::Tensor input = - torch::rand({size, size + 1}, - torch::TensorOptions(torch::kFloat).device(DefaultDevice())); + torch::Tensor input = torch::rand( + {size, size + 1}, + torch::TensorOptions(torch::kFloat).device(DefaultDevice())); // Test all diagonals and out of bounds (must be no-op). for (int diagonal = -size; diagonal <= size; ++diagonal) { torch::Tensor output = torch::triu(input, diagonal); @@ -8512,9 +9027,9 @@ TEST_F(LazyOpsTest, TestTriuNonSquare) { TEST_F(LazyOpsTest, TestTriuBatch) { int size = 5; int batch_size = 3; - torch::Tensor input = - torch::rand({batch_size, size, size}, - torch::TensorOptions(torch::kFloat).device(DefaultDevice())); + torch::Tensor input = torch::rand( + {batch_size, size, size}, + torch::TensorOptions(torch::kFloat).device(DefaultDevice())); // Test all diagonals and out of bounds (must be no-op). for (int diagonal = -size; diagonal <= size; ++diagonal) { torch::Tensor output = torch::triu(input, diagonal); @@ -8528,9 +9043,9 @@ TEST_F(LazyOpsTest, TestTriuBatch) { TEST_F(LazyOpsTest, TestTril) { int size = 5; - torch::Tensor input = - torch::rand({size, size}, - torch::TensorOptions(torch::kFloat).device(DefaultDevice())); + torch::Tensor input = torch::rand( + {size, size}, + torch::TensorOptions(torch::kFloat).device(DefaultDevice())); // Test all diagonals and out of bounds (must be no-op). for (int diagonal = -size; diagonal <= size; ++diagonal) { torch::Tensor output = torch::tril(input, diagonal); @@ -8544,9 +9059,9 @@ TEST_F(LazyOpsTest, TestTril) { TEST_F(LazyOpsTest, TestTrilNonSquare) { int size = 5; - torch::Tensor input = - torch::rand({size, size + 1}, - torch::TensorOptions(torch::kFloat).device(DefaultDevice())); + torch::Tensor input = torch::rand( + {size, size + 1}, + torch::TensorOptions(torch::kFloat).device(DefaultDevice())); // Test all diagonals and out of bounds (must be no-op). for (int diagonal = -size; diagonal <= size; ++diagonal) { torch::Tensor output = torch::tril(input, diagonal); @@ -8561,9 +9076,9 @@ TEST_F(LazyOpsTest, TestTrilNonSquare) { TEST_F(LazyOpsTest, TestTrilBatch) { int size = 5; int batch_size = 3; - torch::Tensor input = - torch::rand({batch_size, size, size}, - torch::TensorOptions(torch::kFloat).device(DefaultDevice())); + torch::Tensor input = torch::rand( + {batch_size, size, size}, + torch::TensorOptions(torch::kFloat).device(DefaultDevice())); // Test all diagonals and out of bounds (must be no-op). for (int diagonal = -size; diagonal <= size; ++diagonal) { torch::Tensor output = torch::tril(input, diagonal); @@ -8624,9 +9139,9 @@ TEST_F(LazyOpsTest, TestTrace) { TEST_F(LazyOpsTest, TestTraceWide) { int lines = 3; int cols = 5; - torch::Tensor input = - torch::rand({lines, cols}, - torch::TensorOptions(torch::kFloat).device(DefaultDevice())); + torch::Tensor input = torch::rand( + {lines, cols}, + torch::TensorOptions(torch::kFloat).device(DefaultDevice())); torch::Tensor output = torch::trace(input); ForEachDevice([&](const torch::Device& device) { torch::Tensor lazy_input = CopyToDevice(input, device); @@ -8638,9 +9153,9 @@ TEST_F(LazyOpsTest, TestTraceWide) { TEST_F(LazyOpsTest, TestTraceNarrow) { int lines = 5; int cols = 3; - torch::Tensor input = - torch::rand({lines, cols}, - torch::TensorOptions(torch::kFloat).device(DefaultDevice())); + torch::Tensor input = torch::rand( + {lines, cols}, + torch::TensorOptions(torch::kFloat).device(DefaultDevice())); torch::Tensor output = torch::trace(input); ForEachDevice([&](const torch::Device& device) { torch::Tensor lazy_input = CopyToDevice(input, device); @@ -8666,9 +9181,9 @@ TEST_F(LazyOpsTest, TestDiagRank1) { TEST_F(LazyOpsTest, TestDiagRank2) { int size = 7; - torch::Tensor input = - torch::rand({size, size}, - torch::TensorOptions(torch::kFloat).device(DefaultDevice())); + torch::Tensor input = torch::rand( + {size, size}, + torch::TensorOptions(torch::kFloat).device(DefaultDevice())); // Test all diagonals and out of bounds (must be no-op). for (int diagonal = -size; diagonal <= size; ++diagonal) { torch::Tensor output = torch::diag(input, diagonal); @@ -8681,9 +9196,9 @@ TEST_F(LazyOpsTest, TestDiagRank2) { } TEST_F(LazyOpsTest, TestDiagFlat) { - torch::Tensor input = - torch::rand({4, 3, 6, 7}, - torch::TensorOptions(torch::kFloat).device(DefaultDevice())); + torch::Tensor input = torch::rand( + {4, 3, 6, 7}, + torch::TensorOptions(torch::kFloat).device(DefaultDevice())); for (int diagonal = -10; diagonal < 10; ++diagonal) { torch::Tensor output = torch::diagflat(input, diagonal); ForEachDevice([&](const torch::Device& device) { @@ -8696,9 +9211,9 @@ TEST_F(LazyOpsTest, TestDiagFlat) { TEST_F(LazyOpsTest, TestDiagonal) { int size = 5; - torch::Tensor input = - torch::rand({size, size}, - torch::TensorOptions(torch::kFloat).device(DefaultDevice())); + torch::Tensor input = torch::rand( + {size, size}, + torch::TensorOptions(torch::kFloat).device(DefaultDevice())); // Test all diagonals and out of bounds (must be no-op). for (int diagonal = -size; diagonal <= size; ++diagonal) { torch::Tensor output = torch::diagonal(input, diagonal); @@ -8714,7 +9229,8 @@ TEST_F(LazyOpsTest, TestDiagonalUpdate) { int size = 5; // Test all diagonals and out of bounds (must be no-op). for (int diagonal = -size; diagonal <= size; ++diagonal) { - auto input = torch::rand({size, size}, + auto input = torch::rand( + {size, size}, torch::TensorOptions(torch::kFloat).device(DefaultDevice())); auto input_clone = input.clone(); auto output = torch::diagonal(input, diagonal); @@ -8733,9 +9249,9 @@ TEST_F(LazyOpsTest, TestDiagonalUpdate) { TEST_F(LazyOpsTest, TestDiagonalNonSquare) { int size = 5; - torch::Tensor input = - torch::rand({size, size + 1}, - torch::TensorOptions(torch::kFloat).device(DefaultDevice())); + torch::Tensor input = torch::rand( + {size, size + 1}, + torch::TensorOptions(torch::kFloat).device(DefaultDevice())); // Test all diagonals and out of bounds (must be no-op). for (int diagonal = -size; diagonal <= size; ++diagonal) { torch::Tensor output = torch::diagonal(input, diagonal); @@ -8752,9 +9268,9 @@ TEST_F(LazyOpsTest, TestDiagonalBatch) { int batch_size = 3; int dim1 = 1; int dim2 = 2; - torch::Tensor input = - torch::rand({batch_size, size, size}, - torch::TensorOptions(torch::kFloat).device(DefaultDevice())); + torch::Tensor input = torch::rand( + {batch_size, size, size}, + torch::TensorOptions(torch::kFloat).device(DefaultDevice())); // Test all diagonals and out of bounds (must be no-op). for (int diagonal = -size; diagonal <= size; ++diagonal) { torch::Tensor output = @@ -8793,21 +9309,25 @@ TEST_F(LazyOpsTest, TestFlatten) { TEST_F(LazyOpsTest, TestLogicalAnd) { for (torch::ScalarType scalar_type1 : - {torch::kFloat, torch::kByte, torch::kChar, torch::kShort, torch::kInt, + {torch::kFloat, + torch::kByte, + torch::kChar, + torch::kShort, + torch::kInt, torch::kLong}) { - torch::Tensor lhs = - isFloatingType(scalar_type1) - ? torch::rand({3, 4}, torch::TensorOptions(scalar_type1)) - : torch::randint(0, 100, {3, 4}, - torch::TensorOptions(scalar_type1)); + torch::Tensor lhs = isFloatingType(scalar_type1) + ? torch::rand({3, 4}, torch::TensorOptions(scalar_type1)) + : torch::randint(0, 100, {3, 4}, torch::TensorOptions(scalar_type1)); for (torch::ScalarType scalar_type2 : - {torch::kFloat, torch::kByte, torch::kChar, torch::kShort, torch::kInt, + {torch::kFloat, + torch::kByte, + torch::kChar, + torch::kShort, + torch::kInt, torch::kLong}) { - torch::Tensor rhs = - isFloatingType(scalar_type2) - ? torch::rand({3, 4}, torch::TensorOptions(scalar_type2)) - : torch::randint(1, 100, {3, 4}, - torch::TensorOptions(scalar_type2)); + torch::Tensor rhs = isFloatingType(scalar_type2) + ? torch::rand({3, 4}, torch::TensorOptions(scalar_type2)) + : torch::randint(1, 100, {3, 4}, torch::TensorOptions(scalar_type2)); torch::Tensor result = torch::logical_and(lhs, rhs); ForEachDevice([&](const torch::Device& device) { torch::Tensor lazy_lhs = CopyToDevice(lhs, device); @@ -8823,10 +9343,16 @@ TEST_F(LazyOpsTest, TestLogicalAnd) { } TEST_F(LazyOpsTest, TestBitwiseAnd) { - torch::Tensor lhs = torch::randint(0, std::numeric_limits::max(), - {4, 2}, torch::TensorOptions(torch::kInt)); - torch::Tensor rhs = torch::randint(0, std::numeric_limits::max(), - {4, 2}, torch::TensorOptions(torch::kInt)); + torch::Tensor lhs = torch::randint( + 0, + std::numeric_limits::max(), + {4, 2}, + torch::TensorOptions(torch::kInt)); + torch::Tensor rhs = torch::randint( + 0, + std::numeric_limits::max(), + {4, 2}, + torch::TensorOptions(torch::kInt)); torch::Tensor result = lhs.__and__(rhs); ForEachDevice([&](const torch::Device& device) { torch::Tensor lazy_lhs = CopyToDevice(lhs, device); @@ -8837,10 +9363,16 @@ TEST_F(LazyOpsTest, TestBitwiseAnd) { } TEST_F(LazyOpsTest, TestBitwiseAndInPlace) { - torch::Tensor lhs = torch::randint(0, std::numeric_limits::max(), - {4, 2}, torch::TensorOptions(torch::kInt)); - torch::Tensor rhs = torch::randint(0, std::numeric_limits::max(), - {4, 2}, torch::TensorOptions(torch::kInt)); + torch::Tensor lhs = torch::randint( + 0, + std::numeric_limits::max(), + {4, 2}, + torch::TensorOptions(torch::kInt)); + torch::Tensor rhs = torch::randint( + 0, + std::numeric_limits::max(), + {4, 2}, + torch::TensorOptions(torch::kInt)); ForEachDevice([&](const torch::Device& device) { torch::Tensor lazy_lhs = CopyToDevice(lhs, device); torch::Tensor result = lhs.__iand__(rhs); @@ -8852,8 +9384,11 @@ TEST_F(LazyOpsTest, TestBitwiseAndInPlace) { } TEST_F(LazyOpsTest, TestBitwiseAndScalar) { - torch::Tensor lhs = torch::randint(0, std::numeric_limits::max(), - {4, 2}, torch::TensorOptions(torch::kInt)); + torch::Tensor lhs = torch::randint( + 0, + std::numeric_limits::max(), + {4, 2}, + torch::TensorOptions(torch::kInt)); torch::Scalar rhs(123456789); torch::Tensor result = lhs.__and__(rhs); ForEachDevice([&](const torch::Device& device) { @@ -8864,8 +9399,11 @@ TEST_F(LazyOpsTest, TestBitwiseAndScalar) { } TEST_F(LazyOpsTest, TestBitwiseAndScalarInPlace) { - torch::Tensor lhs = torch::randint(0, std::numeric_limits::max(), - {4, 2}, torch::TensorOptions(torch::kInt)); + torch::Tensor lhs = torch::randint( + 0, + std::numeric_limits::max(), + {4, 2}, + torch::TensorOptions(torch::kInt)); torch::Scalar rhs(123456789); ForEachDevice([&](const torch::Device& device) { torch::Tensor lazy_lhs = CopyToDevice(lhs, device); @@ -8884,16 +9422,23 @@ TEST_F(LazyOpsTest, TestBitwiseAndPromotion) { ForEachDevice([&](const torch::Device& device) { torch::Tensor lazy_input = CopyToDevice(input, device); torch::Tensor lazy_view = lazy_input.reshape(-1); - torch::Tensor lazy_result = torch::__and__(lazy_view.gt(0), lazy_view.ne(0)); + torch::Tensor lazy_result = + torch::__and__(lazy_view.gt(0), lazy_view.ne(0)); AllEqual(result, lazy_result); }); } TEST_F(LazyOpsTest, TestBitwiseOr) { - torch::Tensor lhs = torch::randint(0, std::numeric_limits::max(), - {4, 2}, torch::TensorOptions(torch::kInt)); - torch::Tensor rhs = torch::randint(0, std::numeric_limits::max(), - {4, 2}, torch::TensorOptions(torch::kInt)); + torch::Tensor lhs = torch::randint( + 0, + std::numeric_limits::max(), + {4, 2}, + torch::TensorOptions(torch::kInt)); + torch::Tensor rhs = torch::randint( + 0, + std::numeric_limits::max(), + {4, 2}, + torch::TensorOptions(torch::kInt)); torch::Tensor result = lhs.__or__(rhs); ForEachDevice([&](const torch::Device& device) { torch::Tensor lazy_lhs = CopyToDevice(lhs, device); @@ -8904,10 +9449,16 @@ TEST_F(LazyOpsTest, TestBitwiseOr) { } TEST_F(LazyOpsTest, TestBitwiseOrInPlace) { - torch::Tensor lhs = torch::randint(0, std::numeric_limits::max(), - {4, 2}, torch::TensorOptions(torch::kInt)); - torch::Tensor rhs = torch::randint(0, std::numeric_limits::max(), - {4, 2}, torch::TensorOptions(torch::kInt)); + torch::Tensor lhs = torch::randint( + 0, + std::numeric_limits::max(), + {4, 2}, + torch::TensorOptions(torch::kInt)); + torch::Tensor rhs = torch::randint( + 0, + std::numeric_limits::max(), + {4, 2}, + torch::TensorOptions(torch::kInt)); ForEachDevice([&](const torch::Device& device) { torch::Tensor lazy_lhs = CopyToDevice(lhs, device); torch::Tensor result = lhs.__ior__(rhs); @@ -8919,8 +9470,11 @@ TEST_F(LazyOpsTest, TestBitwiseOrInPlace) { } TEST_F(LazyOpsTest, TestBitwiseOrScalar) { - torch::Tensor lhs = torch::randint(0, std::numeric_limits::max(), - {4, 2}, torch::TensorOptions(torch::kInt)); + torch::Tensor lhs = torch::randint( + 0, + std::numeric_limits::max(), + {4, 2}, + torch::TensorOptions(torch::kInt)); torch::Scalar rhs(123456789); torch::Tensor result = lhs.__or__(rhs); ForEachDevice([&](const torch::Device& device) { @@ -8931,8 +9485,11 @@ TEST_F(LazyOpsTest, TestBitwiseOrScalar) { } TEST_F(LazyOpsTest, TestBitwiseOrScalarInPlace) { - torch::Tensor lhs = torch::randint(0, std::numeric_limits::max(), - {4, 2}, torch::TensorOptions(torch::kInt)); + torch::Tensor lhs = torch::randint( + 0, + std::numeric_limits::max(), + {4, 2}, + torch::TensorOptions(torch::kInt)); torch::Scalar rhs(123456789); ForEachDevice([&](const torch::Device& device) { torch::Tensor lazy_lhs = CopyToDevice(lhs, device); @@ -8944,10 +9501,16 @@ TEST_F(LazyOpsTest, TestBitwiseOrScalarInPlace) { } TEST_F(LazyOpsTest, TestBitwiseXor) { - torch::Tensor lhs = torch::randint(0, std::numeric_limits::max(), - {4, 2}, torch::TensorOptions(torch::kInt)); - torch::Tensor rhs = torch::randint(0, std::numeric_limits::max(), - {4, 2}, torch::TensorOptions(torch::kInt)); + torch::Tensor lhs = torch::randint( + 0, + std::numeric_limits::max(), + {4, 2}, + torch::TensorOptions(torch::kInt)); + torch::Tensor rhs = torch::randint( + 0, + std::numeric_limits::max(), + {4, 2}, + torch::TensorOptions(torch::kInt)); torch::Tensor result = lhs.__xor__(rhs); ForEachDevice([&](const torch::Device& device) { torch::Tensor lazy_lhs = CopyToDevice(lhs, device); @@ -8958,10 +9521,16 @@ TEST_F(LazyOpsTest, TestBitwiseXor) { } TEST_F(LazyOpsTest, TestBitwiseXorInPlace) { - torch::Tensor lhs = torch::randint(0, std::numeric_limits::max(), - {4, 2}, torch::TensorOptions(torch::kInt)); - torch::Tensor rhs = torch::randint(0, std::numeric_limits::max(), - {4, 2}, torch::TensorOptions(torch::kInt)); + torch::Tensor lhs = torch::randint( + 0, + std::numeric_limits::max(), + {4, 2}, + torch::TensorOptions(torch::kInt)); + torch::Tensor rhs = torch::randint( + 0, + std::numeric_limits::max(), + {4, 2}, + torch::TensorOptions(torch::kInt)); ForEachDevice([&](const torch::Device& device) { torch::Tensor lazy_lhs = CopyToDevice(lhs, device); torch::Tensor result = lhs.__ixor__(rhs); @@ -8973,8 +9542,11 @@ TEST_F(LazyOpsTest, TestBitwiseXorInPlace) { } TEST_F(LazyOpsTest, TestBitwiseXorScalar) { - torch::Tensor lhs = torch::randint(0, std::numeric_limits::max(), - {4, 2}, torch::TensorOptions(torch::kInt)); + torch::Tensor lhs = torch::randint( + 0, + std::numeric_limits::max(), + {4, 2}, + torch::TensorOptions(torch::kInt)); torch::Scalar rhs(123456789); torch::Tensor result = lhs.__xor__(rhs); ForEachDevice([&](const torch::Device& device) { @@ -8985,8 +9557,11 @@ TEST_F(LazyOpsTest, TestBitwiseXorScalar) { } TEST_F(LazyOpsTest, TestBitwiseXorScalarInPlace) { - torch::Tensor lhs = torch::randint(0, std::numeric_limits::max(), - {4, 2}, torch::TensorOptions(torch::kInt)); + torch::Tensor lhs = torch::randint( + 0, + std::numeric_limits::max(), + {4, 2}, + torch::TensorOptions(torch::kInt)); torch::Scalar rhs(123456789); ForEachDevice([&](const torch::Device& device) { torch::Tensor lazy_lhs = CopyToDevice(lhs, device); @@ -9176,9 +9751,9 @@ TEST_F(LazyOpsTest, TestReflectionPad2dRank3) { } TEST_F(LazyOpsTest, TestReflectionPad2dRank4) { - torch::Tensor input = - torch::rand({2, 2, 3, 4}, - torch::TensorOptions(torch::kFloat).device(DefaultDevice())); + torch::Tensor input = torch::rand( + {2, 2, 3, 4}, + torch::TensorOptions(torch::kFloat).device(DefaultDevice())); std::vector pad{2, 2, 2, 2}; torch::Tensor output = torch::reflection_pad2d(input, pad); ForEachDevice([&](const torch::Device& device) { @@ -9194,10 +9769,14 @@ TEST_F(LazyOpsTest, TestReflectionPad2dBackward) { return torch::reflection_pad2d(inputs[0], pad); }; ForEachDevice([&](const torch::Device& device) { - TestBackward({torch::rand({1, 2, 4, 4}, torch::TensorOptions(torch::kFloat) - .device(DefaultDevice()) - .requires_grad(true))}, - device, testfn); + TestBackward( + {torch::rand( + {1, 2, 4, 4}, + torch::TensorOptions(torch::kFloat) + .device(DefaultDevice()) + .requires_grad(true))}, + device, + testfn); }); } @@ -9231,10 +9810,14 @@ TEST_F(LazyOpsTest, TestReplicationPad1dBackward) { return torch::replication_pad1d(inputs[0], pad); }; ForEachDevice([&](const torch::Device& device) { - TestBackward({torch::rand({2, 4}, torch::TensorOptions(torch::kFloat) - .device(DefaultDevice()) - .requires_grad(true))}, - device, testfn); + TestBackward( + {torch::rand( + {2, 4}, + torch::TensorOptions(torch::kFloat) + .device(DefaultDevice()) + .requires_grad(true))}, + device, + testfn); }); } @@ -9268,10 +9851,14 @@ TEST_F(LazyOpsTest, TestReplicationPad2dBackward) { return torch::replication_pad2d(inputs[0], pad); }; ForEachDevice([&](const torch::Device& device) { - TestBackward({torch::rand({2, 3, 4}, torch::TensorOptions(torch::kFloat) - .device(DefaultDevice()) - .requires_grad(true))}, - device, testfn); + TestBackward( + {torch::rand( + {2, 3, 4}, + torch::TensorOptions(torch::kFloat) + .device(DefaultDevice()) + .requires_grad(true))}, + device, + testfn); }); } @@ -9312,14 +9899,18 @@ TEST_F(LazyOpsTest, TestAsStridedWithOffset) { std::vector size = {4, 4, 2}; std::vector stride = {8, 2, 1}; int64_t storage_offset = 4; - torch::Tensor output = - torch::as_strided(input, /*size=*/size, /*stride=*/stride, - /*storage_offset=*/storage_offset); + torch::Tensor output = torch::as_strided( + input, + /*size=*/size, + /*stride=*/stride, + /*storage_offset=*/storage_offset); ForEachDevice([&](const torch::Device& device) { torch::Tensor lazy_input = CopyToDevice(input, device); - torch::Tensor lazy_output = - torch::as_strided(lazy_input, /*size=*/size, /*stride=*/stride, - /*storage_offset=*/storage_offset); + torch::Tensor lazy_output = torch::as_strided( + lazy_input, + /*size=*/size, + /*stride=*/stride, + /*storage_offset=*/storage_offset); AllClose(output, lazy_output); }); } @@ -9360,20 +9951,24 @@ TEST_F(LazyOpsTest, TestAvgPool2DBackward) { for (bool ceil_mode : {false, true}) { auto testfn = [&](const std::vector& inputs) -> torch::Tensor { - return torch::avg_pool2d(inputs[0], - /*kernel_size=*/{kernel_size, kernel_size}, - /*stride=*/{stride, stride}, - /*padding=*/{padding, padding}, - /*ceil_mode=*/ceil_mode, - /*count_include_pad=*/count_include_pad); + return torch::avg_pool2d( + inputs[0], + /*kernel_size=*/{kernel_size, kernel_size}, + /*stride=*/{stride, stride}, + /*padding=*/{padding, padding}, + /*ceil_mode=*/ceil_mode, + /*count_include_pad=*/count_include_pad); }; ForEachDevice([&](const torch::Device& device) { TestBackward( - {torch::rand({1, 1, 7, 7}, torch::TensorOptions(torch::kFloat) - .device(DefaultDevice()) - .requires_grad(true))}, - device, testfn); + {torch::rand( + {1, 1, 7, 7}, + torch::TensorOptions(torch::kFloat) + .device(DefaultDevice()) + .requires_grad(true))}, + device, + testfn); }); } } @@ -9400,11 +9995,14 @@ TEST_F(LazyOpsTest, TestAvgPool3DBackward) { }; ForEachDevice([&](const torch::Device& device) { - TestBackward({torch::rand({1, 1, 7, 7, 7}, - torch::TensorOptions(torch::kFloat) - .device(DefaultDevice()) - .requires_grad(true))}, - device, testfn); + TestBackward( + {torch::rand( + {1, 1, 7, 7, 7}, + torch::TensorOptions(torch::kFloat) + .device(DefaultDevice()) + .requires_grad(true))}, + device, + testfn); }); } } @@ -9421,20 +10019,24 @@ TEST_F(LazyOpsTest, TestAvgPool2DNoBatchBackward) { for (bool ceil_mode : {false, true}) { auto testfn = [&](const std::vector& inputs) -> torch::Tensor { - return torch::avg_pool2d(inputs[0], - /*kernel_size=*/{kernel_size, kernel_size}, - /*stride=*/{stride, stride}, - /*padding=*/{padding, padding}, - /*ceil_mode=*/ceil_mode, - /*count_include_pad=*/count_include_pad); + return torch::avg_pool2d( + inputs[0], + /*kernel_size=*/{kernel_size, kernel_size}, + /*stride=*/{stride, stride}, + /*padding=*/{padding, padding}, + /*ceil_mode=*/ceil_mode, + /*count_include_pad=*/count_include_pad); }; ForEachDevice([&](const torch::Device& device) { TestBackward( - {torch::rand({1, 7, 7}, torch::TensorOptions(torch::kFloat) - .device(DefaultDevice()) - .requires_grad(true))}, - device, testfn); + {torch::rand( + {1, 7, 7}, + torch::TensorOptions(torch::kFloat) + .device(DefaultDevice()) + .requires_grad(true))}, + device, + testfn); }); } } @@ -9462,10 +10064,13 @@ TEST_F(LazyOpsTest, TestAvgPool3DNoBatchBackward) { ForEachDevice([&](const torch::Device& device) { TestBackward( - {torch::rand({1, 7, 7, 7}, torch::TensorOptions(torch::kFloat) - .device(DefaultDevice()) - .requires_grad(true))}, - device, testfn); + {torch::rand( + {1, 7, 7, 7}, + torch::TensorOptions(torch::kFloat) + .device(DefaultDevice()) + .requires_grad(true))}, + device, + testfn); }); } } @@ -9485,10 +10090,13 @@ TEST_F(LazyOpsTest, TestAdaptiveAvgPool3DNoBatchBackward) { }; ForEachDevice([&](const torch::Device& device) { TestBackward( - {torch::rand({1, 56, 28, 28}, torch::TensorOptions(torch::kFloat) - .device(DefaultDevice()) - .requires_grad(true))}, - device, testfn); + {torch::rand( + {1, 56, 28, 28}, + torch::TensorOptions(torch::kFloat) + .device(DefaultDevice()) + .requires_grad(true))}, + device, + testfn); }); } } @@ -9505,10 +10113,13 @@ TEST_F(LazyOpsTest, TestAdaptiveAvgPool3DBackward) { }; ForEachDevice([&](const torch::Device& device) { TestBackward( - {torch::rand({4, 1, 56, 28, 28}, torch::TensorOptions(torch::kFloat) - .device(DefaultDevice()) - .requires_grad(true))}, - device, testfn); + {torch::rand( + {4, 1, 56, 28, 28}, + torch::TensorOptions(torch::kFloat) + .device(DefaultDevice()) + .requires_grad(true))}, + device, + testfn); }); } } @@ -9521,10 +10132,13 @@ TEST_F(LazyOpsTest, TestAdaptiveAvgPool2DBackward) { }; ForEachDevice([&](const torch::Device& device) { TestBackward( - {torch::rand({4, 1, 56, 56}, torch::TensorOptions(torch::kFloat) - .device(DefaultDevice()) - .requires_grad(true))}, - device, testfn); + {torch::rand( + {4, 1, 56, 56}, + torch::TensorOptions(torch::kFloat) + .device(DefaultDevice()) + .requires_grad(true))}, + device, + testfn); }); } } @@ -9536,9 +10150,12 @@ TEST_F(LazyOpsTest, TestAdaptiveAvgPool2DNoBatchBackward) { return torch::adaptive_avg_pool2d(inputs[0], {output_size, output_size}); }; ForEachDevice([&](const torch::Device& device) { - TestBackward({torch::rand({1, 56, 56}, torch::TensorOptions(torch::kFloat) - .requires_grad(true))}, - device, testfn); + TestBackward( + {torch::rand( + {1, 56, 56}, + torch::TensorOptions(torch::kFloat).requires_grad(true))}, + device, + testfn); }); } } @@ -9552,36 +10169,45 @@ TEST_F(LazyOpsTest, TestConv2D) { for (bool with_bias : {true, false}) { for (int dilation = 1; dilation <= 3; ++dilation) { for (int groups : - {1, 2, 4}) { // covers normal, grouped, depthwise conv. + {1, 2, 4}) { // covers normal, grouped, depthwise conv. ForEachDevice([&](const torch::Device& device) { torch::Tensor input = torch::rand( {1, in_channels, 7, 7}, torch::TensorOptions(torch::kDouble).device(DefaultDevice())); torch::Tensor weight = torch::rand( - {out_channels, in_channels / groups, kernel_size, + {out_channels, + in_channels / groups, + kernel_size, kernel_size}, torch::TensorOptions(torch::kDouble).device(DefaultDevice())); - torch::Tensor bias = - with_bias ? torch::rand({out_channels}, - torch::TensorOptions(torch::kDouble) - .device(DefaultDevice())) - : torch::Tensor(); + torch::Tensor bias = with_bias + ? torch::rand( + {out_channels}, + torch::TensorOptions(torch::kDouble) + .device(DefaultDevice())) + : torch::Tensor(); torch::Tensor lazy_input = CopyToDevice(input, device); torch::Tensor lazy_weight = CopyToDevice(weight, device); torch::Tensor lazy_bias = with_bias ? CopyToDevice(bias, device) : torch::Tensor(); - torch::Tensor output = - torch::conv2d(input, weight, bias, - /*stride=*/{stride, stride}, - /*padding=*/{padding, padding}, - /*dilation=*/{dilation, dilation}, groups); - torch::Tensor lazy_output = - torch::conv2d(lazy_input, lazy_weight, lazy_bias, - /*stride=*/{stride, stride}, - /*padding=*/{padding, padding}, - /*dilation=*/{dilation, dilation}, groups); + torch::Tensor output = torch::conv2d( + input, + weight, + bias, + /*stride=*/{stride, stride}, + /*padding=*/{padding, padding}, + /*dilation=*/{dilation, dilation}, + groups); + torch::Tensor lazy_output = torch::conv2d( + lazy_input, + lazy_weight, + lazy_bias, + /*stride=*/{stride, stride}, + /*padding=*/{padding, padding}, + /*dilation=*/{dilation, dilation}, + groups); AllClose(output, lazy_output); }); } @@ -9600,32 +10226,43 @@ TEST_F(LazyOpsTest, TestConv2DBackward) { for (bool with_bias : {true, false}) { for (int dilation = 1; dilation <= 3; ++dilation) { for (int groups : - {1, 2, 4}) { // covers normal, grouped, depthwise conv. + {1, 2, 4}) { // covers normal, grouped, depthwise conv. auto testfn = [&](const std::vector& inputs) -> torch::Tensor { - return torch::conv2d(inputs[0], inputs[1], inputs[2], - /*stride=*/{stride, stride}, - /*padding=*/{padding, padding}, - /*dilation=*/{dilation, dilation}, groups); + return torch::conv2d( + inputs[0], + inputs[1], + inputs[2], + /*stride=*/{stride, stride}, + /*padding=*/{padding, padding}, + /*dilation=*/{dilation, dilation}, + groups); }; ForEachDevice([&](const torch::Device& device) { - torch::Tensor bias = - with_bias ? torch::rand({out_channels}, - torch::TensorOptions(torch::kDouble) - .device(DefaultDevice())) - : torch::Tensor(); - TestBackward({torch::rand({1, in_channels, 7, 7}, - torch::TensorOptions(torch::kDouble) - .device(DefaultDevice()) - .requires_grad(true)), - torch::rand({out_channels, in_channels / groups, - kernel_size, kernel_size}, - torch::TensorOptions(torch::kDouble) - .device(DefaultDevice()) - .requires_grad(true)), - bias}, - device, testfn); + torch::Tensor bias = with_bias + ? torch::rand( + {out_channels}, + torch::TensorOptions(torch::kDouble) + .device(DefaultDevice())) + : torch::Tensor(); + TestBackward( + {torch::rand( + {1, in_channels, 7, 7}, + torch::TensorOptions(torch::kDouble) + .device(DefaultDevice()) + .requires_grad(true)), + torch::rand( + {out_channels, + in_channels / groups, + kernel_size, + kernel_size}, + torch::TensorOptions(torch::kDouble) + .device(DefaultDevice()) + .requires_grad(true)), + bias}, + device, + testfn); }); } }; @@ -9642,14 +10279,17 @@ TEST_F(LazyOpsTest, TestTransposedConv2DBackward) { for (int padding = 0; padding <= 1; ++padding) { for (int dilation = 1; dilation <= 2; ++dilation) { for (int output_padding = 0; - output_padding < std::max(stride, dilation); ++output_padding) { + output_padding < std::max(stride, dilation); + ++output_padding) { for (bool with_bias : {true, false}) { for (int groups : - {1, 2, 4}) { // covers normal, grouped, depthwise conv. + {1, 2, 4}) { // covers normal, grouped, depthwise conv. auto testfn = [&](const std::vector& inputs) -> torch::Tensor { return torch::conv_transpose2d( - inputs[0], inputs[1], inputs[2], + inputs[0], + inputs[1], + inputs[2], /*stride=*/{stride, stride + 1}, /*padding=*/{padding, padding + 1}, /*output_padding=*/output_padding, @@ -9658,23 +10298,31 @@ TEST_F(LazyOpsTest, TestTransposedConv2DBackward) { }; ForEachDevice([&](const torch::Device& device) { torch::Tensor input = torch::rand( - {4, out_channels, 7, 7}, torch::TensorOptions(torch::kFloat) - .device(DefaultDevice()) - .requires_grad(true)); - torch::Tensor weight = - torch::rand({out_channels, in_channels / groups, - kernel_size, kernel_size}, - torch::TensorOptions(torch::kFloat) - .device(DefaultDevice()) - .requires_grad(true)); - torch::Tensor bias = - with_bias ? torch::rand({in_channels}, - torch::TensorOptions(torch::kFloat) - .device(DefaultDevice()) - .requires_grad(true)) - : torch::Tensor(); - TestBackward({input, weight, bias}, device, testfn, - /*rtol=*/1e-5, /*atol=*/1e-5); + {4, out_channels, 7, 7}, + torch::TensorOptions(torch::kFloat) + .device(DefaultDevice()) + .requires_grad(true)); + torch::Tensor weight = torch::rand( + {out_channels, + in_channels / groups, + kernel_size, + kernel_size}, + torch::TensorOptions(torch::kFloat) + .device(DefaultDevice()) + .requires_grad(true)); + torch::Tensor bias = with_bias + ? torch::rand( + {in_channels}, + torch::TensorOptions(torch::kFloat) + .device(DefaultDevice()) + .requires_grad(true)) + : torch::Tensor(); + TestBackward( + {input, weight, bias}, + device, + testfn, + /*rtol=*/1e-5, + /*atol=*/1e-5); }); } }; @@ -9693,33 +10341,44 @@ TEST_F(LazyOpsTest, TestConv3DBackward) { for (bool with_bias : {true, false}) { for (int dilation = 1; dilation <= 2; ++dilation) { for (int groups : - {1, 2, 4}) { // covers normal, grouped, depthwise conv. + {1, 2, 4}) { // covers normal, grouped, depthwise conv. auto testfn = [&](const std::vector& inputs) -> torch::Tensor { - return torch::conv3d(inputs[0], inputs[1], inputs[2], - /*stride=*/{stride, stride, stride}, - /*padding=*/{padding, padding, padding}, - /*dilation=*/{dilation, dilation, dilation}, - groups); + return torch::conv3d( + inputs[0], + inputs[1], + inputs[2], + /*stride=*/{stride, stride, stride}, + /*padding=*/{padding, padding, padding}, + /*dilation=*/{dilation, dilation, dilation}, + groups); }; ForEachDevice([&](const torch::Device& device) { - torch::Tensor bias = - with_bias ? torch::rand({out_channels}, - torch::TensorOptions(torch::kDouble) - .device(DefaultDevice())) - : torch::Tensor(); - TestBackward({torch::rand({4, in_channels, 7, 7, 7}, - torch::TensorOptions(torch::kDouble) - .device(DefaultDevice()) - .requires_grad(true)), - torch::rand({out_channels, in_channels / groups, - kernel_size, kernel_size, kernel_size}, - torch::TensorOptions(torch::kDouble) - .device(DefaultDevice()) - .requires_grad(true)), - bias}, - device, testfn); + torch::Tensor bias = with_bias + ? torch::rand( + {out_channels}, + torch::TensorOptions(torch::kDouble) + .device(DefaultDevice())) + : torch::Tensor(); + TestBackward( + {torch::rand( + {4, in_channels, 7, 7, 7}, + torch::TensorOptions(torch::kDouble) + .device(DefaultDevice()) + .requires_grad(true)), + torch::rand( + {out_channels, + in_channels / groups, + kernel_size, + kernel_size, + kernel_size}, + torch::TensorOptions(torch::kDouble) + .device(DefaultDevice()) + .requires_grad(true)), + bias}, + device, + testfn); }); } }; @@ -9736,14 +10395,17 @@ TEST_F(LazyOpsTest, TestTransposedConv3DBackward) { for (int padding = 0; padding <= 1; ++padding) { for (int dilation = 1; dilation <= 2; ++dilation) { for (int output_padding = 0; - output_padding < std::max(stride, dilation); ++output_padding) { + output_padding < std::max(stride, dilation); + ++output_padding) { for (bool with_bias : {true, false}) { for (int groups : - {1, 2, 4}) { // covers normal, grouped, depthwise conv. + {1, 2, 4}) { // covers normal, grouped, depthwise conv. auto testfn = [&](const std::vector& inputs) -> torch::Tensor { return torch::conv_transpose3d( - inputs[0], inputs[1], inputs[2], + inputs[0], + inputs[1], + inputs[2], /*stride=*/{stride, stride + 1, stride}, /*padding=*/{padding, padding + 1, stride}, /*output_padding=*/output_padding, @@ -9751,23 +10413,27 @@ TEST_F(LazyOpsTest, TestTransposedConv3DBackward) { /*dilation=*/{dilation, dilation + 1, dilation}); }; ForEachDevice([&](const torch::Device& device) { - torch::Tensor input = - torch::rand({4, out_channels, 7, 7, 7}, - torch::TensorOptions(torch::kDouble) - .device(DefaultDevice()) - .requires_grad(true)); - torch::Tensor weight = - torch::rand({out_channels, in_channels / groups, - kernel_size, kernel_size, kernel_size}, - torch::TensorOptions(torch::kDouble) - .device(DefaultDevice()) - .requires_grad(true)); - torch::Tensor bias = - with_bias ? torch::rand({in_channels}, - torch::TensorOptions(torch::kDouble) - .device(DefaultDevice()) - .requires_grad(true)) - : torch::Tensor(); + torch::Tensor input = torch::rand( + {4, out_channels, 7, 7, 7}, + torch::TensorOptions(torch::kDouble) + .device(DefaultDevice()) + .requires_grad(true)); + torch::Tensor weight = torch::rand( + {out_channels, + in_channels / groups, + kernel_size, + kernel_size, + kernel_size}, + torch::TensorOptions(torch::kDouble) + .device(DefaultDevice()) + .requires_grad(true)); + torch::Tensor bias = with_bias + ? torch::rand( + {in_channels}, + torch::TensorOptions(torch::kDouble) + .device(DefaultDevice()) + .requires_grad(true)) + : torch::Tensor(); TestBackward({input, weight, bias}, device, testfn); }); } @@ -9787,18 +10453,23 @@ TEST_F(LazyOpsTest, TestMaxPool2DBackward) { auto testfn = [&](const std::vector& inputs) -> torch::Tensor { return torch::max_pool2d( - inputs[0], /*kernel_size=*/{kernel_size, kernel_size}, + inputs[0], + /*kernel_size=*/{kernel_size, kernel_size}, /*stride=*/{stride, stride}, - /*padding=*/{padding, padding}, /*dilation=*/{1, 1}, + /*padding=*/{padding, padding}, + /*dilation=*/{1, 1}, /*ceil_mode=*/ceil_mode); }; ForEachDevice([&](const torch::Device& device) { TestBackward( - {torch::rand({1, 2, 8, 8}, torch::TensorOptions(torch::kFloat) - .device(DefaultDevice()) - .requires_grad(true))}, - device, testfn); + {torch::rand( + {1, 2, 8, 8}, + torch::TensorOptions(torch::kFloat) + .device(DefaultDevice()) + .requires_grad(true))}, + device, + testfn); }); } } @@ -9817,16 +10488,20 @@ TEST_F(LazyOpsTest, TestMaxPool3DBackward) { inputs[0], /*kernel_size=*/{kernel_size, kernel_size, kernel_size}, /*stride=*/{stride, stride, stride}, - /*padding=*/{padding, padding, padding}, /*dilation=*/{1, 1, 1}, + /*padding=*/{padding, padding, padding}, + /*dilation=*/{1, 1, 1}, /*ceil_mode=*/ceil_mode); }; ForEachDevice([&](const torch::Device& device) { TestBackward( - {torch::rand({1, 2, 4, 4, 4}, torch::TensorOptions(torch::kFloat) - .device(DefaultDevice()) - .requires_grad(true))}, - device, testfn); + {torch::rand( + {1, 2, 4, 4, 4}, + torch::TensorOptions(torch::kFloat) + .device(DefaultDevice()) + .requires_grad(true))}, + device, + testfn); }); } } @@ -9842,18 +10517,23 @@ TEST_F(LazyOpsTest, TestMaxPool2DNoBatchBackward) { auto testfn = [&](const std::vector& inputs) -> torch::Tensor { return torch::max_pool2d( - inputs[0], /*kernel_size=*/{kernel_size, kernel_size}, + inputs[0], + /*kernel_size=*/{kernel_size, kernel_size}, /*stride=*/{stride, stride}, - /*padding=*/{padding, padding}, /*dilation=*/{1, 1}, + /*padding=*/{padding, padding}, + /*dilation=*/{1, 1}, /*ceil_mode=*/ceil_mode); }; ForEachDevice([&](const torch::Device& device) { TestBackward( - {torch::rand({2, 8, 8}, torch::TensorOptions(torch::kFloat) - .device(DefaultDevice()) - .requires_grad(true))}, - device, testfn); + {torch::rand( + {2, 8, 8}, + torch::TensorOptions(torch::kFloat) + .device(DefaultDevice()) + .requires_grad(true))}, + device, + testfn); }); } } @@ -9872,16 +10552,20 @@ TEST_F(LazyOpsTest, TestMaxPool3DNoBatchBackward) { inputs[0], /*kernel_size=*/{kernel_size, kernel_size, kernel_size}, /*stride=*/{stride, stride, stride}, - /*padding=*/{padding, padding, padding}, /*dilation=*/{1, 1, 1}, + /*padding=*/{padding, padding, padding}, + /*dilation=*/{1, 1, 1}, /*ceil_mode=*/ceil_mode); }; ForEachDevice([&](const torch::Device& device) { TestBackward( - {torch::rand({2, 4, 4, 4}, torch::TensorOptions(torch::kFloat) - .device(DefaultDevice()) - .requires_grad(true))}, - device, testfn); + {torch::rand( + {2, 4, 4, 4}, + torch::TensorOptions(torch::kFloat) + .device(DefaultDevice()) + .requires_grad(true))}, + device, + testfn); }); } } @@ -9890,9 +10574,9 @@ TEST_F(LazyOpsTest, TestMaxPool3DNoBatchBackward) { TEST_F(LazyOpsTest, TestMaxUnpool2DBackward) { int kernel_size = 2; - torch::Tensor input = - torch::rand({2, 2, 8, 8}, - torch::TensorOptions(torch::kFloat).device(DefaultDevice())); + torch::Tensor input = torch::rand( + {2, 2, 8, 8}, + torch::TensorOptions(torch::kFloat).device(DefaultDevice())); for (int stride = 1; stride <= 2; ++stride) { for (int padding = 0; padding <= 1; ++padding) { // Test ceil_mode=true through the CPU interop. @@ -9901,9 +10585,11 @@ TEST_F(LazyOpsTest, TestMaxUnpool2DBackward) { torch::Tensor output; torch::Tensor indices; std::tie(output, indices) = torch::max_pool2d_with_indices( - input, /*kernel_size=*/{kernel_size, kernel_size}, + input, + /*kernel_size=*/{kernel_size, kernel_size}, /*stride=*/{stride, stride}, - /*padding=*/{padding, padding}, /*dilation=*/{dilation, dilation}, + /*padding=*/{padding, padding}, + /*dilation=*/{dilation, dilation}, /*ceil_mode=*/ceil_mode); std::vector output_size({input.size(2), input.size(3)}); @@ -9913,8 +10599,8 @@ TEST_F(LazyOpsTest, TestMaxUnpool2DBackward) { }; ForEachDevice([&](const torch::Device& device) { - TestBackward({output.requires_grad_(true), indices}, device, - testfn); + TestBackward( + {output.requires_grad_(true), indices}, device, testfn); }); } } @@ -9924,9 +10610,9 @@ TEST_F(LazyOpsTest, TestMaxUnpool2DBackward) { TEST_F(LazyOpsTest, TestMaxUnpool3DBackward) { int kernel_size = 2; - torch::Tensor input = - torch::rand({1, 1, 4, 4, 4}, - torch::TensorOptions(torch::kFloat).device(DefaultDevice())); + torch::Tensor input = torch::rand( + {1, 1, 4, 4, 4}, + torch::TensorOptions(torch::kFloat).device(DefaultDevice())); for (int stride = 1; stride <= 2; ++stride) { for (int padding = 0; padding <= 1; ++padding) { // Test ceil_mode=true through the CPU interop. @@ -9935,7 +10621,8 @@ TEST_F(LazyOpsTest, TestMaxUnpool3DBackward) { torch::Tensor output; torch::Tensor indices; std::tie(output, indices) = torch::max_pool3d_with_indices( - input, /*kernel_size=*/{kernel_size, kernel_size, kernel_size}, + input, + /*kernel_size=*/{kernel_size, kernel_size, kernel_size}, /*stride=*/{stride, stride, stride}, /*padding=*/{padding, padding, padding}, /*dilation=*/{dilation, dilation, dilation}, @@ -9945,14 +10632,17 @@ TEST_F(LazyOpsTest, TestMaxUnpool3DBackward) { {input.size(2), input.size(3), input.size(4)}); auto testfn = [&](const std::vector& inputs) -> torch::Tensor { - return torch::max_unpool3d(inputs[0], inputs[1], output_size, - /*stride=*/{stride, stride, stride}, - /*padding=*/{padding, padding, padding}); + return torch::max_unpool3d( + inputs[0], + inputs[1], + output_size, + /*stride=*/{stride, stride, stride}, + /*padding=*/{padding, padding, padding}); }; ForEachDevice([&](const torch::Device& device) { - TestBackward({output.requires_grad_(true), indices}, device, - testfn); + TestBackward( + {output.requires_grad_(true), indices}, device, testfn); }); } } @@ -9965,10 +10655,16 @@ TEST_F(LazyOpsTest, TestTanhBackward) { return torch::tanh(inputs[0]); }; ForEachDevice([&](const torch::Device& device) { - TestBackward({torch::rand({2, 2}, torch::TensorOptions(torch::kFloat) - .device(DefaultDevice()) - .requires_grad(true))}, - device, testfn, /*rtol=*/1e-3, /*atol=*/1e-5); + TestBackward( + {torch::rand( + {2, 2}, + torch::TensorOptions(torch::kFloat) + .device(DefaultDevice()) + .requires_grad(true))}, + device, + testfn, + /*rtol=*/1e-3, + /*atol=*/1e-5); }); } @@ -9977,10 +10673,14 @@ TEST_F(LazyOpsTest, TestSigmoidBackward) { return torch::sigmoid(inputs[0]); }; ForEachDevice([&](const torch::Device& device) { - TestBackward({torch::rand({2, 2}, torch::TensorOptions(torch::kFloat) - .device(DefaultDevice()) - .requires_grad(true))}, - device, testfn); + TestBackward( + {torch::rand( + {2, 2}, + torch::TensorOptions(torch::kFloat) + .device(DefaultDevice()) + .requires_grad(true))}, + device, + testfn); }); } @@ -9989,10 +10689,16 @@ TEST_F(LazyOpsTest, TestLogSigmoidBackward) { return torch::log_sigmoid(inputs[0]); }; ForEachDevice([&](const torch::Device& device) { - TestBackward({torch::rand({2, 2}, torch::TensorOptions(torch::kFloat) - .device(DefaultDevice()) - .requires_grad(true))}, - device, testfn, /*rtol=*/1e-3, /*atol=*/1e-5); + TestBackward( + {torch::rand( + {2, 2}, + torch::TensorOptions(torch::kFloat) + .device(DefaultDevice()) + .requires_grad(true))}, + device, + testfn, + /*rtol=*/1e-3, + /*atol=*/1e-5); }); } @@ -10005,10 +10711,15 @@ TEST_F(LazyOpsTest, TestLogSoftmaxBackward) { ForEachDevice([&](const torch::Device& device) { TestBackward( - {torch::rand({5, 3, 4, 2}, torch::TensorOptions(torch::kFloat) - .device(DefaultDevice()) - .requires_grad(true))}, - device, testfn, /*rtol=*/1e-3, /*atol=*/1e-4); + {torch::rand( + {5, 3, 4, 2}, + torch::TensorOptions(torch::kFloat) + .device(DefaultDevice()) + .requires_grad(true))}, + device, + testfn, + /*rtol=*/1e-3, + /*atol=*/1e-4); }); } } @@ -10022,10 +10733,15 @@ TEST_F(LazyOpsTest, TestSoftmaxBackward) { ForEachDevice([&](const torch::Device& device) { TestBackward( - {torch::rand({5, 3, 4, 2}, torch::TensorOptions(torch::kFloat) - .device(DefaultDevice()) - .requires_grad(true))}, - device, testfn, /*rtol=*/1e-3, /*atol=*/1e-4); + {torch::rand( + {5, 3, 4, 2}, + torch::TensorOptions(torch::kFloat) + .device(DefaultDevice()) + .requires_grad(true))}, + device, + testfn, + /*rtol=*/1e-3, + /*atol=*/1e-4); }); } } @@ -10035,10 +10751,15 @@ TEST_F(LazyOpsTest, TestSoftplusBackward) { return torch::softplus(inputs[0]); }; ForEachDevice([&](const torch::Device& device) { - TestBackward({torch::rand({2, 1, 4, 6}, torch::TensorOptions(torch::kFloat) - .device(DefaultDevice()) - .requires_grad(true))}, - device, testfn, /*rtol=*/1e-4); + TestBackward( + {torch::rand( + {2, 1, 4, 6}, + torch::TensorOptions(torch::kFloat) + .device(DefaultDevice()) + .requires_grad(true))}, + device, + testfn, + /*rtol=*/1e-4); }); } @@ -10047,10 +10768,14 @@ TEST_F(LazyOpsTest, TestReluBackward) { return torch::relu(inputs[0]); }; ForEachDevice([&](const torch::Device& device) { - TestBackward({torch::rand({2, 1, 4, 6}, torch::TensorOptions(torch::kFloat) - .device(DefaultDevice()) - .requires_grad(true))}, - device, testfn); + TestBackward( + {torch::rand( + {2, 1, 4, 6}, + torch::TensorOptions(torch::kFloat) + .device(DefaultDevice()) + .requires_grad(true))}, + device, + testfn); }); } @@ -10059,10 +10784,14 @@ TEST_F(LazyOpsTest, TestRreluBackward) { return torch::rrelu(inputs[0]); }; ForEachDevice([&](const torch::Device& device) { - TestBackward({torch::rand({2, 1, 4, 6}, torch::TensorOptions(torch::kFloat) - .device(DefaultDevice()) - .requires_grad(true))}, - device, testfn); + TestBackward( + {torch::rand( + {2, 1, 4, 6}, + torch::TensorOptions(torch::kFloat) + .device(DefaultDevice()) + .requires_grad(true))}, + device, + testfn); }); } @@ -10071,10 +10800,14 @@ TEST_F(LazyOpsTest, TestHardshrinkBackward) { return torch::hardshrink(inputs[0]); }; ForEachDevice([&](const torch::Device& device) { - TestBackward({torch::randn({100}, torch::TensorOptions(torch::kFloat) - .device(DefaultDevice()) - .requires_grad(true))}, - device, testfn); + TestBackward( + {torch::randn( + {100}, + torch::TensorOptions(torch::kFloat) + .device(DefaultDevice()) + .requires_grad(true))}, + device, + testfn); }); } @@ -10083,10 +10816,14 @@ TEST_F(LazyOpsTest, TestSoftshrinkBackward) { return torch::softshrink(inputs[0]); }; ForEachDevice([&](const torch::Device& device) { - TestBackward({torch::randn({100}, torch::TensorOptions(torch::kFloat) - .device(DefaultDevice()) - .requires_grad(true))}, - device, testfn); + TestBackward( + {torch::randn( + {100}, + torch::TensorOptions(torch::kFloat) + .device(DefaultDevice()) + .requires_grad(true))}, + device, + testfn); }); } @@ -10095,10 +10832,14 @@ TEST_F(LazyOpsTest, TestHardtanhBackward) { return torch::hardtanh(inputs[0]); }; ForEachDevice([&](const torch::Device& device) { - TestBackward({torch::randn({100}, torch::TensorOptions(torch::kFloat) - .device(DefaultDevice()) - .requires_grad(true))}, - device, testfn); + TestBackward( + {torch::randn( + {100}, + torch::TensorOptions(torch::kFloat) + .device(DefaultDevice()) + .requires_grad(true))}, + device, + testfn); }); } @@ -10110,10 +10851,14 @@ TEST_F(LazyOpsTest, TestEluBackward) { return torch::elu(inputs[0], alpha, scale, input_scale); }; ForEachDevice([&](const torch::Device& device) { - TestBackward({torch::rand({2, 1, 4, 6}, torch::TensorOptions(torch::kFloat) - .device(DefaultDevice()) - .requires_grad(true))}, - device, testfn); + TestBackward( + {torch::rand( + {2, 1, 4, 6}, + torch::TensorOptions(torch::kFloat) + .device(DefaultDevice()) + .requires_grad(true))}, + device, + testfn); }); } @@ -10122,10 +10867,14 @@ TEST_F(LazyOpsTest, TestGeluBackward) { return torch::gelu(inputs[0]); }; ForEachDevice([&](const torch::Device& device) { - TestBackward({torch::rand({2, 3}, torch::TensorOptions(torch::kFloat) - .device(DefaultDevice()) - .requires_grad(true))}, - device, testfn); + TestBackward( + {torch::rand( + {2, 3}, + torch::TensorOptions(torch::kFloat) + .device(DefaultDevice()) + .requires_grad(true))}, + device, + testfn); }); ExpectCounterChanged("lazy::gelu_backward", GetIgnoredCounters()); } @@ -10136,10 +10885,14 @@ TEST_F(LazyOpsTest, TestLeakyReluBackward) { return torch::leaky_relu(inputs[0], negative_slope); }; ForEachDevice([&](const torch::Device& device) { - TestBackward({torch::rand({2, 1, 4, 6}, torch::TensorOptions(torch::kFloat) - .device(DefaultDevice()) - .requires_grad(true))}, - device, testfn); + TestBackward( + {torch::rand( + {2, 1, 4, 6}, + torch::TensorOptions(torch::kFloat) + .device(DefaultDevice()) + .requires_grad(true))}, + device, + testfn); }); } @@ -10148,10 +10901,14 @@ TEST_F(LazyOpsTest, TestTransposeBackward) { return torch::t(inputs[0]); }; ForEachDevice([&](const torch::Device& device) { - TestBackward({torch::rand({2, 3}, torch::TensorOptions(torch::kFloat) - .device(DefaultDevice()) - .requires_grad(true))}, - device, testfn); + TestBackward( + {torch::rand( + {2, 3}, + torch::TensorOptions(torch::kFloat) + .device(DefaultDevice()) + .requires_grad(true))}, + device, + testfn); }); } @@ -10166,18 +10923,24 @@ TEST_F(LazyOpsTest, TestAddMatMulBackward) { return torch::addmm(inputs[0], inputs[1], inputs[2], /*beta=*/beta); }; ForEachDevice([&](const torch::Device& device) { - TestBackward({torch::rand({labels}, torch::TensorOptions(torch::kFloat) - .device(DefaultDevice()) - .requires_grad(true)), - torch::rand({in_channels, out_channels}, - torch::TensorOptions(torch::kFloat) - .device(DefaultDevice()) - .requires_grad(true)), - torch::rand({out_channels, labels}, - torch::TensorOptions(torch::kFloat) - .device(DefaultDevice()) - .requires_grad(true))}, - device, testfn); + TestBackward( + {torch::rand( + {labels}, + torch::TensorOptions(torch::kFloat) + .device(DefaultDevice()) + .requires_grad(true)), + torch::rand( + {in_channels, out_channels}, + torch::TensorOptions(torch::kFloat) + .device(DefaultDevice()) + .requires_grad(true)), + torch::rand( + {out_channels, labels}, + torch::TensorOptions(torch::kFloat) + .device(DefaultDevice()) + .requires_grad(true))}, + device, + testfn); }); } } @@ -10197,18 +10960,24 @@ TEST_F(LazyOpsTest, TestBinaryCrossEntropyBackward) { weight = torch::rand({batch, classes}, torch::TensorOptions(dtype)); } for (torch::Reduction::Reduction reduction : - {torch::Reduction::Mean, torch::Reduction::Sum, + {torch::Reduction::Mean, + torch::Reduction::Sum, torch::Reduction::None}) { auto testfn = [&](const std::vector& inputs) -> torch::Tensor { return torch::binary_cross_entropy( - /*self=*/inputs[0], /*target=*/inputs[1], + /*self=*/inputs[0], + /*target=*/inputs[1], /*weight=*/inputs[2], /*reduction=*/reduction); }; ForEachDevice([&](const torch::Device& device) { - TestBackward({input, target, weight}, device, testfn, /*rtol=*/1e-4, - /*atol=*/1e-7); + TestBackward( + {input, target, weight}, + device, + testfn, + /*rtol=*/1e-4, + /*atol=*/1e-7); }); } } @@ -10225,12 +10994,15 @@ TEST_F(LazyOpsTest, TestNllLossBackward) { for (auto dtype : {torch::kFloat}) { for (int ignore_index : {-1, 0, 1, 5}) { for (bool def_weight : {false, true}) { - torch::Tensor input = - torch::rand({batch, classes}, torch::TensorOptions(dtype) - .device(DefaultDevice()) - .requires_grad(true)); + torch::Tensor input = torch::rand( + {batch, classes}, + torch::TensorOptions(dtype) + .device(DefaultDevice()) + .requires_grad(true)); torch::Tensor target = torch::randint( - std::min(ignore_index, 0), classes, {batch}, + std::min(ignore_index, 0), + classes, + {batch}, torch::TensorOptions(torch::kLong).device(DefaultDevice())); torch::Tensor weight; if (def_weight) { @@ -10238,18 +11010,25 @@ TEST_F(LazyOpsTest, TestNllLossBackward) { {classes}, torch::TensorOptions(dtype).device(DefaultDevice())); } for (torch::Reduction::Reduction reduction : - {torch::Reduction::Mean, torch::Reduction::Sum, + {torch::Reduction::Mean, + torch::Reduction::Sum, torch::Reduction::None}) { auto testfn = [&](const std::vector& inputs) -> torch::Tensor { return torch::nll_loss( - /*self=*/inputs[0], /*target=*/inputs[1], + /*self=*/inputs[0], + /*target=*/inputs[1], /*weight=*/inputs[2], - /*reduction=*/reduction, /*ignore_index=*/ignore_index); + /*reduction=*/reduction, + /*ignore_index=*/ignore_index); }; ForEachDevice([&](const torch::Device& device) { - TestBackward({input, target, weight}, device, testfn, /*rtol=*/1e-5, - /*atol=*/1e-8); + TestBackward( + {input, target, weight}, + device, + testfn, + /*rtol=*/1e-5, + /*atol=*/1e-8); }); } } @@ -10266,12 +11045,15 @@ TEST_F(LazyOpsTest, TestNllLoss2dBackward) { for (auto dtype : {torch::kFloat}) { for (int ignore_index : {-1, 0, 1, 5}) { for (bool def_weight : {false, true}) { - torch::Tensor input = torch::rand({batch, classes, height, width}, - torch::TensorOptions(dtype) - .device(DefaultDevice()) - .requires_grad(true)); + torch::Tensor input = torch::rand( + {batch, classes, height, width}, + torch::TensorOptions(dtype) + .device(DefaultDevice()) + .requires_grad(true)); torch::Tensor target = torch::randint( - std::min(ignore_index, 0), classes, {batch, height, width}, + std::min(ignore_index, 0), + classes, + {batch, height, width}, torch::TensorOptions(torch::kLong).device(DefaultDevice())); torch::Tensor weight; if (def_weight) { @@ -10279,18 +11061,25 @@ TEST_F(LazyOpsTest, TestNllLoss2dBackward) { {classes}, torch::TensorOptions(dtype).device(DefaultDevice())); } for (torch::Reduction::Reduction reduction : - {torch::Reduction::Mean, torch::Reduction::Sum, + {torch::Reduction::Mean, + torch::Reduction::Sum, torch::Reduction::None}) { auto testfn = [&](const std::vector& inputs) -> torch::Tensor { return torch::nll_loss2d( - /*self=*/inputs[0], /*target=*/inputs[1], + /*self=*/inputs[0], + /*target=*/inputs[1], /*weight=*/inputs[2], - /*reduction=*/reduction, /*ignore_index=*/ignore_index); + /*reduction=*/reduction, + /*ignore_index=*/ignore_index); }; ForEachDevice([&](const torch::Device& device) { - TestBackward({input, target, weight}, device, testfn, /*rtol=*/1e-5, - /*atol=*/1e-8); + TestBackward( + {input, target, weight}, + device, + testfn, + /*rtol=*/1e-5, + /*atol=*/1e-8); }); } } @@ -10299,23 +11088,33 @@ TEST_F(LazyOpsTest, TestNllLoss2dBackward) { } TEST_F(LazyOpsTest, TestSmoothL1LossBackward) { - torch::Tensor input = torch::randn({2, 4}, torch::TensorOptions(torch::kFloat) - .device(DefaultDevice()) - .requires_grad(true)); + torch::Tensor input = torch::randn( + {2, 4}, + torch::TensorOptions(torch::kFloat) + .device(DefaultDevice()) + .requires_grad(true)); torch::Tensor target = torch::randn( {2, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice())); for (torch::Reduction::Reduction reduction : - {torch::Reduction::None, torch::Reduction::Mean, + {torch::Reduction::None, + torch::Reduction::Mean, torch::Reduction::Sum}) { for (double beta : {0.25, 1.}) { auto testfn = [&](const std::vector& inputs) -> torch::Tensor { - return torch::smooth_l1_loss(/*input=*/inputs[0], /*target=*/inputs[1], - /*reduction=*/reduction, /*beta=*/beta); + return torch::smooth_l1_loss( + /*input=*/inputs[0], + /*target=*/inputs[1], + /*reduction=*/reduction, + /*beta=*/beta); }; ForEachDevice([&](const torch::Device& device) { - TestBackward({input, target}, device, testfn, /*rtol=*/1e-5, - /*atol=*/1e-8); + TestBackward( + {input, target}, + device, + testfn, + /*rtol=*/1e-5, + /*atol=*/1e-8); }); } } @@ -10327,10 +11126,13 @@ TEST_F(LazyOpsTest, TestViewBackward) { }; ForEachDevice([&](const torch::Device& device) { TestBackward( - {torch::rand({32, 20, 4, 4}, torch::TensorOptions(torch::kFloat) - .device(DefaultDevice()) - .requires_grad(true))}, - device, testfn); + {torch::rand( + {32, 20, 4, 4}, + torch::TensorOptions(torch::kFloat) + .device(DefaultDevice()) + .requires_grad(true))}, + device, + testfn); }); } @@ -10339,40 +11141,51 @@ TEST_F(LazyOpsTest, TestBatchNorm2DBackward) { double eps = 0.5; auto testfn = [&](const std::vector& inputs) -> torch::Tensor { return torch::batch_norm( - /*input=*/inputs[0], /*weight=*/inputs[1], /*bias=*/inputs[2], - /*running_mean=*/inputs[3], /*running_var=*/inputs[4], - /*training=*/true, /*momentum=*/momentum, /*eps=*/eps, + /*input=*/inputs[0], + /*weight=*/inputs[1], + /*bias=*/inputs[2], + /*running_mean=*/inputs[3], + /*running_var=*/inputs[4], + /*training=*/true, + /*momentum=*/momentum, + /*eps=*/eps, /*cudnn_enabled=*/false); }; int num_features = 3; torch::Tensor undef; for (bool undef_weight_bias : {false, true}) { ForEachDevice([&](const torch::Device& device) { - torch::Tensor input = torch::rand({2, num_features, 4, 4}, - torch::TensorOptions(torch::kFloat) - .device(DefaultDevice()) - .requires_grad(true)); - torch::Tensor weight = - undef_weight_bias - ? undef - : torch::rand({num_features}, torch::TensorOptions(torch::kFloat) - .device(DefaultDevice()) - .requires_grad(true)); - torch::Tensor bias = - undef_weight_bias - ? undef - : torch::rand({num_features}, torch::TensorOptions(torch::kFloat) - .device(DefaultDevice()) - .requires_grad(true)); + torch::Tensor input = torch::rand( + {2, num_features, 4, 4}, + torch::TensorOptions(torch::kFloat) + .device(DefaultDevice()) + .requires_grad(true)); + torch::Tensor weight = undef_weight_bias + ? undef + : torch::rand( + {num_features}, + torch::TensorOptions(torch::kFloat) + .device(DefaultDevice()) + .requires_grad(true)); + torch::Tensor bias = undef_weight_bias + ? undef + : torch::rand( + {num_features}, + torch::TensorOptions(torch::kFloat) + .device(DefaultDevice()) + .requires_grad(true)); torch::Tensor running_mean = torch::zeros( {num_features}, torch::TensorOptions(torch::kFloat).device(DefaultDevice())); torch::Tensor running_var = torch::ones( {num_features}, torch::TensorOptions(torch::kFloat).device(DefaultDevice())); - TestBackward({input, weight, bias, running_mean, running_var}, device, - testfn, - /*rtol=*/1e-3, /*atol=*/1e-4); + TestBackward( + {input, weight, bias, running_mean, running_var}, + device, + testfn, + /*rtol=*/1e-3, + /*atol=*/1e-4); }); } } @@ -10382,40 +11195,51 @@ TEST_F(LazyOpsTest, TestBatchNorm3DBackward) { double eps = 0.5; auto testfn = [&](const std::vector& inputs) -> torch::Tensor { return torch::batch_norm( - /*input=*/inputs[0], /*weight=*/inputs[1], /*bias=*/inputs[2], - /*running_mean=*/inputs[3], /*running_var=*/inputs[4], - /*training=*/true, /*momentum=*/momentum, /*eps=*/eps, + /*input=*/inputs[0], + /*weight=*/inputs[1], + /*bias=*/inputs[2], + /*running_mean=*/inputs[3], + /*running_var=*/inputs[4], + /*training=*/true, + /*momentum=*/momentum, + /*eps=*/eps, /*cudnn_enabled=*/false); }; int num_features = 3; torch::Tensor undef; for (bool undef_weight_bias : {false, true}) { ForEachDevice([&](const torch::Device& device) { - torch::Tensor input = torch::rand({2, num_features, 4, 4, 2}, - torch::TensorOptions(torch::kFloat) - .device(DefaultDevice()) - .requires_grad(true)); - torch::Tensor weight = - undef_weight_bias - ? undef - : torch::rand({num_features}, torch::TensorOptions(torch::kFloat) - .device(DefaultDevice()) - .requires_grad(true)); - torch::Tensor bias = - undef_weight_bias - ? undef - : torch::rand({num_features}, torch::TensorOptions(torch::kFloat) - .device(DefaultDevice()) - .requires_grad(true)); + torch::Tensor input = torch::rand( + {2, num_features, 4, 4, 2}, + torch::TensorOptions(torch::kFloat) + .device(DefaultDevice()) + .requires_grad(true)); + torch::Tensor weight = undef_weight_bias + ? undef + : torch::rand( + {num_features}, + torch::TensorOptions(torch::kFloat) + .device(DefaultDevice()) + .requires_grad(true)); + torch::Tensor bias = undef_weight_bias + ? undef + : torch::rand( + {num_features}, + torch::TensorOptions(torch::kFloat) + .device(DefaultDevice()) + .requires_grad(true)); torch::Tensor running_mean = torch::zeros( {num_features}, torch::TensorOptions(torch::kFloat).device(DefaultDevice())); torch::Tensor running_var = torch::ones( {num_features}, torch::TensorOptions(torch::kFloat).device(DefaultDevice())); - TestBackward({input, weight, bias, running_mean, running_var}, device, - testfn, - /*rtol=*/1e-3, /*atol=*/1e-3); + TestBackward( + {input, weight, bias, running_mean, running_var}, + device, + testfn, + /*rtol=*/1e-3, + /*atol=*/1e-3); }); } } @@ -10425,38 +11249,47 @@ TEST_F(LazyOpsTest, TestBCEWithLogitsBackward) { int classes = 5; torch::Tensor undef; for (torch::Reduction::Reduction reduction : - {torch::Reduction::None, torch::Reduction::Mean, + {torch::Reduction::None, + torch::Reduction::Mean, torch::Reduction::Sum}) { auto testfn = [&](const std::vector& inputs) -> torch::Tensor { return torch::binary_cross_entropy_with_logits( - /*input=*/inputs[0], /*target=*/inputs[1], /*weight=*/inputs[2], + /*input=*/inputs[0], + /*target=*/inputs[1], + /*weight=*/inputs[2], /*pos_weight=*/inputs[3], /*reduction=*/reduction); }; for (bool undef_weight : {false, true}) { for (bool undef_pos_weight : {false, true}) { - torch::Tensor input = - torch::rand({batch, classes}, torch::TensorOptions(torch::kFloat) - .device(DefaultDevice()) - .requires_grad(true)); - torch::Tensor target = - torch::rand({batch, classes}, torch::TensorOptions(torch::kFloat) - .device(DefaultDevice()) - .requires_grad(true)); - torch::Tensor weight = - undef_weight - ? undef - : torch::rand({classes}, torch::TensorOptions(torch::kFloat) - .device(DefaultDevice())); - torch::Tensor pos_weight = - undef_pos_weight - ? undef - : torch::rand({classes}, torch::TensorOptions(torch::kFloat) - .device(DefaultDevice())); + torch::Tensor input = torch::rand( + {batch, classes}, + torch::TensorOptions(torch::kFloat) + .device(DefaultDevice()) + .requires_grad(true)); + torch::Tensor target = torch::rand( + {batch, classes}, + torch::TensorOptions(torch::kFloat) + .device(DefaultDevice()) + .requires_grad(true)); + torch::Tensor weight = undef_weight + ? undef + : torch::rand( + {classes}, + torch::TensorOptions(torch::kFloat).device(DefaultDevice())); + torch::Tensor pos_weight = undef_pos_weight + ? undef + : torch::rand( + {classes}, + torch::TensorOptions(torch::kFloat).device(DefaultDevice())); ForEachDevice([&](const torch::Device& device) { - TestBackward({input, target, weight, pos_weight}, device, testfn, - /*rtol=*/1e-3, /*atol=*/1e-5); + TestBackward( + {input, target, weight, pos_weight}, + device, + testfn, + /*rtol=*/1e-3, + /*atol=*/1e-5); }); } } @@ -10464,22 +11297,31 @@ TEST_F(LazyOpsTest, TestBCEWithLogitsBackward) { } TEST_F(LazyOpsTest, TestKlDivBackward) { - torch::Tensor input = torch::rand({4, 3}, torch::TensorOptions(torch::kFloat) - .device(DefaultDevice()) - .requires_grad(true)); - torch::Tensor target = torch::rand({4, 3}, torch::TensorOptions(torch::kFloat) - .device(DefaultDevice()) - .requires_grad(true)); + torch::Tensor input = torch::rand( + {4, 3}, + torch::TensorOptions(torch::kFloat) + .device(DefaultDevice()) + .requires_grad(true)); + torch::Tensor target = torch::rand( + {4, 3}, + torch::TensorOptions(torch::kFloat) + .device(DefaultDevice()) + .requires_grad(true)); for (torch::Reduction::Reduction reduction : - {torch::Reduction::Mean, torch::Reduction::Sum, + {torch::Reduction::Mean, + torch::Reduction::Sum, torch::Reduction::None}) { auto testfn = [&](const std::vector& inputs) -> torch::Tensor { return torch::kl_div(/*self=*/inputs[0], /*target=*/inputs[1], reduction); }; ForEachDevice([&](const torch::Device& device) { - TestBackward({input, target}, device, testfn, /*rtol=*/1e-4, - /*atol=*/1e-5); + TestBackward( + {input, target}, + device, + testfn, + /*rtol=*/1e-4, + /*atol=*/1e-5); }); } } @@ -10490,21 +11332,29 @@ TEST_F(LazyOpsTest, TestEmbeddingBackward) { for (bool scale_grad_by_freq : {false, true}) { auto testfn = [&](const std::vector& inputs) -> torch::Tensor { - return torch::embedding(inputs[0], inputs[1], - /*padding_idx=*/padding_idx, - /*scale_grad_by_freq=*/scale_grad_by_freq, - /*sparse=*/false); + return torch::embedding( + inputs[0], + inputs[1], + /*padding_idx=*/padding_idx, + /*scale_grad_by_freq=*/scale_grad_by_freq, + /*sparse=*/false); }; ForEachDevice([&](const torch::Device& device) { - torch::Tensor weight = - torch::rand({num_weights, 7}, torch::TensorOptions(torch::kFloat) - .device(DefaultDevice()) - .requires_grad(true)); + torch::Tensor weight = torch::rand( + {num_weights, 7}, + torch::TensorOptions(torch::kFloat) + .device(DefaultDevice()) + .requires_grad(true)); torch::Tensor indices = torch::randint( - num_weights, {3, 9, 4}, + num_weights, + {3, 9, 4}, torch::TensorOptions(torch::kLong).device(DefaultDevice())); - TestBackward({weight, indices}, device, testfn, /*rtol=*/1e-5, - /*atol=*/1e-8); + TestBackward( + {weight, indices}, + device, + testfn, + /*rtol=*/1e-5, + /*atol=*/1e-8); }); } } @@ -10538,14 +11388,14 @@ TEST_F(LazyOpsTest, TestAmpForeachNonFiniteCheckAndUnscale) { torch::Tensor lazy_grads0 = CopyToDevice(grads0, device); torch::Tensor lazy_inv_scale = CopyToDevice(inv_scale, device); torch::Tensor lazy_found_inf = CopyToDevice(found_inf, device); - torch::_amp_foreach_non_finite_check_and_unscale_(lazy_grads0, lazy_found_inf, - lazy_inv_scale); + torch::_amp_foreach_non_finite_check_and_unscale_( + lazy_grads0, lazy_found_inf, lazy_inv_scale); AllClose(grads_output0, lazy_grads0, /*rtol=*/1e-2, /*atol=*/1e-4); AllEqual(found_inf_output0, lazy_found_inf); torch::Tensor lazy_grads1 = CopyToDevice(grads1, device); - torch::_amp_foreach_non_finite_check_and_unscale_(lazy_grads1, lazy_found_inf, - lazy_inv_scale); + torch::_amp_foreach_non_finite_check_and_unscale_( + lazy_grads1, lazy_found_inf, lazy_inv_scale); AllEqual(found_inf_output1, lazy_found_inf); }); } @@ -10589,38 +11439,65 @@ TEST_F(LazyOpsTest, TestAmpUpdateScale) { torch::Tensor lazy_found_inf = CopyToDevice(found_inf, device); torch::Tensor lazy_not_found_inf = CopyToDevice(not_found_inf, device); - torch::_amp_update_scale_(lazy_current_scale, lazy_growth_tracker, - lazy_not_found_inf, scale_growth_factor, - scale_backoff_factor, growth_interval); - AllClose(current_scale_result0, lazy_current_scale, /*rtol=*/1e-2, - /*atol=*/1e-4); + torch::_amp_update_scale_( + lazy_current_scale, + lazy_growth_tracker, + lazy_not_found_inf, + scale_growth_factor, + scale_backoff_factor, + growth_interval); + AllClose( + current_scale_result0, + lazy_current_scale, + /*rtol=*/1e-2, + /*atol=*/1e-4); AllEqual(growth_tracker_result0, lazy_growth_tracker); - torch::_amp_update_scale_(lazy_current_scale, lazy_growth_tracker, - lazy_not_found_inf, scale_growth_factor, - scale_backoff_factor, growth_interval); - AllClose(current_scale_result1, lazy_current_scale, /*rtol=*/1e-2, - /*atol=*/1e-4); + torch::_amp_update_scale_( + lazy_current_scale, + lazy_growth_tracker, + lazy_not_found_inf, + scale_growth_factor, + scale_backoff_factor, + growth_interval); + AllClose( + current_scale_result1, + lazy_current_scale, + /*rtol=*/1e-2, + /*atol=*/1e-4); AllEqual(growth_tracker_result1, lazy_growth_tracker); // torch::_amp_update_scale_ returns the reference of current_scale lazy_current_scale = torch::_amp_update_scale_( - lazy_current_scale, lazy_growth_tracker, lazy_not_found_inf, - scale_growth_factor, scale_backoff_factor, growth_interval); - AllClose(current_scale_result2, lazy_current_scale, /*rtol=*/1e-2, - /*atol=*/1e-4); + lazy_current_scale, + lazy_growth_tracker, + lazy_not_found_inf, + scale_growth_factor, + scale_backoff_factor, + growth_interval); + AllClose( + current_scale_result2, + lazy_current_scale, + /*rtol=*/1e-2, + /*atol=*/1e-4); AllEqual(growth_tracker_result2, lazy_growth_tracker); lazy_current_scale = torch::_amp_update_scale_( - lazy_current_scale, lazy_growth_tracker, lazy_found_inf, - scale_growth_factor, scale_backoff_factor, growth_interval); - AllClose(current_scale_result3, lazy_current_scale, /*rtol=*/1e-2, - /*atol=*/1e-4); + lazy_current_scale, + lazy_growth_tracker, + lazy_found_inf, + scale_growth_factor, + scale_backoff_factor, + growth_interval); + AllClose( + current_scale_result3, + lazy_current_scale, + /*rtol=*/1e-2, + /*atol=*/1e-4); AllEqual(growth_tracker_result3, lazy_growth_tracker); }); ExpectCounterNotChanged("aten::.*", GetIgnoredCounters()); - ExpectCounterChanged("lazy::_amp_update_scale_", - GetIgnoredCounters()); + ExpectCounterChanged("lazy::_amp_update_scale_", GetIgnoredCounters()); } TEST_F(LazyOpsTest, TestEarlySyncLiveTensors) { @@ -10633,14 +11510,11 @@ TEST_F(LazyOpsTest, TestEarlySyncLiveTensors) { ASSERT_EQ(scalar1.to(), scalar2.to()); }); if (DebugUtil::ExperimentEnabled("early_sync")) { - ExpectCounterChanged("EarlySyncLiveTensorsCount", - GetIgnoredCounters()); + ExpectCounterChanged("EarlySyncLiveTensorsCount", GetIgnoredCounters()); } else { - ExpectCounterNotChanged("EarlySyncLiveTensorsCount", - GetIgnoredCounters()); + ExpectCounterNotChanged("EarlySyncLiveTensorsCount", GetIgnoredCounters()); } - ExpectCounterChanged("aten::_local_scalar_dense", - GetIgnoredCounters()); + ExpectCounterChanged("aten::_local_scalar_dense", GetIgnoredCounters()); } TEST_F(LazyOpsTest, TestLerp) { @@ -10761,8 +11635,10 @@ TEST_F(LazyOpsTest, TestLerpScalarOut) { } TEST_F(LazyOpsTest, IsAliasOf) { - auto a = torch::empty(4, torch::TensorOptions(torch::kFloat).device(DefaultDevice())); - auto b = torch::empty(4, torch::TensorOptions(torch::kFloat).device(DefaultDevice())); + auto a = torch::empty( + 4, torch::TensorOptions(torch::kFloat).device(DefaultDevice())); + auto b = torch::empty( + 4, torch::TensorOptions(torch::kFloat).device(DefaultDevice())); ForEachDevice([&](const torch::Device& device) { auto lazy_a = CopyToDevice(a, device); @@ -10782,5 +11658,5 @@ TEST_F(LazyOpsTest, IsAliasOf) { #endif // FBCODE_CAFFE2 -} // namespace lazy -} // namespace torch +} // namespace lazy +} // namespace torch diff --git a/test/cpp/lazy/test_lazy_ops_util.cpp b/test/cpp/lazy/test_lazy_ops_util.cpp index 91c9b653e04131..1045c3ffc9c698 100644 --- a/test/cpp/lazy/test_lazy_ops_util.cpp +++ b/test/cpp/lazy/test_lazy_ops_util.cpp @@ -1,20 +1,20 @@ #include #include -#include #include +#include #include #include #include - namespace torch { namespace lazy { namespace { bool IsLtcTensor(const at::Tensor& tensor) { - return dynamic_cast(tensor.unsafeGetTensorImpl()); + return dynamic_cast( + tensor.unsafeGetTensorImpl()); } std::unordered_set* CreateIgnoredCounters() { @@ -26,7 +26,7 @@ std::unordered_set* CreateIgnoredCounters() { return icounters; } -} // namespace +} // namespace const std::unordered_set* GetIgnoredCounters() { static const std::unordered_set* icounters = @@ -39,8 +39,9 @@ at::Tensor ToCpuTensor(const at::Tensor& tensor) { return tensor.to(torch::kCPU); } -torch::Tensor CopyToDevice(const torch::Tensor& tensor, - const torch::Device& device) { +torch::Tensor CopyToDevice( + const torch::Tensor& tensor, + const torch::Device& device) { return tensor.clone().to(device, /*non_blocking=*/false, /*copy=*/true); } @@ -87,16 +88,19 @@ bool EqualValuesNoElementTypeCheck(at::Tensor tensor1, at::Tensor tensor2) { } void ForEachDevice(const std::function& devfn) { - // Currently TorchScript backend only supports one type of hardware per process, - // which is set by env. And the ordinal is always 0 given distributed training/ - // multi-device is not supported yet. + // Currently TorchScript backend only supports one type of hardware per + // process, which is set by env. And the ordinal is always 0 given distributed + // training/ multi-device is not supported yet. auto device = torch::lazy::BackendDevice(); torch::Device torch_device = torch::lazy::backendDeviceToAtenDevice(device); devfn(torch_device); } -bool CloseValues(at::Tensor tensor1, at::Tensor tensor2, double rtol, - double atol) { +bool CloseValues( + at::Tensor tensor1, + at::Tensor tensor2, + double rtol, + double atol) { tensor1 = ToCpuTensor(tensor1); tensor2 = ToCpuTensor(tensor2); if (torch::isnan(tensor1).any().item()) { @@ -126,10 +130,13 @@ std::string GetTensorDotGraph(at::Tensor tensor) { } void TestBackward( - const std::vector& inputs, const torch::Device& device, + const std::vector& inputs, + const torch::Device& device, const std::function&)>& testfn, - double rtol, double atol, int derivative_level) { + double rtol, + double atol, + int derivative_level) { std::vector input_vars; std::vector xinput_vars; std::vector inputs_w_grad; @@ -173,14 +180,20 @@ void TestBackward( } // Calculating higher order derivative requires create_graph=true bool create_graph = d != derivative_level; - outs = torch::autograd::grad({sum}, inputs_w_grad, /*grad_outputs=*/{}, - /*retain_graph=*/c10::nullopt, - /*create_graph=*/create_graph, - /*allow_unused=*/true); - xouts = torch::autograd::grad({xsum}, xinputs_w_grad, /*grad_outputs=*/{}, - /*retain_graph=*/c10::nullopt, - /*create_graph=*/create_graph, - /*allow_unused=*/true); + outs = torch::autograd::grad( + {sum}, + inputs_w_grad, + /*grad_outputs=*/{}, + /*retain_graph=*/c10::nullopt, + /*create_graph=*/create_graph, + /*allow_unused=*/true); + xouts = torch::autograd::grad( + {xsum}, + xinputs_w_grad, + /*grad_outputs=*/{}, + /*retain_graph=*/c10::nullopt, + /*create_graph=*/create_graph, + /*allow_unused=*/true); for (size_t i = 0; i < outs.size(); ++i) { ASSERT_EQ(outs[i].defined(), xouts[i].defined()); if (outs[i].defined()) { @@ -190,5 +203,5 @@ void TestBackward( } } -} // namespace lazy -} // namespace torch +} // namespace lazy +} // namespace torch diff --git a/test/cpp/lazy/test_lazy_ops_util.h b/test/cpp/lazy/test_lazy_ops_util.h index 6dc26b48be9518..84860794dd485c 100644 --- a/test/cpp/lazy/test_lazy_ops_util.h +++ b/test/cpp/lazy/test_lazy_ops_util.h @@ -25,23 +25,33 @@ const std::unordered_set* GetIgnoredCounters(); at::Tensor ToCpuTensor(const at::Tensor& tensor); // Helper function to copy a tensor to device. -torch::Tensor CopyToDevice(const torch::Tensor& tensor, - const torch::Device& device); +torch::Tensor CopyToDevice( + const torch::Tensor& tensor, + const torch::Device& device); bool EqualValues(at::Tensor tensor1, at::Tensor tensor2); bool EqualValuesNoElementTypeCheck(at::Tensor tensor1, at::Tensor tensor2); -bool CloseValues(at::Tensor tensor1, at::Tensor tensor2, double rtol = 1e-5, - double atol = 1e-8); - -static inline void AllClose(at::Tensor tensor, at::Tensor xla_tensor, - double rtol = 1e-5, double atol = 1e-8) { +bool CloseValues( + at::Tensor tensor1, + at::Tensor tensor2, + double rtol = 1e-5, + double atol = 1e-8); + +static inline void AllClose( + at::Tensor tensor, + at::Tensor xla_tensor, + double rtol = 1e-5, + double atol = 1e-8) { EXPECT_TRUE(CloseValues(tensor, xla_tensor, rtol, atol)); } -static inline void AllClose(at::Tensor tensor, torch::lazy::LazyTensor& xla_tensor, - double rtol = 1e-5, double atol = 1e-8) { +static inline void AllClose( + at::Tensor tensor, + torch::lazy::LazyTensor& xla_tensor, + double rtol = 1e-5, + double atol = 1e-8) { EXPECT_TRUE( CloseValues(tensor, xla_tensor.ToTensor(/*detached=*/false), rtol, atol)); } @@ -59,10 +69,13 @@ std::string GetTensorDotGraph(at::Tensor tensor); std::string GetTensorHloGraph(at::Tensor tensor); void TestBackward( - const std::vector& inputs, const torch::Device& device, + const std::vector& inputs, + const torch::Device& device, const std::function&)>& testfn, - double rtol = 1e-5, double atol = 1e-8, int derivative_level = 1); + double rtol = 1e-5, + double atol = 1e-8, + int derivative_level = 1); -} // namespace lazy -} // namespace torch +} // namespace lazy +} // namespace torch diff --git a/test/cpp/lazy/test_misc.cpp b/test/cpp/lazy/test_misc.cpp index b2f941c42dd6ba..7c16ae7991cc32 100644 --- a/test/cpp/lazy/test_misc.cpp +++ b/test/cpp/lazy/test_misc.cpp @@ -1,8 +1,8 @@ #include #include -#include #include +#include namespace torch { namespace lazy { @@ -27,7 +27,7 @@ TEST(HashTest, Scalar) { // simulate some garbage in the unused bits of the // the tagged union that is c10::Scalar, which is bigger // than the size of the int64_t we're currently using it with - *((uint8_t*)&b) = 1; + *((uint8_t*)&b) = 1; // actual 'value' of the Scalar as a 64 bit int shouldn't have changed EXPECT_EQ(a.toLong(), b.toLong()); // and hash should ignore this garbage @@ -70,7 +70,8 @@ TEST(HashTest, Sanity) { auto a = std::vector({0, 1, 1, 2, 3, 5, 8}); auto b = std::vector({1, 1, 2, 3, 5, 8, 12}); test_hash_repeatable_sensitive(a, b); - test_hash_repeatable_sensitive(c10::ArrayRef(a), c10::ArrayRef(b)); + test_hash_repeatable_sensitive( + c10::ArrayRef(a), c10::ArrayRef(b)); // vector is a special case bc it is implemented as vector auto bool_a = std::vector({true, false, false, true}); diff --git a/test/cpp/lazy/test_shape.cpp b/test/cpp/lazy/test_shape.cpp index a2f1125f5eb949..cc12dfe4118a63 100644 --- a/test/cpp/lazy/test_shape.cpp +++ b/test/cpp/lazy/test_shape.cpp @@ -80,5 +80,5 @@ TEST(ShapeTest, Ostream) { EXPECT_EQ(shape.to_string(), ss.str()); } -} // namespace lazy -} // namespace torch +} // namespace lazy +} // namespace torch diff --git a/test/cpp/lazy/test_symbolic_shape.cpp b/test/cpp/lazy/test_symbolic_shape.cpp index b2224aec0d1c1d..ef344c7f6d92ac 100644 --- a/test/cpp/lazy/test_symbolic_shape.cpp +++ b/test/cpp/lazy/test_symbolic_shape.cpp @@ -52,9 +52,7 @@ class LazyShapeTest : public ::testing::Test { class DynamicInputShapeNode : public Node { public: explicit DynamicInputShapeNode(Shape& shape) - : Node(OpKind(), /* num_outputs */ 1), - hash_(0), - shape_(shape) {} + : Node(OpKind(), /* num_outputs */ 1), hash_(0), shape_(shape) {} ~DynamicInputShapeNode() override = default; const std::vector& operands() const override { @@ -71,8 +69,12 @@ class DynamicInputShapeNode : public Node { return {shape_}; } - hash_t hash() const override { return hash_; } - hash_t shapeHash() const override { return hash_; } + hash_t hash() const override { + return hash_; + } + hash_t shapeHash() const override { + return hash_; + } private: hash_t hash_; diff --git a/test/cpp/lazy/test_tensor_impl.cpp b/test/cpp/lazy/test_tensor_impl.cpp index 8d968f620b6b24..415e03d0faa992 100644 --- a/test/cpp/lazy/test_tensor_impl.cpp +++ b/test/cpp/lazy/test_tensor_impl.cpp @@ -7,13 +7,17 @@ namespace torch { namespace lazy { #ifdef FBCODE_CAFFE2 -// Lazy Tensor is disabled in FBCODE until addressing non-virtual methods (e.g. sizes) in TensorImpl +// Lazy Tensor is disabled in FBCODE until addressing non-virtual methods (e.g. +// sizes) in TensorImpl TEST(LazyTensorImplTest, BasicThrow) { - EXPECT_THROW({ - auto input = torch::rand({0, 1, 3, 0}, torch::TensorOptions(torch::kFloat).device("lazy")); - }, ::c10::Error); + EXPECT_THROW( + { + auto input = torch::rand( + {0, 1, 3, 0}, torch::TensorOptions(torch::kFloat).device("lazy")); + }, + ::c10::Error); } #endif // FBCODE_CAFFE2 -} // namespace lazy -} // namespace torch +} // namespace lazy +} // namespace torch diff --git a/test/cpp/lazy/test_trie_cache.cpp b/test/cpp/lazy/test_trie_cache.cpp index df7d578b94b470..0f57690575c596 100644 --- a/test/cpp/lazy/test_trie_cache.cpp +++ b/test/cpp/lazy/test_trie_cache.cpp @@ -33,8 +33,13 @@ class TrieCacheNode : public Node { operands_.push_back(std::move(v.node)); } - hash_t hash() const override { return hash_; } - hash_t shapeHash() const override { return hash_; } + hash_t hash() const override { + return hash_; + } + hash_t shapeHash() const override { + return hash_; + } + private: size_t id_; hash_t hash_; @@ -56,12 +61,12 @@ TEST(TrieCacheTest, TestSinglePath) { } /* -* 0 -* | -* 1 -* / \ -* 2 3 -*/ + * 0 + * | + * 1 + * / \ + * 2 3 + */ TEST(TrieCacheTest, TestTwoPaths) { FLAGS_torch_lazy_reuse_ir = true; TrieCache::Get()->Clear(); diff --git a/test/cpp/lazy/test_util.cpp b/test/cpp/lazy/test_util.cpp index 6f6541619f0d60..ed9a31f15712ad 100644 --- a/test/cpp/lazy/test_util.cpp +++ b/test/cpp/lazy/test_util.cpp @@ -12,9 +12,8 @@ TEST(UtilTest, ExceptionCleanup) { EXPECT_EQ(exception, nullptr); { - ExceptionCleanup cleanup([&](std::exception_ptr&& e) { - exception = std::move(e); - }); + ExceptionCleanup cleanup( + [&](std::exception_ptr&& e) { exception = std::move(e); }); cleanup.SetStatus(std::make_exception_ptr(std::runtime_error("Oops!"))); } @@ -22,15 +21,14 @@ TEST(UtilTest, ExceptionCleanup) { try { std::rethrow_exception(exception); - } catch(const std::exception& e) { + } catch (const std::exception& e) { EXPECT_STREQ(e.what(), "Oops!"); } exception = nullptr; { - ExceptionCleanup cleanup([&](std::exception_ptr&& e) { - exception = std::move(e); - }); + ExceptionCleanup cleanup( + [&](std::exception_ptr&& e) { exception = std::move(e); }); cleanup.SetStatus(std::make_exception_ptr(std::runtime_error(""))); cleanup.Release(); diff --git a/test/cpp/lite_interpreter_runtime/main.cpp b/test/cpp/lite_interpreter_runtime/main.cpp index ebd8ae8ee31558..2266ac1dbadee2 100644 --- a/test/cpp/lite_interpreter_runtime/main.cpp +++ b/test/cpp/lite_interpreter_runtime/main.cpp @@ -1,9 +1,9 @@ +#include #include +#include +#include #include #include -#include -#include -#include std::string add_negative_flag(const std::string& flag) { std::string filter = ::testing::GTEST_FLAG(filter); @@ -16,8 +16,8 @@ std::string add_negative_flag(const std::string& flag) { return filter; } int main(int argc, char* argv[]) { - ::testing::InitGoogleTest(&argc, argv); - ::testing::GTEST_FLAG(filter) = add_negative_flag("*_CUDA:*_MultiCUDA"); + ::testing::InitGoogleTest(&argc, argv); + ::testing::GTEST_FLAG(filter) = add_negative_flag("*_CUDA:*_MultiCUDA"); - return RUN_ALL_TESTS(); + return RUN_ALL_TESTS(); } diff --git a/test/cpp/lite_interpreter_runtime/test_mobile_profiler.cpp b/test/cpp/lite_interpreter_runtime/test_mobile_profiler.cpp index f25fb123bda826..95ba2b7b853edf 100644 --- a/test/cpp/lite_interpreter_runtime/test_mobile_profiler.cpp +++ b/test/cpp/lite_interpreter_runtime/test_mobile_profiler.cpp @@ -1,4 +1,3 @@ -#include #include #include #include @@ -6,6 +5,7 @@ #include #include #include +#include #include @@ -21,9 +21,9 @@ bool checkMetaData( const std::string& metadata_val, std::ifstream& trace_file) { std::string line; - while (std::getline(trace_file, line) ) { + while (std::getline(trace_file, line)) { if (line.find(op_name) != std::string::npos) { - while (std::getline(trace_file, line) ) { + while (std::getline(trace_file, line)) { if (line.find(metadata_name) != std::string::npos) { return (line.find(metadata_val) != std::string::npos); } @@ -61,13 +61,30 @@ TEST(MobileProfiler, ModuleHierarchy) { ASSERT_TRUE(trace_file.is_open()); trace_file.seekg(0, std::ios_base::beg); const std::string metadata_name("Module Hierarchy"); - ASSERT_TRUE(checkMetaData("aten::sub", metadata_name, "top(C)::.A0(A)::forward.aten::sub", trace_file)); + ASSERT_TRUE(checkMetaData( + "aten::sub", + metadata_name, + "top(C)::.A0(A)::forward.aten::sub", + trace_file)); trace_file.seekg(0, std::ios_base::beg); - ASSERT_TRUE(checkMetaData("aten::mul", metadata_name, "top(C)::.A0(A)::forward.SELF(A)::forward_impl_.SELF(A)::my_new_method.aten::mul", trace_file)); + ASSERT_TRUE(checkMetaData( + "aten::mul", + metadata_name, + "top(C)::.A0(A)::forward.SELF(A)::forward_impl_.SELF(A)::my_new_method.aten::mul", + trace_file)); trace_file.seekg(0, std::ios_base::beg); - ASSERT_TRUE(checkMetaData("aten::add", metadata_name, "top(C)::.A0(A)::forward.SELF(A)::forward_impl_.aten::add", trace_file)); - ASSERT_TRUE(checkMetaData("aten::add", metadata_name, "top(C)::.SELF(C)::call_b.B0(B)::forward.aten::add", trace_file)); - ASSERT_TRUE(checkMetaData("aten::add", metadata_name, "top(C)::.aten::add", trace_file)); + ASSERT_TRUE(checkMetaData( + "aten::add", + metadata_name, + "top(C)::.A0(A)::forward.SELF(A)::forward_impl_.aten::add", + trace_file)); + ASSERT_TRUE(checkMetaData( + "aten::add", + metadata_name, + "top(C)::.SELF(C)::call_b.B0(B)::forward.aten::add", + trace_file)); + ASSERT_TRUE(checkMetaData( + "aten::add", metadata_name, "top(C)::.aten::add", trace_file)); } TEST(MobileProfiler, Backend) { @@ -97,10 +114,12 @@ TEST(MobileProfiler, Backend) { ASSERT_TRUE(trace_file.is_open()); trace_file.seekg(0, std::ios_base::beg); std::string metadata_name("Module Hierarchy"); - ASSERT_TRUE(checkMetaData("aten::add", metadata_name, "top(m)::.aten::add", trace_file)); + ASSERT_TRUE(checkMetaData( + "aten::add", metadata_name, "top(m)::.aten::add", trace_file)); trace_file.seekg(0, std::ios_base::beg); metadata_name = "Backend"; - ASSERT_TRUE(checkMetaData("aten::add", metadata_name, "test_backend", trace_file)); + ASSERT_TRUE( + checkMetaData("aten::add", metadata_name, "test_backend", trace_file)); } } // namespace mobile diff --git a/test/cpp/profiler/containers.cpp b/test/cpp/profiler/containers.cpp index 60e6d0f238b185..c0e0bf14745c8e 100644 --- a/test/cpp/profiler/containers.cpp +++ b/test/cpp/profiler/containers.cpp @@ -10,67 +10,70 @@ #include TEST(ProfilerTest, AppendOnlyList) { - const int n = 4096; - torch::profiler::impl::AppendOnlyList list; - for (const auto i : c10::irange(n)) { - list.emplace_back(i); - ASSERT_EQ(list.size(), i + 1); - } + const int n = 4096; + torch::profiler::impl::AppendOnlyList list; + for (const auto i : c10::irange(n)) { + list.emplace_back(i); + ASSERT_EQ(list.size(), i + 1); + } - int expected = 0; - for (const auto i : list) { - ASSERT_EQ(i, expected++); - } - ASSERT_EQ(expected, n); + int expected = 0; + for (const auto i : list) { + ASSERT_EQ(i, expected++); + } + ASSERT_EQ(expected, n); - list.clear(); - ASSERT_EQ(list.size(), 0); + list.clear(); + ASSERT_EQ(list.size(), 0); } TEST(ProfilerTest, AppendOnlyList_ref) { - const int n = 512; - torch::profiler::impl::AppendOnlyList, 64> list; - std::vector*> refs; - for (const auto _ : c10::irange(n)) { - refs.push_back(list.emplace_back()); - } + const int n = 512; + torch::profiler::impl::AppendOnlyList, 64> list; + std::vector*> refs; + for (const auto _ : c10::irange(n)) { + refs.push_back(list.emplace_back()); + } - for (const auto i : c10::irange(n)) { - *refs.at(i) = {i, 0}; - } + for (const auto i : c10::irange(n)) { + *refs.at(i) = {i, 0}; + } - int expected = 0; - for (const auto& i : list) { - ASSERT_EQ(i.first, expected++); - } + int expected = 0; + for (const auto& i : list) { + ASSERT_EQ(i.first, expected++); + } } // Test that we can convert TSC measurements back to wall clock time. TEST(ProfilerTest, clock_converter) { - const int n = 10001; - torch::profiler::impl::ApproximateClockToUnixTimeConverter converter; - std::vector pairs; - for (const auto i : c10::irange(n)) { - pairs.push_back(torch::profiler::impl::ApproximateClockToUnixTimeConverter::measurePair()); - } - auto count_to_ns = converter.makeConverter(); - std::vector deltas; - for (const auto& i : pairs) { - deltas.push_back(i.t_ - count_to_ns(i.approx_t_)); - } - std::sort(deltas.begin(), deltas.end()); + const int n = 10001; + torch::profiler::impl::ApproximateClockToUnixTimeConverter converter; + std::vector + pairs; + for (const auto i : c10::irange(n)) { + pairs.push_back(torch::profiler::impl::ApproximateClockToUnixTimeConverter:: + measurePair()); + } + auto count_to_ns = converter.makeConverter(); + std::vector deltas; + for (const auto& i : pairs) { + deltas.push_back(i.t_ - count_to_ns(i.approx_t_)); + } + std::sort(deltas.begin(), deltas.end()); - // In general it's not a good idea to put clocks in unit tests as it leads - // to flakiness. We mitigate this by: - // 1) Testing the clock itself. While the time to complete a task may - // vary, two clocks measuring the same time should be much more - // consistent. - // 2) Only testing the interquartile range. Context switches between - // calls to the two timers do occur and can result in hundreds of - // nanoseconds of noise, but such switches are only a few percent - // of cases. - // 3) We're willing to accept a somewhat large bias which can emerge from - // differences in the cost of calling each clock. - EXPECT_LT(std::abs(deltas[n / 2]), 200); - EXPECT_LT(deltas[n * 3 / 4] - deltas[n / 4], 50); + // In general it's not a good idea to put clocks in unit tests as it leads + // to flakiness. We mitigate this by: + // 1) Testing the clock itself. While the time to complete a task may + // vary, two clocks measuring the same time should be much more + // consistent. + // 2) Only testing the interquartile range. Context switches between + // calls to the two timers do occur and can result in hundreds of + // nanoseconds of noise, but such switches are only a few percent + // of cases. + // 3) We're willing to accept a somewhat large bias which can emerge from + // differences in the cost of calling each clock. + EXPECT_LT(std::abs(deltas[n / 2]), 200); + EXPECT_LT(deltas[n * 3 / 4] - deltas[n / 4], 50); } diff --git a/test/cpp/profiler/record_function.cpp b/test/cpp/profiler/record_function.cpp index ba76c5af588889..2e652ed038b647 100644 --- a/test/cpp/profiler/record_function.cpp +++ b/test/cpp/profiler/record_function.cpp @@ -162,9 +162,7 @@ TEST(RecordFunctionTest, ThreadMigration) { at::RecordFunctionCallback( [](const at::RecordFunction&) -> std::unique_ptr { return nullptr; }, - [](const at::RecordFunction&, at::ObserverContext*) { - ++call_count; - }) + [](const at::RecordFunction&, at::ObserverContext*) { ++call_count; }) .scopes({at::RecordScope::FUNCTION})); EXPECT_EQ(call_count, 0); @@ -176,7 +174,7 @@ TEST(RecordFunctionTest, ThreadMigration) { cv.notify_all(); }); auto guard = std::unique_lock(lock); - cv.wait(guard, []{ return call_count > 0; }); + cv.wait(guard, [] { return call_count > 0; }); EXPECT_EQ(call_count, 1); diff --git a/test/cpp/rpc/e2e_test_base.h b/test/cpp/rpc/e2e_test_base.h index 5762a2f0cfaea8..b4b28ceb57c7a8 100644 --- a/test/cpp/rpc/e2e_test_base.h +++ b/test/cpp/rpc/e2e_test_base.h @@ -88,9 +88,10 @@ class TestE2EBase : public ::testing::Test { // Builtin operators does not return py::object, and hence does not require // GIL for destructing the potentially deleted OwerRRef. - jitFuture->addCallback([ownerRRefId = ownerRRef->rrefId()](JitFuture& jitFuture) { - callback::finishCreatingOwnerRRef(jitFuture, ownerRRefId); - }); + jitFuture->addCallback( + [ownerRRefId = ownerRRef->rrefId()](JitFuture& jitFuture) { + callback::finishCreatingOwnerRRef(jitFuture, ownerRRefId); + }); return ownerRRef; } diff --git a/test/cpp/rpc/test_tensorpipe_serialization.cpp b/test/cpp/rpc/test_tensorpipe_serialization.cpp index b561e71aec9954..63b44f1707b689 100644 --- a/test/cpp/rpc/test_tensorpipe_serialization.cpp +++ b/test/cpp/rpc/test_tensorpipe_serialization.cpp @@ -1,8 +1,8 @@ #include +#include #include #include -#include #include #include @@ -20,8 +20,9 @@ TEST(TensorpipeSerialize, Base) { torch::distributed::rpc::MessageType mtype = torch::distributed::rpc::MessageType::UNKNOWN; int64_t mId = 100; - auto sendingRpcMessage = c10::make_intrusive( - std::move(payload), std::move(tensors), mtype); + auto sendingRpcMessage = + c10::make_intrusive( + std::move(payload), std::move(tensors), mtype); sendingRpcMessage->setId(mId); tensorpipe::Message sendingTpMessage; torch::distributed::rpc::TensorpipeWriteBuffers sendingTpBuffers; @@ -90,8 +91,7 @@ TEST(TensorpipeSerialize, Base) { // - Unpickle c10::intrusive_ptr recvingRpcMessage = torch::distributed::rpc::tensorpipeDeserialize( - std::move(recvingTpDescriptor), - std::move(recvingTpBuffers)); + std::move(recvingTpDescriptor), std::move(recvingTpBuffers)); // Data is ready EXPECT_EQ(mtype, recvingRpcMessage->type()); @@ -113,8 +113,9 @@ TEST(TensorpipeSerialize, RecopySparseTensors) { std::vector payload = {'1', '2', '3'}; torch::distributed::rpc::MessageType mtype = torch::distributed::rpc::MessageType::UNKNOWN; - auto sendingRpcMessage = c10::make_intrusive( - std::move(payload), std::move(tensors), mtype); + auto sendingRpcMessage = + c10::make_intrusive( + std::move(payload), std::move(tensors), mtype); tensorpipe::Message sendingTpMessage; torch::distributed::rpc::TensorpipeWriteBuffers tpBuffers; @@ -145,8 +146,9 @@ TEST(TensorpipeSerialize, NoDeleterTensors) { std::vector payload = {'1', '2', '3'}; torch::distributed::rpc::MessageType mtype = torch::distributed::rpc::MessageType::UNKNOWN; - auto sendingRpcMessage = c10::make_intrusive( - std::move(payload), std::move(tensors), mtype); + auto sendingRpcMessage = + c10::make_intrusive( + std::move(payload), std::move(tensors), mtype); tensorpipe::Message sendingTpMessage; torch::distributed::rpc::TensorpipeWriteBuffers tpBuffers; diff --git a/torch/csrc/CudaIPCTypes.cpp b/torch/csrc/CudaIPCTypes.cpp index a4673f023ffbd4..9a2c47a5f7a84e 100644 --- a/torch/csrc/CudaIPCTypes.cpp +++ b/torch/csrc/CudaIPCTypes.cpp @@ -1,5 +1,5 @@ -#include #include +#include #include #include #include @@ -32,7 +32,9 @@ struct CudaIPCGlobalEntities { ref_counters_files_; std::shared_ptr next_available_ref_counters_file_; CudaIPCSentDataLimbo CudaIPCSentDataLimbo_; - CudaIPCGlobalEntities() { alive = true; } + CudaIPCGlobalEntities() { + alive = true; + } ~CudaIPCGlobalEntities() { CudaIPCSentDataLimbo_.collect(); safe_clean_current_file(); @@ -77,7 +79,8 @@ bool CudaIPCSentDataLimbo::collect() { } shared_blocks_ = std::move(kept_blocks); } - // Need to reset blocks out of the critical section here, otherwise it deadlocks. + // Need to reset blocks out of the critical section here, otherwise it + // deadlocks. for (auto& sd : reset_blocks) { sd.reset(); } @@ -107,7 +110,7 @@ uint64_t CudaIPCSentDataLimbo::size() { void CudaIPCSentDataDelete(void* ptr) { std::unique_ptr sent_data( static_cast(ptr)); - if(!CudaIPCGlobalEntities::alive) { + if (!CudaIPCGlobalEntities::alive) { return; } if (sent_data->counter_value() > 0) { @@ -117,7 +120,7 @@ void CudaIPCSentDataDelete(void* ptr) { } void ReturnRefCounter(const std::string& handle, uint64_t offset /* unused */) { - if(!CudaIPCGlobalEntities::alive) { + if (!CudaIPCGlobalEntities::alive) { return; } std::lock_guard lock( @@ -145,23 +148,25 @@ CudaIPCSentData::CudaIPCSentData( original_ptr_(), device_(device) { #if !defined(USE_ROCM) - // CUDA have the unofficial limit on the number of recorded blocking interprocess - // events, to prevent using of all events, we are switching to StreamSync - // before limit reached. + // CUDA have the unofficial limit on the number of recorded blocking + // interprocess events, to prevent using of all events, we are switching to + // StreamSync before limit reached. // // ```python // import torch // a = [ torch.cuda.Event( - // enable_timing=False, blocking=True, interprocess=True) for i in range(30000) ] + // enable_timing=False, blocking=True, interprocess=True) for i in + // range(30000) ] // [i.record() for i in a] // ``` // - if (cuda_ipc_global_entities.sync_events_used_.load() < CUDA_IPC_MAXIMUM_EVENTS_TO_USE) { + if (cuda_ipc_global_entities.sync_events_used_.load() < + CUDA_IPC_MAXIMUM_EVENTS_TO_USE) { // TODO: More efficient would be to create event inside of main thread (at // the moment of the queue.put). The reason this is more efficient is // because the main thread may have queued extra work on the stream, which // this event will consequently wait for (uselessly). - cuda_ipc_global_entities.sync_events_used_ ++; + cuda_ipc_global_entities.sync_events_used_++; C10_CUDA_CHECK(cudaEventCreateWithFlags( &event_, cudaEventDisableTiming | cudaEventInterprocess | @@ -191,10 +196,10 @@ CudaIPCSentData::~CudaIPCSentData() { if (event_sync_required_) { at::cuda::CUDAGuard device_guard(device_.index()); cudaEventDestroy(event_); - if(!CudaIPCGlobalEntities::alive) { + if (!CudaIPCGlobalEntities::alive) { return; } - cuda_ipc_global_entities.sync_events_used_ --; + cuda_ipc_global_entities.sync_events_used_--; } } catch (...) { /* No throw */ } @@ -212,7 +217,8 @@ at::DataPtr GetNewRefCountedSentData(void* data, at::Device device) { if (!cuda_ipc_global_entities.next_available_ref_counters_file_) { std::string ref_counter_handle = at::NewProcessWideShmHandle(); - int flags = at::ALLOCATOR_MAPPED_SHAREDMEM | at::ALLOCATOR_MAPPED_EXCLUSIVE; + int flags = + at::ALLOCATOR_MAPPED_SHAREDMEM | at::ALLOCATOR_MAPPED_EXCLUSIVE; at::DataPtr sptr = at::RefcountedMapAllocator::makeDataPtr( ref_counter_handle.c_str(), flags, @@ -240,7 +246,7 @@ at::DataPtr GetNewRefCountedSentData(void* data, at::Device device) { } bool CudaIPCCollect() { - if(!CudaIPCGlobalEntities::alive) { + if (!CudaIPCGlobalEntities::alive) { return true; } bool freed_memory = cuda_ipc_global_entities.CudaIPCSentDataLimbo_.collect(); diff --git a/torch/csrc/CudaIPCTypes.h b/torch/csrc/CudaIPCTypes.h index 941716e6fbfa4d..bedbe34ad633b3 100644 --- a/torch/csrc/CudaIPCTypes.h +++ b/torch/csrc/CudaIPCTypes.h @@ -1,6 +1,5 @@ #pragma once #ifdef USE_CUDA -#include #include #include #include @@ -8,6 +7,7 @@ #include #include #include +#include #include namespace torch { @@ -48,7 +48,9 @@ struct CudaIPCSentData final { } }; -TORCH_CUDA_CU_API at::DataPtr GetNewRefCountedSentData(void* data, at::Device device); +TORCH_CUDA_CU_API at::DataPtr GetNewRefCountedSentData( + void* data, + at::Device device); namespace { diff --git a/torch/csrc/DataLoader.cpp b/torch/csrc/DataLoader.cpp index ec2ffa50883ad7..415b15164aaec3 100644 --- a/torch/csrc/DataLoader.cpp +++ b/torch/csrc/DataLoader.cpp @@ -4,7 +4,8 @@ // is an effort to do our best to provide some error message to users when a // worker dies due to error / critical signals. // -// See NOTE [ Signal handling in multiprocessing data loading ] for more details. +// See NOTE [ Signal handling in multiprocessing data loading ] for more +// details. // TODO: The following don't work on Windows. Specifically, sigaction, waitid // calls, and SIGCHLD handler. Currently, dummy implementations are provided @@ -18,12 +19,12 @@ #include #include +#include #include +#include #include #include -#include #include -#include using namespace torch; @@ -32,39 +33,53 @@ using namespace torch; // The handler will raise default handler so that the kill information will be // retrieved from main process. // Python handle is _set_worker_signal_handlers(). -#define SIGNAL_HANDLER(SIGNAL, HANDLER_NAME, ERROR_MSG) \ -static void HANDLER_NAME(int sig, siginfo_t *info, void *ctx) \ -{ \ - auto _w = write(STDERR_FILENO, ERROR_MSG, sizeof(ERROR_MSG) / sizeof(char));\ - (void)_w; \ - struct sigaction sa{}; \ - sa.sa_handler = SIG_DFL; \ - sa.sa_flags = 0; \ - if (sigemptyset(&sa.sa_mask) != 0 || sigaction(SIGNAL, &sa, nullptr) != 0) { \ - _exit(EXIT_FAILURE); \ - } else { \ - raise(SIGNAL); \ - } \ -} +#define SIGNAL_HANDLER(SIGNAL, HANDLER_NAME, ERROR_MSG) \ + static void HANDLER_NAME(int sig, siginfo_t* info, void* ctx) { \ + auto _w = \ + write(STDERR_FILENO, ERROR_MSG, sizeof(ERROR_MSG) / sizeof(char)); \ + (void)_w; \ + struct sigaction sa {}; \ + sa.sa_handler = SIG_DFL; \ + sa.sa_flags = 0; \ + if (sigemptyset(&sa.sa_mask) != 0 || \ + sigaction(SIGNAL, &sa, nullptr) != 0) { \ + _exit(EXIT_FAILURE); \ + } else { \ + raise(SIGNAL); \ + } \ + } // signal(2) is really not portable. So use sigaction. // http://man7.org/linux/man-pages/man2/signal.2.html -static inline void setSignalHandler(int signal, void(*handler)(int, siginfo_t *, void *), struct sigaction *old_sa_ptr) -{ - struct sigaction sa{}; +static inline void setSignalHandler( + int signal, + void (*handler)(int, siginfo_t*, void*), + struct sigaction* old_sa_ptr) { + struct sigaction sa {}; sa.sa_sigaction = handler; - sa.sa_flags = SA_RESTART|SA_SIGINFO|SA_NOCLDSTOP|SA_NODEFER; - if (sigemptyset(&sa.sa_mask) != 0 || sigaction(signal, &sa, old_sa_ptr) != 0) { + sa.sa_flags = SA_RESTART | SA_SIGINFO | SA_NOCLDSTOP | SA_NODEFER; + if (sigemptyset(&sa.sa_mask) != 0 || + sigaction(signal, &sa, old_sa_ptr) != 0) { std::ostringstream oss; - oss << "An error occurred while setting handler for " << strsignal(signal) << "."; + oss << "An error occurred while setting handler for " << strsignal(signal) + << "."; throw std::runtime_error(oss.str()); } } -SIGNAL_HANDLER(SIGBUS, handler_SIGBUS, "ERROR: Unexpected bus error encountered in worker. " - "This might be caused by insufficient shared memory (shm).\n"); -SIGNAL_HANDLER(SIGSEGV, handler_SIGSEGV, "ERROR: Unexpected segmentation fault encountered in worker.\n"); -SIGNAL_HANDLER(SIGFPE, handler_SIGFPE, "ERROR: Unexpected floating-point exception encountered in worker.\n"); +SIGNAL_HANDLER( + SIGBUS, + handler_SIGBUS, + "ERROR: Unexpected bus error encountered in worker. " + "This might be caused by insufficient shared memory (shm).\n"); +SIGNAL_HANDLER( + SIGSEGV, + handler_SIGSEGV, + "ERROR: Unexpected segmentation fault encountered in worker.\n"); +SIGNAL_HANDLER( + SIGFPE, + handler_SIGFPE, + "ERROR: Unexpected floating-point exception encountered in worker.\n"); // When an error happened in DataLoader methods and Python starts to exit, the // error trace will keep the loader alive, and Python may kill the children @@ -74,12 +89,11 @@ SIGNAL_HANDLER(SIGFPE, handler_SIGFPE, "ERROR: Unexpected floating-point excepti // loader process here to avoid this by _exit(EXIT_SUCCESS). Note that if we // exit with nonzero code, the loader SIGCHLD handler may report RuntimeError // again, and then it defeats the whole purpose. -static void handler_SIGTERM(int sig, siginfo_t *info, void *ctx) -{ +static void handler_SIGTERM(int sig, siginfo_t* info, void* ctx) { if (info->si_pid == getppid()) { _exit(EXIT_SUCCESS); } - struct sigaction sa{}; + struct sigaction sa {}; sa.sa_handler = SIG_DFL; sa.sa_flags = 0; if (sigemptyset(&sa.sa_mask) != 0 || sigaction(SIGTERM, &sa, nullptr) != 0) { @@ -89,7 +103,9 @@ static void handler_SIGTERM(int sig, siginfo_t *info, void *ctx) } } -static PyObject *THPModule_setWorkerSignalHandlers(PyObject *module, PyObject *arg) { +static PyObject* THPModule_setWorkerSignalHandlers( + PyObject* module, + PyObject* arg) { HANDLE_TH_ERRORS setSignalHandler(SIGBUS, &handler_SIGBUS, nullptr); setSignalHandler(SIGSEGV, &handler_SIGSEGV, nullptr); @@ -101,12 +117,14 @@ static PyObject *THPModule_setWorkerSignalHandlers(PyObject *module, PyObject *a static std::map> worker_pids = {}; -static PyObject *THPModule_errorIfAnyWorkerFails(PyObject *module, PyObject *noargs) { +static PyObject* THPModule_errorIfAnyWorkerFails( + PyObject* module, + PyObject* noargs) { HANDLE_TH_ERRORS // NOLINTNEXTLINE(cppcoreguidelines-init-variables) int error; // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - std::set *pid_set; + std::set* pid_set; // NOLINTNEXTLINE(cppcoreguidelines-init-variables) pid_t worker_pid; siginfo_t infop; @@ -116,14 +134,16 @@ static PyObject *THPModule_errorIfAnyWorkerFails(PyObject *module, PyObject *noa pid_set = &(w.second); for (auto pid_it = pid_set->begin(); pid_it != pid_set->end(); ++pid_it) { worker_pid = *pid_it; - // Use waitid rather than waitpid so that we can set NOWAIT, and that Python - // and other handlers can get whatever info they want about the child. + // Use waitid rather than waitpid so that we can set NOWAIT, and that + // Python and other handlers can get whatever info they want about the + // child. infop.si_pid = 0; - error = waitid(P_PID, worker_pid, &infop, WEXITED|WNOHANG|WNOWAIT); + error = waitid(P_PID, worker_pid, &infop, WEXITED | WNOHANG | WNOWAIT); // ignore errors and case with no waitable child if (error < 0 || infop.si_pid == 0) continue; - if (infop.si_code == CLD_EXITED && infop.si_status != EXIT_SUCCESS) { // exit with error + if (infop.si_code == CLD_EXITED && + infop.si_status != EXIT_SUCCESS) { // exit with error std::ostringstream oss; oss << "DataLoader worker (pid " << worker_pid << ") exited " << "unexpectedly with exit code " << infop.si_status << ". " @@ -133,13 +153,15 @@ static PyObject *THPModule_errorIfAnyWorkerFails(PyObject *module, PyObject *noa // workers, and trigger this again. pid_set->clear(); throw std::runtime_error(oss.str()); - } else if (infop.si_code == CLD_KILLED || infop.si_code == CLD_DUMPED) { // killed by signal + } else if ( + infop.si_code == CLD_KILLED || + infop.si_code == CLD_DUMPED) { // killed by signal std::ostringstream oss; oss << "DataLoader worker (pid " << worker_pid << ") is killed " << "by signal: " << strsignal(infop.si_status) << ". "; if (infop.si_status == SIGBUS) { - oss << "It is possible that dataloader's workers are out of shared memory. " - << "Please try to raise your shared memory limit."; + oss << "It is possible that dataloader's workers are out of shared memory. " + << "Please try to raise your shared memory limit."; } // This is necessary. Otherwise, the runtime error will kill the other // workers, and trigger this again. @@ -154,24 +176,26 @@ static PyObject *THPModule_errorIfAnyWorkerFails(PyObject *module, PyObject *noa // We don't want to exit on any SIGCHLD from any child. child_pids is a tuple // of pids we are interested in. -static PyObject *THPModule_setWorkerPIDs(PyObject *module, PyObject *args) { +static PyObject* THPModule_setWorkerPIDs(PyObject* module, PyObject* args) { HANDLE_TH_ERRORS if (PyTuple_GET_SIZE(args) != 2) { throw TypeError("_set_worker_pids expects exactly 2 arguments."); } int64_t key = THPUtils_unpackLong(PyTuple_GET_ITEM(args, 0)); if (worker_pids.find(key) != worker_pids.end()) { - throw ValueError("_set_worker_pids should be called only once for each _BaseDataLoaderIter."); + throw ValueError( + "_set_worker_pids should be called only once for each _BaseDataLoaderIter."); } - PyObject *child_pids = PyTuple_GET_ITEM(args, 1); + PyObject* child_pids = PyTuple_GET_ITEM(args, 1); if (!PyTuple_Check(child_pids)) { - throw TypeError("_set_worker_pids expects a tuple for child_pids, but got %s.", + throw TypeError( + "_set_worker_pids expects a tuple for child_pids, but got %s.", Py_TYPE(child_pids)->tp_name); } std::set pids_set = {}; auto size = PyTuple_GET_SIZE(child_pids); - for(const auto idx : c10::irange(size)) { + for (const auto idx : c10::irange(size)) { PyObject* obj = PyTuple_GET_ITEM(child_pids, idx); pids_set.insert(static_cast(THPUtils_unpackLong(obj))); } @@ -182,13 +206,17 @@ static PyObject *THPModule_setWorkerPIDs(PyObject *module, PyObject *args) { END_HANDLE_TH_ERRORS } -static PyObject *THPModule_removeWorkerPIDs(PyObject *module, PyObject *loader_id) { +static PyObject* THPModule_removeWorkerPIDs( + PyObject* module, + PyObject* loader_id) { HANDLE_TH_ERRORS int64_t key = THPUtils_unpackLong(loader_id); auto it = worker_pids.find(key); if (it == worker_pids.end()) { - throw ValueError(fmt::format("Cannot find worker information for _BaseDataLoaderIter with id {}", key)); + throw ValueError(fmt::format( + "Cannot find worker information for _BaseDataLoaderIter with id {}", + key)); } worker_pids.erase(it); @@ -201,19 +229,25 @@ static PyObject *THPModule_removeWorkerPIDs(PyObject *module, PyObject *loader_i #else // dummy implementations for windows -static PyObject *THPModule_setWorkerSignalHandlers(PyObject *module, PyObject *_ignored) { +static PyObject* THPModule_setWorkerSignalHandlers( + PyObject* module, + PyObject* _ignored) { Py_RETURN_NONE; } -static PyObject *THPModule_setWorkerPIDs(PyObject *module, PyObject *_ignored) { +static PyObject* THPModule_setWorkerPIDs(PyObject* module, PyObject* _ignored) { Py_RETURN_NONE; } -static PyObject *THPModule_removeWorkerPIDs(PyObject *module, PyObject *_ignored) { +static PyObject* THPModule_removeWorkerPIDs( + PyObject* module, + PyObject* _ignored) { Py_RETURN_NONE; } -static PyObject *THPModule_errorIfAnyWorkerFails(PyObject *module, PyObject *_ignored) { +static PyObject* THPModule_errorIfAnyWorkerFails( + PyObject* module, + PyObject* _ignored) { Py_RETURN_NONE; } @@ -221,9 +255,14 @@ static PyObject *THPModule_errorIfAnyWorkerFails(PyObject *module, PyObject *_ig // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables,modernize-avoid-c-arrays) PyMethodDef DataLoaderMethods[] = { - {"_set_worker_signal_handlers", THPModule_setWorkerSignalHandlers, METH_NOARGS, nullptr}, - {"_set_worker_pids", THPModule_setWorkerPIDs, METH_VARARGS, nullptr}, - {"_remove_worker_pids", THPModule_removeWorkerPIDs, METH_O, nullptr}, - {"_error_if_any_worker_fails", THPModule_errorIfAnyWorkerFails, METH_NOARGS, nullptr}, - {nullptr, nullptr, 0, nullptr} -}; + {"_set_worker_signal_handlers", + THPModule_setWorkerSignalHandlers, + METH_NOARGS, + nullptr}, + {"_set_worker_pids", THPModule_setWorkerPIDs, METH_VARARGS, nullptr}, + {"_remove_worker_pids", THPModule_removeWorkerPIDs, METH_O, nullptr}, + {"_error_if_any_worker_fails", + THPModule_errorIfAnyWorkerFails, + METH_NOARGS, + nullptr}, + {nullptr, nullptr, 0, nullptr}}; diff --git a/torch/csrc/Device.cpp b/torch/csrc/Device.cpp index 8bd5c0ed47710c..9ea495f2c3fd42 100644 --- a/torch/csrc/Device.cpp +++ b/torch/csrc/Device.cpp @@ -2,36 +2,35 @@ #include #include +#include #include -#include #include -#include +#include #include #include +#include #include #include -#include #include -PyObject *THPDevice_New(const at::Device& device) -{ +PyObject* THPDevice_New(const at::Device& device) { auto type = (PyTypeObject*)&THPDeviceType; auto self = THPObjectPtr{type->tp_alloc(type, 0)}; - if (!self) throw python_error(); + if (!self) + throw python_error(); auto self_ = reinterpret_cast(self.get()); self_->device = device; return self.release(); } -PyObject *THPDevice_repr(THPDevice *self) -{ +PyObject* THPDevice_repr(THPDevice* self) { std::ostringstream oss; oss << "device(type=\'" << self->device.type() << "\'"; if (self->device.has_index()) { - // `self->device.index()` returns uint8_t which is treated as ascii while printing, - // hence casting it to uint16_t. + // `self->device.index()` returns uint8_t which is treated as ascii while + // printing, hence casting it to uint16_t. // https://stackoverflow.com/questions/19562103/uint8-t-cant-be-printed-with-cout oss << ", index=" << static_cast(self->device.index()); } @@ -39,31 +38,33 @@ PyObject *THPDevice_repr(THPDevice *self) return THPUtils_packString(oss.str().c_str()); } -PyObject *THPDevice_str(THPDevice *self) -{ +PyObject* THPDevice_str(THPDevice* self) { std::ostringstream oss; oss << self->device; return THPUtils_packString(oss.str().c_str()); } -PyObject *THPDevice_pynew(PyTypeObject *type, PyObject *args, PyObject *kwargs) -{ +PyObject* THPDevice_pynew( + PyTypeObject* type, + PyObject* args, + PyObject* kwargs) { HANDLE_TH_ERRORS - static torch::PythonArgParser parser({ - "Device(Device device)", - "Device(c10::string_view type, int64_t? index=-1)" - }); + static torch::PythonArgParser parser( + {"Device(Device device)", + "Device(c10::string_view type, int64_t? index=-1)"}); torch::ParsedArgs<2> parsed_args; auto r = parser.parse(args, kwargs, parsed_args); if (r.idx == 0) { auto device = r.device(0); return THPDevice_New(device); } else if (r.idx == 1) { - auto as_device = r.device(0); // this works, because device can take strings + auto as_device = r.device(0); // this works, because device can take strings auto device_type = r.string(0); if (as_device.has_index()) { - throw std::runtime_error("type (string) must not include an index because index " - "was passed explicitly: " + device_type); + throw std::runtime_error( + "type (string) must not include an index because index " + "was passed explicitly: " + + device_type); } int32_t device_index = -1; if (!r.isNone(1)) { @@ -79,8 +80,7 @@ PyObject *THPDevice_pynew(PyTypeObject *type, PyObject *args, PyObject *kwargs) END_HANDLE_TH_ERRORS } -PyObject *THPDevice_type(THPDevice *self, PyObject *noargs) -{ +PyObject* THPDevice_type(THPDevice* self, PyObject* noargs) { HANDLE_TH_ERRORS std::ostringstream oss; oss << self->device.type(); @@ -89,8 +89,7 @@ PyObject *THPDevice_type(THPDevice *self, PyObject *noargs) END_HANDLE_TH_ERRORS } -PyObject *THPDevice_index(THPDevice *self, PyObject *noargs) -{ +PyObject* THPDevice_index(THPDevice* self, PyObject* noargs) { HANDLE_TH_ERRORS if (self->device.has_index()) { return THPUtils_packInt64(self->device.index()); @@ -100,24 +99,25 @@ PyObject *THPDevice_index(THPDevice *self, PyObject *noargs) END_HANDLE_TH_ERRORS } -static Py_ssize_t THPDevice_hash(THPDevice *self) -{ +static Py_ssize_t THPDevice_hash(THPDevice* self) { HANDLE_TH_ERRORS - return static_cast(std::hash{}(self->device) % std::numeric_limits::max()); + return static_cast( + std::hash{}(self->device) % + std::numeric_limits::max()); END_HANDLE_TH_ERRORS_RET(-1) } -PyObject *THPDevice_rc(PyObject *a, PyObject *b, int op) { +PyObject* THPDevice_rc(PyObject* a, PyObject* b, int op) { HANDLE_TH_ERRORS if (!THPDevice_Check(a) || !THPDevice_Check(b)) { // Py_RETURN_NOTIMPLEMENTED not in python 2. Py_INCREF(Py_NotImplemented); return Py_NotImplemented; } - THPDevice *da = reinterpret_cast(a); - THPDevice *db = reinterpret_cast(b); + THPDevice* da = reinterpret_cast(a); + THPDevice* db = reinterpret_cast(b); - switch(op) { + switch (op) { case Py_EQ: if (da->device == db->device) { Py_RETURN_TRUE; @@ -141,12 +141,12 @@ PyObject *THPDevice_rc(PyObject *a, PyObject *b, int op) { END_HANDLE_TH_ERRORS } -PyObject *THPDevice_reduce(PyObject *_self, PyObject *noargs) -{ +PyObject* THPDevice_reduce(PyObject* _self, PyObject* noargs) { HANDLE_TH_ERRORS auto self = (THPDevice*)_self; auto ret = THPObjectPtr{PyTuple_New(2)}; - if (!ret) throw python_error(); + if (!ret) + throw python_error(); py::object torch_module = py::module::import("torch"); py::object torch_device = torch_module.attr("device"); @@ -156,82 +156,81 @@ PyObject *THPDevice_reduce(PyObject *_self, PyObject *noargs) std::ostringstream oss; oss << self->device.type(); if (self->device.has_index()) { - args = THPObjectPtr{Py_BuildValue("(si)", oss.str().c_str(), self->device.index())}; + args = THPObjectPtr{ + Py_BuildValue("(si)", oss.str().c_str(), self->device.index())}; } else { args = THPObjectPtr{Py_BuildValue("(s)", oss.str().c_str())}; } - if (!args) throw python_error(); + if (!args) + throw python_error(); PyTuple_SET_ITEM(ret.get(), 1, args.release()); return ret.release(); END_HANDLE_TH_ERRORS } -typedef PyObject *(*getter)(PyObject *, void *); +typedef PyObject* (*getter)(PyObject*, void*); // NB: If you edit these properties/methods, update torch/_C/__init__.pyi.in // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables,modernize-avoid-c-arrays) static struct PyGetSetDef THPDevice_properties[] = { - {"type", (getter)THPDevice_type, nullptr, nullptr, nullptr}, - {"index", (getter)THPDevice_index, nullptr, nullptr, nullptr}, - {nullptr} -}; + {"type", (getter)THPDevice_type, nullptr, nullptr, nullptr}, + {"index", (getter)THPDevice_index, nullptr, nullptr, nullptr}, + {nullptr}}; // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables,modernize-avoid-c-arrays) static PyMethodDef THPDevice_methods[] = { - {"__reduce__", THPDevice_reduce, METH_NOARGS, nullptr}, - {nullptr} /* Sentinel */ + {"__reduce__", THPDevice_reduce, METH_NOARGS, nullptr}, + {nullptr} /* Sentinel */ }; PyTypeObject THPDeviceType = { - PyVarObject_HEAD_INIT(nullptr, 0) - "torch.device", /* tp_name */ - sizeof(THPDevice), /* tp_basicsize */ - 0, /* tp_itemsize */ - nullptr, /* tp_dealloc */ - 0, /* tp_vectorcall_offset */ - nullptr, /* tp_getattr */ - nullptr, /* tp_setattr */ - nullptr, /* tp_reserved */ - (reprfunc)THPDevice_repr, /* tp_repr */ - nullptr, /* tp_as_number */ - nullptr, /* tp_as_sequence */ - nullptr, /* tp_as_mapping */ - (hashfunc)THPDevice_hash, /* tp_hash */ - nullptr, /* tp_call */ - (reprfunc)THPDevice_str, /* tp_str */ - nullptr, /* tp_getattro */ - nullptr, /* tp_setattro */ - nullptr, /* tp_as_buffer */ - Py_TPFLAGS_DEFAULT, /* tp_flags */ - nullptr, /* tp_doc */ - nullptr, /* tp_traverse */ - nullptr, /* tp_clear */ - (richcmpfunc)THPDevice_rc, /* tp_richcompare */ - 0, /* tp_weaklistoffset */ - nullptr, /* tp_iter */ - nullptr, /* tp_iternext */ - THPDevice_methods, /* tp_methods */ - nullptr, /* tp_members */ - THPDevice_properties, /* tp_getset */ - nullptr, /* tp_base */ - nullptr, /* tp_dict */ - nullptr, /* tp_descr_get */ - nullptr, /* tp_descr_set */ - 0, /* tp_dictoffset */ - nullptr, /* tp_init */ - nullptr, /* tp_alloc */ - THPDevice_pynew, /* tp_new */ + PyVarObject_HEAD_INIT(nullptr, 0) "torch.device", /* tp_name */ + sizeof(THPDevice), /* tp_basicsize */ + 0, /* tp_itemsize */ + nullptr, /* tp_dealloc */ + 0, /* tp_vectorcall_offset */ + nullptr, /* tp_getattr */ + nullptr, /* tp_setattr */ + nullptr, /* tp_reserved */ + (reprfunc)THPDevice_repr, /* tp_repr */ + nullptr, /* tp_as_number */ + nullptr, /* tp_as_sequence */ + nullptr, /* tp_as_mapping */ + (hashfunc)THPDevice_hash, /* tp_hash */ + nullptr, /* tp_call */ + (reprfunc)THPDevice_str, /* tp_str */ + nullptr, /* tp_getattro */ + nullptr, /* tp_setattro */ + nullptr, /* tp_as_buffer */ + Py_TPFLAGS_DEFAULT, /* tp_flags */ + nullptr, /* tp_doc */ + nullptr, /* tp_traverse */ + nullptr, /* tp_clear */ + (richcmpfunc)THPDevice_rc, /* tp_richcompare */ + 0, /* tp_weaklistoffset */ + nullptr, /* tp_iter */ + nullptr, /* tp_iternext */ + THPDevice_methods, /* tp_methods */ + nullptr, /* tp_members */ + THPDevice_properties, /* tp_getset */ + nullptr, /* tp_base */ + nullptr, /* tp_dict */ + nullptr, /* tp_descr_get */ + nullptr, /* tp_descr_set */ + 0, /* tp_dictoffset */ + nullptr, /* tp_init */ + nullptr, /* tp_alloc */ + THPDevice_pynew, /* tp_new */ }; -void THPDevice_init(PyObject *module) -{ +void THPDevice_init(PyObject* module) { if (PyType_Ready(&THPDeviceType) < 0) { throw python_error(); } Py_INCREF(&THPDeviceType); - if (PyModule_AddObject(module, "device", (PyObject *)&THPDeviceType) != 0) { + if (PyModule_AddObject(module, "device", (PyObject*)&THPDeviceType) != 0) { throw python_error(); } } diff --git a/torch/csrc/Device.h b/torch/csrc/Device.h index d85b92b1f2dc49..665c38bf035d45 100644 --- a/torch/csrc/Device.h +++ b/torch/csrc/Device.h @@ -1,22 +1,21 @@ #pragma once -#include #include +#include #include // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) struct TORCH_API THPDevice { - PyObject_HEAD - at::Device device; + PyObject_HEAD at::Device device; }; TORCH_API extern PyTypeObject THPDeviceType; -inline bool THPDevice_Check(PyObject *obj) { +inline bool THPDevice_Check(PyObject* obj) { return Py_TYPE(obj) == &THPDeviceType; } -TORCH_API PyObject * THPDevice_New(const at::Device& device); +TORCH_API PyObject* THPDevice_New(const at::Device& device); -TORCH_API void THPDevice_init(PyObject *module); +TORCH_API void THPDevice_init(PyObject* module); diff --git a/torch/csrc/Dtype.cpp b/torch/csrc/Dtype.cpp index 268fe22f648bc2..24b67376f8ae5f 100644 --- a/torch/csrc/Dtype.cpp +++ b/torch/csrc/Dtype.cpp @@ -1,29 +1,28 @@ #include -#include #include #include #include #include #include #include +#include #include -PyObject * THPDtype_New(at::ScalarType scalar_type, const std::string& name) -{ +PyObject* THPDtype_New(at::ScalarType scalar_type, const std::string& name) { AT_ASSERT(name.length() < DTYPE_NAME_LEN); auto type = (PyTypeObject*)&THPDtypeType; auto self = THPObjectPtr{type->tp_alloc(type, 0)}; - if (!self) throw python_error(); + if (!self) + throw python_error(); auto self_ = reinterpret_cast(self.get()); self_->scalar_type = scalar_type; std::strncpy(self_->name, name.c_str(), DTYPE_NAME_LEN); return self.release(); } -PyObject *THPDtype_is_floating_point(THPDtype *self, PyObject *noargs) -{ +PyObject* THPDtype_is_floating_point(THPDtype* self, PyObject* noargs) { if (at::isFloatingType(self->scalar_type)) { Py_RETURN_TRUE; } else { @@ -31,8 +30,7 @@ PyObject *THPDtype_is_floating_point(THPDtype *self, PyObject *noargs) } } -PyObject *THPDtype_is_complex(THPDtype *self, PyObject *noargs) -{ +PyObject* THPDtype_is_complex(THPDtype* self, PyObject* noargs) { if (at::isComplexType(self->scalar_type)) { Py_RETURN_TRUE; } else { @@ -40,8 +38,7 @@ PyObject *THPDtype_is_complex(THPDtype *self, PyObject *noargs) } } -PyObject *THPDtype_is_signed(THPDtype *self, PyObject *noargs) -{ +PyObject* THPDtype_is_signed(THPDtype* self, PyObject* noargs) { HANDLE_TH_ERRORS if (at::isSignedType(self->scalar_type)) { Py_RETURN_TRUE; @@ -51,81 +48,80 @@ PyObject *THPDtype_is_signed(THPDtype *self, PyObject *noargs) END_HANDLE_TH_ERRORS } -PyObject *THPDtype_reduce(PyObject *_self, PyObject *noargs) -{ +PyObject* THPDtype_reduce(PyObject* _self, PyObject* noargs) { /* - * For singletons, a string is returned. The string should be interpreted - * as the name of a global variable. - */ + * For singletons, a string is returned. The string should be interpreted + * as the name of a global variable. + */ auto self = (THPDtype*)_self; return THPUtils_packString(self->name); } -typedef PyObject *(*getter)(PyObject *, void *); +typedef PyObject* (*getter)(PyObject*, void*); // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables,modernize-avoid-c-arrays) static struct PyGetSetDef THPDtype_properties[] = { - {"is_floating_point", (getter)THPDtype_is_floating_point, nullptr, nullptr, nullptr}, - {"is_complex", (getter)THPDtype_is_complex, nullptr, nullptr, nullptr}, - {"is_signed", (getter)THPDtype_is_signed, nullptr, nullptr, nullptr}, - {nullptr} -}; + {"is_floating_point", + (getter)THPDtype_is_floating_point, + nullptr, + nullptr, + nullptr}, + {"is_complex", (getter)THPDtype_is_complex, nullptr, nullptr, nullptr}, + {"is_signed", (getter)THPDtype_is_signed, nullptr, nullptr, nullptr}, + {nullptr}}; // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables,modernize-avoid-c-arrays) static PyMethodDef THPDtype_methods[] = { - {"__reduce__", THPDtype_reduce, METH_NOARGS, nullptr}, - {nullptr} /* Sentinel */ + {"__reduce__", THPDtype_reduce, METH_NOARGS, nullptr}, + {nullptr} /* Sentinel */ }; -PyObject *THPDtype_repr(THPDtype *self) -{ +PyObject* THPDtype_repr(THPDtype* self) { std::string name = self->name; return THPUtils_packString("torch." + name); } PyTypeObject THPDtypeType = { - PyVarObject_HEAD_INIT(nullptr, 0) - "torch.dtype", /* tp_name */ - sizeof(THPDtype), /* tp_basicsize */ - 0, /* tp_itemsize */ - nullptr, /* tp_dealloc */ - 0, /* tp_vectorcall_offset */ - nullptr, /* tp_getattr */ - nullptr, /* tp_setattr */ - nullptr, /* tp_reserved */ - (reprfunc)THPDtype_repr, /* tp_repr */ - nullptr, /* tp_as_number */ - nullptr, /* tp_as_sequence */ - nullptr, /* tp_as_mapping */ - nullptr, /* tp_hash */ - nullptr, /* tp_call */ - nullptr, /* tp_str */ - nullptr, /* tp_getattro */ - nullptr, /* tp_setattro */ - nullptr, /* tp_as_buffer */ - Py_TPFLAGS_DEFAULT, /* tp_flags */ - nullptr, /* tp_doc */ - nullptr, /* tp_traverse */ - nullptr, /* tp_clear */ - nullptr, /* tp_richcompare */ - 0, /* tp_weaklistoffset */ - nullptr, /* tp_iter */ - nullptr, /* tp_iternext */ - THPDtype_methods, /* tp_methods */ - nullptr, /* tp_members */ - THPDtype_properties, /* tp_getset */ - nullptr, /* tp_base */ - nullptr, /* tp_dict */ - nullptr, /* tp_descr_get */ - nullptr, /* tp_descr_set */ - 0, /* tp_dictoffset */ - nullptr, /* tp_init */ - nullptr, /* tp_alloc */ - nullptr, /* tp_new */ + PyVarObject_HEAD_INIT(nullptr, 0) "torch.dtype", /* tp_name */ + sizeof(THPDtype), /* tp_basicsize */ + 0, /* tp_itemsize */ + nullptr, /* tp_dealloc */ + 0, /* tp_vectorcall_offset */ + nullptr, /* tp_getattr */ + nullptr, /* tp_setattr */ + nullptr, /* tp_reserved */ + (reprfunc)THPDtype_repr, /* tp_repr */ + nullptr, /* tp_as_number */ + nullptr, /* tp_as_sequence */ + nullptr, /* tp_as_mapping */ + nullptr, /* tp_hash */ + nullptr, /* tp_call */ + nullptr, /* tp_str */ + nullptr, /* tp_getattro */ + nullptr, /* tp_setattro */ + nullptr, /* tp_as_buffer */ + Py_TPFLAGS_DEFAULT, /* tp_flags */ + nullptr, /* tp_doc */ + nullptr, /* tp_traverse */ + nullptr, /* tp_clear */ + nullptr, /* tp_richcompare */ + 0, /* tp_weaklistoffset */ + nullptr, /* tp_iter */ + nullptr, /* tp_iternext */ + THPDtype_methods, /* tp_methods */ + nullptr, /* tp_members */ + THPDtype_properties, /* tp_getset */ + nullptr, /* tp_base */ + nullptr, /* tp_dict */ + nullptr, /* tp_descr_get */ + nullptr, /* tp_descr_set */ + 0, /* tp_dictoffset */ + nullptr, /* tp_init */ + nullptr, /* tp_alloc */ + nullptr, /* tp_new */ }; -void THPDtype_init(PyObject *module) -{ +void THPDtype_init(PyObject* module) { // Set a __dict__ with `__module__` = `torch`. This means // `__module__` value will be inherited by instances // (i.e. `torch.float32.__module__ == "torch"`). This will prevent @@ -138,11 +134,13 @@ void THPDtype_init(PyObject *module) // See https://github.com/pytorch/pytorch/issues/65077 TORCH_INTERNAL_ASSERT(THPDtypeType.tp_dict == nullptr); auto dict = THPObjectPtr(PyDict_New()); - if (!dict) throw python_error(); + if (!dict) + throw python_error(); auto torch = THPUtils_packString("torch"); - if (!torch) throw python_error(); - if (PyDict_SetItemString(dict, "__module__", torch) < 0 ) { - throw python_error(); + if (!torch) + throw python_error(); + if (PyDict_SetItemString(dict, "__module__", torch) < 0) { + throw python_error(); } THPDtypeType.tp_dict = dict.release(); @@ -150,7 +148,7 @@ void THPDtype_init(PyObject *module) throw python_error(); } Py_INCREF(&THPDtypeType); - if (PyModule_AddObject(module, "dtype", (PyObject *)&THPDtypeType) != 0) { + if (PyModule_AddObject(module, "dtype", (PyObject*)&THPDtypeType) != 0) { throw python_error(); } } diff --git a/torch/csrc/Dtype.h b/torch/csrc/Dtype.h index 1b60a566d4c598..648a8091d93d6e 100644 --- a/torch/csrc/Dtype.h +++ b/torch/csrc/Dtype.h @@ -1,30 +1,28 @@ #pragma once #include -#include #include +#include const int DTYPE_NAME_LEN = 64; struct TORCH_API THPDtype { - PyObject_HEAD - at::ScalarType scalar_type; + PyObject_HEAD at::ScalarType scalar_type; // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays) char name[DTYPE_NAME_LEN + 1]; }; TORCH_API extern PyTypeObject THPDtypeType; -inline bool THPDtype_Check(PyObject *obj) { +inline bool THPDtype_Check(PyObject* obj) { return Py_TYPE(obj) == &THPDtypeType; } -inline bool THPPythonScalarType_Check(PyObject *obj) { +inline bool THPPythonScalarType_Check(PyObject* obj) { return obj == (PyObject*)(&PyFloat_Type) || - obj == (PyObject*)(&PyBool_Type) || - obj == (PyObject*)(&PyLong_Type); + obj == (PyObject*)(&PyBool_Type) || obj == (PyObject*)(&PyLong_Type); } -PyObject * THPDtype_New(at::ScalarType scalar_type, const std::string& name); +PyObject* THPDtype_New(at::ScalarType scalar_type, const std::string& name); -void THPDtype_init(PyObject *module); +void THPDtype_init(PyObject* module); diff --git a/torch/csrc/DynamicTypes.cpp b/torch/csrc/DynamicTypes.cpp index c929f3832116c3..b3cb8103f820bb 100644 --- a/torch/csrc/DynamicTypes.cpp +++ b/torch/csrc/DynamicTypes.cpp @@ -1,16 +1,16 @@ #include +#include #include #include #include #include #include +#include #include #include #include #include -#include -#include #include @@ -24,11 +24,15 @@ namespace torch { namespace { -std::array(at::ScalarType::NumOptions)> dtype_registry = {}; +std::array(at::ScalarType::NumOptions)> + dtype_registry = {}; -std::array(at::Layout::NumOptions)> layout_registry = {}; +std::array(at::Layout::NumOptions)> + layout_registry = {}; -at::DeprecatedTypeProperties* get_type_properties(at::DeviceType device_type, at::ScalarType scalarType) { +at::DeprecatedTypeProperties* get_type_properties( + at::DeviceType device_type, + at::ScalarType scalarType) { at::Backend backend; if (device_type == at::kCPU) { backend = at::Backend::CPU; @@ -43,11 +47,11 @@ at::DeprecatedTypeProperties* get_type_properties(at::DeviceType device_type, at } } // namespace -void registerDtypeObject(THPDtype *dtype, at::ScalarType scalarType) { +void registerDtypeObject(THPDtype* dtype, at::ScalarType scalarType) { dtype_registry[static_cast(scalarType)] = dtype; } -void registerLayoutObject(THPLayout *thp_layout, at::Layout layout) { +void registerLayoutObject(THPLayout* thp_layout, at::Layout layout) { layout_registry[static_cast(layout)] = thp_layout; } @@ -68,13 +72,18 @@ THPLayout* getTHPLayout(at::Layout layout) { } PyObject* createPyObject(const at::Storage& storage) { - if (storage.device_type() != at::DeviceType::Meta && storage.data() == nullptr && storage.nbytes() != 0) { - TORCH_CHECK_NOT_IMPLEMENTED(false, "python bindings to nullptr storage (e.g., from torch.Tensor._make_wrapper_subclass) are currently unsafe and thus disabled. See https://github.com/pytorch/pytorch/issues/61669 for more details"); + if (storage.device_type() != at::DeviceType::Meta && + storage.data() == nullptr && storage.nbytes() != 0) { + TORCH_CHECK_NOT_IMPLEMENTED( + false, + "python bindings to nullptr storage (e.g., from torch.Tensor._make_wrapper_subclass) are currently unsafe and thus disabled. See https://github.com/pytorch/pytorch/issues/61669 for more details"); } PyTypeObject* type = reinterpret_cast(THPStorageClass); auto obj = THPObjectPtr(type->tp_alloc(type, 0)); - if (!obj) throw python_error(); - ((THPVoidStorage*)obj.get())->cdata = at::Storage(/* copy */ storage).unsafeReleaseStorageImpl(); + if (!obj) + throw python_error(); + ((THPVoidStorage*)obj.get())->cdata = + at::Storage(/* copy */ storage).unsafeReleaseStorageImpl(); return obj.release(); } @@ -82,7 +91,8 @@ PyTypeObject* loadTypedStorageTypeObject() { PyObject* storage_module = PyImport_ImportModule("torch.storage"); TORCH_INTERNAL_ASSERT(storage_module && PyModule_Check(storage_module)); - PyObject* typed_storage_obj = PyObject_GetAttrString(storage_module, "_TypedStorage"); + PyObject* typed_storage_obj = + PyObject_GetAttrString(storage_module, "_TypedStorage"); TORCH_INTERNAL_ASSERT(typed_storage_obj && PyType_Check(typed_storage_obj)); return reinterpret_cast( PyObject_GetAttrString(storage_module, "_TypedStorage")); @@ -94,8 +104,7 @@ PyTypeObject* getTypedStorageTypeObject() { return typed_storage_type_obj; } -bool isStorage(PyObject* obj) -{ +bool isStorage(PyObject* obj) { if (PyObject_TypeCheck(obj, getTypedStorageTypeObject())) { return true; } @@ -104,8 +113,10 @@ bool isStorage(PyObject* obj) return obj_type == reinterpret_cast(THPStorageClass); } -at::Storage createStorageGetType(PyObject* obj, at::ScalarType& scalar_type, bool& is_typed_storage) -{ +at::Storage createStorageGetType( + PyObject* obj, + at::ScalarType& scalar_type, + bool& is_typed_storage) { is_typed_storage = PyObject_TypeCheck(obj, getTypedStorageTypeObject()); PyObject* untyped_storage_obj; @@ -129,16 +140,19 @@ at::Storage createStorageGetType(PyObject* obj, at::ScalarType& scalar_type, boo untyped_storage_obj = obj; } - if (Py_TYPE(untyped_storage_obj) != reinterpret_cast(THPStorageClass)) { + if (Py_TYPE(untyped_storage_obj) != + reinterpret_cast(THPStorageClass)) { throw TypeError("not a storage '%s'", Py_TYPE(obj)->tp_name); } - c10::StorageImpl* impl = static_cast(((THPVoidStorage*)untyped_storage_obj)->cdata); + c10::StorageImpl* impl = static_cast( + ((THPVoidStorage*)untyped_storage_obj)->cdata); c10::DeviceType device_type = impl->device().type(); auto type_properties = get_type_properties(device_type, at::kByte); - return type_properties->unsafeStorageFromTH(((THPVoidStorage*)untyped_storage_obj)->cdata, true); + return type_properties->unsafeStorageFromTH( + ((THPVoidStorage*)untyped_storage_obj)->cdata, true); } at::Storage createStorage(PyObject* obj) { @@ -147,4 +161,4 @@ at::Storage createStorage(PyObject* obj) { return createStorageGetType(obj, scalar_type, is_typed_storage); } -} // namespace +} // namespace torch diff --git a/torch/csrc/DynamicTypes.h b/torch/csrc/DynamicTypes.h index 105744c2f68112..6765916634c539 100644 --- a/torch/csrc/DynamicTypes.h +++ b/torch/csrc/DynamicTypes.h @@ -5,10 +5,10 @@ #include #include -#include -#include #include #include +#include +#include #include #include @@ -21,14 +21,17 @@ struct Storage; } namespace torch { -void registerDtypeObject(THPDtype *dtype, at::ScalarType scalarType); -void registerLayoutObject(THPLayout *thp_layout, at::Layout layout); +void registerDtypeObject(THPDtype* dtype, at::ScalarType scalarType); +void registerLayoutObject(THPLayout* thp_layout, at::Layout layout); PyObject* createPyObject(const at::Storage& storage); at::Storage createStorage(PyObject* obj); -at::Storage createStorageGetType(PyObject* obj, at::ScalarType& scalar_type, bool& is_typed_storage); +at::Storage createStorageGetType( + PyObject* obj, + at::ScalarType& scalar_type, + bool& is_typed_storage); bool isStorage(PyObject* obj); THPDtype* getTHPDtype(at::ScalarType scalarType); THPLayout* getTHPLayout(at::Layout layout); -} // namespace torch +} // namespace torch diff --git a/torch/csrc/Exceptions.cpp b/torch/csrc/Exceptions.cpp index 2eb8ae41898eb7..a1506a16252335 100644 --- a/torch/csrc/Exceptions.cpp +++ b/torch/csrc/Exceptions.cpp @@ -1,26 +1,32 @@ #include #include -#include -#include #include #include #include +#include +#include #include PyObject *THPException_FatalError, *THPException_LinAlgError; -#define ASSERT_TRUE(cond) if (!(cond)) return false -bool THPException_init(PyObject *module) -{ - ASSERT_TRUE(THPException_FatalError = PyErr_NewException("torch.FatalError", nullptr, nullptr)); - ASSERT_TRUE(PyModule_AddObject(module, "FatalError", THPException_FatalError) == 0); +#define ASSERT_TRUE(cond) \ + if (!(cond)) \ + return false +bool THPException_init(PyObject* module) { + ASSERT_TRUE( + THPException_FatalError = + PyErr_NewException("torch.FatalError", nullptr, nullptr)); + ASSERT_TRUE( + PyModule_AddObject(module, "FatalError", THPException_FatalError) == 0); - // Set the doc string here since _add_docstr throws malloc errors if tp_doc is modified - // for an error class. - ASSERT_TRUE(THPException_LinAlgError = PyErr_NewExceptionWithDoc("torch._C._LinAlgError", - "Error raised by torch.linalg function when the cause of error is a numerical inconsistency in the data.\n \ + // Set the doc string here since _add_docstr throws malloc errors if tp_doc is + // modified for an error class. + ASSERT_TRUE( + THPException_LinAlgError = PyErr_NewExceptionWithDoc( + "torch._C._LinAlgError", + "Error raised by torch.linalg function when the cause of error is a numerical inconsistency in the data.\n \ For example, you can the torch.linalg.inv function will raise torch.linalg.LinAlgError when it finds that \ a matrix is not invertible.\n \ \n\ @@ -35,17 +41,22 @@ Example:\n \ Traceback (most recent call last):\n \ File \"\", line 1, in \n \ torch._C._LinAlgError: torch.linalg.inv: The diagonal element 3 is zero, the inversion\n \ -could not be completed because the input matrix is singular.", PyExc_RuntimeError, nullptr)); - ASSERT_TRUE(PyModule_AddObject(module, "_LinAlgError", THPException_LinAlgError) == 0); +could not be completed because the input matrix is singular.", + PyExc_RuntimeError, + nullptr)); + ASSERT_TRUE( + PyModule_AddObject(module, "_LinAlgError", THPException_LinAlgError) == + 0); return true; } namespace torch { -void replaceAll(std::string & str, - const std::string & old_str, - const std::string & new_str) { +void replaceAll( + std::string& str, + const std::string& old_str, + const std::string& new_str) { std::string::size_type pos = 0u; while ((pos = str.find(old_str, pos)) != std::string::npos) { str.replace(pos, old_str.length(), new_str); @@ -53,83 +64,82 @@ void replaceAll(std::string & str, } std::string processErrorMsg(std::string str) { - // Translate Aten types to their respective pytorch ones - std::vector> changes { - {"Variable[SparseCUDAByteType]", "torch.cuda.sparse.ByteTensor"}, - {"Variable[SparseCUDACharType]", "torch.cuda.sparse.CharTensor"}, - {"Variable[SparseCUDADoubleType]", "torch.cuda.sparse.DoubleTensor"}, - {"Variable[SparseCUDAFloatType]", "torch.cuda.sparse.FloatTensor"}, - {"Variable[SparseCUDAIntType]", "torch.cuda.sparse.IntTensor"}, - {"Variable[SparseCUDALongType]", "torch.cuda.sparse.LongTensor"}, - {"Variable[SparseCUDAShortType]", "torch.cuda.sparse.ShortTensor"}, - {"Variable[SparseCUDAHalfType]", "torch.cuda.sparse.HalfTensor"}, - {"Variable[SparseCPUByteType]", "torch.sparse.ByteTensor"}, - {"Variable[SparseCPUCharType]", "torch.sparse.CharTensor"}, - {"Variable[SparseCPUDoubleType]", "torch.sparse.DoubleTensor"}, - {"Variable[SparseCPUFloatType]", "torch.sparse.FloatTensor"}, - {"Variable[SparseCPUIntType]", "torch.sparse.IntTensor"}, - {"Variable[SparseCPULongType]", "torch.sparse.LongTensor"}, - {"Variable[SparseCPUShortType]", "torch.sparse.ShortTensor"}, - {"Variable[SparseCPUHalfType]", "torch.sparse.HalfTensor"}, - {"Variable[CUDAByteType]", "torch.cuda.ByteTensor"}, - {"Variable[CUDACharType]", "torch.cuda.CharTensor"}, - {"Variable[CUDADoubleType]", "torch.cuda.DoubleTensor"}, - {"Variable[CUDAFloatType]", "torch.cuda.FloatTensor"}, - {"Variable[CUDAIntType]", "torch.cuda.IntTensor"}, - {"Variable[CUDALongType]", "torch.cuda.LongTensor"}, - {"Variable[CUDAShortType]", "torch.cuda.ShortTensor"}, - {"Variable[CUDAHalfType]", "torch.cuda.HalfTensor"}, - {"Variable[CPUByteType]", "torch.ByteTensor"}, - {"Variable[CPUCharType]", "torch.CharTensor"}, - {"Variable[CPUDoubleType]", "torch.DoubleTensor"}, - {"Variable[CPUFloatType]", "torch.FloatTensor"}, - {"Variable[CPUIntType]", "torch.IntTensor"}, - {"Variable[CPULongType]", "torch.LongTensor"}, - {"Variable[CPUShortType]", "torch.ShortTensor"}, - {"Variable[CPUHalfType]", "torch.HalfTensor"}, - {"SparseCUDAByteType", "torch.cuda.sparse.ByteTensor"}, - {"SparseCUDACharType", "torch.cuda.sparse.CharTensor"}, - {"SparseCUDADoubleType", "torch.cuda.sparse.DoubleTensor"}, - {"SparseCUDAFloatType", "torch.cuda.sparse.FloatTensor"}, - {"SparseCUDAIntType", "torch.cuda.sparse.IntTensor"}, - {"SparseCUDALongType", "torch.cuda.sparse.LongTensor"}, - {"SparseCUDAShortType", "torch.cuda.sparse.ShortTensor"}, - {"SparseCUDAHalfType", "torch.cuda.sparse.HalfTensor"}, - {"SparseCPUByteType", "torch.sparse.ByteTensor"}, - {"SparseCPUCharType", "torch.sparse.CharTensor"}, - {"SparseCPUDoubleType", "torch.sparse.DoubleTensor"}, - {"SparseCPUFloatType", "torch.sparse.FloatTensor"}, - {"SparseCPUIntType", "torch.sparse.IntTensor"}, - {"SparseCPULongType", "torch.sparse.LongTensor"}, - {"SparseCPUShortType", "torch.sparse.ShortTensor"}, - {"SparseCPUHalfType", "torch.sparse.HalfTensor"}, - {"CUDAByteType", "torch.cuda.ByteTensor"}, - {"CUDACharType", "torch.cuda.CharTensor"}, - {"CUDADoubleType", "torch.cuda.DoubleTensor"}, - {"CUDAFloatType", "torch.cuda.FloatTensor"}, - {"CUDAIntType", "torch.cuda.IntTensor"}, - {"CUDALongType", "torch.cuda.LongTensor"}, - {"CUDAShortType", "torch.cuda.ShortTensor"}, - {"CUDAHalfType", "torch.cuda.HalfTensor"}, - {"CPUByteType", "torch.ByteTensor"}, - {"CPUCharType", "torch.CharTensor"}, - {"CPUDoubleType", "torch.DoubleTensor"}, - {"CPUFloatType", "torch.FloatTensor"}, - {"CPUIntType", "torch.IntTensor"}, - {"CPULongType", "torch.LongTensor"}, - {"CPUShortType", "torch.ShortTensor"}, - {"CPUHalfType", "torch.HalfTensor"}, + std::vector> changes{ + {"Variable[SparseCUDAByteType]", "torch.cuda.sparse.ByteTensor"}, + {"Variable[SparseCUDACharType]", "torch.cuda.sparse.CharTensor"}, + {"Variable[SparseCUDADoubleType]", "torch.cuda.sparse.DoubleTensor"}, + {"Variable[SparseCUDAFloatType]", "torch.cuda.sparse.FloatTensor"}, + {"Variable[SparseCUDAIntType]", "torch.cuda.sparse.IntTensor"}, + {"Variable[SparseCUDALongType]", "torch.cuda.sparse.LongTensor"}, + {"Variable[SparseCUDAShortType]", "torch.cuda.sparse.ShortTensor"}, + {"Variable[SparseCUDAHalfType]", "torch.cuda.sparse.HalfTensor"}, + {"Variable[SparseCPUByteType]", "torch.sparse.ByteTensor"}, + {"Variable[SparseCPUCharType]", "torch.sparse.CharTensor"}, + {"Variable[SparseCPUDoubleType]", "torch.sparse.DoubleTensor"}, + {"Variable[SparseCPUFloatType]", "torch.sparse.FloatTensor"}, + {"Variable[SparseCPUIntType]", "torch.sparse.IntTensor"}, + {"Variable[SparseCPULongType]", "torch.sparse.LongTensor"}, + {"Variable[SparseCPUShortType]", "torch.sparse.ShortTensor"}, + {"Variable[SparseCPUHalfType]", "torch.sparse.HalfTensor"}, + {"Variable[CUDAByteType]", "torch.cuda.ByteTensor"}, + {"Variable[CUDACharType]", "torch.cuda.CharTensor"}, + {"Variable[CUDADoubleType]", "torch.cuda.DoubleTensor"}, + {"Variable[CUDAFloatType]", "torch.cuda.FloatTensor"}, + {"Variable[CUDAIntType]", "torch.cuda.IntTensor"}, + {"Variable[CUDALongType]", "torch.cuda.LongTensor"}, + {"Variable[CUDAShortType]", "torch.cuda.ShortTensor"}, + {"Variable[CUDAHalfType]", "torch.cuda.HalfTensor"}, + {"Variable[CPUByteType]", "torch.ByteTensor"}, + {"Variable[CPUCharType]", "torch.CharTensor"}, + {"Variable[CPUDoubleType]", "torch.DoubleTensor"}, + {"Variable[CPUFloatType]", "torch.FloatTensor"}, + {"Variable[CPUIntType]", "torch.IntTensor"}, + {"Variable[CPULongType]", "torch.LongTensor"}, + {"Variable[CPUShortType]", "torch.ShortTensor"}, + {"Variable[CPUHalfType]", "torch.HalfTensor"}, + {"SparseCUDAByteType", "torch.cuda.sparse.ByteTensor"}, + {"SparseCUDACharType", "torch.cuda.sparse.CharTensor"}, + {"SparseCUDADoubleType", "torch.cuda.sparse.DoubleTensor"}, + {"SparseCUDAFloatType", "torch.cuda.sparse.FloatTensor"}, + {"SparseCUDAIntType", "torch.cuda.sparse.IntTensor"}, + {"SparseCUDALongType", "torch.cuda.sparse.LongTensor"}, + {"SparseCUDAShortType", "torch.cuda.sparse.ShortTensor"}, + {"SparseCUDAHalfType", "torch.cuda.sparse.HalfTensor"}, + {"SparseCPUByteType", "torch.sparse.ByteTensor"}, + {"SparseCPUCharType", "torch.sparse.CharTensor"}, + {"SparseCPUDoubleType", "torch.sparse.DoubleTensor"}, + {"SparseCPUFloatType", "torch.sparse.FloatTensor"}, + {"SparseCPUIntType", "torch.sparse.IntTensor"}, + {"SparseCPULongType", "torch.sparse.LongTensor"}, + {"SparseCPUShortType", "torch.sparse.ShortTensor"}, + {"SparseCPUHalfType", "torch.sparse.HalfTensor"}, + {"CUDAByteType", "torch.cuda.ByteTensor"}, + {"CUDACharType", "torch.cuda.CharTensor"}, + {"CUDADoubleType", "torch.cuda.DoubleTensor"}, + {"CUDAFloatType", "torch.cuda.FloatTensor"}, + {"CUDAIntType", "torch.cuda.IntTensor"}, + {"CUDALongType", "torch.cuda.LongTensor"}, + {"CUDAShortType", "torch.cuda.ShortTensor"}, + {"CUDAHalfType", "torch.cuda.HalfTensor"}, + {"CPUByteType", "torch.ByteTensor"}, + {"CPUCharType", "torch.CharTensor"}, + {"CPUDoubleType", "torch.DoubleTensor"}, + {"CPUFloatType", "torch.FloatTensor"}, + {"CPUIntType", "torch.IntTensor"}, + {"CPULongType", "torch.LongTensor"}, + {"CPUShortType", "torch.ShortTensor"}, + {"CPUHalfType", "torch.HalfTensor"}, }; - for (const auto & it : changes) { + for (const auto& it : changes) { replaceAll(str, it.first, it.second); } return str; } -static std::string formatMessage(const char *format, va_list fmt_args) { +static std::string formatMessage(const char* format, va_list fmt_args) { static const size_t ERROR_BUF_SIZE = 1024; // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays) char error_buf[ERROR_BUF_SIZE]; @@ -141,30 +151,32 @@ static std::string formatMessage(const char *format, va_list fmt_args) { return std::string(error_buf); } -void translate_exception_to_python(const std::exception_ptr &e_ptr) { +void translate_exception_to_python(const std::exception_ptr& e_ptr) { try { - TORCH_INTERNAL_ASSERT(e_ptr, "translate_exception_to_python " - "called with invalid exception pointer"); + TORCH_INTERNAL_ASSERT( + e_ptr, + "translate_exception_to_python " + "called with invalid exception pointer"); std::rethrow_exception(e_ptr); } - CATCH_ALL_ERRORS(return) + CATCH_ALL_ERRORS(return ) } -IndexError::IndexError(const char *format, ...) { +IndexError::IndexError(const char* format, ...) { va_list fmt_args; va_start(fmt_args, format); msg = formatMessage(format, fmt_args); va_end(fmt_args); } -TypeError::TypeError(const char *format, ...) { +TypeError::TypeError(const char* format, ...) { va_list fmt_args; va_start(fmt_args, format); msg = formatMessage(format, fmt_args); va_end(fmt_args); } -ValueError::ValueError(const char *format, ...) { +ValueError::ValueError(const char* format, ...) { va_list fmt_args; va_start(fmt_args, format); msg = formatMessage(format, fmt_args); @@ -192,9 +204,8 @@ void PyWarningHandler::InternalHandler::process( warning_buffer_.push_back({source_location, msg, verbatim}); }; -PyWarningHandler::PyWarningHandler() noexcept(true): - prev_handler_(c10::Warning::get_warning_handler()), - in_exception_(false) { +PyWarningHandler::PyWarningHandler() noexcept(true) + : prev_handler_(c10::Warning::get_warning_handler()), in_exception_(false) { c10::Warning::set_warning_handler(&internal_handler_); } @@ -202,7 +213,7 @@ PyWarningHandler::PyWarningHandler() noexcept(true): /// NOLINTNEXTLINE(bugprone-exception-escape) PyWarningHandler::~PyWarningHandler() noexcept(false) { c10::Warning::set_warning_handler(prev_handler_); - auto &warning_buffer = internal_handler_.warning_buffer_; + auto& warning_buffer = internal_handler_.warning_buffer_; if (warning_buffer.size() > 0) { // NOLINTNEXTLINE(cppcoreguidelines-init-variables) diff --git a/torch/csrc/Exceptions.h b/torch/csrc/Exceptions.h index 4a644a0b45e8fd..a9a7df43d63e13 100644 --- a/torch/csrc/Exceptions.h +++ b/torch/csrc/Exceptions.h @@ -1,20 +1,20 @@ #pragma once #include -#include -#include #include -#include #include +#include +#include +#include +#include #include +#include #include #include +#include #include #include -#include -#include -#include #if defined(USE_DISTRIBUTED) && defined(USE_C10D) #include @@ -47,118 +47,123 @@ static inline void PyErr_SetString(PyObject* type, const std::string& message) { /// This advanced handler will only be used in the current thread. /// If any other thread is used, warnings will be processed as /// cpp warnings. -#define HANDLE_TH_ERRORS \ - try { \ - torch::PyWarningHandler __enforce_warning_buffer; \ +#define HANDLE_TH_ERRORS \ + try { \ + torch::PyWarningHandler __enforce_warning_buffer; \ try { - // Only catch torch-specific exceptions -#define CATCH_CORE_ERRORS(retstmnt) \ - catch (python_error & e) { \ - e.restore(); \ - retstmnt; \ - } \ - catch (const c10::IndexError& e) { \ - auto msg = torch::get_cpp_stacktraces_enabled() ? \ - e.what() : e.what_without_backtrace(); \ - PyErr_SetString(PyExc_IndexError, torch::processErrorMsg(msg)); \ - retstmnt; \ - } \ - catch (const c10::ValueError& e) { \ - auto msg = torch::get_cpp_stacktraces_enabled() ? \ - e.what() : e.what_without_backtrace(); \ - PyErr_SetString(PyExc_ValueError, torch::processErrorMsg(msg)); \ - retstmnt; \ - } \ - catch (const c10::TypeError& e) { \ - auto msg = torch::get_cpp_stacktraces_enabled() ? \ - e.what() : e.what_without_backtrace(); \ - PyErr_SetString(PyExc_TypeError, torch::processErrorMsg(msg)); \ - retstmnt; \ - } \ - catch (const c10::NotImplementedError& e) { \ - auto msg = torch::get_cpp_stacktraces_enabled() ? \ - e.what() : e.what_without_backtrace(); \ - PyErr_SetString(PyExc_NotImplementedError, torch::processErrorMsg(msg)); \ - retstmnt; \ - } \ - catch (const c10::LinAlgError& e) { \ - auto msg = torch::get_cpp_stacktraces_enabled() ? \ - e.what() : e.what_without_backtrace(); \ - PyErr_SetString(THPException_LinAlgError, torch::processErrorMsg(msg)); \ - retstmnt; \ - } \ - catch (const c10::Error& e) { \ - auto msg = torch::get_cpp_stacktraces_enabled() ? \ - e.what() : e.what_without_backtrace(); \ - PyErr_SetString(PyExc_RuntimeError, torch::processErrorMsg(msg)); \ - retstmnt; \ - } \ - catch (torch::PyTorchError& e) { \ - auto msg = torch::processErrorMsg(e.what()); \ - PyErr_SetString(e.python_type(), msg); \ - retstmnt; \ - } +#define CATCH_CORE_ERRORS(retstmnt) \ + catch (python_error & e) { \ + e.restore(); \ + retstmnt; \ + } \ + catch (const c10::IndexError& e) { \ + auto msg = torch::get_cpp_stacktraces_enabled() \ + ? e.what() \ + : e.what_without_backtrace(); \ + PyErr_SetString(PyExc_IndexError, torch::processErrorMsg(msg)); \ + retstmnt; \ + } \ + catch (const c10::ValueError& e) { \ + auto msg = torch::get_cpp_stacktraces_enabled() \ + ? e.what() \ + : e.what_without_backtrace(); \ + PyErr_SetString(PyExc_ValueError, torch::processErrorMsg(msg)); \ + retstmnt; \ + } \ + catch (const c10::TypeError& e) { \ + auto msg = torch::get_cpp_stacktraces_enabled() \ + ? e.what() \ + : e.what_without_backtrace(); \ + PyErr_SetString(PyExc_TypeError, torch::processErrorMsg(msg)); \ + retstmnt; \ + } \ + catch (const c10::NotImplementedError& e) { \ + auto msg = torch::get_cpp_stacktraces_enabled() \ + ? e.what() \ + : e.what_without_backtrace(); \ + PyErr_SetString(PyExc_NotImplementedError, torch::processErrorMsg(msg)); \ + retstmnt; \ + } \ + catch (const c10::LinAlgError& e) { \ + auto msg = torch::get_cpp_stacktraces_enabled() \ + ? e.what() \ + : e.what_without_backtrace(); \ + PyErr_SetString(THPException_LinAlgError, torch::processErrorMsg(msg)); \ + retstmnt; \ + } \ + catch (const c10::Error& e) { \ + auto msg = torch::get_cpp_stacktraces_enabled() \ + ? e.what() \ + : e.what_without_backtrace(); \ + PyErr_SetString(PyExc_RuntimeError, torch::processErrorMsg(msg)); \ + retstmnt; \ + } \ + catch (torch::PyTorchError & e) { \ + auto msg = torch::processErrorMsg(e.what()); \ + PyErr_SetString(e.python_type(), msg); \ + retstmnt; \ + } #if defined(USE_DISTRIBUTED) && defined(USE_C10D) -#define CATCH_C10D_ERRORS(retstmnt) \ - catch (const c10d::TimeoutError& e) { \ - auto msg = torch::processErrorMsg(e.what()); \ - PyErr_SetString(PyExc_TimeoutError, msg); \ - retstmnt; \ - } \ - catch (const c10d::C10dError& e) { \ - auto msg = torch::processErrorMsg(e.what()); \ - PyErr_SetString(PyExc_RuntimeError, msg); \ - retstmnt; \ - } +#define CATCH_C10D_ERRORS(retstmnt) \ + catch (const c10d::TimeoutError& e) { \ + auto msg = torch::processErrorMsg(e.what()); \ + PyErr_SetString(PyExc_TimeoutError, msg); \ + retstmnt; \ + } \ + catch (const c10d::C10dError& e) { \ + auto msg = torch::processErrorMsg(e.what()); \ + PyErr_SetString(PyExc_RuntimeError, msg); \ + retstmnt; \ + } #else #define CATCH_C10D_ERRORS(retstmnt) #endif -#define CATCH_TH_ERRORS(retstmnt) \ - CATCH_CORE_ERRORS(retstmnt) \ - CATCH_C10D_ERRORS(retstmnt) +#define CATCH_TH_ERRORS(retstmnt) \ + CATCH_CORE_ERRORS(retstmnt) \ + CATCH_C10D_ERRORS(retstmnt) -#define CATCH_ALL_ERRORS(retstmnt) \ - CATCH_TH_ERRORS(retstmnt) \ - catch (const std::exception& e) { \ - auto msg = torch::processErrorMsg(e.what()); \ - PyErr_SetString(PyExc_RuntimeError, msg); \ - retstmnt; \ - } +#define CATCH_ALL_ERRORS(retstmnt) \ + CATCH_TH_ERRORS(retstmnt) \ + catch (const std::exception& e) { \ + auto msg = torch::processErrorMsg(e.what()); \ + PyErr_SetString(PyExc_RuntimeError, msg); \ + retstmnt; \ + } -#define END_HANDLE_TH_ERRORS_PYBIND \ - } \ - catch(...) { \ - __enforce_warning_buffer.set_in_exception(); \ - throw; \ - } \ - } \ - catch (py::error_already_set & e) { \ - throw; \ - } \ - catch (py::builtin_exception & e) { \ - throw; \ - } \ - catch (torch::jit::JITException & e) { \ - throw; \ - } \ - catch (const std::exception & e) { \ - torch::translate_exception_to_python(std::current_exception()); \ - throw py::error_already_set(); \ +#define END_HANDLE_TH_ERRORS_PYBIND \ + } \ + catch (...) { \ + __enforce_warning_buffer.set_in_exception(); \ + throw; \ + } \ + } \ + catch (py::error_already_set & e) { \ + throw; \ + } \ + catch (py::builtin_exception & e) { \ + throw; \ + } \ + catch (torch::jit::JITException & e) { \ + throw; \ + } \ + catch (const std::exception& e) { \ + torch::translate_exception_to_python(std::current_exception()); \ + throw py::error_already_set(); \ } -#define END_HANDLE_TH_ERRORS_RET(retval) \ - } \ - catch(...) { \ - __enforce_warning_buffer.set_in_exception(); \ - throw; \ - } \ - } \ - catch (const std::exception & e) { \ - torch::translate_exception_to_python(std::current_exception()); \ - return retval; \ +#define END_HANDLE_TH_ERRORS_RET(retval) \ + } \ + catch (...) { \ + __enforce_warning_buffer.set_in_exception(); \ + throw; \ + } \ + } \ + catch (const std::exception& e) { \ + torch::translate_exception_to_python(std::current_exception()); \ + return retval; \ } #define END_HANDLE_TH_ERRORS END_HANDLE_TH_ERRORS_RET(nullptr) @@ -200,7 +205,7 @@ struct python_error : public std::exception { } } - const char* what() const noexcept override { + const char* what() const noexcept override { return message.c_str(); } @@ -243,7 +248,8 @@ struct python_error : public std::exception { /** Saves the exception so that it can be re-thrown on a different thread */ inline void persist() { - if (type) return; // Don't overwrite exceptions + if (type) + return; // Don't overwrite exceptions // PyErr_Fetch overwrites the pointers pybind11::gil_scoped_acquire gil; Py_XDECREF(type); @@ -255,7 +261,8 @@ struct python_error : public std::exception { /** Sets the current Python error from this exception */ inline void restore() { - if (!type) return; + if (!type) + return; // PyErr_Restore steals references pybind11::gil_scoped_acquire gil; Py_XINCREF(type); @@ -272,19 +279,19 @@ struct python_error : public std::exception { std::string message; }; -bool THPException_init(PyObject *module); +bool THPException_init(PyObject* module); namespace torch { // Set python current exception from a C++ exception -TORCH_PYTHON_API void translate_exception_to_python(const std::exception_ptr &); +TORCH_PYTHON_API void translate_exception_to_python(const std::exception_ptr&); TORCH_PYTHON_API std::string processErrorMsg(std::string str); // Abstract base class for exceptions which translate to specific Python types struct PyTorchError : public std::exception { // NOLINTNEXTLINE(modernize-pass-by-value) - PyTorchError(const std::string& msg_ = std::string()): msg(msg_) {} + PyTorchError(const std::string& msg_ = std::string()) : msg(msg_) {} virtual PyObject* python_type() = 0; const char* what() const noexcept override { return msg.c_str(); @@ -295,16 +302,16 @@ struct PyTorchError : public std::exception { // Declare a printf-like function on gcc & clang // The compiler can then warn on invalid format specifiers #ifdef __GNUC__ -# define TORCH_FORMAT_FUNC(FORMAT_INDEX, VA_ARGS_INDEX) \ - __attribute__((format (printf, FORMAT_INDEX, VA_ARGS_INDEX))) +#define TORCH_FORMAT_FUNC(FORMAT_INDEX, VA_ARGS_INDEX) \ + __attribute__((format(printf, FORMAT_INDEX, VA_ARGS_INDEX))) #else -# define TORCH_FORMAT_FUNC(FORMAT_INDEX, VA_ARGS_INDEX) +#define TORCH_FORMAT_FUNC(FORMAT_INDEX, VA_ARGS_INDEX) #endif // Translates to Python IndexError struct IndexError : public PyTorchError { using PyTorchError::PyTorchError; - IndexError(const char *format, ...) TORCH_FORMAT_FUNC(2, 3); + IndexError(const char* format, ...) TORCH_FORMAT_FUNC(2, 3); PyObject* python_type() override { return PyExc_IndexError; } @@ -313,7 +320,7 @@ struct IndexError : public PyTorchError { // Translates to Python TypeError struct TypeError : public PyTorchError { using PyTorchError::PyTorchError; - TORCH_API TypeError(const char *format, ...) TORCH_FORMAT_FUNC(2, 3); + TORCH_API TypeError(const char* format, ...) TORCH_FORMAT_FUNC(2, 3); PyObject* python_type() override { return PyExc_TypeError; } @@ -322,7 +329,7 @@ struct TypeError : public PyTorchError { // Translates to Python ValueError struct ValueError : public PyTorchError { using PyTorchError::PyTorchError; - ValueError(const char *format, ...) TORCH_FORMAT_FUNC(2, 3); + ValueError(const char* format, ...) TORCH_FORMAT_FUNC(2, 3); PyObject* python_type() override { return PyExc_ValueError; } @@ -353,14 +360,12 @@ struct LinAlgError : public PyTorchError { }; struct WarningMeta { - WarningMeta(const c10::SourceLocation& _source_location, + WarningMeta( + const c10::SourceLocation& _source_location, // NOLINTNEXTLINE(modernize-pass-by-value) const std::string& _msg, - const bool _verbatim) : - source_location_{_source_location}, - msg_{_msg}, - verbatim_{_verbatim} { } - + const bool _verbatim) + : source_location_{_source_location}, msg_{_msg}, verbatim_{_verbatim} {} const c10::SourceLocation source_location_; const std::string msg_; @@ -372,23 +377,22 @@ struct PyWarningHandler { // Move actual handler into a separate class with a noexcept // destructor. Otherwise, we need to force all WarningHandler // subclasses to have a noexcept(false) destructor. - struct InternalHandler: at::WarningHandler { + struct InternalHandler : at::WarningHandler { ~InternalHandler() override = default; - void process(const at::SourceLocation &source_location, - const std::string &msg, - const bool verbatim) override; + void process( + const at::SourceLocation& source_location, + const std::string& msg, + const bool verbatim) override; std::vector warning_buffer_; }; - -public: + public: /// See NOTE [ Conversion Cpp Python Warning ] for noexcept justification TORCH_API PyWarningHandler() noexcept(true); // NOLINTNEXTLINE(bugprone-exception-escape) TORCH_API ~PyWarningHandler() noexcept(false); - /** Call if an exception has been thrown * Necessary to determine if it is safe to throw from the desctructor since @@ -399,7 +403,7 @@ struct PyWarningHandler { in_exception_ = true; } -private: + private: InternalHandler internal_handler_; at::WarningHandler* prev_handler_; bool in_exception_; @@ -409,19 +413,19 @@ namespace detail { template using Arg = typename function_traits::template arg::type; -template +template auto wrap_pybind_function_impl_(Func&& f, std::index_sequence) { using traits = function_traits; namespace py = pybind11; // f=f is needed to handle function references on older compilers - return [f=f](Arg ...args) -> typename traits::result_type { + return [f = f](Arg... args) -> typename traits::result_type { HANDLE_TH_ERRORS return f(std::forward>(args)...); END_HANDLE_TH_ERRORS_PYBIND }; } -} // namespace detail +} // namespace detail // Wrap a function with TH error and warning handling. // Returns a function object suitable for registering with pybind11. @@ -429,7 +433,7 @@ template auto wrap_pybind_function(Func&& f) { using traits = function_traits; return torch::detail::wrap_pybind_function_impl_( - std::forward(f), std::make_index_sequence{}); + std::forward(f), std::make_index_sequence{}); } } // namespace torch diff --git a/torch/csrc/Generator.cpp b/torch/csrc/Generator.cpp index 3c6b98c5911e88..31dcfefaea8d8f 100644 --- a/torch/csrc/Generator.cpp +++ b/torch/csrc/Generator.cpp @@ -1,17 +1,17 @@ #include -#include #include #include +#include -#include #include #include -#include +#include #include -#include -#include #include +#include +#include +#include #ifdef USE_CUDA #include @@ -20,20 +20,19 @@ using namespace at; using namespace torch; -PyObject *THPGeneratorClass = nullptr; +PyObject* THPGeneratorClass = nullptr; -PyObject * THPGenerator_initDefaultGenerator(at::Generator cdata) -{ +PyObject* THPGenerator_initDefaultGenerator(at::Generator cdata) { auto type = (PyTypeObject*)THPGeneratorClass; auto self = THPObjectPtr{type->tp_alloc(type, 0)}; - if (!self) throw python_error(); + if (!self) + throw python_error(); auto self_ = reinterpret_cast(self.get()); self_->cdata = cdata; return self.release(); } -static void THPGenerator_dealloc(PyObject* _self) -{ +static void THPGenerator_dealloc(PyObject* _self) { auto self = reinterpret_cast(_self); if (self->cdata.defined()) { self->cdata.set_pyobj(nullptr); @@ -42,38 +41,41 @@ static void THPGenerator_dealloc(PyObject* _self) Py_TYPE(_self)->tp_free(_self); } -static PyObject * THPGenerator_pynew(PyTypeObject *type, PyObject *args, PyObject *kwargs) -{ +static PyObject* THPGenerator_pynew( + PyTypeObject* type, + PyObject* args, + PyObject* kwargs) { HANDLE_TH_ERRORS - static torch::PythonArgParser parser({ - "Generator(Device device=None)" - }); + static torch::PythonArgParser parser({"Generator(Device device=None)"}); torch::ParsedArgs<1> parsed_args; auto r = parser.parse(args, kwargs, parsed_args); auto device = r.deviceWithDefault(0, at::Device(at::kCPU)); - THPGeneratorPtr self((THPGenerator *)type->tp_alloc(type, 0)); + THPGeneratorPtr self((THPGenerator*)type->tp_alloc(type, 0)); #ifdef USE_CUDA if (device.type() == at::kCPU) { self->cdata = make_generator(); - } else if (device.type() == at::kCUDA){ + } else if (device.type() == at::kCUDA) { self->cdata = make_generator(device.index()); } else { - AT_ERROR("Device type ", c10::DeviceTypeName(device.type()), - " is not supported for torch.Generator() api."); + AT_ERROR( + "Device type ", + c10::DeviceTypeName(device.type()), + " is not supported for torch.Generator() api."); } #else - TORCH_CHECK(device.type() == at::kCPU, - "Device type ", c10::DeviceTypeName(device.type()), - " is not supported for torch.Generator() api."); + TORCH_CHECK( + device.type() == at::kCPU, + "Device type ", + c10::DeviceTypeName(device.type()), + " is not supported for torch.Generator() api."); self->cdata = make_generator(); #endif return (PyObject*)self.release(); END_HANDLE_TH_ERRORS } -static PyObject * THPGenerator_getState(PyObject *_self, PyObject *noargs) -{ +static PyObject* THPGenerator_getState(PyObject* _self, PyObject* noargs) { using namespace torch::autograd; HANDLE_TH_ERRORS auto& gen = ((THPGenerator*)_self)->cdata; @@ -86,13 +88,14 @@ static PyObject * THPGenerator_getState(PyObject *_self, PyObject *noargs) END_HANDLE_TH_ERRORS } -static PyObject * THPGenerator_setState(PyObject *_self, PyObject *_new_state) -{ +static PyObject* THPGenerator_setState(PyObject* _self, PyObject* _new_state) { using namespace torch::autograd; HANDLE_TH_ERRORS if (!THPVariable_Check(_new_state)) { - throw torch::TypeError("expected a torch.ByteTensor, but got %s", Py_TYPE(_new_state)->tp_name); + throw torch::TypeError( + "expected a torch.ByteTensor, but got %s", + Py_TYPE(_new_state)->tp_name); } auto self = (THPGenerator*)_self; auto& gen = self->cdata; @@ -107,13 +110,15 @@ static PyObject * THPGenerator_setState(PyObject *_self, PyObject *_new_state) END_HANDLE_TH_ERRORS } -static PyObject * THPGenerator_manualSeed(PyObject *_self, PyObject *seed) -{ +static PyObject* THPGenerator_manualSeed(PyObject* _self, PyObject* seed) { HANDLE_TH_ERRORS auto self = (THPGenerator*)_self; auto generator = self->cdata; - THPUtils_assert(THPUtils_checkLong(seed), "manual_seed expected a long, " - "but got %s", THPUtils_typename(seed)); + THPUtils_assert( + THPUtils_checkLong(seed), + "manual_seed expected a long, " + "but got %s", + THPUtils_typename(seed)); // See Note [Acquire lock when using random generators] std::lock_guard lock(generator.mutex()); // NOLINTNEXTLINE(cppcoreguidelines-init-variables) @@ -121,7 +126,7 @@ static PyObject * THPGenerator_manualSeed(PyObject *_self, PyObject *seed) try { // First try to interpret as unsigned long seed_unpacked = THPUtils_unpackUInt64(seed); - } catch(...) { + } catch (...) { if (PyErr_ExceptionMatches(PyExc_OverflowError)) { // If an overflow happened, then the seed could be negative, // so try to interpret it as signed long @@ -139,8 +144,7 @@ static PyObject * THPGenerator_manualSeed(PyObject *_self, PyObject *seed) END_HANDLE_TH_ERRORS } -static PyObject * THPGenerator_seed(PyObject *_self, PyObject *noargs) -{ +static PyObject* THPGenerator_seed(PyObject* _self, PyObject* noargs) { HANDLE_TH_ERRORS // See Note [Acquire lock when using random generators] auto self = (THPGenerator*)_self; @@ -150,15 +154,14 @@ static PyObject * THPGenerator_seed(PyObject *_self, PyObject *noargs) END_HANDLE_TH_ERRORS } -static PyObject * THPGenerator_initialSeed(PyObject *_self, PyObject *noargs) -{ +static PyObject* THPGenerator_initialSeed(PyObject* _self, PyObject* noargs) { HANDLE_TH_ERRORS auto self = (THPGenerator*)_self; return THPUtils_packUInt64(self->cdata.current_seed()); END_HANDLE_TH_ERRORS } -static PyObject * THPGenerator_get_device(THPGenerator *self, void *unused) { +static PyObject* THPGenerator_get_device(THPGenerator* self, void* unused) { HANDLE_TH_ERRORS return THPDevice_New(self->cdata.device()); END_HANDLE_TH_ERRORS @@ -166,74 +169,73 @@ static PyObject * THPGenerator_get_device(THPGenerator *self, void *unused) { // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables) static struct PyGetSetDef THPGenerator_properties[] = { - {"device", (getter)THPGenerator_get_device, nullptr, nullptr, nullptr}, - {nullptr} -}; + {"device", (getter)THPGenerator_get_device, nullptr, nullptr, nullptr}, + {nullptr}}; // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables) static PyMethodDef THPGenerator_methods[] = { - {"get_state", THPGenerator_getState, METH_NOARGS, nullptr}, - {"set_state", THPGenerator_setState, METH_O, nullptr}, - {"manual_seed", THPGenerator_manualSeed, METH_O, nullptr}, - {"seed", THPGenerator_seed, METH_NOARGS, nullptr}, - {"initial_seed", THPGenerator_initialSeed, METH_NOARGS, nullptr}, - {nullptr} -}; + {"get_state", THPGenerator_getState, METH_NOARGS, nullptr}, + {"set_state", THPGenerator_setState, METH_O, nullptr}, + {"manual_seed", THPGenerator_manualSeed, METH_O, nullptr}, + {"seed", THPGenerator_seed, METH_NOARGS, nullptr}, + {"initial_seed", THPGenerator_initialSeed, METH_NOARGS, nullptr}, + {nullptr}}; // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables) static struct PyMemberDef THPGenerator_members[] = { - {(char*)"_cdata", T_ULONGLONG, offsetof(THPGenerator, cdata), READONLY, nullptr}, - {nullptr} -}; + {(char*)"_cdata", + T_ULONGLONG, + offsetof(THPGenerator, cdata), + READONLY, + nullptr}, + {nullptr}}; PyTypeObject THPGeneratorType = { - PyVarObject_HEAD_INIT(nullptr, 0) - "torch._C.Generator", /* tp_name */ - sizeof(THPGenerator), /* tp_basicsize */ - 0, /* tp_itemsize */ - THPGenerator_dealloc, /* tp_dealloc */ - 0, /* tp_vectorcall_offset */ - nullptr, /* tp_getattr */ - nullptr, /* tp_setattr */ - nullptr, /* tp_reserved */ - nullptr, /* tp_repr */ - nullptr, /* tp_as_number */ - nullptr, /* tp_as_sequence */ - nullptr, /* tp_as_mapping */ - nullptr, /* tp_hash */ - nullptr, /* tp_call */ - nullptr, /* tp_str */ - nullptr, /* tp_getattro */ - nullptr, /* tp_setattro */ - nullptr, /* tp_as_buffer */ - Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /* tp_flags */ - nullptr, /* tp_doc */ - nullptr, /* tp_traverse */ - nullptr, /* tp_clear */ - nullptr, /* tp_richcompare */ - 0, /* tp_weaklistoffset */ - nullptr, /* tp_iter */ - nullptr, /* tp_iternext */ - THPGenerator_methods, /* tp_methods */ - THPGenerator_members, /* tp_members */ - THPGenerator_properties, /* tp_getset */ - nullptr, /* tp_base */ - nullptr, /* tp_dict */ - nullptr, /* tp_descr_get */ - nullptr, /* tp_descr_set */ - 0, /* tp_dictoffset */ - nullptr, /* tp_init */ - nullptr, /* tp_alloc */ - THPGenerator_pynew, /* tp_new */ + PyVarObject_HEAD_INIT(nullptr, 0) "torch._C.Generator", /* tp_name */ + sizeof(THPGenerator), /* tp_basicsize */ + 0, /* tp_itemsize */ + THPGenerator_dealloc, /* tp_dealloc */ + 0, /* tp_vectorcall_offset */ + nullptr, /* tp_getattr */ + nullptr, /* tp_setattr */ + nullptr, /* tp_reserved */ + nullptr, /* tp_repr */ + nullptr, /* tp_as_number */ + nullptr, /* tp_as_sequence */ + nullptr, /* tp_as_mapping */ + nullptr, /* tp_hash */ + nullptr, /* tp_call */ + nullptr, /* tp_str */ + nullptr, /* tp_getattro */ + nullptr, /* tp_setattro */ + nullptr, /* tp_as_buffer */ + Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /* tp_flags */ + nullptr, /* tp_doc */ + nullptr, /* tp_traverse */ + nullptr, /* tp_clear */ + nullptr, /* tp_richcompare */ + 0, /* tp_weaklistoffset */ + nullptr, /* tp_iter */ + nullptr, /* tp_iternext */ + THPGenerator_methods, /* tp_methods */ + THPGenerator_members, /* tp_members */ + THPGenerator_properties, /* tp_getset */ + nullptr, /* tp_base */ + nullptr, /* tp_dict */ + nullptr, /* tp_descr_get */ + nullptr, /* tp_descr_set */ + 0, /* tp_dictoffset */ + nullptr, /* tp_init */ + nullptr, /* tp_alloc */ + THPGenerator_pynew, /* tp_new */ }; -bool THPGenerator_init(PyObject *module) -{ +bool THPGenerator_init(PyObject* module) { THPGeneratorClass = (PyObject*)&THPGeneratorType; if (PyType_Ready(&THPGeneratorType) < 0) return false; Py_INCREF(&THPGeneratorType); - PyModule_AddObject(module, "Generator", (PyObject *)&THPGeneratorType); + PyModule_AddObject(module, "Generator", (PyObject*)&THPGeneratorType); return true; } @@ -247,8 +249,7 @@ PyObject* pyobj(const Generator& self) { return self.pyobj(); } -PyObject * THPGenerator_Wrap(Generator gen) -{ +PyObject* THPGenerator_Wrap(Generator gen) { if (!gen.defined()) { Py_RETURN_NONE; } @@ -258,16 +259,16 @@ PyObject * THPGenerator_Wrap(Generator gen) return obj; } - return THPGenerator_NewWithVar((PyTypeObject *)THPGeneratorClass, std::move(gen)); + return THPGenerator_NewWithVar( + (PyTypeObject*)THPGeneratorClass, std::move(gen)); } // Creates a new Python object for a Generator. The Generator must not already // have a PyObject* associated with it. -PyObject* THPGenerator_NewWithVar(PyTypeObject* type, Generator gen) -{ +PyObject* THPGenerator_NewWithVar(PyTypeObject* type, Generator gen) { PyObject* obj = type->tp_alloc(type, 0); if (obj) { - auto g = (THPGenerator*) obj; + auto g = (THPGenerator*)obj; new (&g->cdata) Generator(std::move(gen)); set_pyobj(g->cdata, obj); } diff --git a/torch/csrc/Generator.h b/torch/csrc/Generator.h index 17f0a55dfe4914..f5b7b4661eb585 100644 --- a/torch/csrc/Generator.h +++ b/torch/csrc/Generator.h @@ -1,29 +1,27 @@ #pragma once +#include #include #include -#include - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) struct THPGenerator { - PyObject_HEAD - at::Generator cdata; + PyObject_HEAD at::Generator cdata; }; -// Creates a new Python object wrapping the default at::Generator. The reference is -// borrowed. The caller should ensure that the at::Generator object lifetime +// Creates a new Python object wrapping the default at::Generator. The reference +// is borrowed. The caller should ensure that the at::Generator object lifetime // last at least as long as the Python wrapper. -TORCH_PYTHON_API PyObject * THPGenerator_initDefaultGenerator(at::Generator cdata); +TORCH_PYTHON_API PyObject* THPGenerator_initDefaultGenerator( + at::Generator cdata); -#define THPGenerator_Check(obj) \ - PyObject_IsInstance(obj, THPGeneratorClass) +#define THPGenerator_Check(obj) PyObject_IsInstance(obj, THPGeneratorClass) -TORCH_PYTHON_API extern PyObject *THPGeneratorClass; +TORCH_PYTHON_API extern PyObject* THPGeneratorClass; -bool THPGenerator_init(PyObject *module); +bool THPGenerator_init(PyObject* module); -TORCH_PYTHON_API PyObject * THPGenerator_Wrap(at::Generator gen); +TORCH_PYTHON_API PyObject* THPGenerator_Wrap(at::Generator gen); // Creates a new Python object for a Generator. The Generator must not already // have a PyObject* associated with it. diff --git a/torch/csrc/Layout.cpp b/torch/csrc/Layout.cpp index 552ca6121078bd..4b56805b0b5a27 100644 --- a/torch/csrc/Layout.cpp +++ b/torch/csrc/Layout.cpp @@ -10,71 +10,68 @@ #include #include -PyObject *THPLayout_New(at::Layout layout, const std::string& name) -{ +PyObject* THPLayout_New(at::Layout layout, const std::string& name) { auto type = (PyTypeObject*)&THPLayoutType; auto self = THPObjectPtr{type->tp_alloc(type, 0)}; - if (!self) throw python_error(); + if (!self) + throw python_error(); auto self_ = reinterpret_cast(self.get()); self_->layout = layout; - std::strncpy (self_->name, name.c_str(), LAYOUT_NAME_LEN); + std::strncpy(self_->name, name.c_str(), LAYOUT_NAME_LEN); self_->name[LAYOUT_NAME_LEN] = '\0'; return self.release(); } -PyObject *THPLayout_repr(THPLayout *self) -{ +PyObject* THPLayout_repr(THPLayout* self) { return THPUtils_packString(self->name); } PyTypeObject THPLayoutType = { - PyVarObject_HEAD_INIT(nullptr, 0) - "torch.layout", /* tp_name */ - sizeof(THPLayout), /* tp_basicsize */ - 0, /* tp_itemsize */ - nullptr, /* tp_dealloc */ - 0, /* tp_vectorcall_offset */ - nullptr, /* tp_getattr */ - nullptr, /* tp_setattr */ - nullptr, /* tp_reserved */ - (reprfunc)THPLayout_repr, /* tp_repr */ - nullptr, /* tp_as_number */ - nullptr, /* tp_as_sequence */ - nullptr, /* tp_as_mapping */ - nullptr, /* tp_hash */ - nullptr, /* tp_call */ - nullptr, /* tp_str */ - nullptr, /* tp_getattro */ - nullptr, /* tp_setattro */ - nullptr, /* tp_as_buffer */ - Py_TPFLAGS_DEFAULT, /* tp_flags */ - nullptr, /* tp_doc */ - nullptr, /* tp_traverse */ - nullptr, /* tp_clear */ - nullptr, /* tp_richcompare */ - 0, /* tp_weaklistoffset */ - nullptr, /* tp_iter */ - nullptr, /* tp_iternext */ - nullptr, /* tp_methods */ - nullptr, /* tp_members */ - nullptr, /* tp_getset */ - nullptr, /* tp_base */ - nullptr, /* tp_dict */ - nullptr, /* tp_descr_get */ - nullptr, /* tp_descr_set */ - 0, /* tp_dictoffset */ - nullptr, /* tp_init */ - nullptr, /* tp_alloc */ - nullptr, /* tp_new */ + PyVarObject_HEAD_INIT(nullptr, 0) "torch.layout", /* tp_name */ + sizeof(THPLayout), /* tp_basicsize */ + 0, /* tp_itemsize */ + nullptr, /* tp_dealloc */ + 0, /* tp_vectorcall_offset */ + nullptr, /* tp_getattr */ + nullptr, /* tp_setattr */ + nullptr, /* tp_reserved */ + (reprfunc)THPLayout_repr, /* tp_repr */ + nullptr, /* tp_as_number */ + nullptr, /* tp_as_sequence */ + nullptr, /* tp_as_mapping */ + nullptr, /* tp_hash */ + nullptr, /* tp_call */ + nullptr, /* tp_str */ + nullptr, /* tp_getattro */ + nullptr, /* tp_setattro */ + nullptr, /* tp_as_buffer */ + Py_TPFLAGS_DEFAULT, /* tp_flags */ + nullptr, /* tp_doc */ + nullptr, /* tp_traverse */ + nullptr, /* tp_clear */ + nullptr, /* tp_richcompare */ + 0, /* tp_weaklistoffset */ + nullptr, /* tp_iter */ + nullptr, /* tp_iternext */ + nullptr, /* tp_methods */ + nullptr, /* tp_members */ + nullptr, /* tp_getset */ + nullptr, /* tp_base */ + nullptr, /* tp_dict */ + nullptr, /* tp_descr_get */ + nullptr, /* tp_descr_set */ + 0, /* tp_dictoffset */ + nullptr, /* tp_init */ + nullptr, /* tp_alloc */ + nullptr, /* tp_new */ }; -void THPLayout_init(PyObject *module) -{ +void THPLayout_init(PyObject* module) { if (PyType_Ready(&THPLayoutType) < 0) { throw python_error(); } Py_INCREF(&THPLayoutType); - if (PyModule_AddObject(module, "layout", (PyObject *)&THPLayoutType) != 0) { + if (PyModule_AddObject(module, "layout", (PyObject*)&THPLayoutType) != 0) { throw python_error(); } } diff --git a/torch/csrc/Layout.h b/torch/csrc/Layout.h index 74d863fe59282b..265582e0ddfaea 100644 --- a/torch/csrc/Layout.h +++ b/torch/csrc/Layout.h @@ -9,18 +9,17 @@ const int LAYOUT_NAME_LEN = 64; struct THPLayout { - PyObject_HEAD - at::Layout layout; + PyObject_HEAD at::Layout layout; // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays) char name[LAYOUT_NAME_LEN + 1]; }; extern PyTypeObject THPLayoutType; -inline bool THPLayout_Check(PyObject *obj) { +inline bool THPLayout_Check(PyObject* obj) { return Py_TYPE(obj) == &THPLayoutType; } -PyObject * THPLayout_New(at::Layout layout, const std::string& name); +PyObject* THPLayout_New(at::Layout layout, const std::string& name); -void THPLayout_init(PyObject *module); +void THPLayout_init(PyObject* module); diff --git a/torch/csrc/MemoryFormat.cpp b/torch/csrc/MemoryFormat.cpp index 85d8750c659cac..cb4a4e387feab8 100644 --- a/torch/csrc/MemoryFormat.cpp +++ b/torch/csrc/MemoryFormat.cpp @@ -10,71 +10,71 @@ #include #include -PyObject *THPMemoryFormat_New(at::MemoryFormat memory_format, const std::string& name) -{ +PyObject* THPMemoryFormat_New( + at::MemoryFormat memory_format, + const std::string& name) { auto type = (PyTypeObject*)&THPMemoryFormatType; auto self = THPObjectPtr{type->tp_alloc(type, 0)}; - if (!self) throw python_error(); + if (!self) + throw python_error(); auto self_ = reinterpret_cast(self.get()); self_->memory_format = memory_format; - std::strncpy (self_->name, name.c_str(), MEMORY_FORMAT_NAME_LEN); + std::strncpy(self_->name, name.c_str(), MEMORY_FORMAT_NAME_LEN); self_->name[MEMORY_FORMAT_NAME_LEN] = '\0'; return self.release(); } -PyObject *THPMemoryFormat_repr(THPMemoryFormat *self) -{ +PyObject* THPMemoryFormat_repr(THPMemoryFormat* self) { return THPUtils_packString(self->name); } PyTypeObject THPMemoryFormatType = { - PyVarObject_HEAD_INIT(nullptr, 0) - "torch.memory_format", /* tp_name */ - sizeof(THPMemoryFormat), /* tp_basicsize */ - 0, /* tp_itemsize */ - nullptr, /* tp_dealloc */ - 0, /* tp_vectorcall_offset */ - nullptr, /* tp_getattr */ - nullptr, /* tp_setattr */ - nullptr, /* tp_reserved */ - (reprfunc)THPMemoryFormat_repr, /* tp_repr */ - nullptr, /* tp_as_number */ - nullptr, /* tp_as_sequence */ - nullptr, /* tp_as_mapping */ - nullptr, /* tp_hash */ - nullptr, /* tp_call */ - nullptr, /* tp_str */ - nullptr, /* tp_getattro */ - nullptr, /* tp_setattro */ - nullptr, /* tp_as_buffer */ - Py_TPFLAGS_DEFAULT, /* tp_flags */ - nullptr, /* tp_doc */ - nullptr, /* tp_traverse */ - nullptr, /* tp_clear */ - nullptr, /* tp_richcompare */ - 0, /* tp_weaklistoffset */ - nullptr, /* tp_iter */ - nullptr, /* tp_iternext */ - nullptr, /* tp_methods */ - nullptr, /* tp_members */ - nullptr, /* tp_getset */ - nullptr, /* tp_base */ - nullptr, /* tp_dict */ - nullptr, /* tp_descr_get */ - nullptr, /* tp_descr_set */ - 0, /* tp_dictoffset */ - nullptr, /* tp_init */ - nullptr, /* tp_alloc */ - nullptr, /* tp_new */ + PyVarObject_HEAD_INIT(nullptr, 0) "torch.memory_format", /* tp_name */ + sizeof(THPMemoryFormat), /* tp_basicsize */ + 0, /* tp_itemsize */ + nullptr, /* tp_dealloc */ + 0, /* tp_vectorcall_offset */ + nullptr, /* tp_getattr */ + nullptr, /* tp_setattr */ + nullptr, /* tp_reserved */ + (reprfunc)THPMemoryFormat_repr, /* tp_repr */ + nullptr, /* tp_as_number */ + nullptr, /* tp_as_sequence */ + nullptr, /* tp_as_mapping */ + nullptr, /* tp_hash */ + nullptr, /* tp_call */ + nullptr, /* tp_str */ + nullptr, /* tp_getattro */ + nullptr, /* tp_setattro */ + nullptr, /* tp_as_buffer */ + Py_TPFLAGS_DEFAULT, /* tp_flags */ + nullptr, /* tp_doc */ + nullptr, /* tp_traverse */ + nullptr, /* tp_clear */ + nullptr, /* tp_richcompare */ + 0, /* tp_weaklistoffset */ + nullptr, /* tp_iter */ + nullptr, /* tp_iternext */ + nullptr, /* tp_methods */ + nullptr, /* tp_members */ + nullptr, /* tp_getset */ + nullptr, /* tp_base */ + nullptr, /* tp_dict */ + nullptr, /* tp_descr_get */ + nullptr, /* tp_descr_set */ + 0, /* tp_dictoffset */ + nullptr, /* tp_init */ + nullptr, /* tp_alloc */ + nullptr, /* tp_new */ }; -void THPMemoryFormat_init(PyObject *module) -{ +void THPMemoryFormat_init(PyObject* module) { if (PyType_Ready(&THPMemoryFormatType) < 0) { throw python_error(); } Py_INCREF(&THPMemoryFormatType); - if (PyModule_AddObject(module, "memory_format", (PyObject *)&THPMemoryFormatType) != 0) { + if (PyModule_AddObject( + module, "memory_format", (PyObject*)&THPMemoryFormatType) != 0) { throw python_error(); } } diff --git a/torch/csrc/MemoryFormat.h b/torch/csrc/MemoryFormat.h index 78b0728bb8676f..7f60a0ba0282c3 100644 --- a/torch/csrc/MemoryFormat.h +++ b/torch/csrc/MemoryFormat.h @@ -9,18 +9,19 @@ const int MEMORY_FORMAT_NAME_LEN = 64; struct THPMemoryFormat { - PyObject_HEAD - at::MemoryFormat memory_format; + PyObject_HEAD at::MemoryFormat memory_format; // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays) char name[MEMORY_FORMAT_NAME_LEN + 1]; }; extern PyTypeObject THPMemoryFormatType; -inline bool THPMemoryFormat_Check(PyObject *obj) { +inline bool THPMemoryFormat_Check(PyObject* obj) { return Py_TYPE(obj) == &THPMemoryFormatType; } -PyObject * THPMemoryFormat_New(at::MemoryFormat memory_format, const std::string& name); +PyObject* THPMemoryFormat_New( + at::MemoryFormat memory_format, + const std::string& name); -void THPMemoryFormat_init(PyObject *module); +void THPMemoryFormat_init(PyObject* module); diff --git a/torch/csrc/Module.cpp b/torch/csrc/Module.cpp index a6e6e733fd502b..ac2f753d89ee9d 100644 --- a/torch/csrc/Module.cpp +++ b/torch/csrc/Module.cpp @@ -1,71 +1,71 @@ -#include #include +#include #ifndef _MSC_VER #include #endif #include -#include #include #include #include #include #include #include -#include #include -#include +#include +#include #include #include -#include #include #include #include +#include +#include #include -#include -#include +#include #include -#include #include -#include +#include #include #include #include #include +#include +#include #include -#include +#include +#include #include +#include #include +#include +#include #include #include -#include -#include -#include #include +#include +#include +#include +#include +#include #include +#include #include #include -#include +#include +#include +#include #include +#include #include +#include #include #include -#include #include #include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include +#include #ifdef USE_DISTRIBUTED #ifdef USE_C10D @@ -84,17 +84,17 @@ namespace py = pybind11; PyObject* module; -THPGenerator *THPDefaultCPUGenerator = nullptr; +THPGenerator* THPDefaultCPUGenerator = nullptr; //////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////// -static PyObject * THPModule_initNames(PyObject *self, PyObject *arg) -{ +static PyObject* THPModule_initNames(PyObject* self, PyObject* arg) { static std::vector names; THPObjectPtr types(PySequence_Fast(arg, "expected a sequence")); - if (!types) return nullptr; + if (!types) + return nullptr; // NOLINTNEXTLINE(bugprone-branch-clone) auto num_classes = PySequence_Fast_GET_SIZE(types.get()); @@ -105,8 +105,10 @@ static PyObject * THPModule_initNames(PyObject *self, PyObject *arg) PyTypeObject* type = (PyTypeObject*)obj; THPObjectPtr module_name(PyObject_GetAttrString(obj, "__module__")); - if (!module_name) return nullptr; - THPUtils_assert(THPUtils_checkString(module_name.get()), + if (!module_name) + return nullptr; + THPUtils_assert( + THPUtils_checkString(module_name.get()), "expected __module__ to be a string"); std::string name = THPUtils_unpackString(module_name.get()); names.emplace_back(name + "." + type->tp_name); @@ -115,12 +117,15 @@ static PyObject * THPModule_initNames(PyObject *self, PyObject *arg) Py_RETURN_NONE; } // -// Callback for python part. Used for additional initialization of python classes -static PyObject * THPModule_initExtension(PyObject *_unused, PyObject *shm_manager_path) -{ +// Callback for python part. Used for additional initialization of python +// classes +static PyObject* THPModule_initExtension( + PyObject* _unused, + PyObject* shm_manager_path) { HANDLE_TH_ERRORS if (!THPUtils_checkString(shm_manager_path)) { - THPUtils_setError("initialization error - expected bytes/string object as shm_manager_path!"); + THPUtils_setError( + "initialization error - expected bytes/string object as shm_manager_path!"); return nullptr; } torch::utils::initializeLayouts(); @@ -132,7 +137,8 @@ static PyObject * THPModule_initExtension(PyObject *_unused, PyObject *shm_manag libshm_init(path.c_str()); auto module = THPObjectPtr(PyImport_ImportModule("torch")); - if (!module) throw python_error(); + if (!module) + throw python_error(); THPStorage_postInit(module); THPAutograd_initFunctions(); @@ -145,82 +151,95 @@ static PyObject * THPModule_initExtension(PyObject *_unused, PyObject *shm_manag // to trigger ASAN if it is enabled. This lets us run a "canary" tests which // checks if our build environment is misconfigured. -static PyObject * THPModule_crashIfCsrcASAN(PyObject *module, PyObject *arg) { - THPUtils_assert(THPUtils_checkLong(arg), "crash_if_csrc_asan expects an int, " - "but got %s", THPUtils_typename(arg)); - //NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays, modernize-avoid-c-arrays) +static PyObject* THPModule_crashIfCsrcASAN(PyObject* module, PyObject* arg) { + THPUtils_assert( + THPUtils_checkLong(arg), + "crash_if_csrc_asan expects an int, " + "but got %s", + THPUtils_typename(arg)); + // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays, modernize-avoid-c-arrays) volatile char x[3]; x[THPUtils_unpackInt(arg)] = 0; - //NOLINTNEXTLINE(clang-analyzer-core.CallAndMessage) + // NOLINTNEXTLINE(clang-analyzer-core.CallAndMessage) return THPUtils_packInt32(x[0]); } -static PyObject * THPModule_crashIfCsrcUBSAN(PyObject *module, PyObject *arg) { - THPUtils_assert(THPUtils_checkLong(arg), "crash_if_csrc_ubsan expects an int, " - "but got %s", THPUtils_typename(arg)); +static PyObject* THPModule_crashIfCsrcUBSAN(PyObject* module, PyObject* arg) { + THPUtils_assert( + THPUtils_checkLong(arg), + "crash_if_csrc_ubsan expects an int, " + "but got %s", + THPUtils_typename(arg)); int32_t x = THPUtils_unpackInt(arg); double y = 1.0 / x; return THPUtils_packInt32((int)y); } -static PyObject * THPModule_crashIfATenASAN(PyObject *module, PyObject *arg) { - THPUtils_assert(THPUtils_checkLong(arg), "crash_if_aten_asan expects an int, " - "but got %s", THPUtils_typename(arg)); +static PyObject* THPModule_crashIfATenASAN(PyObject* module, PyObject* arg) { + THPUtils_assert( + THPUtils_checkLong(arg), + "crash_if_aten_asan expects an int, " + "but got %s", + THPUtils_typename(arg)); return THPUtils_packInt32(at::_crash_if_asan(THPUtils_unpackInt(arg))); } -static PyObject * THPModule_getNumThreads(PyObject *module, PyObject *noargs) -{ +static PyObject* THPModule_getNumThreads(PyObject* module, PyObject* noargs) { return THPUtils_packInt32(at::get_num_threads()); } -static PyObject * THPModule_setNumThreads(PyObject *module, PyObject *arg) -{ - THPUtils_assert(THPUtils_checkLong(arg), "set_num_threads expects an int, " - "but got %s", THPUtils_typename(arg)); +static PyObject* THPModule_setNumThreads(PyObject* module, PyObject* arg) { + THPUtils_assert( + THPUtils_checkLong(arg), + "set_num_threads expects an int, " + "but got %s", + THPUtils_typename(arg)); int nthreads = (int)THPUtils_unpackLong(arg); THPUtils_assert(nthreads > 0, "set_num_threads expects a positive integer"); at::set_num_threads(nthreads); Py_RETURN_NONE; } -static PyObject * THPModule_getNumInteropThreads(PyObject *module, PyObject *noargs) -{ +static PyObject* THPModule_getNumInteropThreads( + PyObject* module, + PyObject* noargs) { return THPUtils_packInt32(at::get_num_interop_threads()); } -static PyObject * THPModule_setNumInteropThreads(PyObject *module, PyObject *arg) -{ - THPUtils_assert(THPUtils_checkLong(arg), "set_num_interop_threads expects an int, " - "but got %s", THPUtils_typename(arg)); +static PyObject* THPModule_setNumInteropThreads( + PyObject* module, + PyObject* arg) { + THPUtils_assert( + THPUtils_checkLong(arg), + "set_num_interop_threads expects an int, " + "but got %s", + THPUtils_typename(arg)); int nthreads = (int)THPUtils_unpackLong(arg); - THPUtils_assert(nthreads > 0, "set_num_interop_threads expects a positive integer"); + THPUtils_assert( + nthreads > 0, "set_num_interop_threads expects a positive integer"); at::set_num_interop_threads(nthreads); Py_RETURN_NONE; } -PyObject * THPModule_setDefaultTensorType(PyObject *_unused, PyObject *type) -{ +PyObject* THPModule_setDefaultTensorType(PyObject* _unused, PyObject* type) { HANDLE_TH_ERRORS torch::tensors::py_set_default_tensor_type(type); Py_RETURN_NONE; END_HANDLE_TH_ERRORS } -PyObject * THPModule_setDefaultDtype(PyObject *_unused, PyObject *dtype) -{ +PyObject* THPModule_setDefaultDtype(PyObject* _unused, PyObject* dtype) { HANDLE_TH_ERRORS torch::tensors::py_set_default_dtype(dtype); Py_RETURN_NONE; END_HANDLE_TH_ERRORS } -PyObject *THPModule_addDocStr(PyObject *_unused, PyObject *args) -{ +PyObject* THPModule_addDocStr(PyObject* _unused, PyObject* args) { // adds a __doc__ string to a function, similar to numpy's arr_add_docstring static std::vector all_docs; - PyObject *obj = nullptr; - PyObject *doc_obj = nullptr; + PyObject* obj = nullptr; + PyObject* doc_obj = nullptr; if (!PyArg_ParseTuple(args, "OO", &obj, &doc_obj)) { return nullptr; } @@ -232,56 +251,62 @@ PyObject *THPModule_addDocStr(PyObject *_unused, PyObject *args) } if (Py_TYPE(obj) == &PyCFunction_Type) { - PyCFunctionObject* f = (PyCFunctionObject *)obj; + PyCFunctionObject* f = (PyCFunctionObject*)obj; if (f->m_ml->ml_doc) { - return PyErr_Format(PyExc_RuntimeError, - "function '%s' already has a docstring", f->m_ml->ml_name); + return PyErr_Format( + PyExc_RuntimeError, + "function '%s' already has a docstring", + f->m_ml->ml_name); } f->m_ml->ml_doc = doc_str; } else if (strcmp(Py_TYPE(obj)->tp_name, "method_descriptor") == 0) { - PyMethodDescrObject* m = (PyMethodDescrObject *)obj; + PyMethodDescrObject* m = (PyMethodDescrObject*)obj; if (m->d_method->ml_doc) { - return PyErr_Format(PyExc_RuntimeError, - "method '%s' already has a docstring", m->d_method->ml_name); + return PyErr_Format( + PyExc_RuntimeError, + "method '%s' already has a docstring", + m->d_method->ml_name); } m->d_method->ml_doc = doc_str; } else if (strcmp(Py_TYPE(obj)->tp_name, "getset_descriptor") == 0) { - //NOLINTNEXTLINE(cppcoreguidelines-pro-type-cstyle-cast) - PyGetSetDescrObject* m = (PyGetSetDescrObject *)obj; + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-cstyle-cast) + PyGetSetDescrObject* m = (PyGetSetDescrObject*)obj; if (m->d_getset->doc) { - //NOLINTNEXTLINE(cppcoreguidelines-pro-type-vararg) - return PyErr_Format(PyExc_RuntimeError, - "attribute '%s' already has a docstring", m->d_getset->name); + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-vararg) + return PyErr_Format( + PyExc_RuntimeError, + "attribute '%s' already has a docstring", + m->d_getset->name); } // This field is not const for python < 3.7 yet the content is // never modified. - //NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) - m->d_getset->doc = const_cast(doc_str); + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) + m->d_getset->doc = const_cast(doc_str); } else if (Py_TYPE(obj) == &PyType_Type) { - PyTypeObject* t = (PyTypeObject *)obj; + PyTypeObject* t = (PyTypeObject*)obj; if (t->tp_doc) { - return PyErr_Format(PyExc_RuntimeError, - "Type '%s' already has a docstring", t->tp_name); + return PyErr_Format( + PyExc_RuntimeError, "Type '%s' already has a docstring", t->tp_name); } t->tp_doc = doc_str; } else { - return PyErr_Format(PyExc_TypeError, - "don't know how to add docstring to type '%s'", Py_TYPE(obj)->tp_name); + return PyErr_Format( + PyExc_TypeError, + "don't know how to add docstring to type '%s'", + Py_TYPE(obj)->tp_name); } Py_INCREF(obj); return obj; } - -PyObject *THPModule_inferSize(PyObject *_unused, PyObject *args) -{ +PyObject* THPModule_inferSize(PyObject* _unused, PyObject* args) { HANDLE_TH_ERRORS - Py_ssize_t num_args = args ? (Py_ssize_t) PyTuple_Size(args) : 0; + Py_ssize_t num_args = args ? (Py_ssize_t)PyTuple_Size(args) : 0; THPUtils_assert(num_args == 2, "expected exactly 2 arguments"); - PyObject *arg1 = PyTuple_GET_ITEM(args, 0); + PyObject* arg1 = PyTuple_GET_ITEM(args, 0); THPUtils_assert(THPSize_Check(arg1), "expected a torch.Size as argument 1"); - PyObject *arg2 = PyTuple_GET_ITEM(args, 1); + PyObject* arg2 = PyTuple_GET_ITEM(args, 1); THPUtils_assert(THPSize_Check(arg2), "expected a torch.Size as argument 2"); auto size1 = THPUtils_unpackLongs(arg1); @@ -291,34 +316,49 @@ PyObject *THPModule_inferSize(PyObject *_unused, PyObject *args) END_HANDLE_TH_ERRORS } -static PyObject *THPModule_setBackcompatBroadcastWarn(PyObject *module, PyObject *arg) { - THPUtils_assert(PyBool_Check(arg), "set_backcompat_broadcast_warn expects a bool, " - "but got %s", THPUtils_typename(arg)); +static PyObject* THPModule_setBackcompatBroadcastWarn( + PyObject* module, + PyObject* arg) { + THPUtils_assert( + PyBool_Check(arg), + "set_backcompat_broadcast_warn expects a bool, " + "but got %s", + THPUtils_typename(arg)); setBackCompatBroadcastWarn(arg == Py_True); Py_RETURN_NONE; } -static PyObject *THPModule_getBackcompatBroadcastWarn(PyObject *module, PyObject *noargs) -{ - if (getBackCompatBroadcastWarn()) Py_RETURN_TRUE; - else Py_RETURN_FALSE; +static PyObject* THPModule_getBackcompatBroadcastWarn( + PyObject* module, + PyObject* noargs) { + if (getBackCompatBroadcastWarn()) + Py_RETURN_TRUE; + else + Py_RETURN_FALSE; } -static PyObject *THPModule_setBackcompatKeepdimWarn(PyObject *module, PyObject *arg) { - THPUtils_assert(PyBool_Check(arg), "set_backcompat_keepdim_warn expects a bool, " - "but got %s", THPUtils_typename(arg)); +static PyObject* THPModule_setBackcompatKeepdimWarn( + PyObject* module, + PyObject* arg) { + THPUtils_assert( + PyBool_Check(arg), + "set_backcompat_keepdim_warn expects a bool, " + "but got %s", + THPUtils_typename(arg)); setBackCompatKeepdimWarn(arg == Py_True); Py_RETURN_NONE; } -static PyObject *THPModule_getBackcompatKeepdimWarn(PyObject *module, PyObject *noargs) -{ - if (getBackCompatKeepdimWarn()) Py_RETURN_TRUE; - else Py_RETURN_FALSE; +static PyObject* THPModule_getBackcompatKeepdimWarn( + PyObject* module, + PyObject* noargs) { + if (getBackCompatKeepdimWarn()) + Py_RETURN_TRUE; + else + Py_RETURN_FALSE; } -PyObject *THPModule_hasDistributed(PyObject *_unused, PyObject *noargs) -{ +PyObject* THPModule_hasDistributed(PyObject* _unused, PyObject* noargs) { #ifdef USE_DISTRIBUTED Py_RETURN_TRUE; #else @@ -326,22 +366,19 @@ PyObject *THPModule_hasDistributed(PyObject *_unused, PyObject *noargs) #endif } -static PyObject *THPModule_showConfig(PyObject *module, PyObject *noargs) -{ +static PyObject* THPModule_showConfig(PyObject* module, PyObject* noargs) { HANDLE_TH_ERRORS return THPUtils_packString(at::show_config()); END_HANDLE_TH_ERRORS } -static PyObject *THPModule_cxxFlags(PyObject *module, PyObject *noargs) -{ +static PyObject* THPModule_cxxFlags(PyObject* module, PyObject* noargs) { HANDLE_TH_ERRORS return THPUtils_packString(at::get_cxx_flags()); END_HANDLE_TH_ERRORS } -static PyObject *THPModule_parallelInfo(PyObject *module, PyObject *noargs) -{ +static PyObject* THPModule_parallelInfo(PyObject* module, PyObject* noargs) { HANDLE_TH_ERRORS return THPUtils_packString(at::get_parallel_info()); END_HANDLE_TH_ERRORS @@ -349,7 +386,8 @@ static PyObject *THPModule_parallelInfo(PyObject *module, PyObject *noargs) void DLPack_Capsule_Destructor(PyObject* data) { HANDLE_TH_ERRORS - DLManagedTensor * dlMTensor = (DLManagedTensor *)PyCapsule_GetPointer(data, "dltensor"); + DLManagedTensor* dlMTensor = + (DLManagedTensor*)PyCapsule_GetPointer(data, "dltensor"); if (dlMTensor) { // the dlMTensor has not been consumed, call deleter ourselves // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) @@ -362,8 +400,7 @@ void DLPack_Capsule_Destructor(PyObject* data) { END_HANDLE_TH_ERRORS_RET() } -PyObject *THPModule_toDLPack(PyObject *_unused, PyObject *data) -{ +PyObject* THPModule_toDLPack(PyObject* _unused, PyObject* data) { HANDLE_TH_ERRORS THPUtils_assert(THPVariable_Check(data), "data must be a Tensor"); DLManagedTensor* dlMTensor = at::toDLPack(THPVariable_Unpack(data)); @@ -371,8 +408,7 @@ PyObject *THPModule_toDLPack(PyObject *_unused, PyObject *data) END_HANDLE_TH_ERRORS } -PyObject *THPModule_fromDLPack(PyObject *_unused, PyObject *data) -{ +PyObject* THPModule_fromDLPack(PyObject* _unused, PyObject* data) { using namespace torch::autograd; HANDLE_TH_ERRORS auto tensor = torch::utils::tensor_fromDLPack(data); @@ -380,31 +416,39 @@ PyObject *THPModule_fromDLPack(PyObject *_unused, PyObject *data) END_HANDLE_TH_ERRORS } -PyObject *THPModule_setAllowTF32CuDNN(PyObject *_unused, PyObject *arg) -{ - THPUtils_assert(PyBool_Check(arg), "set_allow_tf32_cublas expects a bool, " - "but got %s", THPUtils_typename(arg)); +PyObject* THPModule_setAllowTF32CuDNN(PyObject* _unused, PyObject* arg) { + THPUtils_assert( + PyBool_Check(arg), + "set_allow_tf32_cublas expects a bool, " + "but got %s", + THPUtils_typename(arg)); at::globalContext().setAllowTF32CuDNN(arg == Py_True); Py_RETURN_NONE; } -PyObject *THPModule_allowTF32CuDNN(PyObject *_unused, PyObject *noargs) -{ - if (at::globalContext().allowTF32CuDNN()) Py_RETURN_TRUE; - else Py_RETURN_FALSE; +PyObject* THPModule_allowTF32CuDNN(PyObject* _unused, PyObject* noargs) { + if (at::globalContext().allowTF32CuDNN()) + Py_RETURN_TRUE; + else + Py_RETURN_FALSE; } -PyObject *THPModule_setFloat32MatmulPrecision(PyObject *_unused, PyObject *arg) -{ - THPUtils_assert(THPUtils_checkString(arg), "set_float32_matmul_precision expects a str, " - "but got %s", THPUtils_typename(arg)); +PyObject* THPModule_setFloat32MatmulPrecision( + PyObject* _unused, + PyObject* arg) { + THPUtils_assert( + THPUtils_checkString(arg), + "set_float32_matmul_precision expects a str, " + "but got %s", + THPUtils_typename(arg)); std::string s = THPUtils_unpackString(arg); at::globalContext().setFloat32MatmulPrecision(s); Py_RETURN_NONE; } -PyObject *THPModule_float32MatmulPrecision(PyObject *_unused, PyObject *noargs) -{ +PyObject* THPModule_float32MatmulPrecision( + PyObject* _unused, + PyObject* noargs) { std::string s = "highest"; auto p = at::globalContext().float32MatmulPrecision(); if (p == at::Float32MatmulPrecision::HIGH) { @@ -415,55 +459,66 @@ PyObject *THPModule_float32MatmulPrecision(PyObject *_unused, PyObject *noargs) return THPUtils_packString(s); } -PyObject *THPModule_setUserEnabledCuDNN(PyObject *_unused, PyObject *arg) -{ - THPUtils_assert(PyBool_Check(arg), "set_enabled_cudnn expects a bool, " - "but got %s", THPUtils_typename(arg)); +PyObject* THPModule_setUserEnabledCuDNN(PyObject* _unused, PyObject* arg) { + THPUtils_assert( + PyBool_Check(arg), + "set_enabled_cudnn expects a bool, " + "but got %s", + THPUtils_typename(arg)); at::globalContext().setUserEnabledCuDNN(arg == Py_True); Py_RETURN_NONE; } -PyObject *THPModule_userEnabledCuDNN(PyObject *_unused, PyObject *noargs) -{ - if (at::globalContext().userEnabledCuDNN()) Py_RETURN_TRUE; - else Py_RETURN_FALSE; +PyObject* THPModule_userEnabledCuDNN(PyObject* _unused, PyObject* noargs) { + if (at::globalContext().userEnabledCuDNN()) + Py_RETURN_TRUE; + else + Py_RETURN_FALSE; } -PyObject *THPModule_setUserEnabledMkldnn(PyObject *_unused, PyObject *arg) -{ - THPUtils_assert(PyBool_Check(arg), "set_enabled_mkldnn expects a bool, " - "but got %s", THPUtils_typename(arg)); +PyObject* THPModule_setUserEnabledMkldnn(PyObject* _unused, PyObject* arg) { + THPUtils_assert( + PyBool_Check(arg), + "set_enabled_mkldnn expects a bool, " + "but got %s", + THPUtils_typename(arg)); at::globalContext().setUserEnabledMkldnn(arg == Py_True); Py_RETURN_NONE; } -PyObject *THPModule_userEnabledMkldnn(PyObject *_unused, PyObject *noargs) -{ - if (at::globalContext().userEnabledMkldnn()) Py_RETURN_TRUE; - else Py_RETURN_FALSE; +PyObject* THPModule_userEnabledMkldnn(PyObject* _unused, PyObject* noargs) { + if (at::globalContext().userEnabledMkldnn()) + Py_RETURN_TRUE; + else + Py_RETURN_FALSE; } -PyObject *THPModule_setDeterministicCuDNN(PyObject *_unused, PyObject *arg) -{ +PyObject* THPModule_setDeterministicCuDNN(PyObject* _unused, PyObject* arg) { HANDLE_TH_ERRORS - THPUtils_assert(PyBool_Check(arg), "set_deterministic_cudnn expects a bool, " - "but got %s", THPUtils_typename(arg)); + THPUtils_assert( + PyBool_Check(arg), + "set_deterministic_cudnn expects a bool, " + "but got %s", + THPUtils_typename(arg)); at::globalContext().setDeterministicCuDNN(arg == Py_True); Py_RETURN_NONE; END_HANDLE_TH_ERRORS } -PyObject *THPModule_deterministicCuDNN(PyObject *_unused, PyObject *noargs) -{ - if (at::globalContext().deterministicCuDNN()) Py_RETURN_TRUE; - else Py_RETURN_FALSE; +PyObject* THPModule_deterministicCuDNN(PyObject* _unused, PyObject* noargs) { + if (at::globalContext().deterministicCuDNN()) + Py_RETURN_TRUE; + else + Py_RETURN_FALSE; } -PyObject *THPModule_setDeterministicAlgorithms(PyObject *_unused, PyObject *args, PyObject* kwargs) -{ +PyObject* THPModule_setDeterministicAlgorithms( + PyObject* _unused, + PyObject* args, + PyObject* kwargs) { HANDLE_TH_ERRORS - static torch::PythonArgParser parser({ - "_set_deterministic_algorithms(bool mode, *, bool warn_only=False)"}); + static torch::PythonArgParser parser( + {"_set_deterministic_algorithms(bool mode, *, bool warn_only=False)"}); torch::ParsedArgs<2> parsed_args{}; auto r = parser.parse(args, kwargs, parsed_args); bool mode = r.toBool(0); @@ -473,96 +528,109 @@ PyObject *THPModule_setDeterministicAlgorithms(PyObject *_unused, PyObject *args END_HANDLE_TH_ERRORS } -PyObject *THPModule_deterministicAlgorithms(PyObject *_unused, PyObject *noargs) -{ +PyObject* THPModule_deterministicAlgorithms( + PyObject* _unused, + PyObject* noargs) { if (at::globalContext().deterministicAlgorithms()) { - Py_RETURN_TRUE; + Py_RETURN_TRUE; } Py_RETURN_FALSE; } -PyObject *THPModule_deterministicAlgorithmsWarnOnly(PyObject *_unused, PyObject *noargs) -{ +PyObject* THPModule_deterministicAlgorithmsWarnOnly( + PyObject* _unused, + PyObject* noargs) { if (at::globalContext().deterministicAlgorithmsWarnOnly()) { - Py_RETURN_TRUE; + Py_RETURN_TRUE; } Py_RETURN_FALSE; } -PyObject *THPModule_setWarnAlways(PyObject *_unused, PyObject *arg) -{ - THPUtils_assert(PyBool_Check(arg), "setWarnOnlyOnce expects a bool, " - "but got %s", THPUtils_typename(arg)); +PyObject* THPModule_setWarnAlways(PyObject* _unused, PyObject* arg) { + THPUtils_assert( + PyBool_Check(arg), + "setWarnOnlyOnce expects a bool, " + "but got %s", + THPUtils_typename(arg)); c10::Warning::set_warnAlways(arg == Py_True); Py_RETURN_NONE; } -PyObject *THPModule_warnAlways(PyObject *_unused, PyObject *noargs) -{ +PyObject* THPModule_warnAlways(PyObject* _unused, PyObject* noargs) { if (c10::Warning::get_warnAlways()) { Py_RETURN_TRUE; } Py_RETURN_FALSE; } -PyObject *THPModule_setBenchmarkCuDNN(PyObject *_unused, PyObject *arg) -{ - THPUtils_assert(PyBool_Check(arg), "set_benchmark_cudnn expects a bool, " - "but got %s", THPUtils_typename(arg)); +PyObject* THPModule_setBenchmarkCuDNN(PyObject* _unused, PyObject* arg) { + THPUtils_assert( + PyBool_Check(arg), + "set_benchmark_cudnn expects a bool, " + "but got %s", + THPUtils_typename(arg)); at::globalContext().setBenchmarkCuDNN(arg == Py_True); Py_RETURN_NONE; } -PyObject *THPModule_benchmarkCuDNN(PyObject *_unused, PyObject *noargs) -{ +PyObject* THPModule_benchmarkCuDNN(PyObject* _unused, PyObject* noargs) { if (at::globalContext().benchmarkCuDNN()) { Py_RETURN_TRUE; } Py_RETURN_FALSE; } -PyObject *THPModule_setAllowTF32CuBLAS(PyObject *_unused, PyObject *arg) -{ - THPUtils_assert(PyBool_Check(arg), "set_allow_tf32_cublas expects a bool, " - "but got %s", THPUtils_typename(arg)); +PyObject* THPModule_setAllowTF32CuBLAS(PyObject* _unused, PyObject* arg) { + THPUtils_assert( + PyBool_Check(arg), + "set_allow_tf32_cublas expects a bool, " + "but got %s", + THPUtils_typename(arg)); at::globalContext().setAllowTF32CuBLAS(arg == Py_True); Py_RETURN_NONE; } -PyObject *THPModule_allowTF32CuBLAS(PyObject *_unused, PyObject *noargs) -{ +PyObject* THPModule_allowTF32CuBLAS(PyObject* _unused, PyObject* noargs) { if (at::globalContext().allowTF32CuBLAS()) { Py_RETURN_TRUE; } Py_RETURN_FALSE; } -PyObject *THPModule_setAllowFP16ReductionCuBLAS(PyObject *_unused, PyObject *arg) -{ - THPUtils_assert(PyBool_Check(arg), "set_allow_fp16_reduction_cublas expects a bool, " - "but got %s", THPUtils_typename(arg)); +PyObject* THPModule_setAllowFP16ReductionCuBLAS( + PyObject* _unused, + PyObject* arg) { + THPUtils_assert( + PyBool_Check(arg), + "set_allow_fp16_reduction_cublas expects a bool, " + "but got %s", + THPUtils_typename(arg)); at::globalContext().setAllowFP16ReductionCuBLAS(arg == Py_True); Py_RETURN_NONE; } -PyObject *THPModule_allowFP16ReductionCuBLAS(PyObject *_unused, PyObject *noargs) -{ +PyObject* THPModule_allowFP16ReductionCuBLAS( + PyObject* _unused, + PyObject* noargs) { if (at::globalContext().allowFP16ReductionCuBLAS()) { Py_RETURN_TRUE; } Py_RETURN_FALSE; } -PyObject *THPModule_setFlushDenormal(PyObject *_unused, PyObject *arg) { - THPUtils_assert(PyBool_Check(arg), "flush_denormal expects a bool, " - "but got %s", THPUtils_typename(arg)); +PyObject* THPModule_setFlushDenormal(PyObject* _unused, PyObject* arg) { + THPUtils_assert( + PyBool_Check(arg), + "flush_denormal expects a bool, " + "but got %s", + THPUtils_typename(arg)); if (!at::globalContext().setFlushDenormal(arg == Py_True)) { Py_RETURN_FALSE; }; Py_RETURN_TRUE; } -PyObject *THPModule_getDefaultDtype(PyObject *_unused, PyObject *arg) { +PyObject* THPModule_getDefaultDtype(PyObject* _unused, PyObject* arg) { HANDLE_TH_ERRORS auto scalar_type = torch::tensors::get_default_scalar_type(); auto dtype = (PyObject*)torch::getTHPDtype(scalar_type); @@ -571,18 +639,20 @@ PyObject *THPModule_getDefaultDtype(PyObject *_unused, PyObject *arg) { END_HANDLE_TH_ERRORS } -PyObject *THPModule_getDefaultDevice(PyObject *_unused, PyObject *arg) { +PyObject* THPModule_getDefaultDevice(PyObject* _unused, PyObject* arg) { HANDLE_TH_ERRORS - return THPUtils_packString( - c10::DeviceTypeName(dispatchKeyToDeviceType(torch::tensors::get_default_dispatch_key()), - /*lower_case=*/true)); + return THPUtils_packString(c10::DeviceTypeName( + dispatchKeyToDeviceType(torch::tensors::get_default_dispatch_key()), + /*lower_case=*/true)); END_HANDLE_TH_ERRORS } -PyObject *THPModule_setQEngine(PyObject */* unused */, PyObject *arg) -{ - THPUtils_assert(THPUtils_checkLong(arg), "set_qengine expects an int, " - "but got %s", THPUtils_typename(arg)); +PyObject* THPModule_setQEngine(PyObject* /* unused */, PyObject* arg) { + THPUtils_assert( + THPUtils_checkLong(arg), + "set_qengine expects an int, " + "but got %s", + THPUtils_typename(arg)); HANDLE_TH_ERRORS auto qengine = static_cast(THPUtils_unpackLong(arg)); at::globalContext().setQEngine(static_cast(qengine)); @@ -590,68 +660,82 @@ PyObject *THPModule_setQEngine(PyObject */* unused */, PyObject *arg) END_HANDLE_TH_ERRORS } -PyObject *THPModule_qEngine(PyObject *_unused, PyObject *noargs) -{ +PyObject* THPModule_qEngine(PyObject* _unused, PyObject* noargs) { return THPUtils_packInt64(static_cast(at::globalContext().qEngine())); } -PyObject *THPModule_supportedQEngines(PyObject *_unused, PyObject *noargs) -{ +PyObject* THPModule_supportedQEngines(PyObject* _unused, PyObject* noargs) { auto qengines = at::globalContext().supportedQEngines(); auto list = THPObjectPtr(PyList_New(qengines.size())); - if (!list) return nullptr; + if (!list) + return nullptr; for (const auto i : c10::irange(qengines.size())) { - PyObject *i64 = THPUtils_packInt64(static_cast(qengines[i])); - if (!i64) return nullptr; + PyObject* i64 = THPUtils_packInt64(static_cast(qengines[i])); + if (!i64) + return nullptr; PyList_SET_ITEM(list.get(), i, i64); } return list.release(); } -PyObject *THPModule_isEnabledXNNPACK(PyObject *_unused, PyObject *noargs) -{ - if (at::globalContext().isXNNPACKAvailable()) Py_RETURN_TRUE; - else Py_RETURN_FALSE; +PyObject* THPModule_isEnabledXNNPACK(PyObject* _unused, PyObject* noargs) { + if (at::globalContext().isXNNPACKAvailable()) + Py_RETURN_TRUE; + else + Py_RETURN_FALSE; } -PyObject *THPModule_setDefaultMobileCPUAllocator(PyObject *_unused, PyObject *noargs) -{ +PyObject* THPModule_setDefaultMobileCPUAllocator( + PyObject* _unused, + PyObject* noargs) { HANDLE_TH_ERRORS at::globalContext().setDefaultMobileCPUAllocator(); Py_RETURN_NONE; END_HANDLE_TH_ERRORS } -PyObject *THPModule_unsetDefaultMobileCPUAllocator(PyObject *_unused, PyObject *noargs) -{ +PyObject* THPModule_unsetDefaultMobileCPUAllocator( + PyObject* _unused, + PyObject* noargs) { HANDLE_TH_ERRORS at::globalContext().unsetDefaultMobileCPUAllocator(); Py_RETURN_NONE; END_HANDLE_TH_ERRORS } -static PyObject * THPModule_vmapmode_increment_nesting(PyObject* _unused, PyObject *arg) { +static PyObject* THPModule_vmapmode_increment_nesting( + PyObject* _unused, + PyObject* arg) { HANDLE_TH_ERRORS return THPUtils_packInt64(at::impl::VmapMode::increment_nesting()); END_HANDLE_TH_ERRORS } -static PyObject * THPModule_vmapmode_decrement_nesting(PyObject* _unused, PyObject *arg) { +static PyObject* THPModule_vmapmode_decrement_nesting( + PyObject* _unused, + PyObject* arg) { HANDLE_TH_ERRORS return THPUtils_packInt64(at::impl::VmapMode::decrement_nesting()); END_HANDLE_TH_ERRORS } -static PyObject * THPModule_set_display_vmap_fallback_warnings_mode(PyObject* _unused, PyObject *arg) { +static PyObject* THPModule_set_display_vmap_fallback_warnings_mode( + PyObject* _unused, + PyObject* arg) { HANDLE_TH_ERRORS - THPUtils_assert(PyBool_Check(arg), "enabled must be a bool, " - "but got %s", THPUtils_typename(arg)); + THPUtils_assert( + PyBool_Check(arg), + "enabled must be a bool, " + "but got %s", + THPUtils_typename(arg)); at::globalContext().setDisplayVmapFallbackWarnings(arg == Py_True); Py_RETURN_NONE; END_HANDLE_TH_ERRORS } -static PyObject * THPModule_are_vmap_fallback_warnings_enabled(PyObject* _unused, PyObject *arg) { +static PyObject* THPModule_are_vmap_fallback_warnings_enabled( + PyObject* _unused, + PyObject* arg) { HANDLE_TH_ERRORS if (at::globalContext().areVmapFallbackWarningsEnabled()) { Py_RETURN_TRUE; @@ -661,86 +745,169 @@ static PyObject * THPModule_are_vmap_fallback_warnings_enabled(PyObject* _unused END_HANDLE_TH_ERRORS } -//NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays, cppcoreguidelines-avoid-non-const-global-variables, modernize-avoid-c-arrays) +// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays, +// cppcoreguidelines-avoid-non-const-global-variables, modernize-avoid-c-arrays) static PyMethodDef TorchMethods[] = { - {"_initExtension", THPModule_initExtension, METH_O, nullptr}, - {"_autograd_init", THPAutograd_initExtension, METH_NOARGS, nullptr}, - {"_add_docstr", THPModule_addDocStr, METH_VARARGS, nullptr}, - {"_init_names", THPModule_initNames, METH_O, nullptr}, - {"_has_distributed",THPModule_hasDistributed, METH_NOARGS, nullptr}, - {"_set_default_tensor_type", THPModule_setDefaultTensorType, METH_O, nullptr}, - {"_set_default_dtype", THPModule_setDefaultDtype, METH_O, nullptr}, - {"_infer_size", THPModule_inferSize, METH_VARARGS, nullptr}, - {"_crash_if_csrc_asan", THPModule_crashIfCsrcASAN, METH_O, nullptr}, - {"_crash_if_csrc_ubsan", THPModule_crashIfCsrcUBSAN, METH_O, nullptr}, - {"_crash_if_aten_asan", THPModule_crashIfATenASAN, METH_O, nullptr}, - {"_show_config", THPModule_showConfig, METH_NOARGS, nullptr}, - {"_cxx_flags", THPModule_cxxFlags, METH_NOARGS, nullptr}, - {"_parallel_info", THPModule_parallelInfo, METH_NOARGS, nullptr}, - {"_set_backcompat_broadcast_warn", THPModule_setBackcompatBroadcastWarn, METH_O, nullptr}, - {"_get_backcompat_broadcast_warn", THPModule_getBackcompatBroadcastWarn, METH_NOARGS, nullptr}, - {"_set_backcompat_keepdim_warn", THPModule_setBackcompatKeepdimWarn, METH_O, nullptr}, - {"_get_backcompat_keepdim_warn", THPModule_getBackcompatKeepdimWarn, METH_NOARGS, nullptr}, - {"get_num_threads", THPModule_getNumThreads, METH_NOARGS, nullptr}, - {"set_num_threads", THPModule_setNumThreads, METH_O, nullptr}, - {"get_num_interop_threads", THPModule_getNumInteropThreads, METH_NOARGS, nullptr}, - {"set_num_interop_threads", THPModule_setNumInteropThreads, METH_O, nullptr}, - {"_get_cudnn_enabled", THPModule_userEnabledCuDNN, METH_NOARGS, nullptr}, - {"_set_cudnn_enabled", THPModule_setUserEnabledCuDNN, METH_O, nullptr}, - {"_get_mkldnn_enabled", THPModule_userEnabledMkldnn, METH_NOARGS, nullptr}, - {"_set_mkldnn_enabled", THPModule_setUserEnabledMkldnn, METH_O, nullptr}, - {"_get_cudnn_allow_tf32", THPModule_allowTF32CuDNN, METH_NOARGS, nullptr}, - {"_set_cudnn_allow_tf32", THPModule_setAllowTF32CuDNN, METH_O, nullptr}, - {"_get_cudnn_benchmark", THPModule_benchmarkCuDNN, METH_NOARGS, nullptr}, - {"_set_cudnn_benchmark", THPModule_setBenchmarkCuDNN, METH_O, nullptr}, - {"_get_cudnn_deterministic", THPModule_deterministicCuDNN, METH_NOARGS, nullptr}, - {"_set_cudnn_deterministic", THPModule_setDeterministicCuDNN, METH_O, nullptr}, - {"_get_deterministic_algorithms", THPModule_deterministicAlgorithms, METH_NOARGS, nullptr}, - {"_get_deterministic_algorithms_warn_only", THPModule_deterministicAlgorithmsWarnOnly, METH_NOARGS, nullptr}, - {"_set_deterministic_algorithms", castPyCFunctionWithKeywords(THPModule_setDeterministicAlgorithms), METH_VARARGS | METH_KEYWORDS, nullptr}, - {"_get_warnAlways", THPModule_warnAlways, METH_NOARGS, nullptr}, - {"_set_warnAlways", THPModule_setWarnAlways, METH_O, nullptr}, - {"_get_cublas_allow_tf32", THPModule_allowTF32CuBLAS, METH_NOARGS, nullptr}, - {"_set_cublas_allow_tf32", THPModule_setAllowTF32CuBLAS, METH_O, nullptr}, - {"_get_float32_matmul_precision", THPModule_float32MatmulPrecision, METH_NOARGS, nullptr}, - {"_set_float32_matmul_precision", THPModule_setFloat32MatmulPrecision, METH_O, nullptr}, - {"_get_cublas_allow_fp16_reduced_precision_reduction", THPModule_allowFP16ReductionCuBLAS, METH_NOARGS, nullptr}, - {"_set_cublas_allow_fp16_reduced_precision_reduction", THPModule_setAllowFP16ReductionCuBLAS, METH_O, nullptr}, - {"_vmapmode_increment_nesting", THPModule_vmapmode_increment_nesting, METH_NOARGS, nullptr}, - {"_vmapmode_decrement_nesting", THPModule_vmapmode_decrement_nesting, METH_NOARGS, nullptr}, - {"_debug_only_display_vmap_fallback_warnings", THPModule_set_display_vmap_fallback_warnings_mode, METH_O, nullptr}, - {"_debug_only_are_vmap_fallback_warnings_enabled", THPModule_are_vmap_fallback_warnings_enabled, METH_NOARGS, nullptr}, - {"_to_dlpack", THPModule_toDLPack, METH_O, nullptr}, - {"_from_dlpack", THPModule_fromDLPack, METH_O, nullptr}, - {"set_flush_denormal", THPModule_setFlushDenormal, METH_O, nullptr}, - {"get_default_dtype", THPModule_getDefaultDtype, METH_NOARGS, nullptr}, - {"_get_default_device", THPModule_getDefaultDevice, METH_NOARGS, nullptr}, - {"_get_qengine", THPModule_qEngine, METH_NOARGS, nullptr}, - {"_set_qengine", THPModule_setQEngine, METH_O, nullptr}, - {"_supported_qengines", THPModule_supportedQEngines, METH_NOARGS, nullptr}, - {"_is_xnnpack_enabled", THPModule_isEnabledXNNPACK, METH_NOARGS, nullptr}, - {"_set_default_mobile_cpu_allocator", THPModule_setDefaultMobileCPUAllocator, METH_NOARGS, nullptr}, - {"_unset_default_mobile_cpu_allocator", THPModule_unsetDefaultMobileCPUAllocator, METH_NOARGS, nullptr}, - {"_is_torch_function_enabled", THPModule_isEnabledTorchFunction, METH_NOARGS, nullptr}, - {"_disabled_torch_function_impl", THPModule_disable_torch_function, METH_VARARGS, nullptr}, - {"_disabled_torch_dispatch_impl", THPModule_disable_torch_dispatch, METH_VARARGS, nullptr}, - {"_has_torch_function", THPModule_has_torch_function, METH_O, nullptr}, - {"_has_torch_function_unary", THPModule_has_torch_function_unary, METH_O, nullptr}, - {"_has_torch_function_variadic", MAYBE_WRAP_FASTCALL(THPModule_has_torch_function_variadic), MAYBE_METH_FASTCALL, nullptr}, - {nullptr, nullptr, 0, nullptr} -}; - -void THCPStream_init(PyObject *module); -void THCPEvent_init(PyObject *module); -void THCPGraph_init(PyObject *module); + {"_initExtension", THPModule_initExtension, METH_O, nullptr}, + {"_autograd_init", THPAutograd_initExtension, METH_NOARGS, nullptr}, + {"_add_docstr", THPModule_addDocStr, METH_VARARGS, nullptr}, + {"_init_names", THPModule_initNames, METH_O, nullptr}, + {"_has_distributed", THPModule_hasDistributed, METH_NOARGS, nullptr}, + {"_set_default_tensor_type", + THPModule_setDefaultTensorType, + METH_O, + nullptr}, + {"_set_default_dtype", THPModule_setDefaultDtype, METH_O, nullptr}, + {"_infer_size", THPModule_inferSize, METH_VARARGS, nullptr}, + {"_crash_if_csrc_asan", THPModule_crashIfCsrcASAN, METH_O, nullptr}, + {"_crash_if_csrc_ubsan", THPModule_crashIfCsrcUBSAN, METH_O, nullptr}, + {"_crash_if_aten_asan", THPModule_crashIfATenASAN, METH_O, nullptr}, + {"_show_config", THPModule_showConfig, METH_NOARGS, nullptr}, + {"_cxx_flags", THPModule_cxxFlags, METH_NOARGS, nullptr}, + {"_parallel_info", THPModule_parallelInfo, METH_NOARGS, nullptr}, + {"_set_backcompat_broadcast_warn", + THPModule_setBackcompatBroadcastWarn, + METH_O, + nullptr}, + {"_get_backcompat_broadcast_warn", + THPModule_getBackcompatBroadcastWarn, + METH_NOARGS, + nullptr}, + {"_set_backcompat_keepdim_warn", + THPModule_setBackcompatKeepdimWarn, + METH_O, + nullptr}, + {"_get_backcompat_keepdim_warn", + THPModule_getBackcompatKeepdimWarn, + METH_NOARGS, + nullptr}, + {"get_num_threads", THPModule_getNumThreads, METH_NOARGS, nullptr}, + {"set_num_threads", THPModule_setNumThreads, METH_O, nullptr}, + {"get_num_interop_threads", + THPModule_getNumInteropThreads, + METH_NOARGS, + nullptr}, + {"set_num_interop_threads", + THPModule_setNumInteropThreads, + METH_O, + nullptr}, + {"_get_cudnn_enabled", THPModule_userEnabledCuDNN, METH_NOARGS, nullptr}, + {"_set_cudnn_enabled", THPModule_setUserEnabledCuDNN, METH_O, nullptr}, + {"_get_mkldnn_enabled", THPModule_userEnabledMkldnn, METH_NOARGS, nullptr}, + {"_set_mkldnn_enabled", THPModule_setUserEnabledMkldnn, METH_O, nullptr}, + {"_get_cudnn_allow_tf32", THPModule_allowTF32CuDNN, METH_NOARGS, nullptr}, + {"_set_cudnn_allow_tf32", THPModule_setAllowTF32CuDNN, METH_O, nullptr}, + {"_get_cudnn_benchmark", THPModule_benchmarkCuDNN, METH_NOARGS, nullptr}, + {"_set_cudnn_benchmark", THPModule_setBenchmarkCuDNN, METH_O, nullptr}, + {"_get_cudnn_deterministic", + THPModule_deterministicCuDNN, + METH_NOARGS, + nullptr}, + {"_set_cudnn_deterministic", + THPModule_setDeterministicCuDNN, + METH_O, + nullptr}, + {"_get_deterministic_algorithms", + THPModule_deterministicAlgorithms, + METH_NOARGS, + nullptr}, + {"_get_deterministic_algorithms_warn_only", + THPModule_deterministicAlgorithmsWarnOnly, + METH_NOARGS, + nullptr}, + {"_set_deterministic_algorithms", + castPyCFunctionWithKeywords(THPModule_setDeterministicAlgorithms), + METH_VARARGS | METH_KEYWORDS, + nullptr}, + {"_get_warnAlways", THPModule_warnAlways, METH_NOARGS, nullptr}, + {"_set_warnAlways", THPModule_setWarnAlways, METH_O, nullptr}, + {"_get_cublas_allow_tf32", THPModule_allowTF32CuBLAS, METH_NOARGS, nullptr}, + {"_set_cublas_allow_tf32", THPModule_setAllowTF32CuBLAS, METH_O, nullptr}, + {"_get_float32_matmul_precision", + THPModule_float32MatmulPrecision, + METH_NOARGS, + nullptr}, + {"_set_float32_matmul_precision", + THPModule_setFloat32MatmulPrecision, + METH_O, + nullptr}, + {"_get_cublas_allow_fp16_reduced_precision_reduction", + THPModule_allowFP16ReductionCuBLAS, + METH_NOARGS, + nullptr}, + {"_set_cublas_allow_fp16_reduced_precision_reduction", + THPModule_setAllowFP16ReductionCuBLAS, + METH_O, + nullptr}, + {"_vmapmode_increment_nesting", + THPModule_vmapmode_increment_nesting, + METH_NOARGS, + nullptr}, + {"_vmapmode_decrement_nesting", + THPModule_vmapmode_decrement_nesting, + METH_NOARGS, + nullptr}, + {"_debug_only_display_vmap_fallback_warnings", + THPModule_set_display_vmap_fallback_warnings_mode, + METH_O, + nullptr}, + {"_debug_only_are_vmap_fallback_warnings_enabled", + THPModule_are_vmap_fallback_warnings_enabled, + METH_NOARGS, + nullptr}, + {"_to_dlpack", THPModule_toDLPack, METH_O, nullptr}, + {"_from_dlpack", THPModule_fromDLPack, METH_O, nullptr}, + {"set_flush_denormal", THPModule_setFlushDenormal, METH_O, nullptr}, + {"get_default_dtype", THPModule_getDefaultDtype, METH_NOARGS, nullptr}, + {"_get_default_device", THPModule_getDefaultDevice, METH_NOARGS, nullptr}, + {"_get_qengine", THPModule_qEngine, METH_NOARGS, nullptr}, + {"_set_qengine", THPModule_setQEngine, METH_O, nullptr}, + {"_supported_qengines", THPModule_supportedQEngines, METH_NOARGS, nullptr}, + {"_is_xnnpack_enabled", THPModule_isEnabledXNNPACK, METH_NOARGS, nullptr}, + {"_set_default_mobile_cpu_allocator", + THPModule_setDefaultMobileCPUAllocator, + METH_NOARGS, + nullptr}, + {"_unset_default_mobile_cpu_allocator", + THPModule_unsetDefaultMobileCPUAllocator, + METH_NOARGS, + nullptr}, + {"_is_torch_function_enabled", + THPModule_isEnabledTorchFunction, + METH_NOARGS, + nullptr}, + {"_disabled_torch_function_impl", + THPModule_disable_torch_function, + METH_VARARGS, + nullptr}, + {"_disabled_torch_dispatch_impl", + THPModule_disable_torch_dispatch, + METH_VARARGS, + nullptr}, + {"_has_torch_function", THPModule_has_torch_function, METH_O, nullptr}, + {"_has_torch_function_unary", + THPModule_has_torch_function_unary, + METH_O, + nullptr}, + {"_has_torch_function_variadic", + MAYBE_WRAP_FASTCALL(THPModule_has_torch_function_variadic), + MAYBE_METH_FASTCALL, + nullptr}, + {nullptr, nullptr, 0, nullptr}}; + +void THCPStream_init(PyObject* module); +void THCPEvent_init(PyObject* module); +void THCPGraph_init(PyObject* module); #ifdef USE_CUDA PyMethodDef* THCPModule_methods(); -namespace torch { namespace cuda { +namespace torch { +namespace cuda { -void initModule(PyObject *module); +void initModule(PyObject* module); -}} // namespace torch::cuda +} +} // namespace torch #endif static std::vector methods; @@ -759,10 +926,8 @@ static void LogAPIUsageOnceFromPython(const std::string& event) { class WeakTensorRef { c10::weak_intrusive_ptr weakref_; -public: - WeakTensorRef(const at::Tensor& t): - weakref_(t.getIntrusivePtr()) { - } + public: + WeakTensorRef(const at::Tensor& t) : weakref_(t.getIntrusivePtr()) {} bool expired() { return weakref_.expired(); @@ -771,9 +936,9 @@ class WeakTensorRef { extern "C" #ifdef _WIN32 -__declspec(dllexport) + __declspec(dllexport) #endif -TORCH_API PyObject* initModule(); + TORCH_API PyObject* initModule(); // separate decl and defn for msvc error C2491 PyObject* initModule() { HANDLE_TH_ERRORS @@ -785,7 +950,9 @@ PyObject* initModule() { C10_LOG_API_USAGE_ONCE("torch.python.import"); // NOLINTNEXTLINE(cppcoreguidelines-macro-usage) -#define ASSERT_TRUE(cmd) if (!(cmd)) return nullptr +#define ASSERT_TRUE(cmd) \ + if (!(cmd)) \ + return nullptr THPUtils_addPyMethodDefs(methods, TorchMethods); THPUtils_addPyMethodDefs(methods, DataLoaderMethods); @@ -795,22 +962,20 @@ PyObject* initModule() { THPUtils_addPyMethodDefs(methods, THCPModule_methods()); #endif #if defined(USE_DISTRIBUTED) && defined(USE_C10D) - THPUtils_addPyMethodDefs(methods, torch::distributed::c10d::python_functions()); + THPUtils_addPyMethodDefs( + methods, torch::distributed::c10d::python_functions()); #ifndef _WIN32 - THPUtils_addPyMethodDefs(methods, torch::distributed::rpc::python_functions()); + THPUtils_addPyMethodDefs( + methods, torch::distributed::rpc::python_functions()); THPUtils_addPyMethodDefs( methods, torch::distributed::autograd::python_functions()); - THPUtils_addPyMethodDefs(methods, torch::distributed::rpc::testing::python_functions()); + THPUtils_addPyMethodDefs( + methods, torch::distributed::rpc::testing::python_functions()); #endif #endif static struct PyModuleDef torchmodule = { - PyModuleDef_HEAD_INIT, - "torch._C", - nullptr, - -1, - methods.data() - }; + PyModuleDef_HEAD_INIT, "torch._C", nullptr, -1, methods.data()}; ASSERT_TRUE(module = PyModule_Create(&torchmodule)); ASSERT_TRUE(THPGenerator_init(module)); ASSERT_TRUE(THPException_init(module)); @@ -851,34 +1016,35 @@ PyObject* initModule() { #ifdef USE_CUDA // This will only initialise base classes and attach them to library namespace // They won't be ready for real usage until importing cuda module, that will - // complete the process (but it defines Python classes before calling back into - // C, so these lines have to execute first).. + // complete the process (but it defines Python classes before calling back + // into C, so these lines have to execute first).. THCPStream_init(module); THCPEvent_init(module); THCPGraph_init(module); #endif - auto set_module_attr = [&](const char* name, PyObject* v, bool incref = true) { - // PyModule_AddObject steals reference - if (incref) { - Py_INCREF(v); - } - return PyModule_AddObject(module, name, v) == 0; - }; + auto set_module_attr = + [&](const char* name, PyObject* v, bool incref = true) { + // PyModule_AddObject steals reference + if (incref) { + Py_INCREF(v); + } + return PyModule_AddObject(module, name, v) == 0; + }; #if defined(USE_CUDNN) || defined(USE_ROCM) - PyObject *has_cudnn = Py_True; + PyObject* has_cudnn = Py_True; #else - PyObject *has_cudnn = Py_False; + PyObject* has_cudnn = Py_False; #endif - ASSERT_TRUE(set_module_attr("has_cudnn", has_cudnn)); + ASSERT_TRUE(set_module_attr("has_cudnn", has_cudnn)); #if AT_MKL_ENABLED() || AT_POCKETFFT_ENABLED() - PyObject *has_spectral = Py_True; + PyObject* has_spectral = Py_True; #else - PyObject *has_spectral = Py_False; + PyObject* has_spectral = Py_False; #endif - ASSERT_TRUE(set_module_attr("has_spectral", has_spectral)); + ASSERT_TRUE(set_module_attr("has_spectral", has_spectral)); // force ATen to initialize because it handles // setting up TH Errors so that they throw C++ exceptions @@ -899,17 +1065,20 @@ PyObject* initModule() { py_module.def("_log_api_usage_once", &LogAPIUsageOnceFromPython); py_module.def("vitals_enabled", &at::vitals::torchVitalEnabled); - py_module.def("set_vital", [](const std::string &vital, const std::string &attr, const std::string value){ - return at::vitals::VitalsAPI.setVital(vital, attr, value); - }); - py_module.def("read_vitals", [](){ - return at::vitals::VitalsAPI.readVitals(); - }); + py_module.def( + "set_vital", + [](const std::string& vital, + const std::string& attr, + const std::string value) { + return at::vitals::VitalsAPI.setVital(vital, attr, value); + }); + py_module.def( + "read_vitals", []() { return at::vitals::VitalsAPI.readVitals(); }); py_module.def( - "init_num_threads", - torch::wrap_pybind_function(at::init_num_threads), - R"( + "init_num_threads", + torch::wrap_pybind_function(at::init_num_threads), + R"( init_num_threads() Initializes the number of parallel threads used on the current thread. @@ -918,84 +1087,96 @@ Call this whenever a new thread is created in order to propagate values from :func:`torch.set_num_threads` onto the new thread. )"); - ASSERT_TRUE(set_module_attr("has_openmp", at::hasOpenMP() ? Py_True : Py_False)); + ASSERT_TRUE( + set_module_attr("has_openmp", at::hasOpenMP() ? Py_True : Py_False)); ASSERT_TRUE(set_module_attr("has_mkl", at::hasMKL() ? Py_True : Py_False)); - ASSERT_TRUE(set_module_attr("has_lapack", at::hasLAPACK() ? Py_True : Py_False)); + ASSERT_TRUE( + set_module_attr("has_lapack", at::hasLAPACK() ? Py_True : Py_False)); - py_module.def( - "_valgrind_supported_platform", [](){ - #if defined(USE_VALGRIND) - return true; - #else + py_module.def("_valgrind_supported_platform", []() { +#if defined(USE_VALGRIND) + return true; +#else return false; - #endif - } - ); +#endif + }); - py_module.def( - "_valgrind_toggle", [](){ - #if defined(USE_VALGRIND) - CALLGRIND_TOGGLE_COLLECT; - #else + py_module.def("_valgrind_toggle", []() { +#if defined(USE_VALGRIND) + CALLGRIND_TOGGLE_COLLECT; +#else TORCH_CHECK(false, "Valgrind is not supported."); - #endif - } - ); +#endif + }); - py_module.def( - "_valgrind_toggle_and_dump_stats", [](){ - #if defined(USE_VALGRIND) - // NB: If we don't toggle collect around dump stats, callgrind_annotate - // won't process the results correctly. Specifically, - // `callgrind_annotate --inclusive=no` will be almost completely empty. - CALLGRIND_TOGGLE_COLLECT; - CALLGRIND_DUMP_STATS; - #else + py_module.def("_valgrind_toggle_and_dump_stats", []() { +#if defined(USE_VALGRIND) + // NB: If we don't toggle collect around dump stats, callgrind_annotate + // won't process the results correctly. Specifically, + // `callgrind_annotate --inclusive=no` will be almost completely empty. + CALLGRIND_TOGGLE_COLLECT; + CALLGRIND_DUMP_STATS; +#else TORCH_CHECK(false, "Valgrind is not supported."); - #endif - } - ); +#endif + }); py::class_(py_module, "_WeakTensorRef") - .def(py::init([](py::object tensor) { - return WeakTensorRef(THPVariable_Unpack(tensor.ptr())); - })) - .def("expired", &WeakTensorRef::expired); + .def(py::init([](py::object tensor) { + return WeakTensorRef(THPVariable_Unpack(tensor.ptr())); + })) + .def("expired", &WeakTensorRef::expired); py::enum_(py_module, "_ConvBackend") - .value("CudaDepthwise2d", at::native::ConvBackend::CudaDepthwise2d) - .value("CudaDepthwise3d", at::native::ConvBackend::CudaDepthwise3d) - .value("Cudnn", at::native::ConvBackend::Cudnn) - .value("CudnnTranspose", at::native::ConvBackend::CudnnTranspose) - .value("Empty", at::native::ConvBackend::Empty) - .value("Miopen", at::native::ConvBackend::Miopen) - .value("MiopenDepthwise", at::native::ConvBackend::MiopenDepthwise) - .value("MiopenTranspose", at::native::ConvBackend::MiopenTranspose) - .value("Mkldnn", at::native::ConvBackend::Mkldnn) - .value("MkldnnEmpty", at::native::ConvBackend::MkldnnEmpty) - .value("NnpackSpatial", at::native::ConvBackend::NnpackSpatial) - .value("Overrideable", at::native::ConvBackend::Overrideable) - .value("Slow2d", at::native::ConvBackend::Slow2d) - .value("Slow3d", at::native::ConvBackend::Slow3d) - .value("SlowDilated2d", at::native::ConvBackend::SlowDilated2d) - .value("SlowDilated3d", at::native::ConvBackend::SlowDilated3d) - .value("SlowTranspose2d", at::native::ConvBackend::SlowTranspose2d) - .value("SlowTranspose3d", at::native::ConvBackend::SlowTranspose3d) - .value("Winograd3x3Depthwise", at::native::ConvBackend::Winograd3x3Depthwise) - .value("Xnnpack2d", at::native::ConvBackend::Xnnpack2d); - - py_module.def("_select_conv_backend", []( - const at::Tensor& input, const at::Tensor& weight, const c10::optional& bias_opt, - at::IntArrayRef stride_, at::IntArrayRef padding_, at::IntArrayRef dilation_, - bool transposed_, at::IntArrayRef output_padding_, int64_t groups_) { - return at::native::select_conv_backend( - input, weight, bias_opt, stride_, padding_, dilation_, transposed_, output_padding_, groups_); - }); + .value("CudaDepthwise2d", at::native::ConvBackend::CudaDepthwise2d) + .value("CudaDepthwise3d", at::native::ConvBackend::CudaDepthwise3d) + .value("Cudnn", at::native::ConvBackend::Cudnn) + .value("CudnnTranspose", at::native::ConvBackend::CudnnTranspose) + .value("Empty", at::native::ConvBackend::Empty) + .value("Miopen", at::native::ConvBackend::Miopen) + .value("MiopenDepthwise", at::native::ConvBackend::MiopenDepthwise) + .value("MiopenTranspose", at::native::ConvBackend::MiopenTranspose) + .value("Mkldnn", at::native::ConvBackend::Mkldnn) + .value("MkldnnEmpty", at::native::ConvBackend::MkldnnEmpty) + .value("NnpackSpatial", at::native::ConvBackend::NnpackSpatial) + .value("Overrideable", at::native::ConvBackend::Overrideable) + .value("Slow2d", at::native::ConvBackend::Slow2d) + .value("Slow3d", at::native::ConvBackend::Slow3d) + .value("SlowDilated2d", at::native::ConvBackend::SlowDilated2d) + .value("SlowDilated3d", at::native::ConvBackend::SlowDilated3d) + .value("SlowTranspose2d", at::native::ConvBackend::SlowTranspose2d) + .value("SlowTranspose3d", at::native::ConvBackend::SlowTranspose3d) + .value( + "Winograd3x3Depthwise", at::native::ConvBackend::Winograd3x3Depthwise) + .value("Xnnpack2d", at::native::ConvBackend::Xnnpack2d); + + py_module.def( + "_select_conv_backend", + [](const at::Tensor& input, + const at::Tensor& weight, + const c10::optional& bias_opt, + at::IntArrayRef stride_, + at::IntArrayRef padding_, + at::IntArrayRef dilation_, + bool transposed_, + at::IntArrayRef output_padding_, + int64_t groups_) { + return at::native::select_conv_backend( + input, + weight, + bias_opt, + stride_, + padding_, + dilation_, + transposed_, + output_padding_, + groups_); + }); py::enum_(py_module, "_LinalgBackend") - .value("Default", at::LinalgBackend::Default) - .value("Cusolver", at::LinalgBackend::Cusolver) - .value("Magma", at::LinalgBackend::Magma); + .value("Default", at::LinalgBackend::Default) + .value("Cusolver", at::LinalgBackend::Cusolver) + .value("Magma", at::LinalgBackend::Magma); py_module.def("_set_linalg_preferred_backend", [](at::LinalgBackend b) { at::globalContext().setLinalgPreferredBackend(b); @@ -1005,40 +1186,40 @@ Call this whenever a new thread is created in order to propagate values from }); #ifdef USE_CUDA - PyObject *has_cuda = Py_True; + PyObject* has_cuda = Py_True; #else - PyObject *has_cuda = Py_False; + PyObject* has_cuda = Py_False; #endif #ifdef USE_MPS - PyObject *has_mps = Py_True; + PyObject* has_mps = Py_True; #else - PyObject *has_mps = Py_False; + PyObject* has_mps = Py_False; #endif ASSERT_TRUE(set_module_attr("has_cuda", has_cuda)); ASSERT_TRUE(set_module_attr("has_mps", has_mps)); - py_module.def("_is_mps_available", []() { - return at::hasMPS(); - }); + py_module.def("_is_mps_available", []() { return at::hasMPS(); }); - - ASSERT_TRUE(set_module_attr("has_mkldnn", at::hasMKLDNN() ? Py_True : Py_False)); + ASSERT_TRUE( + set_module_attr("has_mkldnn", at::hasMKLDNN() ? Py_True : Py_False)); #ifdef _GLIBCXX_USE_CXX11_ABI - ASSERT_TRUE(set_module_attr("_GLIBCXX_USE_CXX11_ABI", _GLIBCXX_USE_CXX11_ABI ? Py_True : Py_False)); + ASSERT_TRUE(set_module_attr( + "_GLIBCXX_USE_CXX11_ABI", _GLIBCXX_USE_CXX11_ABI ? Py_True : Py_False)); #else ASSERT_TRUE(set_module_attr("_GLIBCXX_USE_CXX11_ABI", Py_False)); #endif // See note [Pybind11 ABI constants] #define SET_STR_DEFINE(name) \ - ASSERT_TRUE(set_module_attr("_" # name, THPUtils_packString(name))) + ASSERT_TRUE(set_module_attr("_" #name, THPUtils_packString(name))) #ifdef PYBIND11_COMPILER_TYPE SET_STR_DEFINE(PYBIND11_COMPILER_TYPE); #else - ASSERT_TRUE(set_module_attr("_" C10_STRINGIZE(PYBIND11_COMPILER_TYPE), Py_None)); + ASSERT_TRUE( + set_module_attr("_" C10_STRINGIZE(PYBIND11_COMPILER_TYPE), Py_None)); #endif #ifdef PYBIND11_STDLIB @@ -1054,24 +1235,31 @@ Call this whenever a new thread is created in order to propagate values from #endif #undef SET_STR_DEFINE - py_module.def("_set_conj", [](const at::Tensor & x, bool conj) { - x._set_conj(conj); - }); - py_module.def("_set_neg", [](const at::Tensor & x, bool neg) { - x._set_neg(neg); - }); - py_module.def("_dispatch_key_set", [](const at::Tensor & x) { + py_module.def( + "_set_conj", [](const at::Tensor& x, bool conj) { x._set_conj(conj); }); + py_module.def( + "_set_neg", [](const at::Tensor& x, bool neg) { x._set_neg(neg); }); + py_module.def("_dispatch_key_set", [](const at::Tensor& x) { return toString(x.key_set()); }); const auto& defaultGenerator = at::detail::getDefaultCPUGenerator(); - THPDefaultCPUGenerator = (THPGenerator*)THPGenerator_initDefaultGenerator(defaultGenerator); + THPDefaultCPUGenerator = + (THPGenerator*)THPGenerator_initDefaultGenerator(defaultGenerator); // This reference is meant to be given away, so no need to incref here. - ASSERT_TRUE(set_module_attr("default_generator", (PyObject*)THPDefaultCPUGenerator, /* incref= */ false)); - ASSERT_TRUE(set_module_attr("DisableTorchFunction", (PyObject*)THPModule_DisableTorchFunctionType(), /* incref= */ false)); - torch::set_disabled_torch_function_impl(PyObject_GetAttrString(module, "_disabled_torch_function_impl")); + ASSERT_TRUE(set_module_attr( + "default_generator", + (PyObject*)THPDefaultCPUGenerator, + /* incref= */ false)); + ASSERT_TRUE(set_module_attr( + "DisableTorchFunction", + (PyObject*)THPModule_DisableTorchFunctionType(), + /* incref= */ false)); + torch::set_disabled_torch_function_impl( + PyObject_GetAttrString(module, "_disabled_torch_function_impl")); ASSERT_TRUE(torch::disabled_torch_function_impl() != nullptr); - torch::set_disabled_torch_dispatch_impl(PyObject_GetAttrString(module, "_disabled_torch_dispatch_impl")); + torch::set_disabled_torch_dispatch_impl( + PyObject_GetAttrString(module, "_disabled_torch_dispatch_impl")); ASSERT_TRUE(torch::disabled_torch_dispatch_impl() != nullptr); return module; END_HANDLE_TH_ERRORS @@ -1087,10 +1275,13 @@ inline void pytorch_duplicate_guard() { abort(); } initialized = 1; -;} + ; +} struct call_duplicate_guard { - call_duplicate_guard() { pytorch_duplicate_guard(); } + call_duplicate_guard() { + pytorch_duplicate_guard(); + } }; static call_duplicate_guard _call_duplicate_guard; diff --git a/torch/csrc/PythonTypes.h b/torch/csrc/PythonTypes.h index f94d41ac5aee61..e457ab891da014 100644 --- a/torch/csrc/PythonTypes.h +++ b/torch/csrc/PythonTypes.h @@ -1,22 +1,20 @@ #pragma once -#include -#include #include #include +#include +#include namespace torch { struct THPVoidTensor { - PyObject_HEAD - c10::TensorImpl *cdata; + PyObject_HEAD c10::TensorImpl* cdata; char device_type; char data_type; }; struct THPVoidStorage { - PyObject_HEAD - c10::StorageImpl *cdata; + PyObject_HEAD c10::StorageImpl* cdata; }; } // namespace torch diff --git a/torch/csrc/QScheme.cpp b/torch/csrc/QScheme.cpp index 08e69dce025f08..abcfc5b4e9eec2 100644 --- a/torch/csrc/QScheme.cpp +++ b/torch/csrc/QScheme.cpp @@ -10,83 +10,80 @@ #include #include -PyObject *THPQScheme_New(at::QScheme qscheme, const std::string& name) -{ +PyObject* THPQScheme_New(at::QScheme qscheme, const std::string& name) { auto type = (PyTypeObject*)&THPQSchemeType; auto self = THPObjectPtr{type->tp_alloc(type, 0)}; - if (!self) throw python_error(); + if (!self) + throw python_error(); auto self_ = reinterpret_cast(self.get()); self_->qscheme = qscheme; - std::strncpy (self_->name, name.c_str(), QSCHEME_NAME_LEN); + std::strncpy(self_->name, name.c_str(), QSCHEME_NAME_LEN); self_->name[QSCHEME_NAME_LEN] = '\0'; return self.release(); } -PyObject *THPQScheme_reduce(PyObject *_self, PyObject *noargs) { +PyObject* THPQScheme_reduce(PyObject* _self, PyObject* noargs) { auto self = (THPQScheme*)_self; return THPUtils_packString(self->name); } // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables,modernize-avoid-c-arrays) static PyMethodDef THPQScheme_methods[] = { - {"__reduce__", THPQScheme_reduce, METH_NOARGS, nullptr}, - {nullptr} /* Sentinel */ + {"__reduce__", THPQScheme_reduce, METH_NOARGS, nullptr}, + {nullptr} /* Sentinel */ }; -PyObject *THPQScheme_repr(THPQScheme *self) -{ +PyObject* THPQScheme_repr(THPQScheme* self) { std::string name = self->name; return THPUtils_packString("torch." + name); } PyTypeObject THPQSchemeType = { - PyVarObject_HEAD_INIT(nullptr, 0) - "torch.qscheme", /* tp_name */ - sizeof(THPQScheme), /* tp_basicsize */ - 0, /* tp_itemsize */ - nullptr, /* tp_dealloc */ - 0, /* tp_vectorcall_offset */ - nullptr, /* tp_getattr */ - nullptr, /* tp_setattr */ - nullptr, /* tp_reserved */ - (reprfunc)THPQScheme_repr, /* tp_repr */ - nullptr, /* tp_as_number */ - nullptr, /* tp_as_sequence */ - nullptr, /* tp_as_mapping */ - nullptr, /* tp_hash */ - nullptr, /* tp_call */ - nullptr, /* tp_str */ - nullptr, /* tp_getattro */ - nullptr, /* tp_setattro */ - nullptr, /* tp_as_buffer */ - Py_TPFLAGS_DEFAULT, /* tp_flags */ - nullptr, /* tp_doc */ - nullptr, /* tp_traverse */ - nullptr, /* tp_clear */ - nullptr, /* tp_richcompare */ - 0, /* tp_weaklistoffset */ - nullptr, /* tp_iter */ - nullptr, /* tp_iternext */ - THPQScheme_methods, /* tp_methods */ - nullptr, /* tp_members */ - nullptr, /* tp_getset */ - nullptr, /* tp_base */ - nullptr, /* tp_dict */ - nullptr, /* tp_descr_get */ - nullptr, /* tp_descr_set */ - 0, /* tp_dictoffset */ - nullptr, /* tp_init */ - nullptr, /* tp_alloc */ - nullptr, /* tp_new */ + PyVarObject_HEAD_INIT(nullptr, 0) "torch.qscheme", /* tp_name */ + sizeof(THPQScheme), /* tp_basicsize */ + 0, /* tp_itemsize */ + nullptr, /* tp_dealloc */ + 0, /* tp_vectorcall_offset */ + nullptr, /* tp_getattr */ + nullptr, /* tp_setattr */ + nullptr, /* tp_reserved */ + (reprfunc)THPQScheme_repr, /* tp_repr */ + nullptr, /* tp_as_number */ + nullptr, /* tp_as_sequence */ + nullptr, /* tp_as_mapping */ + nullptr, /* tp_hash */ + nullptr, /* tp_call */ + nullptr, /* tp_str */ + nullptr, /* tp_getattro */ + nullptr, /* tp_setattro */ + nullptr, /* tp_as_buffer */ + Py_TPFLAGS_DEFAULT, /* tp_flags */ + nullptr, /* tp_doc */ + nullptr, /* tp_traverse */ + nullptr, /* tp_clear */ + nullptr, /* tp_richcompare */ + 0, /* tp_weaklistoffset */ + nullptr, /* tp_iter */ + nullptr, /* tp_iternext */ + THPQScheme_methods, /* tp_methods */ + nullptr, /* tp_members */ + nullptr, /* tp_getset */ + nullptr, /* tp_base */ + nullptr, /* tp_dict */ + nullptr, /* tp_descr_get */ + nullptr, /* tp_descr_set */ + 0, /* tp_dictoffset */ + nullptr, /* tp_init */ + nullptr, /* tp_alloc */ + nullptr, /* tp_new */ }; -void THPQScheme_init(PyObject *module) -{ +void THPQScheme_init(PyObject* module) { if (PyType_Ready(&THPQSchemeType) < 0) { throw python_error(); } Py_INCREF(&THPQSchemeType); - if (PyModule_AddObject(module, "qscheme", (PyObject *)&THPQSchemeType) != 0) { + if (PyModule_AddObject(module, "qscheme", (PyObject*)&THPQSchemeType) != 0) { throw python_error(); } } diff --git a/torch/csrc/QScheme.h b/torch/csrc/QScheme.h index 60444be272be51..fcb75304c0ed0b 100644 --- a/torch/csrc/QScheme.h +++ b/torch/csrc/QScheme.h @@ -9,18 +9,17 @@ constexpr int QSCHEME_NAME_LEN = 64; struct THPQScheme { - PyObject_HEAD - at::QScheme qscheme; + PyObject_HEAD at::QScheme qscheme; // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays) char name[QSCHEME_NAME_LEN + 1]; }; extern PyTypeObject THPQSchemeType; -inline bool THPQScheme_Check(PyObject *obj) { +inline bool THPQScheme_Check(PyObject* obj) { return Py_TYPE(obj) == &THPQSchemeType; } -PyObject * THPQScheme_New(at::QScheme qscheme, const std::string& name); +PyObject* THPQScheme_New(at::QScheme qscheme, const std::string& name); -void THPQScheme_init(PyObject *module); +void THPQScheme_init(PyObject* module); diff --git a/torch/csrc/Size.cpp b/torch/csrc/Size.cpp index e67faba8989acb..2869e9dc53c308 100644 --- a/torch/csrc/Size.cpp +++ b/torch/csrc/Size.cpp @@ -1,11 +1,11 @@ #include #include -#include #include #include #include #include +#include #include #include @@ -14,45 +14,50 @@ struct THPSize { PyTupleObject tuple; }; -PyObject * THPSize_New(const torch::autograd::Variable& var) -{ +PyObject* THPSize_New(const torch::autograd::Variable& var) { if (!torch::jit::tracer::isTracing()) { auto sizes = var.sizes(); return THPSize_NewFromSizes(var.dim(), sizes.data()); } auto self = THPObjectPtr(THPSizeType.tp_alloc(&THPSizeType, var.dim())); - if (!self) throw python_error(); + if (!self) + throw python_error(); for (const auto i : c10::irange(var.dim())) { - PyObject *py_size_tensor = THPVariable_Wrap(torch::jit::tracer::getSizeOf(var, i)); - if (!py_size_tensor) throw python_error(); + PyObject* py_size_tensor = + THPVariable_Wrap(torch::jit::tracer::getSizeOf(var, i)); + if (!py_size_tensor) + throw python_error(); PyTuple_SET_ITEM(self.get(), i, py_size_tensor); } return self.release(); } -PyObject * THPSize_NewFromSizes(int dim, const int64_t *sizes) -{ +PyObject* THPSize_NewFromSizes(int dim, const int64_t* sizes) { auto self = THPObjectPtr(THPSizeType.tp_alloc(&THPSizeType, dim)); - if (!self) throw python_error(); + if (!self) + throw python_error(); THPUtils_packInt64Array(self, dim, sizes); return self.release(); } -static bool isTracedZeroDimVar(PyObject *item) { - if (!THPVariable_Check(item)) return false; - auto & var = THPVariable_Unpack(item); +static bool isTracedZeroDimVar(PyObject* item) { + if (!THPVariable_Check(item)) + return false; + auto& var = THPVariable_Unpack(item); return var.dim() == 0 && torch::jit::tracer::getValueTrace(var); } -static PyObject * THPSize_pynew(PyTypeObject *type, PyObject *args, PyObject *kwargs) -{ +static PyObject* THPSize_pynew( + PyTypeObject* type, + PyObject* args, + PyObject* kwargs) { HANDLE_TH_ERRORS THPObjectPtr self(PyTuple_Type.tp_new(type, args, kwargs)); if (self) { for (Py_ssize_t i = 0; i < PyTuple_Size(self); ++i) { - PyObject *item = PyTuple_GET_ITEM(self.get(), i); + PyObject* item = PyTuple_GET_ITEM(self.get(), i); if (THPUtils_checkLong(item)) { continue; } @@ -69,17 +74,18 @@ static PyObject * THPSize_pynew(PyTypeObject *type, PyObject *args, PyObject *kw } continue; } - return PyErr_Format(PyExc_TypeError, - "torch.Size() takes an iterable of 'int' (item %zd is '%s')", - i, Py_TYPE(item)->tp_name); + return PyErr_Format( + PyExc_TypeError, + "torch.Size() takes an iterable of 'int' (item %zd is '%s')", + i, + Py_TYPE(item)->tp_name); } } return self.release(); END_HANDLE_TH_ERRORS } -static PyObject * THPSize_repr(THPSize *self) -{ +static PyObject* THPSize_repr(THPSize* self) { HANDLE_TH_ERRORS std::string repr("torch.Size(["); for (Py_ssize_t i = 0; i < PyTuple_Size((PyObject*)self); ++i) { @@ -95,13 +101,14 @@ static PyObject * THPSize_repr(THPSize *self) extern PyTypeObject THPSizeType; -template -static PyObject* wrap_tuple_fn(Args ... args) -{ +template +static PyObject* wrap_tuple_fn(Args... args) { THPObjectPtr result((*fn)(std::forward(args)...)); - if (!result) return nullptr; + if (!result) + return nullptr; if (PyTuple_Check(result.get())) { - return PyObject_CallFunctionObjArgs((PyObject*)&THPSizeType, result.get(), nullptr); + return PyObject_CallFunctionObjArgs( + (PyObject*)&THPSizeType, result.get(), nullptr); } return result.release(); } @@ -109,31 +116,28 @@ static PyObject* wrap_tuple_fn(Args ... args) // We use an anonymous namespace instead of static to work around // (what @peterjc123 think is) a bug in Visual Studio namespace { - auto sq_concat = PyTuple_Type.tp_as_sequence->sq_concat; - auto sq_repeat = PyTuple_Type.tp_as_sequence->sq_repeat; - binaryfunc mp_subscript = PyTuple_Type.tp_as_mapping->mp_subscript; -} - +auto sq_concat = PyTuple_Type.tp_as_sequence -> sq_concat; +auto sq_repeat = PyTuple_Type.tp_as_sequence -> sq_repeat; +binaryfunc mp_subscript = PyTuple_Type.tp_as_mapping->mp_subscript; +} // namespace static PySequenceMethods THPSize_as_sequence = { - nullptr, /* sq_length */ - wrap_tuple_fn, - wrap_tuple_fn, - nullptr, /* sq_item */ - nullptr, /* sq_slice */ - nullptr, /* sq_ass_item */ - nullptr, /* sq_ass_slice */ - nullptr /* sq_contains */ + nullptr, /* sq_length */ + wrap_tuple_fn, + wrap_tuple_fn, + nullptr, /* sq_item */ + nullptr, /* sq_slice */ + nullptr, /* sq_ass_item */ + nullptr, /* sq_ass_slice */ + nullptr /* sq_contains */ }; static PyMappingMethods THPSize_as_mapping = { - nullptr, /* mp_length */ + nullptr, /* mp_length */ wrap_tuple_fn, - nullptr -}; + nullptr}; -static PyObject *THPSize_numel(PyObject *_self, PyObject *noargs) -{ +static PyObject* THPSize_numel(PyObject* _self, PyObject* noargs) { HANDLE_TH_ERRORS auto self = (THPSize*)_self; int64_t numel = 1; @@ -144,19 +148,20 @@ static PyObject *THPSize_numel(PyObject *_self, PyObject *noargs) END_HANDLE_TH_ERRORS } -static PyObject *THPSize_reduce(PyObject *_self, PyObject *noargs) -{ +static PyObject* THPSize_reduce(PyObject* _self, PyObject* noargs) { HANDLE_TH_ERRORS auto self = (THPSize*)_self; auto ret = THPObjectPtr{PyTuple_New(2)}; - if (!ret) throw python_error(); + if (!ret) + throw python_error(); auto obj = (PyObject*)(&THPSizeType); Py_INCREF(&THPSizeType); PyTuple_SET_ITEM(ret.get(), 0, obj); THPObjectPtr t(PyTuple_New(PyTuple_Size((PyObject*)self))); - if (!t) throw python_error(); + if (!t) + throw python_error(); for (Py_ssize_t i = 0; i < PyTuple_Size((PyObject*)self); ++i) { auto d = PyTuple_GET_ITEM(self, i); Py_INCREF(d); @@ -164,7 +169,8 @@ static PyObject *THPSize_reduce(PyObject *_self, PyObject *noargs) } THPObjectPtr dims(Py_BuildValue("(O)", t.get())); - if (!dims) throw python_error(); + if (!dims) + throw python_error(); PyTuple_SET_ITEM(ret.get(), 1, dims.release()); return ret.release(); @@ -173,55 +179,51 @@ static PyObject *THPSize_reduce(PyObject *_self, PyObject *noargs) // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables,modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays) static PyMethodDef THPSize_methods[] = { - {"numel", THPSize_numel, METH_NOARGS, nullptr}, - {"__reduce__", THPSize_reduce, METH_NOARGS, nullptr}, - {nullptr} -}; - + {"numel", THPSize_numel, METH_NOARGS, nullptr}, + {"__reduce__", THPSize_reduce, METH_NOARGS, nullptr}, + {nullptr}}; PyTypeObject THPSizeType = { - PyVarObject_HEAD_INIT(nullptr, 0) - "torch.Size", /* tp_name */ - sizeof(THPSize), /* tp_basicsize */ - 0, /* tp_itemsize */ - nullptr, /* tp_dealloc */ - 0, /* tp_vectorcall_offset */ - nullptr, /* tp_getattr */ - nullptr, /* tp_setattr */ - nullptr, /* tp_reserved */ - (reprfunc)THPSize_repr, /* tp_repr */ - nullptr, /* tp_as_number */ - &THPSize_as_sequence, /* tp_as_sequence */ - &THPSize_as_mapping, /* tp_as_mapping */ - nullptr, /* tp_hash */ - nullptr, /* tp_call */ - nullptr, /* tp_str */ - nullptr, /* tp_getattro */ - nullptr, /* tp_setattro */ - nullptr, /* tp_as_buffer */ - Py_TPFLAGS_DEFAULT, /* tp_flags */ - nullptr, /* tp_doc */ - nullptr, /* tp_traverse */ - nullptr, /* tp_clear */ - nullptr, /* tp_richcompare */ - 0, /* tp_weaklistoffset */ - nullptr, /* tp_iter */ - nullptr, /* tp_iternext */ - THPSize_methods, /* tp_methods */ - nullptr, /* tp_members */ - nullptr, /* tp_getset */ - &PyTuple_Type, /* tp_base */ - nullptr, /* tp_dict */ - nullptr, /* tp_descr_get */ - nullptr, /* tp_descr_set */ - 0, /* tp_dictoffset */ - nullptr, /* tp_init */ - nullptr, /* tp_alloc */ - THPSize_pynew, /* tp_new */ + PyVarObject_HEAD_INIT(nullptr, 0) "torch.Size", /* tp_name */ + sizeof(THPSize), /* tp_basicsize */ + 0, /* tp_itemsize */ + nullptr, /* tp_dealloc */ + 0, /* tp_vectorcall_offset */ + nullptr, /* tp_getattr */ + nullptr, /* tp_setattr */ + nullptr, /* tp_reserved */ + (reprfunc)THPSize_repr, /* tp_repr */ + nullptr, /* tp_as_number */ + &THPSize_as_sequence, /* tp_as_sequence */ + &THPSize_as_mapping, /* tp_as_mapping */ + nullptr, /* tp_hash */ + nullptr, /* tp_call */ + nullptr, /* tp_str */ + nullptr, /* tp_getattro */ + nullptr, /* tp_setattro */ + nullptr, /* tp_as_buffer */ + Py_TPFLAGS_DEFAULT, /* tp_flags */ + nullptr, /* tp_doc */ + nullptr, /* tp_traverse */ + nullptr, /* tp_clear */ + nullptr, /* tp_richcompare */ + 0, /* tp_weaklistoffset */ + nullptr, /* tp_iter */ + nullptr, /* tp_iternext */ + THPSize_methods, /* tp_methods */ + nullptr, /* tp_members */ + nullptr, /* tp_getset */ + &PyTuple_Type, /* tp_base */ + nullptr, /* tp_dict */ + nullptr, /* tp_descr_get */ + nullptr, /* tp_descr_set */ + 0, /* tp_dictoffset */ + nullptr, /* tp_init */ + nullptr, /* tp_alloc */ + THPSize_pynew, /* tp_new */ }; -void THPSize_init(PyObject *module) -{ +void THPSize_init(PyObject* module) { if (PyType_Ready(&THPSizeType) < 0) { throw python_error(); } diff --git a/torch/csrc/Size.h b/torch/csrc/Size.h index 690b66b1d416c8..31cf6d369df59c 100644 --- a/torch/csrc/Size.h +++ b/torch/csrc/Size.h @@ -1,14 +1,14 @@ #pragma once -#include #include +#include #include extern PyTypeObject THPSizeType; #define THPSize_Check(obj) (Py_TYPE(obj) == &THPSizeType) -PyObject * THPSize_New(const torch::autograd::Variable& t); -PyObject * THPSize_NewFromSizes(int dim, const int64_t *sizes); +PyObject* THPSize_New(const torch::autograd::Variable& t); +PyObject* THPSize_NewFromSizes(int dim, const int64_t* sizes); -void THPSize_init(PyObject *module); +void THPSize_init(PyObject* module); diff --git a/torch/csrc/Storage.cpp b/torch/csrc/Storage.cpp index 1df0fa16c5ea7b..5170fe746caf39 100644 --- a/torch/csrc/Storage.cpp +++ b/torch/csrc/Storage.cpp @@ -4,57 +4,59 @@ #endif #include +#include #include -#include -#include -#include #include #include -#include -#include +#include #include #include -#include +#include +#include +#include +#include -#include #include +#include -template<> +template <> void THPPointer::free() { if (ptr) { c10::raw::intrusive_ptr::decref(ptr); } } -PyObject *THPStorageClass = nullptr; +PyObject* THPStorageClass = nullptr; -PyObject * THPStorage_New(c10::intrusive_ptr ptr) -{ +PyObject* THPStorage_New(c10::intrusive_ptr ptr) { AT_ASSERT(ptr); - PyTypeObject *type = (PyTypeObject *)THPStorageClass; - PyObject *obj = type->tp_alloc(type, 0); + PyTypeObject* type = (PyTypeObject*)THPStorageClass; + PyObject* obj = type->tp_alloc(type, 0); if (obj) { - ((THPStorage *)obj)->cdata = ptr.release(); + ((THPStorage*)obj)->cdata = ptr.release(); } return obj; } -static void THPStorage_dealloc(THPStorage* self) -{ +static void THPStorage_dealloc(THPStorage* self) { if (self->cdata) { c10::raw::intrusive_ptr::decref(self->cdata); } Py_TYPE(self)->tp_free((PyObject*)self); } -static PyObject * THPStorage_pynew(PyTypeObject *type, PyObject *args, PyObject *kwargs) -{ +static PyObject* THPStorage_pynew( + PyTypeObject* type, + PyObject* args, + PyObject* kwargs) { HANDLE_TH_ERRORS static torch::PythonArgParser parser({ THPStorageStr "(*, int64_t allocator=None, Device device=None)", - THPStorageStr "(int64_t size, *, int64_t allocator=None, Device device=None)", - THPStorageStr "(PyObject* sequence, *, int64_t allocator=None, Device device=None)", + THPStorageStr + "(int64_t size, *, int64_t allocator=None, Device device=None)", + THPStorageStr + "(PyObject* sequence, *, int64_t allocator=None, Device device=None)", }); torch::ParsedArgs<3> parsed_args; auto r = parser.parse(args, kwargs, parsed_args); @@ -70,11 +72,13 @@ static PyObject * THPStorage_pynew(PyTypeObject *type, PyObject *args, PyObject c10::optional allocator_opt = r.toInt64Optional(allocator_arg_idx); c10::optional device_opt = r.deviceOptional(device_arg_idx); - TORCH_CHECK(!allocator_opt.has_value() || !device_opt.has_value(), - THPStorageStr, "(): only one or neither of 'allocator' or 'device' can ", - "be given, but not both"); + TORCH_CHECK( + !allocator_opt.has_value() || !device_opt.has_value(), + THPStorageStr, + "(): only one or neither of 'allocator' or 'device' can ", + "be given, but not both"); - THPStoragePtr self((THPStorage *)type->tp_alloc(type, 0)); + THPStoragePtr self((THPStorage*)type->tp_alloc(type, 0)); THPUtils_assert(self, "failed to allocate a " THPStorageStr " object"); c10::Allocator* allocator = nullptr; at::OptionalDeviceGuard device_guard; @@ -93,8 +97,11 @@ static PyObject * THPStorage_pynew(PyTypeObject *type, PyObject *args, PyObject } else if (device.type() == at::DeviceType::Meta) { allocator = c10::GetAllocator(device.type()); } else { - TORCH_CHECK(false, - THPStorageStr, "(): Storage device not recognized: ", device.type()); + TORCH_CHECK( + false, + THPStorageStr, + "(): Storage device not recognized: ", + device.type()); } device_guard.reset_device(device); } else { @@ -104,38 +111,44 @@ static PyObject * THPStorage_pynew(PyTypeObject *type, PyObject *args, PyObject // torch.Storage(*, ...) if (r.idx == 0) { self->cdata = c10::make_intrusive( - c10::StorageImpl::use_byte_size_t(), - 0, - allocator, - /*resizable=*/true).release(); + c10::StorageImpl::use_byte_size_t(), + 0, + allocator, + /*resizable=*/true) + .release(); return (PyObject*)self.release(); - // torch.Storage(size, *, ...) + // torch.Storage(size, *, ...) } else if (r.idx == 1) { int64_t size = r.toInt64(0); self->cdata = c10::make_intrusive( - c10::StorageImpl::use_byte_size_t(), - size, - allocator, - /*resizable=*/true).release(); + c10::StorageImpl::use_byte_size_t(), + size, + allocator, + /*resizable=*/true) + .release(); return (PyObject*)self.release(); - // torch.Storage(sequence, *, ...) + // torch.Storage(sequence, *, ...) } else if (r.idx == 2) { - PyObject *sequence = r.pyobject(0); + PyObject* sequence = r.pyobject(0); Py_ssize_t length = PySequence_Length(sequence); - TORCH_CHECK(PySequence_Check(sequence), - THPStorageStr, "(): Expected a sequence type, but got ", - THPUtils_typename(sequence)); - TORCH_CHECK(length >= 0, - THPStorageStr, "(): Could not obtain the length of sequence of type ", - THPUtils_typename(sequence)); + TORCH_CHECK( + PySequence_Check(sequence), + THPStorageStr, + "(): Expected a sequence type, but got ", + THPUtils_typename(sequence)); + TORCH_CHECK( + length >= 0, + THPStorageStr, + "(): Could not obtain the length of sequence of type ", + THPUtils_typename(sequence)); self->cdata = c10::make_intrusive( - c10::StorageImpl::use_byte_size_t(), - length, - allocator, - /*resizable=*/true) - .release(); + c10::StorageImpl::use_byte_size_t(), + length, + allocator, + /*resizable=*/true) + .release(); THPObjectPtr item; try { for (Py_ssize_t i = 0; i < length; i++) { @@ -147,13 +160,12 @@ static PyObject * THPStorage_pynew(PyTypeObject *type, PyObject *args, PyObject } else { // TODO: this might be slow - consider batched updates? storage_set( - at::unsafeStorageFromTH(self->cdata, /*retain=*/true), - i, - value); + at::unsafeStorageFromTH(self->cdata, /*retain=*/true), i, value); } } - } catch (const std::exception &e) { - THPUtils_setError(THPStorageStr + } catch (const std::exception& e) { + THPUtils_setError( + THPStorageStr "(): tried to construct a storage from a sequence (%s), " "but one of the items was of type %s instead of %s", THPUtils_typename(sequence), @@ -167,30 +179,34 @@ static PyObject * THPStorage_pynew(PyTypeObject *type, PyObject *args, PyObject END_HANDLE_TH_ERRORS } -static Py_ssize_t THPStorage_length(THPStorage *self) -{ +static Py_ssize_t THPStorage_length(THPStorage* self) { HANDLE_TH_ERRORS return self->cdata->nbytes() / sizeof(uint8_t); END_HANDLE_TH_ERRORS_RET(-1) } -static PyObject * THPStorage_get(THPStorage *self, PyObject *index) -{ +static PyObject* THPStorage_get(THPStorage* self, PyObject* index) { HANDLE_TH_ERRORS /* Integer index */ if (THPUtils_checkLong(index)) { int64_t nindex = THPUtils_unpackLong(index); if (nindex < 0) nindex += (self->cdata->nbytes() / sizeof(uint8_t)); - if (nindex < 0 || nindex >= static_cast(self->cdata->nbytes() / sizeof(uint8_t))) { - PyErr_SetString(PyExc_IndexError, fmt::format( - "index {} out of range for storage of size {}", - nindex, self->cdata->nbytes() / sizeof(uint8_t))); + if (nindex < 0 || + nindex >= + static_cast(self->cdata->nbytes() / sizeof(uint8_t))) { + PyErr_SetString( + PyExc_IndexError, + fmt::format( + "index {} out of range for storage of size {}", + nindex, + self->cdata->nbytes() / sizeof(uint8_t))); return nullptr; } - uint8_t value = storage_get(at::unsafeStorageFromTH(self->cdata, /*retain=*/true), nindex); + uint8_t value = storage_get( + at::unsafeStorageFromTH(self->cdata, /*retain=*/true), nindex); return THPByteUtils_newReal(value); - /* Slice index */ + /* Slice index */ } else if (PySlice_Check(index)) { // NOLINTNEXTLINE(cppcoreguidelines-init-variables) Py_ssize_t start, stop, slicelength, step; @@ -198,12 +214,14 @@ static PyObject * THPStorage_get(THPStorage *self, PyObject *index) if (!THPUtils_parseSlice(index, len, &start, &stop, &step, &slicelength)) return nullptr; if (step != 1) { - THPUtils_setError("Trying to slice with a step of %lld, but only a step of " - "1 is supported", (long long)step); + THPUtils_setError( + "Trying to slice with a step of %lld, but only a step of " + "1 is supported", + (long long)step); return nullptr; } - uint8_t *data = self->cdata->data(); + uint8_t* data = self->cdata->data(); at::StorageImpl* old_storage = self->cdata; c10::raw::intrusive_ptr::incref(old_storage); @@ -224,21 +242,24 @@ static PyObject * THPStorage_get(THPStorage *self, PyObject *index) old_storage->allocator(), /* resizable */ false); - PyObject *_ret = THPStorage_New(std::move(new_storage)); + PyObject* _ret = THPStorage_New(std::move(new_storage)); return _ret; } - PyErr_Format(PyExc_TypeError, "can't index a " THPStorageStr " with %s", + PyErr_Format( + PyExc_TypeError, + "can't index a " THPStorageStr " with %s", THPUtils_typename(index)); return nullptr; END_HANDLE_TH_ERRORS } -static int THPStorage_set(THPStorage *self, PyObject *index, PyObject *value) -{ +static int THPStorage_set(THPStorage* self, PyObject* index, PyObject* value) { HANDLE_TH_ERRORS if (!THPByteUtils_checkReal(value)) { - THPUtils_setError("can only set storage content with a %s, but got " - "%s instead", THPUtils_typeTraits::python_type_str, + THPUtils_setError( + "can only set storage content with a %s, but got " + "%s instead", + THPUtils_typeTraits::python_type_str, THPUtils_typename(value)); return -1; } @@ -247,9 +268,7 @@ static int THPStorage_set(THPStorage *self, PyObject *index, PyObject *value) if (THPUtils_checkLong(index)) { int64_t nindex = THPUtils_unpackLong(index); storage_set( - at::unsafeStorageFromTH(self->cdata, /*retain=*/true), - nindex, - rvalue); + at::unsafeStorageFromTH(self->cdata, /*retain=*/true), nindex, rvalue); return 0; } else if (PySlice_Check(index)) { // NOLINTNEXTLINE(cppcoreguidelines-init-variables) @@ -258,95 +277,98 @@ static int THPStorage_set(THPStorage *self, PyObject *index, PyObject *value) if (!THPUtils_parseSlice(index, len, &start, &stop, &step, &slicelength)) return -1; if (step != 1) { - THPUtils_setError("Trying to slice with a step of %lld, but only a step of " - "1 is supported", (long long)step); + THPUtils_setError( + "Trying to slice with a step of %lld, but only a step of " + "1 is supported", + (long long)step); return 0; } // TODO: check the bounds only once // TODO: fill? - for (;start < stop; start++) + for (; start < stop; start++) storage_set( - at::unsafeStorageFromTH(self->cdata, /*retain=*/true), - start, - rvalue); + at::unsafeStorageFromTH(self->cdata, /*retain=*/true), start, rvalue); return 0; } - THPUtils_setError("can't index a " THPStorageStr " with %s", - THPUtils_typename(index)); + THPUtils_setError( + "can't index a " THPStorageStr " with %s", THPUtils_typename(index)); return -1; END_HANDLE_TH_ERRORS_RET(-1) } static PyMappingMethods THPStorage_mappingmethods = { - (lenfunc)THPStorage_length, - (binaryfunc)THPStorage_get, - (objobjargproc)THPStorage_set -}; + (lenfunc)THPStorage_length, + (binaryfunc)THPStorage_get, + (objobjargproc)THPStorage_set}; // TODO: implement equality PyTypeObject THPStorageType = { - PyVarObject_HEAD_INIT(nullptr, 0) - "torch._C." THPStorageBaseStr, /* tp_name */ - sizeof(THPStorage), /* tp_basicsize */ - 0, /* tp_itemsize */ - (destructor)THPStorage_dealloc, /* tp_dealloc */ - 0, /* tp_vectorcall_offset */ - nullptr, /* tp_getattr */ - nullptr, /* tp_setattr */ - nullptr, /* tp_reserved */ - nullptr, /* tp_repr */ - nullptr, /* tp_as_number */ - nullptr, /* tp_as_sequence */ - &THPStorage_mappingmethods, /* tp_as_mapping */ - nullptr, /* tp_hash */ - nullptr, /* tp_call */ - nullptr, /* tp_str */ - nullptr, /* tp_getattro */ - nullptr, /* tp_setattro */ - nullptr, /* tp_as_buffer */ - Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /* tp_flags */ - nullptr, /* tp_doc */ - nullptr, /* tp_traverse */ - nullptr, /* tp_clear */ - nullptr, /* tp_richcompare */ - 0, /* tp_weaklistoffset */ - nullptr, /* tp_iter */ - nullptr, /* tp_iternext */ - nullptr, /* will be assigned in init */ /* tp_methods */ - nullptr, /* will be assigned in init */ /* tp_members */ - nullptr, /* tp_getset */ - nullptr, /* tp_base */ - nullptr, /* tp_dict */ - nullptr, /* tp_descr_get */ - nullptr, /* tp_descr_set */ - 0, /* tp_dictoffset */ - nullptr, /* tp_init */ - nullptr, /* tp_alloc */ - THPStorage_pynew, /* tp_new */ + PyVarObject_HEAD_INIT( + nullptr, + 0) "torch._C." THPStorageBaseStr, /* tp_name */ + sizeof(THPStorage), /* tp_basicsize */ + 0, /* tp_itemsize */ + (destructor)THPStorage_dealloc, /* tp_dealloc */ + 0, /* tp_vectorcall_offset */ + nullptr, /* tp_getattr */ + nullptr, /* tp_setattr */ + nullptr, /* tp_reserved */ + nullptr, /* tp_repr */ + nullptr, /* tp_as_number */ + nullptr, /* tp_as_sequence */ + &THPStorage_mappingmethods, /* tp_as_mapping */ + nullptr, /* tp_hash */ + nullptr, /* tp_call */ + nullptr, /* tp_str */ + nullptr, /* tp_getattro */ + nullptr, /* tp_setattro */ + nullptr, /* tp_as_buffer */ + Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /* tp_flags */ + nullptr, /* tp_doc */ + nullptr, /* tp_traverse */ + nullptr, /* tp_clear */ + nullptr, /* tp_richcompare */ + 0, /* tp_weaklistoffset */ + nullptr, /* tp_iter */ + nullptr, /* tp_iternext */ + nullptr, + /* will be assigned in init */ /* tp_methods */ + nullptr, + /* will be assigned in init */ /* tp_members */ + nullptr, /* tp_getset */ + nullptr, /* tp_base */ + nullptr, /* tp_dict */ + nullptr, /* tp_descr_get */ + nullptr, /* tp_descr_set */ + 0, /* tp_dictoffset */ + nullptr, /* tp_init */ + nullptr, /* tp_alloc */ + THPStorage_pynew, /* tp_new */ }; // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables) static struct PyMemberDef THPStorage_members[] = { - {(char*)"_cdata", T_ULONGLONG, offsetof(THPStorage, cdata), READONLY, nullptr}, - {nullptr} -}; - -static PyObject * THPStorage_device(THPStorage* self, void *unused) { + {(char*)"_cdata", + T_ULONGLONG, + offsetof(THPStorage, cdata), + READONLY, + nullptr}, + {nullptr}}; + +static PyObject* THPStorage_device(THPStorage* self, void* unused) { HANDLE_TH_ERRORS return THPDevice_New(self->cdata->device()); END_HANDLE_TH_ERRORS } -typedef PyObject *(*getter)(PyObject *, void *); +typedef PyObject* (*getter)(PyObject*, void*); // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables) static struct PyGetSetDef THPStorage_properties[] = { - {"device", (getter)THPStorage_device, nullptr, nullptr, nullptr}, - {nullptr} -}; + {"device", (getter)THPStorage_device, nullptr, nullptr, nullptr}, + {nullptr}}; -bool THPStorage_init(PyObject *module) -{ +bool THPStorage_init(PyObject* module) { static std::vector methods; THPUtils_addPyMethodDefs(methods, THPStorage_getMethods()); THPUtils_addPyMethodDefs(methods, THPStorage_getSharingMethods()); @@ -357,12 +379,12 @@ bool THPStorage_init(PyObject *module) if (PyType_Ready(&THPStorageType) < 0) return false; Py_INCREF(&THPStorageType); - PyModule_AddObject(module, THPStorageBaseStr, (PyObject *)&THPStorageType); + PyModule_AddObject(module, THPStorageBaseStr, (PyObject*)&THPStorageType); return true; } -void THPStorage_postInit(PyObject *module) -{ +void THPStorage_postInit(PyObject* module) { THPStorageClass = PyObject_GetAttrString(module, "_UntypedStorage"); - if (!THPStorageClass) throw python_error(); + if (!THPStorageClass) + throw python_error(); } diff --git a/torch/csrc/Storage.h b/torch/csrc/Storage.h index a7dbb0f3bfa529..593a96921baf10 100644 --- a/torch/csrc/Storage.h +++ b/torch/csrc/Storage.h @@ -7,15 +7,15 @@ #define THPStorageBaseStr "StorageBase" struct THPStorage { - PyObject_HEAD - c10::StorageImpl *cdata; + PyObject_HEAD c10::StorageImpl* cdata; }; -TORCH_PYTHON_API PyObject * THPStorage_New(c10::intrusive_ptr ptr); -extern PyObject *THPStorageClass; +TORCH_PYTHON_API PyObject* THPStorage_New( + c10::intrusive_ptr ptr); +extern PyObject* THPStorageClass; -bool THPStorage_init(PyObject *module); -void THPStorage_postInit(PyObject *module); +bool THPStorage_init(PyObject* module); +void THPStorage_postInit(PyObject* module); extern PyTypeObject THPStorageType; diff --git a/torch/csrc/StorageMethods.cpp b/torch/csrc/StorageMethods.cpp index 2fdf9c813ccf3a..cf6fca98c213cc 100644 --- a/torch/csrc/StorageMethods.cpp +++ b/torch/csrc/StorageMethods.cpp @@ -4,30 +4,30 @@ #endif #include +#include #include -#include -#include -#include #include #include +#include +#include #include -#include +#include -#include #include +#include -#include #include +#include #include #include #include -#include #include +#include #ifdef USE_CUDA -#include #include +#include #endif #include @@ -38,30 +38,30 @@ #define LSEEK lseek #endif -static PyObject * THPStorage_nbytes(PyObject *_self, PyObject *noargs) -{ +static PyObject* THPStorage_nbytes(PyObject* _self, PyObject* noargs) { HANDLE_TH_ERRORS auto self = (THPStorage*)_self; return THPUtils_packUInt64(self->cdata->nbytes()); END_HANDLE_TH_ERRORS } -static PyObject * THPStorage_dataPtr(PyObject *_self, PyObject *noargs) -{ +static PyObject* THPStorage_dataPtr(PyObject* _self, PyObject* noargs) { HANDLE_TH_ERRORS auto self = (THPStorage*)_self; return PyLong_FromVoidPtr(self->cdata->data()); END_HANDLE_TH_ERRORS } -static PyObject * THPStorage_copy_(PyObject *self, PyObject *args, PyObject *kwargs) -{ +static PyObject* THPStorage_copy_( + PyObject* self, + PyObject* args, + PyObject* kwargs) { HANDLE_TH_ERRORS at::Storage self_ = torch::createStorage(self); static torch::PythonArgParser parser({ - "copy_(Storage src, bool? non_blocking=None)", + "copy_(Storage src, bool? non_blocking=None)", }); torch::ParsedArgs<2> parsed_args; auto r = parser.parse(args, kwargs, parsed_args); @@ -79,47 +79,47 @@ static PyObject * THPStorage_copy_(PyObject *self, PyObject *args, PyObject *kwa END_HANDLE_TH_ERRORS } -static PyObject * THPStorage_isPinned(PyObject *_self, PyObject *noargs) -{ +static PyObject* THPStorage_isPinned(PyObject* _self, PyObject* noargs) { HANDLE_TH_ERRORS #if defined(USE_CUDA) auto self = (THPStorage*)_self; - return PyBool_FromLong(at::globalContext().isPinnedPtr(self->cdata->data())); + return PyBool_FromLong( + at::globalContext().isPinnedPtr(self->cdata->data())); #else Py_RETURN_FALSE; #endif END_HANDLE_TH_ERRORS } -static PyObject * THPStorage_elementSize(PyObject *_self, PyObject *noargs) -{ +static PyObject* THPStorage_elementSize(PyObject* _self, PyObject* noargs) { HANDLE_TH_ERRORS return THPUtils_packInt64(sizeof(uint8_t)); END_HANDLE_TH_ERRORS } -static PyObject * THPStorage_new(PyObject *_self, PyObject *noargs) -{ +static PyObject* THPStorage_new(PyObject* _self, PyObject* noargs) { HANDLE_TH_ERRORS auto self = (THPStorage*)_self; c10::Allocator* allocator = self->cdata->allocator(); auto new_storage = c10::make_intrusive( - c10::StorageImpl::use_byte_size_t(), - 0, - allocator, - /*resizable=*/true); + c10::StorageImpl::use_byte_size_t(), + 0, + allocator, + /*resizable=*/true); // NOLINTNEXTLINE(cppcoreguidelines-init-variables) return THPStorage_New(std::move(new_storage)); END_HANDLE_TH_ERRORS } -static PyObject * THPStorage_resize_(PyObject *_self, PyObject *number_arg) -{ +static PyObject* THPStorage_resize_(PyObject* _self, PyObject* number_arg) { HANDLE_TH_ERRORS auto self = (THPStorage*)_self; - THPUtils_assert(THPUtils_checkLong(number_arg), "resize_ expects an int, " - "but got %s", THPUtils_typename(number_arg)); + THPUtils_assert( + THPUtils_checkLong(number_arg), + "resize_ expects an int, " + "but got %s", + THPUtils_typename(number_arg)); int64_t newsize = THPUtils_unpackLong(number_arg); c10::DeviceType device_type = self->cdata->device_type(); if (device_type == at::kCPU) { @@ -127,63 +127,83 @@ static PyObject * THPStorage_resize_(PyObject *_self, PyObject *number_arg) #ifdef USE_CUDA } else if (device_type == at::kCUDA) { ptrdiff_t size_bytes_i = newsize; - TORCH_CHECK(!c10::overflows(size_bytes_i), - "Requested storage size (", size_bytes_i, - ") cannot be represented as a size_t"); + TORCH_CHECK( + !c10::overflows(size_bytes_i), + "Requested storage size (", + size_bytes_i, + ") cannot be represented as a size_t"); const auto size_bytes = static_cast(size_bytes_i); at::native::resize_bytes_cuda(self->cdata, size_bytes); #endif } else { - TORCH_CHECK(false, - "_UntypedStorage.resize_: got unexpected device type ", device_type); + TORCH_CHECK( + false, + "_UntypedStorage.resize_: got unexpected device type ", + device_type); } Py_INCREF(self); return (PyObject*)self; END_HANDLE_TH_ERRORS } -static PyObject * THPStorage_fill_(PyObject *_self, PyObject *number_arg) -{ +static PyObject* THPStorage_fill_(PyObject* _self, PyObject* number_arg) { HANDLE_TH_ERRORS auto self = (THPStorage*)_self; - THPUtils_assert(THPByteUtils_checkReal(number_arg), "fill_ expects %s, " - "but got %s", THPUtils_typeTraits::python_type_str, + THPUtils_assert( + THPByteUtils_checkReal(number_arg), + "fill_ expects %s, " + "but got %s", + THPUtils_typeTraits::python_type_str, THPUtils_typename(number_arg)); storage_fill( - at::unsafeStorageFromTH(self->cdata, /*retain=*/true), - THPByteUtils_unpackReal(number_arg)); + at::unsafeStorageFromTH(self->cdata, /*retain=*/true), + THPByteUtils_unpackReal(number_arg)); Py_INCREF(self); return (PyObject*)self; END_HANDLE_TH_ERRORS } -static PyObject * THPStorage_fromBuffer(PyObject *_unused, PyObject *args, PyObject *keywds) -{ +static PyObject* THPStorage_fromBuffer( + PyObject* _unused, + PyObject* args, + PyObject* keywds) { HANDLE_TH_ERRORS - PyObject *obj = nullptr; + PyObject* obj = nullptr; const char* byte_order_str = nullptr; Py_ssize_t count = -1, offset = 0; PyObject* dtype_obj = nullptr; c10::ScalarType scalar_type = at::kByte; Py_buffer buffer = {}; // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays,clang-diagnostic-writable-strings) - static char *kwlist[] = {"buffer", "byte_order", "count", "offset", "dtype", nullptr}; + static char* kwlist[] = { + "buffer", "byte_order", "count", "offset", "dtype", nullptr}; // NOLINTNEXTLINE(cppcoreguidelines-init-variables) const char* argtypes; argtypes = "O|snnO"; - if (!PyArg_ParseTupleAndKeywords(args, keywds, argtypes, kwlist, - &obj, &byte_order_str, &count, &offset, &dtype_obj)) { + if (!PyArg_ParseTupleAndKeywords( + args, + keywds, + argtypes, + kwlist, + &obj, + &byte_order_str, + &count, + &offset, + &dtype_obj)) { return nullptr; } TORCH_CHECK(dtype_obj != nullptr, "argument 'dtype' cannot be None"); - TORCH_CHECK(THPDtype_Check(dtype_obj), "argument 'dtype' must be of type torch.dtype"); + TORCH_CHECK( + THPDtype_Check(dtype_obj), + "argument 'dtype' must be of type torch.dtype"); auto dtype = reinterpret_cast(dtype_obj); scalar_type = dtype->scalar_type; TORCH_CHECK( - (scalar_type == at::kByte) || (scalar_type == at::kChar) || (byte_order_str != nullptr), - "function missing required argument 'byte_order' (pos 2)"); + (scalar_type == at::kByte) || (scalar_type == at::kChar) || + (byte_order_str != nullptr), + "function missing required argument 'byte_order' (pos 2)"); size_t element_size = c10::elementSize(scalar_type); // NOLINTNEXTLINE(cppcoreguidelines-init-variables) @@ -196,9 +216,10 @@ static PyObject * THPStorage_fromBuffer(PyObject *_unused, PyObject *args, PyObj } else if (strcmp(byte_order_str, "little") == 0) { byte_order = torch::utils::THP_LITTLE_ENDIAN; } else { - PyErr_Format(PyExc_ValueError, - "invalid byte_order '%s' (expected 'big', 'little', or 'native')", - byte_order_str); + PyErr_Format( + PyExc_ValueError, + "invalid byte_order '%s' (expected 'big', 'little', or 'native')", + byte_order_str); return nullptr; } } @@ -207,9 +228,12 @@ static PyObject * THPStorage_fromBuffer(PyObject *_unused, PyObject *args, PyObj return nullptr; if (offset < 0 || offset > buffer.len) { - PyErr_SetString(PyExc_ValueError, fmt::format( - "offset must be non-negative and no greater than buffer length ({}) , but got {}", - offset, buffer.len)); + PyErr_SetString( + PyExc_ValueError, + fmt::format( + "offset must be non-negative and no greater than buffer length ({}) , but got {}", + offset, + buffer.len)); PyBuffer_Release(&buffer); return nullptr; } @@ -218,9 +242,12 @@ static PyObject * THPStorage_fromBuffer(PyObject *_unused, PyObject *args, PyObj size_t size_bytes; if (count < 0) { if ((buffer.len - offset) % element_size != 0) { - PyErr_SetString(PyExc_ValueError, fmt::format( - "buffer size ({}) must be a multiple of element size ({})", - buffer.len, element_size)); + PyErr_SetString( + PyExc_ValueError, + fmt::format( + "buffer size ({}) must be a multiple of element size ({})", + buffer.len, + element_size)); PyBuffer_Release(&buffer); return nullptr; } @@ -231,19 +258,23 @@ static PyObject * THPStorage_fromBuffer(PyObject *_unused, PyObject *args, PyObj } if (offset + (count * (Py_ssize_t)element_size) > buffer.len) { - PyErr_SetString(PyExc_ValueError, fmt::format( - "buffer has only {} elements after offset {}, but specified a size of {}", - buffer.len - offset, offset, count)); + PyErr_SetString( + PyExc_ValueError, + fmt::format( + "buffer has only {} elements after offset {}, but specified a size of {}", + buffer.len - offset, + offset, + count)); PyBuffer_Release(&buffer); return nullptr; } - uint8_t* src = (uint8_t*) buffer.buf; + uint8_t* src = (uint8_t*)buffer.buf; auto storage = c10::make_intrusive( - c10::StorageImpl::use_byte_size_t(), - size_bytes, - c10::GetDefaultCPUAllocator(), - /*resizable=*/true); + c10::StorageImpl::use_byte_size_t(), + size_bytes, + c10::GetDefaultCPUAllocator(), + /*resizable=*/true); if (scalar_type == at::kByte || scalar_type == at::kChar) { memcpy(storage->data(), src + offset, count); @@ -252,44 +283,34 @@ static PyObject * THPStorage_fromBuffer(PyObject *_unused, PyObject *args, PyObj // we are trying to get a value which is not 0 or 1, we have to manually // convert original values to boolean ones. torch::utils::THP_decodeBoolBuffer( - storage->data(), - src + offset, byte_order, count); + storage->data(), src + offset, byte_order, count); } else if (scalar_type == at::kShort) { torch::utils::THP_decodeInt16Buffer( - storage->data(), - src + offset, byte_order, count); + storage->data(), src + offset, byte_order, count); } else if (scalar_type == at::kInt) { torch::utils::THP_decodeInt32Buffer( - storage->data(), - src + offset, byte_order, count); + storage->data(), src + offset, byte_order, count); } else if (scalar_type == at::kLong) { torch::utils::THP_decodeInt64Buffer( - storage->data(), - src + offset, byte_order, count); + storage->data(), src + offset, byte_order, count); } else if (scalar_type == at::kHalf) { torch::utils::THP_decodeHalfBuffer( - storage->data(), - src + offset, byte_order, count); + storage->data(), src + offset, byte_order, count); } else if (scalar_type == at::kBFloat16) { torch::utils::THP_decodeBFloat16Buffer( - storage->data(), - src + offset, byte_order, count); + storage->data(), src + offset, byte_order, count); } else if (scalar_type == at::kFloat) { torch::utils::THP_decodeFloatBuffer( - storage->data(), - src + offset, byte_order, count); + storage->data(), src + offset, byte_order, count); } else if (scalar_type == at::kDouble) { torch::utils::THP_decodeDoubleBuffer( - storage->data(), - src + offset, byte_order, count); + storage->data(), src + offset, byte_order, count); } else if (scalar_type == at::kComplexFloat) { torch::utils::THP_decodeComplexFloatBuffer( - storage->data>(), - src + offset, byte_order, count); + storage->data>(), src + offset, byte_order, count); } else if (scalar_type == at::kComplexDouble) { torch::utils::THP_decodeComplexDoubleBuffer( - storage->data>(), - src + offset, byte_order, count); + storage->data>(), src + offset, byte_order, count); } else { TORCH_CHECK(false, "Unknown type: ", scalar_type); } @@ -299,17 +320,19 @@ static PyObject * THPStorage_fromBuffer(PyObject *_unused, PyObject *args, PyObj END_HANDLE_TH_ERRORS } -static PyObject * THPStorage_fromFile(PyObject *_unused, PyObject *args, PyObject *keywds) -{ +static PyObject* THPStorage_fromFile( + PyObject* _unused, + PyObject* args, + PyObject* keywds) { HANDLE_TH_ERRORS // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - const char *filename; + const char* filename; Py_ssize_t nbytes = 0; int shared = 0; // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays,clang-diagnostic-writable-strings) - static char *kwlist[] = {"filename", "shared", "nbytes", nullptr}; - if (!PyArg_ParseTupleAndKeywords(args, keywds, "s|in", kwlist, - &filename, &shared, &nbytes)) { + static char* kwlist[] = {"filename", "shared", "nbytes", nullptr}; + if (!PyArg_ParseTupleAndKeywords( + args, keywds, "s|in", kwlist, &filename, &shared, &nbytes)) { return nullptr; } if (shared) @@ -317,12 +340,11 @@ static PyObject * THPStorage_fromFile(PyObject *_unused, PyObject *args, PyObjec size_t actual_nbytes = -1; auto storage = c10::make_intrusive( - c10::StorageImpl::use_byte_size_t(), - nbytes, - at::MapAllocator::makeDataPtr( - filename, shared, nbytes, &actual_nbytes), - /*allocator=*/nullptr, - /*resizable=*/false); + c10::StorageImpl::use_byte_size_t(), + nbytes, + at::MapAllocator::makeDataPtr(filename, shared, nbytes, &actual_nbytes), + /*allocator=*/nullptr, + /*resizable=*/false); if (nbytes <= 0) { storage->set_nbytes(actual_nbytes); @@ -332,43 +354,47 @@ static PyObject * THPStorage_fromFile(PyObject *_unused, PyObject *args, PyObjec END_HANDLE_TH_ERRORS } -PyObject * THPStorage_writeFile(PyObject *_self, PyObject *args) -{ +PyObject* THPStorage_writeFile(PyObject* _self, PyObject* args) { HANDLE_TH_ERRORS auto self = (THPStorage*)_self; - PyObject *file = PyTuple_GetItem(args, 0); + PyObject* file = PyTuple_GetItem(args, 0); bool is_real_file = PyTuple_GetItem(args, 1) == Py_True; bool save_size = PyTuple_GetItem(args, 2) == Py_True; - PyObject *element_size_obj = PyTuple_GET_ITEM(args, 3); + PyObject* element_size_obj = PyTuple_GET_ITEM(args, 3); - THPUtils_assert(element_size_obj != Py_None, - "_write_file: need to specify element size"); + THPUtils_assert( + element_size_obj != Py_None, "_write_file: need to specify element size"); uint64_t element_size = THPUtils_unpackUInt64(element_size_obj); if (!is_real_file) { - THPStorage_writeFileRaw(self->cdata, file, save_size, element_size); + THPStorage_writeFileRaw( + self->cdata, file, save_size, element_size); Py_RETURN_NONE; } int fd = PyObject_AsFileDescriptor(file); - THPUtils_assert(fd != -1, "_write_file couldn't retrieve a file descriptor " + THPUtils_assert( + fd != -1, + "_write_file couldn't retrieve a file descriptor " "from given object"); THPStorage_writeFileRaw(self->cdata, fd, save_size, element_size); Py_RETURN_NONE; END_HANDLE_TH_ERRORS } -PyObject * THPStorage_newWithFile(PyObject *_unused, PyObject *args) -{ +PyObject* THPStorage_newWithFile(PyObject* _unused, PyObject* args) { HANDLE_TH_ERRORS - TORCH_CHECK(PyTuple_Size(args) == 2, - "_new_with_file takes exactly two arguments"); + TORCH_CHECK( + PyTuple_Size(args) == 2, "_new_with_file takes exactly two arguments"); int fd = PyObject_AsFileDescriptor(PyTuple_GetItem(args, 0)); - THPUtils_assert(fd != -1, "_new_with_file couldn't retrieve a file " + THPUtils_assert( + fd != -1, + "_new_with_file couldn't retrieve a file " "descriptor from given object"); - PyObject *element_size_obj = PyTuple_GET_ITEM(args, 1); - THPUtils_assert(element_size_obj != Py_None, - "_new_with_file: need to specify element size"); + PyObject* element_size_obj = PyTuple_GET_ITEM(args, 1); + THPUtils_assert( + element_size_obj != Py_None, + "_new_with_file: need to specify element size"); uint64_t element_size = THPUtils_unpackUInt64(element_size_obj); auto storage = THPStorage_readFileRaw(fd, {}, element_size); @@ -378,33 +404,36 @@ PyObject * THPStorage_newWithFile(PyObject *_unused, PyObject *args) END_HANDLE_TH_ERRORS } -static PyObject *THPStorage_setFromFile(PyObject *_self, PyObject *args) -{ +static PyObject* THPStorage_setFromFile(PyObject* _self, PyObject* args) { HANDLE_TH_ERRORS auto self = (THPStorage*)_self; - PyObject *file = PyTuple_GET_ITEM(args, 0); - PyObject *offset = PyTuple_GET_ITEM(args, 1); + PyObject* file = PyTuple_GET_ITEM(args, 0); + PyObject* offset = PyTuple_GET_ITEM(args, 1); bool is_real_file = PyTuple_GET_ITEM(args, 2) == Py_True; - PyObject *element_size_obj = PyTuple_GET_ITEM(args, 3); + PyObject* element_size_obj = PyTuple_GET_ITEM(args, 3); - THPUtils_assert(element_size_obj != Py_None, - "_set_from_file: need to specify element size"); + THPUtils_assert( + element_size_obj != Py_None, + "_set_from_file: need to specify element size"); uint64_t element_size = THPUtils_unpackUInt64(element_size_obj); if (!is_real_file) { // offset can be implemented with a call to the Python object's seek() // but it is currently unnecessary to support this. - THPUtils_assert(offset == Py_None, - "_set_from_file: offset is NYI for filelike objects"); - - auto self_storage = c10::intrusive_ptr::reclaim_copy(self->cdata); - auto storage = THPStorage_readFileRaw(file, std::move(self_storage), element_size); + THPUtils_assert( + offset == Py_None, + "_set_from_file: offset is NYI for filelike objects"); + + auto self_storage = + c10::intrusive_ptr::reclaim_copy(self->cdata); + auto storage = THPStorage_readFileRaw( + file, std::move(self_storage), element_size); if (!storage.defined()) { return nullptr; } Py_INCREF(self); - return (PyObject *) self; + return (PyObject*)self; } // file is backed by a fd @@ -413,9 +442,12 @@ static PyObject *THPStorage_setFromFile(PyObject *_self, PyObject *args) if (offset != Py_None) { LSEEK(fd, THPUtils_unpackLong(offset), SEEK_SET); } - THPUtils_assert(fd != -1, "_set_from_file couldn't retrieve a file " + THPUtils_assert( + fd != -1, + "_set_from_file couldn't retrieve a file " "descriptor from given object"); - auto self_storage = c10::intrusive_ptr::reclaim_copy(self->cdata); + auto self_storage = + c10::intrusive_ptr::reclaim_copy(self->cdata); auto storage = THPStorage_readFileRaw(fd, self_storage, element_size); if (!storage.defined()) return nullptr; @@ -426,24 +458,26 @@ static PyObject *THPStorage_setFromFile(PyObject *_self, PyObject *args) // advanced position const auto fd_current_pos = LSEEK(fd, 0, SEEK_CUR); LSEEK(fd, fd_original_pos, SEEK_SET); - const auto seek_return = PyObject_CallMethod(file, "seek", "Li", (long long)fd_current_pos, 0); + const auto seek_return = + PyObject_CallMethod(file, "seek", "Li", (long long)fd_current_pos, 0); if (seek_return == nullptr) { - return nullptr; + return nullptr; } Py_DECREF(seek_return); - return (PyObject *) self; + return (PyObject*)self; END_HANDLE_TH_ERRORS } -PyObject * THPStorage__setCdata(PyObject *_self, PyObject *new_cdata) -{ +PyObject* THPStorage__setCdata(PyObject* _self, PyObject* new_cdata) { HANDLE_TH_ERRORS auto self = (THPStorage*)_self; - THPUtils_assert(THPUtils_checkLong(new_cdata), "given an invalid argument to " + THPUtils_assert( + THPUtils_checkLong(new_cdata), + "given an invalid argument to " "_set_cdata - expected an int or long, but got %s", THPUtils_typename(new_cdata)); - c10::StorageImpl *ptr = (c10::StorageImpl*)PyLong_AsVoidPtr(new_cdata); + c10::StorageImpl* ptr = (c10::StorageImpl*)PyLong_AsVoidPtr(new_cdata); if (ptr) { c10::raw::intrusive_ptr::incref(ptr); } @@ -458,25 +492,33 @@ PyObject * THPStorage__setCdata(PyObject *_self, PyObject *new_cdata) // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables) static PyMethodDef THPStorage_methods[] = { - {"copy_", castPyCFunctionWithKeywords(THPStorage_copy_), - METH_VARARGS | METH_KEYWORDS, nullptr}, - {"element_size", THPStorage_elementSize, METH_NOARGS, nullptr}, - {"fill_", THPStorage_fill_, METH_O, nullptr}, - {"new", THPStorage_new, METH_NOARGS, nullptr}, - {"resize_", THPStorage_resize_, METH_O, nullptr}, - {"nbytes", THPStorage_nbytes, METH_NOARGS, nullptr}, - {"data_ptr", THPStorage_dataPtr, METH_NOARGS, nullptr}, - {"is_pinned", THPStorage_isPinned, METH_NOARGS, nullptr}, - {"_write_file", THPStorage_writeFile, METH_VARARGS, nullptr}, - {"_new_with_file", THPStorage_newWithFile, METH_VARARGS | METH_STATIC, nullptr}, - {"_set_from_file", THPStorage_setFromFile, METH_VARARGS, nullptr}, - {"from_buffer", castPyCFunctionWithKeywords(THPStorage_fromBuffer), - METH_VARARGS | METH_KEYWORDS | METH_STATIC, nullptr}, - {"from_file", castPyCFunctionWithKeywords(THPStorage_fromFile), - METH_VARARGS | METH_KEYWORDS | METH_STATIC, nullptr}, - {"_set_cdata", THPStorage__setCdata, METH_O, nullptr}, - {nullptr} -}; + {"copy_", + castPyCFunctionWithKeywords(THPStorage_copy_), + METH_VARARGS | METH_KEYWORDS, + nullptr}, + {"element_size", THPStorage_elementSize, METH_NOARGS, nullptr}, + {"fill_", THPStorage_fill_, METH_O, nullptr}, + {"new", THPStorage_new, METH_NOARGS, nullptr}, + {"resize_", THPStorage_resize_, METH_O, nullptr}, + {"nbytes", THPStorage_nbytes, METH_NOARGS, nullptr}, + {"data_ptr", THPStorage_dataPtr, METH_NOARGS, nullptr}, + {"is_pinned", THPStorage_isPinned, METH_NOARGS, nullptr}, + {"_write_file", THPStorage_writeFile, METH_VARARGS, nullptr}, + {"_new_with_file", + THPStorage_newWithFile, + METH_VARARGS | METH_STATIC, + nullptr}, + {"_set_from_file", THPStorage_setFromFile, METH_VARARGS, nullptr}, + {"from_buffer", + castPyCFunctionWithKeywords(THPStorage_fromBuffer), + METH_VARARGS | METH_KEYWORDS | METH_STATIC, + nullptr}, + {"from_file", + castPyCFunctionWithKeywords(THPStorage_fromFile), + METH_VARARGS | METH_KEYWORDS | METH_STATIC, + nullptr}, + {"_set_cdata", THPStorage__setCdata, METH_O, nullptr}, + {nullptr}}; PyMethodDef* THPStorage_getMethods() { return THPStorage_methods; diff --git a/torch/csrc/StorageSharing.cpp b/torch/csrc/StorageSharing.cpp index 0353e42f040d23..c35747a0ae6c6c 100644 --- a/torch/csrc/StorageSharing.cpp +++ b/torch/csrc/StorageSharing.cpp @@ -4,57 +4,57 @@ #endif #include +#include #include -#include -#include -#include #include #include +#include +#include #include -#include +#include -#include #include +#include -#include #include +#include #ifdef USE_CUDA +#include #include #include -#include #endif -#include #include +#include #include #include -static PyObject * THPStorage_sharedDecref(PyObject *_self, PyObject *noargs) -{ +static PyObject* THPStorage_sharedDecref(PyObject* _self, PyObject* noargs) { HANDLE_TH_ERRORS auto self = (THPStorage*)_self; c10::DeviceType device_type = self->cdata->device_type(); if (device_type == at::kCPU) { - c10::StorageImpl *storage = self->cdata; - THManagedMapAllocator *ctx = THManagedMapAllocator::fromDataPtr(storage->data_ptr()); + c10::StorageImpl* storage = self->cdata; + THManagedMapAllocator* ctx = + THManagedMapAllocator::fromDataPtr(storage->data_ptr()); if (ctx) { ctx->decref(); } } Py_INCREF(self); - return (PyObject *)self; + return (PyObject*)self; END_HANDLE_TH_ERRORS } -static PyObject * THPStorage_sharedIncref(PyObject *_self, PyObject *noargs) -{ +static PyObject* THPStorage_sharedIncref(PyObject* _self, PyObject* noargs) { HANDLE_TH_ERRORS auto self = (THPStorage*)_self; c10::DeviceType device_type = self->cdata->device_type(); if (device_type == at::kCPU) { - c10::StorageImpl *storage = self->cdata; - THManagedMapAllocator *ctx = THManagedMapAllocator::fromDataPtr(storage->data_ptr()); + c10::StorageImpl* storage = self->cdata; + THManagedMapAllocator* ctx = + THManagedMapAllocator::fromDataPtr(storage->data_ptr()); if (ctx) { ctx->incref(); } @@ -63,8 +63,9 @@ static PyObject * THPStorage_sharedIncref(PyObject *_self, PyObject *noargs) END_HANDLE_TH_ERRORS } -static PyObject * THPStorage_pyNewFilenameStorage(PyObject *_unused, PyObject *args) -{ +static PyObject* THPStorage_pyNewFilenameStorage( + PyObject* _unused, + PyObject* args) { HANDLE_TH_ERRORS // NOLINTNEXTLINE(cppcoreguidelines-init-variables) long long size; @@ -75,23 +76,23 @@ static PyObject * THPStorage_pyNewFilenameStorage(PyObject *_unused, PyObject *a int flags = at::ALLOCATOR_MAPPED_SHAREDMEM | at::ALLOCATOR_MAPPED_EXCLUSIVE; std::string handle = at::NewProcessWideShmHandle(); return THPStorage_New(c10::make_intrusive( - c10::StorageImpl::use_byte_size_t(), - size, - THManagedMapAllocator::makeDataPtr("", handle.c_str(), flags, size), - /*allocator=*/nullptr, - /*resizable=*/false)); + c10::StorageImpl::use_byte_size_t(), + size, + THManagedMapAllocator::makeDataPtr("", handle.c_str(), flags, size), + /*allocator=*/nullptr, + /*resizable=*/false)); END_HANDLE_TH_ERRORS } -static PyObject * THPStorage_shareFilename(PyObject *_self, PyObject *noargs) -{ +static PyObject* THPStorage_shareFilename(PyObject* _self, PyObject* noargs) { HANDLE_TH_ERRORS - TORCH_CHECK(reinterpret_cast(_self)->cdata->device_type() == at::kCPU, + TORCH_CHECK( + reinterpret_cast(_self)->cdata->device_type() == at::kCPU, "_share_filename_: only available on CPU"); auto self = (THPStorage*)_self; - c10::StorageImpl *storage = self->cdata; + c10::StorageImpl* storage = self->cdata; // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - THManagedMapAllocator *ctx; + THManagedMapAllocator* ctx; // Storage is already in shared memory, just return a handle if ((ctx = THManagedMapAllocator::fromDataPtr(storage->data_ptr()))) { // done @@ -101,11 +102,12 @@ static PyObject * THPStorage_shareFilename(PyObject *_self, PyObject *noargs) int flags = at::ALLOCATOR_MAPPED_SHAREDMEM | at::ALLOCATOR_MAPPED_EXCLUSIVE; std::string handle = at::NewProcessWideShmHandle(); at::Storage new_storage(c10::make_intrusive( - c10::StorageImpl::use_byte_size_t(), - storage->nbytes(), - THManagedMapAllocator::makeDataPtr("", handle.c_str(), flags, storage->nbytes()), - /*allocator=*/nullptr, - /*resizable=*/false)); + c10::StorageImpl::use_byte_size_t(), + storage->nbytes(), + THManagedMapAllocator::makeDataPtr( + "", handle.c_str(), flags, storage->nbytes()), + /*allocator=*/nullptr, + /*resizable=*/false)); at::Storage _self_aten = torch::createStorage(_self); storage_copy(new_storage, _self_aten); @@ -116,14 +118,18 @@ static PyObject * THPStorage_shareFilename(PyObject *_self, PyObject *noargs) } THPObjectPtr manager_handle(PyBytes_FromString(ctx->manager_handle())); - if (!manager_handle) return nullptr; + if (!manager_handle) + return nullptr; THPObjectPtr storage_handle(PyBytes_FromString(ctx->filename())); - if (!storage_handle) return nullptr; + if (!storage_handle) + return nullptr; THPObjectPtr size(THPUtils_packUInt64(storage->nbytes() / sizeof(uint8_t))); - if (!size) return nullptr; + if (!size) + return nullptr; THPObjectPtr tuple(PyTuple_New(3)); - if (!tuple) return nullptr; + if (!tuple) + return nullptr; PyTuple_SET_ITEM(tuple.get(), 0, manager_handle.release()); PyTuple_SET_ITEM(tuple.get(), 1, storage_handle.release()); PyTuple_SET_ITEM(tuple.get(), 2, size.release()); @@ -131,51 +137,54 @@ static PyObject * THPStorage_shareFilename(PyObject *_self, PyObject *noargs) END_HANDLE_TH_ERRORS } -static PyObject * THPStorage_newSharedFilename(PyObject *_unused, PyObject *args) -{ +static PyObject* THPStorage_newSharedFilename( + PyObject* _unused, + PyObject* args) { HANDLE_TH_ERRORS THPUtils_assert(PyTuple_GET_SIZE(args) == 3, "tuple of 3 items expected"); - PyObject *_manager_handle = PyTuple_GET_ITEM(args, 0); - PyObject *_object_handle = PyTuple_GET_ITEM(args, 1); - PyObject *_size = PyTuple_GET_ITEM(args, 2); - if (!PyBytes_Check(_manager_handle) || !PyBytes_Check(_object_handle) || !THPUtils_checkLong(_size)) { - THPUtils_invalidArguments(args, nullptr, "_new_shared in file system mode", 1, + PyObject* _manager_handle = PyTuple_GET_ITEM(args, 0); + PyObject* _object_handle = PyTuple_GET_ITEM(args, 1); + PyObject* _size = PyTuple_GET_ITEM(args, 2); + if (!PyBytes_Check(_manager_handle) || !PyBytes_Check(_object_handle) || + !THPUtils_checkLong(_size)) { + THPUtils_invalidArguments( + args, + nullptr, + "_new_shared in file system mode", + 1, "a handle (string/bytes) and storage size (int)"); return nullptr; } - const char *manager_handle = PyBytes_AS_STRING(_manager_handle); - const char *object_handle = PyBytes_AS_STRING(_object_handle); + const char* manager_handle = PyBytes_AS_STRING(_manager_handle); + const char* object_handle = PyBytes_AS_STRING(_object_handle); int64_t size = THPUtils_unpackLong(_size); - int flags = at::ALLOCATOR_MAPPED_SHAREDMEM | - at::ALLOCATOR_MAPPED_NOCREATE; - return THPStorage_New( - c10::make_intrusive( + int flags = at::ALLOCATOR_MAPPED_SHAREDMEM | at::ALLOCATOR_MAPPED_NOCREATE; + return THPStorage_New(c10::make_intrusive( c10::StorageImpl::use_byte_size_t(), size, - THManagedMapAllocator::makeDataPtr(manager_handle, object_handle, flags, size), + THManagedMapAllocator::makeDataPtr( + manager_handle, object_handle, flags, size), /*allocator=*/nullptr, /*resizable=*/false)); END_HANDLE_TH_ERRORS } -static c10::intrusive_ptr THPStorage_newFdStorage(ptrdiff_t size) -{ - int flags = at::ALLOCATOR_MAPPED_SHAREDMEM | - at::ALLOCATOR_MAPPED_EXCLUSIVE | - at::ALLOCATOR_MAPPED_KEEPFD | - at::ALLOCATOR_MAPPED_UNLINK; +static c10::intrusive_ptr THPStorage_newFdStorage( + ptrdiff_t size) { + int flags = at::ALLOCATOR_MAPPED_SHAREDMEM | at::ALLOCATOR_MAPPED_EXCLUSIVE | + at::ALLOCATOR_MAPPED_KEEPFD | at::ALLOCATOR_MAPPED_UNLINK; std::string handle = at::NewProcessWideShmHandle(); - auto sptr = at::MapAllocator::makeDataPtr(handle.c_str(), flags, size * sizeof(uint8_t), nullptr); + auto sptr = at::MapAllocator::makeDataPtr( + handle.c_str(), flags, size * sizeof(uint8_t), nullptr); return c10::make_intrusive( - c10::StorageImpl::use_byte_size_t(), - size, - std::move(sptr), - /*allocator=*/nullptr, - /*resizable=*/false); + c10::StorageImpl::use_byte_size_t(), + size, + std::move(sptr), + /*allocator=*/nullptr, + /*resizable=*/false); } -static PyObject * THPStorage_pyNewFdStorage(PyObject *_unused, PyObject *args) -{ +static PyObject* THPStorage_pyNewFdStorage(PyObject* _unused, PyObject* args) { HANDLE_TH_ERRORS // NOLINTNEXTLINE(cppcoreguidelines-init-variables) long long size; @@ -186,15 +195,15 @@ static PyObject * THPStorage_pyNewFdStorage(PyObject *_unused, PyObject *args) END_HANDLE_TH_ERRORS } -static PyObject * THPStorage_shareFd(PyObject *_self, PyObject *noargs) -{ +static PyObject* THPStorage_shareFd(PyObject* _self, PyObject* noargs) { HANDLE_TH_ERRORS - TORCH_CHECK(reinterpret_cast(_self)->cdata->device_type() == at::kCPU, + TORCH_CHECK( + reinterpret_cast(_self)->cdata->device_type() == at::kCPU, "_share_fd_: only available on CPU"); auto self = (THPStorage*)_self; - c10::StorageImpl *storage = self->cdata; + c10::StorageImpl* storage = self->cdata; // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - at::MapAllocator *ctx; + at::MapAllocator* ctx; // Storage is already in shared memory, just return a handle if ((ctx = at::MapAllocator::fromDataPtr(storage->data_ptr()))) { // done @@ -209,44 +218,47 @@ static PyObject * THPStorage_shareFd(PyObject *_self, PyObject *noargs) } THPObjectPtr storage_handle(THPUtils_packInt32(ctx->fd())); - if (!storage_handle) return nullptr; + if (!storage_handle) + return nullptr; THPObjectPtr size(THPUtils_packUInt64(storage->nbytes() / sizeof(uint8_t))); - if (!size) return nullptr; + if (!size) + return nullptr; THPObjectPtr tuple(PyTuple_New(2)); - if (!tuple) return nullptr; + if (!tuple) + return nullptr; PyTuple_SET_ITEM(tuple.get(), 0, storage_handle.release()); PyTuple_SET_ITEM(tuple.get(), 1, size.release()); return tuple.release(); END_HANDLE_TH_ERRORS } -static PyObject * THPStorage_newSharedFd(PyObject *_unused, PyObject *args) -{ +static PyObject* THPStorage_newSharedFd(PyObject* _unused, PyObject* args) { HANDLE_TH_ERRORS THPUtils_assert(PyTuple_GET_SIZE(args) == 2, "tuple of 2 items expected"); - PyObject *_tmp_fd = PyTuple_GET_ITEM(args, 0); - PyObject *_size = PyTuple_GET_ITEM(args, 1); + PyObject* _tmp_fd = PyTuple_GET_ITEM(args, 0); + PyObject* _size = PyTuple_GET_ITEM(args, 1); if (!THPUtils_checkLong(_tmp_fd) || !THPUtils_checkLong(_size)) { - THPUtils_invalidArguments(args, nullptr, "_new_shared in file descriptor mode", - 1, "a file descriptor (int) and storage size (int)"); + THPUtils_invalidArguments( + args, + nullptr, + "_new_shared in file descriptor mode", + 1, + "a file descriptor (int) and storage size (int)"); return nullptr; } // NOLINTNEXTLINE(cppcoreguidelines-init-variables) int fd; - int tmp_fd = (int) THPUtils_unpackLong(_tmp_fd); + int tmp_fd = (int)THPUtils_unpackLong(_tmp_fd); int64_t size = THPUtils_unpackLong(_size); if ((fd = dup(tmp_fd)) == -1) { THPUtils_setError("could not duplicate a shared memory file descriptor"); return nullptr; } - int flags = at::ALLOCATOR_MAPPED_SHAREDMEM | - at::ALLOCATOR_MAPPED_NOCREATE | - at::ALLOCATOR_MAPPED_KEEPFD | - at::ALLOCATOR_MAPPED_FROMFD; - return THPStorage_New( - c10::make_intrusive( + int flags = at::ALLOCATOR_MAPPED_SHAREDMEM | at::ALLOCATOR_MAPPED_NOCREATE | + at::ALLOCATOR_MAPPED_KEEPFD | at::ALLOCATOR_MAPPED_FROMFD; + return THPStorage_New(c10::make_intrusive( c10::StorageImpl::use_byte_size_t(), size, at::MapAllocator::makeDataPtr(at::WITH_FD, "", fd, flags, size, nullptr), @@ -255,14 +267,14 @@ static PyObject * THPStorage_newSharedFd(PyObject *_unused, PyObject *args) END_HANDLE_TH_ERRORS } -static PyObject * THPStorage_shareCuda(PyObject *_self, PyObject *noargs) -{ +static PyObject* THPStorage_shareCuda(PyObject* _self, PyObject* noargs) { HANDLE_TH_ERRORS #ifdef USE_CUDA - TORCH_CHECK(reinterpret_cast(_self)->cdata->device_type() == at::kCUDA, + TORCH_CHECK( + reinterpret_cast(_self)->cdata->device_type() == at::kCUDA, "_share_cuda_: only available on CUDA"); auto self = (THPStorage*)_self; - c10::StorageImpl *storage = self->cdata; + c10::StorageImpl* storage = self->cdata; if (storage->received_cuda()) { AT_ERROR( @@ -286,21 +298,24 @@ static PyObject * THPStorage_shareCuda(PyObject *_self, PyObject *noargs) if (storage->data()) { // NOLINTNEXTLINE(cppcoreguidelines-init-variables) size_t base_size; - void *base_ptr = c10::cuda::CUDACachingAllocator::getBaseAllocation(storage->data(), &base_size); + void* base_ptr = c10::cuda::CUDACachingAllocator::getBaseAllocation( + storage->data(), &base_size); ptrdiff_t offset_bytes = (char*)storage->data() - (char*)base_ptr; // NOLINTNEXTLINE(cppcoreguidelines-init-variables) cudaIpcMemHandle_t handle; C10_CUDA_CHECK(cudaIpcGetMemHandle(&handle, base_ptr)); - _handle = PyBytes_FromStringAndSize((char *)&handle, CUDA_IPC_HANDLE_SIZE); + _handle = PyBytes_FromStringAndSize((char*)&handle, CUDA_IPC_HANDLE_SIZE); _offset_bytes = PyLong_FromSsize_t((Py_ssize_t)offset_bytes); // Put Storage Data behind new ref counting context // See Note [CUDA IPC Refcounting implementation explained] - at::DataPtr sent_data_ptr = torch::GetNewRefCountedSentData(storage->data(), storage->device()); + at::DataPtr sent_data_ptr = + torch::GetNewRefCountedSentData(storage->data(), storage->device()); auto old_data_ptr = storage->set_data_ptr(std::move(sent_data_ptr)); - auto sent_data = static_cast(storage->data_ptr().get_context()); + auto sent_data = + static_cast(storage->data_ptr().get_context()); sent_data->set_original_ptr(std::move(old_data_ptr)); _ref_counter = PyBytes_FromString((sent_data->handle()).c_str()); _ref_counter_offset = THPUtils_packInt64(sent_data->offset()); @@ -309,24 +324,28 @@ static PyObject * THPStorage_shareCuda(PyObject *_self, PyObject *noargs) cudaIpcEventHandle_t ipc_event_handle; if (sent_data->event_sync_required_) { - C10_CUDA_CHECK(cudaIpcGetEventHandle(&ipc_event_handle, sent_data->event_)); + C10_CUDA_CHECK( + cudaIpcGetEventHandle(&ipc_event_handle, sent_data->event_)); } - _event_handle = PyBytes_FromStringAndSize((char *)&ipc_event_handle, CUDA_IPC_HANDLE_SIZE); + _event_handle = PyBytes_FromStringAndSize( + (char*)&ipc_event_handle, CUDA_IPC_HANDLE_SIZE); _event_sync_required = PyBool_FromLong(sent_data->event_sync_required_); - } - if (!tuple || !device || !_handle || !size_bytes || !_offset_bytes || !_event_handle) { + if (!tuple || !device || !_handle || !size_bytes || !_offset_bytes || + !_event_handle) { return nullptr; } PyTuple_SET_ITEM(tuple.get(), 0, device.release()); // cudaIpcMemHandle_t(of basePtr) PyTuple_SET_ITEM(tuple.get(), 1, _handle.release()); - // Size(in bytes) of the real storage, note this is not the size of basePtr memory block. + // Size(in bytes) of the real storage, note this is not the size of basePtr + // memory block. PyTuple_SET_ITEM(tuple.get(), 2, size_bytes.release()); // Offset(in bytes) of the real storage in the basePtr memory block. - // NB: this offset MUST be in bytes instead of numel, since we use (storage_handle, offset) + // NB: this offset MUST be in bytes instead of numel, since we use + // (storage_handle, offset) // as key in shared_cache(multiprocessing/reduction.py). // Offset in numel cannot uniquely represent a storage. PyTuple_SET_ITEM(tuple.get(), 3, _offset_bytes.release()); @@ -341,13 +360,14 @@ static PyObject * THPStorage_shareCuda(PyObject *_self, PyObject *noargs) END_HANDLE_TH_ERRORS } -static PyObject * THPStorage_releaseIPCCounter(PyObject *_unused, PyObject *args) -{ +static PyObject* THPStorage_releaseIPCCounter( + PyObject* _unused, + PyObject* args) { HANDLE_TH_ERRORS #ifdef USE_CUDA THPUtils_assert(PyTuple_GET_SIZE(args) == 2, "tuple of 2 items expected"); - PyObject *_ref_counter = PyTuple_GET_ITEM(args, 0); - PyObject *_ref_counter_offset = PyTuple_GET_ITEM(args, 1); + PyObject* _ref_counter = PyTuple_GET_ITEM(args, 0); + PyObject* _ref_counter_offset = PyTuple_GET_ITEM(args, 1); if (!(PyBytes_Check(_ref_counter) && THPUtils_checkLong(_ref_counter_offset))) { THPUtils_invalidArguments( @@ -364,8 +384,7 @@ static PyObject * THPStorage_releaseIPCCounter(PyObject *_unused, PyObject *args // We don't want to break existing code, so resource deletion is best // effort basis. Exception expected if producer process terminated // before consumer released data. - int flags = - at::ALLOCATOR_MAPPED_SHAREDMEM | at::ALLOCATOR_MAPPED_NOCREATE; + int flags = at::ALLOCATOR_MAPPED_SHAREDMEM | at::ALLOCATOR_MAPPED_NOCREATE; try { auto sptr = at::RefcountedMapAllocator::makeDataPtr( ref_counter_handle.c_str(), @@ -384,7 +403,7 @@ static PyObject * THPStorage_releaseIPCCounter(PyObject *_unused, PyObject *args } #ifdef USE_CUDA -static std::string THPStorage_bytesAsHandleString(PyObject *handle) { +static std::string THPStorage_bytesAsHandleString(PyObject* handle) { // NOLINTNEXTLINE(cppcoreguidelines-init-variables) char* buffer; // NOLINTNEXTLINE(cppcoreguidelines-init-variables) @@ -394,29 +413,28 @@ static std::string THPStorage_bytesAsHandleString(PyObject *handle) { return nullptr; } // NOLINTNEXTLINE(bugprone-string-constructor) - THPUtils_assert( - handle_size == CUDA_IPC_HANDLE_SIZE, "incorrect handle size"); + THPUtils_assert(handle_size == CUDA_IPC_HANDLE_SIZE, "incorrect handle size"); return std::string(buffer, handle_size); } #endif -static PyObject * THPStorage_newSharedCuda(PyObject *_unused, PyObject *args) -{ +static PyObject* THPStorage_newSharedCuda(PyObject* _unused, PyObject* args) { HANDLE_TH_ERRORS #ifdef USE_CUDA THPUtils_assert(PyTuple_GET_SIZE(args) == 8, "tuple of 8 items expected"); - PyObject *_device = PyTuple_GET_ITEM(args, 0); - PyObject *_handle = PyTuple_GET_ITEM(args, 1); - PyObject *_size_bytes = PyTuple_GET_ITEM(args, 2); - PyObject *_offset_bytes = PyTuple_GET_ITEM(args, 3); - PyObject *_ref_counter = PyTuple_GET_ITEM(args, 4); - PyObject *_ref_counter_offset = PyTuple_GET_ITEM(args, 5); - PyObject *_event_handle = PyTuple_GET_ITEM(args, 6); - PyObject *_event_sync_required = PyTuple_GET_ITEM(args, 7); + PyObject* _device = PyTuple_GET_ITEM(args, 0); + PyObject* _handle = PyTuple_GET_ITEM(args, 1); + PyObject* _size_bytes = PyTuple_GET_ITEM(args, 2); + PyObject* _offset_bytes = PyTuple_GET_ITEM(args, 3); + PyObject* _ref_counter = PyTuple_GET_ITEM(args, 4); + PyObject* _ref_counter_offset = PyTuple_GET_ITEM(args, 5); + PyObject* _event_handle = PyTuple_GET_ITEM(args, 6); + PyObject* _event_sync_required = PyTuple_GET_ITEM(args, 7); if (!(THPUtils_checkLong(_device) && THPUtils_checkLong(_size_bytes) && PyBytes_Check(_handle) && PyBytes_Check(_ref_counter) && PyBytes_Check(_event_handle) && THPUtils_checkLong(_offset_bytes) && - THPUtils_checkLong(_ref_counter_offset) && PyBool_Check(_event_sync_required))) { + THPUtils_checkLong(_ref_counter_offset) && + PyBool_Check(_event_sync_required))) { THPUtils_invalidArguments( args, nullptr, @@ -426,8 +444,10 @@ static PyObject * THPStorage_newSharedCuda(PyObject *_unused, PyObject *args) return nullptr; } - size_t storage_size = (size_t)THPUtils_unpackLong(_size_bytes) / sizeof(uint8_t); - ptrdiff_t storage_offset_bytes = (ptrdiff_t)THPUtils_unpackLong(_offset_bytes); + size_t storage_size = + (size_t)THPUtils_unpackLong(_size_bytes) / sizeof(uint8_t); + ptrdiff_t storage_offset_bytes = + (ptrdiff_t)THPUtils_unpackLong(_offset_bytes); int64_t device = THPUtils_unpackLong(_device); at::cuda::CUDAGuard device_guard(device); @@ -446,7 +466,8 @@ static PyObject * THPStorage_newSharedCuda(PyObject *_unused, PyObject *args) } std::string s_handle = THPStorage_bytesAsHandleString(_handle); - std::shared_ptr basePtr = c10::cuda::CUDACachingAllocator::getIpcDevPtr(s_handle); + std::shared_ptr basePtr = + c10::cuda::CUDACachingAllocator::getIpcDevPtr(s_handle); // Offset the basePtr to reconstruct the real storage // devPtr = basePtr + storage_offset @@ -455,7 +476,8 @@ static PyObject * THPStorage_newSharedCuda(PyObject *_unused, PyObject *args) devPtr = (char*)devPtr + storage_offset_bytes; std::string ref_counter_handle = PyBytes_AS_STRING(_ref_counter); - ptrdiff_t ref_counter_offset = (ptrdiff_t)THPUtils_unpackLong(_ref_counter_offset); + ptrdiff_t ref_counter_offset = + (ptrdiff_t)THPUtils_unpackLong(_ref_counter_offset); struct IpcDeleterContext { std::string ref_counter_handle; @@ -474,23 +496,26 @@ static PyObject * THPStorage_newSharedCuda(PyObject *_unused, PyObject *args) c10::DataPtr data_ptr( devPtr, ctx.release(), - +[](void *ctx_) { - std::unique_ptr ctx(static_cast(ctx_)); + +[](void* ctx_) { + std::unique_ptr ctx( + static_cast(ctx_)); ctx->received_data.shared_ptr_.reset(); - // Sync default stream to make sure all operations related to the storage is - // finished (otherwise another process may reuse memory and corrupt - // data) + // Sync default stream to make sure all operations related to the + // storage is finished (otherwise another process may reuse memory and + // corrupt data) // Ideally all shared memory reference counting could be replaced by // sending untriggered CUDA event from the producer to consumer and - // using this event as the criteria of memory release. However, CUDA (atm 10.1) - // does not support the creation of untriggered events and performance - // impact of having thousands of shared events is unknown. + // using this event as the criteria of memory release. However, CUDA + // (atm 10.1) does not support the creation of untriggered events and + // performance impact of having thousands of shared events is unknown. // TODO: Instead of cudaStreamSynchronize it is possible to add Stream - // Callback and release counter inside of it (need to check performance impact) - at::cuda::stream_synchronize(c10::cuda::getCurrentCUDAStream(ctx->device)); + // Callback and release counter inside of it (need to check performance + // impact) + at::cuda::stream_synchronize( + c10::cuda::getCurrentCUDAStream(ctx->device)); // We don't want to break existing code, so resource deletion is best // effort basis. Exception expected if producer process terminated @@ -511,11 +536,11 @@ static PyObject * THPStorage_newSharedCuda(PyObject *_unused, PyObject *args) at::Device(at::DeviceType::CUDA, cur_device)); auto base = c10::make_intrusive( - c10::StorageImpl::use_byte_size_t(), - storage_size, - std::move(data_ptr), - /*allocator=*/nullptr, - /*resizable=*/false); + c10::StorageImpl::use_byte_size_t(), + storage_size, + std::move(data_ptr), + /*allocator=*/nullptr, + /*resizable=*/false); base->set_resizable(false); base->set_received_cuda(true); @@ -532,7 +557,7 @@ static PyObject * THPStorage_newSharedCuda(PyObject *_unused, PyObject *args) // pointer. // // NB: This does NOT preserve object identity when you call it multiple times -static PyObject * THPStorage_weakRef(PyObject *_self, PyObject *args) { +static PyObject* THPStorage_weakRef(PyObject* _self, PyObject* args) { HANDLE_TH_ERRORS auto self = (THPStorage*)_self; c10::StorageImpl* storage = self->cdata; @@ -540,50 +565,48 @@ static PyObject * THPStorage_weakRef(PyObject *_self, PyObject *args) { END_HANDLE_TH_ERRORS } -PyObject * THPStorage_newWithWeakPtr(PyObject *_unused, PyObject *arg) -{ +PyObject* THPStorage_newWithWeakPtr(PyObject* _unused, PyObject* arg) { HANDLE_TH_ERRORS - THPUtils_assert(THPUtils_checkLong(arg), - "_new_with_weak_ptr(): arg must be an 'int'"); - c10::StorageImpl *weak_storage = (c10::StorageImpl*)PyLong_AsVoidPtr(arg); + THPUtils_assert( + THPUtils_checkLong(arg), "_new_with_weak_ptr(): arg must be an 'int'"); + c10::StorageImpl* weak_storage = (c10::StorageImpl*)PyLong_AsVoidPtr(arg); if (auto* storage = c10::raw::weak_intrusive_ptr::lock(weak_storage)) { - return THPStorage_New(c10::intrusive_ptr::reclaim(storage)); + return THPStorage_New( + c10::intrusive_ptr::reclaim(storage)); } Py_RETURN_NONE; END_HANDLE_TH_ERRORS } -PyObject * THPStorage_freeWeakRef(PyObject *_unused, PyObject *arg) -{ +PyObject* THPStorage_freeWeakRef(PyObject* _unused, PyObject* arg) { HANDLE_TH_ERRORS if (arg == Py_None) { Py_RETURN_NONE; } - THPUtils_assert(THPUtils_checkLong(arg), - "_free_weak_ref(): arg must be an 'int'"); - c10::StorageImpl *weak_storage = (c10::StorageImpl*)PyLong_AsVoidPtr(arg); + THPUtils_assert( + THPUtils_checkLong(arg), "_free_weak_ref(): arg must be an 'int'"); + c10::StorageImpl* weak_storage = (c10::StorageImpl*)PyLong_AsVoidPtr(arg); c10::raw::weak_intrusive_ptr::decref(weak_storage); Py_RETURN_NONE; END_HANDLE_TH_ERRORS } -PyObject * THPStorage_expired(PyObject *_unused, PyObject *arg) -{ +PyObject* THPStorage_expired(PyObject* _unused, PyObject* arg) { HANDLE_TH_ERRORS THPUtils_assert(THPUtils_checkLong(arg), "_expired(): arg must be an 'int'"); - c10::StorageImpl *weak_storage = (c10::StorageImpl*)PyLong_AsVoidPtr(arg); - return PyBool_FromLong(c10::raw::weak_intrusive_ptr::use_count(weak_storage) == 0); + c10::StorageImpl* weak_storage = (c10::StorageImpl*)PyLong_AsVoidPtr(arg); + return PyBool_FromLong( + c10::raw::weak_intrusive_ptr::use_count(weak_storage) == 0); END_HANDLE_TH_ERRORS } -PyObject * THPStorage_sharedFd(PyObject *_self, PyObject *noargs) -{ +PyObject* THPStorage_sharedFd(PyObject* _self, PyObject* noargs) { HANDLE_TH_ERRORS auto self = (THPStorage*)_self; - at::MapAllocator *ctx = nullptr; + at::MapAllocator* ctx = nullptr; if (self->cdata->device_type() == at::kCPU) { - c10::StorageImpl *storage = self->cdata; + c10::StorageImpl* storage = self->cdata; ctx = at::MapAllocator::fromDataPtr(storage->data_ptr()); } @@ -592,8 +615,7 @@ PyObject * THPStorage_sharedFd(PyObject *_self, PyObject *noargs) END_HANDLE_TH_ERRORS } -PyObject * THPStorage_isShared(PyObject *_self, PyObject *noargs) -{ +PyObject* THPStorage_isShared(PyObject* _self, PyObject* noargs) { auto self = (THPStorage*)_self; if (self->cdata->device_type() == at::kCUDA) { Py_RETURN_TRUE; @@ -608,25 +630,45 @@ PyObject * THPStorage_isShared(PyObject *_self, PyObject *noargs) // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables) static PyMethodDef THPStorage_sharingMethods[] = { - {"_new_with_weak_ptr", THPStorage_newWithWeakPtr, METH_O | METH_CLASS, nullptr}, - {"_share_cuda_", THPStorage_shareCuda, METH_NOARGS, nullptr}, - {"_new_shared_cuda", THPStorage_newSharedCuda, METH_VARARGS | METH_STATIC, nullptr}, - {"_release_ipc_counter_cuda", THPStorage_releaseIPCCounter, METH_VARARGS | METH_STATIC, nullptr}, - {"_share_fd_cpu_", THPStorage_shareFd, METH_NOARGS, nullptr}, - {"_new_shared_fd_cpu", THPStorage_newSharedFd, METH_VARARGS | METH_STATIC, nullptr}, - {"_new_using_fd_cpu", THPStorage_pyNewFdStorage, METH_VARARGS | METH_STATIC, nullptr}, - {"_share_filename_cpu_", THPStorage_shareFilename, METH_NOARGS, nullptr}, - {"_new_shared_filename_cpu", THPStorage_newSharedFilename, METH_VARARGS | METH_STATIC, nullptr}, - {"_new_using_filename_cpu", THPStorage_pyNewFilenameStorage, METH_VARARGS | METH_STATIC, nullptr}, - {"_weak_ref", THPStorage_weakRef, METH_NOARGS, nullptr}, - {"_free_weak_ref", THPStorage_freeWeakRef, METH_O | METH_STATIC, nullptr}, - {"_expired", THPStorage_expired, METH_O | METH_STATIC, nullptr}, - {"_shared_decref", THPStorage_sharedDecref, METH_NOARGS, nullptr}, - {"_shared_incref", THPStorage_sharedIncref, METH_NOARGS, nullptr}, - {"_get_shared_fd", THPStorage_sharedFd, METH_NOARGS, nullptr}, - {"is_shared", THPStorage_isShared, METH_NOARGS, nullptr}, - {nullptr} -}; + {"_new_with_weak_ptr", + THPStorage_newWithWeakPtr, + METH_O | METH_CLASS, + nullptr}, + {"_share_cuda_", THPStorage_shareCuda, METH_NOARGS, nullptr}, + {"_new_shared_cuda", + THPStorage_newSharedCuda, + METH_VARARGS | METH_STATIC, + nullptr}, + {"_release_ipc_counter_cuda", + THPStorage_releaseIPCCounter, + METH_VARARGS | METH_STATIC, + nullptr}, + {"_share_fd_cpu_", THPStorage_shareFd, METH_NOARGS, nullptr}, + {"_new_shared_fd_cpu", + THPStorage_newSharedFd, + METH_VARARGS | METH_STATIC, + nullptr}, + {"_new_using_fd_cpu", + THPStorage_pyNewFdStorage, + METH_VARARGS | METH_STATIC, + nullptr}, + {"_share_filename_cpu_", THPStorage_shareFilename, METH_NOARGS, nullptr}, + {"_new_shared_filename_cpu", + THPStorage_newSharedFilename, + METH_VARARGS | METH_STATIC, + nullptr}, + {"_new_using_filename_cpu", + THPStorage_pyNewFilenameStorage, + METH_VARARGS | METH_STATIC, + nullptr}, + {"_weak_ref", THPStorage_weakRef, METH_NOARGS, nullptr}, + {"_free_weak_ref", THPStorage_freeWeakRef, METH_O | METH_STATIC, nullptr}, + {"_expired", THPStorage_expired, METH_O | METH_STATIC, nullptr}, + {"_shared_decref", THPStorage_sharedDecref, METH_NOARGS, nullptr}, + {"_shared_incref", THPStorage_sharedIncref, METH_NOARGS, nullptr}, + {"_get_shared_fd", THPStorage_sharedFd, METH_NOARGS, nullptr}, + {"is_shared", THPStorage_isShared, METH_NOARGS, nullptr}, + {nullptr}}; PyMethodDef* THPStorage_getSharingMethods() { return THPStorage_sharingMethods; diff --git a/torch/csrc/Stream.cpp b/torch/csrc/Stream.cpp index d21e2124ea5ac6..7d0cd473af94d8 100644 --- a/torch/csrc/Stream.cpp +++ b/torch/csrc/Stream.cpp @@ -5,16 +5,17 @@ #include -PyTypeObject *THPStreamClass = nullptr; +PyTypeObject* THPStreamClass = nullptr; static PyObject* THPStream_pynew( - PyTypeObject *type, PyObject *args, PyObject *kwargs) { + PyTypeObject* type, + PyObject* args, + PyObject* kwargs) { HANDLE_TH_ERRORS uint64_t cdata = 0; // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays,clang-diagnostic-writable-strings) - static char *kwlist[] = {"_cdata", nullptr}; - if (!PyArg_ParseTupleAndKeywords( - args, kwargs, "|K", kwlist, &cdata)) { + static char* kwlist[] = {"_cdata", nullptr}; + if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|K", kwlist, &cdata)) { return nullptr; } @@ -23,23 +24,23 @@ static PyObject* THPStream_pynew( return nullptr; } - THPStream* self = (THPStream *)ptr.get(); + THPStream* self = (THPStream*)ptr.get(); self->cdata = cdata; - return (PyObject *)ptr.release(); + return (PyObject*)ptr.release(); END_HANDLE_TH_ERRORS } -static void THPStream_dealloc(THPStream *self) { +static void THPStream_dealloc(THPStream* self) { Py_TYPE(self)->tp_free((PyObject*)self); } -static PyObject * THPStream_get_device(THPStream *self, void *unused) { +static PyObject* THPStream_get_device(THPStream* self, void* unused) { HANDLE_TH_ERRORS return THPDevice_New(c10::Stream::unpack(self->cdata).device()); END_HANDLE_TH_ERRORS } -static PyObject * THPStream_eq(THPStream *self, THPStream *other) { +static PyObject* THPStream_eq(THPStream* self, THPStream* other) { HANDLE_TH_ERRORS return PyBool_FromLong(self->cdata == other->cdata); END_HANDLE_TH_ERRORS @@ -47,75 +48,71 @@ static PyObject * THPStream_eq(THPStream *self, THPStream *other) { // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables) static struct PyMemberDef THPStream_members[] = { - {(char*)"_cdata", - T_ULONGLONG, offsetof(THPStream, cdata), READONLY, nullptr}, - {nullptr} -}; + {(char*)"_cdata", + T_ULONGLONG, + offsetof(THPStream, cdata), + READONLY, + nullptr}, + {nullptr}}; // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables) static struct PyGetSetDef THPStream_properties[] = { - {"device", (getter)THPStream_get_device, nullptr, nullptr, nullptr}, - {nullptr} -}; + {"device", (getter)THPStream_get_device, nullptr, nullptr, nullptr}, + {nullptr}}; // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables) static PyMethodDef THPStream_methods[] = { - {(char*)"__eq__", (PyCFunction)THPStream_eq, METH_O, nullptr}, - {nullptr} -}; + {(char*)"__eq__", (PyCFunction)THPStream_eq, METH_O, nullptr}, + {nullptr}}; PyTypeObject THPStreamType = { - PyVarObject_HEAD_INIT(nullptr, 0) - "torch.Stream", /* tp_name */ - sizeof(THPStream), /* tp_basicsize */ - 0, /* tp_itemsize */ - (destructor)THPStream_dealloc, /* tp_dealloc */ - 0, /* tp_vectorcall_offset */ - nullptr, /* tp_getattr */ - nullptr, /* tp_setattr */ - nullptr, /* tp_reserved */ - nullptr, /* tp_repr */ - nullptr, /* tp_as_number */ - nullptr, /* tp_as_sequence */ - nullptr, /* tp_as_mapping */ - nullptr, /* tp_hash */ - nullptr, /* tp_call */ - nullptr, /* tp_str */ - nullptr, /* tp_getattro */ - nullptr, /* tp_setattro */ - nullptr, /* tp_as_buffer */ - Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /* tp_flags */ - nullptr, /* tp_doc */ - nullptr, /* tp_traverse */ - nullptr, /* tp_clear */ - nullptr, /* tp_richcompare */ - 0, /* tp_weaklistoffset */ - nullptr, /* tp_iter */ - nullptr, /* tp_iternext */ - THPStream_methods, /* tp_methods */ - THPStream_members, /* tp_members */ - THPStream_properties, /* tp_getset */ - nullptr, /* tp_base */ - nullptr, /* tp_dict */ - nullptr, /* tp_descr_get */ - nullptr, /* tp_descr_set */ - 0, /* tp_dictoffset */ - nullptr, /* tp_init */ - nullptr, /* tp_alloc */ - THPStream_pynew, /* tp_new */ + PyVarObject_HEAD_INIT(nullptr, 0) "torch.Stream", /* tp_name */ + sizeof(THPStream), /* tp_basicsize */ + 0, /* tp_itemsize */ + (destructor)THPStream_dealloc, /* tp_dealloc */ + 0, /* tp_vectorcall_offset */ + nullptr, /* tp_getattr */ + nullptr, /* tp_setattr */ + nullptr, /* tp_reserved */ + nullptr, /* tp_repr */ + nullptr, /* tp_as_number */ + nullptr, /* tp_as_sequence */ + nullptr, /* tp_as_mapping */ + nullptr, /* tp_hash */ + nullptr, /* tp_call */ + nullptr, /* tp_str */ + nullptr, /* tp_getattro */ + nullptr, /* tp_setattro */ + nullptr, /* tp_as_buffer */ + Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /* tp_flags */ + nullptr, /* tp_doc */ + nullptr, /* tp_traverse */ + nullptr, /* tp_clear */ + nullptr, /* tp_richcompare */ + 0, /* tp_weaklistoffset */ + nullptr, /* tp_iter */ + nullptr, /* tp_iternext */ + THPStream_methods, /* tp_methods */ + THPStream_members, /* tp_members */ + THPStream_properties, /* tp_getset */ + nullptr, /* tp_base */ + nullptr, /* tp_dict */ + nullptr, /* tp_descr_get */ + nullptr, /* tp_descr_set */ + 0, /* tp_dictoffset */ + nullptr, /* tp_init */ + nullptr, /* tp_alloc */ + THPStream_pynew, /* tp_new */ }; - -void THPStream_init(PyObject *module) -{ +void THPStream_init(PyObject* module) { THPStreamClass = &THPStreamType; Py_TYPE(&THPStreamType) = &PyType_Type; if (PyType_Ready(&THPStreamType) < 0) { throw python_error(); } Py_INCREF(&THPStreamType); - if (PyModule_AddObject( - module, "Stream", (PyObject *)&THPStreamType) < 0) { + if (PyModule_AddObject(module, "Stream", (PyObject*)&THPStreamType) < 0) { throw python_error(); } } diff --git a/torch/csrc/Stream.h b/torch/csrc/Stream.h index 8d507977e12aa0..6a61e9ca085c97 100644 --- a/torch/csrc/Stream.h +++ b/torch/csrc/Stream.h @@ -4,12 +4,11 @@ #include struct THPStream { - PyObject_HEAD - uint64_t cdata; + PyObject_HEAD uint64_t cdata; }; -extern PyTypeObject *THPStreamClass; +extern PyTypeObject* THPStreamClass; -void THPStream_init(PyObject *module); +void THPStream_init(PyObject* module); inline bool THPStream_Check(PyObject* obj) { return THPStreamClass && PyObject_IsInstance(obj, (PyObject*)THPStreamClass); diff --git a/torch/csrc/THConcat.h b/torch/csrc/THConcat.h index cfc82b45353008..23512f1bce4248 100644 --- a/torch/csrc/THConcat.h +++ b/torch/csrc/THConcat.h @@ -1,19 +1,19 @@ #pragma once -#define TH_CONCAT_STRING_2(x,y) TH_CONCAT_STRING_2_EXPAND(x,y) -#define TH_CONCAT_STRING_2_EXPAND(x,y) #x #y +#define TH_CONCAT_STRING_2(x, y) TH_CONCAT_STRING_2_EXPAND(x, y) +#define TH_CONCAT_STRING_2_EXPAND(x, y) #x #y -#define TH_CONCAT_STRING_3(x,y,z) TH_CONCAT_STRING_3_EXPAND(x,y,z) -#define TH_CONCAT_STRING_3_EXPAND(x,y,z) #x #y #z +#define TH_CONCAT_STRING_3(x, y, z) TH_CONCAT_STRING_3_EXPAND(x, y, z) +#define TH_CONCAT_STRING_3_EXPAND(x, y, z) #x #y #z -#define TH_CONCAT_STRING_4(x,y,z,w) TH_CONCAT_STRING_4_EXPAND(x,y,z,w) -#define TH_CONCAT_STRING_4_EXPAND(x,y,z,w) #x #y #z #w +#define TH_CONCAT_STRING_4(x, y, z, w) TH_CONCAT_STRING_4_EXPAND(x, y, z, w) +#define TH_CONCAT_STRING_4_EXPAND(x, y, z, w) #x #y #z #w -#define TH_CONCAT_2(x,y) TH_CONCAT_2_EXPAND(x,y) -#define TH_CONCAT_2_EXPAND(x,y) x ## y +#define TH_CONCAT_2(x, y) TH_CONCAT_2_EXPAND(x, y) +#define TH_CONCAT_2_EXPAND(x, y) x##y -#define TH_CONCAT_3(x,y,z) TH_CONCAT_3_EXPAND(x,y,z) -#define TH_CONCAT_3_EXPAND(x,y,z) x ## y ## z +#define TH_CONCAT_3(x, y, z) TH_CONCAT_3_EXPAND(x, y, z) +#define TH_CONCAT_3_EXPAND(x, y, z) x##y##z -#define TH_CONCAT_4_EXPAND(x,y,z,w) x ## y ## z ## w -#define TH_CONCAT_4(x,y,z,w) TH_CONCAT_4_EXPAND(x,y,z,w) +#define TH_CONCAT_4_EXPAND(x, y, z, w) x##y##z##w +#define TH_CONCAT_4(x, y, z, w) TH_CONCAT_4_EXPAND(x, y, z, w) diff --git a/torch/csrc/THP.h b/torch/csrc/THP.h index 05c042f29c48ae..88d8489ba7b8f9 100644 --- a/torch/csrc/THP.h +++ b/torch/csrc/THP.h @@ -1,17 +1,17 @@ #ifndef THP_H #define THP_H -#include #include +#include // Back-compatibility macros, Thanks to http://cx-oracle.sourceforge.net/ // define PyInt_* macros for Python 3.x. NB: We must include Python.h first, // otherwise we'll incorrectly conclude PyInt_Check isn't defined! #ifndef PyInt_Check -#define PyInt_Check PyLong_Check -#define PyInt_FromLong PyLong_FromLong -#define PyInt_AsLong PyLong_AsLong -#define PyInt_Type PyLong_Type +#define PyInt_Check PyLong_Check +#define PyInt_FromLong PyLong_FromLong +#define PyInt_AsLong PyLong_AsLong +#define PyInt_Type PyLong_Type #endif #include diff --git a/torch/csrc/TypeInfo.cpp b/torch/csrc/TypeInfo.cpp index 5af75063b0b031..5e08a96109e004 100644 --- a/torch/csrc/TypeInfo.cpp +++ b/torch/csrc/TypeInfo.cpp @@ -75,10 +75,10 @@ PyObject* THPIInfo_pynew(PyTypeObject* type, PyObject* args, PyObject* kwargs) { at::ScalarType scalar_type = r.scalartype(0); if (scalar_type == at::ScalarType::Bool) { return PyErr_Format( - PyExc_TypeError, - "torch.bool is not supported by torch.iinfo"); + PyExc_TypeError, "torch.bool is not supported by torch.iinfo"); } - if (!at::isIntegralType(scalar_type, /*includeBool=*/false) && !at::isQIntType(scalar_type)) { + if (!at::isIntegralType(scalar_type, /*includeBool=*/false) && + !at::isQIntType(scalar_type)) { return PyErr_Format( PyExc_TypeError, "torch.iinfo() requires an integer input type. Use torch.finfo to handle '%s'", @@ -113,8 +113,8 @@ static PyObject* THPDTypeInfo_bits(THPDTypeInfo* self, void*) { } static PyObject* THPFInfo_eps(THPFInfo* self, void*) { - return AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(at::kHalf, at::ScalarType::BFloat16, - self->type, "epsilon", [] { + return AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2( + at::kHalf, at::ScalarType::BFloat16, self->type, "epsilon", [] { return PyFloat_FromDouble( std::numeric_limits< at::scalar_value_type::type>::epsilon()); @@ -122,17 +122,20 @@ static PyObject* THPFInfo_eps(THPFInfo* self, void*) { } static PyObject* THPFInfo_max(THPFInfo* self, void*) { - return AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(at::kHalf, at::ScalarType::BFloat16, self->type, "max", [] { - return PyFloat_FromDouble( - std::numeric_limits::type>::max()); - }); + return AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2( + at::kHalf, at::ScalarType::BFloat16, self->type, "max", [] { + return PyFloat_FromDouble( + std::numeric_limits::type>::max()); + }); } static PyObject* THPFInfo_min(THPFInfo* self, void*) { - return AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(at::kHalf, at::ScalarType::BFloat16, self->type, "lowest", [] { - return PyFloat_FromDouble( - std::numeric_limits::type>::lowest()); - }); + return AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2( + at::kHalf, at::ScalarType::BFloat16, self->type, "lowest", [] { + return PyFloat_FromDouble( + std::numeric_limits< + at::scalar_value_type::type>::lowest()); + }); } static PyObject* THPIInfo_max(THPIInfo* self, void*) { @@ -143,7 +146,7 @@ static PyObject* THPIInfo_max(THPIInfo* self, void*) { } // Quantized Type return AT_DISPATCH_QINT_AND_SUB_BYTE_TYPES(self->type, "max", [] { - return THPUtils_packInt64(std::numeric_limits::max()); + return THPUtils_packInt64(std::numeric_limits::max()); }); } @@ -155,7 +158,7 @@ static PyObject* THPIInfo_min(THPIInfo* self, void*) { } // Quantized Type return AT_DISPATCH_QINT_AND_SUB_BYTE_TYPES(self->type, "min", [] { - return THPUtils_packInt64(std::numeric_limits::lowest()); + return THPUtils_packInt64(std::numeric_limits::lowest()); }); } @@ -169,10 +172,11 @@ static PyObject* THPIInfo_dtype(THPIInfo* self, void*) { } static PyObject* THPFInfo_smallest_normal(THPFInfo* self, void*) { - return AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(at::kHalf, at::ScalarType::BFloat16, self->type, "min", [] { - return PyFloat_FromDouble( - std::numeric_limits::type>::min()); - }); + return AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2( + at::kHalf, at::ScalarType::BFloat16, self->type, "min", [] { + return PyFloat_FromDouble( + std::numeric_limits::type>::min()); + }); } static PyObject* THPFInfo_tiny(THPFInfo* self, void*) { @@ -181,28 +185,34 @@ static PyObject* THPFInfo_tiny(THPFInfo* self, void*) { } static PyObject* THPFInfo_resolution(THPFInfo* self, void*) { - return AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(at::kHalf, at::ScalarType::BFloat16, self->type, "digits10", [] { - return PyFloat_FromDouble( - std::pow(10, -std::numeric_limits::type>::digits10)); - }); + return AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2( + at::kHalf, at::ScalarType::BFloat16, self->type, "digits10", [] { + return PyFloat_FromDouble(std::pow( + 10, + -std::numeric_limits< + at::scalar_value_type::type>::digits10)); + }); } static PyObject* THPFInfo_dtype(THPFInfo* self, void*) { std::string primary_name, legacy_name; std::tie(primary_name, legacy_name) = torch::utils::getDtypeNames(self->type); // NOLINTNEXTLINE(clang-diagnostic-unused-local-typedef) - return AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(at::kHalf, at::ScalarType::BFloat16, self->type, "dtype", [primary_name] { - return PyUnicode_FromString((char*)primary_name.data()); - }); + return AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2( + at::kHalf, at::ScalarType::BFloat16, self->type, "dtype", [primary_name] { + return PyUnicode_FromString((char*)primary_name.data()); + }); } PyObject* THPFInfo_str(THPFInfo* self) { std::ostringstream oss; - oss << "finfo(resolution=" << PyFloat_AsDouble(THPFInfo_resolution(self, nullptr)); + oss << "finfo(resolution=" + << PyFloat_AsDouble(THPFInfo_resolution(self, nullptr)); oss << ", min=" << PyFloat_AsDouble(THPFInfo_min(self, nullptr)); oss << ", max=" << PyFloat_AsDouble(THPFInfo_max(self, nullptr)); oss << ", eps=" << PyFloat_AsDouble(THPFInfo_eps(self, nullptr)); - oss << ", smallest_normal=" << PyFloat_AsDouble(THPFInfo_smallest_normal(self, nullptr)); + oss << ", smallest_normal=" + << PyFloat_AsDouble(THPFInfo_smallest_normal(self, nullptr)); oss << ", tiny=" << PyFloat_AsDouble(THPFInfo_tiny(self, nullptr)); oss << ", dtype=" << PyUnicode_AsUTF8(THPFInfo_dtype(self, nullptr)) << ")"; @@ -228,7 +238,11 @@ static struct PyGetSetDef THPFInfo_properties[] = { {"eps", (getter)THPFInfo_eps, nullptr, nullptr, nullptr}, {"max", (getter)THPFInfo_max, nullptr, nullptr, nullptr}, {"min", (getter)THPFInfo_min, nullptr, nullptr, nullptr}, - {"smallest_normal", (getter)THPFInfo_smallest_normal, nullptr, nullptr, nullptr}, + {"smallest_normal", + (getter)THPFInfo_smallest_normal, + nullptr, + nullptr, + nullptr}, {"tiny", (getter)THPFInfo_tiny, nullptr, nullptr, nullptr}, {"resolution", (getter)THPFInfo_resolution, nullptr, nullptr, nullptr}, {"dtype", (getter)THPFInfo_dtype, nullptr, nullptr, nullptr}, @@ -241,42 +255,42 @@ static PyMethodDef THPFInfo_methods[] = { PyTypeObject THPFInfoType = { PyVarObject_HEAD_INIT(nullptr, 0) "torch.finfo", /* tp_name */ - sizeof(THPFInfo), /* tp_basicsize */ - 0, /* tp_itemsize */ - nullptr, /* tp_dealloc */ - 0, /* tp_vectorcall_offset */ - nullptr, /* tp_getattr */ - nullptr, /* tp_setattr */ - nullptr, /* tp_reserved */ - (reprfunc)THPFInfo_str, /* tp_repr */ - nullptr, /* tp_as_number */ - nullptr, /* tp_as_sequence */ - nullptr, /* tp_as_mapping */ - nullptr, /* tp_hash */ - nullptr, /* tp_call */ - (reprfunc)THPFInfo_str, /* tp_str */ - nullptr, /* tp_getattro */ - nullptr, /* tp_setattro */ - nullptr, /* tp_as_buffer */ - Py_TPFLAGS_DEFAULT, /* tp_flags */ - nullptr, /* tp_doc */ - nullptr, /* tp_traverse */ - nullptr, /* tp_clear */ - (richcmpfunc)THPDTypeInfo_compare, /* tp_richcompare */ - 0, /* tp_weaklistoffset */ - nullptr, /* tp_iter */ - nullptr, /* tp_iternext */ - THPFInfo_methods, /* tp_methods */ - nullptr, /* tp_members */ - THPFInfo_properties, /* tp_getset */ - nullptr, /* tp_base */ - nullptr, /* tp_dict */ - nullptr, /* tp_descr_get */ - nullptr, /* tp_descr_set */ - 0, /* tp_dictoffset */ - nullptr, /* tp_init */ - nullptr, /* tp_alloc */ - THPFInfo_pynew, /* tp_new */ + sizeof(THPFInfo), /* tp_basicsize */ + 0, /* tp_itemsize */ + nullptr, /* tp_dealloc */ + 0, /* tp_vectorcall_offset */ + nullptr, /* tp_getattr */ + nullptr, /* tp_setattr */ + nullptr, /* tp_reserved */ + (reprfunc)THPFInfo_str, /* tp_repr */ + nullptr, /* tp_as_number */ + nullptr, /* tp_as_sequence */ + nullptr, /* tp_as_mapping */ + nullptr, /* tp_hash */ + nullptr, /* tp_call */ + (reprfunc)THPFInfo_str, /* tp_str */ + nullptr, /* tp_getattro */ + nullptr, /* tp_setattro */ + nullptr, /* tp_as_buffer */ + Py_TPFLAGS_DEFAULT, /* tp_flags */ + nullptr, /* tp_doc */ + nullptr, /* tp_traverse */ + nullptr, /* tp_clear */ + (richcmpfunc)THPDTypeInfo_compare, /* tp_richcompare */ + 0, /* tp_weaklistoffset */ + nullptr, /* tp_iter */ + nullptr, /* tp_iternext */ + THPFInfo_methods, /* tp_methods */ + nullptr, /* tp_members */ + THPFInfo_properties, /* tp_getset */ + nullptr, /* tp_base */ + nullptr, /* tp_dict */ + nullptr, /* tp_descr_get */ + nullptr, /* tp_descr_set */ + 0, /* tp_dictoffset */ + nullptr, /* tp_init */ + nullptr, /* tp_alloc */ + THPFInfo_pynew, /* tp_new */ }; // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables,cppcoreguidelines-avoid-c-arrays) diff --git a/torch/csrc/Types.h b/torch/csrc/Types.h index a2147c5079733f..01a20cb01dff68 100644 --- a/torch/csrc/Types.h +++ b/torch/csrc/Types.h @@ -7,6 +7,7 @@ #include #endif -template struct THPTypeInfo {}; +template +struct THPTypeInfo {}; #endif diff --git a/torch/csrc/api/include/torch/all.h b/torch/csrc/api/include/torch/all.h index c1839ddb2df6e0..5228dbd8637a79 100644 --- a/torch/csrc/api/include/torch/all.h +++ b/torch/csrc/api/include/torch/all.h @@ -4,6 +4,7 @@ #error C++14 or later compatible compiler is required to use PyTorch. #endif +#include #include #include #include @@ -17,5 +18,4 @@ #include #include #include -#include #include diff --git a/torch/csrc/api/include/torch/arg.h b/torch/csrc/api/include/torch/arg.h index 5ed77424a033c7..a09362ee582dcc 100644 --- a/torch/csrc/api/include/torch/arg.h +++ b/torch/csrc/api/include/torch/arg.h @@ -2,21 +2,22 @@ #include -#define TORCH_ARG(T, name) \ - public: \ +#define TORCH_ARG(T, name) \ + public: \ inline auto name(const T& new_##name)->decltype(*this) { /* NOLINT */ \ - this->name##_ = new_##name; \ - return *this; \ - } \ + this->name##_ = new_##name; \ + return *this; \ + } \ inline auto name(T&& new_##name)->decltype(*this) { /* NOLINT */ \ - this->name##_ = std::move(new_##name); \ - return *this; \ - } \ + this->name##_ = std::move(new_##name); \ + return *this; \ + } \ inline const T& name() const noexcept { /* NOLINT */ \ - return this->name##_; \ - } \ - inline T& name() noexcept { /* NOLINT */ \ - return this->name##_; \ - } \ - private: \ + return this->name##_; \ + } \ + inline T& name() noexcept { /* NOLINT */ \ + return this->name##_; \ + } \ + \ + private: \ T name##_ /* NOLINT */ diff --git a/torch/csrc/api/include/torch/autograd.h b/torch/csrc/api/include/torch/autograd.h index 809fbe8bd33502..cf0608fa01bbf5 100644 --- a/torch/csrc/api/include/torch/autograd.h +++ b/torch/csrc/api/include/torch/autograd.h @@ -1,5 +1,5 @@ #pragma once #include -#include #include +#include diff --git a/torch/csrc/api/include/torch/cuda.h b/torch/csrc/api/include/torch/cuda.h index 3db1e128825817..537ddf02479c26 100644 --- a/torch/csrc/api/include/torch/cuda.h +++ b/torch/csrc/api/include/torch/cuda.h @@ -2,8 +2,8 @@ #include -#include #include +#include namespace torch { namespace cuda { diff --git a/torch/csrc/api/include/torch/data/dataloader.h b/torch/csrc/api/include/torch/data/dataloader.h index 881b42915d7eca..1825615df90263 100644 --- a/torch/csrc/api/include/torch/data/dataloader.h +++ b/torch/csrc/api/include/torch/data/dataloader.h @@ -25,7 +25,9 @@ torch::disable_if_t< make_data_loader(Dataset dataset, Sampler sampler, DataLoaderOptions options) { return torch::make_unique>( // NOLINTNEXTLINE(performance-move-const-arg) - std::move(dataset), std::move(sampler), std::move(options)); + std::move(dataset), + std::move(sampler), + std::move(options)); } /// Creates a `DataLoader` instance for a stateless `dataset` and some @@ -45,7 +47,9 @@ make_data_loader( "order to construct the Sampler"); return make_data_loader( // NOLINTNEXTLINE(performance-move-const-arg) - std::move(dataset), Sampler(*size), std::move(options)); + std::move(dataset), + Sampler(*size), + std::move(options)); } /// Creates a `DataLoader` for a stateful `dataset` and some `options`. @@ -55,7 +59,8 @@ std::unique_ptr> make_data_loader( DataLoaderOptions options = DataLoaderOptions()) { return torch::make_unique>( // NOLINTNEXTLINE(performance-move-const-arg) - std::move(dataset), std::move(options)); + std::move(dataset), + std::move(options)); } } // namespace data } // namespace torch diff --git a/torch/csrc/api/include/torch/data/dataloader/stateful.h b/torch/csrc/api/include/torch/data/dataloader/stateful.h index 482c5b9c975c9d..a8270a8ec6b76f 100644 --- a/torch/csrc/api/include/torch/data/dataloader/stateful.h +++ b/torch/csrc/api/include/torch/data/dataloader/stateful.h @@ -1,7 +1,7 @@ #pragma once -#include #include +#include #include #include @@ -17,8 +17,8 @@ namespace data { /// this dataset is itself responsible for producing batches rather than /// depending on a sampler. The statefulness here actually refers to the /// dataset. The StatefulDataLoader simply alters the data loading algorithm to -/// accommodate the stateful, shared nature of the dataset. Note that the dataset -/// must be thread safe if more than one worker thread is used. +/// accommodate the stateful, shared nature of the dataset. Note that the +/// dataset must be thread safe if more than one worker thread is used. /// /// A stateful dataloader is created by calling `make_data_loader` with a /// stateful dataset. @@ -40,7 +40,7 @@ class StatefulDataLoader : public DataLoaderBase< // NOLINTNEXTLINE(performance-move-const-arg) std::move(options), torch::make_unique(std::move(dataset))) { - for (const auto w : c10::irange(this->options_.workers)) { + for (const auto w : c10::irange(this->options_.workers)) { // As opposed to the stateless case, here all worker threads access the // same underlying dataset. this->workers_.emplace_back( diff --git a/torch/csrc/api/include/torch/data/dataloader/stateless.h b/torch/csrc/api/include/torch/data/dataloader/stateless.h index 8440fffbc0a7d8..7940a0c6cfbd7c 100644 --- a/torch/csrc/api/include/torch/data/dataloader/stateless.h +++ b/torch/csrc/api/include/torch/data/dataloader/stateless.h @@ -42,12 +42,12 @@ class StatelessDataLoader : public DataLoaderBase< DataLoaderOptions options) // NOLINTNEXTLINE(performance-move-const-arg) : super(std::move(options)), sampler_(std::move(sampler)) { - for (const auto w : c10::irange(this->options_.workers)) { + for (const auto w : c10::irange(this->options_.workers)) { // Here we copy the dataset into the worker thread closure. Each worker // has its own copy of the dataset. This means the dataset must be // trivially copiable, or else we don't expect more than one worker to // be in use. - (void)w; //Suppress unused variable warning + (void)w; // Suppress unused variable warning this->workers_.emplace_back( [this, dataset]() mutable { this->worker_thread(dataset); }); } diff --git a/torch/csrc/api/include/torch/data/datasets/base.h b/torch/csrc/api/include/torch/data/datasets/base.h index 9cd318df5222d8..ba0f0ec839c8a3 100644 --- a/torch/csrc/api/include/torch/data/datasets/base.h +++ b/torch/csrc/api/include/torch/data/datasets/base.h @@ -93,9 +93,9 @@ class Dataset : public BatchDataset> { } }; -/// A `StreamDataset` represents a dataset that is a potentially infinite stream. -/// It takes as batch index only a number, which is the batch size, and yields -/// that many elements from the stream. +/// A `StreamDataset` represents a dataset that is a potentially infinite +/// stream. It takes as batch index only a number, which is the batch size, and +/// yields that many elements from the stream. template >> using StreamDataset = BatchDataset; } // namespace datasets diff --git a/torch/csrc/api/include/torch/data/datasets/chunk.h b/torch/csrc/api/include/torch/data/datasets/chunk.h index 449df78b69cc44..cbead90b55ce56 100644 --- a/torch/csrc/api/include/torch/data/datasets/chunk.h +++ b/torch/csrc/api/include/torch/data/datasets/chunk.h @@ -20,7 +20,9 @@ namespace datasets { /// A chunk could be an entire file, such as an audio data file or an image, /// or part of a file in the case of a large text-file split based on seek /// positions. -template > +template < + typename ExampleType_, + typename ChunkType_ = std::vector> class ChunkDataReader { public: virtual ~ChunkDataReader() = default; @@ -69,8 +71,7 @@ class BatchDataBuffer { // wait till there is available data in the queue or if all chunks are // loaded (i.e. the dataset is exhausted for this epoch) return ( - this->total_example_count_in_queue_ >= batch_size_ || - this->stop_); + this->total_example_count_in_queue_ >= batch_size_ || this->stop_); }); if (batch_queue_.empty()) { AT_ASSERT(stop_); @@ -125,7 +126,8 @@ class BatchDataBuffer { if (!batch_queue_.empty()) { // if the queue has existing data, and the last batch doesn't have enough - // examples to fill a batch_size batch, add more example to this batch first. + // examples to fill a batch_size batch, add more example to this batch + // first. auto& batch = batch_queue_.back(); size_t current_count = batch.batch_data.size(); if (current_count < batch_size_) { @@ -160,10 +162,10 @@ class BatchDataBuffer { cv_write_.wait(lock, [this] { // stop loading if we have preloaded enough data. return ( - this->total_example_count_in_queue_ < this->queue_capacity_ || - this->stop_); + this->total_example_count_in_queue_ < this->queue_capacity_ || + this->stop_); }); - if (stop_){ + if (stop_) { // When stop_ is true, it means this current thread needs to be tore down, // the batch buffer will be discarded, so no need to enqueue any new // exceptions. @@ -175,19 +177,19 @@ class BatchDataBuffer { cv_read_.notify_all(); } - void stop(){ + void stop() { { - // Hold the lock before changing stop_ to prevent a race condition which can - // cause a deadlock. - // To be more specific, conditional variable cv_write_ waits on predicate - // stop_ in add_chunk_data(). The wait happens in two steps: 1) while still - // holding the lock, check if predicate is true; 2) if it is true, proceeds, - // otherwise, release the lock and wait until notified. Without holding a - // lock, cv_write_'s notification can happen in between step 1) and 2). In - // that case, as cv_write_ is not in waiting status yet, so the notification - // is lost and cv_write_ will sleep forever. - // By taking a lock before changing predicate stop_, it is ensured updating - // and evaluating stop_ always happen in a synchronized way + // Hold the lock before changing stop_ to prevent a race condition which + // can cause a deadlock. To be more specific, conditional variable + // cv_write_ waits on predicate stop_ in add_chunk_data(). The wait + // happens in two steps: 1) while still holding the lock, check if + // predicate is true; 2) if it is true, proceeds, otherwise, release the + // lock and wait until notified. Without holding a lock, cv_write_'s + // notification can happen in between step 1) and 2). In that case, as + // cv_write_ is not in waiting status yet, so the notification is lost and + // cv_write_ will sleep forever. By taking a lock before changing + // predicate stop_, it is ensured updating and evaluating stop_ always + // happen in a synchronized way std::lock_guard lock(queue_mutex_); stop_ = true; } @@ -205,11 +207,12 @@ class BatchDataBuffer { /// count of total example stored in the queue size_t total_example_count_in_queue_ = 0; - /// struct that contains a raw unwrapped batch unit. An unwrapped batch unit is - /// the raw data without 'optional' wrapper. It can be a collection of images, - /// utterances, e.t.c. + /// struct that contains a raw unwrapped batch unit. An unwrapped batch unit + /// is the raw data without 'optional' wrapper. It can be a collection of + /// images, utterances, e.t.c. struct UnwrappedBatchData { - explicit UnwrappedBatchData(UnwrappedBatchType data) : batch_data(std::move(data)) {} + explicit UnwrappedBatchData(UnwrappedBatchType data) + : batch_data(std::move(data)) {} // NOLINTNEXTLINE(modernize-pass-by-value) explicit UnwrappedBatchData(std::exception_ptr e) : exception(e) {} @@ -217,8 +220,8 @@ class BatchDataBuffer { /// batch data to return UnwrappedBatchType batch_data; - /// exception pointer which captures any abnormal exceptions while creating the - /// batch. + /// exception pointer which captures any abnormal exceptions while creating + /// the batch. std::exception_ptr exception; }; @@ -236,10 +239,10 @@ class BatchDataBuffer { // configurable maximun number of elements the queue can hold at one time. size_t queue_capacity_; - // When set to true, it wakes the writer threads from the wait and exit current - // function call. This is needed when ChunkDataSet.Reset is called while the - // previous epoch is not exhausted yet. When ChunkDataset is waiting its - // preloader to finish previous work before tearing down the thread, the + // When set to true, it wakes the writer threads from the wait and exit + // current function call. This is needed when ChunkDataSet.Reset is called + // while the previous epoch is not exhausted yet. When ChunkDataset is waiting + // its preloader to finish previous work before tearing down the thread, the // preloader could be still waiting for the conditional variable, thus cause // the program to hang. This boolean is used to break this waiting condition. bool stop_ = false; @@ -339,7 +342,7 @@ class ChunkDataset final running_preloaders_(0), load_checkpoint_(false) {} - ~ChunkDataset() override { + ~ChunkDataset() override { // stop batch buffer first. if (batch_buffer_) { batch_buffer_->stop(); @@ -353,14 +356,16 @@ class ChunkDataset final /// datasets. BatchType get_batch(size_t batch_size) override { TORCH_CHECK( - batch_buffer_ != nullptr, - "Dataset needs to call reset() before calling get_batch()."); + batch_buffer_ != nullptr, + "Dataset needs to call reset() before calling get_batch()."); TORCH_CHECK( - batch_size == options_.batch_size(), - "The requested batch size does not match with the initialized batch size.\n" - " The requested batch size is ", batch_size, - ", while the dataset is created with batch size equal to ", options_.batch_size()); + batch_size == options_.batch_size(), + "The requested batch size does not match with the initialized batch size.\n" + " The requested batch size is ", + batch_size, + ", while the dataset is created with batch size equal to ", + options_.batch_size()); return batch_buffer_->get_batch(); } @@ -380,7 +385,7 @@ class ChunkDataset final free_workers(); preload_threads_.clear(); - if (!load_checkpoint_){ + if (!load_checkpoint_) { chunk_reader_.reset(); chunk_sampler_.reset(chunk_reader_.chunk_count()); load_checkpoint_ = false; @@ -390,9 +395,7 @@ class ChunkDataset final // chunk buffer. batch_buffer_ = torch::make_unique< detail::BatchDataBuffer>( - options_.batch_size(), - example_sampler_, - options_.cache_size()); + options_.batch_size(), example_sampler_, options_.cache_size()); // create new workers for this new epoch. quit_worker_ = false; @@ -420,7 +423,7 @@ class ChunkDataset final chunk_sampler_.save(archive); } - void load(serialize::InputArchive& archive) override{ + void load(serialize::InputArchive& archive) override { std::lock_guard lock(chunk_index_guard_); chunk_sampler_.load(archive); load_checkpoint_ = true; @@ -434,7 +437,8 @@ class ChunkDataset final std::vector chunk_idx; { std::lock_guard lock(chunk_index_guard_); - if (auto chunk_sampler_result = chunk_sampler_.next(this->options_.cross_chunk_shuffle_count())) { + if (auto chunk_sampler_result = chunk_sampler_.next( + this->options_.cross_chunk_shuffle_count())) { chunk_idx = chunk_sampler_result.value(); } else { break; @@ -487,7 +491,8 @@ class ChunkDataset final ExampleSamplerType example_sampler_; // batch data buffer which holds chunk data from preloading thread. - std::shared_ptr> + std::shared_ptr< + detail::BatchDataBuffer> batch_buffer_; // worker thread pool @@ -496,8 +501,8 @@ class ChunkDataset final /// The options the Dataset was configured with. const ChunkDatasetOptions options_; - // function pointer wrapper to apply custom processing over chunk data. This is - // considered an advanced parameter for developers who want to apply a + // function pointer wrapper to apply custom processing over chunk data. This + // is considered an advanced parameter for developers who want to apply a // pre-process to the chunk data before sampling into minibatch. // Different than the collate function, this policy is applied on the chunk // level, instead of minibatch level. When a chunk of data is loaded (multiple @@ -518,7 +523,8 @@ class ChunkDataset final // mutex to synchronize chunk sampler next() call. mutable std::mutex chunk_index_guard_; - // boolean value to indicate whether we need to load the checkpoint for chunk_sampler_. + // boolean value to indicate whether we need to load the checkpoint for + // chunk_sampler_. bool load_checkpoint_; }; } // namespace datasets diff --git a/torch/csrc/api/include/torch/data/datasets/stateful.h b/torch/csrc/api/include/torch/data/datasets/stateful.h index 8ebbc7eaf2cd81..0b1518b27da39d 100644 --- a/torch/csrc/api/include/torch/data/datasets/stateful.h +++ b/torch/csrc/api/include/torch/data/datasets/stateful.h @@ -27,7 +27,9 @@ namespace datasets { /// /// Note that when subclassing a from `StatefulDataset`, the return /// type of `get_batch()`, which the subclass must override, will be -/// `optional` (i.e. the type specified in the `StatefulDataset` specialization is automatically boxed into an `optional` for the dataset's `BatchType`). +/// `optional` (i.e. the type specified in the `StatefulDataset` +/// specialization is automatically boxed into an `optional` for the dataset's +/// `BatchType`). template < typename Self, typename Batch = std::vector>, diff --git a/torch/csrc/api/include/torch/data/detail/sequencers.h b/torch/csrc/api/include/torch/data/detail/sequencers.h index f8c26f7acb1aec..7412263c336ce3 100644 --- a/torch/csrc/api/include/torch/data/detail/sequencers.h +++ b/torch/csrc/api/include/torch/data/detail/sequencers.h @@ -11,7 +11,7 @@ namespace data { namespace detail { namespace sequencers { namespace detail { -template +template bool buffer_contains_result(const std::vector>& buffer) { return std::any_of( buffer.begin(), buffer.end(), [](const optional& result) { diff --git a/torch/csrc/api/include/torch/data/iterator.h b/torch/csrc/api/include/torch/data/iterator.h index 9e17109045f272..82091b59e046a7 100644 --- a/torch/csrc/api/include/torch/data/iterator.h +++ b/torch/csrc/api/include/torch/data/iterator.h @@ -18,9 +18,9 @@ namespace detail { // `Iterator` consists of a `ValidIterator` and a `SentinelIterator`. A // `ValidIterator` yields new batches until the `DataLoader` is exhausted. While // the `DataLoader` is not exhausted, `ValidIterator`s compare equal if they are -// the same object. When the `ValidIterator` becomes exhausted, it compares equal -// to the `SentinelIterator`, but not before. Half the code here is to implement -// double dispatch for the comparison. Got damnit, C++. +// the same object. When the `ValidIterator` becomes exhausted, it compares +// equal to the `SentinelIterator`, but not before. Half the code here is to +// implement double dispatch for the comparison. Got damnit, C++. template struct ValidIterator; diff --git a/torch/csrc/api/include/torch/data/samplers/base.h b/torch/csrc/api/include/torch/data/samplers/base.h index 5469a8ea07a9ec..780f19035c6a3d 100644 --- a/torch/csrc/api/include/torch/data/samplers/base.h +++ b/torch/csrc/api/include/torch/data/samplers/base.h @@ -4,8 +4,8 @@ #include #include -#include #include +#include namespace torch { namespace serialize { diff --git a/torch/csrc/api/include/torch/data/samplers/custom_batch_request.h b/torch/csrc/api/include/torch/data/samplers/custom_batch_request.h index ec48c4450482ca..a553a5d9cd67c7 100644 --- a/torch/csrc/api/include/torch/data/samplers/custom_batch_request.h +++ b/torch/csrc/api/include/torch/data/samplers/custom_batch_request.h @@ -1,7 +1,7 @@ #pragma once -#include #include +#include namespace torch { namespace data { diff --git a/torch/csrc/api/include/torch/data/samplers/random.h b/torch/csrc/api/include/torch/data/samplers/random.h index b5bc8ea0a11026..e8f57b88cc2313 100644 --- a/torch/csrc/api/include/torch/data/samplers/random.h +++ b/torch/csrc/api/include/torch/data/samplers/random.h @@ -26,9 +26,7 @@ class TORCH_API RandomSampler : public Sampler<> { /// The constructor will eagerly allocate all required indices, which is the /// sequence `0 ... size - 1`. `index_dtype` is the data type of the stored /// indices. You can change it to influence memory usage. - explicit RandomSampler( - int64_t size, - Dtype index_dtype = torch::kInt64); + explicit RandomSampler(int64_t size, Dtype index_dtype = torch::kInt64); ~RandomSampler() override; diff --git a/torch/csrc/api/include/torch/detail/TensorDataContainer.h b/torch/csrc/api/include/torch/detail/TensorDataContainer.h index 1aa8f5f433d691..519cd20256df68 100644 --- a/torch/csrc/api/include/torch/detail/TensorDataContainer.h +++ b/torch/csrc/api/include/torch/detail/TensorDataContainer.h @@ -1,8 +1,8 @@ #pragma once -#include #include #include +#include #include @@ -23,7 +23,9 @@ enum class TensorDataContainerType { Scalar, InitList, Tensor }; struct TensorDataContainer; -inline std::ostream& operator<<(std::ostream& stream, const TensorDataContainer& tensor_data_container); +inline std::ostream& operator<<( + std::ostream& stream, + const TensorDataContainer& tensor_data_container); // FIXME: There is no `operator<<` overload for `at::kBFloat16` type, // and we need to convert it to `float` type using `operator float()` function @@ -36,22 +38,24 @@ inline std::ostream& operator<<(std::ostream& stream, c10::BFloat16 value) { inline c10::ScalarType compute_desired_dtype(c10::ScalarType scalar_type) { if (scalar_type == at::kInt || scalar_type == at::kLong) { - // C++ `torch::tensor` with an integer type or an `at::ArrayRef` / `std::vector` / - // (nested) braced-init-list of integer types always produces a tensor of dtype `at::kLong` - // (aka. int64_t), matching Python `torch.tensor` behavior. + // C++ `torch::tensor` with an integer type or an `at::ArrayRef` / + // `std::vector` / (nested) braced-init-list of integer types always + // produces a tensor of dtype `at::kLong` (aka. int64_t), matching Python + // `torch.tensor` behavior. return at::kLong; } else if (scalar_type == at::kFloat || scalar_type == at::kDouble) { - // C++ `torch::tensor` with a floating-point type or an `at::ArrayRef` / `std::vector` / - // (nested) braced-init-list of floating-point types always produces a tensor of dtype - // `torch::get_default_dtype()`, matching Python `torch.tensor` behavior. + // C++ `torch::tensor` with a floating-point type or an `at::ArrayRef` / + // `std::vector` / (nested) braced-init-list of floating-point types always + // produces a tensor of dtype `torch::get_default_dtype()`, matching Python + // `torch.tensor` behavior. return at::typeMetaToScalarType(at::get_default_dtype()); } else { return scalar_type; } } -// We use `TensorDataContainer` to support converting the following data container types -// into the equivalent Tensor: +// We use `TensorDataContainer` to support converting the following data +// container types into the equivalent Tensor: // // 1. Arbitrarily nested braced-init-list (e.g. `{{1, 2}, {3, 4}}`). // 2. `at::ArrayRef` of supported tensor data types. @@ -61,113 +65,132 @@ inline c10::ScalarType compute_desired_dtype(c10::ScalarType scalar_type) { // // 1. A scalar with value `scalar()` and type `scalar_type()`. // 2. A Tensor represented in `std::initializer_list` form, -// with value `init_list()`, Tensor scalar type `scalar_type()`, and Tensor sizes `sizes()`. -// 3. A Tensor represented in `at::Tensor` form, with value `tensor()`, scalar type `scalar_type()`, +// with value `init_list()`, Tensor scalar type `scalar_type()`, and Tensor +// sizes `sizes()`. +// 3. A Tensor represented in `at::Tensor` form, with value `tensor()`, scalar +// type `scalar_type()`, // and Tensor sizes `sizes()`. // -// All the infrastructure here is mostly to support converting an arbitrarily nested braced-init-list -// to the equivalent Tensor successfully. Consider the following example: +// All the infrastructure here is mostly to support converting an arbitrarily +// nested braced-init-list to the equivalent Tensor successfully. Consider the +// following example: // // `torch::tensor({{1}, {2}})` // // this will call into the `torch::tensor` function: // -// `at::Tensor tensor(detail::TensorDataContainer tensor_data_container, const at::TensorOptions& options = {})` +// `at::Tensor tensor(detail::TensorDataContainer tensor_data_container, const +// at::TensorOptions& options = {})` // -// the compiler will first try to convert `{{1}, {2}}` to `TensorDataContainer` type: +// the compiler will first try to convert `{{1}, {2}}` to `TensorDataContainer` +// type: // // `TensorDataContainer({{1}, {2}})` // -// which matches to the `TensorDataContainer(std::initializer_list)` constructor, -// and in an attempt to convert `{1}` and `{2}` to `TensorDataContainer`, it calls the following: +// which matches to the +// `TensorDataContainer(std::initializer_list)` +// constructor, and in an attempt to convert `{1}` and `{2}` to +// `TensorDataContainer`, it calls the following: // -// `TensorDataContainer({1})` (same call path happens for `{2}`, and we'll just focus on `{1}` here) +// `TensorDataContainer({1})` (same call path happens for `{2}`, and we'll just +// focus on `{1}` here) // -// At this point, theoretically there are two plausible ways for `{1}` to be matched to one of the -// constructors of `TensorDataContainer`: +// At this point, theoretically there are two plausible ways for `{1}` to be +// matched to one of the constructors of `TensorDataContainer`: // -// 1. It can be a list-initialization of a scalar value, thus matching `TensorDataContainer(int value)`. -// 2. It can be converted to `std::initializer_list`, thus matching +// 1. It can be a list-initialization of a scalar value, thus matching +// `TensorDataContainer(int value)`. +// 2. It can be converted to `std::initializer_list`, thus +// matching // `TensorDataContainer(std::initializer_list)`. // // How does the compiler decide which one to choose? According to -// `https://en.cppreference.com/w/cpp/language/list_initialization`, braced-init-list always prefers -// the constructor that takes `std::initializer_list`. Hence we happily move forward with constructor #2, +// `https://en.cppreference.com/w/cpp/language/list_initialization`, +// braced-init-list always prefers the constructor that takes +// `std::initializer_list`. Hence we happily move forward with constructor #2, // and it calls the following: // // `TensorDataContainer(1)` // -// Now it matches `TensorDataContainer(int value)`, which stores `1` as a scalar value. All is good. +// Now it matches `TensorDataContainer(int value)`, which stores `1` as a scalar +// value. All is good. struct TensorDataContainer { - // NOTE: For tensors with zero-size dimensions (e.g. `torch::tensor({{}, {}})`), - // the innermost empty braced-init-list `{}` matches the default constructor of - // the innermost `TensorDataContainer`. + // NOTE: For tensors with zero-size dimensions (e.g. `torch::tensor({{}, + // {}})`), the innermost empty braced-init-list `{}` matches the default + // constructor of the innermost `TensorDataContainer`. + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) + TensorDataContainer() + : sizes_({0}), + // NOTE: In Python, the dtype of tensors with zero-size dimensions (e.g. + // `torch.tensor([[], []])`) depends on the value of + // `torch.get_default_dtype()`, and we should do the same for the C++ + // equivalent. + scalar_type_(at::typeMetaToScalarType(at::get_default_dtype())), + type_(TensorDataContainerType::InitList) {} +#define TENSOR(T, S) \ + TensorDataContainer(T value) \ + : sizes_(), \ + scalar_type_(at::k##S), \ + type_(TensorDataContainerType::Scalar), \ + scalar_(value) {} // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) - TensorDataContainer() : - sizes_({0}), - // NOTE: In Python, the dtype of tensors with zero-size dimensions (e.g. `torch.tensor([[], []])`) - // depends on the value of `torch.get_default_dtype()`, and we should do the same for the C++ equivalent. - scalar_type_(at::typeMetaToScalarType(at::get_default_dtype())), - type_(TensorDataContainerType::InitList) {} -#define TENSOR(T, S) \ - TensorDataContainer(T value) : \ - sizes_(), \ - scalar_type_(at::k##S), \ - type_(TensorDataContainerType::Scalar), \ - scalar_(value) {} -// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) -AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, TENSOR) -// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) -AT_FORALL_COMPLEX_TYPES(TENSOR) + AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, TENSOR) + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) + AT_FORALL_COMPLEX_TYPES(TENSOR) #undef TENSOR // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) - TensorDataContainer(std::initializer_list init_list) : - sizes_(), - scalar_type_(init_list.begin()->scalar_type()), - type_(TensorDataContainerType::InitList), - init_list_(init_list) { + TensorDataContainer(std::initializer_list init_list) + : sizes_(), + scalar_type_(init_list.begin()->scalar_type()), + type_(TensorDataContainerType::InitList), + init_list_(init_list) { const TensorDataContainer& first_elem = *(init_list.begin()); for (const auto& elem : init_list) { - TORCH_CHECK(elem.sizes() == first_elem.sizes(), - "Expected all sub-lists to have sizes: ", - first_elem.sizes(), - " (e.g. ", first_elem, "), ", - "but got sub-list ", - elem, - " with sizes: ", - elem.sizes()); - TORCH_CHECK(elem.scalar_type() == first_elem.scalar_type(), - "Expected all elements of the tensor to have the same scalar type: ", - first_elem.scalar_type(), - ", but got element of scalar type: ", - elem.scalar_type()); + TORCH_CHECK( + elem.sizes() == first_elem.sizes(), + "Expected all sub-lists to have sizes: ", + first_elem.sizes(), + " (e.g. ", + first_elem, + "), ", + "but got sub-list ", + elem, + " with sizes: ", + elem.sizes()); + TORCH_CHECK( + elem.scalar_type() == first_elem.scalar_type(), + "Expected all elements of the tensor to have the same scalar type: ", + first_elem.scalar_type(), + ", but got element of scalar type: ", + elem.scalar_type()); } sizes_.reserve(first_elem.sizes().size() + 1); sizes_.push_back(init_list.size()); - sizes_.insert(sizes_.end(), first_elem.sizes().begin(), first_elem.sizes().end()); + sizes_.insert( + sizes_.end(), first_elem.sizes().begin(), first_elem.sizes().end()); } -#define TENSOR(T, S) \ - TensorDataContainer(at::ArrayRef values) : \ - sizes_({(int64_t)values.size()}), \ - scalar_type_(at::k##S), \ - type_(TensorDataContainerType::Tensor) { \ - at::AutoDispatchBelowAutograd mode; \ - if (scalar_type_ == at::kBool) { \ - tensor_ = at::tensor(values, at::TensorOptions().device(at::kCPU)); \ - } else { \ +#define TENSOR(T, S) \ + TensorDataContainer(at::ArrayRef values) \ + : sizes_({(int64_t)values.size()}), \ + scalar_type_(at::k##S), \ + type_(TensorDataContainerType::Tensor) { \ + at::AutoDispatchBelowAutograd mode; \ + if (scalar_type_ == at::kBool) { \ + tensor_ = at::tensor(values, at::TensorOptions().device(at::kCPU)); \ + } else { \ tensor_ = at::tensor(values, at::dtype(scalar_type_).device(at::kCPU)); \ - } \ + } \ } -// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) -AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, TENSOR) -// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) -AT_FORALL_COMPLEX_TYPES(TENSOR) + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) + AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, TENSOR) + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) + AT_FORALL_COMPLEX_TYPES(TENSOR) #undef TENSOR - // NOTE: We need to handle `std::vector` explicitly instead of relying on an implicit conversion - // to `at::ArrayRef`, otherwise the following error can be thrown when calling - // `torch::tensor(std::vector({1, 2}))`: + // NOTE: We need to handle `std::vector` explicitly instead of relying on an + // implicit conversion to `at::ArrayRef`, otherwise the following error can be + // thrown when calling `torch::tensor(std::vector({1, 2}))`: // ``` // error: no matching function for call to 'tensor(const std::vector&)' // no known conversion for argument 1 from 'const std::vector' to @@ -176,12 +199,13 @@ AT_FORALL_COMPLEX_TYPES(TENSOR) // // NOTE: `torch::tensor(std::vector)` is not supported for now, because // ArrayRef cannot be constructed from a std::vector bitfield. -#define TENSOR(T, S) \ - TensorDataContainer(const std::vector& values) : TensorDataContainer(at::ArrayRef(values)) {} -// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) -AT_FORALL_SCALAR_TYPES_AND2(Half, BFloat16, TENSOR) -// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) -AT_FORALL_COMPLEX_TYPES(TENSOR) +#define TENSOR(T, S) \ + TensorDataContainer(const std::vector& values) \ + : TensorDataContainer(at::ArrayRef(values)) {} + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) + AT_FORALL_SCALAR_TYPES_AND2(Half, BFloat16, TENSOR) + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) + AT_FORALL_COMPLEX_TYPES(TENSOR) #undef TENSOR bool is_scalar() const { @@ -190,8 +214,8 @@ AT_FORALL_COMPLEX_TYPES(TENSOR) const c10::Scalar& scalar() const { TORCH_CHECK( - is_scalar(), - "Can only call `scalar()` on a TensorDataContainer that has `is_scalar() == true`"); + is_scalar(), + "Can only call `scalar()` on a TensorDataContainer that has `is_scalar() == true`"); return scalar_; } @@ -201,8 +225,8 @@ AT_FORALL_COMPLEX_TYPES(TENSOR) const std::initializer_list& init_list() const { TORCH_CHECK( - is_init_list(), - "Can only call `init_list()` on a TensorDataContainer that has `is_init_list() == true`"); + is_init_list(), + "Can only call `init_list()` on a TensorDataContainer that has `is_init_list() == true`"); return init_list_; } @@ -212,8 +236,8 @@ AT_FORALL_COMPLEX_TYPES(TENSOR) const at::Tensor& tensor() const { TORCH_CHECK( - is_tensor(), - "Can only call `tensor()` on a TensorDataContainer that has `is_tensor() == true`"); + is_tensor(), + "Can only call `tensor()` on a TensorDataContainer that has `is_tensor() == true`"); return tensor_; } @@ -235,11 +259,11 @@ AT_FORALL_COMPLEX_TYPES(TENSOR) return at::scalar_tensor(scalar_, options); } else if (is_init_list()) { // NOTE: Here we explicitly choose to initialize the tensor on CPU first, - // fill each element of the tensor, and then move the tensor to the desired - // device. For CUDA device, this approach only involves 1 CUDA kernel launch, - // and is much faster than initializing the tensor on CUDA first and then - // filling each element of it (which involves `N` CUDA kernel launches where - // `N` is the number of the elements in the tensor). + // fill each element of the tensor, and then move the tensor to the + // desired device. For CUDA device, this approach only involves 1 CUDA + // kernel launch, and is much faster than initializing the tensor on CUDA + // first and then filling each element of it (which involves `N` CUDA + // kernel launches where `N` is the number of the elements in the tensor). at::Tensor tensor = ([&]() { at::AutoDispatchBelowAutograd mode; return at::empty(sizes_, options.device(at::kCPU)); @@ -248,7 +272,9 @@ AT_FORALL_COMPLEX_TYPES(TENSOR) return tensor.to(options.device()); } else if (is_tensor()) { auto output = tensor_.to(options); - TORCH_CHECK(!tensor_.is_complex() || output.is_complex(), "can not do torch::tensor(complex, dtype=non-complex) because complex can not be casted to real number without loss of information"); + TORCH_CHECK( + !tensor_.is_complex() || output.is_complex(), + "can not do torch::tensor(complex, dtype=non-complex) because complex can not be casted to real number without loss of information"); return output; } else { TORCH_INTERNAL_ASSERT(false, "Invalid TensorDataContainer type"); @@ -262,50 +288,54 @@ AT_FORALL_COMPLEX_TYPES(TENSOR) at::kHalf, at::kBFloat16, scalar_type_, - "TensorDataContainer_pretty_print_scalar", [&] { - stream << scalar_.to(); - }); + "TensorDataContainer_pretty_print_scalar", + [&] { stream << scalar_.to(); }); } else if (is_init_list()) { stream << "{"; - for (const TensorDataContainer* it = init_list_.begin(); it != init_list_.end(); it++) { + for (const TensorDataContainer* it = init_list_.begin(); + it != init_list_.end(); + it++) { stream << *it; - if (std::next(it) != init_list_.end()) stream << ", "; + if (std::next(it) != init_list_.end()) + stream << ", "; } stream << "}"; } else if (is_tensor()) { stream << "{"; - for (const auto i : c10::irange(tensor_.sizes()[0])) { + for (const auto i : c10::irange(tensor_.sizes()[0])) { AT_DISPATCH_ALL_TYPES_AND3( at::kBool, at::kHalf, at::kBFloat16, scalar_type_, - "TensorDataContainer_pretty_print_tensor_item", [&] { - stream << tensor_[i].item(); - }); - if (i != tensor_.sizes()[0] - 1) stream << ", "; + "TensorDataContainer_pretty_print_tensor_item", + [&] { stream << tensor_[i].item(); }); + if (i != tensor_.sizes()[0] - 1) + stream << ", "; } stream << "}"; } else { TORCH_INTERNAL_ASSERT(false, "Invalid TensorDataContainer type"); } } + private: void fill_tensor(at::Tensor& tensor) const { if (is_scalar()) { TORCH_INTERNAL_ASSERT( - tensor.dim() == 0, - "Expected a 0-dim Tensor, but got Tensor with dimensions: ", tensor.dim()); + tensor.dim() == 0, + "Expected a 0-dim Tensor, but got Tensor with dimensions: ", + tensor.dim()); at::NoGradGuard guard; tensor.fill_(scalar_); } else if (is_init_list()) { TORCH_INTERNAL_ASSERT( - tensor.sizes()[0] == (int64_t)init_list_.size(), - "Expected a Tensor with size ", - init_list_.size(), - " in its first dimension, but got Tensor with size ", - tensor.sizes()[0], - " in its first dimension"); + tensor.sizes()[0] == (int64_t)init_list_.size(), + "Expected a Tensor with size ", + init_list_.size(), + " in its first dimension, but got Tensor with size ", + tensor.sizes()[0], + " in its first dimension"); size_t index = 0; for (const auto& elem : init_list_) { at::Tensor slice = tensor[index]; @@ -314,8 +344,8 @@ AT_FORALL_COMPLEX_TYPES(TENSOR) } } else if (is_tensor()) { TORCH_INTERNAL_ASSERT( - false, - "TensorDataContainer is already a Tensor type, `fill_tensor` should not be called"); + false, + "TensorDataContainer is already a Tensor type, `fill_tensor` should not be called"); } else { TORCH_INTERNAL_ASSERT(false, "Invalid TensorDataContainer type"); } @@ -329,7 +359,9 @@ AT_FORALL_COMPLEX_TYPES(TENSOR) at::Tensor tensor_; }; -inline std::ostream& operator<<(std::ostream& stream, const TensorDataContainer& tensor_data_container) { +inline std::ostream& operator<<( + std::ostream& stream, + const TensorDataContainer& tensor_data_container) { tensor_data_container.pretty_print_recursive(stream); return stream; } diff --git a/torch/csrc/api/include/torch/enum.h b/torch/csrc/api/include/torch/enum.h index 8ae692af1a6ead..0e52d22b21c9f9 100644 --- a/torch/csrc/api/include/torch/enum.h +++ b/torch/csrc/api/include/torch/enum.h @@ -7,32 +7,34 @@ #include #include -#define TORCH_ENUM_DECLARE(name) \ -namespace torch { \ -namespace enumtype { \ - /* - NOTE: We need to provide the default constructor for each struct, - otherwise Clang 3.8 would complain: - ``` - error: default initialization of an object of const type 'const enumtype::Enum1' - without a user-provided default constructor - ``` - */ \ - struct k##name { k##name() {} }; \ -} \ -TORCH_API extern const enumtype::k##name k##name; \ -} +#define TORCH_ENUM_DECLARE(name) \ + namespace torch { \ + namespace enumtype { \ + /* \ + NOTE: We need to provide the default constructor for each struct, \ + otherwise Clang 3.8 would complain: \ + ``` \ + error: default initialization of an object of const type 'const \ + enumtype::Enum1' without a user-provided default constructor \ + ``` \ + */ \ + struct k##name { \ + k##name() {} \ + }; \ + } \ + TORCH_API extern const enumtype::k##name k##name; \ + } -#define TORCH_ENUM_DEFINE(name) \ -namespace torch { \ -const enumtype::k##name k##name; \ -} +#define TORCH_ENUM_DEFINE(name) \ + namespace torch { \ + const enumtype::k##name k##name; \ + } -#define TORCH_ENUM_PRETTY_PRINT(name) \ -std::string operator()(const enumtype::k##name& v) const { \ - std::string k("k"); \ - return k + #name; \ -} +#define TORCH_ENUM_PRETTY_PRINT(name) \ + std::string operator()(const enumtype::k##name& v) const { \ + std::string k("k"); \ + return k + #name; \ + } // NOTE: Backstory on why we need the following two macros: // @@ -40,8 +42,9 @@ std::string operator()(const enumtype::k##name& v) const { \ // // ``` // struct TORCH_API SomeOptions { -// typedef c10::variant reduction_t; -// SomeOptions(reduction_t reduction = torch::kMean) : reduction_(reduction) {} +// typedef c10::variant +// reduction_t; SomeOptions(reduction_t reduction = torch::kMean) : +// reduction_(reduction) {} // // TORCH_ARG(reduction_t, reduction); // }; @@ -64,10 +67,12 @@ std::string operator()(const enumtype::k##name& v) const { \ // However, it throws the following error instead: // // ``` -// error: could not convert `torch::kNone` from `const torch::enumtype::kNone` to `torch::nn::SomeOptions` +// error: could not convert `torch::kNone` from `const torch::enumtype::kNone` +// to `torch::nn::SomeOptions` // ``` // -// To get around this problem, we explicitly provide the following constructors for `SomeOptions`: +// To get around this problem, we explicitly provide the following constructors +// for `SomeOptions`: // // ``` // SomeOptions(torch::enumtype::kNone reduction) : reduction_(torch::kNone) {} @@ -79,18 +84,20 @@ std::string operator()(const enumtype::k##name& v) const { \ // // Note that we also provide the default constructor `SomeOptions() {}`, so that // `SomeOptions options = {}` can work. -#define TORCH_OPTIONS_CTOR_VARIANT_ARG3(OPTIONS_NAME, ARG_NAME, TYPE1, TYPE2, TYPE3) \ -OPTIONS_NAME() {} \ -OPTIONS_NAME(torch::enumtype::TYPE1 ARG_NAME) : ARG_NAME##_(torch::TYPE1) {} \ -OPTIONS_NAME(torch::enumtype::TYPE2 ARG_NAME) : ARG_NAME##_(torch::TYPE2) {} \ -OPTIONS_NAME(torch::enumtype::TYPE3 ARG_NAME) : ARG_NAME##_(torch::TYPE3) {} +#define TORCH_OPTIONS_CTOR_VARIANT_ARG3( \ + OPTIONS_NAME, ARG_NAME, TYPE1, TYPE2, TYPE3) \ + OPTIONS_NAME() {} \ + OPTIONS_NAME(torch::enumtype::TYPE1 ARG_NAME) : ARG_NAME##_(torch::TYPE1) {} \ + OPTIONS_NAME(torch::enumtype::TYPE2 ARG_NAME) : ARG_NAME##_(torch::TYPE2) {} \ + OPTIONS_NAME(torch::enumtype::TYPE3 ARG_NAME) : ARG_NAME##_(torch::TYPE3) {} -#define TORCH_OPTIONS_CTOR_VARIANT_ARG4(OPTIONS_NAME, ARG_NAME, TYPE1, TYPE2, TYPE3, TYPE4) \ -OPTIONS_NAME() {} \ -OPTIONS_NAME(torch::enumtype::TYPE1 ARG_NAME) : ARG_NAME##_(torch::TYPE1) {} \ -OPTIONS_NAME(torch::enumtype::TYPE2 ARG_NAME) : ARG_NAME##_(torch::TYPE2) {} \ -OPTIONS_NAME(torch::enumtype::TYPE3 ARG_NAME) : ARG_NAME##_(torch::TYPE3) {} \ -OPTIONS_NAME(torch::enumtype::TYPE4 ARG_NAME) : ARG_NAME##_(torch::TYPE4) {} +#define TORCH_OPTIONS_CTOR_VARIANT_ARG4( \ + OPTIONS_NAME, ARG_NAME, TYPE1, TYPE2, TYPE3, TYPE4) \ + OPTIONS_NAME() {} \ + OPTIONS_NAME(torch::enumtype::TYPE1 ARG_NAME) : ARG_NAME##_(torch::TYPE1) {} \ + OPTIONS_NAME(torch::enumtype::TYPE2 ARG_NAME) : ARG_NAME##_(torch::TYPE2) {} \ + OPTIONS_NAME(torch::enumtype::TYPE3 ARG_NAME) : ARG_NAME##_(torch::TYPE3) {} \ + OPTIONS_NAME(torch::enumtype::TYPE4 ARG_NAME) : ARG_NAME##_(torch::TYPE4) {} TORCH_ENUM_DECLARE(Linear) TORCH_ENUM_DECLARE(Conv1D) @@ -194,8 +201,9 @@ at::Reduction::Reduction reduction_get_enum(V variant_enum) { return at::Reduction::Sum; } else { TORCH_CHECK( - false, - get_enum_name(variant_enum), " is not a valid value for reduction"); + false, + get_enum_name(variant_enum), + " is not a valid value for reduction"); return at::Reduction::END; } } diff --git a/torch/csrc/api/include/torch/expanding_array.h b/torch/csrc/api/include/torch/expanding_array.h index 7814e0036fc931..aa4fecf4ff37c3 100644 --- a/torch/csrc/api/include/torch/expanding_array.h +++ b/torch/csrc/api/include/torch/expanding_array.h @@ -103,31 +103,34 @@ std::ostream& operator<<( return stream << static_cast>(expanding_array); } -/// A utility class that accepts either a container of `D`-many `c10::optional` values, -/// or a single `c10::optional` value, which is internally repeated `D` times. -/// It has the additional ability to accept containers of the underlying type `T` and -/// convert them to a container of `c10::optional`. +/// A utility class that accepts either a container of `D`-many +/// `c10::optional` values, or a single `c10::optional` value, which is +/// internally repeated `D` times. It has the additional ability to accept +/// containers of the underlying type `T` and convert them to a container of +/// `c10::optional`. template -class ExpandingArrayWithOptionalElem : public ExpandingArray> { +class ExpandingArrayWithOptionalElem + : public ExpandingArray> { public: using ExpandingArray>::ExpandingArray; - /// Constructs an `ExpandingArrayWithOptionalElem` from an `initializer_list` of the underlying type `T`. - /// The extent of the length is checked against the `ExpandingArrayWithOptionalElem`'s extent parameter `D` - /// at runtime. + /// Constructs an `ExpandingArrayWithOptionalElem` from an `initializer_list` + /// of the underlying type `T`. The extent of the length is checked against + /// the `ExpandingArrayWithOptionalElem`'s extent parameter `D` at runtime. /*implicit*/ ExpandingArrayWithOptionalElem(std::initializer_list list) : ExpandingArrayWithOptionalElem(at::ArrayRef(list)) {} - /// Constructs an `ExpandingArrayWithOptionalElem` from an `std::vector` of the underlying type `T`. - /// The extent of the length is checked against the `ExpandingArrayWithOptionalElem`'s extent parameter `D` - /// at runtime. + /// Constructs an `ExpandingArrayWithOptionalElem` from an `std::vector` of + /// the underlying type `T`. The extent of the length is checked against the + /// `ExpandingArrayWithOptionalElem`'s extent parameter `D` at runtime. /*implicit*/ ExpandingArrayWithOptionalElem(std::vector vec) : ExpandingArrayWithOptionalElem(at::ArrayRef(vec)) {} - /// Constructs an `ExpandingArrayWithOptionalElem` from an `at::ArrayRef` of the underlying type `T`. - /// The extent of the length is checked against the `ExpandingArrayWithOptionalElem`'s extent parameter `D` - /// at runtime. - /*implicit*/ ExpandingArrayWithOptionalElem(at::ArrayRef values) : ExpandingArray>(0) { + /// Constructs an `ExpandingArrayWithOptionalElem` from an `at::ArrayRef` of + /// the underlying type `T`. The extent of the length is checked against the + /// `ExpandingArrayWithOptionalElem`'s extent parameter `D` at runtime. + /*implicit*/ ExpandingArrayWithOptionalElem(at::ArrayRef values) + : ExpandingArray>(0) { // clang-format off TORCH_CHECK( values.size() == D, @@ -138,16 +141,20 @@ class ExpandingArrayWithOptionalElem : public ExpandingArray } } - /// Constructs an `ExpandingArrayWithOptionalElem` from a single value of the underlying type `T`, - /// which is repeated `D` times (where `D` is the extent parameter of the `ExpandingArrayWithOptionalElem`). - /*implicit*/ ExpandingArrayWithOptionalElem(T single_size) : ExpandingArray>(0) { + /// Constructs an `ExpandingArrayWithOptionalElem` from a single value of the + /// underlying type `T`, which is repeated `D` times (where `D` is the extent + /// parameter of the `ExpandingArrayWithOptionalElem`). + /*implicit*/ ExpandingArrayWithOptionalElem(T single_size) + : ExpandingArray>(0) { for (const auto i : c10::irange(this->values_.size())) { this->values_[i] = single_size; } } - /// Constructs an `ExpandingArrayWithOptionalElem` from a correctly sized `std::array` of the underlying type `T`. - /*implicit*/ ExpandingArrayWithOptionalElem(const std::array& values) : ExpandingArray>(0) { + /// Constructs an `ExpandingArrayWithOptionalElem` from a correctly sized + /// `std::array` of the underlying type `T`. + /*implicit*/ ExpandingArrayWithOptionalElem(const std::array& values) + : ExpandingArray>(0) { for (const auto i : c10::irange(this->values_.size())) { this->values_[i] = values[i]; } @@ -164,7 +171,8 @@ std::ostream& operator<<( } else { std::vector str_array; for (const auto& elem : *expanding_array_with_opt_elem) { - str_array.emplace_back(elem.has_value() ? c10::str(elem.value()) : "None"); + str_array.emplace_back( + elem.has_value() ? c10::str(elem.value()) : "None"); } stream << at::ArrayRef(str_array); } diff --git a/torch/csrc/api/include/torch/fft.h b/torch/csrc/api/include/torch/fft.h index 71a3146c990f18..673449ad056e79 100644 --- a/torch/csrc/api/include/torch/fft.h +++ b/torch/csrc/api/include/torch/fft.h @@ -13,10 +13,11 @@ namespace fft { /// auto t = torch::randn(128, dtype=kComplexDouble); /// torch::fft::fft(t); /// ``` -inline Tensor fft(const Tensor& self, - c10::optional n=c10::nullopt, - int64_t dim=-1, - c10::optional norm=c10::nullopt) { +inline Tensor fft( + const Tensor& self, + c10::optional n = c10::nullopt, + int64_t dim = -1, + c10::optional norm = c10::nullopt) { return torch::fft_fft(self, n, dim, norm); } @@ -28,10 +29,11 @@ inline Tensor fft(const Tensor& self, /// auto t = torch::randn(128, dtype=kComplexDouble); /// torch::fft::ifft(t); /// ``` -inline Tensor ifft(const Tensor& self, - c10::optional n=c10::nullopt, - int64_t dim=-1, - c10::optional norm=c10::nullopt) { +inline Tensor ifft( + const Tensor& self, + c10::optional n = c10::nullopt, + int64_t dim = -1, + c10::optional norm = c10::nullopt) { return torch::fft_ifft(self, n, dim, norm); } @@ -43,10 +45,11 @@ inline Tensor ifft(const Tensor& self, /// auto t = torch::randn({128, 128}, dtype=kComplexDouble); /// torch::fft::fft2(t); /// ``` -inline Tensor fft2(const Tensor& self, - OptionalIntArrayRef s=c10::nullopt, - IntArrayRef dim={-2, -1}, - c10::optional norm=c10::nullopt) { +inline Tensor fft2( + const Tensor& self, + OptionalIntArrayRef s = c10::nullopt, + IntArrayRef dim = {-2, -1}, + c10::optional norm = c10::nullopt) { return torch::fft_fft2(self, s, dim, norm); } @@ -58,10 +61,11 @@ inline Tensor fft2(const Tensor& self, /// auto t = torch::randn({128, 128}, dtype=kComplexDouble); /// torch::fft::ifft2(t); /// ``` -inline Tensor ifft2(const Tensor& self, - at::OptionalIntArrayRef s=c10::nullopt, - IntArrayRef dim={-2, -1}, - c10::optional norm=c10::nullopt) { +inline Tensor ifft2( + const Tensor& self, + at::OptionalIntArrayRef s = c10::nullopt, + IntArrayRef dim = {-2, -1}, + c10::optional norm = c10::nullopt) { return torch::fft_ifft2(self, s, dim, norm); } @@ -73,10 +77,11 @@ inline Tensor ifft2(const Tensor& self, /// auto t = torch::randn({128, 128}, dtype=kComplexDouble); /// torch::fft::fftn(t); /// ``` -inline Tensor fftn(const Tensor& self, - at::OptionalIntArrayRef s=c10::nullopt, - at::OptionalIntArrayRef dim=c10::nullopt, - c10::optional norm=c10::nullopt) { +inline Tensor fftn( + const Tensor& self, + at::OptionalIntArrayRef s = c10::nullopt, + at::OptionalIntArrayRef dim = c10::nullopt, + c10::optional norm = c10::nullopt) { return torch::fft_fftn(self, s, dim, norm); } @@ -88,10 +93,11 @@ inline Tensor fftn(const Tensor& self, /// auto t = torch::randn({128, 128}, dtype=kComplexDouble); /// torch::fft::ifftn(t); /// ``` -inline Tensor ifftn(const Tensor& self, - at::OptionalIntArrayRef s=c10::nullopt, - at::OptionalIntArrayRef dim=c10::nullopt, - c10::optional norm=c10::nullopt) { +inline Tensor ifftn( + const Tensor& self, + at::OptionalIntArrayRef s = c10::nullopt, + at::OptionalIntArrayRef dim = c10::nullopt, + c10::optional norm = c10::nullopt) { return torch::fft_ifftn(self, s, dim, norm); } @@ -104,10 +110,11 @@ inline Tensor ifftn(const Tensor& self, /// auto T = torch::fft::rfft(t); /// assert(T.is_complex() && T.numel() == 128 / 2 + 1); /// ``` -inline Tensor rfft(const Tensor& self, - c10::optional n=c10::nullopt, - int64_t dim=-1, - c10::optional norm=c10::nullopt) { +inline Tensor rfft( + const Tensor& self, + c10::optional n = c10::nullopt, + int64_t dim = -1, + c10::optional norm = c10::nullopt) { return torch::fft_rfft(self, n, dim, norm); } @@ -122,25 +129,27 @@ inline Tensor rfft(const Tensor& self, /// auto t = torch::fft::irfft(t, /*n=*/128); /// assert(t.is_floating_point() && T.numel() == 128); /// ``` -inline Tensor irfft(const Tensor& self, - c10::optional n=c10::nullopt, - int64_t dim=-1, - c10::optional norm=c10::nullopt) { +inline Tensor irfft( + const Tensor& self, + c10::optional n = c10::nullopt, + int64_t dim = -1, + c10::optional norm = c10::nullopt) { return torch::fft_irfft(self, n, dim, norm); } -/// Computes the 2-dimensional FFT of real input. Returns a onesided Hermitian output. -/// See https://pytorch.org/docs/master/fft.html#torch.fft.rfft2 +/// Computes the 2-dimensional FFT of real input. Returns a onesided Hermitian +/// output. See https://pytorch.org/docs/master/fft.html#torch.fft.rfft2 /// /// Example: /// ``` /// auto t = torch::randn({128, 128}, dtype=kDouble); /// torch::fft::rfft2(t); /// ``` -inline Tensor rfft2(const Tensor& self, - at::OptionalIntArrayRef s=c10::nullopt, - IntArrayRef dim={-2, -1}, - c10::optional norm=c10::nullopt) { +inline Tensor rfft2( + const Tensor& self, + at::OptionalIntArrayRef s = c10::nullopt, + IntArrayRef dim = {-2, -1}, + c10::optional norm = c10::nullopt) { return torch::fft_rfft2(self, s, dim, norm); } @@ -152,10 +161,11 @@ inline Tensor rfft2(const Tensor& self, /// auto t = torch::randn({128, 128}, dtype=kComplexDouble); /// torch::fft::irfft2(t); /// ``` -inline Tensor irfft2(const Tensor& self, - at::OptionalIntArrayRef s=c10::nullopt, - IntArrayRef dim={-2, -1}, - c10::optional norm=c10::nullopt) { +inline Tensor irfft2( + const Tensor& self, + at::OptionalIntArrayRef s = c10::nullopt, + IntArrayRef dim = {-2, -1}, + c10::optional norm = c10::nullopt) { return torch::fft_irfft2(self, s, dim, norm); } @@ -167,10 +177,11 @@ inline Tensor irfft2(const Tensor& self, /// auto t = torch::randn({128, 128}, dtype=kDouble); /// torch::fft::rfftn(t); /// ``` -inline Tensor rfftn(const Tensor& self, - at::OptionalIntArrayRef s=c10::nullopt, - at::OptionalIntArrayRef dim=c10::nullopt, - c10::optional norm=c10::nullopt) { +inline Tensor rfftn( + const Tensor& self, + at::OptionalIntArrayRef s = c10::nullopt, + at::OptionalIntArrayRef dim = c10::nullopt, + c10::optional norm = c10::nullopt) { return torch::fft_rfftn(self, s, dim, norm); } @@ -182,10 +193,11 @@ inline Tensor rfftn(const Tensor& self, /// auto t = torch::randn({128, 128}, dtype=kComplexDouble); /// torch::fft::irfftn(t); /// ``` -inline Tensor irfftn(const Tensor& self, - at::OptionalIntArrayRef s=c10::nullopt, - at::OptionalIntArrayRef dim=c10::nullopt, - c10::optional norm=c10::nullopt) { +inline Tensor irfftn( + const Tensor& self, + at::OptionalIntArrayRef s = c10::nullopt, + at::OptionalIntArrayRef dim = c10::nullopt, + c10::optional norm = c10::nullopt) { return torch::fft_irfftn(self, s, dim, norm); } @@ -201,10 +213,11 @@ inline Tensor irfftn(const Tensor& self, /// auto T = torch::fft::hfft(t, /*n=*/128); /// assert(T.is_floating_point() && T.numel() == 128); /// ``` -inline Tensor hfft(const Tensor& self, - c10::optional n=c10::nullopt, - int64_t dim=-1, - c10::optional norm=c10::nullopt) { +inline Tensor hfft( + const Tensor& self, + c10::optional n = c10::nullopt, + int64_t dim = -1, + c10::optional norm = c10::nullopt) { return torch::fft_hfft(self, n, dim, norm); } @@ -219,10 +232,11 @@ inline Tensor hfft(const Tensor& self, /// auto t = torch::fft::ihfft(T); /// assert(t.is_complex() && T.numel() == 128 / 2 + 1); /// ``` -inline Tensor ihfft(const Tensor& self, - c10::optional n=c10::nullopt, - int64_t dim=-1, - c10::optional norm=c10::nullopt) { +inline Tensor ihfft( + const Tensor& self, + c10::optional n = c10::nullopt, + int64_t dim = -1, + c10::optional norm = c10::nullopt) { return torch::fft_ihfft(self, n, dim, norm); } @@ -237,17 +251,19 @@ inline Tensor ihfft(const Tensor& self, /// auto T = torch::fft::hfft2(t, /*s=*/{128, 128}); /// assert(T.is_floating_point() && T.numel() == 128 * 128); /// ``` -inline Tensor hfft2(const Tensor& self, - at::OptionalIntArrayRef s=c10::nullopt, - IntArrayRef dim={-2, -1}, - c10::optional norm=c10::nullopt) { +inline Tensor hfft2( + const Tensor& self, + at::OptionalIntArrayRef s = c10::nullopt, + IntArrayRef dim = {-2, -1}, + c10::optional norm = c10::nullopt) { return torch::fft_hfft2(self, s, dim, norm); } /// Computes the 2-dimensional IFFT of a real input signal. /// /// The output is a onesided representation of the Hermitian symmetric time -/// domain signal. See https://pytorch.org/docs/master/fft.html#torch.fft.ihfft2. +/// domain signal. See +/// https://pytorch.org/docs/master/fft.html#torch.fft.ihfft2. /// /// Example: /// ``` @@ -255,10 +271,11 @@ inline Tensor hfft2(const Tensor& self, /// auto t = torch::fft::hfft2(T); /// assert(t.is_complex() && t.size(1) == 65); /// ``` -inline Tensor ihfft2(const Tensor& self, - at::OptionalIntArrayRef s=c10::nullopt, - IntArrayRef dim={-2, -1}, - c10::optional norm=c10::nullopt) { +inline Tensor ihfft2( + const Tensor& self, + at::OptionalIntArrayRef s = c10::nullopt, + IntArrayRef dim = {-2, -1}, + c10::optional norm = c10::nullopt) { return torch::fft_ihfft2(self, s, dim, norm); } @@ -273,17 +290,19 @@ inline Tensor ihfft2(const Tensor& self, /// auto T = torch::fft::hfftn(t, /*s=*/{128, 128}); /// assert(T.is_floating_point() && T.numel() == 128 * 128); /// ``` -inline Tensor hfftn(const Tensor& self, - at::OptionalIntArrayRef s=c10::nullopt, - IntArrayRef dim={-2, -1}, - c10::optional norm=c10::nullopt) { +inline Tensor hfftn( + const Tensor& self, + at::OptionalIntArrayRef s = c10::nullopt, + IntArrayRef dim = {-2, -1}, + c10::optional norm = c10::nullopt) { return torch::fft_hfftn(self, s, dim, norm); } /// Computes the N-dimensional IFFT of a real input signal. /// /// The output is a onesided representation of the Hermitian symmetric time -/// domain signal. See https://pytorch.org/docs/master/fft.html#torch.fft.ihfftn. +/// domain signal. See +/// https://pytorch.org/docs/master/fft.html#torch.fft.ihfftn. /// /// Example: /// ``` @@ -291,14 +310,16 @@ inline Tensor hfftn(const Tensor& self, /// auto t = torch::fft::hfft2(T); /// assert(t.is_complex() && t.size(1) == 65); /// ``` -inline Tensor ihfftn(const Tensor& self, - at::OptionalIntArrayRef s=c10::nullopt, - IntArrayRef dim={-2, -1}, - c10::optional norm=c10::nullopt) { +inline Tensor ihfftn( + const Tensor& self, + at::OptionalIntArrayRef s = c10::nullopt, + IntArrayRef dim = {-2, -1}, + c10::optional norm = c10::nullopt) { return torch::fft_ihfftn(self, s, dim, norm); } -/// Computes the discrete Fourier Transform sample frequencies for a signal of size n. +/// Computes the discrete Fourier Transform sample frequencies for a signal of +/// size n. /// /// See https://pytorch.org/docs/master/fft.html#torch.fft.fftfreq /// @@ -306,11 +327,11 @@ inline Tensor ihfftn(const Tensor& self, /// ``` /// auto frequencies = torch::fft::fftfreq(128, torch::kDouble); /// ``` -inline Tensor fftfreq(int64_t n, double d, const TensorOptions& options={}) { +inline Tensor fftfreq(int64_t n, double d, const TensorOptions& options = {}) { return torch::fft_fftfreq(n, d, options); } -inline Tensor fftfreq(int64_t n, const TensorOptions& options={}) { +inline Tensor fftfreq(int64_t n, const TensorOptions& options = {}) { return torch::fft_fftfreq(n, /*d=*/1.0, options); } @@ -341,7 +362,9 @@ inline Tensor rfftfreq(int64_t n, const TensorOptions& options) { /// auto x = torch::randn({127, 4}); /// auto centred_fft = torch::fft::fftshift(torch::fft::fftn(x)); /// ``` -inline Tensor fftshift(const Tensor& x, at::OptionalIntArrayRef dim=c10::nullopt) { +inline Tensor fftshift( + const Tensor& x, + at::OptionalIntArrayRef dim = c10::nullopt) { return torch::fft_fftshift(x, dim); } @@ -356,8 +379,11 @@ inline Tensor fftshift(const Tensor& x, at::OptionalIntArrayRef dim=c10::nullopt /// auto unshift = torch::fft::ifftshift(shift); /// assert(torch::allclose(x, unshift)); /// ``` -inline Tensor ifftshift(const Tensor& x, at::OptionalIntArrayRef dim=c10::nullopt) { +inline Tensor ifftshift( + const Tensor& x, + at::OptionalIntArrayRef dim = c10::nullopt) { return torch::fft_ifftshift(x, dim); } -}} // torch::fft +} // namespace fft +} // namespace torch diff --git a/torch/csrc/api/include/torch/imethod.h b/torch/csrc/api/include/torch/imethod.h index 5ab9b838882141..1ee84b729c24fd 100644 --- a/torch/csrc/api/include/torch/imethod.h +++ b/torch/csrc/api/include/torch/imethod.h @@ -37,10 +37,11 @@ class TORCH_API IMethod { const std::vector& getArgumentNames() const; protected: - virtual void setArgumentNames(std::vector& argumentNames) const = 0; + virtual void setArgumentNames( + std::vector& argumentNames) const = 0; private: - mutable bool isArgumentNamesInitialized_ { false }; + mutable bool isArgumentNamesInitialized_{false}; mutable std::vector argumentNames_; }; diff --git a/torch/csrc/api/include/torch/jit.h b/torch/csrc/api/include/torch/jit.h index f2364a14b1a75d..703eed0d042480 100644 --- a/torch/csrc/api/include/torch/jit.h +++ b/torch/csrc/api/include/torch/jit.h @@ -3,8 +3,8 @@ #include #include -#include #include +#include namespace torch { namespace jit { diff --git a/torch/csrc/api/include/torch/linalg.h b/torch/csrc/api/include/torch/linalg.h index 2fc8065f2e60c1..20abf1572f283c 100644 --- a/torch/csrc/api/include/torch/linalg.h +++ b/torch/csrc/api/include/torch/linalg.h @@ -24,7 +24,10 @@ inline std::tuple slogdet(const Tensor& input) { return torch::linalg_slogdet(input); } -inline std::tuple slogdet_out(Tensor& sign, Tensor& logabsdet, const Tensor& input) { +inline std::tuple slogdet_out( + Tensor& sign, + Tensor& logabsdet, + const Tensor& input) { return torch::linalg_slogdet_out(sign, logabsdet, input); } @@ -32,7 +35,10 @@ inline std::tuple eig(const Tensor& self) { return torch::linalg_eig(self); } -inline std::tuple eig_out(Tensor& eigvals, Tensor& eigvecs, const Tensor& self) { +inline std::tuple eig_out( + Tensor& eigvals, + Tensor& eigvecs, + const Tensor& self) { return torch::linalg_eig_out(eigvals, eigvecs, self); } @@ -44,11 +50,17 @@ inline Tensor& eigvals_out(Tensor& result, const Tensor& self) { return torch::linalg_eigvals_out(result, self); } -inline std::tuple eigh(const Tensor& self, c10::string_view uplo) { +inline std::tuple eigh( + const Tensor& self, + c10::string_view uplo) { return torch::linalg_eigh(self, uplo); } -inline std::tuple eigh_out(Tensor& eigvals, Tensor& eigvecs, const Tensor& self, c10::string_view uplo) { +inline std::tuple eigh_out( + Tensor& eigvals, + Tensor& eigvecs, + const Tensor& self, + c10::string_view uplo) { return torch::linalg_eigh_out(eigvals, eigvecs, self, uplo); } @@ -56,7 +68,10 @@ inline Tensor eigvalsh(const Tensor& self, c10::string_view uplo) { return torch::linalg_eigvalsh(self, uplo); } -inline Tensor& eigvalsh_out(Tensor& result, const Tensor& self, c10::string_view uplo) { +inline Tensor& eigvalsh_out( + Tensor& result, + const Tensor& self, + c10::string_view uplo) { return torch::linalg_eigvalsh_out(result, self, uplo); } @@ -64,27 +79,47 @@ inline Tensor householder_product(const Tensor& input, const Tensor& tau) { return torch::linalg_householder_product(input, tau); } -inline Tensor& householder_product_out(Tensor& result, const Tensor& input, const Tensor& tau) { +inline Tensor& householder_product_out( + Tensor& result, + const Tensor& input, + const Tensor& tau) { return torch::linalg_householder_product_out(result, input, tau); } -inline std::tuple lu_factor(const Tensor& self, const bool pivot) { +inline std::tuple lu_factor( + const Tensor& self, + const bool pivot) { return torch::linalg_lu_factor(self, pivot); } -inline std::tuple lu_factor_out(Tensor& LU, Tensor& pivots, const Tensor& self, const bool pivot) { +inline std::tuple lu_factor_out( + Tensor& LU, + Tensor& pivots, + const Tensor& self, + const bool pivot) { return torch::linalg_lu_factor_out(LU, pivots, self, pivot); } -inline std::tuple lu(const Tensor& self, const bool pivot) { +inline std::tuple lu( + const Tensor& self, + const bool pivot) { return torch::linalg_lu(self, pivot); } -inline std::tuple lu_out(Tensor& P, Tensor& L, Tensor& U, const Tensor& self, const bool pivot) { +inline std::tuple lu_out( + Tensor& P, + Tensor& L, + Tensor& U, + const Tensor& self, + const bool pivot) { return torch::linalg_lu_out(P, L, U, self, pivot); } -inline std::tuple lstsq(const Tensor& self, const Tensor& b, c10::optional cond, c10::optional driver) { +inline std::tuple lstsq( + const Tensor& self, + const Tensor& b, + c10::optional cond, + c10::optional driver) { return torch::linalg_lstsq(self, b, cond, driver); } @@ -92,43 +127,100 @@ inline Tensor matrix_exp(const Tensor& self) { return torch::linalg_matrix_exp(self); } -inline Tensor norm(const Tensor& self, const optional& opt_ord, OptionalIntArrayRef opt_dim, bool keepdim, optional opt_dtype) { +inline Tensor norm( + const Tensor& self, + const optional& opt_ord, + OptionalIntArrayRef opt_dim, + bool keepdim, + optional opt_dtype) { return torch::linalg_norm(self, opt_ord, opt_dim, keepdim, opt_dtype); } -inline Tensor norm(const Tensor& self, c10::string_view ord, OptionalIntArrayRef opt_dim, bool keepdim, optional opt_dtype) { +inline Tensor norm( + const Tensor& self, + c10::string_view ord, + OptionalIntArrayRef opt_dim, + bool keepdim, + optional opt_dtype) { return torch::linalg_norm(self, ord, opt_dim, keepdim, opt_dtype); } -inline Tensor& norm_out(Tensor& result, const Tensor& self, const optional& opt_ord, OptionalIntArrayRef opt_dim, bool keepdim, optional opt_dtype) { - return torch::linalg_norm_out(result, self, opt_ord, opt_dim, keepdim, opt_dtype); +inline Tensor& norm_out( + Tensor& result, + const Tensor& self, + const optional& opt_ord, + OptionalIntArrayRef opt_dim, + bool keepdim, + optional opt_dtype) { + return torch::linalg_norm_out( + result, self, opt_ord, opt_dim, keepdim, opt_dtype); } -inline Tensor& norm_out(Tensor& result, const Tensor& self, c10::string_view ord, OptionalIntArrayRef opt_dim, bool keepdim, optional opt_dtype) { +inline Tensor& norm_out( + Tensor& result, + const Tensor& self, + c10::string_view ord, + OptionalIntArrayRef opt_dim, + bool keepdim, + optional opt_dtype) { return torch::linalg_norm_out(result, self, ord, opt_dim, keepdim, opt_dtype); } -inline Tensor vector_norm(const Tensor& self, Scalar ord, OptionalIntArrayRef opt_dim, bool keepdim, optional opt_dtype) { +inline Tensor vector_norm( + const Tensor& self, + Scalar ord, + OptionalIntArrayRef opt_dim, + bool keepdim, + optional opt_dtype) { return torch::linalg_vector_norm(self, ord, opt_dim, keepdim, opt_dtype); } -inline Tensor& vector_norm_out(Tensor& result, const Tensor& self, Scalar ord, OptionalIntArrayRef opt_dim, bool keepdim, optional opt_dtype) { - return torch::linalg_vector_norm_out(result, self, ord, opt_dim, keepdim, opt_dtype); -} - -inline Tensor matrix_norm(const Tensor& self, const Scalar& ord, IntArrayRef dim, bool keepdim, optional dtype) { +inline Tensor& vector_norm_out( + Tensor& result, + const Tensor& self, + Scalar ord, + OptionalIntArrayRef opt_dim, + bool keepdim, + optional opt_dtype) { + return torch::linalg_vector_norm_out( + result, self, ord, opt_dim, keepdim, opt_dtype); +} + +inline Tensor matrix_norm( + const Tensor& self, + const Scalar& ord, + IntArrayRef dim, + bool keepdim, + optional dtype) { return torch::linalg_matrix_norm(self, ord, dim, keepdim, dtype); } -inline Tensor& matrix_norm_out(const Tensor& self, const Scalar& ord, IntArrayRef dim, bool keepdim, optional dtype, Tensor& result) { +inline Tensor& matrix_norm_out( + const Tensor& self, + const Scalar& ord, + IntArrayRef dim, + bool keepdim, + optional dtype, + Tensor& result) { return torch::linalg_matrix_norm_out(result, self, ord, dim, keepdim, dtype); } -inline Tensor matrix_norm(const Tensor& self, std::string ord, IntArrayRef dim, bool keepdim, optional dtype) { +inline Tensor matrix_norm( + const Tensor& self, + std::string ord, + IntArrayRef dim, + bool keepdim, + optional dtype) { return torch::linalg_matrix_norm(self, ord, dim, keepdim, dtype); } -inline Tensor& matrix_norm_out(const Tensor& self, std::string ord, IntArrayRef dim, bool keepdim, optional dtype, Tensor& result) { +inline Tensor& matrix_norm_out( + const Tensor& self, + std::string ord, + IntArrayRef dim, + bool keepdim, + optional dtype, + Tensor& result) { return torch::linalg_matrix_norm_out(result, self, ord, dim, keepdim, dtype); } @@ -144,31 +236,60 @@ inline Tensor matrix_rank(const Tensor& input, double tol, bool hermitian) { return torch::linalg_matrix_rank(input, tol, hermitian); } -inline Tensor matrix_rank(const Tensor& input, const Tensor& tol, bool hermitian) { +inline Tensor matrix_rank( + const Tensor& input, + const Tensor& tol, + bool hermitian) { return torch::linalg_matrix_rank(input, tol, hermitian); } -inline Tensor matrix_rank(const Tensor& input, c10::optional atol, c10::optional rtol, bool hermitian) { +inline Tensor matrix_rank( + const Tensor& input, + c10::optional atol, + c10::optional rtol, + bool hermitian) { return torch::linalg_matrix_rank(input, atol, rtol, hermitian); } -inline Tensor matrix_rank(const Tensor& input, const c10::optional& atol, const c10::optional& rtol, bool hermitian) { +inline Tensor matrix_rank( + const Tensor& input, + const c10::optional& atol, + const c10::optional& rtol, + bool hermitian) { return torch::linalg_matrix_rank(input, atol, rtol, hermitian); } -inline Tensor& matrix_rank_out(Tensor& result, const Tensor& input, double tol, bool hermitian) { +inline Tensor& matrix_rank_out( + Tensor& result, + const Tensor& input, + double tol, + bool hermitian) { return torch::linalg_matrix_rank_out(result, input, tol, hermitian); } -inline Tensor& matrix_rank_out(Tensor& result, const Tensor& input, const Tensor& tol, bool hermitian) { +inline Tensor& matrix_rank_out( + Tensor& result, + const Tensor& input, + const Tensor& tol, + bool hermitian) { return torch::linalg_matrix_rank_out(result, input, tol, hermitian); } -inline Tensor& matrix_rank_out(Tensor& result, const Tensor& input, c10::optional atol, c10::optional rtol, bool hermitian) { +inline Tensor& matrix_rank_out( + Tensor& result, + const Tensor& input, + c10::optional atol, + c10::optional rtol, + bool hermitian) { return torch::linalg_matrix_rank_out(result, input, atol, rtol, hermitian); } -inline Tensor& matrix_rank_out(Tensor& result, const Tensor& input, const c10::optional& atol, const c10::optional& rtol, bool hermitian) { +inline Tensor& matrix_rank_out( + Tensor& result, + const Tensor& input, + const c10::optional& atol, + const c10::optional& rtol, + bool hermitian) { return torch::linalg_matrix_rank_out(result, input, atol, rtol, hermitian); } @@ -184,15 +305,25 @@ inline Tensor pinv(const Tensor& input, double rcond, bool hermitian) { return torch::linalg_pinv(input, rcond, hermitian); } -inline Tensor& pinv_out(Tensor& result, const Tensor& input, double rcond, bool hermitian) { +inline Tensor& pinv_out( + Tensor& result, + const Tensor& input, + double rcond, + bool hermitian) { return torch::linalg_pinv_out(result, input, rcond, hermitian); } -inline std::tuple qr(const Tensor& input, c10::string_view mode) { +inline std::tuple qr( + const Tensor& input, + c10::string_view mode) { return torch::linalg_qr(input, mode); } -inline std::tuple qr_out(Tensor& Q, Tensor& R, const Tensor& input, c10::string_view mode) { +inline std::tuple qr_out( + Tensor& Q, + Tensor& R, + const Tensor& input, + c10::string_view mode) { return torch::linalg_qr_out(Q, R, input, mode); } @@ -200,32 +331,62 @@ inline Tensor solve(const Tensor& input, const Tensor& other, bool left) { return torch::linalg_solve(input, other, left); } -inline Tensor& solve_out(Tensor& result, const Tensor& input, const Tensor& other, bool left) { +inline Tensor& solve_out( + Tensor& result, + const Tensor& input, + const Tensor& other, + bool left) { return torch::linalg_solve_out(result, input, other, left); } -inline Tensor solve_triangular(const Tensor& input, const Tensor& other, bool upper, bool left, bool unitriangular) { - return torch::linalg_solve_triangular(input, other, upper, left, unitriangular); +inline Tensor solve_triangular( + const Tensor& input, + const Tensor& other, + bool upper, + bool left, + bool unitriangular) { + return torch::linalg_solve_triangular( + input, other, upper, left, unitriangular); } -inline Tensor& solve_triangular_out(Tensor& result, const Tensor& input, const Tensor& other, bool upper, bool left, bool unitriangular) { - return torch::linalg_solve_triangular_out(result, input, other, upper, left, unitriangular); +inline Tensor& solve_triangular_out( + Tensor& result, + const Tensor& input, + const Tensor& other, + bool upper, + bool left, + bool unitriangular) { + return torch::linalg_solve_triangular_out( + result, input, other, upper, left, unitriangular); } -inline std::tuple svd(const Tensor& input, bool full_matrices, c10::optional driver) { +inline std::tuple svd( + const Tensor& input, + bool full_matrices, + c10::optional driver) { return torch::linalg_svd(input, full_matrices, driver); } -inline std::tuple svd_out(Tensor& U, Tensor& S, Tensor& Vh, const Tensor& input, bool full_matrices, +inline std::tuple svd_out( + Tensor& U, + Tensor& S, + Tensor& Vh, + const Tensor& input, + bool full_matrices, c10::optional driver) { return torch::linalg_svd_out(U, S, Vh, input, full_matrices, driver); } -inline Tensor svdvals(const Tensor& input, c10::optional driver) { +inline Tensor svdvals( + const Tensor& input, + c10::optional driver) { return torch::linalg_svdvals(input, driver); } -inline Tensor& svdvals_out(Tensor& result, const Tensor& input, c10::optional driver) { +inline Tensor& svdvals_out( + Tensor& result, + const Tensor& input, + c10::optional driver) { return torch::linalg_svdvals_out(result, input, driver); } @@ -233,15 +394,22 @@ inline Tensor tensorinv(const Tensor& self, int64_t ind) { return torch::linalg_tensorinv(self, ind); } -inline Tensor& tensorinv_out(Tensor& result,const Tensor& self, int64_t ind) { +inline Tensor& tensorinv_out(Tensor& result, const Tensor& self, int64_t ind) { return torch::linalg_tensorinv_out(result, self, ind); } -inline Tensor tensorsolve(const Tensor& self, const Tensor& other, OptionalIntArrayRef dims) { +inline Tensor tensorsolve( + const Tensor& self, + const Tensor& other, + OptionalIntArrayRef dims) { return torch::linalg_tensorsolve(self, other, dims); } -inline Tensor& tensorsolve_out(Tensor& result, const Tensor& self, const Tensor& other, OptionalIntArrayRef dims) { +inline Tensor& tensorsolve_out( + Tensor& result, + const Tensor& self, + const Tensor& other, + OptionalIntArrayRef dims) { return torch::linalg_tensorsolve_out(result, self, other, dims); } @@ -292,18 +460,25 @@ inline std::tuple slogdet(const Tensor& input) { return detail::slogdet(input); } -inline std::tuple slogdet_out(Tensor& sign, Tensor& logabsdet, const Tensor& input) { +inline std::tuple slogdet_out( + Tensor& sign, + Tensor& logabsdet, + const Tensor& input) { return detail::slogdet_out(sign, logabsdet, input); } -/// Computes eigenvalues and eigenvectors of non-symmetric/non-hermitian matrices +/// Computes eigenvalues and eigenvectors of non-symmetric/non-hermitian +/// matrices /// /// See https://pytorch.org/docs/master/linalg.html#torch.linalg.eig inline std::tuple eig(const Tensor& self) { return detail::eig(self); } -inline std::tuple eig_out(Tensor& eigvals, Tensor& eigvecs, const Tensor& self) { +inline std::tuple eig_out( + Tensor& eigvals, + Tensor& eigvecs, + const Tensor& self) { return detail::eig_out(eigvals, eigvecs, self); } @@ -321,11 +496,17 @@ inline Tensor& eigvals_out(Tensor& result, const Tensor& self) { /// Computes eigenvalues and eigenvectors /// /// See https://pytorch.org/docs/master/linalg.html#torch.linalg.eigh -inline std::tuple eigh(const Tensor& self, c10::string_view uplo) { +inline std::tuple eigh( + const Tensor& self, + c10::string_view uplo) { return detail::eigh(self, uplo); } -inline std::tuple eigh_out(Tensor& eigvals, Tensor& eigvecs, const Tensor& self, c10::string_view uplo) { +inline std::tuple eigh_out( + Tensor& eigvals, + Tensor& eigvecs, + const Tensor& self, + c10::string_view uplo) { return detail::eigh_out(eigvals, eigvecs, self, uplo); } @@ -336,22 +517,33 @@ inline Tensor eigvalsh(const Tensor& self, c10::string_view uplo) { return detail::eigvalsh(self, uplo); } -inline Tensor& eigvalsh_out(Tensor& result, const Tensor& self, c10::string_view uplo) { +inline Tensor& eigvalsh_out( + Tensor& result, + const Tensor& self, + c10::string_view uplo) { return detail::eigvalsh_out(result, self, uplo); } /// Computes the product of Householder matrices /// -/// See https://pytorch.org/docs/master/linalg.html#torch.linalg.householder_product +/// See +/// https://pytorch.org/docs/master/linalg.html#torch.linalg.householder_product inline Tensor householder_product(const Tensor& input, const Tensor& tau) { return detail::householder_product(input, tau); } -inline Tensor& householder_product_out(Tensor& result, const Tensor& input, const Tensor& tau) { +inline Tensor& householder_product_out( + Tensor& result, + const Tensor& input, + const Tensor& tau) { return detail::householder_product_out(result, input, tau); } -inline std::tuple lstsq(const Tensor& self, const Tensor& b, c10::optional cond, c10::optional driver) { +inline std::tuple lstsq( + const Tensor& self, + const Tensor& b, + c10::optional cond, + c10::optional driver) { return detail::lstsq(self, b, cond, driver); } @@ -363,86 +555,179 @@ inline Tensor matrix_exp(const Tensor& input) { } // C10_DEPRECATED_MESSAGE("linalg_norm is deprecated, use norm instead.") -inline Tensor linalg_norm(const Tensor& self, const optional& opt_ord, OptionalIntArrayRef opt_dim, bool keepdim, optional opt_dtype) { +inline Tensor linalg_norm( + const Tensor& self, + const optional& opt_ord, + OptionalIntArrayRef opt_dim, + bool keepdim, + optional opt_dtype) { return detail::norm(self, opt_ord, opt_dim, keepdim, opt_dtype); } // C10_DEPRECATED_MESSAGE("linalg_norm is deprecated, use norm instead.") -inline Tensor linalg_norm(const Tensor& self, c10::string_view ord, OptionalIntArrayRef opt_dim, bool keepdim, optional opt_dtype) { +inline Tensor linalg_norm( + const Tensor& self, + c10::string_view ord, + OptionalIntArrayRef opt_dim, + bool keepdim, + optional opt_dtype) { return detail::norm(self, ord, opt_dim, keepdim, opt_dtype); } -// C10_DEPRECATED_MESSAGE("linalg_norm_out is deprecated, use norm_out instead.") -inline Tensor& linalg_norm_out(Tensor& result, const Tensor& self, const optional& opt_ord, OptionalIntArrayRef opt_dim, bool keepdim, optional opt_dtype) { +// C10_DEPRECATED_MESSAGE("linalg_norm_out is deprecated, use norm_out +// instead.") +inline Tensor& linalg_norm_out( + Tensor& result, + const Tensor& self, + const optional& opt_ord, + OptionalIntArrayRef opt_dim, + bool keepdim, + optional opt_dtype) { return detail::norm_out(result, self, opt_ord, opt_dim, keepdim, opt_dtype); } -// C10_DEPRECATED_MESSAGE("linalg_norm_out is deprecated, use norm_out instead.") -inline Tensor& linalg_norm_out(Tensor& result, const Tensor& self, c10::string_view ord, OptionalIntArrayRef opt_dim, bool keepdim, optional opt_dtype) { +// C10_DEPRECATED_MESSAGE("linalg_norm_out is deprecated, use norm_out +// instead.") +inline Tensor& linalg_norm_out( + Tensor& result, + const Tensor& self, + c10::string_view ord, + OptionalIntArrayRef opt_dim, + bool keepdim, + optional opt_dtype) { return detail::norm_out(result, self, ord, opt_dim, keepdim, opt_dtype); } /// Computes the LU factorization with partial pivoting /// /// See https://pytorch.org/docs/master/linalg.html#torch.linalg.lu_factor -inline std::tuple lu_factor(const Tensor& input, const bool pivot=true) { +inline std::tuple lu_factor( + const Tensor& input, + const bool pivot = true) { return detail::lu_factor(input, pivot); } -inline std::tuple lu_factor_out(Tensor& LU, Tensor& pivots, const Tensor& self, const bool pivot=true) { +inline std::tuple lu_factor_out( + Tensor& LU, + Tensor& pivots, + const Tensor& self, + const bool pivot = true) { return detail::lu_factor_out(LU, pivots, self, pivot); } /// Computes the LU factorization with partial pivoting /// /// See https://pytorch.org/docs/master/linalg.html#torch.linalg.lu -inline std::tuple lu(const Tensor& input, const bool pivot=true) { +inline std::tuple lu( + const Tensor& input, + const bool pivot = true) { return detail::lu(input, pivot); } -inline std::tuple lu_out(Tensor& P, Tensor& L, Tensor& U, const Tensor& self, const bool pivot=true) { +inline std::tuple lu_out( + Tensor& P, + Tensor& L, + Tensor& U, + const Tensor& self, + const bool pivot = true) { return detail::lu_out(P, L, U, self, pivot); } -inline Tensor norm(const Tensor& self, const optional& opt_ord, OptionalIntArrayRef opt_dim, bool keepdim, optional opt_dtype) { +inline Tensor norm( + const Tensor& self, + const optional& opt_ord, + OptionalIntArrayRef opt_dim, + bool keepdim, + optional opt_dtype) { return detail::norm(self, opt_ord, opt_dim, keepdim, opt_dtype); } -inline Tensor norm(const Tensor& self, std::string ord, OptionalIntArrayRef opt_dim, bool keepdim, optional opt_dtype) { +inline Tensor norm( + const Tensor& self, + std::string ord, + OptionalIntArrayRef opt_dim, + bool keepdim, + optional opt_dtype) { return detail::norm(self, ord, opt_dim, keepdim, opt_dtype); } -inline Tensor& norm_out(Tensor& result, const Tensor& self, const optional& opt_ord, OptionalIntArrayRef opt_dim, bool keepdim, optional opt_dtype) { +inline Tensor& norm_out( + Tensor& result, + const Tensor& self, + const optional& opt_ord, + OptionalIntArrayRef opt_dim, + bool keepdim, + optional opt_dtype) { return detail::norm_out(result, self, opt_ord, opt_dim, keepdim, opt_dtype); } -inline Tensor& norm_out(Tensor& result, const Tensor& self, std::string ord, OptionalIntArrayRef opt_dim, bool keepdim, optional opt_dtype) { +inline Tensor& norm_out( + Tensor& result, + const Tensor& self, + std::string ord, + OptionalIntArrayRef opt_dim, + bool keepdim, + optional opt_dtype) { return detail::norm_out(result, self, ord, opt_dim, keepdim, opt_dtype); } /// See https://pytorch.org/docs/master/linalg.html#torch.linalg.vector_norm -inline Tensor vector_norm(const Tensor& self, Scalar ord, OptionalIntArrayRef opt_dim, bool keepdim, optional opt_dtype) { +inline Tensor vector_norm( + const Tensor& self, + Scalar ord, + OptionalIntArrayRef opt_dim, + bool keepdim, + optional opt_dtype) { return detail::vector_norm(self, ord, opt_dim, keepdim, opt_dtype); } -inline Tensor& vector_norm_out(Tensor& result, const Tensor& self, Scalar ord, OptionalIntArrayRef opt_dim, bool keepdim, optional opt_dtype) { - return detail::vector_norm_out(result, self, ord, opt_dim, keepdim, opt_dtype); +inline Tensor& vector_norm_out( + Tensor& result, + const Tensor& self, + Scalar ord, + OptionalIntArrayRef opt_dim, + bool keepdim, + optional opt_dtype) { + return detail::vector_norm_out( + result, self, ord, opt_dim, keepdim, opt_dtype); } /// See https://pytorch.org/docs/master/linalg.html#torch.linalg.matrix_norm -inline Tensor matrix_norm(const Tensor& self, const Scalar& ord, IntArrayRef dim, bool keepdim, optional dtype) { +inline Tensor matrix_norm( + const Tensor& self, + const Scalar& ord, + IntArrayRef dim, + bool keepdim, + optional dtype) { return detail::matrix_norm(self, ord, dim, keepdim, dtype); } -inline Tensor& matrix_norm_out(const Tensor& self, const Scalar& ord, IntArrayRef dim, bool keepdim, optional dtype, Tensor& result) { +inline Tensor& matrix_norm_out( + const Tensor& self, + const Scalar& ord, + IntArrayRef dim, + bool keepdim, + optional dtype, + Tensor& result) { return detail::matrix_norm_out(self, ord, dim, keepdim, dtype, result); } -inline Tensor matrix_norm(const Tensor& self, std::string ord, IntArrayRef dim, bool keepdim, optional dtype) { +inline Tensor matrix_norm( + const Tensor& self, + std::string ord, + IntArrayRef dim, + bool keepdim, + optional dtype) { return detail::matrix_norm(self, ord, dim, keepdim, dtype); } -inline Tensor& matrix_norm_out(const Tensor& self, std::string ord, IntArrayRef dim, bool keepdim, optional dtype, Tensor& result) { +inline Tensor& matrix_norm_out( + const Tensor& self, + std::string ord, + IntArrayRef dim, + bool keepdim, + optional dtype, + Tensor& result) { return detail::matrix_norm_out(self, ord, dim, keepdim, dtype, result); } @@ -460,31 +745,60 @@ inline Tensor matrix_rank(const Tensor& input, double tol, bool hermitian) { return detail::matrix_rank(input, tol, hermitian); } -inline Tensor matrix_rank(const Tensor& input, const Tensor& tol, bool hermitian) { +inline Tensor matrix_rank( + const Tensor& input, + const Tensor& tol, + bool hermitian) { return detail::matrix_rank(input, tol, hermitian); } -inline Tensor matrix_rank(const Tensor& input, c10::optional atol, c10::optional rtol, bool hermitian) { +inline Tensor matrix_rank( + const Tensor& input, + c10::optional atol, + c10::optional rtol, + bool hermitian) { return detail::matrix_rank(input, atol, rtol, hermitian); } -inline Tensor matrix_rank(const Tensor& input, const c10::optional& atol, const c10::optional& rtol, bool hermitian) { +inline Tensor matrix_rank( + const Tensor& input, + const c10::optional& atol, + const c10::optional& rtol, + bool hermitian) { return detail::matrix_rank(input, atol, rtol, hermitian); } -inline Tensor& matrix_rank_out(Tensor& result, const Tensor& input, double tol, bool hermitian) { +inline Tensor& matrix_rank_out( + Tensor& result, + const Tensor& input, + double tol, + bool hermitian) { return detail::matrix_rank_out(result, input, tol, hermitian); } -inline Tensor& matrix_rank_out(Tensor& result, const Tensor& input, const Tensor& tol, bool hermitian) { +inline Tensor& matrix_rank_out( + Tensor& result, + const Tensor& input, + const Tensor& tol, + bool hermitian) { return detail::matrix_rank_out(result, input, tol, hermitian); } -inline Tensor& matrix_rank_out(Tensor& result, const Tensor& input, c10::optional atol, c10::optional rtol, bool hermitian) { +inline Tensor& matrix_rank_out( + Tensor& result, + const Tensor& input, + c10::optional atol, + c10::optional rtol, + bool hermitian) { return detail::matrix_rank_out(result, input, atol, rtol, hermitian); } -inline Tensor& matrix_rank_out(Tensor& result, const Tensor& input, const c10::optional& atol, const c10::optional& rtol, bool hermitian) { +inline Tensor& matrix_rank_out( + Tensor& result, + const Tensor& input, + const c10::optional& atol, + const c10::optional& rtol, + bool hermitian) { return detail::matrix_rank_out(result, input, atol, rtol, hermitian); } @@ -500,24 +814,37 @@ inline Tensor& multi_dot_out(TensorList tensors, Tensor& result) { /// Computes the pseudo-inverse /// /// See https://pytorch.org/docs/master/linalg.html#torch.linalg.pinv -inline Tensor pinv(const Tensor& input, double rcond=1e-15, bool hermitian=false) { +inline Tensor pinv( + const Tensor& input, + double rcond = 1e-15, + bool hermitian = false) { return detail::pinv(input, rcond, hermitian); } -inline Tensor& pinv_out(Tensor& result, const Tensor& input, double rcond=1e-15, bool hermitian=false) { +inline Tensor& pinv_out( + Tensor& result, + const Tensor& input, + double rcond = 1e-15, + bool hermitian = false) { return detail::pinv_out(result, input, rcond, hermitian); } /// Computes the QR decomposition /// /// See https://pytorch.org/docs/master/linalg.html#torch.linalg.qr -inline std::tuple qr(const Tensor& input, c10::string_view mode="reduced") { +inline std::tuple qr( + const Tensor& input, + c10::string_view mode = "reduced") { // C++17 Change the initialisation to "reduced"sv // Same for qr_out return detail::qr(input, mode); } -inline std::tuple qr_out(Tensor& Q, Tensor& R, const Tensor& input, c10::string_view mode="reduced") { +inline std::tuple qr_out( + Tensor& Q, + Tensor& R, + const Tensor& input, + c10::string_view mode = "reduced") { return detail::qr_out(Q, R, input, mode); } @@ -559,8 +886,7 @@ inline Tensor& ldl_solve_out( const Tensor& pivots, const Tensor& B, bool hermitian) { - return torch::linalg_ldl_solve_out( - result, LD, pivots, B, hermitian); + return torch::linalg_ldl_solve_out(result, LD, pivots, B, hermitian); } /// Computes a tensor `x` such that `matmul(input, x) = other`. @@ -570,41 +896,73 @@ inline Tensor solve(const Tensor& input, const Tensor& other, bool left) { return detail::solve(input, other, left); } -inline Tensor& solve_out(Tensor& result, const Tensor& input, const Tensor& other, bool left) { +inline Tensor& solve_out( + Tensor& result, + const Tensor& input, + const Tensor& other, + bool left) { return detail::solve_out(result, input, other, left); } -/// Computes a solution of a linear system AX = B for input = A and other = B whenever A is square -/// upper or lower triangular and does not have zeros in the diagonal +/// Computes a solution of a linear system AX = B for input = A and other = B +/// whenever A is square upper or lower triangular and does not have zeros in +/// the diagonal /// -/// See https://pytorch.org/docs/master/linalg.html#torch.linalg.solve_triangular -inline Tensor solve_triangular(const Tensor& input, const Tensor& other, bool upper, bool left, bool unitriangular) { +/// See +/// https://pytorch.org/docs/master/linalg.html#torch.linalg.solve_triangular +inline Tensor solve_triangular( + const Tensor& input, + const Tensor& other, + bool upper, + bool left, + bool unitriangular) { return detail::solve_triangular(input, other, upper, left, unitriangular); } -inline Tensor& solve_triangular_out(Tensor& result, const Tensor& input, const Tensor& other, bool upper, bool left, bool unitriangular) { - return detail::solve_triangular_out(result, input, other, upper, left, unitriangular); +inline Tensor& solve_triangular_out( + Tensor& result, + const Tensor& input, + const Tensor& other, + bool upper, + bool left, + bool unitriangular) { + return detail::solve_triangular_out( + result, input, other, upper, left, unitriangular); } /// Computes the singular values and singular vectors /// /// See https://pytorch.org/docs/master/linalg.html#torch.linalg.svd -inline std::tuple svd(const Tensor& input, bool full_matrices, c10::optional driver) { +inline std::tuple svd( + const Tensor& input, + bool full_matrices, + c10::optional driver) { return detail::svd(input, full_matrices, driver); } -inline std::tuple svd_out(Tensor& U, Tensor& S, Tensor& Vh, const Tensor& input, bool full_matrices, c10::optional driver) { +inline std::tuple svd_out( + Tensor& U, + Tensor& S, + Tensor& Vh, + const Tensor& input, + bool full_matrices, + c10::optional driver) { return detail::svd_out(U, S, Vh, input, full_matrices, driver); } /// Computes the singular values /// /// See https://pytorch.org/docs/master/linalg.html#torch.linalg.svdvals -inline Tensor svdvals(const Tensor& input, c10::optional driver) { +inline Tensor svdvals( + const Tensor& input, + c10::optional driver) { return detail::svdvals(input, driver); } -inline Tensor& svdvals_out(Tensor& result, const Tensor& input, c10::optional driver) { +inline Tensor& svdvals_out( + Tensor& result, + const Tensor& input, + c10::optional driver) { return detail::svdvals_out(result, input, driver); } @@ -636,15 +994,23 @@ inline Tensor& tensorinv_out(Tensor& result, const Tensor& self, int64_t ind) { /// auto b = torch::randn(2*3, 4); /// auto x = torch::linalg::tensorsolve(a, b); /// ``` -inline Tensor tensorsolve(const Tensor& input, const Tensor& other, OptionalIntArrayRef dims) { +inline Tensor tensorsolve( + const Tensor& input, + const Tensor& other, + OptionalIntArrayRef dims) { return detail::tensorsolve(input, other, dims); } -inline Tensor& tensorsolve_out(Tensor& result, const Tensor& input, const Tensor& other, OptionalIntArrayRef dims) { +inline Tensor& tensorsolve_out( + Tensor& result, + const Tensor& input, + const Tensor& other, + OptionalIntArrayRef dims) { return detail::tensorsolve_out(result, input, other, dims); } -/// Computes a tensor `inverse_input` such that `dot(input, inverse_input) = eye(input.size(0))`. +/// Computes a tensor `inverse_input` such that `dot(input, inverse_input) = +/// eye(input.size(0))`. /// /// See https://pytorch.org/docs/master/linalg.html#torch.linalg.inv inline Tensor inv(const Tensor& input) { @@ -655,4 +1021,5 @@ inline Tensor& inv_out(Tensor& result, const Tensor& input) { return detail::inv_out(result, input); } -}} // torch::linalg +} // namespace linalg +} // namespace torch diff --git a/torch/csrc/api/include/torch/nn/cloneable.h b/torch/csrc/api/include/torch/nn/cloneable.h index cc735d63cdd83f..bd41c8d727108b 100644 --- a/torch/csrc/api/include/torch/nn/cloneable.h +++ b/torch/csrc/api/include/torch/nn/cloneable.h @@ -50,8 +50,9 @@ class Cloneable : public virtual Module { "and not the constructor?"); for (const auto& parameter : named_parameters(/*recurse=*/false)) { auto& tensor = *parameter; - auto data = device && tensor.device() != *device ? - tensor.to(*device) : autograd::Variable(tensor).clone(); + auto data = device && tensor.device() != *device + ? tensor.to(*device) + : autograd::Variable(tensor).clone(); copy->parameters_[parameter.key()].set_data(data); } TORCH_CHECK( @@ -62,8 +63,9 @@ class Cloneable : public virtual Module { "and not the constructor?"); for (const auto& buffer : named_buffers(/*recurse=*/false)) { auto& tensor = *buffer; - auto data = device && tensor.device() != *device ? - tensor.to(*device) : autograd::Variable(tensor).clone(); + auto data = device && tensor.device() != *device + ? tensor.to(*device) + : autograd::Variable(tensor).clone(); copy->buffers_[buffer.key()].set_data(data); } TORCH_CHECK( diff --git a/torch/csrc/api/include/torch/nn/functional.h b/torch/csrc/api/include/torch/nn/functional.h index 0a9db5c2d1260c..b148edc68173f4 100644 --- a/torch/csrc/api/include/torch/nn/functional.h +++ b/torch/csrc/api/include/torch/nn/functional.h @@ -6,6 +6,7 @@ #include #include #include +#include #include #include #include @@ -14,4 +15,3 @@ #include #include #include -#include diff --git a/torch/csrc/api/include/torch/nn/functional/activation.h b/torch/csrc/api/include/torch/nn/functional/activation.h index 2258dd0c4317ee..4d8879c682d76f 100644 --- a/torch/csrc/api/include/torch/nn/functional/activation.h +++ b/torch/csrc/api/include/torch/nn/functional/activation.h @@ -1,13 +1,13 @@ #pragma once +#include +#include +#include #include -#include #include +#include #include -#include -#include #include -#include namespace torch { namespace nn { @@ -25,11 +25,12 @@ inline Tensor elu(Tensor input, double alpha, bool inplace) { } // namespace detail #endif /* DOXYGEN_SHOULD_SKIP_THIS */ -/// See https://pytorch.org/docs/master/nn.functional.html#torch.nn.functional.elu +/// See +/// https://pytorch.org/docs/master/nn.functional.html#torch.nn.functional.elu /// about the exact behavior of this functional. /// -/// See the documentation for `torch::nn::functional::ELUFuncOptions` class to learn what -/// optional arguments are supported for this functional. +/// See the documentation for `torch::nn::functional::ELUFuncOptions` class to +/// learn what optional arguments are supported for this functional. /// /// Example: /// ``` @@ -54,11 +55,12 @@ inline Tensor selu(Tensor input, bool inplace) { } // namespace detail #endif /* DOXYGEN_SHOULD_SKIP_THIS */ -/// See https://pytorch.org/docs/master/nn.functional.html#torch.nn.functional.selu +/// See +/// https://pytorch.org/docs/master/nn.functional.html#torch.nn.functional.selu /// about the exact behavior of this functional. /// -/// See the documentation for `torch::nn::functional::SELUFuncOptions` class to learn what -/// optional arguments are supported for this functional. +/// See the documentation for `torch::nn::functional::SELUFuncOptions` class to +/// learn what optional arguments are supported for this functional. /// /// Example: /// ``` @@ -73,26 +75,27 @@ inline Tensor selu(Tensor input, const SELUFuncOptions& options = {}) { #ifndef DOXYGEN_SHOULD_SKIP_THIS namespace detail { -inline Tensor hardshrink(const Tensor& input, - double lambda) { +inline Tensor hardshrink(const Tensor& input, double lambda) { return torch::hardshrink(input, lambda); } } // namespace detail #endif /* DOXYGEN_SHOULD_SKIP_THIS */ -/// See https://pytorch.org/docs/master/nn.functional.html#torch.nn.functional.hardshrink +/// See +/// https://pytorch.org/docs/master/nn.functional.html#torch.nn.functional.hardshrink /// about the exact behavior of this functional. /// -/// See the documentation for `torch::nn::functional::HardshrinkFuncOptions` class to learn what -/// optional arguments are supported for this functional. +/// See the documentation for `torch::nn::functional::HardshrinkFuncOptions` +/// class to learn what optional arguments are supported for this functional. /// /// Example: /// ``` /// namespace F = torch::nn::functional; /// F::hardshrink(x, F::HardshrinkFuncOptions().lambda(0.42)); /// ``` -inline Tensor hardshrink(const Tensor& input, - const HardshrinkFuncOptions& options = {}) { +inline Tensor hardshrink( + const Tensor& input, + const HardshrinkFuncOptions& options = {}) { return detail::hardshrink(input, options.lambda()); } @@ -100,10 +103,11 @@ inline Tensor hardshrink(const Tensor& input, #ifndef DOXYGEN_SHOULD_SKIP_THIS namespace detail { -inline Tensor hardtanh(Tensor input, - double min_val, - double max_val, - bool inplace) { +inline Tensor hardtanh( + Tensor input, + double min_val, + double max_val, + bool inplace) { if (inplace) { return torch::hardtanh_(input, min_val, max_val); } else { @@ -113,28 +117,29 @@ inline Tensor hardtanh(Tensor input, } // namespace detail #endif /* DOXYGEN_SHOULD_SKIP_THIS */ -/// See https://pytorch.org/docs/master/nn.functional.html#torch.nn.functional.hardtanh +/// See +/// https://pytorch.org/docs/master/nn.functional.html#torch.nn.functional.hardtanh /// about the exact behavior of this functional. /// -/// See the documentation for `torch::nn::functional::HardtanhFuncOptions` class to learn what -/// optional arguments are supported for this functional. +/// See the documentation for `torch::nn::functional::HardtanhFuncOptions` class +/// to learn what optional arguments are supported for this functional. /// /// Example: /// ``` /// namespace F = torch::nn::functional; -/// F::hardtanh(x, F::HardtanhFuncOptions().min_val(-1.0).max_val(1.0).inplace(true)); +/// F::hardtanh(x, +/// F::HardtanhFuncOptions().min_val(-1.0).max_val(1.0).inplace(true)); /// ``` inline Tensor hardtanh(Tensor input, const HardtanhFuncOptions& options = {}) { - return detail::hardtanh(input, options.min_val(), options.max_val(), options.inplace()); + return detail::hardtanh( + input, options.min_val(), options.max_val(), options.inplace()); } // ============================================================================ #ifndef DOXYGEN_SHOULD_SKIP_THIS namespace detail { -inline Tensor leaky_relu(Tensor input, - double negative_slope, - bool inplace) { +inline Tensor leaky_relu(Tensor input, double negative_slope, bool inplace) { if (inplace) { return torch::leaky_relu_(input, negative_slope); } else { @@ -144,18 +149,22 @@ inline Tensor leaky_relu(Tensor input, } // namespace detail #endif /* DOXYGEN_SHOULD_SKIP_THIS */ -/// See https://pytorch.org/docs/master/nn.functional.html#torch.nn.functional.leaky_relu +/// See +/// https://pytorch.org/docs/master/nn.functional.html#torch.nn.functional.leaky_relu /// about the exact behavior of this functional. /// -/// See the documentation for `torch::nn::functional::LeakyReLUFuncOptions` class to learn what -/// optional arguments are supported for this functional. +/// See the documentation for `torch::nn::functional::LeakyReLUFuncOptions` +/// class to learn what optional arguments are supported for this functional. /// /// Example: /// ``` /// namespace F = torch::nn::functional; -/// F::leaky_relu(x, F::LeakyReLUFuncOptions().negative_slope(0.42).inplace(true)); +/// F::leaky_relu(x, +/// F::LeakyReLUFuncOptions().negative_slope(0.42).inplace(true)); /// ``` -inline Tensor leaky_relu(Tensor input, const LeakyReLUFuncOptions& options = {}) { +inline Tensor leaky_relu( + Tensor input, + const LeakyReLUFuncOptions& options = {}) { return detail::leaky_relu(input, options.negative_slope(), options.inplace()); } @@ -169,12 +178,14 @@ inline Tensor logsigmoid(const Tensor& input) { #ifndef DOXYGEN_SHOULD_SKIP_THIS namespace detail { -inline Tensor gumbel_softmax(const Tensor& logits, - double tau, - bool hard, - int dim) { - auto gumbels = -torch::empty_like(logits).exponential_().log(); // ~Gumbel(0,1) - gumbels = (logits + gumbels) / tau; // ~Gumbel(logits, tau) +inline Tensor gumbel_softmax( + const Tensor& logits, + double tau, + bool hard, + int dim) { + auto gumbels = + -torch::empty_like(logits).exponential_().log(); // ~Gumbel(0,1) + gumbels = (logits + gumbels) / tau; // ~Gumbel(logits, tau) auto y_soft = gumbels.softmax(dim); torch::Tensor ret; @@ -191,27 +202,33 @@ inline Tensor gumbel_softmax(const Tensor& logits, } // namespace detail #endif /* DOXYGEN_SHOULD_SKIP_THIS */ -/// See https://pytorch.org/docs/master/nn.functional.html#torch.nn.functional.gumbel_softmax +/// See +/// https://pytorch.org/docs/master/nn.functional.html#torch.nn.functional.gumbel_softmax /// about the exact behavior of this functional. /// -/// See the documentation for `torch::nn::functional::GumbelSoftmaxFuncOptions` class to learn what -/// optional arguments are supported for this functional. +/// See the documentation for `torch::nn::functional::GumbelSoftmaxFuncOptions` +/// class to learn what optional arguments are supported for this functional. /// /// Example: /// ``` /// namespace F = torch::nn::functional; /// F::gumbel_softmax(logits, F::GumbelSoftmaxFuncOptions().hard(true).dim(-1)); /// ``` -inline Tensor gumbel_softmax(const Tensor& logits, const GumbelSoftmaxFuncOptions& options = {}) { - return detail::gumbel_softmax(logits, options.tau(), options.hard(), options.dim()); +inline Tensor gumbel_softmax( + const Tensor& logits, + const GumbelSoftmaxFuncOptions& options = {}) { + return detail::gumbel_softmax( + logits, options.tau(), options.hard(), options.dim()); } // ============================================================================ #ifndef DOXYGEN_SHOULD_SKIP_THIS namespace detail { -inline Tensor softmax(const Tensor& input, int64_t dim, - c10::optional dtype) { +inline Tensor softmax( + const Tensor& input, + int64_t dim, + c10::optional dtype) { Tensor ret; if (dtype == c10::nullopt) { @@ -225,11 +242,12 @@ inline Tensor softmax(const Tensor& input, int64_t dim, } // namespace detail #endif /* DOXYGEN_SHOULD_SKIP_THIS */ -/// See https://pytorch.org/docs/master/nn.functional.html#torch.nn.functional.softmax +/// See +/// https://pytorch.org/docs/master/nn.functional.html#torch.nn.functional.softmax /// about the exact behavior of this functional. /// -/// See the documentation for `torch::nn::functional::SoftmaxFuncOptions` class to learn what -/// optional arguments are supported for this functional. +/// See the documentation for `torch::nn::functional::SoftmaxFuncOptions` class +/// to learn what optional arguments are supported for this functional. /// /// Example: /// ``` @@ -244,8 +262,10 @@ inline Tensor softmax(const Tensor& input, const SoftmaxFuncOptions& options) { #ifndef DOXYGEN_SHOULD_SKIP_THIS namespace detail { -inline Tensor softmin(const Tensor& input, int64_t dim, - c10::optional dtype) { +inline Tensor softmin( + const Tensor& input, + int64_t dim, + c10::optional dtype) { Tensor ret; if (dtype == c10::nullopt) { @@ -259,11 +279,12 @@ inline Tensor softmin(const Tensor& input, int64_t dim, } // namespace detail #endif /* DOXYGEN_SHOULD_SKIP_THIS */ -/// See https://pytorch.org/docs/master/nn.functional.html#torch.nn.functional.softmin +/// See +/// https://pytorch.org/docs/master/nn.functional.html#torch.nn.functional.softmin /// about the exact behavior of this functional. /// -/// See the documentation for `torch::nn::functional::SoftminFuncOptions` class to learn what -/// optional arguments are supported for this functional. +/// See the documentation for `torch::nn::functional::SoftminFuncOptions` class +/// to learn what optional arguments are supported for this functional. /// /// Example: /// ``` @@ -278,8 +299,10 @@ inline Tensor softmin(const Tensor& input, const SoftminFuncOptions& options) { #ifndef DOXYGEN_SHOULD_SKIP_THIS namespace detail { -inline Tensor log_softmax(const Tensor& input, int64_t dim, - c10::optional dtype) { +inline Tensor log_softmax( + const Tensor& input, + int64_t dim, + c10::optional dtype) { Tensor ret; if (dtype == c10::nullopt) { @@ -293,18 +316,21 @@ inline Tensor log_softmax(const Tensor& input, int64_t dim, } // namespace detail #endif /* DOXYGEN_SHOULD_SKIP_THIS */ -/// See https://pytorch.org/docs/master/nn.functional.html#torch.nn.functional.log_softmax +/// See +/// https://pytorch.org/docs/master/nn.functional.html#torch.nn.functional.log_softmax /// about the exact behavior of this functional. /// -/// See the documentation for `torch::nn::functional::LogSoftmaxFuncOptions` class to learn what -/// optional arguments are supported for this functional. +/// See the documentation for `torch::nn::functional::LogSoftmaxFuncOptions` +/// class to learn what optional arguments are supported for this functional. /// /// Example: /// ``` /// namespace F = torch::nn::functional; /// F::log_softmax(input, LogSoftmaxFuncOptions(1)); /// ``` -inline Tensor log_softmax(const Tensor& input, const LogSoftmaxFuncOptions& options) { +inline Tensor log_softmax( + const Tensor& input, + const LogSoftmaxFuncOptions& options) { return detail::log_softmax(input, options.dim(), options.dtype()); } @@ -313,17 +339,20 @@ inline Tensor log_softmax(const Tensor& input, const LogSoftmaxFuncOptions& opti #ifndef DOXYGEN_SHOULD_SKIP_THIS namespace detail { inline Tensor glu(const Tensor& input, int64_t dim) { - TORCH_CHECK(input.dim() != 0, "glu does not suppport scalars because halving size must be even"); + TORCH_CHECK( + input.dim() != 0, + "glu does not suppport scalars because halving size must be even"); return torch::glu(input, dim); } } // namespace detail #endif /* DOXYGEN_SHOULD_SKIP_THIS */ -/// See https://pytorch.org/docs/master/nn.functional.html#torch.nn.functional.glu +/// See +/// https://pytorch.org/docs/master/nn.functional.html#torch.nn.functional.glu /// about the exact behavior of this functional. /// -/// See the documentation for `torch::nn::functional::GLUFuncOptions` class to learn what -/// optional arguments are supported for this functional. +/// See the documentation for `torch::nn::functional::GLUFuncOptions` class to +/// learn what optional arguments are supported for this functional. /// /// Example: /// ``` @@ -380,11 +409,12 @@ inline Tensor relu(Tensor input, bool inplace) { } // namespace detail #endif /* DOXYGEN_SHOULD_SKIP_THIS */ -/// See https://pytorch.org/docs/master/nn.functional.html#torch.nn.functional.relu +/// See +/// https://pytorch.org/docs/master/nn.functional.html#torch.nn.functional.relu /// about the exact behavior of this functional. /// -/// See the documentation for `torch::nn::functional::ReLUFuncOptions` class to learn what -/// optional arguments are supported for this functional. +/// See the documentation for `torch::nn::functional::ReLUFuncOptions` class to +/// learn what optional arguments are supported for this functional. /// /// Example: /// ``` @@ -409,11 +439,12 @@ inline Tensor relu6(Tensor input, bool inplace) { } // namespace detail #endif /* DOXYGEN_SHOULD_SKIP_THIS */ -/// See https://pytorch.org/docs/master/nn.functional.html#torch.nn.functional.relu6 +/// See +/// https://pytorch.org/docs/master/nn.functional.html#torch.nn.functional.relu6 /// about the exact behavior of this functional. /// -/// See the documentation for `torch::nn::functional::ReLU6FuncOptions` class to learn what -/// optional arguments are supported for this functional. +/// See the documentation for `torch::nn::functional::ReLU6FuncOptions` class to +/// learn what optional arguments are supported for this functional. /// /// Example: /// ``` @@ -428,11 +459,12 @@ inline Tensor relu6(Tensor input, const ReLU6FuncOptions& options = {}) { #ifndef DOXYGEN_SHOULD_SKIP_THIS namespace detail { -inline Tensor rrelu(Tensor input, - double lower, - double upper, - bool training, - bool inplace) { +inline Tensor rrelu( + Tensor input, + double lower, + double upper, + bool training, + bool inplace) { if (inplace) { return torch::rrelu_(input, lower, upper, training); } else { @@ -442,11 +474,12 @@ inline Tensor rrelu(Tensor input, } // namespace detail #endif /* DOXYGEN_SHOULD_SKIP_THIS */ -/// See https://pytorch.org/docs/master/nn.functional.html#torch.nn.functional.rrelu +/// See +/// https://pytorch.org/docs/master/nn.functional.html#torch.nn.functional.rrelu /// about the exact behavior of this functional. /// -/// See the documentation for `torch::nn::functional::RReLUFuncOptions` class to learn what -/// optional arguments are supported for this functional. +/// See the documentation for `torch::nn::functional::RReLUFuncOptions` class to +/// learn what optional arguments are supported for this functional. /// /// Example: /// ``` @@ -454,16 +487,19 @@ inline Tensor rrelu(Tensor input, /// F::rrelu(x, F::RReLUFuncOptions().lower(0.1).upper(0.4).inplace(true)); /// ``` inline Tensor rrelu(Tensor input, const RReLUFuncOptions& options = {}) { - return detail::rrelu(input, options.lower(), options.upper(), options.training(), options.inplace()); + return detail::rrelu( + input, + options.lower(), + options.upper(), + options.training(), + options.inplace()); } // ============================================================================ #ifndef DOXYGEN_SHOULD_SKIP_THIS namespace detail { -inline Tensor celu(Tensor input, - double alpha, - bool inplace) { +inline Tensor celu(Tensor input, double alpha, bool inplace) { if (inplace) { return torch::celu_(input, alpha); } else { @@ -473,11 +509,12 @@ inline Tensor celu(Tensor input, } // namespace detail #endif /* DOXYGEN_SHOULD_SKIP_THIS */ -/// See https://pytorch.org/docs/master/nn.functional.html#torch.nn.functional.celu +/// See +/// https://pytorch.org/docs/master/nn.functional.html#torch.nn.functional.celu /// about the exact behavior of this functional. /// -/// See the documentation for `torch::nn::functional::CELUFuncOptions` class to learn what -/// optional arguments are supported for this functional. +/// See the documentation for `torch::nn::functional::CELUFuncOptions` class to +/// learn what optional arguments are supported for this functional. /// /// Example: /// ``` @@ -492,55 +529,55 @@ inline Tensor celu(Tensor input, const CELUFuncOptions& options = {}) { #ifndef DOXYGEN_SHOULD_SKIP_THIS namespace detail { -inline Tensor softplus(const Tensor& input, - double beta, - double threshold) { +inline Tensor softplus(const Tensor& input, double beta, double threshold) { return torch::softplus(input, beta, threshold); } } // namespace detail #endif /* DOXYGEN_SHOULD_SKIP_THIS */ -/// See https://pytorch.org/docs/master/nn.functional.html#torch.nn.functional.softplus +/// See +/// https://pytorch.org/docs/master/nn.functional.html#torch.nn.functional.softplus /// about the exact behavior of this functional. /// -/// See the documentation for `torch::nn::functional::SoftplusFuncOptions` class to learn what -/// optional arguments are supported for this functional. +/// See the documentation for `torch::nn::functional::SoftplusFuncOptions` class +/// to learn what optional arguments are supported for this functional. /// /// Example: /// ``` /// namespace F = torch::nn::functional; /// F::softplus(x, F::SoftplusFuncOptions().beta(0.5).threshold(3.0)); /// ``` -inline Tensor softplus(const Tensor& input, - const SoftplusFuncOptions& options = {}) { +inline Tensor softplus( + const Tensor& input, + const SoftplusFuncOptions& options = {}) { return detail::softplus(input, options.beta(), options.threshold()); } - // ============================================================================ #ifndef DOXYGEN_SHOULD_SKIP_THIS namespace detail { -inline Tensor softshrink(const Tensor& input, - double lambda) { +inline Tensor softshrink(const Tensor& input, double lambda) { return torch::softshrink(input, lambda); } } // namespace detail #endif /* DOXYGEN_SHOULD_SKIP_THIS */ -/// See https://pytorch.org/docs/master/nn.functional.html#torch.nn.functional.softshrink +/// See +/// https://pytorch.org/docs/master/nn.functional.html#torch.nn.functional.softshrink /// about the exact behavior of this functional. /// -/// See the documentation for `torch::nn::functional::SoftshrinkFuncOptions` class to learn what -/// optional arguments are supported for this functional. +/// See the documentation for `torch::nn::functional::SoftshrinkFuncOptions` +/// class to learn what optional arguments are supported for this functional. /// /// Example: /// ``` /// namespace F = torch::nn::functional; /// F::softshrink(x, F::SoftshrinkFuncOptions(0.42)); /// ``` -inline Tensor softshrink(const Tensor& input, - const SoftshrinkFuncOptions& options = {}) { +inline Tensor softshrink( + const Tensor& input, + const SoftshrinkFuncOptions& options = {}) { return detail::softshrink(input, options.lambda()); } @@ -560,10 +597,11 @@ inline Tensor tanhshrink(const Tensor& input) { #ifndef DOXYGEN_SHOULD_SKIP_THIS namespace detail { -inline Tensor threshold(Tensor input, - double threshold, - double value, - bool inplace) { +inline Tensor threshold( + Tensor input, + double threshold, + double value, + bool inplace) { if (inplace) { return torch::threshold_(input, threshold, value); } else { @@ -573,11 +611,12 @@ inline Tensor threshold(Tensor input, } // namespace detail #endif /* DOXYGEN_SHOULD_SKIP_THIS */ -/// See https://pytorch.org/docs/master/nn.functional.html#torch.nn.functional.threshold +/// See +/// https://pytorch.org/docs/master/nn.functional.html#torch.nn.functional.threshold /// about the exact behavior of this functional. /// -/// See the documentation for `torch::nn::functional::ThresholdFuncOptions` class to learn what -/// optional arguments are supported for this functional. +/// See the documentation for `torch::nn::functional::ThresholdFuncOptions` +/// class to learn what optional arguments are supported for this functional. /// /// Example: /// ``` @@ -585,7 +624,8 @@ inline Tensor threshold(Tensor input, /// F::threshold(x, F::ThresholdFuncOptions(0.5, 0.5).inplace(true)); /// ``` inline Tensor threshold(Tensor input, const ThresholdFuncOptions& options) { - return detail::threshold(input, options.threshold(), options.value(), options.inplace()); + return detail::threshold( + input, options.threshold(), options.value(), options.inplace()); } // ============================================================================ @@ -593,30 +633,30 @@ inline Tensor threshold(Tensor input, const ThresholdFuncOptions& options) { #ifndef DOXYGEN_SHOULD_SKIP_THIS namespace detail { inline std::tuple multi_head_attention_forward( - const Tensor& query, - const Tensor& key, - const Tensor& value, - int64_t embed_dim_to_check, - int64_t num_heads, - const Tensor& in_proj_weight, - const Tensor& in_proj_bias, - const Tensor& bias_k, - const Tensor& bias_v, - bool add_zero_attn, - double dropout_p, - const Tensor& out_proj_weight, - const Tensor& out_proj_bias, - bool training = true, - const Tensor& key_padding_mask = {}, - bool need_weights = true, - const Tensor& attn_mask = {}, - bool use_separate_proj_weight = false, - const Tensor& q_proj_weight = {}, - const Tensor& k_proj_weight = {}, - const Tensor& v_proj_weight = {}, - const Tensor& static_k = {}, - const Tensor& static_v = {}, - bool average_attn_weights = true) { + const Tensor& query, + const Tensor& key, + const Tensor& value, + int64_t embed_dim_to_check, + int64_t num_heads, + const Tensor& in_proj_weight, + const Tensor& in_proj_bias, + const Tensor& bias_k, + const Tensor& bias_v, + bool add_zero_attn, + double dropout_p, + const Tensor& out_proj_weight, + const Tensor& out_proj_bias, + bool training = true, + const Tensor& key_padding_mask = {}, + bool need_weights = true, + const Tensor& attn_mask = {}, + bool use_separate_proj_weight = false, + const Tensor& q_proj_weight = {}, + const Tensor& k_proj_weight = {}, + const Tensor& v_proj_weight = {}, + const Tensor& static_k = {}, + const Tensor& static_v = {}, + bool average_attn_weights = true) { namespace F = torch::nn::functional; const auto query_sizes = query.sizes(); @@ -627,8 +667,9 @@ inline std::tuple multi_head_attention_forward( TORCH_INTERNAL_ASSERT(key.sizes() == value.sizes()); const auto head_dim = embed_dim / num_heads; - TORCH_CHECK(head_dim * num_heads == embed_dim, - "embed_dim must be divisible by num_heads"); + TORCH_CHECK( + head_dim * num_heads == embed_dim, + "embed_dim must be divisible by num_heads"); const auto scaling = 1 / std::sqrt(head_dim); Tensor q, k, v; @@ -636,7 +677,7 @@ inline std::tuple multi_head_attention_forward( if (torch::equal(query, key) && torch::equal(key, value)) { // self-attention const auto chunks = - F::linear(query, in_proj_weight, in_proj_bias).chunk(3, /*dim=*/-1); + F::linear(query, in_proj_weight, in_proj_bias).chunk(3, /*dim=*/-1); q = chunks[0]; k = chunks[1]; v = chunks[2]; @@ -724,9 +765,18 @@ inline std::tuple multi_head_attention_forward( } if (in_proj_bias.defined()) { - q = F::linear(query, q_proj_weight_non_opt, in_proj_bias.slice(/*dim=*/0, 0, embed_dim)); - k = F::linear(key, k_proj_weight_non_opt, in_proj_bias.slice(/*dim=*/0, embed_dim, (embed_dim * 2))); - v = F::linear(value, v_proj_weight_non_opt, in_proj_bias.slice(/*dim=*/0, (embed_dim * 2))); + q = F::linear( + query, + q_proj_weight_non_opt, + in_proj_bias.slice(/*dim=*/0, 0, embed_dim)); + k = F::linear( + key, + k_proj_weight_non_opt, + in_proj_bias.slice(/*dim=*/0, embed_dim, (embed_dim * 2))); + v = F::linear( + value, + v_proj_weight_non_opt, + in_proj_bias.slice(/*dim=*/0, (embed_dim * 2))); } else { q = F::linear(query, q_proj_weight_non_opt, in_proj_bias); k = F::linear(key, k_proj_weight_non_opt, in_proj_bias); @@ -741,23 +791,22 @@ inline std::tuple multi_head_attention_forward( k = torch::cat({k, bias_k.repeat({1, bsz, 1})}); v = torch::cat({v, bias_v.repeat({1, bsz, 1})}); if (attn_mask_.defined()) { - attn_mask_ = torch::cat({ - attn_mask_, - torch::zeros( - {attn_mask_.size(0), 1}, - at::TensorOptions(attn_mask_.dtype()) - .device(attn_mask_.device()) - )}, /*dim=*/1 - ); + attn_mask_ = torch::cat( + {attn_mask_, + torch::zeros( + {attn_mask_.size(0), 1}, + at::TensorOptions(attn_mask_.dtype()) + .device(attn_mask_.device()))}, + /*dim=*/1); } if (key_padding_mask_.defined()) { - key_padding_mask_ = torch::cat({ - key_padding_mask_, - torch::zeros( - {key_padding_mask_.size(0), 1}, - at::TensorOptions(key_padding_mask_.dtype()) - .device(key_padding_mask_.device()) - )}, /*dim=*/1); + key_padding_mask_ = torch::cat( + {key_padding_mask_, + torch::zeros( + {key_padding_mask_.size(0), 1}, + at::TensorOptions(key_padding_mask_.dtype()) + .device(key_padding_mask_.device()))}, + /*dim=*/1); } } else { TORCH_CHECK(!static_k.defined(), "bias cannot be added to static key."); @@ -794,67 +843,72 @@ inline std::tuple multi_head_attention_forward( auto k_sizes = k.sizes().vec(); k_sizes[1] = 1; k = torch::cat( - { - k, - torch::zeros(k_sizes, at::TensorOptions(k.dtype()).device(k.device())) - }, /*dim=*/1); + {k, + torch::zeros( + k_sizes, at::TensorOptions(k.dtype()).device(k.device()))}, + /*dim=*/1); auto v_sizes = v.sizes().vec(); v_sizes[1] = 1; v = torch::cat( - { - v, - torch::zeros(v_sizes, at::TensorOptions(v.dtype()).device(v.device())) - }, /*dim=*/1); + {v, + torch::zeros( + v_sizes, at::TensorOptions(v.dtype()).device(v.device()))}, + /*dim=*/1); if (attn_mask_.defined()) { attn_mask_ = torch::cat( - { - attn_mask_, - torch::zeros( - {attn_mask_.size(0), 1}, - at::TensorOptions(attn_mask_.dtype()) - .device(attn_mask_.device())) - }, /*dim=*/1); + {attn_mask_, + torch::zeros( + {attn_mask_.size(0), 1}, + at::TensorOptions(attn_mask_.dtype()) + .device(attn_mask_.device()))}, + /*dim=*/1); } if (key_padding_mask_.defined()) { key_padding_mask_ = torch::cat( - { - key_padding_mask_, - torch::zeros( - {key_padding_mask_.size(0), 1}, - at::TensorOptions(key_padding_mask_.dtype()) - .device(key_padding_mask_.device())) - }, /*dim=*/1); + {key_padding_mask_, + torch::zeros( + {key_padding_mask_.size(0), 1}, + at::TensorOptions(key_padding_mask_.dtype()) + .device(key_padding_mask_.device()))}, + /*dim=*/1); } } auto attn_output_weights = torch::bmm(q, k.transpose(1, 2)); - TORCH_CHECK(attn_output_weights.sizes() == IntArrayRef({bsz * num_heads, tgt_len, src_len})); + TORCH_CHECK( + attn_output_weights.sizes() == + IntArrayRef({bsz * num_heads, tgt_len, src_len})); if (attn_mask_.defined()) { attn_mask_ = attn_mask_.unsqueeze(0); attn_output_weights += attn_mask_; } if (key_padding_mask_.defined()) { - attn_output_weights = attn_output_weights.view({bsz, num_heads, tgt_len, src_len}); + attn_output_weights = + attn_output_weights.view({bsz, num_heads, tgt_len, src_len}); attn_output_weights = AT_DISPATCH_FLOATING_TYPES( - attn_output_weights.scalar_type(), - "attn_output_weights.masked_fill", - [&]() { - return attn_output_weights.masked_fill( - key_padding_mask_.unsqueeze(1).unsqueeze(2), - -std::numeric_limits::infinity() - ); - } - ); - attn_output_weights = attn_output_weights.view({bsz * num_heads, tgt_len, src_len}); + attn_output_weights.scalar_type(), + "attn_output_weights.masked_fill", + [&]() { + return attn_output_weights.masked_fill( + key_padding_mask_.unsqueeze(1).unsqueeze(2), + -std::numeric_limits::infinity()); + }); + attn_output_weights = + attn_output_weights.view({bsz * num_heads, tgt_len, src_len}); } // NOLINTNEXTLINE(bugprone-argument-comment) attn_output_weights = F::softmax(attn_output_weights, /*dim=*/-1); - attn_output_weights = F::dropout(attn_output_weights, F::DropoutFuncOptions().p(dropout_p).training(training)); + attn_output_weights = F::dropout( + attn_output_weights, + F::DropoutFuncOptions().p(dropout_p).training(training)); auto attn_output = torch::bmm(attn_output_weights, v); - TORCH_CHECK(attn_output.sizes() == IntArrayRef({bsz * num_heads, tgt_len, head_dim})); - attn_output = attn_output.transpose(0, 1).contiguous().view({tgt_len, bsz, embed_dim}); + TORCH_CHECK( + attn_output.sizes() == IntArrayRef({bsz * num_heads, tgt_len, head_dim})); + attn_output = + attn_output.transpose(0, 1).contiguous().view({tgt_len, bsz, embed_dim}); attn_output = F::linear(attn_output, out_proj_weight, out_proj_bias); if (need_weights) { - attn_output_weights = attn_output_weights.view({bsz, num_heads, tgt_len, src_len}); + attn_output_weights = + attn_output_weights.view({bsz, num_heads, tgt_len, src_len}); if (average_attn_weights) { // average attention weights over heads attn_output_weights = attn_output_weights.sum(/*dim=*/1) / num_heads; @@ -868,34 +922,35 @@ inline std::tuple multi_head_attention_forward( #endif /* DOXYGEN_SHOULD_SKIP_THIS */ inline std::tuple multi_head_attention_forward( - const Tensor& query, const Tensor& key, const Tensor& value, - const MultiheadAttentionForwardFuncOptions& options) { + const Tensor& query, + const Tensor& key, + const Tensor& value, + const MultiheadAttentionForwardFuncOptions& options) { return detail::multi_head_attention_forward( - query, - key, - value, - options.embed_dim_to_check(), - options.num_heads(), - options.in_proj_weight(), - options.in_proj_bias(), - options.bias_k(), - options.bias_v(), - options.add_zero_attn(), - options.dropout_p(), - options.out_proj_weight(), - options.out_proj_bias(), - options.training(), - options.key_padding_mask(), - options.need_weights(), - options.attn_mask(), - options.use_separate_proj_weight(), - options.q_proj_weight(), - options.k_proj_weight(), - options.v_proj_weight(), - options.static_k(), - options.static_v(), - options.average_attn_weights() - ); + query, + key, + value, + options.embed_dim_to_check(), + options.num_heads(), + options.in_proj_weight(), + options.in_proj_bias(), + options.bias_k(), + options.bias_v(), + options.add_zero_attn(), + options.dropout_p(), + options.out_proj_weight(), + options.out_proj_bias(), + options.training(), + options.key_padding_mask(), + options.need_weights(), + options.attn_mask(), + options.use_separate_proj_weight(), + options.q_proj_weight(), + options.k_proj_weight(), + options.v_proj_weight(), + options.static_k(), + options.static_v(), + options.average_attn_weights()); } } // namespace functional diff --git a/torch/csrc/api/include/torch/nn/functional/batchnorm.h b/torch/csrc/api/include/torch/nn/functional/batchnorm.h index 5603ec189e91e3..168a3a2e331c1f 100644 --- a/torch/csrc/api/include/torch/nn/functional/batchnorm.h +++ b/torch/csrc/api/include/torch/nn/functional/batchnorm.h @@ -10,60 +10,68 @@ namespace functional { #ifndef DOXYGEN_SHOULD_SKIP_THIS namespace detail { -inline Tensor batch_norm(const Tensor& input, - const Tensor& running_mean, - const Tensor& running_var, - Tensor weight, - Tensor bias, - bool training, - c10::optional momentum, - double eps) { +inline Tensor batch_norm( + const Tensor& input, + const Tensor& running_mean, + const Tensor& running_var, + Tensor weight, + Tensor bias, + bool training, + c10::optional momentum, + double eps) { if (training) { auto size = input.sizes(); int64_t size_prods = size[0]; for (const auto i : c10::irange(size.size() - 2)) { size_prods *= size[i + 2]; } - TORCH_CHECK(size_prods != 1, - "Expected more than 1 value per channel when training, got input size ", size); + TORCH_CHECK( + size_prods != 1, + "Expected more than 1 value per channel when training, got input size ", + size); } return torch::batch_norm( - input, - weight, - bias, - running_mean, - running_var, - training, - momentum.value(), - eps, - at::globalContext().userEnabledCuDNN()); + input, + weight, + bias, + running_mean, + running_var, + training, + momentum.value(), + eps, + at::globalContext().userEnabledCuDNN()); } } // namespace detail #endif /* DOXYGEN_SHOULD_SKIP_THIS */ -/// See https://pytorch.org/docs/master/nn.functional.html#torch.nn.functional.batch_norm +/// See +/// https://pytorch.org/docs/master/nn.functional.html#torch.nn.functional.batch_norm /// about the exact behavior of this functional. /// -/// See the documentation for `torch::nn::functional::BatchNormFuncOptions` class to learn what -/// optional arguments are supported for this functional. +/// See the documentation for `torch::nn::functional::BatchNormFuncOptions` +/// class to learn what optional arguments are supported for this functional. /// /// Example: /// ``` /// namespace F = torch::nn::functional; -/// F::batch_norm(input, mean, variance, F::BatchNormFuncOptions().weight(weight).bias(bias).momentum(0.1).eps(1e-05).training(false)); +/// F::batch_norm(input, mean, variance, +/// F::BatchNormFuncOptions().weight(weight).bias(bias).momentum(0.1).eps(1e-05).training(false)); /// ``` -inline Tensor batch_norm(const Tensor& input, const Tensor& running_mean, - const Tensor& running_var, const BatchNormFuncOptions& options = {}) { +inline Tensor batch_norm( + const Tensor& input, + const Tensor& running_mean, + const Tensor& running_var, + const BatchNormFuncOptions& options = {}) { return detail::batch_norm( - input, - running_mean, - running_var, - options.weight(), - options.bias(), - options.training(), - options.momentum(), - options.eps()); + input, + running_mean, + running_var, + options.weight(), + options.bias(), + options.training(), + options.momentum(), + options.eps()); } } // namespace functional diff --git a/torch/csrc/api/include/torch/nn/functional/conv.h b/torch/csrc/api/include/torch/nn/functional/conv.h index 60637a8daac542..305a1ac77138e3 100644 --- a/torch/csrc/api/include/torch/nn/functional/conv.h +++ b/torch/csrc/api/include/torch/nn/functional/conv.h @@ -18,12 +18,11 @@ inline std::string padding_unwrap(enumtype::kSame) { return "same"; } -template +template IntArrayRef padding_unwrap(const ExpandingArray& array) { return array; } - inline Tensor conv1d( const Tensor& input, const Tensor& weight, @@ -32,25 +31,22 @@ inline Tensor conv1d( const Conv1dFuncOptions::padding_t& padding, ExpandingArray<1> dilation, int64_t groups) { - return c10::visit([&](const auto & pad) { - return torch::conv1d( - input, - weight, - bias, - stride, - padding_unwrap(pad), - dilation, - groups); - }, padding); + return c10::visit( + [&](const auto& pad) { + return torch::conv1d( + input, weight, bias, stride, padding_unwrap(pad), dilation, groups); + }, + padding); } } // namespace detail #endif /* DOXYGEN_SHOULD_SKIP_THIS */ -/// See https://pytorch.org/docs/master/nn.functional.html#torch.nn.functional.conv1d +/// See +/// https://pytorch.org/docs/master/nn.functional.html#torch.nn.functional.conv1d /// about the exact behavior of this functional. /// -/// See the documentation for `torch::nn::functional::Conv1dFuncOptions` class to learn what -/// optional arguments are supported for this functional. +/// See the documentation for `torch::nn::functional::Conv1dFuncOptions` class +/// to learn what optional arguments are supported for this functional. /// /// Example: /// ``` @@ -62,13 +58,13 @@ inline Tensor conv1d( const Tensor& weight, const Conv1dFuncOptions& options = {}) { return detail::conv1d( - input, - weight, - options.bias(), - options.stride(), - options.padding(), - options.dilation(), - options.groups()); + input, + weight, + options.bias(), + options.stride(), + options.padding(), + options.dilation(), + options.groups()); } #ifndef DOXYGEN_SHOULD_SKIP_THIS @@ -81,25 +77,22 @@ inline Tensor conv2d( const Conv2dFuncOptions::padding_t& padding, ExpandingArray<2> dilation, int64_t groups) { - return c10::visit([&](const auto & pad) { - return torch::conv2d( - input, - weight, - bias, - stride, - padding_unwrap(pad), - dilation, - groups); - }, padding); + return c10::visit( + [&](const auto& pad) { + return torch::conv2d( + input, weight, bias, stride, padding_unwrap(pad), dilation, groups); + }, + padding); } } // namespace detail #endif /* DOXYGEN_SHOULD_SKIP_THIS */ -/// See https://pytorch.org/docs/master/nn.functional.html#torch.nn.functional.conv2d +/// See +/// https://pytorch.org/docs/master/nn.functional.html#torch.nn.functional.conv2d /// about the exact behavior of this functional. /// -/// See the documentation for `torch::nn::functional::Conv2dFuncOptions` class to learn what -/// optional arguments are supported for this functional. +/// See the documentation for `torch::nn::functional::Conv2dFuncOptions` class +/// to learn what optional arguments are supported for this functional. /// /// Example: /// ``` @@ -111,13 +104,13 @@ inline Tensor conv2d( const Tensor& weight, const Conv2dFuncOptions& options = {}) { return detail::conv2d( - input, - weight, - options.bias(), - options.stride(), - options.padding(), - options.dilation(), - options.groups()); + input, + weight, + options.bias(), + options.stride(), + options.padding(), + options.dilation(), + options.groups()); } #ifndef DOXYGEN_SHOULD_SKIP_THIS @@ -130,25 +123,22 @@ inline Tensor conv3d( const Conv3dFuncOptions::padding_t& padding, ExpandingArray<3> dilation, int64_t groups) { - return c10::visit([&](const auto & pad) { - return torch::conv3d( - input, - weight, - bias, - stride, - padding_unwrap(pad), - dilation, - groups); - }, padding); + return c10::visit( + [&](const auto& pad) { + return torch::conv3d( + input, weight, bias, stride, padding_unwrap(pad), dilation, groups); + }, + padding); } } // namespace detail #endif /* DOXYGEN_SHOULD_SKIP_THIS */ -/// See https://pytorch.org/docs/master/nn.functional.html#torch.nn.functional.conv3d +/// See +/// https://pytorch.org/docs/master/nn.functional.html#torch.nn.functional.conv3d /// about the exact behavior of this functional. /// -/// See the documentation for `torch::nn::functional::Conv3dFuncOptions` class to learn what -/// optional arguments are supported for this functional. +/// See the documentation for `torch::nn::functional::Conv3dFuncOptions` class +/// to learn what optional arguments are supported for this functional. /// /// Example: /// ``` @@ -160,33 +150,40 @@ inline Tensor conv3d( const Tensor& weight, const Conv3dFuncOptions& options = {}) { return detail::conv3d( - input, - weight, - options.bias(), - options.stride(), - options.padding(), - options.dilation(), - options.groups()); + input, + weight, + options.bias(), + options.stride(), + options.padding(), + options.dilation(), + options.groups()); } // ============================================================================ #ifndef DOXYGEN_SHOULD_SKIP_THIS namespace detail { -inline Tensor conv_transpose1d(const Tensor& input, const Tensor& weight, - const Tensor& bias, IntArrayRef stride, - IntArrayRef padding, IntArrayRef output_padding, - int64_t groups, IntArrayRef dilation) { +inline Tensor conv_transpose1d( + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef output_padding, + int64_t groups, + IntArrayRef dilation) { return torch::conv_transpose1d( - input, weight, bias, stride, padding, output_padding, groups, dilation); + input, weight, bias, stride, padding, output_padding, groups, dilation); } } // namespace detail #endif /* DOXYGEN_SHOULD_SKIP_THIS */ -/// See https://pytorch.org/docs/master/nn.functional.html#torch.nn.functional.conv_transpose1d +/// See +/// https://pytorch.org/docs/master/nn.functional.html#torch.nn.functional.conv_transpose1d /// about the exact behavior of this functional. /// -/// See the documentation for `torch::nn::functional::ConvTranspose1dFuncOptions` class to learn what +/// See the documentation for +/// `torch::nn::functional::ConvTranspose1dFuncOptions` class to learn what /// optional arguments are supported for this functional. /// /// Example: @@ -194,31 +191,44 @@ inline Tensor conv_transpose1d(const Tensor& input, const Tensor& weight, /// namespace F = torch::nn::functional; /// F::conv_transpose1d(x, weight, F::ConvTranspose1dFuncOptions().stride(1)); /// ``` -inline Tensor conv_transpose1d(const Tensor& input, const Tensor& weight, - const ConvTranspose1dFuncOptions& options = {}) { +inline Tensor conv_transpose1d( + const Tensor& input, + const Tensor& weight, + const ConvTranspose1dFuncOptions& options = {}) { return detail::conv_transpose1d( - input, weight, - options.bias(), options.stride(), - options.padding(), options.output_padding(), - options.groups(), options.dilation()); + input, + weight, + options.bias(), + options.stride(), + options.padding(), + options.output_padding(), + options.groups(), + options.dilation()); } #ifndef DOXYGEN_SHOULD_SKIP_THIS namespace detail { -inline Tensor conv_transpose2d(const Tensor& input, const Tensor& weight, - const Tensor& bias, IntArrayRef stride, - IntArrayRef padding, IntArrayRef output_padding, - int64_t groups, IntArrayRef dilation) { +inline Tensor conv_transpose2d( + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef output_padding, + int64_t groups, + IntArrayRef dilation) { return torch::conv_transpose2d( - input, weight, bias, stride, padding, output_padding, groups, dilation); + input, weight, bias, stride, padding, output_padding, groups, dilation); } } // namespace detail #endif /* DOXYGEN_SHOULD_SKIP_THIS */ -/// See https://pytorch.org/docs/master/nn.functional.html#torch.nn.functional.conv_transpose2d +/// See +/// https://pytorch.org/docs/master/nn.functional.html#torch.nn.functional.conv_transpose2d /// about the exact behavior of this functional. /// -/// See the documentation for `torch::nn::functional::ConvTranspose2dFuncOptions` class to learn what +/// See the documentation for +/// `torch::nn::functional::ConvTranspose2dFuncOptions` class to learn what /// optional arguments are supported for this functional. /// /// Example: @@ -226,31 +236,44 @@ inline Tensor conv_transpose2d(const Tensor& input, const Tensor& weight, /// namespace F = torch::nn::functional; /// F::conv_transpose2d(x, weight, F::ConvTranspose2dFuncOptions().stride(1)); /// ``` -inline Tensor conv_transpose2d(const Tensor& input, const Tensor& weight, - const ConvTranspose2dFuncOptions& options = {}) { +inline Tensor conv_transpose2d( + const Tensor& input, + const Tensor& weight, + const ConvTranspose2dFuncOptions& options = {}) { return detail::conv_transpose2d( - input, weight, - options.bias(), options.stride(), - options.padding(), options.output_padding(), - options.groups(), options.dilation()); + input, + weight, + options.bias(), + options.stride(), + options.padding(), + options.output_padding(), + options.groups(), + options.dilation()); } #ifndef DOXYGEN_SHOULD_SKIP_THIS namespace detail { -inline Tensor conv_transpose3d(const Tensor& input, const Tensor& weight, - const Tensor& bias, IntArrayRef stride, - IntArrayRef padding, IntArrayRef output_padding, - int64_t groups, IntArrayRef dilation) { +inline Tensor conv_transpose3d( + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef output_padding, + int64_t groups, + IntArrayRef dilation) { return torch::conv_transpose3d( - input, weight, bias, stride, padding, output_padding, groups, dilation); + input, weight, bias, stride, padding, output_padding, groups, dilation); } } // namespace detail #endif /* DOXYGEN_SHOULD_SKIP_THIS */ -/// See https://pytorch.org/docs/master/nn.functional.html#torch.nn.functional.conv_transpose3d +/// See +/// https://pytorch.org/docs/master/nn.functional.html#torch.nn.functional.conv_transpose3d /// about the exact behavior of this functional. /// -/// See the documentation for `torch::nn::functional::ConvTranspose3dFuncOptions` class to learn what +/// See the documentation for +/// `torch::nn::functional::ConvTranspose3dFuncOptions` class to learn what /// optional arguments are supported for this functional. /// /// Example: @@ -258,13 +281,19 @@ inline Tensor conv_transpose3d(const Tensor& input, const Tensor& weight, /// namespace F = torch::nn::functional; /// F::conv_transpose3d(x, weight, F::ConvTranspose3dFuncOptions().stride(1)); /// ``` -inline Tensor conv_transpose3d(const Tensor& input, const Tensor& weight, - const ConvTranspose3dFuncOptions& options = {}) { +inline Tensor conv_transpose3d( + const Tensor& input, + const Tensor& weight, + const ConvTranspose3dFuncOptions& options = {}) { return detail::conv_transpose3d( - input, weight, - options.bias(), options.stride(), - options.padding(), options.output_padding(), - options.groups(), options.dilation()); + input, + weight, + options.bias(), + options.stride(), + options.padding(), + options.output_padding(), + options.groups(), + options.dilation()); } } // namespace functional diff --git a/torch/csrc/api/include/torch/nn/functional/distance.h b/torch/csrc/api/include/torch/nn/functional/distance.h index 41d2d6fa11d163..27914017fef22d 100644 --- a/torch/csrc/api/include/torch/nn/functional/distance.h +++ b/torch/csrc/api/include/torch/nn/functional/distance.h @@ -13,25 +13,24 @@ inline Tensor cosine_similarity( const Tensor& x2, int64_t dim, double eps) { - return torch::cosine_similarity( - x1, - x2, - dim, - eps); + return torch::cosine_similarity(x1, x2, dim, eps); } } // namespace detail #endif /* DOXYGEN_SHOULD_SKIP_THIS */ -/// See https://pytorch.org/docs/master/nn.functional.html#torch.nn.functional.cosine_similarity +/// See +/// https://pytorch.org/docs/master/nn.functional.html#torch.nn.functional.cosine_similarity /// about the exact behavior of this functional. /// -/// See the documentation for `torch::nn::functional::CosineSimilarityFuncOptions` class to learn what +/// See the documentation for +/// `torch::nn::functional::CosineSimilarityFuncOptions` class to learn what /// optional arguments are supported for this functional. /// /// Example: /// ``` /// namespace F = torch::nn::functional; -/// F::cosine_similarity(input1, input2, F::CosineSimilarityFuncOptions().dim(1)); +/// F::cosine_similarity(input1, input2, +/// F::CosineSimilarityFuncOptions().dim(1)); /// ``` inline Tensor cosine_similarity( const Tensor& x1, @@ -50,20 +49,17 @@ inline Tensor pairwise_distance( double p, double eps, bool keepdim) { - return torch::pairwise_distance( - x1, - x2, - p, - eps, - keepdim); + return torch::pairwise_distance(x1, x2, p, eps, keepdim); } } // namespace detail #endif /* DOXYGEN_SHOULD_SKIP_THIS */ -/// See https://pytorch.org/docs/master/nn.functional.html#torch.nn.functional.pairwise_distance +/// See +/// https://pytorch.org/docs/master/nn.functional.html#torch.nn.functional.pairwise_distance /// about the exact behavior of this functional. /// -/// See the documentation for `torch::nn::functional::PairwiseDistanceFuncOptions` class to learn what +/// See the documentation for +/// `torch::nn::functional::PairwiseDistanceFuncOptions` class to learn what /// optional arguments are supported for this functional. /// /// Example: @@ -75,7 +71,8 @@ inline Tensor pairwise_distance( const Tensor& x1, const Tensor& x2, const PairwiseDistanceFuncOptions& options = {}) { - return detail::pairwise_distance(x1, x2, options.p(), options.eps(), options.keepdim()); + return detail::pairwise_distance( + x1, x2, options.p(), options.eps(), options.keepdim()); } // ============================================================================ diff --git a/torch/csrc/api/include/torch/nn/functional/dropout.h b/torch/csrc/api/include/torch/nn/functional/dropout.h index 0906adc4686672..71db0361cf2bda 100644 --- a/torch/csrc/api/include/torch/nn/functional/dropout.h +++ b/torch/csrc/api/include/torch/nn/functional/dropout.h @@ -11,9 +11,9 @@ namespace detail { inline Tensor dropout(Tensor input, double p, bool training, bool inplace) { TORCH_CHECK( - p >= 0. && p <= 1., - "dropout probability has to be between 0 and 1, but got ", - p); + p >= 0. && p <= 1., + "dropout probability has to be between 0 and 1, but got ", + p); if (inplace) { return torch::dropout_(input, p, training); } else { @@ -24,19 +24,19 @@ inline Tensor dropout(Tensor input, double p, bool training, bool inplace) { } // namespace detail #endif /* DOXYGEN_SHOULD_SKIP_THIS */ -/// See https://pytorch.org/docs/master/nn.functional.html#torch.nn.functional.dropout +/// See +/// https://pytorch.org/docs/master/nn.functional.html#torch.nn.functional.dropout /// about the exact behavior of this functional. /// -/// See the documentation for `torch::nn::functional::DropoutFuncOptions` class to learn what -/// optional arguments are supported for this functional. +/// See the documentation for `torch::nn::functional::DropoutFuncOptions` class +/// to learn what optional arguments are supported for this functional. /// /// Example: /// ``` /// namespace F = torch::nn::functional; /// F::dropout(input, F::DropoutFuncOptions().p(0.5)); /// ``` -inline Tensor dropout(Tensor input, - const DropoutFuncOptions& options = {}) { +inline Tensor dropout(Tensor input, const DropoutFuncOptions& options = {}) { return detail::dropout( input, options.p(), options.training(), options.inplace()); } @@ -46,12 +46,17 @@ inline Tensor dropout(Tensor input, #ifndef DOXYGEN_SHOULD_SKIP_THIS namespace detail { -template -inline Tensor _dropoutNd_helper(Tensor input, double p, bool training, bool inplace, const char* fn_name) { +template +inline Tensor _dropoutNd_helper( + Tensor input, + double p, + bool training, + bool inplace, + const char* fn_name) { TORCH_CHECK( - p >= 0. && p <= 1., - "dropout probability has to be between 0 and 1, but got ", - p); + p >= 0. && p <= 1., + "dropout probability has to be between 0 and 1, but got ", + p); auto inp_dim = input.dim(); auto is_batched = inp_dim == batched_dim; @@ -87,18 +92,20 @@ inline Tensor dropout2d(Tensor input, double p, bool training, bool inplace) { } // namespace detail #endif /* DOXYGEN_SHOULD_SKIP_THIS */ -/// See https://pytorch.org/docs/master/nn.functional.html#torch.nn.functional.dropout2d +/// See +/// https://pytorch.org/docs/master/nn.functional.html#torch.nn.functional.dropout2d /// about the exact behavior of this functional. /// -/// See the documentation for `torch::nn::functional::Dropout2dFuncOptions` class to learn what -/// optional arguments are supported for this functional. +/// See the documentation for `torch::nn::functional::Dropout2dFuncOptions` +/// class to learn what optional arguments are supported for this functional. /// /// Example: /// ``` /// namespace F = torch::nn::functional; /// F::dropout2d(input, F::Dropout2dFuncOptions().p(0.5)); /// ``` -inline Tensor dropout2d(Tensor input, +inline Tensor dropout2d( + Tensor input, const Dropout2dFuncOptions& options = {}) { return detail::dropout2d( input, options.p(), options.training(), options.inplace()); @@ -116,18 +123,20 @@ inline Tensor dropout3d(Tensor input, double p, bool training, bool inplace) { } // namespace detail #endif /* DOXYGEN_SHOULD_SKIP_THIS */ -/// See https://pytorch.org/docs/master/nn.functional.html#torch.nn.functional.dropout3d +/// See +/// https://pytorch.org/docs/master/nn.functional.html#torch.nn.functional.dropout3d /// about the exact behavior of this functional. /// -/// See the documentation for `torch::nn::functional::Dropout3dFuncOptions` class to learn what -/// optional arguments are supported for this functional. +/// See the documentation for `torch::nn::functional::Dropout3dFuncOptions` +/// class to learn what optional arguments are supported for this functional. /// /// Example: /// ``` /// namespace F = torch::nn::functional; /// F::dropout3d(input, F::Dropout3dFuncOptions().p(0.5)); /// ``` -inline Tensor dropout3d(Tensor input, +inline Tensor dropout3d( + Tensor input, const Dropout3dFuncOptions& options = {}) { return detail::dropout3d( input, options.p(), options.training(), options.inplace()); @@ -138,29 +147,40 @@ inline Tensor dropout3d(Tensor input, #ifndef DOXYGEN_SHOULD_SKIP_THIS namespace detail { -inline Tensor alpha_dropout(Tensor input, double p, bool training, bool inplace) { +inline Tensor alpha_dropout( + Tensor input, + double p, + bool training, + bool inplace) { if (p < 0. || p > 1.) { - TORCH_CHECK(false, "dropout probability has to be between 0 and 1, but got ", p); + TORCH_CHECK( + false, "dropout probability has to be between 0 and 1, but got ", p); } - return inplace ? torch::alpha_dropout_(input, p, training) : torch::alpha_dropout(input, p, training); + return inplace ? torch::alpha_dropout_(input, p, training) + : torch::alpha_dropout(input, p, training); } } // namespace detail #endif /* DOXYGEN_SHOULD_SKIP_THIS */ -/// See https://pytorch.org/docs/master/nn.functional.html#torch.nn.functional.alpha_dropout +/// See +/// https://pytorch.org/docs/master/nn.functional.html#torch.nn.functional.alpha_dropout /// about the exact behavior of this functional. /// -/// See the documentation for `torch::nn::functional::AlphaDropoutFuncOptions` class to learn what -/// optional arguments are supported for this functional. +/// See the documentation for `torch::nn::functional::AlphaDropoutFuncOptions` +/// class to learn what optional arguments are supported for this functional. /// /// Example: /// ``` /// namespace F = torch::nn::functional; -/// F::alpha_dropout(input, F::AlphaDropoutFuncOptions().p(0.5).training(false)); +/// F::alpha_dropout(input, +/// F::AlphaDropoutFuncOptions().p(0.5).training(false)); /// ``` -inline Tensor alpha_dropout(Tensor input, const AlphaDropoutFuncOptions& options = {}) { - return detail::alpha_dropout(input, options.p(), options.training(), options.inplace()); +inline Tensor alpha_dropout( + Tensor input, + const AlphaDropoutFuncOptions& options = {}) { + return detail::alpha_dropout( + input, options.p(), options.training(), options.inplace()); } // ============================================================================ @@ -168,29 +188,41 @@ inline Tensor alpha_dropout(Tensor input, const AlphaDropoutFuncOptions& options #ifndef DOXYGEN_SHOULD_SKIP_THIS namespace detail { -inline Tensor feature_alpha_dropout(Tensor input, double p, bool training, bool inplace) { +inline Tensor feature_alpha_dropout( + Tensor input, + double p, + bool training, + bool inplace) { if (p < 0. || p > 1.) { - TORCH_CHECK(false, "dropout probability has to be between 0 and 1, but got ", p); + TORCH_CHECK( + false, "dropout probability has to be between 0 and 1, but got ", p); } - return inplace ? torch::feature_alpha_dropout_(input, p, training) : torch::feature_alpha_dropout(input, p, training); + return inplace ? torch::feature_alpha_dropout_(input, p, training) + : torch::feature_alpha_dropout(input, p, training); } } // namespace detail #endif /* DOXYGEN_SHOULD_SKIP_THIS */ -/// See https://pytorch.org/docs/master/nn.functional.html#torch.nn.functional.feature_alpha_dropout +/// See +/// https://pytorch.org/docs/master/nn.functional.html#torch.nn.functional.feature_alpha_dropout /// about the exact behavior of this functional. /// -/// See the documentation for `torch::nn::functional::FeatureAlphaDropoutFuncOptions` class to learn what +/// See the documentation for +/// `torch::nn::functional::FeatureAlphaDropoutFuncOptions` class to learn what /// optional arguments are supported for this functional. /// /// Example: /// ``` /// namespace F = torch::nn::functional; -/// F::feature_alpha_dropout(input, F::FeatureAlphaDropoutFuncOptions().p(0.5).training(false)); +/// F::feature_alpha_dropout(input, +/// F::FeatureAlphaDropoutFuncOptions().p(0.5).training(false)); /// ``` -inline Tensor feature_alpha_dropout(Tensor input, const FeatureAlphaDropoutFuncOptions& options = {}) { - return detail::feature_alpha_dropout(input, options.p(), options.training(), options.inplace()); +inline Tensor feature_alpha_dropout( + Tensor input, + const FeatureAlphaDropoutFuncOptions& options = {}) { + return detail::feature_alpha_dropout( + input, options.p(), options.training(), options.inplace()); } } // namespace functional diff --git a/torch/csrc/api/include/torch/nn/functional/embedding.h b/torch/csrc/api/include/torch/nn/functional/embedding.h index 0c8a814cd38d9a..dc5452b39907aa 100644 --- a/torch/csrc/api/include/torch/nn/functional/embedding.h +++ b/torch/csrc/api/include/torch/nn/functional/embedding.h @@ -12,26 +12,34 @@ inline Tensor one_hot(const Tensor& tensor, int64_t num_classes = -1) { #ifndef DOXYGEN_SHOULD_SKIP_THIS namespace detail { -inline void _no_grad_embedding_renorm_(Tensor weight, const Tensor& input, float max_norm, float norm_type) { +inline void _no_grad_embedding_renorm_( + Tensor weight, + const Tensor& input, + float max_norm, + float norm_type) { torch::NoGradGuard no_grad; torch::embedding_renorm_(weight, input, max_norm, norm_type); } -inline Tensor embedding(const Tensor& input, - const Tensor& weight, - c10::optional padding_idx, - c10::optional max_norm, - double norm_type, - bool scale_grad_by_freq, - bool sparse) { +inline Tensor embedding( + const Tensor& input, + const Tensor& weight, + c10::optional padding_idx, + c10::optional max_norm, + double norm_type, + bool scale_grad_by_freq, + bool sparse) { auto input_ = input; if (padding_idx != c10::nullopt) { if (*padding_idx > 0) { - TORCH_CHECK(*padding_idx < weight.size(0), "Padding_idx must be within num_embeddings"); - } - else if (*padding_idx < 0) { - TORCH_CHECK(*padding_idx >= -weight.size(0), "Padding_idx must be within num_embedding"); + TORCH_CHECK( + *padding_idx < weight.size(0), + "Padding_idx must be within num_embeddings"); + } else if (*padding_idx < 0) { + TORCH_CHECK( + *padding_idx >= -weight.size(0), + "Padding_idx must be within num_embedding"); padding_idx = weight.size(0) + *padding_idx; } } else { @@ -43,31 +51,37 @@ inline Tensor embedding(const Tensor& input, // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions) _no_grad_embedding_renorm_(weight, input_, *max_norm, norm_type); } - return torch::embedding(weight, input_, *padding_idx, scale_grad_by_freq, sparse); + return torch::embedding( + weight, input_, *padding_idx, scale_grad_by_freq, sparse); } } // namespace detail #endif /* DOXYGEN_SHOULD_SKIP_THIS */ -/// See https://pytorch.org/docs/master/nn.functional.html#torch.nn.functional.embedding +/// See +/// https://pytorch.org/docs/master/nn.functional.html#torch.nn.functional.embedding /// about the exact behavior of this functional. /// -/// See the documentation for `torch::nn::functional::EmbeddingFuncOptions` class to learn what -/// optional arguments are supported for this functional. +/// See the documentation for `torch::nn::functional::EmbeddingFuncOptions` +/// class to learn what optional arguments are supported for this functional. /// /// Example: /// ``` /// namespace F = torch::nn::functional; -/// F::embedding(input, weight, F::EmbeddingFuncOptions().norm_type(2.5).scale_grad_by_freq(true).sparse(true)); +/// F::embedding(input, weight, +/// F::EmbeddingFuncOptions().norm_type(2.5).scale_grad_by_freq(true).sparse(true)); /// ``` -inline Tensor embedding(const Tensor& input, const Tensor& weight, const EmbeddingFuncOptions& options = {}) { +inline Tensor embedding( + const Tensor& input, + const Tensor& weight, + const EmbeddingFuncOptions& options = {}) { return detail::embedding( - input, - weight, - options.padding_idx(), - options.max_norm(), - options.norm_type(), - options.scale_grad_by_freq(), - options.sparse()); + input, + weight, + options.padding_idx(), + options.max_norm(), + options.norm_type(), + options.scale_grad_by_freq(), + options.sparse()); } #ifndef DOXYGEN_SHOULD_SKIP_THIS @@ -87,25 +101,47 @@ inline Tensor embedding_bag( auto input_ = input; auto offsets_ = offsets; auto per_sample_weights_ = per_sample_weights; - TORCH_CHECK(!per_sample_weights_.defined() || input_.sizes() == per_sample_weights_.sizes(), - "embedding_bag: If per_sample_weights (", per_sample_weights_.sizes(), ") is not null, then it must have the same shape as the input (", input_.sizes(), ")"); + TORCH_CHECK( + !per_sample_weights_.defined() || + input_.sizes() == per_sample_weights_.sizes(), + "embedding_bag: If per_sample_weights (", + per_sample_weights_.sizes(), + ") is not null, then it must have the same shape as the input (", + input_.sizes(), + ")"); if (input_.dim() == 2) { - TORCH_CHECK(!offsets_.defined(), - "If input is 2D, then offsets has to be null, as input is treated is a mini-batch of fixed length sequences. However, found offsets of type Tensor"); - offsets_ = torch::arange(0, input_.numel(), input_.size(1), torch::TensorOptions().dtype(torch::kLong).device(input_.device())); + TORCH_CHECK( + !offsets_.defined(), + "If input is 2D, then offsets has to be null, as input is treated is a mini-batch of fixed length sequences. However, found offsets of type Tensor"); + offsets_ = torch::arange( + 0, + input_.numel(), + input_.size(1), + torch::TensorOptions().dtype(torch::kLong).device(input_.device())); input_ = input_.reshape(-1); if (per_sample_weights_.defined()) { per_sample_weights_ = per_sample_weights_.reshape(-1); } } else if (input_.dim() == 1) { - TORCH_CHECK(offsets_.defined(), "offsets has to be a 1D Tensor but got null"); + TORCH_CHECK( + offsets_.defined(), "offsets has to be a 1D Tensor but got null"); TORCH_CHECK(offsets_.dim() == 1, "offsets has to be a 1D Tensor"); - TORCH_CHECK(offsets_[0].item() == 0, "offsets[0] has to be 0, i.e., the first sequence in the mini-batch has to start from position 0. However, got ", - offsets_[0].item()); - TORCH_CHECK(offsets_[-1].item() <= input_.size(0), "offsets[-1] can not be greater than input's length({", - input_.size(0), "}), but got offsets[-1] of {", offsets_[-1].item(), "}"); + TORCH_CHECK( + offsets_[0].item() == 0, + "offsets[0] has to be 0, i.e., the first sequence in the mini-batch has to start from position 0. However, got ", + offsets_[0].item()); + TORCH_CHECK( + offsets_[-1].item() <= input_.size(0), + "offsets[-1] can not be greater than input's length({", + input_.size(0), + "}), but got offsets[-1] of {", + offsets_[-1].item(), + "}"); } else { - TORCH_CHECK(false, "input has to be 1D or 2D Tensor, but got Tensor of dimension ", input_.dim()); + TORCH_CHECK( + false, + "input has to be 1D or 2D Tensor, but got Tensor of dimension ", + input_.dim()); } // NOLINTNEXTLINE(cppcoreguidelines-init-variables) @@ -116,7 +152,9 @@ inline Tensor embedding_bag( mode_enum = 1; } else if (c10::get_if(&mode)) { mode_enum = 2; - TORCH_CHECK(!scale_grad_by_freq, "max mode does not support scaling the gradient by the frequency"); + TORCH_CHECK( + !scale_grad_by_freq, + "max mode does not support scaling the gradient by the frequency"); TORCH_CHECK(!sparse, "max mode does not support sparse weights"); } else { TORCH_CHECK(false, "mode has to be one of sum, mean or max"); @@ -128,13 +166,13 @@ inline Tensor embedding_bag( } TORCH_CHECK( - !per_sample_weights_.defined() || c10::get_if(&mode), - "embedding_bag: per_sample_weights was not null. ", - "per_sample_weights is only supported for mode='kSum' (got mode='", - torch::enumtype::get_enum_name(mode), "').Please open a feature request on GitHub."); + !per_sample_weights_.defined() || c10::get_if(&mode), + "embedding_bag: per_sample_weights was not null. ", + "per_sample_weights is only supported for mode='kSum' (got mode='", + torch::enumtype::get_enum_name(mode), + "').Please open a feature request on GitHub."); - return std::get<0>( - torch::embedding_bag( + return std::get<0>(torch::embedding_bag( weight, input_, offsets_, @@ -148,30 +186,35 @@ inline Tensor embedding_bag( } // namespace detail #endif /* DOXYGEN_SHOULD_SKIP_THIS */ -/// See https://pytorch.org/docs/master/nn.functional.html#torch.nn.functional.embedding_bag +/// See +/// https://pytorch.org/docs/master/nn.functional.html#torch.nn.functional.embedding_bag /// about the exact behavior of this functional. /// -/// See the documentation for `torch::nn::functional::EmbeddingBagFuncOptions` class to learn what -/// optional arguments are supported for this functional. +/// See the documentation for `torch::nn::functional::EmbeddingBagFuncOptions` +/// class to learn what optional arguments are supported for this functional. /// /// Example: /// ``` /// namespace F = torch::nn::functional; -/// F::embedding_bag(input, weight, F::EmbeddingBagFuncOptions().mode(torch::kSum).offsets(offsets)); +/// F::embedding_bag(input, weight, +/// F::EmbeddingBagFuncOptions().mode(torch::kSum).offsets(offsets)); /// ``` -inline Tensor embedding_bag(const Tensor& input, const Tensor& weight, const EmbeddingBagFuncOptions& options = {}) { +inline Tensor embedding_bag( + const Tensor& input, + const Tensor& weight, + const EmbeddingBagFuncOptions& options = {}) { return detail::embedding_bag( - input, - weight, - options.offsets(), - options.max_norm(), - options.norm_type(), - options.scale_grad_by_freq(), - options.mode(), - options.sparse(), - options.per_sample_weights(), - options.include_last_offset(), - options.padding_idx()); + input, + weight, + options.offsets(), + options.max_norm(), + options.norm_type(), + options.scale_grad_by_freq(), + options.mode(), + options.sparse(), + options.per_sample_weights(), + options.include_last_offset(), + options.padding_idx()); } } // namespace functional diff --git a/torch/csrc/api/include/torch/nn/functional/fold.h b/torch/csrc/api/include/torch/nn/functional/fold.h index 382c0458d6942b..cd47138e32e9e6 100644 --- a/torch/csrc/api/include/torch/nn/functional/fold.h +++ b/torch/csrc/api/include/torch/nn/functional/fold.h @@ -8,35 +8,34 @@ namespace functional { #ifndef DOXYGEN_SHOULD_SKIP_THIS namespace detail { -inline Tensor fold(const Tensor& input, - ExpandingArray<2> output_size, - ExpandingArray<2> kernel_size, - ExpandingArray<2> dilation, - ExpandingArray<2> padding, - ExpandingArray<2> stride) { +inline Tensor fold( + const Tensor& input, + ExpandingArray<2> output_size, + ExpandingArray<2> kernel_size, + ExpandingArray<2> dilation, + ExpandingArray<2> padding, + ExpandingArray<2> stride) { if (input.dim() == 3 || input.dim() == 2) { return torch::col2im( - input, - output_size, - kernel_size, - dilation, - padding, - stride); + input, output_size, kernel_size, dilation, padding, stride); } else { TORCH_CHECK( false, "Input Error: Only unbatched (2D) or batched (3D) input Tensors are supported " - "(got ", input.dim(), "D)"); + "(got ", + input.dim(), + "D)"); } } } // namespace detail #endif /* DOXYGEN_SHOULD_SKIP_THIS */ -/// See https://pytorch.org/docs/master/nn.functional.html#torch.nn.functional.fold +/// See +/// https://pytorch.org/docs/master/nn.functional.html#torch.nn.functional.fold /// about the exact behavior of this functional. /// -/// See the documentation for `torch::nn::functional::FoldFuncOptions` class to learn what -/// optional arguments are supported for this functional. +/// See the documentation for `torch::nn::functional::FoldFuncOptions` class to +/// learn what optional arguments are supported for this functional. /// /// Example: /// ``` @@ -45,45 +44,44 @@ inline Tensor fold(const Tensor& input, /// ``` inline Tensor fold(const Tensor& input, const FoldFuncOptions& options) { return detail::fold( - input, - options.output_size(), - options.kernel_size(), - options.dilation(), - options.padding(), - options.stride()); + input, + options.output_size(), + options.kernel_size(), + options.dilation(), + options.padding(), + options.stride()); } // ============================================================================ #ifndef DOXYGEN_SHOULD_SKIP_THIS namespace detail { -inline Tensor unfold(const Tensor& input, - ExpandingArray<2> kernel_size, - ExpandingArray<2> dilation, - ExpandingArray<2> padding, - ExpandingArray<2> stride) { +inline Tensor unfold( + const Tensor& input, + ExpandingArray<2> kernel_size, + ExpandingArray<2> dilation, + ExpandingArray<2> padding, + ExpandingArray<2> stride) { if (input.dim() == 4) { - return torch::im2col( - input, - kernel_size, - dilation, - padding, - stride); + return torch::im2col(input, kernel_size, dilation, padding, stride); } else { TORCH_CHECK( false, "Input Error: Only 4D input Tensors are supported " - "(got ", input.dim(), "D)"); + "(got ", + input.dim(), + "D)"); } } } // namespace detail #endif /* DOXYGEN_SHOULD_SKIP_THIS */ -/// See https://pytorch.org/docs/master/nn.functional.html#torch.nn.functional.unfold +/// See +/// https://pytorch.org/docs/master/nn.functional.html#torch.nn.functional.unfold /// about the exact behavior of this functional. /// -/// See the documentation for `torch::nn::functional::UnfoldFuncOptions` class to learn what -/// optional arguments are supported for this functional. +/// See the documentation for `torch::nn::functional::UnfoldFuncOptions` class +/// to learn what optional arguments are supported for this functional. /// /// Example: /// ``` @@ -91,7 +89,12 @@ inline Tensor unfold(const Tensor& input, /// F::unfold(input, F::UnfoldFuncOptions({2, 2}).padding(1).stride(2)); /// ``` inline Tensor unfold(const Tensor& input, const UnfoldFuncOptions& options) { - return detail::unfold(input, options.kernel_size(), options.dilation(), options.padding(), options.stride()); + return detail::unfold( + input, + options.kernel_size(), + options.dilation(), + options.padding(), + options.stride()); } } // namespace functional diff --git a/torch/csrc/api/include/torch/nn/functional/instancenorm.h b/torch/csrc/api/include/torch/nn/functional/instancenorm.h index ba93c8143b625b..bfa42a32f7940e 100644 --- a/torch/csrc/api/include/torch/nn/functional/instancenorm.h +++ b/torch/csrc/api/include/torch/nn/functional/instancenorm.h @@ -8,34 +8,54 @@ namespace functional { #ifndef DOXYGEN_SHOULD_SKIP_THIS namespace detail { -inline Tensor instance_norm(const Tensor& input, const Tensor& running_mean, - const Tensor& running_var, const Tensor& weight, const Tensor& bias, - bool use_input_stats, double momentum, double eps) { - +inline Tensor instance_norm( + const Tensor& input, + const Tensor& running_mean, + const Tensor& running_var, + const Tensor& weight, + const Tensor& bias, + bool use_input_stats, + double momentum, + double eps) { return torch::instance_norm( - input, weight, bias, running_mean, running_var, - use_input_stats, momentum, eps, at::globalContext().userEnabledCuDNN() - ); + input, + weight, + bias, + running_mean, + running_var, + use_input_stats, + momentum, + eps, + at::globalContext().userEnabledCuDNN()); } } // namespace detail #endif /* DOXYGEN_SHOULD_SKIP_THIS */ -/// See https://pytorch.org/docs/master/nn.functional.html#torch.nn.functional.instance_norm +/// See +/// https://pytorch.org/docs/master/nn.functional.html#torch.nn.functional.instance_norm /// about the exact behavior of this functional. /// -/// See the documentation for `torch::nn::functional::InstanceNormFuncOptions` class to learn what -/// optional arguments are supported for this functional. +/// See the documentation for `torch::nn::functional::InstanceNormFuncOptions` +/// class to learn what optional arguments are supported for this functional. /// /// Example: /// ``` /// namespace F = torch::nn::functional; -/// F::instance_norm(input, F::InstanceNormFuncOptions().running_mean(mean).running_var(variance).weight(weight).bias(bias).momentum(0.1).eps(1e-5)); +/// F::instance_norm(input, +/// F::InstanceNormFuncOptions().running_mean(mean).running_var(variance).weight(weight).bias(bias).momentum(0.1).eps(1e-5)); /// ``` -inline Tensor instance_norm(const Tensor& input, const InstanceNormFuncOptions& options = {}) { +inline Tensor instance_norm( + const Tensor& input, + const InstanceNormFuncOptions& options = {}) { return detail::instance_norm( - input, options.running_mean(), - options.running_var(), options.weight(), options.bias(), - options.use_input_stats(), options.momentum(), options.eps()); + input, + options.running_mean(), + options.running_var(), + options.weight(), + options.bias(), + options.use_input_stats(), + options.momentum(), + options.eps()); } } // namespace functional diff --git a/torch/csrc/api/include/torch/nn/functional/linear.h b/torch/csrc/api/include/torch/nn/functional/linear.h index 979c4944c9460f..ffeafcd712af04 100644 --- a/torch/csrc/api/include/torch/nn/functional/linear.h +++ b/torch/csrc/api/include/torch/nn/functional/linear.h @@ -6,21 +6,27 @@ namespace torch { namespace nn { namespace functional { -inline Tensor bilinear(const Tensor& input1, const Tensor& input2, const Tensor& weight, const Tensor& bias=Tensor()) { - return torch::bilinear(input1, input2, weight, bias); +inline Tensor bilinear( + const Tensor& input1, + const Tensor& input2, + const Tensor& weight, + const Tensor& bias = Tensor()) { + return torch::bilinear(input1, input2, weight, bias); } // ============================================================================ -inline Tensor linear(const Tensor& input, const Tensor& weight, - const Tensor& bias = {}) { +inline Tensor linear( + const Tensor& input, + const Tensor& weight, + const Tensor& bias = {}) { if (input.dim() == 2 && bias.defined()) { // fused op is marginally faster return torch::addmm(bias, input, weight.t()); } else { auto output = input.matmul(weight.t()); if (bias.defined()) { - output += bias; + output += bias; } return output; } diff --git a/torch/csrc/api/include/torch/nn/functional/loss.h b/torch/csrc/api/include/torch/nn/functional/loss.h index f6c7ddc97e4291..99269d2c2c4f22 100644 --- a/torch/csrc/api/include/torch/nn/functional/loss.h +++ b/torch/csrc/api/include/torch/nn/functional/loss.h @@ -14,19 +14,17 @@ inline Tensor l1_loss( const Tensor& input, const Tensor& target, L1LossFuncOptions::reduction_t reduction) { - return torch::l1_loss( - input, - target, - enumtype::reduction_get_enum(reduction)); + return torch::l1_loss(input, target, enumtype::reduction_get_enum(reduction)); } } // namespace detail #endif /* DOXYGEN_SHOULD_SKIP_THIS */ -/// See https://pytorch.org/docs/master/nn.functional.html#torch.nn.functional.l1_loss +/// See +/// https://pytorch.org/docs/master/nn.functional.html#torch.nn.functional.l1_loss /// about the exact behavior of this functional. /// -/// See the documentation for `torch::nn::functional::L1LossFuncOptions` class to learn what -/// optional arguments are supported for this functional. +/// See the documentation for `torch::nn::functional::L1LossFuncOptions` class +/// to learn what optional arguments are supported for this functional. /// /// Example: /// ``` @@ -53,9 +51,10 @@ inline Tensor kl_div( torch::Reduction::Reduction reduction_enum; if (c10::get_if(&reduction)) { - TORCH_WARN("reduction: 'mean' divides the total loss by both the batch size and the support size." - "'batchmean' divides only by the batch size, and aligns with the KL div math definition." - "'mean' will be changed to behave the same as 'batchmean' in the next major release."); + TORCH_WARN( + "reduction: 'mean' divides the total loss by both the batch size and the support size." + "'batchmean' divides only by the batch size, and aligns with the KL div math definition." + "'mean' will be changed to behave the same as 'batchmean' in the next major release."); } // special case for batchmean @@ -76,22 +75,25 @@ inline Tensor kl_div( } // namespace detail #endif /* DOXYGEN_SHOULD_SKIP_THIS */ -/// See https://pytorch.org/docs/master/nn.functional.html#torch.nn.functional.kl_div +/// See +/// https://pytorch.org/docs/master/nn.functional.html#torch.nn.functional.kl_div /// about the exact behavior of this functional. /// -/// See the documentation for `torch::nn::functional::KLDivFuncOptions` class to learn what -/// optional arguments are supported for this functional. +/// See the documentation for `torch::nn::functional::KLDivFuncOptions` class to +/// learn what optional arguments are supported for this functional. /// /// Example: /// ``` /// namespace F = torch::nn::functional; -/// F::kl_div(input, target, F::KLDivFuncOptions.reduction(torch::kNone).log_target(false)); +/// F::kl_div(input, target, +/// F::KLDivFuncOptions.reduction(torch::kNone).log_target(false)); /// ``` inline Tensor kl_div( const Tensor& input, const Tensor& target, const KLDivFuncOptions& options = {}) { - return detail::kl_div(input, target, options.reduction(), options.log_target()); + return detail::kl_div( + input, target, options.reduction(), options.log_target()); } // ============================================================================ @@ -103,27 +105,31 @@ inline Tensor mse_loss( const Tensor& target, MSELossFuncOptions::reduction_t reduction) { if (!(target.sizes() == input.sizes())) { - TORCH_WARN("Using a target size (", target.sizes(), - ") that is different to the input size (", input.sizes(), "). ", - "This will likely lead to incorrect results due to broadcasting. ", - "Please ensure they have the same size."); + TORCH_WARN( + "Using a target size (", + target.sizes(), + ") that is different to the input size (", + input.sizes(), + "). ", + "This will likely lead to incorrect results due to broadcasting. ", + "Please ensure they have the same size."); } - std::vector broadcast_tensors = torch::broadcast_tensors({input, target}); + std::vector broadcast_tensors = + torch::broadcast_tensors({input, target}); auto expanded_input = broadcast_tensors[0]; auto expanded_target = broadcast_tensors[1]; return torch::mse_loss( - expanded_input, - expanded_target, - enumtype::reduction_get_enum(reduction)); + expanded_input, expanded_target, enumtype::reduction_get_enum(reduction)); } } // namespace detail #endif /* DOXYGEN_SHOULD_SKIP_THIS */ -/// See https://pytorch.org/docs/master/nn.functional.html#torch.nn.functional.mse_loss +/// See +/// https://pytorch.org/docs/master/nn.functional.html#torch.nn.functional.mse_loss /// about the exact behavior of this functional. /// -/// See the documentation for `torch::nn::functional::MSELossFuncOptions` class to learn what -/// optional arguments are supported for this functional. +/// See the documentation for `torch::nn::functional::MSELossFuncOptions` class +/// to learn what optional arguments are supported for this functional. /// /// Example: /// ``` @@ -150,10 +156,14 @@ inline Tensor binary_cross_entropy( if (target.sizes() != input.sizes()) { TORCH_CHECK( - false, - "Using a target size (", target.sizes(), ") ", - "that is different to the input size (", input.sizes(), ") is deprecated. ", - "Please ensure they have the same size."); + false, + "Using a target size (", + target.sizes(), + ") ", + "that is different to the input size (", + input.sizes(), + ") is deprecated. ", + "Please ensure they have the same size."); } auto weight_ = weight; @@ -167,22 +177,26 @@ inline Tensor binary_cross_entropy( } // namespace detail #endif /* DOXYGEN_SHOULD_SKIP_THIS */ -/// See https://pytorch.org/docs/master/nn.functional.html#torch.nn.functional.binary_cross_entropy +/// See +/// https://pytorch.org/docs/master/nn.functional.html#torch.nn.functional.binary_cross_entropy /// about the exact behavior of this functional. /// -/// See the documentation for `torch::nn::functional::BinaryCrossEntropyFuncOptions` class to learn what +/// See the documentation for +/// `torch::nn::functional::BinaryCrossEntropyFuncOptions` class to learn what /// optional arguments are supported for this functional. /// /// Example: /// ``` /// namespace F = torch::nn::functional; -/// F::binary_cross_entropy(input, target, F::BinaryCrossEntropyFuncOptions().weight(weight)); +/// F::binary_cross_entropy(input, target, +/// F::BinaryCrossEntropyFuncOptions().weight(weight)); /// ``` inline Tensor binary_cross_entropy( const Tensor& input, const Tensor& target, const BinaryCrossEntropyFuncOptions& options = {}) { - return detail::binary_cross_entropy(input, target, options.weight(), options.reduction()); + return detail::binary_cross_entropy( + input, target, options.weight(), options.reduction()); } // ============================================================================ @@ -195,30 +209,31 @@ inline Tensor hinge_embedding_loss( double margin, HingeEmbeddingLossFuncOptions::reduction_t reduction) { return torch::hinge_embedding_loss( - input, - target, - margin, - enumtype::reduction_get_enum(reduction)); + input, target, margin, enumtype::reduction_get_enum(reduction)); } } // namespace detail #endif /* DOXYGEN_SHOULD_SKIP_THIS */ -/// See https://pytorch.org/docs/master/nn.functional.html#torch.nn.functional.hinge_embedding_loss +/// See +/// https://pytorch.org/docs/master/nn.functional.html#torch.nn.functional.hinge_embedding_loss /// about the exact behavior of this functional. /// -/// See the documentation for `torch::nn::functional::HingeEmbeddingLossFuncOptions` class to learn what +/// See the documentation for +/// `torch::nn::functional::HingeEmbeddingLossFuncOptions` class to learn what /// optional arguments are supported for this functional. /// /// Example: /// ``` /// namespace F = torch::nn::functional; -/// F::hinge_embedding_loss(input, target, F::HingeEmbeddingLossFuncOptions().margin(2)); +/// F::hinge_embedding_loss(input, target, +/// F::HingeEmbeddingLossFuncOptions().margin(2)); /// ``` inline Tensor hinge_embedding_loss( const Tensor& input, const Tensor& target, const HingeEmbeddingLossFuncOptions& options = {}) { - return detail::hinge_embedding_loss(input, target, options.margin(), options.reduction()); + return detail::hinge_embedding_loss( + input, target, options.margin(), options.reduction()); } // ============================================================================ @@ -238,33 +253,41 @@ inline Tensor multi_margin_loss( } return torch::multi_margin_loss( - input, - target, - p, - margin, - weight, - enumtype::reduction_get_enum(reduction) - ); + input, + target, + p, + margin, + weight, + enumtype::reduction_get_enum(reduction)); } } // namespace detail #endif /* DOXYGEN_SHOULD_SKIP_THIS */ -/// See https://pytorch.org/docs/master/nn.functional.html#torch.nn.functional.multi_margin_loss +/// See +/// https://pytorch.org/docs/master/nn.functional.html#torch.nn.functional.multi_margin_loss /// about the exact behavior of this functional. /// -/// See the documentation for `torch::nn::functional::MultiMarginLossFuncOptions` class to learn what +/// See the documentation for +/// `torch::nn::functional::MultiMarginLossFuncOptions` class to learn what /// optional arguments are supported for this functional. /// /// Example: /// ``` /// namespace F = torch::nn::functional; -/// F::multi_margin_loss(input, target, F::MultiMarginLossFuncOptions().margin(2).weight(weight)); +/// F::multi_margin_loss(input, target, +/// F::MultiMarginLossFuncOptions().margin(2).weight(weight)); /// ``` inline Tensor multi_margin_loss( const Tensor& input, const Tensor& target, const MultiMarginLossFuncOptions& options = {}) { - return detail::multi_margin_loss(input, target, options.p(), options.margin(), options.weight(), options.reduction()); + return detail::multi_margin_loss( + input, + target, + options.p(), + options.margin(), + options.weight(), + options.reduction()); } // ============================================================================ @@ -278,39 +301,42 @@ inline Tensor cosine_embedding_loss( double margin, CosineEmbeddingLossFuncOptions::reduction_t reduction) { return torch::cosine_embedding_loss( - input1, - input2, - target, - margin, - enumtype::reduction_get_enum(reduction)); + input1, input2, target, margin, enumtype::reduction_get_enum(reduction)); } } // namespace detail #endif /* DOXYGEN_SHOULD_SKIP_THIS */ -/// See https://pytorch.org/docs/master/nn.functional.html#torch.nn.functional.cosine_embedding_loss +/// See +/// https://pytorch.org/docs/master/nn.functional.html#torch.nn.functional.cosine_embedding_loss /// about the exact behavior of this functional. /// -/// See the documentation for `torch::nn::functional::CosineEmbeddingLossFuncOptions` class to learn what +/// See the documentation for +/// `torch::nn::functional::CosineEmbeddingLossFuncOptions` class to learn what /// optional arguments are supported for this functional. /// /// Example: /// ``` /// namespace F = torch::nn::functional; -/// F::cosine_embedding_loss(input1, input2, target, F::CosineEmbeddingLossFuncOptions().margin(0.5)); +/// F::cosine_embedding_loss(input1, input2, target, +/// F::CosineEmbeddingLossFuncOptions().margin(0.5)); /// ``` inline Tensor cosine_embedding_loss( const Tensor& input1, const Tensor& input2, const Tensor& target, const CosineEmbeddingLossFuncOptions& options = {}) { - return detail::cosine_embedding_loss(input1, input2, target, options.margin(), options.reduction()); + return detail::cosine_embedding_loss( + input1, input2, target, options.margin(), options.reduction()); } // ============================================================================ -inline Tensor _smooth_l1_loss(const Tensor& input, const Tensor& target, double beta = 1.) { - auto t = torch::abs(input - target); - return torch::where(t < beta, 0.5 * torch::pow(t, 2) / beta, t - 0.5 * beta); +inline Tensor _smooth_l1_loss( + const Tensor& input, + const Tensor& target, + double beta = 1.) { + auto t = torch::abs(input - target); + return torch::where(t < beta, 0.5 * torch::pow(t, 2) / beta, t - 0.5 * beta); } #ifndef DOXYGEN_SHOULD_SKIP_THIS @@ -321,22 +347,33 @@ inline Tensor smooth_l1_loss( SmoothL1LossFuncOptions::reduction_t reduction, double beta = 1.) { if (target.sizes() != input.sizes()) { - TORCH_WARN("Using a target size (", target.sizes(), ") that is different to the input size (", input.sizes(), "). ", - "This will likely lead to incorrect results due to broadcasting. ", - "Please ensure they have the same size."); + TORCH_WARN( + "Using a target size (", + target.sizes(), + ") that is different to the input size (", + input.sizes(), + "). ", + "This will likely lead to incorrect results due to broadcasting. ", + "Please ensure they have the same size."); } - std::vector expanded_tensors = torch::broadcast_tensors({input, target}); - return torch::smooth_l1_loss(expanded_tensors[0], expanded_tensors[1], enumtype::reduction_get_enum(reduction), beta); + std::vector expanded_tensors = + torch::broadcast_tensors({input, target}); + return torch::smooth_l1_loss( + expanded_tensors[0], + expanded_tensors[1], + enumtype::reduction_get_enum(reduction), + beta); } } // namespace detail #endif /* DOXYGEN_SHOULD_SKIP_THIS */ -/// See https://pytorch.org/docs/master/nn.functional.html#torch.nn.functional.smooth_l1_loss +/// See +/// https://pytorch.org/docs/master/nn.functional.html#torch.nn.functional.smooth_l1_loss /// about the exact behavior of this functional. /// -/// See the documentation for `torch::nn::functional::SmoothL1LossFuncOptions` class to learn what -/// optional arguments are supported for this functional. +/// See the documentation for `torch::nn::functional::SmoothL1LossFuncOptions` +/// class to learn what optional arguments are supported for this functional. /// /// Example: /// ``` @@ -361,33 +398,46 @@ inline Tensor huber_loss( HuberLossFuncOptions::reduction_t reduction, double delta = 1.) { if (target.sizes() != input.sizes()) { - TORCH_WARN("Using a target size (", target.sizes(), ") that is different to the input size (", input.sizes(), "). ", - "This will likely lead to incorrect results due to broadcasting. ", - "Please ensure they have the same size."); + TORCH_WARN( + "Using a target size (", + target.sizes(), + ") that is different to the input size (", + input.sizes(), + "). ", + "This will likely lead to incorrect results due to broadcasting. ", + "Please ensure they have the same size."); } - std::vector expanded_tensors = torch::broadcast_tensors({input, target}); - return torch::huber_loss(expanded_tensors[0], expanded_tensors[1], enumtype::reduction_get_enum(reduction), delta); + std::vector expanded_tensors = + torch::broadcast_tensors({input, target}); + return torch::huber_loss( + expanded_tensors[0], + expanded_tensors[1], + enumtype::reduction_get_enum(reduction), + delta); } } // namespace detail #endif /* DOXYGEN_SHOULD_SKIP_THIS */ -/// See https://pytorch.org/docs/master/nn.functional.html#torch.nn.functional.huber_loss +/// See +/// https://pytorch.org/docs/master/nn.functional.html#torch.nn.functional.huber_loss /// about the exact behavior of this functional. /// -/// See the documentation for `torch::nn::functional::HuberLossFuncOptions` class to learn what -/// optional arguments are supported for this functional. +/// See the documentation for `torch::nn::functional::HuberLossFuncOptions` +/// class to learn what optional arguments are supported for this functional. /// /// Example: /// ``` /// namespace F = torch::nn::functional; -/// F::huber_loss(input, target, F::HuberLossFuncOptions().reduction(torch::kNone).delta(0.5)); +/// F::huber_loss(input, target, +/// F::HuberLossFuncOptions().reduction(torch::kNone).delta(0.5)); /// ``` inline Tensor huber_loss( const Tensor& input, const Tensor& target, const HuberLossFuncOptions& options = {}) { - return detail::huber_loss(input, target, options.reduction(), options.delta()); + return detail::huber_loss( + input, target, options.reduction(), options.delta()); } // ============================================================================ @@ -399,23 +449,24 @@ inline Tensor multilabel_margin_loss( const Tensor& target, MultilabelMarginLossFuncOptions::reduction_t reduction) { return torch::multilabel_margin_loss( - input, - target, - enumtype::reduction_get_enum(reduction)); + input, target, enumtype::reduction_get_enum(reduction)); } } // namespace detail #endif /* DOXYGEN_SHOULD_SKIP_THIS */ -/// See https://pytorch.org/docs/master/nn.functional.html#torch.nn.functional.multilabel_margin_loss +/// See +/// https://pytorch.org/docs/master/nn.functional.html#torch.nn.functional.multilabel_margin_loss /// about the exact behavior of this functional. /// -/// See the documentation for `torch::nn::functional::MultilabelMarginLossFuncOptions` class to learn what +/// See the documentation for +/// `torch::nn::functional::MultilabelMarginLossFuncOptions` class to learn what /// optional arguments are supported for this functional. /// /// Example: /// ``` /// namespace F = torch::nn::functional; -/// F::multilabel_margin_loss(input, target, F::MultilabelMarginLossFuncOptions(torch::kNone)); +/// F::multilabel_margin_loss(input, target, +/// F::MultilabelMarginLossFuncOptions(torch::kNone)); /// ``` inline Tensor multilabel_margin_loss( const Tensor& input, @@ -433,23 +484,23 @@ inline Tensor soft_margin_loss( const Tensor& target, SoftMarginLossFuncOptions::reduction_t reduction) { return torch::soft_margin_loss( - input, - target, - enumtype::reduction_get_enum(reduction)); + input, target, enumtype::reduction_get_enum(reduction)); } } // namespace detail #endif /* DOXYGEN_SHOULD_SKIP_THIS */ -/// See https://pytorch.org/docs/master/nn.functional.html#torch.nn.functional.soft_margin_loss +/// See +/// https://pytorch.org/docs/master/nn.functional.html#torch.nn.functional.soft_margin_loss /// about the exact behavior of this functional. /// -/// See the documentation for `torch::nn::functional::SoftMarginLossFuncOptions` class to learn what -/// optional arguments are supported for this functional. +/// See the documentation for `torch::nn::functional::SoftMarginLossFuncOptions` +/// class to learn what optional arguments are supported for this functional. /// /// Example: /// ``` /// namespace F = torch::nn::functional; -/// F::soft_margin_loss(input, target, F::SoftMarginLossFuncOptions(torch::kNone)); +/// F::soft_margin_loss(input, target, +/// F::SoftMarginLossFuncOptions(torch::kNone)); /// ``` inline Tensor soft_margin_loss( const Tensor& input, @@ -467,7 +518,9 @@ inline Tensor multilabel_soft_margin_loss( const Tensor& target, const Tensor& weight, MultilabelSoftMarginLossFuncOptions::reduction_t reduction) { - auto loss = -(target * torch::log_sigmoid(input) + (1 - target) * torch::log_sigmoid(-input)); + auto loss = + -(target * torch::log_sigmoid(input) + + (1 - target) * torch::log_sigmoid(-input)); if (weight.defined()) { loss = loss * weight; } @@ -487,31 +540,33 @@ inline Tensor multilabel_soft_margin_loss( } else { ret = input; TORCH_INTERNAL_ASSERT( - false, - enumtype::get_enum_name(reduction), - " is not valid"); + false, enumtype::get_enum_name(reduction), " is not valid"); } return ret; } } // namespace detail #endif /* DOXYGEN_SHOULD_SKIP_THIS */ -/// See https://pytorch.org/docs/master/nn.functional.html#torch.nn.functional.multilabel_soft_margin_loss +/// See +/// https://pytorch.org/docs/master/nn.functional.html#torch.nn.functional.multilabel_soft_margin_loss /// about the exact behavior of this functional. /// -/// See the documentation for `torch::nn::functional::MultilabelSoftMarginLossFuncOptions` class to learn what -/// optional arguments are supported for this functional. +/// See the documentation for +/// `torch::nn::functional::MultilabelSoftMarginLossFuncOptions` class to learn +/// what optional arguments are supported for this functional. /// /// Example: /// ``` /// namespace F = torch::nn::functional; -/// F::multilabel_soft_margin_loss(input, target, F::MultilabelSoftMarginLossFuncOptions().reduction(torch::kNone).weight(weight)); +/// F::multilabel_soft_margin_loss(input, target, +/// F::MultilabelSoftMarginLossFuncOptions().reduction(torch::kNone).weight(weight)); /// ``` inline Tensor multilabel_soft_margin_loss( const Tensor& input, const Tensor& target, const MultilabelSoftMarginLossFuncOptions& options = {}) { - return detail::multilabel_soft_margin_loss(input, target, options.weight(), options.reduction()); + return detail::multilabel_soft_margin_loss( + input, target, options.weight(), options.reduction()); } // ============================================================================ @@ -540,16 +595,19 @@ inline Tensor triplet_margin_loss( } // namespace detail #endif /* DOXYGEN_SHOULD_SKIP_THIS */ -/// See https://pytorch.org/docs/master/nn.functional.html#torch.nn.functional.triplet_margin_loss +/// See +/// https://pytorch.org/docs/master/nn.functional.html#torch.nn.functional.triplet_margin_loss /// about the exact behavior of this functional. /// -/// See the documentation for `torch::nn::functional::TripletMarginLossFuncOptions` class to learn what +/// See the documentation for +/// `torch::nn::functional::TripletMarginLossFuncOptions` class to learn what /// optional arguments are supported for this functional. /// /// Example: /// ``` /// namespace F = torch::nn::functional; -/// F::triplet_margin_loss(anchor, positive, negative, F::TripletMarginLossFuncOptions().margin(1.0)); +/// F::triplet_margin_loss(anchor, positive, negative, +/// F::TripletMarginLossFuncOptions().margin(1.0)); /// ``` inline Tensor triplet_margin_loss( const Tensor& anchor, @@ -557,14 +615,14 @@ inline Tensor triplet_margin_loss( const Tensor& negative, const TripletMarginLossFuncOptions& options = {}) { return detail::triplet_margin_loss( - anchor, - positive, - negative, - options.margin(), - options.p(), - options.eps(), - options.swap(), - options.reduction()); + anchor, + positive, + negative, + options.margin(), + options.p(), + options.eps(), + options.swap(), + options.reduction()); } // ============================================================================ @@ -575,7 +633,8 @@ inline Tensor triplet_margin_with_distance_loss( const Tensor& anchor, const Tensor& positive, const Tensor& negative, - c10::optional distance_function, + c10::optional + distance_function, double margin, bool swap, TripletMarginWithDistanceLossFuncOptions::reduction_t reduction) { @@ -611,25 +670,26 @@ inline Tensor triplet_margin_with_distance_loss( } else { ret = anchor; TORCH_INTERNAL_ASSERT( - false, - enumtype::get_enum_name(reduction), - " is not valid"); + false, enumtype::get_enum_name(reduction), " is not valid"); } return ret; } } // namespace detail #endif /* DOXYGEN_SHOULD_SKIP_THIS */ -/// See https://pytorch.org/docs/master/nn.functional.html#torch.nn.functional.triplet_margin_with_distance_loss +/// See +/// https://pytorch.org/docs/master/nn.functional.html#torch.nn.functional.triplet_margin_with_distance_loss /// about the exact behavior of this functional. /// -/// See the documentation for `torch::nn::functional::TripletMarginWithDistanceLossFuncOptions` class to learn what -/// optional arguments are supported for this functional. +/// See the documentation for +/// `torch::nn::functional::TripletMarginWithDistanceLossFuncOptions` class to +/// learn what optional arguments are supported for this functional. /// /// Example: /// ``` /// namespace F = torch::nn::functional; -/// F::triplet_margin_with_distance_loss(anchor, positive, negative, F::TripletMarginWithDistanceLossFuncOptions().margin(1.0)); +/// F::triplet_margin_with_distance_loss(anchor, positive, negative, +/// F::TripletMarginWithDistanceLossFuncOptions().margin(1.0)); /// ``` inline Tensor triplet_margin_with_distance_loss( const Tensor& anchor, @@ -637,133 +697,162 @@ inline Tensor triplet_margin_with_distance_loss( const Tensor& negative, const TripletMarginWithDistanceLossFuncOptions& options = {}) { return detail::triplet_margin_with_distance_loss( - anchor, - positive, - negative, - options.distance_function(), - options.margin(), - options.swap(), - options.reduction()); + anchor, + positive, + negative, + options.distance_function(), + options.margin(), + options.swap(), + options.reduction()); } // ============================================================================ #ifndef DOXYGEN_SHOULD_SKIP_THIS namespace detail { -inline Tensor ctc_loss(const Tensor& log_probs, - const Tensor& targets, - const Tensor& input_lengths, - const Tensor& target_lengths, - int64_t blank, - CTCLossFuncOptions::reduction_t reduction, - bool zero_infinity) { +inline Tensor ctc_loss( + const Tensor& log_probs, + const Tensor& targets, + const Tensor& input_lengths, + const Tensor& target_lengths, + int64_t blank, + CTCLossFuncOptions::reduction_t reduction, + bool zero_infinity) { return torch::ctc_loss( - log_probs, - targets, - input_lengths, - target_lengths, - blank, - enumtype::reduction_get_enum(reduction), - zero_infinity); + log_probs, + targets, + input_lengths, + target_lengths, + blank, + enumtype::reduction_get_enum(reduction), + zero_infinity); } } // namespace detail #endif /* DOXYGEN_SHOULD_SKIP_THIS */ -/// See https://pytorch.org/docs/master/nn.functional.html#torch.nn.functional.ctc_loss +/// See +/// https://pytorch.org/docs/master/nn.functional.html#torch.nn.functional.ctc_loss /// about the exact behavior of this functional. /// -/// See the documentation for `torch::nn::functional::CTCLossFuncOptions` class to learn what -/// optional arguments are supported for this functional. +/// See the documentation for `torch::nn::functional::CTCLossFuncOptions` class +/// to learn what optional arguments are supported for this functional. /// /// Example: /// ``` /// namespace F = torch::nn::functional; -/// F::ctc_loss(log_probs, targets, input_lengths, target_lengths, F::CTCLossFuncOptions().reduction(torch::kNone)); -/// ``` -inline Tensor ctc_loss(const Tensor& log_probs, - const Tensor& targets, - const Tensor& input_lengths, - const Tensor& target_lengths, - const CTCLossFuncOptions& options = {}) { +/// F::ctc_loss(log_probs, targets, input_lengths, target_lengths, +/// F::CTCLossFuncOptions().reduction(torch::kNone)); +/// ``` +inline Tensor ctc_loss( + const Tensor& log_probs, + const Tensor& targets, + const Tensor& input_lengths, + const Tensor& target_lengths, + const CTCLossFuncOptions& options = {}) { return detail::ctc_loss( - log_probs, - targets, - input_lengths, - target_lengths, - options.blank(), - options.reduction(), - options.zero_infinity()); + log_probs, + targets, + input_lengths, + target_lengths, + options.blank(), + options.reduction(), + options.zero_infinity()); } // ============================================================================ #ifndef DOXYGEN_SHOULD_SKIP_THIS namespace detail { -inline Tensor poisson_nll_loss(const Tensor& input, - const Tensor& target, - bool log_input, - bool full, - double eps, - PoissonNLLLossFuncOptions::reduction_t reduction) { +inline Tensor poisson_nll_loss( + const Tensor& input, + const Tensor& target, + bool log_input, + bool full, + double eps, + PoissonNLLLossFuncOptions::reduction_t reduction) { return torch::poisson_nll_loss( - input, target, - log_input, full, eps, enumtype::reduction_get_enum(reduction)); + input, + target, + log_input, + full, + eps, + enumtype::reduction_get_enum(reduction)); } } // namespace detail #endif /* DOXYGEN_SHOULD_SKIP_THIS */ -/// See https://pytorch.org/docs/master/nn.functional.html#torch.nn.functional.poisson_nll_loss +/// See +/// https://pytorch.org/docs/master/nn.functional.html#torch.nn.functional.poisson_nll_loss /// about the exact behavior of this functional. /// -/// See the documentation for `torch::nn::functional::PoissonNLLLossFuncOptions` class to learn what -/// optional arguments are supported for this functional. +/// See the documentation for `torch::nn::functional::PoissonNLLLossFuncOptions` +/// class to learn what optional arguments are supported for this functional. /// /// Example: /// ``` /// namespace F = torch::nn::functional; -/// F::poisson_nll_loss(input, target, F::PoissonNLLLossFuncOptions().reduction(torch::kNone)); +/// F::poisson_nll_loss(input, target, +/// F::PoissonNLLLossFuncOptions().reduction(torch::kNone)); /// ``` -inline Tensor poisson_nll_loss(const Tensor& input, const Tensor& target, - const PoissonNLLLossFuncOptions& options = {}) { +inline Tensor poisson_nll_loss( + const Tensor& input, + const Tensor& target, + const PoissonNLLLossFuncOptions& options = {}) { return detail::poisson_nll_loss( - input, target, - options.log_input(), options.full(), options.eps(), options.reduction()); + input, + target, + options.log_input(), + options.full(), + options.eps(), + options.reduction()); } // ============================================================================ #ifndef DOXYGEN_SHOULD_SKIP_THIS namespace detail { -inline Tensor margin_ranking_loss(const Tensor& input1, - const Tensor& input2, - const Tensor& target, - double margin, - MarginRankingLossFuncOptions::reduction_t reduction) { +inline Tensor margin_ranking_loss( + const Tensor& input1, + const Tensor& input2, + const Tensor& target, + double margin, + MarginRankingLossFuncOptions::reduction_t reduction) { TORCH_CHECK( - input1.dim() == input2.dim() && input1.dim() == target.dim(), - "margin_ranking_loss : All input tensors should have same dimension but got sizes: " - "input1: ", input1.sizes(), ", input2: ", input2.sizes(), - ", target: ", target.sizes()); - return torch::margin_ranking_loss(input1, input2, target, margin, - enumtype::reduction_get_enum(reduction)); + input1.dim() == input2.dim() && input1.dim() == target.dim(), + "margin_ranking_loss : All input tensors should have same dimension but got sizes: " + "input1: ", + input1.sizes(), + ", input2: ", + input2.sizes(), + ", target: ", + target.sizes()); + return torch::margin_ranking_loss( + input1, input2, target, margin, enumtype::reduction_get_enum(reduction)); } } // namespace detail #endif /* DOXYGEN_SHOULD_SKIP_THIS */ -/// See https://pytorch.org/docs/master/nn.functional.html#torch.nn.functional.margin_ranking_loss +/// See +/// https://pytorch.org/docs/master/nn.functional.html#torch.nn.functional.margin_ranking_loss /// about the exact behavior of this functional. /// -/// See the documentation for `torch::nn::functional::MarginRankingLossFuncOptions` class to learn what +/// See the documentation for +/// `torch::nn::functional::MarginRankingLossFuncOptions` class to learn what /// optional arguments are supported for this functional. /// /// Example: /// ``` /// namespace F = torch::nn::functional; -/// F::margin_ranking_loss(input1, input2, target, F::MarginRankingLossFuncOptions().margin(0.5).reduction(torch::kSum)); +/// F::margin_ranking_loss(input1, input2, target, +/// F::MarginRankingLossFuncOptions().margin(0.5).reduction(torch::kSum)); /// ``` -inline Tensor margin_ranking_loss(const Tensor& input1, const Tensor& input2, - const Tensor& target, const MarginRankingLossFuncOptions& options = {}) { - return detail::margin_ranking_loss(input1, input2, target, options.margin(), options.reduction()); +inline Tensor margin_ranking_loss( + const Tensor& input1, + const Tensor& input2, + const Tensor& target, + const MarginRankingLossFuncOptions& options = {}) { + return detail::margin_ranking_loss( + input1, input2, target, options.margin(), options.reduction()); } // ============================================================================ @@ -776,12 +865,18 @@ inline Tensor nll_loss( const Tensor& weight, int64_t ignore_index, const NLLLossFuncOptions::reduction_t reduction) { - if (input.dim() < 2){ + if (input.dim() < 2) { TORCH_CHECK(false, "Expected 2 or more dimensions (got ", input.dim(), ")"); } if (input.sizes()[0] != target.sizes()[0]) { - TORCH_CHECK(false, "Expected input batch_size (", input.sizes()[0], ") to match target batch_size (", target.sizes()[0], ")."); + TORCH_CHECK( + false, + "Expected input batch_size (", + input.sizes()[0], + ") to match target batch_size (", + target.sizes()[0], + ")."); } return torch::nll_loss_nd( @@ -794,27 +889,29 @@ inline Tensor nll_loss( } // namespace detail #endif /* DOXYGEN_SHOULD_SKIP_THIS */ -/// See https://pytorch.org/docs/master/nn.functional.html#torch.nn.functional.nll_loss +/// See +/// https://pytorch.org/docs/master/nn.functional.html#torch.nn.functional.nll_loss /// about the exact behavior of this functional. /// -/// See the documentation for `torch::nn::functional::NLLLossFuncOptions` class to learn what -/// optional arguments are supported for this functional. +/// See the documentation for `torch::nn::functional::NLLLossFuncOptions` class +/// to learn what optional arguments are supported for this functional. /// /// Example: /// ``` /// namespace F = torch::nn::functional; -/// F::nll_loss(input, target, F::NLLLossFuncOptions().ignore_index(-100).reduction(torch::kMean)); +/// F::nll_loss(input, target, +/// F::NLLLossFuncOptions().ignore_index(-100).reduction(torch::kMean)); /// ``` inline Tensor nll_loss( const Tensor& input, const Tensor& target, const NLLLossFuncOptions& options = {}) { return detail::nll_loss( - input, - target, - options.weight(), - options.ignore_index(), - options.reduction()); + input, + target, + options.weight(), + options.ignore_index(), + options.reduction()); } // ============================================================================ @@ -839,16 +936,18 @@ inline Tensor cross_entropy( } // namespace detail #endif /* DOXYGEN_SHOULD_SKIP_THIS */ -/// See https://pytorch.org/docs/master/nn.functional.html#torch.nn.functional.cross_entropy +/// See +/// https://pytorch.org/docs/master/nn.functional.html#torch.nn.functional.cross_entropy /// about the exact behavior of this functional. /// -/// See the documentation for `torch::nn::functional::CrossEntropyFuncOptions` class to learn what -/// optional arguments are supported for this functional. +/// See the documentation for `torch::nn::functional::CrossEntropyFuncOptions` +/// class to learn what optional arguments are supported for this functional. /// /// Example: /// ``` /// namespace F = torch::nn::functional; -/// F::cross_entropy(input, target, F::CrossEntropyFuncOptions().ignore_index(-100).reduction(torch::kMean)); +/// F::cross_entropy(input, target, +/// F::CrossEntropyFuncOptions().ignore_index(-100).reduction(torch::kMean)); /// ``` inline Tensor cross_entropy( const Tensor& input, @@ -868,37 +967,53 @@ inline Tensor cross_entropy( #ifndef DOXYGEN_SHOULD_SKIP_THIS namespace detail { inline Tensor binary_cross_entropy_with_logits( - const Tensor& input, const Tensor& target, const Tensor& weight, - BinaryCrossEntropyWithLogitsFuncOptions::reduction_t reduction, const Tensor& pos_weight) { - - TORCH_CHECK(target.sizes() == input.sizes(), - "Target size (", target.sizes(), - ") must be the same as input size (", - input.sizes(), ")" - ); - - return torch::binary_cross_entropy_with_logits(input, target, - weight, pos_weight, enumtype::reduction_get_enum(reduction)); + const Tensor& input, + const Tensor& target, + const Tensor& weight, + BinaryCrossEntropyWithLogitsFuncOptions::reduction_t reduction, + const Tensor& pos_weight) { + TORCH_CHECK( + target.sizes() == input.sizes(), + "Target size (", + target.sizes(), + ") must be the same as input size (", + input.sizes(), + ")"); + + return torch::binary_cross_entropy_with_logits( + input, + target, + weight, + pos_weight, + enumtype::reduction_get_enum(reduction)); } } // namespace detail #endif /* DOXYGEN_SHOULD_SKIP_THIS */ -/// See https://pytorch.org/docs/master/nn.functional.html#torch.nn.functional.binary_cross_entropy_with_logits +/// See +/// https://pytorch.org/docs/master/nn.functional.html#torch.nn.functional.binary_cross_entropy_with_logits /// about the exact behavior of this functional. /// -/// See the documentation for `torch::nn::functional::BinaryCrossEntropyWithLogitsFuncOptions` class to learn what -/// optional arguments are supported for this functional. +/// See the documentation for +/// `torch::nn::functional::BinaryCrossEntropyWithLogitsFuncOptions` class to +/// learn what optional arguments are supported for this functional. /// /// Example: /// ``` /// namespace F = torch::nn::functional; -/// F::binary_cross_entropy_with_logits(input, target, F::BinaryCrossEntropyWithLogitsFuncOptions().pos_weight(pos_weight).reduction(torch::kSum)); +/// F::binary_cross_entropy_with_logits(input, target, +/// F::BinaryCrossEntropyWithLogitsFuncOptions().pos_weight(pos_weight).reduction(torch::kSum)); /// ``` inline Tensor binary_cross_entropy_with_logits( - const Tensor& input, const Tensor& target, - const BinaryCrossEntropyWithLogitsFuncOptions& options = {}) { - return detail::binary_cross_entropy_with_logits(input, target, - options.weight(), options.reduction(), options.pos_weight()); + const Tensor& input, + const Tensor& target, + const BinaryCrossEntropyWithLogitsFuncOptions& options = {}) { + return detail::binary_cross_entropy_with_logits( + input, + target, + options.weight(), + options.reduction(), + options.pos_weight()); } } // namespace functional diff --git a/torch/csrc/api/include/torch/nn/functional/normalization.h b/torch/csrc/api/include/torch/nn/functional/normalization.h index b4a4df6e27fe7e..80e1a933dfb3c7 100644 --- a/torch/csrc/api/include/torch/nn/functional/normalization.h +++ b/torch/csrc/api/include/torch/nn/functional/normalization.h @@ -1,8 +1,8 @@ #pragma once -#include #include #include +#include #include namespace torch { @@ -17,22 +17,23 @@ inline Tensor normalize( int64_t dim, double eps, c10::optional out) { - if (out == c10::nullopt) { - auto denom = input.norm(p, dim, true).clamp_min(eps).expand_as(input); - return input / denom; - } else { - auto denom = input.norm(p, dim, true).clamp_min(eps).expand_as(input); - return torch::div_out(*out, input, denom); - } + if (out == c10::nullopt) { + auto denom = input.norm(p, dim, true).clamp_min(eps).expand_as(input); + return input / denom; + } else { + auto denom = input.norm(p, dim, true).clamp_min(eps).expand_as(input); + return torch::div_out(*out, input, denom); + } } } // namespace detail #endif /* DOXYGEN_SHOULD_SKIP_THIS */ -/// See https://pytorch.org/docs/master/nn.functional.html#torch.nn.functional.normalize +/// See +/// https://pytorch.org/docs/master/nn.functional.html#torch.nn.functional.normalize /// about the exact behavior of this functional. /// -/// See the documentation for `torch::nn::functional::NormalizeFuncOptions` class to learn what -/// optional arguments are supported for this functional. +/// See the documentation for `torch::nn::functional::NormalizeFuncOptions` +/// class to learn what optional arguments are supported for this functional. /// /// Example: /// ``` @@ -42,37 +43,46 @@ inline Tensor normalize( inline Tensor normalize( const Tensor& input, NormalizeFuncOptions options = {}) { - return detail::normalize(input, options.p(), options.dim(), options.eps(), options.out()); + return detail::normalize( + input, options.p(), options.dim(), options.eps(), options.out()); } // ============================================================================ #ifndef DOXYGEN_SHOULD_SKIP_THIS namespace detail { -inline Tensor layer_norm(const Tensor& input, - const std::vector& normalized_shape, - const Tensor& weight, - const Tensor& bias, - double eps) { - return torch::layer_norm(input, normalized_shape, weight, bias, eps); +inline Tensor layer_norm( + const Tensor& input, + const std::vector& normalized_shape, + const Tensor& weight, + const Tensor& bias, + double eps) { + return torch::layer_norm(input, normalized_shape, weight, bias, eps); } } // namespace detail #endif /* DOXYGEN_SHOULD_SKIP_THIS */ -/// See https://pytorch.org/docs/master/nn.functional.html#torch.nn.functional.layer_norm +/// See +/// https://pytorch.org/docs/master/nn.functional.html#torch.nn.functional.layer_norm /// about the exact behavior of this functional. /// -/// See the documentation for `torch::nn::functional::LayerNormFuncOptions` class to learn what -/// optional arguments are supported for this functional. +/// See the documentation for `torch::nn::functional::LayerNormFuncOptions` +/// class to learn what optional arguments are supported for this functional. /// /// Example: /// ``` /// namespace F = torch::nn::functional; /// F::layer_norm(input, F::LayerNormFuncOptions({2, 2}).eps(2e-5)); /// ``` -inline Tensor layer_norm(const Tensor& input, +inline Tensor layer_norm( + const Tensor& input, const LayerNormFuncOptions& options) { - return detail::layer_norm(input, options.normalized_shape(), options.weight(), options.bias(), options.eps()); + return detail::layer_norm( + input, + options.normalized_shape(), + options.weight(), + options.bias(), + options.eps()); } // ============================================================================ @@ -85,43 +95,59 @@ inline Tensor local_response_norm( double alpha, double beta, double k) { - auto dim = input.dim(); - TORCH_CHECK(dim >=3, "Expected 3D or higher dimensionality input (got ", dim, " dimensions)"); - auto div = input.mul(input).unsqueeze(1); - if (dim == 3) { - div = detail::pad(div, /*pad=*/{0, 0, size / 2, (size - 1) / 2}, /*mode=*/torch::kConstant, /*value=*/0); - div = detail::avg_pool2d( + auto dim = input.dim(); + TORCH_CHECK( + dim >= 3, + "Expected 3D or higher dimensionality input (got ", + dim, + " dimensions)"); + auto div = input.mul(input).unsqueeze(1); + if (dim == 3) { + div = detail::pad( div, - /*kernel_size=*/{size, 1}, - /*stride=*/1, - /*padding=*/0, - /*ceil_mode=*/false, - /*count_include_pad=*/true, - /*divisor_override=*/c10::nullopt).squeeze(1); - } else { - auto sizes = input.sizes(); - div = div.view({sizes[0], 1, sizes[1], sizes[2], -1}); - div = detail::pad(div, /*pad=*/{0, 0, 0, 0, size / 2, (size - 1) / 2}, /*mode=*/torch::kConstant, /*value=*/0); - div = detail::avg_pool3d( + /*pad=*/{0, 0, size / 2, (size - 1) / 2}, + /*mode=*/torch::kConstant, + /*value=*/0); + div = detail::avg_pool2d( + div, + /*kernel_size=*/{size, 1}, + /*stride=*/1, + /*padding=*/0, + /*ceil_mode=*/false, + /*count_include_pad=*/true, + /*divisor_override=*/c10::nullopt) + .squeeze(1); + } else { + auto sizes = input.sizes(); + div = div.view({sizes[0], 1, sizes[1], sizes[2], -1}); + div = detail::pad( div, - /*kernel_size=*/{size, 1, 1}, - /*stride=*/1, - /*padding=*/0, - /*ceil_mode=*/false, - /*count_include_pad=*/true, - /*divisor_override=*/c10::nullopt).squeeze(1); - div = div.view(sizes); - } - div = div.mul(alpha).add(k).pow(beta); - return input / div; + /*pad=*/{0, 0, 0, 0, size / 2, (size - 1) / 2}, + /*mode=*/torch::kConstant, + /*value=*/0); + div = detail::avg_pool3d( + div, + /*kernel_size=*/{size, 1, 1}, + /*stride=*/1, + /*padding=*/0, + /*ceil_mode=*/false, + /*count_include_pad=*/true, + /*divisor_override=*/c10::nullopt) + .squeeze(1); + div = div.view(sizes); + } + div = div.mul(alpha).add(k).pow(beta); + return input / div; } } // namespace detail #endif /* DOXYGEN_SHOULD_SKIP_THIS */ -/// See https://pytorch.org/docs/master/nn.functional.html#torch.nn.functional.local_response_norm +/// See +/// https://pytorch.org/docs/master/nn.functional.html#torch.nn.functional.local_response_norm /// about the exact behavior of this functional. /// -/// See the documentation for `torch::nn::functional::LocalResponseNormFuncOptions` class to learn what +/// See the documentation for +/// `torch::nn::functional::LocalResponseNormFuncOptions` class to learn what /// optional arguments are supported for this functional. /// /// Example: @@ -132,7 +158,8 @@ inline Tensor local_response_norm( inline Tensor local_response_norm( const Tensor& input, const LocalResponseNormFuncOptions& options) { - return detail::local_response_norm(input, options.size(), options.alpha(), options.beta(), options.k()); + return detail::local_response_norm( + input, options.size(), options.alpha(), options.beta(), options.k()); } // ============================================================================ @@ -145,17 +172,23 @@ inline Tensor group_norm( const Tensor& weight, const Tensor& bias, double eps) { - return torch::group_norm(input, num_groups, weight, bias, eps, - at::globalContext().userEnabledCuDNN()); + return torch::group_norm( + input, + num_groups, + weight, + bias, + eps, + at::globalContext().userEnabledCuDNN()); } } // namespace detail #endif /* DOXYGEN_SHOULD_SKIP_THIS */ -/// See https://pytorch.org/docs/master/nn.functional.html#torch.nn.functional.group_norm +/// See +/// https://pytorch.org/docs/master/nn.functional.html#torch.nn.functional.group_norm /// about the exact behavior of this functional. /// -/// See the documentation for `torch::nn::functional::GroupNormFuncOptions` class to learn what -/// optional arguments are supported for this functional. +/// See the documentation for `torch::nn::functional::GroupNormFuncOptions` +/// class to learn what optional arguments are supported for this functional. /// /// Example: /// ``` @@ -166,11 +199,11 @@ inline Tensor group_norm( const Tensor& input, const GroupNormFuncOptions& options) { return detail::group_norm( - input, - options.num_groups(), - options.weight(), - options.bias(), - options.eps()); + input, + options.num_groups(), + options.weight(), + options.bias(), + options.eps()); } } // namespace functional diff --git a/torch/csrc/api/include/torch/nn/functional/padding.h b/torch/csrc/api/include/torch/nn/functional/padding.h index 1b2f77626cdbbd..49150ce8271eec 100644 --- a/torch/csrc/api/include/torch/nn/functional/padding.h +++ b/torch/csrc/api/include/torch/nn/functional/padding.h @@ -1,7 +1,7 @@ #pragma once -#include #include +#include namespace torch { namespace nn { @@ -9,10 +9,11 @@ namespace functional { #ifndef DOXYGEN_SHOULD_SKIP_THIS namespace detail { -inline Tensor pad(const Tensor& input, - IntArrayRef pad, - PadFuncOptions::mode_t mode, - double value) { +inline Tensor pad( + const Tensor& input, + IntArrayRef pad, + PadFuncOptions::mode_t mode, + double value) { const auto mode_enum = [&] { if (c10::get_if(&mode)) { return at::padding_mode::constant; @@ -35,16 +36,18 @@ inline Tensor pad(const Tensor& input, } // namespace detail #endif /* DOXYGEN_SHOULD_SKIP_THIS */ -/// See https://pytorch.org/docs/master/nn.functional.html#torch.nn.functional.pad +/// See +/// https://pytorch.org/docs/master/nn.functional.html#torch.nn.functional.pad /// about the exact behavior of this functional. /// -/// See the documentation for `torch::nn::functional::PadFuncOptions` class to learn what -/// optional arguments are supported for this functional. +/// See the documentation for `torch::nn::functional::PadFuncOptions` class to +/// learn what optional arguments are supported for this functional. /// /// Example: /// ``` /// namespace F = torch::nn::functional; -/// F::pad(input, F::PadFuncOptions({1, 2, 2, 1, 1, 2}).mode(torch::kReplicate)); +/// F::pad(input, F::PadFuncOptions({1, 2, 2, 1, 1, +/// 2}).mode(torch::kReplicate)); /// ``` inline Tensor pad(const Tensor& input, const PadFuncOptions& options) { return detail::pad(input, options.pad(), options.mode(), options.value()); diff --git a/torch/csrc/api/include/torch/nn/functional/pixelshuffle.h b/torch/csrc/api/include/torch/nn/functional/pixelshuffle.h index 32161d04d806f3..6b962cf814b105 100644 --- a/torch/csrc/api/include/torch/nn/functional/pixelshuffle.h +++ b/torch/csrc/api/include/torch/nn/functional/pixelshuffle.h @@ -8,13 +8,8 @@ namespace functional { #ifndef DOXYGEN_SHOULD_SKIP_THIS namespace detail { -inline Tensor pixel_shuffle( - const Tensor& input, - int64_t upscale_factor) { - return torch::pixel_shuffle( - input, - upscale_factor - ); +inline Tensor pixel_shuffle(const Tensor& input, int64_t upscale_factor) { + return torch::pixel_shuffle(input, upscale_factor); } inline Tensor pixel_unshuffle(const Tensor& input, int64_t downscale_factor) { @@ -23,11 +18,12 @@ inline Tensor pixel_unshuffle(const Tensor& input, int64_t downscale_factor) { } // namespace detail #endif /* DOXYGEN_SHOULD_SKIP_THIS */ -/// See https://pytorch.org/docs/master/nn.functional.html#torch.nn.functional.pixel_shuffle +/// See +/// https://pytorch.org/docs/master/nn.functional.html#torch.nn.functional.pixel_shuffle /// about the exact behavior of this functional. /// -/// See the documentation for `torch::nn::functional::PixelShuffleFuncOptions` class to learn what -/// optional arguments are supported for this functional. +/// See the documentation for `torch::nn::functional::PixelShuffleFuncOptions` +/// class to learn what optional arguments are supported for this functional. /// /// Example: /// ``` diff --git a/torch/csrc/api/include/torch/nn/functional/pooling.h b/torch/csrc/api/include/torch/nn/functional/pooling.h index ae325fc8113e45..c5b42c6f8e1e69 100644 --- a/torch/csrc/api/include/torch/nn/functional/pooling.h +++ b/torch/csrc/api/include/torch/nn/functional/pooling.h @@ -2,8 +2,8 @@ #include #include -#include #include +#include namespace torch { namespace nn { @@ -11,53 +11,53 @@ namespace functional { #ifndef DOXYGEN_SHOULD_SKIP_THIS namespace detail { -inline Tensor avg_pool1d(const Tensor& input, - ExpandingArray<1> kernel_size, - ExpandingArray<1> stride, - ExpandingArray<1> padding, - bool ceil_mode, - bool count_include_pad) { +inline Tensor avg_pool1d( + const Tensor& input, + ExpandingArray<1> kernel_size, + ExpandingArray<1> stride, + ExpandingArray<1> padding, + bool ceil_mode, + bool count_include_pad) { return torch::avg_pool1d( - input, - kernel_size, - stride, - padding, - ceil_mode, - count_include_pad); + input, kernel_size, stride, padding, ceil_mode, count_include_pad); } } // namespace detail #endif /* DOXYGEN_SHOULD_SKIP_THIS */ -/// See https://pytorch.org/docs/master/nn.functional.html#torch.nn.functional.avg_pool1d +/// See +/// https://pytorch.org/docs/master/nn.functional.html#torch.nn.functional.avg_pool1d /// about the exact behavior of this functional. /// -/// See the documentation for `torch::nn::functional::AvgPool1dFuncOptions` class to learn what -/// optional arguments are supported for this functional. +/// See the documentation for `torch::nn::functional::AvgPool1dFuncOptions` +/// class to learn what optional arguments are supported for this functional. /// /// Example: /// ``` /// namespace F = torch::nn::functional; /// F::avg_pool1d(x, F::AvgPool1dFuncOptions(3).stride(2)); /// ``` -inline Tensor avg_pool1d(const Tensor& input, const AvgPool1dFuncOptions& options) { +inline Tensor avg_pool1d( + const Tensor& input, + const AvgPool1dFuncOptions& options) { return avg_pool1d( - input, - options.kernel_size(), - options.stride(), - options.padding(), - options.ceil_mode(), - options.count_include_pad()); + input, + options.kernel_size(), + options.stride(), + options.padding(), + options.ceil_mode(), + options.count_include_pad()); } #ifndef DOXYGEN_SHOULD_SKIP_THIS namespace detail { -inline Tensor avg_pool2d(const Tensor& input, - ExpandingArray<2> kernel_size, - ExpandingArray<2> stride, - ExpandingArray<2> padding, - bool ceil_mode, - bool count_include_pad, - c10::optional divisor_override) { +inline Tensor avg_pool2d( + const Tensor& input, + ExpandingArray<2> kernel_size, + ExpandingArray<2> stride, + ExpandingArray<2> padding, + bool ceil_mode, + bool count_include_pad, + c10::optional divisor_override) { return torch::avg_pool2d( input, kernel_size, @@ -70,37 +70,41 @@ inline Tensor avg_pool2d(const Tensor& input, } // namespace detail #endif /* DOXYGEN_SHOULD_SKIP_THIS */ -/// See https://pytorch.org/docs/master/nn.functional.html#torch.nn.functional.avg_pool2d +/// See +/// https://pytorch.org/docs/master/nn.functional.html#torch.nn.functional.avg_pool2d /// about the exact behavior of this functional. /// -/// See the documentation for `torch::nn::functional::AvgPool2dFuncOptions` class to learn what -/// optional arguments are supported for this functional. +/// See the documentation for `torch::nn::functional::AvgPool2dFuncOptions` +/// class to learn what optional arguments are supported for this functional. /// /// Example: /// ``` /// namespace F = torch::nn::functional; /// F::avg_pool2d(x, F::AvgPool2dFuncOptions(3).stride(2)); /// ``` -inline Tensor avg_pool2d(const Tensor& input, const AvgPool2dFuncOptions& options) { +inline Tensor avg_pool2d( + const Tensor& input, + const AvgPool2dFuncOptions& options) { return detail::avg_pool2d( - input, - options.kernel_size(), - options.stride(), - options.padding(), - options.ceil_mode(), - options.count_include_pad(), - options.divisor_override()); + input, + options.kernel_size(), + options.stride(), + options.padding(), + options.ceil_mode(), + options.count_include_pad(), + options.divisor_override()); } #ifndef DOXYGEN_SHOULD_SKIP_THIS namespace detail { -inline Tensor avg_pool3d(const Tensor& input, - ExpandingArray<3> kernel_size, - ExpandingArray<3> stride, - ExpandingArray<3> padding, - bool ceil_mode, - bool count_include_pad, - c10::optional divisor_override) { +inline Tensor avg_pool3d( + const Tensor& input, + ExpandingArray<3> kernel_size, + ExpandingArray<3> stride, + ExpandingArray<3> padding, + bool ceil_mode, + bool count_include_pad, + c10::optional divisor_override) { return torch::avg_pool3d( input, kernel_size, @@ -113,62 +117,64 @@ inline Tensor avg_pool3d(const Tensor& input, } // namespace detail #endif /* DOXYGEN_SHOULD_SKIP_THIS */ -/// See https://pytorch.org/docs/master/nn.functional.html#torch.nn.functional.avg_pool3d +/// See +/// https://pytorch.org/docs/master/nn.functional.html#torch.nn.functional.avg_pool3d /// about the exact behavior of this functional. /// -/// See the documentation for `torch::nn::functional::AvgPool3dFuncOptions` class to learn what -/// optional arguments are supported for this functional. +/// See the documentation for `torch::nn::functional::AvgPool3dFuncOptions` +/// class to learn what optional arguments are supported for this functional. /// /// Example: /// ``` /// namespace F = torch::nn::functional; /// F::avg_pool3d(x, F::AvgPool3dFuncOptions(3).stride(2)); /// ``` -inline Tensor avg_pool3d(const Tensor& input, const AvgPool3dFuncOptions& options) { +inline Tensor avg_pool3d( + const Tensor& input, + const AvgPool3dFuncOptions& options) { return detail::avg_pool3d( - input, - options.kernel_size(), - options.stride(), - options.padding(), - options.ceil_mode(), - options.count_include_pad(), - options.divisor_override()); + input, + options.kernel_size(), + options.stride(), + options.padding(), + options.ceil_mode(), + options.count_include_pad(), + options.divisor_override()); } // ============================================================================ #ifndef DOXYGEN_SHOULD_SKIP_THIS namespace detail { -inline Tensor max_pool1d(const Tensor& input, - ExpandingArray<1> kernel_size, - ExpandingArray<1> stride, - ExpandingArray<1> padding, - ExpandingArray<1> dilation, - bool ceil_mode) { - return torch::max_pool1d( - input, - kernel_size, - stride, - padding, - dilation, - ceil_mode); +inline Tensor max_pool1d( + const Tensor& input, + ExpandingArray<1> kernel_size, + ExpandingArray<1> stride, + ExpandingArray<1> padding, + ExpandingArray<1> dilation, + bool ceil_mode) { + return torch::max_pool1d( + input, kernel_size, stride, padding, dilation, ceil_mode); } } // namespace detail #endif /* DOXYGEN_SHOULD_SKIP_THIS */ -/// See https://pytorch.org/docs/master/nn.functional.html#torch.nn.functional.max_pool1d +/// See +/// https://pytorch.org/docs/master/nn.functional.html#torch.nn.functional.max_pool1d /// about the exact behavior of this functional. /// -/// See the documentation for `torch::nn::functional::MaxPool1dFuncOptions` class to learn what -/// optional arguments are supported for this functional. +/// See the documentation for `torch::nn::functional::MaxPool1dFuncOptions` +/// class to learn what optional arguments are supported for this functional. /// /// Example: /// ``` /// namespace F = torch::nn::functional; /// F::max_pool1d(x, F::MaxPool1dFuncOptions(3).stride(2)); /// ``` -inline Tensor max_pool1d(const Tensor& input, const MaxPool1dFuncOptions& options) { - return detail::max_pool1d( +inline Tensor max_pool1d( + const Tensor& input, + const MaxPool1dFuncOptions& options) { + return detail::max_pool1d( input, options.kernel_size(), options.stride(), @@ -180,32 +186,29 @@ inline Tensor max_pool1d(const Tensor& input, const MaxPool1dFuncOptions& option #ifndef DOXYGEN_SHOULD_SKIP_THIS namespace detail { inline std::tuple max_pool1d_with_indices( - const Tensor& input, - ExpandingArray<1> kernel_size, - ExpandingArray<1> stride, - ExpandingArray<1> padding, - ExpandingArray<1> dilation, - bool ceil_mode) { + const Tensor& input, + ExpandingArray<1> kernel_size, + ExpandingArray<1> stride, + ExpandingArray<1> padding, + ExpandingArray<1> dilation, + bool ceil_mode) { return torch::max_pool1d_with_indices( - input, - kernel_size, - stride, - padding, - dilation, - ceil_mode); + input, kernel_size, stride, padding, dilation, ceil_mode); } } // namespace detail #endif /* DOXYGEN_SHOULD_SKIP_THIS */ -/// See the documentation for `torch::nn::functional::MaxPool1dFuncOptions` class to learn what -/// optional arguments are supported for this functional. +/// See the documentation for `torch::nn::functional::MaxPool1dFuncOptions` +/// class to learn what optional arguments are supported for this functional. /// /// Example: /// ``` /// namespace F = torch::nn::functional; /// F::max_pool1d_with_indices(x, F::MaxPool1dFuncOptions(3).stride(2)); /// ``` -inline std::tuple max_pool1d_with_indices(const Tensor& input, const MaxPool1dFuncOptions& options) { +inline std::tuple max_pool1d_with_indices( + const Tensor& input, + const MaxPool1dFuncOptions& options) { return detail::max_pool1d_with_indices( input, options.kernel_size(), @@ -217,35 +220,34 @@ inline std::tuple max_pool1d_with_indices(const Tensor& input, c #ifndef DOXYGEN_SHOULD_SKIP_THIS namespace detail { -inline Tensor max_pool2d(const Tensor& input, - ExpandingArray<2> kernel_size, - ExpandingArray<2> stride, - ExpandingArray<2> padding, - ExpandingArray<2> dilation, - bool ceil_mode) { +inline Tensor max_pool2d( + const Tensor& input, + ExpandingArray<2> kernel_size, + ExpandingArray<2> stride, + ExpandingArray<2> padding, + ExpandingArray<2> dilation, + bool ceil_mode) { return torch::max_pool2d( - input, - kernel_size, - stride, - padding, - dilation, - ceil_mode); + input, kernel_size, stride, padding, dilation, ceil_mode); } } // namespace detail #endif /* DOXYGEN_SHOULD_SKIP_THIS */ -/// See https://pytorch.org/docs/master/nn.functional.html#torch.nn.functional.max_pool2d +/// See +/// https://pytorch.org/docs/master/nn.functional.html#torch.nn.functional.max_pool2d /// about the exact behavior of this functional. /// -/// See the documentation for `torch::nn::functional::MaxPool2dFuncOptions` class to learn what -/// optional arguments are supported for this functional. +/// See the documentation for `torch::nn::functional::MaxPool2dFuncOptions` +/// class to learn what optional arguments are supported for this functional. /// /// Example: /// ``` /// namespace F = torch::nn::functional; /// F::max_pool2d(x, F::MaxPool2dFuncOptions(3).stride(2)); /// ``` -inline Tensor max_pool2d(const Tensor& input, const MaxPool2dFuncOptions& options) { +inline Tensor max_pool2d( + const Tensor& input, + const MaxPool2dFuncOptions& options) { return detail::max_pool2d( input, options.kernel_size(), @@ -258,32 +260,29 @@ inline Tensor max_pool2d(const Tensor& input, const MaxPool2dFuncOptions& option #ifndef DOXYGEN_SHOULD_SKIP_THIS namespace detail { inline std::tuple max_pool2d_with_indices( - const Tensor& input, - ExpandingArray<2> kernel_size, - ExpandingArray<2> stride, - ExpandingArray<2> padding, - ExpandingArray<2> dilation, - bool ceil_mode) { + const Tensor& input, + ExpandingArray<2> kernel_size, + ExpandingArray<2> stride, + ExpandingArray<2> padding, + ExpandingArray<2> dilation, + bool ceil_mode) { return torch::max_pool2d_with_indices( - input, - kernel_size, - stride, - padding, - dilation, - ceil_mode); + input, kernel_size, stride, padding, dilation, ceil_mode); } } // namespace detail #endif /* DOXYGEN_SHOULD_SKIP_THIS */ -/// See the documentation for `torch::nn::functional::MaxPool2dFuncOptions` class to learn what -/// optional arguments are supported for this functional. +/// See the documentation for `torch::nn::functional::MaxPool2dFuncOptions` +/// class to learn what optional arguments are supported for this functional. /// /// Example: /// ``` /// namespace F = torch::nn::functional; /// F::max_pool2d_with_indices(x, F::MaxPool2dFuncOptions(3).stride(2)); /// ``` -inline std::tuple max_pool2d_with_indices(const Tensor& input, const MaxPool2dFuncOptions& options) { +inline std::tuple max_pool2d_with_indices( + const Tensor& input, + const MaxPool2dFuncOptions& options) { return detail::max_pool2d_with_indices( input, options.kernel_size(), @@ -295,35 +294,34 @@ inline std::tuple max_pool2d_with_indices(const Tensor& input, c #ifndef DOXYGEN_SHOULD_SKIP_THIS namespace detail { -inline Tensor max_pool3d(const Tensor& input, - ExpandingArray<3> kernel_size, - ExpandingArray<3> stride, - ExpandingArray<3> padding, - ExpandingArray<3> dilation, - bool ceil_mode) { +inline Tensor max_pool3d( + const Tensor& input, + ExpandingArray<3> kernel_size, + ExpandingArray<3> stride, + ExpandingArray<3> padding, + ExpandingArray<3> dilation, + bool ceil_mode) { return torch::max_pool3d( - input, - kernel_size, - stride, - padding, - dilation, - ceil_mode); + input, kernel_size, stride, padding, dilation, ceil_mode); } } // namespace detail #endif /* DOXYGEN_SHOULD_SKIP_THIS */ -/// See https://pytorch.org/docs/master/nn.functional.html#torch.nn.functional.max_pool3d +/// See +/// https://pytorch.org/docs/master/nn.functional.html#torch.nn.functional.max_pool3d /// about the exact behavior of this functional. /// -/// See the documentation for `torch::nn::functional::MaxPool3dFuncOptions` class to learn what -/// optional arguments are supported for this functional. +/// See the documentation for `torch::nn::functional::MaxPool3dFuncOptions` +/// class to learn what optional arguments are supported for this functional. /// /// Example: /// ``` /// namespace F = torch::nn::functional; /// F::max_pool3d(x, F::MaxPool3dFuncOptions(3).stride(2)); /// ``` -inline Tensor max_pool3d(const Tensor& input, const MaxPool3dFuncOptions& options) { +inline Tensor max_pool3d( + const Tensor& input, + const MaxPool3dFuncOptions& options) { return detail::max_pool3d( input, options.kernel_size(), @@ -336,32 +334,29 @@ inline Tensor max_pool3d(const Tensor& input, const MaxPool3dFuncOptions& option #ifndef DOXYGEN_SHOULD_SKIP_THIS namespace detail { inline std::tuple max_pool3d_with_indices( - const Tensor& input, - ExpandingArray<3> kernel_size, - ExpandingArray<3> stride, - ExpandingArray<3> padding, - ExpandingArray<3> dilation, - bool ceil_mode) { + const Tensor& input, + ExpandingArray<3> kernel_size, + ExpandingArray<3> stride, + ExpandingArray<3> padding, + ExpandingArray<3> dilation, + bool ceil_mode) { return torch::max_pool3d_with_indices( - input, - kernel_size, - stride, - padding, - dilation, - ceil_mode); + input, kernel_size, stride, padding, dilation, ceil_mode); } } // namespace detail #endif /* DOXYGEN_SHOULD_SKIP_THIS */ -/// See the documentation for `torch::nn::functional::MaxPool3dFuncOptions` class to learn what -/// optional arguments are supported for this functional. +/// See the documentation for `torch::nn::functional::MaxPool3dFuncOptions` +/// class to learn what optional arguments are supported for this functional. /// /// Example: /// ``` /// namespace F = torch::nn::functional; /// F::max_pool3d_with_indices(x, F::MaxPool3dFuncOptions(3).stride(2)); /// ``` -inline std::tuple max_pool3d_with_indices(const Tensor& input, const MaxPool3dFuncOptions& options) { +inline std::tuple max_pool3d_with_indices( + const Tensor& input, + const MaxPool3dFuncOptions& options) { return detail::max_pool3d_with_indices( input, options.kernel_size(), @@ -376,12 +371,14 @@ inline std::tuple max_pool3d_with_indices(const Tensor& input, c #ifndef DOXYGEN_SHOULD_SKIP_THIS namespace detail { inline std::tuple adaptive_max_pool1d_with_indices( - const Tensor& input, ExpandingArray<1> output_size) { + const Tensor& input, + ExpandingArray<1> output_size) { return torch::adaptive_max_pool1d(input, output_size); } } // namespace detail -/// See the documentation for `torch::nn::functional::AdaptiveMaxPool1dFuncOptions` class to learn what +/// See the documentation for +/// `torch::nn::functional::AdaptiveMaxPool1dFuncOptions` class to learn what /// optional arguments are supported for this functional. /// /// Example: @@ -390,22 +387,26 @@ inline std::tuple adaptive_max_pool1d_with_indices( /// F::adaptive_max_pool1d_with_indices(x, F::AdaptiveMaxPool1dFuncOptions(3)); /// ``` inline std::tuple adaptive_max_pool1d_with_indices( - const Tensor& input, const AdaptiveMaxPool1dFuncOptions& options) { + const Tensor& input, + const AdaptiveMaxPool1dFuncOptions& options) { return detail::adaptive_max_pool1d_with_indices(input, options.output_size()); } namespace detail { -inline Tensor adaptive_max_pool1d(const Tensor& input, +inline Tensor adaptive_max_pool1d( + const Tensor& input, ExpandingArray<1> output_size) { return std::get<0>(adaptive_max_pool1d_with_indices(input, output_size)); } } // namespace detail #endif /* DOXYGEN_SHOULD_SKIP_THIS */ -/// See https://pytorch.org/docs/master/nn.functional.html#torch.nn.functional.adaptive_max_pool1d +/// See +/// https://pytorch.org/docs/master/nn.functional.html#torch.nn.functional.adaptive_max_pool1d /// about the exact behavior of this functional. /// -/// See the documentation for `torch::nn::functional::AdaptiveMaxPool1dFuncOptions` class to learn what +/// See the documentation for +/// `torch::nn::functional::AdaptiveMaxPool1dFuncOptions` class to learn what /// optional arguments are supported for this functional. /// /// Example: @@ -413,7 +414,8 @@ inline Tensor adaptive_max_pool1d(const Tensor& input, /// namespace F = torch::nn::functional; /// F::adaptive_max_pool1d(x, F::AdaptiveMaxPool1dFuncOptions(3)); /// ``` -inline Tensor adaptive_max_pool1d(const Tensor& input, +inline Tensor adaptive_max_pool1d( + const Tensor& input, const AdaptiveMaxPool1dFuncOptions& options) { return detail::adaptive_max_pool1d(input, options.output_size()); } @@ -421,14 +423,17 @@ inline Tensor adaptive_max_pool1d(const Tensor& input, #ifndef DOXYGEN_SHOULD_SKIP_THIS namespace detail { inline std::tuple adaptive_max_pool2d_with_indices( - const Tensor& input, ExpandingArrayWithOptionalElem<2> output_size) { - auto output_size_ = torch::nn::modules::utils::_list_with_default(output_size, input.sizes()); + const Tensor& input, + ExpandingArrayWithOptionalElem<2> output_size) { + auto output_size_ = + torch::nn::modules::utils::_list_with_default(output_size, input.sizes()); return torch::adaptive_max_pool2d(input, output_size_); } } // namespace detail #endif /* DOXYGEN_SHOULD_SKIP_THIS */ -/// See the documentation for `torch::nn::functional::AdaptiveMaxPool2dFuncOptions` class to learn what +/// See the documentation for +/// `torch::nn::functional::AdaptiveMaxPool2dFuncOptions` class to learn what /// optional arguments are supported for this functional. /// /// Example: @@ -437,23 +442,27 @@ inline std::tuple adaptive_max_pool2d_with_indices( /// F::adaptive_max_pool2d_with_indices(x, F::AdaptiveMaxPool2dFuncOptions(3)); /// ``` inline std::tuple adaptive_max_pool2d_with_indices( - const Tensor& input, const AdaptiveMaxPool2dFuncOptions& options) { + const Tensor& input, + const AdaptiveMaxPool2dFuncOptions& options) { return detail::adaptive_max_pool2d_with_indices(input, options.output_size()); } #ifndef DOXYGEN_SHOULD_SKIP_THIS namespace detail { -inline Tensor adaptive_max_pool2d(const Tensor& input, +inline Tensor adaptive_max_pool2d( + const Tensor& input, ExpandingArrayWithOptionalElem<2> output_size) { return std::get<0>(adaptive_max_pool2d_with_indices(input, output_size)); } } // namespace detail #endif /* DOXYGEN_SHOULD_SKIP_THIS */ -/// See https://pytorch.org/docs/master/nn.functional.html#torch.nn.functional.adaptive_max_pool2d +/// See +/// https://pytorch.org/docs/master/nn.functional.html#torch.nn.functional.adaptive_max_pool2d /// about the exact behavior of this functional. /// -/// See the documentation for `torch::nn::functional::AdaptiveMaxPool2dFuncOptions` class to learn what +/// See the documentation for +/// `torch::nn::functional::AdaptiveMaxPool2dFuncOptions` class to learn what /// optional arguments are supported for this functional. /// /// Example: @@ -461,7 +470,8 @@ inline Tensor adaptive_max_pool2d(const Tensor& input, /// namespace F = torch::nn::functional; /// F::adaptive_max_pool2d(x, F::AdaptiveMaxPool2dFuncOptions(3)); /// ``` -inline Tensor adaptive_max_pool2d(const Tensor& input, +inline Tensor adaptive_max_pool2d( + const Tensor& input, const AdaptiveMaxPool2dFuncOptions& options) { return detail::adaptive_max_pool2d(input, options.output_size()); } @@ -469,14 +479,17 @@ inline Tensor adaptive_max_pool2d(const Tensor& input, #ifndef DOXYGEN_SHOULD_SKIP_THIS namespace detail { inline std::tuple adaptive_max_pool3d_with_indices( - const Tensor& input, ExpandingArrayWithOptionalElem<3> output_size) { - auto output_size_ = torch::nn::modules::utils::_list_with_default(output_size, input.sizes()); + const Tensor& input, + ExpandingArrayWithOptionalElem<3> output_size) { + auto output_size_ = + torch::nn::modules::utils::_list_with_default(output_size, input.sizes()); return torch::adaptive_max_pool3d(input, output_size_); } } // namespace detail #endif /* DOXYGEN_SHOULD_SKIP_THIS */ -/// See the documentation for `torch::nn::functional::AdaptiveMaxPool3dFuncOptions` class to learn what +/// See the documentation for +/// `torch::nn::functional::AdaptiveMaxPool3dFuncOptions` class to learn what /// optional arguments are supported for this functional. /// /// Example: @@ -485,23 +498,27 @@ inline std::tuple adaptive_max_pool3d_with_indices( /// F::adaptive_max_pool3d_with_indices(x, F::AdaptiveMaxPool3dFuncOptions(3)); /// ``` inline std::tuple adaptive_max_pool3d_with_indices( - const Tensor& input, const AdaptiveMaxPool3dFuncOptions& options) { + const Tensor& input, + const AdaptiveMaxPool3dFuncOptions& options) { return detail::adaptive_max_pool3d_with_indices(input, options.output_size()); } #ifndef DOXYGEN_SHOULD_SKIP_THIS namespace detail { -inline Tensor adaptive_max_pool3d(const Tensor& input, +inline Tensor adaptive_max_pool3d( + const Tensor& input, ExpandingArrayWithOptionalElem<3> output_size) { return std::get<0>(adaptive_max_pool3d_with_indices(input, output_size)); } } // namespace detail #endif /* DOXYGEN_SHOULD_SKIP_THIS */ -/// See https://pytorch.org/docs/master/nn.functional.html#torch.nn.functional.adaptive_max_pool3d +/// See +/// https://pytorch.org/docs/master/nn.functional.html#torch.nn.functional.adaptive_max_pool3d /// about the exact behavior of this functional. /// -/// See the documentation for `torch::nn::functional::AdaptiveMaxPool3dFuncOptions` class to learn what +/// See the documentation for +/// `torch::nn::functional::AdaptiveMaxPool3dFuncOptions` class to learn what /// optional arguments are supported for this functional. /// /// Example: @@ -509,26 +526,30 @@ inline Tensor adaptive_max_pool3d(const Tensor& input, /// namespace F = torch::nn::functional; /// F::adaptive_max_pool3d(x, F::AdaptiveMaxPool3dFuncOptions(3)); /// ``` -inline Tensor adaptive_max_pool3d(const Tensor& input, - const AdaptiveMaxPool3dFuncOptions& options) { - return detail::adaptive_max_pool3d(input, options.output_size()); +inline Tensor adaptive_max_pool3d( + const Tensor& input, + const AdaptiveMaxPool3dFuncOptions& options) { + return detail::adaptive_max_pool3d(input, options.output_size()); } // ============================================================================ #ifndef DOXYGEN_SHOULD_SKIP_THIS namespace detail { -inline Tensor adaptive_avg_pool1d(const Tensor& input, +inline Tensor adaptive_avg_pool1d( + const Tensor& input, ExpandingArray<1> output_size) { return torch::adaptive_avg_pool1d(input, output_size); } } // namespace detail #endif /* DOXYGEN_SHOULD_SKIP_THIS */ -/// See https://pytorch.org/docs/master/nn.functional.html#torch.nn.functional.adaptive_avg_pool1d +/// See +/// https://pytorch.org/docs/master/nn.functional.html#torch.nn.functional.adaptive_avg_pool1d /// about the exact behavior of this functional. /// -/// See the documentation for `torch::nn::functional::AdaptiveAvgPool1dFuncOptions` class to learn what +/// See the documentation for +/// `torch::nn::functional::AdaptiveAvgPool1dFuncOptions` class to learn what /// optional arguments are supported for this functional. /// /// Example: @@ -536,25 +557,30 @@ inline Tensor adaptive_avg_pool1d(const Tensor& input, /// namespace F = torch::nn::functional; /// F::adaptive_avg_pool1d(x, F::AdaptiveAvgPool1dFuncOptions(3)); /// ``` -inline Tensor adaptive_avg_pool1d(const Tensor& input, +inline Tensor adaptive_avg_pool1d( + const Tensor& input, const AdaptiveAvgPool1dFuncOptions& options) { return detail::adaptive_avg_pool1d(input, options.output_size()); } #ifndef DOXYGEN_SHOULD_SKIP_THIS namespace detail { -inline Tensor adaptive_avg_pool2d(const Tensor& input, +inline Tensor adaptive_avg_pool2d( + const Tensor& input, ExpandingArrayWithOptionalElem<2> output_size) { - auto output_size_ = torch::nn::modules::utils::_list_with_default(output_size, input.sizes()); + auto output_size_ = + torch::nn::modules::utils::_list_with_default(output_size, input.sizes()); return torch::adaptive_avg_pool2d(input, output_size_); } } // namespace detail #endif /* DOXYGEN_SHOULD_SKIP_THIS */ -/// See https://pytorch.org/docs/master/nn.functional.html#torch.nn.functional.adaptive_avg_pool2d +/// See +/// https://pytorch.org/docs/master/nn.functional.html#torch.nn.functional.adaptive_avg_pool2d /// about the exact behavior of this functional. /// -/// See the documentation for `torch::nn::functional::AdaptiveAvgPool2dFuncOptions` class to learn what +/// See the documentation for +/// `torch::nn::functional::AdaptiveAvgPool2dFuncOptions` class to learn what /// optional arguments are supported for this functional. /// /// Example: @@ -562,25 +588,30 @@ inline Tensor adaptive_avg_pool2d(const Tensor& input, /// namespace F = torch::nn::functional; /// F::adaptive_avg_pool2d(x, F::AdaptiveAvgPool2dFuncOptions(3)); /// ``` -inline Tensor adaptive_avg_pool2d(const Tensor& input, +inline Tensor adaptive_avg_pool2d( + const Tensor& input, const AdaptiveAvgPool2dFuncOptions& options) { return detail::adaptive_avg_pool2d(input, options.output_size()); } #ifndef DOXYGEN_SHOULD_SKIP_THIS namespace detail { -inline Tensor adaptive_avg_pool3d(const Tensor& input, +inline Tensor adaptive_avg_pool3d( + const Tensor& input, ExpandingArrayWithOptionalElem<3> output_size) { - auto output_size_ = torch::nn::modules::utils::_list_with_default(output_size, input.sizes()); + auto output_size_ = + torch::nn::modules::utils::_list_with_default(output_size, input.sizes()); return torch::adaptive_avg_pool3d(input, output_size_); } } // namespace detail #endif /* DOXYGEN_SHOULD_SKIP_THIS */ -/// See https://pytorch.org/docs/master/nn.functional.html#torch.nn.functional.adaptive_avg_pool3d +/// See +/// https://pytorch.org/docs/master/nn.functional.html#torch.nn.functional.adaptive_avg_pool3d /// about the exact behavior of this functional. /// -/// See the documentation for `torch::nn::functional::AdaptiveAvgPool3dFuncOptions` class to learn what +/// See the documentation for +/// `torch::nn::functional::AdaptiveAvgPool3dFuncOptions` class to learn what /// optional arguments are supported for this functional. /// /// Example: @@ -588,21 +619,25 @@ inline Tensor adaptive_avg_pool3d(const Tensor& input, /// namespace F = torch::nn::functional; /// F::adaptive_avg_pool3d(x, F::AdaptiveAvgPool3dFuncOptions(3)); /// ``` -inline Tensor adaptive_avg_pool3d(const Tensor& input, +inline Tensor adaptive_avg_pool3d( + const Tensor& input, const AdaptiveAvgPool3dFuncOptions& options) { return detail::adaptive_avg_pool3d(input, options.output_size()); } // ============================================================================ -inline std::vector _unpool_output_size(const Tensor& input, - const IntArrayRef& kernel_size, const IntArrayRef& stride, - const IntArrayRef& padding, const c10::optional>& output_size) { +inline std::vector _unpool_output_size( + const Tensor& input, + const IntArrayRef& kernel_size, + const IntArrayRef& stride, + const IntArrayRef& padding, + const c10::optional>& output_size) { auto input_size = input.sizes(); std::vector default_size; for (const auto d : c10::irange(kernel_size.size())) { - default_size.push_back((input_size[d + 2] - 1) * stride[d] + - kernel_size[d] - 2 * padding[d]); + default_size.push_back( + (input_size[d + 2] - 1) * stride[d] + kernel_size[d] - 2 * padding[d]); } if (!output_size) { return default_size; @@ -612,17 +647,31 @@ inline std::vector _unpool_output_size(const Tensor& input, output_size_ = IntArrayRef(*output_size).slice(2).vec(); } if (output_size_.size() != kernel_size.size()) { - TORCH_CHECK(false, "output_size should be a sequence containing ", - kernel_size.size(), " or ", kernel_size.size() + 2, - " elements, but it has a length of '", - output_size_.size(), "'"); + TORCH_CHECK( + false, + "output_size should be a sequence containing ", + kernel_size.size(), + " or ", + kernel_size.size() + 2, + " elements, but it has a length of '", + output_size_.size(), + "'"); } for (const auto d : c10::irange(kernel_size.size())) { const auto min_size = default_size[d] - stride[d]; const auto max_size = default_size[d] + stride[d]; if (!(min_size <= output_size_[d] && output_size_[d] <= max_size)) { - TORCH_CHECK(false, "invalid output_size ", output_size_, " (dim ", d, - " must be between ", min_size, " and ", max_size, ")"); + TORCH_CHECK( + false, + "invalid output_size ", + output_size_, + " (dim ", + d, + " must be between ", + min_size, + " and ", + max_size, + ")"); } } return output_size_; @@ -638,117 +687,125 @@ inline Tensor max_unpool1d( ExpandingArray<1> stride, ExpandingArray<1> padding, const c10::optional>& output_size) { - auto output_size_ = _unpool_output_size(input, kernel_size, - stride, padding, - output_size); + auto output_size_ = + _unpool_output_size(input, kernel_size, stride, padding, output_size); output_size_.push_back(1); - return torch::max_unpool2d(input.unsqueeze(3), indices.unsqueeze(3), - output_size_).squeeze(3); + return torch::max_unpool2d( + input.unsqueeze(3), indices.unsqueeze(3), output_size_) + .squeeze(3); } } // namespace detail #endif /* DOXYGEN_SHOULD_SKIP_THIS */ -/// See https://pytorch.org/docs/master/nn.functional.html#torch.nn.functional.max_unpool1d +/// See +/// https://pytorch.org/docs/master/nn.functional.html#torch.nn.functional.max_unpool1d /// about the exact behavior of this functional. /// -/// See the documentation for `torch::nn::functional::MaxUnpool1dFuncOptions` class to learn what -/// optional arguments are supported for this functional. +/// See the documentation for `torch::nn::functional::MaxUnpool1dFuncOptions` +/// class to learn what optional arguments are supported for this functional. /// /// Example: /// ``` /// namespace F = torch::nn::functional; -/// F::max_unpool1d(x, indices, F::MaxUnpool1dFuncOptions(3).stride(2).padding(1)); +/// F::max_unpool1d(x, indices, +/// F::MaxUnpool1dFuncOptions(3).stride(2).padding(1)); /// ``` -inline Tensor max_unpool1d(const Tensor& input, const Tensor& indices, +inline Tensor max_unpool1d( + const Tensor& input, + const Tensor& indices, const MaxUnpool1dFuncOptions& options) { return detail::max_unpool1d( - input, - indices, - options.kernel_size(), - options.stride(), - options.padding(), - options.output_size()); + input, + indices, + options.kernel_size(), + options.stride(), + options.padding(), + options.output_size()); } #ifndef DOXYGEN_SHOULD_SKIP_THIS namespace detail { inline Tensor max_unpool2d( - const Tensor& input, - const Tensor& indices, - ExpandingArray<2> kernel_size, - ExpandingArray<2> stride, - ExpandingArray<2> padding, - const c10::optional>& output_size) { - auto output_size_ = _unpool_output_size(input, kernel_size, - stride, padding, - output_size); + const Tensor& input, + const Tensor& indices, + ExpandingArray<2> kernel_size, + ExpandingArray<2> stride, + ExpandingArray<2> padding, + const c10::optional>& output_size) { + auto output_size_ = + _unpool_output_size(input, kernel_size, stride, padding, output_size); return torch::max_unpool2d(input, indices, output_size_); } } // namespace detail #endif /* DOXYGEN_SHOULD_SKIP_THIS */ -/// See https://pytorch.org/docs/master/nn.functional.html#torch.nn.functional.max_unpool2d +/// See +/// https://pytorch.org/docs/master/nn.functional.html#torch.nn.functional.max_unpool2d /// about the exact behavior of this functional. /// -/// See the documentation for `torch::nn::functional::MaxUnpool2dFuncOptions` class to learn what -/// optional arguments are supported for this functional. +/// See the documentation for `torch::nn::functional::MaxUnpool2dFuncOptions` +/// class to learn what optional arguments are supported for this functional. /// /// Example: /// ``` /// namespace F = torch::nn::functional; -/// F::max_unpool2d(x, indices, F::MaxUnpool2dFuncOptions(3).stride(2).padding(1)); +/// F::max_unpool2d(x, indices, +/// F::MaxUnpool2dFuncOptions(3).stride(2).padding(1)); /// ``` -inline Tensor max_unpool2d(const Tensor& input, const Tensor& indices, - const MaxUnpool2dFuncOptions& options) { +inline Tensor max_unpool2d( + const Tensor& input, + const Tensor& indices, + const MaxUnpool2dFuncOptions& options) { return detail::max_unpool2d( - input, - indices, - options.kernel_size(), - options.stride(), - options.padding(), - options.output_size()); + input, + indices, + options.kernel_size(), + options.stride(), + options.padding(), + options.output_size()); } #ifndef DOXYGEN_SHOULD_SKIP_THIS namespace detail { inline Tensor max_unpool3d( - const Tensor& input, - const Tensor& indices, - ExpandingArray<3> kernel_size, - ExpandingArray<3> stride, - ExpandingArray<3> padding, - const c10::optional>& output_size) { - auto output_size_ = _unpool_output_size(input, kernel_size, - stride, padding, - output_size); - - return torch::max_unpool3d(input, indices, output_size_, - stride, padding); + const Tensor& input, + const Tensor& indices, + ExpandingArray<3> kernel_size, + ExpandingArray<3> stride, + ExpandingArray<3> padding, + const c10::optional>& output_size) { + auto output_size_ = + _unpool_output_size(input, kernel_size, stride, padding, output_size); + + return torch::max_unpool3d(input, indices, output_size_, stride, padding); } } // namespace detail #endif /* DOXYGEN_SHOULD_SKIP_THIS */ -/// See https://pytorch.org/docs/master/nn.functional.html#torch.nn.functional.max_unpool3d +/// See +/// https://pytorch.org/docs/master/nn.functional.html#torch.nn.functional.max_unpool3d /// about the exact behavior of this functional. /// -/// See the documentation for `torch::nn::functional::MaxUnpool3dFuncOptions` class to learn what -/// optional arguments are supported for this functional. +/// See the documentation for `torch::nn::functional::MaxUnpool3dFuncOptions` +/// class to learn what optional arguments are supported for this functional. /// /// Example: /// ``` /// namespace F = torch::nn::functional; /// F::max_unpool3d(x, indices, F::MaxUnpool3dFuncOptions(3)); /// ``` -inline Tensor max_unpool3d(const Tensor& input, const Tensor& indices, - const MaxUnpool3dFuncOptions& options) { +inline Tensor max_unpool3d( + const Tensor& input, + const Tensor& indices, + const MaxUnpool3dFuncOptions& options) { return detail::max_unpool3d( - input, - indices, - options.kernel_size(), - options.stride(), - options.padding(), - options.output_size()); + input, + indices, + options.kernel_size(), + options.stride(), + options.padding(), + options.output_size()); } // ============================================================================ @@ -759,40 +816,48 @@ inline std::tuple fractional_max_pool2d_with_indices( const Tensor& input, const ExpandingArray<2>& kernel_size, const c10::optional>& output_size, - const c10::optional>& output_ratio, + const c10::optional>& output_ratio, const Tensor& _random_samples) { if (output_size == c10::nullopt && output_ratio == c10::nullopt) { TORCH_CHECK( - false, - "fractional_max_pool2d requires specifying either ", - "an output_size or an output_ratio"); + false, + "fractional_max_pool2d requires specifying either ", + "an output_size or an output_ratio"); } c10::optional> output_size_ = output_size; if (output_size_ == c10::nullopt) { TORCH_INTERNAL_ASSERT(output_ratio != c10::nullopt); - output_size_ = {(int64_t)(static_cast(input.size(-2)) * (*output_ratio.value())[0]), - (int64_t)(static_cast(input.size(-1)) * (*output_ratio.value())[1])}; + output_size_ = { + (int64_t)(static_cast(input.size(-2)) * (*output_ratio.value())[0]), + (int64_t)(static_cast(input.size(-1)) * (*output_ratio.value())[1])}; } Tensor _random_samples_ = _random_samples; if (!_random_samples_.defined()) { auto n_batch = input.dim() == 3 ? 1 : input.size(0); - _random_samples_ = torch::rand({n_batch, input.size(-3), 2}, torch::TensorOptions().dtype(input.dtype()).device(input.device())); + _random_samples_ = torch::rand( + {n_batch, input.size(-3), 2}, + torch::TensorOptions().dtype(input.dtype()).device(input.device())); } - return torch::fractional_max_pool2d(input, kernel_size, *output_size_, _random_samples_); + return torch::fractional_max_pool2d( + input, kernel_size, *output_size_, _random_samples_); } } // namespace detail #endif /* DOXYGEN_SHOULD_SKIP_THIS */ -/// See the documentation for `torch::nn::functional::FractionalMaxPool2dFuncOptions` class to learn what +/// See the documentation for +/// `torch::nn::functional::FractionalMaxPool2dFuncOptions` class to learn what /// optional arguments are supported for this functional. /// /// Example: /// ``` /// namespace F = torch::nn::functional; -/// F::fractional_max_pool2d_with_indices(x, F::FractionalMaxPool2dFuncOptions(3).output_size(2)); +/// F::fractional_max_pool2d_with_indices(x, +/// F::FractionalMaxPool2dFuncOptions(3).output_size(2)); /// ``` -inline std::tuple fractional_max_pool2d_with_indices(const Tensor& input, const FractionalMaxPool2dFuncOptions& options) { +inline std::tuple fractional_max_pool2d_with_indices( + const Tensor& input, + const FractionalMaxPool2dFuncOptions& options) { return detail::fractional_max_pool2d_with_indices( input, options.kernel_size(), @@ -803,26 +868,31 @@ inline std::tuple fractional_max_pool2d_with_indices(const Tenso #ifndef DOXYGEN_SHOULD_SKIP_THIS namespace detail { -inline Tensor fractional_max_pool2d(const Tensor& input, - ExpandingArray<2> kernel_size, - c10::optional> output_size, - c10::optional> output_ratio, - const Tensor& _random_samples) { - return std::get<0>(fractional_max_pool2d_with_indices(input, kernel_size, output_size, - output_ratio, _random_samples)); +inline Tensor fractional_max_pool2d( + const Tensor& input, + ExpandingArray<2> kernel_size, + c10::optional> output_size, + c10::optional> output_ratio, + const Tensor& _random_samples) { + return std::get<0>(fractional_max_pool2d_with_indices( + input, kernel_size, output_size, output_ratio, _random_samples)); } } // namespace detail #endif /* DOXYGEN_SHOULD_SKIP_THIS */ -/// See the documentation for `torch::nn::functional::FractionalMaxPool2dFuncOptions` class to learn what +/// See the documentation for +/// `torch::nn::functional::FractionalMaxPool2dFuncOptions` class to learn what /// optional arguments are supported for this functional. /// /// Example: /// ``` /// namespace F = torch::nn::functional; -/// F::fractional_max_pool2d(x, F::FractionalMaxPool2dFuncOptions(3).output_size(2)); +/// F::fractional_max_pool2d(x, +/// F::FractionalMaxPool2dFuncOptions(3).output_size(2)); /// ``` -inline Tensor fractional_max_pool2d(const Tensor& input, const FractionalMaxPool2dFuncOptions& options) { +inline Tensor fractional_max_pool2d( + const Tensor& input, + const FractionalMaxPool2dFuncOptions& options) { return detail::fractional_max_pool2d( input, options.kernel_size(), @@ -837,42 +907,50 @@ inline std::tuple fractional_max_pool3d_with_indices( const Tensor& input, const ExpandingArray<3>& kernel_size, const c10::optional>& output_size, - const c10::optional>& output_ratio, + const c10::optional>& output_ratio, const Tensor& _random_samples) { if (output_size == c10::nullopt && output_ratio == c10::nullopt) { TORCH_CHECK( - false, - "fractional_max_pool3d requires specifying either ", - "an output_size or an output_ratio"); + false, + "fractional_max_pool3d requires specifying either ", + "an output_size or an output_ratio"); } c10::optional> output_size_ = output_size; if (output_size_ == c10::nullopt) { TORCH_INTERNAL_ASSERT(output_ratio != c10::nullopt); - output_size_ = {(int64_t)(static_cast(input.size(-3)) * (*output_ratio.value())[0]), - (int64_t)(static_cast(input.size(-2)) * (*output_ratio.value())[1]), - (int64_t)(static_cast(input.size(-1)) * (*output_ratio.value())[2])}; + output_size_ = { + (int64_t)(static_cast(input.size(-3)) * (*output_ratio.value())[0]), + (int64_t)(static_cast(input.size(-2)) * (*output_ratio.value())[1]), + (int64_t)(static_cast(input.size(-1)) * (*output_ratio.value())[2])}; } Tensor _random_samples_ = _random_samples; if (!_random_samples_.defined()) { auto n_batch = input.dim() == 4 ? 1 : input.size(0); - _random_samples_ = torch::rand({n_batch, input.size(-4), 3}, torch::TensorOptions().dtype(input.dtype()).device(input.device())); + _random_samples_ = torch::rand( + {n_batch, input.size(-4), 3}, + torch::TensorOptions().dtype(input.dtype()).device(input.device())); } - return torch::fractional_max_pool3d(input, kernel_size, *output_size_, _random_samples_); + return torch::fractional_max_pool3d( + input, kernel_size, *output_size_, _random_samples_); } } // namespace detail #endif /* DOXYGEN_SHOULD_SKIP_THIS */ -/// See the documentation for `torch::nn::functional::FractionalMaxPool3dFuncOptions` class to learn what +/// See the documentation for +/// `torch::nn::functional::FractionalMaxPool3dFuncOptions` class to learn what /// optional arguments are supported for this functional. /// /// Example: /// ``` /// namespace F = torch::nn::functional; -/// F::fractional_max_pool3d_with_indices(x, F::FractionalMaxPool3dFuncOptions(3).output_size(2)); +/// F::fractional_max_pool3d_with_indices(x, +/// F::FractionalMaxPool3dFuncOptions(3).output_size(2)); /// ``` -inline std::tuple fractional_max_pool3d_with_indices(const Tensor& input, const FractionalMaxPool3dFuncOptions& options) { +inline std::tuple fractional_max_pool3d_with_indices( + const Tensor& input, + const FractionalMaxPool3dFuncOptions& options) { return detail::fractional_max_pool3d_with_indices( input, options.kernel_size(), @@ -883,26 +961,31 @@ inline std::tuple fractional_max_pool3d_with_indices(const Tenso #ifndef DOXYGEN_SHOULD_SKIP_THIS namespace detail { -inline Tensor fractional_max_pool3d(const Tensor& input, - ExpandingArray<3> kernel_size, - c10::optional> output_size, - c10::optional> output_ratio, - const Tensor& _random_samples) { - return std::get<0>(fractional_max_pool3d_with_indices(input, kernel_size, output_size, - output_ratio, _random_samples)); +inline Tensor fractional_max_pool3d( + const Tensor& input, + ExpandingArray<3> kernel_size, + c10::optional> output_size, + c10::optional> output_ratio, + const Tensor& _random_samples) { + return std::get<0>(fractional_max_pool3d_with_indices( + input, kernel_size, output_size, output_ratio, _random_samples)); } } // namespace detail #endif /* DOXYGEN_SHOULD_SKIP_THIS */ -/// See the documentation for `torch::nn::functional::FractionalMaxPool3dFuncOptions` class to learn what +/// See the documentation for +/// `torch::nn::functional::FractionalMaxPool3dFuncOptions` class to learn what /// optional arguments are supported for this functional. /// /// Example: /// ``` /// namespace F = torch::nn::functional; -/// F::fractional_max_pool3d(x, F::FractionalMaxPool3dFuncOptions(3).output_size(2)); +/// F::fractional_max_pool3d(x, +/// F::FractionalMaxPool3dFuncOptions(3).output_size(2)); /// ``` -inline Tensor fractional_max_pool3d(const Tensor& input, const FractionalMaxPool3dFuncOptions& options) { +inline Tensor fractional_max_pool3d( + const Tensor& input, + const FractionalMaxPool3dFuncOptions& options) { return detail::fractional_max_pool3d( input, options.kernel_size(), @@ -916,86 +999,96 @@ inline Tensor fractional_max_pool3d(const Tensor& input, const FractionalMaxPool #ifndef DOXYGEN_SHOULD_SKIP_THIS namespace detail { inline Tensor lp_pool1d( - const Tensor& input, - double norm_type, - ExpandingArray<1> kernel_size, - ExpandingArray<1> stride, - bool ceil_mode) { + const Tensor& input, + double norm_type, + ExpandingArray<1> kernel_size, + ExpandingArray<1> stride, + bool ceil_mode) { Tensor out = detail::avg_pool1d( - input.pow(norm_type), - kernel_size, - stride, - /*padding=*/0, - ceil_mode, - /*count_include_pad=*/true); + input.pow(norm_type), + kernel_size, + stride, + /*padding=*/0, + ceil_mode, + /*count_include_pad=*/true); - return (torch::sign(out) * relu(torch::abs(out))).mul((*kernel_size)[0]).pow(1. / norm_type); + return (torch::sign(out) * relu(torch::abs(out))) + .mul((*kernel_size)[0]) + .pow(1. / norm_type); } } // namespace detail #endif /* DOXYGEN_SHOULD_SKIP_THIS */ -/// See https://pytorch.org/docs/master/nn.functional.html#torch.nn.functional.lp_pool1d +/// See +/// https://pytorch.org/docs/master/nn.functional.html#torch.nn.functional.lp_pool1d /// about the exact behavior of this functional. /// -/// See the documentation for `torch::nn::functional::LPPool1dFuncOptions` class to learn what -/// optional arguments are supported for this functional. +/// See the documentation for `torch::nn::functional::LPPool1dFuncOptions` class +/// to learn what optional arguments are supported for this functional. /// /// Example: /// ``` /// namespace F = torch::nn::functional; /// F::lp_pool1d(x, F::LPPool1dFuncOptions(2, 3).stride(2)); /// ``` -inline Tensor lp_pool1d(const Tensor& input, const LPPool1dFuncOptions& options) { +inline Tensor lp_pool1d( + const Tensor& input, + const LPPool1dFuncOptions& options) { return detail::lp_pool1d( - input, - options.norm_type(), - options.kernel_size(), - options.stride(), - options.ceil_mode()); + input, + options.norm_type(), + options.kernel_size(), + options.stride(), + options.ceil_mode()); } #ifndef DOXYGEN_SHOULD_SKIP_THIS namespace detail { inline Tensor lp_pool2d( - const Tensor& input, - double norm_type, - ExpandingArray<2> kernel_size, - ExpandingArray<2> stride, - bool ceil_mode) { + const Tensor& input, + double norm_type, + ExpandingArray<2> kernel_size, + ExpandingArray<2> stride, + bool ceil_mode) { int kw = (*kernel_size)[0]; int kh = (*kernel_size)[1]; Tensor out = detail::avg_pool2d( - input.pow(norm_type), - kernel_size, - stride, - /*padding=*/0, - ceil_mode, - /*count_include_pad=*/true, - /*divisor_override=*/c10::nullopt); + input.pow(norm_type), + kernel_size, + stride, + /*padding=*/0, + ceil_mode, + /*count_include_pad=*/true, + /*divisor_override=*/c10::nullopt); - return (torch::sign(out) * relu(torch::abs(out))).mul(kw * kh).pow(1. / norm_type); + return (torch::sign(out) * relu(torch::abs(out))) + .mul(kw * kh) + .pow(1. / norm_type); } } // namespace detail #endif /* DOXYGEN_SHOULD_SKIP_THIS */ -/// See https://pytorch.org/docs/master/nn.functional.html#torch.nn.functional.lp_pool2d +/// See +/// https://pytorch.org/docs/master/nn.functional.html#torch.nn.functional.lp_pool2d /// about the exact behavior of this functional. /// -/// See the documentation for `torch::nn::functional::LPPool2dFuncOptions` class to learn what -/// optional arguments are supported for this functional. +/// See the documentation for `torch::nn::functional::LPPool2dFuncOptions` class +/// to learn what optional arguments are supported for this functional. /// /// Example: /// ``` /// namespace F = torch::nn::functional; /// F::lp_pool2d(x, F::LPPool2dFuncOptions(2, {2, 3}).stride(2)); /// ``` -inline Tensor lp_pool2d(const Tensor& input, const LPPool2dFuncOptions& options) { +inline Tensor lp_pool2d( + const Tensor& input, + const LPPool2dFuncOptions& options) { return detail::lp_pool2d( - input, - options.norm_type(), - options.kernel_size(), - options.stride(), - options.ceil_mode()); + input, + options.norm_type(), + options.kernel_size(), + options.stride(), + options.ceil_mode()); } } // namespace functional diff --git a/torch/csrc/api/include/torch/nn/functional/upsampling.h b/torch/csrc/api/include/torch/nn/functional/upsampling.h index fac6a9c6239bad..37d6fab99480c4 100644 --- a/torch/csrc/api/include/torch/nn/functional/upsampling.h +++ b/torch/csrc/api/include/torch/nn/functional/upsampling.h @@ -11,17 +11,18 @@ namespace nn { namespace functional { inline std::vector _interp_output_size( - int64_t dim, - std::tuple< - Tensor, - c10::optional>, - c10::optional>, - c10::optional> closed_over_args) { + int64_t dim, + std::tuple< + Tensor, + c10::optional>, + c10::optional>, + c10::optional> closed_over_args) { Tensor input; c10::optional> size; c10::optional> scale_factor; c10::optional recompute_scale_factor; - std::tie(input, size, scale_factor, recompute_scale_factor) = closed_over_args; + std::tie(input, size, scale_factor, recompute_scale_factor) = + closed_over_args; if (size == c10::nullopt && scale_factor == c10::nullopt) { TORCH_CHECK(false, "either size or scale_factor should be defined"); } @@ -31,9 +32,12 @@ inline std::vector _interp_output_size( if (scale_factor != c10::nullopt) { if (static_cast(scale_factor.value().size()) != dim) { TORCH_CHECK( - false, - "scale_factor shape must match input shape. ", - "Input is ", dim, "D, scale_factor size is ", torch::ArrayRef(*scale_factor)); + false, + "scale_factor shape must match input shape. ", + "Input is ", + dim, + "D, scale_factor size is ", + torch::ArrayRef(*scale_factor)); } } if (size != c10::nullopt) { @@ -54,17 +58,19 @@ inline std::vector _interp_output_size( } } if (is_float_scale_factor) { - TORCH_WARN("The default behavior for interpolate/upsample with float scale_factor changed " - "in 1.6.0 to align with other frameworks/libraries, and uses scale_factor directly, " - "instead of relying on the computed output size. " - "If you wish to keep the old behavior, please set recompute_scale_factor=True. " - "See the documentation of nn.Upsample for details. "); + TORCH_WARN( + "The default behavior for interpolate/upsample with float scale_factor changed " + "in 1.6.0 to align with other frameworks/libraries, and uses scale_factor directly, " + "instead of relying on the computed output size. " + "If you wish to keep the old behavior, please set recompute_scale_factor=True. " + "See the documentation of nn.Upsample for details. "); } } std::vector ret; for (const auto i : c10::irange(dim)) { - ret.emplace_back(static_cast(floor(static_cast(input.size(i + 2)) * scale_factors[i]))); + ret.emplace_back(static_cast( + floor(static_cast(input.size(i + 2)) * scale_factors[i]))); } return ret; } @@ -72,13 +78,13 @@ inline std::vector _interp_output_size( #ifndef DOXYGEN_SHOULD_SKIP_THIS namespace detail { inline Tensor interpolate( - const Tensor& input, - const c10::optional>& size, - const c10::optional>& scale_factor, - InterpolateFuncOptions::mode_t mode, - c10::optional align_corners, - c10::optional recompute_scale_factor, - bool antialias) { + const Tensor& input, + const c10::optional>& size, + const c10::optional>& scale_factor, + InterpolateFuncOptions::mode_t mode, + c10::optional align_corners, + c10::optional recompute_scale_factor, + bool antialias) { if (c10::get_if(&mode) || c10::get_if(&mode)) { if (align_corners != c10::nullopt) { @@ -90,7 +96,9 @@ inline Tensor interpolate( } else { if (align_corners == c10::nullopt) { TORCH_WARN( - "Default upsampling behavior when mode=", enumtype::get_enum_name(mode), " is changed " + "Default upsampling behavior when mode=", + enumtype::get_enum_name(mode), + " is changed " "to align_corners=False since 0.4.0. Please specify " "align_corners=True if the old behavior is desired. " "See the documentation of nn.Upsample for details."); @@ -99,7 +107,8 @@ inline Tensor interpolate( } auto scale_factor_len = input.dim() - 2; - std::vector> scale_factor_list(scale_factor_len, c10::nullopt); + std::vector> scale_factor_list( + scale_factor_len, c10::nullopt); if (scale_factor != c10::nullopt && !recompute_scale_factor.value_or(false)) { auto _scale_factor_repeated = *scale_factor; scale_factor_list = {}; @@ -108,38 +117,65 @@ inline Tensor interpolate( } } - if (antialias && !(input.dim() == 4 && (c10::get_if(&mode) ) ) ) { - TORCH_CHECK( - false, - "Anti-alias option is only supported for bilinear mode"); + if (antialias && + !(input.dim() == 4 && (c10::get_if(&mode)))) { + TORCH_CHECK(false, "Anti-alias option is only supported for bilinear mode"); } - auto closed_over_args = std::make_tuple(input, size, scale_factor, recompute_scale_factor); + auto closed_over_args = + std::make_tuple(input, size, scale_factor, recompute_scale_factor); if (input.dim() == 3 && c10::get_if(&mode)) { - return torch::upsample_nearest1d(input, _interp_output_size(1, closed_over_args), scale_factor_list.at(0)); + return torch::upsample_nearest1d( + input, + _interp_output_size(1, closed_over_args), + scale_factor_list.at(0)); } else if (input.dim() == 4 && c10::get_if(&mode)) { - return torch::upsample_nearest2d(input, _interp_output_size(2, closed_over_args), - scale_factor_list.at(0), scale_factor_list.at(1)); + return torch::upsample_nearest2d( + input, + _interp_output_size(2, closed_over_args), + scale_factor_list.at(0), + scale_factor_list.at(1)); } else if (input.dim() == 5 && c10::get_if(&mode)) { - return torch::upsample_nearest3d(input, _interp_output_size(3, closed_over_args), - scale_factor_list.at(0), scale_factor_list.at(1), scale_factor_list.at(2)); + return torch::upsample_nearest3d( + input, + _interp_output_size(3, closed_over_args), + scale_factor_list.at(0), + scale_factor_list.at(1), + scale_factor_list.at(2)); } else if (input.dim() == 3 && c10::get_if(&mode)) { - return torch::_upsample_nearest_exact1d(input, _interp_output_size(1, closed_over_args), scale_factor_list.at(0)); + return torch::_upsample_nearest_exact1d( + input, + _interp_output_size(1, closed_over_args), + scale_factor_list.at(0)); } else if (input.dim() == 4 && c10::get_if(&mode)) { - return torch::_upsample_nearest_exact2d(input, _interp_output_size(2, closed_over_args), - scale_factor_list.at(0), scale_factor_list.at(1)); + return torch::_upsample_nearest_exact2d( + input, + _interp_output_size(2, closed_over_args), + scale_factor_list.at(0), + scale_factor_list.at(1)); } else if (input.dim() == 5 && c10::get_if(&mode)) { - return torch::_upsample_nearest_exact3d(input, _interp_output_size(3, closed_over_args), - scale_factor_list.at(0), scale_factor_list.at(1), scale_factor_list.at(2)); + return torch::_upsample_nearest_exact3d( + input, + _interp_output_size(3, closed_over_args), + scale_factor_list.at(0), + scale_factor_list.at(1), + scale_factor_list.at(2)); } else if (input.dim() == 3 && c10::get_if(&mode)) { - return detail::adaptive_avg_pool1d(input, _interp_output_size(1, closed_over_args)); + return detail::adaptive_avg_pool1d( + input, _interp_output_size(1, closed_over_args)); } else if (input.dim() == 4 && c10::get_if(&mode)) { - return detail::adaptive_avg_pool2d(input, _interp_output_size(2, closed_over_args)); + return detail::adaptive_avg_pool2d( + input, _interp_output_size(2, closed_over_args)); } else if (input.dim() == 5 && c10::get_if(&mode)) { - return detail::adaptive_avg_pool3d(input, _interp_output_size(3, closed_over_args)); + return detail::adaptive_avg_pool3d( + input, _interp_output_size(3, closed_over_args)); } else if (input.dim() == 3 && c10::get_if(&mode)) { TORCH_INTERNAL_ASSERT(align_corners != c10::nullopt); - return torch::upsample_linear1d(input, _interp_output_size(1, closed_over_args), *align_corners, scale_factor_list.at(0)); + return torch::upsample_linear1d( + input, + _interp_output_size(1, closed_over_args), + *align_corners, + scale_factor_list.at(0)); } else if (input.dim() == 3 && c10::get_if(&mode)) { TORCH_CHECK(false, "Got 3D input, but bilinear mode needs 4D input"); } else if (input.dim() == 3 && c10::get_if(&mode)) { @@ -149,11 +185,19 @@ inline Tensor interpolate( } else if (input.dim() == 4 && c10::get_if(&mode)) { TORCH_INTERNAL_ASSERT(align_corners != c10::nullopt); if (antialias) { - return torch::_upsample_bilinear2d_aa(input, _interp_output_size(2, closed_over_args), *align_corners, - scale_factor_list.at(0), scale_factor_list.at(1)); + return torch::_upsample_bilinear2d_aa( + input, + _interp_output_size(2, closed_over_args), + *align_corners, + scale_factor_list.at(0), + scale_factor_list.at(1)); } - return torch::upsample_bilinear2d(input, _interp_output_size(2, closed_over_args), *align_corners, - scale_factor_list.at(0), scale_factor_list.at(1)); + return torch::upsample_bilinear2d( + input, + _interp_output_size(2, closed_over_args), + *align_corners, + scale_factor_list.at(0), + scale_factor_list.at(1)); } else if (input.dim() == 4 && c10::get_if(&mode)) { TORCH_CHECK(false, "Got 4D input, but trilinear mode needs 5D input"); } else if (input.dim() == 5 && c10::get_if(&mode)) { @@ -162,43 +206,60 @@ inline Tensor interpolate( TORCH_CHECK(false, "Got 5D input, but bilinear mode needs 4D input"); } else if (input.dim() == 5 && c10::get_if(&mode)) { TORCH_INTERNAL_ASSERT(align_corners != c10::nullopt); - return torch::upsample_trilinear3d(input, _interp_output_size(3, closed_over_args), *align_corners, - scale_factor_list.at(0), scale_factor_list.at(1), scale_factor_list.at(2)); + return torch::upsample_trilinear3d( + input, + _interp_output_size(3, closed_over_args), + *align_corners, + scale_factor_list.at(0), + scale_factor_list.at(1), + scale_factor_list.at(2)); } else if (input.dim() == 4 && c10::get_if(&mode)) { TORCH_INTERNAL_ASSERT(align_corners != c10::nullopt); - return torch::upsample_bicubic2d(input, _interp_output_size(2, closed_over_args), *align_corners, - scale_factor_list.at(0), scale_factor_list.at(1)); + return torch::upsample_bicubic2d( + input, + _interp_output_size(2, closed_over_args), + *align_corners, + scale_factor_list.at(0), + scale_factor_list.at(1)); } else { TORCH_CHECK( false, "Input Error: Only 3D, 4D and 5D input Tensors supported " - "(got ", input.dim(), "D) for the modes: nearest | linear | bilinear | bicubic | trilinear " - "(got ", enumtype::get_enum_name(mode), ")"); + "(got ", + input.dim(), + "D) for the modes: nearest | linear | bilinear | bicubic | trilinear " + "(got ", + enumtype::get_enum_name(mode), + ")"); } } } // namespace detail #endif /* DOXYGEN_SHOULD_SKIP_THIS */ -/// See https://pytorch.org/docs/master/nn.functional.html#torch.nn.functional.interpolate +/// See +/// https://pytorch.org/docs/master/nn.functional.html#torch.nn.functional.interpolate /// about the exact behavior of this functional. /// -/// See the documentation for `torch::nn::functional::InterpolateFuncOptions` class to learn what -/// optional arguments are supported for this functional. +/// See the documentation for `torch::nn::functional::InterpolateFuncOptions` +/// class to learn what optional arguments are supported for this functional. /// /// Example: /// ``` /// namespace F = torch::nn::functional; -/// F::interpolate(input, F::InterpolateFuncOptions().size({4}).mode(torch::kNearest)); +/// F::interpolate(input, +/// F::InterpolateFuncOptions().size({4}).mode(torch::kNearest)); /// ``` -inline Tensor interpolate(const Tensor& input, const InterpolateFuncOptions& options = {}) { +inline Tensor interpolate( + const Tensor& input, + const InterpolateFuncOptions& options = {}) { return detail::interpolate( - input, - options.size(), - options.scale_factor(), - options.mode(), - options.align_corners(), - options.recompute_scale_factor(), - options.antialias()); + input, + options.size(), + options.scale_factor(), + options.mode(), + options.align_corners(), + options.recompute_scale_factor(), + options.antialias()); } } // namespace functional diff --git a/torch/csrc/api/include/torch/nn/functional/vision.h b/torch/csrc/api/include/torch/nn/functional/vision.h index a19866fed6c646..2ebd61354f269a 100644 --- a/torch/csrc/api/include/torch/nn/functional/vision.h +++ b/torch/csrc/api/include/torch/nn/functional/vision.h @@ -24,20 +24,23 @@ inline Tensor affine_grid( "Expected a batch of 2D affine matrices of shape Nx2x3 for size ", size, ". Got ", - theta.sizes(), "."); + theta.sizes(), + "."); } else if (size.size() == 5) { TORCH_CHECK( theta.dim() == 3 && theta.size(-2) == 3 && theta.size(-1) == 4, "Expected a batch of 3D affine matrices of shape Nx3x4 for size ", size, ". Got ", - theta.sizes(), "."); + theta.sizes(), + "."); } else { TORCH_CHECK( false, "affine_grid only supports 4D and 5D sizes, ", "for 2D and 3D affine transforms, respectively. ", - "Got size ", size); + "Got size ", + size); } if (*std::min_element(size.begin(), size.end()) <= 0) { @@ -77,39 +80,43 @@ inline Tensor grid_sample( } if (!align_corners.has_value()) { - TORCH_WARN("Default grid_sample and affine_grid behavior has changed ", - "to align_corners=False since 1.3.0. Please specify ", - "align_corners=True if the old behavior is desired. ", - "See the documentation of grid_sample for details."); + TORCH_WARN( + "Default grid_sample and affine_grid behavior has changed ", + "to align_corners=False since 1.3.0. Please specify ", + "align_corners=True if the old behavior is desired. ", + "See the documentation of grid_sample for details."); align_corners = false; } - return torch::grid_sampler(input, grid, mode_enum, padding_mode_enum, align_corners.value()); + return torch::grid_sampler( + input, grid, mode_enum, padding_mode_enum, align_corners.value()); } } // namespace detail #endif /* DOXYGEN_SHOULD_SKIP_THIS */ -/// See https://pytorch.org/docs/master/nn.functional.html#torch.nn.functional.grid_sample +/// See +/// https://pytorch.org/docs/master/nn.functional.html#torch.nn.functional.grid_sample /// about the exact behavior of this functional. /// -/// See the documentation for `torch::nn::functional::GridSampleFuncOptions` class to learn what -/// optional arguments are supported for this functional. +/// See the documentation for `torch::nn::functional::GridSampleFuncOptions` +/// class to learn what optional arguments are supported for this functional. /// /// Example: /// ``` /// namespace F = torch::nn::functional; -/// F::grid_sample(input, grid, F::GridSampleFuncOptions().mode(torch::kBilinear).padding_mode(torch::kZeros).align_corners(true)); +/// F::grid_sample(input, grid, +/// F::GridSampleFuncOptions().mode(torch::kBilinear).padding_mode(torch::kZeros).align_corners(true)); /// ``` inline Tensor grid_sample( const Tensor& input, const Tensor& grid, const GridSampleFuncOptions& options = {}) { return detail::grid_sample( - input, - grid, - options.mode(), - options.padding_mode(), - options.align_corners()); + input, + grid, + options.mode(), + options.padding_mode(), + options.align_corners()); } } // namespace functional diff --git a/torch/csrc/api/include/torch/nn/init.h b/torch/csrc/api/include/torch/nn/init.h index 20a773d16a7675..bffb76f24fe712 100644 --- a/torch/csrc/api/include/torch/nn/init.h +++ b/torch/csrc/api/include/torch/nn/init.h @@ -9,32 +9,30 @@ namespace nn { namespace init { using NonlinearityType = c10::variant< - enumtype::kLinear, - enumtype::kConv1D, - enumtype::kConv2D, - enumtype::kConv3D, - enumtype::kConvTranspose1D, - enumtype::kConvTranspose2D, - enumtype::kConvTranspose3D, - enumtype::kSigmoid, - enumtype::kTanh, - enumtype::kReLU, - enumtype::kLeakyReLU ->; - -using FanModeType = c10::variant< - enumtype::kFanIn, - enumtype::kFanOut ->; + enumtype::kLinear, + enumtype::kConv1D, + enumtype::kConv2D, + enumtype::kConv3D, + enumtype::kConvTranspose1D, + enumtype::kConvTranspose2D, + enumtype::kConvTranspose3D, + enumtype::kSigmoid, + enumtype::kTanh, + enumtype::kReLU, + enumtype::kLeakyReLU>; + +using FanModeType = c10::variant; } // namespace init -} // nn +} // namespace nn namespace nn { namespace init { /// Return the recommended gain value for the given nonlinearity function. -TORCH_API double calculate_gain(NonlinearityType nonlinearity, double param = 0.01); +TORCH_API double calculate_gain( + NonlinearityType nonlinearity, + double param = 0.01); /// Fills the given `tensor` with the provided `value` in-place, and returns it. /// No gradient will be recorded for this operation. @@ -118,7 +116,8 @@ TORCH_API Tensor xavier_uniform_(Tensor tensor, double gain = 1.0); /// No gradient will be recorded for this operation. TORCH_API Tensor zeros_(Tensor tensor); -TORCH_API std::tuple _calculate_fan_in_and_fan_out(const Tensor& tensor); +TORCH_API std::tuple _calculate_fan_in_and_fan_out( + const Tensor& tensor); } // namespace init } // namespace nn diff --git a/torch/csrc/api/include/torch/nn/module.h b/torch/csrc/api/include/torch/nn/module.h index 91fad2aab767b0..ff0348eb841bf3 100644 --- a/torch/csrc/api/include/torch/nn/module.h +++ b/torch/csrc/api/include/torch/nn/module.h @@ -385,15 +385,15 @@ class TORCH_API Module : public std::enable_shared_from_this { /// Serializes the `Module` into the given `OutputArchive`. /// - /// If the `Module` contains unserializable submodules (e.g. `nn::Functional`), - /// those submodules are skipped when serializing. + /// If the `Module` contains unserializable submodules (e.g. + /// `nn::Functional`), those submodules are skipped when serializing. virtual void save(serialize::OutputArchive& archive) const; /// Deserializes the `Module` from the given `InputArchive`. /// - /// If the `Module` contains unserializable submodules (e.g. `nn::Functional`), - /// we don't check the existence of those submodules in the `InputArchive` when - /// deserializing. + /// If the `Module` contains unserializable submodules (e.g. + /// `nn::Functional`), we don't check the existence of those submodules in the + /// `InputArchive` when deserializing. virtual void load(serialize::InputArchive& archive); /// Streams a pretty representation of the `Module` into the given `stream`. @@ -414,8 +414,9 @@ class TORCH_API Module : public std::enable_shared_from_this { /// implementation of your `Module`. Registering it makes it available to /// methods such as `parameters()`, `clone()` or `to().` /// - /// Note that registering an undefined Tensor (e.g. `module.register_parameter("param", Tensor())`) - /// is allowed, and is equivalent to `module.register_parameter("param", None)` in Python API. + /// Note that registering an undefined Tensor (e.g. + /// `module.register_parameter("param", Tensor())`) is allowed, and is + /// equivalent to `module.register_parameter("param", None)` in Python API. /// /// \rst /// .. code-block:: cpp @@ -482,9 +483,11 @@ class TORCH_API Module : public std::enable_shared_from_this { /// Replaces a registered submodule with this `Module`. /// - /// This takes care of the registration, if you used submodule members, you should + /// This takes care of the registration, if you used submodule members, you + /// should // assign the submodule as well, i.e. use as - /// module->submodule_ = module->replace_module("linear", torch::nn::Linear(3, 4)); + /// module->submodule_ = module->replace_module("linear", + /// torch::nn::Linear(3, 4)); /// It only works when a module of the name is already registered. /// /// This is useful for replacing a module after initialization, e.g. @@ -497,7 +500,8 @@ class TORCH_API Module : public std::enable_shared_from_this { /// Replaces a registered submodule with this `Module`. /// This method deals with `ModuleHolder`s. /// - /// This takes care of the registration, if you used submodule members, you should + /// This takes care of the registration, if you used submodule members, you + /// should // assign the submodule as well, i.e. use as /// module->submodule_ = module->replace_module("linear", linear_holder); /// It only works when a module of the name is already registered. @@ -516,26 +520,27 @@ class TORCH_API Module : public std::enable_shared_from_this { protected: /// The following three functions allow a module with default arguments in its /// forward method to be used in a Sequential module. - /// You should NEVER override these functions manually. Instead, you should use the - /// `FORWARD_HAS_DEFAULT_ARGS` macro. + /// You should NEVER override these functions manually. Instead, you should + /// use the `FORWARD_HAS_DEFAULT_ARGS` macro. virtual bool _forward_has_default_args() { return false; } virtual unsigned int _forward_num_required_args() { TORCH_CHECK( - false, - "torch::nn::Module subclass that has default arguments in `forward` method ", - "must override `_forward_num_required_args` method. Please use ", - "`FORWARD_HAS_DEFAULT_ARGS` macro to do so."); + false, + "torch::nn::Module subclass that has default arguments in `forward` method ", + "must override `_forward_num_required_args` method. Please use ", + "`FORWARD_HAS_DEFAULT_ARGS` macro to do so."); } - virtual std::vector _forward_populate_default_args(std::vector&& arguments) { + virtual std::vector _forward_populate_default_args( + std::vector&& arguments) { TORCH_CHECK( - false, - "torch::nn::Module subclass that has default arguments in `forward` method ", - "must override `_forward_populate_default_args` method. Please use ", - "`FORWARD_HAS_DEFAULT_ARGS` macro to do so."); + false, + "torch::nn::Module subclass that has default arguments in `forward` method ", + "must override `_forward_populate_default_args` method. Please use ", + "`FORWARD_HAS_DEFAULT_ARGS` macro to do so."); } /// The registered parameters of this `Module`. diff --git a/torch/csrc/api/include/torch/nn/modules.h b/torch/csrc/api/include/torch/nn/modules.h index cbb0f9bc10ddf0..e037d52a853549 100644 --- a/torch/csrc/api/include/torch/nn/modules.h +++ b/torch/csrc/api/include/torch/nn/modules.h @@ -9,28 +9,28 @@ #include #include #include -#include #include #include +#include // Layers +#include #include #include -#include #include -#include #include +#include #include #include +#include #include #include +#include #include +#include #include #include -#include -#include -#include -#include -#include -#include #include +#include +#include +#include diff --git a/torch/csrc/api/include/torch/nn/modules/_functions.h b/torch/csrc/api/include/torch/nn/modules/_functions.h index 2c6b7ee4df1d02..5bf1ce2dcb2853 100644 --- a/torch/csrc/api/include/torch/nn/modules/_functions.h +++ b/torch/csrc/api/include/torch/nn/modules/_functions.h @@ -2,8 +2,8 @@ #include #include -#include #include +#include namespace torch { namespace nn { @@ -12,12 +12,13 @@ namespace functions { class CrossMapLRN2d : public torch::autograd::Function { public: static torch::autograd::Variable forward( - torch::autograd::AutogradContext *ctx, + torch::autograd::AutogradContext* ctx, const torch::autograd::Variable& input, const CrossMapLRN2dOptions& options); static torch::autograd::variable_list backward( - torch::autograd::AutogradContext *ctx, torch::autograd::variable_list grad_output); + torch::autograd::AutogradContext* ctx, + torch::autograd::variable_list grad_output); }; } // namespace functions diff --git a/torch/csrc/api/include/torch/nn/modules/activation.h b/torch/csrc/api/include/torch/nn/modules/activation.h index e4fc02f310d566..b52c2b12d8f658 100644 --- a/torch/csrc/api/include/torch/nn/modules/activation.h +++ b/torch/csrc/api/include/torch/nn/modules/activation.h @@ -1,10 +1,10 @@ #pragma once #include -#include #include #include #include +#include #include @@ -114,9 +114,9 @@ class TORCH_API HardshrinkImpl : public torch::nn::Cloneable { /// A `ModuleHolder` subclass for `HardshrinkImpl`. /// See the documentation for `HardshrinkImpl` class to learn what methods it -/// provides, and examples of how to use `Hardshrink` with `torch::nn::HardshrinkOptions`. -/// See the documentation for `ModuleHolder` to learn about PyTorch's -/// module storage semantics. +/// provides, and examples of how to use `Hardshrink` with +/// `torch::nn::HardshrinkOptions`. See the documentation for `ModuleHolder` to +/// learn about PyTorch's module storage semantics. TORCH_MODULE(Hardshrink); // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Hardtanh ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -130,7 +130,8 @@ TORCH_MODULE(Hardshrink); /// /// Example: /// ``` -/// Hardtanh model(HardtanhOptions().min_val(-42.42).max_val(0.42).inplace(true)); +/// Hardtanh +/// model(HardtanhOptions().min_val(-42.42).max_val(0.42).inplace(true)); /// ``` // NOLINTNEXTLINE(bugprone-exception-escape) class TORCH_API HardtanhImpl : public torch::nn::Cloneable { @@ -150,9 +151,9 @@ class TORCH_API HardtanhImpl : public torch::nn::Cloneable { /// A `ModuleHolder` subclass for `HardtanhImpl`. /// See the documentation for `HardtanhImpl` class to learn what methods it -/// provides, and examples of how to use `Hardtanh` with `torch::nn::HardtanhOptions`. -/// See the documentation for `ModuleHolder` to learn about PyTorch's -/// module storage semantics. +/// provides, and examples of how to use `Hardtanh` with +/// `torch::nn::HardtanhOptions`. See the documentation for `ModuleHolder` to +/// learn about PyTorch's module storage semantics. TORCH_MODULE(Hardtanh); // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ LeakyReLU ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -186,9 +187,9 @@ class TORCH_API LeakyReLUImpl : public torch::nn::Cloneable { /// A `ModuleHolder` subclass for `LeakyReLUImpl`. /// See the documentation for `LeakyReLUImpl` class to learn what methods it -/// provides, and examples of how to use `LeakyReLU` with `torch::nn::LeakyReLUOptions`. -/// See the documentation for `ModuleHolder` to learn about PyTorch's -/// module storage semantics. +/// provides, and examples of how to use `LeakyReLU` with +/// `torch::nn::LeakyReLUOptions`. See the documentation for `ModuleHolder` to +/// learn about PyTorch's module storage semantics. TORCH_MODULE(LeakyReLU); // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ LogSigmoid ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -244,9 +245,9 @@ class TORCH_API SoftmaxImpl : public torch::nn::Cloneable { /// A `ModuleHolder` subclass for `SoftmaxImpl`. /// See the documentation for `SoftmaxImpl` class to learn what methods it -/// provides, and examples of how to use `Softmax` with `torch::nn::SoftmaxOptions`. -/// See the documentation for `ModuleHolder` to learn about PyTorch's -/// module storage semantics. +/// provides, and examples of how to use `Softmax` with +/// `torch::nn::SoftmaxOptions`. See the documentation for `ModuleHolder` to +/// learn about PyTorch's module storage semantics. TORCH_MODULE(Softmax); // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Softmin ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -280,9 +281,9 @@ class TORCH_API SoftminImpl : public torch::nn::Cloneable { /// A `ModuleHolder` subclass for `SoftminImpl`. /// See the documentation for `SoftminImpl` class to learn what methods it -/// provides, and examples of how to use `Softmin` with `torch::nn::SoftminOptions`. -/// See the documentation for `ModuleHolder` to learn about PyTorch's -/// module storage semantics. +/// provides, and examples of how to use `Softmin` with +/// `torch::nn::SoftminOptions`. See the documentation for `ModuleHolder` to +/// learn about PyTorch's module storage semantics. TORCH_MODULE(Softmin); // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ LogSoftmax ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -301,7 +302,8 @@ TORCH_MODULE(Softmin); // NOLINTNEXTLINE(bugprone-exception-escape) class TORCH_API LogSoftmaxImpl : public torch::nn::Cloneable { public: - explicit LogSoftmaxImpl(int64_t dim) : LogSoftmaxImpl(LogSoftmaxOptions(dim)) {} + explicit LogSoftmaxImpl(int64_t dim) + : LogSoftmaxImpl(LogSoftmaxOptions(dim)) {} explicit LogSoftmaxImpl(const LogSoftmaxOptions& options_); Tensor forward(const Tensor& input); @@ -316,9 +318,9 @@ class TORCH_API LogSoftmaxImpl : public torch::nn::Cloneable { /// A `ModuleHolder` subclass for `LogSoftmaxImpl`. /// See the documentation for `LogSoftmaxImpl` class to learn what methods it -/// provides, and examples of how to use `LogSoftmax` with `torch::nn::LogSoftmaxOptions`. -/// See the documentation for `ModuleHolder` to learn about PyTorch's -/// module storage semantics. +/// provides, and examples of how to use `LogSoftmax` with +/// `torch::nn::LogSoftmaxOptions`. See the documentation for `ModuleHolder` to +/// learn about PyTorch's module storage semantics. TORCH_MODULE(LogSoftmax); // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Softmax2d ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -686,9 +688,9 @@ class TORCH_API SoftplusImpl : public torch::nn::Cloneable { /// A `ModuleHolder` subclass for `SoftplusImpl`. /// See the documentation for `SoftplusImpl` class to learn what methods it -/// provides, and examples of how to use `Softplus` with `torch::nn::SoftplusOptions`. -/// See the documentation for `ModuleHolder` to learn about PyTorch's -/// module storage semantics. +/// provides, and examples of how to use `Softplus` with +/// `torch::nn::SoftplusOptions`. See the documentation for `ModuleHolder` to +/// learn about PyTorch's module storage semantics. TORCH_MODULE(Softplus); // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Softshrink ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -722,9 +724,9 @@ class TORCH_API SoftshrinkImpl : public torch::nn::Cloneable { /// A `ModuleHolder` subclass for `SoftshrinkImpl`. /// See the documentation for `SoftshrinkImpl` class to learn what methods it -/// provides, and examples of how to use `Softshrink` with `torch::nn::SoftshrinkOptions`. -/// See the documentation for `ModuleHolder` to learn about PyTorch's -/// module storage semantics. +/// provides, and examples of how to use `Softshrink` with +/// `torch::nn::SoftshrinkOptions`. See the documentation for `ModuleHolder` to +/// learn about PyTorch's module storage semantics. TORCH_MODULE(Softshrink); // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Softsign ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -810,7 +812,7 @@ TORCH_MODULE(Tanhshrink); class TORCH_API ThresholdImpl : public torch::nn::Cloneable { public: ThresholdImpl(double threshold, double value) - : ThresholdImpl(ThresholdOptions(threshold, value)) {} + : ThresholdImpl(ThresholdOptions(threshold, value)) {} explicit ThresholdImpl(const ThresholdOptions& options_); Tensor forward(Tensor input); @@ -826,9 +828,9 @@ class TORCH_API ThresholdImpl : public torch::nn::Cloneable { /// A `ModuleHolder` subclass for `ThresholdImpl`. /// See the documentation for `ThresholdImpl` class to learn what methods it -/// provides, and examples of how to use `Threshold` with `torch::nn::ThresholdOptions`. -/// See the documentation for `ModuleHolder` to learn about PyTorch's -/// module storage semantics. +/// provides, and examples of how to use `Threshold` with +/// `torch::nn::ThresholdOptions`. See the documentation for `ModuleHolder` to +/// learn about PyTorch's module storage semantics. TORCH_MODULE(Threshold); // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ MultiheadAttention ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -837,8 +839,8 @@ TORCH_MODULE(Threshold); /// See https://pytorch.org/docs/master/nn.html#torch.nn.MultiheadAttention /// to learn about the exact behavior of this module. /// -/// See the documentation for `torch::nn::MultiheadAttentionOptions` class to learn what -/// constructor arguments are supported for this module. +/// See the documentation for `torch::nn::MultiheadAttentionOptions` class to +/// learn what constructor arguments are supported for this module. /// /// Example: /// ``` @@ -846,18 +848,28 @@ TORCH_MODULE(Threshold); /// ``` // NOLINTNEXTLINE(bugprone-exception-escape) class TORCH_API MultiheadAttentionImpl - : public torch::nn::Cloneable { + : public torch::nn::Cloneable { public: MultiheadAttentionImpl(int64_t embed_dim, int64_t num_heads) - : MultiheadAttentionImpl(MultiheadAttentionOptions(embed_dim, num_heads)) {} + : MultiheadAttentionImpl( + MultiheadAttentionOptions(embed_dim, num_heads)) {} explicit MultiheadAttentionImpl(const MultiheadAttentionOptions& options_); - std::tuple forward(const Tensor& query, const Tensor& key, - const Tensor& value, const Tensor& key_padding_mask = {}, - bool need_weights = true, const Tensor& attn_mask = {}, - bool average_attn_weights = true); + std::tuple forward( + const Tensor& query, + const Tensor& key, + const Tensor& value, + const Tensor& key_padding_mask = {}, + bool need_weights = true, + const Tensor& attn_mask = {}, + bool average_attn_weights = true); + protected: - FORWARD_HAS_DEFAULT_ARGS({3, AnyValue(Tensor())}, {4, AnyValue(true)}, {5, AnyValue(Tensor())}, {6, AnyValue(true)}) + FORWARD_HAS_DEFAULT_ARGS( + {3, AnyValue(Tensor())}, + {4, AnyValue(true)}, + {5, AnyValue(Tensor())}, + {6, AnyValue(true)}) public: void reset() override; @@ -880,10 +892,10 @@ class TORCH_API MultiheadAttentionImpl }; /// A `ModuleHolder` subclass for `MultiheadAttentionImpl`. -/// See the documentation for `MultiheadAttentionImpl` class to learn what methods it -/// provides, and examples of how to use `MultiheadAttention` with `torch::nn::MultiheadAttentionOptions`. -/// See the documentation for `ModuleHolder` to learn about PyTorch's -/// module storage semantics. +/// See the documentation for `MultiheadAttentionImpl` class to learn what +/// methods it provides, and examples of how to use `MultiheadAttention` with +/// `torch::nn::MultiheadAttentionOptions`. See the documentation for +/// `ModuleHolder` to learn about PyTorch's module storage semantics. TORCH_MODULE(MultiheadAttention); } // namespace nn diff --git a/torch/csrc/api/include/torch/nn/modules/adaptive.h b/torch/csrc/api/include/torch/nn/modules/adaptive.h index 150091884ef65c..b8b5170d177acb 100644 --- a/torch/csrc/api/include/torch/nn/modules/adaptive.h +++ b/torch/csrc/api/include/torch/nn/modules/adaptive.h @@ -1,11 +1,11 @@ #pragma once #include +#include #include -#include #include #include -#include +#include #include namespace torch { @@ -23,28 +23,39 @@ struct TORCH_API ASMoutput { double loss; }; -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ AdaptiveLogSoftmaxWithLoss ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ AdaptiveLogSoftmaxWithLoss +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ /// Efficient softmax approximation as described in /// `Efficient softmax approximation for GPUs`_ by Edouard Grave, Armand Joulin, /// Moustapha Cissé, David Grangier, and Hervé Jégou. -/// See https://pytorch.org/docs/master/nn.html#torch.nn.AdaptiveLogSoftmaxWithLoss to learn -/// about the exact behavior of this module. +/// See +/// https://pytorch.org/docs/master/nn.html#torch.nn.AdaptiveLogSoftmaxWithLoss +/// to learn about the exact behavior of this module. /// -/// See the documentation for `torch::nn::AdaptiveLogSoftmaxWithLossOptions` class to learn what -/// constructor arguments are supported for this module. +/// See the documentation for `torch::nn::AdaptiveLogSoftmaxWithLossOptions` +/// class to learn what constructor arguments are supported for this module. /// /// Example: /// ``` -/// AdaptiveLogSoftmaxWithLoss model(AdaptiveLogSoftmaxWithLossOptions(8, 10, {4, 8}).div_value(2.).head_bias(true)); +/// AdaptiveLogSoftmaxWithLoss model(AdaptiveLogSoftmaxWithLossOptions(8, 10, +/// {4, 8}).div_value(2.).head_bias(true)); /// ``` // NOLINTNEXTLINE(bugprone-exception-escape) -class TORCH_API AdaptiveLogSoftmaxWithLossImpl : public Cloneable { +class TORCH_API AdaptiveLogSoftmaxWithLossImpl + : public Cloneable { public: - AdaptiveLogSoftmaxWithLossImpl(int64_t in_features, int64_t n_classes, std::vector cutoffs) - : AdaptiveLogSoftmaxWithLossImpl(AdaptiveLogSoftmaxWithLossOptions(in_features, n_classes, cutoffs)) {} - - explicit AdaptiveLogSoftmaxWithLossImpl(AdaptiveLogSoftmaxWithLossOptions options_); + AdaptiveLogSoftmaxWithLossImpl( + int64_t in_features, + int64_t n_classes, + std::vector cutoffs) + : AdaptiveLogSoftmaxWithLossImpl(AdaptiveLogSoftmaxWithLossOptions( + in_features, + n_classes, + cutoffs)) {} + + explicit AdaptiveLogSoftmaxWithLossImpl( + AdaptiveLogSoftmaxWithLossOptions options_); ASMoutput forward(const Tensor& input, const Tensor& target); @@ -52,23 +63,26 @@ class TORCH_API AdaptiveLogSoftmaxWithLossImpl : public Cloneable cutoffs; int64_t shortlist_size; @@ -85,11 +99,12 @@ class TORCH_API AdaptiveLogSoftmaxWithLossImpl : public Cloneable #include -#include #include +#include #include #include @@ -12,7 +12,8 @@ namespace torch { namespace nn { -/// Base class for all (dimension-specialized) batchnorm and instancenorm modules. +/// Base class for all (dimension-specialized) batchnorm and instancenorm +/// modules. template // NOLINTNEXTLINE(bugprone-exception-escape) class NormImplBase : public torch::nn::Cloneable { @@ -27,20 +28,28 @@ class NormImplBase : public torch::nn::Cloneable { void reset() override { if (options.affine()) { - weight = this->register_parameter("weight", torch::empty({options.num_features()})); - bias = this->register_parameter("bias", torch::empty({options.num_features()})); + weight = this->register_parameter( + "weight", torch::empty({options.num_features()})); + bias = this->register_parameter( + "bias", torch::empty({options.num_features()})); } else { - weight = this->register_parameter("weight", Tensor(), /*requires_grad=*/false); - bias = this->register_parameter("bias", Tensor(), /*requires_grad=*/false); + weight = + this->register_parameter("weight", Tensor(), /*requires_grad=*/false); + bias = + this->register_parameter("bias", Tensor(), /*requires_grad=*/false); } if (options.track_running_stats()) { - running_mean = this->register_buffer("running_mean", torch::zeros({options.num_features()})); - running_var = this->register_buffer("running_var", torch::ones({options.num_features()})); - num_batches_tracked = this->register_buffer("num_batches_tracked", torch::tensor(0, torch::dtype(torch::kLong))); + running_mean = this->register_buffer( + "running_mean", torch::zeros({options.num_features()})); + running_var = this->register_buffer( + "running_var", torch::ones({options.num_features()})); + num_batches_tracked = this->register_buffer( + "num_batches_tracked", torch::tensor(0, torch::dtype(torch::kLong))); } else { running_mean = this->register_buffer("running_mean", Tensor()); running_var = this->register_buffer("running_var", Tensor()); - num_batches_tracked = this->register_buffer("num_batches_tracked", Tensor()); + num_batches_tracked = + this->register_buffer("num_batches_tracked", Tensor()); } reset_parameters(); } @@ -73,15 +82,18 @@ class NormImplBase : public torch::nn::Cloneable { Tensor bias; /// The running mean. - /// Only defined if the `track_running_stats` option was `true` upon construction. + /// Only defined if the `track_running_stats` option was `true` upon + /// construction. Tensor running_mean; /// The running variance. - /// Only defined if the `track_running_stats` option was `true` upon construction. + /// Only defined if the `track_running_stats` option was `true` upon + /// construction. Tensor running_var; /// The number of the forward call. - /// Only defined if the `track_running_stats` option was `true` upon construction. + /// Only defined if the `track_running_stats` option was `true` upon + /// construction. Tensor num_batches_tracked; }; @@ -105,9 +117,11 @@ class BatchNormImplBase : public NormImplBase { if (this->is_training() && this->options.track_running_stats()) { if (this->num_batches_tracked.defined()) { this->num_batches_tracked += 1; - if (this->options.momentum() == c10::nullopt) { // use cumulative moving average - exponential_average_factor = 1.0 / this->num_batches_tracked.template item(); - } else { // use exponential moving average + if (this->options.momentum() == + c10::nullopt) { // use cumulative moving average + exponential_average_factor = + 1.0 / this->num_batches_tracked.template item(); + } else { // use exponential moving average exponential_average_factor = this->options.momentum().value(); } } @@ -128,23 +142,25 @@ class BatchNormImplBase : public NormImplBase { void pretty_print(std::ostream& stream) const override; }; -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ BatchNorm1d ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ BatchNorm1d +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ /// Applies the BatchNorm1d function. /// See https://pytorch.org/docs/master/nn.html#torch.nn.BatchNorm1d to learn /// about the exact behavior of this module. /// -/// See the documentation for `torch::nn::BatchNorm1dOptions` class to learn what -/// constructor arguments are supported for this module. +/// See the documentation for `torch::nn::BatchNorm1dOptions` class to learn +/// what constructor arguments are supported for this module. /// /// Example: /// ``` -/// BatchNorm1d model(BatchNorm1dOptions(4).eps(0.5).momentum(0.1).affine(false).track_running_stats(true)); +/// BatchNorm1d +/// model(BatchNorm1dOptions(4).eps(0.5).momentum(0.1).affine(false).track_running_stats(true)); /// ``` // NOLINTNEXTLINE(bugprone-exception-escape) class TORCH_API BatchNorm1dImpl : public BatchNormImplBase<1, BatchNorm1dImpl> { protected: - void _check_input_dim(const Tensor& input) override; + void _check_input_dim(const Tensor& input) override; public: using BatchNormImplBase<1, BatchNorm1dImpl>::BatchNormImplBase; @@ -152,28 +168,30 @@ class TORCH_API BatchNorm1dImpl : public BatchNormImplBase<1, BatchNorm1dImpl> { /// A `ModuleHolder` subclass for `BatchNorm1dImpl`. /// See the documentation for `BatchNorm1dImpl` class to learn what methods it -/// provides, and examples of how to use `BatchNorm1d` with `torch::nn::BatchNorm1dOptions`. -/// See the documentation for `ModuleHolder` to learn about PyTorch's -/// module storage semantics. +/// provides, and examples of how to use `BatchNorm1d` with +/// `torch::nn::BatchNorm1dOptions`. See the documentation for `ModuleHolder` to +/// learn about PyTorch's module storage semantics. TORCH_MODULE(BatchNorm1d); -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ BatchNorm2d ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ BatchNorm2d +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ /// Applies the BatchNorm2d function. /// See https://pytorch.org/docs/master/nn.html#torch.nn.BatchNorm2d to learn /// about the exact behavior of this module. /// -/// See the documentation for `torch::nn::BatchNorm2dOptions` class to learn what -/// constructor arguments are supported for this module. +/// See the documentation for `torch::nn::BatchNorm2dOptions` class to learn +/// what constructor arguments are supported for this module. /// /// Example: /// ``` -/// BatchNorm2d model(BatchNorm2dOptions(4).eps(0.5).momentum(0.1).affine(false).track_running_stats(true)); +/// BatchNorm2d +/// model(BatchNorm2dOptions(4).eps(0.5).momentum(0.1).affine(false).track_running_stats(true)); /// ``` // NOLINTNEXTLINE(bugprone-exception-escape) class TORCH_API BatchNorm2dImpl : public BatchNormImplBase<2, BatchNorm2dImpl> { protected: - void _check_input_dim(const Tensor& input) override; + void _check_input_dim(const Tensor& input) override; public: using BatchNormImplBase<2, BatchNorm2dImpl>::BatchNormImplBase; @@ -181,28 +199,30 @@ class TORCH_API BatchNorm2dImpl : public BatchNormImplBase<2, BatchNorm2dImpl> { /// A `ModuleHolder` subclass for `BatchNorm2dImpl`. /// See the documentation for `BatchNorm2dImpl` class to learn what methods it -/// provides, and examples of how to use `BatchNorm2d` with `torch::nn::BatchNorm2dOptions`. -/// See the documentation for `ModuleHolder` to learn about PyTorch's -/// module storage semantics. +/// provides, and examples of how to use `BatchNorm2d` with +/// `torch::nn::BatchNorm2dOptions`. See the documentation for `ModuleHolder` to +/// learn about PyTorch's module storage semantics. TORCH_MODULE(BatchNorm2d); -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ BatchNorm3d ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ BatchNorm3d +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ /// Applies the BatchNorm3d function. /// See https://pytorch.org/docs/master/nn.html#torch.nn.BatchNorm3d to learn /// about the exact behavior of this module. /// -/// See the documentation for `torch::nn::BatchNorm3dOptions` class to learn what -/// constructor arguments are supported for this module. +/// See the documentation for `torch::nn::BatchNorm3dOptions` class to learn +/// what constructor arguments are supported for this module. /// /// Example: /// ``` -/// BatchNorm3d model(BatchNorm3dOptions(4).eps(0.5).momentum(0.1).affine(false).track_running_stats(true)); +/// BatchNorm3d +/// model(BatchNorm3dOptions(4).eps(0.5).momentum(0.1).affine(false).track_running_stats(true)); /// ``` // NOLINTNEXTLINE(bugprone-exception-escape) class TORCH_API BatchNorm3dImpl : public BatchNormImplBase<3, BatchNorm3dImpl> { protected: - void _check_input_dim(const Tensor& input) override; + void _check_input_dim(const Tensor& input) override; public: using BatchNormImplBase<3, BatchNorm3dImpl>::BatchNormImplBase; @@ -210,9 +230,9 @@ class TORCH_API BatchNorm3dImpl : public BatchNormImplBase<3, BatchNorm3dImpl> { /// A `ModuleHolder` subclass for `BatchNorm3dImpl`. /// See the documentation for `BatchNorm3dImpl` class to learn what methods it -/// provides, and examples of how to use `BatchNorm3d` with `torch::nn::BatchNorm3dOptions`. -/// See the documentation for `ModuleHolder` to learn about PyTorch's -/// module storage semantics. +/// provides, and examples of how to use `BatchNorm3d` with +/// `torch::nn::BatchNorm3dOptions`. See the documentation for `ModuleHolder` to +/// learn about PyTorch's module storage semantics. TORCH_MODULE(BatchNorm3d); } // namespace nn diff --git a/torch/csrc/api/include/torch/nn/modules/common.h b/torch/csrc/api/include/torch/nn/modules/common.h index 1ad202e0948831..11728db8a66180 100644 --- a/torch/csrc/api/include/torch/nn/modules/common.h +++ b/torch/csrc/api/include/torch/nn/modules/common.h @@ -35,8 +35,8 @@ /// `FORWARD_HAS_DEFAULT_ARGS` macro. /// ``` /// -/// The right way to fix this error is to use the `FORWARD_HAS_DEFAULT_ARGS` macro when -/// declaring the module: +/// The right way to fix this error is to use the `FORWARD_HAS_DEFAULT_ARGS` +/// macro when declaring the module: /// ``` /// struct MImpl : torch::nn::Module { /// public: @@ -56,9 +56,11 @@ /// 1 | torch::nn::AnyValue(2) /// 2 | torch::nn::AnyValue(3.0) /// ---------------------------------------------------------------- -/// Thus we pass the following arguments to the `FORWARD_HAS_DEFAULT_ARGS` macro: +/// Thus we pass the following arguments to the `FORWARD_HAS_DEFAULT_ARGS` +/// macro: /// */ -/// FORWARD_HAS_DEFAULT_ARGS({1, torch::nn::AnyValue(2)}, {2, torch::nn::AnyValue(3.0)}) +/// FORWARD_HAS_DEFAULT_ARGS({1, torch::nn::AnyValue(2)}, {2, +/// torch::nn::AnyValue(3.0)}) /// private: /// int value; /// }; @@ -67,29 +69,36 @@ /// Now, running the following would work: /// ``` /// torch::nn::Sequential seq(M(1)); -/// seq->forward(1); // This correctly populates the default arguments for `MImpl::forward` +/// seq->forward(1); // This correctly populates the default arguments for +/// `MImpl::forward` /// ``` -#define FORWARD_HAS_DEFAULT_ARGS(...) \ - template \ - friend struct torch::nn::AnyModuleHolder; \ - bool _forward_has_default_args() override { \ - return true; \ - } \ - unsigned int _forward_num_required_args() override { \ - std::vector> args_info = {__VA_ARGS__}; \ - return args_info[0].first; \ - } \ - std::vector _forward_populate_default_args(std::vector&& arguments) override { \ - std::vector> args_info = {__VA_ARGS__}; \ - unsigned int num_all_args = args_info[args_info.size() - 1].first + 1; \ - TORCH_INTERNAL_ASSERT(arguments.size() >= _forward_num_required_args() && arguments.size() <= num_all_args); \ - std::vector ret; \ - ret.reserve(num_all_args); \ - for (const auto i : c10::irange(arguments.size())) { \ - ret.emplace_back(std::move(arguments[i])); \ - } \ - for (auto& arg_info : args_info) { \ - if (arg_info.first > ret.size() - 1) ret.emplace_back(std::move(arg_info.second)); \ - } \ - return ret; \ +#define FORWARD_HAS_DEFAULT_ARGS(...) \ + template \ + friend struct torch::nn::AnyModuleHolder; \ + bool _forward_has_default_args() override { \ + return true; \ + } \ + unsigned int _forward_num_required_args() override { \ + std::vector> args_info = { \ + __VA_ARGS__}; \ + return args_info[0].first; \ + } \ + std::vector _forward_populate_default_args( \ + std::vector&& arguments) override { \ + std::vector> args_info = { \ + __VA_ARGS__}; \ + unsigned int num_all_args = args_info[args_info.size() - 1].first + 1; \ + TORCH_INTERNAL_ASSERT( \ + arguments.size() >= _forward_num_required_args() && \ + arguments.size() <= num_all_args); \ + std::vector ret; \ + ret.reserve(num_all_args); \ + for (const auto i : c10::irange(arguments.size())) { \ + ret.emplace_back(std::move(arguments[i])); \ + } \ + for (auto& arg_info : args_info) { \ + if (arg_info.first > ret.size() - 1) \ + ret.emplace_back(std::move(arg_info.second)); \ + } \ + return ret; \ } diff --git a/torch/csrc/api/include/torch/nn/modules/container/any.h b/torch/csrc/api/include/torch/nn/modules/container/any.h index d15fca1b9d53ea..b96efc4812c848 100644 --- a/torch/csrc/api/include/torch/nn/modules/container/any.h +++ b/torch/csrc/api/include/torch/nn/modules/container/any.h @@ -151,8 +151,8 @@ class AnyModule { AnyValue any_forward(ArgumentTypes&&... arguments); /// Invokes `forward()` on the contained module with the given arguments, and - /// casts the returned `AnyValue` to the supplied `ReturnType` (which defaults to - /// `torch::Tensor`). + /// casts the returned `AnyValue` to the supplied `ReturnType` (which defaults + /// to `torch::Tensor`). template ReturnType forward(ArgumentTypes&&... arguments); @@ -186,9 +186,9 @@ class AnyModule { bool is_empty() const noexcept; private: - /// Creates a `unique_ptr` pointing to a `AnyModuleHolder` of the correct - /// type. This method is used to deduce the arguments of the module's - /// `forward()` method. + /// Creates a `unique_ptr` pointing to a + /// `AnyModuleHolder` of the correct type. This method is used to deduce the + /// arguments of the module's `forward()` method. template < typename ModuleType, typename Class, @@ -340,7 +340,8 @@ std::unique_ptr AnyModule::make_holder( !std::is_void::value, "AnyModule cannot store modules that return void " "(you can return a dummy value)."); - return torch::make_unique, ArgumentTypes...>>( + return torch::make_unique< + AnyModuleHolder, ArgumentTypes...>>( std::move(module)); } @@ -357,7 +358,8 @@ template ModuleType& AnyModule::get_( ReturnType (ModuleType::*)(ArgumentTypes...)) const { if (typeid(ModuleType).hash_code() == type_info().hash_code()) { - return *static_cast&>(*content_) + return *static_cast&>( + *content_) .module; } AT_ERROR( diff --git a/torch/csrc/api/include/torch/nn/modules/container/any_module_holder.h b/torch/csrc/api/include/torch/nn/modules/container/any_module_holder.h index 63a3927ee79edd..ef9d483178c8f8 100644 --- a/torch/csrc/api/include/torch/nn/modules/container/any_module_holder.h +++ b/torch/csrc/api/include/torch/nn/modules/container/any_module_holder.h @@ -25,7 +25,8 @@ struct AnyModulePlaceholder : public AnyValue::Placeholder { virtual std::unique_ptr copy() const = 0; /// Returns a `AnyModulePlaceholder` with a deep copy of this `AnyModule`. - virtual std::unique_ptr clone_module(optional device) const = 0; + virtual std::unique_ptr clone_module( + optional device) const = 0; }; // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ AnyModuleHolder ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -74,7 +75,8 @@ struct AnyModuleHolder : public AnyModulePlaceholder { AnyValue forward(std::vector&& arguments) override { if (module->_forward_has_default_args()) { TORCH_CHECK( - arguments.size() >= module->_forward_num_required_args() && arguments.size() <= sizeof...(ArgumentTypes), + arguments.size() >= module->_forward_num_required_args() && + arguments.size() <= sizeof...(ArgumentTypes), c10::demangle(type_info.name()), "'s forward() method expects at least ", module->_forward_num_required_args(), @@ -83,13 +85,13 @@ struct AnyModuleHolder : public AnyModulePlaceholder { " argument(s), but received ", arguments.size(), "."); - arguments = std::move(module->_forward_populate_default_args(std::move(arguments))); + arguments = std::move( + module->_forward_populate_default_args(std::move(arguments))); } else { - std::string use_default_args_macro_prompt = \ - " If " + \ - c10::demangle(type_info.name()) + \ - "'s forward() method has default arguments, " + \ - "please make sure the forward() method is declared with a corresponding `FORWARD_HAS_DEFAULT_ARGS` macro."; + std::string use_default_args_macro_prompt = " If " + + c10::demangle(type_info.name()) + + "'s forward() method has default arguments, " + + "please make sure the forward() method is declared with a corresponding `FORWARD_HAS_DEFAULT_ARGS` macro."; TORCH_CHECK( arguments.size() == sizeof...(ArgumentTypes), c10::demangle(type_info.name()), @@ -98,7 +100,9 @@ struct AnyModuleHolder : public AnyModulePlaceholder { " argument(s), but received ", arguments.size(), ".", - (arguments.size() < sizeof...(ArgumentTypes)) ? use_default_args_macro_prompt : ""); + (arguments.size() < sizeof...(ArgumentTypes)) + ? use_default_args_macro_prompt + : ""); } // FYI: During invocation of a module's `forward()` method, the values live @@ -115,7 +119,8 @@ struct AnyModuleHolder : public AnyModulePlaceholder { return torch::make_unique(*this); } - std::unique_ptr clone_module(optional device) const override { + std::unique_ptr clone_module( + optional device) const override { return torch::make_unique( std::dynamic_pointer_cast(module->clone(device))); } diff --git a/torch/csrc/api/include/torch/nn/modules/container/any_value.h b/torch/csrc/api/include/torch/nn/modules/container/any_value.h index 6e60b320fc301d..3a836f75923069 100644 --- a/torch/csrc/api/include/torch/nn/modules/container/any_value.h +++ b/torch/csrc/api/include/torch/nn/modules/container/any_value.h @@ -44,9 +44,9 @@ class AnyValue { : content_( torch::make_unique>>(std::forward(value))) {} - /// Returns a pointer to the value contained in the `AnyValue` if the type passed - /// as template parameter matches the type of the value stored, and returns a - /// null pointer otherwise. + /// Returns a pointer to the value contained in the `AnyValue` if the type + /// passed as template parameter matches the type of the value stored, and + /// returns a null pointer otherwise. template T* try_get() { static_assert( @@ -61,9 +61,9 @@ class AnyValue { return nullptr; } - /// Returns the value contained in the `AnyValue` if the type passed as template - /// parameter matches the type of the value stored, and throws an exception - /// otherwise. + /// Returns the value contained in the `AnyValue` if the type passed as + /// template parameter matches the type of the value stored, and throws an + /// exception otherwise. template T get() { if (auto* maybe_value = try_get()) { diff --git a/torch/csrc/api/include/torch/nn/modules/container/functional.h b/torch/csrc/api/include/torch/nn/modules/container/functional.h index 024ef16d210d21..d1af0e0fd50424 100644 --- a/torch/csrc/api/include/torch/nn/modules/container/functional.h +++ b/torch/csrc/api/include/torch/nn/modules/container/functional.h @@ -1,10 +1,10 @@ #pragma once +#include #include #include #include #include -#include #include #include @@ -41,11 +41,11 @@ namespace nn { /// Functional(torch::leaky_relu, /*slope=*/0.5) /// \endrst /// -/// The value of `0.5` is then stored within the `Functional` object and supplied -/// to the function call at invocation time. Note that such bound values are -/// evaluated eagerly and stored a single time. See the documentation of -/// [std::bind](https://en.cppreference.com/w/cpp/utility/functional/bind) for -/// more information on the semantics of argument binding. +/// The value of `0.5` is then stored within the `Functional` object and +/// supplied to the function call at invocation time. Note that such bound +/// values are evaluated eagerly and stored a single time. See the documentation +/// of [std::bind](https://en.cppreference.com/w/cpp/utility/functional/bind) +/// for more information on the semantics of argument binding. /// /// \rst /// .. attention:: diff --git a/torch/csrc/api/include/torch/nn/modules/container/moduledict.h b/torch/csrc/api/include/torch/nn/modules/container/moduledict.h index 15a843ab783c73..fe02643338511e 100644 --- a/torch/csrc/api/include/torch/nn/modules/container/moduledict.h +++ b/torch/csrc/api/include/torch/nn/modules/container/moduledict.h @@ -40,10 +40,10 @@ namespace nn { /// Why should you use `ModuleDict` instead of a simple `map` or `OrderedDict`? /// The value a `ModuleDict` provides over manually calling an ordered map of /// modules is that it allows treating the whole container *as a single module*, -/// such that performing a transformation on the `ModuleDict` applies to each of the -/// modules it stores (which are each a registered submodule of the `ModuleDict`). -/// For example, calling `.to(torch::kCUDA)` on a `ModuleDict` will move each module -/// in the map to CUDA memory. For example: +/// such that performing a transformation on the `ModuleDict` applies to each of +/// the modules it stores (which are each a registered submodule of the +/// `ModuleDict`). For example, calling `.to(torch::kCUDA)` on a `ModuleDict` +/// will move each module in the map to CUDA memory. For example: /// /// \rst /// .. code-block:: cpp @@ -61,20 +61,23 @@ namespace nn { /// \endrst /// /// Finally, `ModuleDict` provides a lightweight container API, such as allowing -/// iteration over submodules, positional access, adding new modules from a vector -/// of key-module pairs or an `OrderedDict` or another `ModuleDict` after +/// iteration over submodules, positional access, adding new modules from a +/// vector of key-module pairs or an `OrderedDict` or another `ModuleDict` after /// construction via `update`. // NOLINTNEXTLINE(bugprone-exception-escape) class ModuleDictImpl : public Cloneable { public: - using Iterator = torch::OrderedDict>::Iterator; - using ConstIterator = torch::OrderedDict>::ConstIterator; + using Iterator = + torch::OrderedDict>::Iterator; + using ConstIterator = + torch::OrderedDict>::ConstIterator; ModuleDictImpl() = default; /// Constructs the `ModuleDict` from a list of string-Module pairs. explicit ModuleDictImpl( - const std::vector>>& modules) { + const std::vector>>& + modules) { update(modules); } @@ -136,7 +139,8 @@ class ModuleDictImpl : public Cloneable { /// Remove all items from the `ModuleDict`. void clear() { - // Not remove the registration of modules to make it consistent with python version. + // Not remove the registration of modules to make it consistent with python + // version. modules_.clear(); } @@ -195,19 +199,22 @@ class ModuleDictImpl : public Cloneable { std::shared_ptr pop(const std::string& key) { auto module = modules_[key]; modules_.erase(key); - // Not remove the registration of the module to make it consistent with python version. + // Not remove the registration of the module to make it consistent with + // python version. return module; } /// Updated the `ModuleDict` with a vector of key-module pairs. void update( - const std::vector>>& modules) { + const std::vector>>& + modules) { for (auto& item : modules) { insert(item.first, item.second); } } - /// Updated the `ModuleDict` with key-value pairs from `OrderedDict` or `ModuleDict`. + /// Updated the `ModuleDict` with key-value pairs from `OrderedDict` or + /// `ModuleDict`. template void update(const Container& container) { for (auto& item : container) { @@ -215,7 +222,7 @@ class ModuleDictImpl : public Cloneable { } } -private: + private: /// Private `OrderedDict` holding the key-Module pairs. torch::OrderedDict> modules_; @@ -225,13 +232,11 @@ class ModuleDictImpl : public Cloneable { if (contains(key)) { modules_[key] = std::move(module); replace_module(key, modules_[key]); - } - else { + } else { modules_.insert(key, std::move(module)); register_module(key, modules_.back().value()); } } - }; /// A `ModuleHolder` subclass for `ModuleDictImpl`. diff --git a/torch/csrc/api/include/torch/nn/modules/container/named_any.h b/torch/csrc/api/include/torch/nn/modules/container/named_any.h index 489bb6fca471ee..5342ba39cd21f7 100644 --- a/torch/csrc/api/include/torch/nn/modules/container/named_any.h +++ b/torch/csrc/api/include/torch/nn/modules/container/named_any.h @@ -38,8 +38,8 @@ namespace nn { /// }; /// /// Sequential sequential({ -/// {"m1", std::make_shared(1)}, // shared pointer to `Module` is supported -/// {std::string("m2"), M(2)}, // `Module` is supported +/// {"m1", std::make_shared(1)}, // shared pointer to `Module` is +/// supported {std::string("m2"), M(2)}, // `Module` is supported /// {"linear1", Linear(10, 3)} // `ModuleHolder` is supported /// }); /// \endrst @@ -57,9 +57,9 @@ class NamedAnyModule { template > NamedAnyModule(std::string name, M&& module) : NamedAnyModule( - std::move(name), - std::make_shared::type>( - std::forward(module))) {} + std::move(name), + std::make_shared::type>( + std::forward(module))) {} /// Creates a `NamedAnyModule` from a `Module` that is unwrapped from /// a `ModuleHolder`. @@ -69,7 +69,7 @@ class NamedAnyModule { /// Creates a `NamedAnyModule` from a type-erased `AnyModule`. NamedAnyModule(std::string name, AnyModule any_module) - : name_(std::move(name)), module_(std::move(any_module)) {} + : name_(std::move(name)), module_(std::move(any_module)) {} /// Returns a reference to the name. const std::string& name() const noexcept { diff --git a/torch/csrc/api/include/torch/nn/modules/container/sequential.h b/torch/csrc/api/include/torch/nn/modules/container/sequential.h index 477194276c8a3b..adc4af1d74139d 100644 --- a/torch/csrc/api/include/torch/nn/modules/container/sequential.h +++ b/torch/csrc/api/include/torch/nn/modules/container/sequential.h @@ -105,7 +105,8 @@ class SequentialImpl : public Cloneable { } /// Constructs the `Sequential` from an `OrderedDict` of named `AnyModule`s. - explicit SequentialImpl(torch::OrderedDict&& ordered_dict) { + explicit SequentialImpl( + torch::OrderedDict&& ordered_dict) { modules_.reserve(ordered_dict.size()); for (auto& item : ordered_dict) { push_back(item.key(), std::move(item.value())); @@ -119,7 +120,8 @@ class SequentialImpl : public Cloneable { modules_.reserve(named_modules.size()); for (const auto& named_module : named_modules) { // NOLINTNEXTLINE(performance-move-const-arg) - push_back(std::move(named_module.name()), std::move(named_module.module())); + push_back( + std::move(named_module.name()), std::move(named_module.module())); } } @@ -215,8 +217,8 @@ class SequentialImpl : public Cloneable { push_back(c10::to_string(modules_.size()), std::forward(module)); } - /// Adds a new named `Module` to the `Sequential` container, moving or copying it - /// into a `shared_ptr` internally. This method allows passing value types, + /// Adds a new named `Module` to the `Sequential` container, moving or copying + /// it into a `shared_ptr` internally. This method allows passing value types, /// and letting the container deal with the boxing. template > void push_back(std::string name, M&& module) { @@ -342,13 +344,19 @@ class SequentialImpl : public Cloneable { /// pack has only one type, in which case the template would be preferred, /// even if the other `push_back` functions are better fits (e.g. `unique_ptr` /// -> `shared_ptr` overload). - /// NOTE: We explicitly avoid matching this template with `push_back(std::string("name"), module)` - /// or `push_back("name", module)`, since they should be handled by their respective - /// `push_back` functions. - template ::value || - // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays) - std::is_same::type, std::decay::type>::value>> + /// NOTE: We explicitly avoid matching this template with + /// `push_back(std::string("name"), module)` or `push_back("name", module)`, + /// since they should be handled by their respective `push_back` functions. + template < + typename First, + typename Second, + typename... Rest, + typename = torch::disable_if_t< + std::is_same::value || + // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays) + std::is_same< + typename std::decay::type, + std::decay::type>::value>> void push_back(First&& first, Second&& second, Rest&&... rest) { push_back(std::forward(first)); // Recursively calls this method, until the parameter pack only thas this @@ -379,7 +387,9 @@ class Sequential : public torch::nn::ModuleHolder { /// It enables the following use case: /// `Sequential sequential({{"m1", M(1)}, {"m2", M(2)}})` // NOLINTNEXTLINE(performance-move-const-arg) - Sequential(std::initializer_list named_modules) : ModuleHolder(std::make_shared(std::move(named_modules))) {} + Sequential(std::initializer_list named_modules) + : ModuleHolder( + std::make_shared(std::move(named_modules))) {} }; } // namespace nn } // namespace torch diff --git a/torch/csrc/api/include/torch/nn/modules/conv.h b/torch/csrc/api/include/torch/nn/modules/conv.h index a644b52a9c7cc9..a53edf702d8e52 100644 --- a/torch/csrc/api/include/torch/nn/modules/conv.h +++ b/torch/csrc/api/include/torch/nn/modules/conv.h @@ -1,6 +1,5 @@ #pragma once - #include #include @@ -26,67 +25,72 @@ template // NOLINTNEXTLINE(bugprone-exception-escape) class ConvNdImpl : public torch::nn::Cloneable { public: - explicit ConvNdImpl(detail::ConvNdOptions options_) : options(std::move(options_)) { + explicit ConvNdImpl(detail::ConvNdOptions options_) + : options(std::move(options_)) { // NOLINTNEXTLINE(clang-analyzer-optin.cplusplus.VirtualCall) reset(); } void reset() override { TORCH_CHECK( - options.in_channels() % options.groups() == 0, - "in_channels must be divisible by groups"); + options.in_channels() % options.groups() == 0, + "in_channels must be divisible by groups"); TORCH_CHECK( - options.out_channels() % options.groups() == 0, - "out_channels must be divisible by groups"); - - c10::visit(c10::overloaded( - [&](enumtype::kValid) { - _reversed_padding_repeated_twice.resize(2 * D); - std::fill_n(_reversed_padding_repeated_twice.begin(), 2 * D, 0); - }, - [&](enumtype::kSame) { - for (const auto i : c10::irange(D)) { - const auto stride = (*options.stride())[i]; - TORCH_CHECK(stride == 1, "padding='same' is not supported for strided convolutions"); - } - - _reversed_padding_repeated_twice.resize(2 * D); - for (const auto i : c10::irange(D)) { - const auto dilation = (*options.dilation())[i]; - const auto kernel_size = (*options.kernel_size())[i]; - const auto total_padding = dilation * (kernel_size - 1); - auto left_pad = total_padding / 2; - auto right_pad = total_padding - left_pad; - _reversed_padding_repeated_twice[2 * i] = left_pad; - _reversed_padding_repeated_twice[2 * i + 1] = right_pad; - } - }, - [&](const ExpandingArray & pad) { - _reversed_padding_repeated_twice = - torch::nn::modules::utils::_reverse_repeat_vector(pad, 2); - }), + options.out_channels() % options.groups() == 0, + "out_channels must be divisible by groups"); + + c10::visit( + c10::overloaded( + [&](enumtype::kValid) { + _reversed_padding_repeated_twice.resize(2 * D); + std::fill_n(_reversed_padding_repeated_twice.begin(), 2 * D, 0); + }, + [&](enumtype::kSame) { + for (const auto i : c10::irange(D)) { + const auto stride = (*options.stride())[i]; + TORCH_CHECK( + stride == 1, + "padding='same' is not supported for strided convolutions"); + } + + _reversed_padding_repeated_twice.resize(2 * D); + for (const auto i : c10::irange(D)) { + const auto dilation = (*options.dilation())[i]; + const auto kernel_size = (*options.kernel_size())[i]; + const auto total_padding = dilation * (kernel_size - 1); + auto left_pad = total_padding / 2; + auto right_pad = total_padding - left_pad; + _reversed_padding_repeated_twice[2 * i] = left_pad; + _reversed_padding_repeated_twice[2 * i + 1] = right_pad; + } + }, + [&](const ExpandingArray& pad) { + _reversed_padding_repeated_twice = + torch::nn::modules::utils::_reverse_repeat_vector(pad, 2); + }), options.padding()); if (options.transposed()) { std::vector weight_sizes = { - options.in_channels(), - options.out_channels() / options.groups()}; - weight_sizes.insert(weight_sizes.end(), (*options.kernel_size()).begin(), (*options.kernel_size()).end()); - weight = this->register_parameter( - "weight", - torch::empty(weight_sizes)); + options.in_channels(), options.out_channels() / options.groups()}; + weight_sizes.insert( + weight_sizes.end(), + (*options.kernel_size()).begin(), + (*options.kernel_size()).end()); + weight = this->register_parameter("weight", torch::empty(weight_sizes)); } else { std::vector weight_sizes = { - options.out_channels(), - options.in_channels() / options.groups()}; - weight_sizes.insert(weight_sizes.end(), (*options.kernel_size()).begin(), (*options.kernel_size()).end()); - weight = this->register_parameter( - "weight", - torch::empty(weight_sizes)); + options.out_channels(), options.in_channels() / options.groups()}; + weight_sizes.insert( + weight_sizes.end(), + (*options.kernel_size()).begin(), + (*options.kernel_size()).end()); + weight = this->register_parameter("weight", torch::empty(weight_sizes)); } if (options.bias()) { - bias = this->register_parameter("bias", torch::empty({options.out_channels()})); + bias = this->register_parameter( + "bias", torch::empty({options.out_channels()})); } else { this->register_parameter("bias", Tensor(), /*requires_grad=*/false); } @@ -95,7 +99,9 @@ class ConvNdImpl : public torch::nn::Cloneable { } void reset_parameters() { - init::kaiming_uniform_(weight, /*a=*/std::sqrt(5)); // NOLINT(cppcoreguidelines-avoid-magic-numbers) + init::kaiming_uniform_( + weight, + /*a=*/std::sqrt(5)); // NOLINT(cppcoreguidelines-avoid-magic-numbers) if (bias.defined()) { // NOLINTNEXTLINE(cppcoreguidelines-init-variables) @@ -109,22 +115,18 @@ class ConvNdImpl : public torch::nn::Cloneable { /// Pretty prints the `Conv{1,2,3}d` module into the given `stream`. void pretty_print(std::ostream& stream) const override { stream << "torch::nn::Conv" << D << "d" - << "(" << options.in_channels() - << ", " << options.out_channels() + << "(" << options.in_channels() << ", " << options.out_channels() << ", kernel_size=" << options.kernel_size() << ", stride=" << options.stride(); - c10::visit(c10::overloaded( - [&](enumtype::kValid) { - stream << ", padding='valid'"; - }, - [&](enumtype::kSame) { - stream << ", padding='same'"; - }, - [&](const ExpandingArray & pad) { - if (*pad != *ExpandingArray(0)) { - stream << ", padding=" << pad; - } - }), + c10::visit( + c10::overloaded( + [&](enumtype::kValid) { stream << ", padding='valid'"; }, + [&](enumtype::kSame) { stream << ", padding='same'"; }, + [&](const ExpandingArray& pad) { + if (*pad != *ExpandingArray(0)) { + stream << ", padding=" << pad; + } + }), options.padding()); if (*options.dilation() != *ExpandingArray(1)) { stream << ", dilation=" << options.dilation(); @@ -139,7 +141,8 @@ class ConvNdImpl : public torch::nn::Cloneable { stream << ", bias=" << std::boolalpha << false; } if (!c10::get_if(&options.padding_mode())) { - stream << ", padding_mode=" << enumtype::get_enum_name(options.padding_mode()); + stream << ", padding_mode=" + << enumtype::get_enum_name(options.padding_mode()); } stream << ")"; } @@ -181,17 +184,17 @@ class TORCH_API Conv1dImpl : public ConvNdImpl<1, Conv1dImpl> { int64_t input_channels, int64_t output_channels, ExpandingArray<1> kernel_size) - : Conv1dImpl(Conv1dOptions(input_channels, output_channels, kernel_size)) { - } + : Conv1dImpl( + Conv1dOptions(input_channels, output_channels, kernel_size)) {} explicit Conv1dImpl(Conv1dOptions options_); Tensor forward(const Tensor& input); }; /// A `ModuleHolder` subclass for `Conv1dImpl`. /// See the documentation for `Conv1dImpl` class to learn what methods it -/// provides, and examples of how to use `Conv1d` with `torch::nn::Conv1dOptions`. -/// See the documentation for `ModuleHolder` to learn about PyTorch's -/// module storage semantics. +/// provides, and examples of how to use `Conv1d` with +/// `torch::nn::Conv1dOptions`. See the documentation for `ModuleHolder` to +/// learn about PyTorch's module storage semantics. TORCH_MODULE(Conv1d); // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Conv2d ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -214,8 +217,8 @@ class TORCH_API Conv2dImpl : public ConvNdImpl<2, Conv2dImpl> { int64_t input_channels, int64_t output_channels, ExpandingArray<2> kernel_size) - : Conv2dImpl(Conv2dOptions(input_channels, output_channels, kernel_size)) { - } + : Conv2dImpl( + Conv2dOptions(input_channels, output_channels, kernel_size)) {} explicit Conv2dImpl(Conv2dOptions options_); Tensor forward(const Tensor& input); @@ -225,9 +228,9 @@ class TORCH_API Conv2dImpl : public ConvNdImpl<2, Conv2dImpl> { /// A `ModuleHolder` subclass for `Conv2dImpl`. /// See the documentation for `Conv2dImpl` class to learn what methods it -/// provides, and examples of how to use `Conv2d` with `torch::nn::Conv2dOptions`. -/// See the documentation for `ModuleHolder` to learn about PyTorch's -/// module storage semantics. +/// provides, and examples of how to use `Conv2d` with +/// `torch::nn::Conv2dOptions`. See the documentation for `ModuleHolder` to +/// learn about PyTorch's module storage semantics. TORCH_MODULE(Conv2d); // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Conv3d ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -250,17 +253,17 @@ class TORCH_API Conv3dImpl : public ConvNdImpl<3, Conv3dImpl> { int64_t input_channels, int64_t output_channels, ExpandingArray<3> kernel_size) - : Conv3dImpl(Conv3dOptions(input_channels, output_channels, kernel_size)) { - } + : Conv3dImpl( + Conv3dOptions(input_channels, output_channels, kernel_size)) {} explicit Conv3dImpl(Conv3dOptions options_); Tensor forward(const Tensor& input); }; /// A `ModuleHolder` subclass for `Conv3dImpl`. /// See the documentation for `Conv3dImpl` class to learn what methods it -/// provides, and examples of how to use `Conv3d` with `torch::nn::Conv3dOptions`. -/// See the documentation for `ModuleHolder` to learn about PyTorch's -/// module storage semantics. +/// provides, and examples of how to use `Conv3d` with +/// `torch::nn::Conv3dOptions`. See the documentation for `ModuleHolder` to +/// learn about PyTorch's module storage semantics. TORCH_MODULE(Conv3d); // ~~~~~~~~~~~~~~~~~~~~~~~~~ ConvTranspose ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -272,19 +275,20 @@ class ConvTransposeNdImpl : public ConvNdImpl { public: using torch::nn::ConvNdImpl::ConvNdImpl; explicit ConvTransposeNdImpl(detail::ConvNdOptions options_) - : ConvNdImpl(options_) { - TORCH_INTERNAL_ASSERT(c10::holds_alternative>(this->options.padding()), - "ConvTranspose padding cannot be a string"); + : ConvNdImpl(options_) { + TORCH_INTERNAL_ASSERT( + c10::holds_alternative>(this->options.padding()), + "ConvTranspose padding cannot be a string"); } /// Pretty prints the `ConvTranspose{1,2,3}d` module into the given `stream`. void pretty_print(std::ostream& stream) const override { stream << "torch::nn::ConvTranspose" << D << "d" - << "(" << this->options.in_channels() - << ", " << this->options.out_channels() + << "(" << this->options.in_channels() << ", " + << this->options.out_channels() << ", kernel_size=" << this->options.kernel_size() << ", stride=" << this->options.stride(); - const auto & pad = padding(); + const auto& pad = padding(); if (*pad != *ExpandingArray(0)) { stream << ", padding=" << pad; } @@ -301,7 +305,8 @@ class ConvTransposeNdImpl : public ConvNdImpl { stream << ", bias=" << std::boolalpha << false; } if (!c10::get_if(&this->options.padding_mode())) { - stream << ", padding_mode=" << enumtype::get_enum_name(this->options.padding_mode()); + stream << ", padding_mode=" + << enumtype::get_enum_name(this->options.padding_mode()); } stream << ")"; } @@ -312,117 +317,140 @@ class ConvTransposeNdImpl : public ConvNdImpl { } std::vector _output_padding( - const Tensor& input, const c10::optional& output_size, - const ExpandingArray& stride, const ExpandingArray& padding, + const Tensor& input, + const c10::optional& output_size, + const ExpandingArray& stride, + const ExpandingArray& padding, const ExpandingArray& kernel_size); }; -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ConvTranspose1d ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ConvTranspose1d +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ /// Applies the ConvTranspose1d function. /// See https://pytorch.org/docs/master/nn.html#torch.nn.ConvTranspose1d to /// learn about the exact behavior of this module. /// -/// See the documentation for `torch::nn::ConvTranspose1dOptions` class to learn what -/// constructor arguments are supported for this module. +/// See the documentation for `torch::nn::ConvTranspose1dOptions` class to learn +/// what constructor arguments are supported for this module. /// /// Example: /// ``` -/// ConvTranspose1d model(ConvTranspose1dOptions(3, 2, 3).stride(1).bias(false)); +/// ConvTranspose1d model(ConvTranspose1dOptions(3, 2, +/// 3).stride(1).bias(false)); /// ``` // NOLINTNEXTLINE(bugprone-exception-escape) -class TORCH_API ConvTranspose1dImpl : public ConvTransposeNdImpl<1, ConvTranspose1dImpl> { +class TORCH_API ConvTranspose1dImpl + : public ConvTransposeNdImpl<1, ConvTranspose1dImpl> { public: ConvTranspose1dImpl( int64_t input_channels, int64_t output_channels, ExpandingArray<1> kernel_size) - : ConvTranspose1dImpl(ConvTranspose1dOptions(input_channels, output_channels, kernel_size)) { - } + : ConvTranspose1dImpl(ConvTranspose1dOptions( + input_channels, + output_channels, + kernel_size)) {} explicit ConvTranspose1dImpl(ConvTranspose1dOptions options_); - Tensor forward(const Tensor& input, - const c10::optional& output_size = c10::nullopt); + Tensor forward( + const Tensor& input, + const c10::optional& output_size = c10::nullopt); + protected: FORWARD_HAS_DEFAULT_ARGS({1, AnyValue(c10::optional())}) }; /// A `ModuleHolder` subclass for `ConvTranspose1dImpl`. -/// See the documentation for `ConvTranspose1dImpl` class to learn what methods it -/// provides, and examples of how to use `ConvTranspose1d` with `torch::nn::ConvTranspose1dOptions`. -/// See the documentation for `ModuleHolder` to learn about PyTorch's -/// module storage semantics. +/// See the documentation for `ConvTranspose1dImpl` class to learn what methods +/// it provides, and examples of how to use `ConvTranspose1d` with +/// `torch::nn::ConvTranspose1dOptions`. See the documentation for +/// `ModuleHolder` to learn about PyTorch's module storage semantics. TORCH_MODULE(ConvTranspose1d); -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ConvTranspose2d ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ConvTranspose2d +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ /// Applies the ConvTranspose2d function. /// See https://pytorch.org/docs/master/nn.html#torch.nn.ConvTranspose2d to /// learn about the exact behavior of this module. /// -/// See the documentation for `torch::nn::ConvTranspose2dOptions` class to learn what -/// constructor arguments are supported for this module. +/// See the documentation for `torch::nn::ConvTranspose2dOptions` class to learn +/// what constructor arguments are supported for this module. /// /// Example: /// ``` -/// ConvTranspose2d model(ConvTranspose2dOptions(3, 2, 3).stride(1).bias(false)); +/// ConvTranspose2d model(ConvTranspose2dOptions(3, 2, +/// 3).stride(1).bias(false)); /// ``` // NOLINTNEXTLINE(bugprone-exception-escape) -class TORCH_API ConvTranspose2dImpl : public ConvTransposeNdImpl<2, ConvTranspose2dImpl> { +class TORCH_API ConvTranspose2dImpl + : public ConvTransposeNdImpl<2, ConvTranspose2dImpl> { public: ConvTranspose2dImpl( int64_t input_channels, int64_t output_channels, ExpandingArray<2> kernel_size) - : ConvTranspose2dImpl(ConvTranspose2dOptions(input_channels, output_channels, kernel_size)) { - } + : ConvTranspose2dImpl(ConvTranspose2dOptions( + input_channels, + output_channels, + kernel_size)) {} explicit ConvTranspose2dImpl(ConvTranspose2dOptions options_); - Tensor forward(const Tensor& input, - const c10::optional& output_size = c10::nullopt); + Tensor forward( + const Tensor& input, + const c10::optional& output_size = c10::nullopt); + protected: FORWARD_HAS_DEFAULT_ARGS({1, AnyValue(c10::optional())}) }; /// A `ModuleHolder` subclass for `ConvTranspose2dImpl`. -/// See the documentation for `ConvTranspose2dImpl` class to learn what methods it -/// provides, and examples of how to use `ConvTranspose2d` with `torch::nn::ConvTranspose2dOptions`. -/// See the documentation for `ModuleHolder` to learn about PyTorch's -/// module storage semantics. +/// See the documentation for `ConvTranspose2dImpl` class to learn what methods +/// it provides, and examples of how to use `ConvTranspose2d` with +/// `torch::nn::ConvTranspose2dOptions`. See the documentation for +/// `ModuleHolder` to learn about PyTorch's module storage semantics. TORCH_MODULE(ConvTranspose2d); -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ConvTranspose3d ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ConvTranspose3d +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ /// Applies the ConvTranspose3d function. /// See https://pytorch.org/docs/master/nn.html#torch.nn.ConvTranspose3d to /// learn about the exact behavior of this module. /// -/// See the documentation for `torch::nn::ConvTranspose3dOptions` class to learn what -/// constructor arguments are supported for this module. +/// See the documentation for `torch::nn::ConvTranspose3dOptions` class to learn +/// what constructor arguments are supported for this module. /// /// Example: /// ``` -/// ConvTranspose3d model(ConvTranspose3dOptions(2, 2, 2).stride(1).bias(false)); +/// ConvTranspose3d model(ConvTranspose3dOptions(2, 2, +/// 2).stride(1).bias(false)); /// ``` // NOLINTNEXTLINE(bugprone-exception-escape) -class TORCH_API ConvTranspose3dImpl : public ConvTransposeNdImpl<3, ConvTranspose3dImpl> { +class TORCH_API ConvTranspose3dImpl + : public ConvTransposeNdImpl<3, ConvTranspose3dImpl> { public: ConvTranspose3dImpl( int64_t input_channels, int64_t output_channels, ExpandingArray<3> kernel_size) - : ConvTranspose3dImpl(ConvTranspose3dOptions(input_channels, output_channels, kernel_size)) { - } + : ConvTranspose3dImpl(ConvTranspose3dOptions( + input_channels, + output_channels, + kernel_size)) {} explicit ConvTranspose3dImpl(ConvTranspose3dOptions options_); - Tensor forward(const Tensor& input, - const c10::optional& output_size = c10::nullopt); + Tensor forward( + const Tensor& input, + const c10::optional& output_size = c10::nullopt); + protected: FORWARD_HAS_DEFAULT_ARGS({1, AnyValue(c10::optional())}) }; /// A `ModuleHolder` subclass for `ConvTranspose3dImpl`. -/// See the documentation for `ConvTranspose3dImpl` class to learn what methods it -/// provides, and examples of how to use `ConvTranspose3d` with `torch::nn::ConvTranspose3dOptions`. -/// See the documentation for `ModuleHolder` to learn about PyTorch's -/// module storage semantics. +/// See the documentation for `ConvTranspose3dImpl` class to learn what methods +/// it provides, and examples of how to use `ConvTranspose3d` with +/// `torch::nn::ConvTranspose3dOptions`. See the documentation for +/// `ModuleHolder` to learn about PyTorch's module storage semantics. TORCH_MODULE(ConvTranspose3d); } // namespace nn diff --git a/torch/csrc/api/include/torch/nn/modules/distance.h b/torch/csrc/api/include/torch/nn/modules/distance.h index 053800d6ec7805..6cf0b044eb39f9 100644 --- a/torch/csrc/api/include/torch/nn/modules/distance.h +++ b/torch/csrc/api/include/torch/nn/modules/distance.h @@ -16,8 +16,8 @@ namespace nn { /// See https://pytorch.org/docs/master/nn.html#torch.nn.CosineSimilarity to /// learn about the exact behavior of this module. /// -/// See the documentation for `torch::nn::CosineSimilarityOptions` class to learn what -/// constructor arguments are supported for this module. +/// See the documentation for `torch::nn::CosineSimilarityOptions` class to +/// learn what constructor arguments are supported for this module. /// /// Example: /// ``` @@ -40,10 +40,10 @@ class TORCH_API CosineSimilarityImpl : public Cloneable { }; /// A `ModuleHolder` subclass for `CosineSimilarityImpl`. -/// See the documentation for `CosineSimilarityImpl` class to learn what methods it -/// provides, and examples of how to use `CosineSimilarity` with `torch::nn::CosineSimilarityOptions`. -/// See the documentation for `ModuleHolder` to learn about PyTorch's -/// module storage semantics. +/// See the documentation for `CosineSimilarityImpl` class to learn what methods +/// it provides, and examples of how to use `CosineSimilarity` with +/// `torch::nn::CosineSimilarityOptions`. See the documentation for +/// `ModuleHolder` to learn about PyTorch's module storage semantics. TORCH_MODULE(CosineSimilarity); // ============================================================================ @@ -53,12 +53,13 @@ TORCH_MODULE(CosineSimilarity); /// See https://pytorch.org/docs/master/nn.html#torch.nn.PairwiseDistance to /// learn about the exact behavior of this module. /// -/// See the documentation for `torch::nn::PairwiseDistanceOptions` class to learn what -/// constructor arguments are supported for this module. +/// See the documentation for `torch::nn::PairwiseDistanceOptions` class to +/// learn what constructor arguments are supported for this module. /// /// Example: /// ``` -/// PairwiseDistance model(PairwiseDistanceOptions().p(3).eps(0.5).keepdim(true)); +/// PairwiseDistance +/// model(PairwiseDistanceOptions().p(3).eps(0.5).keepdim(true)); /// ``` // NOLINTNEXTLINE(bugprone-exception-escape) class TORCH_API PairwiseDistanceImpl : public Cloneable { @@ -77,10 +78,10 @@ class TORCH_API PairwiseDistanceImpl : public Cloneable { }; /// A `ModuleHolder` subclass for `PairwiseDistanceImpl`. -/// See the documentation for `PairwiseDistanceImpl` class to learn what methods it -/// provides, and examples of how to use `PairwiseDistance` with `torch::nn::PairwiseDistanceOptions`. -/// See the documentation for `ModuleHolder` to learn about PyTorch's -/// module storage semantics. +/// See the documentation for `PairwiseDistanceImpl` class to learn what methods +/// it provides, and examples of how to use `PairwiseDistance` with +/// `torch::nn::PairwiseDistanceOptions`. See the documentation for +/// `ModuleHolder` to learn about PyTorch's module storage semantics. TORCH_MODULE(PairwiseDistance); } // namespace nn diff --git a/torch/csrc/api/include/torch/nn/modules/dropout.h b/torch/csrc/api/include/torch/nn/modules/dropout.h index e6f2ef4ea61ad3..af49b1e98791e9 100644 --- a/torch/csrc/api/include/torch/nn/modules/dropout.h +++ b/torch/csrc/api/include/torch/nn/modules/dropout.h @@ -19,10 +19,9 @@ template // NOLINTNEXTLINE(bugprone-exception-escape) class _DropoutNd : public torch::nn::Cloneable { public: - _DropoutNd(double p) : _DropoutNd(DropoutOptions().p(p)) {}; + _DropoutNd(double p) : _DropoutNd(DropoutOptions().p(p)){}; - explicit _DropoutNd(const DropoutOptions& options_ = {}) - : options(options_) { + explicit _DropoutNd(const DropoutOptions& options_ = {}) : options(options_) { // NOLINTNEXTLINE(clang-analyzer-optin.cplusplus.VirtualCall) reset(); } @@ -38,7 +37,7 @@ class _DropoutNd : public torch::nn::Cloneable { DropoutOptions options; }; -} +} // namespace detail // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Dropout ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -55,7 +54,7 @@ class _DropoutNd : public torch::nn::Cloneable { /// ``` // NOLINTNEXTLINE(bugprone-exception-escape) class TORCH_API DropoutImpl : public detail::_DropoutNd { -public: + public: using detail::_DropoutNd::_DropoutNd; Tensor forward(Tensor input); @@ -66,9 +65,9 @@ class TORCH_API DropoutImpl : public detail::_DropoutNd { /// A `ModuleHolder` subclass for `DropoutImpl`. /// See the documentation for `DropoutImpl` class to learn what methods it -/// provides, and examples of how to use `Dropout` with `torch::nn::DropoutOptions`. -/// See the documentation for `ModuleHolder` to learn about PyTorch's -/// module storage semantics. +/// provides, and examples of how to use `Dropout` with +/// `torch::nn::DropoutOptions`. See the documentation for `ModuleHolder` to +/// learn about PyTorch's module storage semantics. TORCH_MODULE(Dropout); // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Dropout2d ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -86,7 +85,7 @@ TORCH_MODULE(Dropout); /// ``` // NOLINTNEXTLINE(bugprone-exception-escape) class TORCH_API Dropout2dImpl : public detail::_DropoutNd { -public: + public: using detail::_DropoutNd::_DropoutNd; Tensor forward(Tensor input); @@ -97,9 +96,9 @@ class TORCH_API Dropout2dImpl : public detail::_DropoutNd { /// A `ModuleHolder` subclass for `Dropout2dImpl`. /// See the documentation for `Dropout2dImpl` class to learn what methods it -/// provides, and examples of how to use `Dropout2d` with `torch::nn::Dropout2dOptions`. -/// See the documentation for `ModuleHolder` to learn about PyTorch's -/// module storage semantics. +/// provides, and examples of how to use `Dropout2d` with +/// `torch::nn::Dropout2dOptions`. See the documentation for `ModuleHolder` to +/// learn about PyTorch's module storage semantics. TORCH_MODULE(Dropout2d); // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Dropout3d ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -117,7 +116,7 @@ TORCH_MODULE(Dropout2d); /// ``` // NOLINTNEXTLINE(bugprone-exception-escape) class TORCH_API Dropout3dImpl : public detail::_DropoutNd { -public: + public: using detail::_DropoutNd::_DropoutNd; Tensor forward(Tensor input); @@ -128,9 +127,9 @@ class TORCH_API Dropout3dImpl : public detail::_DropoutNd { /// A `ModuleHolder` subclass for `Dropout3dImpl`. /// See the documentation for `Dropout3dImpl` class to learn what methods it -/// provides, and examples of how to use `Dropout3d` with `torch::nn::Dropout3dOptions`. -/// See the documentation for `ModuleHolder` to learn about PyTorch's -/// module storage semantics. +/// provides, and examples of how to use `Dropout3d` with +/// `torch::nn::Dropout3dOptions`. See the documentation for `ModuleHolder` to +/// learn about PyTorch's module storage semantics. TORCH_MODULE(Dropout3d); // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ AlphaDropout ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -139,16 +138,15 @@ TORCH_MODULE(Dropout3d); /// See https://pytorch.org/docs/master/nn.html#torch.nn.AlphaDropout to learn /// about the exact behavior of this module. /// -/// See the documentation for `torch::nn::AlphaDropoutOptions` class to learn what -/// constructor arguments are supported for this module. +/// See the documentation for `torch::nn::AlphaDropoutOptions` class to learn +/// what constructor arguments are supported for this module. /// /// Example: /// ``` /// AlphaDropout model(AlphaDropoutOptions(0.2).inplace(true)); /// ``` // NOLINTNEXTLINE(bugprone-exception-escape) -class TORCH_API AlphaDropoutImpl - : public detail::_DropoutNd { +class TORCH_API AlphaDropoutImpl : public detail::_DropoutNd { public: using detail::_DropoutNd::_DropoutNd; @@ -160,15 +158,16 @@ class TORCH_API AlphaDropoutImpl /// A `ModuleHolder` subclass for `AlphaDropoutImpl`. /// See the documentation for `AlphaDropoutImpl` class to learn what methods it -/// provides, and examples of how to use `AlphaDropout` with `torch::nn::AlphaDropoutOptions`. -/// See the documentation for `ModuleHolder` to learn about PyTorch's -/// module storage semantics. +/// provides, and examples of how to use `AlphaDropout` with +/// `torch::nn::AlphaDropoutOptions`. See the documentation for `ModuleHolder` +/// to learn about PyTorch's module storage semantics. TORCH_MODULE(AlphaDropout); -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ FeatureAlphaDropout ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ FeatureAlphaDropout +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -/// See the documentation for `torch::nn::FeatureAlphaDropoutOptions` class to learn what -/// constructor arguments are supported for this module. +/// See the documentation for `torch::nn::FeatureAlphaDropoutOptions` class to +/// learn what constructor arguments are supported for this module. /// /// Example: /// ``` @@ -187,10 +186,10 @@ class TORCH_API FeatureAlphaDropoutImpl }; /// A `ModuleHolder` subclass for `FeatureAlphaDropoutImpl`. -/// See the documentation for `FeatureAlphaDropoutImpl` class to learn what methods it -/// provides, and examples of how to use `FeatureAlphaDropout` with `torch::nn::FeatureAlphaDropoutOptions`. -/// See the documentation for `ModuleHolder` to learn about PyTorch's -/// module storage semantics. +/// See the documentation for `FeatureAlphaDropoutImpl` class to learn what +/// methods it provides, and examples of how to use `FeatureAlphaDropout` with +/// `torch::nn::FeatureAlphaDropoutOptions`. See the documentation for +/// `ModuleHolder` to learn about PyTorch's module storage semantics. TORCH_MODULE(FeatureAlphaDropout); } // namespace nn diff --git a/torch/csrc/api/include/torch/nn/modules/embedding.h b/torch/csrc/api/include/torch/nn/modules/embedding.h index 849c296c3dd3ca..3bf305c4cbd86a 100644 --- a/torch/csrc/api/include/torch/nn/modules/embedding.h +++ b/torch/csrc/api/include/torch/nn/modules/embedding.h @@ -1,9 +1,9 @@ #pragma once #include -#include #include #include +#include #include #include @@ -12,7 +12,8 @@ namespace torch { namespace nn { -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Embedding ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Embedding +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ /// Performs a lookup in a fixed size embedding table. /// See https://pytorch.org/docs/master/nn.html#torch.nn.Embedding to learn @@ -23,13 +24,14 @@ namespace nn { /// /// Example: /// ``` -/// Embedding model(EmbeddingOptions(10, 2).padding_idx(3).max_norm(2).norm_type(2.5).scale_grad_by_freq(true).sparse(true)); +/// Embedding model(EmbeddingOptions(10, +/// 2).padding_idx(3).max_norm(2).norm_type(2.5).scale_grad_by_freq(true).sparse(true)); /// ``` // NOLINTNEXTLINE(bugprone-exception-escape) class TORCH_API EmbeddingImpl : public torch::nn::Cloneable { public: EmbeddingImpl(int64_t num_embeddings, int64_t embedding_dim) - : EmbeddingImpl(EmbeddingOptions(num_embeddings, embedding_dim)) {} + : EmbeddingImpl(EmbeddingOptions(num_embeddings, embedding_dim)) {} explicit EmbeddingImpl(const EmbeddingOptions& options_); void reset() override; @@ -53,55 +55,61 @@ class TORCH_API EmbeddingImpl : public torch::nn::Cloneable { /// A `ModuleHolder` subclass for `EmbeddingImpl`. /// See the documentation for `EmbeddingImpl` class to learn what methods it -/// provides, and examples of how to use `Embedding` with `torch::nn::EmbeddingOptions`. -/// See the documentation for `ModuleHolder` to learn about PyTorch's -/// module storage semantics. +/// provides, and examples of how to use `Embedding` with +/// `torch::nn::EmbeddingOptions`. See the documentation for `ModuleHolder` to +/// learn about PyTorch's module storage semantics. class Embedding : public torch::nn::ModuleHolder { public: using torch::nn::ModuleHolder::ModuleHolder; - /// See the documentation for `torch::nn::EmbeddingFromPretrainedOptions` class to learn what - /// optional arguments are supported for this function. - static Embedding from_pretrained(const torch::Tensor& embeddings, const EmbeddingFromPretrainedOptions& options = {}) { - TORCH_CHECK(embeddings.dim() == 2, "Embeddings parameter is expected to be 2-dimensional"); + /// See the documentation for `torch::nn::EmbeddingFromPretrainedOptions` + /// class to learn what optional arguments are supported for this function. + static Embedding from_pretrained( + const torch::Tensor& embeddings, + const EmbeddingFromPretrainedOptions& options = {}) { + TORCH_CHECK( + embeddings.dim() == 2, + "Embeddings parameter is expected to be 2-dimensional"); // NOLINTNEXTLINE(cppcoreguidelines-init-variables) int64_t rows, cols; rows = embeddings.size(0); cols = embeddings.size(1); - Embedding embedding( - EmbeddingOptions(rows, cols) - ._weight(embeddings) - .padding_idx(options.padding_idx()) - .max_norm(options.max_norm()) - .norm_type(options.norm_type()) - .scale_grad_by_freq(options.scale_grad_by_freq()) - .sparse(options.sparse())); + Embedding embedding(EmbeddingOptions(rows, cols) + ._weight(embeddings) + .padding_idx(options.padding_idx()) + .max_norm(options.max_norm()) + .norm_type(options.norm_type()) + .scale_grad_by_freq(options.scale_grad_by_freq()) + .sparse(options.sparse())); embedding->weight.set_requires_grad(!options.freeze()); return embedding; } }; -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ EmbeddingBag ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ EmbeddingBag +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ /// Computes sums or means of 'bags' of embeddings, without instantiating the /// intermediate embeddings. /// See https://pytorch.org/docs/master/nn.html#torch.nn.EmbeddingBag to learn /// about the exact behavior of this module. /// -/// See the documentation for `torch::nn::EmbeddingBagOptions` class to learn what -/// constructor arguments are supported for this module. +/// See the documentation for `torch::nn::EmbeddingBagOptions` class to learn +/// what constructor arguments are supported for this module. /// /// Example: /// ``` -/// EmbeddingBag model(EmbeddingBagOptions(10, 2).max_norm(2).norm_type(2.5).scale_grad_by_freq(true).sparse(true).mode(torch::kSum).padding_idx(1)); +/// EmbeddingBag model(EmbeddingBagOptions(10, +/// 2).max_norm(2).norm_type(2.5).scale_grad_by_freq(true).sparse(true).mode(torch::kSum).padding_idx(1)); /// ``` // NOLINTNEXTLINE(bugprone-exception-escape) -class TORCH_API EmbeddingBagImpl : public torch::nn::Cloneable { +class TORCH_API EmbeddingBagImpl + : public torch::nn::Cloneable { public: EmbeddingBagImpl(int64_t num_embeddings, int64_t embedding_dim) - : EmbeddingBagImpl(EmbeddingBagOptions(num_embeddings, embedding_dim)) {} + : EmbeddingBagImpl(EmbeddingBagOptions(num_embeddings, embedding_dim)) {} explicit EmbeddingBagImpl(const EmbeddingBagOptions& options_); void reset() override; @@ -116,24 +124,32 @@ class TORCH_API EmbeddingBagImpl : public torch::nn::Cloneable /// The embedding table. Tensor weight; - Tensor forward(const Tensor& input, const Tensor& offsets = {}, const Tensor& per_sample_weights = {}); + Tensor forward( + const Tensor& input, + const Tensor& offsets = {}, + const Tensor& per_sample_weights = {}); + protected: FORWARD_HAS_DEFAULT_ARGS({1, AnyValue(Tensor())}, {2, AnyValue(Tensor())}) }; /// A `ModuleHolder` subclass for `EmbeddingBagImpl`. /// See the documentation for `EmbeddingBagImpl` class to learn what methods it -/// provides, and examples of how to use `EmbeddingBag` with `torch::nn::EmbeddingBagOptions`. -/// See the documentation for `ModuleHolder` to learn about PyTorch's -/// module storage semantics. +/// provides, and examples of how to use `EmbeddingBag` with +/// `torch::nn::EmbeddingBagOptions`. See the documentation for `ModuleHolder` +/// to learn about PyTorch's module storage semantics. class EmbeddingBag : public torch::nn::ModuleHolder { public: using torch::nn::ModuleHolder::ModuleHolder; - /// See the documentation for `torch::nn::EmbeddingBagFromPretrainedOptions` class to learn what - /// optional arguments are supported for this function. - static EmbeddingBag from_pretrained(const torch::Tensor& embeddings, const EmbeddingBagFromPretrainedOptions& options = {}) { - TORCH_CHECK(embeddings.dim() == 2, "Embeddings parameter is expected to be 2-dimensional"); + /// See the documentation for `torch::nn::EmbeddingBagFromPretrainedOptions` + /// class to learn what optional arguments are supported for this function. + static EmbeddingBag from_pretrained( + const torch::Tensor& embeddings, + const EmbeddingBagFromPretrainedOptions& options = {}) { + TORCH_CHECK( + embeddings.dim() == 2, + "Embeddings parameter is expected to be 2-dimensional"); // NOLINTNEXTLINE(cppcoreguidelines-init-variables) int64_t rows, cols; @@ -141,14 +157,14 @@ class EmbeddingBag : public torch::nn::ModuleHolder { cols = embeddings.size(1); EmbeddingBag embeddingbag( - EmbeddingBagOptions(rows, cols) - ._weight(embeddings) - .max_norm(options.max_norm()) - .norm_type(options.norm_type()) - .scale_grad_by_freq(options.scale_grad_by_freq()) - .mode(options.mode()) - .sparse(options.sparse()) - .padding_idx(options.padding_idx())); + EmbeddingBagOptions(rows, cols) + ._weight(embeddings) + .max_norm(options.max_norm()) + .norm_type(options.norm_type()) + .scale_grad_by_freq(options.scale_grad_by_freq()) + .mode(options.mode()) + .sparse(options.sparse()) + .padding_idx(options.padding_idx())); embeddingbag->weight.set_requires_grad(!options.freeze()); return embeddingbag; } diff --git a/torch/csrc/api/include/torch/nn/modules/fold.h b/torch/csrc/api/include/torch/nn/modules/fold.h index 43207f14bf033d..ff2d9e331d0fa0 100644 --- a/torch/csrc/api/include/torch/nn/modules/fold.h +++ b/torch/csrc/api/include/torch/nn/modules/fold.h @@ -19,7 +19,8 @@ namespace nn { /// /// Example: /// ``` -/// Fold model(FoldOptions({8, 8}, {3, 3}).dilation(2).padding({2, 1}).stride(2)); +/// Fold model(FoldOptions({8, 8}, {3, 3}).dilation(2).padding({2, +/// 1}).stride(2)); /// ``` // NOLINTNEXTLINE(bugprone-exception-escape) class TORCH_API FoldImpl : public torch::nn::Cloneable { @@ -79,9 +80,9 @@ class TORCH_API UnfoldImpl : public Cloneable { /// A `ModuleHolder` subclass for `UnfoldImpl`. /// See the documentation for `UnfoldImpl` class to learn what methods it -/// provides, and examples of how to use `Unfold` with `torch::nn::UnfoldOptions`. -/// See the documentation for `ModuleHolder` to learn about PyTorch's -/// module storage semantics. +/// provides, and examples of how to use `Unfold` with +/// `torch::nn::UnfoldOptions`. See the documentation for `ModuleHolder` to +/// learn about PyTorch's module storage semantics. TORCH_MODULE(Unfold); } // namespace nn diff --git a/torch/csrc/api/include/torch/nn/modules/instancenorm.h b/torch/csrc/api/include/torch/nn/modules/instancenorm.h index ca4d43ebba86da..83b0ea1fbfbe60 100644 --- a/torch/csrc/api/include/torch/nn/modules/instancenorm.h +++ b/torch/csrc/api/include/torch/nn/modules/instancenorm.h @@ -9,7 +9,8 @@ namespace nn { /// Base class for all (dimension-specialized) instance norm modules template // NOLINTNEXTLINE(bugprone-exception-escape) -class InstanceNormImpl : public torch::nn::NormImplBase { +class InstanceNormImpl + : public torch::nn::NormImplBase { private: inline Tensor apply_instance_norm(const Tensor& input) { return torch::nn::functional::detail::instance_norm( @@ -48,91 +49,100 @@ class InstanceNormImpl : public torch::nn::NormImplBase { +class TORCH_API InstanceNorm1dImpl + : public InstanceNormImpl<1, InstanceNorm1dImpl> { protected: - void _check_input_dim(const Tensor& input) override; + void _check_input_dim(const Tensor& input) override; public: using InstanceNormImpl<1, InstanceNorm1dImpl>::InstanceNormImpl; }; /// A `ModuleHolder` subclass for `InstanceNorm1dImpl`. -/// See the documentation for `InstanceNorm1dImpl` class to learn what methods it -/// provides, and examples of how to use `InstanceNorm1d` with `torch::nn::InstanceNorm1dOptions`. -/// See the documentation for `ModuleHolder` to learn about PyTorch's -/// module storage semantics. +/// See the documentation for `InstanceNorm1dImpl` class to learn what methods +/// it provides, and examples of how to use `InstanceNorm1d` with +/// `torch::nn::InstanceNorm1dOptions`. See the documentation for `ModuleHolder` +/// to learn about PyTorch's module storage semantics. TORCH_MODULE(InstanceNorm1d); -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ InstanceNorm2d ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ InstanceNorm2d +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ /// Applies the InstanceNorm2d function. /// See https://pytorch.org/docs/master/nn.html#torch.nn.InstanceNorm2d to learn /// about the exact behavior of this module. /// -/// See the documentation for `torch::nn::InstanceNorm2dOptions` class to learn what -/// constructor arguments are supported for this module. +/// See the documentation for `torch::nn::InstanceNorm2dOptions` class to learn +/// what constructor arguments are supported for this module. /// /// Example: /// ``` -/// InstanceNorm2d model(InstanceNorm2dOptions(4).eps(0.5).momentum(0.1).affine(false).track_running_stats(true)); +/// InstanceNorm2d +/// model(InstanceNorm2dOptions(4).eps(0.5).momentum(0.1).affine(false).track_running_stats(true)); /// ``` // NOLINTNEXTLINE(bugprone-exception-escape) -class TORCH_API InstanceNorm2dImpl : public InstanceNormImpl<2, InstanceNorm2dImpl> { +class TORCH_API InstanceNorm2dImpl + : public InstanceNormImpl<2, InstanceNorm2dImpl> { protected: - void _check_input_dim(const Tensor& input) override; + void _check_input_dim(const Tensor& input) override; public: using InstanceNormImpl<2, InstanceNorm2dImpl>::InstanceNormImpl; }; /// A `ModuleHolder` subclass for `InstanceNorm2dImpl`. -/// See the documentation for `InstanceNorm2dImpl` class to learn what methods it -/// provides, and examples of how to use `InstanceNorm2d` with `torch::nn::InstanceNorm2dOptions`. -/// See the documentation for `ModuleHolder` to learn about PyTorch's -/// module storage semantics. +/// See the documentation for `InstanceNorm2dImpl` class to learn what methods +/// it provides, and examples of how to use `InstanceNorm2d` with +/// `torch::nn::InstanceNorm2dOptions`. See the documentation for `ModuleHolder` +/// to learn about PyTorch's module storage semantics. TORCH_MODULE(InstanceNorm2d); -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ InstanceNorm3d ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ InstanceNorm3d +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ /// Applies the InstanceNorm3d function. /// See https://pytorch.org/docs/master/nn.html#torch.nn.InstanceNorm3d to learn /// about the exact behavior of this module. /// -/// See the documentation for `torch::nn::InstanceNorm3dOptions` class to learn what -/// constructor arguments are supported for this module. +/// See the documentation for `torch::nn::InstanceNorm3dOptions` class to learn +/// what constructor arguments are supported for this module. /// /// Example: /// ``` -/// InstanceNorm3d model(InstanceNorm3dOptions(4).eps(0.5).momentum(0.1).affine(false).track_running_stats(true)); +/// InstanceNorm3d +/// model(InstanceNorm3dOptions(4).eps(0.5).momentum(0.1).affine(false).track_running_stats(true)); /// ``` // NOLINTNEXTLINE(bugprone-exception-escape) -class TORCH_API InstanceNorm3dImpl : public InstanceNormImpl<3, InstanceNorm3dImpl> { +class TORCH_API InstanceNorm3dImpl + : public InstanceNormImpl<3, InstanceNorm3dImpl> { protected: - void _check_input_dim(const Tensor& input) override; + void _check_input_dim(const Tensor& input) override; public: using InstanceNormImpl<3, InstanceNorm3dImpl>::InstanceNormImpl; }; /// A `ModuleHolder` subclass for `InstanceNorm3dImpl`. -/// See the documentation for `InstanceNorm3dImpl` class to learn what methods it -/// provides, and examples of how to use `InstanceNorm3d` with `torch::nn::InstanceNorm3dOptions`. -/// See the documentation for `ModuleHolder` to learn about PyTorch's -/// module storage semantics. +/// See the documentation for `InstanceNorm3dImpl` class to learn what methods +/// it provides, and examples of how to use `InstanceNorm3d` with +/// `torch::nn::InstanceNorm3dOptions`. See the documentation for `ModuleHolder` +/// to learn about PyTorch's module storage semantics. TORCH_MODULE(InstanceNorm3d); } // namespace nn diff --git a/torch/csrc/api/include/torch/nn/modules/linear.h b/torch/csrc/api/include/torch/nn/modules/linear.h index 6f857d95457fd6..6ba9e35eef65d4 100644 --- a/torch/csrc/api/include/torch/nn/modules/linear.h +++ b/torch/csrc/api/include/torch/nn/modules/linear.h @@ -1,10 +1,10 @@ #pragma once #include +#include #include #include #include -#include #include #include @@ -16,8 +16,8 @@ namespace nn { // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Identity ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ /// A placeholder identity operator that is argument-insensitive. -/// See https://pytorch.org/docs/master/generated/torch.nn.Identity.html to learn -/// about the exact behavior of this module. +/// See https://pytorch.org/docs/master/generated/torch.nn.Identity.html to +/// learn about the exact behavior of this module. // NOLINTNEXTLINE(bugprone-exception-escape) class TORCH_API IdentityImpl : public Cloneable { public: @@ -52,7 +52,7 @@ TORCH_MODULE(Identity); class TORCH_API LinearImpl : public Cloneable { public: LinearImpl(int64_t in_features, int64_t out_features) - : LinearImpl(LinearOptions(in_features, out_features)) {} + : LinearImpl(LinearOptions(in_features, out_features)) {} explicit LinearImpl(const LinearOptions& options_); void reset() override; @@ -79,9 +79,9 @@ class TORCH_API LinearImpl : public Cloneable { /// A `ModuleHolder` subclass for `LinearImpl`. /// See the documentation for `LinearImpl` class to learn what methods it -/// provides, and examples of how to use `Linear` with `torch::nn::LinearOptions`. -/// See the documentation for `ModuleHolder` to learn about PyTorch's -/// module storage semantics. +/// provides, and examples of how to use `Linear` with +/// `torch::nn::LinearOptions`. See the documentation for `ModuleHolder` to +/// learn about PyTorch's module storage semantics. TORCH_MODULE(Linear); // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Flatten ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -116,16 +116,17 @@ class TORCH_API FlattenImpl : public Cloneable { /// A `ModuleHolder` subclass for `FlattenImpl`. /// See the documentation for `FlattenImpl` class to learn what methods it -/// provides, and examples of how to use `Flatten` with `torch::nn::FlattenOptions`. -/// See the documentation for `ModuleHolder` to learn about PyTorch's -/// module storage semantics. +/// provides, and examples of how to use `Flatten` with +/// `torch::nn::FlattenOptions`. See the documentation for `ModuleHolder` to +/// learn about PyTorch's module storage semantics. TORCH_MODULE(Flatten); -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Unflatten ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Unflatten +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ /// A placeholder for unflatten operator -/// See https://pytorch.org/docs/master/generated/torch.nn.Unflatten.html to learn -/// about the exact behavior of this module. +/// See https://pytorch.org/docs/master/generated/torch.nn.Unflatten.html to +/// learn about the exact behavior of this module. /// /// See the documentation for `torch::nn::UnflattenOptions` class to learn what /// constructor arguments are supported for this module. @@ -139,9 +140,9 @@ TORCH_MODULE(Flatten); class TORCH_API UnflattenImpl : public Cloneable { public: UnflattenImpl(int64_t dim, std::vector sizes) - : UnflattenImpl(UnflattenOptions(dim, sizes)) {} + : UnflattenImpl(UnflattenOptions(dim, sizes)) {} UnflattenImpl(std::string dimname, UnflattenOptions::namedshape_t namedshape) - : UnflattenImpl(UnflattenOptions(dimname, namedshape)) {} + : UnflattenImpl(UnflattenOptions(dimname, namedshape)) {} explicit UnflattenImpl(UnflattenOptions options_); void reset() override; @@ -158,16 +159,16 @@ class TORCH_API UnflattenImpl : public Cloneable { /// A `ModuleHolder` subclass for `UnflattenImpl`. /// See the documentation for `UnflattenImpl` class to learn what methods it -/// provides, and examples of how to use `Unflatten` with `torch::nn::UnflattenOptions`. -/// See the documentation for `ModuleHolder` to learn about PyTorch's -/// module storage semantics. +/// provides, and examples of how to use `Unflatten` with +/// `torch::nn::UnflattenOptions`. See the documentation for `ModuleHolder` to +/// learn about PyTorch's module storage semantics. TORCH_MODULE(Unflatten); // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Bilinear ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ /// Applies a billinear transformation with optional bias. -/// See https://pytorch.org/docs/master/generated/torch.nn.Bilinear.html to learn -/// about the exact behavior of this module. +/// See https://pytorch.org/docs/master/generated/torch.nn.Bilinear.html to +/// learn about the exact behavior of this module. /// /// See the documentation for `torch::nn::BilinearOptions` class to learn what /// constructor arguments are supported for this module. @@ -179,7 +180,9 @@ TORCH_MODULE(Unflatten); // NOLINTNEXTLINE(bugprone-exception-escape) class TORCH_API BilinearImpl : public Cloneable { public: - BilinearImpl(int64_t in1_features, int64_t in2_features, int64_t out_features) : BilinearImpl(BilinearOptions(in1_features, in2_features, out_features)) {} + BilinearImpl(int64_t in1_features, int64_t in2_features, int64_t out_features) + : BilinearImpl( + BilinearOptions(in1_features, in2_features, out_features)) {} explicit BilinearImpl(const BilinearOptions& options_); void reset() override; @@ -189,9 +192,9 @@ class TORCH_API BilinearImpl : public Cloneable { /// Pretty prints the `Bilinear` module into the given `stream`. void pretty_print(std::ostream& stream) const override; - /// Applies a bilinear transform on the `input1` and `input2` tensor by multiplying - /// with the `weight` and optionally adding the `bias`, if `with_bias` - /// is true in the options. + /// Applies a bilinear transform on the `input1` and `input2` tensor by + /// multiplying with the `weight` and optionally adding the `bias`, if + /// `with_bias` is true in the options. Tensor forward(const Tensor& input1, const Tensor& input2); /// The options used to configure this module. @@ -207,9 +210,9 @@ class TORCH_API BilinearImpl : public Cloneable { /// A `ModuleHolder` subclass for `BilinearImpl`. /// See the documentation for `BilinearImpl` class to learn what methods it -/// provides, and examples of how to use `Bilinear` with `torch::nn::BilinearOptions`. -/// See the documentation for `ModuleHolder` to learn about PyTorch's -/// module storage semantics. +/// provides, and examples of how to use `Bilinear` with +/// `torch::nn::BilinearOptions`. See the documentation for `ModuleHolder` to +/// learn about PyTorch's module storage semantics. TORCH_MODULE(Bilinear); } // namespace nn diff --git a/torch/csrc/api/include/torch/nn/modules/loss.h b/torch/csrc/api/include/torch/nn/modules/loss.h index 9d6958b8ca3f41..cabc1a0ed811b3 100644 --- a/torch/csrc/api/include/torch/nn/modules/loss.h +++ b/torch/csrc/api/include/torch/nn/modules/loss.h @@ -46,12 +46,13 @@ struct TORCH_API L1LossImpl : Cloneable { /// A `ModuleHolder` subclass for `L1LossImpl`. /// See the documentation for `L1LossImpl` class to learn what methods it -/// provides, and examples of how to use `L1Loss` with `torch::nn::L1LossOptions`. -/// See the documentation for `ModuleHolder` to learn about PyTorch's -/// module storage semantics. +/// provides, and examples of how to use `L1Loss` with +/// `torch::nn::L1LossOptions`. See the documentation for `ModuleHolder` to +/// learn about PyTorch's module storage semantics. TORCH_MODULE(L1Loss); -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ KLDivLoss ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ KLDivLoss +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ /// The Kullback-Leibler divergence loss measure /// See https://pytorch.org/docs/master/nn.html#torch.nn.KLDivLoss to learn @@ -81,9 +82,9 @@ struct TORCH_API KLDivLossImpl : Cloneable { /// A `ModuleHolder` subclass for `KLDivLossImpl`. /// See the documentation for `KLDivLossImpl` class to learn what methods it -/// provides, and examples of how to use `KLDivLoss` with `torch::nn::KLDivLossOptions`. -/// See the documentation for `ModuleHolder` to learn about PyTorch's -/// module storage semantics. +/// provides, and examples of how to use `KLDivLoss` with +/// `torch::nn::KLDivLossOptions`. See the documentation for `ModuleHolder` to +/// learn about PyTorch's module storage semantics. TORCH_MODULE(KLDivLoss); // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ MSELoss ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -117,9 +118,9 @@ struct TORCH_API MSELossImpl : Cloneable { /// A `ModuleHolder` subclass for `MSELossImpl`. /// See the documentation for `MSELossImpl` class to learn what methods it -/// provides, and examples of how to use `MSELoss` with `torch::nn::MSELossOptions`. -/// See the documentation for `ModuleHolder` to learn about PyTorch's -/// module storage semantics. +/// provides, and examples of how to use `MSELoss` with +/// `torch::nn::MSELossOptions`. See the documentation for `ModuleHolder` to +/// learn about PyTorch's module storage semantics. TORCH_MODULE(MSELoss); // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ BCELoss ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -153,24 +154,26 @@ struct TORCH_API BCELossImpl : Cloneable { /// A `ModuleHolder` subclass for `BCELossImpl`. /// See the documentation for `BCELossImpl` class to learn what methods it -/// provides, and examples of how to use `BCELoss` with `torch::nn::BCELossOptions`. -/// See the documentation for `ModuleHolder` to learn about PyTorch's -/// module storage semantics. +/// provides, and examples of how to use `BCELoss` with +/// `torch::nn::BCELossOptions`. See the documentation for `ModuleHolder` to +/// learn about PyTorch's module storage semantics. TORCH_MODULE(BCELoss); -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ HingeEmbeddingLoss ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ HingeEmbeddingLoss +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ /// Creates a criterion that measures the loss given an input tensor :math:`x` /// and a labels tensor :math:`y` (containing 1 or -1). -/// See https://pytorch.org/docs/master/nn.html#torch.nn.HingeEmbeddingLoss to learn -/// about the exact behavior of this module. +/// See https://pytorch.org/docs/master/nn.html#torch.nn.HingeEmbeddingLoss to +/// learn about the exact behavior of this module. /// -/// See the documentation for `torch::nn::HingeEmbeddingLossOptions` class to learn what -/// constructor arguments are supported for this module. +/// See the documentation for `torch::nn::HingeEmbeddingLossOptions` class to +/// learn what constructor arguments are supported for this module. /// /// Example: /// ``` -/// HingeEmbeddingLoss model(HingeEmbeddingLossOptions().margin(4).reduction(torch::kNone)); +/// HingeEmbeddingLoss +/// model(HingeEmbeddingLossOptions().margin(4).reduction(torch::kNone)); /// ``` // NOLINTNEXTLINE(bugprone-exception-escape) struct TORCH_API HingeEmbeddingLossImpl : Cloneable { @@ -189,23 +192,24 @@ struct TORCH_API HingeEmbeddingLossImpl : Cloneable { }; /// A `ModuleHolder` subclass for `HingeEmbeddingLossImpl`. -/// See the documentation for `HingeEmbeddingLossImpl` class to learn what methods it -/// provides, and examples of how to use `HingeEmbeddingLoss` with `torch::nn::HingeEmbeddingLossOptions`. -/// See the documentation for `ModuleHolder` to learn about PyTorch's -/// module storage semantics. +/// See the documentation for `HingeEmbeddingLossImpl` class to learn what +/// methods it provides, and examples of how to use `HingeEmbeddingLoss` with +/// `torch::nn::HingeEmbeddingLossOptions`. See the documentation for +/// `ModuleHolder` to learn about PyTorch's module storage semantics. TORCH_MODULE(HingeEmbeddingLoss); -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ MultiMarginLoss ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ MultiMarginLoss +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ /// Creates a criterion that optimizes a multi-class classification hinge -/// loss (margin-based loss) between input :math:`x` (a 2D mini-batch `Tensor`) and -/// output :math:`y` (which is a 1D tensor of target class indices, -/// :math:`0 \leq y \leq \text{x.size}(1)-1`). -/// See https://pytorch.org/docs/master/nn.html#torch.nn.MultiMarginLoss to learn +/// loss (margin-based loss) between input :math:`x` (a 2D mini-batch `Tensor`) +/// and output :math:`y` (which is a 1D tensor of target class indices, :math:`0 +/// \leq y \leq \text{x.size}(1)-1`). See +/// https://pytorch.org/docs/master/nn.html#torch.nn.MultiMarginLoss to learn /// about the exact behavior of this module. /// -/// See the documentation for `torch::nn::MultiMarginLossOptions` class to learn what -/// constructor arguments are supported for this module. +/// See the documentation for `torch::nn::MultiMarginLossOptions` class to learn +/// what constructor arguments are supported for this module. /// /// Example: /// ``` @@ -213,8 +217,7 @@ TORCH_MODULE(HingeEmbeddingLoss); /// ``` // NOLINTNEXTLINE(bugprone-exception-escape) struct TORCH_API MultiMarginLossImpl : public Cloneable { - explicit MultiMarginLossImpl( - const MultiMarginLossOptions& options_ = {}); + explicit MultiMarginLossImpl(const MultiMarginLossOptions& options_ = {}); void reset() override; @@ -228,31 +231,33 @@ struct TORCH_API MultiMarginLossImpl : public Cloneable { }; /// A `ModuleHolder` subclass for `MultiMarginLossImpl`. -/// See the documentation for `MultiMarginLossImpl` class to learn what methods it -/// provides, and examples of how to use `MultiMarginLoss` with `torch::nn::MultiMarginLossOptions`. -/// See the documentation for `ModuleHolder` to learn about PyTorch's -/// module storage semantics. +/// See the documentation for `MultiMarginLossImpl` class to learn what methods +/// it provides, and examples of how to use `MultiMarginLoss` with +/// `torch::nn::MultiMarginLossOptions`. See the documentation for +/// `ModuleHolder` to learn about PyTorch's module storage semantics. TORCH_MODULE(MultiMarginLoss); -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ CosineEmbeddingLoss ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ CosineEmbeddingLoss +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ /// Creates a criterion that measures the loss given input tensors /// `input1`, `input2`, and a `Tensor` label `target` with values 1 or /// -1. This is used for measuring whether two inputs are similar or /// dissimilar, using the cosine distance, and is typically used for learning /// nonlinear embeddings or semi-supervised learning. -/// See https://pytorch.org/docs/master/nn.html#torch.nn.CosineEmbeddingLoss to learn -/// about the exact behavior of this module. +/// See https://pytorch.org/docs/master/nn.html#torch.nn.CosineEmbeddingLoss to +/// learn about the exact behavior of this module. /// -/// See the documentation for `torch::nn::CosineEmbeddingLossOptions` class to learn what -/// constructor arguments are supported for this module. +/// See the documentation for `torch::nn::CosineEmbeddingLossOptions` class to +/// learn what constructor arguments are supported for this module. /// /// Example: /// ``` /// CosineEmbeddingLoss model(CosineEmbeddingLossOptions().margin(0.5)); /// ``` // NOLINTNEXTLINE(bugprone-exception-escape) -struct TORCH_API CosineEmbeddingLossImpl : public Cloneable { +struct TORCH_API CosineEmbeddingLossImpl + : public Cloneable { explicit CosineEmbeddingLossImpl( const CosineEmbeddingLossOptions& options_ = {}); @@ -271,23 +276,24 @@ struct TORCH_API CosineEmbeddingLossImpl : public Cloneable { /// A `ModuleHolder` subclass for `SmoothL1LossImpl`. /// See the documentation for `SmoothL1LossImpl` class to learn what methods it -/// provides, and examples of how to use `SmoothL1Loss` with `torch::nn::SmoothL1LossOptions`. -/// See the documentation for `ModuleHolder` to learn about PyTorch's -/// module storage semantics. +/// provides, and examples of how to use `SmoothL1Loss` with +/// `torch::nn::SmoothL1LossOptions`. See the documentation for `ModuleHolder` +/// to learn about PyTorch's module storage semantics. TORCH_MODULE(SmoothL1Loss); -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ HuberLoss ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ HuberLoss +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ /// Creates a criterion that uses a squared term if the absolute /// element-wise error falls below delta and a delta-scaled L1 term otherwise. @@ -346,30 +353,33 @@ struct TORCH_API HuberLossImpl : public Cloneable { /// A `ModuleHolder` subclass for `HuberLossImpl`. /// See the documentation for `HuberLossImpl` class to learn what methods it -/// provides, and examples of how to use `HuberLoss` with `torch::nn::HuberLossOptions`. -/// See the documentation for `ModuleHolder` to learn about PyTorch's -/// module storage semantics. +/// provides, and examples of how to use `HuberLoss` with +/// `torch::nn::HuberLossOptions`. See the documentation for `ModuleHolder` to +/// learn about PyTorch's module storage semantics. TORCH_MODULE(HuberLoss); -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ MultiLabelMarginLoss ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ MultiLabelMarginLoss +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ /// Creates a criterion that optimizes a multi-class multi-classification -/// hinge loss (margin-based loss) between input :math:`x` (a 2D mini-batch `Tensor`) -/// and output :math:`y` (which is a 2D `Tensor` of target class indices). -/// See https://pytorch.org/docs/master/nn.html#torch.nn.MultiLabelMarginLoss to learn -/// about the exact behavior of this module. +/// hinge loss (margin-based loss) between input :math:`x` (a 2D mini-batch +/// `Tensor`) and output :math:`y` (which is a 2D `Tensor` of target class +/// indices). See +/// https://pytorch.org/docs/master/nn.html#torch.nn.MultiLabelMarginLoss to +/// learn about the exact behavior of this module. /// -/// See the documentation for `torch::nn::MultiLabelMarginLossOptions` class to learn what -/// constructor arguments are supported for this module. +/// See the documentation for `torch::nn::MultiLabelMarginLossOptions` class to +/// learn what constructor arguments are supported for this module. /// /// Example: /// ``` /// MultiLabelMarginLoss model(MultiLabelMarginLossOptions(torch::kNone)); /// ``` // NOLINTNEXTLINE(bugprone-exception-escape) -struct TORCH_API MultiLabelMarginLossImpl : public Cloneable { +struct TORCH_API MultiLabelMarginLossImpl + : public Cloneable { explicit MultiLabelMarginLossImpl( - const MultiLabelMarginLossOptions& options_ = {}); + const MultiLabelMarginLossOptions& options_ = {}); void reset() override; @@ -383,13 +393,14 @@ struct TORCH_API MultiLabelMarginLossImpl : public Cloneable { }; /// A `ModuleHolder` subclass for `SoftMarginLossImpl`. -/// See the documentation for `SoftMarginLossImpl` class to learn what methods it -/// provides, and examples of how to use `SoftMarginLoss` with `torch::nn::SoftMarginLossOptions`. -/// See the documentation for `ModuleHolder` to learn about PyTorch's -/// module storage semantics. +/// See the documentation for `SoftMarginLossImpl` class to learn what methods +/// it provides, and examples of how to use `SoftMarginLoss` with +/// `torch::nn::SoftMarginLossOptions`. See the documentation for `ModuleHolder` +/// to learn about PyTorch's module storage semantics. TORCH_MODULE(SoftMarginLoss); -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ MultiLabelSoftMarginLoss ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ MultiLabelSoftMarginLoss +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ /// Creates a criterion that optimizes a multi-label one-versus-all -/// loss based on max-entropy, between input :math:`x` and target :math:`y` of size -/// :math:`(N, C)`. -/// See https://pytorch.org/docs/master/nn.html#torch.nn.MultiLabelSoftMarginLoss to learn -/// about the exact behavior of this module. +/// loss based on max-entropy, between input :math:`x` and target :math:`y` of +/// size :math:`(N, C)`. See +/// https://pytorch.org/docs/master/nn.html#torch.nn.MultiLabelSoftMarginLoss to +/// learn about the exact behavior of this module. /// -/// See the documentation for `torch::nn::MultiLabelSoftMarginLossOptions` class to learn what -/// constructor arguments are supported for this module. +/// See the documentation for `torch::nn::MultiLabelSoftMarginLossOptions` class +/// to learn what constructor arguments are supported for this module. /// /// Example: /// ``` -/// MultiLabelSoftMarginLoss model(MultiLabelSoftMarginLossOptions().reduction(torch::kNone).weight(weight)); +/// MultiLabelSoftMarginLoss +/// model(MultiLabelSoftMarginLossOptions().reduction(torch::kNone).weight(weight)); /// ``` // NOLINTNEXTLINE(bugprone-exception-escape) -struct TORCH_API MultiLabelSoftMarginLossImpl : public Cloneable { +struct TORCH_API MultiLabelSoftMarginLossImpl + : public Cloneable { explicit MultiLabelSoftMarginLossImpl( - const MultiLabelSoftMarginLossOptions& options_ = {}); + const MultiLabelSoftMarginLossOptions& options_ = {}); - /// Pretty prints the `MultiLabelSoftMarginLoss` module into the given `stream`. + /// Pretty prints the `MultiLabelSoftMarginLoss` module into the given + /// `stream`. void pretty_print(std::ostream& stream) const override; void reset() override; @@ -458,13 +473,14 @@ struct TORCH_API MultiLabelSoftMarginLossImpl : public Cloneable { - explicit TripletMarginLossImpl( - const TripletMarginLossOptions& options_ = {}); +struct TORCH_API TripletMarginLossImpl + : public Cloneable { + explicit TripletMarginLossImpl(const TripletMarginLossOptions& options_ = {}); void reset() override; @@ -502,38 +519,44 @@ struct TORCH_API TripletMarginLossImpl : public Cloneable }; /// A `ModuleHolder` subclass for `TripletMarginLossImpl`. -/// See the documentation for `TripletMarginLossImpl` class to learn what methods it -/// provides, and examples of how to use `TripletMarginLoss` with `torch::nn::TripletMarginLossOptions`. -/// See the documentation for `ModuleHolder` to learn about PyTorch's -/// module storage semantics. +/// See the documentation for `TripletMarginLossImpl` class to learn what +/// methods it provides, and examples of how to use `TripletMarginLoss` with +/// `torch::nn::TripletMarginLossOptions`. See the documentation for +/// `ModuleHolder` to learn about PyTorch's module storage semantics. TORCH_MODULE(TripletMarginLoss); -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ TripletMarginWithDistanceLoss ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ TripletMarginWithDistanceLoss +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ /// Creates a criterion that measures the triplet loss given input /// tensors :math:`a`, :math:`p`, and :math:`n` (representing anchor, -/// positive, and negative examples, respectively); and a nonnegative, real-valued function +/// positive, and negative examples, respectively); and a nonnegative, +/// real-valued function /// ("distance function") used to compute the relationships between the anchor /// and positive example ("positive distance") and the anchor and negative /// example ("negative distance"). -/// See https://pytorch.org/docs/master/nn.html#torch.nn.TripletMarginWithDistanceLoss to learn -/// about the exact behavior of this module. +/// See +/// https://pytorch.org/docs/master/nn.html#torch.nn.TripletMarginWithDistanceLoss +/// to learn about the exact behavior of this module. /// -/// See the documentation for `torch::nn::TripletMarginWithDistanceLossOptions` class to learn what -/// constructor arguments are supported for this module. +/// See the documentation for `torch::nn::TripletMarginWithDistanceLossOptions` +/// class to learn what constructor arguments are supported for this module. /// /// Example: /// ``` -/// TripletMarginWithDistanceLoss model(TripletMarginWithDistanceLossOptions().margin(3).swap(false)); +/// TripletMarginWithDistanceLoss +/// model(TripletMarginWithDistanceLossOptions().margin(3).swap(false)); /// ``` // NOLINTNEXTLINE(bugprone-exception-escape) -struct TORCH_API TripletMarginWithDistanceLossImpl : public Cloneable { +struct TORCH_API TripletMarginWithDistanceLossImpl + : public Cloneable { explicit TripletMarginWithDistanceLossImpl( TripletMarginWithDistanceLossOptions options_ = {}); void reset() override; - /// Pretty prints the `TripletMarginWithDistanceLoss` module into the given `stream`. + /// Pretty prints the `TripletMarginWithDistanceLoss` module into the given + /// `stream`. void pretty_print(std::ostream& stream) const override; Tensor forward( @@ -546,8 +569,9 @@ struct TORCH_API TripletMarginWithDistanceLossImpl : public Cloneable { - explicit CTCLossImpl(const CTCLossOptions& options_ = {}); void reset() override; @@ -576,8 +600,11 @@ struct TORCH_API CTCLossImpl : public Cloneable { /// Pretty prints the `CTCLoss` module into the given `stream`. void pretty_print(std::ostream& stream) const override; - Tensor forward(const Tensor& log_probs, const Tensor& targets, - const Tensor& input_lengths, const Tensor& target_lengths); + Tensor forward( + const Tensor& log_probs, + const Tensor& targets, + const Tensor& input_lengths, + const Tensor& target_lengths); /// The options with which this `Module` was constructed. CTCLossOptions options; @@ -585,23 +612,25 @@ struct TORCH_API CTCLossImpl : public Cloneable { /// A `ModuleHolder` subclass for `CTCLossImpl`. /// See the documentation for `CTCLossImpl` class to learn what methods it -/// provides, and examples of how to use `CTCLoss` with `torch::nn::CTCLossOptions`. -/// See the documentation for `ModuleHolder` to learn about PyTorch's -/// module storage semantics. +/// provides, and examples of how to use `CTCLoss` with +/// `torch::nn::CTCLossOptions`. See the documentation for `ModuleHolder` to +/// learn about PyTorch's module storage semantics. TORCH_MODULE(CTCLoss); -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ PoissonNLLLoss ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ PoissonNLLLoss +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ /// Negative log likelihood loss with Poisson distribution of target. /// See https://pytorch.org/docs/master/nn.html#torch.nn.PoissonNLLLoss to learn /// about the exact behavior of this module. /// -/// See the documentation for `torch::nn::PoissonNLLLossOptions` class to learn what -/// constructor arguments are supported for this module. +/// See the documentation for `torch::nn::PoissonNLLLossOptions` class to learn +/// what constructor arguments are supported for this module. /// /// Example: /// ``` -/// PoissonNLLLoss model(PoissonNLLLossOptions().log_input(false).full(true).eps(0.42).reduction(torch::kSum)); +/// PoissonNLLLoss +/// model(PoissonNLLLossOptions().log_input(false).full(true).eps(0.42).reduction(torch::kSum)); /// ``` // NOLINTNEXTLINE(bugprone-exception-escape) struct TORCH_API PoissonNLLLossImpl : public Cloneable { @@ -619,29 +648,32 @@ struct TORCH_API PoissonNLLLossImpl : public Cloneable { }; /// A `ModuleHolder` subclass for `PoissonNLLLossImpl`. -/// See the documentation for `PoissonNLLLossImpl` class to learn what methods it -/// provides, and examples of how to use `PoissonNLLLoss` with `torch::nn::PoissonNLLLossOptions`. -/// See the documentation for `ModuleHolder` to learn about PyTorch's -/// module storage semantics. +/// See the documentation for `PoissonNLLLossImpl` class to learn what methods +/// it provides, and examples of how to use `PoissonNLLLoss` with +/// `torch::nn::PoissonNLLLossOptions`. See the documentation for `ModuleHolder` +/// to learn about PyTorch's module storage semantics. TORCH_MODULE(PoissonNLLLoss); -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ MarginRankingLoss ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ MarginRankingLoss +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ /// Creates a criterion that measures the loss given /// inputs :math:`x1`, :math:`x2`, two 1D mini-batch `Tensors`, /// and a label 1D mini-batch tensor :math:`y` (containing 1 or -1). -/// See https://pytorch.org/docs/master/nn.html#torch.nn.MarginRankingLoss to learn -/// about the exact behavior of this module. +/// See https://pytorch.org/docs/master/nn.html#torch.nn.MarginRankingLoss to +/// learn about the exact behavior of this module. /// -/// See the documentation for `torch::nn::MarginRankingLossOptions` class to learn what -/// constructor arguments are supported for this module. +/// See the documentation for `torch::nn::MarginRankingLossOptions` class to +/// learn what constructor arguments are supported for this module. /// /// Example: /// ``` -/// MarginRankingLoss model(MarginRankingLossOptions().margin(0.5).reduction(torch::kSum)); +/// MarginRankingLoss +/// model(MarginRankingLossOptions().margin(0.5).reduction(torch::kSum)); /// ``` // NOLINTNEXTLINE(bugprone-exception-escape) -struct TORCH_API MarginRankingLossImpl : public Cloneable { +struct TORCH_API MarginRankingLossImpl + : public Cloneable { explicit MarginRankingLossImpl(const MarginRankingLossOptions& options_ = {}); void reset() override; @@ -649,18 +681,20 @@ struct TORCH_API MarginRankingLossImpl : public Cloneable /// Pretty prints the `MarginRankingLoss` module into the given `stream`. void pretty_print(std::ostream& stream) const override; - Tensor forward(const Tensor& input1, - const Tensor& input2, const Tensor& targets); + Tensor forward( + const Tensor& input1, + const Tensor& input2, + const Tensor& targets); /// The options with which this `Module` was constructed. MarginRankingLossOptions options; }; /// A `ModuleHolder` subclass for `MarginRankingLossImpl`. -/// See the documentation for `MarginRankingLossImpl` class to learn what methods it -/// provides, and examples of how to use `MarginRankingLoss` with `torch::nn::MarginRankingLossOptions`. -/// See the documentation for `ModuleHolder` to learn about PyTorch's -/// module storage semantics. +/// See the documentation for `MarginRankingLossImpl` class to learn what +/// methods it provides, and examples of how to use `MarginRankingLoss` with +/// `torch::nn::MarginRankingLossOptions`. See the documentation for +/// `ModuleHolder` to learn about PyTorch's module storage semantics. TORCH_MODULE(MarginRankingLoss); // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ NLLLoss ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -679,17 +713,14 @@ TORCH_MODULE(MarginRankingLoss); /// ``` // NOLINTNEXTLINE(bugprone-exception-escape) struct TORCH_API NLLLossImpl : public Cloneable { - explicit NLLLossImpl( - const NLLLossOptions& options_ = {}); + explicit NLLLossImpl(const NLLLossOptions& options_ = {}); /// Pretty prints the `NLLLoss` module into the given `stream`. void pretty_print(std::ostream& stream) const override; void reset() override; - Tensor forward( - const Tensor& input, - const Tensor& target); + Tensor forward(const Tensor& input, const Tensor& target); /// The options with which this `Module` was constructed. NLLLossOptions options; @@ -700,37 +731,37 @@ struct TORCH_API NLLLossImpl : public Cloneable { /// A `ModuleHolder` subclass for `NLLLossImpl`. /// See the documentation for `NLLLossImpl` class to learn what methods it -/// provides, and examples of how to use `NLLLoss` with `torch::nn::NLLLossOptions`. -/// See the documentation for `ModuleHolder` to learn about PyTorch's -/// module storage semantics. +/// provides, and examples of how to use `NLLLoss` with +/// `torch::nn::NLLLossOptions`. See the documentation for `ModuleHolder` to +/// learn about PyTorch's module storage semantics. TORCH_MODULE(NLLLoss); -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ CrossEntropyLoss ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ CrossEntropyLoss +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -/// Creates a criterion that computes cross entropy loss between input and target. -/// See https://pytorch.org/docs/master/nn.html#torch.nn.CrossEntropyLoss to learn +/// Creates a criterion that computes cross entropy loss between input and +/// target. See +/// https://pytorch.org/docs/master/nn.html#torch.nn.CrossEntropyLoss to learn /// about the exact behavior of this module. /// -/// See the documentation for `torch::nn::CrossEntropyLossOptions` class to learn what -/// constructor arguments are supported for this module. +/// See the documentation for `torch::nn::CrossEntropyLossOptions` class to +/// learn what constructor arguments are supported for this module. /// /// Example: /// ``` -/// CrossEntropyLoss model(CrossEntropyLossOptions().ignore_index(-100).reduction(torch::kMean)); +/// CrossEntropyLoss +/// model(CrossEntropyLossOptions().ignore_index(-100).reduction(torch::kMean)); /// ``` // NOLINTNEXTLINE(bugprone-exception-escape) struct TORCH_API CrossEntropyLossImpl : public Cloneable { - explicit CrossEntropyLossImpl( - const CrossEntropyLossOptions& options_ = {}); + explicit CrossEntropyLossImpl(const CrossEntropyLossOptions& options_ = {}); void reset() override; /// Pretty prints the `CrossEntropyLoss` module into the given `stream`. void pretty_print(std::ostream& stream) const override; - Tensor forward( - const Tensor& input, - const Tensor& target); + Tensor forward(const Tensor& input, const Tensor& target); /// The options with which this `Module` was constructed. CrossEntropyLossOptions options; @@ -740,30 +771,33 @@ struct TORCH_API CrossEntropyLossImpl : public Cloneable { }; /// A `ModuleHolder` subclass for `CrossEntropyLossImpl`. -/// See the documentation for `CrossEntropyLossImpl` class to learn what methods it -/// provides, and examples of how to use `CrossEntropyLoss` with `torch::nn::CrossEntropyLossOptions`. -/// See the documentation for `ModuleHolder` to learn about PyTorch's -/// module storage semantics. +/// See the documentation for `CrossEntropyLossImpl` class to learn what methods +/// it provides, and examples of how to use `CrossEntropyLoss` with +/// `torch::nn::CrossEntropyLossOptions`. See the documentation for +/// `ModuleHolder` to learn about PyTorch's module storage semantics. TORCH_MODULE(CrossEntropyLoss); -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ BCEWithLogitsLoss ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ BCEWithLogitsLoss +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ /// This loss combines a `Sigmoid` layer and the `BCELoss` in one single /// class. This version is more numerically stable than using a plain `Sigmoid` /// followed by a `BCELoss` as, by combining the operations into one layer, /// we take advantage of the log-sum-exp trick for numerical stability. -/// See https://pytorch.org/docs/master/nn.html#torch.nn.BCEWithLogitsLoss to learn -/// about the exact behavior of this module. +/// See https://pytorch.org/docs/master/nn.html#torch.nn.BCEWithLogitsLoss to +/// learn about the exact behavior of this module. /// -/// See the documentation for `torch::nn::BCEWithLogitsLossOptions` class to learn what -/// constructor arguments are supported for this module. +/// See the documentation for `torch::nn::BCEWithLogitsLossOptions` class to +/// learn what constructor arguments are supported for this module. /// /// Example: /// ``` -/// BCEWithLogitsLoss model(BCEWithLogitsLossOptions().reduction(torch::kNone).weight(weight)); +/// BCEWithLogitsLoss +/// model(BCEWithLogitsLossOptions().reduction(torch::kNone).weight(weight)); /// ``` // NOLINTNEXTLINE(bugprone-exception-escape) -struct TORCH_API BCEWithLogitsLossImpl : public Cloneable { +struct TORCH_API BCEWithLogitsLossImpl + : public Cloneable { explicit BCEWithLogitsLossImpl(const BCEWithLogitsLossOptions& options_ = {}); void reset() override; @@ -784,10 +818,10 @@ struct TORCH_API BCEWithLogitsLossImpl : public Cloneable }; /// A `ModuleHolder` subclass for `BCEWithLogitsLossImpl`. -/// See the documentation for `BCEWithLogitsLossImpl` class to learn what methods it -/// provides, and examples of how to use `BCEWithLogitsLoss` with `torch::nn::BCEWithLogitsLossOptions`. -/// See the documentation for `ModuleHolder` to learn about PyTorch's -/// module storage semantics. +/// See the documentation for `BCEWithLogitsLossImpl` class to learn what +/// methods it provides, and examples of how to use `BCEWithLogitsLoss` with +/// `torch::nn::BCEWithLogitsLossOptions`. See the documentation for +/// `ModuleHolder` to learn about PyTorch's module storage semantics. TORCH_MODULE(BCEWithLogitsLoss); } // namespace nn diff --git a/torch/csrc/api/include/torch/nn/modules/normalization.h b/torch/csrc/api/include/torch/nn/modules/normalization.h index bba2fc022d859f..d57c26c9410336 100644 --- a/torch/csrc/api/include/torch/nn/modules/normalization.h +++ b/torch/csrc/api/include/torch/nn/modules/normalization.h @@ -1,8 +1,8 @@ #pragma once #include -#include #include +#include #include #include #include @@ -25,7 +25,8 @@ namespace nn { /// /// Example: /// ``` -/// LayerNorm model(LayerNormOptions({2, 2}).elementwise_affine(false).eps(2e-5)); +/// LayerNorm model(LayerNormOptions({2, +/// 2}).elementwise_affine(false).eps(2e-5)); /// ``` // NOLINTNEXTLINE(bugprone-exception-escape) class TORCH_API LayerNormImpl : public torch::nn::Cloneable { @@ -55,38 +56,43 @@ class TORCH_API LayerNormImpl : public torch::nn::Cloneable { LayerNormOptions options; /// The learned weight. - /// Initialized to ones if the `elementwise_affine` option is set to `true` upon construction. + /// Initialized to ones if the `elementwise_affine` option is set to `true` + /// upon construction. Tensor weight; /// The learned bias. - /// Initialized to zeros `elementwise_affine` option is set to `true` upon construction. + /// Initialized to zeros `elementwise_affine` option is set to `true` upon + /// construction. Tensor bias; }; /// A `ModuleHolder` subclass for `LayerNormImpl`. /// See the documentation for `LayerNormImpl` class to learn what methods it -/// provides, and examples of how to use `LayerNorm` with `torch::nn::LayerNormOptions`. -/// See the documentation for `ModuleHolder` to learn about PyTorch's -/// module storage semantics. +/// provides, and examples of how to use `LayerNorm` with +/// `torch::nn::LayerNormOptions`. See the documentation for `ModuleHolder` to +/// learn about PyTorch's module storage semantics. TORCH_MODULE(LayerNorm); -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ LocalResponseNorm ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ LocalResponseNorm +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ /// Applies local response normalization over an input signal composed /// of several input planes, where channels occupy the second dimension. /// Applies normalization across channels. -/// See https://pytorch.org/docs/master/nn.html#torch.nn.LocalResponseNorm to learn -/// about the exact behavior of this module. +/// See https://pytorch.org/docs/master/nn.html#torch.nn.LocalResponseNorm to +/// learn about the exact behavior of this module. /// -/// See the documentation for `torch::nn::LocalResponseNormOptions` class to learn what -/// constructor arguments are supported for this module. +/// See the documentation for `torch::nn::LocalResponseNormOptions` class to +/// learn what constructor arguments are supported for this module. /// /// Example: /// ``` -/// LocalResponseNorm model(LocalResponseNormOptions(2).alpha(0.0002).beta(0.85).k(2.)); +/// LocalResponseNorm +/// model(LocalResponseNormOptions(2).alpha(0.0002).beta(0.85).k(2.)); /// ``` // NOLINTNEXTLINE(bugprone-exception-escape) -class TORCH_API LocalResponseNormImpl : public Cloneable { +class TORCH_API LocalResponseNormImpl + : public Cloneable { public: LocalResponseNormImpl(int64_t size) : LocalResponseNormImpl(LocalResponseNormOptions(size)) {} @@ -104,23 +110,24 @@ class TORCH_API LocalResponseNormImpl : public Cloneable }; /// A `ModuleHolder` subclass for `LocalResponseNormImpl`. -/// See the documentation for `LocalResponseNormImpl` class to learn what methods it -/// provides, and examples of how to use `LocalResponseNorm` with `torch::nn::LocalResponseNormOptions`. -/// See the documentation for `ModuleHolder` to learn about PyTorch's -/// module storage semantics. +/// See the documentation for `LocalResponseNormImpl` class to learn what +/// methods it provides, and examples of how to use `LocalResponseNorm` with +/// `torch::nn::LocalResponseNormOptions`. See the documentation for +/// `ModuleHolder` to learn about PyTorch's module storage semantics. TORCH_MODULE(LocalResponseNorm); // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ CrossMapLRN2d ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -/// See the documentation for `torch::nn::CrossMapLRN2dOptions` class to learn what -/// constructor arguments are supported for this module. +/// See the documentation for `torch::nn::CrossMapLRN2dOptions` class to learn +/// what constructor arguments are supported for this module. /// /// Example: /// ``` /// CrossMapLRN2d model(CrossMapLRN2dOptions(3).alpha(1e-5).beta(0.1).k(10)); /// ``` // NOLINTNEXTLINE(bugprone-exception-escape) -class TORCH_API CrossMapLRN2dImpl : public torch::nn::Cloneable { +class TORCH_API CrossMapLRN2dImpl + : public torch::nn::Cloneable { public: CrossMapLRN2dImpl(int64_t size) : CrossMapLRN2dImpl(CrossMapLRN2dOptions(size)) {} @@ -139,9 +146,9 @@ class TORCH_API CrossMapLRN2dImpl : public torch::nn::Cloneable { /// A `ModuleHolder` subclass for `GroupNormImpl`. /// See the documentation for `GroupNormImpl` class to learn what methods it -/// provides, and examples of how to use `GroupNorm` with `torch::nn::GroupNormOptions`. -/// See the documentation for `ModuleHolder` to learn about PyTorch's -/// module storage semantics. +/// provides, and examples of how to use `GroupNorm` with +/// `torch::nn::GroupNormOptions`. See the documentation for `ModuleHolder` to +/// learn about PyTorch's module storage semantics. TORCH_MODULE(GroupNorm); } // namespace nn diff --git a/torch/csrc/api/include/torch/nn/modules/padding.h b/torch/csrc/api/include/torch/nn/modules/padding.h index 5ad3087e4011c1..3efa41af8fb8be 100644 --- a/torch/csrc/api/include/torch/nn/modules/padding.h +++ b/torch/csrc/api/include/torch/nn/modules/padding.h @@ -14,7 +14,7 @@ template // NOLINTNEXTLINE(bugprone-exception-escape) class TORCH_API ReflectionPadImpl : public torch::nn::Cloneable { public: - ReflectionPadImpl(ExpandingArray padding) + ReflectionPadImpl(ExpandingArray padding) : ReflectionPadImpl(ReflectionPadOptions(padding)) {} explicit ReflectionPadImpl(const ReflectionPadOptions& options_); @@ -29,66 +29,71 @@ class TORCH_API ReflectionPadImpl : public torch::nn::Cloneable { ReflectionPadOptions options; }; -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ReflectionPad1d ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ReflectionPad1d +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ /// Applies ReflectionPad over a 1-D input. -/// See https://pytorch.org/docs/master/nn.html#torch.nn.ReflectionPad1d to learn -/// about the exact behavior of this module. +/// See https://pytorch.org/docs/master/nn.html#torch.nn.ReflectionPad1d to +/// learn about the exact behavior of this module. /// -/// See the documentation for `torch::nn::ReflectionPad1dOptions` class to learn what -/// constructor arguments are supported for this module. +/// See the documentation for `torch::nn::ReflectionPad1dOptions` class to learn +/// what constructor arguments are supported for this module. /// /// Example: /// ``` /// ReflectionPad1d model(ReflectionPad1dOptions({3, 1})); /// ``` // NOLINTNEXTLINE(bugprone-exception-escape) -class TORCH_API ReflectionPad1dImpl : public ReflectionPadImpl<1, ReflectionPad1dImpl> { +class TORCH_API ReflectionPad1dImpl + : public ReflectionPadImpl<1, ReflectionPad1dImpl> { public: using ReflectionPadImpl<1, ReflectionPad1dImpl>::ReflectionPadImpl; }; /// A `ModuleHolder` subclass for `ReflectionPad1dImpl`. -/// See the documentation for `ReflectionPad1dImpl` class to learn what methods it -/// provides, and examples of how to use `ReflectionPad1d` with `torch::nn::ReflectionPad1dOptions`. -/// See the documentation for `ModuleHolder` to learn about PyTorch's -/// module storage semantics. +/// See the documentation for `ReflectionPad1dImpl` class to learn what methods +/// it provides, and examples of how to use `ReflectionPad1d` with +/// `torch::nn::ReflectionPad1dOptions`. See the documentation for +/// `ModuleHolder` to learn about PyTorch's module storage semantics. TORCH_MODULE(ReflectionPad1d); -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ReflectionPad2d ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ReflectionPad2d +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ /// Applies ReflectionPad over a 2-D input. -/// See https://pytorch.org/docs/master/nn.html#torch.nn.ReflectionPad2d to learn -/// about the exact behavior of this module. +/// See https://pytorch.org/docs/master/nn.html#torch.nn.ReflectionPad2d to +/// learn about the exact behavior of this module. /// -/// See the documentation for `torch::nn::ReflectionPad2dOptions` class to learn what -/// constructor arguments are supported for this module. +/// See the documentation for `torch::nn::ReflectionPad2dOptions` class to learn +/// what constructor arguments are supported for this module. /// /// Example: /// ``` /// ReflectionPad2d model(ReflectionPad2dOptions({1, 1, 2, 0})); /// ``` // NOLINTNEXTLINE(bugprone-exception-escape) -class TORCH_API ReflectionPad2dImpl : public ReflectionPadImpl<2, ReflectionPad2dImpl> { +class TORCH_API ReflectionPad2dImpl + : public ReflectionPadImpl<2, ReflectionPad2dImpl> { public: using ReflectionPadImpl<2, ReflectionPad2dImpl>::ReflectionPadImpl; }; /// A `ModuleHolder` subclass for `ReflectionPad2dImpl`. -/// See the documentation for `ReflectionPad2dImpl` class to learn what methods it -/// provides, and examples of how to use `ReflectionPad2d` with `torch::nn::ReflectionPad2dOptions`. -/// See the documentation for `ModuleHolder` to learn about PyTorch's -/// module storage semantics. +/// See the documentation for `ReflectionPad2dImpl` class to learn what methods +/// it provides, and examples of how to use `ReflectionPad2d` with +/// `torch::nn::ReflectionPad2dOptions`. See the documentation for +/// `ModuleHolder` to learn about PyTorch's module storage semantics. TORCH_MODULE(ReflectionPad2d); -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ReflectionPad3d ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ReflectionPad3d +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ /// Applies ReflectionPad over a 3-D input. -/// See https://pytorch.org/docs/master/nn.html#torch.nn.ReflectionPad3d to learn -/// about the exact behavior of this module. +/// See https://pytorch.org/docs/master/nn.html#torch.nn.ReflectionPad3d to +/// learn about the exact behavior of this module. /// -/// See the documentation for `torch::nn::ReflectionPad3dOptions` class to learn what -/// constructor arguments are supported for this module. +/// See the documentation for `torch::nn::ReflectionPad3dOptions` class to learn +/// what constructor arguments are supported for this module. /// /// Example: /// ``` @@ -96,16 +101,17 @@ TORCH_MODULE(ReflectionPad2d); /// ReflectionPad3d model(ReflectionPad3dOptions({1, 1, 2, 0, 1, 2})); /// ``` // NOLINTNEXTLINE(bugprone-exception-escape) -class TORCH_API ReflectionPad3dImpl : public ReflectionPadImpl<3, ReflectionPad3dImpl> { +class TORCH_API ReflectionPad3dImpl + : public ReflectionPadImpl<3, ReflectionPad3dImpl> { public: using ReflectionPadImpl<3, ReflectionPad3dImpl>::ReflectionPadImpl; }; /// A `ModuleHolder` subclass for `ReflectionPad3dImpl`. -/// See the documentation for `ReflectionPad3dImpl` class to learn what methods it -/// provides, and examples of how to use `ReflectionPad3d` with `torch::nn::ReflectionPad3dOptions`. -/// See the documentation for `ModuleHolder` to learn about PyTorch's -/// module storage semantics. +/// See the documentation for `ReflectionPad3dImpl` class to learn what methods +/// it provides, and examples of how to use `ReflectionPad3d` with +/// `torch::nn::ReflectionPad3dOptions`. See the documentation for +/// `ModuleHolder` to learn about PyTorch's module storage semantics. TORCH_MODULE(ReflectionPad3d); // ============================================================================ @@ -115,7 +121,7 @@ template // NOLINTNEXTLINE(bugprone-exception-escape) class TORCH_API ReplicationPadImpl : public torch::nn::Cloneable { public: - ReplicationPadImpl(ExpandingArray padding) + ReplicationPadImpl(ExpandingArray padding) : ReplicationPadImpl(ReplicationPadOptions(padding)) {} explicit ReplicationPadImpl(const ReplicationPadOptions& options_); @@ -130,82 +136,88 @@ class TORCH_API ReplicationPadImpl : public torch::nn::Cloneable { ReplicationPadOptions options; }; -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ReplicationPad1d ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ReplicationPad1d +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ /// Applies ReplicationPad over a 1-D input. -/// See https://pytorch.org/docs/master/nn.html#torch.nn.ReplicationPad1d to learn -/// about the exact behavior of this module. +/// See https://pytorch.org/docs/master/nn.html#torch.nn.ReplicationPad1d to +/// learn about the exact behavior of this module. /// -/// See the documentation for `torch::nn::ReplicationPad1dOptions` class to learn what -/// constructor arguments are supported for this module. +/// See the documentation for `torch::nn::ReplicationPad1dOptions` class to +/// learn what constructor arguments are supported for this module. /// /// Example: /// ``` /// ReplicationPad1d model(ReplicationPad1dOptions({3, 1})); /// ``` // NOLINTNEXTLINE(bugprone-exception-escape) -class TORCH_API ReplicationPad1dImpl : public ReplicationPadImpl<1, ReplicationPad1dImpl> { +class TORCH_API ReplicationPad1dImpl + : public ReplicationPadImpl<1, ReplicationPad1dImpl> { public: using ReplicationPadImpl<1, ReplicationPad1dImpl>::ReplicationPadImpl; }; /// A `ModuleHolder` subclass for `ReplicationPad1dImpl`. -/// See the documentation for `ReplicationPad1dImpl` class to learn what methods it -/// provides, and examples of how to use `ReplicationPad1d` with `torch::nn::ReplicationPad1dOptions`. -/// See the documentation for `ModuleHolder` to learn about PyTorch's -/// module storage semantics. +/// See the documentation for `ReplicationPad1dImpl` class to learn what methods +/// it provides, and examples of how to use `ReplicationPad1d` with +/// `torch::nn::ReplicationPad1dOptions`. See the documentation for +/// `ModuleHolder` to learn about PyTorch's module storage semantics. TORCH_MODULE(ReplicationPad1d); -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ReplicationPad2d ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ReplicationPad2d +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ /// Applies ReplicationPad over a 2-D input. -/// See https://pytorch.org/docs/master/nn.html#torch.nn.ReplicationPad2d to learn -/// about the exact behavior of this module. +/// See https://pytorch.org/docs/master/nn.html#torch.nn.ReplicationPad2d to +/// learn about the exact behavior of this module. /// -/// See the documentation for `torch::nn::ReplicationPad2dOptions` class to learn what -/// constructor arguments are supported for this module. +/// See the documentation for `torch::nn::ReplicationPad2dOptions` class to +/// learn what constructor arguments are supported for this module. /// /// Example: /// ``` /// ReplicationPad2d model(ReplicationPad2dOptions({1, 1, 2, 0})); /// ``` // NOLINTNEXTLINE(bugprone-exception-escape) -class TORCH_API ReplicationPad2dImpl : public ReplicationPadImpl<2, ReplicationPad2dImpl> { +class TORCH_API ReplicationPad2dImpl + : public ReplicationPadImpl<2, ReplicationPad2dImpl> { public: using ReplicationPadImpl<2, ReplicationPad2dImpl>::ReplicationPadImpl; }; /// A `ModuleHolder` subclass for `ReplicationPad2dImpl`. -/// See the documentation for `ReplicationPad2dImpl` class to learn what methods it -/// provides, and examples of how to use `ReplicationPad2d` with `torch::nn::ReplicationPad2dOptions`. -/// See the documentation for `ModuleHolder` to learn about PyTorch's -/// module storage semantics. +/// See the documentation for `ReplicationPad2dImpl` class to learn what methods +/// it provides, and examples of how to use `ReplicationPad2d` with +/// `torch::nn::ReplicationPad2dOptions`. See the documentation for +/// `ModuleHolder` to learn about PyTorch's module storage semantics. TORCH_MODULE(ReplicationPad2d); -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ReplicationPad3d ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ReplicationPad3d +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ /// Applies ReplicationPad over a 3-D input. -/// See https://pytorch.org/docs/master/nn.html#torch.nn.ReplicationPad3d to learn -/// about the exact behavior of this module. +/// See https://pytorch.org/docs/master/nn.html#torch.nn.ReplicationPad3d to +/// learn about the exact behavior of this module. /// -/// See the documentation for `torch::nn::ReplicationPad3dOptions` class to learn what -/// constructor arguments are supported for this module. +/// See the documentation for `torch::nn::ReplicationPad3dOptions` class to +/// learn what constructor arguments are supported for this module. /// /// Example: /// ``` /// ReplicationPad3d model(ReplicationPad3dOptions({1, 2, 1, 2, 1, 2})); /// ``` // NOLINTNEXTLINE(bugprone-exception-escape) -class TORCH_API ReplicationPad3dImpl : public ReplicationPadImpl<3, ReplicationPad3dImpl> { +class TORCH_API ReplicationPad3dImpl + : public ReplicationPadImpl<3, ReplicationPad3dImpl> { public: using ReplicationPadImpl<3, ReplicationPad3dImpl>::ReplicationPadImpl; }; /// A `ModuleHolder` subclass for `ReplicationPad3dImpl`. -/// See the documentation for `ReplicationPad3dImpl` class to learn what methods it -/// provides, and examples of how to use `ReplicationPad3d` with `torch::nn::ReplicationPad3dOptions`. -/// See the documentation for `ModuleHolder` to learn about PyTorch's -/// module storage semantics. +/// See the documentation for `ReplicationPad3dImpl` class to learn what methods +/// it provides, and examples of how to use `ReplicationPad3d` with +/// `torch::nn::ReplicationPad3dOptions`. See the documentation for +/// `ModuleHolder` to learn about PyTorch's module storage semantics. TORCH_MODULE(ReplicationPad3d); // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ZeroPad2d ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -241,9 +253,9 @@ class TORCH_API ZeroPad2dImpl : public Cloneable { /// A `ModuleHolder` subclass for `ZeroPad2dImpl`. /// See the documentation for `ZeroPad2dImpl` class to learn what methods it -/// provides, and examples of how to use `ZeroPad2d` with `torch::nn::ZeroPad2dOptions`. -/// See the documentation for `ModuleHolder` to learn about PyTorch's -/// module storage semantics. +/// provides, and examples of how to use `ZeroPad2d` with +/// `torch::nn::ZeroPad2dOptions`. See the documentation for `ModuleHolder` to +/// learn about PyTorch's module storage semantics. TORCH_MODULE(ZeroPad2d); // ============================================================================ @@ -253,7 +265,7 @@ template // NOLINTNEXTLINE(bugprone-exception-escape) class TORCH_API ConstantPadImpl : public torch::nn::Cloneable { public: - ConstantPadImpl(ExpandingArray padding, double value) + ConstantPadImpl(ExpandingArray padding, double value) : ConstantPadImpl(ConstantPadOptions(padding, value)) {} explicit ConstantPadImpl(const ConstantPadOptions& options_); @@ -274,24 +286,25 @@ class TORCH_API ConstantPadImpl : public torch::nn::Cloneable { /// See https://pytorch.org/docs/master/nn.html#torch.nn.ConstantPad1d to learn /// about the exact behavior of this module. /// -/// See the documentation for `torch::nn::ConstantPad1dOptions` class to learn what -/// constructor arguments are supported for this module. +/// See the documentation for `torch::nn::ConstantPad1dOptions` class to learn +/// what constructor arguments are supported for this module. /// /// Example: /// ``` /// ConstantPad1d model(ConstantPad1dOptions({3, 1}, 3.5)); /// ``` // NOLINTNEXTLINE(bugprone-exception-escape) -class TORCH_API ConstantPad1dImpl : public ConstantPadImpl<1, ConstantPad1dImpl> { +class TORCH_API ConstantPad1dImpl + : public ConstantPadImpl<1, ConstantPad1dImpl> { public: using ConstantPadImpl<1, ConstantPad1dImpl>::ConstantPadImpl; }; /// A `ModuleHolder` subclass for `ConstantPad1dImpl`. /// See the documentation for `ConstantPad1dImpl` class to learn what methods it -/// provides, and examples of how to use `ConstantPad1d` with `torch::nn::ConstantPad1dOptions`. -/// See the documentation for `ModuleHolder` to learn about PyTorch's -/// module storage semantics. +/// provides, and examples of how to use `ConstantPad1d` with +/// `torch::nn::ConstantPad1dOptions`. See the documentation for `ModuleHolder` +/// to learn about PyTorch's module storage semantics. TORCH_MODULE(ConstantPad1d); // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ConstantPad2d ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -300,24 +313,25 @@ TORCH_MODULE(ConstantPad1d); /// See https://pytorch.org/docs/master/nn.html#torch.nn.ConstantPad2d to learn /// about the exact behavior of this module. /// -/// See the documentation for `torch::nn::ConstantPad2dOptions` class to learn what -/// constructor arguments are supported for this module. +/// See the documentation for `torch::nn::ConstantPad2dOptions` class to learn +/// what constructor arguments are supported for this module. /// /// Example: /// ``` /// ConstantPad2d model(ConstantPad2dOptions({3, 0, 2, 1}, 3.5)); /// ``` // NOLINTNEXTLINE(bugprone-exception-escape) -class TORCH_API ConstantPad2dImpl : public ConstantPadImpl<2, ConstantPad2dImpl> { +class TORCH_API ConstantPad2dImpl + : public ConstantPadImpl<2, ConstantPad2dImpl> { public: using ConstantPadImpl<2, ConstantPad2dImpl>::ConstantPadImpl; }; /// A `ModuleHolder` subclass for `ConstantPad2dImpl`. /// See the documentation for `ConstantPad2dImpl` class to learn what methods it -/// provides, and examples of how to use `ConstantPad2d` with `torch::nn::ConstantPad2dOptions`. -/// See the documentation for `ModuleHolder` to learn about PyTorch's -/// module storage semantics. +/// provides, and examples of how to use `ConstantPad2d` with +/// `torch::nn::ConstantPad2dOptions`. See the documentation for `ModuleHolder` +/// to learn about PyTorch's module storage semantics. TORCH_MODULE(ConstantPad2d); // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ConstantPad3d ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -326,24 +340,25 @@ TORCH_MODULE(ConstantPad2d); /// See https://pytorch.org/docs/master/nn.html#torch.nn.ConstantPad3d to learn /// about the exact behavior of this module. /// -/// See the documentation for `torch::nn::ConstantPad3dOptions` class to learn what -/// constructor arguments are supported for this module. +/// See the documentation for `torch::nn::ConstantPad3dOptions` class to learn +/// what constructor arguments are supported for this module. /// /// Example: /// ``` /// ConstantPad3d model(ConstantPad3dOptions({1, 2, 1, 2, 1, 2}, 3.5)); /// ``` // NOLINTNEXTLINE(bugprone-exception-escape) -class TORCH_API ConstantPad3dImpl : public ConstantPadImpl<3, ConstantPad3dImpl> { +class TORCH_API ConstantPad3dImpl + : public ConstantPadImpl<3, ConstantPad3dImpl> { public: using ConstantPadImpl<3, ConstantPad3dImpl>::ConstantPadImpl; }; /// A `ModuleHolder` subclass for `ConstantPad3dImpl`. /// See the documentation for `ConstantPad3dImpl` class to learn what methods it -/// provides, and examples of how to use `ConstantPad3d` with `torch::nn::ConstantPad3dOptions`. -/// See the documentation for `ModuleHolder` to learn about PyTorch's -/// module storage semantics. +/// provides, and examples of how to use `ConstantPad3d` with +/// `torch::nn::ConstantPad3dOptions`. See the documentation for `ModuleHolder` +/// to learn about PyTorch's module storage semantics. TORCH_MODULE(ConstantPad3d); } // namespace nn diff --git a/torch/csrc/api/include/torch/nn/modules/pixelshuffle.h b/torch/csrc/api/include/torch/nn/modules/pixelshuffle.h index f24183ce364175..3fb456f618a5fb 100644 --- a/torch/csrc/api/include/torch/nn/modules/pixelshuffle.h +++ b/torch/csrc/api/include/torch/nn/modules/pixelshuffle.h @@ -9,7 +9,8 @@ namespace torch { namespace nn { -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ PixelShuffle ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ PixelShuffle +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ /// Rearranges elements in a tensor of shape :math:`(*, C \times r^2, H, W)` /// to a tensor of shape :math:`(*, C, H \times r, W \times r)`, where r is an @@ -25,7 +26,8 @@ namespace nn { /// PixelShuffle model(PixelShuffleOptions(5)); /// ``` // NOLINTNEXTLINE(bugprone-exception-escape) -struct TORCH_API PixelShuffleImpl : public torch::nn::Cloneable { +struct TORCH_API PixelShuffleImpl + : public torch::nn::Cloneable { explicit PixelShuffleImpl(const PixelShuffleOptions& options_); /// Pretty prints the `PixelShuffle` module into the given `stream`. @@ -41,9 +43,9 @@ struct TORCH_API PixelShuffleImpl : public torch::nn::Cloneable #include -#include #include #include +#include #include @@ -51,9 +51,9 @@ class TORCH_API AvgPool1dImpl : public AvgPoolImpl<1, AvgPool1dImpl> { /// A `ModuleHolder` subclass for `AvgPool1dImpl`. /// See the documentation for `AvgPool1dImpl` class to learn what methods it -/// provides, and examples of how to use `AvgPool1d` with `torch::nn::AvgPool1dOptions`. -/// See the documentation for `ModuleHolder` to learn about PyTorch's -/// module storage semantics. +/// provides, and examples of how to use `AvgPool1d` with +/// `torch::nn::AvgPool1dOptions`. See the documentation for `ModuleHolder` to +/// learn about PyTorch's module storage semantics. TORCH_MODULE(AvgPool1d); // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ AvgPool2d ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -78,9 +78,9 @@ class TORCH_API AvgPool2dImpl : public AvgPoolImpl<2, AvgPool2dImpl> { /// A `ModuleHolder` subclass for `AvgPool2dImpl`. /// See the documentation for `AvgPool2dImpl` class to learn what methods it -/// provides, and examples of how to use `AvgPool2d` with `torch::nn::AvgPool2dOptions`. -/// See the documentation for `ModuleHolder` to learn about PyTorch's -/// module storage semantics. +/// provides, and examples of how to use `AvgPool2d` with +/// `torch::nn::AvgPool2dOptions`. See the documentation for `ModuleHolder` to +/// learn about PyTorch's module storage semantics. TORCH_MODULE(AvgPool2d); // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ AvgPool3d ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -105,9 +105,9 @@ class TORCH_API AvgPool3dImpl : public AvgPoolImpl<3, AvgPool3dImpl> { /// A `ModuleHolder` subclass for `AvgPool3dImpl`. /// See the documentation for `AvgPool3dImpl` class to learn what methods it -/// provides, and examples of how to use `AvgPool3d` with `torch::nn::AvgPool3dOptions`. -/// See the documentation for `ModuleHolder` to learn about PyTorch's -/// module storage semantics. +/// provides, and examples of how to use `AvgPool3d` with +/// `torch::nn::AvgPool3dOptions`. See the documentation for `ModuleHolder` to +/// learn about PyTorch's module storage semantics. TORCH_MODULE(AvgPool3d); // ============================================================================ @@ -156,9 +156,9 @@ class TORCH_API MaxPool1dImpl : public MaxPoolImpl<1, MaxPool1dImpl> { /// A `ModuleHolder` subclass for `MaxPool1dImpl`. /// See the documentation for `MaxPool1dImpl` class to learn what methods it -/// provides, and examples of how to use `MaxPool1d` with `torch::nn::MaxPool1dOptions`. -/// See the documentation for `ModuleHolder` to learn about PyTorch's -/// module storage semantics. +/// provides, and examples of how to use `MaxPool1d` with +/// `torch::nn::MaxPool1dOptions`. See the documentation for `ModuleHolder` to +/// learn about PyTorch's module storage semantics. TORCH_MODULE(MaxPool1d); // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ MaxPool2d ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -187,9 +187,9 @@ class TORCH_API MaxPool2dImpl : public MaxPoolImpl<2, MaxPool2dImpl> { /// A `ModuleHolder` subclass for `MaxPool2dImpl`. /// See the documentation for `MaxPool2dImpl` class to learn what methods it -/// provides, and examples of how to use `MaxPool2d` with `torch::nn::MaxPool2dOptions`. -/// See the documentation for `ModuleHolder` to learn about PyTorch's -/// module storage semantics. +/// provides, and examples of how to use `MaxPool2d` with +/// `torch::nn::MaxPool2dOptions`. See the documentation for `ModuleHolder` to +/// learn about PyTorch's module storage semantics. TORCH_MODULE(MaxPool2d); // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ MaxPool3d ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -218,9 +218,9 @@ class TORCH_API MaxPool3dImpl : public MaxPoolImpl<3, MaxPool3dImpl> { /// A `ModuleHolder` subclass for `MaxPool3dImpl`. /// See the documentation for `MaxPool3dImpl` class to learn what methods it -/// provides, and examples of how to use `MaxPool3d` with `torch::nn::MaxPool3dOptions`. -/// See the documentation for `ModuleHolder` to learn about PyTorch's -/// module storage semantics. +/// provides, and examples of how to use `MaxPool3d` with +/// `torch::nn::MaxPool3dOptions`. See the documentation for `ModuleHolder` to +/// learn about PyTorch's module storage semantics. TORCH_MODULE(MaxPool3d); // ============================================================================ @@ -231,13 +231,16 @@ template class TORCH_API AdaptiveMaxPoolImpl : public torch::nn::Cloneable { public: AdaptiveMaxPoolImpl(output_size_t output_size) - : AdaptiveMaxPoolImpl(AdaptiveMaxPoolOptions(output_size)) {} + : AdaptiveMaxPoolImpl( + AdaptiveMaxPoolOptions(output_size)) {} explicit AdaptiveMaxPoolImpl( - const AdaptiveMaxPoolOptions& options_) : options(options_) {} + const AdaptiveMaxPoolOptions& options_) + : options(options_) {} - void reset() override {}; + void reset() override{}; - /// Pretty prints the `AdaptiveMaxPool{1,2,3}d` module into the given `stream`. + /// Pretty prints the `AdaptiveMaxPool{1,2,3}d` module into the given + /// `stream`. void pretty_print(std::ostream& stream) const override { stream << "torch::nn::AdaptiveMaxPool" << D << "d" << "(output_size=" << options.output_size() << ")"; @@ -250,21 +253,22 @@ class TORCH_API AdaptiveMaxPoolImpl : public torch::nn::Cloneable { // ~~~~~~~~~~~~~~~~~~~~~~~~~~~ AdaptiveMaxPool1d ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ /// Applies adaptive maxpool over a 1-D input. -/// See https://pytorch.org/docs/master/nn.html#torch.nn.AdaptiveMaxPool1d to learn -/// about the exact behavior of this module. +/// See https://pytorch.org/docs/master/nn.html#torch.nn.AdaptiveMaxPool1d to +/// learn about the exact behavior of this module. /// -/// See the documentation for `torch::nn::AdaptiveMaxPool1dOptions` class to learn what -/// constructor arguments are supported for this module. +/// See the documentation for `torch::nn::AdaptiveMaxPool1dOptions` class to +/// learn what constructor arguments are supported for this module. /// /// Example: /// ``` /// AdaptiveMaxPool1d model(AdaptiveMaxPool1dOptions(3)); /// ``` // NOLINTNEXTLINE(bugprone-exception-escape) -class TORCH_API AdaptiveMaxPool1dImpl : - public AdaptiveMaxPoolImpl<1, ExpandingArray<1>, AdaptiveMaxPool1dImpl> { +class TORCH_API AdaptiveMaxPool1dImpl + : public AdaptiveMaxPoolImpl<1, ExpandingArray<1>, AdaptiveMaxPool1dImpl> { public: - using AdaptiveMaxPoolImpl<1, ExpandingArray<1>, AdaptiveMaxPool1dImpl>::AdaptiveMaxPoolImpl; + using AdaptiveMaxPoolImpl<1, ExpandingArray<1>, AdaptiveMaxPool1dImpl>:: + AdaptiveMaxPoolImpl; Tensor forward(const Tensor& input); @@ -274,30 +278,35 @@ class TORCH_API AdaptiveMaxPool1dImpl : }; /// A `ModuleHolder` subclass for `AdaptiveMaxPool1dImpl`. -/// See the documentation for `AdaptiveMaxPool1dImpl` class to learn what methods it -/// provides, and examples of how to use `AdaptiveMaxPool1d` with `torch::nn::AdaptiveMaxPool1dOptions`. -/// See the documentation for `ModuleHolder` to learn about PyTorch's -/// module storage semantics. +/// See the documentation for `AdaptiveMaxPool1dImpl` class to learn what +/// methods it provides, and examples of how to use `AdaptiveMaxPool1d` with +/// `torch::nn::AdaptiveMaxPool1dOptions`. See the documentation for +/// `ModuleHolder` to learn about PyTorch's module storage semantics. TORCH_MODULE(AdaptiveMaxPool1d); // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ AdaptiveMaxPool2d ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ /// Applies adaptive maxpool over a 2-D input. -/// See https://pytorch.org/docs/master/nn.html#torch.nn.AdaptiveMaxPool2d to learn -/// about the exact behavior of this module. +/// See https://pytorch.org/docs/master/nn.html#torch.nn.AdaptiveMaxPool2d to +/// learn about the exact behavior of this module. /// -/// See the documentation for `torch::nn::AdaptiveMaxPool2dOptions` class to learn what -/// constructor arguments are supported for this module. +/// See the documentation for `torch::nn::AdaptiveMaxPool2dOptions` class to +/// learn what constructor arguments are supported for this module. /// /// Example: /// ``` /// AdaptiveMaxPool2d model(AdaptiveMaxPool2dOptions({3, 2})); /// ``` // NOLINTNEXTLINE(bugprone-exception-escape) -class TORCH_API AdaptiveMaxPool2dImpl : - public AdaptiveMaxPoolImpl<2, ExpandingArrayWithOptionalElem<2>, AdaptiveMaxPool2dImpl> { +class TORCH_API AdaptiveMaxPool2dImpl : public AdaptiveMaxPoolImpl< + 2, + ExpandingArrayWithOptionalElem<2>, + AdaptiveMaxPool2dImpl> { public: - using AdaptiveMaxPoolImpl<2, ExpandingArrayWithOptionalElem<2>, AdaptiveMaxPool2dImpl>::AdaptiveMaxPoolImpl; + using AdaptiveMaxPoolImpl< + 2, + ExpandingArrayWithOptionalElem<2>, + AdaptiveMaxPool2dImpl>::AdaptiveMaxPoolImpl; Tensor forward(const Tensor& input); @@ -307,30 +316,35 @@ class TORCH_API AdaptiveMaxPool2dImpl : }; /// A `ModuleHolder` subclass for `AdaptiveMaxPool2dImpl`. -/// See the documentation for `AdaptiveMaxPool2dImpl` class to learn what methods it -/// provides, and examples of how to use `AdaptiveMaxPool2d` with `torch::nn::AdaptiveMaxPool2dOptions`. -/// See the documentation for `ModuleHolder` to learn about PyTorch's -/// module storage semantics. +/// See the documentation for `AdaptiveMaxPool2dImpl` class to learn what +/// methods it provides, and examples of how to use `AdaptiveMaxPool2d` with +/// `torch::nn::AdaptiveMaxPool2dOptions`. See the documentation for +/// `ModuleHolder` to learn about PyTorch's module storage semantics. TORCH_MODULE(AdaptiveMaxPool2d); // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ AdaptiveMaxPool3d ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ /// Applies adaptive maxpool over a 3-D input. -/// See https://pytorch.org/docs/master/nn.html#torch.nn.AdaptiveMaxPool3d to learn -/// about the exact behavior of this module. +/// See https://pytorch.org/docs/master/nn.html#torch.nn.AdaptiveMaxPool3d to +/// learn about the exact behavior of this module. /// -/// See the documentation for `torch::nn::AdaptiveMaxPool3dOptions` class to learn what -/// constructor arguments are supported for this module. +/// See the documentation for `torch::nn::AdaptiveMaxPool3dOptions` class to +/// learn what constructor arguments are supported for this module. /// /// Example: /// ``` /// AdaptiveMaxPool3d model(AdaptiveMaxPool3dOptions(3)); /// ``` // NOLINTNEXTLINE(bugprone-exception-escape) -class TORCH_API AdaptiveMaxPool3dImpl : - public AdaptiveMaxPoolImpl<3, ExpandingArrayWithOptionalElem<3>, AdaptiveMaxPool3dImpl> { +class TORCH_API AdaptiveMaxPool3dImpl : public AdaptiveMaxPoolImpl< + 3, + ExpandingArrayWithOptionalElem<3>, + AdaptiveMaxPool3dImpl> { public: - using AdaptiveMaxPoolImpl<3, ExpandingArrayWithOptionalElem<3>, AdaptiveMaxPool3dImpl>::AdaptiveMaxPoolImpl; + using AdaptiveMaxPoolImpl< + 3, + ExpandingArrayWithOptionalElem<3>, + AdaptiveMaxPool3dImpl>::AdaptiveMaxPoolImpl; Tensor forward(const Tensor& input); @@ -340,10 +354,10 @@ class TORCH_API AdaptiveMaxPool3dImpl : }; /// A `ModuleHolder` subclass for `AdaptiveMaxPool3dImpl`. -/// See the documentation for `AdaptiveMaxPool3dImpl` class to learn what methods it -/// provides, and examples of how to use `AdaptiveMaxPool3d` with `torch::nn::AdaptiveMaxPool3dOptions`. -/// See the documentation for `ModuleHolder` to learn about PyTorch's -/// module storage semantics. +/// See the documentation for `AdaptiveMaxPool3dImpl` class to learn what +/// methods it provides, and examples of how to use `AdaptiveMaxPool3d` with +/// `torch::nn::AdaptiveMaxPool3dOptions`. See the documentation for +/// `ModuleHolder` to learn about PyTorch's module storage semantics. TORCH_MODULE(AdaptiveMaxPool3d); // ============================================================================ @@ -354,13 +368,16 @@ template class TORCH_API AdaptiveAvgPoolImpl : public torch::nn::Cloneable { public: AdaptiveAvgPoolImpl(output_size_t output_size) - : AdaptiveAvgPoolImpl(AdaptiveAvgPoolOptions(output_size)) {} + : AdaptiveAvgPoolImpl( + AdaptiveAvgPoolOptions(output_size)) {} explicit AdaptiveAvgPoolImpl( - const AdaptiveAvgPoolOptions& options_) : options(options_) {} + const AdaptiveAvgPoolOptions& options_) + : options(options_) {} void reset() override {} - /// Pretty prints the `AdaptiveAvgPool{1,2,3}d` module into the given `stream`. + /// Pretty prints the `AdaptiveAvgPool{1,2,3}d` module into the given + /// `stream`. void pretty_print(std::ostream& stream) const override { stream << "torch::nn::AdaptiveAvgPool" << D << "d" << "(output_size=" << options.output_size() << ")"; @@ -373,88 +390,99 @@ class TORCH_API AdaptiveAvgPoolImpl : public torch::nn::Cloneable { // ~~~~~~~~~~~~~~~~~~~~~~~~~~~ AdaptiveAvgPool1d ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ /// Applies adaptive avgpool over a 1-D input. -/// See https://pytorch.org/docs/master/nn.html#torch.nn.AdaptiveAvgPool1d to learn -/// about the exact behavior of this module. +/// See https://pytorch.org/docs/master/nn.html#torch.nn.AdaptiveAvgPool1d to +/// learn about the exact behavior of this module. /// -/// See the documentation for `torch::nn::AdaptiveAvgPool1dOptions` class to learn what -/// constructor arguments are supported for this module. +/// See the documentation for `torch::nn::AdaptiveAvgPool1dOptions` class to +/// learn what constructor arguments are supported for this module. /// /// Example: /// ``` /// AdaptiveAvgPool1d model(AdaptiveAvgPool1dOptions(5)); /// ``` // NOLINTNEXTLINE(bugprone-exception-escape) -class TORCH_API AdaptiveAvgPool1dImpl : - public AdaptiveAvgPoolImpl<1, ExpandingArray<1>, AdaptiveAvgPool1dImpl> { +class TORCH_API AdaptiveAvgPool1dImpl + : public AdaptiveAvgPoolImpl<1, ExpandingArray<1>, AdaptiveAvgPool1dImpl> { public: - using AdaptiveAvgPoolImpl<1, ExpandingArray<1>, AdaptiveAvgPool1dImpl>::AdaptiveAvgPoolImpl; + using AdaptiveAvgPoolImpl<1, ExpandingArray<1>, AdaptiveAvgPool1dImpl>:: + AdaptiveAvgPoolImpl; Tensor forward(const Tensor& input); }; /// A `ModuleHolder` subclass for `AdaptiveAvgPool1dImpl`. -/// See the documentation for `AdaptiveAvgPool1dImpl` class to learn what methods it -/// provides, and examples of how to use `AdaptiveAvgPool1d` with `torch::nn::AdaptiveAvgPool1dOptions`. -/// See the documentation for `ModuleHolder` to learn about PyTorch's -/// module storage semantics. +/// See the documentation for `AdaptiveAvgPool1dImpl` class to learn what +/// methods it provides, and examples of how to use `AdaptiveAvgPool1d` with +/// `torch::nn::AdaptiveAvgPool1dOptions`. See the documentation for +/// `ModuleHolder` to learn about PyTorch's module storage semantics. TORCH_MODULE(AdaptiveAvgPool1d); // ~~~~~~~~~~~~~~~~~~~~~~~~~~~ AdaptiveAvgPool2d ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ /// Applies adaptive avgpool over a 2-D input. -/// See https://pytorch.org/docs/master/nn.html#torch.nn.AdaptiveAvgPool2d to learn -/// about the exact behavior of this module. +/// See https://pytorch.org/docs/master/nn.html#torch.nn.AdaptiveAvgPool2d to +/// learn about the exact behavior of this module. /// -/// See the documentation for `torch::nn::AdaptiveAvgPool2dOptions` class to learn what -/// constructor arguments are supported for this module. +/// See the documentation for `torch::nn::AdaptiveAvgPool2dOptions` class to +/// learn what constructor arguments are supported for this module. /// /// Example: /// ``` /// AdaptiveAvgPool2d model(AdaptiveAvgPool2dOptions({3, 2})); /// ``` // NOLINTNEXTLINE(bugprone-exception-escape) -class TORCH_API AdaptiveAvgPool2dImpl : - public AdaptiveAvgPoolImpl<2, ExpandingArrayWithOptionalElem<2>, AdaptiveAvgPool2dImpl> { +class TORCH_API AdaptiveAvgPool2dImpl : public AdaptiveAvgPoolImpl< + 2, + ExpandingArrayWithOptionalElem<2>, + AdaptiveAvgPool2dImpl> { public: - using AdaptiveAvgPoolImpl<2, ExpandingArrayWithOptionalElem<2>, AdaptiveAvgPool2dImpl>::AdaptiveAvgPoolImpl; + using AdaptiveAvgPoolImpl< + 2, + ExpandingArrayWithOptionalElem<2>, + AdaptiveAvgPool2dImpl>::AdaptiveAvgPoolImpl; Tensor forward(const Tensor& input); }; /// A `ModuleHolder` subclass for `AdaptiveAvgPool2dImpl`. -/// See the documentation for `AdaptiveAvgPool2dImpl` class to learn what methods it -/// provides, and examples of how to use `AdaptiveAvgPool2d` with `torch::nn::AdaptiveAvgPool2dOptions`. -/// See the documentation for `ModuleHolder` to learn about PyTorch's -/// module storage semantics. +/// See the documentation for `AdaptiveAvgPool2dImpl` class to learn what +/// methods it provides, and examples of how to use `AdaptiveAvgPool2d` with +/// `torch::nn::AdaptiveAvgPool2dOptions`. See the documentation for +/// `ModuleHolder` to learn about PyTorch's module storage semantics. TORCH_MODULE(AdaptiveAvgPool2d); // ~~~~~~~~~~~~~~~~~~~~~~~~~~~ AdaptiveAvgPool3d ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ /// Applies adaptive avgpool over a 3-D input. -/// See https://pytorch.org/docs/master/nn.html#torch.nn.AdaptiveAvgPool3d to learn -/// about the exact behavior of this module. +/// See https://pytorch.org/docs/master/nn.html#torch.nn.AdaptiveAvgPool3d to +/// learn about the exact behavior of this module. /// -/// See the documentation for `torch::nn::AdaptiveAvgPool3dOptions` class to learn what -/// constructor arguments are supported for this module. +/// See the documentation for `torch::nn::AdaptiveAvgPool3dOptions` class to +/// learn what constructor arguments are supported for this module. /// /// Example: /// ``` /// AdaptiveAvgPool3d model(AdaptiveAvgPool3dOptions(3)); /// ``` // NOLINTNEXTLINE(bugprone-exception-escape) -class TORCH_API AdaptiveAvgPool3dImpl : - public AdaptiveAvgPoolImpl<3, ExpandingArrayWithOptionalElem<3>, AdaptiveAvgPool3dImpl> { +class TORCH_API AdaptiveAvgPool3dImpl : public AdaptiveAvgPoolImpl< + 3, + ExpandingArrayWithOptionalElem<3>, + AdaptiveAvgPool3dImpl> { public: - using AdaptiveAvgPoolImpl<3, ExpandingArrayWithOptionalElem<3>, AdaptiveAvgPool3dImpl>::AdaptiveAvgPoolImpl; + using AdaptiveAvgPoolImpl< + 3, + ExpandingArrayWithOptionalElem<3>, + AdaptiveAvgPool3dImpl>::AdaptiveAvgPoolImpl; Tensor forward(const Tensor& input); }; /// A `ModuleHolder` subclass for `AdaptiveAvgPool3dImpl`. -/// See the documentation for `AdaptiveAvgPool3dImpl` class to learn what methods it -/// provides, and examples of how to use `AdaptiveAvgPool3d` with `torch::nn::AdaptiveAvgPool3dOptions`. -/// See the documentation for `ModuleHolder` to learn about PyTorch's -/// module storage semantics. +/// See the documentation for `AdaptiveAvgPool3dImpl` class to learn what +/// methods it provides, and examples of how to use `AdaptiveAvgPool3d` with +/// `torch::nn::AdaptiveAvgPool3dOptions`. See the documentation for +/// `ModuleHolder` to learn about PyTorch's module storage semantics. TORCH_MODULE(AdaptiveAvgPool3d); // ============================================================================ @@ -483,8 +511,8 @@ class TORCH_API MaxUnpoolImpl : public torch::nn::Cloneable { /// See https://pytorch.org/docs/master/nn.html#torch.nn.MaxUnpool1d to learn /// about the exact behavior of this module. /// -/// See the documentation for `torch::nn::MaxUnpool1dOptions` class to learn what -/// constructor arguments are supported for this module. +/// See the documentation for `torch::nn::MaxUnpool1dOptions` class to learn +/// what constructor arguments are supported for this module. /// /// Example: /// ``` @@ -494,17 +522,20 @@ class TORCH_API MaxUnpoolImpl : public torch::nn::Cloneable { class TORCH_API MaxUnpool1dImpl : public MaxUnpoolImpl<1, MaxUnpool1dImpl> { public: using MaxUnpoolImpl<1, MaxUnpool1dImpl>::MaxUnpoolImpl; - Tensor forward(const Tensor& input, const Tensor& indices, - const c10::optional>& output_size = c10::nullopt); + Tensor forward( + const Tensor& input, + const Tensor& indices, + const c10::optional>& output_size = c10::nullopt); + protected: FORWARD_HAS_DEFAULT_ARGS({2, AnyValue(c10::optional>())}) }; /// A `ModuleHolder` subclass for `MaxUnpool1dImpl`. /// See the documentation for `MaxUnpool1dImpl` class to learn what methods it -/// provides, and examples of how to use `MaxUnpool1d` with `torch::nn::MaxUnpool1dOptions`. -/// See the documentation for `ModuleHolder` to learn about PyTorch's -/// module storage semantics. +/// provides, and examples of how to use `MaxUnpool1d` with +/// `torch::nn::MaxUnpool1dOptions`. See the documentation for `ModuleHolder` to +/// learn about PyTorch's module storage semantics. TORCH_MODULE(MaxUnpool1d); // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ MaxUnpool2d ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -513,8 +544,8 @@ TORCH_MODULE(MaxUnpool1d); /// See https://pytorch.org/docs/master/nn.html#torch.nn.MaxUnpool2d to learn /// about the exact behavior of this module. /// -/// See the documentation for `torch::nn::MaxUnpool2dOptions` class to learn what -/// constructor arguments are supported for this module. +/// See the documentation for `torch::nn::MaxUnpool2dOptions` class to learn +/// what constructor arguments are supported for this module. /// /// Example: /// ``` @@ -524,17 +555,20 @@ TORCH_MODULE(MaxUnpool1d); class TORCH_API MaxUnpool2dImpl : public MaxUnpoolImpl<2, MaxUnpool2dImpl> { public: using MaxUnpoolImpl<2, MaxUnpool2dImpl>::MaxUnpoolImpl; - Tensor forward(const Tensor& input, const Tensor& indices, - const c10::optional>& output_size = c10::nullopt); + Tensor forward( + const Tensor& input, + const Tensor& indices, + const c10::optional>& output_size = c10::nullopt); + protected: FORWARD_HAS_DEFAULT_ARGS({2, AnyValue(c10::optional>())}) }; /// A `ModuleHolder` subclass for `MaxUnpool2dImpl`. /// See the documentation for `MaxUnpool2dImpl` class to learn what methods it -/// provides, and examples of how to use `MaxUnpool2d` with `torch::nn::MaxUnpool2dOptions`. -/// See the documentation for `ModuleHolder` to learn about PyTorch's -/// module storage semantics. +/// provides, and examples of how to use `MaxUnpool2d` with +/// `torch::nn::MaxUnpool2dOptions`. See the documentation for `ModuleHolder` to +/// learn about PyTorch's module storage semantics. TORCH_MODULE(MaxUnpool2d); // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ MaxUnpool3d ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -543,8 +577,8 @@ TORCH_MODULE(MaxUnpool2d); /// See https://pytorch.org/docs/master/nn.html#torch.nn.MaxUnpool3d to learn /// about the exact behavior of this module. /// -/// See the documentation for `torch::nn::MaxUnpool3dOptions` class to learn what -/// constructor arguments are supported for this module. +/// See the documentation for `torch::nn::MaxUnpool3dOptions` class to learn +/// what constructor arguments are supported for this module. /// /// Example: /// ``` @@ -554,34 +588,39 @@ TORCH_MODULE(MaxUnpool2d); class TORCH_API MaxUnpool3dImpl : public MaxUnpoolImpl<3, MaxUnpool3dImpl> { public: using MaxUnpoolImpl<3, MaxUnpool3dImpl>::MaxUnpoolImpl; - Tensor forward(const Tensor& input, const Tensor& indices, - const c10::optional>& output_size = c10::nullopt); + Tensor forward( + const Tensor& input, + const Tensor& indices, + const c10::optional>& output_size = c10::nullopt); + protected: FORWARD_HAS_DEFAULT_ARGS({2, AnyValue(c10::optional>())}) }; /// A `ModuleHolder` subclass for `MaxUnpool3dImpl`. /// See the documentation for `MaxUnpool3dImpl` class to learn what methods it -/// provides, and examples of how to use `MaxUnpool3d` with `torch::nn::MaxUnpool3dOptions`. -/// See the documentation for `ModuleHolder` to learn about PyTorch's -/// module storage semantics. +/// provides, and examples of how to use `MaxUnpool3d` with +/// `torch::nn::MaxUnpool3dOptions`. See the documentation for `ModuleHolder` to +/// learn about PyTorch's module storage semantics. TORCH_MODULE(MaxUnpool3d); -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ FractionalMaxPool2d ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ FractionalMaxPool2d +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ /// Applies fractional maxpool over a 2-D input. -/// See https://pytorch.org/docs/master/nn.html#torch.nn.FractionalMaxPool2d to learn -/// about the exact behavior of this module. +/// See https://pytorch.org/docs/master/nn.html#torch.nn.FractionalMaxPool2d to +/// learn about the exact behavior of this module. /// -/// See the documentation for `torch::nn::FractionalMaxPool2dOptions` class to learn what -/// constructor arguments are supported for this module. +/// See the documentation for `torch::nn::FractionalMaxPool2dOptions` class to +/// learn what constructor arguments are supported for this module. /// /// Example: /// ``` /// FractionalMaxPool2d model(FractionalMaxPool2dOptions(5).output_size(1)); /// ``` // NOLINTNEXTLINE(bugprone-exception-escape) -class TORCH_API FractionalMaxPool2dImpl : public torch::nn::Cloneable { +class TORCH_API FractionalMaxPool2dImpl + : public torch::nn::Cloneable { public: FractionalMaxPool2dImpl(ExpandingArray<2> kernel_size) : FractionalMaxPool2dImpl(FractionalMaxPool2dOptions(kernel_size)) {} @@ -605,27 +644,29 @@ class TORCH_API FractionalMaxPool2dImpl : public torch::nn::Cloneable { +class TORCH_API FractionalMaxPool3dImpl + : public torch::nn::Cloneable { public: FractionalMaxPool3dImpl(ExpandingArray<3> kernel_size) : FractionalMaxPool3dImpl(FractionalMaxPool3dOptions(kernel_size)) {} @@ -649,10 +690,10 @@ class TORCH_API FractionalMaxPool3dImpl : public torch::nn::Cloneable { using LPPoolImpl<1, LPPool1dImpl>::LPPoolImpl; Tensor forward(const Tensor& input); - }; /// A `ModuleHolder` subclass for `LPPool1dImpl`. /// See the documentation for `LPPool1dImpl` class to learn what methods it -/// provides, and examples of how to use `LPPool1d` with `torch::nn::LPPool1dOptions`. -/// See the documentation for `ModuleHolder` to learn about PyTorch's -/// module storage semantics. +/// provides, and examples of how to use `LPPool1d` with +/// `torch::nn::LPPool1dOptions`. See the documentation for `ModuleHolder` to +/// learn about PyTorch's module storage semantics. TORCH_MODULE(LPPool1d); // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ LPPool2d ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -714,7 +754,8 @@ TORCH_MODULE(LPPool1d); /// /// Example: /// ``` -/// LPPool2d model(LPPool2dOptions(1, std::vector({3, 4})).stride({5, 6}).ceil_mode(true)); +/// LPPool2d model(LPPool2dOptions(1, std::vector({3, 4})).stride({5, +/// 6}).ceil_mode(true)); /// ``` // NOLINTNEXTLINE(bugprone-exception-escape) class TORCH_API LPPool2dImpl : public LPPoolImpl<2, LPPool2dImpl> { @@ -722,14 +763,13 @@ class TORCH_API LPPool2dImpl : public LPPoolImpl<2, LPPool2dImpl> { using LPPoolImpl<2, LPPool2dImpl>::LPPoolImpl; Tensor forward(const Tensor& input); - }; /// A `ModuleHolder` subclass for `LPPool2dImpl`. /// See the documentation for `LPPool2dImpl` class to learn what methods it -/// provides, and examples of how to use `LPPool2d` with `torch::nn::LPPool2dOptions`. -/// See the documentation for `ModuleHolder` to learn about PyTorch's -/// module storage semantics. +/// provides, and examples of how to use `LPPool2d` with +/// `torch::nn::LPPool2dOptions`. See the documentation for `ModuleHolder` to +/// learn about PyTorch's module storage semantics. TORCH_MODULE(LPPool2d); } // namespace nn diff --git a/torch/csrc/api/include/torch/nn/modules/rnn.h b/torch/csrc/api/include/torch/nn/modules/rnn.h index d10b150a7890f8..c3d892dc0f5fcb 100644 --- a/torch/csrc/api/include/torch/nn/modules/rnn.h +++ b/torch/csrc/api/include/torch/nn/modules/rnn.h @@ -1,9 +1,9 @@ #pragma once #include -#include #include #include +#include #include #include #include @@ -66,14 +66,17 @@ class TORCH_API RNNImplBase : public torch::nn::Cloneable { void check_input(const Tensor& input, const Tensor& batch_sizes) const; - std::tuple get_expected_hidden_size(const Tensor& input, const Tensor& batch_sizes) const; + std::tuple get_expected_hidden_size( + const Tensor& input, + const Tensor& batch_sizes) const; void check_hidden_size( - const Tensor& hx, - std::tuple expected_hidden_size, - std::string msg = "Expected hidden size {1}, got {2}") const; + const Tensor& hx, + std::tuple expected_hidden_size, + std::string msg = "Expected hidden size {1}, got {2}") const; - void check_forward_args(Tensor input, Tensor hidden, Tensor batch_sizes) const; + void check_forward_args(Tensor input, Tensor hidden, Tensor batch_sizes) + const; Tensor permute_hidden(Tensor hx, const Tensor& permutation) const; @@ -97,7 +100,8 @@ class TORCH_API RNNImplBase : public torch::nn::Cloneable { /// /// Example: /// ``` -/// RNN model(RNNOptions(128, 64).num_layers(3).dropout(0.2).nonlinearity(torch::kTanh)); +/// RNN model(RNNOptions(128, +/// 64).num_layers(3).dropout(0.2).nonlinearity(torch::kTanh)); /// ``` // NOLINTNEXTLINE(bugprone-exception-escape) class TORCH_API RNNImpl : public detail::RNNImplBase { @@ -107,21 +111,25 @@ class TORCH_API RNNImpl : public detail::RNNImplBase { explicit RNNImpl(const RNNOptions& options_); std::tuple forward(const Tensor& input, Tensor hx = {}); + protected: FORWARD_HAS_DEFAULT_ARGS({1, AnyValue(Tensor())}) public: - std::tuple forward_with_packed_input(const torch::nn::utils::rnn::PackedSequence& packed_input, Tensor hx = {}); + std::tuple + forward_with_packed_input( + const torch::nn::utils::rnn::PackedSequence& packed_input, + Tensor hx = {}); RNNOptions options; protected: std::tuple forward_helper( - const Tensor& input, - const Tensor& batch_sizes, - const Tensor& sorted_indices, - int64_t max_batch_size, - Tensor hx); + const Tensor& input, + const Tensor& batch_sizes, + const Tensor& sorted_indices, + int64_t max_batch_size, + Tensor hx); }; /// A `ModuleHolder` subclass for `RNNImpl`. @@ -142,7 +150,8 @@ TORCH_MODULE(RNN); /// /// Example: /// ``` -/// LSTM model(LSTMOptions(2, 4).num_layers(3).batch_first(false).bidirectional(true)); +/// LSTM model(LSTMOptions(2, +/// 4).num_layers(3).batch_first(false).bidirectional(true)); /// ``` // NOLINTNEXTLINE(bugprone-exception-escape) class TORCH_API LSTMImpl : public detail::RNNImplBase { @@ -152,29 +161,41 @@ class TORCH_API LSTMImpl : public detail::RNNImplBase { explicit LSTMImpl(const LSTMOptions& options_); std::tuple> forward( - const Tensor& input, torch::optional> hx_opt = {}); + const Tensor& input, + torch::optional> hx_opt = {}); + protected: - FORWARD_HAS_DEFAULT_ARGS({1, AnyValue(torch::optional>())}) + FORWARD_HAS_DEFAULT_ARGS( + {1, AnyValue(torch::optional>())}) public: - std::tuple> forward_with_packed_input( - const torch::nn::utils::rnn::PackedSequence& packed_input, torch::optional> hx_opt = {}); + std::tuple> + forward_with_packed_input( + const torch::nn::utils::rnn::PackedSequence& packed_input, + torch::optional> hx_opt = {}); LSTMOptions options; protected: - void check_forward_args(const Tensor& input, std::tuple hidden, const Tensor& batch_sizes) const; + void check_forward_args( + const Tensor& input, + std::tuple hidden, + const Tensor& batch_sizes) const; - std::tuple get_expected_cell_size(const Tensor& input, const Tensor& batch_sizes) const; + std::tuple get_expected_cell_size( + const Tensor& input, + const Tensor& batch_sizes) const; - std::tuple permute_hidden(std::tuple hx, const Tensor& permutation) const; + std::tuple permute_hidden( + std::tuple hx, + const Tensor& permutation) const; std::tuple> forward_helper( - const Tensor& input, - const Tensor& batch_sizes, - const Tensor& sorted_indices, - int64_t max_batch_size, - torch::optional> hx_opt); + const Tensor& input, + const Tensor& batch_sizes, + const Tensor& sorted_indices, + int64_t max_batch_size, + torch::optional> hx_opt); }; /// A `ModuleHolder` subclass for `LSTMImpl`. @@ -195,7 +216,8 @@ TORCH_MODULE(LSTM); /// /// Example: /// ``` -/// GRU model(GRUOptions(2, 4).num_layers(3).batch_first(false).bidirectional(true)); +/// GRU model(GRUOptions(2, +/// 4).num_layers(3).batch_first(false).bidirectional(true)); /// ``` // NOLINTNEXTLINE(bugprone-exception-escape) class TORCH_API GRUImpl : public detail::RNNImplBase { @@ -205,21 +227,25 @@ class TORCH_API GRUImpl : public detail::RNNImplBase { explicit GRUImpl(const GRUOptions& options_); std::tuple forward(const Tensor& input, Tensor hx = {}); + protected: FORWARD_HAS_DEFAULT_ARGS({1, AnyValue(torch::Tensor())}) public: - std::tuple forward_with_packed_input(const torch::nn::utils::rnn::PackedSequence& packed_input, Tensor hx = {}); + std::tuple + forward_with_packed_input( + const torch::nn::utils::rnn::PackedSequence& packed_input, + Tensor hx = {}); GRUOptions options; protected: std::tuple forward_helper( - const Tensor& input, - const Tensor& batch_sizes, - const Tensor& sorted_indices, - int64_t max_batch_size, - Tensor hx); + const Tensor& input, + const Tensor& batch_sizes, + const Tensor& sorted_indices, + int64_t max_batch_size, + Tensor hx); }; /// A `ModuleHolder` subclass for `GRUImpl`. @@ -229,7 +255,8 @@ class TORCH_API GRUImpl : public detail::RNNImplBase { /// module storage semantics. TORCH_MODULE(GRU); -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ RNNCellImplBase ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ RNNCellImplBase +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ namespace detail { /// Base class for all RNNCell implementations (intended for code sharing). @@ -256,13 +283,16 @@ class TORCH_API RNNCellImplBase : public torch::nn::Cloneable { protected: void check_forward_input(const Tensor& input) const; - void check_forward_hidden(const Tensor& input, const Tensor& hx, std::string hidden_label) const; + void check_forward_hidden( + const Tensor& input, + const Tensor& hx, + std::string hidden_label) const; virtual std::string get_nonlinearity_str() const; }; } // namespace detail - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ RNNCell ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ RNNCell +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ /// An Elman RNN cell with tanh or ReLU non-linearity. /// See https://pytorch.org/docs/master/nn.html#torch.nn.RNNCell to learn @@ -273,7 +303,8 @@ class TORCH_API RNNCellImplBase : public torch::nn::Cloneable { /// /// Example: /// ``` -/// RNNCell model(RNNCellOptions(20, 10).bias(false).nonlinearity(torch::kReLU)); +/// RNNCell model(RNNCellOptions(20, +/// 10).bias(false).nonlinearity(torch::kReLU)); /// ``` // NOLINTNEXTLINE(bugprone-exception-escape) class TORCH_API RNNCellImpl : public detail::RNNCellImplBase { @@ -283,6 +314,7 @@ class TORCH_API RNNCellImpl : public detail::RNNCellImplBase { explicit RNNCellImpl(const RNNCellOptions& options_); Tensor forward(const Tensor& input, Tensor hx = {}); + protected: FORWARD_HAS_DEFAULT_ARGS({1, AnyValue(Tensor())}) @@ -295,12 +327,13 @@ class TORCH_API RNNCellImpl : public detail::RNNCellImplBase { /// A `ModuleHolder` subclass for `RNNCellImpl`. /// See the documentation for `RNNCellImpl` class to learn what methods it -/// provides, and examples of how to use `RNNCell` with `torch::nn::RNNCellOptions`. -/// See the documentation for `ModuleHolder` to learn about PyTorch's -/// module storage semantics. +/// provides, and examples of how to use `RNNCell` with +/// `torch::nn::RNNCellOptions`. See the documentation for `ModuleHolder` to +/// learn about PyTorch's module storage semantics. TORCH_MODULE(RNNCell); -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ LSTMCell ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ LSTMCell +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ /// A long short-term memory (LSTM) cell. /// See https://pytorch.org/docs/master/nn.html#torch.nn.LSTMCell to learn @@ -320,9 +353,13 @@ class TORCH_API LSTMCellImpl : public detail::RNNCellImplBase { : LSTMCellImpl(LSTMCellOptions(input_size, hidden_size)) {} explicit LSTMCellImpl(const LSTMCellOptions& options_); - std::tuple forward(const Tensor& input, torch::optional> hx_opt = {}); + std::tuple forward( + const Tensor& input, + torch::optional> hx_opt = {}); + protected: - FORWARD_HAS_DEFAULT_ARGS({1, AnyValue(torch::optional>())}) + FORWARD_HAS_DEFAULT_ARGS( + {1, AnyValue(torch::optional>())}) public: LSTMCellOptions options; @@ -330,12 +367,13 @@ class TORCH_API LSTMCellImpl : public detail::RNNCellImplBase { /// A `ModuleHolder` subclass for `LSTMCellImpl`. /// See the documentation for `LSTMCellImpl` class to learn what methods it -/// provides, and examples of how to use `LSTMCell` with `torch::nn::LSTMCellOptions`. -/// See the documentation for `ModuleHolder` to learn about PyTorch's -/// module storage semantics. +/// provides, and examples of how to use `LSTMCell` with +/// `torch::nn::LSTMCellOptions`. See the documentation for `ModuleHolder` to +/// learn about PyTorch's module storage semantics. TORCH_MODULE(LSTMCell); -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ GRUCell ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ GRUCell +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ /// A gated recurrent unit (GRU) cell. /// See https://pytorch.org/docs/master/nn.html#torch.nn.GRUCell to learn @@ -356,6 +394,7 @@ class TORCH_API GRUCellImpl : public detail::RNNCellImplBase { explicit GRUCellImpl(const GRUCellOptions& options_); Tensor forward(const Tensor& input, Tensor hx = {}); + protected: FORWARD_HAS_DEFAULT_ARGS({1, AnyValue(Tensor())}) @@ -365,9 +404,9 @@ class TORCH_API GRUCellImpl : public detail::RNNCellImplBase { /// A `ModuleHolder` subclass for `GRUCellImpl`. /// See the documentation for `GRUCellImpl` class to learn what methods it -/// provides, and examples of how to use `GRUCell` with `torch::nn::GRUCellOptions`. -/// See the documentation for `ModuleHolder` to learn about PyTorch's -/// module storage semantics. +/// provides, and examples of how to use `GRUCell` with +/// `torch::nn::GRUCellOptions`. See the documentation for `ModuleHolder` to +/// learn about PyTorch's module storage semantics. TORCH_MODULE(GRUCell); } // namespace nn diff --git a/torch/csrc/api/include/torch/nn/modules/transformer.h b/torch/csrc/api/include/torch/nn/modules/transformer.h index 212669f18c22f6..d4e6264e1b3acd 100644 --- a/torch/csrc/api/include/torch/nn/modules/transformer.h +++ b/torch/csrc/api/include/torch/nn/modules/transformer.h @@ -2,9 +2,9 @@ #include #include +#include #include #include -#include #include @@ -15,11 +15,11 @@ namespace nn { // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Transformer ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -/// A transformer model. User is able to modify the attributes as needed. The architecture -/// is based on the paper "Attention Is All You Need". Ashish Vaswani, Noam Shazeer, -/// Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, Lukasz Kaiser, and -/// Illia Polosukhin. 2017. Attention is all you need. In Advances in Neural Information -/// Processing Systems, pages 6000-6010. +/// A transformer model. User is able to modify the attributes as needed. The +/// architecture is based on the paper "Attention Is All You Need". Ashish +/// Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N +/// Gomez, Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. +/// In Advances in Neural Information Processing Systems, pages 6000-6010. /// /// See https://pytorch.org/docs/stable/generated/torch.nn.Transformer.html to /// learn about the exact behavior of this transformer model @@ -33,57 +33,59 @@ namespace nn { /// ``` // NOLINTNEXTLINE(bugprone-exception-escape) class TORCH_API TransformerImpl : public Cloneable { - - public: - explicit TransformerImpl(TransformerOptions options_); - - - /// forward function for Transformer Module - /// Args: - /// src: the sequence to the encoder (required). - /// tgt: the sequence to the decoder (required). - /// src_mask: the additive mask for the src sequence (optional). - /// tgt_mask: the additive mask for the tgt sequence (optional). - /// memory_mask: the additive mask for the encoder output (optional). - /// src_key_padding_mask: the ByteTensor mask for src keys per batch (optional). - /// tgt_key_padding_mask: the ByteTensor mask for tgt keys per batch (optional). - /// memory_key_padding_mask: the ByteTensor mask for memory keys per batch (optional). - /// - /// Shape: - /// src: `(S, N, E)` - /// tgt: `(T, N, E)` - /// src_mask: `(S, S)` - /// tgt_mask: `(T, T)` - /// memory_mask: `(T, S)` - /// src_key_padding_mask: `(N, S)` - /// tgt_key_padding_mask: `(N, T)` - /// memory_key_padding_mask: `(N, S)` - /// - /// Note: - /// [src/tgt/memory]_mask ensures that position i is allowed to attend the unmasked - /// positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend - /// while the zero positions will be unchanged. If a BoolTensor is provided, positions with `True` - /// are not allowed to attend while `False` values will be unchanged. If a FloatTensor - /// is provided, it will be added to the attention weight. - /// - /// [src/tgt/memory]_key_padding_mask provides specified elements in the key to be ignored by - /// the attention. If a ByteTensor is provided, the non-zero positions will be ignored while the zero - /// positions will be unchanged. If a BoolTensor is provided, the positions with the - /// value of `True` will be ignored while the position with the value of `False` will be unchanged. - /// - /// output: `(T, N, E)` - /// - /// Note: - /// Due to the multi-head attention architecture in the transformer model, - /// the output sequence length of a transformer is same as the input sequence - /// (i.e. target) length of the decode. - /// - /// where - /// S is the source sequence length, - /// T is the target sequence length, - /// N is the batch size, - /// E is the feature number. - Tensor forward( + public: + explicit TransformerImpl(TransformerOptions options_); + + /// forward function for Transformer Module + /// Args: + /// src: the sequence to the encoder (required). + /// tgt: the sequence to the decoder (required). + /// src_mask: the additive mask for the src sequence (optional). + /// tgt_mask: the additive mask for the tgt sequence (optional). + /// memory_mask: the additive mask for the encoder output (optional). + /// src_key_padding_mask: the ByteTensor mask for src keys per batch + /// (optional). tgt_key_padding_mask: the ByteTensor mask for tgt keys per + /// batch (optional). memory_key_padding_mask: the ByteTensor mask for + /// memory keys per batch (optional). + /// + /// Shape: + /// src: `(S, N, E)` + /// tgt: `(T, N, E)` + /// src_mask: `(S, S)` + /// tgt_mask: `(T, T)` + /// memory_mask: `(T, S)` + /// src_key_padding_mask: `(N, S)` + /// tgt_key_padding_mask: `(N, T)` + /// memory_key_padding_mask: `(N, S)` + /// + /// Note: + /// [src/tgt/memory]_mask ensures that position i is allowed to attend the + /// unmasked positions. If a ByteTensor is provided, the non-zero + /// positions are not allowed to attend while the zero positions will be + /// unchanged. If a BoolTensor is provided, positions with `True` are not + /// allowed to attend while `False` values will be unchanged. If a + /// FloatTensor is provided, it will be added to the attention weight. + /// + /// [src/tgt/memory]_key_padding_mask provides specified elements in the + /// key to be ignored by the attention. If a ByteTensor is provided, the + /// non-zero positions will be ignored while the zero positions will be + /// unchanged. If a BoolTensor is provided, the positions with the value + /// of `True` will be ignored while the position with the value of `False` + /// will be unchanged. + /// + /// output: `(T, N, E)` + /// + /// Note: + /// Due to the multi-head attention architecture in the transformer model, + /// the output sequence length of a transformer is same as the input + /// sequence (i.e. target) length of the decode. + /// + /// where + /// S is the source sequence length, + /// T is the target sequence length, + /// N is the batch size, + /// E is the feature number. + Tensor forward( const Tensor& src, const Tensor& tgt, const Tensor& src_mask = {}, @@ -93,23 +95,25 @@ class TORCH_API TransformerImpl : public Cloneable { const Tensor& tgt_key_padding_mask = {}, const Tensor& memory_key_padding_mask = {}); - void reset() override; - - void reset_parameters(); - - /// Generate a square mask for the sequence. - /// The masked positions are filled with `-inf` in float type. - /// Unmasked positions are filled with `0.0` in float type. - /// Note: - /// 1. This function will always return a CPU tensor. - /// 2. This function requires the platform support IEEE754, since `-inf` is guaranteed to - /// be valid only when IEEE754 is supported. If the platform doesn't support IEEE754, - /// this function will fill the mask with the smallest float number instead of `-inf`, - /// a one time warning will pop up as well. - static Tensor generate_square_subsequent_mask(int64_t sz); - - protected: - FORWARD_HAS_DEFAULT_ARGS( + void reset() override; + + void reset_parameters(); + + /// Generate a square mask for the sequence. + /// The masked positions are filled with `-inf` in float type. + /// Unmasked positions are filled with `0.0` in float type. + /// Note: + /// 1. This function will always return a CPU tensor. + /// 2. This function requires the platform support IEEE754, since `-inf` is + /// guaranteed to + /// be valid only when IEEE754 is supported. If the platform doesn't + /// support IEEE754, this function will fill the mask with the smallest + /// float number instead of `-inf`, a one time warning will pop up as + /// well. + static Tensor generate_square_subsequent_mask(int64_t sz); + + protected: + FORWARD_HAS_DEFAULT_ARGS( {2, AnyValue(Tensor())}, {3, AnyValue(Tensor())}, {4, AnyValue(Tensor())}, @@ -117,15 +121,15 @@ class TORCH_API TransformerImpl : public Cloneable { {6, AnyValue(Tensor())}, {7, AnyValue(Tensor())}) - public: - /// options with which this `Transformer` was constructed - TransformerOptions options; + public: + /// options with which this `Transformer` was constructed + TransformerOptions options; - /// encoder module - AnyModule encoder; + /// encoder module + AnyModule encoder; - /// decoder module - AnyModule decoder; + /// decoder module + AnyModule decoder; }; /// A `ModuleHolder` subclass for `TransformerImpl`. diff --git a/torch/csrc/api/include/torch/nn/modules/transformercoder.h b/torch/csrc/api/include/torch/nn/modules/transformercoder.h index 82ebcd7984655d..38d432e86a03bc 100644 --- a/torch/csrc/api/include/torch/nn/modules/transformercoder.h +++ b/torch/csrc/api/include/torch/nn/modules/transformercoder.h @@ -2,11 +2,11 @@ #include #include +#include #include #include #include #include -#include #include @@ -15,51 +15,56 @@ namespace torch { namespace nn { -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ TransformerEncoder ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ TransformerEncoder +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ /// TransformerEncoder module. -/// See https://pytorch.org/docs/master/generated/torch.nn.TransformerEncoder.html to -/// learn abouut the exact behavior of this encoder layer module. +/// See +/// https://pytorch.org/docs/master/generated/torch.nn.TransformerEncoder.html +/// to learn abouut the exact behavior of this encoder layer module. /// -/// See the documentation for `torch::nn::TransformerEncoder` class to learn what -/// constructor arguments are supported for this encoder module. +/// See the documentation for `torch::nn::TransformerEncoder` class to learn +/// what constructor arguments are supported for this encoder module. /// /// Example: /// ``` -/// TransformerEncoderLayer encoderLayer(TransformerEncoderLayerOptions(512, 8).dropout(0.1)); -/// TransformerEncoder encoder(TransformerEncoderOptions(encoderLayer, 6).norm(LayerNorm(LayerNormOptions({2})))); +/// TransformerEncoderLayer encoderLayer(TransformerEncoderLayerOptions(512, +/// 8).dropout(0.1)); TransformerEncoder +/// encoder(TransformerEncoderOptions(encoderLayer, +/// 6).norm(LayerNorm(LayerNormOptions({2})))); /// ``` // NOLINTNEXTLINE(bugprone-exception-escape) -class TORCH_API TransformerEncoderImpl : public Cloneable { - - public: - TransformerEncoderImpl(TransformerEncoderLayer encoder_layer, int64_t num_layers) - : TransformerEncoderImpl(TransformerEncoderOptions(encoder_layer, num_layers)) {} - explicit TransformerEncoderImpl(TransformerEncoderOptions options_); - - Tensor forward( +class TORCH_API TransformerEncoderImpl + : public Cloneable { + public: + TransformerEncoderImpl( + TransformerEncoderLayer encoder_layer, + int64_t num_layers) + : TransformerEncoderImpl( + TransformerEncoderOptions(encoder_layer, num_layers)) {} + explicit TransformerEncoderImpl(TransformerEncoderOptions options_); + + Tensor forward( const Tensor& src, const Tensor& src_mask = {}, const Tensor& src_key_padding_mask = {}); - void reset() override; + void reset() override; - void reset_parameters(); + void reset_parameters(); - protected: - FORWARD_HAS_DEFAULT_ARGS( - {1, AnyValue(Tensor())}, - {2, AnyValue(Tensor())}) + protected: + FORWARD_HAS_DEFAULT_ARGS({1, AnyValue(Tensor())}, {2, AnyValue(Tensor())}) - public: - /// options with which this `TransformerEncoder` was constructed - TransformerEncoderOptions options; + public: + /// options with which this `TransformerEncoder` was constructed + TransformerEncoderOptions options; - /// module list that contains all the encoder layers - ModuleList layers = nullptr; + /// module list that contains all the encoder layers + ModuleList layers = nullptr; - /// optional normalization module - AnyModule norm; + /// optional normalization module + AnyModule norm; }; /// A `ModuleHolder` subclass for `TransformerEncoderImpl`. @@ -70,44 +75,51 @@ class TORCH_API TransformerEncoderImpl : public Cloneable { - - public: - TransformerDecoderImpl(TransformerDecoderLayer decoder_layer, int64_t num_layers) - : TransformerDecoderImpl(TransformerDecoderOptions(decoder_layer, num_layers)) {} - explicit TransformerDecoderImpl(TransformerDecoderOptions options_); - - void reset() override; - - void reset_parameters(); - - /// Pass the inputs (and mask) through the decoder layer in turn. - /// Args: - /// tgt: the sequence to the decoder layer (required). - /// memory: the sequence from the last layer of the encoder (required). - /// tgt_mask: the mask for the tgt sequence (optional). - /// memory_mask: the mask for the memory sequence (optional). - /// tgt_key_padding_mask: the mask for the tgt keys per batch (optional). - /// memory_key_padding_mask: the mask for the memory keys per batch (optional). - Tensor forward( +class TORCH_API TransformerDecoderImpl + : public Cloneable { + public: + TransformerDecoderImpl( + TransformerDecoderLayer decoder_layer, + int64_t num_layers) + : TransformerDecoderImpl( + TransformerDecoderOptions(decoder_layer, num_layers)) {} + explicit TransformerDecoderImpl(TransformerDecoderOptions options_); + + void reset() override; + + void reset_parameters(); + + /// Pass the inputs (and mask) through the decoder layer in turn. + /// Args: + /// tgt: the sequence to the decoder layer (required). + /// memory: the sequence from the last layer of the encoder (required). + /// tgt_mask: the mask for the tgt sequence (optional). + /// memory_mask: the mask for the memory sequence (optional). + /// tgt_key_padding_mask: the mask for the tgt keys per batch + /// (optional). memory_key_padding_mask: the mask for the memory keys + /// per batch (optional). + Tensor forward( const Tensor& tgt, const Tensor& memory, const Tensor& tgt_mask = {}, @@ -115,17 +127,17 @@ class TORCH_API TransformerDecoderImpl : public Cloneable #include -#include -#include -#include -#include -#include #include #include +#include +#include +#include +#include +#include #include @@ -17,154 +17,164 @@ namespace torch { namespace nn { -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ TransformerEncoderLayer ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ TransformerEncoderLayer +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ /// TransformerEncoderLayer module. -/// See https://pytorch.org/docs/master/generated/torch.nn.TransformerEncoderLayer.html to -/// learn abouut the exact behavior of this encoder layer model +/// See +/// https://pytorch.org/docs/master/generated/torch.nn.TransformerEncoderLayer.html +/// to learn abouut the exact behavior of this encoder layer model /// -/// See the documentation for `torch::nn::TransformerEncoderLayer` class to learn what -/// constructor arguments are supported for this encoder layer model +/// See the documentation for `torch::nn::TransformerEncoderLayer` class to +/// learn what constructor arguments are supported for this encoder layer model /// /// Example: /// ``` -/// TransformerEncoderLayer encoderLayer(TransformerEncoderLayerOptions(512, 8).dropout(0.1)); +/// TransformerEncoderLayer encoderLayer(TransformerEncoderLayerOptions(512, +/// 8).dropout(0.1)); /// ``` // NOLINTNEXTLINE(bugprone-exception-escape) -class TORCH_API TransformerEncoderLayerImpl : public Cloneable { - - public: - TransformerEncoderLayerImpl(int64_t d_model, int64_t nhead) - : TransformerEncoderLayerImpl(TransformerEncoderLayerOptions(d_model, nhead)) {} - explicit TransformerEncoderLayerImpl(const TransformerEncoderLayerOptions& options_); +class TORCH_API TransformerEncoderLayerImpl + : public Cloneable { + public: + TransformerEncoderLayerImpl(int64_t d_model, int64_t nhead) + : TransformerEncoderLayerImpl( + TransformerEncoderLayerOptions(d_model, nhead)) {} + explicit TransformerEncoderLayerImpl( + const TransformerEncoderLayerOptions& options_); - Tensor forward( + Tensor forward( const Tensor& src, const Tensor& src_mask = {}, const Tensor& src_key_padding_mask = {}); - void reset() override; + void reset() override; - void reset_parameters(); + void reset_parameters(); - protected: - FORWARD_HAS_DEFAULT_ARGS( - {1, AnyValue(Tensor())}, - {2, AnyValue(Tensor())}) + protected: + FORWARD_HAS_DEFAULT_ARGS({1, AnyValue(Tensor())}, {2, AnyValue(Tensor())}) - public: - /// options with which this `TransformerEncoderLayer` was constructed - TransformerEncoderLayerOptions options; + public: + /// options with which this `TransformerEncoderLayer` was constructed + TransformerEncoderLayerOptions options; - /// self attention - MultiheadAttention self_attn = nullptr; + /// self attention + MultiheadAttention self_attn = nullptr; - /// feedforward first linear layer - Linear linear1 = nullptr; + /// feedforward first linear layer + Linear linear1 = nullptr; - /// feedforward dropout layer - Dropout dropout = nullptr; + /// feedforward dropout layer + Dropout dropout = nullptr; - /// feedforward second linear layer - Linear linear2 = nullptr; + /// feedforward second linear layer + Linear linear2 = nullptr; - /// pre feedforward, normalization layer - LayerNorm norm1 = nullptr; - /// post feedfastward, normalization layer - LayerNorm norm2 = nullptr; + /// pre feedforward, normalization layer + LayerNorm norm1 = nullptr; + /// post feedfastward, normalization layer + LayerNorm norm2 = nullptr; - /// pre feedfastward, dropout layer - Dropout dropout1 = nullptr; - /// post feedfastward, dropout layer - Dropout dropout2 = nullptr; + /// pre feedfastward, dropout layer + Dropout dropout1 = nullptr; + /// post feedfastward, dropout layer + Dropout dropout2 = nullptr; }; /// A `ModuleHolder` subclass for `TransformerEncoderLayerImpl``. /// See the documentation for `TransformerEncoderLayerImpl` class to learn what -/// methods it provides, and examples of how to use `TransformerEncoderLayer` with -/// `torch::nn::TransformerEncoderLayerOptions`. -/// See the documentation for `ModuleHolder` to learn about PyTorch's -/// module storage semantics. +/// methods it provides, and examples of how to use `TransformerEncoderLayer` +/// with `torch::nn::TransformerEncoderLayerOptions`. See the documentation for +/// `ModuleHolder` to learn about PyTorch's module storage semantics. TORCH_MODULE(TransformerEncoderLayer); -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ TransformerDecoderLayer ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -/// TransformerDecoderLayer is made up of self-attn, multi-head-attn and feedforward network. -/// This standard decoder layer is based on the paper "Attention Is All You Need". -/// Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, -/// Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in -/// Neural Information Processing Systems, pages 6000-6010. Users may modify or implement -/// in a different way during application. -/// See https://pytorch.org/docs/master/nn.html#transformer-layers to learn -/// about the exact behavior of this module. +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ TransformerDecoderLayer +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +/// TransformerDecoderLayer is made up of self-attn, multi-head-attn and +/// feedforward network. This standard decoder layer is based on the paper +/// "Attention Is All You Need". Ashish Vaswani, Noam Shazeer, Niki Parmar, +/// Jakob Uszkoreit, Llion Jones, Aidan N Gomez, Lukasz Kaiser, and Illia +/// Polosukhin. 2017. Attention is all you need. In Advances in Neural +/// Information Processing Systems, pages 6000-6010. Users may modify or +/// implement in a different way during application. See +/// https://pytorch.org/docs/master/nn.html#transformer-layers to learn about +/// the exact behavior of this module. /// -/// See the documentation for `torch::nn::TransformerDecoderLayerOptions` class to learn what -/// constructor arguments are supported for this module. +/// See the documentation for `torch::nn::TransformerDecoderLayerOptions` class +/// to learn what constructor arguments are supported for this module. /// /// Example: /// ``` -/// TransformerDecoderLayer model(TransformerDecoderLayerOptions(512, 8).dropout(0.2)); +/// TransformerDecoderLayer model(TransformerDecoderLayerOptions(512, +/// 8).dropout(0.2)); /// ``` // NOLINTNEXTLINE(bugprone-exception-escape) -class TORCH_API TransformerDecoderLayerImpl : public Cloneable { +class TORCH_API TransformerDecoderLayerImpl + : public Cloneable { public: TransformerDecoderLayerImpl(int64_t d_model, int64_t nhead) - : TransformerDecoderLayerImpl(TransformerDecoderLayerOptions(d_model, nhead)) {} - explicit TransformerDecoderLayerImpl(const TransformerDecoderLayerOptions& options_); + : TransformerDecoderLayerImpl( + TransformerDecoderLayerOptions(d_model, nhead)) {} + explicit TransformerDecoderLayerImpl( + const TransformerDecoderLayerOptions& options_); void reset() override; void reset_parameters(); /// Pass the inputs (and mask) through the decoder layer. - ///Args: + /// Args: /// tgt: the sequence to the decoder layer (required). /// memory: the sequence from the last layer of the encoder (required). /// tgt_mask: the mask for the tgt sequence (optional). /// memory_mask: the mask for the memory sequence (optional). - /// tgt_key_padding_mask: the mask for the tgt keys per batch (optional). - /// memory_key_padding_mask: the mask for the memory keys per batch (optional). - Tensor forward(Tensor tgt, - const Tensor& memory, - const Tensor& tgt_mask = {}, - const Tensor& memory_mask = {}, - const Tensor& tgt_key_padding_mask = {}, - const Tensor& memory_key_padding_mask = {}); + /// tgt_key_padding_mask: the mask for the tgt keys per batch + /// (optional). memory_key_padding_mask: the mask for the memory keys + /// per batch (optional). + Tensor forward( + Tensor tgt, + const Tensor& memory, + const Tensor& tgt_mask = {}, + const Tensor& memory_mask = {}, + const Tensor& tgt_key_padding_mask = {}, + const Tensor& memory_key_padding_mask = {}); /// The options used to configure this module. TransformerDecoderLayerOptions options; - ///self attention + /// self attention MultiheadAttention self_attn{nullptr}; - ///Dropout, post self attention + /// Dropout, post self attention Dropout dropout1{nullptr}; - ///Normalization, post self attention + /// Normalization, post self attention LayerNorm norm1{nullptr}; - ///Multi-headed attention + /// Multi-headed attention MultiheadAttention multihead_attn{nullptr}; - ///Dropout, post multi-headed attention + /// Dropout, post multi-headed attention Dropout dropout2{nullptr}; - ///Normalization, post multi-headed attention + /// Normalization, post multi-headed attention LayerNorm norm2{nullptr}; - ///Feed forward first linear layer + /// Feed forward first linear layer Linear linear1{nullptr}; - ///Feed forward dropout layer + /// Feed forward dropout layer Dropout dropout{nullptr}; - ///Feed forward second linear layer + /// Feed forward second linear layer Linear linear2{nullptr}; - ///Dropout, post feed forward + /// Dropout, post feed forward Dropout dropout3{nullptr}; - ///Normalization, post feed forward + /// Normalization, post feed forward LayerNorm norm3{nullptr}; protected: @@ -174,16 +184,15 @@ class TORCH_API TransformerDecoderLayerImpl : public Cloneable { @@ -46,9 +47,9 @@ class TORCH_API UpsampleImpl : public Cloneable { /// A `ModuleHolder` subclass for `UpsampleImpl`. /// See the documentation for `UpsampleImpl` class to learn what methods it -/// provides, and examples of how to use `Upsample` with `torch::nn::UpsampleOptions`. -/// See the documentation for `ModuleHolder` to learn about PyTorch's -/// module storage semantics. +/// provides, and examples of how to use `Upsample` with +/// `torch::nn::UpsampleOptions`. See the documentation for `ModuleHolder` to +/// learn about PyTorch's module storage semantics. TORCH_MODULE(Upsample); } // namespace nn diff --git a/torch/csrc/api/include/torch/nn/modules/utils.h b/torch/csrc/api/include/torch/nn/modules/utils.h index 87f307c4a787b2..6d3d383465f331 100644 --- a/torch/csrc/api/include/torch/nn/modules/utils.h +++ b/torch/csrc/api/include/torch/nn/modules/utils.h @@ -16,7 +16,9 @@ namespace utils { // to the ones used by `F::pad`. // // This mirrors `_reverse_repeat_tuple` in `torch/nn/modules/utils.py`. -inline std::vector _reverse_repeat_vector(at::ArrayRef t, int64_t n) { +inline std::vector _reverse_repeat_vector( + at::ArrayRef t, + int64_t n) { TORCH_INTERNAL_ASSERT(n >= 0); std::vector ret; ret.reserve(t.size() * n); @@ -30,12 +32,15 @@ inline std::vector _reverse_repeat_vector(at::ArrayRef t, int6 } inline std::vector _list_with_default( - torch::ArrayRef> out_size, torch::IntArrayRef defaults) { + torch::ArrayRef> out_size, + torch::IntArrayRef defaults) { TORCH_CHECK( - defaults.size() > out_size.size(), - "Input dimension should be at least ", out_size.size() + 1); + defaults.size() > out_size.size(), + "Input dimension should be at least ", + out_size.size() + 1); std::vector ret; - torch::IntArrayRef defaults_slice = defaults.slice(defaults.size() - out_size.size(), out_size.size()); + torch::IntArrayRef defaults_slice = + defaults.slice(defaults.size() - out_size.size(), out_size.size()); for (const auto i : c10::irange(out_size.size())) { auto v = out_size.at(i); auto d = defaults_slice.at(i); diff --git a/torch/csrc/api/include/torch/nn/options.h b/torch/csrc/api/include/torch/nn/options.h index 9ce61e696d5dc8..4a5224a478e1b6 100644 --- a/torch/csrc/api/include/torch/nn/options.h +++ b/torch/csrc/api/include/torch/nn/options.h @@ -8,11 +8,11 @@ #include #include #include +#include #include -#include #include -#include -#include -#include -#include #include +#include +#include +#include +#include diff --git a/torch/csrc/api/include/torch/nn/options/activation.h b/torch/csrc/api/include/torch/nn/options/activation.h index 16ab0245fbb66b..e51805d3648521 100644 --- a/torch/csrc/api/include/torch/nn/options/activation.h +++ b/torch/csrc/api/include/torch/nn/options/activation.h @@ -1,8 +1,8 @@ #pragma once #include -#include #include +#include #include namespace torch { @@ -156,7 +156,8 @@ using HardshrinkFuncOptions = HardshrinkOptions; /// /// Example: /// ``` -/// Hardtanh model(HardtanhOptions().min_val(-42.42).max_val(0.42).inplace(true)); +/// Hardtanh +/// model(HardtanhOptions().min_val(-42.42).max_val(0.42).inplace(true)); /// ``` struct TORCH_API HardtanhOptions { /// minimum value of the linear region range. Default: -1 @@ -178,7 +179,8 @@ namespace functional { /// Example: /// ``` /// namespace F = torch::nn::functional; -/// F::hardtanh(x, F::HardtanhFuncOptions().min_val(-1.0).max_val(1.0).inplace(true)); +/// F::hardtanh(x, +/// F::HardtanhFuncOptions().min_val(-1.0).max_val(1.0).inplace(true)); /// ``` using HardtanhFuncOptions = HardtanhOptions; } // namespace functional @@ -208,10 +210,11 @@ namespace functional { /// Example: /// ``` /// namespace F = torch::nn::functional; -/// F::leaky_relu(x, F::LeakyReLUFuncOptions().negative_slope(0.42).inplace(true)); +/// F::leaky_relu(x, +/// F::LeakyReLUFuncOptions().negative_slope(0.42).inplace(true)); /// ``` using LeakyReLUFuncOptions = LeakyReLUOptions; -} +} // namespace functional // ============================================================================ @@ -247,7 +250,8 @@ struct TORCH_API SoftmaxFuncOptions { /// the desired data type of returned tensor. /// If specified, the input tensor is casted to `dtype` before the operation - /// is performed. This is useful for preventing data type overflows. Default: None. + /// is performed. This is useful for preventing data type overflows. Default: + /// None. TORCH_ARG(c10::optional, dtype) = c10::nullopt; }; @@ -287,7 +291,8 @@ struct TORCH_API SoftminFuncOptions { /// the desired data type of returned tensor. /// If specified, the input tensor is casted to `dtype` before the operation - /// is performed. This is useful for preventing data type overflows. Default: None. + /// is performed. This is useful for preventing data type overflows. Default: + /// None. TORCH_ARG(c10::optional, dtype) = c10::nullopt; }; @@ -327,7 +332,8 @@ struct TORCH_API LogSoftmaxFuncOptions { /// the desired data type of returned tensor. /// If specified, the input tensor is casted to `dtype` before the operation - /// is performed. This is useful for preventing data type overflows. Default: None. + /// is performed. This is useful for preventing data type overflows. Default: + /// None. TORCH_ARG(c10::optional, dtype) = c10::nullopt; }; @@ -343,7 +349,8 @@ struct TORCH_API LogSoftmaxFuncOptions { /// ``` struct TORCH_API PReLUOptions { /// number of `a` to learn. Although it takes an int as input, there is only - /// two values are legitimate: 1, or the number of channels at input. Default: 1 + /// two values are legitimate: 1, or the number of channels at input. Default: + /// 1 TORCH_ARG(int64_t, num_parameters) = 1; /// the initial value of `a`. Default: 0.25 @@ -552,7 +559,7 @@ using SoftshrinkFuncOptions = SoftshrinkOptions; /// ``` struct TORCH_API ThresholdOptions { ThresholdOptions(double threshold, double value) - : threshold_(threshold), value_(value) {} + : threshold_(threshold), value_(value) {} /// The value to threshold at TORCH_ARG(double, threshold); @@ -594,7 +601,8 @@ struct TORCH_API GumbelSoftmaxFuncOptions { TORCH_ARG(double, tau) = 1.0; /// returned samples will be discretized as one-hot vectors, - /// but will be differentiated as if it is the soft sample in autograd. Default: False + /// but will be differentiated as if it is the soft sample in autograd. + /// Default: False TORCH_ARG(bool, hard) = false; /// dimension along which softmax will be computed. Default: -1 @@ -645,14 +653,17 @@ namespace functional { /// Options for `torch::nn::functional::multi_head_attention_forward` struct TORCH_API MultiheadAttentionForwardFuncOptions { - MultiheadAttentionForwardFuncOptions( - int64_t embed_dim_to_check, int64_t num_heads, - Tensor in_proj_weight, Tensor in_proj_bias, - Tensor bias_k, Tensor bias_v, - bool add_zero_attn, double dropout_p, - Tensor out_proj_weight, Tensor out_proj_bias - ); + int64_t embed_dim_to_check, + int64_t num_heads, + Tensor in_proj_weight, + Tensor in_proj_bias, + Tensor bias_k, + Tensor bias_v, + bool add_zero_attn, + double dropout_p, + Tensor out_proj_weight, + Tensor out_proj_bias); TORCH_ARG(int64_t, embed_dim_to_check); diff --git a/torch/csrc/api/include/torch/nn/options/adaptive.h b/torch/csrc/api/include/torch/nn/options/adaptive.h index 3be918184af323..d4754747a1d296 100644 --- a/torch/csrc/api/include/torch/nn/options/adaptive.h +++ b/torch/csrc/api/include/torch/nn/options/adaptive.h @@ -11,10 +11,14 @@ namespace nn { /// /// Example: /// ``` -/// AdaptiveLogSoftmaxWithLoss model(AdaptiveLogSoftmaxWithLossOptions(8, 10, {4, 8}).div_value(2.).head_bias(true)); +/// AdaptiveLogSoftmaxWithLoss model(AdaptiveLogSoftmaxWithLossOptions(8, 10, +/// {4, 8}).div_value(2.).head_bias(true)); /// ``` struct TORCH_API AdaptiveLogSoftmaxWithLossOptions { - /* implicit */ AdaptiveLogSoftmaxWithLossOptions(int64_t in_features, int64_t n_classes, std::vector cutoffs); + /* implicit */ AdaptiveLogSoftmaxWithLossOptions( + int64_t in_features, + int64_t n_classes, + std::vector cutoffs); /// Number of features in the input tensor TORCH_ARG(int64_t, in_features); diff --git a/torch/csrc/api/include/torch/nn/options/batchnorm.h b/torch/csrc/api/include/torch/nn/options/batchnorm.h index 8f8412b14133f6..cd2d7f164203e8 100644 --- a/torch/csrc/api/include/torch/nn/options/batchnorm.h +++ b/torch/csrc/api/include/torch/nn/options/batchnorm.h @@ -38,7 +38,8 @@ struct TORCH_API BatchNormOptions { /// /// Example: /// ``` -/// BatchNorm1d model(BatchNorm1dOptions(4).eps(0.5).momentum(0.1).affine(false).track_running_stats(true)); +/// BatchNorm1d +/// model(BatchNorm1dOptions(4).eps(0.5).momentum(0.1).affine(false).track_running_stats(true)); /// ``` using BatchNorm1dOptions = BatchNormOptions; @@ -46,7 +47,8 @@ using BatchNorm1dOptions = BatchNormOptions; /// /// Example: /// ``` -/// BatchNorm2d model(BatchNorm2dOptions(4).eps(0.5).momentum(0.1).affine(false).track_running_stats(true)); +/// BatchNorm2d +/// model(BatchNorm2dOptions(4).eps(0.5).momentum(0.1).affine(false).track_running_stats(true)); /// ``` using BatchNorm2dOptions = BatchNormOptions; @@ -54,7 +56,8 @@ using BatchNorm2dOptions = BatchNormOptions; /// /// Example: /// ``` -/// BatchNorm3d model(BatchNorm3dOptions(4).eps(0.5).momentum(0.1).affine(false).track_running_stats(true)); +/// BatchNorm3d +/// model(BatchNorm3dOptions(4).eps(0.5).momentum(0.1).affine(false).track_running_stats(true)); /// ``` using BatchNorm3dOptions = BatchNormOptions; @@ -67,7 +70,8 @@ namespace functional { /// Example: /// ``` /// namespace F = torch::nn::functional; -/// F::batch_norm(input, mean, variance, F::BatchNormFuncOptions().weight(weight).bias(bias).momentum(0.1).eps(1e-05).training(false)); +/// F::batch_norm(input, mean, variance, +/// F::BatchNormFuncOptions().weight(weight).bias(bias).momentum(0.1).eps(1e-05).training(false)); /// ``` struct TORCH_API BatchNormFuncOptions { TORCH_ARG(Tensor, weight) = Tensor(); diff --git a/torch/csrc/api/include/torch/nn/options/conv.h b/torch/csrc/api/include/torch/nn/options/conv.h index adffd4e051019d..a2448d313b3637 100644 --- a/torch/csrc/api/include/torch/nn/options/conv.h +++ b/torch/csrc/api/include/torch/nn/options/conv.h @@ -1,8 +1,8 @@ #pragma once #include -#include #include +#include #include #include @@ -12,17 +12,15 @@ namespace nn { namespace detail { typedef c10::variant< - enumtype::kZeros, - enumtype::kReflect, - enumtype::kReplicate, - enumtype::kCircular -> conv_padding_mode_t; + enumtype::kZeros, + enumtype::kReflect, + enumtype::kReplicate, + enumtype::kCircular> + conv_padding_mode_t; template -using conv_padding_t = c10::variant< - ExpandingArray, - enumtype::kValid, - enumtype::kSame>; +using conv_padding_t = + c10::variant, enumtype::kValid, enumtype::kSame>; /// Options for a `D`-dimensional convolution or convolution transpose module. template @@ -31,10 +29,10 @@ struct ConvNdOptions { ConvNdOptions( int64_t in_channels, int64_t out_channels, - ExpandingArray kernel_size) : - in_channels_(in_channels), - out_channels_(out_channels), - kernel_size_(std::move(kernel_size)) {} + ExpandingArray kernel_size) + : in_channels_(in_channels), + out_channels_(out_channels), + kernel_size_(std::move(kernel_size)) {} /// The number of channels the input volumes will have. /// Changing this parameter after construction __has no effect__. @@ -62,7 +60,7 @@ struct ConvNdOptions { /// This parameter __can__ be changed after construction. TORCH_ARG(padding_t, padding) = 0; -public: + public: decltype(auto) padding(std::initializer_list il) { return padding(IntArrayRef{il}); } @@ -92,7 +90,8 @@ struct ConvNdOptions { /// Changing this parameter after construction __has no effect__. TORCH_ARG(bool, bias) = true; - /// Accepted values `torch::kZeros`, `torch::kReflect`, `torch::kReplicate` or `torch::kCircular`. Default: `torch::kZeros` + /// Accepted values `torch::kZeros`, `torch::kReflect`, `torch::kReplicate` or + /// `torch::kCircular`. Default: `torch::kZeros` TORCH_ARG(conv_padding_mode_t, padding_mode) = torch::kZeros; }; @@ -109,10 +108,10 @@ struct ConvOptions { ConvOptions( int64_t in_channels, int64_t out_channels, - ExpandingArray kernel_size) : - in_channels_(in_channels), - out_channels_(out_channels), - kernel_size_(std::move(kernel_size)) {} + ExpandingArray kernel_size) + : in_channels_(in_channels), + out_channels_(out_channels), + kernel_size_(std::move(kernel_size)) {} /// The number of channels the input volumes will have. /// Changing this parameter after construction __has no effect__. @@ -140,7 +139,7 @@ struct ConvOptions { /// This parameter __can__ be changed after construction. TORCH_ARG(padding_t, padding) = 0; -public: + public: decltype(auto) padding(std::initializer_list il) { return padding(IntArrayRef{il}); } @@ -159,7 +158,8 @@ struct ConvOptions { /// Changing this parameter after construction __has no effect__. TORCH_ARG(bool, bias) = true; - /// Accepted values `torch::kZeros`, `torch::kReflect`, `torch::kReplicate` or `torch::kCircular`. Default: `torch::kZeros` + /// Accepted values `torch::kZeros`, `torch::kReflect`, `torch::kReplicate` or + /// `torch::kCircular`. Default: `torch::kZeros` TORCH_ARG(padding_mode_t, padding_mode) = torch::kZeros; }; @@ -209,7 +209,7 @@ struct ConvFuncOptions { /// numbers. TORCH_ARG(padding_t, padding) = 0; -public: + public: decltype(auto) padding(std::initializer_list il) { return padding(IntArrayRef{il}); } @@ -262,10 +262,10 @@ struct ConvTransposeOptions { ConvTransposeOptions( int64_t in_channels, int64_t out_channels, - ExpandingArray kernel_size) : - in_channels_(in_channels), - out_channels_(out_channels), - kernel_size_(std::move(kernel_size)) {} + ExpandingArray kernel_size) + : in_channels_(in_channels), + out_channels_(out_channels), + kernel_size_(std::move(kernel_size)) {} /// The number of channels the input volumes will have. /// Changing this parameter after construction __has no effect__. @@ -313,7 +313,8 @@ struct ConvTransposeOptions { /// This parameter __can__ be changed after construction. TORCH_ARG(ExpandingArray, dilation) = 1; - /// Accepted values `torch::kZeros`, `torch::kReflect`, `torch::kReplicate` or `torch::kCircular`. Default: `torch::kZeros` + /// Accepted values `torch::kZeros`, `torch::kReflect`, `torch::kReplicate` or + /// `torch::kCircular`. Default: `torch::kZeros` TORCH_ARG(padding_mode_t, padding_mode) = torch::kZeros; }; @@ -321,7 +322,8 @@ struct ConvTransposeOptions { /// /// Example: /// ``` -/// ConvTranspose1d model(ConvTranspose1dOptions(3, 2, 3).stride(1).bias(false)); +/// ConvTranspose1d model(ConvTranspose1dOptions(3, 2, +/// 3).stride(1).bias(false)); /// ``` using ConvTranspose1dOptions = ConvTransposeOptions<1>; @@ -329,7 +331,8 @@ using ConvTranspose1dOptions = ConvTransposeOptions<1>; /// /// Example: /// ``` -/// ConvTranspose2d model(ConvTranspose2dOptions(3, 2, 3).stride(1).bias(false)); +/// ConvTranspose2d model(ConvTranspose2dOptions(3, 2, +/// 3).stride(1).bias(false)); /// ``` using ConvTranspose2dOptions = ConvTransposeOptions<2>; @@ -337,7 +340,8 @@ using ConvTranspose2dOptions = ConvTransposeOptions<2>; /// /// Example: /// ``` -/// ConvTranspose3d model(ConvTranspose3dOptions(2, 2, 2).stride(1).bias(false)); +/// ConvTranspose3d model(ConvTranspose3dOptions(2, 2, +/// 2).stride(1).bias(false)); /// ``` using ConvTranspose3dOptions = ConvTransposeOptions<3>; @@ -361,7 +365,8 @@ struct ConvTransposeFuncOptions { /// numbers. TORCH_ARG(ExpandingArray, padding) = 0; - /// Additional size added to one side of each dimension in the output shape. Default: 0 + /// Additional size added to one side of each dimension in the output shape. + /// Default: 0 TORCH_ARG(ExpandingArray, output_padding) = 0; /// Split input into groups, `in_channels` should be divisible by @@ -374,7 +379,8 @@ struct ConvTransposeFuncOptions { TORCH_ARG(ExpandingArray, dilation) = 1; }; -/// `ConvTransposeFuncOptions` specialized for `torch::nn::functional::conv_transpose1d`. +/// `ConvTransposeFuncOptions` specialized for +/// `torch::nn::functional::conv_transpose1d`. /// /// Example: /// ``` @@ -383,7 +389,8 @@ struct ConvTransposeFuncOptions { /// ``` using ConvTranspose1dFuncOptions = ConvTransposeFuncOptions<1>; -/// `ConvTransposeFuncOptions` specialized for `torch::nn::functional::conv_transpose2d`. +/// `ConvTransposeFuncOptions` specialized for +/// `torch::nn::functional::conv_transpose2d`. /// /// Example: /// ``` @@ -392,7 +399,8 @@ using ConvTranspose1dFuncOptions = ConvTransposeFuncOptions<1>; /// ``` using ConvTranspose2dFuncOptions = ConvTransposeFuncOptions<2>; -/// `ConvTransposeFuncOptions` specialized for `torch::nn::functional::conv_transpose3d`. +/// `ConvTransposeFuncOptions` specialized for +/// `torch::nn::functional::conv_transpose3d`. /// /// Example: /// ``` diff --git a/torch/csrc/api/include/torch/nn/options/distance.h b/torch/csrc/api/include/torch/nn/options/distance.h index 302533d3cdf29a..654cd6626498db 100644 --- a/torch/csrc/api/include/torch/nn/options/distance.h +++ b/torch/csrc/api/include/torch/nn/options/distance.h @@ -23,13 +23,14 @@ struct TORCH_API CosineSimilarityOptions { namespace functional { /// Options for `torch::nn::functional::cosine_similarity`. /// -/// See the documentation for `torch::nn::CosineSimilarityOptions` class to learn what -/// arguments are supported. +/// See the documentation for `torch::nn::CosineSimilarityOptions` class to +/// learn what arguments are supported. /// /// Example: /// ``` /// namespace F = torch::nn::functional; -/// F::cosine_similarity(input1, input2, F::CosineSimilarityFuncOptions().dim(1)); +/// F::cosine_similarity(input1, input2, +/// F::CosineSimilarityFuncOptions().dim(1)); /// ``` using CosineSimilarityFuncOptions = CosineSimilarityOptions; } // namespace functional @@ -40,7 +41,8 @@ using CosineSimilarityFuncOptions = CosineSimilarityOptions; /// /// Example: /// ``` -/// PairwiseDistance model(PairwiseDistanceOptions().p(3).eps(0.5).keepdim(true)); +/// PairwiseDistance +/// model(PairwiseDistanceOptions().p(3).eps(0.5).keepdim(true)); /// ``` struct TORCH_API PairwiseDistanceOptions { /// The norm degree. Default: 2 @@ -54,8 +56,8 @@ struct TORCH_API PairwiseDistanceOptions { namespace functional { /// Options for `torch::nn::functional::pairwise_distance`. /// -/// See the documentation for `torch::nn::PairwiseDistanceOptions` class to learn what -/// arguments are supported. +/// See the documentation for `torch::nn::PairwiseDistanceOptions` class to +/// learn what arguments are supported. /// /// Example: /// ``` diff --git a/torch/csrc/api/include/torch/nn/options/dropout.h b/torch/csrc/api/include/torch/nn/options/dropout.h index 57f4bc9b5e73b8..7f41f5672382c9 100644 --- a/torch/csrc/api/include/torch/nn/options/dropout.h +++ b/torch/csrc/api/include/torch/nn/options/dropout.h @@ -97,7 +97,8 @@ using Dropout3dFuncOptions = DropoutFuncOptions; /// Example: /// ``` /// namespace F = torch::nn::functional; -/// F::alpha_dropout(input, F::AlphaDropoutFuncOptions().p(0.5).training(false)); +/// F::alpha_dropout(input, +/// F::AlphaDropoutFuncOptions().p(0.5).training(false)); /// ``` struct TORCH_API AlphaDropoutFuncOptions { TORCH_ARG(double, p) = 0.5; @@ -112,7 +113,8 @@ struct TORCH_API AlphaDropoutFuncOptions { /// Example: /// ``` /// namespace F = torch::nn::functional; -/// F::feature_alpha_dropout(input, F::FeatureAlphaDropoutFuncOptions().p(0.5).training(false)); +/// F::feature_alpha_dropout(input, +/// F::FeatureAlphaDropoutFuncOptions().p(0.5).training(false)); /// ``` struct TORCH_API FeatureAlphaDropoutFuncOptions { TORCH_ARG(double, p) = 0.5; diff --git a/torch/csrc/api/include/torch/nn/options/embedding.h b/torch/csrc/api/include/torch/nn/options/embedding.h index c304f04ead56ef..23c96086d5a8a9 100644 --- a/torch/csrc/api/include/torch/nn/options/embedding.h +++ b/torch/csrc/api/include/torch/nn/options/embedding.h @@ -2,8 +2,8 @@ #include #include -#include #include +#include namespace torch { namespace nn { @@ -12,7 +12,8 @@ namespace nn { /// /// Example: /// ``` -/// Embedding model(EmbeddingOptions(10, 2).padding_idx(3).max_norm(2).norm_type(2.5).scale_grad_by_freq(true).sparse(true)); +/// Embedding model(EmbeddingOptions(10, +/// 2).padding_idx(3).max_norm(2).norm_type(2.5).scale_grad_by_freq(true).sparse(true)); /// ``` struct TORCH_API EmbeddingOptions { EmbeddingOptions(int64_t num_embeddings, int64_t embedding_dim); @@ -25,17 +26,21 @@ struct TORCH_API EmbeddingOptions { /// gradient; therefore, the embedding vector at `padding_idx` is not updated /// during training, i.e. it remains as a fixed "pad". For a newly constructed /// Embedding, the embedding vector at `padding_idx` will default to all - /// zeros, but can be updated to another value to be used as the padding vector. + /// zeros, but can be updated to another value to be used as the padding + /// vector. TORCH_ARG(c10::optional, padding_idx) = c10::nullopt; - /// If given, each embedding vector with norm larger than `max_norm` is renormalized to have norm `max_norm`. + /// If given, each embedding vector with norm larger than `max_norm` is + /// renormalized to have norm `max_norm`. TORCH_ARG(c10::optional, max_norm) = c10::nullopt; /// The p of the p-norm to compute for the `max_norm` option. Default ``2``. TORCH_ARG(double, norm_type) = 2.; - /// If given, this will scale gradients by the inverse of frequency of the words in the mini-batch. Default ``false``. + /// If given, this will scale gradients by the inverse of frequency of the + /// words in the mini-batch. Default ``false``. TORCH_ARG(bool, scale_grad_by_freq) = false; /// If ``true``, gradient w.r.t. `weight` matrix will be a sparse tensor. TORCH_ARG(bool, sparse) = false; - /// The learnable weights of the module of shape (num_embeddings, embedding_dim) + /// The learnable weights of the module of shape (num_embeddings, + /// embedding_dim) TORCH_ARG(torch::Tensor, _weight) = Tensor(); }; @@ -44,17 +49,20 @@ struct TORCH_API EmbeddingOptions { /// Options for the `Embedding::from_pretrained` function. struct TORCH_API EmbeddingFromPretrainedOptions { /// If ``true``, the tensor does not get updated in the learning process. - /// Equivalent to ``embedding.weight.requires_grad_(false)``. Default: ``true`` + /// Equivalent to ``embedding.weight.requires_grad_(false)``. Default: + /// ``true`` TORCH_ARG(bool, freeze) = true; /// If specified, the entries at `padding_idx` do not contribute to the /// gradient; therefore, the embedding vector at `padding_idx` is not updated /// during training, i.e. it remains as a fixed "pad". TORCH_ARG(c10::optional, padding_idx) = c10::nullopt; - /// If given, each embedding vector with norm larger than `max_norm` is renormalized to have norm `max_norm`. + /// If given, each embedding vector with norm larger than `max_norm` is + /// renormalized to have norm `max_norm`. TORCH_ARG(c10::optional, max_norm) = c10::nullopt; /// The p of the p-norm to compute for the `max_norm` option. Default ``2``. TORCH_ARG(double, norm_type) = 2.; - /// If given, this will scale gradients by the inverse of frequency of the words in the mini-batch. Default ``false``. + /// If given, this will scale gradients by the inverse of frequency of the + /// words in the mini-batch. Default ``false``. TORCH_ARG(bool, scale_grad_by_freq) = false; /// If ``true``, gradient w.r.t. `weight` matrix will be a sparse tensor. TORCH_ARG(bool, sparse) = false; @@ -69,18 +77,21 @@ namespace functional { /// Example: /// ``` /// namespace F = torch::nn::functional; -/// F::embedding(input, weight, F::EmbeddingFuncOptions().norm_type(2.5).scale_grad_by_freq(true).sparse(true)); +/// F::embedding(input, weight, +/// F::EmbeddingFuncOptions().norm_type(2.5).scale_grad_by_freq(true).sparse(true)); /// ``` struct TORCH_API EmbeddingFuncOptions { /// If specified, the entries at `padding_idx` do not contribute to the /// gradient; therefore, the embedding vector at `padding_idx` is not updated /// during training, i.e. it remains as a fixed "pad". TORCH_ARG(c10::optional, padding_idx) = c10::nullopt; - /// If given, each embedding vector with norm larger than `max_norm` is renormalized to have norm `max_norm`. + /// If given, each embedding vector with norm larger than `max_norm` is + /// renormalized to have norm `max_norm`. TORCH_ARG(c10::optional, max_norm) = c10::nullopt; /// The p of the p-norm to compute for the `max_norm` option. Default ``2``. TORCH_ARG(double, norm_type) = 2.; - /// If given, this will scale gradients by the inverse of frequency of the words in the mini-batch. Default ``false``. + /// If given, this will scale gradients by the inverse of frequency of the + /// words in the mini-batch. Default ``false``. TORCH_ARG(bool, scale_grad_by_freq) = false; /// If ``true``, gradient w.r.t. `weight` matrix will be a sparse tensor. TORCH_ARG(bool, sparse) = false; @@ -90,13 +101,15 @@ struct TORCH_API EmbeddingFuncOptions { // ============================================================================ -typedef c10::variant EmbeddingBagMode; +typedef c10::variant + EmbeddingBagMode; /// Options for the `EmbeddingBag` module. /// /// Example: /// ``` -/// EmbeddingBag model(EmbeddingBagOptions(10, 2).max_norm(2).norm_type(2.5).scale_grad_by_freq(true).sparse(true).mode(torch::kSum)); +/// EmbeddingBag model(EmbeddingBagOptions(10, +/// 2).max_norm(2).norm_type(2.5).scale_grad_by_freq(true).sparse(true).mode(torch::kSum)); /// ``` struct TORCH_API EmbeddingBagOptions { EmbeddingBagOptions(int64_t num_embeddings, int64_t embedding_dim); @@ -105,20 +118,25 @@ struct TORCH_API EmbeddingBagOptions { TORCH_ARG(int64_t, num_embeddings); /// The size of each embedding vector. TORCH_ARG(int64_t, embedding_dim); - /// If given, each embedding vector with norm larger than `max_norm` is renormalized to have norm `max_norm`. + /// If given, each embedding vector with norm larger than `max_norm` is + /// renormalized to have norm `max_norm`. TORCH_ARG(c10::optional, max_norm) = c10::nullopt; /// The p of the p-norm to compute for the `max_norm` option. Default ``2``. TORCH_ARG(double, norm_type) = 2.; - /// If given, this will scale gradients by the inverse of frequency of the words in the mini-batch. Default ``false``. - /// Note: this option is not supported when ``mode="kMax"``. + /// If given, this will scale gradients by the inverse of frequency of the + /// words in the mini-batch. Default ``false``. Note: this option is not + /// supported when ``mode="kMax"``. TORCH_ARG(bool, scale_grad_by_freq) = false; - /// ``"kSum"``, ``"kMean"`` or ``"kMax"``. Specifies the way to reduce the bag. ``"kSum"`` computes the weighted sum, taking `per_sample_weights` - /// into consideration. ``"kMean"`` computes the average of the values in the bag, ``"kMax"`` computes the max value over each bag. + /// ``"kSum"``, ``"kMean"`` or ``"kMax"``. Specifies the way to reduce the + /// bag. ``"kSum"`` computes the weighted sum, taking `per_sample_weights` + /// into consideration. ``"kMean"`` computes the average of the values in the + /// bag, ``"kMax"`` computes the max value over each bag. TORCH_ARG(EmbeddingBagMode, mode) = torch::kMean; /// If ``true``, gradient w.r.t. `weight` matrix will be a sparse tensor. /// Note: this option is not supported when ``mode="kMax"``. TORCH_ARG(bool, sparse) = false; - /// The learnable weights of the module of shape (num_embeddings, embedding_dim) + /// The learnable weights of the module of shape (num_embeddings, + /// embedding_dim) TORCH_ARG(torch::Tensor, _weight) = Tensor(); /// If ``true``, `offsets` has one additional element, where the last element /// is equivalent to the size of `indices`. This matches the CSR format. @@ -127,8 +145,9 @@ struct TORCH_API EmbeddingBagOptions { /// gradient; therefore, the embedding vector at padding_idx is not updated /// during training, i.e. it remains as a fixed "pad". For a newly constructed /// EmbeddingBag, the embedding vector at `padding_idx` will default to all - /// zeros, but can be updated to another value to be used as the padding vector. - /// Note that the embedding vector at `padding_idx` is excluded from the reduction. + /// zeros, but can be updated to another value to be used as the padding + /// vector. Note that the embedding vector at `padding_idx` is excluded from + /// the reduction. TORCH_ARG(c10::optional, padding_idx) = c10::nullopt; }; @@ -137,17 +156,22 @@ struct TORCH_API EmbeddingBagOptions { /// Options for the `EmbeddingBag::from_pretrained` function. struct TORCH_API EmbeddingBagFromPretrainedOptions { /// If ``true``, the tensor does not get updated in the learning process. - /// Equivalent to ``embeddingbag.weight.requires_grad_(false)``. Default: ``true`` + /// Equivalent to ``embeddingbag.weight.requires_grad_(false)``. Default: + /// ``true`` TORCH_ARG(bool, freeze) = true; - /// If given, each embedding vector with norm larger than `max_norm` is renormalized to have norm `max_norm`. + /// If given, each embedding vector with norm larger than `max_norm` is + /// renormalized to have norm `max_norm`. TORCH_ARG(c10::optional, max_norm) = c10::nullopt; /// The p of the p-norm to compute for the `max_norm` option. Default ``2``. TORCH_ARG(double, norm_type) = 2.; - /// If given, this will scale gradients by the inverse of frequency of the words in the mini-batch. Default ``false``. - /// Note: this option is not supported when ``mode="kMax"``. + /// If given, this will scale gradients by the inverse of frequency of the + /// words in the mini-batch. Default ``false``. Note: this option is not + /// supported when ``mode="kMax"``. TORCH_ARG(bool, scale_grad_by_freq) = false; - /// ``"kSum"``, ``"kMean"`` or ``"kMax"``. Specifies the way to reduce the bag. ``"kSum"`` computes the weighted sum, taking `per_sample_weights` - /// into consideration. ``"kMean"`` computes the average of the values in the bag, ``"kMax"`` computes the max value over each bag. + /// ``"kSum"``, ``"kMean"`` or ``"kMax"``. Specifies the way to reduce the + /// bag. ``"kSum"`` computes the weighted sum, taking `per_sample_weights` + /// into consideration. ``"kMean"`` computes the average of the values in the + /// bag, ``"kMax"`` computes the max value over each bag. TORCH_ARG(EmbeddingBagMode, mode) = torch::kMean; /// If ``true``, gradient w.r.t. `weight` matrix will be a sparse tensor. /// Note: this option is not supported when ``mode="kMax"``. @@ -172,28 +196,34 @@ namespace functional { /// Example: /// ``` /// namespace F = torch::nn::functional; -/// F::embedding_bag(input, weight, F::EmbeddingBagFuncOptions().mode(torch::kSum).offsets(offsets)); +/// F::embedding_bag(input, weight, +/// F::EmbeddingBagFuncOptions().mode(torch::kSum).offsets(offsets)); /// ``` struct TORCH_API EmbeddingBagFuncOptions { /// Only used when `input` is 1D. `offsets` determines /// the starting index position of each bag (sequence) in `input`. TORCH_ARG(torch::Tensor, offsets) = Tensor(); - /// If given, each embedding vector with norm larger than `max_norm` is renormalized to have norm `max_norm`. + /// If given, each embedding vector with norm larger than `max_norm` is + /// renormalized to have norm `max_norm`. TORCH_ARG(c10::optional, max_norm) = c10::nullopt; /// The p of the p-norm to compute for the `max_norm` option. Default ``2``. TORCH_ARG(double, norm_type) = 2.; - /// If given, this will scale gradients by the inverse of frequency of the words in the mini-batch. Default ``false``. - /// Note: this option is not supported when ``mode="kMax"``. + /// If given, this will scale gradients by the inverse of frequency of the + /// words in the mini-batch. Default ``false``. Note: this option is not + /// supported when ``mode="kMax"``. TORCH_ARG(bool, scale_grad_by_freq) = false; - /// ``"kSum"``, ``"kMean"`` or ``"kMax"``. Specifies the way to reduce the bag. ``"kSum"`` computes the weighted sum, taking `per_sample_weights` - /// into consideration. ``"kMean"`` computes the average of the values in the bag, ``"kMax"`` computes the max value over each bag. + /// ``"kSum"``, ``"kMean"`` or ``"kMax"``. Specifies the way to reduce the + /// bag. ``"kSum"`` computes the weighted sum, taking `per_sample_weights` + /// into consideration. ``"kMean"`` computes the average of the values in the + /// bag, ``"kMax"`` computes the max value over each bag. TORCH_ARG(EmbeddingBagMode, mode) = torch::kMean; /// If ``true``, gradient w.r.t. `weight` matrix will be a sparse tensor. /// Note: this option is not supported when ``mode="kMax"``. TORCH_ARG(bool, sparse) = false; - /// a tensor of float / double weights, or None to indicate all weights should be taken to be 1. - /// If specified, `per_sample_weights` must have exactly the same shape as input and is treated as - /// having the same `offsets`, if those are not None. + /// a tensor of float / double weights, or None to indicate all weights should + /// be taken to be 1. If specified, `per_sample_weights` must have exactly the + /// same shape as input and is treated as having the same `offsets`, if those + /// are not None. TORCH_ARG(torch::Tensor, per_sample_weights) = Tensor(); /// If ``true``, `offsets` has one additional element, where the last element /// is equivalent to the size of `indices`. This matches the CSR format. Note: diff --git a/torch/csrc/api/include/torch/nn/options/fold.h b/torch/csrc/api/include/torch/nn/options/fold.h index 00edd84dfaebf8..a8979f0d6050eb 100644 --- a/torch/csrc/api/include/torch/nn/options/fold.h +++ b/torch/csrc/api/include/torch/nn/options/fold.h @@ -12,7 +12,8 @@ namespace nn { /// /// Example: /// ``` -/// Fold model(FoldOptions({8, 8}, {3, 3}).dilation(2).padding({2, 1}).stride(2)); +/// Fold model(FoldOptions({8, 8}, {3, 3}).dilation(2).padding({2, +/// 1}).stride(2)); /// ``` struct TORCH_API FoldOptions { FoldOptions(ExpandingArray<2> output_size, ExpandingArray<2> kernel_size) diff --git a/torch/csrc/api/include/torch/nn/options/instancenorm.h b/torch/csrc/api/include/torch/nn/options/instancenorm.h index ae9b3b6089a4da..d93e10d0c95a23 100644 --- a/torch/csrc/api/include/torch/nn/options/instancenorm.h +++ b/torch/csrc/api/include/torch/nn/options/instancenorm.h @@ -34,7 +34,8 @@ struct TORCH_API InstanceNormOptions { /// /// Example: /// ``` -/// InstanceNorm1d model(InstanceNorm1dOptions(4).eps(0.5).momentum(0.1).affine(false).track_running_stats(true)); +/// InstanceNorm1d +/// model(InstanceNorm1dOptions(4).eps(0.5).momentum(0.1).affine(false).track_running_stats(true)); /// ``` using InstanceNorm1dOptions = InstanceNormOptions; @@ -42,7 +43,8 @@ using InstanceNorm1dOptions = InstanceNormOptions; /// /// Example: /// ``` -/// InstanceNorm2d model(InstanceNorm2dOptions(4).eps(0.5).momentum(0.1).affine(false).track_running_stats(true)); +/// InstanceNorm2d +/// model(InstanceNorm2dOptions(4).eps(0.5).momentum(0.1).affine(false).track_running_stats(true)); /// ``` using InstanceNorm2dOptions = InstanceNormOptions; @@ -50,7 +52,8 @@ using InstanceNorm2dOptions = InstanceNormOptions; /// /// Example: /// ``` -/// InstanceNorm3d model(InstanceNorm3dOptions(4).eps(0.5).momentum(0.1).affine(false).track_running_stats(true)); +/// InstanceNorm3d +/// model(InstanceNorm3dOptions(4).eps(0.5).momentum(0.1).affine(false).track_running_stats(true)); /// ``` using InstanceNorm3dOptions = InstanceNormOptions; @@ -61,7 +64,8 @@ namespace functional { /// Example: /// ``` /// namespace F = torch::nn::functional; -/// F::instance_norm(input, F::InstanceNormFuncOptions().running_mean(mean).running_var(variance).weight(weight).bias(bias).momentum(0.1).eps(1e-5)); +/// F::instance_norm(input, +/// F::InstanceNormFuncOptions().running_mean(mean).running_var(variance).weight(weight).bias(bias).momentum(0.1).eps(1e-5)); /// ``` struct TORCH_API InstanceNormFuncOptions { TORCH_ARG(Tensor, running_mean) = Tensor(); diff --git a/torch/csrc/api/include/torch/nn/options/linear.h b/torch/csrc/api/include/torch/nn/options/linear.h index 742a887e285b4c..ec08da7539a81b 100644 --- a/torch/csrc/api/include/torch/nn/options/linear.h +++ b/torch/csrc/api/include/torch/nn/options/linear.h @@ -1,9 +1,9 @@ #pragma once +#include #include #include #include -#include namespace torch { namespace nn { @@ -78,7 +78,10 @@ struct TORCH_API UnflattenOptions { /// Bilinear model(BilinearOptions(3, 2, 4).bias(false)); /// ``` struct TORCH_API BilinearOptions { - BilinearOptions(int64_t in1_features, int64_t in2_features, int64_t out_features); + BilinearOptions( + int64_t in1_features, + int64_t in2_features, + int64_t out_features); /// The number of features in input 1 (columns of the input1 matrix). TORCH_ARG(int64_t, in1_features); /// The number of features in input 2 (columns of the input2 matrix). diff --git a/torch/csrc/api/include/torch/nn/options/loss.h b/torch/csrc/api/include/torch/nn/options/loss.h index 4961f4b8cb1242..57f4046a4799dc 100644 --- a/torch/csrc/api/include/torch/nn/options/loss.h +++ b/torch/csrc/api/include/torch/nn/options/loss.h @@ -1,8 +1,8 @@ #pragma once #include -#include #include +#include #include namespace torch { @@ -15,7 +15,8 @@ namespace nn { /// L1Loss model(L1LossOptions(torch::kNone)); /// ``` struct TORCH_API L1LossOptions { - typedef c10::variant reduction_t; + typedef c10::variant + reduction_t; TORCH_OPTIONS_CTOR_VARIANT_ARG3(L1LossOptions, reduction, kNone, kMean, kSum) @@ -43,12 +44,24 @@ using L1LossFuncOptions = L1LossOptions; /// /// Example: /// ``` -/// KLDivLoss model(KLDivLossOptions().reduction(torch::kNone).log_target(false)); +/// KLDivLoss +/// model(KLDivLossOptions().reduction(torch::kNone).log_target(false)); /// ``` struct TORCH_API KLDivLossOptions { - typedef c10::variant reduction_t; - - TORCH_OPTIONS_CTOR_VARIANT_ARG4(KLDivLossOptions, reduction, kNone, kBatchMean, kSum, kMean) + typedef c10::variant< + enumtype::kNone, + enumtype::kBatchMean, + enumtype::kSum, + enumtype::kMean> + reduction_t; + + TORCH_OPTIONS_CTOR_VARIANT_ARG4( + KLDivLossOptions, + reduction, + kNone, + kBatchMean, + kSum, + kMean) /// Specifies the reduction to apply to the output. /// ``'none'`` | ``'batchmean'`` | ``'sum'`` | ``'mean'``. Default: ``'mean'`` @@ -67,7 +80,8 @@ namespace functional { /// Example: /// ``` /// namespace F = torch::nn::functional; -/// F::kl_div(input, target, F::KLDivFuncOptions().reduction(torch::kNone).log_target(false)); +/// F::kl_div(input, target, +/// F::KLDivFuncOptions().reduction(torch::kNone).log_target(false)); /// ``` using KLDivFuncOptions = KLDivLossOptions; } // namespace functional @@ -81,7 +95,8 @@ using KLDivFuncOptions = KLDivLossOptions; /// MSELoss model(MSELossOptions(torch::kNone)); /// ``` struct TORCH_API MSELossOptions { - typedef c10::variant reduction_t; + typedef c10::variant + reduction_t; TORCH_OPTIONS_CTOR_VARIANT_ARG3(MSELossOptions, reduction, kNone, kMean, kSum) @@ -113,7 +128,8 @@ using MSELossFuncOptions = MSELossOptions; /// BCELoss model(BCELossOptions().reduction(torch::kNone).weight(weight)); /// ``` struct TORCH_API BCELossOptions { - typedef c10::variant reduction_t; + typedef c10::variant + reduction_t; /// A manual rescaling weight given to the loss of each batch element. TORCH_ARG(Tensor, weight) = {}; @@ -131,7 +147,8 @@ namespace functional { /// Example: /// ``` /// namespace F = torch::nn::functional; -/// F::binary_cross_entropy(input, target, F::BinaryCrossEntropyFuncOptions().weight(weight)); +/// F::binary_cross_entropy(input, target, +/// F::BinaryCrossEntropyFuncOptions().weight(weight)); /// ``` using BinaryCrossEntropyFuncOptions = BCELossOptions; } // namespace functional @@ -142,10 +159,12 @@ using BinaryCrossEntropyFuncOptions = BCELossOptions; /// /// Example: /// ``` -/// HingeEmbeddingLoss model(HingeEmbeddingLossOptions().margin(4).reduction(torch::kNone)); +/// HingeEmbeddingLoss +/// model(HingeEmbeddingLossOptions().margin(4).reduction(torch::kNone)); /// ``` struct TORCH_API HingeEmbeddingLossOptions { - typedef c10::variant reduction_t; + typedef c10::variant + reduction_t; /// Specifies the threshold for which the distance of a negative sample must /// reach in order to incur zero loss. Default: 1 @@ -157,13 +176,14 @@ struct TORCH_API HingeEmbeddingLossOptions { namespace functional { /// Options for `torch::nn::functional::hinge_embedding_loss`. /// -/// See the documentation for `torch::nn::HingeEmbeddingLossOptions` class to learn what -/// arguments are supported. +/// See the documentation for `torch::nn::HingeEmbeddingLossOptions` class to +/// learn what arguments are supported. /// /// Example: /// ``` /// namespace F = torch::nn::functional; -/// F::hinge_embedding_loss(input, target, F::HingeEmbeddingLossFuncOptions().margin(2)); +/// F::hinge_embedding_loss(input, target, +/// F::HingeEmbeddingLossFuncOptions().margin(2)); /// ``` using HingeEmbeddingLossFuncOptions = HingeEmbeddingLossOptions; } // namespace functional @@ -177,7 +197,8 @@ using HingeEmbeddingLossFuncOptions = HingeEmbeddingLossOptions; /// MultiMarginLoss model(MultiMarginLossOptions().margin(2).weight(weight)); /// ``` struct TORCH_API MultiMarginLossOptions { - typedef c10::variant reduction_t; + typedef c10::variant + reduction_t; /// Has a default value of :math:`1`. :math:`1` and :math:`2` /// are the only supported values. @@ -189,22 +210,25 @@ struct TORCH_API MultiMarginLossOptions { /// treated as if having all ones. TORCH_ARG(Tensor, weight) = Tensor(); /// Specifies the reduction to apply to the output: - /// ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, + /// ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be + /// applied, /// ``'mean'``: the sum of the output will be divided by the number of - /// elements in the output, ``'sum'``: the output will be summed. Default: ``'mean'`` + /// elements in the output, ``'sum'``: the output will be summed. Default: + /// ``'mean'`` TORCH_ARG(reduction_t, reduction) = torch::kMean; }; namespace functional { /// Options for `torch::nn::functional::multi_margin_loss`. /// -/// See the documentation for `torch::nn::MultiMarginLossOptions` class to learn what -/// arguments are supported. +/// See the documentation for `torch::nn::MultiMarginLossOptions` class to learn +/// what arguments are supported. /// /// Example: /// ``` /// namespace F = torch::nn::functional; -/// F::multi_margin_loss(input, target, F::MultiMarginLossFuncOptions().margin(2).weight(weight)); +/// F::multi_margin_loss(input, target, +/// F::MultiMarginLossFuncOptions().margin(2).weight(weight)); /// ``` using MultiMarginLossFuncOptions = MultiMarginLossOptions; } // namespace functional @@ -218,7 +242,8 @@ using MultiMarginLossFuncOptions = MultiMarginLossOptions; /// CosineEmbeddingLoss model(CosineEmbeddingLossOptions().margin(0.5)); /// ``` struct TORCH_API CosineEmbeddingLossOptions { - typedef c10::variant reduction_t; + typedef c10::variant + reduction_t; /// Specifies the threshold for which the distance of a negative sample must /// reach in order to incur zero loss. Should be a number from -1 to 1, 0 @@ -231,13 +256,14 @@ struct TORCH_API CosineEmbeddingLossOptions { namespace functional { /// Options for `torch::nn::functional::cosine_embedding_loss`. /// -/// See the documentation for `torch::nn::CosineEmbeddingLossOptions` class to learn what -/// arguments are supported. +/// See the documentation for `torch::nn::CosineEmbeddingLossOptions` class to +/// learn what arguments are supported. /// /// Example: /// ``` /// namespace F = torch::nn::functional; -/// F::cosine_embedding_loss(input1, input2, target, F::CosineEmbeddingLossFuncOptions().margin(0.5)); +/// F::cosine_embedding_loss(input1, input2, target, +/// F::CosineEmbeddingLossFuncOptions().margin(0.5)); /// ``` using CosineEmbeddingLossFuncOptions = CosineEmbeddingLossOptions; } // namespace functional @@ -251,9 +277,15 @@ using CosineEmbeddingLossFuncOptions = CosineEmbeddingLossOptions; /// MultiLabelMarginLoss model(MultiLabelMarginLossOptions(torch::kNone)); /// ``` struct TORCH_API MultiLabelMarginLossOptions { - typedef c10::variant reduction_t; + typedef c10::variant + reduction_t; - TORCH_OPTIONS_CTOR_VARIANT_ARG3(MultiLabelMarginLossOptions, reduction, kNone, kMean, kSum) + TORCH_OPTIONS_CTOR_VARIANT_ARG3( + MultiLabelMarginLossOptions, + reduction, + kNone, + kMean, + kSum) /// Specifies the reduction to apply to the output: 'none' | 'mean' | 'sum'. /// 'none': no reduction will be applied, 'mean': the sum of the output will @@ -265,13 +297,14 @@ struct TORCH_API MultiLabelMarginLossOptions { namespace functional { /// Options for `torch::nn::functional::multilabel_margin_loss`. /// -/// See the documentation for `torch::nn::MultiLabelMarginLossOptions` class to learn what -/// arguments are supported. +/// See the documentation for `torch::nn::MultiLabelMarginLossOptions` class to +/// learn what arguments are supported. /// /// Example: /// ``` /// namespace F = torch::nn::functional; -/// F::multilabel_margin_loss(input, target, F::MultilabelMarginLossFuncOptions(torch::kNone)); +/// F::multilabel_margin_loss(input, target, +/// F::MultilabelMarginLossFuncOptions(torch::kNone)); /// ``` using MultilabelMarginLossFuncOptions = MultiLabelMarginLossOptions; } // namespace functional @@ -285,9 +318,15 @@ using MultilabelMarginLossFuncOptions = MultiLabelMarginLossOptions; /// SoftMarginLoss model(SoftMarginLossOptions(torch::kNone)); /// ``` struct TORCH_API SoftMarginLossOptions { - typedef c10::variant reduction_t; + typedef c10::variant + reduction_t; - TORCH_OPTIONS_CTOR_VARIANT_ARG3(SoftMarginLossOptions, reduction, kNone, kMean, kSum) + TORCH_OPTIONS_CTOR_VARIANT_ARG3( + SoftMarginLossOptions, + reduction, + kNone, + kMean, + kSum) /// Specifies the reduction to apply to the output: 'none' | 'mean' | 'sum'. /// 'none': no reduction will be applied, 'mean': the sum of the output will @@ -299,13 +338,14 @@ struct TORCH_API SoftMarginLossOptions { namespace functional { /// Options for `torch::nn::functional::soft_margin_loss`. /// -/// See the documentation for `torch::nn::SoftMarginLossOptions` class to learn what -/// arguments are supported. +/// See the documentation for `torch::nn::SoftMarginLossOptions` class to learn +/// what arguments are supported. /// /// Example: /// ``` /// namespace F = torch::nn::functional; -/// F::soft_margin_loss(input, target, F::SoftMarginLossFuncOptions(torch::kNone)); +/// F::soft_margin_loss(input, target, +/// F::SoftMarginLossFuncOptions(torch::kNone)); /// ``` using SoftMarginLossFuncOptions = SoftMarginLossOptions; } // namespace functional @@ -316,10 +356,12 @@ using SoftMarginLossFuncOptions = SoftMarginLossOptions; /// /// Example: /// ``` -/// MultiLabelSoftMarginLoss model(MultiLabelSoftMarginLossOptions().reduction(torch::kNone).weight(weight)); +/// MultiLabelSoftMarginLoss +/// model(MultiLabelSoftMarginLossOptions().reduction(torch::kNone).weight(weight)); /// ``` struct TORCH_API MultiLabelSoftMarginLossOptions { - typedef c10::variant reduction_t; + typedef c10::variant + reduction_t; /// A manual rescaling weight given to each /// class. If given, it has to be a Tensor of size `C`. Otherwise, it is @@ -336,13 +378,14 @@ struct TORCH_API MultiLabelSoftMarginLossOptions { namespace functional { /// Options for `torch::nn::functional::multilabel_soft_margin_loss`. /// -/// See the documentation for `torch::nn::MultiLabelSoftMarginLossOptions` class to learn what -/// arguments are supported. +/// See the documentation for `torch::nn::MultiLabelSoftMarginLossOptions` class +/// to learn what arguments are supported. /// /// Example: /// ``` /// namespace F = torch::nn::functional; -/// F::multilabel_soft_margin_loss(input, target, F::MultilabelSoftMarginLossFuncOptions().reduction(torch::kNone).weight(weight)); +/// F::multilabel_soft_margin_loss(input, target, +/// F::MultilabelSoftMarginLossFuncOptions().reduction(torch::kNone).weight(weight)); /// ``` using MultilabelSoftMarginLossFuncOptions = MultiLabelSoftMarginLossOptions; } // namespace functional @@ -353,10 +396,12 @@ using MultilabelSoftMarginLossFuncOptions = MultiLabelSoftMarginLossOptions; /// /// Example: /// ``` -/// TripletMarginLoss model(TripletMarginLossOptions().margin(3).p(2).eps(1e-06).swap(false)); +/// TripletMarginLoss +/// model(TripletMarginLossOptions().margin(3).p(2).eps(1e-06).swap(false)); /// ``` struct TORCH_API TripletMarginLossOptions { - typedef c10::variant reduction_t; + typedef c10::variant + reduction_t; /// Specifies the threshold for which the distance of a negative sample must /// reach in order to incur zero loss. Default: 1 @@ -375,13 +420,14 @@ struct TORCH_API TripletMarginLossOptions { namespace functional { /// Options for `torch::nn::functional::triplet_margin_loss`. /// -/// See the documentation for `torch::nn::TripletMarginLossOptions` class to learn what -/// arguments are supported. +/// See the documentation for `torch::nn::TripletMarginLossOptions` class to +/// learn what arguments are supported. /// /// Example: /// ``` /// namespace F = torch::nn::functional; -/// F::triplet_margin_loss(anchor, positive, negative, F::TripletMarginLossFuncOptions().margin(1.0)); +/// F::triplet_margin_loss(anchor, positive, negative, +/// F::TripletMarginLossFuncOptions().margin(1.0)); /// ``` using TripletMarginLossFuncOptions = TripletMarginLossOptions; } // namespace functional @@ -392,16 +438,20 @@ using TripletMarginLossFuncOptions = TripletMarginLossOptions; /// /// Example: /// ``` -/// TripletMarginWithDistanceLoss model(TripletMarginWithDistanceLossOptions().margin(3).swap(false)); +/// TripletMarginWithDistanceLoss +/// model(TripletMarginWithDistanceLossOptions().margin(3).swap(false)); /// ``` struct TORCH_API TripletMarginWithDistanceLossOptions { - typedef c10::variant reduction_t; - typedef std::function distance_function_t; + typedef c10::variant + reduction_t; + typedef std::function + distance_function_t; /// Specifies a nonnegative, real-valued function that quantifies the /// closeness of two tensors. If not specified, `F::pairwise_distance` will /// be used. Default: nullopt - TORCH_ARG(c10::optional, distance_function) = c10::nullopt; + TORCH_ARG(c10::optional, distance_function) = + c10::nullopt; /// Specifies a nonnegative margin representing the minimum difference /// between the positive and negative distances required for the loss to be 0. /// Larger margins penalize cases where the negative examples are not distance @@ -420,15 +470,17 @@ struct TORCH_API TripletMarginWithDistanceLossOptions { namespace functional { /// Options for `torch::nn::functional::triplet_margin_with_distance_loss`. /// -/// See the documentation for `torch::nn::TripletMarginWithDistanceLossOptions` class to learn what -/// arguments are supported. +/// See the documentation for `torch::nn::TripletMarginWithDistanceLossOptions` +/// class to learn what arguments are supported. /// /// Example: /// ``` /// namespace F = torch::nn::functional; -/// F::triplet_margin_with_distance_loss(anchor, positive, negative, F::TripletMarginWithDistanceLossFuncOptions().margin(1.0)); +/// F::triplet_margin_with_distance_loss(anchor, positive, negative, +/// F::TripletMarginWithDistanceLossFuncOptions().margin(1.0)); /// ``` -using TripletMarginWithDistanceLossFuncOptions = TripletMarginWithDistanceLossOptions; +using TripletMarginWithDistanceLossFuncOptions = + TripletMarginWithDistanceLossOptions; } // namespace functional // ============================================================================ @@ -437,10 +489,12 @@ using TripletMarginWithDistanceLossFuncOptions = TripletMarginWithDistanceLossOp /// /// Example: /// ``` -/// CTCLoss model(CTCLossOptions().blank(42).zero_infinity(false).reduction(torch::kSum)); +/// CTCLoss +/// model(CTCLossOptions().blank(42).zero_infinity(false).reduction(torch::kSum)); /// ``` struct TORCH_API CTCLossOptions { - typedef c10::variant reduction_t; + typedef c10::variant + reduction_t; /// blank label. Default `0`. TORCH_ARG(int64_t, blank) = 0; @@ -461,7 +515,8 @@ namespace functional { /// Example: /// ``` /// namespace F = torch::nn::functional; -/// F::ctc_loss(log_probs, targets, input_lengths, target_lengths, F::CTCLossFuncOptions().reduction(torch::kNone)); +/// F::ctc_loss(log_probs, targets, input_lengths, target_lengths, +/// F::CTCLossFuncOptions().reduction(torch::kNone)); /// ``` using CTCLossFuncOptions = CTCLossOptions; } // namespace functional @@ -475,24 +530,31 @@ using CTCLossFuncOptions = CTCLossOptions; /// SmoothL1Loss model(SmoothL1LossOptions().reduction(torch::kNone).beta(0.5)); /// ``` struct TORCH_API SmoothL1LossOptions { - typedef c10::variant reduction_t; + typedef c10::variant + reduction_t; - TORCH_OPTIONS_CTOR_VARIANT_ARG3(SmoothL1LossOptions, reduction, kNone, kMean, kSum) + TORCH_OPTIONS_CTOR_VARIANT_ARG3( + SmoothL1LossOptions, + reduction, + kNone, + kMean, + kSum) /// Specifies the reduction to apply to the output: 'none' | 'mean' | 'sum'. /// 'none': no reduction will be applied, 'mean': the sum of the output will /// be divided by the number of elements in the output, 'sum': the output will /// be summed. Default: 'mean' TORCH_ARG(reduction_t, reduction) = torch::kMean; - /// Specifies the threshold at which to change between L1 and L2 loss. Default: 1.0 + /// Specifies the threshold at which to change between L1 and L2 loss. + /// Default: 1.0 TORCH_ARG(double, beta) = 1.0; }; namespace functional { /// Options for `torch::nn::functional::smooth_l1_loss`. /// -/// See the documentation for `torch::nn::SmoothL1LossOptions` class to learn what -/// arguments are supported. +/// See the documentation for `torch::nn::SmoothL1LossOptions` class to learn +/// what arguments are supported. /// /// Example: /// ``` @@ -511,16 +573,23 @@ using SmoothL1LossFuncOptions = SmoothL1LossOptions; /// HuberLoss model(HuberLossOptions().reduction(torch::kNone).delta(0.5)); /// ``` struct TORCH_API HuberLossOptions { - typedef c10::variant reduction_t; + typedef c10::variant + reduction_t; - TORCH_OPTIONS_CTOR_VARIANT_ARG3(HuberLossOptions, reduction, kNone, kMean, kSum) + TORCH_OPTIONS_CTOR_VARIANT_ARG3( + HuberLossOptions, + reduction, + kNone, + kMean, + kSum) /// Specifies the reduction to apply to the output: 'none' | 'mean' | 'sum'. /// 'none': no reduction will be applied, 'mean': the sum of the output will /// be divided by the number of elements in the output, 'sum': the output will /// be summed. Default: 'mean' TORCH_ARG(reduction_t, reduction) = torch::kMean; - /// Specifies the threshold at which to change between L1 and L2 loss. Default: 1.0 + /// Specifies the threshold at which to change between L1 and L2 loss. + /// Default: 1.0 TORCH_ARG(double, delta) = 1.0; }; @@ -544,10 +613,12 @@ using HuberLossFuncOptions = HuberLossOptions; /// /// Example: /// ``` -/// PoissonNLLLoss model(PoissonNLLLossOptions().log_input(false).full(true).eps(0.42).reduction(torch::kSum)); +/// PoissonNLLLoss +/// model(PoissonNLLLossOptions().log_input(false).full(true).eps(0.42).reduction(torch::kSum)); /// ``` struct TORCH_API PoissonNLLLossOptions { - typedef c10::variant reduction_t; + typedef c10::variant + reduction_t; /// if true the loss is computed as `exp(input) - target * input`, /// if false the loss is `input - target * log(input + eps)`. @@ -565,13 +636,14 @@ struct TORCH_API PoissonNLLLossOptions { namespace functional { /// Options for `torch::nn::functional::poisson_nll_loss`. /// -/// See the documentation for `torch::nn::PoissonNLLLossOptions` class to learn what -/// arguments are supported. +/// See the documentation for `torch::nn::PoissonNLLLossOptions` class to learn +/// what arguments are supported. /// /// Example: /// ``` /// namespace F = torch::nn::functional; -/// F::poisson_nll_loss(input, target, F::PoissonNLLLossFuncOptions().reduction(torch::kNone)); +/// F::poisson_nll_loss(input, target, +/// F::PoissonNLLLossFuncOptions().reduction(torch::kNone)); /// ``` using PoissonNLLLossFuncOptions = PoissonNLLLossOptions; } // namespace functional @@ -582,10 +654,12 @@ using PoissonNLLLossFuncOptions = PoissonNLLLossOptions; /// /// Example: /// ``` -/// MarginRankingLoss model(MarginRankingLossOptions().margin(0.5).reduction(torch::kSum)); +/// MarginRankingLoss +/// model(MarginRankingLossOptions().margin(0.5).reduction(torch::kSum)); /// ``` struct TORCH_API MarginRankingLossOptions { - typedef c10::variant reduction_t; + typedef c10::variant + reduction_t; /// Has a default value of `0`. TORCH_ARG(double, margin) = 0; @@ -596,13 +670,14 @@ struct TORCH_API MarginRankingLossOptions { namespace functional { /// Options for `torch::nn::functional::margin_ranking_loss`. /// -/// See the documentation for `torch::nn::MarginRankingLossOptions` class to learn what -/// arguments are supported. +/// See the documentation for `torch::nn::MarginRankingLossOptions` class to +/// learn what arguments are supported. /// /// Example: /// ``` /// namespace F = torch::nn::functional; -/// F::margin_ranking_loss(input1, input2, target, F::MarginRankingLossFuncOptions().margin(0.5).reduction(torch::kSum)); +/// F::margin_ranking_loss(input1, input2, target, +/// F::MarginRankingLossFuncOptions().margin(0.5).reduction(torch::kSum)); /// ``` using MarginRankingLossFuncOptions = MarginRankingLossOptions; } // namespace functional @@ -616,7 +691,8 @@ using MarginRankingLossFuncOptions = MarginRankingLossOptions; /// NLLLoss model(NLLLossOptions().ignore_index(-100).reduction(torch::kMean)); /// ``` struct TORCH_API NLLLossOptions { - typedef c10::variant reduction_t; + typedef c10::variant + reduction_t; /// A manual rescaling weight given to each /// class. If given, it has to be a Tensor of size `C`. Otherwise, it is @@ -638,7 +714,8 @@ namespace functional { /// Example: /// ``` /// namespace F = torch::nn::functional; -/// F::nll_loss(input, target, F::NLLLossFuncOptions().ignore_index(-100).reduction(torch::kMean)); +/// F::nll_loss(input, target, +/// F::NLLLossFuncOptions().ignore_index(-100).reduction(torch::kMean)); /// ``` using NLLLossFuncOptions = NLLLossOptions; } // namespace functional @@ -649,13 +726,15 @@ using NLLLossFuncOptions = NLLLossOptions; /// /// Example: /// ``` -/// CrossEntropyLoss model(CrossEntropyLossOptions().ignore_index(-100).reduction(torch::kMean)); +/// CrossEntropyLoss +/// model(CrossEntropyLossOptions().ignore_index(-100).reduction(torch::kMean)); /// ``` struct TORCH_API CrossEntropyLossOptions { - typedef c10::variant reduction_t; + typedef c10::variant + reduction_t; - /// A manual rescaling weight given to each class. If given, has to be a Tensor - /// of size C + /// A manual rescaling weight given to each class. If given, has to be a + /// Tensor of size C TORCH_ARG(Tensor, weight) = {}; /// Specifies a target value that is ignored /// and does not contribute to the input gradient. @@ -669,13 +748,14 @@ struct TORCH_API CrossEntropyLossOptions { namespace functional { /// Options for `torch::nn::functional::cross_entropy`. /// -/// See the documentation for `torch::nn::CrossEntropyLossOptions` class to learn what -/// arguments are supported. +/// See the documentation for `torch::nn::CrossEntropyLossOptions` class to +/// learn what arguments are supported. /// /// Example: /// ``` /// namespace F = torch::nn::functional; -/// F::cross_entropy(input, target, F::CrossEntropyFuncOptions().ignore_index(-100).reduction(torch::kMean)); +/// F::cross_entropy(input, target, +/// F::CrossEntropyFuncOptions().ignore_index(-100).reduction(torch::kMean)); /// ``` using CrossEntropyFuncOptions = CrossEntropyLossOptions; } // namespace functional @@ -686,10 +766,12 @@ using CrossEntropyFuncOptions = CrossEntropyLossOptions; /// /// Example: /// ``` -/// BCEWithLogitsLoss model(BCEWithLogitsLossOptions().reduction(torch::kNone).weight(weight)); +/// BCEWithLogitsLoss +/// model(BCEWithLogitsLossOptions().reduction(torch::kNone).weight(weight)); /// ``` struct TORCH_API BCEWithLogitsLossOptions { - typedef c10::variant reduction_t; + typedef c10::variant + reduction_t; /// A manual rescaling weight given to the loss of each batch element. /// If given, has to be a Tensor of size `nbatch`. TORCH_ARG(Tensor, weight) = {}; @@ -703,13 +785,14 @@ struct TORCH_API BCEWithLogitsLossOptions { namespace functional { /// Options for `torch::nn::functional::binary_cross_entropy_with_logits`. /// -/// See the documentation for `torch::nn::BCEWithLogitsLossOptions` class to learn what -/// arguments are supported. +/// See the documentation for `torch::nn::BCEWithLogitsLossOptions` class to +/// learn what arguments are supported. /// /// Example: /// ``` /// namespace F = torch::nn::functional; -/// F::binary_cross_entropy_with_logits(input, target, F::BinaryCrossEntropyWithLogitsFuncOptions().pos_weight(pos_weight).reduction(torch::kSum)); +/// F::binary_cross_entropy_with_logits(input, target, +/// F::BinaryCrossEntropyWithLogitsFuncOptions().pos_weight(pos_weight).reduction(torch::kSum)); /// ``` using BinaryCrossEntropyWithLogitsFuncOptions = BCEWithLogitsLossOptions; } // namespace functional diff --git a/torch/csrc/api/include/torch/nn/options/normalization.h b/torch/csrc/api/include/torch/nn/options/normalization.h index 05ed1dcda954ab..ae8c206736d50d 100644 --- a/torch/csrc/api/include/torch/nn/options/normalization.h +++ b/torch/csrc/api/include/torch/nn/options/normalization.h @@ -12,17 +12,19 @@ namespace nn { /// /// Example: /// ``` -/// LayerNorm model(LayerNormOptions({2, 2}).elementwise_affine(false).eps(2e-5)); +/// LayerNorm model(LayerNormOptions({2, +/// 2}).elementwise_affine(false).eps(2e-5)); /// ``` struct TORCH_API LayerNormOptions { /* implicit */ LayerNormOptions(std::vector normalized_shape); /// input shape from an expected input. TORCH_ARG(std::vector, normalized_shape); - /// a value added to the denominator for numerical stability. ``Default: 1e-5``. + /// a value added to the denominator for numerical stability. ``Default: + /// 1e-5``. TORCH_ARG(double, eps) = 1e-5; /// a boolean value that when set to ``true``, this module - /// has learnable per-element affine parameters initialized to ones (for weights) - /// and zeros (for biases). ``Default: true``. + /// has learnable per-element affine parameters initialized to ones (for + /// weights) and zeros (for biases). ``Default: true``. TORCH_ARG(bool, elementwise_affine) = true; }; @@ -46,7 +48,8 @@ struct TORCH_API LayerNormFuncOptions { TORCH_ARG(Tensor, bias) = {}; - /// a value added to the denominator for numerical stability. ``Default: 1e-5``. + /// a value added to the denominator for numerical stability. ``Default: + /// 1e-5``. TORCH_ARG(double, eps) = 1e-5; }; @@ -58,7 +61,8 @@ struct TORCH_API LayerNormFuncOptions { /// /// Example: /// ``` -/// LocalResponseNorm model(LocalResponseNormOptions(2).alpha(0.0002).beta(0.85).k(2.)); +/// LocalResponseNorm +/// model(LocalResponseNormOptions(2).alpha(0.0002).beta(0.85).k(2.)); /// ``` struct TORCH_API LocalResponseNormOptions { /* implicit */ LocalResponseNormOptions(int64_t size) : size_(size) {} @@ -78,8 +82,8 @@ struct TORCH_API LocalResponseNormOptions { namespace functional { /// Options for `torch::nn::functional::local_response_norm`. /// -/// See the documentation for `torch::nn::LocalResponseNormOptions` class to learn what -/// arguments are supported. +/// See the documentation for `torch::nn::LocalResponseNormOptions` class to +/// learn what arguments are supported. /// /// Example: /// ``` @@ -107,7 +111,6 @@ struct TORCH_API CrossMapLRN2dOptions { TORCH_ARG(double, beta) = 0.75; TORCH_ARG(int64_t, k) = 1; - }; // ============================================================================ @@ -153,8 +156,8 @@ struct TORCH_API GroupNormOptions { /// a value added to the denominator for numerical stability. Default: 1e-5 TORCH_ARG(double, eps) = 1e-5; /// a boolean value that when set to ``true``, this module - /// has learnable per-channel affine parameters initialized to ones (for weights) - /// and zeros (for biases). Default: ``true``. + /// has learnable per-channel affine parameters initialized to ones (for + /// weights) and zeros (for biases). Default: ``true``. TORCH_ARG(bool, affine) = true; }; diff --git a/torch/csrc/api/include/torch/nn/options/padding.h b/torch/csrc/api/include/torch/nn/options/padding.h index 36c12cc16e3a90..7b7012211daa2d 100644 --- a/torch/csrc/api/include/torch/nn/options/padding.h +++ b/torch/csrc/api/include/torch/nn/options/padding.h @@ -2,8 +2,8 @@ #include #include -#include #include +#include #include #include @@ -13,15 +13,17 @@ namespace nn { /// Options for a `D`-dimensional ReflectionPad module. template struct TORCH_API ReflectionPadOptions { - ReflectionPadOptions(ExpandingArray padding) : padding_(padding) {} + ReflectionPadOptions(ExpandingArray padding) : padding_(padding) {} /// The size of the padding. /// If it is `int`, uses the same padding in all boundaries. - /// If it is a 2-`tuple` (for ReflectionPad1d), uses (padding_left, padding_right). - /// If it is a 4-`tuple` (for ReflectionPad2d), uses (padding_left, padding_right, padding_top, padding_bottom). - /// If it is a 6-`tuple` (for ReflectionPad3d), uses (padding_left, padding_right, padding_top, padding_bottom, padding_front, padding_back). + /// If it is a 2-`tuple` (for ReflectionPad1d), uses (padding_left, + /// padding_right). If it is a 4-`tuple` (for ReflectionPad2d), uses + /// (padding_left, padding_right, padding_top, padding_bottom). If it is a + /// 6-`tuple` (for ReflectionPad3d), uses (padding_left, padding_right, + /// padding_top, padding_bottom, padding_front, padding_back). - TORCH_ARG(ExpandingArray, padding); + TORCH_ARG(ExpandingArray, padding); }; /// `ReflectionPadOptions` specialized for the `ReflectionPad1d` module. @@ -53,15 +55,18 @@ using ReflectionPad3dOptions = ReflectionPadOptions<3>; /// Options for a `D`-dimensional ReplicationPad module. template struct TORCH_API ReplicationPadOptions { - ReplicationPadOptions(ExpandingArray padding) : padding_(padding) {} + ReplicationPadOptions(ExpandingArray padding) : padding_(padding) {} /// The size of the padding. /// - If it is `int`, uses the same padding in all boundaries. - /// - If it is a 2-`tuple` (for ReplicationPad1d), uses (padding_left, padding_right). - /// - If it is a 4-`tuple` (for ReplicationPad2d), uses (padding_left, padding_right, padding_top, padding_bottom). + /// - If it is a 2-`tuple` (for ReplicationPad1d), uses (padding_left, + /// padding_right). + /// - If it is a 4-`tuple` (for ReplicationPad2d), uses (padding_left, + /// padding_right, padding_top, padding_bottom). /// - If it is a 6-`tuple` (for ReplicationPad3d), uses - /// (padding_left, padding_right, padding_top, padding_bottom, padding_front, padding_back). - TORCH_ARG(ExpandingArray, padding); + /// (padding_left, padding_right, padding_top, padding_bottom, + /// padding_front, padding_back). + TORCH_ARG(ExpandingArray, padding); }; /// `ReplicationPadOptions` specialized for the `ReplicationPad1d` module. @@ -101,7 +106,8 @@ struct TORCH_API ZeroPad2dOptions { /// The size of the padding. /// - If it is `int`, uses the same padding in all boundaries. - /// - If it is a 4-`tuple` (for ZeroPad2d), uses (padding_left, padding_right, padding_top, padding_bottom). + /// - If it is a 4-`tuple` (for ZeroPad2d), uses (padding_left, padding_right, + /// padding_top, padding_bottom). TORCH_ARG(ExpandingArray<4>, padding); }; @@ -110,15 +116,19 @@ struct TORCH_API ZeroPad2dOptions { /// Options for a `D`-dimensional ConstantPad module. template struct TORCH_API ConstantPadOptions { - ConstantPadOptions(ExpandingArray padding, double value) : padding_(padding), value_(value) {} + ConstantPadOptions(ExpandingArray padding, double value) + : padding_(padding), value_(value) {} /// The size of the padding. /// - If it is `int`, uses the same padding in all boundaries. - /// - If it is a 2-`tuple` (for ConstantPad1d), uses (padding_left, padding_right). - /// - If it is a 4-`tuple` (for ConstantPad2d), uses (padding_left, padding_right, padding_top, padding_bottom). + /// - If it is a 2-`tuple` (for ConstantPad1d), uses (padding_left, + /// padding_right). + /// - If it is a 4-`tuple` (for ConstantPad2d), uses (padding_left, + /// padding_right, padding_top, padding_bottom). /// - If it is a 6-`tuple` (for ConstantPad3d), uses - /// (padding_left, padding_right, padding_top, padding_bottom, padding_front, padding_back). - TORCH_ARG(ExpandingArray, padding); + /// (padding_left, padding_right, padding_top, padding_bottom, + /// padding_front, padding_back). + TORCH_ARG(ExpandingArray, padding); /// Fill value for constant padding. TORCH_ARG(double, value); @@ -157,14 +167,16 @@ namespace functional { /// Example: /// ``` /// namespace F = torch::nn::functional; -/// F::pad(input, F::PadFuncOptions({1, 2, 2, 1, 1, 2}).mode(torch::kReplicate)); +/// F::pad(input, F::PadFuncOptions({1, 2, 2, 1, 1, +/// 2}).mode(torch::kReplicate)); /// ``` struct TORCH_API PadFuncOptions { typedef c10::variant< - enumtype::kConstant, - enumtype::kReflect, - enumtype::kReplicate, - enumtype::kCircular> mode_t; + enumtype::kConstant, + enumtype::kReflect, + enumtype::kReplicate, + enumtype::kCircular> + mode_t; PadFuncOptions(std::vector pad); diff --git a/torch/csrc/api/include/torch/nn/options/pixelshuffle.h b/torch/csrc/api/include/torch/nn/options/pixelshuffle.h index 89edb2a8bfa667..859da98616db15 100644 --- a/torch/csrc/api/include/torch/nn/options/pixelshuffle.h +++ b/torch/csrc/api/include/torch/nn/options/pixelshuffle.h @@ -38,8 +38,8 @@ struct TORCH_API PixelUnshuffleOptions { namespace functional { /// Options for `torch::nn::functional::pixel_shuffle`. /// -/// See the documentation for `torch::nn::PixelShuffleOptions` class to learn what -/// arguments are supported. +/// See the documentation for `torch::nn::PixelShuffleOptions` class to learn +/// what arguments are supported. /// /// Example: /// ``` diff --git a/torch/csrc/api/include/torch/nn/options/pooling.h b/torch/csrc/api/include/torch/nn/options/pooling.h index 38e20bdd02ce45..219ae7fd5dd01c 100644 --- a/torch/csrc/api/include/torch/nn/options/pooling.h +++ b/torch/csrc/api/include/torch/nn/options/pooling.h @@ -29,7 +29,8 @@ struct AvgPoolOptions { /// when True, will include the zero-padding in the averaging calculation TORCH_ARG(bool, count_include_pad) = true; - /// if specified, it will be used as divisor, otherwise size of the pooling region will be used. + /// if specified, it will be used as divisor, otherwise size of the pooling + /// region will be used. TORCH_ARG(c10::optional, divisor_override) = c10::nullopt; }; @@ -149,7 +150,8 @@ using MaxPool2dOptions = MaxPoolOptions<2>; using MaxPool3dOptions = MaxPoolOptions<3>; namespace functional { -/// Options for `torch::nn::functional::max_pool1d` and `torch::nn::functional::max_pool1d_with_indices`. +/// Options for `torch::nn::functional::max_pool1d` and +/// `torch::nn::functional::max_pool1d_with_indices`. /// /// Example: /// ``` @@ -160,7 +162,8 @@ using MaxPool1dFuncOptions = MaxPool1dOptions; } // namespace functional namespace functional { -/// Options for `torch::nn::functional::max_pool2d` and `torch::nn::functional::max_pool2d_with_indices`. +/// Options for `torch::nn::functional::max_pool2d` and +/// `torch::nn::functional::max_pool2d_with_indices`. /// /// Example: /// ``` @@ -171,7 +174,8 @@ using MaxPool2dFuncOptions = MaxPool2dOptions; } // namespace functional namespace functional { -/// Options for `torch::nn::functional::max_pool3d` and `torch::nn::functional::max_pool3d_with_indices`. +/// Options for `torch::nn::functional::max_pool3d` and +/// `torch::nn::functional::max_pool3d_with_indices`. /// /// Example: /// ``` @@ -207,7 +211,8 @@ using AdaptiveMaxPool1dOptions = AdaptiveMaxPoolOptions>; /// ``` /// AdaptiveMaxPool2d model(AdaptiveMaxPool2dOptions({3, 2})); /// ``` -using AdaptiveMaxPool2dOptions = AdaptiveMaxPoolOptions>; +using AdaptiveMaxPool2dOptions = + AdaptiveMaxPoolOptions>; /// `AdaptiveMaxPoolOptions` specialized for the `AdaptiveMaxPool3d` module. /// @@ -215,10 +220,12 @@ using AdaptiveMaxPool2dOptions = AdaptiveMaxPoolOptions>; +using AdaptiveMaxPool3dOptions = + AdaptiveMaxPoolOptions>; namespace functional { -/// Options for `torch::nn::functional::adaptive_max_pool1d` and `torch::nn::functional::adaptive_max_pool1d_with_indices` +/// Options for `torch::nn::functional::adaptive_max_pool1d` and +/// `torch::nn::functional::adaptive_max_pool1d_with_indices` /// /// Example: /// ``` @@ -229,7 +236,8 @@ using AdaptiveMaxPool1dFuncOptions = AdaptiveMaxPool1dOptions; } // namespace functional namespace functional { -/// Options for `torch::nn::functional::adaptive_max_pool2d` and `torch::nn::functional::adaptive_max_pool2d_with_indices` +/// Options for `torch::nn::functional::adaptive_max_pool2d` and +/// `torch::nn::functional::adaptive_max_pool2d_with_indices` /// /// Example: /// ``` @@ -240,7 +248,8 @@ using AdaptiveMaxPool2dFuncOptions = AdaptiveMaxPool2dOptions; } // namespace functional namespace functional { -/// Options for `torch::nn::functional::adaptive_max_pool3d` and `torch::nn::functional::adaptive_max_pool3d_with_indices` +/// Options for `torch::nn::functional::adaptive_max_pool3d` and +/// `torch::nn::functional::adaptive_max_pool3d_with_indices` /// /// Example: /// ``` @@ -276,7 +285,8 @@ using AdaptiveAvgPool1dOptions = AdaptiveAvgPoolOptions>; /// ``` /// AdaptiveAvgPool2d model(AdaptiveAvgPool2dOptions({3, 2})); /// ``` -using AdaptiveAvgPool2dOptions = AdaptiveAvgPoolOptions>; +using AdaptiveAvgPool2dOptions = + AdaptiveAvgPoolOptions>; /// `AdaptiveAvgPoolOptions` specialized for the `AdaptiveAvgPool3d` module. /// @@ -284,13 +294,14 @@ using AdaptiveAvgPool2dOptions = AdaptiveAvgPoolOptions>; +using AdaptiveAvgPool3dOptions = + AdaptiveAvgPoolOptions>; namespace functional { /// Options for `torch::nn::functional::adaptive_avg_pool1d`. /// -/// See the documentation for `torch::nn::AdaptiveAvgPool1dOptions` class to learn what -/// arguments are supported. +/// See the documentation for `torch::nn::AdaptiveAvgPool1dOptions` class to +/// learn what arguments are supported. /// /// Example: /// ``` @@ -303,8 +314,8 @@ using AdaptiveAvgPool1dFuncOptions = AdaptiveAvgPool1dOptions; namespace functional { /// Options for `torch::nn::functional::adaptive_avg_pool2d`. /// -/// See the documentation for `torch::nn::AdaptiveAvgPool2dOptions` class to learn what -/// arguments are supported. +/// See the documentation for `torch::nn::AdaptiveAvgPool2dOptions` class to +/// learn what arguments are supported. /// /// Example: /// ``` @@ -317,8 +328,8 @@ using AdaptiveAvgPool2dFuncOptions = AdaptiveAvgPool2dOptions; namespace functional { /// Options for `torch::nn::functional::adaptive_avg_pool3d`. /// -/// See the documentation for `torch::nn::AdaptiveAvgPool3dOptions` class to learn what -/// arguments are supported. +/// See the documentation for `torch::nn::AdaptiveAvgPool3dOptions` class to +/// learn what arguments are supported. /// /// Example: /// ``` @@ -393,25 +404,30 @@ struct MaxUnpoolFuncOptions { TORCH_ARG(c10::optional>, output_size) = c10::nullopt; }; -/// `MaxUnpoolFuncOptions` specialized for `torch::nn::functional::max_unpool1d`. +/// `MaxUnpoolFuncOptions` specialized for +/// `torch::nn::functional::max_unpool1d`. /// /// Example: /// ``` /// namespace F = torch::nn::functional; -/// F::max_unpool1d(x, indices, F::MaxUnpool1dFuncOptions(3).stride(2).padding(1)); +/// F::max_unpool1d(x, indices, +/// F::MaxUnpool1dFuncOptions(3).stride(2).padding(1)); /// ``` using MaxUnpool1dFuncOptions = MaxUnpoolFuncOptions<1>; -/// `MaxUnpoolFuncOptions` specialized for `torch::nn::functional::max_unpool2d`. +/// `MaxUnpoolFuncOptions` specialized for +/// `torch::nn::functional::max_unpool2d`. /// /// Example: /// ``` /// namespace F = torch::nn::functional; -/// F::max_unpool2d(x, indices, F::MaxUnpool2dFuncOptions(3).stride(2).padding(1)); +/// F::max_unpool2d(x, indices, +/// F::MaxUnpool2dFuncOptions(3).stride(2).padding(1)); /// ``` using MaxUnpool2dFuncOptions = MaxUnpoolFuncOptions<2>; -/// `MaxUnpoolFuncOptions` specialized for `torch::nn::functional::max_unpool3d`. +/// `MaxUnpoolFuncOptions` specialized for +/// `torch::nn::functional::max_unpool3d`. /// /// Example: /// ``` @@ -436,9 +452,9 @@ struct FractionalMaxPoolOptions { /// the target output size of the image TORCH_ARG(c10::optional>, output_size) = c10::nullopt; - /// If one wants to have an output size as a ratio of the input size, this option can be given. - /// This has to be a number or tuple in the range (0, 1) - using ExpandingArrayDouble=torch::ExpandingArray; + /// If one wants to have an output size as a ratio of the input size, this + /// option can be given. This has to be a number or tuple in the range (0, 1) + using ExpandingArrayDouble = torch::ExpandingArray; TORCH_ARG(c10::optional, output_ratio) = c10::nullopt; TORCH_ARG(torch::Tensor, _random_samples) = Tensor(); @@ -461,23 +477,27 @@ using FractionalMaxPool2dOptions = FractionalMaxPoolOptions<2>; using FractionalMaxPool3dOptions = FractionalMaxPoolOptions<3>; namespace functional { -/// Options for `torch::nn::functional::fractional_max_pool2d` and `torch::nn::functional::fractional_max_pool2d_with_indices` +/// Options for `torch::nn::functional::fractional_max_pool2d` and +/// `torch::nn::functional::fractional_max_pool2d_with_indices` /// /// Example: /// ``` /// namespace F = torch::nn::functional; -/// F::fractional_max_pool2d(x, F::FractionalMaxPool2dFuncOptions(3).output_size(2)); +/// F::fractional_max_pool2d(x, +/// F::FractionalMaxPool2dFuncOptions(3).output_size(2)); /// ``` using FractionalMaxPool2dFuncOptions = FractionalMaxPool2dOptions; } // namespace functional namespace functional { -/// Options for `torch::nn::functional::fractional_max_pool3d` and `torch::nn::functional::fractional_max_pool3d_with_indices` +/// Options for `torch::nn::functional::fractional_max_pool3d` and +/// `torch::nn::functional::fractional_max_pool3d_with_indices` /// /// Example: /// ``` /// namespace F = torch::nn::functional; -/// F::fractional_max_pool3d(x, F::FractionalMaxPool3dFuncOptions(3).output_size(2)); +/// F::fractional_max_pool3d(x, +/// F::FractionalMaxPool3dFuncOptions(3).output_size(2)); /// ``` using FractionalMaxPool3dFuncOptions = FractionalMaxPool3dOptions; } // namespace functional @@ -488,7 +508,9 @@ using FractionalMaxPool3dFuncOptions = FractionalMaxPool3dOptions; template struct LPPoolOptions { LPPoolOptions(double norm_type, ExpandingArray kernel_size) - : norm_type_(norm_type), kernel_size_(kernel_size), stride_(kernel_size) {} + : norm_type_(norm_type), + kernel_size_(kernel_size), + stride_(kernel_size) {} TORCH_ARG(double, norm_type); @@ -514,7 +536,8 @@ using LPPool1dOptions = LPPoolOptions<1>; /// /// Example: /// ``` -/// LPPool2d model(LPPool2dOptions(1, std::vector({3, 4})).stride({5, 6}).ceil_mode(true)); +/// LPPool2d model(LPPool2dOptions(1, std::vector({3, 4})).stride({5, +/// 6}).ceil_mode(true)); /// ``` using LPPool2dOptions = LPPoolOptions<2>; diff --git a/torch/csrc/api/include/torch/nn/options/rnn.h b/torch/csrc/api/include/torch/nn/options/rnn.h index eb1571ef8224b5..fec9392a3aeecc 100644 --- a/torch/csrc/api/include/torch/nn/options/rnn.h +++ b/torch/csrc/api/include/torch/nn/options/rnn.h @@ -1,8 +1,8 @@ #pragma once #include -#include #include +#include #include namespace torch { @@ -13,12 +13,16 @@ namespace detail { /// Common options for RNN, LSTM and GRU modules. struct TORCH_API RNNOptionsBase { typedef c10::variant< - enumtype::kLSTM, - enumtype::kGRU, - enumtype::kRNN_TANH, - enumtype::kRNN_RELU> rnn_options_base_mode_t; + enumtype::kLSTM, + enumtype::kGRU, + enumtype::kRNN_TANH, + enumtype::kRNN_RELU> + rnn_options_base_mode_t; - RNNOptionsBase(rnn_options_base_mode_t mode, int64_t input_size, int64_t hidden_size); + RNNOptionsBase( + rnn_options_base_mode_t mode, + int64_t input_size, + int64_t hidden_size); TORCH_ARG(rnn_options_base_mode_t, mode); /// The number of features of a single sample in the input sequence `x`. @@ -38,7 +42,8 @@ struct TORCH_API RNNOptionsBase { TORCH_ARG(double, dropout) = 0.0; /// Whether to make the RNN bidirectional. TORCH_ARG(bool, bidirectional) = false; - /// Cell projection dimension. If 0, projections are not added. Can only be used for LSTMs. + /// Cell projection dimension. If 0, projections are not added. Can only be + /// used for LSTMs. TORCH_ARG(int64_t, proj_size) = 0; }; @@ -48,7 +53,8 @@ struct TORCH_API RNNOptionsBase { /// /// Example: /// ``` -/// RNN model(RNNOptions(128, 64).num_layers(3).dropout(0.2).nonlinearity(torch::kTanh)); +/// RNN model(RNNOptions(128, +/// 64).num_layers(3).dropout(0.2).nonlinearity(torch::kTanh)); /// ``` struct TORCH_API RNNOptions { typedef c10::variant nonlinearity_t; @@ -64,7 +70,8 @@ struct TORCH_API RNNOptions { /// with the second RNN taking in outputs of the first RNN and /// computing the final results. Default: 1 TORCH_ARG(int64_t, num_layers) = 1; - /// The non-linearity to use. Can be either ``torch::kTanh`` or ``torch::kReLU``. Default: ``torch::kTanh`` + /// The non-linearity to use. Can be either ``torch::kTanh`` or + /// ``torch::kReLU``. Default: ``torch::kTanh`` TORCH_ARG(nonlinearity_t, nonlinearity) = torch::kTanh; /// If ``false``, then the layer does not use bias weights `b_ih` and `b_hh`. /// Default: ``true`` @@ -84,7 +91,8 @@ struct TORCH_API RNNOptions { /// /// Example: /// ``` -/// LSTM model(LSTMOptions(2, 4).num_layers(3).batch_first(false).bidirectional(true)); +/// LSTM model(LSTMOptions(2, +/// 4).num_layers(3).batch_first(false).bidirectional(true)); /// ``` struct TORCH_API LSTMOptions { LSTMOptions(int64_t input_size, int64_t hidden_size); @@ -118,7 +126,8 @@ struct TORCH_API LSTMOptions { /// /// Example: /// ``` -/// GRU model(GRUOptions(2, 4).num_layers(3).batch_first(false).bidirectional(true)); +/// GRU model(GRUOptions(2, +/// 4).num_layers(3).batch_first(false).bidirectional(true)); /// ``` struct TORCH_API GRUOptions { GRUOptions(int64_t input_size, int64_t hidden_size); @@ -150,7 +159,11 @@ namespace detail { /// Common options for RNNCell, LSTMCell and GRUCell modules struct TORCH_API RNNCellOptionsBase { - RNNCellOptionsBase(int64_t input_size, int64_t hidden_size, bool bias, int64_t num_chunks); + RNNCellOptionsBase( + int64_t input_size, + int64_t hidden_size, + bool bias, + int64_t num_chunks); virtual ~RNNCellOptionsBase() = default; TORCH_ARG(int64_t, input_size); @@ -165,7 +178,8 @@ struct TORCH_API RNNCellOptionsBase { /// /// Example: /// ``` -/// RNNCell model(RNNCellOptions(20, 10).bias(false).nonlinearity(torch::kReLU)); +/// RNNCell model(RNNCellOptions(20, +/// 10).bias(false).nonlinearity(torch::kReLU)); /// ``` struct TORCH_API RNNCellOptions { typedef c10::variant nonlinearity_t; @@ -179,7 +193,8 @@ struct TORCH_API RNNCellOptions { /// If ``false``, then the layer does not use bias weights `b_ih` and `b_hh`. /// Default: ``true`` TORCH_ARG(bool, bias) = true; - /// The non-linearity to use. Can be either ``torch::kTanh`` or ``torch::kReLU``. Default: ``torch::kTanh`` + /// The non-linearity to use. Can be either ``torch::kTanh`` or + /// ``torch::kReLU``. Default: ``torch::kTanh`` TORCH_ARG(nonlinearity_t, nonlinearity) = torch::kTanh; }; diff --git a/torch/csrc/api/include/torch/nn/options/transformer.h b/torch/csrc/api/include/torch/nn/options/transformer.h index bbdea49b3eaa35..41db38fe0757a7 100644 --- a/torch/csrc/api/include/torch/nn/options/transformer.h +++ b/torch/csrc/api/include/torch/nn/options/transformer.h @@ -2,8 +2,8 @@ #include #include -#include #include +#include #include #include @@ -20,14 +20,18 @@ namespace nn { /// auto options = TransformerOptions().d_model(4).nhead(2).dropout(0.0); /// ``` struct TORCH_API TransformerOptions { - // The following constructors are commonly used // Please don't add more unless it is proved as a common usage TransformerOptions() = default; TransformerOptions(int64_t d_model, int64_t nhead); - TransformerOptions(int64_t d_model, int64_t nhead, int64_t num_encoder_layers, int64_t num_decoder_layers); - - /// the number of expected features in the encoder/decoder inputs (default=512) + TransformerOptions( + int64_t d_model, + int64_t nhead, + int64_t num_encoder_layers, + int64_t num_decoder_layers); + + /// the number of expected features in the encoder/decoder inputs + /// (default=512) TORCH_ARG(int64_t, d_model) = 512; /// the number of heads in the multiheadattention models (default=8) @@ -45,7 +49,8 @@ struct TORCH_API TransformerOptions { /// the dropout value (default=0.1) TORCH_ARG(double, dropout) = 0.1; - /// the activation function of encoder/decoder intermediate layer (default=``torch::kReLU``) + /// the activation function of encoder/decoder intermediate layer + /// (default=``torch::kReLU``) TORCH_ARG(activation_t, activation) = torch::kReLU; /// custom encoder (default=None) @@ -55,6 +60,5 @@ struct TORCH_API TransformerOptions { TORCH_ARG(AnyModule, custom_decoder); }; - } // namespace nn } // namespace torch diff --git a/torch/csrc/api/include/torch/nn/options/transformercoder.h b/torch/csrc/api/include/torch/nn/options/transformercoder.h index d9e59b35eb1cb7..64f6b998f4c657 100644 --- a/torch/csrc/api/include/torch/nn/options/transformercoder.h +++ b/torch/csrc/api/include/torch/nn/options/transformercoder.h @@ -2,56 +2,65 @@ #include #include -#include #include +#include -#include #include +#include namespace torch { namespace nn { - /// Options for the `TransformerEncoder` - /// - /// Example: - /// ``` - /// TransformerEncoderLayer encoderLayer(TransformerEncoderLayerOptions(512, 8).dropout(0.1)); - /// auto options = TransformerEncoderOptions(encoderLayer, 6).norm(LayerNorm(LayerNormOptions({2}))); - /// ``` - struct TORCH_API TransformerEncoderOptions { - // This constructor will keep a shallow copy of encoder_layer, so it keeps all the data in encoder_layer. - TransformerEncoderOptions(TransformerEncoderLayer encoder_layer, int64_t num_layers); - // This constructor will create a new TransformerEncoderLayer obj based on passed in encoder_layer_options. - TransformerEncoderOptions(const TransformerEncoderLayerOptions& encoder_layer_options, int64_t num_layers); +/// Options for the `TransformerEncoder` +/// +/// Example: +/// ``` +/// TransformerEncoderLayer encoderLayer(TransformerEncoderLayerOptions(512, +/// 8).dropout(0.1)); auto options = TransformerEncoderOptions(encoderLayer, +/// 6).norm(LayerNorm(LayerNormOptions({2}))); +/// ``` +struct TORCH_API TransformerEncoderOptions { + // This constructor will keep a shallow copy of encoder_layer, so it keeps all + // the data in encoder_layer. + TransformerEncoderOptions( + TransformerEncoderLayer encoder_layer, + int64_t num_layers); + // This constructor will create a new TransformerEncoderLayer obj based on + // passed in encoder_layer_options. + TransformerEncoderOptions( + const TransformerEncoderLayerOptions& encoder_layer_options, + int64_t num_layers); - /// transformer Encoder Layer - TORCH_ARG(TransformerEncoderLayer, encoder_layer) = nullptr; + /// transformer Encoder Layer + TORCH_ARG(TransformerEncoderLayer, encoder_layer) = nullptr; - /// number of encoder layers - TORCH_ARG(int64_t, num_layers); + /// number of encoder layers + TORCH_ARG(int64_t, num_layers); - /// normalization module - TORCH_ARG(AnyModule, norm); - }; + /// normalization module + TORCH_ARG(AnyModule, norm); +}; /// Options for the `TransformerDecoder` module. /// /// Example: /// ``` -/// TransformerDecoderLayer decoder_layer(TransformerDecoderLayerOptions(512, 8).dropout(0.1)); -/// auto options = TransformerDecoderOptions(decoder_layer, 6)norm(LayerNorm(LayerNormOptions({2}))); -/// TransformerDecoder transformer_decoder(options); +/// TransformerDecoderLayer decoder_layer(TransformerDecoderLayerOptions(512, +/// 8).dropout(0.1)); auto options = TransformerDecoderOptions(decoder_layer, +/// 6)norm(LayerNorm(LayerNormOptions({2}))); TransformerDecoder +/// transformer_decoder(options); /// ``` struct TORCH_API TransformerDecoderOptions { // This constructor will keep the a ref of passed in decoder_layer, // so it keeps all the data in decoder_layer. - TransformerDecoderOptions(TransformerDecoderLayer decoder_layer, - int64_t num_layers); + TransformerDecoderOptions( + TransformerDecoderLayer decoder_layer, + int64_t num_layers); // This constructor will create a new TransformerDecoderLayer obj, // based on passed in decoder_layer_options. TransformerDecoderOptions( - const TransformerDecoderLayerOptions& decoder_layer_options, - int64_t num_layers); + const TransformerDecoderLayerOptions& decoder_layer_options, + int64_t num_layers); /// decoder layer to be cloned TORCH_ARG(TransformerDecoderLayer, decoder_layer) = nullptr; diff --git a/torch/csrc/api/include/torch/nn/options/transformerlayer.h b/torch/csrc/api/include/torch/nn/options/transformerlayer.h index aaad0040b3b6bd..7c24cb8dcdd196 100644 --- a/torch/csrc/api/include/torch/nn/options/transformerlayer.h +++ b/torch/csrc/api/include/torch/nn/options/transformerlayer.h @@ -2,13 +2,16 @@ #include #include -#include #include +#include namespace torch { namespace nn { -using activation_t = c10::variant >; +using activation_t = c10::variant< + enumtype::kReLU, + enumtype::kGELU, + std::function>; /// Options for the `TransformerEncoderLayer` /// @@ -17,7 +20,6 @@ using activation_t = c10::variant #include -#include #include +#include #include #include @@ -16,7 +16,8 @@ namespace nn { /// /// Example: /// ``` -/// Upsample model(UpsampleOptions().scale_factor(std::vector({3})).mode(torch::kLinear).align_corners(false)); +/// Upsample +/// model(UpsampleOptions().scale_factor(std::vector({3})).mode(torch::kLinear).align_corners(false)); /// ``` struct TORCH_API UpsampleOptions { /// output spatial sizes. @@ -32,7 +33,8 @@ struct TORCH_API UpsampleOptions { enumtype::kLinear, enumtype::kBilinear, enumtype::kBicubic, - enumtype::kTrilinear> mode_t; + enumtype::kTrilinear> + mode_t; TORCH_ARG(mode_t, mode) = torch::kNearest; /// if "True", the corner pixels of the input and output tensors are @@ -49,7 +51,8 @@ namespace functional { /// Example: /// ``` /// namespace F = torch::nn::functional; -/// F::interpolate(input, F::InterpolateFuncOptions().size(std::vector({4})).mode(torch::kNearest)); +/// F::interpolate(input, +/// F::InterpolateFuncOptions().size(std::vector({4})).mode(torch::kNearest)); /// ``` struct TORCH_API InterpolateFuncOptions { typedef c10::variant< @@ -59,7 +62,8 @@ struct TORCH_API InterpolateFuncOptions { enumtype::kBicubic, enumtype::kTrilinear, enumtype::kArea, - enumtype::kNearestExact> mode_t; + enumtype::kNearestExact> + mode_t; /// output spatial sizes. TORCH_ARG(c10::optional>, size) = c10::nullopt; @@ -83,19 +87,21 @@ struct TORCH_API InterpolateFuncOptions { TORCH_ARG(c10::optional, align_corners) = c10::nullopt; /// recompute the scale_factor for use in the - /// interpolation calculation. When `scale_factor` is passed as a parameter, it is used - /// to compute the `output_size`. If `recompute_scale_factor` is `true` or not specified, - /// a new `scale_factor` will be computed based on the output and input sizes for use in the - /// interpolation computation (i.e. the computation will be identical to if the computed - /// `output_size` were passed-in explicitly). Otherwise, the passed-in `scale_factor` will - /// be used in the interpolation computation. Note that when `scale_factor` is floating-point, - /// the recomputed scale_factor may differ from the one passed in due to rounding and precision - /// issues. + /// interpolation calculation. When `scale_factor` is passed as a parameter, + /// it is used to compute the `output_size`. If `recompute_scale_factor` is + /// `true` or not specified, a new `scale_factor` will be computed based on + /// the output and input sizes for use in the interpolation computation (i.e. + /// the computation will be identical to if the computed `output_size` were + /// passed-in explicitly). Otherwise, the passed-in `scale_factor` will be + /// used in the interpolation computation. Note that when `scale_factor` is + /// floating-point, the recomputed scale_factor may differ from the one passed + /// in due to rounding and precision issues. TORCH_ARG(c10::optional, recompute_scale_factor) = c10::nullopt; /// flag to apply anti-aliasing. Using anti-alias - /// option together with :attr:`align_corners` equals "False", interpolation result would match Pillow - /// result for downsampling operation. Supported modes: "bilinear". Default: "False". + /// option together with :attr:`align_corners` equals "False", interpolation + /// result would match Pillow result for downsampling operation. Supported + /// modes: "bilinear". Default: "False". TORCH_ARG(bool, antialias) = false; }; diff --git a/torch/csrc/api/include/torch/nn/options/vision.h b/torch/csrc/api/include/torch/nn/options/vision.h index 1b06ab252b789a..e2c59df0a48fde 100644 --- a/torch/csrc/api/include/torch/nn/options/vision.h +++ b/torch/csrc/api/include/torch/nn/options/vision.h @@ -1,8 +1,8 @@ #pragma once #include -#include #include +#include #include namespace torch { @@ -14,11 +14,14 @@ namespace functional { /// Example: /// ``` /// namespace F = torch::nn::functional; -/// F::grid_sample(input, grid, F::GridSampleFuncOptions().mode(torch::kBilinear).padding_mode(torch::kZeros).align_corners(true)); +/// F::grid_sample(input, grid, +/// F::GridSampleFuncOptions().mode(torch::kBilinear).padding_mode(torch::kZeros).align_corners(true)); /// ``` struct TORCH_API GridSampleFuncOptions { typedef c10::variant mode_t; - typedef c10::variant padding_mode_t; + typedef c10:: + variant + padding_mode_t; /// interpolation mode to calculate output values. Default: Bilinear TORCH_ARG(mode_t, mode) = torch::kBilinear; diff --git a/torch/csrc/api/include/torch/nn/parallel/data_parallel.h b/torch/csrc/api/include/torch/nn/parallel/data_parallel.h index 6794d856252b53..46bf2ac6953e7b 100644 --- a/torch/csrc/api/include/torch/nn/parallel/data_parallel.h +++ b/torch/csrc/api/include/torch/nn/parallel/data_parallel.h @@ -5,9 +5,9 @@ #include #include +#include #include #include -#include #include #include @@ -61,23 +61,30 @@ namespace { // in data parallel, and should not be exposed as a user API. struct ReduceAdd : public autograd::Node { explicit ReduceAdd(const at::Device& destination_device) - : destination_device_(destination_device) {}; + : destination_device_(destination_device){}; ~ReduceAdd() override {} autograd::variable_list apply(autograd::variable_list&& inputs) override { - TORCH_CHECK(!torch::autograd::compute_requires_grad(inputs), + TORCH_CHECK( + !torch::autograd::compute_requires_grad(inputs), "ReduceAdd can only be used during the backward pass of data parallel."); Tensor output = torch::zeros_like(inputs[0], {destination_device_}); - for (auto& input: inputs) { - TORCH_CHECK(input.sizes() == inputs[0].sizes(), + for (auto& input : inputs) { + TORCH_CHECK( + input.sizes() == inputs[0].sizes(), "All inputs of ReduceAdd must have the same size, but got ", - input.sizes(), " and ", inputs[0].sizes()); + input.sizes(), + " and ", + inputs[0].sizes()); - TORCH_CHECK(input.dtype() == inputs[0].dtype(), + TORCH_CHECK( + input.dtype() == inputs[0].dtype(), "All inputs of ReduceAdd must have the same dtype, but got ", - input.dtype(), " and ", inputs[0].dtype()); + input.dtype(), + " and ", + inputs[0].dtype()); // TODO: use nccl reduce output.add_(input.to(destination_device_)); @@ -100,7 +107,6 @@ void replicate_grad_edges( const std::shared_ptr& module, const std::vector>& replicas, const std::vector& devices) { - for (auto& parameter : module->named_parameters(/*recurse=*/false)) { auto grad_fn = std::make_shared((*parameter).device()); grad_fn->set_next_edges(autograd::collect_next_edges(*parameter)); @@ -111,7 +117,7 @@ void replicate_grad_edges( } for (auto& buffer : module->named_buffers(/*recurse=*/false)) { - if (buffer.value().requires_grad()){ + if (buffer.value().requires_grad()) { auto grad_fn = std::make_shared((*buffer).device()); grad_fn->set_next_edges(autograd::collect_next_edges(*buffer)); @@ -274,7 +280,9 @@ Tensor data_parallel( auto scattered_inputs = fmap(scatter.apply({std::move(input)})); // Input tensor might not be big enough to scale across all available devices if (scattered_inputs.size() < devices->size()) { - devices->resize(scattered_inputs.size(), Device(DeviceType::COMPILE_TIME_MAX_DEVICE_TYPES)); + devices->resize( + scattered_inputs.size(), + Device(DeviceType::COMPILE_TIME_MAX_DEVICE_TYPES)); } auto replicas = replicate(module, *devices); diff --git a/torch/csrc/api/include/torch/nn/utils/clip_grad.h b/torch/csrc/api/include/torch/nn/utils/clip_grad.h index 23b94f8466f254..6fcc2bd35a7735 100644 --- a/torch/csrc/api/include/torch/nn/utils/clip_grad.h +++ b/torch/csrc/api/include/torch/nn/utils/clip_grad.h @@ -11,11 +11,11 @@ namespace utils { // https://pytorch.org/docs/stable/nn.html?highlight=clip_grad_norm#torch.nn.utils.clip_grad_norm_ // for more details about this module. // -// Difference with the python version: unlike the python version, even when skipping the finiteness -// checks (error_if_nonfinite = false), this function will introduce a device <=> CPU -// synchronization (for devices where that makes sense!) in order to return a CPU-side `double`. -// This C++ version therefore cannot be run fully asynchronously w.r.t. the device of the -// gradients. +// Difference with the python version: unlike the python version, even when +// skipping the finiteness checks (error_if_nonfinite = false), this function +// will introduce a device <=> CPU synchronization (for devices where that makes +// sense!) in order to return a CPU-side `double`. This C++ version therefore +// cannot be run fully asynchronously w.r.t. the device of the gradients. inline double clip_grad_norm_( std::vector parameters, double max_norm, @@ -42,9 +42,11 @@ inline double clip_grad_norm_( for (const auto& param : params_with_grad) { norms.emplace_back(param.grad().data().abs().max()); } - total_norm_tensor = (norms.size() == 1) ? norms[0] : torch::max(torch::stack(norms)); + total_norm_tensor = + (norms.size() == 1) ? norms[0] : torch::max(torch::stack(norms)); } else if (norm_type == 0) { - total_norm_tensor = torch::full({}, static_cast(params_with_grad.size())); + total_norm_tensor = + torch::full({}, static_cast(params_with_grad.size())); } else { std::vector norms; norms.reserve(params_with_grad.size()); @@ -52,24 +54,30 @@ inline double clip_grad_norm_( for (const auto& param : params_with_grad) { norms.emplace_back(param.grad().data().norm(norm_type)); } - total_norm_tensor = (norms.size() == 1) ? norms[0] : torch::stack(norms).norm(norm_type); + total_norm_tensor = + (norms.size() == 1) ? norms[0] : torch::stack(norms).norm(norm_type); } - // When possible (ie when skipping the finiteness check), we avoid synchronizing the CPU and the - // gradients' device until the very end to preserve async execution on the device. - // When checking for finite-ness, this optional ensures we only sync once. + // When possible (ie when skipping the finiteness check), we avoid + // synchronizing the CPU and the gradients' device until the very end to + // preserve async execution on the device. When checking for finite-ness, this + // optional ensures we only sync once. c10::optional total_norm = c10::nullopt; if (error_if_nonfinite) { total_norm = total_norm_tensor.item().toDouble(); - TORCH_CHECK(std::isfinite(*total_norm), - "The total norm of order ", norm_type, " for gradients from `parameters` ", - "is non-finite, so it cannot be clipped. To disable this error and scale ", - "the gradients with the non-finite norm anyway, set ", - "`error_if_nonfinite=false`"); + TORCH_CHECK( + std::isfinite(*total_norm), + "The total norm of order ", + norm_type, + " for gradients from `parameters` ", + "is non-finite, so it cannot be clipped. To disable this error and scale ", + "the gradients with the non-finite norm anyway, set ", + "`error_if_nonfinite=false`"); } - auto clip_coef = max_norm / (total_norm_tensor + 1e-6); - auto clip_coef_clamped = torch::clamp(clip_coef, c10::nullopt /* min */, 1.0 /* max */); + auto clip_coef = max_norm / (total_norm_tensor + 1e-6); + auto clip_coef_clamped = + torch::clamp(clip_coef, c10::nullopt /* min */, 1.0 /* max */); for (auto& param : params_with_grad) { param.grad().data().mul_(clip_coef_clamped); } @@ -87,7 +95,8 @@ inline double clip_grad_norm_( double max_norm, double norm_type = 2.0, bool error_if_nonfinite = false) { - return clip_grad_norm_(std::vector(parameters), max_norm, norm_type, error_if_nonfinite); + return clip_grad_norm_( + std::vector(parameters), max_norm, norm_type, error_if_nonfinite); } // A wrapper around clip_grad_norm_ that allows us to call the function with a @@ -108,7 +117,6 @@ inline double clip_grad_norm_( inline void clip_grad_value_( std::vector parameters, double clip_value) { - for (const auto& param : parameters) { if (param.grad().defined()) { param.grad().data().clamp_(-clip_value, clip_value); @@ -118,7 +126,9 @@ inline void clip_grad_value_( // A wrapper around clip_grad_value_ that allows us to call the function with a // braced-init-list of Tensors. -inline void clip_grad_value_(std::initializer_list parameters, double clip_value) { +inline void clip_grad_value_( + std::initializer_list parameters, + double clip_value) { clip_grad_value_(std::vector(parameters), clip_value); } diff --git a/torch/csrc/api/include/torch/nn/utils/convert_parameters.h b/torch/csrc/api/include/torch/nn/utils/convert_parameters.h index 8b730e0f0c00aa..e08bb6228389a7 100644 --- a/torch/csrc/api/include/torch/nn/utils/convert_parameters.h +++ b/torch/csrc/api/include/torch/nn/utils/convert_parameters.h @@ -11,22 +11,24 @@ namespace utils { // in the same device. Currently, the conversion between model parameters // and single vector form is not supported for multiple allocations, // e.g. parameters in different GPUs, or mixture of CPU/GPU. -inline c10::optional _check_param_device(const torch::Tensor& param, c10::optional old_param_device) { +inline c10::optional _check_param_device( + const torch::Tensor& param, + c10::optional old_param_device) { // Meet the first parameter if (old_param_device == c10::nullopt) { old_param_device = param.is_cuda() ? param.get_device() : -1; - } - else { + } else { bool warn = false; if (param.is_cuda()) { // Check if in same GPU warn = (param.get_device() != old_param_device.value()); - } - else { // Check if in CPU + } else { // Check if in CPU warn = (old_param_device.value() != -1); } if (warn) { - TORCH_CHECK(false, "Found two parameters on different devices, ", - "this is currently not supported."); + TORCH_CHECK( + false, + "Found two parameters on different devices, ", + "this is currently not supported."); } } @@ -34,7 +36,8 @@ inline c10::optional _check_param_device(const torch::Tensor& param, c1 } // Convert parameters to one vector -inline torch::Tensor parameters_to_vector(const std::vector& parameters) { +inline torch::Tensor parameters_to_vector( + const std::vector& parameters) { c10::optional param_device; std::vector vec; @@ -51,7 +54,9 @@ inline torch::Tensor parameters_to_vector(const std::vector& para } // Convert one vector to the parameters -inline void vector_to_parameters(const torch::Tensor& vec, std::vector parameters) { +inline void vector_to_parameters( + const torch::Tensor& vec, + std::vector parameters) { // Flag for the device where the parameter is located c10::optional param_device; @@ -66,7 +71,8 @@ inline void vector_to_parameters(const torch::Tensor& vec, std::vector pad_packed_sequence( if (total_length.has_value()) { int64_t total_length_val = total_length.value(); TORCH_CHECK( - total_length_val >= max_seq_length, - "Expected total_length to be at least the length " - "of the longest sequence in input, but got " - "total_length=", total_length_val, " and max sequence length being ", max_seq_length); + total_length_val >= max_seq_length, + "Expected total_length to be at least the length " + "of the longest sequence in input, but got " + "total_length=", + total_length_val, + " and max sequence length being ", + max_seq_length); max_seq_length = total_length_val; } Tensor padded_output, lengths; std::tie(padded_output, lengths) = torch::_pad_packed_sequence( - sequence.data(), sequence.batch_sizes(), batch_first, padding_value, max_seq_length); + sequence.data(), + sequence.batch_sizes(), + batch_first, + padding_value, + max_seq_length); const Tensor& unsorted_indices = sequence.unsorted_indices(); if (unsorted_indices.defined()) { int64_t batch_dim = batch_first ? 0 : 1; - return std::make_tuple(padded_output.index_select(batch_dim, unsorted_indices), lengths.index({unsorted_indices})); + return std::make_tuple( + padded_output.index_select(batch_dim, unsorted_indices), + lengths.index({unsorted_indices})); } return std::make_tuple(padded_output, lengths); } @@ -269,7 +292,8 @@ inline std::tuple pad_packed_sequence( /// /// Arguments: /// sequences (torch::ArrayRef): list of variable length sequences. -/// batch_first (bool, optional): output will be in ``B x T x *`` if true, or in +/// batch_first (bool, optional): output will be in ``B x T x *`` if true, +/// or in /// ``T x B x *`` otherwise /// padding_value (double, optional): value for padded elements. Default: 0. /// @@ -290,24 +314,31 @@ inline Tensor pad_sequence( /// including zero. /// /// For unsorted sequences, use `enforce_sorted = false`. If ``enforce_sorted`` -/// is ``true``, the sequences should be sorted in the order of decreasing length. +/// is ``true``, the sequences should be sorted in the order of decreasing +/// length. /// /// /// Arguments: -/// sequences (torch::ArrayRef): A list of sequences of decreasing length. -/// enforce_sorted (bool, optional): if ``true``, checks that the input +/// sequences (torch::ArrayRef): A list of sequences of decreasing +/// length. enforce_sorted (bool, optional): if ``true``, checks that the +/// input /// contains sequences sorted by length in a decreasing order. If /// ``false``, this condition is not checked. Default: ``true``. /// /// Returns: /// a `PackedSequence` object -inline PackedSequence pack_sequence(ArrayRef sequences, bool enforce_sorted = true) { +inline PackedSequence pack_sequence( + ArrayRef sequences, + bool enforce_sorted = true) { Tensor lengths = torch::empty({(int64_t)sequences.size()}, kInt64); for (const auto i : c10::irange(sequences.size())) { lengths[i] = sequences[i].size(0); } return pack_padded_sequence( - at::pad_sequence(sequences), lengths, /*batch_first=*/false, /*enforce_sorted=*/enforce_sorted); + at::pad_sequence(sequences), + lengths, + /*batch_first=*/false, + /*enforce_sorted=*/enforce_sorted); } } // namespace rnn diff --git a/torch/csrc/api/include/torch/optim/adagrad.h b/torch/csrc/api/include/torch/optim/adagrad.h index 09c1de9f614e96..aa520224c52655 100644 --- a/torch/csrc/api/include/torch/optim/adagrad.h +++ b/torch/csrc/api/include/torch/optim/adagrad.h @@ -19,49 +19,73 @@ class InputArchive; namespace torch { namespace optim { -struct TORCH_API AdagradOptions : public OptimizerCloneableOptions { +struct TORCH_API AdagradOptions + : public OptimizerCloneableOptions { AdagradOptions(double lr = 1e-2); TORCH_ARG(double, lr) = 1e-2; TORCH_ARG(double, lr_decay) = 0; TORCH_ARG(double, weight_decay) = 0; TORCH_ARG(double, initial_accumulator_value) = 0; TORCH_ARG(double, eps) = 1e-10; -public: + + public: void serialize(torch::serialize::InputArchive& archive) override; void serialize(torch::serialize::OutputArchive& archive) const override; - TORCH_API friend bool operator==(const AdagradOptions& lhs, const AdagradOptions& rhs); + TORCH_API friend bool operator==( + const AdagradOptions& lhs, + const AdagradOptions& rhs); ~AdagradOptions() override = default; double get_lr() const override; void set_lr(const double lr) override; }; -struct TORCH_API AdagradParamState : public OptimizerCloneableParamState { +struct TORCH_API AdagradParamState + : public OptimizerCloneableParamState { TORCH_ARG(torch::Tensor, sum); TORCH_ARG(int64_t, step) = 0; -public: + public: void serialize(torch::serialize::InputArchive& archive) override; void serialize(torch::serialize::OutputArchive& archive) const override; - TORCH_API friend bool operator==(const AdagradParamState& lhs, const AdagradParamState& rhs); + TORCH_API friend bool operator==( + const AdagradParamState& lhs, + const AdagradParamState& rhs); ~AdagradParamState() override = default; }; class TORCH_API Adagrad : public Optimizer { public: - explicit Adagrad(std::vector param_groups, - AdagradOptions defaults = {}) : Optimizer(std::move(param_groups), std::make_unique(defaults)) { + explicit Adagrad( + std::vector param_groups, + AdagradOptions defaults = {}) + : Optimizer( + std::move(param_groups), + std::make_unique(defaults)) { TORCH_CHECK(defaults.lr() >= 0, "Invalid learning rate: ", defaults.lr()); - TORCH_CHECK(defaults.lr_decay() >= 0, "Invalid lr_decay value: ", defaults.lr_decay()); - TORCH_CHECK(defaults.weight_decay() >= 0, "Invalid weight_decay value: ", defaults.weight_decay()); - TORCH_CHECK(defaults.initial_accumulator_value() >= 0, "Invalid initial_accumulator_value value: ", defaults.initial_accumulator_value()); + TORCH_CHECK( + defaults.lr_decay() >= 0, + "Invalid lr_decay value: ", + defaults.lr_decay()); + TORCH_CHECK( + defaults.weight_decay() >= 0, + "Invalid weight_decay value: ", + defaults.weight_decay()); + TORCH_CHECK( + defaults.initial_accumulator_value() >= 0, + "Invalid initial_accumulator_value value: ", + defaults.initial_accumulator_value()); TORCH_CHECK(defaults.eps() >= 0, "Invalid epsilon value: ", defaults.eps()); for (const auto& group : param_groups_) { for (const auto& p : group.params()) { auto state = std::make_unique(); state->step(0); - state->sum(torch::full_like(p.data(), defaults.initial_accumulator_value(), at::MemoryFormat::Preserve)); - state_[c10::guts::to_string(p.unsafeGetTensorImpl())] = std::move(state); + state->sum(torch::full_like( + p.data(), + defaults.initial_accumulator_value(), + at::MemoryFormat::Preserve)); + state_[c10::guts::to_string(p.unsafeGetTensorImpl())] = + std::move(state); } } } @@ -69,7 +93,8 @@ class TORCH_API Adagrad : public Optimizer { explicit Adagrad( std::vector params, // NOLINTNEXTLINE(performance-move-const-arg) - AdagradOptions defaults = {}) : Adagrad({std::move(OptimizerParamGroup(params))}, defaults) {} + AdagradOptions defaults = {}) + : Adagrad({std::move(OptimizerParamGroup(params))}, defaults) {} torch::Tensor step(LossClosure closure = nullptr) override; void save(serialize::OutputArchive& archive) const override; diff --git a/torch/csrc/api/include/torch/optim/adam.h b/torch/csrc/api/include/torch/optim/adam.h index eec6304421b5d7..c4360a2a0838b2 100644 --- a/torch/csrc/api/include/torch/optim/adam.h +++ b/torch/csrc/api/include/torch/optim/adam.h @@ -25,43 +25,63 @@ struct TORCH_API AdamOptions : public OptimizerCloneableOptions { TORCH_ARG(double, eps) = 1e-8; TORCH_ARG(double, weight_decay) = 0; TORCH_ARG(bool, amsgrad) = false; -public: + + public: void serialize(torch::serialize::InputArchive& archive) override; void serialize(torch::serialize::OutputArchive& archive) const override; - TORCH_API friend bool operator==(const AdamOptions& lhs, const AdamOptions& rhs); + TORCH_API friend bool operator==( + const AdamOptions& lhs, + const AdamOptions& rhs); ~AdamOptions() override = default; double get_lr() const override; void set_lr(const double lr) override; }; -struct TORCH_API AdamParamState : public OptimizerCloneableParamState { +struct TORCH_API AdamParamState + : public OptimizerCloneableParamState { TORCH_ARG(int64_t, step) = 0; TORCH_ARG(torch::Tensor, exp_avg); TORCH_ARG(torch::Tensor, exp_avg_sq); TORCH_ARG(torch::Tensor, max_exp_avg_sq) = {}; -public: + public: void serialize(torch::serialize::InputArchive& archive) override; void serialize(torch::serialize::OutputArchive& archive) const override; - TORCH_API friend bool operator==(const AdamParamState& lhs, const AdamParamState& rhs); + TORCH_API friend bool operator==( + const AdamParamState& lhs, + const AdamParamState& rhs); ~AdamParamState() override = default; }; class TORCH_API Adam : public Optimizer { public: - explicit Adam(std::vector param_groups, - AdamOptions defaults = {}) : Optimizer(std::move(param_groups), std::make_unique(defaults)) { - TORCH_CHECK(defaults.lr() >= 0, "Invalid learning rate: ", defaults.lr()); - TORCH_CHECK(defaults.eps() >= 0, "Invalid epsilon value: ", defaults.eps()); - auto betas = defaults.betas(); - TORCH_CHECK(0 <= std::get<0>(betas) && std::get<0>(betas) < 1.0, "Invalid beta parameter at index 0: ", std::get<0>(betas)); - TORCH_CHECK(0 <= std::get<1>(betas) && std::get<1>(betas) < 1.0, "Invalid beta parameter at index 1: ", std::get<1>(betas)); - TORCH_CHECK(defaults.weight_decay() >= 0, "Invalid weight_decay value: ", defaults.weight_decay()); - } - explicit Adam( - std::vector params, - // NOLINTNEXTLINE(performance-move-const-arg) - AdamOptions defaults = {}) : Adam({std::move(OptimizerParamGroup(params))}, defaults) {} + explicit Adam( + std::vector param_groups, + AdamOptions defaults = {}) + : Optimizer( + std::move(param_groups), + std::make_unique(defaults)) { + TORCH_CHECK(defaults.lr() >= 0, "Invalid learning rate: ", defaults.lr()); + TORCH_CHECK(defaults.eps() >= 0, "Invalid epsilon value: ", defaults.eps()); + auto betas = defaults.betas(); + TORCH_CHECK( + 0 <= std::get<0>(betas) && std::get<0>(betas) < 1.0, + "Invalid beta parameter at index 0: ", + std::get<0>(betas)); + TORCH_CHECK( + 0 <= std::get<1>(betas) && std::get<1>(betas) < 1.0, + "Invalid beta parameter at index 1: ", + std::get<1>(betas)); + TORCH_CHECK( + defaults.weight_decay() >= 0, + "Invalid weight_decay value: ", + defaults.weight_decay()); + } + explicit Adam( + std::vector params, + // NOLINTNEXTLINE(performance-move-const-arg) + AdamOptions defaults = {}) + : Adam({std::move(OptimizerParamGroup(params))}, defaults) {} torch::Tensor step(LossClosure closure = nullptr) override; void save(serialize::OutputArchive& archive) const override; diff --git a/torch/csrc/api/include/torch/optim/adamw.h b/torch/csrc/api/include/torch/optim/adamw.h index aa6430db4417d0..8b7bcf3130760f 100644 --- a/torch/csrc/api/include/torch/optim/adamw.h +++ b/torch/csrc/api/include/torch/optim/adamw.h @@ -25,43 +25,63 @@ struct TORCH_API AdamWOptions : public OptimizerCloneableOptions { TORCH_ARG(double, eps) = 1e-8; TORCH_ARG(double, weight_decay) = 1e-2; TORCH_ARG(bool, amsgrad) = false; -public: + + public: void serialize(torch::serialize::InputArchive& archive) override; void serialize(torch::serialize::OutputArchive& archive) const override; - TORCH_API friend bool operator==(const AdamWOptions& lhs, const AdamWOptions& rhs); + TORCH_API friend bool operator==( + const AdamWOptions& lhs, + const AdamWOptions& rhs); ~AdamWOptions() override = default; double get_lr() const override; void set_lr(const double lr) override; }; -struct TORCH_API AdamWParamState : public OptimizerCloneableParamState { +struct TORCH_API AdamWParamState + : public OptimizerCloneableParamState { TORCH_ARG(int64_t, step) = 0; TORCH_ARG(torch::Tensor, exp_avg); TORCH_ARG(torch::Tensor, exp_avg_sq); TORCH_ARG(torch::Tensor, max_exp_avg_sq) = {}; -public: + public: void serialize(torch::serialize::InputArchive& archive) override; void serialize(torch::serialize::OutputArchive& archive) const override; - TORCH_API friend bool operator==(const AdamWParamState& lhs, const AdamWParamState& rhs); + TORCH_API friend bool operator==( + const AdamWParamState& lhs, + const AdamWParamState& rhs); ~AdamWParamState() override = default; }; class TORCH_API AdamW : public Optimizer { public: - explicit AdamW(std::vector param_groups, - AdamWOptions defaults = {}) : Optimizer(std::move(param_groups), std::make_unique(defaults)) { - TORCH_CHECK(defaults.lr() >= 0, "Invalid learning rate: ", defaults.lr()); - TORCH_CHECK(defaults.eps() >= 0, "Invalid epsilon value: ", defaults.eps()); - auto betas = defaults.betas(); - TORCH_CHECK(0 <= std::get<0>(betas) && std::get<0>(betas) < 1.0, "Invalid beta parameter at index 0: ", std::get<0>(betas)); - TORCH_CHECK(0 <= std::get<1>(betas) && std::get<1>(betas) < 1.0, "Invalid beta parameter at index 1: ", std::get<1>(betas)); - TORCH_CHECK(defaults.weight_decay() >= 0, "Invalid weight_decay value: ", defaults.weight_decay()); - } - explicit AdamW( - std::vector params, - // NOLINTNEXTLINE(performance-move-const-arg) - AdamWOptions defaults = {}) : AdamW({std::move(OptimizerParamGroup(params))}, defaults) {} + explicit AdamW( + std::vector param_groups, + AdamWOptions defaults = {}) + : Optimizer( + std::move(param_groups), + std::make_unique(defaults)) { + TORCH_CHECK(defaults.lr() >= 0, "Invalid learning rate: ", defaults.lr()); + TORCH_CHECK(defaults.eps() >= 0, "Invalid epsilon value: ", defaults.eps()); + auto betas = defaults.betas(); + TORCH_CHECK( + 0 <= std::get<0>(betas) && std::get<0>(betas) < 1.0, + "Invalid beta parameter at index 0: ", + std::get<0>(betas)); + TORCH_CHECK( + 0 <= std::get<1>(betas) && std::get<1>(betas) < 1.0, + "Invalid beta parameter at index 1: ", + std::get<1>(betas)); + TORCH_CHECK( + defaults.weight_decay() >= 0, + "Invalid weight_decay value: ", + defaults.weight_decay()); + } + explicit AdamW( + std::vector params, + // NOLINTNEXTLINE(performance-move-const-arg) + AdamWOptions defaults = {}) + : AdamW({std::move(OptimizerParamGroup(params))}, defaults) {} torch::Tensor step(LossClosure closure = nullptr) override; void save(serialize::OutputArchive& archive) const override; diff --git a/torch/csrc/api/include/torch/optim/lbfgs.h b/torch/csrc/api/include/torch/optim/lbfgs.h index 2b189b96ca4d4b..1906f4d4d19f78 100644 --- a/torch/csrc/api/include/torch/optim/lbfgs.h +++ b/torch/csrc/api/include/torch/optim/lbfgs.h @@ -22,17 +22,21 @@ struct TORCH_API LBFGSOptions : public OptimizerCloneableOptions { TORCH_ARG(double, tolerance_change) = 1e-9; TORCH_ARG(int64_t, history_size) = 100; TORCH_ARG(c10::optional, line_search_fn) = c10::nullopt; + public: void serialize(torch::serialize::InputArchive& archive) override; void serialize(torch::serialize::OutputArchive& archive) const override; - TORCH_API friend bool operator==(const LBFGSOptions& lhs, const LBFGSOptions& rhs); + TORCH_API friend bool operator==( + const LBFGSOptions& lhs, + const LBFGSOptions& rhs); ~LBFGSOptions() override = default; double get_lr() const override; void set_lr(const double lr) override; }; // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) -struct TORCH_API LBFGSParamState : public OptimizerCloneableParamState { +struct TORCH_API LBFGSParamState + : public OptimizerCloneableParamState { TORCH_ARG(int64_t, func_evals) = 0; TORCH_ARG(int64_t, n_iter) = 0; TORCH_ARG(double, t); @@ -48,26 +52,36 @@ struct TORCH_API LBFGSParamState : public OptimizerCloneableParamState param_groups, - LBFGSOptions defaults = {}) : Optimizer(std::move(param_groups), std::make_unique(defaults)) { - TORCH_CHECK(param_groups_.size() == 1, "LBFGS doesn't support per-parameter options (parameter groups)"); - if (defaults.max_eval() == c10::nullopt) { - auto max_eval_val = (defaults.max_iter() * 5) / 4; - static_cast(param_groups_[0].options()).max_eval(max_eval_val); - static_cast(*defaults_.get()).max_eval(max_eval_val); - } - _numel_cache = c10::nullopt; - } - explicit LBFGS( - std::vector params, - // NOLINTNEXTLINE(performance-move-const-arg) - LBFGSOptions defaults = {}) : LBFGS({std::move(OptimizerParamGroup(params))}, defaults) {} + explicit LBFGS( + std::vector param_groups, + LBFGSOptions defaults = {}) + : Optimizer( + std::move(param_groups), + std::make_unique(defaults)) { + TORCH_CHECK( + param_groups_.size() == 1, + "LBFGS doesn't support per-parameter options (parameter groups)"); + if (defaults.max_eval() == c10::nullopt) { + auto max_eval_val = (defaults.max_iter() * 5) / 4; + static_cast(param_groups_[0].options()) + .max_eval(max_eval_val); + static_cast(*defaults_.get()).max_eval(max_eval_val); + } + _numel_cache = c10::nullopt; + } + explicit LBFGS( + std::vector params, + // NOLINTNEXTLINE(performance-move-const-arg) + LBFGSOptions defaults = {}) + : LBFGS({std::move(OptimizerParamGroup(params))}, defaults) {} Tensor step(LossClosure closure) override; void save(serialize::OutputArchive& archive) const override; @@ -79,7 +93,10 @@ class TORCH_API LBFGS : public Optimizer { Tensor _gather_flat_grad(); void _add_grad(const double step_size, const Tensor& update); std::tuple _directional_evaluate( - const LossClosure& closure, const std::vector& x, double t, const Tensor& d); + const LossClosure& closure, + const std::vector& x, + double t, + const Tensor& d); void _set_param(const std::vector& params_data); std::vector _clone_param(); diff --git a/torch/csrc/api/include/torch/optim/optimizer.h b/torch/csrc/api/include/torch/optim/optimizer.h index 010cc15e38a1c3..7a6ef55c956fdf 100644 --- a/torch/csrc/api/include/torch/optim/optimizer.h +++ b/torch/csrc/api/include/torch/optim/optimizer.h @@ -1,11 +1,11 @@ #pragma once #include -#include #include +#include -#include #include +#include #include #include @@ -59,19 +59,29 @@ class TORCH_API OptimizerOptions { template class OptimizerCloneableOptions : public OptimizerOptions { -private: + private: std::unique_ptr clone() const override { return std::make_unique(static_cast(*this)); } }; -/// Stores parameters in the param_group and stores a pointer to the OptimizerOptions +/// Stores parameters in the param_group and stores a pointer to the +/// OptimizerOptions class TORCH_API OptimizerParamGroup { public: - // NOTE: In order to store `OptimizerParamGroup` in a `std::vector`, it has to be copy-constructible. - OptimizerParamGroup(const OptimizerParamGroup& param_group) : params_(param_group.params()), options_(param_group.has_options() ? param_group.options().clone() : nullptr) {} - OptimizerParamGroup(std::vector params) : params_(std::move(params)) {} - OptimizerParamGroup(std::vector params, std::unique_ptr options) : params_(std::move(params)), options_(std::move(options)) {} + // NOTE: In order to store `OptimizerParamGroup` in a `std::vector`, it has to + // be copy-constructible. + OptimizerParamGroup(const OptimizerParamGroup& param_group) + : params_(param_group.params()), + options_( + param_group.has_options() ? param_group.options().clone() + : nullptr) {} + OptimizerParamGroup(std::vector params) + : params_(std::move(params)) {} + OptimizerParamGroup( + std::vector params, + std::unique_ptr options) + : params_(std::move(params)), options_(std::move(options)) {} bool has_options() const; OptimizerOptions& options(); @@ -94,7 +104,10 @@ class TORCH_API Optimizer { Optimizer(const Optimizer& optimizer) = delete; Optimizer(Optimizer&& optimizer) = default; - explicit Optimizer(std::vector param_groups, std::unique_ptr defaults) : defaults_(std::move(defaults)) { + explicit Optimizer( + std::vector param_groups, + std::unique_ptr defaults) + : defaults_(std::move(defaults)) { for (const auto& param_group : param_groups) { add_param_group(param_group); } @@ -102,7 +115,12 @@ class TORCH_API Optimizer { /// Constructs the `Optimizer` from a vector of parameters. // NOLINTNEXTLINE(performance-move-const-arg) - explicit Optimizer(std::vector parameters, std::unique_ptr defaults) : Optimizer({std::move(OptimizerParamGroup(parameters))}, std::move(defaults)) {}; + explicit Optimizer( + std::vector parameters, + std::unique_ptr defaults) + : Optimizer( + {std::move(OptimizerParamGroup(parameters))}, + std::move(defaults)){}; /// Adds the given param_group to the optimizer's param_group list. void add_param_group(const OptimizerParamGroup& param_group); @@ -119,10 +137,12 @@ class TORCH_API Optimizer { /// Zeros out the gradients of all parameters. void zero_grad(); - /// Provides a const reference to the parameters in the first param_group this optimizer holds. + /// Provides a const reference to the parameters in the first param_group this + /// optimizer holds. const std::vector& parameters() const noexcept; - /// Provides a reference to the parameters in the first param_group this optimizer holds. + /// Provides a reference to the parameters in the first param_group this + /// optimizer holds. std::vector& parameters() noexcept; /// Returns the number of parameters referenced by the optimizer. @@ -139,10 +159,12 @@ class TORCH_API Optimizer { const std::vector& param_groups() const noexcept; /// Provides a reference to the state this optimizer holds - ska::flat_hash_map>& state() noexcept; + ska::flat_hash_map>& + state() noexcept; /// Provides a const reference to the state this optimizer holds - const ska::flat_hash_map>& state() const noexcept; + const ska::flat_hash_map>& + state() const noexcept; /// Serializes the optimizer state into the given `archive`. virtual void save(serialize::OutputArchive& archive) const; @@ -151,12 +173,12 @@ class TORCH_API Optimizer { virtual void load(serialize::InputArchive& archive); protected: - // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) - std::vector param_groups_; - // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) - ska::flat_hash_map> state_; - // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) - std::unique_ptr defaults_; + // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) + std::vector param_groups_; + // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) + ska::flat_hash_map> state_; + // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) + std::unique_ptr defaults_; }; /* How do we decide whether to serialize undefined tensors or @@ -164,18 +186,18 @@ class TORCH_API Optimizer { Answer: we strictly follow the behavior of Python API. To be more specific: For optimizer options: -a) For undefined tensor: currently no tensor is used as an options argument in Python API, - so we don't need to worry about it now. -b) For c10::nullopt value: we serialize c10::nullopt values into the output archive, - to follow the exact same behavior as Python API. +a) For undefined tensor: currently no tensor is used as an options argument in +Python API, so we don't need to worry about it now. b) For c10::nullopt value: +we serialize c10::nullopt values into the output archive, to follow the exact +same behavior as Python API. For optimizer param state: -a) For undefined tensor: in param state, undefined tensor in C++ impl is equivalent to - missing key in Python impl. Since we don't serialize missing keys in Python API, - we skip undefined tensors when serializing the param state. -b) For c10::nullopt value: in param state, c10::nullopt value in C++ impl is equivalent to - missing key in Python impl. Since we don't serialize missing keys in Python API, - we skip c10::nullopt values when serializing the param state. */ +a) For undefined tensor: in param state, undefined tensor in C++ impl is +equivalent to missing key in Python impl. Since we don't serialize missing keys +in Python API, we skip undefined tensors when serializing the param state. b) +For c10::nullopt value: in param state, c10::nullopt value in C++ impl is +equivalent to missing key in Python impl. Since we don't serialize missing keys +in Python API, we skip c10::nullopt values when serializing the param state. */ /// Serializes an `Optimizer` into an `OutputArchive`. TORCH_API serialize::OutputArchive& operator<<( diff --git a/torch/csrc/api/include/torch/optim/rmsprop.h b/torch/csrc/api/include/torch/optim/rmsprop.h index 47be24914d4855..fa1197aae293e7 100644 --- a/torch/csrc/api/include/torch/optim/rmsprop.h +++ b/torch/csrc/api/include/torch/optim/rmsprop.h @@ -21,7 +21,8 @@ class InputArchive; namespace torch { namespace optim { -struct TORCH_API RMSpropOptions : public OptimizerCloneableOptions { +struct TORCH_API RMSpropOptions + : public OptimizerCloneableOptions { RMSpropOptions(double lr = 1e-2); TORCH_ARG(double, lr) = 1e-2; TORCH_ARG(double, alpha) = 0.99; @@ -33,13 +34,16 @@ struct TORCH_API RMSpropOptions : public OptimizerCloneableOptions { +struct TORCH_API RMSpropParamState + : public OptimizerCloneableParamState { TORCH_ARG(int64_t, step) = 0; TORCH_ARG(torch::Tensor, square_avg); TORCH_ARG(torch::Tensor, momentum_buffer) = {}; @@ -48,24 +52,39 @@ struct TORCH_API RMSpropParamState : public OptimizerCloneableParamState param_groups, - RMSpropOptions defaults = {}) : Optimizer(std::move(param_groups), std::make_unique(defaults)) { + explicit RMSprop( + std::vector param_groups, + RMSpropOptions defaults = {}) + : Optimizer( + std::move(param_groups), + std::make_unique(defaults)) { TORCH_CHECK(defaults.lr() >= 0, "Invalid learning rate: ", defaults.lr()); TORCH_CHECK(defaults.eps() >= 0, "Invalid epsilon value: ", defaults.eps()); - TORCH_CHECK(defaults.momentum() >= 0, "Invalid momentum value: ", defaults.momentum()); - TORCH_CHECK(defaults.weight_decay() >= 0, "Invalid weight_decay value: ", defaults.weight_decay()); - TORCH_CHECK(defaults.alpha() >= 0, "Invalid alpha value: ", defaults.alpha()); + TORCH_CHECK( + defaults.momentum() >= 0, + "Invalid momentum value: ", + defaults.momentum()); + TORCH_CHECK( + defaults.weight_decay() >= 0, + "Invalid weight_decay value: ", + defaults.weight_decay()); + TORCH_CHECK( + defaults.alpha() >= 0, "Invalid alpha value: ", defaults.alpha()); } - explicit RMSprop(std::vector params, + explicit RMSprop( + std::vector params, // NOLINTNEXTLINE(performance-move-const-arg) - RMSpropOptions defaults = {}) : RMSprop({std::move(OptimizerParamGroup(params))}, defaults) {} + RMSpropOptions defaults = {}) + : RMSprop({std::move(OptimizerParamGroup(params))}, defaults) {} torch::Tensor step(LossClosure closure = nullptr) override; void save(serialize::OutputArchive& archive) const override; diff --git a/torch/csrc/api/include/torch/optim/schedulers/lr_scheduler.h b/torch/csrc/api/include/torch/optim/schedulers/lr_scheduler.h index 21254ca49fee55..4a24c10de252c3 100644 --- a/torch/csrc/api/include/torch/optim/schedulers/lr_scheduler.h +++ b/torch/csrc/api/include/torch/optim/schedulers/lr_scheduler.h @@ -8,34 +8,33 @@ namespace torch { namespace optim { class TORCH_API LRScheduler { -public: - - //This class needs to take a reference of an optimizer from outside such that it can - //modify its learning rates; due to this the lifetime of said optimizer must be maintained + public: + // This class needs to take a reference of an optimizer from outside such that + // it can modify its learning rates; due to this the lifetime of said + // optimizer must be maintained LRScheduler(torch::optim::Optimizer& optimizer); virtual ~LRScheduler() = default; void step(); -protected: - //A vector of learning rates is calculated and returned from the specific subclass. - //A vector is returned with each element being a separate learning rate for each - //param group - although the normal use case would be to return a vector of identical - //elements. + protected: + // A vector of learning rates is calculated and returned from the specific + // subclass. A vector is returned with each element being a separate learning + // rate for each param group - although the normal use case would be to return + // a vector of identical elements. virtual std::vector get_lrs() = 0; - //Get current learning rates from the optimizer + // Get current learning rates from the optimizer std::vector get_current_lrs() const; // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) unsigned step_count_; -private: + private: void set_optimizer_lrs(const std::vector& learning_rates); torch::optim::Optimizer& optimizer_; - }; } // namespace optim -} // namspace torch +} // namespace torch diff --git a/torch/csrc/api/include/torch/optim/schedulers/step_lr.h b/torch/csrc/api/include/torch/optim/schedulers/step_lr.h index 129da878c7f02c..289bb4bd84e54e 100644 --- a/torch/csrc/api/include/torch/optim/schedulers/step_lr.h +++ b/torch/csrc/api/include/torch/optim/schedulers/step_lr.h @@ -6,18 +6,17 @@ namespace torch { namespace optim { class TORCH_API StepLR : public LRScheduler { -public: + public: + StepLR( + torch::optim::Optimizer& optimizer, + const unsigned step_size, + const double gamma = 0.1); - StepLR(torch::optim::Optimizer& optimizer, - const unsigned step_size, - const double gamma = 0.1); - -private: + private: std::vector get_lrs() override; const unsigned step_size_; const double gamma_; - }; } // namespace optim } // namespace torch diff --git a/torch/csrc/api/include/torch/optim/serialize.h b/torch/csrc/api/include/torch/optim/serialize.h index bcbb432b06245f..04edde361f4e25 100644 --- a/torch/csrc/api/include/torch/optim/serialize.h +++ b/torch/csrc/api/include/torch/optim/serialize.h @@ -1,9 +1,9 @@ #pragma once #include +#include #include #include -#include #include #include #include @@ -13,93 +13,108 @@ namespace torch { namespace optim { namespace detail { - // Utility function to save state - template - void serialize( - serialize::OutputArchive& archive, - const ska::flat_hash_map>& state) { - for (const auto& item : state) { - serialize::OutputArchive param_state_archive(archive.compilation_unit()); - std::string tensorimpl_key = item.first; - const DerivedOptimizerParamState& curr_state = static_cast(*(item.second.get())); - curr_state.serialize(param_state_archive); - archive.write(tensorimpl_key, param_state_archive); - } +// Utility function to save state +template +void serialize( + serialize::OutputArchive& archive, + const ska::flat_hash_map>& + state) { + for (const auto& item : state) { + serialize::OutputArchive param_state_archive(archive.compilation_unit()); + std::string tensorimpl_key = item.first; + const DerivedOptimizerParamState& curr_state = + static_cast(*(item.second.get())); + curr_state.serialize(param_state_archive); + archive.write(tensorimpl_key, param_state_archive); } +} - // Utility function to load state - template - void serialize( - serialize::InputArchive& archive, - ska::flat_hash_map>& state) { - std::vector tensorimpl_keys = archive.keys(); - for (const std::string& tensorimpl_key : tensorimpl_keys) { - serialize::InputArchive param_state_archive; - archive.read(tensorimpl_key, param_state_archive); - DerivedOptimizerParamState param_state; - param_state.serialize(param_state_archive); - state[tensorimpl_key] = std::make_unique(param_state); - } +// Utility function to load state +template +void serialize( + serialize::InputArchive& archive, + ska::flat_hash_map>& + state) { + std::vector tensorimpl_keys = archive.keys(); + for (const std::string& tensorimpl_key : tensorimpl_keys) { + serialize::InputArchive param_state_archive; + archive.read(tensorimpl_key, param_state_archive); + DerivedOptimizerParamState param_state; + param_state.serialize(param_state_archive); + state[tensorimpl_key] = + std::make_unique(param_state); } +} - // Utility function to save param_groups - template - void serialize( - serialize::OutputArchive& archive, - const std::vector& param_groups) { - archive.write("param_groups/size", torch::tensor(static_cast(param_groups.size()))); - for (const auto i : c10::irange(param_groups.size())) { - serialize::OutputArchive param_group_archive(archive.compilation_unit()); - std::vector params = param_groups[i].params(); +// Utility function to save param_groups +template +void serialize( + serialize::OutputArchive& archive, + const std::vector& param_groups) { + archive.write( + "param_groups/size", + torch::tensor(static_cast(param_groups.size()))); + for (const auto i : c10::irange(param_groups.size())) { + serialize::OutputArchive param_group_archive(archive.compilation_unit()); + std::vector params = param_groups[i].params(); + param_group_archive.write( + "params/size", torch::tensor(static_cast(params.size()))); + for (const auto index : c10::irange(params.size())) { param_group_archive.write( - "params/size", torch::tensor(static_cast(params.size()))); - for (const auto index : c10::irange(params.size())) { - param_group_archive.write( - "params/" + c10::guts::to_string(index), IValue(c10::guts::to_string(params[index].unsafeGetTensorImpl()))); - } - const DerivedOptimizerParamOptions& param_group_options = static_cast(param_groups[i].options()); - serialize::OutputArchive param_group_options_archive(param_group_archive.compilation_unit()); - param_group_options.serialize(param_group_options_archive); - param_group_archive.write("options", param_group_options_archive); - archive.write("param_groups/" + c10::guts::to_string(i), param_group_archive); + "params/" + c10::guts::to_string(index), + IValue(c10::guts::to_string(params[index].unsafeGetTensorImpl()))); } + const DerivedOptimizerParamOptions& param_group_options = + static_cast( + param_groups[i].options()); + serialize::OutputArchive param_group_options_archive( + param_group_archive.compilation_unit()); + param_group_options.serialize(param_group_options_archive); + param_group_archive.write("options", param_group_options_archive); + archive.write( + "param_groups/" + c10::guts::to_string(i), param_group_archive); } +} - // Utility function to load param_groups - // We take as input vector of pair of string and unique_ptr to optimizer options so that we can retain the state - // for each param by using the old tensor impl keys (saved during serialization) and map the new tensor impl keys to - // the correct state for each param - template - void serialize( - serialize::InputArchive& archive, - std::vector, std::unique_ptr>>& param_groups) { - torch::Tensor param_groups_size_tensor; - archive.read("param_groups/size", param_groups_size_tensor); - const int64_t param_groups_size = param_groups_size_tensor.item(); - for (const auto i : c10::irange(param_groups_size)) { - serialize::InputArchive param_group_archive; - archive.read("param_groups/" + c10::guts::to_string(i), param_group_archive); - torch::Tensor size_tensor; - param_group_archive.read("params/size", size_tensor); - const int64_t size = size_tensor.item(); - std::vector params; - for (const auto index : c10::irange(size)) { - IValue ivalue; - param_group_archive.read( - "params/" + c10::to_string(index), ivalue); - std::string element = ivalue.toStringRef(); - params.emplace_back(element); - } - serialize::InputArchive param_group_options_archive; - param_group_archive.read("options", param_group_options_archive); - DerivedOptimizerParamOptions param_group_options(0); - param_group_options.serialize(param_group_options_archive); - param_groups.emplace_back(std::make_pair(params, std::make_unique(param_group_options))); +// Utility function to load param_groups +// We take as input vector of pair of string and unique_ptr to optimizer options +// so that we can retain the state for each param by using the old tensor impl +// keys (saved during serialization) and map the new tensor impl keys to the +// correct state for each param +template +void serialize( + serialize::InputArchive& archive, + std::vector< + std::pair, std::unique_ptr>>& + param_groups) { + torch::Tensor param_groups_size_tensor; + archive.read("param_groups/size", param_groups_size_tensor); + const int64_t param_groups_size = param_groups_size_tensor.item(); + for (const auto i : c10::irange(param_groups_size)) { + serialize::InputArchive param_group_archive; + archive.read( + "param_groups/" + c10::guts::to_string(i), param_group_archive); + torch::Tensor size_tensor; + param_group_archive.read("params/size", size_tensor); + const int64_t size = size_tensor.item(); + std::vector params; + for (const auto index : c10::irange(size)) { + IValue ivalue; + param_group_archive.read("params/" + c10::to_string(index), ivalue); + std::string element = ivalue.toStringRef(); + params.emplace_back(element); } + serialize::InputArchive param_group_options_archive; + param_group_archive.read("options", param_group_options_archive); + DerivedOptimizerParamOptions param_group_options(0); + param_group_options.serialize(param_group_options_archive); + param_groups.emplace_back(std::make_pair( + params, + std::make_unique(param_group_options))); } +} } // namespace detail - // Note: These functions are all called `serialize()` so they can be called // inside a template where the archive type is a template type and can thus be // passed such that the appropriate overload is selected. @@ -129,52 +144,63 @@ void serialize( std::vector& steps); // Utility function to save state and param_groups -template -void serialize( - serialize::OutputArchive& archive, - const Optimizer& optimizer) { +template < + typename DerivedOptimizerParamState, + typename DerivedOptimizerParamOptions> +void serialize(serialize::OutputArchive& archive, const Optimizer& optimizer) { archive.write("pytorch_version", IValue("1.5.0")); serialize::OutputArchive state_archive(archive.compilation_unit()); - detail::serialize(state_archive, optimizer.state()); + detail::serialize( + state_archive, optimizer.state()); archive.write("state", state_archive); serialize::OutputArchive param_groups_archive(archive.compilation_unit()); - detail::serialize(param_groups_archive, optimizer.param_groups()); + detail::serialize( + param_groups_archive, optimizer.param_groups()); archive.write("param_groups", param_groups_archive); } // Utility function to load state and param_groups and update state -template -void serialize( - serialize::InputArchive& archive, - Optimizer& optimizer) { - - IValue pytorch_version; - archive.read("pytorch_version", pytorch_version); - TORCH_INTERNAL_ASSERT(pytorch_version.toStringRef() == "1.5.0"); - serialize::InputArchive state_archive; - archive.read("state", state_archive); - ska::flat_hash_map> saved_state; - detail::serialize(state_archive, saved_state); +template < + typename DerivedOptimizerParamState, + typename DerivedOptimizerParamOptions> +void serialize(serialize::InputArchive& archive, Optimizer& optimizer) { + IValue pytorch_version; + archive.read("pytorch_version", pytorch_version); + TORCH_INTERNAL_ASSERT(pytorch_version.toStringRef() == "1.5.0"); + serialize::InputArchive state_archive; + archive.read("state", state_archive); + ska::flat_hash_map> + saved_state; + detail::serialize(state_archive, saved_state); - serialize::InputArchive param_groups_archive; - archive.read("param_groups", param_groups_archive); - std::vector, std::unique_ptr>> saved_param_groups; - detail::serialize(param_groups_archive, saved_param_groups); + serialize::InputArchive param_groups_archive; + archive.read("param_groups", param_groups_archive); + std::vector< + std::pair, std::unique_ptr>> + saved_param_groups; + detail::serialize( + param_groups_archive, saved_param_groups); - // update state - TORCH_CHECK(saved_param_groups.size() == optimizer.param_groups().size(), "loaded state dict has a different number of parameter groups"); - for (const auto i : c10::irange(saved_param_groups.size())) { - std::vector param_group_old_keys = saved_param_groups[i].first; - std::vector params = optimizer.param_groups()[i].params(); - TORCH_CHECK(param_group_old_keys.size() == params.size(), "loaded state dict contains a parameter group that has a different size than the optimizer's parameter group"); + // update state + TORCH_CHECK( + saved_param_groups.size() == optimizer.param_groups().size(), + "loaded state dict has a different number of parameter groups"); + for (const auto i : c10::irange(saved_param_groups.size())) { + std::vector param_group_old_keys = saved_param_groups[i].first; + std::vector params = optimizer.param_groups()[i].params(); + TORCH_CHECK( + param_group_old_keys.size() == params.size(), + "loaded state dict contains a parameter group that has a different size than the optimizer's parameter group"); - for (const auto idx : c10::irange(params.size())) { - if(saved_state.find(param_group_old_keys[idx]) != saved_state.end()) { - optimizer.state()[c10::guts::to_string(params[idx].unsafeGetTensorImpl())] = std::move(saved_state[param_group_old_keys[idx]]); - } + for (const auto idx : c10::irange(params.size())) { + if (saved_state.find(param_group_old_keys[idx]) != saved_state.end()) { + optimizer + .state()[c10::guts::to_string(params[idx].unsafeGetTensorImpl())] = + std::move(saved_state[param_group_old_keys[idx]]); } } + } } /// Utility function to save a vector of buffers. @@ -230,47 +256,55 @@ std::deque list_to_deque(const c10::List& list) { #define _TORCH_OPTIM_SERIALIZE(name) \ torch::optim::serialize(archive, #name, self.name) -#define _TORCH_OPTIM_SERIALIZE_WITH_TEMPLATE_ARG(OptimizerName) \ - torch::optim::serialize(archive, self) +#define _TORCH_OPTIM_SERIALIZE_WITH_TEMPLATE_ARG(OptimizerName) \ + torch::optim::serialize( \ + archive, self) -#define _TORCH_OPTIM_SERIALIZE_TORCH_ARG(name) { \ - auto ivalue = torch::IValue(name()); \ - /* do not serialize if name is an undefined tensor*/ \ - if (!(ivalue.isTensor() && ivalue.unsafeToTensorImpl() == at::UndefinedTensorImpl::singleton())) { \ - archive.write(#name, ivalue); \ - } \ -} +#define _TORCH_OPTIM_SERIALIZE_TORCH_ARG(name) \ + { \ + auto ivalue = torch::IValue(name()); \ + /* do not serialize if name is an undefined tensor*/ \ + if (!(ivalue.isTensor() && \ + ivalue.unsafeToTensorImpl() == \ + at::UndefinedTensorImpl::singleton())) { \ + archive.write(#name, ivalue); \ + } \ + } -#define _TORCH_OPTIM_SERIALIZE_TORCH_ARG_DEQUE(name) { \ - c10::IValue ivalue = torch::IValue(deque_to_list(name())); \ - archive.write(#name, ivalue); \ -} +#define _TORCH_OPTIM_SERIALIZE_TORCH_ARG_DEQUE(name) \ + { \ + c10::IValue ivalue = torch::IValue(deque_to_list(name())); \ + archive.write(#name, ivalue); \ + } -#define _TORCH_OPTIM_DESERIALIZE_TORCH_ARG(T, name) { \ - c10::IValue ivalue; \ - bool exists = archive.try_read(#name, ivalue); \ - if (exists) {\ - name(ivalue.to()); \ - } else { \ - bool is_tensor_type = std::is_base_of::value; \ - TORCH_INTERNAL_ASSERT(is_tensor_type); \ - } \ -} +#define _TORCH_OPTIM_DESERIALIZE_TORCH_ARG(T, name) \ + { \ + c10::IValue ivalue; \ + bool exists = archive.try_read(#name, ivalue); \ + if (exists) { \ + name(ivalue.to()); \ + } else { \ + bool is_tensor_type = std::is_base_of::value; \ + TORCH_INTERNAL_ASSERT(is_tensor_type); \ + } \ + } -#define _TORCH_OPTIM_DESERIALIZE_TORCH_ARG_OPTIONAL(T, name) { \ - c10::IValue ivalue; \ - bool exists = archive.try_read(#name, ivalue); \ - if (exists) { \ - name(ivalue.toOptional()); \ - } \ -} +#define _TORCH_OPTIM_DESERIALIZE_TORCH_ARG_OPTIONAL(T, name) \ + { \ + c10::IValue ivalue; \ + bool exists = archive.try_read(#name, ivalue); \ + if (exists) { \ + name(ivalue.toOptional()); \ + } \ + } -#define _TORCH_OPTIM_DESERIALIZE_TORCH_ARG_DEQUE(T, name) { \ - c10::IValue ivalue; \ - archive.read(#name, ivalue); \ - auto list = ivalue.to>(); \ - name(list_to_deque(list)); \ -} +#define _TORCH_OPTIM_DESERIALIZE_TORCH_ARG_DEQUE(T, name) \ + { \ + c10::IValue ivalue; \ + archive.read(#name, ivalue); \ + auto list = ivalue.to>(); \ + name(list_to_deque(list)); \ + } } // namespace optim } // namespace torch diff --git a/torch/csrc/api/include/torch/optim/sgd.h b/torch/csrc/api/include/torch/optim/sgd.h index 4253b25428b454..2533e91d482b53 100644 --- a/torch/csrc/api/include/torch/optim/sgd.h +++ b/torch/csrc/api/include/torch/optim/sgd.h @@ -27,38 +27,59 @@ struct TORCH_API SGDOptions : public OptimizerCloneableOptions { TORCH_ARG(double, dampening) = 0; TORCH_ARG(double, weight_decay) = 0; TORCH_ARG(bool, nesterov) = false; -public: + + public: void serialize(torch::serialize::InputArchive& archive) override; void serialize(torch::serialize::OutputArchive& archive) const override; - TORCH_API friend bool operator==(const SGDOptions& lhs, const SGDOptions& rhs); + TORCH_API friend bool operator==( + const SGDOptions& lhs, + const SGDOptions& rhs); ~SGDOptions() override = default; double get_lr() const override; void set_lr(const double lr) override; }; -struct TORCH_API SGDParamState : public OptimizerCloneableParamState { +struct TORCH_API SGDParamState + : public OptimizerCloneableParamState { TORCH_ARG(torch::Tensor, momentum_buffer); -public: + public: void serialize(torch::serialize::InputArchive& archive) override; void serialize(torch::serialize::OutputArchive& archive) const override; - TORCH_API friend bool operator==(const SGDParamState& lhs, const SGDParamState& rhs); + TORCH_API friend bool operator==( + const SGDParamState& lhs, + const SGDParamState& rhs); ~SGDParamState() override = default; }; class TORCH_API SGD : public Optimizer { public: - explicit SGD(std::vector param_groups, - SGDOptions defaults) : Optimizer(std::move(param_groups), std::make_unique(defaults)) { + explicit SGD( + std::vector param_groups, + SGDOptions defaults) + : Optimizer( + std::move(param_groups), + std::make_unique(defaults)) { TORCH_CHECK(defaults.lr() >= 0, "Invalid learning rate: ", defaults.lr()); - TORCH_CHECK(defaults.momentum() >= 0, "Invalid momentum value: ", defaults.momentum()); - TORCH_CHECK(defaults.weight_decay() >= 0, "Invalid weight_decay value: ", defaults.weight_decay()); - TORCH_CHECK(!defaults.nesterov() || (defaults.momentum() > 0 && defaults.dampening() == 0), "Nesterov momentum requires a momentum and zero dampening"); + TORCH_CHECK( + defaults.momentum() >= 0, + "Invalid momentum value: ", + defaults.momentum()); + TORCH_CHECK( + defaults.weight_decay() >= 0, + "Invalid weight_decay value: ", + defaults.weight_decay()); + TORCH_CHECK( + !defaults.nesterov() || + (defaults.momentum() > 0 && defaults.dampening() == 0), + "Nesterov momentum requires a momentum and zero dampening"); } - explicit SGD(std::vector params, + explicit SGD( + std::vector params, // NOLINTNEXTLINE(performance-move-const-arg) - SGDOptions defaults) : SGD({std::move(OptimizerParamGroup(params))}, defaults) {} + SGDOptions defaults) + : SGD({std::move(OptimizerParamGroup(params))}, defaults) {} torch::Tensor step(LossClosure closure = nullptr) override; diff --git a/torch/csrc/api/include/torch/ordered_dict.h b/torch/csrc/api/include/torch/ordered_dict.h index 38f1abdeee04a2..dbe67e1c49eb52 100644 --- a/torch/csrc/api/include/torch/ordered_dict.h +++ b/torch/csrc/api/include/torch/ordered_dict.h @@ -170,9 +170,12 @@ class OrderedDict { /// `OrderedDict` into a vector of `std::pair`. ::std::vector> pairs() const; - /// Returns true if both dicts contain the same keys and values, in the same order. - template - friend bool operator==(const OrderedDict &a, const OrderedDict &b); + /// Returns true if both dicts contain the same keys and values, in the same + /// order. + template + friend bool operator==( + const OrderedDict& a, + const OrderedDict& b); private: /// A mapping from a key to an index into the `items_` vector. @@ -345,8 +348,8 @@ typename OrderedDict::Item& OrderedDict::operator[]( } template -const typename OrderedDict:: - Item& OrderedDict::operator[](size_t index) const { +const typename OrderedDict::Item& OrderedDict:: +operator[](size_t index) const { TORCH_CHECK(index < items_.size(), "Index ", index, " is out of bounds"); return items_[index]; } @@ -502,16 +505,22 @@ void OrderedDict::reserve(size_t requested_capacity) { items_.reserve(requested_capacity); } -template -bool operator==(const torch::OrderedDict& a, const torch::OrderedDict& b) { +template +bool operator==( + const torch::OrderedDict& a, + const torch::OrderedDict& b) { using Item = typename torch::OrderedDict::Item; - if (a.index_ != b.index_) return false; - if (a.items_.size() != b.items_.size()) return false; - // NOTE: There's no point in comparing keys for items_, as we already know that index is equal. - return std::equal(a.items_.begin(), a.items_.end(), - b.items_.begin(), - [](const Item& a, const Item& b) - { return a.value() == b.value(); }); + if (a.index_ != b.index_) + return false; + if (a.items_.size() != b.items_.size()) + return false; + // NOTE: There's no point in comparing keys for items_, as we already know + // that index is equal. + return std::equal( + a.items_.begin(), + a.items_.end(), + b.items_.begin(), + [](const Item& a, const Item& b) { return a.value() == b.value(); }); } } // namespace torch diff --git a/torch/csrc/api/include/torch/serialize.h b/torch/csrc/api/include/torch/serialize.h index f92b6123d13966..c52575cc219445 100644 --- a/torch/csrc/api/include/torch/serialize.h +++ b/torch/csrc/api/include/torch/serialize.h @@ -1,9 +1,9 @@ #pragma once #include +#include #include #include -#include #include @@ -39,8 +39,7 @@ namespace torch { /// \endrst template void save(const Value& value, SaveToArgs&&... args) { - serialize::OutputArchive archive( - std::make_shared()); + serialize::OutputArchive archive(std::make_shared()); archive << value; archive.save_to(std::forward(args)...); } @@ -54,19 +53,18 @@ void save(const Value& value, SaveToArgs&&... args) { /// \rst /// .. code-block:: cpp /// -/// std::vector tensor_vec = { torch::randn({1, 2}), torch::randn({3, 4}) }; -/// torch::save(tensor_vec, "my_tensor_vec.pt"); +/// std::vector tensor_vec = { torch::randn({1, 2}), +/// torch::randn({3, 4}) }; torch::save(tensor_vec, "my_tensor_vec.pt"); /// -/// std::vector tensor_vec = { torch::randn({5, 6}), torch::randn({7, 8}) }; -/// std::ostringstream stream; +/// std::vector tensor_vec = { torch::randn({5, 6}), +/// torch::randn({7, 8}) }; std::ostringstream stream; /// // Note that the same stream cannot be used in multiple torch::save(...) /// // invocations, otherwise the header will be corrupted. /// torch::save(tensor_vec, stream); /// \endrst template void save(const std::vector& tensor_vec, SaveToArgs&&... args) { - serialize::OutputArchive archive( - std::make_shared()); + serialize::OutputArchive archive(std::make_shared()); for (const auto i : c10::irange(tensor_vec.size())) { auto& value = tensor_vec[i]; archive.write(c10::to_string(i), value); diff --git a/torch/csrc/api/include/torch/serialize/input-archive.h b/torch/csrc/api/include/torch/serialize/input-archive.h index 9b2174c83df3ec..83d1a543ddacb1 100644 --- a/torch/csrc/api/include/torch/serialize/input-archive.h +++ b/torch/csrc/api/include/torch/serialize/input-archive.h @@ -1,10 +1,10 @@ #pragma once -#include #include +#include #include -#include #include +#include #include #include @@ -74,25 +74,29 @@ class TORCH_API InputArchive final { /// Loads the `InputArchive` from a serialized representation stored in the /// file at `filename`. Storage are remapped using device option. If device /// is not specified, the module is loaded to the original device. - void load_from(const std::string& filename, + void load_from( + const std::string& filename, c10::optional device = c10::nullopt); /// Loads the `InputArchive` from a serialized representation stored in the /// given `stream`. Storage are remapped using device option. If device /// is not specified, the module is loaded to the original device. - void load_from(std::istream& stream, + void load_from( + std::istream& stream, c10::optional device = c10::nullopt); // Loads given the specified flat array. - void load_from(const char* data, size_t size, + void load_from( + const char* data, + size_t size, c10::optional device = c10::nullopt); // Loads given the specified read and size functions. void load_from( - const std::function& read_func, - const std::function& size_func, - c10::optional device = c10::nullopt); + const std::function& + read_func, + const std::function& size_func, + c10::optional device = c10::nullopt); // Returns the vector of keys in the input archive. std::vector keys(); diff --git a/torch/csrc/api/include/torch/serialize/output-archive.h b/torch/csrc/api/include/torch/serialize/output-archive.h index 59335edf6f253a..12e0f54971cb39 100644 --- a/torch/csrc/api/include/torch/serialize/output-archive.h +++ b/torch/csrc/api/include/torch/serialize/output-archive.h @@ -24,7 +24,9 @@ namespace serialize { class TORCH_API OutputArchive final { public: explicit OutputArchive(std::shared_ptr cu); - explicit OutputArchive() : cu_(std::make_shared()), module_("__torch__.Module", cu_) {} + explicit OutputArchive() + : cu_(std::make_shared()), + module_("__torch__.Module", cu_) {} // Move is allowed. OutputArchive(OutputArchive&&) = default; diff --git a/torch/csrc/api/include/torch/sparse.h b/torch/csrc/api/include/torch/sparse.h index 912c4ef6e51e05..a30e74477e3658 100644 --- a/torch/csrc/api/include/torch/sparse.h +++ b/torch/csrc/api/include/torch/sparse.h @@ -3,6 +3,5 @@ #include namespace torch { -namespace sparse { - -}} // torch::sparse +namespace sparse {} +} // namespace torch diff --git a/torch/csrc/api/include/torch/special.h b/torch/csrc/api/include/torch/special.h index 4eca020d037aae..6d3c3df33efbeb 100644 --- a/torch/csrc/api/include/torch/special.h +++ b/torch/csrc/api/include/torch/special.h @@ -1,6 +1,7 @@ #pragma once #include +#include namespace torch { namespace special { @@ -34,7 +35,10 @@ inline Tensor gammainc(const Tensor& self, const Tensor& other) { return torch::special_gammainc(self, other); } -inline Tensor& gammainc_out(Tensor& result, const Tensor& self, const Tensor& other) { +inline Tensor& gammainc_out( + Tensor& result, + const Tensor& self, + const Tensor& other) { return torch::special_gammainc_out(result, self, other); } @@ -51,7 +55,10 @@ inline Tensor gammaincc(const Tensor& self, const Tensor& other) { return torch::special_gammaincc(self, other); } -inline Tensor& gammaincc_out(Tensor& result, const Tensor& self, const Tensor& other) { +inline Tensor& gammaincc_out( + Tensor& result, + const Tensor& self, + const Tensor& other) { return torch::special_gammaincc_out(result, self, other); } @@ -199,8 +206,9 @@ inline Tensor& erfinv_out(Tensor& result, const Tensor& self) { return torch::special_erfinv_out(result, self); } -/// Computes the log of summed exponentials of each row of input in the given dimension dim -/// See https://pytorch.org/docs/master/special.html#torch.special.logsumexp. +/// Computes the log of summed exponentials of each row of input in the given +/// dimension dim See +/// https://pytorch.org/docs/master/special.html#torch.special.logsumexp. /// /// Example: /// ``` @@ -211,13 +219,18 @@ inline Tensor logsumexp(const Tensor& self, IntArrayRef dims, bool keepdim) { return torch::special_logsumexp(self, dims, keepdim); } -inline Tensor& logsumexp_out(Tensor& result, const Tensor& self, IntArrayRef dims, bool keepdim) { +inline Tensor& logsumexp_out( + Tensor& result, + const Tensor& self, + IntArrayRef dims, + bool keepdim) { return torch::special_logsumexp_out(result, self, dims, keepdim); } -/// Computes the argument, x, for which the area under the Gaussian probability density -/// function (integrated from minus infinity to x) is equal to input, elementwise. -/// See https://pytorch.org/docs/master/special.html#torch.special.ndtri +/// Computes the argument, x, for which the area under the Gaussian probability +/// density function (integrated from minus infinity to x) is equal to input, +/// elementwise. See +/// https://pytorch.org/docs/master/special.html#torch.special.ndtri /// /// Example: /// ``` @@ -232,9 +245,9 @@ inline Tensor& ndtri_out(Tensor& result, const Tensor& self) { return torch::special_ndtri_out(result, self); } -/// Computes the log of area under the standard Gaussian probability density function, -/// integrated from minus infinity to :attr:`input`, elementwise -/// See https://pytorch.org/docs/master/special.html#torch.special.log_ndtr +/// Computes the log of area under the standard Gaussian probability density +/// function, integrated from minus infinity to :attr:`input`, elementwise See +/// https://pytorch.org/docs/master/special.html#torch.special.log_ndtr /// /// Example: /// ``` @@ -265,8 +278,9 @@ inline Tensor& logit_out(Tensor& result, const Tensor& self) { return torch::special_logit_out(result, self); } -/// Computes the expit (also known as the logistic sigmoid function) of input, elementwise -/// See https://pytorch.org/docs/master/special.html#torch.special.expit. +/// Computes the expit (also known as the logistic sigmoid function) of input, +/// elementwise See +/// https://pytorch.org/docs/master/special.html#torch.special.expit. /// /// Example: /// ``` @@ -334,15 +348,24 @@ inline Tensor xlogy(const Tensor& self, const Scalar& other) { return torch::special_xlogy(self, other); } -inline Tensor& xlogy_out(Tensor& result, const Tensor& self, const Tensor& other) { +inline Tensor& xlogy_out( + Tensor& result, + const Tensor& self, + const Tensor& other) { return torch::special_xlogy_out(result, self, other); } -inline Tensor& xlogy_out(Tensor& result, const Scalar& self, const Tensor& other) { +inline Tensor& xlogy_out( + Tensor& result, + const Scalar& self, + const Tensor& other) { return torch::special_xlogy_out(result, self, other); } -inline Tensor& xlogy_out(Tensor& result, const Tensor& self, const Scalar& other) { +inline Tensor& xlogy_out( + Tensor& result, + const Tensor& self, + const Scalar& other) { return torch::special_xlogy_out(result, self, other); } @@ -367,15 +390,24 @@ inline Tensor xlog1py(const Tensor& self, const Scalar& other) { return torch::special_xlog1py(self, other); } -inline Tensor& xlog1py_out(Tensor& result, const Tensor& self, const Tensor& other) { +inline Tensor& xlog1py_out( + Tensor& result, + const Tensor& self, + const Tensor& other) { return torch::special_xlog1py_out(result, self, other); } -inline Tensor& xlog1py_out(Tensor& result, const Scalar& self, const Tensor& other) { +inline Tensor& xlog1py_out( + Tensor& result, + const Scalar& self, + const Tensor& other) { return torch::special_xlog1py_out(result, self, other); } -inline Tensor& xlog1py_out(Tensor& result, const Tensor& self, const Scalar& other) { +inline Tensor& xlog1py_out( + Tensor& result, + const Tensor& self, + const Scalar& other) { return torch::special_xlog1py_out(result, self, other); } @@ -400,20 +432,30 @@ inline Tensor zeta(const Tensor& self, const Scalar& other) { return torch::special_zeta(self, other); } -inline Tensor& zeta_out(Tensor& result, const Tensor& self, const Tensor& other) { +inline Tensor& zeta_out( + Tensor& result, + const Tensor& self, + const Tensor& other) { return torch::special_zeta_out(result, self, other); } -inline Tensor& zeta_out(Tensor& result, const Scalar& self, const Tensor& other) { +inline Tensor& zeta_out( + Tensor& result, + const Scalar& self, + const Tensor& other) { return torch::special_zeta_out(result, self, other); } -inline Tensor& zeta_out(Tensor& result, const Tensor& self, const Scalar& other) { +inline Tensor& zeta_out( + Tensor& result, + const Tensor& self, + const Scalar& other) { return torch::special_zeta_out(result, self, other); } -/// Computes the zeroth order modified Bessel function of the first kind of input, elementwise -/// See https://pytorch.org/docs/master/special.html#torch.special.i0 +/// Computes the zeroth order modified Bessel function of the first kind of +/// input, elementwise See +/// https://pytorch.org/docs/master/special.html#torch.special.i0 /// /// Example: /// ``` @@ -445,8 +487,9 @@ inline Tensor& ndtr_out(Tensor& result, const Tensor& self) { return torch::special_ndtr_out(result, self); } -/// Computes the exponentially scaled zeroth order modified Bessel function of the first kind -/// See https://pytorch.org/docs/master/special.html#torch.special.i0e. +/// Computes the exponentially scaled zeroth order modified Bessel function of +/// the first kind See +/// https://pytorch.org/docs/master/special.html#torch.special.i0e. /// /// Example: /// ``` @@ -477,8 +520,9 @@ inline Tensor& i1_out(Tensor& result, const Tensor& self) { return torch::special_i1_out(result, self); } -/// Computes the exponentially scaled first order modified Bessel function of the first kind -/// See https://pytorch.org/docs/master/special.html#torch.special.i1e. +/// Computes the exponentially scaled first order modified Bessel function of +/// the first kind See +/// https://pytorch.org/docs/master/special.html#torch.special.i1e. /// /// Example: /// ``` @@ -549,7 +593,10 @@ inline Tensor& log1p_out(Tensor& result, const Tensor& self) { /// auto t = torch::randn(128, 128, dtype=kDouble); /// torch::special::log_softmax(t, 0); /// ``` -inline Tensor log_softmax(const Tensor& self, int64_t dim, c10::optional dtype) { +inline Tensor log_softmax( + const Tensor& self, + int64_t dim, + c10::optional dtype) { return torch::special_log_softmax(self, dim, dtype); } @@ -561,7 +608,10 @@ inline Tensor log_softmax(const Tensor& self, int64_t dim, c10::optional dtype) { +inline Tensor softmax( + const Tensor& self, + int64_t dim, + c10::optional dtype) { return torch::special_softmax(self, dim, dtype); } @@ -643,7 +693,8 @@ inline Tensor& bessel_y1_out(Tensor& result, const Tensor& self) { /// Chebyshev polynomial of the first kind. /// -/// See https://pytorch.org/docs/master/special.html#torch.special.chebyshev_polynomial_t. +/// See +/// https://pytorch.org/docs/master/special.html#torch.special.chebyshev_polynomial_t. /// /// Example: /// @@ -665,21 +716,31 @@ inline Tensor chebyshev_polynomial_t(const Tensor& x, const Scalar& n) { return torch::special_chebyshev_polynomial_t(x, n); } -inline Tensor& chebyshev_polynomial_t_out(Tensor& output, const Tensor& x, const Tensor& n) { +inline Tensor& chebyshev_polynomial_t_out( + Tensor& output, + const Tensor& x, + const Tensor& n) { return torch::special_chebyshev_polynomial_t_out(output, x, n); } -inline Tensor& chebyshev_polynomial_t_out(Tensor& output, const Scalar& x, const Tensor& n) { +inline Tensor& chebyshev_polynomial_t_out( + Tensor& output, + const Scalar& x, + const Tensor& n) { return torch::special_chebyshev_polynomial_t_out(output, x, n); } -inline Tensor& chebyshev_polynomial_t_out(Tensor& output, const Tensor& x, const Scalar& n) { +inline Tensor& chebyshev_polynomial_t_out( + Tensor& output, + const Tensor& x, + const Scalar& n) { return torch::special_chebyshev_polynomial_t_out(output, x, n); } /// Chebyshev polynomial of the second kind. /// -/// See https://pytorch.org/docs/master/special.html#torch.special.chebyshev_polynomial_u. +/// See +/// https://pytorch.org/docs/master/special.html#torch.special.chebyshev_polynomial_u. /// /// Example: /// @@ -701,21 +762,31 @@ inline Tensor chebyshev_polynomial_u(const Tensor& x, const Scalar& n) { return torch::special_chebyshev_polynomial_u(x, n); } -inline Tensor& chebyshev_polynomial_u_out(Tensor& output, const Tensor& x, const Tensor& n) { +inline Tensor& chebyshev_polynomial_u_out( + Tensor& output, + const Tensor& x, + const Tensor& n) { return torch::special_chebyshev_polynomial_u_out(output, x, n); } -inline Tensor& chebyshev_polynomial_u_out(Tensor& output, const Scalar& x, const Tensor& n) { +inline Tensor& chebyshev_polynomial_u_out( + Tensor& output, + const Scalar& x, + const Tensor& n) { return torch::special_chebyshev_polynomial_u_out(output, x, n); } -inline Tensor& chebyshev_polynomial_u_out(Tensor& output, const Tensor& x, const Scalar& n) { +inline Tensor& chebyshev_polynomial_u_out( + Tensor& output, + const Tensor& x, + const Scalar& n) { return torch::special_chebyshev_polynomial_u_out(output, x, n); } /// Chebyshev polynomial of the third kind. /// -/// See https://pytorch.org/docs/master/special.html#torch.special.chebyshev_polynomial_v. +/// See +/// https://pytorch.org/docs/master/special.html#torch.special.chebyshev_polynomial_v. /// /// Example: /// @@ -737,21 +808,31 @@ inline Tensor chebyshev_polynomial_v(const Tensor& x, const Scalar& n) { return torch::special_chebyshev_polynomial_v(x, n); } -inline Tensor& chebyshev_polynomial_v_out(Tensor& output, const Tensor& x, const Tensor& n) { +inline Tensor& chebyshev_polynomial_v_out( + Tensor& output, + const Tensor& x, + const Tensor& n) { return torch::special_chebyshev_polynomial_v_out(output, x, n); } -inline Tensor& chebyshev_polynomial_v_out(Tensor& output, const Scalar& x, const Tensor& n) { +inline Tensor& chebyshev_polynomial_v_out( + Tensor& output, + const Scalar& x, + const Tensor& n) { return torch::special_chebyshev_polynomial_v_out(output, x, n); } -inline Tensor& chebyshev_polynomial_v_out(Tensor& output, const Tensor& x, const Scalar& n) { +inline Tensor& chebyshev_polynomial_v_out( + Tensor& output, + const Tensor& x, + const Scalar& n) { return torch::special_chebyshev_polynomial_v_out(output, x, n); } /// Chebyshev polynomial of the fourth kind. /// -/// See https://pytorch.org/docs/master/special.html#torch.special.chebyshev_polynomial_w. +/// See +/// https://pytorch.org/docs/master/special.html#torch.special.chebyshev_polynomial_w. /// /// Example: /// @@ -773,21 +854,31 @@ inline Tensor chebyshev_polynomial_w(const Tensor& x, const Scalar& n) { return torch::special_chebyshev_polynomial_w(x, n); } -inline Tensor& chebyshev_polynomial_w_out(Tensor& output, const Tensor& x, const Tensor& n) { +inline Tensor& chebyshev_polynomial_w_out( + Tensor& output, + const Tensor& x, + const Tensor& n) { return torch::special_chebyshev_polynomial_w_out(output, x, n); } -inline Tensor& chebyshev_polynomial_w_out(Tensor& output, const Scalar& x, const Tensor& n) { +inline Tensor& chebyshev_polynomial_w_out( + Tensor& output, + const Scalar& x, + const Tensor& n) { return torch::special_chebyshev_polynomial_w_out(output, x, n); } -inline Tensor& chebyshev_polynomial_w_out(Tensor& output, const Tensor& x, const Scalar& n) { +inline Tensor& chebyshev_polynomial_w_out( + Tensor& output, + const Tensor& x, + const Scalar& n) { return torch::special_chebyshev_polynomial_w_out(output, x, n); } /// Physicist’s Hermite polynomial. /// -/// See https://pytorch.org/docs/master/special.html#torch.special.hermite_polynomial_h. +/// See +/// https://pytorch.org/docs/master/special.html#torch.special.hermite_polynomial_h. /// /// Example: /// @@ -809,21 +900,31 @@ inline Tensor hermite_polynomial_h(const Tensor& x, const Scalar& n) { return torch::special_hermite_polynomial_h(x, n); } -inline Tensor& hermite_polynomial_h_out(Tensor& output, const Tensor& x, const Tensor& n) { +inline Tensor& hermite_polynomial_h_out( + Tensor& output, + const Tensor& x, + const Tensor& n) { return torch::special_hermite_polynomial_h_out(output, x, n); } -inline Tensor& hermite_polynomial_h_out(Tensor& output, const Scalar& x, const Tensor& n) { +inline Tensor& hermite_polynomial_h_out( + Tensor& output, + const Scalar& x, + const Tensor& n) { return torch::special_hermite_polynomial_h_out(output, x, n); } -inline Tensor& hermite_polynomial_h_out(Tensor& output, const Tensor& x, const Scalar& n) { +inline Tensor& hermite_polynomial_h_out( + Tensor& output, + const Tensor& x, + const Scalar& n) { return torch::special_hermite_polynomial_h_out(output, x, n); } /// Probabilist’s Hermite polynomial. /// -/// See https://pytorch.org/docs/master/special.html#torch.special.hermite_polynomial_he. +/// See +/// https://pytorch.org/docs/master/special.html#torch.special.hermite_polynomial_he. /// /// Example: /// @@ -845,21 +946,31 @@ inline Tensor hermite_polynomial_he(const Tensor& x, const Scalar& n) { return torch::special_hermite_polynomial_he(x, n); } -inline Tensor& hermite_polynomial_he_out(Tensor& output, const Tensor& x, const Tensor& n) { +inline Tensor& hermite_polynomial_he_out( + Tensor& output, + const Tensor& x, + const Tensor& n) { return torch::special_hermite_polynomial_he_out(output, x, n); } -inline Tensor& hermite_polynomial_he_out(Tensor& output, const Scalar& x, const Tensor& n) { +inline Tensor& hermite_polynomial_he_out( + Tensor& output, + const Scalar& x, + const Tensor& n) { return torch::special_hermite_polynomial_he_out(output, x, n); } -inline Tensor& hermite_polynomial_he_out(Tensor& output, const Tensor& x, const Scalar& n) { +inline Tensor& hermite_polynomial_he_out( + Tensor& output, + const Tensor& x, + const Scalar& n) { return torch::special_hermite_polynomial_he_out(output, x, n); } /// Laguerre polynomial. /// -/// See https://pytorch.org/docs/master/special.html#torch.special.laguerre_polynomial_l. +/// See +/// https://pytorch.org/docs/master/special.html#torch.special.laguerre_polynomial_l. /// /// Example: /// @@ -881,21 +992,31 @@ inline Tensor laguerre_polynomial_l(const Tensor& x, const Scalar& n) { return torch::special_laguerre_polynomial_l(x, n); } -inline Tensor& laguerre_polynomial_l_out(Tensor& output, const Tensor& x, const Tensor& n) { +inline Tensor& laguerre_polynomial_l_out( + Tensor& output, + const Tensor& x, + const Tensor& n) { return torch::special_laguerre_polynomial_l_out(output, x, n); } -inline Tensor& laguerre_polynomial_l_out(Tensor& output, const Scalar& x, const Tensor& n) { +inline Tensor& laguerre_polynomial_l_out( + Tensor& output, + const Scalar& x, + const Tensor& n) { return torch::special_laguerre_polynomial_l_out(output, x, n); } -inline Tensor& laguerre_polynomial_l_out(Tensor& output, const Tensor& x, const Scalar& n) { +inline Tensor& laguerre_polynomial_l_out( + Tensor& output, + const Tensor& x, + const Scalar& n) { return torch::special_laguerre_polynomial_l_out(output, x, n); } /// Legendre polynomial. /// -/// See https://pytorch.org/docs/master/special.html#torch.special.legendre_polynomial_p. +/// See +/// https://pytorch.org/docs/master/special.html#torch.special.legendre_polynomial_p. /// /// Example: /// @@ -917,21 +1038,31 @@ inline Tensor legendre_polynomial_p(const Tensor& x, const Scalar& n) { return torch::special_legendre_polynomial_p(x, n); } -inline Tensor& legendre_polynomial_p_out(Tensor& output, const Tensor& x, const Tensor& n) { +inline Tensor& legendre_polynomial_p_out( + Tensor& output, + const Tensor& x, + const Tensor& n) { return torch::special_legendre_polynomial_p_out(output, x, n); } -inline Tensor& legendre_polynomial_p_out(Tensor& output, const Scalar& x, const Tensor& n) { +inline Tensor& legendre_polynomial_p_out( + Tensor& output, + const Scalar& x, + const Tensor& n) { return torch::special_legendre_polynomial_p_out(output, x, n); } -inline Tensor& legendre_polynomial_p_out(Tensor& output, const Tensor& x, const Scalar& n) { +inline Tensor& legendre_polynomial_p_out( + Tensor& output, + const Tensor& x, + const Scalar& n) { return torch::special_legendre_polynomial_p_out(output, x, n); } /// Modified Bessel function of the first kind of order 0. /// -/// See https://pytorch.org/docs/master/special.html#torch.special.modified_bessel_i0. +/// See +/// https://pytorch.org/docs/master/special.html#torch.special.modified_bessel_i0. /// /// Example: /// @@ -950,7 +1081,8 @@ inline Tensor& modified_bessel_i0_out(Tensor& result, const Tensor& self) { /// Modified Bessel function of the first kind of order 1. /// -/// See https://pytorch.org/docs/master/special.html#torch.special.modified_bessel_i1. +/// See +/// https://pytorch.org/docs/master/special.html#torch.special.modified_bessel_i1. /// /// Example: /// @@ -969,7 +1101,8 @@ inline Tensor& modified_bessel_i1_out(Tensor& result, const Tensor& self) { /// Modified Bessel function of the second kind of order 0. /// -/// See https://pytorch.org/docs/master/special.html#torch.special.modified_bessel_k0. +/// See +/// https://pytorch.org/docs/master/special.html#torch.special.modified_bessel_k0. /// /// Example: /// @@ -988,7 +1121,8 @@ inline Tensor& modified_bessel_k0_out(Tensor& result, const Tensor& self) { /// Modified Bessel function of the second kind of order 1. /// -/// See https://pytorch.org/docs/master/special.html#torch.special.modified_bessel_k1. +/// See +/// https://pytorch.org/docs/master/special.html#torch.special.modified_bessel_k1. /// /// Example: /// @@ -1007,7 +1141,8 @@ inline Tensor& modified_bessel_k1_out(Tensor& result, const Tensor& self) { /// Shifted Chebyshev polynomial of the first kind. /// -/// See https://pytorch.org/docs/master/special.html#torch.special.shifted_chebyshev_polynomial_t. +/// See +/// https://pytorch.org/docs/master/special.html#torch.special.shifted_chebyshev_polynomial_t. /// /// Example: /// @@ -1029,21 +1164,31 @@ inline Tensor shifted_chebyshev_polynomial_t(const Tensor& x, const Scalar& n) { return torch::special_shifted_chebyshev_polynomial_t(x, n); } -inline Tensor& shifted_chebyshev_polynomial_t_out(Tensor& output, const Tensor& x, const Tensor& n) { +inline Tensor& shifted_chebyshev_polynomial_t_out( + Tensor& output, + const Tensor& x, + const Tensor& n) { return torch::special_shifted_chebyshev_polynomial_t_out(output, x, n); } -inline Tensor& shifted_chebyshev_polynomial_t_out(Tensor& output, const Scalar& x, const Tensor& n) { +inline Tensor& shifted_chebyshev_polynomial_t_out( + Tensor& output, + const Scalar& x, + const Tensor& n) { return torch::special_shifted_chebyshev_polynomial_t_out(output, x, n); } -inline Tensor& shifted_chebyshev_polynomial_t_out(Tensor& output, const Tensor& x, const Scalar& n) { +inline Tensor& shifted_chebyshev_polynomial_t_out( + Tensor& output, + const Tensor& x, + const Scalar& n) { return torch::special_shifted_chebyshev_polynomial_t_out(output, x, n); } /// Shifted Chebyshev polynomial of the second kind. /// -/// See https://pytorch.org/docs/master/special.html#torch.special.shifted_chebyshev_polynomial_u. +/// See +/// https://pytorch.org/docs/master/special.html#torch.special.shifted_chebyshev_polynomial_u. /// /// Example: /// @@ -1065,21 +1210,31 @@ inline Tensor shifted_chebyshev_polynomial_u(const Tensor& x, const Scalar& n) { return torch::special_shifted_chebyshev_polynomial_u(x, n); } -inline Tensor& shifted_chebyshev_polynomial_u_out(Tensor& output, const Tensor& x, const Tensor& n) { +inline Tensor& shifted_chebyshev_polynomial_u_out( + Tensor& output, + const Tensor& x, + const Tensor& n) { return torch::special_shifted_chebyshev_polynomial_u_out(output, x, n); } -inline Tensor& shifted_chebyshev_polynomial_u_out(Tensor& output, const Scalar& x, const Tensor& n) { +inline Tensor& shifted_chebyshev_polynomial_u_out( + Tensor& output, + const Scalar& x, + const Tensor& n) { return torch::special_shifted_chebyshev_polynomial_u_out(output, x, n); } -inline Tensor& shifted_chebyshev_polynomial_u_out(Tensor& output, const Tensor& x, const Scalar& n) { +inline Tensor& shifted_chebyshev_polynomial_u_out( + Tensor& output, + const Tensor& x, + const Scalar& n) { return torch::special_shifted_chebyshev_polynomial_u_out(output, x, n); } /// Shifted Chebyshev polynomial of the third kind. /// -/// See https://pytorch.org/docs/master/special.html#torch.special.shifted_chebyshev_polynomial_v. +/// See +/// https://pytorch.org/docs/master/special.html#torch.special.shifted_chebyshev_polynomial_v. /// /// Example: /// @@ -1101,21 +1256,31 @@ inline Tensor shifted_chebyshev_polynomial_v(const Tensor& x, const Scalar& n) { return torch::special_shifted_chebyshev_polynomial_v(x, n); } -inline Tensor& shifted_chebyshev_polynomial_v_out(Tensor& output, const Tensor& x, const Tensor& n) { +inline Tensor& shifted_chebyshev_polynomial_v_out( + Tensor& output, + const Tensor& x, + const Tensor& n) { return torch::special_shifted_chebyshev_polynomial_v_out(output, x, n); } -inline Tensor& shifted_chebyshev_polynomial_v_out(Tensor& output, const Scalar& x, const Tensor& n) { +inline Tensor& shifted_chebyshev_polynomial_v_out( + Tensor& output, + const Scalar& x, + const Tensor& n) { return torch::special_shifted_chebyshev_polynomial_v_out(output, x, n); } -inline Tensor& shifted_chebyshev_polynomial_v_out(Tensor& output, const Tensor& x, const Scalar& n) { +inline Tensor& shifted_chebyshev_polynomial_v_out( + Tensor& output, + const Tensor& x, + const Scalar& n) { return torch::special_shifted_chebyshev_polynomial_v_out(output, x, n); } /// Shifted Chebyshev polynomial of the fourth kind. /// -/// See https://pytorch.org/docs/master/special.html#torch.special.shifted_chebyshev_polynomial_w. +/// See +/// https://pytorch.org/docs/master/special.html#torch.special.shifted_chebyshev_polynomial_w. /// /// Example: /// @@ -1137,16 +1302,26 @@ inline Tensor shifted_chebyshev_polynomial_w(const Tensor& x, const Scalar& n) { return torch::special_shifted_chebyshev_polynomial_w(x, n); } -inline Tensor& shifted_chebyshev_polynomial_w_out(Tensor& output, const Tensor& x, const Tensor& n) { +inline Tensor& shifted_chebyshev_polynomial_w_out( + Tensor& output, + const Tensor& x, + const Tensor& n) { return torch::special_shifted_chebyshev_polynomial_w_out(output, x, n); } -inline Tensor& shifted_chebyshev_polynomial_w_out(Tensor& output, const Scalar& x, const Tensor& n) { +inline Tensor& shifted_chebyshev_polynomial_w_out( + Tensor& output, + const Scalar& x, + const Tensor& n) { return torch::special_shifted_chebyshev_polynomial_w_out(output, x, n); } -inline Tensor& shifted_chebyshev_polynomial_w_out(Tensor& output, const Tensor& x, const Scalar& n) { +inline Tensor& shifted_chebyshev_polynomial_w_out( + Tensor& output, + const Tensor& x, + const Scalar& n) { return torch::special_shifted_chebyshev_polynomial_w_out(output, x, n); } -}} // torch::special +} // namespace special +} // namespace torch diff --git a/torch/csrc/api/include/torch/types.h b/torch/csrc/api/include/torch/types.h index 102eb055e58437..92be710cf4bf46 100644 --- a/torch/csrc/api/include/torch/types.h +++ b/torch/csrc/api/include/torch/types.h @@ -9,34 +9,37 @@ // TODO: These don't really belong here but torchvision builds in CI need them // Remove once the torchvision version being compiled in CI is updated -#include #include +#include namespace torch { // NOTE [ Exposing declarations in `at::` to `torch::` ] // -// The following line `using namespace at;` is responsible for exposing all declarations in -// `at::` namespace to `torch::` namespace. +// The following line `using namespace at;` is responsible for exposing all +// declarations in `at::` namespace to `torch::` namespace. // // According to the rules laid out in -// https://en.cppreference.com/w/cpp/language/qualified_lookup, section "Namespace members": +// https://en.cppreference.com/w/cpp/language/qualified_lookup, section +// "Namespace members": // ``` -// Qualified lookup within the scope of a namespace N first considers all declarations that are -// located in N and all declarations that are located in the inline namespace members of N -// (and, transitively, in their inline namespace members). If there are no declarations in that set -// then it considers declarations in all namespaces named by using-directives found in N and -// in all transitive inline namespace members of N. +// Qualified lookup within the scope of a namespace N first considers all +// declarations that are located in N and all declarations that are located in +// the inline namespace members of N (and, transitively, in their inline +// namespace members). If there are no declarations in that set then it +// considers declarations in all namespaces named by using-directives found in N +// and in all transitive inline namespace members of N. // ``` // -// This means that if both `at::` and `torch::` namespaces have a function with the same signature -// (e.g. both `at::func()` and `torch::func()` exist), after `namespace torch { using namespace at; }`, -// when we call `torch::func()`, the `func()` function defined in `torch::` namespace will always -// be called, and the `func()` function defined in `at::` namespace is always hidden. +// This means that if both `at::` and `torch::` namespaces have a function with +// the same signature (e.g. both `at::func()` and `torch::func()` exist), after +// `namespace torch { using namespace at; }`, when we call `torch::func()`, the +// `func()` function defined in `torch::` namespace will always be called, and +// the `func()` function defined in `at::` namespace is always hidden. using namespace at; // NOLINT -using c10::optional; using c10::nullopt; +using c10::optional; using Dtype = at::ScalarType; diff --git a/torch/csrc/api/include/torch/utils.h b/torch/csrc/api/include/torch/utils.h index 3bb6363a4ced1b..004a0064636ef4 100644 --- a/torch/csrc/api/include/torch/utils.h +++ b/torch/csrc/api/include/torch/utils.h @@ -2,9 +2,9 @@ #include #include +#include #include #include -#include #include namespace torch { @@ -13,10 +13,12 @@ namespace torch { /// /// Disabling gradient calculation is useful for inference, when you are sure /// that you will not call `at::Tensor::backward`. It will reduce memory -/// consumption for computations that would otherwise have `requires_grad() == true`. +/// consumption for computations that would otherwise have `requires_grad() == +/// true`. /// /// In this mode, the result of every computation will have -/// `requires_grad() == false`, even when the inputs have `requires_grad() == true`. +/// `requires_grad() == false`, even when the inputs have `requires_grad() == +/// true`. /// /// This context manager is thread-local; it will not affect computation /// in other threads. @@ -42,7 +44,8 @@ using NoGradGuard = at::NoGradGuard; /// A RAII, thread-local guard that sets gradient calculation to on or off. /// -/// ``AutoGradMode`` will enable or disable grads based on its argument `enabled`. +/// ``AutoGradMode`` will enable or disable grads based on its argument +/// `enabled`. /// /// This context manager is thread-local; it will not affect computation /// in other threads. @@ -87,25 +90,27 @@ using at::set_num_interop_threads; // Returns true if both t1, t2 are undefined or both are defined and equal inline bool equal_if_defined(Tensor t1, Tensor t2) { - return ((!t1.defined() && !t2.defined()) || (t1.defined() && t2.defined() && torch::equal(t1, t2))); + return ( + (!t1.defined() && !t2.defined()) || + (t1.defined() && t2.defined() && torch::equal(t1, t2))); } // RecordFunction API -using at::RecordFunctionCallback; -using at::addThreadLocalCallback; -using at::hasThreadLocalCallbacks; -using at::clearThreadLocalCallbacks; using at::addGlobalCallback; -using at::removeCallback; -using at::hasGlobalCallbacks; -using at::clearGlobalCallbacks; -using at::hasCallbacks; +using at::addThreadLocalCallback; +using at::CallbackHandle; using at::clearCallbacks; +using at::clearGlobalCallbacks; +using at::clearThreadLocalCallbacks; +using at::DisableRecordFunctionGuard; using at::enableRecordFunction; +using at::hasCallbacks; +using at::hasGlobalCallbacks; +using at::hasThreadLocalCallbacks; using at::isRecordFunctionEnabled; -using at::RecordFunctionGuard; -using at::DisableRecordFunctionGuard; -using at::CallbackHandle; using at::RecordFunction; +using at::RecordFunctionCallback; +using at::RecordFunctionGuard; +using at::removeCallback; } // namespace torch diff --git a/torch/csrc/api/src/cuda.cpp b/torch/csrc/api/src/cuda.cpp index f11f0988200e2f..be78073ccf35d0 100644 --- a/torch/csrc/api/src/cuda.cpp +++ b/torch/csrc/api/src/cuda.cpp @@ -54,8 +54,10 @@ void manual_seed_all(uint64_t seed) { void synchronize(int64_t device_index) { TORCH_CHECK(is_available(), "No CUDA GPUs are available"); int64_t num_gpus = cuda::device_count(); - TORCH_CHECK(device_index == -1 || device_index < num_gpus, - "Device index out of range: ", device_index); + TORCH_CHECK( + device_index == -1 || device_index < num_gpus, + "Device index out of range: ", + device_index); at::detail::getCUDAHooks().deviceSynchronize(device_index); } diff --git a/torch/csrc/api/src/imethod.cpp b/torch/csrc/api/src/imethod.cpp index cfbf1c9e2805df..3c74eddc7b4773 100644 --- a/torch/csrc/api/src/imethod.cpp +++ b/torch/csrc/api/src/imethod.cpp @@ -2,8 +2,7 @@ namespace torch { -const std::vector& IMethod::getArgumentNames() const -{ +const std::vector& IMethod::getArgumentNames() const { if (isArgumentNamesInitialized_) { return argumentNames_; } diff --git a/torch/csrc/api/src/jit.cpp b/torch/csrc/api/src/jit.cpp index 4323b8a38bf112..16d9d0040a6592 100644 --- a/torch/csrc/api/src/jit.cpp +++ b/torch/csrc/api/src/jit.cpp @@ -1,7 +1,7 @@ #include -#include #include +#include #include #include @@ -11,11 +11,7 @@ namespace jit { std::shared_ptr compile(const std::string& source) { auto module = std::make_shared(); - module->define( - c10::nullopt, - source, - nativeResolver(), - nullptr); + module->define(c10::nullopt, source, nativeResolver(), nullptr); return module; } diff --git a/torch/csrc/api/src/nn/init.cpp b/torch/csrc/api/src/nn/init.cpp index 554b34e0add081..25fdcdcbc791d7 100644 --- a/torch/csrc/api/src/nn/init.cpp +++ b/torch/csrc/api/src/nn/init.cpp @@ -58,11 +58,11 @@ double calculate_kaiming_std( double calculate_gain(NonlinearityType nonlinearity, double param) { if (c10::get_if(&nonlinearity)) { - return 5.0 / 3.0; // NOLINT + return 5.0 / 3.0; // NOLINT } else if (c10::get_if(&nonlinearity)) { - return std::sqrt(2.0); // NOLINT + return std::sqrt(2.0); // NOLINT } else if (c10::get_if(&nonlinearity)) { - return std::sqrt(2.0 / (1 + pow(param, 2))); // NOLINT + return std::sqrt(2.0 / (1 + pow(param, 2))); // NOLINT } return 1.0; @@ -225,11 +225,13 @@ Tensor zeros_(Tensor tensor) { return tensor.zero_(); } -std::tuple _calculate_fan_in_and_fan_out(const Tensor& tensor) { +std::tuple _calculate_fan_in_and_fan_out( + const Tensor& tensor) { const auto dimensions = tensor.dim(); - TORCH_CHECK(dimensions >= 2, - "Fan in and fan out can not be computed " - "for tensor with fewer than 2 dimensions") + TORCH_CHECK( + dimensions >= 2, + "Fan in and fan out can not be computed " + "for tensor with fewer than 2 dimensions") // NOLINTNEXTLINE(cppcoreguidelines-init-variables) int64_t fan_in, fan_out; diff --git a/torch/csrc/api/src/nn/module.cpp b/torch/csrc/api/src/nn/module.cpp index e84fb3e9c5a65d..1f569aa14386dd 100644 --- a/torch/csrc/api/src/nn/module.cpp +++ b/torch/csrc/api/src/nn/module.cpp @@ -316,8 +316,8 @@ Tensor& Module::register_parameter( if (!tensor.defined()) { if (requires_grad) { TORCH_WARN( - "An undefined tensor cannot require grad. ", - "Ignoring the `requires_grad=true` function parameter."); + "An undefined tensor cannot require grad. ", + "Ignoring the `requires_grad=true` function parameter."); } } else { tensor.set_requires_grad(requires_grad); diff --git a/torch/csrc/api/src/nn/modules/_functions.cpp b/torch/csrc/api/src/nn/modules/_functions.cpp index 1a81b44afc7d7b..62ee4b3c536be7 100644 --- a/torch/csrc/api/src/nn/modules/_functions.cpp +++ b/torch/csrc/api/src/nn/modules/_functions.cpp @@ -8,9 +8,9 @@ namespace nn { namespace functions { Variable CrossMapLRN2d::forward( - AutogradContext *ctx, + AutogradContext* ctx, const Variable& input, - const CrossMapLRN2dOptions& options){ + const CrossMapLRN2dOptions& options) { ctx->saved_data["size"] = options.size(); ctx->saved_data["alpha"] = options.alpha(); ctx->saved_data["beta"] = options.beta(); @@ -19,7 +19,9 @@ Variable CrossMapLRN2d::forward( TORCH_CHECK(input.dim() == 4); - ctx->saved_data["scale"] = ctx->saved_data["scale"].toTensor().defined() ? ctx->saved_data["scale"] : torch::empty({0}, input.options()); + ctx->saved_data["scale"] = ctx->saved_data["scale"].toTensor().defined() + ? ctx->saved_data["scale"] + : torch::empty({0}, input.options()); torch::Tensor output = torch::empty({0}, input.options()); @@ -32,7 +34,8 @@ Variable CrossMapLRN2d::forward( auto input_square = output; torch::pow_out(input_square, input, 2); - int64_t pre_pad = static_cast((ctx->saved_data["size"].toInt() - 1) / 2 + 1); + int64_t pre_pad = + static_cast((ctx->saved_data["size"].toInt() - 1) / 2 + 1); int64_t pre_pad_crop = pre_pad > channels ? channels : pre_pad; auto scale_first = ctx->saved_data["scale"].toTensor().select(1, 0); @@ -63,18 +66,25 @@ Variable CrossMapLRN2d::forward( } } - ctx->saved_data["scale"].toTensor() - .mul_(ctx->saved_data["alpha"].toDouble() / ctx->saved_data["size"].toInt()) + ctx->saved_data["scale"] + .toTensor() + .mul_( + ctx->saved_data["alpha"].toDouble() / ctx->saved_data["size"].toInt()) .add_(ctx->saved_data["k"].toInt()); - torch::pow_out(output, ctx->saved_data["scale"].toTensor(), -ctx->saved_data["beta"].toDouble()); + torch::pow_out( + output, + ctx->saved_data["scale"].toTensor(), + -ctx->saved_data["beta"].toDouble()); output.mul_(input); ctx->save_for_backward({input, output}); return output; } -variable_list CrossMapLRN2d::backward(AutogradContext *ctx, variable_list grad_outputs) { +variable_list CrossMapLRN2d::backward( + AutogradContext* ctx, + variable_list grad_outputs) { auto grad_output = grad_outputs[0]; auto input = ctx->get_saved_variables()[0]; auto output = ctx->get_saved_variables()[1]; @@ -85,15 +95,24 @@ variable_list CrossMapLRN2d::backward(AutogradContext *ctx, variable_list grad_o int64_t input_height = input.size(2); int64_t input_width = input.size(3); - auto padded_ratio = torch::empty({channels + ctx->saved_data["size"].toInt() - 1, input_height, input_width}, - input.options()); - auto accum_ratio = torch::empty({input_height, input_width}, - input.options()); - double cache_ratio_value = 2 * ctx->saved_data["alpha"].toDouble() * ctx->saved_data["beta"].toDouble() / ctx->saved_data["size"].toInt(); - int64_t inversePrePad = static_cast(ctx->saved_data["size"].toInt() - (ctx->saved_data["size"].toInt() - 1) / 2); + auto padded_ratio = torch::empty( + {channels + ctx->saved_data["size"].toInt() - 1, + input_height, + input_width}, + input.options()); + auto accum_ratio = torch::empty({input_height, input_width}, input.options()); + double cache_ratio_value = 2 * ctx->saved_data["alpha"].toDouble() * + ctx->saved_data["beta"].toDouble() / ctx->saved_data["size"].toInt(); + int64_t inversePrePad = static_cast( + ctx->saved_data["size"].toInt() - + (ctx->saved_data["size"].toInt() - 1) / 2); grad_input.resize_as_(input); - torch::pow_out(grad_input, ctx->saved_data["scale"].toTensor(), -ctx->saved_data["beta"].toDouble()).mul_(grad_output); + torch::pow_out( + grad_input, + ctx->saved_data["scale"].toTensor(), + -ctx->saved_data["beta"].toDouble()) + .mul_(grad_output); padded_ratio.zero_(); auto padded_ratio_center = padded_ratio.narrow(0, inversePrePad, channels); @@ -104,7 +123,8 @@ variable_list CrossMapLRN2d::backward(AutogradContext *ctx, variable_list grad_o torch::sum_out( accum_ratio, padded_ratio.narrow(0, 0, ctx->saved_data["size"].toInt() - 1), - 0, /*keepdim=*/false); + 0, + /*keepdim=*/false); for (const auto c : c10::irange(channels)) { accum_ratio.add_(padded_ratio[c + ctx->saved_data["size"].toInt() - 1]); grad_input[n][c].addcmul_(input[n][c], accum_ratio, -cache_ratio_value); @@ -112,9 +132,10 @@ variable_list CrossMapLRN2d::backward(AutogradContext *ctx, variable_list grad_o } } - return variable_list{grad_input, Variable(), Variable(), Variable(), Variable()}; + return variable_list{ + grad_input, Variable(), Variable(), Variable(), Variable()}; } -} +} // namespace functions } // namespace nn } // namespace torch diff --git a/torch/csrc/api/src/nn/modules/activation.cpp b/torch/csrc/api/src/nn/modules/activation.cpp index 001199e98edd55..4ce04a7cddad79 100644 --- a/torch/csrc/api/src/nn/modules/activation.cpp +++ b/torch/csrc/api/src/nn/modules/activation.cpp @@ -1,6 +1,6 @@ -#include #include #include +#include namespace F = torch::nn::functional; @@ -18,7 +18,7 @@ void ELUImpl::reset() {} void ELUImpl::pretty_print(std::ostream& stream) const { stream << "torch::nn::ELU(alpha=" << options.alpha(); if (options.inplace()) { - stream << std::boolalpha << ", inplace=" << options.inplace(); + stream << std::boolalpha << ", inplace=" << options.inplace(); } stream << ")"; } @@ -53,8 +53,8 @@ Tensor HardshrinkImpl::forward(const Tensor& input) { void HardshrinkImpl::reset() {} void HardshrinkImpl::pretty_print(std::ostream& stream) const { - stream << std::boolalpha - << "torch::nn::Hardshrink(" << options.lambda() << ")"; + stream << std::boolalpha << "torch::nn::Hardshrink(" << options.lambda() + << ")"; } // ============================================================================ @@ -66,12 +66,14 @@ HardtanhImpl::HardtanhImpl(const HardtanhOptions& options_) } Tensor HardtanhImpl::forward(Tensor input) { - return F::detail::hardtanh(input, options.min_val(), options.max_val(), options.inplace()); + return F::detail::hardtanh( + input, options.min_val(), options.max_val(), options.inplace()); } void HardtanhImpl::reset() { - TORCH_CHECK(options.max_val() > options.min_val(), - "max_val must be greater than min_val"); + TORCH_CHECK( + options.max_val() > options.min_val(), + "max_val must be greater than min_val"); } void HardtanhImpl::pretty_print(std::ostream& stream) const { @@ -79,7 +81,7 @@ void HardtanhImpl::pretty_print(std::ostream& stream) const { << "torch::nn::Hardtanh(min_val=" << options.min_val() << ", max_val=" << options.max_val(); if (options.inplace()) { - stream << std::boolalpha << ", inplace=" << options.inplace(); + stream << std::boolalpha << ", inplace=" << options.inplace(); } stream << ")"; } @@ -90,7 +92,8 @@ LeakyReLUImpl::LeakyReLUImpl(const LeakyReLUOptions& options_) : options(options_) {} Tensor LeakyReLUImpl::forward(Tensor input) { - return F::detail::leaky_relu(input, options.negative_slope(), options.inplace()); + return F::detail::leaky_relu( + input, options.negative_slope(), options.inplace()); } void LeakyReLUImpl::reset() {} @@ -99,7 +102,7 @@ void LeakyReLUImpl::pretty_print(std::ostream& stream) const { stream << std::boolalpha << "torch::nn::LeakyReLU(negative_slope=" << options.negative_slope(); if (options.inplace()) { - stream << std::boolalpha << ", inplace=" << options.inplace(); + stream << std::boolalpha << ", inplace=" << options.inplace(); } stream << ")"; } @@ -118,8 +121,7 @@ void LogSigmoidImpl::pretty_print(std::ostream& stream) const { // ============================================================================ -SoftmaxImpl::SoftmaxImpl(const SoftmaxOptions& options_) - : options(options_) {} +SoftmaxImpl::SoftmaxImpl(const SoftmaxOptions& options_) : options(options_) {} void SoftmaxImpl::reset() {} @@ -133,8 +135,7 @@ Tensor SoftmaxImpl::forward(const Tensor& input) { // ============================================================================ -SoftminImpl::SoftminImpl(const SoftminOptions& options_) - : options(options_) {} +SoftminImpl::SoftminImpl(const SoftminOptions& options_) : options(options_) {} void SoftminImpl::reset() {} @@ -170,7 +171,9 @@ void Softmax2dImpl::pretty_print(std::ostream& stream) const { } Tensor Softmax2dImpl::forward(const Tensor& input) { - TORCH_CHECK(input.dim() == 4 || input.dim() == 3, "Softmax2d requires a 3D or 4D tensor as input"); + TORCH_CHECK( + input.dim() == 4 || input.dim() == 3, + "Softmax2d requires a 3D or 4D tensor as input"); return F::detail::softmax(input, /*dim=*/-3, c10::nullopt); } @@ -186,13 +189,13 @@ Tensor PReLUImpl::forward(const Tensor& input) { } void PReLUImpl::reset() { - weight = register_parameter("weight", - torch::full(options.num_parameters(), options.init())); + weight = register_parameter( + "weight", torch::full(options.num_parameters(), options.init())); } void PReLUImpl::pretty_print(std::ostream& stream) const { - stream << "torch::nn::PReLU(num_parameters=" - << options.num_parameters() << ")"; + stream << "torch::nn::PReLU(num_parameters=" << options.num_parameters() + << ")"; } // ============================================================================ @@ -236,7 +239,12 @@ void ReLU6Impl::pretty_print(std::ostream& stream) const { RReLUImpl::RReLUImpl(const RReLUOptions& options_) : options(options_) {} Tensor RReLUImpl::forward(Tensor input) { - return F::detail::rrelu(input, options.lower(), options.upper(), is_training(), options.inplace()); + return F::detail::rrelu( + input, + options.lower(), + options.upper(), + is_training(), + options.inplace()); } void RReLUImpl::reset() {} @@ -335,7 +343,7 @@ void SigmoidImpl::pretty_print(std::ostream& stream) const { // ============================================================================ SoftplusImpl::SoftplusImpl(const SoftplusOptions& options_) - : options(options_) {} + : options(options_) {} Tensor SoftplusImpl::forward(const Tensor& input) { return F::detail::softplus(input, options.beta(), options.threshold()); @@ -405,7 +413,8 @@ ThresholdImpl::ThresholdImpl(const ThresholdOptions& options_) : options(options_) {} Tensor ThresholdImpl::forward(Tensor input) { - return F::detail::threshold(input, options.threshold(), options.value(), options.inplace()); + return F::detail::threshold( + input, options.threshold(), options.value(), options.inplace()); } void ThresholdImpl::reset() {} @@ -422,99 +431,109 @@ void ThresholdImpl::pretty_print(std::ostream& stream) const { // ============================================================================ // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) -MultiheadAttentionImpl::MultiheadAttentionImpl(const MultiheadAttentionOptions& options_) +MultiheadAttentionImpl::MultiheadAttentionImpl( + const MultiheadAttentionOptions& options_) : Module("torch::nn::MultiheadAttention"), options(options_) { // NOLINTNEXTLINE(clang-analyzer-optin.cplusplus.VirtualCall) reset(); } std::tuple MultiheadAttentionImpl::forward( - const Tensor& query, const Tensor& key, - const Tensor& value, const Tensor& key_padding_mask, - bool need_weights, const Tensor& attn_mask, - bool average_attn_weights) { + const Tensor& query, + const Tensor& key, + const Tensor& value, + const Tensor& key_padding_mask, + bool need_weights, + const Tensor& attn_mask, + bool average_attn_weights) { if (!_qkv_same_embed_dim) { return F::multi_head_attention_forward( - query, key, value, - F::MultiheadAttentionForwardFuncOptions( - /*embed_dim_to_check=*/options.embed_dim(), - /*num_heads=*/options.num_heads(), - /*in_proj_weight=*/in_proj_weight, - /*in_proj_bias=*/in_proj_bias, - /*bias_k=*/bias_k, - /*bias_v=*/bias_v, - /*add_zero_attn=*/options.add_zero_attn(), - /*dropout_p=*/options.dropout(), - /*out_proj_weight=*/out_proj->weight, - /*out_proj_bias=*/out_proj->bias - ).training(is_training()) - .key_padding_mask(key_padding_mask) - .need_weights(need_weights) - .attn_mask(attn_mask) - .use_separate_proj_weight(true) - .q_proj_weight(q_proj_weight) - .k_proj_weight(k_proj_weight) - .v_proj_weight(v_proj_weight) - .average_attn_weights(average_attn_weights) - ); + query, + key, + value, + F::MultiheadAttentionForwardFuncOptions( + /*embed_dim_to_check=*/options.embed_dim(), + /*num_heads=*/options.num_heads(), + /*in_proj_weight=*/in_proj_weight, + /*in_proj_bias=*/in_proj_bias, + /*bias_k=*/bias_k, + /*bias_v=*/bias_v, + /*add_zero_attn=*/options.add_zero_attn(), + /*dropout_p=*/options.dropout(), + /*out_proj_weight=*/out_proj->weight, + /*out_proj_bias=*/out_proj->bias) + .training(is_training()) + .key_padding_mask(key_padding_mask) + .need_weights(need_weights) + .attn_mask(attn_mask) + .use_separate_proj_weight(true) + .q_proj_weight(q_proj_weight) + .k_proj_weight(k_proj_weight) + .v_proj_weight(v_proj_weight) + .average_attn_weights(average_attn_weights)); } else { return F::multi_head_attention_forward( - query, key, value, - F::MultiheadAttentionForwardFuncOptions( - /*embed_dim_to_check=*/options.embed_dim(), - /*num_heads=*/options.num_heads(), - /*in_proj_weight=*/in_proj_weight, - /*in_proj_bias=*/in_proj_bias, - /*bias_k=*/bias_k, - /*bias_v=*/bias_v, - /*add_zero_attn=*/options.add_zero_attn(), - /*dropout_p=*/options.dropout(), - /*out_proj_weight=*/out_proj->weight, - /*out_proj_bias=*/out_proj->bias - ).training(is_training()) - .key_padding_mask(key_padding_mask) - .need_weights(need_weights) - .attn_mask(attn_mask) - .average_attn_weights(average_attn_weights) - ); + query, + key, + value, + F::MultiheadAttentionForwardFuncOptions( + /*embed_dim_to_check=*/options.embed_dim(), + /*num_heads=*/options.num_heads(), + /*in_proj_weight=*/in_proj_weight, + /*in_proj_bias=*/in_proj_bias, + /*bias_k=*/bias_k, + /*bias_v=*/bias_v, + /*add_zero_attn=*/options.add_zero_attn(), + /*dropout_p=*/options.dropout(), + /*out_proj_weight=*/out_proj->weight, + /*out_proj_bias=*/out_proj->bias) + .training(is_training()) + .key_padding_mask(key_padding_mask) + .need_weights(need_weights) + .attn_mask(attn_mask) + .average_attn_weights(average_attn_weights)); } } void MultiheadAttentionImpl::reset() { _qkv_same_embed_dim = options.kdim() == options.embed_dim() && - options.vdim() == options.embed_dim(); + options.vdim() == options.embed_dim(); head_dim = options.embed_dim() / options.num_heads(); - TORCH_CHECK(head_dim * options.num_heads() == options.embed_dim(), - "embed_dim must be divisible by num_heads"); + TORCH_CHECK( + head_dim * options.num_heads() == options.embed_dim(), + "embed_dim must be divisible by num_heads"); if (!_qkv_same_embed_dim) { q_proj_weight = register_parameter( - "q_proj_weight", torch::empty({options.embed_dim(), options.embed_dim()})); + "q_proj_weight", + torch::empty({options.embed_dim(), options.embed_dim()})); k_proj_weight = register_parameter( - "k_proj_weight", torch::empty({options.embed_dim(), options.kdim()})); + "k_proj_weight", torch::empty({options.embed_dim(), options.kdim()})); v_proj_weight = register_parameter( - "v_proj_weight", torch::empty({options.embed_dim(), options.vdim()})); + "v_proj_weight", torch::empty({options.embed_dim(), options.vdim()})); register_parameter("in_proj_weight", {}, /*requires_grad=*/false); } else { in_proj_weight = register_parameter( - "in_proj_weight", torch::empty({3 * options.embed_dim(), options.embed_dim()})); + "in_proj_weight", + torch::empty({3 * options.embed_dim(), options.embed_dim()})); register_parameter("q_proj_weight", {}, /*requires_grad=*/false); register_parameter("k_proj_weight", {}, /*requires_grad=*/false); register_parameter("v_proj_weight", {}, /*requires_grad=*/false); } if (options.bias()) { in_proj_bias = register_parameter( - "in_proj_bias", torch::empty(3 * options.embed_dim())); + "in_proj_bias", torch::empty(3 * options.embed_dim())); } else { register_parameter("in_proj_bias", {}, /*requires_grad=*/false); } out_proj = register_module( - "out_proj", - Linear(LinearOptions(options.embed_dim(), - options.embed_dim()).bias(options.bias())) - ); + "out_proj", + Linear(LinearOptions(options.embed_dim(), options.embed_dim()) + .bias(options.bias()))); if (options.add_bias_kv()) { - bias_k = register_parameter("bias_k", torch::empty({1, 1, options.embed_dim()})); - bias_v = register_parameter("bias_v", torch::empty({1, 1, options.embed_dim()})); + bias_k = + register_parameter("bias_k", torch::empty({1, 1, options.embed_dim()})); + bias_v = + register_parameter("bias_v", torch::empty({1, 1, options.embed_dim()})); } else { bias_k.reset(); bias_v.reset(); diff --git a/torch/csrc/api/src/nn/modules/adaptive.cpp b/torch/csrc/api/src/nn/modules/adaptive.cpp index 6842b14550cd85..d14cbe129abbe3 100644 --- a/torch/csrc/api/src/nn/modules/adaptive.cpp +++ b/torch/csrc/api/src/nn/modules/adaptive.cpp @@ -10,9 +10,11 @@ using namespace torch::indexing; namespace torch { namespace nn { -ASMoutput::ASMoutput(Tensor output_, double loss_): output(std::move(output_)), loss(loss_) {} +ASMoutput::ASMoutput(Tensor output_, double loss_) + : output(std::move(output_)), loss(loss_) {} -AdaptiveLogSoftmaxWithLossImpl::AdaptiveLogSoftmaxWithLossImpl(AdaptiveLogSoftmaxWithLossOptions options_) +AdaptiveLogSoftmaxWithLossImpl::AdaptiveLogSoftmaxWithLossImpl( + AdaptiveLogSoftmaxWithLossOptions options_) : options(std::move(options_)), shortlist_size(0), n_clusters(0), @@ -22,12 +24,17 @@ AdaptiveLogSoftmaxWithLossImpl::AdaptiveLogSoftmaxWithLossImpl(AdaptiveLogSoftma } void AdaptiveLogSoftmaxWithLossImpl::reset() { - TORCH_CHECK( std::is_sorted(options.cutoffs().begin(), options.cutoffs().end()) && - *std::min_element(options.cutoffs().begin(), options.cutoffs().end()) > 0 && - *std::max_element(options.cutoffs().begin(), options.cutoffs().end()) <= (options.n_classes() - 1) && - std::set(options.cutoffs().begin(), options.cutoffs().end()).size() == options.cutoffs().size(), - "cutoffs should be a sequence of unique, positive integers sorted in an increasing order, ", - "where each value is between 1 and n_classes-1"); + TORCH_CHECK( + std::is_sorted(options.cutoffs().begin(), options.cutoffs().end()) && + *std::min_element( + options.cutoffs().begin(), options.cutoffs().end()) > 0 && + *std::max_element( + options.cutoffs().begin(), options.cutoffs().end()) <= + (options.n_classes() - 1) && + std::set(options.cutoffs().begin(), options.cutoffs().end()) + .size() == options.cutoffs().size(), + "cutoffs should be a sequence of unique, positive integers sorted in an increasing order, ", + "where each value is between 1 and n_classes-1"); cutoffs = options.cutoffs(); cutoffs.push_back(options.n_classes()); @@ -36,11 +43,15 @@ void AdaptiveLogSoftmaxWithLossImpl::reset() { n_clusters = cutoffs.size() - 1; head_size = shortlist_size + n_clusters; - head = this->register_module("head", Linear(LinearOptions(options.in_features(), head_size).bias(options.head_bias()))); + head = this->register_module( + "head", + Linear(LinearOptions(options.in_features(), head_size) + .bias(options.head_bias()))); tail = this->register_module("tail", ModuleList()); - for(const auto i : c10::irange(n_clusters)) { - int64_t hsz = options.in_features() / static_cast(std::pow(options.div_value(), (i + 1))); + for (const auto i : c10::irange(n_clusters)) { + int64_t hsz = options.in_features() / + static_cast(std::pow(options.div_value(), (i + 1))); int64_t osz = cutoffs[i + 1] - cutoffs[i]; Sequential projection( @@ -60,25 +71,27 @@ void AdaptiveLogSoftmaxWithLossImpl::reset_parameters() { } } -ASMoutput AdaptiveLogSoftmaxWithLossImpl::forward(const Tensor& input_, const Tensor& target_) { +ASMoutput AdaptiveLogSoftmaxWithLossImpl::forward( + const Tensor& input_, + const Tensor& target_) { auto targ_dim = target_.dim(); TORCH_CHECK( - targ_dim == 1 || targ_dim == 0, - "0D or 1D target tensor expected, multi-target not supported"); + targ_dim == 1 || targ_dim == 0, + "0D or 1D target tensor expected, multi-target not supported"); if (targ_dim == 1) { - TORCH_CHECK( - input_.dim() == 2, - "1D target tensor expects 2D input tensors, but found inputs with sizes ", - input_.sizes(), - "."); + TORCH_CHECK( + input_.dim() == 2, + "1D target tensor expects 2D input tensors, but found inputs with sizes ", + input_.sizes(), + "."); } else { TORCH_CHECK( - input_.dim() == 1, - "0D target tensor expects 1D input tensors, but found inputs with sizes ", - input_.sizes(), - "."); + input_.dim() == 1, + "0D target tensor expects 1D input tensors, but found inputs with sizes ", + input_.sizes(), + "."); } bool is_batched = (targ_dim > 0); @@ -111,13 +124,15 @@ ASMoutput AdaptiveLogSoftmaxWithLossImpl::forward(const Tensor& input_, const Te Tensor relative_target = target.index({target_mask}) - low_idx; Tensor input_subset = input.index_select(0, row_indices); - const Tensor cluster_output = tail[i - 1]->as()->forward(input_subset); + const Tensor cluster_output = + tail[i - 1]->as()->forward(input_subset); int64_t cluster_index = shortlist_size + i - 1; gather_inds.index_fill_(0, row_indices, cluster_index); const Tensor cluster_logprob = F::log_softmax(cluster_output, 1); - const Tensor local_logprob = cluster_logprob.gather(1, relative_target.unsqueeze(1)); + const Tensor local_logprob = + cluster_logprob.gather(1, relative_target.unsqueeze(1)); output.index_copy_(0, row_indices, local_logprob.squeeze(1)); } @@ -125,10 +140,16 @@ ASMoutput AdaptiveLogSoftmaxWithLossImpl::forward(const Tensor& input_, const Te } TORCH_CHECK( - used_rows == batch_size, - "Target values should be in [0, ", options.n_classes() - 1, "], " - "but values in range [", target.min().item().toDouble(), ", ", target.max().item().toDouble(), "] " - "were found. "); + used_rows == batch_size, + "Target values should be in [0, ", + options.n_classes() - 1, + "], " + "but values in range [", + target.min().item().toDouble(), + ", ", + target.max().item().toDouble(), + "] " + "were found. "); const Tensor head_output = head(input); const Tensor head_logprob = F::log_softmax(head_output, 1); @@ -142,25 +163,32 @@ ASMoutput AdaptiveLogSoftmaxWithLossImpl::forward(const Tensor& input_, const Te return ASMoutput(output, loss); } -Tensor AdaptiveLogSoftmaxWithLossImpl::_get_full_log_prob(const Tensor& input, const Tensor& head_output) { +Tensor AdaptiveLogSoftmaxWithLossImpl::_get_full_log_prob( + const Tensor& input, + const Tensor& head_output) { Tensor out = input.new_empty({head_output.size(0), options.n_classes()}); const Tensor head_logprob = F::log_softmax(head_output, 1); - out.index_put_({Slice(), Slice(None, shortlist_size)}, head_logprob.index({Slice(), Slice(None, shortlist_size)})); + out.index_put_( + {Slice(), Slice(None, shortlist_size)}, + head_logprob.index({Slice(), Slice(None, shortlist_size)})); for (const auto i : c10::irange(cutoffs.size() - 1)) { int64_t start_idx = cutoffs[i]; - int64_t stop_idx = cutoffs[i+1]; + int64_t stop_idx = cutoffs[i + 1]; const Tensor cluster_output = tail[i]->as()->forward(input); const Tensor cluster_logprob = F::log_softmax(cluster_output, 1); - auto output_logprob = cluster_logprob + head_logprob.index({Slice(), static_cast(shortlist_size + i)}).unsqueeze(1); + auto output_logprob = cluster_logprob + + head_logprob.index({Slice(), static_cast(shortlist_size + i)}) + .unsqueeze(1); out.index_put_({Slice(), Slice(start_idx, stop_idx)}, output_logprob); } return out; } -Tensor AdaptiveLogSoftmaxWithLossImpl::AdaptiveLogSoftmaxWithLossImpl::log_prob(const Tensor& input) { +Tensor AdaptiveLogSoftmaxWithLossImpl::AdaptiveLogSoftmaxWithLossImpl::log_prob( + const Tensor& input) { const Tensor head_output = head(input); return _get_full_log_prob(input, head_output); } @@ -178,8 +206,7 @@ Tensor AdaptiveLogSoftmaxWithLossImpl::predict(const Tensor& input) { return torch::argmax(log_prob, 1); } else { const Tensor log_prob = _get_full_log_prob( - input.index({not_in_shortlist}), - head_output.index({not_in_shortlist})); + input.index({not_in_shortlist}), head_output.index({not_in_shortlist})); output.index_put_({not_in_shortlist}, torch::argmax(log_prob, 1)); return output; } diff --git a/torch/csrc/api/src/nn/modules/batchnorm.cpp b/torch/csrc/api/src/nn/modules/batchnorm.cpp index be3343135c9443..8032001857ec9d 100644 --- a/torch/csrc/api/src/nn/modules/batchnorm.cpp +++ b/torch/csrc/api/src/nn/modules/batchnorm.cpp @@ -18,39 +18,39 @@ namespace nn { template void BatchNormImplBase::pretty_print(std::ostream& stream) const { - stream << std::boolalpha - << "torch::nn::BatchNorm" << D << "d(" + stream << std::boolalpha << "torch::nn::BatchNorm" << D << "d(" << this->options.num_features() << ", " << "eps=" << this->options.eps() << ", " << "momentum="; if (this->options.momentum().has_value()) { - stream << this->options.momentum().value(); + stream << this->options.momentum().value(); } else { - stream << "None"; + stream << "None"; } - stream << ", " + stream << ", " << "affine=" << this->options.affine() << ", " - << "track_running_stats=" << this->options.track_running_stats() << ")"; + << "track_running_stats=" << this->options.track_running_stats() + << ")"; } void BatchNorm1dImpl::_check_input_dim(const Tensor& input) { TORCH_CHECK( input.dim() == 2 || input.dim() == 3, - "expected 2D or 3D input (got ", input.dim(), "D input)"); + "expected 2D or 3D input (got ", + input.dim(), + "D input)"); } void BatchNorm2dImpl::_check_input_dim(const Tensor& input) { TORCH_CHECK( - input.dim() == 4, - "expected 4D input (got ", input.dim(), "D input)"); + input.dim() == 4, "expected 4D input (got ", input.dim(), "D input)"); } void BatchNorm3dImpl::_check_input_dim(const Tensor& input) { TORCH_CHECK( - input.dim() == 5, - "expected 5D input (got ", input.dim(), "D input)"); + input.dim() == 5, "expected 5D input (got ", input.dim(), "D input)"); } template class BatchNormImplBase<1, BatchNorm1dImpl>; diff --git a/torch/csrc/api/src/nn/modules/conv.cpp b/torch/csrc/api/src/nn/modules/conv.cpp index b3ea8ec4a9d9f1..8f241dc71e366e 100644 --- a/torch/csrc/api/src/nn/modules/conv.cpp +++ b/torch/csrc/api/src/nn/modules/conv.cpp @@ -17,7 +17,8 @@ namespace F = torch::nn::functional; -F::PadFuncOptions::mode_t _get_pad_mode_from_conv_padding_mode(torch::nn::detail::conv_padding_mode_t conv_padding_mode) { +F::PadFuncOptions::mode_t _get_pad_mode_from_conv_padding_mode( + torch::nn::detail::conv_padding_mode_t conv_padding_mode) { F::PadFuncOptions::mode_t pad_mode; if (c10::get_if(&conv_padding_mode)) { pad_mode = torch::kReflect; @@ -26,123 +27,135 @@ F::PadFuncOptions::mode_t _get_pad_mode_from_conv_padding_mode(torch::nn::detail } else if (c10::get_if(&conv_padding_mode)) { pad_mode = torch::kCircular; } else { - TORCH_CHECK(false, "Unsupported conv padding mode: ", torch::enumtype::get_enum_name(conv_padding_mode)); + TORCH_CHECK( + false, + "Unsupported conv padding mode: ", + torch::enumtype::get_enum_name(conv_padding_mode)); } return pad_mode; } namespace torch { namespace nn { -Conv1dImpl::Conv1dImpl( - Conv1dOptions options_) - : ConvNdImpl( - detail::ConvNdOptions<1>( - /*in_channels=*/options_.in_channels(), - /*out_channels=*/options_.out_channels(), - /*kernel_size=*/options_.kernel_size()) - .stride(options_.stride()) - .padding(options_.padding()) - .dilation(options_.dilation()) - .transposed(false) - .output_padding(0) - .groups(options_.groups()) - .bias(options_.bias()) - .padding_mode(options_.padding_mode())) {} +Conv1dImpl::Conv1dImpl(Conv1dOptions options_) + : ConvNdImpl(detail::ConvNdOptions<1>( + /*in_channels=*/options_.in_channels(), + /*out_channels=*/options_.out_channels(), + /*kernel_size=*/options_.kernel_size()) + .stride(options_.stride()) + .padding(options_.padding()) + .dilation(options_.dilation()) + .transposed(false) + .output_padding(0) + .groups(options_.groups()) + .bias(options_.bias()) + .padding_mode(options_.padding_mode())) {} Tensor Conv1dImpl::forward(const Tensor& input) { if (!c10::get_if(&options.padding_mode())) { return F::detail::conv1d( - F::pad(input, F::PadFuncOptions(_reversed_padding_repeated_twice).mode(_get_pad_mode_from_conv_padding_mode(options.padding_mode()))), - weight, bias, + F::pad( + input, + F::PadFuncOptions(_reversed_padding_repeated_twice) + .mode(_get_pad_mode_from_conv_padding_mode( + options.padding_mode()))), + weight, + bias, + options.stride(), + /*padding=*/0, + options.dilation(), + options.groups()); + } + return F::detail::conv1d( + input, + weight, + bias, options.stride(), - /*padding=*/0, + options.padding(), options.dilation(), options.groups()); - } - return F::detail::conv1d( - input, - weight, - bias, - options.stride(), - options.padding(), - options.dilation(), - options.groups()); } -Conv2dImpl::Conv2dImpl( - Conv2dOptions options_) - : ConvNdImpl( - detail::ConvNdOptions<2>( - /*in_channels=*/options_.in_channels(), - /*out_channels=*/options_.out_channels(), - /*kernel_size=*/options_.kernel_size()) - .stride(options_.stride()) - .padding(options_.padding()) - .dilation(options_.dilation()) - .transposed(false) - .output_padding(0) - .groups(options_.groups()) - .bias(options_.bias()) - .padding_mode(options_.padding_mode())) {} +Conv2dImpl::Conv2dImpl(Conv2dOptions options_) + : ConvNdImpl(detail::ConvNdOptions<2>( + /*in_channels=*/options_.in_channels(), + /*out_channels=*/options_.out_channels(), + /*kernel_size=*/options_.kernel_size()) + .stride(options_.stride()) + .padding(options_.padding()) + .dilation(options_.dilation()) + .transposed(false) + .output_padding(0) + .groups(options_.groups()) + .bias(options_.bias()) + .padding_mode(options_.padding_mode())) {} Tensor Conv2dImpl::_conv_forward(const Tensor& input, const Tensor& weight) { if (!c10::get_if(&options.padding_mode())) { return F::detail::conv2d( - F::pad(input, F::PadFuncOptions(_reversed_padding_repeated_twice).mode(_get_pad_mode_from_conv_padding_mode(options.padding_mode()))), - weight, bias, + F::pad( + input, + F::PadFuncOptions(_reversed_padding_repeated_twice) + .mode(_get_pad_mode_from_conv_padding_mode( + options.padding_mode()))), + weight, + bias, + options.stride(), + /*padding=*/0, + options.dilation(), + options.groups()); + } + return F::detail::conv2d( + input, + weight, + bias, options.stride(), - /*padding=*/0, + options.padding(), options.dilation(), options.groups()); - } - return F::detail::conv2d( - input, - weight, - bias, - options.stride(), - options.padding(), - options.dilation(), - options.groups()); } Tensor Conv2dImpl::forward(const Tensor& input) { return _conv_forward(input, weight); } -Conv3dImpl::Conv3dImpl( - Conv3dOptions options_) - : ConvNdImpl( - detail::ConvNdOptions<3>( - /*in_channels=*/options_.in_channels(), - /*out_channels=*/options_.out_channels(), - /*kernel_size=*/options_.kernel_size()) - .stride(options_.stride()) - .padding(options_.padding()) - .dilation(options_.dilation()) - .transposed(false) - .output_padding(0) - .groups(options_.groups()) - .bias(options_.bias()) - .padding_mode(options_.padding_mode())) {} +Conv3dImpl::Conv3dImpl(Conv3dOptions options_) + : ConvNdImpl(detail::ConvNdOptions<3>( + /*in_channels=*/options_.in_channels(), + /*out_channels=*/options_.out_channels(), + /*kernel_size=*/options_.kernel_size()) + .stride(options_.stride()) + .padding(options_.padding()) + .dilation(options_.dilation()) + .transposed(false) + .output_padding(0) + .groups(options_.groups()) + .bias(options_.bias()) + .padding_mode(options_.padding_mode())) {} Tensor Conv3dImpl::forward(const Tensor& input) { if (!c10::get_if(&options.padding_mode())) { return F::detail::conv3d( - F::pad(input, F::PadFuncOptions(_reversed_padding_repeated_twice).mode(_get_pad_mode_from_conv_padding_mode(options.padding_mode()))), - weight, bias, + F::pad( + input, + F::PadFuncOptions(_reversed_padding_repeated_twice) + .mode(_get_pad_mode_from_conv_padding_mode( + options.padding_mode()))), + weight, + bias, + options.stride(), + /*padding=*/0, + options.dilation(), + options.groups()); + } + return F::detail::conv3d( + input, + weight, + bias, options.stride(), - /*padding=*/0, + options.padding(), options.dilation(), options.groups()); - } - return F::detail::conv3d( - input, - weight, - bias, - options.stride(), - options.padding(), - options.dilation(), - options.groups()); } template class ConvNdImpl<1, Conv1dImpl>; @@ -153,8 +166,10 @@ template class ConvNdImpl<3, Conv3dImpl>; template std::vector ConvTransposeNdImpl::_output_padding( - const Tensor& input, const c10::optional& output_size, - const ExpandingArray& stride, const ExpandingArray& padding, + const Tensor& input, + const c10::optional& output_size, + const ExpandingArray& stride, + const ExpandingArray& padding, const ExpandingArray& kernel_size) { std::vector ret; c10::optional output_size_ = output_size; @@ -169,125 +184,163 @@ std::vector ConvTransposeNdImpl::_output_padding( } // NOLINTNEXTLINE(clang-diagnostic-sign-compare) if (output_size_.value().size() != k) { - TORCH_CHECK(false, - "output_size must have ", k, " or ", k + 2, " elements (got ", output_size_.value().size(), ")"); + TORCH_CHECK( + false, + "output_size must have ", + k, + " or ", + k + 2, + " elements (got ", + output_size_.value().size(), + ")"); } std::vector min_sizes; std::vector max_sizes; - for(const auto d : c10::irange(k)) { - int64_t dim_size = ((input.sizes()[d + 2] - 1) * (*stride)[d] - 2 * (*padding)[d] + (*kernel_size)[d]); + for (const auto d : c10::irange(k)) { + int64_t dim_size = + ((input.sizes()[d + 2] - 1) * (*stride)[d] - 2 * (*padding)[d] + + (*kernel_size)[d]); min_sizes.push_back(dim_size); max_sizes.push_back(min_sizes[d] + (*stride)[d] - 1); } - for(const auto i : c10::irange(output_size_.value().size())) { + for (const auto i : c10::irange(output_size_.value().size())) { int64_t size = output_size_.value()[i]; int64_t min_size = min_sizes[i]; int64_t max_size = max_sizes[i]; if (size < min_size || size > max_size) { - TORCH_CHECK(false, - "requested an output size of ", output_size_.value(), ", but valid sizes range " - "from ", min_sizes, " to ", max_sizes, " (for an input of ", input.sizes().slice(2), ")"); + TORCH_CHECK( + false, + "requested an output size of ", + output_size_.value(), + ", but valid sizes range " + "from ", + min_sizes, + " to ", + max_sizes, + " (for an input of ", + input.sizes().slice(2), + ")"); } } - for(const auto d : c10::irange(k)) { + for (const auto d : c10::irange(k)) { ret.push_back(output_size_.value()[d] - min_sizes[d]); } } return ret; } -ConvTranspose1dImpl::ConvTranspose1dImpl( - ConvTranspose1dOptions options_) - : ConvTransposeNdImpl( - detail::ConvNdOptions<1>( - /*in_channels=*/options_.in_channels(), - /*out_channels=*/options_.out_channels(), - /*kernel_size=*/options_.kernel_size()) - .stride(options_.stride()) - .padding(options_.padding()) - .dilation(options_.dilation()) - .transposed(true) - .output_padding(options_.output_padding()) - .groups(options_.groups()) - .bias(options_.bias()) - .padding_mode(options_.padding_mode())) {} +ConvTranspose1dImpl::ConvTranspose1dImpl(ConvTranspose1dOptions options_) + : ConvTransposeNdImpl(detail::ConvNdOptions<1>( + /*in_channels=*/options_.in_channels(), + /*out_channels=*/options_.out_channels(), + /*kernel_size=*/options_.kernel_size()) + .stride(options_.stride()) + .padding(options_.padding()) + .dilation(options_.dilation()) + .transposed(true) + .output_padding(options_.output_padding()) + .groups(options_.groups()) + .bias(options_.bias()) + .padding_mode(options_.padding_mode())) {} Tensor ConvTranspose1dImpl::forward( - const Tensor& input, const c10::optional& output_size) { + const Tensor& input, + const c10::optional& output_size) { if (!c10::get_if(&options.padding_mode())) { - TORCH_CHECK(false, "Only `zeros` padding mode is supported for ConvTranspose1d"); + TORCH_CHECK( + false, "Only `zeros` padding mode is supported for ConvTranspose1d"); } - const auto & pad = padding(); + const auto& pad = padding(); std::vector output_padding = _output_padding( - input, output_size, options.stride(), pad, options.kernel_size()); + input, output_size, options.stride(), pad, options.kernel_size()); return F::detail::conv_transpose1d( - input, weight, bias, options.stride(), pad, - output_padding, options.groups(), options.dilation()); + input, + weight, + bias, + options.stride(), + pad, + output_padding, + options.groups(), + options.dilation()); } -ConvTranspose2dImpl::ConvTranspose2dImpl( - ConvTranspose2dOptions options_) +ConvTranspose2dImpl::ConvTranspose2dImpl(ConvTranspose2dOptions options_) : ConvTransposeNdImpl(detail::ConvNdOptions<2>( - /*in_channels=*/options_.in_channels(), - /*out_channels=*/options_.out_channels(), - /*kernel_size=*/options_.kernel_size()) - .stride(options_.stride()) - .padding(options_.padding()) - .dilation(options_.dilation()) - .transposed(true) - .output_padding(options_.output_padding()) - .groups(options_.groups()) - .bias(options_.bias()) - .padding_mode(options_.padding_mode())) {} + /*in_channels=*/options_.in_channels(), + /*out_channels=*/options_.out_channels(), + /*kernel_size=*/options_.kernel_size()) + .stride(options_.stride()) + .padding(options_.padding()) + .dilation(options_.dilation()) + .transposed(true) + .output_padding(options_.output_padding()) + .groups(options_.groups()) + .bias(options_.bias()) + .padding_mode(options_.padding_mode())) {} Tensor ConvTranspose2dImpl::forward( - const Tensor& input, const c10::optional& output_size) { + const Tensor& input, + const c10::optional& output_size) { if (!c10::get_if(&options.padding_mode())) { - TORCH_CHECK(false, "Only `zeros` padding mode is supported for ConvTranspose2d"); + TORCH_CHECK( + false, "Only `zeros` padding mode is supported for ConvTranspose2d"); } - const auto & pad = padding(); + const auto& pad = padding(); std::vector output_padding = _output_padding( - input, output_size, options.stride(), pad, options.kernel_size()); + input, output_size, options.stride(), pad, options.kernel_size()); return F::detail::conv_transpose2d( - input, weight, bias, options.stride(), pad, - output_padding, options.groups(), options.dilation()); + input, + weight, + bias, + options.stride(), + pad, + output_padding, + options.groups(), + options.dilation()); } -ConvTranspose3dImpl::ConvTranspose3dImpl( - ConvTranspose3dOptions options_) +ConvTranspose3dImpl::ConvTranspose3dImpl(ConvTranspose3dOptions options_) : ConvTransposeNdImpl(detail::ConvNdOptions<3>( - /*in_channels=*/options_.in_channels(), - /*out_channels=*/options_.out_channels(), - /*kernel_size=*/options_.kernel_size()) - .stride(options_.stride()) - .padding(options_.padding()) - .dilation(options_.dilation()) - .transposed(true) - .output_padding(options_.output_padding()) - .groups(options_.groups()) - .bias(options_.bias()) - .padding_mode(options_.padding_mode())) {} + /*in_channels=*/options_.in_channels(), + /*out_channels=*/options_.out_channels(), + /*kernel_size=*/options_.kernel_size()) + .stride(options_.stride()) + .padding(options_.padding()) + .dilation(options_.dilation()) + .transposed(true) + .output_padding(options_.output_padding()) + .groups(options_.groups()) + .bias(options_.bias()) + .padding_mode(options_.padding_mode())) {} Tensor ConvTranspose3dImpl::forward( - const Tensor& input, const c10::optional& output_size) { + const Tensor& input, + const c10::optional& output_size) { if (!c10::get_if(&options.padding_mode())) { - TORCH_CHECK(false, "Only `zeros` padding mode is supported for ConvTranspose3d"); + TORCH_CHECK( + false, "Only `zeros` padding mode is supported for ConvTranspose3d"); } - const auto & pad = padding(); + const auto& pad = padding(); std::vector output_padding = _output_padding( - input, output_size, options.stride(), pad, options.kernel_size()); + input, output_size, options.stride(), pad, options.kernel_size()); return F::detail::conv_transpose3d( - input, weight, bias, options.stride(), pad, - output_padding, options.groups(), options.dilation()); + input, + weight, + bias, + options.stride(), + pad, + output_padding, + options.groups(), + options.dilation()); } template class ConvTransposeNdImpl<1, ConvTranspose1dImpl>; diff --git a/torch/csrc/api/src/nn/modules/distance.cpp b/torch/csrc/api/src/nn/modules/distance.cpp index bd8502c3b66a44..d62d608977c1d7 100644 --- a/torch/csrc/api/src/nn/modules/distance.cpp +++ b/torch/csrc/api/src/nn/modules/distance.cpp @@ -5,16 +5,15 @@ namespace F = torch::nn::functional; namespace torch { namespace nn { -CosineSimilarityImpl::CosineSimilarityImpl(const CosineSimilarityOptions& options_) +CosineSimilarityImpl::CosineSimilarityImpl( + const CosineSimilarityOptions& options_) : options(options_) {} void CosineSimilarityImpl::reset() {} void CosineSimilarityImpl::pretty_print(std::ostream& stream) const { - stream << std::boolalpha - << "torch::nn::CosineSimilarity" - << "(dim=" << options.dim() - << ", eps=" << options.eps() << ")"; + stream << std::boolalpha << "torch::nn::CosineSimilarity" + << "(dim=" << options.dim() << ", eps=" << options.eps() << ")"; } Tensor CosineSimilarityImpl::forward(const Tensor& x1, const Tensor& x2) { @@ -23,21 +22,21 @@ Tensor CosineSimilarityImpl::forward(const Tensor& x1, const Tensor& x2) { // ============================================================================ -PairwiseDistanceImpl::PairwiseDistanceImpl(const PairwiseDistanceOptions& options_) +PairwiseDistanceImpl::PairwiseDistanceImpl( + const PairwiseDistanceOptions& options_) : options(options_) {} void PairwiseDistanceImpl::reset() {} void PairwiseDistanceImpl::pretty_print(std::ostream& stream) const { - stream << std::boolalpha - << "torch::nn::PairwiseDistance" - << "(p=" << options.p() - << ", eps=" << options.eps() + stream << std::boolalpha << "torch::nn::PairwiseDistance" + << "(p=" << options.p() << ", eps=" << options.eps() << ", keepdim=" << options.keepdim() << ")"; } Tensor PairwiseDistanceImpl::forward(const Tensor& x1, const Tensor& x2) { - return F::detail::pairwise_distance(x1, x2, options.p(), options.eps(), options.keepdim()); + return F::detail::pairwise_distance( + x1, x2, options.p(), options.eps(), options.keepdim()); } } // namespace nn diff --git a/torch/csrc/api/src/nn/modules/dropout.cpp b/torch/csrc/api/src/nn/modules/dropout.cpp index 12adb87d04edfb..2e826869db08ce 100644 --- a/torch/csrc/api/src/nn/modules/dropout.cpp +++ b/torch/csrc/api/src/nn/modules/dropout.cpp @@ -1,5 +1,5 @@ -#include #include +#include #include @@ -15,63 +15,60 @@ namespace torch { namespace nn { Tensor DropoutImpl::forward(Tensor input) { - return F::detail::dropout(input, options.p(), - is_training(), options.inplace()); + return F::detail::dropout( + input, options.p(), is_training(), options.inplace()); } void DropoutImpl::pretty_print(std::ostream& stream) const { - stream << std::boolalpha - << "torch::nn::Dropout(p=" << options.p() + stream << std::boolalpha << "torch::nn::Dropout(p=" << options.p() << ", inplace=" << options.inplace() << ")"; } // ============================================================================ Tensor Dropout2dImpl::forward(Tensor input) { - return F::detail::dropout2d(input, options.p(), - is_training(), options.inplace()); + return F::detail::dropout2d( + input, options.p(), is_training(), options.inplace()); } void Dropout2dImpl::pretty_print(std::ostream& stream) const { - stream << std::boolalpha - << "torch::nn::Dropout2d(p=" << options.p() + stream << std::boolalpha << "torch::nn::Dropout2d(p=" << options.p() << ", inplace=" << options.inplace() << ")"; } // ============================================================================ Tensor Dropout3dImpl::forward(Tensor input) { - return F::detail::dropout3d(input, options.p(), - is_training(), options.inplace()); + return F::detail::dropout3d( + input, options.p(), is_training(), options.inplace()); } void Dropout3dImpl::pretty_print(std::ostream& stream) const { - stream << std::boolalpha - << "torch::nn::Dropout3d(p=" << options.p() + stream << std::boolalpha << "torch::nn::Dropout3d(p=" << options.p() << ", inplace=" << options.inplace() << ")"; } // ============================================================================ Tensor AlphaDropoutImpl::forward(const Tensor& input) { - return F::detail::alpha_dropout(input, options.p(), is_training(), /*inplace=*/false); + return F::detail::alpha_dropout( + input, options.p(), is_training(), /*inplace=*/false); } void AlphaDropoutImpl::pretty_print(std::ostream& stream) const { - stream << std::boolalpha - << "torch::nn::AlphaDropout(p=" << options.p() + stream << std::boolalpha << "torch::nn::AlphaDropout(p=" << options.p() << ", inplace=" << options.inplace() << ")"; } // ============================================================================ Tensor FeatureAlphaDropoutImpl::forward(const Tensor& input) { - return F::detail::feature_alpha_dropout(input, options.p(), is_training(), /*inplace=*/false); + return F::detail::feature_alpha_dropout( + input, options.p(), is_training(), /*inplace=*/false); } void FeatureAlphaDropoutImpl::pretty_print(std::ostream& stream) const { - stream << std::boolalpha - << "torch::nn::FeatureAlphaDropout(p=" << options.p() + stream << std::boolalpha << "torch::nn::FeatureAlphaDropout(p=" << options.p() << ", inplace=" << options.inplace() << ")"; } diff --git a/torch/csrc/api/src/nn/modules/embedding.cpp b/torch/csrc/api/src/nn/modules/embedding.cpp index fec2d7a0f2466a..5354cef4862557 100644 --- a/torch/csrc/api/src/nn/modules/embedding.cpp +++ b/torch/csrc/api/src/nn/modules/embedding.cpp @@ -1,8 +1,8 @@ #include +#include #include #include -#include #include #include @@ -13,7 +13,8 @@ namespace F = torch::nn::functional; namespace torch { namespace nn { -EmbeddingImpl::EmbeddingImpl(const EmbeddingOptions& options_) : options(options_) { // NOLINT(modernize-pass-by-value) +EmbeddingImpl::EmbeddingImpl(const EmbeddingOptions& options_) + : options(options_) { // NOLINT(modernize-pass-by-value) // NOLINTNEXTLINE(clang-analyzer-optin.cplusplus.VirtualCall) reset(); } @@ -21,20 +22,28 @@ EmbeddingImpl::EmbeddingImpl(const EmbeddingOptions& options_) : options(options void EmbeddingImpl::reset() { if (options.padding_idx() != c10::nullopt) { if (*options.padding_idx() > 0) { - TORCH_CHECK(*options.padding_idx() < options.num_embeddings(), "Padding_idx must be within num_embeddings"); - } - else if (*options.padding_idx() < 0) { - TORCH_CHECK(*options.padding_idx() >= -options.num_embeddings(), "Padding_idx must be within num_embedding"); + TORCH_CHECK( + *options.padding_idx() < options.num_embeddings(), + "Padding_idx must be within num_embeddings"); + } else if (*options.padding_idx() < 0) { + TORCH_CHECK( + *options.padding_idx() >= -options.num_embeddings(), + "Padding_idx must be within num_embedding"); options.padding_idx(options.num_embeddings() + *options.padding_idx()); } } if (!options._weight().defined()) { weight = register_parameter( - "weight", torch::empty({options.num_embeddings(), options.embedding_dim()})); + "weight", + torch::empty({options.num_embeddings(), options.embedding_dim()})); reset_parameters(); } else { - TORCH_CHECK(options._weight().sizes() == torch::IntArrayRef({options.num_embeddings(), options.embedding_dim()}), "Shape of _weight does not match num_embeddings and embedding_dim"); + TORCH_CHECK( + options._weight().sizes() == + torch::IntArrayRef( + {options.num_embeddings(), options.embedding_dim()}), + "Shape of _weight does not match num_embeddings and embedding_dim"); weight = register_parameter("weight", options._weight()); } } @@ -60,7 +69,8 @@ void EmbeddingImpl::pretty_print(std::ostream& stream) const { stream << ", norm_type=" << options.norm_type(); } if (options.scale_grad_by_freq()) { - stream << ", scale_grad_by_freq=" << std::boolalpha << options.scale_grad_by_freq(); + stream << ", scale_grad_by_freq=" << std::boolalpha + << options.scale_grad_by_freq(); } if (options.sparse()) { stream << ", sparse=" << std::boolalpha << options.sparse(); @@ -70,16 +80,17 @@ void EmbeddingImpl::pretty_print(std::ostream& stream) const { torch::Tensor EmbeddingImpl::forward(const Tensor& input) { return F::detail::embedding( - input, - weight, - options.padding_idx(), - options.max_norm(), - options.norm_type(), - options.scale_grad_by_freq(), - options.sparse()); + input, + weight, + options.padding_idx(), + options.max_norm(), + options.norm_type(), + options.scale_grad_by_freq(), + options.sparse()); } -EmbeddingBagImpl::EmbeddingBagImpl(const EmbeddingBagOptions& options_) : options(options_) { // NOLINT(modernize-pass-by-value) +EmbeddingBagImpl::EmbeddingBagImpl(const EmbeddingBagOptions& options_) + : options(options_) { // NOLINT(modernize-pass-by-value) // NOLINTNEXTLINE(clang-analyzer-optin.cplusplus.VirtualCall) reset(); } @@ -88,21 +99,27 @@ void EmbeddingBagImpl::reset() { if (options.padding_idx().has_value()) { auto padding_idx = options.padding_idx().value(); if (padding_idx > 0) { - TORCH_CHECK(padding_idx < options.num_embeddings(), "Padding_idx must be within num_embeddings"); - } - else if (padding_idx < 0) { - TORCH_CHECK(padding_idx >= -options.num_embeddings(), "Padding_idx must be within num_embedding"); + TORCH_CHECK( + padding_idx < options.num_embeddings(), + "Padding_idx must be within num_embeddings"); + } else if (padding_idx < 0) { + TORCH_CHECK( + padding_idx >= -options.num_embeddings(), + "Padding_idx must be within num_embedding"); options.padding_idx(options.num_embeddings() + padding_idx); } } if (!options._weight().defined()) { weight = register_parameter( - "weight", torch::empty({options.num_embeddings(), options.embedding_dim()})); + "weight", + torch::empty({options.num_embeddings(), options.embedding_dim()})); reset_parameters(); } else { TORCH_CHECK( - options._weight().sizes() == torch::IntArrayRef({options.num_embeddings(), options.embedding_dim()}), - "Shape of weight does not match num_embeddings and embedding_dim"); + options._weight().sizes() == + torch::IntArrayRef( + {options.num_embeddings(), options.embedding_dim()}), + "Shape of weight does not match num_embeddings and embedding_dim"); weight = register_parameter("weight", options._weight()); } } @@ -115,24 +132,28 @@ void EmbeddingBagImpl::reset_parameters() { torch::nn::init::normal_(weight); } -torch::Tensor EmbeddingBagImpl::forward(const Tensor& input, const Tensor& offsets, const Tensor& per_sample_weights) { +torch::Tensor EmbeddingBagImpl::forward( + const Tensor& input, + const Tensor& offsets, + const Tensor& per_sample_weights) { return F::detail::embedding_bag( - input, - weight, - offsets, - options.max_norm(), - options.norm_type(), - options.scale_grad_by_freq(), - options.mode(), - options.sparse(), - per_sample_weights, - options.include_last_offset(), - options.padding_idx()); + input, + weight, + offsets, + options.max_norm(), + options.norm_type(), + options.scale_grad_by_freq(), + options.mode(), + options.sparse(), + per_sample_weights, + options.include_last_offset(), + options.padding_idx()); } void EmbeddingBagImpl::pretty_print(std::ostream& stream) const { - stream << "torch::nn::EmbeddingBag(num_embeddings=" << options.num_embeddings() - << ", embedding_dim=" << options.embedding_dim(); + stream << "torch::nn::EmbeddingBag(num_embeddings=" + << options.num_embeddings() + << ", embedding_dim=" << options.embedding_dim(); if (options.max_norm() != c10::nullopt) { stream << ", max_norm=" << *options.max_norm(); } @@ -140,16 +161,18 @@ void EmbeddingBagImpl::pretty_print(std::ostream& stream) const { stream << ", norm_type=" << options.norm_type(); } if (options.scale_grad_by_freq()) { - stream << ", scale_grad_by_freq=" << std::boolalpha << options.scale_grad_by_freq(); + stream << ", scale_grad_by_freq=" << std::boolalpha + << options.scale_grad_by_freq(); } if (options.sparse()) { stream << ", sparse=" << std::boolalpha << options.sparse(); } if (!c10::get_if(&options.mode())) { - stream << ", mode=" << torch::enumtype::get_enum_name(options.mode()); + stream << ", mode=" << torch::enumtype::get_enum_name(options.mode()); } if (options.include_last_offset()) { - stream << ", include_last_offset=" << std::boolalpha << options.include_last_offset(); + stream << ", include_last_offset=" << std::boolalpha + << options.include_last_offset(); } if (options.padding_idx().has_value()) { stream << ", padding_idx=" << options.padding_idx().value(); diff --git a/torch/csrc/api/src/nn/modules/fold.cpp b/torch/csrc/api/src/nn/modules/fold.cpp index 0dc30274909177..20dadc4e62f152 100644 --- a/torch/csrc/api/src/nn/modules/fold.cpp +++ b/torch/csrc/api/src/nn/modules/fold.cpp @@ -17,19 +17,18 @@ void FoldImpl::pretty_print(std::ostream& stream) const { stream << "torch::nn::Fold(output_size=" << options.output_size() << ", kernel_size=" << options.kernel_size() << ", dilation=" << options.dilation() - << ", padding=" << options.padding() - << ", stride=" << options.stride() + << ", padding=" << options.padding() << ", stride=" << options.stride() << ")"; } Tensor FoldImpl::forward(const Tensor& input) { return F::detail::fold( - input, - options.output_size(), - options.kernel_size(), - options.dilation(), - options.padding(), - options.stride()); + input, + options.output_size(), + options.kernel_size(), + options.dilation(), + options.padding(), + options.stride()); } // ============================================================================ @@ -41,13 +40,17 @@ void UnfoldImpl::reset() {} void UnfoldImpl::pretty_print(std::ostream& stream) const { stream << "torch::nn::Unfold(kernel_size=" << options.kernel_size() << ", dilation=" << options.dilation() - << ", padding=" << options.padding() - << ", stride=" << options.stride() + << ", padding=" << options.padding() << ", stride=" << options.stride() << ")"; } Tensor UnfoldImpl::forward(const Tensor& input) { - return F::detail::unfold(input, options.kernel_size(), options.dilation(), options.padding(), options.stride()); + return F::detail::unfold( + input, + options.kernel_size(), + options.dilation(), + options.padding(), + options.stride()); } } // namespace nn diff --git a/torch/csrc/api/src/nn/modules/instancenorm.cpp b/torch/csrc/api/src/nn/modules/instancenorm.cpp index a63bf0810fe7cb..a7eb31882e7d4c 100644 --- a/torch/csrc/api/src/nn/modules/instancenorm.cpp +++ b/torch/csrc/api/src/nn/modules/instancenorm.cpp @@ -8,36 +8,34 @@ namespace nn { template void InstanceNormImpl::pretty_print(std::ostream& stream) const { - stream << std::boolalpha - << "torch::nn::InstanceNorm" << D << "d(" + stream << std::boolalpha << "torch::nn::InstanceNorm" << D << "d(" << this->options.num_features() << ", " << "eps=" << this->options.eps() << ", " << "momentum=" << this->options.momentum() << ", " << "affine=" << this->options.affine() << ", " - << "track_running_stats=" << this->options.track_running_stats() << ")"; + << "track_running_stats=" << this->options.track_running_stats() + << ")"; } void InstanceNorm1dImpl::_check_input_dim(const Tensor& input) { if (input.dim() != 3 && input.dim() != 2) { TORCH_CHECK( - false, - "expected 2D or 3D input (got ", input.dim(), "D input)"); + false, "expected 2D or 3D input (got ", input.dim(), "D input)"); } } void InstanceNorm2dImpl::_check_input_dim(const Tensor& input) { if (input.dim() != 4 && input.dim() != 3) { TORCH_CHECK( - false, - "expected 3D or 4D input (got ", input.dim(), "D input)"); + false, "expected 3D or 4D input (got ", input.dim(), "D input)"); } } void InstanceNorm3dImpl::_check_input_dim(const Tensor& input) { - if (input.dim() != 5 && input.dim() != 4) { // NOLINT(cppcoreguidelines-avoid-magic-numbers) + if (input.dim() != 5 && + input.dim() != 4) { // NOLINT(cppcoreguidelines-avoid-magic-numbers) TORCH_CHECK( - false, - "expected 4D or 5D input (got ", input.dim(), "D input)"); + false, "expected 4D or 5D input (got ", input.dim(), "D input)"); } } diff --git a/torch/csrc/api/src/nn/modules/linear.cpp b/torch/csrc/api/src/nn/modules/linear.cpp index d6f16077c4a4f1..56ff49eff189c1 100644 --- a/torch/csrc/api/src/nn/modules/linear.cpp +++ b/torch/csrc/api/src/nn/modules/linear.cpp @@ -1,6 +1,6 @@ -#include #include #include +#include #include #include @@ -31,8 +31,8 @@ LinearImpl::LinearImpl(const LinearOptions& options_) : options(options_) { } void LinearImpl::reset() { - weight = register_parameter("weight", - torch::empty({options.out_features(), options.in_features()})); + weight = register_parameter( + "weight", torch::empty({options.out_features(), options.in_features()})); if (options.bias()) { bias = register_parameter("bias", torch::empty(options.out_features())); } else { @@ -43,12 +43,13 @@ void LinearImpl::reset() { } void LinearImpl::reset_parameters() { - torch::nn::init::kaiming_uniform_(weight, std::sqrt(5)); // NOLINT(cppcoreguidelines-avoid-magic-numbers) + torch::nn::init::kaiming_uniform_( + weight, std::sqrt(5)); // NOLINT(cppcoreguidelines-avoid-magic-numbers) if (bias.defined()) { // NOLINTNEXTLINE(cppcoreguidelines-init-variables) int64_t fan_in, fan_out; std::tie(fan_in, fan_out) = - torch::nn::init::_calculate_fan_in_and_fan_out(weight); + torch::nn::init::_calculate_fan_in_and_fan_out(weight); const auto bound = 1 / std::sqrt(fan_in); torch::nn::init::uniform_(bias, -bound, bound); } @@ -82,21 +83,26 @@ Tensor FlattenImpl::forward(const Tensor& input) { // ============================================================================ -UnflattenImpl::UnflattenImpl(UnflattenOptions options_) : options(std::move(options_)) {} +UnflattenImpl::UnflattenImpl(UnflattenOptions options_) + : options(std::move(options_)) {} void UnflattenImpl::reset() {} void UnflattenImpl::pretty_print(std::ostream& stream) const { auto namedshape = options.namedshape(); if (!namedshape.empty()) { - stream << "torch::nn::Unflatten(dim=\"" << options.dimname() << "\", unflattened_size={"; + stream << "torch::nn::Unflatten(dim=\"" << options.dimname() + << "\", unflattened_size={"; size_t i = 0; for (; i < namedshape.size() - 1; ++i) { - stream << "{\"" << std::get<0>(namedshape[i]) << "\", " << std::get<1>(namedshape[i]) << "}, "; + stream << "{\"" << std::get<0>(namedshape[i]) << "\", " + << std::get<1>(namedshape[i]) << "}, "; } - stream << "{\"" << std::get<0>(namedshape[i]) << "\", " << std::get<1>(namedshape[i]) << "}})"; + stream << "{\"" << std::get<0>(namedshape[i]) << "\", " + << std::get<1>(namedshape[i]) << "}})"; } else { - stream << "torch::nn::Unflatten(dim=" << options.dim() << ", unflattened_size={"; + stream << "torch::nn::Unflatten(dim=" << options.dim() + << ", unflattened_size={"; auto sizes = options.sizes(); size_t i = 0; for (; i < sizes.size() - 1; ++i) { @@ -109,11 +115,13 @@ void UnflattenImpl::pretty_print(std::ostream& stream) const { Tensor UnflattenImpl::forward(const Tensor& input) { auto namedshape = options.namedshape(); if (!namedshape.empty()) { - auto dimname = torch::Dimname::fromSymbol(torch::Symbol::dimname(options.dimname())); + auto dimname = + torch::Dimname::fromSymbol(torch::Symbol::dimname(options.dimname())); std::vector sizes; std::vector names; for (auto i : namedshape) { - names.push_back(torch::Dimname::fromSymbol(torch::Symbol::dimname(std::get<0>(i)))); + names.push_back( + torch::Dimname::fromSymbol(torch::Symbol::dimname(std::get<0>(i)))); sizes.push_back(std::get<1>(i)); } return input.unflatten(dimname, sizes, names); @@ -123,14 +131,19 @@ Tensor UnflattenImpl::forward(const Tensor& input) { // ============================================================================ -BilinearImpl::BilinearImpl(const BilinearOptions& options_) : options(options_) { +BilinearImpl::BilinearImpl(const BilinearOptions& options_) + : options(options_) { // NOLINTNEXTLINE(clang-analyzer-optin.cplusplus.VirtualCall) reset(); } void BilinearImpl::reset() { - weight = - register_parameter("weight", torch::empty({options.out_features(), options.in1_features(), options.in2_features()})); + weight = register_parameter( + "weight", + torch::empty( + {options.out_features(), + options.in1_features(), + options.in2_features()})); if (options.bias()) { bias = register_parameter("bias", torch::empty(options.out_features())); } else { @@ -144,14 +157,16 @@ void BilinearImpl::reset_parameters() { const auto bound = 1.0 / std::sqrt(weight.size(1)); init::uniform_(weight, -bound, bound); if (bias.defined()) { - init::uniform_(bias, -bound, bound); + init::uniform_(bias, -bound, bound); } } void BilinearImpl::pretty_print(std::ostream& stream) const { - stream << std::boolalpha << "torch::nn::Bilinear(in1_features=" << options.in1_features() - << ", in2_features=" << options.in2_features() << ", out_features=" << options.out_features() << ", bias=" << options.bias() - << ")"; + stream << std::boolalpha + << "torch::nn::Bilinear(in1_features=" << options.in1_features() + << ", in2_features=" << options.in2_features() + << ", out_features=" << options.out_features() + << ", bias=" << options.bias() << ")"; } Tensor BilinearImpl::forward(const Tensor& input1, const Tensor& input2) { diff --git a/torch/csrc/api/src/nn/modules/loss.cpp b/torch/csrc/api/src/nn/modules/loss.cpp index dda67fe9c728e3..3e4ecca31d8408 100644 --- a/torch/csrc/api/src/nn/modules/loss.cpp +++ b/torch/csrc/api/src/nn/modules/loss.cpp @@ -29,7 +29,8 @@ void KLDivLossImpl::pretty_print(std::ostream& stream) const { } Tensor KLDivLossImpl::forward(const Tensor& input, const Tensor& target) { - return F::detail::kl_div(input, target, options.reduction(), options.log_target()); + return F::detail::kl_div( + input, target, options.reduction(), options.log_target()); } // ============================================================================ @@ -48,7 +49,8 @@ Tensor MSELossImpl::forward(const Tensor& input, const Tensor& target) { // ============================================================================ -BCELossImpl::BCELossImpl(const BCELossOptions& options_) : options(options_) { // NOLINT(modernize-pass-by-value) +BCELossImpl::BCELossImpl(const BCELossOptions& options_) + : options(options_) { // NOLINT(modernize-pass-by-value) // NOLINTNEXTLINE(clang-analyzer-optin.cplusplus.VirtualCall) reset(); } @@ -62,7 +64,8 @@ void BCELossImpl::pretty_print(std::ostream& stream) const { } Tensor BCELossImpl::forward(const Tensor& input, const Tensor& target) { - return F::detail::binary_cross_entropy(input, target, options.weight(), options.reduction()); + return F::detail::binary_cross_entropy( + input, target, options.weight(), options.reduction()); } // ============================================================================ @@ -80,7 +83,8 @@ void HingeEmbeddingLossImpl::pretty_print(std::ostream& stream) const { Tensor HingeEmbeddingLossImpl::forward( const Tensor& input, const Tensor& target) { - return F::detail::hinge_embedding_loss(input, target, options.margin(), options.reduction()); + return F::detail::hinge_embedding_loss( + input, target, options.margin(), options.reduction()); } // ============================================================================ @@ -104,11 +108,18 @@ void MultiMarginLossImpl::reset() { void MultiMarginLossImpl::pretty_print(std::ostream& stream) const { stream << "torch::nn::MultiMarginLoss(p=" << options.p() << ", margin=" << options.margin() << ", weight=" << options.weight() - << ", reduction=" << enumtype::get_enum_name(options.reduction()) << ")"; + << ", reduction=" << enumtype::get_enum_name(options.reduction()) + << ")"; } Tensor MultiMarginLossImpl::forward(const Tensor& input, const Tensor& target) { - return F::detail::multi_margin_loss(input, target, options.p(), options.margin(), options.weight(), options.reduction()); + return F::detail::multi_margin_loss( + input, + target, + options.p(), + options.margin(), + options.weight(), + options.reduction()); } // ============================================================================ @@ -127,12 +138,14 @@ Tensor CosineEmbeddingLossImpl::forward( const Tensor& input1, const Tensor& input2, const Tensor& target) { - return F::detail::cosine_embedding_loss(input1, input2, target, options.margin(), options.reduction()); + return F::detail::cosine_embedding_loss( + input1, input2, target, options.margin(), options.reduction()); } // ============================================================================ MultiLabelSoftMarginLossImpl::MultiLabelSoftMarginLossImpl( - const torch::nn::MultiLabelSoftMarginLossOptions& options_) // NOLINT(modernize-pass-by-value) + const torch::nn::MultiLabelSoftMarginLossOptions& + options_) // NOLINT(modernize-pass-by-value) : options(options_) { // NOLINTNEXTLINE(clang-analyzer-optin.cplusplus.VirtualCall) reset(); @@ -149,7 +162,8 @@ void MultiLabelSoftMarginLossImpl::reset() { Tensor MultiLabelSoftMarginLossImpl::forward( const Tensor& input, const Tensor& target) { - return F::detail::multilabel_soft_margin_loss(input, target, options.weight(), options.reduction()); + return F::detail::multilabel_soft_margin_loss( + input, target, options.weight(), options.reduction()); } // ============================================================================ @@ -171,14 +185,14 @@ Tensor TripletMarginLossImpl::forward( const Tensor& positive, const Tensor& negative) { return F::detail::triplet_margin_loss( - anchor, - positive, - negative, - options.margin(), - options.p(), - options.eps(), - options.swap(), - options.reduction()); + anchor, + positive, + negative, + options.margin(), + options.p(), + options.eps(), + options.swap(), + options.reduction()); } // ============================================================================ @@ -189,9 +203,11 @@ TripletMarginWithDistanceLossImpl::TripletMarginWithDistanceLossImpl( void TripletMarginWithDistanceLossImpl::reset() {} -void TripletMarginWithDistanceLossImpl::pretty_print(std::ostream& stream) const { - stream << "torch::nn::TripletMarginWithDistanceLoss(margin=" << options.margin() - << std::boolalpha << ", swap=" << options.swap() << ")"; +void TripletMarginWithDistanceLossImpl::pretty_print( + std::ostream& stream) const { + stream << "torch::nn::TripletMarginWithDistanceLoss(margin=" + << options.margin() << std::boolalpha << ", swap=" << options.swap() + << ")"; } Tensor TripletMarginWithDistanceLossImpl::forward( @@ -199,13 +215,13 @@ Tensor TripletMarginWithDistanceLossImpl::forward( const Tensor& positive, const Tensor& negative) { return F::detail::triplet_margin_with_distance_loss( - anchor, - positive, - negative, - options.distance_function(), - options.margin(), - options.swap(), - options.reduction()); + anchor, + positive, + negative, + options.distance_function(), + options.margin(), + options.swap(), + options.reduction()); } // ============================================================================ @@ -220,14 +236,17 @@ void MultiLabelMarginLossImpl::pretty_print(std::ostream& stream) const { stream << "torch::nn::MultiLabelMarginLoss()"; } -Tensor MultiLabelMarginLossImpl::forward(const Tensor& input, const Tensor& target) { +Tensor MultiLabelMarginLossImpl::forward( + const Tensor& input, + const Tensor& target) { return F::detail::multilabel_margin_loss(input, target, options.reduction()); } // ============================================================================ SoftMarginLossImpl::SoftMarginLossImpl( - const torch::nn::SoftMarginLossOptions& options_) : options(options_) {} + const torch::nn::SoftMarginLossOptions& options_) + : options(options_) {} void SoftMarginLossImpl::reset() {} @@ -242,7 +261,8 @@ Tensor SoftMarginLossImpl::forward(const Tensor& input, const Tensor& target) { // ============================================================================ SmoothL1LossImpl::SmoothL1LossImpl( - const torch::nn::SmoothL1LossOptions& options_) : options(options_) {} + const torch::nn::SmoothL1LossOptions& options_) + : options(options_) {} void SmoothL1LossImpl::reset() {} @@ -251,13 +271,14 @@ void SmoothL1LossImpl::pretty_print(std::ostream& stream) const { } Tensor SmoothL1LossImpl::forward(const Tensor& input, const Tensor& target) { - return F::detail::smooth_l1_loss(input, target, options.reduction(), options.beta()); + return F::detail::smooth_l1_loss( + input, target, options.reduction(), options.beta()); } // ============================================================================ -HuberLossImpl::HuberLossImpl( - const torch::nn::HuberLossOptions& options_) : options(options_) {} +HuberLossImpl::HuberLossImpl(const torch::nn::HuberLossOptions& options_) + : options(options_) {} void HuberLossImpl::reset() {} @@ -266,7 +287,8 @@ void HuberLossImpl::pretty_print(std::ostream& stream) const { } Tensor HuberLossImpl::forward(const Tensor& input, const Tensor& target) { - return F::detail::huber_loss(input, target, options.reduction(), options.delta()); + return F::detail::huber_loss( + input, target, options.reduction(), options.delta()); } // ============================================================================ @@ -279,22 +301,25 @@ void CTCLossImpl::pretty_print(std::ostream& stream) const { stream << "torch::nn::CTCLoss()"; } -Tensor CTCLossImpl::forward(const Tensor& log_probs, const Tensor& targets, - const Tensor& input_lengths, const Tensor& target_lengths) { +Tensor CTCLossImpl::forward( + const Tensor& log_probs, + const Tensor& targets, + const Tensor& input_lengths, + const Tensor& target_lengths) { return F::detail::ctc_loss( - log_probs, - targets, - input_lengths, - target_lengths, - options.blank(), - options.reduction(), - options.zero_infinity()); + log_probs, + targets, + input_lengths, + target_lengths, + options.blank(), + options.reduction(), + options.zero_infinity()); } // ============================================================================ PoissonNLLLossImpl::PoissonNLLLossImpl(const PoissonNLLLossOptions& options_) - : options(options_) {} + : options(options_) {} void PoissonNLLLossImpl::reset() {} @@ -303,16 +328,22 @@ void PoissonNLLLossImpl::pretty_print(std::ostream& stream) const { } Tensor PoissonNLLLossImpl::forward( - const Tensor& log_input, const Tensor& target) { + const Tensor& log_input, + const Tensor& target) { return F::detail::poisson_nll_loss( - log_input, target, - options.log_input(), options.full(), options.eps(), options.reduction()); + log_input, + target, + options.log_input(), + options.full(), + options.eps(), + options.reduction()); } // ============================================================================ MarginRankingLossImpl::MarginRankingLossImpl( - const MarginRankingLossOptions& options_) : options(options_) {} + const MarginRankingLossOptions& options_) + : options(options_) {} void MarginRankingLossImpl::reset() {} @@ -320,9 +351,12 @@ void MarginRankingLossImpl::pretty_print(std::ostream& stream) const { stream << "torch::nn::MarginRankingLoss()"; } -Tensor MarginRankingLossImpl::forward(const Tensor& input1, - const Tensor& input2, const Tensor& target) { - return F::detail::margin_ranking_loss(input1, input2, target, options.margin(), options.reduction()); +Tensor MarginRankingLossImpl::forward( + const Tensor& input1, + const Tensor& input2, + const Tensor& target) { + return F::detail::margin_ranking_loss( + input1, input2, target, options.margin(), options.reduction()); } // ============================================================================ @@ -342,15 +376,9 @@ void NLLLossImpl::pretty_print(std::ostream& stream) const { stream << "torch::nn::NLLLoss()"; } -Tensor NLLLossImpl::forward( - const Tensor& input, - const Tensor& target) { +Tensor NLLLossImpl::forward(const Tensor& input, const Tensor& target) { return F::detail::nll_loss( - input, - target, - weight, - options.ignore_index(), - options.reduction()); + input, target, weight, options.ignore_index(), options.reduction()); } // ============================================================================ @@ -374,19 +402,20 @@ Tensor CrossEntropyLossImpl::forward( const Tensor& input, const Tensor& target) { return F::detail::cross_entropy( - input, - target, - weight, - options.ignore_index(), - options.reduction(), - options.label_smoothing()); + input, + target, + weight, + options.ignore_index(), + options.reduction(), + options.label_smoothing()); } // ============================================================================ BCEWithLogitsLossImpl::BCEWithLogitsLossImpl( - // NOLINTNEXTLINE(modernize-pass-by-value) - const BCEWithLogitsLossOptions& options_) : options(options_) { + // NOLINTNEXTLINE(modernize-pass-by-value) + const BCEWithLogitsLossOptions& options_) + : options(options_) { // NOLINTNEXTLINE(clang-analyzer-optin.cplusplus.VirtualCall) reset(); } @@ -401,9 +430,14 @@ void BCEWithLogitsLossImpl::pretty_print(std::ostream& stream) const { } Tensor BCEWithLogitsLossImpl::forward( - const Tensor& input, const Tensor& target) { - return F::detail::binary_cross_entropy_with_logits(input, target, - options.weight(), options.reduction(), options.pos_weight()); + const Tensor& input, + const Tensor& target) { + return F::detail::binary_cross_entropy_with_logits( + input, + target, + options.weight(), + options.reduction(), + options.pos_weight()); } } // namespace nn diff --git a/torch/csrc/api/src/nn/modules/normalization.cpp b/torch/csrc/api/src/nn/modules/normalization.cpp index b2e30df675e50b..e64433b5a66549 100644 --- a/torch/csrc/api/src/nn/modules/normalization.cpp +++ b/torch/csrc/api/src/nn/modules/normalization.cpp @@ -1,8 +1,8 @@ #include #include -#include #include +#include #include #include @@ -12,17 +12,20 @@ namespace F = torch::nn::functional; namespace torch { namespace nn { -LayerNormImpl::LayerNormImpl(const LayerNormOptions& options_) : options(options_) { // NOLINT(modernize-pass-by-value) +LayerNormImpl::LayerNormImpl(const LayerNormOptions& options_) + : options(options_) { // NOLINT(modernize-pass-by-value) // NOLINTNEXTLINE(clang-analyzer-optin.cplusplus.VirtualCall) reset(); } void LayerNormImpl::reset() { if (options.elementwise_affine()) { - weight = register_parameter("weight", torch::empty(options.normalized_shape())); + weight = + register_parameter("weight", torch::empty(options.normalized_shape())); bias = register_parameter("bias", torch::empty(options.normalized_shape())); } else { - weight = register_parameter("weight", torch::Tensor(), /*requires_grad=*/false); + weight = + register_parameter("weight", torch::Tensor(), /*requires_grad=*/false); bias = register_parameter("bias", torch::Tensor(), /*requires_grad=*/false); } reset_parameters(); @@ -36,35 +39,34 @@ void LayerNormImpl::reset_parameters() { } void LayerNormImpl::pretty_print(std::ostream& stream) const { - stream << std::boolalpha - << "torch::nn::LayerNorm(" << torch::IntArrayRef(options.normalized_shape()) + stream << std::boolalpha << "torch::nn::LayerNorm(" + << torch::IntArrayRef(options.normalized_shape()) << ", eps=" << options.eps() - << ", elementwise_affine=" << options.elementwise_affine() - << ")"; + << ", elementwise_affine=" << options.elementwise_affine() << ")"; } torch::Tensor LayerNormImpl::forward(const Tensor& input) { - return F::detail::layer_norm(input, options.normalized_shape(), weight, bias, options.eps()); + return F::detail::layer_norm( + input, options.normalized_shape(), weight, bias, options.eps()); } // ============================================================================ -LocalResponseNormImpl::LocalResponseNormImpl(const LocalResponseNormOptions& options_) +LocalResponseNormImpl::LocalResponseNormImpl( + const LocalResponseNormOptions& options_) : options(options_) {} Tensor LocalResponseNormImpl::forward(const Tensor& input) { - return F::detail::local_response_norm(input, options.size(), options.alpha(), options.beta(), options.k()); + return F::detail::local_response_norm( + input, options.size(), options.alpha(), options.beta(), options.k()); } void LocalResponseNormImpl::reset() {} void LocalResponseNormImpl::pretty_print(std::ostream& stream) const { - stream << std::boolalpha - << "torch::nn::LocalResponseNorm(" << options.size() - << ", alpha=" << options.alpha() - << ", beta=" << options.beta() - << ", k=" << options.k() - << ")"; + stream << std::boolalpha << "torch::nn::LocalResponseNorm(" << options.size() + << ", alpha=" << options.alpha() << ", beta=" << options.beta() + << ", k=" << options.k() << ")"; } // ============================================================================ @@ -72,12 +74,9 @@ void LocalResponseNormImpl::pretty_print(std::ostream& stream) const { void CrossMapLRN2dImpl::reset() {} void CrossMapLRN2dImpl::pretty_print(std::ostream& stream) const { - stream << std::boolalpha - << "torch::nn::CrossMapLRN2d(" << options.size() - << ", alpha=" << options.alpha() - << ", beta=" << options.beta() - << ", k=" << options.k() - << ")"; + stream << std::boolalpha << "torch::nn::CrossMapLRN2d(" << options.size() + << ", alpha=" << options.alpha() << ", beta=" << options.beta() + << ", k=" << options.k() << ")"; } torch::Tensor CrossMapLRN2dImpl::forward(const torch::Tensor& input) { @@ -86,7 +85,8 @@ torch::Tensor CrossMapLRN2dImpl::forward(const torch::Tensor& input) { // ============================================================================ -GroupNormImpl::GroupNormImpl(const GroupNormOptions& options_) : options(options_) { // NOLINT(modernize-pass-by-value) +GroupNormImpl::GroupNormImpl(const GroupNormOptions& options_) + : options(options_) { // NOLINT(modernize-pass-by-value) // NOLINTNEXTLINE(clang-analyzer-optin.cplusplus.VirtualCall) reset(); } @@ -96,7 +96,8 @@ void GroupNormImpl::reset() { weight = register_parameter("weight", torch::empty(options.num_channels())); bias = register_parameter("bias", torch::empty(options.num_channels())); } else { - weight = register_parameter("weight", torch::Tensor(), /*requires_grad=*/false); + weight = + register_parameter("weight", torch::Tensor(), /*requires_grad=*/false); bias = register_parameter("bias", torch::Tensor(), /*requires_grad=*/false); } reset_parameters(); @@ -110,16 +111,14 @@ void GroupNormImpl::reset_parameters() { } torch::Tensor GroupNormImpl::forward(const Tensor& input) { - return F::detail::group_norm(input, options.num_groups(), weight, bias, options.eps()); + return F::detail::group_norm( + input, options.num_groups(), weight, bias, options.eps()); } void GroupNormImpl::pretty_print(std::ostream& stream) const { - stream << std::boolalpha - << "torch::nn::GroupNorm(" << options.num_groups() - << ", " << options.num_channels() - << ", eps=" << options.eps() - << ", affine=" << options.affine() - << ")"; + stream << std::boolalpha << "torch::nn::GroupNorm(" << options.num_groups() + << ", " << options.num_channels() << ", eps=" << options.eps() + << ", affine=" << options.affine() << ")"; } } // namespace nn diff --git a/torch/csrc/api/src/nn/modules/padding.cpp b/torch/csrc/api/src/nn/modules/padding.cpp index c2b3238ab6db63..033e8075aa43db 100644 --- a/torch/csrc/api/src/nn/modules/padding.cpp +++ b/torch/csrc/api/src/nn/modules/padding.cpp @@ -8,7 +8,8 @@ namespace torch { namespace nn { template -ReflectionPadImpl::ReflectionPadImpl(const ReflectionPadOptions& options_) +ReflectionPadImpl::ReflectionPadImpl( + const ReflectionPadOptions& options_) : options(options_) {} template @@ -32,7 +33,8 @@ template class ReflectionPadImpl<3, ReflectionPad3dImpl>; // ============================================================================ template -ReplicationPadImpl::ReplicationPadImpl(const ReplicationPadOptions& options_) +ReplicationPadImpl::ReplicationPadImpl( + const ReplicationPadOptions& options_) : options(options_) {} template @@ -72,7 +74,8 @@ Tensor ZeroPad2dImpl::forward(const Tensor& input) { // ============================================================================ template -ConstantPadImpl::ConstantPadImpl(const ConstantPadOptions& options_) +ConstantPadImpl::ConstantPadImpl( + const ConstantPadOptions& options_) : options(options_) {} template @@ -80,14 +83,15 @@ void ConstantPadImpl::reset() {} template Tensor ConstantPadImpl::forward(const Tensor& input) { - return F::detail::pad(input, options.padding(), torch::kConstant, options.value()); + return F::detail::pad( + input, options.padding(), torch::kConstant, options.value()); } template void ConstantPadImpl::pretty_print(std::ostream& stream) const { stream << "torch::nn::ConstantPad" << D << "d" - << "(padding=" << options.padding() - << ", value=" << options.value() << ")"; + << "(padding=" << options.padding() << ", value=" << options.value() + << ")"; } template class ConstantPadImpl<1, ConstantPad1dImpl>; diff --git a/torch/csrc/api/src/nn/modules/pixelshuffle.cpp b/torch/csrc/api/src/nn/modules/pixelshuffle.cpp index 7062b07fe5d780..b8d212b729aaca 100644 --- a/torch/csrc/api/src/nn/modules/pixelshuffle.cpp +++ b/torch/csrc/api/src/nn/modules/pixelshuffle.cpp @@ -5,8 +5,7 @@ namespace F = torch::nn::functional; namespace torch { namespace nn { -PixelShuffleImpl::PixelShuffleImpl( - const PixelShuffleOptions& options_) +PixelShuffleImpl::PixelShuffleImpl(const PixelShuffleOptions& options_) : options(options_) {} void PixelShuffleImpl::pretty_print(std::ostream& stream) const { @@ -16,8 +15,7 @@ void PixelShuffleImpl::pretty_print(std::ostream& stream) const { void PixelShuffleImpl::reset() {} -Tensor PixelShuffleImpl::forward( - const Tensor& input) { +Tensor PixelShuffleImpl::forward(const Tensor& input) { return F::detail::pixel_shuffle(input, options.upscale_factor()); } diff --git a/torch/csrc/api/src/nn/modules/pooling.cpp b/torch/csrc/api/src/nn/modules/pooling.cpp index ecda3c3ad25b9a..8fef6353685ac9 100644 --- a/torch/csrc/api/src/nn/modules/pooling.cpp +++ b/torch/csrc/api/src/nn/modules/pooling.cpp @@ -18,40 +18,40 @@ template void AvgPoolImpl::pretty_print(std::ostream& stream) const { stream << "torch::nn::AvgPool" << D << "d" << "(kernel_size=" << options.kernel_size() - << ", stride=" << options.stride() - << ", padding=" << options.padding() << ")"; + << ", stride=" << options.stride() << ", padding=" << options.padding() + << ")"; } Tensor AvgPool1dImpl::forward(const Tensor& input) { return F::detail::avg_pool1d( - input, - options.kernel_size(), - options.stride(), - options.padding(), - options.ceil_mode(), - options.count_include_pad()); + input, + options.kernel_size(), + options.stride(), + options.padding(), + options.ceil_mode(), + options.count_include_pad()); } Tensor AvgPool2dImpl::forward(const Tensor& input) { return F::detail::avg_pool2d( - input, - options.kernel_size(), - options.stride(), - options.padding(), - options.ceil_mode(), - options.count_include_pad(), - options.divisor_override()); + input, + options.kernel_size(), + options.stride(), + options.padding(), + options.ceil_mode(), + options.count_include_pad(), + options.divisor_override()); } Tensor AvgPool3dImpl::forward(const Tensor& input) { return F::detail::avg_pool3d( - input, - options.kernel_size(), - options.stride(), - options.padding(), - options.ceil_mode(), - options.count_include_pad(), - options.divisor_override()); + input, + options.kernel_size(), + options.stride(), + options.padding(), + options.ceil_mode(), + options.count_include_pad(), + options.divisor_override()); } template class AvgPoolImpl<1, AvgPool1dImpl>; @@ -69,73 +69,74 @@ void MaxPoolImpl::reset() {} template void MaxPoolImpl::pretty_print(std::ostream& stream) const { - stream << std::boolalpha - << "torch::nn::MaxPool" << D << "d" + stream << std::boolalpha << "torch::nn::MaxPool" << D << "d" << "(kernel_size=" << options.kernel_size() - << ", stride=" << options.stride() - << ", padding=" << options.padding() + << ", stride=" << options.stride() << ", padding=" << options.padding() << ", dilation=" << options.dilation() << ", ceil_mode=" << options.ceil_mode() << ")"; } Tensor MaxPool1dImpl::forward(const Tensor& input) { return F::detail::max_pool1d( - input, - options.kernel_size(), - options.stride(), - options.padding(), - options.dilation(), - options.ceil_mode()); + input, + options.kernel_size(), + options.stride(), + options.padding(), + options.dilation(), + options.ceil_mode()); } -std::tuple MaxPool1dImpl::forward_with_indices(const Tensor& input) { +std::tuple MaxPool1dImpl::forward_with_indices( + const Tensor& input) { return F::detail::max_pool1d_with_indices( - input, - options.kernel_size(), - options.stride(), - options.padding(), - options.dilation(), - options.ceil_mode()); + input, + options.kernel_size(), + options.stride(), + options.padding(), + options.dilation(), + options.ceil_mode()); } Tensor MaxPool2dImpl::forward(const Tensor& input) { return F::detail::max_pool2d( - input, - options.kernel_size(), - options.stride(), - options.padding(), - options.dilation(), - options.ceil_mode()); + input, + options.kernel_size(), + options.stride(), + options.padding(), + options.dilation(), + options.ceil_mode()); } -std::tuple MaxPool2dImpl::forward_with_indices(const Tensor& input) { +std::tuple MaxPool2dImpl::forward_with_indices( + const Tensor& input) { return F::detail::max_pool2d_with_indices( - input, - options.kernel_size(), - options.stride(), - options.padding(), - options.dilation(), - options.ceil_mode()); + input, + options.kernel_size(), + options.stride(), + options.padding(), + options.dilation(), + options.ceil_mode()); } Tensor MaxPool3dImpl::forward(const Tensor& input) { return F::detail::max_pool3d( - input, - options.kernel_size(), - options.stride(), - options.padding(), - options.dilation(), - options.ceil_mode()); + input, + options.kernel_size(), + options.stride(), + options.padding(), + options.dilation(), + options.ceil_mode()); } -std::tuple MaxPool3dImpl::forward_with_indices(const Tensor& input) { +std::tuple MaxPool3dImpl::forward_with_indices( + const Tensor& input) { return F::detail::max_pool3d_with_indices( - input, - options.kernel_size(), - options.stride(), - options.padding(), - options.dilation(), - options.ceil_mode()); + input, + options.kernel_size(), + options.stride(), + options.padding(), + options.dilation(), + options.ceil_mode()); } template class MaxPoolImpl<1, MaxPool1dImpl>; @@ -148,29 +149,41 @@ Tensor AdaptiveMaxPool1dImpl::forward(const Tensor& input) { return F::detail::adaptive_max_pool1d(input, options.output_size()); } -std::tuple AdaptiveMaxPool1dImpl::forward_with_indices(const Tensor& input) { - return F::detail::adaptive_max_pool1d_with_indices(input, options.output_size()); +std::tuple AdaptiveMaxPool1dImpl::forward_with_indices( + const Tensor& input) { + return F::detail::adaptive_max_pool1d_with_indices( + input, options.output_size()); } Tensor AdaptiveMaxPool2dImpl::forward(const Tensor& input) { return F::detail::adaptive_max_pool2d(input, options.output_size()); } -std::tuple AdaptiveMaxPool2dImpl::forward_with_indices(const Tensor& input) { - return F::detail::adaptive_max_pool2d_with_indices(input, options.output_size()); +std::tuple AdaptiveMaxPool2dImpl::forward_with_indices( + const Tensor& input) { + return F::detail::adaptive_max_pool2d_with_indices( + input, options.output_size()); } Tensor AdaptiveMaxPool3dImpl::forward(const Tensor& input) { return F::detail::adaptive_max_pool3d(input, options.output_size()); } -std::tuple AdaptiveMaxPool3dImpl::forward_with_indices(const Tensor& input) { - return F::detail::adaptive_max_pool3d_with_indices(input, options.output_size()); +std::tuple AdaptiveMaxPool3dImpl::forward_with_indices( + const Tensor& input) { + return F::detail::adaptive_max_pool3d_with_indices( + input, options.output_size()); } template class AdaptiveMaxPoolImpl<1, ExpandingArray<1>, AdaptiveMaxPool1dImpl>; -template class AdaptiveMaxPoolImpl<2, ExpandingArrayWithOptionalElem<2>, AdaptiveMaxPool2dImpl>; -template class AdaptiveMaxPoolImpl<3, ExpandingArrayWithOptionalElem<3>, AdaptiveMaxPool3dImpl>; +template class AdaptiveMaxPoolImpl< + 2, + ExpandingArrayWithOptionalElem<2>, + AdaptiveMaxPool2dImpl>; +template class AdaptiveMaxPoolImpl< + 3, + ExpandingArrayWithOptionalElem<3>, + AdaptiveMaxPool3dImpl>; // ============================================================================ @@ -187,8 +200,14 @@ Tensor AdaptiveAvgPool3dImpl::forward(const Tensor& input) { } template class AdaptiveAvgPoolImpl<1, ExpandingArray<1>, AdaptiveAvgPool1dImpl>; -template class AdaptiveAvgPoolImpl<2, ExpandingArrayWithOptionalElem<2>, AdaptiveAvgPool2dImpl>; -template class AdaptiveAvgPoolImpl<3, ExpandingArrayWithOptionalElem<3>, AdaptiveAvgPool3dImpl>; +template class AdaptiveAvgPoolImpl< + 2, + ExpandingArrayWithOptionalElem<2>, + AdaptiveAvgPool2dImpl>; +template class AdaptiveAvgPoolImpl< + 3, + ExpandingArrayWithOptionalElem<3>, + AdaptiveAvgPool3dImpl>; // ============================================================================ @@ -201,44 +220,49 @@ void MaxUnpoolImpl::reset() {} template void MaxUnpoolImpl::pretty_print(std::ostream& stream) const { - stream << std::boolalpha - << "torch::nn::MaxUnpool" << D << "d" + stream << std::boolalpha << "torch::nn::MaxUnpool" << D << "d" << "(kernel_size=" << options.kernel_size() - << ", stride=" << options.stride() - << ", padding=" << options.padding() << ")"; + << ", stride=" << options.stride() << ", padding=" << options.padding() + << ")"; } -Tensor MaxUnpool1dImpl::forward(const Tensor& input, const Tensor& indices, +Tensor MaxUnpool1dImpl::forward( + const Tensor& input, + const Tensor& indices, const c10::optional>& output_size) { return F::detail::max_unpool1d( - input, - indices, - options.kernel_size(), - options.stride(), - options.padding(), - output_size); + input, + indices, + options.kernel_size(), + options.stride(), + options.padding(), + output_size); } -Tensor MaxUnpool2dImpl::forward(const Tensor& input, const Tensor& indices, +Tensor MaxUnpool2dImpl::forward( + const Tensor& input, + const Tensor& indices, const c10::optional>& output_size) { return F::detail::max_unpool2d( - input, - indices, - options.kernel_size(), - options.stride(), - options.padding(), - output_size); + input, + indices, + options.kernel_size(), + options.stride(), + options.padding(), + output_size); } -Tensor MaxUnpool3dImpl::forward(const Tensor& input, const Tensor& indices, +Tensor MaxUnpool3dImpl::forward( + const Tensor& input, + const Tensor& indices, const c10::optional>& output_size) { return F::detail::max_unpool3d( - input, - indices, - options.kernel_size(), - options.stride(), - options.padding(), - output_size); + input, + indices, + options.kernel_size(), + options.stride(), + options.padding(), + output_size); } template class MaxUnpoolImpl<1, MaxUnpool1dImpl>; @@ -247,85 +271,120 @@ template class MaxUnpoolImpl<3, MaxUnpool3dImpl>; // ============================================================================ -FractionalMaxPool2dImpl::FractionalMaxPool2dImpl(const FractionalMaxPool2dOptions& options_) // NOLINT(modernize-pass-by-value) +FractionalMaxPool2dImpl::FractionalMaxPool2dImpl( + const FractionalMaxPool2dOptions& + options_) // NOLINT(modernize-pass-by-value) : options(options_) { // NOLINTNEXTLINE(clang-analyzer-optin.cplusplus.VirtualCall) reset(); } void FractionalMaxPool2dImpl::reset() { - _random_samples = register_buffer("_random_samples", options._random_samples()); - if (options.output_size() == c10::nullopt && options.output_ratio() == c10::nullopt) { + _random_samples = + register_buffer("_random_samples", options._random_samples()); + if (options.output_size() == c10::nullopt && + options.output_ratio() == c10::nullopt) { TORCH_CHECK( - false, - "FractionalMaxPool2d requires specifying either ", - "an output size, or a pooling ratio"); + false, + "FractionalMaxPool2d requires specifying either ", + "an output size, or a pooling ratio"); } - if (options.output_size() != c10::nullopt && options.output_ratio() != c10::nullopt) { - TORCH_CHECK(false, "only one of output_size and output_ratio may be specified"); + if (options.output_size() != c10::nullopt && + options.output_ratio() != c10::nullopt) { + TORCH_CHECK( + false, "only one of output_size and output_ratio may be specified"); } if (options.output_ratio() != c10::nullopt) { - at::ArrayRef output_ratio = at::ArrayRef(options.output_ratio().value()); - if (!(0 < output_ratio[0] && output_ratio[0] < 1 && - 0 < output_ratio[1] && output_ratio[1] < 1)) { - TORCH_CHECK(false, "output_ratio must be between 0 and 1 (got ", output_ratio, ")"); + at::ArrayRef output_ratio = + at::ArrayRef(options.output_ratio().value()); + if (!(0 < output_ratio[0] && output_ratio[0] < 1 && 0 < output_ratio[1] && + output_ratio[1] < 1)) { + TORCH_CHECK( + false, + "output_ratio must be between 0 and 1 (got ", + output_ratio, + ")"); } } } Tensor FractionalMaxPool2dImpl::forward(const Tensor& input) { return F::detail::fractional_max_pool2d( - input, options.kernel_size(), options.output_size(), options.output_ratio(), - _random_samples); + input, + options.kernel_size(), + options.output_size(), + options.output_ratio(), + _random_samples); } -std::tuple FractionalMaxPool2dImpl::forward_with_indices(const Tensor& input) { +std::tuple FractionalMaxPool2dImpl::forward_with_indices( + const Tensor& input) { return F::detail::fractional_max_pool2d_with_indices( - input, options.kernel_size(), options.output_size(), options.output_ratio(), - _random_samples); + input, + options.kernel_size(), + options.output_size(), + options.output_ratio(), + _random_samples); } void FractionalMaxPool2dImpl::pretty_print(std::ostream& stream) const { stream << "torch::nn::FractionalMaxPool2d()"; } -FractionalMaxPool3dImpl::FractionalMaxPool3dImpl(const FractionalMaxPool3dOptions& options_) // NOLINT(modernize-pass-by-value) +FractionalMaxPool3dImpl::FractionalMaxPool3dImpl( + const FractionalMaxPool3dOptions& + options_) // NOLINT(modernize-pass-by-value) : options(options_) { // NOLINTNEXTLINE(clang-analyzer-optin.cplusplus.VirtualCall) reset(); } void FractionalMaxPool3dImpl::reset() { - _random_samples = register_buffer("_random_samples", options._random_samples()); - if (options.output_size() == c10::nullopt && options.output_ratio() == c10::nullopt) { + _random_samples = + register_buffer("_random_samples", options._random_samples()); + if (options.output_size() == c10::nullopt && + options.output_ratio() == c10::nullopt) { TORCH_CHECK( - false, - "FractionalMaxPool3d requires specifying either ", - "an output size, or a pooling ratio"); + false, + "FractionalMaxPool3d requires specifying either ", + "an output size, or a pooling ratio"); } - if (options.output_size() != c10::nullopt && options.output_ratio() != c10::nullopt) { - TORCH_CHECK(false, "only one of output_size and output_ratio may be specified"); + if (options.output_size() != c10::nullopt && + options.output_ratio() != c10::nullopt) { + TORCH_CHECK( + false, "only one of output_size and output_ratio may be specified"); } if (options.output_ratio() != c10::nullopt) { - at::ArrayRef output_ratio = at::ArrayRef(options.output_ratio().value()); - if (!(0 < output_ratio[0] && output_ratio[0] < 1 && - 0 < output_ratio[1] && output_ratio[1] < 1 && - 0 < output_ratio[2] && output_ratio[2] < 1)) { - TORCH_CHECK(false, "output_ratio must be between 0 and 1 (got ", output_ratio, ")"); + at::ArrayRef output_ratio = + at::ArrayRef(options.output_ratio().value()); + if (!(0 < output_ratio[0] && output_ratio[0] < 1 && 0 < output_ratio[1] && + output_ratio[1] < 1 && 0 < output_ratio[2] && output_ratio[2] < 1)) { + TORCH_CHECK( + false, + "output_ratio must be between 0 and 1 (got ", + output_ratio, + ")"); } } } Tensor FractionalMaxPool3dImpl::forward(const Tensor& input) { return F::detail::fractional_max_pool3d( - input, options.kernel_size(), options.output_size(), options.output_ratio(), - _random_samples); + input, + options.kernel_size(), + options.output_size(), + options.output_ratio(), + _random_samples); } -std::tuple FractionalMaxPool3dImpl::forward_with_indices(const Tensor& input) { +std::tuple FractionalMaxPool3dImpl::forward_with_indices( + const Tensor& input) { return F::detail::fractional_max_pool3d_with_indices( - input, options.kernel_size(), options.output_size(), options.output_ratio(), - _random_samples); + input, + options.kernel_size(), + options.output_size(), + options.output_ratio(), + _random_samples); } void FractionalMaxPool3dImpl::pretty_print(std::ostream& stream) const { @@ -343,8 +402,7 @@ void LPPoolImpl::reset() {} template void LPPoolImpl::pretty_print(std::ostream& stream) const { - stream << std::boolalpha - << "torch::nn::LPPool" << D << "d(" + stream << std::boolalpha << "torch::nn::LPPool" << D << "d(" << "norm_type=" << options.norm_type() << ", " << "kernel_size=" << options.kernel_size() << ", " << "stride=" << options.stride() << ", " @@ -353,22 +411,22 @@ void LPPoolImpl::pretty_print(std::ostream& stream) const { Tensor LPPool1dImpl::forward(const Tensor& input) { return F::detail::lp_pool1d( - input, - options.norm_type(), - options.kernel_size(), - options.stride(), - options.ceil_mode()); + input, + options.norm_type(), + options.kernel_size(), + options.stride(), + options.ceil_mode()); } template class LPPoolImpl<1, LPPool1dImpl>; Tensor LPPool2dImpl::forward(const Tensor& input) { return F::detail::lp_pool2d( - input, - options.norm_type(), - options.kernel_size(), - options.stride(), - options.ceil_mode()); + input, + options.norm_type(), + options.kernel_size(), + options.stride(), + options.ceil_mode()); } template class LPPoolImpl<2, LPPool2dImpl>; diff --git a/torch/csrc/api/src/nn/modules/rnn.cpp b/torch/csrc/api/src/nn/modules/rnn.cpp index d92be5e6b51303..3d785ba603d34b 100644 --- a/torch/csrc/api/src/nn/modules/rnn.cpp +++ b/torch/csrc/api/src/nn/modules/rnn.cpp @@ -28,7 +28,8 @@ namespace nn { /// https://docs.nvidia.com/deeplearning/sdk/cudnn-developer-guide/index.html#cudnnRNNMode_t enum class CuDNNMode { RNN_RELU = 0, RNN_TANH = 1, LSTM = 2, GRU = 3 }; -CuDNNMode get_cudnn_mode_for_rnn(detail::RNNOptionsBase::rnn_options_base_mode_t mode) { +CuDNNMode get_cudnn_mode_for_rnn( + detail::RNNOptionsBase::rnn_options_base_mode_t mode) { if (c10::get_if(&mode)) { return CuDNNMode::RNN_RELU; } else if (c10::get_if(&mode)) { @@ -42,7 +43,10 @@ CuDNNMode get_cudnn_mode_for_rnn(detail::RNNOptionsBase::rnn_options_base_mode_t } } -Tensor apply_permutation(const Tensor& tensor, const Tensor& permutation, int64_t dim = 1) { +Tensor apply_permutation( + const Tensor& tensor, + const Tensor& permutation, + int64_t dim = 1) { return tensor.index_select(dim, permutation); } @@ -50,7 +54,7 @@ Tensor apply_permutation(const Tensor& tensor, const Tensor& permutation, int64_ namespace detail { template RNNImplBase::RNNImplBase(const RNNOptionsBase& options_) - : options_base(options_) { + : options_base(options_) { // NOLINTNEXTLINE(clang-analyzer-optin.cplusplus.VirtualCall) reset(); } @@ -60,28 +64,32 @@ void RNNImplBase::reset() { const int64_t num_directions = options_base.bidirectional() ? 2 : 1; TORCH_CHECK( - 0 <= options_base.dropout() && options_base.dropout() <= 1, - "dropout should be a number in range [0, 1] ", - "representing the probability of an element being ", - "zeroed"); + 0 <= options_base.dropout() && options_base.dropout() <= 1, + "dropout should be a number in range [0, 1] ", + "representing the probability of an element being ", + "zeroed"); if (options_base.dropout() > 0 && options_base.num_layers() == 1) { TORCH_WARN( - "dropout option adds dropout after all but last ", - "recurrent layer, so non-zero dropout expects ", - "num_layers greater than 1, but got dropout=", options_base.dropout(), " and ", - "num_layers=", options_base.num_layers()); + "dropout option adds dropout after all but last ", + "recurrent layer, so non-zero dropout expects ", + "num_layers greater than 1, but got dropout=", + options_base.dropout(), + " and ", + "num_layers=", + options_base.num_layers()); } TORCH_CHECK( - 0 <= options_base.proj_size() && options_base.proj_size() < options_base.hidden_size(), - "proj_size has to be a positive integer, smaller than ", - "hidden_size or zero to disable projections"); + 0 <= options_base.proj_size() && + options_base.proj_size() < options_base.hidden_size(), + "proj_size has to be a positive integer, smaller than ", + "hidden_size or zero to disable projections"); if (options_base.proj_size() > 0) { TORCH_CHECK( - c10::get_if(&options_base.mode()), - "proj_size argument is only supported for LSTM, not RNN or GRU"); + c10::get_if(&options_base.mode()), + "proj_size argument is only supported for LSTM, not RNN or GRU"); } int64_t gate_size = 0; @@ -89,22 +97,28 @@ void RNNImplBase::reset() { gate_size = 4 * options_base.hidden_size(); } else if (c10::get_if(&options_base.mode())) { gate_size = 3 * options_base.hidden_size(); - // NOLINTNEXTLINE(bugprone-branch-clone) + // NOLINTNEXTLINE(bugprone-branch-clone) } else if (c10::get_if(&options_base.mode())) { gate_size = options_base.hidden_size(); } else if (c10::get_if(&options_base.mode())) { gate_size = options_base.hidden_size(); } else { - TORCH_CHECK(false, "Unrecognized RNN mode: " + torch::enumtype::get_enum_name(options_base.mode())); + TORCH_CHECK( + false, + "Unrecognized RNN mode: " + + torch::enumtype::get_enum_name(options_base.mode())); } flat_weights_names_ = {}; all_weights_ = {}; - for(const auto layer : c10::irange(options_base.num_layers())) { - for(const auto direction : c10::irange(num_directions)) { - int64_t real_hidden_size = options_base.proj_size() > 0 ? options_base.proj_size() : options_base.hidden_size(); - int64_t layer_input_size = layer == 0 ? options_base.input_size() : real_hidden_size * num_directions; + for (const auto layer : c10::irange(options_base.num_layers())) { + for (const auto direction : c10::irange(num_directions)) { + int64_t real_hidden_size = options_base.proj_size() > 0 + ? options_base.proj_size() + : options_base.hidden_size(); + int64_t layer_input_size = layer == 0 ? options_base.input_size() + : real_hidden_size * num_directions; auto w_ih = torch::empty({gate_size, layer_input_size}); auto w_hh = torch::empty({gate_size, real_hidden_size}); @@ -115,7 +129,8 @@ void RNNImplBase::reset() { std::vector layer_params = {w_ih, w_hh}; std::string suffix = direction == 1 ? "_reverse" : ""; - std::vector param_names = {"weight_ih_l{layer}{suffix}", "weight_hh_l{layer}{suffix}"}; + std::vector param_names = { + "weight_ih_l{layer}{suffix}", "weight_hh_l{layer}{suffix}"}; if (options_base.bias()) { param_names.emplace_back("bias_ih_l{layer}{suffix}"); param_names.emplace_back("bias_hh_l{layer}{suffix}"); @@ -123,21 +138,25 @@ void RNNImplBase::reset() { layer_params.emplace_back(b_hh); } if (options_base.proj_size() > 0) { - auto w_hr = torch::empty({options_base.proj_size(), options_base.hidden_size()}); + auto w_hr = torch::empty( + {options_base.proj_size(), options_base.hidden_size()}); layer_params.emplace_back(w_hr); param_names.emplace_back("weight_hr_l{layer}{suffix}"); } - for(auto& param_name : param_names) { - std::string x = std::regex_replace(param_name, std::regex("\\{layer\\}"), c10::str(layer)); - param_name = std::regex_replace(x, std::regex("\\{suffix\\}"), c10::str(suffix)); + for (auto& param_name : param_names) { + std::string x = std::regex_replace( + param_name, std::regex("\\{layer\\}"), c10::str(layer)); + param_name = + std::regex_replace(x, std::regex("\\{suffix\\}"), c10::str(suffix)); } - for(const auto i : c10::irange(param_names.size())) { + for (const auto i : c10::irange(param_names.size())) { auto name = param_names[i]; auto param = layer_params[i]; this->register_parameter(name, param); } - flat_weights_names_.insert(flat_weights_names_.end(), param_names.begin(), param_names.end()); + flat_weights_names_.insert( + flat_weights_names_.end(), param_names.begin(), param_names.end()); all_weights_.emplace_back(param_names); } } @@ -160,31 +179,30 @@ template void RNNImplBase::flatten_parameters() { // Resets parameter data pointer so that they can use faster code paths. // - // Right now, this works only if the module is on the GPU and cuDNN is enabled. - // Otherwise, it's a no-op. + // Right now, this works only if the module is on the GPU and cuDNN is + // enabled. Otherwise, it's a no-op. // Short-circuits if flat_weights_ is only partially instantiated if (flat_weights_.size() != flat_weights_names_.size()) { return; } - // Short-circuits if any tensor in self.flat_weights_ is not acceptable to cuDNN - // or the tensors in flat_weights_ are of different dtypes + // Short-circuits if any tensor in self.flat_weights_ is not acceptable to + // cuDNN or the tensors in flat_weights_ are of different dtypes auto first_fw = flat_weights_[0]; auto dtype = first_fw.dtype(); for (const auto& fw : flat_weights_) { - if (!(fw.dtype() == dtype) || - !fw.is_cuda() || + if (!(fw.dtype() == dtype) || !fw.is_cuda() || !torch::cudnn_is_acceptable(fw)) { return; } } - // If any parameters alias, we fall back to the slower, copying code path. This is - // a sufficient check, because overlapping parameter buffers that don't completely - // alias would break the assumptions of the uniqueness check in - // Module::named_parameters(). + // If any parameters alias, we fall back to the slower, copying code path. + // This is a sufficient check, because overlapping parameter buffers that + // don't completely alias would break the assumptions of the uniqueness check + // in Module::named_parameters(). std::unordered_set unique_data_ptrs; for (const auto& p : flat_weights_) { unique_data_ptrs.emplace(p.data_ptr()); @@ -206,15 +224,15 @@ void RNNImplBase::flatten_parameters() { ++num_weights; } torch::_cudnn_rnn_flatten_weight( - flat_weights_, - num_weights, - options_base.input_size(), - static_cast(get_cudnn_mode_for_rnn(options_base.mode())), - options_base.hidden_size(), - options_base.proj_size(), - options_base.num_layers(), - options_base.batch_first(), - options_base.bidirectional()); + flat_weights_, + num_weights, + options_base.input_size(), + static_cast(get_cudnn_mode_for_rnn(options_base.mode())), + options_base.hidden_size(), + options_base.proj_size(), + options_base.num_layers(), + options_base.batch_first(), + options_base.bidirectional()); } } } @@ -266,19 +284,28 @@ void RNNImplBase::reset_parameters() { } template -void RNNImplBase::check_input(const Tensor& input, const Tensor& batch_sizes) const { - int64_t expected_input_dim = batch_sizes.defined() ? 2 : 3; +void RNNImplBase::check_input( + const Tensor& input, + const Tensor& batch_sizes) const { + int64_t expected_input_dim = batch_sizes.defined() ? 2 : 3; TORCH_CHECK( - input.dim() == expected_input_dim, - "input must have ", expected_input_dim, " dimensions, got ", input.dim()); + input.dim() == expected_input_dim, + "input must have ", + expected_input_dim, + " dimensions, got ", + input.dim()); TORCH_CHECK( - options_base.input_size() == input.size(-1), - "input.size(-1) must be equal to input_size. Expected ", options_base.input_size(), ", got ", input.size(-1)); + options_base.input_size() == input.size(-1), + "input.size(-1) must be equal to input_size. Expected ", + options_base.input_size(), + ", got ", + input.size(-1)); } template -std::tuple RNNImplBase::get_expected_hidden_size( - const Tensor& input, const Tensor& batch_sizes) const { +std::tuple RNNImplBase:: + get_expected_hidden_size(const Tensor& input, const Tensor& batch_sizes) + const { int64_t mini_batch = 0; if (batch_sizes.defined()) { mini_batch = batch_sizes[0].item(); @@ -286,8 +313,11 @@ std::tuple RNNImplBase::get_expected_hidden_ mini_batch = options_base.batch_first() ? input.size(0) : input.size(1); } int64_t num_directions = options_base.bidirectional() ? 2 : 1; - int64_t real_hidden_size = options_base.proj_size() > 0 ? options_base.proj_size() : options_base.hidden_size(); - return std::make_tuple(options_base.num_layers() * num_directions, mini_batch, real_hidden_size); + int64_t real_hidden_size = options_base.proj_size() > 0 + ? options_base.proj_size() + : options_base.hidden_size(); + return std::make_tuple( + options_base.num_layers() * num_directions, mini_batch, real_hidden_size); } template @@ -296,27 +326,34 @@ void RNNImplBase::check_hidden_size( std::tuple expected_hidden_size, std::string msg) const { auto expected_hidden_size_vec = std::vector({ - std::get<0>(expected_hidden_size), - std::get<1>(expected_hidden_size), - std::get<2>(expected_hidden_size), + std::get<0>(expected_hidden_size), + std::get<1>(expected_hidden_size), + std::get<2>(expected_hidden_size), }); if (hx.sizes() != expected_hidden_size_vec) { - msg = std::regex_replace(msg, std::regex("\\{1\\}"), c10::str(expected_hidden_size_vec)); + msg = std::regex_replace( + msg, std::regex("\\{1\\}"), c10::str(expected_hidden_size_vec)); msg = std::regex_replace(msg, std::regex("\\{2\\}"), c10::str(hx.sizes())); TORCH_CHECK(false, msg); } } template -void RNNImplBase::check_forward_args(Tensor input, Tensor hidden, Tensor batch_sizes) const { +void RNNImplBase::check_forward_args( + Tensor input, + Tensor hidden, + Tensor batch_sizes) const { this->check_input(input, batch_sizes); - auto expected_hidden_size = this->get_expected_hidden_size(input, batch_sizes); + auto expected_hidden_size = + this->get_expected_hidden_size(input, batch_sizes); this->check_hidden_size(hidden, expected_hidden_size); } template -Tensor RNNImplBase::permute_hidden(Tensor hx, const Tensor& permutation) const { +Tensor RNNImplBase::permute_hidden( + Tensor hx, + const Tensor& permutation) const { if (!permutation.defined()) { return hx; } @@ -327,7 +364,8 @@ template void RNNImplBase::pretty_print(std::ostream& stream) const { const std::string name = this->name(); const std::string name_without_impl = name.substr(0, name.size() - 4); - stream << std::boolalpha << name_without_impl << "(input_size=" << options_base.input_size() + stream << std::boolalpha << name_without_impl + << "(input_size=" << options_base.input_size() << ", hidden_size=" << options_base.hidden_size() << ", num_layers=" << options_base.num_layers() << ", bias=" << options_base.bias() @@ -360,22 +398,25 @@ template class RNNImplBase; // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ RNN ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ detail::RNNOptionsBase::rnn_options_base_mode_t compute_rnn_options_base_mode( - RNNOptions::nonlinearity_t nonlinearity) { + RNNOptions::nonlinearity_t nonlinearity) { if (c10::get_if(&nonlinearity)) { return torch::kRNN_TANH; } else if (c10::get_if(&nonlinearity)) { return torch::kRNN_RELU; } else { - TORCH_CHECK(false, "Unknown nonlinearity ", torch::enumtype::get_enum_name(nonlinearity)); + TORCH_CHECK( + false, + "Unknown nonlinearity ", + torch::enumtype::get_enum_name(nonlinearity)); } } RNNImpl::RNNImpl(const RNNOptions& options_) : detail::RNNImplBase( detail::RNNOptionsBase( - compute_rnn_options_base_mode(options_.nonlinearity()), - options_.input_size(), - options_.hidden_size()) + compute_rnn_options_base_mode(options_.nonlinearity()), + options_.input_size(), + options_.hidden_size()) .num_layers(options_.num_layers()) .bias(options_.bias()) .batch_first(options_.batch_first()) @@ -384,16 +425,18 @@ RNNImpl::RNNImpl(const RNNOptions& options_) options(options_) {} std::tuple RNNImpl::forward_helper( - const Tensor& input, - const Tensor& batch_sizes, - const Tensor& sorted_indices, - int64_t max_batch_size, - Tensor hx) { + const Tensor& input, + const Tensor& batch_sizes, + const Tensor& sorted_indices, + int64_t max_batch_size, + Tensor hx) { if (!hx.defined()) { int64_t num_directions = options_base.bidirectional() ? 2 : 1; - hx = torch::zeros({options_base.num_layers() * num_directions, - max_batch_size, options_base.hidden_size()}, - torch::dtype(input.dtype()).device(input.device())); + hx = torch::zeros( + {options_base.num_layers() * num_directions, + max_batch_size, + options_base.hidden_size()}, + torch::dtype(input.dtype()).device(input.device())); } else { // Each batch of the hidden state should match the input sequence that // the user believes he/she is passing in. @@ -405,23 +448,61 @@ std::tuple RNNImpl::forward_helper( std::tuple result; if (!batch_sizes.defined()) { if (c10::get_if(&options_base.mode())) { - result = torch::rnn_tanh(input, hx, flat_weights_, options_base.bias(), options_base.num_layers(), - options_base.dropout(), this->is_training(), options_base.bidirectional(), options_base.batch_first()); + result = torch::rnn_tanh( + input, + hx, + flat_weights_, + options_base.bias(), + options_base.num_layers(), + options_base.dropout(), + this->is_training(), + options_base.bidirectional(), + options_base.batch_first()); } else if (c10::get_if(&options_base.mode())) { - result = torch::rnn_relu(input, hx, flat_weights_, options_base.bias(), options_base.num_layers(), - options_base.dropout(), this->is_training(), options_base.bidirectional(), options_base.batch_first()); + result = torch::rnn_relu( + input, + hx, + flat_weights_, + options_base.bias(), + options_base.num_layers(), + options_base.dropout(), + this->is_training(), + options_base.bidirectional(), + options_base.batch_first()); } else { - TORCH_CHECK(false, "Unknown mode: ", torch::enumtype::get_enum_name(options_base.mode())); + TORCH_CHECK( + false, + "Unknown mode: ", + torch::enumtype::get_enum_name(options_base.mode())); } } else { if (c10::get_if(&options_base.mode())) { - result = torch::rnn_tanh(input, batch_sizes, hx, flat_weights_, options_base.bias(), - options_base.num_layers(), options_base.dropout(), this->is_training(), options_base.bidirectional()); + result = torch::rnn_tanh( + input, + batch_sizes, + hx, + flat_weights_, + options_base.bias(), + options_base.num_layers(), + options_base.dropout(), + this->is_training(), + options_base.bidirectional()); } else if (c10::get_if(&options_base.mode())) { - result = torch::rnn_relu(input, batch_sizes, hx, flat_weights_, options_base.bias(), - options_base.num_layers(), options_base.dropout(), this->is_training(), options_base.bidirectional()); + result = torch::rnn_relu( + input, + batch_sizes, + hx, + flat_weights_, + options_base.bias(), + options_base.num_layers(), + options_base.dropout(), + this->is_training(), + options_base.bidirectional()); } else { - TORCH_CHECK(false, "Unknown mode: ", torch::enumtype::get_enum_name(options_base.mode())); + TORCH_CHECK( + false, + "Unknown mode: ", + torch::enumtype::get_enum_name(options_base.mode())); } } auto output = std::get<0>(result); @@ -432,17 +513,22 @@ std::tuple RNNImpl::forward_helper( std::tuple RNNImpl::forward(const Tensor& input, Tensor hx) { auto batch_sizes = torch::Tensor(); - auto max_batch_size = options_base.batch_first() ? input.size(0) : input.size(1); + auto max_batch_size = + options_base.batch_first() ? input.size(0) : input.size(1); auto sorted_indices = torch::Tensor(); auto unsorted_indices = torch::Tensor(); Tensor output, hidden; - std::tie(output, hidden) = this->forward_helper(input, batch_sizes, sorted_indices, max_batch_size, std::move(hx)); + std::tie(output, hidden) = this->forward_helper( + input, batch_sizes, sorted_indices, max_batch_size, std::move(hx)); - return std::make_tuple(output, this->permute_hidden(hidden, unsorted_indices)); + return std::make_tuple( + output, this->permute_hidden(hidden, unsorted_indices)); } -std::tuple RNNImpl::forward_with_packed_input(const PackedSequence& packed_input, Tensor hx) { +std::tuple RNNImpl::forward_with_packed_input( + const PackedSequence& packed_input, + Tensor hx) { const auto& input = packed_input.data(); const auto& batch_sizes = packed_input.batch_sizes(); const auto& sorted_indices = packed_input.sorted_indices(); @@ -450,30 +536,33 @@ std::tuple RNNImpl::forward_with_packed_input(const Pack auto max_batch_size = batch_sizes[0].item(); Tensor output, hidden; - std::tie(output, hidden) = this->forward_helper(input, batch_sizes, sorted_indices, max_batch_size, std::move(hx)); + std::tie(output, hidden) = this->forward_helper( + input, batch_sizes, sorted_indices, max_batch_size, std::move(hx)); - auto output_packed = PackedSequence(output, batch_sizes, sorted_indices, unsorted_indices); - return std::make_tuple(output_packed, this->permute_hidden(hidden, unsorted_indices)); + auto output_packed = + PackedSequence(output, batch_sizes, sorted_indices, unsorted_indices); + return std::make_tuple( + output_packed, this->permute_hidden(hidden, unsorted_indices)); } // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ LSTM ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ LSTMImpl::LSTMImpl(const LSTMOptions& options_) - : detail::RNNImplBase( - detail::RNNOptionsBase( - torch::kLSTM, - options_.input_size(), - options_.hidden_size()) - .num_layers(options_.num_layers()) - .bias(options_.bias()) - .batch_first(options_.batch_first()) - .dropout(options_.dropout()) - .bidirectional(options_.bidirectional()) - .proj_size(options_.proj_size())), + : detail::RNNImplBase(detail::RNNOptionsBase( + torch::kLSTM, + options_.input_size(), + options_.hidden_size()) + .num_layers(options_.num_layers()) + .bias(options_.bias()) + .batch_first(options_.batch_first()) + .dropout(options_.dropout()) + .bidirectional(options_.bidirectional()) + .proj_size(options_.proj_size())), options(options_) {} std::tuple LSTMImpl::get_expected_cell_size( - const Tensor& input, const Tensor& batch_sizes) const { + const Tensor& input, + const Tensor& batch_sizes) const { int64_t mini_batch = 0; if (batch_sizes.defined()) { mini_batch = batch_sizes[0].item(); @@ -481,44 +570,59 @@ std::tuple LSTMImpl::get_expected_cell_size( mini_batch = options_base.batch_first() ? input.size(0) : input.size(1); } int64_t num_directions = options_base.bidirectional() ? 2 : 1; - return std::make_tuple(options_base.num_layers() * num_directions, mini_batch, options_base.hidden_size()); + return std::make_tuple( + options_base.num_layers() * num_directions, + mini_batch, + options_base.hidden_size()); } -void LSTMImpl::check_forward_args(const Tensor& input, std::tuple hidden, const Tensor& batch_sizes) const { +void LSTMImpl::check_forward_args( + const Tensor& input, + std::tuple hidden, + const Tensor& batch_sizes) const { this->check_input(input, batch_sizes); - this->check_hidden_size(std::get<0>(hidden), this->get_expected_hidden_size(input, batch_sizes), - "Expected hidden[0] size {1}, got {2}"); - this->check_hidden_size(std::get<1>(hidden), this->get_expected_cell_size(input, batch_sizes), - "Expected hidden[1] size {1}, got {2}"); -} - -std::tuple LSTMImpl::permute_hidden(std::tuple hx, const Tensor& permutation) const { + this->check_hidden_size( + std::get<0>(hidden), + this->get_expected_hidden_size(input, batch_sizes), + "Expected hidden[0] size {1}, got {2}"); + this->check_hidden_size( + std::get<1>(hidden), + this->get_expected_cell_size(input, batch_sizes), + "Expected hidden[1] size {1}, got {2}"); +} + +std::tuple LSTMImpl::permute_hidden( + std::tuple hx, + const Tensor& permutation) const { if (!permutation.defined()) { return hx; } return std::make_tuple( - apply_permutation(std::get<0>(hx), permutation), - apply_permutation(std::get<1>(hx), permutation) - ); + apply_permutation(std::get<0>(hx), permutation), + apply_permutation(std::get<1>(hx), permutation)); } std::tuple> LSTMImpl::forward_helper( - const Tensor& input, - const Tensor& batch_sizes, - const Tensor& sorted_indices, - int64_t max_batch_size, - torch::optional> hx_opt) { - + const Tensor& input, + const Tensor& batch_sizes, + const Tensor& sorted_indices, + int64_t max_batch_size, + torch::optional> hx_opt) { std::tuple hx; if (!hx_opt.has_value()) { int64_t num_directions = options.bidirectional() ? 2 : 1; - int64_t real_hidden_size = options.proj_size() > 0 ? options.proj_size() : options.hidden_size(); - auto h_zeros = torch::zeros({options.num_layers() * num_directions, - max_batch_size, real_hidden_size}, - torch::dtype(input.dtype()).device(input.device())); - auto c_zeros = torch::zeros({options.num_layers() * num_directions, - max_batch_size, options.hidden_size()}, - torch::dtype(input.dtype()).device(input.device())); + int64_t real_hidden_size = + options.proj_size() > 0 ? options.proj_size() : options.hidden_size(); + auto h_zeros = torch::zeros( + {options.num_layers() * num_directions, + max_batch_size, + real_hidden_size}, + torch::dtype(input.dtype()).device(input.device())); + auto c_zeros = torch::zeros( + {options.num_layers() * num_directions, + max_batch_size, + options.hidden_size()}, + torch::dtype(input.dtype()).device(input.device())); hx = std::make_tuple(h_zeros, c_zeros); } else { hx = hx_opt.value(); @@ -530,11 +634,27 @@ std::tuple> LSTMImpl::forward_helper( this->check_forward_args(input, hx, batch_sizes); std::tuple result; if (!batch_sizes.defined()) { - result = torch::lstm(input, {std::get<0>(hx), std::get<1>(hx)}, flat_weights_, options.bias(), options.num_layers(), - options.dropout(), this->is_training(), options.bidirectional(), options.batch_first()); + result = torch::lstm( + input, + {std::get<0>(hx), std::get<1>(hx)}, + flat_weights_, + options.bias(), + options.num_layers(), + options.dropout(), + this->is_training(), + options.bidirectional(), + options.batch_first()); } else { - result = torch::lstm(input, batch_sizes, {std::get<0>(hx), std::get<1>(hx)}, flat_weights_, options.bias(), - options.num_layers(), options.dropout(), this->is_training(), options.bidirectional()); + result = torch::lstm( + input, + batch_sizes, + {std::get<0>(hx), std::get<1>(hx)}, + flat_weights_, + options.bias(), + options.num_layers(), + options.dropout(), + this->is_training(), + options.bidirectional()); } auto output = std::get<0>(result); auto hidden = std::make_tuple(std::get<1>(result), std::get<2>(result)); @@ -543,7 +663,8 @@ std::tuple> LSTMImpl::forward_helper( } std::tuple> LSTMImpl::forward( - const Tensor& input, torch::optional> hx_opt) { + const Tensor& input, + torch::optional> hx_opt) { auto batch_sizes = torch::Tensor(); auto max_batch_size = options.batch_first() ? input.size(0) : input.size(1); auto sorted_indices = torch::Tensor(); @@ -551,13 +672,17 @@ std::tuple> LSTMImpl::forward( Tensor output; std::tuple hidden; - std::tie(output, hidden) = this->forward_helper(input, batch_sizes, sorted_indices, max_batch_size, std::move(hx_opt)); + std::tie(output, hidden) = this->forward_helper( + input, batch_sizes, sorted_indices, max_batch_size, std::move(hx_opt)); - return std::make_tuple(output, this->permute_hidden(hidden, unsorted_indices)); + return std::make_tuple( + output, this->permute_hidden(hidden, unsorted_indices)); } -std::tuple> LSTMImpl::forward_with_packed_input( - const PackedSequence& packed_input, torch::optional> hx_opt) { +std::tuple> LSTMImpl:: + forward_with_packed_input( + const PackedSequence& packed_input, + torch::optional> hx_opt) { const auto& input = packed_input.data(); const auto& batch_sizes = packed_input.batch_sizes(); const auto& sorted_indices = packed_input.sorted_indices(); @@ -566,10 +691,13 @@ std::tuple> LSTMImpl::forward_with_pa Tensor output; std::tuple hidden; - std::tie(output, hidden) = this->forward_helper(input, batch_sizes, sorted_indices, max_batch_size, std::move(hx_opt)); + std::tie(output, hidden) = this->forward_helper( + input, batch_sizes, sorted_indices, max_batch_size, std::move(hx_opt)); - auto output_packed = PackedSequence(output, batch_sizes, sorted_indices, unsorted_indices); - return std::make_tuple(output_packed, this->permute_hidden(hidden, unsorted_indices)); + auto output_packed = + PackedSequence(output, batch_sizes, sorted_indices, unsorted_indices); + return std::make_tuple( + output_packed, this->permute_hidden(hidden, unsorted_indices)); } // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ GRU ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -577,9 +705,9 @@ std::tuple> LSTMImpl::forward_with_pa GRUImpl::GRUImpl(const GRUOptions& options_) : detail::RNNImplBase( detail::RNNOptionsBase( - torch::kGRU, - options_.input_size(), - options_.hidden_size()) + torch::kGRU, + options_.input_size(), + options_.hidden_size()) .num_layers(options_.num_layers()) .bias(options_.bias()) .batch_first(options_.batch_first()) @@ -588,16 +716,18 @@ GRUImpl::GRUImpl(const GRUOptions& options_) options(options_) {} std::tuple GRUImpl::forward_helper( - const Tensor& input, - const Tensor& batch_sizes, - const Tensor& sorted_indices, - int64_t max_batch_size, - Tensor hx) { + const Tensor& input, + const Tensor& batch_sizes, + const Tensor& sorted_indices, + int64_t max_batch_size, + Tensor hx) { if (!hx.defined()) { int64_t num_directions = options.bidirectional() ? 2 : 1; - hx = torch::zeros({options.num_layers() * num_directions, - max_batch_size, options.hidden_size()}, - torch::dtype(input.dtype()).device(input.device())); + hx = torch::zeros( + {options.num_layers() * num_directions, + max_batch_size, + options.hidden_size()}, + torch::dtype(input.dtype()).device(input.device())); } else { // Each batch of the hidden state should match the input sequence that // the user believes he/she is passing in. @@ -607,11 +737,27 @@ std::tuple GRUImpl::forward_helper( this->check_forward_args(input, hx, batch_sizes); std::tuple result; if (!batch_sizes.defined()) { - result = torch::gru(input, hx, flat_weights_, options.bias(), options.num_layers(), - options.dropout(), this->is_training(), options.bidirectional(), options.batch_first()); + result = torch::gru( + input, + hx, + flat_weights_, + options.bias(), + options.num_layers(), + options.dropout(), + this->is_training(), + options.bidirectional(), + options.batch_first()); } else { - result = torch::gru(input, batch_sizes, hx, flat_weights_, options.bias(), - options.num_layers(), options.dropout(), this->is_training(), options.bidirectional()); + result = torch::gru( + input, + batch_sizes, + hx, + flat_weights_, + options.bias(), + options.num_layers(), + options.dropout(), + this->is_training(), + options.bidirectional()); } auto output = std::get<0>(result); auto hidden = std::get<1>(result); @@ -626,12 +772,16 @@ std::tuple GRUImpl::forward(const Tensor& input, Tensor hx) { auto unsorted_indices = torch::Tensor(); Tensor output, hidden; - std::tie(output, hidden) = this->forward_helper(input, batch_sizes, sorted_indices, max_batch_size, std::move(hx)); + std::tie(output, hidden) = this->forward_helper( + input, batch_sizes, sorted_indices, max_batch_size, std::move(hx)); - return std::make_tuple(output, this->permute_hidden(hidden, unsorted_indices)); + return std::make_tuple( + output, this->permute_hidden(hidden, unsorted_indices)); } -std::tuple GRUImpl::forward_with_packed_input(const PackedSequence& packed_input, Tensor hx) { +std::tuple GRUImpl::forward_with_packed_input( + const PackedSequence& packed_input, + Tensor hx) { const auto& input = packed_input.data(); const auto& batch_sizes = packed_input.batch_sizes(); const auto& sorted_indices = packed_input.sorted_indices(); @@ -639,19 +789,22 @@ std::tuple GRUImpl::forward_with_packed_input(const Pack auto max_batch_size = batch_sizes[0].item(); Tensor output, hidden; - std::tie(output, hidden) = this->forward_helper(input, batch_sizes, sorted_indices, max_batch_size, std::move(hx)); + std::tie(output, hidden) = this->forward_helper( + input, batch_sizes, sorted_indices, max_batch_size, std::move(hx)); - auto output_packed = PackedSequence(output, batch_sizes, sorted_indices, unsorted_indices); - return std::make_tuple(output_packed, this->permute_hidden(hidden, unsorted_indices)); + auto output_packed = + PackedSequence(output, batch_sizes, sorted_indices, unsorted_indices); + return std::make_tuple( + output_packed, this->permute_hidden(hidden, unsorted_indices)); } -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ RNNCellImplBase ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ RNNCellImplBase +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ namespace detail { template -RNNCellImplBase::RNNCellImplBase( - const RNNCellOptionsBase& options_) - : options_base(options_) { +RNNCellImplBase::RNNCellImplBase(const RNNCellOptionsBase& options_) + : options_base(options_) { // NOLINTNEXTLINE(clang-analyzer-optin.cplusplus.VirtualCall) reset(); } @@ -659,16 +812,28 @@ RNNCellImplBase::RNNCellImplBase( template void RNNCellImplBase::reset() { weight_ih = this->register_parameter( - "weight_ih", torch::empty({options_base.num_chunks() * options_base.hidden_size(), options_base.input_size()})); + "weight_ih", + torch::empty( + {options_base.num_chunks() * options_base.hidden_size(), + options_base.input_size()})); weight_hh = this->register_parameter( - "weight_hh", torch::empty({options_base.num_chunks() * options_base.hidden_size(), options_base.hidden_size()})); + "weight_hh", + torch::empty( + {options_base.num_chunks() * options_base.hidden_size(), + options_base.hidden_size()})); if (options_base.bias()) { - bias_ih = this->register_parameter("bias_ih", torch::empty({options_base.num_chunks() * options_base.hidden_size()})); - bias_hh = this->register_parameter("bias_hh", torch::empty({options_base.num_chunks() * options_base.hidden_size()})); + bias_ih = this->register_parameter( + "bias_ih", + torch::empty({options_base.num_chunks() * options_base.hidden_size()})); + bias_hh = this->register_parameter( + "bias_hh", + torch::empty({options_base.num_chunks() * options_base.hidden_size()})); } else { - bias_ih = this->register_parameter("bias_ih", Tensor(), /*requires_grad=*/false); - bias_hh = this->register_parameter("bias_hh", Tensor(), /*requires_grad=*/false); + bias_ih = + this->register_parameter("bias_ih", Tensor(), /*requires_grad=*/false); + bias_hh = + this->register_parameter("bias_hh", Tensor(), /*requires_grad=*/false); } reset_parameters(); @@ -686,9 +851,8 @@ template void RNNCellImplBase::pretty_print(std::ostream& stream) const { const std::string name = this->name(); const std::string name_without_impl = name.substr(0, name.size() - 4); - stream << name_without_impl - << "(" << options_base.input_size() - << ", " << options_base.hidden_size(); + stream << name_without_impl << "(" << options_base.input_size() << ", " + << options_base.hidden_size(); if (!options_base.bias()) { stream << ", bias=" << std::boolalpha << false; } @@ -702,19 +866,35 @@ void RNNCellImplBase::pretty_print(std::ostream& stream) const { template void RNNCellImplBase::check_forward_input(const Tensor& input) const { TORCH_CHECK( - input.size(1) == options_base.input_size(), - "input has inconsistent input_size: got ", input.size(1), " expected ", options_base.input_size()); + input.size(1) == options_base.input_size(), + "input has inconsistent input_size: got ", + input.size(1), + " expected ", + options_base.input_size()); } template -void RNNCellImplBase::check_forward_hidden(const Tensor& input, const Tensor& hx, std::string hidden_label) const { +void RNNCellImplBase::check_forward_hidden( + const Tensor& input, + const Tensor& hx, + std::string hidden_label) const { TORCH_CHECK( - input.size(0) == hx.size(0), - "Input batch size ", input.size(0), " doesn't match hidden", hidden_label, " batch size ", hx.size(0)); + input.size(0) == hx.size(0), + "Input batch size ", + input.size(0), + " doesn't match hidden", + hidden_label, + " batch size ", + hx.size(0)); TORCH_CHECK( - hx.size(1) == options_base.hidden_size(), - "hidden", hidden_label, " has inconsistent hidden_size: got ", hx.size(1), ", expected ", options_base.hidden_size()); + hx.size(1) == options_base.hidden_size(), + "hidden", + hidden_label, + " has inconsistent hidden_size: got ", + hx.size(1), + ", expected ", + options_base.hidden_size()); } template @@ -727,39 +907,37 @@ template class RNNCellImplBase; template class RNNCellImplBase; } // namespace detail -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ RNNCell ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ RNNCell +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ RNNCellImpl::RNNCellImpl(const RNNCellOptions& options_) - : detail::RNNCellImplBase( - detail::RNNCellOptionsBase( - options_.input_size(), - options_.hidden_size(), - options_.bias(), - /*num_chunks=*/1)), + : detail::RNNCellImplBase(detail::RNNCellOptionsBase( + options_.input_size(), + options_.hidden_size(), + options_.bias(), + /*num_chunks=*/1)), options(options_) {} - Tensor RNNCellImpl::forward(const Tensor& input, Tensor hx) { this->check_forward_input(input); if (!hx.defined()) { - hx = torch::zeros({input.size(0), options.hidden_size()}, torch::dtype(input.dtype()).device(input.device())); + hx = torch::zeros( + {input.size(0), options.hidden_size()}, + torch::dtype(input.dtype()).device(input.device())); } this->check_forward_hidden(input, hx, ""); Tensor ret; if (c10::get_if(&options.nonlinearity())) { - ret = torch::rnn_tanh_cell( - input, hx, - weight_ih, weight_hh, - bias_ih, bias_hh - ); + ret = + torch::rnn_tanh_cell(input, hx, weight_ih, weight_hh, bias_ih, bias_hh); } else if (c10::get_if(&options.nonlinearity())) { - ret = torch::rnn_relu_cell( - input, hx, - weight_ih, weight_hh, - bias_ih, bias_hh - ); + ret = + torch::rnn_relu_cell(input, hx, weight_ih, weight_hh, bias_ih, bias_hh); } else { - TORCH_CHECK(false, "Unknown nonlinearity: ", torch::enumtype::get_enum_name(options.nonlinearity())); + TORCH_CHECK( + false, + "Unknown nonlinearity: ", + torch::enumtype::get_enum_name(options.nonlinearity())); } return ret; } @@ -768,24 +946,27 @@ std::string RNNCellImpl::get_nonlinearity_str() const { return get_enum_name(options.nonlinearity()); } -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ LSTMCell ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ LSTMCell +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ LSTMCellImpl::LSTMCellImpl(const LSTMCellOptions& options_) - : detail::RNNCellImplBase( - detail::RNNCellOptionsBase( - options_.input_size(), - options_.hidden_size(), - options_.bias(), - /*num_chunks=*/4)), + : detail::RNNCellImplBase(detail::RNNCellOptionsBase( + options_.input_size(), + options_.hidden_size(), + options_.bias(), + /*num_chunks=*/4)), options(options_) {} std::tuple LSTMCellImpl::forward( - const Tensor& input, torch::optional> hx_opt) { + const Tensor& input, + torch::optional> hx_opt) { this->check_forward_input(input); std::tuple hx; if (!hx_opt.has_value()) { - auto zeros = torch::zeros({input.size(0), options.hidden_size()}, torch::dtype(input.dtype()).device(input.device())); + auto zeros = torch::zeros( + {input.size(0), options.hidden_size()}, + torch::dtype(input.dtype()).device(input.device())); hx = std::make_tuple(zeros, zeros); } else { hx = hx_opt.value(); @@ -795,34 +976,34 @@ std::tuple LSTMCellImpl::forward( this->check_forward_hidden(input, std::get<1>(hx), "[1]"); return torch::lstm_cell( - input, {std::get<0>(hx), std::get<1>(hx)}, - weight_ih, weight_hh, - bias_ih, bias_hh - ); + input, + {std::get<0>(hx), std::get<1>(hx)}, + weight_ih, + weight_hh, + bias_ih, + bias_hh); } -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ GRUCell ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ GRUCell +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ GRUCellImpl::GRUCellImpl(const GRUCellOptions& options_) - : detail::RNNCellImplBase( - detail::RNNCellOptionsBase( - options_.input_size(), - options_.hidden_size(), - options_.bias(), - /*num_chunks=*/3)), + : detail::RNNCellImplBase(detail::RNNCellOptionsBase( + options_.input_size(), + options_.hidden_size(), + options_.bias(), + /*num_chunks=*/3)), options(options_) {} Tensor GRUCellImpl::forward(const Tensor& input, Tensor hx) { this->check_forward_input(input); if (!hx.defined()) { - hx = torch::zeros({input.size(0), options.hidden_size()}, torch::dtype(input.dtype()).device(input.device())); + hx = torch::zeros( + {input.size(0), options.hidden_size()}, + torch::dtype(input.dtype()).device(input.device())); } this->check_forward_hidden(input, hx, ""); - return torch::gru_cell( - input, hx, - weight_ih, weight_hh, - bias_ih, bias_hh - ); + return torch::gru_cell(input, hx, weight_ih, weight_hh, bias_ih, bias_hh); } } // namespace nn diff --git a/torch/csrc/api/src/nn/modules/transformer.cpp b/torch/csrc/api/src/nn/modules/transformer.cpp index 956f091c4b49dd..6d643fc7354f03 100644 --- a/torch/csrc/api/src/nn/modules/transformer.cpp +++ b/torch/csrc/api/src/nn/modules/transformer.cpp @@ -1,12 +1,11 @@ #include #include -#include -#include #include +#include +#include #include - namespace F = torch::nn::functional; namespace torch { @@ -14,32 +13,39 @@ namespace nn { // ========================TransformerEncoderLayerImpl========================= TransformerEncoderLayerImpl::TransformerEncoderLayerImpl( - const TransformerEncoderLayerOptions& options_) : options(options_) { + const TransformerEncoderLayerOptions& options_) + : options(options_) { // NOLINTNEXTLINE(clang-analyzer-optin.cplusplus.VirtualCall) reset(); } void TransformerEncoderLayerImpl::reset() { - // NOTE: reset() is for initializing the model only, calling reset() after the model is created - // will throw exceptionss. Call reset_parameter() if the created model needs a reset - - self_attn = this->register_module("self_attn", - MultiheadAttention(MultiheadAttentionOptions( - options.d_model(), options.nhead()).dropout(options.dropout()))); - - linear1 = this->register_module("linear1", Linear(options.d_model(), options.dim_feedforward())); + // NOTE: reset() is for initializing the model only, calling reset() after the + // model is created will throw exceptionss. Call reset_parameter() if the + // created model needs a reset + + self_attn = this->register_module( + "self_attn", + MultiheadAttention( + MultiheadAttentionOptions(options.d_model(), options.nhead()) + .dropout(options.dropout()))); + + linear1 = this->register_module( + "linear1", Linear(options.d_model(), options.dim_feedforward())); dropout = this->register_module("dropout", Dropout(options.dropout())); - linear2 = this->register_module("linear2", Linear(options.dim_feedforward(), options.d_model())); + linear2 = this->register_module( + "linear2", Linear(options.dim_feedforward(), options.d_model())); - norm1 = this->register_module("norm1", LayerNorm(LayerNormOptions({options.d_model()}))); - norm2 = this->register_module("norm2", LayerNorm(LayerNormOptions({options.d_model()}))); + norm1 = this->register_module( + "norm1", LayerNorm(LayerNormOptions({options.d_model()}))); + norm2 = this->register_module( + "norm2", LayerNorm(LayerNormOptions({options.d_model()}))); dropout1 = this->register_module("dropout1", Dropout(options.dropout())); dropout2 = this->register_module("dropout2", Dropout(options.dropout())); } void TransformerEncoderLayerImpl::reset_parameters() { - // TODO xinyu: standardrize reset_parameters virtual funcs self_attn->_reset_parameters(); @@ -55,24 +61,25 @@ void TransformerEncoderLayerImpl::reset_parameters() { } Tensor TransformerEncoderLayerImpl::forward( - const Tensor& src, - const Tensor& src_mask, - const Tensor& src_key_padding_mask ) { - + const Tensor& src, + const Tensor& src_mask, + const Tensor& src_key_padding_mask) { // multihead attention - Tensor src2 = std::get<0>(self_attn(src, src, src, src_key_padding_mask, /*need_weights=*/true, src_mask)); + Tensor src2 = std::get<0>(self_attn( + src, src, src, src_key_padding_mask, /*need_weights=*/true, src_mask)); // add & norm Tensor ret = norm1(src + dropout1(src2)); // feedforward if (c10::get_if(&options.activation())) { src2 = linear2(dropout(F::gelu(linear1(ret)))); - } - else if (c10::get_if(&options.activation())) { + } else if (c10::get_if(&options.activation())) { src2 = linear2(dropout(F::relu(linear1(ret)))); - } - else if (c10::get_if >(&options.activation())) { - auto callable_activation = *c10::get_if >(&options.activation()); + } else if (c10::get_if>( + &options.activation())) { + auto callable_activation = + *c10::get_if>( + &options.activation()); src2 = linear2(dropout(callable_activation(linear1(ret)))); } else { TORCH_CHECK(false, "activation should be kGELU, kReLU, or a callable"); @@ -82,42 +89,51 @@ Tensor TransformerEncoderLayerImpl::forward( return norm2(ret + dropout2(src2)); } - // ========================TransformerDecoderLayerImpl========================= TransformerDecoderLayerImpl::TransformerDecoderLayerImpl( - const TransformerDecoderLayerOptions& options_ ) : options(options_) { + const TransformerDecoderLayerOptions& options_) + : options(options_) { // NOLINTNEXTLINE(clang-analyzer-optin.cplusplus.VirtualCall) reset(); } void TransformerDecoderLayerImpl::reset() { - // NOTE: reset() is for initializing the model only, calling reset() after the model is created - // will cause throwing exceptions. Call reset_parameter() if the created model needs a reset. + // NOTE: reset() is for initializing the model only, calling reset() after the + // model is created will cause throwing exceptions. Call reset_parameter() if + // the created model needs a reset. // initialize self attention - self_attn = this->register_module("self_attn", - MultiheadAttention(MultiheadAttentionOptions( - options.d_model(), options.nhead()).dropout(options.dropout()))); + self_attn = this->register_module( + "self_attn", + MultiheadAttention( + MultiheadAttentionOptions(options.d_model(), options.nhead()) + .dropout(options.dropout()))); // initialize multihed attention - multihead_attn = this->register_module("multihead_attn", - MultiheadAttention(MultiheadAttentionOptions( - options.d_model(), options.nhead()).dropout(options.dropout()))); + multihead_attn = this->register_module( + "multihead_attn", + MultiheadAttention( + MultiheadAttentionOptions(options.d_model(), options.nhead()) + .dropout(options.dropout()))); // Initialize Feed forward first linear layer - linear1 = this->register_module("linear1", Linear(options.d_model(), options.dim_feedforward())); + linear1 = this->register_module( + "linear1", Linear(options.d_model(), options.dim_feedforward())); // initialize Feed forward dropout layer dropout = this->register_module("dropout", Dropout(options.dropout())); // initialize Feed forward second linear layer - linear2 = this->register_module("linear2", Linear(options.dim_feedforward(), options.d_model())); - + linear2 = this->register_module( + "linear2", Linear(options.dim_feedforward(), options.d_model())); // initialize Normalization, post self attention - norm1 = this->register_module("norm1", LayerNorm(LayerNormOptions({options.d_model()}))); + norm1 = this->register_module( + "norm1", LayerNorm(LayerNormOptions({options.d_model()}))); // initialize post multi-headed attention Normalization - norm2 = this->register_module("norm2", LayerNorm(LayerNormOptions({options.d_model()}))); + norm2 = this->register_module( + "norm2", LayerNorm(LayerNormOptions({options.d_model()}))); // initialize normalization, post feed forward - norm3 = this->register_module("norm3", LayerNorm(LayerNormOptions({options.d_model()}))); + norm3 = this->register_module( + "norm3", LayerNorm(LayerNormOptions({options.d_model()}))); // initialize Dropout, post self attention dropout1 = this->register_module("dropout1", Dropout(options.dropout())); @@ -128,7 +144,6 @@ void TransformerDecoderLayerImpl::reset() { } void TransformerDecoderLayerImpl::reset_parameters() { - // TODO xinyu: standardrize reset_parameters virtual funcs self_attn->_reset_parameters(); multihead_attn->_reset_parameters(); @@ -145,33 +160,32 @@ void TransformerDecoderLayerImpl::reset_parameters() { // dropout3->reset_paramteres(); } -///Pass the inputs (and mask) through the decoder layer. +/// Pass the inputs (and mask) through the decoder layer. Tensor TransformerDecoderLayerImpl::forward( - Tensor tgt, - const Tensor& memory, - const Tensor& tgt_mask, - const Tensor& memory_mask, - const Tensor& tgt_key_padding_mask, - const Tensor& memory_key_padding_mask){ - - Tensor tgt2 = std::get<0>(self_attn( - tgt, //query - tgt, //key - tgt, //value - tgt_key_padding_mask, //key_padding_mask - false, //need_weights - tgt_mask)//attn_mask + Tensor tgt, + const Tensor& memory, + const Tensor& tgt_mask, + const Tensor& memory_mask, + const Tensor& tgt_key_padding_mask, + const Tensor& memory_key_padding_mask) { + Tensor tgt2 = std::get<0>(self_attn( + tgt, // query + tgt, // key + tgt, // value + tgt_key_padding_mask, // key_padding_mask + false, // need_weights + tgt_mask) // attn_mask ); tgt = tgt + dropout1(tgt2); tgt = norm1(tgt); tgt2 = std::get<0>(multihead_attn( - tgt, //query - memory, //key - memory, //value - memory_key_padding_mask, //key_padding_mask - false, //need_weights - memory_mask)//attn_mask + tgt, // query + memory, // key + memory, // value + memory_key_padding_mask, // key_padding_mask + false, // need_weights + memory_mask) // attn_mask ); tgt = tgt + dropout2(tgt2); tgt = norm2(tgt); @@ -183,23 +197,26 @@ Tensor TransformerDecoderLayerImpl::forward( return tgt; } -Tensor TransformerDecoderLayerImpl::activation(const Tensor& input){ +Tensor TransformerDecoderLayerImpl::activation(const Tensor& input) { if (c10::get_if(&options.activation())) { return F::gelu(input); } else if (c10::get_if(&options.activation())) { return F::relu(input); - } else if(c10::get_if >(&options.activation())) { - auto callable_activation = *c10::get_if >(&options.activation()); + } else if (c10::get_if>( + &options.activation())) { + auto callable_activation = + *c10::get_if>( + &options.activation()); return callable_activation(input); } else { TORCH_CHECK(false, "activation should be kGELU, kReLU, or a callable"); } } - // ========================TransformerEncoderImpl========================= TransformerEncoderImpl::TransformerEncoderImpl( - TransformerEncoderOptions options_) : options(std::move(options_)) { + TransformerEncoderOptions options_) + : options(std::move(options_)) { // NOLINTNEXTLINE(clang-analyzer-optin.cplusplus.VirtualCall) reset(); } @@ -219,15 +236,19 @@ void TransformerEncoderImpl::reset() { void TransformerEncoderImpl::reset_parameters() { TORCH_CHECK( - layers->size() == static_cast(options.num_layers()), - "TransformerEncoder should have", options.num_layers(), " encoder layers, but got ", layers->size()); + layers->size() == static_cast(options.num_layers()), + "TransformerEncoder should have", + options.num_layers(), + " encoder layers, but got ", + layers->size()); size_t num_layers = layers->size(); for (const auto i : c10::irange(num_layers)) { layers->at(i).reset_parameters(); } - // a. No way to know whether module in AnyModule has api to reset_parameters, so replace instead - // b. Allow user to add/delete normalization module when reset parameters + // a. No way to know whether module in AnyModule has api to reset_parameters, + // so replace instead b. Allow user to add/delete normalization module when + // reset parameters if (!norm.is_empty()) { this->unregister_module("norm"); norm = AnyModule(); @@ -239,17 +260,18 @@ void TransformerEncoderImpl::reset_parameters() { } Tensor TransformerEncoderImpl::forward( - const Tensor& src, - const Tensor& src_mask, - const Tensor& src_key_padding_mask ) { - + const Tensor& src, + const Tensor& src_mask, + const Tensor& src_key_padding_mask) { size_t num_layers = layers->size(); Tensor output; if (num_layers > 0) { - output = layers->at(0).forward(src, src_mask, src_key_padding_mask); + output = layers->at(0).forward( + src, src_mask, src_key_padding_mask); } for (const auto i : c10::irange(1, num_layers)) { - output = layers->at(i).forward(output, src_mask, src_key_padding_mask); + output = layers->at(i).forward( + output, src_mask, src_key_padding_mask); } if (!norm.is_empty()) { @@ -260,13 +282,13 @@ Tensor TransformerEncoderImpl::forward( // ========================TransformerDecoderImpl========================= TransformerDecoderImpl::TransformerDecoderImpl( - TransformerDecoderOptions options_ ) : options(std::move(options_)){ + TransformerDecoderOptions options_) + : options(std::move(options_)) { // NOLINTNEXTLINE(clang-analyzer-optin.cplusplus.VirtualCall) reset(); } void TransformerDecoderImpl::reset() { - layers = this->register_module("layers", ModuleList()); for (const auto i : c10::irange(options.num_layers())) { (void)i; // Suppress unused variable warning @@ -280,17 +302,20 @@ void TransformerDecoderImpl::reset() { } void TransformerDecoderImpl::reset_parameters() { - - TORCH_CHECK(layers->size() == static_cast(options.num_layers()), - "TransformerDecoder should have", options.num_layers(), - " decoder layers, but got ", layers->size()); + TORCH_CHECK( + layers->size() == static_cast(options.num_layers()), + "TransformerDecoder should have", + options.num_layers(), + " decoder layers, but got ", + layers->size()); size_t num_layers = layers->size(); for (const auto i : c10::irange(num_layers)) { layers->at(i).reset_parameters(); } - // a. No way to know whether module in AnyModule has api to reset_parameters, so replace instead - // b. Allow user to add/delete normalization module when reset parameters + // a. No way to know whether module in AnyModule has api to reset_parameters, + // so replace instead b. Allow user to add/delete normalization module when + // reset parameters if (!norm.is_empty()) { this->unregister_module("norm"); norm = AnyModule(); @@ -302,32 +327,31 @@ void TransformerDecoderImpl::reset_parameters() { } Tensor TransformerDecoderImpl::forward( - const Tensor& tgt, - const Tensor& memory, - const Tensor& tgt_mask, - const Tensor& memory_mask, - const Tensor& tgt_key_padding_mask, - const Tensor& memory_key_padding_mask){ - + const Tensor& tgt, + const Tensor& memory, + const Tensor& tgt_mask, + const Tensor& memory_mask, + const Tensor& tgt_key_padding_mask, + const Tensor& memory_key_padding_mask) { size_t num_layers = layers->size(); Tensor output; if (num_layers > 0) { output = layers->at(0).forward( - tgt, - memory, - tgt_mask, - memory_mask, - tgt_key_padding_mask, - memory_key_padding_mask); + tgt, + memory, + tgt_mask, + memory_mask, + tgt_key_padding_mask, + memory_key_padding_mask); } for (const auto i : c10::irange(1, num_layers)) { output = layers->at(i).forward( - output, - memory, - tgt_mask, - memory_mask, - tgt_key_padding_mask, - memory_key_padding_mask); + output, + memory, + tgt_mask, + memory_mask, + tgt_key_padding_mask, + memory_key_padding_mask); } if (!norm.is_empty()) { @@ -337,9 +361,9 @@ Tensor TransformerDecoderImpl::forward( return output; } - // =======================================TransformerImpl================================ -TransformerImpl::TransformerImpl(TransformerOptions options_ ) : options(std::move(options_)){ +TransformerImpl::TransformerImpl(TransformerOptions options_) + : options(std::move(options_)) { // NOLINTNEXTLINE(clang-analyzer-optin.cplusplus.VirtualCall) reset(); } @@ -348,34 +372,35 @@ void TransformerImpl::reset() { // set up encoder if (options.custom_encoder().is_empty()) { LayerNorm norm(LayerNormOptions({options.d_model()})); - TransformerEncoder trans_encoder(TransformerEncoderOptions( - TransformerEncoderLayerOptions(options.d_model(), options.nhead()) - .dim_feedforward(options.dim_feedforward()) - .dropout(options.dropout()) - .activation(options.activation()), options.num_encoder_layers()) - .norm(AnyModule(norm))); + TransformerEncoder trans_encoder( + TransformerEncoderOptions( + TransformerEncoderLayerOptions(options.d_model(), options.nhead()) + .dim_feedforward(options.dim_feedforward()) + .dropout(options.dropout()) + .activation(options.activation()), + options.num_encoder_layers()) + .norm(AnyModule(norm))); this->encoder = AnyModule(trans_encoder); - } - else { + } else { this->encoder = options.custom_encoder().clone(); } this->register_module("encoder", this->encoder.ptr()); // set up decoder if (options.custom_decoder().is_empty()) { - LayerNorm norm(LayerNormOptions({options.d_model()})); - TransformerDecoder trans_decoder(TransformerDecoderOptions( - TransformerDecoderLayerOptions(options.d_model(), options.nhead()) - .dim_feedforward(options.dim_feedforward()) - .dropout(options.dropout()) - .activation(options.activation()), options.num_decoder_layers()) - .norm(AnyModule(norm))); + TransformerDecoder trans_decoder( + TransformerDecoderOptions( + TransformerDecoderLayerOptions(options.d_model(), options.nhead()) + .dim_feedforward(options.dim_feedforward()) + .dropout(options.dropout()) + .activation(options.activation()), + options.num_decoder_layers()) + .norm(AnyModule(norm))); this->decoder = AnyModule(trans_decoder); - } - else { + } else { this->decoder = options.custom_decoder().clone(); } this->register_module("decoder", this->decoder.ptr()); @@ -383,7 +408,6 @@ void TransformerImpl::reset() { reset_parameters(); } - void TransformerImpl::reset_parameters() { auto parameters = this->parameters(); for (auto& param : parameters) { @@ -393,49 +417,72 @@ void TransformerImpl::reset_parameters() { } } - Tensor TransformerImpl::forward( - const Tensor& src, - const Tensor& tgt, - const Tensor& src_mask, - const Tensor& tgt_mask, - const Tensor& memory_mask, - const Tensor& src_key_padding_mask, - const Tensor& tgt_key_padding_mask, - const Tensor& memory_key_padding_mask) { - - TORCH_CHECK(src.dim() == 3 && tgt.dim() == 3, - "src and tgt should have 3 dimensions, but got ", src.dim(), " and ", tgt.dim()); - - TORCH_CHECK(src.size(1) == tgt.size(1), - "src and tgt should have equal batch size (at dim 1), but got ", src.size(1), " and ", tgt.size(1)); - - TORCH_CHECK(src.size(2) == options.d_model() && tgt.size(2) == options.d_model(), - "src and tgt should have same feature size as d_model (at dim 2), but got ", - src.size(2), " and ", tgt.size(2), " while d_model is ", options.d_model()); - - Tensor memory = this->encoder.forward(src, src_mask, src_key_padding_mask); + const Tensor& src, + const Tensor& tgt, + const Tensor& src_mask, + const Tensor& tgt_mask, + const Tensor& memory_mask, + const Tensor& src_key_padding_mask, + const Tensor& tgt_key_padding_mask, + const Tensor& memory_key_padding_mask) { + TORCH_CHECK( + src.dim() == 3 && tgt.dim() == 3, + "src and tgt should have 3 dimensions, but got ", + src.dim(), + " and ", + tgt.dim()); + + TORCH_CHECK( + src.size(1) == tgt.size(1), + "src and tgt should have equal batch size (at dim 1), but got ", + src.size(1), + " and ", + tgt.size(1)); + + TORCH_CHECK( + src.size(2) == options.d_model() && tgt.size(2) == options.d_model(), + "src and tgt should have same feature size as d_model (at dim 2), but got ", + src.size(2), + " and ", + tgt.size(2), + " while d_model is ", + options.d_model()); + + Tensor memory = + this->encoder.forward(src, src_mask, src_key_padding_mask); Tensor output = this->decoder.forward( - tgt, memory, tgt_mask, memory_mask, tgt_key_padding_mask, memory_key_padding_mask); + tgt, + memory, + tgt_mask, + memory_mask, + tgt_key_padding_mask, + memory_key_padding_mask); return output; } Tensor TransformerImpl::generate_square_subsequent_mask(int64_t sz) { // Treat 0 dim valid here - TORCH_CHECK(sz >= 0, - "Input size must be non-negative to genearte a valid square subsequent mask, but got ", sz); + TORCH_CHECK( + sz >= 0, + "Input size must be non-negative to genearte a valid square subsequent mask, but got ", + sz); - // check IEEE754 support here since -inf is not guaranteed to be valid on non IEEE754 platform + // check IEEE754 support here since -inf is not guaranteed to be valid on non + // IEEE754 platform if (std::numeric_limits::is_iec559) { - return torch::triu(torch::full({sz, sz}, -std::numeric_limits::infinity()), 1); + return torch::triu( + torch::full({sz, sz}, -std::numeric_limits::infinity()), 1); } - // if IEEE754 is not supported, we use the smallest float number in current platform + // if IEEE754 is not supported, we use the smallest float number in current + // platform else { TORCH_WARN_ONCE( - "IEEE754 is not supporetd on this platform, generate_square_subsequent_mask will fill " - "the mask with smallest float number on this platform instead of -inf"); - return torch::triu(torch::full({sz, sz}, std::numeric_limits::lowest()), 1); + "IEEE754 is not supporetd on this platform, generate_square_subsequent_mask will fill " + "the mask with smallest float number on this platform instead of -inf"); + return torch::triu( + torch::full({sz, sz}, std::numeric_limits::lowest()), 1); } } diff --git a/torch/csrc/api/src/nn/modules/upsampling.cpp b/torch/csrc/api/src/nn/modules/upsampling.cpp index 600493e5f79128..2d74cc330690f6 100644 --- a/torch/csrc/api/src/nn/modules/upsampling.cpp +++ b/torch/csrc/api/src/nn/modules/upsampling.cpp @@ -7,7 +7,8 @@ namespace F = torch::nn::functional; namespace torch { namespace nn { -UpsampleImpl::UpsampleImpl(const UpsampleOptions& options_) // NOLINT(modernize-pass-by-value) +UpsampleImpl::UpsampleImpl( + const UpsampleOptions& options_) // NOLINT(modernize-pass-by-value) : options(options_) {} void UpsampleImpl::reset() {} diff --git a/torch/csrc/api/src/nn/options/activation.cpp b/torch/csrc/api/src/nn/options/activation.cpp index 2d56128f9d4c10..8476a4ff61a270 100644 --- a/torch/csrc/api/src/nn/options/activation.cpp +++ b/torch/csrc/api/src/nn/options/activation.cpp @@ -22,9 +22,12 @@ ReLU6Options::ReLU6Options(bool inplace) : inplace_(inplace) {} SoftshrinkOptions::SoftshrinkOptions(double lambda) : lambda_(lambda) {} MultiheadAttentionOptions::MultiheadAttentionOptions( - int64_t embed_dim, int64_t num_heads) - : embed_dim_(embed_dim), num_heads_(num_heads), - kdim_(embed_dim), vdim_(embed_dim) {} + int64_t embed_dim, + int64_t num_heads) + : embed_dim_(embed_dim), + num_heads_(num_heads), + kdim_(embed_dim), + vdim_(embed_dim) {} namespace functional { @@ -35,16 +38,22 @@ SoftminFuncOptions::SoftminFuncOptions(int64_t dim) : dim_(dim) {} LogSoftmaxFuncOptions::LogSoftmaxFuncOptions(int64_t dim) : dim_(dim) {} MultiheadAttentionForwardFuncOptions::MultiheadAttentionForwardFuncOptions( - int64_t embed_dim_to_check, int64_t num_heads, - Tensor in_proj_weight, Tensor in_proj_bias, - Tensor bias_k, Tensor bias_v, - bool add_zero_attn, double dropout_p, - Tensor out_proj_weight, Tensor out_proj_bias - ) : embed_dim_to_check_(embed_dim_to_check), + int64_t embed_dim_to_check, + int64_t num_heads, + Tensor in_proj_weight, + Tensor in_proj_bias, + Tensor bias_k, + Tensor bias_v, + bool add_zero_attn, + double dropout_p, + Tensor out_proj_weight, + Tensor out_proj_bias) + : embed_dim_to_check_(embed_dim_to_check), num_heads_(num_heads), in_proj_weight_(std::move(in_proj_weight)), in_proj_bias_(std::move(in_proj_bias)), - bias_k_(std::move(bias_k)), bias_v_(std::move(bias_v)), + bias_k_(std::move(bias_k)), + bias_v_(std::move(bias_v)), add_zero_attn_(add_zero_attn), dropout_p_(dropout_p), out_proj_weight_(std::move(out_proj_weight)), diff --git a/torch/csrc/api/src/nn/options/adaptive.cpp b/torch/csrc/api/src/nn/options/adaptive.cpp index bff2e4292d9302..1a8fcc4dc61edb 100644 --- a/torch/csrc/api/src/nn/options/adaptive.cpp +++ b/torch/csrc/api/src/nn/options/adaptive.cpp @@ -4,8 +4,12 @@ namespace torch { namespace nn { AdaptiveLogSoftmaxWithLossOptions::AdaptiveLogSoftmaxWithLossOptions( - int64_t in_features, int64_t n_classes, std::vector cutoffs) - : in_features_(in_features), n_classes_(n_classes), cutoffs_(std::move(cutoffs)) {} + int64_t in_features, + int64_t n_classes, + std::vector cutoffs) + : in_features_(in_features), + n_classes_(n_classes), + cutoffs_(std::move(cutoffs)) {} } // namespace nn } // namespace torch diff --git a/torch/csrc/api/src/nn/options/batchnorm.cpp b/torch/csrc/api/src/nn/options/batchnorm.cpp index 21444439137490..a0f7f226389859 100644 --- a/torch/csrc/api/src/nn/options/batchnorm.cpp +++ b/torch/csrc/api/src/nn/options/batchnorm.cpp @@ -3,7 +3,8 @@ namespace torch { namespace nn { -BatchNormOptions::BatchNormOptions(int64_t num_features) : num_features_(num_features) {} +BatchNormOptions::BatchNormOptions(int64_t num_features) + : num_features_(num_features) {} } // namespace nn } // namespace torch diff --git a/torch/csrc/api/src/nn/options/embedding.cpp b/torch/csrc/api/src/nn/options/embedding.cpp index d7965ccc968be9..3b9509d19a026c 100644 --- a/torch/csrc/api/src/nn/options/embedding.cpp +++ b/torch/csrc/api/src/nn/options/embedding.cpp @@ -2,10 +2,14 @@ namespace torch { namespace nn { -EmbeddingOptions::EmbeddingOptions(int64_t num_embeddings, int64_t embedding_dim) : - num_embeddings_(num_embeddings), embedding_dim_(embedding_dim) {} +EmbeddingOptions::EmbeddingOptions( + int64_t num_embeddings, + int64_t embedding_dim) + : num_embeddings_(num_embeddings), embedding_dim_(embedding_dim) {} -EmbeddingBagOptions::EmbeddingBagOptions(int64_t num_embeddings, int64_t embedding_dim) : - num_embeddings_(num_embeddings), embedding_dim_(embedding_dim) {} +EmbeddingBagOptions::EmbeddingBagOptions( + int64_t num_embeddings, + int64_t embedding_dim) + : num_embeddings_(num_embeddings), embedding_dim_(embedding_dim) {} } // namespace nn } // namespace torch diff --git a/torch/csrc/api/src/nn/options/instancenorm.cpp b/torch/csrc/api/src/nn/options/instancenorm.cpp index 930fc276db9dc2..405c2641955451 100644 --- a/torch/csrc/api/src/nn/options/instancenorm.cpp +++ b/torch/csrc/api/src/nn/options/instancenorm.cpp @@ -3,7 +3,8 @@ namespace torch { namespace nn { -InstanceNormOptions::InstanceNormOptions(int64_t num_features) : num_features_(num_features) {} +InstanceNormOptions::InstanceNormOptions(int64_t num_features) + : num_features_(num_features) {} } // namespace nn } // namespace torch diff --git a/torch/csrc/api/src/nn/options/linear.cpp b/torch/csrc/api/src/nn/options/linear.cpp index 5c1922894c0954..67e167ee117102 100644 --- a/torch/csrc/api/src/nn/options/linear.cpp +++ b/torch/csrc/api/src/nn/options/linear.cpp @@ -4,19 +4,28 @@ namespace torch { namespace nn { LinearOptions::LinearOptions(int64_t in_features, int64_t out_features) - : in_features_(in_features), out_features_(out_features) {} + : in_features_(in_features), out_features_(out_features) {} -BilinearOptions::BilinearOptions(int64_t in1_features, int64_t in2_features, int64_t out_features) - : in1_features_(in1_features), in2_features_(in2_features), out_features_(out_features) {} +BilinearOptions::BilinearOptions( + int64_t in1_features, + int64_t in2_features, + int64_t out_features) + : in1_features_(in1_features), + in2_features_(in2_features), + out_features_(out_features) {} UnflattenOptions::UnflattenOptions(int64_t dim, std::vector sizes) - : dim_(dim), sizes_(std::move(sizes)) {} + : dim_(dim), sizes_(std::move(sizes)) {} UnflattenOptions::UnflattenOptions(const char* dimname, namedshape_t namedshape) - : dim_(0), dimname_(std::string(dimname)), namedshape_(std::move(namedshape)) {} + : dim_(0), + dimname_(std::string(dimname)), + namedshape_(std::move(namedshape)) {} UnflattenOptions::UnflattenOptions(std::string dimname, namedshape_t namedshape) - : dim_(0), dimname_(std::move(dimname)), namedshape_(std::move(namedshape)) {} + : dim_(0), + dimname_(std::move(dimname)), + namedshape_(std::move(namedshape)) {} } // namespace nn } // namespace torch diff --git a/torch/csrc/api/src/nn/options/normalization.cpp b/torch/csrc/api/src/nn/options/normalization.cpp index 9d6549ef5a3609..3b1600c6a69b74 100644 --- a/torch/csrc/api/src/nn/options/normalization.cpp +++ b/torch/csrc/api/src/nn/options/normalization.cpp @@ -3,18 +3,22 @@ namespace torch { namespace nn { -LayerNormOptions::LayerNormOptions(std::vector normalized_shape) : normalized_shape_(std::move(normalized_shape)) {} +LayerNormOptions::LayerNormOptions(std::vector normalized_shape) + : normalized_shape_(std::move(normalized_shape)) {} CrossMapLRN2dOptions::CrossMapLRN2dOptions(int64_t size) : size_(size) {} GroupNormOptions::GroupNormOptions(int64_t num_groups, int64_t num_channels) - : num_groups_(num_groups), num_channels_(num_channels) {} + : num_groups_(num_groups), num_channels_(num_channels) {} namespace functional { -LayerNormFuncOptions::LayerNormFuncOptions(std::vector normalized_shape) : normalized_shape_(std::move(normalized_shape)) {} +LayerNormFuncOptions::LayerNormFuncOptions( + std::vector normalized_shape) + : normalized_shape_(std::move(normalized_shape)) {} -GroupNormFuncOptions::GroupNormFuncOptions(int64_t num_groups) : num_groups_(num_groups) {} +GroupNormFuncOptions::GroupNormFuncOptions(int64_t num_groups) + : num_groups_(num_groups) {} } // namespace functional diff --git a/torch/csrc/api/src/nn/options/padding.cpp b/torch/csrc/api/src/nn/options/padding.cpp index c85bcc0464edca..30b62adddd2732 100644 --- a/torch/csrc/api/src/nn/options/padding.cpp +++ b/torch/csrc/api/src/nn/options/padding.cpp @@ -16,7 +16,8 @@ template struct ConstantPadOptions<3>; namespace functional { -PadFuncOptions::PadFuncOptions(std::vector pad) : pad_(std::move(pad)) {} +PadFuncOptions::PadFuncOptions(std::vector pad) + : pad_(std::move(pad)) {} } // namespace functional diff --git a/torch/csrc/api/src/nn/options/rnn.cpp b/torch/csrc/api/src/nn/options/rnn.cpp index d4a0841b7c555f..b948c0afac1d1b 100644 --- a/torch/csrc/api/src/nn/options/rnn.cpp +++ b/torch/csrc/api/src/nn/options/rnn.cpp @@ -5,7 +5,10 @@ namespace nn { namespace detail { -RNNOptionsBase::RNNOptionsBase(rnn_options_base_mode_t mode, int64_t input_size, int64_t hidden_size) +RNNOptionsBase::RNNOptionsBase( + rnn_options_base_mode_t mode, + int64_t input_size, + int64_t hidden_size) : mode_(mode), input_size_(input_size), hidden_size_(hidden_size) {} } // namespace detail @@ -21,8 +24,15 @@ GRUOptions::GRUOptions(int64_t input_size, int64_t hidden_size) namespace detail { -RNNCellOptionsBase::RNNCellOptionsBase(int64_t input_size, int64_t hidden_size, bool bias, int64_t num_chunks) - : input_size_(input_size), hidden_size_(hidden_size), bias_(bias), num_chunks_(num_chunks) {} +RNNCellOptionsBase::RNNCellOptionsBase( + int64_t input_size, + int64_t hidden_size, + bool bias, + int64_t num_chunks) + : input_size_(input_size), + hidden_size_(hidden_size), + bias_(bias), + num_chunks_(num_chunks) {} } // namespace detail diff --git a/torch/csrc/api/src/nn/options/transformer.cpp b/torch/csrc/api/src/nn/options/transformer.cpp index 937246de83307b..2afb9bda543c41 100644 --- a/torch/csrc/api/src/nn/options/transformer.cpp +++ b/torch/csrc/api/src/nn/options/transformer.cpp @@ -1,46 +1,52 @@ #include -#include #include +#include namespace torch { namespace nn { TransformerEncoderLayerOptions::TransformerEncoderLayerOptions( - int64_t d_model, int64_t nhead) : d_model_(d_model), nhead_(nhead) {} - + int64_t d_model, + int64_t nhead) + : d_model_(d_model), nhead_(nhead) {} TransformerDecoderLayerOptions::TransformerDecoderLayerOptions( - int64_t d_model, int64_t nhead) : d_model_(d_model), nhead_(nhead){} - + int64_t d_model, + int64_t nhead) + : d_model_(d_model), nhead_(nhead) {} TransformerEncoderOptions::TransformerEncoderOptions( - TransformerEncoderLayer encoder_layer, int64_t num_layers) : - encoder_layer_(std::move(encoder_layer)), num_layers_(num_layers) {} - + TransformerEncoderLayer encoder_layer, + int64_t num_layers) + : encoder_layer_(std::move(encoder_layer)), num_layers_(num_layers) {} TransformerEncoderOptions::TransformerEncoderOptions( - const TransformerEncoderLayerOptions& encoder_layer_options, int64_t num_layers) : - encoder_layer_(encoder_layer_options), num_layers_(num_layers) {} - + const TransformerEncoderLayerOptions& encoder_layer_options, + int64_t num_layers) + : encoder_layer_(encoder_layer_options), num_layers_(num_layers) {} TransformerDecoderOptions::TransformerDecoderOptions( - TransformerDecoderLayer decoder_layer, int64_t num_layers) : - decoder_layer_(std::move(decoder_layer)), num_layers_(num_layers) {} - + TransformerDecoderLayer decoder_layer, + int64_t num_layers) + : decoder_layer_(std::move(decoder_layer)), num_layers_(num_layers) {} TransformerDecoderOptions::TransformerDecoderOptions( - const TransformerDecoderLayerOptions& decoder_layer_options, - int64_t num_layers) - : decoder_layer_(decoder_layer_options), num_layers_(num_layers){} - + const TransformerDecoderLayerOptions& decoder_layer_options, + int64_t num_layers) + : decoder_layer_(decoder_layer_options), num_layers_(num_layers) {} -TransformerOptions::TransformerOptions(int64_t d_model, int64_t nhead) : - d_model_(d_model), nhead_(nhead) {} +TransformerOptions::TransformerOptions(int64_t d_model, int64_t nhead) + : d_model_(d_model), nhead_(nhead) {} TransformerOptions::TransformerOptions( - int64_t d_model, int64_t nhead, int64_t num_encoder_layers, int64_t num_decoder_layers) : - d_model_(d_model), nhead_(nhead), - num_encoder_layers_(num_encoder_layers), num_decoder_layers_(num_decoder_layers) {} + int64_t d_model, + int64_t nhead, + int64_t num_encoder_layers, + int64_t num_decoder_layers) + : d_model_(d_model), + nhead_(nhead), + num_encoder_layers_(num_encoder_layers), + num_decoder_layers_(num_decoder_layers) {} } // namespace nn } // namespace torch diff --git a/torch/csrc/api/src/optim/adagrad.cpp b/torch/csrc/api/src/optim/adagrad.cpp index 43b992993ccd56..245c1d39817139 100644 --- a/torch/csrc/api/src/optim/adagrad.cpp +++ b/torch/csrc/api/src/optim/adagrad.cpp @@ -1,9 +1,9 @@ #include #include +#include #include #include -#include #include #include @@ -16,11 +16,10 @@ namespace optim { AdagradOptions::AdagradOptions(double lr) : lr_(lr) {} bool operator==(const AdagradOptions& lhs, const AdagradOptions& rhs) { - return (lhs.lr() == rhs.lr()) && - (lhs.lr_decay() == rhs.lr_decay()) && - (lhs.weight_decay() == rhs.weight_decay()) && - (lhs.initial_accumulator_value() == rhs.initial_accumulator_value()) && - (lhs.eps() == rhs.eps()); + return (lhs.lr() == rhs.lr()) && (lhs.lr_decay() == rhs.lr_decay()) && + (lhs.weight_decay() == rhs.weight_decay()) && + (lhs.initial_accumulator_value() == rhs.initial_accumulator_value()) && + (lhs.eps() == rhs.eps()); } void AdagradOptions::serialize(torch::serialize::OutputArchive& archive) const { @@ -48,11 +47,11 @@ void AdagradOptions::set_lr(const double lr) { } bool operator==(const AdagradParamState& lhs, const AdagradParamState& rhs) { - return (lhs.step() == rhs.step()) && - torch::equal(lhs.sum(), rhs.sum()); + return (lhs.step() == rhs.step()) && torch::equal(lhs.sum(), rhs.sum()); } -void AdagradParamState::serialize(torch::serialize::OutputArchive& archive) const { +void AdagradParamState::serialize( + torch::serialize::OutputArchive& archive) const { _TORCH_OPTIM_SERIALIZE_TORCH_ARG(step); _TORCH_OPTIM_SERIALIZE_TORCH_ARG(sum); } @@ -77,14 +76,20 @@ Tensor Adagrad::step(LossClosure closure) { continue; } auto grad = p.grad(); - TORCH_INTERNAL_ASSERT(state_[c10::guts::to_string(p.unsafeGetTensorImpl())] != nullptr, "state found NULL for the Tensor ", p); - auto& state = static_cast(*state_[c10::guts::to_string(p.unsafeGetTensorImpl())]); + TORCH_INTERNAL_ASSERT( + state_[c10::guts::to_string(p.unsafeGetTensorImpl())] != nullptr, + "state found NULL for the Tensor ", + p); + auto& state = static_cast( + *state_[c10::guts::to_string(p.unsafeGetTensorImpl())]); auto& options = static_cast(group.options()); state.step(state.step() + 1); if (options.weight_decay() != 0) { - TORCH_CHECK(!p.grad().is_sparse(), "weight_decay option is not compatible with sparse gradients"); + TORCH_CHECK( + !p.grad().is_sparse(), + "weight_decay option is not compatible with sparse gradients"); grad = grad.add(p, options.weight_decay()); } const auto clr = options.lr() / @@ -96,19 +101,19 @@ Tensor Adagrad::step(LossClosure closure) { auto grad_values = grad._values(); auto size = grad.sizes(); - auto make_sparse = [&] (const Tensor& values) -> Tensor { + auto make_sparse = [&](const Tensor& values) -> Tensor { if (grad_indices.dim() == 0 || values.dim() == 0) { return torch::empty({0}, grad.options()).resize_as_(grad); } - return torch::sparse_coo_tensor(grad_indices, values, size, grad.options()); + return torch::sparse_coo_tensor( + grad_indices, values, size, grad.options()); }; state.sum(state.sum().add_(make_sparse(grad_values.pow(2)))); auto std = state.sum().sparse_mask(grad); const auto std_values = std._values().sqrt_().add_(options.eps()); p.add_(make_sparse(grad_values / std_values), -clr); - } - else { + } else { state.sum(state.sum().addcmul_(grad, grad, 1.0)); const auto std = state.sum().sqrt().add_(options.eps()); p.addcdiv_(grad, std, -clr); @@ -126,22 +131,24 @@ void Adagrad::load(serialize::InputArchive& archive) { IValue pytorch_version; if (archive.try_read("pytorch_version", pytorch_version)) { serialize(*this, archive); - } - else { // deserializing archives saved in old format (prior to version 1.5.0) + } else { // deserializing archives saved in old format (prior to + // version 1.5.0) TORCH_WARN( - "Your serialized Adagrad optimizer is still using the old serialization format. " - "You should re-save your Adagrad optimizer to use the new serialization format."); + "Your serialized Adagrad optimizer is still using the old serialization format. " + "You should re-save your Adagrad optimizer to use the new serialization format."); std::vector sum_buffers; std::vector step_buffers; torch::optim::serialize(archive, "sum_buffers", sum_buffers); torch::optim::serialize(archive, "step_buffers", step_buffers); - // since there were no param_groups prior to version 1.5.0, assuming all tensors are now in one param_group + // since there were no param_groups prior to version 1.5.0, assuming all + // tensors are now in one param_group std::vector params = param_groups_.at(0).params(); - for(const auto idx : c10::irange(params.size())) { + for (const auto idx : c10::irange(params.size())) { auto state = std::make_unique(); state->step(step_buffers[idx]); state->sum(sum_buffers[idx]); - state_[c10::guts::to_string(params[idx].unsafeGetTensorImpl())] = std::move(state); + state_[c10::guts::to_string(params[idx].unsafeGetTensorImpl())] = + std::move(state); } } } diff --git a/torch/csrc/api/src/optim/adam.cpp b/torch/csrc/api/src/optim/adam.cpp index 1169c590479bb4..7aadfcaba60f7b 100644 --- a/torch/csrc/api/src/optim/adam.cpp +++ b/torch/csrc/api/src/optim/adam.cpp @@ -18,11 +18,11 @@ AdamOptions::AdamOptions(double lr) : lr_(lr) {} bool operator==(const AdamOptions& lhs, const AdamOptions& rhs) { return (lhs.lr() == rhs.lr()) && - (std::get<0>(lhs.betas()) == std::get<0>(rhs.betas())) && - (std::get<1>(lhs.betas()) == std::get<1>(rhs.betas())) && - (lhs.eps() == rhs.eps()) && - (lhs.weight_decay() == rhs.weight_decay() && - (lhs.amsgrad() == rhs.amsgrad())); + (std::get<0>(lhs.betas()) == std::get<0>(rhs.betas())) && + (std::get<1>(lhs.betas()) == std::get<1>(rhs.betas())) && + (lhs.eps() == rhs.eps()) && + (lhs.weight_decay() == rhs.weight_decay() && + (lhs.amsgrad() == rhs.amsgrad())); } void AdamOptions::serialize(torch::serialize::OutputArchive& archive) const { @@ -51,9 +51,9 @@ void AdamOptions::set_lr(const double lr) { bool operator==(const AdamParamState& lhs, const AdamParamState& rhs) { return (lhs.step() == rhs.step()) && - torch::equal(lhs.exp_avg(), rhs.exp_avg()) && - torch::equal(lhs.exp_avg_sq(), rhs.exp_avg_sq()) && - torch::equal_if_defined(lhs.max_exp_avg_sq(), rhs.max_exp_avg_sq()); + torch::equal(lhs.exp_avg(), rhs.exp_avg()) && + torch::equal(lhs.exp_avg_sq(), rhs.exp_avg_sq()) && + torch::equal_if_defined(lhs.max_exp_avg_sq(), rhs.max_exp_avg_sq()); } void AdamParamState::serialize(torch::serialize::OutputArchive& archive) const { @@ -70,7 +70,7 @@ void AdamParamState::serialize(torch::serialize::InputArchive& archive) { _TORCH_OPTIM_DESERIALIZE_TORCH_ARG(Tensor, max_exp_avg_sq); } -Tensor Adam::step(LossClosure closure) { +Tensor Adam::step(LossClosure closure) { NoGradGuard no_grad; Tensor loss = {}; if (closure != nullptr) { @@ -83,38 +83,41 @@ Tensor Adam::step(LossClosure closure) { continue; } auto grad = p.grad(); - TORCH_CHECK(!grad.is_sparse(), "Adam does not support sparse gradients"/*, please consider SparseAdam instead*/); - auto param_state = state_.find(c10::guts::to_string(p.unsafeGetTensorImpl())); + TORCH_CHECK(!grad.is_sparse(), "Adam does not support sparse gradients" /*, please consider SparseAdam instead*/); + auto param_state = + state_.find(c10::guts::to_string(p.unsafeGetTensorImpl())); auto& options = static_cast(group.options()); // State initialization - if(param_state == state_.end()) { + if (param_state == state_.end()) { auto state = std::make_unique(); state->step(0); // Exponential moving average of gradient values state->exp_avg(torch::zeros_like(p, MemoryFormat::Preserve)); // Exponential moving average of squared gradient values state->exp_avg_sq(torch::zeros_like(p, MemoryFormat::Preserve)); - if(options.amsgrad()) { + if (options.amsgrad()) { // Maintains max of all exp. moving avg. of sq. grad. values state->max_exp_avg_sq(torch::zeros_like(p, MemoryFormat::Preserve)); } - state_[c10::guts::to_string(p.unsafeGetTensorImpl())] = std::move(state); + state_[c10::guts::to_string(p.unsafeGetTensorImpl())] = + std::move(state); } - auto& state = static_cast(*state_[c10::guts::to_string(p.unsafeGetTensorImpl())]); + auto& state = static_cast( + *state_[c10::guts::to_string(p.unsafeGetTensorImpl())]); auto& exp_avg = state.exp_avg(); auto& exp_avg_sq = state.exp_avg_sq(); auto& max_exp_avg_sq = state.max_exp_avg_sq(); - state.step(state.step()+1); + state.step(state.step() + 1); auto beta1 = std::get<0>(options.betas()); auto beta2 = std::get<1>(options.betas()); auto bias_correction1 = 1 - std::pow(beta1, state.step()); auto bias_correction2 = 1 - std::pow(beta2, state.step()); - if(options.weight_decay() != 0) { + if (options.weight_decay() != 0) { grad = grad.add(p, options.weight_decay()); } @@ -123,13 +126,15 @@ Tensor Adam::step(LossClosure closure) { exp_avg_sq.mul_(beta2).addcmul_(grad, grad, 1 - beta2); Tensor denom; - if(options.amsgrad()) { + if (options.amsgrad()) { // Maintains the maximum of all 2nd moment running avg. till now torch::max_out(max_exp_avg_sq, exp_avg_sq, max_exp_avg_sq); // Use the max. for normalizing running avg. of gradient - denom = (max_exp_avg_sq.sqrt() / sqrt(bias_correction2)).add_(options.eps()); + denom = (max_exp_avg_sq.sqrt() / sqrt(bias_correction2)) + .add_(options.eps()); } else { - denom = (exp_avg_sq.sqrt() / sqrt(bias_correction2)).add_(options.eps()); + denom = + (exp_avg_sq.sqrt() / sqrt(bias_correction2)).add_(options.eps()); } auto step_size = options.lr() / bias_correction1; @@ -147,22 +152,26 @@ void Adam::load(serialize::InputArchive& archive) { IValue pytorch_version; if (archive.try_read("pytorch_version", pytorch_version)) { serialize(*this, archive); - } - else { // deserializing archives saved in old format (prior to version 1.5.0) + } else { // deserializing archives saved in old format (prior to + // version 1.5.0) TORCH_WARN( - "Your serialized Adam optimizer is still using the old serialization format. " - "You should re-save your Adam optimizer to use the new serialization format."); + "Your serialized Adam optimizer is still using the old serialization format. " + "You should re-save your Adam optimizer to use the new serialization format."); std::vector step_buffers; std::vector exp_average_buffers; std::vector exp_average_sq_buffers; std::vector max_exp_average_sq_buffers; torch::optim::serialize(archive, "step_buffers", step_buffers); - torch::optim::serialize(archive, "exp_average_buffers", exp_average_buffers); - torch::optim::serialize(archive, "exp_average_sq_buffers", exp_average_sq_buffers); - torch::optim::serialize(archive, "max_exp_average_sq_buffers", max_exp_average_sq_buffers); - // since there were no param_groups prior to version 1.5.0, assuming all tensors are now in one param_group + torch::optim::serialize( + archive, "exp_average_buffers", exp_average_buffers); + torch::optim::serialize( + archive, "exp_average_sq_buffers", exp_average_sq_buffers); + torch::optim::serialize( + archive, "max_exp_average_sq_buffers", max_exp_average_sq_buffers); + // since there were no param_groups prior to version 1.5.0, assuming all + // tensors are now in one param_group std::vector params = param_groups_.at(0).params(); - for(const auto idx : c10::irange(step_buffers.size())) { + for (const auto idx : c10::irange(step_buffers.size())) { auto state = std::make_unique(); state->step(step_buffers.at(idx)); state->exp_avg(exp_average_buffers.at(idx)); @@ -170,7 +179,8 @@ void Adam::load(serialize::InputArchive& archive) { if (idx < max_exp_average_sq_buffers.size()) { state->max_exp_avg_sq(max_exp_average_sq_buffers.at(idx)); } - state_[c10::guts::to_string(params.at(idx).unsafeGetTensorImpl())] = std::move(state); + state_[c10::guts::to_string(params.at(idx).unsafeGetTensorImpl())] = + std::move(state); } } } diff --git a/torch/csrc/api/src/optim/adamw.cpp b/torch/csrc/api/src/optim/adamw.cpp index 61dd7f8a25e859..d049fbaf02711a 100644 --- a/torch/csrc/api/src/optim/adamw.cpp +++ b/torch/csrc/api/src/optim/adamw.cpp @@ -18,11 +18,10 @@ AdamWOptions::AdamWOptions(double lr) : lr_(lr) {} bool operator==(const AdamWOptions& lhs, const AdamWOptions& rhs) { return (lhs.lr() == rhs.lr()) && - (std::get<0>(lhs.betas()) == std::get<0>(rhs.betas())) && - (std::get<1>(lhs.betas()) == std::get<1>(rhs.betas())) && - (lhs.eps() == rhs.eps()) && - (lhs.weight_decay() == rhs.weight_decay()) && - (lhs.amsgrad() == rhs.amsgrad()); + (std::get<0>(lhs.betas()) == std::get<0>(rhs.betas())) && + (std::get<1>(lhs.betas()) == std::get<1>(rhs.betas())) && + (lhs.eps() == rhs.eps()) && (lhs.weight_decay() == rhs.weight_decay()) && + (lhs.amsgrad() == rhs.amsgrad()); } void AdamWOptions::serialize(torch::serialize::OutputArchive& archive) const { @@ -51,12 +50,13 @@ void AdamWOptions::set_lr(const double lr) { bool operator==(const AdamWParamState& lhs, const AdamWParamState& rhs) { return (lhs.step() == rhs.step()) && - torch::equal(lhs.exp_avg(), rhs.exp_avg()) && - torch::equal(lhs.exp_avg_sq(), rhs.exp_avg_sq()) && - torch::equal_if_defined(lhs.max_exp_avg_sq(), rhs.max_exp_avg_sq()); + torch::equal(lhs.exp_avg(), rhs.exp_avg()) && + torch::equal(lhs.exp_avg_sq(), rhs.exp_avg_sq()) && + torch::equal_if_defined(lhs.max_exp_avg_sq(), rhs.max_exp_avg_sq()); } -void AdamWParamState::serialize(torch::serialize::OutputArchive& archive) const { +void AdamWParamState::serialize( + torch::serialize::OutputArchive& archive) const { _TORCH_OPTIM_SERIALIZE_TORCH_ARG(step); _TORCH_OPTIM_SERIALIZE_TORCH_ARG(exp_avg); _TORCH_OPTIM_SERIALIZE_TORCH_ARG(exp_avg_sq); @@ -70,7 +70,7 @@ void AdamWParamState::serialize(torch::serialize::InputArchive& archive) { _TORCH_OPTIM_DESERIALIZE_TORCH_ARG(Tensor, max_exp_avg_sq); } -Tensor AdamW::step(LossClosure closure) { +Tensor AdamW::step(LossClosure closure) { NoGradGuard no_grad; Tensor loss = {}; if (closure != nullptr) { @@ -83,36 +83,39 @@ Tensor AdamW::step(LossClosure closure) { continue; } const auto& grad = p.grad(); - TORCH_CHECK(!grad.is_sparse(), "AdamW does not support sparse gradients"/*, please consider SparseAdamW instead*/); - auto param_state = state_.find(c10::guts::to_string(p.unsafeGetTensorImpl())); + TORCH_CHECK(!grad.is_sparse(), "AdamW does not support sparse gradients" /*, please consider SparseAdamW instead*/); + auto param_state = + state_.find(c10::guts::to_string(p.unsafeGetTensorImpl())); auto& options = static_cast(group.options()); // Perform stepweight decay - if(options.weight_decay() != 0) { + if (options.weight_decay() != 0) { p.mul_(1 - options.lr() * options.weight_decay()); } // State initialization - if(param_state == state_.end()) { + if (param_state == state_.end()) { auto state = std::make_unique(); state->step(0); // Exponential moving average of gradient values state->exp_avg(torch::zeros_like(p, MemoryFormat::Preserve)); // Exponential moving average of squared gradient values state->exp_avg_sq(torch::zeros_like(p, MemoryFormat::Preserve)); - if(options.amsgrad()) { + if (options.amsgrad()) { // Maintains max of all exp. moving avg. of sq. grad. values state->max_exp_avg_sq(torch::zeros_like(p, MemoryFormat::Preserve)); } - state_[c10::guts::to_string(p.unsafeGetTensorImpl())] = std::move(state); + state_[c10::guts::to_string(p.unsafeGetTensorImpl())] = + std::move(state); } - auto& state = static_cast(*state_[c10::guts::to_string(p.unsafeGetTensorImpl())]); + auto& state = static_cast( + *state_[c10::guts::to_string(p.unsafeGetTensorImpl())]); auto& exp_avg = state.exp_avg(); auto& exp_avg_sq = state.exp_avg_sq(); auto& max_exp_avg_sq = state.max_exp_avg_sq(); - state.step(state.step()+1); + state.step(state.step() + 1); auto beta1 = std::get<0>(options.betas()); auto beta2 = std::get<1>(options.betas()); @@ -124,13 +127,15 @@ Tensor AdamW::step(LossClosure closure) { exp_avg_sq.mul_(beta2).addcmul_(grad, grad, 1 - beta2); Tensor denom; - if(options.amsgrad()) { + if (options.amsgrad()) { // Maintains the maximum of all 2nd moment running avg. till now torch::max_out(max_exp_avg_sq, exp_avg_sq, max_exp_avg_sq); // Use the max. for normalizing running avg. of gradient - denom = (max_exp_avg_sq.sqrt() / sqrt(bias_correction2)).add_(options.eps()); + denom = (max_exp_avg_sq.sqrt() / sqrt(bias_correction2)) + .add_(options.eps()); } else { - denom = (exp_avg_sq.sqrt() / sqrt(bias_correction2)).add_(options.eps()); + denom = + (exp_avg_sq.sqrt() / sqrt(bias_correction2)).add_(options.eps()); } auto step_size = options.lr() / bias_correction1; @@ -148,22 +153,26 @@ void AdamW::load(serialize::InputArchive& archive) { IValue pytorch_version; if (archive.try_read("pytorch_version", pytorch_version)) { serialize(*this, archive); - } - else { // deserializing archives saved in old format (prior to version 1.5.0) + } else { // deserializing archives saved in old format (prior to + // version 1.5.0) TORCH_WARN( - "Your serialized AdamW optimizer is still using the old serialization format. " - "You should re-save your AdamW optimizer to use the new serialization format."); + "Your serialized AdamW optimizer is still using the old serialization format. " + "You should re-save your AdamW optimizer to use the new serialization format."); std::vector step_buffers; std::vector exp_average_buffers; std::vector exp_average_sq_buffers; std::vector max_exp_average_sq_buffers; torch::optim::serialize(archive, "step_buffers", step_buffers); - torch::optim::serialize(archive, "exp_average_buffers", exp_average_buffers); - torch::optim::serialize(archive, "exp_average_sq_buffers", exp_average_sq_buffers); - torch::optim::serialize(archive, "max_exp_average_sq_buffers", max_exp_average_sq_buffers); - // since there were no param_groups prior to version 1.5.0, assuming all tensors are now in one param_group + torch::optim::serialize( + archive, "exp_average_buffers", exp_average_buffers); + torch::optim::serialize( + archive, "exp_average_sq_buffers", exp_average_sq_buffers); + torch::optim::serialize( + archive, "max_exp_average_sq_buffers", max_exp_average_sq_buffers); + // since there were no param_groups prior to version 1.5.0, assuming all + // tensors are now in one param_group std::vector params = param_groups_.at(0).params(); - for(const auto idx : c10::irange(step_buffers.size())) { + for (const auto idx : c10::irange(step_buffers.size())) { auto state = std::make_unique(); state->step(step_buffers.at(idx)); state->exp_avg(exp_average_buffers.at(idx)); @@ -171,7 +180,8 @@ void AdamW::load(serialize::InputArchive& archive) { if (idx < max_exp_average_sq_buffers.size()) { state->max_exp_avg_sq(max_exp_average_sq_buffers.at(idx)); } - state_[c10::guts::to_string(params.at(idx).unsafeGetTensorImpl())] = std::move(state); + state_[c10::guts::to_string(params.at(idx).unsafeGetTensorImpl())] = + std::move(state); } } } diff --git a/torch/csrc/api/src/optim/lbfgs.cpp b/torch/csrc/api/src/optim/lbfgs.cpp index d3143b07ccdd58..36a63bffba741a 100644 --- a/torch/csrc/api/src/optim/lbfgs.cpp +++ b/torch/csrc/api/src/optim/lbfgs.cpp @@ -8,10 +8,10 @@ #include #include +#include #include #include #include -#include namespace torch { namespace optim { @@ -19,13 +19,12 @@ namespace optim { LBFGSOptions::LBFGSOptions(double lr) : lr_(lr) {} bool operator==(const LBFGSOptions& lhs, const LBFGSOptions& rhs) { - return (lhs.lr() == rhs.lr()) && - (lhs.max_iter() == rhs.max_iter()) && - (lhs.max_eval() == rhs.max_eval()) && - (lhs.tolerance_grad() == rhs.tolerance_grad()) && - (lhs.tolerance_change() == rhs.tolerance_change() && - (lhs.history_size() == rhs.history_size())) && - (lhs.line_search_fn() == rhs.line_search_fn()); + return (lhs.lr() == rhs.lr()) && (lhs.max_iter() == rhs.max_iter()) && + (lhs.max_eval() == rhs.max_eval()) && + (lhs.tolerance_grad() == rhs.tolerance_grad()) && + (lhs.tolerance_change() == rhs.tolerance_change() && + (lhs.history_size() == rhs.history_size())) && + (lhs.line_search_fn() == rhs.line_search_fn()); } void LBFGSOptions::serialize(torch::serialize::OutputArchive& archive) const { @@ -58,30 +57,35 @@ void LBFGSOptions::set_lr(const double lr) { template bool if_container_equal(T lhs, T rhs) { - if (!(lhs.size() == rhs.size())) return false; - for(const auto i : c10::irange(lhs.size())) { - if (!torch::equal(lhs.at(i), rhs.at(i))) return false; + if (!(lhs.size() == rhs.size())) + return false; + for (const auto i : c10::irange(lhs.size())) { + if (!torch::equal(lhs.at(i), rhs.at(i))) + return false; } return true; } bool operator==(const LBFGSParamState& lhs, const LBFGSParamState& rhs) { - auto isNull = [](const c10::optional>& val) { return val == c10::nullopt; }; + auto isNull = [](const c10::optional>& val) { + return val == c10::nullopt; + }; return (lhs.func_evals() == rhs.func_evals()) && - (lhs.n_iter() == rhs.n_iter()) && - (lhs.t() == rhs.t()) && - (lhs.prev_loss() == rhs.prev_loss()) && - torch::equal_if_defined(lhs.d(), rhs.d()) && - torch::equal_if_defined(lhs.H_diag(), rhs.H_diag()) && - torch::equal_if_defined(lhs.prev_flat_grad(), rhs.prev_flat_grad()) && - if_container_equal(lhs.old_dirs(), rhs.old_dirs()) && - if_container_equal(lhs.old_stps(), rhs.old_stps()) && - if_container_equal(lhs.ro(), rhs.ro()) && - ((isNull(lhs.al()) && isNull(rhs.al())) || - (!isNull(lhs.al()) && !isNull(rhs.al()) && if_container_equal(*lhs.al(), *rhs.al()))); + (lhs.n_iter() == rhs.n_iter()) && (lhs.t() == rhs.t()) && + (lhs.prev_loss() == rhs.prev_loss()) && + torch::equal_if_defined(lhs.d(), rhs.d()) && + torch::equal_if_defined(lhs.H_diag(), rhs.H_diag()) && + torch::equal_if_defined(lhs.prev_flat_grad(), rhs.prev_flat_grad()) && + if_container_equal(lhs.old_dirs(), rhs.old_dirs()) && + if_container_equal(lhs.old_stps(), rhs.old_stps()) && + if_container_equal(lhs.ro(), rhs.ro()) && + ((isNull(lhs.al()) && isNull(rhs.al())) || + (!isNull(lhs.al()) && !isNull(rhs.al()) && + if_container_equal(*lhs.al(), *rhs.al()))); } -void LBFGSParamState::serialize(torch::serialize::OutputArchive& archive) const { +void LBFGSParamState::serialize( + torch::serialize::OutputArchive& archive) const { _TORCH_OPTIM_SERIALIZE_TORCH_ARG(func_evals); _TORCH_OPTIM_SERIALIZE_TORCH_ARG(n_iter); _TORCH_OPTIM_SERIALIZE_TORCH_ARG(t); @@ -93,7 +97,7 @@ void LBFGSParamState::serialize(torch::serialize::OutputArchive& archive) const _TORCH_OPTIM_SERIALIZE_TORCH_ARG_DEQUE(old_stps); _TORCH_OPTIM_SERIALIZE_TORCH_ARG_DEQUE(ro); // Python version only serializes state vars if explicitly defined - if(al() != c10::nullopt) { + if (al() != c10::nullopt) { _TORCH_OPTIM_SERIALIZE_TORCH_ARG(al); } } @@ -117,11 +121,9 @@ Tensor LBFGS::_gather_flat_grad() { for (const auto& p : param_groups_.at(0).params()) { if (!p.grad().defined()) { views.emplace_back(p.new_empty({p.numel()}).zero_()); - } - else if (p.grad().is_sparse()) { + } else if (p.grad().is_sparse()) { views.emplace_back(p.grad().to_dense().view(-1)); - } - else { + } else { views.emplace_back(p.grad().view(-1)); } } @@ -144,8 +146,9 @@ void LBFGS::_add_grad(const double step_size, const Tensor& update) { for (auto& p : param_groups_.at(0).params()) { auto numel = p.numel(); // view as to avoid deprecated pointwise semantics - p.add_(update.index( - {at::indexing::Slice(offset, offset + numel)}).view_as(p), step_size); + p.add_( + update.index({at::indexing::Slice(offset, offset + numel)}).view_as(p), + step_size); offset += numel; } TORCH_INTERNAL_ASSERT(offset == _numel()); @@ -154,7 +157,7 @@ void LBFGS::_add_grad(const double step_size, const Tensor& update) { void LBFGS::_set_param(const std::vector& params_data) { auto& _params = param_groups_.at(0).params(); TORCH_INTERNAL_ASSERT(params_data.size() == _params.size()); - for(const auto i : c10::irange(_params.size())) { + for (const auto i : c10::irange(_params.size())) { _params.at(i).copy_(params_data.at(i)); } } @@ -168,22 +171,30 @@ std::vector LBFGS::_clone_param() { } std::tuple LBFGS::_directional_evaluate( - const LossClosure& closure, const std::vector& x, double t, const Tensor& d) { - _add_grad(t, d); - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - double loss; - { - torch::AutoGradMode enable_grad(true); - loss = closure().item(); - } - auto flat_grad = _gather_flat_grad(); - _set_param(x); - return std::make_tuple(loss, flat_grad); + const LossClosure& closure, + const std::vector& x, + double t, + const Tensor& d) { + _add_grad(t, d); + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + double loss; + { + torch::AutoGradMode enable_grad(true); + loss = closure().item(); + } + auto flat_grad = _gather_flat_grad(); + _set_param(x); + return std::make_tuple(loss, flat_grad); } double _cubic_interpolate( - double x1, double f1, double g1, double x2, double f2, double g2, - c10::optional> bounds = c10::nullopt) { + double x1, + double f1, + double g1, + double x2, + double f2, + double g2, + c10::optional> bounds = c10::nullopt) { // ported from https://github.com/torch/optim/blob/master/polyinterp.lua // Compute bounds of interpolation area // NOLINTNEXTLINE(cppcoreguidelines-init-variables) @@ -192,7 +203,7 @@ double _cubic_interpolate( std::tie(xmin_bound, xmax_bound) = *bounds; } else { std::tie(xmin_bound, xmax_bound) = - (x1 <= x2) ? std::make_tuple(x1, x2) : std::make_tuple(x2, x1); + (x1 <= x2) ? std::make_tuple(x1, x2) : std::make_tuple(x2, x1); } // Code for most common case: cubic interpolation of 2 points // w/ function and derivative values for both @@ -221,166 +232,198 @@ double _cubic_interpolate( } } -using Function = std::function(const std::vector& x, double t, const Tensor& d)>; -std::tuple _strong_wolfe(const Function& obj_func, const std::vector& x, - double t, const Tensor& d, double f, Tensor g, const Tensor& gtd, - double c1 = 1e-4, double c2 = 0.9, // // NOLINT(cppcoreguidelines-avoid-magic-numbers) - double tolerance_change = 1e-9, double max_ls = 25) { // NOLINT(cppcoreguidelines-avoid-magic-numbers) - - auto val = [](const Tensor& t) {return t.item(); }; - - auto d_norm = val(d.abs().max()); - g = g.clone(at::MemoryFormat::Contiguous); - // evaluate objective and gradient using initial step - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - double f_new; - Tensor g_new; - std::tie(f_new, g_new) = obj_func(x, t, d); - int64_t ls_func_evals = 1; - auto gtd_new = g_new.dot(d); - - // bracket an interval containing a point satisfying the Wolfe criteria - double t_prev = 0; - auto f_prev = f; - auto g_prev = g; - auto gtd_prev = gtd; - bool done = false; - auto ls_iter = 0; - std::vector bracket, bracket_f; - std::vector bracket_g, bracket_gtd; - - while (ls_iter < max_ls) { - // check conditions - if ((f_new > (f + c1 * t * val(gtd))) || (ls_iter > 1 && (f_new >= f_prev))) { - bracket = {t_prev, t}; - bracket_f = {f_prev, f_new}; - bracket_g = {g_prev, g_new.clone(at::MemoryFormat::Contiguous)}; - bracket_gtd = {gtd_prev, gtd_new}; - break; - } - if (std::abs(val(gtd_new)) <= (-c2 * val(gtd))) { - bracket = {t, t}; - bracket_f = {f_new, f_new}; - bracket_g = {g_new, g_new}; - done = true; - break; - } - if (val(gtd_new) >= 0) { - bracket = {t_prev, t}; - bracket_f = {f_prev, f_new}; - bracket_g = {g_prev, g_new.clone(at::MemoryFormat::Contiguous)}; - bracket_gtd = {gtd_prev, gtd_new}; - break; - } - // interpolate - auto min_step = t + 0.01 * (t - t_prev); // NOLINT(cppcoreguidelines-avoid-magic-numbers) - auto max_step = t * 10; // NOLINT(cppcoreguidelines-avoid-magic-numbers) - auto tmp = t; - t = _cubic_interpolate(t_prev, f_prev, val(gtd_prev), - t, f_new, val(gtd_new), - std::make_tuple(min_step, max_step)); - // next step - t_prev = tmp; - f_prev = f_new; - g_prev = g_new.clone(at::MemoryFormat::Contiguous); - gtd_prev = gtd_new; - std::tie(f_new, g_new) = obj_func(x, t, d); - ls_func_evals += 1; - gtd_new = g_new.dot(d); - ls_iter += 1; +using Function = std::function( + const std::vector& x, + double t, + const Tensor& d)>; +std::tuple _strong_wolfe( + const Function& obj_func, + const std::vector& x, + double t, + const Tensor& d, + double f, + Tensor g, + const Tensor& gtd, + double c1 = 1e-4, + double c2 = 0.9, // // NOLINT(cppcoreguidelines-avoid-magic-numbers) + double tolerance_change = 1e-9, + double max_ls = 25) { // NOLINT(cppcoreguidelines-avoid-magic-numbers) + + auto val = [](const Tensor& t) { return t.item(); }; + + auto d_norm = val(d.abs().max()); + g = g.clone(at::MemoryFormat::Contiguous); + // evaluate objective and gradient using initial step + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + double f_new; + Tensor g_new; + std::tie(f_new, g_new) = obj_func(x, t, d); + int64_t ls_func_evals = 1; + auto gtd_new = g_new.dot(d); + + // bracket an interval containing a point satisfying the Wolfe criteria + double t_prev = 0; + auto f_prev = f; + auto g_prev = g; + auto gtd_prev = gtd; + bool done = false; + auto ls_iter = 0; + std::vector bracket, bracket_f; + std::vector bracket_g, bracket_gtd; + + while (ls_iter < max_ls) { + // check conditions + if ((f_new > (f + c1 * t * val(gtd))) || + (ls_iter > 1 && (f_new >= f_prev))) { + bracket = {t_prev, t}; + bracket_f = {f_prev, f_new}; + bracket_g = {g_prev, g_new.clone(at::MemoryFormat::Contiguous)}; + bracket_gtd = {gtd_prev, gtd_new}; + break; } - // reached max number of iterations? - if (ls_iter == max_ls) { - bracket = {0, t}; - bracket_f = {f, f_new}; - bracket_g = {g, g_new}; + if (std::abs(val(gtd_new)) <= (-c2 * val(gtd))) { + bracket = {t, t}; + bracket_f = {f_new, f_new}; + bracket_g = {g_new, g_new}; + done = true; + break; } + if (val(gtd_new) >= 0) { + bracket = {t_prev, t}; + bracket_f = {f_prev, f_new}; + bracket_g = {g_prev, g_new.clone(at::MemoryFormat::Contiguous)}; + bracket_gtd = {gtd_prev, gtd_new}; + break; + } + // interpolate + auto min_step = t + + 0.01 * (t - t_prev); // NOLINT(cppcoreguidelines-avoid-magic-numbers) + auto max_step = t * 10; // NOLINT(cppcoreguidelines-avoid-magic-numbers) + auto tmp = t; + t = _cubic_interpolate( + t_prev, + f_prev, + val(gtd_prev), + t, + f_new, + val(gtd_new), + std::make_tuple(min_step, max_step)); + // next step + t_prev = tmp; + f_prev = f_new; + g_prev = g_new.clone(at::MemoryFormat::Contiguous); + gtd_prev = gtd_new; + std::tie(f_new, g_new) = obj_func(x, t, d); + ls_func_evals += 1; + gtd_new = g_new.dot(d); + ls_iter += 1; + } + // reached max number of iterations? + if (ls_iter == max_ls) { + bracket = {0, t}; + bracket_f = {f, f_new}; + bracket_g = {g, g_new}; + } - // zoom phase: we now have a point satisfying the criteria, or - // a bracket around it. We refine the bracket until we find the - // exact point satisfying the criteria - bool insuf_progress = false; - // find high and low points in bracket - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - int64_t low_pos, high_pos; - std::tie(low_pos, high_pos) = bracket_f[0] <= bracket_f[1] ? std::make_tuple(0, 1) : std::make_tuple(1, 0); - while(!done && (ls_iter < max_ls)) { - // compute new trial value - t = _cubic_interpolate(bracket[0], bracket_f[0], val(bracket_gtd[0]), - bracket[1], bracket_f[1], val(bracket_gtd[1])); - - // test that we are making sufficient progress: - // in case `t` is so close to boundary, we mark that we are making - // insufficient progress, and if - // + we have made insufficient progress in the last step, or - // + `t` is at one of the boundary, - // we will move `t` to a position which is `0.1 * len(bracket)` - // away from the nearest boundary point. - double bracket_max = std::max(bracket[0], bracket[1]); - auto bracket_min = std::min(bracket[0], bracket[1]); - auto eps = 0.1 * (bracket_max - bracket_min); // // NOLINT(cppcoreguidelines-avoid-magic-numbers) - if (std::min(bracket_max - t, t - bracket_min) < eps) { - // interpolation close to boundary - if (insuf_progress || (t >= bracket_max) || (t <= bracket_min)) { - // evaluate at 0.1 away from boundary - t = (std::abs(t - bracket_max) < std::abs(t - bracket_min)) ? bracket_max - eps : bracket_min + eps; - insuf_progress = false; - } else { - insuf_progress = true; - } - } else { + // zoom phase: we now have a point satisfying the criteria, or + // a bracket around it. We refine the bracket until we find the + // exact point satisfying the criteria + bool insuf_progress = false; + // find high and low points in bracket + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + int64_t low_pos, high_pos; + std::tie(low_pos, high_pos) = bracket_f[0] <= bracket_f[1] + ? std::make_tuple(0, 1) + : std::make_tuple(1, 0); + while (!done && (ls_iter < max_ls)) { + // compute new trial value + t = _cubic_interpolate( + bracket[0], + bracket_f[0], + val(bracket_gtd[0]), + bracket[1], + bracket_f[1], + val(bracket_gtd[1])); + + // test that we are making sufficient progress: + // in case `t` is so close to boundary, we mark that we are making + // insufficient progress, and if + // + we have made insufficient progress in the last step, or + // + `t` is at one of the boundary, + // we will move `t` to a position which is `0.1 * len(bracket)` + // away from the nearest boundary point. + double bracket_max = std::max(bracket[0], bracket[1]); + auto bracket_min = std::min(bracket[0], bracket[1]); + auto eps = 0.1 * + (bracket_max - + bracket_min); // // NOLINT(cppcoreguidelines-avoid-magic-numbers) + if (std::min(bracket_max - t, t - bracket_min) < eps) { + // interpolation close to boundary + if (insuf_progress || (t >= bracket_max) || (t <= bracket_min)) { + // evaluate at 0.1 away from boundary + t = (std::abs(t - bracket_max) < std::abs(t - bracket_min)) + ? bracket_max - eps + : bracket_min + eps; insuf_progress = false; - } - - // Evaluate new point - std::tie(f_new, g_new) = obj_func(x, t, d); - ls_func_evals += 1; - gtd_new = g_new.dot(d); - ls_iter += 1; - - if ((f_new > (f + c1 * t * val(gtd))) || (f_new >= bracket_f[low_pos])) { - // Armijo condition not satisfied or not lower than lowest point - // # Armijo condition not satisfied or not lower than lowest point - bracket[high_pos] = t; - bracket_f[high_pos] = f_new; - bracket_g[high_pos] = g_new.clone(at::MemoryFormat::Contiguous); - bracket_gtd[high_pos] = gtd_new; - std::tie(low_pos, high_pos) = bracket_f[0] <= bracket_f[1] ? std::make_tuple(0, 1) : std::make_tuple(1, 0); } else { - if (val(at::abs(gtd_new)) <= (-c2 * val(gtd))) { - // Wolfe conditions satisfied - done = true; - } else if ((val(gtd_new) * (bracket[high_pos] - bracket[low_pos])) >= 0) { - // old high becomes new low - bracket[high_pos] = bracket[low_pos]; - bracket_f[high_pos] = bracket_f[low_pos]; - bracket_g[high_pos] = bracket_g[low_pos]; - bracket_gtd[high_pos] = bracket_gtd[low_pos]; - } + insuf_progress = true; + } + } else { + insuf_progress = false; + } - // new point becomes new low - bracket[low_pos] = t; - bracket_f[low_pos] = f_new; - bracket_g[low_pos] = g_new.clone(at::MemoryFormat::Contiguous); - bracket_gtd[low_pos] = gtd_new; + // Evaluate new point + std::tie(f_new, g_new) = obj_func(x, t, d); + ls_func_evals += 1; + gtd_new = g_new.dot(d); + ls_iter += 1; + + if ((f_new > (f + c1 * t * val(gtd))) || (f_new >= bracket_f[low_pos])) { + // Armijo condition not satisfied or not lower than lowest point + // # Armijo condition not satisfied or not lower than lowest point + bracket[high_pos] = t; + bracket_f[high_pos] = f_new; + bracket_g[high_pos] = g_new.clone(at::MemoryFormat::Contiguous); + bracket_gtd[high_pos] = gtd_new; + std::tie(low_pos, high_pos) = bracket_f[0] <= bracket_f[1] + ? std::make_tuple(0, 1) + : std::make_tuple(1, 0); + } else { + if (val(at::abs(gtd_new)) <= (-c2 * val(gtd))) { + // Wolfe conditions satisfied + done = true; + } else if ((val(gtd_new) * (bracket[high_pos] - bracket[low_pos])) >= 0) { + // old high becomes new low + bracket[high_pos] = bracket[low_pos]; + bracket_f[high_pos] = bracket_f[low_pos]; + bracket_g[high_pos] = bracket_g[low_pos]; + bracket_gtd[high_pos] = bracket_gtd[low_pos]; } - // line-search bracket is so small - if ((std::abs(bracket[1] - bracket[0]) * d_norm) < tolerance_change) break; + // new point becomes new low + bracket[low_pos] = t; + bracket_f[low_pos] = f_new; + bracket_g[low_pos] = g_new.clone(at::MemoryFormat::Contiguous); + bracket_gtd[low_pos] = gtd_new; } - // return stuff - t = bracket[low_pos]; - f_new = bracket_f[low_pos]; - g_new = bracket_g[low_pos]; - return std::make_tuple(f_new, g_new, t, ls_func_evals); + // line-search bracket is so small + if ((std::abs(bracket[1] - bracket[0]) * d_norm) < tolerance_change) + break; + } + + // return stuff + t = bracket[low_pos]; + f_new = bracket_f[low_pos]; + g_new = bracket_g[low_pos]; + return std::make_tuple(f_new, g_new, t, ls_func_evals); } Tensor LBFGS::step(LossClosure closure) { NoGradGuard no_grad; TORCH_CHECK(closure != nullptr, "LBFGS requires a closure function"); TORCH_INTERNAL_ASSERT(param_groups_.size() == 1); - auto val = [](const Tensor& t) {return t.item(); }; + auto val = [](const Tensor& t) { return t.item(); }; auto& group = param_groups_.at(0); auto& _params = group.params(); @@ -395,11 +438,14 @@ Tensor LBFGS::step(LossClosure closure) { // NOTE: LBFGS has only global state, but we register it as state for // the first param, because this helps with casting in load_state_dict - auto param_state = state_.find(c10::guts::to_string(_params.at(0).unsafeGetTensorImpl())); + auto param_state = + state_.find(c10::guts::to_string(_params.at(0).unsafeGetTensorImpl())); if (param_state == state_.end()) { - state_[c10::guts::to_string(_params.at(0).unsafeGetTensorImpl())] = std::make_unique(); + state_[c10::guts::to_string(_params.at(0).unsafeGetTensorImpl())] = + std::make_unique(); } - auto& state = static_cast(*state_[c10::guts::to_string(_params.at(0).unsafeGetTensorImpl())]); + auto& state = static_cast( + *state_[c10::guts::to_string(_params.at(0).unsafeGetTensorImpl())]); // evaluate initial f(x) and df/dx Tensor orig_loss; { @@ -431,7 +477,7 @@ Tensor LBFGS::step(LossClosure closure) { int n_iter = 0; // optimize for a max of max_iter iterations - while(n_iter < max_iter) { + while (n_iter < max_iter) { // keep track of nb of iterations n_iter += 1; state.n_iter(state.n_iter() + 1); @@ -462,7 +508,7 @@ Tensor LBFGS::step(LossClosure closure) { ro.emplace_back(1. / ys); // update scale of initial Hessian approximation - H_diag = ys / y.dot(y); // (y*y) + H_diag = ys / y.dot(y); // (y*y) } // compute the approximate (L-BFGS) inverse Hessian @@ -485,7 +531,7 @@ Tensor LBFGS::step(LossClosure closure) { // r/d is the final direction auto r = torch::mul(q, H_diag); d = r; - for(const auto i : c10::irange(num_old)) { + for (const auto i : c10::irange(num_old)) { auto be_i = old_dirs.at(i).dot(r) * ro.at(i); r.add_(old_stps.at(i), val((*al).at(i) - be_i)); } @@ -509,18 +555,25 @@ Tensor LBFGS::step(LossClosure closure) { } // directional derivative - auto gtd = flat_grad.dot(d); // g * d + auto gtd = flat_grad.dot(d); // g * d // directional derivative is below tolerance - if (val(gtd) > -tolerance_change) break; + if (val(gtd) > -tolerance_change) + break; // optional line search: user function auto ls_func_evals = 0; if (line_search_fn != c10::nullopt) { - TORCH_CHECK(*line_search_fn == "strong_wolfe", "only 'strong_wolfe' is supported"); + TORCH_CHECK( + *line_search_fn == "strong_wolfe", + "only 'strong_wolfe' is supported"); auto x_init = _clone_param(); - auto obj_func = [&](const std::vector& x, double t, const Tensor& d) { return _directional_evaluate(closure, x, t, d); }; - std::tie(loss, flat_grad, t, ls_func_evals) = _strong_wolfe(obj_func, x_init, t, d, loss, flat_grad, gtd); + auto obj_func = + [&](const std::vector& x, double t, const Tensor& d) { + return _directional_evaluate(closure, x, t, d); + }; + std::tie(loss, flat_grad, t, ls_func_evals) = + _strong_wolfe(obj_func, x_init, t, d, loss, flat_grad, gtd); _add_grad(t, d); opt_cond = (val(flat_grad.abs().max()) <= tolerance_grad); } else { @@ -546,17 +599,22 @@ Tensor LBFGS::step(LossClosure closure) { // ############################################################ // # check conditions // ############################################################ - if (n_iter == max_iter) break; + if (n_iter == max_iter) + break; - if (current_evals >= *max_eval) break; + if (current_evals >= *max_eval) + break; // optimal condition - if (opt_cond) break; + if (opt_cond) + break; // lack of progress - if (val(d.mul(t).abs().max()) <= tolerance_change) break; + if (val(d.mul(t).abs().max()) <= tolerance_change) + break; - if (std::abs(loss - prev_loss) < tolerance_change) break; + if (std::abs(loss - prev_loss) < tolerance_change) + break; } return orig_loss; @@ -570,13 +628,13 @@ void LBFGS::load(serialize::InputArchive& archive) { IValue pytorch_version; if (archive.try_read("pytorch_version", pytorch_version)) { serialize(*this, archive); - } - else { // deserializing archives saved in old format (prior to version 1.5.0) + } else { // deserializing archives saved in old format (prior to + // version 1.5.0) TORCH_WARN( - "Your serialized LBFGS optimizer is still using the old serialization format. " - "The func_evals and n_iter value in state will be set to 0, ro will be set to an empty deque " - "and al will be set to c10::nullopt because the old LBFGS optimizer didn't save these values." - "You should re-save your LBFGS optimizer to use the new serialization format."); + "Your serialized LBFGS optimizer is still using the old serialization format. " + "The func_evals and n_iter value in state will be set to 0, ro will be set to an empty deque " + "and al will be set to c10::nullopt because the old LBFGS optimizer didn't save these values." + "You should re-save your LBFGS optimizer to use the new serialization format."); Tensor d, t, H_diag, prev_flat_grad, prev_loss; std::deque old_dirs, old_stps; archive("d", d, /*is_buffer=*/true); @@ -597,7 +655,9 @@ void LBFGS::load(serialize::InputArchive& archive) { state->prev_loss(prev_loss.item()); state->old_dirs(old_dirs); state->old_stps(old_stps); - state_[c10::guts::to_string(param_groups_.at(0).params().at(0).unsafeGetTensorImpl())] = std::move(state); + state_[c10::guts::to_string( + param_groups_.at(0).params().at(0).unsafeGetTensorImpl())] = + std::move(state); } } } // namespace optim diff --git a/torch/csrc/api/src/optim/optimizer.cpp b/torch/csrc/api/src/optim/optimizer.cpp index de13cc0a9f4a19..95165d850cf6f7 100644 --- a/torch/csrc/api/src/optim/optimizer.cpp +++ b/torch/csrc/api/src/optim/optimizer.cpp @@ -25,7 +25,8 @@ const OptimizerOptions& OptimizerParamGroup::options() const { return *options_.get(); } -void OptimizerParamGroup::set_options(std::unique_ptr options) { +void OptimizerParamGroup::set_options( + std::unique_ptr options) { options_ = std::move(options); } @@ -38,49 +39,61 @@ const std::vector& OptimizerParamGroup::params() const { } std::unique_ptr OptimizerParamState::clone() const { - TORCH_CHECK(false, + TORCH_CHECK( + false, "clone() has not been implemented for torch::optim::OptimizerParamState. ", "Subclass torch::optim::OptimizerCloneableParamState ", "instead of torch::optim::OptimizerParamState to inherit the ability to clone."); } void OptimizerParamState::serialize(torch::serialize::InputArchive& archive) { - TORCH_CHECK(false, - "void serialize(torch::serialize::InputArchive& archive) has not been implemented for torch::optim::OptimizerParamState. ", - "You must override it in your subclass of torch::optim::OptimizerCloneableParamState."); + TORCH_CHECK( + false, + "void serialize(torch::serialize::InputArchive& archive) has not been implemented for torch::optim::OptimizerParamState. ", + "You must override it in your subclass of torch::optim::OptimizerCloneableParamState."); } -void OptimizerParamState::serialize(torch::serialize::OutputArchive& archive) const { - TORCH_CHECK(false, - "void serialize(torch::serialize::OutputArchive& archive) has not been implemented for torch::optim::OptimizerParamState. ", - "You must override it in your subclass of torch::optim::OptimizerCloneableParamState."); +void OptimizerParamState::serialize( + torch::serialize::OutputArchive& archive) const { + TORCH_CHECK( + false, + "void serialize(torch::serialize::OutputArchive& archive) has not been implemented for torch::optim::OptimizerParamState. ", + "You must override it in your subclass of torch::optim::OptimizerCloneableParamState."); } double OptimizerOptions::get_lr() const { - TORCH_CHECK(false, "double get_lr() has not been overidden and implemented in subclass of torch::optim::OptimizerOptions, you must override it in your subclass."); + TORCH_CHECK( + false, + "double get_lr() has not been overidden and implemented in subclass of torch::optim::OptimizerOptions, you must override it in your subclass."); } void OptimizerOptions::set_lr(const double lr) { - TORCH_CHECK(false, "double set_lr() has not been overidden and implemented in subclass of torch::optim::OptimizerOptions, you must override it in your subclass."); + TORCH_CHECK( + false, + "double set_lr() has not been overidden and implemented in subclass of torch::optim::OptimizerOptions, you must override it in your subclass."); } std::unique_ptr OptimizerOptions::clone() const { - TORCH_CHECK(false, + TORCH_CHECK( + false, "clone() has not been implemented for torch::optim::OptimizerOptions. ", "Subclass torch::optim::OptimizerCloneableOptions ", "instead of torch::optim::OptimizerOptions to inherit the ability to clone."); } void OptimizerOptions::serialize(torch::serialize::InputArchive& archive) { - TORCH_CHECK(false, - "void serialize(torch::serialize::InputArchive& archive) has not been implemented for torch::optim::OptimizerOptions. ", - "You must override it in your subclass of torch::optim::OptimizerCloneableOptions."); + TORCH_CHECK( + false, + "void serialize(torch::serialize::InputArchive& archive) has not been implemented for torch::optim::OptimizerOptions. ", + "You must override it in your subclass of torch::optim::OptimizerCloneableOptions."); } -void OptimizerOptions::serialize(torch::serialize::OutputArchive& archive) const { - TORCH_CHECK(false, - "void serialize(torch::serialize::OutputArchive& archive) has not been implemented for torch::optim::OptimizerOptions. ", - "You must override it in your subclass of torch::optim::OptimizerCloneableOptions."); +void OptimizerOptions::serialize( + torch::serialize::OutputArchive& archive) const { + TORCH_CHECK( + false, + "void serialize(torch::serialize::OutputArchive& archive) has not been implemented for torch::optim::OptimizerOptions. ", + "You must override it in your subclass of torch::optim::OptimizerCloneableOptions."); } void Optimizer::add_param_group(const OptimizerParamGroup& param_group) { @@ -95,8 +108,9 @@ void Optimizer::add_param_group(const OptimizerParamGroup& param_group) { param_group_.set_options(param_group.options().clone()); } for (const auto& p : param_group_.params()) { - TORCH_CHECK(state_.count(c10::guts::to_string(p.unsafeGetTensorImpl())) == 0, - "some parameters appear in more than one parameter group"); + TORCH_CHECK( + state_.count(c10::guts::to_string(p.unsafeGetTensorImpl())) == 0, + "some parameters appear in more than one parameter group"); } param_groups_.emplace_back(std::move(param_group_)); } @@ -119,13 +133,13 @@ void Optimizer::zero_grad() { } const std::vector& Optimizer::parameters() const noexcept { - TORCH_WARN("Optimizer::parameters() will be removed in PyTorch 1.6"); - return param_groups_.at(0).params(); + TORCH_WARN("Optimizer::parameters() will be removed in PyTorch 1.6"); + return param_groups_.at(0).params(); } std::vector& Optimizer::parameters() noexcept { - TORCH_WARN("Optimizer::parameters() will be removed in PyTorch 1.6"); - return param_groups_.at(0).params(); + TORCH_WARN("Optimizer::parameters() will be removed in PyTorch 1.6"); + return param_groups_.at(0).params(); } size_t Optimizer::size() const noexcept { @@ -149,15 +163,18 @@ std::vector& Optimizer::param_groups() noexcept { return param_groups_; } -const std::vector& Optimizer::param_groups() const noexcept { +const std::vector& Optimizer::param_groups() + const noexcept { return param_groups_; } -ska::flat_hash_map>& Optimizer::state() noexcept { +ska::flat_hash_map>& +Optimizer::state() noexcept { return state_; } -const ska::flat_hash_map>& Optimizer::state() const noexcept { +const ska::flat_hash_map>& +Optimizer::state() const noexcept { return state_; } diff --git a/torch/csrc/api/src/optim/rmsprop.cpp b/torch/csrc/api/src/optim/rmsprop.cpp index cc5e4ca7de0d70..f54f8242935867 100644 --- a/torch/csrc/api/src/optim/rmsprop.cpp +++ b/torch/csrc/api/src/optim/rmsprop.cpp @@ -15,12 +15,9 @@ namespace optim { RMSpropOptions::RMSpropOptions(double lr) : lr_(lr) {} bool operator==(const RMSpropOptions& lhs, const RMSpropOptions& rhs) { - return (lhs.lr() == rhs.lr()) && - (lhs.alpha() == rhs.alpha()) && - (lhs.eps() == rhs.eps()) && - (lhs.weight_decay() == rhs.weight_decay()) && - (lhs.momentum() == rhs.momentum()) && - (lhs.centered() == rhs.centered()); + return (lhs.lr() == rhs.lr()) && (lhs.alpha() == rhs.alpha()) && + (lhs.eps() == rhs.eps()) && (lhs.weight_decay() == rhs.weight_decay()) && + (lhs.momentum() == rhs.momentum()) && (lhs.centered() == rhs.centered()); } void RMSpropOptions::serialize(torch::serialize::OutputArchive& archive) const { @@ -51,12 +48,13 @@ void RMSpropOptions::set_lr(const double lr) { bool operator==(const RMSpropParamState& lhs, const RMSpropParamState& rhs) { return (lhs.step() == rhs.step()) && - torch::equal(lhs.square_avg(), rhs.square_avg()) && - torch::equal_if_defined(lhs.momentum_buffer(), rhs.momentum_buffer()) && - torch::equal_if_defined(lhs.grad_avg(), rhs.grad_avg()); + torch::equal(lhs.square_avg(), rhs.square_avg()) && + torch::equal_if_defined(lhs.momentum_buffer(), rhs.momentum_buffer()) && + torch::equal_if_defined(lhs.grad_avg(), rhs.grad_avg()); } -void RMSpropParamState::serialize(torch::serialize::OutputArchive& archive) const { +void RMSpropParamState::serialize( + torch::serialize::OutputArchive& archive) const { _TORCH_OPTIM_SERIALIZE_TORCH_ARG(step); _TORCH_OPTIM_SERIALIZE_TORCH_ARG(square_avg); _TORCH_OPTIM_SERIALIZE_TORCH_ARG(momentum_buffer); @@ -72,7 +70,7 @@ void RMSpropParamState::serialize(torch::serialize::InputArchive& archive) { /// Adapted from /// https://github.com/pytorch/pytorch/blob/master/torch/optim/rmsprop.py -Tensor RMSprop::step(LossClosure closure) { +Tensor RMSprop::step(LossClosure closure) { NoGradGuard no_grad; Tensor loss = {}; if (closure != nullptr) { @@ -85,8 +83,10 @@ Tensor RMSprop::step(LossClosure closure) { continue; } auto grad = p.grad(); - TORCH_CHECK(!grad.is_sparse(), "RMSprop does not support sparse gradients"); - auto param_state = state_.find(c10::guts::to_string(p.unsafeGetTensorImpl())); + TORCH_CHECK( + !grad.is_sparse(), "RMSprop does not support sparse gradients"); + auto param_state = + state_.find(c10::guts::to_string(p.unsafeGetTensorImpl())); auto& options = static_cast(group.options()); // State initialization @@ -100,10 +100,12 @@ Tensor RMSprop::step(LossClosure closure) { if (options.centered()) { state->grad_avg(torch::zeros_like(p, MemoryFormat::Preserve)); } - state_[c10::guts::to_string(p.unsafeGetTensorImpl())] = std::move(state); + state_[c10::guts::to_string(p.unsafeGetTensorImpl())] = + std::move(state); } - auto& state = static_cast(*state_[c10::guts::to_string(p.unsafeGetTensorImpl())]); + auto& state = static_cast( + *state_[c10::guts::to_string(p.unsafeGetTensorImpl())]); auto& square_avg = state.square_avg(); auto alpha = options.alpha(); @@ -118,8 +120,10 @@ Tensor RMSprop::step(LossClosure closure) { Tensor avg; if (options.centered()) { auto& grad_avg = state.grad_avg(); - grad_avg.mul_(alpha).add_(grad, 1-alpha); - avg = square_avg.addcmul(grad_avg, grad_avg, -1).sqrt_().add_(options.eps()); + grad_avg.mul_(alpha).add_(grad, 1 - alpha); + avg = square_avg.addcmul(grad_avg, grad_avg, -1) + .sqrt_() + .add_(options.eps()); } else { avg = square_avg.sqrt().add_(options.eps()); } @@ -146,30 +150,34 @@ void RMSprop::load(serialize::InputArchive& archive) { IValue pytorch_version; if (archive.try_read("pytorch_version", pytorch_version)) { serialize(*this, archive); - } - else { // deserializing archives saved in old format (prior to version 1.5.0) + } else { // deserializing archives saved in old format (prior to + // version 1.5.0) TORCH_WARN( - "Your serialized RMSprop optimizer is still using the old serialization format. " - "The step value in state will be set to 0 because the old RMSprop optimizer didn't track the step value." - "You should re-save your RMSprop optimizer to use the new serialization format."); + "Your serialized RMSprop optimizer is still using the old serialization format. " + "The step value in state will be set to 0 because the old RMSprop optimizer didn't track the step value." + "You should re-save your RMSprop optimizer to use the new serialization format."); std::vector square_average_buffers; std::vector momentum_buffers; std::vector grad_average_buffers; - torch::optim::serialize(archive, "square_average_buffers", square_average_buffers); + torch::optim::serialize( + archive, "square_average_buffers", square_average_buffers); torch::optim::serialize(archive, "momentum_buffers", momentum_buffers); - torch::optim::serialize(archive, "grad_average_buffers", grad_average_buffers); - // since there were no param_groups prior to version 1.5.0, assuming all tensors are now in one param_group + torch::optim::serialize( + archive, "grad_average_buffers", grad_average_buffers); + // since there were no param_groups prior to version 1.5.0, assuming all + // tensors are now in one param_group std::vector params = param_groups_.at(0).params(); - for(const auto idx : c10::irange(square_average_buffers.size())) { + for (const auto idx : c10::irange(square_average_buffers.size())) { auto state = std::make_unique(); state->square_avg(square_average_buffers[idx]); - if(idx < momentum_buffers.size()) { + if (idx < momentum_buffers.size()) { state->momentum_buffer(momentum_buffers.at(idx)); } - if(idx < grad_average_buffers.size()) { + if (idx < grad_average_buffers.size()) { state->grad_avg(grad_average_buffers.at(idx)); } - state_[c10::guts::to_string(params[idx].unsafeGetTensorImpl())] = std::move(state); + state_[c10::guts::to_string(params[idx].unsafeGetTensorImpl())] = + std::move(state); } } } diff --git a/torch/csrc/api/src/optim/schedulers/lr_scheduler.cpp b/torch/csrc/api/src/optim/schedulers/lr_scheduler.cpp index a69a289df539a4..00b30a74fb884d 100644 --- a/torch/csrc/api/src/optim/schedulers/lr_scheduler.cpp +++ b/torch/csrc/api/src/optim/schedulers/lr_scheduler.cpp @@ -4,9 +4,8 @@ namespace torch { namespace optim { -LRScheduler::LRScheduler(torch::optim::Optimizer& optimizer) : - step_count_(0), - optimizer_(optimizer) {} +LRScheduler::LRScheduler(torch::optim::Optimizer& optimizer) + : step_count_(0), optimizer_(optimizer) {} void LRScheduler::step() { std::vector learning_rates = get_lrs(); @@ -15,12 +14,15 @@ void LRScheduler::step() { } void LRScheduler::set_optimizer_lrs(const std::vector& learning_rates) { - //Check the number of learning rates is equal to the number of parameters groups in the - //optimizer - TORCH_CHECK(learning_rates.size() == optimizer_.param_groups().size(), - "Number of learning rates not equal to the number of param groups\n", - "Number of learning rates given: ", learning_rates.size(), - "\nNumber of param groups: ", optimizer_.param_groups().size()); + // Check the number of learning rates is equal to the number of parameters + // groups in the optimizer + TORCH_CHECK( + learning_rates.size() == optimizer_.param_groups().size(), + "Number of learning rates not equal to the number of param groups\n", + "Number of learning rates given: ", + learning_rates.size(), + "\nNumber of param groups: ", + optimizer_.param_groups().size()); for (const auto i : c10::irange(optimizer_.param_groups().size())) { optimizer_.param_groups()[i].options().set_lr(learning_rates[i]); @@ -29,7 +31,7 @@ void LRScheduler::set_optimizer_lrs(const std::vector& learning_rates) { std::vector LRScheduler::get_current_lrs() const { std::vector learnings_rates(optimizer_.param_groups().size()); - if(learnings_rates.size() > 0) { + if (learnings_rates.size() > 0) { for (const auto i : c10::irange(optimizer_.param_groups().size())) { learnings_rates[i] = optimizer_.param_groups()[i].options().get_lr(); } diff --git a/torch/csrc/api/src/optim/schedulers/step_lr.cpp b/torch/csrc/api/src/optim/schedulers/step_lr.cpp index d053dca352d769..497ebe08fed3b1 100644 --- a/torch/csrc/api/src/optim/schedulers/step_lr.cpp +++ b/torch/csrc/api/src/optim/schedulers/step_lr.cpp @@ -3,20 +3,21 @@ namespace torch { namespace optim { -StepLR::StepLR(torch::optim::Optimizer& optimizer, - const unsigned step_size, - const double gamma) : - LRScheduler(optimizer), - step_size_(step_size), - gamma_(gamma) {} +StepLR::StepLR( + torch::optim::Optimizer& optimizer, + const unsigned step_size, + const double gamma) + : LRScheduler(optimizer), step_size_(step_size), gamma_(gamma) {} std::vector StepLR::get_lrs() { - if(step_count_ == 0 || step_count_ % step_size_ != 0) + if (step_count_ == 0 || step_count_ % step_size_ != 0) return get_current_lrs(); else { std::vector lrs = get_current_lrs(); - std::transform(lrs.begin(), lrs.end(), lrs.begin(), - [this](const double& v){ return this->gamma_ * v; }); + std::transform( + lrs.begin(), lrs.end(), lrs.begin(), [this](const double& v) { + return this->gamma_ * v; + }); return lrs; } } diff --git a/torch/csrc/api/src/optim/sgd.cpp b/torch/csrc/api/src/optim/sgd.cpp index 96e11d6e8f2f1b..02f1ed5ac1f95c 100644 --- a/torch/csrc/api/src/optim/sgd.cpp +++ b/torch/csrc/api/src/optim/sgd.cpp @@ -18,11 +18,10 @@ namespace optim { SGDOptions::SGDOptions(double lr) : lr_(lr) {} bool operator==(const SGDOptions& lhs, const SGDOptions& rhs) { - return (lhs.lr() == rhs.lr()) && - (lhs.momentum() == rhs.momentum()) && - (lhs.dampening() == rhs.dampening()) && - (lhs.weight_decay() == rhs.weight_decay()) && - (lhs.nesterov() == rhs.nesterov()); + return (lhs.lr() == rhs.lr()) && (lhs.momentum() == rhs.momentum()) && + (lhs.dampening() == rhs.dampening()) && + (lhs.weight_decay() == rhs.weight_decay()) && + (lhs.nesterov() == rhs.nesterov()); } void SGDOptions::serialize(torch::serialize::OutputArchive& archive) const { @@ -61,7 +60,7 @@ void SGDParamState::serialize(torch::serialize::InputArchive& archive) { _TORCH_OPTIM_DESERIALIZE_TORCH_ARG(Tensor, momentum_buffer); } -Tensor SGD::step(LossClosure closure) { +Tensor SGD::step(LossClosure closure) { NoGradGuard no_grad; Tensor loss = {}; if (closure != nullptr) { @@ -85,14 +84,17 @@ Tensor SGD::step(LossClosure closure) { } if (momentum != 0) { Tensor buf; - auto param_state = state_.find(c10::guts::to_string(p.unsafeGetTensorImpl())); - if(param_state == state_.end()) { + auto param_state = + state_.find(c10::guts::to_string(p.unsafeGetTensorImpl())); + if (param_state == state_.end()) { buf = torch::clone(d_p).detach(); auto state = std::make_unique(); state->momentum_buffer(buf); - state_[c10::guts::to_string(p.unsafeGetTensorImpl())] = std::move(state); + state_[c10::guts::to_string(p.unsafeGetTensorImpl())] = + std::move(state); } else { - buf = static_cast(*param_state->second).momentum_buffer(); + buf = static_cast(*param_state->second) + .momentum_buffer(); buf.mul_(momentum).add_(d_p, 1 - dampening); } if (nesterov) { @@ -115,19 +117,21 @@ void SGD::load(serialize::InputArchive& archive) { IValue pytorch_version; if (archive.try_read("pytorch_version", pytorch_version)) { serialize(*this, archive); - } - else { // deserializing archives saved in old format (prior to version 1.5.0) + } else { // deserializing archives saved in old format (prior to + // version 1.5.0) TORCH_WARN( - "Your serialized SGD optimizer is still using the old serialization format. " - "You should re-save your SGD optimizer to use the new serialization format."); + "Your serialized SGD optimizer is still using the old serialization format. " + "You should re-save your SGD optimizer to use the new serialization format."); std::vector momentum_buffers; torch::optim::serialize(archive, "momentum_buffers", momentum_buffers); - // since there were no param_groups prior to version 1.5.0, assuming all tensors are now in one param_group + // since there were no param_groups prior to version 1.5.0, assuming all + // tensors are now in one param_group std::vector params = param_groups_.at(0).params(); - for(const auto idx : c10::irange(momentum_buffers.size())) { + for (const auto idx : c10::irange(momentum_buffers.size())) { auto state = std::make_unique(); state->momentum_buffer(momentum_buffers[idx]); - state_[c10::guts::to_string(params[idx].unsafeGetTensorImpl())] = std::move(state); + state_[c10::guts::to_string(params[idx].unsafeGetTensorImpl())] = + std::move(state); } } } diff --git a/torch/csrc/api/src/python/init.cpp b/torch/csrc/api/src/python/init.cpp index ad094c66b64d25..84ec2570272df1 100644 --- a/torch/csrc/api/src/python/init.cpp +++ b/torch/csrc/api/src/python/init.cpp @@ -1,5 +1,5 @@ -#include #include +#include #include #include diff --git a/torch/csrc/api/src/serialize.cpp b/torch/csrc/api/src/serialize.cpp index 2e9367e7f4ef72..e8497a7f22b567 100644 --- a/torch/csrc/api/src/serialize.cpp +++ b/torch/csrc/api/src/serialize.cpp @@ -14,5 +14,4 @@ torch::IValue pickle_load(const std::vector& data) { return jit::pickle_load(data); } - } // namespace torch diff --git a/torch/csrc/api/src/serialize/input-archive.cpp b/torch/csrc/api/src/serialize/input-archive.cpp index 8005fae953352b..938e9ea07d1f0a 100644 --- a/torch/csrc/api/src/serialize/input-archive.cpp +++ b/torch/csrc/api/src/serialize/input-archive.cpp @@ -3,10 +3,10 @@ #include #include -#include -#include -#include #include +#include +#include +#include #include #include @@ -16,15 +16,14 @@ namespace torch { namespace serialize { -InputArchive::InputArchive() : module_("Module", std::make_shared()) {} +InputArchive::InputArchive() + : module_("Module", std::make_shared()) {} void InputArchive::read(const std::string& key, c10::IValue& ivalue) { ivalue = module_.attr(key); } -bool InputArchive::try_read( - const std::string& key, - c10::IValue& ivalue) { +bool InputArchive::try_read(const std::string& key, c10::IValue& ivalue) { if (!module_.hasattr(key)) { return false; } @@ -92,13 +91,15 @@ void InputArchive::read(const std::string& key, InputArchive& archive) { "'"); } -void InputArchive::load_from(const std::string& filename, +void InputArchive::load_from( + const std::string& filename, c10::optional device /*= c10::nullopt*/) { // NOLINTNEXTLINE(performance-move-const-arg) module_ = torch::jit::load(filename, std::move(device)); } -void InputArchive::load_from(std::istream& stream, +void InputArchive::load_from( + std::istream& stream, c10::optional device /*= c10::nullopt*/) { // NOLINTNEXTLINE(performance-move-const-arg) module_ = torch::jit::load(stream, std::move(device)); @@ -110,14 +111,14 @@ void InputArchive::load_from( c10::optional device /*= c10::nullopt*/) { using caffe2::serialize::ReadAdapterInterface; class OurAdapter : public ReadAdapterInterface { - public: - OurAdapter(const char* data, size_t size) - : data_(data), size_(size) { + public: + OurAdapter(const char* data, size_t size) : data_(data), size_(size) {} + size_t size() const override { + return size_; } - size_t size() const override { return size_; } size_t read(uint64_t pos, void* buf, size_t n, const char* what = "") - const override { - (void) what; + const override { + (void)what; if (pos >= size_) { return 0; } @@ -125,7 +126,8 @@ void InputArchive::load_from( memcpy(buf, data_ + pos, nread); return nread; } - private: + + private: const char* data_; size_t size_; }; @@ -140,19 +142,21 @@ void InputArchive::load_from( c10::optional device /*= c10::nullopt*/) { using caffe2::serialize::ReadAdapterInterface; class OurAdapter : public ReadAdapterInterface { - public: - OurAdapter(const std::function& read_func, - const std::function& size_func) - : read_func_(read_func), - size_func_(size_func) { + public: + OurAdapter( + const std::function& read_func, + const std::function& size_func) + : read_func_(read_func), size_func_(size_func) {} + size_t size() const override { + return size_func_(); } - size_t size() const override { return size_func_(); } size_t read(uint64_t pos, void* buf, size_t n, const char* what = "") - const override { + const override { (void)what; return read_func_(pos, buf, n); } - private: + + private: const std::function& read_func_; const std::function& size_func_; }; @@ -165,7 +169,8 @@ std::vector InputArchive::keys() { std::vector all_keys; all_keys.reserve(module_.named_attributes(/*recurse=*/false).size()); - for (const torch::jit::NameValue& s : module_.named_attributes(/*recurse=*/false)) { + for (const torch::jit::NameValue& s : + module_.named_attributes(/*recurse=*/false)) { all_keys.push_back(s.name); } diff --git a/torch/csrc/api/src/serialize/output-archive.cpp b/torch/csrc/api/src/serialize/output-archive.cpp index 75cb60e7e718c7..f467a6105518d0 100644 --- a/torch/csrc/api/src/serialize/output-archive.cpp +++ b/torch/csrc/api/src/serialize/output-archive.cpp @@ -3,8 +3,8 @@ #include #include -#include #include +#include #include diff --git a/torch/csrc/autograd/FunctionsManual.cpp b/torch/csrc/autograd/FunctionsManual.cpp index 7bb3683f296182..6eb9afc40f7be8 100644 --- a/torch/csrc/autograd/FunctionsManual.cpp +++ b/torch/csrc/autograd/FunctionsManual.cpp @@ -1,36 +1,33 @@ #include -#include -#include #include - +#include +#include #include #include #include -#include -#include #include #include -#include -#include -#include #include #include +#include #include #include #include +#include #include +#include +#include +#include #include +#include #include #include -#include -#include - -#include #include -#include +#include #include +#include // Helper functions for autogenerated code // These used to be inlined into the codegened Functions.cpp @@ -39,13 +36,14 @@ namespace autograd { namespace generated { namespace details { -using at::Tensor; -using at::Scalar; +using at::areAnyTensorSubclassLike; using at::IntArrayRef; +using at::Scalar; +using at::Tensor; using at::TensorList; -using at::areAnyTensorSubclassLike; -const char* kCudnnDoubleBackwardMsg = "Double backwards is not supported for CuDNN RNNs due to limitations in the CuDNN API. To run double backwards, please disable the CuDNN backend temporarily while running the forward pass of your RNN. For example: \nwith torch.backends.cudnn.flags(enabled=False):\n output = model(inputs)"; +const char* kCudnnDoubleBackwardMsg = + "Double backwards is not supported for CuDNN RNNs due to limitations in the CuDNN API. To run double backwards, please disable the CuDNN backend temporarily while running the forward pass of your RNN. For example: \nwith torch.backends.cudnn.flags(enabled=False):\n output = model(inputs)"; Tensor apply_loss_reduction(const Tensor& unreduced, int64_t reduction) { if (reduction == at::Reduction::Mean) { @@ -69,22 +67,29 @@ Tensor toNonOptFwGrad(const c10::optional& t) { } Tensor toNonOptPrimal(const c10::optional& t) { - return (t.has_value() && t->defined()) ? t->_fw_primal(/*level */ 0) : Tensor(); + return (t.has_value() && t->defined()) ? t->_fw_primal(/*level */ 0) + : Tensor(); } -void copy_range(variable_list& out, IndexRange range, const Tensor & t) { +void copy_range(variable_list& out, IndexRange range, const Tensor& t) { AT_ASSERT(range.second <= out.size()); - AT_ASSERTM(range.second - range.first == 1, "inconsistent range for Tensor output"); + AT_ASSERTM( + range.second - range.first == 1, "inconsistent range for Tensor output"); out[range.first] = t; } void copy_range(variable_list& out, IndexRange range, at::ArrayRef t) { AT_ASSERT(range.second <= out.size()); - AT_ASSERTM(range.second - range.first == t.size(), "inconsistent range for TensorList output"); + AT_ASSERTM( + range.second - range.first == t.size(), + "inconsistent range for TensorList output"); std::copy(t.begin(), t.end(), out.begin() + range.first); } -Tensor copysign_tensor_self_backward(const Tensor & grad, const Tensor & self, const Tensor & result) { +Tensor copysign_tensor_self_backward( + const Tensor& grad, + const Tensor& self, + const Tensor& result) { auto ratio = result / self; ratio.masked_fill_(self == 0, 0); return grad * ratio; @@ -92,7 +97,8 @@ Tensor copysign_tensor_self_backward(const Tensor & grad, const Tensor & self, c template T not_implemented_base(const char* name, const char* reason) { - std::string msg = c10::str("the derivative for '", name, "' is not implemented."); + std::string msg = + c10::str("the derivative for '", name, "' is not implemented."); if (strlen(reason) > 0) { msg = c10::str(msg, " ", reason); }; @@ -107,11 +113,11 @@ std::vector not_implemented_list(const char* name, const char* reason) { return not_implemented_base>(name, reason); } -Tensor maybe_multiply(const Tensor & t, const Scalar & s) { +Tensor maybe_multiply(const Tensor& t, const Scalar& s) { bool is_one = false; if (s.isFloatingPoint()) { is_one = s.toDouble() == 1; - } else if(s.isIntegral(true)) { + } else if (s.isIntegral(true)) { is_one = s.toLong() == 1; } @@ -150,7 +156,10 @@ Tensor handle_r_to_c(Tensor self, Tensor gradient_result) { return gradient_result; } -Tensor restore_reduced_dims(const Tensor &output, IntArrayRef dims, bool keepdim) { +Tensor restore_reduced_dims( + const Tensor& output, + IntArrayRef dims, + bool keepdim) { if (keepdim) { return output; } @@ -164,17 +173,25 @@ Tensor restore_reduced_dims(const Tensor &output, IntArrayRef dims, bool keepdim } int64_t j = 0; for (int64_t i : output.sizes()) { - while (target_shape[j] > 0) j++; + while (target_shape[j] > 0) + j++; target_shape[j++] = i; } return output.reshape(target_shape); } -Tensor scale_grad_by_count(const Tensor &grad, const Tensor &mask, IntArrayRef dims) { +Tensor scale_grad_by_count( + const Tensor& grad, + const Tensor& mask, + IntArrayRef dims) { return (grad / mask.sum(dims, true)) * mask; } -std::tuple _euclidean_dist_backward(const Tensor & grad, const Tensor & x1, const Tensor & x2, const Tensor & res) { +std::tuple _euclidean_dist_backward( + const Tensor& grad, + const Tensor& x1, + const Tensor& x2, + const Tensor& res) { if (!grad.defined()) { return std::tuple(Tensor(), Tensor()); } @@ -182,20 +199,31 @@ std::tuple _euclidean_dist_backward(const Tensor & grad, const T Tensor ratio = grad / res; ratio.masked_fill_(res == 0, 0); return std::tuple{ - x1 * ratio.sum(-1, true) - ratio.matmul(x2), - x2 * ratio.sum(-2, false).unsqueeze(-1) - ratio.mT().matmul(x1)}; + x1 * ratio.sum(-1, true) - ratio.matmul(x2), + x2 * ratio.sum(-2, false).unsqueeze(-1) - ratio.mT().matmul(x1)}; } -Tensor norm_backward(const Tensor& grad, const Tensor& self, const optional & p_, const Tensor& norm) { +Tensor norm_backward( + const Tensor& grad, + const Tensor& self, + const optional& p_, + const Tensor& norm) { return norm_backward(grad, self, p_, norm, {}, true); } Tensor norm_backward( - Tensor grad, const Tensor& self, const optional & p_, Tensor norm, IntArrayRef dim, bool keepdim) { - // NB: We mask fill the NaNs in the output to be zero but still do float division - // by zero, which ASAN complains about. One way to appease ASAN is to fill the problematic - // values with something arbitrary before the division, but we decide not to due to - // the perf hit. Instead we just silence ASAN where necessary + Tensor grad, + const Tensor& self, + const optional& p_, + Tensor norm, + IntArrayRef dim, + bool keepdim) { + // NB: We mask fill the NaNs in the output to be zero but still do float + // division + // by zero, which ASAN complains about. One way to appease ASAN is to fill + // the problematic values with something arbitrary before the division, + // but we decide not to due to the perf hit. Instead we just silence ASAN + // where necessary size_t ndim = self.sizes().size(); double p = p_.value_or(2.0).toDouble(); Tensor self_scaled; @@ -214,12 +242,14 @@ Tensor norm_backward( return self * (grad / norm).masked_fill_(norm == 0, 0); } else if (std::isinf(p)) { // Derivative of amax(abs(self), dim, keepdim) but respecting nans - // We create a mask of `argmax`: it's argmax if self.abs() == norm or it's NaN + // We create a mask of `argmax`: it's argmax if self.abs() == norm or it's + // NaN auto self_abs = self.abs(); auto mask = self_abs.eq(norm).logical_or(self_abs.isnan()); return self.sgn() * ((grad / mask.sum(dim, true)) * mask); } else if (p < 1.0) { - self_scaled = self.sgn() * self.abs().pow_(p - 1).masked_fill_(self == 0, 0); + self_scaled = + self.sgn() * self.abs().pow_(p - 1).masked_fill_(self == 0, 0); return self_scaled * grad * norm.pow(1 - p); } else if (p < 2.0) { self_scaled = self.sgn() * self.abs().pow_(p - 1); @@ -236,23 +266,26 @@ Tensor norm_backward( // See norm_backward above for a note on ignoring the sanitizer Tensor norm_jvp( - const Tensor& self_p, const Tensor& self_t, - const optional & p_, - Tensor norm, - IntArrayRef dim, - bool keepdim -) { - // NB: currently norm_jvp is also reused for dist's jvp (which haas two differentiable inputs) - // but self_t still cannot be a ZT because that would require both self_t and other_t to be ZT + const Tensor& self_p, + const Tensor& self_t, + const optional& p_, + Tensor norm, + IntArrayRef dim, + bool keepdim) { + // NB: currently norm_jvp is also reused for dist's jvp (which haas two + // differentiable inputs) + // but self_t still cannot be a ZT because that would require both self_t + // and other_t to be ZT TORCH_INTERNAL_ASSERT(!self_t._is_zerotensor()); - size_t ndim = self_p.dim(); // composite compliance? + size_t ndim = self_p.dim(); // composite compliance? double p = p_.value_or(2.0).toDouble(); if (p == 0.0) { return at::zeros_like(norm); } else if (p == 1.0) { auto result = self_p.sgn(); - result = areAnyTensorSubclassLike({self_t}) ? result.mul(self_t.conj()) : result.mul_(self_t.conj()); + result = areAnyTensorSubclassLike({self_t}) ? result.mul(self_t.conj()) + : result.mul_(self_t.conj()); result = at::real(result); return result.sum(dim, keepdim); } else if (p == 2.0) { @@ -266,57 +299,93 @@ Tensor norm_jvp( } const auto self_isnan = self_p.isnan(); const auto norm_isnan = norm.isnan(); - const auto& self_and_norm_isnan = areAnyTensorSubclassLike({norm}) ? - self_isnan.logical_and(norm_isnan) : - self_isnan.logical_and_(norm_isnan); - const auto is_eq_max = (self_p.abs() == norm).logical_or_(self_and_norm_isnan).type_as(norm); + const auto& self_and_norm_isnan = areAnyTensorSubclassLike({norm}) + ? self_isnan.logical_and(norm_isnan) + : self_isnan.logical_and_(norm_isnan); + const auto is_eq_max = + (self_p.abs() == norm).logical_or_(self_and_norm_isnan).type_as(norm); auto nb_max = is_eq_max.count_nonzero(dim); if (self_p.dim() != 0) { nb_max = unsqueeze_multiple(nb_max, dim, ndim); } - return (at::real(self_p.sgn() * self_t.conj()) * is_eq_max / nb_max).sum(dim, keepdim); + return (at::real(self_p.sgn() * self_t.conj()) * is_eq_max / nb_max) + .sum(dim, keepdim); } else if (p < 1.0) { - auto sumpow_t = (self_p.abs().pow_(p - 1).masked_fill_(self_p == 0, 0) * at::real(self_p.sgn() * self_t.conj())).sum(dim, keepdim); + auto sumpow_t = (self_p.abs().pow_(p - 1).masked_fill_(self_p == 0, 0) * + at::real(self_p.sgn() * self_t.conj())) + .sum(dim, keepdim); return sumpow_t * norm.pow(1 - p); } else if (p < 2.0) { - auto sumpow_t = (self_p.abs().pow_(p - 1) * at::real(self_p.sgn() * self_t.conj())).sum(dim, keepdim); + auto sumpow_t = + (self_p.abs().pow_(p - 1) * at::real(self_p.sgn() * self_t.conj())) + .sum(dim, keepdim); auto out = sumpow_t / norm.pow(p - 1); return out.masked_fill_(norm == 0, 0); } else { - auto sumpow_t = (self_p.abs().pow_(p - 2) * at::real(self_p * self_t.conj())).sum(dim, keepdim); + auto sumpow_t = + (self_p.abs().pow_(p - 2) * at::real(self_p * self_t.conj())) + .sum(dim, keepdim); auto out = sumpow_t / norm.pow(p - 1); return out.masked_fill_(norm == 0, 0); } } -Tensor norm_jvp(const Tensor& self_p, const Tensor& self_t, const optional & p_, Tensor norm) { +Tensor norm_jvp( + const Tensor& self_p, + const Tensor& self_t, + const optional& p_, + Tensor norm) { return norm_jvp(self_p, self_t, p_, norm, {}, true); } -Tensor linalg_vector_norm_jvp(const Tensor& self_p, const Tensor& self_t, const Scalar& scalar_ord, Tensor norm, const at::OptionalIntArrayRef& opt_dim, bool keepdim) { - // No need to handle the dtype arg as it's handled via broadcasting in the function +Tensor linalg_vector_norm_jvp( + const Tensor& self_p, + const Tensor& self_t, + const Scalar& scalar_ord, + Tensor norm, + const at::OptionalIntArrayRef& opt_dim, + bool keepdim) { + // No need to handle the dtype arg as it's handled via broadcasting in the + // function auto dim = opt_dim.value_or(IntArrayRef({})); return norm_jvp(self_p, self_t, scalar_ord, norm, dim, keepdim); } -Tensor linalg_vector_norm_backward(Tensor grad, const Tensor& self, const Scalar& scalar_ord, Tensor norm, const at::OptionalIntArrayRef& opt_dim, bool keepdim) { - // No need to handle the dtype arg as it's handled via broadcasting in the function +Tensor linalg_vector_norm_backward( + Tensor grad, + const Tensor& self, + const Scalar& scalar_ord, + Tensor norm, + const at::OptionalIntArrayRef& opt_dim, + bool keepdim) { + // No need to handle the dtype arg as it's handled via broadcasting in the + // function auto dim = opt_dim.value_or(IntArrayRef({})); return norm_backward(grad, self, scalar_ord, norm, dim, keepdim); } -Tensor pow_backward(Tensor grad, const Tensor & self, const Scalar & exponent) { +Tensor pow_backward(Tensor grad, const Tensor& self, const Scalar& exponent) { if (exponent.equal(0.0)) { return at::zeros_like(self, LEGACY_CONTIGUOUS_MEMORY_FORMAT); } else { - auto grad_lambda = [&](auto exp) { return grad * (exp * self.pow(exp - 1)).conj(); }; - Tensor out = (exponent.isComplex()) ? grad_lambda(exponent.toComplexDouble()) : grad_lambda(exponent.toDouble()); + auto grad_lambda = [&](auto exp) { + return grad * (exp * self.pow(exp - 1)).conj(); + }; + Tensor out = (exponent.isComplex()) + ? grad_lambda(exponent.toComplexDouble()) + : grad_lambda(exponent.toDouble()); return handle_r_to_c(self, out); } } -Tensor pow_backward_self(Tensor grad, const Tensor & self, const Tensor & exponent) { - auto out = at::where(exponent == 0.0, at::zeros({}, grad.options()), grad * (exponent * self.pow(exponent - 1)).conj()); +Tensor pow_backward_self( + Tensor grad, + const Tensor& self, + const Tensor& exponent) { + auto out = at::where( + exponent == 0.0, + at::zeros({}, grad.options()), + grad * (exponent * self.pow(exponent - 1)).conj()); return handle_r_to_c(self, out); } @@ -328,33 +397,44 @@ Tensor pow_backward_self(Tensor grad, const Tensor & self, const Tensor & expone // We define d(a^b)/db = 0 for a = 0 and b = 0 by continuity as // d(a^b)/db = 0 for a > 0 and b -> +0. // Currently, tensorflow agrees with us. -Tensor pow_backward_exponent(Tensor grad, const Tensor& self, const Tensor& exponent, Tensor result) { +Tensor pow_backward_exponent( + Tensor grad, + const Tensor& self, + const Tensor& exponent, + Tensor result) { Tensor cond; if (exponent.is_complex()) { - auto is_real_exp = at::logical_and(at::imag(exponent) == 0, at::real(exponent) >= 0); + auto is_real_exp = + at::logical_and(at::imag(exponent) == 0, at::real(exponent) >= 0); cond = at::logical_and(self == 0, is_real_exp); } else { cond = at::logical_and(self == 0, exponent >= 0); } - auto out = grad * at::where(cond, - at::zeros({}, grad.options()), - (result * self.log()).conj()); + auto out = + grad * + at::where( + cond, at::zeros({}, grad.options()), (result * self.log()).conj()); return handle_r_to_c(exponent, out); } -Tensor pow_backward_exponent(Tensor grad, const Scalar & base, const Tensor& exponent, Tensor result) { +Tensor pow_backward_exponent( + Tensor grad, + const Scalar& base, + const Tensor& exponent, + Tensor result) { auto grad_lambda = [](Tensor a, Scalar b) { return (a * b.log()).conj(); }; if (base.equal(0.0)) { auto cond = [](auto exp) { if (exp.is_complex()) { return at::logical_and(at::imag(exp) == 0, at::real(exp) >= 0); } else { - return exp >=0; + return exp >= 0; } }; - auto out = grad * at::where(cond(exponent), - at::zeros({}, grad.options()), - grad_lambda(result, base)); + auto out = grad * + at::where(cond(exponent), + at::zeros({}, grad.options()), + grad_lambda(result, base)); return handle_r_to_c(exponent, out); } else { auto out = grad * grad_lambda(result, base); @@ -364,14 +444,17 @@ Tensor pow_backward_exponent(Tensor grad, const Scalar & base, const Tensor& exp Tensor angle_backward(Tensor grad, const Tensor& self) { if (self.is_complex()) { - return at::where(self == 0.0, at::zeros({}, self.options()), - grad * self / self.abs().pow(2) * Scalar(c10::complex{0.0, 1.0})); + return at::where( + self == 0.0, + at::zeros({}, self.options()), + grad * self / self.abs().pow(2) * + Scalar(c10::complex{0.0, 1.0})); } else { return at::zeros_like(self, at::MemoryFormat::Preserve); } } -Tensor mvlgamma_backward(Tensor grad, const Tensor & self, int64_t p) { +Tensor mvlgamma_backward(Tensor grad, const Tensor& self, int64_t p) { Tensor args = at::arange(-p / 2. + 0.5, 0.5, 0.5, self.options()); args = args.add(self.unsqueeze(-1)); return grad * args.digamma_().sum(-1); @@ -382,7 +465,10 @@ Tensor sgn_backward(Tensor result, Tensor grad, Tensor self) { auto abs = at::abs(self); // C -> C // https://arxiv.org/pdf/1701.00392.pdf Section 4.20 - return at::where(abs == 0.0, at::zeros({}, grad.options()), (grad/abs - (at::real(grad/self) * result))); + return at::where( + abs == 0.0, + at::zeros({}, grad.options()), + (grad / abs - (at::real(grad / self) * result))); } else { return at::zeros_like(self, at::MemoryFormat::Preserve); } @@ -393,7 +479,11 @@ Tensor mul_tensor_backward(Tensor grad, Tensor other, ScalarType self_st) { return handle_r_to_c(self_st, out); } -Tensor div_tensor_self_backward(Tensor grad, Tensor other, ScalarType self_st, const c10::optional& rounding_mode) { +Tensor div_tensor_self_backward( + Tensor grad, + Tensor other, + ScalarType self_st, + const c10::optional& rounding_mode) { if (rounding_mode.has_value()) { return at::zeros_like(grad, grad.options().dtype(self_st)); } @@ -406,7 +496,11 @@ Tensor div_tensor_self_backward(Tensor grad, Tensor other, ScalarType self_st) { return div_tensor_self_backward(grad, other, self_st, c10::nullopt); } -Tensor div_tensor_other_backward(Tensor grad, Tensor self, Tensor other, const c10::optional& rounding_mode) { +Tensor div_tensor_other_backward( + Tensor grad, + Tensor self, + Tensor other, + const c10::optional& rounding_mode) { if (rounding_mode.has_value()) { return at::zeros_like(grad, grad.options().dtype(other.scalar_type())); } @@ -419,40 +513,46 @@ Tensor div_tensor_other_backward(Tensor grad, Tensor self, Tensor other) { return div_tensor_other_backward(grad, self, other, c10::nullopt); } -Tensor permute_backwards(const Tensor & grad, IntArrayRef fwd_dims) { +Tensor permute_backwards(const Tensor& grad, IntArrayRef fwd_dims) { // invert the permutation auto ndims = fwd_dims.size(); std::vector dims(ndims); - for(const auto i : c10::irange(ndims)) { + for (const auto i : c10::irange(ndims)) { dims[at::maybe_wrap_dim(fwd_dims[i], ndims)] = i; } return grad.permute(dims); } Tensor rad2deg_backward(const Tensor& grad) { - constexpr double M_180_PI = 57.295779513082320876798154814105170332405472466564; + constexpr double M_180_PI = + 57.295779513082320876798154814105170332405472466564; return at::mul(grad, at::native::wrapped_scalar_tensor(Scalar(M_180_PI))); } Tensor deg2rad_backward(const Tensor& grad) { - constexpr double M_PI_180 = 0.017453292519943295769236907684886127134428718885417; + constexpr double M_PI_180 = + 0.017453292519943295769236907684886127134428718885417; return at::mul(grad, at::native::wrapped_scalar_tensor(Scalar(M_PI_180))); } -Tensor unsqueeze_multiple(const Tensor & t, IntArrayRef dim, size_t n_dims) { - auto dims_to_unsqueeze = at::dim_list_to_bitset(dim, n_dims); - Tensor res = t; - for(const auto i : c10::irange(n_dims)){ - if (dims_to_unsqueeze[i]) { - res = res.unsqueeze(i); - } +Tensor unsqueeze_multiple(const Tensor& t, IntArrayRef dim, size_t n_dims) { + auto dims_to_unsqueeze = at::dim_list_to_bitset(dim, n_dims); + Tensor res = t; + for (const auto i : c10::irange(n_dims)) { + if (dims_to_unsqueeze[i]) { + res = res.unsqueeze(i); } - return res; + } + return res; } -Tensor sum_backward(const Tensor & grad, IntArrayRef sizes, IntArrayRef dims, bool keepdim) { +Tensor sum_backward( + const Tensor& grad, + IntArrayRef sizes, + IntArrayRef dims, + bool keepdim) { if (!keepdim && sizes.size() > 0) { - if (dims.size()==1) { + if (dims.size() == 1) { return grad.unsqueeze(dims[0]).expand(sizes); } else { Tensor res = unsqueeze_multiple(grad, dims, sizes.size()); @@ -463,10 +563,14 @@ Tensor sum_backward(const Tensor & grad, IntArrayRef sizes, IntArrayRef dims, bo } } -Tensor nansum_backward(const Tensor & grad, const Tensor & self, IntArrayRef dims, bool keepdim) { +Tensor nansum_backward( + const Tensor& grad, + const Tensor& self, + IntArrayRef dims, + bool keepdim) { auto sizes = self.sizes(); if (!keepdim && sizes.size() > 0) { - if (dims.size()==1) { + if (dims.size() == 1) { return grad.unsqueeze(dims[0]).expand(sizes) * self.isnan().logical_not(); } else { Tensor res = unsqueeze_multiple(grad, dims, sizes.size()); @@ -487,11 +591,15 @@ std::vector reverse_list(const IntArrayRef list) { } Tensor reverse_dim(const Tensor& t, int64_t dim) { - Tensor index = at::arange(t.size(dim) - 1, -1, -1, t.options().dtype(at::kLong)); + Tensor index = + at::arange(t.size(dim) - 1, -1, -1, t.options().dtype(at::kLong)); return t.index_select(dim, index); } -Tensor prod_safe_zeros_backward(const Tensor &grad, const Tensor& inp, int64_t dim) { +Tensor prod_safe_zeros_backward( + const Tensor& grad, + const Tensor& inp, + int64_t dim) { if (inp.size(dim) == 1) { return grad; } @@ -499,12 +607,15 @@ Tensor prod_safe_zeros_backward(const Tensor &grad, const Tensor& inp, int64_t d auto ones_size = inp.sizes().vec(); ones_size[dim] = 1; Tensor ones = at::ones(ones_size, grad.options()); - Tensor exclusive_normal_nocp = at::cat({ones, inp.narrow(dim, 0, inp.size(dim) - 1)}, dim); + Tensor exclusive_normal_nocp = + at::cat({ones, inp.narrow(dim, 0, inp.size(dim) - 1)}, dim); Tensor exclusive_normal = exclusive_normal_nocp.cumprod(dim); - Tensor narrow_reverse = reverse_dim(inp.narrow(dim, 1, inp.size(dim) - 1), dim); + Tensor narrow_reverse = + reverse_dim(inp.narrow(dim, 1, inp.size(dim) - 1), dim); Tensor exclusive_reverse_nocp = at::cat({ones, narrow_reverse}, dim); - Tensor exclusive_reverse = reverse_dim(exclusive_reverse_nocp.cumprod(dim), dim); + Tensor exclusive_reverse = + reverse_dim(exclusive_reverse_nocp.cumprod(dim), dim); return grad * (exclusive_normal * exclusive_reverse).conj(); } @@ -516,7 +627,10 @@ Tensor prod_safe_zeros_backward(const Tensor &grad, const Tensor& inp, int64_t d // cumprod(exclusive, reverse): [b * c, c, 1] // product: [b * c, a * c, a * b] // and this is safe under input with 0s. -Tensor prod_backward(const Tensor& grad, const Tensor& input, const Tensor& result) { +Tensor prod_backward( + const Tensor& grad, + const Tensor& input, + const Tensor& result) { if (input.dim() == 0) { return grad; } @@ -526,11 +640,17 @@ Tensor prod_backward(const Tensor& grad, const Tensor& input, const Tensor& resu } else if (zero_idx.size(0) > 1) { return at::zeros_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT); } else { - return prod_safe_zeros_backward(grad, input.contiguous().view(-1), 0).view_as(input); + return prod_safe_zeros_backward(grad, input.contiguous().view(-1), 0) + .view_as(input); } } -Tensor prod_backward(Tensor grad, const Tensor& input, Tensor result, int64_t dim, bool keepdim) { +Tensor prod_backward( + Tensor grad, + const Tensor& input, + Tensor result, + int64_t dim, + bool keepdim) { if (input.dim() == 0) { return grad; } @@ -552,18 +672,21 @@ Tensor prod_backward(Tensor grad, const Tensor& input, Tensor result, int64_t di template static Tensor generic_solve_jvp( - solve_f solve, - const Tensor& X, const Tensor& A, - const Tensor& dA, const Tensor& dB) { + solve_f solve, + const Tensor& X, + const Tensor& A, + const Tensor& dA, + const Tensor& dB) { auto is_vector_case = at::native::linalg_solve_is_vector_rhs(dA, dB); - auto dA_contrib = is_vector_case ? dA.matmul(X.unsqueeze(-1)).squeeze(-1) : dA.matmul(X); + auto dA_contrib = + is_vector_case ? dA.matmul(X.unsqueeze(-1)).squeeze(-1) : dA.matmul(X); // In general, - // dX = solve(A, dB - dA_contrib), but this behavior is different for lu_solve. - // For refer to lu_solve_jvp for more details on this. + // dX = solve(A, dB - dA_contrib), but this behavior is different for + // lu_solve. For refer to lu_solve_jvp for more details on this. return solve(A, dB, dA_contrib); } -Tensor cumsum_backward(const Tensor & grad, int64_t dim) { +Tensor cumsum_backward(const Tensor& grad, int64_t dim) { // Trivial case if (grad.numel() <= 1 || grad.size(dim) == 1) { return grad; @@ -571,7 +694,12 @@ Tensor cumsum_backward(const Tensor & grad, int64_t dim) { return grad.flip(dim).cumsum(dim).flip(dim); } -Tensor logsumexp_backward(Tensor grad, const Tensor & self, Tensor result, IntArrayRef dim, bool keepdim) { +Tensor logsumexp_backward( + Tensor grad, + const Tensor& self, + Tensor result, + IntArrayRef dim, + bool keepdim) { if (!keepdim && self.dim() != 0) { grad = unsqueeze_multiple(grad, dim, self.sizes().size()); result = unsqueeze_multiple(result, dim, self.sizes().size()); @@ -579,14 +707,19 @@ Tensor logsumexp_backward(Tensor grad, const Tensor & self, Tensor result, IntAr return grad * (self - result).exp(); } -Tensor logcumsumexp_backward(Tensor grad, const Tensor & self, Tensor result, int64_t dim) { +Tensor logcumsumexp_backward( + Tensor grad, + const Tensor& self, + Tensor result, + int64_t dim) { if (grad.dim() == 0 || grad.numel() == 0) { return grad; } // Reference: https://github.com/tensorflow/tensorflow/blob/ // 2a5910906a0e0f3dbc186ff9db6386d81a63448c/tensorflow/python/ops/math_grad.py#L1832-L1863 - return AT_DISPATCH_FLOATING_TYPES_AND(at::ScalarType::BFloat16, + return AT_DISPATCH_FLOATING_TYPES_AND( + at::ScalarType::BFloat16, at::typeMetaToScalarType(grad.dtype()), "logcumsumexp_backward", [grad, self, result, dim]() { @@ -625,11 +758,11 @@ Tensor unbind_backward(const variable_list& grads, int64_t dim) { return at::stack(grads_tensors, dim); } -Tensor unsqueeze_to(const Tensor & self, IntArrayRef sizes) { +Tensor unsqueeze_to(const Tensor& self, IntArrayRef sizes) { auto result = self; int64_t nDims = sizes.size(); - for(const auto dim : c10::irange(nDims)) { + for (const auto dim : c10::irange(nDims)) { if (sizes[dim] == 1) { result = result.unsqueeze(dim); } @@ -637,17 +770,21 @@ Tensor unsqueeze_to(const Tensor & self, IntArrayRef sizes) { return result; } -Tensor unsqueeze_to(const Tensor & self, int64_t dim, IntArrayRef sizes) { +Tensor unsqueeze_to(const Tensor& self, int64_t dim, IntArrayRef sizes) { dim = at::maybe_wrap_dim(dim, sizes.size()); - // in NumPy it's not an error to unsqueeze a scalar, but we still need to avoided - // unsqueezing in the backward. + // in NumPy it's not an error to unsqueeze a scalar, but we still need to + // avoided unsqueezing in the backward. if (sizes.size() > 0 && sizes[dim] == 1) { return self.unsqueeze(dim); } return self; } -std::vector cat_tensors_backward(const Tensor & grad, const std::vector> &sizes, const std::vector &dtypes, int64_t dim) { +std::vector cat_tensors_backward( + const Tensor& grad, + const std::vector>& sizes, + const std::vector& dtypes, + int64_t dim) { std::vector grad_inputs(sizes.size()); if (!grad.defined()) { return grad_inputs; @@ -681,7 +818,10 @@ std::vector cat_tensors_backward(const Tensor & grad, const std::vector< return grad_inputs; } -std::vector block_diag_backward(const Tensor & grad, const std::vector> &sizes, const std::vector &dtypes) { +std::vector block_diag_backward( + const Tensor& grad, + const std::vector>& sizes, + const std::vector& dtypes) { std::vector grad_inputs(sizes.size()); if (!grad.defined()) { return grad_inputs; @@ -697,7 +837,9 @@ std::vector block_diag_backward(const Tensor & grad, const std::vector C - Tensor grad_val = (!at::isComplexType(dtypes[i]) && grad_is_complex) ? real_view_of_grad : grad; + Tensor grad_val = (!at::isComplexType(dtypes[i]) && grad_is_complex) + ? real_view_of_grad + : grad; auto& shape = sizes[i]; // If input was empty tensor, gradInput should be empty tensor. @@ -712,11 +854,12 @@ std::vector block_diag_backward(const Tensor & grad, const std::vector block_diag_backward(const Tensor & grad, const std::vector & min, const optional & max) { - // clamp: gradients not defined on min and max, so we return the subgradient 1 for these cases. +Tensor clamp_backward( + const Tensor& grad, + const Tensor& self, + const optional& min, + const optional& max) { + // clamp: gradients not defined on min and max, so we return the subgradient 1 + // for these cases. if (max && min) { auto zero = at::scalar_tensor(0., grad.options()); return where((self >= *min).logical_and_(self <= *max), grad, zero); @@ -745,15 +893,20 @@ Tensor clamp_backward(const Tensor & grad, const Tensor &self, const optional= min; const auto self_le_max = self <= max; - const auto& pred = areAnyTensorSubclassLike({self, min, max}) ? - self_ge_min.logical_and(self_le_max) : - self_ge_min.logical_and_(self_le_max); + const auto& pred = areAnyTensorSubclassLike({self, min, max}) + ? self_ge_min.logical_and(self_le_max) + : self_ge_min.logical_and_(self_le_max); return where(pred, grad, zero); } else if (min.defined()) { auto zero = at::scalar_tensor(0., grad.options()); @@ -767,7 +920,10 @@ Tensor clamp_backward(const Tensor & grad, const Tensor &self, const Tensor& min } std::tuple clamp_backward_min_max( - const Tensor& grad, const Tensor& self, const Tensor& min, const Tensor& max, + const Tensor& grad, + const Tensor& self, + const Tensor& min, + const Tensor& max, const std::array& grad_input_mask) { // If min > max, min has no gradient std::tuple ret; @@ -780,17 +936,17 @@ std::tuple clamp_backward_min_max( if (grad_input_mask[0]) { const auto self_lt_min = self < min; const auto min_lt_max = min < max; - const auto& pred = areAnyTensorSubclassLike({self, min, max}) ? - self_lt_min.logical_and(min_lt_max) : - self_lt_min.logical_and_(min_lt_max); + const auto& pred = areAnyTensorSubclassLike({self, min, max}) + ? self_lt_min.logical_and(min_lt_max) + : self_lt_min.logical_and_(min_lt_max); std::get<0>(ret) = where(pred, grad, zero); } if (grad_input_mask[1]) { const auto self_gt_max = self > max; const auto max_lt_min = max < min; - const auto& pred = areAnyTensorSubclassLike({self, min, max}) ? - self_gt_max.logical_or(max_lt_min) : - self_gt_max.logical_or_(max_lt_min); + const auto& pred = areAnyTensorSubclassLike({self, min, max}) + ? self_gt_max.logical_or(max_lt_min) + : self_gt_max.logical_or_(max_lt_min); std::get<1>(ret) = where(pred, grad, zero); } } else if (min.defined() && grad_input_mask[0]) { @@ -802,12 +958,17 @@ std::tuple clamp_backward_min_max( } at::Tensor clamp_jvp( - const Tensor& self_p, const Tensor& self_t, - const Tensor& min_p, const Tensor& min_t, - const Tensor& max_p, const Tensor& max_t -) { + const Tensor& self_p, + const Tensor& self_t, + const Tensor& min_p, + const Tensor& min_t, + const Tensor& max_p, + const Tensor& max_t) { if (min_p.defined() && max_p.defined()) { - return where(min_p > max_p, max_t, where(self_p < min_p, min_t, where(self_p > max_p, max_t, self_t))); + return where( + min_p > max_p, + max_t, + where(self_p < min_p, min_t, where(self_p > max_p, max_t, self_t))); } else if (min_p.defined()) { return where(self_p > min_p, self_t, min_t); } else if (max_p.defined()) { @@ -818,32 +979,91 @@ at::Tensor clamp_jvp( } Tensor convolution_jvp( - const Tensor& input_p, const Tensor& input_t, - const Tensor& weight_p, const Tensor& weight_t, - const Tensor& bias_p, const Tensor& bias_t, - IntArrayRef stride, IntArrayRef padding, IntArrayRef dilation, - bool transposed, IntArrayRef output_padding, int64_t groups) { - auto bias_t_opt = bias_t.defined() ? c10::optional(bias_t) : c10::nullopt; + const Tensor& input_p, + const Tensor& input_t, + const Tensor& weight_p, + const Tensor& weight_t, + const Tensor& bias_p, + const Tensor& bias_t, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + bool transposed, + IntArrayRef output_padding, + int64_t groups) { + auto bias_t_opt = + bias_t.defined() ? c10::optional(bias_t) : c10::nullopt; return ( - at::convolution(input_t, weight_p, c10::nullopt, stride, padding, dilation, transposed, output_padding, groups) - + at::convolution(input_p, weight_t, bias_t_opt, stride, padding, dilation, transposed, output_padding, groups)); + at::convolution( + input_t, + weight_p, + c10::nullopt, + stride, + padding, + dilation, + transposed, + output_padding, + groups) + + at::convolution( + input_p, + weight_t, + bias_t_opt, + stride, + padding, + dilation, + transposed, + output_padding, + groups)); } Tensor _convolution_jvp( - const Tensor& input_p, const Tensor& input_t, - const Tensor& weight_p, const Tensor& weight_t, - const Tensor& bias_p, const Tensor& bias_t, - IntArrayRef stride, IntArrayRef padding, IntArrayRef dilation, - bool transposed, IntArrayRef output_padding, int64_t groups, - bool benchmark, bool deterministic, bool cudnn_enabled, bool allow_tf32) { - auto bias_t_opt = bias_t.defined() ? c10::optional(bias_t) : c10::nullopt; + const Tensor& input_p, + const Tensor& input_t, + const Tensor& weight_p, + const Tensor& weight_t, + const Tensor& bias_p, + const Tensor& bias_t, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + bool transposed, + IntArrayRef output_padding, + int64_t groups, + bool benchmark, + bool deterministic, + bool cudnn_enabled, + bool allow_tf32) { + auto bias_t_opt = + bias_t.defined() ? c10::optional(bias_t) : c10::nullopt; return ( at::_convolution( - input_t, weight_p, c10::nullopt, stride, padding, dilation, transposed, output_padding, - groups, benchmark, deterministic, cudnn_enabled, allow_tf32) - + at::_convolution( - input_p, weight_t, bias_t_opt, stride, padding, dilation, transposed, output_padding, - groups, benchmark, deterministic, cudnn_enabled, allow_tf32)); + input_t, + weight_p, + c10::nullopt, + stride, + padding, + dilation, + transposed, + output_padding, + groups, + benchmark, + deterministic, + cudnn_enabled, + allow_tf32) + + at::_convolution( + input_p, + weight_t, + bias_t_opt, + stride, + padding, + dilation, + transposed, + output_padding, + groups, + benchmark, + deterministic, + cudnn_enabled, + allow_tf32)); } Tensor convolution_backward_jvp_grad_bias( @@ -873,9 +1093,9 @@ Tensor convolution_backward_jvp_grad_bias( // calls that appear in derivative formulas. If the tensor has requires_grad // set, this function returns its strides or throws an error if the tensor // is sparse. If requires_grad is not set, an empty array is returned since -// there will be no backward pass. There has one special case, if input is MKLDNN -// tensor and has requires_grad set, just return an empty array, the reason is -// that MKLDNN tensor is a opaque tensor which has not stride info. +// there will be no backward pass. There has one special case, if input is +// MKLDNN tensor and has requires_grad set, just return an empty array, the +// reason is that MKLDNN tensor is a opaque tensor which has not stride info. // // This function only supports the case where `input` is the tensor whose // single derivative is being calculated. @@ -885,27 +1105,40 @@ Tensor convolution_backward_jvp_grad_bias( // Args: // input Tensor to call .strides() on // input_name Name of `input` tensor, from derivative formula -at::IntArrayRef strides_or_error(const Tensor & input, c10::string_view const & input_name) { +at::IntArrayRef strides_or_error( + const Tensor& input, + c10::string_view const& input_name) { // TODO: Ideally, this function would never be called if requires_grad is // not set. Once codegen is updated to avoid the call, we can remove this // check. if (input.requires_grad()) { TORCH_CHECK( - !input.is_sparse(), - "The backward pass for this operation requires the '", input_name, - "' tensor to be strided, but a sparse tensor was given instead. ", - "Please either use a strided tensor or set requires_grad=False for '", - input_name, "'"); - if (input.is_mkldnn()) return IntArrayRef({}); - if (input.is_sparse_csr()) return IntArrayRef({}); + !input.is_sparse(), + "The backward pass for this operation requires the '", + input_name, + "' tensor to be strided, but a sparse tensor was given instead. ", + "Please either use a strided tensor or set requires_grad=False for '", + input_name, + "'"); + if (input.is_mkldnn()) + return IntArrayRef({}); + if (input.is_sparse_csr()) + return IntArrayRef({}); return input.strides(); } else { return IntArrayRef({}); } } -Tensor mm_mat1_backward(const Tensor& grad, const Tensor& mat2, at::IntArrayRef mat1_sizes, at::IntArrayRef mat1_strides, c10::Layout mat1_layout, const Scalar& alpha) { - if (grad.layout() == c10::kStrided && mat2.layout() == c10::kStrided && mat1_layout == c10::kStrided) { +Tensor mm_mat1_backward( + const Tensor& grad, + const Tensor& mat2, + at::IntArrayRef mat1_sizes, + at::IntArrayRef mat1_strides, + c10::Layout mat1_layout, + const Scalar& alpha) { + if (grad.layout() == c10::kStrided && mat2.layout() == c10::kStrided && + mat1_layout == c10::kStrided) { // if input was column-major, return grad as column-order for efficiency if (mat1_strides[0] == 1 && mat1_strides[1] == mat1_sizes[0]) { return maybe_multiply(mat2.conj().mm(grad.t()).t(), alpha.conj()); @@ -916,8 +1149,15 @@ Tensor mm_mat1_backward(const Tensor& grad, const Tensor& mat2, at::IntArrayRef return maybe_multiply(grad.mm(mat2.t().conj()), alpha.conj()); } -Tensor mm_mat2_backward(const Tensor& grad, const Tensor& mat1, IntArrayRef mat2_sizes, IntArrayRef mat2_strides, c10::Layout mat2_layout, const Scalar& alpha) { - if (grad.layout() == c10::kStrided && mat1.layout() == c10::kStrided && mat2_layout == c10::kStrided) { +Tensor mm_mat2_backward( + const Tensor& grad, + const Tensor& mat1, + IntArrayRef mat2_sizes, + IntArrayRef mat2_strides, + c10::Layout mat2_layout, + const Scalar& alpha) { + if (grad.layout() == c10::kStrided && mat1.layout() == c10::kStrided && + mat2_layout == c10::kStrided) { // if input was column-major, return grad as column-order for efficiency if (mat2_strides[0] == 1 && mat2_strides[1] == mat2_sizes[0]) { return maybe_multiply(grad.t().mm(mat1.conj()).t(), alpha.conj()); @@ -928,29 +1168,45 @@ Tensor mm_mat2_backward(const Tensor& grad, const Tensor& mat1, IntArrayRef mat2 return maybe_multiply(mat1.t().conj().mm(grad), alpha.conj()); } -Tensor mm_mat1_sparse_backward(const Tensor& grad, const Tensor& mat1, const Tensor& mat2, const Scalar& alpha) { - if (grad.layout() == c10::kStrided && mat2.layout() == c10::kStrided && mat1.is_sparse()) { +Tensor mm_mat1_sparse_backward( + const Tensor& grad, + const Tensor& mat1, + const Tensor& mat2, + const Scalar& alpha) { + if (grad.layout() == c10::kStrided && mat2.layout() == c10::kStrided && + mat1.is_sparse()) { auto sparse = mat1.coalesce(); Tensor grad_sparse = maybe_multiply(grad.mm(mat2.conj().t()), alpha); return grad_sparse.sparse_mask(sparse); - } else if (grad.layout() == c10::kStrided && mat2.layout() == c10::kStrided && mat1.is_sparse_csr()) { - return at::sparse_sampled_addmm(at::zeros_like(mat1, mat1.options()), grad, mat2.mH(), 1.0, alpha); - } else if (grad.layout() == c10::kStrided && mat2.layout() == c10::kStrided && mat1.layout() == c10::kStrided) { + } else if ( + grad.layout() == c10::kStrided && mat2.layout() == c10::kStrided && + mat1.is_sparse_csr()) { + return at::sparse_sampled_addmm( + at::zeros_like(mat1, mat1.options()), grad, mat2.mH(), 1.0, alpha); + } else if ( + grad.layout() == c10::kStrided && mat2.layout() == c10::kStrided && + mat1.layout() == c10::kStrided) { return maybe_multiply(grad.mm(mat2.mH()), alpha); } - TORCH_CHECK(false, "sparse_addmm_sparse_backward: unsupported combination of layouts", - ", grad: ", grad.layout(), - ", mat1: ", mat1.layout(), - ", mat2: ", mat2.layout()); -} - -// This function return a new SparseTensor with values from Tensor `input` filtered by indices of `mask` -// and values are ignored. `input` and `mask` are sparse matrices, a sparse tensor with sparse_dim=2 and dense_dim=2, -// and they must have the same shape. -// Note that the `output` must have the same `indices` as the `mask` so we are using just a clone. -// However, to get `values` we have to use specific helper function for CPU/CUDA and use the `mask` data to filter `values` -// That's why we created this `_sparse_mask_helper` function. -Tensor _sparse_matrix_mask(const Tensor& input, const Tensor& mask){ + TORCH_CHECK( + false, + "sparse_addmm_sparse_backward: unsupported combination of layouts", + ", grad: ", + grad.layout(), + ", mat1: ", + mat1.layout(), + ", mat2: ", + mat2.layout()); +} + +// This function return a new SparseTensor with values from Tensor `input` +// filtered by indices of `mask` and values are ignored. `input` and `mask` are +// sparse matrices, a sparse tensor with sparse_dim=2 and dense_dim=2, and they +// must have the same shape. Note that the `output` must have the same `indices` +// as the `mask` so we are using just a clone. However, to get `values` we have +// to use specific helper function for CPU/CUDA and use the `mask` data to +// filter `values` That's why we created this `_sparse_mask_helper` function. +Tensor _sparse_matrix_mask(const Tensor& input, const Tensor& mask) { Tensor output = at::empty_like(mask); Tensor mask_indices = mask._indices().clone(); Tensor r_values; @@ -959,7 +1215,8 @@ Tensor _sparse_matrix_mask(const Tensor& input, const Tensor& mask){ } else { r_values = _sparse_mask_helper(input, mask_indices.contiguous()); } - at::sparse::get_sparse_impl(output)->set_indices_and_values_unsafe(mask_indices, r_values); + at::sparse::get_sparse_impl(output)->set_indices_and_values_unsafe( + mask_indices, r_values); return output; } @@ -969,8 +1226,8 @@ Tensor sparse_sparse_matmul_backward( const Tensor& b, int64_t grad_order) { /* - To implement the backward algorithm for sparse matrix-matrix matmul (SPMM) we can start from the following definition - for dense tensors: + To implement the backward algorithm for sparse matrix-matrix matmul (SPMM) we + can start from the following definition for dense tensors: c = a @ b then @@ -995,7 +1252,12 @@ Tensor sparse_sparse_matmul_backward( return _sparse_matrix_mask(b_grad.coalesce(), b.coalesce()); } -Tensor renorm_backward(const Tensor & grad, const Tensor & self, const Scalar& p_s, int64_t dim, const Scalar& maxnorm) { +Tensor renorm_backward( + const Tensor& grad, + const Tensor& self, + const Scalar& p_s, + int64_t dim, + const Scalar& maxnorm) { auto self_sizes = self.sizes(); dim = c10::maybe_wrap_dim(dim, self_sizes.size()); at::DimVector reduce_dims(self_sizes.size()); @@ -1010,8 +1272,7 @@ Tensor renorm_backward(const Tensor & grad, const Tensor & self, const Scalar& p norm = at::linalg_vector_norm( self, p, reduce_dims, /*keepdim=*/true, /*dtype=*/acc_type); } else { - norm = at::linalg_vector_norm( - self, p, reduce_dims, /*keepdim=*/true); + norm = at::linalg_vector_norm(self, p, reduce_dims, /*keepdim=*/true); } const auto real_acc_type = c10::toRealValueType(acc_type); @@ -1020,16 +1281,20 @@ Tensor renorm_backward(const Tensor & grad, const Tensor & self, const Scalar& p if (real_acc_type != acc_type) { grad_output = at::real(grad_output); } - grad_output = grad_output.sum( - reduce_dims, /*keepdim=*/true, /*dtype=*/real_acc_type); - auto nb = norm_backward(grad_output, self, p, norm, reduce_dims, /*keepdim=*/true); + grad_output = + grad_output.sum(reduce_dims, /*keepdim=*/true, /*dtype=*/real_acc_type); + auto nb = + norm_backward(grad_output, self, p, norm, reduce_dims, /*keepdim=*/true); auto invnorm = (norm + 1e-7).reciprocal(); auto grad_norm = maxnorm * invnorm * (grad - invnorm * nb); return at::where(norm > maxnorm, grad_norm.to(grad.scalar_type()), grad); } -Tensor repeat_backward(Tensor grad, IntArrayRef repeats, IntArrayRef input_shape) { +Tensor repeat_backward( + Tensor grad, + IntArrayRef repeats, + IntArrayRef input_shape) { auto find_iter = std::find(repeats.cbegin(), repeats.cend(), 0); if (find_iter != repeats.cend()) { return at::zeros(input_shape, grad.options()); @@ -1047,49 +1312,64 @@ Tensor repeat_backward(Tensor grad, IntArrayRef repeats, IntArrayRef input_shape // Reshape gradient (repeat > 1) // Index: [..., dim , ...] [..., dim , dim+1 , ...] // Shape: From [..., dimsize, ...] to [..., repeat, dimsize/repeat, ...] - // The gradient tensor at 'dim' is reshaped to 'repeat' times of input tensor. - // Then, sum up gradients over repeated tensors along 'dim', and reduce shape - // from 'repeat * dimsize/repeat' to 'dimsize/repeat' ('input_dimsize'). - // Example: + // The gradient tensor at 'dim' is reshaped to 'repeat' times of input + // tensor. Then, sum up gradients over repeated tensors along 'dim', and + // reduce shape from 'repeat * dimsize/repeat' to 'dimsize/repeat' + // ('input_dimsize'). Example: // Size(3, 2) Size(6, 2) - // [[v1_0, v1_1], - // [v1_2, v1_3], - // [[v0, v1], repeat(2, 1) [v1_4, v1_5], - // [v2, v3], -------------> [v2_0, v2_1], - // [v4, v5]] [v2_2, v2_3], - // [v2_4, v2_5]] + // [[v1_0, + // v1_1], + // [v1_2, + // v1_3], + // [[v0, v1], repeat(2, 1) [v1_4, + // v1_5], + // [v2, v3], -------------> [v2_0, + // v2_1], [v4, v5]] [v2_2, v2_3], + // [v2_4, + // v2_5]] // - // input grad (3, 2) reshape (2, 3, 2) output grad (6, 2) - // [[[g1_0, g1_1], [[g1_0, g1_1], - // [g1_2, g1_3], [g1_2, g1_3], - // [[g1_0+g2_0, g1_1+g2_1], [g1_4, g1_5]], [g1_4, g1_5], - // [g1_0+g2_0, g1_1+g2_1], [g2_0, g2_1], - // [g1_0+g2_0, g1_1+g2_1]] [[g2_0, g2_1], [g2_2, g2_3], - // [g2_2, g2_3], [g2_4, g2_5]] - // [g2_4, g2_5]]] - // If gradient tensor is reshaped to [..., dimsize/repeat, repeat, ...] and then - // sum over 'dim+1'. The gradient for input is not correctly aligned with input. - // Example: - // input grad (3, 2) reshape (3, 2, 2) output grad (6, 2) + // input grad (3, 2) reshape (2, 3, 2) output grad (6, + // 2) + // [[[g1_0, g1_1], [[g1_0, + // g1_1], + // [g1_2, g1_3], [g1_2, + // g1_3], + // [[g1_0+g2_0, g1_1+g2_1], [g1_4, g1_5]], [g1_4, + // g1_5], + // [g1_0+g2_0, g1_1+g2_1], [g2_0, + // g2_1], [g1_0+g2_0, g1_1+g2_1]] [[g2_0, g2_1], [g2_2, g2_3], + // [g2_2, g2_3], [g2_4, + // g2_5]] [g2_4, g2_5]]] + // If gradient tensor is reshaped to [..., dimsize/repeat, repeat, ...] and + // then sum over 'dim+1'. The gradient for input is not correctly aligned + // with input. Example: + // input grad (3, 2) reshape (3, 2, 2) output grad (6, + // 2) // [[[g1_0, g1_1], - // [g1_2, g1_3]], [[g1_0, g1_1], - // [g1_2, g1_3], - // [[g1_0+g1_2, g1_1+g1_3], [[g1_4, g1_5], [g1_4, g1_5], - // [g1_4+g2_0, g1_5+g2_1], [g2_0, g2_1]], [g2_0, g2_1], - // [g2_2+g2_4, g2_3+g2_5]] [g2_2, g2_3], - // [[g2_2, g2_3], [g2_4, g2_5]] + // [g1_2, g1_3]], [[g1_0, + // g1_1], + // [g1_2, + // g1_3], + // [[g1_0+g1_2, g1_1+g1_3], [[g1_4, g1_5], [g1_4, + // g1_5], + // [g1_4+g2_0, g1_5+g2_1], [g2_0, g2_1]], [g2_0, + // g2_1], [g2_2+g2_4, g2_3+g2_5]] [g2_2, g2_3], + // [[g2_2, g2_3], [g2_4, + // g2_5]] // [g2_4, g2_5]]] if (repeat != 1) { grad_size.push_back(repeat); sum_dims.push_back(grad_size.size() - 1); } - // Don't need to reshape gradient into (repeat, input_shape[dim]) (repeat == 1) + // Don't need to reshape gradient into (repeat, input_shape[dim]) (repeat == + // 1) grad_size.push_back(input_shape[dim]); } // One-time Reshape & Sum // Reshape gradient to grad_size: // 1. If repeat equals to 1, append input size at that dimension, - // 2. If repeat is larger than 1, append both repeat and input size at that dimension. + // 2. If repeat is larger than 1, append both repeat and input size at that + // dimension. // Sum over all "repeat" dimensions from sum_dims: // Example: // Input Size (2, 3, 4, 5) @@ -1098,8 +1378,9 @@ Tensor repeat_backward(Tensor grad, IntArrayRef repeats, IntArrayRef input_shape // grad_size [4, 2, 3, 9, 4, 3, 5] // sum_dims [0, 3, 5] - // When repeat 1 time over all original dimensions, the empty sum_dims will reduce - // the whole grad tensor into a scalar rather than keeping original dimensions. + // When repeat 1 time over all original dimensions, the empty sum_dims will + // reduce the whole grad tensor into a scalar rather than keeping original + // dimensions. if (!sum_dims.empty()) { grad = grad.reshape(grad_size); grad = grad.sum(sum_dims); @@ -1118,44 +1399,66 @@ Tensor _fused_dropout_backward(Tensor grad, Tensor mask, double p1m) { } // scale == (1 / (1 - prob)) -Tensor infinitely_differentiable_native_dropout_backward(const Tensor& grad, const Tensor& mask, double scale) { +Tensor infinitely_differentiable_native_dropout_backward( + const Tensor& grad, + const Tensor& mask, + double scale) { return grad * (mask.type_as(grad) * scale); } -Tensor native_dropout_double_backward(const Tensor& ggI, const Tensor& grad, const Tensor& mask, double scale) { +Tensor native_dropout_double_backward( + const Tensor& ggI, + const Tensor& grad, + const Tensor& mask, + double scale) { return ggI.type_as(grad) * (mask.type_as(grad) * scale); } -Tensor evenly_distribute_backward(Tensor grad, const Tensor & input, const Tensor & value) { - bool any_tensor_subclass_like = areAnyTensorSubclassLike({grad, input, value}); +Tensor evenly_distribute_backward( + Tensor grad, + const Tensor& input, + const Tensor& value) { + bool any_tensor_subclass_like = + areAnyTensorSubclassLike({grad, input, value}); if (any_tensor_subclass_like || input.is_cuda()) { const auto input_isnan = input.isnan(); const auto value_isnan = value.isnan(); - const auto& input_and_value_isnan = any_tensor_subclass_like ? - input_isnan.logical_and(value_isnan) : - input_isnan.logical_and_(value_isnan); + const auto& input_and_value_isnan = any_tensor_subclass_like + ? input_isnan.logical_and(value_isnan) + : input_isnan.logical_and_(value_isnan); const auto mask = (input == value).logical_or_(input_and_value_isnan); return mask * (grad / mask.sum()); } else { auto mask = value.isnan().item() ? input.isnan() : input == value; - return grad.new_zeros(input.sizes(), input.options()).masked_fill_(mask, grad / mask.sum()); + return grad.new_zeros(input.sizes(), input.options()) + .masked_fill_(mask, grad / mask.sum()); } } -Tensor evenly_read_jvp(const Tensor& fw_grad, const Tensor & input, const Tensor & value) { +Tensor evenly_read_jvp( + const Tensor& fw_grad, + const Tensor& input, + const Tensor& value) { auto mask = (input == value); auto count = mask.sum(); auto grad_output = fw_grad / count; return at::sum(mask * grad_output); } -static Tensor var_backward(const Tensor & grad, const Tensor & self, int64_t correction) { +static Tensor var_backward( + const Tensor& grad, + const Tensor& self, + int64_t correction) { // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-narrowing-conversions) return (2.0 / (self.numel() - correction)) * grad * (self - self.mean()); } -Tensor var_backward(Tensor grad, const Tensor& self, at::OptionalIntArrayRef dim_opt, - c10::optional correction_opt, bool keepdim) { +Tensor var_backward( + Tensor grad, + const Tensor& self, + at::OptionalIntArrayRef dim_opt, + c10::optional correction_opt, + bool keepdim) { auto correction = correction_opt.value_or(1); if (self.dim() == 0 || !dim_opt.has_value()) { return var_backward(grad, self, correction); @@ -1169,25 +1472,44 @@ Tensor var_backward(Tensor grad, const Tensor& self, at::OptionalIntArrayRef dim return (2.0 / dof) * grad * (self - self.mean(dim, /*keepdim=*/true)); } -Tensor var_jvp(const Tensor& self_t, const Tensor& self_p, const Tensor& result, at::OptionalIntArrayRef dim_opt, - c10::optional correction_opt, bool keepdim) { +Tensor var_jvp( + const Tensor& self_t, + const Tensor& self_p, + const Tensor& result, + at::OptionalIntArrayRef dim_opt, + c10::optional correction_opt, + bool keepdim) { auto correction = correction_opt.value_or(1); if (self_p.dim() == 0 || !dim_opt.has_value()) { - return var_backward(self_t.conj(), self_p, correction).sum().expand_as(result).conj(); + return var_backward(self_t.conj(), self_p, correction) + .sum() + .expand_as(result) + .conj(); } auto dim = dim_opt.value(); const int64_t dof = _safe_size(self_p.sizes(), dim) - correction; - return ((2.0 / dof) * self_t.conj() * (self_p - self_p.mean(dim, /*keepdim=*/true))).sum(dim, keepdim).conj(); + return ((2.0 / dof) * self_t.conj() * + (self_p - self_p.mean(dim, /*keepdim=*/true))) + .sum(dim, keepdim) + .conj(); } Tensor std_backward( - const Tensor& result, const Tensor& grad, const Tensor& self, - at::OptionalIntArrayRef dim, c10::optional correction, bool keepdim) { + const Tensor& result, + const Tensor& grad, + const Tensor& self, + at::OptionalIntArrayRef dim, + c10::optional correction, + bool keepdim) { auto grad_var = (grad / (result * 2)).masked_fill_(result == 0, 0); return var_backward(grad_var, self, dim, correction, keepdim); } -Tensor mean_backward(Tensor grad, const IntArrayRef sizes, IntArrayRef dim, bool keepdim) { +Tensor mean_backward( + Tensor grad, + const IntArrayRef sizes, + IntArrayRef dim, + bool keepdim) { return sum_backward(grad, sizes, dim, keepdim) / _safe_size(sizes, dim); } @@ -1196,8 +1518,11 @@ Tensor mean_backward(Tensor grad, const IntArrayRef sizes, int64_t numel) { } static Tensor mean_backward( - const Tensor& grad, const IntArrayRef sizes, int64_t numel, - at::OptionalIntArrayRef dim, bool keepdim) { + const Tensor& grad, + const IntArrayRef sizes, + int64_t numel, + at::OptionalIntArrayRef dim, + bool keepdim) { if (dim.has_value()) { return mean_backward(grad, sizes, *dim, keepdim); } else { @@ -1206,22 +1531,31 @@ static Tensor mean_backward( } Tensor var_std_mean_backward( - const variable_list& grads, const Tensor& self, const Tensor& r1, - const Tensor& r2, at::OptionalIntArrayRef dim, - c10::optional correction, bool keepdim, bool is_std) { + const variable_list& grads, + const Tensor& self, + const Tensor& r1, + const Tensor& r2, + at::OptionalIntArrayRef dim, + c10::optional correction, + bool keepdim, + bool is_std) { Tensor grad; if (grads[0].defined()) { grad = is_std ? std_backward(r1, grads[0], self, dim, correction, keepdim) : var_backward(grads[0], self, dim, correction, keepdim); } if (grads[1].defined()) { - Tensor mean_grad = mean_backward(grads[1], self.sizes(), self.numel(), dim, keepdim); + Tensor mean_grad = + mean_backward(grads[1], self.sizes(), self.numel(), dim, keepdim); grad = grad.defined() ? grad + mean_grad : mean_grad; } return grad; } -Tensor masked_scatter_backward(const Tensor & grad, const Tensor & mask, IntArrayRef sizes) { +Tensor masked_scatter_backward( + const Tensor& grad, + const Tensor& mask, + IntArrayRef sizes) { int64_t numel = 1; for (auto size : sizes) { numel *= size; @@ -1229,8 +1563,9 @@ Tensor masked_scatter_backward(const Tensor & grad, const Tensor & mask, IntArra auto mask_selected = grad.masked_select(mask); auto diff_nelem = numel - mask_selected.numel(); if (diff_nelem > 0) { - // because mask_selected returns a 1-d tensor with size of masked elements that are 1, - // we need to fill out the rest with zeros then reshape back to tensor2's size. + // because mask_selected returns a 1-d tensor with size of masked elements + // that are 1, we need to fill out the rest with zeros then reshape back to + // tensor2's size. auto zeros_fillin = at::zeros({diff_nelem}, grad.options()); mask_selected = at::cat({mask_selected, zeros_fillin}, 0); } @@ -1263,13 +1598,13 @@ Tensor cholesky_backward(const Tensor& gL, bool upper, const Tensor& L) { // From cholesky_jvp we have that // dL = L\pi(L^{-1}dA(L^-H)) // - // Let gL be the projection into the lower-triangular gradient wrt L. Taking adjoints we have - // gA = L^{-H}\pi^*((L^HgL).tril())L^{-1} - // where \pi^*(X) = 0.5 * (X + X^H - diag(X)) - // The only non-standard point of this derivation is noting that the adjoint to multiplying - // on the left by a lower triangular matrix L is multiplying by L^H and then projecting back to - // the lower triangular matrices (hence the .tril() projection) - // Note that the gradient is symmetric and not triangular. + // Let gL be the projection into the lower-triangular gradient wrt L. Taking + // adjoints we have gA = L^{-H}\pi^*((L^HgL).tril())L^{-1} where \pi^*(X) = + // 0.5 * (X + X^H - diag(X)) The only non-standard point of this derivation is + // noting that the adjoint to multiplying on the left by a lower triangular + // matrix L is multiplying by L^H and then projecting back to the lower + // triangular matrices (hence the .tril() projection) Note that the gradient + // is symmetric and not triangular. auto L_ = upper ? L.mH() : L; auto gL_ = upper ? gL.mH() : gL; @@ -1286,7 +1621,11 @@ Tensor cholesky_backward(const Tensor& gL, bool upper, const Tensor& L) { return gA; } -Tensor cholesky_inverse_backward(Tensor grad, Tensor L, bool upper, Tensor inverse) { +Tensor cholesky_inverse_backward( + Tensor grad, + Tensor L, + bool upper, + Tensor inverse) { at::NoTF32Guard disable_tf32; Tensor grad_L; if (grad.defined()) { @@ -1311,7 +1650,11 @@ Tensor cholesky_inverse_backward(Tensor grad, Tensor L, bool upper, Tensor inver // If X = (U^H U)^{-1} with U upper-triangular with a real positive diagonal, // then K becomes // K = -X dU^H X U -Tensor cholesky_inverse_jvp(const Tensor& F, const Tensor& dF, const Tensor& X, bool upper) { +Tensor cholesky_inverse_jvp( + const Tensor& F, + const Tensor& dF, + const Tensor& X, + bool upper) { at::NoTF32Guard disable_tf32; const auto CF = upper ? F : F.mH(); const auto dCF = upper ? dF.mH() : dF; @@ -1321,9 +1664,9 @@ Tensor cholesky_inverse_jvp(const Tensor& F, const Tensor& dF, const Tensor& X, // The formula for forward AD is adapted from // -// Golub, Gene H., and Victor Pereyra. "The Differentiation of Pseudo-Inverses and Nonlinear -// Least Squares Problems Whose Variables Separate." -// SIAM Journal on Numerical Analysis 10(2). (1973). 413-432. doi: 10.1137/0710036 +// Golub, Gene H., and Victor Pereyra. "The Differentiation of Pseudo-Inverses +// and Nonlinear Least Squares Problems Whose Variables Separate." SIAM Journal +// on Numerical Analysis 10(2). (1973). 413-432. doi: 10.1137/0710036 // // We present a short derivation below: // Let Ap := pinv(A), then Ap is the unique matrix such that @@ -1335,34 +1678,30 @@ Tensor cholesky_inverse_jvp(const Tensor& F, const Tensor& dF, const Tensor& X, // // dAp = dAp A Ap + Ap dA Ap + Ap A dAp [3] // -// In the rhs of [3] the products involving dAp could be expressed as products of -// Ap^i, A^j, dA^k with i, j, k in {1, H}, where X^H = X.mH(). -// To prove that, note (A Ap)^H = A Ap and (Ap A)^H = Ap A, which could be shown by -// taking the product between the SVD decompositions of A and Ap. -// Consider the conjugate-tranposed [2]: -// (A Ap A)^H = A^H (A Ap) = A^H. By differentiating it we get: -// dA^H A Ap + A^H dA Ap + A^H A dAp = dA^H. By multiplying from the left by Ap^H -// and using Ap^H A^H = (A Ap)^H = A Ap: -// Ap^H dA^H A Ap + A Ap dA Ap + A Ap A dAp = Ap^H dA^H. By multiplying from the left -// by Ap and by applying [1] and [2] repeatedly until impossible we get: -// Ap Ap^H dA^H A Ap + Ap dA Ap + Ap A dAp = Ap Ap^H dA^H. By rearranging the terms: +// In the rhs of [3] the products involving dAp could be expressed as products +// of Ap^i, A^j, dA^k with i, j, k in {1, H}, where X^H = X.mH(). To prove that, +// note (A Ap)^H = A Ap and (Ap A)^H = Ap A, which could be shown by taking the +// product between the SVD decompositions of A and Ap. Consider the +// conjugate-tranposed [2]: (A Ap A)^H = A^H (A Ap) = A^H. By differentiating it +// we get: dA^H A Ap + A^H dA Ap + A^H A dAp = dA^H. By multiplying from the +// left by Ap^H and using Ap^H A^H = (A Ap)^H = A Ap: Ap^H dA^H A Ap + A Ap dA +// Ap + A Ap A dAp = Ap^H dA^H. By multiplying from the left by Ap and by +// applying [1] and [2] repeatedly until impossible we get: Ap Ap^H dA^H A Ap + +// Ap dA Ap + Ap A dAp = Ap Ap^H dA^H. By rearranging the terms: // // Ap A dAp = -Ap dA Ap + Ap Ap^H dA^H (I - A Ap) [4], // which is one of the summands in [3]. // -// Similar, by differentiating the transpose-conjugated [2] written differently, i.e. -// (A Ap A)^H = Ap A A^H = A^H we will get an expression for dAp A Ap, which is +// Similar, by differentiating the transpose-conjugated [2] written differently, +// i.e. (A Ap A)^H = Ap A A^H = A^H we will get an expression for dAp A Ap, +// which is // // dAp A Ap = -Ap dA Ap + (I - Ap A) dA^H Ap^H Ap [5]. // // By plugging in [4] and [5] into [3] we get the forward AD formula for pinv: // // dAp = -Ap dA Ap + (I - Ap A) dA^H Ap^H Ap + Ap Ap^H dA^H (I - A Ap). -Tensor pinv_jvp( - const Tensor& A, - const Tensor& pinvA, - const Tensor& dA -) { +Tensor pinv_jvp(const Tensor& A, const Tensor& pinvA, const Tensor& dA) { at::NoTF32Guard disable_tf32; auto m = A.size(-2); auto n = A.size(-1); @@ -1371,22 +1710,17 @@ Tensor pinv_jvp( // optimization to produce matrices of the smallest dimension if (m <= n) { auto K = pinvAh.matmul(dAh); - return pinvA.matmul(K - K.mH() - K.matmul(A.matmul(pinvA))) - + (dAh - pinvA.matmul(A.matmul(dAh))).matmul(pinvAh.matmul(pinvA)); - } - else { + return pinvA.matmul(K - K.mH() - K.matmul(A.matmul(pinvA))) + + (dAh - pinvA.matmul(A.matmul(dAh))).matmul(pinvAh.matmul(pinvA)); + } else { auto K = pinvA.matmul(dA); auto Kh = K.mH(); - return (Kh - K - pinvA.matmul(A).matmul(Kh)).matmul(pinvA) - + (pinvA.matmul(pinvAh)).matmul(dAh - (dAh.matmul(A)).matmul(pinvA)); + return (Kh - K - pinvA.matmul(A).matmul(Kh)).matmul(pinvA) + + (pinvA.matmul(pinvAh)).matmul(dAh - (dAh.matmul(A)).matmul(pinvA)); } } -Tensor pinv_backward( - const Tensor& grad, - const Tensor& pinvA, - const Tensor& A -) { +Tensor pinv_backward(const Tensor& grad, const Tensor& pinvA, const Tensor& A) { at::NoTF32Guard disable_tf32; auto m = A.size(-2); auto n = A.size(-1); @@ -1396,25 +1730,28 @@ Tensor pinv_backward( if (m <= n) { auto K = gradh.matmul(pinvA); auto KpinvAh = K.matmul(pinvAh); - return - (pinvA.matmul(K)).mH() - + KpinvAh - (A.matmul(pinvA)).matmul(KpinvAh) - + (pinvAh.matmul(pinvA)).matmul(gradh - K.matmul(A)); - } - else { + return -(pinvA.matmul(K)).mH() + KpinvAh - + (A.matmul(pinvA)).matmul(KpinvAh) + + (pinvAh.matmul(pinvA)).matmul(gradh - K.matmul(A)); + } else { auto K = pinvA.matmul(gradh); auto pinvAhK = pinvAh.matmul(K); - return - (K.matmul(pinvA)).mH() - + (gradh - A.matmul(K)).matmul(pinvA).matmul(pinvAh) - + pinvAhK - pinvAhK.matmul(pinvA).matmul(A); + return -(K.matmul(pinvA)).mH() + + (gradh - A.matmul(K)).matmul(pinvA).matmul(pinvAh) + pinvAhK - + pinvAhK.matmul(pinvA).matmul(A); } } -Tensor split_with_sizes_backward(const std::vector &grads, - IntArrayRef split_sizes, int64_t dim, IntArrayRef sizes, const at::TensorOptions &options) { +Tensor split_with_sizes_backward( + const std::vector& grads, + IntArrayRef split_sizes, + int64_t dim, + IntArrayRef sizes, + const at::TensorOptions& options) { dim = at::maybe_wrap_dim(dim, sizes.size()); - // it's possible some of the grads are not defined (represents tensors of all 0s). - // Since at::cat can't handle those, let's define them + // it's possible some of the grads are not defined (represents tensors of all + // 0s). Since at::cat can't handle those, let's define them std::vector grads_all_defined(grads.size()); for (const auto j : c10::irange(grads.size())) { if (grads[j].defined()) { @@ -1427,21 +1764,29 @@ Tensor split_with_sizes_backward(const std::vector &g } } - auto ret = at::cat(grads_all_defined, dim); + auto ret = at::cat(grads_all_defined, dim); return ret; } -Tensor split_backward(const std::vector &grads, - int64_t split_size, int64_t dim, IntArrayRef sizes, const at::TensorOptions &options) { +Tensor split_backward( + const std::vector& grads, + int64_t split_size, + int64_t dim, + IntArrayRef sizes, + const at::TensorOptions& options) { dim = at::maybe_wrap_dim(dim, sizes.size()); int64_t dim_size = sizes[dim]; int64_t num_splits = grads.size(); std::vector split_sizes(num_splits, split_size); - split_sizes[num_splits - 1] = split_size - (split_size * num_splits - dim_size); + split_sizes[num_splits - 1] = + split_size - (split_size * num_splits - dim_size); return split_with_sizes_backward(grads, split_sizes, dim, sizes, options); } -Tensor max_pool_double_backward(const Tensor & grad, const Tensor & indices, int dim) { +Tensor max_pool_double_backward( + const Tensor& grad, + const Tensor& indices, + int dim) { AT_ASSERT(indices.dim() >= dim); // handle non-empty inputs if (indices.numel()) { @@ -1449,7 +1794,10 @@ Tensor max_pool_double_backward(const Tensor & grad, const Tensor & indices, int size.push_back(-1); auto indices_view = indices.view(size); const auto memory_format = indices.suggest_memory_format(); - return grad.contiguous(memory_format).view(size).gather(-1, indices_view).view(indices.sizes()); + return grad.contiguous(memory_format) + .view(size) + .gather(-1, indices_view) + .view(indices.sizes()); } // handle empty inputs else { @@ -1457,7 +1805,11 @@ Tensor max_pool_double_backward(const Tensor & grad, const Tensor & indices, int } } -Tensor glu_double_backward(const Tensor & grad, const Tensor & grad_output, const Tensor & input, int64_t dim) { +Tensor glu_double_backward( + const Tensor& grad, + const Tensor& grad_output, + const Tensor& input, + int64_t dim) { auto& gO = grad_output; auto input_size = input.size(dim) / 2; auto first_half = input.narrow(dim, 0, input_size); @@ -1471,17 +1823,25 @@ Tensor glu_double_backward(const Tensor & grad, const Tensor & grad_output, cons auto ggI_second_half_times_first_half = ggI_second_half * first_half; auto gI_first_half = ggI_second_half * gO * sig_one_sub_sig; - auto second_order_sh = sig_one_sub_sig * one_sub_sig_second_half - sig_second_half * sig_one_sub_sig; - auto gI_second_half = ggI_second_half_times_first_half * gO * second_order_sh + ggI_first_half * gO * sig_one_sub_sig; + auto second_order_sh = sig_one_sub_sig * one_sub_sig_second_half - + sig_second_half * sig_one_sub_sig; + auto gI_second_half = + ggI_second_half_times_first_half * gO * second_order_sh + + ggI_first_half * gO * sig_one_sub_sig; return at::cat({gI_first_half, gI_second_half}, dim); } -Tensor glu_double_backward_grad_output(const Tensor & grad, const Tensor & input, int64_t dim) { - if (dim < 0) dim += input.dim(); +Tensor glu_double_backward_grad_output( + const Tensor& grad, + const Tensor& input, + int64_t dim) { + if (dim < 0) + dim += input.dim(); auto sizes = input.sizes().vec(); sizes[dim] /= 2; auto tmp = grad * glu_backward(at::ones(sizes, input.options()), input, dim); - return tmp.narrow(dim, 0, sizes[dim]) + tmp.narrow(dim, sizes[dim], sizes[dim]); + return tmp.narrow(dim, 0, sizes[dim]) + + tmp.narrow(dim, sizes[dim], sizes[dim]); } Tensor infinitely_differentiable_silu_backward( @@ -1497,7 +1857,8 @@ Tensor infinitely_differentiable_mish_backward( const Tensor sigmoid = input.sigmoid(); const Tensor softplus = input.exp().log1p(); const Tensor tanh_softplus = softplus.tanh(); - return grad_output * (tanh_softplus + input * sigmoid * (1.0 - tanh_softplus * tanh_softplus)); + return grad_output * + (tanh_softplus + input * sigmoid * (1.0 - tanh_softplus * tanh_softplus)); } Tensor infinitely_differentiable_logit_backward( @@ -1520,8 +1881,14 @@ Tensor infinitely_differentiable_logit_backward( } } -Tensor kl_div_double_backward_grad_output(const Tensor & grad, const Tensor & input, const Tensor & target, int64_t reduction, bool log_target) { - auto result = kl_div_backward(grad, input, target, at::Reduction::None, log_target); +Tensor kl_div_double_backward_grad_output( + const Tensor& grad, + const Tensor& input, + const Tensor& target, + int64_t reduction, + bool log_target) { + auto result = + kl_div_backward(grad, input, target, at::Reduction::None, log_target); if (reduction == at::Reduction::Mean) { return result.mean(); } else if (reduction == at::Reduction::Sum) { @@ -1531,18 +1898,26 @@ Tensor kl_div_double_backward_grad_output(const Tensor & grad, const Tensor & in } // Compute derivatives for targets. -Tensor kl_div_target_backward(Tensor grad_output, Tensor self, Tensor target, int64_t reduction, bool log_target) { +Tensor kl_div_target_backward( + Tensor grad_output, + Tensor self, + Tensor target, + int64_t reduction, + bool log_target) { Tensor grad_target; if (!log_target) { - if (!areAnyTensorSubclassLike({self, target}) && !grad_output._is_zerotensor()) { - grad_target = grad_output.mul(target.log().add_(1).sub_(self)).masked_fill_(target == 0, 0.); + if (!areAnyTensorSubclassLike({self, target}) && + !grad_output._is_zerotensor()) { + grad_target = grad_output.mul(target.log().add_(1).sub_(self)) + .masked_fill_(target == 0, 0.); } else { - grad_target = grad_output.mul(target.log().add(1).sub(self)).masked_fill(target == 0, 0.); + grad_target = grad_output.mul(target.log().add(1).sub(self)) + .masked_fill(target == 0, 0.); } - } - else { + } else { if (!areAnyTensorSubclassLike({self, target})) { - grad_target = grad_output.mul(target.add(1).sub_(self).mul_(target.exp())); + grad_target = + grad_output.mul(target.add(1).sub_(self).mul_(target.exp())); } else { grad_target = grad_output.mul(target.add(1).sub(self).mul_(target.exp())); } @@ -1560,11 +1935,11 @@ Tensor kl_div_target_backward(Tensor grad_output, Tensor self, Tensor target, in } Tensor binary_cross_entropy_target_backward( - const Tensor& grad, - const Tensor& self, - const Tensor& target, - const c10::optional& weight, - int64_t reduction) { + const Tensor& grad, + const Tensor& self, + const Tensor& target, + const c10::optional& weight, + int64_t reduction) { auto grad_target = (1. - self).log_().sub_(self.log()); if (!areAnyTensorSubclassLike({grad})) { grad_target.mul_(grad); @@ -1588,25 +1963,22 @@ Tensor binary_cross_entropy_target_backward( } Tensor binary_cross_entropy_double_backward_target( - const Tensor& grad, - const Tensor& grad_output, - const Tensor& self, - const Tensor& target, - const c10::optional& weight, - int64_t reduction -) { + const Tensor& grad, + const Tensor& grad_output, + const Tensor& self, + const Tensor& target, + const c10::optional& weight, + int64_t reduction) { auto res = -grad * grad_output; if (isDefined(weight)) { - res = isTensorSubclassLike(weight.value()) - ? res.mul(weight.value()) - : res.mul_(weight.value()); + res = isTensorSubclassLike(weight.value()) ? res.mul(weight.value()) + : res.mul_(weight.value()); } auto neg_self = 1 - self; - auto denom = isTensorSubclassLike(self) - ? neg_self.mul(self) - : neg_self.mul_(self); + auto denom = + isTensorSubclassLike(self) ? neg_self.mul(self) : neg_self.mul_(self); { at::NoGradGuard guard; // Default eps in binary_cross_entropy for ALL dtypes @@ -1615,9 +1987,7 @@ Tensor binary_cross_entropy_double_backward_target( denom.clamp_min_(eps); } - res = isTensorSubclassLike(denom) - ? res.div(denom) - : res.div_(denom); + res = isTensorSubclassLike(denom) ? res.div(denom) : res.div_(denom); if (reduction == at::Reduction::Mean) { res.div_(target.numel()); @@ -1626,13 +1996,25 @@ Tensor binary_cross_entropy_double_backward_target( return res; } -Tensor binary_cross_entropy_with_logits_target_backward(const Tensor& grad_output, const Tensor& self, const Tensor& target, const c10::optional& weight, const c10::optional& pos_weight, int64_t reduction) { +Tensor binary_cross_entropy_with_logits_target_backward( + const Tensor& grad_output, + const Tensor& self, + const Tensor& target, + const c10::optional& weight, + const c10::optional& pos_weight, + int64_t reduction) { Tensor grad_target; if (isDefined(pos_weight)) { if (!areAnyTensorSubclassLike({*pos_weight, grad_output})) { - grad_target = (1. - self.sigmoid()).log_().sub_(pos_weight->mul(self.sigmoid().log_())).mul_(grad_output); + grad_target = (1. - self.sigmoid()) + .log_() + .sub_(pos_weight->mul(self.sigmoid().log_())) + .mul_(grad_output); } else { - grad_target = (1. - self.sigmoid()).log_().sub(pos_weight->mul(self.sigmoid().log_())).mul(grad_output); + grad_target = (1. - self.sigmoid()) + .log_() + .sub(pos_weight->mul(self.sigmoid().log_())) + .mul(grad_output); } } else { grad_target = self.mul(-grad_output); @@ -1653,13 +2035,18 @@ Tensor binary_cross_entropy_with_logits_target_backward(const Tensor& grad_outpu return grad_target; } -Tensor log_sigmoid_double_backward(const Tensor & grad, const Tensor & input) { +Tensor log_sigmoid_double_backward(const Tensor& grad, const Tensor& input) { auto z = input.sigmoid(); return grad * (z - 1) * z; } -Tensor softmax_double_backward(const Tensor& grad, const Tensor& grad_output, int dim, const Tensor& output) { - return grad_output * grad - (output * grad_output).sum(dim, true) * grad - grad_output * (output * grad).sum(dim, true); +Tensor softmax_double_backward( + const Tensor& grad, + const Tensor& grad_output, + int dim, + const Tensor& output) { + return grad_output * grad - (output * grad_output).sum(dim, true) * grad - + grad_output * (output * grad).sum(dim, true); } // NOTE: [How to write vmap-compatible backward formulas] @@ -1671,8 +2058,8 @@ Tensor softmax_double_backward(const Tensor& grad, const Tensor& grad_output, in // then as developers we have the following options: // // - If the in-place operation directly followed the creation of a tensor with -// a factory function like at::zeros(...), we should replace the factory with a -// corresponding grad.new_zeros(...) call. The grad.new_zeros(...) call +// a factory function like at::zeros(...), we should replace the factory with +// a corresponding grad.new_zeros(...) call. The grad.new_zeros(...) call // propagates the batch dims to the resulting tensor. // For example: // Before: at::zeros(input.sizes(), grad.options()).copy_(grad) @@ -1684,19 +2071,27 @@ Tensor softmax_double_backward(const Tensor& grad, const Tensor& grad_output, in // areAnyTensorSubclassLike to guard the operation. For example: // c = a * b // Before: c.mul_(grad) -// After: c = !areAnyTensorSubclassLike({c, grad}) ? c.mul_(grad) : c * grad +// After: c = !areAnyTensorSubclassLike({c, grad}) ? c.mul_(grad) : c * +// grad // // - If we don't want to vmap directly over the backward formula (e.g., if the // backward formula is too complicated or has a lot of vmap-incompatible -// operations, then register the backward formula as an operator and eventually -// write a batching rule for it. +// operations, then register the backward formula as an operator and +// eventually write a batching rule for it. -Tensor binary_cross_entropy_double_backward(const Tensor & grad_output, const Tensor & grad, const Tensor & input, const Tensor & target, const c10::optional& weight, int64_t reduction) { +Tensor binary_cross_entropy_double_backward( + const Tensor& grad_output, + const Tensor& grad, + const Tensor& input, + const Tensor& target, + const c10::optional& weight, + int64_t reduction) { auto eps = 1e-12; auto inp_pl_eps = input + eps; auto one_m_inp_pl_eps = 1 - input + eps; // gradient wrt input - auto gI = (input * input - 2 * input * target + target) / (inp_pl_eps.pow(2) * one_m_inp_pl_eps.pow(2)); + auto gI = (input * input - 2 * input * target + target) / + (inp_pl_eps.pow(2) * one_m_inp_pl_eps.pow(2)); if (!areAnyTensorSubclassLike({gI, grad})) { gI *= (grad * grad_output); } else { @@ -1717,7 +2112,12 @@ Tensor binary_cross_entropy_double_backward(const Tensor & grad_output, const Te return gI; } -Tensor binary_cross_entropy_double_backward_grad_output(const Tensor & grad, const Tensor & input, const Tensor & target, const c10::optional& weight, int64_t reduction) { +Tensor binary_cross_entropy_double_backward_grad_output( + const Tensor& grad, + const Tensor& input, + const Tensor& target, + const c10::optional& weight, + int64_t reduction) { auto eps = 1e-12; // gradient wrt grad_output auto ggO = (input - target) / ((input + eps) * (1 - input + eps)); @@ -1740,7 +2140,12 @@ Tensor binary_cross_entropy_double_backward_grad_output(const Tensor & grad, con return ggO; } -Tensor l1_loss_double_backward(const Tensor & grad, const Tensor & grad_output, const Tensor & self, const Tensor & other, int64_t reduction) { +Tensor l1_loss_double_backward( + const Tensor& grad, + const Tensor& grad_output, + const Tensor& self, + const Tensor& other, + int64_t reduction) { if (!self.is_complex()) { return at::zeros_like(grad); } else { @@ -1753,18 +2158,29 @@ Tensor l1_loss_double_backward(const Tensor & grad, const Tensor & grad_output, } } -Tensor l1_loss_double_backward_grad_output(const Tensor & grad, const Tensor & grad_output, const Tensor & input, const Tensor & target, int64_t reduction) { - auto output = at::l1_loss_backward(grad.conj(), input, target, at::Reduction::None); +Tensor l1_loss_double_backward_grad_output( + const Tensor& grad, + const Tensor& grad_output, + const Tensor& input, + const Tensor& target, + int64_t reduction) { + auto output = + at::l1_loss_backward(grad.conj(), input, target, at::Reduction::None); if (reduction == at::Reduction::Mean) { output /= input.numel(); } return handle_r_to_c(grad_output, output); } -Tensor smooth_l1_loss_double_backward(const Tensor & grad, const Tensor & input, const Tensor & target, int64_t reduction, double beta) { +Tensor smooth_l1_loss_double_backward( + const Tensor& grad, + const Tensor& input, + const Tensor& target, + int64_t reduction, + double beta) { // special case to protect against a divide-by-zero. if (beta == 0) { - return at::zeros(grad.sizes(), grad.options()); + return at::zeros(grad.sizes(), grad.options()); } auto d = (input - target).abs(); auto grad_input = grad * (d < beta).type_as(grad) / beta; @@ -1774,7 +2190,12 @@ Tensor smooth_l1_loss_double_backward(const Tensor & grad, const Tensor & input, return grad_input; } -Tensor huber_loss_double_backward(const Tensor & grad, const Tensor & input, const Tensor & target, int64_t reduction, double delta) { +Tensor huber_loss_double_backward( + const Tensor& grad, + const Tensor& input, + const Tensor& target, + int64_t reduction, + double delta) { auto d = (input - target).abs(); auto grad_input = grad * (d < delta); if (reduction == at::Reduction::Mean) { @@ -1783,15 +2204,25 @@ Tensor huber_loss_double_backward(const Tensor & grad, const Tensor & input, con return grad_input; } -Tensor huber_loss_double_backward_grad_output(const Tensor & grad, const Tensor & grad_output, const Tensor & input, const Tensor & target, int64_t reduction, double delta) { +Tensor huber_loss_double_backward_grad_output( + const Tensor& grad, + const Tensor& grad_output, + const Tensor& input, + const Tensor& target, + int64_t reduction, + double delta) { if (reduction == at::Reduction::None) { return huber_loss_backward(grad, input, target, reduction, delta); } - auto r = huber_loss_backward(ones_like(grad_output), input, target, reduction, delta); + auto r = huber_loss_backward( + ones_like(grad_output), input, target, reduction, delta); return (r * grad).sum(); } -Tensor mse_loss_double_backward(const Tensor & grad, const Tensor & input, int64_t reduction) { +Tensor mse_loss_double_backward( + const Tensor& grad, + const Tensor& input, + int64_t reduction) { auto grad_input = 2 * grad; if (reduction == at::Reduction::Mean) { grad_input /= input.numel(); @@ -1799,7 +2230,11 @@ Tensor mse_loss_double_backward(const Tensor & grad, const Tensor & input, int64 return grad_input; } -Tensor soft_margin_loss_double_backward(const Tensor & grad, const Tensor & input, const Tensor & target, int64_t reduction) { +Tensor soft_margin_loss_double_backward( + const Tensor& grad, + const Tensor& input, + const Tensor& target, + int64_t reduction) { auto z = (input * -target).exp(); auto zplus1 = z + 1; auto grad_input = grad * (target * target) * z / (zplus1 * zplus1); @@ -1809,20 +2244,30 @@ Tensor soft_margin_loss_double_backward(const Tensor & grad, const Tensor & inpu return grad_input; } -Tensor soft_margin_loss_double_backward_grad_output(const Tensor & grad, const Tensor & grad_output, const Tensor & input, const Tensor & target, int64_t reduction) { +Tensor soft_margin_loss_double_backward_grad_output( + const Tensor& grad, + const Tensor& grad_output, + const Tensor& input, + const Tensor& target, + int64_t reduction) { if (reduction == at::Reduction::None) { return soft_margin_loss_backward(grad, input, target, reduction); } - auto r = soft_margin_loss_backward(ones_like(grad_output), input, target, reduction); + auto r = soft_margin_loss_backward( + ones_like(grad_output), input, target, reduction); return (r * grad).sum(); } -Tensor softplus_double_backward(const Tensor & grad, const Tensor & input, const Scalar& beta, const Scalar& threshold) { +Tensor softplus_double_backward( + const Tensor& grad, + const Tensor& input, + const Scalar& beta, + const Scalar& threshold) { auto x = (input * beta); - return sigmoid_backward(grad, x.sigmoid()) * (x < threshold).type_as(grad) * beta; + return sigmoid_backward(grad, x.sigmoid()) * (x < threshold).type_as(grad) * + beta; } - // NOTE [ as_strided Backward and layout-aware/agnostic autograd ] // // `storage_offset` is ignored for simplicity in this note. If you just want the @@ -1971,7 +2416,8 @@ Tensor softplus_double_backward(const Tensor & grad, const Tensor & input, const // Denote the set of all input indices that pointing to the same storage // location `storage[n]` as `S(n)`, i.e., // -// S(n) = { index : == n, index is valid given input.size() }, +// S(n) = { index : == n, index is valid given +// input.size() }, // // where `` is the dot product between `x` and `y`. // @@ -2024,7 +2470,6 @@ Tensor softplus_double_backward(const Tensor & grad, const Tensor & input, const // See NOTE [ Detecting Memory Overlap Within A Strided Tensor ] on how to // roughly detech overlapping memory. - // NOTE [ Detecting Memory Overlap Within A Strided Tensor ] // // Checking memory overlap within a strided tensor is the special case of @@ -2072,7 +2517,8 @@ Tensor softplus_double_backward(const Tensor & grad, const Tensor & input, const // // Notation: // Consider a single such block B, -// ... B_prev[-1]], [ B[0], ..., B[i], ..., B[k] = B[-1] ], [ B_next[0], ... +// ... B_prev[-1]], [ B[0], ..., B[i], ..., B[k] = B[-1] ], [ +// B_next[0], ... // start--^^^^ ^^^^^^^^^^^^--end // Each B[i] denotes a dimension index such that B[i] = B[0] + i. // @@ -2091,9 +2537,11 @@ Tensor softplus_double_backward(const Tensor & grad, const Tensor & input, const // // By (*), we know that // stride[B[0]] -// = \sum_{i > 0} (size[B[i]] - 1) * stride[B[i]] + stride[B[-1]] -// < \sum_{i > 0} (size[B[i]] - 1) * stride[B[i]] + stride[d] -// <= \sum_{i > 0} (size[B[i]] - 1) * stride[B[i]] + (size[d] - 1) * stride[d] +// = \sum_{i > 0} (size[B[i]] - 1) * stride[B[i]] + +// stride[B[-1]] < \sum_{i > 0} (size[B[i]] - 1) * +// stride[B[i]] + stride[d] +// <= \sum_{i > 0} (size[B[i]] - 1) * stride[B[i]] + +// (size[d] - 1) * stride[d] // <= \sum{j > B[0]} (size[j] - 1) * stride[j], // // where the first < comes from sorting and @@ -2119,7 +2567,8 @@ Tensor softplus_double_backward(const Tensor & grad, const Tensor & input, const // // By (*), we know that for all i // stride[i] = stride[B[-1]] + -// \sum_{j=i+1}^{k} (size[B[j]] - 1) * stride[B[j]] +// \sum_{j=i+1}^{k} (size[B[j]] - 1) * +// stride[B[j]] // // Then the invariant is obviously satisfied at every dimension // in this block if it is satisfied at dimnesion B[-1]. It only @@ -2133,8 +2582,9 @@ Tensor softplus_double_backward(const Tensor & grad, const Tensor & input, const // By (*), we know that the following holds for both input and // output, for any block B: // \sum_{i > B[-1]} (size[i] - 1) * stride[i] -// = \sum_{block B' after B} \prod_{j in B'} size[B[j]] * stride[B'[-1]] -// = \sum_{block B' after B} numel(B') * stride[B'[-1]]. +// = \sum_{block B' after B} \prod_{j in B'} size[B[j]] * +// stride[B'[-1]] = \sum_{block B' after B} numel(B') * +// stride[B'[-1]]. // ^^^^^^^^^^^^^^^^^^^^^^^|^^^^^^^^^^^^^^^^^^^^^^^^^^ // By (**), we know that, this quantity in the above equation // remains the same in input and output. So both @@ -2212,7 +2662,8 @@ Tensor softplus_double_backward(const Tensor & grad, const Tensor & input, const // // Each dimension d in input with stride[d] > stride[j] has // stride'[d] = stride[d] -// > (size[i] - 1) * stride[i] + (size[j] - 1) * stride[j] +// > (size[i] - 1) * stride[i] + (size[j] - 1) * +// stride[j] // >= stride[i] + stride[j] // = stride[k]. // So, considering the sorted dimensions, this is effectively @@ -2232,12 +2683,16 @@ Tensor softplus_double_backward(const Tensor & grad, const Tensor & input, const // This implements steps (2)~(4) of the algorithm in // NOTE [ Detecting Memory Overlap Within A Strided Tensor ] // Helper for as_strided_backward -static inline bool _maybe_overlapping_memory(IntArrayRef sizes, IntArrayRef strides) { +static inline bool _maybe_overlapping_memory( + IntArrayRef sizes, + IntArrayRef strides) { if (sizes.size() > 0) { std::vector argsort(sizes.size()); std::iota(argsort.begin(), argsort.end(), 0); - std::sort(argsort.begin(), argsort.end(), - [&](std::size_t i, std::size_t j){ return strides[i] < strides[j]; }); + std::sort( + argsort.begin(), argsort.end(), [&](std::size_t i, std::size_t j) { + return strides[i] < strides[j]; + }); int64_t max_index_in_slice = 0; for (auto i : argsort) { @@ -2251,12 +2706,15 @@ static inline bool _maybe_overlapping_memory(IntArrayRef sizes, IntArrayRef stri return false; } -// Returns the minimum storage size needed to contain a tensor of sizes, strides, and storage_offset -// Helper for as_strided_backward -static inline int64_t _min_storage_size(IntArrayRef sizes, IntArrayRef strides, int64_t storage_offset) { +// Returns the minimum storage size needed to contain a tensor of sizes, +// strides, and storage_offset Helper for as_strided_backward +static inline int64_t _min_storage_size( + IntArrayRef sizes, + IntArrayRef strides, + int64_t storage_offset) { int64_t storage_size = storage_offset + 1; int64_t dim = sizes.size(); - for(const auto i : c10::irange(dim)) { + for (const auto i : c10::irange(dim)) { auto size_i = sizes[i]; if (size_i == 0) { return storage_offset; @@ -2266,16 +2724,24 @@ static inline int64_t _min_storage_size(IntArrayRef sizes, IntArrayRef strides, return storage_size; } -// See NOTE [ as_strided Backward and layout-aware/agnostic autograd ] for explanation -Tensor as_strided_backward(Tensor grad, TensorGeometry input_geometry, IntArrayRef sizes, IntArrayRef strides, optional storage_offset_) { +// See NOTE [ as_strided Backward and layout-aware/agnostic autograd ] for +// explanation +Tensor as_strided_backward( + Tensor grad, + TensorGeometry input_geometry, + IntArrayRef sizes, + IntArrayRef strides, + optional storage_offset_) { // For output geometry, // check for size 0 dimensions, // skip size 1 dimensions, // reduce grad on expanded dims (stride=0, size>1) - // Step (0) for the algorithm in NOTE [ as_strided Backward and layout-aware/agnostic autograd ] - // Step (0)~(1) for the algorithm in NOTE [ Detecting Memory Overlap Within A Strided Tensor ] + // Step (0) for the algorithm in NOTE [ as_strided Backward and + // layout-aware/agnostic autograd ] Step (0)~(1) for the algorithm in NOTE [ + // Detecting Memory Overlap Within A Strided Tensor ] // on output geometry - auto storage_offset = storage_offset_.value_or(input_geometry.storage_offset()); + auto storage_offset = + storage_offset_.value_or(input_geometry.storage_offset()); auto odim = grad.dim(); std::vector out_sizes_, out_strides_; out_sizes_.reserve(odim); @@ -2294,17 +2760,20 @@ Tensor as_strided_backward(Tensor grad, TensorGeometry input_geometry, IntArrayR out_strides_.insert(out_strides_.begin(), stride_i); } } - // Step (2)~(4) for the algorithm in NOTE [ Detecting Memory Overlap Within A Strided Tensor ] + // Step (2)~(4) for the algorithm in NOTE [ Detecting Memory Overlap Within A + // Strided Tensor ] // on output geometry auto out_maybe_overlap = _maybe_overlapping_memory(out_sizes_, out_strides_); // For input geometry, // check for size 0 dimensions, // skip size 1 dimensions, - // Step (0)~(1) for the algorithm in NOTE [ Detecting Memory Overlap Within A Strided Tensor ] + // Step (0)~(1) for the algorithm in NOTE [ Detecting Memory Overlap Within A + // Strided Tensor ] // on input geometry auto idim = input_geometry.dim(); - IntArrayRef inp_sizes = input_geometry.sizes(), inp_strides = input_geometry.strides(); + IntArrayRef inp_sizes = input_geometry.sizes(), + inp_strides = input_geometry.strides(); std::vector inp_sizes_, inp_strides_; inp_sizes_.reserve(idim); inp_strides_.reserve(idim); @@ -2318,13 +2787,14 @@ Tensor as_strided_backward(Tensor grad, TensorGeometry input_geometry, IntArrayR inp_strides_.insert(inp_strides_.begin(), stride_i); } } - // Step (1)~(4) for the algorithm in NOTE [ Detecting Memory Overlap Within A Strided Tensor ] + // Step (1)~(4) for the algorithm in NOTE [ Detecting Memory Overlap Within A + // Strided Tensor ] // on input geometry auto inp_maybe_overlap = _maybe_overlapping_memory(inp_sizes_, inp_strides_); - // Rest of this function implements - // Step (1)~(4) for the algorithm in NOTE [ as_strided Backward and layout-aware/agnostic autograd ] + // Step (1)~(4) for the algorithm in NOTE [ as_strided Backward and + // layout-aware/agnostic autograd ] // TODO: Raise if not all output values are visible in input geometry. // Technically speaking, if you treat those values as constants, not // raising is fine, and mathematically correct. However, these values @@ -2333,69 +2803,91 @@ Tensor as_strided_backward(Tensor grad, TensorGeometry input_geometry, IntArrayR // more sensible to raise here. // Step (1): create underlying tensor as "storage" - auto shared_offset = std::min(input_geometry.storage_offset(), storage_offset); + auto shared_offset = + std::min(input_geometry.storage_offset(), storage_offset); auto inp_effective_offset = input_geometry.storage_offset() - shared_offset; auto out_effective_offset = storage_offset - shared_offset; auto base_size = std::max( - _min_storage_size(inp_sizes_, inp_strides_, inp_effective_offset), - _min_storage_size(out_sizes_, out_strides_, out_effective_offset) - ); + _min_storage_size(inp_sizes_, inp_strides_, inp_effective_offset), + _min_storage_size(out_sizes_, out_strides_, out_effective_offset)); auto storage = grad.new_zeros({base_size}); // prepare indices tensor if we will do index_add_ later c10::optional flatten_full_indices; if (inp_maybe_overlap || out_maybe_overlap) { - flatten_full_indices = at::arange(0, base_size, grad.options().dtype(at::kLong)); + flatten_full_indices = + at::arange(0, base_size, grad.options().dtype(at::kLong)); } // Step (2): use output geometry to scatter gradients into storage if (out_maybe_overlap) { - auto out_indices = flatten_full_indices->as_strided(out_sizes_, out_strides_, out_effective_offset); + auto out_indices = flatten_full_indices->as_strided( + out_sizes_, out_strides_, out_effective_offset); storage.index_add_(0, out_indices.reshape(-1), grad.reshape(-1)); } else { // assume that new tensors have 0 storage offset - storage.as_strided(out_sizes_, out_strides_, out_effective_offset).copy_(grad); + storage.as_strided(out_sizes_, out_strides_, out_effective_offset) + .copy_(grad); } // Step (3): if input tensor has overlapping memory, divide scattered gradient // at storage[i] by the number of times i shows up in input geometry if (inp_maybe_overlap) { auto count = at::zeros_like(storage, LEGACY_CONTIGUOUS_MEMORY_FORMAT); - auto inp_indices = flatten_full_indices->as_strided(inp_sizes_, inp_strides_, inp_effective_offset).reshape(-1); - count.index_add_(0, inp_indices, at::ones({1}, grad.options()).expand_as(inp_indices)); + auto inp_indices = + flatten_full_indices + ->as_strided(inp_sizes_, inp_strides_, inp_effective_offset) + .reshape(-1); + count.index_add_( + 0, inp_indices, at::ones({1}, grad.options()).expand_as(inp_indices)); storage.div_(count); // this will give nan outside visible range } // Step (4): return as_strided view of the storage tensor with input geometry return storage.as_strided(inp_sizes, inp_strides, inp_effective_offset); } -Tensor as_strided_scatter_backward(Tensor grad, TensorGeometry input_geometry, TensorGeometry src_geometry, IntArrayRef sizes, IntArrayRef strides, optional storage_offset) { +Tensor as_strided_scatter_backward( + Tensor grad, + TensorGeometry input_geometry, + TensorGeometry src_geometry, + IntArrayRef sizes, + IntArrayRef strides, + optional storage_offset) { // Note [as_strided_scatter backward support] - // as_strided_scatter handling for autograd is a beast, and is non-trivial to implement for arbitrarily strided inputs. - // Most uses for as_strided with functionalization only care about the contiguous case anyway, - // So for now this is not implemented. - // When autograd is being used, we ban non-contiguous inputs. - // We can assume that the input was a contiguous tensor. - // Also, we'll take the perf hit and contiguify grad for now. + // as_strided_scatter handling for autograd is a beast, and is non-trivial to + // implement for arbitrarily strided inputs. Most uses for as_strided with + // functionalization only care about the contiguous case anyway, So for now + // this is not implemented. When autograd is being used, we ban non-contiguous + // inputs. We can assume that the input was a contiguous tensor. Also, we'll + // take the perf hit and contiguify grad for now. auto grad_ = grad.contiguous(); auto grad_slice = grad_.as_strided(sizes, strides, storage_offset); - auto result = grad_.new_empty_strided(input_geometry.sizes(), input_geometry.strides()); + auto result = + grad_.new_empty_strided(input_geometry.sizes(), input_geometry.strides()); auto result_slice = result.as_strided(sizes, strides, storage_offset); result_slice.copy_(grad_slice); return result; } -std::tuple atan2_backward(const Tensor& grad, const Tensor& self, const Tensor& other, std::array output_mask) { +std::tuple atan2_backward( + const Tensor& grad, + const Tensor& self, + const Tensor& other, + std::array output_mask) { if (!grad.defined()) { return std::tuple{Tensor(), Tensor()}; } auto recip = (self * self + other * other).reciprocal(); - return std::tuple{ - output_mask[0] ? grad * other * recip : Tensor(), - output_mask[1] ? grad * -self * recip : Tensor() }; + return std::tuple{ + output_mask[0] ? grad * other * recip : Tensor(), + output_mask[1] ? grad * -self * recip : Tensor()}; } -Tensor prelu_jvp(const Tensor& x, const Tensor& dx, const Tensor& w, const Tensor& dw) { +Tensor prelu_jvp( + const Tensor& x, + const Tensor& dx, + const Tensor& w, + const Tensor& dw) { const auto ndim = x.dim(); auto as_nd = [ndim](const Tensor& t) { std::vector sizes(ndim, 1), strides(ndim, 0); @@ -2415,22 +2907,28 @@ Tensor prelu_jvp(const Tensor& x, const Tensor& dx, const Tensor& w, const Tenso // each output separately; there is not all that much sharing // of computation going on here. std::tuple prelu_double_backward( - const Tensor & grad_grad_input, - const Tensor & grad_grad_weight, - const Tensor & grad_out, - const Tensor & input_, - const Tensor & weight_) { - - if (!(grad_grad_input.defined() || grad_grad_weight.defined() || grad_out.defined())) { + const Tensor& grad_grad_input, + const Tensor& grad_grad_weight, + const Tensor& grad_out, + const Tensor& input_, + const Tensor& weight_) { + if (!(grad_grad_input.defined() || grad_grad_weight.defined() || + grad_out.defined())) { return std::tuple(Tensor(), Tensor(), Tensor()); } - auto input = input_.contiguous(); - auto weight = weight_.contiguous(); + auto input = input_.contiguous(); + auto weight = weight_.contiguous(); // Zero-fill undefined grads (TODO: do this more efficiently) - auto ggI = grad_grad_input.defined() ? grad_grad_input.contiguous() : at::zeros_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT); - auto ggW = grad_grad_weight.defined() ? grad_grad_weight.contiguous() : at::zeros_like(weight, LEGACY_CONTIGUOUS_MEMORY_FORMAT); - auto gO = grad_out.defined() ? grad_out.contiguous() : at::zeros_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT); + auto ggI = grad_grad_input.defined() + ? grad_grad_input.contiguous() + : at::zeros_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT); + auto ggW = grad_grad_weight.defined() + ? grad_grad_weight.contiguous() + : at::zeros_like(weight, LEGACY_CONTIGUOUS_MEMORY_FORMAT); + auto gO = grad_out.defined() + ? grad_out.contiguous() + : at::zeros_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT); auto positive_mask = (input > 0).type_as(ggI); auto nonpositive_mask = (input <= 0).type_as(ggW); @@ -2440,63 +2938,68 @@ std::tuple prelu_double_backward( // = w * i if i <= 0 // gI = df/di * gO = gO if i > 0 gW = df/dw * gO = 0 if i > 0 // = gO * w if i <= 0 = gO * i if i <= 0 - // The rest is taking derivatives of these wrt i, w, gO and summing/expanding properly. + // The rest is taking derivatives of these wrt i, w, gO and summing/expanding + // properly. if (weight.numel() == 1) { - // from PReLU.forward: num_parameters == 0 is used indicate that a - // single weight is shared among all input channels. - - // this is a little tricky because PReLU currently doesn't take a shape so the weight may be - // 1-d when the input is a scalar (and there isn't a good Parameter API for that anyway until Variable - // and tensor are merged). So, use weight and ggW as 0-dim in this case. - bool scalar_input_1d_weight = (positive_mask.dim() == 0 && weight.dim() == 1); - auto weight_maybe_squeeze = scalar_input_1d_weight ? weight.squeeze() : weight; - auto ggW_maybe_squeeze = scalar_input_1d_weight ? ggW.squeeze() : ggW; - - auto mask = positive_mask + nonpositive_mask * weight_maybe_squeeze.expand_as(input); - auto ggO = ggI * mask + ggW_maybe_squeeze.expand_as(gO) * (nonpositive_mask * input); - return std::tuple( - ggO, - ggW_maybe_squeeze.expand_as(gO) * gO * nonpositive_mask, - (ggI * gO * nonpositive_mask).sum().expand_as(weight) - ); + // from PReLU.forward: num_parameters == 0 is used indicate that a + // single weight is shared among all input channels. + + // this is a little tricky because PReLU currently doesn't take a shape so + // the weight may be 1-d when the input is a scalar (and there isn't a good + // Parameter API for that anyway until Variable and tensor are merged). So, + // use weight and ggW as 0-dim in this case. + bool scalar_input_1d_weight = + (positive_mask.dim() == 0 && weight.dim() == 1); + auto weight_maybe_squeeze = + scalar_input_1d_weight ? weight.squeeze() : weight; + auto ggW_maybe_squeeze = scalar_input_1d_weight ? ggW.squeeze() : ggW; + + auto mask = positive_mask + + nonpositive_mask * weight_maybe_squeeze.expand_as(input); + auto ggO = ggI * mask + + ggW_maybe_squeeze.expand_as(gO) * (nonpositive_mask * input); + return std::tuple( + ggO, + ggW_maybe_squeeze.expand_as(gO) * gO * nonpositive_mask, + (ggI * gO * nonpositive_mask).sum().expand_as(weight)); } else { - // Expand ggW to match size of ggI; a simple expand doesn't work because - // ggW is the size of the input channel (dim==1 unless there is only 1 dimension). For example, - // let ggI be size (3,4,5,6,7) and ggW be size (4). Then we unsqueeze ggW to be size (4,1,1,1) - // so the expand succeeds. - auto dims_to_unsqueeze = std::max(input.dim() - 2, 0); - auto ggW_expanded = ggW; - for(const auto i : c10::irange(dims_to_unsqueeze)) { - (void)i; // Suppress unused variable warning - ggW_expanded = ggW_expanded.unsqueeze(1); - } - ggW_expanded = ggW_expanded.expand_as(ggI); + // Expand ggW to match size of ggI; a simple expand doesn't work because + // ggW is the size of the input channel (dim==1 unless there is only 1 + // dimension). For example, let ggI be size (3,4,5,6,7) and ggW be size + // (4). Then we unsqueeze ggW to be size (4,1,1,1) so the expand succeeds. + auto dims_to_unsqueeze = std::max(input.dim() - 2, 0); + auto ggW_expanded = ggW; + for (const auto i : c10::irange(dims_to_unsqueeze)) { + (void)i; // Suppress unused variable warning + ggW_expanded = ggW_expanded.unsqueeze(1); + } + ggW_expanded = ggW_expanded.expand_as(ggI); - auto gI = ggW_expanded * gO * nonpositive_mask; + auto gI = ggW_expanded * gO * nonpositive_mask; - auto gW = ggI * gO * nonpositive_mask; - if (input.dim() > 1) { - gW = gW.sum(0); - } - while (gW.dim() > 1) { - gW = gW.sum(1); - } + auto gW = ggI * gO * nonpositive_mask; + if (input.dim() > 1) { + gW = gW.sum(0); + } + while (gW.dim() > 1) { + gW = gW.sum(1); + } - Tensor ggO; - if (gO.requires_grad()) { - // expand weight as input as in ggW/ggI above - auto weight_expanded = weight; - for(const auto i : c10::irange(dims_to_unsqueeze)) { - (void)i; // Suppress unused variable warning - weight_expanded = weight_expanded.unsqueeze(1); - } - weight_expanded = weight_expanded.expand_as(input); - - auto mask = positive_mask + nonpositive_mask * weight_expanded; - ggO = ggI * mask + ggW_expanded * nonpositive_mask * input; + Tensor ggO; + if (gO.requires_grad()) { + // expand weight as input as in ggW/ggI above + auto weight_expanded = weight; + for (const auto i : c10::irange(dims_to_unsqueeze)) { + (void)i; // Suppress unused variable warning + weight_expanded = weight_expanded.unsqueeze(1); } - return std::tuple{ggO, gI, gW}; + weight_expanded = weight_expanded.expand_as(input); + + auto mask = positive_mask + nonpositive_mask * weight_expanded; + ggO = ggI * mask + ggW_expanded * nonpositive_mask * input; + } + return std::tuple{ggO, gI, gW}; } } @@ -2505,8 +3008,7 @@ Tensor prelu_backward_self_jvp( const Tensor& w, const Tensor& dw, const Tensor& g, - const Tensor& dg -) { + const Tensor& dg) { const auto ndim = x.dim(); auto as_nd = [ndim](const Tensor& t) { std::vector sizes(ndim, 1), strides(ndim, 0); @@ -2527,9 +3029,9 @@ Tensor prelu_backward_weight_jvp( const Tensor& x, const Tensor& dx, const Tensor& g, - const Tensor& dg -) { - const auto dw_full = at::where(x >= 0, at::zeros({}, x.options()), g * dx + dg * x); + const Tensor& dg) { + const auto dw_full = + at::where(x >= 0, at::zeros({}, x.options()), g * dx + dg * x); const auto ndim = x.dim(); std::vector reduction_dims; @@ -2551,11 +3053,12 @@ Tensor prelu_backward_weight_jvp( } Tensor gelu_double_backward( - const Tensor & ggI, - const Tensor & gO, - const Tensor & input, - c10::string_view approximate) { - //if (at::native::get_gelutype_enum(approximate) == at::native::GeluType::Tanh) { + const Tensor& ggI, + const Tensor& gO, + const Tensor& input, + c10::string_view approximate) { + // if (at::native::get_gelutype_enum(approximate) == + // at::native::GeluType::Tanh) { if (approximate == "tanh") { constexpr auto kBeta = M_SQRT2 * M_2_SQRTPI * 0.5; constexpr auto kKappa = 0.044715; @@ -2599,12 +3102,19 @@ Tensor elu_double_backward( const Scalar& input_scale, bool is_result, const Tensor& self_or_result) { - - if (is_result) { - return grad * grad_output * input_scale * (self_or_result < 0).type_as(grad); - } else { - return at::elu_backward(grad * grad_output * input_scale, alpha, scale, input_scale, is_result, self_or_result) * (self_or_result < 0).type_as(grad); - } + if (is_result) { + return grad * grad_output * input_scale * + (self_or_result < 0).type_as(grad); + } else { + return at::elu_backward( + grad * grad_output * input_scale, + alpha, + scale, + input_scale, + is_result, + self_or_result) * + (self_or_result < 0).type_as(grad); + } } Tensor slice_backward_wrapper( @@ -2620,11 +3130,12 @@ Tensor slice_backward_wrapper( return slice_backward(grad, input_sizes, dim, start_val, end_val, step); } -std::tuple linalg_svd_jvp(const Tensor& dA, - const Tensor& U_, - const Tensor& S, - const Tensor& Vh_, - const bool full_matrices) { +std::tuple linalg_svd_jvp( + const Tensor& dA, + const Tensor& U_, + const Tensor& S, + const Tensor& Vh_, + const bool full_matrices) { at::NoTF32Guard disable_tf32; // See svd_backward for the derivation // With sym(X) = X + X^H, we implement @@ -2657,16 +3168,17 @@ std::tuple linalg_svd_jvp(const Tensor& dA, auto dP = m >= n ? at::matmul(U.mH(), at::matmul(dA, V)) : at::matmul(at::matmul(U.mH(), dA), V); - auto dS = is_complex ? at::real(dP.diagonal(0, -2, -1)) - : dP.diagonal(0, -2, -1); + auto dS = + is_complex ? at::real(dP.diagonal(0, -2, -1)) : dP.diagonal(0, -2, -1); // dX = dP - dS dP = dP - dS.diag_embed(); - auto E = [&S]{ + auto E = [&S] { const auto S2 = S * S; auto ret = S2.unsqueeze(-2) - S2.unsqueeze(-1); - // Any number a != 0 would, as we are just going to use it to compute 0 / a later on + // Any number a != 0 would, as we are just going to use it to compute 0 / a + // later on ret.diagonal(0, -2, -1).fill_(1); return ret; }(); @@ -2674,8 +3186,7 @@ std::tuple linalg_svd_jvp(const Tensor& dA, const auto sym = [](const Tensor& X) { return X + X.mH(); }; // diag(dP) / (2S) - auto diagdP2S = is_complex ? dP.diagonal(0, -2, -1).div(2. * S) - : Tensor{}; + auto diagdP2S = is_complex ? dP.diagonal(0, -2, -1).div(2. * S) : Tensor{}; // dU = U (sym(dP S) / E) + i Im(diag(dP)) / (2S) auto dU = [&] { @@ -2687,10 +3198,11 @@ std::tuple linalg_svd_jvp(const Tensor& dA, }(); if (m > n) { // dU += (I_m - UU^H) dA V S^{-1} - const auto dAVSinv = at::matmul(dA, V / S.unsqueeze(-2)) ; + const auto dAVSinv = at::matmul(dA, V / S.unsqueeze(-2)); dU = dU + dAVSinv - at::matmul(U, at::matmul(U.mH(), dAVSinv)); - // To "fix" the full_matrices case (the full_matrices case should not be differentiable...) + // To "fix" the full_matrices case (the full_matrices case should not be + // differentiable...) if (full_matrices) { auto shape = dU.sizes().vec(); shape.end()[-1] = m - n; @@ -2709,10 +3221,11 @@ std::tuple linalg_svd_jvp(const Tensor& dA, }(); if (m < n) { // dVh += S^{-1} U^H dA (I_n - VV^H) - const auto UHdASinv = at::matmul(U.mH() / S.unsqueeze(-1), dA) ; + const auto UHdASinv = at::matmul(U.mH() / S.unsqueeze(-1), dA); dVh = dVh + UHdASinv - at::matmul(at::matmul(UHdASinv, V), Vh); - // To "fix" the full_matrices case (the full_matrices case should not be differentiable...) + // To "fix" the full_matrices case (the full_matrices case should not be + // differentiable...) if (full_matrices) { auto shape = dVh.sizes().vec(); shape.end()[-2] = n - m; @@ -2723,15 +3236,17 @@ std::tuple linalg_svd_jvp(const Tensor& dA, return std::make_tuple(std::move(dU), std::move(dS), std::move(dVh)); } -Tensor svd_backward(const Tensor& gU, - const Tensor& gS, - const Tensor& gVh, - const Tensor& U, - const Tensor& S, - const Tensor& Vh) { +Tensor svd_backward( + const Tensor& gU, + const Tensor& gS, + const Tensor& gVh, + const Tensor& U, + const Tensor& S, + const Tensor& Vh) { at::NoTF32Guard disable_tf32; - // Throughout both the real and complex case we assume A has distinct singular values. - // Furthermore, if A is rectangular or complex, we assume it's full-rank. + // Throughout both the real and complex case we assume A has distinct singular + // values. Furthermore, if A is rectangular or complex, we assume it's + // full-rank. // // // The real case (A \in R) @@ -2750,76 +3265,76 @@ Tensor svd_backward(const Tensor& gU, // // The complex case (A \in C) // This one is trickier because the svd is not locally unique. - // Denote L = diag(e^{i\theta_k}), then we have that if A = USV^H, then (UL, S, VL) is - // another valid SVD decomposition of A as - // A = ULS(VL)^H = ULSL^{-1}V^H = USV^H, - // since L, S and L^{-1} commute, since they are all diagonal. + // Denote L = diag(e^{i\theta_k}), then we have that if A = USV^H, then (UL, + // S, VL) is another valid SVD decomposition of A as A = ULS(VL)^H = + // ULSL^{-1}V^H = USV^H, since L, S and L^{-1} commute, since they are all + // diagonal. // - // Assume wlog that n >= k in what follows, as otherwise we could reason about A^H. - // Denote by St_k(C^n) = {A \in C^{n,k} | A^H A = I_k} the complex Stiefel manifold. - // What this invariance means is that the svd decomposition is not a map - // svd: C^{n x k} -> St_k(C^n) x R^n x St_k(C^k) - // (where St_k(C^k) is simply the unitary group U(k)) but a map - // svd: C^{n x k} -> M x R^n - // where M is the manifold given by quotienting St_k(C^n) x U(n) by the action (U, V) -> (UL, VL) - // with L as above. - // Note that M is a manifold, because the action is free and proper (as U(1)^k \iso (S^1)^k is compact). - // For this reason, pi : St_k(C^n) x U(n) -> M forms a principal bundle. + // Assume wlog that n >= k in what follows, as otherwise we could reason about + // A^H. Denote by St_k(C^n) = {A \in C^{n,k} | A^H A = I_k} the complex + // Stiefel manifold. What this invariance means is that the svd decomposition + // is not a map svd: C^{n x k} -> St_k(C^n) x R^n x St_k(C^k) (where St_k(C^k) + // is simply the unitary group U(k)) but a map svd: C^{n x k} -> M x R^n where + // M is the manifold given by quotienting St_k(C^n) x U(n) by the action (U, + // V) -> (UL, VL) with L as above. Note that M is a manifold, because the + // action is free and proper (as U(1)^k \iso (S^1)^k is compact). For this + // reason, pi : St_k(C^n) x U(n) -> M forms a principal bundle. // // To think about M, consider the case case k = 1. The, we have the bundle // pi : St_1(C^n) x U(1) -> M - // now, St_1(C^n) are just vectors of norm 1 in C^n. That's exactly the sphere of dimension 2n-1 in C^n \iso R^{2n} - // S^{2n-1} = { z \in C^n | z^H z = 1}. + // now, St_1(C^n) are just vectors of norm 1 in C^n. That's exactly the sphere + // of dimension 2n-1 in C^n \iso R^{2n} S^{2n-1} = { z \in C^n | z^H z = 1}. // Then, in this case, we're quotienting out U(1) completely, so we get that // pi : S^{2n-1} x U(1) -> CP(n-1) // where CP(n-1) is the complex projective space of dimension n-1. - // In other words, M is just the complex projective space, and pi is (pretty similar to) - // the usual principal bundle from S^{2n-1} to CP(n-1). - // The case k > 1 is the same, but requiring a linear inependence condition between the + // In other words, M is just the complex projective space, and pi is (pretty + // similar to) the usual principal bundle from S^{2n-1} to CP(n-1). The case k + // > 1 is the same, but requiring a linear inependence condition between the // vectors from the different S^{2n-1} or CP(n-1). // - // Note that this is a U(1)^k-bundle. In plain words, this means that the fibres of this bundle, - // i.e. pi^{-1}(x) for x \in M are isomorphic to U(1) x ... x U(1). - // This is obvious as, if pi(U,V) = x, - // pi^{-1}(x) = {(U diag(e^{i\theta}), V diag(e^{i\theta})) | \theta \in R^k} + // Note that this is a U(1)^k-bundle. In plain words, this means that the + // fibres of this bundle, i.e. pi^{-1}(x) for x \in M are isomorphic to U(1) x + // ... x U(1). This is obvious as, if pi(U,V) = x, pi^{-1}(x) = {(U + // diag(e^{i\theta}), V diag(e^{i\theta})) | \theta \in R^k} // = {(U diag(z), V diag(z)) | z \in U(1)^k} // since U(1) = {z \in C | |z| = 1}. // - // The big issue here is that M with its induced metric is not locally isometric to St_k(C^n) x U(k). - // [The why is rather technical, but you can see that the horizontal distribution is not involutive, - // and hence integrable due to Frobenius' theorem] - // What this means in plain words is that, no matter how we choose to return the U and V from the - // SVD, we won't be able to simply differentiate wrt. U and V and call it a day. - // An example of a case where we can do this is when performing an eigendecomposition on a real - // matrix that happens to have real eigendecomposition. In this case, even though you can rescale - // the eigenvectors by any real number, you can choose them of norm 1 and call it a day. - // In the eigenvector case, we are using that you can isometrically embed S^{n-1} into R^n. - // In the svd case, we need to work with the "quotient manifold" M explicitly, which is + // The big issue here is that M with its induced metric is not locally + // isometric to St_k(C^n) x U(k). [The why is rather technical, but you can + // see that the horizontal distribution is not involutive, and hence + // integrable due to Frobenius' theorem] What this means in plain words is + // that, no matter how we choose to return the U and V from the SVD, we won't + // be able to simply differentiate wrt. U and V and call it a day. An example + // of a case where we can do this is when performing an eigendecomposition on + // a real matrix that happens to have real eigendecomposition. In this case, + // even though you can rescale the eigenvectors by any real number, you can + // choose them of norm 1 and call it a day. In the eigenvector case, we are + // using that you can isometrically embed S^{n-1} into R^n. In the svd case, + // we need to work with the "quotient manifold" M explicitly, which is // slightly more technically challenging. // - // Since the columns of U and V are not uniquely defined, but are representatives of certain - // classes of equivalence which represent elements M, the user may not depend on the particular - // representative that we return from the SVD. In particular, if the loss function depends on U - // or V, it must be invariant under the transformation (U, V) -> (UL, VL) with - // L = diag(e^{i\theta})), for every \theta \in R^k. - // In more geometrical terms, this means that the loss function should be constant on the fibres, - // or, in other words, the gradient along the fibres should be zero. - // We may see this by checking that the gradients as element in the tangent space - // T_{(U, V)}(St(n,k) x U(k)) are normal to the fibres. Differentiating the map - // (U, V) -> (UL, VL), we see that the space tangent to the fibres is given by + // Since the columns of U and V are not uniquely defined, but are + // representatives of certain classes of equivalence which represent elements + // M, the user may not depend on the particular representative that we return + // from the SVD. In particular, if the loss function depends on U or V, it + // must be invariant under the transformation (U, V) -> (UL, VL) with L = + // diag(e^{i\theta})), for every \theta \in R^k. In more geometrical terms, + // this means that the loss function should be constant on the fibres, or, in + // other words, the gradient along the fibres should be zero. We may see this + // by checking that the gradients as element in the tangent space T_{(U, + // V)}(St(n,k) x U(k)) are normal to the fibres. Differentiating the map (U, + // V) -> (UL, VL), we see that the space tangent to the fibres is given by // Vert_{(U, V)}(St(n,k) x U(k)) = { i[U, V]diag(\theta) | \theta in R^k} - // where [U, V] denotes the vertical concatenation of U and V to form an (n+k, k) matrix. - // Then, solving - // = 0 for two matrices S, T \in T_{(U, V)}(St(n,k) x U(k)) - // where = Re tr(A^H B) is the canonical (real) inner product in C^{n x k} - // we get that the function is invariant under action of U(1)^k iff - // Im(diag(U^H gU + V^H gV)) = 0 + // where [U, V] denotes the vertical concatenation of U and V to form an (n+k, + // k) matrix. Then, solving = 0 for two matrices + // S, T \in T_{(U, V)}(St(n,k) x U(k)) where = Re tr(A^H B) is the + // canonical (real) inner product in C^{n x k} we get that the function is + // invariant under action of U(1)^k iff Im(diag(U^H gU + V^H gV)) = 0 // - // Using this in the derviaton for the forward AD, one sees that, with the notation from those notes - // Using this and writing sym(X) = X + X^H, we get that the forward AD for SVD in the complex - // case is given by - // dU = U (sym(dX S) / E + i Im(diag(dX)) / (2S)) - // if m > n + // Using this in the derviaton for the forward AD, one sees that, with the + // notation from those notes Using this and writing sym(X) = X + X^H, we get + // that the forward AD for SVD in the complex case is given by dU = U (sym(dX + // S) / E + i Im(diag(dX)) / (2S)) if m > n // dU = [dU for m == n] + (I_m - UU^H) dA V S^{-1} // dS = Re(diag(dP)) // dV = V (sym(S dX) / E - i Im(diag(dX)) / (2S)) @@ -2834,13 +3349,14 @@ Tensor svd_backward(const Tensor& gU, // Similarly, writing skew(X) = X - X^H // the adjoint wrt. the canonical metric is given by // if m == n - // gA = U [((skew(U^H gU) / E) S + i Im(diag(U^H gU)) / S + S ((skew(V^H gV) / E)) + I o gS] V^H + // gA = U [((skew(U^H gU) / E) S + i Im(diag(U^H gU)) / S + S ((skew(V^H gV) + // / E)) + I o gS] V^H // if m > n // gA = [term in m == n] + (I_m - UU^H)gU S^{-1} V^H // if m < n // gA = [term in m == n] + U S^{-1} (gV)^H (I_n - VV^H) - // where we have used that Im(diag(U^H gU)) = - Im(diag(V^h gV)) to group the diagonal imaginary terms into one - // that just depends on U^H gU. + // where we have used that Im(diag(U^H gU)) = - Im(diag(V^h gV)) to group the + // diagonal imaginary terms into one that just depends on U^H gU. // Checks compute_uv=true TORCH_INTERNAL_ASSERT(U.dim() >= 2 && Vh.dim() >= 2); @@ -2868,25 +3384,28 @@ Tensor svd_backward(const Tensor& gU, // Check for the invariance of the loss function, i.e. // Im(diag(U^H gU)) + Im(diag(V^H gV)) = 0 if (is_complex) { - const auto imdiag_UhgU = gU.defined() ? at::imag(UhgU.diagonal(0, -2, -1)) - : at::zeros_like(S); - const auto imdiag_VhgV = gVh.defined() ? at::imag(VhgV.diagonal(0, -2, -1)) - : at::zeros_like(S); + const auto imdiag_UhgU = + gU.defined() ? at::imag(UhgU.diagonal(0, -2, -1)) : at::zeros_like(S); + const auto imdiag_VhgV = + gVh.defined() ? at::imag(VhgV.diagonal(0, -2, -1)) : at::zeros_like(S); // Rather lax atol and rtol, as we don't want false positives - TORCH_CHECK(at::allclose(imdiag_UhgU, -imdiag_VhgV, /*rtol=*/1e-2, /*atol=*/1e-2), - "svd_backward: The singular vectors in the complex case are specified up to multiplication " - "by e^{i phi}. The specified loss function depends on this phase term, making " - "it ill-defined."); + TORCH_CHECK( + at::allclose(imdiag_UhgU, -imdiag_VhgV, /*rtol=*/1e-2, /*atol=*/1e-2), + "svd_backward: The singular vectors in the complex case are specified up to multiplication " + "by e^{i phi}. The specified loss function depends on this phase term, making " + "it ill-defined."); } - // gA = ((U^H gU) / E) S + S (((V^H gV) / E) + I o (gS + diag(U^H gU) / (2 * S)) + // gA = ((U^H gU) / E) S + S (((V^H gV) / E) + I o (gS + diag(U^H gU) / (2 * + // S)) Tensor gA = [&] { // ret holds everything but the diagonal of gA auto ret = [&] { - const auto E = [&S]{ + const auto E = [&S] { const auto S2 = S * S; auto ret = S2.unsqueeze(-2) - S2.unsqueeze(-1); - // Any number a != 0 would, as we are just going to use it to compute 0 / a later on + // Any number a != 0 would, as we are just going to use it to compute 0 + // / a later on ret.diagonal(0, -2, -1).fill_(1); return ret; }(); @@ -2913,7 +3432,7 @@ Tensor svd_backward(const Tensor& gU, if (m > n && gU.defined()) { // gA = [UgA + (I_m - UU^H)gU S^{-1}]V^H - gA = at::matmul(U, gA); + gA = at::matmul(U, gA); const auto gUSinv = gU / S.unsqueeze(-2); gA = gA + gUSinv - at::matmul(U, at::matmul(U.mH(), gUSinv)); gA = at::matmul(gA, Vh); @@ -2933,16 +3452,22 @@ Tensor svd_backward(const Tensor& gU, } // The implementation follows: -// "An extended collection of matrix derivative results for forward and reverse mode algorithmic differentiation" +// "An extended collection of matrix derivative results for forward and reverse +// mode algorithmic differentiation" // https://people.maths.ox.ac.uk/gilesm/files/NA-08-01.pdf -// However, the reference does not cover the constraints on eigenvectors to have 1-norm. -// See the details below. -Tensor eig_backward(const std::vector &grads, const Tensor& self, - bool is_eigvec_tensor_nonempty, const Tensor& eigenvalues, const Tensor& eigenvectors) { +// However, the reference does not cover the constraints on eigenvectors to have +// 1-norm. See the details below. +Tensor eig_backward( + const std::vector& grads, + const Tensor& self, + bool is_eigvec_tensor_nonempty, + const Tensor& eigenvalues, + const Tensor& eigenvectors) { at::NoTF32Guard disable_tf32; - TORCH_CHECK(is_eigvec_tensor_nonempty, - "eig_backward: torch.eig(eigenvalues=False) is not differentiable. ", - "Please use torch.linalg.eigvals"); + TORCH_CHECK( + is_eigvec_tensor_nonempty, + "eig_backward: torch.eig(eigenvalues=False) is not differentiable. ", + "Please use torch.linalg.eigvals"); // variable names correspond to the ones in the reference document auto D = eigenvalues; @@ -2953,9 +3478,9 @@ Tensor eig_backward(const std::vector &grads, const T // The condition below is trying to marry torch.eig and torch.linalg.eig // for real inputs. // - // For real inputs torch.eig returns a real 2D tensor representing real and complex - // components of eigenvalues, while torch.linalg.eig will most likely always - // return complex eigenvalues. + // For real inputs torch.eig returns a real 2D tensor representing real and + // complex components of eigenvalues, while torch.linalg.eig will most likely + // always return complex eigenvalues. if (!self.is_complex()) { auto is_imag_eigvals_zero = false; // path for torch.eig with always a "real" 2D tensor of eigenvalues @@ -2975,10 +3500,9 @@ Tensor eig_backward(const std::vector &grads, const T } // No support for complex eigenvalues for real inputs yet. TORCH_CHECK( - is_imag_eigvals_zero, - "eig_backward: Backward calculation does not support complex eigenvalues for real inputs at the moment."); - } - else { + is_imag_eigvals_zero, + "eig_backward: Backward calculation does not support complex eigenvalues for real inputs at the moment."); + } else { // torch.eig returns 2d tensors for eigenvalues, // while torch.linalg.eig returns 1d. // Hence we insert additional dimension for complex input, @@ -3003,28 +3527,27 @@ Tensor eig_backward(const std::vector &grads, const T // normalized to 1 norm, and the reference does not take that into account. // Hence, we have to modify the formula accordingly. // - // Normalization to 1 norm imposes the following constraint on the eigenvectors, i.e. - // (U^H U) * I = I, where I is an identity matrix. - // Forward AD for this expression yields: - // (dU^H U + U^H dU) * I = 0 => U^H dU * I = 0 <=> diag(U^H dU) = 0, which means - // that each i-th column of U is orthogonal to the i-th column of dU. - // Now, the value of dU which does not take this constraint into consideration - // comes straight from the reference: - // dU = U(F * U^{-1} dA U). - // To make sure that U^H dU * I = 0, and using U^H U * I = I (normalization), - // we propose a modifed forward AD for U: - // dU_new = dU - U(U^H dU * I) (think of Gram-Schmidt) + // Normalization to 1 norm imposes the following constraint on the + // eigenvectors, i.e. (U^H U) * I = I, where I is an identity matrix. Forward + // AD for this expression yields: (dU^H U + U^H dU) * I = 0 => U^H dU * I = 0 + // <=> diag(U^H dU) = 0, which means that each i-th column of U is orthogonal + // to the i-th column of dU. Now, the value of dU which does not take this + // constraint into consideration comes straight from the reference: dU = U(F * + // U^{-1} dA U). To make sure that U^H dU * I = 0, and using U^H U * I = I + // (normalization), we propose a modifed forward AD for U: dU_new = dU - U(U^H + // dU * I) (think of Gram-Schmidt) // - // The rest is very similar to what is done in the reference and we finally arrive at: + // The rest is very similar to what is done in the reference and we finally + // arrive at: // - // A_grad = U^{-H} (D_grad + (U^H U_grad - U^H U (U^H U_grad * I)) * F.conj()) U^H + // A_grad = U^{-H} (D_grad + (U^H U_grad - U^H U (U^H U_grad * I)) * F.conj()) + // U^H // = U^{-H} (eigenvalues_contribs + eigenvectors_contrib) U^H, where // eigenvalues_contribs := D_grad, // eigenvectors_contribs := (U^H U_grad - U^H U (U^H U_grad * I)) * F.conj(). - // The contributions from the eigenvectors and the eigenvalues are computed below, - // and then we solve the system - // U^H A_grad = (eigenvalues_contribs + eigenvectors_contribs) U_H - // to produce A_grad. + // The contributions from the eigenvectors and the eigenvalues are computed + // below, and then we solve the system U^H A_grad = (eigenvalues_contribs + + // eigenvectors_contribs) U_H to produce A_grad. // contribution from the eigenvectors Tensor U_contrib; @@ -3035,8 +3558,7 @@ Tensor eig_backward(const std::vector &grads, const T if (!F.is_complex()) { F.diagonal(0, -2, -1).fill_(INFINITY); F.pow_(-1); - } - else { + } else { // The F matrix construction for complex eigenvalues // if different from its real counterpart. // There is no complex INFINITY, and we cannot use @@ -3054,9 +3576,10 @@ Tensor eig_backward(const std::vector &grads, const T } auto U_grad_proj_onto_U = at::matmul(U.mH(), U_grad); auto Uh_U = at::matmul(U.mH(), U); - U_contrib = (U_grad_proj_onto_U - Uh_U * U_grad_proj_onto_U.diagonal(0, -2, -1).unsqueeze(-2)) * F.conj(); - } - else { + U_contrib = (U_grad_proj_onto_U - + Uh_U * U_grad_proj_onto_U.diagonal(0, -2, -1).unsqueeze(-2)) * + F.conj(); + } else { U_contrib = at::zeros_like(self, at::MemoryFormat::Contiguous); } @@ -3065,20 +3588,21 @@ Tensor eig_backward(const std::vector &grads, const T if (D_grad.defined()) { // narrow extracts the column corresponding to the real part D_contrib = D_grad.narrow(-1, 0, 1); - } - else { + } else { D_contrib = at::zeros_like(D, at::MemoryFormat::Contiguous); } - return at::linalg_solve(U.mH(), at::matmul(U_contrib, U.mH()) + D_contrib * U.mH()); + return at::linalg_solve( + U.mH(), at::matmul(U_contrib, U.mH()) + D_contrib * U.mH()); } -Tensor linalg_eig_backward(const Tensor& gL, - const Tensor& gV, - const Tensor& L, - const Tensor& V, - const bool is_hermitian, - const bool symeig_eigenvectors) { +Tensor linalg_eig_backward( + const Tensor& gL, + const Tensor& gV, + const Tensor& L, + const Tensor& V, + const bool is_hermitian, + const bool symeig_eigenvectors) { at::NoTF32Guard disable_tf32; // https://arxiv.org/pdf/1701.00392.pdf Eq 4.77 // For A = VLV^{-1}, denoting the gradients gA, gV and gL, we have @@ -3089,15 +3613,18 @@ Tensor linalg_eig_backward(const Tensor& gL, // - diag_embed takes a vector into a diagonal matrix // - diag zeroes out elements outside of the diagonal - // Note: the term '-V^HV diag(real(V^H gV))' comes from the fact that the eigenvalue - // decomposition is returned with eigenvectors normalized to have norm one. + // Note: the term '-V^HV diag(real(V^H gV))' comes from the fact that the + // eigenvalue decomposition is returned with eigenvectors normalized to have + // norm one. - // Note: The Hermitian case is a simplification of this formula using that V^{-1} = V^H and that L is real + // Note: The Hermitian case is a simplification of this formula using that + // V^{-1} = V^H and that L is real // This check just can be triggered in the backwards of torch.symeig - TORCH_CHECK(symeig_eigenvectors, - "linalg_eig_backward: torch.symeig(A, eigenvectors=False) is not differentiable. ", - "Use torch.linalg.eigvalsh(A) instead."); + TORCH_CHECK( + symeig_eigenvectors, + "linalg_eig_backward: torch.symeig(A, eigenvectors=False) is not differentiable. ", + "Use torch.linalg.eigvalsh(A) instead."); // Trivial case if (!gL.defined() && !gV.defined()) { @@ -3117,24 +3644,32 @@ Tensor linalg_eig_backward(const Tensor& gL, const auto diag_VhgV = VhgV.diagonal(0, -2, -1); if (V.is_complex()) { - // Check invariance of the loss function wrt the transformation V -> V e^{i\phi} + // Check invariance of the loss function wrt the transformation V -> V + // e^{i\phi} const auto imdiag_VhgV = at::imag(diag_VhgV); - TORCH_CHECK(at::allclose(imdiag_VhgV, at::zeros_like(imdiag_VhgV), /*rtol=*/1e-2, /*atol=*/1e-2), - is_hermitian ? "linalg_eigh_backward" : "linalg_eig_backward", - ": The eigenvectors in the complex case are specified up to multiplication ", - "by e^{i phi}. The specified loss function depends on this quantity, so it is ill-defined."); + TORCH_CHECK( + at::allclose( + imdiag_VhgV, + at::zeros_like(imdiag_VhgV), + /*rtol=*/1e-2, + /*atol=*/1e-2), + is_hermitian ? "linalg_eigh_backward" : "linalg_eig_backward", + ": The eigenvectors in the complex case are specified up to multiplication ", + "by e^{i phi}. The specified loss function depends on this quantity, so it is ill-defined."); } if (is_hermitian) { - // Project onto the tangent space at the identity of U(n), that is, the skew-Hermitian matrices + // Project onto the tangent space at the identity of U(n), that is, the + // skew-Hermitian matrices VhgV = 0.5 * (VhgV - VhgV.mH()); } else { - // Project onto the tangent space at V^H V of complex matrices with columns of norm 1 + // Project onto the tangent space at V^H V of complex matrices with columns + // of norm 1 VhgV = VhgV - at::matmul(V.mH(), V * at::real(diag_VhgV).unsqueeze(-2)); } - auto gA = [&, VhgV = std::move(VhgV)]{ - auto Econj = [&L]{ + auto gA = [&, VhgV = std::move(VhgV)] { + auto Econj = [&L] { auto Lconj = L.conj(); auto ret = Lconj.unsqueeze(-2) - Lconj.unsqueeze(-1); ret.diagonal(0, -2, -1).fill_(1.); @@ -3149,7 +3684,6 @@ Tensor linalg_eig_backward(const Tensor& gL, return ret; }(); - // Conjugate by V^{-H} if (is_hermitian) { return at::matmul(V, at::matmul(gA, V.mH())); @@ -3158,31 +3692,31 @@ Tensor linalg_eig_backward(const Tensor& gL, } } -std::tuple linalg_eig_jvp(const Tensor& dA, - const Tensor& L, - const Tensor& V, - const bool is_hermitian) { +std::tuple linalg_eig_jvp( + const Tensor& dA, + const Tensor& L, + const Tensor& V, + const bool is_hermitian) { at::NoTF32Guard disable_tf32; // https://people.maths.ox.ac.uk/gilesm/files/NA-08-01.pdf // see also https://arxiv.org/pdf/1701.00392.pdf Eqs. (4.60) and (4.63) - // Note that neither of the formulas in these pdfs are correct, as they do not assume that - // the eigenvectors are of unit norm. As such, they are missing the diagonal term in dV - // dL = diag(dP) - // dV = dX - V Re(diag V^H dX)) - // where - // dP = V^{-1} dA V - // dX = V ((dP - diag(dP)) / E) - // E_{ij} = L_j - L_i if i != j + // Note that neither of the formulas in these pdfs are correct, as they do not + // assume that the eigenvectors are of unit norm. As such, they are missing + // the diagonal term in dV dL = diag(dP) dV = dX - V Re(diag V^H dX)) where dP + // = V^{-1} dA V dX = V ((dP - diag(dP)) / E) E_{ij} = L_j - L_i if i != j // 1 otherwise // Precondition: if is_hermitian == true, then dA is Hermitian - const auto to_complex = [](const Tensor& A){ return A.to(c10::toComplexType(A.scalar_type())); }; + const auto to_complex = [](const Tensor& A) { + return A.to(c10::toComplexType(A.scalar_type())); + }; - const auto dP = is_hermitian ? at::matmul(at::matmul(V.mH(), dA), V) - : at::linalg_solve(V, at::matmul(to_complex(dA), V)); + const auto dP = is_hermitian + ? at::matmul(at::matmul(V.mH(), dA), V) + : at::linalg_solve(V, at::matmul(to_complex(dA), V)); auto dL = is_hermitian && dA.is_complex() ? at::real(dP.diagonal(0, -2, -1)) : dP.diagonal(0, -2, -1); - auto dV = [&dP, &V, &L, is_hermitian]{ + auto dV = [&dP, &V, &L, is_hermitian] { const auto dX = [&] { auto ret = dP / (L.unsqueeze(-2) - L.unsqueeze(-1)); ret.diagonal(0, -2, -1).zero_(); @@ -3193,18 +3727,19 @@ std::tuple linalg_eig_jvp(const Tensor& dA, if (is_hermitian) { return dX; } else { - return dX - V * at::real(at::matmul(V.mH(), dX).diagonal(0, -2, -1)).unsqueeze(-2); + return dX - + V * + at::real(at::matmul(V.mH(), dX).diagonal(0, -2, -1)).unsqueeze(-2); } }(); return std::make_pair(std::move(dL), std::move(dV)); } Tensor linalg_lstsq_jvp( - const Tensor& A, - const Tensor& B, - const Tensor& dA, - const Tensor& dB -) { + const Tensor& A, + const Tensor& B, + const Tensor& dA, + const Tensor& dB) { at::NoTF32Guard disable_tf32; auto pinvA = at::linalg_pinv(A); auto dpinvA = pinv_jvp(A, pinvA, dA); @@ -3213,13 +3748,12 @@ Tensor linalg_lstsq_jvp( } std::tuple linalg_lstsq_backward( - const Tensor& grad, - const Tensor& A, - const Tensor& B, - const c10::optional rcond, - const c10::optional driver, - const std::array& grad_input_mask -) { + const Tensor& grad, + const Tensor& A, + const Tensor& B, + const c10::optional rcond, + const c10::optional driver, + const std::array& grad_input_mask) { at::NoTF32Guard disable_tf32; Tensor A_grad, B_grad; if (!grad.defined()) { @@ -3241,8 +3775,9 @@ std::tuple linalg_lstsq_backward( pinvA = at::linalg_pinv(A); } // Equivalent to - // B_grad = std::get<0>(at::linalg_lstsq(A.transpose(-1, -2).conj(), grad, rcond, driver)); - // but we avoid this approach as `gelsy` is non-deterministic + // B_grad = std::get<0>(at::linalg_lstsq(A.transpose(-1, -2).conj(), grad, + // rcond, driver)); but we avoid this approach as `gelsy` is + // non-deterministic B_grad = pinvA.transpose(-1, -2).conj().matmul(grad); } @@ -3250,11 +3785,10 @@ std::tuple linalg_lstsq_backward( } std::tuple linalg_qr_jvp( - const Tensor& dA, - const Tensor& Q, - const Tensor& R, - const c10::string_view mode -) { + const Tensor& dA, + const Tensor& Q, + const Tensor& R, + const c10::string_view mode) { // dA = dQR + QdR // // Case m >= n @@ -3278,23 +3812,27 @@ std::tuple linalg_qr_jvp( // trilIm(Q^H dQ) = trilIm(Q^H A_1 R_1^{-1}) // and define trilIminv(X) = X - X^H - i*Im diag(X). This is the inverse of // trilIm : Skew_C(m) -> Tril(m, imaginary diag) - // Note that it is just the inverse when the inputs are skew-Hermitian, not necessarily - // when the inputs are arbitrary matrices. We then get - // dQ = Q trilImInv(trilIm(Q^H A_1 R_1^{-1})) + // Note that it is just the inverse when the inputs are skew-Hermitian, not + // necessarily when the inputs are arbitrary matrices. We then get dQ = Q + // trilImInv(trilIm(Q^H A_1 R_1^{-1})) at::NoTF32Guard disable_tf32; // NOLINTNEXTLINE(cppcoreguidelines-init-variables) bool compute_q, reduced; std::tie(compute_q, reduced) = at::native::_parse_qr_mode(mode); - TORCH_CHECK(compute_q, "The derivative of linalg.qr depends on Q, which is not computed when " - "mode='r'. Please use linalg.qr(A, mode='reduced') if you are " - "going to differentiate through linalg.qr."); + TORCH_CHECK( + compute_q, + "The derivative of linalg.qr depends on Q, which is not computed when " + "mode='r'. Please use linalg.qr(A, mode='reduced') if you are " + "going to differentiate through linalg.qr."); auto m = dA.size(-2); auto n = dA.size(-1); - TORCH_CHECK(reduced || m <= n, "The QR decomposition is not differentiable when " - "mode='complete' and nrows > ncols."); + TORCH_CHECK( + reduced || m <= n, + "The QR decomposition is not differentiable when " + "mode='complete' and nrows > ncols."); if (m >= n) { const auto sym = [](const Tensor& X) { return X + X.mH(); }; const auto syminv = [](const Tensor& X) { @@ -3302,7 +3840,8 @@ std::tuple linalg_qr_jvp( ret.diagonal(0, -2, -1).mul_(0.5); return ret; }; - auto dARinv = at::linalg_solve_triangular(R, dA, /*upper=*/true, /*left=*/false); + auto dARinv = + at::linalg_solve_triangular(R, dA, /*upper=*/true, /*left=*/false); auto dR = syminv(sym(Q.mH().matmul(dARinv))); auto dQ = dARinv - Q.matmul(dR); dR = dR.matmul(R); @@ -3323,12 +3862,16 @@ std::tuple linalg_qr_jvp( ret.diagonal(0, -2, -1).mul_(0.5); return ret; } else { - return X - X.mT() ; + return X - X.mT(); } }; auto QHdA = Q.mH().matmul(dA); - auto QHdA1Rinv = at::linalg_solve_triangular(R.narrow(-1, 0, m), QHdA.narrow(-1, 0, m), /*upper=*/true, /*left=*/false); + auto QHdA1Rinv = at::linalg_solve_triangular( + R.narrow(-1, 0, m), + QHdA.narrow(-1, 0, m), + /*upper=*/true, + /*left=*/false); auto dQ = triliminv(trilim(QHdA1Rinv)); auto dR = QHdA - dQ.matmul(R); dQ = Q.matmul(dQ); @@ -3336,9 +3879,12 @@ std::tuple linalg_qr_jvp( } } -Tensor linalg_qr_backward(const Tensor& gQ, const Tensor& gR, - const Tensor& Q, const Tensor& R, - const c10::string_view mode) { +Tensor linalg_qr_backward( + const Tensor& gQ, + const Tensor& gR, + const Tensor& Q, + const Tensor& R, + const c10::string_view mode) { // Nb. We won't be too formal below, as writing this proof formaly is a pain // We'll link here a formal writing of all this at some point in the future // @@ -3346,11 +3892,10 @@ Tensor linalg_qr_backward(const Tensor& gQ, const Tensor& gR, // dQ = dAR^{-1} - Qsyminv(sym(Q^H dA R^{-1})) // dR = syminv(sym(Q^H dA R^{-1}))R // - // With the notation from the JVP formla, the only two computations that we need are - // syminv*(R) = 0.5 * (R.triu() + R.triu()^H - Re diag(R)) - // sym*(X) = 2 * X - // Using these, after a few simplifications we get that - // gA = (gQ + syminvadj(triu(gR R^H - Q^H gQ)))R^{-H} + // With the notation from the JVP formla, the only two computations that we + // need are syminv*(R) = 0.5 * (R.triu() + R.triu()^H - Re diag(R)) sym*(X) = + // 2 * X Using these, after a few simplifications we get that gA = (gQ + + // syminvadj(triu(gR R^H - Q^H gQ)))R^{-H} // // Case m < n // dR = Q^H dA - Q^H dQ R @@ -3372,15 +3917,19 @@ Tensor linalg_qr_backward(const Tensor& gQ, const Tensor& gR, bool compute_q, reduced; std::tie(compute_q, reduced) = at::native::_parse_qr_mode(mode); - TORCH_CHECK(compute_q, "The derivative of linalg.qr depends on Q, which is not computed when " - "mode='r'. Please use linalg.qr(A, mode='reduced') if you are " - "going to differentiate through linalg.qr."); + TORCH_CHECK( + compute_q, + "The derivative of linalg.qr depends on Q, which is not computed when " + "mode='r'. Please use linalg.qr(A, mode='reduced') if you are " + "going to differentiate through linalg.qr."); auto m = Q.size(-2); auto n = R.size(-1); - TORCH_CHECK(reduced || m <= n, "The QR decomposition is not differentiable when " - "mode='complete' and nrows > ncols."); + TORCH_CHECK( + reduced || m <= n, + "The QR decomposition is not differentiable when " + "mode='complete' and nrows > ncols."); if (!gQ.defined() && !gR.defined()) { return {}; @@ -3394,7 +3943,7 @@ Tensor linalg_qr_backward(const Tensor& gQ, const Tensor& gR, gA = -Q.mH().matmul(gQ); } } else { - gA = gR.matmul(R.mH()); + gA = gR.matmul(R.mH()); } if (m >= n) { const auto syminvadj = [](const Tensor& X) { @@ -3406,7 +3955,8 @@ Tensor linalg_qr_backward(const Tensor& gQ, const Tensor& gR, if (gQ.defined()) { gA = gA + gQ; } - gA = at::linalg_solve_triangular(R.mH(), gA, /*upper*/false, /*left*/false); + gA = at::linalg_solve_triangular( + R.mH(), gA, /*upper*/ false, /*left*/ false); return gA; } else { auto trilImInvAdjSkew = [](const Tensor& X) { @@ -3417,7 +3967,8 @@ Tensor linalg_qr_backward(const Tensor& gQ, const Tensor& gR, return ret; }; gA = Q.matmul(trilImInvAdjSkew(-gA)); - gA = at::linalg_solve_triangular(R.narrow(-1, 0, m).mH(), gA, /*upper*/false, /*left*/false); + gA = at::linalg_solve_triangular( + R.narrow(-1, 0, m).mH(), gA, /*upper*/ false, /*left*/ false); auto shape = R.sizes().vec(); shape.end()[-1] = n - m; gA = at::cat({gA, gA.new_zeros(shape)}, /*dim=*/-1); @@ -3436,12 +3987,14 @@ Tensor linalg_qr_backward(const Tensor& gQ, const Tensor& gR, template Tensor differential_analytic_matrix_function( - const Tensor& self, const Tensor& grad, + const Tensor& self, + const Tensor& grad, const func_t& matrix_function, - const bool adjoint // Choose between forward (adjoint=false) or backward AD (adjoint=true) - ) { - // Given an analytic matrix function, this computes the differential (forward AD) - // or the adjoint of the differential (backward AD) + const bool adjoint // Choose between forward (adjoint=false) or backward AD + // (adjoint=true) +) { + // Given an analytic matrix function, this computes the differential (forward + // AD) or the adjoint of the differential (backward AD) auto A = adjoint ? self.transpose(-2, -1).conj() : self; auto meta_grad_sizes = A.sizes().vec(); meta_grad_sizes[A.dim() - 2] *= 2; @@ -3456,18 +4009,23 @@ Tensor differential_analytic_matrix_function( return matrix_function(meta_grad).narrow(-2, 0, n).narrow(-1, n, n); } -Tensor linalg_matrix_exp_differential(const Tensor& self, const Tensor& grad, bool adjoint) { +Tensor linalg_matrix_exp_differential( + const Tensor& self, + const Tensor& grad, + bool adjoint) { at::NoTF32Guard disable_tf32; - return differential_analytic_matrix_function(self, grad, at::linalg_matrix_exp, /* adjoint */ adjoint); + return differential_analytic_matrix_function( + self, grad, at::linalg_matrix_exp, /* adjoint */ adjoint); } -Tensor det_backward(const Tensor & grad, const Tensor& self, const Tensor& det) { +Tensor det_backward(const Tensor& grad, const Tensor& self, const Tensor& det) { if (self.numel() == 0) { return at::empty_like(self); } - auto det_backward_nonsingular = [&](const Tensor& grad, const Tensor& self, const Tensor& det) -> Tensor { + auto det_backward_nonsingular = + [&](const Tensor& grad, const Tensor& self, const Tensor& det) -> Tensor { // Derived from Jacobi's formula for partial derivative, which can be found // at https://en.wikipedia.org/wiki/Jacobi%27s_formula // i.e. if A is the input matrix, then @@ -3480,7 +4038,8 @@ Tensor det_backward(const Tensor & grad, const Tensor& self, const Tensor& det) return at::linalg_solve(self.mH(), d); }; - auto det_backward_singular = [&](const Tensor& grad, const Tensor& self, const Tensor& det) -> Tensor { + auto det_backward_singular = + [&](const Tensor& grad, const Tensor& self, const Tensor& det) -> Tensor { // Derived from the gradient formula that would be used if `self`'s // determinant is calculated using SVD, like so: // u, s, vh = svd(self) @@ -3496,7 +4055,8 @@ Tensor det_backward(const Tensor & grad, const Tensor& self, const Tensor& det) auto u_det_grad = grad * (vh_det * s_prod).conj(); auto u_grad = det_backward_nonsingular(u_det_grad, u, u_det); - auto s_prod_grad = handle_r_to_c(s_prod.scalar_type(), grad * (u_det * vh_det).conj()); + auto s_prod_grad = + handle_r_to_c(s_prod.scalar_type(), grad * (u_det * vh_det).conj()); auto s_grad = prod_backward(s_prod_grad, s, s_prod, -1, false); auto vh_det_grad = grad * (u_det * s_prod).conj(); @@ -3527,23 +4087,26 @@ Tensor det_backward(const Tensor & grad, const Tensor& self, const Tensor& det) Tensor self_grad = self.new_empty(self.sizes(), self.options()); - auto nonzero_det_list = at::native::toListOfOptionalTensors(nonzero_det_mask); + auto nonzero_det_list = + at::native::toListOfOptionalTensors(nonzero_det_mask); self_grad.index_put_( - /*indices=*/nonzero_det_list, - // NOLINTNEXTLINE(bugprone-argument-comment) - /*value=*/det_backward_nonsingular( - grad.index(nonzero_det_list), - self.index(nonzero_det_list), - det.index(nonzero_det_list))); + /*indices=*/nonzero_det_list, + // NOLINTNEXTLINE(bugprone-argument-comment) + /*value=*/ + det_backward_nonsingular( + grad.index(nonzero_det_list), + self.index(nonzero_det_list), + det.index(nonzero_det_list))); auto zero_det_list = at::native::toListOfOptionalTensors(zero_det_mask); self_grad.index_put_( - /*indices=*/zero_det_list, - // NOLINTNEXTLINE(bugprone-argument-comment) - /*value=*/det_backward_singular( - grad.index(zero_det_list), - self.index(zero_det_list), - det.index(zero_det_list))); + /*indices=*/zero_det_list, + // NOLINTNEXTLINE(bugprone-argument-comment) + /*value=*/ + det_backward_singular( + grad.index(zero_det_list), + self.index(zero_det_list), + det.index(zero_det_list))); return self_grad; } @@ -3552,12 +4115,11 @@ Tensor det_backward(const Tensor & grad, const Tensor& self, const Tensor& det) // The backward for this function is just a specialized version of // lu.backward, which is implemented in /torch/_autograd_functions.py Tensor _det_lu_based_helper_backward( - const Tensor& det_grad, - const Tensor& det, - const Tensor& self, - const Tensor& lu, - const Tensor& pivs -) { + const Tensor& det_grad, + const Tensor& det, + const Tensor& self, + const Tensor& lu, + const Tensor& pivs) { if (!self.numel()) { return at::zeros_like(self, at::MemoryFormat::Contiguous); } @@ -3565,17 +4127,17 @@ Tensor _det_lu_based_helper_backward( return Tensor(); } - // run det_backward only if backward is run on _det_lu_based_helper_backward. - // _det_lu_based_helper_backward is more stable for forward det computing functions, - // but it fails with double backward gradient checks (gradgradcheck). - // det_backward, on the other hand, is less stable (due to restrictions on svd_backward, - // namely, svd_backward requries distinct singular values which are sufficiently different - // from each other), yet, if its computation is stable, so is its double backward. - // Hence, if only single backward is run, we use _det_lu_based_helper_backward, - // for the double backward case we use det_backward. The latter approach could produce - // unstable gradients, therefore we DO NOT recommend double backpropagation through - // det computing functions. + // _det_lu_based_helper_backward is more stable for forward det computing + // functions, but it fails with double backward gradient checks + // (gradgradcheck). det_backward, on the other hand, is less stable (due to + // restrictions on svd_backward, namely, svd_backward requries distinct + // singular values which are sufficiently different from each other), yet, if + // its computation is stable, so is its double backward. Hence, if only single + // backward is run, we use _det_lu_based_helper_backward, for the double + // backward case we use det_backward. The latter approach could produce + // unstable gradients, therefore we DO NOT recommend double backpropagation + // through det computing functions. if (at::GradMode::is_enabled()) { return det_backward(det_grad, self, det); } @@ -3583,11 +4145,16 @@ Tensor _det_lu_based_helper_backward( // we use a sequence of kernels to avoid memory copies and checks, // as in the implementation of this function we are 100% sure that // `lu` and `pivs` are coming from a LAPACK routine. - return at::_det_lu_based_helper_backward_helper(det_grad, det, self, lu, pivs); + return at::_det_lu_based_helper_backward_helper( + det_grad, det, self, lu, pivs); } -Tensor logdet_backward(const Tensor & grad, const Tensor& self, const Tensor& logdet) { - auto singular_case_backward = [&](const Tensor& grad, const Tensor& self) -> Tensor { +Tensor logdet_backward( + const Tensor& grad, + const Tensor& self, + const Tensor& logdet) { + auto singular_case_backward = [&](const Tensor& grad, + const Tensor& self) -> Tensor { Tensor u, sigma, vh; std::tie(u, sigma, vh) = at::linalg_svd(self, false); // logdet = \sum log(sigma) @@ -3595,7 +4162,8 @@ Tensor logdet_backward(const Tensor & grad, const Tensor& self, const Tensor& lo return svd_backward({}, gsigma, {}, u, sigma, vh); }; - auto nonsingular_case_backward = [&](const Tensor& grad, const Tensor& self) -> Tensor { + auto nonsingular_case_backward = [&](const Tensor& grad, + const Tensor& self) -> Tensor { return unsqueeze_multiple(grad, {-1, -2}, self.dim()) * self.inverse().mT(); }; @@ -3607,42 +4175,55 @@ Tensor logdet_backward(const Tensor & grad, const Tensor& self, const Tensor& lo return singular_case_backward(grad, self); } } else { - auto finite_logdet_indices = at::native::toListOfOptionalTensors(at::where(logdet != -INFINITY)); + auto finite_logdet_indices = + at::native::toListOfOptionalTensors(at::where(logdet != -INFINITY)); c10::optional first_finite_logdet_index = finite_logdet_indices[0]; - if (first_finite_logdet_index->size(0) == logdet.numel()) { // all log determinants are finite (non-singular) + if (first_finite_logdet_index->size(0) == + logdet.numel()) { // all log determinants are finite (non-singular) return nonsingular_case_backward(grad, self); } - auto neginf_logdet_indices = at::native::toListOfOptionalTensors(at::where(logdet == -INFINITY)); + auto neginf_logdet_indices = + at::native::toListOfOptionalTensors(at::where(logdet == -INFINITY)); c10::optional first_neginf_logdet_index = neginf_logdet_indices[0]; - if (first_neginf_logdet_index->size(0) == logdet.numel()) { // all log determinants are -inf (singular) + if (first_neginf_logdet_index->size(0) == + logdet.numel()) { // all log determinants are -inf (singular) return singular_case_backward(grad, self); } Tensor grad_logdet = at::empty_like(self, LEGACY_CONTIGUOUS_MEMORY_FORMAT); // invertible case - grad_logdet.index_put_(/*indices=*/finite_logdet_indices, - // NOLINTNEXTLINE(bugprone-argument-comment) - /*value=*/nonsingular_case_backward(grad.index(finite_logdet_indices), - self.index(finite_logdet_indices))); + grad_logdet.index_put_( + /*indices=*/finite_logdet_indices, + // NOLINTNEXTLINE(bugprone-argument-comment) + /*value=*/ + nonsingular_case_backward( + grad.index(finite_logdet_indices), + self.index(finite_logdet_indices))); // non-invertible case, uses SVD - grad_logdet.index_put_(/*indices=*/neginf_logdet_indices, - // NOLINTNEXTLINE(bugprone-argument-comment) - /*value=*/singular_case_backward(grad.index(neginf_logdet_indices), - self.index(neginf_logdet_indices))); + grad_logdet.index_put_( + /*indices=*/neginf_logdet_indices, + // NOLINTNEXTLINE(bugprone-argument-comment) + /*value=*/ + singular_case_backward( + grad.index(neginf_logdet_indices), + self.index(neginf_logdet_indices))); return grad_logdet; } } -Tensor slogdet_backward(const Tensor& grad_logabsdet, - const Tensor& self, - const Tensor& signdet, const Tensor& logabsdet) { - auto singular_case_backward = [&](const Tensor& grad_logabsdet, const Tensor& self) -> Tensor { +Tensor slogdet_backward( + const Tensor& grad_logabsdet, + const Tensor& self, + const Tensor& signdet, + const Tensor& logabsdet) { + auto singular_case_backward = [&](const Tensor& grad_logabsdet, + const Tensor& self) -> Tensor { Tensor u, sigma, vh; std::tie(u, sigma, vh) = at::linalg_svd(self, false); Tensor v = vh.mH(); @@ -3653,46 +4234,60 @@ Tensor slogdet_backward(const Tensor& grad_logabsdet, return svd_backward({}, gsigma, {}, u, sigma, vh); }; - auto nonsingular_case_backward = [&](const Tensor& grad_logabsdet, const Tensor& self) -> Tensor { + auto nonsingular_case_backward = [&](const Tensor& grad_logabsdet, + const Tensor& self) -> Tensor { // TODO: replace self.inverse with linalg_inverse - return unsqueeze_multiple(grad_logabsdet, {-1, -2}, self.dim()) * self.inverse().mH(); + return unsqueeze_multiple(grad_logabsdet, {-1, -2}, self.dim()) * + self.inverse().mH(); }; if (self.dim() == 2) { - bool is_singular = self.is_complex() ? signdet.abs().item() == 0 : signdet.item() == 0; + bool is_singular = self.is_complex() ? signdet.abs().item() == 0 + : signdet.item() == 0; if (is_singular) { return singular_case_backward(grad_logabsdet, self); } else { return nonsingular_case_backward(grad_logabsdet, self); } } else { - auto nonzero_signdet_indices = at::native::toListOfOptionalTensors(self.is_complex() ? at::where(signdet.abs()) : at::where(signdet)); - c10::optional first_nonzero_signdet_index = nonzero_signdet_indices[0]; + auto nonzero_signdet_indices = at::native::toListOfOptionalTensors( + self.is_complex() ? at::where(signdet.abs()) : at::where(signdet)); + c10::optional first_nonzero_signdet_index = + nonzero_signdet_indices[0]; - if (first_nonzero_signdet_index->size(0) == logabsdet.numel()) { // all log determinants are finite (non-singular) + if (first_nonzero_signdet_index->size(0) == + logabsdet.numel()) { // all log determinants are finite (non-singular) return nonsingular_case_backward(grad_logabsdet, self); } - auto zero_signdet_indices = at::native::toListOfOptionalTensors(at::where(signdet == 0)); + auto zero_signdet_indices = + at::native::toListOfOptionalTensors(at::where(signdet == 0)); c10::optional first_zero_signdet_index = zero_signdet_indices[0]; - if (first_zero_signdet_index->size(0) == logabsdet.numel()) { // all log determinants are -inf (singular) + if (first_zero_signdet_index->size(0) == + logabsdet.numel()) { // all log determinants are -inf (singular) return singular_case_backward(grad_logabsdet, self); } Tensor grad_slogdet = at::empty_like(self, LEGACY_CONTIGUOUS_MEMORY_FORMAT); // invertible case - grad_slogdet.index_put_(/*indices=*/nonzero_signdet_indices, - // NOLINTNEXTLINE(bugprone-argument-comment) - /*value=*/nonsingular_case_backward(grad_logabsdet.index(nonzero_signdet_indices), - self.index(nonzero_signdet_indices))); + grad_slogdet.index_put_( + /*indices=*/nonzero_signdet_indices, + // NOLINTNEXTLINE(bugprone-argument-comment) + /*value=*/ + nonsingular_case_backward( + grad_logabsdet.index(nonzero_signdet_indices), + self.index(nonzero_signdet_indices))); // non-invertible case, uses SVD - grad_slogdet.index_put_(/*indices=*/zero_signdet_indices, - // NOLINTNEXTLINE(bugprone-argument-comment) - /*value=*/singular_case_backward(grad_logabsdet.index(zero_signdet_indices), - self.index(zero_signdet_indices))); + grad_slogdet.index_put_( + /*indices=*/zero_signdet_indices, + // NOLINTNEXTLINE(bugprone-argument-comment) + /*value=*/ + singular_case_backward( + grad_logabsdet.index(zero_signdet_indices), + self.index(zero_signdet_indices))); return grad_slogdet; } @@ -3702,21 +4297,28 @@ Tensor slogdet_backward(const Tensor& grad_logabsdet, // https://people.maths.ox.ac.uk/gilesm/files/NA-08-01.pdf // Sec. 2.3.1 Matrix inverse product std::tuple triangular_solve_backward( - const Tensor & grad_x, const Tensor & grad_m, - const Tensor & b, const Tensor & a, const Tensor & x, - const bool upper, const bool transpose, const bool unitriangular, + const Tensor& grad_x, + const Tensor& grad_m, + const Tensor& b, + const Tensor& a, + const Tensor& x, + const bool upper, + const bool transpose, + const bool unitriangular, std::array output_mask) { at::NoTF32Guard disable_tf32; Tensor grad_b, grad_a; if (grad_x.defined() || grad_m.defined()) { if (grad_x.defined()) { - grad_b = std::get<0>(grad_x.triangular_solve(a.conj(), upper, !transpose, unitriangular)); + grad_b = std::get<0>( + grad_x.triangular_solve(a.conj(), upper, !transpose, unitriangular)); if (output_mask[1]) { - grad_a = transpose ? -x.conj().matmul(grad_b.mT()) : -grad_b.matmul(x.mH()); + grad_a = + transpose ? -x.conj().matmul(grad_b.mT()) : -grad_b.matmul(x.mH()); if (upper) { - grad_a = grad_a.triu((int) unitriangular); + grad_a = grad_a.triu((int)unitriangular); } else { - grad_a = grad_a.tril(-((int) unitriangular)); + grad_a = grad_a.tril(-((int)unitriangular)); } } } @@ -3734,18 +4336,22 @@ std::tuple triangular_solve_backward( } Tensor triangular_solve_jvp( - const Tensor& X, const Tensor& A, - const Tensor& dA, const Tensor& dB, - const bool upper, - const bool transpose, - const bool unitriangular -) { + const Tensor& X, + const Tensor& A, + const Tensor& dA, + const Tensor& dB, + const bool upper, + const bool transpose, + const bool unitriangular) { return generic_solve_jvp( - [&](const Tensor& A, const Tensor& dB, const Tensor& dA_contrib) { - return std::get<0>(at::triangular_solve(dB - dA_contrib, A, upper, transpose, unitriangular)); - }, - X, A, dA, dB - ); + [&](const Tensor& A, const Tensor& dB, const Tensor& dA_contrib) { + return std::get<0>(at::triangular_solve( + dB - dA_contrib, A, upper, transpose, unitriangular)); + }, + X, + A, + dA, + dB); } Tensor linalg_solve_triangular_forward_AD( @@ -3761,8 +4367,9 @@ Tensor linalg_solve_triangular_forward_AD( // For the derivation see: // [Note: Forward / Backward AD solve_triangular] const Tensor proj_A_t = upper ? A_t.triu(static_cast(unitriangular)) - : A_t.tril(- static_cast(unitriangular)); - const Tensor X_t = B_t - (left ? at::matmul(proj_A_t, X) : at::matmul(X, proj_A_t)); + : A_t.tril(-static_cast(unitriangular)); + const Tensor X_t = + B_t - (left ? at::matmul(proj_A_t, X) : at::matmul(X, proj_A_t)); return at::linalg_solve_triangular(A, X_t, upper, left, unitriangular); } @@ -3796,26 +4403,28 @@ std::tuple linalg_solve_triangular_backward( // Note that you don't need to store B for forward nor backward // // These formulas work for a general solver of linear equations. - // Let's prove now that when A is triangular, G_A is the projection onto the triangular matrices - // of the formula above, i.e. simply taking triu (resp. tril) in the formula above. - // This is because, since the triangular matrices form a vector space, the tangent space at any - // point is itself the space of triangular matrices. The result follows from a reasoning as that - // at the end of [Note: eigh backward] - // Something similar happens for `unitriangular`, only that int his case the tangent space is - // the set of lower-triangular matrices with zeros on the diagonal. + // Let's prove now that when A is triangular, G_A is the projection onto the + // triangular matrices of the formula above, i.e. simply taking triu (resp. + // tril) in the formula above. This is because, since the triangular matrices + // form a vector space, the tangent space at any point is itself the space of + // triangular matrices. The result follows from a reasoning as that at the end + // of [Note: eigh backward] Something similar happens for `unitriangular`, + // only that int his case the tangent space is the set of lower-triangular + // matrices with zeros on the diagonal. if (!grad.defined() || (!A_requires_grad && !B_requires_grad)) { - return std::make_tuple(Tensor{}, Tensor{}); + return std::make_tuple(Tensor{}, Tensor{}); } // We always need to comput G_B const Tensor A_H = A.mH(); - const Tensor G_B = at::linalg_solve_triangular(A_H, grad, !upper, left, unitriangular); + const Tensor G_B = + at::linalg_solve_triangular(A_H, grad, !upper, left, unitriangular); if (A_requires_grad) { const Tensor X_H = X.mH(); Tensor G_A = left ? -at::matmul(G_B, X_H) : -at::matmul(X_H, G_B); G_A = upper ? G_A.triu(static_cast(unitriangular)) - : G_A.tril(- static_cast(unitriangular)); + : G_A.tril(-static_cast(unitriangular)); return std::make_tuple(G_A, B_requires_grad ? G_B : Tensor{}); } else { return std::make_tuple(Tensor{}, G_B); @@ -3823,8 +4432,11 @@ std::tuple linalg_solve_triangular_backward( } std::tuple cholesky_solve_backward( - const Tensor& grad_x, const Tensor& self, - const Tensor& input2, const Tensor& result, const bool upper) { + const Tensor& grad_x, + const Tensor& self, + const Tensor& input2, + const Tensor& result, + const bool upper) { at::NoTF32Guard disable_tf32; Tensor grad_self, grad_input2; if (grad_x.defined()) { @@ -3843,25 +4455,28 @@ std::tuple cholesky_solve_backward( } Tensor cholesky_solve_jvp( - const Tensor& X, - const Tensor& U, - const Tensor& dU, - const Tensor& dB, - const bool upper -) { + const Tensor& X, + const Tensor& U, + const Tensor& dU, + const Tensor& dB, + const bool upper) { at::NoTF32Guard disable_tf32; - auto dK = upper ? dU.mH().matmul(U) - : dU.matmul(U.mH()); + auto dK = upper ? dU.mH().matmul(U) : dU.matmul(U.mH()); auto dA = dK + dK.mH(); return generic_solve_jvp( - [&](const Tensor& A, const Tensor& dB, const Tensor& dA_contrib) { - return at::cholesky_solve(dB - dA_contrib, A, upper); - }, - X, /*A=*/U, dA, dB - ); + [&](const Tensor& A, const Tensor& dB, const Tensor& dA_contrib) { + return at::cholesky_solve(dB - dA_contrib, A, upper); + }, + X, + /*A=*/U, + dA, + dB); } -Tensor fft_c2r_backward(const Tensor& grad, IntArrayRef dim, int64_t normalization) { +Tensor fft_c2r_backward( + const Tensor& grad, + IntArrayRef dim, + int64_t normalization) { // Forward is C2R (onesided) // Think of onesided C2R irfft as // 1. fill the other half by conjugate symmetry @@ -3892,14 +4507,18 @@ Tensor fft_c2r_backward(const Tensor& grad, IntArrayRef dim, int64_t normalizati auto gI = at::_fft_r2c(grad, dim, normalization, /*onesided=*/true); auto double_length = grad.size(dim.back()) - gI.size(dim.back()); - if (double_length > 0) { // also covers case when signal size is zero + if (double_length > 0) { // also covers case when signal size is zero gI.narrow(dim.back(), 1, double_length).mul_(2); } return gI; } -Tensor fft_r2c_backward(const Tensor& grad, IntArrayRef dim, int64_t normalization, - bool onesided, int64_t last_dim_size) { +Tensor fft_r2c_backward( + const Tensor& grad, + IntArrayRef dim, + int64_t normalization, + bool onesided, + int64_t last_dim_size) { if (!onesided) { return at::real(at::_fft_c2c(grad, dim, normalization, /*forward=*/false)); } @@ -3920,15 +4539,17 @@ Tensor fft_r2c_backward(const Tensor& grad, IntArrayRef dim, int64_t normalizati new_grad_shape[last_dim] = last_dim_size; const auto zero_length = last_dim_size - grad.size(dim.back()); - auto complex_full_grad = zero_length > 0 ? grad.new_zeros(new_grad_shape) : grad; + auto complex_full_grad = + zero_length > 0 ? grad.new_zeros(new_grad_shape) : grad; if (zero_length > 0) { complex_full_grad.slice(last_dim, 0, half_sizes[last_dim]).copy_(grad); } - return at::real(at::_fft_c2c(complex_full_grad, dim, normalization, /*forward=*/false)); + return at::real( + at::_fft_c2c(complex_full_grad, dim, normalization, /*forward=*/false)); } // Helper for batchnorm_double_backward -Tensor sum_exclude_dim1(const Tensor& to_sum, bool keepdim=true) { +Tensor sum_exclude_dim1(const Tensor& to_sum, bool keepdim = true) { auto r = to_sum.sum(0, keepdim); int64_t start_point_exclusive = keepdim ? 1 : 0; for (int64_t dim = r.dim() - 1; dim > start_point_exclusive; dim--) { @@ -3963,20 +4584,19 @@ Tensor expand_as_dim1(const Tensor& src, const Tensor& target) { } std::tuple batchnorm_double_backward( - const Tensor & input, - const c10::optional & gamma, - const Tensor & ggI, - const Tensor & ggG, - const Tensor & ggB, - const Tensor & gO, - const c10::optional & running_mean, - const c10::optional & running_var, + const Tensor& input, + const c10::optional& gamma, + const Tensor& ggI, + const Tensor& ggG, + const Tensor& ggB, + const Tensor& gO, + const c10::optional& running_mean, + const c10::optional& running_var, bool training, double eps, - const c10::optional & save_mean, - const c10::optional & save_invstd, - std::array output_mask) { - + const c10::optional& save_mean, + const c10::optional& save_invstd, + std::array output_mask) { bool affine = isDefined(gamma); // TODO: Do we have a ScalarOrTensor type? Would such a thing exist? Tensor gamma_expanded; @@ -4000,7 +4620,10 @@ std::tuple batchnorm_double_backward( } // for half inputs, save_mean, save_invstd are float (ideally, we would cast // everything else, but not now) - auto mu = unsqueeze_dim1(training ? toNonOptTensor(save_mean).to(input.scalar_type()) : toNonOptTensor(running_mean), input); + auto mu = unsqueeze_dim1( + training ? toNonOptTensor(save_mean).to(input.scalar_type()) + : toNonOptTensor(running_mean), + input); auto input_sub_mu = input - mu; auto sigma2_eps_neg_1_2 = unsqueeze_dim1( training ? toNonOptTensor(save_invstd).to(input.scalar_type()) @@ -4018,11 +4641,15 @@ std::tuple batchnorm_double_backward( if (ggI.defined() && training) { auto ggI_sum = sum_exclude_dim1(ggI); auto ggIinmu_sum = sum_exclude_dim1(ggI * input_sub_mu); - auto all_sub = ((ggI_sum * gO_sum).div_(M)).sub_(sum_exclude_dim1(gO * ggI)).add_( - (sigma2_eps_neg_1 * gOinmu_sum * ggIinmu_sum).mul_(3. / M)); + auto all_sub = + ((ggI_sum * gO_sum).div_(M)) + .sub_(sum_exclude_dim1(gO * ggI)) + .add_((sigma2_eps_neg_1 * gOinmu_sum * ggIinmu_sum).mul_(3. / M)); auto gI_0t = (input_mu_sigma2_neg_3_2 * all_sub).div_(M); - auto gI_1t = (ggIinmu_sum * sigma2_eps_neg_3_2).div_(M) * (gO_sum.div(M) - gO); - auto gI_2t = (gOinmu_sum * sigma2_eps_neg_3_2).div_(M) * (ggI_sum.div(M) - ggI); + auto gI_1t = + (ggIinmu_sum * sigma2_eps_neg_3_2).div_(M) * (gO_sum.div(M) - gO); + auto gI_2t = + (gOinmu_sum * sigma2_eps_neg_3_2).div_(M) * (ggI_sum.div(M) - ggI); gI = gamma_expanded * (gI_0t.add_(gI_1t).add_(gI_2t)); } @@ -4032,7 +4659,8 @@ std::tuple batchnorm_double_backward( if (training) { auto t0 = gO * sigma2_eps_neg_1_2; auto t1 = (sigma2_eps_neg_1_2 * gO_sum).div_(-M); - auto t2 = (input_mu_sigma2_neg_3_2 * sum_exclude_dim1(gO * input_sub_mu)).div_(-M); + auto t2 = (input_mu_sigma2_neg_3_2 * sum_exclude_dim1(gO * input_sub_mu)) + .div_(-M); gI_G_term = ggG_expanded * (t0.add_(t1).add_(t2)); gI = gI.defined() ? gI.add_(gI_G_term) : gI_G_term; } else { @@ -4042,10 +4670,14 @@ std::tuple batchnorm_double_backward( } // this is the first backward's grad_input - auto first_back_grad_input = [&](const Tensor& gO, const Tensor& gamma) -> Tensor { + auto first_back_grad_input = [&](const Tensor& gO, + const Tensor& gamma) -> Tensor { auto h0 = (gamma * sigma2_eps_neg_1_2).div_(M); - auto h1 = (M * gO).sub_(sum_exclude_dim1(gO)).sub_( - input_sub_mu.mul(sigma2_eps_neg_1) * sum_exclude_dim1(gO * input_sub_mu)); + auto h1 = (M * gO) + .sub_(sum_exclude_dim1(gO)) + .sub_( + input_sub_mu.mul(sigma2_eps_neg_1) * + sum_exclude_dim1(gO * input_sub_mu)); return h0 * h1; }; @@ -4053,8 +4685,10 @@ std::tuple batchnorm_double_backward( Tensor gG; if (affine && ggI.defined()) { if (training) { - // gG is just the first backwards with the gamma term removed (then shaped properly) - gG = ggI * first_back_grad_input(gO, at::ones({}, sigma2_eps_neg_1_2.options())); + // gG is just the first backwards with the gamma term removed (then shaped + // properly) + gG = ggI * + first_back_grad_input(gO, at::ones({}, sigma2_eps_neg_1_2.options())); gG = sum_exclude_dim1(gG, false); } else { gG = sum_exclude_dim1(ggI * gO * sigma2_eps_neg_1_2, false); @@ -4085,7 +4719,6 @@ std::tuple batchnorm_double_backward( } return std::tuple{gI, gG, ggO}; - } std::tuple layer_norm_double_backward( @@ -4099,7 +4732,6 @@ std::tuple layer_norm_double_backward( const Tensor& save_invstd_t, IntArrayRef normalized_shape, std::array output_mask) { - const int normalized_ndim = normalized_shape.size(); const auto input_shape = input_t.sizes(); const auto input_ndim = input_t.dim(); @@ -4109,7 +4741,7 @@ std::tuple layer_norm_double_backward( c10::multiply_integers(input_shape.cbegin(), input_shape.cbegin() + axis); const int64_t N = c10::multiply_integers(input_shape.cbegin() + axis, input_shape.cend()); - //printf("M: %ld, N: %ld", M, N); + // printf("M: %ld, N: %ld", M, N); auto input = input_t.reshape({M, N}); auto gO = gO_t.reshape({M, N}); @@ -4149,7 +4781,6 @@ std::tuple layer_norm_double_backward( auto input_mu_sigma2_neg_3_2 = input_sub_mu * sigma2_eps_neg_3_2; if (ggI.defined()) { - auto gxhat = gO * gamma_expanded; auto gxhat_mu_sum = (gxhat * input_sub_mu).sum(1, true); auto gxhat_sum = gxhat.sum(1, true); @@ -4157,11 +4788,15 @@ std::tuple layer_norm_double_backward( auto ggI_sum = ggI_expanded.sum(1, true); auto ggI_mu_sum = (ggI_expanded * input_sub_mu).sum(1, true); - auto all_sub = ((ggI_sum * gxhat_sum).div_(N)).sub_((ggI_expanded * gxhat).sum(1, true)).add_( - (sigma2_eps_neg_1 * gxhat_mu_sum * ggI_mu_sum).mul_(3. / N)); + auto all_sub = + ((ggI_sum * gxhat_sum).div_(N)) + .sub_((ggI_expanded * gxhat).sum(1, true)) + .add_((sigma2_eps_neg_1 * gxhat_mu_sum * ggI_mu_sum).mul_(3. / N)); auto gI_0t = (input_mu_sigma2_neg_3_2 * all_sub).div_(N); - auto gI_1t = (ggI_mu_sum * sigma2_eps_neg_3_2).div_(N) * (gxhat_sum.div(N) - gxhat); - auto gI_2t = (gxhat_mu_sum * sigma2_eps_neg_3_2).div_(N) * (ggI_sum.div(N) - ggI_expanded); + auto gI_1t = + (ggI_mu_sum * sigma2_eps_neg_3_2).div_(N) * (gxhat_sum.div(N) - gxhat); + auto gI_2t = (gxhat_mu_sum * sigma2_eps_neg_3_2).div_(N) * + (ggI_sum.div(N) - ggI_expanded); gI = (gI_0t.add_(gI_1t).add_(gI_2t)); } @@ -4170,29 +4805,35 @@ std::tuple layer_norm_double_backward( if (affine && ggG.defined()) { auto t0 = gO * ggG_expanded * sigma2_eps_neg_1_2; auto t1 = (sigma2_eps_neg_1_2 * (gO * ggG_expanded).sum(1, true)).div_(-N); - auto t2 = (input_mu_sigma2_neg_3_2 * (gO * ggG_expanded * input_sub_mu).sum(1,true)).div_(-N); + auto t2 = (input_mu_sigma2_neg_3_2 * + (gO * ggG_expanded * input_sub_mu).sum(1, true)) + .div_(-N); auto gI_G_term = t0.add_(t1).add_(t2); gI = gI.defined() ? gI.add_(gI_G_term) : gI_G_term; } - if (gI.defined()) { - //printf("=== computing gI\n"); + // printf("=== computing gI\n"); gI = gI.reshape_as(input_t); } // this is the grad_input for the first backward function - auto first_bwd_fn_grad_input = [&](const Tensor& gO_local, const Tensor& gamma_local) -> Tensor { + auto first_bwd_fn_grad_input = [&](const Tensor& gO_local, + const Tensor& gamma_local) -> Tensor { auto h0 = (gamma_local * sigma2_eps_neg_1_2).div_(N); - auto h1 = (N * gO_local).sub_(gO_local.sum(1,true)).sub_( - input_sub_mu.mul(sigma2_eps_neg_1) * (gO_local * input_sub_mu).sum(1,true)); + auto h1 = (N * gO_local) + .sub_(gO_local.sum(1, true)) + .sub_( + input_sub_mu.mul(sigma2_eps_neg_1) * + (gO_local * input_sub_mu).sum(1, true)); return h0 * h1; }; // calculate gG Tensor gG; if (affine && ggI.defined()) { - gG = first_bwd_fn_grad_input(ggI_expanded, at::ones({}, sigma2_eps_neg_1_2.options())); + gG = first_bwd_fn_grad_input( + ggI_expanded, at::ones({}, sigma2_eps_neg_1_2.options())); gG = (gO * gG).sum(0); gG = gG.reshape_as(*gamma); } @@ -4302,7 +4943,9 @@ infinitely_differentiable_native_group_norm_backward( } } if (grad_input_mask[1] && dY.defined()) { - dgamma = ((ds - db * mean_tensor) * rstd_tensor).sum(0).reshape_as(toNonOptTensor(gamma)); + dgamma = ((ds - db * mean_tensor) * rstd_tensor) + .sum(0) + .reshape_as(toNonOptTensor(gamma)); } if (grad_input_mask[2] && dY.defined()) { dbeta = db.sum(0).reshape_as(toNonOptTensor(gamma)); @@ -4311,17 +4954,27 @@ infinitely_differentiable_native_group_norm_backward( return std::make_tuple(dX, dgamma, dbeta); } -std::tuple _trilinear_backward(const Tensor& grad_out, const Tensor& i1, const Tensor& i2, const Tensor& i3, - IntArrayRef expand1, IntArrayRef expand2, IntArrayRef expand3, - IntArrayRef sumdim, std::array grad_mask) { +std::tuple _trilinear_backward( + const Tensor& grad_out, + const Tensor& i1, + const Tensor& i2, + const Tensor& i3, + IntArrayRef expand1, + IntArrayRef expand2, + IntArrayRef expand3, + IntArrayRef sumdim, + std::array grad_mask) { Tensor grad_i1, grad_i2, grad_i3; if (grad_out.defined()) { if (grad_mask[0]) - grad_i1 = at::_trilinear(grad_out, i2, i3, sumdim, expand2, expand3, expand1); + grad_i1 = + at::_trilinear(grad_out, i2, i3, sumdim, expand2, expand3, expand1); if (grad_mask[1]) - grad_i2 = at::_trilinear(i1, grad_out, i3, expand1, sumdim, expand3, expand2); + grad_i2 = + at::_trilinear(i1, grad_out, i3, expand1, sumdim, expand3, expand2); if (grad_mask[2]) - grad_i3 = at::_trilinear(i1, i2, grad_out, expand1, expand2, sumdim, expand3); + grad_i3 = + at::_trilinear(i1, i2, grad_out, expand1, expand2, sumdim, expand3); } return std::tuple(grad_i1, grad_i2, grad_i3); } @@ -4329,10 +4982,10 @@ std::tuple _trilinear_backward(const Tensor& grad_out, c Tensor log1p_backward(const Tensor& grad, const Tensor& self) { if (self.is_sparse()) { AT_ERROR( - "log1p of a sparse tensor is made to be non-differentiable since ", - "local gradient of zero is 1 / (0 + 1) = 1 and it makes the tensor dense. ", - "Use a different mathematical operation which preserves sparsity of gradients, ", - "or report a bug if you think this is an error."); + "log1p of a sparse tensor is made to be non-differentiable since ", + "local gradient of zero is 1 / (0 + 1) = 1 and it makes the tensor dense. ", + "Use a different mathematical operation which preserves sparsity of gradients, ", + "or report a bug if you think this is an error."); } return grad / (self + 1).conj(); } @@ -4340,23 +4993,34 @@ Tensor log1p_backward(const Tensor& grad, const Tensor& self) { Tensor sinc_backward(const Tensor& grad, const Tensor& self) { auto self_pi = self * M_PI; auto self_squared_pi = self * self * M_PI; - auto out = grad * ((self_pi * self_pi.cos() - self_pi.sin()) / self_squared_pi).conj(); + auto out = grad * + ((self_pi * self_pi.cos() - self_pi.sin()) / self_squared_pi).conj(); return at::where(self_squared_pi == 0.0, at::zeros({}, grad.options()), out); } -Tensor sparse_constructor_values_backward(const Tensor& sparse_grad_out, const Tensor& indices) { +Tensor sparse_constructor_values_backward( + const Tensor& sparse_grad_out, + const Tensor& indices) { return _sparse_mask_helper(sparse_grad_out.coalesce(), indices.contiguous()); } -// Because the backward of pad(input, pads) is just pad(grad_output, [-p for p in pads]) +// Because the backward of pad(input, pads) is just pad(grad_output, [-p for p +// in pads]) Tensor constant_pad_nd_backward(const Tensor& grad, IntArrayRef pad) { auto negated_pad = pad.vec(); // NOLINTNEXTLINE(modernize-use-transparent-functors) - std::transform(negated_pad.cbegin(), negated_pad.cend(), negated_pad.begin(), std::negate()); + std::transform( + negated_pad.cbegin(), + negated_pad.cend(), + negated_pad.begin(), + std::negate()); return at::constant_pad_nd(grad, negated_pad, 0); } -Tensor embedding_dense_double_backward(const Tensor & grad, const Tensor & indices, int64_t padding_idx) { +Tensor embedding_dense_double_backward( + const Tensor& grad, + const Tensor& indices, + int64_t padding_idx) { // since first backward takes care of scaling by frequency, // we don't need to worry about it here. auto gg_weight = grad.index_select(0, indices.reshape(-1)); @@ -4371,14 +5035,21 @@ Tensor embedding_dense_double_backward(const Tensor & grad, const Tensor & indic return gg_weight.view(size); } -Tensor index_backward(Tensor zeros_like_self, const torch::List>& indices, const Tensor& grad) { +Tensor index_backward( + Tensor zeros_like_self, + const torch::List>& indices, + const Tensor& grad) { return (areAnyTensorSubclassLike({zeros_like_self, grad}) || areAnyOptionalTensorSubclassLike(indices)) ? zeros_like_self.index_put(indices, grad, true) : at::_index_put_impl_(zeros_like_self, indices, grad, true, true); } -Tensor _cudnn_ctc_loss_backward(const Tensor& grad_out, const Tensor& loss, const Tensor& raw_grad, bool zero_infinity) { +Tensor _cudnn_ctc_loss_backward( + const Tensor& grad_out, + const Tensor& loss, + const Tensor& raw_grad, + bool zero_infinity) { if (zero_infinity) { return at::where( loss.unsqueeze(0).unsqueeze(2) == 0, @@ -4400,66 +5071,75 @@ bool any_variable_defined(const variable_list& variables) { // Derivations for the householder_product.backward method. // -// Given a sequence of vectors v_1, ..., v_n and a sequence of scalars tau_1, ..., tau_k, -// the torch.linalg.householder_product computes the firt n columns of the following product: -// Q = (I - tau_1 v_1 v_1^H) ... (I - tau_k v_k v_k^H). -// Let -// H_i(sigma) := I - sigma v_i v_i^H, so Q = (H_1(sigma_1) ... H_k(sigma_k)[:, :k]; -// H_i_minus = H_1(tau_1) ... H_{i - 1}(tau_{i - 1}), with H_1_minus := I; -// H_i_plus = H_{i + 1}(tau_{i + 1}) ... H_k(tau_k) with H_k_plus := I; +// Given a sequence of vectors v_1, ..., v_n and a sequence of scalars tau_1, +// ..., tau_k, the torch.linalg.householder_product computes the firt n columns +// of the following product: Q = (I - tau_1 v_1 v_1^H) ... (I - tau_k v_k +// v_k^H). Let +// H_i(sigma) := I - sigma v_i v_i^H, so Q = (H_1(sigma_1) ... +// H_k(sigma_k)[:, :k]; H_i_minus = H_1(tau_1) ... H_{i - 1}(tau_{i - 1}), +// with H_1_minus := I; H_i_plus = H_{i + 1}(tau_{i + 1}) ... H_k(tau_k) +// with H_k_plus := I; // // Forward AD: -// dQ = sum_{i = 1}^k H_i_minus (-dtau_i v_i v_i^H - tau_i dv_i v_i^H - tau_i v_i dv_i^H) H_i_plus. +// dQ = sum_{i = 1}^k H_i_minus (-dtau_i v_i v_i^H - tau_i dv_i v_i^H - tau_i +// v_i dv_i^H) H_i_plus. // // Backward AD: -// Tr(Q_grad^H dQ) = sum_{i = 1}^k Tr(H_i_plus Q_grad^H H_i_minus (-dtau_i v_i v_i^H - tau_i dv_i v_i^H - tau_i v_i dv_i^H)). -// Let K_i := H_i_plus Q_grad^H H_i_minus, then the gradients are -// v_i_grad = (-tau_i v_i^H K_i)^H - tau_i K_i v_i, -// tau_i_grad = Tr(-v_i^H K_i v_i).conj(). -// NOTE: the algorithms ignores that only n columns of Q are observed, so there is no need in -// recomputing Q to full completion. -// -// Note that K_{i + 1} = H_{i + 1}^{-1} K_i H_i, so we can compute v_i_grad, tau_i_grad one by one -// by just efficiently updating K_i if that is possible. Multiplying with H_i from the right could be -// done with matrix-vector products, but what about the inverse H_{i + 1}^{-1} and does it even exist? -// Luckily, under some assumptions, H_{i + 1}^{-1} exists and admits a representation as H_i(sigma_i) for some -// sigma_i, so the left update is also could be done with matrix-vector and not matrix-matrix products. +// Tr(Q_grad^H dQ) = sum_{i = 1}^k Tr(H_i_plus Q_grad^H H_i_minus (-dtau_i v_i +// v_i^H - tau_i dv_i v_i^H - tau_i v_i dv_i^H)). Let K_i := H_i_plus Q_grad^H +// H_i_minus, then the gradients are v_i_grad = (-tau_i v_i^H K_i)^H - tau_i K_i +// v_i, tau_i_grad = Tr(-v_i^H K_i v_i).conj(). NOTE: the algorithms ignores +// that only n columns of Q are observed, so there is no need in recomputing Q +// to full completion. +// +// Note that K_{i + 1} = H_{i + 1}^{-1} K_i H_i, so we can compute v_i_grad, +// tau_i_grad one by one by just efficiently updating K_i if that is possible. +// Multiplying with H_i from the right could be done with matrix-vector +// products, but what about the inverse H_{i + 1}^{-1} and does it even exist? +// Luckily, under some assumptions, H_{i + 1}^{-1} exists and admits a +// representation as H_i(sigma_i) for some sigma_i, so the left update is also +// could be done with matrix-vector and not matrix-matrix products. // // Let H(tau) := I - tau v v^H. -// H(tau) has eigenvalues 1 with multiplicity (m - 1) with eigenvectors orthogonal to v, -// and an eigenvalue (1 - tau ||v||^2) with the corresponding eigenvector v / ||v||. -// If (1 - tau ||v||^2) != 0, H(tau) is invertible. -// If (1 - tau ||v||^2) != 0, then with sigma defined as -// sigma := tau / (||v||^2 tau - 1) we get that -// H(tau) H(sigma) = H(sigma) H(tau) = I, so H(sigma) is the inverse of H(tau). +// H(tau) has eigenvalues 1 with multiplicity (m - 1) with eigenvectors +// orthogonal to v, and an eigenvalue (1 - tau ||v||^2) with the corresponding +// eigenvector v / ||v||. If (1 - tau ||v||^2) != 0, H(tau) is invertible. If (1 +// - tau ||v||^2) != 0, then with sigma defined as sigma := tau / (||v||^2 tau - +// 1) we get that H(tau) H(sigma) = H(sigma) H(tau) = I, so H(sigma) is the +// inverse of H(tau). // // WARNING: the algorithm below assumes that H_i(tau_i) are all invertible, so // it expects that (1 - tau_i ||v_i||^2) != 0 for all i. -// We would like to point out that if there is H_i(tau_i) which is not invertible, -// the householder_product is still differentiable! We will not be able to compute K_i -// efficiently in such cases, however, as evaluating of each K_i will amount to calls -// to ORGQR to be able to compute H_i_plus. +// We would like to point out that if there is H_i(tau_i) which is not +// invertible, the householder_product is still differentiable! We will not be +// able to compute K_i efficiently in such cases, however, as evaluating of each +// K_i will amount to calls to ORGQR to be able to compute H_i_plus. // This function computes either the product between -// (I - tau u v^H) and K (in-place or not) with `condition_with_I = true`, or between +// (I - tau u v^H) and K (in-place or not) with `condition_with_I = true`, or +// between // (-tau u v^H) and K (out-of-place only) with `condition_with_I = false`. // Parameter `left` controls whether the matrix K is multiplied from the left or // from the right. -// Additionally, when the computation is done in-place, we exploit that the first -// `k` coordinates of `u_full/v_full` are zeros. +// Additionally, when the computation is done in-place, we exploit that the +// first `k` coordinates of `u_full/v_full` are zeros. Tensor apply_simple_transformation( - int64_t m, int64_t k, - const Tensor& u_full, const Tensor& v_full, - const Tensor& t, Tensor& K, + int64_t m, + int64_t k, + const Tensor& u_full, + const Tensor& v_full, + const Tensor& t, + Tensor& K, bool modify_K_in_place = true, bool condition_with_I = true, bool left = true) { - // we assume u_full is a vector of dimension (..., m, 1), t is a scalar of dimension (..., 1) + // we assume u_full is a vector of dimension (..., m, 1), t is a scalar of + // dimension (..., 1) - // TODO: matrix-vector products in the code below are dispatched to matrix-matrix products. - // We either need to extend matmul to support batched matrix-vector products, or - // implement a batched variant of mv. - // We could enable mv for inputs which are not batched, but it is not done to eliminate + // TODO: matrix-vector products in the code below are dispatched to + // matrix-matrix products. We either need to extend matmul to support batched + // matrix-vector products, or implement a batched variant of mv. We could + // enable mv for inputs which are not batched, but it is not done to eliminate // the code duplication. // returns (I - t u v^H) K or -t u v^H K @@ -4469,8 +5149,7 @@ Tensor apply_simple_transformation( auto u = v_full.narrow(-2, k, m - k).mH().matmul(K.narrow(-2, k, m - k)); K.narrow(-2, k, m - k).sub_((t.unsqueeze(-1) * v) * u); return K; - } - else { + } else { auto transformation = (t.unsqueeze(-1) * u_full) * v_full.mH().matmul(K); return condition_with_I ? K - transformation : -transformation; } @@ -4479,18 +5158,22 @@ Tensor apply_simple_transformation( else { if (modify_K_in_place) { auto v = u_full.narrow(-2, k, m - k); - auto u = K.narrow(-1, k, m - k).matmul(t.unsqueeze(-1) * v_full.narrow(-2, k, m - k)); + auto u = K.narrow(-1, k, m - k) + .matmul(t.unsqueeze(-1) * v_full.narrow(-2, k, m - k)); K.narrow(-1, k, m - k).sub_(u * v.mH()); return K; - } - else { + } else { auto transformation = K.matmul(t.unsqueeze(-1) * u_full) * v_full.mH(); return condition_with_I ? K - transformation : -transformation; } } }; -std::tuple householder_product_backward(const Tensor& grad, const Tensor& result, const Tensor& input_, const Tensor& tau) { +std::tuple householder_product_backward( + const Tensor& grad, + const Tensor& result, + const Tensor& input_, + const Tensor& tau) { if (!grad.defined() || !input_.numel() || !tau.numel()) { return std::tuple(Tensor(), Tensor()); } @@ -4509,51 +5192,67 @@ std::tuple householder_product_backward(const Tensor& grad, cons // compute sigma such that // H(sigma_i) == H(tau_i)^{-1}. // If the input to householder_product comes from GEQRF, - // we will never encounter ||v_i||^2 tau_i == 1, so H(tau_i) will always be invertible. - // This follows from the documentation https://www.netlib.org/lapack/lug/node128.html, - // and tau always satisfying the condition |tau|^2 ||v||^2 == 2 * Re(tau). + // we will never encounter ||v_i||^2 tau_i == 1, so H(tau_i) will always be + // invertible. This follows from the documentation + // https://www.netlib.org/lapack/lug/node128.html, and tau always satisfying + // the condition |tau|^2 ||v||^2 == 2 * Re(tau). auto input_first_k_cols = input.narrow(-1, 0, k); - auto input_first_k_cols_norm_squared = ( - input_first_k_cols * input_first_k_cols.conj() - ).sum(-2); + auto input_first_k_cols_norm_squared = + (input_first_k_cols * input_first_k_cols.conj()).sum(-2); auto sigma = tau / (tau * input_first_k_cols_norm_squared - 1.0); auto K = result.matmul(grad.mH()); - // The algorithm updates K by multiplying it from the left/right with Householder reflectors. - // If only single backward is run, we modify K in-place and exploit triangularity of the input. - // With higher order derivatives we cannot rewrite the storage of K, hence we use much less efficient - // out-of-place methods. + // The algorithm updates K by multiplying it from the left/right with + // Householder reflectors. If only single backward is run, we modify K + // in-place and exploit triangularity of the input. With higher order + // derivatives we cannot rewrite the storage of K, hence we use much less + // efficient out-of-place methods. // - // if only first-order derivative is expected, we can modify K in-place for better performance + // if only first-order derivative is expected, we can modify K in-place for + // better performance bool modify_K_in_place = !at::GradMode::is_enabled(); - // This method exploites that at k-th iteration vector v_k has only elements v_k[k:] which are non-zero. - auto update_grad = [&m](int64_t k, const Tensor& v_full, const Tensor& t, const Tensor& K) -> std::tuple { - // v_full is a vector of dimension (..., m, 1), t is a scalar of dimension (..., 1) + // This method exploites that at k-th iteration vector v_k has only elements + // v_k[k:] which are non-zero. + auto update_grad = [&m]( + int64_t k, + const Tensor& v_full, + const Tensor& t, + const Tensor& K) -> std::tuple { + // v_full is a vector of dimension (..., m, 1), t is a scalar of dimension + // (..., 1) auto v = v_full.narrow(-2, k, m - k); auto vHK = v.mH().matmul(K.narrow(-2, k, m - k)); auto Kv = K.narrow(-1, k, m - k).matmul(v); auto t_unsqueezed = t.unsqueeze(-1); - auto v_grad = (-t_unsqueezed * vHK).conj().squeeze(-2) - (t_unsqueezed * Kv).squeeze(-1); + auto v_grad = (-t_unsqueezed * vHK).conj().squeeze(-2) - + (t_unsqueezed * Kv).squeeze(-1); auto tau_grad = -(vHK.narrow(-1, k, m - k).matmul(v)).conj(); return std::make_tuple(v_grad, tau_grad.squeeze(-1)); }; auto apply_householder_reflector = [m, modify_K_in_place]( - int64_t k, const Tensor& v_full, - const Tensor& t, Tensor& K, - bool left = true) -> Tensor { + int64_t k, + const Tensor& v_full, + const Tensor& t, + Tensor& K, + bool left = true) -> Tensor { return apply_simple_transformation( - m, k, v_full, v_full, t, K, modify_K_in_place, /*condition_with_I=*/true, left - ); + m, + k, + v_full, + v_full, + t, + K, + modify_K_in_place, + /*condition_with_I=*/true, + left); }; // K <- H_0^{-1} @ K K = apply_householder_reflector( - 0, input.narrow(-1, 0, 1), sigma.narrow(-1, 0, 1), - K, /*left=*/true - ); + 0, input.narrow(-1, 0, 1), sigma.narrow(-1, 0, 1), K, /*left=*/true); for (const auto i : c10::irange(k)) { // NOTE: narrow will unsqueeze(-1) auto v_i = input.narrow(-1, i, 1); @@ -4569,13 +5268,8 @@ std::tuple householder_product_backward(const Tensor& grad, cons auto v_i_next = input.narrow(-1, i + 1, 1); auto s_i_next = sigma.narrow(-1, i + 1, 1); K = apply_householder_reflector( - i + 1, v_i_next, s_i_next, - K, /*left=*/true - ); - K = apply_householder_reflector( - i, v_i, t_i, - K, /*left=*/false - ); + i + 1, v_i_next, s_i_next, K, /*left=*/true); + K = apply_householder_reflector(i, v_i, t_i, K, /*left=*/false); } } @@ -4586,14 +5280,14 @@ std::tuple householder_product_backward(const Tensor& grad, cons return std::make_tuple(input_grad, tau_grad); } -// We refer to the derivations described above the method `apply_simple_transformation` +// We refer to the derivations described above the method +// `apply_simple_transformation` Tensor householder_product_jvp( const Tensor& dV_, const Tensor& dtau, const Tensor& prod, const Tensor& V_, - const Tensor& tau -) { + const Tensor& tau) { auto m = V_.size(-2); auto k = tau.size(-1); @@ -4606,41 +5300,56 @@ Tensor householder_product_jvp( // compute sigma such that // H(sigma_i) == H(tau_i)^{-1}. // If the input to householder_product comes from GEQRF, - // we will never encounter ||v_i||^2 tau_i == 1, so H(tau_i) will always be invertible. - // This follows from the documentation https://www.netlib.org/lapack/lug/node128.html, - // and tau always satisfying the condition |tau|^2 ||v||^2 == 2 * Re(tau). + // we will never encounter ||v_i||^2 tau_i == 1, so H(tau_i) will always be + // invertible. This follows from the documentation + // https://www.netlib.org/lapack/lug/node128.html, and tau always satisfying + // the condition |tau|^2 ||v||^2 == 2 * Re(tau). auto V_first_k_cols = V.narrow(-1, 0, k); - auto V_first_k_cols_norm_squared = ( - V_first_k_cols * V_first_k_cols.conj() - ).sum(-2); + auto V_first_k_cols_norm_squared = + (V_first_k_cols * V_first_k_cols.conj()).sum(-2); auto sigma = tau / (tau * V_first_k_cols_norm_squared - 1.0); - auto apply_householder_reflector = [m]( - const Tensor& v_full, - const Tensor& t, Tensor& K, - bool left = true) -> Tensor { + auto apply_householder_reflector = [m](const Tensor& v_full, + const Tensor& t, + Tensor& K, + bool left = true) -> Tensor { return apply_simple_transformation( - // setting `modify_K_in_place = true` causes CUDA memory leaks in OpInfo tests of forward AD - // for that reason we ignore `k` by passing zero - m, /*k=*/0, v_full, v_full, t, K, /*modify_K_in_place=*/false, /*condition_with_I=*/true, left - ); + // setting `modify_K_in_place = true` causes CUDA memory leaks in OpInfo + // tests of forward AD for that reason we ignore `k` by passing zero + m, + /*k=*/0, + v_full, + v_full, + t, + K, + /*modify_K_in_place=*/false, + /*condition_with_I=*/true, + left); }; // computes (-t u v^H) K - auto apply_simple_product = [m]( - const Tensor& u_full, const Tensor& v_full, - const Tensor& t, Tensor& K - ) -> Tensor { + auto apply_simple_product = [m](const Tensor& u_full, + const Tensor& v_full, + const Tensor& t, + Tensor& K) -> Tensor { return apply_simple_transformation( - // since ``modify_K_in_place = false`, we can ignore `k` and pass arbitrary value - m, /*k=*/0, u_full, v_full, t, K, /*modify_K_in_place=*/false, /*condition_with_I=*/false, /*left=*/true - ); + // since ``modify_K_in_place = false`, we can ignore `k` and pass + // arbitrary value + m, + /*k=*/0, + u_full, + v_full, + t, + K, + /*modify_K_in_place=*/false, + /*condition_with_I=*/false, + /*left=*/true); }; - auto H_plus = prod.detach().clone(); IntArrayRef batch_vector_shape(V.sizes().data(), V.dim() - 1); - auto H_minus = at::diag_embed(at::ones({1}, V.options()).expand(batch_vector_shape)); + auto H_minus = + at::diag_embed(at::ones({1}, V.options()).expand(batch_vector_shape)); auto dprod = at::zeros_like(prod); for (const auto i : c10::irange(k)) { @@ -4653,10 +5362,9 @@ Tensor householder_product_jvp( H_plus = apply_householder_reflector(v_i, sigma_i, H_plus, /*left=*/true); dprod.add_(H_minus.matmul( - apply_simple_product(v_i, v_i, dtau_i, H_plus) - + apply_simple_product(dv_i, v_i, tau_i, H_plus) - + apply_simple_product(v_i, dv_i, tau_i, H_plus) - )); + apply_simple_product(v_i, v_i, dtau_i, H_plus) + + apply_simple_product(dv_i, v_i, tau_i, H_plus) + + apply_simple_product(v_i, dv_i, tau_i, H_plus))); H_minus = apply_householder_reflector(v_i, tau_i, H_minus, /*left=*/false); } @@ -4752,31 +5460,32 @@ Tensor i1e_backward( // Note that for any invertible matrix A we have A A^{-1} = I, hence // dA A^{-1} + A dA^{-1} = 0 => dA^{-1} = -A^{-1} dA A^{-1}. // Inserting it back into the definition of dX gives: -// dX = -U^{-1} dU U^{-1} L^{-1} P^T B - U^{-1} L^{-1} dL L^{-1} P^T B + U^{-1} L^{-1} P^T dB -// dX = -U^{-1} dU X - U^{-1} L^{-1} dL U X + U^{-1} L^{-1} P^T dB +// dX = -U^{-1} dU U^{-1} L^{-1} P^T B - U^{-1} L^{-1} dL L^{-1} P^T B + U^{-1} +// L^{-1} P^T dB dX = -U^{-1} dU X - U^{-1} L^{-1} dL U X + U^{-1} L^{-1} P^T dB // // Backward AD: // // Using the definition of dL, dU from above: -// Tr(L_grad^H dL) + Tr(U_grad^H dU) = Tr(L_grad^H (dLU * 1_L)) + Tr(U_grad^H (dLU * 1_U)) +// Tr(L_grad^H dL) + Tr(U_grad^H dU) = Tr(L_grad^H (dLU * 1_L)) + Tr(U_grad^H +// (dLU * 1_U)) // = [using Tr(A (B * C)) = Tr((A * B^T) C) -// = Tr((L_grad^H * 1_L^T) dLU) + Tr((U_grad^H * 1_U^T) dLU), +// = Tr((L_grad^H * 1_L^T) dLU) + Tr((U_grad^H +// * 1_U^T) dLU), // hence // LU_grad = L_grad * 1_L + U_grad * 1_U (!!!) // // Then, transposing the formula for dX above we get: -// B_grad = P L^{-H} U^{-H} X_grad = lu_solve(X_grad, LU_data, LU_pivots, /*adjoint=*/true) -// U_grad = -U^{-H} X_grad X^H -// L_grad = L^{-H} U_grad U^H +// B_grad = P L^{-H} U^{-H} X_grad = lu_solve(X_grad, LU_data, LU_pivots, +// /*adjoint=*/true) U_grad = -U^{-H} X_grad X^H L_grad = L^{-H} U_grad U^H // After inserting U_grad and L_grad into (!!!) we get the value for LU_grad. Tensor linalg_lu_solve_LU( - const Tensor& gX, - const Tensor& LU, - const Tensor& pivots, - const Tensor& X, - const bool left, - const bool adjoint) { + const Tensor& gX, + const Tensor& LU, + const Tensor& pivots, + const Tensor& X, + const bool left, + const bool adjoint) { // From linalg_lu_solve_jvp we have that: // left = True, adjoint = True: A^HX = B // left = True, adjoint = False: AX = B @@ -4790,8 +5499,8 @@ Tensor linalg_lu_solve_LU( // dX = op_2(-U^{-1}(dU + L^{-1}dL U)op_2(X)) + S // else: // dX = op_2(op_1(-op_3(X)^H P(LdUU^{-1} + dL)L^{-1} P^T)) + S - // So computing the adjoint of this operation we get that, using an auxiliary variable gR - // if left != adjoint: + // So computing the adjoint of this operation we get that, using an auxiliary + // variable gR if left != adjoint: // gR = U^{-H}op_2(-gX)op_2(X)^H // gU = gR.triu() // gL = (L^{-H} gR U^H).tril(-1) @@ -4803,40 +5512,56 @@ Tensor linalg_lu_solve_LU( at::NoTF32Guard disable_tf32; Tensor P, L, U; - std::tie(P, L, U) = at::lu_unpack(LU, pivots, /*unpack_data=*/true, /*unpack_pivots=*/left == adjoint); - // TODO Optimise the order of the operations to avoid operating on large tensors unnecessarily - // The logic should be: if n < k == left then multiply the gX and X first (as it's done now) - // Otherwise multiply them last + std::tie(P, L, U) = at::lu_unpack( + LU, pivots, /*unpack_data=*/true, /*unpack_pivots=*/left == adjoint); + // TODO Optimise the order of the operations to avoid operating on large + // tensors unnecessarily + // The logic should be: if n < k == left then multiply the gX and X first + // (as it's done now) Otherwise multiply them last if (left != adjoint) { // gR = U^{-H}op_2(-gX)op_2(X)^H - auto gR = at::linalg_solve_triangular(U.mH(), -(left ? gX : gX.mH()).matmul(left ? X.mH() : X), /*upper*/false); + auto gR = at::linalg_solve_triangular( + U.mH(), + -(left ? gX : gX.mH()).matmul(left ? X.mH() : X), + /*upper*/ false); // gL = (L^{-H} gR U^H).tril(-1) - auto gL = at::linalg_solve_triangular(L.mH(), gR.matmul(U.mH()), /*upper*/true, /*left*/true, /*unitriangular*/true).tril(-1);; + auto gL = at::linalg_solve_triangular( + L.mH(), + gR.matmul(U.mH()), + /*upper*/ true, + /*left*/ true, + /*unitriangular*/ true) + .tril(-1); + ; return gL + gR.triu(); } else { // gR = -P^T op_3(X)op_1(op_2(gX))P - auto gR = -P.mT().matmul(left ? X : X.mH()).matmul(left ? gX.mH() : gX).matmul(P); + auto gR = + -P.mT().matmul(left ? X : X.mH()).matmul(left ? gX.mH() : gX).matmul(P); // gR = gR L^{-H} - gR = at::linalg_solve_triangular(L.mH(), gR, /*upper*/true, /*left*/false, /*unitriangular*/true); + gR = at::linalg_solve_triangular( + L.mH(), gR, /*upper*/ true, /*left*/ false, /*unitriangular*/ true); // gU = (L^H gR U^{-H}).triu() - auto gU = at::linalg_solve_triangular(U.mH(), L.mH().matmul(gR), /*upper*/false, /*left*/false).triu(); + auto gU = at::linalg_solve_triangular( + U.mH(), L.mH().matmul(gR), /*upper*/ false, /*left*/ false) + .triu(); return gR.tril(-1) + gU; } } Tensor linalg_lu_solve_jvp( - const Tensor& X, - const Tensor& LU, - const Tensor& pivots, - const Tensor& dLU, - const Tensor& dB, - const bool left, - const bool adjoint) { - // We write the derivation in terms of some adjoint operations, as otherwise we would need to - // write down 4 different proofs with 4 different implementations and it'd be painful to derive - // and maintain - // Below, we just use that X -> X^H is linear, so it commutes with the derivative - // The derivation follows by differentiating op_1(PLU)op_2(X) = op_3(B) + const Tensor& X, + const Tensor& LU, + const Tensor& pivots, + const Tensor& dLU, + const Tensor& dB, + const bool left, + const bool adjoint) { + // We write the derivation in terms of some adjoint operations, as otherwise + // we would need to write down 4 different proofs with 4 different + // implementations and it'd be painful to derive and maintain Below, we just + // use that X -> X^H is linear, so it commutes with the derivative The + // derivation follows by differentiating op_1(PLU)op_2(X) = op_3(B) // left = True, adjoint = True: A^HX = B // left = True, adjoint = False: AX = B @@ -4851,12 +5576,18 @@ Tensor linalg_lu_solve_jvp( at::NoTF32Guard disable_tf32; auto S = at::linalg_lu_solve(LU, pivots, dB, left, adjoint); if (left != adjoint) { - // We see that when left != adjoint, op_1(A) = A, and we can substitute A^{-1}op_3(B) by op_2(X) - // dX = op_2(-U^{-1}(dU + L^{-1}dL U)op_2(X)) + S + // We see that when left != adjoint, op_1(A) = A, and we can substitute + // A^{-1}op_3(B) by op_2(X) dX = op_2(-U^{-1}(dU + L^{-1}dL U)op_2(X)) + S // Let R = -U^{-1}(dU + L^{-1}dL U) - auto R = at::linalg_solve_triangular(LU, dLU.tril(-1), /*upper*/false, /*left*/true, /*unitriangular*/true); + auto R = at::linalg_solve_triangular( + LU, + dLU.tril(-1), + /*upper*/ false, + /*left*/ true, + /*unitriangular*/ true); auto U = LU.triu(); - R = -at::linalg_solve_triangular(U, dLU.triu() + R.matmul(U), /*upper*/true); + R = -at::linalg_solve_triangular( + U, dLU.triu() + R.matmul(U), /*upper*/ true); // dX = op_2(R op_2(X)) + S return (left ? R.matmul(X) : X.matmul(R.mH())) + S; } else { @@ -4871,31 +5602,44 @@ Tensor linalg_lu_solve_jvp( // Compute V = op_3(X)^H auto V = left ? X.mH() : X; // Compute the inner parens LdUU^{-1} + dL - auto R = at::linalg_solve_triangular(U, L.matmul(dLU.triu()), /*upper*/true, /*left*/false) + dLU.tril(-1); + auto R = at::linalg_solve_triangular( + U, L.matmul(dLU.triu()), /*upper*/ true, /*left*/ false) + + dLU.tril(-1); // dX = op_2(op_1(-op_3(X)^H PRL^{-1} P^T)) + S - R = at::linalg_solve_triangular(L, -V.matmul(P).matmul(R), /*upper*/false, /*left*/false, /*unitriangular*/true).matmul(P.mT()); + R = at::linalg_solve_triangular( + L, + -V.matmul(P).matmul(R), + /*upper*/ false, + /*left*/ false, + /*unitriangular*/ true) + .matmul(P.mT()); // dX = op_2(R^H) + S return (left ? R.mH() : R) + S; } } Tensor linalg_solve_jvp( - const Tensor& dA, - const Tensor& dB, - const Tensor& X, - const Tensor& LU, - const Tensor& pivots, - const bool left, - const bool use_A_T) { + const Tensor& dA, + const Tensor& dB, + const Tensor& X, + const Tensor& LU, + const Tensor& pivots, + const bool left, + const bool use_A_T) { at::NoTF32Guard disable_tf32; // For left=True (left=False is analogous) // dX = A^{-1}(dB - dAX) // [NumPy compat] Case where the rhs is a vector. - // We denote with an underscore vectors that have been converted to matrices by `unsqueeze(-1)` + // We denote with an underscore vectors that have been converted to matrices + // by `unsqueeze(-1)` const bool vector_case = at::native::linalg_solve_is_vector_rhs(LU, X); - const auto vector_to_matrix = [vector_case](const Tensor& X) { return vector_case ? X.unsqueeze(-1) : X; }; - const auto matrix_to_vector = [vector_case](const Tensor& X) { return vector_case ? X.squeeze(-1) : X; }; + const auto vector_to_matrix = [vector_case](const Tensor& X) { + return vector_case ? X.unsqueeze(-1) : X; + }; + const auto matrix_to_vector = [vector_case](const Tensor& X) { + return vector_case ? X.squeeze(-1) : X; + }; // This case is disallowed in the primal operation as A.shape = (*, 1, 1) TORCH_INTERNAL_ASSERT(left || !vector_case); @@ -4903,18 +5647,19 @@ Tensor linalg_solve_jvp( auto X_ = vector_to_matrix(X); auto dB_ = vector_to_matrix(dB); auto R_ = left ? dA.matmul(X_) : X_.matmul(dA); - auto dX_ = at::linalg_lu_solve(LU, pivots, dB_ - R_, left, /*adjoint*/use_A_T); + auto dX_ = + at::linalg_lu_solve(LU, pivots, dB_ - R_, left, /*adjoint*/ use_A_T); return matrix_to_vector(dX_); } std::tuple linalg_solve_backward( - const Tensor& gX, - const Tensor& X, - const Tensor& A, - const Tensor& LU, - const Tensor& pivots, - const bool left, - const bool B_requires_grad) { + const Tensor& gX, + const Tensor& X, + const Tensor& A, + const Tensor& LU, + const Tensor& pivots, + const bool left, + const bool B_requires_grad) { // for X = A^{-1}B // gB = A^{-H}gX // gA = -gB X^H @@ -4925,18 +5670,25 @@ std::tuple linalg_solve_backward( } // [NumPy compat] Case where the rhs is a vector. - // We denote with an underscore vectors that have been converted to matrices by `unsqueeze(-1)` + // We denote with an underscore vectors that have been converted to matrices + // by `unsqueeze(-1)` const bool vector_case = at::native::linalg_solve_is_vector_rhs(LU, X); - const auto vector_to_matrix = [vector_case](const Tensor& X) { return vector_case ? X.unsqueeze(-1) : X; }; - const auto matrix_to_vector = [vector_case](const Tensor& X) { return vector_case ? X.squeeze(-1) : X; }; + const auto vector_to_matrix = [vector_case](const Tensor& X) { + return vector_case ? X.unsqueeze(-1) : X; + }; + const auto matrix_to_vector = [vector_case](const Tensor& X) { + return vector_case ? X.squeeze(-1) : X; + }; - // If the user is going to compute higher order gradients, then we need to recompute the LU and the pivots + // If the user is going to compute higher order gradients, then we need to + // recompute the LU and the pivots Tensor gB_; if (at::GradMode::is_enabled()) { gB_ = at::linalg_solve(A.mH(), vector_to_matrix(gX), left); } else { const auto use_A_T = A.is_contiguous() && !A.is_complex(); - gB_ = at::linalg_lu_solve(LU, pivots, vector_to_matrix(gX), left, /*adjoint*/!use_A_T); + gB_ = at::linalg_lu_solve( + LU, pivots, vector_to_matrix(gX), left, /*adjoint*/ !use_A_T); } Tensor gA_; @@ -4944,40 +5696,49 @@ std::tuple linalg_solve_backward( auto X_ = vector_to_matrix(X); gA_ = left ? -gB_.matmul(X_.mH()) : -X_.mH().matmul(gB_); } - return std::make_tuple(A_requires_grad ? matrix_to_vector(gA_) : Tensor{}, - B_requires_grad ? matrix_to_vector(gB_) : Tensor{}); + return std::make_tuple( + A_requires_grad ? matrix_to_vector(gA_) : Tensor{}, + B_requires_grad ? matrix_to_vector(gB_) : Tensor{}); } Tensor solve_jvp( - const Tensor& X, - const Tensor& A, - const Tensor& dA, - const Tensor& dB -) { + const Tensor& X, + const Tensor& A, + const Tensor& dA, + const Tensor& dB) { return generic_solve_jvp( - [](const Tensor& A, const Tensor& dB, const Tensor& dA_contrib) { - return at::linalg_solve(A, dB - dA_contrib); - }, - X, A, dA, dB - ); + [](const Tensor& A, const Tensor& dB, const Tensor& dA_contrib) { + return at::linalg_solve(A, dB - dA_contrib); + }, + X, + A, + dA, + dB); } Tensor lu_unpack_backward( - const Tensor& L_grad, - const Tensor& U_grad, - const int64_t m, - const int64_t n -) { + const Tensor& L_grad, + const Tensor& U_grad, + const int64_t m, + const int64_t n) { if (!L_grad.defined() && !U_grad.defined()) { return {}; } const auto k = std::min(m, n); // Getters for the principal and complementary part of the matrices - const auto get_L1 = [m, k](const Tensor& L) { return m == k ? L.tril(-1) : L.narrow(-2, 0, k).tril(-1); }; - const auto get_L2 = [m, k](const Tensor& L) { return L.narrow(-2, k, m - k); }; - const auto get_U1 = [n, k](const Tensor& U) { return n == k ? U.triu() : U.narrow(-1, 0, k).triu(); }; - const auto get_U2 = [n, k](const Tensor& U) { return U.narrow(-1, k, n - k); }; + const auto get_L1 = [m, k](const Tensor& L) { + return m == k ? L.tril(-1) : L.narrow(-2, 0, k).tril(-1); + }; + const auto get_L2 = [m, k](const Tensor& L) { + return L.narrow(-2, k, m - k); + }; + const auto get_U1 = [n, k](const Tensor& U) { + return n == k ? U.triu() : U.narrow(-1, 0, k).triu(); + }; + const auto get_U2 = [n, k](const Tensor& U) { + return U.narrow(-1, k, n - k); + }; if (L_grad.defined()) { if (U_grad.defined()) { @@ -4995,7 +5756,8 @@ Tensor lu_unpack_backward( } else { auto size = L_grad.sizes().vec(); size.end()[-1] = n - m; - return at::cat({L_grad.tril(-1), at::zeros(size, L_grad.options())}, /*dim=*/-1); + return at::cat( + {L_grad.tril(-1), at::zeros(size, L_grad.options())}, /*dim=*/-1); } } } else { @@ -5004,7 +5766,8 @@ Tensor lu_unpack_backward( } else { auto size = U_grad.sizes().vec(); size.end()[-2] = m - n; - return at::cat({U_grad.triu(), at::zeros(size, U_grad.options())}, /*dim=*/-2); + return at::cat( + {U_grad.triu(), at::zeros(size, U_grad.options())}, /*dim=*/-2); } } } @@ -5013,15 +5776,18 @@ Tensor cat_jvp(at::TensorList tensors, int64_t dim) { Tensor out_fw_grad; auto any_defined = false; - for (const auto& t: tensors) { + for (const auto& t : tensors) { any_defined |= isFwGradDefined(t); } if (any_defined) { std::vector fw_grads; - for (auto& t: tensors) { - fw_grads.push_back(isFwGradDefined(t)? t._fw_grad(/*level*/ 0): at::_efficientzerotensor(t.sizes(), t.options())); + for (auto& t : tensors) { + fw_grads.push_back( + isFwGradDefined(t) + ? t._fw_grad(/*level*/ 0) + : at::_efficientzerotensor(t.sizes(), t.options())); } out_fw_grad = at::cat(fw_grads, dim); @@ -5034,7 +5800,7 @@ Tensor block_diag_jvp(at::TensorList tensors) { Tensor out_fw_grad; auto any_defined = false; - for (const auto& t: tensors) { + for (const auto& t : tensors) { any_defined |= isFwGradDefined(t); } @@ -5042,8 +5808,11 @@ Tensor block_diag_jvp(at::TensorList tensors) { std::vector fw_grads; fw_grads.reserve(tensors.size()); - for (const auto& t: tensors) { - fw_grads.push_back(isFwGradDefined(t)? t._fw_grad(/*level*/ 0): at::_efficientzerotensor(t.sizes(), t.options())); + for (const auto& t : tensors) { + fw_grads.push_back( + isFwGradDefined(t) + ? t._fw_grad(/*level*/ 0) + : at::_efficientzerotensor(t.sizes(), t.options())); } out_fw_grad = at::block_diag(fw_grads); @@ -5058,15 +5827,18 @@ Tensor stack_jvp(at::TensorList tensors, int64_t dim) { Tensor out_fw_grad; auto any_defined = false; - for (const auto& t: tensors) { + for (const auto& t : tensors) { any_defined |= isFwGradDefined(t); } if (any_defined) { std::vector fw_grads; - for (auto& t: tensors) { - fw_grads.push_back(isFwGradDefined(t)? t._fw_grad(/*level*/ 0): at::_efficientzerotensor(t.sizes(), t.options())); + for (auto& t : tensors) { + fw_grads.push_back( + isFwGradDefined(t) + ? t._fw_grad(/*level*/ 0) + : at::_efficientzerotensor(t.sizes(), t.options())); } out_fw_grad = at::stack(fw_grads, dim); } @@ -5100,19 +5872,26 @@ Tensor cumprod_jvp(Tensor self_t, Tensor self_p, Tensor result, int dim) { auto mask_after_first_zero = mask_first_zero.cumsum(dim); // Do the final replacement - return at::where(mask_after_first_zero.to(ScalarType::Bool), new_grad, gradient); + return at::where( + mask_after_first_zero.to(ScalarType::Bool), new_grad, gradient); } } // Helper for {batch,layer,group}_norms below // Computes the jvp for `1 / input.std(dims, keepdim)` static Tensor _invstd_jvp( - const Tensor& input_p, const Tensor& input_t, - const Tensor& mean_p, const Tensor& invstd_p, - IntArrayRef dims, int64_t numel, bool keepdim) { + const Tensor& input_p, + const Tensor& input_t, + const Tensor& mean_p, + const Tensor& invstd_p, + IntArrayRef dims, + int64_t numel, + bool keepdim) { Tensor invstd_t; - if (areAnyTensorSubclassLike({input_t, input_p, mean_p, invstd_p}) || input_t._is_zerotensor()) { - invstd_t = -invstd_p.pow(3) * (input_t - input_t.mean(dims, true)) * (input_p - mean_p); + if (areAnyTensorSubclassLike({input_t, input_p, mean_p, invstd_p}) || + input_t._is_zerotensor()) { + invstd_t = -invstd_p.pow(3) * (input_t - input_t.mean(dims, true)) * + (input_p - mean_p); } else { invstd_t = input_t - input_t.mean(dims, true); invstd_t *= input_p - mean_p; @@ -5126,13 +5905,19 @@ static Tensor _invstd_jvp( // Helper for {batch,layer,group}_norms below only // Computes the jvp for `(input - input.mean(dims)) * input.invstd(dims)` static Tensor _norm_jvp( - const Tensor& input_p, const Tensor& input_t, - const Tensor& mean_p, const Tensor& invstd_p, - IntArrayRef dims, int64_t numel) { - auto invstd_t = _invstd_jvp(input_p, input_t, mean_p, invstd_p, dims, numel, true); + const Tensor& input_p, + const Tensor& input_t, + const Tensor& mean_p, + const Tensor& invstd_p, + IntArrayRef dims, + int64_t numel) { + auto invstd_t = + _invstd_jvp(input_p, input_t, mean_p, invstd_p, dims, numel, true); Tensor result_t; - if (areAnyTensorSubclassLike({input_t, input_p, mean_p, invstd_p}) || input_t._is_zerotensor()) { - result_t = (input_t - input_t.mean(dims, true)) * invstd_p + (input_p - mean_p) * invstd_t; + if (areAnyTensorSubclassLike({input_t, input_p, mean_p, invstd_p}) || + input_t._is_zerotensor()) { + result_t = (input_t - input_t.mean(dims, true)) * invstd_p + + (input_p - mean_p) * invstd_t; } else { result_t = input_t - input_t.mean(dims, true); result_t *= invstd_p; @@ -5144,17 +5929,21 @@ static Tensor _norm_jvp( } // Helper for {batch,layer,group}_norms below only -// Computes the jvp for `input * weight + bias` where weight and bias may be undefined -// Possibly modifies the input inplace +// Computes the jvp for `input * weight + bias` where weight and bias may be +// undefined Possibly modifies the input inplace static Tensor _affine_jvp( - const c10::optional& input_p, Tensor& input_t, - const Tensor& weight_p, const Tensor& weight_t, + const c10::optional& input_p, + Tensor& input_t, + const Tensor& weight_p, + const Tensor& weight_t, const Tensor& bias_t) { // We allow input_p to be optional because if weight_p isn't defined, // it may be possible to avoid computing input_p TORCH_INTERNAL_ASSERT(input_p.has_value() == weight_p.defined()); if (weight_p.defined()) { - if (areAnyTensorSubclassLike({input_p.value(), input_t, weight_p, weight_t}) || input_t._is_zerotensor() || weight_t._is_zerotensor()) { + if (areAnyTensorSubclassLike( + {input_p.value(), input_t, weight_p, weight_t}) || + input_t._is_zerotensor() || weight_t._is_zerotensor()) { input_t = input_t * weight_p + input_p.value() * weight_t; } else { input_t *= weight_p; @@ -5164,7 +5953,8 @@ static Tensor _affine_jvp( } } if (bias_t.defined()) { - if (areAnyTensorSubclassLike({input_t, bias_t}) || input_t._is_zerotensor()) { + if (areAnyTensorSubclassLike({input_t, bias_t}) || + input_t._is_zerotensor()) { input_t = input_t + bias_t; } else { input_t += bias_t; @@ -5174,12 +5964,16 @@ static Tensor _affine_jvp( } Tensor batch_norm_jvp( - const Tensor& input_p, const Tensor& input_t, - const Tensor& weight_p, const Tensor& weight_t, - const Tensor& bias_p, const Tensor& bias_t, + const Tensor& input_p, + const Tensor& input_t, + const Tensor& weight_p, + const Tensor& weight_t, + const Tensor& bias_p, + const Tensor& bias_t, const c10::optional& running_mean, const c10::optional& running_var, - const Tensor& saved_mean, const Tensor& saved_invstd, + const Tensor& saved_mean, + const Tensor& saved_invstd, bool train, double eps) { auto dims = std::vector{}; @@ -5204,27 +5998,35 @@ Tensor batch_norm_jvp( running_mean.has_value() && running_var.has_value(), "Expect running_mean and running_var to have value when train=false"); TORCH_CHECK( - !running_mean.value()._fw_grad(/*level=*/0).defined() && !running_var.value()._fw_grad(/*level=*/0).defined(), + !running_mean.value()._fw_grad(/*level=*/0).defined() && + !running_var.value()._fw_grad(/*level=*/0).defined(), "batch_norm is not differentiable wrt running_mean and running_var, they cannot have forward grad defined"); mean_p = running_mean.value().view(view_size); - invstd_p = (1 / at::sqrt(running_var.value() + at::Scalar(eps))).view(view_size); + invstd_p = + (1 / at::sqrt(running_var.value() + at::Scalar(eps))).view(view_size); result_t = input_t * invstd_p; } c10::optional result_p = weight_p.defined() - ? c10::optional((input_p - mean_p) * invstd_p) : c10::nullopt; + ? c10::optional((input_p - mean_p) * invstd_p) + : c10::nullopt; return _affine_jvp( - result_p, result_t, + result_p, + result_t, weight_p.defined() ? weight_p.view(view_size) : weight_p, weight_t.defined() ? weight_t.view(view_size) : weight_t, bias_t.defined() ? bias_t.view(view_size) : bias_t); } Tensor layer_norm_jvp( - const Tensor& input_p, const Tensor& input_t, - const Tensor& weight_p, const Tensor& weight_t, - const Tensor& bias_p, const Tensor& bias_t, - const Tensor& saved_mean, const Tensor& saved_invstd, + const Tensor& input_p, + const Tensor& input_t, + const Tensor& weight_p, + const Tensor& weight_t, + const Tensor& bias_p, + const Tensor& bias_t, + const Tensor& saved_mean, + const Tensor& saved_invstd, IntArrayRef normalized_shape) { auto dims = std::vector{}; auto view_size = input_t.sizes().vec(); @@ -5245,19 +6047,25 @@ Tensor layer_norm_jvp( auto result_t = _norm_jvp(input_p, input_t, mean_p, invstd_p, dims, numel); c10::optional result_p = weight_p.defined() - ? c10::optional((input_p - mean_p) * invstd_p) : c10::nullopt; + ? c10::optional((input_p - mean_p) * invstd_p) + : c10::nullopt; return _affine_jvp( - result_p, result_t, + result_p, + result_t, weight_p.defined() ? weight_p.view(view_size_affine) : weight_p, weight_t.defined() ? weight_t.view(view_size_affine) : weight_t, bias_t.defined() ? bias_t.view(view_size_affine) : bias_t); } Tensor group_norm_jvp( - const Tensor& input_p, const Tensor& input_t, - const Tensor& weight_p, const Tensor& weight_t, - const Tensor& bias_p, const Tensor& bias_t, - const Tensor& saved_mean, const Tensor& saved_invstd, + const Tensor& input_p, + const Tensor& input_t, + const Tensor& weight_p, + const Tensor& weight_t, + const Tensor& bias_p, + const Tensor& bias_t, + const Tensor& saved_mean, + const Tensor& saved_invstd, int64_t groups) { auto input_shape = input_p.sizes(); int64_t N = input_p.size(0); @@ -5267,30 +6075,43 @@ Tensor group_norm_jvp( auto input_p_reshaped = input_p.view({1, N * groups, N ? -1 : 1}); auto result_t = batch_norm_jvp( - input_p_reshaped, input_t_reshaped, - /*weight_p=*/{}, /*weight_t=*/{}, - /*bias_p=*/{}, /*bias_t=*/{}, - /*running_mean=*/{}, /*running_var=*/{}, - saved_mean, saved_invstd, /*train=*/true, /*eps=*/0).view(input_shape); + input_p_reshaped, + input_t_reshaped, + /*weight_p=*/{}, + /*weight_t=*/{}, + /*bias_p=*/{}, + /*bias_t=*/{}, + /*running_mean=*/{}, + /*running_var=*/{}, + saved_mean, + saved_invstd, + /*train=*/true, + /*eps=*/0) + .view(input_shape); c10::optional result_p = c10::nullopt; if (weight_p.defined()) { std::vector view_size(input_t_reshaped.dim(), 1); view_size[1] = input_t_reshaped.size(1); - result_p = ((input_p_reshaped - saved_mean.view(view_size)) * saved_invstd.view(view_size)).view(input_shape); + result_p = ((input_p_reshaped - saved_mean.view(view_size)) * + saved_invstd.view(view_size)) + .view(input_shape); } std::vector affine_param_shape(input_p.dim(), 1); affine_param_shape[1] = C; return _affine_jvp( - result_p, result_t, + result_p, + result_t, weight_p.defined() ? weight_p.view(affine_param_shape) : weight_p, weight_t.defined() ? weight_t.view(affine_param_shape) : weight_t, bias_t.defined() ? bias_t.view(affine_param_shape) : bias_t); } Tensor group_norm_mean_jvp( - const Tensor& input_t, const Tensor& mean_p, int64_t groups) { + const Tensor& input_t, + const Tensor& mean_p, + int64_t groups) { int64_t N = input_t.size(0); std::array view_shape = {1, N * groups, N ? -1 : 1}; auto input_t_reshaped = input_t.view(view_shape); @@ -5298,8 +6119,10 @@ Tensor group_norm_mean_jvp( } Tensor group_norm_invstd_jvp( - const Tensor& input_p, const Tensor& input_t, - const Tensor& mean_p, const Tensor& invstd_p, + const Tensor& input_p, + const Tensor& input_t, + const Tensor& mean_p, + const Tensor& invstd_p, int64_t groups) { int64_t N = input_p.size(0); @@ -5309,11 +6132,21 @@ Tensor group_norm_invstd_jvp( auto input_p_reshaped = input_p.view(view_shape); return _invstd_jvp( - input_t_reshaped, input_p_reshaped, mean_p.view(view_shape), invstd_p.view(view_shape), - /*dims=*/{2}, /*numel=*/input_t_reshaped.size(2), /*keepdim=*/false).view_as(invstd_p); -} - -Tensor gather_with_keepdimed_indices(const Tensor& input, int64_t dim, const Tensor& indices, bool keepdim) { + input_t_reshaped, + input_p_reshaped, + mean_p.view(view_shape), + invstd_p.view(view_shape), + /*dims=*/{2}, + /*numel=*/input_t_reshaped.size(2), + /*keepdim=*/false) + .view_as(invstd_p); +} + +Tensor gather_with_keepdimed_indices( + const Tensor& input, + int64_t dim, + const Tensor& indices, + bool keepdim) { auto full_indices = indices; if (!keepdim) { full_indices = indices.unsqueeze(dim); @@ -5340,18 +6173,19 @@ Tensor gather_with_keepdimed_indices(const Tensor& input, int64_t dim, const Ten // Below we derive the backward algorithm for the case when m <= n. // The case m > n could be obtained using the same idea. // Since we assume m <= n, the LU decomposition of A could be written as -// A = (A1 | A2) = P L (U1 | U2) where A1, U1 in \C^{m \times m}, A2, U2 in \C^{m, n - m} +// A = (A1 | A2) = P L (U1 | U2) where A1, U1 in \C^{m \times m}, A2, U2 in +// \C^{m, n - m} // // Forward AD: // // dA = P dL U + P L dU => [left-multiply P^T] // (P^T dA1 | P^T dA2) = (dL U1 + L dU1 | dL U2 + L dU2) (*) // From (*): -// P^T dA1 = dL U1 + L dU1 => [left-multiply by L^{-1}, right-multiply by U1^{-1}] -// L^{-1} P^T dA1 U1^{-1} = L^{-1} dL + dU1 U1^{-1} (**). -// Note, L is lower-triangular, and so is its inverse, hence L^{-1} dL is lower-triangular. -// Also, since the diagonal of L (all ones) is never exposed explicity (packed representation), -// the diagonal of dL is zero, and hence diag(L^{-1} dL) = 0. +// P^T dA1 = dL U1 + L dU1 => [left-multiply by L^{-1}, right-multiply by +// U1^{-1}] L^{-1} P^T dA1 U1^{-1} = L^{-1} dL + dU1 U1^{-1} (**). Note, L is +// lower-triangular, and so is its inverse, hence L^{-1} dL is lower-triangular. +// Also, since the diagonal of L (all ones) is never exposed explicity (packed +// representation), the diagonal of dL is zero, and hence diag(L^{-1} dL) = 0. // Assuming that U1 is full-rank, similarly, dU1 U1^{-1} is upper-triangular. // Combining these observations we conclude: // @@ -5375,9 +6209,11 @@ Tensor gather_with_keepdimed_indices(const Tensor& input, int64_t dim, const Ten // // Tr(A_grad^H dA) = Tr(L_grad^H dL) + Tr(U_grad^H dU), then // -// Tr(L_grad^H dL) = Tr(L_grad^H L [(L^{-1} P^T dA1 U1^{-1}) o 1_L] = [using (!)] -// = Tr((L_grad^H L o 1_L^T) L^{-1} P^T dA1 U1^{-1}) = [using the cyclic property of Tr] -// = Tr(U1^{-1} (L_grad^H L o 1_L^T) L^{-1} P^T dA1) +// Tr(L_grad^H dL) = Tr(L_grad^H L [(L^{-1} P^T dA1 U1^{-1}) o 1_L] = [using +// (!)] +// = Tr((L_grad^H L o 1_L^T) L^{-1} P^T dA1 U1^{-1}) = [using +// the cyclic property of Tr] = Tr(U1^{-1} (L_grad^H L o 1_L^T) +// L^{-1} P^T dA1) // // Similar, using (!) and the cyclic property of the trace operator: // Tr(U_grad^H dU) = Tr(U1_grad^H dU1) + Tr(U2_grad^H dU2) @@ -5385,18 +6221,18 @@ Tensor gather_with_keepdimed_indices(const Tensor& input, int64_t dim, const Ten // + Tr(U2_grad^H L^{-1} P^T dA2) // - Tr(U1^{-1} (U2 U2_grad^H o 1_L^T) L^{-1} P^T dA1) // -// By combining the matrices to the left from dA1 and dA2 and then applying conjugate transposition, -// we finally arrive at: +// By combining the matrices to the left from dA1 and dA2 and then applying +// conjugate transposition, we finally arrive at: // -// A1_grad = P L^{-H} [L^H L_grad o 1_L + U1_grad U1^H o 1_U - U2_grad U2^H o 1_L] U1^{-H}, -// A2_grad = P L^{-H} U2_grad +// A1_grad = P L^{-H} [L^H L_grad o 1_L + U1_grad U1^H o 1_U - U2_grad U2^H o +// 1_L] U1^{-H}, A2_grad = P L^{-H} U2_grad Tensor linalg_lu_backward( - const Tensor& L_grad, - const Tensor& U_grad, - const Tensor& P, - const Tensor& L, - const Tensor& U, - const bool pivot) { + const Tensor& L_grad, + const Tensor& U_grad, + const Tensor& P, + const Tensor& L, + const Tensor& U, + const bool pivot) { at::NoTF32Guard disable_tf32; // Return early if there's nothing to do if (!L_grad.defined() && !U_grad.defined()) { @@ -5414,41 +6250,54 @@ Tensor linalg_lu_backward( auto A_grad = L_grad.defined() ? L.mH().matmul(L_grad).tril(-1) : Tensor{}; if (U_grad.defined()) { A_grad = A_grad.defined() ? A_grad + U_grad.matmul(U.mH()).triu() - : U_grad.matmul(U.mH()).triu(); + : U_grad.matmul(U.mH()).triu(); } - A_grad = at::linalg_solve_triangular(U.mH(), A_grad, - /*upper=*/false, - /*left=*/false); - A_grad = at::linalg_solve_triangular(L.mH(), A_grad, - /*upper=*/true, - /*left=*/true, - /*unitriangular=*/true); + A_grad = at::linalg_solve_triangular( + U.mH(), + A_grad, + /*upper=*/false, + /*left=*/false); + A_grad = at::linalg_solve_triangular( + L.mH(), + A_grad, + /*upper=*/true, + /*left=*/true, + /*unitriangular=*/true); return pivot ? P.matmul(std::move(A_grad)) : A_grad; } else if (m < n) { // Wide case - // A1_grad = P L^{-H} [U1_grad + (L^H L_grad o 1_L - U_grad U^H o 1_U) U1^{-H}) U^{-H}] - // A2_grad = P L^{-H} U2_grad - const auto get_U1 = [n, k] (const Tensor& U) { return n == k ? U : U.narrow(-1, 0, k); }; - const auto get_U2 = [n, k] (const Tensor& U) { return U.narrow(-1, k, n - k); }; + // A1_grad = P L^{-H} [U1_grad + (L^H L_grad o 1_L - U_grad U^H o 1_U) + // U1^{-H}) U^{-H}] A2_grad = P L^{-H} U2_grad + const auto get_U1 = [n, k](const Tensor& U) { + return n == k ? U : U.narrow(-1, 0, k); + }; + const auto get_U2 = [n, k](const Tensor& U) { + return U.narrow(-1, k, n - k); + }; auto A_grad = L_grad.defined() ? L.mH().matmul(L_grad) : Tensor{}; if (U_grad.defined()) { A_grad = A_grad.defined() ? A_grad - U_grad.triu().matmul(U.mH()) - : - U_grad.triu().matmul(U.mH()); + : -U_grad.triu().matmul(U.mH()); } - A_grad = at::linalg_solve_triangular(get_U1(U).mH(), A_grad.tril(-1), - /*upper=*/false, - /*left=*/false); + A_grad = at::linalg_solve_triangular( + get_U1(U).mH(), + A_grad.tril(-1), + /*upper=*/false, + /*left=*/false); if (U_grad.defined()) { - A_grad = at::cat({A_grad + get_U1(U_grad).triu(), get_U2(U_grad)}, /*dim=*/-1); + A_grad = + at::cat({A_grad + get_U1(U_grad).triu(), get_U2(U_grad)}, /*dim=*/-1); } - A_grad = at::linalg_solve_triangular(L.mH(), A_grad, - /*upper=*/true, - /*left=*/true, - /*unitriangular=*/true); + A_grad = at::linalg_solve_triangular( + L.mH(), + A_grad, + /*upper=*/true, + /*left=*/true, + /*unitriangular=*/true); if (!U_grad.defined()) { A_grad = at::cat({A_grad, at::zeros_like(get_U2(U))}, /*dim=*/-1); @@ -5459,29 +6308,38 @@ Tensor linalg_lu_backward( return A_grad; } else { // Tall case - // A1_grad = P [L1_grad + L^{-H} (U_grad U^H o 1_U - L^H L_grad o 1_L)]U^{-H} - // A2_grad = P L2_grad U^{-H} + // A1_grad = P [L1_grad + L^{-H} (U_grad U^H o 1_U - L^H L_grad o + // 1_L)]U^{-H} A2_grad = P L2_grad U^{-H} - const auto get_L1 = [m, k] (const Tensor& L) { return m == k ? L : L.narrow(-2, 0, k); }; - const auto get_L2 = [m, k] (const Tensor& L) { return L.narrow(-2, k, m - k); }; + const auto get_L1 = [m, k](const Tensor& L) { + return m == k ? L : L.narrow(-2, 0, k); + }; + const auto get_L2 = [m, k](const Tensor& L) { + return L.narrow(-2, k, m - k); + }; auto A_grad = U_grad.defined() ? U_grad.matmul(U.mH()) : Tensor{}; if (L_grad.defined()) { A_grad = A_grad.defined() ? A_grad - L.mH().matmul(L_grad.tril(-1)) - : - L.mH().matmul(L_grad.tril(-1)); + : -L.mH().matmul(L_grad.tril(-1)); } - A_grad = at::linalg_solve_triangular(get_L1(L).mH(), A_grad.triu(), - /*upper=*/true, - /*left=*/true, - /*unitriangular=*/true); + A_grad = at::linalg_solve_triangular( + get_L1(L).mH(), + A_grad.triu(), + /*upper=*/true, + /*left=*/true, + /*unitriangular=*/true); if (L_grad.defined()) { - A_grad = at::cat({A_grad + get_L1(L_grad).tril(-1), get_L2(L_grad)}, /*dim=*/-2); + A_grad = at::cat( + {A_grad + get_L1(L_grad).tril(-1), get_L2(L_grad)}, /*dim=*/-2); } - A_grad = at::linalg_solve_triangular(U.mH(), A_grad, - /*upper=*/false, - /*left=*/false); + A_grad = at::linalg_solve_triangular( + U.mH(), + A_grad, + /*upper=*/false, + /*left=*/false); if (!L_grad.defined()) { A_grad = at::cat({A_grad, at::zeros_like(get_L2(L))}, /*dim=*/-2); @@ -5494,12 +6352,13 @@ Tensor linalg_lu_backward( } Tensor lu_factor_ex_backward( - const Tensor& grad, - const Tensor& LU, - const Tensor& pivs, - const bool pivot) { + const Tensor& grad, + const Tensor& LU, + const Tensor& pivs, + const bool pivot) { Tensor P, L, U; - std::tie(P, L, U) = at::lu_unpack(LU, pivs, /*unpack_data=*/true, /*unpack_pivots*/pivot); + std::tie(P, L, U) = + at::lu_unpack(LU, pivs, /*unpack_data=*/true, /*unpack_pivots*/ pivot); // L.shape == (..., m, k) // U.shape == (..., k, n) @@ -5508,17 +6367,18 @@ Tensor lu_factor_ex_backward( const auto k = std::min(m, n); const auto L_grad = grad.narrow(-1, 0, k); const auto U_grad = grad.narrow(-2, 0, k); - return linalg_lu_backward(/*L_grad=*/L_grad, /*U_grad=*/U_grad, P, L, U, pivot); + return linalg_lu_backward( + /*L_grad=*/L_grad, /*U_grad=*/U_grad, P, L, U, pivot); } // This function is based on the forward AD derivations outlined // in the description to the linalg_lu_backward function. std::tuple linalg_lu_jvp( - const Tensor& dA, - const Tensor& P, - const Tensor& L, - const Tensor& U, - const bool pivot) { + const Tensor& dA, + const Tensor& P, + const Tensor& L, + const Tensor& U, + const bool pivot) { at::NoTF32Guard disable_tf32; auto m = dA.size(-2); @@ -5527,23 +6387,26 @@ std::tuple linalg_lu_jvp( auto PdA = pivot ? P.mT().matmul(dA) : dA; - // similar to the backward implementation, we also consider block structures such as: - // for a matrix A of size m x n we decompose it as - // A = (A1 | A2) with A1 of size m x m if m <= n and - // A = (A1^T | A2^T)^T with A1 of size n x n if m > n. + // similar to the backward implementation, we also consider block structures + // such as: for a matrix A of size m x n we decompose it as A = (A1 | A2) with + // A1 of size m x m if m <= n and A = (A1^T | A2^T)^T with A1 of size n x n if + // m > n. auto PdA1 = PdA.narrow(-2, 0, k).narrow(-1, 0, k); auto L1 = L.narrow(-2, 0, k).narrow(-1, 0, k); auto U1 = U.narrow(-2, 0, k).narrow(-1, 0, k); // We form using two triangular_solve the matrix, the second one in place // dK = L1^{-1} PdA1 U2^{-1} - auto dK = at::linalg_solve_triangular(L1, PdA1, /*upper=*/false, /*left=*/true, /*unitriangular*/true); + auto dK = at::linalg_solve_triangular( + L1, PdA1, /*upper=*/false, /*left=*/true, /*unitriangular*/ true); // TODO We should be able to do this in-place. At the moment it raises: // RuntimeError: linalg_solve_triangular(): functions with out=... - // arguments don't support automatic differentiation, but one of the arguments requires grad. + // arguments don't support automatic differentiation, but one of the + // arguments requires grad. - // at::linalg_solve_triangular_out(dK, U1, dK, /*upper=*/true, /*left=*/false); + // at::linalg_solve_triangular_out(dK, U1, dK, /*upper=*/true, + // /*left=*/false); dK = at::linalg_solve_triangular(U1, dK, /*upper=*/true, /*left=*/false); auto dL1 = L1.matmul(dK.tril(-1)); @@ -5556,28 +6419,33 @@ std::tuple linalg_lu_jvp( // dU2 := L1^{-1} PdA2 - dK.tril(-1) U2) const auto PdA2 = PdA.narrow(-1, k, n - k); const auto U2 = U.narrow(-1, k, n - k); - auto dU2 = at::linalg_solve_triangular(L1, PdA2, /*upper=*/false, /*left=*/true, /*unitriangular*/true) - dK.tril(-1).matmul(U2); + auto dU2 = + at::linalg_solve_triangular( + L1, PdA2, /*upper=*/false, /*left=*/true, /*unitriangular*/ true) - + dK.tril(-1).matmul(U2); return std::make_tuple(std::move(dL1), at::cat({dU1, dU2}, /*dim=*/-1)); } else { // we only need to update dL2 defined as // dL2 := PdA2 U^{-1} - L2 dK.triu() const auto PdA2 = PdA.narrow(-2, k, m - k); const auto L2 = L.narrow(-2, k, m - k); - auto dL2 = at::linalg_solve_triangular(U1, PdA2, /*upper=*/true, /*left=*/false) - L2.matmul(dK.triu()); + auto dL2 = + at::linalg_solve_triangular(U1, PdA2, /*upper=*/true, /*left=*/false) - + L2.matmul(dK.triu()); return std::make_tuple(at::cat({dL1, dL2}, /*dim=*/-2), std::move(dU1)); } } Tensor lu_factor_ex_jvp( - const Tensor& dA, - const Tensor& LU, - const Tensor& pivs, - const bool pivot -) { + const Tensor& dA, + const Tensor& LU, + const Tensor& pivs, + const bool pivot) { Tensor dL, dU; { Tensor P, L, U; - std::tie(P, L, U) = at::lu_unpack(LU, pivs, /*unpack_data=*/true, /*unpack_pivots=*/pivot); + std::tie(P, L, U) = + at::lu_unpack(LU, pivs, /*unpack_data=*/true, /*unpack_pivots=*/pivot); std::tie(dL, dU) = linalg_lu_jvp(dA, P, L, U, pivot); } @@ -5592,13 +6460,21 @@ Tensor lu_factor_ex_jvp( } } -Tensor logsumexp_jvp(const Tensor& self_p, const Tensor& self_t, IntArrayRef dim, bool keepdim) { - // NB: for simplicitly, we recompute some values that can be reused from forward - auto self_p_exp = (self_p - at::amax(self_p, dim, true)).exp(); // Use the exp-normalize trick +Tensor logsumexp_jvp( + const Tensor& self_p, + const Tensor& self_t, + IntArrayRef dim, + bool keepdim) { + // NB: for simplicitly, we recompute some values that can be reused from + // forward + auto self_p_exp = (self_p - at::amax(self_p, dim, true)) + .exp(); // Use the exp-normalize trick auto sumexp_p = self_p_exp.sum(dim, keepdim); - // NB: it's OK for logsumexp_jvp to be reused for formulas like softmax/log_softmax - // that only have one differentiable input, because that means self_t are never zerotensors + // NB: it's OK for logsumexp_jvp to be reused for formulas like + // softmax/log_softmax + // that only have one differentiable input, because that means self_t are + // never zerotensors TORCH_INTERNAL_ASSERT(!self_t._is_zerotensor()) if (areAnyTensorSubclassLike({self_p, self_t})) { auto result = (self_p_exp * self_t).sum(dim, keepdim); @@ -5611,39 +6487,57 @@ Tensor logsumexp_jvp(const Tensor& self_p, const Tensor& self_t, IntArrayRef dim } } -Tensor warn_backwards(const Tensor &grad_output) { +Tensor warn_backwards(const Tensor& grad_output) { TORCH_WARN("Warn from backward"); return grad_output; } -// This function only exists because cuDNN does not support bias gradient computation and it's not easy -// to slice a std::tuple to return only grad_input / grad_weight from convolution_backward. It will -// be removed when the cudnn_convolution and cudnn_convolution_transpose go away. +// This function only exists because cuDNN does not support bias gradient +// computation and it's not easy to slice a std::tuple to return only grad_input +// / grad_weight from convolution_backward. It will be removed when the +// cudnn_convolution and cudnn_convolution_transpose go away. std::tuple _cudnn_convolution_backward( - const at::Tensor & self, const at::Tensor & grad_output, const at::Tensor & weight, at::IntArrayRef padding, - at::IntArrayRef output_padding, at::IntArrayRef stride, at::IntArrayRef dilation, bool transposed, int64_t groups, - ::std::array output_mask) { + const at::Tensor& self, + const at::Tensor& grad_output, + const at::Tensor& weight, + at::IntArrayRef padding, + at::IntArrayRef output_padding, + at::IntArrayRef stride, + at::IntArrayRef dilation, + bool transposed, + int64_t groups, + ::std::array output_mask) { if (!grad_output.defined()) { return std::tuple(); } // Just call the general backward and ignore the bias gradient part. std::tuple grad_inputs = at::convolution_backward( - grad_output, self, weight, c10::nullopt, stride, padding, dilation, transposed, - output_padding, groups, {output_mask[0], output_mask[1], false}); - std::tuple result = std::make_tuple(std::get<0>(grad_inputs), std::get<1>(grad_inputs)); + grad_output, + self, + weight, + c10::nullopt, + stride, + padding, + dilation, + transposed, + output_padding, + groups, + {output_mask[0], output_mask[1], false}); + std::tuple result = + std::make_tuple(std::get<0>(grad_inputs), std::get<1>(grad_inputs)); return result; } std::tuple scatter_reduce_backward( - const Tensor& grad, - const Tensor& self, - int dim, - const Tensor& index, - const Tensor& src, - c10::string_view reduce, - bool include_self, - const Tensor& result) { + const Tensor& grad, + const Tensor& self, + int dim, + const Tensor& index, + const Tensor& src, + c10::string_view reduce, + bool include_self, + const Tensor& result) { Tensor grad_self, grad_src; // FIXME: complex gradients not handled correctly @@ -5660,23 +6554,30 @@ std::tuple scatter_reduce_backward( } else if (reduce == "prod") { // Explicitly compute exclusive prod for elements in self/src that are 0 Tensor masked_self = self.masked_fill(self == 0, 1); - Tensor masked_self_result = masked_self.scatter_reduce(dim, index, src, reduce, include_self); + Tensor masked_self_result = + masked_self.scatter_reduce(dim, index, src, reduce, include_self); grad_self = grad * masked_self_result / masked_self; Tensor src_zero = src == 0; - Tensor src_num_zeros = zeros_like(self).scatter_add(dim, index, src_zero.to(self.dtype())).gather(dim, index); + Tensor src_num_zeros = + zeros_like(self) + .scatter_add(dim, index, src_zero.to(self.dtype())) + .gather(dim, index); Tensor src_single_zero = bitwise_and(src_zero, src_num_zeros == 1); - // For src positions with src_single_zero, grad * result.gather(dim,index) / src.masked_fill(src_zero, 1) - // would incorrectly propagate zeros as the gradient + // For src positions with src_single_zero, grad * result.gather(dim,index) / + // src.masked_fill(src_zero, 1) would incorrectly propagate zeros as the + // gradient Tensor masked_src = src.masked_fill(src_single_zero, 1); - Tensor masked_src_result = self.scatter_reduce(dim, index, masked_src, reduce, include_self); - Tensor grad_src1 = where(src_single_zero, - (grad * masked_src_result).gather(dim, index), - (grad * result).gather(dim, index) / src.masked_fill(src_zero, 1)); + Tensor masked_src_result = + self.scatter_reduce(dim, index, masked_src, reduce, include_self); + Tensor grad_src1 = where( + src_single_zero, + (grad * masked_src_result).gather(dim, index), + (grad * result).gather(dim, index) / src.masked_fill(src_zero, 1)); if ((src_num_zeros > 1).any().item()) { auto node = std::make_shared( - "scatter_reduce(): Double backward is unsupported for src when >1 zeros in src are scattered to the same position in self", - /* num inputs */ 1); - auto result = node->apply({ grad_src1 }); + "scatter_reduce(): Double backward is unsupported for src when >1 zeros in src are scattered to the same position in self", + /* num inputs */ 1); + auto result = node->apply({grad_src1}); grad_src = result[0]; } else { grad_src = grad_src1; @@ -5693,12 +6594,16 @@ std::tuple scatter_reduce_backward( Tensor value = result.gather(dim, index); Tensor self_is_result = (self == result).to(self.scalar_type()); Tensor src_is_result = (src == value).to(self.scalar_type()); - Tensor N_to_distribute = self_is_result.scatter_add(dim, index, src_is_result); + Tensor N_to_distribute = + self_is_result.scatter_add(dim, index, src_is_result); Tensor grad_distributed = grad / N_to_distribute; grad_self = (self == result) * grad_distributed; grad_src = (src == value) * grad_distributed.gather(dim, index); } else { - AT_ERROR("Expected 'reduce' to be one of 'sum', 'prod', 'mean', 'amax', 'amin' but got ", reduce, "."); + AT_ERROR( + "Expected 'reduce' to be one of 'sum', 'prod', 'mean', 'amax', 'amin' but got ", + reduce, + "."); } if (!include_self) { @@ -5706,10 +6611,11 @@ std::tuple scatter_reduce_backward( } return std::make_tuple(grad_self, grad_src); - } -Tensor _to_copy_backward(const Tensor &grad_, const c10::TensorOptions &self_options) { +Tensor _to_copy_backward( + const Tensor& grad_, + const c10::TensorOptions& self_options) { // Handle R->C copies without raising a warning const auto self_type = self_options.dtype().toScalarType(); auto grad = c10::MaybeOwned::borrowed(grad_); @@ -5721,19 +6627,20 @@ Tensor _to_copy_backward(const Tensor &grad_, const c10::TensorOptions &self_opt } std::tuple index_reduce_backward( - const Tensor& grad, - const Tensor& self, - int dim, - const Tensor& index, - const Tensor& source, - c10::string_view reduce, - bool include_self, - const Tensor& result) { + const Tensor& grad, + const Tensor& self, + int dim, + const Tensor& index, + const Tensor& source, + c10::string_view reduce, + bool include_self, + const Tensor& result) { Tensor grad_self, grad_src; // FIXME: index_add's backward formula has a special case for source.dim == 0 - // but this case seems to throw the error "IndexError: dimension specified as 0 but tensor has no dimensions" - // look into whether this case is reachable and should be covered here + // but this case seems to throw the error "IndexError: dimension specified as + // 0 but tensor has no dimensions" look into whether this case is reachable + // and should be covered here if (!grad.defined()) { return std::make_tuple(grad_self, grad_src); @@ -5741,23 +6648,30 @@ std::tuple index_reduce_backward( if (reduce == "prod") { Tensor masked_self = self.masked_fill(self == 0, 1); - Tensor masked_self_result = masked_self.index_reduce(dim, index, source, reduce, include_self); + Tensor masked_self_result = + masked_self.index_reduce(dim, index, source, reduce, include_self); grad_self = grad * masked_self_result / masked_self; Tensor src_zero = source == 0; - Tensor src_num_zeros = zeros_like(self).index_add(dim, index, src_zero.to(self.dtype())).index_select(dim, index); + Tensor src_num_zeros = zeros_like(self) + .index_add(dim, index, src_zero.to(self.dtype())) + .index_select(dim, index); Tensor src_single_zero = bitwise_and(src_zero, src_num_zeros == 1); - // For src positions with src_single_zero, (grad * result).index_select(dim,index) / source.masked_fill(src_zero, 1) - // would incorrectly propagate zeros as the gradient + // For src positions with src_single_zero, (grad * + // result).index_select(dim,index) / source.masked_fill(src_zero, 1) would + // incorrectly propagate zeros as the gradient Tensor masked_src = source.masked_fill(src_single_zero, 1); - Tensor masked_src_result = self.index_reduce(dim, index, masked_src, reduce, include_self); - Tensor grad_src1 = where(src_single_zero, - (grad * masked_src_result).index_select(dim, index), - (grad * result).index_select(dim, index) / source.masked_fill(src_zero, 1)); + Tensor masked_src_result = + self.index_reduce(dim, index, masked_src, reduce, include_self); + Tensor grad_src1 = where( + src_single_zero, + (grad * masked_src_result).index_select(dim, index), + (grad * result).index_select(dim, index) / + source.masked_fill(src_zero, 1)); if ((src_num_zeros > 1).any().item()) { auto node = std::make_shared( - "index_reduce(): Double backward is unsupported for source when >1 zeros in source are scattered to the same position in self", - /* num inputs */ 1); - auto result = node->apply({ grad_src1 }); + "index_reduce(): Double backward is unsupported for source when >1 zeros in source are scattered to the same position in self", + /* num inputs */ 1); + auto result = node->apply({grad_src1}); grad_src = result[0]; } else { grad_src = grad_src1; @@ -5773,12 +6687,16 @@ std::tuple index_reduce_backward( Tensor value = result.index_select(dim, index); Tensor self_is_result = (self == result).to(self.scalar_type()); Tensor source_is_result = (source == value).to(self.scalar_type()); - Tensor N_to_distribute = self_is_result.index_add(dim, index, source_is_result); + Tensor N_to_distribute = + self_is_result.index_add(dim, index, source_is_result); Tensor grad_distributed = grad / N_to_distribute; grad_self = self_is_result * grad_distributed; grad_src = source_is_result * grad_distributed.index_select(dim, index); } else { - AT_ERROR("Expected 'reduce' to be one of 'prod', 'amax', 'amin' or 'mean' but got ", reduce, "."); + AT_ERROR( + "Expected 'reduce' to be one of 'prod', 'amax', 'amin' or 'mean' but got ", + reduce, + "."); } if (!include_self) { @@ -5786,10 +6704,8 @@ std::tuple index_reduce_backward( } return std::make_tuple(grad_self, grad_src); - } - } // namespace details } // namespace generated } // namespace autograd diff --git a/torch/csrc/autograd/FunctionsManual.h b/torch/csrc/autograd/FunctionsManual.h index d6c0cf0082e184..b61175736f64e3 100644 --- a/torch/csrc/autograd/FunctionsManual.h +++ b/torch/csrc/autograd/FunctionsManual.h @@ -9,8 +9,8 @@ #include #endif -#include #include +#include namespace torch { namespace autograd { @@ -26,9 +26,12 @@ struct IndexRangeGenerator { i += range_size; return {i - range_size, i}; } - size_t size() { return i; } - private: - size_t i = 0; + size_t size() { + return i; + } + + private: + size_t i = 0; }; Tensor toNonOptFwGrad(const c10::optional& t); @@ -37,163 +40,509 @@ Tensor toNonOptTensor(const c10::optional& t); Tensor apply_loss_reduction(const Tensor& unreduced, int64_t reduction); bool any_variable_defined(const variable_list& variables); -void copy_range(variable_list& out, IndexRange range, const at::Tensor & t); -void copy_range(variable_list& out, IndexRange range, at::ArrayRef t); -at::Tensor copysign_tensor_self_backward(const Tensor & grad, const Tensor & self, const Tensor & result); -at::Tensor not_implemented(const char* name, const char* reason=""); -std::vector not_implemented_list(const char* name, const char* reason=""); +void copy_range(variable_list& out, IndexRange range, const at::Tensor& t); +void copy_range( + variable_list& out, + IndexRange range, + at::ArrayRef t); +at::Tensor copysign_tensor_self_backward( + const Tensor& grad, + const Tensor& self, + const Tensor& result); +at::Tensor not_implemented(const char* name, const char* reason = ""); +std::vector not_implemented_list( + const char* name, + const char* reason = ""); at::Tensor handle_r_to_c(ScalarType self_st, Tensor gradient_result); -at::Tensor maybe_multiply(const at::Tensor & t, const at::Scalar & s); +at::Tensor maybe_multiply(const at::Tensor& t, const at::Scalar& s); int64_t _safe_size(IntArrayRef sizes, IntArrayRef dim); -Tensor restore_reduced_dims(const Tensor &output, IntArrayRef dims, bool keepdim); -Tensor scale_grad_by_count(const Tensor &grad, const Tensor &mask, IntArrayRef dims); -at::Tensor norm_backward(const at::Tensor & grad, const at::Tensor & self, const optional & p_, const at::Tensor & norm); -at::Tensor norm_backward(at::Tensor grad, const at::Tensor & self, const optional & p_, at::Tensor norm, at::IntArrayRef dim, bool keepdim); +Tensor restore_reduced_dims( + const Tensor& output, + IntArrayRef dims, + bool keepdim); +Tensor scale_grad_by_count( + const Tensor& grad, + const Tensor& mask, + IntArrayRef dims); +at::Tensor norm_backward( + const at::Tensor& grad, + const at::Tensor& self, + const optional& p_, + const at::Tensor& norm); +at::Tensor norm_backward( + at::Tensor grad, + const at::Tensor& self, + const optional& p_, + at::Tensor norm, + at::IntArrayRef dim, + bool keepdim); +Tensor norm_jvp( + const Tensor& self_p, + const Tensor& self_t, + const optional& p_, + Tensor norm, + IntArrayRef dim, + bool keepdim); Tensor norm_jvp( - const Tensor& self_p, const Tensor& self_t, - const optional & p_, - Tensor norm, - IntArrayRef dim, - bool keepdim -); -Tensor norm_jvp(const Tensor& grad, const Tensor& self, const optional & p_, Tensor norm); -Tensor linalg_vector_norm_jvp(const Tensor& self_p, const Tensor& self_t, const Scalar& scalar_ord, Tensor norm, const at::OptionalIntArrayRef& opt_dim, bool keepdim); -at::Tensor linalg_vector_norm_backward(at::Tensor grad, const at::Tensor & self, const at::Scalar & ord, at::Tensor norm, const at::OptionalIntArrayRef & opt_dim, bool keepdim); -at::Tensor pow_backward(at::Tensor grad, const at::Tensor & self, const at::Scalar & exponent_); -at::Tensor pow_backward_self(at::Tensor grad, const at::Tensor & self, const at::Tensor & exponent); -at::Tensor pow_backward_exponent(at::Tensor grad, const at::Tensor& self, const at::Tensor& exponent, at::Tensor result); -at::Tensor pow_backward_exponent(at::Tensor grad, const at::Scalar & base, const at::Tensor& exponent, at::Tensor result); + const Tensor& grad, + const Tensor& self, + const optional& p_, + Tensor norm); +Tensor linalg_vector_norm_jvp( + const Tensor& self_p, + const Tensor& self_t, + const Scalar& scalar_ord, + Tensor norm, + const at::OptionalIntArrayRef& opt_dim, + bool keepdim); +at::Tensor linalg_vector_norm_backward( + at::Tensor grad, + const at::Tensor& self, + const at::Scalar& ord, + at::Tensor norm, + const at::OptionalIntArrayRef& opt_dim, + bool keepdim); +at::Tensor pow_backward( + at::Tensor grad, + const at::Tensor& self, + const at::Scalar& exponent_); +at::Tensor pow_backward_self( + at::Tensor grad, + const at::Tensor& self, + const at::Tensor& exponent); +at::Tensor pow_backward_exponent( + at::Tensor grad, + const at::Tensor& self, + const at::Tensor& exponent, + at::Tensor result); +at::Tensor pow_backward_exponent( + at::Tensor grad, + const at::Scalar& base, + const at::Tensor& exponent, + at::Tensor result); at::Tensor angle_backward(at::Tensor grad, const at::Tensor& self); at::Tensor mul_tensor_backward(Tensor grad, Tensor other, ScalarType self_st); -at::Tensor div_tensor_self_backward(Tensor grad, Tensor other, ScalarType self_st); +at::Tensor div_tensor_self_backward( + Tensor grad, + Tensor other, + ScalarType self_st); at::Tensor div_tensor_other_backward(Tensor grad, Tensor self, Tensor other); -at::Tensor div_tensor_self_backward(Tensor grad, Tensor other, ScalarType self_st, const c10::optional& rounding_mode); -at::Tensor div_tensor_other_backward(Tensor grad, Tensor self, Tensor other, const c10::optional& rounding_mode); -at::Tensor mvlgamma_backward(at::Tensor grad, const at::Tensor & self, int64_t p); -at::Tensor permute_backwards(const at::Tensor & grad, at::IntArrayRef fwd_dims); +at::Tensor div_tensor_self_backward( + Tensor grad, + Tensor other, + ScalarType self_st, + const c10::optional& rounding_mode); +at::Tensor div_tensor_other_backward( + Tensor grad, + Tensor self, + Tensor other, + const c10::optional& rounding_mode); +at::Tensor mvlgamma_backward( + at::Tensor grad, + const at::Tensor& self, + int64_t p); +at::Tensor permute_backwards(const at::Tensor& grad, at::IntArrayRef fwd_dims); at::Tensor rad2deg_backward(const at::Tensor& grad); at::Tensor deg2rad_backward(const at::Tensor& grad); -at::Tensor unsqueeze_multiple(const at::Tensor & t, at::IntArrayRef dim, size_t n_dims); -at::Tensor sum_backward(const at::Tensor & grad, at::IntArrayRef sizes, at::IntArrayRef dims, bool keepdim); -at::Tensor nansum_backward(const at::Tensor & grad, const at::Tensor & self, at::IntArrayRef dims, bool keepdim); +at::Tensor unsqueeze_multiple( + const at::Tensor& t, + at::IntArrayRef dim, + size_t n_dims); +at::Tensor sum_backward( + const at::Tensor& grad, + at::IntArrayRef sizes, + at::IntArrayRef dims, + bool keepdim); +at::Tensor nansum_backward( + const at::Tensor& grad, + const at::Tensor& self, + at::IntArrayRef dims, + bool keepdim); std::vector reverse_list(const at::IntArrayRef list); at::Tensor reverse_dim(const at::Tensor& t, int64_t dim); -at::Tensor prod_safe_zeros_backward(const at::Tensor &grad, const at::Tensor& inp, int64_t dim); -at::Tensor prod_backward(const at::Tensor& grad, const at::Tensor& input, const at::Tensor& result); -at::Tensor prod_backward(at::Tensor grad, const at::Tensor& input, at::Tensor result, int64_t dim, bool keepdim); -at::Tensor solve_jvp(const Tensor& X, const Tensor& A, const Tensor& dA, const Tensor& dB); -at::Tensor solve_backward_self(const at::Tensor & grad, const at::Tensor & self, const at::Tensor & A); -at::Tensor solve_backward_A(const at::Tensor & grad, const at::Tensor & self, const at::Tensor & A, const at::Tensor & solution); -at::Tensor cumsum_backward(const at::Tensor & grad, int64_t dim); -at::Tensor logsumexp_backward(at::Tensor grad, const at::Tensor & self, at::Tensor result, at::IntArrayRef dim, bool keepdim); -at::Tensor logsumexp_jvp(const at::Tensor& self_p, const at::Tensor& self_t, IntArrayRef dim, bool keepdim); -at::Tensor logcumsumexp_backward(at::Tensor grad, const at::Tensor & self, at::Tensor result, int64_t dim); +at::Tensor prod_safe_zeros_backward( + const at::Tensor& grad, + const at::Tensor& inp, + int64_t dim); +at::Tensor prod_backward( + const at::Tensor& grad, + const at::Tensor& input, + const at::Tensor& result); +at::Tensor prod_backward( + at::Tensor grad, + const at::Tensor& input, + at::Tensor result, + int64_t dim, + bool keepdim); +at::Tensor solve_jvp( + const Tensor& X, + const Tensor& A, + const Tensor& dA, + const Tensor& dB); +at::Tensor solve_backward_self( + const at::Tensor& grad, + const at::Tensor& self, + const at::Tensor& A); +at::Tensor solve_backward_A( + const at::Tensor& grad, + const at::Tensor& self, + const at::Tensor& A, + const at::Tensor& solution); +at::Tensor cumsum_backward(const at::Tensor& grad, int64_t dim); +at::Tensor logsumexp_backward( + at::Tensor grad, + const at::Tensor& self, + at::Tensor result, + at::IntArrayRef dim, + bool keepdim); +at::Tensor logsumexp_jvp( + const at::Tensor& self_p, + const at::Tensor& self_t, + IntArrayRef dim, + bool keepdim); +at::Tensor logcumsumexp_backward( + at::Tensor grad, + const at::Tensor& self, + at::Tensor result, + int64_t dim); at::Tensor unbind_backward(const variable_list& grads, int64_t dim); -at::Tensor unsqueeze_to(const at::Tensor & self, at::IntArrayRef sizes); -at::Tensor unsqueeze_to(const at::Tensor & self, int64_t dim, at::IntArrayRef sizes); -std::vector cat_tensors_backward(const at::Tensor & grad, const std::vector> &sizes, const std::vector &dtypes, int64_t dim); -std::vector block_diag_backward(const at::Tensor & grad, const std::vector> &sizes, const std::vector &dtypes); -at::Tensor clamp_backward(const at::Tensor & grad, const at::Tensor &self, const optional& min, const optional& max); -at::Tensor clamp_backward(const at::Tensor & grad, const at::Tensor &self, const at::Tensor& min, const at::Tensor& max); -std::tuple clamp_backward_min_max(const at::Tensor& grad, const at::Tensor& self, const at::Tensor& min, const at::Tensor& max, const std::array&); +at::Tensor unsqueeze_to(const at::Tensor& self, at::IntArrayRef sizes); +at::Tensor unsqueeze_to( + const at::Tensor& self, + int64_t dim, + at::IntArrayRef sizes); +std::vector cat_tensors_backward( + const at::Tensor& grad, + const std::vector>& sizes, + const std::vector& dtypes, + int64_t dim); +std::vector block_diag_backward( + const at::Tensor& grad, + const std::vector>& sizes, + const std::vector& dtypes); +at::Tensor clamp_backward( + const at::Tensor& grad, + const at::Tensor& self, + const optional& min, + const optional& max); +at::Tensor clamp_backward( + const at::Tensor& grad, + const at::Tensor& self, + const at::Tensor& min, + const at::Tensor& max); +std::tuple clamp_backward_min_max( + const at::Tensor& grad, + const at::Tensor& self, + const at::Tensor& min, + const at::Tensor& max, + const std::array&); at::Tensor clamp_jvp( - const Tensor& self_p, const Tensor& self_t, - const Tensor& min_p, const Tensor& min_t, - const Tensor& max_p, const Tensor& max_t -); -at::IntArrayRef strides_or_error(const Tensor & input, c10::string_view const & input_name); -at::Tensor mm_mat1_backward(const Tensor & grad, const Tensor & mat2, at::IntArrayRef mat1_sizes, at::IntArrayRef mat1_strides, c10::Layout mat1_layout, const Scalar & alpha); -at::Tensor mm_mat2_backward(const at::Tensor & grad, const at::Tensor & mat1, at::IntArrayRef sizes, at::IntArrayRef strides, c10::Layout layout, const at::Scalar & alpha); -at::Tensor mm_mat1_sparse_backward(const at::Tensor& grad, const at::Tensor& mat1, const at::Tensor& mat2, const at::Scalar& alpha); -at::Tensor sparse_sparse_matmul_backward(const at::Tensor& grad, const at::Tensor& mat1, const at::Tensor& mat2,int64_t grad_order); -at::Tensor renorm_backward(const at::Tensor & grad, const at::Tensor & self, const at::Scalar& p, int64_t dim, const at::Scalar& maxnorm); -at::Tensor repeat_backward(at::Tensor grad, at::IntArrayRef repeats, at::IntArrayRef input_shape); -at::Tensor _fused_dropout_backward(at::Tensor grad, at::Tensor mask, double p1m); -at::Tensor infinitely_differentiable_native_dropout_backward(const at::Tensor& grad, const at::Tensor& mask, double scale); -at::Tensor native_dropout_double_backward(const at::Tensor& ggI, const at::Tensor& grad, const at::Tensor& mask, double scale); -at::Tensor evenly_distribute_backward(at::Tensor grad, const at::Tensor & input, const at::Tensor & value); + const Tensor& self_p, + const Tensor& self_t, + const Tensor& min_p, + const Tensor& min_t, + const Tensor& max_p, + const Tensor& max_t); +at::IntArrayRef strides_or_error( + const Tensor& input, + c10::string_view const& input_name); +at::Tensor mm_mat1_backward( + const Tensor& grad, + const Tensor& mat2, + at::IntArrayRef mat1_sizes, + at::IntArrayRef mat1_strides, + c10::Layout mat1_layout, + const Scalar& alpha); +at::Tensor mm_mat2_backward( + const at::Tensor& grad, + const at::Tensor& mat1, + at::IntArrayRef sizes, + at::IntArrayRef strides, + c10::Layout layout, + const at::Scalar& alpha); +at::Tensor mm_mat1_sparse_backward( + const at::Tensor& grad, + const at::Tensor& mat1, + const at::Tensor& mat2, + const at::Scalar& alpha); +at::Tensor sparse_sparse_matmul_backward( + const at::Tensor& grad, + const at::Tensor& mat1, + const at::Tensor& mat2, + int64_t grad_order); +at::Tensor renorm_backward( + const at::Tensor& grad, + const at::Tensor& self, + const at::Scalar& p, + int64_t dim, + const at::Scalar& maxnorm); +at::Tensor repeat_backward( + at::Tensor grad, + at::IntArrayRef repeats, + at::IntArrayRef input_shape); +at::Tensor _fused_dropout_backward( + at::Tensor grad, + at::Tensor mask, + double p1m); +at::Tensor infinitely_differentiable_native_dropout_backward( + const at::Tensor& grad, + const at::Tensor& mask, + double scale); +at::Tensor native_dropout_double_backward( + const at::Tensor& ggI, + const at::Tensor& grad, + const at::Tensor& mask, + double scale); +at::Tensor evenly_distribute_backward( + at::Tensor grad, + const at::Tensor& input, + const at::Tensor& value); at::Tensor sgn_backward(Tensor result, Tensor grad, Tensor self); -at::Tensor var_backward(at::Tensor grad, const at::Tensor& self, at::OptionalIntArrayRef dim, c10::optional correction, bool keepdim); -at::Tensor var_jvp(const at::Tensor& self_t, const at::Tensor& self_p, const at::Tensor& result, at::OptionalIntArrayRef dim_opt, c10::optional correction_opt, bool keepdim); -at::Tensor std_backward(const at::Tensor& result, const at::Tensor& grad, const at::Tensor& self, at::OptionalIntArrayRef dim, c10::optional correction, bool keepdim); -at::Tensor mean_backward(at::Tensor grad, const at::IntArrayRef sizes, at::IntArrayRef dim, bool keepdim); -at::Tensor mean_backward(at::Tensor grad, const at::IntArrayRef sizes, int64_t numel); -at::Tensor var_std_mean_backward(const variable_list& grads, const at::Tensor& self, const at::Tensor& r1, const at::Tensor& r2, at::OptionalIntArrayRef dim, c10::optional correction, bool keepdim, bool is_std); -at::Tensor masked_scatter_backward(const at::Tensor & grad, const at::Tensor & mask, at::IntArrayRef sizes); -at::Tensor cholesky_backward(const at::Tensor& grad, bool upper, const at::Tensor& L); -at::Tensor cholesky_jvp(const at::Tensor& input_tangent, const at::Tensor& L, bool upper); -at::Tensor cholesky_inverse_backward(at::Tensor grad, at::Tensor L, bool upper, at::Tensor inverse); -at::Tensor cholesky_inverse_jvp(const at::Tensor& F, const at::Tensor& dF, const at::Tensor& X, bool upper); -Tensor pinv_jvp( - const Tensor& A, - const Tensor& pinvA, - const Tensor& dA -); -Tensor pinv_backward( - const Tensor& grad, - const Tensor& pinvA, - const Tensor& A -); -at::Tensor split_with_sizes_backward(const std::vector &grads, - IntArrayRef split_sizes, int64_t dim, IntArrayRef sizes, const at::TensorOptions &options); -at::Tensor split_backward(const std::vector &grads, int64_t split_size, int64_t dim, at::IntArrayRef sizes, const at::TensorOptions &options); -at::Tensor max_pool_double_backward(const at::Tensor & grad, const at::Tensor & indices, int dim); -at::Tensor glu_double_backward(const at::Tensor & grad, const at::Tensor & grad_output, const at::Tensor & input, int64_t dim); -at::Tensor glu_double_backward_grad_output(const at::Tensor & grad, const at::Tensor & input, int64_t dim); -at::Tensor infinitely_differentiable_silu_backward(const at::Tensor& grad_output, const at::Tensor& input); -at::Tensor infinitely_differentiable_mish_backward(const at::Tensor& grad_output, const at::Tensor& input); -Tensor infinitely_differentiable_logit_backward(const Tensor& grad, const Tensor& self, c10::optional eps); -at::Tensor kl_div_double_backward_grad_output(const at::Tensor & grad, const at::Tensor & input, const at::Tensor & target, int64_t reduction, bool log_target); +at::Tensor var_backward( + at::Tensor grad, + const at::Tensor& self, + at::OptionalIntArrayRef dim, + c10::optional correction, + bool keepdim); +at::Tensor var_jvp( + const at::Tensor& self_t, + const at::Tensor& self_p, + const at::Tensor& result, + at::OptionalIntArrayRef dim_opt, + c10::optional correction_opt, + bool keepdim); +at::Tensor std_backward( + const at::Tensor& result, + const at::Tensor& grad, + const at::Tensor& self, + at::OptionalIntArrayRef dim, + c10::optional correction, + bool keepdim); +at::Tensor mean_backward( + at::Tensor grad, + const at::IntArrayRef sizes, + at::IntArrayRef dim, + bool keepdim); +at::Tensor mean_backward( + at::Tensor grad, + const at::IntArrayRef sizes, + int64_t numel); +at::Tensor var_std_mean_backward( + const variable_list& grads, + const at::Tensor& self, + const at::Tensor& r1, + const at::Tensor& r2, + at::OptionalIntArrayRef dim, + c10::optional correction, + bool keepdim, + bool is_std); +at::Tensor masked_scatter_backward( + const at::Tensor& grad, + const at::Tensor& mask, + at::IntArrayRef sizes); +at::Tensor cholesky_backward( + const at::Tensor& grad, + bool upper, + const at::Tensor& L); +at::Tensor cholesky_jvp( + const at::Tensor& input_tangent, + const at::Tensor& L, + bool upper); +at::Tensor cholesky_inverse_backward( + at::Tensor grad, + at::Tensor L, + bool upper, + at::Tensor inverse); +at::Tensor cholesky_inverse_jvp( + const at::Tensor& F, + const at::Tensor& dF, + const at::Tensor& X, + bool upper); +Tensor pinv_jvp(const Tensor& A, const Tensor& pinvA, const Tensor& dA); +Tensor pinv_backward(const Tensor& grad, const Tensor& pinvA, const Tensor& A); +at::Tensor split_with_sizes_backward( + const std::vector& grads, + IntArrayRef split_sizes, + int64_t dim, + IntArrayRef sizes, + const at::TensorOptions& options); +at::Tensor split_backward( + const std::vector& grads, + int64_t split_size, + int64_t dim, + at::IntArrayRef sizes, + const at::TensorOptions& options); +at::Tensor max_pool_double_backward( + const at::Tensor& grad, + const at::Tensor& indices, + int dim); +at::Tensor glu_double_backward( + const at::Tensor& grad, + const at::Tensor& grad_output, + const at::Tensor& input, + int64_t dim); +at::Tensor glu_double_backward_grad_output( + const at::Tensor& grad, + const at::Tensor& input, + int64_t dim); +at::Tensor infinitely_differentiable_silu_backward( + const at::Tensor& grad_output, + const at::Tensor& input); +at::Tensor infinitely_differentiable_mish_backward( + const at::Tensor& grad_output, + const at::Tensor& input); +Tensor infinitely_differentiable_logit_backward( + const Tensor& grad, + const Tensor& self, + c10::optional eps); +at::Tensor kl_div_double_backward_grad_output( + const at::Tensor& grad, + const at::Tensor& input, + const at::Tensor& target, + int64_t reduction, + bool log_target); Tensor binary_cross_entropy_target_backward( - const Tensor& grad, - const Tensor& self, - const Tensor& target, - const c10::optional& weight, - int64_t reduction); + const Tensor& grad, + const Tensor& self, + const Tensor& target, + const c10::optional& weight, + int64_t reduction); Tensor binary_cross_entropy_double_backward_target( - const Tensor& grad, - const Tensor& grad_output, - const Tensor& self, - const Tensor& target, - const c10::optional& weight, - int64_t reduction -); -at::Tensor binary_cross_entropy_with_logits_target_backward(const at::Tensor& grad_output, const at::Tensor& self, const at::Tensor& target, const c10::optional& weight, const c10::optional& pos_weight, int64_t reduction); -at::Tensor log_sigmoid_double_backward(const at::Tensor & grad, const at::Tensor & input); -at::Tensor softmax_double_backward(const at::Tensor & grad, const at::Tensor & grad_output, int dim, const at::Tensor & output); -at::Tensor binary_cross_entropy_double_backward(const at::Tensor & grad_output, const at::Tensor & grad, const at::Tensor & input, const at::Tensor & target, const c10::optional& weight, int64_t reduction); -at::Tensor binary_cross_entropy_double_backward_grad_output(const at::Tensor & grad, const at::Tensor & input, const at::Tensor & target, const c10::optional& weight, int64_t reduction); -at::Tensor l1_loss_double_backward(const at::Tensor & grad, const at::Tensor & grad_output, const at::Tensor & input, const at::Tensor & target, int64_t reduction); -at::Tensor l1_loss_double_backward_grad_output(const at::Tensor & grad, const at::Tensor & grad_output, const at::Tensor & input, const at::Tensor & target, int64_t reduction); -at::Tensor smooth_l1_loss_double_backward(const at::Tensor & grad, const at::Tensor & input, const at::Tensor & target, int64_t reduction, double beta); -at::Tensor huber_loss_double_backward(const at::Tensor & grad, const at::Tensor & input, const at::Tensor & target, int64_t reduction, double delta); -at::Tensor huber_loss_double_backward_grad_output(const at::Tensor & grad, const at::Tensor & grad_output, const at::Tensor & input, const at::Tensor & target, int64_t reduction, double delta); -at::Tensor mse_loss_double_backward(const at::Tensor & grad, const at::Tensor & input, int64_t reduction); -at::Tensor soft_margin_loss_double_backward(const at::Tensor & grad, const at::Tensor & input, const at::Tensor & target, int64_t reduction); -at::Tensor soft_margin_loss_double_backward_grad_output(const at::Tensor & grad, const at::Tensor & grad_output, const at::Tensor & input, const at::Tensor & target, int64_t reduction); -at::Tensor softplus_double_backward(const at::Tensor & grad, const at::Tensor & input, const at::Scalar& beta, const at::Scalar& threshold); -at::Tensor logdet_backward(const at::Tensor & grad, const at::Tensor& self, const at::Tensor& logdet); -at::Tensor slogdet_backward(const at::Tensor& grad_logabsdet, const at::Tensor& self, const at::Tensor& signdet, const at::Tensor& logabsdet); + const Tensor& grad, + const Tensor& grad_output, + const Tensor& self, + const Tensor& target, + const c10::optional& weight, + int64_t reduction); +at::Tensor binary_cross_entropy_with_logits_target_backward( + const at::Tensor& grad_output, + const at::Tensor& self, + const at::Tensor& target, + const c10::optional& weight, + const c10::optional& pos_weight, + int64_t reduction); +at::Tensor log_sigmoid_double_backward( + const at::Tensor& grad, + const at::Tensor& input); +at::Tensor softmax_double_backward( + const at::Tensor& grad, + const at::Tensor& grad_output, + int dim, + const at::Tensor& output); +at::Tensor binary_cross_entropy_double_backward( + const at::Tensor& grad_output, + const at::Tensor& grad, + const at::Tensor& input, + const at::Tensor& target, + const c10::optional& weight, + int64_t reduction); +at::Tensor binary_cross_entropy_double_backward_grad_output( + const at::Tensor& grad, + const at::Tensor& input, + const at::Tensor& target, + const c10::optional& weight, + int64_t reduction); +at::Tensor l1_loss_double_backward( + const at::Tensor& grad, + const at::Tensor& grad_output, + const at::Tensor& input, + const at::Tensor& target, + int64_t reduction); +at::Tensor l1_loss_double_backward_grad_output( + const at::Tensor& grad, + const at::Tensor& grad_output, + const at::Tensor& input, + const at::Tensor& target, + int64_t reduction); +at::Tensor smooth_l1_loss_double_backward( + const at::Tensor& grad, + const at::Tensor& input, + const at::Tensor& target, + int64_t reduction, + double beta); +at::Tensor huber_loss_double_backward( + const at::Tensor& grad, + const at::Tensor& input, + const at::Tensor& target, + int64_t reduction, + double delta); +at::Tensor huber_loss_double_backward_grad_output( + const at::Tensor& grad, + const at::Tensor& grad_output, + const at::Tensor& input, + const at::Tensor& target, + int64_t reduction, + double delta); +at::Tensor mse_loss_double_backward( + const at::Tensor& grad, + const at::Tensor& input, + int64_t reduction); +at::Tensor soft_margin_loss_double_backward( + const at::Tensor& grad, + const at::Tensor& input, + const at::Tensor& target, + int64_t reduction); +at::Tensor soft_margin_loss_double_backward_grad_output( + const at::Tensor& grad, + const at::Tensor& grad_output, + const at::Tensor& input, + const at::Tensor& target, + int64_t reduction); +at::Tensor softplus_double_backward( + const at::Tensor& grad, + const at::Tensor& input, + const at::Scalar& beta, + const at::Scalar& threshold); +at::Tensor logdet_backward( + const at::Tensor& grad, + const at::Tensor& self, + const at::Tensor& logdet); +at::Tensor slogdet_backward( + const at::Tensor& grad_logabsdet, + const at::Tensor& self, + const at::Tensor& signdet, + const at::Tensor& logabsdet); at::Tensor log1p_backward(const at::Tensor& grad, const at::Tensor& self); at::Tensor sinc_backward(const at::Tensor& grad, const at::Tensor& self); -at::Tensor sparse_constructor_values_backward(const at::Tensor& sparse_grad_out, const at::Tensor& indices); -at::Tensor embedding_dense_double_backward(const at::Tensor & grad, const at::Tensor & indices, int64_t padding_idx); -at::Tensor index_backward(at::Tensor zeros_like_self, const torch::List>& indices, const at::Tensor& grad); -at::Tensor _cudnn_ctc_loss_backward(const at::Tensor& grad_out, const at::Tensor& loss, const at::Tensor& raw_grad, bool zero_infinity); -at::Tensor elu_double_backward(const Tensor& grad, const Tensor& grad_output, const Scalar& alpha, const Scalar& scale, const Scalar& input_scale, bool is_result, const Tensor& self_or_result); +at::Tensor sparse_constructor_values_backward( + const at::Tensor& sparse_grad_out, + const at::Tensor& indices); +at::Tensor embedding_dense_double_backward( + const at::Tensor& grad, + const at::Tensor& indices, + int64_t padding_idx); +at::Tensor index_backward( + at::Tensor zeros_like_self, + const torch::List>& indices, + const at::Tensor& grad); +at::Tensor _cudnn_ctc_loss_backward( + const at::Tensor& grad_out, + const at::Tensor& loss, + const at::Tensor& raw_grad, + bool zero_infinity); +at::Tensor elu_double_backward( + const Tensor& grad, + const Tensor& grad_output, + const Scalar& alpha, + const Scalar& scale, + const Scalar& input_scale, + bool is_result, + const Tensor& self_or_result); -Tensor svd_backward(const Tensor& gU, - const Tensor& gS, - const Tensor& gVh, - const Tensor& U, - const Tensor& S, - const Tensor& Vh); +Tensor svd_backward( + const Tensor& gU, + const Tensor& gS, + const Tensor& gVh, + const Tensor& U, + const Tensor& S, + const Tensor& Vh); -std::tuple linalg_svd_jvp(const Tensor& dA, - const Tensor& U, - const Tensor& S, - const Tensor& Vh, - const bool full_matrices); +std::tuple linalg_svd_jvp( + const Tensor& dA, + const Tensor& U, + const Tensor& S, + const Tensor& Vh, + const bool full_matrices); Tensor slice_backward_wrapper( const at::Tensor& grad, const c10::IntArrayRef& input_sizes, @@ -201,34 +550,41 @@ Tensor slice_backward_wrapper( c10::optional start, c10::optional end, int64_t step); -std::tuple linalg_eig_jvp(const Tensor& dA, - const Tensor& L, - const Tensor& V, - const bool is_hermitian); -Tensor linalg_eig_backward(const Tensor& gL, - const Tensor& gV, - const Tensor& L, - const Tensor& V, - const bool is_hermitian, - const bool symeig_eigenvectors=true); +std::tuple linalg_eig_jvp( + const Tensor& dA, + const Tensor& L, + const Tensor& V, + const bool is_hermitian); +Tensor linalg_eig_backward( + const Tensor& gL, + const Tensor& gV, + const Tensor& L, + const Tensor& V, + const bool is_hermitian, + const bool symeig_eigenvectors = true); Tensor linalg_lstsq_jvp( - const Tensor& A, - const Tensor& B, - const Tensor& dA, - const Tensor& dB -); + const Tensor& A, + const Tensor& B, + const Tensor& dA, + const Tensor& dB); std::tuple triangular_solve_backward( - const Tensor & grad_x, const Tensor & grad_m, - const Tensor & b, const Tensor & a, const Tensor & x, - const bool upper, const bool transpose, const bool unitriangular, + const Tensor& grad_x, + const Tensor& grad_m, + const Tensor& b, + const Tensor& a, + const Tensor& x, + const bool upper, + const bool transpose, + const bool unitriangular, std::array output_mask); Tensor triangular_solve_jvp( - const Tensor& X, const Tensor& A, - const Tensor& dA, const Tensor& dB, - const bool upper, - const bool transpose, - const bool unitriangular -); + const Tensor& X, + const Tensor& A, + const Tensor& dA, + const Tensor& dB, + const bool upper, + const bool transpose, + const bool unitriangular); Tensor linalg_solve_triangular_forward_AD( const Tensor& A_t, const Tensor& B_t, @@ -245,51 +601,100 @@ std::tuple linalg_solve_triangular_backward( const bool left, const bool unitriangular, std::array output_mask); -std::tuple _trilinear_backward(const Tensor& grad_out, const Tensor& i1, const Tensor& i2, const Tensor& i3, - IntArrayRef expand1, IntArrayRef expand2, IntArrayRef expand3, - IntArrayRef sumdim, std::array grad_mask); -std::tuple linalg_qr_jvp(const Tensor& dA, const Tensor& Q, const Tensor& R, - const c10::string_view mode); -Tensor linalg_qr_backward(const Tensor& gQ, const Tensor& gR, const Tensor& Q, const Tensor& R, const c10::string_view mode); -Tensor eig_backward(const std::vector &grads, const Tensor& self, - bool eigenvectors, const Tensor& lambda, const Tensor& v); -Tensor linalg_matrix_exp_differential(const Tensor& self, const Tensor& grad, bool adjoint); -Tensor linalg_det_backward(const Tensor & grad, const Tensor& self, const Tensor& det); +std::tuple _trilinear_backward( + const Tensor& grad_out, + const Tensor& i1, + const Tensor& i2, + const Tensor& i3, + IntArrayRef expand1, + IntArrayRef expand2, + IntArrayRef expand3, + IntArrayRef sumdim, + std::array grad_mask); +std::tuple linalg_qr_jvp( + const Tensor& dA, + const Tensor& Q, + const Tensor& R, + const c10::string_view mode); +Tensor linalg_qr_backward( + const Tensor& gQ, + const Tensor& gR, + const Tensor& Q, + const Tensor& R, + const c10::string_view mode); +Tensor eig_backward( + const std::vector& grads, + const Tensor& self, + bool eigenvectors, + const Tensor& lambda, + const Tensor& v); +Tensor linalg_matrix_exp_differential( + const Tensor& self, + const Tensor& grad, + bool adjoint); +Tensor linalg_det_backward( + const Tensor& grad, + const Tensor& self, + const Tensor& det); std::tuple batchnorm_double_backward( - const Tensor & input, - const c10::optional & gamma, - const Tensor & ggI, - const Tensor & ggG, - const Tensor & ggB, - const Tensor & gO, - const c10::optional & running_mean, - const c10::optional & running_var, + const Tensor& input, + const c10::optional& gamma, + const Tensor& ggI, + const Tensor& ggG, + const Tensor& ggB, + const Tensor& gO, + const c10::optional& running_mean, + const c10::optional& running_var, bool training, double eps, - const c10::optional & save_mean, - const c10::optional & save_invstd, - std::array output_mask); -std::tuple _euclidean_dist_backward(const Tensor & grad, const Tensor & x1, const Tensor & x2, const Tensor & res); -Tensor kl_div_target_backward(Tensor grad_output, Tensor self, Tensor target, int64_t reduction, bool log_target); -Tensor fft_backward(const Tensor& self, const Tensor& grad, int64_t signal_ndim, - bool complex_input, bool complex_output, - bool inverse, IntArrayRef checked_signal_sizes, - int64_t normalization, bool onesided, - IntArrayRef output_sizes); -Tensor fft_r2c_backward(const Tensor& grad, IntArrayRef dim, int64_t normalization, - bool onesided, int64_t last_dim_size); -Tensor fft_c2r_backward(const Tensor& grad, IntArrayRef dim, int64_t normalization); + const c10::optional& save_mean, + const c10::optional& save_invstd, + std::array output_mask); +std::tuple _euclidean_dist_backward( + const Tensor& grad, + const Tensor& x1, + const Tensor& x2, + const Tensor& res); +Tensor kl_div_target_backward( + Tensor grad_output, + Tensor self, + Tensor target, + int64_t reduction, + bool log_target); +Tensor fft_backward( + const Tensor& self, + const Tensor& grad, + int64_t signal_ndim, + bool complex_input, + bool complex_output, + bool inverse, + IntArrayRef checked_signal_sizes, + int64_t normalization, + bool onesided, + IntArrayRef output_sizes); +Tensor fft_r2c_backward( + const Tensor& grad, + IntArrayRef dim, + int64_t normalization, + bool onesided, + int64_t last_dim_size); +Tensor fft_c2r_backward( + const Tensor& grad, + IntArrayRef dim, + int64_t normalization); Tensor constant_pad_nd_backward(const Tensor& grad, IntArrayRef pad); std::tuple cholesky_solve_backward( - const Tensor& grad_x, const Tensor& self, - const Tensor& input2, const Tensor& result, const bool upper); + const Tensor& grad_x, + const Tensor& self, + const Tensor& input2, + const Tensor& result, + const bool upper); Tensor cholesky_solve_jvp( const Tensor& X, const Tensor& U, const Tensor& dU, const Tensor& dB, - const bool upper -); + const bool upper); std::tuple infinitely_differentiable_native_group_norm_backward( const Tensor& dY, @@ -305,55 +710,75 @@ infinitely_differentiable_native_group_norm_backward( int64_t group, double eps, std::array grad_input_mask); -Tensor prelu_jvp(const Tensor& x, const Tensor& dx, const Tensor& w, const Tensor& dw); +Tensor prelu_jvp( + const Tensor& x, + const Tensor& dx, + const Tensor& w, + const Tensor& dw); std::tuple prelu_double_backward( - const Tensor & grad_grad_input, - const Tensor & grad_grad_weight, - const Tensor & grad_out, - const Tensor & input_, - const Tensor & weight_); + const Tensor& grad_grad_input, + const Tensor& grad_grad_weight, + const Tensor& grad_out, + const Tensor& input_, + const Tensor& weight_); Tensor prelu_backward_self_jvp( const Tensor& x, const Tensor& w, const Tensor& dw, const Tensor& g, - const Tensor& dg -); + const Tensor& dg); Tensor prelu_backward_weight_jvp( const Tensor& w, const Tensor& x, const Tensor& dx, const Tensor& g, - const Tensor& dg -); + const Tensor& dg); Tensor gelu_double_backward( - const Tensor & ggI, - const Tensor & gO, - const Tensor & input, + const Tensor& ggI, + const Tensor& gO, + const Tensor& input, c10::string_view approximate); -Tensor as_strided_backward(Tensor grad, TensorGeometry input_geometry, IntArrayRef sizes, IntArrayRef strides, optional storage_offset_); -Tensor as_strided_scatter_backward(Tensor grad, TensorGeometry input_geometry, TensorGeometry src_geometry, IntArrayRef sizes, IntArrayRef strides, optional storage_offset); -std::tuple atan2_backward(const Tensor& grad, const Tensor& self, const Tensor& other, std::array output_mask); +Tensor as_strided_backward( + Tensor grad, + TensorGeometry input_geometry, + IntArrayRef sizes, + IntArrayRef strides, + optional storage_offset_); +Tensor as_strided_scatter_backward( + Tensor grad, + TensorGeometry input_geometry, + TensorGeometry src_geometry, + IntArrayRef sizes, + IntArrayRef strides, + optional storage_offset); +std::tuple atan2_backward( + const Tensor& grad, + const Tensor& self, + const Tensor& other, + std::array output_mask); std::tuple layer_norm_double_backward( - const Tensor & input, - const c10::optional & gamma, - const Tensor & ggI, - const Tensor & ggG, - const Tensor & ggB, - const Tensor & gO, - const Tensor & save_mean, - const Tensor & save_invstd, + const Tensor& input, + const c10::optional& gamma, + const Tensor& ggI, + const Tensor& ggG, + const Tensor& ggB, + const Tensor& gO, + const Tensor& save_mean, + const Tensor& save_invstd, IntArrayRef normalized_shape, - std::array output_mask); + std::array output_mask); -std::tuple householder_product_backward(const Tensor& grad, const Tensor& result, const Tensor& input, const Tensor& tau); +std::tuple householder_product_backward( + const Tensor& grad, + const Tensor& result, + const Tensor& input, + const Tensor& tau); Tensor householder_product_jvp( const Tensor& dV, const Tensor& dtau, const Tensor& prod, const Tensor& V, - const Tensor& tau -); + const Tensor& tau); std::tuple polar_backward( const Tensor& grad, const Tensor& result); @@ -366,180 +791,215 @@ Tensor i1e_backward( const Tensor& self, const Tensor& result); Tensor linalg_lu_solve_LU( - const Tensor& grad, - const Tensor& LU, - const Tensor& pivots, - const Tensor& X, - const bool left, - const bool adjoint); + const Tensor& grad, + const Tensor& LU, + const Tensor& pivots, + const Tensor& X, + const bool left, + const bool adjoint); Tensor linalg_lu_solve_jvp( - const Tensor& X, - const Tensor& LU, - const Tensor& pivots, - const Tensor& dLU, - const Tensor& dB, - const bool left, - const bool adjoint); + const Tensor& X, + const Tensor& LU, + const Tensor& pivots, + const Tensor& dLU, + const Tensor& dB, + const bool left, + const bool adjoint); std::tuple linalg_solve_backward( - const Tensor& gX, - const Tensor& X, - const Tensor& A, - const Tensor& LU, - const Tensor& pivots, - const bool left, - const bool B_requires_grad); + const Tensor& gX, + const Tensor& X, + const Tensor& A, + const Tensor& LU, + const Tensor& pivots, + const bool left, + const bool B_requires_grad); Tensor linalg_solve_jvp( - const Tensor& dA, - const Tensor& dB, - const Tensor& X, - const Tensor& LU, - const Tensor& pivots, - const bool left, - const bool use_A_T); + const Tensor& dA, + const Tensor& dB, + const Tensor& X, + const Tensor& LU, + const Tensor& pivots, + const bool left, + const bool use_A_T); Tensor lu_unpack_backward( - const Tensor& L_grad, - const Tensor& U_grad, - const int64_t m, - const int64_t n -); + const Tensor& L_grad, + const Tensor& U_grad, + const int64_t m, + const int64_t n); Tensor _det_lu_based_helper_backward( - const Tensor& det_grad, - const Tensor& det, - const Tensor& self, - const Tensor& lu, - const Tensor& pivs -); + const Tensor& det_grad, + const Tensor& det, + const Tensor& self, + const Tensor& lu, + const Tensor& pivs); std::tuple linalg_lstsq_backward( - const Tensor& grad, - const Tensor& A, - const Tensor& B, - const c10::optional rcond, - const c10::optional driver, - const std::array& grad_input_mask -); + const Tensor& grad, + const Tensor& A, + const Tensor& B, + const c10::optional rcond, + const c10::optional driver, + const std::array& grad_input_mask); Tensor linalg_lu_backward( - const Tensor& L_grad, - const Tensor& U_grad, - const Tensor& P, - const Tensor& L, - const Tensor& U, - const bool pivot); + const Tensor& L_grad, + const Tensor& U_grad, + const Tensor& P, + const Tensor& L, + const Tensor& U, + const bool pivot); std::tuple linalg_lu_jvp( - const Tensor& dA, - const Tensor& P, - const Tensor& L, - const Tensor& U, - const bool pivot); + const Tensor& dA, + const Tensor& P, + const Tensor& L, + const Tensor& U, + const bool pivot); Tensor lu_factor_ex_backward( - const Tensor& grad, - const Tensor& LU, - const Tensor& pivs, - const bool pivot -); + const Tensor& grad, + const Tensor& LU, + const Tensor& pivs, + const bool pivot); Tensor lu_factor_ex_jvp( - const Tensor& dX, - const Tensor& LU, - const Tensor& pivs, - const bool pivot -); + const Tensor& dX, + const Tensor& LU, + const Tensor& pivs, + const bool pivot); Tensor batch_norm_jvp( - const Tensor& input_p, const Tensor& input_t, - const Tensor& weight_p, const Tensor& weight_t, - const Tensor& bias_p, const Tensor& bias_t, - const c10::optional& running_mean, - const c10::optional& running_var, - const Tensor& saved_mean, const Tensor& saved_invstd, - bool train, - double eps -); + const Tensor& input_p, + const Tensor& input_t, + const Tensor& weight_p, + const Tensor& weight_t, + const Tensor& bias_p, + const Tensor& bias_t, + const c10::optional& running_mean, + const c10::optional& running_var, + const Tensor& saved_mean, + const Tensor& saved_invstd, + bool train, + double eps); Tensor layer_norm_jvp( - const Tensor& input_p, const Tensor& input_t, - const Tensor& weight_p, const Tensor& weight_t, - const Tensor& bias_p, const Tensor& bias_t, - const Tensor& saved_mean, const Tensor& saved_invstd, - IntArrayRef normalized_shape -); + const Tensor& input_p, + const Tensor& input_t, + const Tensor& weight_p, + const Tensor& weight_t, + const Tensor& bias_p, + const Tensor& bias_t, + const Tensor& saved_mean, + const Tensor& saved_invstd, + IntArrayRef normalized_shape); Tensor group_norm_jvp( - const Tensor& input_p, const Tensor& input_t, - const Tensor& weight_p, const Tensor& weight_t, - const Tensor& bias_p, const Tensor& bias_t, - const Tensor& saved_mean, const Tensor& saved_invstd, - int64_t groups -); + const Tensor& input_p, + const Tensor& input_t, + const Tensor& weight_p, + const Tensor& weight_t, + const Tensor& bias_p, + const Tensor& bias_t, + const Tensor& saved_mean, + const Tensor& saved_invstd, + int64_t groups); Tensor group_norm_mean_jvp( - const Tensor& input_t, - const Tensor& mean_p, - int64_t groups -); + const Tensor& input_t, + const Tensor& mean_p, + int64_t groups); Tensor group_norm_invstd_jvp( - const Tensor& input_p, const Tensor& input_t, - const Tensor& mean_p, const Tensor& invstd_p, - int64_t groups -); + const Tensor& input_p, + const Tensor& input_t, + const Tensor& mean_p, + const Tensor& invstd_p, + int64_t groups); Tensor convolution_jvp( - const Tensor& input_p, const Tensor& input_t, - const Tensor& weight_p, const Tensor& weight_t, - const Tensor& bias_p, const Tensor& bias_t, - IntArrayRef stride, IntArrayRef padding, IntArrayRef dilation, - bool transposed, IntArrayRef output_padding, int64_t groups -); + const Tensor& input_p, + const Tensor& input_t, + const Tensor& weight_p, + const Tensor& weight_t, + const Tensor& bias_p, + const Tensor& bias_t, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + bool transposed, + IntArrayRef output_padding, + int64_t groups); Tensor _convolution_jvp( - const Tensor& input_p, const Tensor& input_t, - const Tensor& weight_p, const Tensor& weight_t, - const Tensor& bias_p, const Tensor& bias_t, - IntArrayRef stride, IntArrayRef padding, IntArrayRef dilation, - bool transposed, IntArrayRef output_padding, int64_t groups, - bool benchmark, bool deterministic, bool cudnn_enabled, bool allow_tf32 -); + const Tensor& input_p, + const Tensor& input_t, + const Tensor& weight_p, + const Tensor& weight_t, + const Tensor& bias_p, + const Tensor& bias_t, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + bool transposed, + IntArrayRef output_padding, + int64_t groups, + bool benchmark, + bool deterministic, + bool cudnn_enabled, + bool allow_tf32); -Tensor convolution_backward_jvp_grad_bias(const Tensor& grad_out_t, const Tensor& grad_bias); +Tensor convolution_backward_jvp_grad_bias( + const Tensor& grad_out_t, + const Tensor& grad_bias); Tensor cat_jvp(at::TensorList tensors, int64_t dim); Tensor block_diag_jvp(at::TensorList tensors); Tensor stack_jvp(at::TensorList tensors, int64_t dim); Tensor cumprod_jvp(Tensor self_t, Tensor self_p, Tensor result, int dim); -Tensor gather_with_keepdimed_indices(const Tensor& input, int64_t dim, const Tensor& indices, bool keepdim); -Tensor evenly_read_jvp(const Tensor& fw_grad, const Tensor & input, const Tensor & value); -Tensor warn_backwards(const Tensor &grad_output); +Tensor gather_with_keepdimed_indices( + const Tensor& input, + int64_t dim, + const Tensor& indices, + bool keepdim); +Tensor evenly_read_jvp( + const Tensor& fw_grad, + const Tensor& input, + const Tensor& value); +Tensor warn_backwards(const Tensor& grad_output); std::tuple _cudnn_convolution_backward( - const at::Tensor & self, const at::Tensor & grad_output, const at::Tensor & weight, at::IntArrayRef padding, - at::IntArrayRef output_padding, at::IntArrayRef stride, at::IntArrayRef dilation, bool transposed, int64_t groups, - ::std::array output_mask); + const at::Tensor& self, + const at::Tensor& grad_output, + const at::Tensor& weight, + at::IntArrayRef padding, + at::IntArrayRef output_padding, + at::IntArrayRef stride, + at::IntArrayRef dilation, + bool transposed, + int64_t groups, + ::std::array output_mask); std::tuple scatter_reduce_backward( - const Tensor& grad, - const Tensor& self, - int dim, - const Tensor& index, - const Tensor& src, - c10::string_view reduce, - bool include_self, - const Tensor& result -); + const Tensor& grad, + const Tensor& self, + int dim, + const Tensor& index, + const Tensor& src, + c10::string_view reduce, + bool include_self, + const Tensor& result); -Tensor _to_copy_backward(const Tensor &grad, const c10::TensorOptions &self_options); +Tensor _to_copy_backward( + const Tensor& grad, + const c10::TensorOptions& self_options); std::tuple index_reduce_backward( - const Tensor& grad, - const Tensor& self, - int dim, - const Tensor& index, - const Tensor& source, - c10::string_view reduce, - bool include_self, - const Tensor& result -); + const Tensor& grad, + const Tensor& self, + int dim, + const Tensor& index, + const Tensor& source, + c10::string_view reduce, + bool include_self, + const Tensor& result); } // namespace details } // namespace generated diff --git a/torch/csrc/autograd/InferenceMode.h b/torch/csrc/autograd/InferenceMode.h index 4395082bb33173..f28a12e022e0d0 100644 --- a/torch/csrc/autograd/InferenceMode.h +++ b/torch/csrc/autograd/InferenceMode.h @@ -3,8 +3,10 @@ #include #include -namespace torch { namespace autograd { +namespace torch { +namespace autograd { using InferenceMode = c10::InferenceMode; -}} +} +} // namespace torch diff --git a/torch/csrc/autograd/TraceTypeManual.cpp b/torch/csrc/autograd/TraceTypeManual.cpp index a96fa42abd172a..e39b3a120f0edb 100644 --- a/torch/csrc/autograd/TraceTypeManual.cpp +++ b/torch/csrc/autograd/TraceTypeManual.cpp @@ -10,13 +10,14 @@ using namespace at; -namespace torch { namespace TraceType { +namespace torch { +namespace TraceType { namespace { -Tensor & copy_(Tensor & self, const Tensor & src, bool non_blocking) { +Tensor& copy_(Tensor& self, const Tensor& src, bool non_blocking) { jit::Value* output = nullptr; - if(torch::jit::tracer::isTracing()) { + if (torch::jit::tracer::isTracing()) { const jit::tracer::TracingState& state = *jit::tracer::getTracingState(); auto& graph = state.graph; if (state.force_outplace && self.storage().use_count() <= 1) { @@ -33,7 +34,8 @@ Tensor & copy_(Tensor & self, const Tensor & src, bool non_blocking) { {jit::tracer::getValueTrace(self), jit::tracer::getValueTrace(src)}); jit::tracer::recordSourceLocation(output->node()); } - jit::tracer::ensureUniqueIfOutOfPlaced("copy_ (possibly due to an assignment)", self); + jit::tracer::ensureUniqueIfOutOfPlaced( + "copy_ (possibly due to an assignment)", self); } { @@ -41,7 +43,7 @@ Tensor & copy_(Tensor & self, const Tensor & src, bool non_blocking) { self.copy_(src, non_blocking); } - if(torch::jit::tracer::isTracing()) { + if (torch::jit::tracer::isTracing()) { jit::tracer::setOutput(output, self); } return self; @@ -82,7 +84,7 @@ const Tensor& resize_as_( return self; } -Tensor detach(const Tensor & self) { +Tensor detach(const Tensor& self) { torch::jit::Node* node = nullptr; if (jit::tracer::isTracing()) { auto& graph = jit::tracer::getTracingState()->graph; @@ -103,7 +105,7 @@ Tensor detach(const Tensor & self) { return result; } -Tensor & detach_(Tensor & self) { +Tensor& detach_(Tensor& self) { torch::jit::Node* node = nullptr; if (jit::tracer::isTracing()) { auto& graph = jit::tracer::getTracingState()->graph; @@ -126,7 +128,8 @@ Tensor & detach_(Tensor & self) { } // Invariant: -// - Ops registered to DispatchKey::Tracer below must be included in `MANUAL_TRACER` in tools/autograd/gen_variable_type.py +// - Ops registered to DispatchKey::Tracer below must be included in +// `MANUAL_TRACER` in tools/autograd/gen_variable_type.py TORCH_LIBRARY_IMPL(aten, Tracer, m) { m.impl("resize_", resize_); m.impl("resize_as_", resize_as_); @@ -134,7 +137,8 @@ TORCH_LIBRARY_IMPL(aten, Tracer, m) { m.impl("detach_", detach_); m.impl("copy_", copy_); - // Skip tracing for the following ops by registering fallthrough kernel explicitly. + // Skip tracing for the following ops by registering fallthrough kernel + // explicitly. m.impl("_backward", CppFunction::makeFallthrough()); m.impl("set_data", CppFunction::makeFallthrough()); m.impl("data", CppFunction::makeFallthrough()); @@ -147,15 +151,14 @@ TORCH_LIBRARY_IMPL(aten, Tracer, m) { m.impl("_make_dual", CppFunction::makeFallthrough()); } -} // namespace +} // namespace -}} // namespace torch::TraceType +} // namespace TraceType +} // namespace torch -namespace torch{ -namespace jit{ -void general_trace_function( - const c10::OperatorHandle& op, - Stack* stack) { +namespace torch { +namespace jit { +void general_trace_function(const c10::OperatorHandle& op, Stack* stack) { const auto input_size = op.schema().arguments().size(); const auto output_size = op.schema().returns().size(); @@ -173,7 +176,7 @@ void general_trace_function( const auto& args = op.schema().arguments(); int i = 0; for (auto iter = stack->end() - input_size; iter != stack->end(); - ++iter, ++i) { + ++iter, ++i) { // TODO we need to refactor graph APIs (e.g., addInputs) // appropriately; after that, we can get rid of the giant if-else // block we will clean this tech debt together in the following PRs @@ -217,8 +220,7 @@ void general_trace_function( for (IValue iv : list) { objects.emplace_back(std::move(iv).toObject()); } - tracer::addInputs( - node, args[i].name().c_str(), objects, class_type); + tracer::addInputs(node, args[i].name().c_str(), objects, class_type); } else if (elem_type->kind() == TypeKind::FloatType) { AT_ASSERT(iter->isDoubleList()); // NB: now, tracer doesn't support tracing double list. We add @@ -231,8 +233,7 @@ void general_trace_function( tracer::recordSourceLocation(info[value_index]->node()); } node->addInput( - graph - ->insertNode(graph->createList(FloatType::get(), info)) + graph->insertNode(graph->createList(FloatType::get(), info)) ->output()); } else if (elem_type->kind() == TypeKind::IntType) { AT_ASSERT(iter->isIntList()); @@ -265,7 +266,7 @@ void general_trace_function( tracer::setTracingState(std::move(tracer_state)); int i = 0; for (auto iter = stack->end() - output_size; iter != stack->end(); - ++iter, ++i) { + ++iter, ++i) { const auto& type = op.schema().returns()[i].type(); if (type->isSubtypeOf(*TensorType::get())) { AT_ASSERT(iter->isTensor()); @@ -289,11 +290,10 @@ void general_trace_function( } } } - - } -TORCH_LIBRARY_IMPL(_, Tracer, m){ - m.fallback(CppFunction::makeFromBoxedFunction<&general_trace_function>()); +TORCH_LIBRARY_IMPL(_, Tracer, m) { + m.fallback(CppFunction::makeFromBoxedFunction<&general_trace_function>()); } -}}// namespace torch::jit +} // namespace jit +} // namespace torch diff --git a/torch/csrc/autograd/VariableTypeManual.cpp b/torch/csrc/autograd/VariableTypeManual.cpp index d0f7548e51a48d..486e5546b4b4ad 100644 --- a/torch/csrc/autograd/VariableTypeManual.cpp +++ b/torch/csrc/autograd/VariableTypeManual.cpp @@ -1,29 +1,34 @@ -#include +#include +#include +#include #include -#include +#include +#include #include +#include +#include #include #include -#include -#include -#include -#include -#include #include using namespace at; using namespace torch::autograd::generated; -using torch::autograd::CreationMeta; using torch::autograd::as_view; +using torch::autograd::CreationMeta; -namespace torch { namespace autograd { namespace VariableType { +namespace torch { +namespace autograd { +namespace VariableType { -std::vector allTypesForBackends(at::ArrayRef backends) { +std::vector allTypesForBackends( + at::ArrayRef backends) { std::vector res; res.reserve(backends.size()); for (auto p : backends) { - for(const auto s : c10::irange(static_cast(ScalarType::NumOptions))) { - auto& type = getDeprecatedTypeProperties(static_cast(p), static_cast(s)); + for (const auto s : + c10::irange(static_cast(ScalarType::NumOptions))) { + auto& type = getDeprecatedTypeProperties( + static_cast(p), static_cast(s)); res.emplace_back(&type); } } @@ -31,51 +36,64 @@ std::vector allTypesForBackends(at::ArrayRef allCPUTypes() { - return allTypesForBackends({ Backend::CPU, Backend::SparseCPU }); + return allTypesForBackends({Backend::CPU, Backend::SparseCPU}); } C10_EXPORT std::vector allCUDATypes() { at::globalContext().lazyInitCUDA(); - return allTypesForBackends({ Backend::CUDA, Backend::SparseCUDA }); + return allTypesForBackends({Backend::CUDA, Backend::SparseCUDA}); } namespace { -const Variable & checked_cast_variable(const Tensor & t, const char * name, int pos) { +const Variable& checked_cast_variable( + const Tensor& t, + const char* name, + int pos) { if (!t.defined()) { - AT_ERROR("Expected a proper Tensor but got None (or an undefined Tensor in C++) ", - "for argument #", pos, " '", name, "'"); + AT_ERROR( + "Expected a proper Tensor but got None (or an undefined Tensor in C++) ", + "for argument #", + pos, + " '", + name, + "'"); } return t; } -Variable & checked_cast_variable(Tensor & t, const char * name, int pos) { +Variable& checked_cast_variable(Tensor& t, const char* name, int pos) { if (!t.defined()) { - AT_ERROR("Expected a proper Tensor but got None (or an undefined Tensor in C++) ", - "for argument #", pos, " '", name, "'"); + AT_ERROR( + "Expected a proper Tensor but got None (or an undefined Tensor in C++) ", + "for argument #", + pos, + " '", + name, + "'"); } return t; } -} +} // namespace -const Tensor & unpack(const Tensor & t, const char * name, int pos) { +const Tensor& unpack(const Tensor& t, const char* name, int pos) { return checked_cast_variable(t, name, pos); } -Tensor & unpack(Tensor & t, const char * name, int pos) { +Tensor& unpack(Tensor& t, const char* name, int pos) { return checked_cast_variable(t, name, pos); } -Tensor unpack_opt(const Tensor & t, const char * name, int pos) { +Tensor unpack_opt(const Tensor& t, const char* name, int pos) { if (!t.defined()) { return Tensor(); } return unpack(t, name, pos); } -std::vector unpack(at::TensorList tl, const char *name, int pos) { +std::vector unpack(at::TensorList tl, const char* name, int pos) { std::vector ret(tl.size()); for (const auto i : c10::irange(tl.size())) { - const auto &t = tl[i]; + const auto& t = tl[i]; if (!t.defined()) { continue; } @@ -87,21 +105,22 @@ std::vector unpack(at::TensorList tl, const char *name, int pos) { namespace { // Taken from codegened version -Tensor _fw_primal(c10::DispatchKeySet ks, const Tensor & self, int64_t level) { +Tensor _fw_primal(c10::DispatchKeySet ks, const Tensor& self, int64_t level) { auto& self_ = unpack(self, "self", 0); std::shared_ptr grad_fn; - if (compute_requires_grad( self )) { + if (compute_requires_grad(self)) { grad_fn = std::make_shared(); - grad_fn->set_next_edges(collect_next_edges( self )); + grad_fn->set_next_edges(collect_next_edges(self)); } auto result = ([&]() { at::AutoDispatchBelowAutograd guard; - return at::redispatch::_fw_primal(ks & c10::after_autograd_keyset, self_, level); + return at::redispatch::_fw_primal( + ks & c10::after_autograd_keyset, self_, level); })(); if (grad_fn) { - set_history(flatten_tensor_args( result ), grad_fn); + set_history(flatten_tensor_args(result), grad_fn); } if (isFwGradDefined(self)) { // Modified from original codegen @@ -112,15 +131,24 @@ Tensor _fw_primal(c10::DispatchKeySet ks, const Tensor & self, int64_t level) { return result; } -// NB: We need a manual variable type kernel so that set_fw_grad properly detects that _make_dual is -// not a forward-differentiable view +// NB: We need a manual variable type kernel so that set_fw_grad properly +// detects that _make_dual is not a forward-differentiable view // -// This function can be used to create a dual Tensor that holds a tangent to compute forward mode gradients. -// Note that the dual Tensor's primal is a view of the given primal and the given tangent is used as-is. -// This function is backward differentiable. -Tensor _make_dual(c10::DispatchKeySet ks, const Tensor& primal, const Tensor& tangent, int64_t level) { - TORCH_CHECK(!primal._fw_grad(level).defined(), "Making a dual Tensor based on a Tensor that " - "already has a forward gradient at the same level ", level, " is not supported."); +// This function can be used to create a dual Tensor that holds a tangent to +// compute forward mode gradients. Note that the dual Tensor's primal is a view +// of the given primal and the given tangent is used as-is. This function is +// backward differentiable. +Tensor _make_dual( + c10::DispatchKeySet ks, + const Tensor& primal, + const Tensor& tangent, + int64_t level) { + TORCH_CHECK( + !primal._fw_grad(level).defined(), + "Making a dual Tensor based on a Tensor that " + "already has a forward gradient at the same level ", + level, + " is not supported."); auto& primal_ = unpack(primal, "primal", 0); auto& tangent_ = unpack(tangent, "tangent", 0); std::shared_ptr grad_fn; @@ -132,7 +160,8 @@ Tensor _make_dual(c10::DispatchKeySet ks, const Tensor& primal, const Tensor& ta auto result = ([&]() { at::AutoDispatchBelowAutograd guard; - return at::redispatch::_make_dual(ks & c10::after_autograd_keyset, primal_, tangent_, level); + return at::redispatch::_make_dual( + ks & c10::after_autograd_keyset, primal_, tangent_, level); })(); if (grad_fn) { @@ -140,12 +169,16 @@ Tensor _make_dual(c10::DispatchKeySet ks, const Tensor& primal, const Tensor& ta } TORCH_CHECK(level == 0, "Invalid level given to _make_dual"); - result._set_fw_grad(tangent_, level, /* is_inplace_op */ false); + result._set_fw_grad(tangent_, level, /* is_inplace_op */ false); return result; } // We don't have an outplace copy, so this can't be generated automatically -Tensor & copy_(c10::DispatchKeySet ks, Tensor & self, const Tensor & src, bool non_blocking) { +Tensor& copy_( + c10::DispatchKeySet ks, + Tensor& self, + const Tensor& src, + bool non_blocking) { // TODO: once copy is exposed in Declarations.yaml we may be able to bind // it automatically auto& self_ = unpack(self, "self", 0); @@ -161,9 +194,10 @@ Tensor & copy_(c10::DispatchKeySet ks, Tensor & self, const Tensor & src, bool n } { at::AutoDispatchBelowAutograd mode; - at::redispatch::copy_(ks & c10::after_autograd_keyset, self_, src_, non_blocking); + at::redispatch::copy_( + ks & c10::after_autograd_keyset, self_, src_, non_blocking); } - rebase_history(self , std::move(grad_fn)); + rebase_history(self, std::move(grad_fn)); if (isDifferentiableType(self.scalar_type()) && (isFwGradDefined(self) || isFwGradDefined(src))) { @@ -200,7 +234,8 @@ const Tensor& resize_( } { at::AutoDispatchBelowAutograd mode; - at::redispatch::resize_(ks & c10::after_autograd_keyset, self_, size, optional_memory_format); + at::redispatch::resize_( + ks & c10::after_autograd_keyset, self_, size, optional_memory_format); } if (self._fw_grad(/* level */ 0).defined()) { @@ -222,7 +257,11 @@ const Tensor& resize_as_( } { at::AutoDispatchBelowAutograd mode; - at::redispatch::resize_as_(ks & c10::after_autograd_keyset, self_, the_template_, optional_memory_format); + at::redispatch::resize_as_( + ks & c10::after_autograd_keyset, + self_, + the_template_, + optional_memory_format); } // Handle fw grad @@ -232,7 +271,7 @@ const Tensor& resize_as_( return self; } -Tensor detach(c10::DispatchKeySet ks, const Tensor & self) { +Tensor detach(c10::DispatchKeySet ks, const Tensor& self) { auto& self_ = unpack(self, "self", 0); RECORD_FUNCTION("detach", std::vector({self})); auto result = ([&]() { @@ -246,17 +285,18 @@ Tensor detach(c10::DispatchKeySet ks, const Tensor & self) { return result; } -Tensor & detach_(c10::DispatchKeySet ks, Tensor & self) { +Tensor& detach_(c10::DispatchKeySet ks, Tensor& self) { RECORD_FUNCTION("detach_", std::vector({self})); if (self.is_view()) { // See NOTE [ View + Inplace detection ] - AT_ERROR("Can't detach views in-place. Use detach() instead. " - "If you are using DistributedDataParallel (DDP) for training, " - "and gradient_as_bucket_view is set as True, gradients are " - "views of DDP buckets, and hence detach_() cannot be called " - "on these gradients. To fix this error, please refer to the " - "Optimizer.zero_grad() function in torch/optim/optimizer.py " - "as the solution."); + AT_ERROR( + "Can't detach views in-place. Use detach() instead. " + "If you are using DistributedDataParallel (DDP) for training, " + "and gradient_as_bucket_view is set as True, gradients are " + "views of DDP buckets, and hence detach_() cannot be called " + "on these gradients. To fix this error, please refer to the " + "Optimizer.zero_grad() function in torch/optim/optimizer.py " + "as the solution."); } // I think the choice here is conservative. In principle, doing // an in-place detach should give us the ability to just clear @@ -277,104 +317,163 @@ Tensor & detach_(c10::DispatchKeySet ks, Tensor & self) { // (1) CompositeImplicitAutograd kernels // (2) Autograd kernels // (3) CompositeExplicitAutograd kernels and additionally Autograd kernels -// The reason for (3) is that ops that also use dispatch (e.g. register CPU/CUDA/QuantizedCPU -// kernels) will skip picking up CompositeImplicitAutograd kernels for Autograd, so we register them to both -// CompositeExplicitAutograd and Autograd instead. See +// The reason for (3) is that ops that also use dispatch (e.g. register +// CPU/CUDA/QuantizedCPU kernels) will skip picking up CompositeImplicitAutograd +// kernels for Autograd, so we register them to both CompositeExplicitAutograd +// and Autograd instead. See // https://github.com/pytorch/pytorch/tree/master/aten/src/ATen/native#choosing-the-right-dispatch-keyword // for more details. // Invariant: -// - Ops registered to CompositeImplicitAutograd or CompositeExplicitAutograd below must match `MANUAL_BACKEND` set in tools/autograd/gen_variable_type.py. +// - Ops registered to CompositeImplicitAutograd or CompositeExplicitAutograd +// below must match `MANUAL_BACKEND` set in tools/autograd/gen_variable_type.py. // and they have manual_kernel_registration=True in native_functions.yaml. -// - Ops registered to DispatchKey::Autograd below must be included in `MANUAL_AUTOGRAD` in tools/autograd/gen_variable_type.py +// - Ops registered to DispatchKey::Autograd below must be included in +// `MANUAL_AUTOGRAD` in tools/autograd/gen_variable_type.py TORCH_LIBRARY_IMPL(aten, Autograd, m) { - m.impl("resize_", torch::dispatch(DispatchKey::Autograd, TORCH_FN(VariableType::resize_))); - m.impl("resize_as_", torch::dispatch(DispatchKey::Autograd, TORCH_FN(VariableType::resize_as_))); - m.impl("detach", torch::dispatch(DispatchKey::Autograd, TORCH_FN(VariableType::detach))); - m.impl("detach_", torch::dispatch(DispatchKey::Autograd, TORCH_FN(VariableType::detach_))); - m.impl("copy_", torch::dispatch(DispatchKey::Autograd, TORCH_FN(VariableType::copy_))); - m.impl("_fw_primal", torch::dispatch(DispatchKey::Autograd, TORCH_FN(VariableType::_fw_primal))); - m.impl("_make_dual", torch::dispatch(DispatchKey::Autograd, TORCH_FN(VariableType::_make_dual))); - + m.impl( + "resize_", + torch::dispatch(DispatchKey::Autograd, TORCH_FN(VariableType::resize_))); + m.impl( + "resize_as_", + torch::dispatch( + DispatchKey::Autograd, TORCH_FN(VariableType::resize_as_))); + m.impl( + "detach", + torch::dispatch(DispatchKey::Autograd, TORCH_FN(VariableType::detach))); + m.impl( + "detach_", + torch::dispatch(DispatchKey::Autograd, TORCH_FN(VariableType::detach_))); + m.impl( + "copy_", + torch::dispatch(DispatchKey::Autograd, TORCH_FN(VariableType::copy_))); + m.impl( + "_fw_primal", + torch::dispatch( + DispatchKey::Autograd, TORCH_FN(VariableType::_fw_primal))); + m.impl( + "_make_dual", + torch::dispatch( + DispatchKey::Autograd, TORCH_FN(VariableType::_make_dual))); } -} // namespace -}} // namespace autograd::VariableType +} // namespace +} // namespace VariableType +} // namespace autograd namespace ADInplaceOrView { - #define CREATION_META_DEFINITION InferenceMode::is_enabled() ? CreationMeta::INFERENCE_MODE : (at::GradMode::is_enabled() ? CreationMeta::DEFAULT : CreationMeta::NO_GRAD_MODE) +#define CREATION_META_DEFINITION \ + InferenceMode::is_enabled() \ + ? CreationMeta::INFERENCE_MODE \ + : (at::GradMode::is_enabled() ? CreationMeta::DEFAULT \ + : CreationMeta::NO_GRAD_MODE) - Tensor & copy_(c10::DispatchKeySet ks, Tensor & self, const Tensor & src, bool non_blocking) { - { - at::AutoDispatchBelowADInplaceOrView guard; - at::redispatch::copy_(ks & c10::after_ADInplaceOrView_keyset, self, src, non_blocking); - } - torch::autograd::increment_version(self); - return self; +Tensor& copy_( + c10::DispatchKeySet ks, + Tensor& self, + const Tensor& src, + bool non_blocking) { + { + at::AutoDispatchBelowADInplaceOrView guard; + at::redispatch::copy_( + ks & c10::after_ADInplaceOrView_keyset, self, src, non_blocking); } + torch::autograd::increment_version(self); + return self; +} - Tensor detach(c10::DispatchKeySet ks, const Tensor & self) { - auto out = ([&]() { - at::AutoDispatchBelowADInplaceOrView guard; - return at::_ops::detach::redispatch(ks & c10::after_ADInplaceOrView_keyset, self); - })(); - // NB: we can't make detach() a normal view operator because the codegen generates - // allow_tensor_metadata_change = True for them. In the future we should have an - // option for this in the codegen. - std::function func=nullptr; - auto result = as_view(/* base */ self, /* output */ out, /* is_bw_differentiable */ false, - /* is_fw_differentiable */ false, /* view_func */ func, - /* creation_meta */ CreationMeta::DEFAULT, - /*allow_tensor_metadata_change=*/false); - - return result; - } +Tensor detach(c10::DispatchKeySet ks, const Tensor& self) { + auto out = ([&]() { + at::AutoDispatchBelowADInplaceOrView guard; + return at::_ops::detach::redispatch( + ks & c10::after_ADInplaceOrView_keyset, self); + })(); + // NB: we can't make detach() a normal view operator because the codegen + // generates allow_tensor_metadata_change = True for them. In the future we + // should have an option for this in the codegen. + std::function func = nullptr; + auto result = as_view( + /* base */ self, + /* output */ out, + /* is_bw_differentiable */ false, + /* is_fw_differentiable */ false, + /* view_func */ func, + /* creation_meta */ CreationMeta::DEFAULT, + /*allow_tensor_metadata_change=*/false); - Tensor _fw_primal(c10::DispatchKeySet ks, const Tensor & self, int64_t level) { - auto tmp = ([&]() { - at::AutoDispatchBelowADInplaceOrView guard; - return at::alias(self); - })(); - std::function func=nullptr; - if (!self.unsafeGetTensorImpl()->support_as_strided()) { - auto size_vec = self.sizes().vec(); - func = [=](const at::Tensor& input_base) { - return input_base.view(size_vec); - }; - } - auto result = as_view(/* base */ self, /* output */ tmp, /* is_bw_differentiable */ true, - /* is_fw_differentiable */ false, /* view_func */ func, /* creation_meta */ CREATION_META_DEFINITION); + return result; +} - return result; +Tensor _fw_primal(c10::DispatchKeySet ks, const Tensor& self, int64_t level) { + auto tmp = ([&]() { + at::AutoDispatchBelowADInplaceOrView guard; + return at::alias(self); + })(); + std::function func = nullptr; + if (!self.unsafeGetTensorImpl()->support_as_strided()) { + auto size_vec = self.sizes().vec(); + func = [=](const at::Tensor& input_base) { + return input_base.view(size_vec); + }; } + auto result = as_view( + /* base */ self, + /* output */ tmp, + /* is_bw_differentiable */ true, + /* is_fw_differentiable */ false, + /* view_func */ func, + /* creation_meta */ CREATION_META_DEFINITION); - // NB: This does not redispatch any further - Tensor _make_dual(c10::DispatchKeySet ks, const Tensor & primal, const Tensor & tangent, int64_t level) { - auto tmp = ([&]() { - at::AutoDispatchBelowADInplaceOrView guard; - return at::alias(primal); - })(); - std::function func=nullptr; - if (!primal.unsafeGetTensorImpl()->support_as_strided()) { - auto size_vec = primal.sizes().vec(); - func = [=](const at::Tensor& input_base) { - return input_base.view(size_vec); - }; - } - auto result = as_view(/* base */ primal, /* output */ tmp, /* is_bw_differentiable */ true, - /* is_fw_differentiable */ false, /* view_func */ func, /* creation_meta */ CREATION_META_DEFINITION); + return result; +} - return result; +// NB: This does not redispatch any further +Tensor _make_dual( + c10::DispatchKeySet ks, + const Tensor& primal, + const Tensor& tangent, + int64_t level) { + auto tmp = ([&]() { + at::AutoDispatchBelowADInplaceOrView guard; + return at::alias(primal); + })(); + std::function func = nullptr; + if (!primal.unsafeGetTensorImpl()->support_as_strided()) { + auto size_vec = primal.sizes().vec(); + func = [=](const at::Tensor& input_base) { + return input_base.view(size_vec); + }; } + auto result = as_view( + /* base */ primal, + /* output */ tmp, + /* is_bw_differentiable */ true, + /* is_fw_differentiable */ false, + /* view_func */ func, + /* creation_meta */ CREATION_META_DEFINITION); - namespace { - TORCH_LIBRARY_IMPL(aten, ADInplaceOrView, m) { - m.impl("copy_", torch::dispatch(DispatchKey::ADInplaceOrView, TORCH_FN(ADInplaceOrView::copy_))); - m.impl("detach", torch::dispatch(DispatchKey::ADInplaceOrView, TORCH_FN(ADInplaceOrView::detach))); - m.impl("_fw_primal", torch::dispatch(DispatchKey::ADInplaceOrView, TORCH_FN(ADInplaceOrView::_fw_primal))); - m.impl("_make_dual", torch::dispatch(DispatchKey::ADInplaceOrView, TORCH_FN(ADInplaceOrView::_make_dual))); + return result; +} - } - } // namespace +namespace { +TORCH_LIBRARY_IMPL(aten, ADInplaceOrView, m) { + m.impl( + "copy_", + torch::dispatch( + DispatchKey::ADInplaceOrView, TORCH_FN(ADInplaceOrView::copy_))); + m.impl( + "detach", + torch::dispatch( + DispatchKey::ADInplaceOrView, TORCH_FN(ADInplaceOrView::detach))); + m.impl( + "_fw_primal", + torch::dispatch( + DispatchKey::ADInplaceOrView, TORCH_FN(ADInplaceOrView::_fw_primal))); + m.impl( + "_make_dual", + torch::dispatch( + DispatchKey::ADInplaceOrView, TORCH_FN(ADInplaceOrView::_make_dual))); +} +} // namespace } // namespace ADInplaceOrView } // namespace torch diff --git a/torch/csrc/autograd/VariableTypeUtils.h b/torch/csrc/autograd/VariableTypeUtils.h index b060ea769a6d3f..b6c92a396ec03c 100644 --- a/torch/csrc/autograd/VariableTypeUtils.h +++ b/torch/csrc/autograd/VariableTypeUtils.h @@ -2,16 +2,16 @@ #include -#include -#include #include +#include +#include +#include #include #include -#include -#include +#include -#include #include +#include #include #include @@ -30,14 +30,14 @@ #endif #endif -namespace torch { namespace autograd { +namespace torch { +namespace autograd { // The requires_grad argument is used to know if the inplace operation needs // gradient to be setup for it. -// In particular, we can have tensor.requires_grad() != requires_grad when writing -// a Tensor that requires gradients inplace into a Tensor that does not require gradients: -// a = torch.rand(2) -// b = torch.rand(2, requires_grad=True) +// In particular, we can have tensor.requires_grad() != requires_grad when +// writing a Tensor that requires gradients inplace into a Tensor that does not +// require gradients: a = torch.rand(2) b = torch.rand(2, requires_grad=True) // a.copy_(b) inline void check_inplace(const at::Tensor& tensor, bool requires_grad) { if (requires_grad && GradMode::is_enabled()) { @@ -46,13 +46,13 @@ inline void check_inplace(const at::Tensor& tensor, bool requires_grad) { // This can throw or warn handle_view_on_rebase(diff_view_meta); if (tensor.requires_grad() && tensor._base().is_leaf()) { - AT_ERROR( + AT_ERROR( "a view of a leaf Variable that requires grad is being used in an in-place operation."); } } if (tensor.requires_grad() && tensor.is_leaf()) { AT_ERROR( - "a leaf Variable that requires grad is being used in an in-place operation."); + "a leaf Variable that requires grad is being used in an in-place operation."); } } } @@ -65,19 +65,26 @@ inline void check_inplace(const at::TensorList tensors, bool requires_grad) { inline void throw_error_out_requires_grad(const char* name) { AT_ERROR( - name, "(): functions with out=... arguments don't support automatic differentiation, " + name, + "(): functions with out=... arguments don't support automatic differentiation, " "but one of the arguments requires grad."); } -inline void throw_error_for_complex_autograd(const at::Tensor& tensor, const char* name) { +inline void throw_error_for_complex_autograd( + const at::Tensor& tensor, + const char* name) { if (tensor.requires_grad()) { - TORCH_CHECK(!tensor.is_complex(), name, - " does not support automatic differentiation for outputs with complex dtype."); + TORCH_CHECK( + !tensor.is_complex(), + name, + " does not support automatic differentiation for outputs with complex dtype."); } } -inline void throw_error_for_complex_autograd(const at::TensorList& tensorlist, const char* name) { - for (const auto& tensor: tensorlist) { +inline void throw_error_for_complex_autograd( + const at::TensorList& tensorlist, + const char* name) { + for (const auto& tensor : tensorlist) { throw_error_for_complex_autograd(tensor, name); } } @@ -91,7 +98,9 @@ inline void rebase_history(Variable& var, std::shared_ptr grad_fn) { } } -inline void rebase_history(std::vector&& vars, std::shared_ptr grad_fn) { +inline void rebase_history( + std::vector&& vars, + std::shared_ptr grad_fn) { if (grad_fn) { for (auto& var : vars) { if (var.defined()) { @@ -106,23 +115,27 @@ inline void rebase_history(std::vector&& vars, std::shared_ptr g } } -inline void increment_version(const at::Tensor & t) { +inline void increment_version(const at::Tensor& t) { impl::bump_version(t); } struct Flatten : IterArgs { Flatten(variable_list& out) : out(out) {} variable_list& out; - void operator()(const at::Tensor& x) { out.emplace_back(x); } + void operator()(const at::Tensor& x) { + out.emplace_back(x); + } void operator()(const c10::optional& x) { - if (x.has_value()) out.emplace_back(x.value()); + if (x.has_value()) + out.emplace_back(x.value()); } void operator()(at::ArrayRef xs) { out.insert(out.end(), xs.begin(), xs.end()); } }; -template inline variable_list flatten_tensor_args(Args&&... args) { +template +inline variable_list flatten_tensor_args(Args&&... args) { variable_list out; out.reserve(count_tensors(std::forward(args)...)); Flatten(out).apply(std::forward(args)...); @@ -130,29 +143,48 @@ template inline variable_list flatten_tensor_args(Args&&... ar } // See NOTE [ Autograd View Variables ] for details. -inline at::Tensor as_view(const at::Tensor & base, const at::Tensor & tensor, bool is_bw_differentiable, - bool is_fw_differentiable, std::function view_func=nullptr, - CreationMeta creation_meta=CreationMeta::DEFAULT, bool allow_tensor_metadata_change=true) { +inline at::Tensor as_view( + const at::Tensor& base, + const at::Tensor& tensor, + bool is_bw_differentiable, + bool is_fw_differentiable, + std::function view_func = nullptr, + CreationMeta creation_meta = CreationMeta::DEFAULT, + bool allow_tensor_metadata_change = true) { // Note [View of inference tensor] // For inference tensor this code can only be hit outside InferenceMode // since ADInplaceOrView is in the default_included_set. // If Inplace and View were separate dispatch keys we can just put Inplace // in the default_included_set, so that view ops on inference tensor doesn't // have to go through as_view even outside InferenceMode. - if (base.is_inference()) return tensor; + if (base.is_inference()) + return tensor; auto diff_view_meta = torch::autograd::impl::get_view_autograd_meta(base); - // To speed up the most common case, we specially handle when both the forward and backward - // view infos are the same, and so a single shared ViewInfo can be used for both of them. - if ((!diff_view_meta || diff_view_meta->shared_view_info()) && is_bw_differentiable && is_fw_differentiable) { + // To speed up the most common case, we specially handle when both the forward + // and backward view infos are the same, and so a single shared ViewInfo can + // be used for both of them. + if ((!diff_view_meta || diff_view_meta->shared_view_info()) && + is_bw_differentiable && is_fw_differentiable) { if (diff_view_meta) { - creation_meta = propagate_creation_meta(diff_view_meta->get_creation_meta(), creation_meta); - return make_variable_differentiable_view(tensor, diff_view_meta->get_backward_view().chain(base, tensor, view_func), - c10::nullopt, /*shared_view_info*/ true, creation_meta, allow_tensor_metadata_change); + creation_meta = propagate_creation_meta( + diff_view_meta->get_creation_meta(), creation_meta); + return make_variable_differentiable_view( + tensor, + diff_view_meta->get_backward_view().chain(base, tensor, view_func), + c10::nullopt, + /*shared_view_info*/ true, + creation_meta, + allow_tensor_metadata_change); } else { - return make_variable_differentiable_view(tensor, ViewInfo(base, view_func), c10::nullopt, /*shared_view_info*/ true, - creation_meta, allow_tensor_metadata_change); + return make_variable_differentiable_view( + tensor, + ViewInfo(base, view_func), + c10::nullopt, + /*shared_view_info*/ true, + creation_meta, + allow_tensor_metadata_change); } } @@ -168,8 +200,9 @@ inline at::Tensor as_view(const at::Tensor & base, const at::Tensor & tensor, bo new_bw_info = ViewInfo(base, view_func); } } else { - TORCH_CHECK(creation_meta == CreationMeta::DEFAULT, - "Non-backward differentiable views must have creation_meta=CreationMeta::DEFAULT"); + TORCH_CHECK( + creation_meta == CreationMeta::DEFAULT, + "Non-backward differentiable views must have creation_meta=CreationMeta::DEFAULT"); } if (is_fw_differentiable) { @@ -184,45 +217,68 @@ inline at::Tensor as_view(const at::Tensor & base, const at::Tensor & tensor, bo if (is_fw_differentiable || is_bw_differentiable) { if (diff_view_meta && diff_view_meta->has_bw_view()) { - creation_meta = propagate_creation_meta(diff_view_meta->get_creation_meta(), creation_meta); + creation_meta = propagate_creation_meta( + diff_view_meta->get_creation_meta(), creation_meta); } - return make_variable_differentiable_view(tensor, std::move(new_bw_info), std::move(new_fw_info), /*shared_view_info*/ false, - creation_meta, allow_tensor_metadata_change); + return make_variable_differentiable_view( + tensor, + std::move(new_bw_info), + std::move(new_fw_info), + /*shared_view_info*/ false, + creation_meta, + allow_tensor_metadata_change); } else { - return make_variable_non_differentiable_view(base, tensor, allow_tensor_metadata_change); + return make_variable_non_differentiable_view( + base, tensor, allow_tensor_metadata_change); } } // See NOTE [ Autograd View Variables ] for details. -inline std::vector as_view(const at::Tensor & base, std::vector& tensors, bool is_bw_differentiable, - bool is_fw_differentiable, CreationMeta creation_meta=CreationMeta::DEFAULT) { +inline std::vector as_view( + const at::Tensor& base, + std::vector& tensors, + bool is_bw_differentiable, + bool is_fw_differentiable, + CreationMeta creation_meta = CreationMeta::DEFAULT) { // See Note [View of inference tensor] - if (base.is_inference()) return tensors; + if (base.is_inference()) + return tensors; auto diff_view_meta = torch::autograd::impl::get_view_autograd_meta(base); - // Special case when view info can be shared for forward and backward differentiable views - if ((!diff_view_meta || diff_view_meta->shared_view_info()) && is_bw_differentiable && is_fw_differentiable) { + // Special case when view info can be shared for forward and backward + // differentiable views + if ((!diff_view_meta || diff_view_meta->shared_view_info()) && + is_bw_differentiable && is_fw_differentiable) { c10::optional new_shared_info; if (diff_view_meta) { - // TODO: fix fb internal use-case so that it doesn't trigger this internal assert when the base is not a view. - // For now, we only do that same (wrong) thing as the old code which is to only check when the inputs is a - // backward differentiable view + // TODO: fix fb internal use-case so that it doesn't trigger this internal + // assert when the base is not a view. For now, we only do that same + // (wrong) thing as the old code which is to only check when the inputs is + // a backward differentiable view if (diff_view_meta->has_bw_view()) { - TORCH_INTERNAL_ASSERT(creation_meta == CreationMeta::NO_GRAD_MODE || creation_meta == CreationMeta::INFERENCE_MODE || - creation_meta == CreationMeta::MULTI_OUTPUT_NODE, + TORCH_INTERNAL_ASSERT( + creation_meta == CreationMeta::NO_GRAD_MODE || + creation_meta == CreationMeta::INFERENCE_MODE || + creation_meta == CreationMeta::MULTI_OUTPUT_NODE, "Functions that result multiple view must have a creation meta reflecting this behavior or more restrictive."); } - creation_meta = propagate_creation_meta(diff_view_meta->get_creation_meta(), creation_meta); + creation_meta = propagate_creation_meta( + diff_view_meta->get_creation_meta(), creation_meta); const auto& base_bw_info = diff_view_meta->get_backward_view(); new_shared_info = ViewInfo(base_bw_info.base_, /* view_func */ nullptr); } else { new_shared_info = ViewInfo(base, /* view_func */ nullptr); } - for(at::Tensor &tensor : tensors) { + for (at::Tensor& tensor : tensors) { if (is_fw_differentiable || is_bw_differentiable) { - tensor = make_variable_differentiable_view(tensor, new_shared_info, c10::nullopt, /*shared_view_info*/ true, creation_meta); + tensor = make_variable_differentiable_view( + tensor, + new_shared_info, + c10::nullopt, + /*shared_view_info*/ true, + creation_meta); } else { tensor = make_variable_non_differentiable_view(base, tensor); } @@ -237,31 +293,37 @@ inline std::vector as_view(const at::Tensor & base, std::vectorhas_bw_view()) { const auto& base_bw_info = diff_view_meta->get_backward_view(); - // TODO: fix fb internal use-case so that it doesn't trigger this internal assert when the base is not a view. - // In this code, the assert should be outside of the if statement. - TORCH_INTERNAL_ASSERT(creation_meta == CreationMeta::NO_GRAD_MODE || creation_meta == CreationMeta::INFERENCE_MODE || - creation_meta == CreationMeta::MULTI_OUTPUT_NODE, + // TODO: fix fb internal use-case so that it doesn't trigger this internal + // assert when the base is not a view. In this code, the assert should be + // outside of the if statement. + TORCH_INTERNAL_ASSERT( + creation_meta == CreationMeta::NO_GRAD_MODE || + creation_meta == CreationMeta::INFERENCE_MODE || + creation_meta == CreationMeta::MULTI_OUTPUT_NODE, "Functions that result multiple view must have a creation meta reflecting this behavior or more restrictive."); - // It is ok to create a ViewInfo where only the base is correct in this case as inplace operations on such views are - // not allowed + // It is ok to create a ViewInfo where only the base is correct in this + // case as inplace operations on such views are not allowed new_bw_info = ViewInfo(base_bw_info.base_, /* view_func */ nullptr); } else { new_bw_info = ViewInfo(base, /* view_func */ nullptr); } } else { - TORCH_CHECK(creation_meta == CreationMeta::DEFAULT, - "Non-backward differentiable views must have creation_meta=CreationMeta::DEFAULT"); + TORCH_CHECK( + creation_meta == CreationMeta::DEFAULT, + "Non-backward differentiable views must have creation_meta=CreationMeta::DEFAULT"); } if (is_fw_differentiable) { // Check if base is a forward differentiabble view auto diff_view_meta = torch::autograd::impl::get_view_autograd_meta(base); if (diff_view_meta && diff_view_meta->has_fw_view()) { const auto& base_fw_info = diff_view_meta->get_forward_view(); - TORCH_INTERNAL_ASSERT(creation_meta == CreationMeta::NO_GRAD_MODE || creation_meta == CreationMeta::INFERENCE_MODE || - creation_meta == CreationMeta::MULTI_OUTPUT_NODE, + TORCH_INTERNAL_ASSERT( + creation_meta == CreationMeta::NO_GRAD_MODE || + creation_meta == CreationMeta::INFERENCE_MODE || + creation_meta == CreationMeta::MULTI_OUTPUT_NODE, "Functions that result multiple view must have a creation meta reflecting this behavior or more restrictive."); - // It is ok to create a ViewInfo where only the base is correct in this case as inplace operations on such views are - // not allowed + // It is ok to create a ViewInfo where only the base is correct in this + // case as inplace operations on such views are not allowed new_fw_info = ViewInfo(base_fw_info.base_, /* view_func */ nullptr); } else { new_fw_info = ViewInfo(base, /* view_func */ nullptr); @@ -271,12 +333,18 @@ inline std::vector as_view(const at::Tensor & base, std::vector diff_view_meta auto diff_view_meta = torch::autograd::impl::get_view_autograd_meta(base); - creation_meta = propagate_creation_meta(diff_view_meta->get_creation_meta(), creation_meta); + creation_meta = propagate_creation_meta( + diff_view_meta->get_creation_meta(), creation_meta); } - for(at::Tensor &tensor : tensors) { + for (at::Tensor& tensor : tensors) { if (is_fw_differentiable || is_bw_differentiable) { - tensor = make_variable_differentiable_view(tensor, new_bw_info, new_fw_info, /*shared_view_info*/ false, creation_meta); + tensor = make_variable_differentiable_view( + tensor, + new_bw_info, + new_fw_info, + /*shared_view_info*/ false, + creation_meta); } else { tensor = make_variable_non_differentiable_view(base, tensor); } @@ -284,20 +352,34 @@ inline std::vector as_view(const at::Tensor & base, std::vector& tensor, const char* name, const char* fn_name="") { +inline void check_no_requires_grad( + const c10::optional& tensor, + const char* name, + const char* fn_name = "") { if (tensor.has_value()) { check_no_requires_grad(*tensor, name, fn_name); } } -inline void check_no_requires_grad(at::TensorList tensors, const char* name, const char* fn_name="") { +inline void check_no_requires_grad( + at::TensorList tensors, + const char* name, + const char* fn_name = "") { // GradMode check is expensive, so check it only once for TensorLists if (!GradMode::is_enabled()) { return; @@ -307,7 +389,10 @@ inline void check_no_requires_grad(at::TensorList tensors, const char* name, con } } -inline void check_no_requires_grad(const c10::List>& tensors, const char* name, const char* fn_name="") { +inline void check_no_requires_grad( + const c10::List>& tensors, + const char* name, + const char* fn_name = "") { // GradMode check is expensive, so check it only once for TensorLists if (!GradMode::is_enabled()) { return; @@ -320,20 +405,24 @@ inline void check_no_requires_grad(const c10::List>& t } // Assumed that saved tensor lists are never inplace outputs -inline std::vector make_saved_variable_list(at::TensorList tensors) { +inline std::vector make_saved_variable_list( + at::TensorList tensors) { return fmap(tensors, [](const at::Tensor& tensor) -> SavedVariable { - return SavedVariable{tensor, false /* is output */}; }); + return SavedVariable{tensor, false /* is output */}; + }); } // Assumed that saved tensor lists are never inplace outputs -inline std::vector make_saved_variable_list(const c10::List>& tensors) { - return fmap(tensors, [](const c10::optional& tensor) -> SavedVariable { - if (tensor.has_value()) { - return SavedVariable{*tensor, false /* is output */}; - } else { - return SavedVariable{at::Tensor(), false /* is output */}; - } - }); +inline std::vector make_saved_variable_list( + const c10::List>& tensors) { + return fmap( + tensors, [](const c10::optional& tensor) -> SavedVariable { + if (tensor.has_value()) { + return SavedVariable{*tensor, false /* is output */}; + } else { + return SavedVariable{at::Tensor(), false /* is output */}; + } + }); } inline std::vector> to_args_sizes(at::TensorList tensors) { @@ -344,7 +433,8 @@ inline std::vector> to_args_sizes(at::TensorList tensors) { return args_sizes; } -inline std::vector to_args_scalartypes(at::TensorList tensors) { +inline std::vector to_args_scalartypes( + at::TensorList tensors) { std::vector args_scalartypes(tensors.size()); for (const auto i : c10::irange(tensors.size())) { args_scalartypes[i] = tensors[i].scalar_type(); @@ -352,4 +442,5 @@ inline std::vector to_args_scalartypes(at::TensorList tensors) return args_scalartypes; } -}} // namespace torch::autograd +} // namespace autograd +} // namespace torch diff --git a/torch/csrc/autograd/anomaly_mode.h b/torch/csrc/autograd/anomaly_mode.h index 95251dcbb8e4ac..a741acdd38af04 100644 --- a/torch/csrc/autograd/anomaly_mode.h +++ b/torch/csrc/autograd/anomaly_mode.h @@ -1,10 +1,11 @@ #pragma once -#include -#include #include +#include +#include -namespace torch { namespace autograd { +namespace torch { +namespace autograd { // forward declaration of Node from function.h struct Node; @@ -17,7 +18,7 @@ struct TORCH_API AnomalyMode { _enabled = enabled; } -private: + private: static bool _enabled; }; @@ -60,4 +61,5 @@ struct TORCH_API AnomalyMetadata { std::shared_ptr parent_; }; -}} +} // namespace autograd +} // namespace torch diff --git a/torch/csrc/autograd/autograd.cpp b/torch/csrc/autograd/autograd.cpp index 471c0cb8944c7d..4a2a51235b9e8e 100644 --- a/torch/csrc/autograd/autograd.cpp +++ b/torch/csrc/autograd/autograd.cpp @@ -14,8 +14,9 @@ namespace autograd { // NB: This code duplicates existing logic at torch/autograd/__init__.py and // torch._C._EngineBase.run_backward in torch/csrc/autograd/python_engine.cpp // This is a purely C++ API for Autograd without any dependencies on python -// it can be exposed in PyTorch C++ API and TorchScript. We will need to maintain -// the logic equality of this file and the python file together if one changes. +// it can be exposed in PyTorch C++ API and TorchScript. We will need to +// maintain the logic equality of this file and the python file together if one +// changes. // TODO: Make the Python API above to just call this C++ API. variable_list _make_grads( const variable_list& outputs, @@ -30,13 +31,18 @@ variable_list _make_grads( TORCH_CHECK( output.numel() == 1, "grad can be implicitly created only for scalar outputs"); - new_grads.emplace_back(at::ones_like(output, LEGACY_CONTIGUOUS_MEMORY_FORMAT)); + new_grads.emplace_back( + at::ones_like(output, LEGACY_CONTIGUOUS_MEMORY_FORMAT)); } } } else { TORCH_CHECK( num_tensors == num_gradients, - "got ", num_tensors, " tensors and ", num_gradients, " gradients"); + "got ", + num_tensors, + " tensors and ", + num_gradients, + " gradients"); for (const auto i : c10::irange(outputs.size())) { const Variable& output = outputs[i]; const Variable& grad_output = grad_outputs[i]; @@ -45,16 +51,22 @@ variable_list _make_grads( TORCH_CHECK( output.numel() == 1, "grad can be implicitly created only for scalar outputs"); - new_grads.emplace_back(at::ones_like(output, LEGACY_CONTIGUOUS_MEMORY_FORMAT)); + new_grads.emplace_back( + at::ones_like(output, LEGACY_CONTIGUOUS_MEMORY_FORMAT)); } } else { TORCH_CHECK( - grad_output.is_complex() == output.is_complex(), - "For complex Tensors, both grad_output and output are required ", - "to have the same dtype. Mismatch in dtype: grad_output[", - grad_output, "] has a dtype of ", grad_output.scalar_type(), - " and output[", output, "] has a dtype of ", output.scalar_type(), - "."); + grad_output.is_complex() == output.is_complex(), + "For complex Tensors, both grad_output and output are required ", + "to have the same dtype. Mismatch in dtype: grad_output[", + grad_output, + "] has a dtype of ", + grad_output.scalar_type(), + " and output[", + output, + "] has a dtype of ", + output.scalar_type(), + "."); // grad output is defined, just append to the new_grads new_grads.emplace_back(grad_output); } @@ -73,12 +85,14 @@ variable_list run_backward( size_t num_tensors = outputs.size(); edge_list roots; roots.reserve(num_tensors); - for(const auto i : c10::irange(num_tensors)) { + for (const auto i : c10::irange(num_tensors)) { const Variable& output = outputs[i]; auto gradient_edge = impl::gradient_edge(output); TORCH_CHECK( gradient_edge.function, - "element ", i, " of tensors does not require grad and does not have a grad_fn"); + "element ", + i, + " of tensors does not require grad and does not have a grad_fn"); roots.push_back(std::move(gradient_edge)); } @@ -109,7 +123,12 @@ variable_list run_backward( } variable_list grad_inputs = Engine::get_default_engine().execute( - roots, grad_outputs, keep_graph, create_graph, accumulate_grad, output_edges); + roots, + grad_outputs, + keep_graph, + create_graph, + accumulate_grad, + output_edges); // check if grad_inputs contains None or not base on the allow_unused flag if (!inputs.empty() && !allow_unused) { size_t num_inputs = inputs.size(); @@ -135,7 +154,14 @@ void backward( if (!retain_graph) { retain_graph = create_graph; } - run_backward(tensors, gradients, retain_graph.value(), create_graph, inputs, /*allow_unused=*/true, /*accumulate_grad=*/true); + run_backward( + tensors, + gradients, + retain_graph.value(), + create_graph, + inputs, + /*allow_unused=*/true, + /*accumulate_grad=*/true); } variable_list grad( @@ -150,10 +176,15 @@ variable_list grad( retain_graph = create_graph; } return run_backward( - outputs, gradients, retain_graph.value(), create_graph, inputs, allow_unused, /*accumulate_grad=*/false); + outputs, + gradients, + retain_graph.value(), + create_graph, + inputs, + allow_unused, + /*accumulate_grad=*/false); } - namespace forward_ad { uint64_t enter_dual_level() { diff --git a/torch/csrc/autograd/autograd.h b/torch/csrc/autograd/autograd.h index c4362ef1ab1efb..c85ff99b46142e 100644 --- a/torch/csrc/autograd/autograd.h +++ b/torch/csrc/autograd/autograd.h @@ -8,35 +8,43 @@ namespace autograd { /// Computes the sum of gradients of given tensors with respect to graph leaves. /// /// The graph is differentiated using the chain rule. If any of ``tensors`` -/// are non-scalar (i.e. their data has more than one element) and require gradient, -/// then the Jacobian-vector product would be computed, in this case the function -/// additionally requires specifying `grad_tensors`. It should be a sequence of -/// matching length, that contains the "vector" in the Jacobian-vector product, -/// usually the gradient of the differentiated function w.r.t. corresponding tensors +/// are non-scalar (i.e. their data has more than one element) and require +/// gradient, then the Jacobian-vector product would be computed, in this case +/// the function additionally requires specifying `grad_tensors`. It should be a +/// sequence of matching length, that contains the "vector" in the +/// Jacobian-vector product, usually the gradient of the differentiated function +/// w.r.t. corresponding tensors /// (`torch::Tensor()` is an acceptable value for all tensors that don't need /// gradient tensors). /// -/// This function accumulates gradients in the leaves - you might need to zero them -/// before calling it. +/// This function accumulates gradients in the leaves - you might need to zero +/// them before calling it. /// /// \param tensors Tensors of which the derivative will be computed. -/// \param grad_tensors The "vector" in the Jacobian-vector product, usually gradients -/// w.r.t. each element of corresponding tensors. `torch::Tensor()` values can be -/// specified for scalar Tensors or ones that don't require grad. If a `torch::Tensor()` value -/// would be acceptable for all grad_tensors, then this argument is optional. -/// \param retain_graph If `false`, the graph used to compute the grad will be freed. -/// Note that in nearly all cases setting this option to `true` is not needed -/// and often can be worked around in a much more efficient way. Defaults to the -/// value of `create_graph`. -/// \param create_graph If `true`, graph of the derivative will be constructed, allowing +/// \param grad_tensors The "vector" in the Jacobian-vector product, usually +/// gradients +/// w.r.t. each element of corresponding tensors. `torch::Tensor()` values +/// can be specified for scalar Tensors or ones that don't require grad. If +/// a `torch::Tensor()` value would be acceptable for all grad_tensors, then +/// this argument is optional. +/// \param retain_graph If `false`, the graph used to compute the grad will be +/// freed. +/// Note that in nearly all cases setting this option to `true` is not +/// needed and often can be worked around in a much more efficient way. +/// Defaults to the value of `create_graph`. +/// \param create_graph If `true`, graph of the derivative will be constructed, +/// allowing /// to compute higher order derivative products. Defaults to `false`. /// \param inputs Inputs w.r.t. which the gradient will be accumulated into -/// `at::Tensor::grad`. All other Tensors will be ignored. If not provided, the gradient -/// is accumulated into all the leaf Tensors that were used to compute param `tensors`. +/// `at::Tensor::grad`. All other Tensors will be ignored. If not provided, +/// the gradient is accumulated into all the leaf Tensors that were used to +/// compute param `tensors`. // When inputs are provided and a given input is not a leaf, -// the current implementation will call its grad_fn (even though it is not strictly needed to get this gradients). -// It is an implementation detail on which the user should not rely. -// See https://github.com/pytorch/pytorch/pull/60521#issuecomment-867061780 for more details. +// the current implementation will call its grad_fn (even though it is not +// strictly needed to get this gradients). It is an implementation detail +// on which the user should not rely. See +// https://github.com/pytorch/pytorch/pull/60521#issuecomment-867061780 for +// more details. TORCH_API void backward( const variable_list& tensors, const variable_list& grad_tensors = {}, @@ -44,7 +52,8 @@ TORCH_API void backward( bool create_graph = false, const variable_list& inputs = {}); -/// Computes and returns the sum of gradients of outputs with respect to the inputs. +/// Computes and returns the sum of gradients of outputs with respect to the +/// inputs. /// /// ``grad_outputs`` should be a sequence of length matching ``output`` /// containing the "vector" in Jacobian-vector product, usually the pre-computed @@ -55,13 +64,14 @@ TORCH_API void backward( /// \param inputs Inputs w.r.t. which the gradient will be /// returned (and not accumulated into ``at::Tensor::grad``). /// \param grad_outputs The "vector" in the Jacobian-vector product. -/// Usually gradients w.r.t. each output. `torch::Tensor()` values can be specified for scalar -/// Tensors or ones that don't require grad. If a `torch::Tensor()` value would be acceptable -/// for all grad_tensors, then this argument is optional. Default: `{}`. +/// Usually gradients w.r.t. each output. `torch::Tensor()` values can be +/// specified for scalar Tensors or ones that don't require grad. If a +/// `torch::Tensor()` value would be acceptable for all grad_tensors, then +/// this argument is optional. Default: `{}`. /// \param retain_graph If ``false``, the graph used to compute the grad -/// will be freed. Note that in nearly all cases setting this option to ``true`` -/// is not needed and often can be worked around in a much more efficient -/// way. Defaults to the value of ``create_graph``. +/// will be freed. Note that in nearly all cases setting this option to +/// ``true`` is not needed and often can be worked around in a much more +/// efficient way. Defaults to the value of ``create_graph``. /// \param create_graph If ``true``, graph of the derivative will /// be constructed, allowing to compute higher order derivative products. /// Default: ``false``. @@ -78,16 +88,17 @@ TORCH_API variable_list grad( namespace forward_ad { -/// Creates a new dual level and returns its index. This level index should then be used to call -/// into the other functions below. -/// This API supports entering a new level before the previous one is exited. We call them nested -/// forward AD levels. These can be used to compute higher order derivatives. +/// Creates a new dual level and returns its index. This level index should then +/// be used to call into the other functions below. This API supports entering a +/// new level before the previous one is exited. We call them nested forward AD +/// levels. These can be used to compute higher order derivatives. TORCH_API uint64_t enter_dual_level(); -/// Exits the given level. This will clear up all the gradients from this level and all dual Tensors -/// that had gradients for this level will become regular Tensors again. -/// This function can only be used to exit the innermost nesting level and so exiting must happen in -/// reverse order compared to the entering that was done with the function above. +/// Exits the given level. This will clear up all the gradients from this level +/// and all dual Tensors that had gradients for this level will become regular +/// Tensors again. This function can only be used to exit the innermost nesting +/// level and so exiting must happen in reverse order compared to the entering +/// that was done with the function above. TORCH_API void exit_dual_level(uint64_t level); } // namespace forward_ad diff --git a/torch/csrc/autograd/autograd_meta.cpp b/torch/csrc/autograd/autograd_meta.cpp index c84c613050acb1..e3c0db4eb6a8f6 100644 --- a/torch/csrc/autograd/autograd_meta.cpp +++ b/torch/csrc/autograd/autograd_meta.cpp @@ -16,13 +16,15 @@ namespace autograd { using at::Tensor; // [Forward Grad View/inplace] -// It is important to us to allow view and inplace to work with dual Tensors. These operations -// should either compute the right gradient or raise a user-friendly error. +// It is important to us to allow view and inplace to work with dual Tensors. +// These operations should either compute the right gradient or raise a +// user-friendly error. // The basic case where all Tensors are dual Tensors is as follows: // # Have: // # foo is a dual Tensor that is not a view -// # bar is a dual Tensor of appropriate size (depending on cases) that is not a view +// # bar is a dual Tensor of appropriate size (depending on cases) that is +// not a view // // # Case 1: no view // foo.copy_(bar) @@ -39,22 +41,26 @@ using at::Tensor; // # In the second and third cases, the forward grad of view must match // # the one of foo for the subset they have in common. // -// All these cases can be handled by the following layout constraint on the forward grad: -// - A Tensor and its forward grad (for all levels) must have the same metadata (size, stride -// conj/neg bit and storage offset). Storage offset must be in this metadata because of -// as_strided. conj/neg bit must be part of this metadata because of ops like `real`. -// - View operations must create a forward grad that is a view of the base's forward grad. +// All these cases can be handled by the following layout constraint on the +// forward grad: +// - A Tensor and its forward grad (for all levels) must have the same +// metadata (size, stride +// conj/neg bit and storage offset). Storage offset must be in this metadata +// because of as_strided. conj/neg bit must be part of this metadata because +// of ops like `real`. +// - View operations must create a forward grad that is a view of the base's +// forward grad. // - Inplace operations must modify the input's forward grad inplace. // // This layout constraint is ensured in the `set_fw_grad` function below - // More complex cases arrise when non-dual Tensor interact with dual Tensors. // The two most important cases are: // // # Have: // # foo is a regular Tensor that is not a view -// # bar is a dual Tensor of appropriate size (depending on cases) that is not a view +// # bar is a dual Tensor of appropriate size (depending on cases) that is +// not a view // // # Case 4: Changes on the view must propagate to its base // view = foo[0] @@ -68,51 +74,62 @@ using at::Tensor; // base.copy_(bar) // # Now both view and foo are dual Tensor with appropriate forward grad // -// # NB there is a case 6 involving changes on a view propagating to other views -// # but it is fully described by the two others and is skipped in this discussion. +// # NB there is a case 6 involving changes on a view propagating to other +// views # but it is fully described by the two others and is skipped in +// this discussion. // -// Case 4 is handled by set_fw_grad by properly setting the forward grad of the base if needed. -// Case 5 is handled in fw_grad by reading the forward grad from the base if needed. - +// Case 4 is handled by set_fw_grad by properly setting the forward grad of the +// base if needed. Case 5 is handled in fw_grad by reading the forward grad from +// the base if needed. namespace { - // Check if two Tensor have the same storage offset, sizes and strides - bool has_same_meta(const Variable& base, const Variable& other) { - if (!base.defined() || !other.defined()) { - return false; - } - if (base.storage_offset() != other.storage_offset()) { - return false; - } - if (base.dim() != other.dim()) { - return false; - } - for (const auto i : c10::irange(base.dim())) { - if (base.sizes()[i] != other.sizes()[i]) { - return false; - } - if (base.strides()[i] != other.strides()[i] && base.sizes()[i] != 1 && base.sizes()[i] != 0) { - return false; - } - } - if (!at::_has_same_storage_numel(base, other)) { +// Check if two Tensor have the same storage offset, sizes and strides +bool has_same_meta(const Variable& base, const Variable& other) { + if (!base.defined() || !other.defined()) { + return false; + } + if (base.storage_offset() != other.storage_offset()) { + return false; + } + if (base.dim() != other.dim()) { + return false; + } + for (const auto i : c10::irange(base.dim())) { + if (base.sizes()[i] != other.sizes()[i]) { return false; } - if (base.is_conj() != other.is_conj() || base.is_neg() != other.is_neg()) { + if (base.strides()[i] != other.strides()[i] && base.sizes()[i] != 1 && + base.sizes()[i] != 0) { return false; } - return true; } + if (!at::_has_same_storage_numel(base, other)) { + return false; + } + if (base.is_conj() != other.is_conj() || base.is_neg() != other.is_neg()) { + return false; + } + return true; +} } // anonymous namespace -// This function is will ensure that the fw_grad_ is properly a view of the base for inplace ops on -// Tensors that do not have forward grad originally. -void AutogradMeta::set_fw_grad(const at::TensorBase& new_grad_base, const at::TensorBase& self_base, uint64_t level, bool is_inplace_op) { - TORCH_CHECK(!new_grad_base._fw_grad(level).defined(), "Setting a forward grad that " - "itself has a forward gradient at the same level", level, " is not supported."); - TORCH_INTERNAL_ASSERT((new_grad_base.is_floating_point() || new_grad_base.is_complex()) && - (self_base.is_floating_point() || self_base.is_complex()), - "Expected both tensor and its forward grad to be floating point or complex"); +// This function is will ensure that the fw_grad_ is properly a view of the base +// for inplace ops on Tensors that do not have forward grad originally. +void AutogradMeta::set_fw_grad( + const at::TensorBase& new_grad_base, + const at::TensorBase& self_base, + uint64_t level, + bool is_inplace_op) { + TORCH_CHECK( + !new_grad_base._fw_grad(level).defined(), + "Setting a forward grad that " + "itself has a forward gradient at the same level", + level, + " is not supported."); + TORCH_INTERNAL_ASSERT( + (new_grad_base.is_floating_point() || new_grad_base.is_complex()) && + (self_base.is_floating_point() || self_base.is_complex()), + "Expected both tensor and its forward grad to be floating point or complex"); // Lazy initialization { std::lock_guard lock(mutex_); @@ -123,31 +140,44 @@ void AutogradMeta::set_fw_grad(const at::TensorBase& new_grad_base, const at::Te if (fw_grad_->contains(level)) { // Setting the forward grad again is only allowed if it is a no-op. // We do allow this case to simplify writing codegen for inplace ops. - TORCH_INTERNAL_ASSERT(new_grad_base.defined(), "Cannot set a forward grad that is an undefined Tensor. Use " - "_fw_primal(level) to get a new Tensor with this forward grad unset."); + TORCH_INTERNAL_ASSERT( + new_grad_base.defined(), + "Cannot set a forward grad that is an undefined Tensor. Use " + "_fw_primal(level) to get a new Tensor with this forward grad unset."); - TORCH_INTERNAL_ASSERT(is_inplace_op, "Only inplace operations can re-set the forward grad of a Tensor that " - "already has one."); + TORCH_INTERNAL_ASSERT( + is_inplace_op, + "Only inplace operations can re-set the forward grad of a Tensor that " + "already has one."); - TORCH_INTERNAL_ASSERT(fw_grad_->value(level).is_same(new_grad_base), "Cannot set a value of a forward grad if it " - "already exists. Inplace operations should modify it inplace."); + TORCH_INTERNAL_ASSERT( + fw_grad_->value(level).is_same(new_grad_base), + "Cannot set a value of a forward grad if it " + "already exists. Inplace operations should modify it inplace."); } else { // TODO(alband) remove this spurious version counter bump Tensor new_grad(new_grad_base); at::OptionalTensorRef self_ref(self_base); - const Tensor &self = *self_ref; + const Tensor& self = *self_ref; - TORCH_CHECK(self.is_same_size(new_grad), "Trying to set a forward gradient that has a different size than that " - "of the original Tensor, this is not supported. Tensor is of size ", self.sizes(), " while the given " - "forward gradient is of size ", new_grad.sizes(), "."); + TORCH_CHECK( + self.is_same_size(new_grad), + "Trying to set a forward gradient that has a different size than that " + "of the original Tensor, this is not supported. Tensor is of size ", + self.sizes(), + " while the given " + "forward gradient is of size ", + new_grad.sizes(), + "."); if (is_inplace_op && is_view_) { auto this_view_meta = static_cast(this); - // For inplace ops on a Tensor that does not already have a forward grad and is a view, we propagate - // the tangent to the base and ensure that the new_grad is a view of that base's tangent. - // This ensure that case 4 from [Forward Grad View/inplace] above works fine - // What happens in this long if statement is: + // For inplace ops on a Tensor that does not already have a forward grad + // and is a view, we propagate the tangent to the base and ensure that the + // new_grad is a view of that base's tangent. This ensure that case 4 from + // [Forward Grad View/inplace] above works fine What happens in this long + // if statement is: // - Check if the base already has a grad // - If not, set a new fw_grad for it full of zeros // - Take a view of the base's forward grad @@ -158,14 +188,16 @@ void AutogradMeta::set_fw_grad(const at::TensorBase& new_grad_base, const at::Te auto& base = view_info.base_; if (!base._fw_grad(level).defined()) { - // Enforce same meta here to make sure that the view op below is always valid + // Enforce same meta here to make sure that the view op below is + // always valid Tensor new_base_fw_grad; if (has_same_meta(new_grad, base) && has_same_meta(new_grad, self)) { - // TODO extend this special case to when the underlying storage of new_grad - // can be re-used. + // TODO extend this special case to when the underlying storage of + // new_grad can be re-used. new_base_fw_grad = new_grad; } else { - new_base_fw_grad = at::_new_zeros_with_same_feature_meta(new_grad, base); + new_base_fw_grad = + at::_new_zeros_with_same_feature_meta(new_grad, base); new_base_fw_grad._set_conj(base.is_conj()); new_base_fw_grad._set_neg(base.is_neg()); @@ -174,7 +206,8 @@ void AutogradMeta::set_fw_grad(const at::TensorBase& new_grad_base, const at::Te if (view_info.has_view_fn()) { new_fw_grad_value = view_info.view_fn()(new_base_fw_grad); } else { - new_fw_grad_value = new_base_fw_grad.as_strided(self.sizes(), self.strides(), self.storage_offset()); + new_fw_grad_value = new_base_fw_grad.as_strided( + self.sizes(), self.strides(), self.storage_offset()); } new_fw_grad_value.copy_(new_grad); @@ -190,7 +223,8 @@ void AutogradMeta::set_fw_grad(const at::TensorBase& new_grad_base, const at::Te if (!has_same_meta(new_grad, self)) { if (is_view_) { auto this_view_meta = static_cast(this); - TORCH_INTERNAL_ASSERT(!this_view_meta->has_fw_view(), + TORCH_INTERNAL_ASSERT( + !this_view_meta->has_fw_view(), "Expected the output of forward differentiable view operations to have the tangent have the same layout as primal") } auto res = at::_new_zeros_with_same_feature_meta(new_grad, self); @@ -204,7 +238,9 @@ void AutogradMeta::set_fw_grad(const at::TensorBase& new_grad_base, const at::Te } } -const Variable& AutogradMeta::fw_grad(uint64_t level, const at::TensorBase& self) const { +const Variable& AutogradMeta::fw_grad( + uint64_t level, + const at::TensorBase& self) const { // TLS that disables forward AD // This is only used for custom Function implementation if (!c10::AutogradState::get_tls_state().get_fw_grad_mode()) { @@ -214,16 +250,20 @@ const Variable& AutogradMeta::fw_grad(uint64_t level, const at::TensorBase& self // Ensure that concurent fw_grad() "reads" are thread safe std::lock_guard lock(mutex_); - const auto& direct_fw_grad = fw_grad_ ? fw_grad_->value(level) : ForwardGrad::undef_grad(); + const auto& direct_fw_grad = + fw_grad_ ? fw_grad_->value(level) : ForwardGrad::undef_grad(); if (!direct_fw_grad.defined() && is_view_) { // For view that don't have a forward grad, check if their base has one that // has been defined by an inplace operation. // This ensure that case 5 from [Forward Grad View/inplace] above works fine - auto const_view_meta = static_cast(this); - // This is ok to do as we ONLY modify fw_grad_ and this field is properly locked in all methods + auto const_view_meta = + static_cast(this); + // This is ok to do as we ONLY modify fw_grad_ and this field is properly + // locked in all methods // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) - auto this_view_meta = const_cast(const_view_meta); + auto this_view_meta = + const_cast(const_view_meta); if (this_view_meta->has_fw_view()) { const auto& view_info = this_view_meta->get_forward_view(); const auto& base = view_info.base_; @@ -237,7 +277,8 @@ const Variable& AutogradMeta::fw_grad(uint64_t level, const at::TensorBase& self if (view_info.has_view_fn()) { new_val = view_info.view_fn()(base_val); } else { - new_val = base_val.as_strided(self.sizes(), self.strides(), self.storage_offset()); + new_val = base_val.as_strided( + self.sizes(), self.strides(), self.storage_offset()); } this_view_meta->fw_grad_->set_value(new_val, level); @@ -248,4 +289,5 @@ const Variable& AutogradMeta::fw_grad(uint64_t level, const at::TensorBase& self return direct_fw_grad; } -}} // torch::autograd +} // namespace autograd +} // namespace torch diff --git a/torch/csrc/autograd/autograd_not_implemented_fallback.cpp b/torch/csrc/autograd/autograd_not_implemented_fallback.cpp index db5aceded67b9b..65a47766acb784 100644 --- a/torch/csrc/autograd/autograd_not_implemented_fallback.cpp +++ b/torch/csrc/autograd/autograd_not_implemented_fallback.cpp @@ -5,16 +5,17 @@ #include #include +#include +#include #include #include #include #include -#include -#include #include -namespace torch { namespace autograd { +namespace torch { +namespace autograd { namespace { @@ -28,7 +29,7 @@ void _foreach_tensor( int idx_tensor = 0; for (const auto idx_arg : c10::irange(size)) { auto& ivalue = (*stack)[stack_start + idx_arg]; - if (ivalue.isTensor()) { // true for optional tensor that has value + if (ivalue.isTensor()) { // true for optional tensor that has value const auto& tensor = ivalue.toTensor(); fn(idx_tensor, idx_arg, tensor); idx_tensor++; @@ -42,9 +43,12 @@ void _foreach_tensor( } } -} +} // namespace -void autogradNotImplementedFallbackImpl(const c10::OperatorHandle& op, c10::DispatchKeySet dispatch_keys, torch::jit::Stack* stack) { +void autogradNotImplementedFallbackImpl( + const c10::OperatorHandle& op, + c10::DispatchKeySet dispatch_keys, + torch::jit::Stack* stack) { // Mimics a subset of the logic of a VariableType NotImplemented kernel // See gen_variable_type.py const auto& schema = op.schema(); @@ -70,7 +74,6 @@ void autogradNotImplementedFallbackImpl(const c10::OperatorHandle& op, c10::Disp is_inplace_output.push_back(alias_info != nullptr && alias_info->isWrite()); any_is_inplace_output |= alias_info != nullptr && alias_info->isWrite(); is_aliased_output.push_back(alias_info != nullptr); - } int aliased_input_idx = -1; int aliased_output_idx = -1; @@ -78,10 +81,10 @@ void autogradNotImplementedFallbackImpl(const c10::OperatorHandle& op, c10::Disp const at::AliasInfo* alias_info = returns[i].alias_info(); if (alias_info != nullptr && !alias_info->isWrite()) { TORCH_CHECK( - aliased_output_idx == -1, - "Expected only a single output in the operator schema to have a non-write alias annotation (i.e., 'Tensor(a)'). " - "Non-composite functions where multiple outputs are aliased with inputs aren't supported." - "Please rewrite your function as a composite function."); + aliased_output_idx == -1, + "Expected only a single output in the operator schema to have a non-write alias annotation (i.e., 'Tensor(a)'). " + "Non-composite functions where multiple outputs are aliased with inputs aren't supported." + "Please rewrite your function as a composite function."); aliased_output_idx = i; } } @@ -89,118 +92,170 @@ void autogradNotImplementedFallbackImpl(const c10::OperatorHandle& op, c10::Disp const at::AliasInfo* alias_info = arguments[i].alias_info(); if (alias_info != nullptr && !alias_info->isWrite()) { TORCH_CHECK( - aliased_input_idx == -1, - "Expected only a single input in the operator schema to have a non-write alias annotation (i.e., 'Tensor(a)'). " - "Non-composite functions where multiple inputs are aliased with outputs aren't supported. " - "Please rewrite your function as a composite function."); + aliased_input_idx == -1, + "Expected only a single input in the operator schema to have a non-write alias annotation (i.e., 'Tensor(a)'). " + "Non-composite functions where multiple inputs are aliased with outputs aren't supported. " + "Please rewrite your function as a composite function."); aliased_input_idx = i; } } - size_t num_tensor_inputs = 0; // Only used for DEBUG-only checks - _foreach_tensor([&](size_t _, size_t idx_arg, const at::Tensor& t) { - if (grad_mode && t.requires_grad()) { - tensors_requiring_grad_on_stack.push_back(&t); - } - num_tensor_inputs++; - TORCH_CHECK_NOT_IMPLEMENTED(!isFwGradDefined(t), "Trying to use forward AD with ", op_name, " that does not support it."); - }, stack, stack_start, num_arguments); + size_t num_tensor_inputs = 0; // Only used for DEBUG-only checks + _foreach_tensor( + [&](size_t _, size_t idx_arg, const at::Tensor& t) { + if (grad_mode && t.requires_grad()) { + tensors_requiring_grad_on_stack.push_back(&t); + } + num_tensor_inputs++; + TORCH_CHECK_NOT_IMPLEMENTED( + !isFwGradDefined(t), + "Trying to use forward AD with ", + op_name, + " that does not support it."); + }, + stack, + stack_start, + num_arguments); const bool any_requires_grad = tensors_requiring_grad_on_stack.size() > 0; - _foreach_tensor([&](size_t _, size_t i, const at::Tensor& t) { - const at::AliasInfo* alias_info = arguments[i].alias_info(); - if (alias_info != nullptr && alias_info->isWrite()) { - check_inplace(t, any_requires_grad); - } - }, stack, stack_start, num_arguments); + _foreach_tensor( + [&](size_t _, size_t i, const at::Tensor& t) { + const at::AliasInfo* alias_info = arguments[i].alias_info(); + if (alias_info != nullptr && alias_info->isWrite()) { + check_inplace(t, any_requires_grad); + } + }, + stack, + stack_start, + num_arguments); std::shared_ptr grad_fn; if (any_requires_grad) { - grad_fn = std::shared_ptr(new NotImplemented(op_name), deleteNode); - grad_fn->set_next_edges(collect_next_edges(tensors_requiring_grad_on_stack)); + grad_fn = std::shared_ptr( + new NotImplemented(op_name), deleteNode); + grad_fn->set_next_edges( + collect_next_edges(tensors_requiring_grad_on_stack)); } - #ifndef NDEBUG +#ifndef NDEBUG // See NOTE [ TensorImpl and Storage Pointer Sanity Checks ] - auto stack_args_copy = std::vector(stack->begin() + stack_start, stack->end()); + auto stack_args_copy = + std::vector(stack->begin() + stack_start, stack->end()); std::vector> impl_saved; impl_saved.reserve(num_tensor_inputs); std::vector> storage_saved; storage_saved.reserve(num_tensor_inputs); - _foreach_tensor([&](size_t idx, size_t _, const at::Tensor& t) { - storage_saved.push_back(t.has_storage() ? c10::optional(t.storage()) : c10::nullopt); - impl_saved.push_back(t.getIntrusivePtr()); - }, &stack_args_copy, 0, num_arguments); - #endif + _foreach_tensor( + [&](size_t idx, size_t _, const at::Tensor& t) { + storage_saved.push_back( + t.has_storage() ? c10::optional(t.storage()) + : c10::nullopt); + impl_saved.push_back(t.getIntrusivePtr()); + }, + &stack_args_copy, + 0, + num_arguments); +#endif if (aliased_input_idx != -1 || any_is_inplace_output) { at::AutoDispatchBelowAutograd guard; op.redispatchBoxed(dispatch_keys & c10::after_autograd_keyset, stack); } else { // If neither in-place nor view at::AutoDispatchBelowADInplaceOrView guard; - op.redispatchBoxed(dispatch_keys & c10::after_ADInplaceOrView_keyset, stack); + op.redispatchBoxed( + dispatch_keys & c10::after_ADInplaceOrView_keyset, stack); } - #ifndef NDEBUG - _foreach_tensor([&](size_t idx_tensor, size_t _, const at::Tensor& t) { - if (storage_saved.at(idx_tensor).has_value()) - TORCH_INTERNAL_ASSERT(storage_saved.at(idx_tensor).value().is_alias_of(t.storage()), op_name); - if (impl_saved.at(idx_tensor)) - TORCH_INTERNAL_ASSERT(impl_saved.at(idx_tensor) == t.getIntrusivePtr(), op_name); - }, &stack_args_copy, 0, num_arguments); - _foreach_tensor([&](size_t idx_tensor, size_t idx_ret, const at::Tensor& t) { - if (at::impl::tensor_has_dispatch(t) || at::impl::dispatch_mode_enabled()) - return; - if (!is_inplace_output[idx_ret]) - TORCH_INTERNAL_ASSERT(t.use_count() <= 1, op_name); // Okay to return undefined tensor - if (!is_aliased_output[idx_ret] && t.has_storage()) - TORCH_INTERNAL_ASSERT(t.storage().use_count() == 1); - }, stack, stack->size() - num_returns, num_returns); - // There should be only a single base-view pair, make sure their storage is aliased. +#ifndef NDEBUG + _foreach_tensor( + [&](size_t idx_tensor, size_t _, const at::Tensor& t) { + if (storage_saved.at(idx_tensor).has_value()) + TORCH_INTERNAL_ASSERT( + storage_saved.at(idx_tensor).value().is_alias_of(t.storage()), + op_name); + if (impl_saved.at(idx_tensor)) + TORCH_INTERNAL_ASSERT( + impl_saved.at(idx_tensor) == t.getIntrusivePtr(), op_name); + }, + &stack_args_copy, + 0, + num_arguments); + _foreach_tensor( + [&](size_t idx_tensor, size_t idx_ret, const at::Tensor& t) { + if (at::impl::tensor_has_dispatch(t) || + at::impl::dispatch_mode_enabled()) + return; + if (!is_inplace_output[idx_ret]) + TORCH_INTERNAL_ASSERT( + t.use_count() <= 1, op_name); // Okay to return undefined tensor + if (!is_aliased_output[idx_ret] && t.has_storage()) + TORCH_INTERNAL_ASSERT(t.storage().use_count() == 1); + }, + stack, + stack->size() - num_returns, + num_returns); + // There should be only a single base-view pair, make sure their storage is + // aliased. if (aliased_input_idx != -1 && aliased_output_idx != -1) { const c10::IValue& aliased_input_iv = stack_args_copy[aliased_input_idx]; - const c10::IValue& aliased_output_iv = (*stack)[stack->size() - num_returns + aliased_output_idx]; + const c10::IValue& aliased_output_iv = + (*stack)[stack->size() - num_returns + aliased_output_idx]; TORCH_INTERNAL_ASSERT(aliased_input_iv.isTensor(), op_name); - TORCH_INTERNAL_ASSERT(aliased_output_iv.isTensor() || aliased_output_iv.isTensorList() , op_name); + TORCH_INTERNAL_ASSERT( + aliased_output_iv.isTensor() || aliased_output_iv.isTensorList(), + op_name); const at::Tensor& aliased_input = aliased_input_iv.toTensor(); if (aliased_input.has_storage()) { if (aliased_output_iv.isTensor()) { const at::Tensor& aliased_output = aliased_input_iv.toTensor(); - TORCH_INTERNAL_ASSERT(aliased_input.storage().is_alias_of(aliased_output.storage()), op_name); + TORCH_INTERNAL_ASSERT( + aliased_input.storage().is_alias_of(aliased_output.storage()), + op_name); } else { const auto aliased_output_vec = aliased_output_iv.toTensorVector(); for (const auto& aliased_output : aliased_output_vec) { - TORCH_INTERNAL_ASSERT(aliased_input.storage().is_alias_of(aliased_output.storage()), op_name); + TORCH_INTERNAL_ASSERT( + aliased_input.storage().is_alias_of(aliased_output.storage()), + op_name); } } } } - #endif +#endif if (any_requires_grad) { - _foreach_tensor([&](size_t idx_tensor, size_t idx_ret, const at::Tensor& t) { - if (isDifferentiableType(t.scalar_type())) { - if (is_inplace_output[idx_ret]) { - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) - rebase_history(const_cast(t), grad_fn); - } else { - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) - set_history(const_cast(t), grad_fn); - } - } - }, stack, stack->size() - num_returns, num_returns); + _foreach_tensor( + [&](size_t idx_tensor, size_t idx_ret, const at::Tensor& t) { + if (isDifferentiableType(t.scalar_type())) { + if (is_inplace_output[idx_ret]) { + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) + rebase_history(const_cast(t), grad_fn); + } else { + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) + set_history(const_cast(t), grad_fn); + } + } + }, + stack, + stack->size() - num_returns, + num_returns); } } torch::CppFunction autogradNotImplementedFallback() { - return torch::CppFunction::makeFromBoxedFunction<&autogradNotImplementedFallbackImpl>(); + return torch::CppFunction::makeFromBoxedFunction< + &autogradNotImplementedFallbackImpl>(); } -void autogradNotImplementedInplaceOrViewFallbackImpl(const c10::OperatorHandle& op, c10::DispatchKeySet dispatch_keys, torch::jit::Stack* stack) { +void autogradNotImplementedInplaceOrViewFallbackImpl( + const c10::OperatorHandle& op, + c10::DispatchKeySet dispatch_keys, + torch::jit::Stack* stack) { // Mimics a subset of the logic from ADInplaceOrViewType kernel: // - see gen_inplace_or_view_type.py // - this should only be used with autogradNotImplementedFallback above - // - For more information see https://pytorch.org/tutorials/advanced/dispatcher + // - For more information see + // https://pytorch.org/tutorials/advanced/dispatcher // // NOTE [ Limitations of ADInplaceOrView boxed kernel ] // @@ -210,9 +265,11 @@ void autogradNotImplementedInplaceOrViewFallbackImpl(const c10::OperatorHandle& // is not implemented" error is still raised. // // Just like the codegened kernel, we try to enforce some things: - // - For views: we enforce that the view relationship is between the first input + // - For views: we enforce that the view relationship is between the first + // input // and the first output (which may be either Tensor or vec of Tensors - // - For inplace (TODO?): enforce that the same op cannot be both a view and inplace + // - For inplace (TODO?): enforce that the same op cannot be both a view and + // inplace // that is not allowed in the gen_inplace_or_view logic const auto& schema = op.schema(); const auto& op_name = schema.operator_name().name; @@ -229,11 +286,11 @@ void autogradNotImplementedInplaceOrViewFallbackImpl(const c10::OperatorHandle& const at::AliasInfo* alias_info = returns[i].alias_info(); if (alias_info != nullptr && !alias_info->isWrite()) { TORCH_CHECK( - aliased_output_idx == -1, - "Fallback ADInplaceOrView kernel expects only a single output in the operator schema to have a " - "non-write alias annotation (i.e., 'Tensor(a)'). " - "Non-composite functions where multiple outputs are aliased with inputs aren't supported." - "Please rewrite your function as a composite function."); + aliased_output_idx == -1, + "Fallback ADInplaceOrView kernel expects only a single output in the operator schema to have a " + "non-write alias annotation (i.e., 'Tensor(a)'). " + "Non-composite functions where multiple outputs are aliased with inputs aren't supported." + "Please rewrite your function as a composite function."); aliased_output_idx = i; } } @@ -244,30 +301,36 @@ void autogradNotImplementedInplaceOrViewFallbackImpl(const c10::OperatorHandle& if (alias_info != nullptr) { if (!alias_info->isWrite()) { TORCH_CHECK( - aliased_input_idx == -1, - "Fallback ADInplaceOrView kernel expects only a single input in the operator schema to have a " - "non-write alias annotation (i.e., 'Tensor(a)'). " - "Non-composite functions where multiple inputs are aliased with outputs aren't supported. " - "Please rewrite your function as a composite function."); + aliased_input_idx == -1, + "Fallback ADInplaceOrView kernel expects only a single input in the operator schema to have a " + "non-write alias annotation (i.e., 'Tensor(a)'). " + "Non-composite functions where multiple inputs are aliased with outputs aren't supported. " + "Please rewrite your function as a composite function."); aliased_input_idx = i; - const c10::IValue& aliased_input_iv = (*stack)[stack_start + i]; // get a reference to an ivalue on the stack + const c10::IValue& aliased_input_iv = + (*stack)[stack_start + i]; // get a reference to an ivalue on the + // stack TORCH_CHECK(aliased_input_iv.isTensor()); - aliased_input = aliased_input_iv.toTensor(); // TODO: Can we avoid saving this tensor and incurring the refcount bump? + aliased_input = + aliased_input_iv + .toTensor(); // TODO: Can we avoid saving this tensor and + // incurring the refcount bump? } } } // See NOTE [ Limitations of ADInplaceOrView boxed kernel ] above TORCH_CHECK( - (aliased_input_idx == -1 && aliased_output_idx == -1) || - (aliased_input_idx == 0 && aliased_output_idx == 0), - "Fallback ADInplaceOrView kernel can only create view relationships between the first " - "input and the first output (the output can be a vector of tensors). Please change the " - "order of your operator's parameters so that this is the case."); + (aliased_input_idx == -1 && aliased_output_idx == -1) || + (aliased_input_idx == 0 && aliased_output_idx == 0), + "Fallback ADInplaceOrView kernel can only create view relationships between the first " + "input and the first output (the output can be a vector of tensors). Please change the " + "order of your operator's parameters so that this is the case."); const bool is_view = aliased_input_idx != -1; { at::AutoDispatchBelowADInplaceOrView guard; - op.redispatchBoxed(dispatch_keys & c10::after_ADInplaceOrView_keyset, stack); + op.redispatchBoxed( + dispatch_keys & c10::after_ADInplaceOrView_keyset, stack); } for (const auto i : c10::irange(num_returns)) { @@ -278,41 +341,60 @@ void autogradNotImplementedInplaceOrViewFallbackImpl(const c10::OperatorHandle& } if (is_view) { - c10::IValue& aliased_output_iv = (*stack)[stack->size() - num_returns + aliased_output_idx]; + c10::IValue& aliased_output_iv = + (*stack)[stack->size() - num_returns + aliased_output_idx]; if (aliased_output_iv.isTensorList()) { auto aliased_output = aliased_output_iv.toTensorVector(); // Only allow rebasing of the history if we return a single Tensor that is // why we don't have to care about the view_func logic below. // See NOTE [ View + Inplace detection ] for more details about this logic auto result = as_view( - /* base=*/aliased_input, - /* tensors=*/aliased_output, - /* is_bw_differentiable=*/true, - /* is_fw_differentiable=*/true, - /* creation_meta=*/InferenceMode::is_enabled() ? CreationMeta::INFERENCE_MODE : (at::GradMode::is_enabled() ? CreationMeta::MULTI_OUTPUT_NODE : CreationMeta::NO_GRAD_MODE)); - // ^ pass in creation meta unecessarily even if not isDifferentiableType, but we don't have that - // information here anyway. + /* base=*/aliased_input, + /* tensors=*/aliased_output, + /* is_bw_differentiable=*/true, + /* is_fw_differentiable=*/true, + /* creation_meta=*/ + InferenceMode::is_enabled() + ? CreationMeta::INFERENCE_MODE + : (at::GradMode::is_enabled() ? CreationMeta::MULTI_OUTPUT_NODE + : CreationMeta::NO_GRAD_MODE)); + // ^ pass in creation meta unecessarily even if not isDifferentiableType, + // but we don't have that + // information here anyway. stack->at(stack->size() - num_returns + aliased_output_idx) = result; } else { TORCH_CHECK(aliased_output_iv.isTensor()); auto result = as_view( - /* base=*/aliased_input, - /* tensor=*/std::move(aliased_output_iv).toTensor(), - /* is_bw_differentiable=*/true, - /* is_fw_differentiable=*/true, - /* view_func=*/[op_name=op_name](const at::Tensor&) { - // We always need this view_func because otherwise if we do in-place on this view, - // we would implicitly use AsStridedBackward instead of the NotImplemented node. - // For the cross-dtype/non-strided cases, we would create something like this anyway - TORCH_CHECK(false, "Mutating the view ", op_name, " which does not have a derivative implemented is forbidden."); return at::Tensor();}, - /* creation_meta=*/InferenceMode::is_enabled() ? CreationMeta::INFERENCE_MODE : (at::GradMode::is_enabled() ? CreationMeta::DEFAULT : CreationMeta::NO_GRAD_MODE)); + /* base=*/aliased_input, + /* tensor=*/std::move(aliased_output_iv).toTensor(), + /* is_bw_differentiable=*/true, + /* is_fw_differentiable=*/true, + /* view_func=*/ + [op_name = op_name](const at::Tensor&) { + // We always need this view_func because otherwise if we do in-place + // on this view, we would implicitly use AsStridedBackward instead + // of the NotImplemented node. For the cross-dtype/non-strided + // cases, we would create something like this anyway + TORCH_CHECK( + false, + "Mutating the view ", + op_name, + " which does not have a derivative implemented is forbidden."); + return at::Tensor(); + }, + /* creation_meta=*/ + InferenceMode::is_enabled() + ? CreationMeta::INFERENCE_MODE + : (at::GradMode::is_enabled() ? CreationMeta::DEFAULT + : CreationMeta::NO_GRAD_MODE)); stack->at(stack->size() - num_returns + aliased_output_idx) = result; } } } torch::CppFunction autogradNotImplementedInplaceOrViewFallback() { - return torch::CppFunction::makeFromBoxedFunction<&autogradNotImplementedInplaceOrViewFallbackImpl>(); + return torch::CppFunction::makeFromBoxedFunction< + &autogradNotImplementedInplaceOrViewFallbackImpl>(); } } // namespace autograd diff --git a/torch/csrc/autograd/autograd_not_implemented_fallback.h b/torch/csrc/autograd/autograd_not_implemented_fallback.h index 315f3cdbd9bd5b..db9cfb6c007b39 100644 --- a/torch/csrc/autograd/autograd_not_implemented_fallback.h +++ b/torch/csrc/autograd/autograd_not_implemented_fallback.h @@ -9,4 +9,5 @@ TORCH_API torch::CppFunction autogradNotImplementedFallback(); TORCH_API torch::CppFunction autogradNotImplementedInplaceOrViewFallback(); -}} // namespace torch::autograd +} // namespace autograd +} // namespace torch diff --git a/torch/csrc/autograd/cpp_hook.cpp b/torch/csrc/autograd/cpp_hook.cpp index 981a774e32e542..e006ceb2579470 100644 --- a/torch/csrc/autograd/cpp_hook.cpp +++ b/torch/csrc/autograd/cpp_hook.cpp @@ -1,30 +1,35 @@ #include #include -#include #include +#include namespace { using torch::autograd::Variable; -void check_single_result (const at::TensorBase &value, const at::TensorBase &result, std::string hook_name) { +void check_single_result( + const at::TensorBase& value, + const at::TensorBase& result, + std::string hook_name) { if (!value.defined()) { - throw std::runtime_error("can't replace a empty gradient with a non-empty value"); + throw std::runtime_error( + "can't replace a empty gradient with a non-empty value"); } torch::autograd::check_variable_result(value, result, hook_name); } -} +} // namespace -namespace torch { namespace autograd { +namespace torch { +namespace autograd { // NOLINTNEXTLINE(modernize-pass-by-value) -CppFunctionPreHook::CppFunctionPreHook(const std::shared_ptr &hooks, int value_idx) -: hooks_(hooks) -, value_idx_(value_idx) -{} +CppFunctionPreHook::CppFunctionPreHook( + const std::shared_ptr& hooks, + int value_idx) + : hooks_(hooks), value_idx_(value_idx) {} variable_list CppFunctionPreHook::operator()(const variable_list& values) { auto value = values[value_idx_]; for (const auto i : c10::irange(hooks_->size())) { - auto &hook = (*hooks_)[i]; + auto& hook = (*hooks_)[i]; if (!hook) { // hook was removed continue; @@ -42,4 +47,5 @@ variable_list CppFunctionPreHook::operator()(const variable_list& values) { return results; } -}} // namespace torch::autograd +} // namespace autograd +} // namespace torch diff --git a/torch/csrc/autograd/cpp_hook.h b/torch/csrc/autograd/cpp_hook.h index c3d9135f437bb7..7dcdf6884c36e7 100644 --- a/torch/csrc/autograd/cpp_hook.h +++ b/torch/csrc/autograd/cpp_hook.h @@ -3,15 +3,18 @@ #include #include -namespace torch { namespace autograd { +namespace torch { +namespace autograd { -using hooks_list = std::vector>; +using hooks_list = + std::vector>; struct CppFunctionPreHook : public FunctionPreHook { - CppFunctionPreHook(const std::shared_ptr &hooks, int value_idx); + CppFunctionPreHook(const std::shared_ptr& hooks, int value_idx); variable_list operator()(const variable_list& values) override; std::shared_ptr hooks_; int value_idx_; }; -}} // namespace torch::autograd +} // namespace autograd +} // namespace torch diff --git a/torch/csrc/autograd/custom_function.cpp b/torch/csrc/autograd/custom_function.cpp index 553e8aa674702a..36df9bfac273e9 100644 --- a/torch/csrc/autograd/custom_function.cpp +++ b/torch/csrc/autograd/custom_function.cpp @@ -1,18 +1,18 @@ #include +#include #include #include -#include -namespace torch { namespace autograd { +namespace torch { +namespace autograd { VariableInfo::VariableInfo(const Variable& var) - : layout(var.layout()) - , device(var.device()) - , scalar_type(var.scalar_type()) - , size(var.sizes().vec()) - , requires_grad(var.requires_grad()) - , is_empty(false) { -} + : layout(var.layout()), + device(var.device()), + scalar_type(var.scalar_type()), + size(var.sizes().vec()), + requires_grad(var.requires_grad()), + is_empty(false) {} VariableInfo::VariableInfo() : requires_grad(false), is_empty(true) {} @@ -27,8 +27,9 @@ Variable VariableInfo::zeros(at::OptionalDeviceGuard& device_guard) const { } // This function has two main goals: -// 1) Use the user-provided jvp function to populate the the outputs' forward gradient -// 2) Perform error checking to ensure that view and inplace ops are properly handled +// 1) Use the user-provided jvp function to populate the the outputs' forward +// gradient 2) Perform error checking to ensure that view and inplace ops are +// properly handled // // For 1) we have to: // - Create a variable_list of grad_inputs based on the function inputs @@ -36,18 +37,21 @@ Variable VariableInfo::zeros(at::OptionalDeviceGuard& device_guard) const { // - Set the forward grad field on each output based on these grad_outputs // // For 2) we want to check the following: -// - If an output is a view, then the generated forward grad must be a view as well and +// - If an output is a view, then the generated forward grad must be a view as +// well and // the output's base's forward grad must be the output's forward grad's base. -// - If an input was modified inplace (it must be an output as well) we make sure that its -// forward grad was also modified inplace and already present on the corresponding output. -void _process_forward_mode_AD(const variable_list &inputs, - std::unordered_map inputs_mapping, - const at::ArrayRef> raw_outputs, - const optional_variable_list &outputs, - const std::unordered_set &non_differentiable, - const std::unordered_set &dirty_inputs, - _jvp_fn_t jvp_user_function) { - +// - If an input was modified inplace (it must be an output as well) we make +// sure that its +// forward grad was also modified inplace and already present on the +// corresponding output. +void _process_forward_mode_AD( + const variable_list& inputs, + std::unordered_map inputs_mapping, + const at::ArrayRef> raw_outputs, + const optional_variable_list& outputs, + const std::unordered_set& non_differentiable, + const std::unordered_set& dirty_inputs, + _jvp_fn_t jvp_user_function) { // TODO handle multiple levels here uint64_t level = 0; @@ -55,31 +59,35 @@ void _process_forward_mode_AD(const variable_list &inputs, const auto num_outputs = outputs.size(); // The tracking info below are used to perform the view and inplace checks. - // They are lazily initialized to reduce the cost of this function in the common - // case where the user is not using forward mode AD. + // They are lazily initialized to reduce the cost of this function in the + // common case where the user is not using forward mode AD. variable_list input_grads; std::vector grad_versions; std::vector grad_impls; std::unordered_map inputs_bases; - auto init_tracked_info = [&] () { + auto init_tracked_info = [&]() { input_grads.resize(num_inputs); grad_versions.resize(num_inputs); grad_impls.resize(num_inputs); - for (const auto i: c10::irange(num_inputs)) { + for (const auto i : c10::irange(num_inputs)) { const auto& inp = inputs[i]; if (inp.is_view() && impl::get_view_autograd_meta(inp)->has_fw_view()) { - inputs_bases.emplace(impl::get_view_autograd_meta(inp)->get_forward_view().base_.unsafeGetTensorImpl(), i); + inputs_bases.emplace( + impl::get_view_autograd_meta(inp) + ->get_forward_view() + .base_.unsafeGetTensorImpl(), + i); } else { inputs_bases.emplace(inp.unsafeGetTensorImpl(), i); } - } }; bool any_input_has_grad = false; - // Extract the input's forward gradients and record any info we will need later + // Extract the input's forward gradients and record any info we will need + // later for (const auto i : c10::irange(num_inputs)) { const auto& inp = inputs[i]; if (!inp.defined()) { @@ -111,16 +119,26 @@ void _process_forward_mode_AD(const variable_list &inputs, // NOLINTNEXTLINE(cppcoreguidelines-init-variables) const auto num_forward_grads = forward_grads.size(); // contrary to backward mode, we don't allow returning too many gradients - TORCH_CHECK(num_forward_grads == num_outputs, "Function's jvp returned " - "an invalid number of forward gradients (expected ", num_outputs, - " but got ", num_forward_grads, ")"); + TORCH_CHECK( + num_forward_grads == num_outputs, + "Function's jvp returned " + "an invalid number of forward gradients (expected ", + num_outputs, + " but got ", + num_forward_grads, + ")"); for (const auto i : c10::irange(num_outputs)) { - const auto& out = outputs[i].has_value()? outputs[i].value() : at::Tensor(); + const auto& out = + outputs[i].has_value() ? outputs[i].value() : at::Tensor(); const auto& out_grad = forward_grads[i]; if (!out.defined()) { - TORCH_CHECK(!out_grad.defined(), "Function's jvp returned a gradient at position ", i, ", but " - " the corresponding forward output is not a differentiable Tensor"); + TORCH_CHECK( + !out_grad.defined(), + "Function's jvp returned a gradient at position ", + i, + ", but " + " the corresponding forward output is not a differentiable Tensor"); continue; } @@ -130,63 +148,91 @@ void _process_forward_mode_AD(const variable_list &inputs, bool is_modified = dirty_inputs.count(out_tensor_impl) > 0; if (is_modified) { - TORCH_CHECK(is_input, "Only input Tensors should be given to ctx.mark_dirty(). If a Tensor is not an input, there" - " is no need to pass it to mark_dirty()."); + TORCH_CHECK( + is_input, + "Only input Tensors should be given to ctx.mark_dirty(). If a Tensor is not an input, there" + " is no need to pass it to mark_dirty()."); auto inp_idx = inputs_mapping[out_tensor_impl]; if (grad_impls[inp_idx]) { // If there was already a forward grad for that input // Just make sure that it is modified inplace and returned as-is - TORCH_CHECK(out_grad._version() != grad_versions[inp_idx], "An inplace custom Function is not modifying the " - "forward mode gradients inplace. If the forward is modifying an input inplace, then the jvp " - "function must modify the corresponding gradient inplace.") - TORCH_CHECK(out_grad.unsafeGetTensorImpl() == grad_impls[inp_idx], "An inplace custom Function is not returning the " - "forward mode gradients as-is. If the forward is modifying an input inplace, then the jvp " - "function must modify the gradient inplace and return it as-is.") + TORCH_CHECK( + out_grad._version() != grad_versions[inp_idx], + "An inplace custom Function is not modifying the " + "forward mode gradients inplace. If the forward is modifying an input inplace, then the jvp " + "function must modify the corresponding gradient inplace.") + TORCH_CHECK( + out_grad.unsafeGetTensorImpl() == grad_impls[inp_idx], + "An inplace custom Function is not returning the " + "forward mode gradients as-is. If the forward is modifying an input inplace, then the jvp " + "function must modify the gradient inplace and return it as-is.") } else { - // If that Tensor didn't had gradients already, set the newly returned one - // We could also use inputs[inp_idx] here as it is the same as out + // If that Tensor didn't had gradients already, set the newly returned + // one We could also use inputs[inp_idx] here as it is the same as out out._set_fw_grad(out_grad, level, /* is_inplace_op */ true); } } else { - // At this point, outputs[i] cannot be one of the input (raw_outputs[i] might be but was changed by the backward code) - TORCH_INTERNAL_ASSERT(inputs_mapping.count(out.unsafeGetTensorImpl()) == 0); + // At this point, outputs[i] cannot be one of the input (raw_outputs[i] + // might be but was changed by the backward code) + TORCH_INTERNAL_ASSERT( + inputs_mapping.count(out.unsafeGetTensorImpl()) == 0); if (out.is_view() && impl::get_view_autograd_meta(out)->has_fw_view()) { // If the output is a view - const auto& out_view_info = impl::get_view_autograd_meta(out)->get_forward_view(); + const auto& out_view_info = + impl::get_view_autograd_meta(out)->get_forward_view(); if (inputs_bases.count(out_view_info.base_.unsafeGetTensorImpl())) { - // And it is a view of an input (either that input is its base or they have a common base) - const auto matching_input_idx = inputs_bases[out_view_info.base_.unsafeGetTensorImpl()]; + // And it is a view of an input (either that input is its base or they + // have a common base) + const auto matching_input_idx = + inputs_bases[out_view_info.base_.unsafeGetTensorImpl()]; const auto& matching_input = inputs[matching_input_idx]; const auto& matching_input_grad = matching_input._fw_grad(level); - // If the matching input has a forward grad, the user should have returned a view of that Tensor + // If the matching input has a forward grad, the user should have + // returned a view of that Tensor if (matching_input_grad.defined()) { - TORCH_CHECK(out_grad.is_view() && impl::get_view_autograd_meta(out_grad)->has_fw_view(), - "A custom Function's forward is returning a view (or an input as-is) but the jvp is not " - "returning a view."); - const auto& out_grad_base = impl::get_view_autograd_meta(out_grad)->get_forward_view().base_; - if (matching_input_grad.is_view() && impl::get_view_autograd_meta(matching_input_grad)->has_fw_view()) { - // If the matching input's grad is a view, ensure that the out_grad is a view of the same base - const auto& matching_input_grad_base = impl::get_view_autograd_meta(matching_input_grad)->get_forward_view().base_; - TORCH_CHECK(matching_input_grad_base.unsafeGetTensorImpl() == out_grad_base.unsafeGetTensorImpl(), - "A custom Function is returning a view but the jvp is not returning a view of the same base as " - "the given grad input."); + TORCH_CHECK( + out_grad.is_view() && + impl::get_view_autograd_meta(out_grad)->has_fw_view(), + "A custom Function's forward is returning a view (or an input as-is) but the jvp is not " + "returning a view."); + const auto& out_grad_base = impl::get_view_autograd_meta(out_grad) + ->get_forward_view() + .base_; + if (matching_input_grad.is_view() && + impl::get_view_autograd_meta(matching_input_grad) + ->has_fw_view()) { + // If the matching input's grad is a view, ensure that the + // out_grad is a view of the same base + const auto& matching_input_grad_base = + impl::get_view_autograd_meta(matching_input_grad) + ->get_forward_view() + .base_; + TORCH_CHECK( + matching_input_grad_base.unsafeGetTensorImpl() == + out_grad_base.unsafeGetTensorImpl(), + "A custom Function is returning a view but the jvp is not returning a view of the same base as " + "the given grad input."); } else { - // If the matching input's grad is not a view, then it must be the output gradient's base - TORCH_CHECK(matching_input_grad.unsafeGetTensorImpl() == out_grad_base.unsafeGetTensorImpl(), - "A custom Function is returning a view but the jvp is not returning a view of the given grad input."); + // If the matching input's grad is not a view, then it must be the + // output gradient's base + TORCH_CHECK( + matching_input_grad.unsafeGetTensorImpl() == + out_grad_base.unsafeGetTensorImpl(), + "A custom Function is returning a view but the jvp is not returning a view of the given grad input."); } } else { - // We have a view op where the input didn't have a forward grad but the user returned one for the output - // To ensure that we maintain the view/inplace constraints, we consider this as an inplace op - // This case CANNOT happen in codegen as all view ops are mapping from one Tensor to one Tensor and so the output - // of the view cannot have a forward grad if the base does not. + // We have a view op where the input didn't have a forward grad but + // the user returned one for the output To ensure that we maintain + // the view/inplace constraints, we consider this as an inplace op + // This case CANNOT happen in codegen as all view ops are mapping + // from one Tensor to one Tensor and so the output of the view + // cannot have a forward grad if the base does not. out._set_fw_grad(out_grad, level, /* is_inplace_op */ true); return; } - } } @@ -202,31 +248,33 @@ at::Tensor _view_as_self_with_no_grad(at::Tensor self) { // so that we can attach a new grad_fn to the Variable. // Run in no_grad mode to mimic the behavior of the forward. // - // (2) Though it is not necessary for the purposes of attaching grad_fn, we also call - // this function when an output is non-differentiable (and does not require grad). - // to help custom forward AD UX more consistent. We'd like to uniformly say - // that returning an input as-is is treated as if `self.view_as(self)` were - // returned for that output. + // (2) Though it is not necessary for the purposes of attaching grad_fn, we + // also call this function when an output is non-differentiable (and does not + // require grad). to help custom forward AD UX more consistent. We'd like to + // uniformly say that returning an input as-is is treated as if + // `self.view_as(self)` were returned for that output. // - // Alternatively, we could have not disabled forward grad while performing this - // view, but it would mean that the user defined jvp may be silently ignored. + // Alternatively, we could have not disabled forward grad while performing + // this view, but it would mean that the user defined jvp may be silently + // ignored. at::AutoFwGradMode fw_grad_mode(false); AutoGradMode grad_mode(false); return self.view_as(self); } optional_variable_list _process_backward_mode_ad( - const std::unordered_map &inputs_mapping, - const std::unordered_set &non_differentiable, - const std::unordered_set &dirty_inputs, - const at::ArrayRef> raw_outputs, - const std::shared_ptr &cdata) { - - + const std::unordered_map& inputs_mapping, + const std::unordered_set& non_differentiable, + const std::unordered_set& dirty_inputs, + const at::ArrayRef> raw_outputs, + const std::shared_ptr& cdata) { int num_outputs = raw_outputs.size(); // Sets the grad_fn and output_nr of an output Variable. - auto set_history = [&](Variable& var, uint32_t output_nr, bool is_input, bool is_modified, + auto set_history = [&](Variable& var, + uint32_t output_nr, + bool is_input, + bool is_modified, bool is_differentiable) { if (!is_differentiable) { if (!var.requires_grad()) { @@ -235,8 +283,8 @@ optional_variable_list _process_backward_mode_ad( } return; } - // Return detached aliases of inputs, instead of changing their requires_grad - // property. + // Return detached aliases of inputs, instead of changing their + // requires_grad property. if (is_input) { var = var.detach(); } else if (!var.is_view()) { @@ -251,22 +299,27 @@ optional_variable_list _process_backward_mode_ad( // Here, `y` requires_grad (!). } else if (is_modified) { if (var.is_leaf() && var.requires_grad()) { - TORCH_CHECK(false, "a leaf Variable that requires grad has been used in an in-place operation."); + TORCH_CHECK( + false, + "a leaf Variable that requires grad has been used in an in-place operation."); } // No need to mark as modified Tensors that are not inputs. if (!is_input) { - TORCH_WARN("Only input Tensors should be given to ctx.mark_dirty(). If a Tensor is not an input, there" - " is no need to pass it to mark_dirty()."); + TORCH_WARN( + "Only input Tensors should be given to ctx.mark_dirty(). If a Tensor is not an input, there" + " is no need to pass it to mark_dirty()."); } - // If the input is a view, the rebase will need to rewrite the graph and this only works if we have a single - // output to this Function. - TORCH_CHECK(!(var.is_view() && num_outputs > 1), "If your Function modifies inplace an input that is a view" - " of another Tensor, your Function cannot return more than one Tensor. This is not supported" - " by the current autograd engine. You should either make sure the input is not a view (using" - " .clone() for example) or make your Function only return one Tensor (potentially splitting" - " it into two Functions: one doing the inplace that returns a single Tensor and a second one" - " that does the other operations). You can ask on the forum https://discuss.pytorch.org/ if" - " you need help to do this change."); + // If the input is a view, the rebase will need to rewrite the graph and + // this only works if we have a single output to this Function. + TORCH_CHECK( + !(var.is_view() && num_outputs > 1), + "If your Function modifies inplace an input that is a view" + " of another Tensor, your Function cannot return more than one Tensor. This is not supported" + " by the current autograd engine. You should either make sure the input is not a view (using" + " .clone() for example) or make your Function only return one Tensor (potentially splitting" + " it into two Functions: one doing the inplace that returns a single Tensor and a second one" + " that does the other operations). You can ask on the forum https://discuss.pytorch.org/ if" + " you need help to do this change."); // If the input was modified, transplant the grad_fn in the graph: // grad_fn <- variable <- self ==> grad_fn <- self <- variable @@ -292,7 +345,6 @@ optional_variable_list _process_backward_mode_ad( outputs.reserve(num_outputs); int num_diff_outputs = 0; - for (const auto i : c10::irange(num_outputs)) { // For outputs that are not tensors, put a placeholder undefined input. if (!raw_outputs[i].has_value()) { @@ -309,8 +361,9 @@ optional_variable_list _process_backward_mode_ad( auto out_tensor_impl = var.unsafeGetTensorImpl(); bool is_input = inputs_mapping.count(out_tensor_impl) > 0; bool is_modified = dirty_inputs.count(out_tensor_impl) > 0; - bool is_differentiable = cdata && non_differentiable.count(out_tensor_impl) == 0 - && isDifferentiableType(var.scalar_type()); + bool is_differentiable = cdata && + non_differentiable.count(out_tensor_impl) == 0 && + isDifferentiableType(var.scalar_type()); if (cdata) { auto output_nr = cdata->add_input_metadata(var); @@ -318,10 +371,11 @@ optional_variable_list _process_backward_mode_ad( } set_history(var, i, is_input, is_modified, is_differentiable); - // For deprecation cycle. Can be removed after 1.6. In the case where we detected a view - // in no grad mode during the forward, only warn the user (do not change the flag if we - // return and input that is a view as is). - // See NOTE [ View + Inplace detection ] for why we replace everything by a warning. + // For deprecation cycle. Can be removed after 1.6. In the case where we + // detected a view in no grad mode during the forward, only warn the user + // (do not change the flag if we return and input that is a view as is). See + // NOTE [ View + Inplace detection ] for why we replace everything by a + // warning. if (!(is_input && is_modified) && var.is_view()) { // is_view() => diff_view_meta auto diff_view_meta = impl::get_view_autograd_meta(var); @@ -336,10 +390,10 @@ optional_variable_list _process_backward_mode_ad( outputs.emplace_back(var); } - // If multiple differentiable outputs are returned, we do not allow views to be modified inplace - // See NOTE [ View + Inplace detection ] for more details + // If multiple differentiable outputs are returned, we do not allow views to + // be modified inplace See NOTE [ View + Inplace detection ] for more details if (num_diff_outputs > 1) { - for (auto& var: outputs) { + for (auto& var : outputs) { if (var.has_value()) { auto diff_view_meta = impl::get_view_autograd_meta(var.value()); if (diff_view_meta && diff_view_meta->has_bw_view()) { @@ -349,41 +403,52 @@ optional_variable_list _process_backward_mode_ad( } } - // All the modified Tensors must be returned as is for the rewrite to be valid. + // All the modified Tensors must be returned as is for the rewrite to be + // valid. for (auto& dirty_input : dirty_inputs) { - TORCH_CHECK(outputs_impl.count(dirty_input) > 0, - "Some elements marked as dirty during the forward method were not returned as output. The" - " inputs that are modified inplace must all be outputs of the Function."); + TORCH_CHECK( + outputs_impl.count(dirty_input) > 0, + "Some elements marked as dirty during the forward method were not returned as output. The" + " inputs that are modified inplace must all be outputs of the Function."); } return outputs; } - - -optional_variable_list _wrap_outputs(const variable_list &input_vars, - const std::unordered_set &non_differentiable, - const std::unordered_set &dirty_inputs, - const at::ArrayRef> raw_outputs, - const std::shared_ptr &cdata, - _jvp_fn_t jvp_user_function) { - +optional_variable_list _wrap_outputs( + const variable_list& input_vars, + const std::unordered_set& non_differentiable, + const std::unordered_set& dirty_inputs, + const at::ArrayRef> raw_outputs, + const std::shared_ptr& cdata, + _jvp_fn_t jvp_user_function) { std::unordered_map inputs_mapping; inputs_mapping.reserve(input_vars.size()); - for (const auto i: c10::irange(input_vars.size())) { + for (const auto i : c10::irange(input_vars.size())) { inputs_mapping.emplace(input_vars[i].unsafeGetTensorImpl(), i); } - auto outputs = _process_backward_mode_ad(inputs_mapping, non_differentiable, dirty_inputs, raw_outputs, cdata); + auto outputs = _process_backward_mode_ad( + inputs_mapping, non_differentiable, dirty_inputs, raw_outputs, cdata); - // This must happen after the backward processing as we expect the computations happening here to track - // backward mode gradients. - _process_forward_mode_AD(input_vars, inputs_mapping, raw_outputs, outputs, non_differentiable, dirty_inputs, jvp_user_function); + // This must happen after the backward processing as we expect the + // computations happening here to track backward mode gradients. + _process_forward_mode_AD( + input_vars, + inputs_mapping, + raw_outputs, + outputs, + non_differentiable, + dirty_inputs, + jvp_user_function); return outputs; } -void check_variable_result(const at::TensorBase& original, const at::TensorBase& result, std::string hook_name) { +void check_variable_result( + const at::TensorBase& original, + const at::TensorBase& result, + std::string hook_name) { if (!original.options().type_equal(result.options())) { std::stringstream ss; ss << "hook '" << hook_name << "' has changed the type of value ("; @@ -414,8 +479,8 @@ void AutogradContext::save_for_backward(variable_list to_save) { to_save_ = std::move(to_save); } -// The logic for handling saved variables here is the same as python_function.cpp -// See _save_variables() and unpack_saved_variables() +// The logic for handling saved variables here is the same as +// python_function.cpp See _save_variables() and unpack_saved_variables() void AutogradContext::save_variables() { saved_variables_.clear(); auto ptr = grad_fn_.lock(); @@ -444,18 +509,18 @@ variable_list AutogradContext::get_saved_variables() const { return saved; } -void AutogradContext::mark_dirty(const variable_list &inputs) { +void AutogradContext::mark_dirty(const variable_list& inputs) { dirty_inputs_.clear(); dirty_inputs_.reserve(inputs.size()); - for(auto& var : inputs) { + for (auto& var : inputs) { dirty_inputs_.insert(var.unsafeGetTensorImpl()); } } -void AutogradContext::mark_non_differentiable(const variable_list &outputs) { +void AutogradContext::mark_non_differentiable(const variable_list& outputs) { non_differentiable_.clear(); non_differentiable_.reserve(outputs.size()); - for(auto& var : outputs) { + for (auto& var : outputs) { non_differentiable_.insert(var.unsafeGetTensorImpl()); } } @@ -464,14 +529,17 @@ void AutogradContext::set_materialize_grads(bool value) { materialize_grads_ = value; } -const std::unordered_set& AutogradContext::get_and_bump_dirty() const { +const std::unordered_set& AutogradContext::get_and_bump_dirty() + const { for (auto& var : dirty_inputs_) { var->bump_version(); } return dirty_inputs_; } -const std::unordered_set& AutogradContext::get_non_differentiable() const { +const std::unordered_set& AutogradContext:: + get_non_differentiable() const { return non_differentiable_; } -}} // namespace torch::autograd +} // namespace autograd +} // namespace torch diff --git a/torch/csrc/autograd/custom_function.h b/torch/csrc/autograd/custom_function.h index 7e67a1086ff709..2875c0376ed3c1 100644 --- a/torch/csrc/autograd/custom_function.h +++ b/torch/csrc/autograd/custom_function.h @@ -1,30 +1,33 @@ #pragma once -#include -#include #include #include #include +#include +#include #include -namespace torch { namespace autograd { +namespace torch { +namespace autograd { using optional_variable_list = std::vector>; using _jvp_fn_t = std::function; TORCH_API std::vector> _wrap_outputs( - const variable_list &input_vars, - const std::unordered_set &non_differentiable, - const std::unordered_set &dirty_inputs, - const at::ArrayRef> raw_outputs, - const std::shared_ptr &cdata, - _jvp_fn_t jvp_user_function); - -TORCH_API void check_variable_result(const at::TensorBase& original, - const at::TensorBase& result, std::string hook_name); + const variable_list& input_vars, + const std::unordered_set& non_differentiable, + const std::unordered_set& dirty_inputs, + const at::ArrayRef> raw_outputs, + const std::shared_ptr& cdata, + _jvp_fn_t jvp_user_function); + +TORCH_API void check_variable_result( + const at::TensorBase& original, + const at::TensorBase& result, + std::string hook_name); // Get the return type of the forward function of the custom Function class X -template +template using forward_t = decltype(X::forward(nullptr, std::declval()...)); /// To use custom autograd operations, implement a Function subclass with @@ -88,16 +91,18 @@ struct TORCH_API Function { // is not declared yet. // The enable_if check is to ensure that the user doesn't explicitly provide // the parameter X. - template - static auto apply(Args&&... args) -> std::enable_if_t::value, forward_t>; + template + static auto apply(Args&&... args) + -> std::enable_if_t::value, forward_t>; }; -/// Context to save information during `forward` that can be accessed in `backward` -/// in custom autograd operations (see `torch::autograd::Function` for details). +/// Context to save information during `forward` that can be accessed in +/// `backward` in custom autograd operations (see `torch::autograd::Function` +/// for details). struct TORCH_API AutogradContext { // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) AutogradContext() : materialize_grads_(true) {} - AutogradContext(const AutogradContext &other) = delete; + AutogradContext(const AutogradContext& other) = delete; AutogradContext& operator=(const AutogradContext& other) = delete; /// Can be used to save non-variable data for `backward`. @@ -110,22 +115,23 @@ struct TORCH_API AutogradContext { /// Marks variables in the list as modified in an in-place operation. This /// should be called at most once from inside of `forward` and all arguments /// should be inputs. - void mark_dirty(const variable_list &inputs); - /// Marks outputs in the list as not requiring gradients. This should be called - /// at most once from inside of `forward` and all arguments should be outputs. - void mark_non_differentiable(const variable_list &outputs); + void mark_dirty(const variable_list& inputs); + /// Marks outputs in the list as not requiring gradients. This should be + /// called at most once from inside of `forward` and all arguments should be + /// outputs. + void mark_non_differentiable(const variable_list& outputs); // Sets whether undefined output grad tensors should be expanded to tensors // full of zeros before calling backward function. Default value is true. void set_materialize_grads(bool value); /// Get the list of variables that were saved in `forward` using - /// `save_for_backward()`. Before returning them to the user, a check is made to - /// ensure that they were not modified by any in-place operations. + /// `save_for_backward()`. Before returning them to the user, a check is made + /// to ensure that they were not modified by any in-place operations. variable_list get_saved_variables() const; const std::unordered_set& get_and_bump_dirty() const; const std::unordered_set& get_non_differentiable() const; -private: + private: std::unordered_set non_differentiable_; std::unordered_set dirty_inputs_; std::vector saved_variables_; @@ -140,7 +146,8 @@ struct TORCH_API AutogradContext { void save_variables(); - template friend struct CppNode; + template + friend struct CppNode; }; struct TORCH_API VariableInfo { @@ -171,14 +178,15 @@ struct CppNode : public Node { void release_variables() override; - void set_ctx_grad_fn(const std::shared_ptr &node); + void set_ctx_grad_fn(const std::shared_ptr& node); void save_variables_to_ctx(); }; struct ExtractVariables : IterArgs { std::vector& is_var_; variable_list& list_; - ExtractVariables(std::vector& is_var, variable_list& list) : is_var_(is_var), list_(list) {} + ExtractVariables(std::vector& is_var, variable_list& list) + : is_var_(is_var), list_(list) {} void operator()(const c10::optional& x) { // NOLINTNEXTLINE(bugprone-branch-clone) if (x.has_value() && x.value().defined()) { @@ -199,13 +207,16 @@ struct ExtractVariables : IterArgs { }; template -inline void extract_vars(std::vector &is_var, variable_list& list, Args&&... args) { +inline void extract_vars( + std::vector& is_var, + variable_list& list, + Args&&... args) { ExtractVariables(is_var, list).apply(std::forward(args)...); } template -typename std::enable_if::value, T>::type to_output_type( - std::vector>& output_list) { +typename std::enable_if::value, T>::type +to_output_type(std::vector>& output_list) { // NOLINTNEXTLINE(cppcoreguidelines-init-variables) variable_list result; std::transform( @@ -217,9 +228,9 @@ typename std::enable_if::value, T>::type to_outpu } template -typename std::enable_if::value, T>::type to_output_type( - std::vector>& output_list) { - return *output_list[0]; +typename std::enable_if::value, T>::type +to_output_type(std::vector>& output_list) { + return *output_list[0]; } inline std::vector> to_optional(Variable& output) { @@ -229,14 +240,18 @@ inline std::vector> to_optional(Variable& output) { inline std::vector> to_optional(variable_list& output) { // NOLINTNEXTLINE(cppcoreguidelines-init-variables) std::vector> result; - std::transform(output.begin(), output.end(), std::back_inserter(result), - [](const Variable& var) { return var; }); + std::transform( + output.begin(), + output.end(), + std::back_inserter(result), + [](const Variable& var) { return var; }); return result; } -template -template -auto Function::apply(Args&&... args) -> std::enable_if_t::value, forward_t> { +template +template +auto Function::apply(Args&&... args) + -> std::enable_if_t::value, forward_t> { std::shared_ptr> node(new CppNode(), deleteNode); // NOLINTNEXTLINE(cppcoreguidelines-init-variables) variable_list input_vars; @@ -248,15 +263,17 @@ auto Function::apply(Args&&... args) -> std::enable_if_t::v extract_vars(node->is_variable_input_, input_vars, args...); // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - bool is_executable = GradMode::is_enabled() && any_variable_requires_grad(input_vars); - auto next_edges = (is_executable ? collect_next_edges(input_vars) : edge_list()); + bool is_executable = + GradMode::is_enabled() && any_variable_requires_grad(input_vars); + auto next_edges = + (is_executable ? collect_next_edges(input_vars) : edge_list()); node->set_ctx_grad_fn(node); node->set_next_edges(std::move(next_edges)); node->clear_input_metadata(); node->input_info_.reserve(input_vars.size()); for (auto& var : input_vars) { - node->input_info_.emplace_back(var); + node->input_info_.emplace_back(var); } using forward_return_t = forward_t; @@ -267,18 +284,21 @@ auto Function::apply(Args&&... args) -> std::enable_if_t::v outputs = T::forward(&node->ctx_, std::forward(args)...); } - _jvp_fn_t jvp_fn = [](variable_list inputs, variable_list gI) -> variable_list { - TORCH_CHECK(false, "jvp is not implemented for the c++ API of custom Function yet.", - "Please open a feature request on Github if you need this."); + _jvp_fn_t jvp_fn = [](variable_list inputs, + variable_list gI) -> variable_list { + TORCH_CHECK( + false, + "jvp is not implemented for the c++ API of custom Function yet.", + "Please open a feature request on Github if you need this."); }; auto wrapped_outputs = _wrap_outputs( - input_vars, - node->ctx_.get_non_differentiable(), - node->ctx_.get_and_bump_dirty(), - to_optional(outputs), - is_executable ? node : nullptr, - jvp_fn); + input_vars, + node->ctx_.get_non_differentiable(), + node->ctx_.get_and_bump_dirty(), + to_optional(outputs), + is_executable ? node : nullptr, + jvp_fn); node->output_info_.reserve(wrapped_outputs.size()); for (auto& output : wrapped_outputs) { @@ -295,12 +315,12 @@ auto Function::apply(Args&&... args) -> std::enable_if_t::v // wrapped_outputs will be a variable_list so, convert it to the correct // return type. Only Variable and variable_list are accepted as return types. - return to_output_type(wrapped_outputs); + return to_output_type(wrapped_outputs); } // The logic here is the same as PyNode::apply, so changes to it should be done // in both the places -template +template variable_list CppNode::apply(variable_list&& inputs) { at::OptionalDeviceGuard _device_guard; @@ -325,10 +345,11 @@ variable_list CppNode::apply(variable_list&& inputs) { auto outputs = T::backward(&ctx_, backward_inputs); - const auto num_forward_inputs = static_cast(is_variable_input_.size()); + const auto num_forward_inputs = + static_cast(is_variable_input_.size()); auto num_outputs = static_cast(outputs.size()); - // Returning too many results is ok, but only as long as they're all undefined. - // Truncate the result vector in that case. + // Returning too many results is ok, but only as long as they're all + // undefined. Truncate the result vector in that case. if (num_outputs > num_forward_inputs) { bool all_undef = true; for (const auto i : c10::irange(num_forward_inputs, num_outputs)) { @@ -343,7 +364,7 @@ variable_list CppNode::apply(variable_list&& inputs) { if (num_outputs != num_forward_inputs) { std::string msg("function "); msg += name() + " returned an incorrect number of gradients (expected "; - msg += c10::to_string(num_forward_inputs) + ", got " ; + msg += c10::to_string(num_forward_inputs) + ", got "; msg += c10::to_string(num_outputs) + ")"; throw std::runtime_error(msg); } @@ -355,8 +376,10 @@ variable_list CppNode::apply(variable_list&& inputs) { if (!is_variable_input_[i]) { if (outputs[i].defined()) { std::string msg("function "); - msg += name() + " returned a gradient different that is defined at position "; - msg += c10::to_string(i + 1) + ", but the corresponding forward input was not a Variable"; + msg += name() + + " returned a gradient different that is defined at position "; + msg += c10::to_string(i + 1) + + ", but the corresponding forward input was not a Variable"; throw std::runtime_error(msg); } continue; @@ -366,7 +389,7 @@ variable_list CppNode::apply(variable_list&& inputs) { return results; } -template +template void CppNode::release_variables() { // lock to ensure thread safety, see [Thread Safety on Autograd Node] std::lock_guard lock(mutex_); @@ -374,14 +397,15 @@ void CppNode::release_variables() { ctx_.has_freed_buffers_ = true; } -template +template void CppNode::save_variables_to_ctx() { ctx_.save_variables(); } -template -void CppNode::set_ctx_grad_fn(const std::shared_ptr &node) { +template +void CppNode::set_ctx_grad_fn(const std::shared_ptr& node) { ctx_.grad_fn_ = node; } -}} // namespace torch::autograd +} // namespace autograd +} // namespace torch diff --git a/torch/csrc/autograd/edge.h b/torch/csrc/autograd/edge.h index 76e6baf8aa3874..e51291b3a2f4f2 100644 --- a/torch/csrc/autograd/edge.h +++ b/torch/csrc/autograd/edge.h @@ -6,7 +6,8 @@ #include -namespace torch { namespace autograd { +namespace torch { +namespace autograd { struct Node; @@ -37,7 +38,8 @@ struct Edge { /// The identifier of a particular input to the function. uint32_t input_nr; }; -}} // namespace torch::autograd +} // namespace autograd +} // namespace torch // The idiomatic way of enabling use of a custom type as the key of hash // containers in C++11. This method removes the requirement of having to pass diff --git a/torch/csrc/autograd/engine.cpp b/torch/csrc/autograd/engine.cpp index 82e5cd21c11d91..6a603c57f62a5b 100644 --- a/torch/csrc/autograd/engine.cpp +++ b/torch/csrc/autograd/engine.cpp @@ -1,10 +1,10 @@ #include +#include #include #include #include #include -#include #include #include @@ -13,14 +13,14 @@ #include #include -#include -#include -#include #include -#include +#include +#include +#include +#include #include #include -#include +#include #include #include @@ -30,15 +30,16 @@ #include #include #include +#include #include +#include #include #include -#include #include -#include -#include +#include -namespace torch { namespace autograd { +namespace torch { +namespace autograd { namespace { static bool in_bad_autograd_fork = @@ -46,7 +47,9 @@ static bool in_bad_autograd_fork = // Called in the forked child if engine's thread pool has already been // initialized -static void forked_autograd_child() { in_bad_autograd_fork = true; } +static void forked_autograd_child() { + in_bad_autograd_fork = true; +} // Should be called before unsafe for forks (thread pool) calls static void track_bad_autograd_forks() { @@ -64,7 +67,7 @@ inline bool should_run_in_cpu_ready_queue(c10::DeviceType device) { return false; } } -} +} // namespace // Threads spawned by the engine are assigned a 'worker_device' specifying // what device they process work for. This variable is initialized at: @@ -84,9 +87,11 @@ static thread_local bool checkpoint_valid = true; // Number of nested reentrant backwards calls currently on this thread static thread_local int current_depth = 0; -// For all device threads (i.e. CUDA, XLA), total_depth represents the total nested +// For all device threads (i.e. CUDA, XLA), total_depth represents the total +// nested // reentrant backwards depths over all device threads. -// For CPU devices, it is the total depth associated with the original backward call. +// For CPU devices, it is the total depth associated with the original backward +// call. static thread_local int total_depth = 0; // The current GraphTask being executed by this thread. This helps @@ -94,16 +99,17 @@ static thread_local int total_depth = 0; C10_DEFINE_TLS_static(std::shared_ptr, tls_current_graph_task); #define current_graph_task (tls_current_graph_task.get()) -// Every autograd worker thread is associated with a ready queue, which specifies -// the stream of work of this thread to do. This shared_ptr is a thread_local -// pointer to each thread's ready_queue, and it should be initialized via the -// Engine::init_local_ready_queue() call in each corresponding thread before execution. +// Every autograd worker thread is associated with a ready queue, which +// specifies the stream of work of this thread to do. This shared_ptr is a +// thread_local pointer to each thread's ready_queue, and it should be +// initialized via the Engine::init_local_ready_queue() call in each +// corresponding thread before execution. // // The CUDA, XLA threads are shared among all invocations of backwards via -// device_ready_queues_, while the caller thread is dedicated to processing work for -// devices returning true in should_run_in_cpu_ready_queue (most notably the CPU device). -// So any given graph task maintains its own cpu_ready_queue_ where you should send work -// for it to be done. +// device_ready_queues_, while the caller thread is dedicated to processing work +// for devices returning true in should_run_in_cpu_ready_queue (most notably the +// CPU device). So any given graph task maintains its own cpu_ready_queue_ where +// you should send work for it to be done. // // For reentrant backward calls, if we spawn new thread from the current thread // because we reached the maximum depth, the new thread will just reuse the same @@ -185,9 +191,11 @@ int NodeTask::getReentrantDepth() const { } } -CheckpointValidGuard::CheckpointValidGuard(const std::shared_ptr& graph_task) { +CheckpointValidGuard::CheckpointValidGuard( + const std::shared_ptr& graph_task) { prev_checkpoint_valid_state = checkpoint_valid; - checkpoint_valid = graph_task->can_checkpoint() && prev_checkpoint_valid_state; + checkpoint_valid = + graph_task->can_checkpoint() && prev_checkpoint_valid_state; } CheckpointValidGuard::~CheckpointValidGuard() { @@ -225,9 +233,10 @@ size_t ReadyQueue::size() const { auto ReadyQueue::pop() -> NodeTask { // Lock mutex for accesses to heap_ std::unique_lock lock(mutex_); - not_empty_.wait(lock, [this]{ return !heap_.empty(); }); + not_empty_.wait(lock, [this] { return !heap_.empty(); }); // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) - auto task = std::move(const_cast(heap_.top())); heap_.pop(); + auto task = std::move(const_cast(heap_.top())); + heap_.pop(); return task; } @@ -237,14 +246,16 @@ bool ReadyQueue::empty() const { return heap_.empty(); } -Engine::Engine() : max_recursion_depth_(MAX_DEPTH), non_reentrant_device_thread_count_(0) {} +Engine::Engine() + : max_recursion_depth_(MAX_DEPTH), non_reentrant_device_thread_count_(0) {} Engine::~Engine() { stop(); } -// Send shutdown tasks to all device_ready_queues_ if no backward tasks are running -// Even though readyQueue should be empty, shutdown tasks have the highest priority +// Send shutdown tasks to all device_ready_queues_ if no backward tasks are +// running Even though readyQueue should be empty, shutdown tasks have the +// highest priority void Engine::stop() { if (stopped_) { return; @@ -258,12 +269,12 @@ void Engine::stop() { } auto wait_duration = std::atof(wait_duration_str); bool noBackward = true; - for (auto& queue: device_ready_queues_) { - noBackward = noBackward && queue->empty(); + for (auto& queue : device_ready_queues_) { + noBackward = noBackward && queue->empty(); } if (noBackward && wait_duration > 0.0f) { for (auto& queue : device_ready_queues_) { - queue->pushShutdownTask(); + queue->pushShutdownTask(); } // Do not wait for termination of global threads on Windows // Because CRT terminates DLL threads before calling @@ -272,10 +283,12 @@ void Engine::stop() { using namespace std::chrono_literals; // Set a deadline for how long it is OK to wait device threads to shutdown - auto wait_deadline = std::chrono::steady_clock::now() + wait_duration * 1.0s; + auto wait_deadline = + std::chrono::steady_clock::now() + wait_duration * 1.0s; std::unique_lock lk(non_reentrant_device_thread_mutex_); - while(non_reentrant_device_thread_count_.load() != 0) { - if (non_reentrant_device_thread_condvar_.wait_until(lk, wait_deadline) == std::cv_status::timeout) { + while (non_reentrant_device_thread_count_.load() != 0) { + if (non_reentrant_device_thread_condvar_.wait_until(lk, wait_deadline) == + std::cv_status::timeout) { break; } } @@ -302,7 +315,10 @@ void Engine::decrement_non_reentrant_thread_count() { non_reentrant_device_thread_condvar_.notify_one(); } -void Engine::thread_init(int device, const std::shared_ptr& ready_queue, bool should_increment) { +void Engine::thread_init( + int device, + const std::shared_ptr& ready_queue, + bool should_increment) { if (should_increment) { increment_non_reentrant_thread_count(); } @@ -328,8 +344,8 @@ void Engine::thread_init(int device, const std::shared_ptr& ready_qu // better. set_device(device); - // initialize each device thread's thread local ready queue with the ready queue - // that is created before the thread initialization + // initialize each device thread's thread local ready queue with the ready + // queue that is created before the thread initialization init_local_ready_queue(ready_queue); std::shared_ptr graph_task = nullptr; @@ -394,7 +410,8 @@ auto Engine::thread_main(const std::shared_ptr& graph_task) -> void { at::ROCmBackwardPassGuard in_backward; #endif - // local_ready_queue should already been initialized when we get into thread_main + // local_ready_queue should already been initialized when we get into + // thread_main TORCH_INTERNAL_ASSERT(local_ready_queue != nullptr); while (graph_task == nullptr || !graph_task->future_result_->completed()) { // local_graph_task represents the graph_task we retrieve from the queue. @@ -421,10 +438,11 @@ auto Engine::thread_main(const std::shared_ptr& graph_task) -> void { if (task.fn_ && !local_graph_task->has_error_.load()) { // Set the ThreadLocalState before calling the function. - // NB: The ThreadLocalStateGuard doesn't set the grad_mode because GraphTask - // always saves ThreadLocalState without grad_mode. + // NB: The ThreadLocalStateGuard doesn't set the grad_mode because + // GraphTask always saves ThreadLocalState without grad_mode. at::ThreadLocalStateGuard tls_guard(local_graph_task->thread_locals_); - c10::Warning::WarningHandlerGuard warnings_guard(&local_graph_task->warning_handler_); + c10::Warning::WarningHandlerGuard warnings_guard( + &local_graph_task->warning_handler_); try { // The guard sets the thread_local current_graph_task on construction @@ -485,10 +503,11 @@ auto Engine::thread_main(const std::shared_ptr& graph_task) -> void { void Engine::reentrant_thread_init() { at::init_num_threads(); auto tp_shared = thread_pool_shared_; - while(true) { + while (true) { std::unique_lock lk(tp_shared->mutex_); ++thread_pool_shared_->num_workers_; - tp_shared->work_.wait(lk, [&tp_shared]{ return !tp_shared->graphtasks_queue_.empty();}); + tp_shared->work_.wait( + lk, [&tp_shared] { return !tp_shared->graphtasks_queue_.empty(); }); --thread_pool_shared_->num_workers_; auto task = tp_shared->graphtasks_queue_.front(); tp_shared->graphtasks_queue_.pop(); @@ -499,8 +518,10 @@ void Engine::reentrant_thread_init() { continue; } set_device(graph_task->owner_); - // set the local_ready_queue to the ready queue on the graph_task->owner_ device - local_ready_queue = ready_queue_by_index(graph_task->cpu_ready_queue_, graph_task->owner_); + // set the local_ready_queue to the ready queue on the graph_task->owner_ + // device + local_ready_queue = + ready_queue_by_index(graph_task->cpu_ready_queue_, graph_task->owner_); total_depth = graph_task->reentrant_depth_; thread_main(graph_task); } @@ -568,13 +589,14 @@ void GraphTask::exec_post_processing() { // any grad on its device's current stream. if (leaf_streams.size() > 0) { for (const auto& leaf_stream : leaf_streams) { - // stash_current_streams() stashed streams for all device IDs that already had a - // CUDA context before the GraphTask executed. For inactive devices, it stashed - // a c10::nullopt. I don't expect GraphTask's backward pass ran leaf nodes on - // any new devices, so the stashed streams should be enough. + // stash_current_streams() stashed streams for all device IDs that already + // had a CUDA context before the GraphTask executed. For inactive devices, + // it stashed a c10::nullopt. I don't expect GraphTask's backward pass ran + // leaf nodes on any new devices, so the stashed streams should be enough. // If leaf_stream.device_index() happens to be for a new device, // operator* on the c10::nullopt should throw an error. - const auto caller_current_stream = *caller_current_streams_[leaf_stream.device_index()]; + const auto caller_current_stream = + *caller_current_streams_[leaf_stream.device_index()]; if (caller_current_stream != leaf_stream) { auto event = c10::Event{c10::DeviceType::CUDA}; @@ -592,11 +614,14 @@ void GraphTask::exec_post_processing() { } { - // final_callbacks run on the per-device caller_current_streams (the ambient streams - // surrounding the user's call to backward()). This has two benefits: - // 1. caller_current_streams have been synced with leaf_streams, so callbacks may + // final_callbacks run on the per-device caller_current_streams (the ambient + // streams surrounding the user's call to backward()). This has two + // benefits: + // 1. caller_current_streams have been synced with leaf_streams, so + // callbacks may // safely access any grad. - // 2. The callback's results can safely be used on (user-facing) caller_current_streams + // 2. The callback's results can safely be used on (user-facing) + // caller_current_streams // after backward(). c10::MultiStreamGuard g(caller_current_streams_filtered); @@ -641,7 +666,10 @@ static variable_list call_pre_hooks(Node& fn, variable_list inputs) { return inputs; } -static variable_list call_post_hooks(Node& fn, variable_list outputs, const variable_list& inputs) { +static variable_list call_post_hooks( + Node& fn, + variable_list outputs, + const variable_list& inputs) { for (const auto& hook : fn.post_hooks()) { outputs = (*hook)(outputs, inputs); } @@ -658,7 +686,8 @@ void set_device(int device) { // Don't use DeviceGuard here because its destructor may be called before the // device is reset. This is fine because the device is thread local. if (device != CPU_DEVICE) { - for(const auto i : c10::irange(static_cast(c10::DeviceType::COMPILE_TIME_MAX_DEVICE_TYPES))) { + for (const auto i : c10::irange(static_cast( + c10::DeviceType::COMPILE_TIME_MAX_DEVICE_TYPES))) { auto* impl = c10::impl::device_guard_impl_registry[i].load(); if (impl && device < impl->deviceCount()) { impl->setDevice(at::Device(static_cast(i), device)); @@ -678,9 +707,10 @@ void validate_outputs( ss << edges.size() << ", but got " << grads.size(); AT_ERROR(format_error(ss.str())); } - for(const auto i : c10::irange(grads.size())) { + for (const auto i : c10::irange(grads.size())) { const auto& edge = edges[i]; - if (!edge.is_valid()) continue; + if (!edge.is_valid()) + continue; const auto& metadata = edge.function->input_metadata(edge.input_nr); auto& grad = grads[i]; @@ -701,36 +731,46 @@ void validate_outputs( } } - bool input_is_complex = isComplexType(c10::typeMetaToScalarType(metadata.options().dtype())); + bool input_is_complex = + isComplexType(c10::typeMetaToScalarType(metadata.options().dtype())); bool grad_is_complex = isComplexType(grad.scalar_type()); - TORCH_CHECK(isFloatingType(grad.scalar_type()) || (input_is_complex == grad_is_complex)); - if (c10::typeMetaToScalarType(metadata.options().dtype()) != grad.scalar_type()) { + TORCH_CHECK( + isFloatingType(grad.scalar_type()) || + (input_is_complex == grad_is_complex)); + if (c10::typeMetaToScalarType(metadata.options().dtype()) != + grad.scalar_type()) { grad = grad.to(c10::typeMetaToScalarType(metadata.options().dtype())); } if (grad.dtype() != metadata.dtype()) { - std::stringstream ss; - ss << "invalid gradient at index " << i << " - expected dtype "; - ss << metadata.dtype() << " but got " << grad.dtype(); - AT_ERROR(format_error(ss.str())); + std::stringstream ss; + ss << "invalid gradient at index " << i << " - expected dtype "; + ss << metadata.dtype() << " but got " << grad.dtype(); + AT_ERROR(format_error(ss.str())); } if (grad.layout() != metadata.layout()) { - // TODO: Currently we only support (*, Sparse) combination for (tensor.layout(), tensor.grad.layout()) - // In future, there will be an oppportunity to support more combinations of layouts if they are composable - // (example., operations like addition etc., are well defined between tensors of different layouts.), - // as well as all parts of autograd like AccumulateGrad correctly handle this. - // We allow grad to be Strided when metadata is SparseCsr - if (!grad.is_sparse() && !(grad.layout() == at::kStrided && metadata.layout() == at::kSparseCsr)) { + // TODO: Currently we only support (*, Sparse) combination for + // (tensor.layout(), tensor.grad.layout()) In future, there will be an + // oppportunity to support more combinations of layouts if they are + // composable (example., operations like addition etc., are well defined + // between tensors of different layouts.), as well as all parts of + // autograd like AccumulateGrad correctly handle this. We allow grad to be + // Strided when metadata is SparseCsr + if (!grad.is_sparse() && + !(grad.layout() == at::kStrided && + metadata.layout() == at::kSparseCsr)) { std::stringstream ss; ss << "invalid gradient at index " << i << " - expected layout "; ss << metadata.layout() << " but got " << grad.layout(); AT_ERROR(format_error(ss.str())); - } + } } if (grad.device() != metadata.device()) { - // quick hack for: https://github.com/pytorch/pytorch/issues/65016 but should be eventually removed - if (!(metadata.is_tensor_subclass() || grad.unsafeGetTensorImpl()->is_python_dispatch())) { + // quick hack for: https://github.com/pytorch/pytorch/issues/65016 but + // should be eventually removed + if (!(metadata.is_tensor_subclass() || + grad.unsafeGetTensorImpl()->is_python_dispatch())) { if (grad.dim() == 0) { grad = grad.to(metadata.device()); } else { @@ -786,11 +826,11 @@ static variable_list call_function( validate_outputs(fn.next_edges(), outputs, [&](const std::string& msg) { std::ostringstream ss; - ss << "Function " << fn.name() << " returned an " << msg; + ss << "Function " << fn.name() << " returned an " << msg; return ss.str(); }); - if(has_post_hooks){ + if (has_post_hooks) { // NOLINTNEXTLINE(bugprone-use-after-move) return call_post_hooks(fn, std::move(outputs), inputs); } @@ -859,19 +899,22 @@ void Engine::evaluate_function( at::OptionalDeviceGuard guard(device_of(output)); if (output.defined() && isnan(output).any().item()) { std::stringstream ss; - ss << "Function '" << fn.name() << "' returned nan values in its " << i << "th output."; + ss << "Function '" << fn.name() << "' returned nan values in its " << i + << "th output."; throw std::runtime_error(ss.str()); } } } - // Lock mutex for the accesses to GraphTask dependencies_, not_ready_ and cpu_ready_queue_ below + // Lock mutex for the accesses to GraphTask dependencies_, not_ready_ and + // cpu_ready_queue_ below std::lock_guard lock(graph_task->mutex_); for (const auto i : c10::irange(num_outputs)) { auto& output = outputs[i]; const auto& next = fn.next_edge(i); - if (!next.is_valid()) continue; + if (!next.is_valid()) + continue; // Check if the next function is ready to be computed bool is_ready = false; @@ -901,10 +944,8 @@ void Engine::evaluate_function( // Accumulates into buffer const auto opt_next_stream = next.function->stream(c10::DeviceType::CUDA); - input_buffer.add(next.input_nr, - std::move(output), - opt_parent_stream, - opt_next_stream); + input_buffer.add( + next.input_nr, std::move(output), opt_parent_stream, opt_next_stream); if (is_ready) { auto queue = ready_queue(cpu_ready_queue, input_buffer.device()); @@ -915,14 +956,12 @@ void Engine::evaluate_function( } } else { // The function already has a buffer - auto &input_buffer = not_ready_it->second; + auto& input_buffer = not_ready_it->second; // Accumulates into buffer const auto opt_next_stream = next.function->stream(c10::DeviceType::CUDA); - input_buffer.add(next.input_nr, - std::move(output), - opt_parent_stream, - opt_next_stream); + input_buffer.add( + next.input_nr, std::move(output), opt_parent_stream, opt_next_stream); if (is_ready) { auto queue = ready_queue(cpu_ready_queue, input_buffer.device()); queue->push( @@ -939,17 +978,20 @@ inline static uint64_t compute_min_topological_nr(const edge_list& outputs) { return 0; } auto min_topo_nr = std::numeric_limits::max(); - for (auto & output_edge : outputs) { + for (auto& output_edge : outputs) { auto topo_nr = output_edge.function.get()->topological_nr(); min_topo_nr = (min_topo_nr < topo_nr) ? min_topo_nr : topo_nr; } return min_topo_nr; } -auto Engine::compute_dependencies(Node* root, GraphTask& task, uint64_t min_topo_nr) -> void { +auto Engine::compute_dependencies( + Node* root, + GraphTask& task, + uint64_t min_topo_nr) -> void { // Computes the number of dependencies for each function which requires grad std::unordered_set seen; - std::vector queue { root }; + std::vector queue{root}; bool might_use_cuda = at::globalContext().hasCUDA(); bool will_use_cuda = false; @@ -957,7 +999,8 @@ auto Engine::compute_dependencies(Node* root, GraphTask& task, uint64_t min_topo // We no longer have to expand functions that don't require grad. auto& dependencies = task.dependencies_; while (!queue.empty()) { - auto fn = queue.back(); queue.pop_back(); + auto fn = queue.back(); + queue.pop_back(); if (fn->topological_nr() < min_topo_nr) { continue; } @@ -968,7 +1011,8 @@ auto Engine::compute_dependencies(Node* root, GraphTask& task, uint64_t min_topo if (auto next_ptr = edge.function.get()) { dependencies[next_ptr] += 1; const bool was_inserted = seen.insert(next_ptr).second; - if (was_inserted) queue.push_back(next_ptr); + if (was_inserted) + queue.push_back(next_ptr); } } } @@ -980,35 +1024,37 @@ auto Engine::compute_dependencies(Node* root, GraphTask& task, uint64_t min_topo } } -auto Engine::execute(const edge_list& roots, - const variable_list& inputs, - bool keep_graph, - bool create_graph, - bool accumulate_grad, - const edge_list& outputs) -> variable_list { +auto Engine::execute( + const edge_list& roots, + const variable_list& inputs, + bool keep_graph, + bool create_graph, + bool accumulate_grad, + const edge_list& outputs) -> variable_list { // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) - validate_outputs(roots, const_cast(inputs), [](const std::string& msg) { - return msg; - }); + validate_outputs( + roots, const_cast(inputs), [](const std::string& msg) { + return msg; + }); if (accumulate_grad && create_graph) { TORCH_WARN_ONCE( - "Using backward() with create_graph=True will create a reference cycle " - "between the parameter and its gradient which can cause a memory leak. " - "We recommend using autograd.grad when creating the graph to avoid this. " - "If you have to use this function, make sure to reset the .grad fields of " - "your parameters to None after use to break the cycle and avoid the leak."); + "Using backward() with create_graph=True will create a reference cycle " + "between the parameter and its gradient which can cause a memory leak. " + "We recommend using autograd.grad when creating the graph to avoid this. " + "If you have to use this function, make sure to reset the .grad fields of " + "your parameters to None after use to break the cycle and avoid the leak."); } // accumulate_grad is true if and only if the frontend call was to // grad(), not backward(). grad() returns the sum of the gradients // w.r.t. the inputs and thus needs the inputs to be present. - TORCH_CHECK_VALUE(accumulate_grad || !outputs.empty(), - "grad requires non-empty inputs."); + TORCH_CHECK_VALUE( + accumulate_grad || !outputs.empty(), "grad requires non-empty inputs."); - // A fresh first time Engine::execute call should start on the CPU device, initialize - // a new thread local ready queue on CPU or reuse the existing one (if there is one - // allocated already, i.e. consecutive backward calls, re-entrant backward calls), - // then memoize the local_ready_queue in GraphTask + // A fresh first time Engine::execute call should start on the CPU device, + // initialize a new thread local ready queue on CPU or reuse the existing one + // (if there is one allocated already, i.e. consecutive backward calls, + // re-entrant backward calls), then memoize the local_ready_queue in GraphTask init_local_ready_queue(); bool not_reentrant_backward_call = worker_device == NO_DEVICE; @@ -1020,16 +1066,17 @@ auto Engine::execute(const edge_list& roots, // If we receive a single root, skip creating extra root node bool skip_dummy_node = roots.size() == 1; - auto graph_root = skip_dummy_node ? - roots.at(0).function : - std::make_shared(roots, inputs); + auto graph_root = skip_dummy_node + ? roots.at(0).function + : std::make_shared(roots, inputs); auto min_topo_nr = compute_min_topological_nr(outputs); // Now compute the dependencies for all executable functions compute_dependencies(graph_root.get(), *graph_task, min_topo_nr); if (!outputs.empty()) { - graph_task->init_to_execute(*graph_root, outputs, accumulate_grad, min_topo_nr); + graph_task->init_to_execute( + *graph_root, outputs, accumulate_grad, min_topo_nr); } // Queue the root @@ -1038,15 +1085,15 @@ auto Engine::execute(const edge_list& roots, auto input = inputs.at(0); const auto input_stream = InputMetadata(input).stream(); - const auto opt_next_stream = roots.at(0).function->stream(c10::DeviceType::CUDA); - input_buffer.add(roots.at(0).input_nr, - std::move(input), - input_stream, - opt_next_stream); + const auto opt_next_stream = + roots.at(0).function->stream(c10::DeviceType::CUDA); + input_buffer.add( + roots.at(0).input_nr, std::move(input), input_stream, opt_next_stream); execute_with_graph_task(graph_task, graph_root, std::move(input_buffer)); } else { - execute_with_graph_task(graph_task, graph_root, InputBuffer(variable_list())); + execute_with_graph_task( + graph_task, graph_root, InputBuffer(variable_list())); } // Avoid a refcount bump for the Future, since we check for refcount in // DistEngine (see TORCH_INTERNAL_ASSERT(futureGrads.use_count() == 1) @@ -1058,10 +1105,12 @@ auto Engine::execute(const edge_list& roots, } void Engine::initialize_device_threads_pool() { - TORCH_CHECK(!in_bad_autograd_fork, - "Unable to handle autograd's threading in combination with fork-based multiprocessing. " - "See https://github.com/pytorch/pytorch/wiki/Autograd-and-Fork"); - std::call_once(start_device_threads_flag_, &Engine::start_device_threads, this); + TORCH_CHECK( + !in_bad_autograd_fork, + "Unable to handle autograd's threading in combination with fork-based multiprocessing. " + "See https://github.com/pytorch/pytorch/wiki/Autograd-and-Fork"); + std::call_once( + start_device_threads_flag_, &Engine::start_device_threads, this); } c10::intrusive_ptr Engine::execute_with_graph_task( @@ -1077,36 +1126,38 @@ c10::intrusive_ptr Engine::execute_with_graph_task( // worker_device == NO_DEVICE it's a CPU thread and it's trying to drive the // autograd engine with corresponding GraphTask, and its NOT a re-entrant call if (worker_device == NO_DEVICE) { - // We set the worker_device to CPU_DEVICE only if worker_device was previously - // NO_DEVICE. Setting it to CPU afterwards allow us to detect whether this is - // a re-entrant call or not. + // We set the worker_device to CPU_DEVICE only if worker_device was + // previously NO_DEVICE. Setting it to CPU afterwards allow us to detect + // whether this is a re-entrant call or not. set_device(CPU_DEVICE); // set the graph_task owner to the current device graph_task->owner_ = worker_device; - // Now that all the non-thread safe fields of the graph_task have been populated, - // we can enqueue it. - queue->push(NodeTask(graph_task, std::move(graph_root), std::move(input_buffer))); + // Now that all the non-thread safe fields of the graph_task have been + // populated, we can enqueue it. + queue->push( + NodeTask(graph_task, std::move(graph_root), std::move(input_buffer))); - // The owning thread start to drive the engine execution for any CPU task that - // was just pushed or will be added later from other worker threads + // The owning thread start to drive the engine execution for any CPU task + // that was just pushed or will be added later from other worker threads lock.unlock(); thread_main(graph_task); TORCH_INTERNAL_ASSERT(graph_task->future_result_->completed()); - // reset the worker_device after the completion of the graph_task, this is so - // that the initial state of the engine remains the same across every backward() - // or grad() call, we don't need to reset local_ready_queue as we could possibly - // reuse it for new backward calls. + // reset the worker_device after the completion of the graph_task, this is + // so that the initial state of the engine remains the same across every + // backward() or grad() call, we don't need to reset local_ready_queue as we + // could possibly reuse it for new backward calls. worker_device = NO_DEVICE; } else { // If worker_device is any devices (i.e. CPU, CUDA): this is a re-entrant // backward call from that device. graph_task->owner_ = worker_device; - // Now that all the non-thread safe fields of the graph_task have been populated, - // we can enqueue it. - queue->push(NodeTask(graph_task, std::move(graph_root), std::move(input_buffer))); + // Now that all the non-thread safe fields of the graph_task have been + // populated, we can enqueue it. + queue->push( + NodeTask(graph_task, std::move(graph_root), std::move(input_buffer))); if (current_depth >= max_recursion_depth_) { // See Note [Reentrant backwards] @@ -1152,7 +1203,6 @@ void set_default_engine_stub(EngineStub stub) { engine_stub.store(stub); } - Engine& Engine::get_default_engine() { return engine_stub.load()(); } @@ -1172,37 +1222,49 @@ bool Engine::is_checkpoint_valid() { void Engine::init_local_ready_queue(std::shared_ptr ready_queue) { if (ready_queue) { - // if ready_queue provided in the caller, use the caller's ready_queue to initialize local_ready_queue + // if ready_queue provided in the caller, use the caller's ready_queue to + // initialize local_ready_queue local_ready_queue = std::move(ready_queue); - } else if (!local_ready_queue){ + } else if (!local_ready_queue) { // otherwise if local_ready_queue not allocated, allocate a new ready_queue local_ready_queue = std::make_shared(); } } -// CPU ready queue is per GraphTask, but CUDA device ready queues are shared across all graph tasks -auto Engine::ready_queue(std::shared_ptr cpu_ready_queue, at::Device device) -> std::shared_ptr{ +// CPU ready queue is per GraphTask, but CUDA device ready queues are shared +// across all graph tasks +auto Engine::ready_queue( + std::shared_ptr cpu_ready_queue, + at::Device device) -> std::shared_ptr { if (should_run_in_cpu_ready_queue(device.type())) { // return the cpu ready queue passed in TORCH_INTERNAL_ASSERT(cpu_ready_queue); return cpu_ready_queue; } else { - TORCH_INTERNAL_ASSERT(0 <= device.index() && device.index() < static_cast(device_ready_queues_.size())); + TORCH_INTERNAL_ASSERT( + 0 <= device.index() && + device.index() < + static_cast(device_ready_queues_.size())); // See Note [Allocating GPUs to autograd threads] return device_ready_queues_.at(device.index()); } } -auto Engine::ready_queue_by_index(std::shared_ptr cpu_ready_queue, int device_index) -> std::shared_ptr { +auto Engine::ready_queue_by_index( + std::shared_ptr cpu_ready_queue, + int device_index) -> std::shared_ptr { if (device_index == CPU_DEVICE) { // return the cpu ready queue passed in TORCH_INTERNAL_ASSERT(cpu_ready_queue); return cpu_ready_queue; } else { - TORCH_INTERNAL_ASSERT(0 <= device_index && device_index < static_cast(device_ready_queues_.size())); + TORCH_INTERNAL_ASSERT( + 0 <= device_index && + device_index < + static_cast(device_ready_queues_.size())); // See Note [Allocating GPUs to autograd threads] - // NB: This function would become obsolete if we truly allocated a CPU thread - // per device, rather than colocate. + // NB: This function would become obsolete if we truly allocated a CPU + // thread per device, rather than colocate. return device_ready_queues_.at(device_index); } } @@ -1232,9 +1294,10 @@ auto Engine::start_device_threads() -> void { track_bad_autograd_forks(); // allocate one thread for every GPU device (but colocate GPUs of different - // types), and pre-allocate the device_ready_queues_ to ensure safe reading on it. + // types), and pre-allocate the device_ready_queues_ to ensure safe reading on + // it. device_ready_queues_ = std::vector>(num_devices); - for (auto& queue : device_ready_queues_) { + for (auto& queue : device_ready_queues_) { queue = std::make_shared(); } @@ -1245,7 +1308,8 @@ auto Engine::start_device_threads() -> void { // Wait for the threads to start { std::unique_lock lk(non_reentrant_device_thread_mutex_); - while(non_reentrant_device_thread_count_.load() != static_cast(num_devices)) { + while (non_reentrant_device_thread_count_.load() != + static_cast(num_devices)) { non_reentrant_device_thread_condvar_.wait(lk); } } @@ -1256,7 +1320,9 @@ void Engine::add_thread_pool_task(const std::weak_ptr& graph_task) { // There may already be some items on the graphtasks_queue_ added by other // threads but not enough workers to get to the new task that will be // added - bool create_thread = (thread_pool_shared_->num_workers_ <= thread_pool_shared_->graphtasks_queue_.size()); + bool create_thread = + (thread_pool_shared_->num_workers_ <= + thread_pool_shared_->graphtasks_queue_.size()); thread_pool_shared_->graphtasks_queue_.push(graph_task); // Don't need to be holding the lock while actually creating the thread lck.unlock(); @@ -1272,22 +1338,26 @@ void Engine::add_thread_pool_task(const std::weak_ptr& graph_task) { } // Remembers current streams on all devices where a context has been created. -// Only called if Engine::execute detects at least one node runs on a cuda stream. +// Only called if Engine::execute detects at least one node runs on a cuda +// stream. void GraphTask::stash_current_streams() { const auto guard = c10::impl::VirtualGuardImpl{c10::DeviceType::CUDA}; auto num_gpus = guard.deviceCount(); caller_current_streams_.resize(num_gpus); if (num_gpus > 0) { - for (c10::DeviceIndex idx = 0; idx < num_gpus; idx++) { + for (c10::DeviceIndex idx = 0; idx < num_gpus; idx++) { #if defined(USE_ROCM) && (ROCM_VERSION < 50000) - // If the build targets ROCM, stash streams for all visible devices unconditionally, to work around + // If the build targets ROCM, stash streams for all visible devices + // unconditionally, to work around // https://github.com/pytorch/pytorch/issues/59750. - // TODO: Remove ROCM-specific behavior when https://github.com/pytorch/pytorch/issues/59750 is fixed. + // TODO: Remove ROCM-specific behavior when + // https://github.com/pytorch/pytorch/issues/59750 is fixed. if (true) { #else if (at::detail::getCUDAHooks().hasPrimaryContext(idx)) { #endif - caller_current_streams_[idx] = guard.getStream({c10::DeviceType::CUDA, idx}); + caller_current_streams_[idx] = + guard.getStream({c10::DeviceType::CUDA, idx}); } else { caller_current_streams_[idx] = c10::nullopt; } @@ -1295,13 +1365,18 @@ void GraphTask::stash_current_streams() { } } -void GraphTask::init_to_execute(Node& graph_root, const edge_list& outputs, bool accumulate_grad, uint64_t min_topo_nr) { - // Populates exec_info so nodes that should be executed have `exec_info[node].needed_ = true` - // Only nodes that have a path to any edge in `outputs` should be executed. - // The code below populates exec_info using recursion, but the actual code does this - // iteratively. Refer to the numbering to see how the actual code corresponds. - // A difference to note is that in the iterative version, when you are working with - // the current Node, you are reponsible to update your parent's is_needed after all your +void GraphTask::init_to_execute( + Node& graph_root, + const edge_list& outputs, + bool accumulate_grad, + uint64_t min_topo_nr) { + // Populates exec_info so nodes that should be executed have + // `exec_info[node].needed_ = true` Only nodes that have a path to any edge in + // `outputs` should be executed. The code below populates exec_info using + // recursion, but the actual code does this iteratively. Refer to the + // numbering to see how the actual code corresponds. A difference to note is + // that in the iterative version, when you are working with the current Node, + // you are reponsible to update your parent's is_needed after all your // children have been updated. // // is_needed = {fn: True for fn in outputs} # (0) @@ -1319,16 +1394,18 @@ void GraphTask::init_to_execute(Node& graph_root, const edge_list& outputs, bool // return is_needed[fn] // compute_is_needed(graph_root) // - // NB: you might be wondering why we don't populate `seen` with outputs. We cannot - // because in the case where two outputs lie on the same path, we still need to explore past - // the first output or we would miss the nodes that are required to compute the second output. + // NB: you might be wondering why we don't populate `seen` with outputs. We + // cannot because in the case where two outputs lie on the same path, we still + // need to explore past the first output or we would miss the nodes that are + // required to compute the second output. int output_idx = 0; - for (auto & output_edge : outputs) { + for (auto& output_edge : outputs) { // (0) `is_needed` above corresponds to `exec_info_[fn].needed_` - Node *output = output_edge.function.get(); - auto & info = exec_info_[output]; + Node* output = output_edge.function.get(); + auto& info = exec_info_[output]; if (accumulate_grad) { - // if called through `.backward()` we directly set `needed_` for all the outputs to true + // if called through `.backward()` we directly set `needed_` for all the + // outputs to true info.needed_ = true; } else { // otherwise it is `.grad()` and we set exec_info[fn].captures_ instead @@ -1343,22 +1420,23 @@ void GraphTask::init_to_execute(Node& graph_root, const edge_list& outputs, bool captured_vars_.resize(output_idx); struct Frame { - Frame (Node *fn) : fn_(fn), next_next_fn_(0) {} - Node *fn_; + Frame(Node* fn) : fn_(fn), next_next_fn_(0) {} + Node* fn_; size_t next_next_fn_; Node* get_next_fn() { - const auto & next = fn_->next_edges(); + const auto& next = fn_->next_edges(); auto num_next = next.size(); while (next_next_fn_ < num_next) { auto fn = next[next_next_fn_++].function.get(); - if (fn) return fn; + if (fn) + return fn; } return nullptr; } }; - auto nodeShouldExecute = [this](Node *fn) { + auto nodeShouldExecute = [this](Node* fn) { auto it = exec_info_.find(fn); return it != exec_info_.end() && it->second.should_execute(); }; @@ -1369,11 +1447,11 @@ void GraphTask::init_to_execute(Node& graph_root, const edge_list& outputs, bool exec_info_.emplace(stack.back().fn_, ExecInfo()); while (!stack.empty()) { - auto &frame = stack.back(); + auto& frame = stack.back(); const auto fn = frame.fn_; - Node *child_fn = nullptr; - while((child_fn = frame.get_next_fn()) && !seen.emplace(child_fn).second) { + Node* child_fn = nullptr; + while ((child_fn = frame.get_next_fn()) && !seen.emplace(child_fn).second) { // (1) next child exists AND has already been seen if (nodeShouldExecute(child_fn)) { exec_info_[fn].needed_ = true; @@ -1399,4 +1477,5 @@ void GraphTask::init_to_execute(Node& graph_root, const edge_list& outputs, bool } } -}} // namespace torch::autograd +} // namespace autograd +} // namespace torch diff --git a/torch/csrc/autograd/engine.h b/torch/csrc/autograd/engine.h index 6aae048432ce68..e436e95b1b669c 100644 --- a/torch/csrc/autograd/engine.h +++ b/torch/csrc/autograd/engine.h @@ -4,8 +4,8 @@ // to "root" variables (variables created by the user with requires_grad=True). #include -#include #include +#include #include #include #include @@ -19,16 +19,19 @@ #include #include #include +#include #include #include #include -#include -namespace torch { namespace autograd { +namespace torch { +namespace autograd { struct ReadyQueue; -}} // namespace torch::autograd +} +} // namespace torch -namespace torch { namespace autograd { +namespace torch { +namespace autograd { static constexpr int NO_DEVICE = -2; static constexpr int CPU_DEVICE = -1; @@ -48,7 +51,7 @@ void validate_outputs( const std::function& format_error); // GraphTask holds metadata needed for a single execution of backward() -struct GraphTask: std::enable_shared_from_this { +struct GraphTask : std::enable_shared_from_this { std::atomic outstanding_tasks_{0}; // Indicates if an error occurred while executing any task. When this is // true, it signals all threads to stop executing. @@ -98,10 +101,10 @@ struct GraphTask: std::enable_shared_from_this { // Exec info has a bit complicated semantics. If it's empty, it means the task // is run in a "default" mode, which means that all next_edges we encounter // should get executed. If it's not empty, only functions that have an entry - // and this entry has needed == True should be executed. exec_info is only empty - // when the graph is executed via .backward() and the inputs parameter is not passed. - // Otherwise, when executed through .grad(), or when inputs arg is specified for - // .backward(), exec_info will be non-empty. + // and this entry has needed == True should be executed. exec_info is only + // empty when the graph is executed via .backward() and the inputs parameter + // is not passed. Otherwise, when executed through .grad(), or when inputs arg + // is specified for .backward(), exec_info will be non-empty. // // exec_info_ is safe to read without synchronization std::unordered_map exec_info_; @@ -110,8 +113,8 @@ struct GraphTask: std::enable_shared_from_this { // out of the GraphTask and are no longer valid. std::vector captured_vars_; - // Note: this field is not ready to be used until the proper `thread_locals_.set_grad_mode()` - // call in the constructor. + // Note: this field is not ready to be used until the proper + // `thread_locals_.set_grad_mode()` call in the constructor. at::ThreadLocalState thread_locals_ = at::ThreadLocalState(); std::unordered_set leaf_streams; @@ -123,7 +126,11 @@ struct GraphTask: std::enable_shared_from_this { // Collects caller_current_streams_ void stash_current_streams(); - void init_to_execute(Node& graph_root, const edge_list& outputs, bool accumulate_grad, uint64_t min_topo_nr); + void init_to_execute( + Node& graph_root, + const edge_list& outputs, + bool accumulate_grad, + uint64_t min_topo_nr); // The value of worker_device in the thread that created this task. // See Note [Reentrant backwards] @@ -156,11 +163,11 @@ struct GraphTask: std::enable_shared_from_this { // an exception as soon as the autograd engine receives an exception. bool exit_on_error_; - // CPU threads are dedicated to processing CPU work for the backward they invoked. - // So any given graph task maintains its own cpu_ready_queue_ where you should send - // work for it to be done. We memoize the cpu_ready_queue_ per GraphTask so that - // we know which ready queue we should push to if we are on device thread (i.e. GPU) - // and but next NodeTask should be run on CPU. + // CPU threads are dedicated to processing CPU work for the backward they + // invoked. So any given graph task maintains its own cpu_ready_queue_ where + // you should send work for it to be done. We memoize the cpu_ready_queue_ per + // GraphTask so that we know which ready queue we should push to if we are on + // device thread (i.e. GPU) and but next NodeTask should be run on CPU. std::shared_ptr cpu_ready_queue_; // Future representing the completion of the graph task. Notified when all @@ -187,9 +194,11 @@ struct GraphTask: std::enable_shared_from_this { reentrant_depth_(reentrant_depth), exit_on_error_(exit_on_error), cpu_ready_queue_(std::move(cpu_ready_queue)), - future_result_(c10::make_intrusive(c10::ListType::create(c10::TensorType::get()))) { + future_result_(c10::make_intrusive( + c10::ListType::create(c10::TensorType::get()))) { thread_locals_.set_grad_mode(grad_mode); - } + } + private: // run GraphTask post processing void exec_post_processing(); @@ -235,19 +244,20 @@ struct NodeTask { // Guard that sets and restores checkpoint_valid class CheckpointValidGuard { public: - explicit CheckpointValidGuard(const std::shared_ptr& graph_task); + explicit CheckpointValidGuard( + const std::shared_ptr& graph_task); ~CheckpointValidGuard(); + private: bool prev_checkpoint_valid_state; }; - struct ReadyQueue { private: // Returns true when t2 should be (weakly) BEFORE t1 in the queue. // Shutdown tasks are first and then empty NodeTask are next. struct CompareNodeTaskTime { - bool operator()(NodeTask const & t1, NodeTask const & t2) { + bool operator()(NodeTask const& t1, NodeTask const& t2) { // NOLINTNEXTLINE(bugprone-branch-clone) if (t2.isShutdownTask_) { return true; @@ -268,7 +278,8 @@ struct ReadyQueue { // To protect read and writes to heap_ mutable std::mutex mutex_; - std::priority_queue, CompareNodeTaskTime> heap_; + std::priority_queue, CompareNodeTaskTime> + heap_; public: // incrementOutstandingTasks indicates whether or not we should increment @@ -282,8 +293,9 @@ struct ReadyQueue { size_t size() const; }; -// A single instance of this struct should be created through the whole process lifetime. -// The worker thread creation logic and Engine's destructor rely on this. +// A single instance of this struct should be created through the whole process +// lifetime. The worker thread creation logic and Engine's destructor rely on +// this. struct TORCH_API Engine { /// Returns a reference to a static `Engine` instance. static Engine& get_default_engine(); @@ -356,10 +368,11 @@ struct TORCH_API Engine { Engine(); void compute_dependencies(Node* root, GraphTask& task, uint64_t min_topo_nr); - // initialize the thread local ready queue with the ready queue that is created - // elsewhere (i.e. thread_init, Engine::execute, etc), or create a new + // initialize the thread local ready queue with the ready queue that is + // created elsewhere (i.e. thread_init, Engine::execute, etc), or create a new // ready queue if ready_queue is not provided. - void init_local_ready_queue(std::shared_ptr ready_queue = nullptr); + void init_local_ready_queue( + std::shared_ptr ready_queue = nullptr); std::shared_ptr ready_queue( std::shared_ptr cpu_ready_queue, @@ -379,7 +392,8 @@ struct TORCH_API Engine { // Ensures device_ready_queues_ are initialized only once // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) std::once_flag start_device_threads_flag_; - // Safe to read device_ready_queues_ without synchronization after initialization + // Safe to read device_ready_queues_ without synchronization after + // initialization // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) std::vector> device_ready_queues_; @@ -409,16 +423,16 @@ struct TORCH_API Engine { // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) ThreadPoolShared() : num_workers_(0) {} - }; + }; - // Temporary workaround until shutting down threads is done - // We need shared ownership of all these objects because the threads are leaked - // when Engine shuts down, so there may be threads waiting on work_ - // for the graphtasks_queue_ to be nonempty. - // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) - std::shared_ptr thread_pool_shared_; + // Temporary workaround until shutting down threads is done + // We need shared ownership of all these objects because the threads are + // leaked when Engine shuts down, so there may be threads waiting on work_ for + // the graphtasks_queue_ to be nonempty. + // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) + std::shared_ptr thread_pool_shared_; -private: + private: // Number of non-reentrant threads std::atomic non_reentrant_device_thread_count_; // Destructor will wait for non-reentrant threads to finish @@ -435,4 +449,5 @@ struct TORCH_API Engine { using EngineStub = Engine& (*)(); TORCH_API void set_default_engine_stub(EngineStub stub); -}} // namespace torch::autograd +} // namespace autograd +} // namespace torch diff --git a/torch/csrc/autograd/forward_grad.cpp b/torch/csrc/autograd/forward_grad.cpp index d881360e565298..f9e6945f21338f 100644 --- a/torch/csrc/autograd/forward_grad.cpp +++ b/torch/csrc/autograd/forward_grad.cpp @@ -1,72 +1,80 @@ #include -namespace torch { namespace autograd { +namespace torch { +namespace autograd { namespace { - // See discussion in forward_grad.h for why these are global variables and not - // thread local +// See discussion in forward_grad.h for why these are global variables and not +// thread local - std::mutex all_forward_levels_mutex_; - std::vector> all_forward_levels_; +std::mutex all_forward_levels_mutex_; +std::vector> all_forward_levels_; - const static at::Tensor singleton_undefined_tensor; -} +const static at::Tensor singleton_undefined_tensor; +} // namespace uint64_t ForwardADLevel::get_next_idx() { - std::lock_guard lock(all_forward_levels_mutex_); - auto next_idx = all_forward_levels_.size(); - TORCH_CHECK(next_idx == 0, "Nested forward mode AD is not supported at the moment"); - all_forward_levels_.push_back(std::make_shared(next_idx)); - return next_idx; + std::lock_guard lock(all_forward_levels_mutex_); + auto next_idx = all_forward_levels_.size(); + TORCH_CHECK( + next_idx == 0, "Nested forward mode AD is not supported at the moment"); + all_forward_levels_.push_back(std::make_shared(next_idx)); + return next_idx; } void ForwardADLevel::release_idx(uint64_t idx) { - std::unique_lock lock(all_forward_levels_mutex_); - TORCH_CHECK(idx + 1 == all_forward_levels_.size(), "Exiting a forward AD level that is not the " - "last that was created is not support. Ensure they are released in the reverse " - "order they were created."); - TORCH_INTERNAL_ASSERT(all_forward_levels_.size() > 0); - // Keep the level alive until we have released the lock - auto lvl = all_forward_levels_.back(); - all_forward_levels_.pop_back(); - lock.unlock(); + std::unique_lock lock(all_forward_levels_mutex_); + TORCH_CHECK( + idx + 1 == all_forward_levels_.size(), + "Exiting a forward AD level that is not the " + "last that was created is not support. Ensure they are released in the reverse " + "order they were created."); + TORCH_INTERNAL_ASSERT(all_forward_levels_.size() > 0); + // Keep the level alive until we have released the lock + auto lvl = all_forward_levels_.back(); + all_forward_levels_.pop_back(); + lock.unlock(); } std::shared_ptr ForwardADLevel::get_by_idx(uint64_t idx) { - std::lock_guard lock(all_forward_levels_mutex_); - TORCH_CHECK(idx < all_forward_levels_.size(), "Trying to access a forward AD level with an invalid index. " - "This index was either not created or is already deleted."); - return all_forward_levels_[idx]; + std::lock_guard lock(all_forward_levels_mutex_); + TORCH_CHECK( + idx < all_forward_levels_.size(), + "Trying to access a forward AD level with an invalid index. " + "This index was either not created or is already deleted."); + return all_forward_levels_[idx]; } std::shared_ptr ForwardADLevel::try_get_by_idx(uint64_t idx) { - std::lock_guard lock(all_forward_levels_mutex_); - if (idx < all_forward_levels_.size()) { - return all_forward_levels_[idx]; - } else { - return nullptr; - } + std::lock_guard lock(all_forward_levels_mutex_); + if (idx < all_forward_levels_.size()) { + return all_forward_levels_[idx]; + } else { + return nullptr; + } } ForwardADLevel::~ForwardADLevel() { - std::lock_guard lock(mutex_); - auto it = grads_.begin(); - while (it != grads_.end()) { - // Warning this will lock *it mutex - // This is ok as this function is the *only* one to call back into another class's method. - (*it)->reset(idx_, /* update_level */ false); - it = grads_.erase(it); - } + std::lock_guard lock(mutex_); + auto it = grads_.begin(); + while (it != grads_.end()) { + // Warning this will lock *it mutex + // This is ok as this function is the *only* one to call back into another + // class's method. + (*it)->reset(idx_, /* update_level */ false); + it = grads_.erase(it); + } } const at::Tensor& ForwardGrad::value(uint64_t level) const { - std::lock_guard lock(mutex_); - const auto& it = content_.find(level); - return it == content_.end() ? singleton_undefined_tensor : (*it).second; + std::lock_guard lock(mutex_); + const auto& it = content_.find(level); + return it == content_.end() ? singleton_undefined_tensor : (*it).second; } const at::Tensor& ForwardGrad::undef_grad() { - return singleton_undefined_tensor; + return singleton_undefined_tensor; } -}} // namespace torch::autograd +} // namespace autograd +} // namespace torch diff --git a/torch/csrc/autograd/forward_grad.h b/torch/csrc/autograd/forward_grad.h index b859ba72ea5ce4..ff682bfc5ba7fb 100644 --- a/torch/csrc/autograd/forward_grad.h +++ b/torch/csrc/autograd/forward_grad.h @@ -2,85 +2,101 @@ #include - -namespace torch { namespace autograd { +namespace torch { +namespace autograd { // [ Using ForwardGrad ] -// ForwardGrad needs to be a shared_ptr to satisfy constraints of its inner design. But -// this shared_ptr must be uniquely associated with the object that stores it (as of -// writing, either AutogradMeta or SavedVariable). This object is called the "owning object" -// in the discussions below. This owning object must call `ForwardGrad::clear()` when it -// is destroyed to ensure that the ForwardGrad is properly de-allocated. +// ForwardGrad needs to be a shared_ptr to satisfy constraints of its inner +// design. But this shared_ptr must be uniquely associated with the object that +// stores it (as of writing, either AutogradMeta or SavedVariable). This object +// is called the "owning object" in the discussions below. This owning object +// must call `ForwardGrad::clear()` when it is destroyed to ensure that the +// ForwardGrad is properly de-allocated. struct ForwardGrad; -// This file contains two classes that are used to store forward AD gradients and -// ensure that they are scoped properly. -// Because forward AD runs concurrently with the evaluation of the function, we need -// a mechanism to separate different forward AD invocations and be able to compute the -// right gradients. We model such invocations as levels here. -// The particular scoping issue mentioned above has two main drivers: -// - Ensure that we can conveniently use forward AD within a high level API without +// This file contains two classes that are used to store forward AD gradients +// and ensure that they are scoped properly. Because forward AD runs +// concurrently with the evaluation of the function, we need a mechanism to +// separate different forward AD invocations and be able to compute the right +// gradients. We model such invocations as levels here. The particular scoping +// issue mentioned above has two main drivers: +// - Ensure that we can conveniently use forward AD within a high level API +// without // leaking the forward AD states outside. -// - Ensure that we can keep the level that we expose to the user API simple (an integer -// that represents the nesting depth) while avoiding confusions when the level index -// is re-used. +// - Ensure that we can keep the level that we expose to the user API simple +// (an integer +// that represents the nesting depth) while avoiding confusions when the +// level index is re-used. // The important external APIs from this file are: -// - ForwardADLevel::get_next_idx() that can be used to enter a new level and get its index +// - ForwardADLevel::get_next_idx() that can be used to enter a new level and +// get its index // - ForwardADLevel::release_idx() that can be used to exit a given level. -// - ForwardGrad() can be used to store a given forward gradient that will handle the level +// - ForwardGrad() can be used to store a given forward gradient that will +// handle the level // tracking automatically. // The basic implementation strategy is as follows: // Every tensor has a ForwardGrad, maintaining a map from levels to tangents. -// ForwardGrad is responsible for registering itself to the appropriate ForwardADLevel when a new -// tangent is added to it via ForwardGrad::set_value and to un-register itself from this same level -// if that tangent is removed via ForwardGrad::reset. -// The ForwardADLevel is created when a new level is entered via ForwardADLevel::get_next_idx. -// A reference to the new ForwardADLevel is stored into a global (for the whole process) vector that -// ensure it can be accessed via ForwardADLevel::get_by_idx. This reference is deleted when the index is -// released by the user when calling ForwardADLevel::release_idx. -// When it is destructed, the ForwardADLevel is responsible for clearing all the tangents for its -// level stored in all the ForwardGrad that registered with it. +// ForwardGrad is responsible for registering itself to the appropriate +// ForwardADLevel when a new tangent is added to it via ForwardGrad::set_value +// and to un-register itself from this same level if that tangent is removed via +// ForwardGrad::reset. The ForwardADLevel is created when a new level is entered +// via ForwardADLevel::get_next_idx. A reference to the new ForwardADLevel is +// stored into a global (for the whole process) vector that ensure it can be +// accessed via ForwardADLevel::get_by_idx. This reference is deleted when the +// index is released by the user when calling ForwardADLevel::release_idx. When +// it is destructed, the ForwardADLevel is responsible for clearing all the +// tangents for its level stored in all the ForwardGrad that registered with it. // -// This process-wide level design, compared to a thread local one, allows us to use very simple user facing -// handle for the level (an int) while enabling cross-thread forward AD. -// The only required synchronization for the user is when entering and exiting the levels. -// Some discussion on alternative design is in https://github.com/pytorch/pytorch/pull/49097#discussion_r543716453 -// and can be refined in the future. +// This process-wide level design, compared to a thread local one, allows us to +// use very simple user facing handle for the level (an int) while enabling +// cross-thread forward AD. The only required synchronization for the user is +// when entering and exiting the levels. Some discussion on alternative design +// is in https://github.com/pytorch/pytorch/pull/49097#discussion_r543716453 and +// can be refined in the future. // Correctness of concurrency: -// Each class uses its own lock when reading or modifying internal storages. This allows in particular -// to safely remove tangents from ForwardGrad when the ForwardADLevel is being exited. -// We ensure no deadlock by ensuring that a methods never calls into another class's method while -// the local class's lock is held except in one single case: calling from ForwardADLevel's destructor +// Each class uses its own lock when reading or modifying internal storages. +// This allows in particular to safely remove tangents from ForwardGrad when the +// ForwardADLevel is being exited. We ensure no deadlock by ensuring that a +// methods never calls into another class's method while the local class's lock +// is held except in one single case: calling from ForwardADLevel's destructor // into ForwardGrad::reset with update_level=false. // The lifetime of these objects is as follows: // The ForwardADLevel can be in three states: -// - Initialized: where one of its reference is held by the global vector and there may be more +// - Initialized: where one of its reference is held by the global vector +// and there may be more // references held by temporary variables in ForwardGrad's methods. -// - About to be destructed: where "release_idx" has been called and the only reason for the -// ForwardADLevel not to be destructed right away is that some methods in ForwardGrad have -// owning reference to it. This is done so that a ForwardADLevel can never be destructed when -// a ForwardGrad is registered with it and in the process of adding something to its internal state. -// - Being destructed: Here the ForwardADLevel is not referenced anymore and can be safely reset -// all of the ForwardGrad. Note that we can have more than one reset being called here (which is ok) -// but we are guaranteed that there is at least one. -// The ForwardGrad is simpler as there is no intermediary state and no special destructor for. The logic to -// unregister it from the different ForwardADLevel is done when the owning object (AutogradMeta or -// SavedVariable) is being destroyed. +// - About to be destructed: where "release_idx" has been called and the +// only reason for the +// ForwardADLevel not to be destructed right away is that some methods in +// ForwardGrad have owning reference to it. This is done so that a +// ForwardADLevel can never be destructed when a ForwardGrad is +// registered with it and in the process of adding something to its +// internal state. +// - Being destructed: Here the ForwardADLevel is not referenced anymore +// and can be safely reset +// all of the ForwardGrad. Note that we can have more than one reset +// being called here (which is ok) but we are guaranteed that there is at +// least one. +// The ForwardGrad is simpler as there is no intermediary state and no special +// destructor for. The logic to unregister it from the different ForwardADLevel +// is done when the owning object (AutogradMeta or SavedVariable) is being +// destroyed. // Other considered design: -// To avoid having the ForwardGrad::clear, we considered storing weak_ptr inside the ForwardADLevel. While this -// would work, it would mean that the set inside the ForwardADLevel would only grow unless we do an -// expensive linear scan to remove all the dangling weak pointers. Hence this approach was not used. +// To avoid having the ForwardGrad::clear, we considered storing weak_ptr inside +// the ForwardADLevel. While this would work, it would mean that the set inside +// the ForwardADLevel would only grow unless we do an expensive linear scan to +// remove all the dangling weak pointers. Hence this approach was not used. // Data structures in this file are optimized for this maximum number of levels. // The number of levels corresponds to the degree of the gradient being -// computed using forward AD and we don't expect more than second order gradients -// to be common. +// computed using forward AD and we don't expect more than second order +// gradients to be common. #define EXPECTED_MAX_LEVEL 2 struct TORCH_API ForwardADLevel { @@ -102,11 +118,10 @@ struct TORCH_API ForwardADLevel { grads_.insert(grad); } -private: - std::unordered_set> grads_; - std::mutex mutex_; - uint64_t idx_; - + private: + std::unordered_set> grads_; + std::mutex mutex_; + uint64_t idx_; }; struct TORCH_API ForwardGrad : std::enable_shared_from_this { @@ -144,53 +159,53 @@ struct TORCH_API ForwardGrad : std::enable_shared_from_this { } void set_value(const at::Tensor& value, uint64_t level) { - // Owning reference to ensure the forward_level is not destroyed - // while we are updating our internal state - auto forward_level = ForwardADLevel::get_by_idx(level); - forward_level->insert(shared_from_this()); + // Owning reference to ensure the forward_level is not destroyed + // while we are updating our internal state + auto forward_level = ForwardADLevel::get_by_idx(level); + forward_level->insert(shared_from_this()); - std::lock_guard lock(mutex_); - content_.insert({level, value}); + std::lock_guard lock(mutex_); + content_.insert({level, value}); } // This function removes the tangent for a given level from this ForwardGrad // Use the update_level flag to disable notifying the level about this reset // This flag is most notably used by the ForwardADLevel destructor. - void reset(uint64_t level, bool update_level=true) { - if (update_level) { - ForwardADLevel::get_by_idx(level)->erase(shared_from_this()); - } + void reset(uint64_t level, bool update_level = true) { + if (update_level) { + ForwardADLevel::get_by_idx(level)->erase(shared_from_this()); + } - std::unique_lock lock(mutex_); - const auto& it = content_.find(level); - TORCH_INTERNAL_ASSERT(it != content_.end(), "Resetting a non-existent level."); - // Keep the Tensor alive until we have released the lock - // This is needed as we can be in a case where this function is called by - // ForwardADLevel destructor - auto t = (*it).second; - content_.erase(level); - lock.unlock(); + std::unique_lock lock(mutex_); + const auto& it = content_.find(level); + TORCH_INTERNAL_ASSERT( + it != content_.end(), "Resetting a non-existent level."); + // Keep the Tensor alive until we have released the lock + // This is needed as we can be in a case where this function is called by + // ForwardADLevel destructor + auto t = (*it).second; + content_.erase(level); + lock.unlock(); } const at::Tensor& value(uint64_t level) const; bool contains(uint64_t level) { - std::lock_guard lock(mutex_); - return content_.count(level) > 0; + std::lock_guard lock(mutex_); + return content_.count(level) > 0; } bool empty() const { - return content_.empty(); + return content_.empty(); } static const at::Tensor& undef_grad(); - -private: - // TODO(albanD): replace this with a SmallVector - std::unordered_map content_; - mutable std::mutex mutex_; - + private: + // TODO(albanD): replace this with a SmallVector + std::unordered_map content_; + mutable std::mutex mutex_; }; -}} // namespace torch::autograd +} // namespace autograd +} // namespace torch diff --git a/torch/csrc/autograd/function.cpp b/torch/csrc/autograd/function.cpp index 83509c9dae1060..5ab3447ca9ef37 100644 --- a/torch/csrc/autograd/function.cpp +++ b/torch/csrc/autograd/function.cpp @@ -14,7 +14,8 @@ #include #include -namespace torch { namespace autograd { +namespace torch { +namespace autograd { // The current evaluating node. This is useful to assign the current node as a // parent of new nodes created during the evaluation of this node in anomaly @@ -61,25 +62,26 @@ static void gatherFunctions( } /* - * Fix for #5534: prevent stack overflow on deletion of deep computation graph - * - * Sometimes one can end up with a very big computation graph of Nodes - * and Edges. Each std::shared_ptr contains a list of Edge, and - * each Edge contains a std::shared_ptr. Deleting a - * std::shared_ptr can trigger the recursive deletion of other - * std::shared_ptr's: this can stack overflow if the graph - * is deep enough. Here is an example of such a graph: - * - * shared_ptr -> Edge -> shared_ptr -> Edge -> ... -> shared_ptr - * - * The solution here is to detect when we are decrementing away the last - * reference to a Node, and when doing so to buffer up the Node's - * that will be recursively decremented. We can then decrement (and free) - * the original Node without causing a recursive cascade, before - * draining the buffer applying the same behavior. This is, in effect, - * converting recursion to a loop, using a heap buffer in place of the - * recursive call stack. - */ + * Fix for #5534: prevent stack overflow on deletion of deep computation graph + * + * Sometimes one can end up with a very big computation graph of Nodes + * and Edges. Each std::shared_ptr contains a list of Edge, and + * each Edge contains a std::shared_ptr. Deleting a + * std::shared_ptr can trigger the recursive deletion of other + * std::shared_ptr's: this can stack overflow if the graph + * is deep enough. Here is an example of such a graph: + * + * shared_ptr -> Edge -> shared_ptr -> Edge -> ... -> + * shared_ptr + * + * The solution here is to detect when we are decrementing away the last + * reference to a Node, and when doing so to buffer up the Node's + * that will be recursively decremented. We can then decrement (and free) + * the original Node without causing a recursive cascade, before + * draining the buffer applying the same behavior. This is, in effect, + * converting recursion to a loop, using a heap buffer in place of the + * recursive call stack. + */ void deleteNode(Node* function) { // To avoid stack overflow on large computational graphs, // we need to track reference decrementing and freeing @@ -97,4 +99,5 @@ void deleteNode(Node* function) { } } -}} // namespace torch::autograd +} // namespace autograd +} // namespace torch diff --git a/torch/csrc/autograd/function.h b/torch/csrc/autograd/function.h index 9c18eced91e0cd..7c2de3be48c9e4 100644 --- a/torch/csrc/autograd/function.h +++ b/torch/csrc/autograd/function.h @@ -1,17 +1,17 @@ #pragma once +#include #include #include -#include -#include #include +#include #include #include #include +#include #include #include -#include #include #include @@ -28,7 +28,8 @@ C10_CLANG_DIAGNOSTIC_PUSH() C10_CLANG_DIAGNOSTIC_IGNORE("-Wshorten-64-to-32") #endif -namespace torch { namespace autograd { +namespace torch { +namespace autograd { struct Edge; struct FunctionPostHook; @@ -100,19 +101,16 @@ class NodeGuard { // sequence numbers will be ordered `A` < `B` < `C`. If, however, `A` and `B` // are created in one thread and `C` is created in a new thread, there are *no // guarantees* w.r.t. the ordering of `C` relative to `A` or `B`. -// See NOTE [ Sequence Number] for more details on the usages of sequence number. +// See NOTE [ Sequence Number] for more details on the usages of sequence +// number. //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ struct TORCH_API Node : std::enable_shared_from_this { public: /// Construct a new `Node` with the given `next_edges` // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) - explicit Node( - uint64_t sequence_nr, - edge_list&& next_edges = edge_list()) - : sequence_nr_(sequence_nr), - next_edges_(std::move(next_edges)) { - - for (const Edge& edge: next_edges_) { + explicit Node(uint64_t sequence_nr, edge_list&& next_edges = edge_list()) + : sequence_nr_(sequence_nr), next_edges_(std::move(next_edges)) { + for (const Edge& edge : next_edges_) { update_topological_nr(edge); } @@ -133,8 +131,9 @@ struct TORCH_API Node : std::enable_shared_from_this { // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) explicit Node(edge_list&& next_edges = edge_list()) - : Node(/*sequence_nr=*/at::sequence_number::get_and_increment(), - std::move(next_edges)) {} + : Node( + /*sequence_nr=*/at::sequence_number::get_and_increment(), + std::move(next_edges)) {} /// Nodes are neither copyable nor moveable. Node(const Node& other) = delete; @@ -151,7 +150,8 @@ struct TORCH_API Node : std::enable_shared_from_this { // probably operate with names. at::NoNamesGuard no_names_guard; - auto step_callbacks = at::getStepCallbacksUnlessEmpty(at::RecordScope::BACKWARD_FUNCTION); + auto step_callbacks = + at::getStepCallbacksUnlessEmpty(at::RecordScope::BACKWARD_FUNCTION); if (C10_UNLIKELY(step_callbacks.has_value())) { at::RecordFunction guard(std::move(*step_callbacks)); // Using sequence number and thread id to correlate with @@ -160,9 +160,10 @@ struct TORCH_API Node : std::enable_shared_from_this { if (guard.needsInputs()) { std::vector inputs_vec(inputs.begin(), inputs.end()); guard.before( - name(), - c10::ArrayRef(inputs_vec.data(), inputs_vec.size()), - sequence_nr()); + name(), + c10::ArrayRef( + inputs_vec.data(), inputs_vec.size()), + sequence_nr()); } else { guard.before(name(), sequence_nr()); } @@ -184,9 +185,9 @@ struct TORCH_API Node : std::enable_shared_from_this { /// Adds the type and shape metadata for a new input. Returns the index of /// of the new input. uint32_t add_input_metadata( - const at::TensorOptions& options, - at::IntArrayRef shape, - bool is_tensor_subclass) noexcept { + const at::TensorOptions& options, + at::IntArrayRef shape, + bool is_tensor_subclass) noexcept { // NOLINTNEXTLINE(cppcoreguidelines-init-variables) uint32_t input_nr = input_metadata_.size(); input_metadata_.emplace_back(options, shape, is_tensor_subclass); @@ -227,7 +228,8 @@ struct TORCH_API Node : std::enable_shared_from_this { */ c10::optional stream(const c10::DeviceType device_type) { for (const auto& metadata : input_metadata_) { - if (metadata.device().type() == device_type) return metadata.stream(); + if (metadata.device().type() == device_type) + return metadata.stream(); } return c10::nullopt; @@ -240,10 +242,11 @@ struct TORCH_API Node : std::enable_shared_from_this { // Outputs ("Next Edges") void update_topological_nr(const Edge& edge) { - TORCH_INTERNAL_ASSERT(!has_parent_, - "Cannot update a node's topological_nr after it already has a parent." - " If we allow this, we can no longer guarantee that a parent's" - " topo_nr is always greater than those of all its children") + TORCH_INTERNAL_ASSERT( + !has_parent_, + "Cannot update a node's topological_nr after it already has a parent." + " If we allow this, we can no longer guarantee that a parent's" + " topo_nr is always greater than those of all its children") Node* node = edge.function.get(); if (node) { auto topo_nr = node->topological_nr(); @@ -265,7 +268,7 @@ struct TORCH_API Node : std::enable_shared_from_this { void set_next_edges(edge_list&& next_edges) { next_edges_ = std::move(next_edges); - for(const auto& next_edge : next_edges_) { + for (const auto& next_edge : next_edges_) { update_topological_nr(next_edge); } } @@ -294,32 +297,35 @@ struct TORCH_API Node : std::enable_shared_from_this { /// The sequence_nr has two main usages in autograd: /// /// 1) Helps determine the node's execution priority in the engine. - /// All else being equal, nodes with higher priority numbers are executed first. - /// Thus, nodes corresponding to ops executed later are the first to be executed in - /// the backward pass. One caveat is that we prioritize AccumulateGrad nodes by - /// explicitly setting its sequence_nr to be UINT64_MAX. - /// 2) The sequence number of this `Node` is paired with with thread_id it was created in + /// All else being equal, nodes with higher priority numbers are executed + /// first. Thus, nodes corresponding to ops executed later are the first to + /// be executed in the backward pass. One caveat is that we prioritize + /// AccumulateGrad nodes by explicitly setting its sequence_nr to be + /// UINT64_MAX. + /// 2) The sequence number of this `Node` is paired with with thread_id it was + /// created in /// as a unique identifier by the profiler to annotate recorded events. - /// The purpose of this is to help users (and possibly programs) interpreting the profiler's - /// output to correlate backward nodes with its forward ops. - /// We need both sequence_nr and thread_id to identify a node because sequence_nr is - /// thread_local, i.e., starts counting up from zero in a new thread + /// The purpose of this is to help users (and possibly programs) + /// interpreting the profiler's output to correlate backward nodes with its + /// forward ops. We need both sequence_nr and thread_id to identify a node + /// because sequence_nr is thread_local, i.e., starts counting up from zero + /// in a new thread uint64_t sequence_nr() const noexcept { return sequence_nr_; } // NOTE [ Topological Number ] // - // topological_nr is used to prune branches in the DAG during autograd discovery as - // maintaining topological_nr helps us check in O(1) if there does NOT exist - // a directed path between two nodes. + // topological_nr is used to prune branches in the DAG during autograd + // discovery as maintaining topological_nr helps us check in O(1) if there + // does NOT exist a directed path between two nodes. // // The topological order number of this `Node` representing the length of the - // longest possible path from this Node to any leaf node. If you are leaf node, - // aka AccumulateGrad, this will be zero. This value has the property that - // For every pair of nodes X, Y in G, existence of a directed path from X to Y - // implies topo_nr(X) > topo_nr(Y). The converse is not true, however, so we - // cannot prove existence of a path from X to Y, only non-existence. + // longest possible path from this Node to any leaf node. If you are leaf + // node, aka AccumulateGrad, this will be zero. This value has the property + // that For every pair of nodes X, Y in G, existence of a directed path from X + // to Y implies topo_nr(X) > topo_nr(Y). The converse is not true, however, so + // we cannot prove existence of a path from X to Y, only non-existence. // // One assumption we make when using topo_nr is that once a node // has been used, i.e., has a parent node, its own topo_nr does not change @@ -327,12 +333,15 @@ struct TORCH_API Node : std::enable_shared_from_this { // // What NOT to do: // - // 1) 2 -> 1 -> 0 In this diagram we label nodes with their topo_nr. - // 2 -> 1 -> 0 We have two simple graphs that can each arise from + // 1) 2 -> 1 -> 0 In this diagram we label nodes with their + // topo_nr. + // 2 -> 1 -> 0 We have two simple graphs that can each + // arise from // `t.exp().exp()`, for example. // 2) 2 -> 1 -> 0 // / - // 2 -> 1 -> 0 We add 2 as a next edge to 1 even though 1 already + // 2 -> 1 -> 0 We add 2 as a next edge to 1 even though 1 + // already // has a parent. // 3) 2 -> 1 -> 0 // / @@ -397,8 +406,8 @@ struct TORCH_API Node : std::enable_shared_from_this { return reinterpret_cast(post_hooks_.back().get()); } - const std::vector>& post_hooks() const - noexcept { + const std::vector>& post_hooks() + const noexcept { return post_hooks_; } @@ -421,8 +430,8 @@ struct TORCH_API Node : std::enable_shared_from_this { pre_hooks_.push_back(std::move(pre_hook)); } - const std::vector>& pre_hooks() const - noexcept { + const std::vector>& pre_hooks() + const noexcept { return pre_hooks_; } @@ -478,8 +487,8 @@ struct TORCH_API Node : std::enable_shared_from_this { uint64_t topological_nr_ = 0; // Tracks whether this node has been added as the next_edge of another node - // via set_next_edge(s), which always calls topological_nr() of all its children - // See NOTE [ Topological Number ] for why we need this. + // via set_next_edge(s), which always calls topological_nr() of all its + // children See NOTE [ Topological Number ] for why we need this. // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) mutable bool has_parent_ = false; @@ -489,40 +498,46 @@ struct TORCH_API Node : std::enable_shared_from_this { // Note [Thread Safety on Autograd Node] // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - // Autograd Engine let the owning thread which calls Engine::execute to drive the - // GraphTask execution, there might be cases that part of the GraphTask is shared - // across different `backward()` or `grad()` calls, i.e. fork new threads in the - // middle of the forward and call `backward()` separately from different threads. - // We need to protect the thread safety on NodeTask to prevent data racing on - // shared variables read/write. + // Autograd Engine let the owning thread which calls Engine::execute to drive + // the GraphTask execution, there might be cases that part of the GraphTask is + // shared across different `backward()` or `grad()` calls, i.e. fork new + // threads in the middle of the forward and call `backward()` separately from + // different threads. We need to protect the thread safety on NodeTask to + // prevent data racing on shared variables read/write. // - // NB: This is only needed for Autograd Nodes that runs on CPU, technically "CUDA", - // "XLA" nodes don't need locking because device threads are always single threaded. + // NB: This is only needed for Autograd Nodes that runs on CPU, technically + // "CUDA", "XLA" nodes don't need locking because device threads are always + // single threaded. // - // Here we add a thread mutex to help protect the Node's thread safety, so that - // different threads cannot race the shared data when executing the same NodeTask - // from multiple CPU threads. It IS the user/developer responsibility to take - // advantage of this mutex to protect the thread safety of their autograd Node. - // The general strategy of thread safety on autograd Node: + // Here we add a thread mutex to help protect the Node's thread safety, so + // that different threads cannot race the shared data when executing the same + // NodeTask from multiple CPU threads. It IS the user/developer responsibility + // to take advantage of this mutex to protect the thread safety of their + // autograd Node. The general strategy of thread safety on autograd Node: // - // 1. User should lock the mutex during Node::release_variables() if the Node needs - // to release the variables on the fly, this serve the purpose that when we release - // saved_variables from one thread, no other threads can release the saved variables - // concurrently. call - // the Node::apply(), - // 2. User should lock the mutex during Node::apply(), this is to ensure Node that - // writing to the shared variable are not racing across threads (i.e. AccumulateGrad - // and custom C++ Autograd Node if writing to shared variables ) - // 3. item 2 and item 3 should work together so that when we release saved variables - // from one thread, no other threads can call Node::apply(), this ensures the variable - // references from other threads aren't dangling. - // 4. if the Node don't release any variables and no shared data read/write in the Node + // 1. User should lock the mutex during Node::release_variables() if the Node + // needs + // to release the variables on the fly, this serve the purpose that when we + // release saved_variables from one thread, no other threads can release + // the saved variables concurrently. call the Node::apply(), + // 2. User should lock the mutex during Node::apply(), this is to ensure Node + // that + // writing to the shared variable are not racing across threads (i.e. + // AccumulateGrad and custom C++ Autograd Node if writing to shared + // variables ) + // 3. item 2 and item 3 should work together so that when we release saved + // variables + // from one thread, no other threads can call Node::apply(), this ensures + // the variable references from other threads aren't dangling. + // 4. if the Node don't release any variables and no shared data read/write in + // the Node // i.e. purely functional, user don't need to lock the mutex // - // This way we could protect the thread safety on Autograd Node, but we could still - // not protect the thread safety on Node pre/post C++ hooks (python hooks are - // automatically thread safe), we rely on the user to write thread safe C++ hooks - // if they want the hook to be correctly applied in multithreading environment. + // This way we could protect the thread safety on Autograd Node, but we could + // still not protect the thread safety on Node pre/post C++ hooks (python + // hooks are automatically thread safe), we rely on the user to write thread + // safe C++ hooks if they want the hook to be correctly applied in + // multithreading environment. // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) std::mutex mutex_; @@ -619,6 +634,7 @@ edge_list collect_next_edges(Variables&&... variables) { make.apply(std::forward(variables)...); return std::move(make.next_edges); } -}} // namespace torch::autograd +} // namespace autograd +} // namespace torch C10_CLANG_DIAGNOSTIC_POP() diff --git a/torch/csrc/autograd/function_hook.cpp b/torch/csrc/autograd/function_hook.cpp index edaf19b06aa8fc..b1746ef140c91c 100644 --- a/torch/csrc/autograd/function_hook.cpp +++ b/torch/csrc/autograd/function_hook.cpp @@ -1,8 +1,10 @@ #include -namespace torch { namespace autograd { +namespace torch { +namespace autograd { FunctionPreHook::~FunctionPreHook() = default; FunctionPostHook::~FunctionPostHook() = default; -}} // namespace torch::autograd +} // namespace autograd +} // namespace torch diff --git a/torch/csrc/autograd/function_hook.h b/torch/csrc/autograd/function_hook.h index c09889af604840..6e3297cec2d26c 100644 --- a/torch/csrc/autograd/function_hook.h +++ b/torch/csrc/autograd/function_hook.h @@ -1,12 +1,13 @@ #pragma once -#include -#include #include +#include +#include // A hook that's called on gradients -namespace torch { namespace autograd { +namespace torch { +namespace autograd { using Variable = at::Tensor; using variable_list = std::vector; @@ -19,8 +20,9 @@ struct TORCH_API FunctionPreHook { struct TORCH_API FunctionPostHook { virtual ~FunctionPostHook(); virtual variable_list operator()( - const variable_list& outputs /* grad_inputs */, - const variable_list& inputs /* grad_outputs */) = 0; + const variable_list& outputs /* grad_inputs */, + const variable_list& inputs /* grad_outputs */) = 0; }; -}} // namespace torch::autograd +} // namespace autograd +} // namespace torch diff --git a/torch/csrc/autograd/functions/accumulate_grad.cpp b/torch/csrc/autograd/functions/accumulate_grad.cpp index 6fe1b01dfaf3d4..53a8f1364cbdb5 100644 --- a/torch/csrc/autograd/functions/accumulate_grad.cpp +++ b/torch/csrc/autograd/functions/accumulate_grad.cpp @@ -1,10 +1,10 @@ #include -#include -#include #include #include #include +#include +#include #include #include @@ -12,13 +12,13 @@ using at::Tensor; -namespace torch { namespace autograd { +namespace torch { +namespace autograd { // AccumulateGrad sets sequence_nr to the max value so it's always called // ASAP during backwards. AccumulateGrad::AccumulateGrad(Variable variable_) - : Node(/*sequence_nr=*/UINT64_MAX), - variable(std::move(variable_)) { + : Node(/*sequence_nr=*/UINT64_MAX), variable(std::move(variable_)) { add_input_metadata(variable); } @@ -60,4 +60,5 @@ auto AccumulateGrad::apply(variable_list&& grads) -> variable_list { return variable_list(); } -}} // namespace torch::autograd +} // namespace autograd +} // namespace torch diff --git a/torch/csrc/autograd/functions/accumulate_grad.h b/torch/csrc/autograd/functions/accumulate_grad.h index 0d6f8c595cf522..c18886a707e738 100644 --- a/torch/csrc/autograd/functions/accumulate_grad.h +++ b/torch/csrc/autograd/functions/accumulate_grad.h @@ -1,11 +1,11 @@ #pragma once -#include -#include -#include -#include #include #include +#include +#include +#include +#include #ifndef AT_PER_OPERATOR_HEADERS #include @@ -15,18 +15,26 @@ #include -namespace torch { namespace autograd { +namespace torch { +namespace autograd { -#define CHECK_RESULT(RESULT, VAR) \ - if (!(RESULT.is_sparse() || VAR.is_sparse() || RESULT.is_sparse_csr() || VAR.is_sparse_csr())) { \ - if (!utils::obeys_layout_contract(RESULT, VAR)) { \ - TORCH_WARN_ONCE("grad and param do not obey the gradient layout contract. " \ - "This is not an error, but may impair performance.\n" \ - "grad.sizes() = ", RESULT.sizes(), \ - ", strides() = ", RESULT.strides(), "\n", \ - "param.sizes() = ", VAR.sizes(), \ - ", strides() = ", VAR.strides()); \ - } \ +#define CHECK_RESULT(RESULT, VAR) \ + if (!(RESULT.is_sparse() || VAR.is_sparse() || RESULT.is_sparse_csr() || \ + VAR.is_sparse_csr())) { \ + if (!utils::obeys_layout_contract(RESULT, VAR)) { \ + TORCH_WARN_ONCE( \ + "grad and param do not obey the gradient layout contract. " \ + "This is not an error, but may impair performance.\n" \ + "grad.sizes() = ", \ + RESULT.sizes(), \ + ", strides() = ", \ + RESULT.strides(), \ + "\n", \ + "param.sizes() = ", \ + VAR.sizes(), \ + ", strides() = ", \ + VAR.strides()); \ + } \ } struct TORCH_API AccumulateGrad : public Node { @@ -34,9 +42,7 @@ struct TORCH_API AccumulateGrad : public Node { variable_list apply(variable_list&& grads) override; - static at::Tensor callHooks( - const Variable& variable, - at::Tensor new_grad) { + static at::Tensor callHooks(const Variable& variable, at::Tensor new_grad) { for (auto& hook : impl::hooks(variable)) { new_grad = (*hook)({new_grad})[0]; } @@ -64,20 +70,20 @@ struct TORCH_API AccumulateGrad : public Node { // // If variable's grad already exists (variable_grad.defined()), new_grad must // be added to variable_grad. If we aren't setting up for double backward - // (!GradMode::is_enabled()), AccumulateGrad performs "variable_grad += new_grad" - // in-place, which keeps variable_grad's layout. We assume (hope) variable_grad - // was created obeying (1) or (2) at some point in the past. + // (!GradMode::is_enabled()), AccumulateGrad performs "variable_grad += + // new_grad" in-place, which keeps variable_grad's layout. We assume (hope) + // variable_grad was created obeying (1) or (2) at some point in the past. // // If we are setting up for double backward, AccumulateGrad updates the grad - // out-of-place via "variable_grad + new_grad." TensorIterator operator+ decides - // result's layout. Typically TensorIterator matches strides of the first arg, - // so we once again assume (hope) variable_grad was originally created obeying - // (1) or (2). + // out-of-place via "variable_grad + new_grad." TensorIterator operator+ + // decides result's layout. Typically TensorIterator matches strides of the + // first arg, so we once again assume (hope) variable_grad was originally + // created obeying (1) or (2). // - // AccumulateGrad does not enforce the contract with 100% certainty. Examples: + // AccumulateGrad does not enforce the contract with 100% certainty. Examples: // - If a user manually permutes a param or its grad, then runs a fwd+bwd, - // variable_grad += new_grad keeps variable_grad's layout without rechecking - // the contract. + // variable_grad += new_grad keeps variable_grad's layout without + // rechecking the contract. // - If TensorIterator changes its corner cases about operator+'s result // (for example, giving more or less priority to channels_last inputs, see // https://github.com/pytorch/pytorch/pull/37968) the result may not obey. @@ -104,17 +110,19 @@ struct TORCH_API AccumulateGrad : public Node { size_t num_expected_refs, const T& update_grad) { if (!variable_grad.defined()) { - if (!GradMode::is_enabled() && - !new_grad.is_sparse() && !new_grad.is_sparse_csr() && + if (!GradMode::is_enabled() && !new_grad.is_sparse() && + !new_grad.is_sparse_csr() && !(variable.is_sparse_csr() && new_grad.layout() == at::kStrided) && new_grad.use_count() <= num_expected_refs && - (new_grad.is_mkldnn() || utils::obeys_layout_contract(new_grad, variable))) { + (new_grad.is_mkldnn() || + utils::obeys_layout_contract(new_grad, variable))) { // we aren't setting up for double-backward // not sparse // no other user-visible tensor references new_grad - // new_grad obeys the "Gradient Layout Contract", there has a special case, - // For MKLDNN tensor, which is a opaque tensor, assuming it obeys layout_contract. - // Under these conditions, we can steal new_grad without a deep copy. + // new_grad obeys the "Gradient Layout Contract", there has a special + // case, For MKLDNN tensor, which is a opaque tensor, assuming it obeys + // layout_contract. Under these conditions, we can steal new_grad + // without a deep copy. update_grad(new_grad.detach()); } else if ( !GradMode::is_enabled() && new_grad.is_sparse() && @@ -166,8 +174,8 @@ struct TORCH_API AccumulateGrad : public Node { } else if (!at::inplaceIsVmapCompatible(variable_grad, new_grad)) { // Ideally we'd perform an in-place operation to avoid changing // the grad tensor. However, if that's impossible because the grads - // are vmap-incompatible (See NOTE: [vmap-incompatible in-place operations]), - // then we just add them out-of-place. + // are vmap-incompatible (See NOTE: [vmap-incompatible in-place + // operations]), then we just add them out-of-place. auto result = variable_grad + new_grad; CHECK_RESULT(result, variable); update_grad(std::move(result)); @@ -199,7 +207,8 @@ struct TORCH_API AccumulateGrad : public Node { // } else { // result = at::empty_strided(variable.sizes(), variable.strides(), // variable.options().memory_format(c10::nullopt)); - // update_grad(at::native::add_out(result, variable_grad, new_grad, 1.0); + // update_grad(at::native::add_out(result, variable_grad, + // new_grad, 1.0); // } // However, that accumulation is sometimes in place and sometimes not, // which may break user code. @@ -207,11 +216,13 @@ struct TORCH_API AccumulateGrad : public Node { } else { at::Tensor result; if (variable_grad.is_sparse() && !new_grad.is_sparse()) { - // CPU backend throws an error on sparse + dense, so prefer dense + sparse here. + // CPU backend throws an error on sparse + dense, so prefer dense + + // sparse here. result = new_grad + variable_grad; } else { // Assumes operator+ result typically matches strides of first arg, - // and hopes variable_grad was originally created obeying layout contract. + // and hopes variable_grad was originally created obeying layout + // contract. result = variable_grad + new_grad; } CHECK_RESULT(result, variable); @@ -225,12 +236,12 @@ struct TORCH_API AccumulateGrad : public Node { // such that the stashed grad is likely to have the right strides if // either variable_grad or new_grad already has the right strides. // We could enforce the contract with certainty by saying - // auto result = variable_grad + new_grad (or vice versa), checking result's - // layout, and copying to an obedient clone if necessary before update_grad. - // The copy would require another gmem pass. We can't create empty result with - // the right layout then add_out into it with a single kernel, because GradMode - // is enabled in this branch, and add_out isn't differentiable. - // Maybe more trouble than it's worth. + // auto result = variable_grad + new_grad (or vice versa), checking + // result's layout, and copying to an obedient clone if necessary before + // update_grad. The copy would require another gmem pass. We can't create + // empty result with the right layout then add_out into it with a single + // kernel, because GradMode is enabled in this branch, and add_out isn't + // differentiable. Maybe more trouble than it's worth. } } diff --git a/torch/csrc/autograd/functions/basic_ops.cpp b/torch/csrc/autograd/functions/basic_ops.cpp index 4798a5f76c65f5..1a8e75d21519b0 100644 --- a/torch/csrc/autograd/functions/basic_ops.cpp +++ b/torch/csrc/autograd/functions/basic_ops.cpp @@ -1,15 +1,16 @@ #include #include -#include #include +#include #include #include #include -namespace torch { namespace autograd { +namespace torch { +namespace autograd { auto Error::apply(variable_list&& inputs) -> variable_list { throw std::runtime_error(msg); @@ -31,14 +32,16 @@ auto UndefinedGrad::apply(variable_list&& inputs) -> variable_list { tensor_list outputs; outputs.reserve(inputs.size()); for (auto& var : inputs) { - outputs.emplace_back(var.defined() ? var.clone().tensor_data() : at::Tensor()); + outputs.emplace_back( + var.defined() ? var.clone().tensor_data() : at::Tensor()); } return wrap_outputs(inputs, std::move(outputs), [&](edge_list&& next_edges) { return std::make_shared(std::move(next_edges)); }); } -auto UndefinedGradBackward::apply(variable_list&& output_grads) -> variable_list { +auto UndefinedGradBackward::apply(variable_list&& output_grads) + -> variable_list { tensor_list input_grads; output_grads.reserve(input_grads.size()); for (auto& grad : output_grads) { @@ -52,4 +55,5 @@ auto Identity::apply(variable_list&& grads) -> variable_list { return std::move(grads); } -}} // namespace torch::autograd +} // namespace autograd +} // namespace torch diff --git a/torch/csrc/autograd/functions/basic_ops.h b/torch/csrc/autograd/functions/basic_ops.h index 2675791c4c3dfd..bc40baccf851f6 100644 --- a/torch/csrc/autograd/functions/basic_ops.h +++ b/torch/csrc/autograd/functions/basic_ops.h @@ -9,15 +9,14 @@ #include #include -namespace torch { namespace autograd { +namespace torch { +namespace autograd { struct TORCH_API Error : public Node { Error(std::string msg, edge_list&& next_edges) - : Node(std::move(next_edges)) - , msg(std::move(msg)) {} + : Node(std::move(next_edges)), msg(std::move(msg)) {} - Error(std::string msg) - : msg(std::move(msg)) {} + Error(std::string msg) : msg(std::move(msg)) {} variable_list apply(variable_list&& inputs) override; @@ -29,20 +28,21 @@ struct TORCH_API Error : public Node { // special case with a new NotImplemented function here. struct TORCH_API NotImplemented : public Error { NotImplemented(const std::string& forward_fn, edge_list&& next_edges) - : Error("derivative for " + forward_fn + " is not implemented", + : Error( + "derivative for " + forward_fn + " is not implemented", std::move(next_edges)) {} NotImplemented(const std::string& forward_fn) - : Error("derivative for " + forward_fn + " is not implemented") {} + : Error("derivative for " + forward_fn + " is not implemented") {} }; -// Identity in forward, Error in backward. Used to implement @once_differentiable +// Identity in forward, Error in backward. Used to implement +// @once_differentiable struct TORCH_API DelayedError : public Node { - DelayedError(std::string msg, int num_inputs) - : msg(std::move(msg)) { + DelayedError(std::string msg, int num_inputs) : msg(std::move(msg)) { // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores) for (const auto i : c10::irange(num_inputs)) { - (void)i; //Suppress unused variable warning + (void)i; // Suppress unused variable warning add_input_metadata(Node::undefined_input()); } } @@ -61,8 +61,7 @@ struct TORCH_API UndefinedGrad : public Node { }; struct TORCH_API UndefinedGradBackward : public Node { - UndefinedGradBackward(edge_list&& next_edges) - : Node(std::move(next_edges)) {} + UndefinedGradBackward(edge_list&& next_edges) : Node(std::move(next_edges)) {} UndefinedGradBackward() = default; @@ -72,10 +71,10 @@ struct TORCH_API UndefinedGradBackward : public Node { struct TORCH_API GraphRoot : public Node { // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) GraphRoot(edge_list functions, variable_list inputs) - : Node(std::move(functions)), - outputs(std::move(inputs)) { - // Ensures calls to stream() on a GraphRoot instance reflect current stream(s) - // on devices of root grad tensors at the time the instance is constructed. + : Node(std::move(functions)), outputs(std::move(inputs)) { + // Ensures calls to stream() on a GraphRoot instance reflect current + // stream(s) on devices of root grad tensors at the time the instance is + // constructed. for (const auto& t : outputs) { add_input_metadata(t); } @@ -92,4 +91,5 @@ struct TORCH_API Identity : public Node { variable_list apply(variable_list&& inputs) override; }; -}} +} // namespace autograd +} // namespace torch diff --git a/torch/csrc/autograd/functions/comm.cpp b/torch/csrc/autograd/functions/comm.cpp index cf35256e13ee1b..994f999c71f13b 100644 --- a/torch/csrc/autograd/functions/comm.cpp +++ b/torch/csrc/autograd/functions/comm.cpp @@ -1,10 +1,10 @@ #include +#include #include #include #include #include -#include #include #include @@ -23,7 +23,8 @@ Scatter::Scatter( const c10::optional>& chunk_sizes, int64_t dim, // NOLINTNEXTLINE(modernize-pass-by-value) - const c10::optional>>& streams, + const c10::optional>>& + streams, bool unsqueeze_scalars) : devices_(std::move(devices)), chunk_sizes_(chunk_sizes), @@ -49,7 +50,11 @@ variable_list Scatter::apply(variable_list&& inputs) { }); auto tensors = torch::cuda::scatter( // NOLINTNEXTLINE(performance-move-const-arg) - std::move(input), device_indices, chunk_sizes_, dim_, streams_); + std::move(input), + device_indices, + chunk_sizes_, + dim_, + streams_); std::vector variables; variables.reserve(tensors.size()); diff --git a/torch/csrc/autograd/functions/comm.h b/torch/csrc/autograd/functions/comm.h index 1a89919d1d631b..d765c3fff7e6d5 100644 --- a/torch/csrc/autograd/functions/comm.h +++ b/torch/csrc/autograd/functions/comm.h @@ -1,12 +1,12 @@ #pragma once +#include #include #include -#include #include -#include #include +#include #include #include @@ -19,8 +19,8 @@ struct TORCH_CUDA_CU_API Scatter : public Node { std::vector devices, const c10::optional>& chunk_sizes = c10::nullopt, int64_t dim = 0, - const c10::optional>>& streams = - c10::nullopt, + const c10::optional>>& + streams = c10::nullopt, bool unsqueeze_scalars = false); ~Scatter() override; diff --git a/torch/csrc/autograd/functions/init.cpp b/torch/csrc/autograd/functions/init.cpp index a5dfbf5ce92334..7a02ce308fb08e 100644 --- a/torch/csrc/autograd/functions/init.cpp +++ b/torch/csrc/autograd/functions/init.cpp @@ -11,20 +11,23 @@ #endif #include #include -#include #include +#include using namespace torch::autograd; struct DelayedErrorCtor { DelayedError* operator()(PyObject* args) { - - TORCH_CHECK(PyTuple_GET_SIZE(args) == 2, "Requires two arguments, got ", PyTuple_GET_SIZE(args)); + TORCH_CHECK( + PyTuple_GET_SIZE(args) == 2, + "Requires two arguments, got ", + PyTuple_GET_SIZE(args)); auto arg1 = PyTuple_GET_ITEM(args, 0); TORCH_CHECK(THPUtils_checkString(arg1), "argument 'msg' must be a string"); std::string msg = THPUtils_unpackString(arg1); auto arg2 = PyTuple_GET_ITEM(args, 1); - TORCH_CHECK(THPUtils_checkLong(arg2), "argument 'num_inputs' must be an int"); + TORCH_CHECK( + THPUtils_checkLong(arg2), "argument 'num_inputs' must be an int"); int num_inputs = THPUtils_unpackLong(arg2); return new DelayedError(msg, num_inputs); } @@ -32,7 +35,10 @@ struct DelayedErrorCtor { struct UndefinedGradCtor { UndefinedGrad* operator()(PyObject* args) { - TORCH_CHECK(PyTuple_GET_SIZE(args) == 0, "Requires zero arguments, got ", PyTuple_GET_SIZE(args)); + TORCH_CHECK( + PyTuple_GET_SIZE(args) == 0, + "Requires zero arguments, got ", + PyTuple_GET_SIZE(args)); return new UndefinedGrad(); } }; @@ -43,26 +49,35 @@ struct NoCtor { } }; -template -static void addClass(PyObject* module, PyTypeObject& type, const char* name, - PyGetSetDef* function_properties=nullptr, PyMethodDef* function_methods=nullptr) -{ - createForwardFunctionPyTypeObject(type, name, function_properties, function_methods); +template +static void addClass( + PyObject* module, + PyTypeObject& type, + const char* name, + PyGetSetDef* function_properties = nullptr, + PyMethodDef* function_methods = nullptr) { + createForwardFunctionPyTypeObject( + type, name, function_properties, function_methods); Py_INCREF(&type); PyModule_AddObject(module, name, (PyObject*)&type); registerCppFunction(typeid(C), &type); } -template -PyObject* getTupleAttr(PyObject* obj, void* _unused) -{ +template < + typename T, + typename ValueT, + typename ParamsT, + ValueT ParamsT::*ptr, + typename ConvertArgT, + PyObject* (*Convert)(ConvertArgT)> +PyObject* getTupleAttr(PyObject* obj, void* _unused) { HANDLE_TH_ERRORS THPCppFunction* self = (THPCppFunction*)obj; auto& arr = ((T*)(self->cdata.get()))->*ptr; auto num_elems = arr.size(); THPObjectPtr py_tuple(PyTuple_New(num_elems)); - if (!py_tuple) return nullptr; + if (!py_tuple) + return nullptr; for (const auto i : c10::irange(num_elems)) { PyTuple_SET_ITEM(py_tuple.get(), i, Convert(arr[i])); } @@ -70,10 +85,14 @@ PyObject* getTupleAttr(PyObject* obj, void* _unused) END_HANDLE_TH_ERRORS } -template -PyObject* getValueAttr(PyObject* obj, void* _unused) -{ +template < + typename T, + typename ValueT, + typename ParamsT, + ValueT ParamsT::*ptr, + typename ConvertArgT, + PyObject* (*Convert)(ConvertArgT)> +PyObject* getValueAttr(PyObject* obj, void* _unused) { HANDLE_TH_ERRORS THPCppFunction* self = (THPCppFunction*)obj; auto& val = ((T*)(self->cdata.get()))->*ptr; @@ -81,8 +100,7 @@ PyObject* getValueAttr(PyObject* obj, void* _unused) END_HANDLE_TH_ERRORS } -static PyObject* accumulateGradVar(PyObject *_self, void* _unused) -{ +static PyObject* accumulateGradVar(PyObject* _self, void* _unused) { THPCppFunction* self = (THPCppFunction*)_self; auto grad_acc = (AccumulateGrad*)self->cdata.get(); return THPVariable_Wrap(grad_acc->variable); @@ -90,33 +108,40 @@ static PyObject* accumulateGradVar(PyObject *_self, void* _unused) // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables,cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays) static struct PyGetSetDef accumulate_grad_properties[] = { - THP_FUNCTION_DEFAULT_PROPERTIES, - {(char*)"variable", accumulateGradVar, nullptr, nullptr, nullptr}, - {nullptr} -}; + THP_FUNCTION_DEFAULT_PROPERTIES, + {(char*)"variable", accumulateGradVar, nullptr, nullptr, nullptr}, + {nullptr}}; -void THPAutograd_initFunctions() -{ +void THPAutograd_initFunctions() { THPObjectPtr module(PyModule_New("torch._C._functions")); - if (!module) throw python_error(); + if (!module) + throw python_error(); static PyTypeObject AccumulateGradClass; - addClass(module, AccumulateGradClass, "AccumulateGrad", accumulate_grad_properties); + addClass( + module, + AccumulateGradClass, + "AccumulateGrad", + accumulate_grad_properties); static PyTypeObject ErrorClass; addClass(module, ErrorClass, "Error"); static PyTypeObject NotImplementedClass; - addClass(module, NotImplementedClass, "NotImplemented"); + addClass( + module, NotImplementedClass, "NotImplemented"); static PyTypeObject DelayedErrorClass; - addClass(module, DelayedErrorClass, "DelayedError"); + addClass( + module, DelayedErrorClass, "DelayedError"); static PyTypeObject UndefinedGradBackwardClass; - addClass(module, UndefinedGradBackwardClass, "UndefinedGradBackward"); + addClass( + module, UndefinedGradBackwardClass, "UndefinedGradBackward"); static PyTypeObject UndefinedGradClass; - addClass(module, UndefinedGradClass, "UndefinedGrad"); + addClass( + module, UndefinedGradClass, "UndefinedGrad"); static PyTypeObject CopyBackwardsClass; addClass(module, CopyBackwardsClass, "CopyBackwards"); @@ -133,7 +158,8 @@ void THPAutograd_initFunctions() generated::initialize_autogenerated_functions(); auto c_module = THPObjectPtr(PyImport_ImportModule("torch._C")); - if (!c_module) throw python_error(); + if (!c_module) + throw python_error(); Py_INCREF(module); if (PyModule_AddObject(c_module, "_functions", module) < 0) { diff --git a/torch/csrc/autograd/functions/pybind.h b/torch/csrc/autograd/functions/pybind.h index dc47a3bafafa9b..5a97e6c4e9f97d 100644 --- a/torch/csrc/autograd/functions/pybind.h +++ b/torch/csrc/autograd/functions/pybind.h @@ -1,14 +1,14 @@ #pragma once -#include #include #include +#include -#include #include +#include namespace py = pybind11; -namespace pybind11 { namespace detail { - -}} // namespace pybind11::detail +namespace pybind11 { +namespace detail {} +} // namespace pybind11 diff --git a/torch/csrc/autograd/functions/tensor.cpp b/torch/csrc/autograd/functions/tensor.cpp index a4df5e5ff5168a..2fe5a3befbcc43 100644 --- a/torch/csrc/autograd/functions/tensor.cpp +++ b/torch/csrc/autograd/functions/tensor.cpp @@ -13,7 +13,8 @@ #include #include -namespace torch { namespace autograd { +namespace torch { +namespace autograd { auto CopyBackwards::apply(variable_list&& grads) -> variable_list { check_input_variables("CopyBackwards", grads, 1, -1, true); @@ -87,10 +88,10 @@ auto CopySlices::apply(variable_list&& inputs) -> variable_list { // TODO: We clone grad_slice because we modify it below and "fn" might save // it for the backward of res. We might be able to avoid the clone() if // double-backprop is disabled. - auto res = (*fn)({ grad_slice.clone(at::MemoryFormat::Contiguous) }); + auto res = (*fn)({grad_slice.clone(at::MemoryFormat::Contiguous)}); variable_list grad_inputs(num_outputs()); - for(const auto i : c10::irange(res.size())) { + for (const auto i : c10::irange(res.size())) { if (should_compute_output(i)) { AT_ASSERT(res[i].defined()); if (i == 0) { @@ -112,4 +113,5 @@ void CopySlices::release_variables() { fn = nullptr; } -}} // namespace torch::autograd +} // namespace autograd +} // namespace torch diff --git a/torch/csrc/autograd/functions/tensor.h b/torch/csrc/autograd/functions/tensor.h index 8fc7a8dc554f3b..cd77c8ceb72446 100644 --- a/torch/csrc/autograd/functions/tensor.h +++ b/torch/csrc/autograd/functions/tensor.h @@ -1,8 +1,8 @@ #pragma once +#include #include #include -#include #include #include @@ -11,7 +11,8 @@ #include #include -namespace torch { namespace autograd { +namespace torch { +namespace autograd { struct TORCH_API CopyBackwards : public Node { variable_list apply(variable_list&& grads) override; @@ -22,8 +23,8 @@ struct TORCH_API CopyBackwards : public Node { // Note [View + Inplace update for base tensor] // Performs grad_view = fn(grad_view), but out-of-place. // view_fn_ is an optional lambda function saved in DifferentiableViewMeta -// from forward pass, so that we can recover we when as_strided is not supported. -// It preserves the invariants: +// from forward pass, so that we can recover we when as_strided is not +// supported. It preserves the invariants: // view = view_fn_(base) // grad_view = view_fn_(grad_base) // @@ -99,4 +100,5 @@ struct TORCH_API CopySlices : public Node { std::shared_ptr fn; }; -}} +} // namespace autograd +} // namespace torch diff --git a/torch/csrc/autograd/functions/utils.cpp b/torch/csrc/autograd/functions/utils.cpp index 1705c75498f50a..07a5b2dc9e8696 100644 --- a/torch/csrc/autograd/functions/utils.cpp +++ b/torch/csrc/autograd/functions/utils.cpp @@ -8,10 +8,13 @@ #include #include -namespace torch { namespace autograd { +namespace torch { +namespace autograd { -variable_list wrap_outputs(const variable_list& inputs, tensor_list&& outputs, - const function_constructor& ctr) { +variable_list wrap_outputs( + const variable_list& inputs, + tensor_list&& outputs, + const function_constructor& ctr) { variable_list result; result.reserve(outputs.size()); if (!any_variable_requires_grad(inputs)) { @@ -23,10 +26,12 @@ variable_list wrap_outputs(const variable_list& inputs, tensor_list&& outputs, } } } else { - auto grad_fn = ctr(GradMode::is_enabled() ? collect_next_edges(inputs) : edge_list()); + auto grad_fn = + ctr(GradMode::is_enabled() ? collect_next_edges(inputs) : edge_list()); for (auto& output : outputs) { if (output.defined()) { - auto variable = autograd::make_variable(output, /*requires_grad=*/false); + auto variable = + autograd::make_variable(output, /*requires_grad=*/false); autograd::create_gradient_edge(variable, grad_fn); result.push_back(std::move(variable)); } else { @@ -38,7 +43,12 @@ variable_list wrap_outputs(const variable_list& inputs, tensor_list&& outputs, return result; } -void check_input_variables(const char* name, const variable_list& inputs, int args, int required_args, bool allow_undefined) { +void check_input_variables( + const char* name, + const variable_list& inputs, + int args, + int required_args, + bool allow_undefined) { if (required_args == -1) { required_args = args; } @@ -56,4 +66,5 @@ void check_input_variables(const char* name, const variable_list& inputs, int ar } } } -}} // namespace torch::autograd +} // namespace autograd +} // namespace torch diff --git a/torch/csrc/autograd/functions/utils.h b/torch/csrc/autograd/functions/utils.h index 0ee9b6310c6c6a..a2169f18656f07 100644 --- a/torch/csrc/autograd/functions/utils.h +++ b/torch/csrc/autograd/functions/utils.h @@ -1,10 +1,10 @@ #pragma once #include +#include #include #include #include -#include #include #include @@ -13,7 +13,8 @@ #include #include -namespace torch { namespace autograd { +namespace torch { +namespace autograd { using function_constructor = std::function(edge_list&&)>; @@ -21,12 +22,20 @@ using function_constructor = std::function(edge_list&&)>; * Wraps the tensor outputs in variables and creates the grad_fn and sets the * grad_fn if necessary. */ -TORCH_API variable_list wrap_outputs(const variable_list& inputs, tensor_list&& outputs, - const function_constructor& ctr); +TORCH_API variable_list wrap_outputs( + const variable_list& inputs, + tensor_list&& outputs, + const function_constructor& ctr); -/// Checks that inputs contains exactly `args` items and that the first `required_args` +/// Checks that inputs contains exactly `args` items and that the first +/// `required_args` /// items are not nullptr. If not specified, `required_args` defaults to `args`. -TORCH_API void check_input_variables(const char* name, const variable_list& inputs, int args, int required_args=-1, bool allow_undefined=false); +TORCH_API void check_input_variables( + const char* name, + const variable_list& inputs, + int args, + int required_args = -1, + bool allow_undefined = false); struct ComputeRequiresGrad : IterArgs { bool out = false; @@ -60,11 +69,11 @@ inline void set_history( const std::shared_ptr& grad_fn) { AT_ASSERT(grad_fn); if (variable.defined()) { - // If the codegen triggers this, you most likely want to add your newly added function - // to the DONT_REQUIRE_DERIVATIVE list in tools/autograd/gen_variable_type.py + // If the codegen triggers this, you most likely want to add your newly + // added function to the DONT_REQUIRE_DERIVATIVE list in + // tools/autograd/gen_variable_type.py TORCH_INTERNAL_ASSERT(isDifferentiableType(variable.scalar_type())); - auto output_nr = - grad_fn->add_input_metadata(variable); + auto output_nr = grad_fn->add_input_metadata(variable); impl::set_gradient_edge(variable, {grad_fn, output_nr}); } else { grad_fn->add_input_metadata(Node::undefined_input()); @@ -91,4 +100,5 @@ inline bool isFwGradDefined(const c10::optional& t) { return t.has_value() && t->defined() && t->_fw_grad(/*level */ 0).defined(); } -}} +} // namespace autograd +} // namespace torch diff --git a/torch/csrc/autograd/grad_mode.h b/torch/csrc/autograd/grad_mode.h index 72b75524179117..359d703946060a 100644 --- a/torch/csrc/autograd/grad_mode.h +++ b/torch/csrc/autograd/grad_mode.h @@ -3,9 +3,11 @@ #include #include -namespace torch { namespace autograd { +namespace torch { +namespace autograd { using GradMode = at::GradMode; using AutoGradMode = at::AutoGradMode; -}} +} // namespace autograd +} // namespace torch diff --git a/torch/csrc/autograd/init.cpp b/torch/csrc/autograd/init.cpp index fb333d2629c85c..2ae3e87172c8b9 100644 --- a/torch/csrc/autograd/init.cpp +++ b/torch/csrc/autograd/init.cpp @@ -1,35 +1,35 @@ #include -#include +#include +#include +#include +#include #include #include +#include #include -#include #include +#include #include -#include -#include -#include -#include #include #include #include -#include -#include #include -#include -#include #include #include -#include +#include +#include +#include +#include #include -#include -#include +#include +#include +#include #include #include -PyObject* THPAutograd_initExtension(PyObject* _unused, PyObject *unused) { +PyObject* THPAutograd_initExtension(PyObject* _unused, PyObject* unused) { using namespace torch::autograd::profiler; using namespace torch::profiler::impl; auto tensor_module = THPObjectPtr(PyImport_ImportModule("torch._tensor")); @@ -56,7 +56,8 @@ PyObject* THPAutograd_initExtension(PyObject* _unused, PyObject *unused) { auto _C_m = py::handle(torch_C_module).cast(); auto m = _C_m.def_submodule("_autograd", "autograd bindings"); - auto parameter_module = THPObjectPtr(PyImport_ImportModule("torch.nn.parameter")); + auto parameter_module = + THPObjectPtr(PyImport_ImportModule("torch.nn.parameter")); if (!parameter_module) return nullptr; @@ -85,10 +86,11 @@ PyObject* THPAutograd_initExtension(PyObject* _unused, PyObject *unused) { .value("CUDA", ActivityType::CUDA); py::class_(m, "_ExperimentalConfig") - .def(py::init< - std::vector /* profiler_metrics */, - bool /* profiler_measure_per_kernel */ - >(), + .def( + py::init< + std::vector /* profiler_metrics */, + bool /* profiler_measure_per_kernel */ + >(), "An experimental config for Kineto features. Please note that" "backward compatibility is not guaranteed.\n" " profiler_metrics : a list of CUPTI profiler metrics used\n" @@ -98,20 +100,19 @@ PyObject* THPAutograd_initExtension(PyObject* _unused, PyObject *unused) { " or for the entire measurement duration.", py::arg("profiler_metrics") = std::vector(), py::arg("profiler_measure_per_kernel") = false) - .def(py::pickle( - [](const ExperimentalConfig &p) { // __getstate__ + .def(py::pickle( + [](const ExperimentalConfig& p) { // __getstate__ py::list py_metrics; for (const auto& metric : p.profiler_metrics) { py::bytes mbytes(metric); py_metrics.append(mbytes); } /* Return a tuple that fully encodes the state of the config */ - return py::make_tuple( - py_metrics, p.profiler_measure_per_kernel); - }, - [](py::tuple t) { // __setstate__ + return py::make_tuple(py_metrics, p.profiler_measure_per_kernel); + }, + [](py::tuple t) { // __setstate__ if (t.size() != 2) { - throw std::runtime_error("Expected 2 values in state"); + throw std::runtime_error("Expected 2 values in state"); } py::list py_metrics = t[0].cast(); @@ -122,19 +123,18 @@ PyObject* THPAutograd_initExtension(PyObject* _unused, PyObject *unused) { } return ExperimentalConfig(std::move(metrics), t[1].cast()); - } - )); - + })); py::class_(m, "ProfilerConfig") - .def(py::init()); + .def(py::init< + ProfilerState, + bool, /* record_input_shapes */ + bool, /* profile_memory */ + bool, /* with_stack */ + bool, /* with_flops */ + bool, /* with_modules */ + ExperimentalConfig /* experimental_config */ + >()); py::class_(m, "ProfilerEvent") .def("kind", &LegacyEvent::kindStr) @@ -179,96 +179,80 @@ PyObject* THPAutograd_initExtension(PyObject* _unused, PyObject *unused) { py::class_(m, "_KinetoEvent") // name of the event - .def("name", [](const KinetoEvent& e) { - return e.name(); - }) + .def("name", [](const KinetoEvent& e) { return e.name(); }) // PyTorch thread id of the start callback - .def("start_thread_id", [](const KinetoEvent& e) { - return e.startThreadId(); - }) + .def( + "start_thread_id", + [](const KinetoEvent& e) { return e.startThreadId(); }) // PyTorch thread id of the end callback - .def("end_thread_id", [](const KinetoEvent& e) { - return e.endThreadId(); - }) + .def( + "end_thread_id", [](const KinetoEvent& e) { return e.endThreadId(); }) // for events of scope BACKWARD_FUNCTION - PyTorch thread id // of the corresponding forward op - .def("fwd_thread_id", [](const KinetoEvent& e) { - return e.fwdThreadId(); - }) + .def( + "fwd_thread_id", [](const KinetoEvent& e) { return e.fwdThreadId(); }) // together with fwd_thread_id, used to uniquely identify // the forward op - .def("sequence_nr", [](const KinetoEvent& e) { - return e.sequenceNr(); - }) + .def("sequence_nr", [](const KinetoEvent& e) { return e.sequenceNr(); }) // absolute start time (since unix epoch) in us - .def("start_us", [](const KinetoEvent& e) { - return e.startUs(); - }) + .def("start_us", [](const KinetoEvent& e) { return e.startUs(); }) // duration in us - .def("duration_us", [](const KinetoEvent& e) { - return e.durationUs(); - }) + .def("duration_us", [](const KinetoEvent& e) { return e.durationUs(); }) // used for correlation between high-level PyTorch events // and low-level device events - .def("correlation_id", [](const KinetoEvent& e) { - return e.correlationId(); - }) + .def( + "correlation_id", + [](const KinetoEvent& e) { return e.correlationId(); }) // shapes of input tensors - .def("shapes", [](const KinetoEvent& e) { - if (e.hasShapes()) { - return e.shapes(); - } else { - return std::vector>(); - } - }) - .def("dtypes", [](const KinetoEvent& e) { - if (e.hasTypes()) { - return e.dtypes(); - } else { - return std::vector(); - } - }) + .def( + "shapes", + [](const KinetoEvent& e) { + if (e.hasShapes()) { + return e.shapes(); + } else { + return std::vector>(); + } + }) + .def( + "dtypes", + [](const KinetoEvent& e) { + if (e.hasTypes()) { + return e.dtypes(); + } else { + return std::vector(); + } + }) // stack traces of the PyTorch CPU events - .def("stack", [](const KinetoEvent& e) { - if (e.hasStack()) { - return e.stack(); - } else { - return std::vector(); - } - }) + .def( + "stack", + [](const KinetoEvent& e) { + if (e.hasStack()) { + return e.stack(); + } else { + return std::vector(); + } + }) // type of the RecordFunction that generated a PyTorch CPU event // (op, torchscript function, user label, etc) - .def("scope", [](const KinetoEvent& e) { - return e.scope(); - }) + .def("scope", [](const KinetoEvent& e) { return e.scope(); }) // device number, for CPU - process id - .def("device_index", [](const KinetoEvent& e) { - return e.deviceIndex(); - }) + .def("device_index", [](const KinetoEvent& e) { return e.deviceIndex(); }) // for CUDA - stream id, for CPU - start thread id - .def("device_resource_id", [](const KinetoEvent& e) { - return e.deviceResourceId(); - }) + .def( + "device_resource_id", + [](const KinetoEvent& e) { return e.deviceResourceId(); }) // device type - .def("device_type", [](const KinetoEvent& e) { - return e.deviceType(); - }) + .def("device_type", [](const KinetoEvent& e) { return e.deviceType(); }) // correlation id of a linked event - .def("linked_correlation_id", [](const KinetoEvent& e) { - return e.linkedCorrelationId(); - }) + .def( + "linked_correlation_id", + [](const KinetoEvent& e) { return e.linkedCorrelationId(); }) // compute flops - .def("flops", [](const KinetoEvent& e) { - return e.flops(); - }) + .def("flops", [](const KinetoEvent& e) { return e.flops(); }) // Whether this is async event or not - .def("is_async", [](const KinetoEvent& e) { - return e.isAsync(); - }) + .def("is_async", [](const KinetoEvent& e) { return e.isAsync(); }) .def("cuda_elapsed_us", &KinetoEvent::cudaElapsedUs) - .def("nbytes", [](const KinetoEvent& e) { - return e.nBytes(); - }); + .def("nbytes", [](const KinetoEvent& e) { return e.nBytes(); }); { using torch::profiler::impl::Result; @@ -285,67 +269,82 @@ PyObject* THPAutograd_initExtension(PyObject* _unused, PyObject *unused) { } py::class_(m, "_ProfilerResult") - .def("trace_start_us", &ProfilerResult::trace_start_us) - .def("events", &ProfilerResult::events) - .def("experimental_event_tree", &ProfilerResult::event_tree) + .def("trace_start_us", &ProfilerResult::trace_start_us) + .def("events", &ProfilerResult::events) + .def("experimental_event_tree", &ProfilerResult::event_tree) #ifdef USE_KINETO - .def("save", &ProfilerResult::save) + .def("save", &ProfilerResult::save) #endif // USE_KINETO - ; + ; - m.def("_enable_profiler", - &enableProfiler, - py::arg("config"), - py::arg("activities"), - py::arg("scopes") = std::unordered_set()); + m.def( + "_enable_profiler", + &enableProfiler, + py::arg("config"), + py::arg("activities"), + py::arg("scopes") = std::unordered_set()); m.def("_disable_profiler", disableProfiler); m.def("_prepare_profiler", prepareProfiler); - m.def("_add_metadata_json", addMetadataJson); // Only if `USE_KINETO` is set - m.def("_kineto_step", profilerStep); // Only if `USE_KINETO` is set + m.def("_add_metadata_json", addMetadataJson); // Only if `USE_KINETO` is set + m.def("_kineto_step", profilerStep); // Only if `USE_KINETO` is set m.def("kineto_available", []() { return torch::profiler::kKinetoAvailable; }); // PyTorch profiler execution graph internal interface. - m.def("_add_execution_graph_observer", &torch::profiler::impl::addExecutionGraphObserver, py::arg("output_file_name")); - m.def("_remove_execution_graph_observer", &torch::profiler::impl::removeExecutionGraphObserver); - m.def("_enable_execution_graph_observer", &torch::profiler::impl::enableExecutionGraphObserver); - m.def("_disable_execution_graph_observer", &torch::profiler::impl::disableExecutionGraphObserver); + m.def( + "_add_execution_graph_observer", + &torch::profiler::impl::addExecutionGraphObserver, + py::arg("output_file_name")); + m.def( + "_remove_execution_graph_observer", + &torch::profiler::impl::removeExecutionGraphObserver); + m.def( + "_enable_execution_graph_observer", + &torch::profiler::impl::enableExecutionGraphObserver); + m.def( + "_disable_execution_graph_observer", + &torch::profiler::impl::disableExecutionGraphObserver); // NOTICE: These record functions are not torch operators and may not show up // in TorchScript tracing, FX transforms, or operator serialization. For these // use cases, please use `torch.profiler.record_function`. // Creates a new profiling scope using RecordFunction and invokes its starting // callbacks. - m.def("_record_function_with_args_enter", [](const std::string& name, py::args args) { - using torch::autograd::profiler::PythonRecordFunction; - auto python_rec = c10::make_intrusive(at::RecordScope::USER_SCOPE); - auto *rec = &python_rec->record; - if (rec->isActive()) { - if (rec->needsInputs()) { - auto iv_inputs = std::vector(); - for (const auto& arg : args) { - iv_inputs.push_back(torch::jit::toTypeInferredIValue(arg)); + m.def( + "_record_function_with_args_enter", + [](const std::string& name, py::args args) { + using torch::autograd::profiler::PythonRecordFunction; + auto python_rec = c10::make_intrusive( + at::RecordScope::USER_SCOPE); + auto* rec = &python_rec->record; + if (rec->isActive()) { + if (rec->needsInputs()) { + auto iv_inputs = std::vector(); + for (const auto& arg : args) { + iv_inputs.push_back(torch::jit::toTypeInferredIValue(arg)); + } + rec->before( + name, + c10::ArrayRef( + iv_inputs.data(), iv_inputs.size())); + } else { + rec->before(name); + } } - rec->before(name, c10::ArrayRef(iv_inputs.data(), iv_inputs.size())); - } else { - rec->before(name); - } - } - return torch::jit::toPyObject(std::move(python_rec)); - }); + return torch::jit::toPyObject(std::move(python_rec)); + }); // Ends the profiling scope created with record_function_with_param_enter. - m.def("_record_function_with_args_exit", - [](const py::object &obj) { - using torch::autograd::profiler::PythonRecordFunction; - auto python_record = torch::jit::toCustomClass(obj); + m.def("_record_function_with_args_exit", [](const py::object& obj) { + using torch::autograd::profiler::PythonRecordFunction; + auto python_record = torch::jit::toCustomClass(obj); - // We don't actually need to do anything with handle just need to persist the - // lifetime until now. - python_record->record.end(); - }); + // We don't actually need to do anything with handle just need to persist + // the lifetime until now. + python_record->record.end(); + }); m.def("_supported_activities", []() { - std::set activities {ActivityType::CPU}; + std::set activities{ActivityType::CPU}; #if defined(USE_KINETO) && !defined(LIBKINETO_NOCUPTI) if (at::getNumGPUs() > 0 && !at::hasHIP()) { activities.insert(ActivityType::CUDA); @@ -367,34 +366,37 @@ PyObject* THPAutograd_initExtension(PyObject* _unused, PyObject *unused) { at::enableRecordFunction(enable); }); m.def("_set_empty_test_observer", [](bool is_global, double sampling_prob) { - auto cb = at::RecordFunctionCallback(nullptr) - .needsInputs(true) - .samplingProb(sampling_prob); + auto cb = + at::RecordFunctionCallback(nullptr).needsInputs(true).samplingProb( + sampling_prob); if (is_global) { at::addGlobalCallback(cb); } else { at::addThreadLocalCallback(cb); } }); - m.def("_clear_callbacks", []() { - at::clearCallbacks(); - }); - m.def("_push_saved_tensors_default_hooks", [](py::function &pack_hook, py::function &unpack_hook) { - torch::autograd::PyDefaultSavedVariableHooks::push_hooks(pack_hook, unpack_hook); - }); + m.def("_clear_callbacks", []() { at::clearCallbacks(); }); + m.def( + "_push_saved_tensors_default_hooks", + [](py::function& pack_hook, py::function& unpack_hook) { + torch::autograd::PyDefaultSavedVariableHooks::push_hooks( + pack_hook, unpack_hook); + }); m.def("_pop_saved_tensors_default_hooks", []() { torch::autograd::PyDefaultSavedVariableHooks::pop_hooks(); }); - _C_m.def("_register_py_class_for_device", [](const std::string& device, py::object python_type_class) { - auto cls = python_type_class.ptr(); - registerPythonTensorClass(device, cls); - }); + _C_m.def( + "_register_py_class_for_device", + [](const std::string& device, py::object python_type_class) { + auto cls = python_type_class.ptr(); + registerPythonTensorClass(device, cls); + }); - py::class_(_C_m, "_InferenceMode") - .def(py::init()); + py::class_(_C_m, "_InferenceMode").def(py::init()); - py::class_(_C_m, "_RestorePythonTLSSnapshot") + py::class_( + _C_m, "_RestorePythonTLSSnapshot") .def(py::init<>()); // TODO: line up this binding with DisableTorchFunction @@ -402,21 +404,31 @@ PyObject* THPAutograd_initExtension(PyObject* _unused, PyObject *unused) { .def(py::init<>()); py::class_(m, "SavedTensor") - .def(py::init([]()->torch::autograd::SavedVariable { - TORCH_CHECK(false, "Trying to create a SavedTensor object from Python is forbidden."); - })) - .def("register_hooks", [](torch::autograd::SavedVariable &s, py::function &pack_hook, py::function &unpack_hook) { - // Because we use a py::object, pybind will increment the refcount of the hook functions for us - s.register_hooks(std::make_unique(pack_hook, unpack_hook)); - }); + .def(py::init([]() -> torch::autograd::SavedVariable { + TORCH_CHECK( + false, + "Trying to create a SavedTensor object from Python is forbidden."); + })) + .def( + "register_hooks", + [](torch::autograd::SavedVariable& s, + py::function& pack_hook, + py::function& unpack_hook) { + // Because we use a py::object, pybind will increment the refcount + // of the hook functions for us + s.register_hooks( + std::make_unique( + pack_hook, unpack_hook)); + }); torch::autograd::profiler::python_tracer::init(); Py_RETURN_TRUE; } -namespace torch { namespace autograd { +namespace torch { +namespace autograd { -static PyObject * set_autocast_enabled(PyObject* _unused, PyObject *arg) { +static PyObject* set_autocast_enabled(PyObject* _unused, PyObject* arg) { HANDLE_TH_ERRORS if (!PyBool_Check(arg)) { throw TypeError("enabled must be a bool (got %s)", Py_TYPE(arg)->tp_name); @@ -426,7 +438,7 @@ static PyObject * set_autocast_enabled(PyObject* _unused, PyObject *arg) { END_HANDLE_TH_ERRORS } -static PyObject * is_autocast_enabled(PyObject* _unused, PyObject *arg) { +static PyObject* is_autocast_enabled(PyObject* _unused, PyObject* arg) { HANDLE_TH_ERRORS if (at::autocast::is_enabled()) { Py_RETURN_TRUE; @@ -436,7 +448,7 @@ static PyObject * is_autocast_enabled(PyObject* _unused, PyObject *arg) { END_HANDLE_TH_ERRORS } -static PyObject * set_autocast_cpu_enabled(PyObject* _unused, PyObject *arg) { +static PyObject* set_autocast_cpu_enabled(PyObject* _unused, PyObject* arg) { HANDLE_TH_ERRORS if (!PyBool_Check(arg)) { throw TypeError("enabled must be a bool (got %s)", Py_TYPE(arg)->tp_name); @@ -446,7 +458,7 @@ static PyObject * set_autocast_cpu_enabled(PyObject* _unused, PyObject *arg) { END_HANDLE_TH_ERRORS } -static PyObject * is_autocast_cpu_enabled(PyObject* _unused, PyObject *arg) { +static PyObject* is_autocast_cpu_enabled(PyObject* _unused, PyObject* arg) { HANDLE_TH_ERRORS if (at::autocast::is_cpu_enabled()) { Py_RETURN_TRUE; @@ -456,7 +468,7 @@ static PyObject * is_autocast_cpu_enabled(PyObject* _unused, PyObject *arg) { END_HANDLE_TH_ERRORS } -static PyObject * set_autocast_gpu_dtype(PyObject* _unused, PyObject *arg) { +static PyObject* set_autocast_gpu_dtype(PyObject* _unused, PyObject* arg) { HANDLE_TH_ERRORS if (!THPDtype_Check(arg)) { throw TypeError( @@ -468,7 +480,7 @@ static PyObject * set_autocast_gpu_dtype(PyObject* _unused, PyObject *arg) { END_HANDLE_TH_ERRORS } -static PyObject * set_autocast_cpu_dtype(PyObject* _unused, PyObject *arg) { +static PyObject* set_autocast_cpu_dtype(PyObject* _unused, PyObject* arg) { HANDLE_TH_ERRORS if (!THPDtype_Check(arg)) { throw TypeError( @@ -480,7 +492,7 @@ static PyObject * set_autocast_cpu_dtype(PyObject* _unused, PyObject *arg) { END_HANDLE_TH_ERRORS } -static PyObject * get_autocast_gpu_dtype(PyObject* _unused, PyObject *arg){ +static PyObject* get_autocast_gpu_dtype(PyObject* _unused, PyObject* arg) { HANDLE_TH_ERRORS at::ScalarType current_dtype = at::autocast::get_autocast_gpu_dtype(); auto dtype = (PyObject*)torch::getTHPDtype(current_dtype); @@ -489,7 +501,7 @@ static PyObject * get_autocast_gpu_dtype(PyObject* _unused, PyObject *arg){ END_HANDLE_TH_ERRORS } -static PyObject * get_autocast_cpu_dtype(PyObject* _unused, PyObject *arg){ +static PyObject* get_autocast_cpu_dtype(PyObject* _unused, PyObject* arg) { HANDLE_TH_ERRORS at::ScalarType current_dtype = at::autocast::get_autocast_cpu_dtype(); auto dtype = (PyObject*)torch::getTHPDtype(current_dtype); @@ -498,26 +510,26 @@ static PyObject * get_autocast_cpu_dtype(PyObject* _unused, PyObject *arg){ END_HANDLE_TH_ERRORS } -static PyObject * clear_autocast_cache(PyObject* _unused, PyObject *arg) { +static PyObject* clear_autocast_cache(PyObject* _unused, PyObject* arg) { HANDLE_TH_ERRORS at::autocast::clear_cache(); Py_RETURN_NONE; END_HANDLE_TH_ERRORS } -static PyObject * autocast_increment_nesting(PyObject* _unused, PyObject *arg) { +static PyObject* autocast_increment_nesting(PyObject* _unused, PyObject* arg) { HANDLE_TH_ERRORS return THPUtils_packInt64(at::autocast::increment_nesting()); END_HANDLE_TH_ERRORS } -static PyObject * autocast_decrement_nesting(PyObject* _unused, PyObject *arg) { +static PyObject* autocast_decrement_nesting(PyObject* _unused, PyObject* arg) { HANDLE_TH_ERRORS return THPUtils_packInt64(at::autocast::decrement_nesting()); END_HANDLE_TH_ERRORS } -static PyObject * is_autocast_cache_enabled(PyObject* _unused, PyObject *arg) { +static PyObject* is_autocast_cache_enabled(PyObject* _unused, PyObject* arg) { HANDLE_TH_ERRORS if (at::autocast::is_autocast_cache_enabled()) { Py_RETURN_TRUE; @@ -527,7 +539,7 @@ static PyObject * is_autocast_cache_enabled(PyObject* _unused, PyObject *arg) { END_HANDLE_TH_ERRORS } -static PyObject * set_autocast_cache_enabled(PyObject* _unused, PyObject *arg) { +static PyObject* set_autocast_cache_enabled(PyObject* _unused, PyObject* arg) { HANDLE_TH_ERRORS if (!PyBool_Check(arg)) { throw TypeError("enabled must be a bool (got %s)", Py_TYPE(arg)->tp_name); @@ -537,7 +549,7 @@ static PyObject * set_autocast_cache_enabled(PyObject* _unused, PyObject *arg) { END_HANDLE_TH_ERRORS } -static PyObject * set_grad_enabled(PyObject* _unused, PyObject *arg) { +static PyObject* set_grad_enabled(PyObject* _unused, PyObject* arg) { HANDLE_TH_ERRORS if (!PyBool_Check(arg)) { throw TypeError("enabled must be a bool (got %s)", Py_TYPE(arg)->tp_name); @@ -547,7 +559,7 @@ static PyObject * set_grad_enabled(PyObject* _unused, PyObject *arg) { END_HANDLE_TH_ERRORS } -static PyObject * is_grad_enabled(PyObject* _unused, PyObject *arg) { +static PyObject* is_grad_enabled(PyObject* _unused, PyObject* arg) { HANDLE_TH_ERRORS if (GradMode::is_enabled()) { Py_RETURN_TRUE; @@ -557,7 +569,7 @@ static PyObject * is_grad_enabled(PyObject* _unused, PyObject *arg) { END_HANDLE_TH_ERRORS } -static PyObject * is_inference_mode_enabled(PyObject* _unused, PyObject *arg) { +static PyObject* is_inference_mode_enabled(PyObject* _unused, PyObject* arg) { HANDLE_TH_ERRORS if (c10::InferenceMode::is_enabled()) { Py_RETURN_TRUE; @@ -567,7 +579,7 @@ static PyObject * is_inference_mode_enabled(PyObject* _unused, PyObject *arg) { END_HANDLE_TH_ERRORS } -static PyObject * set_anomaly_mode_enabled(PyObject* _unused, PyObject *arg) { +static PyObject* set_anomaly_mode_enabled(PyObject* _unused, PyObject* arg) { HANDLE_TH_ERRORS if (!PyBool_Check(arg)) { throw TypeError("enabled must be a bool (got %s)", Py_TYPE(arg)->tp_name); @@ -577,7 +589,7 @@ static PyObject * set_anomaly_mode_enabled(PyObject* _unused, PyObject *arg) { END_HANDLE_TH_ERRORS } -static PyObject * is_anomaly_mode_enabled(PyObject* _unused, PyObject *arg) { +static PyObject* is_anomaly_mode_enabled(PyObject* _unused, PyObject* arg) { HANDLE_TH_ERRORS if (AnomalyMode::is_enabled()) { Py_RETURN_TRUE; @@ -587,19 +599,20 @@ static PyObject * is_anomaly_mode_enabled(PyObject* _unused, PyObject *arg) { END_HANDLE_TH_ERRORS } -static PyObject * python_enter_dual_level(PyObject* _unused, PyObject* arg) { +static PyObject* python_enter_dual_level(PyObject* _unused, PyObject* arg) { HANDLE_TH_ERRORS - // It is unlikely that the depth of forward nesting will overflow int64_t so we - // just static cast here. + // It is unlikely that the depth of forward nesting will overflow int64_t so + // we just static cast here. return utils::wrap(static_cast(forward_ad::enter_dual_level())); END_HANDLE_TH_ERRORS } -static PyObject * python_exit_dual_level(PyObject* _unused, PyObject* args, PyObject* kwargs) { +static PyObject* python_exit_dual_level( + PyObject* _unused, + PyObject* args, + PyObject* kwargs) { HANDLE_TH_ERRORS - static PythonArgParser parser({ - "exit_dual_level(int64_t level)" - }); + static PythonArgParser parser({"exit_dual_level(int64_t level)"}); ParsedArgs<1> parsed_args; auto _r = parser.parse(args, kwargs, parsed_args); @@ -625,7 +638,9 @@ static PyObject* set_torch_dispatch_mode(PyObject* _unused, PyObject* arg) { END_HANDLE_TH_ERRORS } -static PyObject* get_torch_dispatch_mode(PyObject* _unused, PyObject* _unused2) { +static PyObject* get_torch_dispatch_mode( + PyObject* _unused, + PyObject* _unused2) { HANDLE_TH_ERRORS const auto& mode = at::impl::TorchDispatchModeTLS::get_state(); if (!mode) { @@ -638,19 +653,22 @@ static PyObject* get_torch_dispatch_mode(PyObject* _unused, PyObject* _unused2) END_HANDLE_TH_ERRORS } -static PyObject * set_torch_function_mode(PyObject* _unused, PyObject* arg) { +static PyObject* set_torch_function_mode(PyObject* _unused, PyObject* arg) { HANDLE_TH_ERRORS if (arg == Py_None) { at::impl::PythonTorchFunctionTLS::set_mode(nullptr); } else { Py_INCREF(arg); - at::impl::PythonTorchFunctionTLS::set_mode(std::make_shared(arg, getPyInterpreter())); + at::impl::PythonTorchFunctionTLS::set_mode( + std::make_shared(arg, getPyInterpreter())); } Py_RETURN_NONE; END_HANDLE_TH_ERRORS } -static PyObject * get_torch_function_mode(PyObject* _unused, PyObject* _unused2) { +static PyObject* get_torch_function_mode( + PyObject* _unused, + PyObject* _unused2) { HANDLE_TH_ERRORS const auto& mode = at::impl::PythonTorchFunctionTLS::get_mode(); if (!mode) { @@ -665,35 +683,50 @@ static PyObject * get_torch_function_mode(PyObject* _unused, PyObject* _unused2) // autograd methods on torch._C static PyMethodDef methods[] = { // NOLINT - {"_set_grad_enabled", set_grad_enabled, METH_O, nullptr}, - {"is_grad_enabled", is_grad_enabled, METH_NOARGS, nullptr}, - {"is_inference_mode_enabled", is_inference_mode_enabled, METH_NOARGS, nullptr}, - {"set_autocast_enabled", set_autocast_enabled, METH_O, nullptr}, - {"is_autocast_enabled", is_autocast_enabled, METH_NOARGS, nullptr}, - {"clear_autocast_cache", clear_autocast_cache, METH_NOARGS, nullptr}, - {"set_autocast_cpu_enabled", set_autocast_cpu_enabled, METH_O, nullptr}, - {"is_autocast_cpu_enabled", is_autocast_cpu_enabled, METH_NOARGS, nullptr}, - {"set_autocast_cpu_dtype", set_autocast_cpu_dtype, METH_O, nullptr}, - {"get_autocast_cpu_dtype", get_autocast_cpu_dtype, METH_NOARGS, nullptr}, - {"set_autocast_gpu_dtype", set_autocast_gpu_dtype, METH_O, nullptr}, - {"get_autocast_gpu_dtype", get_autocast_gpu_dtype, METH_NOARGS, nullptr}, - {"autocast_increment_nesting", autocast_increment_nesting, METH_NOARGS, nullptr}, - {"autocast_decrement_nesting", autocast_decrement_nesting, METH_NOARGS, nullptr}, - {"is_autocast_cache_enabled", is_autocast_cache_enabled, METH_NOARGS, nullptr}, - {"set_autocast_cache_enabled", set_autocast_cache_enabled, METH_O, nullptr}, - {"set_anomaly_enabled", set_anomaly_mode_enabled, METH_O, nullptr}, - {"is_anomaly_enabled", is_anomaly_mode_enabled, METH_NOARGS, nullptr}, - {"_enter_dual_level", python_enter_dual_level, METH_NOARGS, nullptr}, - {"_exit_dual_level", castPyCFunctionWithKeywords(python_exit_dual_level), METH_VARARGS | METH_KEYWORDS, nullptr}, - {"_set_torch_dispatch_mode", set_torch_dispatch_mode, METH_O, nullptr}, - {"_get_torch_dispatch_mode", get_torch_dispatch_mode, METH_NOARGS, nullptr}, - {"_set_torch_function_mode", set_torch_function_mode, METH_O, nullptr}, - {"_get_torch_function_mode", get_torch_function_mode, METH_NOARGS, nullptr}, - {nullptr, nullptr, 0, nullptr} -}; + {"_set_grad_enabled", set_grad_enabled, METH_O, nullptr}, + {"is_grad_enabled", is_grad_enabled, METH_NOARGS, nullptr}, + {"is_inference_mode_enabled", + is_inference_mode_enabled, + METH_NOARGS, + nullptr}, + {"set_autocast_enabled", set_autocast_enabled, METH_O, nullptr}, + {"is_autocast_enabled", is_autocast_enabled, METH_NOARGS, nullptr}, + {"clear_autocast_cache", clear_autocast_cache, METH_NOARGS, nullptr}, + {"set_autocast_cpu_enabled", set_autocast_cpu_enabled, METH_O, nullptr}, + {"is_autocast_cpu_enabled", is_autocast_cpu_enabled, METH_NOARGS, nullptr}, + {"set_autocast_cpu_dtype", set_autocast_cpu_dtype, METH_O, nullptr}, + {"get_autocast_cpu_dtype", get_autocast_cpu_dtype, METH_NOARGS, nullptr}, + {"set_autocast_gpu_dtype", set_autocast_gpu_dtype, METH_O, nullptr}, + {"get_autocast_gpu_dtype", get_autocast_gpu_dtype, METH_NOARGS, nullptr}, + {"autocast_increment_nesting", + autocast_increment_nesting, + METH_NOARGS, + nullptr}, + {"autocast_decrement_nesting", + autocast_decrement_nesting, + METH_NOARGS, + nullptr}, + {"is_autocast_cache_enabled", + is_autocast_cache_enabled, + METH_NOARGS, + nullptr}, + {"set_autocast_cache_enabled", set_autocast_cache_enabled, METH_O, nullptr}, + {"set_anomaly_enabled", set_anomaly_mode_enabled, METH_O, nullptr}, + {"is_anomaly_enabled", is_anomaly_mode_enabled, METH_NOARGS, nullptr}, + {"_enter_dual_level", python_enter_dual_level, METH_NOARGS, nullptr}, + {"_exit_dual_level", + castPyCFunctionWithKeywords(python_exit_dual_level), + METH_VARARGS | METH_KEYWORDS, + nullptr}, + {"_set_torch_dispatch_mode", set_torch_dispatch_mode, METH_O, nullptr}, + {"_get_torch_dispatch_mode", get_torch_dispatch_mode, METH_NOARGS, nullptr}, + {"_set_torch_function_mode", set_torch_function_mode, METH_O, nullptr}, + {"_get_torch_function_mode", get_torch_function_mode, METH_NOARGS, nullptr}, + {nullptr, nullptr, 0, nullptr}}; PyMethodDef* python_functions() { return methods; } -}} // namespace torch::autograd +} // namespace autograd +} // namespace torch diff --git a/torch/csrc/autograd/input_buffer.cpp b/torch/csrc/autograd/input_buffer.cpp index 71cc6e06d4d649..49d7e1cb5ad85f 100644 --- a/torch/csrc/autograd/input_buffer.cpp +++ b/torch/csrc/autograd/input_buffer.cpp @@ -5,90 +5,101 @@ #include #include -#include #include +#include #include #include #include #include -namespace torch { namespace autograd { +namespace torch { +namespace autograd { namespace { - // look what you made me do >.< - // Divergent paths for per-Impl stream recording that leak implementation - // details of the impls should not be needed here. - // See https://github.com/pytorch/pytorch/issues/60306 - // TODO: clean this up when https://github.com/pytorch/pytorch/issues/60306 is improved - void record_stream_any_impl(Variable& var, c10::Stream& stream) { - const auto guard = c10::impl::VirtualGuardImpl(c10::DeviceType::CUDA); +// look what you made me do >.< +// Divergent paths for per-Impl stream recording that leak implementation +// details of the impls should not be needed here. +// See https://github.com/pytorch/pytorch/issues/60306 +// TODO: clean this up when https://github.com/pytorch/pytorch/issues/60306 is +// improved +void record_stream_any_impl(Variable& var, c10::Stream& stream) { + const auto guard = c10::impl::VirtualGuardImpl(c10::DeviceType::CUDA); - if (C10_UNLIKELY(at::isBatchedTensor(var))) { - auto* impl = at::maybeGetBatchedImpl(var); - if (impl) { - guard.recordDataPtrOnStream(impl->value().storage().data_ptr(), stream); - } else { - TORCH_INTERNAL_ASSERT(false, "Expected batched tensor"); - } + if (C10_UNLIKELY(at::isBatchedTensor(var))) { + auto* impl = at::maybeGetBatchedImpl(var); + if (impl) { + guard.recordDataPtrOnStream(impl->value().storage().data_ptr(), stream); } else { - switch (var.layout()) { - case c10::kSparseCsr: - case c10::kSparseCsc: - case c10::kSparseBsr: - case c10::kSparseBsc: - { - auto* impl = at::sparse_csr::get_sparse_csr_impl(var); - guard.recordDataPtrOnStream(impl->values().storage().data_ptr(), stream); - guard.recordDataPtrOnStream(impl->compressed_indices().storage().data_ptr(), stream); - guard.recordDataPtrOnStream(impl->plain_indices().storage().data_ptr(), stream); - break; - } - case c10::kSparse: - { - auto* impl = at::sparse::get_sparse_impl(var); - guard.recordDataPtrOnStream(impl->values().storage().data_ptr(), stream); - guard.recordDataPtrOnStream(impl->indices().storage().data_ptr(), stream); - break; - } - case c10::kStrided: - guard.recordDataPtrOnStream(var.storage().data_ptr(), stream); - break; - default: - TORCH_INTERNAL_ASSERT(false, "Unknown layout in record_stream_any_impl"); + TORCH_INTERNAL_ASSERT(false, "Expected batched tensor"); + } + } else { + switch (var.layout()) { + case c10::kSparseCsr: + case c10::kSparseCsc: + case c10::kSparseBsr: + case c10::kSparseBsc: { + auto* impl = at::sparse_csr::get_sparse_csr_impl(var); + guard.recordDataPtrOnStream( + impl->values().storage().data_ptr(), stream); + guard.recordDataPtrOnStream( + impl->compressed_indices().storage().data_ptr(), stream); + guard.recordDataPtrOnStream( + impl->plain_indices().storage().data_ptr(), stream); + break; + } + case c10::kSparse: { + auto* impl = at::sparse::get_sparse_impl(var); + guard.recordDataPtrOnStream( + impl->values().storage().data_ptr(), stream); + guard.recordDataPtrOnStream( + impl->indices().storage().data_ptr(), stream); + break; } + case c10::kStrided: + guard.recordDataPtrOnStream(var.storage().data_ptr(), stream); + break; + default: + TORCH_INTERNAL_ASSERT( + false, "Unknown layout in record_stream_any_impl"); } } -} // anonymous namespace +} +} // anonymous namespace - static void accumulate(std::vector& buffer, - const size_t pos, - Variable&& var) { - TORCH_INTERNAL_ASSERT(pos < buffer.size()); - auto& old_var = buffer[pos]; - // ATen doesn't route sparse additions correctly... - // do dense + sparse in-place if possible - if (old_var.is_sparse()) { - // It is safe to change the Tensor inplace if the Tensor is only used in this buffer (this could be the gradient passed by the - // user) and that no other Tensor is using the same storage. - if (!var.is_sparse() && var.is_contiguous() && var.use_count() == 1 && var.storage().use_count() == 1) { - buffer[pos] = var.add_(old_var); - } else { - buffer[pos] = var + old_var; - } +static void accumulate( + std::vector& buffer, + const size_t pos, + Variable&& var) { + TORCH_INTERNAL_ASSERT(pos < buffer.size()); + auto& old_var = buffer[pos]; + // ATen doesn't route sparse additions correctly... + // do dense + sparse in-place if possible + if (old_var.is_sparse()) { + // It is safe to change the Tensor inplace if the Tensor is only used in + // this buffer (this could be the gradient passed by the user) and that no + // other Tensor is using the same storage. + if (!var.is_sparse() && var.is_contiguous() && var.use_count() == 1 && + var.storage().use_count() == 1) { + buffer[pos] = var.add_(old_var); } else { - if (var.is_sparse() && !old_var.is_sparse() && old_var.is_contiguous() && old_var.use_count() == 1 && old_var.storage().use_count() == 1) { - buffer[pos] = old_var.add_(var); - } else { - buffer[pos] = old_var + var; - } + buffer[pos] = var + old_var; + } + } else { + if (var.is_sparse() && !old_var.is_sparse() && old_var.is_contiguous() && + old_var.use_count() == 1 && old_var.storage().use_count() == 1) { + buffer[pos] = old_var.add_(var); + } else { + buffer[pos] = old_var + var; } } +} - void InputBuffer::add(size_t pos, - Variable&& var, - const c10::optional& opt_producer_stream, - const c10::optional& opt_consumer_stream) { +void InputBuffer::add( + size_t pos, + Variable&& var, + const c10::optional& opt_producer_stream, + const c10::optional& opt_consumer_stream) { TORCH_INTERNAL_ASSERT(pos < buffer.size()); if (!var.defined()) { return; @@ -97,28 +108,32 @@ namespace { // Switches to accumulate device // The device (and stream) chosen for accumulation is: // (1) var is not a CUDA variable. Accumulation happens on var's device. - // (2) var is a CUDA variable and it, the consumer, and the producer share the same device: + // (2) var is a CUDA variable and it, the consumer, and the producer share + // the same device: // (2a) Uses the consumer's stream as the accumulation stream - // (2b) Syncs the accumulation stream with the producer's stream (if different) - // (2c) Accumulates. - // (3) var is a CUDA variable and it shares a device with the consumer but not the producer: + // (2b) Syncs the accumulation stream with the producer's stream (if + // different) (2c) Accumulates. + // (3) var is a CUDA variable and it shares a device with the consumer but + // not the producer: // (3a) Uses the consumer's stream as the accumulation stream - // (3b) Syncs the accumulation stream with the consumer device's default stream - // (3c) Accumulates. - // (4) var is a CUDA variable and it shares a device with the producer but not the consumer: - // (4a) Uses the producer device's default stream as the accumulation stream - // (4b) Syncs the accumulation stream with the the producer's stream - // (4c) Accumulates. - // (5) var is a CUDA variable and it does not share a device with the consumer or producer. + // (3b) Syncs the accumulation stream with the consumer device's default + // stream (3c) Accumulates. + // (4) var is a CUDA variable and it shares a device with the producer but + // not the consumer: + // (4a) Uses the producer device's default stream as the accumulation + // stream (4b) Syncs the accumulation stream with the the producer's + // stream (4c) Accumulates. + // (5) var is a CUDA variable and it does not share a device with the + // consumer or producer. // Accumulation happens on the var device's default stream. TORCH_INTERNAL_ASSERT(device_of(var)); c10::optional opt_accumulate_stream = c10::nullopt; if (device_of(var)->is_cuda()) { - const auto on_producer = opt_producer_stream - && device_of(var) == opt_producer_stream->device(); - const auto on_consumer = opt_consumer_stream - && device_of(var) == opt_consumer_stream->device(); + const auto on_producer = + opt_producer_stream && device_of(var) == opt_producer_stream->device(); + const auto on_consumer = + opt_consumer_stream && device_of(var) == opt_consumer_stream->device(); if (on_producer && on_consumer) { // (2a) @@ -139,7 +154,8 @@ namespace { opt_sync_stream = guard.getDefaultStream(opt_consumer_stream->device()); } else if (on_producer && !on_consumer) { // (4a) - opt_accumulate_stream = guard.getDefaultStream(opt_producer_stream->device()); + opt_accumulate_stream = + guard.getDefaultStream(opt_producer_stream->device()); opt_sync_stream = opt_producer_stream; } else { // (5) @@ -196,4 +212,5 @@ auto InputBuffer::variables(InputBuffer&& g) -> std::vector { return result; } -}} // namespace torch::autograd +} // namespace autograd +} // namespace torch diff --git a/torch/csrc/autograd/input_buffer.h b/torch/csrc/autograd/input_buffer.h index e13d6b17ee900f..dfa65cfaea4511 100644 --- a/torch/csrc/autograd/input_buffer.h +++ b/torch/csrc/autograd/input_buffer.h @@ -5,44 +5,48 @@ // values in-place (adding an input twice will accumulate the result). // This behaviour is needed and used only in backward graphs. -#include -#include -#include #include +#include +#include +#include -#include -#include #include +#include +#include -namespace torch { namespace autograd { +namespace torch { +namespace autograd { struct InputBuffer { // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) - explicit InputBuffer(size_t size) - : buffer(size) {} + explicit InputBuffer(size_t size) : buffer(size) {} InputBuffer(const InputBuffer& other) = delete; InputBuffer(InputBuffer&& other) = default; // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) - explicit InputBuffer(variable_list&& inputs): buffer(std::move(inputs)) {}; + explicit InputBuffer(variable_list&& inputs) : buffer(std::move(inputs)){}; InputBuffer& operator=(InputBuffer&& other) = default; // Accumulates the variable at a specified index. // The optional CUDA streams determine which stream the accumulation // is run on and how the addition is synchronized. - void add(size_t pos, - Variable&& var, - const c10::optional& opt_producer_stream, - const c10::optional& opt_consumer_stream); + void add( + size_t pos, + Variable&& var, + const c10::optional& opt_producer_stream, + const c10::optional& opt_consumer_stream); at::Device device() const; - Variable operator[](size_t pos) { return buffer[pos]; } + Variable operator[](size_t pos) { + return buffer[pos]; + } // Returns the inputs as a list of variables. Destroys given InputBuffer. static std::vector variables(InputBuffer&& g); -private: + private: std::vector buffer; }; -}} // namespace torch::autograd +} // namespace autograd +} // namespace torch diff --git a/torch/csrc/autograd/input_metadata.h b/torch/csrc/autograd/input_metadata.h index c4d50fcadfa154..b24126b3978964 100644 --- a/torch/csrc/autograd/input_metadata.h +++ b/torch/csrc/autograd/input_metadata.h @@ -1,11 +1,11 @@ #pragma once +#include #include #include #include #include #include -#include #ifndef AT_PER_OPERATOR_HEADERS #include @@ -15,11 +15,13 @@ #include -namespace torch { namespace autograd { +namespace torch { +namespace autograd { /** - * Records TensorOptions, shape of the tensor, whether or not the Python dispatch key is set (tensor subclass), - * and, where applicable, the stream the corresponding operation took place on. + * Records TensorOptions, shape of the tensor, whether or not the Python + * dispatch key is set (tensor subclass), and, where applicable, the stream the + * corresponding operation took place on. * * If is_valid() is false, then the corresponding input is not used and may be * an undefined tensor. @@ -27,14 +29,22 @@ namespace torch { namespace autograd { struct InputMetadata { InputMetadata() = default; - InputMetadata(const at::TensorOptions options, at::IntArrayRef shape, bool is_tensor_subclass) - : options_{options}, shape_{shape}, is_tensor_subclass_{is_tensor_subclass} { + InputMetadata( + const at::TensorOptions options, + at::IntArrayRef shape, + bool is_tensor_subclass) + : options_{options}, + shape_{shape}, + is_tensor_subclass_{is_tensor_subclass} { auto device_ = options.device(); stream_ = c10::impl::getDeviceGuardImpl(device_.type())->getStream(device_); } InputMetadata(const at::Tensor& t) - : InputMetadata(t.options(), t.sizes(), t.unsafeGetTensorImpl()->is_python_dispatch()) { } + : InputMetadata( + t.options(), + t.sizes(), + t.unsafeGetTensorImpl()->is_python_dispatch()) {} const at::TensorOptions options() const { return options_; @@ -90,11 +100,12 @@ struct InputMetadata { return ss; } -private: + private: const at::TensorOptions options_; at::DimVector shape_; c10::Stream stream_ = c10::Stream(c10::Stream::Default::DEFAULT, device()); bool is_tensor_subclass_ = false; }; -}} // torch::autograd +} // namespace autograd +} // namespace torch diff --git a/torch/csrc/autograd/profiler.h b/torch/csrc/autograd/profiler.h index 7ac44096cda70c..519f49005f7764 100644 --- a/torch/csrc/autograd/profiler.h +++ b/torch/csrc/autograd/profiler.h @@ -1,4 +1,4 @@ #pragma once -#include #include +#include diff --git a/torch/csrc/autograd/profiler_kineto.cpp b/torch/csrc/autograd/profiler_kineto.cpp index 9294d4551b1635..cb795d71e757ce 100644 --- a/torch/csrc/autograd/profiler_kineto.cpp +++ b/torch/csrc/autograd/profiler_kineto.cpp @@ -2,11 +2,11 @@ #include #include +#include #include #include #include #include -#include #include #include @@ -59,15 +59,15 @@ inline int64_t getTimeUs() { #endif // USE_KINETO } -using torch::profiler::impl::ProfilerThreadLocalStateBase; using torch::profiler::impl::ActiveProfilerType; +using torch::profiler::impl::dtypesToStr; using torch::profiler::impl::EventType; using torch::profiler::impl::ExtraFields; +using torch::profiler::impl::ProfilerThreadLocalStateBase; using torch::profiler::impl::Result; -using torch::profiler::impl::kineto::annotation_t; using torch::profiler::impl::shapesToStr; -using torch::profiler::impl::dtypesToStr; using torch::profiler::impl::stacksToStr; +using torch::profiler::impl::kineto::annotation_t; using torch::profiler::impl::kineto::KinetoActivityType; struct EventFieldsVisitor { @@ -75,8 +75,7 @@ struct EventFieldsVisitor { std::shared_ptr& result, KinetoEvent& kineto_event, const post_process_t& post_process) - : kineto_event_{kineto_event}, - post_process_{post_process} { + : kineto_event_{kineto_event}, post_process_{post_process} { pushPythonMetadata(result->parent_.lock()); c10::visit(*this, result->extra_fields_); handleStack(result->parent_); @@ -153,7 +152,8 @@ struct EventFieldsVisitor { "Total Allocated", std::to_string(alloc.total_allocated_)); } if (alloc.total_reserved_ >= 0) { - annotations_.emplace_back("Total Reserved", std::to_string(alloc.total_reserved_)); + annotations_.emplace_back( + "Total Reserved", std::to_string(alloc.total_reserved_)); } } @@ -354,7 +354,7 @@ struct KinetoThreadLocalState : public ProfilerThreadLocalStateBase { // NB: also sets fields on `kineto_events_.back()`. auto visitor = EventFieldsVisitor( - e, kineto_events_.back(), getEventPostProcessingCallback()); + e, kineto_events_.back(), getEventPostProcessingCallback()); cpu_trace_.addCPUActivity( e->name(), @@ -368,7 +368,8 @@ struct KinetoThreadLocalState : public ProfilerThreadLocalStateBase { } } - void finalizeCPUTrace(std::unique_ptr& cpu_trace) { + void finalizeCPUTrace( + std::unique_ptr& cpu_trace) { #ifndef USE_KINETO } #else // USE_KINETO @@ -378,10 +379,11 @@ struct KinetoThreadLocalState : public ProfilerThreadLocalStateBase { // Low-16bits of startThreadId and low-48bits seqNum are concatenated into // one uint64_t variable as key. - // From the time being, we need disable the forward/backward correlation feature to - // workaround the crash bug. + // From the time being, we need disable the forward/backward correlation + // feature to workaround the crash bug. // TODO: by Mike Guo - // reenable the forward/backward correlation when kineto fix the following raw pointer + // reenable the forward/backward correlation when kineto fix the following + // raw pointer // GenericTraceActivity.flow.linkedActivity /* @@ -443,7 +445,8 @@ struct KinetoThreadLocalState : public ProfilerThreadLocalStateBase { } #endif // USE_KINETO - void addTraceEvents(torch::profiler::impl::kineto::ActivityTraceWrapper& trace) { + void addTraceEvents( + torch::profiler::impl::kineto::ActivityTraceWrapper& trace) { #ifdef USE_KINETO const auto& events = *(trace.get()->activities()); for (const auto& ev_ptr : events) { @@ -520,15 +523,16 @@ class GlobalStateManager { std::shared_ptr state_; }; -template +template static KinetoThreadLocalState* getStatePtr() { return c10::guts::if_constexpr( [] { return GlobalStateManager::get(); }, [] { return KinetoThreadLocalState::getTLS(); }); } -template -std::unique_ptr onFunctionEnter(const at::RecordFunction& fn) { +template +std::unique_ptr onFunctionEnter( + const at::RecordFunction& fn) { auto state_ptr = getStatePtr(); if (!state_ptr) { return nullptr; @@ -543,18 +547,22 @@ std::unique_ptr onFunctionEnter(const at::RecordFunction& f } // @lint-ignore CLANGTIDY clang-diagnostic-unused-parameter -template -void onFunctionExit(const at::RecordFunction& fn, at::ObserverContext* ctx_ptr) { +template +void onFunctionExit( + const at::RecordFunction& fn, + at::ObserverContext* ctx_ptr) { auto state_ptr = getStatePtr(); if (!state_ptr) { return; } const auto& config = state_ptr->config(); auto* kineto_ctx_ptr = - static_cast(ctx_ptr); + static_cast(ctx_ptr); TORCH_INTERNAL_ASSERT(kineto_ctx_ptr != nullptr); - kineto_ctx_ptr->event_->end_time_ = torch::profiler::impl::getApproximateTime(); - kineto_ctx_ptr->event_->basic_fields_.end_tid_ = at::RecordFunction::currentThreadId(); + kineto_ctx_ptr->event_->end_time_ = + torch::profiler::impl::getApproximateTime(); + kineto_ctx_ptr->event_->basic_fields_.end_tid_ = + at::RecordFunction::currentThreadId(); if (config.state == ProfilerState::KINETO_GPU_FALLBACK) { try { auto fallback = kineto_ctx_ptr->fallback_; @@ -586,8 +594,7 @@ void pushProfilingCallbacks(const std::unordered_set& scopes) { auto handle = c10::guts::if_constexpr( [&] { return at::addGlobalCallback(recordFunctionCallback); }, - [&] { return at::addThreadLocalCallback(recordFunctionCallback); - }); + [&] { return at::addThreadLocalCallback(recordFunctionCallback); }); registration_state_ptr->setCallbackHandle(handle); } @@ -687,7 +694,9 @@ void enableProfiler( if (config.state == ProfilerState::KINETO_ONDEMAND) { GlobalStateManager::init(config, activities); - TORCH_INTERNAL_ASSERT(activities.count(ActivityType::CPU), "Ondemand profiling must enable CPU tracing"); + TORCH_INTERNAL_ASSERT( + activities.count(ActivityType::CPU), + "Ondemand profiling must enable CPU tracing"); pushProfilingCallbacks(scopes); } } @@ -714,7 +723,8 @@ std::unique_ptr disableProfiler() { // Traces are converged via libkineto automatically for ondemand flow if (state_ptr->config().state == ProfilerState::KINETO_ONDEMAND) { - (void)std::static_pointer_cast(state_ptr)->finalizeTrace(); + (void)std::static_pointer_cast(state_ptr) + ->finalizeTrace(); return std::make_unique(); } @@ -726,7 +736,8 @@ std::unique_ptr disableProfiler() { if (config.state == ProfilerState::KINETO || config.state == ProfilerState::KINETO_GPU_FALLBACK) { - auto kineto_state_ptr = std::static_pointer_cast(state_ptr); + auto kineto_state_ptr = + std::static_pointer_cast(state_ptr); auto trace = kineto_state_ptr->finalizeTrace(); result = std::make_unique( kineto_state_ptr->start_time_, @@ -743,7 +754,8 @@ int64_t KinetoEvent::cudaElapsedUs() const { return -1; } try { - return (int64_t)torch::profiler::impl::cudaStubs()->elapsed(&cuda_event_start_, &cuda_event_end_); + return (int64_t)torch::profiler::impl::cudaStubs()->elapsed( + &cuda_event_start_, &cuda_event_end_); } catch (std::exception& e) { LOG(WARNING) << "Failed to measure time between two CUDA events. " << e.what(); diff --git a/torch/csrc/autograd/profiler_kineto.h b/torch/csrc/autograd/profiler_kineto.h index 6525c52750854e..8c5a0cdfa2ad87 100644 --- a/torch/csrc/autograd/profiler_kineto.h +++ b/torch/csrc/autograd/profiler_kineto.h @@ -127,7 +127,8 @@ struct TORCH_API KinetoEvent { return *module_hierarchy_; } - KinetoEvent& moduleHierarchy(const std::vector& module_hierarchy) { + KinetoEvent& moduleHierarchy( + const std::vector& module_hierarchy) { module_hierarchy_ = module_hierarchy; return *this; } @@ -208,7 +209,7 @@ struct TORCH_API KinetoEvent { return correlation_id_; } - KinetoEvent& correlationId(uint64_t correlation_id) { + KinetoEvent& correlationId(uint64_t correlation_id) { correlation_id_ = correlation_id; return *this; } @@ -352,13 +353,15 @@ TORCH_API void enableProfiler( * Additionally, it takes a functor that does in-place post processing of * events, e.g. populate stack trace or module hierarchy information lazily * using debug_handle. - * Example usage is with lite interpreter that has recording scope of LITE_INTERPRETER. - * In this case lite interpreter runtime, records debug handles in RecordFunction, along - * with other information. Debug handles are eventually passed down to KinetoEvent and - * recorded as part of the event. KinetoEdgeCPUProfiler, - * in torch/csrc/jit/mobile/profiler_edge.cpp, enables profiler using post-processing - * callback, via enableProfilerWithEventPostProcess, that takes these debug handles - * and generates stack trace and module hierarchy information, once profiling is done. + * Example usage is with lite interpreter that has recording scope of + * LITE_INTERPRETER. In this case lite interpreter runtime, records debug + * handles in RecordFunction, along with other information. Debug handles are + * eventually passed down to KinetoEvent and recorded as part of the event. + * KinetoEdgeCPUProfiler, in torch/csrc/jit/mobile/profiler_edge.cpp, enables + * profiler using post-processing callback, via + * enableProfilerWithEventPostProcess, that takes these debug handles and + * generates stack trace and module hierarchy information, once profiling is + * done. */ using post_process_t = std::function& activities); } // namespace profiler -}} // namespace torch::autograd +} // namespace autograd +} // namespace torch diff --git a/torch/csrc/autograd/profiler_legacy.cpp b/torch/csrc/autograd/profiler_legacy.cpp index 72601c804ea409..22d9c5ecf99021 100644 --- a/torch/csrc/autograd/profiler_legacy.cpp +++ b/torch/csrc/autograd/profiler_legacy.cpp @@ -23,7 +23,9 @@ #include -namespace torch { namespace autograd { namespace profiler { +namespace torch { +namespace autograd { +namespace profiler { // We decompose the profiler logic into the following components: // @@ -117,20 +119,21 @@ namespace torch { namespace autograd { namespace profiler { // namespace { -using torch::profiler::impl::ProfilerThreadLocalStateBase; using torch::profiler::impl::ActiveProfilerType; +using torch::profiler::impl::ProfilerThreadLocalStateBase; struct ProfilerLegacyThreadLocalState : public ProfilerThreadLocalStateBase { // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) - explicit ProfilerLegacyThreadLocalState(const torch::profiler::impl::ProfilerConfig& config) - : ProfilerThreadLocalStateBase(config), remoteProfiledEvents_{c10::nullopt} {} + explicit ProfilerLegacyThreadLocalState( + const torch::profiler::impl::ProfilerConfig& config) + : ProfilerThreadLocalStateBase(config), + remoteProfiledEvents_{c10::nullopt} {} ~ProfilerLegacyThreadLocalState() override = default; static ProfilerLegacyThreadLocalState* getTLS() { auto tls = ProfilerThreadLocalStateBase::getTLS(); TORCH_INTERNAL_ASSERT_DEBUG_ONLY( - tls == nullptr || - tls->profilerType() == ActiveProfilerType::LEGACY); + tls == nullptr || tls->profilerType() == ActiveProfilerType::LEGACY); return static_cast(tls); } @@ -200,7 +203,8 @@ void ProfilerLegacyThreadLocalState::mark(std::string name, bool include_cuda) { EventKind::Mark, at::StringView(std::move(name)), at::RecordFunction::currentThreadId(), - include_cuda && config_.state == torch::profiler::impl::ProfilerState::CUDA); + include_cuda && + config_.state == torch::profiler::impl::ProfilerState::CUDA); evt.setNodeId(at::RecordFunction::getDefaultNodeId()); getEventList().record(std::move(evt)); } @@ -225,8 +229,9 @@ void ProfilerLegacyThreadLocalState::pushRange( return; } if (config_.state == torch::profiler::impl::ProfilerState::NVTX) { - torch::profiler::impl::cudaStubs()->nvtxRangePushA(torch::profiler::impl::getNvtxStr( - fn.name(), fn.seqNr(), shapes).c_str()); + torch::profiler::impl::cudaStubs()->nvtxRangePushA( + torch::profiler::impl::getNvtxStr(fn.name(), fn.seqNr(), shapes) + .c_str()); } else { LegacyEvent evt( EventKind::PushRange, @@ -242,17 +247,21 @@ void ProfilerLegacyThreadLocalState::pushRange( evt.setScope((uint8_t)fn.scope()); if (config_.with_flops) { evt.setExtraArgs(torch::profiler::impl::saveExtraArgs(fn)); - evt.setFlops(torch::profiler::impl::computeFlops(std::string(fn.name()), evt.extraArgs())); + evt.setFlops(torch::profiler::impl::computeFlops( + std::string(fn.name()), evt.extraArgs())); } // TODO: will unify the two macros BUILD_LITE_INTERPRETER and C10_MOBILE soon. #if !defined BUILD_LITE_INTERPRETER && !defined C10_MOBILE // backward nodes source range corresponds to the forward node // TODO: consider using C++ stack trace - if (config_.with_stack && fn.scope() != at::RecordScope::BACKWARD_FUNCTION) { - auto cs = torch::profiler::impl::prepareCallstack(jit::currentCallstack()); + if (config_.with_stack && + fn.scope() != at::RecordScope::BACKWARD_FUNCTION) { + auto cs = + torch::profiler::impl::prepareCallstack(jit::currentCallstack()); if (cs.empty()) { - cs = torch::profiler::impl::prepareCallstack(jit::tracer::pythonCallstack()); + cs = torch::profiler::impl::prepareCallstack( + jit::tracer::pythonCallstack()); } evt.setStack(callstackStr(cs)); } @@ -261,7 +270,9 @@ void ProfilerLegacyThreadLocalState::pushRange( } } -void ProfilerLegacyThreadLocalState::popRange(const at::RecordFunction& fn, const bool record_cuda) { +void ProfilerLegacyThreadLocalState::popRange( + const at::RecordFunction& fn, + const bool record_cuda) { if (config_.state == torch::profiler::impl::ProfilerState::Disabled) { return; } @@ -289,7 +300,8 @@ void ProfilerLegacyThreadLocalState::reportMemoryUsage( int64_t /* total_allocated, unused for legacy */, int64_t /* total_reserved, unused for legacy */, c10::Device device) { - if (config_.profile_memory && config_.state != torch::profiler::impl::ProfilerState::Disabled) { + if (config_.profile_memory && + config_.state != torch::profiler::impl::ProfilerState::Disabled) { uint64_t thread_id = at::RecordFunction::currentThreadId(); LegacyEvent evt( EventKind::MemoryAlloc, @@ -301,7 +313,8 @@ void ProfilerLegacyThreadLocalState::reportMemoryUsage( } } -RangeEventList& ProfilerLegacyThreadLocalState::getEventList(int64_t thread_id) { +RangeEventList& ProfilerLegacyThreadLocalState::getEventList( + int64_t thread_id) { if (thread_id < 0) { thread_id = at::RecordFunction::currentThreadId(); } @@ -335,69 +348,81 @@ enum EventIValueIdx { }; const std::unordered_set disable_cuda_profiling = { - "aten::view", - "aten::t", - "aten::transpose", - "aten::stride", - "aten::empty", - "aten::empty_like", - "aten::empty_strided", - "aten::as_strided", - "aten::expand", - "aten::resize_", - "aten::squeeze", - "aten::unsqueeze", - "aten::slice", - "aten::_unsafe_view", - "aten::size" -}; + "aten::view", + "aten::t", + "aten::transpose", + "aten::stride", + "aten::empty", + "aten::empty_like", + "aten::empty_strided", + "aten::as_strided", + "aten::expand", + "aten::resize_", + "aten::squeeze", + "aten::unsqueeze", + "aten::slice", + "aten::_unsafe_view", + "aten::size"}; void pushProfilingCallbacksLegacy() { auto registration_state_ptr = ProfilerLegacyThreadLocalState::getTLS(); TORCH_INTERNAL_ASSERT(registration_state_ptr, "Expected profiler state set"); - auto handle = at::addThreadLocalCallback(at::RecordFunctionCallback( - [](const at::RecordFunction& fn) -> std::unique_ptr { - auto state_ptr = ProfilerLegacyThreadLocalState::getTLS(); - if (!state_ptr || state_ptr->config().state == torch::profiler::impl::ProfilerState::Disabled) { - return nullptr; - } - bool record_cuda = - state_ptr->config().state == torch::profiler::impl::ProfilerState::CUDA; - if (record_cuda && disable_cuda_profiling.find(fn.name()) != disable_cuda_profiling.end()) { - record_cuda = false; - } - - if (state_ptr->config().report_input_shapes) { - auto sizes = torch::profiler::impl::inputSizes(fn); - state_ptr->pushRange(fn, record_cuda, std::move(sizes)); - } else { - state_ptr->pushRange(fn, record_cuda); - } - - return nullptr; - }, - [](const at::RecordFunction& fn, at::ObserverContext*) { - auto state_ptr = ProfilerLegacyThreadLocalState::getTLS(); - if (!state_ptr || state_ptr->config().state == torch::profiler::impl::ProfilerState::Disabled) { - return; - } - bool record_cuda = - state_ptr->config().state == torch::profiler::impl::ProfilerState::CUDA; - if (record_cuda && disable_cuda_profiling.find(fn.name()) != disable_cuda_profiling.end()) { - record_cuda = false; - } - state_ptr->popRange(fn, record_cuda); - }) - .needsInputs(registration_state_ptr->config().report_input_shapes) - .needsIds(true)); + auto handle = at::addThreadLocalCallback( + at::RecordFunctionCallback( + [](const at::RecordFunction& fn) + -> std::unique_ptr { + auto state_ptr = ProfilerLegacyThreadLocalState::getTLS(); + if (!state_ptr || + state_ptr->config().state == + torch::profiler::impl::ProfilerState::Disabled) { + return nullptr; + } + bool record_cuda = state_ptr->config().state == + torch::profiler::impl::ProfilerState::CUDA; + if (record_cuda && + disable_cuda_profiling.find(fn.name()) != + disable_cuda_profiling.end()) { + record_cuda = false; + } + + if (state_ptr->config().report_input_shapes) { + auto sizes = torch::profiler::impl::inputSizes(fn); + state_ptr->pushRange(fn, record_cuda, std::move(sizes)); + } else { + state_ptr->pushRange(fn, record_cuda); + } + + return nullptr; + }, + [](const at::RecordFunction& fn, at::ObserverContext*) { + auto state_ptr = ProfilerLegacyThreadLocalState::getTLS(); + if (!state_ptr || + state_ptr->config().state == + torch::profiler::impl::ProfilerState::Disabled) { + return; + } + bool record_cuda = state_ptr->config().state == + torch::profiler::impl::ProfilerState::CUDA; + if (record_cuda && + disable_cuda_profiling.find(fn.name()) != + disable_cuda_profiling.end()) { + record_cuda = false; + } + state_ptr->popRange(fn, record_cuda); + }) + .needsInputs(registration_state_ptr->config().report_input_shapes) + .needsIds(true)); registration_state_ptr->setCallbackHandle(handle); } } // namespace -void enableProfilerLegacy(const torch::profiler::impl::ProfilerConfig& new_config) { - TORCH_CHECK(new_config.state != torch::profiler::impl::ProfilerState::NVTX || torch::profiler::impl::cudaStubs()->enabled(), - "Can't use NVTX profiler - PyTorch was compiled without CUDA"); +void enableProfilerLegacy( + const torch::profiler::impl::ProfilerConfig& new_config) { + TORCH_CHECK( + new_config.state != torch::profiler::impl::ProfilerState::NVTX || + torch::profiler::impl::cudaStubs()->enabled(), + "Can't use NVTX profiler - PyTorch was compiled without CUDA"); TORCH_CHECK(new_config.state != torch::profiler::impl::ProfilerState::KINETO); @@ -411,26 +436,35 @@ void enableProfilerLegacy(const torch::profiler::impl::ProfilerConfig& new_confi state->mark("__start_profile", false); } -thread_event_lists disableProfilerLegacy(c10::optional profilerDisableOptions) { - auto cleanupTLSState = profilerDisableOptions ? profilerDisableOptions->cleanupTLSState : true; - auto consolidate = profilerDisableOptions ? profilerDisableOptions->consolidate : true; - // all the DebugInfoBase objects are scope based and supposed to use DebugInfoGuard +thread_event_lists disableProfilerLegacy( + c10::optional profilerDisableOptions) { + auto cleanupTLSState = + profilerDisableOptions ? profilerDisableOptions->cleanupTLSState : true; + auto consolidate = + profilerDisableOptions ? profilerDisableOptions->consolidate : true; + // all the DebugInfoBase objects are scope based and supposed to use + // DebugInfoGuard std::shared_ptr state; if (cleanupTLSState) { state = c10::ThreadLocalDebugInfo::_pop(c10::DebugInfoKind::PROFILER_STATE); } else { - state = c10::ThreadLocalDebugInfo::_peek(c10::DebugInfoKind::PROFILER_STATE); + state = + c10::ThreadLocalDebugInfo::_peek(c10::DebugInfoKind::PROFILER_STATE); } auto state_ptr = static_cast(state.get()); - TORCH_CHECK(state_ptr && state_ptr->config().state != torch::profiler::impl::ProfilerState::Disabled, + TORCH_CHECK( + state_ptr && + state_ptr->config().state != + torch::profiler::impl::ProfilerState::Disabled, "Can't disable profiler when it's not running"); if (cleanupTLSState) { at::removeCallback(state_ptr->callbackHandle()); } - if (!consolidate || state_ptr->config().state == torch::profiler::impl::ProfilerState::NVTX) { + if (!consolidate || + state_ptr->config().state == torch::profiler::impl::ProfilerState::NVTX) { return thread_event_lists(); } @@ -453,7 +487,8 @@ void LegacyEvent::record(bool record_cuda) { cpu_ns_ = torch::profiler::impl::getTime(); } -/* static */ LegacyEvent LegacyEvent::fromIValue(const at::IValue& eventIValue) { +/* static */ LegacyEvent LegacyEvent::fromIValue( + const at::IValue& eventIValue) { TORCH_INTERNAL_ASSERT( eventIValue.isList(), "Expected IValue to contain type c10::impl::GenericList"); @@ -467,9 +502,8 @@ void LegacyEvent::record(bool record_cuda) { // Reconstruct input shapes from ivalues. auto shapeListIValue = ivalues.get(EventIValueIdx::SHAPES); TORCH_INTERNAL_ASSERT( - shapeListIValue.isList(), - "Expected profiler shapes IValue to contain type c10::impl::GenericList." - ); + shapeListIValue.isList(), + "Expected profiler shapes IValue to contain type c10::impl::GenericList."); auto shapeList = shapeListIValue.toList(); std::vector> shapes; @@ -551,7 +585,8 @@ double LegacyEvent::cudaElapsedUs(const LegacyEvent& e) const { TORCH_INTERNAL_ASSERT(cuda_us_ >= 0 && e.cuda_us_ >= 0); return static_cast(e.cuda_us_ - cuda_us_); } - return torch::profiler::impl::cudaStubs()->elapsed(&cuda_event, &e.cuda_event); + return torch::profiler::impl::cudaStubs()->elapsed( + &cuda_event, &e.cuda_event); } static const at::jit::CodeTemplate event_template(R"( @@ -565,7 +600,9 @@ static const at::jit::CodeTemplate event_template(R"( "args": {} })"); -void writeProfilerEventsToStream(std::ostream& out, const std::vector& events) { +void writeProfilerEventsToStream( + std::ostream& out, + const std::vector& events) { TORCH_CHECK(out, "Could not open file"); LegacyEvent* profiler_start = nullptr; for (LegacyEvent* e : events) { @@ -577,12 +614,17 @@ void writeProfilerEventsToStream(std::ostream& out, const std::vector p) const - noexcept { - return std::hash()(p.first) ^ std::hash()(p.second); + size_t operator()( + std::pair p) const noexcept { + return std::hash()(p.first) ^ + std::hash()(p.second); } }; - std::unordered_map, LegacyEvent*, PairHash> events_map; + std::unordered_map< + std::pair, + LegacyEvent*, + PairHash> + events_map; out << "[\n"; bool first = true; for (LegacyEvent* evt : events) { @@ -609,18 +651,18 @@ void writeProfilerEventsToStream(std::ostream& out, const std::vector events; for (auto& l : event_lists) { for (auto& e : l) { - events.push_back(&e); + events.push_back(&e); } } processEvents(events); @@ -644,4 +686,6 @@ void RecordProfile::processEvents(const std::vector& events) { writeProfilerEventsToStream(out_, events); } -}}} +} // namespace profiler +} // namespace autograd +} // namespace torch diff --git a/torch/csrc/autograd/profiler_legacy.h b/torch/csrc/autograd/profiler_legacy.h index 5dad7de1a250e4..a2b7626aa9d17e 100644 --- a/torch/csrc/autograd/profiler_legacy.h +++ b/torch/csrc/autograd/profiler_legacy.h @@ -1,20 +1,21 @@ #pragma once +#include +#include #include -#include #include -#include -#include -#include +#include #include -#include +#include #include +#include #include -#include #include +#include -namespace torch { namespace autograd { +namespace torch { +namespace autograd { struct Node; @@ -96,10 +97,14 @@ struct TORCH_API LegacyEvent { std::string kindStr() const { switch (kind_) { - case EventKind::Mark: return "mark"; - case EventKind::PushRange: return "push"; - case EventKind::PopRange: return "pop"; - case EventKind::MemoryAlloc: return "memory_alloc"; + case EventKind::Mark: + return "mark"; + case EventKind::PushRange: + return "push"; + case EventKind::PopRange: + return "pop"; + case EventKind::MemoryAlloc: + return "memory_alloc"; } throw std::runtime_error("unknown event kind"); } @@ -122,7 +127,7 @@ struct TORCH_API LegacyEvent { double cpuElapsedUs(const LegacyEvent& e) const { // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions,cppcoreguidelines-avoid-magic-numbers) - return static_cast(e.cpu_ns_ - cpu_ns_)/(1000.0); + return static_cast(e.cpu_ns_ - cpu_ns_) / (1000.0); } void setCpuUs(int64_t cpu_us) { @@ -144,11 +149,10 @@ struct TORCH_API LegacyEvent { } void updateMemoryStats(int64_t alloc_size, c10::Device device) { - if (device.is_cuda() || - device.type() == c10::DeviceType::HIP) { + if (device.is_cuda() || device.type() == c10::DeviceType::HIP) { cuda_memory_usage_ = alloc_size; - } else if (device.is_cpu() || - device.type() == c10::DeviceType::MKLDNN || + } else if ( + device.is_cpu() || device.type() == c10::DeviceType::MKLDNN || device.type() == c10::DeviceType::IDEEP) { cpu_memory_usage_ = alloc_size; } else { @@ -169,7 +173,7 @@ struct TORCH_API LegacyEvent { } // Node ID corresponding to this event. - int nodeId( ) const { + int nodeId() const { return node_id_; } @@ -257,7 +261,7 @@ struct TORCH_API LegacyEvent { EventKind kind_; uint64_t thread_id_; uint64_t fwd_thread_id_; - at::RecordFunctionHandle handle_ {0}; + at::RecordFunctionHandle handle_{0}; std::vector> shapes_; int64_t cpu_memory_usage_ = 0; int64_t cuda_memory_usage_ = 0; @@ -286,7 +290,7 @@ struct RangeEventList { events_.reserve(kReservedCapacity); } - template + template void record(Args&&... args) { std::lock_guard guard(mutex_); events_.emplace_back(std::forward(args)...); @@ -335,15 +339,20 @@ struct TORCH_API ProfilerDisableOptions { // NOTE: profiler mode is thread local, with automatic propagation // across thread boundary (e.g. at::launch tasks) -TORCH_API void enableProfilerLegacy(const torch::profiler::impl::ProfilerConfig&); +TORCH_API void enableProfilerLegacy( + const torch::profiler::impl::ProfilerConfig&); using thread_event_lists = std::vector>; -TORCH_API thread_event_lists disableProfilerLegacy(c10::optional profilerDisableOptions = c10::nullopt); +TORCH_API thread_event_lists disableProfilerLegacy( + c10::optional profilerDisableOptions = + c10::nullopt); // adds profiledEvents to the current thread local recorded events. Each event // will be marked with node ID given by fromNodeId. TORCH_API void addEventList(std::vector&& profiledEvents); // Writes profiled events to a stream. -TORCH_API void writeProfilerEventsToStream(std::ostream& out, const std::vector& events); +TORCH_API void writeProfilerEventsToStream( + std::ostream& out, + const std::vector& events); // Usage: // { @@ -356,16 +365,16 @@ struct TORCH_API RecordProfile { RecordProfile(const std::string& filename); ~RecordProfile(); -private: + + private: void init(); std::unique_ptr file_; std::ostream& out_; void processEvents(const std::vector& events); }; -// A guard that enables the legacy profiler, taking in an optional callback to process -// the results -// Usage: +// A guard that enables the legacy profiler, taking in an optional callback to +// process the results Usage: // { // TLSLegacyProfilerGuard g([](thread_event_lists profilerResults) { // // process profilerResults @@ -386,7 +395,8 @@ struct TORCH_API TLSLegacyProfilerGuard { } ~TLSLegacyProfilerGuard() { // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - thread_event_lists event_lists = disableProfilerLegacy(profilerDisableOptions_); + thread_event_lists event_lists = + disableProfilerLegacy(profilerDisableOptions_); if (cb_) { try { (*cb_)(event_lists); @@ -401,4 +411,6 @@ struct TORCH_API TLSLegacyProfilerGuard { const c10::optional profilerDisableOptions_; }; -}}} // namespace torch::autograd::profiler +} // namespace profiler +} // namespace autograd +} // namespace torch diff --git a/torch/csrc/autograd/profiler_python.cpp b/torch/csrc/autograd/profiler_python.cpp index e4d4dfa8affa43..0301aa65a64e52 100644 --- a/torch/csrc/autograd/profiler_python.cpp +++ b/torch/csrc/autograd/profiler_python.cpp @@ -20,8 +20,8 @@ #include #include #include -#include #include +#include namespace py = pybind11; @@ -144,10 +144,10 @@ class CallTypeHelper final { // 2) Add a specialization of Config which defined key_t and cache_t. // 3) Add a specialization of ValueCache::store and ValueCache::load. -template +template struct Config; -template<> +template <> struct Config { using key_t = CodeLocation; using cache_t = ska::flat_hash_map; @@ -165,7 +165,7 @@ struct Config { static constexpr EventType event_type = EventType::PyCall; }; -template<> +template <> struct Config { using key_t = torch::profiler::impl::PyCFunction; using cache_t = ska::flat_hash_map; @@ -299,7 +299,8 @@ ExtraFields::args_t ValueCache::load( // TODO: Use re2. void ValueCache::trimPrefixes() { static auto prefixes = py::module::import("torch.profiler.python_tracer") - .attr("_prefix_regex")().cast>(); + .attr("_prefix_regex")() + .cast>(); for (auto& it : std::get(state_)) { std::string filename = it.second.filename_.str(); @@ -358,51 +359,50 @@ struct TraceKeyCacheState { // `PyEval_SetProfile`. struct ThreadLocalResults; struct TraceContext { - PyObject_HEAD - ThreadLocalResults* thread_local_results_; + PyObject_HEAD ThreadLocalResults* thread_local_results_; }; // CPython boilerplate to define `TraceContext` as a proper python object. static PyTypeObject TraceContextType = { - PyVarObject_HEAD_INIT(nullptr, 0) - "TraceContext", /* tp_name */ - sizeof(TraceContext), /* tp_basicsize */ - 0, /* tp_itemsize */ - nullptr, /* tp_dealloc */ - 0, /* tp_vectorcall_offset */ // NOLINT: modernize-use-nullptr - nullptr, /* tp_getattr */ - nullptr, /* tp_setattr */ - nullptr, /* tp_reserved */ - nullptr, /* tp_repr */ - nullptr, /* tp_as_number */ - nullptr, /* tp_as_sequence */ - nullptr, /* tp_as_mapping */ - nullptr, /* tp_hash */ - nullptr, /* tp_call */ - nullptr, /* tp_str */ - nullptr, /* tp_getattro */ - nullptr, /* tp_setattro */ - nullptr, /* tp_as_buffer */ - Py_TPFLAGS_DEFAULT, /* tp_flags */ - "Python tracer TLS", /* tp_doc */ - nullptr, /* tp_traverse */ - nullptr, /* tp_clear */ - nullptr, /* tp_richcompare */ - 0, /* tp_weaklistoffset */ - nullptr, /* tp_iter */ - nullptr, /* tp_iternext */ - nullptr, /* tp_methods */ - nullptr, /* tp_members */ - nullptr, /* tp_getset */ - nullptr, /* tp_base */ - nullptr, /* tp_dict */ - nullptr, /* tp_descr_get */ - nullptr, /* tp_descr_set */ - 0, /* tp_dictoffset */ - nullptr, /* tp_init */ - nullptr, /* tp_alloc */ - PyType_GenericNew, /* tp_new */ - nullptr /* tp_free */ + PyVarObject_HEAD_INIT(nullptr, 0) "TraceContext", /* tp_name */ + sizeof(TraceContext), /* tp_basicsize */ + 0, /* tp_itemsize */ + nullptr, /* tp_dealloc */ + 0, + /* tp_vectorcall_offset */ // NOLINT: modernize-use-nullptr + nullptr, /* tp_getattr */ + nullptr, /* tp_setattr */ + nullptr, /* tp_reserved */ + nullptr, /* tp_repr */ + nullptr, /* tp_as_number */ + nullptr, /* tp_as_sequence */ + nullptr, /* tp_as_mapping */ + nullptr, /* tp_hash */ + nullptr, /* tp_call */ + nullptr, /* tp_str */ + nullptr, /* tp_getattro */ + nullptr, /* tp_setattro */ + nullptr, /* tp_as_buffer */ + Py_TPFLAGS_DEFAULT, /* tp_flags */ + "Python tracer TLS", /* tp_doc */ + nullptr, /* tp_traverse */ + nullptr, /* tp_clear */ + nullptr, /* tp_richcompare */ + 0, /* tp_weaklistoffset */ + nullptr, /* tp_iter */ + nullptr, /* tp_iternext */ + nullptr, /* tp_methods */ + nullptr, /* tp_members */ + nullptr, /* tp_getset */ + nullptr, /* tp_base */ + nullptr, /* tp_dict */ + nullptr, /* tp_descr_get */ + nullptr, /* tp_descr_set */ + 0, /* tp_dictoffset */ + nullptr, /* tp_init */ + nullptr, /* tp_alloc */ + PyType_GenericNew, /* tp_new */ + nullptr /* tp_free */ }; // ============================================================================ @@ -468,7 +468,10 @@ class PythonTracer final : public python_tracer::PythonTracerBase { PythonTracer(); void recordPyCall(ThreadLocalResults& tls, PyFrameObject* frame); - void recordCCall(ThreadLocalResults& tls, PyFrameObject* frame, PyObject* arg); + void recordCCall( + ThreadLocalResults& tls, + PyFrameObject* frame, + PyObject* arg); torch::profiler::impl::RecordQueue* queue_; PyObject* module_call_code_; @@ -560,7 +563,8 @@ void PythonTracer::stop() { } void PythonTracer::clear() { - TORCH_CHECK(queue_ == nullptr, "Cannot clear state while PythonTracer is active."); + TORCH_CHECK( + queue_ == nullptr, "Cannot clear state while PythonTracer is active."); thread_local_results_.clear(); value_cache_ = ValueCache(); } @@ -643,9 +647,7 @@ class PostProcess { } template - void addExits( - AppendOnlyList& exits, - size_t python_tid) { + void addExits(AppendOnlyList& exits, size_t python_tid) { for (const auto i : exits) { get_state().exits_.push({time_converter_(i), python_tid}); } @@ -702,7 +704,7 @@ class PostProcess { template auto& get_state() { - return std::get(state_); + return std::get < E == EventType::PyCall ? 0 : 1 > (state_); } std::function time_converter_; @@ -738,10 +740,9 @@ std::vector> PythonTracer::getEvents( PostProcess post_process(time_converter, thread_local_results_, value_cache_); auto out = post_process.run(enters); - std::stable_sort( - out.begin(), out.end(), [](const auto& a, const auto& b) { - return a->start_time_ns_ < b->start_time_ns_; - }); + std::stable_sort(out.begin(), out.end(), [](const auto& a, const auto& b) { + return a->start_time_ns_ < b->start_time_ns_; + }); PythonIDVisitor id_visitor; for (auto& i : out) { diff --git a/torch/csrc/autograd/profiler_python.h b/torch/csrc/autograd/profiler_python.h index 947217efb4bfe2..d16bce8f3b6e58 100644 --- a/torch/csrc/autograd/profiler_python.h +++ b/torch/csrc/autograd/profiler_python.h @@ -1,7 +1,13 @@ #pragma once -namespace torch { namespace autograd { namespace profiler { namespace python_tracer { +namespace torch { +namespace autograd { +namespace profiler { +namespace python_tracer { void init(); -}}}} // namespace torch::autograd::profiler::python_tracer +} +} // namespace profiler +} // namespace autograd +} // namespace torch diff --git a/torch/csrc/autograd/python_anomaly_mode.cpp b/torch/csrc/autograd/python_anomaly_mode.cpp index 3e831d544c0b1d..4cde43b6dae9aa 100644 --- a/torch/csrc/autograd/python_anomaly_mode.cpp +++ b/torch/csrc/autograd/python_anomaly_mode.cpp @@ -1,7 +1,7 @@ -#include #include #include #include +#include #include #include #include @@ -10,7 +10,8 @@ #include -namespace torch { namespace autograd { +namespace torch { +namespace autograd { void PyAnomalyMetadata::store_stack() { pybind11::gil_scoped_acquire gil; @@ -54,20 +55,24 @@ void PyAnomalyMetadata::print_stack(const std::string& current_node_name) { throw python_error(); } const std::string parent_name(parent_name_char); - PyObject* parent_stack = PyDict_GetItemString(parent_metadata.get(), ANOMALY_TRACE_KEY); + PyObject* parent_stack = + PyDict_GetItemString(parent_metadata.get(), ANOMALY_TRACE_KEY); _print_stack(parent_stack, parent_name, true); - // get the parent of this node, if this node is a root, pyparent is simply null + // get the parent of this node, if this node is a root, pyparent is simply + // null pyparent = PyDict_GetItemString(parent_metadata.get(), ANOMALY_PARENT_KEY); } } -void PyAnomalyMetadata::assign_parent(const std::shared_ptr& parent_node) { +void PyAnomalyMetadata::assign_parent( + const std::shared_ptr& parent_node) { // assign the python object of parent_node in metadata["parent_"] // if parent_node is nullptr, then do nothing (it can mean that "parent_" key // is not in metadata) pybind11::gil_scoped_acquire gil; - if (!parent_node) return; + if (!parent_node) + return; THPObjectPtr parent_node_(functionToPyObject(parent_node)); if (!parent_node_) { @@ -78,11 +83,17 @@ void PyAnomalyMetadata::assign_parent(const std::shared_ptr& parent_node) } } -void _print_stack(PyObject* stack, const std::string& current_node_name, bool is_parent) { +void _print_stack( + PyObject* stack, + const std::string& current_node_name, + bool is_parent) { if (!stack) { - TORCH_WARN("Error detected in ", current_node_name, ". ", - "No forward pass information available. Enable detect anomaly " - "during forward pass for more information."); + TORCH_WARN( + "Error detected in ", + current_node_name, + ". ", + "No forward pass information available. Enable detect anomaly " + "during forward pass for more information."); return; } @@ -99,15 +110,22 @@ void _print_stack(PyObject* stack, const std::string& current_node_name, bool is } if (!is_parent) { - TORCH_WARN("Error detected in ", current_node_name, ". ", - "Traceback of forward call that caused the error:\n", - THPUtils_unpackString(msg.get())); + TORCH_WARN( + "Error detected in ", + current_node_name, + ". ", + "Traceback of forward call that caused the error:\n", + THPUtils_unpackString(msg.get())); } else { - TORCH_WARN("\n\n", - "Previous calculation was induced by ", current_node_name, ". " - "Traceback of forward call that induced the previous calculation:\n", - THPUtils_unpackString(msg.get())); + TORCH_WARN( + "\n\n", + "Previous calculation was induced by ", + current_node_name, + ". " + "Traceback of forward call that induced the previous calculation:\n", + THPUtils_unpackString(msg.get())); } } -}} +} // namespace autograd +} // namespace torch diff --git a/torch/csrc/autograd/python_anomaly_mode.h b/torch/csrc/autograd/python_anomaly_mode.h index f0bdd5e68028f1..d58efd34ebaa45 100644 --- a/torch/csrc/autograd/python_anomaly_mode.h +++ b/torch/csrc/autograd/python_anomaly_mode.h @@ -5,7 +5,8 @@ #include #include -namespace torch { namespace autograd { +namespace torch { +namespace autograd { struct PyAnomalyMetadata : public AnomalyMetadata { // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables,clang-diagnostic-writable-strings) @@ -32,9 +33,13 @@ struct PyAnomalyMetadata : public AnomalyMetadata { return dict_; } -private: + private: PyObject* dict_; }; -void _print_stack(PyObject* trace_stack, const std::string& current_node_name, bool is_parent); +void _print_stack( + PyObject* trace_stack, + const std::string& current_node_name, + bool is_parent); -}} +} // namespace autograd +} // namespace torch diff --git a/torch/csrc/autograd/python_autograd.h b/torch/csrc/autograd/python_autograd.h index 04771dca2af581..aca84af9c3f8b0 100644 --- a/torch/csrc/autograd/python_autograd.h +++ b/torch/csrc/autograd/python_autograd.h @@ -1,17 +1,19 @@ #ifndef THP_AUTOGRAD_H #define THP_AUTOGRAD_H -PyObject * THPAutograd_initExtension(PyObject *_unused, PyObject *unused); +PyObject* THPAutograd_initExtension(PyObject* _unused, PyObject* unused); void THPAutograd_initFunctions(); -namespace torch { namespace autograd { +namespace torch { +namespace autograd { PyMethodDef* python_functions(); -}} +} +} // namespace torch +#include #include #include -#include #endif diff --git a/torch/csrc/autograd/python_cpp_function.cpp b/torch/csrc/autograd/python_cpp_function.cpp index 3ae46289fb8e4e..3c4805fcc18bb1 100644 --- a/torch/csrc/autograd/python_cpp_function.cpp +++ b/torch/csrc/autograd/python_cpp_function.cpp @@ -2,29 +2,32 @@ #include #include -#include #include +#include #include #include +#include +#include +#include +#include #include -#include #include -#include -#include +#include #include #include -#include -#include using namespace torch::autograd; -namespace torch { namespace autograd { +namespace torch { +namespace autograd { namespace { -PyObject* THPCppFunction_call(PyObject* self, PyObject* args, PyObject *kwargs) -{ +PyObject* THPCppFunction_call( + PyObject* self, + PyObject* args, + PyObject* kwargs) { if (kwargs && PyDict_Size(kwargs) != 0) { return PyErr_Format(PyExc_TypeError, "keyword arguments are not supported"); } @@ -32,8 +35,11 @@ PyObject* THPCppFunction_call(PyObject* self, PyObject* args, PyObject *kwargs) int num_inputs = PyTuple_GET_SIZE(args); int num_inputs_required = ((THPCppFunction*)self)->cdata->num_inputs(); if (num_inputs != num_inputs_required) { - return PyErr_Format(PyExc_TypeError, "expected %d arguments, got %d instead", - num_inputs_required, num_inputs); + return PyErr_Format( + PyExc_TypeError, + "expected %d arguments, got %d instead", + num_inputs_required, + num_inputs); } variable_list vars(num_inputs); for (int i = 0; i != num_inputs; ++i) { @@ -68,8 +74,7 @@ PyObject* THPCppFunction_call(PyObject* self, PyObject* args, PyObject *kwargs) return tuple.release(); } -int THPCppFunction_traverse(PyObject* self, visitproc visit, void *arg) -{ +int THPCppFunction_traverse(PyObject* self, visitproc visit, void* arg) { auto& fn = *((THPCppFunction*)self)->cdata; for (const auto& hook : fn.pre_hooks()) { if (auto pyhook = dynamic_cast(hook.get())) { @@ -84,8 +89,7 @@ int THPCppFunction_traverse(PyObject* self, visitproc visit, void *arg) return 0; } -int THPCppFunction_clear(PyObject* self) -{ +int THPCppFunction_clear(PyObject* self) { auto f = (THPCppFunction*)self; // Remove the weak ref of the c++ object if it exist if (f->cdata) { @@ -95,8 +99,7 @@ int THPCppFunction_clear(PyObject* self) return 0; } -void THPCppFunction_dealloc(PyObject* self) -{ +void THPCppFunction_dealloc(PyObject* self) { THPCppFunction_clear(self); ((THPCppFunction*)self)->cdata.~shared_ptr(); Py_TYPE(self)->tp_free(self); @@ -104,83 +107,86 @@ void THPCppFunction_dealloc(PyObject* self) } // namespace -PyObject* THPCppFunction_next_functions(THPCppFunction* self, PyObject* hook) -{ +PyObject* THPCppFunction_next_functions(THPCppFunction* self, PyObject* hook) { const auto num_next = self->cdata->num_outputs(); THPObjectPtr py_functions(PyTuple_New(num_next)); - if (!py_functions) return nullptr; + if (!py_functions) + return nullptr; for (const auto i : c10::irange(num_next)) { auto& c_tuple = self->cdata->next_edge(i); THPObjectPtr tuple(PyTuple_New(2)); - if (!tuple) return nullptr; - PyObject *py_fn = functionToPyObject(c_tuple.function); - if (!py_fn) return nullptr; + if (!tuple) + return nullptr; + PyObject* py_fn = functionToPyObject(c_tuple.function); + if (!py_fn) + return nullptr; PyTuple_SET_ITEM(tuple.get(), 0, py_fn); - PyObject *py_idx = THPUtils_packUInt32(c_tuple.input_nr); - if (!py_idx) return nullptr; + PyObject* py_idx = THPUtils_packUInt32(c_tuple.input_nr); + if (!py_idx) + return nullptr; PyTuple_SET_ITEM(tuple.get(), 1, py_idx); PyTuple_SET_ITEM(py_functions.get(), i, tuple.release()); } return py_functions.release(); } -PyObject* THPCppFunction_metadata(THPCppFunction *self, void *_unused) -{ - auto metadata = static_cast(self->cdata->metadata())->dict(); +PyObject* THPCppFunction_metadata(THPCppFunction* self, void* _unused) { + auto metadata = + static_cast(self->cdata->metadata())->dict(); Py_INCREF(metadata); return metadata; } -PyObject* THPCppFunction_requires_grad(THPCppFunction* self, void *unused) { +PyObject* THPCppFunction_requires_grad(THPCppFunction* self, void* unused) { Py_RETURN_TRUE; } -PyObject* THPCppFunction_register_hook_dict(PyObject* self, PyObject* _var) -{ +PyObject* THPCppFunction_register_hook_dict(PyObject* self, PyObject* _var) { if (!THPVariable_Check(_var)) { - return PyErr_Format(PyExc_TypeError, "_register_hook_dict expected a variable"); + return PyErr_Format( + PyExc_TypeError, "_register_hook_dict expected a variable"); } auto var = (THPVariable*)_var; auto& fn = *((THPCppFunction*)self)->cdata; - std::unique_ptr hook( - new PyFunctionPreHook(var->backward_hooks, THPVariable_Unpack(var).output_nr())); + std::unique_ptr hook(new PyFunctionPreHook( + var->backward_hooks, THPVariable_Unpack(var).output_nr())); fn.add_pre_hook(std::move(hook)); Py_RETURN_NONE; } -PyObject* THPCppFunction_register_hook(PyObject* self, PyObject* hook) -{ +PyObject* THPCppFunction_register_hook(PyObject* self, PyObject* hook) { auto& fn = *((THPCppFunction*)self)->cdata; return registerFunctionHook(fn, hook); } -PyObject* THPCppFunction_name(PyObject* self, PyObject *noargs) { +PyObject* THPCppFunction_name(PyObject* self, PyObject* noargs) { auto& fn = *((THPCppFunction*)self)->cdata; return THPUtils_packString(fn.name()); } // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables,modernize-avoid-c-arrays) static struct PyMethodDef default_methods[] = { - THP_FUNCTION_DEFAULT_METHODS, - {nullptr} -}; + THP_FUNCTION_DEFAULT_METHODS, + {nullptr}}; // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables,modernize-avoid-c-arrays) static struct PyGetSetDef default_properties[] = { - THP_FUNCTION_DEFAULT_PROPERTIES, - {nullptr} -}; - -PyTypeObject* _initFunctionPyTypeObject(PyTypeObject& type, const char* name, - PyGetSetDef* function_properties, PyMethodDef* function_methods) -{ + THP_FUNCTION_DEFAULT_PROPERTIES, + {nullptr}}; + +PyTypeObject* _initFunctionPyTypeObject( + PyTypeObject& type, + const char* name, + PyGetSetDef* function_properties, + PyMethodDef* function_methods) { type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC; type.tp_name = name; type.tp_basicsize = sizeof(THPCppFunction); type.tp_call = THPCppFunction_call; type.tp_methods = function_methods ? function_methods : default_methods; - type.tp_getset = function_properties ? function_properties : default_properties; + type.tp_getset = + function_properties ? function_properties : default_properties; type.tp_dealloc = THPCppFunction_dealloc; type.tp_traverse = THPCppFunction_traverse; type.tp_clear = THPCppFunction_clear; @@ -202,8 +208,7 @@ struct DefaultFunctionType { PyTypeObject type; }; -PyObject* functionToPyObject(const std::shared_ptr& cdata) -{ +PyObject* functionToPyObject(const std::shared_ptr& cdata) { static DefaultFunctionType default_type; if (!cdata) { @@ -230,7 +235,8 @@ PyObject* functionToPyObject(const std::shared_ptr& cdata) } THPObjectPtr obj(type->tp_alloc(type, 0)); - if (!obj) return nullptr; + if (!obj) + return nullptr; THPCppFunction* f = (THPCppFunction*)obj.get(); new (&f->cdata) std::shared_ptr(cdata); @@ -241,14 +247,12 @@ PyObject* functionToPyObject(const std::shared_ptr& cdata) return cdata->pyobj(); } -void registerCppFunction(const std::type_info& type, PyTypeObject* pytype) -{ +void registerCppFunction(const std::type_info& type, PyTypeObject* pytype) { Py_INCREF((PyObject*)pytype); cpp_function_types[std::type_index(type)] = THPObjectPtr((PyObject*)pytype); } -PyObject* registerFunctionHook(Node& fn, PyObject* hook) -{ +PyObject* registerFunctionHook(Node& fn, PyObject* hook) { PyObject* dict = Py_None; for (const auto& hook : fn.post_hooks()) { if (auto pyhook = dynamic_cast(hook.get())) { @@ -257,10 +261,14 @@ PyObject* registerFunctionHook(Node& fn, PyObject* hook) } } - THPObjectPtr register_fn(PyObject_GetAttrString(THPFunctionClass, "_register_hook")); - if (!register_fn) return nullptr; - THPObjectPtr res(PyObject_CallFunctionObjArgs(register_fn.get(), dict, hook, nullptr)); - if (!res) return nullptr; + THPObjectPtr register_fn( + PyObject_GetAttrString(THPFunctionClass, "_register_hook")); + if (!register_fn) + return nullptr; + THPObjectPtr res( + PyObject_CallFunctionObjArgs(register_fn.get(), dict, hook, nullptr)); + if (!res) + return nullptr; if (dict == Py_None) { dict = PyTuple_GET_ITEM(res.get(), 0); @@ -273,4 +281,5 @@ PyObject* registerFunctionHook(Node& fn, PyObject* hook) return handle; } -}} // namespace torch::autograd +} // namespace autograd +} // namespace torch diff --git a/torch/csrc/autograd/python_cpp_function.h b/torch/csrc/autograd/python_cpp_function.h index fb1ba9c1a278c3..9c8cf19483fbb9 100644 --- a/torch/csrc/autograd/python_cpp_function.h +++ b/torch/csrc/autograd/python_cpp_function.h @@ -4,22 +4,25 @@ #include #include +#include #include #include -#include -namespace torch { namespace autograd { +namespace torch { +namespace autograd { struct THPCppFunction { - PyObject_HEAD - std::shared_ptr cdata; + PyObject_HEAD std::shared_ptr cdata; }; -template -PyObject* CppFunction_pynew(PyTypeObject *type, PyObject *args, PyObject *kwds) -{ +template +PyObject* CppFunction_pynew( + PyTypeObject* type, + PyObject* args, + PyObject* kwds) { THPObjectPtr obj(type->tp_alloc(type, 0)); - if (!obj) return nullptr; + if (!obj) + return nullptr; THPCppFunction* f = (THPCppFunction*)obj.get(); HANDLE_TH_ERRORS new (&f->cdata) std::shared_ptr(Ctor()(args)); @@ -30,37 +33,60 @@ PyObject* CppFunction_pynew(PyTypeObject *type, PyObject *args, PyObject *kwds) return obj.release(); } -#define THP_FUNCTION_DEFAULT_METHODS \ - {(char*)"_register_hook_dict", THPCppFunction_register_hook_dict, METH_O, nullptr}, \ - {(char*)"register_hook", THPCppFunction_register_hook, METH_O, nullptr}, \ - {(char*)"name", THPCppFunction_name, METH_NOARGS, nullptr} +#define THP_FUNCTION_DEFAULT_METHODS \ + {(char*)"_register_hook_dict", \ + THPCppFunction_register_hook_dict, \ + METH_O, \ + nullptr}, \ + {(char*)"register_hook", THPCppFunction_register_hook, METH_O, nullptr}, \ + { \ + (char*)"name", THPCppFunction_name, METH_NOARGS, nullptr \ + } -#define THP_FUNCTION_DEFAULT_PROPERTIES \ - {(char*)"next_functions", (getter)THPCppFunction_next_functions, nullptr, nullptr, nullptr}, \ - {(char*)"requires_grad", (getter)THPCppFunction_requires_grad, nullptr, nullptr, nullptr}, \ - {(char*)"metadata", (getter)THPCppFunction_metadata, nullptr, nullptr, nullptr} +#define THP_FUNCTION_DEFAULT_PROPERTIES \ + {(char*)"next_functions", \ + (getter)THPCppFunction_next_functions, \ + nullptr, \ + nullptr, \ + nullptr}, \ + {(char*)"requires_grad", \ + (getter)THPCppFunction_requires_grad, \ + nullptr, \ + nullptr, \ + nullptr}, \ + { \ + (char*)"metadata", (getter)THPCppFunction_metadata, nullptr, nullptr, \ + nullptr \ + } PyObject* THPCppFunction_next_functions(THPCppFunction* self, PyObject* hook); -PyObject* THPCppFunction_metadata(THPCppFunction *self, void *_unused); -PyObject* THPCppFunction_requires_grad(THPCppFunction* self, void *_unused); +PyObject* THPCppFunction_metadata(THPCppFunction* self, void* _unused); +PyObject* THPCppFunction_requires_grad(THPCppFunction* self, void* _unused); PyObject* THPCppFunction_register_hook_dict(PyObject* self, PyObject* _var); PyObject* THPCppFunction_register_hook(PyObject* self, PyObject* hook); -PyObject* THPCppFunction_name(PyObject* self, PyObject *noargs); +PyObject* THPCppFunction_name(PyObject* self, PyObject* noargs); -PyTypeObject* _initFunctionPyTypeObject(PyTypeObject& type, const char* name, - PyGetSetDef* function_properties, PyMethodDef* function_methods); +PyTypeObject* _initFunctionPyTypeObject( + PyTypeObject& type, + const char* name, + PyGetSetDef* function_properties, + PyMethodDef* function_methods); PyObject* registerFunctionHook(Node& fn, PyObject* hook); -template -PyTypeObject* createForwardFunctionPyTypeObject(PyTypeObject& type, const char* name, - PyGetSetDef* function_properties=nullptr, PyMethodDef* function_methods=nullptr) -{ +template +PyTypeObject* createForwardFunctionPyTypeObject( + PyTypeObject& type, + const char* name, + PyGetSetDef* function_properties = nullptr, + PyMethodDef* function_methods = nullptr) { type.tp_new = &CppFunction_pynew; - return _initFunctionPyTypeObject(type, name, function_properties, function_methods); + return _initFunctionPyTypeObject( + type, name, function_properties, function_methods); } void registerCppFunction(const std::type_info& type, PyTypeObject* pytype); PyObject* functionToPyObject(const std::shared_ptr& cdata); -}} // namespace torch::autograd +} // namespace autograd +} // namespace torch diff --git a/torch/csrc/autograd/python_engine.cpp b/torch/csrc/autograd/python_engine.cpp index 6d6c54b1bd3cd4..203ccea416f2ae 100644 --- a/torch/csrc/autograd/python_engine.cpp +++ b/torch/csrc/autograd/python_engine.cpp @@ -1,5 +1,9 @@ #include +#include +#include +#include +#include #include #include #include @@ -7,30 +11,28 @@ #include #include #include -#include #include +#include #include -#include -#include -#include -#include #ifndef _WIN32 #include #endif -#include #include // for unique_ptr +#include using namespace torch::autograd; struct THPEngine { - PyObject_HEAD + PyObject_HEAD }; static bool _reinitialize_engine = false; -namespace torch { namespace autograd { namespace python { +namespace torch { +namespace autograd { +namespace python { PythonEngine::PythonEngine() = default; @@ -59,14 +61,17 @@ PythonEngine::~PythonEngine() { #define IS_PYTHON_3_9_PLUS #endif -void PythonEngine::thread_init(int device, const std::shared_ptr& ready_queue, bool should_increment) { +void PythonEngine::thread_init( + int device, + const std::shared_ptr& ready_queue, + bool should_increment) { // Increment thread usage count before acquiring the GIL if (should_increment) { increment_non_reentrant_thread_count(); } - // Create a PyThreadState, but release the GIL. This lets pybind11::gil_scoped_acquire calls - // inside thread_main acquire the GIL without having to create a new - // PyThreadState each time. + // Create a PyThreadState, but release the GIL. This lets + // pybind11::gil_scoped_acquire calls inside thread_main acquire the GIL + // without having to create a new PyThreadState each time. #if defined(IS_PYTHON_3_9_PLUS) || defined(USE_DEPLOY) auto gil = std::make_unique(); #else @@ -81,11 +86,14 @@ void PythonEngine::thread_init(int device, const std::shared_ptr& re } #if defined(IS_PYTHON_3_9_PLUS) || defined(USE_DEPLOY) - // Do not call PyEval_RestoreThread, PyThreadState_[Clear|DeleteCurrent] if runtime is finalizing + // Do not call PyEval_RestoreThread, PyThreadState_[Clear|DeleteCurrent] if + // runtime is finalizing if (!Py_IsInitialized()) { no_gil.disarm(); - // TODO: call disarm rather than leak gil_scoped_acquired once PyThreadState_Clear can safely be called from finalize - // NOTE: deploy.cpp calls `PyInterpreterState_Delete` to destruct PyThreadState, so avoid use-after-free here. + // TODO: call disarm rather than leak gil_scoped_acquired once + // PyThreadState_Clear can safely be called from finalize NOTE: deploy.cpp + // calls `PyInterpreterState_Delete` to destruct PyThreadState, so avoid + // use-after-free here. gil.release(); } #endif @@ -106,7 +114,8 @@ std::unique_ptr PythonEngine::make_anomaly_metadata() { return std::unique_ptr(new PyAnomalyMetadata()); } -std::unique_ptr PythonEngine::get_default_saved_variable_hooks() { +std::unique_ptr PythonEngine:: + get_default_saved_variable_hooks() { return PyDefaultSavedVariableHooks::get_hooks(); } @@ -117,12 +126,15 @@ variable_list PythonEngine::execute( bool create_graph, bool accumulate_grad, const edge_list& outputs) { - TORCH_CHECK(!PyGILState_Check(), "The autograd engine was called while holding the GIL. If you are using the C++ " - "API, the autograd engine is an expensive operation that does not require the " - "GIL to be held so you should release it with 'pybind11::gil_scoped_release no_gil;'" - ". If you are not using the C++ API, please report a bug to the pytorch team.") + TORCH_CHECK( + !PyGILState_Check(), + "The autograd engine was called while holding the GIL. If you are using the C++ " + "API, the autograd engine is an expensive operation that does not require the " + "GIL to be held so you should release it with 'pybind11::gil_scoped_release no_gil;'" + ". If you are not using the C++ API, please report a bug to the pytorch team.") try { - return Engine::execute(roots, inputs, keep_graph, create_graph, accumulate_grad, outputs); + return Engine::execute( + roots, inputs, keep_graph, create_graph, accumulate_grad, outputs); } catch (python_error& e) { e.restore(); throw; @@ -134,7 +146,8 @@ c10::intrusive_ptr PythonEngine::execute_with_graph_task( std::shared_ptr graph_root, InputBuffer&& input_buffer) { try { - return Engine::execute_with_graph_task(graph_task, graph_root, std::move(input_buffer)); + return Engine::execute_with_graph_task( + graph_task, graph_root, std::move(input_buffer)); } catch (python_error& e) { pybind11::gil_scoped_acquire gil; if (!PyErr_Occurred()) { @@ -144,41 +157,73 @@ c10::intrusive_ptr PythonEngine::execute_with_graph_task( throw; } } -}}} // namespace torch::autograd::python +} // namespace python +} // namespace autograd +} // namespace torch -PyObject *THPEngineClass = nullptr; +PyObject* THPEngineClass = nullptr; // Implementation of torch._C._EngineBase.run_backward -PyObject *THPEngine_run_backward(PyObject *self, PyObject *args, PyObject *kwargs) -{ +PyObject* THPEngine_run_backward( + PyObject* self, + PyObject* args, + PyObject* kwargs) { HANDLE_TH_ERRORS - PyObject *tensors = nullptr; - PyObject *grad_tensors = nullptr; + PyObject* tensors = nullptr; + PyObject* grad_tensors = nullptr; unsigned char keep_graph = 0; unsigned char create_graph = 0; - PyObject *inputs = nullptr; + PyObject* inputs = nullptr; unsigned char allow_unreachable = 0; - unsigned char accumulate_grad = 0; // Indicate whether to accumulate grad into leaf Tensors or capture - const char *accepted_kwargs[] = { // NOLINT - "tensors", "grad_tensors", "keep_graph", "create_graph", "inputs", - "allow_unreachable", "accumulate_grad", nullptr - }; - if (!PyArg_ParseTupleAndKeywords(args, kwargs, "OObb|Obb", (char**)accepted_kwargs, - &tensors, &grad_tensors, &keep_graph, &create_graph, &inputs, &allow_unreachable, &accumulate_grad)) + unsigned char accumulate_grad = + 0; // Indicate whether to accumulate grad into leaf Tensors or capture + const char* accepted_kwargs[] = {// NOLINT + "tensors", + "grad_tensors", + "keep_graph", + "create_graph", + "inputs", + "allow_unreachable", + "accumulate_grad", + nullptr}; + if (!PyArg_ParseTupleAndKeywords( + args, + kwargs, + "OObb|Obb", + (char**)accepted_kwargs, + &tensors, + &grad_tensors, + &keep_graph, + &create_graph, + &inputs, + &allow_unreachable, + &accumulate_grad)) return nullptr; - THPUtils_assert(PyTuple_Check(tensors), "tensors argument is expected to " - "be a tuple, but got %s", THPUtils_typename(tensors)); - THPUtils_assert(PyTuple_Check(grad_tensors), "grad_tensors argument is " - "expected to be a tuple, but got %s", THPUtils_typename(grad_tensors)); + THPUtils_assert( + PyTuple_Check(tensors), + "tensors argument is expected to " + "be a tuple, but got %s", + THPUtils_typename(tensors)); + THPUtils_assert( + PyTuple_Check(grad_tensors), + "grad_tensors argument is " + "expected to be a tuple, but got %s", + THPUtils_typename(grad_tensors)); Py_ssize_t num_tensors = PyTuple_GET_SIZE(tensors); Py_ssize_t num_gradients = PyTuple_GET_SIZE(grad_tensors); - THPUtils_assert(num_tensors == num_gradients, "got %ld tensors and %ld " - "gradients", num_tensors, num_gradients); - - // The user either called autograd.backward(...) or autograd.grad(...) to get here + THPUtils_assert( + num_tensors == num_gradients, + "got %ld tensors and %ld " + "gradients", + num_tensors, + num_gradients); + + // The user either called autograd.backward(...) or autograd.grad(...) to get + // here bool backward_api_called = accumulate_grad; - TORCH_CHECK(!backward_api_called || at::impl::VmapMode::current_vmap_level() == 0, + TORCH_CHECK( + !backward_api_called || at::impl::VmapMode::current_vmap_level() == 0, "backward() called inside torch.vmap. This is not supported, " "please call backward() outside torch.vmap or instead use " "torch.autograd.grad inside torch.vmap"); @@ -187,37 +232,49 @@ PyObject *THPEngine_run_backward(PyObject *self, PyObject *args, PyObject *kwarg roots.reserve(num_tensors); variable_list grads; grads.reserve(num_tensors); - for(const auto i : c10::irange(num_tensors)) { - PyObject *_tensor = PyTuple_GET_ITEM(tensors, i); - THPUtils_assert(THPVariable_Check(_tensor), "element %d of tensors " - "tuple is not a Tensor", i); + for (const auto i : c10::irange(num_tensors)) { + PyObject* _tensor = PyTuple_GET_ITEM(tensors, i); + THPUtils_assert( + THPVariable_Check(_tensor), + "element %d of tensors " + "tuple is not a Tensor", + i); const auto& variable = THPVariable_Unpack(_tensor); - TORCH_CHECK(!isBatchedTensor(variable), + TORCH_CHECK( + !isBatchedTensor(variable), "torch.autograd.grad(outputs, inputs, grad_outputs) called inside ", "torch.vmap. We do not support the case where any outputs are ", - "vmapped tensors (output ", i, " is being vmapped over). Please " + "vmapped tensors (output ", + i, + " is being vmapped over). Please " "call autograd.grad() outside torch.vmap or file a bug report " "with your use case.") auto gradient_edge = torch::autograd::impl::gradient_edge(variable); - THPUtils_assert(gradient_edge.function, - "element %d of tensors does not require grad and does not have a grad_fn", i); + THPUtils_assert( + gradient_edge.function, + "element %d of tensors does not require grad and does not have a grad_fn", + i); roots.push_back(std::move(gradient_edge)); - PyObject *grad = PyTuple_GET_ITEM(grad_tensors, i); + PyObject* grad = PyTuple_GET_ITEM(grad_tensors, i); if (THPVariable_Check(grad)) { const Variable& grad_var = THPVariable_Unpack(grad); if (grad_var.has_names()) { TORCH_WARN( - "Autograd was passed a named grad tensor with dims ", grad_var.names(), + "Autograd was passed a named grad tensor with dims ", + grad_var.names(), ". Autograd does not yet support named tensor semantics, so all names ", "will be ignored. In practice all computed gradients will still be correct " "according to regular tensor semantics."); } grads.push_back(grad_var); } else { - THPUtils_assert(grad == Py_None, - "element %d of gradients tuple is not a Tensor or None", i); - THPUtils_assert(!variable.requires_grad(), + THPUtils_assert( + grad == Py_None, + "element %d of gradients tuple is not a Tensor or None", + i); + THPUtils_assert( + !variable.requires_grad(), "element %d of gradients tuple is None, but the corresponding Tensor requires grad"); } } @@ -227,14 +284,19 @@ PyObject *THPEngine_run_backward(PyObject *self, PyObject *args, PyObject *kwarg int num_inputs = PyTuple_GET_SIZE(inputs); output_edges.reserve(num_inputs); for (const auto i : c10::irange(num_inputs)) { - PyObject *input = PyTuple_GET_ITEM(inputs, i); - THPUtils_assert(THPVariable_Check(input), - "all inputs have to be Tensors, but got %s", THPUtils_typename(input)); + PyObject* input = PyTuple_GET_ITEM(inputs, i); + THPUtils_assert( + THPVariable_Check(input), + "all inputs have to be Tensors, but got %s", + THPUtils_typename(input)); const auto& tensor = THPVariable_Unpack(input); - TORCH_CHECK(!isBatchedTensor(tensor), + TORCH_CHECK( + !isBatchedTensor(tensor), "torch.autograd.grad(outputs, inputs, grad_outputs) called inside ", "torch.vmap. We do not support the case where any inputs are ", - "vmapped tensors (input ", i, " is being vmapped over). Please " + "vmapped tensors (input ", + i, + " is being vmapped over). Please " "call autograd.grad() outside torch.vmap or file a bug report " "with your use case.") const auto output_nr = tensor.output_nr(); @@ -245,14 +307,16 @@ PyObject *THPEngine_run_backward(PyObject *self, PyObject *args, PyObject *kwarg if (accumulate_grad) { tensor.retain_grad(); } - THPUtils_assert(tensor.requires_grad(), + THPUtils_assert( + tensor.requires_grad(), "One of the differentiated Tensors does not require grad"); if (!grad_fn) { // NOTE [ Autograd Unreachable Input ] - // Since input has no grad_accumulator, its guaranteed to be unreachable. - // We initialize an edge pointing to a non-nullptr Node so nodes in the graph - // (e.g., mul when an operand is scalar) that have edges pointing to nullptr - // don't get erroneously assigned `needed = True` in exec_info. + // Since input has no grad_accumulator, its guaranteed to be + // unreachable. We initialize an edge pointing to a non-nullptr Node so + // nodes in the graph (e.g., mul when an operand is scalar) that have + // edges pointing to nullptr don't get erroneously assigned `needed = + // True` in exec_info. output_edges.emplace_back(std::make_shared(), 0); } else { output_edges.emplace_back(grad_fn, output_nr); @@ -264,18 +328,22 @@ PyObject *THPEngine_run_backward(PyObject *self, PyObject *args, PyObject *kwarg { pybind11::gil_scoped_release no_gil; auto& engine = python::PythonEngine::get_python_engine(); - outputs = engine.execute(roots, grads, keep_graph, create_graph, accumulate_grad, output_edges); + outputs = engine.execute( + roots, grads, keep_graph, create_graph, accumulate_grad, output_edges); } if (!backward_api_called && inputs != nullptr) { int num_inputs = PyTuple_GET_SIZE(inputs); - THPObjectPtr py_outputs {PyTuple_New(num_inputs)}; - if (!py_outputs) return nullptr; - for(const auto i : c10::irange(num_inputs)) { - THPUtils_assert(allow_unreachable || outputs[i].defined(), "One of the " - "differentiated Tensors appears to not have been used " - "in the graph. Set allow_unused=True if this is the " - "desired behavior."); + THPObjectPtr py_outputs{PyTuple_New(num_inputs)}; + if (!py_outputs) + return nullptr; + for (const auto i : c10::irange(num_inputs)) { + THPUtils_assert( + allow_unreachable || outputs[i].defined(), + "One of the " + "differentiated Tensors appears to not have been used " + "in the graph. Set allow_unused=True if this is the " + "desired behavior."); PyTuple_SET_ITEM(py_outputs.get(), i, THPVariable_Wrap(outputs[i])); } return py_outputs.release(); @@ -285,24 +353,28 @@ PyObject *THPEngine_run_backward(PyObject *self, PyObject *args, PyObject *kwarg END_HANDLE_TH_ERRORS } -PyObject* THPEngine_queue_callback(PyObject *self, PyObject *_callback) { +PyObject* THPEngine_queue_callback(PyObject* self, PyObject* _callback) { HANDLE_TH_ERRORS auto& engine = python::PythonEngine::get_python_engine(); - std::shared_ptr callback(_callback, [](PyObject *obj) { pybind11::gil_scoped_acquire gil; Py_DECREF(obj); }); + std::shared_ptr callback(_callback, [](PyObject* obj) { + pybind11::gil_scoped_acquire gil; + Py_DECREF(obj); + }); Py_INCREF(_callback); engine.queue_callback([callback]() { pybind11::gil_scoped_acquire gil; - THPObjectPtr result {PyObject_CallFunctionObjArgs(callback.get(), nullptr)}; - if (!result) throw python_error(); + THPObjectPtr result{PyObject_CallFunctionObjArgs(callback.get(), nullptr)}; + if (!result) + throw python_error(); }); Py_RETURN_NONE; END_HANDLE_TH_ERRORS } -PyObject* THPEngine_is_checkpoint_valid(PyObject *self, PyObject *noargs) { +PyObject* THPEngine_is_checkpoint_valid(PyObject* self, PyObject* noargs) { HANDLE_TH_ERRORS auto& engine = python::PythonEngine::get_python_engine(); - if(engine.is_checkpoint_valid()) { + if (engine.is_checkpoint_valid()) { Py_RETURN_TRUE; } else { Py_RETURN_FALSE; @@ -310,69 +382,68 @@ PyObject* THPEngine_is_checkpoint_valid(PyObject *self, PyObject *noargs) { END_HANDLE_TH_ERRORS } -PyObject *THPEngine_new(PyTypeObject *type, PyObject *args, PyObject *kwargs) -{ +PyObject* THPEngine_new(PyTypeObject* type, PyObject* args, PyObject* kwargs) { return type->tp_alloc(type, 0); } // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables) static struct PyMethodDef THPEngine_methods[] = { - {(char*)"run_backward", - castPyCFunctionWithKeywords(THPEngine_run_backward), - METH_VARARGS | METH_KEYWORDS, nullptr}, - {(char*)"queue_callback", THPEngine_queue_callback, METH_O, nullptr}, - {(char*)"is_checkpoint_valid", THPEngine_is_checkpoint_valid, METH_NOARGS, nullptr}, - {nullptr} -}; - + {(char*)"run_backward", + castPyCFunctionWithKeywords(THPEngine_run_backward), + METH_VARARGS | METH_KEYWORDS, + nullptr}, + {(char*)"queue_callback", THPEngine_queue_callback, METH_O, nullptr}, + {(char*)"is_checkpoint_valid", + THPEngine_is_checkpoint_valid, + METH_NOARGS, + nullptr}, + {nullptr}}; PyTypeObject THPEngineType = { - PyVarObject_HEAD_INIT(nullptr, 0) - "torch._C._EngineBase", /* tp_name */ - sizeof(THPEngine), /* tp_basicsize */ - 0, /* tp_itemsize */ - nullptr, /* tp_dealloc */ - 0, /* tp_vectorcall_offset */ - nullptr, /* tp_getattr */ - nullptr, /* tp_setattr */ - nullptr, /* tp_reserved */ - nullptr, /* tp_repr */ - nullptr, /* tp_as_number */ - nullptr, /* tp_as_sequence */ - nullptr, /* tp_as_mapping */ - nullptr, /* tp_hash */ - nullptr, /* tp_call */ - nullptr, /* tp_str */ - nullptr, /* tp_getattro */ - nullptr, /* tp_setattro */ - nullptr, /* tp_as_buffer */ - Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /* tp_flags */ - nullptr, /* tp_doc */ - nullptr, /* tp_traverse */ - nullptr, /* tp_clear */ - nullptr, /* tp_richcompare */ - 0, /* tp_weaklistoffset */ - nullptr, /* tp_iter */ - nullptr, /* tp_iternext */ - THPEngine_methods, /* tp_methods */ - nullptr, /* tp_members */ - nullptr, /* tp_getset */ - nullptr, /* tp_base */ - nullptr, /* tp_dict */ - nullptr, /* tp_descr_get */ - nullptr, /* tp_descr_set */ - 0, /* tp_dictoffset */ - nullptr, /* tp_init */ - nullptr, /* tp_alloc */ - THPEngine_new /* tp_new */ + PyVarObject_HEAD_INIT(nullptr, 0) "torch._C._EngineBase", /* tp_name */ + sizeof(THPEngine), /* tp_basicsize */ + 0, /* tp_itemsize */ + nullptr, /* tp_dealloc */ + 0, /* tp_vectorcall_offset */ + nullptr, /* tp_getattr */ + nullptr, /* tp_setattr */ + nullptr, /* tp_reserved */ + nullptr, /* tp_repr */ + nullptr, /* tp_as_number */ + nullptr, /* tp_as_sequence */ + nullptr, /* tp_as_mapping */ + nullptr, /* tp_hash */ + nullptr, /* tp_call */ + nullptr, /* tp_str */ + nullptr, /* tp_getattro */ + nullptr, /* tp_setattro */ + nullptr, /* tp_as_buffer */ + Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /* tp_flags */ + nullptr, /* tp_doc */ + nullptr, /* tp_traverse */ + nullptr, /* tp_clear */ + nullptr, /* tp_richcompare */ + 0, /* tp_weaklistoffset */ + nullptr, /* tp_iter */ + nullptr, /* tp_iternext */ + THPEngine_methods, /* tp_methods */ + nullptr, /* tp_members */ + nullptr, /* tp_getset */ + nullptr, /* tp_base */ + nullptr, /* tp_dict */ + nullptr, /* tp_descr_get */ + nullptr, /* tp_descr_set */ + 0, /* tp_dictoffset */ + nullptr, /* tp_init */ + nullptr, /* tp_alloc */ + THPEngine_new /* tp_new */ }; static void child_atfork() { _reinitialize_engine = true; } -bool THPEngine_initModule(PyObject *module) -{ +bool THPEngine_initModule(PyObject* module) { #ifndef _WIN32 if (pthread_atfork(nullptr, nullptr, child_atfork) != 0) { throw std::runtime_error("unable to set pthread_atfork handler"); @@ -381,7 +452,7 @@ bool THPEngine_initModule(PyObject *module) if (PyType_Ready(&THPEngineType) < 0) return false; Py_INCREF(&THPEngineType); - PyModule_AddObject(module, "_ImperativeEngine", (PyObject *)&THPEngineType); + PyModule_AddObject(module, "_ImperativeEngine", (PyObject*)&THPEngineType); set_default_engine_stub(python::PythonEngine::get_python_engine); return true; } diff --git a/torch/csrc/autograd/python_engine.h b/torch/csrc/autograd/python_engine.h index 7f877874fd297b..d7c060330dc790 100644 --- a/torch/csrc/autograd/python_engine.h +++ b/torch/csrc/autograd/python_engine.h @@ -2,17 +2,20 @@ #include -#include #include +#include -bool THPEngine_initModule(PyObject *module); +bool THPEngine_initModule(PyObject* module); -namespace torch { namespace autograd { namespace python { +namespace torch { +namespace autograd { +namespace python { struct PythonEngine : public Engine { static Engine& get_python_engine(); ~PythonEngine() override; - void thread_init(int device, + void thread_init( + int device, const std::shared_ptr& ready_queue, bool should_increment) override; void thread_on_exception( @@ -33,9 +36,13 @@ struct PythonEngine : public Engine { InputBuffer&& input_buffer) override; std::unique_ptr make_anomaly_metadata() override; - std::unique_ptr get_default_saved_variable_hooks() override; - private: - PythonEngine(); + std::unique_ptr get_default_saved_variable_hooks() + override; + + private: + PythonEngine(); }; -}}} // namespace torch::autograd::python +} // namespace python +} // namespace autograd +} // namespace torch diff --git a/torch/csrc/autograd/python_enum_tag.h b/torch/csrc/autograd/python_enum_tag.h index 7a95cd98229922..51ef784be3c81a 100644 --- a/torch/csrc/autograd/python_enum_tag.h +++ b/torch/csrc/autograd/python_enum_tag.h @@ -3,6 +3,7 @@ #include namespace torch { - namespace autograd { - void initEnumTag(PyObject* module); -}} // namespace torch::autograd +namespace autograd { +void initEnumTag(PyObject* module); +} +} // namespace torch diff --git a/torch/csrc/autograd/python_fft_functions.h b/torch/csrc/autograd/python_fft_functions.h index 18b664dacf1265..6f0796473f793a 100644 --- a/torch/csrc/autograd/python_fft_functions.h +++ b/torch/csrc/autograd/python_fft_functions.h @@ -1,7 +1,9 @@ #pragma once -namespace torch { namespace autograd { +namespace torch { +namespace autograd { void initFFTFunctions(PyObject* module); -}} // namespace torch::autograd +} +} // namespace torch diff --git a/torch/csrc/autograd/python_function.cpp b/torch/csrc/autograd/python_function.cpp index 43911fe18b993f..492585396ae227 100644 --- a/torch/csrc/autograd/python_function.cpp +++ b/torch/csrc/autograd/python_function.cpp @@ -1,29 +1,29 @@ #include -#include -#include #include #include #include #include +#include +#include +#include +#include +#include #include -#include #include #include #include +#include +#include #include #include #include -#include #include #include -#include #include +#include #include -#include -#include -#include #include #include @@ -41,10 +41,13 @@ using namespace torch::autograd; using namespace torch::jit; using at::Tensor; -PyObject *THPFunctionClass = nullptr; +PyObject* THPFunctionClass = nullptr; -#define THPFunction_assert(condition, ...) \ - if (!(condition)) { THPUtils_setError(__VA_ARGS__); throw python_error(); } +#define THPFunction_assert(condition, ...) \ + if (!(condition)) { \ + THPUtils_setError(__VA_ARGS__); \ + throw python_error(); \ + } // Anonymous namespace for helpful functions used in this file namespace { @@ -65,13 +68,14 @@ void throw_python_error() { throw err; } -} +} // namespace -namespace torch { namespace autograd { +namespace torch { +namespace autograd { -// NOTE: this function is written in a way that assumes it's only called for backward; -// it's used by engine.cpp. This is responsible for forwarding a call from -// C++'s Node::apply to a Python method "apply". +// NOTE: this function is written in a way that assumes it's only called for +// backward; it's used by engine.cpp. This is responsible for forwarding a call +// from C++'s Node::apply to a Python method "apply". auto PyNode::apply(variable_list&& inputs) -> variable_list { pybind11::gil_scoped_acquire gil; at::OptionalDeviceGuard _device_guard; @@ -80,7 +84,8 @@ auto PyNode::apply(variable_list&& inputs) -> variable_list { // Massage a C++ variable_list into a Python arguments tuple auto num_inputs = inputs.size(); THPObjectPtr pyInputs(PyTuple_New(num_inputs)); - if (!pyInputs) throw_python_error(); + if (!pyInputs) + throw_python_error(); auto& output_info = py_fn->output_info; for (const auto i : c10::irange(num_inputs)) { // NOLINTNEXTLINE(cppcoreguidelines-init-variables) @@ -90,14 +95,17 @@ auto PyNode::apply(variable_list&& inputs) -> variable_list { } else { input = THPVariable_Wrap(output_info[i].zeros(_device_guard)); } - if (!input) throw_python_error(); + if (!input) + throw_python_error(); PyTuple_SET_ITEM(pyInputs.get(), i, input); } THPObjectPtr apply_fn(PyObject_GetAttrString(obj, "apply")); - if (!apply_fn) throw_python_error(); + if (!apply_fn) + throw_python_error(); THPObjectPtr r(PyObject_CallObject(apply_fn, pyInputs.get())); - if (!r) throw_python_error(); + if (!r) + throw_python_error(); ensure_tuple(r); auto& is_variable_input = py_fn->is_variable_input; @@ -113,7 +121,8 @@ auto PyNode::apply(variable_list&& inputs) -> variable_list { if (all_none) { num_outputs = num_forward_inputs; r = PyTuple_GetSlice(r.get(), 0, num_forward_inputs); - if (!r) throw_python_error(); + if (!r) + throw_python_error(); } } @@ -121,7 +130,7 @@ auto PyNode::apply(variable_list&& inputs) -> variable_list { if (num_outputs != num_forward_inputs) { std::string msg("function "); msg += name() + " returned an incorrect number of gradients (expected "; - msg += std::to_string(num_forward_inputs) + ", got " ; + msg += std::to_string(num_forward_inputs) + ", got "; msg += std::to_string(num_outputs) + ")"; throw std::runtime_error(msg); } @@ -136,7 +145,8 @@ auto PyNode::apply(variable_list&& inputs) -> variable_list { if (output != Py_None) { std::string msg("function "); msg += name() + " returned a gradient different than None at position "; - msg += std::to_string(i + 1) + ", but the corresponding forward input was not a Variable"; + msg += std::to_string(i + 1) + + ", but the corresponding forward input was not a Variable"; throw std::runtime_error(msg); } continue; @@ -159,10 +169,13 @@ auto PyNode::apply(variable_list&& inputs) -> variable_list { auto PyNode::is_traceable() -> bool { pybind11::gil_scoped_acquire gil; - THPObjectPtr forward_class {PyObject_GetAttrString(obj, "_forward_cls")}; - if (!forward_class) throw_python_error(); - THPObjectPtr traceable_py_bool {PyObject_GetAttrString(forward_class, "is_traceable")}; - if (!traceable_py_bool) throw_python_error(); + THPObjectPtr forward_class{PyObject_GetAttrString(obj, "_forward_cls")}; + if (!forward_class) + throw_python_error(); + THPObjectPtr traceable_py_bool{ + PyObject_GetAttrString(forward_class, "is_traceable")}; + if (!traceable_py_bool) + throw_python_error(); return traceable_py_bool == Py_True; } @@ -173,7 +186,7 @@ auto PyNode::release_variables() -> void { // we just leak the saved objects. if (Py_IsInitialized()) { pybind11::gil_scoped_acquire gil; - auto f = (THPFunction*) obj; + auto f = (THPFunction*)obj; f->saved_variables.clear(); f->has_freed_buffers = 1; } @@ -181,18 +194,19 @@ auto PyNode::release_variables() -> void { auto PyNode::name() const -> std::string { pybind11::gil_scoped_acquire gil; - auto f = (THPFunction*) obj; + auto f = (THPFunction*)obj; auto name = std::string(Py_TYPE(f)->tp_name); return name; } -}} // namespace torch::autograd +} // namespace autograd +} // namespace torch // Traverse and clear are required for supporting Python's GC cycle handling. -static int THPFunction_traverse(THPFunction *self, visitproc visit, void *arg) -{ +static int THPFunction_traverse(THPFunction* self, visitproc visit, void* arg) { // cdata could be null if the PyNode has already gone out of scope - // by the time we're GC'ing this THPFunction (e.g., the user saved grad_fn only). + // by the time we're GC'ing this THPFunction (e.g., the user saved grad_fn + // only). // // TODO: I'm not really sure if we're actually obligated to traverse PyObject // that is stored in PyNode, since we don't really own that C++ object. @@ -215,8 +229,7 @@ static int THPFunction_traverse(THPFunction *self, visitproc visit, void *arg) return 0; } -static int THPFunction_clear(THPFunction *self) -{ +static int THPFunction_clear(THPFunction* self) { // Note that the cdata might not be expired yet in the case where this // object is part of a cycle and the GC happens to tp_clear this PyObject // before the other ones that trigger the de-allocation of the cdata @@ -236,8 +249,7 @@ static int THPFunction_clear(THPFunction *self) return 0; } -static void THPFunction_dealloc(THPFunction* self) -{ +static void THPFunction_dealloc(THPFunction* self) { // Why is this guaranteed to be true? Suppose that self->cdata is non-null // (otherwise the condition is trivially true). Then there is a PyNode // which contains an owning reference to this object. But we are only @@ -262,10 +274,13 @@ static void THPFunction_dealloc(THPFunction* self) Py_TYPE(self)->tp_free((PyObject*)self); } -PyObject *THPFunction_new(PyTypeObject *type, PyObject *args, PyObject *kwargs) -{ +PyObject* THPFunction_new( + PyTypeObject* type, + PyObject* args, + PyObject* kwargs) { PyObject* obj = type->tp_alloc(type, 0); - if (!obj) return nullptr; + if (!obj) + return nullptr; // Python zero-initializes the object memory, so there's no need to initialize // most fields THPFunction* self = (THPFunction*)obj; @@ -285,21 +300,27 @@ PyObject *THPFunction_new(PyTypeObject *type, PyObject *args, PyObject *kwargs) // Bump the counters of all recorded dirty input tensors, adding each of them // into dirty_inputs. Also does some sanity checking. -static std::unordered_set _mark_dirty(THPFunction *self) -{ +static std::unordered_set _mark_dirty(THPFunction* self) { // Increase versions of modified tensors std::unordered_set dirty_inputs; - if (!self->dirty_tensors) return dirty_inputs; + if (!self->dirty_tensors) + return dirty_inputs; - THPFunction_assert(PyTuple_Check(self->dirty_tensors), "autograd " + THPFunction_assert( + PyTuple_Check(self->dirty_tensors), + "autograd " "internal error: dirty_tensors attribute is expected to be a tuple " - "but is %s", THPUtils_typename(self->dirty_tensors)); + "but is %s", + THPUtils_typename(self->dirty_tensors)); Py_ssize_t num_dirty = PyTuple_GET_SIZE(self->dirty_tensors); dirty_inputs.reserve(num_dirty); - for(const auto i : c10::irange(num_dirty)) { - PyObject *obj = PyTuple_GET_ITEM(self->dirty_tensors, i); - THPFunction_assert(THPVariable_Check(obj), "mark_dirty can " - "only accept variables, but argument %d is of type %s", i, + for (const auto i : c10::irange(num_dirty)) { + PyObject* obj = PyTuple_GET_ITEM(self->dirty_tensors, i); + THPFunction_assert( + THPVariable_Check(obj), + "mark_dirty can " + "only accept variables, but argument %d is of type %s", + i, THPUtils_typename(obj)); const auto& tensor = THPVariable_Unpack(obj); @@ -311,7 +332,8 @@ static std::unordered_set _mark_dirty(THPFunction *self) return dirty_inputs; } -static std::unordered_set _parse_non_differentiable(THPFunction *self); +static std::unordered_set _parse_non_differentiable( + THPFunction* self); // Given a Python tuple of raw output tensors (raw_output), set each of // the corresponding entries in a different Python tuple (outputs) with @@ -324,9 +346,13 @@ static std::unordered_set _parse_non_differentiable(THPFunction // the set of dirty tensors (dirty_inputs) is used to figure out what to // do in this case. After this method is run, t2var is extended with // mappings for output tensors as well. -static void _wrap_outputs(const std::shared_ptr& cdata, THPFunction *self, - const variable_list &input_vars, PyObject *raw_output, PyObject *outputs, bool is_executable) -{ +static void _wrap_outputs( + const std::shared_ptr& cdata, + THPFunction* self, + const variable_list& input_vars, + PyObject* raw_output, + PyObject* outputs, + bool is_executable) { auto cdata_if_executable = is_executable ? cdata : nullptr; Py_ssize_t num_outputs = PyTuple_GET_SIZE(raw_output); if (is_executable) { @@ -349,14 +375,17 @@ static void _wrap_outputs(const std::shared_ptr& cdata, THPFunction *sel } } - _jvp_fn_t jvp_user_function = [self](variable_list inputs, variable_list grad_inputs) { + _jvp_fn_t jvp_user_function = [self]( + variable_list inputs, + variable_list grad_inputs) { pybind11::gil_scoped_acquire gil; // Massage a C++ variable_list into a Python arguments tuple // Making sure to introduce the proper None for non-Tensor inputs auto num_inputs = self->is_variable_input.size(); THPObjectPtr pyInputs(PyTuple_New(num_inputs)); - if (!pyInputs) throw_python_error(); + if (!pyInputs) + throw_python_error(); int64_t variable_idx = 0; for (const auto i : c10::irange(num_inputs)) { PyObject* input = nullptr; @@ -377,10 +406,13 @@ static void _wrap_outputs(const std::shared_ptr& cdata, THPFunction *sel PyTuple_SET_ITEM(pyInputs.get(), i, input); } - THPObjectPtr apply_jvp_fn(PyObject_GetAttrString((PyObject*)self, "apply_jvp")); - if (!apply_jvp_fn) throw_python_error(); + THPObjectPtr apply_jvp_fn( + PyObject_GetAttrString((PyObject*)self, "apply_jvp")); + if (!apply_jvp_fn) + throw_python_error(); THPObjectPtr r(PyObject_CallObject(apply_jvp_fn, pyInputs.get())); - if (!r) throw_python_error(); + if (!r) + throw_python_error(); ensure_tuple(r); // Massage the Python results tuple back into a C++ variable_list @@ -389,13 +421,18 @@ static void _wrap_outputs(const std::shared_ptr& cdata, THPFunction *sel const int num_outputs = PyTuple_GET_SIZE(r.get()); variable_list results; results.reserve(num_outputs); - for(const auto i : c10::irange(num_outputs)) { + for (const auto i : c10::irange(num_outputs)) { PyObject* output = PyTuple_GET_ITEM(r.get(), i); if (output == Py_None) { results.emplace_back(); } else { - TORCH_CHECK(THPVariable_Check(output), "expected Variable or None (got ", - THPUtils_typename(output), ") for grad output ", i, ".") + TORCH_CHECK( + THPVariable_Check(output), + "expected Variable or None (got ", + THPUtils_typename(output), + ") for grad output ", + i, + ".") results.emplace_back(THPVariable_Unpack(output)); } } @@ -404,10 +441,15 @@ static void _wrap_outputs(const std::shared_ptr& cdata, THPFunction *sel }; // Wrap only the tensor outputs. - auto wrapped_outputs = _wrap_outputs(input_vars, non_differentiable, dirty_inputs, - raw_output_vars, cdata_if_executable, jvp_user_function); + auto wrapped_outputs = _wrap_outputs( + input_vars, + non_differentiable, + dirty_inputs, + raw_output_vars, + cdata_if_executable, + jvp_user_function); - for(const auto i : c10::irange(num_outputs)) { + for (const auto i : c10::irange(num_outputs)) { PyObject* obj = PyTuple_GetItem(raw_output, i); // Keep the non-tensor outputs as is. if (!THPVariable_Check(obj)) { @@ -426,18 +468,22 @@ static void _wrap_outputs(const std::shared_ptr& cdata, THPFunction *sel } // Save any variables that requested by to_save -static void _save_variables(const std::shared_ptr& cdata_ptr, THPFunction* self) -{ - if (!self->to_save) return; +static void _save_variables( + const std::shared_ptr& cdata_ptr, + THPFunction* self) { + if (!self->to_save) + return; - THPFunction_assert(PyTuple_Check(self->to_save), "autograd internal " + THPFunction_assert( + PyTuple_Check(self->to_save), + "autograd internal " "error: to_save attribute is expected to be a tuple but is %s", THPUtils_typename(self->to_save)); Py_ssize_t num_saved = PyTuple_GET_SIZE(self->to_save); self->saved_variables.clear(); self->saved_variables.reserve(num_saved); - for(const auto i : c10::irange(num_saved)) { - PyObject *obj = PyTuple_GET_ITEM(self->to_save, i); + for (const auto i : c10::irange(num_saved)) { + PyObject* obj = PyTuple_GET_ITEM(self->to_save, i); if (obj == Py_None) { self->saved_variables.emplace_back(); continue; @@ -448,29 +494,38 @@ static void _save_variables(const std::shared_ptr& cdata_ptr, THPFunctio } else { throw torch::TypeError( "save_for_backward can only save variables, but argument %ld is of " - "type %s", i, Py_TYPE(obj)->tp_name); + "type %s", + i, + Py_TYPE(obj)->tp_name); } } // Free .to_save Py_CLEAR(self->to_save); } -// Mark requires_grad = 0 on non-differentiable variables (as per non_differentiable) -static std::unordered_set -_parse_non_differentiable(THPFunction *self) -{ +// Mark requires_grad = 0 on non-differentiable variables (as per +// non_differentiable) +static std::unordered_set _parse_non_differentiable( + THPFunction* self) { std::unordered_set set; - if (!self->non_differentiable) return set; + if (!self->non_differentiable) + return set; - THPFunction_assert(PyTuple_Check(self->non_differentiable), "autograd " + THPFunction_assert( + PyTuple_Check(self->non_differentiable), + "autograd " "internal error: non_differentiable attribute is expected to be a " - "tuple but is %s", THPUtils_typename(self->non_differentiable)); + "tuple but is %s", + THPUtils_typename(self->non_differentiable)); Py_ssize_t num_nondiff = PyTuple_GET_SIZE(self->non_differentiable); set.reserve(num_nondiff); - for(const auto i : c10::irange(num_nondiff)) { - PyObject *t = PyTuple_GET_ITEM(self->non_differentiable, i); - THPFunction_assert(THPVariable_Check(t), "mark_non_differentiable " - "only accepts variable arguments, but got %s", THPUtils_typename(t)); + for (const auto i : c10::irange(num_nondiff)) { + PyObject* t = PyTuple_GET_ITEM(self->non_differentiable, i); + THPFunction_assert( + THPVariable_Check(t), + "mark_non_differentiable " + "only accepts variable arguments, but got %s", + THPUtils_typename(t)); set.insert(THPVariable_Unpack(t).unsafeGetTensorImpl()); } Py_CLEAR(self->non_differentiable); @@ -489,24 +544,25 @@ struct InputFlags { std::vector is_variable_input; }; -template -std::pair unpack_input(PyObject *args) { +template +std::pair unpack_input(PyObject* args) { UnpackedInput unpacked; InputFlags flags; auto num_args = PyTuple_GET_SIZE(args); unpacked.input_tuple = PyTuple_New(num_args); flags.needs_input_grad = PyTuple_New(num_args); - for(const auto i : c10::irange(num_args)) { - PyObject *arg = PyTuple_GET_ITEM(args, i); + for (const auto i : c10::irange(num_args)) { + PyObject* arg = PyTuple_GET_ITEM(args, i); bool is_variable = THPVariable_Check(arg); flags.is_variable_input.push_back(is_variable); if (!is_variable) { - // TODO: remove this code path once Variable and Tensor are merged in Python + // TODO: remove this code path once Variable and Tensor are merged in + // Python if (enforce_variables) { - THPUtils_setError("expected a Tensor argument, but got %s", - THPUtils_typename(arg)); + THPUtils_setError( + "expected a Tensor argument, but got %s", THPUtils_typename(arg)); throw python_error(); } Py_INCREF(Py_False); @@ -522,14 +578,17 @@ std::pair unpack_input(PyObject *args) { PyTuple_SET_ITEM(unpacked.input_tuple.get(), i, arg); } - flags.is_executable = GradMode::is_enabled() && any_variable_requires_grad(unpacked.input_vars); - flags.next_edges = (flags.is_executable ? collect_next_edges(unpacked.input_vars) : edge_list()); + flags.is_executable = + GradMode::is_enabled() && any_variable_requires_grad(unpacked.input_vars); + flags.next_edges = + (flags.is_executable ? collect_next_edges(unpacked.input_vars) + : edge_list()); return std::make_pair(std::move(unpacked), std::move(flags)); } static torch::jit::Node* _trace_pre_record( PyObject* op_obj, - PyObject *input_objects, + PyObject* input_objects, const variable_list& input_vars) { if (!jit::tracer::isTracing()) { return nullptr; @@ -541,8 +600,8 @@ static torch::jit::Node* _trace_pre_record( std::string arg_types; arg_types.reserve(num_args); scalar_args.reserve(num_args); - for(const auto i : c10::irange(num_args)) { - PyObject *arg_object = PyTuple_GET_ITEM(input_objects, i); + for (const auto i : c10::irange(num_args)) { + PyObject* arg_object = PyTuple_GET_ITEM(input_objects, i); if (THPVariable_Check(arg_object)) { arg_types.push_back('d'); } else { @@ -562,7 +621,7 @@ static void _trace_post_record( torch::jit::Node* node, PyObject* op_obj, const variable_list& input_vars, - PyObject *output_objects, + PyObject* output_objects, bool is_inplace, bool unpack_output) { if (!jit::tracer::isTracing()) { @@ -570,9 +629,10 @@ static void _trace_post_record( } node->i_(jit::attr::inplace, is_inplace); - if (PyObject* module_name = PyDict_GetItemString(((PyTypeObject*)op_obj)->tp_dict, "__module__")) { + if (PyObject* module_name = PyDict_GetItemString( + ((PyTypeObject*)op_obj)->tp_dict, "__module__")) { if (auto ptr = PyUnicode_AsUTF8(module_name)) { - node->s_(jit::attr::module, std::string(ptr)); + node->s_(jit::attr::module, std::string(ptr)); } } @@ -610,21 +670,28 @@ static void _trace_post_record( new_tuple_values.push_back(ptr); } TypePtr tuple_type = TupleType::create(std::move(new_tuple_values)); - // The i-th tuple element receives a new tensor type with element type and shape. + // The i-th tuple element receives a new tensor type with element type and + // shape. old_node->output()->setType(tuple_type); } } -PyObject* process_outputs(PyObject *op_obj, const std::shared_ptr& cdata, - THPFunction* grad_fn, const UnpackedInput& unpacked, - PyObject *inputs, THPObjectPtr&& raw_output, bool is_executable, - torch::jit::Node* node) { +PyObject* process_outputs( + PyObject* op_obj, + const std::shared_ptr& cdata, + THPFunction* grad_fn, + const UnpackedInput& unpacked, + PyObject* inputs, + THPObjectPtr&& raw_output, + bool is_executable, + torch::jit::Node* node) { bool unpack_output = ensure_tuple(raw_output); auto num_outputs = PyTuple_GET_SIZE(raw_output.get()); THPObjectPtr outputs(PyTuple_New(num_outputs)); - if (!outputs) throw python_error(); + if (!outputs) + throw python_error(); cdata->clear_input_metadata(); @@ -638,11 +705,14 @@ PyObject* process_outputs(PyObject *op_obj, const std::shared_ptr& cdata } bool is_inplace = static_cast(grad_fn->dirty_tensors); - _wrap_outputs(cdata, grad_fn, unpacked.input_vars, raw_output, outputs, is_executable); - _trace_post_record(node, op_obj, unpacked.input_vars, outputs, is_inplace, unpack_output); - - // It is important that creating the SavedVariables happen after the output wrapping as the - // outputs must have their grad_fn/fw_grad properly set before we save them. + _wrap_outputs( + cdata, grad_fn, unpacked.input_vars, raw_output, outputs, is_executable); + _trace_post_record( + node, op_obj, unpacked.input_vars, outputs, is_inplace, unpack_output); + + // It is important that creating the SavedVariables happen after the output + // wrapping as the outputs must have their grad_fn/fw_grad properly set before + // we save them. if (is_executable) { _save_variables(cdata, grad_fn); } else { @@ -658,7 +728,7 @@ PyObject* process_outputs(PyObject *op_obj, const std::shared_ptr& cdata // Unpack the output, unless .forward() returned a tuple if (unpack_output) { - PyObject *output = PyTuple_GET_ITEM(outputs.get(), 0); + PyObject* output = PyTuple_GET_ITEM(outputs.get(), 0); Py_INCREF(output); return output; } @@ -666,21 +736,21 @@ PyObject* process_outputs(PyObject *op_obj, const std::shared_ptr& cdata return outputs.release(); } -PyObject* THPFunction_name(PyObject *self, PyObject* noargs) { +PyObject* THPFunction_name(PyObject* self, PyObject* noargs) { HANDLE_TH_ERRORS auto cdata = ((THPFunction*)self)->cdata.lock(); - TORCH_CHECK(cdata, - "Attribute 'name' is invalid for this instance of _C._FunctionBase. " - "Accessing this attribute directly on an instance of autograd.Function is a legacy " - "access pattern that is no longer supported. For examples on how to use new-style " - "autograd functions, see " - "https://pytorch.org/docs/stable/autograd.html#torch.autograd.Function "); + TORCH_CHECK( + cdata, + "Attribute 'name' is invalid for this instance of _C._FunctionBase. " + "Accessing this attribute directly on an instance of autograd.Function is a legacy " + "access pattern that is no longer supported. For examples on how to use new-style " + "autograd functions, see " + "https://pytorch.org/docs/stable/autograd.html#torch.autograd.Function "); return THPUtils_packString(cdata->name()); END_HANDLE_TH_ERRORS } -PyObject *THPFunction_apply(PyObject *cls, PyObject *inputs) -{ +PyObject* THPFunction_apply(PyObject* cls, PyObject* inputs) { HANDLE_TH_ERRORS // save a local copy of seq_id before it gets incremented @@ -692,9 +762,10 @@ PyObject *THPFunction_apply(PyObject *cls, PyObject *inputs) // Call record function after all the inputs have been decoded, but // before context has been allocated. RECORD_FUNCTION( - ((PyTypeObject*)cls)->tp_name, - std::vector(unpacked_input.input_vars.begin(), unpacked_input.input_vars.end()), - seq_id); + ((PyTypeObject*)cls)->tp_name, + std::vector( + unpacked_input.input_vars.begin(), unpacked_input.input_vars.end()), + seq_id); // Temporary hack to improve functorch UX. We'll find a better solution. const auto& functorch_tls = at::functorch::functorchTLSAccessor(); @@ -703,12 +774,15 @@ PyObject *THPFunction_apply(PyObject *cls, PyObject *inputs) } THPObjectPtr backward_cls(PyObject_GetAttrString(cls, "_backward_cls")); - if (!backward_cls) return nullptr; + if (!backward_cls) + return nullptr; THPObjectPtr ctx_obj(PyObject_CallFunctionObjArgs(backward_cls, nullptr)); - if (!ctx_obj) return nullptr; + if (!ctx_obj) + return nullptr; THPFunction* ctx = (THPFunction*)ctx_obj.get(); - auto cdata = std::shared_ptr(new PyNode(std::move(ctx_obj)), deleteNode); + auto cdata = + std::shared_ptr(new PyNode(std::move(ctx_obj)), deleteNode); ctx->cdata = cdata; // Record input nodes if tracing @@ -720,15 +794,15 @@ PyObject *THPFunction_apply(PyObject *cls, PyObject *inputs) ctx->needs_input_grad = input_info.needs_input_grad.release(); ctx->is_variable_input = std::move(input_info.is_variable_input); - // Prepend ctx to input_tuple, in preparation for static method call auto num_args = PyTuple_GET_SIZE(inputs); THPObjectPtr ctx_input_tuple(PyTuple_New(num_args + 1)); - if (!ctx_input_tuple) return nullptr; + if (!ctx_input_tuple) + return nullptr; Py_INCREF(ctx); PyTuple_SET_ITEM(ctx_input_tuple.get(), 0, (PyObject*)ctx); for (const auto i : c10::irange(num_args)) { - PyObject *arg = PyTuple_GET_ITEM(unpacked_input.input_tuple.get(), i); + PyObject* arg = PyTuple_GET_ITEM(unpacked_input.input_tuple.get(), i); Py_INCREF(arg); PyTuple_SET_ITEM(ctx_input_tuple.get(), i + 1, arg); } @@ -739,62 +813,74 @@ PyObject *THPFunction_apply(PyObject *cls, PyObject *inputs) AutoGradMode grad_mode(false); at::AutoFwGradMode fw_grad_mode(false); THPObjectPtr forward_fn(PyObject_GetAttrString(cls, "forward")); - if (!forward_fn) return nullptr; + if (!forward_fn) + return nullptr; tensor_outputs = PyObject_CallObject(forward_fn, ctx_input_tuple); - if (!tensor_outputs) return nullptr; + if (!tensor_outputs) + return nullptr; } - return process_outputs(cls, cdata, ctx, unpacked_input, inputs, std::move(tensor_outputs), - is_executable, node); + return process_outputs( + cls, + cdata, + ctx, + unpacked_input, + inputs, + std::move(tensor_outputs), + is_executable, + node); END_HANDLE_TH_ERRORS } - //////////////////////////////////////////////////////////////////////////////// // Other methods / attributes //////////////////////////////////////////////////////////////////////////////// -PyObject* THPFunction__register_hook_dict(PyObject *_self, PyObject *_var) -{ +PyObject* THPFunction__register_hook_dict(PyObject* _self, PyObject* _var) { HANDLE_TH_ERRORS - THPUtils_assert(THPVariable_Check(_var), "_register_hook_dict expected a Tensor"); + THPUtils_assert( + THPVariable_Check(_var), "_register_hook_dict expected a Tensor"); THPVariable* var = reinterpret_cast(_var); const auto& tensor = THPVariable_Unpack(var); - std::unique_ptr hook(new PyFunctionPreHook( - var->backward_hooks, tensor.output_nr())); + std::unique_ptr hook( + new PyFunctionPreHook(var->backward_hooks, tensor.output_nr())); auto self = (THPFunction*)_self; auto cdata = self->cdata.lock(); - TORCH_CHECK(cdata, - "Attribute '_register_hook_dict' is invalid for this instance of _C._FunctionBase. " - "Accessing this attribute directly on an instance of autograd.Function is a legacy " - "access pattern that is no longer supported. For examples on how to use new-style " - "autograd functions, see " - "https://pytorch.org/docs/stable/autograd.html#torch.autograd.Function "); + TORCH_CHECK( + cdata, + "Attribute '_register_hook_dict' is invalid for this instance of _C._FunctionBase. " + "Accessing this attribute directly on an instance of autograd.Function is a legacy " + "access pattern that is no longer supported. For examples on how to use new-style " + "autograd functions, see " + "https://pytorch.org/docs/stable/autograd.html#torch.autograd.Function "); cdata->add_pre_hook(std::move(hook)); Py_RETURN_NONE; END_HANDLE_TH_ERRORS } -PyObject* THPFunction_register_hook(PyObject *_self, PyObject *hook) -{ +PyObject* THPFunction_register_hook(PyObject* _self, PyObject* hook) { HANDLE_TH_ERRORS - auto self= (THPFunction*)_self; + auto self = (THPFunction*)_self; auto cdata = self->cdata.lock(); - TORCH_CHECK(cdata, - "Attribute 'register_hook' is invalid for this instance of _C._FunctionBase. " - "Accessing this attribute directly on an instance of autograd.Function is a legacy " - "access pattern that is no longer supported. For examples on how to use new-style " - "autograd functions, see " - "https://pytorch.org/docs/stable/autograd.html#torch.autograd.Function "); + TORCH_CHECK( + cdata, + "Attribute 'register_hook' is invalid for this instance of _C._FunctionBase. " + "Accessing this attribute directly on an instance of autograd.Function is a legacy " + "access pattern that is no longer supported. For examples on how to use new-style " + "autograd functions, see " + "https://pytorch.org/docs/stable/autograd.html#torch.autograd.Function "); return torch::autograd::registerFunctionHook(*cdata, hook); END_HANDLE_TH_ERRORS } -int THPFunction_set_materialize_grads(THPFunction *self, PyObject *value, void *unused) -{ +int THPFunction_set_materialize_grads( + THPFunction* self, + PyObject* value, + void* unused) { HANDLE_TH_ERRORS if (!PyBool_Check(value)) { - THPUtils_invalidArguments(value, nullptr, "set_materialize_grads", 1, "(bool)"); + THPUtils_invalidArguments( + value, nullptr, "set_materialize_grads", 1, "(bool)"); return -1; } self->materialize_grads = (value == Py_True); @@ -802,10 +888,9 @@ int THPFunction_set_materialize_grads(THPFunction *self, PyObject *value, void * END_HANDLE_TH_ERRORS_RET(-1) } -static PyObject *unpack_saved_variables( - THPFunction *self, - const std::function& unpack_fn) -{ +static PyObject* unpack_saved_variables( + THPFunction* self, + const std::function& unpack_fn) { THPUtils_assert(!self->has_freed_buffers, ERR_BACKWARD_TWICE); auto& saved_variables = self->saved_variables; if (saved_variables.empty()) @@ -824,7 +909,7 @@ static PyObject *unpack_saved_variables( // because we will never hit this line of code if the buffers are freed-- // and in any case saved_for will be non-NULL.) TORCH_INTERNAL_ASSERT(saved_for); - for(const auto i : c10::irange(num_saved)) { + for (const auto i : c10::irange(num_saved)) { auto unpacked_var = saved_variables[i].unpack(saved_for); THPObjectPtr value; if (!unpacked_var.defined()) { @@ -838,34 +923,32 @@ static PyObject *unpack_saved_variables( return saved.release(); } -PyObject *THPFunction_saved_tensors(THPFunction *self, void *_unused) -{ +PyObject* THPFunction_saved_tensors(THPFunction* self, void* _unused) { HANDLE_TH_ERRORS if (self->saved_for_forward) { Py_INCREF(self->saved_for_forward); return self->saved_for_forward; } else { - return unpack_saved_variables(self, [](const Variable& var) { - return THPVariable_Wrap(var); - }); + return unpack_saved_variables( + self, [](const Variable& var) { return THPVariable_Wrap(var); }); } END_HANDLE_TH_ERRORS } -PyObject *THPFunction_saved_variables(THPFunction *self, void *_unused) -{ +PyObject* THPFunction_saved_variables(THPFunction* self, void* _unused) { HANDLE_TH_ERRORS - auto r = PyErr_WarnEx(PyExc_DeprecationWarning, - "'saved_variables' is deprecated; use 'saved_tensors'", 0); - if (r != 0) throw python_error(); - return unpack_saved_variables(self, [](const Variable& var) { - return THPVariable_Wrap(var); - }); + auto r = PyErr_WarnEx( + PyExc_DeprecationWarning, + "'saved_variables' is deprecated; use 'saved_tensors'", + 0); + if (r != 0) + throw python_error(); + return unpack_saved_variables( + self, [](const Variable& var) { return THPVariable_Wrap(var); }); END_HANDLE_TH_ERRORS } -PyObject *THPFunction_raw_saved_tensors(THPFunction *self, void *_unused) -{ +PyObject* THPFunction_raw_saved_tensors(THPFunction* self, void* _unused) { HANDLE_TH_ERRORS // User tries to access saved variables after they have been freed THPUtils_assert(!self->has_freed_buffers, ERR_BACKWARD_TWICE); @@ -877,34 +960,37 @@ PyObject *THPFunction_raw_saved_tensors(THPFunction *self, void *_unused) if (!saved) { return nullptr; } - for(const auto i : c10::irange(num_saved)) { - py::object obj = py::cast(saved_variables[i], py::return_value_policy::reference); + for (const auto i : c10::irange(num_saved)) { + py::object obj = + py::cast(saved_variables[i], py::return_value_policy::reference); PyTuple_SET_ITEM(saved.get(), i, obj.release().ptr()); } return saved.release(); END_HANDLE_TH_ERRORS } -PyObject *THPFunction_next_functions(THPFunction *self, void *_unused) -{ +PyObject* THPFunction_next_functions(THPFunction* self, void* _unused) { HANDLE_TH_ERRORS auto cdata = self->cdata.lock(); - TORCH_CHECK(cdata, - "Attribute 'next_functions' is invalid for this instance of _C._FunctionBase. " - "Accessing this attribute directly on an instance of autograd.Function is a legacy " - "access pattern that is no longer supported. For examples on how to use new-style " - "autograd functions, see " - "https://pytorch.org/docs/stable/autograd.html#torch.autograd.Function "); + TORCH_CHECK( + cdata, + "Attribute 'next_functions' is invalid for this instance of _C._FunctionBase. " + "Accessing this attribute directly on an instance of autograd.Function is a legacy " + "access pattern that is no longer supported. For examples on how to use new-style " + "autograd functions, see " + "https://pytorch.org/docs/stable/autograd.html#torch.autograd.Function "); const auto num_outputs = cdata->num_outputs(); THPObjectPtr result(PyTuple_New(num_outputs)); if (!result) return nullptr; for (const auto i : c10::irange(num_outputs)) { THPObjectPtr fn_tuple(PyTuple_New(2)); - if (!fn_tuple) return nullptr; + if (!fn_tuple) + return nullptr; const auto& edge = cdata->next_edge(i); PyObject* fn = functionToPyObject(edge.function); - if (!fn) return nullptr; + if (!fn) + return nullptr; PyTuple_SET_ITEM(fn_tuple.get(), 0, fn); PyTuple_SET_ITEM(fn_tuple.get(), 1, THPUtils_packInt64(edge.input_nr)); PyTuple_SET_ITEM(result.get(), i, fn_tuple.release()); @@ -913,8 +999,7 @@ PyObject *THPFunction_next_functions(THPFunction *self, void *_unused) END_HANDLE_TH_ERRORS } -PyObject *THPFunction_metadata(THPFunction *self, void *_unused) -{ +PyObject* THPFunction_metadata(THPFunction* self, void* _unused) { HANDLE_TH_ERRORS auto cdata = self->cdata.lock(); // The correct way to solve this problem is to stop exposing grad_fn @@ -923,13 +1008,14 @@ PyObject *THPFunction_metadata(THPFunction *self, void *_unused) // mean that you no longer get the property that grad_fn is a subclass // of the autograd function class that you defined in the custom case, // so I didn't fix it here. - TORCH_CHECK(cdata, - "You attempted to access the anomaly metadata of a custom autograd function " - "but the underlying PyNode has already been deallocated. The most likely " - "reason this occurred is because you assigned x.grad_fn to a local variable " - "and then let the original variable get deallocated. Don't do that! If " - "you really have no way of restructuring your code so this is the case, " - "please file an issue reporting that you are affected by this."); + TORCH_CHECK( + cdata, + "You attempted to access the anomaly metadata of a custom autograd function " + "but the underlying PyNode has already been deallocated. The most likely " + "reason this occurred is because you assigned x.grad_fn to a local variable " + "and then let the original variable get deallocated. Don't do that! If " + "you really have no way of restructuring your code so this is the case, " + "please file an issue reporting that you are affected by this."); auto metadata = static_cast(cdata->metadata())->dict(); Py_INCREF(metadata); @@ -937,12 +1023,12 @@ PyObject *THPFunction_metadata(THPFunction *self, void *_unused) END_HANDLE_TH_ERRORS } -typedef PyObject *(*getter)(PyObject *, void *); -typedef int (*setter)(PyObject *, PyObject *, void *); +typedef PyObject* (*getter)(PyObject*, void*); +typedef int (*setter)(PyObject*, PyObject*, void*); namespace { -template +template PyObject* getObject(PyObject* obj, void* _unused) { auto self = (THPFunction*)obj; PyObject* value = self->*ptr; @@ -953,7 +1039,7 @@ PyObject* getObject(PyObject* obj, void* _unused) { return value; } -template +template int setObject(PyObject* obj, PyObject* value, void* _unused) { auto self = (THPFunction*)obj; if (value == Py_None) { @@ -965,13 +1051,13 @@ int setObject(PyObject* obj, PyObject* value, void* _unused) { return 0; } -template +template PyObject* getMember(PyObject* obj, void* _unused) { auto self = (THPFunction*)obj; return Convert(self->*ptr); } -template +template PyObject* getImplMember(PyObject* obj, void* _unused) { auto self = (THPFunction*)obj; return Convert(self->cdata.*ptr); @@ -981,80 +1067,120 @@ PyObject* getRequiresGrad(PyObject* obj, void* _unused) { Py_RETURN_TRUE; } -} +} // namespace // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables) static struct PyGetSetDef THPFunction_properties[] = { - {"saved_tensors", (getter)THPFunction_saved_tensors, nullptr, nullptr, nullptr}, - {"saved_variables", (getter)THPFunction_saved_variables, nullptr, nullptr, nullptr}, - {"_raw_saved_tensors", (getter)THPFunction_raw_saved_tensors, nullptr, nullptr, nullptr}, - {"next_functions", (getter)THPFunction_next_functions, nullptr, nullptr, nullptr}, - {"to_save", &getObject<&THPFunction::to_save>, &setObject<&THPFunction::to_save>, nullptr, nullptr}, - {"non_differentiable", &getObject<&THPFunction::non_differentiable>, &setObject<&THPFunction::non_differentiable>, nullptr, nullptr}, - {"dirty_tensors", &getObject<&THPFunction::dirty_tensors>, &setObject<&THPFunction::dirty_tensors>, nullptr, nullptr}, - {"saved_for_forward", &getObject<&THPFunction::saved_for_forward>, &setObject<&THPFunction::saved_for_forward>, nullptr, nullptr}, - {"needs_input_grad", &getObject<&THPFunction::needs_input_grad>, nullptr, nullptr, nullptr}, - {"requires_grad", getRequiresGrad, nullptr, nullptr, nullptr}, - {"metadata", (getter)THPFunction_metadata, nullptr, nullptr, nullptr}, - {"materialize_grads", nullptr, (setter)THPFunction_set_materialize_grads, nullptr, nullptr}, - {nullptr} -}; + {"saved_tensors", + (getter)THPFunction_saved_tensors, + nullptr, + nullptr, + nullptr}, + {"saved_variables", + (getter)THPFunction_saved_variables, + nullptr, + nullptr, + nullptr}, + {"_raw_saved_tensors", + (getter)THPFunction_raw_saved_tensors, + nullptr, + nullptr, + nullptr}, + {"next_functions", + (getter)THPFunction_next_functions, + nullptr, + nullptr, + nullptr}, + {"to_save", + &getObject<&THPFunction::to_save>, + &setObject<&THPFunction::to_save>, + nullptr, + nullptr}, + {"non_differentiable", + &getObject<&THPFunction::non_differentiable>, + &setObject<&THPFunction::non_differentiable>, + nullptr, + nullptr}, + {"dirty_tensors", + &getObject<&THPFunction::dirty_tensors>, + &setObject<&THPFunction::dirty_tensors>, + nullptr, + nullptr}, + {"saved_for_forward", + &getObject<&THPFunction::saved_for_forward>, + &setObject<&THPFunction::saved_for_forward>, + nullptr, + nullptr}, + {"needs_input_grad", + &getObject<&THPFunction::needs_input_grad>, + nullptr, + nullptr, + nullptr}, + {"requires_grad", getRequiresGrad, nullptr, nullptr, nullptr}, + {"metadata", (getter)THPFunction_metadata, nullptr, nullptr, nullptr}, + {"materialize_grads", + nullptr, + (setter)THPFunction_set_materialize_grads, + nullptr, + nullptr}, + {nullptr}}; // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables) static struct PyMethodDef THPFunction_methods[] = { - {(char*)"name", THPFunction_name, METH_NOARGS, nullptr}, - {(char*)"apply", THPFunction_apply, METH_CLASS | METH_VARARGS, nullptr}, - {(char*)"_register_hook_dict", THPFunction__register_hook_dict, METH_O, nullptr}, - {(char*)"register_hook", THPFunction_register_hook, METH_O, nullptr}, - {nullptr} -}; + {(char*)"name", THPFunction_name, METH_NOARGS, nullptr}, + {(char*)"apply", THPFunction_apply, METH_CLASS | METH_VARARGS, nullptr}, + {(char*)"_register_hook_dict", + THPFunction__register_hook_dict, + METH_O, + nullptr}, + {(char*)"register_hook", THPFunction_register_hook, METH_O, nullptr}, + {nullptr}}; PyTypeObject THPFunctionType = { - PyVarObject_HEAD_INIT(nullptr, 0) - "torch._C._FunctionBase", /* tp_name */ - sizeof(THPFunction), /* tp_basicsize */ - 0, /* tp_itemsize */ - (destructor)THPFunction_dealloc, /* tp_dealloc */ - 0, /* tp_vectorcall_offset */ - nullptr, /* tp_getattr */ - nullptr, /* tp_setattr */ - nullptr, /* tp_reserved */ - nullptr, /* tp_repr */ - nullptr, /* tp_as_number */ - nullptr, /* tp_as_sequence */ - nullptr, /* tp_as_mapping */ - nullptr, /* tp_hash */ - nullptr, /* tp_call */ - nullptr, /* tp_str */ - nullptr, /* tp_getattro */ - nullptr, /* tp_setattro */ - nullptr, /* tp_as_buffer */ - Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE | Py_TPFLAGS_HAVE_GC, /* tp_flags */ - nullptr, /* tp_doc */ - (traverseproc)THPFunction_traverse, /* tp_traverse */ - (inquiry)THPFunction_clear, /* tp_clear */ - nullptr, /* tp_richcompare */ - 0, /* tp_weaklistoffset */ - nullptr, /* tp_iter */ - nullptr, /* tp_iternext */ - THPFunction_methods, /* tp_methods */ - nullptr, /* tp_members */ - THPFunction_properties, /* tp_getset */ - nullptr, /* tp_base */ - nullptr, /* tp_dict */ - nullptr, /* tp_descr_get */ - nullptr, /* tp_descr_set */ - 0, /* tp_dictoffset */ - nullptr, /* tp_init */ - nullptr, /* tp_alloc */ - THPFunction_new /* tp_new */ + PyVarObject_HEAD_INIT(nullptr, 0) "torch._C._FunctionBase", /* tp_name */ + sizeof(THPFunction), /* tp_basicsize */ + 0, /* tp_itemsize */ + (destructor)THPFunction_dealloc, /* tp_dealloc */ + 0, /* tp_vectorcall_offset */ + nullptr, /* tp_getattr */ + nullptr, /* tp_setattr */ + nullptr, /* tp_reserved */ + nullptr, /* tp_repr */ + nullptr, /* tp_as_number */ + nullptr, /* tp_as_sequence */ + nullptr, /* tp_as_mapping */ + nullptr, /* tp_hash */ + nullptr, /* tp_call */ + nullptr, /* tp_str */ + nullptr, /* tp_getattro */ + nullptr, /* tp_setattro */ + nullptr, /* tp_as_buffer */ + Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE | + Py_TPFLAGS_HAVE_GC, /* tp_flags */ + nullptr, /* tp_doc */ + (traverseproc)THPFunction_traverse, /* tp_traverse */ + (inquiry)THPFunction_clear, /* tp_clear */ + nullptr, /* tp_richcompare */ + 0, /* tp_weaklistoffset */ + nullptr, /* tp_iter */ + nullptr, /* tp_iternext */ + THPFunction_methods, /* tp_methods */ + nullptr, /* tp_members */ + THPFunction_properties, /* tp_getset */ + nullptr, /* tp_base */ + nullptr, /* tp_dict */ + nullptr, /* tp_descr_get */ + nullptr, /* tp_descr_set */ + 0, /* tp_dictoffset */ + nullptr, /* tp_init */ + nullptr, /* tp_alloc */ + THPFunction_new /* tp_new */ }; -bool THPFunction_initModule(PyObject *module) -{ +bool THPFunction_initModule(PyObject* module) { if (PyType_Ready(&THPFunctionType) < 0) return false; Py_INCREF(&THPFunctionType); - PyModule_AddObject(module, "_FunctionBase", (PyObject *)&THPFunctionType); + PyModule_AddObject(module, "_FunctionBase", (PyObject*)&THPFunctionType); return true; } diff --git a/torch/csrc/autograd/python_function.h b/torch/csrc/autograd/python_function.h index 77147b090fb281..cc606b97e39433 100644 --- a/torch/csrc/autograd/python_function.h +++ b/torch/csrc/autograd/python_function.h @@ -5,19 +5,24 @@ #include #include #include -#include #include +#include #include -#include #include +#include -#include -#include #include +#include +#include -namespace torch { namespace jit { struct Graph; }} -namespace torch { namespace autograd { +namespace torch { +namespace jit { +struct Graph; +} +} // namespace torch +namespace torch { +namespace autograd { // A Function which is implemented by a Python object (i.e., a THPFunction). // Calls to 'apply' are forwarded to the Python method implementation. @@ -54,64 +59,66 @@ inline bool ensure_tuple(THPObjectPtr& obj) { if (PyTuple_Check(obj.get())) return false; - PyObject *tuple = PyTuple_New(1); - if (!tuple) throw python_error(); + PyObject* tuple = PyTuple_New(1); + if (!tuple) + throw python_error(); PyTuple_SET_ITEM(tuple, 0, obj.release()); obj = tuple; return true; } -}} // namespace torch::autograd +} // namespace autograd +} // namespace torch // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) struct THPFunction { - PyObject_HEAD - - PyObject *needs_input_grad; - - // Python tuple of tensors whose variables we should save. Set - // by Python with 'save_for_backward'. If nullptr, no tensors were - // saved. - PyObject *to_save; - // Python tuple of tensors which are not differentiable. Set by - // Python with 'mark_non_differentiable'. If nullptr, no tensors were - // non-differentiable. - PyObject *non_differentiable; - // Python tuple of tensors which had inplace updates in the forward() - // pass. Set by Python with 'mark_dirty'. If nullptr, no tensors were - // modified inplace. - PyObject *dirty_tensors; - - // boolean indicating whether to materialize undefined output grad tensors - // into tensors full of zeros. Set by Python with 'set_materialize_grads'. - // Default is true. - bool materialize_grads; - - std::vector output_info; - std::vector input_info; - std::vector saved_variables; - // For each input, true if the input is a THPVariable - std::vector is_variable_input; - char has_freed_buffers; - - PyObject *saved_for_forward; - // The actual PyNode (in the autograd graph) that this data was - // saved for. This field may be NULL (because a user can construct - // a THPFunction directly from Python), but when this field is non-NULL, - // it is guaranteed that cdata.lock()->obj == this - // - // In most ordinary use, this field should always be non-NULL; e.g., - // when we allocate a THPFunction because we are running Node.apply, - // after constructing a THPFunction, we immediately allocate a PyNode - // for it. We can't enforce this directly in the constructor of - // THPFunction though, because there's no way to keep it live long enough - // to save an owning reference to PyNode into the grad_fn of a Variable. - std::weak_ptr cdata; + PyObject_HEAD + + PyObject* needs_input_grad; + + // Python tuple of tensors whose variables we should save. Set + // by Python with 'save_for_backward'. If nullptr, no tensors were + // saved. + PyObject* to_save; + // Python tuple of tensors which are not differentiable. Set by + // Python with 'mark_non_differentiable'. If nullptr, no tensors were + // non-differentiable. + PyObject* non_differentiable; + // Python tuple of tensors which had inplace updates in the forward() + // pass. Set by Python with 'mark_dirty'. If nullptr, no tensors were + // modified inplace. + PyObject* dirty_tensors; + + // boolean indicating whether to materialize undefined output grad tensors + // into tensors full of zeros. Set by Python with 'set_materialize_grads'. + // Default is true. + bool materialize_grads; + + std::vector output_info; + std::vector input_info; + std::vector saved_variables; + // For each input, true if the input is a THPVariable + std::vector is_variable_input; + char has_freed_buffers; + + PyObject* saved_for_forward; + // The actual PyNode (in the autograd graph) that this data was + // saved for. This field may be NULL (because a user can construct + // a THPFunction directly from Python), but when this field is non-NULL, + // it is guaranteed that cdata.lock()->obj == this + // + // In most ordinary use, this field should always be non-NULL; e.g., + // when we allocate a THPFunction because we are running Node.apply, + // after constructing a THPFunction, we immediately allocate a PyNode + // for it. We can't enforce this directly in the constructor of + // THPFunction though, because there's no way to keep it live long enough + // to save an owning reference to PyNode into the grad_fn of a Variable. + std::weak_ptr cdata; }; -bool THPFunction_initModule(PyObject *module); +bool THPFunction_initModule(PyObject* module); extern PyTypeObject THPFunctionType; -extern PyObject *THPFunctionClass; +extern PyObject* THPFunctionClass; inline bool THPFunction_Check(PyObject* obj) { return PyObject_IsInstance(obj, (PyObject*)&THPFunctionType); diff --git a/torch/csrc/autograd/python_hook.cpp b/torch/csrc/autograd/python_hook.cpp index 5e30bc040273e1..44d58711357133 100644 --- a/torch/csrc/autograd/python_hook.cpp +++ b/torch/csrc/autograd/python_hook.cpp @@ -2,30 +2,31 @@ #include #include +#include #include #include #include #include -#include #include -using torch::autograd::variable_list; using torch::autograd::Variable; +using torch::autograd::variable_list; static PyObject* wrap_variables(const variable_list& c_variables); static variable_list unwrap_variables(PyObject* py_variables); static std::string hook_name(PyObject* hook); static void check_result(PyObject* original, PyObject* result, PyObject* hook); -static void check_single_result(PyObject* original, PyObject* result, PyObject* hook); - +static void check_single_result( + PyObject* original, + PyObject* result, + PyObject* hook); -namespace torch { namespace autograd { +namespace torch { +namespace autograd { PyFunctionPreHook::PyFunctionPreHook(PyObject* dict, int value_idx) - : dict(dict) - , value_idx(value_idx) -{ + : dict(dict), value_idx(value_idx) { Py_INCREF(dict); } @@ -37,12 +38,13 @@ PyFunctionPreHook::~PyFunctionPreHook() { } } -auto PyFunctionPreHook::operator()(const variable_list& values) -> variable_list -{ +auto PyFunctionPreHook::operator()(const variable_list& values) + -> variable_list { pybind11::gil_scoped_acquire gil; THPObjectPtr value(THPVariable_Wrap(values.at(value_idx))); - if (!value) throw python_error(); + if (!value) + throw python_error(); // Note: [Extend Hook Lifetime] // Hold a reference to hooks till we iterate over them. @@ -58,14 +60,17 @@ auto PyFunctionPreHook::operator()(const variable_list& values) -> variable_list for (Py_ssize_t idx = 0; idx < len; ++idx) { const auto hook = PyList_GetItem(hooks, idx); THPObjectPtr res(PyObject_CallFunctionObjArgs(hook, value.get(), nullptr)); - if (!res) throw python_error(); - if (res == Py_None) continue; + if (!res) + throw python_error(); + if (res == Py_None) + continue; check_single_result(value.get(), res.get(), hook); value = std::move(res); } variable_list results(values); - if (value != Py_None) results[value_idx] = THPVariable_Unpack(value.get()); + if (value != Py_None) + results[value_idx] = THPVariable_Unpack(value.get()); return results; } @@ -83,8 +88,7 @@ PyFunctionPostHook::~PyFunctionPostHook() { auto PyFunctionPostHook::operator()( const variable_list& _outputs, /* grad_inputs */ - const variable_list& _inputs /* grad_outputs */) -> variable_list -{ + const variable_list& _inputs /* grad_outputs */) -> variable_list { pybind11::gil_scoped_acquire gil; THPObjectPtr outputs(wrap_variables(_outputs)); @@ -97,8 +101,10 @@ auto PyFunctionPostHook::operator()( const auto hook = PyList_GetItem(hooks, idx); THPObjectPtr res(PyObject_CallFunctionObjArgs( hook, outputs.get(), inputs.get(), nullptr)); - if (!res) throw python_error(); - if (res == Py_None) continue; + if (!res) + throw python_error(); + if (res == Py_None) + continue; check_result(outputs, res, hook); outputs = std::move(res); } @@ -106,25 +112,26 @@ auto PyFunctionPostHook::operator()( return unwrap_variables(outputs.get()); } -}} // namespace torch::autograd +} // namespace autograd +} // namespace torch - -static PyObject *wrap_variables(const variable_list& c_variables) -{ +static PyObject* wrap_variables(const variable_list& c_variables) { size_t num_vars = c_variables.size(); THPObjectPtr tuple(PyTuple_New(num_vars)); - if (!tuple) throw python_error(); + if (!tuple) + throw python_error(); for (const auto i : c10::irange(num_vars)) { THPObjectPtr var(THPVariable_Wrap(c_variables[i])); - if (!var) throw python_error(); + if (!var) + throw python_error(); PyTuple_SET_ITEM(tuple.get(), i, var.release()); } return tuple.release(); } -static variable_list unwrap_variables(PyObject* py_variables) { +static variable_list unwrap_variables(PyObject* py_variables) { variable_list results(PyTuple_GET_SIZE(py_variables)); - for(const auto i : c10::irange(results.size())) { + for (const auto i : c10::irange(results.size())) { PyObject* item = PyTuple_GET_ITEM(py_variables, i); if (item == Py_None) { continue; @@ -142,7 +149,9 @@ static variable_list unwrap_variables(PyObject* py_variables) { static void check_result(PyObject* prev, PyObject* result, PyObject* hook) { if (!PyTuple_Check(result)) { - PyErr_Format(PyExc_TypeError, "expected tuple, but hook returned '%s'", + PyErr_Format( + PyExc_TypeError, + "expected tuple, but hook returned '%s'", THPUtils_typename(result)); throw python_error(); } @@ -159,19 +168,27 @@ static void check_result(PyObject* prev, PyObject* result, PyObject* hook) { } for (const auto i : c10::irange(prev_size)) { - check_single_result(PyTuple_GET_ITEM(prev, i), PyTuple_GET_ITEM(result, i), hook); + check_single_result( + PyTuple_GET_ITEM(prev, i), PyTuple_GET_ITEM(result, i), hook); } } -static void check_single_result(PyObject* _original, PyObject* _result, PyObject* hook) { - if (_result == Py_None) return; +static void check_single_result( + PyObject* _original, + PyObject* _result, + PyObject* hook) { + if (_result == Py_None) + return; if (_original == Py_None) { - throw std::runtime_error("can't replace a None gradient with a non-None value"); + throw std::runtime_error( + "can't replace a None gradient with a non-None value"); } if (!PyObject_IsInstance(_result, THPVariableClass)) { - PyErr_Format(PyExc_TypeError, "expected Variable, but hook returned '%s'", + PyErr_Format( + PyExc_TypeError, + "expected Variable, but hook returned '%s'", THPUtils_typename(_result)); throw python_error(); } @@ -185,7 +202,8 @@ static void check_single_result(PyObject* _original, PyObject* _result, PyObject static std::string hook_name(PyObject* hook) { if (PyObject_HasAttrString(hook, "__name__")) { THPObjectPtr name(PyObject_GetAttrString(hook, "__name__")); - if (!name) throw python_error(); + if (!name) + throw python_error(); if (name && THPUtils_checkString(name.get())) { return THPUtils_unpackString(name.get()); diff --git a/torch/csrc/autograd/python_hook.h b/torch/csrc/autograd/python_hook.h index 9783d7263ae63d..0d5a6d306a6527 100644 --- a/torch/csrc/autograd/python_hook.h +++ b/torch/csrc/autograd/python_hook.h @@ -1,10 +1,11 @@ #pragma once -#include #include +#include #include -namespace torch { namespace autograd { +namespace torch { +namespace autograd { struct PyFunctionPreHook : public FunctionPreHook { PyFunctionPreHook(PyObject* dict, int value_idx); @@ -17,8 +18,11 @@ struct PyFunctionPreHook : public FunctionPreHook { struct PyFunctionPostHook : public FunctionPostHook { PyFunctionPostHook(PyObject* dict); ~PyFunctionPostHook() override; - variable_list operator()(const variable_list& outputs, const variable_list& inputs) override; + variable_list operator()( + const variable_list& outputs, + const variable_list& inputs) override; PyObject* dict; }; -}} // namespace torch::autograd +} // namespace autograd +} // namespace torch diff --git a/torch/csrc/autograd/python_legacy_variable.cpp b/torch/csrc/autograd/python_legacy_variable.cpp index f4e732efc94984..4da5333546ba02 100644 --- a/torch/csrc/autograd/python_legacy_variable.cpp +++ b/torch/csrc/autograd/python_legacy_variable.cpp @@ -5,43 +5,61 @@ #include #include #include -#include #include +#include using namespace at; -namespace torch { namespace autograd { +namespace torch { +namespace autograd { -static PyObject *THPVariable_pynew(PyTypeObject* type, PyObject *args, PyObject *kwds) { +static PyObject* THPVariable_pynew( + PyTypeObject* type, + PyObject* args, + PyObject* kwds) { HANDLE_TH_ERRORS THPObjectPtr _data; - PyObject *data = nullptr; - PyObject *grad_fn = nullptr; + PyObject* data = nullptr; + PyObject* grad_fn = nullptr; char is_volatile = 0; char requires_grad = 0; const char* name = nullptr; // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays) - const char *accepted_args[] = {"data", "requires_grad", "volatile", "_grad_fn", "name", nullptr}; - if (!PyArg_ParseTupleAndKeywords(args, kwds, "|ObbOz", (char**)accepted_args, - &data, &requires_grad, &is_volatile, &grad_fn, &name)) + const char* accepted_args[] = { + "data", "requires_grad", "volatile", "_grad_fn", "name", nullptr}; + if (!PyArg_ParseTupleAndKeywords( + args, + kwds, + "|ObbOz", + (char**)accepted_args, + &data, + &requires_grad, + &is_volatile, + &grad_fn, + &name)) return nullptr; if (grad_fn == Py_None) grad_fn = nullptr; if (is_volatile) { - auto r = PyErr_WarnEx(PyExc_UserWarning, + auto r = PyErr_WarnEx( + PyExc_UserWarning, "volatile was removed and now has no effect. Use `with torch.no_grad():` " - "instead.", 1); - if (r != 0) throw python_error(); + "instead.", + 1); + if (r != 0) + throw python_error(); } if (is_volatile && requires_grad) { - throw ValueError("Variable can't be volatile and require_grad at the same time!"); + throw ValueError( + "Variable can't be volatile and require_grad at the same time!"); } if (grad_fn && !THPFunction_Check(grad_fn)) { - throw TypeError("_grad_fn has to be a Function object or None, but got %s", + throw TypeError( + "_grad_fn has to be a Function object or None, but got %s", Py_TYPE(grad_fn)->tp_name); } Variable var; @@ -51,17 +69,17 @@ static PyObject *THPVariable_pynew(PyTypeObject* type, PyObject *args, PyObject auto dispatch_key = torch::tensors::get_default_dispatch_key(); auto scalar_type = torch::tensors::get_default_scalar_type(); auto options = TensorOptions(scalar_type) - .device(dispatchKeyToDeviceType(dispatch_key)) - .layout(dispatchKeyToLayout(dispatch_key)); + .device(dispatchKeyToDeviceType(dispatch_key)) + .layout(dispatchKeyToLayout(dispatch_key)); var = at::empty({0}, options); } else if (THPVariable_Check(data)) { var = THPVariable_Unpack(data).detach(); } else { - throw torch::TypeError("Variable data has to be a tensor, but got %s", - Py_TYPE(data)->tp_name); + throw torch::TypeError( + "Variable data has to be a tensor, but got %s", Py_TYPE(data)->tp_name); } - // We set `tensor`'s `allow_tensor_metadata_change` to true here, because we want to - // allow the following use case for backward compatibility: + // We set `tensor`'s `allow_tensor_metadata_change` to true here, because we + // want to allow the following use case for backward compatibility: // // ```python // var = Variable(torch.randn(2, 3)) @@ -69,18 +87,20 @@ static PyObject *THPVariable_pynew(PyTypeObject* type, PyObject *args, PyObject // ``` var.unsafeGetTensorImpl()->set_allow_tensor_metadata_change(true); - TORCH_CHECK(!grad_fn, - "_grad_fn argument to legacy Variable constructor is no longer supported. " - "Instead, please invoke your _grad_fn to produce a variable with it as the " - "_grad_fn."); + TORCH_CHECK( + !grad_fn, + "_grad_fn argument to legacy Variable constructor is no longer supported. " + "Instead, please invoke your _grad_fn to produce a variable with it as the " + "_grad_fn."); var.set_requires_grad(requires_grad); if (name) { impl::set_name(var, name); } - if (jit::tracer::isTracing() && data && data != Py_None && THPVariable_Check(data)) { - if (auto *v = jit::tracer::getValueTrace(THPVariable_Unpack(data))) { + if (jit::tracer::isTracing() && data && data != Py_None && + THPVariable_Check(data)) { + if (auto* v = jit::tracer::getValueTrace(THPVariable_Unpack(data))) { jit::tracer::setValueTrace(var, v); } } @@ -90,47 +110,48 @@ static PyObject *THPVariable_pynew(PyTypeObject* type, PyObject *args, PyObject } PyTypeObject THPLegacyVariableType = { - PyVarObject_HEAD_INIT(nullptr, 0) - "torch._C._LegacyVariableBase", /* tp_name */ - 0, /* tp_basicsize */ - 0, /* tp_itemsize */ - nullptr, /* tp_dealloc */ - 0, /* tp_vectorcall_offset */ - nullptr, /* tp_getattr */ - nullptr, /* tp_setattr */ - nullptr, /* tp_reserved */ - nullptr, /* tp_repr */ - nullptr, /* tp_as_number */ - nullptr, /* tp_as_sequence */ - nullptr, /* tp_as_mapping */ - nullptr, /* tp_hash */ - nullptr, /* tp_call */ - nullptr, /* tp_str */ - nullptr, /* tp_getattro */ - nullptr, /* tp_setattro */ - nullptr, /* tp_as_buffer */ - Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /* tp_flags */ - nullptr, /* tp_doc */ - nullptr, /* tp_traverse */ - nullptr, /* tp_clear */ - nullptr, /* tp_richcompare */ - 0, /* tp_weaklistoffset */ - nullptr, /* tp_iter */ - nullptr, /* tp_iternext */ - nullptr, /* tp_methods */ - nullptr, /* tp_members */ - nullptr, /* tp_getset */ - nullptr, /* tp_base */ - nullptr, /* tp_dict */ - nullptr, /* tp_descr_get */ - nullptr, /* tp_descr_set */ - 0, /* tp_dictoffset */ - nullptr, /* tp_init */ - nullptr, /* tp_alloc */ - THPVariable_pynew /* tp_new */ + PyVarObject_HEAD_INIT( + nullptr, + 0) "torch._C._LegacyVariableBase", /* tp_name */ + 0, /* tp_basicsize */ + 0, /* tp_itemsize */ + nullptr, /* tp_dealloc */ + 0, /* tp_vectorcall_offset */ + nullptr, /* tp_getattr */ + nullptr, /* tp_setattr */ + nullptr, /* tp_reserved */ + nullptr, /* tp_repr */ + nullptr, /* tp_as_number */ + nullptr, /* tp_as_sequence */ + nullptr, /* tp_as_mapping */ + nullptr, /* tp_hash */ + nullptr, /* tp_call */ + nullptr, /* tp_str */ + nullptr, /* tp_getattro */ + nullptr, /* tp_setattro */ + nullptr, /* tp_as_buffer */ + Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /* tp_flags */ + nullptr, /* tp_doc */ + nullptr, /* tp_traverse */ + nullptr, /* tp_clear */ + nullptr, /* tp_richcompare */ + 0, /* tp_weaklistoffset */ + nullptr, /* tp_iter */ + nullptr, /* tp_iternext */ + nullptr, /* tp_methods */ + nullptr, /* tp_members */ + nullptr, /* tp_getset */ + nullptr, /* tp_base */ + nullptr, /* tp_dict */ + nullptr, /* tp_descr_get */ + nullptr, /* tp_descr_set */ + 0, /* tp_dictoffset */ + nullptr, /* tp_init */ + nullptr, /* tp_alloc */ + THPVariable_pynew /* tp_new */ }; -void init_legacy_variable(PyObject *module) { +void init_legacy_variable(PyObject* module) { if (PyType_Ready(&THPLegacyVariableType) < 0) { throw python_error(); } @@ -141,4 +162,5 @@ void init_legacy_variable(PyObject *module) { } } -}} // namespace torch::autograd +} // namespace autograd +} // namespace torch diff --git a/torch/csrc/autograd/python_legacy_variable.h b/torch/csrc/autograd/python_legacy_variable.h index 47a5f071723b6b..94e231171bd527 100644 --- a/torch/csrc/autograd/python_legacy_variable.h +++ b/torch/csrc/autograd/python_legacy_variable.h @@ -5,8 +5,10 @@ #include -namespace torch { namespace autograd { +namespace torch { +namespace autograd { -void init_legacy_variable(PyObject *module); +void init_legacy_variable(PyObject* module); -}} // namespace torch::autograd +} +} // namespace torch diff --git a/torch/csrc/autograd/python_linalg_functions.h b/torch/csrc/autograd/python_linalg_functions.h index 79aca560bb7b1f..b94f395376e476 100644 --- a/torch/csrc/autograd/python_linalg_functions.h +++ b/torch/csrc/autograd/python_linalg_functions.h @@ -1,7 +1,9 @@ #pragma once -namespace torch { namespace autograd { +namespace torch { +namespace autograd { void initLinalgFunctions(PyObject* module); -}} // namespace torch::autograd +} +} // namespace torch diff --git a/torch/csrc/autograd/python_nn_functions.h b/torch/csrc/autograd/python_nn_functions.h index 4ea251a91ab45f..7d75cf9d299394 100644 --- a/torch/csrc/autograd/python_nn_functions.h +++ b/torch/csrc/autograd/python_nn_functions.h @@ -1,7 +1,9 @@ #pragma once -namespace torch { namespace autograd { +namespace torch { +namespace autograd { void initNNFunctions(PyObject* module); -}} // namespace torch::autograd +} +} // namespace torch diff --git a/torch/csrc/autograd/python_return_types.h b/torch/csrc/autograd/python_return_types.h index c4f68deefa5d0c..7b18802e9d375f 100644 --- a/torch/csrc/autograd/python_return_types.h +++ b/torch/csrc/autograd/python_return_types.h @@ -1,8 +1,10 @@ #pragma once -namespace torch { namespace autograd { +namespace torch { +namespace autograd { PyTypeObject* get_namedtuple(std::string name); void initReturnTypes(PyObject* module); -}} // namespace torch::autograd +} // namespace autograd +} // namespace torch diff --git a/torch/csrc/autograd/python_saved_variable_hooks.cpp b/torch/csrc/autograd/python_saved_variable_hooks.cpp index 4d224f2982b234..2bec0e75b1a704 100644 --- a/torch/csrc/autograd/python_saved_variable_hooks.cpp +++ b/torch/csrc/autograd/python_saved_variable_hooks.cpp @@ -1,78 +1,89 @@ -#include #include +#include #include namespace py = pybind11; -namespace torch { namespace autograd { - PySavedVariableHooks::PySavedVariableHooks(py::function &pack_hook, py::function &unpack_hook) : - // steals the reference (we will decref ourselves) - pack_hook_(pack_hook.release().ptr()), - unpack_hook_(unpack_hook.release().ptr()) {} +namespace torch { +namespace autograd { +PySavedVariableHooks::PySavedVariableHooks( + py::function& pack_hook, + py::function& unpack_hook) + : // steals the reference (we will decref ourselves) + pack_hook_(pack_hook.release().ptr()), + unpack_hook_(unpack_hook.release().ptr()) {} - void PySavedVariableHooks::call_pack_hook(const at::Tensor &tensor) { - py::gil_scoped_acquire acquire; - auto pack_hook = py::reinterpret_borrow(pack_hook_); - auto wrapped = THPVariable_Wrap(tensor); - py::object obj = py::reinterpret_steal(wrapped); - py::object packed = pack_hook(obj); - data_ = packed.release().ptr(); - // pack_hook, obj are decrefed on exit - // wrapped and packed had their references stolen - // pack_hook_ and data_ will be manually decrefed when the saved variable is released - } +void PySavedVariableHooks::call_pack_hook(const at::Tensor& tensor) { + py::gil_scoped_acquire acquire; + auto pack_hook = py::reinterpret_borrow(pack_hook_); + auto wrapped = THPVariable_Wrap(tensor); + py::object obj = py::reinterpret_steal(wrapped); + py::object packed = pack_hook(obj); + data_ = packed.release().ptr(); + // pack_hook, obj are decrefed on exit + // wrapped and packed had their references stolen + // pack_hook_ and data_ will be manually decrefed when the saved variable is + // released +} - at::Tensor PySavedVariableHooks::call_unpack_hook() { - py::gil_scoped_acquire acquire; - auto unpack_hook = py::reinterpret_borrow(unpack_hook_); - py::object obj = py::cast(data_); - py::object res = unpack_hook(obj); - PyObject* ptr = res.ptr(); - TORCH_CHECK_TYPE(THPVariable_Check(ptr), "Output of saved tensor unpack_hook expected to be a Tensor but got result of type ", THPUtils_typename(ptr)); - return THPVariable_Unpack(ptr); - // unpack_hook, obj and res are decrefed on exit - // ptr is only alive as long as res is - // unpack_hook_ will be manually decrefed when the saved variable is released - } +at::Tensor PySavedVariableHooks::call_unpack_hook() { + py::gil_scoped_acquire acquire; + auto unpack_hook = py::reinterpret_borrow(unpack_hook_); + py::object obj = py::cast(data_); + py::object res = unpack_hook(obj); + PyObject* ptr = res.ptr(); + TORCH_CHECK_TYPE( + THPVariable_Check(ptr), + "Output of saved tensor unpack_hook expected to be a Tensor but got result of type ", + THPUtils_typename(ptr)); + return THPVariable_Unpack(ptr); + // unpack_hook, obj and res are decrefed on exit + // ptr is only alive as long as res is + // unpack_hook_ will be manually decrefed when the saved variable is released +} - PySavedVariableHooks::~PySavedVariableHooks() { - // If python is already dead, leak the wrapped python objects - if (Py_IsInitialized()) { - py::gil_scoped_acquire gil; - Py_XDECREF(pack_hook_); - Py_XDECREF(unpack_hook_); - Py_XDECREF(data_); - } +PySavedVariableHooks::~PySavedVariableHooks() { + // If python is already dead, leak the wrapped python objects + if (Py_IsInitialized()) { + py::gil_scoped_acquire gil; + Py_XDECREF(pack_hook_); + Py_XDECREF(unpack_hook_); + Py_XDECREF(data_); } +} - void PyDefaultSavedVariableHooks::push_hooks(py::function &pack_hook, py::function &unpack_hook) { - at::SavedTensorDefaultHooks::enable(); - at::SavedTensorDefaultHooks::push_hooks(pack_hook.release().ptr(), unpack_hook.release().ptr()); - } +void PyDefaultSavedVariableHooks::push_hooks( + py::function& pack_hook, + py::function& unpack_hook) { + at::SavedTensorDefaultHooks::enable(); + at::SavedTensorDefaultHooks::push_hooks( + pack_hook.release().ptr(), unpack_hook.release().ptr()); +} - void PyDefaultSavedVariableHooks::pop_hooks() { - PyObject *pack_hook(nullptr), *unpack_hook(nullptr); - std::tie(pack_hook, unpack_hook) = at::SavedTensorDefaultHooks::get_hooks(); - TORCH_INTERNAL_ASSERT(pack_hook != nullptr && unpack_hook != nullptr); - if (Py_IsInitialized()) { - py::gil_scoped_acquire gil; - Py_XDECREF(pack_hook); - Py_XDECREF(unpack_hook); - } - at::SavedTensorDefaultHooks::pop_hooks(); +void PyDefaultSavedVariableHooks::pop_hooks() { + PyObject *pack_hook(nullptr), *unpack_hook(nullptr); + std::tie(pack_hook, unpack_hook) = at::SavedTensorDefaultHooks::get_hooks(); + TORCH_INTERNAL_ASSERT(pack_hook != nullptr && unpack_hook != nullptr); + if (Py_IsInitialized()) { + py::gil_scoped_acquire gil; + Py_XDECREF(pack_hook); + Py_XDECREF(unpack_hook); } + at::SavedTensorDefaultHooks::pop_hooks(); +} - std::unique_ptr PyDefaultSavedVariableHooks::get_hooks() { - PyObject *pack_hook(nullptr), *unpack_hook(nullptr); - std::tie(pack_hook, unpack_hook) = at::SavedTensorDefaultHooks::get_hooks(); - if (!pack_hook || !unpack_hook) { - return nullptr; - } - py::gil_scoped_acquire gil; - py::function pack_hook_ = py::reinterpret_borrow(pack_hook); - py::function unpack_hook_ = py::reinterpret_borrow(unpack_hook); - return std::make_unique(pack_hook_, unpack_hook_); +std::unique_ptr PyDefaultSavedVariableHooks::get_hooks() { + PyObject *pack_hook(nullptr), *unpack_hook(nullptr); + std::tie(pack_hook, unpack_hook) = at::SavedTensorDefaultHooks::get_hooks(); + if (!pack_hook || !unpack_hook) { + return nullptr; } + py::gil_scoped_acquire gil; + py::function pack_hook_ = py::reinterpret_borrow(pack_hook); + py::function unpack_hook_ = py::reinterpret_borrow(unpack_hook); + return std::make_unique(pack_hook_, unpack_hook_); +} -}} +} // namespace autograd +} // namespace torch diff --git a/torch/csrc/autograd/python_saved_variable_hooks.h b/torch/csrc/autograd/python_saved_variable_hooks.h index 4500bf6b19f252..835b2ed04cef14 100644 --- a/torch/csrc/autograd/python_saved_variable_hooks.h +++ b/torch/csrc/autograd/python_saved_variable_hooks.h @@ -1,32 +1,34 @@ #pragma once +#include #include +#include #include #include #include -#include -#include namespace py = pybind11; -namespace torch { namespace autograd { +namespace torch { +namespace autograd { struct PySavedVariableHooks : public SavedVariableHooks { - PySavedVariableHooks(py::function &pack_hook, py::function &unpack_hook); - void call_pack_hook(const at::Tensor &tensor) override; + PySavedVariableHooks(py::function& pack_hook, py::function& unpack_hook); + void call_pack_hook(const at::Tensor& tensor) override; at::Tensor call_unpack_hook() override; ~PySavedVariableHooks() override; -private: + private: PyObject* pack_hook_; PyObject* unpack_hook_; PyObject* data_ = nullptr; }; struct PyDefaultSavedVariableHooks { - static void push_hooks(py::function &pack_hook, py::function &unpack_hook); + static void push_hooks(py::function& pack_hook, py::function& unpack_hook); static void pop_hooks(); static std::unique_ptr get_hooks(); }; -}} +} // namespace autograd +} // namespace torch diff --git a/torch/csrc/autograd/python_sparse_functions.h b/torch/csrc/autograd/python_sparse_functions.h index 18a0dcb54fa914..a9a6617e997d9d 100644 --- a/torch/csrc/autograd/python_sparse_functions.h +++ b/torch/csrc/autograd/python_sparse_functions.h @@ -1,7 +1,9 @@ #pragma once -namespace torch { namespace autograd { +namespace torch { +namespace autograd { void initSparseFunctions(PyObject* module); -}} // namespace torch::autograd +} +} // namespace torch diff --git a/torch/csrc/autograd/python_special_functions.h b/torch/csrc/autograd/python_special_functions.h index f2c42f4dabff86..8af7438d5b4e4e 100644 --- a/torch/csrc/autograd/python_special_functions.h +++ b/torch/csrc/autograd/python_special_functions.h @@ -1,7 +1,9 @@ #pragma once -namespace torch { namespace autograd { +namespace torch { +namespace autograd { void initSpecialFunctions(PyObject* module); -}} // namespace torch::autograd +} +} // namespace torch diff --git a/torch/csrc/autograd/python_torch_functions.h b/torch/csrc/autograd/python_torch_functions.h index 58257794812eec..0221bbafab2836 100644 --- a/torch/csrc/autograd/python_torch_functions.h +++ b/torch/csrc/autograd/python_torch_functions.h @@ -2,15 +2,18 @@ #include - -namespace torch { namespace autograd { +namespace torch { +namespace autograd { extern PyObject* THPVariableFunctionsModule; // Wrapper converts a raised TypeError into returning NotImplemented // Used to implement binary arithmetic operators template -inline PyObject * TypeError_to_NotImplemented_(PyObject* self, PyObject* args, PyObject* kwargs) { +inline PyObject* TypeError_to_NotImplemented_( + PyObject* self, + PyObject* args, + PyObject* kwargs) { PyObject* ret = Func(self, args, kwargs); if (!ret && PyErr_ExceptionMatches(PyExc_TypeError)) { PyErr_Clear(); @@ -22,4 +25,5 @@ inline PyObject * TypeError_to_NotImplemented_(PyObject* self, PyObject* args, P void initTorchFunctions(); -}} // namespace torch::autograd +} // namespace autograd +} // namespace torch diff --git a/torch/csrc/autograd/python_torch_functions_manual.cpp b/torch/csrc/autograd/python_torch_functions_manual.cpp index e35af63ccf2eb5..bcb7bb64672619 100644 --- a/torch/csrc/autograd/python_torch_functions_manual.cpp +++ b/torch/csrc/autograd/python_torch_functions_manual.cpp @@ -1,54 +1,54 @@ -#include -#include -#include #include #include #include +#include +#include +#include +#include +#include +#include #include #include #include #include +#include #include #include #include -#include -#include -#include -#include #include #include -#include #include +#include #include #include -using at::Tensor; +using at::ArrayRef; +using at::Backend; using at::Device; +using at::DeviceGuard; +using at::Dimname; +using at::DimnameList; +using at::Generator; +using at::IntArrayRef; using at::Layout; +using at::OptionalDeviceGuard; using at::Scalar; using at::ScalarType; -using at::Backend; -using at::OptionalDeviceGuard; -using at::DeviceGuard; -using at::TensorOptions; -using at::IntArrayRef; -using at::Generator; +using at::Tensor; using at::TensorList; -using at::Dimname; -using at::DimnameList; -using at::ArrayRef; +using at::TensorOptions; using torch::utils::check_out_type_matches; using namespace torch::autograd::utils; -namespace torch { namespace autograd { +namespace torch { +namespace autograd { // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) PyObject* THPVariableFunctionsModule = nullptr; - inline Tensor dispatch_arange(const Scalar& end, Tensor result) { pybind11::gil_scoped_release no_gil; return at::arange_out(result, end); @@ -60,30 +60,43 @@ inline Tensor dispatch_arange(const Scalar& end, const TensorOptions& options) { return torch::arange(end, options); } -inline Tensor dispatch_arange(const Scalar& start, const Scalar& end, const Scalar& step, Tensor result) { +inline Tensor dispatch_arange( + const Scalar& start, + const Scalar& end, + const Scalar& step, + Tensor result) { pybind11::gil_scoped_release no_gil; return at::arange_out(result, start, end, step); } -inline Tensor dispatch_arange(const Scalar& start, const Scalar& end, const Scalar& step, const TensorOptions& options) { +inline Tensor dispatch_arange( + const Scalar& start, + const Scalar& end, + const Scalar& step, + const TensorOptions& options) { torch::utils::maybe_initialize_cuda(options); pybind11::gil_scoped_release no_gil; return torch::arange(start, end, step, options); } -static PyObject * THPVariable_arange(PyObject* self, PyObject* args, PyObject* kwargs) -{ +static PyObject* THPVariable_arange( + PyObject* self, + PyObject* args, + PyObject* kwargs) { HANDLE_TH_ERRORS - static PythonArgParser parser({ - "arange(Scalar end, *, Tensor out=None, ScalarType dtype=None, Layout layout=torch.strided, Device device=None, bool pin_memory=False, bool requires_grad=False)", - "arange(Scalar start, Scalar end, Scalar step=1, *, Tensor out=None, ScalarType dtype=None, Layout layout=torch.strided, Device device=None, bool pin_memory=False, bool requires_grad=False)", - }, /*traceable=*/true); + static PythonArgParser parser( + { + "arange(Scalar end, *, Tensor out=None, ScalarType dtype=None, Layout layout=torch.strided, Device device=None, bool pin_memory=False, bool requires_grad=False)", + "arange(Scalar start, Scalar end, Scalar step=1, *, Tensor out=None, ScalarType dtype=None, Layout layout=torch.strided, Device device=None, bool pin_memory=False, bool requires_grad=False)", + }, + /*traceable=*/true); ParsedArgs<9> parsed_args; auto r = parser.parse(args, kwargs, parsed_args); - if(r.has_torch_function()) { - return handle_torch_function(r, args, kwargs, THPVariableFunctionsModule, "torch"); + if (r.has_torch_function()) { + return handle_torch_function( + r, args, kwargs, THPVariableFunctionsModule, "torch"); } if (r.idx == 0) { @@ -92,17 +105,24 @@ static PyObject * THPVariable_arange(PyObject* self, PyObject* args, PyObject* k // NOTE: r.scalartype(X) gives the default dtype if r.isNone(X) c10::optional scalarType = r.scalartypeOptional(2); const auto options = TensorOptions() - .dtype(scalarType) - .device(r.device(4)) - .layout(r.layout(3)) - .requires_grad(r.toBool(6)) - .pinned_memory(r.toBool(5)); + .dtype(scalarType) + .device(r.device(4)) + .layout(r.layout(3)) + .requires_grad(r.toBool(6)) + .pinned_memory(r.toBool(5)); return wrap(dispatch_arange(end, options)); } else { - TORCH_CHECK(!r.toBool(5), " `pin_memory` and `out` parameters are incompatible"); - check_out_type_matches(r.tensor(1), r.scalartype(2), r.isNone(2), r.layout(3), - r.device(4), r.isNone(4)); - return wrap(dispatch_arange(r.scalar(0), r.tensor(1)).set_requires_grad(r.toBool(6))); + TORCH_CHECK( + !r.toBool(5), " `pin_memory` and `out` parameters are incompatible"); + check_out_type_matches( + r.tensor(1), + r.scalartype(2), + r.isNone(2), + r.layout(3), + r.device(4), + r.isNone(4)); + return wrap(dispatch_arange(r.scalar(0), r.tensor(1)) + .set_requires_grad(r.toBool(6))); } } else if (r.idx == 1) { if (r.isNone(3)) { @@ -112,41 +132,59 @@ static PyObject * THPVariable_arange(PyObject* self, PyObject* args, PyObject* k // NOTE: r.scalartype(X) gives the default dtype if r.isNone(X) c10::optional scalarType = r.scalartypeOptional(4); const auto options = TensorOptions() - .dtype(scalarType) - .device(r.device(6)) - .layout(r.layout(5)) - .requires_grad(r.toBool(8)) - .pinned_memory(r.toBool(7)); + .dtype(scalarType) + .device(r.device(6)) + .layout(r.layout(5)) + .requires_grad(r.toBool(8)) + .pinned_memory(r.toBool(7)); return wrap(dispatch_arange(start, end, step, options)); } else { - TORCH_CHECK(!r.toBool(7), " `pin_memory` and `out` parameters are incompatible"); - check_out_type_matches(r.tensor(3), r.scalartype(4), r.isNone(4), r.layout(5), - r.device(6), r.isNone(6)); - return wrap(dispatch_arange(r.scalar(0), r.scalar(1), r.scalar(2), r.tensor(3)).set_requires_grad(r.toBool(8))); + TORCH_CHECK( + !r.toBool(7), " `pin_memory` and `out` parameters are incompatible"); + check_out_type_matches( + r.tensor(3), + r.scalartype(4), + r.isNone(4), + r.layout(5), + r.device(6), + r.isNone(6)); + return wrap( + dispatch_arange(r.scalar(0), r.scalar(1), r.scalar(2), r.tensor(3)) + .set_requires_grad(r.toBool(8))); } } Py_RETURN_NONE; END_HANDLE_TH_ERRORS } -inline Tensor dispatch_range(const Scalar& start, const Scalar& end, const Scalar& step, Tensor result) { +inline Tensor dispatch_range( + const Scalar& start, + const Scalar& end, + const Scalar& step, + Tensor result) { pybind11::gil_scoped_release no_gil; OptionalDeviceGuard device_guard(device_of(result)); return at::range_out(result, start, end, step); } -inline Tensor dispatch_range(const Scalar& start, const Scalar& end, const Scalar& step, const TensorOptions& options) { +inline Tensor dispatch_range( + const Scalar& start, + const Scalar& end, + const Scalar& step, + const TensorOptions& options) { torch::utils::maybe_initialize_cuda(options); pybind11::gil_scoped_release no_gil; DeviceGuard device_guard(options.device()); return torch::range(start, end, step, options); } -static PyObject * THPVariable_range(PyObject* self, PyObject* args, PyObject* kwargs) -{ +static PyObject* THPVariable_range( + PyObject* self, + PyObject* args, + PyObject* kwargs) { HANDLE_TH_ERRORS static PythonArgParser parser({ - "range(Scalar start, Scalar end, Scalar step=1, *, Tensor out=None, ScalarType dtype=None, Layout layout=torch.strided, Device device=None, bool requires_grad=False)", + "range(Scalar start, Scalar end, Scalar step=1, *, Tensor out=None, ScalarType dtype=None, Layout layout=torch.strided, Device device=None, bool requires_grad=False)", }); ParsedArgs<8> parsed_args; @@ -159,18 +197,27 @@ static PyObject * THPVariable_range(PyObject* self, PyObject* args, PyObject* kw "because its behavior is inconsistent with Python's range builtin. " "Instead, use torch.arange, which produces values in [start, end).", 1); - if (ret != 0) throw python_error(); + if (ret != 0) + throw python_error(); if (r.isNone(3)) { const auto options = TensorOptions() - .dtype(r.scalartype(4)) - .device(r.device(6)) - .layout(r.layout(5)) - .requires_grad(r.toBool(7)); - return wrap(dispatch_range(r.scalar(0), r.scalar(1), r.scalar(2), options)); + .dtype(r.scalartype(4)) + .device(r.device(6)) + .layout(r.layout(5)) + .requires_grad(r.toBool(7)); + return wrap( + dispatch_range(r.scalar(0), r.scalar(1), r.scalar(2), options)); } else { - check_out_type_matches(r.tensor(3), r.scalartype(4), r.isNone(4), - r.layout(5), r.device(6), r.isNone(6)); - return wrap(dispatch_range(r.scalar(0), r.scalar(1), r.scalar(2), r.tensor(3)).set_requires_grad(r.toBool(7))); + check_out_type_matches( + r.tensor(3), + r.scalartype(4), + r.isNone(4), + r.layout(5), + r.device(6), + r.isNone(6)); + return wrap( + dispatch_range(r.scalar(0), r.scalar(1), r.scalar(2), r.tensor(3)) + .set_requires_grad(r.toBool(7))); } } Py_RETURN_NONE; @@ -204,65 +251,89 @@ inline Tensor dispatch_full( return at::full_out(result, size, fill_val); } -static PyObject * THPVariable_full(PyObject* self, PyObject* args, PyObject* kwargs) { +static PyObject* THPVariable_full( + PyObject* self, + PyObject* args, + PyObject* kwargs) { HANDLE_TH_ERRORS - static PythonArgParser parser({ - "full(IntArrayRef size, Scalar fill_value, *, Tensor out=None, ScalarType dtype=None, Layout layout=torch.strided, Device device=None, bool pin_memory=False, bool requires_grad=False)", - "full(IntArrayRef size, Scalar fill_value, *, DimnameList names=None, ScalarType dtype=None, Layout layout=torch.strided, Device device=None, bool pin_memory=False, bool requires_grad=False)", - }, /*traceable=*/true); + static PythonArgParser parser( + { + "full(IntArrayRef size, Scalar fill_value, *, Tensor out=None, ScalarType dtype=None, Layout layout=torch.strided, Device device=None, bool pin_memory=False, bool requires_grad=False)", + "full(IntArrayRef size, Scalar fill_value, *, DimnameList names=None, ScalarType dtype=None, Layout layout=torch.strided, Device device=None, bool pin_memory=False, bool requires_grad=False)", + }, + /*traceable=*/true); // Acquires (common) arguments ParsedArgs<8> parsed_args; auto r = parser.parse(args, kwargs, parsed_args); - if(r.has_torch_function()) { - return handle_torch_function(r, args, kwargs, THPVariableFunctionsModule, "torch"); + if (r.has_torch_function()) { + return handle_torch_function( + r, args, kwargs, THPVariableFunctionsModule, "torch"); } auto size = r.intlist(0); auto fill_val = r.scalar(1); const auto options = TensorOptions{} - .dtype(r.scalartypeOptional(3)) - .layout(r.layout(4)) - .device(r.device(5)) - .pinned_memory(r.toBool(6)); + .dtype(r.scalartypeOptional(3)) + .layout(r.layout(4)) + .device(r.device(5)) + .pinned_memory(r.toBool(6)); if (r.idx == 0) { // full if (r.isNone(2)) { - return wrap(dispatch_full(size, fill_val, options).set_requires_grad(r.toBool(7))); + return wrap(dispatch_full(size, fill_val, options) + .set_requires_grad(r.toBool(7))); } // full.out // Validates out tensor and other kwargs auto result = r.tensor(2); - TORCH_CHECK(!r.toBool(6), " `pin_memory` and `out` parameters are incompatible"); - check_out_type_matches(result, r.scalartype(3), r.isNone(3), r.layout(4), - r.device(5), r.isNone(5)); - - return wrap(dispatch_full(size, fill_val, result).set_requires_grad(r.toBool(7))); + TORCH_CHECK( + !r.toBool(6), " `pin_memory` and `out` parameters are incompatible"); + check_out_type_matches( + result, + r.scalartype(3), + r.isNone(3), + r.layout(4), + r.device(5), + r.isNone(5)); + + return wrap( + dispatch_full(size, fill_val, result).set_requires_grad(r.toBool(7))); } else if (r.idx == 1) { // full.names if (r.isNone(2)) { - return wrap(dispatch_full(size, fill_val, c10::nullopt, options).set_requires_grad(r.toBool(7))); + return wrap(dispatch_full(size, fill_val, c10::nullopt, options) + .set_requires_grad(r.toBool(7))); } // Converts from c10::optional to c10::optional auto raw_names = r.toDimnameListOptional(2); c10::optional names(*raw_names); - return wrap(dispatch_full(size, fill_val, names, options).set_requires_grad(r.toBool(7))); + return wrap(dispatch_full(size, fill_val, names, options) + .set_requires_grad(r.toBool(7))); } Py_RETURN_NONE; END_HANDLE_TH_ERRORS } -inline Tensor dispatch_randint(int64_t high, IntArrayRef size, c10::optional generator, Tensor result) { +inline Tensor dispatch_randint( + int64_t high, + IntArrayRef size, + c10::optional generator, + Tensor result) { pybind11::gil_scoped_release no_gil; return at::randint_out(result, high, size, generator); } -inline Tensor dispatch_randint(int64_t high, IntArrayRef size, c10::optional generator, const TensorOptions & options) { +inline Tensor dispatch_randint( + int64_t high, + IntArrayRef size, + c10::optional generator, + const TensorOptions& options) { torch::utils::maybe_initialize_cuda(options); pybind11::gil_scoped_release no_gil; return torch::randint(high, size, generator, options); @@ -271,43 +342,69 @@ inline Tensor dispatch_randint(int64_t high, IntArrayRef size, Tensor result) { pybind11::gil_scoped_release no_gil; return at::randint_out(result, high, size); } -inline Tensor dispatch_randint(int64_t high, IntArrayRef size, const TensorOptions & options) { +inline Tensor dispatch_randint( + int64_t high, + IntArrayRef size, + const TensorOptions& options) { torch::utils::maybe_initialize_cuda(options); pybind11::gil_scoped_release no_gil; return torch::randint(high, size, options); } -inline Tensor dispatch_randint(int64_t low, int64_t high, IntArrayRef size, c10::optional generator, Tensor result) { +inline Tensor dispatch_randint( + int64_t low, + int64_t high, + IntArrayRef size, + c10::optional generator, + Tensor result) { pybind11::gil_scoped_release no_gil; return at::randint_out(result, low, high, size, generator); } -inline Tensor dispatch_randint(int64_t low, int64_t high, IntArrayRef size, c10::optional generator, const TensorOptions & options) { +inline Tensor dispatch_randint( + int64_t low, + int64_t high, + IntArrayRef size, + c10::optional generator, + const TensorOptions& options) { torch::utils::maybe_initialize_cuda(options); pybind11::gil_scoped_release no_gil; return torch::randint(low, high, size, generator, options); } -inline Tensor dispatch_randint(int64_t low, int64_t high, IntArrayRef size, Tensor result) { +inline Tensor dispatch_randint( + int64_t low, + int64_t high, + IntArrayRef size, + Tensor result) { pybind11::gil_scoped_release no_gil; return at::randint_out(result, low, high, size); } -inline Tensor dispatch_randint(int64_t low, int64_t high, IntArrayRef size, const TensorOptions & options) { +inline Tensor dispatch_randint( + int64_t low, + int64_t high, + IntArrayRef size, + const TensorOptions& options) { torch::utils::maybe_initialize_cuda(options); pybind11::gil_scoped_release no_gil; return torch::randint(low, high, size, options); } -static PyObject * THPVariable_randint(PyObject* self_, PyObject* args, PyObject* kwargs) -{ +static PyObject* THPVariable_randint( + PyObject* self_, + PyObject* args, + PyObject* kwargs) { HANDLE_TH_ERRORS - static PythonArgParser parser({ - "randint(int64_t high, IntArrayRef size, *, Generator generator=None, Tensor out=None, ScalarType dtype=None, Layout layout=torch.strided, Device device=None, bool requires_grad=False)", - "randint(int64_t low, int64_t high, IntArrayRef size, *, Generator generator=None, Tensor out=None, ScalarType dtype=None, Layout layout=torch.strided, Device device=None, bool requires_grad=False)", - }, /*traceable=*/false); + static PythonArgParser parser( + { + "randint(int64_t high, IntArrayRef size, *, Generator generator=None, Tensor out=None, ScalarType dtype=None, Layout layout=torch.strided, Device device=None, bool requires_grad=False)", + "randint(int64_t low, int64_t high, IntArrayRef size, *, Generator generator=None, Tensor out=None, ScalarType dtype=None, Layout layout=torch.strided, Device device=None, bool requires_grad=False)", + }, + /*traceable=*/false); ParsedArgs<9> parsed_args; auto r = parser.parse(args, kwargs, parsed_args); - if(r.has_torch_function()) { - return handle_torch_function(r, args, kwargs, THPVariableFunctionsModule, "torch"); + if (r.has_torch_function()) { + return handle_torch_function( + r, args, kwargs, THPVariableFunctionsModule, "torch"); } if (r.idx == 0) { @@ -319,15 +416,22 @@ static PyObject * THPVariable_randint(PyObject* self_, PyObject* args, PyObject* auto dtype = r.scalartypeWithDefault(4, at::ScalarType::Long); auto device = r.device(6); const auto options = TensorOptions() - .dtype(dtype) - .device(device) - .layout(r.layout(5)) - .requires_grad(r.toBool(7)); + .dtype(dtype) + .device(device) + .layout(r.layout(5)) + .requires_grad(r.toBool(7)); return wrap(dispatch_randint(high, size, generator, options)); } else { - check_out_type_matches(r.tensor(3), r.scalartype(4), r.isNone(4), - r.layout(5), r.device(6), r.isNone(6)); - return wrap(dispatch_randint(r.toInt64(0), r.intlist(1), r.generator(2), r.tensor(3)).set_requires_grad(r.toBool(7))); + check_out_type_matches( + r.tensor(3), + r.scalartype(4), + r.isNone(4), + r.layout(5), + r.device(6), + r.isNone(6)); + return wrap(dispatch_randint( + r.toInt64(0), r.intlist(1), r.generator(2), r.tensor(3)) + .set_requires_grad(r.toBool(7))); } } else if (r.idx == 1) { if (r.isNone(4)) { @@ -339,25 +443,38 @@ static PyObject * THPVariable_randint(PyObject* self_, PyObject* args, PyObject* auto dtype = r.scalartypeWithDefault(5, at::ScalarType::Long); auto device = r.device(7); const auto options = TensorOptions() - .dtype(dtype) - .device(device) - .layout(r.layout(6)) - .requires_grad(r.toBool(8)); + .dtype(dtype) + .device(device) + .layout(r.layout(6)) + .requires_grad(r.toBool(8)); return wrap(dispatch_randint(low, high, size, generator, options)); } else { - check_out_type_matches(r.tensor(4), r.scalartype(5), r.isNone(5), - r.layout(6), r.device(7), r.isNone(7)); - return wrap(dispatch_randint(r.toInt64(0), r.toInt64(1), r.intlist(2), r.generator(3), r.tensor(4)).set_requires_grad(r.toBool(8))); + check_out_type_matches( + r.tensor(4), + r.scalartype(5), + r.isNone(5), + r.layout(6), + r.device(7), + r.isNone(7)); + return wrap(dispatch_randint( + r.toInt64(0), + r.toInt64(1), + r.intlist(2), + r.generator(3), + r.tensor(4)) + .set_requires_grad(r.toBool(8))); } } Py_RETURN_NONE; END_HANDLE_TH_ERRORS } -// implemented on python object to allow torch.as_tensor to be constructed with arbitrarily nested -// python objects - list, tuple, np array, scalar, etc. -static PyObject * THPVariable_as_tensor(PyObject* self, PyObject* args, PyObject* kwargs) -{ +// implemented on python object to allow torch.as_tensor to be constructed with +// arbitrarily nested python objects - list, tuple, np array, scalar, etc. +static PyObject* THPVariable_as_tensor( + PyObject* self, + PyObject* args, + PyObject* kwargs) { HANDLE_TH_ERRORS static PythonArgParser parser({ "as_tensor(PyObject* data, *, ScalarType dtype=None, Device? device=None)", @@ -377,81 +494,108 @@ static PyObject * THPVariable_as_tensor(PyObject* self, PyObject* args, PyObject END_HANDLE_TH_ERRORS } -// implemented on python object here because PyObject currently not natively declarable -// See: ATen/native/README.md for more context -static PyObject * THPVariable_from_numpy(PyObject* module, PyObject* arg) -{ +// implemented on python object here because PyObject currently not natively +// declarable See: ATen/native/README.md for more context +static PyObject* THPVariable_from_numpy(PyObject* module, PyObject* arg) { HANDLE_TH_ERRORS jit::tracer::warn("torch.from_numpy", jit::tracer::WARN_CONSTRUCTOR); return THPVariable_Wrap(torch::utils::tensor_from_numpy(arg)); END_HANDLE_TH_ERRORS } -static Tensor dispatch_nonzero(const Tensor & self) { +static Tensor dispatch_nonzero(const Tensor& self) { pybind11::gil_scoped_release no_gil; OptionalDeviceGuard device_guard(device_of(self)); return self.nonzero(); } -static Tensor dispatch_nonzero(const Tensor & self, Tensor out) { +static Tensor dispatch_nonzero(const Tensor& self, Tensor out) { pybind11::gil_scoped_release no_gil; OptionalDeviceGuard device_guard(device_of(self)); return at::nonzero_out(out, self); } -static std::vector dispatch_nonzero_numpy(const Tensor & self) { +static std::vector dispatch_nonzero_numpy(const Tensor& self) { pybind11::gil_scoped_release no_gil; OptionalDeviceGuard device_guard(device_of(self)); return self.nonzero_numpy(); } -static PyObject * THPVariable_nonzero(PyObject* self, PyObject* args, PyObject* kwargs); - -#define THPVARIABLE_SPARSE_COMPRESSED_CTOR(NAME, NARGS, SIGNATURES) \ -static PyObject * THPVariable_ ## NAME(PyObject* self, PyObject* args, PyObject* kwargs) \ -{ \ - HANDLE_TH_ERRORS \ - static PythonArgParser parser SIGNATURES ; \ - ParsedArgs parsed_args; \ - auto r = parser.parse(args, kwargs, parsed_args); \ - if (r.has_torch_function()) { \ - return handle_torch_function(r, nullptr, args, kwargs, THPVariableFunctionsModule, "torch"); \ - } \ - jit::tracer::warn("torch." # NAME, jit::tracer::WARN_CONSTRUCTOR); \ - return THPVariable_Wrap(torch::utils::NAME ## _ctor(torch::tensors::get_default_dispatch_key(), torch::tensors::get_default_scalar_type(), r)); \ - END_HANDLE_TH_ERRORS \ -} +static PyObject* THPVariable_nonzero( + PyObject* self, + PyObject* args, + PyObject* kwargs); + +#define THPVARIABLE_SPARSE_COMPRESSED_CTOR(NAME, NARGS, SIGNATURES) \ + static PyObject* THPVariable_##NAME( \ + PyObject* self, PyObject* args, PyObject* kwargs) { \ + HANDLE_TH_ERRORS \ + static PythonArgParser parser SIGNATURES; \ + ParsedArgs parsed_args; \ + auto r = parser.parse(args, kwargs, parsed_args); \ + if (r.has_torch_function()) { \ + return handle_torch_function( \ + r, nullptr, args, kwargs, THPVariableFunctionsModule, "torch"); \ + } \ + jit::tracer::warn("torch." #NAME, jit::tracer::WARN_CONSTRUCTOR); \ + return THPVariable_Wrap(torch::utils::NAME##_ctor( \ + torch::tensors::get_default_dispatch_key(), \ + torch::tensors::get_default_scalar_type(), \ + r)); \ + END_HANDLE_TH_ERRORS \ + } -THPVARIABLE_SPARSE_COMPRESSED_CTOR(sparse_compressed_tensor, 9, +THPVARIABLE_SPARSE_COMPRESSED_CTOR( + sparse_compressed_tensor, + 9, ({"sparse_compressed_tensor(PyObject* compressed_indices, PyObject* plain_indices, PyObject* values, IntArrayRef size, *, ScalarType dtype=None, Layout? layout=None, Device? device=None, bool pin_memory=False, bool requires_grad=False)", - "sparse_compressed_tensor(PyObject* compressed_indices, PyObject* plain_indices, PyObject* values, *, ScalarType dtype=None, Layout? layout=None, Device? device=None, bool pin_memory=False, bool requires_grad=False)"})) -THPVARIABLE_SPARSE_COMPRESSED_CTOR(sparse_csr_tensor, 9, + "sparse_compressed_tensor(PyObject* compressed_indices, PyObject* plain_indices, PyObject* values, *, ScalarType dtype=None, Layout? layout=None, Device? device=None, bool pin_memory=False, bool requires_grad=False)"})) +THPVARIABLE_SPARSE_COMPRESSED_CTOR( + sparse_csr_tensor, + 9, ({"sparse_csr_tensor(PyObject* crow_indices, PyObject* col_indices, PyObject* values, IntArrayRef size, *, ScalarType dtype=None, Layout? layout=None, Device? device=None, bool pin_memory=False, bool requires_grad=False)", "sparse_csr_tensor(PyObject* crow_indices, PyObject* col_indices, PyObject* values, *, ScalarType dtype=None, Layout? layout=None, Device? device=None, bool pin_memory=False, bool requires_grad=False)"})) -THPVARIABLE_SPARSE_COMPRESSED_CTOR(sparse_csc_tensor, 9, +THPVARIABLE_SPARSE_COMPRESSED_CTOR( + sparse_csc_tensor, + 9, ({"sparse_csc_tensor(PyObject* ccol_indices, PyObject* row_indices, PyObject* values, IntArrayRef size, *, ScalarType dtype=None, Layout? layout=None, Device? device=None, bool pin_memory=False, bool requires_grad=False)", "sparse_csc_tensor(PyObject* ccol_indices, PyObject* row_indices, PyObject* values, *, ScalarType dtype=None, Layout? layout=None, Device? device=None, bool pin_memory=False, bool requires_grad=False)"})) -THPVARIABLE_SPARSE_COMPRESSED_CTOR(sparse_bsr_tensor, 9, +THPVARIABLE_SPARSE_COMPRESSED_CTOR( + sparse_bsr_tensor, + 9, ({"sparse_bsr_tensor(PyObject* crow_indices, PyObject* col_indices, PyObject* values, IntArrayRef size, *, ScalarType dtype=None, Layout? layout=None, Device? device=None, bool pin_memory=False, bool requires_grad=False)", "sparse_bsr_tensor(PyObject* crow_indices, PyObject* col_indices, PyObject* values, *, ScalarType dtype=None, Layout? layout=None, Device? device=None, bool pin_memory=False, bool requires_grad=False)"})) -THPVARIABLE_SPARSE_COMPRESSED_CTOR(sparse_bsc_tensor, 9, +THPVARIABLE_SPARSE_COMPRESSED_CTOR( + sparse_bsc_tensor, + 9, ({"sparse_bsc_tensor(PyObject* ccol_indices, PyObject* row_indices, PyObject* values, IntArrayRef size, *, ScalarType dtype=None, Layout? layout=None, Device? device=None, bool pin_memory=False, bool requires_grad=False)", "sparse_bsc_tensor(PyObject* ccol_indices, PyObject* row_indices, PyObject* values, *, ScalarType dtype=None, Layout? layout=None, Device? device=None, bool pin_memory=False, bool requires_grad=False)"})) - -THPVARIABLE_SPARSE_COMPRESSED_CTOR(_sparse_compressed_tensor_unsafe, 8, +THPVARIABLE_SPARSE_COMPRESSED_CTOR( + _sparse_compressed_tensor_unsafe, + 8, ({"_sparse_compressed_tensor_unsafe(PyObject* compressed_indices, PyObject* plain_indices, PyObject* values, IntArrayRef size, *, ScalarType dtype=None, Layout? layout=None, Device? device=None, bool requires_grad=False)"})) -THPVARIABLE_SPARSE_COMPRESSED_CTOR(_sparse_csr_tensor_unsafe, 7, +THPVARIABLE_SPARSE_COMPRESSED_CTOR( + _sparse_csr_tensor_unsafe, + 7, ({"_sparse_csr_tensor_unsafe(PyObject* crow_indices, PyObject* col_indices, PyObject* values, IntArrayRef size, *, ScalarType dtype=None, Device? device=None, bool requires_grad=False)"})) -THPVARIABLE_SPARSE_COMPRESSED_CTOR(_sparse_csc_tensor_unsafe, 7, +THPVARIABLE_SPARSE_COMPRESSED_CTOR( + _sparse_csc_tensor_unsafe, + 7, ({"_sparse_csc_tensor_unsafe(PyObject* ccol_indices, PyObject* row_indices, PyObject* values, IntArrayRef size, *, ScalarType dtype=None, Device? device=None, bool requires_grad=False)"})) -THPVARIABLE_SPARSE_COMPRESSED_CTOR(_sparse_bsr_tensor_unsafe, 7, +THPVARIABLE_SPARSE_COMPRESSED_CTOR( + _sparse_bsr_tensor_unsafe, + 7, ({"_sparse_bsr_tensor_unsafe(PyObject* crow_indices, PyObject* col_indices, PyObject* values, IntArrayRef size, *, ScalarType dtype=None, Device? device=None, bool requires_grad=False)"})) -THPVARIABLE_SPARSE_COMPRESSED_CTOR(_sparse_bsc_tensor_unsafe, 7, +THPVARIABLE_SPARSE_COMPRESSED_CTOR( + _sparse_bsc_tensor_unsafe, + 7, ({"_sparse_bsc_tensor_unsafe(PyObject* ccol_indices, PyObject* row_indices, PyObject* values, IntArrayRef size, *, ScalarType dtype=None, Device? device=None, bool requires_grad=False)"})) -static PyObject * THPVariable_sparse_coo_tensor(PyObject* self, PyObject* args, PyObject* kwargs) -{ +static PyObject* THPVariable_sparse_coo_tensor( + PyObject* self, + PyObject* args, + PyObject* kwargs) { HANDLE_TH_ERRORS static PythonArgParser parser({ "sparse_coo_tensor(PyObject* indices, PyObject* values, *, ScalarType dtype=None, Device? device=None, bool requires_grad=False)", @@ -473,8 +617,10 @@ static PyObject * THPVariable_sparse_coo_tensor(PyObject* self, PyObject* args, END_HANDLE_TH_ERRORS } -static PyObject * THPVariable__sparse_coo_tensor_unsafe(PyObject* self, PyObject* args, PyObject* kwargs) -{ +static PyObject* THPVariable__sparse_coo_tensor_unsafe( + PyObject* self, + PyObject* args, + PyObject* kwargs) { HANDLE_TH_ERRORS static PythonArgParser parser({ "_sparse_coo_tensor_unsafe(PyObject* indices, PyObject* values, IntArrayRef size, *, ScalarType dtype=None, Device? device=None, bool requires_grad=False)", @@ -486,7 +632,8 @@ static PyObject * THPVariable__sparse_coo_tensor_unsafe(PyObject* self, PyObject return handle_torch_function( r, nullptr, args, kwargs, THPVariableFunctionsModule, "torch"); } - jit::tracer::warn("torch._sparse_coo_tensor_unsafe", jit::tracer::WARN_CONSTRUCTOR); + jit::tracer::warn( + "torch._sparse_coo_tensor_unsafe", jit::tracer::WARN_CONSTRUCTOR); return THPVariable_Wrap(torch::utils::_sparse_coo_tensor_unsafe_ctor( torch::tensors::get_default_dispatch_key(), torch::tensors::get_default_scalar_type(), @@ -494,10 +641,12 @@ static PyObject * THPVariable__sparse_coo_tensor_unsafe(PyObject* self, PyObject END_HANDLE_TH_ERRORS } -// implemented on python object to allow torch.tensor to be constructed with arbitrarily nested -// python objects - list, tuple, np array, scalar, etc. -static PyObject * THPVariable_tensor(PyObject* self, PyObject* args, PyObject* kwargs) -{ +// implemented on python object to allow torch.tensor to be constructed with +// arbitrarily nested python objects - list, tuple, np array, scalar, etc. +static PyObject* THPVariable_tensor( + PyObject* self, + PyObject* args, + PyObject* kwargs) { HANDLE_TH_ERRORS static PythonArgParser parser({ "tensor(PyObject* data, *, ScalarType dtype=None, Device? device=None, bool pin_memory=False, bool requires_grad=False, DimnameList? names=None)", @@ -518,12 +667,16 @@ static PyObject * THPVariable_tensor(PyObject* self, PyObject* args, PyObject* k END_HANDLE_TH_ERRORS } -static PyObject * THPVariable_get_device(PyObject* self_, PyObject* args, PyObject* kwargs) -{ +static PyObject* THPVariable_get_device( + PyObject* self_, + PyObject* args, + PyObject* kwargs) { HANDLE_TH_ERRORS - static PythonArgParser parser({ - "get_device(Tensor input)", - }, /*traceable=*/false); + static PythonArgParser parser( + { + "get_device(Tensor input)", + }, + /*traceable=*/false); ParsedArgs<1> parsed_args; auto r = parser.parse(args, kwargs, parsed_args); @@ -535,12 +688,16 @@ static PyObject * THPVariable_get_device(PyObject* self_, PyObject* args, PyObje END_HANDLE_TH_ERRORS } -static PyObject * THPVariable_frombuffer(PyObject* self_, PyObject* args, PyObject* kwargs) -{ +static PyObject* THPVariable_frombuffer( + PyObject* self_, + PyObject* args, + PyObject* kwargs) { HANDLE_TH_ERRORS - static PythonArgParser parser({ - "frombuffer(PyObject* buffer, *, ScalarType dtype, int64_t count=-1, int64_t offset=0, bool requires_grad=False)", - }, /*traceable=*/false); + static PythonArgParser parser( + { + "frombuffer(PyObject* buffer, *, ScalarType dtype, int64_t count=-1, int64_t offset=0, bool requires_grad=False)", + }, + /*traceable=*/false); ParsedArgs<5> parsed_args; auto r = parser.parse(args, kwargs, parsed_args); @@ -563,12 +720,16 @@ static PyObject * THPVariable_frombuffer(PyObject* self_, PyObject* args, PyObje END_HANDLE_TH_ERRORS } -static PyObject * THPVariable_asarray(PyObject* self_, PyObject* args, PyObject* kwargs) -{ +static PyObject* THPVariable_asarray( + PyObject* self_, + PyObject* args, + PyObject* kwargs) { HANDLE_TH_ERRORS - static PythonArgParser parser({ - "asarray(PyObject* obj, *, ScalarType? dtype=None, Device? device=None, bool? copy=None, bool requires_grad=False)", - }, /*traceable=*/false); + static PythonArgParser parser( + { + "asarray(PyObject* obj, *, ScalarType? dtype=None, Device? device=None, bool? copy=None, bool requires_grad=False)", + }, + /*traceable=*/false); ParsedArgs<5> parsed_args; auto r = parser.parse(args, kwargs, parsed_args); @@ -586,108 +747,160 @@ static PyObject * THPVariable_asarray(PyObject* self_, PyObject* args, PyObject* END_HANDLE_TH_ERRORS } -static PyObject * THPVariable_numel(PyObject* self_, PyObject* args, PyObject* kwargs); +static PyObject* THPVariable_numel( + PyObject* self_, + PyObject* args, + PyObject* kwargs); // linspace -static PyObject * THPVariable_linspace(PyObject* self_, PyObject* args, PyObject* kwargs) -{ +static PyObject* THPVariable_linspace( + PyObject* self_, + PyObject* args, + PyObject* kwargs) { HANDLE_TH_ERRORS - static PythonArgParser parser({ - "linspace(Scalar start, Scalar end, int64_t steps, *, Tensor out=None, ScalarType dtype=None, Layout layout=torch.strided, Device device=None, bool pin_memory=False, bool requires_grad=False)", - }, /*traceable=*/true); + static PythonArgParser parser( + { + "linspace(Scalar start, Scalar end, int64_t steps, *, Tensor out=None, ScalarType dtype=None, Layout layout=torch.strided, Device device=None, bool pin_memory=False, bool requires_grad=False)", + }, + /*traceable=*/true); ParsedArgs<9> parsed_args; auto _r = parser.parse(nullptr, args, kwargs, parsed_args); - if(_r.has_torch_function()) { - return handle_torch_function(_r, nullptr, args, kwargs, THPVariableFunctionsModule, "torch"); + if (_r.has_torch_function()) { + return handle_torch_function( + _r, nullptr, args, kwargs, THPVariableFunctionsModule, "torch"); } if (_r.isNone(3)) { - // aten::linspace(Scalar start, Scalar end, int steps, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + // aten::linspace(Scalar start, Scalar end, int steps, *, ScalarType? + // dtype=None, Layout? layout=None, Device? device=None, bool? + // pin_memory=None) -> Tensor // NOTE: r.scalartype(X) gives the default dtype if r.isNone(X) // This leads to problem in the operator argument checks, // when either `start` or `end` is complex and dtype is None const auto options = TensorOptions() - .dtype(_r.scalartypeOptional(4)) - .device(_r.device(6)) - .layout(_r.layoutOptional(5)) - .requires_grad(_r.toBool(8)) - .pinned_memory(_r.toBool(7)); + .dtype(_r.scalartypeOptional(4)) + .device(_r.device(6)) + .layout(_r.layoutOptional(5)) + .requires_grad(_r.toBool(8)) + .pinned_memory(_r.toBool(7)); torch::utils::maybe_initialize_cuda(options); - auto dispatch_linspace = [](Scalar start, Scalar end, int64_t steps, TensorOptions options) -> Tensor { + auto dispatch_linspace = [](Scalar start, + Scalar end, + int64_t steps, + TensorOptions options) -> Tensor { pybind11::gil_scoped_release no_gil; return torch::linspace(start, end, steps, options); }; - return wrap(dispatch_linspace(_r.scalar(0), _r.scalar(1), _r.toInt64(2), options)); + return wrap( + dispatch_linspace(_r.scalar(0), _r.scalar(1), _r.toInt64(2), options)); } else { - // aten::linspace.out(Scalar start, Scalar end, int? steps=None, *, Tensor(a!) out) -> Tensor(a!) - check_out_type_matches(_r.tensor(3), _r.scalartype(4), - _r.isNone(4), _r.layoutOptional(5), - _r.device(6), _r.isNone(6)); - - auto dispatch_linspace_out = [](Tensor out, Scalar start, Scalar end, int64_t steps) -> Tensor { + // aten::linspace.out(Scalar start, Scalar end, int? steps=None, *, + // Tensor(a!) out) -> Tensor(a!) + check_out_type_matches( + _r.tensor(3), + _r.scalartype(4), + _r.isNone(4), + _r.layoutOptional(5), + _r.device(6), + _r.isNone(6)); + + auto dispatch_linspace_out = + [](Tensor out, Scalar start, Scalar end, int64_t steps) -> Tensor { pybind11::gil_scoped_release no_gil; return at::linspace_out(out, start, end, steps); }; - return wrap(dispatch_linspace_out(_r.tensor(3), _r.scalar(0), _r.scalar(1), _r.toInt64(2)).set_requires_grad(_r.toBool(8))); + return wrap(dispatch_linspace_out( + _r.tensor(3), _r.scalar(0), _r.scalar(1), _r.toInt64(2)) + .set_requires_grad(_r.toBool(8))); } Py_RETURN_NONE; END_HANDLE_TH_ERRORS } // logspace -static PyObject * THPVariable_logspace(PyObject* self_, PyObject* args, PyObject* kwargs) -{ +static PyObject* THPVariable_logspace( + PyObject* self_, + PyObject* args, + PyObject* kwargs) { HANDLE_TH_ERRORS - static PythonArgParser parser({ - "logspace(Scalar start, Scalar end, int64_t steps, double base=10.0, *, Tensor out=None, ScalarType dtype=None, Layout layout=torch.strided, Device device=None, bool pin_memory=False, bool requires_grad=False)", - }, /*traceable=*/true); + static PythonArgParser parser( + { + "logspace(Scalar start, Scalar end, int64_t steps, double base=10.0, *, Tensor out=None, ScalarType dtype=None, Layout layout=torch.strided, Device device=None, bool pin_memory=False, bool requires_grad=False)", + }, + /*traceable=*/true); ParsedArgs<10> parsed_args; auto _r = parser.parse(nullptr, args, kwargs, parsed_args); - if(_r.has_torch_function()) { - return handle_torch_function(_r, nullptr, args, kwargs, THPVariableFunctionsModule, "torch"); + if (_r.has_torch_function()) { + return handle_torch_function( + _r, nullptr, args, kwargs, THPVariableFunctionsModule, "torch"); } if (_r.isNone(4)) { - // aten::logspace(Scalar start, Scalar end, int steps, float base=10.0, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + // aten::logspace(Scalar start, Scalar end, int steps, float base=10.0, *, + // ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? + // pin_memory=None) -> Tensor // NOTE: r.scalartype(X) gives the default dtype if r.isNone(X) // This leads to problem in the operator argument checks, // when either `start` or `end` is complex and dtype is None const auto options = TensorOptions() - .dtype(_r.scalartypeOptional(5)) - .device(_r.device(7)) - .layout(_r.layoutOptional(6)) - .requires_grad(_r.toBool(9)) - .pinned_memory(_r.toBool(8)); + .dtype(_r.scalartypeOptional(5)) + .device(_r.device(7)) + .layout(_r.layoutOptional(6)) + .requires_grad(_r.toBool(9)) + .pinned_memory(_r.toBool(8)); torch::utils::maybe_initialize_cuda(options); - auto dispatch_logspace = [](Scalar start, Scalar end, int64_t steps, double base, TensorOptions options) -> Tensor { + auto dispatch_logspace = [](Scalar start, + Scalar end, + int64_t steps, + double base, + TensorOptions options) -> Tensor { pybind11::gil_scoped_release no_gil; return torch::logspace(start, end, steps, base, options); }; - return wrap(dispatch_logspace(_r.scalar(0), _r.scalar(1), _r.toInt64(2), _r.toDouble(3), options)); + return wrap(dispatch_logspace( + _r.scalar(0), _r.scalar(1), _r.toInt64(2), _r.toDouble(3), options)); } else { - // aten::logspace.out(Scalar start, Scalar end, int steps, float base=10.0, *, Tensor(a!) out) -> Tensor(a!) - check_out_type_matches(_r.tensor(4), _r.scalartype(5), - _r.isNone(5), _r.layoutOptional(6), - _r.device(7), _r.isNone(7)); - - auto dispatch_logspace_out = [](Tensor out, Scalar start, Scalar end, int64_t steps, double base) -> Tensor { + // aten::logspace.out(Scalar start, Scalar end, int steps, float base=10.0, + // *, Tensor(a!) out) -> Tensor(a!) + check_out_type_matches( + _r.tensor(4), + _r.scalartype(5), + _r.isNone(5), + _r.layoutOptional(6), + _r.device(7), + _r.isNone(7)); + + auto dispatch_logspace_out = [](Tensor out, + Scalar start, + Scalar end, + int64_t steps, + double base) -> Tensor { pybind11::gil_scoped_release no_gil; return at::logspace_out(out, start, end, steps, base); }; - return wrap(dispatch_logspace_out(_r.tensor(4), _r.scalar(0), _r.scalar(1), _r.toInt64(2), _r.toDouble(3)).set_requires_grad(_r.toBool(9))); + return wrap(dispatch_logspace_out( + _r.tensor(4), + _r.scalar(0), + _r.scalar(1), + _r.toInt64(2), + _r.toDouble(3)) + .set_requires_grad(_r.toBool(9))); } Py_RETURN_NONE; END_HANDLE_TH_ERRORS } -static PyObject * THPVariable__to_functional_tensor(PyObject *self, PyObject* args, PyObject* kwargs) -{ +static PyObject* THPVariable__to_functional_tensor( + PyObject* self, + PyObject* args, + PyObject* kwargs) { HANDLE_TH_ERRORS - static PythonArgParser parser({"_to_functional_tensor(Tensor t)"}, /*traceable=*/true); + static PythonArgParser parser( + {"_to_functional_tensor(Tensor t)"}, /*traceable=*/true); ParsedArgs<1> parsed_args; auto r = parser.parse(args, kwargs, parsed_args); @@ -697,10 +910,13 @@ static PyObject * THPVariable__to_functional_tensor(PyObject *self, PyObject* ar END_HANDLE_TH_ERRORS } -static PyObject * THPVariable__from_functional_tensor(PyObject *self, PyObject* args, PyObject* kwargs) -{ +static PyObject* THPVariable__from_functional_tensor( + PyObject* self, + PyObject* args, + PyObject* kwargs) { HANDLE_TH_ERRORS - static PythonArgParser parser({"_from_functional_tensor(Tensor t)"}, /*traceable=*/true); + static PythonArgParser parser( + {"_from_functional_tensor(Tensor t)"}, /*traceable=*/true); ParsedArgs<1> parsed_args; auto r = parser.parse(args, kwargs, parsed_args); @@ -710,10 +926,13 @@ static PyObject * THPVariable__from_functional_tensor(PyObject *self, PyObject* END_HANDLE_TH_ERRORS } -static PyObject * THPVariable__is_functional_tensor(PyObject *self, PyObject* args, PyObject* kwargs) -{ +static PyObject* THPVariable__is_functional_tensor( + PyObject* self, + PyObject* args, + PyObject* kwargs) { HANDLE_TH_ERRORS - static PythonArgParser parser({"_is_functional_tensor(Tensor t)"}, /*traceable=*/true); + static PythonArgParser parser( + {"_is_functional_tensor(Tensor t)"}, /*traceable=*/true); ParsedArgs<1> parsed_args; auto r = parser.parse(args, kwargs, parsed_args); @@ -726,8 +945,10 @@ static PyObject * THPVariable__is_functional_tensor(PyObject *self, PyObject* ar END_HANDLE_TH_ERRORS } -static PyObject * THPVariable__sync(PyObject *self, PyObject* args, PyObject* kwargs) -{ +static PyObject* THPVariable__sync( + PyObject* self, + PyObject* args, + PyObject* kwargs) { HANDLE_TH_ERRORS static PythonArgParser parser({"_sync(Tensor t)"}, /*traceable=*/true); @@ -740,30 +961,40 @@ static PyObject * THPVariable__sync(PyObject *self, PyObject* args, PyObject* kw END_HANDLE_TH_ERRORS } -static PyObject * THPVariable__enable_functionalization(PyObject *self, PyObject* args, PyObject* kwargs) -{ +static PyObject* THPVariable__enable_functionalization( + PyObject* self, + PyObject* args, + PyObject* kwargs) { HANDLE_TH_ERRORS - static PythonArgParser parser({"_enable_functionalization(*, bool reapply_views=False)"}, /*traceable=*/true); + static PythonArgParser parser( + {"_enable_functionalization(*, bool reapply_views=False)"}, + /*traceable=*/true); ParsedArgs<1> parsed_args; auto r = parser.parse(args, kwargs, parsed_args); const auto reapply_views = r.toBool(0); if (c10::impl::tls_is_dispatch_key_included(at::DispatchKey::Functionalize)) { - TORCH_INTERNAL_ASSERT(false, "multiple layers of mode-style functionalization nesting is not" - " currently supported, outside of the functionalize() transform"); + TORCH_INTERNAL_ASSERT( + false, + "multiple layers of mode-style functionalization nesting is not" + " currently supported, outside of the functionalize() transform"); } - c10::impl::tls_set_dispatch_key_included(at::DispatchKey::Functionalize, true); + c10::impl::tls_set_dispatch_key_included( + at::DispatchKey::Functionalize, true); if (reapply_views) { - at::functionalization::impl::setFunctionalizationReapplyViewsTLS(true); + at::functionalization::impl::setFunctionalizationReapplyViewsTLS(true); } Py_RETURN_NONE; END_HANDLE_TH_ERRORS } -static PyObject * THPVariable__disable_functionalization(PyObject *self, PyObject* args, PyObject* kwargs) -{ +static PyObject* THPVariable__disable_functionalization( + PyObject* self, + PyObject* args, + PyObject* kwargs) { HANDLE_TH_ERRORS - c10::impl::tls_set_dispatch_key_included(at::DispatchKey::Functionalize, false); + c10::impl::tls_set_dispatch_key_included( + at::DispatchKey::Functionalize, false); at::functionalization::impl::setFunctionalizationReapplyViewsTLS(false); Py_RETURN_NONE; END_HANDLE_TH_ERRORS @@ -774,61 +1005,156 @@ static PyObject * THPVariable__disable_functionalization(PyObject *self, PyObjec // being registered through native_functions.yaml, and be tagged cpp / JIT // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays) static PyMethodDef torch_functions_manual[] = { - {"arange", castPyCFunctionWithKeywords(THPVariable_arange), - METH_VARARGS | METH_KEYWORDS | METH_STATIC, nullptr}, - {"asarray", castPyCFunctionWithKeywords(THPVariable_asarray), - METH_VARARGS | METH_KEYWORDS | METH_STATIC, nullptr}, - {"as_tensor", castPyCFunctionWithKeywords(THPVariable_as_tensor), - METH_VARARGS | METH_KEYWORDS | METH_STATIC, nullptr}, - {"from_numpy", THPVariable_from_numpy, METH_STATIC | METH_O, nullptr}, - {"frombuffer", castPyCFunctionWithKeywords(THPVariable_frombuffer), METH_VARARGS | METH_KEYWORDS | METH_STATIC, nullptr}, - {"full", castPyCFunctionWithKeywords(THPVariable_full), METH_VARARGS | METH_KEYWORDS | METH_STATIC, nullptr}, - {"linspace", castPyCFunctionWithKeywords(THPVariable_linspace), METH_VARARGS | METH_KEYWORDS | METH_STATIC, nullptr}, - {"logspace", castPyCFunctionWithKeywords(THPVariable_logspace), METH_VARARGS | METH_KEYWORDS | METH_STATIC, nullptr}, - {"_is_functional_tensor", castPyCFunctionWithKeywords(THPVariable__is_functional_tensor), METH_VARARGS | METH_KEYWORDS | METH_STATIC, nullptr}, - {"_to_functional_tensor", castPyCFunctionWithKeywords(THPVariable__to_functional_tensor), METH_VARARGS | METH_KEYWORDS | METH_STATIC, nullptr}, - {"_from_functional_tensor", castPyCFunctionWithKeywords(THPVariable__from_functional_tensor), METH_VARARGS | METH_KEYWORDS | METH_STATIC, nullptr}, - {"_sync", castPyCFunctionWithKeywords(THPVariable__sync), METH_VARARGS | METH_KEYWORDS | METH_STATIC, nullptr}, - {"_enable_functionalization", castPyCFunctionWithKeywords(THPVariable__enable_functionalization), METH_VARARGS | METH_KEYWORDS | METH_STATIC, nullptr}, - {"_disable_functionalization", castPyCFunctionWithKeywords(THPVariable__disable_functionalization), METH_VARARGS | METH_KEYWORDS | METH_STATIC, nullptr}, - {"nonzero", castPyCFunctionWithKeywords(THPVariable_nonzero), METH_VARARGS | METH_KEYWORDS | METH_STATIC, nullptr}, - {"randint", castPyCFunctionWithKeywords(THPVariable_randint), METH_VARARGS | METH_KEYWORDS | METH_STATIC, nullptr}, - {"range", castPyCFunctionWithKeywords(THPVariable_range), METH_VARARGS | METH_KEYWORDS | METH_STATIC, nullptr}, - {"sparse_coo_tensor", castPyCFunctionWithKeywords(THPVariable_sparse_coo_tensor), METH_VARARGS | METH_KEYWORDS | METH_STATIC, nullptr}, - {"_sparse_coo_tensor_unsafe", castPyCFunctionWithKeywords(THPVariable__sparse_coo_tensor_unsafe), METH_VARARGS | METH_KEYWORDS | METH_STATIC, nullptr}, - {"_sparse_compressed_tensor_unsafe", castPyCFunctionWithKeywords(THPVariable__sparse_compressed_tensor_unsafe), METH_VARARGS | METH_KEYWORDS | METH_STATIC, nullptr}, - {"sparse_compressed_tensor", castPyCFunctionWithKeywords(THPVariable_sparse_compressed_tensor), METH_VARARGS | METH_KEYWORDS | METH_STATIC, nullptr}, - {"sparse_csr_tensor", castPyCFunctionWithKeywords(THPVariable_sparse_csr_tensor), METH_VARARGS | METH_KEYWORDS | METH_STATIC, nullptr}, - {"sparse_csc_tensor", castPyCFunctionWithKeywords(THPVariable_sparse_csc_tensor), METH_VARARGS | METH_KEYWORDS | METH_STATIC, nullptr}, - {"sparse_bsr_tensor", castPyCFunctionWithKeywords(THPVariable_sparse_bsr_tensor), METH_VARARGS | METH_KEYWORDS | METH_STATIC, nullptr}, - {"sparse_bsc_tensor", castPyCFunctionWithKeywords(THPVariable_sparse_bsc_tensor), METH_VARARGS | METH_KEYWORDS | METH_STATIC, nullptr}, - {"_sparse_csr_tensor_unsafe", castPyCFunctionWithKeywords(THPVariable__sparse_csr_tensor_unsafe), METH_VARARGS | METH_KEYWORDS | METH_STATIC, nullptr}, - {"_sparse_csc_tensor_unsafe", castPyCFunctionWithKeywords(THPVariable__sparse_csc_tensor_unsafe), METH_VARARGS | METH_KEYWORDS | METH_STATIC, nullptr}, - {"_sparse_bsr_tensor_unsafe", castPyCFunctionWithKeywords(THPVariable__sparse_bsr_tensor_unsafe), METH_VARARGS | METH_KEYWORDS | METH_STATIC, nullptr}, - {"_sparse_bsc_tensor_unsafe", castPyCFunctionWithKeywords(THPVariable__sparse_bsc_tensor_unsafe), METH_VARARGS | METH_KEYWORDS | METH_STATIC, nullptr}, - {"tensor", castPyCFunctionWithKeywords(THPVariable_tensor), METH_VARARGS | METH_KEYWORDS | METH_STATIC, nullptr}, - {"get_device", castPyCFunctionWithKeywords(THPVariable_get_device), METH_VARARGS | METH_KEYWORDS | METH_STATIC, nullptr}, - {"numel", castPyCFunctionWithKeywords(THPVariable_numel), METH_VARARGS | METH_KEYWORDS | METH_STATIC, nullptr}, + {"arange", + castPyCFunctionWithKeywords(THPVariable_arange), + METH_VARARGS | METH_KEYWORDS | METH_STATIC, + nullptr}, + {"asarray", + castPyCFunctionWithKeywords(THPVariable_asarray), + METH_VARARGS | METH_KEYWORDS | METH_STATIC, + nullptr}, + {"as_tensor", + castPyCFunctionWithKeywords(THPVariable_as_tensor), + METH_VARARGS | METH_KEYWORDS | METH_STATIC, + nullptr}, + {"from_numpy", THPVariable_from_numpy, METH_STATIC | METH_O, nullptr}, + {"frombuffer", + castPyCFunctionWithKeywords(THPVariable_frombuffer), + METH_VARARGS | METH_KEYWORDS | METH_STATIC, + nullptr}, + {"full", + castPyCFunctionWithKeywords(THPVariable_full), + METH_VARARGS | METH_KEYWORDS | METH_STATIC, + nullptr}, + {"linspace", + castPyCFunctionWithKeywords(THPVariable_linspace), + METH_VARARGS | METH_KEYWORDS | METH_STATIC, + nullptr}, + {"logspace", + castPyCFunctionWithKeywords(THPVariable_logspace), + METH_VARARGS | METH_KEYWORDS | METH_STATIC, + nullptr}, + {"_is_functional_tensor", + castPyCFunctionWithKeywords(THPVariable__is_functional_tensor), + METH_VARARGS | METH_KEYWORDS | METH_STATIC, + nullptr}, + {"_to_functional_tensor", + castPyCFunctionWithKeywords(THPVariable__to_functional_tensor), + METH_VARARGS | METH_KEYWORDS | METH_STATIC, + nullptr}, + {"_from_functional_tensor", + castPyCFunctionWithKeywords(THPVariable__from_functional_tensor), + METH_VARARGS | METH_KEYWORDS | METH_STATIC, + nullptr}, + {"_sync", + castPyCFunctionWithKeywords(THPVariable__sync), + METH_VARARGS | METH_KEYWORDS | METH_STATIC, + nullptr}, + {"_enable_functionalization", + castPyCFunctionWithKeywords(THPVariable__enable_functionalization), + METH_VARARGS | METH_KEYWORDS | METH_STATIC, + nullptr}, + {"_disable_functionalization", + castPyCFunctionWithKeywords(THPVariable__disable_functionalization), + METH_VARARGS | METH_KEYWORDS | METH_STATIC, + nullptr}, + {"nonzero", + castPyCFunctionWithKeywords(THPVariable_nonzero), + METH_VARARGS | METH_KEYWORDS | METH_STATIC, + nullptr}, + {"randint", + castPyCFunctionWithKeywords(THPVariable_randint), + METH_VARARGS | METH_KEYWORDS | METH_STATIC, + nullptr}, + {"range", + castPyCFunctionWithKeywords(THPVariable_range), + METH_VARARGS | METH_KEYWORDS | METH_STATIC, + nullptr}, + {"sparse_coo_tensor", + castPyCFunctionWithKeywords(THPVariable_sparse_coo_tensor), + METH_VARARGS | METH_KEYWORDS | METH_STATIC, + nullptr}, + {"_sparse_coo_tensor_unsafe", + castPyCFunctionWithKeywords(THPVariable__sparse_coo_tensor_unsafe), + METH_VARARGS | METH_KEYWORDS | METH_STATIC, + nullptr}, + {"_sparse_compressed_tensor_unsafe", + castPyCFunctionWithKeywords(THPVariable__sparse_compressed_tensor_unsafe), + METH_VARARGS | METH_KEYWORDS | METH_STATIC, + nullptr}, + {"sparse_compressed_tensor", + castPyCFunctionWithKeywords(THPVariable_sparse_compressed_tensor), + METH_VARARGS | METH_KEYWORDS | METH_STATIC, + nullptr}, + {"sparse_csr_tensor", + castPyCFunctionWithKeywords(THPVariable_sparse_csr_tensor), + METH_VARARGS | METH_KEYWORDS | METH_STATIC, + nullptr}, + {"sparse_csc_tensor", + castPyCFunctionWithKeywords(THPVariable_sparse_csc_tensor), + METH_VARARGS | METH_KEYWORDS | METH_STATIC, + nullptr}, + {"sparse_bsr_tensor", + castPyCFunctionWithKeywords(THPVariable_sparse_bsr_tensor), + METH_VARARGS | METH_KEYWORDS | METH_STATIC, + nullptr}, + {"sparse_bsc_tensor", + castPyCFunctionWithKeywords(THPVariable_sparse_bsc_tensor), + METH_VARARGS | METH_KEYWORDS | METH_STATIC, + nullptr}, + {"_sparse_csr_tensor_unsafe", + castPyCFunctionWithKeywords(THPVariable__sparse_csr_tensor_unsafe), + METH_VARARGS | METH_KEYWORDS | METH_STATIC, + nullptr}, + {"_sparse_csc_tensor_unsafe", + castPyCFunctionWithKeywords(THPVariable__sparse_csc_tensor_unsafe), + METH_VARARGS | METH_KEYWORDS | METH_STATIC, + nullptr}, + {"_sparse_bsr_tensor_unsafe", + castPyCFunctionWithKeywords(THPVariable__sparse_bsr_tensor_unsafe), + METH_VARARGS | METH_KEYWORDS | METH_STATIC, + nullptr}, + {"_sparse_bsc_tensor_unsafe", + castPyCFunctionWithKeywords(THPVariable__sparse_bsc_tensor_unsafe), + METH_VARARGS | METH_KEYWORDS | METH_STATIC, + nullptr}, + {"tensor", + castPyCFunctionWithKeywords(THPVariable_tensor), + METH_VARARGS | METH_KEYWORDS | METH_STATIC, + nullptr}, + {"get_device", + castPyCFunctionWithKeywords(THPVariable_get_device), + METH_VARARGS | METH_KEYWORDS | METH_STATIC, + nullptr}, + {"numel", + castPyCFunctionWithKeywords(THPVariable_numel), + METH_VARARGS | METH_KEYWORDS | METH_STATIC, + nullptr}, }; -static PyObject * THPVariable_nonzero(PyObject* self, PyObject* args, PyObject* kwargs) -{ +static PyObject* THPVariable_nonzero( + PyObject* self, + PyObject* args, + PyObject* kwargs) { HANDLE_TH_ERRORS static PythonArgParser parser({ - "nonzero(Tensor input, *, bool as_tuple=False, Tensor out=None)", + "nonzero(Tensor input, *, bool as_tuple=False, Tensor out=None)", }); ParsedArgs<3> parsed_args; auto r = parser.parse(args, kwargs, parsed_args); - if(r.has_torch_function()){ - return handle_torch_function(r, args, kwargs, THPVariableFunctionsModule, "torch"); + if (r.has_torch_function()) { + return handle_torch_function( + r, args, kwargs, THPVariableFunctionsModule, "torch"); } const auto as_tuple = r.toBool(1); const auto has_out = !r.isNone(2); if (as_tuple) { - TORCH_CHECK(!has_out, "nonzero does not support the out kwarg when as_tuple is True"); + TORCH_CHECK( + !has_out, + "nonzero does not support the out kwarg when as_tuple is True"); return wrap(dispatch_nonzero_numpy(r.tensor(0))); } @@ -841,18 +1167,23 @@ static PyObject * THPVariable_nonzero(PyObject* self, PyObject* args, PyObject* END_HANDLE_TH_ERRORS } -static PyObject * THPVariable_numel(PyObject* self_, PyObject* args, PyObject* kwargs) -{ +static PyObject* THPVariable_numel( + PyObject* self_, + PyObject* args, + PyObject* kwargs) { HANDLE_TH_ERRORS - static PythonArgParser parser({ - "numel(Tensor input)", - }, /*traceable=*/false); + static PythonArgParser parser( + { + "numel(Tensor input)", + }, + /*traceable=*/false); ParsedArgs<1> parsed_args; auto r = parser.parse(args, kwargs, parsed_args); - if(r.has_torch_function()){ - return handle_torch_function(r, args, kwargs, THPVariableFunctionsModule, "torch"); + if (r.has_torch_function()) { + return handle_torch_function( + r, args, kwargs, THPVariableFunctionsModule, "torch"); } if (r.idx == 0) { @@ -863,35 +1194,41 @@ static PyObject * THPVariable_numel(PyObject* self_, PyObject* args, PyObject* k } // Sharded function definitions -void gatherTorchFunctions_0(std::vector &torch_functions); -void gatherTorchFunctions_1(std::vector &torch_functions); -void gatherTorchFunctions_2(std::vector &torch_functions); - -void gatherTorchFunctions(std::vector &torch_functions) { - constexpr size_t num_functions = sizeof(torch_functions_manual) / sizeof(torch_functions_manual[0]); - torch_functions.assign(torch_functions_manual, - torch_functions_manual + num_functions); - // NOTE: Must be synced with num_shards in tools/autograd/gen_python_functions.py +void gatherTorchFunctions_0(std::vector& torch_functions); +void gatherTorchFunctions_1(std::vector& torch_functions); +void gatherTorchFunctions_2(std::vector& torch_functions); + +void gatherTorchFunctions(std::vector& torch_functions) { + constexpr size_t num_functions = + sizeof(torch_functions_manual) / sizeof(torch_functions_manual[0]); + torch_functions.assign( + torch_functions_manual, torch_functions_manual + num_functions); + // NOTE: Must be synced with num_shards in + // tools/autograd/gen_python_functions.py gatherTorchFunctions_0(torch_functions); gatherTorchFunctions_1(torch_functions); gatherTorchFunctions_2(torch_functions); - static std::array, 4> aliases{{ - // Canonical function, alias name - {"sspaddmm", "saddmm"}, - {"mm", "spmm"}, - {"mm", "dsmm"}, - {"hspmm", "hsmm"} - }}; + static std::array, 4> aliases{ + {// Canonical function, alias name + {"sspaddmm", "saddmm"}, + {"mm", "spmm"}, + {"mm", "dsmm"}, + {"hspmm", "hsmm"}}}; for (const auto& alias : aliases) { - auto it = std::find_if(torch_functions.begin(), torch_functions.end(), - [&](const PyMethodDef& def) { - return strcmp(def.ml_name, alias.first) == 0; - }); + auto it = std::find_if( + torch_functions.begin(), + torch_functions.end(), + [&](const PyMethodDef& def) { + return strcmp(def.ml_name, alias.first) == 0; + }); TORCH_INTERNAL_ASSERT( it != torch_functions.end(), - "Failed to create function alias from ", alias.first, " to ", alias.second); + "Failed to create function alias from ", + alias.first, + " to ", + alias.second); PyMethodDef alias_def = *it; alias_def.ml_name = alias.second; @@ -903,47 +1240,48 @@ void gatherTorchFunctions(std::vector &torch_functions) { } static PyTypeObject THPVariableFunctions = { - PyVarObject_HEAD_INIT(nullptr, 0) - "torch._C._VariableFunctionsClass", /* tp_name */ - 0, /* tp_basicsize */ - 0, /* tp_itemsize */ - nullptr, /* tp_dealloc */ - 0, /* tp_vectorcall_offset */ - nullptr, /* tp_getattr */ - nullptr, /* tp_setattr */ - nullptr, /* tp_reserved */ - nullptr, /* tp_repr */ - nullptr, /* tp_as_number */ - nullptr, /* tp_as_sequence */ - nullptr, /* tp_as_mapping */ - nullptr, /* tp_hash */ - nullptr, /* tp_call */ - nullptr, /* tp_str */ - nullptr, /* tp_getattro */ - nullptr, /* tp_setattro */ - nullptr, /* tp_as_buffer */ - Py_TPFLAGS_DEFAULT, /* tp_flags */ - nullptr, /* tp_doc */ - nullptr, /* tp_traverse */ - nullptr, /* tp_clear */ - nullptr, /* tp_richcompare */ - 0, /* tp_weaklistoffset */ - nullptr, /* tp_iter */ - nullptr, /* tp_iternext */ - nullptr, /* tp_methods */ - nullptr, /* tp_members */ - nullptr, /* tp_getset */ - nullptr, /* tp_base */ - nullptr, /* tp_dict */ - nullptr, /* tp_descr_get */ - nullptr, /* tp_descr_set */ - 0, /* tp_dictoffset */ - nullptr, /* tp_init */ - nullptr, /* tp_alloc */ - nullptr /* tp_new */ + PyVarObject_HEAD_INIT( + nullptr, + 0) "torch._C._VariableFunctionsClass", /* tp_name */ + 0, /* tp_basicsize */ + 0, /* tp_itemsize */ + nullptr, /* tp_dealloc */ + 0, /* tp_vectorcall_offset */ + nullptr, /* tp_getattr */ + nullptr, /* tp_setattr */ + nullptr, /* tp_reserved */ + nullptr, /* tp_repr */ + nullptr, /* tp_as_number */ + nullptr, /* tp_as_sequence */ + nullptr, /* tp_as_mapping */ + nullptr, /* tp_hash */ + nullptr, /* tp_call */ + nullptr, /* tp_str */ + nullptr, /* tp_getattro */ + nullptr, /* tp_setattro */ + nullptr, /* tp_as_buffer */ + Py_TPFLAGS_DEFAULT, /* tp_flags */ + nullptr, /* tp_doc */ + nullptr, /* tp_traverse */ + nullptr, /* tp_clear */ + nullptr, /* tp_richcompare */ + 0, /* tp_weaklistoffset */ + nullptr, /* tp_iter */ + nullptr, /* tp_iternext */ + nullptr, /* tp_methods */ + nullptr, /* tp_members */ + nullptr, /* tp_getset */ + nullptr, /* tp_base */ + nullptr, /* tp_dict */ + nullptr, /* tp_descr_get */ + nullptr, /* tp_descr_set */ + 0, /* tp_dictoffset */ + nullptr, /* tp_init */ + nullptr, /* tp_alloc */ + nullptr /* tp_new */ }; -void initTorchFunctions(PyObject *module) { +void initTorchFunctions(PyObject* module) { static std::vector torch_functions; gatherTorchFunctions(torch_functions); THPVariableFunctions.tp_methods = torch_functions.data(); @@ -955,16 +1293,21 @@ void initTorchFunctions(PyObject *module) { // Steals Py_INCREF(&THPVariableFunctions); - if (PyModule_AddObject(module, "_VariableFunctionsClass", - reinterpret_cast(&THPVariableFunctions)) < 0) { + if (PyModule_AddObject( + module, + "_VariableFunctionsClass", + reinterpret_cast(&THPVariableFunctions)) < 0) { throw python_error(); } // PyType_GenericNew returns a new reference - THPVariableFunctionsModule = PyType_GenericNew(&THPVariableFunctions, Py_None, Py_None); + THPVariableFunctionsModule = + PyType_GenericNew(&THPVariableFunctions, Py_None, Py_None); // PyModule_AddObject steals a reference - if (PyModule_AddObject(module, "_VariableFunctions", THPVariableFunctionsModule) < 0) { + if (PyModule_AddObject( + module, "_VariableFunctions", THPVariableFunctionsModule) < 0) { throw python_error(); } } -}} // namespace torch::autograd +} // namespace autograd +} // namespace torch diff --git a/torch/csrc/autograd/python_variable.cpp b/torch/csrc/autograd/python_variable.cpp index e7d6bd83c53dd4..fdbba226c27171 100644 --- a/torch/csrc/autograd/python_variable.cpp +++ b/torch/csrc/autograd/python_variable.cpp @@ -22,21 +22,19 @@ #include #include #include +#include #include #include +#include #include #include -#include -#include -#include -#include #include +#include #include -#include +#include -#include #include - +#include #include @@ -47,14 +45,16 @@ #include #include - - using namespace at; using namespace torch; using namespace torch::autograd; -std::pair parseIValuesToPyArgsKwargs(const c10::OperatorHandle& op, const std::vector& arguments) { - TORCH_CHECK(PyGILState_Check(), "GIL must be held before you call parseIValuesToPyArgsKwargs"); +std::pair parseIValuesToPyArgsKwargs( + const c10::OperatorHandle& op, + const std::vector& arguments) { + TORCH_CHECK( + PyGILState_Check(), + "GIL must be held before you call parseIValuesToPyArgsKwargs"); const auto& schema = op.schema(); py::dict kwargs; // About all the pointers: @@ -99,28 +99,36 @@ std::pair parseIValuesToPyArgsKwargs(const c10::OperatorHa } } - auto args = py::reinterpret_steal(PyTuple_New(positional_default_start)); + auto args = + py::reinterpret_steal(PyTuple_New(positional_default_start)); auto schemaAwareToPyObject = [&](int64_t idx) -> py::object { const auto& arg = schema.arguments()[idx]; auto match = [&](c10::TypeKind kind) { const auto& t = arg.real_type(); - if (t->kind() == kind) return true; + if (t->kind() == kind) + return true; if (auto opt_t = t->cast()) { - if (opt_t->getElementType()->kind() == kind) return true; + if (opt_t->getElementType()->kind() == kind) + return true; } return false; }; if (arguments[idx].isNone()) { return py::none(); } else if (match(c10::ScalarTypeType::Kind)) { - auto* obj = getTHPDtype(static_cast(arguments[idx].toInt())); - return py::reinterpret_borrow(reinterpret_cast(obj)); + auto* obj = + getTHPDtype(static_cast(arguments[idx].toInt())); + return py::reinterpret_borrow( + reinterpret_cast(obj)); } else if (match(c10::LayoutType::Kind)) { - auto* obj = getTHPLayout(static_cast(arguments[idx].toInt())); - return py::reinterpret_borrow(reinterpret_cast(obj)); + auto* obj = + getTHPLayout(static_cast(arguments[idx].toInt())); + return py::reinterpret_borrow( + reinterpret_cast(obj)); } else if (match(c10::MemoryFormatType::Kind)) { - return torch::utils::getTHPMemoryFormat(static_cast(arguments[idx].toInt())); + return torch::utils::getTHPMemoryFormat( + static_cast(arguments[idx].toInt())); } else { return torch::jit::toPyObject(arguments[idx]); } @@ -128,13 +136,15 @@ std::pair parseIValuesToPyArgsKwargs(const c10::OperatorHa // Populate positional arguments for (const auto idx : c10::irange(positional_default_start)) { - PyTuple_SET_ITEM(args.ptr(), idx, schemaAwareToPyObject(idx).release().ptr()); + PyTuple_SET_ITEM( + args.ptr(), idx, schemaAwareToPyObject(idx).release().ptr()); } // Populate keyword arguments for (const auto idx : c10::irange(kwarg_only_start, arguments.size())) { // But don't populate default keyword arguments - if (is_default(idx)) continue; + if (is_default(idx)) + continue; const auto& arg = schema.arguments()[idx]; kwargs[py::cast(arg.name())] = schemaAwareToPyObject(idx); } @@ -146,19 +156,28 @@ void pushPyOutToStack( torch::jit::Stack* stack, py::object out, const char* msg) { - TORCH_CHECK(PyGILState_Check(), "GIL must be held before you call pushPyOutToStack"); + TORCH_CHECK( + PyGILState_Check(), "GIL must be held before you call pushPyOutToStack"); auto schema_returns = op.schema().returns(); const auto num_returns = schema_returns.size(); if (num_returns == 0) { // Check that we got a None return from Python. Anything else is an error. - TORCH_CHECK(out.is(py::none()), "Expected ", msg, " for ", op.operator_name(), - " to return None but it returned something else instead."); + TORCH_CHECK( + out.is(py::none()), + "Expected ", + msg, + " for ", + op.operator_name(), + " to return None but it returned something else instead."); } else if (num_returns == 1) { - torch::jit::push(stack, torch::jit::toIValue(out.ptr(), schema_returns[0].type())); + torch::jit::push( + stack, torch::jit::toIValue(out.ptr(), schema_returns[0].type())); } else { auto outs = py::cast(out); for (const auto idx : c10::irange(outs.size())) { - torch::jit::push(stack, torch::jit::toIValue(outs[idx].ptr(), schema_returns[idx].type())); + torch::jit::push( + stack, + torch::jit::toIValue(outs[idx].ptr(), schema_returns[idx].type())); } } } @@ -179,7 +198,10 @@ std::string concrete_name_fn(const c10::impl::PyInterpreter* self) { // One alternative to this is using PyObject_IsInstance // to get at this information. However, we don't want to risk an incorrect // `__instancecheck__` changing the semantics here. -void concrete_decref_fn(const c10::impl::PyInterpreter* self, PyObject* pyobj, bool is_tensor) { +void concrete_decref_fn( + const c10::impl::PyInterpreter* self, + PyObject* pyobj, + bool is_tensor) { // Leak the pyobj if not initialized. This can happen if we are running // exit handlers that are destructing tensors with residual (owned) // PyObjects stored in them. @@ -189,7 +211,8 @@ void concrete_decref_fn(const c10::impl::PyInterpreter* self, PyObject* pyobj, b pybind11::gil_scoped_acquire gil; // Two possibilities: // 1. We are decref-ing a tensor. Then we must be careful about - // PyObject resurrection (this only applies to Tensors, see THPVariable_clear). + // PyObject resurrection (this only applies to Tensors, see + // THPVariable_clear). // 2. We are decref-ing some other Python object. We don't do // PyObject resurrection on non-Tensors, so we just carry on as usual if (is_tensor && Py_REFCNT(pyobj) > 1) { @@ -199,26 +222,35 @@ void concrete_decref_fn(const c10::impl::PyInterpreter* self, PyObject* pyobj, b // so that it fails on subsequent uses. Don't raise an error here; // you're probably in a destructor. TORCH_WARN( - "Deallocating Tensor that still has live PyObject references. " - "This probably happened because you took out a weak reference to " - "Tensor and didn't call _fix_weakref() after dereferencing it. " - "Subsequent accesses to this tensor via the PyObject will now fail." - ); + "Deallocating Tensor that still has live PyObject references. " + "This probably happened because you took out a weak reference to " + "Tensor and didn't call _fix_weakref() after dereferencing it. " + "Subsequent accesses to this tensor via the PyObject will now fail."); ((THPVariable*)pyobj)->cdata = MaybeOwned(); } Py_DECREF(pyobj); }; -c10::intrusive_ptr concrete_detach_fn(const c10::impl::PyInterpreter*, const c10::TensorImpl* self); +c10::intrusive_ptr concrete_detach_fn( + const c10::impl::PyInterpreter*, + const c10::TensorImpl* self); void concrete_dispatch_fn( const c10::impl::PyInterpreter*, const c10::OperatorHandle& op, torch::jit::Stack* stack, const std::shared_ptr& type); -bool concrete_is_contiguous_fn(const c10::impl::PyInterpreter*, const c10::TensorImpl* self); -c10::Device concrete_device_fn(const c10::impl::PyInterpreter*, const c10::TensorImpl* self); -int64_t concrete_dim_fn(const c10::impl::PyInterpreter*, const c10::TensorImpl* self); -c10::IntArrayRef concrete_strides_fn(const c10::impl::PyInterpreter*, const c10::TensorImpl* self); +bool concrete_is_contiguous_fn( + const c10::impl::PyInterpreter*, + const c10::TensorImpl* self); +c10::Device concrete_device_fn( + const c10::impl::PyInterpreter*, + const c10::TensorImpl* self); +int64_t concrete_dim_fn( + const c10::impl::PyInterpreter*, + const c10::TensorImpl* self); +c10::IntArrayRef concrete_strides_fn( + const c10::impl::PyInterpreter*, + const c10::TensorImpl* self); class PyInterpreterHolder { public: @@ -267,9 +299,9 @@ c10::impl::PyInterpreter* getPyInterpreter() { return self_interpreter.get(); } -PyObject *THPVariableClass = nullptr; +PyObject* THPVariableClass = nullptr; -PyObject *ParameterClass = nullptr; +PyObject* ParameterClass = nullptr; static PyObject* THPVariable_NewWithVar( PyTypeObject* type, @@ -281,26 +313,31 @@ static const char* VOLATILE_WARNING = "volatile was removed and now has no effect. Use " "`with torch.no_grad():` instead."; -static bool check_has_torch_dispatch(PyObject *obj) { - PyTypeObject *tp = Py_TYPE(obj); +static bool check_has_torch_dispatch(PyObject* obj) { + PyTypeObject* tp = Py_TYPE(obj); if (THPVariable_CheckTypeExact(tp)) { return false; } py::object attr = PyObject_FastGetAttrString(obj, "__torch_dispatch__"); - return (attr.ptr() != nullptr && - attr.ptr() != torch::disabled_torch_dispatch_impl() - ); + return ( + attr.ptr() != nullptr && + attr.ptr() != torch::disabled_torch_dispatch_impl()); } // NOLINTNEXTLINE -static PyObject* device_to_py_class_ [static_cast(c10::DeviceType::COMPILE_TIME_MAX_DEVICE_TYPES)]; +static PyObject* device_to_py_class_[static_cast( + c10::DeviceType::COMPILE_TIME_MAX_DEVICE_TYPES)]; -void registerPythonTensorClass(const std::string& device, PyObject* python_tensor_class) { +void registerPythonTensorClass( + const std::string& device, + PyObject* python_tensor_class) { c10::Device dev(device); - TORCH_CHECK(dev.type() == kXLA, "Only the python class for XLA can be overriden"); + TORCH_CHECK( + dev.type() == kXLA, "Only the python class for XLA can be overriden"); if (device_to_py_class_[static_cast(dev.type())] != nullptr) { - TORCH_WARN("Overriding a previously registered python class for ", dev.str()); + TORCH_WARN( + "Overriding a previously registered python class for ", dev.str()); } device_to_py_class_[static_cast(dev.type())] = python_tensor_class; @@ -311,8 +348,7 @@ static PyObject* getPythonTensorClass(c10::Device d) { } // TODO: Make this take Variable by const reference -PyObject * THPVariable_Wrap(at::TensorBase var) -{ +PyObject* THPVariable_Wrap(at::TensorBase var) { if (!var.defined()) { Py_RETURN_NONE; } @@ -358,12 +394,11 @@ PyObject * THPVariable_Wrap(at::TensorBase var) if (C10_LIKELY(var.device().type() != c10::kXLA)) { return THPVariable_NewWithVar( - (PyTypeObject*)THPVariableClass, std::move(var), status); + (PyTypeObject*)THPVariableClass, std::move(var), status); } if (auto clazz = getPythonTensorClass(var.device())) { - return THPVariable_NewWithVar( - (PyTypeObject*)clazz, std::move(var), status); + return THPVariable_NewWithVar((PyTypeObject*)clazz, std::move(var), status); } return THPVariable_NewWithVar( @@ -426,7 +461,7 @@ static bool THPVariable_tryResurrect(THPVariable* self) { // NB: this will overreport _Py_RefTotal but based on inspection of object.c // there is no way to avoid this #ifdef Py_TRACE_REFS - _Py_AddToAllObjects(reinterpret_cast(self), 1); + _Py_AddToAllObjects(reinterpret_cast(self), 1); #endif Py_INCREF(self); @@ -442,7 +477,6 @@ static bool THPVariable_tryResurrect(THPVariable* self) { return true; } - static int THPVariable_clear(THPVariable* self) { // Is it OK for an object to still be live after running // tp_clear? Yes. When Python is breaking reference cycles, it can't assume @@ -457,7 +491,8 @@ static int THPVariable_clear(THPVariable* self) { // resurrection). // 2. The PyObject is part of a reference cycle. This case should not actually - // be possible, due to the logic in our tp_traverse (THPVariable_subclass_traverse). + // be possible, due to the logic in our tp_traverse + // (THPVariable_subclass_traverse). // In fact, resurrecting here breaks the invariant that "C++ owns Python only // when PyObject's refcount would otherwise be 0". Most immediately, as we're @@ -527,8 +562,10 @@ static int THPVariable_clear(THPVariable* self) { return 0; } - -PyObject *THPVariable_pynew(PyTypeObject *type, PyObject *args, PyObject *kwargs); +PyObject* THPVariable_pynew( + PyTypeObject* type, + PyObject* args, + PyObject* kwargs); static PyObject* THPVariable_fix_weakref(PyObject* self, PyObject* noargs) { const auto& var = THPVariable_Unpack(self); @@ -537,17 +574,21 @@ static PyObject* THPVariable_fix_weakref(PyObject* self, PyObject* noargs) { } // Instantiates a subclass of self with the same data. -static PyObject* THPVariable_as_subclass(PyObject* _self, PyObject* args, PyObject* kwargs) { +static PyObject* THPVariable_as_subclass( + PyObject* _self, + PyObject* args, + PyObject* kwargs) { HANDLE_TH_ERRORS const auto& self = THPVariable_Unpack(_self); static PythonArgParser parser({ - "as_subclass(PyObject* cls)", + "as_subclass(PyObject* cls)", }); ParsedArgs<1> parsed_args{}; auto r = parser.parse(_self, args, kwargs, parsed_args); PyObject* cls = r.pyobject(0); if (!PyType_Check(cls)) { - throw torch::TypeError("cls must be a type (got %s)", Py_TYPE(cls)->tp_name); + throw torch::TypeError( + "cls must be a type (got %s)", Py_TYPE(cls)->tp_name); } return THPVariable_NewWithVar( (PyTypeObject*)cls, @@ -556,26 +597,30 @@ static PyObject* THPVariable_as_subclass(PyObject* _self, PyObject* args, PyObje END_HANDLE_TH_ERRORS } -static PyObject* THPVariable_make_subclass(PyObject* _ignored, PyObject* args, PyObject* kwargs) { +static PyObject* THPVariable_make_subclass( + PyObject* _ignored, + PyObject* args, + PyObject* kwargs) { HANDLE_TH_ERRORS static PythonArgParser parser({ - "_make_subclass(PyObject* cls, Tensor data, bool require_grad=False, *, c10::string_view? dispatch_sizes_strides_policy=None, bool dispatch_device=False)", + "_make_subclass(PyObject* cls, Tensor data, bool require_grad=False, *, c10::string_view? dispatch_sizes_strides_policy=None, bool dispatch_device=False)", }); ParsedArgs<5> parsed_args{}; auto r = parser.parse(args, kwargs, parsed_args); PyObject* cls = r.pyobject(0); if (!PyType_Check(cls)) { - throw torch::TypeError("cls must be a type (got %s)", Py_TYPE(cls)->tp_name); + throw torch::TypeError( + "cls must be a type (got %s)", Py_TYPE(cls)->tp_name); } auto data = r.tensor(1).detach(); // creates a fresh Tensor (DEFINITELY_UNINITIALIZED) - // We set `data`'s `allow_tensor_metadata_change` to true here, because we want to - // allow the following use case for backward compatibility: + // We set `data`'s `allow_tensor_metadata_change` to true here, because we + // want to allow the following use case for backward compatibility: // // ```python // rnn = torch.nn.RNN(100, 100, 2) - // # The following calls `torch._cudnn_rnn_flatten_weight(rnn._flat_weights, ...)`, - // # which changes storage of `rnn`'s weights in-place + // # The following calls `torch._cudnn_rnn_flatten_weight(rnn._flat_weights, + // ...)`, # which changes storage of `rnn`'s weights in-place // rnn.flatten_parameters() // ``` data.unsafeGetTensorImpl()->set_allow_tensor_metadata_change(true); @@ -595,7 +640,10 @@ static PyObject* THPVariable_make_subclass(PyObject* _ignored, PyObject* args, P END_HANDLE_TH_ERRORS } -static PyObject* THPVariable_make_wrapper_subclass(PyObject*, PyObject* args, PyObject* kwargs) { +static PyObject* THPVariable_make_wrapper_subclass( + PyObject*, + PyObject* args, + PyObject* kwargs) { HANDLE_TH_ERRORS // NB: pin_memory doesn't actually do anything // TODO: strides variant? @@ -609,7 +657,11 @@ static PyObject* THPVariable_make_wrapper_subclass(PyObject*, PyObject* args, Py auto r = parser.parse(args, kwargs, parsed_args); PyObject* cls = r.pyobject(0); - TORCH_CHECK_TYPE(PyType_Check(cls), "cls must be a type (got ", Py_TYPE(cls)->tp_name, ")"); + TORCH_CHECK_TYPE( + PyType_Check(cls), + "cls must be a type (got ", + Py_TYPE(cls)->tp_name, + ")"); // This is an important safety check; without it, the default behavior will be // to continue on to the underlying CPU/CUDA kernel advertised by the dispatch @@ -619,30 +671,35 @@ static PyObject* THPVariable_make_wrapper_subclass(PyObject*, PyObject* args, Py // dispatch and then go again, triggering segfault. TBH I'm thinking I want // to delete this function entirely py::object attr = PyObject_FastGetAttrString(cls, "__torch_dispatch__"); - TORCH_CHECK_TYPE(attr.ptr() != nullptr && attr.ptr() != torch::disabled_torch_dispatch_impl() -, - ((PyTypeObject*)cls)->tp_name, " must define __torch_dispatch__"); + TORCH_CHECK_TYPE( + attr.ptr() != nullptr && + attr.ptr() != torch::disabled_torch_dispatch_impl(), + ((PyTypeObject*)cls)->tp_name, + " must define __torch_dispatch__"); const auto options = TensorOptions() - .dtype(r.scalartype(5)) - .device(r.device(7)) - .layout(r.layoutOptional(6)) - // NB: long standing issue, requires_grad is not respected here; you - // have to set it post facto, see https://github.com/pytorch/pytorch/issues/26428 - // .requires_grad(r.toBool(7)) - .pinned_memory(r.toBool(8)); + .dtype(r.scalartype(5)) + .device(r.device(7)) + .layout(r.layoutOptional(6)) + // NB: long standing issue, requires_grad is not + // respected here; you have to set it post facto, see + // https://github.com/pytorch/pytorch/issues/26428 + // .requires_grad(r.toBool(7)) + .pinned_memory(r.toBool(8)); // don't bother releasing GIL here, as we are not allocating any nontrivial // data // TODO: for_blob produces non-resizable tensors, we might want this to be // resizable (have to define a custom allocator in that case) - auto data = at::for_blob(nullptr, r.intlist(1)) - .strides(r.intlistOptional(2)) - .storage_offset(r.toInt64Optional(3)) - .context(nullptr, [](void *ctx) {}) - .target_device(options.device()) // TODO: this shouldn't be necessary if it came from options - .options(options) - .make_tensor(); + auto data = + at::for_blob(nullptr, r.intlist(1)) + .strides(r.intlistOptional(2)) + .storage_offset(r.toInt64Optional(3)) + .context(nullptr, [](void* ctx) {}) + .target_device(options.device()) // TODO: this shouldn't be necessary + // if it came from options + .options(options) + .make_tensor(); data.set_requires_grad(r.toBool(9)); const auto sizes_strides_policy = r.stringViewOptional(10); @@ -661,21 +718,20 @@ static PyObject* THPVariable_make_wrapper_subclass(PyObject*, PyObject* args, Py END_HANDLE_TH_ERRORS } -typedef PyObject *(*getter)(PyObject *, void *); -typedef int (*setter)(PyObject *, PyObject *, void *); +typedef PyObject* (*getter)(PyObject*, void*); +typedef int (*setter)(PyObject*, PyObject*, void*); -PyObject *THPVariable_get_python_dispatch(THPVariable *self, void *unused) -{ +PyObject* THPVariable_get_python_dispatch(THPVariable* self, void* unused) { HANDLE_TH_ERRORS const auto& var = THPVariable_Unpack(self); - return torch::autograd::utils::wrap(var.unsafeGetTensorImpl()->is_python_dispatch()); + return torch::autograd::utils::wrap( + var.unsafeGetTensorImpl()->is_python_dispatch()); END_HANDLE_TH_ERRORS } -PyObject *THPVariable_get_T(THPVariable *self, void *unused) -{ +PyObject* THPVariable_get_T(THPVariable* self, void* unused) { HANDLE_TH_ERRORS - if (check_has_torch_function((PyObject *)self)) { + if (check_has_torch_function((PyObject*)self)) { return handle_torch_function_getter(self, "T"); } const auto& var = THPVariable_Unpack(self); @@ -683,10 +739,9 @@ PyObject *THPVariable_get_T(THPVariable *self, void *unused) END_HANDLE_TH_ERRORS } -PyObject *THPVariable_get_H(THPVariable *self, void *unused) -{ +PyObject* THPVariable_get_H(THPVariable* self, void* unused) { HANDLE_TH_ERRORS - if (check_has_torch_function((PyObject *)self)) { + if (check_has_torch_function((PyObject*)self)) { return handle_torch_function_getter(self, "H"); } const auto& var = THPVariable_Unpack(self); @@ -694,10 +749,9 @@ PyObject *THPVariable_get_H(THPVariable *self, void *unused) END_HANDLE_TH_ERRORS } -PyObject *THPVariable_get_mT(THPVariable *self, void *unused) -{ +PyObject* THPVariable_get_mT(THPVariable* self, void* unused) { HANDLE_TH_ERRORS - if (check_has_torch_function((PyObject *)self)) { + if (check_has_torch_function((PyObject*)self)) { return handle_torch_function_getter(self, "mT"); } const auto& var = THPVariable_Unpack(self); @@ -705,10 +759,9 @@ PyObject *THPVariable_get_mT(THPVariable *self, void *unused) END_HANDLE_TH_ERRORS } -PyObject *THPVariable_get_mH(THPVariable *self, void *unused) -{ +PyObject* THPVariable_get_mH(THPVariable* self, void* unused) { HANDLE_TH_ERRORS - if (check_has_torch_function((PyObject *)self)) { + if (check_has_torch_function((PyObject*)self)) { return handle_torch_function_getter(self, "mH"); } const auto& var = THPVariable_Unpack(self); @@ -716,10 +769,9 @@ PyObject *THPVariable_get_mH(THPVariable *self, void *unused) END_HANDLE_TH_ERRORS } -PyObject *THPVariable_get_cdata(THPVariable *self, void *unused) -{ +PyObject* THPVariable_get_cdata(THPVariable* self, void* unused) { HANDLE_TH_ERRORS - if (check_has_torch_function((PyObject *)self)) { + if (check_has_torch_function((PyObject*)self)) { return handle_torch_function_getter(self, "_cdata"); } const auto& var = THPVariable_Unpack(self); @@ -727,10 +779,9 @@ PyObject *THPVariable_get_cdata(THPVariable *self, void *unused) END_HANDLE_TH_ERRORS } -PyObject *THPVariable_get_version(THPVariable *self, void *unused) -{ +PyObject* THPVariable_get_version(THPVariable* self, void* unused) { HANDLE_TH_ERRORS - if (check_has_torch_function((PyObject *)self)) { + if (check_has_torch_function((PyObject*)self)) { return handle_torch_function_getter(self, "_version"); } const auto& var = THPVariable_Unpack(self); @@ -738,10 +789,9 @@ PyObject *THPVariable_get_version(THPVariable *self, void *unused) END_HANDLE_TH_ERRORS } -PyObject *THPVariable_get_grad_fn(THPVariable *self, void *unused) -{ +PyObject* THPVariable_get_grad_fn(THPVariable* self, void* unused) { HANDLE_TH_ERRORS - if (check_has_torch_function((PyObject *)self)) { + if (check_has_torch_function((PyObject*)self)) { return handle_torch_function_getter(self, "grad_fn"); } const auto& var = THPVariable_Unpack(self); @@ -752,33 +802,34 @@ PyObject *THPVariable_get_grad_fn(THPVariable *self, void *unused) END_HANDLE_TH_ERRORS } -static int THPVariable_set_grad_fn(THPVariable *self, PyObject *obj, void *unused) -{ +static int THPVariable_set_grad_fn( + THPVariable* self, + PyObject* obj, + void* unused) { HANDLE_TH_ERRORS - if (check_has_torch_function((PyObject *)self)) { + if (check_has_torch_function((PyObject*)self)) { return handle_torch_function_setter(self, "_grad_fn", obj); } - THPUtils_assertRet(-1, obj, "Deletion of _grad_fn not allowed. Detach tensor instead!"); + THPUtils_assertRet( + -1, obj, "Deletion of _grad_fn not allowed. Detach tensor instead!"); THPUtils_assertRet(-1, obj == Py_None, "_grad_fn can be only set to None"); THPVariable_Unpack(self).detach_(); return 0; END_HANDLE_TH_ERRORS_RET(-1) } -static PyObject *THPVariable_is_leaf(THPVariable *self, void *unused) -{ +static PyObject* THPVariable_is_leaf(THPVariable* self, void* unused) { HANDLE_TH_ERRORS - if (check_has_torch_function((PyObject *)self)) { + if (check_has_torch_function((PyObject*)self)) { return handle_torch_function_getter(self, "is_leaf"); } return PyBool_FromLong(!THPVariable_Unpack(self).grad_fn()); END_HANDLE_TH_ERRORS } -static PyObject * THPVariable_get_data(THPVariable *self, void *unused) -{ +static PyObject* THPVariable_get_data(THPVariable* self, void* unused) { HANDLE_TH_ERRORS - if (check_has_torch_function((PyObject *)self)) { + if (check_has_torch_function((PyObject*)self)) { return handle_torch_function_getter(self, "data"); } const auto& var = THPVariable_Unpack(self).variable_data(); @@ -786,15 +837,16 @@ static PyObject * THPVariable_get_data(THPVariable *self, void *unused) END_HANDLE_TH_ERRORS } -int THPVariable_set_data(THPVariable *self, PyObject *data, void *unused) -{ +int THPVariable_set_data(THPVariable* self, PyObject* data, void* unused) { HANDLE_TH_ERRORS - if (check_has_torch_function((PyObject *)self)) { + if (check_has_torch_function((PyObject*)self)) { return handle_torch_function_setter(self, "data", data); } - THPUtils_assertRet(-1, data, "Deleting tensor data is not allowed. Delete tensor instead!"); + THPUtils_assertRet( + -1, data, "Deleting tensor data is not allowed. Delete tensor instead!"); if (!THPVariable_Check(data)) { - throw torch::TypeError("Variable data has to be a tensor, but got %s", Py_TYPE(data)->tp_name); + throw torch::TypeError( + "Variable data has to be a tensor, but got %s", Py_TYPE(data)->tp_name); } THPVariable_Unpack(self).set_data(THPVariable_Unpack(data)); @@ -802,20 +854,18 @@ int THPVariable_set_data(THPVariable *self, PyObject *data, void *unused) END_HANDLE_TH_ERRORS_RET(-1) } -PyObject *THPVariable_get_grad(THPVariable *self, void *unused) -{ +PyObject* THPVariable_get_grad(THPVariable* self, void* unused) { HANDLE_TH_ERRORS - if (check_has_torch_function((PyObject *)self)) { + if (check_has_torch_function((PyObject*)self)) { return handle_torch_function_getter(self, "grad"); } return THPVariable_Wrap(THPVariable_Unpack(self).grad()); END_HANDLE_TH_ERRORS } -int THPVariable_set_grad(THPVariable *self, PyObject *py_grad, void *unused) -{ +int THPVariable_set_grad(THPVariable* self, PyObject* py_grad, void* unused) { HANDLE_TH_ERRORS - if (check_has_torch_function((PyObject *)self)) { + if (check_has_torch_function((PyObject*)self)) { return handle_torch_function_setter(self, "grad", py_grad); } const auto& var = THPVariable_Unpack(self); @@ -824,22 +874,32 @@ int THPVariable_set_grad(THPVariable *self, PyObject *py_grad, void *unused) return 0; } - TORCH_CHECK_TYPE(THPVariable_Check(py_grad), - "assigned grad expected to be a Tensor or None but got grad of type", THPUtils_typename(py_grad)); - THPUtils_assertRet(-1, self != (THPVariable*)py_grad, + TORCH_CHECK_TYPE( + THPVariable_Check(py_grad), + "assigned grad expected to be a Tensor or None but got grad of type", + THPUtils_typename(py_grad)); + THPUtils_assertRet( + -1, + self != (THPVariable*)py_grad, "can't assign Variable as its own grad"); const auto& grad = THPVariable_Unpack(py_grad); - bool gradIsSparse = (var.dtype() == grad.dtype() && - var.device().type() == grad.device().type() && - grad.layout() == kSparse); - THPUtils_assertRet(-1, grad.options().type_equal(var.options()) || gradIsSparse, + bool gradIsSparse = + (var.dtype() == grad.dtype() && + var.device().type() == grad.device().type() && grad.layout() == kSparse); + THPUtils_assertRet( + -1, + grad.options().type_equal(var.options()) || gradIsSparse, "assigned grad has data of a different type"); if (var.is_cuda()) { - THPUtils_assertRet(-1, grad.get_device() == var.get_device(), + THPUtils_assertRet( + -1, + grad.get_device() == var.get_device(), "assigned grad has data located on a different device"); } - THPUtils_assertRet(-1, grad.sizes().equals(var.sizes()), + THPUtils_assertRet( + -1, + grad.sizes().equals(var.sizes()), "assigned grad has data of a different size"); var.mutable_grad() = grad; @@ -847,49 +907,48 @@ int THPVariable_set_grad(THPVariable *self, PyObject *py_grad, void *unused) END_HANDLE_TH_ERRORS_RET(-1) } -PyObject *THPVariable_get_volatile(THPVariable *self, void *unused) -{ +PyObject* THPVariable_get_volatile(THPVariable* self, void* unused) { HANDLE_TH_ERRORS - if (check_has_torch_function((PyObject *)self)) { + if (check_has_torch_function((PyObject*)self)) { return handle_torch_function_getter(self, "volatile"); } const char* msg = "volatile was removed (Variable.volatile is always False)"; auto r = PyErr_WarnEx(PyExc_UserWarning, msg, 1); - if (r != 0) throw python_error(); + if (r != 0) + throw python_error(); Py_RETURN_FALSE; END_HANDLE_TH_ERRORS } -int THPVariable_set_volatile(THPVariable *self, PyObject *obj, void *unused) -{ +int THPVariable_set_volatile(THPVariable* self, PyObject* obj, void* unused) { HANDLE_TH_ERRORS - if (check_has_torch_function((PyObject *)self)) { + if (check_has_torch_function((PyObject*)self)) { return handle_torch_function_setter(self, "volatile", obj); } auto r = PyErr_WarnEx(PyExc_UserWarning, VOLATILE_WARNING, 1); - if (r != 0) throw python_error(); + if (r != 0) + throw python_error(); return 0; END_HANDLE_TH_ERRORS_RET(-1) } -PyObject *THPVariable_get_output_nr(THPVariable *self, void *unused) -{ +PyObject* THPVariable_get_output_nr(THPVariable* self, void* unused) { HANDLE_TH_ERRORS - if (check_has_torch_function((PyObject *)self)) { + if (check_has_torch_function((PyObject*)self)) { return handle_torch_function_getter(self, "output_nr"); } - const auto output_nr = static_cast(THPVariable_Unpack(self).output_nr()); + const auto output_nr = + static_cast(THPVariable_Unpack(self).output_nr()); return PyInt_FromLong(output_nr); END_HANDLE_TH_ERRORS } -PyObject *THPVariable_get_requires_grad(THPVariable *self, void *unused) -{ +PyObject* THPVariable_get_requires_grad(THPVariable* self, void* unused) { HANDLE_TH_ERRORS - if (check_has_torch_function((PyObject *)self)) { + if (check_has_torch_function((PyObject*)self)) { return handle_torch_function_getter(self, "requires_grad"); } - if(THPVariable_Unpack(self).requires_grad()) { + if (THPVariable_Unpack(self).requires_grad()) { Py_RETURN_TRUE; } else { Py_RETURN_FALSE; @@ -897,13 +956,12 @@ PyObject *THPVariable_get_requires_grad(THPVariable *self, void *unused) END_HANDLE_TH_ERRORS } -PyObject *THPVariable_retains_grad(THPVariable *self, void *unused) -{ +PyObject* THPVariable_retains_grad(THPVariable* self, void* unused) { HANDLE_TH_ERRORS - if (check_has_torch_function((PyObject *)self)) { + if (check_has_torch_function((PyObject*)self)) { return handle_torch_function_getter(self, "retains_grad"); } - if(THPVariable_Unpack(self).retains_grad()) { + if (THPVariable_Unpack(self).retains_grad()) { Py_RETURN_TRUE; } else { Py_RETURN_FALSE; @@ -911,18 +969,16 @@ PyObject *THPVariable_retains_grad(THPVariable *self, void *unused) END_HANDLE_TH_ERRORS } -PyObject *THPVariable_get_ndim(THPVariable *self, void *unused) -{ +PyObject* THPVariable_get_ndim(THPVariable* self, void* unused) { HANDLE_TH_ERRORS - if (check_has_torch_function((PyObject *)self)) { + if (check_has_torch_function((PyObject*)self)) { return handle_torch_function_getter(self, "ndim"); } return PyInt_FromLong(THPVariable_Unpack(self).dim()); END_HANDLE_TH_ERRORS } -PyObject *THPVariable_get_names(PyObject *self, void *unused) -{ +PyObject* THPVariable_get_names(PyObject* self, void* unused) { HANDLE_TH_ERRORS if (check_has_torch_function(self)) { return handle_torch_function_getter((THPVariable*)self, "names"); @@ -932,7 +988,8 @@ PyObject *THPVariable_get_names(PyObject *self, void *unused) const auto& tensor = THPVariable_Unpack(self); size_t size = tensor.dim(); THPObjectPtr tuple(PyTuple_New(size)); - if (!tuple) throw python_error(); + if (!tuple) + throw python_error(); const auto dimnames = tensor.names(); for (const auto i : c10::irange(size)) { @@ -945,12 +1002,14 @@ PyObject *THPVariable_get_names(PyObject *self, void *unused) // the refcount. // Sources: // - https://docs.python.org/3/c-api/tuple.html#c.PyTuple_SetItem - // - https://stackoverflow.com/questions/16400600/how-to-return-a-tuple-containing-a-none-value-from-the-c-api + // - + // https://stackoverflow.com/questions/16400600/how-to-return-a-tuple-containing-a-none-value-from-the-c-api Py_INCREF(Py_None); str = Py_None; } else { str = THPUtils_packString(dimnames[i].symbol().toUnqualString()); - if (!str) throw python_error(); + if (!str) + throw python_error(); } PyTuple_SET_ITEM(tuple.get(), i, str); } @@ -958,7 +1017,7 @@ PyObject *THPVariable_get_names(PyObject *self, void *unused) END_HANDLE_TH_ERRORS } -int THPVariable_set_names(PyObject *self, PyObject *names, void *unused) { +int THPVariable_set_names(PyObject* self, PyObject* names, void* unused) { HANDLE_TH_ERRORS if (check_has_torch_function(self)) { return handle_torch_function_setter((THPVariable*)self, "names", names); @@ -967,7 +1026,8 @@ int THPVariable_set_names(PyObject *self, PyObject *names, void *unused) { if (names == Py_None) { at::internal_set_names_inplace(var, at::nullopt); } else { - THPUtils_assertRet(-1, + THPUtils_assertRet( + -1, THPUtils_checkDimnameList(names), "names must either be None or a tuple of dim names"); at::internal_set_names_inplace(var, torch::parseDimnameList(names)); @@ -976,21 +1036,27 @@ int THPVariable_set_names(PyObject *self, PyObject *names, void *unused) { END_HANDLE_TH_ERRORS_RET(-1) } -int THPVariable_set_requires_grad(THPVariable *self, PyObject *obj, void *unused) -{ +int THPVariable_set_requires_grad( + THPVariable* self, + PyObject* obj, + void* unused) { HANDLE_TH_ERRORS - if (check_has_torch_function((PyObject *)self)) { + if (check_has_torch_function((PyObject*)self)) { return handle_torch_function_setter(self, "requires_grad", obj); } - THPUtils_assertRet(-1, obj && PyBool_Check(obj), "requires_grad must be a bool"); + THPUtils_assertRet( + -1, obj && PyBool_Check(obj), "requires_grad must be a bool"); const auto& var = THPVariable_Unpack(self); auto requires_grad = (obj == Py_True); if (!var.is_leaf()) { - THPUtils_setError(autograd::utils::requires_grad_leaf_error(obj == Py_True).c_str()); + THPUtils_setError( + autograd::utils::requires_grad_leaf_error(obj == Py_True).c_str()); return -1; } - if (requires_grad && !isDifferentiableType(at::typeMetaToScalarType((var.dtype())))) { - THPUtils_setError("only Tensors of floating point and complex dtype can require gradients"); + if (requires_grad && + !isDifferentiableType(at::typeMetaToScalarType((var.dtype())))) { + THPUtils_setError( + "only Tensors of floating point and complex dtype can require gradients"); return -1; } var.set_requires_grad(requires_grad); @@ -998,9 +1064,8 @@ int THPVariable_set_requires_grad(THPVariable *self, PyObject *obj, void *unused END_HANDLE_TH_ERRORS_RET(-1) } -PyObject *THPVariable_get_name(THPVariable* self, void *unused) -{ - if (check_has_torch_function((PyObject *)self)) { +PyObject* THPVariable_get_name(THPVariable* self, void* unused) { + if (check_has_torch_function((PyObject*)self)) { HANDLE_TH_ERRORS return handle_torch_function_getter(self, "name"); END_HANDLE_TH_ERRORS @@ -1011,10 +1076,9 @@ PyObject *THPVariable_get_name(THPVariable* self, void *unused) return THPUtils_packString(tensor.name().c_str()); } -PyObject *THPVariable_get_backwards_hooks(THPVariable *self, void *unused) -{ +PyObject* THPVariable_get_backwards_hooks(THPVariable* self, void* unused) { HANDLE_TH_ERRORS - if (check_has_torch_function((PyObject *)self)) { + if (check_has_torch_function((PyObject*)self)) { return handle_torch_function_getter(self, "_backward_hooks"); } if (self->backward_hooks) { @@ -1025,10 +1089,12 @@ PyObject *THPVariable_get_backwards_hooks(THPVariable *self, void *unused) END_HANDLE_TH_ERRORS } -int THPVariable_set_backwards_hooks(THPVariable *self, PyObject *obj, void *unused) -{ +int THPVariable_set_backwards_hooks( + THPVariable* self, + PyObject* obj, + void* unused) { HANDLE_TH_ERRORS - if (check_has_torch_function((PyObject *)self)) { + if (check_has_torch_function((PyObject*)self)) { return handle_torch_function_setter(self, "_backward_hooks", obj); } THPUtils_assertRet(-1, obj, "Deletion of _backwards_hooks not allowed!"); @@ -1041,16 +1107,16 @@ int THPVariable_set_backwards_hooks(THPVariable *self, PyObject *obj, void *unus const auto& tensor = THPVariable_Unpack(self); torch::autograd::impl::clear_hooks(tensor); if (obj) { - torch::autograd::impl::add_hook(tensor, std::make_shared(obj, 0)); + torch::autograd::impl::add_hook( + tensor, std::make_shared(obj, 0)); } return 0; END_HANDLE_TH_ERRORS_RET(-1) } -PyObject *THPVariable_get_base(THPVariable *self, void *unused) -{ +PyObject* THPVariable_get_base(THPVariable* self, void* unused) { HANDLE_TH_ERRORS - if (check_has_torch_function((PyObject *)self)) { + if (check_has_torch_function((PyObject*)self)) { return handle_torch_function_getter(self, "_base"); } const auto& tensor = THPVariable_Unpack(self); @@ -1085,23 +1151,22 @@ struct ConcretePythonGILHooks : public c10::impl::PythonGILHooks { // dead would be to leak it at destruction time. I didn't do that because // it's annoying to write the Registerer class for this case. ConcretePythonGILHooks python_gil_hooks; -static c10::impl::PythonGILHooksRegisterer python_gil_hooks_registerer(&python_gil_hooks); +static c10::impl::PythonGILHooksRegisterer python_gil_hooks_registerer( + &python_gil_hooks); #endif -PyObject *THPVariable_get_shape(THPVariable *self, void *unused) -{ +PyObject* THPVariable_get_shape(THPVariable* self, void* unused) { HANDLE_TH_ERRORS - if (check_has_torch_function((PyObject *)self)) { + if (check_has_torch_function((PyObject*)self)) { return handle_torch_function_getter(self, "shape"); } return THPSize_New(THPVariable_Unpack(self)); END_HANDLE_TH_ERRORS } -PyObject *THPVariable_is_cpu(THPVariable *self, void *unused) -{ +PyObject* THPVariable_is_cpu(THPVariable* self, void* unused) { HANDLE_TH_ERRORS - if (check_has_torch_function((PyObject *)self)) { + if (check_has_torch_function((PyObject*)self)) { return handle_torch_function_getter(self, "is_cpu"); } auto& self_ = THPVariable_Unpack(self); @@ -1109,10 +1174,9 @@ PyObject *THPVariable_is_cpu(THPVariable *self, void *unused) END_HANDLE_TH_ERRORS } -PyObject *THPVariable_is_cuda(THPVariable *self, void *unused) -{ +PyObject* THPVariable_is_cuda(THPVariable* self, void* unused) { HANDLE_TH_ERRORS - if (check_has_torch_function((PyObject *)self)) { + if (check_has_torch_function((PyObject*)self)) { return handle_torch_function_getter(self, "is_cuda"); } auto& self_ = THPVariable_Unpack(self); @@ -1140,10 +1204,9 @@ PyObject* THPVariable_is_xpu(THPVariable* self, void* unused) { END_HANDLE_TH_ERRORS } -PyObject *THPVariable_is_sparse(THPVariable *self, void *unused) -{ +PyObject* THPVariable_is_sparse(THPVariable* self, void* unused) { HANDLE_TH_ERRORS - if (check_has_torch_function((PyObject *)self)) { + if (check_has_torch_function((PyObject*)self)) { return handle_torch_function_getter(self, "is_sparse"); } auto& self_ = THPVariable_Unpack(self); @@ -1151,10 +1214,9 @@ PyObject *THPVariable_is_sparse(THPVariable *self, void *unused) END_HANDLE_TH_ERRORS } -PyObject *THPVariable_is_sparse_csr(THPVariable *self, void *unused) -{ +PyObject* THPVariable_is_sparse_csr(THPVariable* self, void* unused) { HANDLE_TH_ERRORS - if (check_has_torch_function((PyObject *)self)) { + if (check_has_torch_function((PyObject*)self)) { return handle_torch_function_getter(self, "is_sparse_csr"); } auto& self_ = THPVariable_Unpack(self); @@ -1162,10 +1224,9 @@ PyObject *THPVariable_is_sparse_csr(THPVariable *self, void *unused) END_HANDLE_TH_ERRORS } -PyObject *THPVariable_is_mkldnn(THPVariable *self, void *unused) -{ +PyObject* THPVariable_is_mkldnn(THPVariable* self, void* unused) { HANDLE_TH_ERRORS - if (check_has_torch_function((PyObject *)self)) { + if (check_has_torch_function((PyObject*)self)) { return handle_torch_function_getter(self, "is_mkldnn"); } auto& self_ = THPVariable_Unpack(self); @@ -1173,10 +1234,9 @@ PyObject *THPVariable_is_mkldnn(THPVariable *self, void *unused) END_HANDLE_TH_ERRORS } -PyObject *THPVariable_is_mps(THPVariable *self, void *unused) -{ +PyObject* THPVariable_is_mps(THPVariable* self, void* unused) { HANDLE_TH_ERRORS - if (check_has_torch_function((PyObject *)self)) { + if (check_has_torch_function((PyObject*)self)) { return handle_torch_function_getter(self, "is_mps"); } auto& self_ = THPVariable_Unpack(self); @@ -1184,10 +1244,9 @@ PyObject *THPVariable_is_mps(THPVariable *self, void *unused) END_HANDLE_TH_ERRORS } -PyObject *THPVariable_is_ort(THPVariable *self, void *unused) -{ +PyObject* THPVariable_is_ort(THPVariable* self, void* unused) { HANDLE_TH_ERRORS - if (check_has_torch_function((PyObject *)self)) { + if (check_has_torch_function((PyObject*)self)) { return handle_torch_function_getter(self, "is_ort"); } auto& self_ = THPVariable_Unpack(self); @@ -1195,10 +1254,9 @@ PyObject *THPVariable_is_ort(THPVariable *self, void *unused) END_HANDLE_TH_ERRORS } -PyObject *THPVariable_is_vulkan(THPVariable *self, void *unused) -{ +PyObject* THPVariable_is_vulkan(THPVariable* self, void* unused) { HANDLE_TH_ERRORS - if (check_has_torch_function((PyObject *)self)) { + if (check_has_torch_function((PyObject*)self)) { return handle_torch_function_getter(self, "is_vulkan"); } auto& self_ = THPVariable_Unpack(self); @@ -1206,10 +1264,9 @@ PyObject *THPVariable_is_vulkan(THPVariable *self, void *unused) END_HANDLE_TH_ERRORS } -PyObject *THPVariable_is_quantized(THPVariable *self, void *unused) -{ +PyObject* THPVariable_is_quantized(THPVariable* self, void* unused) { HANDLE_TH_ERRORS - if (check_has_torch_function((PyObject *)self)) { + if (check_has_torch_function((PyObject*)self)) { return handle_torch_function_getter(self, "is_quantized"); } auto& self_ = THPVariable_Unpack(self); @@ -1217,10 +1274,9 @@ PyObject *THPVariable_is_quantized(THPVariable *self, void *unused) END_HANDLE_TH_ERRORS } -PyObject *THPVariable_is_meta(THPVariable *self, void *unused) -{ +PyObject* THPVariable_is_meta(THPVariable* self, void* unused) { HANDLE_TH_ERRORS - if (check_has_torch_function((PyObject *)self)) { + if (check_has_torch_function((PyObject*)self)) { return handle_torch_function_getter(self, "is_meta"); } auto& self_ = THPVariable_Unpack(self); @@ -1228,10 +1284,9 @@ PyObject *THPVariable_is_meta(THPVariable *self, void *unused) END_HANDLE_TH_ERRORS } -PyObject *THPVariable_is_complex(THPVariable *self, void *unused) -{ +PyObject* THPVariable_is_complex(THPVariable* self, void* unused) { HANDLE_TH_ERRORS - if (check_has_torch_function((PyObject *)self)) { + if (check_has_torch_function((PyObject*)self)) { return handle_torch_function_getter(self, "is_complex"); } auto& self_ = THPVariable_Unpack(self); @@ -1239,10 +1294,9 @@ PyObject *THPVariable_is_complex(THPVariable *self, void *unused) END_HANDLE_TH_ERRORS } -PyObject *THPVariable_is_nested(THPVariable *self, void *unused) -{ +PyObject* THPVariable_is_nested(THPVariable* self, void* unused) { HANDLE_TH_ERRORS - if (check_has_torch_function((PyObject *)self)) { + if (check_has_torch_function((PyObject*)self)) { return handle_torch_function_getter(self, "is_nested"); } auto& self_ = THPVariable_Unpack(self); @@ -1250,10 +1304,9 @@ PyObject *THPVariable_is_nested(THPVariable *self, void *unused) END_HANDLE_TH_ERRORS } -static PyObject *THPVariable_dtype(THPVariable *self, void *unused) -{ +static PyObject* THPVariable_dtype(THPVariable* self, void* unused) { HANDLE_TH_ERRORS - if (check_has_torch_function((PyObject *)self)) { + if (check_has_torch_function((PyObject*)self)) { return handle_torch_function_getter(self, "dtype"); } auto& self_ = THPVariable_Unpack(self); @@ -1261,9 +1314,9 @@ static PyObject *THPVariable_dtype(THPVariable *self, void *unused) END_HANDLE_TH_ERRORS } -static PyObject * THPVariable_layout(THPVariable* self, void *unused) { +static PyObject* THPVariable_layout(THPVariable* self, void* unused) { HANDLE_TH_ERRORS - if (check_has_torch_function((PyObject *)self)) { + if (check_has_torch_function((PyObject*)self)) { return handle_torch_function_getter(self, "layout"); } auto& self_ = THPVariable_Unpack(self); @@ -1271,19 +1324,18 @@ static PyObject * THPVariable_layout(THPVariable* self, void *unused) { END_HANDLE_TH_ERRORS } -static PyObject * THPVariable_device(THPVariable* self, void *unused) { +static PyObject* THPVariable_device(THPVariable* self, void* unused) { HANDLE_TH_ERRORS - if (check_has_torch_function((PyObject *)self)) { + if (check_has_torch_function((PyObject*)self)) { return handle_torch_function_getter(self, "device"); } return THPDevice_New(THPVariable_Unpack(self).device()); END_HANDLE_TH_ERRORS } -PyObject *THPVariable_get_real(THPVariable* self, void *unused) -{ +PyObject* THPVariable_get_real(THPVariable* self, void* unused) { HANDLE_TH_ERRORS - if (check_has_torch_function((PyObject *)self)) { + if (check_has_torch_function((PyObject*)self)) { return handle_torch_function_getter(self, "real"); } auto& self_ = THPVariable_Unpack(self); @@ -1292,10 +1344,9 @@ PyObject *THPVariable_get_real(THPVariable* self, void *unused) END_HANDLE_TH_ERRORS } -PyObject *THPVariable_get_imag(THPVariable* self, void *unused) -{ +PyObject* THPVariable_get_imag(THPVariable* self, void* unused) { HANDLE_TH_ERRORS - if (check_has_torch_function((PyObject *)self)) { + if (check_has_torch_function((PyObject*)self)) { return handle_torch_function_getter(self, "imag"); } auto& self_ = THPVariable_Unpack(self); @@ -1304,8 +1355,7 @@ PyObject *THPVariable_get_imag(THPVariable* self, void *unused) END_HANDLE_TH_ERRORS } -int THPVariable_set_real(PyObject* self, PyObject* real, void *unused) -{ +int THPVariable_set_real(PyObject* self, PyObject* real, void* unused) { HANDLE_TH_ERRORS auto& self_ = THPVariable_Unpack(self); auto self_real = at::real(self_); @@ -1318,8 +1368,7 @@ int THPVariable_set_real(PyObject* self, PyObject* real, void *unused) END_HANDLE_TH_ERRORS_RET(-1) } -int THPVariable_set_imag(PyObject* self, PyObject* imag, void *unused) -{ +int THPVariable_set_imag(PyObject* self, PyObject* imag, void* unused) { HANDLE_TH_ERRORS auto& self_ = THPVariable_Unpack(self); auto self_imag = at::imag(self_); @@ -1332,73 +1381,132 @@ int THPVariable_set_imag(PyObject* self, PyObject* imag, void *unused) END_HANDLE_TH_ERRORS_RET(-1) } -// properties are registered here because we are currently only able to bind them -// manually. TODO: make declarable in native_functions +// properties are registered here because we are currently only able to bind +// them manually. TODO: make declarable in native_functions // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables) static struct PyGetSetDef THPVariable_properties[] = { - {"_python_dispatch", (getter)THPVariable_get_python_dispatch, nullptr, nullptr, nullptr}, - {"T", (getter)THPVariable_get_T, nullptr, nullptr, nullptr}, - {"H", (getter)THPVariable_get_H, nullptr, nullptr, nullptr}, - {"mT", (getter)THPVariable_get_mT, nullptr, nullptr, nullptr}, - {"mH", (getter)THPVariable_get_mH, nullptr, nullptr, nullptr}, - {"_cdata", (getter)THPVariable_get_cdata, nullptr, nullptr, nullptr}, - {"_version", (getter)THPVariable_get_version, nullptr, nullptr, nullptr}, - {"grad_fn", (getter)THPVariable_get_grad_fn, nullptr, nullptr, nullptr}, - {"_grad_fn", (getter)THPVariable_get_grad_fn, (setter)THPVariable_set_grad_fn, nullptr, nullptr}, - {"is_leaf", (getter)THPVariable_is_leaf, nullptr, nullptr, nullptr}, - {"retains_grad", (getter)THPVariable_retains_grad, nullptr, nullptr, nullptr}, - {"data", (getter)THPVariable_get_data, (setter)THPVariable_set_data, nullptr, nullptr}, - {"_grad", (getter)THPVariable_get_grad, (setter)THPVariable_set_grad, nullptr, nullptr}, // Allows the python class to override .grad - {"grad", (getter)THPVariable_get_grad, (setter)THPVariable_set_grad, nullptr, nullptr}, - {"_base", (getter)THPVariable_get_base, nullptr, nullptr, nullptr}, - {"volatile", (getter)THPVariable_get_volatile, (setter)THPVariable_set_volatile, nullptr, nullptr}, - {"output_nr", (getter)THPVariable_get_output_nr, nullptr, nullptr, nullptr}, - {"requires_grad", (getter)THPVariable_get_requires_grad, (setter)THPVariable_set_requires_grad, nullptr, nullptr}, - {"_backward_hooks", (getter)THPVariable_get_backwards_hooks, (setter)THPVariable_set_backwards_hooks, nullptr, nullptr}, - {"name", (getter)THPVariable_get_name, nullptr, nullptr, nullptr}, - {"shape", (getter)THPVariable_get_shape, nullptr, nullptr, nullptr}, - {"is_cuda", (getter)THPVariable_is_cuda, nullptr, nullptr, nullptr}, - {"is_cpu", (getter)THPVariable_is_cpu, nullptr, nullptr, nullptr}, - {"is_xpu", (getter)THPVariable_is_xpu, nullptr, nullptr, nullptr}, - {"is_ipu", (getter)THPVariable_is_ipu, nullptr, nullptr, nullptr}, - {"is_sparse", (getter)THPVariable_is_sparse, nullptr, nullptr, nullptr}, - {"is_sparse_csr", (getter)THPVariable_is_sparse_csr, nullptr, nullptr, nullptr}, - {"is_mkldnn", (getter)THPVariable_is_mkldnn, nullptr, nullptr, nullptr}, - {"is_mps", (getter)THPVariable_is_mps, nullptr, nullptr, nullptr}, - {"is_ort", (getter)THPVariable_is_ort, nullptr, nullptr, nullptr}, - {"is_vulkan", (getter)THPVariable_is_vulkan, nullptr, nullptr, nullptr}, - {"is_complex", (getter)THPVariable_is_complex, nullptr, nullptr, nullptr}, - {"is_quantized", (getter)THPVariable_is_quantized, nullptr, nullptr, nullptr}, - {"is_meta", (getter)THPVariable_is_meta, nullptr, nullptr, nullptr}, - {"is_nested", (getter)THPVariable_is_nested, nullptr, nullptr, nullptr}, - {"dtype", (getter)THPVariable_dtype, nullptr, nullptr, nullptr}, - {"layout", (getter)THPVariable_layout, nullptr, nullptr, nullptr}, - {"device", (getter)THPVariable_device, nullptr, nullptr, nullptr}, - {"ndim", (getter)THPVariable_get_ndim, nullptr, nullptr, nullptr}, - {"names", (getter)THPVariable_get_names, (setter)THPVariable_set_names, nullptr, nullptr}, - {"real", (getter)THPVariable_get_real, (setter)THPVariable_set_real, nullptr, nullptr}, - {"imag", (getter)THPVariable_get_imag, (setter)THPVariable_set_imag, nullptr, nullptr}, - {nullptr} -}; + {"_python_dispatch", + (getter)THPVariable_get_python_dispatch, + nullptr, + nullptr, + nullptr}, + {"T", (getter)THPVariable_get_T, nullptr, nullptr, nullptr}, + {"H", (getter)THPVariable_get_H, nullptr, nullptr, nullptr}, + {"mT", (getter)THPVariable_get_mT, nullptr, nullptr, nullptr}, + {"mH", (getter)THPVariable_get_mH, nullptr, nullptr, nullptr}, + {"_cdata", (getter)THPVariable_get_cdata, nullptr, nullptr, nullptr}, + {"_version", (getter)THPVariable_get_version, nullptr, nullptr, nullptr}, + {"grad_fn", (getter)THPVariable_get_grad_fn, nullptr, nullptr, nullptr}, + {"_grad_fn", + (getter)THPVariable_get_grad_fn, + (setter)THPVariable_set_grad_fn, + nullptr, + nullptr}, + {"is_leaf", (getter)THPVariable_is_leaf, nullptr, nullptr, nullptr}, + {"retains_grad", + (getter)THPVariable_retains_grad, + nullptr, + nullptr, + nullptr}, + {"data", + (getter)THPVariable_get_data, + (setter)THPVariable_set_data, + nullptr, + nullptr}, + {"_grad", + (getter)THPVariable_get_grad, + (setter)THPVariable_set_grad, + nullptr, + nullptr}, // Allows the python class to override .grad + {"grad", + (getter)THPVariable_get_grad, + (setter)THPVariable_set_grad, + nullptr, + nullptr}, + {"_base", (getter)THPVariable_get_base, nullptr, nullptr, nullptr}, + {"volatile", + (getter)THPVariable_get_volatile, + (setter)THPVariable_set_volatile, + nullptr, + nullptr}, + {"output_nr", (getter)THPVariable_get_output_nr, nullptr, nullptr, nullptr}, + {"requires_grad", + (getter)THPVariable_get_requires_grad, + (setter)THPVariable_set_requires_grad, + nullptr, + nullptr}, + {"_backward_hooks", + (getter)THPVariable_get_backwards_hooks, + (setter)THPVariable_set_backwards_hooks, + nullptr, + nullptr}, + {"name", (getter)THPVariable_get_name, nullptr, nullptr, nullptr}, + {"shape", (getter)THPVariable_get_shape, nullptr, nullptr, nullptr}, + {"is_cuda", (getter)THPVariable_is_cuda, nullptr, nullptr, nullptr}, + {"is_cpu", (getter)THPVariable_is_cpu, nullptr, nullptr, nullptr}, + {"is_xpu", (getter)THPVariable_is_xpu, nullptr, nullptr, nullptr}, + {"is_ipu", (getter)THPVariable_is_ipu, nullptr, nullptr, nullptr}, + {"is_sparse", (getter)THPVariable_is_sparse, nullptr, nullptr, nullptr}, + {"is_sparse_csr", + (getter)THPVariable_is_sparse_csr, + nullptr, + nullptr, + nullptr}, + {"is_mkldnn", (getter)THPVariable_is_mkldnn, nullptr, nullptr, nullptr}, + {"is_mps", (getter)THPVariable_is_mps, nullptr, nullptr, nullptr}, + {"is_ort", (getter)THPVariable_is_ort, nullptr, nullptr, nullptr}, + {"is_vulkan", (getter)THPVariable_is_vulkan, nullptr, nullptr, nullptr}, + {"is_complex", (getter)THPVariable_is_complex, nullptr, nullptr, nullptr}, + {"is_quantized", + (getter)THPVariable_is_quantized, + nullptr, + nullptr, + nullptr}, + {"is_meta", (getter)THPVariable_is_meta, nullptr, nullptr, nullptr}, + {"is_nested", (getter)THPVariable_is_nested, nullptr, nullptr, nullptr}, + {"dtype", (getter)THPVariable_dtype, nullptr, nullptr, nullptr}, + {"layout", (getter)THPVariable_layout, nullptr, nullptr, nullptr}, + {"device", (getter)THPVariable_device, nullptr, nullptr, nullptr}, + {"ndim", (getter)THPVariable_get_ndim, nullptr, nullptr, nullptr}, + {"names", + (getter)THPVariable_get_names, + (setter)THPVariable_set_names, + nullptr, + nullptr}, + {"real", + (getter)THPVariable_get_real, + (setter)THPVariable_set_real, + nullptr, + nullptr}, + {"imag", + (getter)THPVariable_get_imag, + (setter)THPVariable_set_imag, + nullptr, + nullptr}, + {nullptr}}; static PyMappingMethods THPVariable_as_mapping = { - THPVariable_length, - THPVariable_getitem, - THPVariable_setitem, + THPVariable_length, + THPVariable_getitem, + THPVariable_setitem, }; // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables) static PyMethodDef extra_methods[] = { - {"as_subclass", castPyCFunctionWithKeywords(THPVariable_as_subclass), - METH_VARARGS | METH_KEYWORDS, nullptr}, - {"_make_subclass", castPyCFunctionWithKeywords(THPVariable_make_subclass), - METH_STATIC | METH_VARARGS | METH_KEYWORDS, nullptr}, - {"_make_wrapper_subclass", castPyCFunctionWithKeywords(THPVariable_make_wrapper_subclass), - METH_STATIC | METH_VARARGS | METH_KEYWORDS, nullptr}, - {"_fix_weakref", THPVariable_fix_weakref, - METH_NOARGS, nullptr}, - {nullptr} -}; + {"as_subclass", + castPyCFunctionWithKeywords(THPVariable_as_subclass), + METH_VARARGS | METH_KEYWORDS, + nullptr}, + {"_make_subclass", + castPyCFunctionWithKeywords(THPVariable_make_subclass), + METH_STATIC | METH_VARARGS | METH_KEYWORDS, + nullptr}, + {"_make_wrapper_subclass", + castPyCFunctionWithKeywords(THPVariable_make_wrapper_subclass), + METH_STATIC | METH_VARARGS | METH_KEYWORDS, + nullptr}, + {"_fix_weakref", THPVariable_fix_weakref, METH_NOARGS, nullptr}, + {nullptr}}; /* From https://github.com/python/cpython/blob/v3.7.0/Modules/xxsubtype.c If compiled as a shared library instead, some compilers don't allow addresses @@ -1413,47 +1521,48 @@ struct THPVariableMeta { PyHeapTypeObject base; }; -int THPVariableMetaType_init(PyObject *cls, PyObject *args, PyObject *kwargs); +int THPVariableMetaType_init(PyObject* cls, PyObject* args, PyObject* kwargs); PyTypeObject THPVariableMetaType = { - PyVarObject_HEAD_INIT(DEFERRED_ADDRESS(&PyType_Type), 0) - "torch._C._TensorMeta", /* tp_name */ - sizeof(THPVariableMeta), /* tp_basicsize */ - 0, /* tp_itemsize */ - nullptr, /* tp_dealloc */ - 0, /* tp_vectorcall_offset */ - nullptr, /* tp_getattr */ - nullptr, /* tp_setattr */ - nullptr, /* tp_reserved */ - nullptr, /* tp_repr */ - nullptr, /* tp_as_number */ - nullptr, /* tp_as_sequence */ - nullptr, /* tp_as_mapping */ - nullptr, /* tp_hash */ - nullptr, /* tp_call */ - nullptr, /* tp_str */ - nullptr, /* tp_getattro */ - nullptr, /* tp_setattro */ - nullptr, /* tp_as_buffer */ - Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /* tp_flags */ - nullptr, /* tp_doc */ - nullptr, /* tp_traverse */ - nullptr, /* tp_clear */ - nullptr, /* tp_richcompare */ - 0, /* tp_weaklistoffset */ - nullptr, /* tp_iter */ - nullptr, /* tp_iternext */ - nullptr, /* tp_methods */ - nullptr, /* tp_members */ - nullptr, /* tp_getset */ - DEFERRED_ADDRESS(&PyType_Type), /* tp_base */ - nullptr, /* tp_dict */ - nullptr, /* tp_descr_get */ - nullptr, /* tp_descr_set */ - 0, /* tp_dictoffset */ - THPVariableMetaType_init, /* tp_init */ - nullptr, /* tp_alloc */ - nullptr, /* tp_new */ + PyVarObject_HEAD_INIT( + DEFERRED_ADDRESS(&PyType_Type), + 0) "torch._C._TensorMeta", /* tp_name */ + sizeof(THPVariableMeta), /* tp_basicsize */ + 0, /* tp_itemsize */ + nullptr, /* tp_dealloc */ + 0, /* tp_vectorcall_offset */ + nullptr, /* tp_getattr */ + nullptr, /* tp_setattr */ + nullptr, /* tp_reserved */ + nullptr, /* tp_repr */ + nullptr, /* tp_as_number */ + nullptr, /* tp_as_sequence */ + nullptr, /* tp_as_mapping */ + nullptr, /* tp_hash */ + nullptr, /* tp_call */ + nullptr, /* tp_str */ + nullptr, /* tp_getattro */ + nullptr, /* tp_setattro */ + nullptr, /* tp_as_buffer */ + Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /* tp_flags */ + nullptr, /* tp_doc */ + nullptr, /* tp_traverse */ + nullptr, /* tp_clear */ + nullptr, /* tp_richcompare */ + 0, /* tp_weaklistoffset */ + nullptr, /* tp_iter */ + nullptr, /* tp_iternext */ + nullptr, /* tp_methods */ + nullptr, /* tp_members */ + nullptr, /* tp_getset */ + DEFERRED_ADDRESS(&PyType_Type), /* tp_base */ + nullptr, /* tp_dict */ + nullptr, /* tp_descr_get */ + nullptr, /* tp_descr_set */ + 0, /* tp_dictoffset */ + THPVariableMetaType_init, /* tp_init */ + nullptr, /* tp_alloc */ + nullptr, /* tp_new */ }; PyTypeObject THPVariableType = { @@ -1505,10 +1614,14 @@ PyTypeObject THPVariableType = { THPVariable_pynew, /* tp_new */ }; -PyObject *THPVariable_pynew(PyTypeObject *type, PyObject *args, PyObject *kwargs) -{ +PyObject* THPVariable_pynew( + PyTypeObject* type, + PyObject* args, + PyObject* kwargs) { HANDLE_TH_ERRORS - TORCH_CHECK(type != &THPVariableType, "Cannot directly construct _TensorBase; subclass it and then construct that"); + TORCH_CHECK( + type != &THPVariableType, + "Cannot directly construct _TensorBase; subclass it and then construct that"); jit::tracer::warn("torch.Tensor", jit::tracer::WARN_CONSTRUCTOR); auto tensor = torch::utils::base_tensor_ctor(args, kwargs); // WARNING: tensor is NOT guaranteed to be a fresh tensor; e.g., if it was @@ -1653,17 +1766,23 @@ static PyObject* THPVariable_NewWithVar( // This function overwrite the Tensor's pyobj field without extra checks // Make sure it is not set otherwise we would leak memory auto mb_obj = _var.unsafeGetTensorImpl()->check_pyobj(self_interpreter.get()); - TORCH_CHECK(!mb_obj.has_value() || !mb_obj.value(), "Creating a new Tensor subclass ", - type->tp_name, " but the raw Tensor object is already associated to a python object ", - "of type ", mb_obj.value()->ob_type->tp_name); + TORCH_CHECK( + !mb_obj.has_value() || !mb_obj.value(), + "Creating a new Tensor subclass ", + type->tp_name, + " but the raw Tensor object is already associated to a python object ", + "of type ", + mb_obj.value()->ob_type->tp_name); // Make sure that the reinterpret into a THPVariable* will be valid - TORCH_CHECK(PyType_IsSubtype(type, &THPVariableType), "Creating a Tensor subclass from a class ", - "that does not inherit from Tensor is not possible. Make sure your class inherits from Tensor."); + TORCH_CHECK( + PyType_IsSubtype(type, &THPVariableType), + "Creating a Tensor subclass from a class ", + "that does not inherit from Tensor is not possible. Make sure your class inherits from Tensor."); PyObject* obj = type->tp_alloc(type, 0); if (obj) { - auto v = (THPVariable*) obj; + auto v = (THPVariable*)obj; // TODO: named constructor to avoid default initialization new (&v->cdata) MaybeOwned(); v->cdata = MaybeOwned::owned(std::move(_var)); @@ -1678,43 +1797,48 @@ static PyObject* THPVariable_NewWithVar( /// NOTE [ PyObject Traversal ] /// -/// PyObjects that are wrapping c++ objects can lead to non-trivial traverse logic -/// and it can be tricky to know what to traverse and when. This note tries to -/// clarify what is the danger here and a simple algorithm to choose how to write -/// the tp_traverse and tp_clear functions. -/// If you're not already familiar with how the CPython GC works, you should read this -/// in-depth description: https://devguide.python.org/garbage_collector/ +/// PyObjects that are wrapping c++ objects can lead to non-trivial traverse +/// logic and it can be tricky to know what to traverse and when. This note +/// tries to clarify what is the danger here and a simple algorithm to choose +/// how to write the tp_traverse and tp_clear functions. If you're not already +/// familiar with how the CPython GC works, you should read this in-depth +/// description: https://devguide.python.org/garbage_collector/ /// /// The complexity for us comes from the fact that some c++ shared_ptr objects -/// own references to python objects and are also owned both by other python objects -/// and c++ objects. This means that to allow the GC to collect all cycles, we need to -/// properly implement the traverse/clear methods that take into account these C++ -/// ownership links. +/// own references to python objects and are also owned both by other python +/// objects and c++ objects. This means that to allow the GC to collect all +/// cycles, we need to properly implement the traverse/clear methods that take +/// into account these C++ ownership links. /// -/// The main danger here comes from the fact that, while all python-related code is -/// thread safe wrt the GC execution (thanks to the GIL), other threads might be using -/// our C++ objects arbitrarily which can lead to shared_ptr ref count going up or down -/// in between the different traverse/clear invocations. -/// The one constraint we add here that is not explicitly mentioned in the GC description -/// above is that for a given GC run (meaning while the GIL is held), the traverse/clear -/// pair should never report different ownership relations: if traverse visited a given -/// PyObject, then the clear within that same GC run must still be the sole owner and -/// clear that PyObject. +/// The main danger here comes from the fact that, while all python-related code +/// is thread safe wrt the GC execution (thanks to the GIL), other threads might +/// be using our C++ objects arbitrarily which can lead to shared_ptr ref count +/// going up or down in between the different traverse/clear invocations. The +/// one constraint we add here that is not explicitly mentioned in the GC +/// description above is that for a given GC run (meaning while the GIL is +/// held), the traverse/clear pair should never report different ownership +/// relations: if traverse visited a given PyObject, then the clear within that +/// same GC run must still be the sole owner and clear that PyObject. /// /// A more mechanical algorithm to know what to traverse/clear is as follows: -/// - Any field on this PyObject that contains a strong reference to another PyObject -/// must be visited and cleared. An example of that is the "backward_hooks" field of -/// the THPVariable. -/// - Any field that contains a C++ object that is uniquely owned by this PyObject (either -/// a unique_ptr or a shared_ptr with use_count==1) should have all the PyObject it owns -/// visited and cleared. An example would be here the tensor hooks. -/// - If that uniquely owned C++ object also uniquely owns other C++ objects, these should be +/// - Any field on this PyObject that contains a strong reference to another +/// PyObject +/// must be visited and cleared. An example of that is the "backward_hooks" +/// field of the THPVariable. +/// - Any field that contains a C++ object that is uniquely owned by this +/// PyObject (either +/// a unique_ptr or a shared_ptr with use_count==1) should have all the +/// PyObject it owns visited and cleared. An example would be here the +/// tensor hooks. +/// - If that uniquely owned C++ object also uniquely owns other C++ objects, +/// these should be /// visited and cleared as well if they contain any PyObject. /// -/// Caveat: to avoid slow runtime, we limit the depth of this exploration of C++ objects in -/// practice and we do not, for example, go through the whole autograd graph, even if it is -/// uniquely owned. This is a known place where users can create noncollectable cycles as described -/// in: https://github.com/pytorch/pytorch/issues/7343 +/// Caveat: to avoid slow runtime, we limit the depth of this exploration of C++ +/// objects in practice and we do not, for example, go through the whole +/// autograd graph, even if it is uniquely owned. This is a known place where +/// users can create noncollectable cycles as described in: +/// https://github.com/pytorch/pytorch/issues/7343 /// static int traverse_slots( @@ -1802,17 +1926,17 @@ static int THPVariable_subclass_traverse( if (!var->cdata.unsafeIsBorrowed()) { const auto& tensor = THPVariable_Unpack(var); if (tensor.defined()) { - // WARNING: The grad_fn traversal logic is very subtle, if you change this, - // be very careful not to re-introduce this bug: + // WARNING: The grad_fn traversal logic is very subtle, if you change + // this, be very careful not to re-introduce this bug: // https://gist.github.com/zou3519/7ac92b84dd7d206dcc6eae55fee8372c - // We ensure that we follow NOTE [ PyObject Traversal ] he by checking that this - // python object is the sole owner of the underlying Tensor and that this Tensor - // is the sole owner of its grad_fn. - // In this case, the only way to get a new reference to the grad_fn is by using - // this python object, which requires the GIL to be accessed. - // Note that this is only valid as long as user don't share non-owning references - // across different threads (which is crazy and should never be done). + // We ensure that we follow NOTE [ PyObject Traversal ] he by checking + // that this python object is the sole owner of the underlying Tensor and + // that this Tensor is the sole owner of its grad_fn. In this case, the + // only way to get a new reference to the grad_fn is by using this python + // object, which requires the GIL to be accessed. Note that this is only + // valid as long as user don't share non-owning references across + // different threads (which is crazy and should never be done). if (tensor.use_count() == 1) { auto autograd_meta = torch::autograd::impl::get_autograd_meta(tensor); @@ -1841,7 +1965,7 @@ static int THPVariable_subclass_traverse( return 0; } -int THPVariableMetaType_init(PyObject *cls, PyObject *args, PyObject *kwargs) { +int THPVariableMetaType_init(PyObject* cls, PyObject* args, PyObject* kwargs) { if (PyType_Type.tp_init(cls, args, kwargs) < 0) { return -1; } @@ -1851,11 +1975,12 @@ int THPVariableMetaType_init(PyObject *cls, PyObject *args, PyObject *kwargs) { return 0; } -namespace torch { namespace autograd { +namespace torch { +namespace autograd { // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables) extern PyMethodDef variable_methods[]; -extern void initTorchFunctions(PyObject *module); +extern void initTorchFunctions(PyObject* module); void initTensorImplConversion(PyObject* module) { auto m = py::handle(module).cast(); @@ -1874,15 +1999,15 @@ void initTensorImplConversion(PyObject* module) { return t->getIntrusivePtr().get(); }); } -}} +} // namespace autograd +} // namespace torch -bool THPVariable_initModule(PyObject *module) -{ +bool THPVariable_initModule(PyObject* module) { THPVariableMetaType.tp_base = &PyType_Type; if (PyType_Ready(&THPVariableMetaType) < 0) return false; Py_INCREF(&THPVariableMetaType); - PyModule_AddObject(module, "_TensorMeta", (PyObject *)&THPVariableMetaType); + PyModule_AddObject(module, "_TensorMeta", (PyObject*)&THPVariableMetaType); static std::vector methods; THPUtils_addPyMethodDefs(methods, torch::autograd::variable_methods); @@ -1891,7 +2016,7 @@ bool THPVariable_initModule(PyObject *module) if (PyType_Ready(&THPVariableType) < 0) return false; Py_INCREF(&THPVariableType); - PyModule_AddObject(module, "_TensorBase", (PyObject *)&THPVariableType); + PyModule_AddObject(module, "_TensorBase", (PyObject*)&THPVariableType); torch::autograd::initTorchFunctions(module); torch::autograd::initTensorImplConversion(module); return true; @@ -1903,14 +2028,21 @@ bool isPythonTensor(const Tensor& tensor) { return tensor.unsafeGetTensorImpl()->key_set().has(c10::DispatchKey::Python); } - -py::object torchDispatchFromTensorImpl(const c10::TensorImpl* self, const char* func_name, PyObject* torch_api_function, const char* module_name) { - TORCH_CHECK(PyGILState_Check(), "GIL must be held before you call parseIValuesToPyArgsKwargs"); +py::object torchDispatchFromTensorImpl( + const c10::TensorImpl* self, + const char* func_name, + PyObject* torch_api_function, + const char* module_name) { + TORCH_CHECK( + PyGILState_Check(), + "GIL must be held before you call parseIValuesToPyArgsKwargs"); std::vector overloaded_args; // TODO: there should be a shorter way to spell this // TODO: fix the constness of target - Tensor self_t = Tensor(c10::intrusive_ptr::unsafe_reclaim_from_nonowning(const_cast(self))); + Tensor self_t = Tensor( + c10::intrusive_ptr:: + unsafe_reclaim_from_nonowning(const_cast(self))); auto self_p = py::reinterpret_steal(THPVariable_Wrap(self_t)); TORCH_INTERNAL_ASSERT(isPythonTensor(self_t)); append_overloaded_tensor(&overloaded_args, self_p.ptr()); @@ -1932,8 +2064,8 @@ py::object torchDispatchFromTensorImpl(const c10::TensorImpl* self, const char* // NOTE [dispatch_fn's type argument] // `type` is nullable and represents the TorchDispatchMode going on. -// Right now we only support a single TorchDispatchMode, but in the future we could -// change this to a stack of TorchDispatchModes. +// Right now we only support a single TorchDispatchMode, but in the future we +// could change this to a stack of TorchDispatchModes. // // If `type` isn't null, then we consider the type for dispatch by prepending // it to the overloaded_args list. `handle_torch_funciton_no_python_arg_parser` @@ -1967,12 +2099,14 @@ void concrete_dispatch_fn( py::gil_scoped_acquire g; std::vector overloaded_args; - py::handle torch_api_function = py::module::import("torch").attr("ops").attr(ns).attr(func_name); + py::handle torch_api_function = + py::module::import("torch").attr("ops").attr(ns).attr(func_name); py::handle torch_api_function_overload; if (overload_name == "") { torch_api_function_overload = torch_api_function.attr("default"); } else { - torch_api_function_overload = torch_api_function.attr(overload_name.c_str()); + torch_api_function_overload = + torch_api_function.attr(overload_name.c_str()); } std::string module_name_str = "torch.ops." + ns_str; @@ -2007,17 +2141,20 @@ void concrete_dispatch_fn( auto kwargs = std::move(args_kwargs.second); PyObject* obj = handle_torch_function_no_python_arg_parser( - overloaded_args, - args.ptr(), - kwargs.ptr(), - func_name, - torch_api_function_overload.ptr(), - module_name_str.c_str(), - TorchFunctionName::TorchDispatch); - pushPyOutToStack(op, stack, py::reinterpret_steal(obj), "__torch_dispatch__"); -} - -c10::intrusive_ptr concrete_detach_fn(const c10::impl::PyInterpreter*, const c10::TensorImpl* self) { + overloaded_args, + args.ptr(), + kwargs.ptr(), + func_name, + torch_api_function_overload.ptr(), + module_name_str.c_str(), + TorchFunctionName::TorchDispatch); + pushPyOutToStack( + op, stack, py::reinterpret_steal(obj), "__torch_dispatch__"); +} + +c10::intrusive_ptr concrete_detach_fn( + const c10::impl::PyInterpreter*, + const c10::TensorImpl* self) { pybind11::gil_scoped_acquire gil; at::impl::MaybeSetTLSOnEntryGuard guard; @@ -2032,12 +2169,18 @@ c10::intrusive_ptr concrete_detach_fn(const c10::impl::PyInterpreter .ptr(), "torch.ops.aten"); - TORCH_CHECK(THPVariable_Check(out.ptr()), "detach returned invalid type ", py::detail::get_fully_qualified_tp_name(Py_TYPE(out.ptr())), ", expected Tensor"); + TORCH_CHECK( + THPVariable_Check(out.ptr()), + "detach returned invalid type ", + py::detail::get_fully_qualified_tp_name(Py_TYPE(out.ptr())), + ", expected Tensor"); const Tensor& res_t = THPVariable_Unpack(out.ptr()); return res_t.getIntrusivePtr(); } -bool concrete_is_contiguous_fn(const c10::impl::PyInterpreter*, const c10::TensorImpl* self) { +bool concrete_is_contiguous_fn( + const c10::impl::PyInterpreter*, + const c10::TensorImpl* self) { pybind11::gil_scoped_acquire gil; at::impl::MaybeSetTLSOnEntryGuard guard; @@ -2052,7 +2195,11 @@ bool concrete_is_contiguous_fn(const c10::impl::PyInterpreter*, const c10::Tenso .ptr(), "torch.ops.aten"); - TORCH_CHECK(PyBool_Check(out.ptr()), "is_contiguous returned invalid type ", py::detail::get_fully_qualified_tp_name(Py_TYPE(out.ptr())), ", expected bool"); + TORCH_CHECK( + PyBool_Check(out.ptr()), + "is_contiguous returned invalid type ", + py::detail::get_fully_qualified_tp_name(Py_TYPE(out.ptr())), + ", expected bool"); return PyObject_IsTrue(out.ptr()); } @@ -2083,7 +2230,9 @@ int64_t concrete_dim_fn( return THPUtils_unpackLong(out.ptr()); } -c10::Device concrete_device_fn(const c10::impl::PyInterpreter*, const c10::TensorImpl* self) { +c10::Device concrete_device_fn( + const c10::impl::PyInterpreter*, + const c10::TensorImpl* self) { pybind11::gil_scoped_acquire gil; at::impl::MaybeSetTLSOnEntryGuard guard; @@ -2101,18 +2250,16 @@ c10::Device concrete_device_fn(const c10::impl::PyInterpreter*, const c10::Tenso return toDevice(out.ptr()); } -c10::IntArrayRef concrete_strides_fn(const c10::impl::PyInterpreter*, const c10::TensorImpl* self) { +c10::IntArrayRef concrete_strides_fn( + const c10::impl::PyInterpreter*, + const c10::TensorImpl* self) { pybind11::gil_scoped_acquire gil; at::impl::MaybeSetTLSOnEntryGuard guard; auto out = torchDispatchFromTensorImpl( self, "stride", - py::module::import("torch") - .attr("ops") - .attr("aten") - .attr("stride") - .ptr(), + py::module::import("torch").attr("ops").attr("aten").attr("stride").ptr(), "torch.ops.aten"); if (out == Py_None) { @@ -2123,16 +2270,18 @@ c10::IntArrayRef concrete_strides_fn(const c10::impl::PyInterpreter*, const c10: c10::TensorImpl* ptr = const_cast(self); c10::optional mb_obj = ptr->check_pyobj(getPyInterpreter()); - TORCH_CHECK(mb_obj.has_value(), "Tensor subclass's PyInterpreter has no value"); + TORCH_CHECK( + mb_obj.has_value(), "Tensor subclass's PyInterpreter has no value"); PyObject* subclass = *mb_obj; Py_INCREF(subclass); py::object sub = py::reinterpret_steal(subclass); py::object os = py::module_::import("torch").attr("overrides"); - py::function get_buffer = py::reinterpret_borrow(os.attr("get_buffer")); + py::function get_buffer = + py::reinterpret_borrow(os.attr("get_buffer")); auto buffer = get_buffer(sub, values); auto result = THPUtils_unpackLongs(buffer.ptr()); - int64_t* start = (int64_t*) result[0]; + int64_t* start = (int64_t*)result[0]; int64_t len = result[1]; return c10::IntArrayRef(start, len); diff --git a/torch/csrc/autograd/python_variable.h b/torch/csrc/autograd/python_variable.h index c4856cdd4d12fe..92d00be2da33d2 100644 --- a/torch/csrc/autograd/python_variable.h +++ b/torch/csrc/autograd/python_variable.h @@ -1,14 +1,14 @@ #pragma once +#include #include #include -#include -#include -#include -#include -#include #include +#include +#include +#include +#include namespace py = pybind11; @@ -23,13 +23,15 @@ struct THPVariable { PyObject* backward_hooks = nullptr; }; -TORCH_API void registerPythonTensorClass(const std::string& device, PyObject* python_tensor_class); +TORCH_API void registerPythonTensorClass( + const std::string& device, + PyObject* python_tensor_class); -TORCH_PYTHON_API extern PyObject *THPVariableClass; -TORCH_PYTHON_API extern PyObject *ParameterClass; +TORCH_PYTHON_API extern PyObject* THPVariableClass; +TORCH_PYTHON_API extern PyObject* ParameterClass; -bool THPVariable_initModule(PyObject *module); -TORCH_PYTHON_API PyObject * THPVariable_Wrap(at::TensorBase var); +bool THPVariable_initModule(PyObject* module); +TORCH_PYTHON_API PyObject* THPVariable_Wrap(at::TensorBase var); static inline bool THPVariable_CheckTypeExact(PyTypeObject* tp) { // Check that a python object is a `Tensor`, but not a `Tensor` subclass. @@ -37,23 +39,21 @@ static inline bool THPVariable_CheckTypeExact(PyTypeObject* tp) { // Parameter, which is used for Python bookkeeping but is equivalent to // Tensor as far as C++ is concerned. return ( - tp == (PyTypeObject*)THPVariableClass || - tp == (PyTypeObject*)ParameterClass - ); + tp == (PyTypeObject*)THPVariableClass || + tp == (PyTypeObject*)ParameterClass); } -static inline bool THPVariable_CheckExact(PyObject *obj) { +static inline bool THPVariable_CheckExact(PyObject* obj) { return THPVariable_CheckTypeExact(Py_TYPE(obj)); } -inline bool THPVariable_Check(PyObject *obj) -{ +inline bool THPVariable_Check(PyObject* obj) { if (!THPVariableClass) - return false; + return false; const auto result = PyObject_IsInstance(obj, THPVariableClass); if (result == -1) - throw python_error(); + throw python_error(); return result; } @@ -67,6 +67,12 @@ inline const at::Tensor& THPVariable_Unpack(PyObject* obj) { TORCH_PYTHON_API c10::impl::PyInterpreter* getPyInterpreter(); -std::pair parseIValuesToPyArgsKwargs(const c10::OperatorHandle& op, const std::vector& arguments); +std::pair parseIValuesToPyArgsKwargs( + const c10::OperatorHandle& op, + const std::vector& arguments); -void pushPyOutToStack(const c10::OperatorHandle& op, torch::jit::Stack* stack, py::object out, const char* msg); +void pushPyOutToStack( + const c10::OperatorHandle& op, + torch::jit::Stack* stack, + py::object out, + const char* msg); diff --git a/torch/csrc/autograd/python_variable_indexing.cpp b/torch/csrc/autograd/python_variable_indexing.cpp index a87f98710e0384..4dc62f007a55ba 100644 --- a/torch/csrc/autograd/python_variable_indexing.cpp +++ b/torch/csrc/autograd/python_variable_indexing.cpp @@ -6,34 +6,36 @@ #include #include #include +#include +#include +#include #include #include #include -#include -#include -#include #include #include #include #include #include +#include #include #include -#include -#include #include +#include using namespace at; using namespace torch::autograd::utils; -namespace torch { namespace autograd { +namespace torch { +namespace autograd { Py_ssize_t THPVariable_length(PyObject* self) { HANDLE_TH_ERRORS if (check_has_torch_function(self)) { - py::object ret = py::reinterpret_steal(handle_torch_function(self, "__len__")); + py::object ret = py::reinterpret_steal( + handle_torch_function(self, "__len__")); Py_ssize_t length = PyLong_AsSsize_t(ret.ptr()); if (PyErr_Occurred()) { throw python_error(); @@ -48,7 +50,6 @@ Py_ssize_t THPVariable_length(PyObject* self) { END_HANDLE_TH_ERRORS_RET(-1) } - // We allow indexing by integers, slices, ellipsis, None, Variables, // and tuples of those types. We also handle bools as if they were a // Variable[ByteTensor]. @@ -57,10 +58,13 @@ static inline int64_t count_specified_dimensions(PyObject* index) { // Count the number of indexed dimensions (everything but ellipsis and None) // -1 is a sentinel for __torch_function__ int64_t count = 0; - auto size = PyTuple_GET_SIZE(index); // NOLINT(cppcoreguidelines-pro-type-cstyle-cast) + auto size = + PyTuple_GET_SIZE(index); // NOLINT(cppcoreguidelines-pro-type-cstyle-cast) for (Py_ssize_t i = 0; i < size; i++) { - PyObject* obj = PyTuple_GET_ITEM(index, i); // NOLINT(cppcoreguidelines-pro-type-cstyle-cast) - if (!THPVariable_CheckExact(obj) && check_has_torch_function(obj)) return -1; + PyObject* obj = PyTuple_GET_ITEM( + index, i); // NOLINT(cppcoreguidelines-pro-type-cstyle-cast) + if (!THPVariable_CheckExact(obj) && check_has_torch_function(obj)) + return -1; if (THPVariable_Check(obj)) { const auto& var = THPVariable_Unpack(obj); const auto& var_scalar_type = var.scalar_type(); @@ -69,46 +73,61 @@ static inline int64_t count_specified_dimensions(PyObject* index) { } else { count++; } - } else if (obj != Py_None && obj != Py_Ellipsis && obj != Py_True && obj != Py_False) { // NOLINT(cppcoreguidelines-pro-type-cstyle-cast) + } else if ( + obj != Py_None && obj != Py_Ellipsis && obj != Py_True && + obj != Py_False) { // NOLINT(cppcoreguidelines-pro-type-cstyle-cast) count++; } } return count; } -[[noreturn]] -static inline void invalid_index(PyObject* obj) { +[[noreturn]] static inline void invalid_index(PyObject* obj) { throw IndexError( - "only integers, slices (`:`), ellipsis (`...`), None and long or byte " - "Variables are valid indices (got %s)", Py_TYPE(obj)->tp_name); + "only integers, slices (`:`), ellipsis (`...`), None and long or byte " + "Variables are valid indices (got %s)", + Py_TYPE(obj)->tp_name); } -static inline Variable sequenceToVariable(c10::TensorOptions options, PyObject* seq) { - return torch::utils::indexing_tensor_from_data(options, kLong, c10::nullopt, seq); +static inline Variable sequenceToVariable( + c10::TensorOptions options, + PyObject* seq) { + return torch::utils::indexing_tensor_from_data( + options, kLong, c10::nullopt, seq); } -inline Variable valueToTensor(c10::TensorOptions options, PyObject* value, const at::Device& device) { +inline Variable valueToTensor( + c10::TensorOptions options, + PyObject* value, + const at::Device& device) { if (THPVariable_Check(value)) { return THPVariable_Unpack(value); } - at::AutoDispatchBelowADInplaceOrView guard; // TODO: remove + at::AutoDispatchBelowADInplaceOrView guard; // TODO: remove at::tracer::impl::NoTracerDispatchMode tracer_guard; if (THPUtils_checkLong(value) || PyBool_Check(value)) { - return at::indexing::scalarToTensor(Scalar(THPUtils_unpackLong(value)), options, device); + return at::indexing::scalarToTensor( + Scalar(THPUtils_unpackLong(value)), options, device); } if (PyFloat_Check(value)) { - return at::indexing::scalarToTensor(Scalar(THPUtils_unpackDouble(value)), options, device); + return at::indexing::scalarToTensor( + Scalar(THPUtils_unpackDouble(value)), options, device); } if (PyComplex_Check(value)) { - return at::indexing::scalarToTensor(Scalar(THPUtils_unpackComplexDouble(value)), options, device); + return at::indexing::scalarToTensor( + Scalar(THPUtils_unpackComplexDouble(value)), options, device); } throw TypeError( - "can't assign a %s to a %s", - Py_TYPE(value)->tp_name, - torch::utils::options_to_string(options).c_str()); + "can't assign a %s to a %s", + Py_TYPE(value)->tp_name, + torch::utils::options_to_string(options).c_str()); } -static inline void checkUnpackSlice(PyObject* index, Py_ssize_t* start_ptr, Py_ssize_t* stop_ptr, Py_ssize_t* step_ptr) { +static inline void checkUnpackSlice( + PyObject* index, + Py_ssize_t* start_ptr, + Py_ssize_t* stop_ptr, + Py_ssize_t* step_ptr) { if (!THPUtils_unpackSlice(index, start_ptr, stop_ptr, step_ptr)) { throw python_error(); } @@ -117,18 +136,31 @@ static inline void checkUnpackSlice(PyObject* index, Py_ssize_t* start_ptr, Py_s static inline void recordSliceTrace(PyObject* obj) { PySliceObject* sliceobj = (PySliceObject*)obj; if (THPVariable_Check(sliceobj->start)) { - torch::jit::tracer::ArgumentStash::stashValue(std::string("start"), 1, THPVariable_Unpack(sliceobj->start), torch::jit::IntType::get()); + torch::jit::tracer::ArgumentStash::stashValue( + std::string("start"), + 1, + THPVariable_Unpack(sliceobj->start), + torch::jit::IntType::get()); } if (THPVariable_Check(sliceobj->stop)) { - torch::jit::tracer::ArgumentStash::stashValue(std::string("end"), 1, THPVariable_Unpack(sliceobj->stop), torch::jit::IntType::get()); + torch::jit::tracer::ArgumentStash::stashValue( + std::string("end"), + 1, + THPVariable_Unpack(sliceobj->stop), + torch::jit::IntType::get()); } if (THPVariable_Check(sliceobj->step)) { - torch::jit::tracer::ArgumentStash::stashValue(std::string("step"), 1, THPVariable_Unpack(sliceobj->step), torch::jit::IntType::get()); + torch::jit::tracer::ArgumentStash::stashValue( + std::string("step"), + 1, + THPVariable_Unpack(sliceobj->step), + torch::jit::IntType::get()); } } static inline void recordSelectTrace(const Tensor& index_tensor) { - torch::jit::tracer::ArgumentStash::stashValue(std::string("index"), 1, index_tensor, torch::jit::IntType::get()); + torch::jit::tracer::ArgumentStash::stashValue( + std::string("index"), 1, index_tensor, torch::jit::IntType::get()); } static inline Variable applySlicing( @@ -137,79 +169,90 @@ static inline Variable applySlicing( variable_list& outIndices, bool is_tracing, const at::Device& self_device, - const c10::optional & self_sizes, + const c10::optional& self_sizes, int64_t specified_dims) { - int64_t size = PyTuple_GET_SIZE(index); // NOLINT(cppcoreguidelines-pro-type-cstyle-cast) + int64_t size = + PyTuple_GET_SIZE(index); // NOLINT(cppcoreguidelines-pro-type-cstyle-cast) int64_t dim = 0; // See NOTE [nested tensor size for indexing] if (self_sizes.has_value()) { TORCH_CHECK_INDEX( - specified_dims <= (int64_t)self_sizes->size(), - "too many indices for tensor of dimension ", (int)self_sizes->size()); + specified_dims <= (int64_t)self_sizes->size(), + "too many indices for tensor of dimension ", + (int)self_sizes->size()); } Variable result = self; - for(const auto i : c10::irange(size)) { - PyObject* obj = PyTuple_GET_ITEM(index, i); // NOLINT(cppcoreguidelines-pro-type-cstyle-cast) + for (const auto i : c10::irange(size)) { + PyObject* obj = PyTuple_GET_ITEM( + index, i); // NOLINT(cppcoreguidelines-pro-type-cstyle-cast) // NOTE [nested tensor size for indexing] - // nested tensor does not have a size (yet) so for now we represent its size as null - // may need to be changed after we reach a better solution for nested tensor size - c10::optional result_sizes = result.is_nested() ? c10::optional(c10::nullopt) : c10::optional(result.sizes()); + // nested tensor does not have a size (yet) so for now we represent its size + // as null may need to be changed after we reach a better solution for + // nested tensor size + c10::optional result_sizes = result.is_nested() + ? c10::optional(c10::nullopt) + : c10::optional(result.sizes()); result = at::indexing::handleDimInMultiDimIndexing( - /*prev_dim_result=*/result, - /*original_tensor=*/self, - /*index=*/([&]() { - if (THPUtils_checkLong(obj)) { - if (is_tracing && THPVariable_Check(obj)) { - recordSelectTrace(THPVariable_Unpack(obj)); - } - return at::indexing::TensorIndex(THPUtils_unpackLong(obj)); - } else if (PySlice_Check(obj)) { - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - Py_ssize_t start, stop, step; - checkUnpackSlice(obj, &start, &stop, &step); - if (is_tracing) { - recordSliceTrace(obj); - } - return at::indexing::TensorIndex(at::indexing::Slice(start, stop, step)); - } else if (obj == Py_Ellipsis) { - return at::indexing::TensorIndex(at::indexing::Ellipsis); - } else if (obj == Py_None) { - return at::indexing::TensorIndex(at::indexing::None); - } else if (PyBool_Check(obj)) { - return at::indexing::TensorIndex(obj == Py_True); - } else if (THPVariable_Check(obj)) { - Tensor tensor = THPVariable_Unpack(obj); - if (is_tracing) { - auto scalar_type = tensor.scalar_type(); - if (tensor.dim() == 0 && at::isIntegralType(scalar_type, /*includeBool=*/false) && scalar_type != at::kByte) { - recordSelectTrace(tensor); + /*prev_dim_result=*/result, + /*original_tensor=*/self, + /*index=*/([&]() { + if (THPUtils_checkLong(obj)) { + if (is_tracing && THPVariable_Check(obj)) { + recordSelectTrace(THPVariable_Unpack(obj)); } + return at::indexing::TensorIndex(THPUtils_unpackLong(obj)); + } else if (PySlice_Check(obj)) { + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + Py_ssize_t start, stop, step; + checkUnpackSlice(obj, &start, &stop, &step); + if (is_tracing) { + recordSliceTrace(obj); + } + return at::indexing::TensorIndex( + at::indexing::Slice(start, stop, step)); + } else if (obj == Py_Ellipsis) { + return at::indexing::TensorIndex(at::indexing::Ellipsis); + } else if (obj == Py_None) { + return at::indexing::TensorIndex(at::indexing::None); + } else if (PyBool_Check(obj)) { + return at::indexing::TensorIndex(obj == Py_True); + } else if (THPVariable_Check(obj)) { + Tensor tensor = THPVariable_Unpack(obj); + if (is_tracing) { + auto scalar_type = tensor.scalar_type(); + if (tensor.dim() == 0 && + at::isIntegralType(scalar_type, /*includeBool=*/false) && + scalar_type != at::kByte) { + recordSelectTrace(tensor); + } + } + return at::indexing::TensorIndex(std::move(tensor)); + } else if (PySequence_Check(obj)) { + return at::indexing::TensorIndex( + sequenceToVariable(self.options(), obj)); + } else { + auto idx = THPObjectPtr(PyNumber_Index(obj)); + if (!idx) { + PyErr_Clear(); + invalid_index(obj); + } + if (is_tracing && THPVariable_Check(idx)) { + recordSelectTrace(THPVariable_Unpack(idx)); + } + return at::indexing::TensorIndex(THPUtils_unpackLong(idx)); } - return at::indexing::TensorIndex(std::move(tensor)); - } else if (PySequence_Check(obj)) { - return at::indexing::TensorIndex(sequenceToVariable(self.options(), obj)); - } else { - auto idx = THPObjectPtr(PyNumber_Index(obj)); - if (!idx) { - PyErr_Clear(); - invalid_index(obj); - } - if (is_tracing && THPVariable_Check(idx)) { - recordSelectTrace(THPVariable_Unpack(idx)); - } - return at::indexing::TensorIndex(THPUtils_unpackLong(idx)); - } - })(), - /*dim_ptr=*/&dim, - /*specified_dims_ptr=*/&specified_dims, - /*real_dim=*/i, - /*outIndices=*/outIndices, - // See NOTE [ Setting `disable_slice_optimization` when calling C++ tensor indexing functions from Python ] - /*disable_slice_optimization=*/is_tracing, - /*original_tensor_device=*/self_device, - /*prev_dim_result_sizes=*/result_sizes); + })(), + /*dim_ptr=*/&dim, + /*specified_dims_ptr=*/&specified_dims, + /*real_dim=*/i, + /*outIndices=*/outIndices, + // See NOTE [ Setting `disable_slice_optimization` when calling C++ + // tensor indexing functions from Python ] + /*disable_slice_optimization=*/is_tracing, + /*original_tensor_device=*/self_device, + /*prev_dim_result_sizes=*/result_sizes); } return result; } @@ -248,7 +291,8 @@ static inline bool treatSequenceAsTuple(PyObject* index) { PyErr_Clear(); return false; } - if (THPVariable_Check(obj.get()) || PySequence_Check(obj.get()) || PySlice_Check(obj.get())) { + if (THPVariable_Check(obj.get()) || PySequence_Check(obj.get()) || + PySlice_Check(obj.get())) { return true; } if (obj.get() == Py_Ellipsis || obj.get() == Py_None) { @@ -263,9 +307,11 @@ static inline THPObjectPtr wrapTuple(PyObject* index) { if (treatSequenceAsTuple(index)) { res = PySequence_Tuple(index); } else { - res = PyTuple_Pack(1, index); // NOLINT(cppcoreguidelines-pro-type-cstyle-cast) + res = PyTuple_Pack( + 1, index); // NOLINT(cppcoreguidelines-pro-type-cstyle-cast) } - if (!res) throw python_error(); + if (!res) + throw python_error(); return res; } @@ -287,11 +333,11 @@ PyObject* THPVariable_getitem(PyObject* self, PyObject* index) { // handle simple types: none, ellipsis if (index == Py_None) { - return THPVariable_Wrap( - at::indexing::get_item(self_, {at::indexing::TensorIndex(at::indexing::None)})); + return THPVariable_Wrap(at::indexing::get_item( + self_, {at::indexing::TensorIndex(at::indexing::None)})); } else if (index == Py_Ellipsis) { - return THPVariable_Wrap( - at::indexing::get_item(self_, {at::indexing::TensorIndex(at::indexing::Ellipsis)})); + return THPVariable_Wrap(at::indexing::get_item( + self_, {at::indexing::TensorIndex(at::indexing::Ellipsis)})); } bool is_tracing = torch::jit::tracer::isTracing(); @@ -301,8 +347,8 @@ PyObject* THPVariable_getitem(PyObject* self, PyObject* index) { if (is_tracing && THPVariable_Check(index)) { recordSelectTrace(THPVariable_Unpack(index)); } - return THPVariable_Wrap( - at::indexing::get_item(self_, {at::indexing::TensorIndex(THPUtils_unpackLong(index))})); + return THPVariable_Wrap(at::indexing::get_item( + self_, {at::indexing::TensorIndex(THPUtils_unpackLong(index))})); } else if (PySlice_Check(index)) { // NOLINTNEXTLINE(cppcoreguidelines-init-variables) Py_ssize_t start, stop, step; @@ -310,12 +356,14 @@ PyObject* THPVariable_getitem(PyObject* self, PyObject* index) { if (is_tracing) { recordSliceTrace(index); } - return THPVariable_Wrap( - at::indexing::get_item(self_, {at::indexing::TensorIndex(at::indexing::Slice(start, stop, step))})); + return THPVariable_Wrap(at::indexing::get_item( + self_, + {at::indexing::TensorIndex(at::indexing::Slice(start, stop, step))})); } else if (index == Py_False || index == Py_True) { return THPVariable_Wrap(([&]() { pybind11::gil_scoped_release no_gil; - return at::indexing::get_item(self_, {at::indexing::TensorIndex(index == Py_True)}); + return at::indexing::get_item( + self_, {at::indexing::TensorIndex(index == Py_True)}); })()); } @@ -329,9 +377,16 @@ PyObject* THPVariable_getitem(PyObject* self, PyObject* index) { } // See NOTE [nested tensor size for indexing] c10::optional self_sizes = c10::nullopt; - if (! self_.is_nested()) self_sizes = self_.sizes(); + if (!self_.is_nested()) + self_sizes = self_.sizes(); Variable sliced = applySlicing( - self_, holder.get(), variableIndices, /*is_tracing=*/is_tracing, self_.device(), self_sizes, specified_dims); + self_, + holder.get(), + variableIndices, + /*is_tracing=*/is_tracing, + self_.device(), + self_sizes, + specified_dims); if (variableIndices.empty()) { if (sliced.is_same(self_)) { // ensure we return a shallow copy for things like x[...] @@ -350,8 +405,11 @@ PyObject* THPVariable_getitem(PyObject* self, PyObject* index) { END_HANDLE_TH_ERRORS } -void dispatch_set_item(const Tensor& self, ArrayRef indices, - const Tensor& value, bool disable_slice_optimization=false) { +void dispatch_set_item( + const Tensor& self, + ArrayRef indices, + const Tensor& value, + bool disable_slice_optimization = false) { pybind11::gil_scoped_release no_gil; at::indexing::set_item(self, indices, value, disable_slice_optimization); } @@ -370,16 +428,15 @@ int THPVariable_setitem(PyObject* self, PyObject* index, PyObject* py_value) { throw TypeError("Tensor does not support deleting items"); } if ((!THPVariable_CheckExact(self) && check_has_torch_function(self)) || - (!THPVariable_CheckExact(py_value) && check_has_torch_function(py_value))) { + (!THPVariable_CheckExact(py_value) && + check_has_torch_function(py_value))) { py::object ret = py::reinterpret_steal( - handle_torch_function_indexing(self, index, py_value) - ); + handle_torch_function_indexing(self, index, py_value)); return 0; } const auto& self_ = THPVariable_Unpack(self); - if (self_.is_sparse() || self_.is_sparse_csr()) - { + if (self_.is_sparse() || self_.is_sparse_csr()) { throw TypeError("Cannot assign to a sparse tensor"); } OptionalDeviceGuard device_guard(device_of(self_)); @@ -387,7 +444,8 @@ int THPVariable_setitem(PyObject* self, PyObject* index, PyObject* py_value) { Variable value; // TODO: This qint special case looks very suspicious... if (isQIntType(self_.scalar_type())) { - value = valueToTensor(device(kCPU).dtype(kFloat), py_value, at::Device(kCPU)); + value = + valueToTensor(device(kCPU).dtype(kFloat), py_value, at::Device(kCPU)); } else if (self_device.is_cuda()) { value = valueToTensor(self_.options(), py_value, at::Device(kCPU)); } else { @@ -396,14 +454,16 @@ int THPVariable_setitem(PyObject* self, PyObject* index, PyObject* py_value) { // handle simple types: ellipsis, none, bool if (index == Py_False) { // NOLINT(cppcoreguidelines-pro-type-cstyle-cast) - // do nothing for false (technically we should check the size, but we don't have - // real 0-sized shapes. + // do nothing for false (technically we should check the size, but we don't + // have real 0-sized shapes. return 0; } else if (index == Py_Ellipsis) { - dispatch_set_item(self_, {at::indexing::TensorIndex(at::indexing::Ellipsis)}, value); + dispatch_set_item( + self_, {at::indexing::TensorIndex(at::indexing::Ellipsis)}, value); return 0; } else if (index == Py_None) { - dispatch_set_item(self_, {at::indexing::TensorIndex(at::indexing::None)}, value); + dispatch_set_item( + self_, {at::indexing::TensorIndex(at::indexing::None)}, value); return 0; } else if (index == Py_True) { dispatch_set_item(self_, {at::indexing::TensorIndex(true)}, value); @@ -417,7 +477,8 @@ int THPVariable_setitem(PyObject* self, PyObject* index, PyObject* py_value) { if (is_tracing && THPVariable_Check(index)) { recordSelectTrace(THPVariable_Unpack(index)); } - dispatch_set_item(self_, {at::indexing::TensorIndex(THPUtils_unpackLong(index))}, value); + dispatch_set_item( + self_, {at::indexing::TensorIndex(THPUtils_unpackLong(index))}, value); return 0; } else if (PySlice_Check(index)) { // NOLINTNEXTLINE(cppcoreguidelines-init-variables) @@ -426,9 +487,13 @@ int THPVariable_setitem(PyObject* self, PyObject* index, PyObject* py_value) { if (is_tracing) { recordSliceTrace(index); } - // See NOTE [ Setting `disable_slice_optimization` when calling C++ tensor indexing functions from Python ] + // See NOTE [ Setting `disable_slice_optimization` when calling C++ tensor + // indexing functions from Python ] dispatch_set_item( - self_, {at::indexing::TensorIndex(at::indexing::Slice(start, stop, step))}, value, /*disable_slice_optimization=*/is_tracing); + self_, + {at::indexing::TensorIndex(at::indexing::Slice(start, stop, step))}, + value, + /*disable_slice_optimization=*/is_tracing); return 0; } @@ -439,12 +504,17 @@ int THPVariable_setitem(PyObject* self, PyObject* index, PyObject* py_value) { int64_t specified_dims = count_specified_dimensions(holder.get()); if (specified_dims == -1) { py::object val = py::reinterpret_steal( - handle_torch_function_indexing(self, index, py_value) - ); + handle_torch_function_indexing(self, index, py_value)); return 0; } Variable sliced = applySlicing( - self_, holder.get(), variableIndices, /*is_tracing=*/is_tracing, self_device, self_.sizes(), specified_dims); + self_, + holder.get(), + variableIndices, + /*is_tracing=*/is_tracing, + self_device, + self_.sizes(), + specified_dims); if (variableIndices.empty()) { pybind11::gil_scoped_release no_gil; at::indexing::copy_to(sliced, value); @@ -461,10 +531,12 @@ int THPVariable_setitem(PyObject* self, PyObject* index, PyObject* py_value) { } else { valuesSliced = value; } - at::indexing::dispatch_index_put_(sliced, std::move(variableIndices), valuesSliced); + at::indexing::dispatch_index_put_( + sliced, std::move(variableIndices), valuesSliced); return 0; } END_HANDLE_TH_ERRORS_RET(-1) } -}} // namespace torch::autograd +} // namespace autograd +} // namespace torch diff --git a/torch/csrc/autograd/python_variable_indexing.h b/torch/csrc/autograd/python_variable_indexing.h index 027bffb6dc8a04..7dda1ad9e9d8af 100644 --- a/torch/csrc/autograd/python_variable_indexing.h +++ b/torch/csrc/autograd/python_variable_indexing.h @@ -1,14 +1,19 @@ #pragma once -#include #include +#include -namespace torch { namespace autograd { +namespace torch { +namespace autograd { Py_ssize_t THPVariable_length(PyObject* self); PyObject* THPVariable_getitem(PyObject* self, PyObject* index); int THPVariable_setitem(PyObject* self, PyObject* index, PyObject* value); -Variable valueToTensor(c10::TensorOptions options, PyObject* value, const at::Device& device); +Variable valueToTensor( + c10::TensorOptions options, + PyObject* value, + const at::Device& device); -}} // namespace torch::autograd +} // namespace autograd +} // namespace torch diff --git a/torch/csrc/autograd/record_function_ops.cpp b/torch/csrc/autograd/record_function_ops.cpp index f5f09b3fe94080..716981fd6a562c 100644 --- a/torch/csrc/autograd/record_function_ops.cpp +++ b/torch/csrc/autograd/record_function_ops.cpp @@ -1,10 +1,10 @@ -#include +#include #include #include -#include +#include -#include #include +#include namespace caffe2 { // Required for cpp_custom_type_hack to work @@ -21,10 +21,11 @@ namespace profiler { void record_function_enter( const std::string& name, const c10::optional& args, - at::RecordFunction &rec) { + at::RecordFunction& rec) { if (rec.isActive()) { if (rec.needsInputs() && args.has_value()) { - rec.before(name, c10::ArrayRef{c10::IValue{args.value()}}); + rec.before( + name, c10::ArrayRef{c10::IValue{args.value()}}); } else { rec.before(name); } @@ -42,8 +43,10 @@ at::Tensor record_function_enter_legacy( // New signature using custom_class c10::intrusive_ptr record_function_enter_new( - const std::string &name, const c10::optional &args) { - auto rec = c10::make_intrusive(at::RecordScope::USER_SCOPE); + const std::string& name, + const c10::optional& args) { + auto rec = + c10::make_intrusive(at::RecordScope::USER_SCOPE); record_function_enter(name, args, rec->record); return rec; } @@ -54,12 +57,12 @@ at::RecordFunction& getRecordFunctionFromTensor(const at::Tensor& handle) { } // Ends the profiling scope created with record_function_enter. -void record_function_exit(at::RecordFunction &rec) { +void record_function_exit(at::RecordFunction& rec) { rec.end(); } // Legacy signature using cpp_custom_type_hack -void record_function_exit_legacy(const at::Tensor &handle) { +void record_function_exit_legacy(const at::Tensor& handle) { // We don't actually need to do anything with handle just need to persist the // lifetime until now. auto& rec = getRecordFunctionFromTensor(handle); @@ -67,7 +70,8 @@ void record_function_exit_legacy(const at::Tensor &handle) { } // New signature using custom_class -void record_function_exit_new(const c10::intrusive_ptr &record) { +void record_function_exit_new( + const c10::intrusive_ptr& record) { record_function_exit(record->record); } @@ -81,27 +85,26 @@ c10::intrusive_ptr _call_end_callbacks_on_fut( [get_record = std::move(get_record)](c10::ivalue::Future& fut) { auto& rec = get_record(); rec.end(); - // Note: this future is returned to the user to ensure that a call to wait() - // ensures that profiling callbacks have ran. To ensure that this is - // transparent, we must make this future propagate the value of the RPC - // future. - // Use value() here instead of constValue() to ensure we propagate errors. + // Note: this future is returned to the user to ensure that a call to + // wait() ensures that profiling callbacks have ran. To ensure that this + // is transparent, we must make this future propagate the value of the + // RPC future. Use value() here instead of constValue() to ensure we + // propagate errors. return fut.value(); }; // Define a future that completes after the profiling callbacks are run. - auto profiledFut = fut->then(at::wrapPropagateTLSState( - std::move(futureProfilingFunc)), - fut->elementType() - ); + auto profiledFut = fut->then( + at::wrapPropagateTLSState(std::move(futureProfilingFunc)), + fut->elementType()); return profiledFut; } // Legacy signature using cpp_custom_type_hack c10::intrusive_ptr _call_end_callbacks_on_fut_legacy( - const at::Tensor &handle, + const at::Tensor& handle, const c10::intrusive_ptr& fut) { return _call_end_callbacks_on_fut( - [handle] () -> at::RecordFunction& { + [handle]() -> at::RecordFunction& { TORCH_INTERNAL_ASSERT( handle.defined(), "Undefined RecordFunction handle. This can happen if the handle is " @@ -110,53 +113,55 @@ c10::intrusive_ptr _call_end_callbacks_on_fut_legacy( return getRecordFunctionFromTensor(handle); }, - fut - ); + fut); } // New signature using custom_class c10::intrusive_ptr _call_end_callbacks_on_fut_new( - const c10::intrusive_ptr &record, + const c10::intrusive_ptr& record, const c10::intrusive_ptr& fut) { return _call_end_callbacks_on_fut( - [record] () -> at::RecordFunction& { return record->record; }, fut); + [record]() -> at::RecordFunction& { return record->record; }, fut); } // Internal only, do not use directly, use Python's record_function() TORCH_LIBRARY_FRAGMENT(profiler, m) { m.class_("_RecordFunction"); - m.def("_record_function_enter(str name, str? args=None) -> Tensor", - &record_function_enter_legacy); - m.def("_record_function_enter_new(str name, str? args=None) -> " - "__torch__.torch.classes.profiler._RecordFunction", - &record_function_enter_new); + m.def( + "_record_function_enter(str name, str? args=None) -> Tensor", + &record_function_enter_legacy); + m.def( + "_record_function_enter_new(str name, str? args=None) -> " + "__torch__.torch.classes.profiler._RecordFunction", + &record_function_enter_new); m.def("_record_function_exit", &record_function_exit_legacy); m.def("_record_function_exit._RecordFunction", &record_function_exit_new); torch::jit::registerOperator(torch::jit::Operator( - "profiler::_call_end_callbacks_on_jit_fut(Tensor x, Future(t) y) -> Future(t)", - [](jit::Stack& stack) { - // Pop inputs, which should be a future and a tensor - auto fut = jit::pop(stack).toFuture(); - auto tensor = jit::pop(stack).toTensor(); - auto profiledFut = _call_end_callbacks_on_fut_legacy(tensor, fut); - // return future that completes when profiling callbacks have run. - jit::push(stack, std::move(profiledFut)); - }, - c10::AliasAnalysisKind::FROM_SCHEMA)); + "profiler::_call_end_callbacks_on_jit_fut(Tensor x, Future(t) y) -> Future(t)", + [](jit::Stack& stack) { + // Pop inputs, which should be a future and a tensor + auto fut = jit::pop(stack).toFuture(); + auto tensor = jit::pop(stack).toTensor(); + auto profiledFut = _call_end_callbacks_on_fut_legacy(tensor, fut); + // return future that completes when profiling callbacks have run. + jit::push(stack, std::move(profiledFut)); + }, + c10::AliasAnalysisKind::FROM_SCHEMA)); torch::jit::registerOperator(torch::jit::Operator( - "profiler::_call_end_callbacks_on_jit_fut._RecordFunction(" - "__torch__.torch.classes.profiler._RecordFunction x, Future(t) y) -> Future(t)", - [](c10::Stack &stack) { - // Pop inputs, which should be a future and a PythonRecordFunction - auto fut = torch::jit::pop(stack).toFuture(); - auto tensor = torch::jit::pop(stack).toCustomClass(); - auto profiledFut = _call_end_callbacks_on_fut_new(tensor, fut); - // return future that completes when profiling callbacks have run. - torch::jit::push(stack, std::move(profiledFut)); - }, - c10::AliasAnalysisKind::FROM_SCHEMA)); + "profiler::_call_end_callbacks_on_jit_fut._RecordFunction(" + "__torch__.torch.classes.profiler._RecordFunction x, Future(t) y) -> Future(t)", + [](c10::Stack& stack) { + // Pop inputs, which should be a future and a PythonRecordFunction + auto fut = torch::jit::pop(stack).toFuture(); + auto tensor = + torch::jit::pop(stack).toCustomClass(); + auto profiledFut = _call_end_callbacks_on_fut_new(tensor, fut); + // return future that completes when profiling callbacks have run. + torch::jit::push(stack, std::move(profiledFut)); + }, + c10::AliasAnalysisKind::FROM_SCHEMA)); } } // namespace profiler diff --git a/torch/csrc/autograd/record_function_ops.h b/torch/csrc/autograd/record_function_ops.h index 2c074f2dfe5b30..2011acb68b9a71 100644 --- a/torch/csrc/autograd/record_function_ops.h +++ b/torch/csrc/autograd/record_function_ops.h @@ -18,11 +18,12 @@ struct PythonRecordFunction : public torch::CustomClassHolder { // Creates a new profiling scope using RecordFunction and invokes its starting // callbacks. TORCH_API c10::intrusive_ptr record_function_enter_new( - const std::string &name, const c10::optional &args = c10::nullopt); + const std::string& name, + const c10::optional& args = c10::nullopt); // Schedules RecordFunction's end callbacks to be run on completion of a future. TORCH_API c10::intrusive_ptr _call_end_callbacks_on_fut_new( - const c10::intrusive_ptr &record, + const c10::intrusive_ptr& record, const c10::intrusive_ptr& fut); } // namespace profiler diff --git a/torch/csrc/autograd/saved_variable.cpp b/torch/csrc/autograd/saved_variable.cpp index b16fa87fda8208..c6ca8eda13d199 100644 --- a/torch/csrc/autograd/saved_variable.cpp +++ b/torch/csrc/autograd/saved_variable.cpp @@ -14,9 +14,13 @@ #include #include -namespace torch { namespace autograd { +namespace torch { +namespace autograd { -SavedVariable::SavedVariable(const Variable& variable, bool is_output, bool is_inplace_on_view) { +SavedVariable::SavedVariable( + const Variable& variable, + bool is_output, + bool is_inplace_on_view) { if (variable.defined()) { // Note [Inference tensor cannot be saved for backward] // Invariant: @@ -24,19 +28,22 @@ SavedVariable::SavedVariable(const Variable& variable, bool is_output, bool is_i // If an inference tensor was saved for backward in an autograd session and // then you reenter inference mode and make an inplace update to the tensor // without bumping version_counter, it'll lead to silent wrong result when - // you do backward() for the previous autograd session. Technically we don't - // have to check here since it'll fail when querying `current_version` on - // the inference tensor, but we can give a much better error message here. + // you do backward() for the previous autograd session. Technically we + // don't have to check here since it'll fail when querying `current_version` + // on the inference tensor, but we can give a much better error message + // here. // // Note in the documentation we say "inference tensor cannot participate // in autograd" which is more restrictive than the invariant. In practice // the check is more permissive and only error out when an inference tensor - // is saved for backward. Whether a tensor is saved for backward is determined - // by derivative formula and thus varies op by op, so by saying "no inference - // tensor in autograd" it's easier for users to understand and follow. - TORCH_CHECK(!variable.is_inference(), - "Inference tensors cannot be saved for backward. To work around " - "you can make a clone to get a normal tensor and use it in autograd.") + // is saved for backward. Whether a tensor is saved for backward is + // determined by derivative formula and thus varies op by op, so by saying + // "no inference tensor in autograd" it's easier for users to understand and + // follow. + TORCH_CHECK( + !variable.is_inference(), + "Inference tensors cannot be saved for backward. To work around " + "you can make a clone to get a normal tensor and use it in autograd.") was_default_constructed_ = false; const auto& version_counter = impl::version_counter(variable); @@ -63,9 +70,9 @@ SavedVariable::SavedVariable(const Variable& variable, bool is_output, bool is_i // 1. If the variable is not an output, its grad_fn has already been fully // created and in particular will be a different Node than the one // we are currently constructing (the one that owns this SavedVariable). - // 2. If the variable is a leaf, it only has weak reference to the grad_accumulator - // which cannot create a cycle. - // In those cases, we save the original variable and don't need further processing. + // 2. If the variable is a leaf, it only has weak reference to the + // grad_accumulator which cannot create a cycle. In those cases, we save the + // original variable and don't need further processing. if (!is_output || is_leaf_) { saved_original_ = true; data_ = variable; @@ -110,8 +117,14 @@ void SavedVariable::reset_data() { data_.reset(); } -SavedVariable::SavedVariable(const c10::optional& variable, bool is_output, bool is_inplace_on_view) - : SavedVariable(variable.has_value() ? *variable : Variable(), is_output, is_inplace_on_view) {} +SavedVariable::SavedVariable( + const c10::optional& variable, + bool is_output, + bool is_inplace_on_view) + : SavedVariable( + variable.has_value() ? *variable : Variable(), + is_output, + is_inplace_on_view) {} Variable SavedVariable::unpack(std::shared_ptr saved_for) const { if (was_default_constructed_) { @@ -126,8 +139,8 @@ Variable SavedVariable::unpack(std::shared_ptr saved_for) const { // if versions don't match auto grad_fn = is_inplace_on_view_ ? weak_grad_fn_.lock() - : !hooks_ ? saved_original_ ? data_.grad_fn() : nullptr - : grad_fn_; + : !hooks_ ? saved_original_ ? data_.grad_fn() : nullptr + : grad_fn_; if (!is_leaf_ && !grad_fn) { TORCH_INTERNAL_ASSERT(saved_for, "No grad_fn for non-leaf saved tensor"); @@ -137,36 +150,39 @@ Variable SavedVariable::unpack(std::shared_ptr saved_for) const { // Only check version counter in the case without hooks // If user provides hooks, we can't track versions through the hooks if (!hooks_) { - auto current_version = saved_original_ ? impl::version_counter(data_).current_version() - : version_counter_.current_version(); + auto current_version = saved_original_ + ? impl::version_counter(data_).current_version() + : version_counter_.current_version(); if (saved_version_ != current_version) { std::stringstream message; - message << "one of the variables needed for gradient computation has been " - "modified by an inplace operation: [" << data_.toString() << " " - << data_.sizes() << "]"; + message + << "one of the variables needed for gradient computation has been " + "modified by an inplace operation: [" + << data_.toString() << " " << data_.sizes() << "]"; if (grad_fn) { - message << ", which is output " << output_nr_ - << " of " << grad_fn->name() << ","; + message << ", which is output " << output_nr_ << " of " + << grad_fn->name() << ","; } - message << " is at version " << current_version - << "; expected version " << saved_version_ << " instead."; + message << " is at version " << current_version << "; expected version " + << saved_version_ << " instead."; if (!AnomalyMode::is_enabled()) { - message << " Hint: enable anomaly detection to find the operation " - "that failed to compute its gradient, with torch.autograd." - "set_detect_anomaly(True)."; - } - else { - message << " Hint: the backtrace further above shows the operation " - "that failed to compute its gradient. The variable in question " - "was changed in there or anywhere later. Good luck!"; + message << " Hint: enable anomaly detection to find the operation " + "that failed to compute its gradient, with torch.autograd." + "set_detect_anomaly(True)."; + } else { + message + << " Hint: the backtrace further above shows the operation " + "that failed to compute its gradient. The variable in question " + "was changed in there or anywhere later. Good luck!"; } TORCH_CHECK(false, message.str()); } } // The version counter is correct. - // Additionnally, if we deal with a non-leaf variable, we have its correct grad_fn. + // Additionnally, if we deal with a non-leaf variable, we have its correct + // grad_fn. // If we have the original variable, we simply return it if (!hooks_ && saved_original_) { @@ -193,8 +209,7 @@ Variable SavedVariable::unpack(std::shared_ptr saved_for) const { // graph. if (is_leaf_ && requires_grad_) { TORCH_INTERNAL_ASSERT( - !grad_accumulator_.expired(), - "No grad accumulator for a saved leaf"); + !grad_accumulator_.expired(), "No grad accumulator for a saved leaf"); } impl::set_grad_accumulator(var, grad_accumulator_); @@ -211,12 +226,15 @@ Variable SavedVariable::unpack(std::shared_ptr saved_for) const { return var; } -void SavedVariable::set_hooks_and_pack_data(std::unique_ptr&& hooks, const Variable& data) { +void SavedVariable::set_hooks_and_pack_data( + std::unique_ptr&& hooks, + const Variable& data) { hooks_ = std::move(hooks); at::NoGradGuard guard; const auto version = impl::version_counter(data).current_version(); hooks_->call_pack_hook(saved_original_ ? data.detach() : data); - TORCH_CHECK(version == impl::version_counter(data).current_version(), + TORCH_CHECK( + version == impl::version_counter(data).current_version(), "A saved tensor pack hook is modifying its input in place. " "Tensors provided as input to pack hook can not be modified by " "in-place operations as this can lead to unexpected side-effects. " @@ -224,22 +242,26 @@ void SavedVariable::set_hooks_and_pack_data(std::unique_ptr& "the input to a pack hook."); } -void SavedVariable::register_hooks(std::unique_ptr&& hooks) { +void SavedVariable::register_hooks( + std::unique_ptr&& hooks) { TORCH_INTERNAL_ASSERT(hooks); - TORCH_CHECK(!hooks_, - "Calling register_hooks on a saved tensor whose hooks have already been set. " - "Hint: only one pair of hooks is allowed at a time."); + TORCH_CHECK( + !hooks_, + "Calling register_hooks on a saved tensor whose hooks have already been set. " + "Hint: only one pair of hooks is allowed at a time."); if (!data_.defined()) { if (!was_default_constructed_) { - TORCH_CHECK(false, - "Calling register_hooks on a saved tensor after it has been freed. " - "Saved intermediate values of the graph are freed when you call " - ".backward() or autograd.grad(). Specify retain_graph=True if you " - "need to backward through the graph a second time or if you need to " - "access saved variables after calling backward."); + TORCH_CHECK( + false, + "Calling register_hooks on a saved tensor after it has been freed. " + "Saved intermediate values of the graph are freed when you call " + ".backward() or autograd.grad(). Specify retain_graph=True if you " + "need to backward through the graph a second time or if you need to " + "access saved variables after calling backward."); } else { - TORCH_CHECK(false, - "Calling register_hooks on a saved tensor with value None is forbidden"); + TORCH_CHECK( + false, + "Calling register_hooks on a saved tensor with value None is forbidden"); } } // If we didn't save the original variable, we already saved metadata @@ -257,4 +279,5 @@ const char* ERR_BACKWARD_TWICE = "retain_graph=True if you need to backward through the graph a second time or " "if you need to access saved tensors after calling backward."; -}} // namespace torch::autograd +} // namespace autograd +} // namespace torch diff --git a/torch/csrc/autograd/saved_variable.h b/torch/csrc/autograd/saved_variable.h index 2959d58fbcbf1a..9136ff1a620e0d 100644 --- a/torch/csrc/autograd/saved_variable.h +++ b/torch/csrc/autograd/saved_variable.h @@ -9,7 +9,8 @@ #include #include -namespace torch { namespace autograd { +namespace torch { +namespace autograd { using Variable = at::Tensor; struct Node; @@ -22,8 +23,14 @@ TORCH_API extern const char* ERR_BACKWARD_TWICE; class TORCH_API SavedVariable { public: SavedVariable() = default; - SavedVariable(const Variable& variable, bool is_output, bool is_inplace_on_view=false); - SavedVariable(const c10::optional& variable, bool is_output, bool is_inplace_on_view=false); + SavedVariable( + const Variable& variable, + bool is_output, + bool is_inplace_on_view = false); + SavedVariable( + const c10::optional& variable, + bool is_output, + bool is_inplace_on_view = false); SavedVariable(SavedVariable&&) = default; SavedVariable& operator=(SavedVariable&&) = default; ~SavedVariable() { @@ -49,25 +56,25 @@ class TORCH_API SavedVariable { // If storing the variable itself would create a circular reference, // we fall into the second case and its metadata is also saved separately. // In that case, the grad_fn must be passed in to the unpack function when - // reconstructing the Variable (except when we are doing an inplace operation on - // a view, see below). - // The field saved_orignal_ below reflects the two cases: its value is true - // in the first case and false in the second case. + // reconstructing the Variable (except when we are doing an inplace operation + // on a view, see below). The field saved_orignal_ below reflects the two + // cases: its value is true in the first case and false in the second case. // The value data_.defined() can be false in three cases: - // 1. SavedVariable was constructed without a Tensor (the value to save is None), in - // that case was_default_constructed_ will be kept at true - // 2. The saved variable has been released by calling SavedVariable::reset_data(), typically - // during the backward pass - // 3. Hooks have been registered. In that case, hooks_ will be defined instead. - // Note that the value of saved_original_ only reflects what happened during the construction - // of the SavedVariable. If saved_original_ is true, we saved the original tensor in data_, - // but if the user registers hooks, we will no longer have it (despite the saved_original_ still - // being true) + // 1. SavedVariable was constructed without a Tensor (the value to save is + // None), in that case was_default_constructed_ will be kept at true + // 2. The saved variable has been released by calling + // SavedVariable::reset_data(), typically during the backward pass + // 3. Hooks have been registered. In that case, hooks_ will be defined + // instead. Note that the value of saved_original_ only reflects what happened + // during the construction of the SavedVariable. If saved_original_ is true, + // we saved the original tensor in data_, but if the user registers hooks, we + // will no longer have it (despite the saved_original_ still being true) at::Tensor data_; // This field is used to store the forward AD gradients associated with // the saved Tensor. Note that this shared_ptr must never be shared with - // either the saved Tensor or the unpacked Tensor. See note [ Using ForwardGrad ] + // either the saved Tensor or the unpacked Tensor. See note [ Using + // ForwardGrad ] std::shared_ptr fw_grad_; // Weak version of grad_fn_ that prevents leaks in rebase_history() for @@ -87,18 +94,23 @@ class TORCH_API SavedVariable { bool is_leaf_ = false; bool is_output_ = false; - // Hooks are a pair of functions pack_hook/unpack_hook that provides fine-grained control - // over how the SavedVariable should save its data. - // pack_hook is called upon registration, while unpack_hook is called when unpacking. + // Hooks are a pair of functions pack_hook/unpack_hook that provides + // fine-grained control over how the SavedVariable should save its data. + // pack_hook is called upon registration, while unpack_hook is called when + // unpacking. std::unique_ptr hooks_; - // Fields grad_fn_, grad_accumulator_, and requires_grad_ are only used if hooks are defined. - // They are set before pack_hook is called and used after unpack_hook is called. + // Fields grad_fn_, grad_accumulator_, and requires_grad_ are only used if + // hooks are defined. They are set before pack_hook is called and used after + // unpack_hook is called. std::shared_ptr grad_fn_; std::weak_ptr grad_accumulator_; bool requires_grad_ = false; void save_metadata(const Variable& data); static std::unique_ptr get_default_hooks(); - void set_hooks_and_pack_data(std::unique_ptr&& hooks, const Variable& data); + void set_hooks_and_pack_data( + std::unique_ptr&& hooks, + const Variable& data); }; -}} // namespace torch::autograd +} // namespace autograd +} // namespace torch diff --git a/torch/csrc/autograd/saved_variable_hooks.h b/torch/csrc/autograd/saved_variable_hooks.h index dc9c038353e592..ccd51f4e2ac387 100644 --- a/torch/csrc/autograd/saved_variable_hooks.h +++ b/torch/csrc/autograd/saved_variable_hooks.h @@ -2,12 +2,14 @@ #include -namespace torch { namespace autograd { +namespace torch { +namespace autograd { struct TORCH_API SavedVariableHooks { - virtual void call_pack_hook(const at::Tensor &tensor) = 0; + virtual void call_pack_hook(const at::Tensor& tensor) = 0; virtual at::Tensor call_unpack_hook() = 0; virtual ~SavedVariableHooks() = default; }; -}} +} // namespace autograd +} // namespace torch diff --git a/torch/csrc/autograd/symbolic.h b/torch/csrc/autograd/symbolic.h index caba1945ad8e84..6d57aa0aabdfa3 100644 --- a/torch/csrc/autograd/symbolic.h +++ b/torch/csrc/autograd/symbolic.h @@ -4,7 +4,8 @@ #include #include -namespace torch { namespace autograd { +namespace torch { +namespace autograd { struct SymbolicContext { jit::Block* block; @@ -14,4 +15,5 @@ struct symbolic_unconvertible : public std::runtime_error { using std::runtime_error::runtime_error; }; -}} // namespace torch::autograd +} // namespace autograd +} // namespace torch diff --git a/torch/csrc/autograd/utils/error_messages.h b/torch/csrc/autograd/utils/error_messages.h index a83a7409ecf90e..da655883f3f69d 100644 --- a/torch/csrc/autograd/utils/error_messages.h +++ b/torch/csrc/autograd/utils/error_messages.h @@ -2,17 +2,21 @@ #include -namespace torch { namespace autograd { namespace utils { +namespace torch { +namespace autograd { +namespace utils { inline std::string requires_grad_leaf_error(bool requires_grad) { std::ostringstream oss; oss << "you can only change requires_grad flags of leaf variables."; if (requires_grad == false) { - oss << " If you want to use a computed variable in a subgraph " - "that doesn't require differentiation use " - "var_no_grad = var.detach()."; + oss << " If you want to use a computed variable in a subgraph " + "that doesn't require differentiation use " + "var_no_grad = var.detach()."; } return oss.str(); } -}}} // namespace torch::autograd::utils +} // namespace utils +} // namespace autograd +} // namespace torch diff --git a/torch/csrc/autograd/utils/grad_layout_contract.h b/torch/csrc/autograd/utils/grad_layout_contract.h index c7e1bad9fb8a32..e7e1e14c66aed9 100644 --- a/torch/csrc/autograd/utils/grad_layout_contract.h +++ b/torch/csrc/autograd/utils/grad_layout_contract.h @@ -10,7 +10,9 @@ namespace utils { // torch/csrc/autograd/functions/accumulate_grad.h. // Checks if grad obeys the contract with variable. -inline bool obeys_layout_contract(const at::Tensor& grad, const at::Tensor& variable) { +inline bool obeys_layout_contract( + const at::Tensor& grad, + const at::Tensor& variable) { TORCH_INTERNAL_ASSERT(!grad.is_sparse()); TORCH_INTERNAL_ASSERT(!variable.is_sparse()); TORCH_INTERNAL_ASSERT(!grad.is_sparse_csr()); @@ -27,10 +29,10 @@ inline bool obeys_layout_contract(const at::Tensor& grad, const at::Tensor& vari } } else { // This should not be needed but we don't check if a Tensor has views - // before stashing it. And 0-strided Tensors of size 1 are actually views - // for ops like cat. - // TODO: Actually detect views in the accumulateGrad function so that this - // Tensor is not considered at all. + // before stashing it. And 0-strided Tensors of size 1 are actually + // views for ops like cat. + // TODO: Actually detect views in the accumulateGrad function so that + // this Tensor is not considered at all. if (grad_strides[idx] == 0) { return false; } @@ -44,14 +46,19 @@ inline bool obeys_layout_contract(const at::Tensor& grad, const at::Tensor& vari // Creates a clone of new_grad that obeys the contract with variable. // The clone should attach to new_grad's history if GradMode::is_enabled(). -inline at::Tensor clone_obey_contract(const at::Tensor& new_grad, const at::Tensor& variable) { +inline at::Tensor clone_obey_contract( + const at::Tensor& new_grad, + const at::Tensor& variable) { if (variable.is_non_overlapping_and_dense()) { // (1) // Does this dicey-looking sequence attach the result to new_grad's // history if GradMode::is_enabled()? Yes, and @alband says it should. - return std::move(new_grad.new_empty_strided(variable.sizes(), variable.strides(), - variable.options().memory_format(c10::nullopt)) - .copy_(new_grad)); + return std::move(new_grad + .new_empty_strided( + variable.sizes(), + variable.strides(), + variable.options().memory_format(c10::nullopt)) + .copy_(new_grad)); } else { // (2) return new_grad.clone(at::MemoryFormat::Contiguous); diff --git a/torch/csrc/autograd/utils/python_arg_parsing.h b/torch/csrc/autograd/utils/python_arg_parsing.h index 989da5adaf63a1..7701e97fe9189c 100644 --- a/torch/csrc/autograd/utils/python_arg_parsing.h +++ b/torch/csrc/autograd/utils/python_arg_parsing.h @@ -1,35 +1,53 @@ #pragma once -#include #include +#include #include -namespace torch { namespace autograd { namespace utils { +namespace torch { +namespace autograd { +namespace utils { // The parameter allow_copy is to accept copy for Tensor.to (and by proxy // PackedSequences.to) but not nn.Module.to. -inline std::tuple, c10::optional, bool, bool, c10::optional> - parse_to_conversion(PythonArgs& r, bool allow_copy) { +inline std::tuple< + c10::optional, + c10::optional, + bool, + bool, + c10::optional> +parse_to_conversion(PythonArgs& r, bool allow_copy) { if (r.idx == 0) { if (!allow_copy && !r.isNone(3)) throw std::runtime_error(".to() does not accept copy argument"); - return std::make_tuple(r.deviceOptional(0), r.scalartypeOptional(1), r.toBool(2), r.toBool(3), r.memoryformatOptional(4)); + return std::make_tuple( + r.deviceOptional(0), + r.scalartypeOptional(1), + r.toBool(2), + r.toBool(3), + r.memoryformatOptional(4)); } else if (r.idx == 1) { if (!allow_copy && !r.isNone(2)) throw std::runtime_error(".to() does not accept copy argument"); - return std::make_tuple(c10::nullopt, r.scalartype(0), r.toBool(1), r.toBool(2), r.memoryformatOptional(3)); + return std::make_tuple( + c10::nullopt, + r.scalartype(0), + r.toBool(1), + r.toBool(2), + r.memoryformatOptional(3)); } else { auto tensor = r.tensor(0); if (!allow_copy && !r.isNone(2)) throw std::runtime_error(".to() does not accept copy argument"); return std::make_tuple( - tensor.device(), - tensor.scalar_type(), - r.toBool(1), - r.toBool(2), - r.memoryformatOptional(3) - ); + tensor.device(), + tensor.scalar_type(), + r.toBool(1), + r.toBool(2), + r.memoryformatOptional(3)); } } -}}} // namespace torch::autograd::utils +} // namespace utils +} // namespace autograd +} // namespace torch diff --git a/torch/csrc/autograd/utils/warnings.cpp b/torch/csrc/autograd/utils/warnings.cpp index 681d9eadf02ba5..a8f989d51ce6ac 100644 --- a/torch/csrc/autograd/utils/warnings.cpp +++ b/torch/csrc/autograd/utils/warnings.cpp @@ -1,10 +1,12 @@ #include -namespace torch { namespace autograd { namespace utils { +namespace torch { +namespace autograd { +namespace utils { void DelayWarningHandler::process( - const at::SourceLocation &source_location, - const std::string &msg, + const at::SourceLocation& source_location, + const std::string& msg, const bool verbatim) { std::lock_guard lock(mutex_); warnings_.push_back({source_location, msg, verbatim}); @@ -20,4 +22,6 @@ void DelayWarningHandler::replay_warnings() { } } -}}} // namespace torch::autograd::utils +} // namespace utils +} // namespace autograd +} // namespace torch diff --git a/torch/csrc/autograd/utils/warnings.h b/torch/csrc/autograd/utils/warnings.h index 76ced6fb7b53d4..d9e028acb4d877 100644 --- a/torch/csrc/autograd/utils/warnings.h +++ b/torch/csrc/autograd/utils/warnings.h @@ -1,25 +1,26 @@ #pragma once #include -#include #include +#include - -namespace torch { namespace autograd { namespace utils { +namespace torch { +namespace autograd { +namespace utils { // Warning handler for multi-threaded contexts. Gather warnings from // all threads into a single queue, then process together at the end // in the main thread. -class DelayWarningHandler: public at::WarningHandler { -public: +class DelayWarningHandler : public at::WarningHandler { + public: ~DelayWarningHandler() override = default; void replay_warnings(); -private: - - void process(const at::SourceLocation &source_location, - const std::string &msg, - bool verbatim) override; + private: + void process( + const at::SourceLocation& source_location, + const std::string& msg, + bool verbatim) override; struct Warning { c10::SourceLocation source_location; @@ -31,4 +32,6 @@ class DelayWarningHandler: public at::WarningHandler { std::mutex mutex_; }; -}}} // namespace torch::autograd::utils +} // namespace utils +} // namespace autograd +} // namespace torch diff --git a/torch/csrc/autograd/utils/wrap_outputs.h b/torch/csrc/autograd/utils/wrap_outputs.h index 114b53487368c7..370d6fdeff111a 100644 --- a/torch/csrc/autograd/utils/wrap_outputs.h +++ b/torch/csrc/autograd/utils/wrap_outputs.h @@ -2,23 +2,25 @@ // Wrap tensor operation outputs as PyObject* -#include #include +#include #include #include -#include #include +#include #include +#include #include #include #include #include #include #include -#include -namespace torch { namespace autograd { namespace utils { +namespace torch { +namespace autograd { +namespace utils { inline PyObject* wrap(bool value) { if (value) { @@ -46,7 +48,7 @@ inline PyObject* wrap(void* value) { return THPUtils_packInt64(reinterpret_cast(value)); } -inline PyObject* wrap(THPDtype *dtype) { +inline PyObject* wrap(THPDtype* dtype) { Py_INCREF(dtype); return (PyObject*)dtype; } @@ -55,7 +57,7 @@ inline PyObject* wrap(at::ScalarType scalarType) { return wrap(getTHPDtype(scalarType)); } -inline PyObject* wrap(THPLayout *layout) { +inline PyObject* wrap(THPLayout* layout) { Py_INCREF(layout); return (PyObject*)layout; } @@ -80,7 +82,8 @@ inline PyObject* wrap(at::QScheme qscheme) { inline PyObject* wrap(at::TensorList tl) { auto r = THPObjectPtr{PyTuple_New(tl.size())}; - if (!r) throw python_error(); + if (!r) + throw python_error(); for (const auto i : c10::irange(tl.size())) { PyTuple_SET_ITEM(r.get(), i, wrap(tl[i])); } @@ -89,7 +92,8 @@ inline PyObject* wrap(at::TensorList tl) { inline PyObject* wrap(at::IntArrayRef list) { auto r = THPObjectPtr{PyTuple_New(list.size())}; - if (!r) throw python_error(); + if (!r) + throw python_error(); for (const auto i : c10::irange(list.size())) { PyTuple_SET_ITEM(r.get(), i, wrap(list[i])); } @@ -97,38 +101,47 @@ inline PyObject* wrap(at::IntArrayRef list) { } namespace detail { -template -void apply_with_idx_impl(const F &f, Tuple &t, std::index_sequence /*indices*/) { - (void)std::initializer_list { - (f(std::get(t), Is), 0)... - }; +template +void apply_with_idx_impl( + const F& f, + Tuple& t, + std::index_sequence /*indices*/) { + (void)std::initializer_list{(f(std::get(t), Is), 0)...}; } // For tuple(a, b, c), calls f(a, 0), f(b, 1), f(c, 2) -template -void apply_with_idx(const F & f, std::tuple &t) { +template +void apply_with_idx(const F& f, std::tuple& t) { apply_with_idx_impl(f, t, std::index_sequence_for{}); } -} // namespace detail +} // namespace detail -template +template PyObject* wrap(std::tuple values) { auto r = THPObjectPtr{PyTuple_New(sizeof...(Ts))}; - if (!r) throw python_error(); - detail::apply_with_idx([&](auto &value, size_t idx) { - PyTuple_SET_ITEM(r.get(), idx, wrap(std::move(value))); - }, values); + if (!r) + throw python_error(); + detail::apply_with_idx( + [&](auto& value, size_t idx) { + PyTuple_SET_ITEM(r.get(), idx, wrap(std::move(value))); + }, + values); return r.release(); } -template -PyObject* wrap(PyTypeObject *type, std::tuple values) { +template +PyObject* wrap(PyTypeObject* type, std::tuple values) { auto r = THPObjectPtr{PyStructSequence_New(type)}; - if (!r) throw python_error(); - detail::apply_with_idx([&](auto &value, size_t idx) { - PyStructSequence_SET_ITEM(r.get(), idx, wrap(std::move(value))); - }, values); + if (!r) + throw python_error(); + detail::apply_with_idx( + [&](auto& value, size_t idx) { + PyStructSequence_SET_ITEM(r.get(), idx, wrap(std::move(value))); + }, + values); return r.release(); } -}}} // namespace torch::autograd::utils +} // namespace utils +} // namespace autograd +} // namespace torch diff --git a/torch/csrc/autograd/variable.cpp b/torch/csrc/autograd/variable.cpp index 22a9e76ec60beb..373c26aebab0bc 100644 --- a/torch/csrc/autograd/variable.cpp +++ b/torch/csrc/autograd/variable.cpp @@ -1,10 +1,10 @@ #include +#include #include #include #include #include -#include #include #include #include @@ -13,28 +13,28 @@ #include #include -#include #include +#include #include +#include #include #include #include #include #include -#include #include -#include +#include namespace torch { namespace autograd { - -DifferentiableViewMeta::DifferentiableViewMeta(at::TensorImpl* self_impl, - c10::optional backward_info, - c10::optional forward_info, - bool shared_view_info, - CreationMeta creation_meta) +DifferentiableViewMeta::DifferentiableViewMeta( + at::TensorImpl* self_impl, + c10::optional backward_info, + c10::optional forward_info, + bool shared_view_info, + CreationMeta creation_meta) : AutogradMeta(self_impl), backward_info_(std::move(backward_info)), forward_info_(std::move(forward_info)), @@ -42,23 +42,31 @@ DifferentiableViewMeta::DifferentiableViewMeta(at::TensorImpl* self_impl, creation_meta_(creation_meta) { is_view_ = true; if (backward_info_.has_value()) { - self_impl->set_version_counter(impl::version_counter(backward_info_.value().base_)); + self_impl->set_version_counter( + impl::version_counter(backward_info_.value().base_)); attr_version_ = self_impl->version_counter().current_version(); } if (shared_view_info_) { - TORCH_INTERNAL_ASSERT(backward_info_.has_value(), "Shared view info require a backward view info."); - TORCH_INTERNAL_ASSERT(!forward_info_.has_value(), "Shared view info require forward view info to be empty") + TORCH_INTERNAL_ASSERT( + backward_info_.has_value(), + "Shared view info require a backward view info."); + TORCH_INTERNAL_ASSERT( + !forward_info_.has_value(), + "Shared view info require forward view info to be empty") } } // Chain this view info with the new view op between base and tensor -ViewInfo ViewInfo::chain(const Variable & base, const Variable & tensor, - std::function view_func) const { +ViewInfo ViewInfo::chain( + const Variable& base, + const Variable& tensor, + std::function view_func) const { // Set `view_func` using the root base as input. - // `view_func` is used to recover views in backward when either as_strided is not supported - // or the view function changes the metadata which is not recorded by as_strided - // See Note [View + Inplace update on base tensor] and [View + Inplace update on view tensor] - // for more details how we use this function in backward. + // `view_func` is used to recover views in backward when either as_strided is + // not supported or the view function changes the metadata which is not + // recorded by as_strided See Note [View + Inplace update on base tensor] and + // [View + Inplace update on view tensor] for more details how we use this + // function in backward. if (view_func) { // both current_view and it's parent have a view_func if (view_fn_) { @@ -79,22 +87,25 @@ ViewInfo ViewInfo::chain(const Variable & base, const Variable & tensor, return view_func(temp); }; } else { - // When base is a view but doesn't carry a view_fn in DifferentiableViewMeta, it's - // a view that doesn't support inplace update, e.g. unbind. - // In this case we should throw an error when inplace update happens in **forward**. - // One would naturally think the following function will be first called in backward pass. - // But the first call site is indeed in **forward** pass when we refresh `grad_fn` - // triggered by inplace update. - // Search Note [View + Inplace update for view tensor] to for the call site. + // When base is a view but doesn't carry a view_fn in + // DifferentiableViewMeta, it's a view that doesn't support inplace + // update, e.g. unbind. In this case we should throw an error when + // inplace update happens in **forward**. One would naturally think the + // following function will be first called in backward pass. But the + // first call site is indeed in **forward** pass when we refresh + // `grad_fn` triggered by inplace update. Search Note [View + Inplace + // update for view tensor] to for the call site. view_func = [=](const at::Tensor& root_base) { - TORCH_CHECK(false, "This view is the output of a function that returns multiple views." - "Such functions do not allow the output views to be modified inplace." - "You should replace the inplace operation by an out-of-place one"); + TORCH_CHECK( + false, + "This view is the output of a function that returns multiple views." + "Such functions do not allow the output views to be modified inplace." + "You should replace the inplace operation by an out-of-place one"); return root_base; }; } } - } else if(view_fn_) { + } else if (view_fn_) { // if current_view doesn't have a view_func but it's parent has one // Copy parent view function to gain ownership auto prev_view_fn = view_fn_; @@ -125,205 +136,222 @@ struct ConcreteAutogradMetaFactory : public c10::impl::AutogradMetaFactory { ConcreteAutogradMetaFactory meta_factory; -static c10::impl::AutogradMetaFactoryRegisterer meta_factory_registerer(&meta_factory); +static c10::impl::AutogradMetaFactoryRegisterer meta_factory_registerer( + &meta_factory); -} +} // namespace namespace impl { - AutogradMeta* materialize_autograd_meta(const at::TensorBase& self) { - TORCH_CHECK(self.defined(), "cannot call materialize_autograd_meta() on undefined tensor"); - auto p = self.unsafeGetTensorImpl(); - if (!p->autograd_meta()) { - p->set_autograd_meta(std::make_unique()); - } - return get_autograd_meta(self); - } - - void rebase_history(const Variable& self, Edge gradient_edge) { - TORCH_INTERNAL_ASSERT(gradient_edge.function != nullptr); - auto diff_view_meta = get_view_autograd_meta(self); - if (diff_view_meta && diff_view_meta->has_bw_view()) { - // See NOTE [ View + Inplace detection ] - auto creation_meta = diff_view_meta->get_creation_meta(); - // Do not use handle_view_on_rebase here as check_inplace should have been called before this - // and either throw an error - TORCH_INTERNAL_ASSERT(creation_meta == CreationMeta::DEFAULT); - TORCH_INTERNAL_ASSERT(gradient_edge.input_nr == 0); - TORCH_INTERNAL_ASSERT(gradient_edge.function); - TORCH_CHECK( - gradient_edge.function->num_inputs() == 1, - "Functions which modify views in-place must return a single Variable"); - auto view_info = diff_view_meta->get_backward_view(); - diff_view_meta->output_nr_ = gradient_edge.input_nr; - auto copy_slices = std::make_shared( - view_info.base_, at::TensorGeometry(self), view_info.view_fn_, std::move(gradient_edge.function)); - set_gradient_edge(view_info.base_, {std::move(copy_slices), 0}); - self.grad_fn(); // trigger an update to the view's grad_fn - return; - } - - set_gradient_edge(self, std::move(gradient_edge)); +AutogradMeta* materialize_autograd_meta(const at::TensorBase& self) { + TORCH_CHECK( + self.defined(), + "cannot call materialize_autograd_meta() on undefined tensor"); + auto p = self.unsafeGetTensorImpl(); + if (!p->autograd_meta()) { + p->set_autograd_meta(std::make_unique()); } + return get_autograd_meta(self); +} - void create_cpp_hook(const at::TensorBase& self) { - auto &list = materialize_autograd_meta(self)->cpp_hooks_list_; - // NOLINTNEXTLINE(modernize-make-shared) - list.reset(new hooks_list()); - std::unique_ptr hook_ptr(new CppFunctionPreHook(list, self.output_nr())); - clear_hooks(self); - add_hook(self, std::make_shared(list, 0)); - const auto& fn = self.grad_fn(); - if (fn) { - fn->add_pre_hook(std::move(hook_ptr)); - } +void rebase_history(const Variable& self, Edge gradient_edge) { + TORCH_INTERNAL_ASSERT(gradient_edge.function != nullptr); + auto diff_view_meta = get_view_autograd_meta(self); + if (diff_view_meta && diff_view_meta->has_bw_view()) { + // See NOTE [ View + Inplace detection ] + auto creation_meta = diff_view_meta->get_creation_meta(); + // Do not use handle_view_on_rebase here as check_inplace should have been + // called before this and either throw an error + TORCH_INTERNAL_ASSERT(creation_meta == CreationMeta::DEFAULT); + TORCH_INTERNAL_ASSERT(gradient_edge.input_nr == 0); + TORCH_INTERNAL_ASSERT(gradient_edge.function); + TORCH_CHECK( + gradient_edge.function->num_inputs() == 1, + "Functions which modify views in-place must return a single Variable"); + auto view_info = diff_view_meta->get_backward_view(); + diff_view_meta->output_nr_ = gradient_edge.input_nr; + auto copy_slices = std::make_shared( + view_info.base_, + at::TensorGeometry(self), + view_info.view_fn_, + std::move(gradient_edge.function)); + set_gradient_edge(view_info.base_, {std::move(copy_slices), 0}); + self.grad_fn(); // trigger an update to the view's grad_fn + return; } - void set_grad_accumulator(const Variable& self, - std::weak_ptr grad_accumulator) { - materialize_autograd_meta(self)->grad_accumulator_ = std::move(grad_accumulator); - } + set_gradient_edge(self, std::move(gradient_edge)); +} - std::shared_ptr try_get_grad_accumulator(const Variable& self) { - if (get_autograd_meta(self)) { - return get_autograd_meta(self)->grad_accumulator_.lock(); - } else { - return nullptr; - } +void create_cpp_hook(const at::TensorBase& self) { + auto& list = materialize_autograd_meta(self)->cpp_hooks_list_; + // NOLINTNEXTLINE(modernize-make-shared) + list.reset(new hooks_list()); + std::unique_ptr hook_ptr( + new CppFunctionPreHook(list, self.output_nr())); + clear_hooks(self); + add_hook(self, std::make_shared(list, 0)); + const auto& fn = self.grad_fn(); + if (fn) { + fn->add_pre_hook(std::move(hook_ptr)); } +} - std::shared_ptr grad_accumulator(const Variable& self) { - auto autograd_meta = get_autograd_meta(self); - if (!autograd_meta) { - return nullptr; - } - if (autograd_meta->grad_fn_) { - throw std::logic_error( - "grad_accumulator() should be only called on leaf Variables"); - } - if (!autograd_meta->requires_grad_) { - return nullptr; - } +void set_grad_accumulator( + const Variable& self, + std::weak_ptr grad_accumulator) { + materialize_autograd_meta(self)->grad_accumulator_ = + std::move(grad_accumulator); +} - std::lock_guard lock(autograd_meta->mutex_); +std::shared_ptr try_get_grad_accumulator(const Variable& self) { + if (get_autograd_meta(self)) { + return get_autograd_meta(self)->grad_accumulator_.lock(); + } else { + return nullptr; + } +} - auto result = autograd_meta->grad_accumulator_.lock(); - if (result) - return result; +std::shared_ptr grad_accumulator(const Variable& self) { + auto autograd_meta = get_autograd_meta(self); + if (!autograd_meta) { + return nullptr; + } + if (autograd_meta->grad_fn_) { + throw std::logic_error( + "grad_accumulator() should be only called on leaf Variables"); + } + if (!autograd_meta->requires_grad_) { + return nullptr; + } - c10::raw::intrusive_ptr::incref(self.unsafeGetTensorImpl()); - auto intrusive_from_this = c10::intrusive_ptr::reclaim(self.unsafeGetTensorImpl()); - result = std::make_shared(Variable(std::move(intrusive_from_this))); - autograd_meta->grad_accumulator_ = result; + std::lock_guard lock(autograd_meta->mutex_); + + auto result = autograd_meta->grad_accumulator_.lock(); + if (result) return result; - } - Edge gradient_edge(const Variable& self) { - // If grad_fn is null (as is the case for a leaf node), we instead - // interpret the gradient function to be a gradient accumulator, which will - // accumulate its inputs into the grad property of the variable. These - // nodes get suppressed in some situations, see "suppress gradient - // accumulation" below. Note that only variables which have `requires_grad = - // True` can have gradient accumulators. - if (const auto& gradient = self.grad_fn()) { - return Edge(gradient, self.output_nr()); - } else { - return Edge(grad_accumulator(self), 0); - } + c10::raw::intrusive_ptr::incref(self.unsafeGetTensorImpl()); + auto intrusive_from_this = + c10::intrusive_ptr::reclaim(self.unsafeGetTensorImpl()); + result = std::make_shared( + Variable(std::move(intrusive_from_this))); + autograd_meta->grad_accumulator_ = result; + return result; +} + +Edge gradient_edge(const Variable& self) { + // If grad_fn is null (as is the case for a leaf node), we instead + // interpret the gradient function to be a gradient accumulator, which will + // accumulate its inputs into the grad property of the variable. These + // nodes get suppressed in some situations, see "suppress gradient + // accumulation" below. Note that only variables which have `requires_grad = + // True` can have gradient accumulators. + if (const auto& gradient = self.grad_fn()) { + return Edge(gradient, self.output_nr()); + } else { + return Edge(grad_accumulator(self), 0); } +} - void set_gradient_edge(const Variable& self, Edge edge) { - auto* meta = materialize_autograd_meta(self); - meta->grad_fn_ = std::move(edge.function); - meta->output_nr_ = edge.input_nr; - // For views, make sure this new grad_fn_ is not overwritten unless it is necessary - // in the VariableHooks::grad_fn below. - // This logic is only relevant for custom autograd Functions for which multiple - // operations can happen on a given Tensor before its gradient edge is set when - // exiting the custom Function. - auto diff_view_meta = get_view_autograd_meta(self); - if (diff_view_meta && diff_view_meta->has_bw_view()) { - diff_view_meta->set_attr_version(self._version()); - } +void set_gradient_edge(const Variable& self, Edge edge) { + auto* meta = materialize_autograd_meta(self); + meta->grad_fn_ = std::move(edge.function); + meta->output_nr_ = edge.input_nr; + // For views, make sure this new grad_fn_ is not overwritten unless it is + // necessary in the VariableHooks::grad_fn below. This logic is only relevant + // for custom autograd Functions for which multiple operations can happen on a + // given Tensor before its gradient edge is set when exiting the custom + // Function. + auto diff_view_meta = get_view_autograd_meta(self); + if (diff_view_meta && diff_view_meta->has_bw_view()) { + diff_view_meta->set_attr_version(self._version()); } +} - Node* grad_fn_unsafe(const Variable& self) { - if (get_autograd_meta(self)) { - return get_autograd_meta(self)->grad_fn_.get(); - } else { - return nullptr; - } +Node* grad_fn_unsafe(const Variable& self) { + if (get_autograd_meta(self)) { + return get_autograd_meta(self)->grad_fn_.get(); + } else { + return nullptr; } +} - // Versions - //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// Versions +//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - void set_version_counter( - const Variable& self, - const c10::VariableVersion& version_counter) { - TORCH_CHECK(self.defined(), "cannot call set_version_counter() on undefined tensor"); - self.unsafeGetTensorImpl()->set_version_counter(version_counter); - } +void set_version_counter( + const Variable& self, + const c10::VariableVersion& version_counter) { + TORCH_CHECK( + self.defined(), "cannot call set_version_counter() on undefined tensor"); + self.unsafeGetTensorImpl()->set_version_counter(version_counter); +} - void bump_version(const Variable& self) { - TORCH_CHECK(self.defined(), "cannot call bump_version() on undefined tensor"); - self.unsafeGetTensorImpl()->bump_version(); - } +void bump_version(const Variable& self) { + TORCH_CHECK(self.defined(), "cannot call bump_version() on undefined tensor"); + self.unsafeGetTensorImpl()->bump_version(); +} - const c10::VariableVersion& version_counter(const Variable& self) { - TORCH_CHECK(self.defined(), "cannot call version_counter() on undefined tensor"); - return self.unsafeGetTensorImpl()->version_counter(); - } +const c10::VariableVersion& version_counter(const Variable& self) { + TORCH_CHECK( + self.defined(), "cannot call version_counter() on undefined tensor"); + return self.unsafeGetTensorImpl()->version_counter(); +} - // Hooks - //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// Hooks +//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - void add_hook(const at::TensorBase& self, std::shared_ptr hook) { - materialize_autograd_meta(self)->hooks_.push_back(std::move(hook)); - } +void add_hook( + const at::TensorBase& self, + std::shared_ptr hook) { + materialize_autograd_meta(self)->hooks_.push_back(std::move(hook)); +} - namespace { - std::vector> empty_singleton; - } +namespace { +std::vector> empty_singleton; +} - // TODO: Return an ArrayRef instead (and delete the singleton while you're at - // it - const std::vector>& hooks(const Variable& self) - { - if (get_autograd_meta(self)) { - return get_autograd_meta(self)->hooks_; - } else { - return empty_singleton; - } +// TODO: Return an ArrayRef instead (and delete the singleton while you're at +// it +const std::vector>& hooks( + const Variable& self) { + if (get_autograd_meta(self)) { + return get_autograd_meta(self)->hooks_; + } else { + return empty_singleton; } +} - void clear_hooks(const at::TensorBase& self) { - // This is a little goofy, but usually this should be a no oop - materialize_autograd_meta(self)->hooks_.clear(); - } +void clear_hooks(const at::TensorBase& self) { + // This is a little goofy, but usually this should be a no oop + materialize_autograd_meta(self)->hooks_.clear(); +} - void set_name(const Variable& self, const std::string& name) { - materialize_autograd_meta(self)->name_ = name; - } +void set_name(const Variable& self, const std::string& name) { + materialize_autograd_meta(self)->name_ = name; +} - // Miscellaneous - //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// Miscellaneous +//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - AutogradMeta* get_autograd_meta(const at::TensorBase& self) { - // NB: could return nullptr - TORCH_CHECK(self.defined(), "cannot call get_autograd_meta() on undefined tensor"); - return static_cast(self.unsafeGetTensorImpl()->autograd_meta()); - } +AutogradMeta* get_autograd_meta(const at::TensorBase& self) { + // NB: could return nullptr + TORCH_CHECK( + self.defined(), "cannot call get_autograd_meta() on undefined tensor"); + return static_cast( + self.unsafeGetTensorImpl()->autograd_meta()); +} - DifferentiableViewMeta* get_view_autograd_meta(const at::TensorBase& self) { - // NB: return nullptr if self is not a view - AutogradMeta* meta = get_autograd_meta(self); - if (meta && meta->is_view_) { - return static_cast(meta); - } else { - return nullptr; - } +DifferentiableViewMeta* get_view_autograd_meta(const at::TensorBase& self) { + // NB: return nullptr if self is not a view + AutogradMeta* meta = get_autograd_meta(self); + if (meta && meta->is_view_) { + return static_cast(meta); + } else { + return nullptr; } +} } // namespace impl @@ -332,34 +360,42 @@ using at::Tensor; struct VariableHooks final : at::impl::VariableHooksInterface { at::TensorBase tensor_data(const at::TensorBase&) const override; at::TensorBase variable_data(const at::TensorBase&) const override; - const std::shared_ptr& grad_fn(const at::TensorBase&) const override; + const std::shared_ptr& grad_fn( + const at::TensorBase&) const override; unsigned _register_hook( - const at::TensorBase&, std::function hook) const override; + const at::TensorBase&, + std::function hook) const override; void remove_hook(const at::TensorBase&, unsigned pos) const override; bool is_view(const at::TensorBase&) const override; const at::TensorBase& base(const at::TensorBase&) const override; const std::string& name(const at::TensorBase&) const override; bool is_leaf(const at::TensorBase&) const override; int64_t output_nr(const at::TensorBase&) const override; - void set_data(const at::TensorBase & self, const at::TensorBase & new_data) const override; - at::TensorBase data(const at::TensorBase & self) const override; - int64_t _version(const at::TensorBase & self) const override; + void set_data(const at::TensorBase& self, const at::TensorBase& new_data) + const override; + at::TensorBase data(const at::TensorBase& self) const override; + int64_t _version(const at::TensorBase& self) const override; void retain_grad(const at::TensorBase& self) const override; bool retains_grad(const at::TensorBase& self) const override; - void _backward(const Tensor& self, at::TensorList inputs, - const c10::optional& gradient, c10::optional keep_graph, - bool create_graph) const override; - void requires_grad_(const at::TensorBase& self, bool _requires_grad) const override; + void _backward( + const Tensor& self, + at::TensorList inputs, + const c10::optional& gradient, + c10::optional keep_graph, + bool create_graph) const override; + void requires_grad_(const at::TensorBase& self, bool _requires_grad) + const override; }; VariableHooks variableHooks; at::impl::VariableHooksRegisterer registerVariableHooks(&variableHooks); at::TensorBase VariableHooks::variable_data(const at::TensorBase& self) const { - TORCH_CHECK(self.defined(), "cannot call variable_data() on undefined tensor"); + TORCH_CHECK( + self.defined(), "cannot call variable_data() on undefined tensor"); auto self_impl_copy = self.unsafeGetTensorImpl()->shallow_copy_and_detach( - /*version_counter=*/0, - /*allow_tensor_metadata_change=*/false); + /*version_counter=*/0, + /*allow_tensor_metadata_change=*/false); self_impl_copy->set_autograd_meta(nullptr); return at::Tensor(self_impl_copy); } @@ -367,12 +403,13 @@ at::TensorBase VariableHooks::variable_data(const at::TensorBase& self) const { at::TensorBase VariableHooks::tensor_data(const at::TensorBase& self) const { TORCH_CHECK(self.defined(), "cannot call tensor_data() on undefined tensor"); auto self_impl_copy = self.unsafeGetTensorImpl()->shallow_copy_and_detach( - /*version_counter=*/self.unsafeGetTensorImpl()->version_counter(), - /*allow_tensor_metadata_change=*/self.unsafeGetTensorImpl()->allow_tensor_metadata_change()); + /*version_counter=*/self.unsafeGetTensorImpl()->version_counter(), + /*allow_tensor_metadata_change=*/ + self.unsafeGetTensorImpl()->allow_tensor_metadata_change()); return at::Tensor(self_impl_copy); } -bool VariableHooks::is_leaf(const at::TensorBase & self) const { +bool VariableHooks::is_leaf(const at::TensorBase& self) const { if (impl::get_autograd_meta(self)) { return impl::get_autograd_meta(self)->grad_fn_ == nullptr; } else { @@ -380,7 +417,7 @@ bool VariableHooks::is_leaf(const at::TensorBase & self) const { } } -int64_t VariableHooks::output_nr(const at::TensorBase & self) const { +int64_t VariableHooks::output_nr(const at::TensorBase& self) const { if (impl::get_autograd_meta(self)) { return impl::get_autograd_meta(self)->output_nr_; } else { @@ -388,22 +425,25 @@ int64_t VariableHooks::output_nr(const at::TensorBase & self) const { } } -void VariableHooks::set_data(const at::TensorBase & self_base, const at::TensorBase & new_data_base) const { +void VariableHooks::set_data( + const at::TensorBase& self_base, + const at::TensorBase& new_data_base) const { at::OptionalTensorRef self_ref(self_base); - const Tensor &self = *self_ref; + const Tensor& self = *self_ref; at::OptionalTensorRef new_data_ref(new_data_base); - const Tensor &new_data = *new_data_ref; + const Tensor& new_data = *new_data_ref; // `var.set_data(new_data)` shallow-copies all non-autograd TensorImpl fields - // from `new_data` to `var`. It requires that `new_data` and `var` have compatible - // tensor type. + // from `new_data` to `var`. It requires that `new_data` and `var` have + // compatible tensor type. TORCH_CHECK( - _has_compatible_shallow_copy_type(self, new_data), - "Attempted to call `variable.set_data(tensor)`, but `variable` and `tensor` have incompatible tensor type."); + _has_compatible_shallow_copy_type(self, new_data), + "Attempted to call `variable.set_data(tensor)`, but `variable` and `tensor` have incompatible tensor type."); TORCH_CHECK( - !self.requires_grad() || isDifferentiableType(at::typeMetaToScalarType(new_data.dtype())), - "data set to a tensor that requires gradients must be floating point or complex dtype"); + !self.requires_grad() || + isDifferentiableType(at::typeMetaToScalarType(new_data.dtype())), + "data set to a tensor that requires gradients must be floating point or complex dtype"); // Resets gradient accumulator if metadata is out of date AutogradMeta* autograd_meta = impl::get_autograd_meta(self); @@ -414,33 +454,37 @@ void VariableHooks::set_data(const at::TensorBase & self_base, const at::TensorB const auto prior_device = prior_accumulator->input_metadata(0).device(); const auto new_device = new_data.device(); - if (!new_data.options().type_equal(self.options()) || prior_device != new_device) { + if (!new_data.options().type_equal(self.options()) || + prior_device != new_device) { autograd_meta->grad_accumulator_.reset(); } } } // Version counter is not shared when we replace a `Variable`'s tensor data - // by calling `set_data(...)`. The original version of the `Variable` is always preserved. - // See NOTE [ Version Counter Sharing ] for details. + // by calling `set_data(...)`. The original version of the `Variable` is + // always preserved. See NOTE [ Version Counter Sharing ] for details. // - // `var.set_data(new_data)` always ignores `var`'s `allow_tensor_metadata_change_`, because - // users need this API as an escape hatch for changing a tensor's metadata regardless of its - // `allow_tensor_metadata_change_` value, and the users are responsible for ensuring this is - // the behavior they want. + // `var.set_data(new_data)` always ignores `var`'s + // `allow_tensor_metadata_change_`, because users need this API as an escape + // hatch for changing a tensor's metadata regardless of its + // `allow_tensor_metadata_change_` value, and the users are responsible for + // ensuring this is the behavior they want. self.unsafeGetTensorImpl()->shallow_copy_from(new_data.getIntrusivePtr()); } -at::TensorBase VariableHooks::data(const at::TensorBase & self) const { +at::TensorBase VariableHooks::data(const at::TensorBase& self) const { return self.variable_data(); } -int64_t VariableHooks::_version(const at::TensorBase & self) const { +int64_t VariableHooks::_version(const at::TensorBase& self) const { return self.unsafeGetTensorImpl()->version_counter().current_version(); } void VariableHooks::retain_grad(const at::TensorBase& self) const { - TORCH_CHECK(self.requires_grad(), "can't retain_grad on Tensor that has requires_grad=False"); + TORCH_CHECK( + self.requires_grad(), + "can't retain_grad on Tensor that has requires_grad=False"); // temporary hack to improve functorch UX. const auto& functorch_tls = at::functorch::functorchTLSAccessor(); @@ -448,7 +492,7 @@ void VariableHooks::retain_grad(const at::TensorBase& self) const { functorch_tls->checkSupportsRetainGrad(); } - if (self.is_leaf()) { // no-op for leaves + if (self.is_leaf()) { // no-op for leaves return; } if (impl::get_autograd_meta(self)->retains_grad_) { @@ -491,18 +535,22 @@ void VariableHooks::_backward( const c10::optional& gradient, c10::optional keep_graph, bool create_graph) const { - // TODO torch::autograd::backward should take the c10::optional gradient directly - // instead of us having to unwrap it to Tensor _gradient here. + // TODO torch::autograd::backward should take the c10::optional + // gradient directly instead of us having to unwrap it to Tensor _gradient + // here. Tensor _gradient = gradient.has_value() ? *gradient : Tensor(); - std::vector input_vars(inputs.begin(), inputs.end()); - torch::autograd::backward({self}, {_gradient}, keep_graph, create_graph, input_vars); + std::vector input_vars( + inputs.begin(), inputs.end()); + torch::autograd::backward( + {self}, {_gradient}, keep_graph, create_graph, input_vars); } -void VariableHooks::requires_grad_(const at::TensorBase& self, bool _requires_grad) const { +void VariableHooks::requires_grad_( + const at::TensorBase& self, + bool _requires_grad) const { if (!self.is_leaf() && !_requires_grad) { throw std::runtime_error( - autograd::utils::requires_grad_leaf_error(_requires_grad) - ); + autograd::utils::requires_grad_leaf_error(_requires_grad)); } self.set_requires_grad(_requires_grad); } @@ -522,7 +570,9 @@ bool VariableHooks::is_view(const at::TensorBase& self) const { const at::TensorBase& VariableHooks::base(const at::TensorBase& self) const { auto diff_view_meta = torch::autograd::impl::get_view_autograd_meta(self); if (diff_view_meta) { - TORCH_CHECK(diff_view_meta->has_bw_view(), "Can't get base of non-backward view Tensor"); + TORCH_CHECK( + diff_view_meta->has_bw_view(), + "Can't get base of non-backward view Tensor"); return diff_view_meta->get_backward_view().base_; } else { throw std::runtime_error("Can't get base of non-view Tensor"); @@ -530,11 +580,12 @@ const at::TensorBase& VariableHooks::base(const at::TensorBase& self) const { } namespace { - std::string singleton_string; +std::string singleton_string; } const std::string& VariableHooks::name(const at::TensorBase& self) const { - TORCH_CHECK(self.defined(), "cannot call variable_data() on undefined tensor"); + TORCH_CHECK( + self.defined(), "cannot call variable_data() on undefined tensor"); if (torch::autograd::impl::get_autograd_meta(self)) { return torch::autograd::impl::get_autograd_meta(self)->name_; } else { @@ -543,10 +594,11 @@ const std::string& VariableHooks::name(const at::TensorBase& self) const { } namespace { - std::shared_ptr singleton_shared_ptr; +std::shared_ptr singleton_shared_ptr; } -const std::shared_ptr& VariableHooks::grad_fn(const at::TensorBase& self) const { +const std::shared_ptr& VariableHooks::grad_fn( + const at::TensorBase& self) const { auto diff_view_meta = torch::autograd::impl::get_view_autograd_meta(self); if (diff_view_meta && diff_view_meta->has_bw_view()) { // See NOTE [ View + Inplace detection ] @@ -557,7 +609,8 @@ const std::shared_ptr& VariableHooks::grad_fn(const at::T } auto current_version = self._version(); if (diff_view_meta->get_attr_version() != current_version) { - // This is an indirect rebase_history due to another view or the base being modified inplace + // This is an indirect rebase_history due to another view or the base + // being modified inplace handle_view_on_rebase(diff_view_meta, /* indirect */ true); TORCH_INTERNAL_ASSERT(diff_view_meta->output_nr_ == 0); // Note [View + Inplace update for view tensor] @@ -569,37 +622,42 @@ const std::shared_ptr& VariableHooks::grad_fn(const at::T // self = view_op_n(view_n-1) // self = inplace_op(self) // - // For CPU/CUDA backends, we employ one AsStridedBackward0 Node to represent the chain of - // view backward ops for effienciency. + // For CPU/CUDA backends, we employ one AsStridedBackward0 Node to + // represent the chain of view backward ops for effienciency. // - // However in XLA backend we don't have full support of AsStridedBackward0, we instead run a full - // forward pass with a tensor that requires gradient to get proper grad_fn setup, - // then save it to DifferentiableViewMeta for future use. - // This is fairly cheap for XLA lazy tensor approach (but would be really expensive for CPU/CUDA). - // XLA Tensor only run thorugh VariableType dispatch and lower the forward pass to a XLA HLO graph, - // then we take grad_fn and never materialize the tensor content. - // So we only construct the graph but not execute it, which is a fairly cheap operation to do. + // However in XLA backend we don't have full support of + // AsStridedBackward0, we instead run a full forward pass with a tensor + // that requires gradient to get proper grad_fn setup, then save it to + // DifferentiableViewMeta for future use. This is fairly cheap for XLA + // lazy tensor approach (but would be really expensive for CPU/CUDA). XLA + // Tensor only run thorugh VariableType dispatch and lower the forward + // pass to a XLA HLO graph, then we take grad_fn and never materialize the + // tensor content. So we only construct the graph but not execute it, + // which is a fairly cheap operation to do. // - // See Note [View + Inplace update for base tensor] for what we do to base tensor when - // an in-place operation happens. + // See Note [View + Inplace update for base tensor] for what we do to base + // tensor when an in-place operation happens. // - // TODO: Potentially the following logic can be replaced by special logic in VariableType_x.cpp + // TODO: Potentially the following logic can be replaced by special logic + // in VariableType_x.cpp // that would provide a way to recreate the grad_fn chain. if (view_info.has_view_fn()) { auto view_fn = view_info.view_fn(); auto diff_view = view_fn(view_info.base_); diff_view_meta->grad_fn_ = diff_view.grad_fn(); } else { - auto fn = std::make_shared(); + auto fn = + std::make_shared(); fn->self_geometry = at::TensorGeometry(view_info.base_); fn->size = self.sizes().vec(); fn->stride = self.strides().vec(); fn->storage_offset = self.storage_offset(); - fn->set_next_edges(torch::autograd::collect_next_edges(view_info.base_)); + fn->set_next_edges( + torch::autograd::collect_next_edges(view_info.base_)); fn->add_input_metadata( - view_info.base_.options(), - self.sizes(), // Note: sizes(), not base_.sizes(), is intentional - self.unsafeGetTensorImpl()->is_python_dispatch()); + view_info.base_.options(), + self.sizes(), // Note: sizes(), not base_.sizes(), is intentional + self.unsafeGetTensorImpl()->is_python_dispatch()); diff_view_meta->grad_fn_ = std::move(fn); } diff_view_meta->set_attr_version(current_version); @@ -614,20 +672,26 @@ const std::shared_ptr& VariableHooks::grad_fn(const at::T } } -void VariableHooks::remove_hook(const at::TensorBase& self, unsigned pos) const { - auto &list = torch::autograd::impl::materialize_autograd_meta(self)->cpp_hooks_list_; - TORCH_CHECK(list && pos < list->size() , "Invalid index, no hook at position ", pos); +void VariableHooks::remove_hook(const at::TensorBase& self, unsigned pos) + const { + auto& list = + torch::autograd::impl::materialize_autograd_meta(self)->cpp_hooks_list_; + TORCH_CHECK( + list && pos < list->size(), "Invalid index, no hook at position ", pos); // Hook will be ignored (*list)[pos] = nullptr; } unsigned VariableHooks::_register_hook( - const at::TensorBase& self, std::function hook) const { - TORCH_CHECK(self.requires_grad(), "cannot register a hook on a variable that " - "doesn't require gradient"); + const at::TensorBase& self, + std::function hook) const { + TORCH_CHECK( + self.requires_grad(), + "cannot register a hook on a variable that " + "doesn't require gradient"); // NB: materialize_autograd_meta unnecessary due to requires grad check - auto &list = torch::autograd::impl::get_autograd_meta(self)->cpp_hooks_list_; - if(!list) { + auto& list = torch::autograd::impl::get_autograd_meta(self)->cpp_hooks_list_; + if (!list) { torch::autograd::impl::create_cpp_hook(self); } unsigned idx = list->size(); @@ -635,7 +699,9 @@ unsigned VariableHooks::_register_hook( return idx; } -void handle_view_on_rebase(DifferentiableViewMeta* diff_view_meta, bool indirect) { +void handle_view_on_rebase( + DifferentiableViewMeta* diff_view_meta, + bool indirect) { /// See NOTE [ View + Inplace detection ] for justification of the logic below auto creation_meta = diff_view_meta->get_creation_meta(); if (creation_meta != CreationMeta::DEFAULT) { @@ -649,36 +715,56 @@ void handle_view_on_rebase(DifferentiableViewMeta* diff_view_meta, bool indirect modified_obj = "is being"; } if (grad_fn) { - msg = c10::str("Output ", diff_view_meta->output_nr_, " of ", grad_fn->name(), " is a view and ", - modified_obj, " modified inplace."); + msg = c10::str( + "Output ", + diff_view_meta->output_nr_, + " of ", + grad_fn->name(), + " is a view and ", + modified_obj, + " modified inplace."); } else if (creation_meta == CreationMeta::INFERENCE_MODE) { - msg = c10::str("A view was created in inference mode and ", modified_obj, " modified inplace in normal mode."); + msg = c10::str( + "A view was created in inference mode and ", + modified_obj, + " modified inplace in normal mode."); } else { - msg = c10::str("A view was created in no_grad mode and ", modified_obj, " modified inplace with grad mode enabled."); + msg = c10::str( + "A view was created in no_grad mode and ", + modified_obj, + " modified inplace with grad mode enabled."); } if (creation_meta == CreationMeta::MULTI_OUTPUT_NODE) { - msg = c10::str(msg, " This view is the output of a function that returns multiple views. Such functions do not" - " allow the output views to be modified inplace. You should replace the inplace operation by an" - " out-of-place one."); + msg = c10::str( + msg, + " This view is the output of a function that returns multiple views. Such functions do not" + " allow the output views to be modified inplace. You should replace the inplace operation by an" + " out-of-place one."); } else if (creation_meta == CreationMeta::NO_GRAD_MODE) { TORCH_INTERNAL_ASSERT(!grad_fn); - msg = c10::str(msg, " Given that this use case is ambiguous and error-prone, it is forbidden." - " You can clarify your code by moving both the view and the inplace either both" - " inside the no_grad block (if you don't want the inplace to be tracked) or both outside (if you want" - " the inplace to be tracked)."); + msg = c10::str( + msg, + " Given that this use case is ambiguous and error-prone, it is forbidden." + " You can clarify your code by moving both the view and the inplace either both" + " inside the no_grad block (if you don't want the inplace to be tracked) or both outside (if you want" + " the inplace to be tracked)."); } else if (creation_meta == CreationMeta::INFERENCE_MODE) { TORCH_INTERNAL_ASSERT(!grad_fn); - msg = c10::str(msg, " Given that this use case is ambiguous and error-prone, it is forbidden." - " You can clarify your code by moving both the view and the inplace either both" - " inside the inference_mode block (if you don't want the inplace to be tracked) or both outside (if you want" - " the inplace to be tracked)."); + msg = c10::str( + msg, + " Given that this use case is ambiguous and error-prone, it is forbidden." + " You can clarify your code by moving both the view and the inplace either both" + " inside the inference_mode block (if you don't want the inplace to be tracked) or both outside (if you want" + " the inplace to be tracked)."); TORCH_CHECK(false, msg); } else if (creation_meta == CreationMeta::IN_CUSTOM_FUNCTION) { - msg = c10::str(msg, " This view was created inside a custom Function (or because an input was returned as-is) and the" - " autograd logic to handle view+inplace would override the custom backward associated with the custom" - " Function, leading to incorrect gradients. This behavior is forbidden. You can fix this by" - " cloning the output of the custom Function."); + msg = c10::str( + msg, + " This view was created inside a custom Function (or because an input was returned as-is) and the" + " autograd logic to handle view+inplace would override the custom backward associated with the custom" + " Function, leading to incorrect gradients. This behavior is forbidden. You can fix this by" + " cloning the output of the custom Function."); } else { TORCH_INTERNAL_ASSERT(false, "Invalid CreationMeta state"); } @@ -687,6 +773,5 @@ void handle_view_on_rebase(DifferentiableViewMeta* diff_view_meta, bool indirect } } - - -}} // namespace torch::autograd +} // namespace autograd +} // namespace torch diff --git a/torch/csrc/autograd/variable.h b/torch/csrc/autograd/variable.h index e6478551628e05..c20f6b8d80f7e3 100644 --- a/torch/csrc/autograd/variable.h +++ b/torch/csrc/autograd/variable.h @@ -3,49 +3,53 @@ #include #include -#include -#include #include +#include #include +#include -#include #include +#include #include +#include #include #include #include #include #include #include -#include -namespace torch { namespace autograd { +namespace torch { +namespace autograd { -/// `Variable` is exactly the same as `Tensor` (i.e. we have `using Variable = at::Tensor`). -/// This means you can perform all the usual mathematical and other -/// operations you can perform on `Tensor`s also on `Variable`s. +/// `Variable` is exactly the same as `Tensor` (i.e. we have `using Variable = +/// at::Tensor`). This means you can perform all the usual mathematical and +/// other operations you can perform on `Tensor`s also on `Variable`s. /// -/// The only reason we are keeping the `Variable` class is backward compatibility -/// with external user's legacy C++ frontend code. Our intention is to eliminate -/// the `Variable` class in the near future. +/// The only reason we are keeping the `Variable` class is backward +/// compatibility with external user's legacy C++ frontend code. Our intention +/// is to eliminate the `Variable` class in the near future. using Variable = at::Tensor; } // namespace autograd } // namespace torch // The following are all internal APIs and should not be shown in libtorch docs. -// Therefore, we wrap the following code with `#ifndef DOXYGEN_SHOULD_SKIP_THIS ... #endif` +// Therefore, we wrap the following code with `#ifndef DOXYGEN_SHOULD_SKIP_THIS +// ... #endif` #ifndef DOXYGEN_SHOULD_SKIP_THIS -namespace torch { namespace autograd { +namespace torch { +namespace autograd { /// Check if this type is supported by the autograd engine. -/// If you change this, update the doc at the top of the torch/autograd/__init__.py file -/// and "test_set_requires_grad_only_for_continuous_types" in test/test_autograd.py +/// If you change this, update the doc at the top of the +/// torch/autograd/__init__.py file and +/// "test_set_requires_grad_only_for_continuous_types" in test/test_autograd.py static inline bool isDifferentiableType(at::ScalarType t) { - return isFloatingType(t) || isComplexType(t); + return isFloatingType(t) || isComplexType(t); } struct Node; @@ -105,87 +109,94 @@ struct DifferentiableViewMeta; // on Tensor proper namespace impl { - // WARNING: This may return a nullptr. If you require AutogradMeta to return - // a materialized structure, use materialize_autograd_meta instead. - TORCH_API AutogradMeta* get_autograd_meta(const at::TensorBase&); - - // WARNING: This will return a nullptr if the Tensor is not a view. - TORCH_API DifferentiableViewMeta* get_view_autograd_meta(const at::TensorBase&); - - // Returns the current autograd meta, materializing it if it was previously - // none. This counts as a *mutating* operation, so do not call it on - // "read-only" operators; in particular, this is NOT thread safe - TORCH_API AutogradMeta* materialize_autograd_meta(const at::TensorBase&); - - /// Set the gradient accumulator of the `Variable`. This is only applicable to - /// leaf variables. Interior variables should call `set_gradient_edge()`. - TORCH_API void set_grad_accumulator(const Variable&, std::weak_ptr grad_accumulator); - - /// Attempts to get a pointer to the gradient accumulator of the `Variable`, - /// if it still exists. If the gradient accumulator function has been - /// destroyed, returns a `nullptr`. - TORCH_API std::shared_ptr try_get_grad_accumulator(const Variable&); - - /// Gets the gradient accumulator of the `Variable` if it has one, or else - /// create one on the fly and return it. - TORCH_API std::shared_ptr grad_accumulator(const Variable&); - - /// Returns the "canonical" gradient edge of this `Variable`, i.e. either the - /// gradient function if this is an interior `Variable`, or the gradient - /// accumulator otherwise. If the `Variable` is interior, the returned `Edge` - /// will store the input index of the `Node` to which this variable is - /// connected in its `input_nr` field. For leaves, the `input_nr` is always - /// zero. Note that `set_gradient_edge` and `gradient_edge` are not - /// symmetric. You must use `set_gradient_edge` to set the `grad_fn` and - /// `set_grad_accumulator` to set the accumulator. - TORCH_API Edge gradient_edge(const Variable&); - - /// Set the gradient edge -- i.e. `grad_fn` and `input_nr` -- of the - /// `Variable`. - /// NOTE: This will always set the `grad_fn`, even if this is a leaf variable, - /// and never the `grad_accumulator`. For the latter, use - /// `set_grad_accumulator`. This allows late construction of an interior - /// `Variable`. - TORCH_API void set_gradient_edge(const Variable&, Edge edge); - - // Autograd Graph Interaction - //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - - /// Update the `grad_fn` of an existing Variable. Called after in-place - /// modifications. - /// - /// For View Variables: - /// Called after in-place modifications. Modifies the grad_fn of the base - /// Variable. - TORCH_API void rebase_history(const Variable&, Edge gradient_edge); - - /// Gets the raw gradient function pointer, whatever it currently is. - TORCH_API Node* grad_fn_unsafe(const Variable&); - - /// Increments the version count of this `Variable`. - TORCH_API void bump_version(const Variable&); - TORCH_API void set_version_counter(const Variable&, const c10::VariableVersion& version_counter); - - /// Retrieves this `Variable`s version counter. - TORCH_API const c10::VariableVersion& version_counter(const Variable&); - - TORCH_API void set_name(const Variable&, const std::string& name); - - TORCH_API void add_hook(const at::TensorBase&, std::shared_ptr hook); - TORCH_API const std::vector>& hooks(const Variable&); - TORCH_API void clear_hooks(const at::TensorBase&); - - TORCH_API void create_cpp_hook(const at::TensorBase&); -} +// WARNING: This may return a nullptr. If you require AutogradMeta to return +// a materialized structure, use materialize_autograd_meta instead. +TORCH_API AutogradMeta* get_autograd_meta(const at::TensorBase&); + +// WARNING: This will return a nullptr if the Tensor is not a view. +TORCH_API DifferentiableViewMeta* get_view_autograd_meta(const at::TensorBase&); + +// Returns the current autograd meta, materializing it if it was previously +// none. This counts as a *mutating* operation, so do not call it on +// "read-only" operators; in particular, this is NOT thread safe +TORCH_API AutogradMeta* materialize_autograd_meta(const at::TensorBase&); + +/// Set the gradient accumulator of the `Variable`. This is only applicable to +/// leaf variables. Interior variables should call `set_gradient_edge()`. +TORCH_API void set_grad_accumulator( + const Variable&, + std::weak_ptr grad_accumulator); + +/// Attempts to get a pointer to the gradient accumulator of the `Variable`, +/// if it still exists. If the gradient accumulator function has been +/// destroyed, returns a `nullptr`. +TORCH_API std::shared_ptr try_get_grad_accumulator(const Variable&); + +/// Gets the gradient accumulator of the `Variable` if it has one, or else +/// create one on the fly and return it. +TORCH_API std::shared_ptr grad_accumulator(const Variable&); + +/// Returns the "canonical" gradient edge of this `Variable`, i.e. either the +/// gradient function if this is an interior `Variable`, or the gradient +/// accumulator otherwise. If the `Variable` is interior, the returned `Edge` +/// will store the input index of the `Node` to which this variable is +/// connected in its `input_nr` field. For leaves, the `input_nr` is always +/// zero. Note that `set_gradient_edge` and `gradient_edge` are not +/// symmetric. You must use `set_gradient_edge` to set the `grad_fn` and +/// `set_grad_accumulator` to set the accumulator. +TORCH_API Edge gradient_edge(const Variable&); + +/// Set the gradient edge -- i.e. `grad_fn` and `input_nr` -- of the +/// `Variable`. +/// NOTE: This will always set the `grad_fn`, even if this is a leaf variable, +/// and never the `grad_accumulator`. For the latter, use +/// `set_grad_accumulator`. This allows late construction of an interior +/// `Variable`. +TORCH_API void set_gradient_edge(const Variable&, Edge edge); + +// Autograd Graph Interaction +//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +/// Update the `grad_fn` of an existing Variable. Called after in-place +/// modifications. +/// +/// For View Variables: +/// Called after in-place modifications. Modifies the grad_fn of the base +/// Variable. +TORCH_API void rebase_history(const Variable&, Edge gradient_edge); + +/// Gets the raw gradient function pointer, whatever it currently is. +TORCH_API Node* grad_fn_unsafe(const Variable&); + +/// Increments the version count of this `Variable`. +TORCH_API void bump_version(const Variable&); +TORCH_API void set_version_counter( + const Variable&, + const c10::VariableVersion& version_counter); + +/// Retrieves this `Variable`s version counter. +TORCH_API const c10::VariableVersion& version_counter(const Variable&); + +TORCH_API void set_name(const Variable&, const std::string& name); + +TORCH_API void add_hook( + const at::TensorBase&, + std::shared_ptr hook); +TORCH_API const std::vector>& hooks( + const Variable&); +TORCH_API void clear_hooks(const at::TensorBase&); + +TORCH_API void create_cpp_hook(const at::TensorBase&); +} // namespace impl //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ // AutogradMeta //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ /// Each `Variable` has one unique `AutogradMeta` struct, which stores autograd -/// metadata fields that are necessary for tracking the Variable's autograd history. -/// As an optimization, a Variable may store a nullptr, in lieu of a default -/// constructed AutogradMeta. +/// metadata fields that are necessary for tracking the Variable's autograd +/// history. As an optimization, a Variable may store a nullptr, in lieu of a +/// default constructed AutogradMeta. struct TORCH_API AutogradMeta : public c10::AutogradMetaInterface { std::string name_; @@ -232,10 +243,12 @@ struct TORCH_API AutogradMeta : public c10::AutogradMetaInterface { /// Sets the `requires_grad` property of `Variable`. This should be true for /// leaf variables that want to accumulate gradients, and false for all other /// variables. - void set_requires_grad(bool requires_grad, at::TensorImpl* self_impl) override { + void set_requires_grad(bool requires_grad, at::TensorImpl* self_impl) + override { TORCH_CHECK( - !requires_grad || isDifferentiableType(at::typeMetaToScalarType(self_impl->dtype())), - "Only Tensors of floating point and complex dtype can require gradients"); + !requires_grad || + isDifferentiableType(at::typeMetaToScalarType(self_impl->dtype())), + "Only Tensors of floating point and complex dtype can require gradients"); requires_grad_ = requires_grad; } @@ -252,12 +265,20 @@ struct TORCH_API AutogradMeta : public c10::AutogradMetaInterface { return grad_; } - const Variable& fw_grad(uint64_t level, const at::TensorBase& self) const override; + const Variable& fw_grad(uint64_t level, const at::TensorBase& self) + const override; - void set_fw_grad(const at::TensorBase& new_grad, const at::TensorBase& self, uint64_t level, bool is_inplace_op) override; + void set_fw_grad( + const at::TensorBase& new_grad, + const at::TensorBase& self, + uint64_t level, + bool is_inplace_op) override; // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) - AutogradMeta(at::TensorImpl* self_impl = nullptr, bool requires_grad = false, Edge gradient_edge = Edge() ) { + AutogradMeta( + at::TensorImpl* self_impl = nullptr, + bool requires_grad = false, + Edge gradient_edge = Edge()) { grad_fn_ = std::move(gradient_edge.function); requires_grad_ = false; retains_grad_ = false; @@ -276,9 +297,10 @@ struct TORCH_API AutogradMeta : public c10::AutogradMetaInterface { } ~AutogradMeta() override { - // If AutogradMeta is being destroyed, it means that there is no other reference to its - // corresponding Tensor. It implies that no other thread can be using this object and so there is - // no need to lock mutex_ here to guard the check if fw_grad_ is populated. + // If AutogradMeta is being destroyed, it means that there is no other + // reference to its corresponding Tensor. It implies that no other thread + // can be using this object and so there is no need to lock mutex_ here to + // guard the check if fw_grad_ is populated. if (fw_grad_) { // See note [ Using ForwardGrad ] fw_grad_->clear(); @@ -303,24 +325,26 @@ struct TORCH_API ViewInfo { } std::function view_fn() const { - TORCH_CHECK(has_view_fn(), "Can only access the view function if it exists."); + TORCH_CHECK( + has_view_fn(), "Can only access the view function if it exists."); return view_fn_; } - /// The chain function can be used to build a new ViewInfo for a differentiable view - /// function. It will return a new view info that accurately represents how "tensor" is - /// a view of this instance's "base_". - /// The "base" and "tensor" are respectively the input and output of the differentiable - /// view function that happened. They are required to properly set the optional - /// view_fn_ when it is not provided. - /// The "view_func", if provided, should be a function that allows to re-do the view - /// between "base" and "tensor". - ViewInfo chain(const Variable & base, const Variable & tensor, - std::function view_func=nullptr) const; - - ViewInfo(Variable base, std::function view_fn) : - base_(std::move(base)), - view_fn_(std::move(view_fn)) { + /// The chain function can be used to build a new ViewInfo for a + /// differentiable view function. It will return a new view info that + /// accurately represents how "tensor" is a view of this instance's "base_". + /// The "base" and "tensor" are respectively the input and output of the + /// differentiable view function that happened. They are required to properly + /// set the optional view_fn_ when it is not provided. The "view_func", if + /// provided, should be a function that allows to re-do the view between + /// "base" and "tensor". + ViewInfo chain( + const Variable& base, + const Variable& tensor, + std::function view_func = nullptr) const; + + ViewInfo(Variable base, std::function view_fn) + : base_(std::move(base)), view_fn_(std::move(view_fn)) { TORCH_CHECK(base_.defined(), "base is undefined"); } }; @@ -343,16 +367,17 @@ struct TORCH_API ViewInfo { /// /// Differentiable Views /// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -/// This class allows to track both forward and backward AD differentiable views. -/// These views can have different base as non-differentiable view for forward -/// and backward mode AD are not the same. +/// This class allows to track both forward and backward AD differentiable +/// views. These views can have different base as non-differentiable view for +/// forward and backward mode AD are not the same. /// /// Most function are either both forward and backward differentiable views (for /// example: view, select, narrow, transpose, etc) or both not forward and not /// backward differentiable views (for example: indices, values, eq, lt, etc). -/// But there are also functions that are forward but not backward differentiable -/// views (only detach for now) or functions that are backward but not forward -/// differentiable view (only make_dual and unpack dual for now). +/// But there are also functions that are forward but not backward +/// differentiable views (only detach for now) or functions that are backward +/// but not forward differentiable view (only make_dual and unpack dual for +/// now). /// /// A concrete example of two views with different bases is as follow: /// @@ -377,7 +402,8 @@ struct TORCH_API ViewInfo { /// # base.requires_grad = False /// # var.requires_grad = True /// base[1] = var # i.e., base[1].copy_(var) -/// torch.autograd.grad(base.sum(), var) <- should return an all ones tensor +/// torch.autograd.grad(base.sum(), var) <- should return an all ones +/// tensor /// /// (2) in-place operation on base after view is created, e.g., /// @@ -403,8 +429,8 @@ struct TORCH_API ViewInfo { /// base[1] = var # i.e., base[1].copy_(var) /// # Now, base is a dual Tensor /// _, fw_grad = fwAD.unpack_dual(base) <- fw_grad should be a tensor with -/// fw_grad[1] filled with all ones and -/// zeros everywhere else +/// fw_grad[1] filled with all ones +/// and zeros everywhere else /// /// (2) in-place operation on base after view is created, e.g., /// @@ -413,9 +439,11 @@ struct TORCH_API ViewInfo { /// # var is a dual Tensor whose tangent is all ones /// view = base[1] /// base.copy_(var) -/// _, fw_grad = fwAD.unpack_dual(view) <- fw_grad should be an all ones tensor +/// _, fw_grad = fwAD.unpack_dual(view) <- fw_grad should be an all ones +/// tensor /// -/// See Note [Forward Grad View/inplace] for more details on how we handle these hard cases. +/// See Note [Forward Grad View/inplace] for more details on how we handle these +/// hard cases. /// /// /// DifferentiableViewMeta is created to support gradient tracking of @@ -428,14 +456,14 @@ struct TORCH_API ViewInfo { /// called after every in-place op in VariableType.cpp, the grad_fn of base /// is updated. /// + if a single autograd Node returns multiple differentiable views, if any -/// output is modified by an inplace operation, the autograd engine will make -/// an equivalent graph (corresponding to the view operations) without using -/// equivalent graph, where each output is treated as if it were produced by a -/// distinct view operation. This discards the original (e.g., user provided) -/// grad_fn. If the provided grad_fn does more than the backward of the view, -/// then the DifferentiableViewMeta must be created with creation_meta= -/// CreationMeta::MULTI_OUTPUT_NODE to prevent the engine from ignoring the -/// provided grad_fn. +/// output is modified by an inplace operation, the autograd engine will +/// make an equivalent graph (corresponding to the view operations) without +/// using equivalent graph, where each output is treated as if it were +/// produced by a distinct view operation. This discards the original (e.g., +/// user provided) grad_fn. If the provided grad_fn does more than the +/// backward of the view, then the DifferentiableViewMeta must be created +/// with creation_meta= CreationMeta::MULTI_OUTPUT_NODE to prevent the +/// engine from ignoring the provided grad_fn. /// /// Interaction with GradMode: /// The particular case that we consider here is: @@ -448,15 +476,15 @@ struct TORCH_API ViewInfo { /// view.copy_(var) /// torch.autograd.grad(base.sum(), var) <- what should it return? /// -/// Given that this particular code example is ambiguous and can easily be replace by -/// either moving both inside the no_grad block or both outside, we explicitly forbid -/// it. For now, it is deprecated by a warning. This is achieved by setting -/// creation_meta=CreationMeta::NO_GRAD_MODE for all differentiable views created -/// in no_grad mode. +/// Given that this particular code example is ambiguous and can easily be +/// replace by either moving both inside the no_grad block or both outside, we +/// explicitly forbid it. For now, it is deprecated by a warning. This is +/// achieved by setting creation_meta=CreationMeta::NO_GRAD_MODE for all +/// differentiable views created in no_grad mode. /// /// See Note [View + Inplace update for base tensor] -/// and Note [View + Inplace update for view tensor] for the details how autograd -/// handles inplace update with view ops. +/// and Note [View + Inplace update for view tensor] for the details how +/// autograd handles inplace update with view ops. /// /// Non-Differentiable Views /// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -475,61 +503,75 @@ struct TORCH_API ViewInfo { /// These are called non-differentiable views as the gradients do not flow /// through the view relation. /// -/// Relevant logic for both differentiable and non-differentiable views is implemented in -/// make_variable_(non_)differentiable_view below, and wrap_output of gen_variable_type.py. - +/// Relevant logic for both differentiable and non-differentiable views is +/// implemented in make_variable_(non_)differentiable_view below, and +/// wrap_output of gen_variable_type.py. /// NOTE [ View + Inplace detection ] /// -/// We want to detect views followed by inplace as they are often forbidden to ensure -/// correctness of the computed gradients. But since we want to only notify the user -/// when both happen, we tag the DifferentiableViewMeta when the view is created -/// via the `make_variable_*_view()` functions. This tag is then checked by the -/// `check_inplace()` function from `VariableTypeUtils.h` that should be called before -/// every inplace operation and to detect cases where other views are modified and this -/// one is rebased by side effect, we also check in the `VariableHooks::grad_fn()`. +/// We want to detect views followed by inplace as they are often forbidden to +/// ensure correctness of the computed gradients. But since we want to only +/// notify the user when both happen, we tag the DifferentiableViewMeta when the +/// view is created via the `make_variable_*_view()` functions. This tag is then +/// checked by the `check_inplace()` function from `VariableTypeUtils.h` that +/// should be called before every inplace operation and to detect cases where +/// other views are modified and this one is rebased by side effect, we also +/// check in the `VariableHooks::grad_fn()`. /// Flag that gives more information about when this view was created: /// - IN_CUSTOM_FUNCTION should be set when the view is created inside a custom /// autograd Function is returned. -/// - NO_GRAD_MODE should be set when a view in created when GradMode is disabled -/// - MULTI_OUTPUT_NODE should be set when a Node created by codegen code returns +/// - NO_GRAD_MODE should be set when a view in created when GradMode is +/// disabled +/// - MULTI_OUTPUT_NODE should be set when a Node created by codegen code +/// returns /// multiple differentiable views -/// - Inference_MODE should be set when a view of normal tensor is created in InferenceMode. +/// - Inference_MODE should be set when a view of normal tensor is created in +/// InferenceMode. /// - DEFAULT is for all other cases -enum class CreationMeta: uint8_t { DEFAULT, IN_CUSTOM_FUNCTION, MULTI_OUTPUT_NODE, - NO_GRAD_MODE, INFERENCE_MODE}; - -/// Handles correctly propagating CreationMeta when a new view is created from a previous view. -/// In general, we don't want the new view to be _less_ restrictive than the previous view -/// (it's okay to be _more_ restrictive). -/// A CreationMeta value of DEFAULT is currently the least restrictive, as the behavior for -/// all other CreationMeta values is to error out for in-place ops. -/// A CreationMeta value of INFERENCE_MODE is currently the most restrictive, so it takes -/// precedence in propagation. -/// If this changes, the logic here will need to be updated to properly handle the new semantics. -inline CreationMeta propagate_creation_meta(CreationMeta prev_view_creation_meta, CreationMeta new_view_creation_meta) { - return (new_view_creation_meta == CreationMeta::DEFAULT) ? - prev_view_creation_meta : - (prev_view_creation_meta == CreationMeta::INFERENCE_MODE ? prev_view_creation_meta : new_view_creation_meta); +enum class CreationMeta : uint8_t { + DEFAULT, + IN_CUSTOM_FUNCTION, + MULTI_OUTPUT_NODE, + NO_GRAD_MODE, + INFERENCE_MODE +}; + +/// Handles correctly propagating CreationMeta when a new view is created from a +/// previous view. In general, we don't want the new view to be _less_ +/// restrictive than the previous view (it's okay to be _more_ restrictive). A +/// CreationMeta value of DEFAULT is currently the least restrictive, as the +/// behavior for all other CreationMeta values is to error out for in-place ops. +/// A CreationMeta value of INFERENCE_MODE is currently the most restrictive, so +/// it takes precedence in propagation. If this changes, the logic here will +/// need to be updated to properly handle the new semantics. +inline CreationMeta propagate_creation_meta( + CreationMeta prev_view_creation_meta, + CreationMeta new_view_creation_meta) { + return (new_view_creation_meta == CreationMeta::DEFAULT) + ? prev_view_creation_meta + : (prev_view_creation_meta == CreationMeta::INFERENCE_MODE + ? prev_view_creation_meta + : new_view_creation_meta); } /// Unified function to handle error checking when rebase happens -/// indirect=true means that the caller is not doing the inplace, but the inplace happened -/// somewhere else. -TORCH_API void handle_view_on_rebase(DifferentiableViewMeta* diff_view_meta, bool indirect=false); +/// indirect=true means that the caller is not doing the inplace, but the +/// inplace happened somewhere else. +TORCH_API void handle_view_on_rebase( + DifferentiableViewMeta* diff_view_meta, + bool indirect = false); struct TORCH_API DifferentiableViewMeta : public AutogradMeta { -private: + private: /// Informations about the views c10::optional backward_info_; c10::optional forward_info_; // Optimization to reduce the number of ViewInfo we create. // In the (very common) case where backward_info_ == forward_info_, we only - // populate backward_info_ (that should be used as both the forward and backward - // view information) and set shared_view_info_ = true. - // Invariants: + // populate backward_info_ (that should be used as both the forward and + // backward view information) and set shared_view_info_ = true. Invariants: // - If shared_view_info_ is false, there is no special constraints on // backward_info_ and forward_info_ // - If shared_view_info_ is true, we must have: @@ -537,19 +579,21 @@ struct TORCH_API DifferentiableViewMeta : public AutogradMeta { // - forward_info_.has_value() == false bool shared_view_info_; - /// The two following fields are extra information that we track to ensure that - /// any operation on this backward view is valid. + /// The two following fields are extra information that we track to ensure + /// that any operation on this backward view is valid. /// The value of the version_counter at the time grad_fn was created. The - /// grad_fn field is stale if attr_version_ != version_counter.current_version(). + /// grad_fn field is stale if attr_version_ != + /// version_counter.current_version(). uint32_t attr_version_; CreationMeta creation_meta_; -public: - /// requires_grad is a backward AD field so we only use the view specific logic - /// for backward differentiable views + public: + /// requires_grad is a backward AD field so we only use the view specific + /// logic for backward differentiable views bool requires_grad() const override { - return requires_grad_ || grad_fn_ || (has_bw_view() && get_backward_view().base_.requires_grad()); + return requires_grad_ || grad_fn_ || + (has_bw_view() && get_backward_view().base_.requires_grad()); } bool shared_view_info() const { @@ -561,27 +605,32 @@ struct TORCH_API DifferentiableViewMeta : public AutogradMeta { } const ViewInfo& get_backward_view() const { - TORCH_CHECK(has_bw_view(), "backward view info can only exist for backward views."); + TORCH_CHECK( + has_bw_view(), "backward view info can only exist for backward views."); return backward_info_.value(); } uint32_t get_attr_version() const { - TORCH_CHECK(has_bw_view(), "attr_version can only exist for backward views."); + TORCH_CHECK( + has_bw_view(), "attr_version can only exist for backward views."); return attr_version_; } void set_attr_version(uint32_t new_attr_version) { - TORCH_CHECK(has_bw_view(), "attr_version can only exist for backward views."); + TORCH_CHECK( + has_bw_view(), "attr_version can only exist for backward views."); attr_version_ = new_attr_version; } CreationMeta get_creation_meta() const { - TORCH_CHECK(has_bw_view(), "creation_meta can only exist for backward views."); + TORCH_CHECK( + has_bw_view(), "creation_meta can only exist for backward views."); return creation_meta_; } void set_creation_meta(CreationMeta new_creation_meta) { - TORCH_CHECK(has_bw_view(), "creation_meta can only exist for backward views."); + TORCH_CHECK( + has_bw_view(), "creation_meta can only exist for backward views."); creation_meta_ = new_creation_meta; } @@ -590,13 +639,20 @@ struct TORCH_API DifferentiableViewMeta : public AutogradMeta { } const ViewInfo& get_forward_view() const { - TORCH_CHECK(has_fw_view(), "forward view info can only exist for forward views."); - TORCH_CHECK(!shared_view_info_ || has_bw_view(), "forward view info can only exist for forward views."); + TORCH_CHECK( + has_fw_view(), "forward view info can only exist for forward views."); + TORCH_CHECK( + !shared_view_info_ || has_bw_view(), + "forward view info can only exist for forward views."); return shared_view_info_ ? backward_info_.value() : forward_info_.value(); } - DifferentiableViewMeta(at::TensorImpl* self_impl, c10::optional backward_info, - c10::optional forward_info, bool shared_view_info, CreationMeta creation_meta=CreationMeta::DEFAULT); + DifferentiableViewMeta( + at::TensorImpl* self_impl, + c10::optional backward_info, + c10::optional forward_info, + bool shared_view_info, + CreationMeta creation_meta = CreationMeta::DEFAULT); }; //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -612,10 +668,10 @@ struct TORCH_API DifferentiableViewMeta : public AutogradMeta { /// differentiable, i.e., whether the relation should be tracked by autograd. /// See NOTE [ Autograd View Variables ] for details. -/// NOTE: `allow_tensor_metadata_change` is set to true by default, because there -/// are a lot of call sites to these factory functions that need to change the -/// variable's size or storage afterwards, and they don't expect the original -/// tensor (where the variable is created from) to be updated. Setting +/// NOTE: `allow_tensor_metadata_change` is set to true by default, because +/// there are a lot of call sites to these factory functions that need to change +/// the variable's size or storage afterwards, and they don't expect the +/// original tensor (where the variable is created from) to be updated. Setting /// `allow_tensor_metadata_change_` to false by default would unnecessarily /// prevent those changes from happening and is undesirable. @@ -630,25 +686,34 @@ inline Variable make_variable_differentiable_view( bool allow_tensor_metadata_change = true) { if (data.defined()) { // If we already did a TensorImpl allocation for data, just reuse it. - // Otherwise(e.g tensor.swapdim(0, 0) when we return the same tensor as input), - // we have to use shallow_copy_and_detach to create a new TensorImpl to avoid - // moving leaf node into graph interior. This guarantees only 1 TensorImpl - // allocation happens in view ops. - if (data.getIntrusivePtr().unique() && data.getIntrusivePtr()->unique_version()) { + // Otherwise(e.g tensor.swapdim(0, 0) when we return the same tensor as + // input), we have to use shallow_copy_and_detach to create a new TensorImpl + // to avoid moving leaf node into graph interior. This guarantees only 1 + // TensorImpl allocation happens in view ops. + if (data.getIntrusivePtr().unique() && + data.getIntrusivePtr()->unique_version()) { at::TensorImpl* data_impl = data.unsafeGetTensorImpl(); data_impl->set_allow_tensor_metadata_change(allow_tensor_metadata_change); data_impl->set_autograd_meta(std::make_unique( - data_impl, std::move(backward_info), std::move(forward_info), - shared_view_info, creation_meta)); + data_impl, + std::move(backward_info), + std::move(forward_info), + shared_view_info, + creation_meta)); return data; } else { // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - c10::intrusive_ptr data_impl_copy = data.getIntrusivePtr()->shallow_copy_and_detach( - /*version_counter=*/0, - /*allow_tensor_metadata_change=*/allow_tensor_metadata_change); - data_impl_copy->set_autograd_meta(std::make_unique( - data_impl_copy.get(), std::move(backward_info), std::move(forward_info), - shared_view_info, creation_meta)); + c10::intrusive_ptr data_impl_copy = + data.getIntrusivePtr()->shallow_copy_and_detach( + /*version_counter=*/0, + /*allow_tensor_metadata_change=*/allow_tensor_metadata_change); + data_impl_copy->set_autograd_meta( + std::make_unique( + data_impl_copy.get(), + std::move(backward_info), + std::move(forward_info), + shared_view_info, + creation_meta)); return Variable(data_impl_copy); } } @@ -666,19 +731,18 @@ inline Variable make_variable_non_differentiable_view( // share the same TensorImpl as their base Tensor. Thus a new TensorImpl // allocation here is required. auto data_impl_copy = data.getIntrusivePtr()->shallow_copy_and_detach( - /*version_counter=*/impl::version_counter(base), - /*allow_tensor_metadata_change=*/allow_tensor_metadata_change); + /*version_counter=*/impl::version_counter(base), + /*allow_tensor_metadata_change=*/allow_tensor_metadata_change); data_impl_copy->set_autograd_meta(nullptr); return Variable(data_impl_copy); } return Variable(); } -/// Creates a `Variable` from the given `Tensor`, copying its underlying `TensorImpl`. -/// `requires_grad` should be -/// set only for leaves, and determines whether the `Variable` will accumulate -/// gradients. NOTE: `data` must *not* be a `Variable` already. Its dynamic -/// type *must* be `Tensor`. +/// Creates a `Variable` from the given `Tensor`, copying its underlying +/// `TensorImpl`. `requires_grad` should be set only for leaves, and determines +/// whether the `Variable` will accumulate gradients. NOTE: `data` must *not* be +/// a `Variable` already. Its dynamic type *must* be `Tensor`. /// /// TODO: Eliminate this function as much as possible, as it can be expressed /// more clearly as detach() or a no-op in most call sites (especially when @@ -688,24 +752,26 @@ inline Variable make_variable( bool requires_grad = false, bool allow_tensor_metadata_change = true) { if (data.defined()) { - if (data.getIntrusivePtr().use_count() == 1 && data.getIntrusivePtr()->unique_version()) { + if (data.getIntrusivePtr().use_count() == 1 && + data.getIntrusivePtr()->unique_version()) { auto data_impl = data.unsafeReleaseIntrusivePtr(); data_impl->set_allow_tensor_metadata_change(allow_tensor_metadata_change); // NOLINTNEXTLINE(bugprone-branch-clone) if (requires_grad) { - data_impl->set_autograd_meta(std::make_unique(data_impl.get(), requires_grad)); + data_impl->set_autograd_meta( + std::make_unique(data_impl.get(), requires_grad)); } else { data_impl->set_autograd_meta(nullptr); } return Variable(std::move(data_impl)); } else { auto data_impl_copy = data.getIntrusivePtr()->shallow_copy_and_detach( - /*version_counter=*/0, - /*allow_tensor_metadata_change=*/allow_tensor_metadata_change); + /*version_counter=*/0, + /*allow_tensor_metadata_change=*/allow_tensor_metadata_change); // NOLINTNEXTLINE(bugprone-branch-clone) if (requires_grad) { data_impl_copy->set_autograd_meta(std::make_unique( - data_impl_copy.get(), requires_grad)); + data_impl_copy.get(), requires_grad)); } else { data_impl_copy->set_autograd_meta(nullptr); } @@ -715,26 +781,26 @@ inline Variable make_variable( return Variable(); } -/// Creates a `Variable` from the given `Tensor`, copying its underlying `TensorImpl`. -/// `gradient_edge` should be a (function, input_nr) pair specifying the function -/// in the autograd graph, and what particular input of that function, this -/// variable is connected to. +/// Creates a `Variable` from the given `Tensor`, copying its underlying +/// `TensorImpl`. `gradient_edge` should be a (function, input_nr) pair +/// specifying the function in the autograd graph, and what particular input of +/// that function, this variable is connected to. inline Variable make_variable( at::Tensor data, Edge gradient_edge, bool allow_tensor_metadata_change = true) { if (data.defined()) { auto data_impl_copy = data.getIntrusivePtr()->shallow_copy_and_detach( - /*version_counter=*/0, - /*allow_tensor_metadata_change=*/allow_tensor_metadata_change); + /*version_counter=*/0, + /*allow_tensor_metadata_change=*/allow_tensor_metadata_change); data_impl_copy->set_autograd_meta(std::make_unique( - data_impl_copy.get(), false, std::move(gradient_edge))); + data_impl_copy.get(), false, std::move(gradient_edge))); return Variable(data_impl_copy); } return Variable(); } - -}} // namespace torch::autograd +} // namespace autograd +} // namespace torch #endif /* DOXYGEN_SHOULD_SKIP_THIS */ diff --git a/torch/csrc/copy_utils.h b/torch/csrc/copy_utils.h index 6184b063999add..200b46a23f6395 100644 --- a/torch/csrc/copy_utils.h +++ b/torch/csrc/copy_utils.h @@ -1,23 +1,28 @@ #pragma once +#include #include #include -#include typedef std::function THPCopyFunction; // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) struct THPCopyInfo { - PyTypeObject* srcType; // Python type of src tensor/storage - THPCopyFunction copy; // copy function - bool non_blocking; // true if copy implements an 'non_blocking' copy - bool broadcast; // true if the copy implements a broadcast copy + PyTypeObject* srcType; // Python type of src tensor/storage + THPCopyFunction copy; // copy function + bool non_blocking; // true if copy implements an 'non_blocking' copy + bool broadcast; // true if the copy implements a broadcast copy }; typedef std::vector THPCopyList; -inline bool tryTHPCopy(const THPCopyList& v, PyObject* dst, PyObject* src, bool non_blocking, bool broadcast) -{ +inline bool tryTHPCopy( + const THPCopyList& v, + PyObject* dst, + PyObject* src, + bool non_blocking, + bool broadcast) { for (auto& i : v) { - if (i.non_blocking == non_blocking && PyType_IsSubtype(Py_TYPE(src), i.srcType)) { + if (i.non_blocking == non_blocking && + PyType_IsSubtype(Py_TYPE(src), i.srcType)) { (i.copy)(dst, src, broadcast); return true; } @@ -25,15 +30,21 @@ inline bool tryTHPCopy(const THPCopyList& v, PyObject* dst, PyObject* src, bool return false; } -inline bool THPCopy(const THPCopyList& v, PyObject* dst, PyObject* src, bool non_blocking, bool broadcast) -{ +inline bool THPCopy( + const THPCopyList& v, + PyObject* dst, + PyObject* src, + bool non_blocking, + bool broadcast) { // NOLINTNEXTLINE(bugprone-branch-clone) if (tryTHPCopy(v, dst, src, non_blocking, broadcast)) { return true; } else if (non_blocking && tryTHPCopy(v, dst, src, false, broadcast)) { return true; } - THPUtils_setError("copy from %s to %s isn't implemented", - THPUtils_typename(src), THPUtils_typename(dst)); + THPUtils_setError( + "copy from %s to %s isn't implemented", + THPUtils_typename(src), + THPUtils_typename(dst)); return false; } diff --git a/torch/csrc/cuda/Event.cpp b/torch/csrc/cuda/Event.cpp index 4312b3aaf7b0c0..da81879bccebcb 100644 --- a/torch/csrc/cuda/Event.cpp +++ b/torch/csrc/cuda/Event.cpp @@ -1,31 +1,39 @@ #include +#include +#include #include #include #include -#include -#include #include #include #include -#include #include +#include -PyObject *THCPEventClass = nullptr; +PyObject* THCPEventClass = nullptr; -static PyObject * THCPEvent_pynew( - PyTypeObject *type, PyObject *args, PyObject *kwargs) { +static PyObject* THCPEvent_pynew( + PyTypeObject* type, + PyObject* args, + PyObject* kwargs) { HANDLE_TH_ERRORS unsigned char enable_timing = 0; unsigned char blocking = 0; unsigned char interprocess = 0; // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays) - static char *kwlist[] = - {"enable_timing", "blocking", "interprocess", nullptr}; - if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|bbb", kwlist, - &enable_timing, &blocking, &interprocess)) { + static char* kwlist[] = { + "enable_timing", "blocking", "interprocess", nullptr}; + if (!PyArg_ParseTupleAndKeywords( + args, + kwargs, + "|bbb", + kwlist, + &enable_timing, + &blocking, + &interprocess)) { return nullptr; } @@ -35,26 +43,27 @@ static PyObject * THCPEvent_pynew( } // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - THCPEvent* self = (THCPEvent *)ptr.get(); + THCPEvent* self = (THCPEvent*)ptr.get(); // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - unsigned int flags = - (blocking ? cudaEventBlockingSync : cudaEventDefault) | - (enable_timing ? cudaEventDefault : cudaEventDisableTiming) | - (interprocess ? cudaEventInterprocess : cudaEventDefault); + unsigned int flags = (blocking ? cudaEventBlockingSync : cudaEventDefault) | + (enable_timing ? cudaEventDefault : cudaEventDisableTiming) | + (interprocess ? cudaEventInterprocess : cudaEventDefault); new (&self->cuda_event) at::cuda::CUDAEvent(flags); - return (PyObject *)ptr.release(); + return (PyObject*)ptr.release(); END_HANDLE_TH_ERRORS } -static PyObject * THCPEvent_from_ipc_handle( - PyObject *_type, PyObject *args, PyObject *kwargs) { +static PyObject* THCPEvent_from_ipc_handle( + PyObject* _type, + PyObject* args, + PyObject* kwargs) { HANDLE_TH_ERRORS auto type = (PyTypeObject*)_type; static torch::PythonArgParser parser({ - "from_ipc_handle(Device device, std::string ipc_handle)", + "from_ipc_handle(Device device, std::string ipc_handle)", }); torch::ParsedArgs<2> parsed_args; auto r = parser.parse(args, kwargs, parsed_args); @@ -62,40 +71,46 @@ static PyObject * THCPEvent_from_ipc_handle( at::Device device = r.device(0); std::string handle_string = r.string(1); - TORCH_CHECK(handle_string.size() == sizeof(cudaIpcEventHandle_t), - "cudaIpcEventHandle_t expects byte-like object of size ", - sizeof(cudaIpcEventHandle_t), ", but got ", handle_string.size()); - TORCH_CHECK(device.type() == at::kCUDA, "Event can only be created on " - "CUDA devices, but got device type ", device.type()) + TORCH_CHECK( + handle_string.size() == sizeof(cudaIpcEventHandle_t), + "cudaIpcEventHandle_t expects byte-like object of size ", + sizeof(cudaIpcEventHandle_t), + ", but got ", + handle_string.size()); + TORCH_CHECK( + device.type() == at::kCUDA, + "Event can only be created on " + "CUDA devices, but got device type ", + device.type()) THPObjectPtr ptr(type->tp_alloc(type, 0)); if (!ptr) { return nullptr; } // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - THCPEvent* self = (THCPEvent *)ptr.get(); + THCPEvent* self = (THCPEvent*)ptr.get(); // NOLINTNEXTLINE(cppcoreguidelines-init-variables) cudaIpcEventHandle_t handle; std::memcpy(&handle, handle_string.c_str(), handle_string.size()); new (&self->cuda_event) at::cuda::CUDAEvent(device.index(), &handle); - return (PyObject *)ptr.release(); + return (PyObject*)ptr.release(); END_HANDLE_TH_ERRORS } -static void THCPEvent_dealloc(THCPEvent *self) { +static void THCPEvent_dealloc(THCPEvent* self) { self->cuda_event.~CUDAEvent(); Py_TYPE(self)->tp_free((PyObject*)self); } -static PyObject * THCPEvent_get_cuda_event(THCPEvent *self, void *unused) { +static PyObject* THCPEvent_get_cuda_event(THCPEvent* self, void* unused) { HANDLE_TH_ERRORS return PyLong_FromVoidPtr(self->cuda_event.event()); END_HANDLE_TH_ERRORS } -static PyObject * THCPEvent_get_device(THCPEvent *self, void *unused) { +static PyObject* THCPEvent_get_device(THCPEvent* self, void* unused) { HANDLE_TH_ERRORS at::optional device = self->cuda_event.device(); if (!device) { @@ -105,7 +120,7 @@ static PyObject * THCPEvent_get_device(THCPEvent *self, void *unused) { END_HANDLE_TH_ERRORS } -static PyObject * THCPEvent_record(PyObject *_self, PyObject *_stream) { +static PyObject* THCPEvent_record(PyObject* _self, PyObject* _stream) { HANDLE_TH_ERRORS auto self = (THCPEvent*)_self; auto stream = (THCPStream*)_stream; @@ -114,9 +129,8 @@ static PyObject * THCPEvent_record(PyObject *_self, PyObject *_stream) { END_HANDLE_TH_ERRORS } -static PyObject * THCPEvent_wait(PyObject *_self, PyObject *_stream) { - HANDLE_TH_ERRORS - { +static PyObject* THCPEvent_wait(PyObject* _self, PyObject* _stream) { + HANDLE_TH_ERRORS { auto self = (THCPEvent*)_self; auto stream = (THCPStream*)_stream; pybind11::gil_scoped_release no_gil{}; @@ -126,14 +140,14 @@ static PyObject * THCPEvent_wait(PyObject *_self, PyObject *_stream) { END_HANDLE_TH_ERRORS } -static PyObject * THCPEvent_query(PyObject *_self, PyObject *noargs) { +static PyObject* THCPEvent_query(PyObject* _self, PyObject* noargs) { HANDLE_TH_ERRORS auto self = (THCPEvent*)_self; return PyBool_FromLong(self->cuda_event.query()); END_HANDLE_TH_ERRORS } -static PyObject * THCPEvent_elapsed_time(PyObject *_self, PyObject *_other) { +static PyObject* THCPEvent_elapsed_time(PyObject* _self, PyObject* _other) { HANDLE_TH_ERRORS auto self = (THCPEvent*)_self; auto other = (THCPEvent*)_other; @@ -141,9 +155,8 @@ static PyObject * THCPEvent_elapsed_time(PyObject *_self, PyObject *_other) { END_HANDLE_TH_ERRORS } -static PyObject * THCPEvent_synchronize(PyObject *_self, PyObject *noargs) { - HANDLE_TH_ERRORS - { +static PyObject* THCPEvent_synchronize(PyObject* _self, PyObject* noargs) { + HANDLE_TH_ERRORS { auto self = (THCPEvent*)_self; pybind11::gil_scoped_release no_gil{}; self->cuda_event.synchronize(); @@ -152,88 +165,86 @@ static PyObject * THCPEvent_synchronize(PyObject *_self, PyObject *noargs) { END_HANDLE_TH_ERRORS } -static PyObject * THCPEvent_ipc_handle(PyObject *_self, PyObject *noargs) { +static PyObject* THCPEvent_ipc_handle(PyObject* _self, PyObject* noargs) { HANDLE_TH_ERRORS auto self = (THCPEvent*)_self; // NOLINTNEXTLINE(cppcoreguidelines-init-variables) cudaIpcEventHandle_t handle; self->cuda_event.ipc_handle(&handle); - return PyBytes_FromStringAndSize((const char *)&handle, sizeof(handle)); + return PyBytes_FromStringAndSize((const char*)&handle, sizeof(handle)); END_HANDLE_TH_ERRORS } -// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays, cppcoreguidelines-avoid-non-const-global-variables, modernize-avoid-c-arrays) +// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays, +// cppcoreguidelines-avoid-non-const-global-variables, modernize-avoid-c-arrays) static struct PyGetSetDef THCPEvent_properties[] = { - {"device", (getter)THCPEvent_get_device, nullptr, nullptr, nullptr}, - {"cuda_event", (getter)THCPEvent_get_cuda_event, nullptr, nullptr, nullptr}, - {nullptr} -}; + {"device", (getter)THCPEvent_get_device, nullptr, nullptr, nullptr}, + {"cuda_event", (getter)THCPEvent_get_cuda_event, nullptr, nullptr, nullptr}, + {nullptr}}; -// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays, cppcoreguidelines-avoid-non-const-global-variables, modernize-avoid-c-arrays) +// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays, +// cppcoreguidelines-avoid-non-const-global-variables, modernize-avoid-c-arrays) static PyMethodDef THCPEvent_methods[] = { - {(char*)"from_ipc_handle", - castPyCFunctionWithKeywords(THCPEvent_from_ipc_handle), - METH_CLASS | METH_VARARGS | METH_KEYWORDS, nullptr}, - {(char*)"record", THCPEvent_record, METH_O, nullptr}, - {(char*)"wait", THCPEvent_wait, METH_O, nullptr}, - {(char*)"query", THCPEvent_query, METH_NOARGS, nullptr}, - {(char*)"elapsed_time", THCPEvent_elapsed_time, METH_O, nullptr}, - {(char*)"synchronize", THCPEvent_synchronize, - METH_NOARGS, nullptr}, - {(char*)"ipc_handle", THCPEvent_ipc_handle, - METH_NOARGS, nullptr}, - {nullptr} -}; + {(char*)"from_ipc_handle", + castPyCFunctionWithKeywords(THCPEvent_from_ipc_handle), + METH_CLASS | METH_VARARGS | METH_KEYWORDS, + nullptr}, + {(char*)"record", THCPEvent_record, METH_O, nullptr}, + {(char*)"wait", THCPEvent_wait, METH_O, nullptr}, + {(char*)"query", THCPEvent_query, METH_NOARGS, nullptr}, + {(char*)"elapsed_time", THCPEvent_elapsed_time, METH_O, nullptr}, + {(char*)"synchronize", THCPEvent_synchronize, METH_NOARGS, nullptr}, + {(char*)"ipc_handle", THCPEvent_ipc_handle, METH_NOARGS, nullptr}, + {nullptr}}; PyTypeObject THCPEventType = { - PyVarObject_HEAD_INIT(nullptr, 0) - "torch._C._CudaEventBase", /* tp_name */ - sizeof(THCPEvent), /* tp_basicsize */ - 0, /* tp_itemsize */ - (destructor)THCPEvent_dealloc, /* tp_dealloc */ - 0, /* tp_vectorcall_offset */ - nullptr, /* tp_getattr */ - nullptr, /* tp_setattr */ - nullptr, /* tp_reserved */ - nullptr, /* tp_repr */ - nullptr, /* tp_as_number */ - nullptr, /* tp_as_sequence */ - nullptr, /* tp_as_mapping */ - nullptr, /* tp_hash */ - nullptr, /* tp_call */ - nullptr, /* tp_str */ - nullptr, /* tp_getattro */ - nullptr, /* tp_setattro */ - nullptr, /* tp_as_buffer */ - Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /* tp_flags */ - nullptr, /* tp_doc */ - nullptr, /* tp_traverse */ - nullptr, /* tp_clear */ - nullptr, /* tp_richcompare */ - 0, /* tp_weaklistoffset */ - nullptr, /* tp_iter */ - nullptr, /* tp_iternext */ - THCPEvent_methods, /* tp_methods */ - nullptr, /* tp_members */ - THCPEvent_properties, /* tp_getset */ - nullptr, /* tp_base */ - nullptr, /* tp_dict */ - nullptr, /* tp_descr_get */ - nullptr, /* tp_descr_set */ - 0, /* tp_dictoffset */ - nullptr, /* tp_init */ - nullptr, /* tp_alloc */ - THCPEvent_pynew, /* tp_new */ + PyVarObject_HEAD_INIT(nullptr, 0) "torch._C._CudaEventBase", /* tp_name */ + sizeof(THCPEvent), /* tp_basicsize */ + 0, /* tp_itemsize */ + (destructor)THCPEvent_dealloc, /* tp_dealloc */ + 0, /* tp_vectorcall_offset */ + nullptr, /* tp_getattr */ + nullptr, /* tp_setattr */ + nullptr, /* tp_reserved */ + nullptr, /* tp_repr */ + nullptr, /* tp_as_number */ + nullptr, /* tp_as_sequence */ + nullptr, /* tp_as_mapping */ + nullptr, /* tp_hash */ + nullptr, /* tp_call */ + nullptr, /* tp_str */ + nullptr, /* tp_getattro */ + nullptr, /* tp_setattro */ + nullptr, /* tp_as_buffer */ + Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /* tp_flags */ + nullptr, /* tp_doc */ + nullptr, /* tp_traverse */ + nullptr, /* tp_clear */ + nullptr, /* tp_richcompare */ + 0, /* tp_weaklistoffset */ + nullptr, /* tp_iter */ + nullptr, /* tp_iternext */ + THCPEvent_methods, /* tp_methods */ + nullptr, /* tp_members */ + THCPEvent_properties, /* tp_getset */ + nullptr, /* tp_base */ + nullptr, /* tp_dict */ + nullptr, /* tp_descr_get */ + nullptr, /* tp_descr_set */ + 0, /* tp_dictoffset */ + nullptr, /* tp_init */ + nullptr, /* tp_alloc */ + THCPEvent_pynew, /* tp_new */ }; -void THCPEvent_init(PyObject *module) { +void THCPEvent_init(PyObject* module) { THCPEventClass = (PyObject*)&THCPEventType; if (PyType_Ready(&THCPEventType) < 0) { throw python_error(); } Py_INCREF(&THCPEventType); - if (PyModule_AddObject( - module, "_CudaEventBase", (PyObject *)&THCPEventType) < 0) { + if (PyModule_AddObject(module, "_CudaEventBase", (PyObject*)&THCPEventType) < + 0) { throw python_error(); } } diff --git a/torch/csrc/cuda/Event.h b/torch/csrc/cuda/Event.h index 2d6037e68973f1..1b654415bb7f93 100644 --- a/torch/csrc/cuda/Event.h +++ b/torch/csrc/cuda/Event.h @@ -6,12 +6,11 @@ // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) struct THCPEvent { - PyObject_HEAD - at::cuda::CUDAEvent cuda_event; + PyObject_HEAD at::cuda::CUDAEvent cuda_event; }; -extern PyObject *THCPEventClass; +extern PyObject* THCPEventClass; -void THCPEvent_init(PyObject *module); +void THCPEvent_init(PyObject* module); inline bool THCPEvent_Check(PyObject* obj) { return THCPEventClass && PyObject_IsInstance(obj, THCPEventClass); diff --git a/torch/csrc/cuda/Graph.cpp b/torch/csrc/cuda/Graph.cpp index beacefa3f88786..0866b82f659dd2 100644 --- a/torch/csrc/cuda/Graph.cpp +++ b/torch/csrc/cuda/Graph.cpp @@ -11,41 +11,42 @@ // and partially from csrc/cuda/Stream.cpp. // THCPStream_init is also declared at global scope. -// Because THCPGraph_init is forward declared in the only consumer (csrc/Module.cpp) -// I don't think we need a Graph.h. +// Because THCPGraph_init is forward declared in the only consumer +// (csrc/Module.cpp) I don't think we need a Graph.h. template using shared_ptr_class_ = py::class_>; -void THCPGraph_init(PyObject *module) { +void THCPGraph_init(PyObject* module) { // Pybind11 patch notes say "py::module_" is more up-to-date syntax, // but CI linter and some builds prefer "module". auto torch_C_m = py::handle(module).cast(); - torch_C_m - .def("_graph_pool_handle", - &::at::cuda::graph_pool_handle); + torch_C_m.def("_graph_pool_handle", &::at::cuda::graph_pool_handle); - shared_ptr_class_<::at::cuda::CUDAGraph> - (torch_C_m, - "_CUDAGraph") + shared_ptr_class_<::at::cuda::CUDAGraph>(torch_C_m, "_CUDAGraph") .def(py::init<>()) - // I'm not sure this is the correct order of all the arguments. Pybind11 docs - // aren't clear. But it works. - .def("capture_begin", - &::at::cuda::CUDAGraph::capture_begin, - py::call_guard(), - py::arg("pool") = c10::cuda::MempoolId_t{0, 0}) - .def("capture_end", - &::at::cuda::CUDAGraph::capture_end, - py::call_guard()) - .def("replay", - &::at::cuda::CUDAGraph::replay, - py::call_guard()) - .def("reset", - &::at::cuda::CUDAGraph::reset, - py::call_guard()) - .def("pool", - &::at::cuda::CUDAGraph::pool, - py::call_guard()); + // I'm not sure this is the correct order of all the arguments. Pybind11 + // docs aren't clear. But it works. + .def( + "capture_begin", + &::at::cuda::CUDAGraph::capture_begin, + py::call_guard(), + py::arg("pool") = c10::cuda::MempoolId_t{0, 0}) + .def( + "capture_end", + &::at::cuda::CUDAGraph::capture_end, + py::call_guard()) + .def( + "replay", + &::at::cuda::CUDAGraph::replay, + py::call_guard()) + .def( + "reset", + &::at::cuda::CUDAGraph::reset, + py::call_guard()) + .def( + "pool", + &::at::cuda::CUDAGraph::pool, + py::call_guard()); } diff --git a/torch/csrc/cuda/Module.cpp b/torch/csrc/cuda/Module.cpp index 8bf1a41e01fb37..86b7418dfa86b8 100644 --- a/torch/csrc/cuda/Module.cpp +++ b/torch/csrc/cuda/Module.cpp @@ -1,33 +1,33 @@ #include #include #include -#include -#include #include -#include #include #include #include +#include +#include +#include #ifdef USE_NCCL #include #endif #include -#include #include -#include +#include +#include +#include +#include #include +#include #include #include -#include -#include -#include #include -#include -#include #include #include +#include +#include #ifndef WIN32 #include @@ -35,7 +35,7 @@ using namespace torch; -static bool in_bad_fork = false; // True for children forked after cuda init +static bool in_bad_fork = false; // True for children forked after cuda init #ifndef WIN32 // Called in the forked child if cuda has already been initialized @@ -51,7 +51,7 @@ static void forked_child() { static void poison_fork() { #ifndef WIN32 static std::once_flag flag; - std::call_once(flag, []{ pthread_atfork(nullptr, nullptr, forked_child); }); + std::call_once(flag, [] { pthread_atfork(nullptr, nullptr, forked_child); }); #endif } @@ -59,13 +59,11 @@ static void poison_fork() { // CUDA management methods //////////////////////////////////////////////////////////////////////////////// -void THCPModule_setDevice(int device) -{ +void THCPModule_setDevice(int device) { c10::cuda::set_device(static_cast(device)); } -PyObject * THCPModule_setDevice_wrap(PyObject *self, PyObject *arg) -{ +PyObject* THCPModule_setDevice_wrap(PyObject* self, PyObject* arg) { HANDLE_TH_ERRORS THPUtils_assert(THPUtils_checkLong(arg), "invalid argument to setDevice"); int64_t device = THPUtils_unpackLong(arg); @@ -77,8 +75,7 @@ PyObject * THCPModule_setDevice_wrap(PyObject *self, PyObject *arg) END_HANDLE_TH_ERRORS } -PyObject * THCPModule_getDevice_wrap(PyObject *self, PyObject *noargs) -{ +PyObject* THCPModule_getDevice_wrap(PyObject* self, PyObject* noargs) { HANDLE_TH_ERRORS torch::utils::cuda_lazy_init(); // NOLINTNEXTLINE(bugprone-signed-char-misuse) @@ -87,12 +84,11 @@ PyObject * THCPModule_getDevice_wrap(PyObject *self, PyObject *noargs) END_HANDLE_TH_ERRORS } -PyObject * THCPModule_canDeviceAccessPeer_wrap(PyObject *self, PyObject *args) -{ +PyObject* THCPModule_canDeviceAccessPeer_wrap(PyObject* self, PyObject* args) { HANDLE_TH_ERRORS PyObject* arg1 = nullptr; PyObject* arg2 = nullptr; - if(!PyArg_ParseTuple(args, "OO", &arg1, &arg2)) { + if (!PyArg_ParseTuple(args, "OO", &arg1, &arg2)) { THPUtils_invalidArguments( args, nullptr, @@ -101,8 +97,10 @@ PyObject * THCPModule_canDeviceAccessPeer_wrap(PyObject *self, PyObject *args) "(int device, int peer_device);"); return nullptr; } - THPUtils_assert(THPUtils_checkLong(arg1), "invalid argument to canDeviceAccessPeer"); - THPUtils_assert(THPUtils_checkLong(arg2), "invalid argument to canDeviceAccessPeer"); + THPUtils_assert( + THPUtils_checkLong(arg1), "invalid argument to canDeviceAccessPeer"); + THPUtils_assert( + THPUtils_checkLong(arg2), "invalid argument to canDeviceAccessPeer"); int64_t device = THPUtils_unpackLong(arg1); int64_t peer_device = THPUtils_unpackLong(arg2); @@ -112,16 +110,14 @@ PyObject * THCPModule_canDeviceAccessPeer_wrap(PyObject *self, PyObject *args) END_HANDLE_TH_ERRORS } -PyObject * THCPModule_getDeviceCount_wrap(PyObject *self, PyObject *noargs) -{ +PyObject* THCPModule_getDeviceCount_wrap(PyObject* self, PyObject* noargs) { HANDLE_TH_ERRORS poison_fork(); return THPUtils_packUInt64(at::cuda::device_count()); END_HANDLE_TH_ERRORS } -PyObject * THCPModule_getArchFlags(PyObject *self, PyObject *noargs) -{ +PyObject* THCPModule_getArchFlags(PyObject* self, PyObject* noargs) { HANDLE_TH_ERRORS poison_fork(); #ifdef CUDA_ARCH_FLAGS @@ -133,48 +129,48 @@ PyObject * THCPModule_getArchFlags(PyObject *self, PyObject *noargs) END_HANDLE_TH_ERRORS } -static PyObject * THCPModule_isInBadFork(PyObject *self, PyObject *noargs) { +static PyObject* THCPModule_isInBadFork(PyObject* self, PyObject* noargs) { HANDLE_TH_ERRORS return PyBool_FromLong(in_bad_fork); END_HANDLE_TH_ERRORS } -PyObject * THCPModule_getCurrentStream_wrap( - PyObject * /* unused */, PyObject *device_index) { +PyObject* THCPModule_getCurrentStream_wrap( + PyObject* /* unused */, + PyObject* device_index) { HANDLE_TH_ERRORS THPUtils_assert( - THPUtils_checkLong(device_index), "invalid argument to getCurrentStream"); + THPUtils_checkLong(device_index), "invalid argument to getCurrentStream"); int64_t device = THPUtils_unpackLong(device_index); return PyLong_FromUnsignedLongLong( - at::cuda::getCurrentCUDAStream(device).pack()); + at::cuda::getCurrentCUDAStream(device).pack()); END_HANDLE_TH_ERRORS } -PyObject * THCPModule_getCurrentStream_raw( - PyObject * /* unused */, PyObject *device_index) { +PyObject* THCPModule_getCurrentStream_raw( + PyObject* /* unused */, + PyObject* device_index) { HANDLE_TH_ERRORS THPUtils_assert( - THPUtils_checkLong(device_index), "invalid argument to getCurrentStream"); + THPUtils_checkLong(device_index), "invalid argument to getCurrentStream"); int64_t device = THPUtils_unpackLong(device_index); - return PyLong_FromVoidPtr( - at::cuda::getCurrentCUDAStream(device).stream()); + return PyLong_FromVoidPtr(at::cuda::getCurrentCUDAStream(device).stream()); END_HANDLE_TH_ERRORS } - -PyObject * THCPModule_getDefaultStream_wrap( - PyObject * /* unused */, PyObject *device_index) { +PyObject* THCPModule_getDefaultStream_wrap( + PyObject* /* unused */, + PyObject* device_index) { HANDLE_TH_ERRORS THPUtils_assert( - THPUtils_checkLong(device_index), "invalid argument to getDefaultStream"); + THPUtils_checkLong(device_index), "invalid argument to getDefaultStream"); int64_t device = THPUtils_unpackLong(device_index); return PyLong_FromUnsignedLongLong( - at::cuda::getDefaultCUDAStream(device).pack()); + at::cuda::getDefaultCUDAStream(device).pack()); END_HANDLE_TH_ERRORS } -PyObject * THCPModule_setStream_wrap(PyObject *self, PyObject *obj) -{ +PyObject* THCPModule_setStream_wrap(PyObject* self, PyObject* obj) { HANDLE_TH_ERRORS THPUtils_assert(PyLong_Check(obj), "invalid stream"); uint64_t bits = PyLong_AsUnsignedLongLong(obj); @@ -192,28 +188,28 @@ PyObject * THCPModule_setStream_wrap(PyObject *self, PyObject *obj) END_HANDLE_TH_ERRORS } -PyObject * THCPModule_getCompiledVersion(PyObject *self, PyObject *noargs) -{ +PyObject* THCPModule_getCompiledVersion(PyObject* self, PyObject* noargs) { #if defined(USE_ROCM) - return THPUtils_packInt64((int64_t) ROCM_VERSION); + return THPUtils_packInt64((int64_t)ROCM_VERSION); #else - return THPUtils_packInt64((int64_t) CUDA_VERSION); + return THPUtils_packInt64((int64_t)CUDA_VERSION); #endif } -PyObject * THCPModule_cudaHostAllocator(PyObject *_unused, PyObject *noargs) -{ +PyObject* THCPModule_cudaHostAllocator(PyObject* _unused, PyObject* noargs) { HANDLE_TH_ERRORS c10::Allocator* allocator = at::cuda::getCachingHostAllocator(); return PyLong_FromVoidPtr(allocator); END_HANDLE_TH_ERRORS } -PyObject * THCPModule_cudaCachingAllocator_raw_alloc(PyObject *_unused, PyObject *args){ +PyObject* THCPModule_cudaCachingAllocator_raw_alloc( + PyObject* _unused, + PyObject* args) { HANDLE_TH_ERRORS PyObject* size_o = nullptr; PyObject* stream_o = nullptr; - if(!PyArg_ParseTuple(args, "OO", &size_o, &stream_o)) { + if (!PyArg_ParseTuple(args, "OO", &size_o, &stream_o)) { THPUtils_invalidArguments( args, nullptr, @@ -226,15 +222,16 @@ PyObject * THCPModule_cudaCachingAllocator_raw_alloc(PyObject *_unused, PyObject // NOLINTNEXTLINE(cppcoreguidelines-init-variables) cudaStream_t stream = static_cast(PyLong_AsVoidPtr(stream_o)); // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - void* mem = c10::cuda::CUDACachingAllocator::raw_alloc_with_stream(size, stream); + void* mem = + c10::cuda::CUDACachingAllocator::raw_alloc_with_stream(size, stream); return PyLong_FromVoidPtr(mem); END_HANDLE_TH_ERRORS } // Unpack a PyObject to at::Scalar, throw an exception if it fails at::Scalar as_scalar(PyObject* arg) { - // Zero-dim tensors are converted to Scalars as-is. Note this doesn't currently - // handle most NumPy scalar types except np.float64. + // Zero-dim tensors are converted to Scalars as-is. Note this doesn't + // currently handle most NumPy scalar types except np.float64. if (THPVariable_Check(arg)) { return THPVariable_Unpack(arg).item(); } @@ -255,7 +252,9 @@ at::Scalar as_scalar(PyObject* arg) { // Entrypoint for the callable created by torch.cuda.jiterator // See jiterator.py for more details -PyObject * THCPModule_cudaJiteratorCompileAndLaunchKernel(PyObject *_unused, PyObject *args){ +PyObject* THCPModule_cudaJiteratorCompileAndLaunchKernel( + PyObject* _unused, + PyObject* args) { HANDLE_TH_ERRORS PyObject* code_string_o = nullptr; @@ -264,8 +263,15 @@ PyObject * THCPModule_cudaJiteratorCompileAndLaunchKernel(PyObject *_unused, PyO PyObject* num_outputs_o = nullptr; PyObject* tensors_o = nullptr; PyObject* kwargs_o = nullptr; - if(!PyArg_ParseTuple(args, "OOOOO|O", - &code_string_o, &kernel_name_o, &return_by_ref_o, &num_outputs_o, &tensors_o, &kwargs_o)) { + if (!PyArg_ParseTuple( + args, + "OOOOO|O", + &code_string_o, + &kernel_name_o, + &return_by_ref_o, + &num_outputs_o, + &tensors_o, + &kwargs_o)) { return nullptr; } @@ -274,32 +280,43 @@ PyObject * THCPModule_cudaJiteratorCompileAndLaunchKernel(PyObject *_unused, PyO const bool return_by_ref = THPUtils_unpackBool(return_by_ref_o); const int num_outputs = static_cast(THPUtils_unpackLong(num_outputs_o)); - THPUtils_assert(PyTuple_Check(tensors_o), "tensors argument is expected to " - "be a tuple, but got %s", THPUtils_typename(tensors_o)); + THPUtils_assert( + PyTuple_Check(tensors_o), + "tensors argument is expected to " + "be a tuple, but got %s", + THPUtils_typename(tensors_o)); Py_ssize_t num_tensors = PyTuple_GET_SIZE(tensors_o); c10::SmallVector tensors; - for(const auto i : c10::irange(num_tensors)) { - PyObject *_tensor = PyTuple_GET_ITEM(tensors_o, i); - THPUtils_assert(THPVariable_Check(_tensor), "%d of input tensors tuple is not a Tensor", i); + for (const auto i : c10::irange(num_tensors)) { + PyObject* _tensor = PyTuple_GET_ITEM(tensors_o, i); + THPUtils_assert( + THPVariable_Check(_tensor), + "%d of input tensors tuple is not a Tensor", + i); tensors.emplace_back(THPVariable_Unpack(_tensor)); } c10::SmallVector extra_args; - PyObject *key = nullptr; - PyObject *value = nullptr; + PyObject* key = nullptr; + PyObject* value = nullptr; Py_ssize_t pos = 0; while (PyDict_Next(kwargs_o, &pos, &key, &value)) { extra_args.emplace_back(as_scalar(value)); } - c10::SmallVector outputs = - at::cuda::CompileAndLaunchKernel(code_string, kernel_name, num_outputs, tensors, extra_args, return_by_ref); + c10::SmallVector outputs = at::cuda::CompileAndLaunchKernel( + code_string, + kernel_name, + num_outputs, + tensors, + extra_args, + return_by_ref); if (num_outputs == 1) { return THPVariable_Wrap(outputs[0]); - } else{ + } else { PyObject* output_tuple = PyTuple_New(num_outputs); for (int i = 0; i < num_outputs; ++i) { PyTuple_SetItem(output_tuple, i, THPVariable_Wrap(outputs[i])); @@ -310,7 +327,9 @@ PyObject * THCPModule_cudaJiteratorCompileAndLaunchKernel(PyObject *_unused, PyO END_HANDLE_TH_ERRORS } -PyObject * THCPModule_cudaCachingAllocator_raw_delete(PyObject *_unused, PyObject *obj){ +PyObject* THCPModule_cudaCachingAllocator_raw_delete( + PyObject* _unused, + PyObject* obj) { HANDLE_TH_ERRORS void* mem_ptr = PyLong_AsVoidPtr(obj); c10::cuda::CUDACachingAllocator::raw_delete(mem_ptr); @@ -318,26 +337,24 @@ PyObject * THCPModule_cudaCachingAllocator_raw_delete(PyObject *_unused, PyObjec END_HANDLE_TH_ERRORS } -PyObject * THCPModule_cudaSynchronize(PyObject *_unused, PyObject *noargs) -{ +PyObject* THCPModule_cudaSynchronize(PyObject* _unused, PyObject* noargs) { HANDLE_TH_ERRORS c10::cuda::device_synchronize(); Py_RETURN_NONE; END_HANDLE_TH_ERRORS } -PyObject * THCPModule_cudaIPCCollect(PyObject *_unused, PyObject *noargs) -{ +PyObject* THCPModule_cudaIPCCollect(PyObject* _unused, PyObject* noargs) { HANDLE_TH_ERRORS torch::CudaIPCCollect(); Py_RETURN_NONE; END_HANDLE_TH_ERRORS } -PyObject * THCPModule_cudaSleep(PyObject *_unused, PyObject *cycles) -{ +PyObject* THCPModule_cudaSleep(PyObject* _unused, PyObject* cycles) { HANDLE_TH_ERRORS - THPUtils_assert(THPUtils_checkLong(cycles), "torch.cuda._sleep(): expected 'int'"); + THPUtils_assert( + THPUtils_checkLong(cycles), "torch.cuda._sleep(): expected 'int'"); at::cuda::sleep(THPUtils_unpackLong(cycles)); Py_RETURN_NONE; END_HANDLE_TH_ERRORS @@ -351,8 +368,7 @@ PyObject * THCPModule_cudaSleep(PyObject *_unused, PyObject *cycles) // such thread). static PyGILState_STATE cudaMutexGILState; -PyObject * THCPModule_cudaLockMutex(PyObject *module, PyObject *noargs) -{ +PyObject* THCPModule_cudaLockMutex(PyObject* module, PyObject* noargs) { auto mutex = c10::cuda::CUDACachingAllocator::getFreeMutex(); // This has to be a busy loop because we **absolutely need to** hold the GIL // or it's a recipe for a deadlock otherwise (if we let other Python threads @@ -372,18 +388,17 @@ PyObject * THCPModule_cudaLockMutex(PyObject *module, PyObject *noargs) Py_RETURN_NONE; } -PyObject * THCPModule_cudaUnlockMutex(PyObject *module, PyObject *noargs) -{ +PyObject* THCPModule_cudaUnlockMutex(PyObject* module, PyObject* noargs) { auto mutex = c10::cuda::CUDACachingAllocator::getFreeMutex(); PyGILState_Release(cudaMutexGILState); mutex->unlock(); Py_RETURN_NONE; } -PyObject * THCPModule_hasPrimaryContext(PyObject *_unused, PyObject *arg) -{ +PyObject* THCPModule_hasPrimaryContext(PyObject* _unused, PyObject* arg) { HANDLE_TH_ERRORS - THPUtils_assert(THPUtils_checkLong(arg), "invalid argument to has_primary_context"); + THPUtils_assert( + THPUtils_checkLong(arg), "invalid argument to has_primary_context"); int64_t device_index = static_cast(THPUtils_unpackLong(arg)); if (at::cuda::detail::hasPrimaryContext(device_index)) { Py_RETURN_TRUE; @@ -393,12 +408,11 @@ PyObject * THCPModule_hasPrimaryContext(PyObject *_unused, PyObject *arg) END_HANDLE_TH_ERRORS } -PyObject * THCPModule_setMemoryFraction(PyObject *_unused, PyObject *args) -{ +PyObject* THCPModule_setMemoryFraction(PyObject* _unused, PyObject* args) { HANDLE_TH_ERRORS PyObject* fraction_o = nullptr; PyObject* device_o = nullptr; - if(!PyArg_ParseTuple(args, "OO", &fraction_o, &device_o)) { + if (!PyArg_ParseTuple(args, "OO", &fraction_o, &device_o)) { THPUtils_invalidArguments( args, nullptr, @@ -415,24 +429,23 @@ PyObject * THCPModule_setMemoryFraction(PyObject *_unused, PyObject *args) Py_RETURN_NONE; } -PyObject * THCPModule_emptyCache(PyObject *_unused, PyObject *noargs) -{ +PyObject* THCPModule_emptyCache(PyObject* _unused, PyObject* noargs) { HANDLE_TH_ERRORS c10::cuda::CUDACachingAllocator::emptyCache(); END_HANDLE_TH_ERRORS Py_RETURN_NONE; } -PyObject * THCPModule_memoryStats(PyObject *_unused, PyObject *arg) -{ +PyObject* THCPModule_memoryStats(PyObject* _unused, PyObject* arg) { HANDLE_TH_ERRORS - THPUtils_assert(THPUtils_checkLong(arg), "invalid argument to memory_allocated"); - const int device = (int) THPUtils_unpackLong(arg); + THPUtils_assert( + THPUtils_checkLong(arg), "invalid argument to memory_allocated"); + const int device = (int)THPUtils_unpackLong(arg); - using c10::cuda::CUDACachingAllocator::StatType; + using c10::cuda::CUDACachingAllocator::DeviceStats; using c10::cuda::CUDACachingAllocator::Stat; using c10::cuda::CUDACachingAllocator::StatArray; - using c10::cuda::CUDACachingAllocator::DeviceStats; + using c10::cuda::CUDACachingAllocator::StatType; const auto statToDict = [](const Stat& stat) { py::dict dict; @@ -445,9 +458,8 @@ PyObject * THCPModule_memoryStats(PyObject *_unused, PyObject *arg) }; const auto statArrayToDict = [=](const StatArray& statArray) { - const std::array(StatType::NUM_TYPES)> statTypeNames = { - "all", "small_pool", "large_pool" - }; + const std::array(StatType::NUM_TYPES)> + statTypeNames = {"all", "small_pool", "large_pool"}; py::dict dict; for (const auto i : c10::irange(statTypeNames.size())) { dict[statTypeNames[i]] = statToDict(statArray[i]); @@ -455,7 +467,8 @@ PyObject * THCPModule_memoryStats(PyObject *_unused, PyObject *arg) return dict; }; - const DeviceStats stats = c10::cuda::CUDACachingAllocator::getDeviceStats(device); + const DeviceStats stats = + c10::cuda::CUDACachingAllocator::getDeviceStats(device); py::dict result; result["num_alloc_retries"] = stats.num_alloc_retries; @@ -476,32 +489,34 @@ PyObject * THCPModule_memoryStats(PyObject *_unused, PyObject *arg) END_HANDLE_TH_ERRORS } -PyObject * THCPModule_resetAccumulatedMemoryStats(PyObject *_unused, PyObject *arg) -{ +PyObject* THCPModule_resetAccumulatedMemoryStats( + PyObject* _unused, + PyObject* arg) { HANDLE_TH_ERRORS - THPUtils_assert(THPUtils_checkLong(arg), "invalid argument to reset_accumulated_memory_stats"); - const int device = (int) THPUtils_unpackLong(arg); + THPUtils_assert( + THPUtils_checkLong(arg), + "invalid argument to reset_accumulated_memory_stats"); + const int device = (int)THPUtils_unpackLong(arg); c10::cuda::CUDACachingAllocator::resetAccumulatedStats(device); END_HANDLE_TH_ERRORS Py_RETURN_NONE; } -PyObject * THCPModule_resetPeakMemoryStats(PyObject *_unused, PyObject *arg) -{ +PyObject* THCPModule_resetPeakMemoryStats(PyObject* _unused, PyObject* arg) { HANDLE_TH_ERRORS - THPUtils_assert(THPUtils_checkLong(arg), "invalid argument to reset_peak_memory_stats"); - const int device = (int) THPUtils_unpackLong(arg); + THPUtils_assert( + THPUtils_checkLong(arg), "invalid argument to reset_peak_memory_stats"); + const int device = (int)THPUtils_unpackLong(arg); c10::cuda::CUDACachingAllocator::resetPeakStats(device); END_HANDLE_TH_ERRORS Py_RETURN_NONE; } -PyObject * THCPModule_memorySnapshot(PyObject *_unused, PyObject *noargs) -{ +PyObject* THCPModule_memorySnapshot(PyObject* _unused, PyObject* noargs) { HANDLE_TH_ERRORS - using c10::cuda::CUDACachingAllocator::SegmentInfo; using c10::cuda::CUDACachingAllocator::BlockInfo; + using c10::cuda::CUDACachingAllocator::SegmentInfo; const auto segmentInfoToDict = [](const SegmentInfo& segmentInfo) { py::dict segmentDict; @@ -516,7 +531,10 @@ PyObject * THCPModule_memorySnapshot(PyObject *_unused, PyObject *noargs) for (const auto& blockInfo : segmentInfo.blocks) { py::dict blockDict; blockDict["size"] = blockInfo.size; - blockDict["state"] = (blockInfo.allocated ? "active_allocated" : (blockInfo.active ? "active_pending_free" : "inactive")); + blockDict["state"] = + (blockInfo.allocated + ? "active_allocated" + : (blockInfo.active ? "active_pending_free" : "inactive")); blocks.append(blockDict); } segmentDict["blocks"] = blocks; @@ -524,7 +542,8 @@ PyObject * THCPModule_memorySnapshot(PyObject *_unused, PyObject *noargs) return segmentDict; }; - const std::vector& snapshot = c10::cuda::CUDACachingAllocator::snapshot(); + const std::vector& snapshot = + c10::cuda::CUDACachingAllocator::snapshot(); py::list result; for (const auto& segmentInfo : snapshot) { @@ -535,38 +554,53 @@ PyObject * THCPModule_memorySnapshot(PyObject *_unused, PyObject *noargs) END_HANDLE_TH_ERRORS } -PyObject * THCPModule_cudaSetSyncDebugMode(PyObject * _unused, PyObject * arg){ +PyObject* THCPModule_cudaSetSyncDebugMode(PyObject* _unused, PyObject* arg) { HANDLE_TH_ERRORS - TORCH_WARN_ONCE("Synchronization debug mode is a prototype feature and does not yet detect all " \ - "synchronizing operations"); - THPUtils_assert(THPUtils_checkLong(arg), "invalid argument to set_sync_debug_mode"); + TORCH_WARN_ONCE( + "Synchronization debug mode is a prototype feature and does not yet detect all " + "synchronizing operations"); + THPUtils_assert( + THPUtils_checkLong(arg), "invalid argument to set_sync_debug_mode"); int64_t debug_mode = THPUtils_unpackLong(arg); - TORCH_CHECK(debug_mode >=0 && debug_mode <=2, "invalid value of debug_mode, expected one of 0,1,2"); + TORCH_CHECK( + debug_mode >= 0 && debug_mode <= 2, + "invalid value of debug_mode, expected one of 0,1,2"); c10::cuda::SyncDebugMode l; switch (debug_mode) { - case 0: l = c10::cuda::SyncDebugMode::L_DISABLED; break; - case 1: l = c10::cuda::SyncDebugMode::L_WARN; break; - case 2: l = c10::cuda::SyncDebugMode::L_ERROR; break; - default: l = c10::cuda::SyncDebugMode::L_DISABLED; break; // can't happen + case 0: + l = c10::cuda::SyncDebugMode::L_DISABLED; + break; + case 1: + l = c10::cuda::SyncDebugMode::L_WARN; + break; + case 2: + l = c10::cuda::SyncDebugMode::L_ERROR; + break; + default: + l = c10::cuda::SyncDebugMode::L_DISABLED; + break; // can't happen } c10::cuda::warning_state().set_sync_debug_mode(l); Py_RETURN_NONE; END_HANDLE_TH_ERRORS } -PyObject * THCPModule_cudaGetSyncDebugMode(PyObject *self, PyObject *noargs){ +PyObject* THCPModule_cudaGetSyncDebugMode(PyObject* self, PyObject* noargs) { HANDLE_TH_ERRORS auto debug_mode = c10::cuda::warning_state().get_sync_debug_mode(); - switch (debug_mode){ - case c10::cuda::SyncDebugMode::L_DISABLED: return THPUtils_packInt32(0); - case c10::cuda::SyncDebugMode::L_WARN: return THPUtils_packInt32(1); - case c10::cuda::SyncDebugMode::L_ERROR: return THPUtils_packInt32(2); - default: return THPUtils_packInt32(-1); // can't happen + switch (debug_mode) { + case c10::cuda::SyncDebugMode::L_DISABLED: + return THPUtils_packInt32(0); + case c10::cuda::SyncDebugMode::L_WARN: + return THPUtils_packInt32(1); + case c10::cuda::SyncDebugMode::L_ERROR: + return THPUtils_packInt32(2); + default: + return THPUtils_packInt32(-1); // can't happen } END_HANDLE_TH_ERRORS } - //////////////////////////////////////////////////////////////////////////////// // Cuda module initialization //////////////////////////////////////////////////////////////////////////////// @@ -575,47 +609,54 @@ static void registerCudaDeviceProperties(PyObject* module) { // Add _cudaDevicePropertires class to torch._C auto m = py::handle(module).cast(); py::class_(m, "_CudaDeviceProperties") - .def_readonly("name", &cudaDeviceProp::name) - .def_readonly("major", &cudaDeviceProp::major) - .def_readonly("minor", &cudaDeviceProp::minor) - .def_readonly("is_multi_gpu_board", &cudaDeviceProp::isMultiGpuBoard) - .def_readonly("is_integrated", &cudaDeviceProp::integrated) - .def_readonly("multi_processor_count", &cudaDeviceProp::multiProcessorCount) - .def_readonly("total_memory", &cudaDeviceProp::totalGlobalMem) - .def("__repr__", [](const cudaDeviceProp &prop) { - std::ostringstream stream; - stream << "_CudaDeviceProperties(name='" << prop.name << "', major=" << prop.major - << ", minor=" << prop.minor << ", total_memory=" << prop.totalGlobalMem / (1024 * 1024) - << "MB, multi_processor_count=" << prop.multiProcessorCount << ")"; - return stream.str(); - }); + .def_readonly("name", &cudaDeviceProp::name) + .def_readonly("major", &cudaDeviceProp::major) + .def_readonly("minor", &cudaDeviceProp::minor) + .def_readonly("is_multi_gpu_board", &cudaDeviceProp::isMultiGpuBoard) + .def_readonly("is_integrated", &cudaDeviceProp::integrated) + .def_readonly( + "multi_processor_count", &cudaDeviceProp::multiProcessorCount) + .def_readonly("total_memory", &cudaDeviceProp::totalGlobalMem) + .def("__repr__", [](const cudaDeviceProp& prop) { + std::ostringstream stream; + stream << "_CudaDeviceProperties(name='" << prop.name + << "', major=" << prop.major << ", minor=" << prop.minor + << ", total_memory=" << prop.totalGlobalMem / (1024 * 1024) + << "MB, multi_processor_count=" << prop.multiProcessorCount + << ")"; + return stream.str(); + }); } static void bindGetDeviceProperties(PyObject* module) { // Add method to torch.cuda auto m = py::handle(module).cast(); - m.def("_get_device_properties", [](int device) -> cudaDeviceProp * { - return at::cuda::getDeviceProperties(device); - }, py::return_value_policy::reference); + m.def( + "_get_device_properties", + [](int device) -> cudaDeviceProp* { + return at::cuda::getDeviceProperties(device); + }, + py::return_value_policy::reference); } -// Callback for python part. Used for additional initialization of python classes -static PyObject * THCPModule_initExtension(PyObject *self, PyObject *noargs) -{ +// Callback for python part. Used for additional initialization of python +// classes +static PyObject* THCPModule_initExtension(PyObject* self, PyObject* noargs) { #if C10_ASAN_ENABLED TORCH_WARN( - "torch.cuda: your pytorch binary has address sanitizer (asan) built in, " - "asan is currently not compatible with torch.cuda module, " - "you might get unexpected behavior (eg. out of memory, crash, etc.), " - "please rebuild pytorch without asan if you need to use this module"); + "torch.cuda: your pytorch binary has address sanitizer (asan) built in, " + "asan is currently not compatible with torch.cuda module, " + "you might get unexpected behavior (eg. out of memory, crash, etc.), " + "please rebuild pytorch without asan if you need to use this module"); #endif HANDLE_TH_ERRORS - TORCH_INTERNAL_ASSERT(!in_bad_fork); // Handled at python level + TORCH_INTERNAL_ASSERT(!in_bad_fork); // Handled at python level poison_fork(); at::globalContext().lazyInitCUDA(); auto m = THPObjectPtr(PyImport_ImportModule("torch.cuda")); - if (!m) throw python_error(); + if (!m) + throw python_error(); bool has_half = true; @@ -631,7 +672,7 @@ static PyObject * THCPModule_initExtension(PyObject *self, PyObject *noargs) auto num_gpus = c10::cuda::device_count(); auto default_cuda_generators = PyTuple_New(static_cast(num_gpus)); - for(const auto i : c10::irange(num_gpus)) { + for (const auto i : c10::irange(num_gpus)) { // NOLINTNEXTLINE(performance-unnecessary-copy-initialization) auto gen = at::cuda::detail::getDefaultCUDAGenerator(i); auto cast_gen = (THPGenerator*)THPGenerator_initDefaultGenerator(gen); @@ -645,8 +686,9 @@ static PyObject * THCPModule_initExtension(PyObject *self, PyObject *noargs) END_HANDLE_TH_ERRORS } -PyObject * THCPModule_getCurrentBlasHandle_wrap(PyObject *self, PyObject *noargs) -{ +PyObject* THCPModule_getCurrentBlasHandle_wrap( + PyObject* self, + PyObject* noargs) { HANDLE_TH_ERRORS // NOLINTNEXTLINE(cppcoreguidelines-init-variables) cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); @@ -654,14 +696,14 @@ PyObject * THCPModule_getCurrentBlasHandle_wrap(PyObject *self, PyObject *noargs END_HANDLE_TH_ERRORS } -PyObject * THCPModule_rocm_is_backward_pass(PyObject *_unused, PyObject *noargs) -{ +PyObject* THCPModule_rocm_is_backward_pass( + PyObject* _unused, + PyObject* noargs) { HANDLE_TH_ERRORS #if USE_ROCM if (at::ROCmBackwardPassGuard::is_backward_pass()) { Py_RETURN_TRUE; - } - else { + } else { Py_RETURN_FALSE; } #else @@ -670,8 +712,9 @@ PyObject * THCPModule_rocm_is_backward_pass(PyObject *_unused, PyObject *noargs) END_HANDLE_TH_ERRORS } -static PyObject * THCPModule_isCurrentStreamCapturing_wrap(PyObject *self, PyObject *noargs) -{ +static PyObject* THCPModule_isCurrentStreamCapturing_wrap( + PyObject* self, + PyObject* noargs) { HANDLE_TH_ERRORS // If there's no cuda context, at::cuda::currentStreamCaptureStatus returns // CaptureStatus::None without initializing a context. @@ -683,62 +726,118 @@ static PyObject * THCPModule_isCurrentStreamCapturing_wrap(PyObject *self, PyObj END_HANDLE_TH_ERRORS } -// NOLINTNEXTLINE(modernize-avoid-c-arrays, cppcoreguidelines-avoid-non-const-global-variables, cppcoreguidelines-avoid-c-arrays) +// NOLINTNEXTLINE(modernize-avoid-c-arrays, +// cppcoreguidelines-avoid-non-const-global-variables, +// cppcoreguidelines-avoid-c-arrays) static struct PyMethodDef _THCPModule_methods[] = { - {"_cuda_init", THCPModule_initExtension, METH_NOARGS, nullptr}, - {"_cuda_setDevice", THCPModule_setDevice_wrap, METH_O, nullptr}, - {"_cuda_getDevice", THCPModule_getDevice_wrap, METH_NOARGS, nullptr}, - {"_cuda_getDeviceCount", THCPModule_getDeviceCount_wrap, METH_NOARGS, nullptr}, - {"_cuda_canDeviceAccessPeer", THCPModule_canDeviceAccessPeer_wrap, METH_VARARGS, nullptr}, - {"_cuda_getArchFlags", THCPModule_getArchFlags, METH_NOARGS, nullptr}, - {"_cuda_isInBadFork", THCPModule_isInBadFork, METH_NOARGS, nullptr}, - {"_cuda_getCurrentStream", - THCPModule_getCurrentStream_wrap, METH_O, nullptr}, - {"_cuda_getCurrentRawStream", - THCPModule_getCurrentStream_raw, METH_O, nullptr}, - {"_cuda_getDefaultStream", - THCPModule_getDefaultStream_wrap, METH_O, nullptr}, - {"_cuda_getCurrentBlasHandle", THCPModule_getCurrentBlasHandle_wrap, METH_NOARGS, nullptr}, - {"_cuda_isCurrentStreamCapturing", THCPModule_isCurrentStreamCapturing_wrap, METH_NOARGS, nullptr}, - {"_cuda_setStream", THCPModule_setStream_wrap, METH_O, nullptr}, - {"_cuda_getCompiledVersion", THCPModule_getCompiledVersion, METH_NOARGS, nullptr}, - {"_cuda_hasPrimaryContext", THCPModule_hasPrimaryContext, METH_O, nullptr}, - {"_cuda_setMemoryFraction", THCPModule_setMemoryFraction, METH_VARARGS, nullptr}, - {"_cuda_emptyCache", THCPModule_emptyCache, METH_NOARGS, nullptr}, - {"_cuda_memoryStats", THCPModule_memoryStats, METH_O, nullptr}, - {"_cuda_resetAccumulatedMemoryStats", THCPModule_resetAccumulatedMemoryStats, METH_O, nullptr}, - {"_cuda_resetPeakMemoryStats", THCPModule_resetPeakMemoryStats, METH_O, nullptr}, - {"_cuda_memorySnapshot", THCPModule_memorySnapshot, METH_NOARGS, nullptr}, - {"_cuda_cudaHostAllocator", THCPModule_cudaHostAllocator, METH_NOARGS, nullptr}, - {"_cuda_cudaCachingAllocator_raw_alloc", THCPModule_cudaCachingAllocator_raw_alloc, METH_VARARGS, nullptr}, - {"_cuda_cudaCachingAllocator_raw_delete", THCPModule_cudaCachingAllocator_raw_delete, METH_O, nullptr}, - {"_cuda_synchronize", THCPModule_cudaSynchronize, METH_NOARGS, nullptr}, - {"_cuda_ipc_collect", THCPModule_cudaIPCCollect, METH_NOARGS, nullptr}, - {"_cuda_sleep", THCPModule_cudaSleep, METH_O, nullptr}, - {"_cuda_lock_mutex", THCPModule_cudaLockMutex, METH_NOARGS, nullptr}, - {"_cuda_unlock_mutex", THCPModule_cudaUnlockMutex, METH_NOARGS, nullptr}, - {"_cuda_set_sync_debug_mode", THCPModule_cudaSetSyncDebugMode, METH_O, nullptr}, - {"_cuda_get_sync_debug_mode", THCPModule_cudaGetSyncDebugMode, METH_NOARGS, nullptr}, - {"_cuda_jiterator_compile_and_launch_kernel", THCPModule_cudaJiteratorCompileAndLaunchKernel, METH_VARARGS, nullptr}, + {"_cuda_init", THCPModule_initExtension, METH_NOARGS, nullptr}, + {"_cuda_setDevice", THCPModule_setDevice_wrap, METH_O, nullptr}, + {"_cuda_getDevice", THCPModule_getDevice_wrap, METH_NOARGS, nullptr}, + {"_cuda_getDeviceCount", + THCPModule_getDeviceCount_wrap, + METH_NOARGS, + nullptr}, + {"_cuda_canDeviceAccessPeer", + THCPModule_canDeviceAccessPeer_wrap, + METH_VARARGS, + nullptr}, + {"_cuda_getArchFlags", THCPModule_getArchFlags, METH_NOARGS, nullptr}, + {"_cuda_isInBadFork", THCPModule_isInBadFork, METH_NOARGS, nullptr}, + {"_cuda_getCurrentStream", + THCPModule_getCurrentStream_wrap, + METH_O, + nullptr}, + {"_cuda_getCurrentRawStream", + THCPModule_getCurrentStream_raw, + METH_O, + nullptr}, + {"_cuda_getDefaultStream", + THCPModule_getDefaultStream_wrap, + METH_O, + nullptr}, + {"_cuda_getCurrentBlasHandle", + THCPModule_getCurrentBlasHandle_wrap, + METH_NOARGS, + nullptr}, + {"_cuda_isCurrentStreamCapturing", + THCPModule_isCurrentStreamCapturing_wrap, + METH_NOARGS, + nullptr}, + {"_cuda_setStream", THCPModule_setStream_wrap, METH_O, nullptr}, + {"_cuda_getCompiledVersion", + THCPModule_getCompiledVersion, + METH_NOARGS, + nullptr}, + {"_cuda_hasPrimaryContext", THCPModule_hasPrimaryContext, METH_O, nullptr}, + {"_cuda_setMemoryFraction", + THCPModule_setMemoryFraction, + METH_VARARGS, + nullptr}, + {"_cuda_emptyCache", THCPModule_emptyCache, METH_NOARGS, nullptr}, + {"_cuda_memoryStats", THCPModule_memoryStats, METH_O, nullptr}, + {"_cuda_resetAccumulatedMemoryStats", + THCPModule_resetAccumulatedMemoryStats, + METH_O, + nullptr}, + {"_cuda_resetPeakMemoryStats", + THCPModule_resetPeakMemoryStats, + METH_O, + nullptr}, + {"_cuda_memorySnapshot", THCPModule_memorySnapshot, METH_NOARGS, nullptr}, + {"_cuda_cudaHostAllocator", + THCPModule_cudaHostAllocator, + METH_NOARGS, + nullptr}, + {"_cuda_cudaCachingAllocator_raw_alloc", + THCPModule_cudaCachingAllocator_raw_alloc, + METH_VARARGS, + nullptr}, + {"_cuda_cudaCachingAllocator_raw_delete", + THCPModule_cudaCachingAllocator_raw_delete, + METH_O, + nullptr}, + {"_cuda_synchronize", THCPModule_cudaSynchronize, METH_NOARGS, nullptr}, + {"_cuda_ipc_collect", THCPModule_cudaIPCCollect, METH_NOARGS, nullptr}, + {"_cuda_sleep", THCPModule_cudaSleep, METH_O, nullptr}, + {"_cuda_lock_mutex", THCPModule_cudaLockMutex, METH_NOARGS, nullptr}, + {"_cuda_unlock_mutex", THCPModule_cudaUnlockMutex, METH_NOARGS, nullptr}, + {"_cuda_set_sync_debug_mode", + THCPModule_cudaSetSyncDebugMode, + METH_O, + nullptr}, + {"_cuda_get_sync_debug_mode", + THCPModule_cudaGetSyncDebugMode, + METH_NOARGS, + nullptr}, + {"_cuda_jiterator_compile_and_launch_kernel", + THCPModule_cudaJiteratorCompileAndLaunchKernel, + METH_VARARGS, + nullptr}, #ifdef USE_NCCL - {"_nccl_version", THCPModule_nccl_version, METH_NOARGS, nullptr}, - {"_nccl_unique_id", THCPModule_nccl_unique_id, METH_NOARGS, nullptr}, - {"_nccl_init_rank", THCPModule_nccl_init_rank, METH_VARARGS, nullptr}, - {"_nccl_reduce", THCPModule_nccl_reduce, METH_VARARGS, nullptr}, - {"_nccl_all_reduce", THCPModule_nccl_all_reduce, METH_VARARGS, nullptr}, - {"_nccl_broadcast", THCPModule_nccl_broadcast, METH_VARARGS, nullptr}, - {"_nccl_all_gather", THCPModule_nccl_all_gather, METH_VARARGS, nullptr}, - {"_nccl_reduce_scatter", THCPModule_nccl_reduce_scatter, METH_VARARGS, nullptr}, + {"_nccl_version", THCPModule_nccl_version, METH_NOARGS, nullptr}, + {"_nccl_unique_id", THCPModule_nccl_unique_id, METH_NOARGS, nullptr}, + {"_nccl_init_rank", THCPModule_nccl_init_rank, METH_VARARGS, nullptr}, + {"_nccl_reduce", THCPModule_nccl_reduce, METH_VARARGS, nullptr}, + {"_nccl_all_reduce", THCPModule_nccl_all_reduce, METH_VARARGS, nullptr}, + {"_nccl_broadcast", THCPModule_nccl_broadcast, METH_VARARGS, nullptr}, + {"_nccl_all_gather", THCPModule_nccl_all_gather, METH_VARARGS, nullptr}, + {"_nccl_reduce_scatter", + THCPModule_nccl_reduce_scatter, + METH_VARARGS, + nullptr}, #endif - {"_rocm_is_backward_pass", THCPModule_rocm_is_backward_pass, METH_NOARGS, nullptr}, - {nullptr} -}; + {"_rocm_is_backward_pass", + THCPModule_rocm_is_backward_pass, + METH_NOARGS, + nullptr}, + {nullptr}}; PyMethodDef* THCPModule_methods() { return _THCPModule_methods; } -namespace torch { namespace cuda { +namespace torch { +namespace cuda { namespace shared { @@ -750,7 +849,7 @@ void initCudnnBindings(PyObject* module); } // namespace shared -void initModule(PyObject *module) { +void initModule(PyObject* module) { python::initCommMethods(module); // As weird as it seems, this file is also compiled for ROCm, // so this condition might not always be true... @@ -762,4 +861,5 @@ void initModule(PyObject *module) { registerCudaDeviceProperties(module); } -}} +} // namespace cuda +} // namespace torch diff --git a/torch/csrc/cuda/Module.h b/torch/csrc/cuda/Module.h index 6f903decd1fb27..23d9079f146f54 100644 --- a/torch/csrc/cuda/Module.h +++ b/torch/csrc/cuda/Module.h @@ -2,11 +2,11 @@ #define THCP_CUDA_MODULE_INC void THCPModule_setDevice(int idx); -PyObject * THCPModule_getDevice_wrap(PyObject *self); -PyObject * THCPModule_setDevice_wrap(PyObject *self, PyObject *arg); -PyObject * THCPModule_getDeviceName_wrap(PyObject *self, PyObject *arg); -PyObject * THCPModule_getDriverVersion(PyObject *self); -PyObject * THCPModule_isDriverSufficient(PyObject *self); -PyObject * THCPModule_getCurrentBlasHandle_wrap(PyObject *self); +PyObject* THCPModule_getDevice_wrap(PyObject* self); +PyObject* THCPModule_setDevice_wrap(PyObject* self, PyObject* arg); +PyObject* THCPModule_getDeviceName_wrap(PyObject* self, PyObject* arg); +PyObject* THCPModule_getDriverVersion(PyObject* self); +PyObject* THCPModule_isDriverSufficient(PyObject* self); +PyObject* THCPModule_getCurrentBlasHandle_wrap(PyObject* self); #endif diff --git a/torch/csrc/cuda/Stream.cpp b/torch/csrc/cuda/Stream.cpp index c23d2dd202bfc3..b604529f035c9b 100644 --- a/torch/csrc/cuda/Stream.cpp +++ b/torch/csrc/cuda/Stream.cpp @@ -1,19 +1,21 @@ #include -#include -#include -#include #include #include +#include +#include +#include #include -#include #include +#include -PyObject *THCPStreamClass = nullptr; +PyObject* THCPStreamClass = nullptr; -static PyObject * THCPStream_pynew( - PyTypeObject *type, PyObject *args, PyObject *kwargs) { +static PyObject* THCPStream_pynew( + PyTypeObject* type, + PyObject* args, + PyObject* kwargs) { HANDLE_TH_ERRORS const auto current_device = c10::cuda::current_device(); @@ -23,9 +25,9 @@ static PyObject * THCPStream_pynew( uint64_t stream_ptr = 0; // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays) - static char *kwlist[] = {"priority", "_cdata", "stream_ptr", nullptr}; + static char* kwlist[] = {"priority", "_cdata", "stream_ptr", nullptr}; if (!PyArg_ParseTupleAndKeywords( - args, kwargs, "|iKK", kwlist, &priority, &cdata, &stream_ptr)) { + args, kwargs, "|iKK", kwlist, &priority, &cdata, &stream_ptr)) { return nullptr; } @@ -35,69 +37,70 @@ static PyObject * THCPStream_pynew( } if (stream_ptr) { - TORCH_CHECK(priority == 0, "Priority was explicitly set for a external stream") + TORCH_CHECK( + priority == 0, "Priority was explicitly set for a external stream") } - at::cuda::CUDAStream stream = - cdata ? - at::cuda::CUDAStream::unpack(cdata) : - stream_ptr ? - at::cuda::getStreamFromExternal(reinterpret_cast(stream_ptr), current_device) : - at::cuda::getStreamFromPool( - /* isHighPriority */ priority < 0 ? true : false); + at::cuda::CUDAStream stream = cdata ? at::cuda::CUDAStream::unpack(cdata) + : stream_ptr + ? at::cuda::getStreamFromExternal( + reinterpret_cast(stream_ptr), current_device) + : at::cuda::getStreamFromPool( + /* isHighPriority */ priority < 0 ? true : false); // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - THCPStream* self = (THCPStream *)ptr.get(); + THCPStream* self = (THCPStream*)ptr.get(); self->cdata = stream.pack(); new (&self->cuda_stream) at::cuda::CUDAStream(stream); - return (PyObject *)ptr.release(); + return (PyObject*)ptr.release(); END_HANDLE_TH_ERRORS } -static void THCPStream_dealloc(THCPStream *self) { +static void THCPStream_dealloc(THCPStream* self) { self->cuda_stream.~CUDAStream(); Py_TYPE(self)->tp_free((PyObject*)self); } -static PyObject * THCPStream_get_device(THCPStream *self, void *unused) { +static PyObject* THCPStream_get_device(THCPStream* self, void* unused) { HANDLE_TH_ERRORS return THPDevice_New(self->cuda_stream.device()); END_HANDLE_TH_ERRORS } -static PyObject * THCPStream_get_cuda_stream(THCPStream *self, void *unused) { +static PyObject* THCPStream_get_cuda_stream(THCPStream* self, void* unused) { HANDLE_TH_ERRORS return PyLong_FromVoidPtr(self->cuda_stream.stream()); END_HANDLE_TH_ERRORS } -static PyObject * THCPStream_get_priority(THCPStream *self, void *unused) { +static PyObject* THCPStream_get_priority(THCPStream* self, void* unused) { HANDLE_TH_ERRORS return THPUtils_packInt64(self->cuda_stream.priority()); END_HANDLE_TH_ERRORS } -static PyObject * THCPStream_priority_range(PyObject *_unused, PyObject* noargs) { +static PyObject* THCPStream_priority_range( + PyObject* _unused, + PyObject* noargs) { HANDLE_TH_ERRORS // NOLINTNEXTLINE(cppcoreguidelines-init-variables) int least_priority, greatest_priority; std::tie(least_priority, greatest_priority) = - at::cuda::CUDAStream::priority_range(); + at::cuda::CUDAStream::priority_range(); return Py_BuildValue("(ii)", least_priority, greatest_priority); END_HANDLE_TH_ERRORS } -static PyObject * THCPStream_query(PyObject *_self, PyObject *noargs) { +static PyObject* THCPStream_query(PyObject* _self, PyObject* noargs) { HANDLE_TH_ERRORS auto self = (THCPStream*)_self; return PyBool_FromLong(self->cuda_stream.query()); END_HANDLE_TH_ERRORS } -static PyObject * THCPStream_synchronize(PyObject *_self, PyObject *noargs) { - HANDLE_TH_ERRORS - { +static PyObject* THCPStream_synchronize(PyObject* _self, PyObject* noargs) { + HANDLE_TH_ERRORS { pybind11::gil_scoped_release no_gil; auto self = (THCPStream*)_self; self->cuda_stream.synchronize(); @@ -106,7 +109,7 @@ static PyObject * THCPStream_synchronize(PyObject *_self, PyObject *noargs) { END_HANDLE_TH_ERRORS } -static PyObject * THCPStream_eq(PyObject *_self, PyObject *_other) { +static PyObject* THCPStream_eq(PyObject* _self, PyObject* _other) { HANDLE_TH_ERRORS auto self = (THCPStream*)_self; auto other = (THCPStream*)_other; @@ -114,74 +117,77 @@ static PyObject * THCPStream_eq(PyObject *_self, PyObject *_other) { END_HANDLE_TH_ERRORS } -// NOLINTNEXTLINE(modernize-avoid-c-arrays, cppcoreguidelines-avoid-non-const-global-variables, cppcoreguidelines-avoid-c-arrays) -static struct PyMemberDef THCPStream_members[] = { - {nullptr} -}; +// NOLINTNEXTLINE(modernize-avoid-c-arrays, +// cppcoreguidelines-avoid-non-const-global-variables, +// cppcoreguidelines-avoid-c-arrays) +static struct PyMemberDef THCPStream_members[] = {{nullptr}}; -// NOLINTNEXTLINE(modernize-avoid-c-arrays, cppcoreguidelines-avoid-non-const-global-variables, cppcoreguidelines-avoid-c-arrays) +// NOLINTNEXTLINE(modernize-avoid-c-arrays, +// cppcoreguidelines-avoid-non-const-global-variables, +// cppcoreguidelines-avoid-c-arrays) static struct PyGetSetDef THCPStream_properties[] = { - {"cuda_stream", - (getter)THCPStream_get_cuda_stream, nullptr, nullptr, nullptr}, - {"priority", (getter)THCPStream_get_priority, nullptr, nullptr, nullptr}, - {nullptr} -}; - -// NOLINTNEXTLINE(modernize-avoid-c-arrays, cppcoreguidelines-avoid-non-const-global-variables, cppcoreguidelines-avoid-c-arrays) + {"cuda_stream", + (getter)THCPStream_get_cuda_stream, + nullptr, + nullptr, + nullptr}, + {"priority", (getter)THCPStream_get_priority, nullptr, nullptr, nullptr}, + {nullptr}}; + +// NOLINTNEXTLINE(modernize-avoid-c-arrays, +// cppcoreguidelines-avoid-non-const-global-variables, +// cppcoreguidelines-avoid-c-arrays) static PyMethodDef THCPStream_methods[] = { - {(char*)"query", THCPStream_query, METH_NOARGS, nullptr}, - {(char*)"synchronize", - THCPStream_synchronize, METH_NOARGS, nullptr}, - {(char*)"priority_range", - THCPStream_priority_range, METH_STATIC | METH_NOARGS, nullptr}, - {(char*)"__eq__", THCPStream_eq, METH_O, nullptr}, - {nullptr} -}; + {(char*)"query", THCPStream_query, METH_NOARGS, nullptr}, + {(char*)"synchronize", THCPStream_synchronize, METH_NOARGS, nullptr}, + {(char*)"priority_range", + THCPStream_priority_range, + METH_STATIC | METH_NOARGS, + nullptr}, + {(char*)"__eq__", THCPStream_eq, METH_O, nullptr}, + {nullptr}}; PyTypeObject THCPStreamType = { - PyVarObject_HEAD_INIT(nullptr, 0) - "torch._C._CudaStreamBase", /* tp_name */ - sizeof(THCPStream), /* tp_basicsize */ - 0, /* tp_itemsize */ - (destructor)THCPStream_dealloc, /* tp_dealloc */ - 0, /* tp_vectorcall_offset */ - nullptr, /* tp_getattr */ - nullptr, /* tp_setattr */ - nullptr, /* tp_reserved */ - nullptr, /* tp_repr */ - nullptr, /* tp_as_number */ - nullptr, /* tp_as_sequence */ - nullptr, /* tp_as_mapping */ - nullptr, /* tp_hash */ - nullptr, /* tp_call */ - nullptr, /* tp_str */ - nullptr, /* tp_getattro */ - nullptr, /* tp_setattro */ - nullptr, /* tp_as_buffer */ - Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /* tp_flags */ - nullptr, /* tp_doc */ - nullptr, /* tp_traverse */ - nullptr, /* tp_clear */ - nullptr, /* tp_richcompare */ - 0, /* tp_weaklistoffset */ - nullptr, /* tp_iter */ - nullptr, /* tp_iternext */ - THCPStream_methods, /* tp_methods */ - THCPStream_members, /* tp_members */ - THCPStream_properties, /* tp_getset */ - nullptr, /* tp_base */ - nullptr, /* tp_dict */ - nullptr, /* tp_descr_get */ - nullptr, /* tp_descr_set */ - 0, /* tp_dictoffset */ - nullptr, /* tp_init */ - nullptr, /* tp_alloc */ - THCPStream_pynew, /* tp_new */ + PyVarObject_HEAD_INIT(nullptr, 0) "torch._C._CudaStreamBase", /* tp_name */ + sizeof(THCPStream), /* tp_basicsize */ + 0, /* tp_itemsize */ + (destructor)THCPStream_dealloc, /* tp_dealloc */ + 0, /* tp_vectorcall_offset */ + nullptr, /* tp_getattr */ + nullptr, /* tp_setattr */ + nullptr, /* tp_reserved */ + nullptr, /* tp_repr */ + nullptr, /* tp_as_number */ + nullptr, /* tp_as_sequence */ + nullptr, /* tp_as_mapping */ + nullptr, /* tp_hash */ + nullptr, /* tp_call */ + nullptr, /* tp_str */ + nullptr, /* tp_getattro */ + nullptr, /* tp_setattro */ + nullptr, /* tp_as_buffer */ + Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /* tp_flags */ + nullptr, /* tp_doc */ + nullptr, /* tp_traverse */ + nullptr, /* tp_clear */ + nullptr, /* tp_richcompare */ + 0, /* tp_weaklistoffset */ + nullptr, /* tp_iter */ + nullptr, /* tp_iternext */ + THCPStream_methods, /* tp_methods */ + THCPStream_members, /* tp_members */ + THCPStream_properties, /* tp_getset */ + nullptr, /* tp_base */ + nullptr, /* tp_dict */ + nullptr, /* tp_descr_get */ + nullptr, /* tp_descr_set */ + 0, /* tp_dictoffset */ + nullptr, /* tp_init */ + nullptr, /* tp_alloc */ + THCPStream_pynew, /* tp_new */ }; - -void THCPStream_init(PyObject *module) -{ +void THCPStream_init(PyObject* module) { Py_INCREF(THPStreamClass); THCPStreamType.tp_base = THPStreamClass; THCPStreamClass = (PyObject*)&THCPStreamType; @@ -190,7 +196,7 @@ void THCPStream_init(PyObject *module) } Py_INCREF(&THCPStreamType); if (PyModule_AddObject( - module, "_CudaStreamBase", (PyObject *)&THCPStreamType) < 0) { + module, "_CudaStreamBase", (PyObject*)&THCPStreamType) < 0) { throw python_error(); } } diff --git a/torch/csrc/cuda/Stream.h b/torch/csrc/cuda/Stream.h index d0995d298528a9..9b7197d74390c1 100644 --- a/torch/csrc/cuda/Stream.h +++ b/torch/csrc/cuda/Stream.h @@ -1,17 +1,17 @@ #ifndef THCP_STREAM_INC #define THCP_STREAM_INC -#include #include +#include #include // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) -struct THCPStream : THPStream{ +struct THCPStream : THPStream { at::cuda::CUDAStream cuda_stream; }; -extern PyObject *THCPStreamClass; +extern PyObject* THCPStreamClass; -void THCPStream_init(PyObject *module); +void THCPStream_init(PyObject* module); inline bool THCPStream_Check(PyObject* obj) { return THCPStreamClass && PyObject_IsInstance(obj, THCPStreamClass); diff --git a/torch/csrc/cuda/THCP.h b/torch/csrc/cuda/THCP.h index 431d3adfa6abc4..12b7042ebb49dd 100644 --- a/torch/csrc/cuda/THCP.h +++ b/torch/csrc/cuda/THCP.h @@ -1,11 +1,11 @@ #ifndef THCP_H #define THCP_H -#include #include +#include #include #include -#include #include +#include #endif diff --git a/torch/csrc/cuda/Tensor.cpp b/torch/csrc/cuda/Tensor.cpp index 0801907df461c8..7bd8c498c68df1 100644 --- a/torch/csrc/cuda/Tensor.cpp +++ b/torch/csrc/cuda/Tensor.cpp @@ -1,19 +1,19 @@ #define __STDC_FORMAT_MACROS -#include #include +#include -#include +#include #include #include -#include +#include -#include -#include #include +#include +#include #include -//generic_include THC torch/csrc/generic/Tensor.cpp +// generic_include THC torch/csrc/generic/Tensor.cpp -#include #include +#include diff --git a/torch/csrc/cuda/comm.cpp b/torch/csrc/cuda/comm.cpp index b9f58910fad85f..117f6b571792b3 100644 --- a/torch/csrc/cuda/comm.cpp +++ b/torch/csrc/cuda/comm.cpp @@ -11,8 +11,8 @@ #include #include #include -#include #include +#include #include #include @@ -74,7 +74,7 @@ static inline std::vector& _broadcast_out_impl( std::vector& broadcast_out( const Tensor& tensor, std::vector& out_tensors) { - for(const auto i : c10::irange(out_tensors.size())) { + for (const auto i : c10::irange(out_tensors.size())) { TORCH_CHECK( out_tensors[i].is_cuda(), "Expected all output tensors to be CUDA tensors, but output tensor at index ", @@ -242,7 +242,7 @@ std::vector& scatter_out( // NOLINTNEXTLINE(cppcoreguidelines-init-variables) std::vector chunk_sizes; chunk_sizes.reserve(out_tensors.size()); - for(const auto i : c10::irange(out_tensors.size())) { + for (const auto i : c10::irange(out_tensors.size())) { TORCH_CHECK( out_tensors[i].is_cuda(), "Expected all output tensors to be CUDA tensors, but output tensor at index ", @@ -285,7 +285,7 @@ std::vector& scatter_out( auto chunks = tensor.split_with_sizes(/*split_sizes=*/chunk_sizes, /*dim=*/dim); at::cuda::OptionalCUDAStreamGuard cuda_guard; - for(const auto i : c10::irange(chunks.size())) { + for (const auto i : c10::irange(chunks.size())) { if (i < (streams ? streams->size() : 0U) && (*streams)[i]) { const auto device_index = static_cast(out_tensors[i].get_device()); @@ -383,7 +383,7 @@ static inline at::Tensor& _gather_out_impl( } auto chunks = out_tensor.split_with_sizes(/*split_sizes=*/chunk_sizes, /*dim=*/dim); - for(const auto i : c10::irange(tensors.size())) { + for (const auto i : c10::irange(tensors.size())) { chunks[i].copy_(tensors[i], /*non_blocking=*/out_tensor.is_cuda()); } return out_tensor; @@ -400,7 +400,7 @@ at::Tensor& gather_out( dim = at::maybe_wrap_dim(dim, first); // NOLINTNEXTLINE(cppcoreguidelines-init-variables) std::vector expected_size(first_size.begin(), first_size.end()); - for(const auto i : c10::irange(tensors.size())) { + for (const auto i : c10::irange(tensors.size())) { const auto& tensor = tensors[i]; TORCH_CHECK( tensor.is_cuda(), @@ -456,7 +456,7 @@ at::Tensor gather( // NOLINTNEXTLINE(cppcoreguidelines-init-variables) std::vector expected_size(first_size.begin(), first_size.end()); auto memory_format = first.suggest_memory_format(); - for(const auto i : c10::irange(tensors.size())) { + for (const auto i : c10::irange(tensors.size())) { const auto& tensor = tensors[i]; TORCH_CHECK( tensor.is_cuda(), diff --git a/torch/csrc/cuda/comm.h b/torch/csrc/cuda/comm.h index 7dd99e64141d82..5c12b5bda09414 100644 --- a/torch/csrc/cuda/comm.h +++ b/torch/csrc/cuda/comm.h @@ -1,15 +1,16 @@ #pragma once #include -#include #include #include #include +#include #include #include -namespace torch { namespace cuda { +namespace torch { +namespace cuda { using tensor_list2d = std::vector>; @@ -48,4 +49,5 @@ TORCH_CUDA_CU_API at::Tensor gather( at::TensorList tensors, int64_t dim, c10::optional destination_index); -}} +} // namespace cuda +} // namespace torch diff --git a/torch/csrc/cuda/device_set.h b/torch/csrc/cuda/device_set.h index a939710c30b747..82fa34294d363f 100644 --- a/torch/csrc/cuda/device_set.h +++ b/torch/csrc/cuda/device_set.h @@ -7,4 +7,4 @@ namespace torch { static constexpr size_t MAX_CUDA_DEVICES = 64; using device_set = std::bitset; -} +} // namespace torch diff --git a/torch/csrc/cuda/nccl.cpp b/torch/csrc/cuda/nccl.cpp index 5817449c1a4972..e5dc7992d56c74 100644 --- a/torch/csrc/cuda/nccl.cpp +++ b/torch/csrc/cuda/nccl.cpp @@ -1,6 +1,6 @@ -#include -#include #include +#include +#include #include #include @@ -15,7 +15,6 @@ #include #include - ncclComm_t* to_nccl_comm(torch::cuda::nccl::ncclComm_t* var) { return reinterpret_cast(var); } @@ -50,7 +49,7 @@ ncclResult_t to_nccl_result(torch::cuda::nccl::ncclResult var) { } torch::cuda::nccl::ncclResult from_nccl_result(ncclResult_t var) { - switch (var) { + switch (var) { case ncclSuccess: return torch::cuda::nccl::ncclResult::Success; case ncclUnhandledCudaError: @@ -99,7 +98,10 @@ ncclDataType_t to_nccl_data_type(c10::ScalarType type) { ncclDataType_t to_nccl_data_type(const at::Tensor& t) { if (!t.is_cuda()) { - TORCH_CHECK(false, "NCCL only supports CUDA tensors, but got a tensor on ", t.device()); + TORCH_CHECK( + false, + "NCCL only supports CUDA tensors, but got a tensor on ", + t.device()); } return to_nccl_data_type(t.scalar_type()); } @@ -137,7 +139,8 @@ struct AutoNcclGroup { void throw_nccl_error(torch::cuda::nccl::ncclResult status) { std::ostringstream err; - err << "NCCL Error " << static_cast(status) << ": " << ncclGetErrorString(to_nccl_result(status)); + err << "NCCL Error " << static_cast(status) << ": " + << ncclGetErrorString(to_nccl_result(status)); throw std::runtime_error(err.str()); } @@ -146,13 +149,13 @@ struct NcclCommList { int ndevices; NcclCommList(const std::vector& devices) : comms(new ncclComm_t[devices.size()]), ndevices(devices.size()) { - NCCL_CHECK( - ncclCommInitAll(to_nccl_comm(comms.get()), devices.size(), devices.data())); + NCCL_CHECK(ncclCommInitAll( + to_nccl_comm(comms.get()), devices.size(), devices.data())); } NcclCommList(NcclCommList&& foo) = default; ~NcclCommList() { if (comms) { - for(const auto i : c10::irange(ndevices)) { + for (const auto i : c10::irange(ndevices)) { int dummy_var; if (cudaGetDevice(&dummy_var) != cudaSuccess) { /* there are cases when this destructor is called after the @@ -185,16 +188,14 @@ ArrayRef get_communicators(TensorList inputs) { return it->second.ref(); } -static inline -void check_tensor( +static inline void check_tensor( const at::Tensor& input, const at::optional& output, int input_multiplier, int output_multiplier, int64_t ref_numel, ScalarType ref_dtype) { - - auto check_one = [&](const at::Tensor &tensor) { + auto check_one = [&](const at::Tensor& tensor) { if (!tensor.is_cuda() || tensor.is_sparse()) { throw std::runtime_error( "input and output elements have to be cuda dense Tensors"); @@ -256,11 +257,12 @@ void check_inputs( int64_t numel = inputs[0].numel(); auto dtype = inputs[0].scalar_type(); - for(const auto i : c10::irange(len)) { + for (const auto i : c10::irange(len)) { auto input = inputs[i]; auto output = outputs[i]; - check_tensor(input, output, input_multiplier, output_multiplier, numel, dtype); + check_tensor( + input, output, input_multiplier, output_multiplier, numel, dtype); auto input_device = input.get_device(); // inputs must be on unique devices @@ -287,13 +289,16 @@ void check_inputs( int64_t numel = inputs[0].numel(); auto dtype = inputs[0].scalar_type(); - for(const auto i : c10::irange(len)) { + for (const auto i : c10::irange(len)) { auto input = inputs[i]; check_tensor( - input, - i == root ? at::optional{output} : at::nullopt, - input_multiplier, output_multiplier, numel, dtype); + input, + i == root ? at::optional{output} : at::nullopt, + input_multiplier, + output_multiplier, + numel, + dtype); auto input_device = input.get_device(); // inputs must be on unique devices @@ -327,20 +332,18 @@ bool is_available(TensorList tensors) { std::uint64_t version() { #if defined(NCCL_MAJOR) - constexpr std::uint64_t ver = (((uint64_t) NCCL_MAJOR) << 32) | - (((uint64_t) NCCL_MINOR) << 16) | - ((uint64_t) NCCL_PATCH); + constexpr std::uint64_t ver = (((uint64_t)NCCL_MAJOR) << 32) | + (((uint64_t)NCCL_MINOR) << 16) | ((uint64_t)NCCL_PATCH); return ver; #elif defined(USE_NCCL) // return major version "1" - return ((uint64_t) 1) << 32; + return ((uint64_t)1) << 32; #else return 0; #endif } -void get_unique_id(ncclUniqueId& id) -{ +void get_unique_id(ncclUniqueId& id) { #ifdef USE_NCCL using namespace torch::cuda::nccl::detail; NCCL_CHECK(ncclGetUniqueId(to_nccl_unique_id(&id))); @@ -355,18 +358,14 @@ ncclComm_t comm_init_rank(int nranks, const ncclUniqueId& comm_id, int rank) { ncclComm_t comm; ncclUniqueId id = comm_id; NCCL_CHECK(ncclCommInitRank( - to_nccl_comm(&comm), - nranks, - *(to_nccl_unique_id(&id)), - rank)); + to_nccl_comm(&comm), nranks, *(to_nccl_unique_id(&id)), rank)); return comm; #else return nullptr; #endif } -void comm_destroy(ncclComm_t comm) -{ +void comm_destroy(ncclComm_t comm) { /* * TODO(T30279827) Temporarily disable calling ncclCommDestroy * Calling ncclCommDestroy while program exiting is undefined @@ -437,7 +436,12 @@ void broadcast( ")"); ncclComm_t comm = comms[i]; NCCL_CHECK(ncclBcast( - tensors[i].data_ptr(), numel, data_type, 0, to_nccl_comm(comm), stream)); + tensors[i].data_ptr(), + numel, + data_type, + 0, + to_nccl_comm(comm), + stream)); } #else AT_ERROR("PyTorch built without NCCL support"); @@ -467,7 +471,7 @@ void reduce( AutoNcclGroup nccl_group_guard; at::cuda::OptionalCUDAGuard device_guard; - for(const auto i : c10::irange(len)) { + for (const auto i : c10::irange(len)) { int device = inputs[i].device().index(); device_guard.set_index(device); // Default to the current stream @@ -519,7 +523,7 @@ void all_reduce( AutoNcclGroup nccl_group_guard; at::cuda::OptionalCUDAGuard device_guard; - for(const auto i : c10::irange(len)) { + for (const auto i : c10::irange(len)) { int device = inputs[i].device().index(); device_guard.set_index(device); // Default to the current stream @@ -561,7 +565,7 @@ void reduce_scatter( AutoNcclGroup nccl_group_guard; at::cuda::OptionalCUDAGuard device_guard; - for(const auto i : c10::irange(len)) { + for (const auto i : c10::irange(len)) { int device = inputs[i].device().index(); device_guard.set_index(device); // Default to the current stream @@ -602,7 +606,7 @@ void all_gather( AutoNcclGroup nccl_group_guard; at::cuda::OptionalCUDAGuard device_guard; - for(const auto i : c10::irange(len)) { + for (const auto i : c10::irange(len)) { int device = inputs[i].device().index(); device_guard.set_index(device); // Default to the current stream @@ -612,21 +616,21 @@ void all_gather( ncclComm_t comm = comms_ref[i]; #if defined(NCCL_MAJOR) && (NCCL_MAJOR >= 2) - NCCL_CHECK(ncclAllGather( - inputs[i].data_ptr(), - outputs[i].data_ptr(), - count, - data_type, - to_nccl_comm(comm), - stream)); + NCCL_CHECK(ncclAllGather( + inputs[i].data_ptr(), + outputs[i].data_ptr(), + count, + data_type, + to_nccl_comm(comm), + stream)); #else - NCCL_CHECK(ncclAllGather( - inputs[i].data_ptr(), - count, - data_type, - outputs[i].data_ptr(), - to_nccl_comm(comm), - stream)); + NCCL_CHECK(ncclAllGather( + inputs[i].data_ptr(), + count, + data_type, + outputs[i].data_ptr(), + to_nccl_comm(comm), + stream)); #endif } #else @@ -634,13 +638,15 @@ void all_gather( #endif } -void all2all_single_equal_split(at::Tensor& input, - at::Tensor& output, - int size, - ncclComm_t _comm, - at::cuda::CUDAStream& stream) { +void all2all_single_equal_split( + at::Tensor& input, + at::Tensor& output, + int size, + ncclComm_t _comm, + at::cuda::CUDAStream& stream) { #ifdef USE_NCCL -#if defined(NCCL_MAJOR) && (NCCL_MAJOR == 2) && (NCCL_MAJOR * 10 + NCCL_MINOR) >= 27 +#if defined(NCCL_MAJOR) && (NCCL_MAJOR == 2) && \ + (NCCL_MAJOR * 10 + NCCL_MINOR) >= 27 using namespace torch::cuda::nccl::detail; int numranks; @@ -648,19 +654,21 @@ void all2all_single_equal_split(at::Tensor& input, size_t count = input.numel() / size; size_t rankdiff = input.nbytes() / size; const auto* sendbuff = reinterpret_cast(input.data_ptr()); - auto* recvbuff = reinterpret_cast(output.data_ptr()); + auto* recvbuff = reinterpret_cast(output.data_ptr()); auto comm = to_nccl_comm(_comm); #if defined(USE_ROCM) && ROCM_VERSION >= 50000 - NCCL_CHECK(ncclAllToAll(sendbuff , recvbuff , count, type, comm, stream)); + NCCL_CHECK(ncclAllToAll(sendbuff, recvbuff, count, type, comm, stream)); #else NCCL_CHECK(ncclCommCount(comm, &numranks)); NCCL_CHECK(ncclGroupStart()); - for(const auto r : c10::irange(numranks)) { + for (const auto r : c10::irange(numranks)) { // NCCL uses 0 byte message for synchronization // Avoid send/recv when message size is zero if (count != 0) { - NCCL_CHECK(ncclSend(sendbuff + r * rankdiff, count, type, r, comm, stream)); - NCCL_CHECK(ncclRecv(recvbuff + r * rankdiff, count, type, r, comm, stream)); + NCCL_CHECK( + ncclSend(sendbuff + r * rankdiff, count, type, r, comm, stream)); + NCCL_CHECK( + ncclRecv(recvbuff + r * rankdiff, count, type, r, comm, stream)); } } NCCL_CHECK(ncclGroupEnd()); @@ -685,7 +693,8 @@ void all2all_single_unequal_split( ncclComm_t _comm, at::cuda::CUDAStream& stream) { #ifdef USE_NCCL -#if defined(NCCL_MAJOR) && (NCCL_MAJOR == 2) && (NCCL_MAJOR * 10 + NCCL_MINOR) >= 27 +#if defined(NCCL_MAJOR) && (NCCL_MAJOR == 2) && \ + (NCCL_MAJOR * 10 + NCCL_MINOR) >= 27 using namespace torch::cuda::nccl::detail; auto type = to_nccl_data_type(_type); @@ -693,7 +702,7 @@ void all2all_single_unequal_split( int numranks; NCCL_CHECK(ncclCommCount(comm, &numranks)); NCCL_CHECK(ncclGroupStart()); - for(const auto r : c10::irange(numranks)) { + for (const auto r : c10::irange(numranks)) { // NCCL uses 0 byte message for synchronization // Avoid send/recv when message size is zero if (sendcounts[r] != 0) { @@ -724,19 +733,21 @@ void all2all_single_unequal_split( #endif } -void all2all(std::vector& outputTensors, - std::vector& inputTensors, - ncclComm_t _comm, - at::cuda::CUDAStream& stream) { +void all2all( + std::vector& outputTensors, + std::vector& inputTensors, + ncclComm_t _comm, + at::cuda::CUDAStream& stream) { #ifdef USE_NCCL -#if defined(NCCL_MAJOR) && (NCCL_MAJOR == 2) && (NCCL_MAJOR * 10 + NCCL_MINOR) >= 27 +#if defined(NCCL_MAJOR) && (NCCL_MAJOR == 2) && \ + (NCCL_MAJOR * 10 + NCCL_MINOR) >= 27 using namespace torch::cuda::nccl::detail; auto comm = to_nccl_comm(_comm); NCCL_CHECK(ncclGroupStart()); - for(const auto r : c10::irange(outputTensors.size())) { - at::Tensor &input = inputTensors[r]; - at::Tensor &output = outputTensors[r]; + for (const auto r : c10::irange(outputTensors.size())) { + at::Tensor& input = inputTensors[r]; + at::Tensor& output = outputTensors[r]; if (input.numel() != 0) { NCCL_CHECK(ncclSend( input.data_ptr(), @@ -813,7 +824,6 @@ void recv( #endif } - void gather( const at::Tensor& inputs, std::vector& outputs, @@ -821,7 +831,8 @@ void gather( at::cuda::CUDAStream& stream, int32_t root) { #ifdef USE_NCCL -#if defined(NCCL_MAJOR) && (NCCL_MAJOR == 2) && (NCCL_MAJOR * 10 + NCCL_MINOR) >= 27 +#if defined(NCCL_MAJOR) && (NCCL_MAJOR == 2) && \ + (NCCL_MAJOR * 10 + NCCL_MINOR) >= 27 using namespace torch::cuda::nccl::detail; auto comm = to_nccl_comm(_comm); @@ -835,11 +846,10 @@ void gather( NCCL_CHECK(ncclGroupStart()); - if (cur_rank == root) - { + if (cur_rank == root) { for (const auto r : c10::irange(numranks)) { if (r != root) { - auto* recvbuff = reinterpret_cast(outputs[r].data_ptr()); + auto* recvbuff = reinterpret_cast(outputs[r].data_ptr()); NCCL_CHECK(ncclRecv(recvbuff, count, type, r, comm, stream)); } else { // on its own rank, simply copy from the input @@ -866,7 +876,8 @@ void scatter( at::cuda::CUDAStream& stream, int32_t root) { #ifdef USE_NCCL -#if defined(NCCL_MAJOR) && (NCCL_MAJOR == 2) && (NCCL_MAJOR * 10 + NCCL_MINOR) >= 27 +#if defined(NCCL_MAJOR) && (NCCL_MAJOR == 2) && \ + (NCCL_MAJOR * 10 + NCCL_MINOR) >= 27 using namespace torch::cuda::nccl::detail; auto comm = to_nccl_comm(_comm); @@ -875,13 +886,12 @@ void scatter( NCCL_CHECK(ncclCommUserRank(comm, &cur_rank)); NCCL_CHECK(ncclGroupStart()); - if (cur_rank == root) - { + if (cur_rank == root) { for (const auto r : c10::irange(numranks)) { if (r != root) { size_t send_count = inputs[r].numel(); auto send_type = to_nccl_data_type(inputs[r]); - const auto* sendbuff = reinterpret_cast(inputs[r].data_ptr()); + const auto* sendbuff = reinterpret_cast(inputs[r].data_ptr()); NCCL_CHECK(ncclSend(sendbuff, send_count, send_type, r, comm, stream)); } else { // on its own rank, simply copy it to the output @@ -904,7 +914,6 @@ void scatter( #endif } - } // namespace nccl } // namespace cuda } // namespace torch diff --git a/torch/csrc/cuda/nccl.h b/torch/csrc/cuda/nccl.h index e383bce486f4b1..076ba5cbcddbea 100644 --- a/torch/csrc/cuda/nccl.h +++ b/torch/csrc/cuda/nccl.h @@ -8,7 +8,8 @@ #include #include -// NCCL BFloat16 is enabled only for CUDA 11+ and NCCL versions 2.10+, or for HIP 3.1+ +// NCCL BFloat16 is enabled only for CUDA 11+ and NCCL versions 2.10+, or for +// HIP 3.1+ #if defined(__CUDA_BF16_TYPES_EXIST__) #define HAS_NCCL_BF16_DATATYPE \ ((NCCL_MAJOR > 2) || (NCCL_MAJOR == 2) && (NCCL_MINOR >= 10)) @@ -22,52 +23,56 @@ namespace torch { namespace cuda { namespace nccl { -/* The following are copied from and redefined in torch::cuda::nccl namespace */ +/* The following are copied from and redefined in torch::cuda::nccl + * namespace */ /* pytorch should only use the following definition within pytorch scope */ -/* Opaque handle to communicator to ncclComm*, this will reinterpret as ncclComm in nccl.cpp */ +/* Opaque handle to communicator to ncclComm*, this will reinterpret as ncclComm + * in nccl.cpp */ typedef void* ncclComm_t; -/** redefine nccl unique ID in torch scope. this should be identical to native nccl impp. */ +/** redefine nccl unique ID in torch scope. this should be identical to native + * nccl impp. */ #define NCCL_UNIQUE_ID_BYTES 128 // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays) -typedef struct { char internal[NCCL_UNIQUE_ID_BYTES]; } ncclUniqueId; +typedef struct { + char internal[NCCL_UNIQUE_ID_BYTES]; +} ncclUniqueId; /* Error type */ enum class ncclResult { - Success = 0, - UnhandledCudaError = 1, - SystemError = 2, - InternalError = 3, - InvalidArgument = 4, - InvalidUsage = 5, - NumResults = 6 }; + Success = 0, + UnhandledCudaError = 1, + SystemError = 2, + InternalError = 3, + InvalidArgument = 4, + InvalidUsage = 5, + NumResults = 6 +}; /* Reduction operation selector */ -enum class ncclRedOp { - Sum = 0, - Prod = 1, - Max = 2, - Min = 3, - NumOps = 4 }; +enum class ncclRedOp { Sum = 0, Prod = 1, Max = 2, Min = 3, NumOps = 4 }; /* Data types */ enum class ncclDataType { - Int8 = 0, Char = 0, - Uint8 = 1, - Int32 = 2, Int = 2, - Uint32 = 3, - Int64 = 4, - Uint64 = 5, - Float16 = 6, Half = 6, - Float32 = 7, Float = 7, - Float64 = 8, Double = 8, - Bfloat16 = 9, - NumTypes = 10 + Int8 = 0, + Char = 0, + Uint8 = 1, + Int32 = 2, + Int = 2, + Uint32 = 3, + Int64 = 4, + Uint64 = 5, + Float16 = 6, + Half = 6, + Float32 = 7, + Float = 7, + Float64 = 8, + Double = 8, + Bfloat16 = 9, + NumTypes = 10 }; - - // NOTE: this is exposed only so that python_nccl.cpp can some of these helpers. // Don't use them outside of these files. namespace detail { diff --git a/torch/csrc/cuda/override_macros.h b/torch/csrc/cuda/override_macros.h index 96b0c9fa99f6a7..ffd1efd26e46ad 100644 --- a/torch/csrc/cuda/override_macros.h +++ b/torch/csrc/cuda/override_macros.h @@ -5,8 +5,9 @@ #define THWTensor THCTensor #define THWTensor_(NAME) THCTensor_(NAME) -#define THPTensor_(NAME) TH_CONCAT_4(THCP,Real,Tensor_,NAME) -#define THPTensor_stateless_(NAME) TH_CONCAT_4(THCP,Real,Tensor_stateless_,NAME) +#define THPTensor_(NAME) TH_CONCAT_4(THCP, Real, Tensor_, NAME) +#define THPTensor_stateless_(NAME) \ + TH_CONCAT_4(THCP, Real, Tensor_stateless_, NAME) #define THPTensor THCPTensor #define THPTensorStr THCPTensorStr #define THPTensorBaseStr THCPTensorBaseStr @@ -16,11 +17,11 @@ #define THPTensorStatelessType THCPTensorStatelessType #define THPTensorStateless THCPTensorStateless - #define THSPTensorPtr THCSPTensorPtr -#define THSPTensor_(NAME) TH_CONCAT_4(THCSP,Real,Tensor_,NAME) -#define THSPTensor_stateless_(NAME) TH_CONCAT_4(THCSP,Real,Tensor_stateless_,NAME) +#define THSPTensor_(NAME) TH_CONCAT_4(THCSP, Real, Tensor_, NAME) +#define THSPTensor_stateless_(NAME) \ + TH_CONCAT_4(THCSP, Real, Tensor_stateless_, NAME) #define THSPTensor THCSPTensor #define THSPTensorStr THCSPTensorStr #define THSPTensorBaseStr THCSPTensorBaseStr @@ -30,5 +31,4 @@ #define THSPTensorStatelessType THCSPTensorStatelessType #define THSPTensorStateless THCSPTensorStateless - #define TH_GENERIC_FILE THC_GENERIC_FILE diff --git a/torch/csrc/cuda/python_comm.cpp b/torch/csrc/cuda/python_comm.cpp index d2296a13181022..5c81daf732e212 100644 --- a/torch/csrc/cuda/python_comm.cpp +++ b/torch/csrc/cuda/python_comm.cpp @@ -1,17 +1,19 @@ -#include -#include +#include +#include #include #include -#include -#include +#include +#include #include #include #include -namespace torch { namespace cuda { namespace python { -void initCommMethods(PyObject *module) { +namespace torch { +namespace cuda { +namespace python { +void initCommMethods(PyObject* module) { auto m = py::cast(module); m.def( "_broadcast_coalesced", @@ -47,7 +49,8 @@ void initCommMethods(PyObject *module) { c10::optional> chunk_sizes, int64_t dim, c10::optional py_streams) { - c10::optional>> streams; + c10::optional>> + streams; if (py_streams) { py::handle handle = *py_streams; streams = THPUtils_PySequence_to_CUDAStreamList(handle.ptr()); @@ -67,7 +70,8 @@ void initCommMethods(PyObject *module) { std::vector& out_tensors, int64_t dim, c10::optional py_streams) { - c10::optional>> streams; + c10::optional>> + streams; if (py_streams) { py::handle handle = *py_streams; streams = THPUtils_PySequence_to_CUDAStreamList(handle.ptr()); @@ -95,12 +99,12 @@ void initCommMethods(PyObject *module) { "_gather_out", [](std::vector& tensors, at::Tensor& out_tensor, - int64_t dim) { - return gather_out(tensors, out_tensor, dim); - }, + int64_t dim) { return gather_out(tensors, out_tensor, dim); }, py::arg("tensors"), py::arg("out"), py::arg("dim"), py::call_guard()); } -}}} +} // namespace python +} // namespace cuda +} // namespace torch diff --git a/torch/csrc/cuda/python_comm.h b/torch/csrc/cuda/python_comm.h index 519a11448cab27..35cefe5f2ca269 100644 --- a/torch/csrc/cuda/python_comm.h +++ b/torch/csrc/cuda/python_comm.h @@ -1,7 +1,11 @@ #pragma once -namespace torch { namespace cuda { namespace python { +namespace torch { +namespace cuda { +namespace python { -void initCommMethods(PyObject *module); +void initCommMethods(PyObject* module); -}}} +} +} // namespace cuda +} // namespace torch diff --git a/torch/csrc/cuda/python_nccl.cpp b/torch/csrc/cuda/python_nccl.cpp index 4fb841df700047..0eab79dc7a800c 100644 --- a/torch/csrc/cuda/python_nccl.cpp +++ b/torch/csrc/cuda/python_nccl.cpp @@ -1,13 +1,13 @@ #include +#include #include -#include #include #include #include #include #include -#include +#include #include #include @@ -52,7 +52,9 @@ static void destroy_nccl_comm(PyObject* capsule) { END_HANDLE_TH_ERRORS_RET() } -static std::vector> unpack_streams(PyObject* obj, size_t size) { +static std::vector> unpack_streams( + PyObject* obj, + size_t size) { if (obj == Py_None) { return std::vector>(size, c10::nullopt); } @@ -80,7 +82,7 @@ static std::vector unpack_comms(PyObject* obj, size_t size) { throw python_error(); auto size = PySequence_Fast_GET_SIZE(seq.get()); comms = std::vector(size); - for(const auto i : c10::irange(size)) { + for (const auto i : c10::irange(size)) { comms[i] = unpack_nccl_comm(PySequence_Fast_GET_ITEM(seq.get(), i)); } } @@ -125,14 +127,7 @@ PyObject* THCPModule_nccl_reduce(PyObject* self, PyObject* args) { int root, op; if (!PyArg_ParseTuple( - args, - "OOiiOO", - &_inputs, - &_output, - &root, - &op, - &_streams, - &_comms)) { + args, "OOiiOO", &_inputs, &_output, &root, &op, &_streams, &_comms)) { THPUtils_invalidArguments( args, nullptr, @@ -145,7 +140,8 @@ PyObject* THCPModule_nccl_reduce(PyObject* self, PyObject* args) { std::vector inputs = extract_tensors(_inputs); auto output = extract_tensor(_output); - std::vector> streams = unpack_streams(_streams, inputs.size()); + std::vector> streams = + unpack_streams(_streams, inputs.size()); auto user_comms = unpack_comms(_comms, inputs.size()); { @@ -283,16 +279,14 @@ PyObject* THCPModule_nccl_reduce_scatter(PyObject* self, PyObject* args) { END_HANDLE_TH_ERRORS } -static inline -at::Tensor extract_tensor(PyObject* obj) { +static inline at::Tensor extract_tensor(PyObject* obj) { if (!THPVariable_Check(obj)) { throw torch::TypeError("expected Tensor (got %s)", Py_TYPE(obj)->tp_name); } return THPVariable_Unpack(obj); } -static inline -std::vector extract_tensors(PyObject* obj) { +static inline std::vector extract_tensors(PyObject* obj) { auto seq = THPObjectPtr(PySequence_Fast(obj, "expected a sequence")); if (!seq) throw python_error(); diff --git a/torch/csrc/cuda/python_nccl.h b/torch/csrc/cuda/python_nccl.h index 0698f3a1e5d843..9932d659fb40fe 100644 --- a/torch/csrc/cuda/python_nccl.h +++ b/torch/csrc/cuda/python_nccl.h @@ -2,11 +2,11 @@ #include -PyObject * THCPModule_nccl_version(PyObject *self, PyObject *args); -PyObject * THCPModule_nccl_unique_id(PyObject *self, PyObject *args); -PyObject * THCPModule_nccl_init_rank(PyObject *self, PyObject *args); -PyObject * THCPModule_nccl_reduce(PyObject *self, PyObject *args); -PyObject * THCPModule_nccl_all_reduce(PyObject *self, PyObject *args); -PyObject * THCPModule_nccl_broadcast(PyObject *self, PyObject *args); -PyObject * THCPModule_nccl_all_gather(PyObject *self, PyObject *args); -PyObject * THCPModule_nccl_reduce_scatter(PyObject *self, PyObject *args); +PyObject* THCPModule_nccl_version(PyObject* self, PyObject* args); +PyObject* THCPModule_nccl_unique_id(PyObject* self, PyObject* args); +PyObject* THCPModule_nccl_init_rank(PyObject* self, PyObject* args); +PyObject* THCPModule_nccl_reduce(PyObject* self, PyObject* args); +PyObject* THCPModule_nccl_all_reduce(PyObject* self, PyObject* args); +PyObject* THCPModule_nccl_broadcast(PyObject* self, PyObject* args); +PyObject* THCPModule_nccl_all_gather(PyObject* self, PyObject* args); +PyObject* THCPModule_nccl_reduce_scatter(PyObject* self, PyObject* args); diff --git a/torch/csrc/cuda/restore_macros.h b/torch/csrc/cuda/restore_macros.h index 0dd5d8b56dee37..e76c600f12e5ee 100644 --- a/torch/csrc/cuda/restore_macros.h +++ b/torch/csrc/cuda/restore_macros.h @@ -1,12 +1,12 @@ #include -#define THWTensor TH_CONCAT_3(TH,Real,Tensor) -#define THWTensor_(NAME) TH_CONCAT_4(TH,Real,Tensor_,NAME) +#define THWTensor TH_CONCAT_3(TH, Real, Tensor) +#define THWTensor_(NAME) TH_CONCAT_4(TH, Real, Tensor_, NAME) -#define THPTensor TH_CONCAT_3(THP,Real,Tensor) -#define THPTensorStr TH_CONCAT_STRING_3(torch.,Real,Tensor) -#define THPTensorClass TH_CONCAT_3(THP,Real,TensorClass) -#define THPTensor_(NAME) TH_CONCAT_4(THP,Real,Tensor_,NAME) +#define THPTensor TH_CONCAT_3(THP, Real, Tensor) +#define THPTensorStr TH_CONCAT_STRING_3(torch., Real, Tensor) +#define THPTensorClass TH_CONCAT_3(THP, Real, TensorClass) +#define THPTensor_(NAME) TH_CONCAT_4(THP, Real, Tensor_, NAME) -#define THWTensorPtr TH_CONCAT_3(TH,Real,TensorPtr) -#define THPTensorPtr TH_CONCAT_3(THP,Real,TensorPtr) +#define THWTensorPtr TH_CONCAT_3(TH, Real, TensorPtr) +#define THPTensorPtr TH_CONCAT_3(THP, Real, TensorPtr) diff --git a/torch/csrc/cuda/shared/cudart.cpp b/torch/csrc/cuda/shared/cudart.cpp index b0af4c0884e91b..8fafd1ccdc5cf7 100644 --- a/torch/csrc/cuda/shared/cudart.cpp +++ b/torch/csrc/cuda/shared/cudart.cpp @@ -1,16 +1,18 @@ -#include #include #include +#include #if !defined(USE_ROCM) #include #else #include #endif -#include #include +#include -namespace torch { namespace cuda { namespace shared { +namespace torch { +namespace cuda { +namespace shared { void initCudartBindings(PyObject* module) { auto m = py::handle(module).cast(); @@ -21,39 +23,72 @@ void initCudartBindings(PyObject* module) { // HIP rewrite rules from changing these names when building with HIP. #if !defined(USE_ROCM) - py::enum_(cudart, "cuda" "OutputMode") + py::enum_( + cudart, + "cuda" + "OutputMode") .value("KeyValuePair", cudaKeyValuePair) .value("CSV", cudaCSV); #endif - py::enum_(cudart, "cuda" "Error") + py::enum_( + cudart, + "cuda" + "Error") .value("success", cudaSuccess); - cudart.def("cuda" "GetErrorString", cudaGetErrorString); - cudart.def("cuda" "ProfilerStart", cudaProfilerStart); - cudart.def("cuda" "ProfilerStop", cudaProfilerStop); - cudart.def("cuda" "HostRegister", [](uintptr_t ptr, size_t size, unsigned int flags) -> cudaError_t { - return cudaHostRegister((void*)ptr, size, flags); - }); - cudart.def("cuda" "HostUnregister", [](uintptr_t ptr) -> cudaError_t { - return cudaHostUnregister((void*)ptr); - }); - cudart.def("cuda" "StreamCreate", [](uintptr_t ptr) -> cudaError_t { - return cudaStreamCreate((cudaStream_t*)ptr); - }); - cudart.def("cuda" "StreamDestroy", [](uintptr_t ptr) -> cudaError_t { - return cudaStreamDestroy((cudaStream_t)ptr); - }); + cudart.def( + "cuda" + "GetErrorString", + cudaGetErrorString); + cudart.def( + "cuda" + "ProfilerStart", + cudaProfilerStart); + cudart.def( + "cuda" + "ProfilerStop", + cudaProfilerStop); + cudart.def( + "cuda" + "HostRegister", + [](uintptr_t ptr, size_t size, unsigned int flags) -> cudaError_t { + return cudaHostRegister((void*)ptr, size, flags); + }); + cudart.def( + "cuda" + "HostUnregister", + [](uintptr_t ptr) -> cudaError_t { + return cudaHostUnregister((void*)ptr); + }); + cudart.def( + "cuda" + "StreamCreate", + [](uintptr_t ptr) -> cudaError_t { + return cudaStreamCreate((cudaStream_t*)ptr); + }); + cudart.def( + "cuda" + "StreamDestroy", + [](uintptr_t ptr) -> cudaError_t { + return cudaStreamDestroy((cudaStream_t)ptr); + }); #if !defined(USE_ROCM) - cudart.def("cuda" "ProfilerInitialize", cudaProfilerInitialize); + cudart.def( + "cuda" + "ProfilerInitialize", + cudaProfilerInitialize); #endif - cudart.def("cuda" "MemGetInfo", [](int device) -> std::pair { - c10::cuda::CUDAGuard guard(device); - size_t device_free = 0; - size_t device_total = 0; - cudaMemGetInfo(&device_free, &device_total); - return {device_free, device_total}; - }); + cudart.def( + "cuda" + "MemGetInfo", + [](int device) -> std::pair { + c10::cuda::CUDAGuard guard(device); + size_t device_free = 0; + size_t device_total = 0; + cudaMemGetInfo(&device_free, &device_total); + return {device_free, device_total}; + }); } } // namespace shared diff --git a/torch/csrc/cuda/shared/cudnn.cpp b/torch/csrc/cuda/shared/cudnn.cpp index 594a4a19c0c7c2..f6cdc152d12685 100644 --- a/torch/csrc/cuda/shared/cudnn.cpp +++ b/torch/csrc/cuda/shared/cudnn.cpp @@ -1,5 +1,6 @@ // The clang-tidy job seems to complain that it can't find cudnn.h without this. -// This file should only be compiled if this condition holds, so it should be safe. +// This file should only be compiled if this condition holds, so it should be +// safe. #if defined(USE_CUDNN) || defined(USE_ROCM) #include @@ -39,7 +40,7 @@ size_t getVersionInt() { #endif } -} +} // namespace #elif defined(USE_ROCM) #include #include @@ -53,7 +54,8 @@ version_tuple getCompileVersion() { version_tuple getRuntimeVersion() { // MIOpen doesn't include runtime version info before 2.3.0 -#if (MIOPEN_VERSION_MAJOR > 2) || (MIOPEN_VERSION_MAJOR == 2 && MIOPEN_VERSION_MINOR > 2) +#if (MIOPEN_VERSION_MAJOR > 2) || \ + (MIOPEN_VERSION_MAJOR == 2 && MIOPEN_VERSION_MINOR > 2) size_t major, minor, patch; miopenGetVersion(&major, &minor, &patch); return version_tuple(major, minor, patch); @@ -69,10 +71,12 @@ size_t getVersionInt() { return major * 1000000 + minor * 1000 + patch; } -} +} // namespace #endif -namespace torch { namespace cuda { namespace shared { +namespace torch { +namespace cuda { +namespace shared { void initCudnnBindings(PyObject* module) { auto m = py::handle(module).cast(); @@ -80,10 +84,10 @@ void initCudnnBindings(PyObject* module) { auto cudnn = m.def_submodule("_cudnn", "libcudnn.so bindings"); py::enum_(cudnn, "RNNMode") - .value("rnn_relu", CUDNN_RNN_RELU) - .value("rnn_tanh", CUDNN_RNN_TANH) - .value("lstm", CUDNN_LSTM) - .value("gru", CUDNN_GRU); + .value("rnn_relu", CUDNN_RNN_RELU) + .value("rnn_tanh", CUDNN_RNN_TANH) + .value("lstm", CUDNN_LSTM) + .value("gru", CUDNN_GRU); // The runtime version check in python needs to distinguish cudnn from miopen #ifdef USE_CUDNN diff --git a/torch/csrc/cuda/shared/nvtx.cpp b/torch/csrc/cuda/shared/nvtx.cpp index 2e24ab16dd6288..6c4584a39b48ed 100644 --- a/torch/csrc/cuda/shared/nvtx.cpp +++ b/torch/csrc/cuda/shared/nvtx.cpp @@ -1,7 +1,9 @@ -#include #include +#include -namespace torch { namespace cuda { namespace shared { +namespace torch { +namespace cuda { +namespace shared { void initNvtxBindings(PyObject* module) { auto m = py::handle(module).cast(); diff --git a/torch/csrc/cuda/utils.cpp b/torch/csrc/cuda/utils.cpp index a274f261a2bd7f..b3b33230f36391 100644 --- a/torch/csrc/cuda/utils.cpp +++ b/torch/csrc/cuda/utils.cpp @@ -1,35 +1,40 @@ +#include #include #include #include -#include #include #ifdef USE_CUDA // NB: It's a list of *optional* CUDAStream; when nullopt, that means to use // whatever the current stream of the device the input is associated with was. -std::vector> THPUtils_PySequence_to_CUDAStreamList(PyObject *obj) { +std::vector> +THPUtils_PySequence_to_CUDAStreamList(PyObject* obj) { if (!PySequence_Check(obj)) { - throw std::runtime_error("Expected a sequence in THPUtils_PySequence_to_CUDAStreamList"); + throw std::runtime_error( + "Expected a sequence in THPUtils_PySequence_to_CUDAStreamList"); } THPObjectPtr seq = THPObjectPtr(PySequence_Fast(obj, nullptr)); if (seq.get() == nullptr) { - throw std::runtime_error("expected PySequence, but got " + std::string(THPUtils_typename(obj))); + throw std::runtime_error( + "expected PySequence, but got " + std::string(THPUtils_typename(obj))); } std::vector> streams; Py_ssize_t length = PySequence_Fast_GET_SIZE(seq.get()); for (Py_ssize_t i = 0; i < length; i++) { - PyObject *stream = PySequence_Fast_GET_ITEM(seq.get(), i); + PyObject* stream = PySequence_Fast_GET_ITEM(seq.get(), i); if (PyObject_IsInstance(stream, THCPStreamClass)) { // Spicy hot reinterpret cast!! - streams.emplace_back( at::cuda::CUDAStream::unpack((reinterpret_cast(stream))->cdata) ); + streams.emplace_back(at::cuda::CUDAStream::unpack( + (reinterpret_cast(stream))->cdata)); } else if (stream == Py_None) { streams.emplace_back(); } else { // NOLINTNEXTLINE(bugprone-throw-keyword-missing) - std::runtime_error("Unknown data type found in stream list. Need torch.cuda.Stream or None"); + std::runtime_error( + "Unknown data type found in stream list. Need torch.cuda.Stream or None"); } } return streams; diff --git a/torch/csrc/cuda/utils.h b/torch/csrc/cuda/utils.h index 51084615cec37a..fa29394b563abc 100644 --- a/torch/csrc/cuda/utils.h +++ b/torch/csrc/cuda/utils.h @@ -1,14 +1,14 @@ #ifndef THCP_UTILS_H #define THCP_UTILS_H -#define THCPUtils_(NAME) TH_CONCAT_4(THCP,Real,Utils_,NAME) +#define THCPUtils_(NAME) TH_CONCAT_4(THCP, Real, Utils_, NAME) -#define THCTensorPtr TH_CONCAT_3(THC,Real,TensorPtr) -#define THCPStoragePtr TH_CONCAT_3(THCP,Real,StoragePtr) -#define THCPTensorPtr TH_CONCAT_3(THCP,Real,TensorPtr) +#define THCTensorPtr TH_CONCAT_3(THC, Real, TensorPtr) +#define THCPStoragePtr TH_CONCAT_3(THCP, Real, StoragePtr) +#define THCPTensorPtr TH_CONCAT_3(THCP, Real, TensorPtr) -#define THCSTensorPtr TH_CONCAT_3(THCS,Real,TensorPtr) -#define THCSPTensorPtr TH_CONCAT_3(THCSP,Real,TensorPtr) +#define THCSTensorPtr TH_CONCAT_3(THCS, Real, TensorPtr) +#define THCSPTensorPtr TH_CONCAT_3(THCSP, Real, TensorPtr) #include diff --git a/torch/csrc/distributed/autograd/autograd.cpp b/torch/csrc/distributed/autograd/autograd.cpp index 9c54c74d909f85..f5f0fe8153d7e7 100644 --- a/torch/csrc/distributed/autograd/autograd.cpp +++ b/torch/csrc/distributed/autograd/autograd.cpp @@ -1,5 +1,5 @@ -#include #include +#include namespace torch { namespace distributed { diff --git a/torch/csrc/distributed/autograd/context/container.cpp b/torch/csrc/distributed/autograd/context/container.cpp index 3b6ef844509f1b..fcfa90b1f13a8b 100644 --- a/torch/csrc/distributed/autograd/context/container.cpp +++ b/torch/csrc/distributed/autograd/context/container.cpp @@ -246,18 +246,17 @@ void DistAutogradContainer::sendReleaseContextRpc( CleanupAutogradContextReq(context_id).toMessage(), options); - cleanupFuture->addCallback( - [worker_id](rpc::JitFuture& future) { - if (future.hasError()) { - std::string errorMsg = c10::str( - "Could not release Dist Autograd Context on node ", - worker_id, - ": ", - future.tryRetrieveErrorMessage()); - LOG(ERROR) << errorMsg; - return; - } - }); + cleanupFuture->addCallback([worker_id](rpc::JitFuture& future) { + if (future.hasError()) { + std::string errorMsg = c10::str( + "Could not release Dist Autograd Context on node ", + worker_id, + ": ", + future.tryRetrieveErrorMessage()); + LOG(ERROR) << errorMsg; + return; + } + }); } catch (const std::exception& e) { LOG(INFO) << "Failed to send RPC to clear Dist Autograd context to worker id: " diff --git a/torch/csrc/distributed/autograd/context/context.h b/torch/csrc/distributed/autograd/context/context.h index 2282af71f97a16..bf8bb7cdef8c06 100644 --- a/torch/csrc/distributed/autograd/context/context.h +++ b/torch/csrc/distributed/autograd/context/context.h @@ -51,8 +51,7 @@ class TORCH_API DistAutogradContext { const; // Adds a future message recording an outstanding RPC. - void addOutstandingRpc( - const c10::intrusive_ptr& jitFuture); + void addOutstandingRpc(const c10::intrusive_ptr& jitFuture); // Returns all gradients. const c10::Dict getGradients() const; diff --git a/torch/csrc/distributed/autograd/engine/dist_engine.cpp b/torch/csrc/distributed/autograd/engine/dist_engine.cpp index e6522c33280a92..2da315644845cf 100644 --- a/torch/csrc/distributed/autograd/engine/dist_engine.cpp +++ b/torch/csrc/distributed/autograd/engine/dist_engine.cpp @@ -2,12 +2,12 @@ #include #include +#include #include #include #include #include #include -#include namespace torch { namespace distributed { @@ -98,7 +98,7 @@ void DistEngine::globalCpuThread( variables = InputBuffer::variables(std::move(task.inputs_))]() mutable { InputBuffer inputs(variables.size()); - for(const auto i : c10::irange(variables.size())) { + for (const auto i : c10::irange(variables.size())) { inputs.add(i, std::move(variables[i]), c10::nullopt, c10::nullopt); } execute_graph_task_until_ready_queue_empty( @@ -296,7 +296,8 @@ void DistEngine::computeDependencies( // Create a dummy GraphRoot and run init_to_execute with it. GraphRoot dummyRoot(edges, {}); - graphTask->init_to_execute(dummyRoot, outputEdges, /*accumulate_grad=*/false, /*min_topo_nr=*/0); + graphTask->init_to_execute( + dummyRoot, outputEdges, /*accumulate_grad=*/false, /*min_topo_nr=*/0); for (auto& mapEntry : graphTask->exec_info_) { auto& execInfo = mapEntry.second; if (!execInfo.captures_) { @@ -411,32 +412,31 @@ c10::intrusive_ptr DistEngine:: auto accumulateGradFuture = c10::make_intrusive(c10::NoneType::get()); - futureGrads->addCallback( - [autogradContext, outputEdges, accumulateGradFuture](c10::ivalue::Future& futureGrads) { - if (futureGrads.hasError()) { - // Don't accumulate gradients if we receive an error. - // We must add the node information here since DistEngine::execute - // waits on accumulateGradFuture and will throw an exception once we - // set the error below. - std::string errorMsg = c10::str( - "Error on Node ", - DistAutogradContainer::getInstance().getWorkerId(), - ": ", - futureGrads.tryRetrieveErrorMessage()); - accumulateGradFuture->setError(std::make_exception_ptr( - c10::ivalue::Future::FutureError(std::move(errorMsg)))); - return; - } + futureGrads->addCallback([autogradContext, outputEdges, accumulateGradFuture]( + c10::ivalue::Future& futureGrads) { + if (futureGrads.hasError()) { + // Don't accumulate gradients if we receive an error. + // We must add the node information here since DistEngine::execute + // waits on accumulateGradFuture and will throw an exception once we + // set the error below. + std::string errorMsg = c10::str( + "Error on Node ", + DistAutogradContainer::getInstance().getWorkerId(), + ": ", + futureGrads.tryRetrieveErrorMessage()); + accumulateGradFuture->setError(std::make_exception_ptr( + c10::ivalue::Future::FutureError(std::move(errorMsg)))); + return; + } - try { - const variable_list& grads = - futureGrads.constValue().toTensorVector(); - TORCH_INTERNAL_ASSERT(grads.size() == outputEdges.size()); - accumulateGradFuture->markCompleted(c10::IValue()); - } catch (std::exception& e) { - accumulateGradFuture->setErrorIfNeeded(std::current_exception()); - } - }); + try { + const variable_list& grads = futureGrads.constValue().toTensorVector(); + TORCH_INTERNAL_ASSERT(grads.size() == outputEdges.size()); + accumulateGradFuture->markCompleted(c10::IValue()); + } catch (std::exception& e) { + accumulateGradFuture->setErrorIfNeeded(std::current_exception()); + } + }); return accumulateGradFuture; } @@ -445,22 +445,22 @@ c10::intrusive_ptr DistEngine::executeSendFunctionAsync( const ContextPtr& autogradContext, const std::shared_ptr& sendFunction, bool retainGraph) { - // Typically the local autograd engine ensures stream synchronizations between // nodes in the graph. However, for distributed autograd the sendFunction // inputs might have been retrieved over the wire on a separate stream and the // sendFunction itself runs on a different stream. As a result, we need to // manually synchronize those two streams here. - const auto& send_backward_stream = sendFunction->stream(c10::DeviceType::CUDA); + const auto& send_backward_stream = + sendFunction->stream(c10::DeviceType::CUDA); if (send_backward_stream) { for (const auto& grad : sendFunction->getGrads()) { - const auto guard = c10::impl::VirtualGuardImpl{c10::DeviceType::CUDA}; - const auto default_stream = guard.getStream(grad.device()); - if (send_backward_stream != default_stream) { - auto event = c10::Event{c10::DeviceType::CUDA}; - event.record(default_stream); - send_backward_stream->wait(event); - } + const auto guard = c10::impl::VirtualGuardImpl{c10::DeviceType::CUDA}; + const auto default_stream = guard.getStream(grad.device()); + if (send_backward_stream != default_stream) { + auto event = c10::Event{c10::DeviceType::CUDA}; + event.record(default_stream); + send_backward_stream->wait(event); + } } } @@ -490,42 +490,45 @@ c10::intrusive_ptr DistEngine::executeSendFunctionAsync( auto callbackFuture = c10::make_intrusive(c10::NoneType::get()); - accumulateGradFuture->addCallback([autogradContext, - callbackFuture](c10::ivalue::Future& accumulateGradFuture) { - try { - if (accumulateGradFuture.hasError()) { - // Perform cleanup at the end of the backward pass (before we mark - // the future as completed). - DistEngine::getInstance().cleanupBackwardPass(autogradContext); - - // Skip any further processing on errors. - callbackFuture->setError(accumulateGradFuture.exception_ptr()); - return; - } - - // Wait for all RPCs after the autograd engine is done. - auto rpcFuture = autogradContext->clearAndWaitForOutstandingRpcsAsync(); - rpcFuture->addCallback([callbackFuture, autogradContext](c10::ivalue::Future& rpcFuture) { + accumulateGradFuture->addCallback( + [autogradContext, + callbackFuture](c10::ivalue::Future& accumulateGradFuture) { try { - // Perform cleanup at the end of the backward pass (before - // we mark the future as completed). - DistEngine::getInstance().cleanupBackwardPass(autogradContext); + if (accumulateGradFuture.hasError()) { + // Perform cleanup at the end of the backward pass (before we mark + // the future as completed). + DistEngine::getInstance().cleanupBackwardPass(autogradContext); + + // Skip any further processing on errors. + callbackFuture->setError(accumulateGradFuture.exception_ptr()); + return; + } + + // Wait for all RPCs after the autograd engine is done. + auto rpcFuture = + autogradContext->clearAndWaitForOutstandingRpcsAsync(); + rpcFuture->addCallback([callbackFuture, autogradContext]( + c10::ivalue::Future& rpcFuture) { + try { + // Perform cleanup at the end of the backward pass (before + // we mark the future as completed). + DistEngine::getInstance().cleanupBackwardPass(autogradContext); + } catch (std::exception& e) { + callbackFuture->setErrorIfNeeded(std::current_exception()); + return; + } + + // Finally mark the 'uber' future as completed. + if (!rpcFuture.hasError()) { + callbackFuture->markCompleted(c10::IValue()); + } else { + callbackFuture->setError(rpcFuture.exception_ptr()); + } + }); } catch (std::exception& e) { callbackFuture->setErrorIfNeeded(std::current_exception()); - return; - } - - // Finally mark the 'uber' future as completed. - if (!rpcFuture.hasError()) { - callbackFuture->markCompleted(c10::IValue()); - } else { - callbackFuture->setError(rpcFuture.exception_ptr()); } }); - } catch (std::exception& e) { - callbackFuture->setErrorIfNeeded(std::current_exception()); - } - }); // Return the future which waits for all async processing to be done. return callbackFuture; diff --git a/torch/csrc/distributed/autograd/engine/dist_engine.h b/torch/csrc/distributed/autograd/engine/dist_engine.h index ac1c66ec438028..f8102a796595b7 100644 --- a/torch/csrc/distributed/autograd/engine/dist_engine.h +++ b/torch/csrc/distributed/autograd/engine/dist_engine.h @@ -60,6 +60,7 @@ class TORCH_API DistEngine { DistEngine& operator=(const DistEngine&) = delete; DistEngine(DistEngine&&) = delete; DistEngine& operator=(DistEngine&&) = delete; + private: // Make sure this is a singleton. DistEngine(); diff --git a/torch/csrc/distributed/autograd/functions/recvrpc_backward.cpp b/torch/csrc/distributed/autograd/functions/recvrpc_backward.cpp index a492d9847fb371..62cfe0a4f8474b 100644 --- a/torch/csrc/distributed/autograd/functions/recvrpc_backward.cpp +++ b/torch/csrc/distributed/autograd/functions/recvrpc_backward.cpp @@ -1,6 +1,6 @@ -#include #include #include +#include #include #include @@ -24,7 +24,7 @@ RecvRpcBackward::RecvRpcBackward( variable_list RecvRpcBackward::apply(variable_list&& grads) { std::vector outputGrads; - for(const auto i : c10::irange(grads.size())) { + for (const auto i : c10::irange(grads.size())) { const auto& grad = grads[i]; if (grad.defined()) { outputGrads.emplace_back(grad); diff --git a/torch/csrc/distributed/autograd/init.cpp b/torch/csrc/distributed/autograd/init.cpp index 036f4329108a84..102b5cc9fc38b6 100644 --- a/torch/csrc/distributed/autograd/init.cpp +++ b/torch/csrc/distributed/autograd/init.cpp @@ -28,7 +28,8 @@ PyObject* dist_autograd_init(PyObject* _unused, PyObject* noargs) { } auto torch_C_m = py::handle(torch_C_module).cast(); - auto m = torch_C_m.def_submodule("_distributed_autograd", "distributed autograd bindings"); + auto m = torch_C_m.def_submodule( + "_distributed_autograd", "distributed autograd bindings"); auto module = py::handle(m).cast(); @@ -225,10 +226,7 @@ Example:: } // namespace static PyMethodDef methods[] = { // NOLINT - {"_dist_autograd_init", - dist_autograd_init, - METH_NOARGS, - nullptr}, + {"_dist_autograd_init", dist_autograd_init, METH_NOARGS, nullptr}, {nullptr, nullptr, 0, nullptr}}; PyMethodDef* python_functions() { diff --git a/torch/csrc/distributed/autograd/rpc_messages/cleanup_autograd_context_resp.cpp b/torch/csrc/distributed/autograd/rpc_messages/cleanup_autograd_context_resp.cpp index 17d4a72dc22f04..bab28835e57c3b 100644 --- a/torch/csrc/distributed/autograd/rpc_messages/cleanup_autograd_context_resp.cpp +++ b/torch/csrc/distributed/autograd/rpc_messages/cleanup_autograd_context_resp.cpp @@ -4,7 +4,8 @@ namespace torch { namespace distributed { namespace autograd { -c10::intrusive_ptr CleanupAutogradContextResp::toMessageImpl() && { +c10::intrusive_ptr CleanupAutogradContextResp:: + toMessageImpl() && { std::vector tensors; std::vector payload; return c10::make_intrusive( diff --git a/torch/csrc/distributed/autograd/rpc_messages/propagate_gradients_req.cpp b/torch/csrc/distributed/autograd/rpc_messages/propagate_gradients_req.cpp index a2a742ec5f7f97..aaf5df632d2cf5 100644 --- a/torch/csrc/distributed/autograd/rpc_messages/propagate_gradients_req.cpp +++ b/torch/csrc/distributed/autograd/rpc_messages/propagate_gradients_req.cpp @@ -73,12 +73,12 @@ std::unique_ptr PropagateGradientsReq::fromMessage( // Retrieve the gradient tensors. std::vector grads(tupleElements.size() - 3); - for(const auto i : c10::irange(tupleElements.size() - 3)) { + for (const auto i : c10::irange(tupleElements.size() - 3)) { grads[i] = tupleElements[i].toTensor(); } return std::make_unique( - autogradMetadata, grads, retainGraph); + autogradMetadata, grads, retainGraph); } const AutogradMetadata& PropagateGradientsReq::getAutogradMetadata() { diff --git a/torch/csrc/distributed/autograd/rpc_messages/rpc_with_autograd.cpp b/torch/csrc/distributed/autograd/rpc_messages/rpc_with_autograd.cpp index b8d28f7be7c2d9..31ab867d506a45 100644 --- a/torch/csrc/distributed/autograd/rpc_messages/rpc_with_autograd.cpp +++ b/torch/csrc/distributed/autograd/rpc_messages/rpc_with_autograd.cpp @@ -1,5 +1,5 @@ -#include #include +#include #include #include #include @@ -66,11 +66,12 @@ c10::intrusive_ptr RpcWithAutograd::toMessageImpl() && { deviceMap.insert(mapEntry.first.str(), mapEntry.second.str()); } - std::vector ivalues{wrappedMessageType, - autogradMetadata_.autogradContextId, - autogradMetadata_.autogradMessageId, - fromWorkerId_, - deviceMap}; + std::vector ivalues{ + wrappedMessageType, + autogradMetadata_.autogradContextId, + autogradMetadata_.autogradMessageId, + fromWorkerId_, + deviceMap}; // Now pickle using JIT pickler. std::vector tensorTable; @@ -109,7 +110,8 @@ std::unique_ptr RpcWithAutograd::fromMessage( AutogradMetadata autogradMetadata( tupleElements[1].toInt(), tupleElements[2].toInt()); worker_id_t workerId = tupleElements[3].toInt(); - auto c10DeviceMap = tupleElements[4].to>(); + auto c10DeviceMap = + tupleElements[4].to>(); // Convert to regular map. rpc::DeviceMap deviceMap; @@ -169,8 +171,7 @@ rpc::worker_id_t RpcWithAutograd::fromWorkerId() const { return fromWorkerId_; } -const rpc::DeviceMap& RpcWithAutograd:: - deviceMap() { +const rpc::DeviceMap& RpcWithAutograd::deviceMap() { return deviceMap_; } diff --git a/torch/csrc/distributed/autograd/utils.cpp b/torch/csrc/distributed/autograd/utils.cpp index eb9cc51d911b4b..f67b93bd12661c 100644 --- a/torch/csrc/distributed/autograd/utils.cpp +++ b/torch/csrc/distributed/autograd/utils.cpp @@ -1,6 +1,6 @@ -#include #include #include +#include #include #include #include @@ -160,27 +160,27 @@ c10::intrusive_ptr sendMessageWithAutograd( // tell the remote end to process this request with the profiler enabled. if (!forceDisableProfiling) { switch (torch::profiler::impl::profilerType()) { - case torch::profiler::impl::ActiveProfilerType::LEGACY: - { - auto profilerConfig = torch::autograd::profiler::getProfilerConfig(); - auto msgWithProfiling = getMessageWithProfiling( + case torch::profiler::impl::ActiveProfilerType::LEGACY: { + auto profilerConfig = torch::autograd::profiler::getProfilerConfig(); + auto msgWithProfiling = getMessageWithProfiling( std::move(msg), rpc::MessageType::RUN_WITH_PROFILING_REQ, // NOLINTNEXTLINE(performance-move-const-arg) std::move(profilerConfig)); - return agent.send(dst, std::move(msgWithProfiling), rpcTimeoutSeconds); - } + return agent.send(dst, std::move(msgWithProfiling), rpcTimeoutSeconds); + } case torch::profiler::impl::ActiveProfilerType::KINETO: TORCH_WARN_ONCE( - "Profiling a distributed call with the Kineto profiler will profile " - "the caller, but not the worker."); + "Profiling a distributed call with the Kineto profiler will profile " + "the caller, but not the worker."); break; default: break; } } - return agent.send(dst, std::move(msg), rpcTimeoutSeconds);; + return agent.send(dst, std::move(msg), rpcTimeoutSeconds); + ; } } // namespace autograd diff --git a/torch/csrc/distributed/autograd/utils.h b/torch/csrc/distributed/autograd/utils.h index 94883ce6052695..8c77ae34e30360 100644 --- a/torch/csrc/distributed/autograd/utils.h +++ b/torch/csrc/distributed/autograd/utils.h @@ -44,12 +44,10 @@ TORCH_API c10::intrusive_ptr getMessageWithAutograd( c10::intrusive_ptr wrappedRpcMsg, rpc::MessageType msgType, bool forceGradRecording = false, - const rpc::DeviceMap& deviceMap = - {}); + const rpc::DeviceMap& deviceMap = {}); // Send message after autograd checking -TORCH_API c10::intrusive_ptr -sendMessageWithAutograd( +TORCH_API c10::intrusive_ptr sendMessageWithAutograd( rpc::RpcAgent& agent, const rpc::WorkerInfo& dst, c10::intrusive_ptr wrappedRpcMsg, diff --git a/torch/csrc/distributed/c10d/FileStore.cpp b/torch/csrc/distributed/c10d/FileStore.cpp index dd0192778c2df0..279ca57ab5072c 100644 --- a/torch/csrc/distributed/c10d/FileStore.cpp +++ b/torch/csrc/distributed/c10d/FileStore.cpp @@ -358,8 +358,13 @@ std::vector FileStore::get(const std::string& key) { const auto elapsed = std::chrono::duration_cast( std::chrono::steady_clock::now() - start); if (timeout_ != kNoTimeout && elapsed > timeout_) { - auto err = c10::str("Timeout waiting for key: ", key, " after ", timeout_.count(), " ms"); - TORCH_CHECK(false, err); + auto err = c10::str( + "Timeout waiting for key: ", + key, + " after ", + timeout_.count(), + " ms"); + TORCH_CHECK(false, err); } std::this_thread::sleep_for(std::chrono::milliseconds(10)); continue; diff --git a/torch/csrc/distributed/c10d/GlooDeviceFactory.cpp b/torch/csrc/distributed/c10d/GlooDeviceFactory.cpp index cb83a998385200..9c55d047084feb 100644 --- a/torch/csrc/distributed/c10d/GlooDeviceFactory.cpp +++ b/torch/csrc/distributed/c10d/GlooDeviceFactory.cpp @@ -67,7 +67,7 @@ C10_REGISTER_CREATOR(GlooDeviceRegistry, TCP, makeTCPDevice); #if GLOO_HAVE_TRANSPORT_TCP_TLS static std::string cstr_to_std_string(const char* chars) { - return std::string (chars != nullptr ? chars : ""); + return std::string(chars != nullptr ? chars : ""); } static std::shared_ptr<::gloo::transport::Device> makeTCPTLSDevice( @@ -84,11 +84,16 @@ static std::shared_ptr<::gloo::transport::Device> makeTCPTLSDevice( } else { attr.hostname = hostname; } - const auto pkey = cstr_to_std_string(std::getenv("GLOO_DEVICE_TRANSPORT_TCP_TLS_PKEY")); - const auto cert = cstr_to_std_string(std::getenv("GLOO_DEVICE_TRANSPORT_TCP_TLS_CERT")); - const auto caFile = cstr_to_std_string(std::getenv("GLOO_DEVICE_TRANSPORT_TCP_TLS_CA_FILE")); - const auto caPath = cstr_to_std_string(std::getenv("GLOO_DEVICE_TRANSPORT_TCP_TLS_CA_PATH")); - return ::gloo::transport::tcp::tls::CreateDevice(attr, pkey, cert, caFile, caPath); + const auto pkey = + cstr_to_std_string(std::getenv("GLOO_DEVICE_TRANSPORT_TCP_TLS_PKEY")); + const auto cert = + cstr_to_std_string(std::getenv("GLOO_DEVICE_TRANSPORT_TCP_TLS_CERT")); + const auto caFile = + cstr_to_std_string(std::getenv("GLOO_DEVICE_TRANSPORT_TCP_TLS_CA_FILE")); + const auto caPath = + cstr_to_std_string(std::getenv("GLOO_DEVICE_TRANSPORT_TCP_TLS_CA_PATH")); + return ::gloo::transport::tcp::tls::CreateDevice( + attr, pkey, cert, caFile, caPath); } C10_REGISTER_CREATOR(GlooDeviceRegistry, TCP_TLS, makeTCPTLSDevice); @@ -121,9 +126,9 @@ C10_REGISTER_CREATOR(GlooDeviceRegistry, UV, makeUVDevice); #endif namespace { -std::shared_ptr<::gloo::transport::Device> -makeGlooDevice(const std::string& interfaceName, const std::string& hostName) -{ +std::shared_ptr<::gloo::transport::Device> makeGlooDevice( + const std::string& interfaceName, + const std::string& hostName) { static auto transportName = getenv("GLOO_DEVICE_TRANSPORT"); if (transportName) { return GlooDeviceRegistry()->Create(transportName, interfaceName, hostName); diff --git a/torch/csrc/distributed/c10d/NCCLUtils.cpp b/torch/csrc/distributed/c10d/NCCLUtils.cpp index 568c23ef7a2059..c262cedc9da8b0 100644 --- a/torch/csrc/distributed/c10d/NCCLUtils.cpp +++ b/torch/csrc/distributed/c10d/NCCLUtils.cpp @@ -6,7 +6,6 @@ namespace c10d { - ncclComm_t NCCLComm::getNcclComm() { std::unique_lock lock(mutex_); if (aborted_) { @@ -37,11 +36,12 @@ std::string getNcclVersion() { versionString = "Unknown NCCL version"; } else { // NCCL changed version coding starting 2.9 - const int majorBase = version < 2900 ? 1000 : 10000; + const int majorBase = version < 2900 ? 1000 : 10000; const int minorBase = 100; auto ncclMajor = version / majorBase; auto ncclMinor = (version % majorBase) / minorBase; - auto ncclPatch = version % (ncclMajor * majorBase + ncclMinor * minorBase); + auto ncclPatch = + version % (ncclMajor * majorBase + ncclMinor * minorBase); versionString = std::to_string(ncclMajor) + "." + std::to_string(ncclMinor) + "." + std::to_string(ncclPatch); } diff --git a/torch/csrc/distributed/c10d/ParamCommsUtils.cpp b/torch/csrc/distributed/c10d/ParamCommsUtils.cpp index 2a123db85dc552..eaa4649f7bb8b0 100644 --- a/torch/csrc/distributed/c10d/ParamCommsUtils.cpp +++ b/torch/csrc/distributed/c10d/ParamCommsUtils.cpp @@ -13,8 +13,8 @@ ParamCommsDebugInfo::ParamCommsDebugInfo( int outSize, at::ScalarType dType, std::vector inSplitSizes, - std::vector outSplitSizes) : - rank_(rank), + std::vector outSplitSizes) + : rank_(rank), columnName_(colName), inMessageSize_(inSize), outMessageSize_(outSize), diff --git a/torch/csrc/distributed/c10d/ProcessGroup.cpp b/torch/csrc/distributed/c10d/ProcessGroup.cpp index ff3bb7cf074845..f31a7d8c8153b1 100644 --- a/torch/csrc/distributed/c10d/ProcessGroup.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroup.cpp @@ -51,7 +51,8 @@ std::string opTypeToString(OpType opType) { } bool isP2POp(OpType opType, bool batchP2P /*= false*/) { - if (batchP2P) return false; + if (batchP2P) + return false; return opType == OpType::SEND || opType == OpType::RECV || opType == OpType::RECVANYSOURCE; } @@ -78,7 +79,9 @@ ProcessGroup::Work::Work( inputs.emplace_back(tensor); } } - recordingFunction->before(profilingTitle, c10::ArrayRef(inputs.data(), inputs.size())); + recordingFunction->before( + profilingTitle, + c10::ArrayRef(inputs.data(), inputs.size())); std::function end_handler = [recordingFunction]() { recordingFunction->end(); }; @@ -91,7 +94,7 @@ OpType ProcessGroup::Work::retrieveOpType() { return opType_; } -ProcessGroup::Work::~Work()=default; +ProcessGroup::Work::~Work() = default; bool ProcessGroup::Work::isCompleted() { std::lock_guard lock(mutex_); @@ -109,7 +112,8 @@ std::exception_ptr ProcessGroup::Work::exception() const { } int ProcessGroup::Work::sourceRank() const { - TORCH_CHECK(false, + TORCH_CHECK( + false, "sourceRank() may only be called on work objects " "that correspond to a recv or recv-from-any call."); } @@ -183,7 +187,8 @@ ProcessGroup::ProcessGroup(int rank, int size) ProcessGroup::~ProcessGroup() {} void ProcessGroup::init() { - C10_LOG_API_USAGE_ONCE(fmt::format("c10d.process_group_{}", getBackendName())); + C10_LOG_API_USAGE_ONCE( + fmt::format("c10d.process_group_{}", getBackendName())); } } // namespace c10d diff --git a/torch/csrc/distributed/c10d/ProcessGroupGloo.cpp b/torch/csrc/distributed/c10d/ProcessGroupGloo.cpp index f2b553ba1cc8eb..1c1ac652f6ed2a 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupGloo.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupGloo.cpp @@ -43,59 +43,59 @@ #include #ifdef _WIN32 -#define GENERATE_ALL_TYPES(type, func, ...) \ - switch (type) { \ - case ::at::ScalarType::Float: \ - func(__VA_ARGS__); \ - break; \ - case ::at::ScalarType::Double: \ - func(__VA_ARGS__); \ - break; \ - case ::at::ScalarType::Half: \ - func(__VA_ARGS__); \ - break; \ - case ::at::ScalarType::Char: \ - func(__VA_ARGS__); \ - break; \ - case ::at::ScalarType::Byte: \ - func(__VA_ARGS__); \ - break; \ - case ::at::ScalarType::Int: \ - func(__VA_ARGS__); \ - break; \ - case ::at::ScalarType::Long: \ - func(__VA_ARGS__); \ - break; \ - default: \ +#define GENERATE_ALL_TYPES(type, func, ...) \ + switch (type) { \ + case ::at::ScalarType::Float: \ + func(__VA_ARGS__); \ + break; \ + case ::at::ScalarType::Double: \ + func(__VA_ARGS__); \ + break; \ + case ::at::ScalarType::Half: \ + func(__VA_ARGS__); \ + break; \ + case ::at::ScalarType::Char: \ + func(__VA_ARGS__); \ + break; \ + case ::at::ScalarType::Byte: \ + func(__VA_ARGS__); \ + break; \ + case ::at::ScalarType::Int: \ + func(__VA_ARGS__); \ + break; \ + case ::at::ScalarType::Long: \ + func(__VA_ARGS__); \ + break; \ + default: \ TORCH_CHECK(false, "Invalid scalar type"); \ } #define HOST_NAME_MAX 256 #else -#define GENERATE_ALL_TYPES(type, func, args...) \ - switch (type) { \ - case ::at::ScalarType::Float: \ - func(args); \ - break; \ - case ::at::ScalarType::Double: \ - func(args); \ - break; \ - case ::at::ScalarType::Half: \ - func(args); \ - break; \ - case ::at::ScalarType::Char: \ - func(args); \ - break; \ - case ::at::ScalarType::Byte: \ - func(args); \ - break; \ - case ::at::ScalarType::Int: \ - func(args); \ - break; \ - case ::at::ScalarType::Long: \ - func(args); \ - break; \ - default: \ +#define GENERATE_ALL_TYPES(type, func, args...) \ + switch (type) { \ + case ::at::ScalarType::Float: \ + func(args); \ + break; \ + case ::at::ScalarType::Double: \ + func(args); \ + break; \ + case ::at::ScalarType::Half: \ + func(args); \ + break; \ + case ::at::ScalarType::Char: \ + func(args); \ + break; \ + case ::at::ScalarType::Byte: \ + func(args); \ + break; \ + case ::at::ScalarType::Int: \ + func(args); \ + break; \ + case ::at::ScalarType::Long: \ + func(args); \ + break; \ + default: \ TORCH_CHECK(false, "Invalid scalar type"); \ } #endif @@ -179,20 +179,16 @@ ReduceFunc toFunction(const ReduceOp& r) { case ReduceOp::MAX: return ReduceFunc(&::gloo::max); case ReduceOp::BAND: - TORCH_CHECK(false, - "Cannot use ReduceOp.BAND with non-integral dtype"); + TORCH_CHECK(false, "Cannot use ReduceOp.BAND with non-integral dtype"); break; case ReduceOp::BOR: - TORCH_CHECK(false, - "Cannot use ReduceOp.BOR with non-integral dtype"); + TORCH_CHECK(false, "Cannot use ReduceOp.BOR with non-integral dtype"); break; case ReduceOp::BXOR: - TORCH_CHECK(false, - "Cannot use ReduceOp.BXOR with non-integral dtype"); + TORCH_CHECK(false, "Cannot use ReduceOp.BXOR with non-integral dtype"); break; case ReduceOp::AVG: - TORCH_CHECK(false, - "Cannot use ReduceOp.AVG with Gloo"); + TORCH_CHECK(false, "Cannot use ReduceOp.AVG with Gloo"); break; case ReduceOp::UNUSED: break; @@ -209,7 +205,7 @@ void band(void* c, const void* a, const void* b, size_t n) { auto tc = static_cast(c); auto ta = static_cast(a); auto tb = static_cast(b); - for(const auto i : c10::irange(n)) { + for (const auto i : c10::irange(n)) { tc[i] = ta[i] & tb[i]; } } @@ -222,7 +218,7 @@ void bor(void* c, const void* a, const void* b, size_t n) { auto tc = static_cast(c); auto ta = static_cast(a); auto tb = static_cast(b); - for(const auto i : c10::irange(n)) { + for (const auto i : c10::irange(n)) { tc[i] = ta[i] | tb[i]; } } @@ -235,7 +231,7 @@ void bxor(void* c, const void* a, const void* b, size_t n) { auto tc = static_cast(c); auto ta = static_cast(a); auto tb = static_cast(b); - for(const auto i : c10::irange(n)) { + for (const auto i : c10::irange(n)) { tc[i] = ta[i] ^ tb[i]; } } @@ -260,8 +256,7 @@ ReduceFunc toFunction(const ReduceOp& r) { case ReduceOp::BXOR: return ReduceFunc(&bxor); case ReduceOp::AVG: - TORCH_CHECK(false, - "Cannot use ReduceOp.AVG with Gloo"); + TORCH_CHECK(false, "Cannot use ReduceOp.AVG with Gloo"); break; case ReduceOp::UNUSED: break; @@ -333,7 +328,7 @@ void initializeStreamsEvents( std::vector& events) { streams.reserve(tensors.size()); events.reserve(tensors.size()); - for(const auto i : c10::irange(tensors.size())) { + for (const auto i : c10::irange(tensors.size())) { c10::Device device = tensors[i].device(); c10::impl::VirtualGuardImpl impl(device.type()); // Record event on current stream @@ -342,7 +337,8 @@ void initializeStreamsEvents( // Get a non-default stream to execute asynchronous CUDA operations // on for this device. This ensures that the default stream used // by the caller is not occupied by c10d related operations. - streams.push_back(impl.getStreamFromGlobalPool(device, /*isHighPriority=*/true)); + streams.push_back( + impl.getStreamFromGlobalPool(device, /*isHighPriority=*/true)); // Ensure the new stream is synchronized with the current stream. events[i].block(streams[i]); @@ -350,8 +346,10 @@ void initializeStreamsEvents( // new streams in this Work to prevent being freed before the Work finishes. if (tensors[i].is_sparse()) { if (tensors[i].is_coalesced()) { - impl.recordDataPtrOnStream(tensors[i].indices().storage().data_ptr(), streams[i]); - impl.recordDataPtrOnStream(tensors[i].values().storage().data_ptr(), streams[i]); + impl.recordDataPtrOnStream( + tensors[i].indices().storage().data_ptr(), streams[i]); + impl.recordDataPtrOnStream( + tensors[i].values().storage().data_ptr(), streams[i]); } else { // We will need to coalesce first, which means new tensors will // be allocated on the streams we just allocated, and there @@ -377,7 +375,8 @@ void initializeStreamsEvents( const auto device_id = tensorgroup[0].device().index(); for (const auto& tensor : tensorgroup) { if (tensor.device().index() != device_id) { - TORCH_CHECK(false, + TORCH_CHECK( + false, "tensors in the nested tensor vectors need to " "be on the same device"); } @@ -386,7 +385,7 @@ void initializeStreamsEvents( streams.reserve(tensors.size()); events.reserve(tensors.size()); - for(const auto i : c10::irange(tensors.size())) { + for (const auto i : c10::irange(tensors.size())) { c10::Device device = tensors[i][0].device(); c10::impl::VirtualGuardImpl impl(device.type()); // Record event on current stream @@ -395,7 +394,8 @@ void initializeStreamsEvents( // Get a non-default stream to execute asynchronous CUDA operations // on for this output. This ensures that the default stream used // by the caller is not occupied by c10d related operations. - streams.push_back(impl.getStreamFromGlobalPool(device, /*isHighPriority=*/true)); + streams.push_back( + impl.getStreamFromGlobalPool(device, /*isHighPriority=*/true)); // Ensure the new stream is synchronized with the current stream. events[i].block(streams[i]); @@ -414,7 +414,7 @@ const auto kLoopbackAddress = "127.0.0.1"; // static void ProcessGroupGloo::AsyncWork::execute(c10::intrusive_ptr work) { - if(work->recordFunctionBeforeCallback_){ + if (work->recordFunctionBeforeCallback_) { work->recordFunctionBeforeCallback_(); } try { @@ -481,17 +481,19 @@ inline void ProcessGroupGloo::AsyncWork::recordAsyncWorkProfilingInfo( if (recordingFunction->isActive()) { std::function before_handler = [inputTensors, profilingTitle, recordingFunction]() { - // The work will be started and completed by different threads. - recordingFunction->_setAsync(); - std::vector inputs; - if (inputTensors) { - inputs.reserve(inputTensors->size()); - for (const auto& tensor : *inputTensors) { - inputs.emplace_back(tensor); - } - } - recordingFunction->before(profilingTitle, c10::ArrayRef(inputs.data(), inputs.size())); - }; + // The work will be started and completed by different threads. + recordingFunction->_setAsync(); + std::vector inputs; + if (inputTensors) { + inputs.reserve(inputTensors->size()); + for (const auto& tensor : *inputTensors) { + inputs.emplace_back(tensor); + } + } + recordingFunction->before( + profilingTitle, + c10::ArrayRef(inputs.data(), inputs.size())); + }; recordFunctionBeforeCallback_ = at::wrapPropagateTLSState(before_handler); std::function end_handler = [recordingFunction]() { recordingFunction->end(); @@ -745,7 +747,7 @@ ProcessGroupGloo::ProcessGroupGloo( // by a single I/O thread. // contexts_.reserve(options->devices.size()); - for(const auto i : c10::irange(options->devices.size())) { + for (const auto i : c10::irange(options->devices.size())) { auto context = std::make_shared<::gloo::rendezvous::Context>(rank_, size_); auto store = ::gloo::rendezvous::PrefixStore(std::to_string(i), *store_); context->setTimeout(options->timeout); @@ -760,7 +762,7 @@ ProcessGroupGloo::ProcessGroupGloo( workInProgress_.resize(options->threads); threads_.resize(options->threads); - for(const auto i : c10::irange(threads_.size())) { + for (const auto i : c10::irange(threads_.size())) { threads_[i] = std::thread(&ProcessGroupGloo::runLoop, this, i); } @@ -867,7 +869,7 @@ class AsyncBroadcastWork : public ProcessGroupGloo::AsyncWork { broadcast(inputs[rootTensor]); // Copy to non-root tensors - for(const auto i : c10::irange(inputs.size())) { + for (const auto i : c10::irange(inputs.size())) { if (i == static_cast(rootTensor)) { continue; } @@ -907,7 +909,7 @@ class AsyncBroadcastCUDAWork : public AsyncBroadcastWork { // Kick off copy back to the CUDA tensors. c10::OptionalStreamGuard guard; - for(const auto i : c10::irange(inputs.size())) { + for (const auto i : c10::irange(inputs.size())) { guard.reset_stream(streams[i]); inputs[i].copy_(tmp, /* non_blocking */ true); events[i].record(streams[i]); @@ -916,9 +918,10 @@ class AsyncBroadcastCUDAWork : public AsyncBroadcastWork { void synchronize() override { // Synchronize with the copy back to CUDA tensors. - for(const auto i : c10::irange(inputs.size())) { + for (const auto i : c10::irange(inputs.size())) { c10::Device device = inputs[i].device(); - events[i].block(c10::impl::VirtualGuardImpl(device.type()).getStream(device)); + events[i].block( + c10::impl::VirtualGuardImpl(device.type()).getStream(device)); } } @@ -1234,7 +1237,7 @@ class AsyncSparseAllreduceWork : public ProcessGroupGloo::AsyncWork { std::vector counts(context->size); int64_t totalSize = 0; - for(const auto i : c10::irange(metadata.size())) { + for (const auto i : c10::irange(metadata.size())) { counts[i] = metadata[i].nnz() * sparseDim; totalSize += counts[i]; } @@ -1279,7 +1282,7 @@ class AsyncSparseAllreduceWork : public ProcessGroupGloo::AsyncWork { std::vector counts(context->size); int64_t totalSize = 0; - for(const auto i : c10::irange(metadata.size())) { + for (const auto i : c10::irange(metadata.size())) { counts[i] = metadata[i].nnz() * denseNumel; totalSize += counts[i]; } @@ -1330,7 +1333,7 @@ class AsyncAllreduceCUDAWork : public AsyncAllreduceWork { // Kick off copy from CUDA tensors to pinned CPU tensors. tmp.reserve(inputs.size()); c10::OptionalStreamGuard guard; - for(const auto i : c10::irange(inputs.size())) { + for (const auto i : c10::irange(inputs.size())) { guard.reset_stream(streams[i]); tmp.push_back(pinnedLike(inputs[i]).copy_(inputs[i], true)); } @@ -1338,7 +1341,7 @@ class AsyncAllreduceCUDAWork : public AsyncAllreduceWork { void run() override { // Synchronize with copy operations. - for(const auto i : c10::irange(inputs.size())) { + for (const auto i : c10::irange(inputs.size())) { streams[i].synchronize(); } @@ -1346,7 +1349,7 @@ class AsyncAllreduceCUDAWork : public AsyncAllreduceWork { allreduce(tmp); c10::OptionalStreamGuard guard; - for(const auto i : c10::irange(inputs.size())) { + for (const auto i : c10::irange(inputs.size())) { guard.reset_stream(streams[i]); inputs[i].copy_(tmp[i], /* non_blocking */ true); events[i].record(streams[i]); @@ -1355,9 +1358,10 @@ class AsyncAllreduceCUDAWork : public AsyncAllreduceWork { void synchronize() override { // Synchronize with the copy back to CUDA tensors. - for(const auto i : c10::irange(inputs.size())) { + for (const auto i : c10::irange(inputs.size())) { c10::Device device = inputs[i].device(); - events[i].block(c10::impl::VirtualGuardImpl(device.type()).getStream(device)); + events[i].block( + c10::impl::VirtualGuardImpl(device.type()).getStream(device)); } } @@ -1380,7 +1384,7 @@ class AsyncSparseAllreduceCUDAWork : public AsyncSparseAllreduceWork { // memory must be performed asynchronously, or we block the caller. tmp.reserve(inputs.size()); c10::OptionalStreamGuard guard; - for(const auto i : c10::irange(inputs.size())) { + for (const auto i : c10::irange(inputs.size())) { guard.reset_stream(streams[i]); tmp.push_back( inputs[i].coalesce().to(at::DeviceType::CPU, /*non_blocking=*/true)); @@ -1389,7 +1393,7 @@ class AsyncSparseAllreduceCUDAWork : public AsyncSparseAllreduceWork { void run() override { // Synchronize with copy operations. - for(const auto i : c10::irange(inputs.size())) { + for (const auto i : c10::irange(inputs.size())) { streams[i].synchronize(); } @@ -1398,7 +1402,7 @@ class AsyncSparseAllreduceCUDAWork : public AsyncSparseAllreduceWork { // Kick off copy back to the CUDA tensors. c10::OptionalStreamGuard guard; - for(const auto i : c10::irange(inputs.size())) { + for (const auto i : c10::irange(inputs.size())) { guard.reset_stream(streams[i]); inputs[i].copy_(output, /*non_blocking=*/true); events[i].record(streams[i]); @@ -1407,9 +1411,10 @@ class AsyncSparseAllreduceCUDAWork : public AsyncSparseAllreduceWork { void synchronize() override { // Synchronize with the copy back to CUDA tensors. - for(const auto i : c10::irange(inputs.size())) { + for (const auto i : c10::irange(inputs.size())) { c10::Device device = inputs[i].device(); - events[i].block(c10::impl::VirtualGuardImpl(device.type()).getStream(device)); + events[i].block( + c10::impl::VirtualGuardImpl(device.type()).getStream(device)); } } @@ -1485,8 +1490,7 @@ c10::intrusive_ptr ProcessGroupGloo::allreduce_coalesced( std::vector& tensors, const AllreduceCoalescedOptions& opts) { static auto invalidArgument = [](const std::string& msg) { - TORCH_CHECK(false, - "ProcessGroupGloo::allreduce_coalesced: " + msg); + TORCH_CHECK(false, "ProcessGroupGloo::allreduce_coalesced: " + msg); }; assertNonEmpty(invalidArgument, tensors); @@ -1611,7 +1615,7 @@ class AsyncReduceCUDAWork : public AsyncReduceWork { // Kick off copy from CUDA tensors to pinned CPU tensors. tmp.reserve(inputs.size()); c10::OptionalStreamGuard guard; - for(const auto i : c10::irange(inputs.size())) { + for (const auto i : c10::irange(inputs.size())) { guard.reset_stream(streams[i]); tmp.push_back(pinnedLike(inputs[i]).copy_(inputs[i], true)); } @@ -1619,7 +1623,7 @@ class AsyncReduceCUDAWork : public AsyncReduceWork { void run() override { // Synchronize with copy operations. - for(const auto i : c10::irange(inputs.size())) { + for (const auto i : c10::irange(inputs.size())) { streams[i].synchronize(); } @@ -1628,7 +1632,7 @@ class AsyncReduceCUDAWork : public AsyncReduceWork { // Kick off copy back to the CUDA tensors. c10::OptionalStreamGuard guard; - for(const auto i : c10::irange(inputs.size())) { + for (const auto i : c10::irange(inputs.size())) { guard.reset_stream(streams[i]); inputs[i].copy_(tmp[i], /* non_blocking */ true); events[i].record(streams[i]); @@ -1637,9 +1641,10 @@ class AsyncReduceCUDAWork : public AsyncReduceWork { void synchronize() override { // Synchronize with the copy back to CUDA tensors. - for(const auto i : c10::irange(inputs.size())) { + for (const auto i : c10::irange(inputs.size())) { c10::Device device = inputs[i].device(); - events[i].block(c10::impl::VirtualGuardImpl(device.type()).getStream(device)); + events[i].block( + c10::impl::VirtualGuardImpl(device.type()).getStream(device)); } } @@ -1767,15 +1772,15 @@ class AsyncAllgatherCUDAWork : public AsyncAllgatherWork { // Kick off copy from CUDA tensors to pinned CPU tensors. tmpInputs.reserve(inputs.size()); c10::OptionalStreamGuard guard; - for(const auto i : c10::irange(inputs.size())) { + for (const auto i : c10::irange(inputs.size())) { guard.reset_stream(inputStreams[i]); tmpInputs.push_back(pinnedLike(inputs[i]).copy_(inputs[i], true)); } tmpOutputs.resize(outputs.size()); - for(const auto i : c10::irange(outputs.size())) { + for (const auto i : c10::irange(outputs.size())) { tmpOutputs[i].reserve(outputs[i].size()); - for(const auto j : c10::irange(outputs[i].size())) { + for (const auto j : c10::irange(outputs[i].size())) { tmpOutputs[i].push_back(pinnedLike(outputs[i][j])); } } @@ -1783,11 +1788,11 @@ class AsyncAllgatherCUDAWork : public AsyncAllgatherWork { void run() override { // Synchronize with copy operations. - for(const auto i : c10::irange(inputs.size())) { + for (const auto i : c10::irange(inputs.size())) { inputStreams[i].synchronize(); } - for(const auto i : c10::irange(outputs.size())) { + for (const auto i : c10::irange(outputs.size())) { outputStreams[i].synchronize(); } @@ -1796,9 +1801,9 @@ class AsyncAllgatherCUDAWork : public AsyncAllgatherWork { // Kick off copy back to the CUDA tensors. c10::OptionalStreamGuard guard; - for(const auto i : c10::irange(outputs.size())) { + for (const auto i : c10::irange(outputs.size())) { guard.reset_stream(outputStreams[i]); - for(const auto j : c10::irange(outputs[i].size())) { + for (const auto j : c10::irange(outputs[i].size())) { outputs[i][j].copy_(tmpOutputs[i][j], /* non_blocking */ true); } outputEvents[i].record(outputStreams[i]); @@ -1807,9 +1812,10 @@ class AsyncAllgatherCUDAWork : public AsyncAllgatherWork { void synchronize() override { // Synchronize with the copy back to CUDA tensors. - for(const auto i : c10::irange(outputs.size())) { + for (const auto i : c10::irange(outputs.size())) { c10::Device device = outputs[i][0].device(); - outputEvents[i].block(c10::impl::VirtualGuardImpl(device.type()).getStream(device)); + outputEvents[i].block( + c10::impl::VirtualGuardImpl(device.type()).getStream(device)); } } @@ -1843,7 +1849,7 @@ c10::intrusive_ptr ProcessGroupGloo::allgather( "requires input/output tensor lists to have the same length"); } - for(const auto i : c10::irange(outputs.size())) { + for (const auto i : c10::irange(outputs.size())) { const auto expected = inputs.size() * getSize(); const auto actual = outputs[i].size(); if (actual != expected) { @@ -1965,8 +1971,7 @@ c10::intrusive_ptr ProcessGroupGloo::allgather_coalesced( std::vector& input_list, const AllgatherOptions& /* unused */) { static auto invalidArgument = [](const std::string& msg) { - TORCH_CHECK(false, - "ProcessGroupGloo::allgather_coalesced: " + msg); + TORCH_CHECK(false, "ProcessGroupGloo::allgather_coalesced: " + msg); }; if (input_list.empty()) { @@ -2020,8 +2025,7 @@ c10::intrusive_ptr ProcessGroupGloo::_allgather_base( at::Tensor& /*unused */, at::Tensor& /*unused */, const AllgatherOptions& /*unused */) { - TORCH_CHECK(false, - "no support for _allgather_base in Gloo process group"); + TORCH_CHECK(false, "no support for _allgather_base in Gloo process group"); } namespace { @@ -2069,7 +2073,7 @@ class AsyncGatherWork : public ProcessGroupGloo::AsyncWork { // Unflatten into output tensors on root process. if (context->rank == root) { - for(const auto i : c10::irange(outputs[0].size())) { + for (const auto i : c10::irange(outputs[0].size())) { outputs[0][i].copy_(flatOutputTensor[i]); } } @@ -2100,15 +2104,15 @@ class AsyncGatherCUDAWork : public AsyncGatherWork { // Kick off copy from CUDA tensors to pinned CPU tensors. tmpInputs.reserve(inputs.size()); c10::OptionalStreamGuard guard; - for(const auto i : c10::irange(inputs.size())) { + for (const auto i : c10::irange(inputs.size())) { guard.reset_stream(inputStreams[i]); tmpInputs.push_back(pinnedLike(inputs[i]).copy_(inputs[i], true)); } tmpOutputs.resize(outputs.size()); - for(const auto i : c10::irange(outputs.size())) { + for (const auto i : c10::irange(outputs.size())) { tmpOutputs[i].reserve(outputs[i].size()); - for(const auto j : c10::irange(outputs[i].size())) { + for (const auto j : c10::irange(outputs[i].size())) { tmpOutputs[i].push_back(pinnedLike(outputs[i][j])); } } @@ -2116,11 +2120,11 @@ class AsyncGatherCUDAWork : public AsyncGatherWork { void run() override { // Synchronize with copy operations. - for(const auto i : c10::irange(inputs.size())) { + for (const auto i : c10::irange(inputs.size())) { inputStreams[i].synchronize(); } - for(const auto i : c10::irange(outputs.size())) { + for (const auto i : c10::irange(outputs.size())) { outputStreams[i].synchronize(); } @@ -2129,9 +2133,9 @@ class AsyncGatherCUDAWork : public AsyncGatherWork { // Kick off copy back to the CUDA tensors. c10::OptionalStreamGuard guard; - for(const auto i : c10::irange(outputs.size())) { + for (const auto i : c10::irange(outputs.size())) { guard.reset_stream(outputStreams[i]); - for(const auto j : c10::irange(outputs[i].size())) { + for (const auto j : c10::irange(outputs[i].size())) { outputs[i][j].copy_(tmpOutputs[i][j], /* non_blocking */ true); } outputEvents[i].record(outputStreams[i]); @@ -2140,9 +2144,10 @@ class AsyncGatherCUDAWork : public AsyncGatherWork { void synchronize() override { // Synchronize with the copy back to CUDA tensors. - for(const auto i : c10::irange(outputs.size())) { + for (const auto i : c10::irange(outputs.size())) { c10::Device device = outputs[i][0].device(); - outputEvents[i].block(c10::impl::VirtualGuardImpl(device.type()).getStream(device)); + outputEvents[i].block( + c10::impl::VirtualGuardImpl(device.type()).getStream(device)); } } @@ -2286,10 +2291,10 @@ class AsyncScatterCUDAWork : public AsyncScatterWork { // Kick off copy from CUDA tensors to pinned CPU tensors. tmpInputs.resize(inputs.size()); c10::OptionalStreamGuard guard; - for(const auto i : c10::irange(inputs.size())) { + for (const auto i : c10::irange(inputs.size())) { guard.reset_stream(inputStreams[i]); tmpInputs[i].reserve(inputs[i].size()); - for(const auto j : c10::irange(inputs[i].size())) { + for (const auto j : c10::irange(inputs[i].size())) { tmpInputs[i].push_back( pinnedLike(inputs[i][j]).copy_(inputs[i][j], true)); } @@ -2303,10 +2308,10 @@ class AsyncScatterCUDAWork : public AsyncScatterWork { void run() override { // Synchronize with copy operations. - for(const auto i : c10::irange(inputs.size())) { + for (const auto i : c10::irange(inputs.size())) { inputStreams[i].synchronize(); } - for(const auto i : c10::irange(outputs.size())) { + for (const auto i : c10::irange(outputs.size())) { outputStreams[i].synchronize(); } @@ -2315,7 +2320,7 @@ class AsyncScatterCUDAWork : public AsyncScatterWork { // Kick off copy back to the CUDA tensors. c10::OptionalStreamGuard guard; - for(const auto i : c10::irange(outputs.size())) { + for (const auto i : c10::irange(outputs.size())) { guard.reset_stream(outputStreams[i]); outputs[i].copy_(tmpOutputs[i], /* non_blocking */ true); outputEvents[i].record(outputStreams[i]); @@ -2324,9 +2329,10 @@ class AsyncScatterCUDAWork : public AsyncScatterWork { void synchronize() override { // Synchronize with the copy back to CUDA tensors. - for(const auto i : c10::irange(outputs.size())) { + for (const auto i : c10::irange(outputs.size())) { c10::Device device = outputs[i].device(); - outputEvents[i].block(c10::impl::VirtualGuardImpl(device.type()).getStream(device)); + outputEvents[i].block( + c10::impl::VirtualGuardImpl(device.type()).getStream(device)); } } @@ -2519,7 +2525,8 @@ class AsyncAlltoallCUDAWork : public AsyncAlltoallWork { void synchronize() override { // Synchronize with the copy back to CUDA tensors. c10::Device device = outputTensor.device(); - outputEvents.front().block(c10::impl::VirtualGuardImpl(device.type()).getStream(device)); + outputEvents.front().block( + c10::impl::VirtualGuardImpl(device.type()).getStream(device)); } at::Tensor cpuOutput; @@ -2837,8 +2844,8 @@ void ProcessGroupGloo::monitoredBarrier( waitLoop(sendWorkMap); using namespace std::chrono; - C10_UNUSED auto elapsedTime = duration_cast( - steady_clock::now() - startTime); + C10_UNUSED auto elapsedTime = + duration_cast(steady_clock::now() - startTime); } void ProcessGroupGloo::setSequenceNumberForGroup() { diff --git a/torch/csrc/distributed/c10d/ProcessGroupMPI.cpp b/torch/csrc/distributed/c10d/ProcessGroupMPI.cpp index 556ab138871235..e2f28507a8db69 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupMPI.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupMPI.cpp @@ -2,9 +2,9 @@ #ifdef USE_C10D_MPI +#include #include #include -#include #include #include @@ -22,7 +22,7 @@ namespace c10d { std::string err = "MPI error in: " + std::string(__FILE__) + ":" + \ std::to_string(__LINE__) + \ ", with error code: " + std::to_string(mpiStatus); \ - TORCH_CHECK(false, err); \ + TORCH_CHECK(false, err); \ } \ } while (0) @@ -70,7 +70,8 @@ void checkSingleTensorHelper(const at::Tensor& tensor) { TORCH_CHECK(false, "input tensor has to be dense"); } if (tensor.is_cuda() && !cudaAwareMpiCheck()) { - TORCH_CHECK(false, + TORCH_CHECK( + false, "CUDA tensor detected and the MPI used doesn't " "have CUDA-aware MPI support"); } @@ -78,8 +79,8 @@ void checkSingleTensorHelper(const at::Tensor& tensor) { void checkSingleTensor(const std::vector& tensors) { if (tensors.size() != 1) { - TORCH_CHECK(false, - "MPI process group does not support multi-GPU collectives"); + TORCH_CHECK( + false, "MPI process group does not support multi-GPU collectives"); } checkSingleTensorHelper(tensors[0]); } @@ -159,7 +160,8 @@ bool ProcessGroupMPI::AsyncWork::isCompleted() { bool ProcessGroupMPI::AsyncWork::isSuccess() const { if (request_ != MPI_REQUEST_NULL) { - TORCH_CHECK(false, + TORCH_CHECK( + false, "Invalid call to AsyncWork::isSuccess before work has completed"); } @@ -200,9 +202,8 @@ bool ProcessGroupMPI::AsyncWork::wait(std::chrono::milliseconds /* unused */) { return true; } -void ProcessGroupMPI::AsyncWork::abort() { - TORCH_CHECK(false, "ProcessGroupMPI::AsyncWork::abort not implemented.") -} +void ProcessGroupMPI::AsyncWork::abort(){ + TORCH_CHECK(false, "ProcessGroupMPI::AsyncWork::abort not implemented.")} std::vector ProcessGroupMPI::AsyncWork::result() { return outputTensors_; @@ -233,7 +234,8 @@ void ProcessGroupMPI::initMPIOnce() { MPI_CHECK(MPI_Init_thread( nullptr, nullptr, MPI_THREAD_SERIALIZED, &mpiThreadSupport_)); if (mpiThreadSupport_ < MPI_THREAD_SERIALIZED) { - TORCH_CHECK(false, + TORCH_CHECK( + false, "Used MPI implementation doesn't have the " "minimum level of threading support: " "MPI_THREAD_SERIALIZED. This is required by " @@ -372,7 +374,8 @@ c10::intrusive_ptr ProcessGroupMPI::enqueue( std::unique_ptr entry, const char* profilingTitle, const c10::optional>& inputTensors) { - auto work = c10::make_intrusive(entry->dst, profilingTitle, inputTensors); + auto work = + c10::make_intrusive(entry->dst, profilingTitle, inputTensors); std::unique_lock lock(pgMutex_); queue_.push_back(std::make_tuple(std::move(entry), work)); lock.unlock(); @@ -396,7 +399,8 @@ c10::intrusive_ptr ProcessGroupMPI::broadcast( opts.rootRank, pgComm_)); }; - auto entry = std::make_unique(&tensors, &tensors, std::move(runFunc)); + auto entry = + std::make_unique(&tensors, &tensors, std::move(runFunc)); return enqueue( std::move(entry), "mpi:broadcast", @@ -421,7 +425,8 @@ c10::intrusive_ptr ProcessGroupMPI::allreduce( mpiOp.at(opts.reduceOp), pgComm_)); }; - auto entry = std::make_unique(&tensors, &tensors, std::move(runFunc)); + auto entry = + std::make_unique(&tensors, &tensors, std::move(runFunc)); return enqueue( std::move(entry), "mpi:all_reduce", @@ -431,8 +436,7 @@ c10::intrusive_ptr ProcessGroupMPI::allreduce( c10::intrusive_ptr ProcessGroupMPI::allreduce_coalesced( std::vector& tensors, const AllreduceCoalescedOptions& opts) { - TORCH_CHECK(false, - "allreduce_coalesced is currently not supported with MPI"); + TORCH_CHECK(false, "allreduce_coalesced is currently not supported with MPI"); } c10::intrusive_ptr ProcessGroupMPI::reduce( @@ -458,7 +462,8 @@ c10::intrusive_ptr ProcessGroupMPI::reduce( opts.rootRank, pgComm_)); }; - auto entry = std::make_unique(&tensors, &tensors, std::move(runFunc)); + auto entry = + std::make_unique(&tensors, &tensors, std::move(runFunc)); return enqueue( std::move(entry), "mpi:reduce", @@ -471,12 +476,14 @@ c10::intrusive_ptr ProcessGroupMPI::allgather( const AllgatherOptions& opts) { checkSingleTensor(inputTensors); if (outputTensors.size() != 1) { - TORCH_CHECK(false, + TORCH_CHECK( + false, "MPI process group only supports a single " "tensor op"); } if (static_cast(size_) != outputTensors[0].size()) { - TORCH_CHECK(false, + TORCH_CHECK( + false, "All gather: number of output tensors should equal " "to the world size"); } @@ -516,8 +523,7 @@ c10::intrusive_ptr ProcessGroupMPI::allgather_coalesced( std::vector>& /* unused */, std::vector& /* unused */, const AllgatherOptions& /* unused */) { - TORCH_CHECK(false, - "ProcessGroupMPI does not support allgather_coalesced"); + TORCH_CHECK(false, "ProcessGroupMPI does not support allgather_coalesced"); } c10::intrusive_ptr ProcessGroupMPI::gather( @@ -528,7 +534,8 @@ c10::intrusive_ptr ProcessGroupMPI::gather( if (rank_ != opts.rootRank) { if (outputTensors.size() > 0) { - TORCH_CHECK(false, + TORCH_CHECK( + false, "Gather: number of output tensors should be 0 " "for non-root"); } @@ -537,7 +544,8 @@ c10::intrusive_ptr ProcessGroupMPI::gather( TORCH_CHECK(false, "Gather: multi-GPU collective is not supported"); } if (static_cast(size_) != outputTensors[0].size()) { - TORCH_CHECK(false, + TORCH_CHECK( + false, "Gather: number of output tensors should equal " "to the world size"); } @@ -585,8 +593,8 @@ c10::intrusive_ptr ProcessGroupMPI::gather( "mpi:gather", c10::optional>(inputTensors)); } else { - auto entry = std::make_unique( - &inputTensors, nullptr, std::move(runFunc)); + auto entry = + std::make_unique(&inputTensors, nullptr, std::move(runFunc)); return enqueue( std::move(entry), "mpi:gather", @@ -602,17 +610,18 @@ c10::intrusive_ptr ProcessGroupMPI::scatter( if (rank_ != opts.rootRank) { if (inputTensors.size() > 0) { - TORCH_CHECK(false, + TORCH_CHECK( + false, "Scatter: number of input tensors should be 0 " "for non-root"); } } else { if (inputTensors.size() != 1) { - TORCH_CHECK(false, - "Scatter: multi-GPU collective is not supported"); + TORCH_CHECK(false, "Scatter: multi-GPU collective is not supported"); } if (static_cast(size_) != inputTensors[0].size()) { - TORCH_CHECK(false, + TORCH_CHECK( + false, "Scatter: number of input tensors should equal " "to the world size"); } @@ -660,7 +669,7 @@ c10::intrusive_ptr ProcessGroupMPI::scatter( : c10::nullopt); } else { auto entry = std::make_unique( - nullptr, &outputTensors, std::move(runFunc)); + nullptr, &outputTensors, std::move(runFunc)); return enqueue( std::move(entry), "mpi:scatter", @@ -790,7 +799,7 @@ c10::intrusive_ptr ProcessGroupMPI::alltoall( at::Tensor dstFlatData = at::empty({dst_len}, dstdata[0].options()); auto srcFlatDataSplits = srcFlatData.split_with_sizes(c10::IntArrayRef(send_lengthsL), 0); - for(const auto i : c10::irange(size_)) { + for (const auto i : c10::irange(size_)) { srcFlatDataSplits[i].copy_(srcdata[i].view({-1})); } c10::DeviceGuard guard1(srcdata[0].device()); @@ -808,7 +817,7 @@ c10::intrusive_ptr ProcessGroupMPI::alltoall( auto dstFlatDataSplits = dstFlatData.split_with_sizes(c10::IntArrayRef(recv_lengthsL), 0); - for(const auto i : c10::irange(size_)) { + for (const auto i : c10::irange(size_)) { dstdata[i].view({-1}).copy_(dstFlatDataSplits[i]); } }; @@ -913,7 +922,8 @@ c10::intrusive_ptr ProcessGroupMPI::barrier( std::unique_lock globalLock(pgGlobalMutex_); MPI_CHECK(MPI_Barrier(pgComm_)); }; - auto entry = std::make_unique(nullptr, nullptr, std::move(runFunc)); + auto entry = + std::make_unique(nullptr, nullptr, std::move(runFunc)); return enqueue(std::move(entry), "mpi:barrier", c10::nullopt); } @@ -921,8 +931,7 @@ c10::intrusive_ptr ProcessGroupMPI::_allgather_base( at::Tensor& /*unused */, at::Tensor& /*unused */, const AllgatherOptions& /*unused */) { - TORCH_CHECK(false, - "no support for _allgather_base in MPI process group"); + TORCH_CHECK(false, "no support for _allgather_base in MPI process group"); } } // namespace c10d diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp index 3686309ff66b06..6aee5c0a9ca49a 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp @@ -24,7 +24,6 @@ #include - namespace c10d { constexpr const char* const kNCCLAbortedCommStoreKey = "NCCLABORTEDCOMM"; @@ -49,8 +48,8 @@ struct AutoNcclGroup { } }; -#if defined(NCCL_MAJOR) && ((NCCL_MAJOR > 2) || \ - (NCCL_MAJOR == 2) && (NCCL_MINOR >= 10)) +#if defined(NCCL_MAJOR) && \ + ((NCCL_MAJOR > 2) || (NCCL_MAJOR == 2) && (NCCL_MINOR >= 10)) #define NCCL_HAS_AVG 1 #endif @@ -109,9 +108,12 @@ ncclRedOp_t getNcclReduceOp(const ReduceOp reduceOp, at::Tensor& input) { } catch (const std::out_of_range& e) { switch (reduceOp) { case ReduceOp::AVG: - TORCH_CHECK(false, - "AVG requires NCCL 2.10+. The current version is ", - NCCL_MAJOR, ".", NCCL_MINOR); + TORCH_CHECK( + false, + "AVG requires NCCL 2.10+. The current version is ", + NCCL_MAJOR, + ".", + NCCL_MINOR); break; case ReduceOp::BAND: TORCH_CHECK(false, "Cannot use ReduceOp.BAND with NCCL"); @@ -157,7 +159,8 @@ std::vector getDeviceList(const std::vector& tensors) { for (auto& tensor : tensors) { // tensors must all be on the same device, or all on distinct devices. // The line below assumes that constraint has already been enforced - // (by check_gpu_tensors_same_device or check_gpu_tensors_different_devices). + // (by check_gpu_tensors_same_device or + // check_gpu_tensors_different_devices). if (res.size() == 0 || tensor.device() != res[0]) { res.push_back(tensor.device()); } @@ -203,7 +206,7 @@ void syncStreams( std::string buildNcclUniqueIdStr(const ncclUniqueId& ncclID) { const uint8_t* bytes = reinterpret_cast(&ncclID); std::ostringstream oss; - for(const auto i : c10::irange(NCCL_UNIQUE_ID_BYTES)) { + for (const auto i : c10::irange(NCCL_UNIQUE_ID_BYTES)) { oss << std::hex << static_cast(bytes[i]); } return oss.str(); @@ -229,11 +232,13 @@ std::string getExceptionMsgFromExceptionPtr( inline void errorIfCapturingNonCapturableNCCL() { auto status = c10::cuda::currentStreamCaptureStatusMayInitCtx(); // parentheses avoid some compiler warnings - static const uint64_t min_version = (((uint64_t)2) << 32) + (((uint64_t)9) << 16) + ((uint64_t)6); + static const uint64_t min_version = + (((uint64_t)2) << 32) + (((uint64_t)9) << 16) + ((uint64_t)6); static const uint64_t cur_version = torch::cuda::nccl::version(); if (cur_version < min_version) { - TORCH_CHECK(status == c10::cuda::CaptureStatus::None, - "Capturing NCCL collectives is only allowed with NCCL >= 2.9.6"); + TORCH_CHECK( + status == c10::cuda::CaptureStatus::None, + "Capturing NCCL collectives is only allowed with NCCL >= 2.9.6"); } } @@ -383,12 +388,12 @@ bool ProcessGroupNCCL::WorkNCCL::finishedGPUExecutionInternal() const { } } } catch (const std::exception& e) { - if (std::string(e.what()).find("driver shutting down") == std::string::npos) { + if (std::string(e.what()).find("driver shutting down") == + std::string::npos) { throw; } LOG(INFO) << "[Rank " << rank_ - << "] Event query failed with exception: " - << e.what(); + << "] Event query failed with exception: " << e.what(); } return true; } @@ -453,7 +458,8 @@ void ProcessGroupNCCL::WorkNCCL::synchronizeInternal( std::stringstream ss; ss << *this; - auto timeoutErrorMsg = c10::str("Work ", ss.str(), " timed out in call to wait()."); + auto timeoutErrorMsg = + c10::str("Work ", ss.str(), " timed out in call to wait()."); for (const auto& ncclComm : ncclComms_) { ncclComm->ncclCommAbort(timeoutErrorMsg); const auto& storeKey = getNcclAbortedCommStoreKey( @@ -504,13 +510,13 @@ void ProcessGroupNCCL::WorkNCCL::synchronizeInternal( // Same as calling synchronize(). bool ProcessGroupNCCL::WorkNCCL::wait(std::chrono::milliseconds timeout) { RECORD_PARAM_COMMS( - rank_, // rank - "wait", // colName - 0, // inSize - 0, // outSize - at::kByte, // dType - std::vector(), // inSplitSizes - std::vector()); // outSplitSizes + rank_, // rank + "wait", // colName + 0, // inSize + 0, // outSize + at::kByte, // dType + std::vector(), // inSplitSizes + std::vector()); // outSplitSizes synchronizeInternal(timeout); // Always return true, because abort API is not implemented. return true; @@ -595,20 +601,22 @@ ProcessGroupNCCL::ProcessGroupNCCL( #ifdef USE_NCCL_WITH_UCC static std::once_flag initialize_ucc_lib_flag; - std::call_once(initialize_ucc_lib_flag, [&]{ + std::call_once(initialize_ucc_lib_flag, [&] { uccLib_ = loadTorchUCC(); if (uccLib_ != nullptr) { - LOG(INFO) << "[Rank " << rank_ << "] torch_ucc.so loaded"; + LOG(INFO) << "[Rank " << rank_ << "] torch_ucc.so loaded"; } }); if (uccLib_ != nullptr) { - LOG(INFO) << "[Rank " << rank_ << "] torch_ucc.so loaded"; - typedef c10::intrusive_ptr fn(const c10::intrusive_ptr& store, int rank, int size); - auto createProcessGroupUCC = reinterpret_cast(uccLib_->sym("createProcessGroupUCC")); + LOG(INFO) << "[Rank " << rank_ << "] torch_ucc.so loaded"; + typedef c10::intrusive_ptr fn( + const c10::intrusive_ptr& store, int rank, int size); + auto createProcessGroupUCC = + reinterpret_cast(uccLib_->sym("createProcessGroupUCC")); if (createProcessGroupUCC != nullptr) { uccPG_ = createProcessGroupUCC(store, rank_, size_); - LOG(INFO) << "[Rank " << rank_ << "] ProcessGroupUCC created."; + LOG(INFO) << "[Rank " << rank_ << "] ProcessGroupUCC created."; } } #endif @@ -778,8 +786,8 @@ void ProcessGroupNCCL::ncclCommWatchdogInternal() { } std::exception_ptr ncclErrorException = checkForNCCLErrors(ncclComms); if (ncclErrorException) { - auto exceptionMsg - = getExceptionMsgFromExceptionPtr(ncclErrorException); + auto exceptionMsg = + getExceptionMsgFromExceptionPtr(ncclErrorException); LOG(INFO) << "[Rank " << rank_ << "] Received NCCL errors for communicators in the cache: \n" @@ -856,11 +864,10 @@ void ProcessGroupNCCL::ncclCommWatchdogInternal() { auto val = store_->get(storeKey); std::string rank(reinterpret_cast(val.data()), val.size()); std::stringstream ss; - ss << "[Rank " << rank_ - << "] Found key in store: " << storeKey - << ", from rank: " << rank - << ". This means that rank has aborted its NCCL communicators previously and is not in a healthy state." - << ". Aborting appropriate communicators"; + ss << "[Rank " << rank_ << "] Found key in store: " << storeKey + << ", from rank: " << rank + << ". This means that rank has aborted its NCCL communicators previously and is not in a healthy state." + << ". Aborting appropriate communicators"; std::string abortReason = ss.str(); LOG(WARNING) << abortReason; @@ -936,7 +943,7 @@ void ProcessGroupNCCL::workCleanupLoop() { opTypeToString(work.opType_)); } if (!terminateProcessGroup_.load() && !storeError_) { - storeError_= !c10d::traceUpdate( + storeError_ = !c10d::traceUpdate( store_, traceKeyEnd_, work.seq_, @@ -979,14 +986,9 @@ std::exception_ptr ProcessGroupNCCL::checkForNCCLErrorsInternal( // commFailureReason is set. auto commFailureReason = ncclComm->getNcclCommFailureReason(); if (commFailureReason != c10::nullopt) { - return std::make_exception_ptr( - std::runtime_error( - c10::str( - "NCCL communicator encountered error set by ProcessGroupNCCL: ", - *commFailureReason - ) - ) - ); + return std::make_exception_ptr(std::runtime_error(c10::str( + "NCCL communicator encountered error set by ProcessGroupNCCL: ", + *commFailureReason))); } ncclResult_t ncclAsyncErr = ncclComm->checkForNcclError(); if (ncclAsyncErr != ncclSuccess) { @@ -1059,7 +1061,6 @@ void ProcessGroupNCCL::broadcastUniqueNCCLID( } } - void ProcessGroupNCCL::destroyNCCLComms(const std::string& devNCCLCommMapKey) { std::lock_guard lock(mutex_); if (devNCCLCommMap_.find(devNCCLCommMapKey) == devNCCLCommMap_.end()) { @@ -1091,7 +1092,8 @@ std::vector>& ProcessGroupNCCL::getNCCLComm( bool isSendRecvSelf) { // Sanity check if (devicesKey.empty()) { - TORCH_CHECK(false, + TORCH_CHECK( + false, "Not able to create/get the NCCL Communicator since " "the GPU devices are not known"); } @@ -1171,8 +1173,8 @@ std::vector>& ProcessGroupNCCL::getNCCLComm( numRanks = 1; rank = 0; } else { - // For single point-to-point operation, there are only 2 processes involved so - // the GPU rank is either 0 or 1. + // For single point-to-point operation, there are only 2 processes + // involved so the GPU rank is either 0 or 1. numRanks = 2; rank = p2pRank; } @@ -1190,8 +1192,8 @@ std::vector>& ProcessGroupNCCL::getNCCLComm( // [Note 2 ] C10D_NCCL_CHECK(ncclGroupEnd(), c10::nullopt); - // At this point NCCL should have been initialized, hence we can accurately get - // the env value even if NCCL sets it by reading from nccl.conf file + // At this point NCCL should have been initialized, hence we can accurately + // get the env value even if NCCL sets it by reading from nccl.conf file if (getRank() == 0) { LOG(INFO) << "NCCL_DEBUG: " << parse_env("NCCL_DEBUG"); } @@ -1236,15 +1238,19 @@ void check_gpu_single_tensor(const at::Tensor& tensor) { } } -// Checks that all `tensors' have the same type and shape and reside on distinct GPUs. -// TODO: test_c10d_nccl.py should consider adding tests for the error conditions here, -// ie, that deliberately pass invalid tensors and check the right exception is thrown. -void check_gpu_tensors_different_devices(const std::vector& tensors) { +// Checks that all `tensors' have the same type and shape and reside on distinct +// GPUs. +// TODO: test_c10d_nccl.py should consider adding tests for the error conditions +// here, ie, that deliberately pass invalid tensors and check the right +// exception is thrown. +void check_gpu_tensors_different_devices( + const std::vector& tensors) { if (tensors.size() == 0) { TORCH_CHECK(false, "Tensor list must be nonempty"); } if (tensors.size() > static_cast(at::cuda::getNumGPUs())) { - TORCH_CHECK(false, + TORCH_CHECK( + false, "Tensor list mustn't be larger than the number of available GPUs"); } @@ -1277,11 +1283,13 @@ void check_gpu_tensors_different_devices(const std::vector& tensors) } } -// Checks that all `tensors' have the same type and shape and reside on the same GPU. -// TODO: test_c10d_nccl.py should consider adding tests for the error conditions here, -// ie, that deliberately pass invalid tensors and check the right exception is thrown. -// The "Expected list of tensors on the same device" condition may be a challenge -// because the test would need to pass tensors on different devices in the same process. +// Checks that all `tensors' have the same type and shape and reside on the same +// GPU. +// TODO: test_c10d_nccl.py should consider adding tests for the error conditions +// here, ie, that deliberately pass invalid tensors and check the right +// exception is thrown. The "Expected list of tensors on the same device" +// condition may be a challenge because the test would need to pass tensors on +// different devices in the same process. int64_t check_gpu_tensors_same_device(const std::vector& tensors) { if (tensors.size() == 0) { TORCH_CHECK(false, "Tensor list must be nonempty"); @@ -1304,8 +1312,9 @@ int64_t check_gpu_tensors_same_device(const std::vector& tensors) { // on a set of tensors with potentially different sizes and strides. // Therefore, we don't check for matching sizes and strides, // but we do double-check tensors are on the same device. - TORCH_CHECK(t.get_device() == tensors[0].get_device(), - "Expected list of tensors on the same device"); + TORCH_CHECK( + t.get_device() == tensors[0].get_device(), + "Expected list of tensors on the same device"); total_numel += t.numel(); } @@ -1319,7 +1328,8 @@ std::vector flatten_for_scatter_gather( std::vector& other, size_t world_size) { if (tensor_lists.size() != other.size()) { - TORCH_CHECK(false, + TORCH_CHECK( + false, "Tensor list operands to scatter/gather must have the same length"); } const auto num_devices = tensor_lists.size(); @@ -1329,7 +1339,8 @@ std::vector flatten_for_scatter_gather( for (const auto i : c10::irange(size_t{}, num_devices)) { if (tensor_lists[i].size() != world_size * num_devices) { - TORCH_CHECK(false, + TORCH_CHECK( + false, "Tensor list input to scatter/gather must match number of collective" " participants"); } @@ -1337,14 +1348,16 @@ std::vector flatten_for_scatter_gather( // Only check device match for the first tensor in the list; the call to // newLikeFlat() below will check the rest. if (tensor_lists[i].front().get_device() != other[i].get_device()) { - TORCH_CHECK(false, + TORCH_CHECK( + false, "Corresponding input/output tensors to scatter/gather must all reside" " on the same device"); } for (const auto& t : tensor_lists[i]) { if (t.numel() != other[i].numel()) { - TORCH_CHECK(false, + TORCH_CHECK( + false, "All tensor operands to scatter/gather must have the same number of elements"); } } @@ -1363,13 +1376,7 @@ c10::intrusive_ptr ProcessGroupNCCL::initWork( const char* profilingTitle, const c10::optional>& inputs) { return c10::make_intrusive( - devices, - rank, - opType, - seq_, - profilingTitle, - inputs, - desyncDebug_); + devices, rank, opType, seq_, profilingTitle, inputs, desyncDebug_); } std::vector ProcessGroupNCCL::WorkNCCL::result() { @@ -1406,21 +1413,23 @@ c10::intrusive_ptr ProcessGroupNCCL::collective( PostProcess post, OpType opType, const char* profilingTitle) { - errorIfCapturingNonCapturableNCCL(); // Bump collective counter seq_++; - // Currently, the API permits two scenarios where inputs.size() and outputs.size() are > 0. - // 1. If the call was a _coalesced call, all inputs must be on the same device. + // Currently, the API permits two scenarios where inputs.size() and + // outputs.size() are > 0. + // 1. If the call was a _coalesced call, all inputs must be on the same + // device. // The group of nccl calls applies the collective separately to each input, - // but the group as a whole should be efficient, and might even execute as a - // single fused kernel. - // 2. If the call was a _multigpu call, all inputs must be on different devices. - // The nccl group applies the collective across them (eg, if the collective is - // an allreduce, the output on each device contains contributions summed across - // `inputs' tensors). + // but the group as a whole should be efficient, and might even execute as + // a single fused kernel. + // 2. If the call was a _multigpu call, all inputs must be on different + // devices. + // The nccl group applies the collective across them (eg, if the collective + // is an allreduce, the output on each device contains contributions summed + // across `inputs' tensors). const auto devices = getDeviceList(inputs); const bool inputs_same_dev = (devices.size() == 1); const auto key = getKeyFromDevices(devices); @@ -1476,11 +1485,9 @@ c10::intrusive_ptr ProcessGroupNCCL::collective( // See [Sync Streams]. c10::cuda::CUDACachingAllocator::recordStream( inputs[i].storage().data_ptr(), ncclStream); - C10D_NCCL_CHECK(fn(inputs[i], - outputs[i], - ncclComm->getNcclComm(), - ncclStream), - ncclComm->getNcclCommFailureReason()); + C10D_NCCL_CHECK( + fn(inputs[i], outputs[i], ncclComm->getNcclComm(), ncclStream), + ncclComm->getNcclCommFailureReason()); } } @@ -1496,15 +1503,15 @@ c10::intrusive_ptr ProcessGroupNCCL::collective( { c10::cuda::CUDAMultiStreamGuard streamGuard(ncclStreams); work->future_ = c10::make_intrusive( - c10::ListType::create(c10::TensorType::get()), - devices); + c10::ListType::create(c10::TensorType::get()), devices); // Add a callback that runs profiling end callbacks. wrapCallback() in CUDA // future blocks the stream this callback runs on the corresponding // ncclEndEvents_ ensuring appropriate synchronization. if (work->recordFunctionEndCallback_) { - work->future_->addCallback( - [work](at::ivalue::Future& /* unused */) { work->recordFunctionEndCallback_(); }); + work->future_->addCallback([work](at::ivalue::Future& /* unused */) { + work->recordFunctionEndCallback_(); + }); } work->future_->markCompleted(at::IValue(*work->outputs_)); } @@ -1601,8 +1608,12 @@ c10::intrusive_ptr ProcessGroupNCCL::pointToPoint( for (const auto i : c10::irange(tensors.size())) { gpuGuard.set_index(devices[i].index()); at::cuda::CUDAStream& ncclStream = ncclStreams_[key][i]; - C10D_NCCL_CHECK(fn( - tensors[i], ncclComms[i]->getNcclComm(), ncclStream, p2pTargetRank), ncclComms[i]->getNcclCommFailureReason()); + C10D_NCCL_CHECK( + fn(tensors[i], + ncclComms[i]->getNcclComm(), + ncclStream, + p2pTargetRank), + ncclComms[i]->getNcclCommFailureReason()); } } @@ -1624,8 +1635,7 @@ c10::intrusive_ptr ProcessGroupNCCL::pointToPoint( { c10::cuda::CUDAMultiStreamGuard streamGuard(ncclStreams_[key]); work->future_ = c10::make_intrusive( - c10::ListType::create(c10::TensorType::get()), - devices); + c10::ListType::create(c10::TensorType::get()), devices); work->future_->markCompleted(at::IValue(*work->outputs_)); } @@ -1633,8 +1643,9 @@ c10::intrusive_ptr ProcessGroupNCCL::pointToPoint( // future blocks the stream this callback runs on the corresponding // ncclEndEvents_ ensuring appropriate synchronization. if (work->recordFunctionEndCallback_) { - work->future_->addCallback( - [work](at::ivalue::Future& /* unused */) { work->recordFunctionEndCallback_(); }); + work->future_->addCallback([work](at::ivalue::Future& /* unused */) { + work->recordFunctionEndCallback_(); + }); } return work; @@ -1743,13 +1754,13 @@ c10::intrusive_ptr ProcessGroupNCCL::broadcast( // @lint-ignore CLANGTIDY auto tensor = tensors.back(); RECORD_PARAM_COMMS( - rank_, // rank - "broadcast", // colName - tensor.numel(), // inSize - tensor.numel(), // outSize - tensor.scalar_type(), // dType - std::vector(), // inSplitSizes - std::vector()); // outSplitSizes + rank_, // rank + "broadcast", // colName + tensor.numel(), // inSize + tensor.numel(), // outSize + tensor.scalar_type(), // dType + std::vector(), // inSplitSizes + std::vector()); // outSplitSizes return collective( tensors, @@ -1778,13 +1789,13 @@ c10::intrusive_ptr ProcessGroupNCCL::reduce( // @lint-ignore CLANGTIDY auto tensor = tensors.back(); RECORD_PARAM_COMMS( - rank_, // rank - "reduce", // colName - tensor.numel(), // inSize - tensor.numel(), // outSize - tensor.scalar_type(), // dType - std::vector(), // inSplitSizes - std::vector()); // outSplitSizes + rank_, // rank + "reduce", // colName + tensor.numel(), // inSize + tensor.numel(), // outSize + tensor.scalar_type(), // dType + std::vector(), // inSplitSizes + std::vector()); // outSplitSizes return collective( tensors, @@ -1821,14 +1832,14 @@ c10::intrusive_ptr ProcessGroupNCCL::allgather( // @lint-ignore CLANGTIDY auto tensor = inputTensors.back(); RECORD_PARAM_COMMS( - rank_, // rank - "all_gather", // colName - tensor.numel(), // inSize - tensor.numel() * // outSize - this->getSize(), // dType + rank_, // rank + "all_gather", // colName + tensor.numel(), // inSize + tensor.numel() * // outSize + this->getSize(), // dType tensor.scalar_type(), - std::vector(), // inSplitSizes - std::vector()); // outSplitSize + std::vector(), // inSplitSizes + std::vector()); // outSplitSize return collective( inputTensors, @@ -1869,8 +1880,7 @@ c10::intrusive_ptr ProcessGroupNCCL::allgather_coalesced( std::vector>& /* unused */, std::vector& /* unused */, const AllgatherOptions& /* unused */) { - TORCH_CHECK(false, - "ProcessGroupNCCL does not support allgather_coalesced"); + TORCH_CHECK(false, "ProcessGroupNCCL does not support allgather_coalesced"); } c10::intrusive_ptr ProcessGroupNCCL::reduce_scatter( @@ -1882,14 +1892,14 @@ c10::intrusive_ptr ProcessGroupNCCL::reduce_scatter( // @lint-ignore CLANGTIDY auto tensor = outputTensors.back(); RECORD_PARAM_COMMS( - rank_, // rank - "reduce_scatter", // colName - tensor.numel() * // inSize - this->getSize(), // outSize - tensor.numel(), // dType + rank_, // rank + "reduce_scatter", // colName + tensor.numel() * // inSize + this->getSize(), // outSize + tensor.numel(), // dType tensor.scalar_type(), - std::vector(), // inSplitSizes - std::vector()); // outSplitSizes + std::vector(), // inSplitSizes + std::vector()); // outSplitSizes auto inputFlattened = flatten_for_scatter_gather(inputTensors, outputTensors, size_); @@ -1935,29 +1945,31 @@ c10::intrusive_ptr ProcessGroupNCCL::_reduce_scatter_base( at::Tensor& outputTensor, at::Tensor& inputTensor, const ReduceScatterOptions& opts) { - if (inputTensor.dtype() != outputTensor.dtype()) { - TORCH_CHECK(false, "input tensor must be the same type as the outut tensor."); + TORCH_CHECK( + false, "input tensor must be the same type as the outut tensor."); } if (inputTensor.numel() != outputTensor.numel() * size_) { - TORCH_CHECK(false, "input tensor must be the same size as output size times world size"); + TORCH_CHECK( + false, + "input tensor must be the same size as output size times world size"); } // @lint-ignore CLANGTIDY const auto& tensor = outputTensor; RECORD_PARAM_COMMS( - rank_, // rank - "_reduce_scatter_base", // colName - tensor.numel() * // inSize - this->getSize(), - tensor.numel(), // outSize - tensor.scalar_type(), // dtype - std::vector(), // inSplitSizes - std::vector()); // outSplitSizes - - auto inputs = std::vector {inputTensor}; - auto outputs = std::vector {outputTensor}; + rank_, // rank + "_reduce_scatter_base", // colName + tensor.numel() * // inSize + this->getSize(), + tensor.numel(), // outSize + tensor.scalar_type(), // dtype + std::vector(), // inSplitSizes + std::vector()); // outSplitSizes + + auto inputs = std::vector{inputTensor}; + auto outputs = std::vector{outputTensor}; return collective( inputs, @@ -1986,13 +1998,13 @@ c10::intrusive_ptr ProcessGroupNCCL::_reduce_scatter_base( c10::intrusive_ptr ProcessGroupNCCL::barrier( const BarrierOptions& opts) { RECORD_PARAM_COMMS( - rank_, // rank - "barrier", // colName - 0, // inSize - 0, // outSize - at::kByte, // dType - std::vector(), // inSplitSizes - std::vector()); // outSplitSizes + rank_, // rank + "barrier", // colName + 0, // inSize + 0, // outSize + at::kByte, // dType + std::vector(), // inSplitSizes + std::vector()); // outSplitSizes std::vector devices; @@ -2010,14 +2022,13 @@ c10::intrusive_ptr ProcessGroupNCCL::barrier( auto numGPUs = at::cuda::getNumGPUs(); int16_t deviceIdx = static_cast(rank_ % numGPUs); LOG(INFO) << c10::str( - "Rank ", - this->getRank(), - " using GPU ", - deviceIdx, - " to perform barrier as devices used by this process are currently unknown. ", + "Rank ", + this->getRank(), + " using GPU ", + deviceIdx, + " to perform barrier as devices used by this process are currently unknown. ", "This can potentially cause a hang if this rank to GPU mapping is incorrect.", - "Specify device_ids in barrier() to force use of a particular device." - ); + "Specify device_ids in barrier() to force use of a particular device."); devices.emplace_back(getDeviceForRank(rank_)); } else { for (auto usedDeviceIdx : usedDeviceIdxs_) { @@ -2061,13 +2072,13 @@ c10::intrusive_ptr ProcessGroupNCCL::alltoall_base( std::vector outputTensors = {outputTensor}; RECORD_PARAM_COMMS( - rank_, // rank - "all_to_all", // colName - inputTensor.numel(), // inSize - outputTensor.numel(), // outSize - inputTensor.scalar_type(),// dType - std::vector(), // inSplitSizes - std::vector()); // outSplitSizes + rank_, // rank + "all_to_all", // colName + inputTensor.numel(), // inSize + outputTensor.numel(), // outSize + inputTensor.scalar_type(), // dType + std::vector(), // inSplitSizes + std::vector()); // outSplitSizes return collective( inputTensors, @@ -2092,13 +2103,13 @@ c10::intrusive_ptr ProcessGroupNCCL::alltoall_base( std::vector outputTensors = {outputTensor}; RECORD_PARAM_COMMS( - rank_, // rank - "all_to_allv", // colName - inputTensor.numel(), // inSize - outputTensor.numel(), // outSize - inputTensor.scalar_type(),// dType - inputSplitSizes, // inSplitSizes - outputSplitSizes); // outSplitSizes + rank_, // rank + "all_to_allv", // colName + inputTensor.numel(), // inSize + outputTensor.numel(), // outSize + inputTensor.scalar_type(), // dType + inputSplitSizes, // inSplitSizes + outputSplitSizes); // outSplitSizes return collective( inputTensors, @@ -2141,7 +2152,7 @@ c10::intrusive_ptr ProcessGroupNCCL::alltoall( std::vector& inputTensors, const AllToAllOptions& /* unused */) { auto device = outputTensors[0].device(); - for(const auto r : c10::irange(outputTensors.size())) { + for (const auto r : c10::irange(outputTensors.size())) { check_gpu_single_tensor(outputTensors[r]); check_gpu_single_tensor(inputTensors[r]); TORCH_CHECK( @@ -2210,7 +2221,8 @@ c10::intrusive_ptr ProcessGroupNCCL::alltoall_base( std::vector& /* unused */, std::vector& /* unused */, const AllToAllOptions& /* unused */) { - TORCH_CHECK(false, + TORCH_CHECK( + false, "ProcessGroupNCCL only supports alltoall* for NCCL lib version >= 2.7.0"); } @@ -2218,7 +2230,8 @@ c10::intrusive_ptr ProcessGroupNCCL::alltoall( std::vector& /* unused */, std::vector& /* unused */, const AllToAllOptions& /* unused */) { - TORCH_CHECK(false, + TORCH_CHECK( + false, "ProcessGroupNCCL only supports alltoall* for NCCL lib version >= 2.7.0"); } @@ -2226,7 +2239,8 @@ c10::intrusive_ptr ProcessGroupNCCL::send( std::vector& /* unused */, int /* unused */, int /* unused */) { - TORCH_CHECK(false, + TORCH_CHECK( + false, "ProcessGroupNCCL only supports send for NCCL lib version >= 2.7.0"); } @@ -2234,7 +2248,8 @@ c10::intrusive_ptr ProcessGroupNCCL::recv( std::vector& /* unused */, int /* unused */, int /* unused */) { - TORCH_CHECK(false, + TORCH_CHECK( + false, "ProcessGroupNCCL only supports recv for NCCL lib version >= 2.7.0"); } #endif @@ -2268,14 +2283,13 @@ c10::intrusive_ptr ProcessGroupNCCL::gather( // @lint-ignore CLANGTIDY auto tensor = inputTensors.back(); RECORD_PARAM_COMMS( - rank_, // rank - "gather", // colName - tensor.numel(), // inSize - tensor.numel() * - this->getSize(), // outSize - tensor.scalar_type(), // dType - std::vector(), // inSplitSizes - std::vector()); // outSplitSize + rank_, // rank + "gather", // colName + tensor.numel(), // inSize + tensor.numel() * this->getSize(), // outSize + tensor.scalar_type(), // dType + std::vector(), // inSplitSizes + std::vector()); // outSplitSize std::vector outputs; @@ -2317,7 +2331,7 @@ c10::intrusive_ptr ProcessGroupNCCL::gather( at::cuda::CUDAStream& stream) { const auto root = opts.rootRank; if (getRank() == root) { - for(auto output: outputs) { + for (auto output : outputs) { c10::cuda::CUDACachingAllocator::recordStream( output.storage().data_ptr(), stream); } @@ -2333,7 +2347,6 @@ c10::intrusive_ptr ProcessGroupNCCL::scatter( std::vector& outputTensors, std::vector>& inputTensors, const ScatterOptions& opts) { - static auto invalidArgument = [](const std::string& msg) { TORCH_CHECK(false, "ProcessGroupNCCL::scatter: " + msg); }; @@ -2345,14 +2358,13 @@ c10::intrusive_ptr ProcessGroupNCCL::scatter( // @lint-ignore CLANGTIDY auto tensor = outputTensors.back(); RECORD_PARAM_COMMS( - rank_, // rank - "scatter", // colName - tensor.numel(), // inSize - tensor.numel() * - this->getSize(), // outSize - tensor.scalar_type(), // dType - std::vector(), // inSplitSizes - std::vector()); // outSplitSize + rank_, // rank + "scatter", // colName + tensor.numel(), // inSize + tensor.numel() * this->getSize(), // outSize + tensor.scalar_type(), // dType + std::vector(), // inSplitSizes + std::vector()); // outSplitSize std::vector inputs; @@ -2395,12 +2407,13 @@ c10::intrusive_ptr ProcessGroupNCCL::scatter( at::cuda::CUDAStream& stream) { const auto root = opts.rootRank; if (getRank() == root) { - for(auto input: inputs) { + for (auto input : inputs) { c10::cuda::CUDACachingAllocator::recordStream( input.storage().data_ptr(), stream); } } - torch::cuda::nccl::scatter(inputs, outputTensors[0], comm, stream, root); + torch::cuda::nccl::scatter( + inputs, outputTensors[0], comm, stream, root); return ncclSuccess; }, OpType::SCATTER, @@ -2425,12 +2438,14 @@ c10::intrusive_ptr ProcessGroupNCCL::_allgather_base( } if (input_tensor.numel() * size_ != output_tensor.numel()) { - TORCH_CHECK(false, "output tensor size must be equal to world_size times input tensor size"); + TORCH_CHECK( + false, + "output tensor size must be equal to world_size times input tensor size"); } // just a wrapper to fit the collective interface - auto inputs = std::vector {input_tensor}; - auto outputs = std::vector {output_tensor}; + auto inputs = std::vector{input_tensor}; + auto outputs = std::vector{output_tensor}; return collective( inputs, diff --git a/torch/csrc/distributed/c10d/ProcessGroupRoundRobin.cpp b/torch/csrc/distributed/c10d/ProcessGroupRoundRobin.cpp index c439cf771a147f..4e017613200aad 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupRoundRobin.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupRoundRobin.cpp @@ -29,9 +29,10 @@ c10::intrusive_ptr ProcessGroupRoundRobin::allreduce( return next()->allreduce(tensors, opts); } -c10::intrusive_ptr ProcessGroupRoundRobin::allreduce_coalesced( - std::vector& tensors, - const AllreduceCoalescedOptions& opts) { +c10::intrusive_ptr ProcessGroupRoundRobin:: + allreduce_coalesced( + std::vector& tensors, + const AllreduceCoalescedOptions& opts) { return next()->allreduce_coalesced(tensors, opts); } @@ -48,10 +49,11 @@ c10::intrusive_ptr ProcessGroupRoundRobin::allgather( return next()->allgather(outputs, inputs, opts); }; -c10::intrusive_ptr ProcessGroupRoundRobin::allgather_coalesced( - std::vector>& outputTensorLists, - std::vector& inputTensors, - const AllgatherOptions& opts) { +c10::intrusive_ptr ProcessGroupRoundRobin:: + allgather_coalesced( + std::vector>& outputTensorLists, + std::vector& inputTensors, + const AllgatherOptions& opts) { return next()->allgather(outputTensorLists, inputTensors, opts); } @@ -124,8 +126,8 @@ c10::intrusive_ptr ProcessGroupRoundRobin::_allgather_base( at::Tensor& /*unused */, at::Tensor& /*unused */, const AllgatherOptions& /*unused */) { - TORCH_CHECK(false, - "no support for _allgather_base in RoundRobin process group"); + TORCH_CHECK( + false, "no support for _allgather_base in RoundRobin process group"); } } // namespace c10d diff --git a/torch/csrc/distributed/c10d/ProcessGroupWrapper.cpp b/torch/csrc/distributed/c10d/ProcessGroupWrapper.cpp index 118ee3e19c3b35..43ac73073bd5e8 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupWrapper.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupWrapper.cpp @@ -73,7 +73,7 @@ struct CollectiveFingerPrint { std::vector outputs; outputs.reserve(pg->getSize()); for (const auto i : c10::irange(pg->getSize())) { - (void)i; // Suppress unused variable warning + (void)i; // Suppress unused variable warning outputs.emplace_back(at::zeros_like(tensor_shape)); } output_tensors.emplace_back(outputs); diff --git a/torch/csrc/distributed/c10d/Store.cpp b/torch/csrc/distributed/c10d/Store.cpp index ea7e6b1e15c943..ae8f1358b8699a 100644 --- a/torch/csrc/distributed/c10d/Store.cpp +++ b/torch/csrc/distributed/c10d/Store.cpp @@ -9,7 +9,7 @@ constexpr std::chrono::milliseconds Store::kNoTimeout; Store::~Store() {} const std::chrono::milliseconds& Store::getTimeout() const noexcept { - return timeout_; + return timeout_; } // Set timeout function diff --git a/torch/csrc/distributed/c10d/TCPStore.cpp b/torch/csrc/distributed/c10d/TCPStore.cpp index 46dc29ec0f65f3..0e5628a21e6625 100644 --- a/torch/csrc/distributed/c10d/TCPStore.cpp +++ b/torch/csrc/distributed/c10d/TCPStore.cpp @@ -872,9 +872,8 @@ std::unique_ptr TCPClient::connect( const SocketAddress& addr, const TCPStoreOptions& opts) { auto timeout = std::chrono::duration_cast(opts.timeout); - Socket socket = Socket::connect(addr.host, - addr.port, - SocketOptions{}.connect_timeout(timeout)); + Socket socket = Socket::connect( + addr.host, addr.port, SocketOptions{}.connect_timeout(timeout)); return std::make_unique(std::move(socket)); } @@ -948,9 +947,8 @@ std::unique_ptr TCPCallbackClient::connect( const SocketAddress& addr, const TCPStoreOptions& opts) { auto timeout = std::chrono::duration_cast(opts.timeout); - Socket socket = Socket::connect(addr.host, - addr.port, - SocketOptions{}.connect_timeout(timeout)); + Socket socket = Socket::connect( + addr.host, addr.port, SocketOptions{}.connect_timeout(timeout)); int rawSocket = socket.handle(); diff --git a/torch/csrc/distributed/c10d/comm.cpp b/torch/csrc/distributed/c10d/comm.cpp index ea9da63b5988ff..101522a98b14bf 100644 --- a/torch/csrc/distributed/c10d/comm.cpp +++ b/torch/csrc/distributed/c10d/comm.cpp @@ -30,7 +30,7 @@ class BroadcastWork { auto output_tensors = torch::utils::unflatten_dense_tensors( flat_tensor_.front(), bucket_tensors_); TORCH_INTERNAL_ASSERT(output_tensors.size() == bucket_tensors_.size()); - for(const auto i : c10::irange(output_tensors.size())) { + for (const auto i : c10::irange(output_tensors.size())) { bucket_tensors_[i].copy_(output_tensors[i], /*non_blocking=*/true); } } diff --git a/torch/csrc/distributed/c10d/debug.cpp b/torch/csrc/distributed/c10d/debug.cpp index a22f322576cdb2..15270719725a54 100644 --- a/torch/csrc/distributed/c10d/debug.cpp +++ b/torch/csrc/distributed/c10d/debug.cpp @@ -29,10 +29,11 @@ DebugLevel loadDebugLevelFromEnvironment() { std::string level_str{env_value}; - std::transform(level_str.begin(), level_str.end(), level_str.begin(), - [](unsigned char c) { - return toupper(c); - }); + std::transform( + level_str.begin(), + level_str.end(), + level_str.begin(), + [](unsigned char c) { return toupper(c); }); if (level_str == "OFF") { level = DebugLevel::Off; @@ -41,7 +42,8 @@ DebugLevel loadDebugLevelFromEnvironment() { } else if (level_str == "DETAIL") { level = DebugLevel::Detail; } else { - throw C10dError{"The value of TORCH_DISTRIBUTED_DEBUG must be OFF, INFO, or DETAIL."}; + throw C10dError{ + "The value of TORCH_DISTRIBUTED_DEBUG must be OFF, INFO, or DETAIL."}; } C10D_INFO("The debug level is set to {}.", level_str); diff --git a/torch/csrc/distributed/c10d/debug.h b/torch/csrc/distributed/c10d/debug.h index ecfb4944829570..97237811cc1bf4 100644 --- a/torch/csrc/distributed/c10d/debug.h +++ b/torch/csrc/distributed/c10d/debug.h @@ -10,11 +10,7 @@ namespace c10d { -enum class DebugLevel { - Off, - Info, - Detail -}; +enum class DebugLevel { Off, Info, Detail }; TORCH_API void setDebugLevel(DebugLevel level); diff --git a/torch/csrc/distributed/c10d/default_comm_hooks.cpp b/torch/csrc/distributed/c10d/default_comm_hooks.cpp index 30bc96b16f7db9..e78e2d7a0fac9c 100644 --- a/torch/csrc/distributed/c10d/default_comm_hooks.cpp +++ b/torch/csrc/distributed/c10d/default_comm_hooks.cpp @@ -1,6 +1,6 @@ -#include #include #include +#include #include #include @@ -18,7 +18,6 @@ c10::intrusive_ptr AllReduceCommHook::runHook( c10::intrusive_ptr FP16CompressCommHook::runHook( GradBucket& bucket) { - auto compressed_tensor = bucket.getBufferRef().to(torch::kFloat16); // Apply the division first to avoid overflow. compressed_tensor /= state_->getSize(); @@ -34,10 +33,9 @@ c10::intrusive_ptr FP16CompressCommHook::runHook( auto reduce_tensor = result.toTensorVector()[0]; TORCH_INTERNAL_ASSERT_DEBUG_ONLY( - reduce_tensor.scalar_type() == at::ScalarType::Half, - "Expected reduced tensor to be fp16 in FP16CompressHook, but got type ", - reduce_tensor.scalar_type() - ); + reduce_tensor.scalar_type() == at::ScalarType::Half, + "Expected reduced tensor to be fp16 in FP16CompressHook, but got type ", + reduce_tensor.scalar_type()); decompressed_tensor.copy_(reduce_tensor); return c10::IValue(decompressed_tensor); }; @@ -45,8 +43,8 @@ c10::intrusive_ptr FP16CompressCommHook::runHook( return allreduce_fut->then(decompress, allreduce_fut->elementType()); } -c10::intrusive_ptr _AllReduceBySumCommHook:: - runHook(GradBucket& bucket) { +c10::intrusive_ptr _AllReduceBySumCommHook::runHook( + GradBucket& bucket) { std::vector tensors = {bucket.getBufferRef()}; return state_->allreduce(tensors)->getFuture(); } diff --git a/torch/csrc/distributed/c10d/error.h b/torch/csrc/distributed/c10d/error.h index fe726d16954951..e277e88ee9ae39 100644 --- a/torch/csrc/distributed/c10d/error.h +++ b/torch/csrc/distributed/c10d/error.h @@ -37,7 +37,8 @@ struct formatter { template decltype(auto) format(const std::error_code& err, FormatContext& ctx) { - return format_to(ctx.out(), "({}: {} - {})", err.category(), err.value(), err.message()); + return format_to( + ctx.out(), "({}: {} - {})", err.category(), err.value(), err.message()); } }; diff --git a/torch/csrc/distributed/c10d/init.cpp b/torch/csrc/distributed/c10d/init.cpp index 873b6b35f1682f..e846faee63291b 100644 --- a/torch/csrc/distributed/c10d/init.cpp +++ b/torch/csrc/distributed/c10d/init.cpp @@ -306,7 +306,7 @@ to apply layer-wise operations. "parameters", &::c10d::GradBucket::getParameters, py::call_guard(), - R"( + R"( Returns: A list of ``torch.Tensor``. Each tensor in the list corresponds to a model parameter. @@ -374,12 +374,17 @@ An enum-like class for built-in communication hooks: ``ALLREDUCE`` and ``FP16_CO }, py::call_guard()) .def("get_backward_stats", &::c10d::Reducer::get_backward_stats) - .def("_install_post_backward_futures", [](::c10d::Reducer& reducer, const std::vector>& futs) { - c10::List> futures(c10::FutureType::create(c10::TensorType::get())); - for (const auto & fut : futs) { + .def( + "_install_post_backward_futures", + [](::c10d::Reducer& reducer, + const std::vector>& + futs) { + c10::List> futures( + c10::FutureType::create(c10::TensorType::get())); + for (const auto& fut : futs) { futures.push_back(fut->fut); - } - reducer.install_futures(std::move(futures)); + } + reducer.install_futures(std::move(futures)); }, py::call_guard()) .def( @@ -401,8 +406,7 @@ An enum-like class for built-in communication hooks: ``ALLREDUCE`` and ``FP16_CO &::c10d::Reducer::set_forward_pass_work_handle, py::call_guard()) .def( - "_get_local_used_map", - &::c10d::Reducer::get_local_used_map_on_device) + "_get_local_used_map", &::c10d::Reducer::get_local_used_map_on_device) .def( "_set_ddp_runtime_logging_sample_rate", &::c10d::Reducer::set_ddp_runtime_logging_sample_rate, @@ -413,9 +417,9 @@ An enum-like class for built-in communication hooks: ``ALLREDUCE`` and ``FP16_CO &::c10d::Reducer::set_static_graph, py::call_guard()) .def( - "_ddp_graph_static", - &::c10d::Reducer::ddp_graph_static, - py::call_guard()) + "_ddp_graph_static", + &::c10d::Reducer::ddp_graph_static, + py::call_guard()) .def( "_delay_all_reduce", &::c10d::Reducer::delay_all_reduce, @@ -491,11 +495,17 @@ An enum-like class for built-in communication hooks: ``ALLREDUCE`` and ``FP16_CO .value("DETAIL", ::c10d::DebugLevel::Detail); module - .def("get_debug_level", ::c10d::debug_level, + .def( + "get_debug_level", + ::c10d::debug_level, R"(Gets the debug level of the torch.distributed package.)") - .def("set_debug_level", ::c10d::setDebugLevel, + .def( + "set_debug_level", + ::c10d::setDebugLevel, R"(Sets the debug level of the torch.distributed package.)") - .def("set_debug_level_from_env", ::c10d::setDebugLevelFromEnvironment, + .def( + "set_debug_level_from_env", + ::c10d::setDebugLevelFromEnvironment, R"(Sets the debug level of the torch.distributed package from the ``TORCH_DISTRIBUTED_DEBUG`` environment variable.)"); @@ -965,9 +975,10 @@ that adds a prefix to each key inserted to the store. .def(py::init>()); auto processGroup = - py::class_<::c10d::ProcessGroup, - c10::intrusive_ptr<::c10d::ProcessGroup>, - ::c10d::PyProcessGroup>(module, "ProcessGroup") + py::class_< + ::c10d::ProcessGroup, + c10::intrusive_ptr<::c10d::ProcessGroup>, + ::c10d::PyProcessGroup>(module, "ProcessGroup") .def(py::init()) .def("rank", &::c10d::ProcessGroup::getRank) .def("size", &::c10d::ProcessGroup::getSize) @@ -1398,9 +1409,8 @@ options :class:`~torch.distributed.ProcessGroupNCCL.Options`). py::arg("pg"), py::arg("gloo_pg"), py::call_guard()) - .def_property_readonly( - "wrapped_pg", &::c10d::ProcessGroupWrapper::getWrappedPg - ); + .def_property_readonly( + "wrapped_pg", &::c10d::ProcessGroupWrapper::getWrappedPg); #endif #ifdef USE_C10D_NCCL @@ -1481,9 +1491,10 @@ Example:: py::call_guard()); #endif - py::class_<::c10d::ProcessGroup::Work, - c10::intrusive_ptr<::c10d::ProcessGroup::Work>, - ::c10d::PyProcessGroup::PyWork>(module, "Work") + py::class_< + ::c10d::ProcessGroup::Work, + c10::intrusive_ptr<::c10d::ProcessGroup::Work>, + ::c10d::PyProcessGroup::PyWork>(module, "Work") .def(py::init<>()) .def("is_completed", &::c10d::ProcessGroup::Work::isCompleted) .def( @@ -1579,16 +1590,26 @@ Example:: module.def( "_compute_bucket_assignment_by_size", [](const std::vector& tensors, - const std::vector& bucket_size_limits, - const std::vector& expect_sparse_gradient, - const std::vector& tensor_indices, - const c10::optional>& logger) { - if (logger.has_value()) { - std::weak_ptr<::c10d::Logger> logger_weakref = logger.value(); - return ::c10d::compute_bucket_assignment_by_size(tensors, bucket_size_limits, expect_sparse_gradient, tensor_indices, {logger_weakref}); - } else { - return ::c10d::compute_bucket_assignment_by_size(tensors, bucket_size_limits, expect_sparse_gradient, tensor_indices, {}); - } + const std::vector& bucket_size_limits, + const std::vector& expect_sparse_gradient, + const std::vector& tensor_indices, + const c10::optional>& logger) { + if (logger.has_value()) { + std::weak_ptr<::c10d::Logger> logger_weakref = logger.value(); + return ::c10d::compute_bucket_assignment_by_size( + tensors, + bucket_size_limits, + expect_sparse_gradient, + tensor_indices, + {logger_weakref}); + } else { + return ::c10d::compute_bucket_assignment_by_size( + tensors, + bucket_size_limits, + expect_sparse_gradient, + tensor_indices, + {}); + } }, py::arg("tensors"), py::arg("bucket_size"), @@ -1602,12 +1623,13 @@ Example:: [](const c10::intrusive_ptr<::c10d::ProcessGroup>& process_group, const std::vector& params, const c10::optional>& logger) { - if (logger.has_value()) { - std::weak_ptr<::c10d::Logger> logger_weakref = logger.value(); - verify_params_across_processes(process_group, params, {logger_weakref}); - } else { - verify_params_across_processes(process_group, params, {}); - } + if (logger.has_value()) { + std::weak_ptr<::c10d::Logger> logger_weakref = logger.value(); + verify_params_across_processes( + process_group, params, {logger_weakref}); + } else { + verify_params_across_processes(process_group, params, {}); + } }, py::arg("process_group"), py::arg("params"), diff --git a/torch/csrc/distributed/c10d/logger.cpp b/torch/csrc/distributed/c10d/logger.cpp index 93e8d05f265581..1b4b5df3ce768c 100644 --- a/torch/csrc/distributed/c10d/logger.cpp +++ b/torch/csrc/distributed/c10d/logger.cpp @@ -89,12 +89,12 @@ void Logger::set_env_variables() { ddp_logging_data_->strs_map["gloo_device_transport"] = parse_env("GLOO_DEVICE_TRANSPORT"); - #ifdef USE_C10D_GLOO +#ifdef USE_C10D_GLOO auto gloo_pg = static_cast(reducer_->process_group_.get()); auto n_threads = gloo_pg->getNumThreads(); ddp_logging_data_->ints_map["gloo_num_threads"] = n_threads; - #endif +#endif } } @@ -232,7 +232,8 @@ void Logger::calculate_avg_time( Timer::Event start_event, Timer::Event end_event) { TORCH_CHECK(num_iterations_stats_recorded_ > 0); - c10::optional maybe_time_duration = timer.measureDifference(start_event, end_event); + c10::optional maybe_time_duration = + timer.measureDifference(start_event, end_event); if (!maybe_time_duration.has_value()) { return; } @@ -296,7 +297,7 @@ void Logger::set_runtime_stats_and_log() { per_bucket_variable_indices.push_back(c10::Join(" ", bucket_indices)); } ddp_logging_data_->strs_map["rebuilt_per_bucket_param_indices"] = - c10::Join(", ", per_bucket_variable_indices); + c10::Join(", ", per_bucket_variable_indices); } // Log gradient ready order if (!reducer_->grad_ready_order_indices_.empty()) { @@ -312,16 +313,14 @@ void Logger::set_runtime_stats_and_log() { // Cuda time stats are only collected for single device modules. if (reducer_->params_[0].is_cuda() && reducer_->is_multi_device_module_) { TORCH_WARN_ONCE( - "Cuda time stats are not collected for multi-device modules." - ); + "Cuda time stats are not collected for multi-device modules."); return; } if (!reducer_->params_[0].is_cuda() && !reducer_->params_[0].is_cpu()) { TORCH_WARN_ONCE( - "Time stats are currently only collected for CPU and CUDA devices. " - "Please refer to CpuTimer or CudaTimer for how to register timer " - "for other device type." - ); + "Time stats are currently only collected for CPU and CUDA devices. " + "Please refer to CpuTimer or CudaTimer for how to register timer " + "for other device type."); return; } TORCH_INTERNAL_ASSERT(reducer_->timer_); @@ -351,30 +350,25 @@ void Logger::set_runtime_stats_and_log() { Timer::Event::kBackwardComputeEnd); set_event_time( - ddp_logging_data_->ints_map["forward_compute_time_start"], - *reducer_->timer_, - Timer::Event::kForwardStart - ); + ddp_logging_data_->ints_map["forward_compute_time_start"], + *reducer_->timer_, + Timer::Event::kForwardStart); set_event_time( - ddp_logging_data_->ints_map["backward_compute_time_start"], - *reducer_->timer_, - Timer::Event::kBackwardComputeStart - ); + ddp_logging_data_->ints_map["backward_compute_time_start"], + *reducer_->timer_, + Timer::Event::kBackwardComputeStart); set_event_time( - ddp_logging_data_->ints_map["backward_comm_time_start"], - *reducer_->timer_, - Timer::Event::kBackwardCommStart - ); + ddp_logging_data_->ints_map["backward_comm_time_start"], + *reducer_->timer_, + Timer::Event::kBackwardCommStart); set_event_time( - ddp_logging_data_->ints_map["backward_compute_time_end"], - *reducer_->timer_, - Timer::Event::kBackwardComputeEnd - ); + ddp_logging_data_->ints_map["backward_compute_time_end"], + *reducer_->timer_, + Timer::Event::kBackwardComputeEnd); set_event_time( - ddp_logging_data_->ints_map["backward_comm_time_end"], - *reducer_->timer_, - Timer::Event::kBackwardCommEnd - ); + ddp_logging_data_->ints_map["backward_comm_time_end"], + *reducer_->timer_, + Timer::Event::kBackwardCommEnd); // Log runtime stats to stderr if TORCH_DISTRIBUTED_DEBUG=DETAIL is enabled. if (debug_level() == DebugLevel::Detail) { diff --git a/torch/csrc/distributed/c10d/logging.h b/torch/csrc/distributed/c10d/logging.h index 57ee974a0d3577..b4df02e773807e 100644 --- a/torch/csrc/distributed/c10d/logging.h +++ b/torch/csrc/distributed/c10d/logging.h @@ -15,13 +15,7 @@ namespace c10d { namespace detail { -enum class LogLevel { - Trace, - Debug, - Info, - Warning, - Error -}; +enum class LogLevel { Trace, Debug, Info, Warning, Error }; TORCH_API bool isLogLevelEnabled(LogLevel level) noexcept; @@ -33,22 +27,25 @@ std::string formatLogMessage(fmt::string_view fmt, T&&... args) { } // namespace detail } // namespace c10d -#define C10D_ERROR(...)\ - LOG_IF(ERROR, c10d::detail::isLogLevelEnabled(c10d::detail::LogLevel::Error))\ - << "[c10d] " << c10d::detail::formatLogMessage(__VA_ARGS__) +#define C10D_ERROR(...) \ + LOG_IF( \ + ERROR, c10d::detail::isLogLevelEnabled(c10d::detail::LogLevel::Error)) \ + << "[c10d] " << c10d::detail::formatLogMessage(__VA_ARGS__) -#define C10D_WARNING(...)\ - LOG_IF(WARNING, c10d::detail::isLogLevelEnabled(c10d::detail::LogLevel::Warning))\ - << "[c10d] " << c10d::detail::formatLogMessage(__VA_ARGS__) +#define C10D_WARNING(...) \ + LOG_IF( \ + WARNING, \ + c10d::detail::isLogLevelEnabled(c10d::detail::LogLevel::Warning)) \ + << "[c10d] " << c10d::detail::formatLogMessage(__VA_ARGS__) -#define C10D_INFO(...)\ - LOG_IF(INFO, c10d::detail::isLogLevelEnabled(c10d::detail::LogLevel::Info))\ - << "[c10d] " << c10d::detail::formatLogMessage(__VA_ARGS__) +#define C10D_INFO(...) \ + LOG_IF(INFO, c10d::detail::isLogLevelEnabled(c10d::detail::LogLevel::Info)) \ + << "[c10d] " << c10d::detail::formatLogMessage(__VA_ARGS__) -#define C10D_DEBUG(...)\ - LOG_IF(INFO, c10d::detail::isLogLevelEnabled(c10d::detail::LogLevel::Debug))\ - << "[c10d - debug] " << c10d::detail::formatLogMessage(__VA_ARGS__) +#define C10D_DEBUG(...) \ + LOG_IF(INFO, c10d::detail::isLogLevelEnabled(c10d::detail::LogLevel::Debug)) \ + << "[c10d - debug] " << c10d::detail::formatLogMessage(__VA_ARGS__) -#define C10D_TRACE(...)\ - LOG_IF(INFO, c10d::detail::isLogLevelEnabled(c10d::detail::LogLevel::Trace))\ - << "[c10d - trace] " << c10d::detail::formatLogMessage(__VA_ARGS__) +#define C10D_TRACE(...) \ + LOG_IF(INFO, c10d::detail::isLogLevelEnabled(c10d::detail::LogLevel::Trace)) \ + << "[c10d - trace] " << c10d::detail::formatLogMessage(__VA_ARGS__) diff --git a/torch/csrc/distributed/c10d/python_comm_hook.h b/torch/csrc/distributed/c10d/python_comm_hook.h index ddfeb86984a006..5d30e63b5acca5 100644 --- a/torch/csrc/distributed/c10d/python_comm_hook.h +++ b/torch/csrc/distributed/c10d/python_comm_hook.h @@ -23,7 +23,7 @@ class TORCH_PYTHON_API PythonCommHook : public CommHookInterface { c10::intrusive_ptr runHook(GradBucket& bucket) override; -at::Tensor parseHookResult(const c10::IValue& result) override; + at::Tensor parseHookResult(const c10::IValue& result) override; private: // Only needed for stateful communication. diff --git a/torch/csrc/distributed/c10d/quantization/quantization.cpp b/torch/csrc/distributed/c10d/quantization/quantization.cpp index 572c81d8aa5a3b..104aaab873c673 100644 --- a/torch/csrc/distributed/c10d/quantization/quantization.cpp +++ b/torch/csrc/distributed/c10d/quantization/quantization.cpp @@ -13,7 +13,7 @@ void FloatToBFloat16Quantized_ref( const float* const input, const size_t nrows, const size_t ncols, - uint16_t* const output){ + uint16_t* const output) { for (const auto row : c10::irange(nrows)) { const float* input_row = input + row * ncols; uint16_t* output_row = output + row * ncols; @@ -30,7 +30,7 @@ void BFloat16QuantizedToFloat_ref( const at::BFloat16* const input, const size_t nrows, const size_t ncols, - float* const output){ + float* const output) { const int32_t output_columns = ncols; for (const auto row : c10::irange(nrows)) { @@ -55,9 +55,8 @@ at::Tensor _float_to_bfloat16_cpu(const at::Tensor& input) { const int32_t nrows = input_sizes[0]; const int32_t ncols = input_sizes[1]; const int32_t output_columns = ncols; - auto output = at::empty( - {nrows, output_columns}, - input.options().dtype(at::kHalf)); + auto output = + at::empty({nrows, output_columns}, input.options().dtype(at::kHalf)); FloatToBFloat16Quantized_ref( input.data_ptr(), @@ -90,15 +89,14 @@ at::Tensor _bfloat16_to_float_cpu(const at::Tensor& input) { return output; } - TORCH_LIBRARY(quantization, m) { - m.def("_Bfloat16QuantizedToFloat(Tensor input) -> Tensor"); - m.def("_FloatToBfloat16Quantized(Tensor input) -> Tensor"); + m.def("_Bfloat16QuantizedToFloat(Tensor input) -> Tensor"); + m.def("_FloatToBfloat16Quantized(Tensor input) -> Tensor"); } TORCH_LIBRARY_IMPL(quantization, CPU, m) { - m.impl("_Bfloat16QuantizedToFloat", _bfloat16_to_float_cpu); - m.impl("_FloatToBfloat16Quantized", _float_to_bfloat16_cpu); + m.impl("_Bfloat16QuantizedToFloat", _bfloat16_to_float_cpu); + m.impl("_FloatToBfloat16Quantized", _float_to_bfloat16_cpu); } } // namespace quantization diff --git a/torch/csrc/distributed/c10d/quantization/quantization.h b/torch/csrc/distributed/c10d/quantization/quantization.h index 658fa754488d14..739991b92505f3 100644 --- a/torch/csrc/distributed/c10d/quantization/quantization.h +++ b/torch/csrc/distributed/c10d/quantization/quantization.h @@ -2,7 +2,6 @@ #pragma once - #include #include diff --git a/torch/csrc/distributed/c10d/quantization/quantization_gpu.h b/torch/csrc/distributed/c10d/quantization/quantization_gpu.h index 2a0c8f8f8d39ce..c10ffd4c6a3903 100644 --- a/torch/csrc/distributed/c10d/quantization/quantization_gpu.h +++ b/torch/csrc/distributed/c10d/quantization/quantization_gpu.h @@ -2,7 +2,6 @@ #pragma once - #include #include diff --git a/torch/csrc/distributed/c10d/quantization/quantization_utils.h b/torch/csrc/distributed/c10d/quantization/quantization_utils.h index 0467ba2769f5be..add49c84ace275 100644 --- a/torch/csrc/distributed/c10d/quantization/quantization_utils.h +++ b/torch/csrc/distributed/c10d/quantization/quantization_utils.h @@ -20,7 +20,7 @@ inline std::string torch_tensor_device_name(const at::Tensor& ten) { #define TENSOR_ON_CPU(x) \ TORCH_CHECK( \ - !x.is_cuda(), \ + !x.is_cuda(), \ #x " must be a CPU tensor; it is currently on device ", \ torch_tensor_device_name(x)) diff --git a/torch/csrc/distributed/c10d/reducer.cpp b/torch/csrc/distributed/c10d/reducer.cpp index 31d376b13a2473..921d7814265fab 100644 --- a/torch/csrc/distributed/c10d/reducer.cpp +++ b/torch/csrc/distributed/c10d/reducer.cpp @@ -24,7 +24,6 @@ namespace c10d { namespace { - constexpr int kUnsetDivFactor = -1; // Macro that wraps TORCH_CHECK with DDP logging. @@ -203,8 +202,7 @@ Reducer::Reducer( "Reducer tried to register duplicate grad accumulator for variable ", variable_index)); - grad_accumulators_[variable_index] = - std::move(grad_accumulator); + grad_accumulators_[variable_index] = std::move(grad_accumulator); } } @@ -283,15 +281,13 @@ void Reducer::initialize_local_used_map() { // Deliberately don't pin the memory even if local_used_map_dev_ will // be cuda. See Note [local_used_map_ -> local_used_map_dev copying] - local_used_map_ = - at::zeros({static_cast(variable_count)}, options); + local_used_map_ = at::zeros({static_cast(variable_count)}, options); // This tensor needs to be on the same device as the replica params because // backend such as NCCL may not support CPU tensors, and hence it might not // work if we always put it on CPU. options = options.device(params_[0].device()); - local_used_map_dev_ = - at::empty({static_cast(variable_count)}, options); + local_used_map_dev_ = at::empty({static_cast(variable_count)}, options); } void Reducer::check_grad_layout( @@ -401,8 +397,7 @@ void Reducer::mark_variable_ready_dense(size_t variable_index) { logger_, "Encountered gradient which is undefined, but still allreduced by " "DDP reducer. This indicates a bug in DDP implementation, please " - "report a bug with a repro to PyTorch." - ); + "report a bug with a repro to PyTorch."); } bucket_view.zero_(); } @@ -708,8 +703,7 @@ void Reducer::checkAndRaiseMarkedTwiceError(size_t index) { auto param_name = param_names_.find(index); const bool found_param_name = param_name != param_names_.end(); TORCH_INTERNAL_ASSERT( - ddp_debug_level_ == c10d::DebugLevel::Off || - found_param_name, + ddp_debug_level_ == c10d::DebugLevel::Off || found_param_name, "Expected to find parameter name in debug mode."); std::string paramInfo = c10::str( "Parameter at index ", @@ -917,7 +911,8 @@ void Reducer::mark_bucket_ready(size_t bucket_index) { } } -void Reducer::install_futures(c10::List> futs) { +void Reducer::install_futures( + c10::List> futs) { // Append instead of overwrite so that this method can be called multiple // times in one iteration. if (!installed_futures_) { @@ -927,7 +922,8 @@ void Reducer::install_futures(c10::List> } } -void Reducer::initialize_buckets(std::vector> bucket_indices) { +void Reducer::initialize_buckets( + std::vector> bucket_indices) { // If initialize_buckets is called inside DDP constructor, then // it does not matter rpc context ptr is nullptr or not, as grad // will not be mutated. @@ -970,8 +966,7 @@ void Reducer::initialize_buckets(std::vector> bucket_indices // Variables that expect sparse gradients must have their own bucket. if (bucket_indices[bucket_index].size() == 1) { const auto variable_index = bucket_indices[bucket_index].front(); - bucket.expect_sparse_gradient = - expect_sparse_gradients_[variable_index]; + bucket.expect_sparse_gradient = expect_sparse_gradients_[variable_index]; } else { for (const auto variable_index : bucket_indices[bucket_index]) { REDUCER_CHECK( @@ -1382,8 +1377,7 @@ void Reducer::finalize_bucket_dense(Bucket& bucket) { // the point as below where we wait for the reduction work, make D2H // copy, and update global_unused with the real global consensus, i.e. // local_used_map_reduced_ is true. - global_unused = - local_used_map_[variable_index].item() == 0; + global_unused = local_used_map_[variable_index].item() == 0; if (global_unused && !local_used_map_reduced_) { // Wait for local_used_map reduction to complete. local_used_work_->wait(); @@ -1391,8 +1385,7 @@ void Reducer::finalize_bucket_dense(Bucket& bucket) { // Blocking copy, if local_used_map_dev_ is cuda local_used_map_.copy_(local_used_map_dev_); - global_unused = - local_used_map_[variable_index].item() == 0; + global_unused = local_used_map_[variable_index].item() == 0; local_used_map_reduced_ = true; } } @@ -1403,8 +1396,7 @@ void Reducer::finalize_bucket_dense(Bucket& bucket) { std::vector({variable})); copy_bucket_to_grad(variable, bucket, intra_bucket_index, global_unused); } else { - const auto& bucket_view_out = - bucket.bucket_views_out[intra_bucket_index]; + const auto& bucket_view_out = bucket.bucket_views_out[intra_bucket_index]; auto& bucket_view_in = bucket.bucket_views_in[intra_bucket_index]; // If a communication hook is registered, then `bucket_view_out` stores // the allreduced results in a newly allocated tensor, so we copy @@ -1483,8 +1475,8 @@ void Reducer::finalize_backward() { } if (installed_futures_ != c10::nullopt) { - c10::collectAll(*installed_futures_)->wait(); - installed_futures_ = c10::nullopt; + c10::collectAll(*installed_futures_)->wait(); + installed_futures_ = c10::nullopt; } // See Note [Skip allreducing local_used_maps_dev] @@ -1652,10 +1644,10 @@ bool Reducer::rebuild_buckets() { if (ddp_set_last_bucket_as_small) { // Reverse so that first_bucket_bytes_cap_ (smaller bucket) becomes the last - // bucket. We cannot simply pass in {bucket_bytes_cap_, first_bucket_bytes_cap} - // as the bucket order as we would immediately advance to the 2nd element - // after the first bucket, whereas we only want the last bucket to have - // a smaller size. + // bucket. We cannot simply pass in {bucket_bytes_cap_, + // first_bucket_bytes_cap} as the bucket order as we would immediately + // advance to the 2nd element after the first bucket, whereas we only want + // the last bucket to have a smaller size. std::reverse(rebuilt_params_.begin(), rebuilt_params_.end()); std::reverse(rebuilt_param_indices_.begin(), rebuilt_param_indices_.end()); } @@ -1679,8 +1671,7 @@ bool Reducer::rebuild_buckets() { rebuilt_bucket_indices.size() == per_bucket_size_limits.size()) LOG(INFO) << rebuilt_bucket_indices.size() << " buckets rebuilt with size limits: " - << c10::Join(", ", per_bucket_size_limits) - << " bytes."; + << c10::Join(", ", per_bucket_size_limits) << " bytes."; } // For rebuilt bucket indices, it needs to be synced across all ranks. @@ -1967,8 +1958,8 @@ compute_bucket_assignment_by_size( // be grouped together with other gradients and gets its own bucket. if (!expect_sparse_gradient.empty() && expect_sparse_gradient[tensor_index]) { - result.emplace_back(std::vector({tensor_index}), kNoSizeLimit); - continue; + result.emplace_back(std::vector({tensor_index}), kNoSizeLimit); + continue; } auto key = BucketKey(tensor.scalar_type(), tensor.device()); @@ -2040,13 +2031,12 @@ compute_bucket_assignment_by_size( return std::make_tuple(bucket_indices, per_bucket_size_limits); } -// Verifies corresponding params in the model replica have the same sizes/strides -// across processes. +// Verifies corresponding params in the model replica have the same +// sizes/strides across processes. void verify_params_across_processes( const c10::intrusive_ptr& process_group, const std::vector& params, const c10::optional>& logger) { - // First verify number of parameters to avoid inconsistent inputs into // broadcast which can cause a crash. // See https://github.com/pytorch/pytorch/issues/73547 @@ -2056,34 +2046,35 @@ void verify_params_across_processes( param_size_options = param_size_options.device(params[0].device()); // Note: Not using tensor building API because of // https://github.com/pytorch/pytorch/issues/74114 - at::Tensor param_size_tensor = at::tensor( - {static_cast(params.size())}, param_size_options); + at::Tensor param_size_tensor = + at::tensor({static_cast(params.size())}, param_size_options); // Allgather and verify parameter size. std::vector> param_size_output_tensors; param_size_output_tensors.emplace_back(std::vector{}); auto world_size = process_group->getSize(); - for (size_t i = 0 ; i < world_size ; ++i) { + for (size_t i = 0; i < world_size; ++i) { param_size_output_tensors.front().emplace_back( - at::empty_like(param_size_tensor) - ); + at::empty_like(param_size_tensor)); } std::vector param_size_vec{param_size_tensor}; process_group->allgather(param_size_output_tensors, param_size_vec)->wait(); auto result_size_tensors = param_size_output_tensors.front(); - for (size_t i = 0; i < world_size ; ++i ) { + for (size_t i = 0; i < world_size; ++i) { auto param_size_for_rank = result_size_tensors[i][0].item(); TORCH_CHECK( - param_size_for_rank == params.size(), - c10::str( - "DDP expects same model across all ranks, but Rank ", - process_group->getRank(), - " has ", params.size(), " params, while rank ", i, - " has inconsistent ", param_size_for_rank, - " params." - ) - ); + param_size_for_rank == params.size(), + c10::str( + "DDP expects same model across all ranks, but Rank ", + process_group->getRank(), + " has ", + params.size(), + " params, while rank ", + i, + " has inconsistent ", + param_size_for_rank, + " params.")); } // Continue with parameter shape verification. @@ -2121,23 +2112,29 @@ void verify_params_across_processes( for (const auto p : c10::irange(params.size())) { const auto& t = params[p]; for (const auto& sz : t.sizes()) { - auto msg = c10::str("[", process_group->getRank(), - "]: params[", p, "] in this process", - " with sizes ", - t.sizes(), - " appears not to match sizes of the same param in process 0."); + auto msg = c10::str( + "[", + process_group->getRank(), + "]: params[", + p, + "] in this process", + " with sizes ", + t.sizes(), + " appears not to match sizes of the same param in process 0."); if (logger.has_value()) { REDUCER_CHECK(sz == control_accessor[i++], logger.value(), msg) } else { TORCH_CHECK(sz == control_accessor[i++], msg) } - } for (const auto& str : t.strides()) { - auto msg = c10::str("params[", p, "] in this process", - " with sizes ", - t.sizes(), - " appears not to match strides of the same param in process 0."); + auto msg = c10::str( + "params[", + p, + "] in this process", + " with sizes ", + t.sizes(), + " appears not to match strides of the same param in process 0."); if (logger.has_value()) { REDUCER_CHECK(str == control_accessor[i++], logger.value(), msg) } else { diff --git a/torch/csrc/distributed/c10d/reducer_cuda.cpp b/torch/csrc/distributed/c10d/reducer_cuda.cpp index a1c570da5d59ab..60aa450284f65b 100644 --- a/torch/csrc/distributed/c10d/reducer_cuda.cpp +++ b/torch/csrc/distributed/c10d/reducer_cuda.cpp @@ -1,7 +1,7 @@ #include -#include #include +#include namespace c10d { namespace { @@ -19,8 +19,7 @@ class CudaTimer : public Timer { at::cuda::CUDAEvent(cudaEventDefault); at::cuda::CUDAEvent backward_comm_start = at::cuda::CUDAEvent(cudaEventDefault); - at::cuda::CUDAEvent backward_comm_end = - at::cuda::CUDAEvent(cudaEventDefault); + at::cuda::CUDAEvent backward_comm_end = at::cuda::CUDAEvent(cudaEventDefault); at::cuda::CUDAEvent& getEvent(Event event) { switch (event) { diff --git a/torch/csrc/distributed/c10d/sequence_num.cpp b/torch/csrc/distributed/c10d/sequence_num.cpp index 0caeaa4b701453..27b2e8012bd4df 100644 --- a/torch/csrc/distributed/c10d/sequence_num.cpp +++ b/torch/csrc/distributed/c10d/sequence_num.cpp @@ -1,5 +1,5 @@ -#include #include +#include #include #include diff --git a/torch/csrc/distributed/c10d/socket.cpp b/torch/csrc/distributed/c10d/socket.cpp index acd819ab631cda..b39d63a2eb90a8 100644 --- a/torch/csrc/distributed/c10d/socket.cpp +++ b/torch/csrc/distributed/c10d/socket.cpp @@ -46,12 +46,23 @@ const auto pollFd = ::WSAPoll; // Winsock's `getsockopt()` and `setsockopt()` functions expect option values to // be passed as `char*` instead of `void*`. We wrap them here to avoid redundant // casts in the source code. -int getSocketOption(SOCKET s, int level, int optname, void* optval, int* optlen) { +int getSocketOption( + SOCKET s, + int level, + int optname, + void* optval, + int* optlen) { return ::getsockopt(s, level, optname, static_cast(optval), optlen); } -int setSocketOption(SOCKET s, int level, int optname, const void* optval, int optlen) { - return ::setsockopt(s, level, optname, static_cast(optval), optlen); +int setSocketOption( + SOCKET s, + int level, + int optname, + const void* optval, + int optlen) { + return ::setsockopt( + s, level, optname, static_cast(optval), optlen); } // Winsock has its own error codes which differ from Berkeley's. Fortunately the @@ -123,8 +134,7 @@ class SocketImpl { static constexpr Handle invalid_socket = -1; #endif - explicit SocketImpl(Handle hnd) noexcept - : hnd_{hnd} {} + explicit SocketImpl(Handle hnd) noexcept : hnd_{hnd} {} SocketImpl(const SocketImpl& other) = delete; @@ -185,19 +195,20 @@ struct formatter<::addrinfo> { decltype(auto) format(const ::addrinfo& addr, FormatContext& ctx) { char host[NI_MAXHOST], port[NI_MAXSERV]; // NOLINT - int r = ::getnameinfo(addr.ai_addr, - addr.ai_addrlen, - host, - NI_MAXHOST, - port, - NI_MAXSERV, - NI_NUMERICSERV); + int r = ::getnameinfo( + addr.ai_addr, + addr.ai_addrlen, + host, + NI_MAXHOST, + port, + NI_MAXSERV, + NI_NUMERICSERV); if (r != 0) { return format_to(ctx.out(), "?UNKNOWN?"); } if (addr.ai_addr->sa_family == AF_INET) { - return format_to(ctx.out(), "{}:{}", host, port); + return format_to(ctx.out(), "{}:{}", host, port); } else { return format_to(ctx.out(), "[{}]:{}", host, port); } @@ -211,7 +222,9 @@ struct formatter { } template - decltype(auto) format(const c10d::detail::SocketImpl& socket, FormatContext& ctx) { + decltype(auto) format( + const c10d::detail::SocketImpl& socket, + FormatContext& ctx) { ::sockaddr_storage addr_s{}; auto addr_ptr = reinterpret_cast<::sockaddr*>(&addr_s); @@ -259,9 +272,13 @@ std::unique_ptr SocketImpl::accept() const { std::string msg{}; if (err == std::errc::invalid_argument) { - msg = fmt::format("The server socket on {} is not listening for connections.", *this); + msg = fmt::format( + "The server socket on {} is not listening for connections.", *this); } else { - msg = fmt::format("The server socket on {} has failed to accept a connection {}.", *this, err); + msg = fmt::format( + "The server socket on {} has failed to accept a connection {}.", + *this, + err); } C10D_ERROR(msg); @@ -273,7 +290,10 @@ std::unique_ptr SocketImpl::accept() const { addr.ai_addr = addr_ptr; addr.ai_addrlen = addr_len; - C10D_DEBUG("The server socket on {} has accepted a connection from {}.", *this, addr); + C10D_DEBUG( + "The server socket on {} has accepted a connection from {}.", + *this, + addr); auto impl = std::make_unique(hnd); @@ -281,7 +301,9 @@ std::unique_ptr SocketImpl::accept() const { impl->closeOnExec(); if (!impl->enableNoDelay()) { - C10D_WARNING("The no-delay option cannot be enabled for the client socket on {}.", addr); + C10D_WARNING( + "The no-delay option cannot be enabled for the client socket on {}.", + addr); } return impl; @@ -353,12 +375,13 @@ std::uint16_t SocketImpl::getPort() const { ::socklen_t addr_len = sizeof(addr_s); - if (::getsockname(hnd_, reinterpret_cast<::sockaddr*>(&addr_s), &addr_len) != 0) { + if (::getsockname(hnd_, reinterpret_cast<::sockaddr*>(&addr_s), &addr_len) != + 0) { throw SocketError{"The port number of the socket cannot be retrieved."}; } if (addr_s.ss_family == AF_INET) { - return ntohs(reinterpret_cast<::sockaddr_in*> (&addr_s)->sin_port); + return ntohs(reinterpret_cast<::sockaddr_in*>(&addr_s)->sin_port); } else { return ntohs(reinterpret_cast<::sockaddr_in6*>(&addr_s)->sin6_port); } @@ -424,13 +447,15 @@ std::unique_ptr SocketListenOp::run() { return std::move(socket_); } } else { - C10D_DEBUG("The server socket will attempt to listen on an IPv4 or IPv6 address."); + C10D_DEBUG( + "The server socket will attempt to listen on an IPv4 or IPv6 address."); if (tryListen(AF_UNSPEC)) { return std::move(socket_); } } - constexpr auto* msg = "The server socket has failed to listen on any local network address."; + constexpr auto* msg = + "The server socket has failed to listen on any local network address."; C10D_ERROR(msg); @@ -448,10 +473,13 @@ bool SocketListenOp::tryListen(int family) { if (r != 0) { const char* gai_err = ::gai_strerror(r); - recordError("The local {}network addresses cannot be retrieved (gai error: {} - {}).", - family == AF_INET ? "IPv4 " : family == AF_INET6 ? "IPv6 " : "", - r, - gai_err); + recordError( + "The local {}network addresses cannot be retrieved (gai error: {} - {}).", + family == AF_INET ? "IPv4 " + : family == AF_INET6 ? "IPv6 " + : "", + r, + gai_err); return false; } @@ -469,9 +497,13 @@ bool SocketListenOp::tryListen(int family) { } bool SocketListenOp::tryListen(const ::addrinfo& addr) { - SocketImpl::Handle hnd = ::socket(addr.ai_family, addr.ai_socktype, addr.ai_protocol); + SocketImpl::Handle hnd = + ::socket(addr.ai_family, addr.ai_socktype, addr.ai_protocol); if (hnd == SocketImpl::invalid_socket) { - recordError("The server socket cannot be initialized on {} {}.", addr, getSocketError()); + recordError( + "The server socket cannot be initialized on {} {}.", + addr, + getSocketError()); return false; } @@ -480,7 +512,9 @@ bool SocketListenOp::tryListen(const ::addrinfo& addr) { #ifndef _WIN32 if (!socket_->enableAddressReuse()) { - C10D_WARNING("The address reuse option cannot be enabled for the server socket on {}.", addr); + C10D_WARNING( + "The address reuse option cannot be enabled for the server socket on {}.", + addr); } #endif @@ -492,8 +526,9 @@ bool SocketListenOp::tryListen(const ::addrinfo& addr) { // Here we follow the recommendation of Microsoft and use the non-standard // SO_EXCLUSIVEADDRUSE flag instead. if (!socket_->enableExclusiveAddressUse()) { - C10D_WARNING("The exclusive address use option cannot be enabled for the server socket on {}.", - addr); + C10D_WARNING( + "The exclusive address use option cannot be enabled for the server socket on {}.", + addr); } #endif @@ -501,18 +536,25 @@ bool SocketListenOp::tryListen(const ::addrinfo& addr) { // wish to use our IPv6 socket for IPv4 communication as well, we explicitly // ask the system to enable it. if (addr.ai_family == AF_INET6 && !socket_->enableDualStack()) { - C10D_WARNING("The server socket does not support IPv4 communication on {}.", addr); + C10D_WARNING( + "The server socket does not support IPv4 communication on {}.", addr); } if (::bind(socket_->handle(), addr.ai_addr, addr.ai_addrlen) != 0) { - recordError("The server socket has failed to bind to {} {}.", addr, getSocketError()); + recordError( + "The server socket has failed to bind to {} {}.", + addr, + getSocketError()); return false; } // NOLINTNEXTLINE(bugprone-argument-comment) if (::listen(socket_->handle(), /*backlog=*/2048) != 0) { - recordError("The server socket has failed to listen on {} {}.", addr, getSocketError()); + recordError( + "The server socket has failed to listen on {} {}.", + addr, + getSocketError()); return false; } @@ -531,14 +573,13 @@ class SocketConnectOp { static const std::chrono::seconds delay_duration_; - enum class ConnectResult { - Success, - Error, - Retry - }; + enum class ConnectResult { Success, Error, Retry }; public: - SocketConnectOp(const std::string& host, std::uint16_t port, const SocketOptions& opts); + SocketConnectOp( + const std::string& host, + std::uint16_t port, + const SocketOptions& opts); std::unique_ptr run(); @@ -570,32 +611,36 @@ class SocketConnectOp { const std::chrono::seconds SocketConnectOp::delay_duration_{1}; -SocketConnectOp::SocketConnectOp(const std::string& host, - std::uint16_t port, - const SocketOptions& opts) +SocketConnectOp::SocketConnectOp( + const std::string& host, + std::uint16_t port, + const SocketOptions& opts) : host_{host.c_str()}, port_{fmt::to_string(port)}, opts_{&opts} {} std::unique_ptr SocketConnectOp::run() { if (opts_->prefer_ipv6()) { - C10D_DEBUG("The client socket will attempt to connect to an IPv6 address of ({}, {}).", - host_, - port_); + C10D_DEBUG( + "The client socket will attempt to connect to an IPv6 address of ({}, {}).", + host_, + port_); if (tryConnect(AF_INET6)) { return std::move(socket_); } - C10D_DEBUG("The client socket will attempt to connect to an IPv4 address of ({}, {}).", - host_, - port_); + C10D_DEBUG( + "The client socket will attempt to connect to an IPv4 address of ({}, {}).", + host_, + port_); if (tryConnect(AF_INET)) { return std::move(socket_); } } else { - C10D_DEBUG("The client socket will attempt to connect to an IPv4 or IPv6 address of ({}, {}).", - host_, - port_); + C10D_DEBUG( + "The client socket will attempt to connect to an IPv4 or IPv6 address of ({}, {}).", + host_, + port_); if (tryConnect(AF_UNSPEC)) { return std::move(socket_); @@ -628,23 +673,27 @@ bool SocketConnectOp::tryConnect(int family) { errors_.clear(); - ::addrinfo *naked_result = nullptr; + ::addrinfo* naked_result = nullptr; // patternlint-disable cpp-dns-deps int r = ::getaddrinfo(host_, port_.c_str(), &hints, &naked_result); if (r != 0) { const char* gai_err = ::gai_strerror(r); - recordError("The {}network addresses of ({}, {}) cannot be retrieved (gai error: {} - {}).", - family == AF_INET ? "IPv4 " : family == AF_INET6 ? "IPv6 " : "", - host_, - port_, - r, - gai_err); + recordError( + "The {}network addresses of ({}, {}) cannot be retrieved (gai error: {} - {}).", + family == AF_INET ? "IPv4 " + : family == AF_INET6 ? "IPv6 " + : "", + host_, + port_, + r, + gai_err); retry = true; } else { addrinfo_ptr result{naked_result}; - for (::addrinfo* addr = naked_result; addr != nullptr; addr = addr->ai_next) { + for (::addrinfo* addr = naked_result; addr != nullptr; + addr = addr->ai_next) { C10D_TRACE("The client socket is attempting to connect to {}.", *addr); ConnectResult cr = tryConnect(*addr); @@ -662,7 +711,10 @@ bool SocketConnectOp::tryConnect(int family) { if (Clock::now() < deadline_ - delay_duration_) { // Prevent our log output to be too noisy, warn only every 30 seconds. if (retry_attempt == 30) { - C10D_INFO("No socket on ({}, {}) is listening yet, will retry.", host_, port_); + C10D_INFO( + "No socket on ({}, {}) is listening yet, will retry.", + host_, + port_); retry_attempt = 0; } @@ -680,16 +732,19 @@ bool SocketConnectOp::tryConnect(int family) { return false; } -SocketConnectOp::ConnectResult SocketConnectOp::tryConnect(const ::addrinfo& addr) { +SocketConnectOp::ConnectResult SocketConnectOp::tryConnect( + const ::addrinfo& addr) { if (Clock::now() >= deadline_) { throwTimeoutError(); } - SocketImpl::Handle hnd = ::socket(addr.ai_family, addr.ai_socktype, addr.ai_protocol); + SocketImpl::Handle hnd = + ::socket(addr.ai_family, addr.ai_socktype, addr.ai_protocol); if (hnd == SocketImpl::invalid_socket) { - recordError("The client socket cannot be initialized to connect to {} {}.", - addr, - getSocketError()); + recordError( + "The client socket cannot be initialized to connect to {} {}.", + addr, + getSocketError()); return ConnectResult::Error; } @@ -706,12 +761,17 @@ SocketConnectOp::ConnectResult SocketConnectOp::tryConnect(const ::addrinfo& add } // Retry if the server is not yet listening or if its backlog is exhausted. - if (err == std::errc::connection_refused || err == std::errc::connection_reset) { - C10D_TRACE("The server socket on {} is not yet listening {}, will retry.", addr, err); + if (err == std::errc::connection_refused || + err == std::errc::connection_reset) { + C10D_TRACE( + "The server socket on {} is not yet listening {}, will retry.", + addr, + err); return ConnectResult::Retry; } else { - recordError("The client socket has failed to connect to {} {}.", addr, err); + recordError( + "The client socket has failed to connect to {} {}.", addr, err); return ConnectResult::Error; } @@ -725,13 +785,16 @@ SocketConnectOp::ConnectResult SocketConnectOp::tryConnect(const ::addrinfo& add C10D_INFO("The client socket has connected to {} on {}.", addr, *socket_); if (!socket_->enableNoDelay()) { - C10D_WARNING("The no-delay option cannot be enabled for the client socket on {}.", *socket_); + C10D_WARNING( + "The no-delay option cannot be enabled for the client socket on {}.", + *socket_); } return ConnectResult::Success; } -SocketConnectOp::ConnectResult SocketConnectOp::tryConnectCore(const ::addrinfo& addr) { +SocketConnectOp::ConnectResult SocketConnectOp::tryConnectCore( + const ::addrinfo& addr) { int r = ::connect(socket_->handle(), addr.ai_addr, addr.ai_addrlen); if (r == 0) { return ConnectResult::Success; @@ -742,7 +805,8 @@ SocketConnectOp::ConnectResult SocketConnectOp::tryConnectCore(const ::addrinfo& return ConnectResult::Success; } - if (err != std::errc::operation_in_progress && err != std::errc::operation_would_block) { + if (err != std::errc::operation_in_progress && + err != std::errc::operation_would_block) { return ConnectResult::Error; } @@ -769,7 +833,8 @@ SocketConnectOp::ConnectResult SocketConnectOp::tryConnectCore(const ::addrinfo& ::socklen_t err_len = sizeof(int); - r = getSocketOption(socket_->handle(), SOL_SOCKET, SO_ERROR, &err_code, &err_len); + r = getSocketOption( + socket_->handle(), SOL_SOCKET, SO_ERROR, &err_code, &err_len); if (r != 0) { return ConnectResult::Error; } @@ -818,7 +883,10 @@ Socket Socket::listen(std::uint16_t port, const SocketOptions& opts) { return Socket{op.run()}; } -Socket Socket::connect(const std::string& host, std::uint16_t port, const SocketOptions& opts) { +Socket Socket::connect( + const std::string& host, + std::uint16_t port, + const SocketOptions& opts) { SocketConnectOp op{host, port, opts}; return Socket{op.run()}; diff --git a/torch/csrc/distributed/c10d/socket.h b/torch/csrc/distributed/c10d/socket.h index c26900760fbe0e..2bbb5f55353745 100644 --- a/torch/csrc/distributed/c10d/socket.h +++ b/torch/csrc/distributed/c10d/socket.h @@ -39,7 +39,7 @@ class SocketOptions { return connect_timeout_; } -private: + private: bool prefer_ipv6_ = true; std::chrono::seconds connect_timeout_{30}; }; @@ -54,7 +54,10 @@ class Socket { static Socket listen(std::uint16_t port, const SocketOptions& opts = {}); - static Socket connect(const std::string& host, std::uint16_t port, const SocketOptions& opts = {}); + static Socket connect( + const std::string& host, + std::uint16_t port, + const SocketOptions& opts = {}); Socket() noexcept = default; diff --git a/torch/csrc/distributed/rpc/rpc_agent.cpp b/torch/csrc/distributed/rpc/rpc_agent.cpp index 9285558cb275c4..451ec4b7598ce8 100644 --- a/torch/csrc/distributed/rpc/rpc_agent.cpp +++ b/torch/csrc/distributed/rpc/rpc_agent.cpp @@ -7,11 +7,11 @@ namespace rpc { RegisterWorkerInfoOnce::RegisterWorkerInfoOnce() { // WorkerInfo needs to be registered exactly once. Since the op registration - // happens in libtorch_python we wrap the class registration in a helper to make - // sure that if there's multiple copies of Python such as used in torch::deploy - // we only ever register it once. + // happens in libtorch_python we wrap the class registration in a helper to + // make sure that if there's multiple copies of Python such as used in + // torch::deploy we only ever register it once. static auto workerInfo = torch::class_("dist_rpc", "WorkerInfo") - .def(torch::init()); + .def(torch::init()); } constexpr size_t WorkerInfo::MAX_NAME_LEN; diff --git a/torch/csrc/generic/utils.h b/torch/csrc/generic/utils.h index 7ec4d686013f9d..538b265a226e2f 100644 --- a/torch/csrc/generic/utils.h +++ b/torch/csrc/generic/utils.h @@ -11,20 +11,19 @@ struct THPStorage; struct THSPTensor; -typedef class THPPointer THPStoragePtr; +typedef class THPPointer THPStoragePtr; -#if (!defined(THC_GENERIC_FILE)) && \ - (!defined(THQUANTIZED)) -template<> +#if (!defined(THC_GENERIC_FILE)) && (!defined(THQUANTIZED)) +template <> struct THPUtils_typeTraits { -#if defined(TH_REAL_IS_FLOAT) || defined(TH_REAL_IS_DOUBLE) || \ +#if defined(TH_REAL_IS_FLOAT) || defined(TH_REAL_IS_DOUBLE) || \ defined(THC_REAL_IS_FLOAT) || defined(THC_REAL_IS_DOUBLE) || \ defined(THC_REAL_IS_HALF) - static constexpr const char *python_type_str = "float"; + static constexpr const char* python_type_str = "float"; #elif defined(TH_REAL_IS_COMPLEX) || defined(THC_REAL_IS_COMPLEX) - static constexpr const char *python_type_str = "complex"; + static constexpr const char* python_type_str = "complex"; #else - static constexpr const char *python_type_str = "int"; + static constexpr const char* python_type_str = "int"; #endif }; #endif diff --git a/torch/csrc/lazy/backend/backend_data.h b/torch/csrc/lazy/backend/backend_data.h index 15efd7a71b0917..496ecbdbbc6c53 100644 --- a/torch/csrc/lazy/backend/backend_data.h +++ b/torch/csrc/lazy/backend/backend_data.h @@ -1,8 +1,8 @@ #pragma once -#include -#include #include +#include +#include namespace torch { namespace lazy { diff --git a/torch/csrc/lazy/backend/backend_device.cpp b/torch/csrc/lazy/backend/backend_device.cpp index 8445fd74c9740d..13a67bfe481aea 100644 --- a/torch/csrc/lazy/backend/backend_device.cpp +++ b/torch/csrc/lazy/backend/backend_device.cpp @@ -2,21 +2,22 @@ #include #include +#include #include -#include #include -#include +#include namespace torch { namespace lazy { // TODO(alanwaketan): Use the backend API to get the default device type. // In the future, we should also get the default device ordinal. -BackendDevice::BackendDevice() - : type_(std::make_shared()) {} +BackendDevice::BackendDevice() : type_(std::make_shared()) {} -BackendDevice::BackendDevice(std::shared_ptr&& type, int64_t ordinal) - : type_(std::move(type)), ordinal_(ordinal) {} +BackendDevice::BackendDevice( + std::shared_ptr&& type, + int64_t ordinal) + : type_(std::move(type)), ordinal_(ordinal) {} int8_t BackendDevice::type() const { TORCH_INTERNAL_ASSERT(type_); @@ -40,7 +41,8 @@ std::ostream& operator<<(std::ostream& os, const BackendDevice& device) { return os; } -// TODO(whc) refactor this: we need to support non-zero default ordinal for torch/XLA. +// TODO(whc) refactor this: we need to support non-zero default ordinal for +// torch/XLA. BackendDevice atenDeviceToBackendDevice(const c10::Device& device) { TORCH_CHECK(device.type() == at::kLazy, device); int64_t ordinal = device.has_index() ? device.index() : 0; @@ -53,7 +55,7 @@ c10::Device backendDeviceToAtenDevice(const BackendDevice& device) { } c10::optional GetBackendDevice(const at::TensorList tensors) { - for (auto& tensor: tensors) { + for (auto& tensor : tensors) { if (auto lt = TryGetLtcTensor(tensor)) { return lt->GetDevice(); } @@ -68,7 +70,8 @@ c10::optional GetBackendDevice(const at::Tensor& tensor) { return c10::nullopt; } -c10::optional GetBackendDevice(const c10::optional device) { +c10::optional GetBackendDevice( + const c10::optional device) { if (device) { return c10::make_optional(atenDeviceToBackendDevice(*device)); } @@ -79,5 +82,5 @@ c10::optional GetBackendDevice() { return c10::nullopt; } -} // namespace lazy -} // namespace torch +} // namespace lazy +} // namespace torch diff --git a/torch/csrc/lazy/backend/backend_device.h b/torch/csrc/lazy/backend/backend_device.h index 818d0f2c10db27..fc911a859fe177 100644 --- a/torch/csrc/lazy/backend/backend_device.h +++ b/torch/csrc/lazy/backend/backend_device.h @@ -1,7 +1,7 @@ #pragma once -#include #include +#include #include #include @@ -18,14 +18,17 @@ namespace lazy { // Backend should extend it and define their own supported hardware types. struct TORCH_API BackendDeviceType { - int8_t type {(int8_t)at::kCPU}; - // Note: previous default value was '0', which actually maps to at::kCPU, at least now it is explicit, - // we may want to make default/undefined semantics more clear though - BackendDeviceType() :type((int8_t)at::kCPU) {} - BackendDeviceType(int8_t type) :type(type) {} + int8_t type{(int8_t)at::kCPU}; + // Note: previous default value was '0', which actually maps to at::kCPU, at + // least now it is explicit, we may want to make default/undefined semantics + // more clear though + BackendDeviceType() : type((int8_t)at::kCPU) {} + BackendDeviceType(int8_t type) : type(type) {} virtual ~BackendDeviceType() = default; - virtual std::string toString() const { return "Unknown"; } + virtual std::string toString() const { + return "Unknown"; + } }; class TORCH_API BackendDevice { @@ -36,11 +39,19 @@ class TORCH_API BackendDevice { BackendDevice(std::shared_ptr&& type, int64_t ordinal); int8_t type() const; - int64_t ordinal() const { return ordinal_; } - - bool operator==(const BackendDevice& other) const { return compare(other) == 0; } - bool operator!=(const BackendDevice& other) const { return compare(other) != 0; } - bool operator<(const BackendDevice& rhs) const { return compare(rhs) < 0; } + int64_t ordinal() const { + return ordinal_; + } + + bool operator==(const BackendDevice& other) const { + return compare(other) == 0; + } + bool operator!=(const BackendDevice& other) const { + return compare(other) != 0; + } + bool operator<(const BackendDevice& rhs) const { + return compare(rhs) < 0; + } std::string toString() const; @@ -49,32 +60,39 @@ class TORCH_API BackendDevice { // Use shared_ptr instead of unique_ptr so that BackendDevice can be copied. std::shared_ptr type_; - int64_t ordinal_ {0}; + int64_t ordinal_{0}; }; -TORCH_API std::ostream& operator<<(std::ostream& os, const BackendDevice& device); +TORCH_API std::ostream& operator<<( + std::ostream& os, + const BackendDevice& device); // Helpers for converting a c10::Device to BackendDevice and vice versa. TORCH_API BackendDevice atenDeviceToBackendDevice(const c10::Device& device); TORCH_API c10::Device backendDeviceToAtenDevice(const BackendDevice& device); -// Tries to extract the backend device out of the lazy tensor. Returns nullopt if the -// input is not a lazy tensor. -TORCH_API c10::optional GetBackendDevice(const at::TensorList tensors); -TORCH_API c10::optional GetBackendDevice(const at::Tensor& tensor); -TORCH_API c10::optional GetBackendDevice(const c10::optional device); +// Tries to extract the backend device out of the lazy tensor. Returns nullopt +// if the input is not a lazy tensor. +TORCH_API c10::optional GetBackendDevice( + const at::TensorList tensors); +TORCH_API c10::optional GetBackendDevice( + const at::Tensor& tensor); +TORCH_API c10::optional GetBackendDevice( + const c10::optional device); // For variadic template. TORCH_API c10::optional GetBackendDevice(); -template -c10::optional GetBackendDevice(const T& tensor, const Args&... forward_tensors) { - auto optional_device = GetBackendDevice(tensor); - if (optional_device) { - return optional_device; - } - return GetBackendDevice(forward_tensors...); +template +c10::optional GetBackendDevice( + const T& tensor, + const Args&... forward_tensors) { + auto optional_device = GetBackendDevice(tensor); + if (optional_device) { + return optional_device; + } + return GetBackendDevice(forward_tensors...); } -} // namespace lazy -} // namespace torch +} // namespace lazy +} // namespace torch diff --git a/torch/csrc/lazy/backend/backend_interface.cpp b/torch/csrc/lazy/backend/backend_interface.cpp index d754bafefa4bbc..48e375051d046e 100644 --- a/torch/csrc/lazy/backend/backend_interface.cpp +++ b/torch/csrc/lazy/backend/backend_interface.cpp @@ -28,7 +28,6 @@ const IrBuilder* getIrBuilder() { return builder; } - at::Tensor MakeTensorFromComputationData( const BackendDataPtr data, c10::optional logical_scalar_type) { @@ -50,5 +49,5 @@ std::unique_ptr LoweringContext::Create( return getBackend()->CreateLoweringContext(name, device); } -} // namespace lazy -} // namespace torch +} // namespace lazy +} // namespace torch diff --git a/torch/csrc/lazy/backend/backend_interface.h b/torch/csrc/lazy/backend/backend_interface.h index 097b01969457f0..23eaad3ffb5a66 100644 --- a/torch/csrc/lazy/backend/backend_interface.h +++ b/torch/csrc/lazy/backend/backend_interface.h @@ -1,11 +1,11 @@ #pragma once -#include #include #include #include #include #include +#include namespace torch { namespace lazy { @@ -45,15 +45,18 @@ class TORCH_API BackendImplInterface { * */ virtual BackendDataPtr MakeComputationDataFromTensor( - const at::Tensor& tensor, const Shape& shape, + const at::Tensor& tensor, + const Shape& shape, const BackendDevice& device) const = 0; virtual BackendDataPtr MakeComputationDataFromScalar( const at::Scalar& scalar, const torch::lazy::BackendDevice& device) const = 0; virtual BackendDataPtr CreateDataPlaceholder( - const BackendDevice& device, const Shape& shape) const = 0; + const BackendDevice& device, + const Shape& shape) const = 0; - // Gets backend data if the node is a device data node. Otherwise returns nullptr + // Gets backend data if the node is a device data node. Otherwise returns + // nullptr virtual BackendDataPtr GetComputationDataFromNode(Node*) const = 0; virtual at::Tensor MakeTensorFromComputationData( @@ -65,22 +68,26 @@ class TORCH_API BackendImplInterface { * */ virtual std::unique_ptr CreateLoweringContext( - const std::string& name, BackendDevice device, + const std::string& name, + BackendDevice device, c10::ArrayRef post_order, Util::EmissionMap emit_status) const = 0; virtual std::unique_ptr CreateLoweringContext( - const std::string& name, BackendDevice device) const = 0; + const std::string& name, + BackendDevice device) const = 0; // TODO(whc) need to keep this? virtual std::vector GetCompilationDevices( - const std::string& device, c10::ArrayRef devices) const = 0; + const std::string& device, + c10::ArrayRef devices) const = 0; virtual std::vector Compile( std::vector instances) const = 0; virtual std::vector ExecuteComputation( - Computation& computation, c10::ArrayRef arguments, + Computation& computation, + c10::ArrayRef arguments, const BackendDevice& device) const = 0; /** @@ -91,15 +98,13 @@ class TORCH_API BackendImplInterface { // For backends used with virtual c10:: Devices, this configures what real // device type the backend should use, and matters if the backend supports // more than one type of real device. - virtual std::shared_ptr - GetDefaultDeviceType() const = 0; + virtual std::shared_ptr GetDefaultDeviceType() const = 0; virtual void SetDefaultDeviceType(std::string) = 0; // Specify which aten device should be used for eager fallback // may change depending on current 'Default' DeviceType virtual at::DeviceType EagerFallbackDeviceType() const = 0; - // Query all available backend devices virtual std::vector GetBackendDevices() const = 0; @@ -137,5 +142,5 @@ TORCH_API const BackendImplInterface* getBackend(); TORCH_API const IrBuilder* getIrBuilder(); -} // lazy -} // torch +} // namespace lazy +} // namespace torch diff --git a/torch/csrc/lazy/backend/lowering_context.cpp b/torch/csrc/lazy/backend/lowering_context.cpp index c8c46697794009..64922a1b3e136f 100644 --- a/torch/csrc/lazy/backend/lowering_context.cpp +++ b/torch/csrc/lazy/backend/lowering_context.cpp @@ -6,14 +6,16 @@ namespace lazy { LoweringContext::LoweringContext(const std::string& name, BackendDevice device) : device_(std::move(device)) {} -LoweringContext::LoweringContext(const std::string& name, BackendDevice device, - c10::ArrayRef post_order, - Util::EmissionMap emit_status) +LoweringContext::LoweringContext( + const std::string& name, + BackendDevice device, + c10::ArrayRef post_order, + Util::EmissionMap emit_status) : device_(std::move(device)), emit_status_(std::move(emit_status)) {} const std::vector& LoweringContext::GetParametersData() const { return parameters_; } -} // namespace lazy -} // namespace torch +} // namespace lazy +} // namespace torch diff --git a/torch/csrc/lazy/backend/lowering_context.h b/torch/csrc/lazy/backend/lowering_context.h index ca20eb696b015e..6f487aef7f741b 100644 --- a/torch/csrc/lazy/backend/lowering_context.h +++ b/torch/csrc/lazy/backend/lowering_context.h @@ -16,7 +16,7 @@ namespace lazy { class TORCH_API Computation { public: - virtual int parameters_size() const = 0; + virtual int parameters_size() const = 0; virtual const std::vector& parameter_shapes() const = 0; @@ -39,26 +39,31 @@ using ComputationPtr = std::shared_ptr; class TORCH_API LoweringContext { public: LoweringContext(const std::string& name, BackendDevice device); - LoweringContext(const std::string& name, BackendDevice device, - c10::ArrayRef post_order, - Util::EmissionMap emit_status); + LoweringContext( + const std::string& name, + BackendDevice device, + c10::ArrayRef post_order, + Util::EmissionMap emit_status); virtual ~LoweringContext() = default; static std::unique_ptr Create( - const std::string& name, BackendDevice device, + const std::string& name, + BackendDevice device, c10::ArrayRef post_order, Util::EmissionMap emit_status); - static std::unique_ptr Create(const std::string& name, - BackendDevice device); + static std::unique_ptr Create( + const std::string& name, + BackendDevice device); - const BackendDevice& device() const { return device_; }; + const BackendDevice& device() const { + return device_; + }; // Retrieves the vector holding all the tensors associated with the parameter // instructions which have been created. - const std::vector& - GetParametersData() const; + const std::vector& GetParametersData() const; // Adds a new input/output alias. virtual void SetUpAlias( @@ -84,15 +89,19 @@ class TORCH_API LoweringContext { // Associates the given output with the input parameter of the given index and // shape. Only used for the operator-by-operator execution, mostly for // debugging purposes. - virtual void AddParameter(const torch::lazy::Output& output, size_t index, - const Shape& shape, - const std::string& name) = 0; + virtual void AddParameter( + const torch::lazy::Output& output, + size_t index, + const Shape& shape, + const std::string& name) = 0; // Build the computation capturing all the operations created with the // embedded builder (returned by the builder() API). virtual ComputationPtr Build() = 0; - size_t GetEmittedNodeCount() const { return emit_status_.size(); } + size_t GetEmittedNodeCount() const { + return emit_status_.size(); + } protected: BackendDevice device_; @@ -101,5 +110,5 @@ class TORCH_API LoweringContext { Util::EmissionMap emit_status_; }; -} // namespace lazy -} // namespace torch +} // namespace lazy +} // namespace torch diff --git a/torch/csrc/lazy/core/config.cpp b/torch/csrc/lazy/core/config.cpp index 0898fd1a3afe1c..b3488c4df79ac3 100644 --- a/torch/csrc/lazy/core/config.cpp +++ b/torch/csrc/lazy/core/config.cpp @@ -61,21 +61,21 @@ C10_DEFINE_int( 4096, "Set the size for the shape cache used for shape inference"); - namespace torch { namespace lazy { std::string& getLTCForceFallback() { - static std::string config; - static bool _ignore = [&]() { - char *envptr = std::getenv("LTC_FORCE_FALLBACK"); - if (envptr) { - config = std::string(envptr); - } - return true; - }(); - (void) _ignore; // avoid unused variables warning - return config; + static std::string config; + static bool _ignore = [&]() { + char* envptr = std::getenv("LTC_FORCE_FALLBACK"); + if (envptr) { + config = std::string(envptr); + } + return true; + }(); + (void)_ignore; // avoid unused variables warning + return config; } -} } +} // namespace lazy +} // namespace torch diff --git a/torch/csrc/lazy/core/config.h b/torch/csrc/lazy/core/config.h index a73b123a9dd42d..fa481421c2aab9 100644 --- a/torch/csrc/lazy/core/config.h +++ b/torch/csrc/lazy/core/config.h @@ -1,6 +1,6 @@ #pragma once -#include #include +#include C10_DECLARE_bool(torch_lazy_ir_debug); C10_DECLARE_bool(torch_lazy_handle_special_scalars); @@ -22,4 +22,5 @@ C10_DECLARE_int(torch_lazy_shape_cache_size); namespace torch { namespace lazy { TORCH_API std::string& getLTCForceFallback(); -} } +} +} // namespace torch diff --git a/torch/csrc/lazy/core/debug_util.cpp b/torch/csrc/lazy/core/debug_util.cpp index f64fe476592015..50f42b718128e0 100644 --- a/torch/csrc/lazy/core/debug_util.cpp +++ b/torch/csrc/lazy/core/debug_util.cpp @@ -15,7 +15,7 @@ namespace torch { namespace lazy { -namespace { +namespace { std::string GetEnvString(const char* name, const std::string& defval) { const char* env = std::getenv(name); @@ -47,9 +47,9 @@ std::unordered_set* LoadExperiments() { return xset.release(); } -} // namespace +} // namespace -std::vector NoPythonFrames(){ +std::vector NoPythonFrames() { SourceLocation dummy_loc; dummy_loc.file = "No Python Frames"; return {dummy_loc}; @@ -66,7 +66,6 @@ DebugUtil::GraphFormat DebugUtil::GetDefaultGraphFormat() { } std::string GetFirstUserFrameInPython() { - std::string empty; if (!torch::lazy::GetPythonFramesFunction()) { return empty; @@ -78,18 +77,17 @@ std::string GetFirstUserFrameInPython() { auto& loc = frames[i - 1]; if (loc.file.find("site-packages") == std::string::npos) { std::stringstream ss; - ss << loc.file << " " - << loc.function << " " - << loc.line; + ss << loc.file << " " << loc.function << " " << loc.line; return ss.str(); } } return empty; } -std::string DebugUtil::GetTensorsGraphInfo(c10::ArrayRef tensors, - const std::vector* indices, - GraphFormat format) { +std::string DebugUtil::GetTensorsGraphInfo( + c10::ArrayRef tensors, + const std::vector* indices, + GraphFormat format) { std::vector root_nodes; std::vector root_values; std::vector root_hashes; @@ -117,7 +115,8 @@ std::string DebugUtil::GetTensorsGraphInfo(c10::ArrayRef frames = GetPythonFramesFunction()(); ss << "Python Stacktrace:\n"; for (auto& location : frames) { @@ -149,11 +148,13 @@ std::string DebugUtil::GetTensorsGraphInfo(c10::ArrayRef tensors, - const std::vector* indices, - GraphFormat format) { - static const std::string save_file = GetEnvString("LTC_SAVE_TENSORS_FILE", ""); +void DebugUtil::SaveTensorsGraphInfo( + const char* name, + c10::ArrayRef tensors, + const std::vector* indices, + GraphFormat format) { + static const std::string save_file = + GetEnvString("LTC_SAVE_TENSORS_FILE", ""); if (!save_file.empty()) { static std::mutex lock; std::string info = GetTensorsGraphInfo(tensors, indices, format); @@ -168,5 +169,5 @@ bool DebugUtil::ExperimentEnabled(const std::string& name) { return xset->find(name) != xset->end(); } -} // namespace lazy -} // namespace torch +} // namespace lazy +} // namespace torch diff --git a/torch/csrc/lazy/core/debug_util.h b/torch/csrc/lazy/core/debug_util.h index 62cdcf98c641fc..50ba1dae8c9ed3 100644 --- a/torch/csrc/lazy/core/debug_util.h +++ b/torch/csrc/lazy/core/debug_util.h @@ -9,7 +9,8 @@ namespace torch { namespace lazy { -TORCH_API std::function()>& GetPythonFramesFunction(); +TORCH_API std::function()>& +GetPythonFramesFunction(); TORCH_API std::string GetFirstUserFrameInPython(); @@ -27,20 +28,21 @@ class TORCH_API DebugUtil { // values held at the tensors. If indices is not nullptr, it selects the // indices of the tensors whose graph will be emitted. static std::string GetTensorsGraphInfo( - c10::ArrayRef tensors, const std::vector* indices, + c10::ArrayRef tensors, + const std::vector* indices, GraphFormat format = GetDefaultGraphFormat()); // If the environment variable LTC_SAVE_TENSORS_FILE is set to the proper // output path, an instance of the report returned by GetTensorsGraphInfo() is // saved. static void SaveTensorsGraphInfo( - const char* name, c10::ArrayRef tensors, + const char* name, + c10::ArrayRef tensors, const std::vector* indices, GraphFormat format = GetDefaultGraphFormat()); static bool ExperimentEnabled(const std::string& name); }; - -} // namespace lazy -} // namespace torch +} // namespace lazy +} // namespace torch diff --git a/torch/csrc/lazy/core/hash.cpp b/torch/csrc/lazy/core/hash.cpp index 42dac8d09bf320..306c76e32a629b 100644 --- a/torch/csrc/lazy/core/hash.cpp +++ b/torch/csrc/lazy/core/hash.cpp @@ -23,8 +23,8 @@ hash_t LoadHash(const uint8_t** data, const uint8_t* top) { hash_t h; std::array b; #ifdef _MSC_VER - // MSVC (or some versions we use) doesn't support C99 union field init - // but it initializes the first member of the union. + // MSVC (or some versions we use) doesn't support C99 union field init + // but it initializes the first member of the union. } uval = {hash_t(0)}; #else } uval = {.h = hash_t(0)}; @@ -61,7 +61,8 @@ hash_t HashBlock(const void* data, size_t n, const hash_t& seed) { } hash_t DataHash(const void* data, size_t size) { - return HashBlock(data, size, hash_t(static_cast(0xc2b2ae3d27d4eb4f))); + return HashBlock( + data, size, hash_t(static_cast(0xc2b2ae3d27d4eb4f))); } size_t StdDataHash(const void* data, size_t size) { @@ -70,12 +71,13 @@ size_t StdDataHash(const void* data, size_t size) { size_t StdHashCombine(uintmax_t a, uintmax_t b) { return a ^ - (b * 0x27d4eb2f165667c5 + 0x9e3779b97f4a7c15 + (a << 6) + (a >> 2)); + (b * 0x27d4eb2f165667c5 + 0x9e3779b97f4a7c15 + (a << 6) + (a >> 2)); } hash_t HashCombine(const hash_t& a, const hash_t& b) { static const hash_t kb(101, 0x27d4eb2f165667c5); - return hash_t(a ^ (b * kb + (uint64_t)0x9e3779b97f4a7c15 + (a << 6) + (a >> 2))); + return hash_t( + a ^ (b * kb + (uint64_t)0x9e3779b97f4a7c15 + (a << 6) + (a >> 2))); } size_t HashReduce(const hash_t& a) { @@ -89,15 +91,15 @@ std::string HashToString(const hash_t& a) { return ss.str(); } -hash_t Hash(const std::vector& values){ +hash_t Hash(const std::vector& values) { // We can't assume a DataHash size/dataptr approach here bc // vector can be optimized as vector and storage details // are decoupled from actual size of 'bool' type hash_t h(static_cast(0xad2ed1983bbf2e28)); static const hash_t h_true(static_cast(0x74f6b5198daa2b2)); static const hash_t h_false(static_cast(0xe39f30789cab5382)); - for (const auto& b: values) { - if(b) { + for (const auto& b : values) { + if (b) { h = HashCombine(h, h_true); } else { h = HashCombine(h, h_false); @@ -106,6 +108,5 @@ hash_t Hash(const std::vector& values){ return h; } - } // namespace lazy } // namespace torch diff --git a/torch/csrc/lazy/core/hash.h b/torch/csrc/lazy/core/hash.h index 0e871bec2d45f3..402600ea794b20 100644 --- a/torch/csrc/lazy/core/hash.h +++ b/torch/csrc/lazy/core/hash.h @@ -4,14 +4,14 @@ */ #pragma once -#include -#include -#include -#include #include #include #include #include +#include +#include +#include +#include namespace torch { namespace lazy { @@ -93,17 +93,17 @@ static inline hash_t Hash(const c10::Layout& value) { } static inline hash_t Hash(const c10::Scalar& value) { - switch(value.type()){ - case c10::ScalarType::ComplexDouble: - return Hash(value.toComplexDouble()); - case c10::ScalarType::Double: - return Hash(value.toDouble()); - case c10::ScalarType::Long: - return Hash(value.toLong()); - case c10::ScalarType::Bool: - return Hash(value.toBool()); - default: - TORCH_INTERNAL_ASSERT(false, "Unknown scalar type.", value.type()); + switch (value.type()) { + case c10::ScalarType::ComplexDouble: + return Hash(value.toComplexDouble()); + case c10::ScalarType::Double: + return Hash(value.toDouble()); + case c10::ScalarType::Long: + return Hash(value.toLong()); + case c10::ScalarType::Bool: + return Hash(value.toBool()); + default: + TORCH_INTERNAL_ASSERT(false, "Unknown scalar type.", value.type()); } } diff --git a/torch/csrc/lazy/core/ir.cpp b/torch/csrc/lazy/core/ir.cpp index 2a884b6a7c318c..62a0eeadf3b101 100644 --- a/torch/csrc/lazy/core/ir.cpp +++ b/torch/csrc/lazy/core/ir.cpp @@ -5,7 +5,10 @@ #include // Enables caching on for dynamic shapes (aka disable hash on shapes) -C10_DEFINE_bool(ltc_enable_dynamic_shapes, false, "Whether dynamic shape is enabled"); +C10_DEFINE_bool( + ltc_enable_dynamic_shapes, + false, + "Whether dynamic shape is enabled"); namespace torch { namespace lazy { @@ -31,8 +34,8 @@ std::string Output::ToString() const { bool Output::operator==(const Value& rhs) const { // Either side could be kNullValue which has node as nullptr - return (!node == !rhs.node) && (!node || - (node->hash() == rhs.node->hash() && index == rhs.index)); + return (!node == !rhs.node) && + (!node || (node->hash() == rhs.node->hash() && index == rhs.index)); } hash_t Value::hash() const { @@ -57,18 +60,19 @@ bool Node::enableDynamicShape() { } Node::Node(OpKind op, size_t num_outputs) - : op_(op), - num_outputs_(num_outputs), - metadata_(GetMetaDataIfDebugging()) {} + : op_(op), num_outputs_(num_outputs), metadata_(GetMetaDataIfDebugging()) {} -Node::Node(OpKind op, OpList operands, std::vector&& shapes, size_t num_outputs) +Node::Node( + OpKind op, + OpList operands, + std::vector&& shapes, + size_t num_outputs) : Node(op, num_outputs) { - // Move shapes into node shapes_.insert( - shapes_.end(), - std::make_move_iterator(shapes.begin()), - std::make_move_iterator(shapes.end())); + shapes_.end(), + std::make_move_iterator(shapes.begin()), + std::make_move_iterator(shapes.end())); for (auto& operand : operands) { // Ideally, optional operands should be filtered by the leaf node classes, @@ -83,7 +87,11 @@ Node::Node(OpKind op, OpList operands, std::vector&& shapes, size_t num_o } } -Node::Node(OpKind op, OpList operands, const std::function& shape_fn, size_t num_outputs) +Node::Node( + OpKind op, + OpList operands, + const std::function& shape_fn, + size_t num_outputs) : Node(op, operands, std::vector{}, num_outputs) { addComputedShape(shape_fn); } @@ -91,15 +99,16 @@ Node::Node(OpKind op, OpList operands, const std::function& shape_fn, s Node::Node(OpKind op, OpList operands, size_t num_outputs) : Node(op, operands, std::vector{}, num_outputs) {} -Node::Node(OpKind op, Shape shape, size_t num_outputs) - : Node(op, num_outputs) { +Node::Node(OpKind op, Shape shape, size_t num_outputs) : Node(op, num_outputs) { shapes_.push_back(std::move(shape)); } Node::~Node() = default; // Retrieves the full shape of the IR Node. -c10::ArrayRef Node::shapes() const { return shapes_; } +c10::ArrayRef Node::shapes() const { + return shapes_; +} // Retrieves the shape of the output at a given index. const Shape& Node::shape(size_t output_index) const { @@ -152,6 +161,5 @@ void Node::AddOperand(NodePtr node, size_t index) { operands_as_outputs_.emplace_back(operands_.back().get(), index); } - } // namespace lazy } // namespace torch diff --git a/torch/csrc/lazy/core/ir.h b/torch/csrc/lazy/core/ir.h index e583bfb2fe342f..2e3b6891fdb83f 100644 --- a/torch/csrc/lazy/core/ir.h +++ b/torch/csrc/lazy/core/ir.h @@ -13,10 +13,10 @@ #include #include +#include #include -#include #include -#include +#include C10_DECLARE_bool(ltc_enable_dynamic_shapes); @@ -67,7 +67,10 @@ inline std::ostream& operator<<(std::ostream& stream, const OpKind& op) { using OpList = c10::ArrayRef; -hash_t OperandHashes(const OpList& operands, const hash_t& seed, bool bakeInSizes); +hash_t OperandHashes( + const OpList& operands, + const hash_t& seed, + bool bakeInSizes); // A node in the graph. Nodes for operations which require extra data to be // stored for lowering should inherit from this class and add an operation // specific member there. For example, a constant might create a new @@ -87,10 +90,18 @@ class TORCH_API Node { Node(OpKind op, size_t num_outputs); // Construct node with operands and shapes - Node(OpKind op, OpList operands, std::vector&& shapes, size_t num_outputs = 1); + Node( + OpKind op, + OpList operands, + std::vector&& shapes, + size_t num_outputs = 1); // Construct node with operands and shape generated from a function - Node(OpKind op, OpList operands, const std::function& shape_fn, size_t num_outputs = 1); + Node( + OpKind op, + OpList operands, + const std::function& shape_fn, + size_t num_outputs = 1); // Construct node with operands and no shape Node(OpKind op, OpList operands, size_t num_outputs = 1); @@ -156,7 +167,7 @@ class TORCH_API Node { // from UserMetaData. std::shared_ptr user_metadata_; -protected: + protected: // Adds node's index output number as operand. void AddOperand(NodePtr node, size_t index = 0); @@ -168,8 +179,6 @@ class TORCH_API Node { std::vector operands_as_outputs_; }; - - inline std::ostream& operator<<(std::ostream& stream, const Node& node) { stream << node.ToString(); return stream; @@ -200,7 +209,6 @@ const T* NodeCast(const Node* node) { return dynamic_cast(node); } - // Represents a specific output produced by a node. Since the output of a node // can be composed by multiple outputs, the node+index coordinates fully qualify // each single output. @@ -250,8 +258,10 @@ using OutputMap = std::unordered_map; // Represents an input/operand for a Node object. struct TORCH_API Value { Value() = default; - /* implicit */ Value(NodePtr&& node, size_t index = 0) : node(std::move(node)), index(index) {} - /* implicit */ Value(const NodePtr& node, size_t index = 0) : node(node), index(index) {} + /* implicit */ Value(NodePtr&& node, size_t index = 0) + : node(std::move(node)), index(index) {} + /* implicit */ Value(const NodePtr& node, size_t index = 0) + : node(node), index(index) {} hash_t hash() const; hash_t shapeHash() const; @@ -280,6 +290,6 @@ struct TORCH_API Value { } // namespace torch namespace c10 { - // Explicit template instantiation to make ArrayRef work - template class at::ArrayRef; -} +// Explicit template instantiation to make ArrayRef work +template class at::ArrayRef; +} // namespace c10 diff --git a/torch/csrc/lazy/core/ir_builder.h b/torch/csrc/lazy/core/ir_builder.h index 8cc6de02dba45c..2245b97b98fee9 100644 --- a/torch/csrc/lazy/core/ir_builder.h +++ b/torch/csrc/lazy/core/ir_builder.h @@ -8,8 +8,9 @@ #include #include -// This file is part of the backend interface. So, ops shouldn't be added or removed without due process -// The exception to this being the view ops which will be removed soon pending functionalization +// This file is part of the backend interface. So, ops shouldn't be added or +// removed without due process The exception to this being the view ops which +// will be removed soon pending functionalization namespace torch { namespace lazy { @@ -34,7 +35,8 @@ NodePtr MakeNode(Args&&... args) { return std::make_shared(std::forward(args)...); } -// op is passed in for a more efficient node casting, see the implementation of NodeCast +// op is passed in for a more efficient node casting, see the implementation of +// NodeCast template NodePtr ReuseOrMakeNode(Args&&... args) { NodePtr node = ReuseNode(std::forward(args)...); @@ -46,25 +48,80 @@ NodePtr ReuseOrMakeNode(Args&&... args) { } struct IrBuilder { - virtual NodePtr MakeDeviceData(const std::shared_ptr& data) const = 0; - virtual NodePtr MakeScalar(const at::Scalar& value, const at::ScalarType& type) const = 0; - virtual NodePtr MakeExpand(const Value& input0, const std::vector& size, const bool& is_scalar_expand) const = 0; - virtual NodePtr MakeView(const Value& input0, const std::vector& output_size) const = 0; - virtual NodePtr MakeCast(const Value& input0, const at::ScalarType& dtype, const c10::optional& stype = c10::nullopt) const = 0; + virtual NodePtr MakeDeviceData( + const std::shared_ptr& data) const = 0; + virtual NodePtr MakeScalar( + const at::Scalar& value, + const at::ScalarType& type) const = 0; + virtual NodePtr MakeExpand( + const Value& input0, + const std::vector& size, + const bool& is_scalar_expand) const = 0; + virtual NodePtr MakeView( + const Value& input0, + const std::vector& output_size) const = 0; + virtual NodePtr MakeCast( + const Value& input0, + const at::ScalarType& dtype, + const c10::optional& stype = c10::nullopt) const = 0; virtual NodePtr MakeTensorList(const OpList& inputs) const = 0; - virtual NodePtr MakeGeneric(const OpKind& op, const OpList& operands, const Shape& shape, const size_t& num_outputs = 1, const hash_t& hash_seed = static_cast(0x5a2d296e9)) const = 0; + virtual NodePtr MakeGeneric( + const OpKind& op, + const OpList& operands, + const Shape& shape, + const size_t& num_outputs = 1, + const hash_t& hash_seed = static_cast(0x5a2d296e9)) const = 0; // View op nodes - virtual NodePtr MakeAsStridedViewUpdate(const Value& input0, const Value& input1, const std::vector& size, const std::vector& stride, const int64_t& storage_offset) const = 0; - virtual NodePtr MakeAsStrided(const Value& input0, const std::vector& size, const std::vector& stride, const int64_t& storage_offset) const = 0; - virtual NodePtr MakeDiagonalViewUpdate(const Value& input0, const Value& input1, const int64_t& offset, const int64_t& dim1, const int64_t& dim2) const = 0; - virtual NodePtr MakeDiagonal(const Value& input0, const int64_t& offset, const int64_t& dim1, const int64_t& dim2) const = 0; - virtual NodePtr MakeNarrowViewUpdate(const Value& input0, const Value& input1, const std::vector& base_indices) const = 0; - virtual NodePtr MakeNarrow(const Value& input0, const std::vector& base_indices, const std::vector& sizes) const = 0; - virtual NodePtr MakePermute(const Value& input0, const std::vector& dims) const = 0; - virtual NodePtr MakeResize(const Value& input0, const std::vector& size) const = 0; - virtual NodePtr MakeSelectViewUpdate(const Value& input0, const Value& input1, const int64_t& dim, const int64_t& start, const int64_t& end, const int64_t& stride) const = 0; - virtual NodePtr MakeSelect(const Value& input0, const int64_t& dim, const int64_t& start, const int64_t& end, const int64_t& stride) const = 0; + virtual NodePtr MakeAsStridedViewUpdate( + const Value& input0, + const Value& input1, + const std::vector& size, + const std::vector& stride, + const int64_t& storage_offset) const = 0; + virtual NodePtr MakeAsStrided( + const Value& input0, + const std::vector& size, + const std::vector& stride, + const int64_t& storage_offset) const = 0; + virtual NodePtr MakeDiagonalViewUpdate( + const Value& input0, + const Value& input1, + const int64_t& offset, + const int64_t& dim1, + const int64_t& dim2) const = 0; + virtual NodePtr MakeDiagonal( + const Value& input0, + const int64_t& offset, + const int64_t& dim1, + const int64_t& dim2) const = 0; + virtual NodePtr MakeNarrowViewUpdate( + const Value& input0, + const Value& input1, + const std::vector& base_indices) const = 0; + virtual NodePtr MakeNarrow( + const Value& input0, + const std::vector& base_indices, + const std::vector& sizes) const = 0; + virtual NodePtr MakePermute( + const Value& input0, + const std::vector& dims) const = 0; + virtual NodePtr MakeResize( + const Value& input0, + const std::vector& size) const = 0; + virtual NodePtr MakeSelectViewUpdate( + const Value& input0, + const Value& input1, + const int64_t& dim, + const int64_t& start, + const int64_t& end, + const int64_t& stride) const = 0; + virtual NodePtr MakeSelect( + const Value& input0, + const int64_t& dim, + const int64_t& start, + const int64_t& end, + const int64_t& stride) const = 0; virtual NodePtr MakeSqueeze(const Value& input0, const int& dim) const = 0; virtual NodePtr MakeUnsqueeze(const Value& input0, const int& dim) const = 0; @@ -76,63 +133,121 @@ struct IrBuilder { }; static inline NodePtr MakeDeviceData(const std::shared_ptr& data) { - return getIrBuilder()->MakeDeviceData(data); + return getIrBuilder()->MakeDeviceData(data); } -static inline NodePtr MakeScalar(const at::Scalar& value, const at::ScalarType& type) { - return getIrBuilder()->MakeScalar(value, type); +static inline NodePtr MakeScalar( + const at::Scalar& value, + const at::ScalarType& type) { + return getIrBuilder()->MakeScalar(value, type); } -static inline NodePtr MakeExpand(const Value& input0, const std::vector& size, const bool& is_scalar_expand) { - return getIrBuilder()->MakeExpand(input0, size, is_scalar_expand); +static inline NodePtr MakeExpand( + const Value& input0, + const std::vector& size, + const bool& is_scalar_expand) { + return getIrBuilder()->MakeExpand(input0, size, is_scalar_expand); } -static inline NodePtr MakeView(const Value& input0, const std::vector& output_size) { - return getIrBuilder()->MakeView(input0, output_size); +static inline NodePtr MakeView( + const Value& input0, + const std::vector& output_size) { + return getIrBuilder()->MakeView(input0, output_size); } -static inline NodePtr MakeCast(const Value& input0, const at::ScalarType& dtype, const c10::optional& stype = c10::nullopt) { - return getIrBuilder()->MakeCast(input0, dtype, stype); +static inline NodePtr MakeCast( + const Value& input0, + const at::ScalarType& dtype, + const c10::optional& stype = c10::nullopt) { + return getIrBuilder()->MakeCast(input0, dtype, stype); } static inline NodePtr MakeTensorList(const OpList& inputs) { - return getIrBuilder()->MakeTensorList(inputs); + return getIrBuilder()->MakeTensorList(inputs); } -static inline NodePtr MakeGeneric(const OpKind& op, const OpList& operands, const Shape& shape, const size_t& num_outputs = 1, const hash_t& hash_seed = static_cast(0x5a2d296e9)) { - return getIrBuilder()->MakeGeneric(op, operands, shape, num_outputs, hash_seed); +static inline NodePtr MakeGeneric( + const OpKind& op, + const OpList& operands, + const Shape& shape, + const size_t& num_outputs = 1, + const hash_t& hash_seed = static_cast(0x5a2d296e9)) { + return getIrBuilder()->MakeGeneric( + op, operands, shape, num_outputs, hash_seed); } // View op nodes -static inline NodePtr MakeAsStridedViewUpdate(const Value& input0, const Value& input1, const std::vector& size, const std::vector& stride, const int64_t& storage_offset) { - return getIrBuilder()->MakeAsStridedViewUpdate(input0, input1, size, stride, storage_offset); +static inline NodePtr MakeAsStridedViewUpdate( + const Value& input0, + const Value& input1, + const std::vector& size, + const std::vector& stride, + const int64_t& storage_offset) { + return getIrBuilder()->MakeAsStridedViewUpdate( + input0, input1, size, stride, storage_offset); } -static inline NodePtr MakeAsStrided(const Value& input0, const std::vector& size, const std::vector& stride, const int64_t& storage_offset) { - return getIrBuilder()->MakeAsStrided(input0, size, stride, storage_offset); +static inline NodePtr MakeAsStrided( + const Value& input0, + const std::vector& size, + const std::vector& stride, + const int64_t& storage_offset) { + return getIrBuilder()->MakeAsStrided(input0, size, stride, storage_offset); } -static inline NodePtr MakeDiagonalViewUpdate(const Value& input0, const Value& input1, const int64_t& offset, const int64_t& dim1, const int64_t& dim2) { - return getIrBuilder()->MakeDiagonalViewUpdate(input0, input1, offset, dim1, dim2); +static inline NodePtr MakeDiagonalViewUpdate( + const Value& input0, + const Value& input1, + const int64_t& offset, + const int64_t& dim1, + const int64_t& dim2) { + return getIrBuilder()->MakeDiagonalViewUpdate( + input0, input1, offset, dim1, dim2); } -static inline NodePtr MakeDiagonal(const Value& input0, const int64_t& offset, const int64_t& dim1, const int64_t& dim2) { - return getIrBuilder()->MakeDiagonal(input0, offset, dim1, dim2); +static inline NodePtr MakeDiagonal( + const Value& input0, + const int64_t& offset, + const int64_t& dim1, + const int64_t& dim2) { + return getIrBuilder()->MakeDiagonal(input0, offset, dim1, dim2); } -static inline NodePtr MakeNarrowViewUpdate(const Value& input0, const Value& input1, const std::vector& base_indices) { - return getIrBuilder()->MakeNarrowViewUpdate(input0, input1, base_indices); +static inline NodePtr MakeNarrowViewUpdate( + const Value& input0, + const Value& input1, + const std::vector& base_indices) { + return getIrBuilder()->MakeNarrowViewUpdate(input0, input1, base_indices); } -static inline NodePtr MakeNarrow(const Value& input0, const std::vector& base_indices, const std::vector& sizes) { - return getIrBuilder()->MakeNarrow(input0, base_indices, sizes); +static inline NodePtr MakeNarrow( + const Value& input0, + const std::vector& base_indices, + const std::vector& sizes) { + return getIrBuilder()->MakeNarrow(input0, base_indices, sizes); } -static inline NodePtr MakePermute(const Value& input0, const std::vector& dims) { - return getIrBuilder()->MakePermute(input0, dims); +static inline NodePtr MakePermute( + const Value& input0, + const std::vector& dims) { + return getIrBuilder()->MakePermute(input0, dims); } -static inline NodePtr MakeResize(const Value& input0, const std::vector& size) { - return getIrBuilder()->MakeResize(input0, size); +static inline NodePtr MakeResize( + const Value& input0, + const std::vector& size) { + return getIrBuilder()->MakeResize(input0, size); } -static inline NodePtr MakeSelectViewUpdate(const Value& input0, const Value& input1, const int64_t& dim, const int64_t& start, const int64_t& end, const int64_t& stride) { - return getIrBuilder()->MakeSelectViewUpdate(input0, input1, dim, start, end, stride); +static inline NodePtr MakeSelectViewUpdate( + const Value& input0, + const Value& input1, + const int64_t& dim, + const int64_t& start, + const int64_t& end, + const int64_t& stride) { + return getIrBuilder()->MakeSelectViewUpdate( + input0, input1, dim, start, end, stride); } -static inline NodePtr MakeSelect(const Value& input0, const int64_t& dim, const int64_t& start, const int64_t& end, const int64_t& stride) { - return getIrBuilder()->MakeSelect(input0, dim, start, end, stride); +static inline NodePtr MakeSelect( + const Value& input0, + const int64_t& dim, + const int64_t& start, + const int64_t& end, + const int64_t& stride) { + return getIrBuilder()->MakeSelect(input0, dim, start, end, stride); } static inline NodePtr MakeSqueeze(const Value& input0, const int& dim) { - return getIrBuilder()->MakeSqueeze(input0, dim); + return getIrBuilder()->MakeSqueeze(input0, dim); } static inline NodePtr MakeUnsqueeze(const Value& input0, const int& dim) { - return getIrBuilder()->MakeUnsqueeze(input0, dim); + return getIrBuilder()->MakeUnsqueeze(input0, dim); } // dynamic ir nodes diff --git a/torch/csrc/lazy/core/ir_metadata.cpp b/torch/csrc/lazy/core/ir_metadata.cpp index cea721bc139741..75bc0c3051658c 100644 --- a/torch/csrc/lazy/core/ir_metadata.cpp +++ b/torch/csrc/lazy/core/ir_metadata.cpp @@ -1,7 +1,7 @@ -#include #include #include #include +#include namespace torch { namespace lazy { diff --git a/torch/csrc/lazy/core/lazy_graph_executor.cpp b/torch/csrc/lazy/core/lazy_graph_executor.cpp index 569e0999a90cda..c312997c49de55 100644 --- a/torch/csrc/lazy/core/lazy_graph_executor.cpp +++ b/torch/csrc/lazy/core/lazy_graph_executor.cpp @@ -12,10 +12,10 @@ #include #include -#include -#include #include +#include #include +#include #include @@ -529,13 +529,16 @@ Value LazyGraphExecutor::GetDeviceDataIrValue( return MakeDeviceData(std::move(data)); } -Value LazyGraphExecutor::GetIrValueForScalarFromCodegen(const at::Scalar& value) { +Value LazyGraphExecutor::GetIrValueForScalarFromCodegen( + const at::Scalar& value) { if (IsSpecialScalar(value)) { return MakeScalar(value, value.type()); } auto cpu_device = getBackend()->GetBackendDevice(c10::Device(c10::kCPU, 0)); - BackendDataPtr data = getBackend()->MakeComputationDataFromScalar(value, cpu_device); - data->SetInfo(std::make_shared(/*tensor_id=*/-1, /*read_only=*/true)); + BackendDataPtr data = + getBackend()->MakeComputationDataFromScalar(value, cpu_device); + data->SetInfo( + std::make_shared(/*tensor_id=*/-1, /*read_only=*/true)); return MakeDeviceData(std::move(data)); } @@ -546,7 +549,8 @@ Value LazyGraphExecutor::GetIrValueForScalar( if (IsSpecialScalar(value)) { return MakeScalar(value, type); } - return GetDeviceDataIrValue(value, type, getBackend()->GetBackendDevice(c10::Device(c10::kCPU, 0))); + return GetDeviceDataIrValue( + value, type, getBackend()->GetBackendDevice(c10::Device(c10::kCPU, 0))); } Value LazyGraphExecutor::GetIrValueForScalar( @@ -556,15 +560,17 @@ Value LazyGraphExecutor::GetIrValueForScalar( } Value LazyGraphExecutor::GetIrValueForExpandedScalar( - const at::Scalar& value, const Shape& shape, + const at::Scalar& value, + const Shape& shape, const BackendDevice& device) { c10::ArrayRef dimensions = shape.sizes(); auto type = shape.scalar_type(); Value ir_value = GetIrValueForScalar(value, type, device); if (!dimensions.empty()) { - ir_value = MakeExpand( - ir_value, dimensions.vec(), - /*is_scalar_expand=*/true); + ir_value = MakeExpand( + ir_value, + dimensions.vec(), + /*is_scalar_expand=*/true); } return ir_value; } @@ -902,8 +908,8 @@ std::shared_ptr LazyGraphExecutor:: return nullptr; } PostOrderData po_data = RunPostOrder(*tensors, &coll); - DebugUtil::SaveTensorsGraphInfo("ScheduleSyncTensorsGraph", *tensors, - &coll.indices); + DebugUtil::SaveTensorsGraphInfo( + "ScheduleSyncTensorsGraph", *tensors, &coll.indices); coll.hash = HashCombine(coll.hash, Hash(po_data.parameter_sequence)); VLOG(4) << "Parameter sequence graph hash " << HashToString(coll.hash); std::shared_ptr async = TryRunCachedSync(tensors, &coll, &po_data); @@ -914,7 +920,8 @@ std::shared_ptr LazyGraphExecutor:: CompilationResult compile_result = Compile(*tensors, devices, coll, &po_data); if (GRAPH_DUMP_ENABLED) { auto* comp = compile_result.computation.get(); - LOG(ERROR) << "Add a cached computation with hash " << coll.hash << std::endl; + LOG(ERROR) << "Add a cached computation with hash " << coll.hash + << std::endl; LOG(ERROR) << "Add a graph to cache: " << comp->to_string() << std::endl; } @@ -961,9 +968,12 @@ std::shared_ptr LazyGraphExecutor:: VLOG(3) << "Executing IR graph hash " << HashToString(hash) << " on device " << async->device << " done!"; - TORCH_CHECK(async->tensors_data.size() == results.size(), - "Expected number of outputs does not match TorchScript Stack size: ", - async->tensors_data.size(), " != ", results.size()); + TORCH_CHECK( + async->tensors_data.size() == results.size(), + "Expected number of outputs does not match TorchScript Stack size: ", + async->tensors_data.size(), + " != ", + results.size()); for (const auto i : c10::irange(results.size())) { if (async->tensors_data[i] != nullptr) { @@ -1055,7 +1065,8 @@ std::vector LazyGraphExecutor::FetchTensors( ++literals_index; ++sync_index; } else { - c10::optional tensor_data = (*tensors)[i]->CurrentTensorData(); + c10::optional tensor_data = + (*tensors)[i]->CurrentTensorData(); if (tensor_data) { results.push_back(*tensor_data); } else { @@ -1107,17 +1118,19 @@ void LazyGraphExecutor::TensorCollectionBarrier(SyncTensorCollection* coll) { return; } if (coll) { - VLOG(4) << "Waiting on device barrier for device " << coll->device << " ..."; + VLOG(4) << "Waiting on device barrier for device " << coll->device + << " ..."; { TORCH_LAZY_TIMED("DeviceLockWait"); - coll->unlocker = - DeviceLockerArena::Get()->LockDevices({coll->device }); + coll->unlocker = DeviceLockerArena::Get()->LockDevices({coll->device}); } - VLOG(4) << "Waiting on device barrier for device " << coll->device << " done!"; + VLOG(4) << "Waiting on device barrier for device " << coll->device + << " done!"; } } -hash_t LazyGraphExecutor::GetGraphHash(const std::vector& tensors) { +hash_t LazyGraphExecutor::GetGraphHash( + const std::vector& tensors) { SyncTensorsConfig config; config.sync_ltc_data = false; diff --git a/torch/csrc/lazy/core/lazy_view.cpp b/torch/csrc/lazy/core/lazy_view.cpp index b9766bc1698565..d52c0f62fb77ef 100644 --- a/torch/csrc/lazy/core/lazy_view.cpp +++ b/torch/csrc/lazy/core/lazy_view.cpp @@ -1,9 +1,9 @@ #include #include -#include #include #include +#include #include #include @@ -62,10 +62,10 @@ Value ApplyViewInfo(Value ir_value, const ViewInfo& view_info) { // a = torch.diagonal(b) // b.add_(1) # a should be updated as well. // -// Ideally we should all have a *ViewUpdate IR which updates the original tensor/view -// withe current value. See DiagonalViewUpdate and corresponding LowerDiagonalViewUpdate -// in ts_node_lowering.cpp. There are some "edge cases" here simply because they can -// smartly reuse some other ops to undo themselves. +// Ideally we should all have a *ViewUpdate IR which updates the original +// tensor/view withe current value. See DiagonalViewUpdate and corresponding +// LowerDiagonalViewUpdate in ts_node_lowering.cpp. There are some "edge cases" +// here simply because they can smartly reuse some other ops to undo themselves. Value ApplyUpdate(Value ir_value, const Alias::UpdateData& update_data) { // We first bring the source IR value forward, by reshaping and slicing. std::vector tmp_values({ir_value}); @@ -88,14 +88,13 @@ Value ApplyUpdate(Value ir_value, const Alias::UpdateData& update_data) { view_info.select->stride); break; case ViewInfo::Type::kNarrow: - result = MakeNarrowViewUpdate( - tmp_values[i - 1], result, view_info.indices); + result = + MakeNarrowViewUpdate(tmp_values[i - 1], result, view_info.indices); break; case ViewInfo::Type::kNoOp: break; case ViewInfo::Type::kPermute: - result = MakePermute( - result, InversePermutation(view_info.permutation)); + result = MakePermute(result, InversePermutation(view_info.permutation)); break; case ViewInfo::Type::kReshape: result = MakeView(result, view_info.source_shape.sizes().vec()); @@ -104,11 +103,11 @@ Value ApplyUpdate(Value ir_value, const Alias::UpdateData& update_data) { result = MakeResize(result, view_info.source_shape.sizes().vec()); break; case ViewInfo::Type::kSqueeze: - result = MakeUnsqueeze(ir_value, view_info.squeeze_index); - break; + result = MakeUnsqueeze(ir_value, view_info.squeeze_index); + break; case ViewInfo::Type::kUnsqueeze: - result = MakeSqueeze(ir_value, view_info.squeeze_index); - break; + result = MakeSqueeze(ir_value, view_info.squeeze_index); + break; case ViewInfo::Type::kAsStrided: result = MakeAsStridedViewUpdate( tmp_values[i - 1], @@ -145,8 +144,7 @@ ViewInfo::ViewInfo(Type view_type, Shape shape, Shape source_shape, int64_t sqi) : view_type(view_type), shape(std::move(shape)), source_shape(std::move(source_shape)), - squeeze_index(sqi) -{ + squeeze_index(sqi) { TORCH_CHECK(view_type == Type::kSqueeze); } diff --git a/torch/csrc/lazy/core/metrics.h b/torch/csrc/lazy/core/metrics.h index c1e125f4048cb2..43fb617c1ba166 100644 --- a/torch/csrc/lazy/core/metrics.h +++ b/torch/csrc/lazy/core/metrics.h @@ -204,8 +204,7 @@ class TORCH_API Counter { __counter->AddValue(value); \ } while (0) -#define TORCH_LAZY_FN_COUNTER(ns) \ - TORCH_LAZY_COUNTER(c10::str(ns, __func__), 1) +#define TORCH_LAZY_FN_COUNTER(ns) TORCH_LAZY_COUNTER(c10::str(ns, __func__), 1) #define TORCH_LAZY_VALUE_METRIC(name, value) \ do { \ diff --git a/torch/csrc/lazy/core/ops/arithmetic_ir_ops.cpp b/torch/csrc/lazy/core/ops/arithmetic_ir_ops.cpp index 9c416822015d92..933b6ab4e6d897 100644 --- a/torch/csrc/lazy/core/ops/arithmetic_ir_ops.cpp +++ b/torch/csrc/lazy/core/ops/arithmetic_ir_ops.cpp @@ -10,9 +10,10 @@ namespace torch { namespace lazy { // These operators were once widely used in nativefunction impls to perform -// convenient decompositions (partial lowerings) of aten operators into more primitive -// opererators. They should not be used for this purpose anymore, but still used in -// lazy_graph_executor for RNG math in one place. We could rewrite that. +// convenient decompositions (partial lowerings) of aten operators into more +// primitive opererators. They should not be used for this purpose anymore, but +// still used in lazy_graph_executor for RNG math in one place. We could +// rewrite that. NodePtr operator+(const Value& node1, const Value& node2) { return MakeGeneric( OpKind(at::aten::add), @@ -24,24 +25,21 @@ NodePtr operator-(const Value& node1, const Value& node2) { return MakeGeneric( OpKind(at::aten::sub), {node1, node2}, - GetPromotedBinaryOpShape( - node1.shape(), node2.shape())); + GetPromotedBinaryOpShape(node1.shape(), node2.shape())); } NodePtr operator*(const Value& node1, const Value& node2) { return MakeGeneric( OpKind(at::aten::mul), {node1, node2}, - GetPromotedBinaryOpShape( - node1.shape(), node2.shape())); + GetPromotedBinaryOpShape(node1.shape(), node2.shape())); } NodePtr operator/(const Value& node1, const Value& node2) { return MakeGeneric( OpKind(at::aten::div), {node1, node2}, - GetPromotedBinaryOpShape( - node1.shape(), node2.shape())); + GetPromotedBinaryOpShape(node1.shape(), node2.shape())); } } // namespace lazy diff --git a/torch/csrc/lazy/core/ops/utils.cpp b/torch/csrc/lazy/core/ops/utils.cpp index 65f7cbd1639da2..aa8cc62b8614de 100644 --- a/torch/csrc/lazy/core/ops/utils.cpp +++ b/torch/csrc/lazy/core/ops/utils.cpp @@ -1,9 +1,9 @@ #include -#include +#include #include +#include #include -#include namespace torch { namespace lazy { @@ -14,8 +14,7 @@ bool StrideIsSupported(c10::ArrayRef stride) { return stride.empty() || sorted_stride.front() == 1; } -std::vector GetArrayStridePermutation( - c10::ArrayRef stride) { +std::vector GetArrayStridePermutation(c10::ArrayRef stride) { std::vector permutation = Iota(stride.size()); std::sort(permutation.begin(), permutation.end(), [&](int64_t a, int64_t b) { return stride[a] > stride[b]; @@ -54,7 +53,6 @@ Shape MakePermuteShape( PermuteDimensions(permutation, source_shape.sizes())); } - Shape MakeSelectShape( const Shape& shape, int64_t dim, @@ -76,10 +74,11 @@ int64_t GetStride(int64_t start, int64_t end, int64_t stride) { return stride; } -// This is almost like at::inferSqueezeGeometry, but that requires a Tensor input -// and also computes new strides. This logic seems correct. -std::vector BuildSqueezedDimensions(c10::ArrayRef dimensions, - int64_t squeeze_dim) { +// This is almost like at::inferSqueezeGeometry, but that requires a Tensor +// input and also computes new strides. This logic seems correct. +std::vector BuildSqueezedDimensions( + c10::ArrayRef dimensions, + int64_t squeeze_dim) { std::vector output_dimensions; for (const auto i : c10::irange(dimensions.size())) { int64_t dim = dimensions[i]; diff --git a/torch/csrc/lazy/core/ops/utils.h b/torch/csrc/lazy/core/ops/utils.h index 15ba1d26428744..cc5d5bdbe25bc3 100644 --- a/torch/csrc/lazy/core/ops/utils.h +++ b/torch/csrc/lazy/core/ops/utils.h @@ -8,7 +8,8 @@ namespace lazy { TORCH_API bool StrideIsSupported(c10::ArrayRef stride); -TORCH_API std::vector GetArrayStridePermutation(c10::ArrayRef stride); +TORCH_API std::vector GetArrayStridePermutation( + c10::ArrayRef stride); TORCH_API Shape MakeDiagonalShape( const Shape& shape, @@ -16,9 +17,8 @@ TORCH_API Shape MakeDiagonalShape( int64_t dim1, int64_t dim2); -TORCH_API Shape MakePermuteShape( - const Shape& source_shape, - c10::ArrayRef permutation); +TORCH_API Shape +MakePermuteShape(const Shape& source_shape, c10::ArrayRef permutation); TORCH_API Shape MakeSelectShape( const Shape& shape, @@ -29,8 +29,9 @@ TORCH_API Shape MakeSelectShape( TORCH_API int64_t GetStride(int64_t start, int64_t end, int64_t stride); -TORCH_API std::vector BuildSqueezedDimensions(c10::ArrayRef dimensions, - int64_t squeeze_dim); +TORCH_API std::vector BuildSqueezedDimensions( + c10::ArrayRef dimensions, + int64_t squeeze_dim); TORCH_API std::vector BuildUnsqueezedDimensions( c10::ArrayRef dimensions, diff --git a/torch/csrc/lazy/core/permutation_util.h b/torch/csrc/lazy/core/permutation_util.h index fd5a862d9160a0..7a368301035e9f 100644 --- a/torch/csrc/lazy/core/permutation_util.h +++ b/torch/csrc/lazy/core/permutation_util.h @@ -24,7 +24,11 @@ std::vector PermuteDimensions( using T = typename Container::value_type; TORCH_CHECK( dimensions.size() == permutation.size(), - "Invalid permutation specified. dimensions.size() != permutation.size() (", dimensions.size(), " vs. ", permutation.size(), ")"); + "Invalid permutation specified. dimensions.size() != permutation.size() (", + dimensions.size(), + " vs. ", + permutation.size(), + ")"); TORCH_CHECK( IsPermutation(permutation), "Invalid permutation specified. Permutation is not permutation"); diff --git a/torch/csrc/lazy/core/shape.cpp b/torch/csrc/lazy/core/shape.cpp index ba90fbbcba78e3..2844cc818e2356 100644 --- a/torch/csrc/lazy/core/shape.cpp +++ b/torch/csrc/lazy/core/shape.cpp @@ -108,7 +108,7 @@ void applySymbolicShapesOnLT( } auto res_symbolic = jit::calculateSymbolicShapesOnOp(&schema, converted_args); if (!res_symbolic) { - for (auto & result_shape : result_shapes) { + for (auto& result_shape : result_shapes) { result_shape = result_shape.with_symbolic_dims(c10::nullopt); } } else { diff --git a/torch/csrc/lazy/core/shape_inference.cpp b/torch/csrc/lazy/core/shape_inference.cpp index 4cc188e08f0f32..67303b93e5fb43 100644 --- a/torch/csrc/lazy/core/shape_inference.cpp +++ b/torch/csrc/lazy/core/shape_inference.cpp @@ -6,8 +6,9 @@ * where we do not yet have structured kernels in pytorch core. Ops * for which there _are_ structured kernels can use meta::op() to infer * shape/dtype, and codegen makes use of this. Ops for which there are not - * yet structured kernels can still be used with lazy_tensor codegen, but require - * manual intervention to implement compute_shape_{op} and compute_dtype_{op}. + * yet structured kernels can still be used with lazy_tensor codegen, but + * require manual intervention to implement compute_shape_{op} and + * compute_dtype_{op}. * * READ THIS! * @@ -17,23 +18,26 @@ * kernels instead, but it's a lot faster to write these so we're decoupling the * two efforts to move fast for adding support for codegenned Lazy Tensor ops. * - * Codegenned Lazy Tensor ops with handwritten shape formulae are still better than - * fully handwritten Lazy Tensor ops (which also have handwritten shape formulae). + * Codegenned Lazy Tensor ops with handwritten shape formulae are still better + * than fully handwritten Lazy Tensor ops (which also have handwritten shape + * formulae). * * 2. Structured Kernels For The Win * --------------------------------- - * Long term, more and more ops should be supported as 'structured kernels'. Consider - * doing your part and porting an op. As ops get ported over, the codegen will automatically - * notice and stop generating declarations for these shape formulae, so we'll need to - * manually clean up the unused functions in this file, or somehow automate that. + * Long term, more and more ops should be supported as 'structured kernels'. + * Consider doing your part and porting an op. As ops get ported over, the + * codegen will automatically notice and stop generating declarations for these + * shape formulae, so we'll need to manually clean up the unused functions in + * this file, or somehow automate that. * * https://dev-discuss.pytorch.org/t/slides-from-structured-kernel-presentation/179 * * 3. How to figure out the shape/dtype * ------------------------------------ - * Unfortunatley there isn't a one-stop-shop for learning the output shape formulae for all - * operators. This is partly because some operators are not part of our 'public' API, including - * backward operators which users don't directly invoke. + * Unfortunatley there isn't a one-stop-shop for learning the output shape + * formulae for all operators. This is partly because some operators are not + * part of our 'public' API, including backward operators which users don't + * directly invoke. * * Check our opinfo registry: * https://github.com/pytorch/pytorch/blob/13b859983183ea9938deb5030ac9a0747841f0a8/torch/csrc/jit/runtime/symbolic_shape_registry.cpp @@ -45,26 +49,27 @@ #include -#include -#include -#include -#include #include -#include #include +#include #include #include +#include +#include #include #include #include #include +#include +#include #include #include -namespace torch{ +namespace torch { namespace lazy { -// Copied from ATen/native/utils/ParamUtils.h, which aparently I can't include from here? +// Copied from ATen/native/utils/ParamUtils.h, which aparently I can't include +// from here? std::vector expand_param_if_needed( at::IntArrayRef list_param, const char* param_name, @@ -82,63 +87,83 @@ std::vector expand_param_if_needed( } } -// It seems more common to not use parameters than to use them, so disable unused-parameter warning +// It seems more common to not use parameters than to use them, so disable +// unused-parameter warning #pragma GCC diagnostic push #pragma GCC diagnostic ignored "-Wunused-parameter" -TORCH_API std::vector compute_shape_arange_out(const at::Scalar & start, const at::Scalar & end, const at::Scalar & step, at::Tensor & out) { +TORCH_API std::vector compute_shape_arange_out( + const at::Scalar& start, + const at::Scalar& end, + const at::Scalar& step, + at::Tensor& out) { double size_d = 0; // shape inference code copied from RangeFactories.cpp arange_out function - // Note: AT_DISPATCH_ALL_TYPES_AND is just a macro that defines the correct c++ scalar_t type depending on out tensor - - AT_DISPATCH_ALL_TYPES_AND(c10::kBFloat16, out.scalar_type(), "compute_shape_arange_out", [&]() { - // Note: acc_type further defines an accumulataion type depending on the scalar_t and whether its on cuda vs cpu. - using accscalar_t = at::acc_type; - auto xstart = start.to(); - auto xend = end.to(); - auto xstep = step.to(); - - // we use double precision for (start - end) / step - // to compute size_d for consistency across devices. - // The problem with using accscalar_t is that accscalar_t might be float32 on gpu for a float32 scalar_t, - // but double on cpu for the same, - // and the effective output size starts differing on CPU vs GPU because of precision issues, which - // we dont want. - // the corner-case we do want to take into account is int64_t, which has higher precision than double - // NOLINTNEXTLINE(bugprone-branch-clone) - if (std::is_same::value) { - size_d = std::ceil(static_cast(end.to() - start.to()) - / step.to()); - } else { - size_d = std::ceil(static_cast(end.to() - start.to()) - / step.to()); - } - - TORCH_CHECK(xstep > 0 || xstep < 0, "step must be nonzero"); - TORCH_CHECK(std::isfinite(static_cast(xstart)) && - std::isfinite(static_cast(xend)), - "unsupported range: ", xstart, " -> ", xend); - TORCH_CHECK(((xstep > 0) && (xend >= xstart)) || ((xstep < 0) && (xend <= xstart)), - "upper bound and larger bound inconsistent with step sign"); - - TORCH_CHECK(size_d >= 0 && size_d <= static_cast(std::numeric_limits::max()), - "invalid size, possible overflow?"); - }); + // Note: AT_DISPATCH_ALL_TYPES_AND is just a macro that defines the correct + // c++ scalar_t type depending on out tensor + + AT_DISPATCH_ALL_TYPES_AND( + c10::kBFloat16, out.scalar_type(), "compute_shape_arange_out", [&]() { + // Note: acc_type further defines an accumulataion type depending on the + // scalar_t and whether its on cuda vs cpu. + using accscalar_t = at::acc_type; + auto xstart = start.to(); + auto xend = end.to(); + auto xstep = step.to(); + + // we use double precision for (start - end) / step + // to compute size_d for consistency across devices. + // The problem with using accscalar_t is that accscalar_t might be + // float32 on gpu for a float32 scalar_t, but double on cpu for the + // same, and the effective output size starts differing on CPU vs GPU + // because of precision issues, which we dont want. the corner-case we + // do want to take into account is int64_t, which has higher precision + // than double NOLINTNEXTLINE(bugprone-branch-clone) + if (std::is_same::value) { + size_d = std::ceil( + static_cast( + end.to() - start.to()) / + step.to()); + } else { + size_d = std::ceil( + static_cast(end.to() - start.to()) / + step.to()); + } + + TORCH_CHECK(xstep > 0 || xstep < 0, "step must be nonzero"); + TORCH_CHECK( + std::isfinite(static_cast(xstart)) && + std::isfinite(static_cast(xend)), + "unsupported range: ", + xstart, + " -> ", + xend); + TORCH_CHECK( + ((xstep > 0) && (xend >= xstart)) || + ((xstep < 0) && (xend <= xstart)), + "upper bound and larger bound inconsistent with step sign"); + + TORCH_CHECK( + size_d >= 0 && + size_d <= + static_cast(std::numeric_limits::max()), + "invalid size, possible overflow?"); + }); int64_t size = static_cast(size_d); - // From torch.arange docs: // dtype (torch.dtype, optional) – the desired data type of returned tensor. - // Default: if None, uses a global default (see torch.set_default_tensor_type()). - // If dtype is not given, infer the data type from the other input arguments. - // If any of start, end, or stop are floating-point, the dtype is inferred to be the default dtype, see get_default_dtype(). - // Otherwise, the dtype is inferred to be torch.int64. + // Default: if None, uses a global default (see + // torch.set_default_tensor_type()). If dtype is not given, infer the data + // type from the other input arguments. If any of start, end, or stop are + // floating-point, the dtype is inferred to be the default dtype, see + // get_default_dtype(). Otherwise, the dtype is inferred to be torch.int64. return {Shape(out.scalar_type(), {size})}; } -std::vector compute_shape_abs(const at::Tensor & self) { +std::vector compute_shape_abs(const at::Tensor& self) { if (self.is_complex()) { const auto float_type = c10::toRealValueType(self.scalar_type()); return {Shape(float_type, self.sizes().vec())}; @@ -146,107 +171,185 @@ std::vector compute_shape_abs(const at::Tensor & self) { return {Shape(self.scalar_type(), self.sizes().vec())}; } -std::vector compute_shape_bernoulli(const at::Tensor & self, c10::optional generator) { +std::vector compute_shape_bernoulli( + const at::Tensor& self, + c10::optional generator) { return {Shape(self.scalar_type(), self.sizes().vec())}; } -std::vector compute_shape_bernoulli(const at::Tensor & self, double p, c10::optional generator) { +std::vector compute_shape_bernoulli( + const at::Tensor& self, + double p, + c10::optional generator) { return compute_shape_bernoulli(self, generator); } -std::vector compute_shape_binary_cross_entropy(const at::Tensor & self, const at::Tensor & target, const c10::optional & weight, int64_t reduction) { - if(reduction == at::Reduction::None) { +std::vector compute_shape_binary_cross_entropy( + const at::Tensor& self, + const at::Tensor& target, + const c10::optional& weight, + int64_t reduction) { + if (reduction == at::Reduction::None) { return {Shape(self.scalar_type(), self.sizes().vec())}; } return {Shape(self.scalar_type(), {})}; } -std::vector compute_shape_binary_cross_entropy_backward(const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & target, const c10::optional & weight, int64_t reduction) { +std::vector compute_shape_binary_cross_entropy_backward( + const at::Tensor& grad_output, + const at::Tensor& self, + const at::Tensor& target, + const c10::optional& weight, + int64_t reduction) { return {Shape(self.scalar_type(), self.sizes().vec())}; } - -std::vector compute_shape_constant_pad_nd(const at::Tensor & self, at::IntArrayRef pad, const at::Scalar & value) { +std::vector compute_shape_constant_pad_nd( + const at::Tensor& self, + at::IntArrayRef pad, + const at::Scalar& value) { // Based on aten/src/ATen/native/ConstantPadNd.cpp::constant_pad_nd - TORCH_CHECK(pad.size() % 2 == 0, "Length of pad must be even but instead it equals ", - pad.size()); + TORCH_CHECK( + pad.size() % 2 == 0, + "Length of pad must be even but instead it equals ", + pad.size()); auto input_sizes = self.sizes(); auto l_inp = self.dim(); auto l_pad = pad.size() / 2; auto l_diff = l_inp - l_pad; - TORCH_CHECK(l_inp >= (int64_t)l_pad, "Length of pad should be no more than twice the number of " - "dimensions of the input. Pad length is ", pad.size(), "while the input has ", - l_inp, "dimensions."); + TORCH_CHECK( + l_inp >= (int64_t)l_pad, + "Length of pad should be no more than twice the number of " + "dimensions of the input. Pad length is ", + pad.size(), + "while the input has ", + l_inp, + "dimensions."); std::vector new_shape; - for (size_t i = 0; i < (size_t)l_diff; i ++) { - new_shape.emplace_back(input_sizes[i]); + for (size_t i = 0; i < (size_t)l_diff; i++) { + new_shape.emplace_back(input_sizes[i]); } for (const auto i : c10::irange((size_t)l_pad)) { - auto pad_idx = pad.size() - ((i + 1) * 2); - auto new_dim = input_sizes[l_diff + i] + pad[pad_idx] + pad[pad_idx + 1]; - TORCH_CHECK(new_dim > 0, "The input size ", input_sizes[l_diff + i], ", plus negative padding ", - pad[pad_idx], " and ", pad[pad_idx + 1], " resulted in a negative output size, " - "which is invalid. Check dimension ", l_diff + i, " of your input."); - new_shape.emplace_back(new_dim); + auto pad_idx = pad.size() - ((i + 1) * 2); + auto new_dim = input_sizes[l_diff + i] + pad[pad_idx] + pad[pad_idx + 1]; + TORCH_CHECK( + new_dim > 0, + "The input size ", + input_sizes[l_diff + i], + ", plus negative padding ", + pad[pad_idx], + " and ", + pad[pad_idx + 1], + " resulted in a negative output size, " + "which is invalid. Check dimension ", + l_diff + i, + " of your input."); + new_shape.emplace_back(new_dim); } return {Shape(self.scalar_type(), new_shape)}; } -std::vector compute_shape_convolution_backward(const at::Tensor & grad_output, const at::Tensor & input, const at::Tensor & weight, at::OptionalIntArrayRef bias_sizes, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool transposed, at::IntArrayRef output_padding, int64_t groups, ::std::array output_mask) { +std::vector compute_shape_convolution_backward( + const at::Tensor& grad_output, + const at::Tensor& input, + const at::Tensor& weight, + at::OptionalIntArrayRef bias_sizes, + at::IntArrayRef stride, + at::IntArrayRef padding, + at::IntArrayRef dilation, + bool transposed, + at::IntArrayRef output_padding, + int64_t groups, + ::std::array output_mask) { if (bias_sizes.has_value()) { - return {Shape(input.scalar_type(), input.sizes().vec()), - Shape(weight.scalar_type(), weight.sizes().vec()), - Shape(grad_output.scalar_type(), bias_sizes.value().vec())}; + return { + Shape(input.scalar_type(), input.sizes().vec()), + Shape(weight.scalar_type(), weight.sizes().vec()), + Shape(grad_output.scalar_type(), bias_sizes.value().vec())}; } else { - // TODO(whc) not sure whether to return 2 shapes here, or a 3rd one that is empty - return {Shape(input.scalar_type(), input.sizes().vec()), - Shape(weight.scalar_type(), weight.sizes().vec())}; + // TODO(whc) not sure whether to return 2 shapes here, or a 3rd one that is + // empty + return { + Shape(input.scalar_type(), input.sizes().vec()), + Shape(weight.scalar_type(), weight.sizes().vec())}; } } -std::vector compute_shape_convolution(const at::Tensor & input, const at::Tensor & weight, const c10::optional & bias, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool transposed, at::IntArrayRef output_padding, int64_t groups) { - +std::vector compute_shape_convolution( + const at::Tensor& input, + const at::Tensor& weight, + const c10::optional& bias, + at::IntArrayRef stride, + at::IntArrayRef padding, + at::IntArrayRef dilation, + bool transposed, + at::IntArrayRef output_padding, + int64_t groups) { int64_t dim = weight.ndimension() - 2; TORCH_CHECK(dim > 0, "weight should have at least three dimensions"); - // at::convolution performs parameter expansion before running kernels on expanded parameters - // we must do the same. Shape formulae access differnent dimensions of e.g. output_padding, but - // output_padding may be passed in as a scalar. Sadly, accessing output_padding[1] in this case - // gives incorrect results rather than indexing error + // at::convolution performs parameter expansion before running kernels on + // expanded parameters we must do the same. Shape formulae access differnent + // dimensions of e.g. output_padding, but output_padding may be passed in as a + // scalar. Sadly, accessing output_padding[1] in this case gives incorrect + // results rather than indexing error auto expanded_stride = expand_param_if_needed(stride, "stride", dim); auto expanded_padding = expand_param_if_needed(padding, "padding", dim); auto expanded_dilation = expand_param_if_needed(dilation, "dilation", dim); if (!transposed) { - return {Shape(input.scalar_type(), at::native::conv_output_size(input.sizes(), weight.sizes(), expanded_padding, expanded_stride, expanded_dilation))}; + return {Shape( + input.scalar_type(), + at::native::conv_output_size( + input.sizes(), + weight.sizes(), + expanded_padding, + expanded_stride, + expanded_dilation))}; } else { - auto expanded_output_padding = expand_param_if_needed(output_padding, "output_padding", dim); - auto out_shape = at::native::conv_input_size(input.sizes(), weight.sizes(), expanded_padding, expanded_output_padding, expanded_stride, expanded_dilation, groups); + auto expanded_output_padding = + expand_param_if_needed(output_padding, "output_padding", dim); + auto out_shape = at::native::conv_input_size( + input.sizes(), + weight.sizes(), + expanded_padding, + expanded_output_padding, + expanded_stride, + expanded_dilation, + groups); return {Shape(input.scalar_type(), out_shape)}; } } -std::vector compute_shape_masked_fill(const at::Tensor & self, const at::Tensor & mask, const at::Scalar & value) { +std::vector compute_shape_masked_fill( + const at::Tensor& self, + const at::Tensor& mask, + const at::Scalar& value) { return {Shape(self.scalar_type(), self.sizes().vec())}; } -std::vector compute_shape_masked_fill(const at::Tensor & self, const at::Tensor & mask, const at::Tensor & value) { +std::vector compute_shape_masked_fill( + const at::Tensor& self, + const at::Tensor& mask, + const at::Tensor& value) { return {Shape(self.scalar_type(), self.sizes().vec())}; } -std::vector compute_shape_max(const at::Tensor & self) { - TORCH_CHECK(self.numel() > 0, - "max(): Expected reduction dim to be specified for input.numel() == 0. Specify the reduction dim with the 'dim' argument."); +std::vector compute_shape_max(const at::Tensor& self) { + TORCH_CHECK( + self.numel() > 0, + "max(): Expected reduction dim to be specified for input.numel() == 0. Specify the reduction dim with the 'dim' argument."); return {Shape(self.scalar_type(), {})}; } -std::vector compute_shape_min(const at::Tensor & self){ - TORCH_CHECK(self.numel() > 0, - "min(): Expected reduction dim to be specified for input.numel() == 0. Specify the reduction dim with the 'dim' argument."); - return {Shape(self.scalar_type(), {})}; +std::vector compute_shape_min(const at::Tensor& self) { + TORCH_CHECK( + self.numel() > 0, + "min(): Expected reduction dim to be specified for input.numel() == 0. Specify the reduction dim with the 'dim' argument."); + return {Shape(self.scalar_type(), {})}; } std::vector compute_shape_nonzero(const at::Tensor& t, bool as_tuple) { @@ -268,36 +371,60 @@ std::vector compute_shape_nonzero(const at::Tensor& self) { return compute_shape_nonzero(self, false); } -std::vector compute_shape_embedding(const at::Tensor & weight, const at::Tensor & indices, int64_t padding_idx, bool scale_grad_by_freq, bool sparse){ +std::vector compute_shape_embedding( + const at::Tensor& weight, + const at::Tensor& indices, + int64_t padding_idx, + bool scale_grad_by_freq, + bool sparse) { // Based on aten/src/ATen/native/Embedding.cpp::embedding. std::vector out_sizes = indices.sizes().vec(); out_sizes.emplace_back(weight.size(1)); return {Shape(weight.scalar_type(), out_sizes)}; } -std::vector compute_shape_std(const at::Tensor & self, bool unbiased){ +std::vector compute_shape_std(const at::Tensor& self, bool unbiased) { return compute_shape_std(self, c10::nullopt, c10::nullopt, false); } -std::vector compute_shape_std(const at::Tensor & self, at::IntArrayRef dim, bool unbiased, bool keepdim){ +std::vector compute_shape_std( + const at::Tensor& self, + at::IntArrayRef dim, + bool unbiased, + bool keepdim) { return compute_shape_std(self, dim, c10::nullopt, keepdim); } -std::vector compute_shape_std(const at::Tensor & self, at::OptionalIntArrayRef dim, c10::optional correction, bool keepdim){ +std::vector compute_shape_std( + const at::Tensor& self, + at::OptionalIntArrayRef dim, + c10::optional correction, + bool keepdim) { if (dim.has_value()) { - auto shape = at::native::shape_from_dim_mask(self, at::native::make_dim_mask(dim.value(), self.dim()), keepdim); - return {Shape(self.scalar_type(), std::vector(shape.begin(), shape.end()))}; + auto shape = at::native::shape_from_dim_mask( + self, at::native::make_dim_mask(dim.value(), self.dim()), keepdim); + return {Shape( + self.scalar_type(), std::vector(shape.begin(), shape.end()))}; } return {Shape(self.scalar_type(), {})}; } -std::vector compute_shape_embedding_dense_backward(const at::Tensor& grad_output, const at::Tensor& indices, int64_t num_weights, int64_t padding_idx, bool scale_grad_by_freq) { +std::vector compute_shape_embedding_dense_backward( + const at::Tensor& grad_output, + const at::Tensor& indices, + int64_t num_weights, + int64_t padding_idx, + bool scale_grad_by_freq) { // Based on aten/src/ATen/native/Embedding.cpp::embedding_dense_backward_cpu. - return {Shape(grad_output.scalar_type(), {num_weights, grad_output.size(-1)})}; + return { + Shape(grad_output.scalar_type(), {num_weights, grad_output.size(-1)})}; } -std::vector compute_shape_index_select(const at::Tensor & self, int64_t dim, - const at::Tensor & index) { - // Based on definition of https://pytorch.org/docs/stable/generated/torch.index_select.html. - // Promote Rank 0 index tensor to a 1 * 1 tensor. +std::vector compute_shape_index_select( + const at::Tensor& self, + int64_t dim, + const at::Tensor& index) { + // Based on definition of + // https://pytorch.org/docs/stable/generated/torch.index_select.html. Promote + // Rank 0 index tensor to a 1 * 1 tensor. dim = at::maybe_wrap_dim(dim, self); auto index_dim = index.dim() > 0 ? index.dim() : 1; auto index_size = index.dim() > 0 ? index.size(0) : 1; @@ -311,32 +438,43 @@ std::vector compute_shape_index_select(const at::Tensor & self, int64_t d return {Shape(self.scalar_type(), output_sizes)}; } -std::vector compute_shape_inverse(const at::Tensor & self) { +std::vector compute_shape_inverse(const at::Tensor& self) { return {Shape(self.scalar_type(), self.sizes().vec())}; } -std::vector compute_shape_kl_div_backward(const at::Tensor& grad_output, const at::Tensor& self, const at::Tensor& target, int64_t reduction, bool log_target) { +std::vector compute_shape_kl_div_backward( + const at::Tensor& grad_output, + const at::Tensor& self, + const at::Tensor& target, + int64_t reduction, + bool log_target) { // Based on definition of aten/src/ATen/native/Loss.cpp::kl_div_backward_cpu. return {Shape(self.scalar_type(), self.sizes().vec())}; } std::vector compute_shape_cat(at::TensorList tensors, int64_t dim) { // TODO(whc) support cat in codegen and move this to compute_*_cat functions - std::vector out_shape(tensors[0].sizes().begin(), tensors[0].sizes().end()); + std::vector out_shape( + tensors[0].sizes().begin(), tensors[0].sizes().end()); dim = at::maybe_wrap_dim(dim, tensors); size_t extended_dim_shape = 0; - for (auto& tensor: tensors) { + for (auto& tensor : tensors) { extended_dim_shape += tensor.sizes()[dim]; } TORCH_CHECK(out_shape.size() > 0, "Scalar tensors are not supported in cat."); - TORCH_CHECK(extended_dim_shape <= std::numeric_limits::max(), "Size overflow"); + TORCH_CHECK( + extended_dim_shape <= std::numeric_limits::max(), + "Size overflow"); out_shape[dim] = extended_dim_shape; return {Shape(tensors[0].scalar_type(), out_shape)}; } -std::vector compute_shape_native_layer_norm(const at::Tensor & input, - at::IntArrayRef normalized_shape, const c10::optional & weight, const c10::optional & bias, +std::vector compute_shape_native_layer_norm( + const at::Tensor& input, + at::IntArrayRef normalized_shape, + const c10::optional& weight, + const c10::optional& bias, double eps) { // Copied from aten/src/ATen/native/layer_norm.cpp::layer_norm_cpu_out. auto input_shape = input.sizes().vec(); @@ -352,52 +490,84 @@ std::vector compute_shape_native_layer_norm(const at::Tensor & input, stat_shape.emplace_back(1); } - return {Shape(input.scalar_type(), input_shape), - Shape(input.scalar_type(), stat_shape), - Shape(input.scalar_type(), stat_shape)}; -} - -std::vector compute_shape_native_layer_norm_backward(const at::Tensor& grad_out, - const at::Tensor& input, at::IntArrayRef normalized_shape, const at::Tensor& mean, const at::Tensor& rstd, - const c10::optional& weight, const c10::optional& bias, ::std::array output_mask) { + return { + Shape(input.scalar_type(), input_shape), + Shape(input.scalar_type(), stat_shape), + Shape(input.scalar_type(), stat_shape)}; +} + +std::vector compute_shape_native_layer_norm_backward( + const at::Tensor& grad_out, + const at::Tensor& input, + at::IntArrayRef normalized_shape, + const at::Tensor& mean, + const at::Tensor& rstd, + const c10::optional& weight, + const c10::optional& bias, + ::std::array output_mask) { std::vector shapes; - shapes.emplace_back(input.scalar_type(), - output_mask[0] ? input.sizes().vec() : std::vector{}); - shapes.emplace_back(weight && weight->defined() ? weight->scalar_type() : input.scalar_type(), - output_mask[1] && weight ? weight->sizes().vec() : std::vector{}); - shapes.emplace_back(bias && weight->defined() ? bias->scalar_type() : input.scalar_type(), - output_mask[2] && bias ? bias->sizes().vec() : std::vector{}); + shapes.emplace_back( + input.scalar_type(), + output_mask[0] ? input.sizes().vec() : std::vector{}); + shapes.emplace_back( + weight && weight->defined() ? weight->scalar_type() : input.scalar_type(), + output_mask[1] && weight ? weight->sizes().vec() + : std::vector{}); + shapes.emplace_back( + bias && weight->defined() ? bias->scalar_type() : input.scalar_type(), + output_mask[2] && bias ? bias->sizes().vec() : std::vector{}); return shapes; } -std::vector compute_shape_mean(const at::Tensor& self, c10::optional dtype) { +std::vector compute_shape_mean( + const at::Tensor& self, + c10::optional dtype) { if (dtype.has_value()) { return {Shape(dtype.value(), {})}; } return {Shape(self.scalar_type(), {})}; } -std::vector compute_shape_mv(const at::Tensor& self, const at::Tensor& vec) { +std::vector compute_shape_mv( + const at::Tensor& self, + const at::Tensor& vec) { return {Shape(self.scalar_type(), {self.size(0)})}; } -std::vector compute_shape_native_dropout(const at::Tensor & input, double p, c10::optional train) { - return {Shape(input.scalar_type(), input.sizes().vec()), Shape(c10::ScalarType::Bool, input.sizes().vec())}; +std::vector compute_shape_native_dropout( + const at::Tensor& input, + double p, + c10::optional train) { + return { + Shape(input.scalar_type(), input.sizes().vec()), + Shape(c10::ScalarType::Bool, input.sizes().vec())}; } -std::vector compute_shape_native_dropout_backward(const at::Tensor & grad_output, const at::Tensor & mask, double scale) { +std::vector compute_shape_native_dropout_backward( + const at::Tensor& grad_output, + const at::Tensor& mask, + double scale) { return {Shape(grad_output.scalar_type(), grad_output.sizes().vec())}; } -std::vector compute_shape_random_functional(const at::Tensor & self, c10::optional generator) { +std::vector compute_shape_random_functional( + const at::Tensor& self, + c10::optional generator) { return {Shape(self.scalar_type(), self.sizes().vec())}; } -std::vector compute_shape_random_functional(const at::Tensor & self, int64_t to, c10::optional generator) { +std::vector compute_shape_random_functional( + const at::Tensor& self, + int64_t to, + c10::optional generator) { return compute_shape_random_functional(self, generator); } -std::vector compute_shape_random_functional(const at::Tensor & self, int64_t from, c10::optional to, c10::optional generator) { +std::vector compute_shape_random_functional( + const at::Tensor& self, + int64_t from, + c10::optional to, + c10::optional generator) { return compute_shape_random_functional(self, generator); } @@ -405,12 +575,15 @@ std::vector compute_shape_relu(const at::Tensor& self) { return {Shape(self.scalar_type(), self.sizes().vec())}; } -std::vector compute_shape_bitwise_and(const at::Tensor& self, const at::Scalar& other) { +std::vector compute_shape_bitwise_and( + const at::Tensor& self, + const at::Scalar& other) { return {Shape(self.scalar_type(), self.sizes().vec())}; } std::vector compute_shape_sum( - const at::Tensor& self, c10::optional dtype) { + const at::Tensor& self, + c10::optional dtype) { if (dtype.has_value()) { return {Shape(dtype.value(), {})}; } @@ -419,7 +592,8 @@ std::vector compute_shape_sum( if (isIntegralType(self.scalar_type(), /*includeBool*/ true)) { return {Shape(c10::ScalarType::Long, {})}; } - return {Shape(self.scalar_type(), {})};; + return {Shape(self.scalar_type(), {})}; + ; } std::vector compute_shape_zero_functional(const at::Tensor& self) { @@ -430,13 +604,19 @@ std::vector compute_shape_trace(const at::Tensor& self) { return {Shape(self.scalar_type(), {})}; } -std::vector compute_shape_sort(const at::Tensor & self, int64_t dim, bool descending) { - return {Shape(self.scalar_type(), self.sizes().vec()), - Shape(c10::ScalarType::Long, self.sizes().vec())}; +std::vector compute_shape_sort( + const at::Tensor& self, + int64_t dim, + bool descending) { + return { + Shape(self.scalar_type(), self.sizes().vec()), + Shape(c10::ScalarType::Long, self.sizes().vec())}; } std::vector compute_shape_smooth_l1_loss( - const at::Tensor& self, const at::Tensor& target, int64_t reduction, + const at::Tensor& self, + const at::Tensor& target, + int64_t reduction, double beta) { // Taken from definition of 'Output' shape here: // https://pytorch.org/docs/stable/generated/torch.nn.SmoothL1Loss.html @@ -448,39 +628,52 @@ std::vector compute_shape_smooth_l1_loss( } } -std::vector compute_shape_slogdet(const at::Tensor & self) { +std::vector compute_shape_slogdet(const at::Tensor& self) { // assumes self.shape is {*, n, n} and returns shape * TORCH_INTERNAL_ASSERT(self.dim() >= 2); std::vector out_sizes(self.sizes().begin(), self.sizes().end() - 2); // Doesn't check input dtype, but output dtype either matches it, // or the actual slogdet operation will throw if it's an unsupported type. // Sign and det outputs hold the same shape, dtype. - return {Shape(self.scalar_type(), out_sizes), - Shape(self.scalar_type(), out_sizes)}; + return { + Shape(self.scalar_type(), out_sizes), + Shape(self.scalar_type(), out_sizes)}; } -std::vector compute_shape_logical_and(at::Tensor & self, const at::Tensor & other) { +std::vector compute_shape_logical_and( + at::Tensor& self, + const at::Tensor& other) { TORCH_INTERNAL_ASSERT(at::are_expandable(self.sizes(), other.sizes())); - return {Shape(c10::ScalarType::Bool, at::infer_size(self.sizes(), other.sizes()))}; + return {Shape( + c10::ScalarType::Bool, at::infer_size(self.sizes(), other.sizes()))}; } -std::vector compute_shape_logical_not(at::Tensor & self) { +std::vector compute_shape_logical_not(at::Tensor& self) { return {Shape(c10::ScalarType::Bool, self.sizes().vec())}; } -std::vector compute_shape_logical_or(at::Tensor & self, const at::Tensor & other) { +std::vector compute_shape_logical_or( + at::Tensor& self, + const at::Tensor& other) { TORCH_INTERNAL_ASSERT(at::are_expandable(self.sizes(), other.sizes())); - return {Shape(c10::ScalarType::Bool, at::infer_size(self.sizes(), other.sizes()))}; + return {Shape( + c10::ScalarType::Bool, at::infer_size(self.sizes(), other.sizes()))}; } -std::vector compute_shape_logical_xor(at::Tensor & self, const at::Tensor & other) { +std::vector compute_shape_logical_xor( + at::Tensor& self, + const at::Tensor& other) { TORCH_INTERNAL_ASSERT(at::are_expandable(self.sizes(), other.sizes())); - return {Shape(c10::ScalarType::Bool, at::infer_size(self.sizes(), other.sizes()))}; + return {Shape( + c10::ScalarType::Bool, at::infer_size(self.sizes(), other.sizes()))}; } std::vector compute_shape_smooth_l1_loss_backward( - const at::Tensor& grad_output, const at::Tensor& self, - const at::Tensor& target, int64_t reduction, double beta) { + const at::Tensor& grad_output, + const at::Tensor& self, + const at::Tensor& target, + int64_t reduction, + double beta) { // The `grad_output` tensor is really the input to this kernel, and while its // shape may vary following the logic of the forward output, the output of // this kernel should have fixed shapes matching the inputs to the forward @@ -488,7 +681,7 @@ std::vector compute_shape_smooth_l1_loss_backward( return {Shape(self.scalar_type(), self.sizes().vec())}; } -std::vector compute_shape_logdet(const at::Tensor & self) { +std::vector compute_shape_logdet(const at::Tensor& self) { // assumes self.shape is {*, n, n} and returns shape * TORCH_INTERNAL_ASSERT(self.dim() >= 2); std::vector out_sizes(self.sizes().begin(), self.sizes().end() - 2); @@ -498,21 +691,30 @@ std::vector compute_shape_logdet(const at::Tensor & self) { } std::vector compute_shape_log_sigmoid_forward(const at::Tensor& self) { - // Based on definition of aten/src/ATen/native/Activation.cpp::log_sigmoid_forward_out_cpu. - return {Shape(self.scalar_type(), self.sizes().vec()), - Shape(self.scalar_type(), self.sizes().vec())}; -} - -std::vector compute_shape_log_sigmoid_backward(const at::Tensor& grad_output, const at::Tensor& self, const at::Tensor& buffer) { - // Based on definition of aten/src/ATen/native/Activation.cpp::log_sigmoid_backward_cpu*. + // Based on definition of + // aten/src/ATen/native/Activation.cpp::log_sigmoid_forward_out_cpu. + return { + Shape(self.scalar_type(), self.sizes().vec()), + Shape(self.scalar_type(), self.sizes().vec())}; +} + +std::vector compute_shape_log_sigmoid_backward( + const at::Tensor& grad_output, + const at::Tensor& self, + const at::Tensor& buffer) { + // Based on definition of + // aten/src/ATen/native/Activation.cpp::log_sigmoid_backward_cpu*. return {Shape(grad_output.scalar_type(), grad_output.sizes().vec())}; } std::vector compute_shape_nll_loss2d_forward( - const at::Tensor& self, const at::Tensor& target, - const c10::optional& weight, int64_t reduction, + const at::Tensor& self, + const at::Tensor& target, + const c10::optional& weight, + int64_t reduction, int64_t ignore_index) { - // Based on definition of aten/src/ATen/native/LossNLL2d.cpp:nll_loss2d_forward_cpu + // Based on definition of + // aten/src/ATen/native/LossNLL2d.cpp:nll_loss2d_forward_cpu auto sizes = (reduction == at::Reduction::Reduction::None ? target.sizes().vec() : std::vector{}); @@ -520,13 +722,22 @@ std::vector compute_shape_nll_loss2d_forward( } std::vector compute_shape_nll_loss2d_backward( - const at::Tensor& grad_output, const at::Tensor& self, - const at::Tensor& target, const c10::optional& weight, - int64_t reduction, int64_t ignore_index, const at::Tensor& total_weight) { + const at::Tensor& grad_output, + const at::Tensor& self, + const at::Tensor& target, + const c10::optional& weight, + int64_t reduction, + int64_t ignore_index, + const at::Tensor& total_weight) { return {Shape(self.scalar_type(), self.sizes().vec())}; } -std::vector compute_shape_grid_sampler_2d(const at::Tensor & input, const at::Tensor & grid, int64_t interpolation_mode, int64_t padding_mode, bool align_corners) { +std::vector compute_shape_grid_sampler_2d( + const at::Tensor& input, + const at::Tensor& grid, + int64_t interpolation_mode, + int64_t padding_mode, + bool align_corners) { // from `aten/src/ATen/native/cpu/GridSamplerKernel.cpp int64_t N = input.size(0); int64_t C = input.size(1); @@ -535,36 +746,59 @@ std::vector compute_shape_grid_sampler_2d(const at::Tensor & input, const return {Shape(input.scalar_type(), {N, C, H, W})}; } -std::vector compute_shape_grid_sampler_2d_backward(const at::Tensor & grad_output, const at::Tensor & input, const at::Tensor & grid, int64_t interpolation_mode, int64_t padding_mode, bool align_corners, ::std::array output_mask) { +std::vector compute_shape_grid_sampler_2d_backward( + const at::Tensor& grad_output, + const at::Tensor& input, + const at::Tensor& grid, + int64_t interpolation_mode, + int64_t padding_mode, + bool align_corners, + ::std::array output_mask) { // from `aten/src/ATen/native/cpu/GridSamplerKernel.cpp auto grad_input_shape = Shape(input.scalar_type(), input.sizes().vec()); auto grad_grid_shape = Shape(grid.scalar_type(), grid.sizes().vec()); return {grad_input_shape, grad_grid_shape}; } -std::vector compute_shape_flip(const at::Tensor & self, at::IntArrayRef dims) { +std::vector compute_shape_flip( + const at::Tensor& self, + at::IntArrayRef dims) { return {Shape(self.scalar_type(), self.sizes().vec())}; } -std::vector compute_shape__adaptive_avg_pool2d(const at::Tensor & self, at::IntArrayRef output_size) { +std::vector compute_shape__adaptive_avg_pool2d( + const at::Tensor& self, + at::IntArrayRef output_size) { // Checks based on `aten/src/ATen/native/AdaptiveAveragePooling.cpp` // and on `aten/src/ATen/native/cpu/AdaptiveAvgPoolKernel.cpp` - TORCH_CHECK(output_size.size() == 2, "adaptive_avg_pool2d: output_size must be 2"); + TORCH_CHECK( + output_size.size() == 2, "adaptive_avg_pool2d: output_size must be 2"); TORCH_CHECK( (output_size[0] >= 0 && output_size[1] >= 0), "adaptive_avg_pool2d: elements of output_size must be greater than or equal to 0 ", - "but received {", output_size[0], ", ", output_size[1], "}"); - int64_t ndim = self.ndimension(); - for (const auto i : c10::irange(1, ndim)) { - TORCH_CHECK(self.size(i) > 0, + "but received {", + output_size[0], + ", ", + output_size[1], + "}"); + int64_t ndim = self.ndimension(); + for (const auto i : c10::irange(1, ndim)) { + TORCH_CHECK( + self.size(i) > 0, "adaptive_avg_pool2d(): Expected self to have non-zero size for non-batch dimensions, " - "but Tensor has sizes ", self.sizes(), " with dimension ", i, " being " + "but Tensor has sizes ", + self.sizes(), + " with dimension ", + i, + " being " "empty"); - } - TORCH_CHECK((ndim == 3 || ndim == 4), - "adaptive_avg_pool2d(): Expected 3D or 4D tensor, but got ", self.sizes()); + } + TORCH_CHECK( + (ndim == 3 || ndim == 4), + "adaptive_avg_pool2d(): Expected 3D or 4D tensor, but got ", + self.sizes()); - int64_t channels = self.size(-3); + int64_t channels = self.size(-3); int64_t output_height = output_size[0]; int64_t output_width = output_size[1]; @@ -572,48 +806,82 @@ std::vector compute_shape__adaptive_avg_pool2d(const at::Tensor & self, a return {Shape(self.scalar_type(), {channels, output_height, output_width})}; } else { int64_t nbatch = self.size(0); - return {Shape(self.scalar_type(), {nbatch, channels, output_height, output_width})}; + return {Shape( + self.scalar_type(), {nbatch, channels, output_height, output_width})}; } } -std::vector compute_shape__adaptive_avg_pool2d_backward(const at::Tensor & grad_output, const at::Tensor & self) { - // Checks based on `aten/src/ATen/native/AdaptiveAveragePooling.cpp` - int64_t ndim = grad_output.ndimension(); +std::vector compute_shape__adaptive_avg_pool2d_backward( + const at::Tensor& grad_output, + const at::Tensor& self) { + // Checks based on `aten/src/ATen/native/AdaptiveAveragePooling.cpp` + int64_t ndim = grad_output.ndimension(); - for (const auto i : c10::irange(1, ndim)) { - TORCH_CHECK(grad_output.size(i) > 0, + for (const auto i : c10::irange(1, ndim)) { + TORCH_CHECK( + grad_output.size(i) > 0, "adaptive_avg_pool2d_backward(): Expected grad_output to have non-zero size for non-batch dimensions, " - "but grad_output has sizes ", grad_output.sizes(), " with dimension ", i, " being " + "but grad_output has sizes ", + grad_output.sizes(), + " with dimension ", + i, + " being " "empty"); - } + } - TORCH_CHECK((ndim == 3 || ndim == 4), - "adaptive_avg_pool2d_backward(): Expected 3D or 4D tensor, but got ", self.sizes()); - TORCH_CHECK(self.dtype() == grad_output.dtype(), - "expected dtype ", self.dtype(), " for `grad_output` but got dtype ", grad_output.dtype()); + TORCH_CHECK( + (ndim == 3 || ndim == 4), + "adaptive_avg_pool2d_backward(): Expected 3D or 4D tensor, but got ", + self.sizes()); + TORCH_CHECK( + self.dtype() == grad_output.dtype(), + "expected dtype ", + self.dtype(), + " for `grad_output` but got dtype ", + grad_output.dtype()); - return {Shape(self.scalar_type(), self.sizes().vec())}; + return {Shape(self.scalar_type(), self.sizes().vec())}; } -std::vector compute_shape_glu_backward(const at::Tensor & grad_output, const at::Tensor & self, int64_t dim) { +std::vector compute_shape_glu_backward( + const at::Tensor& grad_output, + const at::Tensor& self, + int64_t dim) { return {Shape(self.scalar_type(), self.sizes().vec())}; } -std::vector compute_shape_glu_jvp(const at::Tensor & glu, const at::Tensor & x, const at::Tensor & dx, int64_t dim) { +std::vector compute_shape_glu_jvp( + const at::Tensor& glu, + const at::Tensor& x, + const at::Tensor& dx, + int64_t dim) { return {Shape(glu.scalar_type(), glu.sizes().vec())}; } -std::vector compute_shape_l1_loss_backward(const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & target, int64_t reduction) { +std::vector compute_shape_l1_loss_backward( + const at::Tensor& grad_output, + const at::Tensor& self, + const at::Tensor& target, + int64_t reduction) { TORCH_INTERNAL_ASSERT(grad_output.scalar_type() == self.dtype()); return {Shape(self.scalar_type(), self.sizes().vec())}; } -std::vector compute_shape_clamp_min(const at::Tensor & self, const at::Scalar & min) { +std::vector compute_shape_clamp_min( + const at::Tensor& self, + const at::Scalar& min) { return {Shape(self.scalar_type(), self.sizes().vec())}; } -std::vector compute_shape__to_copy(const at::Tensor & self, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory, bool non_blocking, c10::optional memory_format) { - if(dtype){ +std::vector compute_shape__to_copy( + const at::Tensor& self, + c10::optional dtype, + c10::optional layout, + c10::optional device, + c10::optional pin_memory, + bool non_blocking, + c10::optional memory_format) { + if (dtype) { return {Shape(*dtype, self.sizes().vec())}; } return {Shape(self.scalar_type(), self.sizes().vec())}; @@ -626,9 +894,14 @@ std::vector compute_shape_stack(at::TensorList tensors, int64_t dim) { // Copied from 'check_stack_inputs' in TensorShape.cpp at::IntArrayRef entry_shape = tensors[0].sizes(); for (const auto i : c10::irange(1, tensors.size())) { - TORCH_CHECK(tensors[i].sizes() == entry_shape, - "stack expects each tensor to be equal size, but got ", entry_shape, - " at entry 0 and ", tensors[i].sizes(), " at entry ", i); + TORCH_CHECK( + tensors[i].sizes() == entry_shape, + "stack expects each tensor to be equal size, but got ", + entry_shape, + " at entry 0 and ", + tensors[i].sizes(), + " at entry ", + i); } auto result_sizes = tensors[0].sizes().vec(); @@ -636,12 +909,14 @@ std::vector compute_shape_stack(at::TensorList tensors, int64_t dim) { return {Shape(tensors[0].scalar_type(), result_sizes)}; } -std::vector compute_shape_repeat(const at::Tensor & self, at::IntArrayRef repeats) { +std::vector compute_shape_repeat( + const at::Tensor& self, + at::IntArrayRef repeats) { CHECK_GE(repeats.size(), self.dim()); int64_t num_new_dimensions = repeats.size() - self.dim(); std::vector padded_size(num_new_dimensions, 1); - padded_size.insert(padded_size.end(), self.sizes().begin(), - self.sizes().end()); + padded_size.insert( + padded_size.end(), self.sizes().begin(), self.sizes().end()); std::vector target_size(repeats.size()); for (const auto idx : c10::irange(repeats.size())) { target_size[idx] = padded_size[idx] * repeats[idx]; @@ -649,69 +924,126 @@ std::vector compute_shape_repeat(const at::Tensor & self, at::IntArrayRef return {Shape(self.scalar_type(), target_size)}; } -std::vector compute_shape_narrow_copy(const at::Tensor & self, int64_t dim, int64_t start, c10::SymInt length) { +std::vector compute_shape_narrow_copy( + const at::Tensor& self, + int64_t dim, + int64_t start, + c10::SymInt length) { return {Shape(self.scalar_type(), self.sizes().vec())}; } - // Non-Native Ops -std::vector compute_shape_scalar(const at::Scalar& value, const at::ScalarType& type) { - return { Shape(type, {}) }; -} -std::vector compute_shape_expand(const Output& input, const std::vector& size, const bool& is_scalar_expand) { - return { Shape(input.shape().scalar_type(), size) }; -} -std::vector compute_shape_view(const Output& input, const std::vector& output_sizes) { +std::vector compute_shape_scalar( + const at::Scalar& value, + const at::ScalarType& type) { + return {Shape(type, {})}; +} +std::vector compute_shape_expand( + const Output& input, + const std::vector& size, + const bool& is_scalar_expand) { + return {Shape(input.shape().scalar_type(), size)}; +} +std::vector compute_shape_view( + const Output& input, + const std::vector& output_sizes) { const Shape& input_shape = input.shape(); const auto complete_output_sizes = at::infer_size(output_sizes, input_shape.numel()); - return { Shape(input_shape.scalar_type(), complete_output_sizes) }; + return {Shape(input_shape.scalar_type(), complete_output_sizes)}; } -std::vector compute_shape_cast(const Output& input, const at::ScalarType& dtype, const c10::optional& stype) { +std::vector compute_shape_cast( + const Output& input, + const at::ScalarType& dtype, + const c10::optional& stype) { Shape shape = input.shape(); shape.set_scalar_type(dtype); - return { shape }; + return {shape}; } - // View Ops -std::vector compute_shape_as_strided_view_update(const Output& target, const Output& input, const std::vector& size, const std::vector& stride, const int64_t& storage_offset) { - return { Shape(target.shape().scalar_type(), size) }; -} -std::vector compute_shape_as_strided(const Output& input, const std::vector& size, const std::vector& stride, const int64_t& storage_offset) { - return { Shape(input.shape().scalar_type(), size) }; -} -std::vector compute_shape_diagonal_view_update(const Output& target, const Output& input, const int64_t& offset, const int64_t& dim1, const int64_t& dim2) { - return { target.shape() }; -} -std::vector compute_shape_diagonal(const Output& input, const int64_t& offset, const int64_t& dim1, const int64_t& dim2) { - return { MakeDiagonalShape(input.shape(), offset, dim1, dim2) }; -} -std::vector compute_shape_narrow_view_update(const Output& input, const Output& source, const std::vector& base_indices) { - return { input.shape() }; -} -std::vector compute_shape_narrow(const Output& input, const std::vector& base_indices, const std::vector& sizes) { - return { Shape(input.shape().scalar_type(), sizes) }; -} -std::vector compute_shape_permute(const Output& input, const std::vector& dims) { - return { MakePermuteShape(input.shape(), dims) }; -} -std::vector compute_shape_resize(const Output& input, const std::vector& size) { - return { Shape(input.shape().scalar_type(), size) }; -} -std::vector compute_shape_select_view_update(const Output& target, const Output& source, const int64_t& dim, const int64_t& start, const int64_t& end, const int64_t& stride) { - return { target.shape() }; -} -std::vector compute_shape_select(const Output& input, const int64_t& dim, const int64_t& start, const int64_t& end, const int64_t& stride) { - return { MakeSelectShape(input.shape(), dim, start, end, stride) }; +std::vector compute_shape_as_strided_view_update( + const Output& target, + const Output& input, + const std::vector& size, + const std::vector& stride, + const int64_t& storage_offset) { + return {Shape(target.shape().scalar_type(), size)}; +} +std::vector compute_shape_as_strided( + const Output& input, + const std::vector& size, + const std::vector& stride, + const int64_t& storage_offset) { + return {Shape(input.shape().scalar_type(), size)}; +} +std::vector compute_shape_diagonal_view_update( + const Output& target, + const Output& input, + const int64_t& offset, + const int64_t& dim1, + const int64_t& dim2) { + return {target.shape()}; +} +std::vector compute_shape_diagonal( + const Output& input, + const int64_t& offset, + const int64_t& dim1, + const int64_t& dim2) { + return {MakeDiagonalShape(input.shape(), offset, dim1, dim2)}; +} +std::vector compute_shape_narrow_view_update( + const Output& input, + const Output& source, + const std::vector& base_indices) { + return {input.shape()}; +} +std::vector compute_shape_narrow( + const Output& input, + const std::vector& base_indices, + const std::vector& sizes) { + return {Shape(input.shape().scalar_type(), sizes)}; +} +std::vector compute_shape_permute( + const Output& input, + const std::vector& dims) { + return {MakePermuteShape(input.shape(), dims)}; +} +std::vector compute_shape_resize( + const Output& input, + const std::vector& size) { + return {Shape(input.shape().scalar_type(), size)}; +} +std::vector compute_shape_select_view_update( + const Output& target, + const Output& source, + const int64_t& dim, + const int64_t& start, + const int64_t& end, + const int64_t& stride) { + return {target.shape()}; +} +std::vector compute_shape_select( + const Output& input, + const int64_t& dim, + const int64_t& start, + const int64_t& end, + const int64_t& stride) { + return {MakeSelectShape(input.shape(), dim, start, end, stride)}; } std::vector compute_shape_squeeze(const Output& input, const int& dim) { const auto& input_shape = input.shape(); - return { torch::lazy::Shape(input_shape.scalar_type(), BuildSqueezedDimensions(input_shape.sizes(), dim)) }; + return {torch::lazy::Shape( + input_shape.scalar_type(), + BuildSqueezedDimensions(input_shape.sizes(), dim))}; } -std::vector compute_shape_unsqueeze(const Output& input, const int& dim) { +std::vector compute_shape_unsqueeze( + const Output& input, + const int& dim) { const auto& input_shape = input.shape(); - return { torch::lazy::Shape(input_shape.scalar_type(), BuildUnsqueezedDimensions(input_shape.sizes(), dim)) }; + return {torch::lazy::Shape( + input_shape.scalar_type(), + BuildUnsqueezedDimensions(input_shape.sizes(), dim))}; } // Restore unused-parameters warnings diff --git a/torch/csrc/lazy/core/shape_inference.h b/torch/csrc/lazy/core/shape_inference.h index 41477b7be6048f..cd1978d7ab8350 100644 --- a/torch/csrc/lazy/core/shape_inference.h +++ b/torch/csrc/lazy/core/shape_inference.h @@ -9,7 +9,7 @@ #include #include -namespace torch{ +namespace torch { namespace lazy { // Turn clang-format off, as we rely on the whole signature being on one line // for codegen. diff --git a/torch/csrc/lazy/core/tensor.cpp b/torch/csrc/lazy/core/tensor.cpp index 1cc1ade8fe4aa1..0fb34f2fa699b9 100644 --- a/torch/csrc/lazy/core/tensor.cpp +++ b/torch/csrc/lazy/core/tensor.cpp @@ -3,26 +3,26 @@ #include #include +#include #include #include #include #include #include -#include - namespace torch { namespace lazy { namespace { -LazyTensorPtr GetOrCreateLtcTensor(const at::Tensor& tensor, - const BackendDevice& device) { +LazyTensorPtr GetOrCreateLtcTensor( + const at::Tensor& tensor, + const BackendDevice& device) { if (!tensor.defined()) { return torch::lazy::LazyTensorPtr(); } auto lazy_tensor = TryGetLtcTensor(tensor); return lazy_tensor ? lazy_tensor : LazyTensor::Create(tensor, device); } -} // namespace +} // namespace LazyTensor::Data::~Data() { LazyGraphExecutor::Get()->UnregisterTensor(this); @@ -32,13 +32,15 @@ LazyTensorPtr LazyTensor::Create( const at::Tensor& tensor, const BackendDevice& device) { TORCH_CHECK(tensor.device().type() != at::kLazy); - LazyTensorPtr lazy_tensor = c10::make_intrusive(LazyTensor(tensor, device)); + LazyTensorPtr lazy_tensor = + c10::make_intrusive(LazyTensor(tensor, device)); LazyGraphExecutor::Get()->RegisterTensor(lazy_tensor->data_ptr()); return lazy_tensor; } LazyTensorPtr LazyTensor::Create(Value ir_value, const BackendDevice& device) { - LazyTensorPtr lazy_tensor = c10::make_intrusive(LazyTensor(std::move(ir_value), device)); + LazyTensorPtr lazy_tensor = + c10::make_intrusive(LazyTensor(std::move(ir_value), device)); LazyGraphExecutor::Get()->RegisterTensor(lazy_tensor->data_ptr()); return lazy_tensor; } @@ -46,13 +48,15 @@ LazyTensorPtr LazyTensor::Create(Value ir_value, const BackendDevice& device) { LazyTensorPtr LazyTensor::Create( std::shared_ptr view, const BackendDevice& device) { - LazyTensorPtr lazy_tensor = c10::make_intrusive(LazyTensor(std::move(view), device)); + LazyTensorPtr lazy_tensor = + c10::make_intrusive(LazyTensor(std::move(view), device)); LazyGraphExecutor::Get()->RegisterTensor(lazy_tensor->data_ptr()); return lazy_tensor; } LazyTensorPtr LazyTensor::Create(BackendDataPtr handle) { - LazyTensorPtr lazy_tensor = c10::make_intrusive(LazyTensor(std::move(handle))); + LazyTensorPtr lazy_tensor = + c10::make_intrusive(LazyTensor(std::move(handle))); LazyGraphExecutor::Get()->RegisterTensor(lazy_tensor->data_ptr()); return lazy_tensor; } @@ -78,8 +82,11 @@ LazyTensor::LazyTensor( : LazyTensor(std::make_shared(std::move(view), device)) {} LazyTensor::LazyTensor(std::shared_ptr data) - : data_(std::move(data)) - , storage_(c10::Storage({}, 0, c10::DataPtr(nullptr, backendDeviceToAtenDevice(data_->device)))) {} + : data_(std::move(data)), + storage_(c10::Storage( + {}, + 0, + c10::DataPtr(nullptr, backendDeviceToAtenDevice(data_->device)))) {} LazyTensor::Data* LazyTensor::data() const { TORCH_CHECK(data_ != nullptr, "Trying to access a null cursor"); @@ -200,9 +207,9 @@ void LazyTensor::SetIrValue(Value ir_value) { void LazyTensor::SetInPlaceIrValue(Value ir_value) { auto tensor_shape = shape(); - if (tensor_shape.Get().scalar_type() != - ir_value.shape().scalar_type()) { - ir_value = MakeCast(ir_value, tensor_shape.Get().scalar_type(), c10::nullopt); + if (tensor_shape.Get().scalar_type() != ir_value.shape().scalar_type()) { + ir_value = + MakeCast(ir_value, tensor_shape.Get().scalar_type(), c10::nullopt); } SetIrValue(std::move(ir_value)); } @@ -336,9 +343,7 @@ std::shared_ptr LazyTensor::CreateView(ViewInfo view_info) const { Value ir_value = GetIrValue(); std::shared_ptr alias = std::make_shared(ir_value); ViewInfo this_view_info( - ViewInfo::Type::kNoOp, - ir_value.shape(), - ir_value.shape()); + ViewInfo::Type::kNoOp, ir_value.shape(), ir_value.shape()); data()->view = std::make_shared( ir_value.shape(), alias, std::move(this_view_info)); AssignIrValue(Value()); @@ -448,7 +453,8 @@ void LazyTensor::ApplyPendingGraph() { // This method is called to ensure that the tensor data is available on // device, so that a call to CurrentDataHandle() returns a valid pointer. if (CurrentDataHandle() == nullptr) { - std::vector tensors({c10::make_intrusive(LazyTensor(*this))}); + std::vector tensors( + {c10::make_intrusive(LazyTensor(*this))}); LazyGraphExecutor::Get()->SyncTensorsGraph( &tensors, {}, @@ -464,10 +470,11 @@ int64_t LazyTensor::GetNextTensorId() { torch::lazy::Value GetTensorList(c10::ArrayRef tensors) { std::vector values; - for (const auto& t: tensors) { + for (const auto& t : tensors) { auto* impl = dynamic_cast(t.unsafeGetTensorImpl()); - TORCH_INTERNAL_ASSERT(impl, - "GetTensorList only supports lists of valid tensors, but optional support could be added"); + TORCH_INTERNAL_ASSERT( + impl, + "GetTensorList only supports lists of valid tensors, but optional support could be added"); values.push_back(impl->tensor()->GetIrValue()); } @@ -485,7 +492,8 @@ LazyTensorPtr TryGetLtcTensor(const at::Tensor& tensor) { LazyTensorPtr GetLtcTensor(const at::Tensor& tensor) { auto lazy_tensor = TryGetLtcTensor(tensor); - CHECK(lazy_tensor) << "Input tensor is not a lazy tensor: " << tensor.toString(); + CHECK(lazy_tensor) << "Input tensor is not a lazy tensor: " + << tensor.toString(); return lazy_tensor; } @@ -498,18 +506,21 @@ std::vector GetLtcTensors(c10::ArrayRef tensors) { return ltc_tensors; } -LazyTensorPtr GetOrCreateLtcTensor(const c10::optional& tensor, - const BackendDevice& device) { +LazyTensorPtr GetOrCreateLtcTensor( + const c10::optional& tensor, + const BackendDevice& device) { return GetOrCreateLtcTensor(tensor.value_or(at::Tensor()), device); } -LazyTensorPtr GetLtcTensorOrCreateForWrappedNumber(const at::Tensor& tensor, const BackendDevice& device) { +LazyTensorPtr GetLtcTensorOrCreateForWrappedNumber( + const at::Tensor& tensor, + const BackendDevice& device) { // TODO: There are places in core where a scalar is wrapped but not marked as // wrapped. return (tensor.unsafeGetTensorImpl()->is_wrapped_number() || (tensor.dim() == 0 && tensor.numel() == 1)) - ? GetOrCreateLtcTensor(tensor, device) - : GetLtcTensor(tensor); + ? GetOrCreateLtcTensor(tensor, device) + : GetLtcTensor(tensor); } at::Tensor CreateAtenFromLtcTensor(const LazyTensorPtr& ltc_tensor) { diff --git a/torch/csrc/lazy/core/tensor.h b/torch/csrc/lazy/core/tensor.h index 9401a43dc522b6..74c0c79ce3500f 100644 --- a/torch/csrc/lazy/core/tensor.h +++ b/torch/csrc/lazy/core/tensor.h @@ -1,20 +1,19 @@ #pragma once +#include #include #include #include #include #include #include -#include - namespace torch { namespace lazy { -class TORCH_API SymbolicIntNode: public c10::SymbolicIntNode { -public: - SymbolicIntNode(NodePtr ptr): node_(std::move(ptr)) {}; +class TORCH_API SymbolicIntNode : public c10::SymbolicIntNode { + public: + SymbolicIntNode(NodePtr ptr) : node_(std::move(ptr)){}; NodePtr node_; }; @@ -62,12 +61,12 @@ class TORCH_API LazyTensor : public c10::intrusive_ptr_target { static LazyTensorPtr Create(BackendDataPtr handle); static LazyTensorPtr Create(std::shared_ptr data); - // The default ctor previously created a null LazyTensor (one with no 'data' obj). - // Creating a null LazyTensor is no longer possible, since the same can be achieved by - // creating a null LazyTensorPtr and it is way too confusing to have to check both - // lazy_tensor_ptr && *lazy_tensor_ptr, - // so everywhere that used to rely on a LazyTensor obj with a null Data can now rely on - // a null LazyTensorPtr instead. + // The default ctor previously created a null LazyTensor (one with no 'data' + // obj). Creating a null LazyTensor is no longer possible, since the same can + // be achieved by creating a null LazyTensorPtr and it is way too confusing to + // have to check both lazy_tensor_ptr && *lazy_tensor_ptr, so everywhere that + // used to rely on a LazyTensor obj with a null Data can now rely on a null + // LazyTensorPtr instead. LazyTensor() = delete; size_t generation() const { @@ -141,10 +140,14 @@ class TORCH_API LazyTensor : public c10::intrusive_ptr_target { // Applies the queue of operations in preparation for using the data. void ApplyPendingGraph(); - const c10::Storage& Storage() const { return storage_; } + const c10::Storage& Storage() const { + return storage_; + } // This is currently only used by outlier view ops such as expand that // don't go through CreateViewTensor to support Tensor.is_alias_of. - void SetStorage(const c10::Storage& storage) { storage_ = storage; } + void SetStorage(const c10::Storage& storage) { + storage_ = storage; + } private: LazyTensor(const at::Tensor& tensor, const BackendDevice& device); @@ -201,9 +204,10 @@ class TORCH_API LazyTensor : public c10::intrusive_ptr_target { // Utils to convert at::Tensor to LazyTensor, and vice versa. // Section 0: c10::Tensorlist ==> lazy::TensorList -// note: GetTensorList is not totally parallel to GetLtcTensor; A TensorList skips -// the LazyTensor wrappers, assuming that the list of underlying IR nodes is -// actually more useful for downstream computations. TBD. +// note: GetTensorList is not totally parallel to GetLtcTensor; A TensorList +// skips +// the LazyTensor wrappers, assuming that the list of underlying IR nodes +// is actually more useful for downstream computations. TBD. TORCH_API torch::lazy::Value GetTensorList(c10::ArrayRef tensors); // Section 1: at::Tensor => LazyTensor. @@ -216,14 +220,18 @@ TORCH_API LazyTensorPtr TryGetLtcTensor(const at::Tensor& tensor); TORCH_API LazyTensorPtr GetLtcTensor(const at::Tensor& tensor); // Same as above, applied to a list of tensors. -TORCH_API std::vector GetLtcTensors(c10::ArrayRef tensors); +TORCH_API std::vector GetLtcTensors( + c10::ArrayRef tensors); // If tensor is a lazy tensor type, returns the LazyTensor embedded within it, // otherwise creates a new lazy tensor type with tensor as data. -TORCH_API LazyTensorPtr GetOrCreateLtcTensor(const c10::optional& tensor, - const BackendDevice& device); +TORCH_API LazyTensorPtr GetOrCreateLtcTensor( + const c10::optional& tensor, + const BackendDevice& device); -TORCH_API LazyTensorPtr GetLtcTensorOrCreateForWrappedNumber(const at::Tensor& tensor, const BackendDevice& device); +TORCH_API LazyTensorPtr GetLtcTensorOrCreateForWrappedNumber( + const at::Tensor& tensor, + const BackendDevice& device); // Section 2: LazyTensor => at::Tensor. // Creates an ATen tensor from an LazyTensor. @@ -231,13 +239,15 @@ TORCH_API at::Tensor CreateAtenFromLtcTensor(const LazyTensorPtr& ltc_tensor); TORCH_API at::Tensor CreateAtenFromLtcTensor(LazyTensor&& ltc_tensor); template -auto TupleAtenFromLtcTensorsImpl(const std::vector& tensors, std::index_sequence) { - return std::make_tuple(CreateAtenFromLtcTensor(tensors[Indices])...); +auto TupleAtenFromLtcTensorsImpl( + const std::vector& tensors, + std::index_sequence) { + return std::make_tuple(CreateAtenFromLtcTensor(tensors[Indices])...); } template auto TupleAtenFromLtcTensors(const std::vector& tensors) { - return TupleAtenFromLtcTensorsImpl(tensors, std::make_index_sequence{}); + return TupleAtenFromLtcTensorsImpl(tensors, std::make_index_sequence{}); } } // namespace lazy diff --git a/torch/csrc/lazy/core/tensor_impl.cpp b/torch/csrc/lazy/core/tensor_impl.cpp index c7be506caec755..78398fb828d249 100644 --- a/torch/csrc/lazy/core/tensor_impl.cpp +++ b/torch/csrc/lazy/core/tensor_impl.cpp @@ -5,22 +5,25 @@ #include #include #include -#include #include +#include namespace torch { namespace lazy { namespace { -// LTCGuardImpl is used by CompositeExplicitAutograd ops or eager fallbacks to make sure that some particular tensors -// within the life scope of the guard are on the same device. For example, in RegisterCompositeExplicitAutograd.cpp, -// outputs of each op are examined if they are on same device as the supplied TensorOptions. For more information, -// see DeviceGuard.h. -// For ops that have LTC native function implementations, this guard is omitted. +// LTCGuardImpl is used by CompositeExplicitAutograd ops or eager fallbacks to +// make sure that some particular tensors within the life scope of the guard are +// on the same device. For example, in RegisterCompositeExplicitAutograd.cpp, +// outputs of each op are examined if they are on same device as the supplied +// TensorOptions. For more information, see DeviceGuard.h. For ops that have LTC +// native function implementations, this guard is omitted. thread_local c10::Device g_device(c10::DeviceType::Lazy); struct LTCGuardImpl : public c10::impl::DeviceGuardImplInterface { - at::DeviceType type() const override { return at::DeviceType::Lazy; } + at::DeviceType type() const override { + return at::DeviceType::Lazy; + } c10::Device exchangeDevice(c10::Device device) const override { TORCH_INTERNAL_ASSERT(device.type() == c10::DeviceType::Lazy); @@ -65,7 +68,7 @@ struct LTCGuardImpl : public c10::impl::DeviceGuardImplInterface { C10_REGISTER_GUARD_IMPL(Lazy, LTCGuardImpl); -} // namespace +} // namespace // TODO(whc) when do we want to clone vs share? LTCTensorImpl::LTCTensorImpl(const LazyTensorPtr& tensor) @@ -75,10 +78,12 @@ LTCTensorImpl::LTCTensorImpl(const LazyTensor& tensor) : LTCTensorImpl(LazyTensor(tensor)) {} LTCTensorImpl::LTCTensorImpl(LazyTensor&& tensor) - : c10::TensorImpl(c10::DispatchKeySet{c10::DispatchKey::Lazy, - c10::DispatchKey::AutogradLazy}, - c10::scalarTypeToTypeMeta(tensor.dtype()), - backendDeviceToAtenDevice(tensor.GetDevice())), + : c10::TensorImpl( + c10::DispatchKeySet{ + c10::DispatchKey::Lazy, + c10::DispatchKey::AutogradLazy}, + c10::scalarTypeToTypeMeta(tensor.dtype()), + backendDeviceToAtenDevice(tensor.GetDevice())), tensor_(c10::make_intrusive(std::move(tensor))) { // This is a temporary fix for a PyTorch core issue, // according to https://github.com/pytorch/xla/pull/2682. @@ -87,8 +92,9 @@ LTCTensorImpl::LTCTensorImpl(LazyTensor&& tensor) auto rank = tensor_->shape().Get().sizes().size(); sym_sizes_.reserve(rank); - for (auto i: c10::irange(rank)) { - auto dim_node = getBackend()->GetIrBuilder()->MakeSizeNode(this->tensor_->GetIrValue(), i); + for (auto i : c10::irange(rank)) { + auto dim_node = getBackend()->GetIrBuilder()->MakeSizeNode( + this->tensor_->GetIrValue(), i); auto sn = std::make_shared(dim_node); sym_sizes_.push_back(sn->toSymInt()); } @@ -149,7 +155,8 @@ void LTCTensorImpl::setup_size_properties() { // We can't call refresh_numel() given we override sizes() too. numel_ = shape.Get().numel(); sizes_and_strides_.set_sizes(shape.Get().sizes()); - // We can't call empty_tensor_restride(c10::MemoryFormat::Contiguous) given we override sizes() too. + // We can't call empty_tensor_restride(c10::MemoryFormat::Contiguous) given + // we override sizes() too. std::vector updated_strides; updated_strides = ComputeArrayStrides(shape.Get().sizes()); for (const auto i : c10::irange(updated_strides.size())) { @@ -192,5 +199,5 @@ bool LTCTensorImpl::is_contiguous_custom(c10::MemoryFormat _unused) const { return true; } -} // namespace lazy -} // namespace torch +} // namespace lazy +} // namespace torch diff --git a/torch/csrc/lazy/core/tensor_impl.h b/torch/csrc/lazy/core/tensor_impl.h index 6a848a84cf1382..a2232fe47b1ef4 100644 --- a/torch/csrc/lazy/core/tensor_impl.h +++ b/torch/csrc/lazy/core/tensor_impl.h @@ -17,11 +17,15 @@ class TORCH_API LTCTensorImpl final : public c10::TensorImpl { explicit LTCTensorImpl(const LazyTensor& tensor); explicit LTCTensorImpl(LazyTensor&& tensor); - LazyTensorPtr tensor() { return tensor_; } + LazyTensorPtr tensor() { + return tensor_; + } void set_tensor(const LazyTensorPtr& lazy_tensor); - void force_refresh_sizes() { generation_ = 0; } + void force_refresh_sizes() { + generation_ = 0; + } c10::intrusive_ptr shallow_copy_and_detach( const c10::VariableVersion& version_counter, @@ -42,17 +46,21 @@ class TORCH_API LTCTensorImpl final : public c10::TensorImpl { virtual c10::SymIntArrayRef sym_sizes_custom() const override; #ifndef C10_DISABLE_TENSORIMPL_EXTENSIBILITY - const at::Storage& storage() const override { return tensor_->Storage(); } - bool has_storage() const override { return tensor_->Storage(); } -#endif // C10_DISABLE_TENSORIMPL_EXTENSIBILITY + const at::Storage& storage() const override { + return tensor_->Storage(); + } + bool has_storage() const override { + return tensor_->Storage(); + } +#endif // C10_DISABLE_TENSORIMPL_EXTENSIBILITY private: void setup_size_properties(); LazyTensorPtr tensor_; std::vector sym_sizes_; - size_t generation_ {0}; + size_t generation_{0}; }; -} // namespace lazy -} // namespace torch +} // namespace lazy +} // namespace torch diff --git a/torch/csrc/lazy/core/tensor_util.cpp b/torch/csrc/lazy/core/tensor_util.cpp index 1d9c5c36dbaf9a..8eb7632c2d8a59 100644 --- a/torch/csrc/lazy/core/tensor_util.cpp +++ b/torch/csrc/lazy/core/tensor_util.cpp @@ -58,7 +58,8 @@ std::vector CreateTensorsData( } bool IsSpecialScalar(const at::Scalar& value) { - if (FLAGS_torch_lazy_handle_special_scalars && (value.isIntegral(false) || value.isFloatingPoint())) { + if (FLAGS_torch_lazy_handle_special_scalars && + (value.isIntegral(false) || value.isFloatingPoint())) { double scalar_value = value.toDouble(); return scalar_value == 0.0 || std::fabs(scalar_value) == 1.0; } diff --git a/torch/csrc/lazy/core/thread_pool.cpp b/torch/csrc/lazy/core/thread_pool.cpp index 5cad72042deacf..a328f7b42afacd 100644 --- a/torch/csrc/lazy/core/thread_pool.cpp +++ b/torch/csrc/lazy/core/thread_pool.cpp @@ -19,7 +19,7 @@ class ThreadPool { explicit ThreadPool(size_t num_threads) { threads_.reserve(num_threads); for (const auto i : c10::irange(num_threads)) { - (void)i; //Suppress unused variable warning + (void)i; // Suppress unused variable warning threads_.emplace_back([this]() { Worker(); }); } } diff --git a/torch/csrc/lazy/core/trie.cpp b/torch/csrc/lazy/core/trie.cpp index 6566f62f9df5d0..21f9b7cea2b1d3 100644 --- a/torch/csrc/lazy/core/trie.cpp +++ b/torch/csrc/lazy/core/trie.cpp @@ -17,7 +17,7 @@ void TraverseTrie(TrieNode* node, std::stringstream& ss) { } if (node->ir_node) { ss << node->unique_id << "[label=\"" << node->ir_node->op().ToString() - << ", " << node->hit_counter << " hits\"]\n"; + << ", " << node->hit_counter << " hits\"]\n"; } for (auto& successor : node->successors) { ss << node->unique_id << " -> " << successor->unique_id << "\n"; @@ -31,13 +31,15 @@ TrieCache* TrieCache::Get() { return trie; } -TrieCache::TrieCache() : root_(std::make_shared()), current_(root_.get()) {} +TrieCache::TrieCache() + : root_(std::make_shared()), current_(root_.get()) {} TrieNode* TrieCache::Current() const { return current_; } -void TrieCache::SetCurrent(std::list>::iterator& iter) { +void TrieCache::SetCurrent( + std::list>::iterator& iter) { auto& successors = current_->successors; // Update current_ before iter gets destroyed current_ = (*iter).get(); diff --git a/torch/csrc/lazy/core/trie.h b/torch/csrc/lazy/core/trie.h index cfe1142a9758aa..bfb026d963cc0f 100644 --- a/torch/csrc/lazy/core/trie.h +++ b/torch/csrc/lazy/core/trie.h @@ -23,7 +23,9 @@ struct TORCH_API TrieNode { TrieNode() : unique_id(GetNextUniqueId()), hit_counter(0), ir_node(nullptr) {} explicit TrieNode(NodePtr node) - : unique_id(GetNextUniqueId()), hit_counter(0), ir_node(std::move(node)) {} + : unique_id(GetNextUniqueId()), + hit_counter(0), + ir_node(std::move(node)) {} }; class TORCH_API TrieCache { @@ -61,8 +63,10 @@ NodePtr LookupNodeFromTrieCache(Args&&... args) { for (auto it = successors.begin(); it != successors.end(); it++) { NodePtr ir_node = (*it)->ir_node; const T* concrete_node = NodeCast(ir_node.get()); - if (concrete_node && concrete_node->CanBeReused(std::forward(args)...)) { - TORCH_LAZY_COUNTER("IrNodeReused_" + c10::demangle((typeid(T).name())), 1); + if (concrete_node && + concrete_node->CanBeReused(std::forward(args)...)) { + TORCH_LAZY_COUNTER( + "IrNodeReused_" + c10::demangle((typeid(T).name())), 1); (*it)->hit_counter++; TrieCache::Get()->SetCurrent(it); return ir_node; diff --git a/torch/csrc/lazy/core/util.h b/torch/csrc/lazy/core/util.h index e324099af8b165..e1df80fdec87aa 100644 --- a/torch/csrc/lazy/core/util.h +++ b/torch/csrc/lazy/core/util.h @@ -44,11 +44,17 @@ class Cleanup { return *this; } - void Release() { func_ = nullptr; } + void Release() { + func_ = nullptr; + } - void SetStatus(StatusType&& status) { status_ = std::move(status); } + void SetStatus(StatusType&& status) { + status_ = std::move(status); + } - const StatusType& GetStatus() const { return status_; } + const StatusType& GetStatus() const { + return status_; + } private: std::function func_; @@ -59,19 +65,28 @@ using ExceptionCleanup = Cleanup; // Allows APIs which might return const references and values, to not be forced // to return values in the signature. -// TODO(alanwaketan): This is clever, but is there really no std or c10 supports? -// Needs more investigations. +// TODO(alanwaketan): This is clever, but is there really no std or c10 +// supports? Needs more investigations. template class MaybeRef { public: /* implicit */ MaybeRef(const T& ref) : ref_(ref) {} - /* implicit */ MaybeRef(T&& value) : storage_(std::move(value)), ref_(*storage_) {} + /* implicit */ MaybeRef(T&& value) + : storage_(std::move(value)), ref_(*storage_) {} - const T& Get() const { return ref_; } - const T& operator*() const { return Get(); } - operator const T&() const { return Get(); } + const T& Get() const { + return ref_; + } + const T& operator*() const { + return Get(); + } + operator const T&() const { + return Get(); + } - bool IsStored() const { return storage_.has_value(); } + bool IsStored() const { + return storage_.has_value(); + } private: c10::optional storage_; @@ -93,8 +108,9 @@ std::vector ToVector(const S& input) { return std::vector(input.begin(), input.end()); } -template -c10::optional> ToOptionalVector(c10::OptionalArrayRef arrayRef) { +template +c10::optional> ToOptionalVector( + c10::OptionalArrayRef arrayRef) { if (arrayRef) { return arrayRef->vec(); } diff --git a/torch/csrc/lazy/python/init.cpp b/torch/csrc/lazy/python/init.cpp index d36592d10abc70..df331c081d6697 100644 --- a/torch/csrc/lazy/python/init.cpp +++ b/torch/csrc/lazy/python/init.cpp @@ -3,15 +3,15 @@ #include #include #include +#include +#include #include +#include +#include #include #include #include #include -#include -#include -#include -#include #if !(defined(FBCODE_CAFFE2) || defined(OVRSOURCE)) #include #include @@ -22,9 +22,10 @@ namespace torch { namespace lazy { -// TODO(whc) backend 'device' related APIs are not very clear, this code could be -// simplified but it should probably be done together with designing/refactoring -// the overall approach to get/set of default eager/lazy device types +// TODO(whc) backend 'device' related APIs are not very clear, this code could +// be simplified but it should probably be done together with +// designing/refactoring the overall approach to get/set of default eager/lazy +// device types torch::lazy::BackendDevice GetDeviceOrCurrent(const std::string& device_str) { if (device_str.empty()) { getBackend()->GetDefaultDeviceType(); @@ -45,15 +46,17 @@ std::string GetTensorsDump( std::vector nodes; std::vector values; for (auto& tensor : tensors) { - torch::lazy::LazyTensorPtr lazy_tensor = torch::lazy::TryGetLtcTensor(tensor); + torch::lazy::LazyTensorPtr lazy_tensor = + torch::lazy::TryGetLtcTensor(tensor); values.push_back(lazy_tensor->GetIrValue()); nodes.push_back(values.back().node.get()); } return coverter(nodes); } -std::vector GetLtcTensors(const std::vector& tensors, - bool want_all) { +std::vector GetLtcTensors( + const std::vector& tensors, + bool want_all) { std::vector lazy_tensors; lazy_tensors.reserve(tensors.size()); if (want_all) { @@ -72,19 +75,24 @@ std::vector GetLtcTensors(const std::vector& tensors) { - std::vector lazy_tensors = GetLtcTensors(tensors, /*want_all=*/false); - return torch::lazy::LazyGraphExecutor::Get()->DumpBackendComputation(lazy_tensors); + std::vector lazy_tensors = + GetLtcTensors(tensors, /*want_all=*/false); + return torch::lazy::LazyGraphExecutor::Get()->DumpBackendComputation( + lazy_tensors); } -void SyncTensors(const std::vector& tensors, - const std::vector& devices, bool wait, - bool sync_ltc_data) { - std::vector lazy_tensors = GetLtcTensors(tensors, /*want_all=*/false); - torch::lazy::LazyGraphExecutor::Get()->SyncTensorsGraph(&lazy_tensors, devices, wait, - sync_ltc_data); +void SyncTensors( + const std::vector& tensors, + const std::vector& devices, + bool wait, + bool sync_ltc_data) { + std::vector lazy_tensors = + GetLtcTensors(tensors, /*want_all=*/false); + torch::lazy::LazyGraphExecutor::Get()->SyncTensorsGraph( + &lazy_tensors, devices, wait, sync_ltc_data); } -void initLazyBindings(PyObject* module){ +void initLazyBindings(PyObject* module) { auto m = py::handle(module).cast(); auto lazy = m.def_submodule("_lazy"); auto lazy_ts_backend = m.def_submodule("_lazy_ts_backend"); @@ -93,14 +101,18 @@ void initLazyBindings(PyObject* module){ "_mark_step", // TODO(whc) this API should probably change from vector to // vector but in a separate PR - [](const std::string& device_str, const std::vector& devices, + [](const std::string& device_str, + const std::vector& devices, bool wait) { pybind11::gil_scoped_release no_gil; auto backend_device = GetDeviceOrCurrent(device_str); - torch::lazy::LazyGraphExecutor::Get()->SyncLiveTensorsGraph(&backend_device, devices, wait); + torch::lazy::LazyGraphExecutor::Get()->SyncLiveTensorsGraph( + &backend_device, devices, wait); torch::lazy::LazyGraphExecutor::Get()->MarkStep(backend_device); }, - py::arg("device") = "", py::arg("devices"), py::arg("wait") = true); + py::arg("device") = "", + py::arg("devices"), + py::arg("wait") = true); lazy.def( "_wait_device_ops", [](const std::vector& devices) { @@ -112,33 +124,38 @@ void initLazyBindings(PyObject* module){ torch::lazy::LazyGraphExecutor::Get()->WaitDeviceOps({}); }, py::arg("devices")); - lazy.def("_reset_metrics", - []() { torch::lazy::MetricsArena::Get()->Reset(); }); + lazy.def( + "_reset_metrics", []() { torch::lazy::MetricsArena::Get()->Reset(); }); lazy.def("_counter_names", []() { return torch::lazy::GetCounterNames(); }); lazy.def("_counter_value", [](const std::string& name) -> py::object { torch::lazy::CounterData* data = torch::lazy::GetCounter(name); return data != nullptr ? py::cast(data->Value()) : py::none(); }); - lazy.def("_get_tensor_id", [](const at::Tensor& tensor) { return GetTensorId(tensor); }); + lazy.def("_get_tensor_id", [](const at::Tensor& tensor) { + return GetTensorId(tensor); + }); - lazy.def("_get_tensors_text", - [](const std::vector& tensors) -> std::string { - auto coverter = [](c10::ArrayRef nodes) { - return torch::lazy::DumpUtil::ToText(nodes); - }; - return GetTensorsDump(tensors, coverter); - }); - lazy.def("_get_tensors_dot", - [](const std::vector& tensors) -> std::string { - auto coverter = [](c10::ArrayRef nodes) { - return torch::lazy::DumpUtil::ToDot(nodes); - }; - return GetTensorsDump(tensors, coverter); - }); - lazy.def("_get_tensors_backend", - [](const std::vector& tensors) -> std::string { - return GetTensorsBackendGraph(tensors); - }); + lazy.def( + "_get_tensors_text", + [](const std::vector& tensors) -> std::string { + auto coverter = [](c10::ArrayRef nodes) { + return torch::lazy::DumpUtil::ToText(nodes); + }; + return GetTensorsDump(tensors, coverter); + }); + lazy.def( + "_get_tensors_dot", + [](const std::vector& tensors) -> std::string { + auto coverter = [](c10::ArrayRef nodes) { + return torch::lazy::DumpUtil::ToDot(nodes); + }; + return GetTensorsDump(tensors, coverter); + }); + lazy.def( + "_get_tensors_backend", + [](const std::vector& tensors) -> std::string { + return GetTensorsBackendGraph(tensors); + }); lazy.def("_get_graph_hash", [](const std::vector& tensors) { std::vector xtensors; xtensors.reserve(tensors.size()); @@ -146,153 +163,151 @@ void initLazyBindings(PyObject* module){ xtensors.push_back(TryGetLtcTensor(tensor)); } auto hash = LazyGraphExecutor::Get()->GetGraphHash(xtensors); - std::string bin((const char*) &hash, sizeof(hash)); + std::string bin((const char*)&hash, sizeof(hash)); return py::bytes(bin); }); lazy.def( "_sync_multi", [](const std::vector& tensors, - const std::vector& devices, bool wait, + const std::vector& devices, + bool wait, bool sync_ltc_data) { pybind11::gil_scoped_release no_gil; SyncTensors(tensors, devices, wait, sync_ltc_data); }, - py::arg("tensors"), py::arg("devices"), py::arg("wait") = true, + py::arg("tensors"), + py::arg("devices"), + py::arg("wait") = true, py::arg("sync_ltc_data") = true); - lazy.def( - "_get_force_fallback", []() { - return torch::lazy::getLTCForceFallback(); - } - ); - lazy.def( - "_set_force_fallback", [](std::string newval) { - torch::lazy::getLTCForceFallback() = newval; - } - ); - lazy.def( - "_clear_ir_cache", []() { - TrieCache::Get()->Clear(); - } - ); - lazy.def( - "_dump_ir_cache", [](std::string filename) { - TrieCache::Get()->DumpToDotFile(filename); - } - ); - lazy.def( - "_set_reuse_ir", [](bool val) { - FLAGS_torch_lazy_reuse_ir = val; - } - ); - lazy.def( - "_set_symbolic_shape_mode", [](bool val) { - FLAGS_ltc_enable_symbolic_shapes = val; - } - ); - lazy.def( - "_get_symbolic_shape_mode", []() { - return FLAGS_ltc_enable_symbolic_shapes; - } - ); + lazy.def("_get_force_fallback", []() { + return torch::lazy::getLTCForceFallback(); + }); + lazy.def("_set_force_fallback", [](std::string newval) { + torch::lazy::getLTCForceFallback() = newval; + }); + lazy.def("_clear_ir_cache", []() { TrieCache::Get()->Clear(); }); + lazy.def("_dump_ir_cache", [](std::string filename) { + TrieCache::Get()->DumpToDotFile(filename); + }); + lazy.def("_set_reuse_ir", [](bool val) { FLAGS_torch_lazy_reuse_ir = val; }); + lazy.def("_set_symbolic_shape_mode", [](bool val) { + FLAGS_ltc_enable_symbolic_shapes = val; + }); + lazy.def("_get_symbolic_shape_mode", []() { + return FLAGS_ltc_enable_symbolic_shapes; + }); - lazy_ts_backend.def( - "_init", - []() { + lazy_ts_backend.def("_init", []() { #if !(defined(FBCODE_CAFFE2) || defined(OVRSOURCE)) - torch::lazy::InitTorchScriptBackend(); + torch::lazy::InitTorchScriptBackend(); #else TORCH_CHECK(false, "TorchScript backend not yet supported in FBCODE/OVRSOURCE builds"); -#endif // !(defined(FBCODE_CAFFE2) || defined(OVRSOURCE)) - }); +#endif // !(defined(FBCODE_CAFFE2) || defined(OVRSOURCE)) + }); /* * Return tensor ids and tensors for DeviceData nodes. * TODO(shunting) revisit this API for XLA */ - lazy_ts_backend.def("_get_tensors_ts_device_data_node", - [](const std::vector& tensors) -> std::pair, std::vector> { + lazy_ts_backend.def( + "_get_tensors_ts_device_data_node", + [](const std::vector& tensors) + -> std::pair, std::vector> { #if !(defined(FBCODE_CAFFE2) || defined(OVRSOURCE)) - std::vector roots; - for (auto& tensor : tensors) { - auto xtensor = TryGetLtcTensor(tensor); - roots.push_back(xtensor->GetIrValue().node.get()); - } - auto post_order = Util::ComputePostOrder(roots); - std::vector tensor_ids; - std::vector ivalues; + std::vector roots; + for (auto& tensor : tensors) { + auto xtensor = TryGetLtcTensor(tensor); + roots.push_back(xtensor->GetIrValue().node.get()); + } + auto post_order = Util::ComputePostOrder(roots); + std::vector tensor_ids; + std::vector ivalues; - std::unordered_set data_handles_; - for (auto nodeptr : post_order) { - if (nodeptr->op() == *torch::lazy::ltc_device_data) { - const auto backend_data = getBackend()->GetComputationDataFromNode(nodeptr); + std::unordered_set data_handles_; + for (auto nodeptr : post_order) { + if (nodeptr->op() == *torch::lazy::ltc_device_data) { + const auto backend_data = + getBackend()->GetComputationDataFromNode(nodeptr); - auto infoptr = backend_data->info(); - auto deviceDataInfoPtr = (torch::lazy::LazyGraphExecutor::DeviceDataInfo*) infoptr; - auto* tsDataPtr = (torch::lazy::TSData*) backend_data.get(); + auto infoptr = backend_data->info(); + auto deviceDataInfoPtr = + (torch::lazy::LazyGraphExecutor::DeviceDataInfo*)infoptr; + auto* tsDataPtr = (torch::lazy::TSData*)backend_data.get(); - // dedup DeviceData by handle - auto handle = tsDataPtr->GetHandle(); - if (!data_handles_.insert(handle).second) { - continue; - } - tensor_ids.push_back(deviceDataInfoPtr->tensor_id); - /* - * If the TSData contains a tensor, then the tensor id will uniquely identify the tensor. - * We use that tensor id to find the tensor in other places: e.g. in the python forward method parameters. - * - * If the TSData contains a scalar, the tensor id itself is not important. We reuse the scalar value in - * future calls. - */ - if (tsDataPtr->HasValue()) { - ivalues.emplace_back(tsDataPtr->data()); - } else { - CHECK(tsDataPtr->scalar.has_value()); - ivalues.emplace_back(tsDataPtr->scalar.value()); - } + // dedup DeviceData by handle + auto handle = tsDataPtr->GetHandle(); + if (!data_handles_.insert(handle).second) { + continue; + } + tensor_ids.push_back(deviceDataInfoPtr->tensor_id); + /* + * If the TSData contains a tensor, then the tensor id will uniquely + * identify the tensor. We use that tensor id to find the tensor in + * other places: e.g. in the python forward method parameters. + * + * If the TSData contains a scalar, the tensor id itself is not + * important. We reuse the scalar value in future calls. + */ + if (tsDataPtr->HasValue()) { + ivalues.emplace_back(tsDataPtr->data()); + } else { + CHECK(tsDataPtr->scalar.has_value()); + ivalues.emplace_back(tsDataPtr->scalar.value()); } } - return std::make_pair(tensor_ids, ivalues); + } + return std::make_pair(tensor_ids, ivalues); #else - TORCH_CHECK(false, "TorchScript backend not yet supported in FBCODE builds"); - return std::make_pair(std::vector(), std::vector()); -#endif // !(defined(FBCODE_CAFFE2) || defined(OVRSOURCE)) - }); + TORCH_CHECK( + false, "TorchScript backend not yet supported in FBCODE builds"); + return std::make_pair( + std::vector(), std::vector()); +#endif // !(defined(FBCODE_CAFFE2) || defined(OVRSOURCE)) + }); // TODO(shunting) revisit this part for XLA - lazy_ts_backend.def("_run_cached_graph", [](const std::string& hash_str, const std::vector& graph_inputs) { - std::vector result; + lazy_ts_backend.def( + "_run_cached_graph", + [](const std::string& hash_str, + const std::vector& graph_inputs) { + std::vector result; #if !(defined(FBCODE_CAFFE2) || defined(OVRSOURCE)) - TORCH_CHECK(hash_str.size() == sizeof(hash_t)); - hash_t hash = *(hash_t*) (hash_str.c_str()); - auto cachedComputation = LazyGraphExecutor::Get()->GetComputationCache()->Get(hash); - TORCH_CHECK(cachedComputation, "Failed to get computation by hash. Maybe the entry get kicked out of the LRU cache"); // TODO implement a fallback mechanism, or make sure those entries never get kicked out - auto computationPtr = (torch::lazy::TSComputation*) cachedComputation->computation.get(); + TORCH_CHECK(hash_str.size() == sizeof(hash_t)); + hash_t hash = *(hash_t*)(hash_str.c_str()); + auto cachedComputation = + LazyGraphExecutor::Get()->GetComputationCache()->Get(hash); + TORCH_CHECK( + cachedComputation, + "Failed to get computation by hash. Maybe the entry get kicked out of the LRU cache"); // TODO implement a fallback mechanism, or make sure those entries never get kicked out + auto computationPtr = + (torch::lazy::TSComputation*)cachedComputation->computation.get(); - std::vector stack; - stack.reserve(graph_inputs.size()); - for (const auto& arg : graph_inputs) { - stack.emplace_back(arg); - } - computationPtr->graph_executor().run(stack); - result.reserve(stack.size()); - for (torch::jit::IValue elem : stack) { - result.push_back(elem.toTensor()); - } + std::vector stack; + stack.reserve(graph_inputs.size()); + for (const auto& arg : graph_inputs) { + stack.emplace_back(arg); + } + computationPtr->graph_executor().run(stack); + result.reserve(stack.size()); + for (torch::jit::IValue elem : stack) { + result.push_back(elem.toTensor()); + } #else - TORCH_CHECK(false, "TorchScript backend not yet supported in FBCODE builds"); -#endif // !(defined(FBCODE_CAFFE2) || defined(OVRSOURCE)) - return result; - }); + TORCH_CHECK( + false, "TorchScript backend not yet supported in FBCODE builds"); +#endif // !(defined(FBCODE_CAFFE2) || defined(OVRSOURCE)) + return result; + }); #ifndef USE_DEPLOY // When libtorch_python is loaded, we register the python frame getter // otherwise, debug util simply omits python frames // TODO(whc) can we make this work inside torch deploy interpreter? - // it doesn't work as-is, possibly becuase GetPythonFrames resolves to external - // cpython rather than embedded cpython + // it doesn't work as-is, possibly becuase GetPythonFrames resolves to + // external cpython rather than embedded cpython GetPythonFramesFunction() = GetPythonFrames; #endif } -} // namespace lazy -} // namespace torch +} // namespace lazy +} // namespace torch diff --git a/torch/csrc/lazy/python/init.h b/torch/csrc/lazy/python/init.h index 3049baa4421ea2..00df90966e4ef0 100644 --- a/torch/csrc/lazy/python/init.h +++ b/torch/csrc/lazy/python/init.h @@ -7,5 +7,5 @@ namespace lazy { TORCH_API void initLazyBindings(PyObject* module); -} // namespace lazy -} // namespace torch +} // namespace lazy +} // namespace torch diff --git a/torch/csrc/lazy/python/python_util.cpp b/torch/csrc/lazy/python/python_util.cpp index 485d288c90b1fe..e594c752de6f4f 100644 --- a/torch/csrc/lazy/python/python_util.cpp +++ b/torch/csrc/lazy/python/python_util.cpp @@ -3,8 +3,8 @@ #include #include #include -#include #include +#include namespace torch { namespace lazy { @@ -42,5 +42,5 @@ std::vector GetPythonFrames() { return frames; } -} // namespace lazy -} // namespace torch +} // namespace lazy +} // namespace torch diff --git a/torch/csrc/lazy/python/python_util.h b/torch/csrc/lazy/python/python_util.h index 7b00f8ac7cd239..23df3d192fe9aa 100644 --- a/torch/csrc/lazy/python/python_util.h +++ b/torch/csrc/lazy/python/python_util.h @@ -11,5 +11,5 @@ c10::optional TORCH_API GetPythonFrameTop(); std::vector TORCH_API GetPythonFrames(); -} // namespace lazy -} // namespace torch +} // namespace lazy +} // namespace torch diff --git a/torch/csrc/lazy/ts_backend/dynamic_ir.cpp b/torch/csrc/lazy/ts_backend/dynamic_ir.cpp index 0c94aab7b4e99c..38ac4e4cc341a0 100644 --- a/torch/csrc/lazy/ts_backend/dynamic_ir.cpp +++ b/torch/csrc/lazy/ts_backend/dynamic_ir.cpp @@ -3,67 +3,81 @@ namespace torch { namespace lazy { -DimensionNode::DimensionNode(OpKind op, OpList operands, hash_t hash_seed): - TsNode(op, operands, /*num_outputs=*/1, - /* hash_seed */ HashCombine(op.hash(), hash_seed)){} +DimensionNode::DimensionNode(OpKind op, OpList operands, hash_t hash_seed) + : TsNode( + op, + operands, + /*num_outputs=*/1, + /* hash_seed */ HashCombine(op.hash(), hash_seed)) {} std::string DimensionNode::ToString() const { return "DimensionNode"; } - TSOpVector SizeNode::Lower(std::shared_ptr function, - TSLoweringContext* loctx) const { +TSOpVector SizeNode::Lower( + std::shared_ptr function, + TSLoweringContext* loctx) const { std::vector arguments; std::vector kwarguments; arguments.reserve(2); auto index = loctx->graph()->insertConstant(static_cast(this->dim_)); arguments.emplace_back(loctx->GetOutputOp(operand(0))); arguments.emplace_back(index); - torch::lazy::TSOpVector size_out = torch::lazy::LowerTSBuiltin(function, op().op, arguments, kwarguments); + torch::lazy::TSOpVector size_out = + torch::lazy::LowerTSBuiltin(function, op().op, arguments, kwarguments); CHECK_EQ(size_out.size(), 1); return size_out; } -SizeNode::SizeNode(Value input, size_t dim): - DimensionNode(OpKind{c10::Symbol::fromQualString("aten::size")}, {input}, MHash(dim)), - dim_(dim) {}; +SizeNode::SizeNode(Value input, size_t dim) + : DimensionNode( + OpKind{c10::Symbol::fromQualString("aten::size")}, + {input}, + MHash(dim)), + dim_(dim){}; -int64_t SizeNode:: getStaticValue() const { - return dynamic_cast(operand(0).node)->shape(0).size(dim_); +int64_t SizeNode::getStaticValue() const { + return dynamic_cast(operand(0).node)->shape(0).size(dim_); } std::string SizeNode::ToString() const { return "SizeNode"; } -SizeAdd::SizeAdd(Value a, Value b): - DimensionNode(OpKind{c10::Symbol::fromQualString("aten::add")}, {a, b}) {}; +SizeAdd::SizeAdd(Value a, Value b) + : DimensionNode(OpKind{c10::Symbol::fromQualString("aten::add")}, {a, b}){}; int64_t SizeAdd::getStaticValue() const { - return dynamic_cast(operand(0).node)->getStaticValue() + dynamic_cast(operand(1).node)->getStaticValue(); + return dynamic_cast(operand(0).node)->getStaticValue() + + dynamic_cast(operand(1).node)->getStaticValue(); } std::string SizeAdd::ToString() const { return "SizeAdd"; } -SizeMul::SizeMul(Value a, Value b): - DimensionNode(OpKind{c10::Symbol::fromQualString("aten::mul")}, {a, b}) {}; +SizeMul::SizeMul(Value a, Value b) + : DimensionNode(OpKind{c10::Symbol::fromQualString("aten::mul")}, {a, b}){}; int64_t SizeMul::getStaticValue() const { - return dynamic_cast(operand(0).node)->getStaticValue() * dynamic_cast(operand(1).node)->getStaticValue(); + return dynamic_cast(operand(0).node)->getStaticValue() * + dynamic_cast(operand(1).node)->getStaticValue(); } std::string SizeMul::ToString() const { return "SizeMul"; } -SizeDiv::SizeDiv(Value a, Value b): - DimensionNode(OpKind{c10::Symbol::fromQualString("aten::div")}, {a, b}) {}; +SizeDiv::SizeDiv(Value a, Value b) + : DimensionNode(OpKind{c10::Symbol::fromQualString("aten::div")}, {a, b}){}; int64_t SizeDiv::getStaticValue() const { - TORCH_CHECK(dynamic_cast(operand(1).node)->getStaticValue() != 0, "Can't divide a dimension by zero"); - return dynamic_cast(operand(0).node)->getStaticValue() / dynamic_cast(operand(1).node)->getStaticValue(); + TORCH_CHECK( + dynamic_cast(operand(1).node)->getStaticValue() != + 0, + "Can't divide a dimension by zero"); + return dynamic_cast(operand(0).node)->getStaticValue() / + dynamic_cast(operand(1).node)->getStaticValue(); } std::string SizeDiv::ToString() const { diff --git a/torch/csrc/lazy/ts_backend/dynamic_ir.h b/torch/csrc/lazy/ts_backend/dynamic_ir.h index c35d9c28f46710..d45fca12417242 100644 --- a/torch/csrc/lazy/ts_backend/dynamic_ir.h +++ b/torch/csrc/lazy/ts_backend/dynamic_ir.h @@ -12,11 +12,11 @@ #include #include +#include #include #include #include #include -#include C10_DECLARE_bool(ltc_enable_dynamic_shapes); @@ -47,7 +47,7 @@ class TORCH_API DimensionNode : public lazy::TsNode { public: DimensionNode(OpKind op, OpList operands, hash_t hash_seed = kHashSeed); bool isDynamic() { - return false; + return false; } std::string ToString() const override; @@ -62,25 +62,26 @@ class TORCH_API SizeNode : public DimensionNode { int64_t getStaticValue() const override; std::string ToString() const override; size_t dim_ = 0; - virtual TSOpVector Lower(std::shared_ptr function, - TSLoweringContext* loctx) const override; + virtual TSOpVector Lower( + std::shared_ptr function, + TSLoweringContext* loctx) const override; }; -class TORCH_API SizeAdd: public DimensionNode { +class TORCH_API SizeAdd : public DimensionNode { public: SizeAdd(Value a, Value b); int64_t getStaticValue() const override; std::string ToString() const override; }; -class TORCH_API SizeMul: public DimensionNode { +class TORCH_API SizeMul : public DimensionNode { public: SizeMul(Value a, Value b); int64_t getStaticValue() const override; std::string ToString() const override; }; -class TORCH_API SizeDiv: public DimensionNode { +class TORCH_API SizeDiv : public DimensionNode { public: SizeDiv(Value a, Value b); int64_t getStaticValue() const override; diff --git a/torch/csrc/lazy/ts_backend/ir_builder.h b/torch/csrc/lazy/ts_backend/ir_builder.h index f1b00f4dd67b04..600243b67f6221 100644 --- a/torch/csrc/lazy/ts_backend/ir_builder.h +++ b/torch/csrc/lazy/ts_backend/ir_builder.h @@ -1,50 +1,152 @@ #pragma once +#include #include #include -#include #include #include -#include #include -#include #include +#include +#include namespace torch { namespace lazy { - -struct TorchScriptIrBuilder: IrBuilder { - NodePtr MakeDeviceData(const std::shared_ptr& data) const override { return DeviceData::Create(data); } - // TODO: Scalar node is not currently used by ts_backend. Enable reusing Scalar node later if needed. - NodePtr MakeScalar(const at::Scalar& value, const at::ScalarType& type) const override { return MakeNode(value, type); } - NodePtr MakeExpand(const Value& input0, const std::vector& size, const bool& is_scalar_expand) const override { return ReuseOrMakeNode(input0, size, is_scalar_expand); } - NodePtr MakeView(const Value& input0, const std::vector& output_size) const override { return ReuseOrMakeNode(input0, output_size); } - NodePtr MakeCast(const Value& input0, const at::ScalarType& dtype, const c10::optional& stype = c10::nullopt) const override { return ReuseOrMakeNode(input0, dtype, stype); } - NodePtr MakeTensorList(const OpList& inputs) const override { return ReuseOrMakeNode(inputs); } +struct TorchScriptIrBuilder : IrBuilder { + NodePtr MakeDeviceData( + const std::shared_ptr& data) const override { + return DeviceData::Create(data); + } + // TODO: Scalar node is not currently used by ts_backend. Enable reusing + // Scalar node later if needed. + NodePtr MakeScalar(const at::Scalar& value, const at::ScalarType& type) + const override { + return MakeNode(value, type); + } + NodePtr MakeExpand( + const Value& input0, + const std::vector& size, + const bool& is_scalar_expand) const override { + return ReuseOrMakeNode(input0, size, is_scalar_expand); + } + NodePtr MakeView(const Value& input0, const std::vector& output_size) + const override { + return ReuseOrMakeNode(input0, output_size); + } + NodePtr MakeCast( + const Value& input0, + const at::ScalarType& dtype, + const c10::optional& stype = + c10::nullopt) const override { + return ReuseOrMakeNode(input0, dtype, stype); + } + NodePtr MakeTensorList(const OpList& inputs) const override { + return ReuseOrMakeNode(inputs); + } // Generic needs cleanup - NodePtr MakeGeneric(const OpKind& op, const OpList& operands, const Shape& shape, const size_t& num_outputs = 1, const hash_t& hash_seed = static_cast(0x5a2d296e9)) const override { return MakeNode(op, operands, shape, num_outputs, hash_seed); } + NodePtr MakeGeneric( + const OpKind& op, + const OpList& operands, + const Shape& shape, + const size_t& num_outputs = 1, + const hash_t& hash_seed = + static_cast(0x5a2d296e9)) const override { + return MakeNode(op, operands, shape, num_outputs, hash_seed); + } // View op nodes - NodePtr MakeAsStridedViewUpdate(const Value& input0, const Value& input1, const std::vector& size, const std::vector& stride, const int64_t& storage_offset) const override { return ReuseOrMakeNode(input0, input1, size, stride, storage_offset); } - NodePtr MakeAsStrided(const Value& input0, const std::vector& size, const std::vector& stride, const int64_t& storage_offset) const override { return ReuseOrMakeNode(input0, size, stride, storage_offset); } - NodePtr MakeDiagonalViewUpdate(const Value& input0, const Value& input1, const int64_t& offset, const int64_t& dim1, const int64_t& dim2) const override { return ReuseOrMakeNode(input0, input1, offset, dim1, dim2); } - NodePtr MakeDiagonal(const Value& input0, const int64_t& offset, const int64_t& dim1, const int64_t& dim2) const override { return ReuseOrMakeNode(input0, offset, dim1, dim2); } - NodePtr MakeNarrowViewUpdate(const Value& input0, const Value& input1, const std::vector& base_indices) const override { return ReuseOrMakeNode(input0, input1, base_indices); } - NodePtr MakeNarrow(const Value& input0, const std::vector& base_indices, const std::vector& sizes) const override { return ReuseOrMakeNode(input0, base_indices, sizes); } - NodePtr MakePermute(const Value& input0, const std::vector& dims) const override { return ReuseOrMakeNode(input0, dims); } - NodePtr MakeResize(const Value& input0, const std::vector& size) const override { return ReuseOrMakeNode(input0, size); } - NodePtr MakeSelectViewUpdate(const Value& input0, const Value& input1, const int64_t& dim, const int64_t& start, const int64_t& end, const int64_t& stride) const override { return ReuseOrMakeNode(input0, input1, dim, start, end, stride); } - NodePtr MakeSelect(const Value& input0, const int64_t& dim, const int64_t& start, const int64_t& end, const int64_t& stride) const override { return ReuseOrMakeNode(input0, dim, start, end, stride); + } + NodePtr MakeSqueeze(const Value& input0, const int& dim) const override { + return ReuseOrMakeNode(input0, dim); + } + NodePtr MakeUnsqueeze(const Value& input0, const int& dim) const override { + return ReuseOrMakeNode(input0, dim); + } // dynamic ir nodes // TODO: verify if IR node reusing works for Dynamic shape ops - NodePtr MakeSizeNode(const Value& input, size_t dim) const override { return MakeNode(input, dim); } - NodePtr MakeSizeAdd(const Value& a, const Value& b) const override { return MakeNode(a, b); } - NodePtr MakeSizeMul(const Value& a, const Value& b) const override { return MakeNode(a, b); } - NodePtr MakeSizeDiv(const Value& a, const Value& b) const override { return MakeNode(a, b); } + NodePtr MakeSizeNode(const Value& input, size_t dim) const override { + return MakeNode(input, dim); + } + NodePtr MakeSizeAdd(const Value& a, const Value& b) const override { + return MakeNode(a, b); + } + NodePtr MakeSizeMul(const Value& a, const Value& b) const override { + return MakeNode(a, b); + } + NodePtr MakeSizeDiv(const Value& a, const Value& b) const override { + return MakeNode(a, b); + } }; } // namespace lazy diff --git a/torch/csrc/lazy/ts_backend/ops/batch_norm_ops.cpp b/torch/csrc/lazy/ts_backend/ops/batch_norm_ops.cpp index d751e414a5539c..3b5ca967e4c1d3 100644 --- a/torch/csrc/lazy/ts_backend/ops/batch_norm_ops.cpp +++ b/torch/csrc/lazy/ts_backend/ops/batch_norm_ops.cpp @@ -1,43 +1,61 @@ -#include #include +#include namespace torch { namespace lazy { TSNativeBatchNormBackward::TSNativeBatchNormBackward( - const torch::lazy::Value& grad_out, const torch::lazy::Value& input, - const torch::lazy::Value& weight, const torch::lazy::Value& running_mean, - const torch::lazy::Value& running_var, const torch::lazy::Value& save_mean, - const torch::lazy::Value& save_invstd, bool training, double eps, + const torch::lazy::Value& grad_out, + const torch::lazy::Value& input, + const torch::lazy::Value& weight, + const torch::lazy::Value& running_mean, + const torch::lazy::Value& running_var, + const torch::lazy::Value& save_mean, + const torch::lazy::Value& save_invstd, + bool training, + double eps, std::array output_mask) : torch::lazy::TsNode( torch::lazy::OpKind(at::aten::native_batch_norm_backward), - {grad_out, input, weight, running_mean, running_var, save_mean, + {grad_out, + input, + weight, + running_mean, + running_var, + save_mean, save_invstd}, - {input.shape(), - weight.shape(), - weight.shape()}, + {input.shape(), weight.shape(), weight.shape()}, /*num_outputs=*/3, - torch::lazy::MHash(training, eps, output_mask[0], output_mask[1], - output_mask[2])), + torch::lazy::MHash( + training, + eps, + output_mask[0], + output_mask[1], + output_mask[2])), training_(training), eps_(eps), output_mask_(output_mask) {} TSNativeBatchNormBackward::TSNativeBatchNormBackward( - const torch::lazy::Value& grad_out, const torch::lazy::Value& input, - const torch::lazy::Value& weight, const torch::lazy::Value& save_mean, - const torch::lazy::Value& save_invstd, bool training, double eps, + const torch::lazy::Value& grad_out, + const torch::lazy::Value& input, + const torch::lazy::Value& weight, + const torch::lazy::Value& save_mean, + const torch::lazy::Value& save_invstd, + bool training, + double eps, std::array output_mask) : torch::lazy::TsNode( torch::lazy::OpKind(at::aten::native_batch_norm_backward), {grad_out, input, weight, save_mean, save_invstd}, - {input.shape(), - weight.shape(), - weight.shape()}, + {input.shape(), weight.shape(), weight.shape()}, /*num_outputs=*/3, - torch::lazy::MHash(training, eps, output_mask[0], output_mask[1], - output_mask[2])), + torch::lazy::MHash( + training, + eps, + output_mask[0], + output_mask[1], + output_mask[2])), training_(training), eps_(eps), output_mask_(output_mask) {} @@ -50,17 +68,20 @@ std::string TSNativeBatchNormBackward::ToString() const { } TSNativeBatchNormForward::TSNativeBatchNormForward( - const torch::lazy::Value& input, const torch::lazy::Value& weight, - const torch::lazy::Value& bias, const torch::lazy::Value& running_mean, - const torch::lazy::Value& running_var, bool training, double momentum, + const torch::lazy::Value& input, + const torch::lazy::Value& weight, + const torch::lazy::Value& bias, + const torch::lazy::Value& running_mean, + const torch::lazy::Value& running_var, + bool training, + double momentum, double eps) - : torch::lazy::TsNode(torch::lazy::OpKind(at::aten::native_batch_norm), - {input, weight, bias, running_mean, running_var}, - {input.shape(), - running_mean.shape(), - running_var.shape()}, - /*num_outputs=*/3, - torch::lazy::MHash(training, momentum, eps)), + : torch::lazy::TsNode( + torch::lazy::OpKind(at::aten::native_batch_norm), + {input, weight, bias, running_mean, running_var}, + {input.shape(), running_mean.shape(), running_var.shape()}, + /*num_outputs=*/3, + torch::lazy::MHash(training, momentum, eps)), training_(training), momentum_(momentum), eps_(eps) {} @@ -72,5 +93,5 @@ std::string TSNativeBatchNormForward::ToString() const { return ss.str(); } -} // namespace lazy -} // namespace torch +} // namespace lazy +} // namespace torch diff --git a/torch/csrc/lazy/ts_backend/ops/batch_norm_ops.h b/torch/csrc/lazy/ts_backend/ops/batch_norm_ops.h index 64655fa4ee5b15..515a75f0b29c2b 100644 --- a/torch/csrc/lazy/ts_backend/ops/batch_norm_ops.h +++ b/torch/csrc/lazy/ts_backend/ops/batch_norm_ops.h @@ -12,54 +12,82 @@ class TSNativeBatchNormBackward : public torch::lazy::TsNode { return OpKind(at::aten::native_batch_norm_backward); } - TSNativeBatchNormBackward(const torch::lazy::Value& grad_out, const torch::lazy::Value& input, - const torch::lazy::Value& weight, const torch::lazy::Value& running_mean, - const torch::lazy::Value& running_var, const torch::lazy::Value& save_mean, - const torch::lazy::Value& save_invstd, bool training, double eps, - std::array output_mask); - - TSNativeBatchNormBackward(const torch::lazy::Value& grad_out, const torch::lazy::Value& input, - const torch::lazy::Value& weight, const torch::lazy::Value& save_mean, - const torch::lazy::Value& save_invstd, bool training, double eps, - std::array output_mask); - - bool CanBeReused(const torch::lazy::Value& grad_out, - const torch::lazy::Value& input, const torch::lazy::Value& weight, - const torch::lazy::Value& running_mean, - const torch::lazy::Value& running_var, - const torch::lazy::Value& save_mean, - const torch::lazy::Value& save_invstd, bool training, double eps, - std::array output_mask) const { + TSNativeBatchNormBackward( + const torch::lazy::Value& grad_out, + const torch::lazy::Value& input, + const torch::lazy::Value& weight, + const torch::lazy::Value& running_mean, + const torch::lazy::Value& running_var, + const torch::lazy::Value& save_mean, + const torch::lazy::Value& save_invstd, + bool training, + double eps, + std::array output_mask); + + TSNativeBatchNormBackward( + const torch::lazy::Value& grad_out, + const torch::lazy::Value& input, + const torch::lazy::Value& weight, + const torch::lazy::Value& save_mean, + const torch::lazy::Value& save_invstd, + bool training, + double eps, + std::array output_mask); + + bool CanBeReused( + const torch::lazy::Value& grad_out, + const torch::lazy::Value& input, + const torch::lazy::Value& weight, + const torch::lazy::Value& running_mean, + const torch::lazy::Value& running_var, + const torch::lazy::Value& save_mean, + const torch::lazy::Value& save_invstd, + bool training, + double eps, + std::array output_mask) const { size_t i = 0; - return (operand(i++) == grad_out && operand(i++) == input && - operand(i++) == weight && operand(i++) == running_mean && - operand(i++) == running_var && operand(i++) == save_mean && - operand(i++) == save_invstd && training_ == training && - eps_ == eps && output_mask_ == output_mask); + return ( + operand(i++) == grad_out && operand(i++) == input && + operand(i++) == weight && operand(i++) == running_mean && + operand(i++) == running_var && operand(i++) == save_mean && + operand(i++) == save_invstd && training_ == training && eps_ == eps && + output_mask_ == output_mask); } - bool CanBeReused(const torch::lazy::Value& grad_out, - const torch::lazy::Value& input, const torch::lazy::Value& weight, - const torch::lazy::Value& save_mean, - const torch::lazy::Value& save_invstd, bool training, double eps, - std::array output_mask) const { + bool CanBeReused( + const torch::lazy::Value& grad_out, + const torch::lazy::Value& input, + const torch::lazy::Value& weight, + const torch::lazy::Value& save_mean, + const torch::lazy::Value& save_invstd, + bool training, + double eps, + std::array output_mask) const { size_t i = 0; - return (operand(i++) == grad_out && operand(i++) == input && - operand(i++) == weight && operand(i++) == save_mean && - operand(i++) == save_invstd && training_ == training && - eps_ == eps && output_mask_ == output_mask); + return ( + operand(i++) == grad_out && operand(i++) == input && + operand(i++) == weight && operand(i++) == save_mean && + operand(i++) == save_invstd && training_ == training && eps_ == eps && + output_mask_ == output_mask); } std::string ToString() const override; - bool training() const { return training_; } + bool training() const { + return training_; + } - double eps() const { return eps_; } + double eps() const { + return eps_; + } - const std::array& output_mask() const { return output_mask_; } + const std::array& output_mask() const { + return output_mask_; + } - TSOpVector Lower(std::shared_ptr function, - TSLoweringContext* loctx) const override; + TSOpVector Lower( + std::shared_ptr function, + TSLoweringContext* loctx) const override; private: bool training_; @@ -73,33 +101,50 @@ class TSNativeBatchNormForward : public torch::lazy::TsNode { return OpKind(at::aten::native_batch_norm); } - TSNativeBatchNormForward(const torch::lazy::Value& input, const torch::lazy::Value& weight, - const torch::lazy::Value& bias, const torch::lazy::Value& running_mean, - const torch::lazy::Value& running_var, bool training, - double momentum, double eps); - - bool CanBeReused(const torch::lazy::Value& input, const torch::lazy::Value& weight, - const torch::lazy::Value& bias, - const torch::lazy::Value& running_mean, - const torch::lazy::Value& running_var, bool training, - double momentum, double eps) const { + TSNativeBatchNormForward( + const torch::lazy::Value& input, + const torch::lazy::Value& weight, + const torch::lazy::Value& bias, + const torch::lazy::Value& running_mean, + const torch::lazy::Value& running_var, + bool training, + double momentum, + double eps); + + bool CanBeReused( + const torch::lazy::Value& input, + const torch::lazy::Value& weight, + const torch::lazy::Value& bias, + const torch::lazy::Value& running_mean, + const torch::lazy::Value& running_var, + bool training, + double momentum, + double eps) const { size_t i = 0; - return (operand(i++) == input && operand(i++) == weight && - operand(i++) == bias && operand(i++) == running_mean && - operand(i++) == running_var && training_ == training && - momentum_ == momentum && eps == eps_); + return ( + operand(i++) == input && operand(i++) == weight && + operand(i++) == bias && operand(i++) == running_mean && + operand(i++) == running_var && training_ == training && + momentum_ == momentum && eps == eps_); } std::string ToString() const override; - bool training() const { return training_; } + bool training() const { + return training_; + } - double momentum() const { return momentum_; } + double momentum() const { + return momentum_; + } - double eps() const { return eps_; } + double eps() const { + return eps_; + } - TSOpVector Lower(std::shared_ptr function, - TSLoweringContext* loctx) const override; + TSOpVector Lower( + std::shared_ptr function, + TSLoweringContext* loctx) const override; private: bool training_; @@ -107,5 +152,5 @@ class TSNativeBatchNormForward : public torch::lazy::TsNode { double eps_; }; -} // namespace lazy -} // namespace torch +} // namespace lazy +} // namespace torch diff --git a/torch/csrc/lazy/ts_backend/ops/device_data.h b/torch/csrc/lazy/ts_backend/ops/device_data.h index 2880eb5e9c1d80..53e7814fc39a44 100644 --- a/torch/csrc/lazy/ts_backend/ops/device_data.h +++ b/torch/csrc/lazy/ts_backend/ops/device_data.h @@ -38,8 +38,9 @@ class TORCH_API DeviceData : public TsNode { // instead of calling the constructor directly. static NodePtr Create(std::shared_ptr data); - TSOpVector Lower(std::shared_ptr function, - TSLoweringContext* loctx) const override; + TSOpVector Lower( + std::shared_ptr function, + TSLoweringContext* loctx) const override; private: std::shared_ptr data_; diff --git a/torch/csrc/lazy/ts_backend/ops/random_ops.cpp b/torch/csrc/lazy/ts_backend/ops/random_ops.cpp index ff29f4ce8a32db..cb3708953cd582 100644 --- a/torch/csrc/lazy/ts_backend/ops/random_ops.cpp +++ b/torch/csrc/lazy/ts_backend/ops/random_ops.cpp @@ -1,16 +1,22 @@ -#include #include +#include namespace torch { namespace lazy { -Normal::Normal(const torch::lazy::Value& self, const double& mean, const double& std, std::vector&& shapes) - : torch::lazy::TsNode(ClassOpKind(), - {self}, std::move(shapes), - /* num_outputs */ 1, - torch::lazy::MHash(mean, std)), - mean_(mean), - std_(std) {} +Normal::Normal( + const torch::lazy::Value& self, + const double& mean, + const double& std, + std::vector&& shapes) + : torch::lazy::TsNode( + ClassOpKind(), + {self}, + std::move(shapes), + /* num_outputs */ 1, + torch::lazy::MHash(mean, std)), + mean_(mean), + std_(std) {} std::string Normal::ToString() const { std::stringstream ss; @@ -30,11 +36,12 @@ torch::lazy::TSOpVector Normal::Lower( arguments.emplace_back(loctx->GetOutputOp(operand(i++))); arguments.emplace_back("mean", mean_); arguments.emplace_back("std", std_); - torch::lazy::TSOpVector normal__out = torch::lazy::LowerTSBuiltin(function, op().op, arguments, kwarguments); + torch::lazy::TSOpVector normal__out = + torch::lazy::LowerTSBuiltin(function, op().op, arguments, kwarguments); CHECK_EQ(normal__out.size(), 1); return normal__out; } -} // namespace lazy -} // namespace torch +} // namespace lazy +} // namespace torch diff --git a/torch/csrc/lazy/ts_backend/ops/random_ops.h b/torch/csrc/lazy/ts_backend/ops/random_ops.h index f44c5cc261d930..eb095a6a9542a7 100644 --- a/torch/csrc/lazy/ts_backend/ops/random_ops.h +++ b/torch/csrc/lazy/ts_backend/ops/random_ops.h @@ -11,15 +11,20 @@ class Normal : public torch::lazy::TsNode { return OpKind::Get("aten::normal_"); } - Normal(const torch::lazy::Value& self, const double& mean, const double& std, std::vector&& shapes); + Normal( + const torch::lazy::Value& self, + const double& mean, + const double& std, + std::vector&& shapes); std::string ToString() const override; - torch::lazy::TSOpVector Lower(std::shared_ptr function, - torch::lazy::TSLoweringContext* loctx) const override; + torch::lazy::TSOpVector Lower( + std::shared_ptr function, + torch::lazy::TSLoweringContext* loctx) const override; double mean_; double std_; }; -} // namespace lazy -} // namespace torch +} // namespace lazy +} // namespace torch diff --git a/torch/csrc/lazy/ts_backend/ops/to_copy.h b/torch/csrc/lazy/ts_backend/ops/to_copy.h index 773516cbf7195a..fdc90f46e609d1 100644 --- a/torch/csrc/lazy/ts_backend/ops/to_copy.h +++ b/torch/csrc/lazy/ts_backend/ops/to_copy.h @@ -5,22 +5,38 @@ namespace torch { namespace lazy { - -// This IR was copied from code-generated output, but the entire _to_copy operator -// cannot be trivially code genereated since it is only desirable to capture IR for -// certain permutaions of _to_copy (e.g. dtype), and for the others it is difficult to even invoke -// the aten/eager fallback necessitating directly implementing the right to(device) behavior +// This IR was copied from code-generated output, but the entire _to_copy +// operator cannot be trivially code genereated since it is only desirable to +// capture IR for certain permutaions of _to_copy (e.g. dtype), and for the +// others it is difficult to even invoke the aten/eager fallback necessitating +// directly implementing the right to(device) behavior class ToCopy : public torch::lazy::TsNode { public: static OpKind ClassOpKind() { return OpKind(at::aten::_to_copy); } - ToCopy(const torch::lazy::Value& self, const c10::optional& dtype, const c10::optional& layout, const c10::optional& device, const c10::optional& pin_memory, const bool& non_blocking, const c10::optional& memory_format, std::vector&& shapes) - : torch::lazy::TsNode(ClassOpKind(), - {self}, std::move(shapes), - /* num_outputs */ 1, - torch::lazy::MHash(dtype, layout, device, pin_memory, non_blocking, memory_format)), + ToCopy( + const torch::lazy::Value& self, + const c10::optional& dtype, + const c10::optional& layout, + const c10::optional& device, + const c10::optional& pin_memory, + const bool& non_blocking, + const c10::optional& memory_format, + std::vector&& shapes) + : torch::lazy::TsNode( + ClassOpKind(), + {self}, + std::move(shapes), + /* num_outputs */ 1, + torch::lazy::MHash( + dtype, + layout, + device, + pin_memory, + non_blocking, + memory_format)), dtype(dtype), layout(layout), @@ -29,55 +45,58 @@ class ToCopy : public torch::lazy::TsNode { non_blocking(non_blocking), memory_format(memory_format) {} - bool CanBeReused(const torch::lazy::Value& self, - const c10::optional& dtype, - const c10::optional& layout, - const c10::optional& device, - const c10::optional& pin_memory, const bool& non_blocking, - const c10::optional& memory_format) const { + bool CanBeReused( + const torch::lazy::Value& self, + const c10::optional& dtype, + const c10::optional& layout, + const c10::optional& device, + const c10::optional& pin_memory, + const bool& non_blocking, + const c10::optional& memory_format) const { size_t i = 0; - return (operand(i++) == self && this->dtype == dtype && - this->layout == layout && this->device == device && - this->pin_memory == pin_memory && - this->non_blocking == non_blocking && - this->memory_format == memory_format); + return ( + operand(i++) == self && this->dtype == dtype && + this->layout == layout && this->device == device && + this->pin_memory == pin_memory && this->non_blocking == non_blocking && + this->memory_format == memory_format); } std::string ToString() const override { std::stringstream ss; ss << torch::lazy::TsNode::ToString(); if (dtype.has_value()) { - ss << ", dtype=" << dtype.value(); + ss << ", dtype=" << dtype.value(); } else { - ss << ", dtype=null"; + ss << ", dtype=null"; } if (layout.has_value()) { - ss << ", layout=" << layout.value(); + ss << ", layout=" << layout.value(); } else { - ss << ", layout=null"; + ss << ", layout=null"; } if (device.has_value()) { - ss << ", device=" << device.value(); + ss << ", device=" << device.value(); } else { - ss << ", device=null"; + ss << ", device=null"; } if (pin_memory.has_value()) { - ss << ", pin_memory=" << pin_memory.value(); + ss << ", pin_memory=" << pin_memory.value(); } else { - ss << ", pin_memory=null"; + ss << ", pin_memory=null"; } ss << ", non_blocking=" << non_blocking; if (memory_format.has_value()) { - ss << ", memory_format=" << memory_format.value(); + ss << ", memory_format=" << memory_format.value(); } else { - ss << ", memory_format=null"; + ss << ", memory_format=null"; } return ss.str(); } - torch::lazy::TSOpVector Lower(std::shared_ptr function, - torch::lazy::TSLoweringContext* loctx) const override { - std::vector arguments; + torch::lazy::TSOpVector Lower( + std::shared_ptr function, + torch::lazy::TSLoweringContext* loctx) const override { + std::vector arguments; std::vector kwarguments; arguments.reserve(1); kwarguments.reserve(6); @@ -89,11 +108,11 @@ class ToCopy : public torch::lazy::TsNode { kwarguments.emplace_back("pin_memory", pin_memory); kwarguments.emplace_back("non_blocking", non_blocking); kwarguments.emplace_back("memory_format", memory_format); - torch::lazy::TSOpVector _to_copy_out = torch::lazy::LowerTSBuiltin(function, op().op, arguments, kwarguments); + torch::lazy::TSOpVector _to_copy_out = + torch::lazy::LowerTSBuiltin(function, op().op, arguments, kwarguments); CHECK_EQ(_to_copy_out.size(), 1); return _to_copy_out; - } c10::optional dtype; @@ -104,5 +123,5 @@ class ToCopy : public torch::lazy::TsNode { c10::optional memory_format; }; -} // namespace lazy -} // namespace torch +} // namespace lazy +} // namespace torch diff --git a/torch/csrc/lazy/ts_backend/tensor_aten_ops.cpp b/torch/csrc/lazy/ts_backend/tensor_aten_ops.cpp index 423ee67c9dd60b..9d23c48fd94e8a 100644 --- a/torch/csrc/lazy/ts_backend/tensor_aten_ops.cpp +++ b/torch/csrc/lazy/ts_backend/tensor_aten_ops.cpp @@ -4,17 +4,17 @@ #include #include #include -#include -#include -#include +#include #include #include #include +#include +#include #include #include -#include -#include #include +#include +#include #include #include @@ -25,18 +25,21 @@ namespace { // to enable operator+-*/ for Value using namespace torch::lazy; -torch::lazy::Value MaybeExpand(const torch::lazy::Value& input, - const torch::lazy::Shape& target_shape) { +torch::lazy::Value MaybeExpand( + const torch::lazy::Value& input, + const torch::lazy::Shape& target_shape) { if (input.shape().sizes() == target_shape.sizes()) { return input; } return torch::lazy::MakeExpand( - input, target_shape.sizes().vec(), + input, + target_shape.sizes().vec(), /*is_scalar_expand=*/false); } -std::vector GetExpandDimensions(const torch::lazy::Shape& shape, - std::vector dimensions) { +std::vector GetExpandDimensions( + const torch::lazy::Shape& shape, + std::vector dimensions) { CHECK_GE(dimensions.size(), shape.dim()) << shape; int64_t base = dimensions.size() - shape.dim(); for (size_t i = 0; i < shape.dim(); ++i) { @@ -48,28 +51,31 @@ std::vector GetExpandDimensions(const torch::lazy::Shape& shape, } // Returns a 1-D shape for batch norm weight or bias based on the input shape. -torch::lazy::Shape BatchNormFeaturesShape(const torch::lazy::LazyTensorPtr& input) { +torch::lazy::Shape BatchNormFeaturesShape( + const torch::lazy::LazyTensorPtr& input) { CHECK(input); auto input_shape = input->shape().Get(); - return torch::lazy::Shape(input_shape.scalar_type(), - input_shape.sizes()[1]); + return torch::lazy::Shape(input_shape.scalar_type(), input_shape.sizes()[1]); } // Returns the IR for the given input or the provided default value broadcasted // to the default shape, if the input is undefined. -torch::lazy::Value GetIrValueOrDefault(const torch::lazy::LazyTensorPtr& input, - const at::Scalar& default_value, - const torch::lazy::Shape& default_shape, - const torch::lazy::BackendDevice& device) { - return input ? input->GetIrValue() - : torch::lazy::LazyGraphExecutor::Get()->GetIrValueForExpandedScalar(default_value, - default_shape, - device); +torch::lazy::Value GetIrValueOrDefault( + const torch::lazy::LazyTensorPtr& input, + const at::Scalar& default_value, + const torch::lazy::Shape& default_shape, + const torch::lazy::BackendDevice& device) { + return input + ? input->GetIrValue() + : torch::lazy::LazyGraphExecutor::Get()->GetIrValueForExpandedScalar( + default_value, default_shape, device); } torch::lazy::ViewInfo CreateAsStridedViewInfo( - const torch::lazy::Shape& input_shape, std::vector size, - std::vector stride, c10::optional storage_offset) { + const torch::lazy::Shape& input_shape, + std::vector size, + std::vector stride, + c10::optional storage_offset) { torch::lazy::Shape result_shape = torch::lazy::Shape(input_shape.scalar_type(), size); torch::lazy::AsStridedInfo as_strided_info; @@ -77,30 +83,38 @@ torch::lazy::ViewInfo CreateAsStridedViewInfo( if (storage_offset) { as_strided_info.offset = *storage_offset; } - return torch::lazy::ViewInfo(torch::lazy::ViewInfo::Type::kAsStrided, - std::move(result_shape), input_shape, - std::move(as_strided_info)); + return torch::lazy::ViewInfo( + torch::lazy::ViewInfo::Type::kAsStrided, + std::move(result_shape), + input_shape, + std::move(as_strided_info)); } -} // namespace +} // namespace ////////////////////////////////////////////////////////////////////////////// // ATEN operators follows here, listed in alphabetical order. ////////////////////////////////////////////////////////////////////////////// -torch::lazy::LazyTensorPtr as_strided(const torch::lazy::LazyTensorPtr& input, std::vector size, - std::vector stride, - c10::optional storage_offset) { +torch::lazy::LazyTensorPtr as_strided( + const torch::lazy::LazyTensorPtr& input, + std::vector size, + std::vector stride, + c10::optional storage_offset) { auto input_shape = input->shape(); return input->CreateViewTensor(CreateAsStridedViewInfo( input_shape, std::move(size), std::move(stride), storage_offset)); } -void as_strided_(torch::lazy::LazyTensorPtr& input, std::vector size, - std::vector stride, - c10::optional storage_offset) { +void as_strided_( + torch::lazy::LazyTensorPtr& input, + std::vector size, + std::vector stride, + c10::optional storage_offset) { if (input->data()->view == nullptr) { input->SetIrValue(torch::lazy::MakeAsStrided( - input->GetIrValue(), std::move(size), std::move(stride), + input->GetIrValue(), + std::move(size), + std::move(stride), storage_offset.value_or(0))); } else { auto input_shape = input->shape(); @@ -109,24 +123,32 @@ void as_strided_(torch::lazy::LazyTensorPtr& input, std::vector size, } } -torch::lazy::LazyTensorPtr expand(const torch::lazy::LazyTensorPtr& input, std::vector size) { +torch::lazy::LazyTensorPtr expand( + const torch::lazy::LazyTensorPtr& input, + std::vector size) { auto input_shape = input->shape(); - auto output = torch::lazy::LazyTensor::Create(torch::lazy::MakeExpand( - input->GetIrValue(), - GetExpandDimensions(input_shape.Get(), std::move(size)), - /*is_scalar_expand=*/false), input->GetDevice()); + auto output = torch::lazy::LazyTensor::Create( + torch::lazy::MakeExpand( + input->GetIrValue(), + GetExpandDimensions(input_shape.Get(), std::move(size)), + /*is_scalar_expand=*/false), + input->GetDevice()); output->SetStorage(input->Storage()); return output; } void fill_(torch::lazy::LazyTensorPtr& input, const at::Scalar& value) { - torch::lazy::Value constant = torch::lazy::LazyGraphExecutor::Get()->GetIrValueForExpandedScalar( - value, input->shape(), input->GetDevice()); + torch::lazy::Value constant = + torch::lazy::LazyGraphExecutor::Get()->GetIrValueForExpandedScalar( + value, input->shape(), input->GetDevice()); input->SetInPlaceIrValue(std::move(constant)); } -torch::lazy::LazyTensorPtr narrow(const torch::lazy::LazyTensorPtr& input, int64_t dim, int64_t start, - int64_t length) { +torch::lazy::LazyTensorPtr narrow( + const torch::lazy::LazyTensorPtr& input, + int64_t dim, + int64_t start, + int64_t length) { auto input_shape = input->shape(); dim = torch::lazy::GetCanonicalDimensionIndex(dim, input_shape.Get().dim()); torch::lazy::Shape narrow_shape = input_shape; @@ -134,19 +156,28 @@ torch::lazy::LazyTensorPtr narrow(const torch::lazy::LazyTensorPtr& input, int64 torch::lazy::ViewInfo::Type view_type = (input_shape.Get().numel() == narrow_shape.numel()) - ? torch::lazy::ViewInfo::Type::kReshape - : torch::lazy::ViewInfo::Type::kNarrow; - torch::lazy::ViewInfo view_info(view_type, std::move(narrow_shape), - input_shape); + ? torch::lazy::ViewInfo::Type::kReshape + : torch::lazy::ViewInfo::Type::kNarrow; + torch::lazy::ViewInfo view_info( + view_type, std::move(narrow_shape), input_shape); view_info.indices[dim] = torch::lazy::GetCanonicalPosition(input_shape.Get().sizes(), dim, start); return input->CreateViewTensor(std::move(view_info)); } -std::tuple ts_native_batch_norm( - const torch::lazy::LazyTensorPtr& input, const torch::lazy::LazyTensorPtr& weight, const torch::lazy::LazyTensorPtr& bias, - torch::lazy::LazyTensorPtr& running_mean, torch::lazy::LazyTensorPtr& running_var, bool training, - double momentum, double eps) { +std::tuple< + torch::lazy::LazyTensorPtr, + torch::lazy::LazyTensorPtr, + torch::lazy::LazyTensorPtr> +ts_native_batch_norm( + const torch::lazy::LazyTensorPtr& input, + const torch::lazy::LazyTensorPtr& weight, + const torch::lazy::LazyTensorPtr& bias, + torch::lazy::LazyTensorPtr& running_mean, + torch::lazy::LazyTensorPtr& running_var, + bool training, + double momentum, + double eps) { torch::lazy::Shape features_shape = BatchNormFeaturesShape(input); torch::lazy::Value weight_value = GetIrValueOrDefault(weight, 1, features_shape, input->GetDevice()); @@ -156,23 +187,43 @@ std::tupleGetDevice()); torch::lazy::Value running_var_value = GetIrValueOrDefault(running_var, 0, features_shape, input->GetDevice()); - torch::lazy::NodePtr node = - ReuseOrMakeNode( - input->GetIrValue(), weight_value, bias_value, running_mean_value, - running_var_value, training, momentum, eps); - torch::lazy::LazyTensorPtr output = torch::lazy::LazyTensor::Create(torch::lazy::Value(node, 0), input->GetDevice()); + torch::lazy::NodePtr node = ReuseOrMakeNode( + input->GetIrValue(), + weight_value, + bias_value, + running_mean_value, + running_var_value, + training, + momentum, + eps); + torch::lazy::LazyTensorPtr output = torch::lazy::LazyTensor::Create( + torch::lazy::Value(node, 0), input->GetDevice()); torch::lazy::LazyTensorPtr running_mean_output = - torch::lazy::LazyTensor::Create(torch::lazy::Value(node, 1), input->GetDevice()); - torch::lazy::LazyTensorPtr running_var_output = torch::lazy::LazyTensor::Create(torch::lazy::Value(node, 2), input->GetDevice()); - return std::make_tuple(std::move(output), std::move(running_mean_output), - std::move(running_var_output)); + torch::lazy::LazyTensor::Create( + torch::lazy::Value(node, 1), input->GetDevice()); + torch::lazy::LazyTensorPtr running_var_output = + torch::lazy::LazyTensor::Create( + torch::lazy::Value(node, 2), input->GetDevice()); + return std::make_tuple( + std::move(output), + std::move(running_mean_output), + std::move(running_var_output)); } -std::tuple ts_native_batch_norm_backward( - const torch::lazy::LazyTensorPtr& grad_out, const torch::lazy::LazyTensorPtr& input, - const torch::lazy::LazyTensorPtr& weight, const torch::lazy::LazyTensorPtr& running_mean, - const torch::lazy::LazyTensorPtr& running_var, const torch::lazy::LazyTensorPtr& save_mean, - const torch::lazy::LazyTensorPtr& save_invstd, bool training, double eps, +std::tuple< + torch::lazy::LazyTensorPtr, + torch::lazy::LazyTensorPtr, + torch::lazy::LazyTensorPtr> +ts_native_batch_norm_backward( + const torch::lazy::LazyTensorPtr& grad_out, + const torch::lazy::LazyTensorPtr& input, + const torch::lazy::LazyTensorPtr& weight, + const torch::lazy::LazyTensorPtr& running_mean, + const torch::lazy::LazyTensorPtr& running_var, + const torch::lazy::LazyTensorPtr& save_mean, + const torch::lazy::LazyTensorPtr& save_invstd, + bool training, + double eps, c10::ArrayRef output_mask) { torch::lazy::Shape features_shape = BatchNormFeaturesShape(input); torch::lazy::Value weight_value = @@ -180,29 +231,46 @@ std::tuple( - grad_out->GetIrValue(), input->GetIrValue(), weight_value, - save_mean->GetIrValue(), save_invstd->GetIrValue(), training, eps, + grad_out->GetIrValue(), + input->GetIrValue(), + weight_value, + save_mean->GetIrValue(), + save_invstd->GetIrValue(), + training, + eps, std::array{output_mask[0], output_mask[1], output_mask[2]}); } else { CHECK(running_mean); CHECK(running_var); node = ReuseOrMakeNode( - grad_out->GetIrValue(), input->GetIrValue(), weight_value, - running_mean->GetIrValue(), running_var->GetIrValue(), - save_mean->GetIrValue(), save_invstd->GetIrValue(), training, eps, + grad_out->GetIrValue(), + input->GetIrValue(), + weight_value, + running_mean->GetIrValue(), + running_var->GetIrValue(), + save_mean->GetIrValue(), + save_invstd->GetIrValue(), + training, + eps, std::array{output_mask[0], output_mask[1], output_mask[2]}); } - torch::lazy::LazyTensorPtr grad_input = torch::lazy::LazyTensor::Create(torch::lazy::Value(node, 0), input->GetDevice()); - torch::lazy::LazyTensorPtr grad_weight = torch::lazy::LazyTensor::Create(torch::lazy::Value(node, 1), input->GetDevice()); - torch::lazy::LazyTensorPtr grad_bias = torch::lazy::LazyTensor::Create(torch::lazy::Value(node, 2), input->GetDevice()); - return std::make_tuple(std::move(grad_input), std::move(grad_weight), - std::move(grad_bias)); + torch::lazy::LazyTensorPtr grad_input = torch::lazy::LazyTensor::Create( + torch::lazy::Value(node, 0), input->GetDevice()); + torch::lazy::LazyTensorPtr grad_weight = torch::lazy::LazyTensor::Create( + torch::lazy::Value(node, 1), input->GetDevice()); + torch::lazy::LazyTensorPtr grad_bias = torch::lazy::LazyTensor::Create( + torch::lazy::Value(node, 2), input->GetDevice()); + return std::make_tuple( + std::move(grad_input), std::move(grad_weight), std::move(grad_bias)); } -torch::lazy::LazyTensorPtr permute(const torch::lazy::LazyTensorPtr& input, c10::ArrayRef dims) { +torch::lazy::LazyTensorPtr permute( + const torch::lazy::LazyTensorPtr& input, + c10::ArrayRef dims) { auto input_shape = input->shape(); torch::lazy::ViewInfo view_info( - torch::lazy::ViewInfo::Type::kPermute, input_shape, + torch::lazy::ViewInfo::Type::kPermute, + input_shape, torch::lazy::GetCanonicalDimensionIndices(dims, input_shape.Get().dim())); return input->CreateViewTensor(std::move(view_info)); } @@ -227,7 +295,10 @@ void copy_(torch::lazy::LazyTensorPtr& input, torch::lazy::LazyTensorPtr& src) { } } -torch::lazy::LazyTensorPtr select(const torch::lazy::LazyTensorPtr& input, int64_t dim, int64_t index) { +torch::lazy::LazyTensorPtr select( + const torch::lazy::LazyTensorPtr& input, + int64_t dim, + int64_t index) { auto shape = input->shape(); dim = torch::lazy::GetCanonicalDimensionIndex(dim, shape.Get().dim()); torch::lazy::LazyTensorPtr result = narrow(input, dim, index, 1); @@ -235,8 +306,12 @@ torch::lazy::LazyTensorPtr select(const torch::lazy::LazyTensorPtr& input, int64 return view(result, new_dims); } -torch::lazy::LazyTensorPtr slice(const torch::lazy::LazyTensorPtr& input, int64_t dim, int64_t start, - int64_t end, int64_t step) { +torch::lazy::LazyTensorPtr slice( + const torch::lazy::LazyTensorPtr& input, + int64_t dim, + int64_t start, + int64_t end, + int64_t step) { auto input_shape = input->shape(); dim = torch::lazy::GetCanonicalDimensionIndex(dim, input_shape.Get().dim()); start = @@ -249,19 +324,21 @@ torch::lazy::LazyTensorPtr slice(const torch::lazy::LazyTensorPtr& input, int64_ step = std::min(step, end - start); torch::lazy::SelectInfo select = {dim, start, end, step}; - torch::lazy::ViewInfo view_info(torch::lazy::ViewInfo::Type::kSelect, - input_shape, select); + torch::lazy::ViewInfo view_info( + torch::lazy::ViewInfo::Type::kSelect, input_shape, select); return input->CreateViewTensor(std::move(view_info)); } torch::lazy::LazyTensorPtr squeeze(const torch::lazy::LazyTensorPtr& input) { auto input_shape = input->shape(); - auto output_dimensions = BuildSqueezedDimensions( - input_shape.Get().sizes(), /*squeeze_dim=*/-1); + auto output_dimensions = + BuildSqueezedDimensions(input_shape.Get().sizes(), /*squeeze_dim=*/-1); return view(input, output_dimensions); } -torch::lazy::LazyTensorPtr squeeze(const torch::lazy::LazyTensorPtr& input, int64_t dim) { +torch::lazy::LazyTensorPtr squeeze( + const torch::lazy::LazyTensorPtr& input, + int64_t dim) { auto input_shape = input->shape(); int64_t squeeze_dim = torch::lazy::GetCanonicalDimensionIndex(dim, input->shape().Get().dim()); @@ -271,22 +348,25 @@ torch::lazy::LazyTensorPtr squeeze(const torch::lazy::LazyTensorPtr& input, int6 } void squeeze_(torch::lazy::LazyTensorPtr& input) { - input->SetIrValue( - torch::lazy::MakeSqueeze(input->GetIrValue(), -1)); + input->SetIrValue(torch::lazy::MakeSqueeze(input->GetIrValue(), -1)); } void squeeze_(torch::lazy::LazyTensorPtr& input, int64_t dim) { input->SetIrValue(torch::lazy::MakeSqueeze( input->GetIrValue(), - torch::lazy::GetCanonicalDimensionIndex(dim, input->shape().Get().dim()))); + torch::lazy::GetCanonicalDimensionIndex( + dim, input->shape().Get().dim()))); } -torch::lazy::LazyTensorPtr transpose(const torch::lazy::LazyTensorPtr& input, int64_t dim0, int64_t dim1) { +torch::lazy::LazyTensorPtr transpose( + const torch::lazy::LazyTensorPtr& input, + int64_t dim0, + int64_t dim1) { auto input_shape = input->shape(); auto permute_dims = torch::lazy::MakeTransposePermutation( /*dim0=*/dim0, /*dim1=*/dim1, /*rank=*/input_shape.Get().dim()); - torch::lazy::ViewInfo view_info(torch::lazy::ViewInfo::Type::kPermute, - input_shape, permute_dims); + torch::lazy::ViewInfo view_info( + torch::lazy::ViewInfo::Type::kPermute, input_shape, permute_dims); return input->CreateViewTensor(std::move(view_info)); } @@ -294,12 +374,14 @@ void transpose_(torch::lazy::LazyTensorPtr& input, int64_t dim0, int64_t dim1) { auto input_shape = input->shape(); auto permute_dims = torch::lazy::MakeTransposePermutation( /*dim0=*/dim0, /*dim1=*/dim1, /*rank=*/input_shape.Get().dim()); - torch::lazy::ViewInfo view_info(torch::lazy::ViewInfo::Type::kPermute, - input_shape, permute_dims); + torch::lazy::ViewInfo view_info( + torch::lazy::ViewInfo::Type::kPermute, input_shape, permute_dims); return input->ModifyCurrentView(std::move(view_info)); } -torch::lazy::LazyTensorPtr unsqueeze(const torch::lazy::LazyTensorPtr& input, int64_t dim) { +torch::lazy::LazyTensorPtr unsqueeze( + const torch::lazy::LazyTensorPtr& input, + int64_t dim) { auto input_shape = input->shape(); int64_t squeeze_dim = torch::lazy::GetCanonicalDimensionIndex(dim, input_shape.Get().dim() + 1); @@ -311,18 +393,21 @@ torch::lazy::LazyTensorPtr unsqueeze(const torch::lazy::LazyTensorPtr& input, in void unsqueeze_(torch::lazy::LazyTensorPtr& input, int64_t dim) { int squeeze_dim = torch::lazy::GetCanonicalDimensionIndex( dim, input->shape().Get().dim() + 1); - input->SetIrValue(torch::lazy::MakeUnsqueeze(input->GetIrValue(), - squeeze_dim)); + input->SetIrValue( + torch::lazy::MakeUnsqueeze(input->GetIrValue(), squeeze_dim)); } -torch::lazy::LazyTensorPtr view(const torch::lazy::LazyTensorPtr& input, c10::ArrayRef output_size) { +torch::lazy::LazyTensorPtr view( + const torch::lazy::LazyTensorPtr& input, + c10::ArrayRef output_size) { auto input_shape = input->shape().Get(); torch::lazy::Shape shape = torch::lazy::Shape( - input_shape.scalar_type(), at::infer_size(output_size, input_shape.numel())); - torch::lazy::ViewInfo view_info(torch::lazy::ViewInfo::Type::kReshape, - std::move(shape), input_shape); + input_shape.scalar_type(), + at::infer_size(output_size, input_shape.numel())); + torch::lazy::ViewInfo view_info( + torch::lazy::ViewInfo::Type::kReshape, std::move(shape), input_shape); return input->CreateViewTensor(std::move(view_info)); } -} // namespace lazy -} // namespace torch +} // namespace lazy +} // namespace torch diff --git a/torch/csrc/lazy/ts_backend/tensor_aten_ops.h b/torch/csrc/lazy/ts_backend/tensor_aten_ops.h index e506bfd0255ace..d2104d5fcdde0b 100644 --- a/torch/csrc/lazy/ts_backend/tensor_aten_ops.h +++ b/torch/csrc/lazy/ts_backend/tensor_aten_ops.h @@ -10,81 +10,128 @@ namespace lazy { ////////////////////////////////////////////////////////////////////////////// // Takes a slice from the input as R1 at the specified offset and reshapes it // into the provided size. -torch::lazy::LazyTensorPtr as_strided(const torch::lazy::LazyTensorPtr& input, std::vector size, - std::vector stride, - c10::optional storage_offset); +torch::lazy::LazyTensorPtr as_strided( + const torch::lazy::LazyTensorPtr& input, + std::vector size, + std::vector stride, + c10::optional storage_offset); // In-place version of the method above. -void as_strided_(torch::lazy::LazyTensorPtr& input, std::vector size, - std::vector stride, - c10::optional storage_offset); +void as_strided_( + torch::lazy::LazyTensorPtr& input, + std::vector size, + std::vector stride, + c10::optional storage_offset); -torch::lazy::LazyTensorPtr expand(const torch::lazy::LazyTensorPtr& input, - std::vector size); +torch::lazy::LazyTensorPtr expand( + const torch::lazy::LazyTensorPtr& input, + std::vector size); // Fills the input with the given value. void fill_(torch::lazy::LazyTensorPtr& input, const at::Scalar& value); // Returns a new tensor that is a narrowed view of the input in the given // dimension. -torch::lazy::LazyTensorPtr narrow(const torch::lazy::LazyTensorPtr& input, int64_t dim, int64_t start, - int64_t length); - -std::tuple ts_native_batch_norm( - const torch::lazy::LazyTensorPtr& input, const torch::lazy::LazyTensorPtr& weight, const torch::lazy::LazyTensorPtr& bias, - torch::lazy::LazyTensorPtr& running_mean, torch::lazy::LazyTensorPtr& running_var, bool training, - double momentum, double eps); - -std::tuple ts_native_batch_norm_backward( - const torch::lazy::LazyTensorPtr& grad_out, const torch::lazy::LazyTensorPtr& input, - const torch::lazy::LazyTensorPtr& weight, const torch::lazy::LazyTensorPtr& running_mean, - const torch::lazy::LazyTensorPtr& running_var, const torch::lazy::LazyTensorPtr& save_mean, - const torch::lazy::LazyTensorPtr& save_invstd, bool training, double eps, +torch::lazy::LazyTensorPtr narrow( + const torch::lazy::LazyTensorPtr& input, + int64_t dim, + int64_t start, + int64_t length); + +std::tuple< + torch::lazy::LazyTensorPtr, + torch::lazy::LazyTensorPtr, + torch::lazy::LazyTensorPtr> +ts_native_batch_norm( + const torch::lazy::LazyTensorPtr& input, + const torch::lazy::LazyTensorPtr& weight, + const torch::lazy::LazyTensorPtr& bias, + torch::lazy::LazyTensorPtr& running_mean, + torch::lazy::LazyTensorPtr& running_var, + bool training, + double momentum, + double eps); + +std::tuple< + torch::lazy::LazyTensorPtr, + torch::lazy::LazyTensorPtr, + torch::lazy::LazyTensorPtr> +ts_native_batch_norm_backward( + const torch::lazy::LazyTensorPtr& grad_out, + const torch::lazy::LazyTensorPtr& input, + const torch::lazy::LazyTensorPtr& weight, + const torch::lazy::LazyTensorPtr& running_mean, + const torch::lazy::LazyTensorPtr& running_var, + const torch::lazy::LazyTensorPtr& save_mean, + const torch::lazy::LazyTensorPtr& save_invstd, + bool training, + double eps, c10::ArrayRef output_mask); // Permute the dimensions of this tensor according to the given permutation. -torch::lazy::LazyTensorPtr permute(const torch::lazy::LazyTensorPtr& input, c10::ArrayRef dims); +torch::lazy::LazyTensorPtr permute( + const torch::lazy::LazyTensorPtr& input, + c10::ArrayRef dims); // Repeats the input tensor along each dimension by the given number of // repeats. -torch::lazy::LazyTensorPtr repeat(const torch::lazy::LazyTensorPtr& input, std::vector repeats); +torch::lazy::LazyTensorPtr repeat( + const torch::lazy::LazyTensorPtr& input, + std::vector repeats); void copy_(torch::lazy::LazyTensorPtr& input, torch::lazy::LazyTensorPtr& src); -torch::lazy::LazyTensorPtr select(const torch::lazy::LazyTensorPtr& input, int64_t dim, int64_t index); +torch::lazy::LazyTensorPtr select( + const torch::lazy::LazyTensorPtr& input, + int64_t dim, + int64_t index); -torch::lazy::LazyTensorPtr slice(const torch::lazy::LazyTensorPtr& input, int64_t dim, int64_t start, - int64_t end, int64_t step); +torch::lazy::LazyTensorPtr slice( + const torch::lazy::LazyTensorPtr& input, + int64_t dim, + int64_t start, + int64_t end, + int64_t step); // Squeeze out all trivial (size 1) dimensions. torch::lazy::LazyTensorPtr squeeze(const torch::lazy::LazyTensorPtr& input); // Squeeze out the specified dimension index, if trivial (size 1). Returns // unchanged input otherwise. -torch::lazy::LazyTensorPtr squeeze(const torch::lazy::LazyTensorPtr& input, int64_t dim); +torch::lazy::LazyTensorPtr squeeze( + const torch::lazy::LazyTensorPtr& input, + int64_t dim); // In-place versions of the methods above. void squeeze_(torch::lazy::LazyTensorPtr& input); void squeeze_(torch::lazy::LazyTensorPtr& input, int64_t dim); - -std::tuple svd( - const torch::lazy::LazyTensorPtr& input, - bool some, bool compute_uv); +std::tuple< + torch::lazy::LazyTensorPtr, + torch::lazy::LazyTensorPtr, + torch::lazy::LazyTensorPtr> +svd(const torch::lazy::LazyTensorPtr& input, bool some, bool compute_uv); // Swap given dimensions of the input. -torch::lazy::LazyTensorPtr transpose(const torch::lazy::LazyTensorPtr& input, int64_t dim0, int64_t dim1); +torch::lazy::LazyTensorPtr transpose( + const torch::lazy::LazyTensorPtr& input, + int64_t dim0, + int64_t dim1); // In-place version of the method above. void transpose_(torch::lazy::LazyTensorPtr& input, int64_t dim0, int64_t dim1); // Insert a dimension of size one at the specified position. -torch::lazy::LazyTensorPtr unsqueeze(const torch::lazy::LazyTensorPtr& input, int64_t dim); +torch::lazy::LazyTensorPtr unsqueeze( + const torch::lazy::LazyTensorPtr& input, + int64_t dim); // In-place version of the method above. void unsqueeze_(torch::lazy::LazyTensorPtr& input, int64_t dim); // Like reshape, but it returns a view into the original tensor. -torch::lazy::LazyTensorPtr view(const torch::lazy::LazyTensorPtr& input, c10::ArrayRef output_size); -} // namespace lazy -} // namespace torch +torch::lazy::LazyTensorPtr view( + const torch::lazy::LazyTensorPtr& input, + c10::ArrayRef output_size); +} // namespace lazy +} // namespace torch diff --git a/torch/csrc/lazy/ts_backend/ts_autograd_functions.cpp b/torch/csrc/lazy/ts_backend/ts_autograd_functions.cpp index 1a1259f6441df0..4631f156c50c43 100644 --- a/torch/csrc/lazy/ts_backend/ts_autograd_functions.cpp +++ b/torch/csrc/lazy/ts_backend/ts_autograd_functions.cpp @@ -1,27 +1,27 @@ -#include #include #include +#include #include namespace torch { namespace lazy { at::Tensor MaxPool3dAutogradFunctionTS::forward( - torch::autograd::AutogradContext* ctx, at::Tensor self, - at::IntArrayRef kernel_size, at::IntArrayRef stride, - at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode) { + torch::autograd::AutogradContext* ctx, + at::Tensor self, + at::IntArrayRef kernel_size, + at::IntArrayRef stride, + at::IntArrayRef padding, + at::IntArrayRef dilation, + bool ceil_mode) { ctx->saved_data["kernel_size"] = kernel_size; ctx->saved_data["stride"] = stride; ctx->saved_data["padding"] = padding; ctx->saved_data["dilation"] = dilation; ctx->saved_data["ceil_mode"] = ceil_mode; - auto results = at::native::call_fallback_fn< - <c_eager_fallback, ATEN_OP(max_pool3d_with_indices)>::call(self, - kernel_size, - stride, - padding, - dilation, - ceil_mode); + auto results = at::native:: + call_fallback_fn<<c_eager_fallback, ATEN_OP(max_pool3d_with_indices)>:: + call(self, kernel_size, stride, padding, dilation, ceil_mode); ctx->save_for_backward({self, std::get<1>(results)}); return std::get<0>(results); } @@ -40,16 +40,22 @@ torch::autograd::variable_list MaxPool3dAutogradFunctionTS::backward( auto indices = saved[1]; grad = at::native::call_fallback_fn< <c_eager_fallback, - ATEN_OP(max_pool3d_with_indices_backward)>::call(grad_output[0], self, - kernel_size, stride, - padding, dilation, - ceil_mode, indices); + ATEN_OP(max_pool3d_with_indices_backward)>:: + call( + grad_output[0], + self, + kernel_size, + stride, + padding, + dilation, + ceil_mode, + indices); at::Tensor undef; - torch::autograd::variable_list grad_inputs = {grad, undef, undef, - undef, undef, undef}; + torch::autograd::variable_list grad_inputs = { + grad, undef, undef, undef, undef, undef}; return grad_inputs; } -} // namespace lazy -} // namespace torch +} // namespace lazy +} // namespace torch diff --git a/torch/csrc/lazy/ts_backend/ts_autograd_functions.h b/torch/csrc/lazy/ts_backend/ts_autograd_functions.h index 3ed0587e2ad6f0..7e017244703844 100644 --- a/torch/csrc/lazy/ts_backend/ts_autograd_functions.h +++ b/torch/csrc/lazy/ts_backend/ts_autograd_functions.h @@ -7,16 +7,18 @@ namespace lazy { struct MaxPool3dAutogradFunctionTS : public torch::autograd::Function { - static at::Tensor forward(torch::autograd::AutogradContext* ctx, - at::Tensor self, - at::IntArrayRef kernel_size, - at::IntArrayRef stride, - at::IntArrayRef padding, - at::IntArrayRef dilation, bool ceil_mode); + static at::Tensor forward( + torch::autograd::AutogradContext* ctx, + at::Tensor self, + at::IntArrayRef kernel_size, + at::IntArrayRef stride, + at::IntArrayRef padding, + at::IntArrayRef dilation, + bool ceil_mode); static torch::autograd::variable_list backward( torch::autograd::AutogradContext* ctx, torch::autograd::variable_list grad_output); }; -} // namespace lazy -} // namespace torch +} // namespace lazy +} // namespace torch diff --git a/torch/csrc/lazy/ts_backend/ts_backend_impl.cpp b/torch/csrc/lazy/ts_backend/ts_backend_impl.cpp index b7d3d9f50fa97d..6d33ebc5250328 100644 --- a/torch/csrc/lazy/ts_backend/ts_backend_impl.cpp +++ b/torch/csrc/lazy/ts_backend/ts_backend_impl.cpp @@ -4,18 +4,19 @@ #include #include #include -#include #include +#include #include namespace at { // This function is defined in the codegenerated RegisterDispatchKey.cpp file. -// For the TorchScript backend, we have a special case where the registration does not happen -// immediately (at static initialization time), so that if an external backend is loaded, -// it has a chance to register itself, and TorchScript only registers itself if explicitly initialized +// For the TorchScript backend, we have a special case where the registration +// does not happen immediately (at static initialization time), so that if an +// external backend is loaded, it has a chance to register itself, and +// TorchScript only registers itself if explicitly initialized extern TORCH_API void RegisterTorchScriptLazyNativeFunctions(); extern TORCH_API void RegisterTorchScriptAutogradLazyNativeFunctions(); -} +} // namespace at namespace torch { namespace lazy { @@ -23,7 +24,7 @@ namespace lazy { struct TSBackendDeviceType : public BackendDeviceType { TSBackendDeviceType() = delete; TSBackendDeviceType(c10::DeviceType deviceType) - :BackendDeviceType((int8_t)deviceType) { + : BackendDeviceType((int8_t)deviceType) { TORCH_CHECK(deviceType == at::kCPU || deviceType == at::kCUDA); } @@ -41,7 +42,8 @@ class TSBackendImpl : public torch::lazy::BackendImplInterface { TSBackendImpl() : default_device_type_(at::kCPU) { // TODO(whc) unify how all our flags are set and parsed as envs static bool env_use_cuda = std::getenv("LTC_TS_CUDA") != nullptr; - auto type = (env_use_cuda || FLAGS_torch_lazy_ts_cuda) ? at::kCUDA : at::kCPU; + auto type = + (env_use_cuda || FLAGS_torch_lazy_ts_cuda) ? at::kCUDA : at::kCPU; default_device_type_ = TSBackendDeviceType(type); } @@ -89,8 +91,8 @@ class TSBackendImpl : public torch::lazy::BackendImplInterface { return std::make_shared( tensor.to(options, /*non_blocking=*/true), shape, device); } else if (tensor.device().type() == at::kCPU && tensor.numel() == 1) { - // calling .item() on singleton cpu tensor is fast, and using fill is a safe, - // async way to copy cpu to cuda for a single value + // calling .item() on singleton cpu tensor is fast, and using fill is a + // safe, async way to copy cpu to cuda for a single value auto device_tensor = at::full(tensor.sizes(), tensor.item(), options); return std::make_shared(device_tensor, shape, device); } else { @@ -182,7 +184,6 @@ torch::lazy::BackendDataPtr TSBackendImpl::CreateDataPlaceholder( std::vector TSBackendImpl::Compile( std::vector instances) const { - for (const auto& instance : instances) { auto ts_computation = static_cast(instance.get()); diff --git a/torch/csrc/lazy/ts_backend/ts_eager_fallback.cpp b/torch/csrc/lazy/ts_backend/ts_eager_fallback.cpp index 9951bb123b5789..f3aef2cd1a5686 100644 --- a/torch/csrc/lazy/ts_backend/ts_eager_fallback.cpp +++ b/torch/csrc/lazy/ts_backend/ts_eager_fallback.cpp @@ -1,13 +1,13 @@ #include -#include #include +#include #include #include #include +#include #include #include -#include #include #include #include @@ -153,7 +153,7 @@ bool force_eager_fallback(c10::Symbol op) { return true; } } - if(op == at::aten::nonzero){ + if (op == at::aten::nonzero) { // When symbolic shape mode is not enabled, the nonzero shape function // returns an incorrect result. return !symbolicShapeEnabled(); @@ -165,8 +165,8 @@ bool force_eager_fallback(c10::Symbol op) { void ltc_eager_fallback( const c10::OperatorHandle& op, torch::jit::Stack* stack) { - // TODO(whc) this FN_TRACK thing hasn't been used so far in LTC iirc but could land/re-enable it - // LTC_FN_TRACK(3);; + // TODO(whc) this FN_TRACK thing hasn't been used so far in LTC iirc but could + // land/re-enable it LTC_FN_TRACK(3);; const auto name = c10::toString(op.operator_name()); // Manually applying the TORCH_LAZY_COUNTER macro. @@ -183,7 +183,7 @@ void ltc_eager_fallback( auto arguments = torch::jit::last(stack, args.size()); // Log each tensor argument. - for (const auto & ivalue : arguments) { + for (const auto& ivalue : arguments) { if (ivalue.isTensor()) { VLOG(3) << ivalue.toTensor().toString(); } @@ -323,8 +323,8 @@ void ts_eager_fallback( "mutable alias: ", schema_returns[idx]); } else { - c10::optional tgt_device = - compute_target_device(tensor_args, tensorlist_args, opt_tensorlist_args); + c10::optional tgt_device = compute_target_device( + tensor_args, tensorlist_args, opt_tensorlist_args); if (alias_info != nullptr && !alias_info->isWrite()) { // immutable alias (view) case: Warn here, since we're copying and // not creating a view. diff --git a/torch/csrc/lazy/ts_backend/ts_eager_fallback.h b/torch/csrc/lazy/ts_backend/ts_eager_fallback.h index 7c805a63364af0..9f993d6f30290e 100644 --- a/torch/csrc/lazy/ts_backend/ts_eager_fallback.h +++ b/torch/csrc/lazy/ts_backend/ts_eager_fallback.h @@ -1,9 +1,9 @@ #pragma once -#include #include #include #include +#include namespace torch { namespace lazy { diff --git a/torch/csrc/lazy/ts_backend/ts_lowering_context.cpp b/torch/csrc/lazy/ts_backend/ts_lowering_context.cpp index 145066d432f903..1defb5993b37c1 100644 --- a/torch/csrc/lazy/ts_backend/ts_lowering_context.cpp +++ b/torch/csrc/lazy/ts_backend/ts_lowering_context.cpp @@ -11,7 +11,8 @@ TSLoweringContext::TSLoweringContext( BackendDevice device) : torch::lazy::LoweringContext(name, device), graph_(std::make_shared()), - function_(std::make_shared(name, graph_, nullptr)) {} + function_( + std::make_shared(name, graph_, nullptr)) {} TSLoweringContext::TSLoweringContext( const std::string& name, @@ -20,13 +21,13 @@ TSLoweringContext::TSLoweringContext( Util::EmissionMap emit_status) : torch::lazy::LoweringContext(name, device, post_order, emit_status), graph_(std::make_shared()), - function_(std::make_shared(name, graph_, nullptr)) { + function_( + std::make_shared(name, graph_, nullptr)) { for (auto node : post_order) { Lower(node); } } - void TSLoweringContext::Lower(const Node* node) { if (auto* tsnode = dynamic_cast(node)) { // First, we call the node lowering function, which exists for newly @@ -37,8 +38,7 @@ void TSLoweringContext::Lower(const Node* node) { for (size_t i = 0; i < ops.size(); ++i) { AssignOutputOp(torch::lazy::Output(node, i), ops[i]); } - } - else { + } else { throw std::runtime_error( "Expected torch::lazy::TsNode but could not dynamic cast"); } @@ -56,8 +56,7 @@ void TSLoweringContext::AssignOutputOp( } torch::jit::Value* TSLoweringContext::GetParameter(BackendDataPtr data) { - const auto ts_data = - std::static_pointer_cast(data); + const auto ts_data = std::static_pointer_cast(data); BackendData::Handle handle = ts_data->GetHandle(); auto it = parameters_map_.find(handle); if (it == parameters_map_.end()) { diff --git a/torch/csrc/lazy/ts_backend/ts_lowering_context.h b/torch/csrc/lazy/ts_backend/ts_lowering_context.h index 714b63a114ed39..700f27d505fd37 100644 --- a/torch/csrc/lazy/ts_backend/ts_lowering_context.h +++ b/torch/csrc/lazy/ts_backend/ts_lowering_context.h @@ -83,7 +83,6 @@ class TORCH_API TSLoweringContext : public LoweringContext { size_t index, const Shape& shape, const std::string& name) override { - TORCH_INTERNAL_ASSERT(false, "not implemented"); } diff --git a/torch/csrc/lazy/ts_backend/ts_native_functions.cpp b/torch/csrc/lazy/ts_backend/ts_native_functions.cpp index 676f069b9b336c..a6796884cc86d0 100644 --- a/torch/csrc/lazy/ts_backend/ts_native_functions.cpp +++ b/torch/csrc/lazy/ts_backend/ts_native_functions.cpp @@ -1,37 +1,40 @@ -#include #include #include +#include #include #include #include -#include -#include -#include #include +#include #include +#include #include +#include #include #include -#include -#include -#include #include #include +#include +#include +#include #include namespace torch { namespace lazy { namespace { -at::Tensor CreateLtcTensor(const at::Tensor& tensor, - const c10::optional& device) { +at::Tensor CreateLtcTensor( + const at::Tensor& tensor, + const c10::optional& device) { if (tensor.defined() && device) { - return torch::lazy::CreateAtenFromLtcTensor(torch::lazy::LazyTensor::Create(tensor, *device)); + return torch::lazy::CreateAtenFromLtcTensor( + torch::lazy::LazyTensor::Create(tensor, *device)); } return tensor; } -c10::optional GetLtcDevice(const c10::optional& device) { +c10::optional GetLtcDevice( + const c10::optional& device) { if (!device) { return c10::nullopt; } @@ -41,7 +44,7 @@ c10::optional GetLtcDevice(const c10::optional storage_offset) { TORCH_LAZY_FN_COUNTER("lazy::"); torch::lazy::LazyTensorPtr self_tensor = torch::lazy::TryGetLtcTensor(self); auto xsize = torch::lazy::ToI64Vector(size); auto xstride = torch::lazy::ToI64Vector(stride); if (!torch::lazy::StrideIsSupported(xstride)) { - return at::native::call_fallback_fn< - <c_eager_fallback, ATEN_OP(as_strided)>::call(self, size, stride, - storage_offset); + return at::native:: + call_fallback_fn<<c_eager_fallback, ATEN_OP(as_strided)>::call( + self, size, stride, storage_offset); } return torch::lazy::CreateAtenFromLtcTensor(torch::lazy::as_strided( self_tensor, std::move(xsize), std::move(xstride), storage_offset)); } const at::Tensor& LazyNativeFunctions::as_strided_( - const at::Tensor& self, at::IntArrayRef size, at::IntArrayRef stride, + const at::Tensor& self, + at::IntArrayRef size, + at::IntArrayRef stride, c10::optional storage_offset) { TORCH_LAZY_FN_COUNTER("lazy::"); auto self_tensor = torch::lazy::TryGetLtcTensor(self); auto xsize = torch::lazy::ToI64Vector(size); auto xstride = torch::lazy::ToI64Vector(stride); if (!torch::lazy::StrideIsSupported(xstride)) { - return at::native::call_fallback_fn< - <c_eager_fallback, ATEN_OP(as_strided_)>::call(self, size, stride, - storage_offset); + return at::native:: + call_fallback_fn<<c_eager_fallback, ATEN_OP(as_strided_)>::call( + self, size, stride, storage_offset); } - torch::lazy::as_strided_(self_tensor, std::move(xsize), - std::move(xstride), storage_offset); + torch::lazy::as_strided_( + self_tensor, std::move(xsize), std::move(xstride), storage_offset); return self; } -at::Tensor LazyNativeFunctions::clone(const at::Tensor & self, c10::optional memory_format) { +at::Tensor LazyNativeFunctions::clone( + const at::Tensor& self, + c10::optional memory_format) { auto self_lt = torch::lazy::TryGetLtcTensor(self); - return torch::lazy::CreateAtenFromLtcTensor(self_lt->Create(self_lt->GetIrValue(), self_lt->GetDevice())); + return torch::lazy::CreateAtenFromLtcTensor( + self_lt->Create(self_lt->GetIrValue(), self_lt->GetDevice())); } -at::Tensor LazyNativeFunctions::_copy_from(const at::Tensor& self, - const at::Tensor& dst, - bool non_blocking) { +at::Tensor LazyNativeFunctions::_copy_from( + const at::Tensor& self, + const at::Tensor& dst, + bool non_blocking) { TORCH_LAZY_FN_COUNTER("lazy::"); auto dst_tensor = torch::lazy::TryGetLtcTensor(dst); auto self_tensor = torch::lazy::TryGetLtcTensor(self); @@ -98,9 +109,10 @@ at::Tensor LazyNativeFunctions::_copy_from(const at::Tensor& self, CHECK(dst_tensor); dst_tensor->UpdateFromTensor(self, /*sync=*/sync_update); } else if (!dst_tensor) { - // materializing a lazy tensor (self) and copying its value into eager tensor (dst) - // detached=false lets us skip a copy in `ToTensor`, which should be safe - // because we are only going to use the tensor for dst.copy_() + // materializing a lazy tensor (self) and copying its value into eager + // tensor (dst) detached=false lets us skip a copy in `ToTensor`, which + // should be safe because we are only going to use the tensor for + // dst.copy_() CHECK(self_tensor); at::Tensor tensor = self_tensor->ToTensor(/*detached=*/false); at::Tensor typed_tensor = @@ -115,26 +127,30 @@ at::Tensor LazyNativeFunctions::_copy_from(const at::Tensor& self, CHECK(dst_tensor_data); auto src_tensor_data = self_tensor->CurrentTensorData(); if (src_tensor_data) { - // both src/dst are simply backed by at::Tensor data, no IR- do a straightforward copy + // both src/dst are simply backed by at::Tensor data, no IR- do a + // straightforward copy dst_tensor_data->copy_(*src_tensor_data); } else { - // src needs to be materialized before its result can be used for a copy into dst - // since we use the src tensor only for making a copy, we don't need to detach it - // note: it would be even more efficient if we could cause ToTensor to materialize the - // value directly into dst's buffer (that would need to be detached though). + // src needs to be materialized before its result can be used for a copy + // into dst since we use the src tensor only for making a copy, we don't + // need to detach it note: it would be even more efficient if we could + // cause ToTensor to materialize the value directly into dst's buffer + // (that would need to be detached though). dst_tensor_data->copy_(self_tensor->ToTensor(/*detached=*/false)); } } else { copy_(dst_tensor, self_tensor); - auto* impl = dynamic_cast(dst.unsafeGetTensorImpl()); + auto* impl = + dynamic_cast(dst.unsafeGetTensorImpl()); impl->set_tensor(dst_tensor); } } return dst; } -at::Tensor LazyNativeFunctions::_copy_from_and_resize(const at::Tensor& self, - const at::Tensor& dst) { +at::Tensor LazyNativeFunctions::_copy_from_and_resize( + const at::Tensor& self, + const at::Tensor& dst) { TORCH_LAZY_FN_COUNTER("lazy::"); auto dst_tensor = torch::lazy::TryGetLtcTensor(dst); auto self_tensor = torch::lazy::TryGetLtcTensor(self); @@ -157,123 +173,154 @@ at::Tensor LazyNativeFunctions::_copy_from_and_resize(const at::Tensor& self, return dst; } -at::Tensor LazyNativeFunctions::_to_copy(const at::Tensor & self, - c10::optional dtype, - c10::optional layout, - c10::optional device, - c10::optional pin_memory, - bool non_blocking, - c10::optional memory_format) { - - if (force_eager_fallback(at::aten::_to_copy)) { - TORCH_INTERNAL_ASSERT(false, +at::Tensor LazyNativeFunctions::_to_copy( + const at::Tensor& self, + c10::optional dtype, + c10::optional layout, + c10::optional device, + c10::optional pin_memory, + bool non_blocking, + c10::optional memory_format) { + if (force_eager_fallback(at::aten::_to_copy)) { + TORCH_INTERNAL_ASSERT( + false, "Fallback is currently impossible for _to_copy since the fallback helper itself reinvokes _to_copy"); - } - - auto options = self.options(); - if (dtype) { - // I put each of these setters in a conditional instead of doing `self.options().dtype(dtype).layout(layout)... - // because calling .dtype(nullopt) on an options() that already has dtype appears to wipe it - options = options.dtype(dtype); - } - if (layout) { - options = options.layout(layout); - } - if (memory_format) { - options = options.memory_format(memory_format); - } - if (pin_memory) { - // TODO(whc) can we honor 'pin_memory' in some/all cases? - options = options.pinned_memory(pin_memory); - TORCH_WARN_ONCE("Pinned memory used in lazy _to_copy, check if the behavior is as intended"); - } + } - TORCH_LAZY_FN_COUNTER("lazy::"); - auto lazy_self = torch::lazy::TryGetLtcTensor(self); - if (!lazy_self && device && device->type() == c10::kLazy) { - // Case 1: eager->lazy (we create a new lazy tensor) - - auto eager_tensor = self.to(options, /*non_blocking=*/non_blocking, /*copy=*/true); - lazy_self = torch::lazy::GetOrCreateLtcTensor(eager_tensor, - torch::lazy::atenDeviceToBackendDevice(*device)); - return torch::lazy::CreateAtenFromLtcTensor(lazy_self); - } else if(device && device->type() != c10::kLazy) { - // Case 2: lazy->eager (forces a graph break since we are materializing a tensor) - - TORCH_INTERNAL_ASSERT(lazy_self); - auto eager_tensor = lazy_self->ToTensor(/*detached=*/true); - options = options.device(device); - auto moved_eager_tensor = eager_tensor.to(options, /*non_blocking=*/non_blocking, /*copy=*/true); - return moved_eager_tensor; - } else if (device && - device->type() == c10::kLazy && - device->has_index() && - device->index() != self.device().index()) { - // Case 3: lazy:0 -> lazy:1 - - // TODO(whc) what do we actually want to do here? - // option 1: materialize, move eager tensor, create new lazy tensor - // - this should be our default, as it is what would happen before we implemented _to_copy - // - actually combines case 1 + case 2 - // option 2: support multiple devices inside one lazy/TS executor (case 4) - // - but: we may have other assumptions that there is just one device per executor? so don't take this lightly - - TORCH_INTERNAL_ASSERT(lazy_self); - auto eager_tensor = lazy_self->ToTensor(/*detached=*/true); - // we move the eager tensor to the 'eager' equivalent of our lazy device - // e.g. if our device is lazy:1, the backend maps that to cuda:1, which is what we use - auto eager_device = c10::Device(torch::lazy::getBackend()->EagerFallbackDeviceType(), device->index()); - options = options.device(eager_device); - auto moved_eager_tensor = eager_tensor.to(options, /*non_blocking=*/false, /*copy=*/true); - lazy_self = torch::lazy::GetOrCreateLtcTensor(moved_eager_tensor, - torch::lazy::atenDeviceToBackendDevice(eager_device)); - return torch::lazy::CreateAtenFromLtcTensor(lazy_self); + auto options = self.options(); + if (dtype) { + // I put each of these setters in a conditional instead of doing + // `self.options().dtype(dtype).layout(layout)... because calling + // .dtype(nullopt) on an options() that already has dtype appears to wipe it + options = options.dtype(dtype); + } + if (layout) { + options = options.layout(layout); + } + if (memory_format) { + options = options.memory_format(memory_format); + } + if (pin_memory) { + // TODO(whc) can we honor 'pin_memory' in some/all cases? + options = options.pinned_memory(pin_memory); + TORCH_WARN_ONCE( + "Pinned memory used in lazy _to_copy, check if the behavior is as intended"); + } - } else { - // Case 4: lazy->lazy (special case: keep the _to_copy INSIDE the lazy graph) - - // Note: captured _to_copy will be executed with real eager tensors, not lazy tensors. - // We DO NOT want to burn 'lazy:0' as the device into this captured IR, or we will try to - // convert an eager tensor back to a lazy one inside the torchscript executor - // lazy:0 -> lazy:1 is handled in case3, so we can safely drop the device argument - device = c10::nullopt; - - torch::lazy::NodePtr node = torch::lazy::ReuseNode( - lazy_self->GetIrValue(), dtype, layout, device, pin_memory, - non_blocking, memory_format); - if (!node) { - auto shapes = torch::lazy::compute_shape__to_copy( - self, dtype, layout, device, pin_memory, non_blocking, - memory_format); - TORCH_INTERNAL_ASSERT(shapes.size() == 1); - node = torch::lazy::MakeNode( - lazy_self->GetIrValue(), dtype, layout, device, pin_memory, - non_blocking, memory_format, std::move(shapes)); - CacheNode(node); - } + TORCH_LAZY_FN_COUNTER("lazy::"); + auto lazy_self = torch::lazy::TryGetLtcTensor(self); + if (!lazy_self && device && device->type() == c10::kLazy) { + // Case 1: eager->lazy (we create a new lazy tensor) + + auto eager_tensor = + self.to(options, /*non_blocking=*/non_blocking, /*copy=*/true); + lazy_self = torch::lazy::GetOrCreateLtcTensor( + eager_tensor, torch::lazy::atenDeviceToBackendDevice(*device)); + return torch::lazy::CreateAtenFromLtcTensor(lazy_self); + } else if (device && device->type() != c10::kLazy) { + // Case 2: lazy->eager (forces a graph break since we are materializing a + // tensor) + + TORCH_INTERNAL_ASSERT(lazy_self); + auto eager_tensor = lazy_self->ToTensor(/*detached=*/true); + options = options.device(device); + auto moved_eager_tensor = + eager_tensor.to(options, /*non_blocking=*/non_blocking, /*copy=*/true); + return moved_eager_tensor; + } else if ( + device && device->type() == c10::kLazy && device->has_index() && + device->index() != self.device().index()) { + // Case 3: lazy:0 -> lazy:1 + + // TODO(whc) what do we actually want to do here? + // option 1: materialize, move eager tensor, create new lazy tensor + // - this should be our default, as it is what would happen before we + // implemented _to_copy + // - actually combines case 1 + case 2 + // option 2: support multiple devices inside one lazy/TS executor (case 4) + // - but: we may have other assumptions that there is just one device + // per executor? so don't take this lightly + + TORCH_INTERNAL_ASSERT(lazy_self); + auto eager_tensor = lazy_self->ToTensor(/*detached=*/true); + // we move the eager tensor to the 'eager' equivalent of our lazy device + // e.g. if our device is lazy:1, the backend maps that to cuda:1, which is + // what we use + auto eager_device = c10::Device( + torch::lazy::getBackend()->EagerFallbackDeviceType(), device->index()); + options = options.device(eager_device); + auto moved_eager_tensor = + eager_tensor.to(options, /*non_blocking=*/false, /*copy=*/true); + lazy_self = torch::lazy::GetOrCreateLtcTensor( + moved_eager_tensor, + torch::lazy::atenDeviceToBackendDevice(eager_device)); + return torch::lazy::CreateAtenFromLtcTensor(lazy_self); - auto result = torch::lazy::CreateAtenFromLtcTensor( - torch::lazy::LazyTensor::Create(std::move(node), lazy_self->GetDevice())); - return result; + } else { + // Case 4: lazy->lazy (special case: keep the _to_copy INSIDE the lazy + // graph) + + // Note: captured _to_copy will be executed with real eager tensors, not + // lazy tensors. We DO NOT want to burn 'lazy:0' as the device into this + // captured IR, or we will try to convert an eager tensor back to a lazy one + // inside the torchscript executor lazy:0 -> lazy:1 is handled in case3, so + // we can safely drop the device argument + device = c10::nullopt; + + torch::lazy::NodePtr node = torch::lazy::ReuseNode( + lazy_self->GetIrValue(), + dtype, + layout, + device, + pin_memory, + non_blocking, + memory_format); + if (!node) { + auto shapes = torch::lazy::compute_shape__to_copy( + self, dtype, layout, device, pin_memory, non_blocking, memory_format); + TORCH_INTERNAL_ASSERT(shapes.size() == 1); + node = torch::lazy::MakeNode( + lazy_self->GetIrValue(), + dtype, + layout, + device, + pin_memory, + non_blocking, + memory_format, + std::move(shapes)); + CacheNode(node); } + + auto result = + torch::lazy::CreateAtenFromLtcTensor(torch::lazy::LazyTensor::Create( + std::move(node), lazy_self->GetDevice())); + return result; + } }; -at::Tensor LazyNativeFunctions::diagonal(const at::Tensor& self, int64_t offset, int64_t dim1, int64_t dim2) { +at::Tensor LazyNativeFunctions::diagonal( + const at::Tensor& self, + int64_t offset, + int64_t dim1, + int64_t dim2) { TORCH_LAZY_FN_COUNTER("lazy::"); auto input = GetLtcTensor(self); auto input_shape = input->shape(); dim1 = at::maybe_wrap_dim(dim1, self); dim2 = at::maybe_wrap_dim(dim2, self); - auto diagonal_info = DiagonalInfo {offset, dim1, dim2}; - auto view_info = ViewInfo(ViewInfo::Type::kDiagonal, input_shape, diagonal_info); + auto diagonal_info = DiagonalInfo{offset, dim1, dim2}; + auto view_info = + ViewInfo(ViewInfo::Type::kDiagonal, input_shape, diagonal_info); return CreateAtenFromLtcTensor(input->CreateViewTensor(std::move(view_info))); } at::Tensor LazyNativeFunctions::empty( - at::IntArrayRef size, c10::optional dtype, - c10::optional layout, c10::optional device, + at::IntArrayRef size, + c10::optional dtype, + c10::optional layout, + c10::optional device, c10::optional pin_memory, c10::optional memory_format) { const auto device_type = torch::lazy::getBackend()->EagerFallbackDeviceType(); @@ -287,24 +334,29 @@ at::Tensor LazyNativeFunctions::empty( } at::Tensor LazyNativeFunctions::empty_strided( - at::IntArrayRef size, at::IntArrayRef stride, - c10::optional dtype, c10::optional layout, - c10::optional device, c10::optional pin_memory) { + at::IntArrayRef size, + at::IntArrayRef stride, + c10::optional dtype, + c10::optional layout, + c10::optional device, + c10::optional pin_memory) { TORCH_LAZY_FN_COUNTER("lazy::"); at::Tensor t = empty(size, dtype, layout, device, pin_memory, c10::nullopt); - return LazyNativeFunctions::as_strided( - t, size, stride, /*storage_offset=*/0); + return LazyNativeFunctions::as_strided(t, size, stride, /*storage_offset=*/0); } -at::Tensor LazyNativeFunctions::expand(const at::Tensor& self, - at::IntArrayRef size, bool implicit) { +at::Tensor LazyNativeFunctions::expand( + const at::Tensor& self, + at::IntArrayRef size, + bool implicit) { TORCH_LAZY_FN_COUNTER("lazy::"); - return torch::lazy::CreateAtenFromLtcTensor(torch::lazy::expand( - torch::lazy::TryGetLtcTensor(self), size.vec())); + return torch::lazy::CreateAtenFromLtcTensor( + torch::lazy::expand(torch::lazy::TryGetLtcTensor(self), size.vec())); } -at::Tensor& LazyNativeFunctions::fill_(at::Tensor& self, - const at::Scalar& value) { +at::Tensor& LazyNativeFunctions::fill_( + at::Tensor& self, + const at::Scalar& value) { TORCH_LAZY_FN_COUNTER("lazy::"); auto self_tensor = torch::lazy::TryGetLtcTensor(self); torch::lazy::fill_(self_tensor, value); @@ -312,44 +364,58 @@ at::Tensor& LazyNativeFunctions::fill_(at::Tensor& self, } at::Tensor LazyNativeFunctions::max_pool3d( - const at::Tensor& self, at::IntArrayRef kernel_size, at::IntArrayRef stride, - at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode) { + const at::Tensor& self, + at::IntArrayRef kernel_size, + at::IntArrayRef stride, + at::IntArrayRef padding, + at::IntArrayRef dilation, + bool ceil_mode) { return torch::lazy::MaxPool3dAutogradFunctionTS::apply( self, kernel_size, stride, padding, dilation, ceil_mode); } -std::tuple -LazyNativeFunctions::native_batch_norm( - const at::Tensor& input, const c10::optional& weight, - const c10::optional& bias, - const c10::optional& running_mean, - const c10::optional& running_var, bool training, - double momentum, double eps) { +std::tuple LazyNativeFunctions:: + native_batch_norm( + const at::Tensor& input, + const c10::optional& weight, + const c10::optional& bias, + const c10::optional& running_mean, + const c10::optional& running_var, + bool training, + double momentum, + double eps) { TORCH_LAZY_FN_COUNTER("lazy::"); auto input_tensor = torch::lazy::TryGetLtcTensor(input); const torch::lazy::BackendDevice& device = input_tensor->GetDevice(); - auto running_mean_tensor = - GetOrCreateLtcTensor(running_mean, device); - auto running_var_tensor = - GetOrCreateLtcTensor(running_var, device); + auto running_mean_tensor = GetOrCreateLtcTensor(running_mean, device); + auto running_var_tensor = GetOrCreateLtcTensor(running_var, device); auto outputs = ts_native_batch_norm( - torch::lazy::TryGetLtcTensor(input), GetOrCreateLtcTensor(weight, device), - GetOrCreateLtcTensor(bias, device), running_mean_tensor, - running_var_tensor, training, momentum, eps); - return std::make_tuple(torch::lazy::CreateAtenFromLtcTensor(std::get<0>(outputs)), - torch::lazy::CreateAtenFromLtcTensor(std::get<1>(outputs)), - torch::lazy::CreateAtenFromLtcTensor(std::get<2>(outputs))); -} - -std::tuple -LazyNativeFunctions::native_batch_norm_backward( - const at::Tensor& grad_out, const at::Tensor& input, - const c10::optional& weight, - const c10::optional& running_mean, - const c10::optional& running_var, - const c10::optional& save_mean, - const c10::optional& save_invstd, bool train, double eps, - std::array output_mask) { + torch::lazy::TryGetLtcTensor(input), + GetOrCreateLtcTensor(weight, device), + GetOrCreateLtcTensor(bias, device), + running_mean_tensor, + running_var_tensor, + training, + momentum, + eps); + return std::make_tuple( + torch::lazy::CreateAtenFromLtcTensor(std::get<0>(outputs)), + torch::lazy::CreateAtenFromLtcTensor(std::get<1>(outputs)), + torch::lazy::CreateAtenFromLtcTensor(std::get<2>(outputs))); +} + +std::tuple LazyNativeFunctions:: + native_batch_norm_backward( + const at::Tensor& grad_out, + const at::Tensor& input, + const c10::optional& weight, + const c10::optional& running_mean, + const c10::optional& running_var, + const c10::optional& save_mean, + const c10::optional& save_invstd, + bool train, + double eps, + std::array output_mask) { TORCH_LAZY_FN_COUNTER("lazy::"); auto grad_out_tensor = torch::lazy::TryGetLtcTensor(grad_out); const torch::lazy::BackendDevice& device = grad_out_tensor->GetDevice(); @@ -357,101 +423,132 @@ LazyNativeFunctions::native_batch_norm_backward( bool running_stats = running_mean && running_mean->defined(); CHECK_EQ(running_var && running_var->defined(), running_stats); auto gradients = ts_native_batch_norm_backward( - torch::lazy::TryGetLtcTensor(grad_out), torch::lazy::TryGetLtcTensor(input), + torch::lazy::TryGetLtcTensor(grad_out), + torch::lazy::TryGetLtcTensor(input), GetOrCreateLtcTensor(weight, device), - running_stats ? GetOrCreateLtcTensor(running_mean, device) - : null_tensor, - running_stats ? GetOrCreateLtcTensor(running_var, device) - : null_tensor, + running_stats ? GetOrCreateLtcTensor(running_mean, device) : null_tensor, + running_stats ? GetOrCreateLtcTensor(running_var, device) : null_tensor, GetOrCreateLtcTensor(save_mean, device), - GetOrCreateLtcTensor(save_invstd, device), train, eps, + GetOrCreateLtcTensor(save_invstd, device), + train, + eps, output_mask); at::Tensor undefined; return std::make_tuple( - output_mask[0] ? torch::lazy::CreateAtenFromLtcTensor(std::get<0>(gradients)) - : undefined, - output_mask[1] ? torch::lazy::CreateAtenFromLtcTensor(std::get<1>(gradients)) - : undefined, - output_mask[2] ? torch::lazy::CreateAtenFromLtcTensor(std::get<2>(gradients)) - : undefined); + output_mask[0] + ? torch::lazy::CreateAtenFromLtcTensor(std::get<0>(gradients)) + : undefined, + output_mask[1] + ? torch::lazy::CreateAtenFromLtcTensor(std::get<1>(gradients)) + : undefined, + output_mask[2] + ? torch::lazy::CreateAtenFromLtcTensor(std::get<2>(gradients)) + : undefined); } // We need to explicitly override max pooling operators and just call the // fallback for them because we've customized the autograd function for them // (backward needs saved indices from forward). std::tuple LazyNativeFunctions::max_pool3d_with_indices( - const at::Tensor& self, at::IntArrayRef kernel_size, at::IntArrayRef stride, - at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode) { - return at::native::call_fallback_fn< - <c_eager_fallback, ATEN_OP(max_pool3d_with_indices)>::call(self, - kernel_size, - stride, - padding, - dilation, - ceil_mode); + const at::Tensor& self, + at::IntArrayRef kernel_size, + at::IntArrayRef stride, + at::IntArrayRef padding, + at::IntArrayRef dilation, + bool ceil_mode) { + return at::native:: + call_fallback_fn<<c_eager_fallback, ATEN_OP(max_pool3d_with_indices)>:: + call(self, kernel_size, stride, padding, dilation, ceil_mode); } at::Tensor LazyNativeFunctions::max_pool3d_with_indices_backward( - const at::Tensor& grad_output, const at::Tensor& self, - at::IntArrayRef kernel_size, at::IntArrayRef stride, - at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode, + const at::Tensor& grad_output, + const at::Tensor& self, + at::IntArrayRef kernel_size, + at::IntArrayRef stride, + at::IntArrayRef padding, + at::IntArrayRef dilation, + bool ceil_mode, const at::Tensor& indices) { return at::native::call_fallback_fn< <c_eager_fallback, - ATEN_OP(max_pool3d_with_indices_backward)>::call(grad_output, self, - kernel_size, stride, - padding, dilation, - ceil_mode, indices); -} - -at::Tensor LazyNativeFunctions::narrow(const at::Tensor& self, int64_t dim, int64_t start, int64_t length) { + ATEN_OP(max_pool3d_with_indices_backward)>:: + call( + grad_output, + self, + kernel_size, + stride, + padding, + dilation, + ceil_mode, + indices); +} + +at::Tensor LazyNativeFunctions::narrow( + const at::Tensor& self, + int64_t dim, + int64_t start, + int64_t length) { TORCH_LAZY_FN_COUNTER("lazy::"); auto self_tensor = torch::lazy::TryGetLtcTensor(self); - return torch::lazy::CreateAtenFromLtcTensor(torch::lazy::narrow( - self_tensor, dim, start, length)); -} - -at::Tensor & LazyNativeFunctions::normal_(at::Tensor & self, double mean, double std, c10::optional generator) { - // Unconditionally fall back. - // implementing normal_ via lazy tensor caused differences in results compared to eager. - return at::native::call_fallback_fn<<c_eager_fallback, ATEN_OP(normal_)>::call(self, mean, std, generator); - - // if (force_eager_fallback(c10::Symbol::fromQualString("aten::normal_"))) { - // return at::native::call_fallback_fn<<c_eager_fallback, ATEN_OP(normal_)>::call(self, mean, std, generator); - // } - - // if (generator.has_value()) { - // return at::native::call_fallback_fn<<c_eager_fallback, ATEN_OP(normal_)>::call(self, mean, std, generator); - // } - - // TORCH_LAZY_FN_COUNTER("lazy::"); - // auto device = bridge::GetBackendDevice(self); - // LazyTensor lazy_self = GetLtcTensorOrCreateForWrappedNumber(self, *device); - // std::vector shapes = {torch::lazy::Shape(self.scalar_type(), self.sizes().vec())}; - // auto node = torch::lazy::MakeNode(lazy_self.GetIrValue(), mean, std, std::move(shapes)); - // lazy_self.SetInPlaceIrValue(node); - // return self; + return torch::lazy::CreateAtenFromLtcTensor( + torch::lazy::narrow(self_tensor, dim, start, length)); +} + +at::Tensor& LazyNativeFunctions::normal_( + at::Tensor& self, + double mean, + double std, + c10::optional generator) { + // Unconditionally fall back. + // implementing normal_ via lazy tensor caused differences in results compared + // to eager. + return at::native::call_fallback_fn<<c_eager_fallback, ATEN_OP(normal_)>:: + call(self, mean, std, generator); + + // if (force_eager_fallback(c10::Symbol::fromQualString("aten::normal_"))) { + // return at::native::call_fallback_fn<<c_eager_fallback, + // ATEN_OP(normal_)>::call(self, mean, std, generator); + // } + + // if (generator.has_value()) { + // return at::native::call_fallback_fn<<c_eager_fallback, + // ATEN_OP(normal_)>::call(self, mean, std, generator); + // } + + // TORCH_LAZY_FN_COUNTER("lazy::"); + // auto device = bridge::GetBackendDevice(self); + // LazyTensor lazy_self = GetLtcTensorOrCreateForWrappedNumber(self, *device); + // std::vector shapes = + // {torch::lazy::Shape(self.scalar_type(), self.sizes().vec())}; auto node = + // torch::lazy::MakeNode(lazy_self.GetIrValue(), mean, std, + // std::move(shapes)); lazy_self.SetInPlaceIrValue(node); return self; }; -at::Tensor LazyNativeFunctions::permute(const at::Tensor& self, - at::IntArrayRef dims) { +at::Tensor LazyNativeFunctions::permute( + const at::Tensor& self, + at::IntArrayRef dims) { TORCH_LAZY_FN_COUNTER("lazy::"); auto self_tensor = torch::lazy::TryGetLtcTensor(self); - return torch::lazy::CreateAtenFromLtcTensor(torch::lazy::permute( - self_tensor, torch::lazy::ToI64Vector(dims))); + return torch::lazy::CreateAtenFromLtcTensor( + torch::lazy::permute(self_tensor, torch::lazy::ToI64Vector(dims))); } -at::Tensor LazyNativeFunctions::select(const at::Tensor& self, int64_t dim, - int64_t index) { +at::Tensor LazyNativeFunctions::select( + const at::Tensor& self, + int64_t dim, + int64_t index) { TORCH_LAZY_FN_COUNTER("lazy::"); return torch::lazy::CreateAtenFromLtcTensor( torch::lazy::select(torch::lazy::TryGetLtcTensor(self), dim, index)); } -at::Tensor LazyNativeFunctions::slice(const at::Tensor& self, int64_t dim, - c10::optional start, - c10::optional end, - int64_t step) { +at::Tensor LazyNativeFunctions::slice( + const at::Tensor& self, + int64_t dim, + c10::optional start, + c10::optional end, + int64_t step) { int64_t start_val = start.has_value() ? start.value() : 0; int64_t end_val = end.has_value() ? end.value() : INT64_MAX; TORCH_LAZY_FN_COUNTER("lazy::"); @@ -498,15 +595,19 @@ at::Tensor& LazyNativeFunctions::t_(at::Tensor& self) { return self; } -at::Tensor LazyNativeFunctions::transpose(const at::Tensor& self, int64_t dim0, - int64_t dim1) { +at::Tensor LazyNativeFunctions::transpose( + const at::Tensor& self, + int64_t dim0, + int64_t dim1) { TORCH_LAZY_FN_COUNTER("lazy::"); return torch::lazy::CreateAtenFromLtcTensor( torch::lazy::transpose(torch::lazy::TryGetLtcTensor(self), dim0, dim1)); } -at::Tensor& LazyNativeFunctions::transpose_(at::Tensor& self, int64_t dim0, - int64_t dim1) { +at::Tensor& LazyNativeFunctions::transpose_( + at::Tensor& self, + int64_t dim0, + int64_t dim1) { TORCH_LAZY_FN_COUNTER("lazy::"); auto self_tensor = torch::lazy::TryGetLtcTensor(self); torch::lazy::transpose_(self_tensor, dim0, dim1); @@ -526,16 +627,18 @@ at::Tensor& LazyNativeFunctions::unsqueeze_(at::Tensor& self, int64_t dim) { return self; } -at::Tensor LazyNativeFunctions::view(const at::Tensor& self, - at::IntArrayRef size) { +at::Tensor LazyNativeFunctions::view( + const at::Tensor& self, + at::IntArrayRef size) { TORCH_LAZY_FN_COUNTER("lazy::"); auto self_tensor = torch::lazy::TryGetLtcTensor(self); return torch::lazy::CreateAtenFromLtcTensor( torch::lazy::view(self_tensor, torch::lazy::ToI64Vector(size))); } -at::Tensor LazyNativeFunctions::_unsafe_view(const at::Tensor& self, - at::IntArrayRef size) { +at::Tensor LazyNativeFunctions::_unsafe_view( + const at::Tensor& self, + at::IntArrayRef size) { TORCH_LAZY_FN_COUNTER("lazy::"); auto self_tensor = torch::lazy::TryGetLtcTensor(self); return torch::lazy::CreateAtenFromLtcTensor( @@ -544,5 +647,5 @@ at::Tensor LazyNativeFunctions::_unsafe_view(const at::Tensor& self, void InitializeAtenBindings() {} -} // namespace lazy -} // namespace torch +} // namespace lazy +} // namespace torch diff --git a/torch/csrc/lazy/ts_backend/ts_node.cpp b/torch/csrc/lazy/ts_backend/ts_node.cpp index 8f3b4156ff5188..de9fdb40908af0 100644 --- a/torch/csrc/lazy/ts_backend/ts_node.cpp +++ b/torch/csrc/lazy/ts_backend/ts_node.cpp @@ -1,25 +1,26 @@ - \ -#include #include +#include namespace { - std::string GetFirstUserFrameInPythonIfEnabled() { - static const auto LTC_ENABLE_SOURCE_INFO = std::getenv("LTC_ENABLE_SOURCE_INFO"); - if (!LTC_ENABLE_SOURCE_INFO) { - return {}; - } - - return torch::lazy::GetFirstUserFrameInPython(); +std::string GetFirstUserFrameInPythonIfEnabled() { + static const auto LTC_ENABLE_SOURCE_INFO = + std::getenv("LTC_ENABLE_SOURCE_INFO"); + if (!LTC_ENABLE_SOURCE_INFO) { + return {}; } + + return torch::lazy::GetFirstUserFrameInPython(); } +} // namespace namespace torch { namespace lazy { - -hash_t OperandHashes(const OpList& operands, - const c10::ArrayRef& shapes, - const hash_t& seed, bool bakeInSizes) { +hash_t OperandHashes( + const OpList& operands, + const c10::ArrayRef& shapes, + const hash_t& seed, + bool bakeInSizes) { hash_t hash = seed; for (auto& operand : operands) { if (!operand) { @@ -35,16 +36,27 @@ hash_t OperandHashes(const OpList& operands, return hash; } -TsNode::TsNode(OpKind op, OpList operands, std::vector&& shapes, size_t num_outputs, hash_t hash_seed) +TsNode::TsNode( + OpKind op, + OpList operands, + std::vector&& shapes, + size_t num_outputs, + hash_t hash_seed) : Node(op, operands, std::move(shapes), num_outputs) { hash_seed = HashCombine(op.hash(), hash_seed); shape_hash_ = OperandHashes(operands, this->shapes(), hash_seed, true); - dag_hash_ = (enableDynamicShape() ? OperandHashes(operands, this->shapes(), hash_seed, false) : shape_hash_); + dag_hash_ = + (enableDynamicShape() + ? OperandHashes(operands, this->shapes(), hash_seed, false) + : shape_hash_); } - -TsNode::TsNode(OpKind op, OpList operands, const std::function& shape_fn, - size_t num_outputs, hash_t hash_seed) +TsNode::TsNode( + OpKind op, + OpList operands, + const std::function& shape_fn, + size_t num_outputs, + hash_t hash_seed) : TsNode(op, operands, std::vector{}, num_outputs, hash_seed) { addComputedShape(shape_fn); } @@ -55,35 +67,39 @@ TsNode::TsNode(OpKind op, OpList operands, size_t num_outputs, hash_t hash_seed) TsNode::TsNode(OpKind op, Shape shape, size_t num_outputs, hash_t hash_seed) : TsNode(op, {}, {std::move(shape)}, num_outputs, hash_seed) {} -hash_t TsNode::hash() const { return dag_hash_; } +hash_t TsNode::hash() const { + return dag_hash_; +} -hash_t TsNode::shapeHash() const { return shape_hash_; } +hash_t TsNode::shapeHash() const { + return shape_hash_; +} const std::string TsNode::getPythonStacktrace() const { return GetFirstUserFrameInPythonIfEnabled(); } TensorList::TensorList(OpList values) - : TsNode(/*op=*/ClassOpKind(), - /*operands=*/values, - /*shapes=*/std::vector(), - /*num_outputs=*/1, - /*hash_seed=*/kHashSeed) {} - -TSOpVector TensorList::Lower(std::shared_ptr function, - TSLoweringContext* loctx) const { - + : TsNode( + /*op=*/ClassOpKind(), + /*operands=*/values, + /*shapes=*/std::vector(), + /*num_outputs=*/1, + /*hash_seed=*/kHashSeed) {} + +TSOpVector TensorList::Lower( + std::shared_ptr function, + TSLoweringContext* loctx) const { std::vector tensor_list; CHECK(!operands().empty()); for (const torch::lazy::Output& operand : operands()) { tensor_list.emplace_back(loctx->GetOutputOp(operand)); } auto graph = function->graph(); - auto listnode = graph->insertNode(graph->createList(tensor_list[0]->type(), tensor_list)); + auto listnode = + graph->insertNode(graph->createList(tensor_list[0]->type(), tensor_list)); return {listnode->output()}; } - - -} // namespace lazy -} // namespace torch +} // namespace lazy +} // namespace torch diff --git a/torch/csrc/lazy/ts_backend/ts_node.h b/torch/csrc/lazy/ts_backend/ts_node.h index 9d43fb03f5477d..62cc9016f6ffa2 100644 --- a/torch/csrc/lazy/ts_backend/ts_node.h +++ b/torch/csrc/lazy/ts_backend/ts_node.h @@ -1,11 +1,11 @@ #pragma once #include +#include +#include #include -#include #include -#include -#include +#include #include namespace torch { @@ -15,15 +15,31 @@ using TSOpVector = std::vector; class TORCH_API TsNode : public lazy::Node { public: - TsNode(OpKind op, OpList operands, std::vector&& shapes, - size_t num_outputs, hash_t hash_seed = kHashSeed); - - TsNode(OpKind op, OpList operands, const std::function& shape_fn, - size_t num_outputs, hash_t hash_seed = kHashSeed); - - TsNode(OpKind op, OpList operands, size_t num_outputs, hash_t hash_seed = kHashSeed); - - TsNode(OpKind op, Shape shape, size_t num_outputs, hash_t hash_seed = kHashSeed); + TsNode( + OpKind op, + OpList operands, + std::vector&& shapes, + size_t num_outputs, + hash_t hash_seed = kHashSeed); + + TsNode( + OpKind op, + OpList operands, + const std::function& shape_fn, + size_t num_outputs, + hash_t hash_seed = kHashSeed); + + TsNode( + OpKind op, + OpList operands, + size_t num_outputs, + hash_t hash_seed = kHashSeed); + + TsNode( + OpKind op, + Shape shape, + size_t num_outputs, + hash_t hash_seed = kHashSeed); ~TsNode() override = default; @@ -36,33 +52,39 @@ class TORCH_API TsNode : public lazy::Node { // Lower is a backend-specific method since it returns a backend specific // type. hence, it is convenient to define it differently per-backend rather // than at Node API - virtual TSOpVector Lower(std::shared_ptr function, - TSLoweringContext* loctx) const; + virtual TSOpVector Lower( + std::shared_ptr function, + TSLoweringContext* loctx) const; private: // The hash of the dag WITH size info. Used for shape caching hash_t shape_hash_; // The hash of the dag used to look up the compiled graph by a hash - // in this case, we will use the dag hash WITHOUT size info if dynamic shape is enabled - // and use the dag hash WITH size info otherwise. + // in this case, we will use the dag hash WITHOUT size info if dynamic shape + // is enabled and use the dag hash WITH size info otherwise. hash_t dag_hash_; }; -// Note: this OpKind is separate from ltc_ops.h since it would be a circular import otherwise, I like leaving TensorList -// in this file, and I think most of ltc_ops special cases will be deleted anyway +// Note: this OpKind is separate from ltc_ops.h since it would be a circular +// import otherwise, I like leaving TensorList in this file, and I think most of +// ltc_ops special cases will be deleted anyway const OpKind tensor_list_opkind = OpKind::Get("lazy_tensors::tensor_list"); -// TensorList represents an at::TensorList which is a vector[Tensor] but is also a first-class IValue -// and can be fed as a single input to a TS program. It is much easier to handle TensorLists in Lazy Tensor code -// if they are represented as a single Node so there can be more than one TensorList and more than one Tensor -// side-by-side as operands to an op. +// TensorList represents an at::TensorList which is a vector[Tensor] but is also +// a first-class IValue and can be fed as a single input to a TS program. It is +// much easier to handle TensorLists in Lazy Tensor code if they are represented +// as a single Node so there can be more than one TensorList and more than one +// Tensor side-by-side as operands to an op. // -// Note: shape is undefined for TensorList. We assert in some places that #shapes matches #outputs and this stems from -// the fact that currently all IR nodes represent tensors (there is no type system for this IR). Becuase of this, -// TensorList is a bit of a hack. +// Note: shape is undefined for TensorList. We assert in some places that +// #shapes matches #outputs and this stems from +// the fact that currently all IR nodes represent tensors (there is no +// type system for this IR). Becuase of this, TensorList is a bit of a +// hack. // -// TODO(whc) once Shape() API is moved to Node base, also make it virtual, and then implement it as NotImplemented for -// TensorList, also fixing the assertion that would fail. +// TODO(whc) once Shape() API is moved to Node base, also make it virtual, and +// then implement it as NotImplemented for TensorList, also fixing the assertion +// that would fail. struct TORCH_API TensorList : public TsNode { static OpKind ClassOpKind() { return tensor_list_opkind; @@ -75,9 +97,10 @@ struct TORCH_API TensorList : public TsNode { return operands() == std::vector(values.begin(), values.end()); } - TSOpVector Lower(std::shared_ptr function, - TSLoweringContext* loctx) const override; + TSOpVector Lower( + std::shared_ptr function, + TSLoweringContext* loctx) const override; }; -} // namespace lazy -} // namespace torch +} // namespace lazy +} // namespace torch diff --git a/torch/csrc/lazy/ts_backend/ts_node_lowering.cpp b/torch/csrc/lazy/ts_backend/ts_node_lowering.cpp index 5c0489ed812788..1fabeb5c03462f 100644 --- a/torch/csrc/lazy/ts_backend/ts_node_lowering.cpp +++ b/torch/csrc/lazy/ts_backend/ts_node_lowering.cpp @@ -7,9 +7,9 @@ #include #include #include +#include #include #include -#include #include #include #include @@ -33,7 +33,8 @@ TSOpVector LowerBuiltin( } TSOpVector LowerTSBuiltin( - std::shared_ptr function, c10::Symbol sym, + std::shared_ptr function, + c10::Symbol sym, const std::vector& arguments, const std::vector& kwarguments) { auto builtin = @@ -55,10 +56,9 @@ TSOpVector LowerTSBuiltin( return {sv->getValue()}; } - torch::jit::Value* GenerateClone( - torch::jit::Value* val, - std::shared_ptr function) { + torch::jit::Value* val, + std::shared_ptr function) { std::vector clone_arguments; clone_arguments.emplace_back(val); TSOpVector cloned = LowerBuiltin(at::aten::clone, function, clone_arguments); @@ -67,9 +67,9 @@ torch::jit::Value* GenerateClone( } void GenerateCopy( - torch::jit::Value* destination, - torch::jit::Value* source, - std::shared_ptr function) { + torch::jit::Value* destination, + torch::jit::Value* source, + std::shared_ptr function) { std::vector arguments; arguments.emplace_back(destination); arguments.emplace_back(source); @@ -77,9 +77,12 @@ void GenerateCopy( } torch::jit::Value* GenerateSlice( - torch::jit::Value* base, int64_t dim, - int64_t start, int64_t end, int64_t step, - std::shared_ptr function) { + torch::jit::Value* base, + int64_t dim, + int64_t start, + int64_t end, + int64_t step, + std::shared_ptr function) { std::vector arguments; arguments.emplace_back(base); arguments.emplace_back(dim); @@ -94,8 +97,9 @@ torch::jit::Value* GenerateSlice( // Node Lowerings // Default node lowering -TSOpVector TsNode::Lower(std::shared_ptr function, - TSLoweringContext* loctx) const { +TSOpVector TsNode::Lower( + std::shared_ptr function, + TSLoweringContext* loctx) const { std::vector arguments; for (const torch::lazy::Output& output : operands()) { arguments.emplace_back(loctx->GetOutputOp(output)); @@ -138,9 +142,9 @@ TSOpVector TSNativeBatchNormBackward::Lower( return LowerBuiltin(this, function, arguments); } - // Non-native ops -torch::lazy::TSOpVector Cast::Lower(std::shared_ptr function, +torch::lazy::TSOpVector Cast::Lower( + std::shared_ptr function, torch::lazy::TSLoweringContext* loctx) const { std::vector arguments; arguments.emplace_back(loctx->GetOutputOp(operand(0))); @@ -148,17 +152,21 @@ torch::lazy::TSOpVector Cast::Lower(std::shared_ptr f return LowerBuiltin(at::aten::to, function, arguments); } -torch::lazy::TSOpVector DeviceData::Lower(std::shared_ptr function, +torch::lazy::TSOpVector DeviceData::Lower( + std::shared_ptr function, torch::lazy::TSLoweringContext* loctx) const { auto infoptr = data_->info(); - auto deviceDataInfoPtr = (torch::lazy::LazyGraphExecutor::DeviceDataInfo*) infoptr; + auto deviceDataInfoPtr = + (torch::lazy::LazyGraphExecutor::DeviceDataInfo*)infoptr; if (GRAPH_DUMP_ENABLED) { - LOG(ERROR) << "Lowering device data node, tensor id " << deviceDataInfoPtr->tensor_id << std::endl; + LOG(ERROR) << "Lowering device data node, tensor id " + << deviceDataInfoPtr->tensor_id << std::endl; } return {loctx->GetParameter(data_)}; } -torch::lazy::TSOpVector Expand::Lower(std::shared_ptr function, +torch::lazy::TSOpVector Expand::Lower( + std::shared_ptr function, torch::lazy::TSLoweringContext* loctx) const { std::vector arguments; arguments.emplace_back(loctx->GetOutputOp(operand(0))); @@ -175,21 +183,21 @@ torch::lazy::TSOpVector Expand::Lower(std::shared_ptr return expand_out; } -torch::lazy::TSOpVector Scalar::Lower(std::shared_ptr function, +torch::lazy::TSOpVector Scalar::Lower( + std::shared_ptr function, torch::lazy::TSLoweringContext* loctx) const { auto options = at::TensorOptions() .device(torch::lazy::getBackend()->EagerFallbackDeviceType()) .dtype(shape().scalar_type()); - return { - loctx->graph()->insertConstant(at::scalar_tensor(value, options))}; + return {loctx->graph()->insertConstant(at::scalar_tensor(value, options))}; } // View Ops -torch::lazy::TSOpVector AsStrided::Lower(std::shared_ptr function, +torch::lazy::TSOpVector AsStrided::Lower( + std::shared_ptr function, torch::lazy::TSLoweringContext* loctx) const { - std::vector arguments; arguments.emplace_back(loctx->GetOutputOp(operand(0))); arguments.emplace_back(size); @@ -200,9 +208,9 @@ torch::lazy::TSOpVector AsStrided::Lower(std::shared_ptr function, +torch::lazy::TSOpVector AsStridedViewUpdate::Lower( + std::shared_ptr function, torch::lazy::TSLoweringContext* loctx) const { - torch::jit::Value* destination = GenerateClone(loctx->GetOutputOp(operand(0)), function); const torch::lazy::Output& input_op = operand(1); @@ -222,9 +230,9 @@ torch::lazy::TSOpVector AsStridedViewUpdate::Lower(std::shared_ptr function, +torch::lazy::TSOpVector Diagonal::Lower( + std::shared_ptr function, torch::lazy::TSLoweringContext* loctx) const { - std::vector arguments; arguments.emplace_back(loctx->GetOutputOp(operand(0))); arguments.emplace_back(offset); @@ -233,13 +241,15 @@ torch::lazy::TSOpVector Diagonal::Lower(std::shared_ptr function, +torch::lazy::TSOpVector DiagonalViewUpdate::Lower( + std::shared_ptr function, torch::lazy::TSLoweringContext* loctx) const { // Since we promise the backends that we never generate any aliased // inplace update IR, therefore we clone the target first and then // update the clone inplace instead. Since the clone is transient, // it will never be aliased, and therefore it's safe. - torch::jit::Value* destination = GenerateClone(loctx->GetOutputOp(operand(0)), function); + torch::jit::Value* destination = + GenerateClone(loctx->GetOutputOp(operand(0)), function); // Replay the diagonal. std::vector arguments; @@ -256,7 +266,8 @@ torch::lazy::TSOpVector DiagonalViewUpdate::Lower(std::shared_ptr function, +torch::lazy::TSOpVector Narrow::Lower( + std::shared_ptr function, torch::lazy::TSLoweringContext* loctx) const { const torch::lazy::Output& input = operand(0); torch::jit::Value* base = loctx->GetOutputOp(input); @@ -265,14 +276,19 @@ torch::lazy::TSOpVector Narrow::Lower(std::shared_ptr CHECK_EQ(input_shape.dim(), base_indices.size()); for (size_t dim = 0; dim < base_indices.size(); ++dim) { int64_t start = base_indices[dim]; - base = GenerateSlice(/*base=*/base, /*dim=*/dim, /*start=*/start, - /*end=*/start + sizes[dim], /*step=*/1, - /*function=*/function); + base = GenerateSlice( + /*base=*/base, + /*dim=*/dim, + /*start=*/start, + /*end=*/start + sizes[dim], + /*step=*/1, + /*function=*/function); } return {base}; } -torch::lazy::TSOpVector NarrowViewUpdate::Lower(std::shared_ptr function, +torch::lazy::TSOpVector NarrowViewUpdate::Lower( + std::shared_ptr function, torch::lazy::TSLoweringContext* loctx) const { torch::jit::Value* dest = GenerateClone(loctx->GetOutputOp(operand(0)), function); @@ -282,15 +298,20 @@ torch::lazy::TSOpVector NarrowViewUpdate::Lower(std::shared_ptrGetOutputOp(source_argument), function); return {dest}; } -torch::lazy::TSOpVector Permute::Lower(std::shared_ptr function, +torch::lazy::TSOpVector Permute::Lower( + std::shared_ptr function, torch::lazy::TSLoweringContext* loctx) const { std::vector arguments; arguments.emplace_back(loctx->GetOutputOp(operand(0))); @@ -298,9 +319,9 @@ torch::lazy::TSOpVector Permute::Lower(std::shared_ptr function, +torch::lazy::TSOpVector Resize::Lower( + std::shared_ptr function, torch::lazy::TSLoweringContext* loctx) const { - std::vector arguments; for (const torch::lazy::Output& output : operands()) { arguments.emplace_back(loctx->GetOutputOp(output)); @@ -308,28 +329,39 @@ torch::lazy::TSOpVector Resize::Lower(std::shared_ptr return LowerBuiltin(this, function, arguments); } -torch::lazy::TSOpVector Select::Lower(std::shared_ptr function, +torch::lazy::TSOpVector Select::Lower( + std::shared_ptr function, torch::lazy::TSLoweringContext* loctx) const { int64_t step = torch::lazy::GetStride(start, end, stride); torch::jit::Value* base = loctx->GetOutputOp(operand(0)); - return {GenerateSlice(/*base=*/base, /*dim=*/dim, - /*start=*/start, /*end=*/end, - /*step=*/step, /*function=*/function)}; + return {GenerateSlice( + /*base=*/base, + /*dim=*/dim, + /*start=*/start, + /*end=*/end, + /*step=*/step, + /*function=*/function)}; } -torch::lazy::TSOpVector SelectViewUpdate::Lower(std::shared_ptr function, +torch::lazy::TSOpVector SelectViewUpdate::Lower( + std::shared_ptr function, torch::lazy::TSLoweringContext* loctx) const { torch::jit::Value* dest = GenerateClone(loctx->GetOutputOp(operand(0)), function); int64_t step = torch::lazy::GetStride(start, end, stride); torch::jit::Value* selected = GenerateSlice( - /*base=*/dest, /*dim=*/dim, /*start=*/start, - /*end=*/end, /*step=*/step, /*function=*/function); + /*base=*/dest, + /*dim=*/dim, + /*start=*/start, + /*end=*/end, + /*step=*/step, + /*function=*/function); GenerateCopy(selected, loctx->GetOutputOp(operand(1)), function); return {dest}; } -torch::lazy::TSOpVector Squeeze::Lower(std::shared_ptr function, +torch::lazy::TSOpVector Squeeze::Lower( + std::shared_ptr function, torch::lazy::TSLoweringContext* loctx) const { std::vector arguments; arguments.emplace_back(loctx->GetOutputOp(operand(0))); @@ -339,7 +371,8 @@ torch::lazy::TSOpVector Squeeze::Lower(std::shared_ptr function, +torch::lazy::TSOpVector Unsqueeze::Lower( + std::shared_ptr function, torch::lazy::TSLoweringContext* loctx) const { std::vector arguments; arguments.emplace_back(loctx->GetOutputOp(operand(0))); @@ -347,7 +380,8 @@ torch::lazy::TSOpVector Unsqueeze::Lower(std::shared_ptr function, +torch::lazy::TSOpVector View::Lower( + std::shared_ptr function, torch::lazy::TSLoweringContext* loctx) const { std::vector arguments; arguments.emplace_back(loctx->GetOutputOp(operand(0))); @@ -355,5 +389,5 @@ torch::lazy::TSOpVector View::Lower(std::shared_ptr f return LowerBuiltin(at::aten::reshape, function, arguments); } -} // namespace lazy -} // namespace torch +} // namespace lazy +} // namespace torch diff --git a/torch/csrc/lazy/ts_backend/ts_node_lowering.h b/torch/csrc/lazy/ts_backend/ts_node_lowering.h index fa3803d4e40582..cf46311ca24b33 100644 --- a/torch/csrc/lazy/ts_backend/ts_node_lowering.h +++ b/torch/csrc/lazy/ts_backend/ts_node_lowering.h @@ -8,9 +8,10 @@ namespace lazy { using TSOpVector = std::vector; TORCH_API TSOpVector LowerTSBuiltin( - std::shared_ptr function, - c10::Symbol sym, const std::vector& arguments, - const std::vector& kwarguments = {}); + std::shared_ptr function, + c10::Symbol sym, + const std::vector& arguments, + const std::vector& kwarguments = {}); } // namespace lazy } // namespace torch diff --git a/torch/csrc/monitor/events.h b/torch/csrc/monitor/events.h index 7d44f71d1bbb32..bb18e27ce528be 100644 --- a/torch/csrc/monitor/events.h +++ b/torch/csrc/monitor/events.h @@ -4,8 +4,8 @@ #include #include -#include #include +#include namespace torch { namespace monitor { diff --git a/torch/csrc/multiprocessing/init.cpp b/torch/csrc/multiprocessing/init.cpp index ceade98fbba913..098cf7f97d5363 100644 --- a/torch/csrc/multiprocessing/init.cpp +++ b/torch/csrc/multiprocessing/init.cpp @@ -18,7 +18,7 @@ namespace multiprocessing { namespace { -PyObject* multiprocessing_init(PyObject* _unused, PyObject *noargs) { +PyObject* multiprocessing_init(PyObject* _unused, PyObject* noargs) { auto multiprocessing_module = THPObjectPtr(PyImport_ImportModule("torch.multiprocessing")); if (!multiprocessing_module) { diff --git a/torch/csrc/onnx/init.h b/torch/csrc/onnx/init.h index a88d251ea1366a..7436916c2b7c31 100644 --- a/torch/csrc/onnx/init.h +++ b/torch/csrc/onnx/init.h @@ -2,8 +2,10 @@ #include -namespace torch { namespace onnx { +namespace torch { +namespace onnx { void initONNXBindings(PyObject* module); -}} // namespace torch::onnx +} +} // namespace torch diff --git a/torch/csrc/profiler/api.cpp b/torch/csrc/profiler/api.cpp index 3be80715194cc8..e89955a039f8f0 100644 --- a/torch/csrc/profiler/api.cpp +++ b/torch/csrc/profiler/api.cpp @@ -49,9 +49,8 @@ bool profilerEnabled() { TORCH_API ActiveProfilerType profilerType() { auto state_ptr = ProfilerThreadLocalStateBase::getTLS(); - return state_ptr == nullptr - ? ActiveProfilerType::NONE - : state_ptr->profilerType(); + return state_ptr == nullptr ? ActiveProfilerType::NONE + : state_ptr->profilerType(); } torch::profiler::impl::ProfilerConfig getProfilerConfig() { diff --git a/torch/csrc/profiler/api.h b/torch/csrc/profiler/api.h index 1eb55d6f98bf06..53610e30d99add 100644 --- a/torch/csrc/profiler/api.h +++ b/torch/csrc/profiler/api.h @@ -29,19 +29,14 @@ enum class C10_API_ENUM ProfilerState { NUM_PROFILER_STATES, // must be the last one }; -enum class C10_API_ENUM ActiveProfilerType { - NONE = 0, - LEGACY, - KINETO, - NVTX -}; +enum class C10_API_ENUM ActiveProfilerType { NONE = 0, LEGACY, KINETO, NVTX }; struct TORCH_API ExperimentalConfig { explicit ExperimentalConfig( std::vector profiler_metrics = {}, bool profiler_measure_per_kernel = false) - : profiler_metrics(std::move(profiler_metrics)), - profiler_measure_per_kernel(profiler_measure_per_kernel) {} + : profiler_metrics(std::move(profiler_metrics)), + profiler_measure_per_kernel(profiler_measure_per_kernel) {} ~ExperimentalConfig() = default; std::vector profiler_metrics; bool profiler_measure_per_kernel = false; @@ -168,10 +163,10 @@ namespace torch { namespace autograd { namespace profiler { using torch::profiler::impl::ActivityType; +using torch::profiler::impl::getProfilerConfig; using torch::profiler::impl::ProfilerConfig; -using torch::profiler::impl::ProfilerState; using torch::profiler::impl::profilerEnabled; -using torch::profiler::impl::getProfilerConfig; +using torch::profiler::impl::ProfilerState; } // namespace profiler } // namespace autograd } // namespace torch diff --git a/torch/csrc/profiler/collection.cpp b/torch/csrc/profiler/collection.cpp index 9eec2c9453536e..c3204feae343f9 100644 --- a/torch/csrc/profiler/collection.cpp +++ b/torch/csrc/profiler/collection.cpp @@ -38,14 +38,14 @@ void InputOutputEncoder::push(const at::Tensor& t) { const auto& sizes = t.sizes(); const auto dim = sizes.size(); TORCH_CHECK( - dim <= std::numeric_limits::max(), - "Cannot profile Tensors of size > uint32 max. Got dim: ", dim); + dim <= std::numeric_limits::max(), + "Cannot profile Tensors of size > uint32 max. Got dim: ", + dim); tensor_metadata_.emplace_back( - /*ptr_=*/(void*)t.unsafeGetTensorImpl(), - /*dtype_=*/t.scalar_type(), - /*dim_=*/(uint32_t)dim - ); + /*ptr_=*/(void*)t.unsafeGetTensorImpl(), + /*dtype_=*/t.scalar_type(), + /*dim_=*/(uint32_t)dim); for (const auto i : sizes) { tensor_sizes_.emplace_back(i); @@ -57,7 +57,8 @@ void InputOutputEncoder::push(const at::Tensor& t) { // This is a custom-iterator-like getter to obtain input shapes and dtypes. auto InputOutputEncoder::getNextShapesAndDtypes() { - return [this, tag_it = tags_.begin(), + return [this, + tag_it = tags_.begin(), tensor_metadata_it = tensor_metadata_.begin(), tensor_size_it = tensor_sizes_.begin()]() mutable { struct Inputs out; @@ -65,21 +66,19 @@ auto InputOutputEncoder::getNextShapesAndDtypes() { while (!terminate && tag_it != tags_.end()) { out.shapes_.emplace_back(); switch (*tag_it) { - case Tag::Tensor: - { - const auto& md = *tensor_metadata_it++; - for (const auto _ : c10::irange(md.dim_)) { - (void)_; // Suppress unused variable warning - out.shapes_.back().push_back(*tensor_size_it++); - } - out.dtypes_.emplace_back(scalarTypeToTypeMeta(md.dtype_).name()); + case Tag::Tensor: { + const auto& md = *tensor_metadata_it++; + for (const auto _ : c10::irange(md.dim_)) { + (void)_; // Suppress unused variable warning + out.shapes_.back().push_back(*tensor_size_it++); } - break; + out.dtypes_.emplace_back(scalarTypeToTypeMeta(md.dtype_).name()); + } break; case Tag::TensorListBegin: - while (*(++tag_it) != Tag::TERMINATOR) { - // TODO: Skip TensorLists for now. - } + while (*(++tag_it) != Tag::TERMINATOR) { + // TODO: Skip TensorLists for now. + } out.dtypes_.emplace_back("TensorList"); break; @@ -343,8 +342,9 @@ ThreadLocalSubqueue* RecordQueue::getSubqueue() { std::lock_guard guard(sub_queue_mutex_); auto it = sub_queues_.find(tid); if (it == sub_queues_.end()) { - it = - sub_queues_.emplace(tid, std::make_unique(tid, config_)).first; + it = sub_queues_ + .emplace(tid, std::make_unique(tid, config_)) + .first; } sub_queue_cache_ = SubQueueThreadCache{id_, it->second.get()}; diff --git a/torch/csrc/profiler/collection.h b/torch/csrc/profiler/collection.h index 9509c659940ef6..734cd7caf1a56c 100644 --- a/torch/csrc/profiler/collection.h +++ b/torch/csrc/profiler/collection.h @@ -270,7 +270,8 @@ class InputOutputEncoder final { void push(const at::Tensor& t); AppendOnlyList tags_; - AppendOnlyList tensor_metadata_; + AppendOnlyList + tensor_metadata_; AppendOnlyList tensor_sizes_; }; @@ -322,7 +323,9 @@ class TORCH_API ThreadLocalSubqueue { public: ThreadLocalSubqueue(const uint64_t tid, const ProfilerConfig& config); - std::unique_ptr begin_op(const at::RecordFunction& fn, uint64_t correlation_id); + std::unique_ptr begin_op( + const at::RecordFunction& fn, + uint64_t correlation_id); template void emplace_backend_event(Args&&... args) { @@ -362,7 +365,8 @@ class TORCH_API ThreadLocalSubqueue { // with_stack AppendOnlyList jit_stack_; - AppendOnlyList, BlockSize> py_calls_; + AppendOnlyList, BlockSize> + py_calls_; // with_modules AppendOnlyList jit_modules_; @@ -396,7 +400,8 @@ class TORCH_API RecordQueue { uint32_t id_; ProfilerConfig config_; std::set activities_; - ska::flat_hash_map> sub_queues_; + ska::flat_hash_map> + sub_queues_; std::mutex sub_queue_mutex_; }; diff --git a/torch/csrc/profiler/containers.h b/torch/csrc/profiler/containers.h index 782ed90248210f..db0f3b32b9815c 100644 --- a/torch/csrc/profiler/containers.h +++ b/torch/csrc/profiler/containers.h @@ -24,11 +24,13 @@ namespace impl { // indirection), and this class of operation is sufficiently important to the // profiling hot path to warrant specializing: // https://godbolt.org/z/rTjozf1c4 -// https://quick-bench.com/q/mmfuu71ogwaiULDCJyHdKnHZms4 (Prototype #1, int) -// https://quick-bench.com/q/5vWDW6jjdXVdoffev2zst8D09no (Prototype #1, int pair) -// https://quick-bench.com/q/IfEkfAQMeJSNBA52xtMP6Agcl-Q (Prototype #2, int pair) -// https://quick-bench.com/q/wJV2lKmuXL4XyGJzcI5hs4gEHFg (Prototype #3, int pair) -// https://quick-bench.com/q/xiO8ZaBEkYRYUA9dFrMuPLlW9fo (Full impl, int pair) +// https://quick-bench.com/q/mmfuu71ogwaiULDCJyHdKnHZms4 (Prototype #1, +// int) https://quick-bench.com/q/5vWDW6jjdXVdoffev2zst8D09no (Prototype +// #1, int pair) https://quick-bench.com/q/IfEkfAQMeJSNBA52xtMP6Agcl-Q +// (Prototype #2, int pair) +// https://quick-bench.com/q/wJV2lKmuXL4XyGJzcI5hs4gEHFg (Prototype #3, int +// pair) https://quick-bench.com/q/xiO8ZaBEkYRYUA9dFrMuPLlW9fo (Full impl, +// int pair) // AppendOnlyList has 2x lower emplace overhead compared to more generic STL // containers. // @@ -68,13 +70,13 @@ class AppendOnlyList { struct Iterator { using iterator_category = std::forward_iterator_tag; - using difference_type = std::ptrdiff_t; - using value_type = T; - using pointer = T*; - using reference = T&; + using difference_type = std::ptrdiff_t; + using value_type = T; + using pointer = T*; + using reference = T&; Iterator(std::forward_list& buffer, const size_t size) - : block_{buffer.begin()}, size_{size} {} + : block_{buffer.begin()}, size_{size} {} // End iterator. Iterator() = default; @@ -83,8 +85,12 @@ class AppendOnlyList { return current_ >= size_; } - reference operator*() const { return *current_ptr(/*checked=*/true); } - pointer operator->() { return current_ptr(/*checked=*/true); } + reference operator*() const { + return *current_ptr(/*checked=*/true); + } + pointer operator->() { + return current_ptr(/*checked=*/true); + } // Prefix increment Iterator& operator++() { @@ -95,7 +101,11 @@ class AppendOnlyList { } // Postfix increment - Iterator operator++(int) { Iterator tmp = *this; ++(*this); return tmp; } + Iterator operator++(int) { + Iterator tmp = *this; + ++(*this); + return tmp; + } friend bool operator==(const Iterator& a, const Iterator& b) { return a.current_ptr() == b.current_ptr(); @@ -105,7 +115,7 @@ class AppendOnlyList { } std::pair address() const { - if (current_ >= size_){ + if (current_ >= size_) { return {nullptr, 0}; } return {&(*block_), current_ % ChunkSize}; @@ -122,15 +132,19 @@ class AppendOnlyList { } typename std::forward_list::iterator block_; - size_t current_ {0}; - size_t size_ {0}; + size_t current_{0}; + size_t size_{0}; }; - Iterator begin() { return Iterator(buffer_, size()); } - Iterator end() { return Iterator(); } + Iterator begin() { + return Iterator(buffer_, size()); + } + Iterator end() { + return Iterator(); + } // TODO: cbegin and cend() -// TODO: make private + // TODO: make private protected: void maybe_grow() { if (C10_UNLIKELY(next_ == end_)) { @@ -146,9 +160,9 @@ class AppendOnlyList { // We maintain a pointer to the last element of `buffer_` so that we can // insert at the end in O(1) time. typename std::forward_list::iterator buffer_last_; - size_t n_blocks_ {0}; - T* next_ {nullptr}; - T* end_ {nullptr}; + size_t n_blocks_{0}; + T* next_{nullptr}; + T* end_{nullptr}; }; } // namespace impl diff --git a/torch/csrc/profiler/cuda.cpp b/torch/csrc/profiler/cuda.cpp index 5af415fd14a744..9116516f4ae0ba 100644 --- a/torch/csrc/profiler/cuda.cpp +++ b/torch/csrc/profiler/cuda.cpp @@ -1,7 +1,7 @@ -#include #include #include #include +#include #include @@ -10,8 +10,8 @@ namespace profiler { namespace impl { namespace { -static inline void cudaCheck(cudaError_t result, const char * file, int line) { - if(result != cudaSuccess) { +static inline void cudaCheck(cudaError_t result, const char* file, int line) { + if (result != cudaSuccess) { std::stringstream ss; ss << file << ":" << line << ": "; if (result == cudaErrorInitializationError) { @@ -31,10 +31,11 @@ static inline void cudaCheck(cudaError_t result, const char * file, int line) { throw std::runtime_error(ss.str()); } } -#define TORCH_CUDA_CHECK(result) cudaCheck(result,__FILE__,__LINE__); +#define TORCH_CUDA_CHECK(result) cudaCheck(result, __FILE__, __LINE__); struct CUDAMethods : public CUDAStubs { - void record(int* device, CUDAEventStub* event, int64_t* cpu_ns) const override { + void record(int* device, CUDAEventStub* event, int64_t* cpu_ns) + const override { if (device) { TORCH_CUDA_CHECK(cudaGetDevice(device)); } @@ -51,14 +52,15 @@ struct CUDAMethods : public CUDAStubs { TORCH_CUDA_CHECK(cudaEventRecord(cuda_event_ptr, stream)); } - float elapsed(const CUDAEventStub* event, const CUDAEventStub* event2) const override{ + float elapsed(const CUDAEventStub* event, const CUDAEventStub* event2) + const override { TORCH_CUDA_CHECK(cudaEventSynchronize(event->get())); TORCH_CUDA_CHECK(cudaEventSynchronize(event2->get())); // NOLINTNEXTLINE(cppcoreguidelines-init-variables) float ms; TORCH_CUDA_CHECK(cudaEventElapsedTime(&ms, event->get(), event2->get())); // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-narrowing-conversions) - return ms*1000.0; + return ms * 1000.0; } void nvtxMarkA(const char* name) const override { @@ -79,7 +81,7 @@ struct CUDAMethods : public CUDAStubs { at::cuda::OptionalCUDAGuard device_guard; // NOLINTNEXTLINE(bugprone-signed-char-misuse) int count = at::cuda::device_count(); - for(const auto i : c10::irange(count)) { + for (const auto i : c10::irange(count)) { device_guard.set_index(i); op(i); } diff --git a/torch/csrc/profiler/kineto_client_interface.cpp b/torch/csrc/profiler/kineto_client_interface.cpp index 515e41810a3d96..e6842718e0474a 100644 --- a/torch/csrc/profiler/kineto_client_interface.cpp +++ b/torch/csrc/profiler/kineto_client_interface.cpp @@ -34,13 +34,13 @@ class LibKinetoClient : public libkineto::ClientInterface { void stop() override { (void)disableProfiler(); } + private: bool reportInputShapes_{true}; }; } // namespace - } // namespace impl } // namespace profiler diff --git a/torch/csrc/profiler/kineto_shim.cpp b/torch/csrc/profiler/kineto_shim.cpp index 38fb008d12f0a3..92a456f98c62e9 100644 --- a/torch/csrc/profiler/kineto_shim.cpp +++ b/torch/csrc/profiler/kineto_shim.cpp @@ -92,8 +92,8 @@ void TraceWrapper::addCPUActivity( #ifdef USE_KINETO TORCH_CHECK((bool)(*this), "Cannot add event to non-existent trace."); auto type = toActivityType(kineto_type); - cpu_trace_->activities.emplace_back(libkineto::GenericTraceActivity( - cpu_trace_->span, type, name)); + cpu_trace_->activities.emplace_back( + libkineto::GenericTraceActivity(cpu_trace_->span, type, name)); auto& act = cpu_trace_->activities.back(); act.device = device_and_resource.device; act.resource = device_and_resource.resource; @@ -154,18 +154,18 @@ class ExperimentalConfigWrapper { public: explicit ExperimentalConfigWrapper( const torch::profiler::impl::ExperimentalConfig& config) - : config_(config) {} + : config_(config) {} - bool assertValid( - const ActivitySet& activities) { + bool assertValid(const ActivitySet& activities) { // Kineto supports reading performance events per kernel/iteration // using CUPTI Range based profiler API. In this mode however we // do not trace CPU or GPU events. bool cupti_range_profiler = config_.profiler_metrics.size() > 0; - if (cupti_range_profiler && activities.count( - torch::autograd::profiler::ActivityType::CPU)) { - LOG(WARNING) << "Cannot run range profiler with CPU activities, please only" - << " use CUDA activity type"; + if (cupti_range_profiler && + activities.count(torch::autograd::profiler::ActivityType::CPU)) { + LOG(WARNING) + << "Cannot run range profiler with CPU activities, please only" + << " use CUDA activity type"; return false; } return cupti_range_profiler; @@ -173,30 +173,30 @@ class ExperimentalConfigWrapper { void prepareTraceWithExperimentalOptions() { #ifdef USE_KINETO - std::set k_activities{ - libkineto::ActivityType::CUDA_PROFILER_RANGE}; + std::set k_activities{ + libkineto::ActivityType::CUDA_PROFILER_RANGE}; - const size_t num_metrics = config_.profiler_metrics.size(); - std::stringstream configss; + const size_t num_metrics = config_.profiler_metrics.size(); + std::stringstream configss; - LOG(INFO) << "CUPTI profiler metrics size = " << num_metrics; + LOG(INFO) << "CUPTI profiler metrics size = " << num_metrics; - configss << "ACTIVITIES_WARMUP_PERIOD_SECS=0\n" - << "CUPTI_PROFILER_METRICS="; + configss << "ACTIVITIES_WARMUP_PERIOD_SECS=0\n" + << "CUPTI_PROFILER_METRICS="; - for (int i = 0; i < num_metrics; i++) { - configss << config_.profiler_metrics[i]; - if (num_metrics > 1 && i < (num_metrics-1)) { - configss << ","; + for (int i = 0; i < num_metrics; i++) { + configss << config_.profiler_metrics[i]; + if (num_metrics > 1 && i < (num_metrics - 1)) { + configss << ","; + } } - } - configss << "\nCUPTI_PROFILER_ENABLE_PER_KERNEL=" - << (config_.profiler_measure_per_kernel ? "true" : "false") - << "\n"; - LOG(INFO) << "Generated config = " << configss.str(); + configss << "\nCUPTI_PROFILER_ENABLE_PER_KERNEL=" + << (config_.profiler_measure_per_kernel ? "true" : "false") + << "\n"; + LOG(INFO) << "Generated config = " << configss.str(); - libkineto::api().activityProfiler().prepareTrace( - k_activities, configss.str()); + libkineto::api().activityProfiler().prepareTrace( + k_activities, configss.str()); #endif // USE_KINETO } diff --git a/torch/csrc/profiler/kineto_shim.h b/torch/csrc/profiler/kineto_shim.h index ce3b33d0307bba..885e2a6755a626 100644 --- a/torch/csrc/profiler/kineto_shim.h +++ b/torch/csrc/profiler/kineto_shim.h @@ -21,16 +21,16 @@ namespace libkineto { enum class ActivityType; struct CpuTraceBuffer; class ActivityTraceInterface; -} +} // namespace libkineto #endif namespace torch { namespace profiler { #ifdef USE_KINETO -constexpr bool kKinetoAvailable {true}; +constexpr bool kKinetoAvailable{true}; #else -constexpr bool kKinetoAvailable {false}; +constexpr bool kKinetoAvailable{false}; #endif namespace impl { @@ -89,7 +89,7 @@ struct TraceWrapper { explicit operator bool() const; std::unique_ptr& get() { - return cpu_trace_; + return cpu_trace_; } private: @@ -106,7 +106,7 @@ struct ActivityTraceWrapper { void save(const std::string& path); const std::unique_ptr& get() { - return trace_; + return trace_; } private: @@ -116,7 +116,8 @@ struct ActivityTraceWrapper { using ActivitySet = std::set; void prepareTrace( - const bool cpuOnly, const ActivitySet& activities, + const bool cpuOnly, + const ActivitySet& activities, const torch::profiler::impl::ExperimentalConfig& config); void startTrace(); ActivityTraceWrapper stopTrace(); diff --git a/torch/csrc/profiler/nvtx_observer.cpp b/torch/csrc/profiler/nvtx_observer.cpp index 2f460a38f8da29..ddc677d9786b58 100644 --- a/torch/csrc/profiler/nvtx_observer.cpp +++ b/torch/csrc/profiler/nvtx_observer.cpp @@ -30,10 +30,15 @@ struct NVTXThreadLocalState : ProfilerThreadLocalStateBase { tls == nullptr || tls->profilerType() == ActiveProfilerType::NVTX); return static_cast(tls); } - std::pair getOpIdFromInput(const at::Tensor& tensor); - - void setProducerTensorMap(at::TensorImpl *tensor, at::RecordFunctionHandle op_id, int output_nr){ - producer_tensor_map_[(void*)tensor] = std::pair {op_id, output_nr}; + std::pair getOpIdFromInput( + const at::Tensor& tensor); + + void setProducerTensorMap( + at::TensorImpl* tensor, + at::RecordFunctionHandle op_id, + int output_nr) { + producer_tensor_map_[(void*)tensor] = + std::pair{op_id, output_nr}; } protected: @@ -42,22 +47,26 @@ struct NVTXThreadLocalState : ProfilerThreadLocalStateBase { // at::TensorImpl* is the actual type of the key, but using void* // to indicate the pointer is just being used as a key // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) - std::unordered_map> producer_tensor_map_; + std::unordered_map> + producer_tensor_map_; }; -std::pair NVTXThreadLocalState::getOpIdFromInput(const at::Tensor& tensor) { +std::pair NVTXThreadLocalState::getOpIdFromInput( + const at::Tensor& tensor) { std::pair producer_op_pair(0, -1); if (tensor.defined()) { - at::TensorImpl *ten_addr = tensor.unsafeGetTensorImpl(); + at::TensorImpl* ten_addr = tensor.unsafeGetTensorImpl(); // See if Address is in the map already if (producer_tensor_map_.count((void*)ten_addr) > 0) { - producer_op_pair = producer_tensor_map_[(void*)ten_addr]; + producer_op_pair = producer_tensor_map_[(void*)ten_addr]; } } return producer_op_pair; } -std::list> flattenOpIdList(c10::List list, std::string fn_name) { +std::list> flattenOpIdList( + c10::List list, + std::string fn_name) { std::list> input_op_id_list; auto state_ptr = NVTXThreadLocalState::getTLS(); TORCH_INTERNAL_ASSERT(state_ptr, "Expected profiler state set"); @@ -71,19 +80,21 @@ std::list> flattenOpIdList(c10::List> getInputTensorOpIds(const at::RecordFunction& fn) { - std::pair undefined_op_pair(0,-1); +std::list> getInputTensorOpIds( + const at::RecordFunction& fn) { + std::pair undefined_op_pair(0, -1); std::list> input_producer_ops_; auto state_ptr = NVTXThreadLocalState::getTLS(); TORCH_INTERNAL_ASSERT(state_ptr, "Expected profiler state set"); for (const c10::IValue& input_item : fn.inputs()) { - if(input_item.isTensor()) { + if (input_item.isTensor()) { const at::Tensor& tensor = input_item.toTensor(); auto producer_pair = state_ptr->getOpIdFromInput(tensor); input_producer_ops_.push_back(producer_pair); } else { if (input_item.isList()) { - std::list> tmp_op_ids = flattenOpIdList(input_item.toList(), std::string(fn.name())); + std::list> tmp_op_ids = + flattenOpIdList(input_item.toList(), std::string(fn.name())); // Extend the current sizes array by the array returned from input sizes if (!tmp_op_ids.empty()) { input_producer_ops_.splice(input_producer_ops_.end(), tmp_op_ids); @@ -91,7 +102,7 @@ std::list> getInputTensorOpIds(const at input_producer_ops_.emplace_back(undefined_op_pair); } } else { - input_producer_ops_.emplace_back(undefined_op_pair); + input_producer_ops_.emplace_back(undefined_op_pair); } } } @@ -102,11 +113,11 @@ void updateOutputTensorTracker(const at::RecordFunction& fn) { int output_nr = 0; auto state_ptr = NVTXThreadLocalState::getTLS(); TORCH_INTERNAL_ASSERT(state_ptr, "Expected profiler state set"); - for (const c10::IValue& s_tensor : fn.outputs()){ - if(s_tensor.isTensor()) { + for (const c10::IValue& s_tensor : fn.outputs()) { + if (s_tensor.isTensor()) { const at::Tensor& tensor = s_tensor.toTensor(); if (tensor.defined()) { - auto ten_addr = tensor.unsafeGetTensorImpl(); + auto ten_addr = tensor.unsafeGetTensorImpl(); state_ptr->setProducerTensorMap(ten_addr, fn.handle(), output_nr); } } @@ -125,8 +136,9 @@ std::unique_ptr enterNVTX(const at::RecordFunction& fn) { report_input_shapes ? torch::profiler::impl::inputSizes(fn, true) : std::vector>(), fn.handle(), - report_input_shapes ? input_op_ids - : std::list>()) + report_input_shapes + ? input_op_ids + : std::list>()) .c_str()); } return nullptr; @@ -151,7 +163,7 @@ void pushNVTXCallbacks( state_ptr->config().report_input_shapes ? &enterNVTX : &enterNVTX, - [](const at::RecordFunction& fn, at::ObserverContext *ctx) { + [](const at::RecordFunction& fn, at::ObserverContext* ctx) { torch::profiler::impl::cudaStubs()->nvtxRangePop(); updateOutputTensorTracker(fn); }) diff --git a/torch/csrc/profiler/util.cpp b/torch/csrc/profiler/util.cpp index 4415b7fa78a900..6ae2e745806f32 100644 --- a/torch/csrc/profiler/util.cpp +++ b/torch/csrc/profiler/util.cpp @@ -1,10 +1,10 @@ -#include #include #include +#include #include -#include #include +#include #ifdef USE_KINETO #include @@ -15,7 +15,7 @@ namespace profiler { namespace impl { ApproximateClockToUnixTimeConverter::ApproximateClockToUnixTimeConverter() - : start_times_(measurePairs()) {} + : start_times_(measurePairs()) {} ApproximateClockToUnixTimeConverter::UnixAndApproximateTimePair ApproximateClockToUnixTimeConverter::measurePair() { @@ -33,7 +33,7 @@ ApproximateClockToUnixTimeConverter::measurePair() { } ApproximateClockToUnixTimeConverter::time_pairs - ApproximateClockToUnixTimeConverter::measurePairs() { +ApproximateClockToUnixTimeConverter::measurePairs() { static constexpr auto n_warmup = 5; for (C10_UNUSED const auto _ : c10::irange(n_warmup)) { getApproximateTime(); @@ -47,8 +47,8 @@ ApproximateClockToUnixTimeConverter::time_pairs return out; } -std::function -ApproximateClockToUnixTimeConverter::makeConverter() { +std::function ApproximateClockToUnixTimeConverter:: + makeConverter() { auto end_times = measurePairs(); // Compute the real time that passes for each tick of the approximate clock. @@ -72,8 +72,9 @@ ApproximateClockToUnixTimeConverter::makeConverter() { auto t0_approx = start_times_[0].approx_t_; std::array t0_correction{}; for (const auto i : c10::irange(replicates)) { - auto dt = start_times_[i].t_ - t0; - auto dt_approx = (double)(start_times_[i].approx_t_ - t0_approx) * scale_factor; + auto dt = start_times_[i].t_ - t0; + auto dt_approx = + (double)(start_times_[i].approx_t_ - t0_approx) * scale_factor; t0_correction[i] = dt - (time_t)dt_approx; } t0 += t0_correction[t0_correction.size() / 2 + 1]; @@ -114,7 +115,8 @@ std::string getNvtxStr( // Include the op ids of the input edges so // you can build the network graph if (input_op_ids.size() > 0) { - str = fmt::format("{}, input_op_ids = {}", str, inputOpIdsToStr(input_op_ids)); + str = fmt::format( + "{}, input_op_ids = {}", str, inputOpIdsToStr(input_op_ids)); } return str; } else { @@ -174,7 +176,9 @@ std::string stacksToStr( return "\"" + rc + "\""; } -std::vector> flattenList(c10::List list, std::string fn_name) { +std::vector> flattenList( + c10::List list, + std::string fn_name) { std::vector> tensor_dims; for (const c10::IValue input : list) { if (input.isTensor()) { @@ -187,7 +191,9 @@ std::vector> flattenList(c10::List list, std:: return tensor_dims; } -std::vector> inputSizes(const at::RecordFunction& fn, bool flatten_list_enabled) { +std::vector> inputSizes( + const at::RecordFunction& fn, + bool flatten_list_enabled) { std::vector> sizes; sizes.reserve(fn.inputs().size()); for (const c10::IValue& input : fn.inputs()) { @@ -235,7 +241,8 @@ std::string shapesToStr(const std::vector>& shapes) { return str; } -std::string inputOpIdsToStr(const std::list>& input_op_ids) { +std::string inputOpIdsToStr( + const std::list>& input_op_ids) { std::string str("["); int idx = 0; @@ -244,7 +251,8 @@ std::string inputOpIdsToStr(const std::list #include +#include #include #include #include -#include -#include #include +#include #include #include @@ -95,7 +95,7 @@ inline auto getApproximateTime() { using approx_time_t = decltype(getApproximateTime()); static_assert( std::is_same::value || - std::is_same::value, + std::is_same::value, "Expected either int64_t (`getTime`) or uint64_t (some TSC reads)."); // Convert `getCount` results to Nanoseconds since unix epoch. @@ -123,7 +123,8 @@ std::string getNvtxStr( int64_t sequence_nr, const std::vector>& shapes, at::RecordFunctionHandle op_id = 0, - const std::list>& input_op_ids = {}); + const std::list>& input_op_ids = + {}); struct TORCH_API FileLineFunc { std::string filename; @@ -140,11 +141,12 @@ TORCH_API std::string stacksToStr( const char* delim); TORCH_API std::vector> inputSizes( const at::RecordFunction& fn, - const bool flatten_list_enabled=false); + const bool flatten_list_enabled = false); TORCH_API std::string shapesToStr( const std::vector>& shapes); TORCH_API std::string dtypesToStr(const std::vector& types); -TORCH_API std::string inputOpIdsToStr(const std::list>& input_op_ids); +TORCH_API std::string inputOpIdsToStr( + const std::list>& input_op_ids); TORCH_API std::vector inputTypes(const at::RecordFunction& fn); std::unordered_map TORCH_API @@ -161,8 +163,8 @@ uint64_t TORCH_API computeFlops( namespace torch { namespace autograd { namespace profiler { -using torch::profiler::impl::getTime; using torch::profiler::impl::computeFlops; +using torch::profiler::impl::getTime; } // namespace profiler } // namespace autograd } // namespace torch diff --git a/torch/csrc/python_dimname.cpp b/torch/csrc/python_dimname.cpp index 2b65555239ac60..a403375261afdf 100644 --- a/torch/csrc/python_dimname.cpp +++ b/torch/csrc/python_dimname.cpp @@ -1,30 +1,32 @@ -#include +#include #include +#include #include -#include namespace torch { struct InternedStringsTable { InternedStringsTable() = default; ~InternedStringsTable(); - InternedStringsTable(const InternedStringsTable &) = delete; - InternedStringsTable& operator =(InternedStringsTable const&) = delete; + InternedStringsTable(const InternedStringsTable&) = delete; + InternedStringsTable& operator=(InternedStringsTable const&) = delete; InternedStringsTable(InternedStringsTable&&) = delete; InternedStringsTable& operator=(InternedStringsTable&&) = delete; at::optional lookup(PyObject* obj); // Precondition: obj is an interned python string. void addMapping(PyObject* obj, at::Dimname dimname); + private: - ska::flat_hash_map py_interned_string_to_dimname_; + ska::flat_hash_map py_interned_string_to_dimname_; }; InternedStringsTable kPyInternedStringToDimname; InternedStringsTable::~InternedStringsTable() { for (auto it = py_interned_string_to_dimname_.begin(); - it != py_interned_string_to_dimname_.end(); ++it) { + it != py_interned_string_to_dimname_.end(); + ++it) { // See Note [References to python interned strings] Py_DECREF(it->first); } @@ -38,7 +40,6 @@ at::optional InternedStringsTable::lookup(PyObject* obj) { return it->second; } - void InternedStringsTable::addMapping(PyObject* obj, at::Dimname dimname) { // Note [References to python interned strings] // If a Python interned string has no references to it, then it gets @@ -66,7 +67,8 @@ bool THPUtils_checkDimnameList(PyObject* obj) { if (size == 0) { return true; } - PyObject* first_elt = tuple ? PyTuple_GET_ITEM(obj, 0) : PyList_GET_ITEM(obj, 0); + PyObject* first_elt = + tuple ? PyTuple_GET_ITEM(obj, 0) : PyList_GET_ITEM(obj, 0); return THPUtils_checkDimname(first_elt); } @@ -76,13 +78,16 @@ at::Dimname THPDimname_parse(PyObject* obj) { } if (!THPUtils_checkString(obj)) { - throw torch::TypeError("expected None or string for Dimname but got %s", Py_TYPE(obj)->tp_name); + throw torch::TypeError( + "expected None or string for Dimname but got %s", + Py_TYPE(obj)->tp_name); } if (!THPUtils_isInterned(obj)) { // internStringInPlace decrefs obj and increfs the result. Because we're // not actually returning the result to the user, we need to undo these. - // See https://docs.python.org/3/c-api/unicode.html#c.PyUnicode_InternInPlace + // See + // https://docs.python.org/3/c-api/unicode.html#c.PyUnicode_InternInPlace Py_INCREF(obj); THPUtils_internStringInPlace(&obj); Py_DECREF(obj); diff --git a/torch/csrc/python_dimname.h b/torch/csrc/python_dimname.h index 2bca0f07344f27..01a6007e9f8e82 100644 --- a/torch/csrc/python_dimname.h +++ b/torch/csrc/python_dimname.h @@ -1,6 +1,6 @@ #pragma once -#include #include +#include at::Dimname THPDimname_parse(PyObject* obj); bool THPUtils_checkDimname(PyObject* obj); diff --git a/torch/csrc/serialization.cpp b/torch/csrc/serialization.cpp index 7c9af576e8e4cd..46f3a04f355b41 100644 --- a/torch/csrc/serialization.cpp +++ b/torch/csrc/serialization.cpp @@ -1,9 +1,9 @@ #include #include +#include #include #include -#include template Py_ssize_t doPartialRead(io fildes, void* buf, size_t nbytes); @@ -11,9 +11,18 @@ Py_ssize_t doPartialRead(io fildes, void* buf, size_t nbytes); template Py_ssize_t doPartialWrite(io fildes, void* buf, size_t nbytes); -static Py_ssize_t doPartialPythonReadBuffered(PyObject* fildes, void* buf, size_t nbytes); -static Py_ssize_t doPartialPythonReadInto(PyObject* fildes, void* buf, size_t nbytes); -static Py_ssize_t doPartialPythonWrite(PyObject* fildes, void* buf, size_t nbytes); +static Py_ssize_t doPartialPythonReadBuffered( + PyObject* fildes, + void* buf, + size_t nbytes); +static Py_ssize_t doPartialPythonReadInto( + PyObject* fildes, + void* buf, + size_t nbytes); +static Py_ssize_t doPartialPythonWrite( + PyObject* fildes, + void* buf, + size_t nbytes); template <> Py_ssize_t doPartialRead(int fildes, void* buf, size_t nbytes) { @@ -21,7 +30,10 @@ Py_ssize_t doPartialRead(int fildes, void* buf, size_t nbytes) { } template <> -Py_ssize_t doPartialRead(PyObject* fildes, void* buf, size_t nbytes) { +Py_ssize_t doPartialRead( + PyObject* fildes, + void* buf, + size_t nbytes) { // Try to use fildes.readinto() instead of fildes.read() // because it is more memory efficient. // TODO: Stop calling PyObject_HasAttrString() in a loop on our read loop @@ -38,20 +50,28 @@ Py_ssize_t doPartialWrite(int fildes, void* buf, size_t nbytes) { } template <> -Py_ssize_t doPartialWrite(PyObject* fildes, void* buf, size_t nbytes) { +Py_ssize_t doPartialWrite( + PyObject* fildes, + void* buf, + size_t nbytes) { return doPartialPythonWrite(fildes, buf, nbytes); } static inline bool isUnsupportedOperation() { THPObjectPtr io(PyImport_ImportModule("io")); - if (!io) throw python_error(); + if (!io) + throw python_error(); THPObjectPtr exception(PyObject_GetAttrString(io, "UnsupportedOperation")); - if (!exception) throw python_error(); + if (!exception) + throw python_error(); return PyErr_ExceptionMatches(exception.get()); } // Call Python fildes.read(nbytes) and copy it to buf. -static inline Py_ssize_t doPartialPythonReadBuffered(PyObject* fildes, void* buf, size_t raw_nbytes) { +static inline Py_ssize_t doPartialPythonReadBuffered( + PyObject* fildes, + void* buf, + size_t raw_nbytes) { // If we request a large amount of data, f.read() will internally try to // allocate a buffer of that size. This is counterproductive, because // it's not the buffer we ultimately want to write the data into. Read @@ -60,7 +80,8 @@ static inline Py_ssize_t doPartialPythonReadBuffered(PyObject* fildes, void* buf const size_t nbytes = std::min(raw_nbytes, 262144u); // 2^18 (~260 KB) THPObjectPtr r(PyObject_CallMethod(fildes, "read", "i", nbytes)); - if (!r) throw python_error(); + if (!r) + throw python_error(); auto size = PyBytes_GET_SIZE(r.get()); const void* py_buf = PyBytes_AsString(r.get()); @@ -77,22 +98,29 @@ static inline Py_ssize_t doPartialPythonReadBuffered(PyObject* fildes, void* buf } // Either does fildes.readinto(buf) or fildes.write(buf) -static inline Py_ssize_t doPartialPythonIO(PyObject* fildes, void* buf, size_t nbytes, bool is_read) { +static inline Py_ssize_t doPartialPythonIO( + PyObject* fildes, + void* buf, + size_t nbytes, + bool is_read) { auto rw_flag = is_read ? PyBUF_WRITE : PyBUF_READ; - THPObjectPtr memview(PyMemoryView_FromMemory( - reinterpret_cast(buf), nbytes, rw_flag)); - if (!memview) throw python_error(); + THPObjectPtr memview( + PyMemoryView_FromMemory(reinterpret_cast(buf), nbytes, rw_flag)); + if (!memview) + throw python_error(); std::string method = "write"; if (is_read) { method = "readinto"; } - THPObjectPtr r(PyObject_CallMethod(fildes, method.c_str(), "O", memview.get())); + THPObjectPtr r( + PyObject_CallMethod(fildes, method.c_str(), "O", memview.get())); if (r) { return PyLong_AsSsize_t(r.get()); } - // fildes.readinto can return UnsupportedOperation so fall back to fildes.read. + // fildes.readinto can return UnsupportedOperation so fall back to + // fildes.read. if (is_read && isUnsupportedOperation()) { PyErr_Clear(); return doPartialPythonReadBuffered(fildes, buf, nbytes); @@ -101,12 +129,18 @@ static inline Py_ssize_t doPartialPythonIO(PyObject* fildes, void* buf, size_t n } // Call Python fildes.readinto(buf) -static Py_ssize_t doPartialPythonReadInto(PyObject* fildes, void* buf, size_t nbytes) { +static Py_ssize_t doPartialPythonReadInto( + PyObject* fildes, + void* buf, + size_t nbytes) { return doPartialPythonIO(fildes, buf, nbytes, /* is_read */ true); } // Call Python fildes.write(buf) -static Py_ssize_t doPartialPythonWrite(PyObject* fildes, void* buf, size_t nbytes) { +static Py_ssize_t doPartialPythonWrite( + PyObject* fildes, + void* buf, + size_t nbytes) { return doPartialPythonIO(fildes, buf, nbytes, /* is_read */ false); } @@ -118,12 +152,17 @@ void doRead(io fildes, void* raw_buf, size_t nbytes) { errno = 0; // doPartialRead may not set errno // we read in 1GB blocks to avoid bugs on Mac OS X Lion // see https://github.com/pytorch/pytorch/issues/1031 for more details - Py_ssize_t r = doPartialRead(fildes, buf, std::min(nbytes, 1073741824)); + Py_ssize_t r = + doPartialRead(fildes, buf, std::min(nbytes, 1073741824)); if (r < 0) { int err = errno; - TORCH_INTERNAL_ASSERT(err != 0, "read(): impossible! r < 0, but no errno was set"); - TORCH_INTERNAL_ASSERT(err != EAGAIN, "read(): non-blocking fd ", fildes, - " read EAGAIN; cowardly refusing to spin-wait"); + TORCH_INTERNAL_ASSERT( + err != 0, "read(): impossible! r < 0, but no errno was set"); + TORCH_INTERNAL_ASSERT( + err != EAGAIN, + "read(): non-blocking fd ", + fildes, + " read EAGAIN; cowardly refusing to spin-wait"); if (err == EINTR) { continue; } else { @@ -139,7 +178,10 @@ void doRead(io fildes, void* raw_buf, size_t nbytes) { nbytes -= r; } if (nbytes != 0) { - AT_ERROR("unexpected EOF, expected ", nbytes, " more bytes. The file might be corrupted."); + AT_ERROR( + "unexpected EOF, expected ", + nbytes, + " more bytes. The file might be corrupted."); } } @@ -150,12 +192,17 @@ void doWrite(io fildes, void* raw_buf, size_t nbytes) { errno = 0; // doPartialWrite may not set errno // we write in 1GB blocks to avoid bugs on Mac OS X Lion // see https://github.com/pytorch/pytorch/issues/1031 for more details - Py_ssize_t r = doPartialWrite(fildes, buf, std::min(nbytes, 1073741824)); + Py_ssize_t r = + doPartialWrite(fildes, buf, std::min(nbytes, 1073741824)); if (r < 0) { int err = errno; - TORCH_INTERNAL_ASSERT(err != 0, "write(): impossible! r < 0, but no errno was set"); - TORCH_INTERNAL_ASSERT(err != EAGAIN, "write(): non-blocking fd ", fildes, - " read EAGAIN; cowardly refusing to spin-wait"); + TORCH_INTERNAL_ASSERT( + err != 0, "write(): impossible! r < 0, but no errno was set"); + TORCH_INTERNAL_ASSERT( + err != EAGAIN, + "write(): non-blocking fd ", + fildes, + " read EAGAIN; cowardly refusing to spin-wait"); if (err == EINTR) { continue; } else { @@ -172,11 +219,14 @@ void doWrite(io fildes, void* raw_buf, size_t nbytes) { // [size + data], but the v1.5 eager format removes this since size is saved in // the filesize. template -void THPStorage_writeFileRaw(c10::StorageImpl *self, io fd, bool save_size, uint64_t element_size) -{ +void THPStorage_writeFileRaw( + c10::StorageImpl* self, + io fd, + bool save_size, + uint64_t element_size) { c10::DeviceGuard guard(self->device()); // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - uint8_t *data; + uint8_t* data; // NOLINTNEXTLINE(cppcoreguidelines-init-variables) std::unique_ptr cpu_data; int64_t size_bytes = self->nbytes(); @@ -188,13 +238,11 @@ void THPStorage_writeFileRaw(c10::StorageImpl *self, io fd, bool save_size, uint cpu_data = std::unique_ptr(new char[size_bytes]); data = (uint8_t*)cpu_data.get(); C10_CUDA_CHECK(cudaMemcpy( - data, - self->data(), - size_bytes, - cudaMemcpyDeviceToHost)); + data, self->data(), size_bytes, cudaMemcpyDeviceToHost)); #endif } else { - TORCH_CHECK(false, "writeFileRaw: Device not recognized: ", self->device_type()); + TORCH_CHECK( + false, "writeFileRaw: Device not recognized: ", self->device_type()); } if (save_size) { if (torch::utils::THP_nativeByteOrder() == @@ -220,7 +268,8 @@ void THPStorage_writeFileRaw(c10::StorageImpl *self, io fd, bool save_size, uint int64_t buffer_size = std::min(numel, (int64_t)5000); // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays) // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - std::unique_ptr le_buffer(new uint8_t[buffer_size * element_size]); + std::unique_ptr le_buffer( + new uint8_t[buffer_size * element_size]); for (int64_t i = 0; i < numel; i += buffer_size) { size_t to_convert = std::min(numel - i, buffer_size); // NOLINTNEXTLINE(bugprone-branch-clone) @@ -248,19 +297,28 @@ void THPStorage_writeFileRaw(c10::StorageImpl *self, io fd, bool save_size, uint } } -template void THPStorage_writeFileRaw(c10::StorageImpl *self, int fd, bool save_size, uint64_t element_size); -template void THPStorage_writeFileRaw(c10::StorageImpl *self, PyObject* fd, bool save_size, uint64_t element_size); +template void THPStorage_writeFileRaw( + c10::StorageImpl* self, + int fd, + bool save_size, + uint64_t element_size); +template void THPStorage_writeFileRaw( + c10::StorageImpl* self, + PyObject* fd, + bool save_size, + uint64_t element_size); template c10::intrusive_ptr THPStorage_readFileRaw( - io file, c10::intrusive_ptr storage, uint64_t element_size) -{ + io file, + c10::intrusive_ptr storage, + uint64_t element_size) { c10::OptionalDeviceGuard guard; if (storage.defined()) { guard.reset_device(storage->device()); } // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - uint8_t *data; + uint8_t* data; // NOLINTNEXTLINE(cppcoreguidelines-init-variables) int64_t size; doRead(file, &size, sizeof(int64_t)); @@ -271,14 +329,17 @@ c10::intrusive_ptr THPStorage_readFileRaw( int64_t nsize; // convert little endian storage to big endian cpu nsize = nbytes; torch::utils::THP_decodeInt64Buffer( - &nbytes, (const uint8_t*)&nsize, torch::utils::THP_nativeByteOrder(), 1); + &nbytes, + (const uint8_t*)&nsize, + torch::utils::THP_nativeByteOrder(), + 1); } if (!storage.defined()) { storage = c10::make_intrusive( - c10::StorageImpl::use_byte_size_t(), - nbytes, - c10::GetDefaultCPUAllocator(), - /*resizable=*/true); + c10::StorageImpl::use_byte_size_t(), + nbytes, + c10::GetDefaultCPUAllocator(), + /*resizable=*/true); } else { int64_t _storage_nbytes = storage->nbytes(); TORCH_CHECK( @@ -307,7 +368,8 @@ c10::intrusive_ptr THPStorage_readFileRaw( int64_t buffer_size = std::min(size, (int64_t)5000); // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays) // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - std::unique_ptr le_buffer(new uint8_t[buffer_size * element_size]); + std::unique_ptr le_buffer( + new uint8_t[buffer_size * element_size]); for (int64_t i = 0; i < size; i += buffer_size) { size_t to_convert = std::min(size - i, buffer_size); @@ -338,13 +400,18 @@ c10::intrusive_ptr THPStorage_readFileRaw( #ifdef USE_CUDA if (storage->device_type() == at::kCUDA) { - C10_CUDA_CHECK(cudaMemcpy(storage->data(), data, nbytes, cudaMemcpyHostToDevice)); + C10_CUDA_CHECK(cudaMemcpy( + storage->data(), data, nbytes, cudaMemcpyHostToDevice)); } #endif return storage; } template c10::intrusive_ptr THPStorage_readFileRaw( - int fd, c10::intrusive_ptr storage, uint64_t element_size); + int fd, + c10::intrusive_ptr storage, + uint64_t element_size); template c10::intrusive_ptr THPStorage_readFileRaw( - PyObject* fd, c10::intrusive_ptr storage, uint64_t element_size); + PyObject* fd, + c10::intrusive_ptr storage, + uint64_t element_size); diff --git a/torch/csrc/serialization.h b/torch/csrc/serialization.h index f0c10a974e379a..6fe774ae1231fa 100644 --- a/torch/csrc/serialization.h +++ b/torch/csrc/serialization.h @@ -8,10 +8,16 @@ template void doWrite(io fildes, void* buf, size_t nbytes); template -void THPStorage_writeFileRaw(c10::StorageImpl *self, io fd, bool save_size, uint64_t element_size); +void THPStorage_writeFileRaw( + c10::StorageImpl* self, + io fd, + bool save_size, + uint64_t element_size); template c10::intrusive_ptr THPStorage_readFileRaw( - io fd, c10::intrusive_ptr storage, uint64_t element_size); + io fd, + c10::intrusive_ptr storage, + uint64_t element_size); #endif diff --git a/torch/csrc/tensor/python_tensor.cpp b/torch/csrc/tensor/python_tensor.cpp index 81e65b0f5c00d0..13c3d63675266e 100644 --- a/torch/csrc/tensor/python_tensor.cpp +++ b/torch/csrc/tensor/python_tensor.cpp @@ -1,16 +1,16 @@ #include -#include #include +#include #include #include #include #include -#include -#include #include +#include #include +#include #include #include #include @@ -24,7 +24,8 @@ #include #include -namespace torch { namespace tensors { +namespace torch { +namespace tensors { using namespace at; using namespace torch::autograd; @@ -52,23 +53,35 @@ struct PyTensorType { } }; -static_assert(std::is_standard_layout::value, "PyTensorType must be standard layout"); +static_assert( + std::is_standard_layout::value, + "PyTensorType must be standard layout"); static Backend default_backend = Backend::CPU; -static void py_bind_tensor_types(const std::vector& tensor_types); +static void py_bind_tensor_types( + const std::vector& tensor_types); static TypeError unavailable_type(const PyTensorType& type) { - return TypeError("type %s not available. Torch not compiled with CUDA enabled.", type.name); + return TypeError( + "type %s not available. Torch not compiled with CUDA enabled.", + type.name); } -static PyObject* Tensor_new(PyTypeObject *type, PyObject *args, PyObject *kwargs) { +static PyObject* Tensor_new( + PyTypeObject* type, + PyObject* args, + PyObject* kwargs) { HANDLE_TH_ERRORS auto& tensor_type = *((PyTensorType*)type); if (tensor_type.is_cuda && !torch::utils::cuda_enabled()) { throw unavailable_type(tensor_type); } - return THPVariable_Wrap(torch::utils::legacy_tensor_ctor(tensor_type.get_dispatch_key(), tensor_type.get_scalar_type(), args, kwargs)); + return THPVariable_Wrap(torch::utils::legacy_tensor_ctor( + tensor_type.get_dispatch_key(), + tensor_type.get_scalar_type(), + args, + kwargs)); END_HANDLE_TH_ERRORS } @@ -99,15 +112,15 @@ static PyObject* Tensor_instancecheck(PyObject* _self, PyObject* arg) { END_HANDLE_TH_ERRORS } -PyObject *Tensor_dtype(PyTensorType* self, void *unused) { +PyObject* Tensor_dtype(PyTensorType* self, void* unused) { return torch::autograd::utils::wrap(self->dtype); } -PyObject *Tensor_layout(PyTensorType* self, void *unused) { +PyObject* Tensor_layout(PyTensorType* self, void* unused) { return torch::autograd::utils::wrap(self->layout); } -PyObject *Tensor_is_cuda(PyTensorType* self, void *unused) { +PyObject* Tensor_is_cuda(PyTensorType* self, void* unused) { if (self->is_cuda) { Py_RETURN_TRUE; } else { @@ -115,7 +128,7 @@ PyObject *Tensor_is_cuda(PyTensorType* self, void *unused) { } } -PyObject *Tensor_is_sparse(PyTensorType *self, void *unused) { +PyObject* Tensor_is_sparse(PyTensorType* self, void* unused) { if (self->layout->layout == at::Layout::Strided) { Py_RETURN_FALSE; } else { @@ -123,7 +136,7 @@ PyObject *Tensor_is_sparse(PyTensorType *self, void *unused) { } } -PyObject *Tensor_is_sparse_csr(PyTensorType *self, void *unused) { +PyObject* Tensor_is_sparse_csr(PyTensorType* self, void* unused) { if (self->layout->layout == at::Layout::SparseCsr) { Py_RETURN_TRUE; } else { @@ -133,26 +146,23 @@ PyObject *Tensor_is_sparse_csr(PyTensorType *self, void *unused) { // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables,modernize-avoid-c-arrays) static struct PyMethodDef metaclass_methods[] = { - {"__instancecheck__", Tensor_instancecheck, METH_O, nullptr}, - {nullptr} -}; + {"__instancecheck__", Tensor_instancecheck, METH_O, nullptr}, + {nullptr}}; -typedef PyObject *(*getter)(PyObject *, void *); +typedef PyObject* (*getter)(PyObject*, void*); // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables,modernize-avoid-c-arrays) static struct PyGetSetDef metaclass_properties[] = { - {"dtype", (getter)Tensor_dtype, nullptr, nullptr, nullptr}, - {"layout", (getter)Tensor_layout, nullptr, nullptr, nullptr}, - {"is_cuda", (getter)Tensor_is_cuda, nullptr, nullptr, nullptr}, - {"is_sparse", (getter)Tensor_is_sparse, nullptr, nullptr, nullptr}, - {"is_sparse_csr",(getter)Tensor_is_sparse_csr, nullptr, nullptr, nullptr}, - {nullptr} -}; + {"dtype", (getter)Tensor_dtype, nullptr, nullptr, nullptr}, + {"layout", (getter)Tensor_layout, nullptr, nullptr, nullptr}, + {"is_cuda", (getter)Tensor_is_cuda, nullptr, nullptr, nullptr}, + {"is_sparse", (getter)Tensor_is_sparse, nullptr, nullptr, nullptr}, + {"is_sparse_csr", (getter)Tensor_is_sparse_csr, nullptr, nullptr, nullptr}, + {nullptr}}; static PyTypeObject metaclass = { - PyVarObject_HEAD_INIT(nullptr, 0) - "torch.tensortype", /* tp_name */ - sizeof(PyTypeObject) /* tp_basicsize */ + PyVarObject_HEAD_INIT(nullptr, 0) "torch.tensortype", /* tp_name */ + sizeof(PyTypeObject) /* tp_basicsize */ }; static void py_initialize_metaclass(PyTypeObject& metaclass) { @@ -166,12 +176,14 @@ static void py_initialize_metaclass(PyTypeObject& metaclass) { } static PyTypeObject tensor_type_prototype = { - PyVarObject_HEAD_INIT(&metaclass, 0) - nullptr, /* tp_name */ - sizeof(PyTensorType) /* tp_basicsize */ + PyVarObject_HEAD_INIT(&metaclass, 0) nullptr, /* tp_name */ + sizeof(PyTensorType) /* tp_basicsize */ }; -static void py_initialize_tensor_type(PyTypeObject& type, const char* name, PyObject* tp_dict) { +static void py_initialize_tensor_type( + PyTypeObject& type, + const char* name, + PyObject* tp_dict) { // NOTE: we don't use the typical static declaration of PyTypeObject because // we need to initialize as many types as there are VariableType instances. // We copy the basic object fields from a prototype definition and initialize @@ -192,11 +204,16 @@ static void py_initialize_tensor_type(PyTypeObject& type, const char* name, PyOb static const char* get_module(Backend backend) { switch (backend) { - case Backend::CPU: return "torch"; - case Backend::CUDA: return "torch.cuda"; - case Backend::SparseCPU: return "torch.sparse"; - case Backend::SparseCUDA: return "torch.cuda.sparse"; - default: AT_ERROR("invalid backend: ", toString(backend)); + case Backend::CPU: + return "torch"; + case Backend::CUDA: + return "torch.cuda"; + case Backend::SparseCPU: + return "torch.sparse"; + case Backend::SparseCUDA: + return "torch.cuda.sparse"; + default: + AT_ERROR("invalid backend: ", toString(backend)); } } @@ -209,23 +226,29 @@ static std::string get_name(Backend backend, ScalarType scalarType) { static THPObjectPtr get_storage_obj(Backend backend, ScalarType dtype) { auto module_name = get_module(backend); auto module_obj = THPObjectPtr(PyImport_ImportModule(module_name)); - if (!module_obj) throw python_error(); + if (!module_obj) + throw python_error(); auto storage_name = std::string(toString(dtype)) + "Storage"; - THPObjectPtr storage(PyObject_GetAttrString(module_obj.get(), storage_name.c_str())); + THPObjectPtr storage( + PyObject_GetAttrString(module_obj.get(), storage_name.c_str())); if (!storage.get()) { throw TypeError("couldn't find storage object %s", storage_name.c_str()); } return storage; } -static void set_type(PyTensorType& type_obj, Backend backend, ScalarType scalarType) { +static void set_type( + PyTensorType& type_obj, + Backend backend, + ScalarType scalarType) { // This field is lazily initialized from backend and scalar_type type_obj.backend = static_cast(backend); type_obj.scalar_type = static_cast(scalarType); type_obj.layout = torch::getTHPLayout(layout_from_backend(backend)); type_obj.dtype = torch::getTHPDtype(scalarType); - type_obj.is_cuda = (backend == at::Backend::CUDA || backend == at::Backend::SparseCUDA); + type_obj.is_cuda = + (backend == at::Backend::CUDA || backend == at::Backend::SparseCUDA); } static void set_name(PyTensorType& type_obj, const std::string& name) { @@ -236,16 +259,19 @@ static void set_name(PyTensorType& type_obj, const std::string& name) { static THPObjectPtr get_tensor_dict() { auto torch = THPObjectPtr(PyImport_ImportModule("torch")); - if (!torch) throw python_error(); + if (!torch) + throw python_error(); auto tensor_class = THPObjectPtr(PyObject_GetAttrString(torch, "Tensor")); - if (!tensor_class) throw python_error(); + if (!tensor_class) + throw python_error(); auto tensor_type = (PyTypeObject*)tensor_class.get(); TORCH_CHECK(tensor_type->tp_base, "missing base type for Tensor"); auto res = THPObjectPtr(PyDict_New()); - if (!res) throw python_error(); + if (!res) + throw python_error(); if (PyDict_Merge(res.get(), tensor_type->tp_dict, 0) < 0) { throw python_error(); @@ -262,8 +288,8 @@ static THPObjectPtr get_tensor_dict() { // dynamically at init time, because their exact number depends on // torch::utils::all_declared_types(). The memory for each PyTensorType is // allocated by initialize_aten_types() and never freed: technically it's a -// leak, but it's not a problem since we want them to be alive for the whole time -// of the process anyway. +// leak, but it's not a problem since we want them to be alive for the whole +// time of the process anyway. // // An alternative is to use a std::vector instead, and let // std::vector to manage the lifetime of its items. This is problematic @@ -279,24 +305,32 @@ void set_default_storage_type(Backend backend, ScalarType dtype) { THPObjectPtr storage = get_storage_obj(backend, dtype); auto torch_module = THPObjectPtr(PyImport_ImportModule("torch")); - if (!torch_module) throw python_error(); + if (!torch_module) + throw python_error(); if (PyObject_SetAttrString(torch_module.get(), "Storage", storage) != 0) { throw python_error(); } } -void set_default_tensor_type(c10::optional backend, c10::optional dtype) { +void set_default_tensor_type( + c10::optional backend, + c10::optional dtype) { if (backend.has_value()) { - TORCH_CHECK_TYPE(*backend != Backend::Undefined, "default type cannot be undefined"); - TORCH_CHECK_TYPE(!isSparse(*backend), "only dense types are supported as the default type"); + TORCH_CHECK_TYPE( + *backend != Backend::Undefined, "default type cannot be undefined"); + TORCH_CHECK_TYPE( + !isSparse(*backend), + "only dense types are supported as the default type"); } if (dtype.has_value()) { - TORCH_CHECK_TYPE(at::isFloatingType(*dtype), - "only floating-point types are supported as the default type"); + TORCH_CHECK_TYPE( + at::isFloatingType(*dtype), + "only floating-point types are supported as the default type"); } - // Try setting default storage in python first as it's the only operation that can fail + // Try setting default storage in python first as it's the only operation that + // can fail set_default_storage_type( backend.value_or(default_backend), dtype.value_or(at::get_default_dtype_as_scalartype())); @@ -341,9 +375,11 @@ void initialize_python_bindings() { // `torch.FloatTensor.add`. auto tensor_dict = get_tensor_dict(); - // Initialize each Python type object torch.FloatTensor, torch.DoubleTensor, etc. + // Initialize each Python type object torch.FloatTensor, torch.DoubleTensor, + // etc. for (auto& tensor_type : tensor_types) { - py_initialize_tensor_type(tensor_type->py_type, tensor_type->name, tensor_dict.get()); + py_initialize_tensor_type( + tensor_type->py_type, tensor_type->name, tensor_dict.get()); } // Add the type objects to their corresponding modules. e.g. torch.FloatTensor @@ -352,12 +388,16 @@ void initialize_python_bindings() { py_bind_tensor_types(tensor_types); } -static void py_bind_tensor_types(const std::vector& tensor_types) { +static void py_bind_tensor_types( + const std::vector& tensor_types) { auto torch_module = THPObjectPtr(PyImport_ImportModule("torch")); - if (!torch_module) throw python_error(); + if (!torch_module) + throw python_error(); - auto tensor_classes = THPObjectPtr(PyObject_GetAttrString(torch_module.get(), "_tensor_classes")); - if (!tensor_classes) throw python_error(); + auto tensor_classes = THPObjectPtr( + PyObject_GetAttrString(torch_module.get(), "_tensor_classes")); + if (!tensor_classes) + throw python_error(); for (auto& tensor_type : tensor_types) { auto name = std::string(tensor_type->name); @@ -366,7 +406,8 @@ static void py_bind_tensor_types(const std::vector& tensor_types) auto module_name = name.substr(0, idx); auto module_obj = THPObjectPtr(PyImport_ImportModule(module_name.c_str())); - if (!module_obj) throw python_error(); + if (!module_obj) + throw python_error(); PyObject* type_obj = (PyObject*)tensor_type; Py_INCREF(type_obj); @@ -380,17 +421,17 @@ static void py_bind_tensor_types(const std::vector& tensor_types) } static bool PyTensorType_Check(PyObject* obj) { - auto it = std::find_if(tensor_types.begin(), tensor_types.end(), - [obj](PyTensorType *x) { - return (PyObject*)x == obj; - }); + auto it = std::find_if( + tensor_types.begin(), tensor_types.end(), [obj](PyTensorType* x) { + return (PyObject*)x == obj; + }); return it != tensor_types.end(); } void py_set_default_tensor_type(PyObject* obj) { // NOLINTNEXTLINE(cppcoreguidelines-init-variables) TORCH_CHECK_TYPE(PyTensorType_Check(obj), "invalid type object"); - PyTensorType *type = (PyTensorType*)obj; + PyTensorType* type = (PyTensorType*)obj; if (type->is_cuda && !torch::utils::cuda_enabled()) { throw unavailable_type(*type); } @@ -411,4 +452,5 @@ ScalarType get_default_scalar_type() { return get_default_dtype_as_scalartype(); } -}} // namespace torch::tensors +} // namespace tensors +} // namespace torch diff --git a/torch/csrc/tensor/python_tensor.h b/torch/csrc/tensor/python_tensor.h index acf73b9e029386..87a66012b49551 100644 --- a/torch/csrc/tensor/python_tensor.h +++ b/torch/csrc/tensor/python_tensor.h @@ -1,8 +1,8 @@ #pragma once -#include -#include #include +#include +#include namespace c10 { struct Device; @@ -12,7 +12,8 @@ namespace at { class Tensor; } // namespace at -namespace torch { namespace tensors { +namespace torch { +namespace tensors { // Initializes the Python tensor type objects: torch.FloatTensor, // torch.DoubleTensor, etc. and binds them in their containing modules. @@ -33,4 +34,5 @@ c10::DispatchKey get_default_dispatch_key(); // Gets the ScalarType for the default tensor type. at::ScalarType get_default_scalar_type(); -}} // namespace torch::tensors +} // namespace tensors +} // namespace torch diff --git a/torch/csrc/utils.cpp b/torch/csrc/utils.cpp index 08de528a427fa4..990e094f82c3f4 100644 --- a/torch/csrc/utils.cpp +++ b/torch/csrc/utils.cpp @@ -1,9 +1,9 @@ -#include +#include #include -#include -#include #include -#include +#include +#include +#include #include @@ -15,14 +15,14 @@ #include #include -int THPUtils_getCallable(PyObject *arg, PyObject **result) { +int THPUtils_getCallable(PyObject* arg, PyObject** result) { if (!PyCallable_Check(arg)) return 0; *result = arg; return 1; } -std::vector THPUtils_unpackLongs(PyObject *arg) { +std::vector THPUtils_unpackLongs(PyObject* arg) { bool tuple = PyTuple_Check(arg); bool list = PyList_Check(arg); if (tuple || list) { @@ -30,10 +30,12 @@ std::vector THPUtils_unpackLongs(PyObject *arg) { const auto nDim = tuple ? PyTuple_GET_SIZE(arg) : PyList_GET_SIZE(arg); std::vector sizes(nDim); for (int i = 0; i != nDim; ++i) { - PyObject* item = tuple ? PyTuple_GET_ITEM(arg, i) : PyList_GET_ITEM(arg, i); + PyObject* item = + tuple ? PyTuple_GET_ITEM(arg, i) : PyList_GET_ITEM(arg, i); if (!THPUtils_checkLong(item)) { std::ostringstream oss; - oss << "expected int at position " << i << ", but got: " << THPUtils_typename(item); + oss << "expected int at position " << i + << ", but got: " << THPUtils_typename(item); throw std::runtime_error(oss.str()); } sizes[i] = THPUtils_unpackLong(item); @@ -43,8 +45,7 @@ std::vector THPUtils_unpackLongs(PyObject *arg) { throw std::runtime_error("Expected tuple or list"); } -bool THPUtils_checkIntTuple(PyObject *arg) -{ +bool THPUtils_checkIntTuple(PyObject* arg) { if (!PyTuple_Check(arg)) { return false; } @@ -56,8 +57,7 @@ bool THPUtils_checkIntTuple(PyObject *arg) return true; } -std::vector THPUtils_unpackIntTuple(PyObject *arg) -{ +std::vector THPUtils_unpackIntTuple(PyObject* arg) { if (!THPUtils_checkIntTuple(arg)) { throw std::runtime_error("Couldn't unpack int tuple"); } @@ -68,8 +68,7 @@ std::vector THPUtils_unpackIntTuple(PyObject *arg) return values; } -void THPUtils_setError(const char *format, ...) -{ +void THPUtils_setError(const char* format, ...) { static const size_t ERROR_BUFFER_SIZE = 1000; // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays) char buffer[ERROR_BUFFER_SIZE]; @@ -81,8 +80,9 @@ void THPUtils_setError(const char *format, ...) PyErr_SetString(PyExc_RuntimeError, buffer); } -void THPUtils_addPyMethodDefs(std::vector& vector, PyMethodDef* methods) -{ +void THPUtils_addPyMethodDefs( + std::vector& vector, + PyMethodDef* methods) { if (!vector.empty()) { // remove nullptr terminator vector.pop_back(); @@ -103,10 +103,13 @@ static const char* classOrTypename(PyObject* obj) { return Py_TYPE(obj)->tp_name; } -PyObject * THPUtils_dispatchStateless( - PyObject *tensor, const char *name, PyObject *args, PyObject *kwargs) -{ - THPObjectPtr methods(PyObject_GetAttrString(tensor, THP_STATELESS_ATTRIBUTE_NAME)); +PyObject* THPUtils_dispatchStateless( + PyObject* tensor, + const char* name, + PyObject* args, + PyObject* kwargs) { + THPObjectPtr methods( + PyObject_GetAttrString(tensor, THP_STATELESS_ATTRIBUTE_NAME)); if (!methods) { return PyErr_Format( PyExc_TypeError, @@ -124,21 +127,29 @@ PyObject * THPUtils_dispatchStateless( return PyObject_Call(method.get(), args, kwargs); } -void THPUtils_invalidArguments(PyObject *given_args, PyObject *given_kwargs, - const char *function_name, size_t num_options, ...) { +void THPUtils_invalidArguments( + PyObject* given_args, + PyObject* given_kwargs, + const char* function_name, + size_t num_options, + ...) { std::vector option_strings; va_list option_list; va_start(option_list, num_options); - std::generate_n(std::back_inserter(option_strings), num_options, [&option_list] { - return va_arg(option_list, const char*); - }); + std::generate_n( + std::back_inserter(option_strings), num_options, [&option_list] { + return va_arg(option_list, const char*); + }); va_end(option_list); - PyErr_SetString(PyExc_TypeError, torch::format_invalid_args( - given_args, given_kwargs, function_name, option_strings).c_str()); + PyErr_SetString( + PyExc_TypeError, + torch::format_invalid_args( + given_args, given_kwargs, function_name, option_strings) + .c_str()); } -template<> +template <> void THPPointer::free() { if (ptr) Py_DECREF(ptr); @@ -166,17 +177,17 @@ bool getBackCompatKeepdimWarn() { return backCompatKeepdimWarn; } -bool maybeThrowBackCompatKeepdimWarn(char *func) { - if(getBackCompatKeepdimWarn()) { - std::ostringstream ss; - ss << "backwards compatibility: call to \"" << func - << "\" uses default value for keepdim which has changed default to False. Consider passing as kwarg.", - PyErr_WarnEx(PyExc_UserWarning, ss.str().c_str(), 1); +bool maybeThrowBackCompatKeepdimWarn(char* func) { + if (getBackCompatKeepdimWarn()) { + std::ostringstream ss; + ss << "backwards compatibility: call to \"" << func + << "\" uses default value for keepdim which has changed default to False. Consider passing as kwarg.", + PyErr_WarnEx(PyExc_UserWarning, ss.str().c_str(), 1); } return true; } -template<> +template <> void THPPointer::free() { if (ptr) Py_DECREF(ptr); @@ -198,14 +209,18 @@ void storage_fill(at::Storage self, uint8_t value) { } void storage_set(at::Storage self, ptrdiff_t idx, uint8_t value) { - TORCH_CHECK((idx >= 0) && (idx < static_cast(self.nbytes())), "out of bounds"); + TORCH_CHECK( + (idx >= 0) && (idx < static_cast(self.nbytes())), + "out of bounds"); auto options = c10::TensorOptions().device(self.device()).dtype(at::kByte); auto self_t = at::empty({0}, {}, options).set_(self); self_t[idx].fill_(value); } uint8_t storage_get(at::Storage self, ptrdiff_t idx) { - TORCH_CHECK((idx >= 0) && (idx < static_cast(self.nbytes())), "out of bounds"); + TORCH_CHECK( + (idx >= 0) && (idx < static_cast(self.nbytes())), + "out of bounds"); auto options = c10::TensorOptions().device(self.device()).dtype(at::kByte); auto self_t = at::empty({0}, {}, options).set_(self); return self_t[idx].item(); @@ -213,7 +228,8 @@ uint8_t storage_get(at::Storage self, ptrdiff_t idx) { template class THPPointer; -namespace torch { namespace gdb { +namespace torch { +namespace gdb { /* ~~~ misc debugging utilities ~~~ * * torch::gdb::* functions are NOT meant to be called by general pytorch code, @@ -228,14 +244,14 @@ namespace torch { namespace gdb { // call free than delete[] from withing gdb. // Currently the code for computing the repr of a tensor is written in Python, // so we need to wrap the Tensor into a Python object first. -char *tensor_repr(at::Tensor tensor) { +char* tensor_repr(at::Tensor tensor) { PyGILState_STATE gil = PyGILState_Ensure(); - PyObject *pytensor = nullptr; - PyObject *repr = nullptr; + PyObject* pytensor = nullptr; + PyObject* repr = nullptr; // NOLINTNEXTLINE(cppcoreguidelines-init-variables) Py_ssize_t bufsize; - const char *buf = nullptr; - char *result = nullptr; + const char* buf = nullptr; + char* result = nullptr; pytensor = THPVariable_Wrap(at::Tensor(tensor)); if (!pytensor) @@ -250,7 +266,8 @@ char *tensor_repr(at::Tensor tensor) { // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto) goto error; // NOLINTNEXTLINE(cppcoreguidelines-no-malloc) - result = static_cast(malloc(bufsize + 1)); // account for the trailing \0 + result = + static_cast(malloc(bufsize + 1)); // account for the trailing \0 if (!result) { fprintf(stderr, "cannot allocate memory for the result\n"); // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto) @@ -275,4 +292,5 @@ char *tensor_repr(at::Tensor tensor) { return nullptr; } -}} // namespace torch::gdb +} // namespace gdb +} // namespace torch diff --git a/torch/csrc/utils.h b/torch/csrc/utils.h index e014a563372ab1..b45120b18a3d8a 100644 --- a/torch/csrc/utils.h +++ b/torch/csrc/utils.h @@ -1,20 +1,20 @@ #ifndef THP_UTILS_H #define THP_UTILS_H -#include -#include -#include #include #include #include -#include #include +#include +#include +#include +#include #ifdef USE_CUDA #include #endif -#define THPUtils_(NAME) TH_CONCAT_4(THP,Real,Utils_,NAME) +#define THPUtils_(NAME) TH_CONCAT_4(THP, Real, Utils_, NAME) #define THPUtils_typename(obj) (Py_TYPE(obj)->tp_name) @@ -24,124 +24,150 @@ #define THP_EXPECT(x, y) (x) #endif -#define THPUtils_checkReal_FLOAT(object) \ - (PyFloat_Check(object) || PyLong_Check(object)) +#define THPUtils_checkReal_FLOAT(object) \ + (PyFloat_Check(object) || PyLong_Check(object)) -#define THPUtils_unpackReal_FLOAT(object) \ - (PyFloat_Check(object) ? PyFloat_AsDouble(object) : \ - PyLong_Check(object) ? PyLong_AsLongLong(object) : \ - (throw std::runtime_error("Could not parse real"), 0)) +#define THPUtils_unpackReal_FLOAT(object) \ + (PyFloat_Check(object) ? PyFloat_AsDouble(object) \ + : PyLong_Check(object) \ + ? PyLong_AsLongLong(object) \ + : (throw std::runtime_error("Could not parse real"), 0)) -#define THPUtils_checkReal_INT(object) \ - PyLong_Check(object) +#define THPUtils_checkReal_INT(object) PyLong_Check(object) -#define THPUtils_unpackReal_INT(object) \ - (PyLong_Check(object) ? PyLong_AsLongLong(object) : \ - (throw std::runtime_error("Could not parse real"), 0)) +#define THPUtils_unpackReal_INT(object) \ + (PyLong_Check(object) \ + ? PyLong_AsLongLong(object) \ + : (throw std::runtime_error("Could not parse real"), 0)) -#define THPUtils_unpackReal_BOOL(object) \ - (PyBool_Check(object) ? object : \ - (throw std::runtime_error("Could not parse real"), Py_False)) +#define THPUtils_unpackReal_BOOL(object) \ + (PyBool_Check(object) \ + ? object \ + : (throw std::runtime_error("Could not parse real"), Py_False)) -#define THPUtils_unpackReal_COMPLEX(object) \ - (PyComplex_Check(object) ? \ - (c10::complex(PyComplex_RealAsDouble(object), PyComplex_ImagAsDouble(object))) : \ - PyFloat_Check(object) ? (c10::complex(PyFloat_AsDouble(object), 0)) : \ - PyLong_Check(object) ? (c10::complex(PyLong_AsLongLong(object), 0)) : \ - (throw std::runtime_error("Could not parse real"), c10::complex(0,0))) \ +#define THPUtils_unpackReal_COMPLEX(object) \ + (PyComplex_Check(object) \ + ? (c10::complex( \ + PyComplex_RealAsDouble(object), PyComplex_ImagAsDouble(object))) \ + : PyFloat_Check(object) \ + ? (c10::complex(PyFloat_AsDouble(object), 0)) \ + : PyLong_Check(object) \ + ? (c10::complex(PyLong_AsLongLong(object), 0)) \ + : (throw std::runtime_error("Could not parse real"), \ + c10::complex(0, 0))) -#define THPUtils_checkReal_BOOL(object) \ - PyBool_Check(object) +#define THPUtils_checkReal_BOOL(object) PyBool_Check(object) -#define THPUtils_checkReal_COMPLEX(object) \ - PyComplex_Check(object) || PyFloat_Check(object) || PyLong_Check(object) || PyInt_Check(object) +#define THPUtils_checkReal_COMPLEX(object) \ + PyComplex_Check(object) || PyFloat_Check(object) || PyLong_Check(object) || \ + PyInt_Check(object) #define THPUtils_newReal_FLOAT(value) PyFloat_FromDouble(value) #define THPUtils_newReal_INT(value) PyInt_FromLong(value) #define THPUtils_newReal_BOOL(value) PyBool_FromLong(value) -#define THPUtils_newReal_COMPLEX(value) PyComplex_FromDoubles(value.real(), value.imag()) - -#define THPDoubleUtils_checkReal(object) THPUtils_checkReal_FLOAT(object) -#define THPDoubleUtils_unpackReal(object) (double)THPUtils_unpackReal_FLOAT(object) -#define THPDoubleUtils_newReal(value) THPUtils_newReal_FLOAT(value) -#define THPFloatUtils_checkReal(object) THPUtils_checkReal_FLOAT(object) -#define THPFloatUtils_unpackReal(object) (float)THPUtils_unpackReal_FLOAT(object) -#define THPFloatUtils_newReal(value) THPUtils_newReal_FLOAT(value) -#define THPHalfUtils_checkReal(object) THPUtils_checkReal_FLOAT(object) -#define THPHalfUtils_unpackReal(object) (at::Half)THPUtils_unpackReal_FLOAT(object) -#define THPHalfUtils_newReal(value) PyFloat_FromDouble(value) -#define THPHalfUtils_newAccreal(value) THPUtils_newReal_FLOAT(value) -#define THPComplexDoubleUtils_checkReal(object) THPUtils_checkReal_COMPLEX(object) -#define THPComplexDoubleUtils_unpackReal(object) THPUtils_unpackReal_COMPLEX(object) -#define THPComplexDoubleUtils_newReal(value) THPUtils_newReal_COMPLEX(value) -#define THPComplexFloatUtils_checkReal(object) THPUtils_checkReal_COMPLEX(object) -#define THPComplexFloatUtils_unpackReal(object) (c10::complex)THPUtils_unpackReal_COMPLEX(object) -#define THPComplexFloatUtils_newReal(value) THPUtils_newReal_COMPLEX(value) -#define THPBFloat16Utils_checkReal(object) THPUtils_checkReal_FLOAT(object) -#define THPBFloat16Utils_unpackReal(object) (at::BFloat16)THPUtils_unpackReal_FLOAT(object) -#define THPBFloat16Utils_newReal(value) PyFloat_FromDouble(value) -#define THPBFloat16Utils_newAccreal(value) THPUtils_newReal_FLOAT(value) - -#define THPBoolUtils_checkReal(object) THPUtils_checkReal_BOOL(object) -#define THPBoolUtils_unpackReal(object) THPUtils_unpackReal_BOOL(object) -#define THPBoolUtils_newReal(value) THPUtils_newReal_BOOL(value) -#define THPBoolUtils_checkAccreal(object) THPUtils_checkReal_BOOL(object) -#define THPBoolUtils_unpackAccreal(object) (int64_t)THPUtils_unpackReal_BOOL(object) -#define THPBoolUtils_newAccreal(value) THPUtils_newReal_BOOL(value) -#define THPLongUtils_checkReal(object) THPUtils_checkReal_INT(object) -#define THPLongUtils_unpackReal(object) (int64_t)THPUtils_unpackReal_INT(object) -#define THPLongUtils_newReal(value) THPUtils_newReal_INT(value) -#define THPIntUtils_checkReal(object) THPUtils_checkReal_INT(object) -#define THPIntUtils_unpackReal(object) (int)THPUtils_unpackReal_INT(object) -#define THPIntUtils_newReal(value) THPUtils_newReal_INT(value) -#define THPShortUtils_checkReal(object) THPUtils_checkReal_INT(object) -#define THPShortUtils_unpackReal(object) (short)THPUtils_unpackReal_INT(object) -#define THPShortUtils_newReal(value) THPUtils_newReal_INT(value) -#define THPCharUtils_checkReal(object) THPUtils_checkReal_INT(object) -#define THPCharUtils_unpackReal(object) (char)THPUtils_unpackReal_INT(object) -#define THPCharUtils_newReal(value) THPUtils_newReal_INT(value) -#define THPByteUtils_checkReal(object) THPUtils_checkReal_INT(object) -#define THPByteUtils_unpackReal(object) (unsigned char)THPUtils_unpackReal_INT(object) -#define THPByteUtils_newReal(value) THPUtils_newReal_INT(value) +#define THPUtils_newReal_COMPLEX(value) \ + PyComplex_FromDoubles(value.real(), value.imag()) + +#define THPDoubleUtils_checkReal(object) THPUtils_checkReal_FLOAT(object) +#define THPDoubleUtils_unpackReal(object) \ + (double)THPUtils_unpackReal_FLOAT(object) +#define THPDoubleUtils_newReal(value) THPUtils_newReal_FLOAT(value) +#define THPFloatUtils_checkReal(object) THPUtils_checkReal_FLOAT(object) +#define THPFloatUtils_unpackReal(object) \ + (float)THPUtils_unpackReal_FLOAT(object) +#define THPFloatUtils_newReal(value) THPUtils_newReal_FLOAT(value) +#define THPHalfUtils_checkReal(object) THPUtils_checkReal_FLOAT(object) +#define THPHalfUtils_unpackReal(object) \ + (at::Half) THPUtils_unpackReal_FLOAT(object) +#define THPHalfUtils_newReal(value) PyFloat_FromDouble(value) +#define THPHalfUtils_newAccreal(value) THPUtils_newReal_FLOAT(value) +#define THPComplexDoubleUtils_checkReal(object) \ + THPUtils_checkReal_COMPLEX(object) +#define THPComplexDoubleUtils_unpackReal(object) \ + THPUtils_unpackReal_COMPLEX(object) +#define THPComplexDoubleUtils_newReal(value) THPUtils_newReal_COMPLEX(value) +#define THPComplexFloatUtils_checkReal(object) \ + THPUtils_checkReal_COMPLEX(object) +#define THPComplexFloatUtils_unpackReal(object) \ + (c10::complex)THPUtils_unpackReal_COMPLEX(object) +#define THPComplexFloatUtils_newReal(value) THPUtils_newReal_COMPLEX(value) +#define THPBFloat16Utils_checkReal(object) THPUtils_checkReal_FLOAT(object) +#define THPBFloat16Utils_unpackReal(object) \ + (at::BFloat16) THPUtils_unpackReal_FLOAT(object) +#define THPBFloat16Utils_newReal(value) PyFloat_FromDouble(value) +#define THPBFloat16Utils_newAccreal(value) THPUtils_newReal_FLOAT(value) + +#define THPBoolUtils_checkReal(object) THPUtils_checkReal_BOOL(object) +#define THPBoolUtils_unpackReal(object) THPUtils_unpackReal_BOOL(object) +#define THPBoolUtils_newReal(value) THPUtils_newReal_BOOL(value) +#define THPBoolUtils_checkAccreal(object) THPUtils_checkReal_BOOL(object) +#define THPBoolUtils_unpackAccreal(object) \ + (int64_t) THPUtils_unpackReal_BOOL(object) +#define THPBoolUtils_newAccreal(value) THPUtils_newReal_BOOL(value) +#define THPLongUtils_checkReal(object) THPUtils_checkReal_INT(object) +#define THPLongUtils_unpackReal(object) \ + (int64_t) THPUtils_unpackReal_INT(object) +#define THPLongUtils_newReal(value) THPUtils_newReal_INT(value) +#define THPIntUtils_checkReal(object) THPUtils_checkReal_INT(object) +#define THPIntUtils_unpackReal(object) (int)THPUtils_unpackReal_INT(object) +#define THPIntUtils_newReal(value) THPUtils_newReal_INT(value) +#define THPShortUtils_checkReal(object) THPUtils_checkReal_INT(object) +#define THPShortUtils_unpackReal(object) (short)THPUtils_unpackReal_INT(object) +#define THPShortUtils_newReal(value) THPUtils_newReal_INT(value) +#define THPCharUtils_checkReal(object) THPUtils_checkReal_INT(object) +#define THPCharUtils_unpackReal(object) (char)THPUtils_unpackReal_INT(object) +#define THPCharUtils_newReal(value) THPUtils_newReal_INT(value) +#define THPByteUtils_checkReal(object) THPUtils_checkReal_INT(object) +#define THPByteUtils_unpackReal(object) \ + (unsigned char)THPUtils_unpackReal_INT(object) +#define THPByteUtils_newReal(value) THPUtils_newReal_INT(value) // quantized types -#define THPQUInt8Utils_checkReal(object) THPUtils_checkReal_INT(object) -#define THPQUInt8Utils_unpackReal(object) (int)THPUtils_unpackReal_INT(object) -#define THPQUInt8Utils_newReal(value) THPUtils_newReal_INT(value) -#define THPQInt8Utils_checkReal(object) THPUtils_checkReal_INT(object) -#define THPQInt8Utils_unpackReal(object) (int)THPUtils_unpackReal_INT(object) -#define THPQInt8Utils_newReal(value) THPUtils_newReal_INT(value) -#define THPQInt32Utils_checkReal(object) THPUtils_checkReal_INT(object) -#define THPQInt32Utils_unpackReal(object) (int)THPUtils_unpackReal_INT(object) -#define THPQInt32Utils_newReal(value) THPUtils_newReal_INT(value) -#define THPQUInt4x2Utils_checkReal(object) THPUtils_checkReal_INT(object) -#define THPQUInt4x2Utils_unpackReal(object) (int)THPUtils_unpackReal_INT(object) -#define THPQUInt4x2Utils_newReal(value) THPUtils_newReal_INT(value) -#define THPQUInt2x4Utils_checkReal(object) THPUtils_checkReal_INT(object) -#define THPQUInt2x4Utils_unpackReal(object) (int)THPUtils_unpackReal_INT(object) -#define THPQUInt2x4Utils_newReal(value) THPUtils_newReal_INT(value) - - -#define THPUtils_assert(cond, ...) THPUtils_assertRet(nullptr, cond, __VA_ARGS__) -#define THPUtils_assertRet(value, cond, ...) \ -if (THP_EXPECT(!(cond), 0)) { THPUtils_setError(__VA_ARGS__); return value; } -TORCH_PYTHON_API void THPUtils_setError(const char *format, ...); +#define THPQUInt8Utils_checkReal(object) THPUtils_checkReal_INT(object) +#define THPQUInt8Utils_unpackReal(object) (int)THPUtils_unpackReal_INT(object) +#define THPQUInt8Utils_newReal(value) THPUtils_newReal_INT(value) +#define THPQInt8Utils_checkReal(object) THPUtils_checkReal_INT(object) +#define THPQInt8Utils_unpackReal(object) (int)THPUtils_unpackReal_INT(object) +#define THPQInt8Utils_newReal(value) THPUtils_newReal_INT(value) +#define THPQInt32Utils_checkReal(object) THPUtils_checkReal_INT(object) +#define THPQInt32Utils_unpackReal(object) (int)THPUtils_unpackReal_INT(object) +#define THPQInt32Utils_newReal(value) THPUtils_newReal_INT(value) +#define THPQUInt4x2Utils_checkReal(object) THPUtils_checkReal_INT(object) +#define THPQUInt4x2Utils_unpackReal(object) (int)THPUtils_unpackReal_INT(object) +#define THPQUInt4x2Utils_newReal(value) THPUtils_newReal_INT(value) +#define THPQUInt2x4Utils_checkReal(object) THPUtils_checkReal_INT(object) +#define THPQUInt2x4Utils_unpackReal(object) (int)THPUtils_unpackReal_INT(object) +#define THPQUInt2x4Utils_newReal(value) THPUtils_newReal_INT(value) + +#define THPUtils_assert(cond, ...) \ + THPUtils_assertRet(nullptr, cond, __VA_ARGS__) +#define THPUtils_assertRet(value, cond, ...) \ + if (THP_EXPECT(!(cond), 0)) { \ + THPUtils_setError(__VA_ARGS__); \ + return value; \ + } +TORCH_PYTHON_API void THPUtils_setError(const char* format, ...); TORCH_PYTHON_API void THPUtils_invalidArguments( - PyObject *given_args, PyObject *given_kwargs, - const char *function_name, size_t num_options, ...); + PyObject* given_args, + PyObject* given_kwargs, + const char* function_name, + size_t num_options, + ...); -bool THPUtils_checkIntTuple(PyObject *arg); -std::vector THPUtils_unpackIntTuple(PyObject *arg); +bool THPUtils_checkIntTuple(PyObject* arg); +std::vector THPUtils_unpackIntTuple(PyObject* arg); -void THPUtils_addPyMethodDefs(std::vector& vector, PyMethodDef* methods); +void THPUtils_addPyMethodDefs( + std::vector& vector, + PyMethodDef* methods); -int THPUtils_getCallable(PyObject *arg, PyObject **result); +int THPUtils_getCallable(PyObject* arg, PyObject** result); -#define THWTensorPtr TH_CONCAT_3(TH,Real,TensorPtr) -#define THPStoragePtr TH_CONCAT_2(THP,StoragePtr) -#define THPTensorPtr TH_CONCAT_3(THP,Real,TensorPtr) -#define THSPTensorPtr TH_CONCAT_3(THSP,Real,TensorPtr) +#define THWTensorPtr TH_CONCAT_3(TH, Real, TensorPtr) +#define THPStoragePtr TH_CONCAT_2(THP, StoragePtr) +#define THPTensorPtr TH_CONCAT_3(THP, Real, TensorPtr) +#define THSPTensorPtr TH_CONCAT_3(THSP, Real, TensorPtr) typedef THPPointer THPGeneratorPtr; @@ -155,20 +181,32 @@ struct THPUtils_typeTraits {}; #include // clang-format on -std::vector THPUtils_unpackLongs(PyObject *arg); -PyObject * THPUtils_dispatchStateless(PyObject *tensor, const char *name, PyObject *args, PyObject *kwargs); +std::vector THPUtils_unpackLongs(PyObject* arg); +PyObject* THPUtils_dispatchStateless( + PyObject* tensor, + const char* name, + PyObject* args, + PyObject* kwargs); -template +template struct mod_traits {}; -template -struct mod_traits<_real, typename std::enable_if::value>::type> { - static _real mod(_real a, _real b) { return fmod(a, b); } +template +struct mod_traits< + _real, + typename std::enable_if::value>::type> { + static _real mod(_real a, _real b) { + return fmod(a, b); + } }; -template -struct mod_traits<_real, typename std::enable_if::value>::type> { - static _real mod(_real a, _real b) { return a % b; } +template +struct mod_traits< + _real, + typename std::enable_if::value>::type> { + static _real mod(_real a, _real b) { + return a % b; + } }; void setBackCompatBroadcastWarn(bool warn); @@ -176,14 +214,15 @@ bool getBackCompatBroadcastWarn(); void setBackCompatKeepdimWarn(bool warn); bool getBackCompatKeepdimWarn(); -bool maybeThrowBackCompatKeepdimWarn(char *func); +bool maybeThrowBackCompatKeepdimWarn(char* func); // NB: This is in torch/csrc/cuda/utils.cpp, for whatever reason #ifdef USE_CUDA -std::vector> THPUtils_PySequence_to_CUDAStreamList(PyObject *obj); +std::vector> +THPUtils_PySequence_to_CUDAStreamList(PyObject* obj); #endif -void storage_copy(at::Storage dst, at::Storage src, bool non_blocking=false); +void storage_copy(at::Storage dst, at::Storage src, bool non_blocking = false); void storage_fill(at::Storage self, uint8_t value); void storage_set(at::Storage self, ptrdiff_t idx, uint8_t value); uint8_t storage_get(at::Storage self, ptrdiff_t idx); diff --git a/torch/csrc/utils/auto_gil.h b/torch/csrc/utils/auto_gil.h index 6100f586e2879b..9b6ee12bd5b871 100644 --- a/torch/csrc/utils/auto_gil.h +++ b/torch/csrc/utils/auto_gil.h @@ -10,7 +10,8 @@ // Acquires the GIL on construction struct /* C10_DEPRECATED_MESSAGE( - "Use pybind11::gil_scoped_acquire instead") */ AutoGIL { + "Use pybind11::gil_scoped_acquire instead") */ + AutoGIL { AutoGIL() : gstate(PyGILState_Ensure()) {} ~AutoGIL() { PyGILState_Release(gstate); @@ -21,7 +22,8 @@ struct /* C10_DEPRECATED_MESSAGE( // Releases the GIL on construction struct /* C10_DEPRECATED_MESSAGE( - "Use pybind11::gil_scoped_release instead") */ AutoNoGIL { + "Use pybind11::gil_scoped_release instead") */ + AutoNoGIL { AutoNoGIL() : save(PyEval_SaveThread()) {} ~AutoNoGIL() { PyEval_RestoreThread(save); diff --git a/torch/csrc/utils/byte_order.cpp b/torch/csrc/utils/byte_order.cpp index d54c81dc15c740..c4039de579938b 100644 --- a/torch/csrc/utils/byte_order.cpp +++ b/torch/csrc/utils/byte_order.cpp @@ -1,6 +1,6 @@ -#include #include #include +#include #include #include @@ -11,8 +11,7 @@ namespace { -static inline void swapBytes16(void *ptr) -{ +static inline void swapBytes16(void* ptr) { // NOLINTNEXTLINE(cppcoreguidelines-init-variables) uint16_t output; memcpy(&output, ptr, sizeof(uint16_t)); @@ -28,8 +27,7 @@ static inline void swapBytes16(void *ptr) memcpy(ptr, &output, sizeof(uint16_t)); } -static inline void swapBytes32(void *ptr) -{ +static inline void swapBytes32(void* ptr) { // NOLINTNEXTLINE(cppcoreguidelines-init-variables) uint32_t output; memcpy(&output, ptr, sizeof(uint32_t)); @@ -38,17 +36,16 @@ static inline void swapBytes32(void *ptr) #elif defined(__llvm__) || defined(__GNUC__) && !defined(__ICC) output = __builtin_bswap32(output); #else - uint32_t Byte0 = output & 0x000000FF; - uint32_t Byte1 = output & 0x0000FF00; - uint32_t Byte2 = output & 0x00FF0000; - uint32_t Byte3 = output & 0xFF000000; - output = (Byte0 << 24) | (Byte1 << 8) | (Byte2 >> 8) | (Byte3 >> 24); + uint32_t Byte0 = output & 0x000000FF; + uint32_t Byte1 = output & 0x0000FF00; + uint32_t Byte2 = output & 0x00FF0000; + uint32_t Byte3 = output & 0xFF000000; + output = (Byte0 << 24) | (Byte1 << 8) | (Byte2 >> 8) | (Byte3 >> 24); #endif memcpy(ptr, &output, sizeof(uint32_t)); } -static inline void swapBytes64(void *ptr) -{ +static inline void swapBytes64(void* ptr) { // NOLINTNEXTLINE(cppcoreguidelines-init-variables) uint64_t output; memcpy(&output, ptr, sizeof(uint64_t)); @@ -65,46 +62,47 @@ static inline void swapBytes64(void *ptr) uint64_t Byte5 = output & 0x0000FF0000000000; uint64_t Byte6 = output & 0x00FF000000000000; uint64_t Byte7 = output & 0xFF00000000000000; - output = (Byte0 << (7*8)) | (Byte1 << (5*8)) | (Byte2 << (3*8)) | (Byte3 << (1*8)) | - (Byte7 >> (7*8)) | (Byte6 >> (5*8)) | (Byte5 >> (3*8)) | (Byte4 >> (1*8)); + output = (Byte0 << (7 * 8)) | (Byte1 << (5 * 8)) | (Byte2 << (3 * 8)) | + (Byte3 << (1 * 8)) | (Byte7 >> (7 * 8)) | (Byte6 >> (5 * 8)) | + (Byte5 >> (3 * 8)) | (Byte4 >> (1 * 8)); #endif memcpy(ptr, &output, sizeof(uint64_t)); } -static inline uint16_t decodeUInt16LE(const uint8_t *data) { +static inline uint16_t decodeUInt16LE(const uint8_t* data) { // NOLINTNEXTLINE(cppcoreguidelines-init-variables) uint16_t output; memcpy(&output, data, sizeof(uint16_t)); return output; } -static inline uint16_t decodeUInt16BE(const uint8_t *data) { +static inline uint16_t decodeUInt16BE(const uint8_t* data) { uint16_t output = decodeUInt16LE(data); swapBytes16(&output); return output; } -static inline uint32_t decodeUInt32LE(const uint8_t *data) { +static inline uint32_t decodeUInt32LE(const uint8_t* data) { // NOLINTNEXTLINE(cppcoreguidelines-init-variables) uint32_t output; memcpy(&output, data, sizeof(uint32_t)); return output; } -static inline uint32_t decodeUInt32BE(const uint8_t *data) { +static inline uint32_t decodeUInt32BE(const uint8_t* data) { uint32_t output = decodeUInt32LE(data); swapBytes32(&output); return output; } -static inline uint64_t decodeUInt64LE(const uint8_t *data) { +static inline uint64_t decodeUInt64LE(const uint8_t* data) { // NOLINTNEXTLINE(cppcoreguidelines-init-variables) uint64_t output; memcpy(&output, data, sizeof(uint64_t)); return output; } -static inline uint64_t decodeUInt64BE(const uint8_t *data) { +static inline uint64_t decodeUInt64BE(const uint8_t* data) { uint64_t output = decodeUInt64LE(data); swapBytes64(&output); return output; @@ -115,53 +113,70 @@ static inline uint64_t decodeUInt64BE(const uint8_t *data) { namespace torch { namespace utils { -THPByteOrder THP_nativeByteOrder() -{ +THPByteOrder THP_nativeByteOrder() { uint32_t x = 1; return *(uint8_t*)&x ? THP_LITTLE_ENDIAN : THP_BIG_ENDIAN; } -void THP_decodeInt16Buffer(int16_t* dst, const uint8_t* src, THPByteOrder order, size_t len) -{ - for(const auto i : c10::irange(len)) { - dst[i] = (int16_t)( - order == THP_BIG_ENDIAN ? decodeUInt16BE(src) : decodeUInt16LE(src)); +void THP_decodeInt16Buffer( + int16_t* dst, + const uint8_t* src, + THPByteOrder order, + size_t len) { + for (const auto i : c10::irange(len)) { + dst[i] = + (int16_t)(order == THP_BIG_ENDIAN ? decodeUInt16BE(src) : decodeUInt16LE(src)); src += sizeof(int16_t); } } -void THP_decodeInt32Buffer(int32_t* dst, const uint8_t* src, THPByteOrder order, size_t len) -{ - for(const auto i : c10::irange(len)) { - dst[i] = (int32_t)( - order == THP_BIG_ENDIAN ? decodeUInt32BE(src) : decodeUInt32LE(src)); +void THP_decodeInt32Buffer( + int32_t* dst, + const uint8_t* src, + THPByteOrder order, + size_t len) { + for (const auto i : c10::irange(len)) { + dst[i] = + (int32_t)(order == THP_BIG_ENDIAN ? decodeUInt32BE(src) : decodeUInt32LE(src)); src += sizeof(int32_t); } } -void THP_decodeInt64Buffer(int64_t* dst, const uint8_t* src, THPByteOrder order, size_t len) -{ - for(const auto i : c10::irange(len)) { - dst[i] = (int64_t)( - order == THP_BIG_ENDIAN ? decodeUInt64BE(src) : decodeUInt64LE(src)); +void THP_decodeInt64Buffer( + int64_t* dst, + const uint8_t* src, + THPByteOrder order, + size_t len) { + for (const auto i : c10::irange(len)) { + dst[i] = + (int64_t)(order == THP_BIG_ENDIAN ? decodeUInt64BE(src) : decodeUInt64LE(src)); src += sizeof(int64_t); } } -void THP_decodeHalfBuffer(c10::Half* dst, const uint8_t* src, THPByteOrder order, size_t len) -{ - for(const auto i : c10::irange(len)) { +void THP_decodeHalfBuffer( + c10::Half* dst, + const uint8_t* src, + THPByteOrder order, + size_t len) { + for (const auto i : c10::irange(len)) { // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) - union { uint16_t x; c10::Half f; }; + union { + uint16_t x; + c10::Half f; + }; x = (order == THP_BIG_ENDIAN ? decodeUInt16BE(src) : decodeUInt16LE(src)); dst[i] = f; src += sizeof(uint16_t); } } -void THP_decodeBFloat16Buffer(at::BFloat16* dst, const uint8_t* src, THPByteOrder order, size_t len) -{ - for(const auto i : c10::irange(len)) { +void THP_decodeBFloat16Buffer( + at::BFloat16* dst, + const uint8_t* src, + THPByteOrder order, + size_t len) { + for (const auto i : c10::irange(len)) { uint16_t x = (order == THP_BIG_ENDIAN ? decodeUInt16BE(src) : decodeUInt16LE(src)); std::memcpy(&dst[i], &x, sizeof(dst[i])); @@ -169,42 +184,66 @@ void THP_decodeBFloat16Buffer(at::BFloat16* dst, const uint8_t* src, THPByteOrde } } -void THP_decodeBoolBuffer(bool* dst, const uint8_t* src, THPByteOrder order, size_t len) -{ - for(const auto i : c10::irange(len)) { +void THP_decodeBoolBuffer( + bool* dst, + const uint8_t* src, + THPByteOrder order, + size_t len) { + for (const auto i : c10::irange(len)) { dst[i] = (int)src[i] != 0 ? true : false; } } -void THP_decodeFloatBuffer(float* dst, const uint8_t* src, THPByteOrder order, size_t len) -{ - for(const auto i : c10::irange(len)) { +void THP_decodeFloatBuffer( + float* dst, + const uint8_t* src, + THPByteOrder order, + size_t len) { + for (const auto i : c10::irange(len)) { // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) - union { uint32_t x; float f; }; + union { + uint32_t x; + float f; + }; x = (order == THP_BIG_ENDIAN ? decodeUInt32BE(src) : decodeUInt32LE(src)); dst[i] = f; src += sizeof(float); } } -void THP_decodeDoubleBuffer(double* dst, const uint8_t* src, THPByteOrder order, size_t len) -{ - for(const auto i : c10::irange(len)) { +void THP_decodeDoubleBuffer( + double* dst, + const uint8_t* src, + THPByteOrder order, + size_t len) { + for (const auto i : c10::irange(len)) { // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) - union { uint64_t x; double d; }; + union { + uint64_t x; + double d; + }; x = (order == THP_BIG_ENDIAN ? decodeUInt64BE(src) : decodeUInt64LE(src)); dst[i] = d; src += sizeof(double); } } -void THP_decodeComplexFloatBuffer(c10::complex* dst, const uint8_t* src, THPByteOrder order, size_t len) -{ - for(const auto i : c10::irange(len)) { +void THP_decodeComplexFloatBuffer( + c10::complex* dst, + const uint8_t* src, + THPByteOrder order, + size_t len) { + for (const auto i : c10::irange(len)) { // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) - union { uint32_t x; float re; }; + union { + uint32_t x; + float re; + }; // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) - union { uint32_t y; float im; }; + union { + uint32_t y; + float im; + }; x = (order == THP_BIG_ENDIAN ? decodeUInt32BE(src) : decodeUInt32LE(src)); src += sizeof(float); @@ -215,13 +254,22 @@ void THP_decodeComplexFloatBuffer(c10::complex* dst, const uint8_t* src, } } -void THP_decodeComplexDoubleBuffer(c10::complex* dst, const uint8_t* src, THPByteOrder order, size_t len) -{ - for(const auto i : c10::irange(len)) { +void THP_decodeComplexDoubleBuffer( + c10::complex* dst, + const uint8_t* src, + THPByteOrder order, + size_t len) { + for (const auto i : c10::irange(len)) { // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) - union { uint32_t x; double re; }; + union { + uint32_t x; + double re; + }; // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) - union { uint32_t y; double im; }; + union { + uint32_t y; + double im; + }; x = (order == THP_BIG_ENDIAN ? decodeUInt64BE(src) : decodeUInt64LE(src)); src += sizeof(double); @@ -232,11 +280,14 @@ void THP_decodeComplexDoubleBuffer(c10::complex* dst, const uint8_t* src } } -void THP_encodeInt16Buffer(uint8_t* dst, const int16_t* src, THPByteOrder order, size_t len) -{ +void THP_encodeInt16Buffer( + uint8_t* dst, + const int16_t* src, + THPByteOrder order, + size_t len) { memcpy(dst, src, sizeof(int16_t) * len); if (order != THP_nativeByteOrder()) { - for(const auto i : c10::irange(len)) { + for (const auto i : c10::irange(len)) { (void)i; swapBytes16(dst); dst += sizeof(int16_t); @@ -244,11 +295,14 @@ void THP_encodeInt16Buffer(uint8_t* dst, const int16_t* src, THPByteOrder order, } } -void THP_encodeInt32Buffer(uint8_t* dst, const int32_t* src, THPByteOrder order, size_t len) -{ +void THP_encodeInt32Buffer( + uint8_t* dst, + const int32_t* src, + THPByteOrder order, + size_t len) { memcpy(dst, src, sizeof(int32_t) * len); if (order != THP_nativeByteOrder()) { - for(const auto i : c10::irange(len)) { + for (const auto i : c10::irange(len)) { (void)i; swapBytes32(dst); dst += sizeof(int32_t); @@ -256,11 +310,14 @@ void THP_encodeInt32Buffer(uint8_t* dst, const int32_t* src, THPByteOrder order, } } -void THP_encodeInt64Buffer(uint8_t* dst, const int64_t* src, THPByteOrder order, size_t len) -{ +void THP_encodeInt64Buffer( + uint8_t* dst, + const int64_t* src, + THPByteOrder order, + size_t len) { memcpy(dst, src, sizeof(int64_t) * len); if (order != THP_nativeByteOrder()) { - for(const auto i : c10::irange(len)) { + for (const auto i : c10::irange(len)) { (void)i; swapBytes64(dst); dst += sizeof(int64_t); @@ -268,11 +325,14 @@ void THP_encodeInt64Buffer(uint8_t* dst, const int64_t* src, THPByteOrder order, } } -void THP_encodeFloatBuffer(uint8_t* dst, const float* src, THPByteOrder order, size_t len) -{ +void THP_encodeFloatBuffer( + uint8_t* dst, + const float* src, + THPByteOrder order, + size_t len) { memcpy(dst, src, sizeof(float) * len); if (order != THP_nativeByteOrder()) { - for(const auto i : c10::irange(len)) { + for (const auto i : c10::irange(len)) { (void)i; swapBytes32(dst); dst += sizeof(float); @@ -280,11 +340,14 @@ void THP_encodeFloatBuffer(uint8_t* dst, const float* src, THPByteOrder order, s } } -void THP_encodeDoubleBuffer(uint8_t* dst, const double* src, THPByteOrder order, size_t len) -{ +void THP_encodeDoubleBuffer( + uint8_t* dst, + const double* src, + THPByteOrder order, + size_t len) { memcpy(dst, src, sizeof(double) * len); if (order != THP_nativeByteOrder()) { - for(const auto i : c10::irange(len)) { + for (const auto i : c10::irange(len)) { (void)i; swapBytes64(dst); dst += sizeof(double); @@ -296,7 +359,7 @@ template std::vector complex_to_float(const c10::complex* src, size_t len) { std::vector new_src; new_src.reserve(2 * len); - for(const auto i : c10::irange(len)) { + for (const auto i : c10::irange(len)) { auto elem = src[i]; new_src.emplace_back(elem.real()); new_src.emplace_back(elem.imag()); @@ -304,8 +367,11 @@ std::vector complex_to_float(const c10::complex* src, size_t len) { return new_src; } -void THP_encodeComplexFloatBuffer(uint8_t* dst, const c10::complex* src, THPByteOrder order, size_t len) -{ +void THP_encodeComplexFloatBuffer( + uint8_t* dst, + const c10::complex* src, + THPByteOrder order, + size_t len) { auto new_src = complex_to_float(src, len); memcpy(dst, static_cast(&new_src), 2 * sizeof(float) * len); if (order != THP_nativeByteOrder()) { @@ -317,8 +383,11 @@ void THP_encodeComplexFloatBuffer(uint8_t* dst, const c10::complex* src, } } -void THP_encodeCompelxDoubleBuffer(uint8_t* dst, const c10::complex* src, THPByteOrder order, size_t len) -{ +void THP_encodeCompelxDoubleBuffer( + uint8_t* dst, + const c10::complex* src, + THPByteOrder order, + size_t len) { auto new_src = complex_to_float(src, len); memcpy(dst, static_cast(&new_src), 2 * sizeof(double) * len); if (order != THP_nativeByteOrder()) { diff --git a/torch/csrc/utils/byte_order.h b/torch/csrc/utils/byte_order.h index 78c38ad9ede31d..b4c3c32eccc330 100644 --- a/torch/csrc/utils/byte_order.h +++ b/torch/csrc/utils/byte_order.h @@ -1,7 +1,7 @@ #pragma once -#include #include +#include #include #include #include @@ -9,10 +9,7 @@ namespace torch { namespace utils { -enum THPByteOrder { - THP_LITTLE_ENDIAN = 0, - THP_BIG_ENDIAN = 1 -}; +enum THPByteOrder { THP_LITTLE_ENDIAN = 0, THP_BIG_ENDIAN = 1 }; TORCH_API THPByteOrder THP_nativeByteOrder(); diff --git a/torch/csrc/utils/cpp_stacktraces.cpp b/torch/csrc/utils/cpp_stacktraces.cpp index 6543ca986f5bcb..ac9897e8ad9644 100644 --- a/torch/csrc/utils/cpp_stacktraces.cpp +++ b/torch/csrc/utils/cpp_stacktraces.cpp @@ -16,8 +16,10 @@ bool compute_cpp_stack_traces_enabled() { if (strcmp(envar, "1") == 0) { return true; } - TORCH_WARN("ignoring invalid value for TORCH_SHOW_CPP_STACKTRACES: ", envar, - " valid values are 0 or 1."); + TORCH_WARN( + "ignoring invalid value for TORCH_SHOW_CPP_STACKTRACES: ", + envar, + " valid values are 0 or 1."); } return false; } diff --git a/torch/csrc/utils/cuda_enabled.h b/torch/csrc/utils/cuda_enabled.h index 84353e07e43747..e27c168a8ef46a 100644 --- a/torch/csrc/utils/cuda_enabled.h +++ b/torch/csrc/utils/cuda_enabled.h @@ -11,5 +11,5 @@ static inline bool cuda_enabled() { #endif } -} -} +} // namespace utils +} // namespace torch diff --git a/torch/csrc/utils/cuda_lazy_init.cpp b/torch/csrc/utils/cuda_lazy_init.cpp index b5b97e89eb63fe..37c682bdbd1f29 100644 --- a/torch/csrc/utils/cuda_lazy_init.cpp +++ b/torch/csrc/utils/cuda_lazy_init.cpp @@ -18,9 +18,12 @@ void cuda_lazy_init() { // have taken a lock. if (!run_yet) { auto module = THPObjectPtr(PyImport_ImportModule("torch.cuda")); - if (!module) throw python_error(); - auto res = THPObjectPtr(PyObject_CallMethod(module.get(), "_lazy_init", "")); - if (!res) throw python_error(); + if (!module) + throw python_error(); + auto res = + THPObjectPtr(PyObject_CallMethod(module.get(), "_lazy_init", "")); + if (!res) + throw python_error(); run_yet = true; } } @@ -29,5 +32,5 @@ void set_run_yet_variable_to_false() { run_yet = false; } -} -} +} // namespace utils +} // namespace torch diff --git a/torch/csrc/utils/cuda_lazy_init.h b/torch/csrc/utils/cuda_lazy_init.h index a4a10b54c3de52..e6efc1f648e521 100644 --- a/torch/csrc/utils/cuda_lazy_init.h +++ b/torch/csrc/utils/cuda_lazy_init.h @@ -29,5 +29,5 @@ static void maybe_initialize_cuda(const at::TensorOptions& options) { } } -} -} +} // namespace utils +} // namespace torch diff --git a/torch/csrc/utils/disable_torch_function.cpp b/torch/csrc/utils/disable_torch_function.cpp index 1bf5bfb0ddbfe6..ba7d3eb162918e 100644 --- a/torch/csrc/utils/disable_torch_function.cpp +++ b/torch/csrc/utils/disable_torch_function.cpp @@ -1,119 +1,121 @@ +#include #include #include -#include #include #include namespace torch { - PyObject* disabled_torch_function = nullptr; - PyObject* disabled_torch_dispatch = nullptr; +PyObject* disabled_torch_function = nullptr; +PyObject* disabled_torch_dispatch = nullptr; - bool torch_function_enabled() { - return !at::impl::PythonTorchFunctionTLS::is_disabled(); - } +bool torch_function_enabled() { + return !at::impl::PythonTorchFunctionTLS::is_disabled(); +} - PyObject* disabled_torch_function_impl() { - return disabled_torch_function; - } +PyObject* disabled_torch_function_impl() { + return disabled_torch_function; +} - void set_disabled_torch_function_impl(PyObject* value) { - disabled_torch_function = value; - } +void set_disabled_torch_function_impl(PyObject* value) { + disabled_torch_function = value; +} - PyObject* disabled_torch_dispatch_impl() { - return disabled_torch_dispatch; - } +PyObject* disabled_torch_dispatch_impl() { + return disabled_torch_dispatch; +} - void set_disabled_torch_dispatch_impl(PyObject* value) { - disabled_torch_dispatch = value; - } +void set_disabled_torch_dispatch_impl(PyObject* value) { + disabled_torch_dispatch = value; } +} // namespace torch typedef struct { - PyObject_HEAD - /* Type-specific fields go here. */ - bool old_state; + PyObject_HEAD + /* Type-specific fields go here. */ + bool old_state; } DisableTorchFunction; -PyObject* DisableTorchFunction__enter(PyObject* self, PyObject *unused) { - ((DisableTorchFunction*)self)->old_state = at::impl::PythonTorchFunctionTLS::is_disabled(); - at::impl::PythonTorchFunctionTLS::set_disabled(true); - Py_RETURN_NONE; +PyObject* DisableTorchFunction__enter(PyObject* self, PyObject* unused) { + ((DisableTorchFunction*)self)->old_state = + at::impl::PythonTorchFunctionTLS::is_disabled(); + at::impl::PythonTorchFunctionTLS::set_disabled(true); + Py_RETURN_NONE; } -PyObject* DisableTorchFunction__exit(PyObject* self, PyObject *unused) { - at::impl::PythonTorchFunctionTLS::set_disabled(((DisableTorchFunction*)self)->old_state); - Py_RETURN_NONE; +PyObject* DisableTorchFunction__exit(PyObject* self, PyObject* unused) { + at::impl::PythonTorchFunctionTLS::set_disabled( + ((DisableTorchFunction*)self)->old_state); + Py_RETURN_NONE; } -PyObject* THPModule_isEnabledTorchFunction(PyObject* self, PyObject *unused) { - if (torch::torch_function_enabled()) { - Py_RETURN_TRUE; - } else - { - Py_RETURN_FALSE; - } +PyObject* THPModule_isEnabledTorchFunction(PyObject* self, PyObject* unused) { + if (torch::torch_function_enabled()) { + Py_RETURN_TRUE; + } else { + Py_RETURN_FALSE; + } } static PyMethodDef DisableTorchFunction_methods[] = { // NOLINT - {"__enter__", DisableTorchFunction__enter, METH_NOARGS, nullptr}, - {"__exit__", DisableTorchFunction__exit, METH_VARARGS, nullptr}, - {nullptr, nullptr, 0, nullptr} -}; + {"__enter__", DisableTorchFunction__enter, METH_NOARGS, nullptr}, + {"__exit__", DisableTorchFunction__exit, METH_VARARGS, nullptr}, + {nullptr, nullptr, 0, nullptr}}; PyTypeObject DisableTorchFunctionType = { - PyVarObject_HEAD_INIT(nullptr, 0) - "torch._C.DisableTorchFunction", /* tp_name */ - sizeof(DisableTorchFunction), /* tp_basicsize */ - 0, /* tp_itemsize */ - nullptr, /* tp_dealloc */ - 0, /* tp_vectorcall_offset */ - nullptr, /* tp_getattr */ - nullptr, /* tp_setattr */ - nullptr, /* tp_reserved */ - nullptr, /* tp_repr */ - nullptr, /* tp_as_number */ - nullptr, /* tp_as_sequence */ - nullptr, /* tp_as_mapping */ - nullptr, /* tp_hash */ - nullptr, /* tp_call */ - nullptr, /* tp_str */ - nullptr, /* tp_getattro */ - nullptr, /* tp_setattro */ - nullptr, /* tp_as_buffer */ - Py_TPFLAGS_DEFAULT, /* tp_flags */ - nullptr, /* tp_doc */ - nullptr, /* tp_traverse */ - nullptr, /* tp_clear */ - nullptr, /* tp_richcompare */ - 0, /* tp_weaklistoffset */ - nullptr, /* tp_iter */ - nullptr, /* tp_iternext */ - DisableTorchFunction_methods, /* tp_methods */ - nullptr, /* tp_members */ - nullptr, /* tp_getset */ - nullptr, /* tp_base */ - nullptr, /* tp_dict */ - nullptr, /* tp_descr_get */ - nullptr, /* tp_descr_set */ - 0, /* tp_dictoffset */ - nullptr, /* tp_init */ - PyType_GenericAlloc, /* tp_alloc */ - PyType_GenericNew, /* tp_new */ + PyVarObject_HEAD_INIT( + nullptr, + 0) "torch._C.DisableTorchFunction", /* tp_name */ + sizeof(DisableTorchFunction), /* tp_basicsize */ + 0, /* tp_itemsize */ + nullptr, /* tp_dealloc */ + 0, /* tp_vectorcall_offset */ + nullptr, /* tp_getattr */ + nullptr, /* tp_setattr */ + nullptr, /* tp_reserved */ + nullptr, /* tp_repr */ + nullptr, /* tp_as_number */ + nullptr, /* tp_as_sequence */ + nullptr, /* tp_as_mapping */ + nullptr, /* tp_hash */ + nullptr, /* tp_call */ + nullptr, /* tp_str */ + nullptr, /* tp_getattro */ + nullptr, /* tp_setattro */ + nullptr, /* tp_as_buffer */ + Py_TPFLAGS_DEFAULT, /* tp_flags */ + nullptr, /* tp_doc */ + nullptr, /* tp_traverse */ + nullptr, /* tp_clear */ + nullptr, /* tp_richcompare */ + 0, /* tp_weaklistoffset */ + nullptr, /* tp_iter */ + nullptr, /* tp_iternext */ + DisableTorchFunction_methods, /* tp_methods */ + nullptr, /* tp_members */ + nullptr, /* tp_getset */ + nullptr, /* tp_base */ + nullptr, /* tp_dict */ + nullptr, /* tp_descr_get */ + nullptr, /* tp_descr_set */ + 0, /* tp_dictoffset */ + nullptr, /* tp_init */ + PyType_GenericAlloc, /* tp_alloc */ + PyType_GenericNew, /* tp_new */ }; PyObject* THPModule_DisableTorchFunctionType() { if (PyType_Ready(&DisableTorchFunctionType) < 0) { - return nullptr; + return nullptr; } - return (PyObject *)(&DisableTorchFunctionType); + return (PyObject*)(&DisableTorchFunctionType); } -PyObject* THPModule_disable_torch_function(PyObject *self, PyObject *a) { +PyObject* THPModule_disable_torch_function(PyObject* self, PyObject* a) { HANDLE_TH_ERRORS - PyObject *func=nullptr, *types=nullptr, *args=nullptr, *kwargs=nullptr; + PyObject *func = nullptr, *types = nullptr, *args = nullptr, + *kwargs = nullptr; if (!PyArg_ParseTuple(a, "OO|OO", &func, &types, &args, &kwargs)) { return nullptr; } @@ -125,7 +127,8 @@ PyObject* THPModule_disable_torch_function(PyObject *self, PyObject *a) { } else if (PyTuple_Check(args)) { py_args = py::reinterpret_borrow(args); } else { - throw torch::TypeError("expected List or Tuple (got %s)", Py_TYPE(args)->tp_name); + throw torch::TypeError( + "expected List or Tuple (got %s)", Py_TYPE(args)->tp_name); } // These are all C-API calls so no exceptions will be raised @@ -134,15 +137,16 @@ PyObject* THPModule_disable_torch_function(PyObject *self, PyObject *a) { bool old_value = at::impl::PythonTorchFunctionTLS::is_disabled(); at::impl::PythonTorchFunctionTLS::set_disabled(true); // kwargs can safely be nullptr here. - PyObject *result = PyObject_Call(func, py_args.ptr(), kwargs); + PyObject* result = PyObject_Call(func, py_args.ptr(), kwargs); at::impl::PythonTorchFunctionTLS::set_disabled(old_value); return result; END_HANDLE_TH_ERRORS } -PyObject* THPModule_disable_torch_dispatch(PyObject *self, PyObject *a) { +PyObject* THPModule_disable_torch_dispatch(PyObject* self, PyObject* a) { HANDLE_TH_ERRORS - PyObject *func=nullptr, *types=nullptr, *args=nullptr, *kwargs=nullptr; + PyObject *func = nullptr, *types = nullptr, *args = nullptr, + *kwargs = nullptr; if (!PyArg_ParseTuple(a, "OO|OO", &func, &types, &args, &kwargs)) { return nullptr; } @@ -154,7 +158,8 @@ PyObject* THPModule_disable_torch_dispatch(PyObject *self, PyObject *a) { } else if (PyTuple_Check(args)) { py_args = py::reinterpret_borrow(args); } else { - throw torch::TypeError("expected List or Tuple (got %s)", Py_TYPE(args)->tp_name); + throw torch::TypeError( + "expected List or Tuple (got %s)", Py_TYPE(args)->tp_name); } // This implementation is not completely correct. The moral @@ -171,68 +176,56 @@ PyObject* THPModule_disable_torch_dispatch(PyObject *self, PyObject *a) { c10::impl::ExcludeDispatchKeyGuard guard_( // TODO: add constructor for this specifically c10::DispatchKeySet(c10::DispatchKeySet::FULL) - - c10::DispatchKeySet(c10::DispatchKeySet::FULL_AFTER, c10::DispatchKey::Python) + c10::DispatchKeySet( + c10::DispatchKeySet::FULL_AFTER, c10::DispatchKey::Python) // NB: off by one hazard here, but it works out: python key is not // included in AFTER, so it is included in the negation (and that's // correct: we want to exclude Python key and everything BEFORE it.) ); auto r = PyObject_Call(func, py_args.ptr(), kwargs); - if (r == nullptr) throw python_error(); + if (r == nullptr) + throw python_error(); return r; END_HANDLE_TH_ERRORS } // Makes sure that we don't check for __torch_function__ on basic Python types -static bool is_basic_python_type(PyTypeObject *tp) -{ +static bool is_basic_python_type(PyTypeObject* tp) { return ( - /* Basic number types */ - tp == &PyBool_Type || - - tp == &PyLong_Type || - tp == &PyFloat_Type || - tp == &PyComplex_Type || - - /* Basic sequence types */ - tp == &PyList_Type || - tp == &PyTuple_Type || - tp == &PyDict_Type || - tp == &PySet_Type || - tp == &PyFrozenSet_Type || - tp == &PyUnicode_Type || - tp == &PyBytes_Type || - - /* other builtins */ - tp == &PySlice_Type || - tp == Py_TYPE(Py_None) || - tp == Py_TYPE(Py_Ellipsis) || - tp == Py_TYPE(Py_NotImplemented) || - - PyModule_Check(tp) || - /* sentinel to swallow trailing || */ - false - ); + /* Basic number types */ + tp == &PyBool_Type || + + tp == &PyLong_Type || tp == &PyFloat_Type || tp == &PyComplex_Type || + + /* Basic sequence types */ + tp == &PyList_Type || tp == &PyTuple_Type || tp == &PyDict_Type || + tp == &PySet_Type || tp == &PyFrozenSet_Type || tp == &PyUnicode_Type || + tp == &PyBytes_Type || + + /* other builtins */ + tp == &PySlice_Type || tp == Py_TYPE(Py_None) || + tp == Py_TYPE(Py_Ellipsis) || tp == Py_TYPE(Py_NotImplemented) || + + PyModule_Check(tp) || + /* sentinel to swallow trailing || */ + false); } inline bool has_torch_function_attr(PyObject* obj) { // NOLINTNEXTLINE(clang-diagnostic-writable-strings) auto attr = PyObject_FastGetAttrString(obj, "__torch_function__"); return ( - attr.ptr() != nullptr && - attr.ptr() != torch::disabled_torch_function); + attr.ptr() != nullptr && attr.ptr() != torch::disabled_torch_function); } namespace torch { auto check_has_torch_function(PyObject* obj, bool ignore_mode) -> bool { if (!ignore_mode && at::impl::PythonTorchFunctionTLS::get_mode()) return true; - PyTypeObject *tp = Py_TYPE(obj); + PyTypeObject* tp = Py_TYPE(obj); return ( - !THPVariable_CheckTypeExact(tp) && - !is_basic_python_type(tp) && - torch::torch_function_enabled() && - has_torch_function_attr(obj) - ); + !THPVariable_CheckTypeExact(tp) && !is_basic_python_type(tp) && + torch::torch_function_enabled() && has_torch_function_attr(obj)); } } // namespace torch @@ -248,7 +241,7 @@ inline bool sequence_has_torch_function(PyObject* args) { return false; } -inline bool array_has_torch_function(PyObject *const *args, Py_ssize_t nargs) { +inline bool array_has_torch_function(PyObject* const* args, Py_ssize_t nargs) { for (Py_ssize_t i = 0; i < nargs; i++) { if (torch::check_has_torch_function(args[i])) { return true; @@ -257,8 +250,8 @@ inline bool array_has_torch_function(PyObject *const *args, Py_ssize_t nargs) { return false; } -PyObject* THPModule_has_torch_function(PyObject*, PyObject *arg) { - bool result; // NOLINT(cppcoreguidelines-init-variables) +PyObject* THPModule_has_torch_function(PyObject*, PyObject* arg) { + bool result; // NOLINT(cppcoreguidelines-init-variables) if (PyTuple_CheckExact(arg) || PyList_CheckExact(arg)) { // Fast path: // If we know that we have a tuple or list, we can skip an INCREF and @@ -268,7 +261,7 @@ PyObject* THPModule_has_torch_function(PyObject*, PyObject *arg) { result = sequence_has_torch_function(arg); } else { auto args = py::reinterpret_steal( - PySequence_Fast(arg, "expected a sequence")); + PySequence_Fast(arg, "expected a sequence")); result = sequence_has_torch_function(args.ptr()); } @@ -278,7 +271,7 @@ PyObject* THPModule_has_torch_function(PyObject*, PyObject *arg) { Py_RETURN_FALSE; } -PyObject* THPModule_has_torch_function_unary(PyObject*, PyObject *obj) { +PyObject* THPModule_has_torch_function_unary(PyObject*, PyObject* obj) { // Special case `THPModule_has_torch_function` for the single arg case. if (torch::check_has_torch_function(obj)) { Py_RETURN_TRUE; @@ -286,7 +279,10 @@ PyObject* THPModule_has_torch_function_unary(PyObject*, PyObject *obj) { Py_RETURN_FALSE; } -PyObject* THPModule_has_torch_function_variadic(PyObject*, PyObject *const *args, Py_ssize_t nargs) { +PyObject* THPModule_has_torch_function_variadic( + PyObject*, + PyObject* const* args, + Py_ssize_t nargs) { if (array_has_torch_function(args, nargs)) { Py_RETURN_TRUE; } diff --git a/torch/csrc/utils/disable_torch_function.h b/torch/csrc/utils/disable_torch_function.h index 18ed7f47bef60f..8200eeb88e954d 100644 --- a/torch/csrc/utils/disable_torch_function.h +++ b/torch/csrc/utils/disable_torch_function.h @@ -1,35 +1,39 @@ #pragma once -#include #include #include +#include namespace torch { - // Sometimes we don't want infinite recursion for subclasses, - // Or a way to achieve the old behaviour. +// Sometimes we don't want infinite recursion for subclasses, +// Or a way to achieve the old behaviour. - // This is an internal utility, not exposed to users. - bool torch_function_enabled(); - PyObject* disabled_torch_function_impl(); - PyObject* disabled_torch_dispatch_impl(); - void set_disabled_torch_function_impl(PyObject* value); - void set_disabled_torch_dispatch_impl(PyObject* value); - // Set ignore_mode to true if you're trying to collect overloaded arguments; - // using mode here will improperly cause you to add ALL objects to the - // overloaded list even if they don't actually have __torch_function__ - bool check_has_torch_function(PyObject* obj, bool ignore_mode = false); +// This is an internal utility, not exposed to users. +bool torch_function_enabled(); +PyObject* disabled_torch_function_impl(); +PyObject* disabled_torch_dispatch_impl(); +void set_disabled_torch_function_impl(PyObject* value); +void set_disabled_torch_dispatch_impl(PyObject* value); +// Set ignore_mode to true if you're trying to collect overloaded arguments; +// using mode here will improperly cause you to add ALL objects to the +// overloaded list even if they don't actually have __torch_function__ +bool check_has_torch_function(PyObject* obj, bool ignore_mode = false); - struct DisableTorchDispatch { - DisableTorchDispatch() : guard_(c10::DispatchKey::Python), - guard_tls_snapshot_(c10::DispatchKey::PythonTLSSnapshot) {} - c10::impl::ExcludeDispatchKeyGuard guard_; - c10::impl::ExcludeDispatchKeyGuard guard_tls_snapshot_; - }; -} +struct DisableTorchDispatch { + DisableTorchDispatch() + : guard_(c10::DispatchKey::Python), + guard_tls_snapshot_(c10::DispatchKey::PythonTLSSnapshot) {} + c10::impl::ExcludeDispatchKeyGuard guard_; + c10::impl::ExcludeDispatchKeyGuard guard_tls_snapshot_; +}; +} // namespace torch -PyObject* THPModule_isEnabledTorchFunction(PyObject* self, PyObject *unused); +PyObject* THPModule_isEnabledTorchFunction(PyObject* self, PyObject* unused); PyObject* THPModule_DisableTorchFunctionType(); -PyObject* THPModule_disable_torch_function(PyObject *self, PyObject *args); -PyObject* THPModule_disable_torch_dispatch(PyObject *self, PyObject *args); -PyObject* THPModule_has_torch_function(PyObject*, PyObject *arg); -PyObject* THPModule_has_torch_function_unary(PyObject*, PyObject *obj); -PyObject* THPModule_has_torch_function_variadic(PyObject*, PyObject *const *args, Py_ssize_t nargs); +PyObject* THPModule_disable_torch_function(PyObject* self, PyObject* args); +PyObject* THPModule_disable_torch_dispatch(PyObject* self, PyObject* args); +PyObject* THPModule_has_torch_function(PyObject*, PyObject* arg); +PyObject* THPModule_has_torch_function_unary(PyObject*, PyObject* obj); +PyObject* THPModule_has_torch_function_variadic( + PyObject*, + PyObject* const* args, + Py_ssize_t nargs); diff --git a/torch/csrc/utils/init.cpp b/torch/csrc/utils/init.cpp index 04a5d1b3e028d5..584860fbc229b6 100644 --- a/torch/csrc/utils/init.cpp +++ b/torch/csrc/utils/init.cpp @@ -17,7 +17,8 @@ void initThroughputBenchmarkBindings(PyObject* module) { .def_readwrite("num_worker_threads", &BenchmarkConfig::num_worker_threads) .def_readwrite("num_warmup_iters", &BenchmarkConfig::num_warmup_iters) .def_readwrite("num_iters", &BenchmarkConfig::num_iters) - .def_readwrite("profiler_output_path", &BenchmarkConfig::profiler_output_path); + .def_readwrite( + "profiler_output_path", &BenchmarkConfig::profiler_output_path); py::class_(m, "BenchmarkExecutionStats") .def_readonly("latency_avg_ms", &BenchmarkExecutionStats::latency_avg_ms) @@ -45,8 +46,6 @@ void initThroughputBenchmarkBindings(PyObject* module) { pybind11::gil_scoped_release no_gil_guard; return self.benchmark(config); }); - - } } // namespace throughput_benchmark diff --git a/torch/csrc/utils/invalid_arguments.cpp b/torch/csrc/utils/invalid_arguments.cpp index 2b955125dff37a..e76b9cf22ff509 100644 --- a/torch/csrc/utils/invalid_arguments.cpp +++ b/torch/csrc/utils/invalid_arguments.cpp @@ -6,37 +6,37 @@ #include #include -#include #include +#include namespace torch { namespace { -std::string py_typename(PyObject *object) { +std::string py_typename(PyObject* object) { return Py_TYPE(object)->tp_name; } struct Type { - virtual bool is_matching(PyObject *object) = 0; + virtual bool is_matching(PyObject* object) = 0; virtual ~Type() = default; }; -struct SimpleType: public Type { - SimpleType(std::string& name): name(name) {}; +struct SimpleType : public Type { + SimpleType(std::string& name) : name(name){}; - bool is_matching(PyObject *object) override { + bool is_matching(PyObject* object) override { return py_typename(object) == name; } std::string name; }; -struct MultiType: public Type { - MultiType(std::initializer_list accepted_types): - types(accepted_types) {}; +struct MultiType : public Type { + MultiType(std::initializer_list accepted_types) + : types(accepted_types){}; - bool is_matching(PyObject *object) override { + bool is_matching(PyObject* object) override { auto it = std::find(types.begin(), types.end(), py_typename(object)); return it != types.end(); } @@ -44,25 +44,27 @@ struct MultiType: public Type { std::vector types; }; -struct NullableType: public Type { - NullableType(std::unique_ptr type): type(std::move(type)) {}; +struct NullableType : public Type { + NullableType(std::unique_ptr type) : type(std::move(type)){}; - bool is_matching(PyObject *object) override { + bool is_matching(PyObject* object) override { return object == Py_None || type->is_matching(object); } std::unique_ptr type; }; -struct TupleType: public Type { - TupleType(std::vector> types): - types(std::move(types)) {}; +struct TupleType : public Type { + TupleType(std::vector> types) + : types(std::move(types)){}; - bool is_matching(PyObject *object) override { - if (!PyTuple_Check(object)) return false; + bool is_matching(PyObject* object) override { + if (!PyTuple_Check(object)) + return false; auto num_elements = PyTuple_GET_SIZE(object); - if (num_elements != (long)types.size()) return false; - for(const auto i : c10::irange(num_elements)) { + if (num_elements != (long)types.size()) + return false; + for (const auto i : c10::irange(num_elements)) { if (!types[i]->is_matching(PyTuple_GET_ITEM(object, i))) return false; } @@ -72,14 +74,14 @@ struct TupleType: public Type { std::vector> types; }; -struct SequenceType: public Type { - SequenceType(std::unique_ptr type): - type(std::move(type)) {}; +struct SequenceType : public Type { + SequenceType(std::unique_ptr type) : type(std::move(type)){}; - bool is_matching(PyObject *object) override { - if (!PySequence_Check(object)) return false; + bool is_matching(PyObject* object) override { + if (!PySequence_Check(object)) + return false; auto num_elements = PySequence_Length(object); - for(const auto i : c10::irange(num_elements)) { + for (const auto i : c10::irange(num_elements)) { if (!type->is_matching(PySequence_GetItem(object, i))) return false; } @@ -90,35 +92,40 @@ struct SequenceType: public Type { }; struct Argument { - Argument(std::string name, std::unique_ptr type): - name(std::move(name)), type(std::move(type)) {}; + Argument(std::string name, std::unique_ptr type) + : name(std::move(name)), type(std::move(type)){}; std::string name; std::unique_ptr type; }; struct Option { - Option(std::vector arguments, bool is_variadic, bool has_out): - arguments(std::move(arguments)), is_variadic(is_variadic), has_out(has_out) {}; - Option(bool is_variadic, bool has_out): - arguments(), is_variadic(is_variadic), has_out(has_out) {}; + Option(std::vector arguments, bool is_variadic, bool has_out) + : arguments(std::move(arguments)), + is_variadic(is_variadic), + has_out(has_out){}; + Option(bool is_variadic, bool has_out) + : arguments(), is_variadic(is_variadic), has_out(has_out){}; Option(const Option&) = delete; - Option(Option&& other): - arguments(std::move(other.arguments)), is_variadic(other.is_variadic), - has_out(other.has_out) {}; + Option(Option&& other) + : arguments(std::move(other.arguments)), + is_variadic(other.is_variadic), + has_out(other.has_out){}; std::vector arguments; bool is_variadic; bool has_out; }; -std::vector _splitString(const std::string &s, const std::string& delim) { +std::vector _splitString( + const std::string& s, + const std::string& delim) { std::vector tokens; size_t start = 0; // NOLINTNEXTLINE(cppcoreguidelines-init-variables) size_t end; - while((end = s.find(delim, start)) != std::string::npos) { - tokens.push_back(s.substr(start, end-start)); + while ((end = s.find(delim, start)) != std::string::npos) { + tokens.push_back(s.substr(start, end - start)); start = end + delim.length(); } tokens.push_back(s.substr(start)); @@ -135,7 +142,7 @@ std::unique_ptr _buildType(std::string type_name, bool is_nullable) { auto type_list = type_name.substr(6); type_list.pop_back(); std::vector> types; - for (auto& type: _splitString(type_list, ",")) + for (auto& type : _splitString(type_list, ",")) types.emplace_back(_buildType(type, false)); result = torch::make_unique(std::move(types)); } else if (type_name.find("sequence[") == 0) { @@ -150,26 +157,26 @@ std::unique_ptr _buildType(std::string type_name, bool is_nullable) { return result; } -std::pair _parseOption(const std::string& _option_str, - const std::unordered_map& kwargs) -{ +std::pair _parseOption( + const std::string& _option_str, + const std::unordered_map& kwargs) { if (_option_str == "no arguments") return std::pair(Option(false, false), _option_str); bool has_out = false; std::vector arguments; std::string printable_option = _option_str; - std::string option_str = _option_str.substr(1, _option_str.length()-2); + std::string option_str = _option_str.substr(1, _option_str.length() - 2); /// XXX: this is a hack only for the out arg in TensorMethods auto out_pos = printable_option.find('#'); if (out_pos != std::string::npos) { if (kwargs.count("out") > 0) { - std::string kwonly_part = printable_option.substr(out_pos+1); + std::string kwonly_part = printable_option.substr(out_pos + 1); printable_option.erase(out_pos); printable_option += "*, "; printable_option += kwonly_part; } else if (out_pos >= 2) { - printable_option.erase(out_pos-2); + printable_option.erase(out_pos - 2); printable_option += ")"; } else { printable_option.erase(out_pos); @@ -178,7 +185,7 @@ std::pair _parseOption(const std::string& _option_str, has_out = true; } - for (auto& arg: _splitString(option_str, ", ")) { + for (auto& arg : _splitString(option_str, ", ")) { bool is_nullable = false; auto type_start_idx = 0; if (arg[type_start_idx] == '#') { @@ -197,42 +204,38 @@ std::pair _parseOption(const std::string& _option_str, // ^ ^ auto dots_idx = arg.find("..."); if (dots_idx != std::string::npos) - type_end_idx -= 4; + type_end_idx -= 4; std::string type_name = - arg.substr(type_start_idx, type_end_idx-type_start_idx); - std::string name = - arg.substr(name_start_idx); + arg.substr(type_start_idx, type_end_idx - type_start_idx); + std::string name = arg.substr(name_start_idx); arguments.emplace_back(name, _buildType(type_name, is_nullable)); } bool is_variadic = option_str.find("...") != std::string::npos; return std::pair( - Option(std::move(arguments), is_variadic, has_out), - std::move(printable_option) - ); + Option(std::move(arguments), is_variadic, has_out), + std::move(printable_option)); } bool _argcountMatch( const Option& option, const std::vector& arguments, - const std::unordered_map& kwargs) -{ + const std::unordered_map& kwargs) { auto num_expected = option.arguments.size(); auto num_got = arguments.size() + kwargs.size(); // Note: variadic functions don't accept kwargs, so it's ok if (option.has_out && kwargs.count("out") == 0) num_expected--; return num_got == num_expected || - (option.is_variadic && num_got > num_expected); + (option.is_variadic && num_got > num_expected); } std::string _formattedArgDesc( const Option& option, const std::vector& arguments, - const std::unordered_map& kwargs) -{ + const std::unordered_map& kwargs) { std::string red; std::string reset_red; std::string green; @@ -251,9 +254,10 @@ std::string _formattedArgDesc( auto num_args = arguments.size() + kwargs.size(); std::string result = "("; - for(const auto i : c10::irange(num_args)) { + for (const auto i : c10::irange(num_args)) { bool is_kwarg = i >= arguments.size(); - PyObject *arg = is_kwarg ? kwargs.at(option.arguments[i].name) : arguments[i]; + PyObject* arg = + is_kwarg ? kwargs.at(option.arguments[i].name) : arguments[i]; bool is_matching = false; if (i < option.arguments.size()) { @@ -266,35 +270,37 @@ std::string _formattedArgDesc( result += green; else result += red; - if (is_kwarg) result += option.arguments[i].name + "="; + if (is_kwarg) + result += option.arguments[i].name + "="; result += py_typename(arg); if (is_matching) - result += reset_green; + result += reset_green; else - result += reset_red; + result += reset_red; result += ", "; } if (arguments.size() > 0) - result.erase(result.length()-2); + result.erase(result.length() - 2); result += ")"; return result; } -std::string _argDesc(const std::vector& arguments, - const std::unordered_map& kwargs) -{ +std::string _argDesc( + const std::vector& arguments, + const std::unordered_map& kwargs) { std::string result = "("; - for (auto& arg: arguments) + for (auto& arg : arguments) result += std::string(py_typename(arg)) + ", "; - for (auto& kwarg: kwargs) + for (auto& kwarg : kwargs) result += kwarg.first + "=" + py_typename(kwarg.second) + ", "; if (arguments.size() > 0) - result.erase(result.length()-2); + result.erase(result.length() - 2); result += ")"; return result; } -std::vector _tryMatchKwargs(const Option& option, +std::vector _tryMatchKwargs( + const Option& option, const std::unordered_map& kwargs) { std::vector unmatched; // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions) @@ -303,7 +309,7 @@ std::vector _tryMatchKwargs(const Option& option, start_idx--; if (start_idx < 0) start_idx = 0; - for (auto& entry: kwargs) { + for (auto& entry : kwargs) { bool found = false; for (unsigned int i = start_idx; i < option.arguments.size(); i++) { if (option.arguments[i].name == entry.first) { @@ -320,19 +326,20 @@ std::vector _tryMatchKwargs(const Option& option, } // anonymous namespace std::string format_invalid_args( - PyObject *given_args, PyObject *given_kwargs, const std::string& function_name, - const std::vector& options) -{ - std::vector args; - std::unordered_map kwargs; + PyObject* given_args, + PyObject* given_kwargs, + const std::string& function_name, + const std::vector& options) { + std::vector args; + std::unordered_map kwargs; std::string error_msg; error_msg.reserve(2000); error_msg += function_name; error_msg += " received an invalid combination of arguments - "; Py_ssize_t num_args = PyTuple_Size(given_args); - for(const auto i : c10::irange(num_args)) { - PyObject *arg = PyTuple_GET_ITEM(given_args, i); + for (const auto i : c10::irange(num_args)) { + PyObject* arg = PyTuple_GET_ITEM(given_args, i); args.push_back(arg); } @@ -356,9 +363,9 @@ std::string format_invalid_args( unmatched_kwargs = _tryMatchKwargs(option, kwargs); if (unmatched_kwargs.size()) { error_msg += "got unrecognized keyword arguments: "; - for (auto& kwarg: unmatched_kwargs) + for (auto& kwarg : unmatched_kwargs) error_msg += kwarg + ", "; - error_msg.erase(error_msg.length()-2); + error_msg.erase(error_msg.length() - 2); } else { error_msg += "got "; if (_argcountMatch(option, args, kwargs)) { @@ -373,7 +380,7 @@ std::string format_invalid_args( error_msg += "got "; error_msg += _argDesc(args, kwargs); error_msg += ", but expected one of:\n"; - for (auto &option_str: options) { + for (auto& option_str : options) { auto pair = _parseOption(option_str, kwargs); auto& option = pair.first; auto& printable_option_str = pair.second; @@ -385,13 +392,15 @@ std::string format_invalid_args( if (has_kwargs) unmatched_kwargs = _tryMatchKwargs(option, kwargs); if (unmatched_kwargs.size() > 0) { - error_msg += " didn't match because some of the keywords were incorrect: "; - for (auto& kwarg: unmatched_kwargs) + error_msg += + " didn't match because some of the keywords were incorrect: "; + for (auto& kwarg : unmatched_kwargs) error_msg += kwarg + ", "; - error_msg.erase(error_msg.length()-2); + error_msg.erase(error_msg.length() - 2); error_msg += "\n"; } else { - error_msg += " didn't match because some of the arguments have invalid types: "; + error_msg += + " didn't match because some of the arguments have invalid types: "; error_msg += _formattedArgDesc(option, args, kwargs); error_msg += "\n"; } @@ -401,5 +410,4 @@ std::string format_invalid_args( return error_msg; } - } // namespace torch diff --git a/torch/csrc/utils/object_ptr.cpp b/torch/csrc/utils/object_ptr.cpp index c475be1d6db668..7faba811a1fa1a 100644 --- a/torch/csrc/utils/object_ptr.cpp +++ b/torch/csrc/utils/object_ptr.cpp @@ -2,7 +2,7 @@ #include -template<> +template <> void THPPointer::free() { if (ptr) Py_DECREF(ptr); diff --git a/torch/csrc/utils/object_ptr.h b/torch/csrc/utils/object_ptr.h index c27f35ad997cdb..5b4c1bc404457c 100644 --- a/torch/csrc/utils/object_ptr.h +++ b/torch/csrc/utils/object_ptr.h @@ -2,26 +2,55 @@ #include -template +template class THPPointer { -public: - THPPointer(): ptr(nullptr) {}; - explicit THPPointer(T *ptr) noexcept : ptr(ptr) {}; - THPPointer(THPPointer &&p) noexcept { free(); ptr = p.ptr; p.ptr = nullptr; }; + public: + THPPointer() : ptr(nullptr){}; + explicit THPPointer(T* ptr) noexcept : ptr(ptr){}; + THPPointer(THPPointer&& p) noexcept { + free(); + ptr = p.ptr; + p.ptr = nullptr; + }; - ~THPPointer() { free(); }; - T * get() { return ptr; } - const T * get() const { return ptr; } - T * release() { T *tmp = ptr; ptr = nullptr; return tmp; } - operator T*() { return ptr; } - THPPointer& operator =(T *new_ptr) noexcept { free(); ptr = new_ptr; return *this; } - THPPointer& operator =(THPPointer &&p) noexcept { free(); ptr = p.ptr; p.ptr = nullptr; return *this; } - T * operator ->() { return ptr; } - explicit operator bool() const { return ptr != nullptr; } + ~THPPointer() { + free(); + }; + T* get() { + return ptr; + } + const T* get() const { + return ptr; + } + T* release() { + T* tmp = ptr; + ptr = nullptr; + return tmp; + } + operator T*() { + return ptr; + } + THPPointer& operator=(T* new_ptr) noexcept { + free(); + ptr = new_ptr; + return *this; + } + THPPointer& operator=(THPPointer&& p) noexcept { + free(); + ptr = p.ptr; + p.ptr = nullptr; + return *this; + } + T* operator->() { + return ptr; + } + explicit operator bool() const { + return ptr != nullptr; + } -private: + private: void free(); - T *ptr = nullptr; + T* ptr = nullptr; }; /** diff --git a/torch/csrc/utils/out_types.cpp b/torch/csrc/utils/out_types.cpp index 0ceeb43bd1f814..4eead8119d660b 100644 --- a/torch/csrc/utils/out_types.cpp +++ b/torch/csrc/utils/out_types.cpp @@ -3,37 +3,54 @@ namespace torch { namespace utils { -// Used by python binding codegen to ensure any TensorOptions arguments are consistent -// with the out tensor's options -void check_out_type_matches(const at::Tensor& result, - at::ScalarType scalarType, bool scalarType_is_none, - c10::optional layout, - const at::Device& device, bool device_is_none) { - if (scalarType_is_none && !layout && device_is_none) { // common case +// Used by python binding codegen to ensure any TensorOptions arguments are +// consistent with the out tensor's options +void check_out_type_matches( + const at::Tensor& result, + at::ScalarType scalarType, + bool scalarType_is_none, + c10::optional layout, + const at::Device& device, + bool device_is_none) { + if (scalarType_is_none && !layout && device_is_none) { // common case return; } if (!scalarType_is_none && result.scalar_type() != scalarType) { AT_ERROR( - "dtype ", scalarType, - " does not match dtype of out parameter (", result.scalar_type(), ")"); + "dtype ", + scalarType, + " does not match dtype of out parameter (", + result.scalar_type(), + ")"); } auto scalarType_arg = scalarType_is_none ? result.scalar_type() : scalarType; - auto device_type_arg = device_is_none ? result.device().type() : device.type(); + auto device_type_arg = + device_is_none ? result.device().type() : device.type(); if (result.scalar_type() != scalarType_arg) { AT_ERROR( - "scalar type ", scalarType_arg, - " does not match scalar type of out parameter (", result.scalar_type(), ")"); + "scalar type ", + scalarType_arg, + " does not match scalar type of out parameter (", + result.scalar_type(), + ")"); } if (layout && result.layout() != *layout) { AT_ERROR( - "layout ", *layout, - " does not match layout of out parameter (", result.layout(), ")"); + "layout ", + *layout, + " does not match layout of out parameter (", + result.layout(), + ")"); } if (result.device().type() != device_type_arg) { AT_ERROR( - "device type ", device_type_arg, - " does not match device type of out parameter (", result.device().type(), ")"); + "device type ", + device_type_arg, + " does not match device type of out parameter (", + result.device().type(), + ")"); } } -}} +} // namespace utils +} // namespace torch diff --git a/torch/csrc/utils/out_types.h b/torch/csrc/utils/out_types.h index 47b588924d77a7..b3cf97f95bc181 100644 --- a/torch/csrc/utils/out_types.h +++ b/torch/csrc/utils/out_types.h @@ -7,8 +7,11 @@ namespace utils { TORCH_API void check_out_type_matches( const at::Tensor& result, - at::ScalarType scalarType, bool scalarType_is_none, + at::ScalarType scalarType, + bool scalarType_is_none, c10::optional layout, - const at::Device& device, bool device_is_none); + const at::Device& device, + bool device_is_none); -}} +} +} // namespace torch diff --git a/torch/csrc/utils/pybind.h b/torch/csrc/utils/pybind.h index d11cc98380a18d..335acb5e132201 100644 --- a/torch/csrc/utils/pybind.h +++ b/torch/csrc/utils/pybind.h @@ -9,10 +9,10 @@ #include #include +#include #include -#include #include -#include +#include #include #include @@ -27,7 +27,8 @@ PYBIND11_DECLARE_HOLDER_TYPE(T, c10::intrusive_ptr, true); PYBIND11_DECLARE_HOLDER_TYPE(T, c10::SingletonOrSharedTypePtr); PYBIND11_DECLARE_HOLDER_TYPE(T, c10::SingletonTypePtr, true); -namespace pybind11 { namespace detail { +namespace pybind11 { +namespace detail { // torch.Tensor <-> at::Tensor conversions (without unwrapping) template <> @@ -45,8 +46,10 @@ struct type_caster { return false; } - static handle - cast(const at::Tensor& src, return_value_policy /* policy */, handle /* parent */) { + static handle cast( + const at::Tensor& src, + return_value_policy /* policy */, + handle /* parent */) { return handle(THPVariable_Wrap(src)); } }; @@ -67,8 +70,10 @@ struct type_caster { return false; } - static handle - cast(const at::Storage& src, return_value_policy /* policy */, handle /* parent */) { + static handle cast( + const at::Storage& src, + return_value_policy /* policy */, + handle /* parent */) { return handle(torch::createPyObject(src)); } }; @@ -88,30 +93,36 @@ struct type_caster { return false; } - static handle - cast(const at::Generator& src, return_value_policy /* policy */, handle /* parent */) { + static handle cast( + const at::Generator& src, + return_value_policy /* policy */, + handle /* parent */) { return handle(THPGenerator_Wrap(src)); } }; -template<> struct type_caster { -public: +template <> +struct type_caster { + public: // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) PYBIND11_TYPE_CASTER(at::IntArrayRef, _("at::IntArrayRef")); bool load(handle src, bool) { - PyObject *source = src.ptr(); + PyObject* source = src.ptr(); auto tuple = PyTuple_Check(source); if (tuple || PyList_Check(source)) { // NOLINTNEXTLINE(bugprone-branch-clone) - const auto size = tuple ? PyTuple_GET_SIZE(source) : PyList_GET_SIZE(source); + const auto size = + tuple ? PyTuple_GET_SIZE(source) : PyList_GET_SIZE(source); v_value.resize(size); - for(const auto idx : c10::irange(size)) { - PyObject* obj = tuple ? PyTuple_GET_ITEM(source, idx) : PyList_GET_ITEM(source, idx); + for (const auto idx : c10::irange(size)) { + PyObject* obj = tuple ? PyTuple_GET_ITEM(source, idx) + : PyList_GET_ITEM(source, idx); if (THPVariable_Check(obj)) { v_value[idx] = THPVariable_Unpack(obj).item(); } else if (PyLong_Check(obj)) { - // use THPUtils_unpackLong after it is safe to include python_numbers.h + // use THPUtils_unpackLong after it is safe to include + // python_numbers.h v_value[idx] = THPUtils_unpackLong(obj); } else { return false; @@ -122,10 +133,14 @@ template<> struct type_caster { } return false; } - static handle cast(at::IntArrayRef src, return_value_policy /* policy */, handle /* parent */) { + static handle cast( + at::IntArrayRef src, + return_value_policy /* policy */, + handle /* parent */) { return handle(THPUtils_packInt64Array(src.size(), src.data())); } -private: + + private: std::vector v_value; }; @@ -150,8 +165,10 @@ struct type_caster { return false; } - static handle - cast(const at::Device& src, return_value_policy /* policy */, handle /* parent */) { + static handle cast( + const at::Device& src, + return_value_policy /* policy */, + handle /* parent */) { return handle(THPDevice_New(src)); } }; @@ -160,7 +177,8 @@ struct type_caster { // http://pybind11.readthedocs.io/en/stable/advanced/cast/stl.html#c-17-library-containers template struct type_caster> : optional_caster> {}; -}} // namespace pybind11::detail +} // namespace detail +} // namespace pybind11 namespace torch { namespace impl { @@ -192,8 +210,10 @@ namespace impl { // violation (why does code that is ostensibly Python agnostic calling into the // GIL). // -// Adapted from https://github.com/pybind/pybind11/issues/1446#issuecomment-406341510 -template inline void destroy_without_gil(T *ptr) { +// Adapted from +// https://github.com/pybind/pybind11/issues/1446#issuecomment-406341510 +template +inline void destroy_without_gil(T* ptr) { // Because the ownership of a shared_ptr is diffuse, it's not possible to // necessarily predict whether or not the last reference to an object will // be destructed from Python or C++. This means that in the destructor here, diff --git a/torch/csrc/utils/pycfunction_helpers.h b/torch/csrc/utils/pycfunction_helpers.h index 32cfeb5999d5dc..fa1016f9a7e786 100644 --- a/torch/csrc/utils/pycfunction_helpers.h +++ b/torch/csrc/utils/pycfunction_helpers.h @@ -4,5 +4,5 @@ inline PyCFunction castPyCFunctionWithKeywords(PyCFunctionWithKeywords func) { // NOLINTNEXTLINE(modernize-redundant-void-arg) - return (PyCFunction)(void(*)(void))func; + return (PyCFunction)(void (*)(void))func; } diff --git a/torch/csrc/utils/python_arg_parser.cpp b/torch/csrc/utils/python_arg_parser.cpp index 6d9fe5a5c9a255..92ac9608bda1de 100644 --- a/torch/csrc/utils/python_arg_parser.cpp +++ b/torch/csrc/utils/python_arg_parser.cpp @@ -23,32 +23,32 @@ namespace torch { static std::unordered_map type_map = { - {"Tensor", ParameterType::TENSOR}, - {"Scalar", ParameterType::SCALAR}, - {"int64_t", ParameterType::INT64}, - {"double", ParameterType::DOUBLE}, - {"complex", ParameterType::COMPLEX}, - {"TensorList", ParameterType::TENSOR_LIST}, - {"c10::List>", ParameterType::TENSOR_LIST}, - {"IntArrayRef", ParameterType::INT_LIST}, - {"ArrayRef", ParameterType::FLOAT_LIST}, - {"Generator", ParameterType::GENERATOR}, - {"bool", ParameterType::BOOL}, - {"Storage", ParameterType::STORAGE}, - {"PyObject*", ParameterType::PYOBJECT}, - {"ScalarType", ParameterType::SCALARTYPE}, - {"Layout", ParameterType::LAYOUT}, - {"MemoryFormat", ParameterType::MEMORY_FORMAT}, - {"QScheme", ParameterType::QSCHEME}, - {"Device", ParameterType::DEVICE}, - {"Stream", ParameterType::STREAM}, - {"std::string", ParameterType::STRING}, - {"c10::string_view", ParameterType::STRING}, - {"SymInt", ParameterType::SYM_INT}, - {"Dimname", ParameterType::DIMNAME}, - {"SymIntArrayRef", ParameterType::SYM_INT_LIST}, - {"DimnameList", ParameterType::DIMNAME_LIST}, - {"ScalarList", ParameterType::SCALAR_LIST}, + {"Tensor", ParameterType::TENSOR}, + {"Scalar", ParameterType::SCALAR}, + {"int64_t", ParameterType::INT64}, + {"double", ParameterType::DOUBLE}, + {"complex", ParameterType::COMPLEX}, + {"TensorList", ParameterType::TENSOR_LIST}, + {"c10::List>", ParameterType::TENSOR_LIST}, + {"IntArrayRef", ParameterType::INT_LIST}, + {"ArrayRef", ParameterType::FLOAT_LIST}, + {"Generator", ParameterType::GENERATOR}, + {"bool", ParameterType::BOOL}, + {"Storage", ParameterType::STORAGE}, + {"PyObject*", ParameterType::PYOBJECT}, + {"ScalarType", ParameterType::SCALARTYPE}, + {"Layout", ParameterType::LAYOUT}, + {"MemoryFormat", ParameterType::MEMORY_FORMAT}, + {"QScheme", ParameterType::QSCHEME}, + {"Device", ParameterType::DEVICE}, + {"Stream", ParameterType::STREAM}, + {"std::string", ParameterType::STRING}, + {"c10::string_view", ParameterType::STRING}, + {"SymInt", ParameterType::SYM_INT}, + {"Dimname", ParameterType::DIMNAME}, + {"SymIntArrayRef", ParameterType::SYM_INT_LIST}, + {"DimnameList", ParameterType::DIMNAME_LIST}, + {"ScalarList", ParameterType::SCALAR_LIST}, }; // Default arg name translations for compatibility with NumPy. @@ -60,17 +60,18 @@ static std::unordered_map type_map = { // ``` // // A vector is necessary, because we might need to try multiple values. -// In particular, NumPy sometimes uses "x" and sometimes "a" for the main input tensor. -// Rather than annotate each function separately with whether it should take "x" or "a", -// just try both. +// In particular, NumPy sometimes uses "x" and sometimes "a" for the main input +// tensor. Rather than annotate each function separately with whether it should +// take "x" or "a", just try both. // // TODO: Allow individual functions to specify non-default translations: // For example, `torch.pow` should translate "exponent" to "x2". -static const std::unordered_map> numpy_compatibility_arg_names = { - {"dim", {"axis"}}, - {"keepdim", {"keepdims"}}, - {"input", {"x", "a", "x1"}}, - {"other", {"x2"}}, +static const std::unordered_map> + numpy_compatibility_arg_names = { + {"dim", {"axis"}}, + {"keepdim", {"keepdims"}}, + {"input", {"x", "a", "x1"}}, + {"other", {"x2"}}, }; // TODO: remove this. This is a temporary list of functions that allow Python @@ -83,27 +84,25 @@ static const std::unordered_map> numpy_com // functions.) static bool should_allow_numbers_as_tensors(const std::string& name) { static std::unordered_set allowed = { - "add", "add_", "add_out", - "div", "div_", "div_out", - "divide", "divide_", "divide_out", // alias of div - "mul", "mul_", "mul_out", - "multiply", "multiply_", "multiply_out", // alias of mul - "sub", "sub_", "sub_out", - "subtract", "subtract_", "subtract_out", // alias of sub - "true_divide", "true_divide_", "true_divide_out", - "floor_divide", "floor_divide_", "floor_divide_out" - }; + "add", "add_", "add_out", + "div", "div_", "div_out", + "divide", "divide_", "divide_out", // alias of div + "mul", "mul_", "mul_out", + "multiply", "multiply_", "multiply_out", // alias of mul + "sub", "sub_", "sub_out", + "subtract", "subtract_", "subtract_out", // alias of sub + "true_divide", "true_divide_", "true_divide_out", + "floor_divide", "floor_divide_", "floor_divide_out"}; return allowed.find(name) != allowed.end(); } // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) FunctionParameter::FunctionParameter(const std::string& fmt, bool keyword_only) - : optional(false) - , allow_none(false) - , keyword_only(keyword_only) - , size(0) - , default_scalar(0) -{ + : optional(false), + allow_none(false), + keyword_only(keyword_only), + size(0), + default_scalar(0) { auto space = fmt.find(' '); if (space == std::string::npos) { throw std::runtime_error("FunctionParameter(): missing type: " + fmt); @@ -120,7 +119,8 @@ FunctionParameter::FunctionParameter(const std::string& fmt, bool keyword_only) // Parse and remove brackets from type_str auto bracket = type_str.find('['); if (bracket != std::string::npos) { - auto size_str = type_str.substr(bracket + 1, type_str.length() - bracket - 2); + auto size_str = + type_str.substr(bracket + 1, type_str.length() - bracket - 2); size = atoi(size_str.c_str()); type_str = type_str.substr(0, bracket); } @@ -128,7 +128,8 @@ FunctionParameter::FunctionParameter(const std::string& fmt, bool keyword_only) auto name_str = fmt.substr(space + 1); auto it = type_map.find(type_str); if (it == type_map.end()) { - throw std::runtime_error("FunctionParameter(): invalid type string: " + type_str); + throw std::runtime_error( + "FunctionParameter(): invalid type string: " + type_str); } type_ = it->second; @@ -143,38 +144,60 @@ FunctionParameter::FunctionParameter(const std::string& fmt, bool keyword_only) python_name = THPUtils_internString(name); auto np_compat_it = numpy_compatibility_arg_names.find(name); if (np_compat_it != numpy_compatibility_arg_names.end()) { - for (const auto& str: np_compat_it->second) { + for (const auto& str : np_compat_it->second) { numpy_python_names.push_back(THPUtils_internString(str)); } } } -auto handle_torch_function_getter(THPVariable* self, const std::string& property_name) -> PyObject* { - py::object torch_api = PyObject_FastGetAttrString(THPVariableClass, (char*)property_name.c_str()); +auto handle_torch_function_getter( + THPVariable* self, + const std::string& property_name) -> PyObject* { + py::object torch_api = PyObject_FastGetAttrString( + THPVariableClass, (char*)property_name.c_str()); std::string module_name = "torch.Tensor." + property_name; - return handle_torch_function((PyObject *)self, "__get__", nullptr, nullptr, torch_api.ptr(), module_name); + return handle_torch_function( + (PyObject*)self, + "__get__", + nullptr, + nullptr, + torch_api.ptr(), + module_name); } -auto handle_torch_function_setter(THPVariable* self, const std::string& property_name, PyObject* value) -> int { - py::object torch_api = PyObject_FastGetAttrString(THPVariableClass, (char*)property_name.c_str()); +auto handle_torch_function_setter( + THPVariable* self, + const std::string& property_name, + PyObject* value) -> int { + py::object torch_api = PyObject_FastGetAttrString( + THPVariableClass, (char*)property_name.c_str()); std::string module_name = "torch.Tensor." + property_name; - if (value != nullptr) - { + if (value != nullptr) { py::tuple args_ = py::make_tuple(py::handle(value)); - handle_torch_function((PyObject *)self, "__set__", args_.ptr(), nullptr, torch_api.ptr(), module_name); - } - else { - handle_torch_function((PyObject *)self, "__delete__", nullptr, nullptr, torch_api.ptr(), module_name); + handle_torch_function( + (PyObject*)self, + "__set__", + args_.ptr(), + nullptr, + torch_api.ptr(), + module_name); + } else { + handle_torch_function( + (PyObject*)self, + "__delete__", + nullptr, + nullptr, + torch_api.ptr(), + module_name); } return 0; } // Combines self and args into one tuple. -auto combine_self_args(PyObject *self, PyObject *args) -> py::tuple { +auto combine_self_args(PyObject* self, PyObject* args) -> py::tuple { if (args == nullptr) { return py::make_tuple(py::handle(self)); - } - else if (self == nullptr) { + } else if (self == nullptr) { return py::reinterpret_borrow(args); } @@ -182,8 +205,8 @@ auto combine_self_args(PyObject *self, PyObject *args) -> py::tuple { size_t n = py_args.size(); auto args_ = py::tuple(n + 1); args_[0] = py::handle(self); - for(const auto i : c10::irange(n)) { - args_[i+1] = py_args[i]; + for (const auto i : c10::irange(n)) { + args_[i + 1] = py_args[i]; } return args_; } @@ -196,11 +219,26 @@ auto combine_self_args(PyObject *self, PyObject *args) -> py::tuple { // be passed around in this way). const char* torch_function_mode_name = "__torch_function__"; -auto handle_torch_function(PyObject* self, const std::string& func_name, PyObject* args, PyObject* kwargs, PyObject* torch_api, const std::string& module_name) -> PyObject* { - py::object torch_api_function = PyObject_FastGetAttrString(torch_api, (char*)func_name.c_str()); - TORCH_INTERNAL_ASSERT(torch_api_function.ptr() != nullptr, "torch API function must exist"); +auto handle_torch_function( + PyObject* self, + const std::string& func_name, + PyObject* args, + PyObject* kwargs, + PyObject* torch_api, + const std::string& module_name) -> PyObject* { + py::object torch_api_function = + PyObject_FastGetAttrString(torch_api, (char*)func_name.c_str()); + TORCH_INTERNAL_ASSERT( + torch_api_function.ptr() != nullptr, "torch API function must exist"); py::tuple args_ = combine_self_args(self, args); - return handle_torch_function_no_python_arg_parser({py::handle(self)}, args_.ptr(), kwargs, func_name.c_str(), torch_api_function.ptr(), module_name.c_str(), TorchFunctionName::TorchFunction); + return handle_torch_function_no_python_arg_parser( + {py::handle(self)}, + args_.ptr(), + kwargs, + func_name.c_str(), + torch_api_function.ptr(), + module_name.c_str(), + TorchFunctionName::TorchFunction); } // Note: [Overloaded args] @@ -243,14 +281,18 @@ auto handle_torch_function_no_python_arg_parser( // necessarily types std::vector overloaded_types; overloaded_types.reserve(overloaded_args.size()); - for (auto &arg : overloaded_args) { - overloaded_types.push_back(py::reinterpret_borrow(get_type_of_overloaded_arg(arg.ptr()))); + for (auto& arg : overloaded_args) { + overloaded_types.push_back(py::reinterpret_borrow( + get_type_of_overloaded_arg(arg.ptr()))); } py::tuple py_types = py::cast(overloaded_types); py::object ret; PyObject* mode_obj = nullptr; - const bool is_torch_function = torch_function_name == TorchFunctionName::TorchFunction; - const auto& maybe_mode = is_torch_function ? at::impl::PythonTorchFunctionTLS::get_mode() : at::impl::TorchDispatchModeTLS::get_state(); + const bool is_torch_function = + torch_function_name == TorchFunctionName::TorchFunction; + const auto& maybe_mode = is_torch_function + ? at::impl::PythonTorchFunctionTLS::get_mode() + : at::impl::TorchDispatchModeTLS::get_state(); if (maybe_mode) { mode_obj = maybe_mode->ptr(getPyInterpreter()); TORCH_INTERNAL_ASSERT(py_types.ptr() != nullptr); @@ -261,14 +303,19 @@ auto handle_torch_function_no_python_arg_parser( at::optional td_g; if (is_torch_function) { tf_g.emplace(); - } else{ + } else { td_g.emplace(); } // Blegh. This accidentally works in PyObject_CallFunctionObjArgs below // because the nullptr terminates the argument list ick ick ick. if (kwargs == nullptr) { ret = py::reinterpret_steal(PyObject_CallMethod( - mode_obj, torch_function_name_str, "OOO", torch_api_function, py_types.ptr(), args)); + mode_obj, + torch_function_name_str, + "OOO", + torch_api_function, + py_types.ptr(), + args)); } else { ret = py::reinterpret_steal(PyObject_CallMethod( mode_obj, @@ -290,10 +337,14 @@ auto handle_torch_function_no_python_arg_parser( PyObject_FastGetAttrString(arg.ptr(), torch_function_name_str); // See https://github.com/pytorch/pytorch/issues/63767 - if (PyObject_FastGetAttrString(torch_function.ptr(), "__self__").is(arg) && + if (PyObject_FastGetAttrString(torch_function.ptr(), "__self__") + .is(arg) && torch_function.ptr() != torch::disabled_torch_function_impl()) { - TORCH_WARN("Defining your `", torch_function_name_str, "` as a plain method is deprecated ", - "and will be an error in future, please define it as a classmethod."); + TORCH_WARN( + "Defining your `", + torch_function_name_str, + "` as a plain method is deprecated ", + "and will be an error in future, please define it as a classmethod."); } ret = py::reinterpret_steal(PyObject_CallFunctionObjArgs( @@ -315,14 +366,13 @@ auto handle_torch_function_no_python_arg_parser( // if an exception occurred in a user's implementation of // __torch_function__, throw it throw python_error(); - } - else if (ret.ptr() == Py_NotImplemented) { + } else if (ret.ptr() == Py_NotImplemented) { // all __torch_function__ implementations in overloaded_args // returned NotImplemented, so we raise a TypeError. std::stringstream ss; ss << "no implementation found for '" << module_name << "." << func_name << "' on types that implement " << torch_function_name_str << ": ["; - for (auto &arg : overloaded_args) { + for (auto& arg : overloaded_args) { ss << py::repr(get_type_of_overloaded_arg(arg.ptr())); if (!arg.is(overloaded_args.back())) { ss << ", "; @@ -347,53 +397,82 @@ auto handle_torch_function_no_python_arg_parser( return ret.release().ptr(); } -auto handle_torch_function(PythonArgs &r, PyObject* self, PyObject* args, PyObject* kwargs, PyObject* torch_api, const char* module_name, const char* func_name_override) -> PyObject* { +auto handle_torch_function( + PythonArgs& r, + PyObject* self, + PyObject* args, + PyObject* kwargs, + PyObject* torch_api, + const char* module_name, + const char* func_name_override) -> PyObject* { py::object torch_api_function = PyObject_FastGetAttrString( - torch_api, - (char*)( - func_name_override ? func_name_override : r.get_func_name().c_str() - ) - ); - TORCH_INTERNAL_ASSERT(torch_api_function.ptr() != nullptr, "torch API function must exist"); + torch_api, + (char*)(func_name_override ? func_name_override : r.get_func_name().c_str())); + TORCH_INTERNAL_ASSERT( + torch_api_function.ptr() != nullptr, "torch API function must exist"); py::object ret; py::tuple args_ = combine_self_args(self, args); // overloaded_args already all have unique types std::vector overloaded_types; overloaded_types.reserve(r.signature.overloaded_args.size()); - for (auto &arg : r.signature.overloaded_args) { - overloaded_types.push_back(py::reinterpret_borrow((PyObject *) Py_TYPE(arg.ptr()))); + for (auto& arg : r.signature.overloaded_args) { + overloaded_types.push_back( + py::reinterpret_borrow((PyObject*)Py_TYPE(arg.ptr()))); } py::tuple py_types = py::cast(overloaded_types); - return handle_torch_function_no_python_arg_parser(r.signature.overloaded_args, args_.ptr(), kwargs, r.get_func_name().c_str(), torch_api_function.ptr(), module_name); + return handle_torch_function_no_python_arg_parser( + r.signature.overloaded_args, + args_.ptr(), + kwargs, + r.get_func_name().c_str(), + torch_api_function.ptr(), + module_name); } -auto handle_torch_function(PythonArgs &r, PyObject* args, PyObject* kwargs, PyObject* torch_api, const char* module_name, const char* func_name_override) -> PyObject* -{ - return handle_torch_function(r, nullptr, args, kwargs, torch_api, module_name, func_name_override); +auto handle_torch_function( + PythonArgs& r, + PyObject* args, + PyObject* kwargs, + PyObject* torch_api, + const char* module_name, + const char* func_name_override) -> PyObject* { + return handle_torch_function( + r, nullptr, args, kwargs, torch_api, module_name, func_name_override); } -auto handle_torch_function_indexing(PyObject* self, PyObject* index, PyObject* val) -> PyObject* { - const char *func_name = (val == nullptr) ? "__getitem__" : "__setitem__"; +auto handle_torch_function_indexing( + PyObject* self, + PyObject* index, + PyObject* val) -> PyObject* { + const char* func_name = (val == nullptr) ? "__getitem__" : "__setitem__"; py::object index_tup; if (PyTuple_Check(index)) { index_tup = py::reinterpret_borrow(index); - } - else { + } else { index_tup = py::make_tuple(py::handle(index)); } std::vector overridable_args; is_tensor_and_append_overloaded(self, &overridable_args); - auto size = PyTuple_GET_SIZE(index_tup.ptr()); + auto size = PyTuple_GET_SIZE(index_tup.ptr()); for (auto i : c10::irange(size)) { - auto *obj = PyTuple_GetItem(index_tup.ptr(), i); + auto* obj = PyTuple_GetItem(index_tup.ptr(), i); is_tensor_and_append_overloaded(obj, &overridable_args); } if (val != nullptr) { is_tensor_and_append_overloaded(val, &overridable_args); } - py::object func = PyObject_FastGetAttrString(THPVariableClass, (char *)func_name); - py::object args = (val == nullptr) ? py::make_tuple(py::handle(self), py::handle(index)) : py::make_tuple(py::handle(self), py::handle(index), py::handle(val)); - return handle_torch_function_no_python_arg_parser(overridable_args, args.ptr(), nullptr, func_name, func.ptr(), "torch.Tensor"); + py::object func = + PyObject_FastGetAttrString(THPVariableClass, (char*)func_name); + py::object args = (val == nullptr) + ? py::make_tuple(py::handle(self), py::handle(index)) + : py::make_tuple(py::handle(self), py::handle(index), py::handle(val)); + return handle_torch_function_no_python_arg_parser( + overridable_args, + args.ptr(), + nullptr, + func_name, + func.ptr(), + "torch.Tensor"); } /* @@ -436,10 +515,13 @@ auto handle_torch_function_indexing(PyObject* self, PyObject* index, PyObject* v * */ -static void append_overloaded_arg(std::vector* overloaded_args, PyObject* obj, bool obj_is_type) { +static void append_overloaded_arg( + std::vector* overloaded_args, + PyObject* obj, + bool obj_is_type) { bool class_not_seen_yet = true; PyObject* obj_type = obj_is_type ? obj : (PyObject*)Py_TYPE(obj); - for (auto &arg : *overloaded_args) { + for (auto& arg : *overloaded_args) { if (obj_type == get_type_of_overloaded_arg(arg.ptr())) { // obj is the same type as another parameter we've seen in a prior // iteration of the loop over parameters so we already have an entry @@ -451,8 +533,11 @@ static void append_overloaded_arg(std::vector* overloaded_args, PyOb } if (class_not_seen_yet) { int arg_index = overloaded_args->size(); - for(const auto j : c10::irange(arg_index)) { - if (PyObject_IsSubclass(obj_type, (PyObject*)(get_type_of_overloaded_arg((*overloaded_args)[j].ptr())))) { + for (const auto j : c10::irange(arg_index)) { + if (PyObject_IsSubclass( + obj_type, + (PyObject*)(get_type_of_overloaded_arg( + (*overloaded_args)[j].ptr())))) { // obj is a subclass of another object we've seen already so its // __torch_function__ should be called first, therefore we // insert it into overloaded_args before the superclass @@ -467,15 +552,21 @@ static void append_overloaded_arg(std::vector* overloaded_args, PyOb } } -void append_overloaded_tensor(std::vector* overloaded_args, PyObject* obj) { - append_overloaded_arg(overloaded_args, obj, /*obj_is_type*/false); +void append_overloaded_tensor( + std::vector* overloaded_args, + PyObject* obj) { + append_overloaded_arg(overloaded_args, obj, /*obj_is_type*/ false); } -void append_overloaded_type(std::vector* overloaded_args, PyObject* obj) { - append_overloaded_arg(overloaded_args, obj, /*obj_is_type*/true); +void append_overloaded_type( + std::vector* overloaded_args, + PyObject* obj) { + append_overloaded_arg(overloaded_args, obj, /*obj_is_type*/ true); } -bool is_tensor_and_append_overloaded(PyObject* obj, std::vector* overloaded_args) { +bool is_tensor_and_append_overloaded( + PyObject* obj, + std::vector* overloaded_args) { if (THPVariable_CheckExact(obj)) { // torch.Tensor instances (not subclasses, except for Parameter) return true; @@ -501,7 +592,8 @@ bool is_scalar_list(PyObject* obj) { // NOLINTNEXTLINE(bugprone-branch-clone) const auto size = tuple ? PyTuple_GET_SIZE(obj) : PyList_GET_SIZE(obj); for (const auto idx : c10::irange(size)) { - PyObject* iobj = tuple ? PyTuple_GET_ITEM(obj, idx) : PyList_GET_ITEM(obj, idx); + PyObject* iobj = + tuple ? PyTuple_GET_ITEM(obj, idx) : PyList_GET_ITEM(obj, idx); if (!THPUtils_checkScalar(iobj)) { return false; } @@ -509,19 +601,27 @@ bool is_scalar_list(PyObject* obj) { return true; } -bool is_tensor_list_and_append_overloaded(PyObject* obj, std::vector* overloaded_args, int argnum, bool throw_error) { +bool is_tensor_list_and_append_overloaded( + PyObject* obj, + std::vector* overloaded_args, + int argnum, + bool throw_error) { auto tuple = six::isTuple(obj); if (!(tuple || PyList_Check(obj))) { return false; } // NOLINTNEXTLINE(bugprone-branch-clone) -const auto size = tuple ? PyTuple_GET_SIZE(obj) : PyList_GET_SIZE(obj); + const auto size = tuple ? PyTuple_GET_SIZE(obj) : PyList_GET_SIZE(obj); for (long idx = 0; idx < size; idx++) { - PyObject* iobj = tuple ? PyTuple_GET_ITEM(obj, idx) : PyList_GET_ITEM(obj, idx); + PyObject* iobj = + tuple ? PyTuple_GET_ITEM(obj, idx) : PyList_GET_ITEM(obj, idx); if (!is_tensor_and_append_overloaded(iobj, overloaded_args)) { if (throw_error) { - throw TypeError("expected Tensor as element %d in argument %d, but got %s", - static_cast(idx), argnum, Py_TYPE(iobj)->tp_name); + throw TypeError( + "expected Tensor as element %d in argument %d, but got %s", + static_cast(idx), + argnum, + Py_TYPE(iobj)->tp_name); } return false; } @@ -552,19 +652,18 @@ static bool is_int_list_(PyObject* obj, int broadcast_size) { if (PySequence_Size(obj) == 0) { return true; } - auto item = py::reinterpret_steal( - PySequence_GetItem(obj, 0)); + auto item = py::reinterpret_steal(PySequence_GetItem(obj, 0)); if (THPUtils_checkIndex(item.ptr())) { return true; } // NOTE: JIT tracer allows arbitrary scalar tensors to act as ints // in an intlist argument. Even float or complex scalar tensors. return ( - jit::tracer::isTracing() && - THPVariable_Check(item.ptr()) && + jit::tracer::isTracing() && THPVariable_Check(item.ptr()) && THPVariable_Unpack(item.ptr()).sizes() == c10::IntArrayRef{}); } - // if a size is specified (e.g. IntArrayRef[2]) we also allow passing a single int + // if a size is specified (e.g. IntArrayRef[2]) we also allow passing a single + // int return broadcast_size > 0 && THPUtils_checkLong(obj); } @@ -574,7 +673,7 @@ static bool is_int_list(PyObject* obj, int broadcast_size) { static bool is_int_or_symint(PyObject* obj) { if (THPUtils_checkLong(obj)) { - return true; + return true; } // TODO: test if it's the Python binding for SymbolicIntNode @@ -587,8 +686,10 @@ static bool is_int_or_symint_list(PyObject* obj, int broadcast_size) { } // argnum is needed for raising the TypeError, it's used in the error message. -auto FunctionParameter::check(PyObject* obj, std::vector &overloaded_args, int argnum) -> bool -{ +auto FunctionParameter::check( + PyObject* obj, + std::vector& overloaded_args, + int argnum) -> bool { switch (type_) { case ParameterType::TENSOR: { if (is_tensor_and_append_overloaded(obj, &overloaded_args)) { @@ -618,37 +719,54 @@ auto FunctionParameter::check(PyObject* obj, std::vector &overloaded } if (THPVariable_Check(obj)) { const auto& var = THPVariable_Unpack(obj); - return at::isIntegralType(var.scalar_type(), /*includeBool=*/false) && !var.requires_grad() && var.dim() == 0; + return at::isIntegralType(var.scalar_type(), /*includeBool=*/false) && + !var.requires_grad() && var.dim() == 0; } return false; } - case ParameterType::DIMNAME: return THPUtils_checkDimname(obj); + case ParameterType::DIMNAME: + return THPUtils_checkDimname(obj); case ParameterType::DIMNAME_LIST: { if (THPUtils_checkDimnameList(obj)) { return true; } - // if a size is specified (e.g. DimnameList[1]) we also allow passing a single Dimname + // if a size is specified (e.g. DimnameList[1]) we also allow passing a + // single Dimname return size == 1 && THPUtils_checkDimname(obj); } case ParameterType::TENSOR_LIST: { - return is_tensor_list_and_append_overloaded(obj, &overloaded_args, argnum, true /* throw_error */); + return is_tensor_list_and_append_overloaded( + obj, &overloaded_args, argnum, true /* throw_error */); } - case ParameterType::INT_LIST: return is_int_list(obj, size); - case ParameterType::FLOAT_LIST: return is_float_or_complex_list(obj); - case ParameterType::GENERATOR: return THPGenerator_Check(obj); - case ParameterType::BOOL: return PyBool_Check(obj); - case ParameterType::STORAGE: return isStorage(obj); - case ParameterType::PYOBJECT: return true; - case ParameterType::SCALARTYPE: return THPDtype_Check(obj) || THPPythonScalarType_Check(obj); - case ParameterType::LAYOUT: return THPLayout_Check(obj); - case ParameterType::MEMORY_FORMAT: return THPMemoryFormat_Check(obj); - case ParameterType::QSCHEME: return THPQScheme_Check(obj); + case ParameterType::INT_LIST: + return is_int_list(obj, size); + case ParameterType::FLOAT_LIST: + return is_float_or_complex_list(obj); + case ParameterType::GENERATOR: + return THPGenerator_Check(obj); + case ParameterType::BOOL: + return PyBool_Check(obj); + case ParameterType::STORAGE: + return isStorage(obj); + case ParameterType::PYOBJECT: + return true; + case ParameterType::SCALARTYPE: + return THPDtype_Check(obj) || THPPythonScalarType_Check(obj); + case ParameterType::LAYOUT: + return THPLayout_Check(obj); + case ParameterType::MEMORY_FORMAT: + return THPMemoryFormat_Check(obj); + case ParameterType::QSCHEME: + return THPQScheme_Check(obj); case ParameterType::DEVICE: - return THPUtils_checkLong(obj) || THPUtils_checkString(obj) || THPDevice_Check(obj); + return THPUtils_checkLong(obj) || THPUtils_checkString(obj) || + THPDevice_Check(obj); case ParameterType::STREAM: return THPStream_Check(obj); - case ParameterType::STRING: return THPUtils_checkString(obj); - default: throw std::runtime_error("unknown parameter type"); + case ParameterType::STRING: + return THPUtils_checkString(obj); + default: + throw std::runtime_error("unknown parameter type"); case ParameterType::SCALAR_LIST: { return is_scalar_list(obj); } @@ -663,30 +781,54 @@ auto FunctionParameter::check(PyObject* obj, std::vector &overloaded std::string FunctionParameter::type_name() const { switch (type_) { - case ParameterType::TENSOR: return "Tensor"; - case ParameterType::SCALAR: return "Number"; - case ParameterType::INT64: return "int"; - case ParameterType::SYM_INT: return "SymInt"; - case ParameterType::DOUBLE: return "float"; - case ParameterType::COMPLEX: return "complex"; - case ParameterType::TENSOR_LIST: return "tuple of Tensors"; - case ParameterType::INT_LIST: return "tuple of ints"; - case ParameterType::FLOAT_LIST: return "tuple of floats"; - case ParameterType::GENERATOR: return "torch.Generator"; - case ParameterType::BOOL: return "bool"; - case ParameterType::STORAGE: return "torch.Storage"; - case ParameterType::PYOBJECT: return "object"; - case ParameterType::SCALARTYPE: return "torch.dtype"; - case ParameterType::LAYOUT: return "torch.layout"; - case ParameterType::MEMORY_FORMAT: return "torch.memory_format"; - case ParameterType::QSCHEME: return "torch.qscheme"; - case ParameterType::DEVICE: return "torch.device"; - case ParameterType::STRING: return "str"; - case ParameterType::DIMNAME: return "name"; - case ParameterType::DIMNAME_LIST: return "tuple of names"; - case ParameterType::SCALAR_LIST: return "tuple of Scalars"; - case ParameterType::SYM_INT_LIST: return "tuple of SymInts"; - default: throw std::runtime_error("unknown parameter type"); + case ParameterType::TENSOR: + return "Tensor"; + case ParameterType::SCALAR: + return "Number"; + case ParameterType::INT64: + return "int"; + case ParameterType::SYM_INT: + return "SymInt"; + case ParameterType::DOUBLE: + return "float"; + case ParameterType::COMPLEX: + return "complex"; + case ParameterType::TENSOR_LIST: + return "tuple of Tensors"; + case ParameterType::INT_LIST: + return "tuple of ints"; + case ParameterType::FLOAT_LIST: + return "tuple of floats"; + case ParameterType::GENERATOR: + return "torch.Generator"; + case ParameterType::BOOL: + return "bool"; + case ParameterType::STORAGE: + return "torch.Storage"; + case ParameterType::PYOBJECT: + return "object"; + case ParameterType::SCALARTYPE: + return "torch.dtype"; + case ParameterType::LAYOUT: + return "torch.layout"; + case ParameterType::MEMORY_FORMAT: + return "torch.memory_format"; + case ParameterType::QSCHEME: + return "torch.qscheme"; + case ParameterType::DEVICE: + return "torch.device"; + case ParameterType::STRING: + return "str"; + case ParameterType::DIMNAME: + return "name"; + case ParameterType::DIMNAME_LIST: + return "tuple of names"; + case ParameterType::SCALAR_LIST: + return "tuple of Scalars"; + case ParameterType::SYM_INT_LIST: + return "tuple of SymInts"; + default: + throw std::runtime_error("unknown parameter type"); } } @@ -694,7 +836,7 @@ static inline c10::optional parse_as_integer(const std::string& s) { if (s.empty()) return c10::nullopt; // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - char *str_end; + char* str_end; long ans = strtol(s.c_str(), &str_end, 0); // *str_end == 0 if the entire string was parsed as an integer. return (*str_end == 0) ? c10::optional(ans) : c10::nullopt; @@ -705,12 +847,16 @@ Parse default value of IntArrayRef declared at native_functions.yaml There are two kinds of default values: 1. IntArrayRef[2] x=1 (where size=2, value={1,1} -2. IntArrayRef x={1,2,3} (where size=3, value={1,2,3}, note that there cannot be space after comma since native_parse.py uses ', ' to split args) +2. IntArrayRef x={1,2,3} (where size=3, value={1,2,3}, note that there cannot be +space after comma since native_parse.py uses ', ' to split args) */ -static inline std::vector parse_intlist_args(const std::string& s, int64_t size) { +static inline std::vector parse_intlist_args( + const std::string& s, + int64_t size) { size_t n = s.size(); - if (s.empty()) return std::vector(); + if (s.empty()) + return std::vector(); // case 1. s is an int (e.g., s=2) if (s[0] != '{') { @@ -719,14 +865,18 @@ static inline std::vector parse_intlist_args(const std::string& s, int6 // case 2. s is a list of dims (e.g., s={1,2}) - // since already checked left brace '{' above, here only checks right brace '}' - TORCH_CHECK(s[n - 1] == '}', "Default value of IntArrayRef is missing right brace '}', found ", s[n - 1]); + // since already checked left brace '{' above, here only checks right brace + // '}' + TORCH_CHECK( + s[n - 1] == '}', + "Default value of IntArrayRef is missing right brace '}', found ", + s[n - 1]); auto args = std::vector(); std::istringstream ss(s.substr(1, s.length() - 2)); // exclude '{' and '}' std::string tok; - while(std::getline(ss, tok, ',')) { + while (std::getline(ss, tok, ',')) { args.emplace_back(std::stol(tok)); } return args; @@ -737,11 +887,13 @@ static std::string parse_string_literal(c10::string_view str) { TORCH_CHECK(str.length() >= 2, "String defaults must be quoted"); if (str.front() == '"') { - TORCH_CHECK(str.back() == '"', - "Mismatched quotes in string default: ", str); + TORCH_CHECK( + str.back() == '"', "Mismatched quotes in string default: ", str); } else { - TORCH_CHECK(str.front() == '\'' && str.back() == '\'', - "Invalid quotes in string default: ", str) + TORCH_CHECK( + str.front() == '\'' && str.back() == '\'', + "Invalid quotes in string default: ", + str) } std::string parsed; @@ -754,7 +906,8 @@ static std::string parse_string_literal(c10::string_view str) { } // Handle escape sequences - TORCH_CHECK(i < str.size() - 2, "String ends with escaped final quote: ", str) + TORCH_CHECK( + i < str.size() - 2, "String ends with escaped final quote: ", str) char c = str[i + 1]; switch (c) { case '\\': @@ -780,7 +933,10 @@ static std::string parse_string_literal(c10::string_view str) { c = '\t'; break; default: - TORCH_CHECK(false, "Unsupported escape sequence in string default: \\", str[i + 1]); + TORCH_CHECK( + false, + "Unsupported escape sequence in string default: \\", + str[i + 1]); } parsed.push_back(c); i += 2; @@ -794,7 +950,8 @@ void FunctionParameter::set_default_str(const std::string& str) { } if (type_ == ParameterType::TENSOR) { if (str != "None") { - throw std::runtime_error("default value for Tensor must be none, got: " + str); + throw std::runtime_error( + "default value for Tensor must be none, got: " + str); } } else if (type_ == ParameterType::INT64) { default_int = atol(str.c_str()); @@ -809,8 +966,8 @@ void FunctionParameter::set_default_str(const std::string& str) { if (str != "None") { // we sometimes rely on integer-vs-float values, e.g. with arange. const auto as_integer = parse_as_integer(str); - default_scalar = as_integer.has_value() ? at::Scalar(as_integer.value()) : - at::Scalar(atof(str.c_str())); + default_scalar = as_integer.has_value() ? at::Scalar(as_integer.value()) + : at::Scalar(atof(str.c_str())); } } else if (type_ == ParameterType::INT_LIST) { if (str != "None") { @@ -855,13 +1012,12 @@ void FunctionParameter::set_default_str(const std::string& str) { // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) FunctionSignature::FunctionSignature(const std::string& fmt, int index) - : min_args(0) - , max_args(0) - , max_pos_args(0) - , index(index) - , hidden(false) - , deprecated(false) -{ + : min_args(0), + max_args(0), + max_pos_args(0), + index(index), + hidden(false), + deprecated(false) { auto open_paren = fmt.find('('); if (open_paren == std::string::npos) { throw std::runtime_error("missing opening parenthesis: " + fmt); @@ -880,7 +1036,7 @@ FunctionSignature::FunctionSignature(const std::string& fmt, int index) if (offset == std::string::npos) { offset = fmt.find(')', last_offset); done = true; - next_offset = offset+ 1; + next_offset = offset + 1; // this 'if' happens for an empty parameter list, i.e. fn(). if (offset == last_offset) { last_offset = next_offset; @@ -928,7 +1084,8 @@ FunctionSignature::FunctionSignature(const std::string& fmt, int index) } std::string FunctionSignature::toString() const { - // TODO: consider printing more proper schema strings with defaults, optionals, etc. + // TODO: consider printing more proper schema strings with defaults, + // optionals, etc. std::ostringstream ss; bool keyword_already = false; ss << "("; @@ -948,23 +1105,32 @@ std::string FunctionSignature::toString() const { return ss.str(); } -[[noreturn]] -static void extra_args(const FunctionSignature& signature, Py_ssize_t nargs) { +[[noreturn]] static void extra_args( + const FunctionSignature& signature, + Py_ssize_t nargs) { const long max_pos_args = signature.max_pos_args; const long min_args = signature.min_args; const long nargs_ = nargs; if (min_args != max_pos_args) { - throw TypeError("%s() takes from %ld to %ld positional arguments but %ld were given", - signature.name.c_str(), min_args, max_pos_args, nargs_); - } - throw TypeError("%s() takes %ld positional argument%s but %ld %s given", + throw TypeError( + "%s() takes from %ld to %ld positional arguments but %ld were given", + signature.name.c_str(), + min_args, + max_pos_args, + nargs_); + } + throw TypeError( + "%s() takes %ld positional argument%s but %ld %s given", signature.name.c_str(), - max_pos_args, max_pos_args == 1 ? "" : "s", - nargs_, nargs == 1 ? "was" : "were"); + max_pos_args, + max_pos_args == 1 ? "" : "s", + nargs_, + nargs == 1 ? "was" : "were"); } -[[noreturn]] -static void missing_args(const FunctionSignature& signature, int idx) { +[[noreturn]] static void missing_args( + const FunctionSignature& signature, + int idx) { int num_missing = 0; std::stringstream ss; @@ -979,7 +1145,8 @@ static void missing_args(const FunctionSignature& signature, int idx) { } } - throw TypeError("%s() missing %d required positional argument%s: %s", + throw TypeError( + "%s() missing %d required positional argument%s: %s", signature.name.c_str(), num_missing, num_missing == 1 ? "s" : "", @@ -1000,10 +1167,12 @@ static Py_ssize_t find_param(FunctionSignature& signature, PyObject* name) { return -1; } -[[noreturn]] -static void extra_kwargs(FunctionSignature& signature, PyObject* kwargs, Py_ssize_t num_pos_args) { - PyObject *key = nullptr; - PyObject *value = nullptr; +[[noreturn]] static void extra_kwargs( + FunctionSignature& signature, + PyObject* kwargs, + Py_ssize_t num_pos_args) { + PyObject* key = nullptr; + PyObject* value = nullptr; Py_ssize_t pos = 0; while (PyDict_Next(kwargs, &pos, &key, &value)) { @@ -1013,13 +1182,17 @@ static void extra_kwargs(FunctionSignature& signature, PyObject* kwargs, Py_ssiz auto param_idx = find_param(signature, key); if (param_idx < 0) { - throw TypeError("%s() got an unexpected keyword argument '%s'", - signature.name.c_str(), THPUtils_unpackString(key).c_str()); + throw TypeError( + "%s() got an unexpected keyword argument '%s'", + signature.name.c_str(), + THPUtils_unpackString(key).c_str()); } if (param_idx < num_pos_args) { - throw TypeError("%s() got multiple values for argument '%s'", - signature.name.c_str(), THPUtils_unpackString(key).c_str()); + throw TypeError( + "%s() got multiple values for argument '%s'", + signature.name.c_str(), + THPUtils_unpackString(key).c_str()); } } @@ -1027,15 +1200,20 @@ static void extra_kwargs(FunctionSignature& signature, PyObject* kwargs, Py_ssiz throw TypeError("invalid keyword arguments"); } -bool FunctionSignature::parse(PyObject* self, PyObject* args, PyObject* kwargs, PyObject* dst[], // NOLINT - bool raise_exception) { +bool FunctionSignature::parse( + PyObject* self, + PyObject* args, + PyObject* kwargs, + PyObject* dst[], // NOLINT + bool raise_exception) { size_t nargs = args ? PyTuple_GET_SIZE(args) : 0; auto remaining_kwargs = kwargs ? PyDict_Size(kwargs) : 0; size_t arg_pos = 0; bool allow_varargs_intlist = false; - // if there is a single positional IntArrayRef argument, i.e. expand(..), view(...), - // allow a var-args style IntArrayRef, so expand(5,3) behaves as expand((5,3)) + // if there is a single positional IntArrayRef argument, i.e. expand(..), + // view(...), allow a var-args style IntArrayRef, so expand(5,3) behaves as + // expand((5,3)) if (max_pos_args == 1 && params[0].type_ == ParameterType::INT_LIST) { allow_varargs_intlist = true; } @@ -1070,7 +1248,7 @@ bool FunctionSignature::parse(PyObject* self, PyObject* args, PyObject* kwargs, obj = PyTuple_GET_ITEM(args, arg_pos); } else if (kwargs) { obj = PyDict_GetItem(kwargs, param.python_name); - for (PyObject *numpy_name: param.numpy_python_names) { + for (PyObject* numpy_name : param.numpy_python_names) { if (obj) { break; } @@ -1089,11 +1267,12 @@ bool FunctionSignature::parse(PyObject* self, PyObject* args, PyObject* kwargs, return false; } else if (param.check(obj, this->overloaded_args, i)) { dst[i++] = obj; - // XXX: the Variable check is necessary because sizes become tensors when - // tracer is enabled. This behavior easily leads to ambiguities, and we - // should avoid having complex signatures that make use of it... - } else if (allow_varargs_intlist && arg_pos == 0 && !is_kwd && - THPUtils_checkIndex(obj)) { + // XXX: the Variable check is necessary because sizes become tensors when + // tracer is enabled. This behavior easily leads to ambiguities, and we + // should avoid having complex signatures that make use of it... + } else if ( + allow_varargs_intlist && arg_pos == 0 && !is_kwd && + THPUtils_checkIndex(obj)) { // take all positional arguments as this parameter // e.g. permute(1, 2, 3) -> permute((1, 2, 3)) dst[i++] = args; @@ -1102,14 +1281,21 @@ bool FunctionSignature::parse(PyObject* self, PyObject* args, PyObject* kwargs, } else if (raise_exception) { if (is_kwd) { // foo(): argument 'other' must be str, not int - throw TypeError("%s(): argument '%s' must be %s, not %s", - name.c_str(), param.name.c_str(), param.type_name().c_str(), + throw TypeError( + "%s(): argument '%s' must be %s, not %s", + name.c_str(), + param.name.c_str(), + param.type_name().c_str(), Py_TYPE(obj)->tp_name); } else { // foo(): argument 'other' (position 2) must be str, not int - throw TypeError("%s(): argument '%s' (position %ld) must be %s, not %s", - name.c_str(), param.name.c_str(), static_cast(arg_pos + 1), - param.type_name().c_str(), Py_TYPE(obj)->tp_name); + throw TypeError( + "%s(): argument '%s' (position %ld) must be %s, not %s", + name.c_str(), + param.name.c_str(), + static_cast(arg_pos + 1), + param.type_name().c_str(), + Py_TYPE(obj)->tp_name); } } else { return false; @@ -1133,9 +1319,7 @@ bool FunctionSignature::parse(PyObject* self, PyObject* args, PyObject* kwargs, } PythonArgParser::PythonArgParser(std::vector fmts, bool traceable) - : max_args(0) - , traceable(traceable) -{ + : max_args(0), traceable(traceable) { int index = 0; for (auto& fmt : fmts) { signatures_.emplace_back(fmt, index); @@ -1151,21 +1335,24 @@ PythonArgParser::PythonArgParser(std::vector fmts, bool traceable) } // Check deprecated signatures last - std::stable_partition(signatures_.begin(), signatures_.end(), - [](const FunctionSignature & sig) { - return !sig.deprecated; - }); + std::stable_partition( + signatures_.begin(), signatures_.end(), [](const FunctionSignature& sig) { + return !sig.deprecated; + }); } -void PythonArgParser::check_deprecated(const FunctionSignature & signature) { +void PythonArgParser::check_deprecated(const FunctionSignature& signature) { if (signature.deprecated) { auto msg = c10::str( - "This overload of ", signature.name, " is deprecated:\n\t", - signature.name, signature.toString()); + "This overload of ", + signature.name, + " is deprecated:\n\t", + signature.name, + signature.toString()); auto signatures = get_signatures(); if (!signatures.empty()) { msg += "\nConsider using one of the following signatures instead:"; - for (const auto & sig : signatures) { + for (const auto& sig : signatures) { msg += "\n\t"; msg += signature.name; msg += sig; @@ -1175,7 +1362,11 @@ void PythonArgParser::check_deprecated(const FunctionSignature & signature) { } } -PythonArgs PythonArgParser::raw_parse(PyObject* self, PyObject* args, PyObject* kwargs, PyObject* parsed_args[]) { // NOLINT +PythonArgs PythonArgParser::raw_parse( + PyObject* self, + PyObject* args, + PyObject* kwargs, + PyObject* parsed_args[]) { // NOLINT if (signatures_.size() == 1) { auto& signature = signatures_[0]; signature.parse(self, args, kwargs, parsed_args, true); @@ -1193,14 +1384,18 @@ PythonArgs PythonArgParser::raw_parse(PyObject* self, PyObject* args, PyObject* print_error(self, args, kwargs, parsed_args); } - -void PythonArgParser::print_error(PyObject* self, PyObject* args, PyObject* kwargs, PyObject* parsed_args[]) { // NOLINT +void PythonArgParser::print_error( + PyObject* self, + PyObject* args, + PyObject* kwargs, + PyObject* parsed_args[]) { // NOLINT // NOLINTNEXTLINE(clang-analyzer-core.NullDereference) size_t num_args = PyTuple_GET_SIZE(args) + (kwargs ? PyDict_Size(kwargs) : 0); std::vector plausible_idxs; unsigned i = 0; for (auto& signature : signatures_) { - if (num_args >= signature.min_args && num_args <= signature.max_args && !signature.hidden) { + if (num_args >= signature.min_args && num_args <= signature.max_args && + !signature.hidden) { plausible_idxs.push_back(i); } i++; @@ -1212,7 +1407,8 @@ void PythonArgParser::print_error(PyObject* self, PyObject* args, PyObject* kwar } auto options = get_signatures(); - auto msg = torch::format_invalid_args(args, kwargs, function_name + "()", options); + auto msg = + torch::format_invalid_args(args, kwargs, function_name + "()", options); throw TypeError("%s", msg.c_str()); } @@ -1250,10 +1446,10 @@ at::Tensor PythonArgs::tensor_slow(int i) { // a test for Py_None here; instead, you need to mark the argument // as *allowing none*; you can do this by writing 'Tensor?' instead // of 'Tensor' in the ATen metadata. - throw TypeError("expected Tensor as argument %d, but got %s", i, - Py_TYPE(obj)->tp_name); + throw TypeError( + "expected Tensor as argument %d, but got %s", i, Py_TYPE(obj)->tp_name); } - at::AutoDispatchBelowADInplaceOrView guard; // TODO: remove + at::AutoDispatchBelowADInplaceOrView guard; // TODO: remove at::tracer::impl::NoTracerDispatchMode tracer_guard; at::Tensor tensor = scalar_to_tensor(scalar); @@ -1272,8 +1468,8 @@ at::Scalar PythonArgs::scalar_slow(int i) { } at::Scalar PythonArgs::scalar_slow(PyObject* arg) { - // Zero-dim tensors are converted to Scalars as-is. Note this doesn't currently - // handle most NumPy scalar types except np.float64. + // Zero-dim tensors are converted to Scalars as-is. Note this doesn't + // currently handle most NumPy scalar types except np.float64. if (THPVariable_Check(arg)) { return THPVariable_Unpack(arg).item(); } diff --git a/torch/csrc/utils/python_arg_parser.h b/torch/csrc/utils/python_arg_parser.h index 6e502de14c9ed1..647359811cb8af 100644 --- a/torch/csrc/utils/python_arg_parser.h +++ b/torch/csrc/utils/python_arg_parser.h @@ -39,29 +39,28 @@ // Scalar and Tensor, UNLESS they require grad (in which case // they only bind to Tensor). - #include -#include #include #include #include #include #include +#include #include #include -#include +#include #include +#include #include #include #include +#include #include #include #include #include -#include #include -#include #include #include @@ -78,9 +77,30 @@ namespace torch { enum class ParameterType { - TENSOR, SCALAR, INT64, SYM_INT, DOUBLE, COMPLEX, TENSOR_LIST, INT_LIST, GENERATOR, - BOOL, STORAGE, PYOBJECT, SCALARTYPE, LAYOUT, MEMORY_FORMAT, DEVICE, STREAM, STRING, - DIMNAME, DIMNAME_LIST, QSCHEME, FLOAT_LIST, SCALAR_LIST, SYM_INT_LIST + TENSOR, + SCALAR, + INT64, + SYM_INT, + DOUBLE, + COMPLEX, + TENSOR_LIST, + INT_LIST, + GENERATOR, + BOOL, + STORAGE, + PYOBJECT, + SCALARTYPE, + LAYOUT, + MEMORY_FORMAT, + DEVICE, + STREAM, + STRING, + DIMNAME, + DIMNAME_LIST, + QSCHEME, + FLOAT_LIST, + SCALAR_LIST, + SYM_INT_LIST }; struct FunctionParameter; @@ -88,21 +108,27 @@ struct FunctionSignature; struct PythonArgs; // Contains bound Python arguments in declaration order -template +template struct ParsedArgs { - ParsedArgs() : args() { } + ParsedArgs() : args() {} // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays) PyObject* args[N]; }; struct PythonArgParser { - explicit PythonArgParser(std::vector fmts, bool traceable=false); + explicit PythonArgParser( + std::vector fmts, + bool traceable = false); // meant only for `torch` functions. - template - inline PythonArgs parse(PyObject* self, PyObject* args, PyObject* kwargs, ParsedArgs& dst); - - template + template + inline PythonArgs parse( + PyObject* self, + PyObject* args, + PyObject* kwargs, + ParsedArgs& dst); + + template inline PythonArgs parse(PyObject* args, PyObject* kwargs, ParsedArgs& dst); inline PythonArgs parse(PyObject* self, ParsedArgs<0>& dst); @@ -110,13 +136,22 @@ struct PythonArgParser { // Formatted strings of non-hidden signatures std::vector get_signatures() const; -private: + private: [[noreturn]] // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays) - void print_error(PyObject* self, PyObject* args, PyObject* kwargs, PyObject* parsed_args[]); - void check_deprecated(const FunctionSignature & signature); + void + print_error( + PyObject* self, + PyObject* args, + PyObject* kwargs, + PyObject* parsed_args[]); + void check_deprecated(const FunctionSignature& signature); // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays) - PythonArgs raw_parse(PyObject* self, PyObject* args, PyObject* kwargs, PyObject* parsed_args[]); + PythonArgs raw_parse( + PyObject* self, + PyObject* args, + PyObject* kwargs, + PyObject* parsed_args[]); std::vector signatures_; std::string function_name; @@ -128,7 +163,12 @@ struct PYBIND11_EXPORT FunctionSignature { explicit FunctionSignature(const std::string& fmt, int index); // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays) - bool parse(PyObject* self, PyObject* args, PyObject* kwargs, PyObject* dst[], bool raise_exception); + bool parse( + PyObject* self, + PyObject* args, + PyObject* kwargs, + PyObject* dst[], + bool raise_exception); std::string toString() const; @@ -145,11 +185,14 @@ struct PYBIND11_EXPORT FunctionSignature { }; struct PythonArgs { - PythonArgs(bool traceable, const FunctionSignature& signature, PyObject** args) - : idx(signature.index) - , traceable(traceable) - , signature(signature) - , args(args) {} + PythonArgs( + bool traceable, + const FunctionSignature& signature, + PyObject** args) + : idx(signature.index), + traceable(traceable), + signature(signature), + args(args) {} int idx; bool traceable; @@ -165,18 +208,25 @@ struct PythonArgs { inline std::vector scalarlist(int i); inline std::vector tensorlist(int i); inline torch::List> list_of_optional_tensors(int i); - template + template inline std::array tensorlist_n(int i); inline std::vector intlist(int i); inline std::vector symintlist(int i); inline c10::OptionalArray intlistOptional(int i); - inline std::vector intlistWithDefault(int i, std::vector default_intlist); + inline std::vector intlistWithDefault( + int i, + std::vector default_intlist); inline c10::optional generator(int i); inline at::Storage storage(int i); - inline at::Storage storage(int i, at::ScalarType& storage_scalar_type, bool& is_typed_storage); + inline at::Storage storage( + int i, + at::ScalarType& storage_scalar_type, + bool& is_typed_storage); inline c10::Stream stream(int i); inline at::ScalarType scalartype(int i); - inline at::ScalarType scalartypeWithDefault(int i, at::ScalarType default_scalartype); + inline at::ScalarType scalartypeWithDefault( + int i, + at::ScalarType default_scalartype); inline c10::optional scalartypeOptional(int i); inline c10::optional scalarOptional(int i); inline c10::optional toInt64Optional(int i); @@ -201,7 +251,9 @@ struct PythonArgs { inline std::string stringWithDefault(int i, const std::string& default_str); inline c10::optional stringOptional(int i); inline c10::string_view stringView(int i); - inline c10::string_view stringViewWithDefault(int i, const c10::string_view default_str); + inline c10::string_view stringViewWithDefault( + int i, + const c10::string_view default_str); inline c10::optional stringViewOptional(int i); inline PyObject* pyobject(int i); inline int64_t toInt64(int i); @@ -210,12 +262,14 @@ struct PythonArgs { inline double toDouble(int i); inline double toDoubleWithDefault(int i, double default_double); inline c10::complex toComplex(int i); - inline c10::complex toComplexWithDefault(int i, c10::complex default_complex); + inline c10::complex toComplexWithDefault( + int i, + c10::complex default_complex); inline bool toBool(int i); inline bool toBoolWithDefault(int i, bool default_bool); inline bool isNone(int i); -private: + private: at::Tensor tensor_slow(int i); at::Scalar scalar_slow(int i); at::Scalar scalar_slow(PyObject* arg); @@ -224,7 +278,10 @@ struct PythonArgs { struct FunctionParameter { FunctionParameter(const std::string& fmt, bool keyword_only); - bool check(PyObject* obj, std::vector &overloaded_args, int argnum); + bool check( + PyObject* obj, + std::vector& overloaded_args, + int argnum); void set_default_str(const std::string& str); std::string type_name() const; @@ -236,11 +293,12 @@ struct FunctionParameter { bool allow_numbers_as_tensors = false; int size; std::string name; - // having this as a raw PyObject * will presumably leak it, but these are only held by static objects - // anyway, and Py_Finalize can already be called when this is destructed. - PyObject *python_name; + // having this as a raw PyObject * will presumably leak it, but these are only + // held by static objects anyway, and Py_Finalize can already be called when + // this is destructed. + PyObject* python_name; // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - at::SmallVector numpy_python_names; + at::SmallVector numpy_python_names; at::Scalar default_scalar; std::vector default_intlist; std::string default_string; @@ -255,17 +313,26 @@ struct FunctionParameter { }; }; -template -inline PythonArgs PythonArgParser::parse(PyObject* self, PyObject* args, PyObject* kwargs, ParsedArgs& dst) { +template +inline PythonArgs PythonArgParser::parse( + PyObject* self, + PyObject* args, + PyObject* kwargs, + ParsedArgs& dst) { if (N < max_args) { - throw ValueError("PythonArgParser: dst ParsedArgs buffer does not have enough capacity, expected %d (got %d)", - (int)max_args, N); + throw ValueError( + "PythonArgParser: dst ParsedArgs buffer does not have enough capacity, expected %d (got %d)", + (int)max_args, + N); } return raw_parse(self, args, kwargs, dst.args); } -template -inline PythonArgs PythonArgParser::parse(PyObject* args, PyObject* kwargs, ParsedArgs& dst) { +template +inline PythonArgs PythonArgParser::parse( + PyObject* args, + PyObject* kwargs, + ParsedArgs& dst) { return parse(nullptr, args, kwargs, dst); } @@ -273,12 +340,12 @@ inline PythonArgs PythonArgParser::parse(PyObject* self, ParsedArgs<0>& dst) { return parse(self, nullptr, nullptr, dst); } -inline bool PythonArgs::has_torch_function(){ +inline bool PythonArgs::has_torch_function() { return !this->signature.overloaded_args.empty() || - at::impl::PythonTorchFunctionTLS::get_mode(); + at::impl::PythonTorchFunctionTLS::get_mode(); } -inline std::string PythonArgs::get_func_name(){ +inline std::string PythonArgs::get_func_name() { return signature.name; } @@ -301,54 +368,65 @@ inline c10::optional PythonArgs::optionalTensor(int i) { } inline at::Scalar PythonArgs::scalar(int i) { - if (!args[i]) return signature.params[i].default_scalar; + if (!args[i]) + return signature.params[i].default_scalar; return scalar_slow(i); } inline std::vector PythonArgs::scalarlist(int i) { - if (!args[i]) return std::vector(); + if (!args[i]) + return std::vector(); auto tuple = six::isTuple(args[i]); THPObjectPtr arg = six::maybeAsTuple(args[i]); // NOLINTNEXTLINE(bugprone-branch-clone) auto size = tuple ? PyTuple_GET_SIZE(arg.get()) : PyList_GET_SIZE(arg.get()); // NOLINTNEXTLINE(cppcoreguidelines-init-variables) std::vector res(size); - for(const auto idx : c10::irange(size)) { - PyObject* obj = tuple ? PyTuple_GET_ITEM(arg.get(), idx) : PyList_GET_ITEM(arg.get(), idx); + for (const auto idx : c10::irange(size)) { + PyObject* obj = tuple ? PyTuple_GET_ITEM(arg.get(), idx) + : PyList_GET_ITEM(arg.get(), idx); res[idx] = scalar_slow(obj); } return res; } -inline at::Scalar PythonArgs::scalarWithDefault(int i, const at::Scalar& default_scalar) { - if (!args[i]) return default_scalar; +inline at::Scalar PythonArgs::scalarWithDefault( + int i, + const at::Scalar& default_scalar) { + if (!args[i]) + return default_scalar; return scalar_slow(i); } inline c10::optional PythonArgs::scalarOptional(int i) { - if (!args[i]) return c10::nullopt; + if (!args[i]) + return c10::nullopt; return scalar_slow(i); } inline std::vector PythonArgs::tensorlist(int i) { - if (!args[i]) return std::vector(); + if (!args[i]) + return std::vector(); auto tuple = six::isTuple(args[i]); THPObjectPtr arg = six::maybeAsTuple(args[i]); // NOLINTNEXTLINE(bugprone-branch-clone) auto size = tuple ? PyTuple_GET_SIZE(arg.get()) : PyList_GET_SIZE(arg.get()); // NOLINTNEXTLINE(cppcoreguidelines-init-variables) std::vector res(size); - for(const auto idx : c10::irange(size)) { - PyObject* obj = tuple ? PyTuple_GET_ITEM(arg.get(), idx) : PyList_GET_ITEM(arg.get(), idx); - // This is checked by the argument parser so it's safe to cast without checking - // if this is a tensor first + for (const auto idx : c10::irange(size)) { + PyObject* obj = tuple ? PyTuple_GET_ITEM(arg.get(), idx) + : PyList_GET_ITEM(arg.get(), idx); + // This is checked by the argument parser so it's safe to cast without + // checking if this is a tensor first res[idx] = THPVariable_Unpack(obj); } return res; } -inline torch::List> PythonArgs::list_of_optional_tensors(int i) { - if (!args[i]) return torch::List>(); +inline torch::List> PythonArgs:: + list_of_optional_tensors(int i) { + if (!args[i]) + return torch::List>(); auto tuple = six::isTuple(args[i]); THPObjectPtr arg = six::maybeAsTuple(args[i]); // NOLINTNEXTLINE(bugprone-branch-clone) @@ -356,19 +434,21 @@ inline torch::List> PythonArgs::list_of_optional_tenso // NOLINTNEXTLINE(cppcoreguidelines-init-variables) torch::List> res; res.reserve(size); - for(const auto idx : c10::irange(size)) { - PyObject* obj = tuple ? PyTuple_GET_ITEM(arg.get(), idx) : PyList_GET_ITEM(arg.get(), idx); - // This is checked by the argument parser so it's safe to cast without checking - // if this is a tensor first + for (const auto idx : c10::irange(size)) { + PyObject* obj = tuple ? PyTuple_GET_ITEM(arg.get(), idx) + : PyList_GET_ITEM(arg.get(), idx); + // This is checked by the argument parser so it's safe to cast without + // checking if this is a tensor first res.push_back(THPVariable_Unpack(obj)); } return res; } -template +template inline std::array PythonArgs::tensorlist_n(int i) { auto res = std::array(); - if (!args[i]) return res; + if (!args[i]) + return res; auto tuple = six::isTuple(args[i]); THPObjectPtr arg = six::maybeAsTuple(args[i]); // NOLINTNEXTLINE(bugprone-branch-clone) @@ -376,10 +456,11 @@ inline std::array PythonArgs::tensorlist_n(int i) { if (size != N) { throw TypeError("expected tuple of %d elements but got %d", N, (int)size); } - for(const auto idx : c10::irange(size)) { - PyObject* obj = tuple ? PyTuple_GET_ITEM(arg.get(), idx) : PyList_GET_ITEM(arg.get(), idx); - // This is checked by the argument parser so it's safe to cast without checking - // if this is a tensor first + for (const auto idx : c10::irange(size)) { + PyObject* obj = tuple ? PyTuple_GET_ITEM(arg.get(), idx) + : PyList_GET_ITEM(arg.get(), idx); + // This is checked by the argument parser so it's safe to cast without + // checking if this is a tensor first res[idx] = THPVariable_Unpack(obj); } return res; @@ -391,11 +472,14 @@ inline std::vector PythonArgs::intlist(int i) { inline std::vector PythonArgs::symintlist(int i) { auto intlist = intlistWithDefault(i, signature.params[i].default_intlist); - return c10::fmap(intlist, [](int64_t n) {return c10::SymInt(n); }); + return c10::fmap(intlist, [](int64_t n) { return c10::SymInt(n); }); } -inline std::vector PythonArgs::intlistWithDefault(int i, std::vector default_intlist) { - if (!args[i]) return default_intlist; +inline std::vector PythonArgs::intlistWithDefault( + int i, + std::vector default_intlist) { + if (!args[i]) + return default_intlist; PyObject* arg = args[i]; const auto size1 = signature.params[i].size; if (size1 > 0 && THPUtils_checkLong(arg)) { @@ -406,13 +490,14 @@ inline std::vector PythonArgs::intlistWithDefault(int i, std::vector res(size2); - for(const auto idx : c10::irange(size2)) { - PyObject* obj = tuple ? PyTuple_GET_ITEM(arg, idx) : PyList_GET_ITEM(arg, idx); + for (const auto idx : c10::irange(size2)) { + PyObject* obj = + tuple ? PyTuple_GET_ITEM(arg, idx) : PyList_GET_ITEM(arg, idx); try { - // Elements of torch.Size are tensors during tracing, and we need to record extra - // information before they are turned into an IntArrayRef + // Elements of torch.Size are tensors during tracing, and we need to + // record extra information before they are turned into an IntArrayRef if (traceable && jit::tracer::isTracing() && THPVariable_Check(obj)) { - auto & var = THPVariable_Unpack(obj); + auto& var = THPVariable_Unpack(obj); jit::tracer::ArgumentStash::stashIntArrayRefElem( signature.params[i].name, size2, idx, var); res[idx] = var.item(); @@ -420,10 +505,14 @@ inline std::vector PythonArgs::intlistWithDefault(int i, std::vectortp_name, idx + 1); + } catch (const std::exception& e) { + throw TypeError( + "%s(): argument '%s' must be %s, but found element of type %s at pos %ld", + signature.name.c_str(), + signature.params[i].name.c_str(), + signature.params[i].type_name().c_str(), + Py_TYPE(obj)->tp_name, + idx + 1); } } return res; @@ -443,14 +532,19 @@ inline std::vector PythonArgs::getDoublelist(int i) { auto size = tuple ? PyTuple_GET_SIZE(arg) : PyList_GET_SIZE(arg); // NOLINTNEXTLINE(cppcoreguidelines-init-variables) std::vector res(size); - for(const auto idx : c10::irange(size)) { - PyObject* obj = tuple ? PyTuple_GET_ITEM(arg, idx) : PyList_GET_ITEM(arg, idx); + for (const auto idx : c10::irange(size)) { + PyObject* obj = + tuple ? PyTuple_GET_ITEM(arg, idx) : PyList_GET_ITEM(arg, idx); try { res[idx] = THPUtils_unpackDouble(obj); - } catch (const std::exception &e) { - throw TypeError("%s(): argument '%s' must be %s, but found element of type %s at pos %ld", - signature.name.c_str(), signature.params[i].name.c_str(), - signature.params[i].type_name().c_str(), Py_TYPE(obj)->tp_name, idx + 1); + } catch (const std::exception& e) { + throw TypeError( + "%s(): argument '%s' must be %s, but found element of type %s at pos %ld", + signature.name.c_str(), + signature.params[i].name.c_str(), + signature.params[i].type_name().c_str(), + Py_TYPE(obj)->tp_name, + idx + 1); } } return res; @@ -470,18 +564,22 @@ inline std::vector PythonArgs::doublelist(int i) { return this->getDoublelist(i); } -inline at::ScalarType PythonArgs::scalartypeWithDefault(int i, at::ScalarType default_scalartype) { - if (!args[i]) return default_scalartype; +inline at::ScalarType PythonArgs::scalartypeWithDefault( + int i, + at::ScalarType default_scalartype) { + if (!args[i]) + return default_scalartype; return scalartype(i); } inline at::ScalarType PythonArgs::scalartype(int i) { if (!args[i]) { auto scalartype = signature.params[i].default_scalartype; - return (scalartype == at::ScalarType::Undefined) ? - torch::tensors::get_default_scalar_type() : scalartype; + return (scalartype == at::ScalarType::Undefined) + ? torch::tensors::get_default_scalar_type() + : scalartype; } - PyObject *obj = args[i]; + PyObject* obj = args[i]; if (obj == (PyObject*)&PyFloat_Type) { return at::ScalarType::Double; } @@ -501,21 +599,26 @@ inline c10::optional PythonArgs::scalartypeOptional(int i) { } inline at::Layout PythonArgs::layout(int i) { - if (!args[i]) return signature.params[i].default_layout; + if (!args[i]) + return signature.params[i].default_layout; return reinterpret_cast(args[i])->layout; } -inline at::Layout PythonArgs::layoutWithDefault(int i, at::Layout default_layout) { - if (!args[i]) return default_layout; +inline at::Layout PythonArgs::layoutWithDefault( + int i, + at::Layout default_layout) { + if (!args[i]) + return default_layout; return layout(i); } inline c10::optional PythonArgs::layoutOptional(int i) { - if (!args[i]) return c10::nullopt; + if (!args[i]) + return c10::nullopt; return layout(i); } -inline at::Device toDevice(PyObject * obj) { +inline at::Device toDevice(PyObject* obj) { if (THPDevice_Check(obj)) { const auto device = reinterpret_cast(obj); return device->device; @@ -525,19 +628,23 @@ inline at::Device toDevice(PyObject * obj) { TORCH_CHECK(device_index >= 0, "Device index must not be negative"); return at::Device(DeviceType::CUDA, device_index); } - const std::string &device_str = THPUtils_unpackString(obj); + const std::string& device_str = THPUtils_unpackString(obj); return at::Device(device_str); } inline at::Device PythonArgs::device(int i) { if (!args[i]) { - return at::Device(backendToDeviceType(dispatchKeyToBackend(torch::tensors::get_default_dispatch_key()))); + return at::Device(backendToDeviceType( + dispatchKeyToBackend(torch::tensors::get_default_dispatch_key()))); } return toDevice(args[i]); } -inline at::Device PythonArgs::deviceWithDefault(int i, const at::Device& default_device) { - if (!args[i]) return default_device; +inline at::Device PythonArgs::deviceWithDefault( + int i, + const at::Device& default_device) { + if (!args[i]) + return default_device; return device(i); } @@ -559,15 +666,18 @@ inline std::vector parseDimnameList(PyObject* arg) { // NOLINTNEXTLINE(cppcoreguidelines-init-variables) std::vector res; res.reserve(size); - for(const auto idx : c10::irange(size)) { - PyObject* obj = tuple ? PyTuple_GET_ITEM(arg, idx) : PyList_GET_ITEM(arg, idx); + for (const auto idx : c10::irange(size)) { + PyObject* obj = + tuple ? PyTuple_GET_ITEM(arg, idx) : PyList_GET_ITEM(arg, idx); res.push_back(THPDimname_parse(obj)); } return res; } -inline c10::optional> PythonArgs::toDimnameListOptional(int i) { - if (!args[i]) return c10::nullopt; +inline c10::optional> PythonArgs:: + toDimnameListOptional(int i) { + if (!args[i]) + return c10::nullopt; return parseDimnameList(args[i]); } @@ -577,14 +687,17 @@ inline std::vector PythonArgs::dimnamelist(int i) { auto size = signature.params[i].size; TORCH_INTERNAL_ASSERT(size == 0 || size == 1); if (size == 1 && THPUtils_checkDimname(arg)) { - return { THPDimname_parse(arg) }; + return {THPDimname_parse(arg)}; } return parseDimnameList(arg); } inline at::MemoryFormat PythonArgs::memoryformat(int i) { - if (!args[i]) return at::MemoryFormat::Contiguous; - TORCH_CHECK(THPMemoryFormat_Check(args[i]), "memory_format arg must be an instance of the torch.memory_format"); + if (!args[i]) + return at::MemoryFormat::Contiguous; + TORCH_CHECK( + THPMemoryFormat_Check(args[i]), + "memory_format arg must be an instance of the torch.memory_format"); const auto memory_format = reinterpret_cast(args[i]); return memory_format->memory_format; } @@ -596,8 +709,11 @@ inline c10::optional PythonArgs::memoryformatOptional(int i) { } inline at::QScheme PythonArgs::toQScheme(int i) { - if (!args[i]) return at::kPerTensorAffine; - TORCH_CHECK(THPQScheme_Check(args[i]), "qscheme arg must be an instance of the torch.qscheme"); + if (!args[i]) + return at::kPerTensorAffine; + TORCH_CHECK( + THPQScheme_Check(args[i]), + "qscheme arg must be an instance of the torch.qscheme"); const auto qscheme = reinterpret_cast(args[i]); return qscheme->qscheme; } @@ -606,13 +722,17 @@ inline std::string PythonArgs::string(int i) { return stringWithDefault(i, signature.params[i].default_string); } -inline std::string PythonArgs::stringWithDefault(int i, const std::string& default_str) { - if (!args[i]) return default_str; +inline std::string PythonArgs::stringWithDefault( + int i, + const std::string& default_str) { + if (!args[i]) + return default_str; return THPUtils_unpackString(args[i]); } inline c10::optional PythonArgs::stringOptional(int i) { - if (!args[i]) return c10::nullopt; + if (!args[i]) + return c10::nullopt; return THPUtils_unpackString(args[i]); } @@ -620,20 +740,25 @@ inline c10::string_view PythonArgs::stringView(int i) { return stringViewWithDefault(i, signature.params[i].default_string); } -inline c10::string_view PythonArgs::stringViewWithDefault(int i, const c10::string_view default_str) { - if (!args[i]) return default_str; +inline c10::string_view PythonArgs::stringViewWithDefault( + int i, + const c10::string_view default_str) { + if (!args[i]) + return default_str; return THPUtils_unpackStringView(args[i]); } inline c10::optional PythonArgs::stringViewOptional(int i) { - if (!args[i]) return c10::nullopt; + if (!args[i]) + return c10::nullopt; return THPUtils_unpackStringView(args[i]); } inline int64_t PythonArgs::toInt64(int i) { - if (!args[i]) return signature.params[i].default_int; + if (!args[i]) + return signature.params[i].default_int; if (traceable && jit::tracer::isTracing() && THPVariable_Check(args[i])) { - auto & var = THPVariable_Unpack(args[i]); + auto& var = THPVariable_Unpack(args[i]); jit::tracer::ArgumentStash::stashValue( signature.params[i].name, idx, var, c10::IntType::get()); } @@ -645,7 +770,7 @@ inline c10::SymInt PythonArgs::toSymInt(int i) { return c10::SymInt(signature.params[i].default_int); } if (traceable && jit::tracer::isTracing() && THPVariable_Check(args[i])) { - auto & var = THPVariable_Unpack(args[i]); + auto& var = THPVariable_Unpack(args[i]); jit::tracer::ArgumentStash::stashValue( signature.params[i].name, idx, var, c10::IntType::get()); } @@ -653,7 +778,8 @@ inline c10::SymInt PythonArgs::toSymInt(int i) { } inline int64_t PythonArgs::toInt64WithDefault(int i, int64_t default_int) { - if (!args[i]) return default_int; + if (!args[i]) + return default_int; return toInt64(i); } @@ -678,35 +804,44 @@ inline c10::optional PythonArgs::toDoubleOptional(int i) { } inline double PythonArgs::toDouble(int i) { - if (!args[i]) return signature.params[i].default_double; + if (!args[i]) + return signature.params[i].default_double; return THPUtils_unpackDouble(args[i]); } inline double PythonArgs::toDoubleWithDefault(int i, double default_double) { - if (!args[i]) return default_double; + if (!args[i]) + return default_double; return toDouble(i); } inline c10::complex PythonArgs::toComplex(int i) { // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) - c10::complex default_value = *const_cast *>( - reinterpret_cast *>(signature.params[i].default_complex)); - if (!args[i]) return default_value; + c10::complex default_value = *const_cast*>( + reinterpret_cast*>( + signature.params[i].default_complex)); + if (!args[i]) + return default_value; return THPUtils_unpackComplexDouble(args[i]); } -inline c10::complex PythonArgs::toComplexWithDefault(int i, c10::complex default_value) { - if (!args[i]) return default_value; +inline c10::complex PythonArgs::toComplexWithDefault( + int i, + c10::complex default_value) { + if (!args[i]) + return default_value; return toComplex(i); } inline bool PythonArgs::toBool(int i) { - if (!args[i]) return signature.params[i].default_bool; + if (!args[i]) + return signature.params[i].default_bool; return args[i] == Py_True; } inline bool PythonArgs::toBoolWithDefault(int i, bool default_bool) { - if (!args[i]) return default_bool; + if (!args[i]) + return default_bool; return toBool(i); } @@ -715,37 +850,47 @@ inline bool PythonArgs::isNone(int i) { } inline c10::optional PythonArgs::generator(int i) { - if (!args[i]) return c10::nullopt; + if (!args[i]) + return c10::nullopt; return reinterpret_cast(args[i])->cdata; } inline at::Storage PythonArgs::storage(int i) { - if (!args[i]) return at::Storage(); + if (!args[i]) + return at::Storage(); return createStorage(args[i]); } -inline at::Storage PythonArgs::storage(int i, at::ScalarType& storage_scalar_type, bool& is_typed_storage) { +inline at::Storage PythonArgs::storage( + int i, + at::ScalarType& storage_scalar_type, + bool& is_typed_storage) { at::Storage storage; if (!args[i]) { storage = at::Storage(); is_typed_storage = false; storage_scalar_type = at::ScalarType::Undefined; } else { - storage = createStorageGetType(args[i], storage_scalar_type, is_typed_storage); + storage = + createStorageGetType(args[i], storage_scalar_type, is_typed_storage); } return storage; } inline c10::Stream PythonArgs::stream(int i) { - if (!args[i]) return c10::Stream(c10::Stream::Default::DEFAULT, c10::Device(DeviceType::CPU, -1)); + if (!args[i]) + return c10::Stream( + c10::Stream::Default::DEFAULT, c10::Device(DeviceType::CPU, -1)); if (!THPStream_Check(args[i])) { - throw TypeError("expected Stream object. Got '%s'", Py_TYPE(args[i])->tp_name); + throw TypeError( + "expected Stream object. Got '%s'", Py_TYPE(args[i])->tp_name); } return c10::Stream::unpack(((THPStream*)args[i])->cdata); } inline PyObject* PythonArgs::pyobject(int i) { - if (!args[i]) return Py_None; + if (!args[i]) + return Py_None; return args[i]; } @@ -802,15 +947,35 @@ inline PyObject* PythonArgs::pyobject(int i) { * */ // Used for Tensor methods with arguments. -auto handle_torch_function(PythonArgs &r, PyObject* self, PyObject* args, PyObject* kwargs, PyObject* torch_api, const char* module_name, const char* func_name_override = nullptr) -> PyObject*; +auto handle_torch_function( + PythonArgs& r, + PyObject* self, + PyObject* args, + PyObject* kwargs, + PyObject* torch_api, + const char* module_name, + const char* func_name_override = nullptr) -> PyObject*; // Used for functions which needs to parse python args. -auto handle_torch_function(PythonArgs &r, PyObject* args, PyObject* kwargs, PyObject* torch_api, const char* module_name, const char* func_name_override = nullptr) -> PyObject*; +auto handle_torch_function( + PythonArgs& r, + PyObject* args, + PyObject* kwargs, + PyObject* torch_api, + const char* module_name, + const char* func_name_override = nullptr) -> PyObject*; // Used for functions that have no argument parsing. -auto handle_torch_function(PyObject* self, const std::string& func_name, PyObject* args=nullptr, PyObject* kwargs=nullptr, PyObject* torch_api=THPVariableClass, const std::string& module_name="torch.Tensor") -> PyObject*; - -// Used for functions created in C++, e.g., C++ custom op, which doesn't use PythonArgParser to get overloaded_args. +auto handle_torch_function( + PyObject* self, + const std::string& func_name, + PyObject* args = nullptr, + PyObject* kwargs = nullptr, + PyObject* torch_api = THPVariableClass, + const std::string& module_name = "torch.Tensor") -> PyObject*; + +// Used for functions created in C++, e.g., C++ custom op, which doesn't use +// PythonArgParser to get overloaded_args. enum class TorchFunctionName { TorchFunction, TorchDispatch }; auto TORCH_API handle_torch_function_no_python_arg_parser( @@ -824,13 +989,21 @@ auto TORCH_API handle_torch_function_no_python_arg_parser( -> PyObject*; // Used for getters of Tensor properties -auto handle_torch_function_getter(THPVariable* self, const std::string& property_name) -> PyObject*; +auto handle_torch_function_getter( + THPVariable* self, + const std::string& property_name) -> PyObject*; // Used for setters of Tensor properties. -auto handle_torch_function_setter(THPVariable* self, const std::string& property_name, PyObject* value) -> int; +auto handle_torch_function_setter( + THPVariable* self, + const std::string& property_name, + PyObject* value) -> int; // Used for __getitem__ and __setitem__ -auto handle_torch_function_indexing(PyObject* self, PyObject* index, PyObject* val=nullptr) -> PyObject*; +auto handle_torch_function_indexing( + PyObject* self, + PyObject* index, + PyObject* val = nullptr) -> PyObject*; /* * Check if the input obj is Tensor type, including its subclass, or overloaded @@ -841,7 +1014,9 @@ auto handle_torch_function_indexing(PyObject* self, PyObject* index, PyObject* v * 'obj': the input argument to be checked * 'overloaded_args': the vector to append the overloaded args. */ -bool is_tensor_and_append_overloaded(PyObject* obj, std::vector* overloaded_args); +bool is_tensor_and_append_overloaded( + PyObject* obj, + std::vector* overloaded_args); /* * Check if the input obj is Tensor List or Tensor Tuple type. First check @@ -855,7 +1030,11 @@ bool is_tensor_and_append_overloaded(PyObject* obj, std::vector* ove * 'throw_error': whether throw error if any element in the list or tuple is * not tensor type or overloaded. */ -bool is_tensor_list_and_append_overloaded(PyObject* obj, std::vector* overloaded_args, int argnum, bool throw_error); +bool is_tensor_list_and_append_overloaded( + PyObject* obj, + std::vector* overloaded_args, + int argnum, + bool throw_error); /* Given an argument that is definitely a tensor and is definitely overloaded, * append it to the overloaded arguments list. Use this instead of @@ -865,15 +1044,20 @@ bool is_tensor_list_and_append_overloaded(PyObject* obj, std::vector * 'overloaded_args': the vector to append the overloaded args * 'obj': the input tensor that is overloaded */ -void append_overloaded_tensor(std::vector* overloaded_args, PyObject* obj); +void append_overloaded_tensor( + std::vector* overloaded_args, + PyObject* obj); /* Given an argument that is definitely a type and is definitely overloaded, - * append it to the overloaded arguments list. Use this only with __torch_dispatch__, - * where we operate on classes that have a __torch_dispatch__ classmethod. + * append it to the overloaded arguments list. Use this only with + * __torch_dispatch__, where we operate on classes that have a + * __torch_dispatch__ classmethod. * * 'overloaded_args': the vector to append the overloaded type * 'obj': the input class that has a __torch_dispatch__ classmethod. */ -void append_overloaded_type(std::vector* overloaded_args, PyObject* obj); +void append_overloaded_type( + std::vector* overloaded_args, + PyObject* obj); } // namespace torch diff --git a/torch/csrc/utils/python_compat.h b/torch/csrc/utils/python_compat.h index c435d411513a0c..fb0eec3afb8024 100644 --- a/torch/csrc/utils/python_compat.h +++ b/torch/csrc/utils/python_compat.h @@ -3,7 +3,7 @@ #include #define MAYBE_METH_FASTCALL METH_FASTCALL -#define MAYBE_WRAP_FASTCALL(f) (PyCFunction)(void(*)(void))f +#define MAYBE_WRAP_FASTCALL(f) (PyCFunction)(void (*)(void)) f // PyPy 3.6 does not yet have PySlice_Unpack #if PY_VERSION_HEX < 0x03060100 || defined(PYPY_VERSION) @@ -13,50 +13,50 @@ // https://docs.python.org/3/c-api/slice.html#c.PySlice_Unpack // https://github.com/python/cpython/blob/master/Objects/sliceobject.c#L196 -inline int -__PySlice_Unpack(PyObject *_r, - Py_ssize_t *start, Py_ssize_t *stop, Py_ssize_t *step) -{ - PySliceObject *r = (PySliceObject*)_r; - /* this is harder to get right than you might think */ +inline int __PySlice_Unpack( + PyObject* _r, + Py_ssize_t* start, + Py_ssize_t* stop, + Py_ssize_t* step) { + PySliceObject* r = (PySliceObject*)_r; + /* this is harder to get right than you might think */ - // Py_BUILD_ASSERT replaced because it is not available in all versions - static_assert(PY_SSIZE_T_MIN + 1 <= -PY_SSIZE_T_MAX, "Build failed"); + // Py_BUILD_ASSERT replaced because it is not available in all versions + static_assert(PY_SSIZE_T_MIN + 1 <= -PY_SSIZE_T_MAX, "Build failed"); - if (r->step == Py_None) { - *step = 1; - } - else { - if (!_PyEval_SliceIndex(r->step, step)) return -1; - if (*step == 0) { - PyErr_SetString(PyExc_ValueError, - "slice step cannot be zero"); - return -1; - } - /* Here *step might be -PY_SSIZE_T_MAX-1; in this case we replace it - * with -PY_SSIZE_T_MAX. This doesn't affect the semantics, and it - * guards against later undefined behaviour resulting from code that - * does "step = -step" as part of a slice reversal. - */ - if (*step < -PY_SSIZE_T_MAX) - *step = -PY_SSIZE_T_MAX; + if (r->step == Py_None) { + *step = 1; + } else { + if (!_PyEval_SliceIndex(r->step, step)) + return -1; + if (*step == 0) { + PyErr_SetString(PyExc_ValueError, "slice step cannot be zero"); + return -1; } + /* Here *step might be -PY_SSIZE_T_MAX-1; in this case we replace it + * with -PY_SSIZE_T_MAX. This doesn't affect the semantics, and it + * guards against later undefined behaviour resulting from code that + * does "step = -step" as part of a slice reversal. + */ + if (*step < -PY_SSIZE_T_MAX) + *step = -PY_SSIZE_T_MAX; + } - if (r->start == Py_None) { - *start = *step < 0 ? PY_SSIZE_T_MAX : 0; - } - else { - if (!_PyEval_SliceIndex(r->start, start)) return -1; - } + if (r->start == Py_None) { + *start = *step < 0 ? PY_SSIZE_T_MAX : 0; + } else { + if (!_PyEval_SliceIndex(r->start, start)) + return -1; + } - if (r->stop == Py_None) { - *stop = *step < 0 ? PY_SSIZE_T_MIN : PY_SSIZE_T_MAX; - } - else { - if (!_PyEval_SliceIndex(r->stop, stop)) return -1; - } + if (r->stop == Py_None) { + *stop = *step < 0 ? PY_SSIZE_T_MIN : PY_SSIZE_T_MAX; + } else { + if (!_PyEval_SliceIndex(r->stop, stop)) + return -1; + } - return 0; + return 0; } #define THPUtils_unpackSlice(SLICE, START, STOP, STEP) \ diff --git a/torch/csrc/utils/python_dispatch.cpp b/torch/csrc/utils/python_dispatch.cpp index 3ec220c8edb39f..c3607f75cbf31d 100644 --- a/torch/csrc/utils/python_dispatch.cpp +++ b/torch/csrc/utils/python_dispatch.cpp @@ -1,14 +1,14 @@ -#include #include +#include -#include #include -#include #include +#include +#include #include -#include #include +#include #include #include @@ -23,9 +23,9 @@ namespace dispatch { torch::Library::Kind parseKind(const std::string& k) { static std::unordered_map kind_map = { - {"DEF", torch::Library::DEF}, - {"IMPL", torch::Library::IMPL}, - {"FRAGMENT", torch::Library::FRAGMENT}, + {"DEF", torch::Library::DEF}, + {"IMPL", torch::Library::IMPL}, + {"FRAGMENT", torch::Library::FRAGMENT}, }; auto it = kind_map.find(k); TORCH_CHECK(it != kind_map.end(), "could not parse ", k); @@ -33,20 +33,21 @@ torch::Library::Kind parseKind(const std::string& k) { } c10::AliasAnalysisKind parseAliasAnalysisKind(const std::string& k) { static std::unordered_map key_map = { - {"CONSERVATIVE", c10::AliasAnalysisKind::CONSERVATIVE}, - {"FROM_SCHEMA", c10::AliasAnalysisKind::FROM_SCHEMA}, - {"PURE_FUNCTION", c10::AliasAnalysisKind::PURE_FUNCTION}, - {"", c10::AliasAnalysisKind::FROM_SCHEMA}, // default + {"CONSERVATIVE", c10::AliasAnalysisKind::CONSERVATIVE}, + {"FROM_SCHEMA", c10::AliasAnalysisKind::FROM_SCHEMA}, + {"PURE_FUNCTION", c10::AliasAnalysisKind::PURE_FUNCTION}, + {"", c10::AliasAnalysisKind::FROM_SCHEMA}, // default }; auto it = key_map.find(k); TORCH_CHECK(it != key_map.end(), "could not parse ", k); return it->second; } - template inline torch::CppFunction dispatch_str(const char* key, Func&& raw_f) { - auto mb_key = std::string(key) == "" ? c10::nullopt : c10::make_optional(c10::parseDispatchKey(key)); + auto mb_key = std::string(key) == "" + ? c10::nullopt + : c10::make_optional(c10::parseDispatchKey(key)); if (mb_key) { return torch::dispatch(*mb_key, std::forward(raw_f)); } else { @@ -57,15 +58,25 @@ inline torch::CppFunction dispatch_str(const char* key, Func&& raw_f) { class PythonKernelHolder : public c10::OperatorKernel { c10::SafePyObject func_; -public: - PythonKernelHolder(py::object func) : func_(func.release().ptr(), getPyInterpreter()) {} - void operator()(const c10::OperatorHandle& op, c10::DispatchKeySet keyset, torch::jit::Stack* stack) { + public: + PythonKernelHolder(py::object func) + : func_(func.release().ptr(), getPyInterpreter()) {} + + void operator()( + const c10::OperatorHandle& op, + c10::DispatchKeySet keyset, + torch::jit::Stack* stack) { auto arguments = torch::jit::pop(*stack, op.schema().arguments().size()); py::gil_scoped_acquire g; auto args_kwargs = parseIValuesToPyArgsKwargs(op, arguments); - auto obj = py::reinterpret_steal(PyObject_Call(func_.ptr(getPyInterpreter()), args_kwargs.first.ptr(), args_kwargs.second.ptr())); - if (!obj) { throw python_error(); } + auto obj = py::reinterpret_steal(PyObject_Call( + func_.ptr(getPyInterpreter()), + args_kwargs.first.ptr(), + args_kwargs.second.ptr())); + if (!obj) { + throw python_error(); + } pushPyOutToStack(op, stack, obj, "PythonKernelHolder"); } }; @@ -74,105 +85,175 @@ void initDispatchBindings(PyObject* module) { auto m = py::handle(module).cast(); py::class_(m, "_DispatchOperatorHandle") - .def("schema", &c10::OperatorHandle::schema); + .def("schema", &c10::OperatorHandle::schema); // TODO: figure out how to do chaining py::class_(m, "_DispatchModule") - .def("def_", [](py::object self, const char* schema, const char* alias) { - self.cast().def(torch::schema(schema, parseAliasAnalysisKind(alias))); - return self; - }, "", py::arg("schema"), py::arg("alias") = "") - // Simulated "legacy" def where alias analysis kind is not set. - // Ordinarily this can only be exercised from RegisterOperators() API - // but I am not going to bind that here - .def("def_legacy", [](py::object self, const char* schema) { - self.cast().def(torch::jit::parseSchema(schema)); - return self; - }, "", py::arg("schema")) - // We can't conveniently turn Python functions into valid functions - // in the dispatcher. So instead we provide a bunch of precanned - // functions for testing purposes. You're NOT intended to actually - // call these functions; they're just here so we can actually register - // something - // - // Mangling scheme: args_rets. One character per. - // t = Tensor - .def("def_name_t_t", [](py::object self, const char* name, const char* dispatch, const char* debug) { - self.cast().def( - name, - dispatch_str(dispatch, [](const at::Tensor& a) { - return a; - }).debug(debug) - ); - return self; - }, "", py::arg("name"), - py::arg("dispatch") = "", - py::arg("debug") = "default_def_name_t_t") - .def("def_schema_t_t", [](py::object self, const char* schema, const char* dispatch, const char* alias, const char* debug) { - self.cast().def( - torch::schema(schema, parseAliasAnalysisKind(alias)), - dispatch_str(dispatch, [](const at::Tensor& a) { - return a; - }).debug(debug) - ); - return self; - }, "", py::arg("name"), - py::arg("dispatch") = "", - py::arg("alias") = "", - py::arg("debug") = "default_def_schema_t_t") - // TODO: maybe consider deduplicating the definitions here, it's getting - // pretty long - .def("impl_t_t", [](py::object self, const char* name, const char* dispatch, const char* debug) { - self.cast().impl( - name, - dispatch_str(dispatch, [](const at::Tensor& a) { - return a; - }).debug(debug) - ); - return self; - }, "", py::arg("name"), - py::arg("dispatch") = "", - py::arg("debug") = "impl_t_t") - .def("impl_tt_t", [](py::object self, const char* name, const char* dispatch, const char* debug) { - self.cast().impl( - name, - dispatch_str(dispatch, [](const at::Tensor& a, const at::Tensor& b) { - return a; - }).debug(debug) - ); - return self; - }, "", py::arg("name"), py::arg("dispatch") = "", py::arg("debug") = "") - .def("impl", [](py::object self, const char* name, const char* dispatch, py::object func) { - HANDLE_TH_ERRORS - self.cast().impl( - name, - dispatch_str(dispatch, CppFunction::makeFromBoxedFunctor(std::make_unique(std::move(func)))) - ); - END_HANDLE_TH_ERRORS_PYBIND - }, "", py::arg("name"), py::arg("dispatch"), py::arg("func")) - .def("define", [](py::object self, const char* schema, const char* alias_analysis) { - self.cast().def(torch::schema(schema, parseAliasAnalysisKind(alias_analysis))); - return torch::schema(schema, parseAliasAnalysisKind(alias_analysis)).name(); - }, "", py::arg("schema"), py::arg("alias_analysis") = "") - .def("fallback_fallthrough", [](py::object self, const char* dispatch) { - self.cast().fallback( - dispatch_str(dispatch, CppFunction::makeFallthrough()) - ); - return self; - }, "", py::arg("dispatch") = "") - ; + .def( + "def_", + [](py::object self, const char* schema, const char* alias) { + self.cast().def( + torch::schema(schema, parseAliasAnalysisKind(alias))); + return self; + }, + "", + py::arg("schema"), + py::arg("alias") = "") + // Simulated "legacy" def where alias analysis kind is not set. + // Ordinarily this can only be exercised from RegisterOperators() API + // but I am not going to bind that here + .def( + "def_legacy", + [](py::object self, const char* schema) { + self.cast().def(torch::jit::parseSchema(schema)); + return self; + }, + "", + py::arg("schema")) + // We can't conveniently turn Python functions into valid functions + // in the dispatcher. So instead we provide a bunch of precanned + // functions for testing purposes. You're NOT intended to actually + // call these functions; they're just here so we can actually register + // something + // + // Mangling scheme: args_rets. One character per. + // t = Tensor + .def( + "def_name_t_t", + [](py::object self, + const char* name, + const char* dispatch, + const char* debug) { + self.cast().def( + name, dispatch_str(dispatch, [](const at::Tensor& a) { + return a; + }).debug(debug)); + return self; + }, + "", + py::arg("name"), + py::arg("dispatch") = "", + py::arg("debug") = "default_def_name_t_t") + .def( + "def_schema_t_t", + [](py::object self, + const char* schema, + const char* dispatch, + const char* alias, + const char* debug) { + self.cast().def( + torch::schema(schema, parseAliasAnalysisKind(alias)), + dispatch_str(dispatch, [](const at::Tensor& a) { + return a; + }).debug(debug)); + return self; + }, + "", + py::arg("name"), + py::arg("dispatch") = "", + py::arg("alias") = "", + py::arg("debug") = "default_def_schema_t_t") + // TODO: maybe consider deduplicating the definitions here, it's getting + // pretty long + .def( + "impl_t_t", + [](py::object self, + const char* name, + const char* dispatch, + const char* debug) { + self.cast().impl( + name, dispatch_str(dispatch, [](const at::Tensor& a) { + return a; + }).debug(debug)); + return self; + }, + "", + py::arg("name"), + py::arg("dispatch") = "", + py::arg("debug") = "impl_t_t") + .def( + "impl_tt_t", + [](py::object self, + const char* name, + const char* dispatch, + const char* debug) { + self.cast().impl( + name, + dispatch_str( + dispatch, + [](const at::Tensor& a, const at::Tensor& b) { return a; }) + .debug(debug)); + return self; + }, + "", + py::arg("name"), + py::arg("dispatch") = "", + py::arg("debug") = "") + .def( + "impl", + [](py::object self, + const char* name, + const char* dispatch, + py::object func) { + HANDLE_TH_ERRORS + self.cast().impl( + name, + dispatch_str( + dispatch, + CppFunction::makeFromBoxedFunctor( + std::make_unique( + std::move(func))))); + END_HANDLE_TH_ERRORS_PYBIND + }, + "", + py::arg("name"), + py::arg("dispatch"), + py::arg("func")) + .def( + "define", + [](py::object self, const char* schema, const char* alias_analysis) { + self.cast().def( + torch::schema(schema, parseAliasAnalysisKind(alias_analysis))); + return torch::schema(schema, parseAliasAnalysisKind(alias_analysis)) + .name(); + }, + "", + py::arg("schema"), + py::arg("alias_analysis") = "") + .def( + "fallback_fallthrough", + [](py::object self, const char* dispatch) { + self.cast().fallback( + dispatch_str(dispatch, CppFunction::makeFallthrough())); + return self; + }, + "", + py::arg("dispatch") = ""); - m.def("_dispatch_library", [](const char* kind, std::string name, const char* dispatch, const char* file, uint32_t linenum) { - HANDLE_TH_ERRORS - return std::make_unique( - parseKind(kind), - std::move(name), - std::string(dispatch) == "" ? c10::nullopt : c10::make_optional(c10::parseDispatchKey(dispatch)), - "/dev/null", // temporary workaround - linenum); - END_HANDLE_TH_ERRORS_PYBIND - }, "", py::arg("kind"), py::arg("name"), py::arg("dispatch"), py::arg("file")="/dev/null", py::arg("linenum")=0) - ; + m.def( + "_dispatch_library", + [](const char* kind, + std::string name, + const char* dispatch, + const char* file, + uint32_t linenum) { + HANDLE_TH_ERRORS + return std::make_unique( + parseKind(kind), + std::move(name), + std::string(dispatch) == "" + ? c10::nullopt + : c10::make_optional(c10::parseDispatchKey(dispatch)), + "/dev/null", // temporary workaround + linenum); + END_HANDLE_TH_ERRORS_PYBIND + }, + "", + py::arg("kind"), + py::arg("name"), + py::arg("dispatch"), + py::arg("file") = "/dev/null", + py::arg("linenum") = 0); m.def("_dispatch_dump", [](const char* name) -> std::string { auto op = c10::Dispatcher::singleton().findOp(torch::jit::parseName(name)); @@ -209,14 +290,17 @@ void initDispatchBindings(PyObject* module) { return static_cast(op); }); - m.def("_dispatch_has_kernel_for_dispatch_key", [](const char* name, const char* dispatch) -> bool { - auto op = c10::Dispatcher::singleton().findOp(torch::jit::parseName(name)); - TORCH_CHECK(op, "operator ", name, " does not exist"); - return op->hasKernelForDispatchKey(c10::parseDispatchKey(dispatch)); - }); + m.def( + "_dispatch_has_kernel_for_dispatch_key", + [](const char* name, const char* dispatch) -> bool { + auto op = + c10::Dispatcher::singleton().findOp(torch::jit::parseName(name)); + TORCH_CHECK(op, "operator ", name, " does not exist"); + return op->hasKernelForDispatchKey(c10::parseDispatchKey(dispatch)); + }); m.def("_dispatch_find_dangling_impls", []() -> std::vector { - auto danglingImpls = c10::Dispatcher::singleton().findDanglingImpls(); + auto danglingImpls = c10::Dispatcher::singleton().findDanglingImpls(); std::vector states; states.reserve(danglingImpls.size()); @@ -227,43 +311,62 @@ void initDispatchBindings(PyObject* module) { return states; }); - m.def("_dispatch_tls_set_dispatch_key_excluded", [](const char* dispatch_key, bool desired_state) { - c10::impl::tls_set_dispatch_key_excluded(c10::parseDispatchKey(dispatch_key), desired_state); - }); + m.def( + "_dispatch_tls_set_dispatch_key_excluded", + [](const char* dispatch_key, bool desired_state) { + c10::impl::tls_set_dispatch_key_excluded( + c10::parseDispatchKey(dispatch_key), desired_state); + }); m.def("_dispatch_tls_is_dispatch_key_excluded", [](const char* dispatch_key) { - return c10::impl::tls_is_dispatch_key_excluded(c10::parseDispatchKey(dispatch_key)); + return c10::impl::tls_is_dispatch_key_excluded( + c10::parseDispatchKey(dispatch_key)); }); m.def("_dispatch_isTensorSubclassLike", [](const at::Tensor& tensor) { return at::isTensorSubclassLike(tensor); }); - py::class_(m, "_AutoDispatchBelowAutograd") - .def(py::init<>()); + .def(py::init<>()); - // Prints out the name of every operator that has a kernel registered to the Dispatcher - // under [dispatch_key]. - // If no arguments are specified, it'll print out the name of every operator that the Dispatcher knows of. - // This can be useful to answer questions like "list all operators that do not have a CPU kernel". - m.def("_dispatch_print_registrations_for_dispatch_key", [](const char* dispatch_key = "") { - auto k = std::string(dispatch_key) == "" ? c10::nullopt : c10::make_optional(c10::parseDispatchKey(dispatch_key)); - auto op_names = c10::Dispatcher::singleton().getRegistrationsForDispatchKey(k); - for (auto& op : op_names) { - std::cout << op << std::endl; - } - }, py::arg("dispatch_key") = static_cast("")); + // Prints out the name of every operator that has a kernel registered to the + // Dispatcher under [dispatch_key]. If no arguments are specified, it'll print + // out the name of every operator that the Dispatcher knows of. This can be + // useful to answer questions like "list all operators that do not have a CPU + // kernel". + m.def( + "_dispatch_print_registrations_for_dispatch_key", + [](const char* dispatch_key = "") { + auto k = std::string(dispatch_key) == "" + ? c10::nullopt + : c10::make_optional(c10::parseDispatchKey(dispatch_key)); + auto op_names = + c10::Dispatcher::singleton().getRegistrationsForDispatchKey(k); + for (auto& op : op_names) { + std::cout << op << std::endl; + } + }, + py::arg("dispatch_key") = static_cast("")); - m.def("_dispatch_get_registrations_for_dispatch_key", [](const char* dispatch_key = "") { - auto k = std::string(dispatch_key) == "" ? c10::nullopt : c10::make_optional(c10::parseDispatchKey(dispatch_key)); - auto op_names = c10::Dispatcher::singleton().getRegistrationsForDispatchKey(k); - std::vector names; - names.reserve(op_names.size()); - for (auto& op : op_names) { - names.push_back(op.name + (op.overload_name == "" ? "" : "." + op.overload_name)); - } - return names; - }, py::arg("dispatch_key") = static_cast("")); + m.def( + "_dispatch_get_registrations_for_dispatch_key", + [](const char* dispatch_key = "") { + auto k = std::string(dispatch_key) == "" + ? c10::nullopt + : c10::make_optional(c10::parseDispatchKey(dispatch_key)); + auto op_names = + c10::Dispatcher::singleton().getRegistrationsForDispatchKey(k); + std::vector names; + names.reserve(op_names.size()); + for (auto& op : op_names) { + names.push_back( + op.name + (op.overload_name == "" ? "" : "." + op.overload_name)); + } + return names; + }, + py::arg("dispatch_key") = static_cast("")); } -}}} // namespace torch::impl::dispatch +} // namespace dispatch +} // namespace impl +} // namespace torch diff --git a/torch/csrc/utils/python_dispatch.h b/torch/csrc/utils/python_dispatch.h index cf5b92d1d41f5d..e1abbca27758b1 100644 --- a/torch/csrc/utils/python_dispatch.h +++ b/torch/csrc/utils/python_dispatch.h @@ -6,4 +6,6 @@ namespace dispatch { void initDispatchBindings(PyObject* module); -}}} // namespace torch::impl::dispatch +} +} // namespace impl +} // namespace torch diff --git a/torch/csrc/utils/python_numbers.h b/torch/csrc/utils/python_numbers.h index a4d23ace8e9af6..f312073bec2434 100644 --- a/torch/csrc/utils/python_numbers.h +++ b/torch/csrc/utils/python_numbers.h @@ -91,7 +91,7 @@ inline uint64_t THPUtils_unpackUInt64(PyObject* obj) { return (uint64_t)value; } -inline bool THPUtils_checkIndex(PyObject *obj) { +inline bool THPUtils_checkIndex(PyObject* obj) { if (PyBool_Check(obj)) { return false; } @@ -159,7 +159,7 @@ inline double THPUtils_unpackDouble(PyObject* obj) { return value; } -inline c10::complex THPUtils_unpackComplexDouble(PyObject *obj) { +inline c10::complex THPUtils_unpackComplexDouble(PyObject* obj) { Py_complex value = PyComplex_AsCComplex(obj); if (value.real == -1.0 && PyErr_Occurred()) { throw python_error(); diff --git a/torch/csrc/utils/python_scalars.h b/torch/csrc/utils/python_scalars.h index 7f454bdff8262a..0117dc21b12dff 100644 --- a/torch/csrc/utils/python_scalars.h +++ b/torch/csrc/utils/python_scalars.h @@ -3,44 +3,82 @@ #include #include -#include #include +#include -namespace torch { namespace utils { +namespace torch { +namespace utils { inline void store_scalar(void* data, at::ScalarType scalarType, PyObject* obj) { switch (scalarType) { - case at::kByte: *(uint8_t*)data = (uint8_t)THPUtils_unpackLong(obj); break; - case at::kChar: *(int8_t*)data = (int8_t)THPUtils_unpackLong(obj); break; - case at::kShort: *(int16_t*)data = (int16_t)THPUtils_unpackLong(obj); break; - case at::kInt: *(int32_t*)data = (int32_t)THPUtils_unpackLong(obj); break; - case at::kLong: *(int64_t*)data = THPUtils_unpackLong(obj); break; + case at::kByte: + *(uint8_t*)data = (uint8_t)THPUtils_unpackLong(obj); + break; + case at::kChar: + *(int8_t*)data = (int8_t)THPUtils_unpackLong(obj); + break; + case at::kShort: + *(int16_t*)data = (int16_t)THPUtils_unpackLong(obj); + break; + case at::kInt: + *(int32_t*)data = (int32_t)THPUtils_unpackLong(obj); + break; + case at::kLong: + *(int64_t*)data = THPUtils_unpackLong(obj); + break; case at::kHalf: - *(at::Half*)data = at::convert(THPUtils_unpackDouble(obj)); - break; - case at::kFloat: *(float*)data = (float)THPUtils_unpackDouble(obj); break; - case at::kDouble: *(double*)data = THPUtils_unpackDouble(obj); break; - case at::kComplexHalf: *(c10::complex*)data = (c10::complex)static_cast>(THPUtils_unpackComplexDouble(obj)); break; - case at::kComplexFloat: *(c10::complex*)data = (c10::complex)THPUtils_unpackComplexDouble(obj); break; - case at::kComplexDouble: *(c10::complex*)data = THPUtils_unpackComplexDouble(obj); break; - case at::kBool: *(bool*)data = THPUtils_unpackNumberAsBool(obj); break; + *(at::Half*)data = + at::convert(THPUtils_unpackDouble(obj)); + break; + case at::kFloat: + *(float*)data = (float)THPUtils_unpackDouble(obj); + break; + case at::kDouble: + *(double*)data = THPUtils_unpackDouble(obj); + break; + case at::kComplexHalf: + *(c10::complex*)data = + (c10::complex)static_cast>( + THPUtils_unpackComplexDouble(obj)); + break; + case at::kComplexFloat: + *(c10::complex*)data = + (c10::complex)THPUtils_unpackComplexDouble(obj); + break; + case at::kComplexDouble: + *(c10::complex*)data = THPUtils_unpackComplexDouble(obj); + break; + case at::kBool: + *(bool*)data = THPUtils_unpackNumberAsBool(obj); + break; case at::kBFloat16: - *(at::BFloat16*)data = at::convert(THPUtils_unpackDouble(obj)); + *(at::BFloat16*)data = + at::convert(THPUtils_unpackDouble(obj)); break; - default: throw std::runtime_error("invalid type"); + default: + throw std::runtime_error("invalid type"); } } inline PyObject* load_scalar(void* data, at::ScalarType scalarType) { switch (scalarType) { - case at::kByte: return THPUtils_packInt64(*(uint8_t*)data); - case at::kChar: return THPUtils_packInt64(*(int8_t*)data); - case at::kShort: return THPUtils_packInt64(*(int16_t*)data); - case at::kInt: return THPUtils_packInt64(*(int32_t*)data); - case at::kLong: return THPUtils_packInt64(*(int64_t*)data); - case at::kHalf: return PyFloat_FromDouble(at::convert(*(at::Half*)data)); - case at::kFloat: return PyFloat_FromDouble(*(float*)data); - case at::kDouble: return PyFloat_FromDouble(*(double*)data); + case at::kByte: + return THPUtils_packInt64(*(uint8_t*)data); + case at::kChar: + return THPUtils_packInt64(*(int8_t*)data); + case at::kShort: + return THPUtils_packInt64(*(int16_t*)data); + case at::kInt: + return THPUtils_packInt64(*(int32_t*)data); + case at::kLong: + return THPUtils_packInt64(*(int64_t*)data); + case at::kHalf: + return PyFloat_FromDouble( + at::convert(*(at::Half*)data)); + case at::kFloat: + return PyFloat_FromDouble(*(float*)data); + case at::kDouble: + return PyFloat_FromDouble(*(double*)data); case at::kComplexHalf: { auto data_ = reinterpret_cast*>(data); return PyComplex_FromDoubles(data_->real(), data_->imag()); @@ -49,11 +87,18 @@ inline PyObject* load_scalar(void* data, at::ScalarType scalarType) { auto data_ = reinterpret_cast*>(data); return PyComplex_FromDoubles(data_->real(), data_->imag()); } - case at::kComplexDouble: return PyComplex_FromCComplex(*reinterpret_cast((c10::complex*)data)); - case at::kBool: return PyBool_FromLong(*(bool*)data); - case at::kBFloat16: return PyFloat_FromDouble(at::convert(*(at::BFloat16*)data)); - default: throw std::runtime_error("invalid type"); + case at::kComplexDouble: + return PyComplex_FromCComplex( + *reinterpret_cast((c10::complex*)data)); + case at::kBool: + return PyBool_FromLong(*(bool*)data); + case at::kBFloat16: + return PyFloat_FromDouble( + at::convert(*(at::BFloat16*)data)); + default: + throw std::runtime_error("invalid type"); } } -}} // namespace torch::utils +} // namespace utils +} // namespace torch diff --git a/torch/csrc/utils/python_strings.h b/torch/csrc/utils/python_strings.h index 7621db32403fc7..fe2466acd9420b 100644 --- a/torch/csrc/utils/python_strings.h +++ b/torch/csrc/utils/python_strings.h @@ -1,10 +1,10 @@ #pragma once #include -#include -#include #include #include +#include +#include // Utilities for handling Python strings. Note that PyString, when defined, is // the same as PyBytes. @@ -83,7 +83,8 @@ inline void THPUtils_internStringInPlace(PyObject** obj) { } /* - * Reference: https://github.com/numpy/numpy/blob/f4c497c768e0646df740b647782df463825bfd27/numpy/core/src/common/get_attr_string.h#L42 + * Reference: + * https://github.com/numpy/numpy/blob/f4c497c768e0646df740b647782df463825bfd27/numpy/core/src/common/get_attr_string.h#L42 * * Stripped down version of PyObject_GetAttrString, * avoids lookups for None, tuple, and List objects, @@ -96,37 +97,35 @@ inline void THPUtils_internStringInPlace(PyObject** obj) { * * 'name' is the attribute to search for. * - * Returns a py::object wrapping the return value. If the attribute lookup failed - * the value will be NULL. + * Returns a py::object wrapping the return value. If the attribute lookup + * failed the value will be NULL. * */ // NOLINTNEXTLINE(clang-diagnostic-unused-function) -static py::object PyObject_FastGetAttrString(PyObject *obj, const char *name) -{ - PyTypeObject *tp = Py_TYPE(obj); - PyObject *res = (PyObject *)nullptr; +static py::object PyObject_FastGetAttrString(PyObject* obj, const char* name) { + PyTypeObject* tp = Py_TYPE(obj); + PyObject* res = (PyObject*)nullptr; - /* Attribute referenced by (char *)name */ - if (tp->tp_getattr != nullptr) { - // This is OK per https://bugs.python.org/issue39620 - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) - res = (*tp->tp_getattr)(obj, const_cast(name)); - if (res == nullptr) { - PyErr_Clear(); - } + /* Attribute referenced by (char *)name */ + if (tp->tp_getattr != nullptr) { + // This is OK per https://bugs.python.org/issue39620 + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) + res = (*tp->tp_getattr)(obj, const_cast(name)); + if (res == nullptr) { + PyErr_Clear(); } - /* Attribute referenced by (PyObject *)name */ - else if (tp->tp_getattro != nullptr) { - auto w = py::reinterpret_steal( - THPUtils_internString(name)); - if (w.ptr() == nullptr) { - return py::object(); - } - res = (*tp->tp_getattro)(obj, w.ptr()); - if (res == nullptr) { - PyErr_Clear(); - } + } + /* Attribute referenced by (PyObject *)name */ + else if (tp->tp_getattro != nullptr) { + auto w = py::reinterpret_steal(THPUtils_internString(name)); + if (w.ptr() == nullptr) { + return py::object(); + } + res = (*tp->tp_getattro)(obj, w.ptr()); + if (res == nullptr) { + PyErr_Clear(); } - return py::reinterpret_steal(res); + } + return py::reinterpret_steal(res); } diff --git a/torch/csrc/utils/python_torch_function_mode.h b/torch/csrc/utils/python_torch_function_mode.h index 5a27c9b13a88cf..1ab703dc5032f1 100644 --- a/torch/csrc/utils/python_torch_function_mode.h +++ b/torch/csrc/utils/python_torch_function_mode.h @@ -11,9 +11,14 @@ namespace overrides { // From C++ side, there is no such thing as compositional modes, there is one // mode and of course you should be able to clear it. struct StashTorchFunctionModeGuard { - StashTorchFunctionModeGuard() { at::impl::PythonTorchFunctionTLS::swap_mode(old_mode_); } - ~StashTorchFunctionModeGuard() { at::impl::PythonTorchFunctionTLS::set_mode(std::move(old_mode_)); } -private: + StashTorchFunctionModeGuard() { + at::impl::PythonTorchFunctionTLS::swap_mode(old_mode_); + } + ~StashTorchFunctionModeGuard() { + at::impl::PythonTorchFunctionTLS::set_mode(std::move(old_mode_)); + } + + private: std::shared_ptr old_mode_ = nullptr; }; diff --git a/torch/csrc/utils/python_tuples.h b/torch/csrc/utils/python_tuples.h index fa9a41b536e68c..ab71ccbd444118 100644 --- a/torch/csrc/utils/python_tuples.h +++ b/torch/csrc/utils/python_tuples.h @@ -1,13 +1,16 @@ #pragma once -#include #include +#include #include #include -inline void THPUtils_packInt64Array(PyObject *tuple, size_t size, const int64_t *sizes) { +inline void THPUtils_packInt64Array( + PyObject* tuple, + size_t size, + const int64_t* sizes) { for (size_t i = 0; i != size; ++i) { - PyObject *i64 = THPUtils_packInt64(sizes[i]); + PyObject* i64 = THPUtils_packInt64(sizes[i]); if (!i64) { throw python_error(); } @@ -15,9 +18,10 @@ inline void THPUtils_packInt64Array(PyObject *tuple, size_t size, const int64_t } } -inline PyObject* THPUtils_packInt64Array(size_t size, const int64_t *sizes) { +inline PyObject* THPUtils_packInt64Array(size_t size, const int64_t* sizes) { THPObjectPtr tuple(PyTuple_New(size)); - if (!tuple) throw python_error(); + if (!tuple) + throw python_error(); THPUtils_packInt64Array(tuple.get(), size, sizes); return tuple.release(); } diff --git a/torch/csrc/utils/six.h b/torch/csrc/utils/six.h index b83e60c77cf3c4..d6c17a6516f400 100644 --- a/torch/csrc/utils/six.h +++ b/torch/csrc/utils/six.h @@ -1,8 +1,8 @@ #pragma once #include -#include #include +#include namespace six { @@ -12,7 +12,8 @@ namespace six { // by a pytorch operator. inline bool isStructSeq(pybind11::handle input) { - return pybind11::cast(input.get_type().attr("__module__")) == "torch.return_types"; + return pybind11::cast(input.get_type().attr("__module__")) == + "torch.return_types"; } inline bool isStructSeq(PyObject* obj) { @@ -32,19 +33,19 @@ inline bool isTuple(PyObject* obj) { // maybeAsTuple: if the input is a structseq, then convert it to a tuple // -// On Python 3, structseq is a subtype of tuple, so these APIs could be used directly. -// But on Python 2, structseq is not a subtype of tuple, so we need to manually create a -// new tuple object from structseq. -inline THPObjectPtr maybeAsTuple(PyStructSequence *obj) { +// On Python 3, structseq is a subtype of tuple, so these APIs could be used +// directly. But on Python 2, structseq is not a subtype of tuple, so we need to +// manually create a new tuple object from structseq. +inline THPObjectPtr maybeAsTuple(PyStructSequence* obj) { Py_INCREF(obj); - return THPObjectPtr((PyObject *)obj); + return THPObjectPtr((PyObject*)obj); } -inline THPObjectPtr maybeAsTuple(PyObject *obj) { +inline THPObjectPtr maybeAsTuple(PyObject* obj) { if (isStructSeq(obj)) - return maybeAsTuple((PyStructSequence *)obj); + return maybeAsTuple((PyStructSequence*)obj); Py_INCREF(obj); return THPObjectPtr(obj); } -} // namespace six +} // namespace six diff --git a/torch/csrc/utils/structseq.cpp b/torch/csrc/utils/structseq.cpp index b410bbff4299b0..5533516193cce3 100644 --- a/torch/csrc/utils/structseq.cpp +++ b/torch/csrc/utils/structseq.cpp @@ -12,8 +12,8 @@ * https://github.com/python/cpython#copyright-and-license-information */ -#include #include +#include #include #include @@ -24,49 +24,53 @@ namespace utils { // NOTE: The built-in repr method from PyStructSequence was updated in // https://github.com/python/cpython/commit/c70ab02df2894c34da2223fc3798c0404b41fd79 // so this function might not be required in Python 3.8+. -PyObject *returned_structseq_repr(PyStructSequence *obj) { - PyTypeObject *typ = Py_TYPE(obj); - THPObjectPtr tup = six::maybeAsTuple(obj); - if (tup == nullptr) { - return nullptr; - } +PyObject* returned_structseq_repr(PyStructSequence* obj) { + PyTypeObject* typ = Py_TYPE(obj); + THPObjectPtr tup = six::maybeAsTuple(obj); + if (tup == nullptr) { + return nullptr; + } - std::stringstream ss; - ss << typ->tp_name << "(\n"; - Py_ssize_t num_elements = Py_SIZE(obj); + std::stringstream ss; + ss << typ->tp_name << "(\n"; + Py_ssize_t num_elements = Py_SIZE(obj); - for (Py_ssize_t i = 0; i < num_elements; i++) { - const char *cname = typ->tp_members[i].name; - if (cname == nullptr) { - PyErr_Format(PyExc_SystemError, "In structseq_repr(), member %zd name is nullptr" - " for type %.500s", i, typ->tp_name); - return nullptr; - } + for (Py_ssize_t i = 0; i < num_elements; i++) { + const char* cname = typ->tp_members[i].name; + if (cname == nullptr) { + PyErr_Format( + PyExc_SystemError, + "In structseq_repr(), member %zd name is nullptr" + " for type %.500s", + i, + typ->tp_name); + return nullptr; + } - PyObject* val = PyTuple_GetItem(tup.get(), i); - if (val == nullptr) { - return nullptr; - } + PyObject* val = PyTuple_GetItem(tup.get(), i); + if (val == nullptr) { + return nullptr; + } - auto repr = THPObjectPtr(PyObject_Repr(val)); - if (repr == nullptr) { - return nullptr; - } + auto repr = THPObjectPtr(PyObject_Repr(val)); + if (repr == nullptr) { + return nullptr; + } - const char* crepr = PyUnicode_AsUTF8(repr); - if (crepr == nullptr) { - return nullptr; - } + const char* crepr = PyUnicode_AsUTF8(repr); + if (crepr == nullptr) { + return nullptr; + } - ss << cname << '=' << crepr; - if (i < num_elements - 1) { - ss << ",\n"; - } + ss << cname << '=' << crepr; + if (i < num_elements - 1) { + ss << ",\n"; } - ss << ")"; + } + ss << ")"; - return PyUnicode_FromString(ss.str().c_str()); + return PyUnicode_FromString(ss.str().c_str()); } -} -} +} // namespace utils +} // namespace torch diff --git a/torch/csrc/utils/structseq.h b/torch/csrc/utils/structseq.h index ae1420030f142d..0d91d39d34be61 100644 --- a/torch/csrc/utils/structseq.h +++ b/torch/csrc/utils/structseq.h @@ -2,8 +2,10 @@ #include -namespace torch { namespace utils { +namespace torch { +namespace utils { -PyObject *returned_structseq_repr(PyStructSequence *obj); +PyObject* returned_structseq_repr(PyStructSequence* obj); -}} +} +} // namespace torch diff --git a/torch/csrc/utils/tensor_apply.cpp b/torch/csrc/utils/tensor_apply.cpp index 0a0f65b6bfc033..7632c6511ea47d 100644 --- a/torch/csrc/utils/tensor_apply.cpp +++ b/torch/csrc/utils/tensor_apply.cpp @@ -1,7 +1,7 @@ #include -#include #include +#include #include #include @@ -10,13 +10,14 @@ using namespace at; -namespace torch { namespace utils { +namespace torch { +namespace utils { struct StridedData { - StridedData(const Tensor & tensor) - : data(tensor.data_ptr()) - , strides(tensor.strides()) - , elementSize(tensor.element_size()) {} + StridedData(const Tensor& tensor) + : data(tensor.data_ptr()), + strides(tensor.strides()), + elementSize(tensor.element_size()) {} void* data; IntArrayRef strides; @@ -27,26 +28,33 @@ struct StridedData { } }; -template -static void recursive_apply(IntArrayRef sizes, ScalarType scalarType, int64_t dim, - PyObject* fn, std::array strided_data) { +template +static void recursive_apply( + IntArrayRef sizes, + ScalarType scalarType, + int64_t dim, + PyObject* fn, + std::array strided_data) { int64_t ndim = sizes.size(); if (dim == ndim) { auto args = THPObjectPtr(PyTuple_New(N)); - if (!args) throw python_error(); - for(const auto i : c10::irange(N)) { + if (!args) + throw python_error(); + for (const auto i : c10::irange(N)) { PyObject* arg = load_scalar(strided_data[i].data, scalarType); - if (!arg) throw python_error(); + if (!arg) + throw python_error(); PyTuple_SET_ITEM(args.get(), i, arg); } auto ret = THPObjectPtr(PyObject_CallObject(fn, args.get())); - if (!ret) throw python_error(); + if (!ret) + throw python_error(); store_scalar(strided_data[0].data, scalarType, ret.get()); return; } auto n = sizes[dim]; - for(const auto i : c10::irange(n)) { + for (const auto i : c10::irange(n)) { (void)i; // Suppress unused variable warning recursive_apply(sizes, scalarType, dim + 1, fn, strided_data); for (auto& td : strided_data) { @@ -55,54 +63,71 @@ static void recursive_apply(IntArrayRef sizes, ScalarType scalarType, int64_t di } } -const Tensor & apply_(const Tensor & self, PyObject* fn) { +const Tensor& apply_(const Tensor& self, PyObject* fn) { if (self.is_meta()) { - return self; // Just skip + return self; // Just skip } if (!self.device().is_cpu()) { throw TypeError("apply_ is only implemented on CPU tensors"); } auto scalarType = self.scalar_type(); - recursive_apply<1>(self.sizes(), scalarType, 0, fn, {{ self }}); + recursive_apply<1>(self.sizes(), scalarType, 0, fn, {{self}}); return self; } -const Tensor & map_(const Tensor & self, const Tensor & other_, PyObject* fn) { +const Tensor& map_(const Tensor& self, const Tensor& other_, PyObject* fn) { if (!other_.options().type_equal(self.options())) { - throw TypeError("map_: expected %s for 'other' (got %s)", - self.toString().c_str(), other_.toString().c_str()); + throw TypeError( + "map_: expected %s for 'other' (got %s)", + self.toString().c_str(), + other_.toString().c_str()); } if (self.is_meta()) { - return self; // Just skip + return self; // Just skip } if (!self.device().is_cpu()) { throw TypeError("map_ is only implemented on CPU tensors"); } c10::MaybeOwned other = expand_inplace(self, other_, "map_"); auto scalarType = self.scalar_type(); - recursive_apply<2>(self.sizes(), scalarType, 0, fn, {{ self, *other }}); + recursive_apply<2>(self.sizes(), scalarType, 0, fn, {{self, *other}}); return self; } -const Tensor & map2_(const Tensor & self, const Tensor & x_, const Tensor & y_, PyObject* fn) { +const Tensor& map2_( + const Tensor& self, + const Tensor& x_, + const Tensor& y_, + PyObject* fn) { if (!x_.options().type_equal(self.options())) { - throw TypeError("map2_: expected %s for argument 'x' (got %s)", - self.toString().c_str(), x_.toString().c_str()); + throw TypeError( + "map2_: expected %s for argument 'x' (got %s)", + self.toString().c_str(), + x_.toString().c_str()); } if (!y_.options().type_equal(self.options())) { - throw TypeError("map2_: expected %s for argument 'y' (got %s)", - self.toString().c_str(), y_.toString().c_str()); + throw TypeError( + "map2_: expected %s for argument 'y' (got %s)", + self.toString().c_str(), + y_.toString().c_str()); } if (self.is_meta()) { - return self; // Just skip + return self; // Just skip } - if (!self.device().is_cpu() || !x_.device().is_cpu() || !y_.device().is_cpu()) { + if (!self.device().is_cpu() || !x_.device().is_cpu() || + !y_.device().is_cpu()) { throw TypeError("map2_ is only implemented on CPU tensors"); } auto others = expand_inplace(self, x_, y_, "map2_"); auto scalarType = self.scalar_type(); - recursive_apply<3>(self.sizes(), scalarType, 0, fn, {{ self, *std::get<0>(others), *std::get<1>(others) }}); + recursive_apply<3>( + self.sizes(), + scalarType, + 0, + fn, + {{self, *std::get<0>(others), *std::get<1>(others)}}); return self; } -}} // namespace torch::utils +} // namespace utils +} // namespace torch diff --git a/torch/csrc/utils/tensor_apply.h b/torch/csrc/utils/tensor_apply.h index e460f29eb2b3a9..bd06e0f3e30b48 100644 --- a/torch/csrc/utils/tensor_apply.h +++ b/torch/csrc/utils/tensor_apply.h @@ -1,13 +1,21 @@ #pragma once -#include #include +#include -namespace torch { namespace utils { +namespace torch { +namespace utils { -const at::Tensor & apply_(const at::Tensor & self, PyObject* fn); -const at::Tensor & map_(const at::Tensor & self, const at::Tensor & other_, PyObject* fn); -const at::Tensor & map2_(const at::Tensor & self, const at::Tensor & x_, - const at::Tensor & y_, PyObject* fn); +const at::Tensor& apply_(const at::Tensor& self, PyObject* fn); +const at::Tensor& map_( + const at::Tensor& self, + const at::Tensor& other_, + PyObject* fn); +const at::Tensor& map2_( + const at::Tensor& self, + const at::Tensor& x_, + const at::Tensor& y_, + PyObject* fn); -}} // namespace torch::utils +} // namespace utils +} // namespace torch diff --git a/torch/csrc/utils/tensor_dtypes.cpp b/torch/csrc/utils/tensor_dtypes.cpp index 11a64b3bdf804c..3e0e3acf38c294 100644 --- a/torch/csrc/utils/tensor_dtypes.cpp +++ b/torch/csrc/utils/tensor_dtypes.cpp @@ -1,17 +1,16 @@ -#include #include #include #include #include #include #include +#include #include namespace torch { namespace utils { -std::pair getDtypeNames( - at::ScalarType scalarType) { +std::pair getDtypeNames(at::ScalarType scalarType) { switch (scalarType) { case at::ScalarType::Byte: // no "byte" because byte is signed in numpy and we overload @@ -72,7 +71,7 @@ void initializeDtypes() { for (at::ScalarType scalarType : all_scalar_types) { std::string primary_name, legacy_name; std::tie(primary_name, legacy_name) = getDtypeNames(scalarType); - PyObject *dtype = THPDtype_New(scalarType, primary_name); + PyObject* dtype = THPDtype_New(scalarType, primary_name); torch::registerDtypeObject((THPDtype*)dtype, scalarType); Py_INCREF(dtype); if (PyModule_AddObject(torch_module.get(), primary_name.c_str(), dtype) != diff --git a/torch/csrc/utils/tensor_dtypes.h b/torch/csrc/utils/tensor_dtypes.h index 3dbe2fef7f045f..32b769971d03fa 100644 --- a/torch/csrc/utils/tensor_dtypes.h +++ b/torch/csrc/utils/tensor_dtypes.h @@ -1,13 +1,15 @@ #pragma once -#include -#include #include +#include +#include -namespace torch { namespace utils { +namespace torch { +namespace utils { std::pair getDtypeNames(at::ScalarType scalarType); void initializeDtypes(); -}} // namespace torch::utils +} // namespace utils +} // namespace torch diff --git a/torch/csrc/utils/tensor_flatten.cpp b/torch/csrc/utils/tensor_flatten.cpp index aff08537ba10f1..6b0f6388b276d2 100644 --- a/torch/csrc/utils/tensor_flatten.cpp +++ b/torch/csrc/utils/tensor_flatten.cpp @@ -3,11 +3,11 @@ #include #include -namespace torch { namespace utils { +namespace torch { +namespace utils { using namespace at; - std::vector take_tensors( TensorList tensors, size_t size_limit, @@ -18,13 +18,13 @@ std::vector take_tensors( std::map groups; size_t cur_group_size = 0; - for (const auto & tensor : tensors) { + for (const auto& tensor : tensors) { size_t tensor_size = 0; if (tensor.is_sparse()) { const auto& indices = tensor._indices(); const auto& values = tensor._values(); tensor_size = indices.numel() * indices.element_size() + - values.numel() * indices.element_size(); + values.numel() * indices.element_size(); } else { tensor_size = tensor.numel() * tensor.element_size(); } @@ -72,10 +72,10 @@ void reorder_tensors_like(std::vector& tensors, TensorList order) { std::unordered_map type_id_to_type_used; std::vector ordered_tensors; ordered_tensors.reserve(tensors.size()); - for (auto & tmpl_tensor : order) { + for (auto& tmpl_tensor : order) { size_t tmpl_type_id = type_id(tmpl_tensor); - auto & indices = type_id_to_indices[tmpl_type_id]; - auto & used = type_id_to_type_used[tmpl_type_id]; + auto& indices = type_id_to_indices[tmpl_type_id]; + auto& used = type_id_to_type_used[tmpl_type_id]; ordered_tensors.push_back(tensors[indices[used++]]); } std::swap(tensors, ordered_tensors); @@ -91,31 +91,37 @@ at::Tensor get_values(const at::Tensor& t) { return t._values(); } -} +} // namespace -std::pair flatten_sparse_tensors(at::TensorList tensors) { +std::pair flatten_sparse_tensors( + at::TensorList tensors) { auto flat_indices = utils::flatten_dense_tensors(fmap(tensors, &get_indices)); auto flat_values = utils::flatten_dense_tensors(fmap(tensors, &get_values)); return std::make_pair(flat_indices, flat_values); } std::vector unflatten_sparse_tensors( - const at::Tensor& flat_indices, const at::Tensor& flat_values, - at::TensorList tensors) { - if (tensors.size() == 0) return {}; + const at::Tensor& flat_indices, + const at::Tensor& flat_values, + at::TensorList tensors) { + if (tensors.size() == 0) + return {}; - auto indices = utils::unflatten_dense_tensors(flat_indices, fmap(tensors, &get_indices)); - auto values = utils::unflatten_dense_tensors(flat_values, fmap(tensors, &get_values)); + auto indices = + utils::unflatten_dense_tensors(flat_indices, fmap(tensors, &get_indices)); + auto values = + utils::unflatten_dense_tensors(flat_values, fmap(tensors, &get_values)); std::vector outputs; outputs.reserve(tensors.size()); for (size_t i = 0, num_tensors = tensors.size(); i < num_tensors; ++i) { - auto &ref_t = tensors[i]; - auto t = at::_sparse_coo_tensor_unsafe(indices[i], values[i], ref_t.sizes()); + auto& ref_t = tensors[i]; + auto t = + at::_sparse_coo_tensor_unsafe(indices[i], values[i], ref_t.sizes()); outputs.emplace_back(t._coalesced_(ref_t.is_coalesced())); } return outputs; } - -}} +} // namespace utils +} // namespace torch diff --git a/torch/csrc/utils/tensor_flatten.h b/torch/csrc/utils/tensor_flatten.h index cc39f4129bf7f5..77b506a92adae1 100644 --- a/torch/csrc/utils/tensor_flatten.h +++ b/torch/csrc/utils/tensor_flatten.h @@ -1,12 +1,13 @@ #pragma once +#include #include +#include #include -#include #include -#include -namespace torch { namespace utils { +namespace torch { +namespace utils { /// Generate an ID for a combination of tensor backend + scalar type to be used /// when ordering tensors ('like' tensors are grouped by pulling out their @@ -21,7 +22,9 @@ inline at::Tensor flatten_dense_tensors(at::TensorList tensors) { return at::flatten_dense_tensors(tensors); } -inline std::vector unflatten_dense_tensors(const at::Tensor& flat, at::TensorList tensors) { +inline std::vector unflatten_dense_tensors( + const at::Tensor& flat, + at::TensorList tensors) { return at::unflatten_dense_tensors(flat, tensors); } @@ -68,13 +71,17 @@ TORCH_API std::vector take_tensors( size_t size_limit, bool fine_grained = false); -TORCH_API void reorder_tensors_like(std::vector& tensors, at::TensorList order); +TORCH_API void reorder_tensors_like( + std::vector& tensors, + at::TensorList order); -TORCH_API std::pair flatten_sparse_tensors(at::TensorList tensors); +TORCH_API std::pair flatten_sparse_tensors( + at::TensorList tensors); TORCH_API std::vector unflatten_sparse_tensors( const at::Tensor& flat_indices, const at::Tensor& flat_values, at::TensorList tensors); -}} +} // namespace utils +} // namespace torch diff --git a/torch/csrc/utils/tensor_layouts.cpp b/torch/csrc/utils/tensor_layouts.cpp index 8c87c0cc7788f5..debce6ff662734 100644 --- a/torch/csrc/utils/tensor_layouts.cpp +++ b/torch/csrc/utils/tensor_layouts.cpp @@ -7,29 +7,33 @@ #include #include -namespace torch { namespace utils { - -#define REGISTER_LAYOUT(layout, LAYOUT) \ - PyObject* layout##_layout = \ - THPLayout_New(at::Layout::LAYOUT, "torch." # layout); \ - Py_INCREF(layout##_layout); \ - if (PyModule_AddObject(torch_module, "" # layout, layout##_layout) != 0) { \ - throw python_error(); \ - } \ - registerLayoutObject((THPLayout*)layout##_layout, at::Layout::LAYOUT); +namespace torch { +namespace utils { + +#define REGISTER_LAYOUT(layout, LAYOUT) \ + PyObject* layout##_layout = \ + THPLayout_New(at::Layout::LAYOUT, "torch." #layout); \ + Py_INCREF(layout##_layout); \ + if (PyModule_AddObject(torch_module, "" #layout, layout##_layout) != 0) { \ + throw python_error(); \ + } \ + registerLayoutObject((THPLayout*)layout##_layout, at::Layout::LAYOUT); void initializeLayouts() { auto torch_module = THPObjectPtr(PyImport_ImportModule("torch")); - if (!torch_module) throw python_error(); + if (!torch_module) + throw python_error(); - PyObject* strided_layout = THPLayout_New(at::Layout::Strided, "torch.strided"); + PyObject* strided_layout = + THPLayout_New(at::Layout::Strided, "torch.strided"); Py_INCREF(strided_layout); if (PyModule_AddObject(torch_module, "strided", strided_layout) != 0) { throw python_error(); } registerLayoutObject((THPLayout*)strided_layout, at::Layout::Strided); - PyObject* sparse_coo_layout = THPLayout_New(at::Layout::Sparse, "torch.sparse_coo"); + PyObject* sparse_coo_layout = + THPLayout_New(at::Layout::Sparse, "torch.sparse_coo"); Py_INCREF(sparse_coo_layout); if (PyModule_AddObject(torch_module, "sparse_coo", sparse_coo_layout) != 0) { throw python_error(); @@ -49,4 +53,5 @@ void initializeLayouts() { registerLayoutObject((THPLayout*)mkldnn_layout, at::Layout::Mkldnn); } -}} // namespace torch::utils +} // namespace utils +} // namespace torch diff --git a/torch/csrc/utils/tensor_layouts.h b/torch/csrc/utils/tensor_layouts.h index 57142d54ada7bb..33e32b516b1215 100644 --- a/torch/csrc/utils/tensor_layouts.h +++ b/torch/csrc/utils/tensor_layouts.h @@ -1,7 +1,9 @@ #pragma once -namespace torch { namespace utils { +namespace torch { +namespace utils { void initializeLayouts(); -}} // namespace torch::utils +} +} // namespace torch diff --git a/torch/csrc/utils/tensor_list.cpp b/torch/csrc/utils/tensor_list.cpp index 63cffd438afa1d..492500992e5411 100644 --- a/torch/csrc/utils/tensor_list.cpp +++ b/torch/csrc/utils/tensor_list.cpp @@ -7,22 +7,29 @@ using namespace at; -namespace torch { namespace utils { +namespace torch { +namespace utils { static PyObject* recursive_to_list( - char* data, IntArrayRef sizes, IntArrayRef strides, int64_t dim, - ScalarType scalarType, int64_t elementSize) -{ + char* data, + IntArrayRef sizes, + IntArrayRef strides, + int64_t dim, + ScalarType scalarType, + int64_t elementSize) { int64_t ndim = sizes.size(); if (dim == ndim) { return torch::utils::load_scalar(data, scalarType); } auto n = sizes[dim]; auto list = THPObjectPtr(PyList_New(n)); - if (!list) throw python_error(); - for(const auto i : c10::irange(n)) { - PyObject* obj = recursive_to_list(data, sizes, strides, dim + 1, scalarType, elementSize); - if (!obj) throw python_error(); + if (!list) + throw python_error(); + for (const auto i : c10::irange(n)) { + PyObject* obj = recursive_to_list( + data, sizes, strides, dim + 1, scalarType, elementSize); + if (!obj) + throw python_error(); PyList_SET_ITEM(list.get(), i, obj); auto advance_data_ptr = strides[dim] * elementSize; TORCH_INTERNAL_ASSERT(data || (advance_data_ptr == 0)); @@ -32,16 +39,25 @@ static PyObject* recursive_to_list( } PyObject* tensor_to_list(const Tensor& tensor) { - TORCH_CHECK(!tensor.unsafeGetTensorImpl()->is_python_dispatch(), ".tolist() is not supported for tensor subclasses."); + TORCH_CHECK( + !tensor.unsafeGetTensorImpl()->is_python_dispatch(), + ".tolist() is not supported for tensor subclasses."); Tensor data = tensor.resolve_conj().resolve_neg(); if (!data.device().is_cpu()) { pybind11::gil_scoped_release no_gil; data = data.toBackend(Backend::CPU); } - TORCH_CHECK(tensor.numel() == 0 || data.data_ptr(), "tolist() shouldn't be called on a tensor with unallocated storage"); + TORCH_CHECK( + tensor.numel() == 0 || data.data_ptr(), + "tolist() shouldn't be called on a tensor with unallocated storage"); return recursive_to_list( - (char*)data.data_ptr(), data.sizes(), data.strides(), 0, - data.scalar_type(), tensor.numel() == 0 ? 0 : data.dtype().itemsize()); + (char*)data.data_ptr(), + data.sizes(), + data.strides(), + 0, + data.scalar_type(), + tensor.numel() == 0 ? 0 : data.dtype().itemsize()); } -}} // namespace torch::utils +} // namespace utils +} // namespace torch diff --git a/torch/csrc/utils/tensor_list.h b/torch/csrc/utils/tensor_list.h index 7f3de6c23fa03e..8ae77df4700af5 100644 --- a/torch/csrc/utils/tensor_list.h +++ b/torch/csrc/utils/tensor_list.h @@ -6,8 +6,10 @@ namespace at { class Tensor; } -namespace torch { namespace utils { +namespace torch { +namespace utils { PyObject* tensor_to_list(const at::Tensor& tensor); -}} // namespace torch::utils +} +} // namespace torch diff --git a/torch/csrc/utils/tensor_memoryformats.cpp b/torch/csrc/utils/tensor_memoryformats.cpp index 07c480e7ee8202..1faecd5e62346a 100644 --- a/torch/csrc/utils/tensor_memoryformats.cpp +++ b/torch/csrc/utils/tensor_memoryformats.cpp @@ -1,24 +1,25 @@ #include +#include #include #include #include -#include #include #include - namespace torch { namespace utils { namespace { // Intentionally leaked -std::array(at::MemoryFormat::NumOptions)> memory_format_registry = {}; +std::array(at::MemoryFormat::NumOptions)> + memory_format_registry = {}; } // anonymous namespace py::object getTHPMemoryFormat(at::MemoryFormat memory_format) { - return py::reinterpret_borrow(memory_format_registry[static_cast(memory_format)]); + return py::reinterpret_borrow( + memory_format_registry[static_cast(memory_format)]); } void initializeMemoryFormats() { @@ -43,7 +44,6 @@ void initializeMemoryFormats() { add_memory_format(at::MemoryFormat::Contiguous, "contiguous_format"); add_memory_format(at::MemoryFormat::ChannelsLast, "channels_last"); add_memory_format(at::MemoryFormat::ChannelsLast3d, "channels_last_3d"); - } } // namespace utils diff --git a/torch/csrc/utils/tensor_memoryformats.h b/torch/csrc/utils/tensor_memoryformats.h index e078d540bd96af..e9aa1af1f64787 100644 --- a/torch/csrc/utils/tensor_memoryformats.h +++ b/torch/csrc/utils/tensor_memoryformats.h @@ -1,11 +1,13 @@ #pragma once -#include #include +#include -namespace torch { namespace utils { +namespace torch { +namespace utils { void initializeMemoryFormats(); py::object getTHPMemoryFormat(c10::MemoryFormat); -}} // namespace torch::utils +} // namespace utils +} // namespace torch diff --git a/torch/csrc/utils/tensor_new.cpp b/torch/csrc/utils/tensor_new.cpp index 2e19a5e6ba5002..9f93191c7f3136 100644 --- a/torch/csrc/utils/tensor_new.cpp +++ b/torch/csrc/utils/tensor_new.cpp @@ -5,6 +5,7 @@ #include #include #include +#include #include #include #include @@ -13,21 +14,20 @@ #include #include #include -#include #include #include -#include #include #include #include #include +#include #include #include #include #include -#include #include +#include #include #include @@ -37,8 +37,8 @@ using at::Device; using at::IntArrayRef; using at::kCPU; using at::kCUDA; -using at::kLong; using at::kInt; +using at::kLong; using at::Scalar; using at::ScalarType; using at::Storage; @@ -47,11 +47,15 @@ using at::TensorOptions; using at::Type; using c10::optional; -namespace torch { namespace utils { +namespace torch { +namespace utils { namespace { const int MAX_DIMS = 128; -TensorOptions build_options(c10::TensorOptions options, at::ScalarType scalar_type, const c10::optional& device=c10::nullopt) { +TensorOptions build_options( + c10::TensorOptions options, + at::ScalarType scalar_type, + const c10::optional& device = c10::nullopt) { options = options.dtype(scalar_type); if (device.has_value()) { return options.device(device); @@ -65,18 +69,25 @@ void maybe_initialize_cuda(const Device device) { } } -// NB: It appears there is some consistency invariant between options and device, where -// if device is non-empty, its type must be consistent with the device type in -// options. +// NB: It appears there is some consistency invariant between options and +// device, where if device is non-empty, its type must be consistent with the +// device type in options. // TODO: Refactor this so we just pass everything in via options -Tensor new_with_sizes(c10::TensorOptions options, at::ScalarType scalar_type, const optional& device, IntArrayRef sizes) { +Tensor new_with_sizes( + c10::TensorOptions options, + at::ScalarType scalar_type, + const optional& device, + IntArrayRef sizes) { maybe_initialize_cuda(options.device()); pybind11::gil_scoped_release no_gil; return torch::empty(sizes, build_options(options, scalar_type, device)); } -Tensor new_with_storage(c10::TensorOptions options, at::ScalarType scalar_type, Storage storage) { +Tensor new_with_storage( + c10::TensorOptions options, + at::ScalarType scalar_type, + Storage storage) { auto tensor = at::empty({}, build_options(options, scalar_type)); tensor.set_(std::move(storage)); return tensor; @@ -88,7 +99,8 @@ std::vector compute_sizes(PyObject* seq, ScalarType scalar_type) { THPObjectPtr handle; while (PySequence_Check(seq)) { auto length = PySequence_Length(seq); - if (length < 0) throw python_error(); + if (length < 0) + throw python_error(); if (is_storage) { length /= elementSize(scalar_type); } @@ -96,10 +108,13 @@ std::vector compute_sizes(PyObject* seq, ScalarType scalar_type) { if (sizes.size() > MAX_DIMS) { throw ValueError("too many dimensions '%s'", Py_TYPE(seq)->tp_name); } - if (length == 0) break; + if (length == 0) + break; handle = THPObjectPtr(PySequence_GetItem(seq, 0)); if (!handle) { - throw ValueError("could not determine the shape of object type '%s'", Py_TYPE(seq)->tp_name); + throw ValueError( + "could not determine the shape of object type '%s'", + Py_TYPE(seq)->tp_name); } seq = handle.get(); } @@ -107,7 +122,7 @@ std::vector compute_sizes(PyObject* seq, ScalarType scalar_type) { return sizes; } -ScalarType infer_scalar_type(PyObject *obj) { +ScalarType infer_scalar_type(PyObject* obj) { #ifdef USE_NUMPY if (is_numpy_available()) { if (PyArray_Check(obj)) { @@ -115,13 +130,14 @@ ScalarType infer_scalar_type(PyObject *obj) { } if (PyArray_CheckScalar(obj)) { THPObjectPtr arr(PyArray_FromScalar(obj, nullptr)); - return numpy_dtype_to_aten(PyArray_TYPE((PyArrayObject*) arr.get())); + return numpy_dtype_to_aten(PyArray_TYPE((PyArrayObject*)arr.get())); } } #endif if (PyFloat_Check(obj)) { // this is always guaranteed to be a floating-point type, and makes it more - // convenient to write e.g. torch.tensor(0.) than torch.tensor(0., dtype=torch.Tensor.dtype). + // convenient to write e.g. torch.tensor(0.) than torch.tensor(0., + // dtype=torch.Tensor.dtype). return torch::tensors::get_default_scalar_type(); } if (THPUtils_checkLong(obj)) { @@ -132,9 +148,12 @@ ScalarType infer_scalar_type(PyObject *obj) { } if (PyComplex_Check(obj)) { switch (torch::tensors::get_default_scalar_type()) { - case ScalarType::Float: return ScalarType::ComplexFloat; - case ScalarType::Double: return ScalarType::ComplexDouble; - default: TORCH_CHECK(false, "invalid default scalar type for complex"); + case ScalarType::Float: + return ScalarType::ComplexFloat; + case ScalarType::Double: + return ScalarType::ComplexDouble; + default: + TORCH_CHECK(false, "invalid default scalar type for complex"); } } if (THPVariable_Check(obj)) { @@ -147,19 +166,24 @@ ScalarType infer_scalar_type(PyObject *obj) { if (PySequence_Check(obj)) { c10::optional scalarType; auto length = PySequence_Length(obj); - if (length < 0) throw python_error(); + if (length < 0) + throw python_error(); // match NumPy semantics, except use default tensor type instead of double. - if (length == 0) return torch::tensors::get_default_scalar_type(); + if (length == 0) + return torch::tensors::get_default_scalar_type(); for (const auto i : c10::irange(length)) { THPObjectPtr handle(PySequence_GetItem(obj, i)); - if (!handle) throw python_error(); + if (!handle) + throw python_error(); auto cur_item = handle.get(); - if (cur_item == obj) throw TypeError("new(): self-referential lists are incompatible"); + if (cur_item == obj) + throw TypeError("new(): self-referential lists are incompatible"); ScalarType item_scalarType = infer_scalar_type(cur_item); - scalarType = (scalarType) ? - at::promoteTypes(*scalarType, item_scalarType) : item_scalarType; + scalarType = (scalarType) ? at::promoteTypes(*scalarType, item_scalarType) + : item_scalarType; if (scalarType == ScalarType::ComplexDouble) { - // this won't change (unless we hit undefined, but that will fail later). + // this won't change (unless we hit undefined, but that will fail + // later). return *scalarType; } } @@ -168,8 +192,14 @@ ScalarType infer_scalar_type(PyObject *obj) { AT_ERROR("Could not infer dtype of ", Py_TYPE(obj)->tp_name); } -void recursive_store(char* data, IntArrayRef sizes, IntArrayRef strides, int64_t dim, - ScalarType scalarType, int elementSize, PyObject* obj) { +void recursive_store( + char* data, + IntArrayRef sizes, + IntArrayRef strides, + int64_t dim, + ScalarType scalarType, + int elementSize, + PyObject* obj) { TORCH_INTERNAL_ASSERT_DEBUG_ONLY(data != nullptr); int64_t ndim = sizes.size(); @@ -180,25 +210,30 @@ void recursive_store(char* data, IntArrayRef sizes, IntArrayRef strides, int64_t auto n = sizes[dim]; auto seq = THPObjectPtr(PySequence_Fast(obj, "not a sequence")); - if (!seq) throw python_error(); + if (!seq) + throw python_error(); // NOLINTNEXTLINE(bugprone-branch-clone) auto seq_size = PySequence_Fast_GET_SIZE(seq.get()); if (seq_size != n) { - throw ValueError("expected sequence of length %lld at dim %lld (got %lld)", - (long long)n, (long long)dim, (long long)seq_size); + throw ValueError( + "expected sequence of length %lld at dim %lld (got %lld)", + (long long)n, + (long long)dim, + (long long)seq_size); } PyObject** items = PySequence_Fast_ITEMS(seq.get()); - for(const auto i : c10::irange(n)) { + for (const auto i : c10::irange(n)) { #ifdef USE_NUMPY if (is_numpy_available() && PyArray_Check(items[i])) { TORCH_WARN_ONCE( - "Creating a tensor from a list of numpy.ndarrays is extremely slow. " - "Please consider converting the list to a single numpy.ndarray with " - "numpy.array() before converting to a tensor."); + "Creating a tensor from a list of numpy.ndarrays is extremely slow. " + "Please consider converting the list to a single numpy.ndarray with " + "numpy.array() before converting to a tensor."); } #endif - recursive_store(data, sizes, strides, dim + 1, scalarType, elementSize, items[i]); + recursive_store( + data, sizes, strides, dim + 1, scalarType, elementSize, items[i]); data += strides[dim] * elementSize; } } @@ -223,34 +258,53 @@ Tensor internal_new_from_data( if (copy_variables) { var = var.detach(); } - // infer the scalar type and device type; it's not expected to infer the layout since these constructors - // are defined per-layout-type (e.g. tensor vs sparse_coo_tensor). - const auto& inferred_scalar_type = type_inference ? var.scalar_type() : scalar_type; + // infer the scalar type and device type; it's not expected to infer the + // layout since these constructors are defined per-layout-type (e.g. tensor + // vs sparse_coo_tensor). + const auto& inferred_scalar_type = + type_inference ? var.scalar_type() : scalar_type; auto device = device_opt.has_value() ? *device_opt : var.device(); pybind11::gil_scoped_release no_gil; maybe_initialize_cuda(device); - return var.to(device, inferred_scalar_type, /*non_blocking=*/false, /*copy=*/copy_variables); + return var.to( + device, + inferred_scalar_type, + /*non_blocking=*/false, + /*copy=*/copy_variables); } #ifdef USE_NUMPY if (PyObject_HasAttrString(data, "__cuda_array_interface__")) { - TORCH_CHECK(!pin_memory, "Can't pin tensor constructed from __cuda_array_interface__"); + TORCH_CHECK( + !pin_memory, + "Can't pin tensor constructed from __cuda_array_interface__"); auto tensor = tensor_from_cuda_array_interface(data); - const auto& inferred_scalar_type = type_inference ? tensor.scalar_type() : scalar_type; + const auto& inferred_scalar_type = + type_inference ? tensor.scalar_type() : scalar_type; auto device = device_opt.has_value() ? *device_opt : options.device(); pybind11::gil_scoped_release no_gil; maybe_initialize_cuda(device); - return tensor.to(device, inferred_scalar_type, /*non_blocking=*/false, /*copy=*/copy_numpy); + return tensor.to( + device, + inferred_scalar_type, + /*non_blocking=*/false, + /*copy=*/copy_numpy); } if (is_numpy_available() && PyArray_Check(data)) { TORCH_CHECK(!pin_memory, "Can't pin tensor constructed from numpy"); - auto tensor = tensor_from_numpy(data, /*warn_if_not_writeable=*/!copy_numpy); - const auto& inferred_scalar_type = type_inference ? tensor.scalar_type() : scalar_type; + auto tensor = + tensor_from_numpy(data, /*warn_if_not_writeable=*/!copy_numpy); + const auto& inferred_scalar_type = + type_inference ? tensor.scalar_type() : scalar_type; auto device = device_opt.has_value() ? *device_opt : options.device(); pybind11::gil_scoped_release no_gil; maybe_initialize_cuda(device); - return tensor.to(device, inferred_scalar_type, /*non_blocking=*/false, /*copy=*/copy_numpy); + return tensor.to( + device, + inferred_scalar_type, + /*non_blocking=*/false, + /*copy=*/copy_numpy); } #endif @@ -258,57 +312,84 @@ Tensor internal_new_from_data( auto sizes = compute_sizes(data, scalar_type); - ScalarType inferred_scalar_type = type_inference ? infer_scalar_type(data) : scalar_type; + ScalarType inferred_scalar_type = + type_inference ? infer_scalar_type(data) : scalar_type; // This exists to prevent us from tracing the call to empty(). The actual // autograd code doesn't really matter, because requires_grad is always false // here. // What are the semantics of tensor_new()? - // We manually construct a tensor and place on it on the correct device with empty() and to(). - // We then have to "lift" the newly constructed tensor in some cases, like when we're performing - // a functorch transform or running functionalization. - // The exclude guards are all to ensure that extra logic doesn't run when we're constructing the raw tensor. + // We manually construct a tensor and place on it on the correct device with + // empty() and to(). We then have to "lift" the newly constructed tensor in + // some cases, like when we're performing a functorch transform or running + // functionalization. The exclude guards are all to ensure that extra logic + // doesn't run when we're constructing the raw tensor. Tensor tensor; { at::AutoDispatchBelowADInplaceOrView guard; - c10::impl::ExcludeDispatchKeyGuard torchdispatchmode_guard(c10::DispatchKey::Python); - c10::impl::ExcludeDispatchKeyGuard torchdispatchmode_snapshot_guard(c10::DispatchKey::PythonTLSSnapshot); + c10::impl::ExcludeDispatchKeyGuard torchdispatchmode_guard( + c10::DispatchKey::Python); + c10::impl::ExcludeDispatchKeyGuard torchdispatchmode_snapshot_guard( + c10::DispatchKey::PythonTLSSnapshot); // functorch uses FuncTorchDynamicLayerBackMode as a mode key to wrap all // tensors returned from operators in special TensorWrapper tensor extension - c10::impl::ExcludeDispatchKeyGuard functorch_front_guard(c10::DispatchKey::FuncTorchDynamicLayerFrontMode); - c10::impl::ExcludeDispatchKeyGuard functorch_back_guard(c10::DispatchKey::FuncTorchDynamicLayerBackMode); + c10::impl::ExcludeDispatchKeyGuard functorch_front_guard( + c10::DispatchKey::FuncTorchDynamicLayerFrontMode); + c10::impl::ExcludeDispatchKeyGuard functorch_back_guard( + c10::DispatchKey::FuncTorchDynamicLayerBackMode); // We disable DeferredInit handler for similar reasons as functorch. - c10::impl::ExcludeDispatchKeyGuard deferred_init_guard(c10::DispatchKey::DeferredInit); + c10::impl::ExcludeDispatchKeyGuard deferred_init_guard( + c10::DispatchKey::DeferredInit); // Note [Functionalization <> torch.Tensor constructor] - // Functionalization "lifts" the newly constructed tensor into a wrapper using aten::lift(). - c10::impl::ExcludeDispatchKeyGuard functionalize_guard(c10::DispatchKey::Functionalize); + // Functionalization "lifts" the newly constructed tensor into a wrapper + // using aten::lift(). + c10::impl::ExcludeDispatchKeyGuard functionalize_guard( + c10::DispatchKey::Functionalize); { - // Tracing should probably also use the "lift" operator to add the tensor to a trace, - // but it's technically BC-breaking to do that, since we currently trace .to() calls. + // Tracing should probably also use the "lift" operator to add the tensor + // to a trace, but it's technically BC-breaking to do that, since we + // currently trace .to() calls. at::tracer::impl::NoTracerDispatchMode tracer_guard; if (isStorage(data)) { ScalarType storage_scalar_type; bool is_typed_storage = false; - Storage storage = createStorageGetType(data, storage_scalar_type, is_typed_storage); - TORCH_CHECK(!is_typed_storage || storage_scalar_type == scalar_type, - "Expected a Storage of type ", scalar_type, - " or an _UntypedStorage, but got ", storage_scalar_type); - tensor = at::empty(sizes, at::initialTensorOptions().dtype(is_typed_storage ? storage_scalar_type : inferred_scalar_type).pinned_memory(pin_memory).device(storage.device())); + Storage storage = + createStorageGetType(data, storage_scalar_type, is_typed_storage); + TORCH_CHECK( + !is_typed_storage || storage_scalar_type == scalar_type, + "Expected a Storage of type ", + scalar_type, + " or an _UntypedStorage, but got ", + storage_scalar_type); + tensor = at::empty( + sizes, + at::initialTensorOptions() + .dtype( + is_typed_storage ? storage_scalar_type + : inferred_scalar_type) + .pinned_memory(pin_memory) + .device(storage.device())); tensor.set_(storage); } else { - TensorOptions opts = at::initialTensorOptions().dtype(inferred_scalar_type); + TensorOptions opts = + at::initialTensorOptions().dtype(inferred_scalar_type); - // If the device is Meta, take the shortcut. We don't want to allocate an - // empty CPU tensor which would break our contract for meta tensors. + // If the device is Meta, take the shortcut. We don't want to allocate + // an empty CPU tensor which would break our contract for meta tensors. if (device == at::kMeta) { return at::empty(sizes, opts.device(device)); } tensor = at::empty(sizes, opts.pinned_memory(pin_memory)); if (c10::multiply_integers(tensor.sizes()) != 0) { recursive_store( - (char*)tensor.data_ptr(), tensor.sizes(), tensor.strides(), 0, - inferred_scalar_type, tensor.dtype().itemsize(), data); + (char*)tensor.data_ptr(), + tensor.sizes(), + tensor.strides(), + 0, + inferred_scalar_type, + tensor.dtype().itemsize(), + data); } } } @@ -322,13 +403,15 @@ Tensor internal_new_from_data( // "no observable data dependence". In an ideal world, we wouldn't trace // a to() call but I need to think harder about what exactly we should trace // in this case. - tensor = tensor.to(device, inferred_scalar_type, /*non_blocking=*/false, /*copy=*/false); + tensor = tensor.to( + device, inferred_scalar_type, /*non_blocking=*/false, /*copy=*/false); } - // torch.jit.trace will continue to trace out `.to()` instead of `.lift()`, since - // changing it is BC-breaking. + // torch.jit.trace will continue to trace out `.to()` instead of `.lift()`, + // since changing it is BC-breaking. at::tracer::impl::NoTracerDispatchMode tracer_guard; - // lift has no autograd implementation, so we need to make sure we don't try to dispatch to it. + // lift has no autograd implementation, so we need to make sure we don't try + // to dispatch to it. at::AutoDispatchBelowADInplaceOrView guard; return tensor.lift(); } @@ -338,9 +421,14 @@ Tensor new_from_data_copy( at::ScalarType scalar_type, c10::optional device, PyObject* data) { - return internal_new_from_data(options, scalar_type, device, data, - /*copy_variables=*/true, /*copy_numpy=*/true, - /*type_inference=*/false); + return internal_new_from_data( + options, + scalar_type, + device, + data, + /*copy_variables=*/true, + /*copy_numpy=*/true, + /*type_inference=*/false); } Tensor legacy_new_from_sequence( @@ -349,17 +437,25 @@ Tensor legacy_new_from_sequence( c10::optional device, PyObject* data) { if (!PySequence_Check(data)) { - throw TypeError("new(): data must be a sequence (got %s)", Py_TYPE(data)->tp_name); + throw TypeError( + "new(): data must be a sequence (got %s)", Py_TYPE(data)->tp_name); } - return internal_new_from_data(options, scalar_type, device, data, - /*copy_variables=*/false, /*copy_numpy=*/false, - /*type_inference=*/false); + return internal_new_from_data( + options, + scalar_type, + device, + data, + /*copy_variables=*/false, + /*copy_numpy=*/false, + /*type_inference=*/false); } -// "base" here refers to the Tensor type on which the function was invoked, e.g.: -// in x.new(y), 'x' is the base. +// "base" here refers to the Tensor type on which the function was invoked, +// e.g.: in x.new(y), 'x' is the base. // TODO: Rewrite this using dispatchKeyToTensorOptions -void check_base_legacy_new(c10::DispatchKey dispatch_key, at::Layout expected_layout) { +void check_base_legacy_new( + c10::DispatchKey dispatch_key, + at::Layout expected_layout) { if (expected_layout == c10::kStrided) { constexpr c10::DispatchKeySet expected_key_set({ c10::DispatchKey::CPU, @@ -372,12 +468,13 @@ void check_base_legacy_new(c10::DispatchKey dispatch_key, at::Layout expected_la c10::DispatchKey::HPU, c10::DispatchKey::MPS, }); - TORCH_CHECK(expected_key_set.has(dispatch_key), + TORCH_CHECK( + expected_key_set.has(dispatch_key), "new(): expected key in ", expected_key_set, " but got: ", dispatch_key); - } else if(expected_layout == c10::kSparse) { + } else if (expected_layout == c10::kSparse) { // NOTE: no sparse XLA or Lazy constexpr c10::DispatchKeySet expected_key_set({ c10::DispatchKey::SparseCPU, @@ -385,7 +482,8 @@ void check_base_legacy_new(c10::DispatchKey dispatch_key, at::Layout expected_la c10::DispatchKey::SparseHIP, c10::DispatchKey::SparseXPU, }); - TORCH_CHECK(expected_key_set.has(dispatch_key), + TORCH_CHECK( + expected_key_set.has(dispatch_key), "new(): expected key in ", expected_key_set, " but got: ", @@ -396,11 +494,17 @@ void check_base_legacy_new(c10::DispatchKey dispatch_key, at::Layout expected_la } // TODO: Make this accept options instead of dispatch key -void check_legacy_ctor_device(c10::DispatchKey dispatch_key, c10::optional device) { +void check_legacy_ctor_device( + c10::DispatchKey dispatch_key, + c10::optional device) { if (device.has_value()) { - TORCH_CHECK(dispatchKeyToDeviceType(dispatch_key) == device.value().type(), - "legacy constructor expects device type: ", dispatchKeyToDeviceType(dispatch_key), - " but device type: ", device.value().type(), " was passed"); + TORCH_CHECK( + dispatchKeyToDeviceType(dispatch_key) == device.value().type(), + "legacy constructor expects device type: ", + dispatchKeyToDeviceType(dispatch_key), + " but device type: ", + device.value().type(), + " was passed"); } } @@ -410,16 +514,22 @@ enum class CtorOrNew { NEW, }; -Tensor legacy_sparse_tensor_generic_ctor_new(c10::DispatchKey dispatch_key, at::ScalarType scalar_type, PyObject* args, PyObject* kwargs, CtorOrNew ctor_or_new) { +Tensor legacy_sparse_tensor_generic_ctor_new( + c10::DispatchKey dispatch_key, + at::ScalarType scalar_type, + PyObject* args, + PyObject* kwargs, + CtorOrNew ctor_or_new) { auto options = dispatchKeyToTensorOptions(dispatch_key); static PythonArgParser parser({ - "new(*, Device? device=None)", - "new(*, int64_t cdata)|hidden", - "new(Tensor indices, Tensor values, *, Device? device=None)", - "new(Tensor indices, Tensor values, IntArrayRef size, *, Device? device=None)", - "new(IntArrayRef size, *, Device? device=None)", + "new(*, Device? device=None)", + "new(*, int64_t cdata)|hidden", + "new(Tensor indices, Tensor values, *, Device? device=None)", + "new(Tensor indices, Tensor values, IntArrayRef size, *, Device? device=None)", + "new(IntArrayRef size, *, Device? device=None)", }); - if (ctor_or_new == CtorOrNew::NEW) check_base_legacy_new(dispatch_key, c10::kSparse); + if (ctor_or_new == CtorOrNew::NEW) + check_base_legacy_new(dispatch_key, c10::kSparse); ParsedArgs<4> parsed_args; auto r = parser.parse(args, kwargs, parsed_args); if (r.idx == 0) { @@ -430,15 +540,15 @@ Tensor legacy_sparse_tensor_generic_ctor_new(c10::DispatchKey dispatch_key, at:: auto cdata = reinterpret_cast(r.toInt64(0)); return at::unsafeTensorFromTH(cdata, true); } else if (r.idx == 2) { - // Note: this signature doesn't have a dtype, even though it has a device; it probably shouldn't - // have a device (we should infer it). + // Note: this signature doesn't have a dtype, even though it has a device; + // it probably shouldn't have a device (we should infer it). auto deviceOptional = r.deviceOptional(2); check_legacy_ctor_device(dispatch_key, deviceOptional); at::OptionalDeviceGuard device_guard(deviceOptional); return at::sparse_coo_tensor(r.tensor(0), r.tensor(1)); } else if (r.idx == 3) { - // Note: this signature doesn't have a dtype, even though it has a device; it probably shouldn't - // have a device (we should infer it). + // Note: this signature doesn't have a dtype, even though it has a device; + // it probably shouldn't have a device (we should infer it). auto deviceOptional = r.deviceOptional(3); check_legacy_ctor_device(dispatch_key, deviceOptional); at::OptionalDeviceGuard device_guard(deviceOptional); @@ -447,24 +557,31 @@ Tensor legacy_sparse_tensor_generic_ctor_new(c10::DispatchKey dispatch_key, at:: PyObject* arg = r.pyobject(0); auto deviceOptional = r.deviceOptional(1); check_legacy_ctor_device(dispatch_key, deviceOptional); - if (!THPSize_Check(arg) && PyTuple_GET_SIZE(args) >= 1 && arg == PyTuple_GET_ITEM(args, 0)) { + if (!THPSize_Check(arg) && PyTuple_GET_SIZE(args) >= 1 && + arg == PyTuple_GET_ITEM(args, 0)) { // new(sequence) binds to this signature but should be treated differently // unless the sequences is a torch.Size if (ctor_or_new == CtorOrNew::CTOR) { - throw TypeError("torch.SparseTensor(sequence) only accepts sizes. Please use torch.sparse_coo_tensor() " \ - "or construct a strided tensor and convert it to sparse via to_sparse."); + throw TypeError( + "torch.SparseTensor(sequence) only accepts sizes. Please use torch.sparse_coo_tensor() " + "or construct a strided tensor and convert it to sparse via to_sparse."); } else { - throw TypeError("SparseTensor.new(sequence) only accepts sizes. Please use torch.sparse_coo_tensor() " \ - "or construct a strided tensor and convert it to sparse via to_sparse."); + throw TypeError( + "SparseTensor.new(sequence) only accepts sizes. Please use torch.sparse_coo_tensor() " + "or construct a strided tensor and convert it to sparse via to_sparse."); } } - return new_with_sizes(options, scalar_type, r.deviceOptional(1), r.intlist(0)); + return new_with_sizes( + options, scalar_type, r.deviceOptional(1), r.intlist(0)); } throw std::runtime_error("new(): invalid arguments"); } // NB: device_idx here is NOT a DeviceIndex, but index into PythonArgs -c10::TensorOptions typeIdWithDefault(PythonArgs& r, int64_t device_idx, c10::DispatchKey dispatch_key) { +c10::TensorOptions typeIdWithDefault( + PythonArgs& r, + int64_t device_idx, + c10::DispatchKey dispatch_key) { auto options = dispatchKeyToTensorOptions(dispatch_key); if (!r.isNone(device_idx)) { // TODO: This line doesn't seem to be exercised at all in tests @@ -475,25 +592,35 @@ c10::TensorOptions typeIdWithDefault(PythonArgs& r, int64_t device_idx, c10::Dis } // namespace -Tensor legacy_tensor_generic_ctor_new(c10::DispatchKey dispatch_key, at::ScalarType scalar_type, PyObject* args, PyObject* kwargs, CtorOrNew ctor_or_new) { +Tensor legacy_tensor_generic_ctor_new( + c10::DispatchKey dispatch_key, + at::ScalarType scalar_type, + PyObject* args, + PyObject* kwargs, + CtorOrNew ctor_or_new) { auto options = dispatchKeyToTensorOptions(dispatch_key); static PythonArgParser parser({ - "new(*, Device? device=None)", - "new(Storage storage)", - "new(*, int64_t cdata)|hidden", - // This constructor is no longer legacy, it will also be usable for - // subclass initialization - "new(Tensor other)", - "new(Tensor other, *, Device? device=None)|hidden", // prevent Tensor matching with IntArrayRef, PyObject* - "new(IntArrayRef size, *, Device? device=None)", - "new(PyObject* data, *, Device? device=None)", + "new(*, Device? device=None)", + "new(Storage storage)", + "new(*, int64_t cdata)|hidden", + // This constructor is no longer legacy, it will also be usable for + // subclass initialization + "new(Tensor other)", + "new(Tensor other, *, Device? device=None)|hidden", // prevent Tensor + // matching with + // IntArrayRef, + // PyObject* + "new(IntArrayRef size, *, Device? device=None)", + "new(PyObject* data, *, Device? device=None)", }); if (isSparse(dispatchKeyToBackend(dispatch_key))) { - return legacy_sparse_tensor_generic_ctor_new(dispatch_key, scalar_type, args, kwargs, ctor_or_new); + return legacy_sparse_tensor_generic_ctor_new( + dispatch_key, scalar_type, args, kwargs, ctor_or_new); } - if (ctor_or_new == CtorOrNew::NEW) check_base_legacy_new(dispatch_key, c10::kStrided); + if (ctor_or_new == CtorOrNew::NEW) + check_base_legacy_new(dispatch_key, c10::kStrided); ParsedArgs<2> parsed_args; auto r = parser.parse(args, kwargs, parsed_args); @@ -508,10 +635,12 @@ Tensor legacy_tensor_generic_ctor_new(c10::DispatchKey dispatch_key, at::ScalarT at::Storage storage = r.storage(0, storage_scalar_type, is_typed_storage); if (storage_scalar_type != at::ScalarType::Undefined && is_typed_storage) { TORCH_CHECK( - storage_scalar_type == scalar_type, - "Expected a Storage of type ", scalar_type, - " or an _UntypedStorage, but got type ", storage_scalar_type, - " for argument 1 'storage'"); + storage_scalar_type == scalar_type, + "Expected a Storage of type ", + scalar_type, + " or an _UntypedStorage, but got type ", + storage_scalar_type, + " for argument 1 'storage'"); } return new_with_storage(options, scalar_type, storage); } else if (r.idx == 2) { @@ -523,32 +652,45 @@ Tensor legacy_tensor_generic_ctor_new(c10::DispatchKey dispatch_key, at::ScalarT // dtype; previously it was "float" biased if (ctor_or_new != CtorOrNew::BASE_CTOR) { options = options.dtype(scalar_type); - TORCH_CHECK_TYPE(other.options().type_equal(options), "expected ", - options, " (got ", other.options(), ")"); + TORCH_CHECK_TYPE( + other.options().type_equal(options), + "expected ", + options, + " (got ", + other.options(), + ")"); } return other.alias(); } else if (r.idx == 4) { if (ctor_or_new == CtorOrNew::CTOR || ctor_or_new == CtorOrNew::BASE_CTOR) { - TORCH_CHECK(false, "Legacy tensor constructor of the form torch.Tensor(tensor, device=device) " \ - "is not supported. Use torch.tensor(...) or torch.as_tensor(...) instead."); + TORCH_CHECK( + false, + "Legacy tensor constructor of the form torch.Tensor(tensor, device=device) " + "is not supported. Use torch.tensor(...) or torch.as_tensor(...) instead."); } else { - TORCH_CHECK(false, "Legacy tensor new of the form tensor.new(tensor, device=device) " \ - "is not supported. Use torch.as_tensor(...) instead."); + TORCH_CHECK( + false, + "Legacy tensor new of the form tensor.new(tensor, device=device) " + "is not supported. Use torch.as_tensor(...) instead."); } } else if (r.idx == 5) { PyObject* arg = r.pyobject(0); auto deviceOptional = r.deviceOptional(1); check_legacy_ctor_device(dispatch_key, deviceOptional); - if (!THPSize_Check(arg) && PyTuple_GET_SIZE(args) >= 1 && arg == PyTuple_GET_ITEM(args, 0)) { + if (!THPSize_Check(arg) && PyTuple_GET_SIZE(args) >= 1 && + arg == PyTuple_GET_ITEM(args, 0)) { // new(sequence) binds to this signature but should be treated differently // unless the sequences is a torch.Size - return legacy_new_from_sequence(options, scalar_type, deviceOptional, r.pyobject(0)); + return legacy_new_from_sequence( + options, scalar_type, deviceOptional, r.pyobject(0)); } - return new_with_sizes(options, scalar_type, r.deviceOptional(1), r.intlist(0)); + return new_with_sizes( + options, scalar_type, r.deviceOptional(1), r.intlist(0)); } else if (r.idx == 6) { auto deviceOptional = r.deviceOptional(1); check_legacy_ctor_device(dispatch_key, deviceOptional); - return legacy_new_from_sequence(options, scalar_type, deviceOptional, r.pyobject(0)); + return legacy_new_from_sequence( + options, scalar_type, deviceOptional, r.pyobject(0)); } throw std::runtime_error("new(): invalid arguments"); } @@ -559,21 +701,32 @@ Tensor legacy_tensor_generic_ctor_new(c10::DispatchKey dispatch_key, at::ScalarT // match the default) Tensor base_tensor_ctor(PyObject* args, PyObject* kwargs) { return legacy_tensor_generic_ctor_new( - torch::tensors::get_default_dispatch_key(), - torch::tensors::get_default_scalar_type(), - args, kwargs, CtorOrNew::BASE_CTOR - ); + torch::tensors::get_default_dispatch_key(), + torch::tensors::get_default_scalar_type(), + args, + kwargs, + CtorOrNew::BASE_CTOR); } // Handles calls like torch.DoubleTensor, torch.cuda.FloatTensor, // torch.sparse.FloatTensor, etc. -Tensor legacy_tensor_ctor(c10::DispatchKey dispatch_key, at::ScalarType scalar_type, PyObject* args, PyObject* kwargs) { - return legacy_tensor_generic_ctor_new(dispatch_key, scalar_type, args, kwargs, CtorOrNew::CTOR); +Tensor legacy_tensor_ctor( + c10::DispatchKey dispatch_key, + at::ScalarType scalar_type, + PyObject* args, + PyObject* kwargs) { + return legacy_tensor_generic_ctor_new( + dispatch_key, scalar_type, args, kwargs, CtorOrNew::CTOR); } // Handles tensor.new(...) -Tensor legacy_tensor_new(c10::DispatchKey dispatch_key, at::ScalarType scalar_type, PyObject* args, PyObject* kwargs) { - return legacy_tensor_generic_ctor_new(dispatch_key, scalar_type, args, kwargs, CtorOrNew::NEW); +Tensor legacy_tensor_new( + c10::DispatchKey dispatch_key, + at::ScalarType scalar_type, + PyObject* args, + PyObject* kwargs) { + return legacy_tensor_generic_ctor_new( + dispatch_key, scalar_type, args, kwargs, CtorOrNew::NEW); } Tensor indexing_tensor_from_data( @@ -584,43 +737,60 @@ Tensor indexing_tensor_from_data( // Specific to tensor indexing, converts an indexing list to an // indexing tensor (type Byte or Long) ScalarType inferred_scalar_type = infer_scalar_type(data); - if (inferred_scalar_type == ScalarType::Byte || inferred_scalar_type == ScalarType::Bool) { - return internal_new_from_data(options, inferred_scalar_type, device, data, - /*copy_variables=*/false, /*copy_numpy=*/false, - /*type_inference=*/false); + if (inferred_scalar_type == ScalarType::Byte || + inferred_scalar_type == ScalarType::Bool) { + return internal_new_from_data( + options, + inferred_scalar_type, + device, + data, + /*copy_variables=*/false, + /*copy_numpy=*/false, + /*type_inference=*/false); } else { - return internal_new_from_data(options, scalar_type, device, data, - /*copy_variables=*/false, /*copy_numpy=*/false, - /*type_inference=*/false); + return internal_new_from_data( + options, + scalar_type, + device, + data, + /*copy_variables=*/false, + /*copy_numpy=*/false, + /*type_inference=*/false); } } -Tensor sparse_compressed_tensor_ctor_worker(std::string name, c10::DispatchKey dispatch_key, at::ScalarType scalar_type, PythonArgs& r, c10::optional required_layout) { +Tensor sparse_compressed_tensor_ctor_worker( + std::string name, + c10::DispatchKey dispatch_key, + at::ScalarType scalar_type, + PythonArgs& r, + c10::optional required_layout) { TORCH_INTERNAL_ASSERT(!isSparseCsr(dispatchKeyToBackend(dispatch_key))); TORCH_INTERNAL_ASSERT(!isSparse(dispatchKeyToBackend(dispatch_key))); enum { - ARG_COMPRESSED_INDICES = 0, - ARG_PLAIN_INDICES, - ARG_VALUES, - ARG_SIZE, - ARG_TYPE, - ARG_LAYOUT, - ARG_DEVICE, - ARG_PIN_MEMORY, - ARG_REQUIRES_GRAD, - ARGS_COUNT + ARG_COMPRESSED_INDICES = 0, + ARG_PLAIN_INDICES, + ARG_VALUES, + ARG_SIZE, + ARG_TYPE, + ARG_LAYOUT, + ARG_DEVICE, + ARG_PIN_MEMORY, + ARG_REQUIRES_GRAD, + ARGS_COUNT }; enum { - ARG_VALUES1 = ARG_VALUES, - ARG_TYPE1, - ARG_LAYOUT1, - ARG_DEVICE1, - ARG_PIN_MEMORY1, - ARG_REQUIRES_GRAD1, - ARGS_COUNT1 + ARG_VALUES1 = ARG_VALUES, + ARG_TYPE1, + ARG_LAYOUT1, + ARG_DEVICE1, + ARG_PIN_MEMORY1, + ARG_REQUIRES_GRAD1, + ARGS_COUNT1 }; - auto safe_get_attr_string = [](PyObject *o, const char *attr_name) -> PyObject* { + auto safe_get_attr_string = [](PyObject* o, + const char* attr_name) -> PyObject* { // Clear error indicator if attribute does not exists. // Otherwise subsequent Python C API calls might return bogus values. // See https://github.com/pytorch/pytorch/issues/58520 for more details @@ -634,167 +804,318 @@ Tensor sparse_compressed_tensor_ctor_worker(std::string name, c10::DispatchKey d } return rc; }; - THPObjectPtr compressed_indices_dtype_attr(safe_get_attr_string(r.pyobject(ARG_COMPRESSED_INDICES), "dtype")); - THPObjectPtr plain_indices_dtype_attr(safe_get_attr_string(r.pyobject(ARG_PLAIN_INDICES), "dtype")); - at::ScalarType compressed_indices_scalar_type = compressed_indices_dtype_attr ? reinterpret_cast( - compressed_indices_dtype_attr.get())->scalar_type : kInt; - at::ScalarType plain_indices_scalar_type = plain_indices_dtype_attr ? reinterpret_cast( - plain_indices_dtype_attr.get())->scalar_type : kInt; + THPObjectPtr compressed_indices_dtype_attr( + safe_get_attr_string(r.pyobject(ARG_COMPRESSED_INDICES), "dtype")); + THPObjectPtr plain_indices_dtype_attr( + safe_get_attr_string(r.pyobject(ARG_PLAIN_INDICES), "dtype")); + at::ScalarType compressed_indices_scalar_type = compressed_indices_dtype_attr + ? reinterpret_cast(compressed_indices_dtype_attr.get()) + ->scalar_type + : kInt; + at::ScalarType plain_indices_scalar_type = plain_indices_dtype_attr + ? reinterpret_cast(plain_indices_dtype_attr.get())->scalar_type + : kInt; if (r.idx == 0) { bool type_inference = r.isNone(ARG_TYPE); - const auto inferred_options = typeIdWithDefault(r, ARG_DEVICE, dispatch_key); - const auto inferred_scalar_type = r.scalartypeWithDefault(ARG_TYPE, scalar_type); + const auto inferred_options = + typeIdWithDefault(r, ARG_DEVICE, dispatch_key); + const auto inferred_scalar_type = + r.scalartypeWithDefault(ARG_TYPE, scalar_type); at::OptionalDeviceGuard device_guard(r.deviceOptional(ARG_DEVICE)); - Tensor values = internal_new_from_data(inferred_options, inferred_scalar_type, r.deviceOptional(ARG_DEVICE), - r.pyobject(ARG_VALUES), /*copy_variables=*/false, /*copy_numpy=*/true, - /*type_inference=*/type_inference); - Tensor compressed_indices = internal_new_from_data(values.options(), - compressed_indices_scalar_type, r.deviceOptional(ARG_DEVICE), r.pyobject(ARG_COMPRESSED_INDICES), - /*copy_variables=*/false, /*copy_numpy=*/true, - /*type_inference=*/true); - Tensor plain_indices = internal_new_from_data(values.options(), - plain_indices_scalar_type, r.deviceOptional(ARG_DEVICE), r.pyobject(ARG_PLAIN_INDICES), - /*copy_variables=*/false, /*copy_numpy=*/true, - /*type_inference=*/true); - c10::optional layout = (required_layout ? r.layoutWithDefault(ARG_LAYOUT, required_layout.value()) : r.layoutOptional(ARG_LAYOUT)); + Tensor values = internal_new_from_data( + inferred_options, + inferred_scalar_type, + r.deviceOptional(ARG_DEVICE), + r.pyobject(ARG_VALUES), + /*copy_variables=*/false, + /*copy_numpy=*/true, + /*type_inference=*/type_inference); + Tensor compressed_indices = internal_new_from_data( + values.options(), + compressed_indices_scalar_type, + r.deviceOptional(ARG_DEVICE), + r.pyobject(ARG_COMPRESSED_INDICES), + /*copy_variables=*/false, + /*copy_numpy=*/true, + /*type_inference=*/true); + Tensor plain_indices = internal_new_from_data( + values.options(), + plain_indices_scalar_type, + r.deviceOptional(ARG_DEVICE), + r.pyobject(ARG_PLAIN_INDICES), + /*copy_variables=*/false, + /*copy_numpy=*/true, + /*type_inference=*/true); + c10::optional layout = + (required_layout + ? r.layoutWithDefault(ARG_LAYOUT, required_layout.value()) + : r.layoutOptional(ARG_LAYOUT)); if (required_layout && layout) { - TORCH_CHECK(layout.value() == required_layout.value(), name, ": layout must be ", required_layout.value(), " but got ", layout.value()); + TORCH_CHECK( + layout.value() == required_layout.value(), + name, + ": layout must be ", + required_layout.value(), + " but got ", + layout.value()); } - return at::sparse_compressed_tensor(compressed_indices, plain_indices, values, r.intlist(ARG_SIZE), - values.options().layout(layout)).set_requires_grad(r.toBool(ARG_REQUIRES_GRAD)); + return at::sparse_compressed_tensor( + compressed_indices, + plain_indices, + values, + r.intlist(ARG_SIZE), + values.options().layout(layout)) + .set_requires_grad(r.toBool(ARG_REQUIRES_GRAD)); } else if (r.idx == 1) { bool type_inference = r.isNone(ARG_TYPE1); - const auto inferred_options = typeIdWithDefault(r, ARG_DEVICE1, dispatch_key); - const auto inferred_scalar_type = r.scalartypeWithDefault(ARG_TYPE1, scalar_type); + const auto inferred_options = + typeIdWithDefault(r, ARG_DEVICE1, dispatch_key); + const auto inferred_scalar_type = + r.scalartypeWithDefault(ARG_TYPE1, scalar_type); at::OptionalDeviceGuard device_guard(r.deviceOptional(ARG_DEVICE1)); - Tensor values = internal_new_from_data(inferred_options, inferred_scalar_type, r.deviceOptional(ARG_DEVICE1), - r.pyobject(ARG_VALUES), /*copy_variables=*/false, /*copy_numpy=*/true, - /*type_inference=*/type_inference); - Tensor compressed_indices = internal_new_from_data(values.options(), - compressed_indices_scalar_type, r.deviceOptional(ARG_DEVICE1), - r.pyobject(ARG_COMPRESSED_INDICES), /*copy_variables=*/false, /*copy_numpy=*/true, - /*type_inference=*/true); - Tensor plain_indices = internal_new_from_data(values.options(), plain_indices_scalar_type, r.deviceOptional(ARG_DEVICE1), - r.pyobject(ARG_PLAIN_INDICES), /*copy_variables=*/false, /*copy_numpy=*/true, - /*type_inference=*/true); - c10::optional layout = (required_layout ? r.layoutWithDefault(ARG_LAYOUT1, required_layout.value()) : r.layoutOptional(ARG_LAYOUT1)); + Tensor values = internal_new_from_data( + inferred_options, + inferred_scalar_type, + r.deviceOptional(ARG_DEVICE1), + r.pyobject(ARG_VALUES), + /*copy_variables=*/false, + /*copy_numpy=*/true, + /*type_inference=*/type_inference); + Tensor compressed_indices = internal_new_from_data( + values.options(), + compressed_indices_scalar_type, + r.deviceOptional(ARG_DEVICE1), + r.pyobject(ARG_COMPRESSED_INDICES), + /*copy_variables=*/false, + /*copy_numpy=*/true, + /*type_inference=*/true); + Tensor plain_indices = internal_new_from_data( + values.options(), + plain_indices_scalar_type, + r.deviceOptional(ARG_DEVICE1), + r.pyobject(ARG_PLAIN_INDICES), + /*copy_variables=*/false, + /*copy_numpy=*/true, + /*type_inference=*/true); + c10::optional layout = + (required_layout + ? r.layoutWithDefault(ARG_LAYOUT1, required_layout.value()) + : r.layoutOptional(ARG_LAYOUT1)); if (required_layout && layout) { - TORCH_CHECK(layout.value() == required_layout.value(), name, ": layout must be ", required_layout.value(), " but got ", layout.value()); + TORCH_CHECK( + layout.value() == required_layout.value(), + name, + ": layout must be ", + required_layout.value(), + " but got ", + layout.value()); } - return at::sparse_compressed_tensor(compressed_indices, plain_indices, values, - values.options().layout(layout)).set_requires_grad(r.toBool(ARG_REQUIRES_GRAD1)); + return at::sparse_compressed_tensor( + compressed_indices, + plain_indices, + values, + values.options().layout(layout)) + .set_requires_grad(r.toBool(ARG_REQUIRES_GRAD1)); } throw std::runtime_error(name + ": invalid arguments"); } -Tensor sparse_compressed_tensor_ctor(c10::DispatchKey dispatch_key, at::ScalarType scalar_type, PythonArgs& r) { - c10::optional required_layout {}; - return sparse_compressed_tensor_ctor_worker("sparse_compressed_tensor", dispatch_key, scalar_type, r, required_layout); +Tensor sparse_compressed_tensor_ctor( + c10::DispatchKey dispatch_key, + at::ScalarType scalar_type, + PythonArgs& r) { + c10::optional required_layout{}; + return sparse_compressed_tensor_ctor_worker( + "sparse_compressed_tensor", + dispatch_key, + scalar_type, + r, + required_layout); } -Tensor sparse_csr_tensor_ctor(c10::DispatchKey dispatch_key, at::ScalarType scalar_type, PythonArgs& r) { +Tensor sparse_csr_tensor_ctor( + c10::DispatchKey dispatch_key, + at::ScalarType scalar_type, + PythonArgs& r) { c10::optional required_layout(c10::Layout::SparseCsr); - return sparse_compressed_tensor_ctor_worker("sparse_csr_tensor", dispatch_key, scalar_type, r, required_layout); + return sparse_compressed_tensor_ctor_worker( + "sparse_csr_tensor", dispatch_key, scalar_type, r, required_layout); } -Tensor sparse_csc_tensor_ctor(c10::DispatchKey dispatch_key, at::ScalarType scalar_type, PythonArgs& r) { +Tensor sparse_csc_tensor_ctor( + c10::DispatchKey dispatch_key, + at::ScalarType scalar_type, + PythonArgs& r) { c10::optional required_layout(c10::Layout::SparseCsc); - return sparse_compressed_tensor_ctor_worker("sparse_csc_tensor", dispatch_key, scalar_type, r, required_layout); + return sparse_compressed_tensor_ctor_worker( + "sparse_csc_tensor", dispatch_key, scalar_type, r, required_layout); } -Tensor sparse_bsr_tensor_ctor(c10::DispatchKey dispatch_key, at::ScalarType scalar_type, PythonArgs& r) { +Tensor sparse_bsr_tensor_ctor( + c10::DispatchKey dispatch_key, + at::ScalarType scalar_type, + PythonArgs& r) { c10::optional required_layout(c10::Layout::SparseBsr); - return sparse_compressed_tensor_ctor_worker("sparse_bsr_tensor", dispatch_key, scalar_type, r, required_layout); + return sparse_compressed_tensor_ctor_worker( + "sparse_bsr_tensor", dispatch_key, scalar_type, r, required_layout); } -Tensor sparse_bsc_tensor_ctor(c10::DispatchKey dispatch_key, at::ScalarType scalar_type, PythonArgs& r) { +Tensor sparse_bsc_tensor_ctor( + c10::DispatchKey dispatch_key, + at::ScalarType scalar_type, + PythonArgs& r) { c10::optional required_layout(c10::Layout::SparseBsc); - return sparse_compressed_tensor_ctor_worker("sparse_bsc_tensor", dispatch_key, scalar_type, r, required_layout); + return sparse_compressed_tensor_ctor_worker( + "sparse_bsc_tensor", dispatch_key, scalar_type, r, required_layout); } -Tensor _sparse_compressed_tensor_unsafe_ctor(c10::DispatchKey dispatch_key, at::ScalarType scalar_type, PythonArgs& r) { +Tensor _sparse_compressed_tensor_unsafe_ctor( + c10::DispatchKey dispatch_key, + at::ScalarType scalar_type, + PythonArgs& r) { TORCH_INTERNAL_ASSERT(!isSparseCsr(dispatchKeyToBackend(dispatch_key))); TORCH_INTERNAL_ASSERT(!isSparse(dispatchKeyToBackend(dispatch_key))); enum { - ARG_COMPRESSED_INDICES = 0, - ARG_PLAIN_INDICES, - ARG_VALUES, - ARG_SIZE, - ARG_TYPE, - ARG_LAYOUT, - ARG_DEVICE, - ARG_REQUIRES_GRAD, - ARGS_COUNT + ARG_COMPRESSED_INDICES = 0, + ARG_PLAIN_INDICES, + ARG_VALUES, + ARG_SIZE, + ARG_TYPE, + ARG_LAYOUT, + ARG_DEVICE, + ARG_REQUIRES_GRAD, + ARGS_COUNT }; bool type_inference = r.isNone(ARG_TYPE); const auto inferred_options = typeIdWithDefault(r, ARG_DEVICE, dispatch_key); - const auto inferred_scalar_type = r.scalartypeWithDefault(ARG_TYPE, scalar_type); + const auto inferred_scalar_type = + r.scalartypeWithDefault(ARG_TYPE, scalar_type); at::OptionalDeviceGuard device_guard(r.deviceOptional(ARG_DEVICE)); - Tensor values = internal_new_from_data(inferred_options, inferred_scalar_type, r.deviceOptional(ARG_DEVICE), r.pyobject(ARG_VALUES), - /*copy_variables=*/false, /*copy_numpy=*/true, - /*type_inference=*/type_inference); - - Tensor compressed_indices = internal_new_from_data(values.options(), kInt, r.deviceOptional(ARG_DEVICE), r.pyobject(ARG_COMPRESSED_INDICES), - /*copy_variables=*/false, /*copy_numpy=*/true, - /*type_inference=*/true); - - Tensor plain_indices = internal_new_from_data(values.options(), kInt, r.deviceOptional(ARG_DEVICE), r.pyobject(ARG_PLAIN_INDICES), - /*copy_variables=*/false, /*copy_numpy=*/true, - /*type_inference=*/true); - return at::_sparse_compressed_tensor_unsafe(compressed_indices, plain_indices, values, r.intlist(ARG_SIZE), - values.options().layout(r.layoutOptional(ARG_LAYOUT))).set_requires_grad(r.toBool(ARG_REQUIRES_GRAD)); + Tensor values = internal_new_from_data( + inferred_options, + inferred_scalar_type, + r.deviceOptional(ARG_DEVICE), + r.pyobject(ARG_VALUES), + /*copy_variables=*/false, + /*copy_numpy=*/true, + /*type_inference=*/type_inference); + + Tensor compressed_indices = internal_new_from_data( + values.options(), + kInt, + r.deviceOptional(ARG_DEVICE), + r.pyobject(ARG_COMPRESSED_INDICES), + /*copy_variables=*/false, + /*copy_numpy=*/true, + /*type_inference=*/true); + + Tensor plain_indices = internal_new_from_data( + values.options(), + kInt, + r.deviceOptional(ARG_DEVICE), + r.pyobject(ARG_PLAIN_INDICES), + /*copy_variables=*/false, + /*copy_numpy=*/true, + /*type_inference=*/true); + return at::_sparse_compressed_tensor_unsafe( + compressed_indices, + plain_indices, + values, + r.intlist(ARG_SIZE), + values.options().layout(r.layoutOptional(ARG_LAYOUT))) + .set_requires_grad(r.toBool(ARG_REQUIRES_GRAD)); } template -Tensor _sparse_compressed_tensor_unsafe_ctor_template(c10::DispatchKey dispatch_key, at::ScalarType scalar_type, PythonArgs& r) { +Tensor _sparse_compressed_tensor_unsafe_ctor_template( + c10::DispatchKey dispatch_key, + at::ScalarType scalar_type, + PythonArgs& r) { TORCH_INTERNAL_ASSERT(!isSparseCsr(dispatchKeyToBackend(dispatch_key))); TORCH_INTERNAL_ASSERT(!isSparse(dispatchKeyToBackend(dispatch_key))); enum { - ARG_COMPRESSED_INDICES = 0, - ARG_PLAIN_INDICES, - ARG_VALUES, - ARG_SIZE, - ARG_TYPE, - ARG_DEVICE, - ARG_REQUIRES_GRAD, - ARGS_COUNT + ARG_COMPRESSED_INDICES = 0, + ARG_PLAIN_INDICES, + ARG_VALUES, + ARG_SIZE, + ARG_TYPE, + ARG_DEVICE, + ARG_REQUIRES_GRAD, + ARGS_COUNT }; bool type_inference = r.isNone(ARG_TYPE); const auto inferred_options = typeIdWithDefault(r, ARG_DEVICE, dispatch_key); - const auto inferred_scalar_type = r.scalartypeWithDefault(ARG_TYPE, scalar_type); + const auto inferred_scalar_type = + r.scalartypeWithDefault(ARG_TYPE, scalar_type); at::OptionalDeviceGuard device_guard(r.deviceOptional(ARG_DEVICE)); - Tensor values = internal_new_from_data(inferred_options, inferred_scalar_type, r.deviceOptional(ARG_DEVICE), r.pyobject(ARG_VALUES), - /*copy_variables=*/false, /*copy_numpy=*/true, - /*type_inference=*/type_inference); - - Tensor compressed_indices = internal_new_from_data(values.options(), kInt, r.deviceOptional(ARG_DEVICE), r.pyobject(ARG_COMPRESSED_INDICES), - /*copy_variables=*/false, /*copy_numpy=*/true, - /*type_inference=*/true); - - Tensor plain_indices = internal_new_from_data(values.options(), kInt, r.deviceOptional(ARG_DEVICE), r.pyobject(ARG_PLAIN_INDICES), - /*copy_variables=*/false, /*copy_numpy=*/true, - /*type_inference=*/true); - return at::_sparse_compressed_tensor_unsafe(compressed_indices, plain_indices, values, r.intlist(ARG_SIZE), - values.options().layout(required_layout)).set_requires_grad(r.toBool(ARG_REQUIRES_GRAD)); + Tensor values = internal_new_from_data( + inferred_options, + inferred_scalar_type, + r.deviceOptional(ARG_DEVICE), + r.pyobject(ARG_VALUES), + /*copy_variables=*/false, + /*copy_numpy=*/true, + /*type_inference=*/type_inference); + + Tensor compressed_indices = internal_new_from_data( + values.options(), + kInt, + r.deviceOptional(ARG_DEVICE), + r.pyobject(ARG_COMPRESSED_INDICES), + /*copy_variables=*/false, + /*copy_numpy=*/true, + /*type_inference=*/true); + + Tensor plain_indices = internal_new_from_data( + values.options(), + kInt, + r.deviceOptional(ARG_DEVICE), + r.pyobject(ARG_PLAIN_INDICES), + /*copy_variables=*/false, + /*copy_numpy=*/true, + /*type_inference=*/true); + return at::_sparse_compressed_tensor_unsafe( + compressed_indices, + plain_indices, + values, + r.intlist(ARG_SIZE), + values.options().layout(required_layout)) + .set_requires_grad(r.toBool(ARG_REQUIRES_GRAD)); } -Tensor _sparse_csr_tensor_unsafe_ctor(c10::DispatchKey dispatch_key, at::ScalarType scalar_type, PythonArgs& r) { - return _sparse_compressed_tensor_unsafe_ctor_template(dispatch_key, scalar_type, r); +Tensor _sparse_csr_tensor_unsafe_ctor( + c10::DispatchKey dispatch_key, + at::ScalarType scalar_type, + PythonArgs& r) { + return _sparse_compressed_tensor_unsafe_ctor_template( + dispatch_key, scalar_type, r); } -Tensor _sparse_csc_tensor_unsafe_ctor(c10::DispatchKey dispatch_key, at::ScalarType scalar_type, PythonArgs& r) { - return _sparse_compressed_tensor_unsafe_ctor_template(dispatch_key, scalar_type, r); +Tensor _sparse_csc_tensor_unsafe_ctor( + c10::DispatchKey dispatch_key, + at::ScalarType scalar_type, + PythonArgs& r) { + return _sparse_compressed_tensor_unsafe_ctor_template( + dispatch_key, scalar_type, r); } -Tensor _sparse_bsr_tensor_unsafe_ctor(c10::DispatchKey dispatch_key, at::ScalarType scalar_type, PythonArgs& r) { - return _sparse_compressed_tensor_unsafe_ctor_template(dispatch_key, scalar_type, r); +Tensor _sparse_bsr_tensor_unsafe_ctor( + c10::DispatchKey dispatch_key, + at::ScalarType scalar_type, + PythonArgs& r) { + return _sparse_compressed_tensor_unsafe_ctor_template( + dispatch_key, scalar_type, r); } -Tensor _sparse_bsc_tensor_unsafe_ctor(c10::DispatchKey dispatch_key, at::ScalarType scalar_type, PythonArgs& r) { - return _sparse_compressed_tensor_unsafe_ctor_template(dispatch_key, scalar_type, r); +Tensor _sparse_bsc_tensor_unsafe_ctor( + c10::DispatchKey dispatch_key, + at::ScalarType scalar_type, + PythonArgs& r) { + return _sparse_compressed_tensor_unsafe_ctor_template( + dispatch_key, scalar_type, r); } // Note [Ensuring sparse values and indices match devices] @@ -830,32 +1151,62 @@ Tensor sparse_coo_tensor_ctor( const auto inferred_scalar_type = r.scalartypeWithDefault(2, scalar_type); at::OptionalDeviceGuard device_guard(r.deviceOptional(3)); // if no dtype provided, infer type based on value type. - Tensor values = internal_new_from_data(inferred_options, inferred_scalar_type, r.deviceOptional(3), r.pyobject(1), - /*copy_variables=*/false, /*copy_numpy=*/true, - /*type_inference=*/type_inference); + Tensor values = internal_new_from_data( + inferred_options, + inferred_scalar_type, + r.deviceOptional(3), + r.pyobject(1), + /*copy_variables=*/false, + /*copy_numpy=*/true, + /*type_inference=*/type_inference); // See Note [Ensuring sparse values and indices match devices] - Tensor indices = internal_new_from_data(values.options(), kLong, r.deviceOptional(3), r.pyobject(0), - /*copy_variables=*/false, /*copy_numpy=*/true, - /*type_inference=*/false); - return at::sparse_coo_tensor(indices, values, values.options().layout(at::kSparse)).set_requires_grad(r.toBool(4)); + Tensor indices = internal_new_from_data( + values.options(), + kLong, + r.deviceOptional(3), + r.pyobject(0), + /*copy_variables=*/false, + /*copy_numpy=*/true, + /*type_inference=*/false); + return at::sparse_coo_tensor( + indices, values, values.options().layout(at::kSparse)) + .set_requires_grad(r.toBool(4)); } else if (r.idx == 1) { bool type_inference = r.isNone(3); const auto inferred_options = typeIdWithDefault(r, 4, dispatch_key); const auto inferred_scalar_type = r.scalartypeWithDefault(3, scalar_type); at::OptionalDeviceGuard device_guard(r.deviceOptional(4)); - Tensor values = internal_new_from_data(inferred_options, inferred_scalar_type, r.deviceOptional(4), r.pyobject(1), - /*copy_variables=*/false, /*copy_numpy=*/true, - /*type_inference=*/type_inference); + Tensor values = internal_new_from_data( + inferred_options, + inferred_scalar_type, + r.deviceOptional(4), + r.pyobject(1), + /*copy_variables=*/false, + /*copy_numpy=*/true, + /*type_inference=*/type_inference); // See Note [Ensuring sparse values and indices match devices] - Tensor indices = internal_new_from_data(values.options(), kLong, r.deviceOptional(4), r.pyobject(0), - /*copy_variables=*/false, /*copy_numpy=*/true, - /*type_inference=*/false); - return at::sparse_coo_tensor(indices, values, r.intlist(2), values.options().layout(at::kSparse)).set_requires_grad(r.toBool(5)); + Tensor indices = internal_new_from_data( + values.options(), + kLong, + r.deviceOptional(4), + r.pyobject(0), + /*copy_variables=*/false, + /*copy_numpy=*/true, + /*type_inference=*/false); + return at::sparse_coo_tensor( + indices, + values, + r.intlist(2), + values.options().layout(at::kSparse)) + .set_requires_grad(r.toBool(5)); } else if (r.idx == 2) { const auto inferred_options = typeIdWithDefault(r, 2, dispatch_key); const auto inferred_scalar_type = r.scalartypeWithDefault(1, scalar_type); at::OptionalDeviceGuard device_guard(r.deviceOptional(2)); - return at::sparse_coo_tensor(r.intlist(0), inferred_options.dtype(inferred_scalar_type).layout(at::kSparse)).set_requires_grad(r.toBool(3)); + return at::sparse_coo_tensor( + r.intlist(0), + inferred_options.dtype(inferred_scalar_type).layout(at::kSparse)) + .set_requires_grad(r.toBool(3)); } throw std::runtime_error("sparse_coo_tensor(): invalid arguments"); } @@ -877,37 +1228,71 @@ Tensor _sparse_coo_tensor_unsafe_ctor( }; bool type_inference = r.isNone(ARG_TYPE); const auto inferred_options = typeIdWithDefault(r, ARG_DEVICE, dispatch_key); - const auto inferred_scalar_type = r.scalartypeWithDefault(ARG_TYPE, scalar_type); + const auto inferred_scalar_type = + r.scalartypeWithDefault(ARG_TYPE, scalar_type); at::OptionalDeviceGuard device_guard(r.deviceOptional(ARG_DEVICE)); - Tensor values = internal_new_from_data(inferred_options, inferred_scalar_type, r.deviceOptional(ARG_DEVICE), r.pyobject(ARG_VALUES), - /*copy_variables=*/false, /*copy_numpy=*/true, - /*type_inference=*/type_inference); + Tensor values = internal_new_from_data( + inferred_options, + inferred_scalar_type, + r.deviceOptional(ARG_DEVICE), + r.pyobject(ARG_VALUES), + /*copy_variables=*/false, + /*copy_numpy=*/true, + /*type_inference=*/type_inference); // See Note [Ensuring sparse values and indices match devices] - Tensor indices = internal_new_from_data(values.options(), kLong, r.deviceOptional(ARG_DEVICE), r.pyobject(ARG_INDICES), - /*copy_variables=*/false, /*copy_numpy=*/true, - /*type_inference=*/false); - return at::_sparse_coo_tensor_unsafe(indices, values, r.intlist(ARG_SIZE), values.options().layout(at::kSparse)).set_requires_grad(r.toBool(ARG_REQUIRES_GRAD)); + Tensor indices = internal_new_from_data( + values.options(), + kLong, + r.deviceOptional(ARG_DEVICE), + r.pyobject(ARG_INDICES), + /*copy_variables=*/false, + /*copy_numpy=*/true, + /*type_inference=*/false); + return at::_sparse_coo_tensor_unsafe( + indices, + values, + r.intlist(ARG_SIZE), + values.options().layout(at::kSparse)) + .set_requires_grad(r.toBool(ARG_REQUIRES_GRAD)); } -void _validate_sparse_coo_tensor_args(c10::DispatchKey dispatch_key, at::ScalarType scalar_type, PyObject* args, PyObject* kwargs) { +void _validate_sparse_coo_tensor_args( + c10::DispatchKey dispatch_key, + at::ScalarType scalar_type, + PyObject* args, + PyObject* kwargs) { auto options = dispatchKeyToTensorOptions(dispatch_key); static PythonArgParser parser({ - "_validate_sparse_coo_tensor(PyObject* indices, PyObject* values, IntArrayRef size)", + "_validate_sparse_coo_tensor(PyObject* indices, PyObject* values, IntArrayRef size)", }); ParsedArgs<3> parsed_args; auto r = parser.parse(args, kwargs, parsed_args); Tensor values = internal_new_from_data( - options, scalar_type, c10::nullopt, r.pyobject(1), - /*copy_variables=*/false, /*copy_numpy=*/true, /*type_inference=*/true); + options, + scalar_type, + c10::nullopt, + r.pyobject(1), + /*copy_variables=*/false, + /*copy_numpy=*/true, + /*type_inference=*/true); // See Note [Ensuring sparse values and indices match devices] Tensor indices = internal_new_from_data( - values.options(), kLong, c10::nullopt, r.pyobject(0), - /*copy_variables=*/false, /*copy_numpy=*/true, /*type_inference=*/false); + values.options(), + kLong, + c10::nullopt, + r.pyobject(0), + /*copy_variables=*/false, + /*copy_numpy=*/true, + /*type_inference=*/false); at::native::_validate_sparse_coo_tensor_args(indices, values, r.intlist(2)); } -void _validate_sparse_compressed_tensor_args(c10::DispatchKey dispatch_key, at::ScalarType scalar_type, PyObject* args, PyObject* kwargs) { +void _validate_sparse_compressed_tensor_args( + c10::DispatchKey dispatch_key, + at::ScalarType scalar_type, + PyObject* args, + PyObject* kwargs) { auto options = dispatchKeyToTensorOptions(dispatch_key); enum { ARG_COMPRESSED_INDICES = 0, @@ -918,26 +1303,51 @@ void _validate_sparse_compressed_tensor_args(c10::DispatchKey dispatch_key, at:: ARGS_COUNT }; - const std::string signature = "_validate_sparse_compressed_tensor(PyObject* compressed_indices, PyObject* plain_indices, PyObject* values, IntArrayRef size, Layout layout)"; + const std::string signature = + "_validate_sparse_compressed_tensor(PyObject* compressed_indices, PyObject* plain_indices, PyObject* values, IntArrayRef size, Layout layout)"; static PythonArgParser parser({signature}); ParsedArgs parsed_args; auto r = parser.parse(args, kwargs, parsed_args); Tensor values = internal_new_from_data( - options, scalar_type, c10::nullopt, r.pyobject(ARG_VALUES), - /*copy_variables=*/false, /*copy_numpy=*/true, /*type_inference=*/true); + options, + scalar_type, + c10::nullopt, + r.pyobject(ARG_VALUES), + /*copy_variables=*/false, + /*copy_numpy=*/true, + /*type_inference=*/true); // See Note [Ensuring sparse values and indices match devices] Tensor compressed_indices = internal_new_from_data( - values.options(), kInt, c10::nullopt, r.pyobject(ARG_COMPRESSED_INDICES), - /*copy_variables=*/false, /*copy_numpy=*/true, /*type_inference=*/true); + values.options(), + kInt, + c10::nullopt, + r.pyobject(ARG_COMPRESSED_INDICES), + /*copy_variables=*/false, + /*copy_numpy=*/true, + /*type_inference=*/true); Tensor plain_indices = internal_new_from_data( - values.options(), kInt, c10::nullopt, r.pyobject(ARG_PLAIN_INDICES), - /*copy_variables=*/false, /*copy_numpy=*/true, /*type_inference=*/true); - at::native::_validate_sparse_compressed_tensor_args(compressed_indices, plain_indices, values, r.intlist(ARG_SIZE), r.layout(ARG_LAYOUT)); + values.options(), + kInt, + c10::nullopt, + r.pyobject(ARG_PLAIN_INDICES), + /*copy_variables=*/false, + /*copy_numpy=*/true, + /*type_inference=*/true); + at::native::_validate_sparse_compressed_tensor_args( + compressed_indices, + plain_indices, + values, + r.intlist(ARG_SIZE), + r.layout(ARG_LAYOUT)); } template -void _validate_sparse_compressed_tensor_args_template(c10::DispatchKey dispatch_key, at::ScalarType scalar_type, PyObject* args, PyObject* kwargs) { +void _validate_sparse_compressed_tensor_args_template( + c10::DispatchKey dispatch_key, + at::ScalarType scalar_type, + PyObject* args, + PyObject* kwargs) { auto options = dispatchKeyToTensorOptions(dispatch_key); enum { ARG_COMPRESSED_INDICES = 0, @@ -948,53 +1358,92 @@ void _validate_sparse_compressed_tensor_args_template(c10::DispatchKey dispatch_ }; static std::string sig; switch (required_layout) { - case c10::Layout::SparseCsr: - sig = "_validate_sparse_csr_tensor(PyObject* crow_indices, PyObject* col_indices, PyObject* values, IntArrayRef size)"; - break; - case c10::Layout::SparseCsc: - sig = "_validate_sparse_csc_tensor(PyObject* ccol_indices, PyObject* row_indices, PyObject* values, IntArrayRef size)"; - break; - case c10::Layout::SparseBsr: - sig = "_validate_sparse_bsr_tensor(PyObject* crow_indices, PyObject* col_indices, PyObject* values, IntArrayRef size)"; - break; - case c10::Layout::SparseBsc: - sig = "_validate_sparse_bsc_tensor(PyObject* ccol_indices, PyObject* row_indices, PyObject* values, IntArrayRef size)"; - break; - default: - ; + case c10::Layout::SparseCsr: + sig = + "_validate_sparse_csr_tensor(PyObject* crow_indices, PyObject* col_indices, PyObject* values, IntArrayRef size)"; + break; + case c10::Layout::SparseCsc: + sig = + "_validate_sparse_csc_tensor(PyObject* ccol_indices, PyObject* row_indices, PyObject* values, IntArrayRef size)"; + break; + case c10::Layout::SparseBsr: + sig = + "_validate_sparse_bsr_tensor(PyObject* crow_indices, PyObject* col_indices, PyObject* values, IntArrayRef size)"; + break; + case c10::Layout::SparseBsc: + sig = + "_validate_sparse_bsc_tensor(PyObject* ccol_indices, PyObject* row_indices, PyObject* values, IntArrayRef size)"; + break; + default:; } static PythonArgParser parser({sig}); ParsedArgs parsed_args; auto r = parser.parse(args, kwargs, parsed_args); Tensor values = internal_new_from_data( - options, scalar_type, c10::nullopt, r.pyobject(ARG_VALUES), - /*copy_variables=*/false, /*copy_numpy=*/true, /*type_inference=*/true); + options, + scalar_type, + c10::nullopt, + r.pyobject(ARG_VALUES), + /*copy_variables=*/false, + /*copy_numpy=*/true, + /*type_inference=*/true); // See Note [Ensuring sparse values and indices match devices] Tensor compressed_indices = internal_new_from_data( - values.options(), kInt, c10::nullopt, r.pyobject(ARG_COMPRESSED_INDICES), - /*copy_variables=*/false, /*copy_numpy=*/true, /*type_inference=*/true); + values.options(), + kInt, + c10::nullopt, + r.pyobject(ARG_COMPRESSED_INDICES), + /*copy_variables=*/false, + /*copy_numpy=*/true, + /*type_inference=*/true); Tensor plain_indices = internal_new_from_data( - values.options(), kInt, c10::nullopt, r.pyobject(ARG_PLAIN_INDICES), - /*copy_variables=*/false, /*copy_numpy=*/true, /*type_inference=*/true); + values.options(), + kInt, + c10::nullopt, + r.pyobject(ARG_PLAIN_INDICES), + /*copy_variables=*/false, + /*copy_numpy=*/true, + /*type_inference=*/true); - at::native::_validate_sparse_compressed_tensor_args(compressed_indices, plain_indices, values, r.intlist(3), required_layout); + at::native::_validate_sparse_compressed_tensor_args( + compressed_indices, plain_indices, values, r.intlist(3), required_layout); } -void _validate_sparse_csr_tensor_args(c10::DispatchKey dispatch_key, at::ScalarType scalar_type, PyObject* args, PyObject* kwargs) { - _validate_sparse_compressed_tensor_args_template(dispatch_key, scalar_type, args, kwargs); +void _validate_sparse_csr_tensor_args( + c10::DispatchKey dispatch_key, + at::ScalarType scalar_type, + PyObject* args, + PyObject* kwargs) { + _validate_sparse_compressed_tensor_args_template( + dispatch_key, scalar_type, args, kwargs); } -void _validate_sparse_csc_tensor_args(c10::DispatchKey dispatch_key, at::ScalarType scalar_type, PyObject* args, PyObject* kwargs) { - _validate_sparse_compressed_tensor_args_template(dispatch_key, scalar_type, args, kwargs); +void _validate_sparse_csc_tensor_args( + c10::DispatchKey dispatch_key, + at::ScalarType scalar_type, + PyObject* args, + PyObject* kwargs) { + _validate_sparse_compressed_tensor_args_template( + dispatch_key, scalar_type, args, kwargs); } -void _validate_sparse_bsr_tensor_args(c10::DispatchKey dispatch_key, at::ScalarType scalar_type, PyObject* args, PyObject* kwargs) { - _validate_sparse_compressed_tensor_args_template(dispatch_key, scalar_type, args, kwargs); +void _validate_sparse_bsr_tensor_args( + c10::DispatchKey dispatch_key, + at::ScalarType scalar_type, + PyObject* args, + PyObject* kwargs) { + _validate_sparse_compressed_tensor_args_template( + dispatch_key, scalar_type, args, kwargs); } -void _validate_sparse_bsc_tensor_args(c10::DispatchKey dispatch_key, at::ScalarType scalar_type, PyObject* args, PyObject* kwargs) { - _validate_sparse_compressed_tensor_args_template(dispatch_key, scalar_type, args, kwargs); +void _validate_sparse_bsc_tensor_args( + c10::DispatchKey dispatch_key, + at::ScalarType scalar_type, + PyObject* args, + PyObject* kwargs) { + _validate_sparse_compressed_tensor_args_template( + dispatch_key, scalar_type, args, kwargs); } Tensor tensor_ctor( @@ -1004,27 +1453,31 @@ Tensor tensor_ctor( if (r.idx == 0) { PyObject* data = r.pyobject(0); if (THPVariable_Check(data)) { - auto ret = PyErr_WarnEx(PyExc_UserWarning, - "To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() " - "or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).", 1); - if (ret != 0) throw python_error(); + auto ret = PyErr_WarnEx( + PyExc_UserWarning, + "To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() " + "or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).", + 1); + if (ret != 0) + throw python_error(); } bool type_inference = r.isNone(1); bool pin_memory = r.toBool(3); bool args_requires_grad = r.toBool(4); auto new_tensor = internal_new_from_data( - typeIdWithDefault(r, 2, dispatch_key), - r.scalartypeWithDefault(1, scalar_type), - r.deviceOptional(2), - data, - /*copy_variables=*/true, - /*copy_numpy=*/true, - /*type_inference=*/type_inference, - pin_memory); + typeIdWithDefault(r, 2, dispatch_key), + r.scalartypeWithDefault(1, scalar_type), + r.deviceOptional(2), + data, + /*copy_variables=*/true, + /*copy_numpy=*/true, + /*type_inference=*/type_inference, + pin_memory); auto names = r.toDimnameListOptional(5); if (names) { - at::namedinference::propagate_names(new_tensor, *names, /*validate_names=*/true); + at::namedinference::propagate_names( + new_tensor, *names, /*validate_names=*/true); } new_tensor.detach_(); // ensure new_tensor a leaf node new_tensor.set_requires_grad(args_requires_grad); @@ -1052,9 +1505,13 @@ Tensor as_tensor( throw std::runtime_error("tensor(): invalid arguments"); } -Tensor new_tensor(c10::DispatchKey dispatch_key, at::ScalarType scalar_type, PyObject* args, PyObject* kwargs) { +Tensor new_tensor( + c10::DispatchKey dispatch_key, + at::ScalarType scalar_type, + PyObject* args, + PyObject* kwargs) { static PythonArgParser parser({ - "new_tensor(PyObject* data, *, ScalarType dtype=None, Device? device=None, bool requires_grad=False)", + "new_tensor(PyObject* data, *, ScalarType dtype=None, Device? device=None, bool requires_grad=False)", }); ParsedArgs<4> parsed_args; @@ -1062,18 +1519,21 @@ Tensor new_tensor(c10::DispatchKey dispatch_key, at::ScalarType scalar_type, PyO if (r.idx == 0) { PyObject* data = r.pyobject(0); if (THPVariable_Check(data)) { - auto ret = PyErr_WarnEx(PyExc_UserWarning, - "To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() " - "or sourceTensor.clone().detach().requires_grad_(True), rather than tensor.new_tensor(sourceTensor).", 1); - if (ret != 0) throw python_error(); + auto ret = PyErr_WarnEx( + PyExc_UserWarning, + "To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() " + "or sourceTensor.clone().detach().requires_grad_(True), rather than tensor.new_tensor(sourceTensor).", + 1); + if (ret != 0) + throw python_error(); } bool args_requires_grad = r.toBool(3); auto new_tensor = new_from_data_copy( - typeIdWithDefault(r, 2, dispatch_key), - r.scalartypeWithDefault(1, scalar_type), - r.deviceOptional(2), - data); + typeIdWithDefault(r, 2, dispatch_key), + r.scalartypeWithDefault(1, scalar_type), + r.deviceOptional(2), + data); new_tensor.detach_(); // ensure new_tensor a leaf node new_tensor.set_requires_grad(args_requires_grad); return new_tensor; @@ -1081,7 +1541,12 @@ Tensor new_tensor(c10::DispatchKey dispatch_key, at::ScalarType scalar_type, PyO throw std::runtime_error("new_tensor(): invalid arguments"); } -Tensor tensor_frombuffer(PyObject* buffer, ScalarType dtype, int64_t count, int64_t offset, bool requires_grad) { +Tensor tensor_frombuffer( + PyObject* buffer, + ScalarType dtype, + int64_t count, + int64_t offset, + bool requires_grad) { auto elsize = at::elementSize(dtype); size_t actual_count = 0; @@ -1109,15 +1574,29 @@ Tensor tensor_frombuffer(PyObject* buffer, ScalarType dtype, int64_t count, int6 TORCH_CHECK_VALUE( len > 0 && count != 0, - "both buffer length (", len, ") and count (", count, ") must not be 0"); + "both buffer length (", + len, + ") and count (", + count, + ") must not be 0"); TORCH_CHECK_VALUE( offset >= 0 && offset < len, - "offset (", offset, " bytes) must be non-negative and no greater than " - "buffer length (", len, " bytes) minus 1"); + "offset (", + offset, + " bytes) must be non-negative and no greater than " + "buffer length (", + len, + " bytes) minus 1"); TORCH_CHECK_VALUE( count > 0 || (len - offset) % elsize == 0, - "buffer length (", len - offset, " bytes) after offset (", offset, " bytes) " - "must be a multiple of element size (", elsize, ")"); + "buffer length (", + len - offset, + " bytes) after offset (", + offset, + " bytes) " + "must be a multiple of element size (", + elsize, + ")"); if (count < 0) { actual_count = (len - offset) / elsize; @@ -1126,15 +1605,22 @@ Tensor tensor_frombuffer(PyObject* buffer, ScalarType dtype, int64_t count, int6 } TORCH_CHECK_VALUE( - static_cast(offset) + actual_count * elsize <= static_cast(len), - "requested buffer length (", actual_count, " * ", elsize, " bytes) " - "after offset (", offset, " bytes) must not be greater than actual " - "buffer length (", len, " bytes)"); + static_cast(offset) + actual_count * elsize <= + static_cast(len), + "requested buffer length (", + actual_count, + " * ", + elsize, + " bytes) " + "after offset (", + offset, + " bytes) must not be greater than actual " + "buffer length (", + len, + " bytes)"); auto offset_buf = static_cast(buf) + offset; - auto options = TensorOptions() - .dtype(dtype) - .device(c10::kCPU); + auto options = TensorOptions().dtype(dtype).device(c10::kCPU); auto tensor = at::for_blob(offset_buf, static_cast(actual_count)) .options(options) @@ -1147,16 +1633,19 @@ Tensor tensor_frombuffer(PyObject* buffer, ScalarType dtype, int64_t count, int6 return tensor; } -Tensor tensor_fromDLPack(PyObject *data) { - DLManagedTensor * dlMTensor = (DLManagedTensor *)PyCapsule_GetPointer(data, "dltensor"); - TORCH_CHECK(dlMTensor, - "from_dlpack received an invalid capsule. " - "Note that DLTensor capsules can be consumed only once, " - "so you might have already constructed a tensor from it once."); +Tensor tensor_fromDLPack(PyObject* data) { + DLManagedTensor* dlMTensor = + (DLManagedTensor*)PyCapsule_GetPointer(data, "dltensor"); + TORCH_CHECK( + dlMTensor, + "from_dlpack received an invalid capsule. " + "Note that DLTensor capsules can be consumed only once, " + "so you might have already constructed a tensor from it once."); // atensor steals the ownership of the underlying storage. It also passes a // destructor function that will be called when the underlying storage goes - // out of scope. When the destructor is called, the dlMTensor is destructed too. + // out of scope. When the destructor is called, the dlMTensor is destructed + // too. auto atensor = at::fromDLPack(dlMTensor); // Make sure this capsule will never be used again. @@ -1168,7 +1657,7 @@ Tensor tensor_fromDLPack(PyObject *data) { // because cuda ATen types have not been registered in Python yet. // so if we have a cuda tensor, then we need to make sure // we have called _lazy_init here - if(atensor.is_cuda()) { + if (atensor.is_cuda()) { py::module::import("torch.cuda").attr("init")(); } return atensor; @@ -1198,7 +1687,7 @@ Tensor asarray( // Check whether 'obj' is a NumPy Array if (is_numpy_available() && PyArray_Check(obj)) { tensor = tensor_from_numpy(obj, /*warn_if_not_writeable=*/false); - should_warn_numpy_not_writable = !PyArray_ISWRITEABLE((PyArrayObject*) obj); + should_warn_numpy_not_writable = !PyArray_ISWRITEABLE((PyArrayObject*)obj); } #endif @@ -1234,12 +1723,18 @@ Tensor asarray( // in the right device, with the right dtype. TORCH_CHECK_VALUE( !wrong_device, - "can't alias tensor from device '", tensor.device(), - "' to '", device.value(), "'."); + "can't alias tensor from device '", + tensor.device(), + "' to '", + device.value(), + "'."); TORCH_CHECK_VALUE( !wrong_dtype, - "can't alias tensor with dtype '", tensor.scalar_type(), - "' into dtype '", dtype.value(), "'."); + "can't alias tensor with dtype '", + tensor.scalar_type(), + "' into dtype '", + dtype.value(), + "'."); // If tensor is a NumPy Array view, we warn the user about non-writeable // arrays if this is the case. if (should_warn_numpy_not_writable) { @@ -1266,12 +1761,18 @@ Tensor asarray( // Type inference is activated only if the dtype has not been specified. // Otherwise, we force the unwrapped dtype. tensor = internal_new_from_data( - TensorOptions(), dtype_unwrapped, device, obj, - /* copy_variables = */ false, /* copy_numpy = */ false, /* type_inference = */ !dtype.has_value()); + TensorOptions(), + dtype_unwrapped, + device, + obj, + /* copy_variables = */ false, + /* copy_numpy = */ false, + /* type_inference = */ !dtype.has_value()); tensor.set_requires_grad(requires_grad); } return tensor; } -}} // namespace torch::utils +} // namespace utils +} // namespace torch diff --git a/torch/csrc/utils/tensor_new.h b/torch/csrc/utils/tensor_new.h index da2829feb16022..b06840864b6d8c 100644 --- a/torch/csrc/utils/tensor_new.h +++ b/torch/csrc/utils/tensor_new.h @@ -5,11 +5,20 @@ #include -namespace torch { namespace utils { +namespace torch { +namespace utils { at::Tensor base_tensor_ctor(PyObject* args, PyObject* kwargs); -at::Tensor legacy_tensor_ctor(c10::DispatchKey dispatch_key, at::ScalarType scalar_type, PyObject* args, PyObject* kwargs); -at::Tensor legacy_tensor_new(c10::DispatchKey dispatch_key, at::ScalarType scalar_type, PyObject* args, PyObject* kwargs); +at::Tensor legacy_tensor_ctor( + c10::DispatchKey dispatch_key, + at::ScalarType scalar_type, + PyObject* args, + PyObject* kwargs); +at::Tensor legacy_tensor_new( + c10::DispatchKey dispatch_key, + at::ScalarType scalar_type, + PyObject* args, + PyObject* kwargs); at::Tensor indexing_tensor_from_data( c10::TensorOptions options, at::ScalarType scalar_type, @@ -23,25 +32,79 @@ at::Tensor _sparse_coo_tensor_unsafe_ctor( c10::DispatchKey dispatch_key, at::ScalarType scalar_type, PythonArgs& r); -void _validate_sparse_coo_tensor_args(c10::DispatchKey dispatch_key, at::ScalarType scalar_type, PyObject* args, PyObject* kwargs); +void _validate_sparse_coo_tensor_args( + c10::DispatchKey dispatch_key, + at::ScalarType scalar_type, + PyObject* args, + PyObject* kwargs); -at::Tensor sparse_compressed_tensor_ctor(c10::DispatchKey dispatch_key, at::ScalarType scalar_type, PythonArgs& r); -at::Tensor sparse_csr_tensor_ctor(c10::DispatchKey dispatch_key, at::ScalarType scalar_type, PythonArgs& r); -at::Tensor sparse_csc_tensor_ctor(c10::DispatchKey dispatch_key, at::ScalarType scalar_type, PythonArgs& r); -at::Tensor sparse_bsr_tensor_ctor(c10::DispatchKey dispatch_key, at::ScalarType scalar_type, PythonArgs& r); -at::Tensor sparse_bsc_tensor_ctor(c10::DispatchKey dispatch_key, at::ScalarType scalar_type, PythonArgs& r); +at::Tensor sparse_compressed_tensor_ctor( + c10::DispatchKey dispatch_key, + at::ScalarType scalar_type, + PythonArgs& r); +at::Tensor sparse_csr_tensor_ctor( + c10::DispatchKey dispatch_key, + at::ScalarType scalar_type, + PythonArgs& r); +at::Tensor sparse_csc_tensor_ctor( + c10::DispatchKey dispatch_key, + at::ScalarType scalar_type, + PythonArgs& r); +at::Tensor sparse_bsr_tensor_ctor( + c10::DispatchKey dispatch_key, + at::ScalarType scalar_type, + PythonArgs& r); +at::Tensor sparse_bsc_tensor_ctor( + c10::DispatchKey dispatch_key, + at::ScalarType scalar_type, + PythonArgs& r); -at::Tensor _sparse_compressed_tensor_unsafe_ctor(c10::DispatchKey dispatch_key, at::ScalarType scalar_type, PythonArgs& r); -at::Tensor _sparse_csr_tensor_unsafe_ctor(c10::DispatchKey dispatch_key, at::ScalarType scalar_type, PythonArgs& r); -at::Tensor _sparse_csc_tensor_unsafe_ctor(c10::DispatchKey dispatch_key, at::ScalarType scalar_type, PythonArgs& r); -at::Tensor _sparse_bsr_tensor_unsafe_ctor(c10::DispatchKey dispatch_key, at::ScalarType scalar_type, PythonArgs& r); -at::Tensor _sparse_bsc_tensor_unsafe_ctor(c10::DispatchKey dispatch_key, at::ScalarType scalar_type, PythonArgs& r); +at::Tensor _sparse_compressed_tensor_unsafe_ctor( + c10::DispatchKey dispatch_key, + at::ScalarType scalar_type, + PythonArgs& r); +at::Tensor _sparse_csr_tensor_unsafe_ctor( + c10::DispatchKey dispatch_key, + at::ScalarType scalar_type, + PythonArgs& r); +at::Tensor _sparse_csc_tensor_unsafe_ctor( + c10::DispatchKey dispatch_key, + at::ScalarType scalar_type, + PythonArgs& r); +at::Tensor _sparse_bsr_tensor_unsafe_ctor( + c10::DispatchKey dispatch_key, + at::ScalarType scalar_type, + PythonArgs& r); +at::Tensor _sparse_bsc_tensor_unsafe_ctor( + c10::DispatchKey dispatch_key, + at::ScalarType scalar_type, + PythonArgs& r); -void _validate_sparse_compressed_tensor_args(c10::DispatchKey dispatch_key, at::ScalarType scalar_type, PyObject* args, PyObject* kwargs); -void _validate_sparse_csr_tensor_args(c10::DispatchKey dispatch_key, at::ScalarType scalar_type, PyObject* args, PyObject* kwargs); -void _validate_sparse_csc_tensor_args(c10::DispatchKey dispatch_key, at::ScalarType scalar_type, PyObject* args, PyObject* kwargs); -void _validate_sparse_bsr_tensor_args(c10::DispatchKey dispatch_key, at::ScalarType scalar_type, PyObject* args, PyObject* kwargs); -void _validate_sparse_bsc_tensor_args(c10::DispatchKey dispatch_key, at::ScalarType scalar_type, PyObject* args, PyObject* kwargs); +void _validate_sparse_compressed_tensor_args( + c10::DispatchKey dispatch_key, + at::ScalarType scalar_type, + PyObject* args, + PyObject* kwargs); +void _validate_sparse_csr_tensor_args( + c10::DispatchKey dispatch_key, + at::ScalarType scalar_type, + PyObject* args, + PyObject* kwargs); +void _validate_sparse_csc_tensor_args( + c10::DispatchKey dispatch_key, + at::ScalarType scalar_type, + PyObject* args, + PyObject* kwargs); +void _validate_sparse_bsr_tensor_args( + c10::DispatchKey dispatch_key, + at::ScalarType scalar_type, + PyObject* args, + PyObject* kwargs); +void _validate_sparse_bsc_tensor_args( + c10::DispatchKey dispatch_key, + at::ScalarType scalar_type, + PyObject* args, + PyObject* kwargs); at::Tensor tensor_ctor( c10::DispatchKey dispatch_key, @@ -51,9 +114,28 @@ at::Tensor as_tensor( c10::DispatchKey dispatch_key, at::ScalarType scalar_type, PythonArgs& r); -at::Tensor new_tensor(c10::DispatchKey dispatch_key, at::ScalarType scalar_type, PyObject* args, PyObject* kwargs); -at::Tensor new_ones(c10::DispatchKey dispatch_key, at::ScalarType scalar_type, PyObject* args, PyObject* kwargs); -at::Tensor tensor_frombuffer(PyObject* buffer, at::ScalarType dtype, int64_t count, int64_t offset, bool requires_grad); -at::Tensor tensor_fromDLPack(PyObject *data); -at::Tensor asarray(PyObject* obj, c10::optional dtype, c10::optional device, c10::optional copy, bool requires_grad); -}} // namespace torch::utils +at::Tensor new_tensor( + c10::DispatchKey dispatch_key, + at::ScalarType scalar_type, + PyObject* args, + PyObject* kwargs); +at::Tensor new_ones( + c10::DispatchKey dispatch_key, + at::ScalarType scalar_type, + PyObject* args, + PyObject* kwargs); +at::Tensor tensor_frombuffer( + PyObject* buffer, + at::ScalarType dtype, + int64_t count, + int64_t offset, + bool requires_grad); +at::Tensor tensor_fromDLPack(PyObject* data); +at::Tensor asarray( + PyObject* obj, + c10::optional dtype, + c10::optional device, + c10::optional copy, + bool requires_grad); +} // namespace utils +} // namespace torch diff --git a/torch/csrc/utils/tensor_numpy.cpp b/torch/csrc/utils/tensor_numpy.cpp index 1192b391ef79ed..05db981d5dafdb 100644 --- a/torch/csrc/utils/tensor_numpy.cpp +++ b/torch/csrc/utils/tensor_numpy.cpp @@ -1,15 +1,18 @@ #include #include #define WITH_NUMPY_IMPORT_ARRAY -#include #include +#include #ifndef USE_NUMPY -namespace torch { namespace utils { +namespace torch { +namespace utils { PyObject* tensor_to_numpy(const at::Tensor& tensor) { throw std::runtime_error("PyTorch was compiled without NumPy support"); } -at::Tensor tensor_from_numpy(PyObject* obj, bool warn_if_not_writeable/*=true*/) { +at::Tensor tensor_from_numpy( + PyObject* obj, + bool warn_if_not_writeable /*=true*/) { throw std::runtime_error("PyTorch was compiled without NumPy support"); } @@ -26,7 +29,8 @@ bool is_numpy_scalar(PyObject* obj) { at::Tensor tensor_from_cuda_array_interface(PyObject* obj) { throw std::runtime_error("PyTorch was compiled without NumPy support"); } -}} +} // namespace utils +} // namespace torch #else #include @@ -43,7 +47,8 @@ at::Tensor tensor_from_cuda_array_interface(PyObject* obj) { using namespace at; using namespace torch::autograd; -namespace torch { namespace utils { +namespace torch { +namespace utils { bool is_numpy_available() { static bool available = []() { @@ -74,7 +79,7 @@ static std::vector to_numpy_shape(IntArrayRef x) { // shape and stride conversion from int64_t to npy_intp auto nelem = x.size(); auto result = std::vector(nelem); - for(const auto i : c10::irange(nelem)) { + for (const auto i : c10::irange(nelem)) { result[i] = static_cast(x[i]); } return result; @@ -83,56 +88,66 @@ static std::vector to_numpy_shape(IntArrayRef x) { static std::vector to_aten_shape(int ndim, npy_intp* values) { // shape and stride conversion from npy_intp to int64_t auto result = std::vector(ndim); - for(const auto i : c10::irange(ndim)) { + for (const auto i : c10::irange(ndim)) { result[i] = static_cast(values[i]); } return result; } -static std::vector seq_to_aten_shape(PyObject *py_seq) { +static std::vector seq_to_aten_shape(PyObject* py_seq) { int ndim = PySequence_Length(py_seq); if (ndim == -1) { throw TypeError("shape and strides must be sequences"); } auto result = std::vector(ndim); - for(const auto i : c10::irange(ndim)) { + for (const auto i : c10::irange(ndim)) { auto item = THPObjectPtr(PySequence_GetItem(py_seq, i)); - if (!item) throw python_error(); + if (!item) + throw python_error(); result[i] = PyLong_AsLongLong(item); - if (result[i] == -1 && PyErr_Occurred()) throw python_error(); + if (result[i] == -1 && PyErr_Occurred()) + throw python_error(); } return result; } -PyObject* tensor_to_numpy(const at::Tensor& tensor, bool force/*=false*/) { +PyObject* tensor_to_numpy(const at::Tensor& tensor, bool force /*=false*/) { TORCH_CHECK(is_numpy_available(), "Numpy is not available"); - TORCH_CHECK(!tensor.unsafeGetTensorImpl()->is_python_dispatch(), - ".numpy() is not supported for tensor subclasses."); + TORCH_CHECK( + !tensor.unsafeGetTensorImpl()->is_python_dispatch(), + ".numpy() is not supported for tensor subclasses."); - TORCH_CHECK_TYPE(tensor.layout() == Layout::Strided, - "can't convert ", c10::str(tensor.layout()).c_str(), + TORCH_CHECK_TYPE( + tensor.layout() == Layout::Strided, + "can't convert ", + c10::str(tensor.layout()).c_str(), " layout tensor to numpy. ", "Use Tensor.dense() first."); - if (!force){ - TORCH_CHECK_TYPE(tensor.device().type() == DeviceType::CPU, - "can't convert ", tensor.device().str().c_str(), - " device type tensor to numpy. Use Tensor.cpu() to ", - "copy the tensor to host memory first."); - - TORCH_CHECK(!(at::GradMode::is_enabled() && tensor.requires_grad()), - "Can't call numpy() on Tensor that requires grad. " - "Use tensor.detach().numpy() instead."); - - TORCH_CHECK(!tensor.is_conj(), - "Can't call numpy() on Tensor that has conjugate bit set. ", - "Use tensor.resolve_conj().numpy() instead."); - - TORCH_CHECK(!tensor.is_neg(), - "Can't call numpy() on Tensor that has negative bit set. " - "Use tensor.resolve_neg().numpy() instead."); + if (!force) { + TORCH_CHECK_TYPE( + tensor.device().type() == DeviceType::CPU, + "can't convert ", + tensor.device().str().c_str(), + " device type tensor to numpy. Use Tensor.cpu() to ", + "copy the tensor to host memory first."); + + TORCH_CHECK( + !(at::GradMode::is_enabled() && tensor.requires_grad()), + "Can't call numpy() on Tensor that requires grad. " + "Use tensor.detach().numpy() instead."); + + TORCH_CHECK( + !tensor.is_conj(), + "Can't call numpy() on Tensor that has conjugate bit set. ", + "Use tensor.resolve_conj().numpy() instead."); + + TORCH_CHECK( + !tensor.is_neg(), + "Can't call numpy() on Tensor that has negative bit set. " + "Use tensor.resolve_neg().numpy() instead."); } auto prepared_tensor = tensor.detach().cpu().resolve_conj().resolve_neg(); @@ -157,14 +172,16 @@ PyObject* tensor_to_numpy(const at::Tensor& tensor, bool force/*=false*/) { 0, NPY_ARRAY_ALIGNED | NPY_ARRAY_WRITEABLE, nullptr)); - if (!array) return nullptr; + if (!array) + return nullptr; // TODO: This attempts to keep the underlying memory alive by setting the base // object of the ndarray to the tensor and disabling resizes on the storage. // This is not sufficient. For example, the tensor's storage may be changed // via Tensor.set_, which can free the underlying memory. PyObject* py_tensor = THPVariable_Wrap(prepared_tensor); - if (!py_tensor) throw python_error(); + if (!py_tensor) + throw python_error(); if (PyArray_SetBaseObject((PyArrayObject*)array.get(), py_tensor) == -1) { return nullptr; } @@ -176,15 +193,17 @@ PyObject* tensor_to_numpy(const at::Tensor& tensor, bool force/*=false*/) { void warn_numpy_not_writeable() { TORCH_WARN_ONCE( - "The given NumPy array is not writable, and PyTorch does " - "not support non-writable tensors. This means writing to this tensor " - "will result in undefined behavior. " - "You may want to copy the array to protect its data or make it writable " - "before converting it to a tensor. This type of warning will be " - "suppressed for the rest of this program."); + "The given NumPy array is not writable, and PyTorch does " + "not support non-writable tensors. This means writing to this tensor " + "will result in undefined behavior. " + "You may want to copy the array to protect its data or make it writable " + "before converting it to a tensor. This type of warning will be " + "suppressed for the rest of this program."); } -at::Tensor tensor_from_numpy(PyObject* obj, bool warn_if_not_writeable/*=true*/) { +at::Tensor tensor_from_numpy( + PyObject* obj, + bool warn_if_not_writeable /*=true*/) { if (!is_numpy_available()) { throw std::runtime_error("Numpy is not available"); } @@ -205,16 +224,16 @@ at::Tensor tensor_from_numpy(PyObject* obj, bool warn_if_not_writeable/*=true*/) // NumPy strides use bytes. Torch strides use element counts. auto element_size_in_bytes = PyArray_ITEMSIZE(array); for (auto& stride : strides) { - if (stride%element_size_in_bytes != 0) { + if (stride % element_size_in_bytes != 0) { throw ValueError( - "given numpy array strides not a multiple of the element byte size. " - "Copy the numpy array to reallocate the memory."); + "given numpy array strides not a multiple of the element byte size. " + "Copy the numpy array to reallocate the memory."); } stride /= element_size_in_bytes; } size_t storage_size = 1; - for(const auto i : c10::irange(ndim)) { + for (const auto i : c10::irange(ndim)) { if (strides[i] < 0) { throw ValueError( "At least one stride in the given numpy array is negative, " @@ -241,23 +260,33 @@ at::Tensor tensor_from_numpy(PyObject* obj, bool warn_if_not_writeable/*=true*/) pybind11::gil_scoped_acquire gil; Py_DECREF(obj); }, - at::device(kCPU).dtype(numpy_dtype_to_aten(PyArray_TYPE(array))) - ); + at::device(kCPU).dtype(numpy_dtype_to_aten(PyArray_TYPE(array)))); } int aten_to_numpy_dtype(const ScalarType scalar_type) { switch (scalar_type) { - case kDouble: return NPY_DOUBLE; - case kFloat: return NPY_FLOAT; - case kHalf: return NPY_HALF; - case kComplexDouble: return NPY_COMPLEX128; - case kComplexFloat: return NPY_COMPLEX64; - case kLong: return NPY_INT64; - case kInt: return NPY_INT32; - case kShort: return NPY_INT16; - case kChar: return NPY_INT8; - case kByte: return NPY_UINT8; - case kBool: return NPY_BOOL; + case kDouble: + return NPY_DOUBLE; + case kFloat: + return NPY_FLOAT; + case kHalf: + return NPY_HALF; + case kComplexDouble: + return NPY_COMPLEX128; + case kComplexFloat: + return NPY_COMPLEX64; + case kLong: + return NPY_INT64; + case kInt: + return NPY_INT32; + case kShort: + return NPY_INT16; + case kChar: + return NPY_INT8; + case kByte: + return NPY_UINT8; + case kBool: + return NPY_BOOL; default: throw TypeError("Got unsupported ScalarType %s", toString(scalar_type)); } @@ -265,17 +294,27 @@ int aten_to_numpy_dtype(const ScalarType scalar_type) { ScalarType numpy_dtype_to_aten(int dtype) { switch (dtype) { - case NPY_DOUBLE: return kDouble; - case NPY_FLOAT: return kFloat; - case NPY_HALF: return kHalf; - case NPY_COMPLEX64: return kComplexFloat; - case NPY_COMPLEX128: return kComplexDouble; - case NPY_INT16: return kShort; - case NPY_INT8: return kChar; - case NPY_UINT8: return kByte; - case NPY_BOOL: return kBool; + case NPY_DOUBLE: + return kDouble; + case NPY_FLOAT: + return kFloat; + case NPY_HALF: + return kHalf; + case NPY_COMPLEX64: + return kComplexFloat; + case NPY_COMPLEX128: + return kComplexDouble; + case NPY_INT16: + return kShort; + case NPY_INT8: + return kChar; + case NPY_UINT8: + return kByte; + case NPY_BOOL: + return kBool; default: - // Workaround: MSVC does not support two switch cases that have the same value + // Workaround: MSVC does not support two switch cases that have the same + // value if (dtype == NPY_INT || dtype == NPY_INT32) { // To cover all cases we must use NPY_INT because // NPY_INT32 is an alias which maybe equal to: @@ -288,11 +327,13 @@ ScalarType numpy_dtype_to_aten(int dtype) { // - NPY_LONGLONG, when sizeof(long) = 4 and sizeof(long long) = 8 return kLong; } else { - break; // break as if this is one of the cases above because this is only a workaround + break; // break as if this is one of the cases above because this is + // only a workaround } } auto pytype = THPObjectPtr(PyArray_TypeObjectFromType(dtype)); - if (!pytype) throw python_error(); + if (!pytype) + throw python_error(); throw TypeError( "can't convert np.ndarray of type %s. The only supported types are: " "float64, float32, float16, complex64, complex128, int64, int32, int16, int8, uint8, and bool.", @@ -304,15 +345,18 @@ bool is_numpy_int(PyObject* obj) { } bool is_numpy_scalar(PyObject* obj) { - return is_numpy_available() && (is_numpy_int(obj) || PyArray_IsScalar(obj, Bool) || - PyArray_IsScalar(obj, Floating) || PyArray_IsScalar(obj, ComplexFloating)); + return is_numpy_available() && + (is_numpy_int(obj) || PyArray_IsScalar(obj, Bool) || + PyArray_IsScalar(obj, Floating) || + PyArray_IsScalar(obj, ComplexFloating)); } at::Tensor tensor_from_cuda_array_interface(PyObject* obj) { if (!is_numpy_available()) { throw std::runtime_error("Numpy is not available"); } - auto cuda_dict = THPObjectPtr(PyObject_GetAttrString(obj, "__cuda_array_interface__")); + auto cuda_dict = + THPObjectPtr(PyObject_GetAttrString(obj, "__cuda_array_interface__")); TORCH_INTERNAL_ASSERT(cuda_dict); if (!PyDict_Check(cuda_dict)) { @@ -322,7 +366,7 @@ at::Tensor tensor_from_cuda_array_interface(PyObject* obj) { // Extract the `obj.__cuda_array_interface__['shape']` attribute std::vector sizes; { - PyObject *py_shape = PyDict_GetItemString(cuda_dict, "shape"); + PyObject* py_shape = PyDict_GetItemString(cuda_dict, "shape"); if (py_shape == nullptr) { throw TypeError("attribute `shape` must exist"); } @@ -334,13 +378,13 @@ at::Tensor tensor_from_cuda_array_interface(PyObject* obj) { // NOLINTNEXTLINE(cppcoreguidelines-init-variables) int dtype_size_in_bytes; { - PyObject *py_typestr = PyDict_GetItemString(cuda_dict, "typestr"); + PyObject* py_typestr = PyDict_GetItemString(cuda_dict, "typestr"); if (py_typestr == nullptr) { throw TypeError("attribute `typestr` must exist"); } // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - PyArray_Descr *descr; - if(!PyArray_DescrConverter(py_typestr, &descr)) { + PyArray_Descr* descr; + if (!PyArray_DescrConverter(py_typestr, &descr)) { throw ValueError("cannot parse `typestr`"); } dtype = numpy_dtype_to_aten(descr->type_num); @@ -350,13 +394,13 @@ at::Tensor tensor_from_cuda_array_interface(PyObject* obj) { // Extract the `obj.__cuda_array_interface__['data']` attribute // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - void *data_ptr; + void* data_ptr; { - PyObject *py_data = PyDict_GetItemString(cuda_dict, "data"); + PyObject* py_data = PyDict_GetItemString(cuda_dict, "data"); if (py_data == nullptr) { throw TypeError("attribute `shape` data exist"); } - if(!PyTuple_Check(py_data) || PyTuple_GET_SIZE(py_data) != 2) { + if (!PyTuple_Check(py_data) || PyTuple_GET_SIZE(py_data) != 2) { throw TypeError("`data` must be a 2-tuple of (int, bool)"); } data_ptr = PyLong_AsVoidPtr(PyTuple_GET_ITEM(py_data, 0)); @@ -368,27 +412,31 @@ at::Tensor tensor_from_cuda_array_interface(PyObject* obj) { throw python_error(); } if (read_only) { - throw TypeError("the read only flag is not supported, should always be False"); + throw TypeError( + "the read only flag is not supported, should always be False"); } } // Extract the `obj.__cuda_array_interface__['strides']` attribute std::vector strides; { - PyObject *py_strides = PyDict_GetItemString(cuda_dict, "strides"); + PyObject* py_strides = PyDict_GetItemString(cuda_dict, "strides"); if (py_strides != nullptr && py_strides != Py_None) { - if (PySequence_Length(py_strides) == -1 || static_cast(PySequence_Length(py_strides)) != sizes.size()) { - throw TypeError("strides must be a sequence of the same length as shape"); + if (PySequence_Length(py_strides) == -1 || + static_cast(PySequence_Length(py_strides)) != sizes.size()) { + throw TypeError( + "strides must be a sequence of the same length as shape"); } strides = seq_to_aten_shape(py_strides); - // __cuda_array_interface__ strides use bytes. Torch strides use element counts. + // __cuda_array_interface__ strides use bytes. Torch strides use element + // counts. for (auto& stride : strides) { - if (stride%dtype_size_in_bytes != 0) { + if (stride % dtype_size_in_bytes != 0) { throw ValueError( "given array strides not a multiple of the element byte size. " "Make a copy of the array to reallocate the memory."); - } + } stride /= dtype_size_in_bytes; } } else { @@ -405,9 +453,9 @@ at::Tensor tensor_from_cuda_array_interface(PyObject* obj) { pybind11::gil_scoped_acquire gil; Py_DECREF(obj); }, - at::device(kCUDA).dtype(dtype) - ); + at::device(kCUDA).dtype(dtype)); } -}} // namespace torch::utils +} // namespace utils +} // namespace torch -#endif // USE_NUMPY +#endif // USE_NUMPY diff --git a/torch/csrc/utils/tensor_numpy.h b/torch/csrc/utils/tensor_numpy.h index 3eb5e42598a67d..a0397f9ff71902 100644 --- a/torch/csrc/utils/tensor_numpy.h +++ b/torch/csrc/utils/tensor_numpy.h @@ -1,12 +1,13 @@ #pragma once -#include #include +#include -namespace torch { namespace utils { +namespace torch { +namespace utils { -PyObject* tensor_to_numpy(const at::Tensor& tensor, bool force=false); -at::Tensor tensor_from_numpy(PyObject* obj, bool warn_if_not_writeable=true); +PyObject* tensor_to_numpy(const at::Tensor& tensor, bool force = false); +at::Tensor tensor_from_numpy(PyObject* obj, bool warn_if_not_writeable = true); int aten_to_numpy_dtype(const at::ScalarType scalar_type); at::ScalarType numpy_dtype_to_aten(int dtype); @@ -18,4 +19,5 @@ bool is_numpy_scalar(PyObject* obj); void warn_numpy_not_writeable(); at::Tensor tensor_from_cuda_array_interface(PyObject* obj); -}} // namespace torch::utils +} // namespace utils +} // namespace torch diff --git a/torch/csrc/utils/tensor_qschemes.cpp b/torch/csrc/utils/tensor_qschemes.cpp index a06694cae0e4aa..9e9d6dbdcfce7b 100644 --- a/torch/csrc/utils/tensor_qschemes.cpp +++ b/torch/csrc/utils/tensor_qschemes.cpp @@ -1,15 +1,14 @@ #include +#include +#include #include #include #include -#include -#include #include #include - namespace torch { namespace utils { diff --git a/torch/csrc/utils/tensor_qschemes.h b/torch/csrc/utils/tensor_qschemes.h index 19d9a373dd1682..71e65479047b8a 100644 --- a/torch/csrc/utils/tensor_qschemes.h +++ b/torch/csrc/utils/tensor_qschemes.h @@ -1,9 +1,11 @@ #pragma once #include -namespace torch { namespace utils { +namespace torch { +namespace utils { PyObject* getTHPQScheme(at::QScheme qscheme); void initializeQSchemes(); -}} // namespace torch::utils +} // namespace utils +} // namespace torch diff --git a/torch/csrc/utils/tensor_types.cpp b/torch/csrc/utils/tensor_types.cpp index c52f10d1722ffa..53566da8390045 100644 --- a/torch/csrc/utils/tensor_types.cpp +++ b/torch/csrc/utils/tensor_types.cpp @@ -2,46 +2,61 @@ #include -#include -#include -#include #include #include +#include +#include +#include +#include #include #include -#include using namespace at; -namespace torch { namespace utils { +namespace torch { +namespace utils { static const char* backend_to_string(const at::Backend& backend) { switch (backend) { - case at::Backend::CPU: return "torch"; - case at::Backend::CUDA: return "torch.cuda"; - case at::Backend::XPU: return "torch.xpu"; - case at::Backend::IPU: return "torch.ipu"; - case at::Backend::SparseCPU: return "torch.sparse"; - case at::Backend::SparseCUDA: return "torch.cuda.sparse"; - case at::Backend::SparseXPU: return "torch.xpu.sparse"; - case at::Backend::QuantizedCPU: return "torch.quantized"; - case at::Backend::HPU: return "torch.hpu"; - case at::Backend::MPS: return "torch.mps"; - case at::Backend::PrivateUse1: return "torch.privateuseone"; - default: AT_ERROR("Unimplemented backend ", backend); + case at::Backend::CPU: + return "torch"; + case at::Backend::CUDA: + return "torch.cuda"; + case at::Backend::XPU: + return "torch.xpu"; + case at::Backend::IPU: + return "torch.ipu"; + case at::Backend::SparseCPU: + return "torch.sparse"; + case at::Backend::SparseCUDA: + return "torch.cuda.sparse"; + case at::Backend::SparseXPU: + return "torch.xpu.sparse"; + case at::Backend::QuantizedCPU: + return "torch.quantized"; + case at::Backend::HPU: + return "torch.hpu"; + case at::Backend::MPS: + return "torch.mps"; + case at::Backend::PrivateUse1: + return "torch.privateuseone"; + default: + AT_ERROR("Unimplemented backend ", backend); } } std::string options_to_string(const at::TensorOptions options) { std::ostringstream ss; - ss << backend_to_string(options.backend()) << "." << toString(at::typeMetaToScalarType(options.dtype())) << "Tensor"; + ss << backend_to_string(options.backend()) << "." + << toString(at::typeMetaToScalarType(options.dtype())) << "Tensor"; return ss.str(); } std::string type_to_string(const at::DeprecatedTypeProperties& type) { std::ostringstream ss; - ss << backend_to_string(type.backend()) << "." << toString(type.scalarType()) << "Tensor"; + ss << backend_to_string(type.backend()) << "." << toString(type.scalarType()) + << "Tensor"; return ss.str(); } @@ -50,17 +65,21 @@ at::TensorOptions options_from_string(const std::string& str) { static std::once_flag cpu_once; static std::once_flag cuda_once; static std::unordered_map cpu_map; - static std::unordered_map cuda_map; + static std::unordered_map + cuda_map; - const std::unordered_map* map = nullptr; + const std::unordered_map* map = + nullptr; if (str == "torch.Tensor") { - auto backend = dispatchKeyToBackend(torch::tensors::get_default_dispatch_key()); + auto backend = + dispatchKeyToBackend(torch::tensors::get_default_dispatch_key()); auto scalar_type = torch::tensors::get_default_scalar_type(); return getDeprecatedTypeProperties(backend, scalar_type).options(); } - if (std::mismatch(cuda_prefix.begin(), cuda_prefix.end(), str.begin()).first == cuda_prefix.end()) { + if (std::mismatch(cuda_prefix.begin(), cuda_prefix.end(), str.begin()) + .first == cuda_prefix.end()) { // torch.cuda. is prefix of str std::call_once(cuda_once, []() { for (auto type : autograd::VariableType::allCUDATypes()) { @@ -90,16 +109,25 @@ std::vector> all_declared_types() { // NOTE: Do not add more types here. This list controls the creation // of legacy tensor types e.g. torch.cuda.FloatTensor which are // maintained for backwards-compatibility only. - std::vector backends = { Backend::CPU, Backend::CUDA, Backend::SparseCPU, Backend::SparseCUDA }; + std::vector backends = { + Backend::CPU, Backend::CUDA, Backend::SparseCPU, Backend::SparseCUDA}; std::vector scalar_types = { - ScalarType::Byte, ScalarType::Char, ScalarType::Double, ScalarType::Float, - ScalarType::Int, ScalarType::Long, ScalarType::Short, ScalarType::Half, - ScalarType::Bool, ScalarType::BFloat16}; + ScalarType::Byte, + ScalarType::Char, + ScalarType::Double, + ScalarType::Float, + ScalarType::Int, + ScalarType::Long, + ScalarType::Short, + ScalarType::Half, + ScalarType::Bool, + ScalarType::BFloat16}; for (auto& backend : backends) { for (auto& scalar_type : scalar_types) { // there is no sparse bool type. - if (scalar_type == ScalarType::Bool && (backend == Backend::SparseCUDA || backend == Backend::SparseCPU)) { + if (scalar_type == ScalarType::Bool && + (backend == Backend::SparseCUDA || backend == Backend::SparseCPU)) { continue; } ret.emplace_back(std::make_pair(backend, scalar_type)); @@ -109,4 +137,5 @@ std::vector> all_declared_types() { return ret; } -}} // namespace torch::utils +} // namespace utils +} // namespace torch diff --git a/torch/csrc/utils/tensor_types.h b/torch/csrc/utils/tensor_types.h index d89e3543ccf8c8..b45c284c9c1ef5 100644 --- a/torch/csrc/utils/tensor_types.h +++ b/torch/csrc/utils/tensor_types.h @@ -5,7 +5,8 @@ #include #include -namespace torch { namespace utils { +namespace torch { +namespace utils { std::string options_to_string(const at::TensorOptions options); std::string type_to_string(const at::DeprecatedTypeProperties& type); @@ -14,4 +15,5 @@ at::TensorOptions options_from_string(const std::string& str); // return a vector of all "declared" types, even those that weren't compiled std::vector> all_declared_types(); -}} // namespace torch::utils +} // namespace utils +} // namespace torch diff --git a/torch/csrc/utils/throughput_benchmark.cpp b/torch/csrc/utils/throughput_benchmark.cpp index dd4abdc9044cf4..65bc190b8b3c83 100644 --- a/torch/csrc/utils/throughput_benchmark.cpp +++ b/torch/csrc/utils/throughput_benchmark.cpp @@ -6,9 +6,11 @@ namespace torch { namespace throughput_benchmark { -std::ostream& operator<<(std::ostream& os, const BenchmarkExecutionStats& value) { - return os << "Average latency / iter (ms): " << value.latency_avg_ms - << "\n Total number of iters: " << value.num_iters; +std::ostream& operator<<( + std::ostream& os, + const BenchmarkExecutionStats& value) { + return os << "Average latency / iter (ms): " << value.latency_avg_ms + << "\n Total number of iters: " << value.num_iters; } void ThroughputBenchmark::addInput(py::args args, py::kwargs kwargs) { @@ -21,7 +23,7 @@ void ThroughputBenchmark::addInput(py::args args, py::kwargs kwargs) { } } -py::object ThroughputBenchmark::runOnce(py::args&& args, py::kwargs&& kwargs) { +py::object ThroughputBenchmark::runOnce(py::args&& args, py::kwargs&& kwargs) { CHECK(script_module_.initialized() ^ module_.initialized()); if (script_module_.initialized()) { c10::IValue result; @@ -36,12 +38,10 @@ py::object ThroughputBenchmark::runOnce(py::args&& args, py::kwargs&& kwargs) { } } -ThroughputBenchmark::ThroughputBenchmark( - jit::Module script_module) +ThroughputBenchmark::ThroughputBenchmark(jit::Module script_module) : script_module_(script_module) {} -ThroughputBenchmark::ThroughputBenchmark( - py::object module) +ThroughputBenchmark::ThroughputBenchmark(py::object module) : module_(std::move(module)) {} BenchmarkExecutionStats ThroughputBenchmark::benchmark( @@ -54,9 +54,10 @@ BenchmarkExecutionStats ThroughputBenchmark::benchmark( return script_module_.benchmark(config); } else { CHECK(module_.initialized()); - TORCH_WARN("Starting benchmark on an nn.Module. This can be slow due " - "to Python GIL.For proper inference simulation you might want to switch to " - "a ScriptModule instead"); + TORCH_WARN( + "Starting benchmark on an nn.Module. This can be slow due " + "to Python GIL.For proper inference simulation you might want to switch to " + "a ScriptModule instead"); return module_.benchmark(config); } } @@ -139,4 +140,4 @@ ScriptModuleInput cloneInput( } // namespace detail } // namespace throughput_benchmark -} // namepsace torch +} // namespace torch diff --git a/torch/csrc/utils/throughput_benchmark.h b/torch/csrc/utils/throughput_benchmark.h index 09e25d38629f7d..f0582e400847d8 100644 --- a/torch/csrc/utils/throughput_benchmark.h +++ b/torch/csrc/utils/throughput_benchmark.h @@ -1,8 +1,8 @@ #pragma once #include -#include #include +#include #include @@ -25,7 +25,9 @@ struct BenchmarkExecutionStats { int64_t num_iters{-1}; }; -std::ostream& operator<<(std::ostream& os, const BenchmarkExecutionStats& value); +std::ostream& operator<<( + std::ostream& os, + const BenchmarkExecutionStats& value); /** * Use this struct in order to configure a throughput benchmark run. @@ -53,8 +55,8 @@ struct BenchmarkConfig { // Number of iterations the benchmark should run with. This number is separate // from the warmup iterations int64_t num_iters{100}; - // If set autograd profiler will be enabled. I.e. this variable would be created - // before the main benchmark loop (but after the warmup): + // If set autograd profiler will be enabled. I.e. this variable would be + // created before the main benchmark loop (but after the warmup): // RecordProfile guard(profiler_output_path); std::string profiler_output_path{""}; }; @@ -66,10 +68,10 @@ namespace detail { */ template class BenchmarkHelper { -public: + public: BenchmarkHelper(); // NOLINTNEXTLINE(modernize-pass-by-value) - explicit BenchmarkHelper(Model model): model_(model), initialized_(true) {} + explicit BenchmarkHelper(Model model) : model_(model), initialized_(true) {} // This method to be used in benchmark() method // Note that there is no result. This way we don't have to call this under GIL @@ -84,7 +86,9 @@ class BenchmarkHelper { void addInput(Input&&); BenchmarkExecutionStats benchmark(const BenchmarkConfig& config) const; - bool initialized() const { return initialized_; } + bool initialized() const { + return initialized_; + } // Destructor doesn't require the GIL because it is going to be executed on // the PyThon thread @@ -110,26 +114,23 @@ typedef py::object ModuleOutput; typedef std::vector ScriptModuleInput; typedef at::IValue ScriptModuleOutput; -template +template Input cloneInput(const Input& input); -typedef BenchmarkHelper< - ScriptModuleInput, - at::IValue, - jit::Module> +typedef BenchmarkHelper ScriptModuleBenchmark; template <> -inline BenchmarkHelper::BenchmarkHelper() - : model_("Module", std::make_shared()), - initialized_(false) {} +inline BenchmarkHelper:: + BenchmarkHelper() + : model_("Module", std::make_shared()), + initialized_(false) {} typedef BenchmarkHelper ModuleBenchmark; template <> inline BenchmarkHelper::BenchmarkHelper() - : initialized_(false) {} + : initialized_(false) {} template <> -void ScriptModuleBenchmark::runOnce( - ScriptModuleInput&& input) const; +void ScriptModuleBenchmark::runOnce(ScriptModuleInput&& input) const; template <> ScriptModuleOutput ScriptModuleBenchmark::runOnce( @@ -148,7 +149,6 @@ void ScriptModuleBenchmark::addInput(py::args&& args, py::kwargs&& kwargs); template <> void ScriptModuleBenchmark::addInput(ScriptModuleInput&& input); - template <> void ModuleBenchmark::addInput(py::args&& args, py::kwargs&& kwargs); @@ -192,7 +192,7 @@ class C10_HIDDEN ThroughputBenchmark { detail::ScriptModuleBenchmark script_module_; detail::ModuleBenchmark module_; }; -} // namespace throughput benchmark -} // namepsace torch +} // namespace throughput_benchmark +} // namespace torch #include diff --git a/torch/csrc/utils/torch_dispatch_mode.h b/torch/csrc/utils/torch_dispatch_mode.h index 487b444b883f64..3dd73b7422d77a 100644 --- a/torch/csrc/utils/torch_dispatch_mode.h +++ b/torch/csrc/utils/torch_dispatch_mode.h @@ -6,7 +6,7 @@ namespace torch { namespace torch_dispatch_mode { struct StashTorchDispatchModeGuard { -public: + public: StashTorchDispatchModeGuard() { saved_ = at::impl::TorchDispatchModeTLS::get_state(); at::impl::TorchDispatchModeTLS::set_state(nullptr); @@ -15,7 +15,8 @@ struct StashTorchDispatchModeGuard { ~StashTorchDispatchModeGuard() { at::impl::TorchDispatchModeTLS::set_state(saved_); } -private: + + private: std::shared_ptr saved_; }; diff --git a/torch/csrc/utils/variadic.h b/torch/csrc/utils/variadic.h index 6c6bb79ffa6486..74ace7be77fd83 100644 --- a/torch/csrc/utils/variadic.h +++ b/torch/csrc/utils/variadic.h @@ -1,8 +1,8 @@ #pragma once #include -#include #include +#include #include #include @@ -127,7 +127,11 @@ void apply(Function function, Ts&&... ts) { (void)_; } -template +template < + typename ReturnType, + typename... Ts, + typename Function, + typename Accessor> ReturnType unpack(Function function, Accessor accessor) { return ReturnType(unpack( std::move(function), @@ -135,7 +139,12 @@ ReturnType unpack(Function function, Accessor accessor) { typename MakeIndices::indices())); } -template +template < + typename ReturnType, + typename... Ts, + typename Function, + typename Accessor, + size_t... Is> ReturnType unpack(Function function, Accessor accessor, Indices) { return ReturnType(function(accessor.template operator()(Is)...)); }