Skip to content

Commit

Permalink
May be nussbaumer is working?
Browse files Browse the repository at this point in the history
  • Loading branch information
nindanaoto committed Feb 11, 2024
1 parent 165c43f commit b464c7a
Show file tree
Hide file tree
Showing 3 changed files with 183 additions and 0 deletions.
113 changes: 113 additions & 0 deletions include/nussbaumer.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
#pragma once
#include <span>

namespace Nussbaumer{

template <typename T, uint rbit>
inline void PolynomialMulByXai(const std::span<T,1ull<<rbit> res, const size_t a)
{
if (a == 0)
return;
else{
constexpr size_t r = 1ull<<rbit;
std::array<T,r> temp;
std::copy(res.begin(),res.end(),temp.begin());
if (a < r) {
for (int i = 0; i < a; i++) res[i] = -temp[i - a + r];
for (int i = a; i < r; i++) res[i] = temp[i - a];
}
else {
const size_t aa = a - r;
for (int i = 0; i < aa; i++) res[i] = temp[i - aa + r];
for (int i = aa; i < r; i++) res[i] = -temp[i - aa];
}
}
}

template<typename T, uint mbit, uint rbit>
void NussbaumerButterfly(const std::span<T,(1u<<(rbit+mbit))> res){
constexpr size_t m = 1ull<<mbit;
constexpr size_t r = 1ull<<rbit;
for(int i = 0; i < m/2; i++)
for(int j = 0; j < r; j++){
const T temp = res[i*r+j];
res[i*r+j] += res[(i+m/2)*r+j];
res[(i+m/2)*r+j] = temp - res[(i+m/2)*r+j];
}
if constexpr(mbit!=1){
constexpr size_t stride = 1ull<<(rbit-mbit);
for(int i = 1; i < m/2; i++) PolynomialMulByXai<T, rbit>(static_cast<std::span<T,r>>(res.subspan((i+m/2)*r,r)),i*stride);
NussbaumerButterfly<T,mbit-1,rbit>(res.template subspan<0,m*r/2>());
NussbaumerButterfly<T,mbit-1,rbit>(res.template subspan<m*r/2,m*r/2>());
}
}

template<typename T, uint Nbit>
void NussbaumerTransform(std::span<T,(1ull<<Nbit)> res){
if constexpr(Nbit == 1){
const T temp = res[0];
res[0] += res[1];
res[1] = temp - res[1];
return;
}else{
//initialize
constexpr uint mbit = Nbit/2;
constexpr size_t m = 1ull<<mbit;
constexpr uint rbit = Nbit-mbit;
constexpr size_t r = 1ull<<rbit;
std::array<T,(1ull<<Nbit)> temp;
std::copy(res.begin(),res.end(),temp.begin());
//reorder
for(int i = 0; i < m; i++){
for(int j = 0; j < r; j++)
res[i*r+j] = temp[m*j+i];
}
NussbaumerButterfly<T,mbit,rbit>(res);
for(int i = 0; i < m; i++)
NussbaumerTransform<T,rbit>(static_cast<std::span<T,r>>(res.subspan(i*r,r)));
}
}

template<typename T, uint mbit, uint rbit>
void InverseNussbaumerButterfly(const std::span<T,(1u<<(rbit+mbit))> res){
constexpr size_t m = 1ull<<mbit;
constexpr size_t r = 1ull<<rbit;
if constexpr(mbit!=1){
constexpr size_t stride = 1ull<<(rbit-mbit);
InverseNussbaumerButterfly<T,mbit-1,rbit>(res.template subspan<0,m*r/2>());
InverseNussbaumerButterfly<T,mbit-1,rbit>(res.template subspan<m*r/2,m*r/2>());
for(int i = 1; i < m/2; i++) PolynomialMulByXai<T, rbit>(static_cast<std::span<T,r>>(res.subspan((i+m/2)*r,r)),2*r-i*stride);
}
for(int i = 0; i < m/2; i++)
for(int j = 0; j < r; j++){
const T temp = res[i*r+j];
res[i*r+j] += res[(i+m/2)*r+j];
res[(i+m/2)*r+j] = temp - res[(i+m/2)*r+j];
}
}

template<typename T, uint Nbit>
void InverseNussbaumerTransform(std::span<T,(1ull<<Nbit)> res){
if constexpr(Nbit == 1){
const T temp = res[0];
res[0] += res[1];
res[1] = temp - res[1];
return;
}else{
//initialize
constexpr uint mbit = Nbit/2;
constexpr size_t m = 1ull<<mbit;
constexpr uint rbit = Nbit-mbit;
constexpr size_t r = 1ull<<rbit;
for(int i = 0; i < m; i++)
InverseNussbaumerTransform<T,rbit>(static_cast<std::span<T,r>>(res.subspan(i*r,r)));
InverseNussbaumerButterfly<T,mbit,rbit>(res);
std::array<T,(1ull<<Nbit)> temp;
std::copy(res.begin(),res.end(),temp.begin());
//reorder
for(int i = 0; i < m; i++)
for(int j = 0; j < r; j++)
res[m*j+i] = temp[i*r+j];
}
}
}
1 change: 1 addition & 0 deletions include/params.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

#include "cuhe++.hpp"
#include "raintt.hpp"
#include "nussbaumer.hpp"

namespace TFHEpp {

Expand Down
69 changes: 69 additions & 0 deletions test/nussbaumer.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
#include <algorithm>
#include <cassert>
#include <iostream>
#include <random>
#include <tfhe++.hpp>

int main()
{
constexpr uint32_t num_test = 1000;
std::random_device seed_gen;
std::default_random_engine engine(seed_gen());
std::uniform_int_distribution<uint32_t> Bgdist(0, TFHEpp::lvl1param::Bg);
std::uniform_int_distribution<uint32_t> Torus32dist(0, UINT32_MAX);

// std::cout << "Start LVL1 test." << std::endl;
for (int test = 0; test < num_test; test++) {
using T = uint64_t;
std::array<T,TFHEpp::lvl1param::n> a,res;
for (T &i : a) i = Torus32dist(engine);
res = a;
Nussbaumer::NussbaumerTransform<T,TFHEpp::lvl1param::nbit>(std::span{res});
Nussbaumer::InverseNussbaumerTransform<T,TFHEpp::lvl1param::nbit>(std::span{res});
for (int i = 0; i < TFHEpp::lvl1param::n; i++)
assert(abs(static_cast<int32_t>(a[i] - res[i]/TFHEpp::lvl1param::n) <= 1));
}
std::cout << "Id Passed" << std::endl;

// for (int test = 0; test < num_test; test++) {
// array<typename TFHEpp::lvl1param::T, lvl1param::n> a;
// for (int i = 0; i < lvl1param::n; i++)
// a[i] = Bgdist(engine) - lvl1param::Bg / 2;
// for (typename TFHEpp::lvl1param::T &i : a)
// i = Bgdist(engine) - lvl1param::Bg / 2;
// array<typename TFHEpp::lvl1param::T, lvl1param::n> b;
// for (typename TFHEpp::lvl1param::T &i : b) i = Torus32dist(engine);

// Polynomial<lvl1param> polymul;
// TFHEpp::PolyMul<lvl1param>(polymul, a, b);
// Polynomial<lvl1param> naieve = {};
// for (int i = 0; i < lvl1param::n; i++) {
// for (int j = 0; j <= i; j++)
// naieve[i] += static_cast<int32_t>(a[j]) * b[i - j];
// for (int j = i + 1; j < lvl1param::n; j++)
// naieve[i] -=
// static_cast<int32_t>(a[j]) * b[lvl1param::n + i - j];
// }
// for (int i = 0; i < lvl1param::n; i++)
// assert(abs(static_cast<int32_t>(naieve[i] - polymul[i])) <= 1);
// }
// cout << "PolyMul Passed" << endl;

// uniform_int_distribution<uint64_t> Bgbardist(0, lvl2param::Bg);
// uniform_int_distribution<uint64_t> Torus64dist(0, UINT64_MAX);

// cout << "Start LVL2 test." << endl;
// for (int test = 0; test < num_test; test++) {
// Polynomial<lvl2param> a;
// for (uint64_t &i : a) i = Torus64dist(engine);
// PolynomialInFD<lvl2param> resfft;
// TFHEpp::TwistIFFT<lvl2param>(resfft, a);
// Polynomial<lvl2param> res;
// TFHEpp::TwistFFT<lvl2param>(res, resfft);
// for (int i = 0; i < lvl2param::n; i++)
// assert(abs(static_cast<int64_t>(a[i] - res[i])) <= (1 << 14));
// }
// cout << "FFT Passed" << endl;

return 0;
}

0 comments on commit b464c7a

Please sign in to comment.