Skip to content

Commit

Permalink
Algorithms : transform
Browse files Browse the repository at this point in the history
  • Loading branch information
SylvainJoube committed Sep 11, 2024
1 parent a1e2ec3 commit c839e5c
Show file tree
Hide file tree
Showing 3 changed files with 205 additions and 108 deletions.
24 changes: 16 additions & 8 deletions include/kwk/algorithm/algos/transform.hpp
Original file line number Diff line number Diff line change
@@ -1,25 +1,33 @@
//==================================================================================================
//======================================================================================================================
/**
KIWAKU - Containers Well Made
Copyright : KIWAKU Project Contributors
SPDX-License-Identifier: BSL-1.0
**/
//==================================================================================================
//======================================================================================================================
#pragma once

#include <kwk/concepts/container.hpp>
#include <kwk/algorithm/algos/for_each.hpp>
#include <kwk/detail/abi.hpp>
#include <cstddef>
#include <utility>
#include <kwk/context/context.hpp>

namespace kwk
{
// Transform is not a required part of contexts anymore
template< typename Context, typename Func, concepts::container Out
, concepts::container C0, concepts::container... Cs
>
constexpr void transform(Context& ctx,Func&& f, Out& out, C0 const& c0, Cs const&... cs)
{
ctx.map ( [&](auto& o, auto const& i0, auto const&... in) { o = KWK_FWD(f)(i0, in...); }
, ctx.out(out), ctx.in(c0), ctx.in(cs)...
);
}

template< typename Func, concepts::container Out
, concepts::container C0, concepts::container... Cs
>
constexpr auto transform(Func f, Out& out, C0&& c0, Cs&&... cs)
constexpr void transform(Func&& f, Out& out, C0&& c0, Cs&&... cs)
{
kwk::for_each([&](auto... is) { out(is...) = f(KWK_FWD(c0)(is...), KWK_FWD(cs)(is...)...); }, out.shape() );
kwk::transform(cpu, KWK_FWD(f), out, KWK_FWD(c0), KWK_FWD(cs)...);
}
}
34 changes: 34 additions & 0 deletions test/algorithm/algos/context/cpu/transform.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
//======================================================================================================================
/*
KIWAKU - Containers Well Made
Copyright : KIWAKU Contributors & Maintainers
SPDX-License-Identifier: BSL-1.0
*/
//======================================================================================================================
#include <kwk/algorithm/algos/for_each.hpp>
#include <kwk/algorithm/algos/transform.hpp>
#include <kwk/container.hpp>
#include "test.hpp"
#include "../generic/transform.hpp"

// TODO: update these tests

TTS_CASE("Check for kwk::transform(value, new_value) 1D - CPU context")
{
kwk::test::transform_value_new_value_1D(kwk::cpu);
};

TTS_CASE("Check for kwk::transform(value, new_value) 2D - CPU context")
{
kwk::test::transform_value_new_value_2D(kwk::cpu);
};

TTS_CASE("Check for kwk::transform(value, new_value) 3D - CPU context")
{
kwk::test::transform_value_new_value_3D(kwk::cpu);
};

TTS_CASE("Check for kwk::transform(value, new_value) 4D - CPU context")
{
kwk::test::transform_value_new_value_4D(kwk::cpu);
};
255 changes: 155 additions & 100 deletions test/algorithm/algos/transform.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,117 +9,172 @@
#include <kwk/algorithm/algos/transform.hpp>
#include <kwk/container.hpp>
#include "test.hpp"
#include <algorithm>

TTS_CASE("Check for kwk::transform(value, new_value) 1D")
{
int data[2];
double res[2];
double vdata[2] = {1, 0.5};

fill_data(data, kwk::of_size(2), true);

auto d = kwk::view{kwk::source = data, kwk::of_size(2)};
auto v = kwk::view{kwk::source = res, kwk::of_size(2)};

int count = 0;
transform(
[&](auto e)
{
count++;
return 1.0/(1.0+e);
}, v, d);

TTS_ALL_EQUAL(res, vdata);

TTS_EQUAL(count, d.numel());
using data_type = int;
const std::size_t d0 = 87;
const std::size_t input_size = d0;
std::array<data_type, input_size> input1, input2, output, check;

for (std::size_t i = 0; i < input_size; ++i) { input1[i] = i * 3; input2[i] = i * 2; }

auto view_in1 = kwk::view{kwk::source = input1 , kwk::of_size(d0)};
auto view_in2 = kwk::view{kwk::source = input2 , kwk::of_size(d0)};
auto view_out = kwk::view{kwk::source = output , kwk::of_size(d0)};

kwk::transform( [&](auto const e1, auto const e2) {
return e1 + e2;
}
, view_out
, view_in1
, view_in2
);

std::transform( input1.begin(), input1.end()
, input2.begin()
, check.begin()
, [](auto const e1, auto const e2) {
return e1 + e2;
}
);

TTS_ALL_EQUAL(output, check);
};

TTS_CASE("Check for kwk::transform(value, new_value) 2D")
TTS_CASE("Check for kwk::transform(value, new_value) 1D with float")
{
int data[2*3];
double res[2*3];
double vdata[2*3];

fill_data(data, kwk::of_size(2,3), true);
fill_data(vdata, kwk::of_size(2,3), true);

auto d = kwk::view{kwk::source = data, kwk::of_size(2,3)};
auto v = kwk::view{kwk::source = res, kwk::of_size(2,3)};

for(int i = 0; i<2; i++)
for(int j = 0; j<3; j++)
vdata[i*3+j] = 1.0/(1.0+vdata[i*3+j]);

int count = 0;
transform(
[&](auto e)
{
count++;
return 1.0/(1.0+e);
}, v, d);

TTS_ALL_EQUAL(res, vdata);

TTS_EQUAL(count, d.numel());
using data_type = float;
const std::size_t d0 = 87;
const std::size_t input_size = d0;
std::array<data_type, input_size> input1, input2, output, check;

for (std::size_t i = 0; i < input_size; ++i)
{
input1[i] = i * static_cast<data_type>(3.88);
input2[i] = i * static_cast<data_type>(2.87);
}

auto view_in1 = kwk::view{kwk::source = input1 , kwk::of_size(d0)};
auto view_in2 = kwk::view{kwk::source = input2 , kwk::of_size(d0)};
auto view_out = kwk::view{kwk::source = output , kwk::of_size(d0)};

kwk::transform( [&](auto const e1, auto const e2) {
return e1 + e2;
}
, view_out
, view_in1
, view_in2
);

std::transform( input1.begin(), input1.end()
, input2.begin()
, check.begin()
, [](auto const e1, auto const e2) {
return e1 + e2;
}
);

TTS_ALL_EQUAL(output, check);
};

TTS_CASE("Check for kwk::transform(value, new_value) 3D")
TTS_CASE("Check for kwk::transform(value, new_value) 1D with std::uint64_t")
{
using data_type = std::uint64_t;
const std::size_t d0 = 87;
const std::size_t input_size = d0;
std::array<data_type, input_size> input1, input2, output, check;

for (std::size_t i = 0; i < input_size; ++i) { input1[i] = i * 3; input2[i] = i * 2; }

auto view_in1 = kwk::view{kwk::source = input1 , kwk::of_size(d0)};
auto view_in2 = kwk::view{kwk::source = input2 , kwk::of_size(d0)};
auto view_out = kwk::view{kwk::source = output , kwk::of_size(d0)};

kwk::transform( [&](auto const e1, auto const e2) {
return e1 + e2;
}
, view_out
, view_in1
, view_in2
);

std::transform( input1.begin(), input1.end()
, input2.begin()
, check.begin()
, [](auto const e1, auto const e2) {
return e1 + e2;
}
);

TTS_ALL_EQUAL(output, check);
};

TTS_CASE("Check for kwk::transform(value, new_value) 2D")
{
int data[2*3*4];
double res[2*3*4];
double vdata[2*3*4];

fill_data(data, kwk::of_size(2,3,4), true);
fill_data(vdata, kwk::of_size(2,3,4), true);

auto d = kwk::view{kwk::source = data, kwk::of_size(2,3,4)};
auto v = kwk::view{kwk::source = res, kwk::of_size(2,3,4)};

for(int i = 0; i<2; i++)
for(int j = 0; j<3; j++)
for(int k = 0; k<4; k++)
vdata[i*4*3+j*4+k] = 1.0/(1.0+vdata[i*4*3+j*4+k]);

int count = 0;
transform(
[&](auto e)
{
count++;
return 1.0/(1.0+e);
}, v, d);

TTS_ALL_EQUAL(res, vdata);

TTS_EQUAL(count, d.numel());
using data_type = int;
const std::size_t d0 = 87;
const std::size_t d1 = 18;
const std::size_t input_size = d0 * d1;
std::array<data_type, input_size> input1, input2, output, check;

for (std::size_t i = 0; i < input_size; ++i) { input1[i] = i * 3; input2[i] = i * 2; }

auto view_in1 = kwk::view{kwk::source = input1 , kwk::of_size(d0, d1)};
auto view_in2 = kwk::view{kwk::source = input2 , kwk::of_size(d0, d1)};
auto view_out = kwk::view{kwk::source = output , kwk::of_size(d0, d1)};

kwk::transform( [&](auto const e1, auto const e2) {
return e1 + e2;
}
, view_out
, view_in1
, view_in2
);

std::transform( input1.begin(), input1.end()
, input2.begin()
, check.begin()
, [](auto const e1, auto const e2) {
return e1 + e2;
}
);

TTS_ALL_EQUAL(output, check);
};

TTS_CASE("Check for kwk::transform(value, new_value) 4D")
{
int data[2*3*4*5];
double res[2*3*4*5];
double vdata[2*3*4*5];

fill_data(data, kwk::of_size(2,3,4,5), true);
fill_data(vdata, kwk::of_size(2,3,4,5), true);

auto d = kwk::view{kwk::source = data, kwk::of_size(2,3,4,5)};
auto v = kwk::view{kwk::source = res, kwk::of_size(2,3,4,5)};

for(int i = 0; i<2; i++)
for(int j = 0; j<3; j++)
for(int k = 0; k<4; k++)
for(int l = 0; l<5; l++)
vdata[i*5*4*3+j*5*4+k*5+l] = 1.0/(1.0+vdata[i*5*4*3+j*5*4+k*5+l]);

int count = 0;
transform(
[&](auto e)
{
count++;
return 1.0/(1.0+e);
}, v, d);

TTS_ALL_EQUAL(res, vdata);

TTS_EQUAL(count, d.numel());
using data_type = int;
const std::size_t d0 = 87;
const std::size_t d1 = 18;
const std::size_t d2 = 41;
const std::size_t d3 = 8;
const std::size_t input_size = d0 * d1 * d2 * d3;
std::array<data_type, input_size> input1, input2, output, check;

for (std::size_t i = 0; i < input_size; ++i) { input1[i] = i * 3; input2[i] = i * 2; }

auto view_in1 = kwk::view{kwk::source = input1 , kwk::of_size(d0, d1, d2, d3)};
auto view_in2 = kwk::view{kwk::source = input2 , kwk::of_size(d0, d1, d2, d3)};
auto view_out = kwk::view{kwk::source = output , kwk::of_size(d0, d1, d2, d3)};

kwk::transform( [&](auto const e1, auto const e2) {
return e1 + e2;
}
, view_out
, view_in1
, view_in2
);

std::transform( input1.begin(), input1.end()
, input2.begin()
, check.begin()
, [](auto const e1, auto const e2) {
return e1 + e2;
}
);

TTS_ALL_EQUAL(output, check);
};

0 comments on commit c839e5c

Please sign in to comment.