Skip to content

Commit

Permalink
[TT-Train ]added tests for sum and mean (#16152)
Browse files Browse the repository at this point in the history
### Problem description
We need to know that ttnn reduce ops are almost the same as moreh.

### What's changed
* Added sum test
* Added mean test
* Updated sum_over_dim

<img width="305" alt="Screenshot 2024-12-18 at 1 56 19 PM"
src="https://github.com/user-attachments/assets/58faaccb-83be-47db-a4f2-8578ae5a68db"
/>


### Checklist
- [x] Post commit CI passes
- [x] Blackhole Post commit (if applicable)
- [x] Model regression CI testing passes (if applicable)
- [x] Device performance regression CI testing passes (if applicable)
- [x] **(For models and ops writers)** Full [new
models](https://github.com/tenstorrent/tt-metal/actions/workflows/full-new-models-suite.yaml)
tests passes
- [x] New/Existing tests provide coverage for changes

https://github.com/tenstorrent/tt-metal/actions/runs/12405523060

---------

Co-authored-by: Roman Furko <[email protected]>
  • Loading branch information
dmakoviichuk-tt and rfurko-tt authored Dec 19, 2024
1 parent ed9964a commit 5c91e97
Show file tree
Hide file tree
Showing 3 changed files with 171 additions and 7 deletions.
38 changes: 31 additions & 7 deletions tt-train/sources/ttml/ttnn_fixed/trivial_ttnn_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,14 @@
#include <core/ttnn_all_includes.hpp>
#include <ttnn/operations/moreh/moreh_sum/moreh_sum.hpp>

#include "autograd/auto_context.hpp"
#include "core/compute_kernel_config.hpp"
#include "core/tt_tensor_utils.hpp"

namespace ttml::ttnn_fixed {

tt::tt_metal::Tensor sum_over_dim(const tt::tt_metal::Tensor& t, uint32_t dim) {
return ttnn::moreh_sum(
t,
/* dim */ dim,
/* keep_dim */ true,
/* output */ std::nullopt,
/* output_mem_config */ std::nullopt,
/*compute_kernel_config */ core::ComputeKernelConfig::precise());
return sum_ttnn(t, dim, /* keepdim */ true);
}

tt::tt_metal::Tensor sum_over_batch(const tt::tt_metal::Tensor& t) {
Expand Down Expand Up @@ -54,4 +49,33 @@ tt::tt_metal::Tensor divide(const tt::tt_metal::Tensor& a, const tt::tt_metal::T
return ttnn::multiply(a, inv_b);
}

tt::tt_metal::Tensor mean_moreh(const tt::tt_metal::Tensor& t, int dim, bool keep_dim) {
auto res = ttnn::moreh_mean(
t,
dim,
keep_dim,
std::nullopt,
std::nullopt,
std::nullopt,
/* device_compute_kernel_config */ core::ComputeKernelConfig::precise());
return res;
}
tt::tt_metal::Tensor mean_ttnn(const tt::tt_metal::Tensor& t, int dim, bool keep_dim) {
return ttnn::mean(t, dim, keep_dim, std::nullopt, core::ComputeKernelConfig::precise());
}

tt::tt_metal::Tensor sum_moreh(const tt::tt_metal::Tensor& t, int dim, bool keep_dim) {
auto res = ttnn::moreh_sum(
t,
dim,
keep_dim,
std::nullopt,
std::nullopt,
/* device_compute_kernel_config */ core::ComputeKernelConfig::precise());
return res;
}
tt::tt_metal::Tensor sum_ttnn(const tt::tt_metal::Tensor& t, int dim, bool keep_dim) {
return ttnn::sum(t, dim, keep_dim, std::nullopt, core::ComputeKernelConfig::precise());
}

} // namespace ttml::ttnn_fixed
7 changes: 7 additions & 0 deletions tt-train/sources/ttml/ttnn_fixed/trivial_ttnn_ops.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
#pragma once
#include <core/ttnn_all_includes.hpp>

#include "core/tt_tensor_utils.hpp"

namespace ttml::ttnn_fixed {

tt::tt_metal::Tensor sum_over_dim(const tt::tt_metal::Tensor& t, uint32_t dim);
Expand All @@ -13,4 +15,9 @@ tt::tt_metal::Tensor log_softmax(const tt::tt_metal::Tensor& t, int dim);
tt::tt_metal::Tensor softmax(const tt::tt_metal::Tensor& t, int dim);
tt::tt_metal::Tensor divide(const tt::tt_metal::Tensor& a, const tt::tt_metal::Tensor& b);

tt::tt_metal::Tensor mean_moreh(const tt::tt_metal::Tensor& t, int dim, bool keep_dim);
tt::tt_metal::Tensor mean_ttnn(const tt::tt_metal::Tensor& t, int dim, bool keep_dim);

tt::tt_metal::Tensor sum_moreh(const tt::tt_metal::Tensor& t, int dim, bool keep_dim);
tt::tt_metal::Tensor sum_ttnn(const tt::tt_metal::Tensor& t, int dim, bool keep_dim);
} // namespace ttml::ttnn_fixed
133 changes: 133 additions & 0 deletions tt-train/tests/ttnn_fixed/reduce_ops_test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
//
// SPDX-License-Identifier: Apache-2.0

#include <gtest/gtest.h>

#include <core/ttnn_all_includes.hpp>
#include <memory>
#include <ttnn/operations/core/compute_kernel/compute_kernel_config.hpp>
#include <ttnn/operations/reduction/generic/generic_reductions.hpp>
#include <vector>

#include "autograd/auto_context.hpp"
#include "core/compute_kernel_config.hpp"
#include "core/device.hpp"
#include "core/tt_tensor_utils.hpp"
#include "ttnn_fixed/trivial_ttnn_ops.hpp"

class ReduceOpTest : public ::testing::Test {
protected:
void SetUp() override {
ttml::autograd::ctx().open_device();
}

void TearDown() override {
ttml::autograd::ctx().close_device();
}
};

TEST_F(ReduceOpTest, TestMeanDim0) {
xt::random::seed(42);
auto* device = &ttml::autograd::ctx().get_device();
xt::xarray<float> xtensor_a = xt::random::rand({128 * 64}, -0.5, 0.5).reshape({2, 1, 64, 64});

auto xtensor_a_tensor = ttml::core::from_xtensor(xtensor_a, device);

auto ttnn_mean_dim0 = ttml::ttnn_fixed::mean_ttnn(xtensor_a_tensor, 0, true);
auto moreh_mean_dim0 = ttml::ttnn_fixed::mean_moreh(xtensor_a_tensor, 0, true);

xt::xarray<float> mean_xtensor = xt::mean(xtensor_a, {0}, xt::evaluation_strategy::immediate);
mean_xtensor.reshape({1, 1, 64, 64});

auto mean_ttnn = ttml::core::to_xtensor(ttnn_mean_dim0);
auto mean_moreh = ttml::core::to_xtensor(moreh_mean_dim0);

EXPECT_TRUE(xt::allclose(mean_ttnn, mean_moreh, /*rtol=*/1e-4, /*atol=*/1e-3));
EXPECT_TRUE(xt::allclose(mean_xtensor, mean_ttnn, /*rtol=*/1e-3, /*atol=*/1e-2));
EXPECT_TRUE(xt::allclose(mean_xtensor, mean_moreh, /*rtol=*/1e-3, /*atol=*/1e-2));
}

TEST_F(ReduceOpTest, TestSumDim0) {
xt::random::seed(42);
auto* device = &ttml::autograd::ctx().get_device();
xt::xarray<float> xtensor_a = xt::random::rand({128 * 64}, -0.1, 0.1).reshape({2, 1, 64, 64});

auto xtensor_a_tensor = ttml::core::from_xtensor(xtensor_a, device);

auto ttnn_sum_dim0 = ttml::ttnn_fixed::sum_ttnn(xtensor_a_tensor, 0, true);
auto moreh_sum_dim0 = ttml::ttnn_fixed::sum_moreh(xtensor_a_tensor, 0, true);

xt::xarray<float> sum_xtensor = xt::sum(xtensor_a, {0}, xt::evaluation_strategy::immediate);
sum_xtensor.reshape({1, 1, 64, 64});

auto sum_ttnn = ttml::core::to_xtensor(ttnn_sum_dim0);
auto sum_moreh = ttml::core::to_xtensor(moreh_sum_dim0);

EXPECT_TRUE(xt::allclose(sum_ttnn, sum_moreh, /*rtol=*/1e-4, /*atol=*/1e-3));
EXPECT_TRUE(xt::allclose(sum_xtensor, sum_ttnn, /*rtol=*/1e-2, /*atol=*/1e-2));
EXPECT_TRUE(xt::allclose(sum_xtensor, sum_moreh, /*rtol=*/1e-2, /*atol=*/1e-2));
}

TEST_F(ReduceOpTest, TestMeanDim3) {
xt::random::seed(42);
auto* device = &ttml::autograd::ctx().get_device();
xt::xarray<float> xtensor_a = xt::random::rand({128 * 64}, -0.5, 0.5).reshape({2, 1, 64, 64});

auto xtensor_a_tensor = ttml::core::from_xtensor(xtensor_a, device);

auto ttnn_mean_dim3 = ttml::ttnn_fixed::mean_ttnn(xtensor_a_tensor, 3, true);
auto moreh_mean_dim3 = ttml::ttnn_fixed::mean_moreh(xtensor_a_tensor, 3, true);

xt::xarray<float> mean_xtensor = xt::mean(xtensor_a, {3}, xt::evaluation_strategy::immediate);
mean_xtensor.reshape({2, 1, 64, 1});

auto mean_ttnn = ttml::core::to_xtensor(ttnn_mean_dim3);
auto mean_moreh = ttml::core::to_xtensor(moreh_mean_dim3);

EXPECT_TRUE(xt::allclose(mean_ttnn, mean_moreh, /*rtol=*/1e-4, /*atol=*/1e-3));
EXPECT_TRUE(xt::allclose(mean_xtensor, mean_ttnn, /*rtol=*/1e-3, /*atol=*/1e-2));
EXPECT_TRUE(xt::allclose(mean_xtensor, mean_moreh, /*rtol=*/1e-3, /*atol=*/1e-2));
}

TEST_F(ReduceOpTest, TestSumDim3) {
xt::random::seed(42);
auto* device = &ttml::autograd::ctx().get_device();
xt::xarray<float> xtensor_a = xt::random::rand({128 * 64}, -0.1, 0.1).reshape({2, 1, 64, 64});

auto xtensor_a_tensor = ttml::core::from_xtensor(xtensor_a, device);

auto ttnn_sum_dim3 = ttml::ttnn_fixed::sum_ttnn(xtensor_a_tensor, 3, true);
auto moreh_sum_dim3 = ttml::ttnn_fixed::sum_moreh(xtensor_a_tensor, 3, true);

xt::xarray<float> sum_xtensor = xt::sum(xtensor_a, {3}, xt::evaluation_strategy::immediate);
sum_xtensor.reshape({2, 1, 64, 1});

auto sum_ttnn = ttml::core::to_xtensor(ttnn_sum_dim3);
auto sum_moreh = ttml::core::to_xtensor(moreh_sum_dim3);

EXPECT_TRUE(xt::allclose(sum_ttnn, sum_moreh, /*rtol=*/1e-4, /*atol=*/1e-3));
EXPECT_TRUE(xt::allclose(sum_xtensor, sum_ttnn, /*rtol=*/1e-2, /*atol=*/1e-2));
EXPECT_TRUE(xt::allclose(sum_xtensor, sum_moreh, /*rtol=*/1e-2, /*atol=*/1e-2));
}

TEST_F(ReduceOpTest, TestMeanLargeDim3) {
xt::random::seed(42);
auto* device = &ttml::autograd::ctx().get_device();
xt::xarray<float> xtensor_a = xt::random::rand({1024 * 1024}, -0.5, 0.5).reshape({2, 1, 512, 1024});

auto xtensor_a_tensor = ttml::core::from_xtensor(xtensor_a, device);

auto ttnn_mean_dim3 = ttml::ttnn_fixed::mean_ttnn(xtensor_a_tensor, 3, true);
auto moreh_mean_dim3 = ttml::ttnn_fixed::mean_moreh(xtensor_a_tensor, 3, true);

xt::xarray<float> mean_xtensor = xt::mean(xtensor_a, {3}, xt::evaluation_strategy::immediate);
mean_xtensor.reshape({2, 1, 512, 1});

auto mean_ttnn = ttml::core::to_xtensor(ttnn_mean_dim3);
auto mean_moreh = ttml::core::to_xtensor(moreh_mean_dim3);

EXPECT_TRUE(xt::allclose(mean_ttnn, mean_moreh, /*rtol=*/1e-4, /*atol=*/1e-3));
EXPECT_TRUE(xt::allclose(mean_xtensor, mean_ttnn, /*rtol=*/1e-3, /*atol=*/1e-2));
EXPECT_TRUE(xt::allclose(mean_xtensor, mean_moreh, /*rtol=*/1e-3, /*atol=*/1e-2));
}

0 comments on commit 5c91e97

Please sign in to comment.