Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TT-Train ]added tests for sum and mean #16152

Merged
merged 12 commits into from
Dec 19, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions tt-train/sources/ttml/ttnn_fixed/trivial_ttnn_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include <core/ttnn_all_includes.hpp>
#include <ttnn/operations/moreh/moreh_sum/moreh_sum.hpp>

#include "autograd/auto_context.hpp"
#include "core/compute_kernel_config.hpp"
#include "core/tt_tensor_utils.hpp"

Expand Down Expand Up @@ -54,4 +55,41 @@ tt::tt_metal::Tensor divide(const tt::tt_metal::Tensor& a, const tt::tt_metal::T
return ttnn::multiply(a, inv_b);
}

tt::tt_metal ::Tensor mean_moreh(const tt::tt_metal::Tensor& t, int dim, bool keep_dim) {
// auto tensor_shape = t.get_shape();
// auto shape = core::create_shape({tensor_shape[0], tensor_shape[1], tensor_shape[2], 1});
// auto* device = &autograd::ctx().get_device();
// auto mean = core::empty(shape, device, t.memory_config());
auto res = ttnn::moreh_mean(
t,
dim,
keep_dim,
std::nullopt,
std::nullopt,
std::nullopt,
/* device_compute_kernel_config */ core::ComputeKernelConfig::precise());
return res;
}
tt::tt_metal ::Tensor mean_ttnn(const tt::tt_metal::Tensor& t, int dim, bool keep_dim) {
return ttnn::mean(t, dim, keep_dim, std::nullopt, core::ComputeKernelConfig::precise());
}

tt::tt_metal ::Tensor sum_moreh(const tt::tt_metal::Tensor& t, int dim, bool keep_dim) {
// auto tensor_shape = t.get_shape();
// auto shape = core::create_shape({tensor_shape[0], tensor_shape[1], tensor_shape[2], 1});
// auto* device = &autograd::ctx().get_device();
// auto mean = core::empty(shape, device, t.memory_config());
auto res = ttnn::moreh_sum(
t,
dim,
keep_dim,
std::nullopt,
std::nullopt,
/* device_compute_kernel_config */ core::ComputeKernelConfig::precise());
return res;
}
tt::tt_metal ::Tensor sum_ttnn(const tt::tt_metal::Tensor& t, int dim, bool keep_dim) {
return ttnn::sum(t, dim, keep_dim, std::nullopt, core::ComputeKernelConfig::precise());
}

} // namespace ttml::ttnn_fixed
7 changes: 7 additions & 0 deletions tt-train/sources/ttml/ttnn_fixed/trivial_ttnn_ops.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
#pragma once
#include <core/ttnn_all_includes.hpp>

#include "core/tt_tensor_utils.hpp"

namespace ttml::ttnn_fixed {

tt::tt_metal::Tensor sum_over_dim(const tt::tt_metal::Tensor& t, uint32_t dim);
Expand All @@ -13,4 +15,9 @@ tt::tt_metal::Tensor log_softmax(const tt::tt_metal::Tensor& t, int dim);
tt::tt_metal::Tensor softmax(const tt::tt_metal::Tensor& t, int dim);
tt::tt_metal::Tensor divide(const tt::tt_metal::Tensor& a, const tt::tt_metal::Tensor& b);

tt::tt_metal ::Tensor mean_moreh(const tt::tt_metal::Tensor& t, int dim, bool keep_dim);
tt::tt_metal ::Tensor mean_ttnn(const tt::tt_metal::Tensor& t, int dim, bool keep_dim);

tt::tt_metal ::Tensor sum_moreh(const tt::tt_metal::Tensor& t, int dim, bool keep_dim);
tt::tt_metal ::Tensor sum_ttnn(const tt::tt_metal::Tensor& t, int dim, bool keep_dim);
} // namespace ttml::ttnn_fixed
73 changes: 73 additions & 0 deletions tt-train/tests/ttnn_fixed/reduce_ops_test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
//
// SPDX-License-Identifier: Apache-2.0

#include <gtest/gtest.h>

#include <core/ttnn_all_includes.hpp>
#include <memory>
#include <ttnn/operations/core/compute_kernel/compute_kernel_config.hpp>
#include <ttnn/operations/reduction/generic/generic_reductions.hpp>
#include <vector>

#include "autograd/auto_context.hpp"
#include "core/compute_kernel_config.hpp"
#include "core/device.hpp"
#include "core/tt_tensor_utils.hpp"
#include "ttnn_fixed/trivial_ttnn_ops.hpp"

class ReduceOpTest : public ::testing::Test {
protected:
void SetUp() override {
ttml::autograd::ctx().open_device();
}

void TearDown() override {
ttml::autograd::ctx().close_device();
}
};

TEST_F(ReduceOpTest, TestMean) {
xt::random::seed(42);
auto* device = &ttml::autograd::ctx().get_device();
xt::xarray<float> xtensor_a = xt::random::rand({128 * 64}, -0.5, 0.5).reshape({2, 1, 64, 64});

auto xtensor_a_tensor = ttml::core::from_xtensor(xtensor_a, device);

auto ttnn_mean_dim3 = ttml::ttnn_fixed::mean_ttnn(xtensor_a_tensor, 3, true);
auto moreh_mean_dim3 = ttml::ttnn_fixed::mean_moreh(xtensor_a_tensor, 3, true);

xt::xarray<float> mean_xtensor = xt::mean(xtensor_a, {3}, xt::evaluation_strategy::immediate);
mean_xtensor.reshape({2, 1, 64, 1});

auto mean_ttnn = ttml::core::to_xtensor(ttnn_mean_dim3);
auto mean_moreh = ttml::core::to_xtensor(moreh_mean_dim3);

EXPECT_TRUE(xt::allclose(mean_ttnn, mean_moreh, /*rtol=*/1e-4, /*atol=*/1e-3));
EXPECT_TRUE(xt::allclose(mean_xtensor, mean_ttnn, /*rtol=*/1e-3, /*atol=*/1e-2));
EXPECT_TRUE(xt::allclose(mean_xtensor, mean_moreh, /*rtol=*/1e-3, /*atol=*/1e-2));
}

TEST_F(ReduceOpTest, TestSum) {
xt::random::seed(42);
auto* device = &ttml::autograd::ctx().get_device();
xt::xarray<float> xtensor_a = xt::random::rand({128 * 64}, -0.1, 0.1).reshape({2, 1, 64, 64});

auto xtensor_a_tensor = ttml::core::from_xtensor(xtensor_a, device);

auto ttnn_sum_dim3 = ttml::ttnn_fixed::sum_ttnn(xtensor_a_tensor, 3, true);
auto moreh_sum_dim3 = ttml::ttnn_fixed::sum_moreh(xtensor_a_tensor, 3, true);

xt::xarray<float> sum_xtensor = xt::sum(xtensor_a, {3}, xt::evaluation_strategy::immediate);
sum_xtensor.reshape({2, 1, 64, 1});

auto sum_ttnn = ttml::core::to_xtensor(ttnn_sum_dim3);
auto sum_moreh = ttml::core::to_xtensor(moreh_sum_dim3);
std::cout << sum_ttnn << std::endl;
std::cout << "------------" << std::endl;
std::cout << sum_xtensor << std::endl;

EXPECT_TRUE(xt::allclose(sum_ttnn, sum_moreh, /*rtol=*/1e-4, /*atol=*/1e-3));
EXPECT_TRUE(xt::allclose(sum_xtensor, sum_ttnn, /*rtol=*/1e-2, /*atol=*/1e-2));
EXPECT_TRUE(xt::allclose(sum_xtensor, sum_moreh, /*rtol=*/1e-2, /*atol=*/1e-2));
}
Loading