Skip to content

Commit

Permalink
Algorithms : transform
Browse files Browse the repository at this point in the history
  • Loading branch information
SylvainJoube committed Sep 11, 2024
1 parent 01f99f9 commit 0094ffe
Show file tree
Hide file tree
Showing 4 changed files with 309 additions and 108 deletions.
24 changes: 16 additions & 8 deletions include/kwk/algorithm/algos/transform.hpp
Original file line number Diff line number Diff line change
@@ -1,25 +1,33 @@
//==================================================================================================
//======================================================================================================================
/**
KIWAKU - Containers Well Made
Copyright : KIWAKU Project Contributors
SPDX-License-Identifier: BSL-1.0
**/
//==================================================================================================
//======================================================================================================================
#pragma once

#include <kwk/concepts/container.hpp>
#include <kwk/algorithm/algos/for_each.hpp>
#include <kwk/detail/abi.hpp>
#include <cstddef>
#include <utility>
#include <kwk/context/context.hpp>

namespace kwk
{
// Transform is not a required part of contexts anymore
template< typename Context, typename Func, concepts::container Out
, concepts::container C0, concepts::container... Cs
>
constexpr void transform(Context& ctx,Func&& f, Out& out, C0 const& c0, Cs const&... cs)
{
ctx.map ( [&](auto& o, auto const& i0, auto const&... in) { o = KWK_FWD(f)(i0, in...); }
, ctx.out(out), ctx.in(c0), ctx.in(cs)...
);
}

template< typename Func, concepts::container Out
, concepts::container C0, concepts::container... Cs
>
constexpr auto transform(Func f, Out& out, C0&& c0, Cs&&... cs)
constexpr void transform(Func&& f, Out& out, C0&& c0, Cs&&... cs)
{
kwk::for_each([&](auto... is) { out(is...) = f(KWK_FWD(c0)(is...), KWK_FWD(cs)(is...)...); }, out.shape() );
kwk::transform(cpu, KWK_FWD(f), out, KWK_FWD(c0), KWK_FWD(cs)...);
}
}
34 changes: 34 additions & 0 deletions test/algorithm/algos/context/cpu/transform.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
//======================================================================================================================
/*
KIWAKU - Containers Well Made
Copyright : KIWAKU Contributors & Maintainers
SPDX-License-Identifier: BSL-1.0
*/
//======================================================================================================================
#include <kwk/algorithm/algos/for_each.hpp>
#include <kwk/algorithm/algos/transform.hpp>
#include <kwk/container.hpp>
#include "test.hpp"
#include "../generic/transform.hpp"

// TODO: update these tests

TTS_CASE("Check for kwk::transform(value, new_value) 1D - CPU context")
{
kwk::test::transform_value_new_value_1D(kwk::cpu);
};

TTS_CASE("Check for kwk::transform(value, new_value) 2D - CPU context")
{
kwk::test::transform_value_new_value_2D(kwk::cpu);
};

TTS_CASE("Check for kwk::transform(value, new_value) 3D - CPU context")
{
kwk::test::transform_value_new_value_3D(kwk::cpu);
};

TTS_CASE("Check for kwk::transform(value, new_value) 4D - CPU context")
{
kwk::test::transform_value_new_value_4D(kwk::cpu);
};
104 changes: 104 additions & 0 deletions test/algorithm/algos/context/generic/transform.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
//======================================================================================================================
/*
KIWAKU - Containers Well Made
Copyright : KIWAKU Contributors & Maintainers
SPDX-License-Identifier: BSL-1.0
*/
//======================================================================================================================
#pragma once

#include <kwk/algorithm/algos/transform.hpp>
#include <kwk/container.hpp>
#include "test.hpp"

// TODO: update these tests

namespace kwk::test
{
template<typename Context>
void transform_value_new_value_1D(Context&& ctx)
{
int data[2];
double res[2];
double vdata[2] = {1, 0.5};

fill_data(data, kwk::of_size(2), true);

auto d = kwk::view{kwk::source = data, kwk::of_size(2)};
auto v = kwk::view{kwk::source = res, kwk::of_size(2)};

::kwk::transform(ctx, [&](auto e) { return 1.0/(1.0+e); }, v, d);

TTS_ALL_EQUAL(res, vdata);
}

template<typename Context>
void transform_value_new_value_2D(Context&& ctx)
{
int data[2*3];
double res[2*3];
double vdata[2*3];

fill_data(data, kwk::of_size(2,3), true);
fill_data(vdata, kwk::of_size(2,3), true);

auto d = kwk::view{kwk::source = data, kwk::of_size(2,3)};
auto v = kwk::view{kwk::source = res, kwk::of_size(2,3)};

for(int i = 0; i<2; i++)
for(int j = 0; j<3; j++)
vdata[i*3+j] = 1.0/(1.0+vdata[i*3+j]);

::kwk::transform(ctx, [&](auto e) { return 1.0/(1.0+e); }, v, d);

TTS_ALL_EQUAL(res, vdata);
};

template<typename Context>
void transform_value_new_value_3D(Context&& ctx)
{
int data[2*3*4];
double res[2*3*4];
double vdata[2*3*4];

fill_data(data, kwk::of_size(2,3,4), true);
fill_data(vdata, kwk::of_size(2,3,4), true);

auto d = kwk::view{kwk::source = data, kwk::of_size(2,3,4)};
auto v = kwk::view{kwk::source = res, kwk::of_size(2,3,4)};

for(int i = 0; i<2; i++)
for(int j = 0; j<3; j++)
for(int k = 0; k<4; k++)
vdata[i*4*3+j*4+k] = 1.0/(1.0+vdata[i*4*3+j*4+k]);

::kwk::transform(ctx, [&](auto e) { return 1.0/(1.0+e); }, v, d);

TTS_ALL_EQUAL(res, vdata);
};

template<typename Context>
void transform_value_new_value_4D(Context&& ctx)
{
int data[2*3*4*5];
double res[2*3*4*5];
double vdata[2*3*4*5];

fill_data(data, kwk::of_size(2,3,4,5), true);
fill_data(vdata, kwk::of_size(2,3,4,5), true);

auto d = kwk::view{kwk::source = data, kwk::of_size(2,3,4,5)};
auto v = kwk::view{kwk::source = res, kwk::of_size(2,3,4,5)};

for(int i = 0; i<2; i++)
for(int j = 0; j<3; j++)
for(int k = 0; k<4; k++)
for(int l = 0; l<5; l++)
vdata[i*5*4*3+j*5*4+k*5+l] = 1.0/(1.0+vdata[i*5*4*3+j*5*4+k*5+l]);

::kwk::transform(ctx, [&](auto e) { return 1.0/(1.0+e); }, v, d);

TTS_ALL_EQUAL(res, vdata);
};

}
Loading

0 comments on commit 0094ffe

Please sign in to comment.