Skip to content

Commit

Permalink
Modify param names
Browse files Browse the repository at this point in the history
  • Loading branch information
Aswinmcw committed Nov 22, 2024
1 parent 027ae63 commit c44fa2a
Show file tree
Hide file tree
Showing 10 changed files with 56 additions and 56 deletions.
8 changes: 4 additions & 4 deletions ttnn/cpp/ttnn/operations/ccl/all_gather/all_gather.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,19 +9,19 @@
namespace ttnn::operations::ccl {

ttnn::Tensor ExecuteAllGather::invoke(const ttnn::Tensor& input_tensor,
const int16_t gather_dim,
const int32_t dim,
const uint32_t num_links,
const std::optional<ttnn::MemoryConfig>& memory_config,
const std::optional<size_t> num_workers,
const std::optional<size_t> num_buffers_per_channel,
const ttnn::ccl::Topology topology) {
return ttnn::operations::ccl::all_gather(
input_tensor, gather_dim, num_links, memory_config, num_workers, num_buffers_per_channel, topology);
input_tensor, dim, num_links, memory_config, num_workers, num_buffers_per_channel, topology);
}

ttnn::Tensor ExecuteAllGather::invoke(
const ttnn::Tensor& input_tensor,
const int16_t gather_dim,
const int32_t dim,
const uint32_t cluster_axis,
const MeshDevice& mesh_device,
const uint32_t num_links,
Expand All @@ -30,7 +30,7 @@ ttnn::Tensor ExecuteAllGather::invoke(
const std::optional<size_t> num_buffers_per_channel,
const ttnn::ccl::Topology topology) {
return ttnn::operations::ccl::all_gather(
input_tensor, gather_dim, cluster_axis, mesh_device, num_links, memory_config, num_workers, num_buffers_per_channel, topology);
input_tensor, dim, cluster_axis, mesh_device, num_links, memory_config, num_workers, num_buffers_per_channel, topology);
}

} // namespace ttnn::operations::ccl
4 changes: 2 additions & 2 deletions ttnn/cpp/ttnn/operations/ccl/all_gather/all_gather.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ namespace ccl {
struct ExecuteAllGather {
static ttnn::Tensor invoke(
const ttnn::Tensor& input_tensor,
const int16_t gather_dim,
const int32_t dim,
const uint32_t num_links = 1,
const std::optional<ttnn::MemoryConfig>& memory_config = std::nullopt,
const std::optional<size_t> num_workers = std::nullopt,
Expand All @@ -23,7 +23,7 @@ struct ExecuteAllGather {

static ttnn::Tensor invoke(
const ttnn::Tensor& input_tensor,
const int16_t gather_dim,
const int32_t dim,
const uint32_t cluster_axis,
const MeshDevice& mesh_device,
const uint32_t num_links = 1,
Expand Down
16 changes: 8 additions & 8 deletions ttnn/cpp/ttnn/operations/ccl/all_gather/all_gather_pybind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,16 +29,16 @@ void bind_all_gather(pybind11::module& module, const ccl_operation_t& operation,
ttnn::pybind_overload_t{
[](const ccl_operation_t& self,
const ttnn::Tensor& input_tensor,
const int16_t gather_dim,
const int32_t dim,
const uint32_t num_links,
const std::optional<ttnn::MemoryConfig>& memory_config,
const std::optional<size_t> num_workers,
const std::optional<size_t> num_buffers_per_channel,
const ttnn::ccl::Topology topology) -> ttnn::Tensor {
return self(input_tensor, gather_dim, num_links, memory_config, num_workers, num_buffers_per_channel, topology);
return self(input_tensor, dim, num_links, memory_config, num_workers, num_buffers_per_channel, topology);
},
py::arg("input_tensor"),
py::arg("gather_dim"),
py::arg("dim"),
py::kw_only(),
py::arg("num_links") = 1,
py::arg("memory_config") = std::nullopt,
Expand All @@ -49,18 +49,18 @@ void bind_all_gather(pybind11::module& module, const ccl_operation_t& operation,
ttnn::pybind_overload_t{
[](const ccl_operation_t& self,
const ttnn::Tensor& input_tensor,
const int16_t gather_dim,
const int32_t dim,
const uint32_t cluster_axis,
const MeshDevice& mesh_device,
const uint32_t num_links,
const std::optional<ttnn::MemoryConfig>& memory_config,
const std::optional<size_t> num_workers,
const std::optional<size_t> num_buffers_per_channel,
const ttnn::ccl::Topology topology) -> ttnn::Tensor {
return self(input_tensor, gather_dim, cluster_axis, mesh_device, num_links, memory_config, num_workers, num_buffers_per_channel, topology);
return self(input_tensor, dim, cluster_axis, mesh_device, num_links, memory_config, num_workers, num_buffers_per_channel, topology);
},
py::arg("input_tensor"),
py::arg("gather_dim"),
py::arg("dim"),
py::arg("cluster_axis"),
py::arg("mesh_device"),
py::kw_only(),
Expand All @@ -84,7 +84,7 @@ void py_bind_all_gather(pybind11::module& module) {
Args:
input_tensor (ttnn.Tensor): multi-device tensor.
gather_dim (int): Dimension to perform operation.
dim (int): Dimension to perform operation.
cluster_axis (int): Provided a MeshTensor, the axis corresponding to MeshDevice to perform the line-all-gather operation on.
mesh_device (MeshDevice): Device mesh to perform the line-all-gather operation on.
* cluster_axis and mesh_device parameters are applicable only for Linear Topology.
Expand Down Expand Up @@ -113,7 +113,7 @@ void py_bind_all_gather(pybind11::module& module) {
memory_config=mem_config,
mesh_mapper=ShardTensor2dMesh(mesh_device, mesh_shape=(1, 8), dims=(-1, -2)))
>>> ttnn_tensor = ttnn.to_device(ttnn_tensor, mesh_device)
>>> output = ttnn.all_gather(ttnn_tensor, gather_dim=0, topology=ttnn.Topology.Ring)
>>> output = ttnn.all_gather(ttnn_tensor, dim=0, topology=ttnn.Topology.Ring)
)doc");
}
Expand Down
24 changes: 12 additions & 12 deletions ttnn/cpp/ttnn/operations/ccl/all_gather/device/all_gather_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ namespace operations {
namespace ccl {

Tensor all_gather(
const Tensor& input_tensor, const int16_t gather_dim, const uint32_t num_links, const std::optional<MemoryConfig>& memory_config, const std::optional<size_t> user_defined_num_workers, const std::optional<size_t> user_defined_num_buffers_per_channel, const ttnn::ccl::Topology topology) {
const Tensor& input_tensor, const int32_t dim, const uint32_t num_links, const std::optional<MemoryConfig>& memory_config, const std::optional<size_t> user_defined_num_workers, const std::optional<size_t> user_defined_num_buffers_per_channel, const ttnn::ccl::Topology topology) {

TT_FATAL(std::getenv("TT_METAL_SLOW_DISPATCH_MODE") == nullptr, "all_gather op is only supported for Fast Dispatch");
auto devices = input_tensor.get_workers();
Expand All @@ -187,23 +187,23 @@ Tensor all_gather(
ccl_topology = ttnn::ccl::Topology::Linear;
}

int16_t rank = input_tensor.get_logical_shape().rank();
int32_t rank = input_tensor.get_logical_shape().rank();

int16_t dim = (gather_dim < 0) ? rank + gather_dim : gather_dim;
int32_t gather_dim = (dim < 0) ? rank + dim : dim;

TT_FATAL(dim >= -rank && dim <= rank - 1 , "Dimension input should be in between -{} and {}, but has {}", rank, rank - 1, dim);
TT_FATAL(gather_dim >= -rank && gather_dim <= rank - 1 , "Dimension input should be in between -{} and {}, but has {}", rank, rank - 1, dim);

std::vector<Tensor> output_tensors = {Tensor(operation::get_workers_for_op_output({input_tensor}))};
operation::launch_op(
[dim, num_links, memory_config, user_defined_num_workers, user_defined_num_buffers_per_channel, devices, ccl_topology](
[gather_dim, num_links, memory_config, user_defined_num_workers, user_defined_num_buffers_per_channel, devices, ccl_topology](
const std::vector<Tensor>& input_tensors,
const std::vector<std::optional<const Tensor>>& optional_input_tensors,
const std::vector<std::optional<Tensor>>& optional_output_tensors) mutable -> std::vector<Tensor> {

const auto& input_tensor = input_tensors.at(0);

return operation::run(
ttnn::ccl::all_gather_detail::create_all_gather_struct(input_tensor, dim, num_links, memory_config, user_defined_num_workers, user_defined_num_buffers_per_channel, devices, ccl_topology),
ttnn::ccl::all_gather_detail::create_all_gather_struct(input_tensor, gather_dim, num_links, memory_config, user_defined_num_workers, user_defined_num_buffers_per_channel, devices, ccl_topology),
{input_tensor});
},
{input_tensor},
Expand All @@ -213,7 +213,7 @@ Tensor all_gather(

Tensor all_gather(
const Tensor& input_tensor,
const int16_t gather_dim,
const int32_t dim,
const uint32_t cluster_axis,
const MeshDevice& mesh_device,
const uint32_t num_links,
Expand All @@ -226,16 +226,16 @@ Tensor all_gather(
const auto mesh_view = mesh_device.get_view();
std::size_t num_devices = (cluster_axis == 0) ? mesh_view->num_rows() : mesh_view->num_cols();

int16_t rank = input_tensor.get_logical_shape().rank();
int32_t rank = input_tensor.get_logical_shape().rank();

int16_t dim = (gather_dim < 0) ? rank + gather_dim : gather_dim;
int32_t gather_dim = (dim < 0) ? rank + dim : dim;

TT_FATAL(dim >= -rank && dim <= rank - 1 , "Dimension input should be in between -{} and {}, but has {}", rank, rank - 1, dim);
TT_FATAL(gather_dim >= -rank && gather_dim <= rank - 1 , "Dimension input should be in between -{} and {}, but has {}", rank, rank - 1, dim);

std::vector<Tensor> output_tensors = {Tensor(operation::get_workers_for_op_output({input_tensor}))};

operation::launch_op(
[dim, num_links, memory_config, mesh_view, cluster_axis, user_defined_num_workers, user_defined_num_buffers_per_channel, num_devices, topology](
[gather_dim, num_links, memory_config, mesh_view, cluster_axis, user_defined_num_workers, user_defined_num_buffers_per_channel, num_devices, topology](
const std::vector<Tensor>& input_tensors,
const std::vector<std::optional<const Tensor>>& optional_input_tensors,
const std::vector<std::optional<Tensor>>& optional_output_tensors) mutable -> std::vector<Tensor> {
Expand Down Expand Up @@ -263,7 +263,7 @@ Tensor all_gather(

return operation::run(
ttnn::AllGather{
dim, num_links, num_devices, device_index, user_defined_num_workers, user_defined_num_buffers_per_channel, receiver_device_id, sender_device_id, memory_config.value_or(input_device_tensor.memory_config()), topology},
gather_dim, num_links, num_devices, device_index, user_defined_num_workers, user_defined_num_buffers_per_channel, receiver_device_id, sender_device_id, memory_config.value_or(input_device_tensor.memory_config()), topology},
{input_device_tensor});
},
{input_tensor},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ namespace ccl {

Tensor all_gather(
const Tensor& input_tensor,
const int16_t gather_dim,
const int32_t dim,
const uint32_t num_links = 1,
const std::optional<MemoryConfig>& memory_config = std::nullopt,
const std::optional<size_t> user_defined_num_workers = std::nullopt,
Expand All @@ -209,7 +209,7 @@ Tensor all_gather(

Tensor all_gather(
const Tensor& input_tensor,
const int16_t gather_dim,
const int32_t dim,
const uint32_t cluster_axis,
const MeshDevice& mesh_device,
const uint32_t num_links = 1,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ namespace operations{
namespace ccl{
Tensor reduce_scatter(
const Tensor& input_tensor,
const int16_t scatter_dim,
const int32_t dim,
ttnn::operations::reduction::ReduceType math_op,
const uint32_t num_links,
const MemoryConfig& output_mem_config,
Expand All @@ -128,13 +128,13 @@ Tensor reduce_scatter(

int16_t rank = input_tensor.get_logical_shape().rank();

int16_t dim = (scatter_dim < 0) ? rank + scatter_dim : scatter_dim;
int16_t scatter_dim = (dim < 0) ? rank + dim : dim;

TT_FATAL(dim >= -rank && dim <= rank - 1 , "Dimension input should be in between -{} and {}, but has {}", rank, rank - 1, dim);
TT_FATAL(scatter_dim >= -rank && scatter_dim <= rank - 1 , "Dimension input should be in between -{} and {}, but has {}", rank, rank - 1, dim);

std::vector<Tensor> output_tensors = {Tensor(operation::get_workers_for_op_output({input_tensor}))};
operation::launch_op(
[binary_op_type, dim, num_links, output_mem_config, ccl_topology, devices, user_defined_num_workers, user_defined_num_buffers_per_channel](
[binary_op_type, scatter_dim, num_links, output_mem_config, ccl_topology, devices, user_defined_num_workers, user_defined_num_buffers_per_channel](
const std::vector<Tensor>& input_tensors,
const std::vector<std::optional<const Tensor>>& optional_input_tensors,
const std::vector<std::optional<Tensor>>& optional_output_tensors) mutable -> std::vector<Tensor> {
Expand All @@ -145,7 +145,7 @@ Tensor reduce_scatter(
ttnn::ccl::reduce_scatter_detail::create_reduce_scatter_struct(
input_tensor,
binary_op_type,
dim,
scatter_dim,
num_links,
output_mem_config,
user_defined_num_workers,
Expand All @@ -164,7 +164,7 @@ Tensor reduce_scatter(

Tensor reduce_scatter(
const Tensor &input_tensor,
const int16_t scatter_dim,
const int32_t dim,
const uint32_t cluster_axis,
const MeshDevice& mesh_device,
ttnn::operations::reduction::ReduceType reduce_op,
Expand All @@ -182,14 +182,14 @@ Tensor reduce_scatter(

int16_t rank = input_tensor.get_logical_shape().rank();

int16_t dim = (scatter_dim < 0) ? rank + scatter_dim : scatter_dim;
int16_t scatter_dim = (dim < 0) ? rank + dim : dim;

TT_FATAL(dim >= -rank && dim <= rank - 1 , "Dimension input should be in between -{} and {}, but has {}", rank, rank - 1, dim);
TT_FATAL(scatter_dim >= -rank && scatter_dim <= rank - 1 , "Dimension input should be in between -{} and {}, but has {}", rank, rank - 1, dim);

std::vector<Tensor> output_tensors = {Tensor(operation::get_workers_for_op_output({input_tensor}))};

operation::launch_op(
[dim, binary_op_type, num_links, output_mem_config, mesh_view, cluster_axis, user_defined_num_workers, user_defined_num_buffers_per_channel, num_devices, topology](
[scatter_dim, binary_op_type, num_links, output_mem_config, mesh_view, cluster_axis, user_defined_num_workers, user_defined_num_buffers_per_channel, num_devices, topology](
const std::vector<Tensor>& input_tensors,
const std::vector<std::optional<const Tensor>>& optional_input_tensors,
const std::vector<std::optional<Tensor>>& optional_output_tensors) mutable -> std::vector<Tensor> {
Expand Down Expand Up @@ -218,7 +218,7 @@ Tensor reduce_scatter(
return operation::run(
ttnn::ReduceScatter{
binary_op_type,
dim,
scatter_dim,
num_links,
num_devices,
device_index,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ namespace operations{
namespace ccl{
Tensor reduce_scatter(
const Tensor &input_tensor,
const int16_t scatter_split_dim,
const int32_t dim,
ttnn::operations::reduction::ReduceType reduce_op = ttnn::operations::reduction::ReduceType::Sum,
const uint32_t num_links = 1,
const MemoryConfig &output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG,
Expand All @@ -79,7 +79,7 @@ Tensor reduce_scatter(

Tensor reduce_scatter(
const ttnn::Tensor &input_tensor,
const int16_t scatter_dim,
const int32_t dim,
const uint32_t cluster_axis,
const MeshDevice& mesh_device,
ttnn::operations::reduction::ReduceType reduce_op = ttnn::operations::reduction::ReduceType::Sum,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ namespace ttnn::operations::ccl {

ttnn::Tensor ExecuteReduceScatter::invoke(
const ttnn::Tensor& input_tensor,
const int16_t scatter_dim,
const int32_t dim,
ttnn::operations::reduction::ReduceType math_op,
const uint32_t num_links,
const std::optional<ttnn::MemoryConfig>& memory_config,
Expand All @@ -19,11 +19,11 @@ ttnn::Tensor ExecuteReduceScatter::invoke(
const std::optional<size_t> num_buffers_per_channel) {

MemoryConfig out_memory_config = memory_config.value_or(input_tensor.memory_config());
return ttnn::operations::ccl::reduce_scatter(input_tensor, scatter_dim, math_op, num_links, out_memory_config, topology, num_workers, num_buffers_per_channel);
return ttnn::operations::ccl::reduce_scatter(input_tensor, dim, math_op, num_links, out_memory_config, topology, num_workers, num_buffers_per_channel);
}
ttnn::Tensor ExecuteReduceScatter::invoke(
const ttnn::Tensor& input_tensor,
const int16_t scatter_dim,
const int32_t dim,
const uint32_t cluster_axis,
const MeshDevice& mesh_device,
ttnn::operations::reduction::ReduceType math_op,
Expand All @@ -34,7 +34,7 @@ ttnn::Tensor ExecuteReduceScatter::invoke(
const std::optional<size_t> num_buffers_per_channel) {

MemoryConfig out_memory_config = memory_config.value_or(input_tensor.memory_config());
return ttnn::operations::ccl::reduce_scatter(input_tensor, scatter_dim, cluster_axis, mesh_device, math_op, num_links, out_memory_config, topology, num_workers, num_buffers_per_channel);
return ttnn::operations::ccl::reduce_scatter(input_tensor, dim, cluster_axis, mesh_device, math_op, num_links, out_memory_config, topology, num_workers, num_buffers_per_channel);
}

} // namespace ttnn::operations::ccl
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ namespace ccl {
struct ExecuteReduceScatter {
static ttnn::Tensor invoke(
const Tensor &input_tensor,
const int16_t scatter_dim,
const int32_t dim,
const uint32_t cluster_axis,
const MeshDevice& mesh_device,
ttnn::operations::reduction::ReduceType reduce_op = ttnn::operations::reduction::ReduceType::Sum,
Expand All @@ -29,7 +29,7 @@ struct ExecuteReduceScatter {

static ttnn::Tensor invoke(
const ttnn::Tensor& input_tensor,
const int16_t scatter_dim,
const int32_t dim,
ttnn::operations::reduction::ReduceType math_op,
const uint32_t num_links = 1,
const std::optional<ttnn::MemoryConfig>& memory_config = std::nullopt,
Expand Down
Loading

0 comments on commit c44fa2a

Please sign in to comment.