diff --git a/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/device/all_gather_async_op.cpp b/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/device/all_gather_async_op.cpp index 56999d464d5..2cd2a82cdd5 100644 --- a/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/device/all_gather_async_op.cpp +++ b/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/device/all_gather_async_op.cpp @@ -123,8 +123,21 @@ operation::ProgramWithCallbacks AllGatherAsync::create_program( } const operation::Hash AllGatherAsync::compute_program_hash(const std::vector& input_tensors) const { + auto input_shape = input_tensors[0].get_padded_shape(); + auto input_memory_layout = input_tensors[0].get_layout(); + auto input_dtype = input_tensors[0].get_dtype(); + auto input_memory_config = input_tensors[0].memory_config(); return operation::hash_operation( - this->dim, this->num_links, this->ring_size, this->ring_index, this->output_mem_config, this->topology); + this->dim, + this->num_links, + this->ring_size, + this->ring_index, + this->output_mem_config, + this->topology, + input_shape, + input_memory_layout, + input_dtype, + input_memory_config); } diff --git a/ttnn/cpp/ttnn/operations/experimental/ccl/reduce_scatter_async/device/reduce_scatter_async_op.cpp b/ttnn/cpp/ttnn/operations/experimental/ccl/reduce_scatter_async/device/reduce_scatter_async_op.cpp index cdd9aff9758..e6c804523ad 100644 --- a/ttnn/cpp/ttnn/operations/experimental/ccl/reduce_scatter_async/device/reduce_scatter_async_op.cpp +++ b/ttnn/cpp/ttnn/operations/experimental/ccl/reduce_scatter_async/device/reduce_scatter_async_op.cpp @@ -159,8 +159,20 @@ operation::ProgramWithCallbacks ReduceScatterAsync::create_program( } operation::Hash ReduceScatterAsync::compute_program_hash(const std::vector& input_tensors) const { + auto input_shape = input_tensors[0].get_padded_shape(); + auto input_memory_layout = input_tensors[0].get_layout(); + auto input_dtype = input_tensors[0].get_dtype(); + auto input_memory_config = input_tensors[0].memory_config(); return operation::hash_operation( - this->binary_op_type, this->scatter_dim, this->ring_size, this->ring_index, this->topology); + this->binary_op_type, + this->scatter_dim, + this->ring_size, + this->ring_index, + this->topology, + input_shape, + input_memory_layout, + input_dtype, + input_memory_config); } namespace {