diff --git a/include/kwk/algorithm/algos/transform.hpp b/include/kwk/algorithm/algos/transform.hpp index 406d3a9f..937353ac 100644 --- a/include/kwk/algorithm/algos/transform.hpp +++ b/include/kwk/algorithm/algos/transform.hpp @@ -1,25 +1,33 @@ -//================================================================================================== +//====================================================================================================================== /** KIWAKU - Containers Well Made Copyright : KIWAKU Project Contributors SPDX-License-Identifier: BSL-1.0 **/ -//================================================================================================== +//====================================================================================================================== #pragma once #include -#include -#include -#include -#include +#include namespace kwk { + // Transform is not a required part of contexts anymore + template< typename Context, typename Func, concepts::container Out + , concepts::container C0, concepts::container... Cs + > + constexpr void transform(Context& ctx,Func&& f, Out& out, C0 const& c0, Cs const&... cs) + { + ctx.map ( [&](auto& o, auto const& i0, auto const&... in) { o = KWK_FWD(f)(i0, in...); } + , ctx.out(out), ctx.in(c0), ctx.in(cs)... + ); + } + template< typename Func, concepts::container Out , concepts::container C0, concepts::container... Cs > - constexpr auto transform(Func f, Out& out, C0&& c0, Cs&&... cs) + constexpr void transform(Func&& f, Out& out, C0&& c0, Cs&&... cs) { - kwk::for_each([&](auto... is) { out(is...) = f(KWK_FWD(c0)(is...), KWK_FWD(cs)(is...)...); }, out.shape() ); + kwk::transform(cpu, KWK_FWD(f), out, KWK_FWD(c0), KWK_FWD(cs)...); } } diff --git a/test/algorithm/algos/context/cpu/transform.cpp b/test/algorithm/algos/context/cpu/transform.cpp new file mode 100644 index 00000000..d612dbf7 --- /dev/null +++ b/test/algorithm/algos/context/cpu/transform.cpp @@ -0,0 +1,34 @@ +//====================================================================================================================== +/* + KIWAKU - Containers Well Made + Copyright : KIWAKU Contributors & Maintainers + SPDX-License-Identifier: BSL-1.0 +*/ +//====================================================================================================================== +#include +#include +#include +#include "test.hpp" +#include "../generic/transform.hpp" + +// TODO: update these tests + +TTS_CASE("Check for kwk::transform(value, new_value) 1D - CPU context") +{ + kwk::test::transform_value_new_value_1D(kwk::cpu); +}; + +TTS_CASE("Check for kwk::transform(value, new_value) 2D - CPU context") +{ + kwk::test::transform_value_new_value_2D(kwk::cpu); +}; + +TTS_CASE("Check for kwk::transform(value, new_value) 3D - CPU context") +{ + kwk::test::transform_value_new_value_3D(kwk::cpu); +}; + +TTS_CASE("Check for kwk::transform(value, new_value) 4D - CPU context") +{ + kwk::test::transform_value_new_value_4D(kwk::cpu); +}; diff --git a/test/algorithm/algos/context/generic/transform.hpp b/test/algorithm/algos/context/generic/transform.hpp new file mode 100644 index 00000000..cf81a852 --- /dev/null +++ b/test/algorithm/algos/context/generic/transform.hpp @@ -0,0 +1,104 @@ +//====================================================================================================================== +/* + KIWAKU - Containers Well Made + Copyright : KIWAKU Contributors & Maintainers + SPDX-License-Identifier: BSL-1.0 +*/ +//====================================================================================================================== +#pragma once + +#include +#include +#include "test.hpp" + +// TODO: update these tests + +namespace kwk::test +{ + template + void transform_value_new_value_1D(Context&& ctx) + { + int data[2]; + double res[2]; + double vdata[2] = {1, 0.5}; + + fill_data(data, kwk::of_size(2), true); + + auto d = kwk::view{kwk::source = data, kwk::of_size(2)}; + auto v = kwk::view{kwk::source = res, kwk::of_size(2)}; + + ::kwk::transform(ctx, [&](auto e) { return 1.0/(1.0+e); }, v, d); + + TTS_ALL_EQUAL(res, vdata); + } + + template + void transform_value_new_value_2D(Context&& ctx) + { + int data[2*3]; + double res[2*3]; + double vdata[2*3]; + + fill_data(data, kwk::of_size(2,3), true); + fill_data(vdata, kwk::of_size(2,3), true); + + auto d = kwk::view{kwk::source = data, kwk::of_size(2,3)}; + auto v = kwk::view{kwk::source = res, kwk::of_size(2,3)}; + + for(int i = 0; i<2; i++) + for(int j = 0; j<3; j++) + vdata[i*3+j] = 1.0/(1.0+vdata[i*3+j]); + + ::kwk::transform(ctx, [&](auto e) { return 1.0/(1.0+e); }, v, d); + + TTS_ALL_EQUAL(res, vdata); + }; + + template + void transform_value_new_value_3D(Context&& ctx) + { + int data[2*3*4]; + double res[2*3*4]; + double vdata[2*3*4]; + + fill_data(data, kwk::of_size(2,3,4), true); + fill_data(vdata, kwk::of_size(2,3,4), true); + + auto d = kwk::view{kwk::source = data, kwk::of_size(2,3,4)}; + auto v = kwk::view{kwk::source = res, kwk::of_size(2,3,4)}; + + for(int i = 0; i<2; i++) + for(int j = 0; j<3; j++) + for(int k = 0; k<4; k++) + vdata[i*4*3+j*4+k] = 1.0/(1.0+vdata[i*4*3+j*4+k]); + + ::kwk::transform(ctx, [&](auto e) { return 1.0/(1.0+e); }, v, d); + + TTS_ALL_EQUAL(res, vdata); + }; + + template + void transform_value_new_value_4D(Context&& ctx) + { + int data[2*3*4*5]; + double res[2*3*4*5]; + double vdata[2*3*4*5]; + + fill_data(data, kwk::of_size(2,3,4,5), true); + fill_data(vdata, kwk::of_size(2,3,4,5), true); + + auto d = kwk::view{kwk::source = data, kwk::of_size(2,3,4,5)}; + auto v = kwk::view{kwk::source = res, kwk::of_size(2,3,4,5)}; + + for(int i = 0; i<2; i++) + for(int j = 0; j<3; j++) + for(int k = 0; k<4; k++) + for(int l = 0; l<5; l++) + vdata[i*5*4*3+j*5*4+k*5+l] = 1.0/(1.0+vdata[i*5*4*3+j*5*4+k*5+l]); + + ::kwk::transform(ctx, [&](auto e) { return 1.0/(1.0+e); }, v, d); + + TTS_ALL_EQUAL(res, vdata); + }; + +} \ No newline at end of file diff --git a/test/algorithm/algos/transform.cpp b/test/algorithm/algos/transform.cpp index 66c209b1..f9e219a1 100644 --- a/test/algorithm/algos/transform.cpp +++ b/test/algorithm/algos/transform.cpp @@ -9,117 +9,172 @@ #include #include #include "test.hpp" +#include TTS_CASE("Check for kwk::transform(value, new_value) 1D") { - int data[2]; - double res[2]; - double vdata[2] = {1, 0.5}; - - fill_data(data, kwk::of_size(2), true); - - auto d = kwk::view{kwk::source = data, kwk::of_size(2)}; - auto v = kwk::view{kwk::source = res, kwk::of_size(2)}; - - int count = 0; - transform( - [&](auto e) - { - count++; - return 1.0/(1.0+e); - }, v, d); - - TTS_ALL_EQUAL(res, vdata); - - TTS_EQUAL(count, d.numel()); + using data_type = int; + const std::size_t d0 = 87; + const std::size_t input_size = d0; + std::array input1, input2, output, check; + + for (std::size_t i = 0; i < input_size; ++i) { input1[i] = i * 3; input2[i] = i * 2; } + + auto view_in1 = kwk::view{kwk::source = input1 , kwk::of_size(d0)}; + auto view_in2 = kwk::view{kwk::source = input2 , kwk::of_size(d0)}; + auto view_out = kwk::view{kwk::source = output , kwk::of_size(d0)}; + + kwk::transform( [&](auto const e1, auto const e2) { + return e1 + e2; + } + , view_out + , view_in1 + , view_in2 + ); + + std::transform( input1.begin(), input1.end() + , input2.begin() + , check.begin() + , [](auto const e1, auto const e2) { + return e1 + e2; + } + ); + + TTS_ALL_EQUAL(output, check); }; -TTS_CASE("Check for kwk::transform(value, new_value) 2D") +TTS_CASE("Check for kwk::transform(value, new_value) 1D with float") { - int data[2*3]; - double res[2*3]; - double vdata[2*3]; - - fill_data(data, kwk::of_size(2,3), true); - fill_data(vdata, kwk::of_size(2,3), true); - - auto d = kwk::view{kwk::source = data, kwk::of_size(2,3)}; - auto v = kwk::view{kwk::source = res, kwk::of_size(2,3)}; - - for(int i = 0; i<2; i++) - for(int j = 0; j<3; j++) - vdata[i*3+j] = 1.0/(1.0+vdata[i*3+j]); - - int count = 0; - transform( - [&](auto e) - { - count++; - return 1.0/(1.0+e); - }, v, d); - - TTS_ALL_EQUAL(res, vdata); - - TTS_EQUAL(count, d.numel()); + using data_type = float; + const std::size_t d0 = 87; + const std::size_t input_size = d0; + std::array input1, input2, output, check; + + for (std::size_t i = 0; i < input_size; ++i) + { + input1[i] = i * static_cast(3.88); + input2[i] = i * static_cast(2.87); + } + + auto view_in1 = kwk::view{kwk::source = input1 , kwk::of_size(d0)}; + auto view_in2 = kwk::view{kwk::source = input2 , kwk::of_size(d0)}; + auto view_out = kwk::view{kwk::source = output , kwk::of_size(d0)}; + + kwk::transform( [&](auto const e1, auto const e2) { + return e1 + e2; + } + , view_out + , view_in1 + , view_in2 + ); + + std::transform( input1.begin(), input1.end() + , input2.begin() + , check.begin() + , [](auto const e1, auto const e2) { + return e1 + e2; + } + ); + + TTS_ALL_EQUAL(output, check); }; -TTS_CASE("Check for kwk::transform(value, new_value) 3D") +TTS_CASE("Check for kwk::transform(value, new_value) 1D with std::uint64_t") +{ + using data_type = std::uint64_t; + const std::size_t d0 = 87; + const std::size_t input_size = d0; + std::array input1, input2, output, check; + + for (std::size_t i = 0; i < input_size; ++i) { input1[i] = i * 3; input2[i] = i * 2; } + + auto view_in1 = kwk::view{kwk::source = input1 , kwk::of_size(d0)}; + auto view_in2 = kwk::view{kwk::source = input2 , kwk::of_size(d0)}; + auto view_out = kwk::view{kwk::source = output , kwk::of_size(d0)}; + + kwk::transform( [&](auto const e1, auto const e2) { + return e1 + e2; + } + , view_out + , view_in1 + , view_in2 + ); + + std::transform( input1.begin(), input1.end() + , input2.begin() + , check.begin() + , [](auto const e1, auto const e2) { + return e1 + e2; + } + ); + + TTS_ALL_EQUAL(output, check); +}; + +TTS_CASE("Check for kwk::transform(value, new_value) 2D") { - int data[2*3*4]; - double res[2*3*4]; - double vdata[2*3*4]; - - fill_data(data, kwk::of_size(2,3,4), true); - fill_data(vdata, kwk::of_size(2,3,4), true); - - auto d = kwk::view{kwk::source = data, kwk::of_size(2,3,4)}; - auto v = kwk::view{kwk::source = res, kwk::of_size(2,3,4)}; - - for(int i = 0; i<2; i++) - for(int j = 0; j<3; j++) - for(int k = 0; k<4; k++) - vdata[i*4*3+j*4+k] = 1.0/(1.0+vdata[i*4*3+j*4+k]); - - int count = 0; - transform( - [&](auto e) - { - count++; - return 1.0/(1.0+e); - }, v, d); - - TTS_ALL_EQUAL(res, vdata); - - TTS_EQUAL(count, d.numel()); + using data_type = int; + const std::size_t d0 = 87; + const std::size_t d1 = 18; + const std::size_t input_size = d0 * d1; + std::array input1, input2, output, check; + + for (std::size_t i = 0; i < input_size; ++i) { input1[i] = i * 3; input2[i] = i * 2; } + + auto view_in1 = kwk::view{kwk::source = input1 , kwk::of_size(d0, d1)}; + auto view_in2 = kwk::view{kwk::source = input2 , kwk::of_size(d0, d1)}; + auto view_out = kwk::view{kwk::source = output , kwk::of_size(d0, d1)}; + + kwk::transform( [&](auto const e1, auto const e2) { + return e1 + e2; + } + , view_out + , view_in1 + , view_in2 + ); + + std::transform( input1.begin(), input1.end() + , input2.begin() + , check.begin() + , [](auto const e1, auto const e2) { + return e1 + e2; + } + ); + + TTS_ALL_EQUAL(output, check); }; TTS_CASE("Check for kwk::transform(value, new_value) 4D") { - int data[2*3*4*5]; - double res[2*3*4*5]; - double vdata[2*3*4*5]; - - fill_data(data, kwk::of_size(2,3,4,5), true); - fill_data(vdata, kwk::of_size(2,3,4,5), true); - - auto d = kwk::view{kwk::source = data, kwk::of_size(2,3,4,5)}; - auto v = kwk::view{kwk::source = res, kwk::of_size(2,3,4,5)}; - - for(int i = 0; i<2; i++) - for(int j = 0; j<3; j++) - for(int k = 0; k<4; k++) - for(int l = 0; l<5; l++) - vdata[i*5*4*3+j*5*4+k*5+l] = 1.0/(1.0+vdata[i*5*4*3+j*5*4+k*5+l]); - - int count = 0; - transform( - [&](auto e) - { - count++; - return 1.0/(1.0+e); - }, v, d); - - TTS_ALL_EQUAL(res, vdata); - - TTS_EQUAL(count, d.numel()); + using data_type = int; + const std::size_t d0 = 87; + const std::size_t d1 = 18; + const std::size_t d2 = 41; + const std::size_t d3 = 8; + const std::size_t input_size = d0 * d1 * d2 * d3; + std::array input1, input2, output, check; + + for (std::size_t i = 0; i < input_size; ++i) { input1[i] = i * 3; input2[i] = i * 2; } + + auto view_in1 = kwk::view{kwk::source = input1 , kwk::of_size(d0, d1, d2, d3)}; + auto view_in2 = kwk::view{kwk::source = input2 , kwk::of_size(d0, d1, d2, d3)}; + auto view_out = kwk::view{kwk::source = output , kwk::of_size(d0, d1, d2, d3)}; + + kwk::transform( [&](auto const e1, auto const e2) { + return e1 + e2; + } + , view_out + , view_in1 + , view_in2 + ); + + std::transform( input1.begin(), input1.end() + , input2.begin() + , check.begin() + , [](auto const e1, auto const e2) { + return e1 + e2; + } + ); + + TTS_ALL_EQUAL(output, check); }; \ No newline at end of file