-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpprf_test.cc
83 lines (63 loc) · 2.35 KB
/
pprf_test.cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
#include <glog/logging.h>
#include <gtest/gtest.h>
#include "network/mem_channel.h"
#include "psi/ot/tools/silentpprf.h"
using primihub::crypto::PprfOutputFormat;
using primihub::crypto::SilentMultiPprfReceiver;
using primihub::crypto::SilentMultiPprfSender;
using primihub::link::Channel;
using primihub::link::MemoryChannel;
using primihub::link::Status;
using ChannelRole = MemoryChannel::ChannelRole;
TEST(silentpprf, base_test) {
u64 depth = 3;
u64 domain = 1ull << depth;
auto threads = 3;
u64 numPoints = 8;
auto channel_impl1 = std::make_shared<MemoryChannel>(ChannelRole::CLIENT);
auto channel1 = std::make_shared<Channel>(channel_impl1, "base_test");
auto channel_impl2 = std::make_shared<MemoryChannel>(ChannelRole::SERVER);
auto channel2 = std::make_shared<Channel>(channel_impl2, "base_test");
PRNG prng(ZeroBlock);
auto format = PprfOutputFormat::Plain;
SilentMultiPprfSender sender;
SilentMultiPprfReceiver recver;
sender.configure(domain, numPoints);
recver.configure(domain, numPoints);
auto numOTs = sender.baseOtCount();
std::vector<std::array<block, 2>> sendOTs(numOTs);
std::vector<block> recvOTs(numOTs);
BitVector recvBits = recver.sampleChoiceBits(domain, format, prng);
prng.get(sendOTs.data(), sendOTs.size());
// sendOTs[cmd.getOr("i",0)] = prng.get();
// recvBits[16] = 1;
for (u64 i = 0; i < numOTs; ++i) {
// recvBits[i] = 0;
recvOTs[i] = sendOTs[i][recvBits[i]];
}
sender.setBase(sendOTs);
recver.setBase(recvOTs);
Matrix<block> sOut(domain, numPoints);
Matrix<block> rOut(domain, numPoints);
std::vector<u64> points(numPoints);
recver.getPoints(points, format);
auto sender_fn = [channel1, &sender, &prng, &sOut, format, threads]() {
sender.expand(channel1, {&CCBlock, 1}, prng, sOut, format, true, threads);
};
auto recver_fn = [channel2, &recver, &prng, &rOut, format, threads]() {
recver.expand(channel2, prng, rOut, format, true, threads);
};
std::future<void> recver_fut = std::async(recver_fn);
std::future<void> sender_fut = std::async(sender_fn);
sender_fut.get();
recver_fut.get();
bool failed = false;
for (u64 j = 0; j < numPoints; ++j) {
for (u64 i = 0; i < domain; ++i) {
auto exp = sOut(i, j);
if (points[j] == i) exp = exp ^ CCBlock;
if (neq(exp, rOut(i, j))) failed = true;
}
}
EXPECT_EQ(failed == false, true);
}