diff --git a/BMR/AuthValue.cpp b/BMR/AuthValue.cpp deleted file mode 100644 index 376f02118..000000000 --- a/BMR/AuthValue.cpp +++ /dev/null @@ -1,31 +0,0 @@ -/* - * AuthValue.cpp - * - */ - -#include "GC/Secret.h" - -namespace GC -{ - -void AuthValue::assign(const word& value, const int128& mac_key, bool not_first_player) -{ - if (not_first_player) - share = 0; - else - share = value; -#ifdef __PCLMUL__ - mac = _mm_clmulepi64_si128(_mm_cvtsi64_si128(mac_key.get_lower()), _mm_cvtsi64_si128(value), 0); -#else - (void) mac_key; - throw runtime_error("need to compile with PCLMUL support"); -#endif -} - -ostream& operator<<(ostream& o, const AuthValue& auth_value) -{ - o << hex << auth_value.share << " " << auth_value.mac; - return o; -} - -} diff --git a/BMR/CommonParty.h b/BMR/CommonParty.h index 5d6887921..81f1d12b1 100644 --- a/BMR/CommonParty.h +++ b/BMR/CommonParty.h @@ -58,11 +58,11 @@ class CommonParty LocalBuffer wires; ReceivedMsgStore wire_storage; - template - GC::BreakType first_phase(GC::Program& program, GC::Processor& processor, + template + GC::BreakType first_phase(GC::Program& program, GC::Processor& processor, GC::Machine& machine); template - GC::BreakType second_phase(GC::Program& program, GC::Processor& processor, + GC::BreakType second_phase(GC::Program& program, GC::Processor& processor, GC::Machine& machine, U& dynamic_memory); public: diff --git a/BMR/CommonParty.hpp b/BMR/CommonParty.hpp index e8f0964c3..037ecd5f6 100644 --- a/BMR/CommonParty.hpp +++ b/BMR/CommonParty.hpp @@ -8,8 +8,8 @@ #include "CommonParty.h" -template -GC::BreakType CommonParty::first_phase(GC::Program& program, +template +GC::BreakType CommonParty::first_phase(GC::Program& program, GC::Processor& processor, GC::Machine& machine) { (void)machine; @@ -20,7 +20,7 @@ GC::BreakType CommonParty::first_phase(GC::Program& program, GC::BreakType next; try { - next = (reinterpret_cast*>(&program))->execute(processor, dynamic_memory); + next = program.execute(processor, dynamic_memory); } catch (needs_cleaning& e) { @@ -44,7 +44,7 @@ GC::BreakType CommonParty::first_phase(GC::Program& program, } template -GC::BreakType CommonParty::second_phase(GC::Program& program, +GC::BreakType CommonParty::second_phase(GC::Program& program, GC::Processor& processor, GC::Machine& machine, U& dynamic_memory) { diff --git a/BMR/Key.h b/BMR/Key.h index b07964f10..f3e1fb01a 100644 --- a/BMR/Key.h +++ b/BMR/Key.h @@ -37,7 +37,7 @@ class Key { void serialize(SendBuffer& output) const { output.serialize(r); } void serialize_no_allocate(SendBuffer& output) const { output.serialize_no_allocate(r); } - bool get_signal() const { return _mm_cvtsi128_si64(r) & 1; } + bool get_signal() const { return _mm_cvtsi128_si32(r) & 1; } void set_signal(bool signal); Key doubling(int i) const; diff --git a/BMR/Machine.cpp b/BMR/Machine.cpp deleted file mode 100644 index 9c917412d..000000000 --- a/BMR/Machine.cpp +++ /dev/null @@ -1,103 +0,0 @@ -/* - * Secret.cpp - * - */ - -#include "BMR/CommonParty.h" -#include "BMR/Register_inline.h" - -#include "BMR/Register.hpp" -#include "GC/Machine.hpp" -#include "GC/Processor.hpp" -#include "GC/Secret.hpp" -#include "GC/Thread.hpp" -#include "GC/ThreadMaster.hpp" -#include "GC/Program.hpp" -#include "GC/Instruction.hpp" -#include "Processor/Instruction.hpp" - -namespace GC -{ - -template -Secret Secret::reconstruct(const int128& x, int length) -{ - Secret res; - for (int i = 0; i < CommonParty::singleton->get_n_parties(); i++) - { - Secret tmp = res; - Secret share = input(i + 1, x, length); - res = share + tmp; -#ifdef DEBUG_DYNAMIC - int128 a,b,c; - tmp.reveal(a); - share.reveal(b); - res.reveal(c); - cout << hex << c << "(" << dec << res.size() << ") = " << hex << a - << "(" << dec << tmp.size() << ")" << " ^ " << hex << b << "(" - << dec << share.size() << ") (" << dec << x << ", " << dec - << length << ")" << endl; -#endif - } - return res; - if ((size_t)length != res.registers.size()) - { - cout << length << " " << res.registers.size() << endl; - throw runtime_error("wrong bit length in reconstruct()"); - } -} - -template -void Secret::store(Memory& mem, size_t address) -{ - AuthValue& dest = mem[address]; - Secret mac_key = reconstruct(CommonParty::s().get_mac_key().get(), default_length); - Secret mac, mask, mac_mask; - mac = carryless_mult(*this, mac_key); - GC::Mask mask_share; - int length = registers.size(); - int mac_length = mac.registers.size(); - T::get_dyn_mask(mask_share, length, mac_length); - mask.random(length, mask_share.share); - mac_mask.random(mac_length, mask_share.mac); - word masked; - int128 masked_mac; - (*this + mask).reveal(length, masked); - (mac + mac_mask).reveal(mac_length, masked_mac); -#ifdef DEBUG_DYNAMIC - word a,b; - int128 c,d; - reveal(a); - mask.reveal(b); - mac.reveal(c); - mac_mask.reveal(d); - cout << masked << " = " << a << " ^ " << b << endl; - cout << masked_mac << " = " << c << " ^ " << d << endl; -#endif - T::unmask(dest, mask_share.share, mask_share.mac, masked, masked_mac); -} - -template -void Secret::load(int n, const Memory& mem, size_t address) -{ - (void)n; - const AuthValue& x = mem[address]; - *this = reconstruct(x.share, default_length); - Secret mac, check_mac, mac_key; - mac = reconstruct(x.mac, 2 * default_length); - mac_key = reconstruct(CommonParty::s().get_mac_key().get(), default_length); - check_mac = carryless_mult(*this, mac_key); - int128 result; - (mac + check_mac).reveal(2 * default_length, result); -#ifdef DEBUG_DYNAMIC - cout << "loading " << hex << x.share << " " << x.mac << endl; - int128 a; - mac.reveal(a); - word b; - reveal(b); - cout << "stored value " << hex << b << " mac " << a << endl; -#endif - T::check(result, x.share, x.mac); -} - -} diff --git a/BMR/Party.cpp b/BMR/Party.cpp index f54b88787..6cdbc83de 100644 --- a/BMR/Party.cpp +++ b/BMR/Party.cpp @@ -382,9 +382,6 @@ void FakeProgramParty::receive_spdz_wires(ReceivedMsg& msg) void ProgramParty::store_wire(const Register& reg) { wires.serialize(reg.key(get_id(), 0)); -#ifndef FREE_XOR - wires.serialize(reg.key(get_id(), 1)); -#endif #ifdef DEBUG cout << "storing wire" << endl; reg.print(); @@ -394,11 +391,7 @@ void ProgramParty::store_wire(const Register& reg) void ProgramParty::load_wire(Register& reg) { wires.unserialize(reg.key(get_id(), 0)); -#ifdef FREE_XOR reg.key(get_id(), 1) = reg.key(get_id(), 0) ^ get_delta(); -#else - wires.unserialize(reg.key(get_id(), 1)); -#endif #ifdef DEBUG cout << "loading wire" << endl; reg.print(); diff --git a/BMR/Party.h b/BMR/Party.h index 4f561c009..78d95e4e3 100644 --- a/BMR/Party.h +++ b/BMR/Party.h @@ -99,7 +99,7 @@ class ProgramParty : virtual public CommonParty, virtual public PartyProperties, GC::Machine< GC::Secret > machine; GC::Processor > processor; - GC::Program > program; + GC::Program program; GC::Machine< GC::Secret > prf_machine; GC::Processor > prf_processor; @@ -170,11 +170,7 @@ class ProgramPartySpec : public ProgramParty void get_spdz_wire(SpdzOp op, DualWire& spdz_wire); }; -#ifdef SPDZ_AUTH typedef ProgramPartySpec> FakeProgramPartySuper; -#else -typedef ProgramPartySpec> FakeProgramPartySuper; -#endif class FakeProgramParty : virtual public BaseParty, virtual public FakeProgramPartySuper { diff --git a/BMR/ProgramParty.hpp b/BMR/ProgramParty.hpp index 7a136118f..1490dc7b0 100644 --- a/BMR/ProgramParty.hpp +++ b/BMR/ProgramParty.hpp @@ -31,8 +31,8 @@ void ProgramPartySpec::load(string progname) program.parse(progname + "-0"); machine.reset(program, dynamic_memory); processor.reset(program); - prf_machine.reset(*reinterpret_cast >* >(&program)); - prf_processor.reset(*reinterpret_cast >* >(&program)); + prf_machine.reset(program); + prf_processor.reset(program); } template diff --git a/BMR/RealProgramParty.hpp b/BMR/RealProgramParty.hpp index 76b6b4e25..630d2ec22 100644 --- a/BMR/RealProgramParty.hpp +++ b/BMR/RealProgramParty.hpp @@ -78,9 +78,7 @@ RealProgramParty::RealProgramParty(int argc, const char** argv) : P = new CryptoPlayer(N, 0); delta = prng.get_doubleword(); -#ifdef KEY_SIGNAL delta.set_signal(1); -#endif #ifdef VERBOSE cerr << "delta: " << delta << endl; #endif @@ -201,16 +199,11 @@ RealProgramParty::~RealProgramParty() template void RealProgramParty::receive_keys(Register& reg) { -#ifndef FREE_XOR -#error not implemented -#endif auto& _id = this->_id; auto& _N = this->_N; reg.init(_N); reg.keys[0][_id - 1] = this->prng.get_doubleword(); -#ifdef KEY_SIGNAL reg.keys[0][_id - 1].set_signal(0); -#endif reg.keys[1][_id - 1] = reg.keys[0][_id - 1] ^ this->get_delta(); } diff --git a/BMR/Register.cpp b/BMR/Register.cpp index 84ddb53b6..0be981fbe 100644 --- a/BMR/Register.cpp +++ b/BMR/Register.cpp @@ -48,12 +48,10 @@ void Register::init(int rfd, int n_parties) { mask = mask>0 ? 1 : 0; keys.init(n_parties); keys.randomize(); -#ifdef KEY_SIGNAL for (int i = 0; i < 2; i++) for (size_t j = 0; j < keys[i].size(); j++) if (keys[i][j].get_signal() != i) keys[i][j] ^= Key(1); -#endif } void Register::set_eval_keys() @@ -284,21 +282,7 @@ void Register::eval(const Register& left, const Register& right, GarbledGate& ga // } // std::cout << std::endl; -#ifdef KEY_SIGNAL external = garbled_entry[my_id - 1].get_signal(); -#else - if(garbled_entry[my_id-1] == key(my_id, 0)) { - external = 0; - } else if (garbled_entry[my_id-1] == key(my_id, 1)) { - external = 1; - } else { - printf("\nERROR!!!\n"); - cout << "got key: " << garbled_entry[my_id - 1] << endl; - cout << "possibilities: " << key(my_id, 0) << " " << key(my_id, 1) << endl; - throw std::invalid_argument("result key doesn't fit any of my keys"); -// return NO_SIGNAL; - } -#endif #ifdef DEBUG_MASK cout << "output signal: " << (int)external << endl; @@ -680,9 +664,7 @@ void RandomRegister::randomize() party.random_timer.start(); init(party.randomfd, party._N); party.random_timer.stop(); -#ifdef FREE_XOR keys[1] = keys[0] ^ party.get_deltas(); -#endif party.add_keys(*this); } @@ -764,16 +746,13 @@ void EvalRegister::output() ProgramParty& party = ProgramParty::s(); party.load_wire(*this); set_mask(party.output_masks.pop_front()); -#ifdef KEY_SIGNAL #ifdef DEBUG_REGS cout << "check " << get_id() << endl; #endif check_signal_key(party.get_id(), garbled_entry); -#endif party.taint(); } -#ifdef FREE_XOR void RandomRegister::XOR(const Register& left, const Register& right) { mask = left.get_mask() ^ right.get_mask(); @@ -824,46 +803,6 @@ void EvalRegister::XOR(const Register& left, const Register& right) << " ^ " << right.get_garbled_entry()[i] << endl; #endif } -#endif - -void EvalRegister::check(const int128& value, word share, int128 mac) -{ -#ifdef DEBUG_DYNAMIC - cout << "check result " << value << endl; -#endif - if (value != 0) - { - cout << "MAC check: " << value << " " << share<< " " << mac << endl; - throw runtime_error("MAC check failed"); - } -} - -void EvalRegister::get_dyn_mask(GC::Mask& mask, int length, int mac_length) -{ - mask.share = CommonParty::s().prng.get_word() & ((1ULL << length) - 1); - mask.mac = int128(CommonParty::s().prng.get_doubleword()) - & int128::ones(mac_length); -#ifdef DEBUG_DYNAMIC - cout << "mask " << hex << mask.share << " " << mask.mac << " "; - cout << ((1ULL << length) - 1) << " " << int128::ones(mac_length) << endl; -#endif -} - -void EvalRegister::unmask(GC::AuthValue& dest, word mask_share, int128 mac_mask_share, - word masked, int128 masked_mac) -{ - dest.share = mask_share; - dest.mac = mac_mask_share; - if (ProgramParty::s()._id == 1) - { - dest.share ^= masked; - dest.mac ^= masked_mac; - } -#ifdef DEBUG_DYNAMIC - cout << dest.share << " ?= " << mask_share << " ^ " << masked << endl; - cout << dest.mac << " ?= " << mac_mask_share << " ^ " << masked_mac << endl; -#endif -} template <> void RandomRegister::store(NoMemory& mem, diff --git a/BMR/Register.h b/BMR/Register.h index fb8c89111..a55f029fa 100644 --- a/BMR/Register.h +++ b/BMR/Register.h @@ -21,10 +21,6 @@ using namespace std; #include "Math/gf2n.h" #include "Tools/FlexBuffer.h" -#ifndef FREE_XOR -#warning not using free XOR has not been tested in a while -#endif - //#define PAD_TO_8(n) (n+8-n%8) #define PAD_TO_8(n) (n) @@ -119,9 +115,6 @@ class KeyTuple { namespace GC { -class AuthValue; -class Mask; -class SpdzShare; template class Secret; template @@ -208,16 +201,9 @@ class Phase typedef BlackHole out_type; static BlackHole out; - static void check(const int128& value, word share, int128 mac) - { (void)value; (void)share; (void)mac; } - static void get_dyn_mask(GC::Mask& mask, int length, int mac_length) - { (void)mask; (void)length; (void)mac_length; } template static void store_clear_in_dynamic(T& mem, const vector& accesses) { (void)mem; (void)accesses; } - static void unmask(GC::AuthValue& dest, word mask_share, int128 mac_mask_share, - word masked, int128 masked_mac) - { (void)dest; (void)mask_share; (void)mac_mask_share; (void)masked; (void)masked_mac; } template static void store(NoMemory& dest, @@ -260,9 +246,6 @@ class ProgramRegister : public Phase, public Register template static void store(NoMemory& dest, const vector >& accesses) { (void)dest; (void)accesses; } - template - static void load(vector >& accesses, - const GC::Memory& source) { (void)accesses; (void)source; } // most BMR phases don't need actual input template @@ -301,11 +284,6 @@ class EvalRegister : public ProgramRegister typedef ostream& out_type; static ostream& out; - static void check(const int128& value, word share, int128 mac); - static void get_dyn_mask(GC::Mask& mask, int length, int mac_length); - static void unmask(GC::AuthValue& dest, word mask_share, int128 mac_mask_share, - word masked, int128 masked_mac); - template static void store(GC::Memory& dest, const vector >& accesses); diff --git a/BMR/TrustedParty.cpp b/BMR/TrustedParty.cpp index 7364efbc6..9a7bfee23 100644 --- a/BMR/TrustedParty.cpp +++ b/BMR/TrustedParty.cpp @@ -56,8 +56,8 @@ TrustedProgramParty::TrustedProgramParty(int argc, char** argv) : program.parse(string(argv[1]) + "-0"); processor.reset(program); machine.reset(program); - random_processor.reset(program.cast< GC::Secret >()); - random_machine.reset(program.cast< GC::Secret >()); + random_processor.reset(program); + random_machine.reset(program); if (singleton) throw runtime_error("there can only be one"); singleton = this; @@ -65,7 +65,6 @@ TrustedProgramParty::TrustedProgramParty(int argc, char** argv) : init(argv[2], 0); else init("LOOPBACK", 0); -#ifdef FREE_XOR deltas.resize(_N); for (size_t i = 0; i < _N; i++) { @@ -73,13 +72,10 @@ TrustedProgramParty::TrustedProgramParty(int argc, char** argv) : #ifdef DEBUG deltas[i] = Key(i + 1, 0); #endif -#ifdef KEY_SIGNAL if (deltas[i].get_signal() == 0) deltas[i] ^= Key(1); -#endif cout << "Delta " << i << ": " << deltas[i] << endl; } -#endif } TrustedProgramParty::~TrustedProgramParty() @@ -240,14 +236,12 @@ void BaseTrustedParty::Start() void TrustedProgramParty::NodeReady() { -#ifdef FREE_XOR for (int i = 0; i < get_n_parties(); i++) { SendBuffer& buffer = get_buffer(TYPE_DELTA); buffer.serialize(deltas[i]); _node->Send(i + 1, buffer); } -#endif this->BaseTrustedParty::NodeReady(); } diff --git a/BMR/TrustedParty.h b/BMR/TrustedParty.h index d79f00ed7..24e8120de 100644 --- a/BMR/TrustedParty.h +++ b/BMR/TrustedParty.h @@ -70,10 +70,8 @@ class TrustedProgramParty : public BaseTrustedParty { void store_wire(const Register& reg); void load_wire(Register& reg); -#ifdef FREE_XOR const Key& delta(int i) { return deltas[i]; } const KeyVector& get_deltas() { return deltas; } -#endif private: friend class GarbleRegister; @@ -84,14 +82,12 @@ class TrustedProgramParty : public BaseTrustedParty { GC::Machine< GC::Secret > machine; GC::Processor< GC::Secret > processor; - GC::Program< GC::Secret > program; + GC::Program program; GC::Machine< GC::Secret > random_machine; GC::Processor< GC::Secret > random_processor; -#ifdef FREE_XOR KeyVector deltas; -#endif vector spdz_wires[SPDZ_OP_N]; vector< Share > mask_shares; diff --git a/BMR/config.h b/BMR/config.h index 2d76c14b4..619de618c 100644 --- a/BMR/config.h +++ b/BMR/config.h @@ -10,10 +10,6 @@ //#define N_PARTIES 2 #define MAX_N_PARTIES 3 -#define FREE_XOR -#define KEY_SIGNAL -#define SPDZ_AUTH -#define NO_INPUT #define MAX_INLINE //#define SIGNAL_CHECK diff --git a/CHANGELOG.md b/CHANGELOG.md index 0d5a92363..51db3bb17 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,12 @@ The changelog explains changes pulled through from the private development repository. Bug fixes and small enhancements are committed between releases and not documented here. +## 0.1.8 (June 15, 2020) + +- Half-gate garbling +- Native 2D convolution +- Inference with some TensorFlow graphs +- MASCOT with several MACs to increase security + ## 0.1.7 (May 8, 2020) - Possibility of using global keyword in loops instead of MemValue diff --git a/Compiler/GC/types.py b/Compiler/GC/types.py index 0b4fc5cd5..974c0828d 100644 --- a/Compiler/GC/types.py +++ b/Compiler/GC/types.py @@ -459,7 +459,41 @@ def to_sint(self, n_bits): class sbitvec(_vec): @classmethod def get_type(cls, n): - return cls + class sbitvecn(cls): + @staticmethod + def malloc(size): + return sbits.malloc(size * n) + @staticmethod + def n_elements(): + return n + @classmethod + def get_input_from(cls, player): + return cls.from_vec( + sbits.get_input_from(player, n).bit_decompose(n)) + get_raw_input_from = get_input_from + def __init__(self, other=None): + if other is not None: + self.v = sbits(other, n=n).bit_decompose(n) + @classmethod + def load_mem(cls, address): + try: + assert len(address) == n + return cls.from_vec(sbit.load_mem(x) for x in address) + except: + return cls.from_vec(sbit.load_mem(address + i) + for i in range(n)) + def store_in_mem(self, address): + assert self.v[0].n == 1 + try: + assert len(address) == n + for x, y in zip(self.v, address): + x.store_in_mem(y) + except: + for i in range(n): + self.v[i].store_in_mem(address + i) + def reveal(self): + return self.elements()[0].reveal() + return sbitvecn @classmethod def from_vec(cls, vector): res = cls() @@ -640,6 +674,7 @@ class sbitint(_bitint, _number, sbits): n_bits = None bin_type = None types = {} + vector_mul = True @classmethod def get_type(cls, n, other=None): if isinstance(other, sbitvec): @@ -727,8 +762,24 @@ def round(self, k, m, kappa=None, nearest=None, signed=None): bits = self.bit_decompose() res_bits = self.bit_adder(bits[m:k], [bits[m-1]]) return self.get_type(k - m).compose(res_bits) + @classmethod + def get_bit_matrix(cls, self_bits, other): + n = len(self_bits) + assert n == other.n + res = [] + for i, bit in enumerate(self_bits): + if util.is_zero(bit): + res.append([0] * (n - i)) + else: + if cls.vector_mul: + x = sbits.get_type(n - i)() + inst.andrs(n - i, x, other, bit) + res.append(x.bit_decompose(n - i)) + else: + res.append([(x & bit) for x in other.bit_decompose(n - i)]) + return res -class sbitintvec(sbitvec): +class sbitintvec(sbitvec, _number): def __add__(self, other): if util.is_zero(other): return self @@ -740,10 +791,11 @@ def less_than(self, other, *args, **kwargs): assert(len(self.v) == len(other.v)) return self.from_vec(sbitint.bit_less_than(self.v, other.v)) def __mul__(self, other): - assert isinstance(other, sbitint) - matrix = [[x * b for x in self.v] for b in other.bit_decompose()] + matrix = [] + for i, b in enumerate(other.bit_decompose()): + matrix.append([x * b for x in self.v[:len(self.v)-i]]) v = sbitint.wallace_tree_from_matrix(matrix) - return self.from_vec(v) + return self.from_vec(v[:len(self.v)]) __rmul__ = __mul__ reduce_after_mul = lambda x: x diff --git a/Compiler/allocator.py b/Compiler/allocator.py index 6d70c449d..77575c23d 100644 --- a/Compiler/allocator.py +++ b/Compiler/allocator.py @@ -267,7 +267,7 @@ def dependency_graph(self, merge_classes): last_print_str = None last = defaultdict(lambda: defaultdict(lambda: None)) last_open = deque() - last_text_input = [None, None] + last_input = defaultdict(lambda: [None, None]) depths = [0] * len(block.instructions) self.depths = depths @@ -331,6 +331,16 @@ def keep_order(instr, n, t, arg_index=None): add_edge(last[t][player], n) last[t][player] = n + def keep_merged_order(instr, n, t): + if last_input[t][0] is not None: + if instr.merge_id() != \ + block.instructions[last_input[t][0]].merge_id(): + add_edge(last_input[t][0], n) + last_input[t][1] = last_input[t][0] + elif last_input[t][1] is not None: + add_edge(last_input[t][1], n) + last_input[t][0] = n + for n,instr in enumerate(block.instructions): outputs,inputs = instr.get_def(), instr.get_used() @@ -355,14 +365,9 @@ def keep_order(instr, n, t, arg_index=None): # will be merged if isinstance(instr, TextInputInstruction): - if last_text_input[0] is not None: - if instr.merge_id() != \ - block.instructions[last_text_input[0]].merge_id(): - add_edge(last_text_input[0], n) - last_text_input[1] = last_text_input[0] - elif last_text_input[1] is not None: - add_edge(last_text_input[1], n) - last_text_input[0] = n + keep_merged_order(instr, n, TextInputInstruction) + elif isinstance(instr, RawInputInstruction): + keep_merged_order(instr, n, RawInputInstruction) if isinstance(instr, merge_classes): open_nodes.add(n) @@ -413,12 +418,10 @@ def keep_order(instr, n, t, arg_index=None): last_print_str = n elif isinstance(instr, PublicFileIOInstruction): keep_order(instr, n, instr.__class__) - elif isinstance(instr, RawInputInstruction): - keep_order(instr, n, instr.__class__) elif isinstance(instr, startprivateoutput_class): keep_order(instr, n, startprivateoutput_class, 2) elif isinstance(instr, stopprivateoutput_class): - keep_order(instr, n, stopprivateoutput_class, 1) + keep_order(instr, n, stopprivateoutput_class, 2) elif isinstance(instr, prep_class): keep_order(instr, n, instr.args[0]) elif isinstance(instr, StackInstruction): diff --git a/Compiler/instructions.py b/Compiler/instructions.py index 46cc4c2fe..ac4f85501 100644 --- a/Compiler/instructions.py +++ b/Compiler/instructions.py @@ -898,6 +898,7 @@ def add_usage(self, req_node): req_node.increment((self.field_type, 'input', 0), float('inf')) @base.gf2n +@base.vectorize class rawinput(base.RawInputInstruction, base.Mergeable): r""" Receive inputs from player $p$. """ __slots__ = [] @@ -941,7 +942,7 @@ class print_reg_plain(base.IOInstruction): class cond_print_plain(base.IOInstruction): r""" Conditionally print the value of a register. """ code = base.opcodes['CONDPRINTPLAIN'] - arg_format = ['c', 'c'] + arg_format = ['c', 'c', 'c'] class print_int(base.IOInstruction): r""" Print only the value of register \verb|ci| to stdout. """ @@ -1142,7 +1143,7 @@ class stopprivateoutput(base.Instruction): r""" Previously iniated private output to $n$ via $c_i$. """ __slots__ = [] code = base.opcodes['STOPPRIVATEOUTPUT'] - arg_format = ['c','p'] + arg_format = ['cw','c','p'] @base.vectorize class rand(base.Instruction): @@ -1458,6 +1459,24 @@ def __init__(self, *args, **kwargs): for i in range(2): assert args[8 + i].size == args[4 + i] +class conv2ds(base.DataInstruction): + """ Secret 2D convolution """ + code = base.opcodes['CONV2DS'] + arg_format = ['sw','s','s','int','int','int','int','int','int','int','int', + 'int','int','int'] + data_type = 'triple' + is_vec = lambda self: True + + def __init__(self, *args, **kwargs): + super(conv2ds, self).__init__(*args, **kwargs) + assert args[0].size == args[3] * args[4] + assert args[1].size == args[5] * args[6] * args[11] + assert args[2].size == args[7] * args[8] * args[11] + + def get_repeat(self): + return self.args[3] * self.args[4] * self.args[7] * self.args[8] * \ + self.args[11] + @base.vectorize class trunc_pr(base.VarArgsInstruction): """ Probalistic truncation for semi-honest computation """ diff --git a/Compiler/instructions_base.py b/Compiler/instructions_base.py index 6b828c8fa..0b7f4ef95 100644 --- a/Compiler/instructions_base.py +++ b/Compiler/instructions_base.py @@ -97,6 +97,7 @@ TRUNC_PR = 0xA9, MATMULS = 0xAA, MATMULSM = 0xAB, + CONV2DS = 0xAC, # Data access TRIPLE = 0x50, BIT = 0x51, diff --git a/Compiler/library.py b/Compiler/library.py index 3a6c6d0b1..7f055091a 100644 --- a/Compiler/library.py +++ b/Compiler/library.py @@ -3,7 +3,7 @@ in particularly providing flow control and output. """ -from Compiler.types import cint,sint,cfix,sfix,sfloat,MPCThread,Array,MemValue,cgf2n,sgf2n,_number,_mem,_register,regint,Matrix,_types, cfloat, _single, localint +from Compiler.types import cint,sint,cfix,sfix,sfloat,MPCThread,Array,MemValue,cgf2n,sgf2n,_number,_mem,_register,regint,Matrix,_types, cfloat, _single, localint, personal, copy_doc from Compiler.instructions import * from Compiler.util import tuplify,untuplify,is_zero from Compiler import instructions,instructions_base,comparison,program,util @@ -35,6 +35,7 @@ def vectorized_function(*args, **kwargs): res = function(*args, **kwargs) return res vectorized_function.__name__ = function.__name__ + copy_doc(vectorized_function, function) return vectorized_function def set_instruction_type(function): @@ -59,7 +60,7 @@ def print_str(s, *args): def print_plain_str(ss): """ Print a plain string (no custom formatting options) """ i = 1 - while 4*i < len(ss): + while 4*i <= len(ss): print_char4(ss[4*(i-1):4*i]) i += 1 i = 4*(i-1) @@ -110,8 +111,7 @@ def print_ln(s='', *args): print_ln('a is %s.', a.reveal()) """ - print_str(s, *args) - print_char('\n') + print_str(s + '\n', *args) def print_ln_if(cond, ss, *args): """ Print line if :py:obj:`cond` is true. The further arguments @@ -138,16 +138,36 @@ def print_ln_if(cond, ss, *args): cond = cint.conv(cond) for i, s in enumerate(subs): if i != 0: - cond_print_plain(cond, cint.conv(args[i - 1])) - if i < len(args): - s += ' ' * ((-len(s)) % 4) - else: - s += ' ' * ((-len(s) + 3) % 4) + args[i - 1].output_if(cond) + if i == len(args): s += '\n' + s += '\0' * ((-len(s)) % 4) while s: cond.print_if(s[:4]) s = s[4:] +def print_ln_to(player, ss, *args): + """ Print line at :py:obj:`player` only. Note that printing is + disabled by default except at player 0. + + :param player: int + :param ss: Python string + :param args: list of values known to :py:obj:`player` + + Example:: + + print_ln_to(player, 'output for %s: %s', player, x.reveal_to(player)) + """ + cond = player == get_player_id() + new_args = [] + for arg in args: + if isinstance(arg, personal): + assert arg.player == player + new_args.append(arg._v) + else: + new_args.append(arg) + print_ln_if(cond, ss, *new_args) + def print_float_precision(n): """ Set the precision for floating-point printing. @@ -1041,7 +1061,7 @@ def f(base, size): def map_reduce(n_threads, n_parallel, n_loops, initializer, reducer, \ thread_mem_req={}, looping=True): assert(n_threads != 0) - if isinstance(n_loops, list): + if isinstance(n_loops, (list, tuple)): split = n_loops n_loops = reduce(operator.mul, n_loops) def decorator(loop_body): diff --git a/Compiler/ml.py b/Compiler/ml.py index e2d264c8e..561e613d5 100644 --- a/Compiler/ml.py +++ b/Compiler/ml.py @@ -1,7 +1,8 @@ """ This module contains machine learning functionality. It is work in -progress, so you must expect things to change. The most tested -functionality is logistic regression. It can be run as follows:: +progress, so you must expect things to change. The only tested +functionality for training is logistic regression. It can be run as +follows:: sgd = ml.SGD([ml.Dense(n_examples, n_features, 1), ml.Output(n_examples, approx=True)], n_epochs, @@ -22,11 +23,27 @@ data.input_from(0) res = sgd.eval(data) print_ln('Results: %s', [x.reveal() for x in res]) + +For inference/classification, this module offers the layers necessary +for neural networks such as DenseNet, ResNet, and SqueezeNet. A +minimal example using input from player 0 and model from player 1 +looks as follows:: + + graph = Optimizer() + graph.layers = layers + layers[0].X.input_from(0) + for layer in layers: + layer.input_from(1) + graph.forward(1) + res = layers[-1].Y + +See the `readme `_ for +an example of how to run MP-SPDZ on TensorFlow graphs. """ import math -from Compiler import mpc_math +from Compiler import mpc_math, util from Compiler.types import * from Compiler.types import _unreduced_squant from Compiler.library import * @@ -51,17 +68,29 @@ def sanitize(x, raw, lower, upper): return (x < -limit).if_else(lower, res) def sigmoid(x): + """ Sigmoid function. + + :param x: sfix """ return sigmoid_from_e_x(x, exp(-x)) def sigmoid_from_e_x(x, e_x): return sanitize(x, 1 / (1 + e_x), 0, 1) def sigmoid_prime(x): + """ Sigmoid derivative. + + :param x: sfix """ sx = sigmoid(x) return sx * (1 - sx) @vectorize def approx_sigmoid(x, n=3): + """ Piece-wise approximate sigmoid as in + `Dahl et al. `_ + + :param x: input + :param n: number of pieces, 3 (default) or 5 + """ if n == 5: cuts = [-5, -2.5, 2.5, 5] le = [0] + [x <= cut for cut in cuts] + [1] @@ -84,11 +113,24 @@ def lse_0(x): return lse_0_from_e_x(x, exp(x)) def relu_prime(x): + """ ReLU derivative. """ return (0 <= x) def relu(x): + """ ReLU function (maximum of input and zero). """ return (0 < x).if_else(x, 0) +def argmax(x): + """ Compute index of maximum element. + + :param x: iterable + :returns: sint + """ + def op(a, b): + comp = (a[1] > b[1]) + return comp.if_else(a[0], b[0]), comp.if_else(a[1], b[1]) + return tree_reduce(op, enumerate(x))[0] + def progress(x): return print_ln(x) @@ -98,10 +140,47 @@ def set_n_threads(n_threads): Layer.n_threads = n_threads Optimizer.n_threads = n_threads +class Tensor(MultiArray): + def __init__(self, *args, **kwargs): + kwargs['alloc'] = False + super(Tensor, self).__init__(*args, **kwargs) + class Layer: n_threads = 1 + inputs = [] + input_bias = True + + @property + def shape(self): + return list(self._Y.sizes) + + @property + def X(self): + self._X.alloc() + return self._X + + @X.setter + def X(self, value): + self._X = value + + @property + def Y(self): + self._Y.alloc() + return self._Y + + @Y.setter + def Y(self, value): + self._Y = value + +class NoVariableLayer(Layer): + input_from = lambda *args, **kwargs: None class Output(Layer): + """ Fixed-point logistic regression output layer. + + :param N: number of examples + :param approx: :py:obj:`False` (default) or parameter for :py:obj:`approx_sigmoid` + """ def __init__(self, N, debug=False, approx=False): self.N = N self.X = sfix.Array(N) @@ -220,6 +299,12 @@ def _(i): progress('nabla W/b') class Dense(DenseBase): + """ Fixed-point dense (matrix multiplication) layer. + + :param N: number of examples + :param d_in: input dimension + :param d_out: output dimension + """ def __init__(self, N, d_in, d_out, d=1, activation='id'): self.activation = activation if activation == 'id': @@ -241,12 +326,9 @@ def __init__(self, N, d_in, d_out, d=1, activation='id'): self.W = sfix.Matrix(d_in, d_out) self.b = sfix.Array(d_out) - self.reset() - self.nabla_Y = MultiArray([N, d, d_out], sfix) self.nabla_X = MultiArray([N, d, d_in], sfix) self.nabla_W = sfix.Matrix(d_in, d_out) - self.nabla_W.assign_all(0) self.nabla_b = sfix.Array(d_out) self.f_input = MultiArray([N, d, d_out], sfix) @@ -262,11 +344,18 @@ def _(j): self.W[i][j] = sfix.get_random(-r, r) self.b.assign_all(0) + def input_from(self, player, raw=False): + self.W.input_from(player, raw=raw) + if self.input_bias: + self.b.input_from(player, raw=raw) + def compute_f_input(self, batch): N = len(batch) - prod = MultiArray([N, self.d, self.d_out], sfix) assert self.d == 1 - assert self.d_out == 1 + if self.input_bias: + prod = MultiArray([N, self.d, self.d_out], sfix) + else: + prod = self.f_input @multithread(self.n_threads, N) def _(base, size): X_sub = sfix.Matrix(self.N, self.d_in, address=self.X.address) @@ -277,10 +366,17 @@ def _(base, size): regint.inc(self.d_out))), base) - @multithread(self.n_threads, N) - def _(base, size): - v = prod.get_vector(base, size) + self.b.expand_to_vector(0, size) - self.f_input.assign_vector(v, base) + if self.input_bias: + if self.d_out == 1: + @multithread(self.n_threads, N) + def _(base, size): + v = prod.get_vector(base, size) + self.b.expand_to_vector(0, size) + self.f_input.assign_vector(v, base) + else: + @for_range_opt_multithread(self.n_threads, N) + def _(i): + v = prod[i].get_vector() + self.b.get_vector() + self.f_input[i].assign_vector(v) progress('f input') def forward(self, batch=None): @@ -406,32 +502,243 @@ def _(k): def backward(self): self.nabla_X = self.nabla_Y.schur(self.B) +class Relu(NoVariableLayer): + """ Fixed-point ReLU layer. + + :param shape: input/output shape (tuple/list of int) + """ + def __init__(self, shape, inputs=None): + self.X = Tensor(shape, sfix) + self.Y = Tensor(shape, sfix) + self.inputs = inputs + + def forward(self, batch=[0]): + assert len(batch) == 1 + @multithread(self.n_threads, self.X[batch[0]].total_size()) + def _(base, size): + tmp = relu(self.X[batch[0]].get_vector(base, size)) + self.Y[batch[0]].assign_vector(tmp, base) + +class Square(NoVariableLayer): + """ Fixed-point square layer. + + :param shape: input/output shape (tuple/list of int) + """ + def __init__(self, shape): + self.X = MultiArray(shape, sfix) + self.Y = MultiArray(shape, sfix) + + def forward(self, batch=[0]): + assert len(batch) == 1 + self.Y.assign_vector(self.X.get_part_vector(batch[0]) ** 2) + +class MaxPool(NoVariableLayer): + """ Fixed-point MaxPool layer. + + :param shape: input shape (tuple/list of four int) + :param strides: strides (tuple/list of four int, first and last must be 1) + :param ksize: kernel size (tuple/list of four int, first and last must be 1) + :param padding: :py:obj:`'VALID'` (default) or :py:obj:`'SAME'` + """ + def __init__(self, shape, strides=(1, 2, 2, 1), ksize=(1, 2, 2, 1), + padding='VALID'): + assert len(shape) == 4 + for x in strides, ksize: + for i in 0, 3: + assert x[i] == 1 + self.X = MultiArray(shape, sfix) + if padding == 'SAME': + output_shape = [int(math.ceil(shape[i] / strides[i])) for i in range(4)] + else: + output_shape = [(shape[i] - ksize[i]) // strides[i] + 1 for i in range(4)] + self.Y = MultiArray(output_shape, sfix) + self.strides = strides + self.ksize = ksize + + def forward(self, batch=[0]): + assert len(batch) == 1 + bi = MemValue(batch[0]) + need_padding = [self.strides[i] * (self.Y.sizes[i] - 1) + self.ksize[i] > + self.X.sizes[i] for i in range(4)] + @for_range_opt_multithread(self.n_threads, self.X.sizes[3]) + def _(k): + @for_range_opt(self.Y.sizes[1]) + def _(i): + h_base = self.strides[1] * i + @for_range_opt(self.Y.sizes[2]) + def _(j): + w_base = self.strides[2] * j + pool = [] + for ii in range(self.ksize[1]): + h = h_base + ii + if need_padding[1]: + h_in = h < self.X.sizes[1] + else: + h_in = True + for jj in range(self.ksize[2]): + w = w_base + jj + if need_padding[2]: + w_in = w < self.X.sizes[2] + else: + w_in = True + if not is_zero(h_in * w_in): + pool.append(h_in * w_in * self.X[bi][h_in * h] + [w_in * w][k]) + self.Y[bi][i][j][k] = util.tree_reduce( + lambda a, b: a.max(b), pool) + +class Argmax(NoVariableLayer): + """ Fixed-point Argmax layer. + + :param shape: input shape (tuple/list of two int) + """ + def __init__(self, shape): + assert len(shape) == 2 + self.X = MultiArray(shape, sfix) + self.Y = Array(shape[0], sint) + + def forward(self, batch=[0]): + assert len(batch) == 1 + self.Y[batch[0]] = argmax(self.X[batch[0]]) + +class Concat(NoVariableLayer): + """ Fixed-point concatentation layer. + + :param inputs: two input layers (tuple/list) + :param dimension: dimension for concatenation (must be 3) + """ + def __init__(self, inputs, dimension): + self.inputs = inputs + self.dimension = dimension + shapes = [inp.shape for inp in inputs] + assert dimension == 3 + assert len(shapes) == 2 + assert len(shapes[0]) == len(shapes[1]) + shape = [] + for i in range(len(shapes[0])): + if i == dimension: + shape.append(shapes[0][i] + shapes[1][i]) + else: + assert shapes[0][i] == shapes[1][i] + shape.append(shapes[0][i]) + self.Y = Tensor(shape, sfix) + + def forward(self, batch=[0]): + assert len(batch) == 1 + @for_range_multithread(self.n_threads, 1, self.Y.sizes[1:3]) + def _(i, j): + X = [x.Y[batch[0]] for x in self.inputs] + self.Y[batch[0]][i][j].assign_vector(X[0][i][j].get_vector()) + self.Y[batch[0]][i][j].assign_part_vector( + X[1][i][j].get_vector(), + len(X[0][i][j])) + +class Add(NoVariableLayer): + """ Fixed-point addition layer. + + :param inputs: two input layers with same shape (tuple/list) + """ + def __init__(self, inputs): + assert len(inputs) > 1 + shape = inputs[0].shape + for inp in inputs: + assert inp.shape == shape + self.Y = Tensor(shape, sfix) + self.inputs = inputs + + def forward(self, batch=[0]): + assert len(batch) == 1 + @multithread(self.n_threads, self.Y[0].total_size()) + def _(base, size): + tmp = sum(inp.Y[batch[0]].get_vector(base, size) + for inp in self.inputs) + self.Y[batch[0]].assign_vector(tmp, base) + +class FusedBatchNorm(Layer): + """ Fixed-point fused batch normalization layer. + + :param shape: input/output shape (tuple/list of four int) + """ + def __init__(self, shape, inputs=None): + assert len(shape) == 4 + self.X = Tensor(shape, sfix) + self.Y = Tensor(shape, sfix) + self.weights = sfix.Array(shape[3]) + self.bias = sfix.Array(shape[3]) + self.inputs = inputs + + def input_from(self, player, raw=False): + self.weights.input_from(player, raw=raw) + self.bias.input_from(player, raw=raw) + tmp = sfix.Array(len(self.bias)) + tmp.input_from(player, raw=raw) + tmp.input_from(player, raw=raw) + + def forward(self, batch=[0]): + assert len(batch) == 1 + @for_range_opt_multithread(self.n_threads, self.X.sizes[1:3]) + def _(i, j): + self.Y[batch[0]][i][j].assign_vector( + self.X[batch[0]][i][j].get_vector() * self.weights.get_vector() + + self.bias.get_vector()) + class QuantBase(object): - n_threads = 1 + bias_before_reduction = True @staticmethod def new_squant(): class _(squant): @classmethod + def get_params_from(cls, player): + cls.set_params(sfloat.get_input_from(player), + sint.get_input_from(player)) + @classmethod def get_input_from(cls, player, size=None): return cls._new(sint.get_input_from(player, size=size)) return _ - def __init__(self, input_shape, output_shape): + def const_div(self, acc, n): + logn = int(math.log(n, 2)) + acc = (acc + n // 2) + if 2 ** logn == n: + acc = acc.round(self.output_squant.params.k + logn, logn, nearest=True) + else: + acc = acc.int_div(sint(n), self.output_squant.params.k + logn) + return acc + +class FixBase: + bias_before_reduction = False + + @staticmethod + def new_squant(): + class _(sfix): + params = None + return _ + + def input_params_from(self, player): + pass + + def const_div(self, acc, n): + return (sfix._new(acc) * self.output_squant(1 / n)).v + +class BaseLayer(Layer): + def __init__(self, input_shape, output_shape, inputs=None): self.input_shape = input_shape self.output_shape = output_shape self.input_squant = self.new_squant() self.output_squant = self.new_squant() - self.X = MultiArray(input_shape, self.input_squant) - self.Y = MultiArray(output_shape, self.output_squant) + self.X = Tensor(input_shape, self.input_squant) + self.Y = Tensor(output_shape, self.output_squant) + self.inputs = inputs def temp_shape(self): return [0] -class QuantConvBase(QuantBase): +class ConvBase(BaseLayer): fewer_rounds = True + use_conv2ds = False temp_weights = None temp_inputs = None @@ -443,12 +750,32 @@ def init_temp(cls, layers): cls.temp_weights = sfix.Array(size) cls.temp_inputs = sfix.Array(size) - def __init__(self, input_shape, weight_shape, bias_shape, output_shape, stride): - super(QuantConvBase, self).__init__(input_shape, output_shape) + def __init__(self, input_shape, weight_shape, bias_shape, output_shape, stride, + padding='SAME', tf_weight_format=False, inputs=None): + super(ConvBase, self).__init__(input_shape, output_shape, inputs=inputs) self.weight_shape = weight_shape self.bias_shape = bias_shape self.stride = stride + self.tf_weight_format = tf_weight_format + if padding == 'SAME': + # https://web.archive.org/web/20171223022012/https://www.tensorflow.org/api_guides/python/nn + self.padding = [] + for i in 1, 2: + s = stride[i - 1] + if tf_weight_format: + w = weight_shape[i - 1] + else: + w = weight_shape[i] + if (input_shape[i] % stride[1] == 0): + pad_total = max(w - s, 0) + else: + pad_total = max(w - (input_shape[i] % s), 0) + self.padding.append(pad_total // 2) + elif padding == 'VALID': + self.padding = [0, 0] + else: + self.padding = padding self.weight_squant = self.new_squant() self.bias_squant = self.new_squant() @@ -456,28 +783,28 @@ def __init__(self, input_shape, weight_shape, bias_shape, output_shape, stride): self.weights = MultiArray(weight_shape, self.weight_squant) self.bias = Array(output_shape[-1], self.bias_squant) - self.unreduced = MultiArray(self.output_shape, sint, - address=self.Y.address) + self.unreduced = Tensor(self.output_shape, sint) - assert(weight_shape[-1] == input_shape[-1]) + if tf_weight_format: + weight_in = weight_shape[2] + else: + weight_in = weight_shape[3] + assert(weight_in == input_shape[-1]) assert(bias_shape[0] == output_shape[-1]) assert(len(bias_shape) == 1) assert(len(input_shape) == 4) assert(len(output_shape) == 4) assert(len(weight_shape) == 4) - def input_from(self, player): - for s in self.input_squant, self.weight_squant, self.bias_squant, self.output_squant: - s.set_params(sfloat.get_input_from(player), sint.get_input_from(player)) - self.weights.input_from(player, budget=100000) - self.bias.input_from(player) - print('WARNING: assuming that bias quantization parameters are correct') - - self.output_squant.params.precompute(self.input_squant.params, self.weight_squant.params) + def input_from(self, player, raw=False): + self.input_params_from(player) + self.weights.input_from(player, budget=100000, raw=raw) + if self.input_bias: + self.bias.input_from(player, raw=raw) def dot_product(self, iv, wv, out_y, out_x, out_c): bias = self.bias[out_c] - acc = squant.unreduced_dot_product(iv, wv) + acc = self.output_squant.unreduced_dot_product(iv, wv) acc.v += bias.v acc.res_params = self.output_squant.params #self.Y[0][out_y][out_x][out_c] = acc.reduce_after_mul() @@ -488,26 +815,17 @@ def reduction(self): n_summands = self.n_summands() start_timer(2) n_outputs = reduce(operator.mul, self.output_shape) - if n_outputs % self.n_threads == 0: - n_per_thread = n_outputs // self.n_threads - @for_range_opt_multithread(self.n_threads, self.n_threads) - def _(i): - res = _unreduced_squant( - sint.load_mem(unreduced.address + i * n_per_thread, - size=n_per_thread), - (self.input_squant.params, self.weight_squant.params), - self.output_squant.params, - n_summands).reduce_after_mul() - res.store_in_mem(self.Y.address + i * n_per_thread) - else: - @for_range_opt_multithread(self.n_threads, self.output_shape[1]) - def _(out_y): - self.Y[0][out_y].assign_vector(_unreduced_squant( - unreduced[0][out_y].get_vector(), - (self.input_squant.params, self.weight_squant.params), - self.output_squant.params, - n_summands).reduce_after_mul()) + @multithread(self.n_threads, n_outputs) + def _(base, n_per_thread): + res = self.input_squant().unreduced( + sint.load_mem(unreduced.address + base, + size=n_per_thread), + self.weight_squant(), + self.output_squant.params, + n_summands).reduce_after_mul() + res.store_in_mem(self.Y.address + base) stop_timer(2) + unreduced.delete() def temp_shape(self): return list(self.output_shape[1:]) + [self.n_summands()] @@ -520,7 +838,7 @@ def prepare_temp(self): address=self.temp_weights) return inputs, weights -class QuantConv2d(QuantConvBase): +class Conv2d(ConvBase): def n_summands(self): _, weights_h, weights_w, _ = self.weight_shape _, inputs_h, inputs_w, n_channels_in = self.input_shape @@ -528,17 +846,50 @@ def n_summands(self): def forward(self, batch=[None]): assert len(batch) == 1 - assert(self.weight_shape[0] == self.output_shape[-1]) - _, weights_h, weights_w, _ = self.weight_shape + if self.tf_weight_format: + assert(self.weight_shape[3] == self.output_shape[-1]) + weights_h, weights_w, _, _ = self.weight_shape + else: + assert(self.weight_shape[0] == self.output_shape[-1]) + _, weights_h, weights_w, _ = self.weight_shape _, inputs_h, inputs_w, n_channels_in = self.input_shape _, output_h, output_w, n_channels_out = self.output_shape stride_h, stride_w = self.stride - padding_h, padding_w = (weights_h // 2, weights_w // 2) + padding_h, padding_w = self.padding - if self.fewer_rounds: - inputs, weights = self.prepare_temp() + self.unreduced.alloc() + + if self.use_conv2ds: + @for_range_opt_multithread(self.n_threads, n_channels_out) + def _(j): + inputs = self.X.get_part_vector(0) + if self.tf_weight_format: + weights = self.weights.get_vector_by_indices(None, None, None, j) + else: + weights = self.weights.get_part_vector(j) + inputs = inputs.pre_mul() + weights = weights.pre_mul() + res = sint(size = output_h * output_w) + conv2ds(res, inputs, weights, output_h, output_w, + inputs_h, inputs_w, weights_h, weights_w, + stride_h, stride_w, n_channels_in, padding_h, padding_w) + if self.bias_before_reduction: + res += self.bias.expand_to_vector(j, res.size).v + self.unreduced.assign_vector_by_indices(res, 0, None, None, j) + self.reduction() + if not self.bias_before_reduction: + @for_range_multithread(self.n_threads, 1, + [self.output_shape[1], + self.output_shape[2]]) + def _(i, j): + self.Y[0][i][j].assign_vector(self.Y[0][i][j].get_vector() + + self.bias.get_vector()) + return + else: + if self.fewer_rounds: + inputs, weights = self.prepare_temp() @for_range_opt_multithread(self.n_threads, [output_h, output_w, n_channels_out]) @@ -577,7 +928,29 @@ def _(out_y, out_x, out_c): self.reduction() -class QuantDepthwiseConv2d(QuantConvBase): +class QuantConvBase(QuantBase): + def input_params_from(self, player): + for s in self.input_squant, self.weight_squant, self.bias_squant, self.output_squant: + s.get_params_from(player) + print('WARNING: assuming that bias quantization parameters are correct') + self.output_squant.params.precompute(self.input_squant.params, self.weight_squant.params) + +class QuantConv2d(QuantConvBase, Conv2d): + pass + +class FixConv2d(Conv2d, FixBase): + """ Fixed-point 2D convolution layer. + + :param input_shape: input shape (tuple/list of four int) + :param weight_shape: weight shape (tuple/list of four int) + :param bias_shape: bias shape (tuple/list of one int) + :param output_shape: output shape (tuple/list of four int) + :param stride: stride (tuple/list of two int) + :param padding: :py:obj:`'SAME'` (default), :py:obj:`'VALID'`, or tuple/list of two int + :param tf_weight_format: weight shape format is (height, width, input channels, output channels) instead of the default (output channels, height, widght, input channels) + """ + +class QuantDepthwiseConv2d(QuantConvBase, Conv2d): def n_summands(self): _, weights_h, weights_w, _ = self.weight_shape return weights_h * weights_w @@ -592,12 +965,34 @@ def forward(self, batch): _, output_h, output_w, n_channels_out = self.output_shape stride_h, stride_w = self.stride - padding_h, padding_w = (weights_h // 2, weights_w // 2) + padding_h, padding_w = self.padding depth_multiplier = 1 - if self.fewer_rounds: - inputs, weights = self.prepare_temp() + self.unreduced.alloc() + + if self.use_conv2ds: + assert depth_multiplier == 1 + assert self.weight_shape[0] == 1 + @for_range_opt_multithread(self.n_threads, n_channels_in) + def _(j): + inputs = self.X.get_vector_by_indices(0, None, None, j) + assert not self.tf_weight_format + weights = self.weights.get_vector_by_indices(0, None, None, + j) + inputs = inputs.pre_mul() + weights = weights.pre_mul() + res = sint(size = output_h * output_w) + conv2ds(res, inputs, weights, output_h, output_w, + inputs_h, inputs_w, weights_h, weights_w, + stride_h, stride_w, 1, padding_h, padding_w) + res += self.bias.expand_to_vector(j, res.size).v + self.unreduced.assign_vector_by_indices(res, 0, None, None, j) + self.reduction() + return + else: + if self.fewer_rounds: + inputs, weights = self.prepare_temp() @for_range_opt_multithread(self.n_threads, [output_h, output_w, n_channels_in]) @@ -635,66 +1030,74 @@ def _(out_y, out_x, out_c): self.reduction() -class QuantAveragePool2d(QuantBase): - def __init__(self, input_shape, output_shape, filter_size): - super(QuantAveragePool2d, self).__init__(input_shape, output_shape) +class AveragePool2d(BaseLayer): + def __init__(self, input_shape, output_shape, filter_size, strides=(1, 1)): + super(AveragePool2d, self).__init__(input_shape, output_shape) self.filter_size = filter_size + self.strides = strides + for i in (0, 1): + if strides[i] == 1: + assert output_shape[1+i] == 1 + assert filter_size[i] == input_shape[1+i] + else: + assert strides[i] == filter_size[i] + assert output_shape[1+i] * strides[i] == input_shape[1+i] - def input_from(self, player): - print('WARNING: assuming that input and output quantization parameters are the same') - for s in self.input_squant, self.output_squant: - s.set_params(sfloat.get_input_from(player), sint.get_input_from(player)) + def input_from(self, player, raw=False): + self.input_params_from(player) - def forward(self, batch): + def forward(self, batch=[0]): assert len(batch) == 1 _, input_h, input_w, n_channels_in = self.input_shape _, output_h, output_w, n_channels_out = self.output_shape - n = input_h * input_w - print('divisor: ', n) - - assert output_h == output_w == 1 assert n_channels_in == n_channels_out padding_h, padding_w = (0, 0) - stride_h, stride_w = (2, 2) + stride_h, stride_w = self.strides filter_h, filter_w = self.filter_size + n = filter_h * filter_w + print('divisor: ', n) - @for_range_opt(output_h) - def _(out_y): - @for_range_opt(output_w) - def _(out_x): - @for_range_opt(n_channels_in) - def _(c): - in_x_origin = (out_x * stride_w) - padding_w - in_y_origin = (out_y * stride_h) - padding_h - fxs = (-in_x_origin).max(0) - #fxe = min(filter_w, input_w - in_x_origin) - fys = (-in_y_origin).max(0) - #fye = min(filter_h, input_h - in_y_origin) - acc = 0 - #fc = 0 - for i in range(filter_h): - filter_y = fys + i - for j in range(filter_w): - filter_x = fxs + j - in_x = in_x_origin + filter_x - in_y = in_y_origin + filter_y - acc += self.X[0][in_y][in_x][c].v - #fc += 1 - logn = int(math.log(n, 2)) - acc = (acc + n // 2) - if 2 ** logn == n: - acc = acc.round(self.output_squant.params.k + logn, - logn, nearest=True) - else: - acc = acc.int_div(sint(n), - self.output_squant.params.k + logn) - #acc = min(255, max(0, acc)) - self.Y[0][out_y][out_x][c] = self.output_squant._new(acc) + @for_range_opt_multithread(self.n_threads, + [output_h, output_w, n_channels_in]) + def _(out_y, out_x, c): + in_x_origin = (out_x * stride_w) - padding_w + in_y_origin = (out_y * stride_h) - padding_h + fxs = util.max(-in_x_origin, 0) + #fxe = min(filter_w, input_w - in_x_origin) + fys = util.max(-in_y_origin, 0) + #fye = min(filter_h, input_h - in_y_origin) + acc = 0 + #fc = 0 + for i in range(filter_h): + filter_y = fys + i + for j in range(filter_w): + filter_x = fxs + j + in_x = in_x_origin + filter_x + in_y = in_y_origin + filter_y + acc += self.X[0][in_y][in_x][c].v + #fc += 1 + acc = self.const_div(acc, n) + self.Y[0][out_y][out_x][c] = self.output_squant._new(acc) + +class QuantAveragePool2d(QuantBase, AveragePool2d): + def input_params_from(self, player): + print('WARNING: assuming that input and output quantization parameters are the same') + for s in self.input_squant, self.output_squant: + s.get_params_from(player) + +class FixAveragePool2d(FixBase, AveragePool2d): + """ Fixed-point 2D AvgPool layer. + + :param input_shape: input shape (tuple/list of four int) + :param output_shape: output shape (tuple/list of four int) + :param filter_size: filter size (tuple/list of two int) + :param strides: strides (tuple/list of two int) + """ -class QuantReshape(QuantBase): +class QuantReshape(QuantBase, BaseLayer): def __init__(self, input_shape, _, output_shape): super(QuantReshape, self).__init__(input_shape, output_shape) @@ -711,7 +1114,7 @@ def forward(self, batch): # reshaping is implicit self.Y.assign(self.X) -class QuantSoftmax(QuantBase): +class QuantSoftmax(QuantBase, BaseLayer): def input_from(self, player): print('WARNING: assuming that input and output quantization parameters are the same') for s in self.input_squant, self.output_squant: @@ -729,32 +1132,76 @@ def comp(left, right): print_ln('guess: %s', util.tree_reduce(comp, list(enumerate(self.X[0])))[0].reveal()) class Optimizer: + """ Base class for graphs of layers. """ n_threads = Layer.n_threads - def forward(self, N=None, batch=None): + @property + def layers(self): + """ Get all layers. """ + return self._layers + + @layers.setter + def layers(self, layers): + """ Construct linear graph from list of layers. """ + self._layers = layers + prev = None + for layer in layers: + if not layer.inputs and prev is not None: + layer.inputs = [prev] + prev = layer + + def set_layers_with_inputs(self, layers): + """ Construct graph from :py:obj:`inputs` members of list of layers. """ + self._layers = layers + used = set([None]) + for layer in reversed(layers): + layer.last_used = list(filter(lambda x: x not in used, layer.inputs)) + used.update(layer.inputs) + + def forward(self, N=None, batch=None, keep_intermediate=True): + """ Compute graph. + + :param N: batch size (used if batch not given) + :param batch: indices for computation (:py:class:`Compiler.types.Array`. or list) + :param keep_intermediate: do not free memory of intermediate results after use + """ if batch is None: batch = regint.Array(N) batch.assign(regint.inc(N)) - for j in range(len(self.layers) - 1): - self.layers[j].forward(batch=batch) - tmp = self.layers[j].Y.get_part_vector(0, len(batch)) - self.layers[j + 1].X.assign_vector(tmp) - self.layers[-1].forward(batch=batch) + for layer in self.layers: + if layer.inputs and len(layer.inputs) == 1 and layer.inputs[0] is not None: + layer._X.address = layer.inputs[0].Y.address + layer.Y.alloc() + break_point() + layer.forward(batch=batch) + break_point() + if not keep_intermediate: + for l in layer.last_used: + l.Y.delete() def eval(self, data): + """ Compute evaluation after training. """ N = len(data) self.layers[0].X.assign(data) self.forward(N) return self.layers[-1].eval(N) def backward(self, batch): - for j in range(1, len(self.layers)): - self.layers[-j].backward(batch=batch) - self.layers[-j - 1].nabla_Y.assign_vector( - self.layers[-j].nabla_X.get_part_vector(0, len(batch))) - self.layers[0].backward(compute_nabla_X=False, batch=batch) + """ Compute backward propagation. """ + for layer in reversed(self.layers): + if len(layer.inputs) == 0: + layer.backward(compute_nabla_X=False, batch=batch) + else: + layer.backward(batch=batch) + if len(layer.inputs) == 1: + layer.inputs[0].nabla_Y.assign_vector( + layer.nabla_X.get_part_vector(0, len(batch))) def run(self, batch_size=None): + """ Run training. + + :param batch_size: batch size (defaults to example size of first layer) + """ if batch_size is not None: N = batch_size else: @@ -841,6 +1288,12 @@ def _(k): mpc_math.sqrt(vhat) + self.epsilon class SGD(Optimizer): + """ Stochastic gradient descent. + + :param layers: layers of linear graph + :param n_epochs: number of epochs for training + :param report_loss: disclose and print loss + """ def __init__(self, layers, n_epochs, debug=False, report_loss=False): self.momentum = 0.9 self.layers = layers @@ -860,6 +1313,10 @@ def __init__(self, layers, n_epochs, debug=False, report_loss=False): self.X_by_label = None def reset(self, X_by_label=None): + """ Reset layer parameters. + + :param X_by_label: if given, set training data by public labels for balancing + """ self.X_by_label = X_by_label if X_by_label is not None: for label, X in enumerate(X_by_label): diff --git a/Compiler/program.py b/Compiler/program.py index f9da6fe18..bf756efd0 100644 --- a/Compiler/program.py +++ b/Compiler/program.py @@ -14,6 +14,7 @@ import itertools import math from functools import reduce +import re data_types = dict( @@ -66,8 +67,9 @@ def __init__(self, args, options, param=-1): self._curr_tape = None self.DEBUG = False self.allocated_mem = RegType.create_dict(lambda: USER_MEM) - self.free_mem_blocks = defaultdict(set) + self.free_mem_blocks = defaultdict(lambda: defaultdict(set)) self.allocated_mem_blocks = {} + self.saved = 0 self.req_num = None self.tape_stack = [] self.n_threads = 1 @@ -142,7 +144,8 @@ def init_names(self, args): else: self.name = progname if len(args) > 1: - self.name += '-' + '-'.join(args[1:]) + self.name += '-' + '-'.join(re.sub('/', '_', arg) + for arg in args[1:]) self.progname = progname def new_tape(self, function, args=[], name=None, single_thread=False): @@ -248,9 +251,20 @@ def malloc(self, size, mem_type, reg_type=None): mem_type = mem_type.reg_type elif reg_type is not None: self.types[mem_type] = reg_type - key = size, mem_type - if self.free_mem_blocks[key]: - addr = self.free_mem_blocks[key].pop() + block_size = 0 + blocks = self.free_mem_blocks[mem_type] + if len(blocks[size]) > 0: + block_size = size + else: + for block_size, addresses in blocks.items(): + if block_size >= size and len(addresses) > 0: + break + else: + block_size = 0 + if block_size >= size: + addr = self.free_mem_blocks[mem_type][block_size].pop() + self.free_mem_blocks[mem_type][block_size - size].add(addr + size) + self.saved += size else: addr = self.allocated_mem[mem_type] self.allocated_mem[mem_type] += size @@ -265,7 +279,7 @@ def free(self, addr, mem_type): is not self.curr_tape.basicblocks[0].alloc_pool: raise CompilerError('Cannot free memory within function block') size = self.allocated_mem_blocks.pop((addr,mem_type)) - self.free_mem_blocks[size,mem_type].add(addr) + self.free_mem_blocks[mem_type][size].add(addr) def finalize_memory(self): from . import library @@ -280,6 +294,9 @@ def finalize_memory(self): else: from Compiler.types import _get_type _get_type(mem_type).load_mem(size - 1, mem_type) + if self.verbose: + if self.saved: + print('Saved %s memory units through reallocation' % self.saved) def public_input(self, x): self.public_input_file.write('%s\n' % str(x)) diff --git a/Compiler/types.py b/Compiler/types.py index 89029142c..b94522d68 100644 --- a/Compiler/types.py +++ b/Compiler/types.py @@ -304,12 +304,15 @@ def dot_product(cls, a, b): from Compiler.library import for_range_opt_multithread res = MemValue(cls(0)) l = min(len(a), len(b)) - aa, bb = [Array(l, cls) for i in range(2)] - aa.assign(a) - bb.assign(b) + xx = [a, b] + for i, x in enumerate((a, b)): + if not isinstance(x, Array): + xx[i] = Array(l, cls) + xx[i].assign(x) + aa, bb = xx @for_range_opt_multithread(None, l) def _(i): - res.iadd(aa[i] * bb[i]) + res.iadd(res.value_type.conv(aa[i] * bb[i])) return res.read() class _int(object): @@ -488,11 +491,16 @@ def hard_conv(cls, val): @vectorized_classmethod @set_instruction_type def _load_mem(cls, address, direct_inst, indirect_inst): - res = cls() if isinstance(address, _register): + if address.size > 1: + size = address.size + else: + size = get_global_vector_size() + res = cls(size=size) indirect_inst(res, cls._expand_address(address, get_global_vector_size())) else: + res = cls() direct_inst(res, address) return res @@ -531,6 +539,10 @@ def malloc(cls, size): :param size: compile-time (int) """ return program.malloc(size, cls) + @classmethod + def free(cls, addr): + program.free(addr, cls.reg_type) + @set_instruction_type def __init__(self, reg_type, val, size): if isinstance(val, (tuple, list)): @@ -964,7 +976,8 @@ def print_if(self, string): :param string: Python string """ cond_print_str(self, string) - + def output_if(self, cond): + cond_print_plain(cond, self, cint(0)) class cgf2n(_clear, _gf2n): @@ -1399,6 +1412,9 @@ def print_if(self, string): :param string: Python string """ cint(self).print_if(string) + def output_if(self, cond): + cint(self).output_if(cond) + class localint(object): """ Local integer that must prevented from leaking into the secure computation. Uses regint internally. """ @@ -1420,6 +1436,11 @@ def output(self): __eq__ = lambda self, other: localint(self._v == other) __ne__ = lambda self, other: localint(self._v != other) +class personal(object): + def __init__(self, player, value): + self.player = player + self._v = value + class _secret(_register): __slots__ = [] @@ -1666,13 +1687,17 @@ def reveal(self): @set_instruction_type def reveal_to(self, player): - """ Reveal secret value to player. + """ Reveal secret value to :py:obj:`player`. Result written to ``Player-Data/Private-Output-P`` - :param player: int """ + :param player: int + :returns: value to be used with :py:func:`Compiler.library.print_ln_to` + """ masked = self.__class__() + res = personal(player, self.clear_type()) startprivateoutput(masked, self, player) - stopprivateoutput(masked.reveal(), player) + stopprivateoutput(res._v, masked.reveal(), player) + return res class sint(_secret, _int): @@ -2365,8 +2390,7 @@ def mul(self, other): other = self.bin_type(other) except CompilerError: return NotImplemented - products = [x * other for x in self_bits] - bit_matrix = [util.bit_decompose(x, self.n_bits) for x in products] + bit_matrix = self.get_bit_matrix(self_bits, other) return self.compose(self.wallace_tree_from_matrix(bit_matrix, False)) @classmethod @@ -2537,6 +2561,11 @@ def compose(cls, bits): gmovs(res, sum(b << i for i,b in enumerate(bits))) return res + @staticmethod + def get_bit_matrix(self_bits, other): + products = [x * other for x in self_bits] + return [util.bit_decompose(x, len(self_bits)) for x in products] + def load_other(self, other): if isinstance(other, sgf2nint): gmovs(self, self.compose(other.bit_decompose(self.n_bits))) @@ -2945,6 +2974,9 @@ def print_plain(self): print_float_plain(cint(abs_v), cint(-self.f), \ cint(0), cint(sign), cint(0)) + def output_if(self, cond): + cond_print_plain(cond, self.v, cint(-self.f)) + class _single(_number, _structure): """ Representation as single integer preserving the order """ """ E.g. fixed-point numbers """ @@ -2952,11 +2984,6 @@ class _single(_number, _structure): kappa = 40 round_nearest = False - @property - @classmethod - def reg_type(cls): - return cls.int_type.reg_type - @classmethod def receive_from_client(cls, n, client_id, message_type=ClientMessageType.NoType): """ Securely obtain shares of n values input by a client. @@ -2989,6 +3016,10 @@ def coerce(cls, other): def malloc(cls, size): return program.malloc(size, cls.int_type) + @classmethod + def free(cls, addr): + return cls.int_type.free(addr) + @staticmethod def n_elements(): return 1 @@ -3299,6 +3330,10 @@ def get_input_from(cls, player): inputmixed('fix', v, cls.f, player) return cls._new(v) + @vectorized_classmethod + def get_raw_input_from(cls, player): + return cls._new(cls.int_type.get_raw_input_from(player)) + @vectorized_classmethod def get_random(cls, lower, upper): """ Uniform secret random number around centre of bounds. @@ -3340,6 +3375,16 @@ def unreduced(self, v, other=None, res_params=None, n_summands=1): def multipliable(v, k, f): return cfix(cint.conv(v), k, f) + def reveal_to(self, player): + """ Reveal secret value to :py:obj:`player`. + Raw representation written to ``Player-Data/Private-Output-P`` + + :param player: int + :returns: value to be used with :py:func:`Compiler.library.print_ln_to` + """ + return personal(player, cfix(self.v.reveal_to(player)._v, + self.k, self.f)) + class unreduced_sfix(_single): int_type = sint @@ -4040,7 +4085,7 @@ def create_from(cls, l): res.assign(tmp) return res - def __init__(self, length, value_type, address=None, debug=None): + def __init__(self, length, value_type, address=None, debug=None, alloc=True): """ :param length: compile-time integer (int) or :py:obj:`None` for unknown length :param value_type: basic type @@ -4049,17 +4094,19 @@ def __init__(self, length, value_type, address=None, debug=None): self.address = address self.length = length self.value_type = value_type - if address is None: - self.address = self._malloc() + self.address = address self.address_cache = {} self.debug = debug + if alloc: + self.alloc() - def _malloc(self): - return self.value_type.malloc(self.length) + def alloc(self): + if self.address is None: + self.address = self.value_type.malloc(self.length) def delete(self): - if program: - program.free(self.address, self.value_type.reg_type) + self.value_type.free(self.address) + self.address = None def get_address(self, index): key = str(index) @@ -4074,8 +4121,9 @@ def get_address(self, index): if n == 1: # length can be None for single-element arrays length = 0 + base = self.address + index self.address_cache[program.curr_block, key] = \ - util.untuplify([self.address + index + i * length \ + util.untuplify([base + i * length \ for i in range(n)]) if self.debug: library.print_ln_if(index >= self.length, 'OF:' + self.debug) @@ -4136,6 +4184,9 @@ def _store(self, value, address): def __len__(self): return self.length + def total_size(self): + return len(self) * self.value_type.n_elements() + def __iter__(self): for i in range(self.length): yield self[i] @@ -4153,11 +4204,17 @@ def assign(self, other, base=0): if len(self) != None and util.is_constant(base): assert len(self) >= other.size + base except AttributeError: - for i,j in enumerate(other): - self[i] = j + if isinstance(other, Array): + @library.for_range_opt(len(other)) + def _(i): + self[i] = other[i] + else: + for i,j in enumerate(other): + self[i] = j return self assign_vector = assign + assign_part_vector = assign def assign_all(self, value, use_threads=True, conv=True): """ Assign the same value to all entries. @@ -4197,11 +4254,20 @@ def expand_to_vector(self, index, size): def get_mem_value(self, index): return MemValue(self[index], self.get_address(index)) - def input_from(self, player, budget=None): + def input_from(self, player, budget=None, raw=False): """ Fill with inputs from player if supported by type. :param player: public (regint/cint/int) """ - self.assign(self.value_type.get_input_from(player, size=len(self))) + if raw: + input_from = self.value_type.get_raw_input_from + else: + input_from = self.value_type.get_input_from + try: + self.assign(input_from(player, size=len(self))) + except: + @library.for_range_opt(len(self), budget=budget) + def _(i): + self[i] = input_from(player) def __add__(self, other): """ Vector addition. @@ -4260,9 +4326,12 @@ class SubMultiArray(object): """ Multidimensional array functionality. """ def __init__(self, sizes, value_type, address, index, debug=None): """ Do not call this, use :py:class:`MultiArray` instead. """ - self.sizes = sizes + self.sizes = tuple(sizes) self.value_type = _get_type(value_type) - self.address = address + index * self.total_size() + if address is not None: + self.address = address + index * self.total_size() + else: + self.address = None self.sub_cache = {} self.debug = debug if debug: @@ -4344,22 +4413,58 @@ def assign(self, other): def get_part_vector(self, base=0, size=None): assert self.value_type.n_elements() == 1 part_size = reduce(operator.mul, self.sizes[1:]) - size = (size or len(self)) * part_size + size = (size or 1) * part_size assert size <= self.total_size() return self.value_type.load_mem(self.address + base * part_size, size=size) + def get_addresses(self, *indices): + assert self.value_type.n_elements() == 1 + assert len(indices) == len(self.sizes) + size = 1 + base = 0 + has_glob = False + last_was_glob = False + for i, x in enumerate(indices): + part_size = reduce(operator.mul, (1,) + self.sizes[i + 1:]) + if x is None: + assert not has_glob or last_was_glob + has_glob = True + size *= self.sizes[i] + skip = part_size + else: + base += x * part_size + last_was_glob = x is None + res = regint.inc(size, self.address + base, skip) + return res + + def get_vector_by_indices(self, *indices): + addresses = self.get_addresses(*indices) + return self.value_type.load_mem(addresses) + + def assign_vector_by_indices(self, vector, *indices): + addresses = self.get_addresses(*indices) + vector.store_in_mem(addresses) + def same_shape(self): """ :return: new multidimensional array with same shape and basic type """ return MultiArray(self.sizes, self.value_type) - def input_from(self, player, budget=None): + def input_from(self, player, budget=None, raw=False): """ Fill with inputs from player if supported by type. :param player: public (regint/cint/int) """ - @library.for_range_opt(self.sizes[0], budget=budget) - def _(i): - self[i].input_from(player, budget=budget) + if (budget is None or self.total_size() < budget) and \ + self.value_type.n_elements() == 1: + if raw: + input_from = self.value_type.get_raw_input_from + else: + input_from = self.value_type.get_input_from + self.assign_vector(input_from(player, size=self.total_size())) + else: + @library.for_range_opt(self.sizes[0], budget=budget) + def _(i): + self[i].input_from(player, budget=budget, raw=raw) def schur(self, other): """ Element-wise product. @@ -4484,6 +4589,11 @@ def direct_mul(self, other, reduce=True, indices=None): self.sizes[0], *other.sizes, reduce=reduce, indices=indices) + def direct_mul_to_matrix(self, other): + res = self.value_type.Matrix(self.sizes[0], other.sizes[1]) + res.assign_vector(self.direct_mul(other)) + return res + def budget_mul(self, other, n_rows, row, n_columns, column, reduce=True, res=None): assert len(self.sizes) == 2 @@ -4576,7 +4686,7 @@ def _(j): class MultiArray(SubMultiArray): """ Multidimensional array. """ - def __init__(self, sizes, value_type, debug=None, address=None): + def __init__(self, sizes, value_type, debug=None, address=None, alloc=True): """ :param sizes: shape (compile-time list of integers) :param value_type: basic type of entries @@ -4585,12 +4695,26 @@ def __init__(self, sizes, value_type, debug=None, address=None): self.array = address else: self.array = Array(reduce(operator.mul, sizes), \ - value_type, address=address) + value_type, address=address, alloc=alloc) SubMultiArray.__init__(self, sizes, value_type, self.array.address, 0, \ debug=debug) if len(sizes) < 2: raise CompilerError('Use Array') + @property + def address(self): + return self.array.address + + @address.setter + def address(self, value): + self.array.address = value + + def alloc(self): + self.array.alloc() + + def delete(self): + self.array.delete() + class Matrix(MultiArray): """ Matrix. """ def __init__(self, rows, columns, value_type, debug=None, address=None): @@ -4777,8 +4901,14 @@ def reveal(self): if_else = lambda self,*args,**kwargs: self.read().if_else(*args, **kwargs) - expand_to_vector = lambda self,*args,**kwargs: \ - self.read().expand_to_vector(*args, **kwargs) + def expand_to_vector(self, size=None): + if program.curr_block == self.last_write_block: + return self.read().expand_to_vector(size) + else: + if size is None: + size = get_global_vector_size() + addresses = regint.inc(size, self.address, 0) + return self.value_type.load_mem(addresses) def __repr__(self): return 'MemValue(%s,%d)' % (self.value_type, self.address) diff --git a/Compiler/util.py b/Compiler/util.py index 4ebadf301..6bcce91c1 100644 --- a/Compiler/util.py +++ b/Compiler/util.py @@ -109,6 +109,7 @@ def round_to_int(x): def tree_reduce(function, sequence): sequence = list(sequence) + assert len(sequence) > 0 n = len(sequence) if n == 1: return sequence[0] @@ -162,6 +163,9 @@ def is_all_ones(x, n): else: return False +def max(x, y): + return if_else(x > y, x, y) + def long_one(x): try: return x.long_one() diff --git a/ECDSA/preprocessing.hpp b/ECDSA/preprocessing.hpp index 5bb9a38dd..4701e5882 100644 --- a/ECDSA/preprocessing.hpp +++ b/ECDSA/preprocessing.hpp @@ -11,6 +11,7 @@ #include "Processor/Data_Files.h" #include "Protocols/ReplicatedPrep.h" #include "Protocols/MaliciousShamirShare.h" +#include "Protocols/Rep3Share.h" #include "GC/TinierSecret.h" #include "GC/TinierPrep.h" #include "GC/MaliciousCcdSecret.h" diff --git a/FHE/Matrix.cpp b/FHE/Matrix.cpp index ed7ad7213..747cac3c9 100644 --- a/FHE/Matrix.cpp +++ b/FHE/Matrix.cpp @@ -2,6 +2,8 @@ #include "FHE/Matrix.h" #include "Exceptions/Exceptions.h" +#include "Math/modp.hpp" + #include #include diff --git a/FHEOffline/Sacrificing.cpp b/FHEOffline/Sacrificing.cpp index c45ae162e..2926f0486 100644 --- a/FHEOffline/Sacrificing.cpp +++ b/FHEOffline/Sacrificing.cpp @@ -19,7 +19,12 @@ void Triple_Checking(const Player& P, MAC_Check& MC, int nm, int output_thread, TripleSacriFactory< Share >& factory, bool write_output, bool clear, string dir) { - assert(T::length() >= 40); + if (T::length() < 40) + { + cerr << "Field too small for reasonable security" << endl; + cerr << "Use a larger field or remove this warning from " << __FILE__ << endl; + exit(1); + } ofstream outf; if (write_output) diff --git a/GC/Machine.h b/GC/Machine.h index 29e574590..bdc36afc3 100644 --- a/GC/Machine.h +++ b/GC/Machine.h @@ -17,7 +17,7 @@ using namespace std; namespace GC { -template class Program; +class Program; template class Memories @@ -38,7 +38,7 @@ class Machine : public ::BaseMachine, public Memories public: Memory MI; - vector > progs; + vector progs; bool use_encryption; bool more_comm_less_comp; diff --git a/GC/Memory.h b/GC/Memory.h index 89f01d4b9..baa3aab6c 100644 --- a/GC/Memory.h +++ b/GC/Memory.h @@ -33,9 +33,6 @@ class Memory : public vector T& operator[] (Integer i); const T& operator[] (Integer i) const; size_t capacity_in_bytes() const { return this->capacity() * sizeof(T); } - - template - Memory& cast() { return *reinterpret_cast< Memory* >(this); } }; template diff --git a/GC/Processor.h b/GC/Processor.h index 0895acca0..492e306a4 100644 --- a/GC/Processor.h +++ b/GC/Processor.h @@ -20,8 +20,6 @@ using namespace std; namespace GC { -template class Program; - class ExecutionStats : public map { public: @@ -39,7 +37,8 @@ class Processor : public ::ProcessorBase, public GC::RuntimeBranching public: static int check_args(const vector& args, int n); - static void check_input(bigint in, int n_bits); + template + static void check_input(const U& in, int n_bits); Machine* machine; Memories& memories; @@ -56,6 +55,8 @@ class Processor : public ::ProcessorBase, public GC::RuntimeBranching ExecutionStats stats; + Timer xor_timer; + Processor(Machine& machine); Processor(Memories& memories, Machine* machine = 0); ~Processor(); @@ -66,7 +67,8 @@ class Processor : public ::ProcessorBase, public GC::RuntimeBranching void reset(const U& program); long long get_input(const int* params, bool interactive = false); - bigint get_long_input(const int* params, ProcessorBase& input_proc, + template + U get_long_input(const int* params, ProcessorBase& input_proc, bool interactive = false); void bitcoms(T& x, const vector& regs) { x.bitcom(S, regs); } @@ -92,6 +94,7 @@ class Processor : public ::ProcessorBase, public GC::RuntimeBranching Integer dest_address, Integer source_address); void xors(const vector& args); + void xors(const vector& args, size_t start, size_t end); void nots(const ::BaseInstruction& instruction); void andm(const ::BaseInstruction& instruction); void and_(const vector& args, bool repeat); @@ -118,9 +121,9 @@ inline int GC::Processor::check_args(const vector& args, int n) if (args.size() % n != 0) throw runtime_error("invalid number of arguments"); int total = 0; - for (size_t i = 0; i < args.size(); i += n) + for (auto it = args.begin(); it < args.end(); it += n) { - total += args[i]; + total += *it; } return total; } diff --git a/GC/Processor.hpp b/GC/Processor.hpp index 46c3dcd53..16c638fb2 100644 --- a/GC/Processor.hpp +++ b/GC/Processor.hpp @@ -39,7 +39,10 @@ template Processor::~Processor() { #ifdef VERBOSE - cerr << "Finished after " << time << " instructions" << endl; + if (xor_timer.elapsed() > 0) + cerr << "XOR time: " << xor_timer.elapsed() << endl; + if (time > 0) + cerr << "Finished after " << time << " instructions" << endl; #endif } @@ -68,14 +71,15 @@ template inline long long GC::Processor::get_input(const int* params, bool interactive) { assert(params[0] <= 64); - return get_long_input(params, *this, interactive).get_si(); + return get_long_input(params, *this, interactive).get(); } template -bigint GC::Processor::get_long_input(const int* params, +template +U GC::Processor::get_long_input(const int* params, ProcessorBase& input_proc, bool interactive) { - bigint res = input_proc.get_input>(interactive, + U res = input_proc.get_input>(interactive, ¶ms[1]).items[0]; int n_bits = *params; check_input(res, n_bits); @@ -83,19 +87,20 @@ bigint GC::Processor::get_long_input(const int* params, } template -void GC::Processor::check_input(bigint in, int n_bits) +template +void GC::Processor::check_input(const U& in, int n_bits) { auto test = in >> (n_bits - 1); if (n_bits == 1) { if (not (in == 0 or in == 1)) - throw runtime_error("input not a bit: " + in.get_str()); + throw runtime_error("input not a bit: " + to_string(in)); } else if (not (test == 0 or test == -1)) { throw runtime_error( "input too large for a " + std::to_string(n_bits) - + "-bit signed integer: " + in.get_str()); + + "-bit signed integer: " + to_string(in)); } } @@ -193,9 +198,18 @@ void Processor::mem_op(int n, Memory& dest, const Memory& source, template void Processor::xors(const vector& args) { - assert(args.size() % 4 == 0); + xors(args, 0, args.size()); +} + +template +void Processor::xors(const vector& args, size_t start, size_t end) +{ + assert(start % 4 == 0); + assert(end % 4 == 0); + assert(start < end); + assert(args.begin() + end <= args.end()); int dl = T::default_length; - for (auto it = args.begin(); it < args.end(); it += 4) + for (auto it = args.begin() + start; it < args.begin() + end; it += 4) { if (*it == 1) S[*(it + 1)].xor_(1, S[*(it + 2)], S[*(it + 3)]); @@ -209,9 +223,6 @@ void Processor::xors(const vector& args) S[*(it + 3) + j]); } } -#ifndef FREE_XOR - complexity += args[i]; -#endif } } diff --git a/GC/Program.h b/GC/Program.h index c91dec955..5afc5fed4 100644 --- a/GC/Program.h +++ b/GC/Program.h @@ -24,11 +24,9 @@ enum BreakType { template class Processor; -template class Program { vector p; - int offline_data_used; // Maximal register used unsigned max_reg[MAX_REG_TYPE]; @@ -36,9 +34,6 @@ class Program // Memory size used directly unsigned max_mem[MAX_REG_TYPE]; - // True if program contains variable-sized loop - bool unknown_usage; - void compute_constants(); public: @@ -50,24 +45,14 @@ class Program void parse(const string& programe); void parse(istream& s); - int get_offline_data_used() const { return offline_data_used; } - void print_offline_cost() const; - - bool usage_unknown() const { return unknown_usage; } - unsigned num_reg(RegType reg_type) const { return max_reg[reg_type]; } const unsigned* direct_mem(RegType reg_type) const { return &max_mem[reg_type]; } - template + template BreakType execute(Processor& Proc, U& dynamic_memory, int PC = -1) const; - - bool done(Processor& Proc) const { return Proc.PC >= p.size(); } - - template - Program& cast() { return *reinterpret_cast< Program* >(this); } }; diff --git a/GC/Program.hpp b/GC/Program.hpp index 0e1db7ee5..e78a6a4c9 100644 --- a/GC/Program.hpp +++ b/GC/Program.hpp @@ -17,15 +17,14 @@ namespace GC { -template -Program::Program() : - offline_data_used(0), unknown_usage(false) +inline +Program::Program() { compute_constants(); } -template -void Program::compute_constants() +inline +void Program::compute_constants() { for (int reg_type = 0; reg_type < MAX_REG_TYPE; reg_type++) { @@ -34,8 +33,6 @@ void Program::compute_constants() } for (unsigned int i = 0; i < p.size(); i++) { - if (!p[i].get_offline_data_usage(offline_data_used)) - unknown_usage = true; for (int reg_type = 0; reg_type < MAX_REG_TYPE; reg_type++) { max_reg[reg_type] = max(max_reg[reg_type], @@ -46,15 +43,15 @@ void Program::compute_constants() } } -template -void Program::parse(const string& bytecode_name) +inline +void Program::parse(const string& bytecode_name) { string filename = "Programs/Bytecode/" + bytecode_name + ".bc"; parse_file(filename); } -template -void Program::parse_file(const string& filename) +inline +void Program::parse_file(const string& filename) { ifstream s(filename.c_str()); if (s.bad() or s.fail()) @@ -62,8 +59,8 @@ void Program::parse_file(const string& filename) parse(s); } -template -void Program::parse(istream& s) +inline +void Program::parse(istream& s) { p.resize(0); Instruction instr; @@ -84,22 +81,9 @@ void Program::parse(istream& s) compute_constants(); } -template -void Program::print_offline_cost() const -{ - if (unknown_usage) - { - cerr << "Tape has unknown usage" << endl; - return; - } - - cerr << "Cost of first tape: " << offline_data_used << endl; -} - -template -template +template __attribute__((flatten)) -BreakType Program::execute(Processor& Proc, U& dynamic_memory, +BreakType Program::execute(Processor& Proc, U& dynamic_memory, int PC) const { if (PC != -1) diff --git a/GC/Secret.h b/GC/Secret.h index 97a2b371d..d2596a10d 100644 --- a/GC/Secret.h +++ b/GC/Secret.h @@ -34,33 +34,9 @@ class Secret; template inline void XOR(T& res, const T& left, const T& right) { -#ifdef FREE_XOR - Secret::cast(res).XOR(Secret::cast(left), Secret::cast(right)); -#else - Secret::cast(res).op(Secret::cast(left), Secret::cast(right), 0x0110); -#endif + res.XOR(left, right); } -class AuthValue -{ -public: - static string type_string() { return "authenticated value"; } - word share; - int128 mac; - AuthValue() : share(0), mac(0) {} - void assign(const word& value, const int128& mac_key, bool first_player); - void check(const word& mac_key) const; - friend ostream& operator<<(ostream& o, const AuthValue& auth_value); -}; - -class Mask -{ -public: - word share; - int128 mac; - Mask() : share(0) {} -}; - template class Processor; template class Machine; @@ -78,6 +54,8 @@ class Secret T& get_new_reg(); public: + typedef T part_type; + typedef typename T::DynamicMemory DynamicMemory; typedef NoShare bit_type; @@ -93,19 +71,13 @@ class Secret static const bool is_real = true; - static T& cast(T& reg) { return *reinterpret_cast(®); } - static const T& cast(const T& reg) { return *reinterpret_cast(®); } - static Secret input(party_id_t from, const int128& input, int n_bits = -1); static Secret input(Processor>& processor, const InputArgs& args); void random(int n_bits, int128 share); void random_bit(); - static Secret reconstruct(const int128& x, int length); template static void store_clear_in_dynamic(U& mem, const vector& accesses) { T::store_clear_in_dynamic(mem, accesses); } - void store(Memory& mem, size_t address); - static Secret carryless_mult(const Secret& x, const Secret& y); static void output(T& reg); template @@ -146,7 +118,6 @@ class Secret void load_clear(int n, const Integer& x); void operator=(const Integer& x) { load_clear(default_length, x); } - void load(int n, const Memory& mem, size_t address); Secret operator<<(int i); Secret operator>>(int i); @@ -176,8 +147,8 @@ class Secret RegVector& get_regs() { return registers; } const RegVector& get_regs() const { return registers; } - const T& get_reg(int i) const { return *reinterpret_cast(®isters.at(i)); } - T& get_reg(int i) { return *reinterpret_cast(®isters.at(i)); } + const T& get_reg(int i) const { return registers.at(i); } + T& get_reg(int i) { return registers.at(i); } void resize_regs(size_t n); }; diff --git a/GC/Secret.hpp b/GC/Secret.hpp index a12266343..f7362e990 100644 --- a/GC/Secret.hpp +++ b/GC/Secret.hpp @@ -63,25 +63,9 @@ void Secret::random(int n_bits, int128 share) (void)share; if (n_bits > 128) throw not_implemented(); -#ifdef NO_INPUT resize_regs(n_bits); for (int i = 0; i < n_bits; i++) get_reg(i).random(); -#else - for (int i = 0; i < CommonParty::singleton->get_n_parties(); i++) - { - Secret tmp = *this; - Secret s = input(i + 1, share, n_bits); - *this = tmp + s; -#ifdef DEBUG_DYNAMIC - int128 a,b,c; - tmp.reveal(a); - s.reveal(b); - reveal(c); - cout << c << " = " << a << " ^ " << b << " (" << dec << n_bits << ", " << share << ")" << endl; -#endif - } -#endif #ifdef DEBUG_RANDOM int128 revealed; reveal(revealed); @@ -97,11 +81,7 @@ void Secret::random(int n_bits, int128 share) template void Secret::random_bit() { -#ifdef NO_INPUT return random(1, 0); -#else - return random(1, CommonParty::s().prng.get_uchar() & 1); -#endif } template @@ -115,67 +95,7 @@ void Secret::store(U& mem, template void Secret::output(T& reg) { - cast(reg).output(); -} - -template -Secret Secret::carryless_mult(const Secret& x, const Secret& y) -{ - Secret res; - if (x.registers.size() == 0) - throw not_implemented(); -#ifdef DEBUG_DYNAMIC2 - for (int i = 0; i < x.registers.size(); i++) - output(x.registers[i]); - for (int i = 0; i < y.registers.size(); i++) - output(y.registers[i]); -#endif - for (size_t i = 0; i < x.registers.size() + y.registers.size() - 1; i++) - { - int start = max((size_t)0, i - y.registers.size() + 1); - int stop = min(i + 1, x.registers.size()); - T sum = AND(x.get_reg(start), y.get_reg(i - start)); -#ifdef DEBUG_DYNAMIC2 - output(sum); - cout << "carryless " << i << " " << start << " " << i - start << - " sum " << (int)cast(sum).get_output() << - " x " << (int)x.get_reg(start).get_output() << - " y " << (int)y.get_reg(i - start).get_output() << - " sum id " << sum.get_reg().get_id() << endl; -#endif - for (int j = start + 1; j < stop; j++) - { - T product = AND(x.get_reg(j), y.get_reg(i - j)); - sum = XOR(sum, product); -#ifdef DEBUG_DYNAMIC2 - cout << "carryless " << - " prod id " << product.get_reg().get_id() << - " sum id " << sum.get_reg().get_id() << endl << flush; - output(product); - output(sum); - cout << "carryless " << i << " " << j << " " << i - j << - " prod " << (int)cast(product).get_output() << - " sum " << (int)cast(sum).get_output() << - " x " << (int)x.get_reg(j).get_output() << - " y " << (int)y.get_reg(i - j).get_output() << endl; -#endif - } - res.registers.push_back(sum); - } -#ifdef DEBUG_DYNAMIC - word a, b; - int128 c; - x.reveal(a); - y.reveal(b); - res.reveal(c); - cout << typeid(T).name() << endl; - cout << c << " = " << hex << a << " * " << b << endl; - AuthValue d; - d.assign(a, b, false); - if (d.mac != c) - throw runtime_error("carryless mult"); -#endif - return res; + reg.output(); } template @@ -189,7 +109,7 @@ template T& GC::Secret::get_new_reg() { registers.push_back(T::new_reg()); - T &res = cast(registers.back()); + T& res = registers.back(); #ifdef DEBUG_REGS cout << "Secret: new " << typeid(T).name() << " " << res.get_id() << " at " << &res << endl; #endif diff --git a/GC/Secret_inline.h b/GC/Secret_inline.h index 1d1d8ff7e..ccc6939eb 100644 --- a/GC/Secret_inline.h +++ b/GC/Secret_inline.h @@ -28,13 +28,10 @@ inline T XOR(const T& left, const T& right) template inline void AND(T& res, const T& left, const T& right) { -#ifdef KEY_SIGNAL #ifdef DEBUG_REGS cout << "*" << res.get_id() << " = AND(*" << left.get_id() << ",*" << right.get_id() << ")" << endl; #endif -#else -#endif - Secret::cast(res).op(Secret::cast(left), Secret::cast(right), 0x0001); + res.op(left, right, 0x0001); } template diff --git a/GC/SemiPrep.cpp b/GC/SemiPrep.cpp index e98101cc4..04a7d3885 100644 --- a/GC/SemiPrep.cpp +++ b/GC/SemiPrep.cpp @@ -64,4 +64,12 @@ void SemiPrep::buffer_bits() this->bits.push_back((r >> i) & 1); } +size_t SemiPrep::data_sent() +{ + if (triple_generator) + return triple_generator->data_sent(); + else + return 0; +} + } /* namespace GC */ diff --git a/GC/SemiPrep.h b/GC/SemiPrep.h index 619858d52..243444fb7 100644 --- a/GC/SemiPrep.h +++ b/GC/SemiPrep.h @@ -7,7 +7,7 @@ #define GC_SEMIPREP_H_ #include "Protocols/ReplicatedPrep.h" -#include "OT/TripleMachine.h" +#include "OT/MascotParams.h" #include "SemiSecret.h" #include "ShiftableTripleBuffer.h" @@ -50,6 +50,8 @@ class SemiPrep : public BufferPrep, ShiftableTripleBuffer::inputb(Processor& processor, { if (x.from == party.P->my_num()) { - bigint whole_input = processor.get_long_input(x.params, + bigint whole_input = processor.template + get_long_input(x.params, input_processor, interactive); for (int i = 0; i < DIV_CEIL(x.n_bits, dl); i++) input.add_mine(bigint(whole_input >> (i * dl)).get_si(), diff --git a/GC/Thread.h b/GC/Thread.h index 790b988dd..56fc87cc2 100644 --- a/GC/Thread.h +++ b/GC/Thread.h @@ -50,7 +50,7 @@ class Thread void run(); virtual void pre_run() {} - virtual void run(Program& program); + virtual void run(Program& program); virtual void post_run() {} void join_tape(); diff --git a/GC/Thread.hpp b/GC/Thread.hpp index fc57e8ef7..7e504755c 100644 --- a/GC/Thread.hpp +++ b/GC/Thread.hpp @@ -70,7 +70,7 @@ void Thread::run() } template -void Thread::run(Program& program) +void Thread::run(Program& program) { while (program.execute(processor, master.memory) != DONE_BREAK) ; diff --git a/GC/TinyPrep.h b/GC/TinyPrep.h index b8ebdc0a4..a8e0cf69f 100644 --- a/GC/TinyPrep.h +++ b/GC/TinyPrep.h @@ -7,7 +7,7 @@ #define GC_TINYPREP_H_ #include "Thread.h" -#include "OT/TripleMachine.h" +#include "OT/MascotParams.h" #include "Protocols/Beaver.h" #include "Protocols/ReplicatedPrep.h" #include "Protocols/RandomPrep.h" diff --git a/Machines/Player-Online.cpp b/Machines/Player-Online.cpp index 9c5491f12..e9599cfbc 100644 --- a/Machines/Player-Online.cpp +++ b/Machines/Player-Online.cpp @@ -6,6 +6,7 @@ #include "Processor/config.h" #include "Protocols/Share.h" #include "GC/TinierSecret.h" +#include "Math/gfp.h" #include "Player-Online.hpp" diff --git a/Machines/Player-Online.hpp b/Machines/Player-Online.hpp index 6a51e48c8..64874e641 100644 --- a/Machines/Player-Online.hpp +++ b/Machines/Player-Online.hpp @@ -91,7 +91,7 @@ int spdz_main(int argc, const char** argv, ez::ezOptionParser& opt, bool live_pr 1, // Number of args expected. 0, // Delimiter if expecting multiple args. "Maximum number of parties to send to at once", // Help description. - "-B", // Flag token. + "-mb", // Flag token. "--max-broadcast" // Flag token. ); opt.add( diff --git a/Machines/Rep.hpp b/Machines/Rep.hpp index fa35fc927..b6cb49d32 100644 --- a/Machines/Rep.hpp +++ b/Machines/Rep.hpp @@ -9,7 +9,6 @@ #include "Processor/Data_Files.hpp" #include "Processor/Instruction.hpp" #include "Processor/Machine.hpp" -#include "Protocols/BrainPrep.hpp" #include "Protocols/MalRepRingPrep.hpp" #include "Protocols/MaliciousRepPrep.hpp" #include "Protocols/MAC_Check_Base.hpp" diff --git a/Machines/Semi.hpp b/Machines/Semi.hpp index 644e141e4..36c9d8c50 100644 --- a/Machines/Semi.hpp +++ b/Machines/Semi.hpp @@ -4,8 +4,6 @@ */ #include "Protocols/SemiShare.h" -#include "Protocols/Semi2kShare.h" -#include "Math/gfp.h" #include "Math/gf2n.h" #include "Protocols/SemiMC.h" #include "Protocols/SemiPrep.h" @@ -20,4 +18,3 @@ #include "Protocols/MAC_Check.hpp" #include "Protocols/SemiMC.hpp" #include "Protocols/Beaver.hpp" -#include "Math/Z2k.hpp" diff --git a/Machines/brain-party.cpp b/Machines/brain-party.cpp index af6bfcca3..8c3559da2 100644 --- a/Machines/brain-party.cpp +++ b/Machines/brain-party.cpp @@ -8,6 +8,7 @@ #include "Processor/RingOptions.h" #include "Protocols/ReplicatedMachine.hpp" +#include "Protocols/BrainPrep.hpp" #include "Machines/RepRing.hpp" #include "Math/gfp.hpp" diff --git a/Machines/mama-party.cpp b/Machines/mama-party.cpp new file mode 100644 index 000000000..40b3edda9 --- /dev/null +++ b/Machines/mama-party.cpp @@ -0,0 +1,22 @@ +/* + * mama-party.cpp + * + */ + +#include "Protocols/MamaShare.h" + +#include "Protocols/MamaPrep.hpp" +#include "Protocols/MascotPrep.hpp" +#include "SPDZ.hpp" +#include "Player-Online.hpp" +#include "Math/gfp.hpp" + +#ifndef N_MAMA_MACS +#define N_MAMA_MACS 3 +#endif + +int main(int argc, const char** argv) +{ + ez::ezOptionParser opt; + return spdz_main, Share>(argc, argv, opt); +} diff --git a/Makefile b/Makefile index 7426449d1..c0e12092a 100644 --- a/Makefile +++ b/Makefile @@ -84,7 +84,7 @@ rep-bin: replicated-bin-party.x malicious-rep-bin-party.x Fake-Offline.x replicated: rep-field rep-ring rep-bin spdz2k: spdz2k-party.x ot-offline.x Check-Offline-Z2k.x galois-degree.x Fake-Offline.x -mascot: mascot-party.x spdz2k +mascot: mascot-party.x spdz2k mama-party.x tldr: -echo ARCH = -march=native >> CONFIG.mine @@ -101,7 +101,7 @@ shamir: shamir-party.x malicious-shamir-party.x galois-degree.x ecdsa: $(patsubst ECDSA/%.cpp,%.x,$(wildcard ECDSA/*-ecdsa-party.cpp)) ecdsa-static: static-dir $(patsubst ECDSA/%.cpp,static/%.x,$(wildcard ECDSA/*-ecdsa-party.cpp)) -$(LIBRELEASE): $(patsubst %.cpp,%.o,$(wildcard Protocols/*.cpp)) $(YAO) $(PROCESSOR) $(COMMON) $(BMR) $(FHEOFFLINE) $(GC) +$(LIBRELEASE): $(patsubst %.cpp,%.o,$(wildcard Protocols/*.cpp)) $(PROCESSOR) $(COMMON) $(BMR) $(FHEOFFLINE) $(GC) $(AR) -csr $@ $^ static/%.x: Machines/%.o $(LIBRELEASE) $(LIBSIMPLEOT) @@ -145,14 +145,15 @@ cnc-offline.x: $(FHEOFFLINE) spdz2-offline.x: $(FHEOFFLINE) yao-party.x: $(YAO) +static/yao-party.x: $(YAO) yao-clean: -rm Yao/*.o -galois-degree.x: Utils/galois-degree.cpp +galois-degree.x: Utils/galois-degree.o $(CXX) $(CFLAGS) -o $@ $^ $(LDLIBS) -default-prime-length.x: Utils/default-prime-length.cpp +default-prime-length.x: Utils/default-prime-length.o $(CXX) $(CFLAGS) -o $@ $^ $(LDLIBS) secure.x: Utils/secure.o @@ -186,6 +187,7 @@ chaigear-party.x: $(FHEOFFLINE) Protocols/CowGearOptions.o $(OT) mascot-party.x: Machines/SPDZ.o $(OT) static/mascot-party.x: Machines/SPDZ.o Player-Online.x: Machines/SPDZ.o $(OT) +mama-party.x: $(OT) ps-rep-ring-party.x: Protocols/MalRepRingOptions.o malicious-rep-ring-party.x: Protocols/MalRepRingOptions.o semi-ecdsa-party.x: $(OT) $(LIBSIMPLEOT) GC/SemiPrep.o GC/SemiSecret.o @@ -200,6 +202,7 @@ OT/BaseOT.o: SimpleOT/Makefile SimpleOT/Makefile: git submodule update --init SimpleOT +.PHONY: Programs/Circuits Programs/Circuits: git submodule update --init Programs/Circuits diff --git a/Math/FixedVec.h b/Math/FixedVec.h index acfa202c5..3be4d8545 100644 --- a/Math/FixedVec.h +++ b/Math/FixedVec.h @@ -26,6 +26,8 @@ class FixedVec public: typedef T value_type; + typedef FixedVec Scalar; + static const int length = L; static int size() @@ -134,6 +136,12 @@ class FixedVec add(*this, x); } + void add(octetStream& os) + { + for (int i = 0; i < L; i++) + v[i].add(os); + } + void negate() { for (auto& x : v) @@ -148,6 +156,11 @@ class FixedVec return true; } + bool operator!=(const FixedVec& other) const + { + return not equal(other); + } + bool is_zero() { return equal(0); @@ -299,6 +312,13 @@ class FixedVec for (auto& x : v) x.randomize(G); } + + void almost_randomize(PRNG& G) + { + for (auto& x : v) + x.almost_randomize(G); + } + void randomize_to_sum(const T& sum, PRNG& G) { T s = 0; @@ -347,7 +367,7 @@ class FixedVec template FixedVec operator*(const U& a, const FixedVec& b) { - return b * a; + return b * T(a); } template diff --git a/Math/Integer.h b/Math/Integer.h index b9f997125..3c87bd2df 100644 --- a/Math/Integer.h +++ b/Math/Integer.h @@ -160,6 +160,11 @@ class Integer : public IntBase void SHR(const Integer& x, const Integer& y) { *this = (unsigned long)x.a >> y.a; } }; +inline string to_string(const Integer& x) +{ + return to_string(x.get()); +} + template<> inline void IntBase::randomize(PRNG& G) { diff --git a/Math/Setup.h b/Math/Setup.h index 6ea6fc295..aeb978e4c 100644 --- a/Math/Setup.h +++ b/Math/Setup.h @@ -7,7 +7,6 @@ #define MATH_SETUP_H_ #include "Math/bigint.h" -#include "Math/gfp.h" #include "Tools/mkpath.h" #include @@ -67,10 +66,6 @@ void generate_prime_setup(string dir, int nparties, int lgp) generate_online_setup(get_prep_sub_dir(dir, nparties, lgp), p, lgp); } -// Read prime from file -template -void read_setup(const string& dir_prefix, int lgp = -1); - void init_gf2n(int gf2ndegree); #endif /* MATH_SETUP_H_ */ diff --git a/Math/Setup.hpp b/Math/Setup.hpp index a3f7ee27c..04eb84f7f 100644 --- a/Math/Setup.hpp +++ b/Math/Setup.hpp @@ -8,8 +8,8 @@ #include "gfp.h" -template -void read_setup(const string& dir_prefix, int lgp) +template +void read_setup(const string& dir_prefix, int lgp = -1) { bigint p; diff --git a/Math/Square.cpp b/Math/Square.cpp index 83e1d97f2..da49d8fae 100644 --- a/Math/Square.cpp +++ b/Math/Square.cpp @@ -5,6 +5,8 @@ #include "Square.h" #include "BitVec.h" +#include "gf2n.h" +#include "gfp.h" template<> void Square::to(gf2n_short& result) diff --git a/Math/ValueInterface.h b/Math/ValueInterface.h index b7a818b50..041a3f0be 100644 --- a/Math/ValueInterface.h +++ b/Math/ValueInterface.h @@ -9,6 +9,7 @@ #include "Exceptions/Exceptions.h" class OnlineOptions; +class bigint; class ValueInterface { @@ -24,6 +25,8 @@ class ValueInterface template static void generate_setup(string, int, int) {} + static bigint pr() { throw runtime_error("no prime modulus"); } + static int power_of_two(bool, int) { throw not_implemented(); } void normalize() {} diff --git a/Math/Z2k.h b/Math/Z2k.h index 5e0833060..395f59d76 100644 --- a/Math/Z2k.h +++ b/Math/Z2k.h @@ -257,10 +257,8 @@ class SignedZ2 : public Z2 return Z2::operator-(other); } - template - SignedZ2 operator*(const SignedZ2& other) const + SignedZ2 operator*(const SignedZ2& other) const { - assert((K % 64 == 0) and (L % 64 == 0)); return Z2::operator*(other); } diff --git a/Math/bigint.h b/Math/bigint.h index 2320c107a..e39f4a6e6 100644 --- a/Math/bigint.h +++ b/Math/bigint.h @@ -44,7 +44,7 @@ class bigint : public mpz_class static void init_thread() { tmp = 0; } template - static mpf_class get_float(T v, Integer exp, T z, T s); + static mpf_class get_float(T v, T p, T z, T s); template static void output_float(U& o, const mpf_class& x, T nan); diff --git a/Math/bigint.hpp b/Math/bigint.hpp index 73cd81ff3..a5b195235 100644 --- a/Math/bigint.hpp +++ b/Math/bigint.hpp @@ -10,8 +10,10 @@ #include "Integer.h" template -mpf_class bigint::get_float(T v, Integer exp, T z, T s) +mpf_class bigint::get_float(T v, T p, T z, T s) { + // MPIR can't handle more precision in exponent + Integer exp = Integer(p, 31).get(); bigint tmp = v; mpf_class res = tmp; if (exp > 0) diff --git a/Math/gf2nlong.h b/Math/gf2nlong.h index 9696279dd..c541ce2f8 100644 --- a/Math/gf2nlong.h +++ b/Math/gf2nlong.h @@ -152,7 +152,6 @@ class gf2n_long : public ValueInterface static gf2n_long Mul(gf2n_long a, gf2n_long b) { return a * b; } int128 get() const { return a; } - __m128i to_m128i() const { return a.a; } word get_word() const { return _mm_cvtsi128_si64(a.a); } const void* get_ptr() const { return &a.a; } diff --git a/Math/gfp.h b/Math/gfp.h index 2324b840f..dc7453452 100644 --- a/Math/gfp.h +++ b/Math/gfp.h @@ -155,16 +155,6 @@ class gfp_ return *this; } - void to_m128i(__m128i& ans) - { - memcpy(&ans, a.x, sizeof(ans)); - } - - __m128i to_m128i() - { - return _mm_loadu_si128((__m128i*)a.x); - } - void zero_overhang(); void check(); diff --git a/OT/MamaRectangle.h b/OT/MamaRectangle.h new file mode 100644 index 000000000..6c7f4b003 --- /dev/null +++ b/OT/MamaRectangle.h @@ -0,0 +1,84 @@ +/* + * MamaRectangle.h + * + */ + +#ifndef OT_MAMARECTANGLE_H_ +#define OT_MAMARECTANGLE_H_ + +#include "Math/FixedVec.h" +#include "Math/gfp.h" +#include "Tools/BitVector.h" + +template +class MamaRectangle +{ + typedef MamaRectangle This; + + typename T::Square squares[N]; + +public: + static const int N_ROWS = T::Square::N_ROWS; + static const int N_COLUMNS = T::Square::N_COLUMNS; + static const int N_ROW_BYTES = T::Square::N_ROW_BYTES; + + static int size() + { + return N * T::Square::size(); + } + + void conditional_add(BitVector& conditions, This& other, + int offset) + { + for (int i = 0; i < N; i++) + squares[i].conditional_add(conditions, other.squares[i], + offset * N + i); + } + + This& sub(const This& other) + { + for (int i = 0; i < N; i++) + squares[i].sub(other.squares[i]); + return *this; + } + + This& rsub(const This& other) + { + for (int i = 0; i < N; i++) + squares[i].rsub(other.squares[i]); + return *this; + } + + This& sub(const void* other) + { + for (int i = 0; i < N; i++) + squares[i].sub(other); + return *this; + } + + void randomize(int row, PRNG& G) + { + squares[row / T::Square::N_ROWS].randomize(row % T::Square::N_ROWS, G); + } + + void pack(octetStream& os) const + { + for (int i = 0; i < N; i++) + squares[i].pack(os); + } + + void unpack(octetStream& os) + { + for (int i = 0; i < N; i++) + squares[i].unpack(os); + } + + template + void to(FixedVec& result) + { + for (int i = 0; i < N; i++) + squares[i].to(result[i]); + } +}; + +#endif /* OT_MAMARECTANGLE_H_ */ diff --git a/OT/MascotParams.h b/OT/MascotParams.h new file mode 100644 index 000000000..e81d15107 --- /dev/null +++ b/OT/MascotParams.h @@ -0,0 +1,30 @@ +/* + * MascotParams.h + * + */ + +#ifndef OT_MASCOTPARAMS_H_ +#define OT_MASCOTPARAMS_H_ + +#include "Tools/OfflineMachineBase.h" + +class MascotParams : virtual public OfflineParams +{ +public: + string prep_data_dir; + bool generateMACs; + bool amplify; + bool check; + bool correlation_check; + bool generateBits; + bool use_extension; + bool fewer_rounds; + bool fiat_shamir; + struct timeval start, stop; + + MascotParams(); + + void set_passive(); +}; + +#endif /* OT_MASCOTPARAMS_H_ */ diff --git a/OT/NPartyTripleGenerator.h b/OT/NPartyTripleGenerator.h index 7b8564cea..87c86d95e 100644 --- a/OT/NPartyTripleGenerator.h +++ b/OT/NPartyTripleGenerator.h @@ -8,7 +8,7 @@ #include "Processor/InputTuple.h" #include "OT/OTTripleSetup.h" -#include "OT/TripleMachine.h" +#include "OT/MascotParams.h" #include "OT/OTMultiplier.h" #include @@ -143,22 +143,42 @@ class NPartyTripleGenerator : public OTTripleGenerator }; template -class MascotTripleGenerator : public NPartyTripleGenerator +class SimpleMascotTripleGenerator : public NPartyTripleGenerator { - typedef typename T::open_type open_type; typedef typename T::mac_key_type mac_key_type; typedef typename T::MAC_Check MAC_Check; + virtual void sacrifice(typename T::MAC_Check&, PRNG&) { throw not_implemented(); } + +public: + vector< ShareTriple > uncheckedTriples; + + SimpleMascotTripleGenerator(const OTTripleSetup& setup, const Names& names, + int thread_num, int nTriples, int nloops, MascotParams& machine, + mac_key_type mac_key, + Player* parentPlayer = 0); + virtual ~SimpleMascotTripleGenerator() {} + void generateTriples(); +}; + +template +class MascotTripleGenerator : public SimpleMascotTripleGenerator +{ + typedef typename T::open_type open_type; + typedef typename T::mac_key_type mac_key_type; + typedef typename T::MAC_Check MAC_Check; + void generateBits(); void generateBitsGf2n(); - void generateBitsFromTriples(MAC_Check& MC, ofstream& outputFile); + template + void generateBitsFromTriples(MAC_Check& MC, ofstream& outputFile, gfp_); + template + void generateBitsFromTriples(MAC_Check& MC, ofstream& outputFile, U); void sacrifice(typename T::MAC_Check& MC, PRNG& G); public: - vector< ShareTriple > uncheckedTriples; - vector bits; MascotTripleGenerator(const OTTripleSetup& setup, const Names& names, diff --git a/OT/NPartyTripleGenerator.hpp b/OT/NPartyTripleGenerator.hpp index 5ec6e2f12..18867475a 100644 --- a/OT/NPartyTripleGenerator.hpp +++ b/OT/NPartyTripleGenerator.hpp @@ -41,7 +41,7 @@ NPartyTripleGenerator::NPartyTripleGenerator(const OTTripleSetup& setup, } template -MascotTripleGenerator::MascotTripleGenerator(const OTTripleSetup& setup, +SimpleMascotTripleGenerator::SimpleMascotTripleGenerator(const OTTripleSetup& setup, const Names& names, int thread_num, int _nTriples, int nloops, MascotParams& machine, mac_key_type mac_key, Player* parentPlayer) : NPartyTripleGenerator(setup, names, thread_num, _nTriples, nloops, @@ -49,6 +49,15 @@ MascotTripleGenerator::MascotTripleGenerator(const OTTripleSetup& setup, { } +template +MascotTripleGenerator::MascotTripleGenerator(const OTTripleSetup& setup, + const Names& names, int thread_num, int _nTriples, int nloops, + MascotParams& machine, mac_key_type mac_key, Player* parentPlayer) : + SimpleMascotTripleGenerator(setup, names, thread_num, _nTriples, nloops, + machine, mac_key, parentPlayer) +{ +} + template Spdz2kTripleGenerator::Spdz2kTripleGenerator(const OTTripleSetup& setup, const Names& names, int thread_num, int _nTriples, int nloops, @@ -332,7 +341,7 @@ void MascotTripleGenerator::generateBits() if (T::clear::characteristic_two) generateBitsGf2n(); else - generateTriples(); + this->generateTriples(); } template @@ -519,7 +528,7 @@ void OTTripleGenerator::plainTripleRound(int k) } template -void MascotTripleGenerator::generateTriples() +void SimpleMascotTripleGenerator::generateTriples() { typedef typename U::open_type T; @@ -655,7 +664,7 @@ void MascotTripleGenerator::sacrifice(typename T::MAC_Check& MC, PRNG& G) MC.Check(globalPlayer); if (machine.generateBits) - generateBitsFromTriples(MC, outputFile); + generateBitsFromTriples(MC, outputFile, typename T::clear()); else if (machine.output) for (int j = 0; j < nTriplesPerLoop; j++) @@ -715,12 +724,15 @@ void Spdz2kTripleGenerator::sacrificeZ2k(U& MC, PRNG& G) uncheckedTriples[j].template reduce().output(outputFile, 1); } -template<> -inline -void MascotTripleGenerator>::generateBitsFromTriples(MAC_Check& MC, - ofstream& outputFile) +template +template +void MascotTripleGenerator::generateBitsFromTriples(MAC_Check& MC, + ofstream& outputFile, gfp_) { + typedef gfp_ gfp1; auto& triples = this->uncheckedTriples; + auto& nTriplesPerLoop = this->nTriplesPerLoop; + auto& globalPlayer = this->globalPlayer; vector< Share > a_plus_b(nTriplesPerLoop), a_squared(nTriplesPerLoop); for (int i = 0; i < nTriplesPerLoop; i++) a_plus_b[i] = triples[i].a[0] + triples[i].b; @@ -740,7 +752,7 @@ void MascotTripleGenerator>::generateBitsFromTriples(MAC_Check& MC, if (root.is_zero()) continue; Share bit = (triples[i].a[0] / root + one) / gfp1(2); - if (machine.output) + if (this->machine.output) bit.output(outputFile, false); else bits.push_back(bit); @@ -748,7 +760,8 @@ void MascotTripleGenerator>::generateBitsFromTriples(MAC_Check& MC, } template -void MascotTripleGenerator::generateBitsFromTriples(MAC_Check&, ofstream&) +template +void MascotTripleGenerator::generateBitsFromTriples(MAC_Check&, ofstream&, U) { throw how_would_that_work(); } diff --git a/OT/OTCorrelator.hpp b/OT/OTCorrelator.hpp index cc0d4f9cd..ab22a9e4d 100644 --- a/OT/OTCorrelator.hpp +++ b/OT/OTCorrelator.hpp @@ -135,8 +135,6 @@ void OTExtensionWithMatrix::hash_outputs(int nOTs, vector& senderOutput, gettimeofday(&startv, NULL); #endif - typedef typename V::PartType::RowType T; - int n_rows = V::PartType::N_ROWS_ALLOCATED; int n = (nOTs + n_rows - 1) / n_rows * V::PartType::N_ROWS; for (int i = 0; i < 2; i++) @@ -165,13 +163,13 @@ void OTExtensionWithMatrix::hash_outputs(int nOTs, vector& senderOutput, senderOutputMatrices[1].squares[i_outer_input].rows[i_inner_input + j]; } for (int j = 0; j < 2; j++) - mmo.hashBlocks( + mmo.hashEightBlocks( &senderOutput[j].squares[i_outer_output].rows[i_inner_output], &tmp[j]); } if (ot_role & RECEIVER) { - mmo.hashBlocks( + mmo.hashEightBlocks( &receiverOutput.squares[i_outer_output].rows[i_inner_output], &receiverOutputMatrix.squares[i_outer_input].rows[i_inner_input]); } diff --git a/OT/OTMultiplier.h b/OT/OTMultiplier.h index 2ae9a865f..d229a5ea5 100644 --- a/OT/OTMultiplier.h +++ b/OT/OTMultiplier.h @@ -88,7 +88,10 @@ class MascotMultiplier : public OTMultiplier const vector& baseReceiverOutput); void multiplyForBits(); - void multiplyForGf2nBits(); + template + void multiplyForBits(U); + template + void multiplyForBits(gfp_); public: vector c_output; diff --git a/OT/OTMultiplier.hpp b/OT/OTMultiplier.hpp index ed14aa9a3..9c878e59a 100644 --- a/OT/OTMultiplier.hpp +++ b/OT/OTMultiplier.hpp @@ -400,29 +400,22 @@ void Spdz2kMultiplier::after_correlation() this->outbox.push({}); } -template<> -inline -void MascotMultiplier>::multiplyForBits() -{ - multiplyForTriples(); -} - -template<> -inline -void MascotMultiplier>::multiplyForBits() +template +void MascotMultiplier::multiplyForBits() { - multiplyForGf2nBits(); + multiplyForBits(typename T::clear()); } -template<> -inline -void MascotMultiplier>::multiplyForBits() +template +template +void MascotMultiplier::multiplyForBits(gfp_) { - multiplyForGf2nBits(); + throw runtime_error("should not be called"); } template -void MascotMultiplier::multiplyForGf2nBits() +template +void MascotMultiplier::multiplyForBits(U) { auto& macs = this->macs; auto& outbox = this->outbox; @@ -540,9 +533,3 @@ void OTMultiplier::multiplyForBits() { throw runtime_error("bit generation not implemented in this case"); } - -template -void MascotMultiplier::multiplyForBits() -{ - throw runtime_error("bit generation not implemented in this case"); -} diff --git a/OT/OTTripleSetup.cpp b/OT/OTTripleSetup.cpp index 7cabdf5af..f2bdabe23 100644 --- a/OT/OTTripleSetup.cpp +++ b/OT/OTTripleSetup.cpp @@ -20,7 +20,7 @@ void OTTripleSetup::setup() baseReceiverOutputs[i] = baseOTs[i]->receiver_outputs; } gettimeofday(&baseOTend, NULL); -#ifdef VERBOSE +#ifdef VERBOSE_BASEOT double basetime = timeval_diff(&baseOTstart, &baseOTend); cout << "\t\tBaseTime: " << basetime/1000000 << endl << flush; #endif diff --git a/OT/OTTripleSetup.h b/OT/OTTripleSetup.h index a30b72bdc..2c40c79e7 100644 --- a/OT/OTTripleSetup.h +++ b/OT/OTTripleSetup.h @@ -38,7 +38,7 @@ class OTTripleSetup baseSenderInputs.resize(nparties - 1); baseReceiverOutputs.resize(nparties - 1); -#ifdef VERBOSE +#ifdef VERBOSE_BASEOT if (real_OTs) cout << "Doing real base OTs\n"; else diff --git a/OT/Triple.hpp b/OT/Triple.hpp index b18d7e1f2..318082e9e 100644 --- a/OT/Triple.hpp +++ b/OT/Triple.hpp @@ -140,7 +140,7 @@ class ShareTriple : public Triple int repeat = this->repeat(l, generator.machine.check); for (int j = 0; j < repeat; j++) { - T value = triple.byIndex(l,j); + typename U::share_type value = triple.byIndex(l,j); typename U::mac_type mac; mac.mul(value, generator.get_mac_key()); for (int i = 0; i < generator.nparties-1; i++) diff --git a/OT/TripleMachine.h b/OT/TripleMachine.h index 1ca3947f1..2b95ff4d8 100644 --- a/OT/TripleMachine.h +++ b/OT/TripleMachine.h @@ -9,30 +9,11 @@ #include "Math/gf2n.h" #include "Math/gfp.h" #include "Math/Z2k.h" -#include "Tools/OfflineMachineBase.h" #include "OT/OTTripleSetup.h" +#include "OT/MascotParams.h" class GeneratorThread; -class MascotParams : virtual public OfflineParams -{ -public: - string prep_data_dir; - bool generateMACs; - bool amplify; - bool check; - bool correlation_check; - bool generateBits; - bool use_extension; - bool fewer_rounds; - bool fiat_shamir; - struct timeval start, stop; - - MascotParams(); - - void set_passive(); -}; - class TripleMachine : public OfflineMachineBase, public MascotParams { Names N[2]; diff --git a/Processor/BaseMachine.cpp b/Processor/BaseMachine.cpp index db8627dfd..25b8beec7 100644 --- a/Processor/BaseMachine.cpp +++ b/Processor/BaseMachine.cpp @@ -42,6 +42,9 @@ void BaseMachine::load_schedule(string progname) inpf >> nthreads; inpf >> nprogs; + if (inpf.fail()) + throw file_error("Error reading " + fname); + #ifdef DEBUG_FILES cerr << "Number of threads I will run in parallel = " << nthreads << endl; cerr << "Number of program sequences I need to load = " << nprogs << endl; diff --git a/Processor/Input.h b/Processor/Input.h index 266943b7f..ba10becbb 100644 --- a/Processor/Input.h +++ b/Processor/Input.h @@ -58,7 +58,7 @@ class InputBase virtual void finalize_other(int player, T& target, octetStream& o, int n_bits = -1) = 0; T finalize(int player, int n_bits = -1); - void raw_input(SubProcessor& proc, const vector& args); + void raw_input(SubProcessor& proc, const vector& args, int size); }; template diff --git a/Processor/Input.hpp b/Processor/Input.hpp index 319dc9bc9..34ac2f644 100644 --- a/Processor/Input.hpp +++ b/Processor/Input.hpp @@ -131,7 +131,8 @@ void InputBase::exchange() } template -void InputBase::raw_input(SubProcessor& proc, const vector& args) +void InputBase::raw_input(SubProcessor& proc, const vector& args, + int size) { auto& P = proc.P; reset_all(P); @@ -142,20 +143,24 @@ void InputBase::raw_input(SubProcessor& proc, const vector& args) it++; if (player == P.my_num()) { - clear t; - try - { - this->buffer.input(t); - } - catch (not_enough_to_buffer& e) + for (int i = 0; i < size; i++) { - throw runtime_error("Insufficient input data to buffer"); + clear t; + try + { + this->buffer.input(t); + } + catch (not_enough_to_buffer& e) + { + throw runtime_error("Insufficient input data to buffer"); + } + add_mine(t); } - add_mine(t); } else { - add_other(player); + for (int i = 0; i < size; i++) + add_other(player); } } @@ -166,7 +171,9 @@ void InputBase::raw_input(SubProcessor& proc, const vector& args) for (auto it = args.begin(); it != args.end();) { int player = *it++; - proc.get_S_ref(*it++) = finalize(player); + int base = *it++; + for (int i = 0; i < size; i++) + proc.get_S_ref(base + i) = finalize(player); } } diff --git a/Processor/Instruction.h b/Processor/Instruction.h index 167bd1332..f81cd87fc 100644 --- a/Processor/Instruction.h +++ b/Processor/Instruction.h @@ -99,6 +99,7 @@ enum TRUNC_PR = 0xA9, MATMULS = 0xAA, MATMULSM = 0xAB, + CONV2DS = 0xAC, // Data access TRIPLE = 0x50, BIT = 0x51, diff --git a/Processor/Instruction.hpp b/Processor/Instruction.hpp index 163f1c6f1..5e9b46605 100644 --- a/Processor/Instruction.hpp +++ b/Processor/Instruction.hpp @@ -101,6 +101,7 @@ void BaseInstruction::parse_operands(istream& s, int pos, int file_pos) case SUBINT: case MULINT: case DIVINT: + case CONDPRINTPLAIN: r[0]=get_int(s); r[1]=get_int(s); r[2]=get_int(s); @@ -141,7 +142,6 @@ void BaseInstruction::parse_operands(istream& s, int pos, int file_pos) case GPROTECTMEMS: case GPROTECTMEMC: case PROTECTMEMINT: - case CONDPRINTPLAIN: case DABIT: case SHUFFLE: r[0]=get_int(s); @@ -223,6 +223,8 @@ void BaseInstruction::parse_operands(istream& s, int pos, int file_pos) case RUN_TAPE: case STARTPRIVATEOUTPUT: case GSTARTPRIVATEOUTPUT: + case STOPPRIVATEOUTPUT: + case GSTOPPRIVATEOUTPUT: case DIGESTC: r[0]=get_int(s); r[1]=get_int(s); @@ -253,8 +255,6 @@ void BaseInstruction::parse_operands(istream& s, int pos, int file_pos) case PRINTREGB: case GPRINTREG: case LDINT: - case STOPPRIVATEOUTPUT: - case GSTOPPRIVATEOUTPUT: case INPUTMASK: case GINPUTMASK: case ACCEPTCLIENTCONNECTION: @@ -321,6 +321,10 @@ void BaseInstruction::parse_operands(istream& s, int pos, int file_pos) get_ints(r, s, 3); get_vector(9, start, s); break; + case CONV2DS: + get_ints(r, s, 3); + get_vector(11, start, s); + break; // read from file, input is opcode num_args, // start_file_posn (read), end_file_posn(write) var1, var2, ... @@ -489,7 +493,7 @@ bool Instruction::get_offline_data_usage(DataPositions& usage) usage.edabits[{r[0], r[1]}] = n; return int(n) >= 0; case USE_PREP: - usage.extended[gfp::field_type()][r] = n; + usage.extended[DATA_INT][r] = n; return int(n) >= 0; case GUSE_PREP: usage.extended[gf2n::field_type()][r] = n; @@ -601,6 +605,8 @@ unsigned BaseInstruction::get_max_reg(int reg_type) const case MATMULS: case MATMULSM: return r[0] + start[0] * start[2]; + case CONV2DS: + return r[0] + start[0] * start[1]; case LDMSD: case LDMSDI: skip = 3; @@ -779,6 +785,10 @@ inline void Instruction::execute(Processor& Proc) const for (int i = 0; i < size; i++) Proc.write_Sp(r[0] + i, Proc.machine.Mp.read_S(n + i)); return; + case STMSI: + for (int i = 0; i < size; i++) + Proc.machine.Mp.write_S(Proc.read_Ci(r[1] + i), Proc.read_Sp(r[0] + i), Proc.PC); + return; case STMS: for (int i = 0; i < size; i++) Proc.machine.Mp.write_S(n + i, Proc.read_Sp(r[0] + i), Proc.PC); @@ -1119,7 +1129,7 @@ inline void Instruction::execute(Processor& Proc) const break; case LEGENDREC: to_bigint(Proc.temp.aa, Proc.read_Cp(r[1])); - Proc.temp.aa = mpz_legendre(Proc.temp.aa.get_mpz_t(), gfp::pr().get_mpz_t()); + Proc.temp.aa = mpz_legendre(Proc.temp.aa.get_mpz_t(), sint::clear::pr().get_mpz_t()); to_gfp(Proc.temp.ansp, Proc.temp.aa); Proc.write_Cp(r[0], Proc.temp.ansp); break; @@ -1294,11 +1304,11 @@ inline void Instruction::execute(Processor& Proc) const sint::Input::input_mixed(Proc.Procp, start, size, true); return; case RAWINPUT: - Proc.Procp.input.raw_input(Proc.Procp, start); - break; + Proc.Procp.input.raw_input(Proc.Procp, start, size); + return; case GRAWINPUT: - Proc.Proc2.input.raw_input(Proc.Proc2, start); - break; + Proc.Proc2.input.raw_input(Proc.Proc2, start, size); + return; case ANDC: Proc.get_Cp_ref(r[0]).AND(Proc.read_Cp(r[1]),Proc.read_Cp(r[2])); break; @@ -1441,6 +1451,9 @@ inline void Instruction::execute(Processor& Proc) const Proc.Procp.matmulsm(Proc.machine.Mp.MS, *this, Proc.read_Ci(r[1]), Proc.read_Ci(r[2])); return; + case CONV2DS: + Proc.Procp.conv2ds(*this); + return; case TRUNC_PR: Proc.Procp.protocol.trunc_pr(start, size, Proc.Procp); return; @@ -1552,7 +1565,14 @@ inline void Instruction::execute(Processor& Proc) const break; case CONDPRINTPLAIN: if (not Proc.read_Cp(r[0]).is_zero()) - Proc.out << Proc.read_Cp(r[1]) << flush; + { + auto v = Proc.read_Cp(r[1]); + auto p = Proc.read_Cp(r[2]); + if (p.is_zero()) + Proc.out << v << flush; + else + Proc.out << bigint::get_float(v, p, {}, {}) << flush; + } break; case GPRINTREGPLAIN: { @@ -1571,9 +1591,7 @@ inline void Instruction::execute(Processor& Proc) const typename sint::clear p = Proc.read_Cp(start[1]); typename sint::clear z = Proc.read_Cp(start[2]); typename sint::clear s = Proc.read_Cp(start[3]); - // MPIR can't handle more precision in exponent - long exp = Integer(p, 31).get(); - bigint::output_float(Proc.out, bigint::get_float(v, exp, z, s), nan); + bigint::output_float(Proc.out, bigint::get_float(v, p, z, s), nan); } break; case PRINTFLOATPREC: @@ -1586,7 +1604,13 @@ inline void Instruction::execute(Processor& Proc) const break; case CONDPRINTSTR: if (not Proc.read_Cp(r[0]).is_zero()) - Proc.out << string((char*)&n,sizeof(n)) << flush; + { + string str = {(char*)&n, sizeof(n)}; + size_t n = str.find('\0'); + if (n < 4) + str.erase(n); + Proc.out << str << flush; + } break; case PRINTCHR: { @@ -1733,10 +1757,10 @@ inline void Instruction::execute(Processor& Proc) const Proc.privateOutput2.start(n,r[0],r[1]); break; case STOPPRIVATEOUTPUT: - Proc.privateOutputp.stop(n,r[0]); + Proc.privateOutputp.stop(n,r[0],r[1]); break; case GSTOPPRIVATEOUTPUT: - Proc.privateOutput2.stop(n,r[0]); + Proc.privateOutput2.stop(n,r[0],r[1]); break; case PREP: Procp.DataF.get(Proc.Procp.get_S(), r, start, size); diff --git a/Processor/Machine.hpp b/Processor/Machine.hpp index b9624c150..bd602e4f6 100644 --- a/Processor/Machine.hpp +++ b/Processor/Machine.hpp @@ -148,7 +148,7 @@ DataPositions Machine::run_tape(int thread_number, int tape_number, // central preprocessing auto usage = progs[tape_number].get_offline_data_used(); - if (sint::expensive and prep != 0) + if (sint::expensive and prep != 0 and OnlineOptions::singleton.bucket_size == 3) { try { @@ -185,7 +185,7 @@ DataPositions Machine::run_tape(int thread_number, int tape_number, } typedef typename sint::bit_type bit_type; - if (bit_type::expensive_triples and bit_prep) + if (bit_type::expensive_triples and bit_prep and OnlineOptions::singleton.bucket_size == 3) { try { diff --git a/Processor/Online-Thread.hpp b/Processor/Online-Thread.hpp index fa55a5ce3..6f61dabb9 100644 --- a/Processor/Online-Thread.hpp +++ b/Processor/Online-Thread.hpp @@ -270,9 +270,12 @@ void thread_info::Sub_Main_Func() wait_timer.stop(); #ifdef VERBOSE - cerr << num << " : MAC Checking" << endl; - cerr << "\tMC2.number=" << MC2->number() << endl; - cerr << "\tMCp.number=" << MCp->number() << endl; + if (MC2->number() + MCp->number() > 0) + cerr << num << " : MAC Checking" << endl; + if (MC2->number()) + cerr << "\tMC2.number=" << MC2->number() << endl; + if (MCp->number()) + cerr << "\tMCp.number=" << MCp->number() << endl; cerr << "Thread " << num << " timer: " << thread_timer.elapsed() << endl; cerr << "Thread " << num << " wait timer: " << wait_timer.elapsed() << endl; diff --git a/Processor/OnlineOptions.cpp b/Processor/OnlineOptions.cpp index 39bda1e84..410ea67c3 100644 --- a/Processor/OnlineOptions.cpp +++ b/Processor/OnlineOptions.cpp @@ -18,7 +18,7 @@ OnlineOptions::OnlineOptions() : playerno(-1) batch_size = 10000; memtype = "empty"; direct = false; - fake_batch = false; + bucket_size = 3; } OnlineOptions::OnlineOptions(ez::ezOptionParser& opt, int argc, @@ -107,9 +107,8 @@ OnlineOptions::OnlineOptions(ez::ezOptionParser& opt, int argc, 0, // Required? 1, // Number of args expected. 0, // Delimiter if expecting multiple args. - "Where to obtain memory, new|old|empty (default: empty)\n\t" - "new: copy from Player-Memory-P file\n\t" - "old: reuse previous memory in Memory-P\n\t" + "Where to obtain memory, old|empty (default: empty)\n\t" + "old: reuse previous memory in Memory--P\n\t" "empty: create new empty memory", // Help description. "-m", // Flag token. "--memory" // Flag token. @@ -124,13 +123,13 @@ OnlineOptions::OnlineOptions(ez::ezOptionParser& opt, int argc, "--direct" // Flag token. ); opt.add( - "", // Default. + "3", // Default. 0, // Required? - 0, // Number of args expected. + 1, // Number of args expected. 0, // Delimiter if expecting multiple args. - "Use insecurely small batches for testing", // Help description. - "-fake-batch", // Flag token. - "--fake-batch" // Flag token. + "Batch size for sacrifice (3-5, default: 3)", // Help description. + "-B", // Flag token. + "--bucket-size" // Flag token. ); opt.parse(argc, argv); @@ -152,14 +151,7 @@ OnlineOptions::OnlineOptions(ez::ezOptionParser& opt, int argc, opt.get("--memory")->getString(memtype); direct = opt.isSet("--direct"); - bool fb = opt.isSet("--fake-batch"); -#ifdef INSECURE - fake_batch = fb; -#else - if (fb) - throw runtime_error("option only supported " - "when compiled with -DINSECURE"); -#endif + opt.get("--bucket-size")->getInt(bucket_size); opt.resetArgs(); } diff --git a/Processor/OnlineOptions.h b/Processor/OnlineOptions.h index 880f1fe83..68055bb51 100644 --- a/Processor/OnlineOptions.h +++ b/Processor/OnlineOptions.h @@ -23,7 +23,7 @@ class OnlineOptions int batch_size; std::string memtype; bool direct; - bool fake_batch; + int bucket_size; OnlineOptions(); OnlineOptions(ez::ezOptionParser& opt, int argc, const char** argv, diff --git a/Processor/PrivateOutput.h b/Processor/PrivateOutput.h index 4cd47d0db..a0ac2a50a 100644 --- a/Processor/PrivateOutput.h +++ b/Processor/PrivateOutput.h @@ -23,7 +23,7 @@ class PrivateOutput PrivateOutput(SubProcessor& proc) : proc(proc) { }; void start(int player, int target, int source); - void stop(int player, int source); + void stop(int player, int dest, int source); }; #endif /* PROCESSOR_PRIVATEOUTPUT_H_ */ diff --git a/Processor/PrivateOutput.hpp b/Processor/PrivateOutput.hpp index c78358bf9..700363440 100644 --- a/Processor/PrivateOutput.hpp +++ b/Processor/PrivateOutput.hpp @@ -9,6 +9,7 @@ template void PrivateOutput::start(int player, int target, int source) { + assert (player < proc.P.num_players()); open_type mask; proc.DataF.get_input(proc.get_S_ref(target), mask, player); proc.get_S_ref(target) += proc.get_S_ref(source); @@ -18,11 +19,11 @@ void PrivateOutput::start(int player, int target, int source) } template -void PrivateOutput::stop(int player, int source) +void PrivateOutput::stop(int player, int dest, int source) { if (player == proc.P.my_num() and proc.Proc) { - open_type value; + auto& value = proc.get_C_ref(dest); value.sub(proc.get_C_ref(source), masks.front()); value.output(proc.Proc->private_output, false); masks.pop_front(); diff --git a/Processor/Processor.h b/Processor/Processor.h index ec474c0a8..ae276f769 100644 --- a/Processor/Processor.h +++ b/Processor/Processor.h @@ -68,6 +68,7 @@ class SubProcessor int b); void matmulsm(const CheckVector& source, const Instruction& instruction, int a, int b); + void conv2ds(const Instruction& instruction); vector& get_S() { diff --git a/Processor/Processor.hpp b/Processor/Processor.hpp index 6803ac26e..47d803d3b 100644 --- a/Processor/Processor.hpp +++ b/Processor/Processor.hpp @@ -523,6 +523,63 @@ void SubProcessor::matmulsm(const CheckVector& source, *(C + i * dim[2] + j) = protocol.finalize_dotprod(dim[1]); } +template +void SubProcessor::conv2ds(const Instruction& instruction) +{ + protocol.init_dotprod(this); + auto& args = instruction.get_start(); + int output_h = args[0], output_w = args[1]; + int inputs_h = args[2], inputs_w = args[3]; + int weights_h = args[4], weights_w = args[5]; + int stride_h = args[6], stride_w = args[7]; + int n_channels_in = args[8]; + int padding_h = args[9]; + int padding_w = args[10]; + int r0 = instruction.get_r(0); + int r1 = instruction.get_r(1); + int r2 = instruction.get_r(2); + int lengths[output_h][output_w]; + memset(lengths, 0, sizeof(lengths)); + + for (int out_y = 0; out_y < output_h; out_y++) + for (int out_x = 0; out_x < output_w; out_x++) + { + int in_x_origin = (out_x * stride_w) - padding_w; + int in_y_origin = (out_y * stride_h) - padding_h; + + for (int filter_y = 0; filter_y < weights_h; filter_y++) + { + int in_y = in_y_origin + filter_y; + if ((0 <= in_y) and (in_y < inputs_h)) + for (int filter_x = 0; filter_x < weights_w; filter_x++) + { + int in_x = in_x_origin + filter_x; + if ((0 <= in_x) and (in_x < inputs_w)) + { + for (int in_c = 0; in_c < n_channels_in; in_c++) + protocol.prepare_dotprod( + S[r1 + (in_y * inputs_w + in_x) * + n_channels_in + in_c], + S[r2 + (filter_y * weights_w + filter_x) * + n_channels_in + in_c]); + lengths[out_y][out_x] += n_channels_in; + } + } + } + + protocol.next_dotprod(); + } + + protocol.exchange(); + + for (int out_y = 0; out_y < output_h; out_y++) + for (int out_x = 0; out_x < output_w; out_x++) + { + S[r0 + out_y * output_w + out_x] = protocol.finalize_dotprod( + lengths[out_y][out_x]); + } +} + template ostream& operator<<(ostream& s,const Processor& P) { diff --git a/Programs/Source/benchmark_mobilenet.mpc b/Programs/Source/benchmark_mobilenet.mpc index 03f8de7f6..896350897 100644 --- a/Programs/Source/benchmark_mobilenet.mpc +++ b/Programs/Source/benchmark_mobilenet.mpc @@ -35,6 +35,18 @@ program.use_edabit_for = lambda *args: args in edabits ml.QuantBase.n_threads = 8 +if 'conv2ds' in program.args: + ml.ConvBase.use_conv2ds = True + +if 'split' in program.args: + program.use_split(3) + +if 'split2' in program.args: + program.use_split(2) + +if 'cisc' in program.args: + program.options.cisc = True + if len(program.args) > 3: ml.QuantBase.n_threads = int(program.args[3]) @@ -585,7 +597,7 @@ if network == 'v1_1.0_224': QuantSoftmax((1, 1001), (1, 1001)) ] -QuantConvBase.init_temp(layers) +ConvBase.init_temp(layers) for layer in layers: layer.input_from(0) diff --git a/Programs/Source/benchmark_net.mpc b/Programs/Source/benchmark_net.mpc new file mode 100644 index 000000000..cf3931e73 --- /dev/null +++ b/Programs/Source/benchmark_net.mpc @@ -0,0 +1,80 @@ +import ml +import util +import math + +if 'trunc_pr' in program.args: + program.use_trunc_pr = True +if 'split' in program.args: + program.use_split(3) + +program.options.cisc = True + +try: + n_threads = int(program.args[2]) +except: + n_threads = None + +ml.Layer.n_threads = n_threads +ml.FixConv2d.use_conv2ds = True + +if 'full' in program.args: + sfix.set_precision(12, 63) +else: + sfix.set_precision(12, 31) + +if program.args[1] == 'A': + layers = [ + ml.Dense(1, 784, 128), + ml.Square([1, 128]), + ml.Dense(1, 128, 128), + ml.Square([1, 128]), + ml.Dense(1, 128, 10), + ml.Argmax((1, 10)), + ] +elif program.args[1] == 'B': + layers = [ + ml.FixConv2d([1, 28, 28, 1], (16, 5, 5, 1), (16,), [1, 24, 24, 16], (1, 1)), + ml.MaxPool([1, 24, 24, 16]), + ml.Relu([1, 12, 12, 16]), + ml.FixConv2d([1, 12, 12, 16], (16, 5, 5, 16), (16,), [1, 8, 8, 16], (1, 1)), + ml.MaxPool([1, 8, 8, 16]), + ml.Relu([1, 4, 4, 16]), + ml.Dense(1, 256, 100), + ml.Relu([1, 100]), + ml.Dense(1, 100, 10), + ml.Argmax((1, 10)), + ] +elif program.args[1] == 'C': + layers = [ + ml.FixConv2d([1, 28, 28, 1], (20, 5, 5, 1), (20,), [1, 24, 24, 20], (1, 1)), + ml.MaxPool([1, 24, 24, 20]), + ml.Relu([1, 12, 12, 20]), + ml.FixConv2d([1, 12, 12, 20], (50, 5, 5, 20), (50,), [1, 8, 8, 50], (1, 1)), + ml.MaxPool([1, 8, 8, 50]), + ml.Relu([1, 4, 4, 50]), + ml.Dense(1, 800, 500), + ml.Relu([1, 500]), + ml.Dense(1, 500, 10), + ml.Argmax((1, 10)), + ] +elif program.args[1] == 'D': + layers = [ + ml.FixConv2d([1, 28, 28, 1], (5, 5, 5, 1), (5,), [1, 14, 14, 5], (2, 2)), + ml.Relu([1, 14, 14, 5]), + ml.Dense(1, 980, 100), + ml.Relu([1, 100]), + ml.Dense(1, 100, 10), + ml.Argmax((1, 10)), + ] +else: + raise Exception('unknown network: ' + program.args[1]) + +opt = ml.Optimizer() +opt.layers = layers +for layer in layers: + layer.input_from(0, raw='raw' in program.args) +layers[0].X.input_from(1) +start_timer(1) +opt.forward(1) +stop_timer(1) +print_ln('guess %s', layers[-1].Y[0].reveal()) diff --git a/Programs/Source/tf.mpc b/Programs/Source/tf.mpc new file mode 100644 index 000000000..e285dd917 --- /dev/null +++ b/Programs/Source/tf.mpc @@ -0,0 +1,37 @@ +import ml +import util +import math +import subprocess + +if 'trunc_pr' in program.args: + program.use_trunc_pr = True +if 'split' in program.args: + program.use_split(3) + +program.options.cisc = True + +try: + n_threads = int(program.args[2]) +except: + n_threads = None +ml.Layer.n_threads = n_threads +ml.FixConv2d.use_conv2ds = True + +sfix.set_precision(12, 31) + +layers = [] +named = {} + +exec(subprocess.check_output(['Scripts/process-tf.py', program.args[1]])) + +opt = ml.Optimizer() +opt.set_layers_with_inputs(layers) +layers[0].X.input_from(0) +for layer in layers: + layer.input_from(0, raw='raw' in program.args) + +start_timer(1) +opt.forward(1, keep_intermediate=False) +stop_timer(1) +if isinstance(layers[-1].Y, Array): + print_ln('guess %s', layers[-1].Y[0].reveal()) diff --git a/Protocols/BrainPrep.hpp b/Protocols/BrainPrep.hpp index 4a7c2ef96..48244496a 100644 --- a/Protocols/BrainPrep.hpp +++ b/Protocols/BrainPrep.hpp @@ -7,6 +7,7 @@ #include "Processor/Processor.h" #include "Protocols/MaliciousRepMC.h" #include "Tools/Subroutines.h" +#include "Math/gfp.h" template class ZProtocol; @@ -80,6 +81,8 @@ class ZProtocol : public Replicated>> tmp.randomize(G); input.add_mine(tmp); } + for (int i = 0; i < this->P.num_players(); i++) + input.add_other(i); input.exchange(); for (int i = 0; i < buffer_size; i++) { diff --git a/Protocols/MamaPrep.h b/Protocols/MamaPrep.h new file mode 100644 index 000000000..a1e774fe7 --- /dev/null +++ b/Protocols/MamaPrep.h @@ -0,0 +1,24 @@ +/* + * MamaPrep.h + * + */ + +#ifndef PROTOCOLS_MAMAPREP_H_ +#define PROTOCOLS_MAMAPREP_H_ + +#include "MascotPrep.h" + +template +class MamaPrep : public OTPrep +{ +public: + static void basic_setup(Player&) {}; + static void teardown() {}; + + MamaPrep(SubProcessor* proc, DataPositions& usage); + + void buffer_triples(); + +}; + +#endif /* PROTOCOLS_MAMAPREP_H_ */ diff --git a/Protocols/MamaPrep.hpp b/Protocols/MamaPrep.hpp new file mode 100644 index 000000000..1aebfe99a --- /dev/null +++ b/Protocols/MamaPrep.hpp @@ -0,0 +1,109 @@ +/* + * MamPrep.cpp + * + */ + +#include "MamaPrep.h" + +#include "SemiMC.hpp" + +template +MamaPrep::MamaPrep(SubProcessor* proc, DataPositions& usage) : + BufferPrep(usage), BitPrep(proc, usage), + RingPrep(proc, usage), OTPrep(proc, usage) +{ + this->params.amplify = true; + this->params.generateMACs = true; + this->params.check = false; +} + +template +void MamaPrep::buffer_triples() +{ + int mac_security = T::N_MACS * T::clear::length(); + + if (mac_security < 40) + { + cerr << T::N_MACS << " MACs are not enough for 40-bit security with " + << T::clear::length() << "-bit primes." << endl; + cerr << "Compile with -DN_MAMA_MACS=" + << DIV_CEIL(40, T::clear::length()) + << " or remove this check in " << __FILE__ << endl; + exit(1); + } + + auto& triple_generator = this->triple_generator; + assert(triple_generator != 0); + assert(this->proc != 0); + this->params.generateBits = false; + vector> triples; + ShuffleSacrifice sacrifice; + size_t required = OnlineOptions::singleton.batch_size; + + // prefer shuffling if not loosing much security and bucket size is smaller + bool use_shuffling = mac_security <= 42 + and OnlineOptions::singleton.bucket_size < T::N_MACS; + if (use_shuffling) + required = sacrifice.minimum_n_inputs(); + + while (triples.size() < required) + { + triple_generator->generateTriples(); + triple_generator->unlock(); + for (auto& x : triple_generator->uncheckedTriples) + { + triples.push_back({}); + for (int k = 0; k < 3; k++) + triples.back()[k] = x.byIndex(k, 0); + } + cerr << "Got " << triple_generator->uncheckedTriples.size() + << " triples" << endl; + } + + if (use_shuffling) + sacrifice.triple_sacrifice(triples, triples, this->proc->P, + this->proc->MC); + else + { + auto& proc = this->proc; + auto& P = proc->P; + const unsigned n_sacrifice = T::N_MACS - 1; + vector, n_sacrifice>> check_triples; + while (n_sacrifice <= triples.size()) + { + check_triples.push_back({}); + for (unsigned i = 0; i < n_sacrifice; i++) + { + check_triples.back()[i] = triples.back(); + triples.pop_back(); + } + } + auto t = GlobalPRNG(P).get(); + vector masked; + PointerVector opened; + for (auto& x : check_triples) + for (unsigned i = 1; i < n_sacrifice; i++) + { + masked.push_back(t * x[0][0] - x[i][0]); + masked.push_back(x[0][1] - x[i][1]); + } + proc->MC.POpen(opened, masked, P); + vector checks; + for (auto& x : check_triples) + { + triples.push_back(x[0]); + for (unsigned i = 1; i < n_sacrifice; i++) + { + auto rho = opened.next(); + auto sigma = opened.next(); + checks.push_back( + t * x[0][2] - x[i][2] - x[i][1] * rho - x[i][0] * sigma + - T::constant(sigma * rho, P.my_num(), + proc->MC.get_alphai())); + } + } + proc->MC.CheckFor(0, checks, P); + } + + this->triples = triples; +} diff --git a/Protocols/MamaShare.h b/Protocols/MamaShare.h new file mode 100644 index 000000000..5eadcbf99 --- /dev/null +++ b/Protocols/MamaShare.h @@ -0,0 +1,73 @@ +/* + * MamaShare.h + * + */ + +#ifndef PROTOCOLS_MAMASHARE_H_ +#define PROTOCOLS_MAMASHARE_H_ + +#include "Share.h" +#include "Math/gfp.h" +#include "Math/FixedVec.h" +#include "OT/MamaRectangle.h" + +template class MamaPrep; +template class MamaMultiplier; +template class SimpleMascotTripleGenerator; + +template +class MamaShare : public Share_, FixedVec, N>> +{ + typedef MamaShare This; + +public: + typedef FixedVec, N> mac_key_type; + typedef Share_, mac_key_type> super; + + typedef Beaver Protocol; + typedef MAC_Check_ MAC_Check; + typedef Direct_MAC_Check Direct_MC; + typedef ::Input Input; + typedef ::PrivateOutput PrivateOutput; + + typedef MamaPrep LivePrep; + typedef MamaShare prep_type; + typedef SimpleMascotTripleGenerator TripleGenerator; + typedef MascotMultiplier Multiplier; + typedef FixedVec sacri_type; + typedef This input_type; + typedef MamaRectangle Square; + typedef typename T::Square Rectangle; + + static const int N_MACS = N; + + static const bool expensive = true; + + static string type_string() + { + return "Mama" + to_string(N); + } + + static void read_or_generate_mac_key(string, Names&, mac_key_type& key) + { + SeededPRNG G; + key.randomize(G); + } + + MamaShare() + { + } + + MamaShare(const super& other) : + super(other) + { + } + + template + MamaShare(const MamaShare& other) : + super(other.get_share(), other.get_mac()) + { + } +}; + +#endif /* PROTOCOLS_MAMASHARE_H_ */ diff --git a/Protocols/MascotPrep.h b/Protocols/MascotPrep.h index a4bdf70db..d7769917d 100644 --- a/Protocols/MascotPrep.h +++ b/Protocols/MascotPrep.h @@ -8,7 +8,7 @@ #include "ReplicatedPrep.h" #include "RandomPrep.h" -#include "OT/TripleMachine.h" +#include "OT/MascotParams.h" template class OTPrep : public virtual RingPrep diff --git a/Protocols/PostSacriRepRingShare.h b/Protocols/PostSacriRepRingShare.h index 3e6c5959c..85adf8b8e 100644 --- a/Protocols/PostSacriRepRingShare.h +++ b/Protocols/PostSacriRepRingShare.h @@ -8,14 +8,15 @@ #include "Protocols/MaliciousRep3Share.h" #include "Protocols/MalRepRingShare.h" +#include "Protocols/Rep3Share2k.h" template class MalRepRingPrepWithBits; template class PostSacrifice; template -class PostSacriRepRingShare : public MaliciousRep3Share> +class PostSacriRepRingShare : public Rep3Share2 { - typedef MaliciousRep3Share> super; + typedef Rep3Share2 super; public: static const int BIT_LENGTH = K; @@ -32,6 +33,8 @@ class PostSacriRepRingShare : public MaliciousRep3Share> typedef ::PrivateOutput PrivateOutput; typedef MalRepRingPrepWithBits LivePrep; + typedef GC::MaliciousRepSecret bit_type; + const static bool expensive = true; static string type_short() diff --git a/Protocols/Rep3Share2k.h b/Protocols/Rep3Share2k.h index 431785fc7..4f0760ebe 100644 --- a/Protocols/Rep3Share2k.h +++ b/Protocols/Rep3Share2k.h @@ -38,7 +38,8 @@ class Rep3Share2 : public Rep3Share> FixedVec::operator=(other); } - static void split(vector& dest, const vector& regs, + template + static void split(vector& dest, const vector& regs, int n_bits, const Rep3Share2* source, int n_inputs, Player& P) { int my_num = P.my_num(); @@ -75,7 +76,7 @@ class Rep3Share2 : public Rep3Share> break; case 2: { - ReplicatedInput input(P); + ReplicatedInput input(P); input.reset_all(P); if (P.my_num() == 0) { @@ -86,6 +87,10 @@ class Rep3Share2 : public Rep3Share> for (int j = 0; j < n_bits; j++) input.add_mine(square.rows[j], m); } + else + for (int j = 0; j < n_bits; j++) + input.add_other(0); + input.exchange(); for (int j = 0; j < n_bits; j++) dest.at(regs.at(2 * j) + k) = input.finalize(0, m); diff --git a/Protocols/ReplicatedInput.h b/Protocols/ReplicatedInput.h index 54119ef94..f38a9588b 100644 --- a/Protocols/ReplicatedInput.h +++ b/Protocols/ReplicatedInput.h @@ -41,6 +41,7 @@ class ReplicatedInput : public PrepLessInput vector os; SeededPRNG secure_prng; ReplicatedBase protocol; + vector expect; public: ReplicatedInput(SubProcessor& proc) : @@ -65,6 +66,7 @@ class ReplicatedInput : public PrepLessInput PrepLessInput(proc), proc(proc), P(P), protocol(P) { assert(T::length == 2); + expect.resize(P.num_players()); } void reset(int player); diff --git a/Protocols/ReplicatedInput.hpp b/Protocols/ReplicatedInput.hpp index 606270113..3daedf0f6 100644 --- a/Protocols/ReplicatedInput.hpp +++ b/Protocols/ReplicatedInput.hpp @@ -14,6 +14,7 @@ template void ReplicatedInput::reset(int player) { + assert(P.num_players() == 3); if (player == P.my_num()) { this->shares.clear(); @@ -22,6 +23,7 @@ void ReplicatedInput::reset(int player) for (auto& o : os) o.reset_write_head(); } + expect[player] = false; } template @@ -39,7 +41,7 @@ inline void ReplicatedInput::add_mine(const typename T::open_type& input, int template void ReplicatedInput::add_other(int player) { - (void) player; + expect[player] = true; } template @@ -51,17 +53,25 @@ void ReplicatedInput::send_mine() template void ReplicatedInput::exchange() { - for (int i = 1; i < P.num_players(); i++) - { - P.pass_around(os[i - 1], InputBase::os[P.get_player(-i)], i); - } + bool receive = expect[P.get_player(1)]; + bool send = not os[1].empty(); + auto& dest = InputBase::os[P.get_player(1)]; + if (send) + if (receive) + P.pass_around(os[1], dest, -1); + else + P.send_to(P.get_player(-1), os[1], true); + else + if (receive) + P.receive_player(P.get_player(1), dest, true); } template inline void ReplicatedInput::finalize_other(int player, T& target, octetStream& o, int n_bits) { - if (P.get_offset(player) == 1) + int offset = player - P.my_num(); + if (offset == 1 or offset == -2) { typename T::value_type t; t.unpack(o, n_bits); diff --git a/Protocols/ReplicatedMC.hpp b/Protocols/ReplicatedMC.hpp index 02eeb7f70..bbcf73e58 100644 --- a/Protocols/ReplicatedMC.hpp +++ b/Protocols/ReplicatedMC.hpp @@ -67,7 +67,8 @@ void ReplicatedMC::finalize(vector& values, template typename T::open_type ReplicatedMC::finalize_open() { - return this->secrets.next().sum() + o.get(); + auto a = this->secrets.next().sum(); + return a + o.get(); } #endif diff --git a/Protocols/ReplicatedMachine.hpp b/Protocols/ReplicatedMachine.hpp index 722886972..595a449d7 100644 --- a/Protocols/ReplicatedMachine.hpp +++ b/Protocols/ReplicatedMachine.hpp @@ -17,7 +17,7 @@ ReplicatedMachine::ReplicatedMachine(int argc, const char** argv, { (void) name; - OnlineOptions online_opts(opt, argc, argv, 1000, true, T::clear::invertible); + OnlineOptions online_opts(opt, argc, argv, 10000, true, T::clear::invertible); OnlineOptions::singleton = online_opts; NetworkOptionsWithNumber network_opts(opt, argc, argv, nplayers, false); opt.add( diff --git a/Protocols/ReplicatedPrep.h b/Protocols/ReplicatedPrep.h index a40b54435..e5c66e886 100644 --- a/Protocols/ReplicatedPrep.h +++ b/Protocols/ReplicatedPrep.h @@ -9,10 +9,9 @@ #include "Networking/Player.h" #include "Processor/Data_Files.h" #include "Processor/OnlineOptions.h" -#include "Processor/Machine.h" -#include "Protocols/Rep3Share.h" +#include "Processor/ThreadQueues.h" #include "Protocols/ShuffleSacrifice.h" -#include "Math/gfp.h" +#include "Protocols/MAC_Check_Base.h" #include "edabit.h" #include @@ -133,8 +132,6 @@ class RingPrep : public virtual BitPrep typedef typename T::bit_type::part_type BT; protected: - size_t sent; - void buffer_dabits_without_check(vector>& dabits, int buffer_size = -1, ThreadQueues* queues = 0); void buffer_edabits_without_check(int n_bits, vector& sums, @@ -173,8 +170,6 @@ class RingPrep : public virtual BitPrep void buffer_personal_edabits_without_check(int n_bits, vector& sums, vector >& bits, SubProcessor& proc, int input_player, int begin, int end); - - virtual size_t data_sent() { return sent; } }; template diff --git a/Protocols/ReplicatedPrep.hpp b/Protocols/ReplicatedPrep.hpp index 7de370975..b35d8cec4 100644 --- a/Protocols/ReplicatedPrep.hpp +++ b/Protocols/ReplicatedPrep.hpp @@ -11,8 +11,8 @@ #include "Spdz2kPrep.h" #include "GC/BitAdder.h" -#include "Math/gfp.h" #include "Processor/OnlineOptions.h" +#include "Protocols/Rep3Share.h" #include "MaliciousRingPrep.hpp" #include "ShuffleSacrifice.hpp" @@ -63,7 +63,7 @@ BitPrep::BitPrep(SubProcessor* proc, DataPositions& usage) : template RingPrep::RingPrep(SubProcessor* proc, DataPositions& usage) : - BufferPrep(usage), BitPrep(proc, usage), sent(0) + BufferPrep(usage), BitPrep(proc, usage) { } @@ -258,22 +258,22 @@ void XOR(vector& res, vector& x, vector& y, res[i] = x[i] + y[i] - prot.finalize_mul() * two; } -template class T> -void buffer_bits_from_squares(RingPrep>& prep) +template +void buffer_bits_from_squares(RingPrep& prep) { auto proc = prep.get_proc(); assert(proc != 0); auto& bits = prep.get_bits(); - vector, 2>> squares(prep.buffer_size); - vector> s; + vector> squares(prep.buffer_size); + vector s; for (int i = 0; i < prep.buffer_size; i++) { prep.get_two(DATA_SQUARE, squares[i][0], squares[i][1]); s.push_back(squares[i][1]); } - vector open; + vector open; proc->MC.POpen(open, s, proc->P); - auto one = T::constant(1, proc->P.my_num(), proc->MC.get_alphai()); + auto one = T::constant(1, proc->P.my_num(), proc->MC.get_alphai()); for (size_t i = 0; i < s.size(); i++) if (open[i] != 0) bits.push_back((squares[i][0] / open[i].sqrRoot() + one) / 2); @@ -282,15 +282,15 @@ void buffer_bits_from_squares(RingPrep>& prep) throw runtime_error("squares were all zero"); } -template class T> -void buffer_bits_spec(ReplicatedPrep>& prep, vector>& bits, - typename T::Protocol& prot) +template class T, int X, int L> +void buffer_bits_spec(ReplicatedPrep>>& prep, vector>>& bits, + typename T>::Protocol& prot) { (void) bits, (void) prot; if (prot.get_n_relevant_players() > 10) buffer_bits_from_squares(prep); else - prep.ReplicatedRingPrep>::buffer_bits(); + prep.ReplicatedRingPrep>>::buffer_bits(); } template diff --git a/Protocols/Share.hpp b/Protocols/Share.hpp index 900a087f1..9fd997b7e 100644 --- a/Protocols/Share.hpp +++ b/Protocols/Share.hpp @@ -1,10 +1,5 @@ #include "Share.h" -#include "Math/gfp.h" -#include "Math/gf2n.h" -#include "Math/Z2k.h" -#include "Math/FixedVec.h" -#include "Math/Integer.h" template diff --git a/Protocols/ShuffleSacrifice.h b/Protocols/ShuffleSacrifice.h index f288f17c8..f54afef34 100644 --- a/Protocols/ShuffleSacrifice.h +++ b/Protocols/ShuffleSacrifice.h @@ -25,29 +25,31 @@ class ShuffleSacrifice { typedef typename T::bit_type::part_type BT; - static const int B = 3; + const int B; public: - static const int C = 3; + const int C; - static int minimum_n_inputs(int n_outputs = 1) + ShuffleSacrifice(); + + int minimum_n_inputs(int n_outputs = 1) { return max(n_outputs, minimum_n_outputs()) * B + C; } - static int minimum_n_inputs_with_combining() + int minimum_n_inputs_with_combining() { return minimum_n_inputs(B * minimum_n_outputs()); } - static int minimum_n_outputs() + int minimum_n_outputs() { -#ifdef INSECURE - if (OnlineOptions::singleton.fake_batch) - { - cout << "FAKE FAKE FAKE" << endl; - return 1 << 10; - } -#endif - return 1 << 20; + if (B == 3) + return 1 << 20; + else if (B == 4) + return 10368; + else if (B == 5) + return 1024; + else + throw runtime_error("not supported: B = " + to_string(B)); } template diff --git a/Protocols/ShuffleSacrifice.hpp b/Protocols/ShuffleSacrifice.hpp index 9ca9fe75c..86bac99af 100644 --- a/Protocols/ShuffleSacrifice.hpp +++ b/Protocols/ShuffleSacrifice.hpp @@ -13,6 +13,12 @@ #include "MalRepRingPrep.hpp" #include "LimitedPrep.hpp" +template +ShuffleSacrifice::ShuffleSacrifice() : + B(OnlineOptions::singleton.bucket_size), C(this->B) +{ +} + template inline void ShuffleSacrifice::triple_combine(vector >& triples, vector >& to_combine, Player& P, diff --git a/Protocols/fake-stuff.hpp b/Protocols/fake-stuff.hpp index a39e61046..0dc8ca795 100644 --- a/Protocols/fake-stuff.hpp +++ b/Protocols/fake-stuff.hpp @@ -4,8 +4,6 @@ #include "Protocols/fake-stuff.h" #include "Processor/Data_Files.h" #include "Tools/benchmarking.h" -#include "Math/gfp.h" -#include "Math/gf2n.h" #include "Math/Setup.h" #include "Protocols/ShamirInput.hpp" diff --git a/README.md b/README.md index 9aa916275..fd0a08e68 100644 --- a/README.md +++ b/README.md @@ -297,6 +297,35 @@ Player-Data Programs $ ../spdz/Scripts/run-online.sh test ``` +### TensorFlow inference + +MP-SPDZ supports inference with selected TensorFlow graphs, in +particular DenseNet, ResNet, and SqueezeNet as used in +[CrypTFlow](https://github.com/mpc-msri/EzPC). For example, you can +run SqueezeNet inference for ImageNet as follows: + +``` +git clone https://github.com/mkskeller/EzPC +cd EzPC/Athos/Networks/SqueezeNetImgNet +axel -a -n 5 -c --output ./PreTrainedModel https://github.com/avoroshilov/tf-squeezenet/raw/master/sqz_full.mat +pip3 install scipy==1.1.0 +python3 squeezenet_main.py --in ./SampleImages/n02109961_36.JPEG --saveTFMetadata True +python3 squeezenet_main.py --in ./SampleImages/n02109961_36.JPEG --scalingFac 12 --saveImgAndWtData True +cd ../../../.. +Scripts/fixed-rep-to-float.py EzPC/Athos/Networks/SqueezeNetImgNet/SqNetImgNet_img_input.inp +./compile.py -R 64 tf EzPC/Athos/Networks/SqueezeNetImgNet/graphDef.bin 1 trunc_pr split +Scripts/ring.sh tf-EzPC_Athos_Networks_SqueezeNetImgNet_graphDef.bin-1-trunc_pr-split +``` + +This requires TensorFlow and the axel command-line utility to be +installed. It runs inference with +three-party semi-honest computation, similar to CrypTFlow's +Porthos. Replace 1 by the desired number of thread in the last two +lines. If you run with any other protocol, you will need to remove +`trunc_pr` and `split`. Also note that you will need to use a +CrypTFlow repository that includes the patch in +https://github.com/mkskeller/EzPC/commit/2021be90d21dc26894be98f33cd10dd26769f479. + ## Dishonest majority Some full implementations require oblivious transfer, which is @@ -310,6 +339,7 @@ The following table shows all programs for dishonest-majority computation using | Program | Protocol | Domain | Security | Script | | --- | --- | --- | --- | --- | | `mascot-party.x` | [MASCOT](https://eprint.iacr.org/2016/505) | Mod prime | Malicious | `mascot.sh` | +| `mama-party.x` | MASCOT* | Mod prime | Malicious | `mama.sh` | | `spdz2k-party.x` | [SPDZ2k](https://eprint.iacr.org/2018/482) | Mod 2^k | Malicious | `spdz2k.sh` | | `semi-party.x` | OT-based | Mod prime | Semi-honest | `semi.sh` | | `semi2k-party.x` | OT-based | Mod 2^k | Semi-honest | `semi2k.sh` | @@ -321,6 +351,12 @@ The following table shows all programs for dishonest-majority computation using | `tiny-party.x` | Adapted SPDZ2k | Binary | Malicious | `tiny.sh` | | `tinier-party.x` | [FKOS15](https://eprint.iacr.org/2015/901) | Binary | Malicious | `tinier.sh` | +Mama denotes MASCOT with several MACs to increase the security +parameter to a multiple of the prime length. The number of MACs +defaults to three, and it is controlled by the `N_MAMA_MACS` +compile-time parameter (add `MY_CFLAGS += -DN_MAMA_MACS=` to `CONFIG.mine`). + Semi and Semi2k denote the result of stripping MASCOT/SPDZ2k of all steps required for malicious security, namely amplifying, sacrificing, MAC generation, and OT correlation checks. What remains is the diff --git a/Scripts/fixed-rep-to-float.py b/Scripts/fixed-rep-to-float.py new file mode 100755 index 000000000..3ed34c4b0 --- /dev/null +++ b/Scripts/fixed-rep-to-float.py @@ -0,0 +1,22 @@ +#!/usr/bin/env python3 + +import sys, operator + +try: + f = int(sys.argv[2]) +except: + f = 12 + +try: + filename = sys.argv[3] +except: + filename = 'Player-Data/Input-P0-0' + +out = open(filename, 'w') + +for line in open(sys.argv[1]): + line = (line.strip()) + if line: + x = (line.split(' ')) + out.write(' '.join(str(int(xx) / 2**f) for xx in x)) + out.write('\n') diff --git a/Scripts/mama.sh b/Scripts/mama.sh new file mode 100755 index 000000000..c1eb3020f --- /dev/null +++ b/Scripts/mama.sh @@ -0,0 +1,8 @@ +#!/bin/bash + +HERE=$(cd `dirname $0`; pwd) +SPDZROOT=$HERE/.. + +. $HERE/run-common.sh + +run_player mama-party.x $* || exit 1 diff --git a/Scripts/process-tf.py b/Scripts/process-tf.py new file mode 100755 index 000000000..46bbbde9e --- /dev/null +++ b/Scripts/process-tf.py @@ -0,0 +1,162 @@ +#!/usr/bin/env python3 + +import sys +from functools import reduce +import operator +import math + +filename = sys.argv[1] + +import tensorflow as tf +from tensorflow.core.framework import graph_pb2 +import numpy + +graph_def = graph_pb2.GraphDef() +graph_def.ParseFromString(open(filename, mode='rb').read()) +tf.import_graph_def(graph_def) +graph = tf.compat.v1.get_default_graph() + +first = True +paddings = {} + +print('ml.Layer.input_bias = False') + +def output(op, layer, prev_input=True): + global first + print('named["%s"] = %s' % (op.name, layer)) + print('layers.append(named["%s"])' % op.name) + if prev_input and not first: + print('named["%s"].inputs = [named["%s"]]' % (op.name, + op.inputs[0].name[:-2])) + first = False + +def link(dest, source): + print('named["%s"] = named["%s"]' % (dest.name, source.name)) + +def source(dest): + print('named["%s"] = None' % dest.name) + +def activate_bias(op): + print('named["%s"].input_bias = True' % op.name) + +def get_shape(shape): + res = [] + for x in shape: + try: + res.append(int(x)) + except: + res.append(1) + return res + +def get_valid_padding(input_shape, window, strides): + return [int(math.ceil((x - y + 1) / z)) + for x, y, z in zip(input_shape, window, strides)] + +for op in graph.get_operations(): + if op.inputs: + shape = get_shape(op.inputs[0].shape) + else: + shape = None + t = op.type + if t in ('VariableV2', 'Const'): + pass + elif t in ('Reshape', 'Squeeze'): + link(op, op.inputs[0].op) + elif t == 'Placeholder': + source(op) + elif t == 'MatMul': + #print (op.inputs[0].shape) + assert reduce(operator.mul, shape) == op.inputs[1].shape[0] + output(op, 'ml.Dense(1, %d, %d)' % (op.inputs[1].shape[0], + op.inputs[1].shape[1])) + shape = [1, int(op.inputs[1].shape[1])] + elif t == 'Conv2D': + strides = op.get_attr('strides') + assert len(strides) == 4 + assert strides[0] == 1 + assert strides[3] == 1 + strides = tuple(strides[1:3]) + input_shape = get_shape(op.inputs[0].shape) + assert len(input_shape) == 4 + window = [int(x) for x in op.inputs[1].shape] + padding = op.get_attr('padding').decode('u8') + if padding not in ('SAME', 'VALID'): + padding = get_shape(padding) + if op.inputs[0].op.name in paddings: + assert padding == 'VALID' + input_shape = get_shape(op.inputs[0].op.inputs[0].shape) + p = paddings.pop(op.inputs[0].op.name) + for i in 0, 6: + assert p[i] == 0 + padding = [p[2], p[4]] + output_shape = get_shape(op.outputs[0].shape) + assert len(output_shape) == 4 + output(op, 'ml.FixConv2d(%s, %s, %s, %s, %s, %s, True, ' + 'inputs=[named["%s"]])' % \ + (input_shape, tuple(window), (window[3],), output_shape, strides, + repr(padding), op.inputs[0].op.name)) + elif t == 'Add' and op.inputs[1].op.type != 'VariableV2': + output(op, 'ml.Add([%s])' % ','.join('named["%s"]' % x.op.name + for x in op.inputs), False) + elif t in ('Add', 'BiasAdd'): + assert op.inputs[0].op.type in ('MatMul', 'Conv2D') + activate_bias(op.inputs[0].op) + link(op, op.inputs[0].op) + elif t == 'Relu': + assert len(op.inputs) == 1 + output(op, 'ml.Relu(%s, inputs=[named["%s"]])' % (shape, + op.inputs[0].op.name)) + elif t == 'Square': + output(op, 'ml.Square(%s)' % (shape,)) + elif t == 'MaxPool': + strides = op.get_attr('strides') + ksize = op.get_attr('ksize') + padding = str(op.get_attr('padding').decode('u8')) + output(op, 'ml.MaxPool(%s, %s, %s, "%s")' % (shape, strides, ksize, + padding)) + elif t == 'AvgPool': + filter_size = op.get_attr('ksize') + assert len(filter_size) == 4 + assert filter_size[0] == 1 + assert filter_size[-1] == 1 + input_shape = get_shape(op.inputs[0].shape) + strides = get_shape(op.get_attr('strides')) + assert strides[0] == 1 + assert strides[3] == 1 + padding = op.get_attr('padding').decode('u8') + if padding == 'VALID': + output_shape = get_valid_padding(input_shape, filter_size, strides) + elif padding == 'SAME': + output_shape = [int(math.ceil(x / y)) + for x, y in zip(input_shape, filter_size)] + else: + raise Exception('unknown padding type: %s' % padding) + output(op, 'ml.FixAveragePool2d(%s, %s, %s, %s)' % + (input_shape, output_shape, filter_size[1:3], strides[1:3])) + elif t == 'ArgMax': + assert len(op.inputs) == 2 + shape = get_shape(op.inputs[0].shape) + dim = int(op.inputs[1].op.get_attr('value').int_val[0]) + for i in range(1, len(shape)): + if i != dim: + assert shape[i] == 1 + output(op, 'ml.Argmax((1, %s))' % shape[dim]) + elif t == 'ConcatV2': + assert len(op.inputs) == 3 + dim = int(op.inputs[2].op.get_attr('value').int_val[0]) + output(op, 'ml.Concat([%s], %s)' % ( + ','.join('named["%s"]' % x.name[:-2] for x in op.inputs[:2]), dim), + prev_input=False) + elif t == 'FusedBatchNorm': + output(op, 'ml.FusedBatchNorm(%s, inputs=[named["%s"]])' % + (get_shape(op.inputs[0].shape), op.inputs[0].op.name)) + elif t == 'Pad': + paddings[op.name] = numpy.fromstring(op.inputs[1].op.get_attr('value'). + tensor_content, 'int32').tolist() + link(op, op.inputs[0].op) + else: + raise Exception('unknown type: %s' % t) + +if paddings: + raise Exception('padding layers only supported before valid convolution:', + paddings) diff --git a/Scripts/test_tutorial.sh b/Scripts/test_tutorial.sh index 59fa95931..cbdd9255f 100755 --- a/Scripts/test_tutorial.sh +++ b/Scripts/test_tutorial.sh @@ -7,8 +7,6 @@ while getopts XYC opt; do ;; Y) dabit=2 ;; - C) cheap=1 - ;; esac done @@ -36,50 +34,30 @@ function test_vm fi } +# big buckets for smallest batches +run_opts="$run_opts -B 5" + for dabit in ${dabit:-0 1 2}; do if [[ $dabit = 1 ]]; then compile_opts="$compile_opts -X" elif [[ $dabit = 2 ]]; then - if [[ $cheap != 1 ]]; then - run_opts="$run_opts --fake-batch" - fi compile_opts="$compile_opts -Y" fi ./compile.py -R 64 $compile_opts tutorial - for i in ring semi2k; do + for i in ring semi2k brain mal-rep-ring ps-rep-ring spdz2k; do test_vm $i $run_opts done - if ! test "$dabit" = 2 -a "$cheap" = 1; then - for i in brain mal-rep-ring ps-rep-ring spdz2k; do - test_vm $i $run_opts - done - fi - ./compile.py $compile_opts tutorial - for i in rep-field shamir; do - test_vm $i - done - - if ! test "$dabit" = 2 -a "$cheap" = 1; then - for i in mal-rep-field ps-rep-field mal-shamir; do - test_vm $i $run_opts - done - fi - - for i in hemi semi soho; do - test_vm $i + for i in rep-field shamir mal-rep-field ps-rep-field mal-shamir hemi semi \ + soho cowgear mascot; do + test_vm $i $run_opts done - if ! test "$dabit" = 2 -a "$cheap" = 1; then - for i in cowgear mascot; do - test_vm $i $run_opts - done - test_vm chaigear $run_opts -l 3 -c 2 - fi + test_vm chaigear $run_opts -l 3 -c 2 done ./compile.py tutorial @@ -89,6 +67,6 @@ test_vm chaigear -T -l 3 -c 2 ./compile.py -B 16 $compile_opts tutorial -for i in replicated mal-rep-bin semi-bin ccd mal-ccd yao tinier rep-bmr mal-rep-bmr shamir-bmr mal-shamir-bmr tiny; do +for i in replicated mal-rep-bin semi-bin ccd mal-ccd yao tinier rep-bmr mal-rep-bmr shamir-bmr mal-shamir-bmr; do test_vm $i done diff --git a/Tools/BitVector.h b/Tools/BitVector.h index 93cf22b03..d92ddec52 100644 --- a/Tools/BitVector.h +++ b/Tools/BitVector.h @@ -13,9 +13,8 @@ using namespace std; #include "Exceptions/Exceptions.h" #include "Networking/data.h" // just for util functions -#include "Math/bigint.h" #include "Math/gf2nlong.h" -#include "Math/gfp.h" +#include "Math/FixedVec.h" class PRNG; class octetStream; @@ -186,7 +185,8 @@ class BitVector template void set(const T& a); - + template + void set(const FixedVec& a); bool get_bit(int i) const { if (i >= (int)nbits) @@ -281,6 +281,18 @@ void inline BitVector::set_portion(int i, const T& a) memcpy(bytes + a.size() * i, a.get_ptr(), a.size()); } +template +void BitVector::set(const FixedVec& a) +{ + resize(8 * a.size()); + size_t base = 0; + for (int i = 0; i < L; i++) + { + memcpy(bytes + base, a[i].get_ptr(), a[i].size()); + base += a[i].size(); + } +} + template void inline BitVector::set(const T& a) { diff --git a/Tools/MMO.h b/Tools/MMO.h index f2fd2996e..5bc440f48 100644 --- a/Tools/MMO.h +++ b/Tools/MMO.h @@ -15,6 +15,9 @@ class MMO static const int N_KEYS = 2; octet IV[N_KEYS][176] __attribute__((aligned (16))); + template + static void encrypt_and_xor(__m128i* output, const __m128i* input, + const octet* key); template static void encrypt_and_xor(void* output, const void* input, const octet* key); @@ -32,8 +35,10 @@ class MMO void hashBlocks(void* output, const void* input, size_t alloc_size); template void hashBlocks(void* output, const void* input); + template + void hashEightBlocks(T* output, const void* input); template - void hashEightGfp(void* output, const void* input); + void hashEightBlocks(gfp_* output, const void* input); template void outputOneBlock(octet* output); Key hash(const Key& input); @@ -42,13 +47,19 @@ class MMO }; template -inline void MMO::encrypt_and_xor(void* output, const void* input, const octet* key) +inline void MMO::encrypt_and_xor(__m128i* out, const __m128i* in, const octet* key) { - __m128i in[N], out[N]; - avx_memcpy(in, input, sizeof(in)); ecb_aes_128_encrypt(out, in, key); for (int i = 0; i < N; i++) out[i] = _mm_xor_si128(out[i], in[i]); +} + +template +inline void MMO::encrypt_and_xor(void* output, const void* input, const octet* key) +{ + __m128i in[N], out[N]; + avx_memcpy(in, input, sizeof(in)); + encrypt_and_xor(out, in, key); avx_memcpy(output, out, sizeof(out)); } @@ -62,7 +73,7 @@ inline Key MMO::hash(const Key& input) template inline void MMO::hash(Key* output, const Key* input) { - encrypt_and_xor(output, input, IV[0]); + encrypt_and_xor(&output->r, &input->r, IV[0]); } #endif /* TOOLS_MMO_H_ */ diff --git a/Tools/MMO.hpp b/Tools/MMO.hpp index cde5354ff..d0f43e450 100644 --- a/Tools/MMO.hpp +++ b/Tools/MMO.hpp @@ -5,7 +5,6 @@ */ #include "MMO.h" -#include "Math/gfp.h" #include @@ -68,20 +67,8 @@ void MMO::hashBlocks(void* output, const void* input) ((T*)output + j)->normalize(); } -template <> -inline -void MMO::hashBlocks(void* output, const void* input) -{ - if (gfp1::get_ZpD().get_t() != 2) - throw not_implemented(); - encrypt_and_xor<1>(output, input, IV[0]); - while (mpn_cmp((mp_limb_t*)output, gfp1::get_ZpD().get_prA(), gfp1::t()) >= 0) - _mm_storeu_si128((__m128i *) output, - aes_128_encrypt(_mm_loadu_si128((__m128i *) output), IV[0])); -} - template -void MMO::hashEightGfp(void* output, const void* input) +void MMO::hashEightBlocks(gfp_* output, const void* input) { gfp_* out = (gfp_*)output; const int block_size = sizeof(__m128i); @@ -117,16 +104,15 @@ void MMO::hashEightGfp(void* output, const void* input) } } -template <> -inline -void MMO::hashBlocks(void* output, const void* input) +template +void MMO::hashEightBlocks(T* output, const void* input) { - hashEightGfp<1, GFP_MOD_SZ>(output, input); + hashBlocks(output, input); } template <> inline -void MMO::hashBlocks(void* output, const void* input) +void MMO::hashEightBlocks(__m128i* output, const void* input) { - hashEightGfp<3, 4>(output, input); + hashBlocks<8, 16>(output, input, 16); } diff --git a/Tools/aes.h b/Tools/aes.h index fec9bc949..7ef6c3d2a 100644 --- a/Tools/aes.h +++ b/Tools/aes.h @@ -75,7 +75,7 @@ inline __m128i aes_128_encrypt(__m128i in, const octet* key) } template -inline void software_ecb_aes_128_encrypt(__m128i* out, __m128i* in, uint* key) +inline void software_ecb_aes_128_encrypt(__m128i* out, const __m128i* in, uint* key) { for (int i = 0; i < N; i++) aes_128_encrypt((octet*)&out[i], (octet*)&in[i], key); @@ -85,7 +85,7 @@ template #ifndef __clang__ __attribute__((optimize("unroll-loops"))) #endif -inline void ecb_aes_128_encrypt(__m128i* out, __m128i* in, const octet* key) +inline void ecb_aes_128_encrypt(__m128i* out, const __m128i* in, const octet* key) { #ifdef __AES__ if (cpu_has_aes()) diff --git a/Utils/Check-Offline.cpp b/Utils/Check-Offline.cpp index aaa87c9f3..a8124b3eb 100644 --- a/Utils/Check-Offline.cpp +++ b/Utils/Check-Offline.cpp @@ -8,6 +8,7 @@ #include "Protocols/Share.h" #include "Protocols/fake-stuff.h" #include "Protocols/MAC_Check.h" +#include "Protocols/Rep3Share.h" #include "Tools/ezOptionParser.h" #include "Exceptions/Exceptions.h" #include "GC/MaliciousRepSecret.h" diff --git a/Utils/gc-emulate.cpp b/Utils/gc-emulate.cpp index bdeb357fe..f90014733 100644 --- a/Utils/gc-emulate.cpp +++ b/Utils/gc-emulate.cpp @@ -26,7 +26,7 @@ int main(int argc, char** argv) GC::Memory dynamic_memory; GC::Machine machine; GC::Processor processor(machine); - GC::Program program; + GC::Program program; program.parse(string(argv[1]) + "-0"); machine.reset(program, dynamic_memory); processor.reset(program); diff --git a/Utils/ot-offline.cpp b/Utils/ot-offline.cpp index 3df15c9ed..ea057c869 100644 --- a/Utils/ot-offline.cpp +++ b/Utils/ot-offline.cpp @@ -4,6 +4,7 @@ */ #include "OT/NPartyTripleGenerator.h" +#include "OT/TripleMachine.h" int main(int argc, const char** argv) { diff --git a/Yao/YaoAndJob.h b/Yao/YaoAndJob.h index 3eb72d9c1..d2e32fd7f 100644 --- a/Yao/YaoAndJob.h +++ b/Yao/YaoAndJob.h @@ -9,11 +9,17 @@ #include "YaoGarbleWire.h" #include "Tools/Worker.h" -class YaoGate; +enum YaoJobType +{ + YAO_AND_JOB, + YAO_XOR_JOB, + YAO_NO_JOB +}; +template class YaoAndJob { - GC::Memory< GC::Secret >* S; + GC::Processor< GC::Secret >* processor; const vector* args; size_t start, end, n_gates; YaoGate* gate; @@ -21,14 +27,15 @@ class YaoAndJob PRNG prng; map timers; bool repeat; - YaoGarbler& garbler; + typename T::Party& party; + YaoJobType type; public: Worker worker; - YaoAndJob(YaoGarbler& garbler) : - S(0), args(0), start(0), end(0), n_gates(0), gate(0), counter(0), - repeat(0), garbler(garbler) + YaoAndJob(typename T::Party& party) : + processor(0), args(0), start(0), end(0), n_gates(0), gate(0), + counter(0), repeat(0), party(party), type(YAO_NO_JOB) { prng.ReSeed(); } @@ -41,11 +48,13 @@ class YaoAndJob #endif } - void dispatch(GC::Memory >& S, const vector& args, + void dispatch(YaoJobType type, + GC::Processor >& processor, const vector& args, size_t start, size_t end, size_t n_gates, YaoGate* gate, long counter, bool repeat) { - this->S = &S; + this->type = type; + this->processor = &processor; this->args = &args; this->start = start; this->end = end; @@ -58,8 +67,20 @@ class YaoAndJob int run() { - YaoGarbleWire::and_(*S, *args, start, end, n_gates, gate, counter, - prng, timers, repeat, garbler); + switch(type) + { + case YAO_AND_JOB: + T::and_(processor->S, *args, start, end, n_gates, gate, counter, + prng, timers, repeat, party); + break; + case YAO_XOR_JOB: + T::xors(*processor, *args, start, end); + break; + default: + throw runtime_error("job not specified: " + to_string(type)); + } + + type = YAO_NO_JOB; return 0; } }; diff --git a/Yao/YaoCommon.h b/Yao/YaoCommon.h index d4ebdd850..ab65c2407 100644 --- a/Yao/YaoCommon.h +++ b/Yao/YaoCommon.h @@ -8,23 +8,47 @@ #include #include +#include +#include #include "Exceptions/Exceptions.h" #include "GC/RuntimeBranching.h" +#include "GC/ThreadMaster.h" +#include "YaoAndJob.h" +#include + +template class YaoCommon : public GC::RuntimeBranching { int log_n_threads; + GC::ThreadMaster>& master; + public: static const int DONE = -1; static const int MORE = -2; long counter; - YaoCommon() : - log_n_threads(8), counter(0) + vector*> jobs; + + YaoCommon(GC::ThreadMaster>& master) : + log_n_threads(8), master(master), counter(0) + { + } + + ~YaoCommon() + { + for (auto& job : jobs) + delete job; + } + + void init(typename T::Party& party) { + jobs.resize(get_n_worker_threads()); + for (auto& job : jobs) + job = new YaoAndJob(party); } void set_n_program_threads(int n_threads) @@ -36,6 +60,20 @@ class YaoCommon : public GC::RuntimeBranching { return counter + (thread_num << (64 - log_n_threads)); } + + int get_n_worker_threads() + { + return max(1u, thread::hardware_concurrency() / master.machine.nthreads); + } + + vector> get_splits(const vector& args, int threshold, + int total); + + void wait(int n_threads) + { + for (int i = 0; i < n_threads; i++) + jobs[i]->worker.done(); + } }; #endif /* YAO_YAOCOMMON_H_ */ diff --git a/Yao/YaoCommon.hpp b/Yao/YaoCommon.hpp new file mode 100644 index 000000000..918f774f5 --- /dev/null +++ b/Yao/YaoCommon.hpp @@ -0,0 +1,27 @@ +/* + * YaoCommon.cpp + * + */ + +#include "YaoCommon.h" + +template +vector > YaoCommon::get_splits(const vector& args, + int threshold, int total) +{ + vector> res; + size_t max_gates_per_thread = max(threshold / 2, + (total + get_n_worker_threads() - 1) / get_n_worker_threads()); + size_t i_gate = 0; + for (auto it = args.begin(); it < args.end(); it += 4) + { + i_gate += *it; + auto end = it + 4; + if (i_gate >= max_gates_per_thread or end >= args.end()) + { + res.push_back({{i_gate, size_t(end - args.begin())}}); + i_gate = 0; + } + } + return res; +} diff --git a/Yao/YaoEvalMaster.cpp b/Yao/YaoEvalMaster.cpp index 575b38001..558d2cfc3 100644 --- a/Yao/YaoEvalMaster.cpp +++ b/Yao/YaoEvalMaster.cpp @@ -14,6 +14,7 @@ #include "GC/Thread.hpp" #include "GC/ThreadMaster.hpp" #include "Processor/Instruction.hpp" +#include "YaoWire.hpp" YaoEvalMaster::YaoEvalMaster(bool continuous, OnlineOptions& opts) : ThreadMaster>(opts), continuous(continuous) diff --git a/Yao/YaoEvalWire.cpp b/Yao/YaoEvalWire.cpp index 704cf2187..9d5af9d70 100644 --- a/Yao/YaoEvalWire.cpp +++ b/Yao/YaoEvalWire.cpp @@ -14,6 +14,7 @@ #include "GC/Processor.hpp" #include "GC/Secret.hpp" #include "GC/Thread.hpp" +#include "YaoCommon.hpp" ostream& YaoEvalWire::out = cout; @@ -32,98 +33,100 @@ template void YaoEvalWire::and_(GC::Processor >& processor, const vector& args) { - int total_ands = processor.check_args(args, 4); - if (total_ands < 10) - return processor.andrs(args); - processor.complexity += total_ands; - Key* labels; - Key* hashes; - vector label_vec, hash_vec; - size_t n_hashes = total_ands; - Key label_arr[1000], hash_arr[1000]; - if (total_ands < 1000) + YaoEvaluator& party = YaoEvaluator::s(); + int total = processor.check_args(args, 4); + int threshold = 1024; + if (total < threshold) { - labels = label_arr; - hashes = hash_arr; + // run in single thread + and_singlethread(processor, args, total); + return; } - else + + processor.complexity += total; + int i_thread = 0, start = 0; + for (auto& x : party.get_splits(args, threshold, total)) { - label_vec.resize(n_hashes); - hash_vec.resize(n_hashes); - labels = label_vec.data(); - hashes = hash_vec.data(); + auto i_gate = x[0]; + auto end = x[1]; + YaoGate* gate = (YaoGate*) party.gates.consume( + i_gate * sizeof(YaoGate)); + party.jobs[i_thread++]->dispatch(YAO_AND_JOB, processor, args, start, + end, i_gate, gate, party.get_gate_id(), repeat); + party.counter += i_gate; + start = end; } - size_t i_label = 0; - auto& evaluator = YaoEvaluator::s(); + party.wait(i_thread); +} + +template +void YaoEvalWire::and_singlethread(GC::Processor >& processor, + const vector& args, int total_ands) +{ + if (total_ands < 10) + return processor.and_(args, repeat); + processor.complexity += total_ands; + size_t n_args = args.size(); + YaoEvaluator& party = YaoEvaluator::s(); + YaoGate* gate = (YaoGate*) party.gates.consume(total_ands * sizeof(YaoGate)); + long counter = party.get_gate_id(); + map timers; + SeededPRNG prng; + and_(processor.S, args, 0, n_args, total_ands, gate, counter, + prng, timers, repeat, party); + party.counter += counter - party.get_gate_id(); +} + +void YaoEvalWire::and_(GC::Memory >& S, + const vector& args, size_t start, size_t end, size_t, + YaoGate* gates, long& gate_id, PRNG&, map&, + bool repeat, YaoEvaluator& evaluator) +{ int dl = GC::Secret::default_length; - for (auto it = args.begin(); it < args.end(); it += 4) - { - if (*it == 1) - { - evaluator.counter++; - labels[i_label++] = YaoGate::E_input( - processor.S[*(it + 2)].get_reg(0).key, - processor.S[*(it + 3)].get_reg(0).key, - evaluator.get_gate_id()); - } - else - { - int n_units = DIV_CEIL(*it, dl); - for (int j = 0; j < n_units; j++) - { - auto& left = processor.S[*(it + 2) + j]; - auto& right = processor.S[*(it + 3) + (repeat ? 0 : j)]; - int n = min(dl, *it - j * dl); - for (int k = 0; k < n; k++) - { - auto& left_wire = left.get_reg(k); - auto& right_key = right.get_reg(repeat ? 0 : k).key; - evaluator.counter++; - labels[i_label++] = YaoGate::E_input(left_wire.key, right_key, - evaluator.get_gate_id()); - } - } - } - } MMO& mmo = evaluator.mmo; - size_t i; - for (i = 0; i + 8 <= n_hashes; i += 8) - mmo.hash<8>(&hashes[i], &labels[i]); - for (; i < n_hashes; i++) - hashes[i] = mmo.hash(labels[i]); - size_t j = 0; - for (auto it = args.begin(); it < args.end(); it += 4) + for (auto it = args.begin() + start; it < args.begin() + end; it += 4) { if (*it == 1) { - auto& out = processor.S[*(it + 1)]; + Key label[YaoGate::N_EVAL_HASHES]; + Key hash[YaoGate::N_EVAL_HASHES]; + gate_id++; + YaoGate::eval_inputs(label, + S[*(it + 2)].get_reg(0).key(), + S[*(it + 3)].get_reg(0).key(), + gate_id); + mmo.hash(hash, label); + auto& out = S[*(it + 1)]; out.resize_regs(1); - YaoGate gate; - evaluator.load_gate(gate); - gate.eval(out.get_reg(0), hashes[j++], - gate.get_entry(processor.S[*(it + 2)].get_reg(0).external, - processor.S[*(it + 3)].get_reg(0).external)); + YaoGate& gate = *gates; + gates++; + gate.eval(out.get_reg(0), hash, S[*(it + 2)].get_reg(0), + S[*(it + 3)].get_reg(0)); } else { int n_units = DIV_CEIL(*it, dl); - for (int l = 0; l < n_units; l++) + for (int j = 0; j < n_units; j++) { - auto& left = processor.S[*(it + 2) + l]; - auto& right = processor.S[*(it + 3) + (repeat ? 0 : l)]; - auto& out = processor.S[*(it + 1) + l]; - int n = min(dl, *it - l * dl); + auto& left = S[*(it + 2) + j]; + auto& right = S[*(it + 3) + (repeat ? 0 : j)]; + auto& out = S[*(it + 1) + j]; + int n = min(dl, *it - j * dl); out.resize_regs(n); - for (int k = 0; k < n; k++) { - auto& right_wire = right.get_reg(repeat ? 0 : k); + Key label[YaoGate::N_EVAL_HASHES]; + Key hash[YaoGate::N_EVAL_HASHES]; auto& left_wire = left.get_reg(k); - YaoGate gate; - evaluator.load_gate(gate); - gate.eval(out.get_reg(k), hashes[j++], - gate.get_entry(left_wire.external, - right_wire.external)); + auto& right_key = right.get_reg(repeat ? 0 : k).key(); + gate_id++; + YaoGate::eval_inputs(label, left_wire.key(), right_key, + gate_id); + mmo.hash(hash, label); + auto& right_wire = right.get_reg(repeat ? 0 : k); + YaoGate& gate = *gates; + gates++; + gate.eval(out.get_reg(k), hash, left_wire, right_wire); } } } @@ -197,23 +200,22 @@ void YaoEvalWire::op(const YaoEvalWire& left, const YaoEvalWire& right, bool YaoEvalWire::get_output() { YaoEvaluator::s().taint(); - bool res = external ^ YaoEvaluator::s().output_masks.pop_front(); + bool res = external() ^ YaoEvaluator::s().output_masks.pop_front(); #ifdef DEBUG - cout << "output " << res << " mask " << (external ^ res) << " external " - << external << endl; + cout << "output " << res << " mask " << (external() ^ res) << " external() " + << external() << endl; #endif return res; } void YaoEvalWire::set(const Key& key) { - this->key = key; - external = key.get_signal(); + this->key_ = key; } void YaoEvalWire::set(Key key, bool external) { - key.set_signal(external); + assert(key.get_signal() == external); set(key); } diff --git a/Yao/YaoEvalWire.h b/Yao/YaoEvalWire.h index 6aedf64bd..d5787bc7c 100644 --- a/Yao/YaoEvalWire.h +++ b/Yao/YaoEvalWire.h @@ -10,18 +10,21 @@ #include "BMR/Gate.h" #include "BMR/Register.h" #include "Processor/DummyProtocol.h" +#include "config.h" +#include "YaoWire.h" -class YaoEvalWire : public Phase +class YaoEvaluator; + +class YaoEvalWire : public YaoWire { public: + typedef YaoEvaluator Party; + static string name() { return "YaoEvalWire"; } typedef ostream& out_type; static ostream& out; - bool external; - Key key; - static YaoEvalWire new_reg() { return {}; } static void andrs(GC::Processor>& processor, @@ -37,6 +40,14 @@ class YaoEvalWire : public Phase template static void and_(GC::Processor>& processor, const vector& args); + template + static void and_singlethread( + GC::Processor>& processor, + const vector& args, int total_ands); + static void and_(GC::Memory>& S, + const vector& args, size_t start, size_t end, + size_t total_ands, YaoGate* gate, long& counter, PRNG& prng, + map& timers, bool repeat, YaoEvaluator& garbler); static void inputb(GC::Processor>& processor, const vector& args); @@ -47,17 +58,20 @@ class YaoEvalWire : public Phase void set(const Key& key); void set(Key key, bool external); + const Key& key() const + { + return key_; + } + + bool external() const + { + return key_.get_signal(); + } + void random(); void public_input(bool value); void op(const YaoEvalWire& left, const YaoEvalWire& right, Function func); - void XOR(const YaoEvalWire& left, const YaoEvalWire& right); bool get_output(); }; -inline void YaoEvalWire::XOR(const YaoEvalWire& left, const YaoEvalWire& right) -{ - external = left.external ^ right.external; - key = left.key ^ right.key; -} - #endif /* YAO_YAOEVALWIRE_H_ */ diff --git a/Yao/YaoEvaluator.cpp b/Yao/YaoEvaluator.cpp index 8239ffc21..750c5ed30 100644 --- a/Yao/YaoEvaluator.cpp +++ b/Yao/YaoEvaluator.cpp @@ -13,16 +13,19 @@ #include "GC/Thread.hpp" #include "GC/ThreadMaster.hpp" #include "Tools/MMO.hpp" +#include "YaoWire.hpp" thread_local YaoEvaluator* YaoEvaluator::singleton = 0; YaoEvaluator::YaoEvaluator(int thread_num, YaoEvalMaster& master) : Thread>(thread_num, master), + YaoCommon(master), master(master), player(N, 0, thread_num << 24), ot_ext(OTExtensionWithMatrix::setup(player, {}, RECEIVER, true)) { set_n_program_threads(master.machine.nthreads); + this->init(*this); } void YaoEvaluator::pre_run() @@ -31,7 +34,7 @@ void YaoEvaluator::pre_run() receive_to_store(*P); } -void YaoEvaluator::run(GC::Program>& program) +void YaoEvaluator::run(GC::Program& program) { singleton = this; @@ -43,7 +46,7 @@ void YaoEvaluator::run(GC::Program>& program) } } -void YaoEvaluator::run(GC::Program>& program, Player& P) +void YaoEvaluator::run(GC::Program& program, Player& P) { auto next = GC::TIME_BREAK; do @@ -60,7 +63,7 @@ void YaoEvaluator::run(GC::Program>& program, Player& P) while(GC::DONE_BREAK != next); } -void YaoEvaluator::run_from_store(GC::Program>& program) +void YaoEvaluator::run_from_store(GC::Program& program) { machine.reset_timer(); do diff --git a/Yao/YaoEvaluator.h b/Yao/YaoEvaluator.h index b542304cb..8ac34d78a 100644 --- a/Yao/YaoEvaluator.h +++ b/Yao/YaoEvaluator.h @@ -14,7 +14,8 @@ #include "Tools/MMO.h" #include "OT/OTExtensionWithMatrix.h" -class YaoEvaluator : public GC::Thread>, public YaoCommon +class YaoEvaluator: public GC::Thread>, + public YaoCommon { protected: static thread_local YaoEvaluator* singleton; @@ -24,6 +25,9 @@ class YaoEvaluator : public GC::Thread>, public YaoCommo YaoEvalMaster& master; + friend class YaoCommon; + friend class YaoEvalWire; + public: ReceivedMsg output_masks; ReceivedMsgStore output_masks_store; @@ -40,15 +44,18 @@ class YaoEvaluator : public GC::Thread>, public YaoCommo bool continuous() { return master.continuous and master.machine.nthreads == 1; } void pre_run(); - void run(GC::Program>& program); - void run(GC::Program>& program, Player& P); - void run_from_store(GC::Program>& program); + void run(GC::Program& program); + void run(GC::Program& program, Player& P); + void run_from_store(GC::Program& program); bool receive(Player& P); void receive_to_store(Player& P); void load_gate(YaoGate& gate); long get_gate_id() { return gate_id(thread_num); } + + int get_n_worker_threads() + { return max(1u, thread::hardware_concurrency() / master.machine.nthreads); } }; inline void YaoEvaluator::load_gate(YaoGate& gate) diff --git a/Yao/YaoGarbleMaster.h b/Yao/YaoGarbleMaster.h index 40914013b..503fa8289 100644 --- a/Yao/YaoGarbleMaster.h +++ b/Yao/YaoGarbleMaster.h @@ -15,14 +15,20 @@ class YaoGarbleMaster : public GC::ThreadMaster> { typedef GC::ThreadMaster> super; + Key delta; + public: bool continuous; int threshold; - Key delta; YaoGarbleMaster(bool continuous, OnlineOptions& opts, int threshold = 1024); GC::Thread>* new_thread(int i); + + Key get_delta() + { + return delta; + } }; #endif /* YAO_YAOGARBLEMASTER_H_ */ diff --git a/Yao/YaoGarbleWire.cpp b/Yao/YaoGarbleWire.cpp index 1ed519878..eff25d5bb 100644 --- a/Yao/YaoGarbleWire.cpp +++ b/Yao/YaoGarbleWire.cpp @@ -11,17 +11,16 @@ #include "GC/Processor.hpp" #include "GC/Secret.hpp" #include "GC/Thread.hpp" +#include "YaoCommon.hpp" void YaoGarbleWire::random() { - mask = YaoGarbler::s().prng.get_bit(); - key = 0; + key_ = YaoGarbler::s().prng.get_bit(); } void YaoGarbleWire::public_input(bool value) { - mask = value; - key = 0; + key_ = YaoGate::garble_public_input(value, YaoGarbler::s().get_delta()); } void YaoGarbleWire::and_(GC::Processor >& processor, @@ -51,30 +50,24 @@ void YaoGarbleWire::and_multithread(GC::Processor >& p processor.complexity += total; SendBuffer& gates = party.gates; gates.allocate(total * sizeof(YaoGate)); - int max_gates_per_thread = max(party.get_threshold() / 2, - (total + party.get_n_worker_threads() - 1) / party.get_n_worker_threads()); - int i_thread = 0, i_gate = 0, start = 0; - for (size_t j = 0; j < args.size(); j += 4) + int i_thread = 0, start = 0; + for (auto& x : party.get_splits(args, party.get_threshold(), total)) { - i_gate += args[j]; - size_t end = j + 4; - if (i_gate >= max_gates_per_thread or end >= args.size()) - { - YaoGate* gate = (YaoGate*)gates.end(); - gates.skip(i_gate * sizeof(YaoGate)); - party.timers["Dispatch"].start(); - party.and_jobs[i_thread++]->dispatch(processor.S, args, start, end, - i_gate, gate, party.get_gate_id(), repeat); - party.timers["Dispatch"].stop(); - party.counter += i_gate; - i_gate = 0; - start = end; - } + int i_gate = x[0]; + int end = x[1]; + YaoGate* gate = (YaoGate*)gates.end(); + gates.skip(i_gate * sizeof(YaoGate)); + party.timers["Dispatch"].start(); + party.jobs[i_thread++]->dispatch(YAO_AND_JOB, processor, args, start, + end, i_gate, gate, party.get_gate_id(), repeat); + party.timers["Dispatch"].stop(); + party.counter += i_gate; + i_gate = 0; + start = end; } party.and_prepare_timer.stop(); party.and_wait_timer.start(); - for (int i = 0; i < i_thread; i++) - party.and_jobs[i]->worker.done(); + party.wait(i_thread); party.and_wait_timer.stop(); } @@ -83,7 +76,7 @@ void YaoGarbleWire::and_singlethread(GC::Processor >& { int total_ands = processor.check_args(args, 4); if (total_ands < 10) - return processor.andrs(args); + return processor.and_(args, repeat); processor.complexity += total_ands; size_t n_args = args.size(); auto& garbler = YaoGarbler::s(); @@ -96,121 +89,65 @@ void YaoGarbleWire::and_singlethread(GC::Processor >& } void YaoGarbleWire::and_(GC::Memory >& S, - const vector& args, size_t start, size_t end, size_t total_ands, + const vector& args, size_t start, size_t end, size_t, YaoGate* gate, long& counter, PRNG& prng, map& timers, bool repeat, YaoGarbler& garbler) { (void)timers; - Key* labels; - Key* hashes; - vector label_vec, hash_vec; - size_t n_hashes = 4 * total_ands; - Key label_arr[400], hash_arr[400]; - if (total_ands < 100) - { - labels = label_arr; - hashes = hash_arr; - } - else - { - label_vec.resize(n_hashes); - hash_vec.resize(n_hashes); - labels = label_vec.data(); - hashes = hash_vec.data(); - } - //timers["Hash input"].start(); const Key& delta = garbler.get_delta(); - size_t i_label = 0; int dl = GC::Secret::default_length; Key left_delta = delta.doubling(1); Key right_delta = delta.doubling(2); + Key labels[4]; + Key hashes[4]; + MMO& mmo = garbler.mmo; for (auto it = args.begin() + start; it < args.begin() + end; it += 4) { if (*it == 1) { counter++; - YaoGate::E_inputs(&labels[i_label], S[*(it + 2)].get_reg(0).key, - S[*(it + 3)].get_reg(0).key, left_delta, right_delta, + YaoGate::E_inputs(labels, S[*(it + 2)].get_reg(0), + S[*(it + 3)].get_reg(0), left_delta, right_delta, counter); - i_label += 4; + mmo.hash<4>(hashes, labels); + auto& out = S[*(it + 1)]; + out.resize_regs(1); + YaoGate::randomize(out.get_reg(0), prng); + (gate++)->and_garble(out.get_reg(0), hashes, + S[*(it + 2)].get_reg(0), + S[*(it + 3)].get_reg(0), garbler.get_delta()); } else { int n_units = DIV_CEIL(*it, dl); for (int j = 0; j < n_units; j++) { + auto& out = S[*(it + 1) + j]; int left = min(dl, *it - j * dl); + out.resize_regs(left); for (int k = 0; k < left; k++) { auto& left_wire = S[*(it + 2) + j].get_reg(k); - const Key& right_key = S[*(it + 3) + j].get_reg( - repeat ? 0 : k).key; + auto& right_wire = S[*(it + 3) + j].get_reg( + repeat ? 0 : k); counter++; - YaoGate::E_inputs(&labels[i_label], left_wire.key, - right_key, left_delta, right_delta, counter); - i_label += 4; - } - } - } - } - //timers["Hash input"].stop(); - //timers["Hashing"].start(); - MMO& mmo = garbler.mmo; - size_t i; - for (i = 0; i + 8 <= n_hashes; i += 8) - mmo.hash<8>(&hashes[i], &labels[i]); - for (; i < n_hashes; i++) - hashes[i] = mmo.hash(labels[i]); - //timers["Hashing"].stop(); - //timers["Garbling"].start(); - size_t i_hash = 0; - for (auto it = args.begin() + start; it < args.begin() + end; it += 4) - { - if (*it == 1) - { - auto& out = S[*(it + 1)]; - out.resize_regs(1); - out.get_reg(0).randomize(prng); - (gate++)->and_garble(out.get_reg(0), &hashes[i_hash], - S[*(it + 2)].get_reg(0).mask, - S[*(it + 3)].get_reg(0).mask, garbler.get_delta()); - //timers["Gate computation"].stop(); - i_hash += 4; - } - else - { - int n_units = DIV_CEIL(*it, dl); - for (int j = 0; j < n_units; j++) - { - //timers["Outer ref"].start(); - auto& out = S[*(it + 1) + j]; - //timers["Outer ref"].stop(); - //timers["Resizing"].start(); - int n = min(dl, *it - j * dl); - out.resize_regs(n); - //timers["Resizing"].stop(); - for (int k = 0; k < n; k++) - { - YaoGarbleWire& right_wire = - S[*(it + 3) + (repeat ? 0 : j)].get_reg( - repeat ? 0 : k); + YaoGate::E_inputs(labels, left_wire, + right_wire, left_delta, right_delta, counter); + mmo.hash<4>(hashes, labels); //timers["Inner ref"].start(); - auto& left_wire = S[*(it + 2) + j].get_reg(k); //timers["Inner ref"].stop(); //timers["Randomizing"].start(); out.get_reg(k).randomize(prng); //timers["Randomizing"].stop(); //timers["Gate computation"].start(); - (gate++)->and_garble(out.get_reg(k), &hashes[i_hash], - left_wire.mask, right_wire.mask, + (gate++)->and_garble(out.get_reg(k), hashes, + left_wire, right_wire, garbler.get_delta()); //timers["Gate computation"].stop(); - i_hash += 4; } } } } - //timers["Garbling"].stop(); } @@ -252,7 +189,8 @@ void YaoGarbleWire::inputb(GC::Processor>& processor, for (auto& reg : processor.S[x.dest].get_regs()) { reg.set(garbler.prng.get_doubleword(), 0); - garbler.receiver_input_keys.back().push_back(reg.key); + assert(reg.mask() == 0); + garbler.receiver_input_keys.back().push_back(reg.full_key()); } } } @@ -276,7 +214,7 @@ void YaoGarbleWire::op(const YaoGarbleWire& left, const YaoGarbleWire& right, char YaoGarbleWire::get_output() { YaoGarbler::s().taint(); - YaoGarbler::s().output_masks.push_back(mask); + YaoGarbler::s().output_masks.push_back(mask()); return 0; } diff --git a/Yao/YaoGarbleWire.h b/Yao/YaoGarbleWire.h index 6dae71dd4..196c8de6d 100644 --- a/Yao/YaoGarbleWire.h +++ b/Yao/YaoGarbleWire.h @@ -8,19 +8,19 @@ #include "BMR/Key.h" #include "BMR/Register.h" +#include "config.h" +#include "YaoWire.h" #include -class YaoGate; class YaoGarbler; -class YaoGarbleWire : public Phase +class YaoGarbleWire : public YaoWire { public: - static string name() { return "YaoGarbleWire"; } + typedef YaoGarbler Party; - Key key; - bool mask; + static string name() { return "YaoGarbleWire"; } static YaoGarbleWire new_reg() { return {}; } @@ -57,33 +57,47 @@ class YaoGarbleWire : public Phase void randomize(PRNG& prng); void set(Key key, bool mask); + Key full_key() const + { + return key_; + } + + void set_full_key(Key key) + { + key_ = key; + } + + Key key() const + { + Key res = key_; + res.set_signal(0); + return res; + } + + bool mask() const + { + return key_.get_signal(); + } + void random(); void public_input(bool value); void op(const YaoGarbleWire& left, const YaoGarbleWire& right, Function func); - void XOR(const YaoGarbleWire& left, const YaoGarbleWire& right); char get_output(); }; inline void YaoGarbleWire::randomize(PRNG& prng) { - key = prng.get_doubleword(); + key_ = prng.get_doubleword(); #ifdef DEBUG //key = YaoGarbler::s().counter << 1; #endif - set(key, prng.get_uchar() & 1); } inline void YaoGarbleWire::set(Key key, bool mask) { - key.set_signal(0); - this->key = key; - this->mask = mask; -} - -inline void YaoGarbleWire::XOR(const YaoGarbleWire& left, const YaoGarbleWire& right) -{ - mask = left.mask ^ right.mask; - key = left.key ^ right.key; + key.set_signal(mask); + this->key_ = key; + assert(key.get_signal() == mask); } #endif /* YAO_YAOGARBLEWIRE_H_ */ diff --git a/Yao/YaoGarbler.cpp b/Yao/YaoGarbler.cpp index ad32dfe83..754a0d3d5 100644 --- a/Yao/YaoGarbler.cpp +++ b/Yao/YaoGarbler.cpp @@ -14,30 +14,27 @@ #include "GC/Secret.hpp" #include "GC/Thread.hpp" #include "Tools/MMO.hpp" +#include "YaoWire.hpp" thread_local YaoGarbler* YaoGarbler::singleton = 0; YaoGarbler::YaoGarbler(int thread_num, YaoGarbleMaster& master) : GC::Thread>(thread_num, master), + YaoCommon(master), master(master), and_proc_timer(CLOCK_PROCESS_CPUTIME_ID), and_main_thread_timer(CLOCK_THREAD_CPUTIME_ID), player(master.N, 1, thread_num << 24), ot_ext(OTExtensionWithMatrix::setup(player, - master.delta.get<__m128i>(), SENDER, true)) + master.get_delta().get<__m128i>(), SENDER, true)) { prng.ReSeed(); set_n_program_threads(master.machine.nthreads); - - and_jobs.resize(get_n_worker_threads()); - for (auto& job : and_jobs) - job = new YaoAndJob(*this); + this->init(*this); } YaoGarbler::~YaoGarbler() { - for (auto& job : and_jobs) - delete job; #ifdef VERBOSE cerr << "Number of AND gates: " << counter << endl; #endif @@ -52,7 +49,7 @@ YaoGarbler::~YaoGarbler() #endif } -void YaoGarbler::run(GC::Program>& program) +void YaoGarbler::run(GC::Program& program) { singleton = this; diff --git a/Yao/YaoGarbler.h b/Yao/YaoGarbler.h index 7a8c4b314..860ff311d 100644 --- a/Yao/YaoGarbler.h +++ b/Yao/YaoGarbler.h @@ -18,11 +18,11 @@ #include -class YaoGate; - -class YaoGarbler : public GC::Thread>, public YaoCommon +class YaoGarbler: public GC::Thread>, + public YaoCommon { friend class YaoGarbleWire; + friend class YaoCommon; protected: static thread_local YaoGarbler* singleton; @@ -42,8 +42,6 @@ class YaoGarbler : public GC::Thread>, public YaoCommo SendBuffer output_masks; MMO mmo; - vector and_jobs; - map timers; RealTwoPartyPlayer player; @@ -58,18 +56,16 @@ class YaoGarbler : public GC::Thread>, public YaoCommo bool continuous() { return master.continuous and master.machine.nthreads == 1; } - void run(GC::Program>& program); + void run(GC::Program& program); void run(Player& P, bool continuous); void post_run(); void send(Player& P); void process_receiver_inputs(); - const Key& get_delta() { return master.delta; } + Key get_delta() { return master.get_delta(); } void store_gate(const YaoGate& gate); - int get_n_worker_threads() - { return max(1u, thread::hardware_concurrency() / master.machine.nthreads); } int get_threshold() { return master.threshold; } long get_gate_id() { return gate_id(thread_num); } diff --git a/Yao/YaoGate.cpp b/Yao/YaoGate.cpp index 9c6dd6279..0163a144a 100644 --- a/Yao/YaoGate.cpp +++ b/Yao/YaoGate.cpp @@ -9,7 +9,7 @@ #include "BMR/prf.h" #include "Tools/MMO.h" -YaoGate::YaoGate(const YaoGarbleWire& out, const YaoGarbleWire& left, +YaoFullGate::YaoFullGate(const YaoGarbleWire& out, const YaoGarbleWire& left, const YaoGarbleWire& right, Function func) { const Key& delta = YaoGarbler::s().get_delta(); @@ -18,25 +18,26 @@ YaoGate::YaoGate(const YaoGarbleWire& out, const YaoGarbleWire& left, for (int i = 0; i < 2; i++) for (int j = 0; j < 2; j++) hashes[2 * i + j] = mmo.hash( - E_input(left.key ^ (i ? delta : 0), - right.key ^ (j ? delta : 0), + E_input(left.key() ^ (i ? delta : 0), + right.key() ^ (j ? delta : 0), YaoGarbler::s().get_gate_id())); - garble(out, hashes, left.mask, right.mask, func, delta); + garble(out, hashes, left.mask(), right.mask(), func, delta); #ifdef DEBUG - cout << "left " << left.mask << " " << left.key << " " << (left.key ^ delta) << endl; - cout << "right " << right.mask << " " << right.key << " " << (right.key ^ delta) << endl; - cout << "out " << out.mask << " " << out.key << " " << (out.key ^ delta) << endl; + cout << "left " << left.mask() << " " << left.key() << " " << (left.key() ^ delta) << endl; + cout << "right " << right.mask() << " " << right.key() << " " << (right.key() ^ delta) << endl; + cout << "out " << out.mask() << " " << out.key() << " " << (out.key() ^ delta) << endl; #endif } -void YaoGate::eval(YaoEvalWire& out, const YaoEvalWire& left, const YaoEvalWire& right) +void YaoFullGate::eval(YaoEvalWire& out, const YaoEvalWire& left, const YaoEvalWire& right) { MMO& mmo = YaoEvaluator::s().mmo; - Key key = E_input(left.key, right.key, YaoEvaluator::s().get_gate_id()); - eval(out, mmo.hash(key), get_entry(left.external, right.external)); + Key key = E_input(left.key(), right.key(), YaoEvaluator::s().get_gate_id()); + Key hash = mmo.hash(key); + eval(out, &hash, left, right); #ifdef DEBUG - cout << "external " << left.external << " " << right.external << endl; - cout << "entry " << get_entry(left.external, right.external) << endl; + cout << "external " << left.external() << " " << right.external() << endl; + cout << "entry " << get_entry(left.external(), right.external()) << endl; cout << "out " << out.key << endl; #endif } diff --git a/Yao/YaoGate.h b/Yao/YaoGate.h index 40e30a12b..06264982c 100644 --- a/Yao/YaoGate.h +++ b/Yao/YaoGate.h @@ -10,30 +10,47 @@ #include "BMR/Key.h" #include "YaoGarbleWire.h" #include "YaoEvalWire.h" +#include "YaoHalfGate.h" -class YaoGate +class YaoFullGate { Key entries[2][2]; + public: + static const int N_EVAL_HASHES = 1; + static Key E_input(const Key& left, const Key& right, long T); - static void E_inputs(Key* output, const Key& left, const Key& right, + static void E_inputs(Key* output, const YaoGarbleWire& left, + const YaoGarbleWire& right, const Key& left_delta, const Key& right_delta, long T); + static void eval_inputs(Key* out, const Key& left, const Key& right, long T) + { + *out = E_input(left, right, T); + } + static void randomize(YaoGarbleWire& out, PRNG& prng) + { + out.randomize(prng); + } + static Key garble_public_input(bool value, Key) + { + return value; + } - YaoGate() {} - YaoGate(const YaoGarbleWire& out, const YaoGarbleWire& left, + YaoFullGate() {} + YaoFullGate(const YaoGarbleWire& out, const YaoGarbleWire& left, const YaoGarbleWire& right, Function func); - void and_garble(const YaoGarbleWire& out, const Key* hashes, bool left_mask, - bool right_mask, Key delta); + void and_garble(const YaoGarbleWire& out, const Key* hashes, const YaoGarbleWire& left, + const YaoGarbleWire& right, Key delta); void garble(const YaoGarbleWire& out, const Key* hashes, bool left_mask, bool right_mask, Function func, Key delta); void eval(YaoEvalWire& out, const YaoEvalWire& left, const YaoEvalWire& right); - void eval(YaoEvalWire& out, const Key& hash, - const Key& entry); + void eval(YaoEvalWire& out, const Key* hash, const YaoEvalWire& left, + const YaoEvalWire& right); const Key& get_entry(bool left, bool right) { return entries[left][right]; } }; -inline Key YaoGate::E_input(const Key& left, const Key& right, long T) +inline Key YaoFullGate::E_input(const Key& left, const Key& right, long T) { Key res = left.doubling(1) ^ right.doubling(2) ^ T; #ifdef DEBUG @@ -43,11 +60,12 @@ inline Key YaoGate::E_input(const Key& left, const Key& right, long T) return res; } -inline void YaoGate::E_inputs(Key* output, const Key& left, const Key& right, +inline void YaoFullGate::E_inputs(Key* output, const YaoGarbleWire& left, + const YaoGarbleWire& right, const Key& left_delta, const Key& right_delta, long T) { - auto l = left.doubling(1); - auto r = right.doubling(2); + auto l = left.key().doubling(1); + auto r = right.key().doubling(2); for (int i = 0; i < 2; i++) for (int j = 0; j < 2; j++) @@ -55,22 +73,25 @@ inline void YaoGate::E_inputs(Key* output, const Key& left, const Key& right, ^ (j ? right_delta : 0) ^ T; } -inline void YaoGate::and_garble(const YaoGarbleWire& out, const Key* hashes, - bool left_mask, bool right_mask, Key delta) +__attribute__((always_inline)) +inline void YaoFullGate::and_garble(const YaoGarbleWire& out, const Key* hashes, + const YaoGarbleWire& left, const YaoGarbleWire& right, Key delta) { + bool left_mask = left.mask(); + bool right_mask = right.mask(); #define XX(L, R, O) \ for (int left = 0; left < 2; left++) \ for (int right = 0; right < 2; right++) \ { \ int index = 2 * left + right; \ - Key key = out.key; \ + Key key = out.key(); \ if (((left ^ L) & (right ^ R)) ^ O) \ key += delta; \ key += hashes[index]; \ entries[left][right] = key; \ } #define Y(L, R) \ - if (out.mask) \ + if (out.mask()) \ XX(L, R, true) \ else \ XX(L, R, false) @@ -86,14 +107,14 @@ inline void YaoGate::and_garble(const YaoGarbleWire& out, const Key* hashes, Z(false) } -inline void YaoGate::garble(const YaoGarbleWire& out, const Key* hashes, +inline void YaoFullGate::garble(const YaoGarbleWire& out, const Key* hashes, bool left_mask, bool right_mask, Function func, Key delta) { for (int left = 0; left < 2; left++) for (int right = 0; right < 2; right++) { - Key key = out.key; - if (func.call(left ^ left_mask, right ^ right_mask) ^ out.mask) + Key key = out.key(); + if (func.call(left ^ left_mask, right ^ right_mask) ^ out.mask()) key += delta; #ifdef DEBUG cout << "start key " << key << endl; @@ -112,10 +133,11 @@ inline void YaoGate::garble(const YaoGarbleWire& out, const Key* hashes, #endif } -inline void YaoGate::eval(YaoEvalWire& out, const Key& hash, const Key& entry) +inline void YaoFullGate::eval(YaoEvalWire& out, const Key* hash, + const YaoEvalWire& left, const YaoEvalWire& right) { - Key key = entry; - key -= hash; + Key key = get_entry(left.external(), right.external()); + key -= *hash; #ifdef DEBUG cout << "after left " << key << endl; #endif diff --git a/Yao/YaoHalfGate.cpp b/Yao/YaoHalfGate.cpp new file mode 100644 index 000000000..abce2d69b --- /dev/null +++ b/Yao/YaoHalfGate.cpp @@ -0,0 +1,31 @@ +/* + * YaoHalfGate.cpp + * + */ + +#include "YaoHalfGate.h" +#include "YaoGarbler.h" +#include "YaoEvaluator.h" + +YaoHalfGate::YaoHalfGate(YaoGarbleWire& out, const YaoGarbleWire& left, + const YaoGarbleWire& right, Function function) +{ + for (int i = 0; i < 4; i++) + assert(function[i] == Function(0x0001)[i]); + Key labels[4]; + Key hashes[4]; + E_inputs(labels, left, right, YaoGarbler::s().get_delta().doubling(1), + {}, YaoGarbler::s().counter); + YaoGarbler::s().mmo.hash<4>(hashes, labels); + and_garble(out, hashes, left, right, YaoGarbler::s().get_delta()); +} + +void YaoHalfGate::eval(YaoEvalWire& out, const YaoEvalWire& left, + const YaoEvalWire& right) +{ + Key hashes[2]; + Key labels[2]; + eval_inputs(labels, left.key(), right.key(), YaoEvaluator::s().counter); + YaoEvaluator::s().mmo.hash<2>(hashes, labels); + eval(out, hashes, left, right); +} diff --git a/Yao/YaoHalfGate.h b/Yao/YaoHalfGate.h new file mode 100644 index 000000000..f10a61b4e --- /dev/null +++ b/Yao/YaoHalfGate.h @@ -0,0 +1,97 @@ +/* + * YaoHalfGate.h + * + */ + +#ifndef YAO_YAOHALFGATE_H_ +#define YAO_YAOHALFGATE_H_ + +#include "BMR/Key.h" +#include "YaoGarbleWire.h" +#include "YaoEvalWire.h" + +class YaoHalfGate +{ + Key TG, TE; + +public: + static const int N_EVAL_HASHES = 2; + + static void eval_inputs(Key* output, const Key& left, const Key& right, + long T); + static void E_inputs(Key* output, const YaoGarbleWire& left, + const YaoGarbleWire& right, const Key& left_delta, + const Key& right_delta, long T); + static void randomize(YaoGarbleWire&, PRNG&) {} + static Key garble_public_input(bool value, Key delta) + { + return value ? delta : 0; + } + + YaoHalfGate() {} + YaoHalfGate(YaoGarbleWire&, const YaoGarbleWire&, + const YaoGarbleWire&, Function); + void and_garble(YaoGarbleWire& out, const Key* hashes, + const YaoGarbleWire& left, const YaoGarbleWire& right, Key delta); + void garble(const YaoGarbleWire&, const Key*, bool, + bool, Function, Key); + void eval(YaoEvalWire&, const YaoEvalWire&, + const YaoEvalWire&); + void eval(YaoEvalWire& out, const Key* hashes, const YaoEvalWire& left, + const YaoEvalWire& right); +}; + +inline void YaoHalfGate::E_inputs(Key* output, const YaoGarbleWire& left, + const YaoGarbleWire& right, const Key& left_delta, const Key&, long T) +{ + auto l = left.full_key().doubling(1); + auto r = right.full_key().doubling(1); + long j = T << 1; + output[0] = l ^ j; + output[1] = output[0] ^ left_delta; + output[2] = r ^ (j + 1); + output[3] = output[2] ^ left_delta; +} + +inline void YaoHalfGate::and_garble(YaoGarbleWire& out, const Key* hashes, + const YaoGarbleWire& left, const YaoGarbleWire& right, Key delta) +{ + bool pa = left.mask(); + bool pb = right.mask(); + TG = hashes[0] ^ hashes[1]; + if (pb) + TG ^= delta; + Key WG = hashes[0]; + if (pa) + WG ^= TG; + TE = hashes[2] ^ hashes[3] ^ left.full_key(); + Key WE = hashes[2]; + if (pb) + WE ^= TE ^ left.full_key(); + out.set_full_key(WG ^ WE); + assert(out.mask() == out.full_key().get_signal()); +} + +inline void YaoHalfGate::eval_inputs(Key* output, const Key& left, + const Key& right, long T) +{ + long j = T << 1; + output[0] = left.doubling(1) ^ j; + output[1] = right.doubling(1) ^ (j + 1); +} + +inline void YaoHalfGate::eval(YaoEvalWire& out, const Key* hashes, + const YaoEvalWire& left, const YaoEvalWire& right) +{ + bool sa = left.external(); + bool sb = right.external(); + Key WG = hashes[0]; + if (sa) + WG ^= TG; + Key WE = hashes[1]; + if (sb) + WE ^= TE ^ left.key(); + out.set(WG ^ WE); +} + +#endif /* YAO_YAOHALFGATE_H_ */ diff --git a/Yao/YaoPlayer.cpp b/Yao/YaoPlayer.cpp index a2d32af45..2a7ef170f 100644 --- a/Yao/YaoPlayer.cpp +++ b/Yao/YaoPlayer.cpp @@ -94,6 +94,8 @@ YaoPlayer::YaoPlayer(int argc, const char** argv) if (my_num == 1) ((YaoEvalMaster*)master)->machine.write_memory(0); + + delete master; } YaoPlayer::~YaoPlayer() diff --git a/Yao/YaoWire.h b/Yao/YaoWire.h new file mode 100644 index 000000000..e6b8f61a9 --- /dev/null +++ b/Yao/YaoWire.h @@ -0,0 +1,30 @@ +/* + * YaoWire.h + * + */ + +#ifndef YAO_YAOWIRE_H_ +#define YAO_YAOWIRE_H_ + +#include "BMR/Key.h" +#include "BMR/Register.h" + +class YaoWire : public Phase +{ +protected: + Key key_; + +public: + template + static void xors(GC::Processor& processor, const vector& args); + template + static void xors(GC::Processor& processor, const vector& args, + size_t start, size_t end); + + void XOR(const YaoWire& left, const YaoWire& right) + { + key_ = left.key_ ^ right.key_; + } +}; + +#endif /* YAO_YAOWIRE_H_ */ diff --git a/Yao/YaoWire.hpp b/Yao/YaoWire.hpp new file mode 100644 index 000000000..bb3b14068 --- /dev/null +++ b/Yao/YaoWire.hpp @@ -0,0 +1,49 @@ +/* + * YaoWire.hpp + * + */ + +#ifndef YAO_YAOWIRE_HPP_ +#define YAO_YAOWIRE_HPP_ + +#include "YaoWire.h" + +template +void YaoWire::xors(GC::Processor& processor, const vector& args) +{ + size_t threshold = 1024; + if (args.size() / 4 < threshold) + { + processor.xor_timer.start(); + processor.xors(args); + processor.xor_timer.stop(); + return; + } + + processor.xor_timer.start(); + + auto& party = T::part_type::Party::s(); + size_t start = 0; + int batch = args.size() / 4 / (party.get_n_worker_threads() + 1); + for (int i = 0; i < party.get_n_worker_threads(); i++) + { + size_t end = start + batch * 4; + party.jobs.at(i)->dispatch(YAO_XOR_JOB, processor, args, start, end, + 0, 0, 0, 0); + start = end; + } + assert(start <= args.size()); + xors(processor, args, start, args.size()); + party.wait(party.get_n_worker_threads()); + + processor.xor_timer.stop(); +} + +template +void YaoWire::xors(GC::Processor& processor, const vector& args, + size_t start, size_t end) +{ + processor.xors(args, start, end); +} + +#endif /* YAO_YAOWIRE_HPP_ */ diff --git a/Yao/config.h b/Yao/config.h index af86d30bf..237469c33 100644 --- a/Yao/config.h +++ b/Yao/config.h @@ -10,4 +10,15 @@ //#define CHECK_BUFFER +#define HALF_GATES + +class YaoFullGate; +class YaoHalfGate; + +#ifdef HALF_GATES +typedef YaoHalfGate YaoGate; +#else +typedef YaoFullGate YaoGate; +#endif + #endif /* YAO_CONFIG_H_ */ diff --git a/doc/Compiler.rst b/doc/Compiler.rst index 4f9a44f0d..e48ac4451 100644 --- a/doc/Compiler.rst +++ b/doc/Compiler.rst @@ -56,6 +56,10 @@ Compiler.ml module ------------------------- .. automodule:: Compiler.ml + :members: + :no-undoc-members: + :exclude-members: Adam, Tensor +.. autofunction:: approx_sigmoid Compiler.circuit module -----------------------