diff --git a/BMR/Register.h b/BMR/Register.h index e648f5c73..7295c637b 100644 --- a/BMR/Register.h +++ b/BMR/Register.h @@ -233,6 +233,8 @@ class Phase template static void ands(T& processor, const vector& args) { processor.ands(args); } template + static void xors(T& processor, const vector& args) { processor.xors(args); } + template static void inputb(T& processor, const vector& args) { processor.input(args); } template static T get_input(int from, GC::Processor& processor, int n_bits) diff --git a/BMR/common.h b/BMR/common.h index cb18cc685..7de2588ba 100644 --- a/BMR/common.h +++ b/BMR/common.h @@ -11,6 +11,8 @@ #include using namespace std; +#include "Tools/CheckVector.h" + typedef unsigned long wire_id_t; typedef unsigned long gate_id_t; typedef unsigned int party_id_t; @@ -37,20 +39,4 @@ class Function { bool call(bool left, bool right) { return rep[2 * left + right]; } }; -template -class CheckVector : public vector -{ -public: - CheckVector() : vector() {} - CheckVector(size_t size) : vector(size) {} - CheckVector(size_t size, const T& def) : vector(size, def) {} -#ifdef CHECK_SIZE - T& operator[](size_t i) { return this->at(i); } - const T& operator[](size_t i) const { return this->at(i); } -#else - T& at(size_t i) { return (*this)[i]; } - const T& at(size_t i) const { return (*this)[i]; } -#endif -}; - #endif /* CIRCUIT_INC_COMMON_H_ */ diff --git a/CHANGELOG.md b/CHANGELOG.md index 91efc1445..d9072297a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,17 @@ The changelog explains changes pulled through from the private development repository. Bug fixes and small enhancements are committed between releases and not documented here. +## 0.1.5 (Mar 20, 2020) + +- Faster conversion between arithmetic and binary secret sharing using [extended daBits](https://eprint.iacr.org/2020/338) +- Optimized daBits +- Optimized logistic regression +- Faster compilation of repetitive code (compiler option `-C`) +- ChaiGear: [HighGear](https://eprint.iacr.org/2017/1230) with covert key generation +- [TopGear](https://eprint.iacr.org/2019/035) zero-knowledge proofs +- Binary computation based on Shamir secret sharing +- Fixed security bug: Prove correctness of ciphertexts in input tuple generation +- Fixed security bug: Missing check in MASCOT bit generation and various binary computations + ## 0.1.4 (Dec 23, 2019) - Mixed circuit computation with secret sharing diff --git a/Compiler/GC/instructions.py b/Compiler/GC/instructions.py index d3ed17c53..c48ae357c 100644 --- a/Compiler/GC/instructions.py +++ b/Compiler/GC/instructions.py @@ -36,6 +36,8 @@ class ClearBitsAF(base.RegisterArgFormat): STMSBI = 0x243, MOVSB = 0x244, INPUTB = 0x246, + SPLIT = 0x248, + CONVCBIT2S = 0x249, XORCBI = 0x210, BITDECC = 0x211, CONVCINT = 0x213, @@ -49,15 +51,23 @@ class ClearBitsAF(base.RegisterArgFormat): MULCBI = 0x21c, SHRCBI = 0x21d, SHLCBI = 0x21e, + CONVCINTVEC = 0x21f, PRINTREGSIGNED = 0x220, PRINTREGB = 0x221, PRINTREGPLAINB = 0x222, PRINTFLOATPLAINB = 0x223, CONDPRINTSTRB = 0x224, CONVCBIT = 0x230, + CONVCBITVEC = 0x231, ) -class xors(base.Instruction): +class BinaryVectorInstruction(base.Instruction): + is_vec = lambda self: True + + def copy(self, size, subs): + return type(self)(*self.get_new_args(size, subs)) + +class xors(BinaryVectorInstruction): code = opcodes['XORS'] arg_format = tools.cycle(['int','sbw','sb','sb']) @@ -73,15 +83,21 @@ class xorcbi(base.Instruction): code = opcodes['XORCBI'] arg_format = ['cbw','cb','int'] -class andrs(base.Instruction): +class andrs(BinaryVectorInstruction): code = opcodes['ANDRS'] arg_format = tools.cycle(['int','sbw','sb','sb']) -class ands(base.Instruction): + def add_usage(self, req_node): + req_node.increment(('bit', 'triple'), sum(self.args[::4])) + +class ands(BinaryVectorInstruction): code = opcodes['ANDS'] arg_format = tools.cycle(['int','sbw','sb','sb']) -class andm(base.Instruction): + def add_usage(self, req_node): + req_node.increment(('bit', 'triple'), sum(self.args[::4])) + +class andm(BinaryVectorInstruction): code = opcodes['ANDM'] arg_format = ['int','sbw','sb','cb'] @@ -181,6 +197,31 @@ class convcbit(base.Instruction): code = opcodes['CONVCBIT'] arg_format = ['ciw','cb'] +@base.vectorize +class convcintvec(base.Instruction): + code = opcodes['CONVCINTVEC'] + arg_format = tools.chain(['c'], tools.cycle(['cbw'])) + +class convcbitvec(BinaryVectorInstruction): + code = opcodes['CONVCBITVEC'] + arg_format = ['int','ciw','cb'] + def __init__(self, *args): + super(convcbitvec, self).__init__(*args) + assert(args[2].n == args[0]) + args[1].set_size(args[0]) + +class convcbit2s(BinaryVectorInstruction): + code = opcodes['CONVCBIT2S'] + arg_format = ['int','sbw','cb'] + +@base.vectorize +class split(base.Instruction): + code = opcodes['SPLIT'] + arg_format = tools.chain(['int','s'], tools.cycle(['sbw'])) + def __init__(self, *args, **kwargs): + super(split_class, self).__init__(*args, **kwargs) + assert (len(args) - 2) % args[0] == 0 + class movsb(base.Instruction): code = opcodes['MOVSB'] arg_format = ['sbw','sb'] @@ -196,9 +237,9 @@ class bitb(base.Instruction): code = opcodes['BITB'] arg_format = ['sbw'] -class reveal(base.Instruction): +class reveal(BinaryVectorInstruction, base.VarArgsInstruction, base.Mergeable): code = opcodes['REVEAL'] - arg_format = ['int','cbw','sb'] + arg_format = tools.cycle(['int','cbw','sb']) class inputb(base.DoNotEliminateInstruction, base.VarArgsInstruction): __slots__ = [] diff --git a/Compiler/GC/types.py b/Compiler/GC/types.py index a6d01a5fb..968edae85 100644 --- a/Compiler/GC/types.py +++ b/Compiler/GC/types.py @@ -1,15 +1,17 @@ from Compiler.types import MemValue, read_mem_value, regint, Array, cint -from Compiler.types import _bitint, _number, _fix, _structure, _bit +from Compiler.types import _bitint, _number, _fix, _structure, _bit, _vec, sint from Compiler.program import Tape, Program from Compiler.exceptions import * from Compiler import util, oram, floatingpoint, library +from Compiler import instructions_base import Compiler.GC.instructions as inst import operator +import math from functools import reduce class bits(Tape.Register, _structure, _bit): n = 40 - size = 1 + unit = 64 PreOp = staticmethod(floatingpoint.PreOpN) decomposed = None @staticmethod @@ -19,9 +21,7 @@ def PreOR(l): [1 - x for x in l])] @classmethod def get_type(cls, length): - if length is None: - return cls - elif length == 1: + if length == 1: return cls.bit_type if length not in cls.types: class bitsn(cls): @@ -65,6 +65,11 @@ def bit_decompose(self, bit_length=None): return res + suffix else: return self.decomposed[:n] + suffix + @staticmethod + def bit_decompose_clear(a, n_bits): + res = [cbits.get_type(a.size)() for i in range(n_bits)] + cbits.conv_cint_vec(a, *res) + return res @classmethod def malloc(cls, size): return Program.prog.malloc(size, cls) @@ -87,41 +92,61 @@ def store_in_mem(self, address): def __init__(self, value=None, n=None, size=None): if size != 1 and size is not None: raise Exception('invalid size for bit type: %s' % size) - Tape.Register.__init__(self, self.reg_type, Program.prog.curr_tape) - self.set_length(n or self.n) + self.n = n or self.n + size = math.ceil(self.n / self.unit) if self.n != None else None + Tape.Register.__init__(self, self.reg_type, Program.prog.curr_tape, + size=size) if value is not None: self.load_other(value) + def copy(self): + return type(self)(n=instructions_base.get_global_vector_size()) def set_length(self, n): if n > self.max_length: print(self.max_length) raise Exception('too long: %d' % n) self.n = n + def set_size(self, size): + pass def load_other(self, other): if isinstance(other, cint): - size = other.size - other = sum(x << i for i, x in enumerate(other)) - other = other.to_regint(size) - if isinstance(other, int): + assert(self.n == other.size) + self.conv_regint_by_bit(self.n, self, other.to_regint(1)) + elif isinstance(other, int): self.set_length(self.n or util.int_len(other)) self.load_int(other) elif isinstance(other, regint): - assert(other.size == 1) - self.conv_regint(self.n, self, other) + assert(other.size == math.ceil(self.n / self.unit)) + for i, (x, y) in enumerate(zip(self, other)): + self.conv_regint(min(self.unit, self.n - i * self.unit), x, y) elif isinstance(self, type(other)) or isinstance(other, type(self)): - self.mov(self, other) + assert(self.n == other.n) + for i in range(math.ceil(self.n / self.unit)): + self.mov(self[i], other[i]) else: try: other = self.bit_compose(other.bit_decompose()) - self.mov(self, other) + self.load_other(other) except: raise CompilerError('cannot convert from %s to %s' % \ (type(other), type(self))) def long_one(self): - return 2**self.n - 1 + return 2**self.n - 1 if self.n != None else None def __repr__(self): - return '%s(%d/%d)' % \ - (super(bits, self).__repr__(), self.n, type(self).n) + if self.n != None: + suffix = '%d' % self.n + if type(self).n != None and type(self).n != self.n: + suffice += '/%d' % type(self).n + else: + suffix = 'undef' + return '%s(%s)' % (super(bits, self).__repr__(), suffix) __str__ = __repr__ + def _new_by_number(self, i, size=1): + assert(size == 1) + n = min(self.unit, self.n - (i - self.i) * self.unit) + res = self.get_type(n)() + res.i = i + res.program = self.program + return res class cbits(bits): max_length = 64 @@ -131,6 +156,12 @@ class cbits(bits): store_inst = (None, inst.stmcb) bitdec = inst.bitdecc conv_regint = staticmethod(lambda n, x, y: inst.convcint(x, y)) + conv_cint_vec = inst.convcintvec + @classmethod + def conv_regint_by_bit(cls, n, res, other): + assert n == res.n + assert n == other.size + cls.conv_cint_vec(cint(other, size=other.size), res) types = {} def load_int(self, value): self.load_other(regint(value)) @@ -187,6 +218,13 @@ def to_regint(self, dest): if self.n > 64: raise CompilerError('too many bits') inst.convcbit(dest, self) + def to_regint_by_bit(self): + if self.n != None: + res = regint(size=self.n) + else: + res = regint() + inst.convcbitvec(self.n, res, self) + return res class sbits(bits): max_length = 128 @@ -199,6 +237,11 @@ class sbits(bits): bitdec = inst.bitdecs bitcom = inst.bitcoms conv_regint = inst.convsint + @classmethod + def conv_regint_by_bit(cls, n, res, other): + tmp = cbits.get_type(n)() + tmp.conv_regint_by_bit(n, tmp, other) + res.load_other(tmp) mov = inst.movsb types = {} def __init__(self, *args, **kwargs): @@ -250,18 +293,26 @@ def load_int(self, value): self.mov(self, lower + (upper << 64)) else: raise NotImplementedError('more than 128 bits wanted') + def load_other(self, other): + if isinstance(other, cbits) and self.n == other.n: + inst.convcbit2s(self.n, self, other) + else: + super(sbits, self).load_other(other) @read_mem_value def __add__(self, other): - if isinstance(other, int): + if isinstance(other, int) or other is None: return self.xor_int(other) else: if not isinstance(other, sbits): other = self.conv(other) - n = min(self.n, other.n) + if self.n is None or other.n is None: + assert self.n == other.n + n = None + else: + n = min(self.n, other.n) res = self.new(n=n) inst.xors(n, res, self, other) - max_n = max(self.n, other.n) - if max_n > n: + if self.n != None and max(self.n, other.n) > n: if self.n > n: longer = self else: @@ -293,17 +344,13 @@ def __mul__(self, other): return res except AttributeError: return NotImplemented - @read_mem_value - def __rmul__(self, other): - if isinstance(other, cbits): - return other * self - else: - return self.mul_int(other) + __rmul__ = __mul__ @read_mem_value def __and__(self, other): if util.is_zero(other): return 0 - elif util.is_all_ones(other, self.n): + elif util.is_all_ones(other, self.n) or \ + (other is None and self.n == None): return self res = self.new(n=self.n) if not isinstance(other, sbits): @@ -314,6 +361,7 @@ def __and__(self, other): assert(self.n == other.n) inst.ands(self.n, res, self, other) return res + __rand__ = __and__ def xor_int(self, other): if other == 0: return self @@ -326,6 +374,7 @@ def xor_int(self, other): for x,y in zip(self_bits, other_bits)] \ + extra_bits) def mul_int(self, other): + assert(util.is_constant(other)) if other == 0: return 0 elif other == 1: @@ -344,14 +393,19 @@ def __invert__(self): # res = type(self)(n=self.n) # inst.nots(res, self) # return res - one = self.new(value=self.long_one(), n=self.n) + if self.n == None or self.n > self.unit: + one = self.get_type(self.n)() + self.conv_regint_by_bit(self.n, one, regint(1, size=self.n)) + else: + one = self.new(value=self.long_one(), n=self.n) return self + one def __neg__(self): return self def reveal(self): - if self.n > self.clear_type.max_length: - raise Exception('too long to reveal') - res = self.clear_type(n=self.n) + if self.n == None or \ + self.n > max(self.max_length, self.clear_type.max_length): + assert(self.unit == self.clear_type.unit) + res = self.clear_type.get_type(self.n)() inst.reveal(self.n, res, self) return res def equal(self, other, n=None): @@ -395,8 +449,11 @@ def if_else(self, x, y): @staticmethod def bit_adder(*args, **kwargs): return sbitint.bit_adder(*args, **kwargs) + @staticmethod + def ripple_carry_adder(*args, **kwargs): + return sbitint.ripple_carry_adder(*args, **kwargs) -class sbitvec(object): +class sbitvec(_vec): @classmethod def get_type(cls, n): return cls @@ -414,8 +471,27 @@ def combine(cls, vectors): def from_matrix(cls, matrix): # any number of rows, limited number of columns return cls.combine(cls(row) for row in matrix) - def __init__(self, elements=None): - if elements is not None: + def __init__(self, elements=None, length=None): + if length: + assert isinstance(elements, sint) + if Program.prog.use_split(): + n = Program.prog.use_split() + columns = [[sbits.get_type(elements.size)() + for i in range(n)] for i in range(length)] + inst.split(n, elements, *sum(columns, [])) + x = sbitint.wallace_tree_without_finish(columns, False) + v = sbitint.carry_lookahead_adder(x[0], x[1], fewer_inv=True) + else: + assert Program.prog.options.ring + l = int(Program.prog.options.ring) + r, r_bits = sint.get_edabit(length, size=elements.size) + c = ((elements - r) << (l - length)).reveal() + c >>= l - length + cb = [(c >> i) for i in range(length)] + x = sbitintvec.from_vec(r_bits) + sbitintvec.from_vec(cb) + v = x.v + self.v = v[:length] + elif elements is not None: self.v = sbits.trans(elements) def popcnt(self): res = sbitint.wallace_tree([[b] for b in self.v]) @@ -426,8 +502,15 @@ def elements(self, start=None, stop=None): if stop is None: start, stop = stop, start return sbits.trans(self.v[start:stop]) + def coerce(self, other): + if isinstance(other, cint): + size = other.size + return (other.get_vector(base, min(64, size - base)) \ + for base in range(0, size, 64)) + return other def __xor__(self, other): - return self.from_vec(x ^ y for x, y in zip(self.v, other.v)) + other = self.coerce(other) + return self.from_vec(x ^ y for x, y in zip(self.v, other)) def __and__(self, other): return self.from_vec(x & y for x, y in zip(self.v, other.v)) def if_else(self, x, y): @@ -453,6 +536,26 @@ def store_in_mem(self, address): def bit_decompose(self): return self.v bit_compose = from_vec + def reveal(self): + assert len(self) == 1 + return self.v[0].reveal() + def long_one(self): + return [x.long_one() for x in self.v] + def __rsub__(self, other): + return self.from_vec(y - x for x, y in zip(self.v, other)) + def half_adder(self, other): + other = self.coerce(other) + res = zip(*(x.half_adder(y) for x, y in zip(self.v, other))) + return (self.from_vec(x) for x in res) + def __mul__(self, other): + if isinstance(other, int): + return self.from_vec(x * other for x in self.v) + def __add__(self, other): + return self.from_vec(x + y for x, y in zip(self.v, other)) + def bit_and(self, other): + return self & other + def bit_xor(self, other): + return self ^ other class bit(object): n = 1 diff --git a/Compiler/allocator.py b/Compiler/allocator.py index df0ef7be2..0f570e6d2 100644 --- a/Compiler/allocator.py +++ b/Compiler/allocator.py @@ -16,12 +16,13 @@ class StraightlineAllocator: """Allocate variables in a straightline program using n registers. It is based on the precondition that every register is only defined once.""" - def __init__(self, n): + def __init__(self, n, program): self.alloc = dict_by_id() self.usage = Compiler.program.RegType.create_dict(lambda: 0) self.defined = dict_by_id() self.dealloc = set_by_id() self.n = n + self.program = program def alloc_reg(self, reg, free): base = reg.vectorbase @@ -76,7 +77,8 @@ def process(self, program, alloc_pool): # unused register self.alloc_reg(j, alloc_pool) unused_regs.append(j) - if unused_regs and len(unused_regs) == len(list(i.get_def())): + if unused_regs and len(unused_regs) == len(list(i.get_def())) and \ + self.program.verbose: # only report if all assigned registers are unused print("Register(s) %s never used, assigned by '%s' in %s" % \ (unused_regs,i,format_trace(i.caller))) @@ -175,37 +177,8 @@ def do_merge(self, merges_iter): except StopIteration: return mergecount, None - def expand_vector_args(inst): - if inst.is_vec(): - for arg in inst.args: - arg.create_vector_elements() - res = sum(list(zip(*inst.args)), ()) - return list(res) - else: - return inst.args - for i in merges_iter: - if isinstance(instructions[n], startinput_class): - instructions[n].args[1] += instructions[i].args[1] - elif isinstance(instructions[n], (stopinput, gstopinput)): - if instructions[n].get_size() != instructions[i].get_size(): - raise NotImplemented() - else: - instructions[n].args += instructions[i].args[1:] - else: - if instructions[n].get_size() != instructions[i].get_size(): - # merge as non-vector instruction - instructions[n].args = expand_vector_args(instructions[n]) + \ - expand_vector_args(instructions[i]) - if instructions[n].is_vec(): - instructions[n].size = 1 - else: - instructions[n].args += instructions[i].args - - # join arg_formats if not special iterators - # if not isinstance(instructions[n].arg_format, (itertools.repeat, itertools.cycle)) and \ - # not isinstance(instructions[i].arg_format, (itertools.repeat, itertools.cycle)): - # instructions[n].arg_format += instructions[i].arg_format + instructions[n].merge(instructions[i]) instructions[i] = None self.merge_nodes(n, i) mergecount += 1 @@ -343,7 +316,7 @@ def longest_paths_merge(self): merge = merges[i] t = type(self.instructions[merge[0]]) self.counter[t] += len(merge) - if len(merge) > 1000: + if len(merge) > 10000: print('Merging %d %s in round %d/%d' % \ (len(merge), t.__name__, i, len(merges))) self.do_merge(merge) @@ -504,7 +477,8 @@ def keep_order(instr, n, t, arg_index=None): next_available_depth[type(instr), d] = depth round_type[depth] = instr.merge_id() - parallel_open[depth] += len(instr.args) * instr.get_size() + if int(options.max_parallel_open) > 0: + parallel_open[depth] += len(instr.args) * instr.get_size() depths[n] = depth if isinstance(instr, ReadMemoryInstruction): @@ -557,8 +531,9 @@ def keep_order(instr, n, t, arg_index=None): print("Processed dependency of %d/%d instructions at" % \ (n, len(block.instructions)), time.asctime()) - if len(open_nodes) > 1000: - print("Program has %d %s instructions" % (len(open_nodes), merge_classes)) + if len(open_nodes) > 1000 and self.block.parent.program.verbose: + print("Basic block has %d %s instructions" % + (len(open_nodes), merge_classes)) def merge_nodes(self, i, j): """ Merge node j into i, removing node j """ @@ -608,7 +583,7 @@ def eliminate(i): eliminate(list(G.pred[i])[0]) eliminate(i) count += 2 - if count > 0: + if count > 0 and self.block.parent.program.verbose: print('Eliminated %d dead instructions, among which %d opens: %s' \ % (count, open_count, dict(stats))) diff --git a/Compiler/comparison.py b/Compiler/comparison.py index 62834887d..8af1bdec8 100644 --- a/Compiler/comparison.py +++ b/Compiler/comparison.py @@ -69,13 +69,31 @@ def divide_by_two(res, x, m=1): inv2m(tmp, m) mulc(res, x, tmp) +@instructions_base.cisc def LTZ(s, a, k, kappa): """ s = (a ?< 0) k: bit length of a """ - from .types import sint + from .types import sint, _bitint + from .GC.types import sbitvec + if program.use_split(): + movs(s, sint.conv(sbitvec(a, k).v[-1])) + return + elif program.options.ring: + from . import floatingpoint + assert(int(program.options.ring) >= k) + m = k - 1 + shift = int(program.options.ring) - k + r_prime, r_bin = MaskingBitsInRing(k) + tmp = a - r_prime + c_prime = (tmp << shift).reveal() >> shift + a = r_bin[0].bit_decompose_clear(c_prime, m) + b = r_bin[:m] + u = CarryOutRaw(a[::-1], b[::-1]) + movs(s, sint.conv(r_bin[m].bit_xor(c_prime >> m).bit_xor(u))) + return t = sint() Trunc(t, a, k, k - 1, kappa, True) subsfi(s, t, 0) @@ -86,6 +104,7 @@ def LessThanZero(a, k, kappa): LTZ(res, a, k, kappa) return res +@instructions_base.cisc def Trunc(d, a, k, m, kappa, signed): """ d = a >> m @@ -120,7 +139,7 @@ def TruncRing(d, a, k, m, signed): movs(d, res) return res -def TruncZeroes(a, k, m, signed): +def TruncZeros(a, k, m, signed): if program.options.ring: return TruncLeakyInRing(a, k, m, signed) else: @@ -139,9 +158,8 @@ def TruncLeakyInRing(a, k, m, signed): from .types import sint, intbitint, cint, cgf2n n_bits = k - m n_shift = int(program.options.ring) - n_bits - if program.use_dabit and n_bits > 1: - r, r_bits = zip(*(sint.get_dabit() for i in range(n_bits))) - r = sint.bit_compose(r) + if n_bits > 1: + r, r_bits = MaskingBitsInRing(n_bits, True) else: r_bits = [sint.get_random_bit() for i in range(n_bits)] r = sint.bit_compose(r_bits) @@ -150,7 +168,7 @@ def TruncLeakyInRing(a, k, m, signed): shifted = ((a << (n_shift - m)) + (r << n_shift)).reveal() masked = shifted >> n_shift u = sint() - BitLTL(u, masked, r_bits, 0) + BitLTL(u, masked, r_bits[:n_bits], 0) res = (u << n_bits) + masked - r if signed: res -= (1 << (n_bits - 1)) @@ -174,6 +192,7 @@ def TruncRoundNearest(a, k, m, kappa, signed=False): Trunc(res, a + (1 << (m - 1)), k + 1, m, kappa, signed) return res +@instructions_base.cisc def Mod2m(a_prime, a, k, m, kappa, signed): """ a_prime = a % 2^m @@ -199,16 +218,11 @@ def Mod2mRing(a_prime, a, k, m, signed): assert(int(program.options.ring) >= k) from Compiler.types import sint, intbitint, cint shift = int(program.options.ring) - m - if program.use_dabit: - r, r_bin = zip(*(sint.get_dabit() for i in range(m))) - else: - r = [sint.get_random_bit() for i in range(m)] - r_bin = r - r_prime = sint.bit_compose(r) + r_prime, r_bin = MaskingBitsInRing(m, True) tmp = a + r_prime c_prime = (tmp << shift).reveal() >> shift u = sint() - BitLTL(u, c_prime, r_bin, 0) + BitLTL(u, c_prime, r_bin[:m], 0) res = (u << m) + c_prime - r_prime if a_prime is not None: movs(a_prime, res) @@ -247,19 +261,35 @@ def Mod2mField(a_prime, a, k, m, kappa, signed): adds(a_prime, t[5], t[4]) return r_dprime, r_prime, c, c_prime, u, t, c2k1 +def MaskingBitsInRing(m, strict=False): + from Compiler.types import sint + if program.use_edabit(): + return sint.get_edabit(m, strict) + elif program.use_dabit: + r, r_bin = zip(*(sint.get_dabit() for i in range(m))) + else: + r = [sint.get_random_bit() for i in range(m)] + r_bin = r + return sint.bit_compose(r), r_bin + def PRandM(r_dprime, r_prime, b, k, m, kappa, use_dabit=True): """ r_dprime = random secret integer in range [0, 2^(k + kappa - m) - 1] r_prime = random secret integer in range [0, 2^m - 1] b = array containing bits of r_prime """ + program.curr_tape.require_bit_length(k + kappa) + from .types import sint + if program.use_edabit() and m > 1: + movs(r_dprime, sint.get_edabit(k + kappa - m, True)[0]) + tmp, b[:] = sint.get_edabit(m, True) + movs(r_prime, tmp) + return t = [[program.curr_block.new_reg('s') for j in range(2)] for i in range(m)] t[0][1] = b[-1] PRandInt(r_dprime, k + kappa - m) # r_dprime is always multiplied by 2^m - program.curr_tape.require_bit_length(k + kappa) if use_dabit and program.use_dabit and m > 1: - from .types import sint r, b[:] = zip(*(sint.get_dabit() for i in range(m))) r = sint.bit_compose(r) movs(r_prime, r) @@ -389,17 +419,25 @@ def CarryOut(res, a, b, c=0, kappa=None): b: array of secret bits (same length as a) c: initial carry-in bit """ + from .types import sint + movs(res, sint.conv(CarryOutRaw(a, b, c))) + +def CarryOutRaw(a, b, c=0): + assert len(a) == len(b) k = len(a) from . import types d = [program.curr_block.new_reg('s') for i in range(k)] s = [program.curr_block.new_reg('s') for i in range(3)] for i in range(k): d[i] = list(b[i].half_adder(a[i])) - s[0] = d[-1][0] * c + s[0] = d[-1][0].bit_and(c) s[1] = d[-1][1] + s[0] d[-1][1] = s[1] - - movs(res, types.sint.conv(CarryOutAux(d[::-1], kappa))) + return CarryOutAux(d[::-1], None) + +def CarryOutRawLE(a, b, c=0): + """ Little-endian version """ + return CarryOutRaw(a[::-1], b[::-1], c) def CarryOutLE(a, b, c=0): """ Little-endian version """ @@ -416,13 +454,12 @@ def BitLTL(res, a, b, kappa): b: array of secret bits (same length as a) """ k = len(b) - from . import floatingpoint - a_bits = floatingpoint.bits(a, k) + a_bits = b[0].bit_decompose_clear(a, k) s = [[program.curr_block.new_reg('s') for i in range(k)] for j in range(2)] t = [program.curr_block.new_reg('s') for i in range(1)] for i in range(len(b)): s[0][i] = b[0].long_one() - b[i] - CarryOut(t[0], a_bits[::-1], s[0][::-1], 1, kappa) + CarryOut(t[0], a_bits[::-1], s[0][::-1], b[0].long_one(), kappa) subsfi(res, t[0], 1) return a_bits, s[0] diff --git a/Compiler/compilerLib.py b/Compiler/compilerLib.py index 7a9fcbe9a..fca698a3b 100644 --- a/Compiler/compilerLib.py +++ b/Compiler/compilerLib.py @@ -71,9 +71,10 @@ def run(args, options, param=-1, merge_opens=True, emulate=True, \ if prog.main_thread_running: prog.update_req(prog.curr_tape) - print('Program requires:', repr(prog.req_num)) - print('Cost:', 0 if prog.req_num is None else prog.req_num.cost()) - print('Memory size:', dict(prog.allocated_mem)) + if prog.verbose: + print('Program requires:', repr(prog.req_num)) + print('Cost:', 0 if prog.req_num is None else prog.req_num.cost()) + print('Memory size:', dict(prog.allocated_mem)) # finalize the memory prog.finalize_memory() diff --git a/Compiler/config.py b/Compiler/config.py index a3518c5f6..680ad64c9 100644 --- a/Compiler/config.py +++ b/Compiler/config.py @@ -17,22 +17,16 @@ P_VALUES[-1] = P_VALUES[128] -BIT_LENGTHS = { -1: 32, +BIT_LENGTHS = { -1: 64, 32: 16, 64: 16, 128: 64, 256: 64, 512: 64 } -STAT_SEC = { -1: 6, - 32: 6, - 64: 30, - 128: 40, - 256: 40, - 512: 40 } - -COST = { 'modp': defaultdict(lambda: 0, +COST = defaultdict(lambda: defaultdict(lambda: 0), + { 'modp': defaultdict(lambda: 0, { 'triple': 0.00020652622883106154, 'square': 0.00020652622883106154, 'bit': 0.00020652622883106154, @@ -51,7 +45,7 @@ 'all': { 'round': 0, 'inv': 0, } -} +}) try: diff --git a/Compiler/floatingpoint.py b/Compiler/floatingpoint.py index e49fc456f..530c6440e 100644 --- a/Compiler/floatingpoint.py +++ b/Compiler/floatingpoint.py @@ -4,6 +4,7 @@ from . import comparison from . import program from . import util +from . import instructions_base ## ## Helper functions for floating point arithmetic @@ -50,13 +51,14 @@ def maskField(a, k, kappa): asm_open(c, a + two_power(k) * r_dprime + r_prime)# + 2**(k-1)) return c, r +@instructions_base.ret_cisc def EQZ(a, k, kappa): if program.Program.prog.options.ring: c, r = maskRing(a, k) else: c, r = maskField(a, k, kappa) d = [None]*k - for i,b in enumerate(bits(c, k)): + for i,b in enumerate(r[0].bit_decompose_clear(c, k)): d[i] = r[i].bit_xor(b) return 1 - types.sint.conv(KOR(d, kappa)) @@ -299,6 +301,7 @@ def BitDec(a, k, m, kappa, bits_to_compute=None): def BitDecRing(a, k, m): n_shift = int(program.Program.prog.options.ring) - m + assert(n_shift >= 0) if program.Program.prog.use_dabit: r, r_bits = zip(*(types.sint.get_dabit() for i in range(m))) r = types.sint.bit_compose(r) @@ -328,10 +331,11 @@ def BitDecField(a, k, m, kappa, bits_to_compute=None): print('BitDec assertion failed') print('a =', a.value) print('a mod 2^%d =' % k, (a.value % 2**k)) - res = r[0].bit_adder(r, list(bits(c,m))) + res = r[0].bit_adder(r, list(r[0].bit_decompose_clear(c,m))) return [types.sint.conv(bit) for bit in res] +@instructions_base.ret_cisc def Pow2(a, l, kappa): m = int(ceil(log(l, 2))) t = BitDec(a, m, m, kappa) @@ -361,10 +365,16 @@ def B2U_from_Pow2(pow2a, l, kappa): for i in range(l): bit(r[i]) r_bits = r - comparison.PRandInt(t, kappa) - asm_open(c, pow2a + two_power(l) * t + sum(two_power(i)*r[i] for i in range(l))) - comparison.program.curr_tape.require_bit_length(l + kappa) - c = list(bits(c, l)) + if program.Program.prog.options.ring: + n_shift = int(program.Program.prog.options.ring) - l + assert n_shift > 0 + c = ((pow2a + types.sint.bit_compose(r)) << n_shift).reveal() >> n_shift + else: + comparison.PRandInt(t, kappa) + asm_open(c, pow2a + two_power(l) * t + + sum(two_power(i) * r[i] for i in range(l))) + comparison.program.curr_tape.require_bit_length(l + kappa) + c = list(r_bits[0].bit_decompose_clear(c, l)) x = [r_bits[i].bit_xor(c[i]) for i in range(l)] #print ' '.join(str(b.value) for b in x) y = PreOR(x, kappa) @@ -402,10 +412,14 @@ def Trunc(a, l, m, kappa, compute_modulo=False, signed=False): r_prime += t2 r_dprime += t1 - t2 #assert(r_prime.value == (sum(2**i*x[i].value*r[i].value for i in range(l)) % comparison.program.P)) - comparison.PRandInt(rk, kappa) - r_dprime += two_power(l) * rk - #assert(r_dprime.value == (2**l * rk.value + sum(2**i*(1 - x[i].value)*r[i].value for i in range(l)) % comparison.program.P)) - asm_open(c, a + r_dprime + r_prime) + if program.Program.prog.options.ring: + n_shift = int(program.Program.prog.options.ring) - l + c = ((a + r_dprime + r_prime) << n_shift).reveal() >> n_shift + else: + comparison.PRandInt(rk, kappa) + r_dprime += two_power(l) * rk + #assert(r_dprime.value == (2**l * rk.value + sum(2**i*(1 - x[i].value)*r[i].value for i in range(l)) % comparison.program.P)) + asm_open(c, a + r_dprime + r_prime) for i in range(1,l): ci[i] = c % two_power(i) #assert(ci[i].value == c.value % 2**i) @@ -439,6 +453,12 @@ def TruncInRing(to_shift, l, pow2m): bits = types.intbitint.bit_adder(r_bits, masked.bit_decompose(l)) return types.sint.bit_compose(reversed(bits)) +def SplitInRing(a, l, m): + pow2m = Pow2(m, l, None) + upper = TruncInRing(a, l, pow2m) + lower = a - upper * pow2m + return lower, upper, pow2m + def TruncRoundNearestAdjustOverflow(a, length, target_length, kappa): t = comparison.TruncRoundNearest(a, length, length - target_length, kappa) overflow = t.greater_equal(two_power(target_length), target_length + 1, kappa) @@ -496,6 +516,7 @@ def FLRound(x, mode): p = ((p1 + d * a) * (1 - b) + b * away_from_zero * (1 - l)) * (1 - z) return v, p, z, s +@instructions_base.ret_cisc def TruncPr(a, k, m, kappa=None, signed=True): """ Probabilistic truncation [a/2^m + u] where Pr[u = 1] = (a % 2^m) / 2^m @@ -513,8 +534,11 @@ def TruncPrRing(a, k, m, signed=True): n_ring = int(program.Program.prog.options.ring) assert n_ring >= k, '%d too large' % k if k == n_ring: - for i in range(m): - a += types.sint.get_random_bit() << i + if program.Program.prog.use_edabit(): + a += types.sint.get_edabit(m, True)[0] + else: + for i in range(m): + a += types.sint.get_random_bit() << i return comparison.TruncLeakyInRing(a, k, m, signed=signed) else: from .types import sint @@ -525,13 +549,22 @@ def TruncPrRing(a, k, m, signed=True): trunc_pr(res, a, k, m) else: # extra bit to mask overflow - r_bits = [sint.get_random_bit() for i in range(k + 1)] - n_shift = n_ring - len(r_bits) - tmp = a + sint.bit_compose(r_bits) + if program.Program.prog.use_edabit(): + lower = sint.get_edabit(m, True)[0] + upper = sint.get_edabit(k - m, True)[0] + msb = sint.get_random_bit() + r = (msb << k) + (upper << m) + lower + else: + r_bits = [sint.get_random_bit() for i in range(k + 1)] + r = sint.bit_compose(r_bits) + upper = sint.bit_compose(r_bits[m:k]) + msb = r_bits[-1] + n_shift = n_ring - (k + 1) + tmp = a + r masked = (tmp << n_shift).reveal() shifted = (masked << 1 >> (n_shift + m + 1)) - overflow = r_bits[-1].bit_xor(masked >> (n_ring - 1)) - res = shifted - sint.bit_compose(r_bits[m:k]) + \ + overflow = msb.bit_xor(masked >> (n_ring - 1)) + res = shifted - upper + \ (overflow << (k - m)) if signed: res -= (1 << (k - m - 1)) @@ -555,6 +588,7 @@ def TruncPrField(a, k, m, kappa=None): d = (a - a_prime) / two_to_m return d +@instructions_base.ret_cisc def SDiv(a, b, l, kappa, round_nearest=False): theta = int(ceil(log(l / 3.5) / log(2))) alpha = two_power(2*l) @@ -564,7 +598,7 @@ def SDiv(a, b, l, kappa, round_nearest=False): y = y.round(2 * l + 1, l, kappa, round_nearest, signed=False) x2 = types.sint() comparison.Mod2m(x2, x, 2 * l + 1, l, kappa, False) - x1 = comparison.TruncZeroes(x - x2, 2 * l + 1, l, True) + x1 = comparison.TruncZeros(x - x2, 2 * l + 1, l, True) for i in range(theta-1): y = y * (x1 + two_power(l)) + (y * x2).round(2 * l, l, kappa, round_nearest, @@ -576,7 +610,7 @@ def SDiv(a, b, l, kappa, round_nearest=False): signed=False) x2 = types.sint() comparison.Mod2m(x2, x, 2 * l, l, kappa, False) - x1 = comparison.TruncZeroes(x - x2, 2 * l + 1, l, True) + x1 = comparison.TruncZeros(x - x2, 2 * l + 1, l, True) y = y * (x1 + two_power(l)) + (y * x2).round(2 * l, l, kappa, round_nearest, signed=False) y = y.round(2 * l + 1, l - 1, kappa, round_nearest) diff --git a/Compiler/instructions.py b/Compiler/instructions.py index 9fda1346b..fd74fe78c 100644 --- a/Compiler/instructions.py +++ b/Compiler/instructions.py @@ -11,8 +11,10 @@ """ import itertools +import operator from . import tools from random import randint +from functools import reduce from Compiler.config import * from Compiler.exceptions import * import Compiler.instructions_base as base @@ -318,6 +320,11 @@ class use_inp(base.Instruction): code = base.opcodes['USE_INP'] arg_format = ['int','int','int'] +class use_edabit(base.Instruction): + r""" edaBit usage. """ + code = base.opcodes['USE_EDABIT'] + arg_format = ['int','int','int'] + class run_tape(base.Instruction): r""" Start tape $n$ in thread $c_i$ with argument $c_j$. """ code = base.opcodes['RUN_TAPE'] @@ -808,7 +815,29 @@ class dabit(base.DataInstruction): code = base.opcodes['DABIT'] arg_format = ['sw', 'sbw'] field_type = 'modp' - data_type = 'bit' + data_type = 'dabit' + +@base.vectorize +class edabit(base.Instruction): + """ edaBit """ + __slots__ = [] + code = base.opcodes['EDABIT'] + arg_format = tools.chain(['sw'], itertools.repeat('sbw')) + field_type = 'modp' + + def add_usage(self, req_node): + req_node.increment(('edabit', len(self.args) - 1), self.get_size()) + +@base.vectorize +class sedabit(base.Instruction): + """ strict edaBit """ + __slots__ = [] + code = base.opcodes['SEDABIT'] + arg_format = tools.chain(['sw'], itertools.repeat('sbw')) + field_type = 'modp' + + def add_usage(self, req_node): + req_node.increment(('sedabit', len(self.args) - 1), self.get_size()) @base.gf2n @base.vectorize @@ -988,7 +1017,19 @@ def add_usage(self, req_node): req_node.increment((self.field_type, 'input', self.args[0]), \ self.args[1]) -class stopinput(base.RawInputInstruction): + def merge(self, other): + self.args[1] += other.args[1] + +class StopInputInstruction(base.RawInputInstruction): + __slots__ = [] + + def merge(self, other): + if self.get_size() != other.get_size(): + raise NotImplemented() + else: + self.args += other.args[1:] + +class stopinput(StopInputInstruction): r""" Receive inputs from player $p$ and put in registers. """ __slots__ = [] code = base.opcodes['STOPINPUT'] @@ -997,7 +1038,7 @@ class stopinput(base.RawInputInstruction): def has_var_args(self): return True -class gstopinput(base.RawInputInstruction): +class gstopinput(StopInputInstruction): r""" Receive inputs from player $p$ and put in registers. """ __slots__ = [] code = 0x100 + base.opcodes['STOPINPUT'] @@ -1322,6 +1363,26 @@ class bitdecint(base.Instruction): code = base.opcodes['BITDECINT'] arg_format = tools.chain(['ci'], itertools.repeat('ciw')) +class incint(base.VectorInstruction): + __slots__ = [] + code = base.opcodes['INCINT'] + arg_format = ['ciw', 'ci', 'i', 'i', 'i'] + + def __init__(self, *args, **kwargs): + assert len(args[1]) == 1 + if len(args) == 3: + args = list(args) + [1, len(args[0])] + super(incint, self).__init__(*args, **kwargs) + +class shuffle(base.VectorInstruction): + __slots__ = [] + code = base.opcodes['SHUFFLE'] + arg_format = ['ciw','ci'] + + def __init__(self, *args, **kwargs): + super(shuffle, self).__init__(*args, **kwargs) + assert len(args[0]) == len(args[1]) + ### ### Clear comparison instructions ### @@ -1429,6 +1490,9 @@ class convmodp(base.Instruction): code = base.opcodes['CONVMODP'] arg_format = ['ciw', 'c', 'int'] def __init__(self, *args, **kwargs): + if len(args) == len(self.arg_format): + super(convmodp_class, self).__init__(*args) + return bitlength = kwargs.get('bitlength') bitlength = program.bit_length if bitlength is None else bitlength if bitlength > 64: @@ -1472,7 +1536,7 @@ def get_repeat(self): def merge_id(self): # can merge different sizes # but not if large - if self.get_size() > 100: + if self.get_size() is None or self.get_size() > 100: return type(self), self.get_size() return type(self) @@ -1561,6 +1625,31 @@ def get_used(self): for reg in self.args[i + 2:i + self.args[i]]: yield reg +class matmul_base(base.DataInstruction): + data_type = 'triple' + is_vec = lambda self: True + + def get_repeat(self): + return reduce(operator.mul, self.args[3:6]) + +class matmuls(matmul_base): + """ Secret matrix multiplication """ + code = base.opcodes['MATMULS'] + arg_format = ['sw','s','s','int','int','int'] + +class matmulsm(matmul_base): + """ Secret matrix multiplication reading directly from memory """ + code = base.opcodes['MATMULSM'] + arg_format = ['sw','ci','ci','int','int','int','ci','ci','ci','ci', + 'int','int'] + + def __init__(self, *args, **kwargs): + matmul_base.__init__(self, *args, **kwargs) + for i in range(2): + assert args[6 + i].size == args[3 + i] + for i in range(2): + assert args[8 + i].size == args[4 + i] + @base.vectorize class trunc_pr(base.VarArgsInstruction): """ Probalistic truncation for semi-honest computation """ diff --git a/Compiler/instructions_base.py b/Compiler/instructions_base.py index 8508bab5f..70514eb62 100644 --- a/Compiler/instructions_base.py +++ b/Compiler/instructions_base.py @@ -8,6 +8,7 @@ from Compiler.config import * from Compiler import util from Compiler import tools +from Compiler import program ### @@ -59,6 +60,7 @@ NPLAYERS = 0xE2, THRESHOLD = 0xE3, PLAYERID = 0xE4, + USE_EDABIT = 0xE5, # Addition ADDC = 0x20, ADDS = 0x21, @@ -93,6 +95,8 @@ MULRS = 0xA7, DOTPRODS = 0xA8, TRUNC_PR = 0xA9, + MATMULS = 0xAA, + MATMULSM = 0xAB, # Data access TRIPLE = 0x50, BIT = 0x51, @@ -103,6 +107,8 @@ INPUTMASK = 0x56, PREP = 0x57, DABIT = 0x58, + EDABIT = 0x59, + SEDABIT = 0x5A, # Input INPUT = 0x60, INPUTFIX = 0xF0, @@ -153,6 +159,8 @@ MULINT = 0x9D, DIVINT = 0x9E, PRINTINT = 0x9F, + INCINT = 0xD1, + SHUFFLE = 0xD2, # Conversion CONVINT = 0xC0, CONVMODP = 0xC1, @@ -235,13 +243,17 @@ class Vectorized_Instruction(instruction): def __init__(self, size, *args, **kwargs): self.size = size super(Vectorized_Instruction, self).__init__(*args, **kwargs) - for arg,f in zip(self.args, self.arg_format): - if issubclass(ArgFormats[f], RegisterArgFormat): - arg.set_size(size) + if not kwargs.get('copying', False): + for arg,f in zip(self.args, self.arg_format): + if issubclass(ArgFormats[f], RegisterArgFormat): + arg.set_size(size) def get_code(self): - return (self.size << 10) + self.code + return instruction.get_code(self, self.get_size()) def get_pre_arg(self): - return "%d, " % self.size + try: + return "%d, " % self.size + except: + return "{undef}, " def is_vec(self): return True def get_size(self): @@ -250,6 +262,9 @@ def expand(self): set_global_vector_size(self.size) super(Vectorized_Instruction, self).expand() reset_global_vector_size() + def copy(self, size, subs): + return type(self)(size, *self.get_new_args(size, subs), + copying=True) @functools.wraps(instruction) def maybe_vectorized_instruction(*args, **kwargs): @@ -360,6 +375,169 @@ def maybe_gf2n_instruction(*args, **kwargs): return maybe_gf2n_instruction #return instruction +class Mergeable: + pass + +def cisc(function): + class MergeCISC(Mergeable): + instructions = {} + + def __init__(self, *args, **kwargs): + self.args = args + self.kwargs = kwargs + self.calls = [(args, kwargs)] + self.params = [] + self.used = [] + for arg in self.args[1:]: + if isinstance(arg, program.curr_tape.Register): + self.used.append(arg) + self.params.append(type(arg)) + else: + self.params.append(arg) + self.function = function + program.curr_block.instructions.append(self) + + def get_def(self): + return [self.args[0]] + + def get_used(self): + return self.used + + def is_vec(self): + return True + + def merge_id(self): + return self.function, tuple(self.params), \ + tuple(sorted(self.kwargs.items())) + + def merge(self, other): + self.calls += other.calls + + def get_size(self): + return self.args[0].size + + def new_instructions(self, size, regs): + if self.merge_id() not in self.instructions: + from Compiler.program import Tape + tape = Tape(self.function.__name__, program) + old_tape = program.curr_tape + program.curr_tape = tape + block = tape.BasicBlock(tape, None, None) + tape.active_basicblock = block + set_global_vector_size(None) + args = [] + for arg in self.args: + try: + args.append(type(arg)(size=None)) + except: + args.append(arg) + program.options.cisc = False + self.function(*args, **self.kwargs) + program.options.cisc = True + reset_global_vector_size() + program.curr_tape = old_tape + from Compiler.allocator import Merger + merger = Merger(block, program.options, + tuple(program.to_merge)) + args[0].can_eliminate = False + merger.eliminate_dead_code() + assert int(program.options.max_parallel_open) == 0, \ + 'merging restriction not compatible with ' \ + 'mergeable CISC instructions' + merger.longest_paths_merge() + filtered = filter(lambda x: x is not None, block.instructions) + self.instructions[self.merge_id()] = list(filtered), args + template, args = self.instructions[self.merge_id()] + subs = util.dict_by_id() + for arg, reg in zip(args, regs): + subs[arg] = reg + set_global_vector_size(size) + for inst in template: + inst.copy(size, subs) + reset_global_vector_size() + + def expand_merged(self): + tape = program.curr_tape + block = tape.BasicBlock(tape, None, None) + tape.active_basicblock = block + size = sum(call[0][0].size for call in self.calls) + new_regs = [] + for arg in self.args: + try: + new_regs.append(type(arg)(size=size)) + except: + break + base = 0 + for call in self.calls: + for new_reg, reg in zip(new_regs[1:], call[0][1:]): + set_global_vector_size(reg.size) + reg.mov(new_reg.get_vector(base, reg.size), reg) + reset_global_vector_size() + base += reg.size + self.new_instructions(size, new_regs) + base = 0 + for call in self.calls: + reg = call[0][0] + set_global_vector_size(reg.size) + reg.mov(reg, new_regs[0].get_vector(base, reg.size)) + reset_global_vector_size() + base += reg.size + return block.instructions + + MergeCISC.__name__ = function.__name__ + + def wrapper(*args, **kwargs): + if program.options.cisc: + return MergeCISC(*args, **kwargs) + else: + return function(*args, **kwargs) + return wrapper + +def ret_cisc(function): + def instruction(res, *args, **kwargs): + res.mov(res, function(*args, **kwargs)) + instruction.__name__ = function.__name__ + instruction = cisc(instruction) + + def wrapper(*args, **kwargs): + if not program.options.cisc: + return function(*args, **kwargs) + from Compiler import types + if isinstance(args[0], types._clear): + res_type = type(args[1]) + else: + res_type = type(args[0]) + res = res_type(size=args[0].size) + instruction(res, *args, **kwargs) + return res + return wrapper + +def sfix_cisc(function): + from Compiler.types import sfix, sint, cfix, copy_doc + def instruction(res, arg, k, f): + assert k is not None + assert f is not None + old = sfix.k, sfix.f, cfix.k, cfix.f + sfix.k, sfix.f, cfix.k, cfix.f = [None] * 4 + res.mov(res, function(sfix._new(arg, k=k, f=f)).v) + sfix.k, sfix.f, cfix.k, cfix.f = old + instruction.__name__ = function.__name__ + instruction = cisc(instruction) + + def wrapper(*args, **kwargs): + if isinstance(args[0], sfix): + assert len(args) == 1 + assert not kwargs + assert args[0].size == args[0].v.size + k = args[0].k + f = args[0].f + res = sfix._new(sint(size=args[0].size), k=k, f=f) + instruction(res.v, args[0].v, k, f) + return res + else: + return function(*args, **kwargs) + copy_doc(wrapper, function) + return wrapper class RegType(object): """ enum-like static class for Register types """ @@ -381,6 +559,8 @@ def create_dict(init_value_fn): return res class ArgFormat(object): + is_reg = False + @classmethod def check(cls, arg): return NotImplemented @@ -390,11 +570,13 @@ def encode(cls, arg): return NotImplemented class RegisterArgFormat(ArgFormat): + is_reg = True + @classmethod def check(cls, arg): if not isinstance(arg, program.curr_tape.Register): raise ArgumentError(arg, 'Invalid register argument') - if arg.i > REG_MAX: + if arg.i > REG_MAX and arg.i != float('inf'): raise ArgumentError(arg, 'Register index too large') if arg.program != program.curr_tape: raise ArgumentError(arg, 'Register from other tape, trace: %s' % \ @@ -425,7 +607,7 @@ class ClearIntAF(RegisterArgFormat): class IntArgFormat(ArgFormat): @classmethod def check(cls, arg): - if not isinstance(arg, int): + if not isinstance(arg, int) and not arg is None: raise ArgumentError(arg, 'Expected an integer-valued argument') @classmethod @@ -487,7 +669,7 @@ def encode(cls, arg): } def format_str_is_reg(format_str): - return issubclass(ArgFormats[format_str], RegisterArgFormat) + return ArgFormats[format_str].is_reg def format_str_is_writeable(format_str): return format_str_is_reg(format_str) and format_str[-1] == 'w' @@ -504,7 +686,8 @@ class Instruction(object): def __init__(self, *args, **kwargs): """ Create an instruction and append it to the program list. """ self.args = list(args) - self.check_args() + if not kwargs.get('copying', False): + self.check_args() if not program.FIRST_PASS: if kwargs.get('add_to_prog', True): program.curr_block.instructions.append(self) @@ -519,9 +702,9 @@ def __init__(self, *args, **kwargs): if Instruction.count % 100000 == 0: print("Compiled %d lines at" % self.__class__.count, time.asctime()) - def get_code(self): - return self.code - + def get_code(self, prefix=0): + return (prefix << 10) + self.code + def get_encoding(self): enc = int_to_bytes(self.get_code()) # add the number of registers if instruction flagged as has var args @@ -540,12 +723,12 @@ def execute(self): def check_args(self): """ Check the args match up with that specified in arg_format """ - for n,(arg,f) in enumerate(itertools.zip_longest(self.args, self.arg_format)): - if arg is None: - if not isinstance(self.arg_format, (list, tuple)): - break # end of optional arguments - else: - raise CompilerError('Incorrect number of arguments for instruction %s' % (self)) + try: + if len(self.args) != len(self.arg_format): + raise CompilerError('Incorrect number of arguments for instruction %s' % (self)) + except TypeError: + pass + for n,(arg,f) in enumerate(zip(self.args, self.arg_format)): try: ArgFormats[f].check(arg) except ArgumentError as e: @@ -589,6 +772,42 @@ def add_usage(self, req_node): def merge_id(self): return type(self), self.get_size() + def merge(self, other): + if self.get_size() != other.get_size(): + # merge as non-vector instruction + self.args = self.expand_vector_args() + other.expand_vector_args() + if self.is_vec(): + self.size = 1 + else: + self.args += other.args + + def expand_vector_args(self): + if self.is_vec(): + for arg in self.args: + arg.create_vector_elements() + res = sum(list(zip(*self.args)), ()) + return list(res) + else: + return self.args + + def expand_merged(self): + return [self] + + def get_new_args(self, size, subs): + new_args = [] + for arg, f in zip(self.args, self.arg_format): + if arg in subs: + new_args.append(subs[arg]) + elif arg is None: + new_args.append(size) + else: + if format_str_is_writeable(f): + new_args.append(arg.copy()) + subs[arg] = new_args[-1] + else: + new_args.append(arg) + return new_args + # String version of instruction attempting to replicate encoded version def __str__(self): @@ -606,6 +825,13 @@ class VarArgsInstruction(Instruction): def has_var_args(self): return True +class VectorInstruction(Instruction): + __slots__ = [] + is_vec = lambda self: True + + def get_code(self): + return super(VectorInstruction, self).get_code(len(self.args[0])) + ### ### Basic arithmetic ### diff --git a/Compiler/library.py b/Compiler/library.py index d486e008b..e28b3f7e2 100644 --- a/Compiler/library.py +++ b/Compiler/library.py @@ -22,7 +22,7 @@ def get_block(): def vectorize(function): def vectorized_function(*args, **kwargs): - if len(args) > 0 and isinstance(args[0], program.Tape.Register): + if len(args) > 0 and 'size' in dir(args[0]): instructions_base.set_global_vector_size(args[0].size) res = function(*args, **kwargs) instructions_base.reset_global_vector_size() @@ -88,7 +88,7 @@ def print_plain_str(ss): raise CompilerError('Cannot print secret value:', args[i]) elif isinstance(val, cfloat): val.print_float_plain() - elif isinstance(val, list): + elif isinstance(val, (list, tuple, Array)): print_str('[' + ', '.join('%s' for i in range(len(val))) + ']', *val) else: try: @@ -362,9 +362,14 @@ def wrapped_function(*compile_args): class FunctionTape(Function): # not thread-safe + def __init__(self, function, name=None, compile_args=[], + single_thread=False): + Function.__init__(self, function, name, compile_args) + self.single_thread = single_thread def on_first_call(self, wrapped_function): self.thread = MPCThread(wrapped_function, self.name, - args=self.compile_args) + args=self.compile_args, + single_thread=self.single_thread) def on_call(self, base, bases): return FunctionTapeCall(self.thread, base, bases) @@ -376,6 +381,9 @@ def wrapper(function): return FunctionTape(function, compile_args=args) return wrapper +def single_thread_function_tape(function): + return FunctionTape(function, single_thread=True) + def memorize(x): if isinstance(x, (tuple, list)): return tuple(memorize(i) for i in x) @@ -397,13 +405,15 @@ def on_first_call(self, wrapped_function): block.alloc_pool = defaultdict(set) del parent_node.children[-1] self.node = get_tape().req_node - print('Compiling function', self.name) + if get_program().verbose: + print('Compiling function', self.name) result = wrapped_function(*self.compile_args) if result is not None: self.result = memorize(result) else: self.result = None - print('Done compiling function', self.name) + if get_program().verbose: + print('Done compiling function', self.name) p_return_address = get_tape().program.malloc(1, 'ci') get_tape().function_basicblocks[block] = p_return_address return_address = regint.load_mem(p_return_address) @@ -528,7 +538,7 @@ def round(): a[m], a[m+step] = cond_swap(a[m], a[m+step]) for i in range(len(a)): a[i].store_in_mem(i * a[i].sizeof()) - chunk = MPCThread(round, 'sort-%d-%d' % (l,k)) + chunk = MPCThread(round, 'sort-%d-%d' % (l,k), single_thread=True) chunk.start() chunk.join() #round() @@ -541,7 +551,6 @@ def chunkier_odd_even_merge_sort(a, n=None, max_chunk_size=512, n_threads=7, use a_base = instructions.program.malloc(n, 's') for i,j in enumerate(a): store_in_mem(j, a_base + i) - instructions.program.restart_main_thread() else: a_base = a tmp_base = instructions.program.malloc(n, 's') @@ -657,7 +666,6 @@ def run_postproc(): run_postproc() if isinstance(a, list): - instructions.program.restart_main_thread() for i in range(n): a[i] = load_secret_mem(a_base + i) instructions.program.free(a_base, 's') @@ -669,7 +677,6 @@ def loopy_chunkier_odd_even_merge_sort(a, n=None, max_chunk_size=512, n_threads= a_base = instructions.program.malloc(n, 's') for i,j in enumerate(a): store_in_mem(j, a_base + i) - instructions.program.restart_main_thread() else: a_base = a tmp_base = instructions.program.malloc(n, 's') @@ -764,7 +771,6 @@ def inner2(m): range_loop(outer, n // l) if isinstance(a, list): - instructions.program.restart_main_thread() for i in range(n): a[i] = load_secret_mem(a_base + i) instructions.program.free(a_base, 's') @@ -772,30 +778,39 @@ def inner2(m): instructions.program.free(tmp_i, 'ci') -def loopy_odd_even_merge_sort(a, sorted_length=1, n_parallel=32): +def loopy_odd_even_merge_sort(a, sorted_length=1, n_parallel=32, + n_threads=None): + steps = {} l = sorted_length while l < len(a): l *= 2 k = 1 while k < l: k *= 2 - n_outer = len(a) // l - n_inner = l // k n_innermost = 1 if k == 2 else k // 2 - 1 - @for_range_parallel(n_parallel // n_innermost // n_inner, n_outer) - def loop(i): - @for_range_parallel(n_parallel // n_innermost, n_inner) - def inner(j): - base = i*l + j - step = l//k - if k == 2: - a[base], a[base+step] = cond_swap(a[base], a[base+step]) - else: - @for_range_parallel(n_parallel, n_innermost) - def f(i): - m1 = step + i * 2 * step - m2 = m1 + base - a[m2], a[m2+step] = cond_swap(a[m2], a[m2+step]) + key = k + if key not in steps: + @function_block + def step(l): + l = MemValue(l) + @for_range_opt_multithread(n_threads, len(a) // k) + def _(i): + n_inner = l // k + j = i % n_inner + i //= n_inner + base = i*l + j + step = l//k + if k == 2: + a[base], a[base+step] = \ + cond_swap(a[base], a[base+step]) + else: + @for_range_opt(n_innermost) + def f(i): + m1 = step + i * 2 * step + m2 = m1 + base + a[m2], a[m2+step] = cond_swap(a[m2], a[m2+step]) + steps[key] = step + steps[key](l) def mergesort(A): B = Array(len(A), sint) @@ -899,11 +914,12 @@ def _(i): def for_range_opt(n_loops, budget=None): """ Execute loop bodies in parallel up to an optimization budget. This prevents excessive loop unrolling. The budget is respected - even with nested loops. + even with nested loops. Note that optimization is rather + rudimentary for runtime :py:obj:`n_loops` (regint/cint). Consider + using :py:func:`for_range_parallel` in this case. - :param n_loops: compile-time (int) + :param n_loops: int/regint/cint :param budget: number of instructions after which to start optimization (default is 100,000) - :type: compile-time (int) Example: @@ -912,6 +928,7 @@ def for_range_opt(n_loops, budget=None): @for_range_opt(n) def _(i): ... + """ return map_reduce_single(None, n_loops, budget=budget) @@ -929,6 +946,8 @@ def map_reduce_single(n_parallel, n_loops, initializer=lambda *x: [], else: # use Arrays for multithread version use_array = True + if not util.is_constant(n_loops): + budget //= 10 def decorator(loop_body): my_n_parallel = n_parallel if isinstance(n_parallel, int): @@ -955,17 +974,18 @@ def f(i): r = reducer(mem_state, state) write_state_to_memory(r) else: - if n_loops == 0: + if is_zero(n_loops): return - regint.push(0) + n_opt_loops_reg = regint(0) + n_opt_loops_inst = get_block().instructions[-1] parent_block = get_block() - @while_do(lambda x: x + regint.pop() <= n_loops, regint(0)) + @while_do(lambda x: x + n_opt_loops_reg <= n_loops, regint(0)) def _(i): state = tuplify(initializer()) k = 0 block = get_block() - while k < n_loops and (len(get_block()) < budget \ - or k == 0) \ + while (not util.is_constant(n_loops) or k < n_loops) \ + and (len(get_block()) < budget or k == 0) \ and block is get_block(): j = i + k state = reducer(tuplify(loop_body(j)), state) @@ -974,13 +994,13 @@ def _(i): write_state_to_memory(r) global n_opt_loops n_opt_loops = k - regint.push(k) + n_opt_loops_inst.args[1] = k return i + k my_n_parallel = n_opt_loops loop_rounds = n_loops // my_n_parallel blocks = get_tape().basicblocks n_to_merge = 5 - if loop_rounds == 1 and parent_block is blocks[-n_to_merge]: + if util.is_one(loop_rounds) and parent_block is blocks[-n_to_merge]: # merge blocks started by if and do_while def exit_elimination(block): if block.exit_condition is not None: @@ -996,14 +1016,15 @@ def exit_elimination(block): for block in blocks[-n_to_merge + 1:]: merged.instructions += block.instructions exit_elimination(block) - block.purge() + block.purge(retain_usage=False) del blocks[-n_to_merge + 1:] del get_tape().req_node.children[-1] merged.children = [] get_tape().active_basicblock = merged else: req_node = get_tape().req_node.children[-1].nodes[0] - req_node.children[0].aggregator = lambda x: loop_rounds * x[0] + if util.is_constant(loop_rounds): + req_node.children[0].aggregator = lambda x: loop_rounds * x[0] if isinstance(n_loops, int): state = mem_state for j in range(loop_rounds * my_n_parallel, n_loops): @@ -1040,7 +1061,9 @@ def for_range_opt_multithread(n_threads, n_loops): """ Execute :py:obj:`n_loops` loop bodies in up to :py:obj:`n_threads` threads, in parallel up to an optimization budget per thread - similar to :py:func:`for_range_opt`. + similar to :py:func:`for_range_opt`. Note that optimization is rather + rudimentary for runtime :py:obj:`n_loops` (regint/cint). Consider + using :py:func:`for_range_multithread` in this case. :param n_threads: compile-time (int) :param n_loops: regint/cint/int @@ -1089,7 +1112,7 @@ def f(base, size): def map_reduce(n_threads, n_parallel, n_loops, initializer, reducer, \ thread_mem_req={}, looping=True): - n_threads = n_threads or 1 + assert(n_threads != 0) if isinstance(n_loops, list): split = n_loops n_loops = reduce(operator.mul, n_loops) @@ -1103,9 +1126,22 @@ def new_body(i): return new_body new_dec = map_reduce(n_threads, n_parallel, n_loops, initializer, reducer, thread_mem_req) return lambda loop_body: new_dec(decorator(loop_body)) + n_loops = MemValue.if_necessary(n_loops) + if n_threads == None or util.is_one(n_loops): + if not looping: + return lambda loop_body: loop_body(0, n_loops) + dec = map_reduce_single(n_parallel, n_loops, initializer, reducer) + if thread_mem_req: + thread_mem = Array(thread_mem_req[regint], regint) + return lambda loop_body: dec(lambda i: loop_body(i, thread_mem)) + else: + return dec def decorator(loop_body): - thread_rounds = n_loops // n_threads - remainder = n_loops % n_threads + thread_rounds = MemValue.if_necessary(n_loops // n_threads) + if util.is_constant(thread_rounds): + remainder = n_loops % n_threads + else: + remainder = 0 for t in thread_mem_req: if t != regint: raise CompilerError('Not implemented for other than regint') @@ -1113,6 +1149,11 @@ def decorator(loop_body): state = tuple(initializer()) def f(inc): base = args[get_arg()][0] + if not util.is_constant(thread_rounds): + i = base / thread_rounds + overhang = n_loops % n_threads + inc = i < overhang + base += inc.if_else(i, overhang) if not looping: return loop_body(base, thread_rounds + inc) if thread_mem_req: @@ -1129,7 +1170,7 @@ def f(i): return loop_body(base + i) prog = get_program() threads = [] - if thread_rounds: + if not util.is_zero(thread_rounds): tape = prog.new_tape(f, (0,), 'multithread') for i in range(n_threads - remainder): mem_state = make_array(initializer()) @@ -1465,6 +1506,15 @@ def get_player_id(): playerid(res._v) return res +def break_point(name=''): + """ + Insert break point. This makes sure that all following code + will be executed after preceding code. + + :param name: Name for identification (optional) + """ + get_tape().start_new_basicblock(name=name) + # Fixed point ops from math import ceil, log @@ -1590,6 +1640,7 @@ def IntDiv(a, b, k, kappa=None): return FPDiv(a.extend(2 * k) << k, b.extend(2 * k) << k, 2 * k, k, kappa, nearest=True) +@instructions_base.ret_cisc def FPDiv(a, b, k, f, kappa, simplex_flag=False, nearest=False): """ Goldschmidt method as presented in Catrina10, diff --git a/Compiler/ml.py b/Compiler/ml.py index 417f440a0..d5557981c 100644 --- a/Compiler/ml.py +++ b/Compiler/ml.py @@ -1,9 +1,37 @@ -import mpc_math, math - +""" +This module contains machine learning functionality. It is work in +progress, so you must expect things to change. The most tested +functionality is logistic regression. It can be run as follows:: + + sgd = ml.SGD([ml.Dense(n_examples, n_features, 1), + ml.Output(n_examples, approx=True)], n_epochs, + report_loss=True) + sgd.layers[0].X.input_from(0) + sgd.layers[1].Y.input_from(1) + sgd.reset() + sgd.run() + +This loads measurements from party 0 and labels (0/1) from party +1. After running, the model is stored in :py:obj:`sgd.layers[0].W` and +:py:obj:`sgd.layers[1].b`. The :py:obj:`approx` parameter determines +whether to use an approximate sigmoid function. Inference can be run as +follows:: + + data = sfix.Matrix(n_test, n_features) + data.input_from(0) + res = sgd.eval(data) + print_ln('Results: %s', [x.reveal() for x in res]) +""" + +import math + +from Compiler import mpc_math from Compiler.types import * from Compiler.types import _unreduced_squant from Compiler.library import * -from Compiler.util import is_zero +from Compiler.util import is_zero, tree_reduce +from Compiler.comparison import CarryOutRawLE +from Compiler.GC.types import sbitint from functools import reduce def log_e(x): @@ -31,6 +59,29 @@ def sigmoid_prime(x): sx = sigmoid(x) return sx * (1 - sx) +@vectorize +def approx_sigmoid(x): + if approx_sigmoid.special and \ + get_program().options.ring and get_program().use_edabit(): + l = int(get_program().options.ring) + r, r_bits = sint.get_edabit(x.k, False) + c = ((x.v - r) << (l - x.k)).reveal() >> (l - x.k) + c_bits = c.bit_decompose(x.k) + lower_overflow = CarryOutRawLE(c_bits[:x.f - 1], r_bits[:x.f - 1]) + higher_bits = sbitint.bit_adder(c_bits[x.f - 1:], r_bits[x.f - 1:], + lower_overflow) + sign = higher_bits[-1] + higher_bits.pop(-1) + aa = sign & ~util.tree_reduce(operator.and_, higher_bits) + bb = ~sign & ~util.tree_reduce(operator.and_, [~x for x in higher_bits]) + a, b = (sint.conv(x) for x in (aa, bb)) + else: + a = x < -0.5 + b = x > 0.5 + return a.if_else(0, b.if_else(1, 0.5 + x)) + +approx_sigmoid.special = False + def lse_0_from_e_x(x, e_x): return sanitize(-x, log_e(1 + e_x), x + 2 ** -x.f, 0) @@ -56,7 +107,7 @@ class Layer: n_threads = 1 class Output(Layer): - def __init__(self, N, debug=False): + def __init__(self, N, debug=False, approx=False): self.N = N self.X = sfix.Array(N) self.Y = sfix.Array(N) @@ -64,9 +115,8 @@ def __init__(self, N, debug=False): self.l = MemValue(sfix(-1)) self.e_x = sfix.Array(N) self.debug = debug - self.weights = cint.Array(N) - self.weights.assign_all(1) - self.weight_total = N + self.weights = None + self.approx = approx nablas = lambda self: () thetas = lambda self: () @@ -75,30 +125,45 @@ def __init__(self, N, debug=False): def divisor(self, divisor, size): return cfix(1.0 / divisor, size=size) - def forward(self, N=None): - N = N or self.N + def forward(self, batch): + if self.approx: + self.l.write(999) + return + N = len(batch) lse = sfix.Array(N) @multithread(self.n_threads, N) def _(base, size): x = self.X.get_vector(base, size) - y = self.Y.get_vector(base, size) + y = self.Y.get(batch.get_vector(base, size)) e_x = exp(-x) self.e_x.assign(e_x, base) lse.assign(lse_0_from_e_x(-x, e_x) + x * (1 - y), base) e_x = self.e_x.get_vector(0, N) self.l.write(sum(lse) * \ - self.divisor(self.N, 1)) + self.divisor(N, 1)) - def backward(self): - @multithread(self.n_threads, self.N) + def eval(self, size, base=0): + if self.approx: + return approx_sigmoid(self.X.get_vector(base, size)) + else: + return sigmoid_from_e_x(self.X.get_vector(base, size), + self.e_x.get_vector(base, size)) + + def backward(self, batch): + N = len(batch) + @multithread(self.n_threads, N) def _(base, size): - diff = sigmoid_from_e_x(self.X.get_vector(base, size), - self.e_x.get_vector(base, size)) - \ - self.Y.get_vector(base, size) + diff = self.eval(size, base) - \ + self.Y.get(batch.get_vector(base, size)) assert sfix.f == cfix.f - diff *= self.weights.get_vector(base, size) - self.nabla_X.assign(diff * self.divisor(self.weight_total, size), \ - base) + if self.weights is None: + diff *= self.divisor(N, size) + else: + assert N == len(self.weights) + diff *= self.weights.get_vector(base, size) + if self.weight_total != 1: + diff *= self.divisor(self.weight_total, size) + self.nabla_X.assign(diff, base) # @for_range_opt(len(diff)) # def _(i): # self.nabla_X[i] = self.nabla_X[i] * self.weights[i] @@ -112,6 +177,7 @@ def _(i): #print_ln('%s', x) def set_weights(self, weights): + self.weights = cfix.Array(len(weights)) self.weights.assign(weights) self.weight_total = sum(weights) @@ -119,16 +185,28 @@ class DenseBase(Layer): thetas = lambda self: (self.W, self.b) nablas = lambda self: (self.nabla_W, self.nabla_b) - def backward_params(self, f_schur_Y): - N = self.N + def backward_params(self, f_schur_Y, batch): + N = len(batch) tmp = Matrix(self.d_in, self.d_out, unreduced_sfix) - @for_range_opt_multithread(self.n_threads, [self.d_in, self.d_out]) - def _(j, k): - assert self.d == 1 - a = [f_schur_Y[i][0][k] for i in range(N)] - b = [self.X[i][0][j] for i in range(N)] - tmp[j][k] = sfix.unreduced_dot_product(a, b) + assert self.d == 1 + if self.d_out == 1: + @multithread(self.n_threads, self.d_in) + def _(base, size): + A = sfix.Matrix(1, self.N, address=f_schur_Y.address) + B = sfix.Matrix(self.N, self.d_in, address=self.X.address) + mp = A.direct_mul(B, reduce=False, + indices=(regint(0, size=1), + regint.inc(N), + batch.get_vector(), + regint.inc(size, base))) + tmp.assign_vector(mp, base) + else: + @for_range_opt_multithread(self.n_threads, [self.d_in, self.d_out]) + def _(j, k): + a = [f_schur_Y[i][0][k] for i in range(N)] + b = [self.X[i][0][j] for i in batch] + tmp[j][k] = sfix.unreduced_dot_product(a, b) if self.d_in * self.d_out < 100000: print('reduce at once') @@ -189,26 +267,34 @@ def _(j): self.W[i][j] = sfix.get_random(-r, r) self.b.assign_all(0) - def compute_f_input(self): - prod = MultiArray([self.N, self.d, self.d_out], sfix) - @for_range_opt_multithread(self.n_threads, self.N) - def _(i): - self.X[i].plain_mul(self.W, res=prod[i]) + def compute_f_input(self, batch): + N = len(batch) + prod = MultiArray([N, self.d, self.d_out], sfix) + assert self.d == 1 + assert self.d_out == 1 + @multithread(self.n_threads, N) + def _(base, size): + X_sub = sfix.Matrix(self.N, self.d_in, address=self.X.address) + prod.assign_vector( + X_sub.direct_mul(self.W, indices=(batch.get_vector(base, size), + regint.inc(self.d_in), + regint.inc(self.d_in), + regint.inc(self.d_out))), + base) - @for_range_opt_multithread(self.n_threads, self.N) - def _(i): - @for_range_opt(self.d) - def _(j): - v = prod[i][j].get_vector() + self.b.get_vector() - self.f_input[i][j].assign(v) + @multithread(self.n_threads, N) + def _(base, size): + v = prod.get_vector(base, size) + self.b.expand_to_vector(0, size) + self.f_input.assign_vector(v, base) progress('f input') - def forward(self): - self.compute_f_input() - self.Y.assign_vector(self.f(self.f_input.get_vector())) + def forward(self, batch=None): + self.compute_f_input(batch=batch) + self.Y.assign_vector(self.f( + self.f_input.get_part_vector(0, len(batch)))) - def backward(self, compute_nabla_X=True): - N = self.N + def backward(self, compute_nabla_X=True, batch=None): + N = len(batch) d = self.d d_out = self.d_out X = self.X @@ -233,6 +319,7 @@ def backward(self, compute_nabla_X=True): @for_range_opt(N) def _(i): + i = batch[i] f_schur_Y[i] = nabla_Y[i].schur(f_prime_bit[i]) progress('f prime schur Y') @@ -240,6 +327,7 @@ def _(i): if compute_nabla_X: @for_range_opt(N) def _(i): + i = batch[i] if self.activation == 'id': nabla_X[i] = nabla_Y[i].mul_trans(W) else: @@ -247,7 +335,7 @@ def _(i): progress('nabla X') - self.backward_params(f_schur_Y) + self.backward_params(f_schur_Y, batch=batch) class QuantizedDense(DenseBase): def __init__(self, N, d_in, d_out): @@ -443,8 +531,8 @@ def n_summands(self): _, inputs_h, inputs_w, n_channels_in = self.input_shape return weights_h * weights_w * n_channels_in - def forward(self, N=1): - assert(N == 1) + def forward(self, batch): + assert len(batch) == 1 assert(self.weight_shape[0] == self.output_shape[-1]) _, weights_h, weights_w, _ = self.weight_shape @@ -499,8 +587,8 @@ def n_summands(self): _, weights_h, weights_w, _ = self.weight_shape return weights_h * weights_w - def forward(self, N=1): - assert(N == 1) + def forward(self, batch): + assert len(batch) == 1 assert(self.weight_shape[-1] == self.output_shape[-1]) assert(self.input_shape[-1] == self.output_shape[-1]) @@ -562,8 +650,8 @@ def input_from(self, player): for s in self.input_squant, self.output_squant: s.set_params(sfloat.get_input_from(player), sint.get_input_from(player)) - def forward(self, N=1): - assert(N == 1) + def forward(self, batch): + assert len(batch) == 1 _, input_h, input_w, n_channels_in = self.input_shape _, output_h, output_w, n_channels_out = self.output_shape @@ -623,8 +711,8 @@ def input_from(self, player): for i in range(2): sint.get_input_from(player) - def forward(self, N=1): - assert(N == 1) + def forward(self, batch): + assert len(batch) == 1 # reshaping is implicit self.Y.assign(self.X) @@ -634,8 +722,8 @@ def input_from(self, player): for s in self.input_squant, self.output_squant: s.set_params(sfloat.get_input_from(player), sint.get_input_from(player)) - def forward(self, N=1): - assert(N == 1) + def forward(self, batch): + assert len(batch) == 1 assert(len(self.input_shape) == 2) # just print the best @@ -648,31 +736,40 @@ def comp(left, right): class Optimizer: n_threads = Layer.n_threads - def forward(self, N): + def forward(self, N=None, batch=None): + if batch is None: + batch = regint.Array(N) + batch.assign(regint.inc(N)) for j in range(len(self.layers) - 1): - self.layers[j].forward() - self.layers[j + 1].X.assign(self.layers[j].Y) - self.layers[-1].forward(N) - - def backward(self): + self.layers[j].forward(batch=batch) + tmp = self.layers[j].Y.get_part_vector(0, len(batch)) + self.layers[j + 1].X.assign_vector(tmp) + self.layers[-1].forward(batch=batch) + + def eval(self, data): + N = len(data) + self.layers[0].X.assign(data) + self.forward(N) + return self.layers[-1].eval(N) + + def backward(self, batch): for j in range(1, len(self.layers)): - self.layers[-j].backward() - self.layers[-j - 1].nabla_Y.assign(self.layers[-j].nabla_X) - self.layers[0].backward(compute_nabla_X=False) - - def run(self): + self.layers[-j].backward(batch=batch) + self.layers[-j - 1].nabla_Y.assign_vector( + self.layers[-j].nabla_X.get_part_vector(0, len(batch))) + self.layers[0].backward(compute_nabla_X=False, batch=batch) + + def run(self, batch_size=None): + if batch_size is not None: + N = batch_size + else: + N = self.layers[0].N i = MemValue(0) @do_while def _(): if self.X_by_label is not None: - N = self.layers[0].N - assert self.layers[-1].N == N assert N % 2 == 0 n = N // 2 - @for_range(n) - def _(i): - self.layers[-1].Y[i] = 0 - self.layers[-1].Y[i + n] = 1 n_per_epoch = int(math.ceil(1. * max(len(X) for X in self.X_by_label) / n)) print('%d runs per epoch' % n_per_epoch) @@ -680,26 +777,27 @@ def _(i): for label, X in enumerate(self.X_by_label): indices = regint.Array(n * n_per_epoch) indices_by_label.append(indices) - indices.assign(i % len(X) for i in range(len(indices))) + indices.assign(regint.inc(len(indices), 0, 1, 1, len(X))) indices.shuffle() @for_range(n_per_epoch) def _(j): - j = MemValue(j) + batch = regint.Array(N) for label, X in enumerate(self.X_by_label): indices = indices_by_label[label] - @for_range_multithread(self.n_threads, 1, n) - def _(i): - idx = indices[i + j * n] - self.layers[0].X[i + label * n] = X[idx] - self.forward(None) - self.backward() + batch.assign(indices.get_vector(j * n, n) + + regint(label * len(self.X_by_label[0]), size=n), + label * n) + self.forward(batch=batch) + self.backward(batch=batch) self.update(i) else: - self.forward(None) - self.backward() + batch = regint.Array(N) + batch.assign(regint.inc(N)) + self.forward(batch=batch) + self.backward(batch=batch) self.update(i) loss = self.layers[-1].l - if self.report_loss: + if self.report_loss and not self.layers[-1].approx: print_ln('loss after epoch %s: %s', i, loss.reveal()) else: print_ln('done with epoch %s', i) @@ -772,6 +870,13 @@ def __init__(self, layers, n_epochs, debug=False, report_loss=False): def reset(self, X_by_label=None): self.X_by_label = X_by_label + if X_by_label is not None: + for label, X in enumerate(X_by_label): + @for_range_multithread(self.n_threads, 1, len(X)) + def _(i): + j = i + label * len(X_by_label[0]) + self.layers[0].X[j] = X[i] + self.layers[-1].Y[j] = label for y in self.delta_thetas: y.assign_all(0) for layer in self.layers: @@ -780,16 +885,15 @@ def reset(self, X_by_label=None): def update(self, i_epoch): for nabla, theta, delta_theta in zip(self.nablas, self.thetas, self.delta_thetas): - @for_range_opt_multithread(self.n_threads, len(nabla)) - def _(k): - old = delta_theta[k] - if isinstance(old, Array): - old = old.get_vector() + @multithread(self.n_threads, len(nabla)) + def _(base, size): + old = delta_theta.get_vector(base, size) red_old = self.momentum * old - new = self.gamma * nabla[k] + new = self.gamma * nabla.get_vector(base, size) diff = red_old - new - delta_theta[k] = diff - theta[k] = theta[k] + delta_theta[k] + delta_theta.assign_vector(diff, base) + theta.assign_vector(theta.get_vector(base, size) + + delta_theta.get_vector(base, size), base) if self.debug: for x, name in (old, 'old'), (red_old, 'red_old'), \ (new, 'new'), (diff, 'diff'): diff --git a/Compiler/mpc_math.py b/Compiler/mpc_math.py index 84c45c42b..a8014266b 100644 --- a/Compiler/mpc_math.py +++ b/Compiler/mpc_math.py @@ -12,6 +12,8 @@ from Compiler import types from Compiler import comparison from Compiler import program +from Compiler import instructions_base + # polynomials as enumerated on Hart's book ## # @private @@ -154,8 +156,8 @@ def p_eval(p_c, x): def sTrigSub(x): # reduction to 2* \pi f = x * (1.0 / (2 * pi)) - f = load_sint(trunc(f), type(x)) - y = x - (f) * (2 * pi) + f = trunc(f) + y = x - (f) * x.coerce(2 * pi) # reduction to \pi b1 = y > pi w = b1 * ((2 * pi - y) - y) + y @@ -210,6 +212,7 @@ def scos(w, s): # facade method calls --it is built in a generic way +@instructions_base.sfix_cisc def sin(x): """ Returns the sine of any given fractional value. @@ -224,6 +227,7 @@ def sin(x): return ssin(w, b1) +@instructions_base.sfix_cisc def cos(x): """ Returns the cosine of any given fractional value. @@ -239,6 +243,7 @@ def cos(x): return scos(w, b2) +@instructions_base.sfix_cisc def tan(x): """ Returns the tangent of any given fractional value. @@ -258,6 +263,7 @@ def tan(x): @types.vectorize +@instructions_base.sfix_cisc def exp2_fx(a): """ Power of two for fixed-point numbers. @@ -273,39 +279,49 @@ def exp2_fx(a): n_int_bits = int(math.ceil(math.log(a.k - a.f, 2))) n_bits = a.f + n_int_bits n_shift = int(types.program.options.ring) - a.k - r_bits = [sint.get_random_bit() for i in range(a.k)] - shifted = ((a.v - sint.bit_compose(r_bits)) << n_shift).reveal() + if types.program.use_edabit(): + l = sint.get_edabit(a.f, True) + u = sint.get_edabit(a.k - a.f, True) + r_bits = l[1] + u[1] + r = l[0] + (u[0] << a.f) + lower_r = l[0] + else: + r_bits = [sint.get_random_bit() for i in range(a.k)] + r = sint.bit_compose(r_bits) + lower_r = sint.bit_compose(r_bits[:a.f]) + shifted = ((a.v - r) << n_shift).reveal() masked_bits = (shifted >> n_shift).bit_decompose(a.k) - lower_overflow = sint() - comparison.CarryOut(lower_overflow, masked_bits[a.f-1::-1], + lower_overflow = comparison.CarryOutRaw(masked_bits[a.f-1::-1], r_bits[a.f-1::-1]) - lower_r = sint.bit_compose(r_bits[:a.f]) lower_masked = sint.bit_compose(masked_bits[:a.f]) - lower = lower_r + lower_masked - (lower_overflow << (a.f)) + lower = lower_r + lower_masked - (sint.conv(lower_overflow) << (a.f)) c = types.sfix._new(lower, k=a.k, f=a.f) - higher_bits = intbitint.bit_adder(masked_bits[a.f:n_bits], - r_bits[a.f:n_bits], + higher_bits = r_bits[0].bit_adder(r_bits[a.f:n_bits], + masked_bits[a.f:n_bits], carry_in=lower_overflow, get_carry=True) - d = types.sfix.from_sint(floatingpoint.Pow2_from_bits(higher_bits[:-1]), - k=a.k, f=a.f) + assert(len(higher_bits) == n_bits - a.f + 1) + pow2_bits = [sint.conv(x) for x in higher_bits] + d = floatingpoint.Pow2_from_bits(pow2_bits[:-1]) e = p_eval(p_1045, c) g = d * e - small_result = types.sfix._new(g.v.round(a.k + 1, a.f, signed=False, + small_result = types.sfix._new(g.v.round(a.f + 2 ** n_int_bits, + 2 ** n_int_bits, signed=False, nearest=types.sfix.round_nearest), k=a.k, f=a.f) carry = comparison.CarryOutLE(masked_bits[n_bits:-1], r_bits[n_bits:-1], higher_bits[-1]) # should be for free - highest_bits = intbitint.ripple_carry_adder( + highest_bits = r_bits[0].ripple_carry_adder( masked_bits[n_bits:-1], [0] * (a.k - n_bits), carry_in=higher_bits[-1]) bits_to_check = [x.bit_xor(y) for x, y in zip(highest_bits[:-1], r_bits[n_bits:-1])] - t = floatingpoint.KMul(bits_to_check) + t = sint.conv(floatingpoint.KOpL(lambda x, y: x.bit_and(y), + bits_to_check)) # sign - s = masked_bits[-1].bit_xor(r_bits[-1]).bit_xor(carry) + s = carry.bit_xor(sint.conv(r_bits[-1])).bit_xor(masked_bits[-1]) return s.if_else(t.if_else(small_result, 0), g) else: # obtain absolute value of a @@ -313,16 +329,17 @@ def exp2_fx(a): a = (s * (-2) + 1) * a # isolates fractional part of number b = trunc(a) - c = a - load_sint(b, type(a)) + c = a - b # squares integer part of a - d = load_sint(b.pow2(types.sfix.k - types.sfix.f), type(a)) + d = b.pow2(a.k - a.f) # evaluates fractional part of a in p_1045 e = p_eval(p_1045, c) g = d * e - return (1 - s) * g + s * ((types.sfix(1)) / g) + return (1 - s) * g + s / g @types.vectorize +@instructions_base.sfix_cisc def log2_fx(x): """ Returns the result of :math:`\log_2(x)` for any unbounded @@ -347,14 +364,13 @@ def log2_fx(x): v, p, vlen = d.v, d.p, d.vlen # isolates mantisa of d, now the n can be also substituted by the # secret shared p from d in the expresion above. - v = load_sint(v, type(x)) - w = (1.0 / (2 ** (vlen))) + w = x.coerce(1.0 / (2 ** (vlen))) v = v * w # polynomials for the log_2 evaluation of f are calculated P = p_eval(p_2524, v) Q = p_eval(q_2524, v) # the log is returned by adding the result of the division plus p. - a = P / Q + load_sint(vlen + p, type(x)) + a = P / Q + (vlen + p) return a # *(1-(f.z))*(1-f.s)*(1-f.error) @@ -515,7 +531,7 @@ def norm_simplified_SQ(b, k): # @return g: approximated sqrt def sqrt_simplified_fx(x): # fix theta (number of iterations) - theta = max(int(math.ceil(math.log(types.sfix.k))), 6) + theta = max(int(math.ceil(math.log(x.k))), 6) # process to use 2^(m/2) approximation m_odd, m, w = norm_simplified_SQ(x.v, x.k) @@ -524,15 +540,15 @@ def sqrt_simplified_fx(x): m_odd = (1 - 2 * m_odd) + m_odd w = (w * 2 - w) * (1-m_odd) + w # map number to use sfix format and instantiate the number - w = types.sfix(w * 2 ** ((x.f - (x.f % 2)) // 2)) + w = types.sfix(w * 2 ** ((x.f - (x.f % 2)) // 2), k=x.k, f=x.f) # obtains correct 2 ** (m/2) - w = (w * (types.cfix(2 ** (1/2.0))) - w) * m_odd + w + w = (w * (2 ** (1/2.0)) - w) * m_odd + w # produce x/ 2^(m/2) - y_0 = types.cfix(1.0) / w + y_0 = 1 / w # from this point on it sufices to work sfix-wise g_0 = (y_0 * x) - h_0 = y_0 * types.cfix(0.5) + h_0 = y_0 * 0.5 gh_0 = g_0 * h_0 ## initialization @@ -689,7 +705,8 @@ def sqrt_fx(x_l, k, f): @types.vectorize -def sqrt(x, k = types.sfix.k, f = types.sfix.f): +@instructions_base.sfix_cisc +def sqrt(x, k=None, f=None): """ Returns the square root (sfix) of any given fractional value as long as it can be rounded to a integral value @@ -699,7 +716,11 @@ def sqrt(x, k = types.sfix.k, f = types.sfix.f): :return: square root of :py:obj:`x` (sfix). """ - if (3 *k -2 * f >= types.sfix.f): + if k is None: + k = x.k + if f is None: + f = x.f + if (3 *k -2 * f >= f): return sqrt_simplified_fx(x) # raise OverflowError("bound for precision violated: 3 * k - 2 * f < x.f ") else: @@ -707,6 +728,7 @@ def sqrt(x, k = types.sfix.k, f = types.sfix.f): return sqrt_fx(param ,k ,f) +@instructions_base.sfix_cisc def atan(x): """ Returns the arctangent (sfix) of any given fractional value. @@ -720,12 +742,12 @@ def atan(x): x_abs = (s * (-2) + 1) * x # angle isolation b = x_abs > 1 - v = (types.cfix(1.0) / x_abs) + v = 1 / x_abs v = (1 - b) * (x_abs - v) + v v_2 =v*v # range of polynomial coefficients - assert x.k - x.f >= 18 + assert x.k - x.f >= 15 P = p_eval(p_5102, v_2) Q = p_eval(q_5102, v_2) diff --git a/Compiler/oram.py b/Compiler/oram.py index 4d61b2d86..db26bcf10 100644 --- a/Compiler/oram.py +++ b/Compiler/oram.py @@ -57,10 +57,14 @@ def __init__(self, value, start, lengths, entries_per_block): length = sum(self.lengths) self.n_bits = length * entries_per_block self.start = self.value_type.hard_conv(start * length) - self.lower, self.shift = \ - floatingpoint.Trunc(self.value, self.n_bits, self.start, \ + if Program.prog.options.ring: + self.lower, trunc, self.shift = floatingpoint.SplitInRing( + self.value, self.n_bits, self.start) + else: + self.lower, self.shift = \ + floatingpoint.Trunc(self.value, self.n_bits, self.start, \ Program.prog.security, True) - trunc = (self.value - self.lower) / self.shift + trunc = (self.value - self.lower) / self.shift self.slice = trunc.mod2m(length, self.n_bits, False) self.upper = (trunc - self.slice) * self.shift def get_slice(self): @@ -810,7 +814,7 @@ def get_n_threads(n_loops): if n_loops > 2048: return 8 else: - return 1 + return None else: return n_threads @@ -1375,7 +1379,11 @@ def get_value_size(value_type): if value_type == sgf2n: return Program.prog.galois_length elif value_type == sint: - return 127 - Program.prog.security + ring = Program.prog.options.ring + if ring: + return int(ring) + else: + return 127 - Program.prog.security else: return value_type.max_length @@ -1477,11 +1485,13 @@ def translate_index(self, index): rem = mod2m(index, self.log_entries_per_block, log2(self.size), False) c = mod2m(rem, self.log_entries_per_element, \ self.log_entries_per_block, False) - b = (rem - c) / self.entries_per_element + b = (rem - c).trunc_zeros(self.log_entries_per_element, + self.log_entries_per_block) if self.small: return 0, b, c else: - return (index - rem) / self.entries_per_block, b, c + return (index - rem).trunc_zeros(self.log_entries_per_block, + log2(self.size)), b, c else: index_bits = bit_decompose(index, log2(self.size)) l1 = self.log_entries_per_element diff --git a/Compiler/permutation.py b/Compiler/permutation.py index 79c32e275..4be4611b5 100644 --- a/Compiler/permutation.py +++ b/Compiler/permutation.py @@ -301,7 +301,8 @@ def f(i): conf_address = MemValue(config.address + depth.read()*n) do_round(size, conf_address, a.address, a2.address, 1) - for i in range(n): + @for_range(n) + def _(i): a[i] = a2[i] nblocks.write(nblocks*2) @@ -317,7 +318,8 @@ def f(i): conf_address = MemValue(config.address + depth.read()*n) do_round(size, conf_address, a.address, a2.address, 0) - for i in range(n): + @for_range(n) + def _(i): a[i] = a2[i] nblocks.write(nblocks//2) @@ -379,6 +381,14 @@ def config_shuffle(n, value_type): config_bits = configure_waksman(perm) # 2-D array config = Array(len(config_bits) * len(perm), value_type.reg_type) + if n > 1024: + for x in config_bits: + for y in x: + get_program().public_input(y) + @for_range(sum(len(x) for x in config_bits)) + def _(i): + config[i] = public_input() + return config for i,c in enumerate(config_bits): for j,b in enumerate(c): config[i * len(perm) + j] = b diff --git a/Compiler/program.py b/Compiler/program.py index 7b2537884..7147adc74 100644 --- a/Compiler/program.py +++ b/Compiler/program.py @@ -22,12 +22,14 @@ bit = 2, inverse = 3, bittriple = 4, - bitgf2ntriple = 5 + bitgf2ntriple = 5, + dabit = 6, ) field_types = dict( modp = 0, gf2n = 1, + bit = 2, ) @@ -39,6 +41,7 @@ class Program(object): and threads. """ def __init__(self, args, options, param=-1, assemblymode=False): self.options = options + self.verbose = options.verbose self.args = args self.init_names(args, assemblymode) self.P = P_VALUES[param] @@ -56,7 +59,8 @@ def __init__(self, args, options, param=-1, assemblymode=False): self.security = 40 print('Default security parameter:', self.security) self.galois_length = int(options.galois) - print('Galois length:', self.galois_length) + if self.verbose: + print('Galois length:', self.galois_length) self.schedule = [('start', [])] self.tape_counter = 0 self.tapes = [] @@ -73,9 +77,9 @@ def __init__(self, args, options, param=-1, assemblymode=False): self.n_threads = 1 self.free_threads = set() self.public_input_file = open(self.programs_dir + '/Public-Input/%s' % self.name, 'w') + self.use_public_input_file = False self.types = {} self.budget = int(self.options.budget) - self.verbose = False self.to_merge = [Compiler.instructions.asm_open_class, \ Compiler.instructions.gasm_open_class, \ Compiler.instructions.muls_class, \ @@ -89,12 +93,15 @@ def __init__(self, args, options, param=-1, assemblymode=False): Compiler.instructions.inputfix_class, Compiler.instructions.inputfloat_class, Compiler.instructions.inputmixed_class, - Compiler.instructions.trunc_pr_class] + Compiler.instructions.trunc_pr_class, + Compiler.instructions_base.Mergeable] import Compiler.GC.instructions as gc self.to_merge += [gc.ldmsdi, gc.stmsdi, gc.ldmsd, gc.stmsd, \ gc.stmsdci, gc.xors, gc.andrs, gc.ands, gc.inputb] self.use_trunc_pr = False self.use_dabit = options.mixed + self._edabit = options.edabit + self._split = False Program.prog = self self.reset_values() @@ -132,7 +139,8 @@ def init_names(self, args, assemblymode): else: # assume source is in main SPDZ directory self.programs_dir = sys.path[0] + '/Programs' - print('Compiling program in', self.programs_dir) + if self.verbose: + print('Compiling program in', self.programs_dir) # create extra directories if needed for dirname in ['Public-Input', 'Bytecode', 'Schedules']: @@ -161,7 +169,7 @@ def init_names(self, args, assemblymode): self.name += '-' + '-'.join(args[1:]) self.progname = progname - def new_tape(self, function, args=[], name=None): + def new_tape(self, function, args=[], name=None, single_thread=False): if name is None: name = function.__name__ name = "%s-%s" % (self.name, name) @@ -170,7 +178,7 @@ def new_tape(self, function, args=[], name=None): tape_index = len(self.tapes) self.tape_stack.append(self.curr_tape) self.curr_tape = Tape(name, self) - self.curr_tape.prevent_direct_memory_write = True + self.curr_tape.prevent_direct_memory_write = not single_thread self.tapes.append(self.curr_tape) function(*args) self.finalize_tape(self.curr_tape) @@ -183,7 +191,8 @@ def run_tape(self, tape_index, arg): raise CompilerError('Compiler does not support ' \ 'recursive spawning of threads') if self.free_threads: - thread_number = self.free_threads.pop() + thread_number = min(self.free_threads) + self.free_threads.remove(thread_number) else: thread_number = self.n_threads self.n_threads += 1 @@ -376,7 +385,7 @@ def malloc(self, size, mem_type, reg_type=None): else: addr = self.allocated_mem[mem_type] self.allocated_mem[mem_type] += size - if len(str(addr)) != len(str(addr + size)): + if len(str(addr)) != len(str(addr + size)) and self.verbose: print("Memory of type '%s' now of size %d" % (mem_type, addr + size)) self.allocated_mem_blocks[addr,mem_type] = size return addr @@ -404,6 +413,7 @@ def finalize_memory(self): def public_input(self, x): self.public_input_file.write('%s\n' % str(x)) + self.use_public_input_file = True def set_bit_length(self, bit_length): self.bit_length = bit_length @@ -421,6 +431,22 @@ def get_tape_counter(self): self.tape_counter += 1 return res + def use_edabit(self, change=None): + if change is None: + return self._edabit + else: + self._edabit = change + + def use_edabit_for(self, *args): + return True + + def use_split(self, change=None): + if change is None: + return self._split + else: + assert change in (2, 3) + self._split = change + class Tape: """ A tape contains a list of basic blocks, onto which instructions are added. """ def __init__(self, name, program): @@ -503,13 +529,17 @@ def adjust_jump(self): self.exit_condition.set_relative_jump(offset) #print 'Basic block %d jumps to %d (%d)' % (next_block_index, jump_index, offset) - def purge(self): + def purge(self, retain_usage=True): def relevant(inst): req_node = Tape.ReqNode('') req_node.num = Tape.ReqNum() inst.add_usage(req_node) return req_node.num != {} - self.usage_instructions = list(filter(relevant, self.instructions)) + if retain_usage: + self.usage_instructions = list(filter(relevant, + self.instructions)) + else: + self.usage_instructions = [] if len(self.usage_instructions) > 1000: print('Retaining %d instructions' % len(self.usage_instructions)) del self.instructions @@ -526,6 +556,12 @@ def add_usage(self, req_node): req_node.num['all', 'round'] = self.n_rounds req_node.num['all', 'inv'] = self.n_to_merge + def expand_cisc(self): + new_instructions = [] + for inst in self.instructions: + new_instructions.extend(inst.expand_merged()) + self.instructions = new_instructions + def __str__(self): return self.name @@ -577,8 +613,9 @@ def purge(self): def unpurged(function): def wrapper(self, *args, **kwargs): if self.purged: - print('%s called on purged block %s, ignoring' % \ - (function.__name__, self.name)) + if self.program.verbose: + print('%s called on purged block %s, ignoring' % \ + (function.__name__, self.name)) return return function(self, *args, **kwargs) return wrapper @@ -592,7 +629,8 @@ def optimize(self, options): if self.if_states: raise CompilerError('Unclosed if/else blocks') - print('Processing tape', self.name, 'with %d blocks' % len(self.basicblocks)) + if self.program.verbose: + print('Processing tape', self.name, 'with %d blocks' % len(self.basicblocks)) for block in self.basicblocks: al.determine_scope(block, options) @@ -601,7 +639,7 @@ def optimize(self, options): # need to do this if there are several blocks if (options.merge_opens and self.merge_opens) or options.dead_code_elimination: for i,block in enumerate(self.basicblocks): - if len(block.instructions) > 0: + if len(block.instructions) > 0 and self.program.verbose: print('Processing basic block %s, %d/%d, %d instructions' % \ (block.name, i, len(self.basicblocks), \ len(block.instructions))) @@ -609,7 +647,7 @@ def optimize(self, options): merger = al.Merger(block, options, \ tuple(self.program.to_merge)) if options.dead_code_elimination: - if len(block.instructions) > 10000: + if len(block.instructions) > 100000: print('Eliminate dead code...') merger.eliminate_dead_code() if options.merge_opens and self.merge_opens: @@ -617,14 +655,14 @@ def optimize(self, options): block.used_from_scope = util.set_by_id() block.defined_registers = util.set_by_id() continue - if len(block.instructions) > 10000: + if len(block.instructions) > 100000: print('Merging instructions...') numrounds = merger.longest_paths_merge() block.n_rounds = numrounds block.n_to_merge = len(merger.open_nodes) - if numrounds > 0: + if numrounds > 0 and self.program.verbose: print('Program requires %d rounds of communication' % numrounds) - if merger.counter: + if merger.counter and self.program.verbose: print('Block requires', \ ', '.join('%d %s' % (y, x.__name__) \ for x, y in list(merger.counter.items()))) @@ -635,6 +673,9 @@ def optimize(self, options): if not (options.merge_opens and self.merge_opens): print('Not merging instructions in tape %s' % self.name) + if options.cisc: + self.expand_cisc() + # add jumps offset = 0 for block in self.basicblocks: @@ -659,7 +700,7 @@ def optimize(self, options): print('modp: %d clear, %d secret' % (reg_counts[RegType.ClearModp], reg_counts[RegType.SecretModp])) print('GF2N: %d clear, %d secret' % (reg_counts[RegType.ClearGF2N], reg_counts[RegType.SecretGF2N])) print('Re-allocating...') - allocator = al.StraightlineAllocator(REG_MAX) + allocator = al.StraightlineAllocator(REG_MAX, self.program) def alloc(block): for reg in sorted(block.used_from_scope, key=lambda x: (x.reg_type, x.i)): @@ -673,7 +714,7 @@ def alloc_loop(block): if child.instructions: left.append(child) for i,block in enumerate(reversed(self.basicblocks)): - if len(block.instructions) > 10000: + if len(block.instructions) > 100000: print('Allocating %s, %d/%d' % \ (block.name, i, len(self.basicblocks))) if block.exit_condition is not None: @@ -684,9 +725,11 @@ def alloc_loop(block): allocator.process(block.instructions, block.alloc_pool) # offline data requirements - print('Compile offline data requirements...') + if self.program.verbose: + print('Compile offline data requirements...') self.req_num = self.req_tree.aggregate() - print('Tape requires', self.req_num) + if self.program.verbose: + print('Tape requires', self.req_num) for req,num in sorted(self.req_num.items()): if num == float('inf') or num >= 2 ** 32: num = -1 @@ -708,6 +751,14 @@ def alloc_loop(block): self.basicblocks[-1].instructions.append( Compiler.instructions.guse_prep(req[1], num, \ add_to_prog=False)) + elif req[0] == 'edabit': + self.basicblocks[-1].instructions.append( + Compiler.instructions.use_edabit(False, req[1], num, \ + add_to_prog=False)) + elif req[0] == 'sedabit': + self.basicblocks[-1].instructions.append( + Compiler.instructions.use_edabit(True, req[1], num, \ + add_to_prog=False)) if not self.is_empty(): # bit length requirement @@ -723,6 +774,11 @@ def alloc_loop(block): print('Tape requires prime bit length', self.req_bit_length['p']) print('Tape requires galois bit length', self.req_bit_length['2']) + @unpurged + def expand_cisc(self): + for block in self.basicblocks: + block.expand_cisc() + @unpurged def _get_instructions(self): return itertools.chain.\ @@ -786,7 +842,7 @@ def count_regs(self, reg_type=None): def reset_registers(self): """ Reset register values to zero. """ - self.reg_values = RegType.create_dict(lambda: [0] * INIT_REG_MAX) + self.reg_values = RegType.create_dict(lambda: []) def get_value(self, reg_type, i): return self.reg_values[reg_type][i] @@ -826,7 +882,7 @@ def max(self, other): return res def cost(self): return sum(num * COST[req[0]][req[1]] for req,num in list(self.items()) \ - if req[1] != 'input') + if req[1] != 'input' and req[0] != 'edabit') def __str__(self): return ", ".join('%s inputs in %s from player %d' \ % (num, req[0], req[2]) \ @@ -927,9 +983,11 @@ def __init__(self, reg_type, program, value=None, size=None, i=None): self.relative_i = 0 if i is not None: self.i = i - else: + elif size is not None: self.i = program.reg_counter[reg_type] program.reg_counter[reg_type] += size + else: + self.i = float('inf') self.vector = [] if value is not None: self.value = value @@ -952,8 +1010,9 @@ def i(self, value): def set_size(self, size): if self.size == size: return - elif not self.program.options.assemblymode: - raise CompilerError('Mismatch of instruction and register size') + elif not self.program.program.options.assemblymode: + raise CompilerError('Mismatch of instruction and register size:' + ' %s != %s' % (self.size, size)) elif self.size == 1 and self.vectorbase is self: if '%s%d' % (self.reg_type, self.i) in compilerLib.VARS: # create vector register in assembly mode @@ -979,6 +1038,10 @@ def _new_by_number(self, i, size=1): return Tape.Register(self.reg_type, self.program, size=size, i=i) def get_vector(self, base, size): + if base == 0 and size == self.size: + return self + if size == 1: + return self[base] res = self._new_by_number(self.i + base, size=size) res.set_vectorbase(self) self.create_vector_elements() @@ -1007,7 +1070,10 @@ def __getitem__(self, index): def __len__(self): return self.size - + + def copy(self): + return Tape.Register(self.reg_type, Program.prog.curr_tape) + @property def value(self): return self.program.reg_values[self.reg_type][self.i] @@ -1029,6 +1095,9 @@ def is_clear(self): self.reg_type == RegType.ClearGF2N or \ self.reg_type == RegType.ClearInt + def __bool__(self): + raise CompilerError('cannot derive truth value from register') + def __str__(self): return self.reg_type + str(self.i) diff --git a/Compiler/types.py b/Compiler/types.py index ea8126a43..a13be46b2 100644 --- a/Compiler/types.py +++ b/Compiler/types.py @@ -95,7 +95,8 @@ class ClientMessageType: class MPCThread(object): - def __init__(self, target, name, args = [], runtime_arg = None): + def __init__(self, target, name, args = [], runtime_arg = 0, + single_thread = False): """ Create a thread from a callable object. """ if not callable(target): raise CompilerError('Target %s for thread %s is not callable' % (target,name)) @@ -105,18 +106,33 @@ def __init__(self, target, name, args = [], runtime_arg = None): self.args = args self.runtime_arg = runtime_arg self.running = 0 + self.tape_handle = program.new_tape(target, args, name, + single_thread=single_thread) + self.run_handles = [] def start(self, runtime_arg = None): self.running += 1 - program.start_thread(self, runtime_arg or self.runtime_arg) + self.run_handles.append(program.run_tape(self.tape_handle, \ + runtime_arg or self.runtime_arg)) def join(self): if not self.running: raise CompilerError('Thread %s is not running' % self.name) self.running -= 1 - program.stop_thread(self) + program.join_tape(self.run_handles.pop(0)) +def copy_doc(a, b): + try: + a.__doc__ = b.__doc__ + except: + pass + +def no_doc(operation): + def wrapper(*args, **kwargs): + return operation(*args, **kwargs) + return wrapper + def copy_doc(a, b): try: a.__doc__ = b.__doc__ @@ -131,7 +147,9 @@ def wrapper(*args, **kwargs): def vectorize(operation): def vectorized_operation(self, *args, **kwargs): if len(args): + from .GC.types import bits if (isinstance(args[0], Tape.Register) or isinstance(args[0], sfloat)) \ + and not isinstance(args[0], bits) \ and args[0].size != self.size: raise CompilerError('Different vector sizes of operands') set_global_vector_size(self.size) @@ -292,6 +310,10 @@ class _int(object): def bit_adder(*args, **kwargs): return intbitint.bit_adder(*args, **kwargs) + @staticmethod + def ripple_carry_adder(*args, **kwargs): + return intbitint.ripple_carry_adder(*args, **kwargs) + def if_else(self, a, b): """ MUX on bit in arithmetic circuits. @@ -416,6 +438,9 @@ def row_matrix_mul(cls, row, matrix, res_params=None): res_params) \ for k in range(len(row))).reduce_after_mul() +class _vec(object): + pass + class _register(Tape.Register, _number, _structure): @staticmethod def n_elements(): @@ -427,7 +452,7 @@ def conv(cls, val): val = val.read() if isinstance(val, cls): return val - elif not isinstance(val, _register): + elif not isinstance(val, (_register, _vec)): try: return type(val)(cls.conv(v) for v in val) except TypeError: @@ -467,8 +492,7 @@ def _expand_address(address, size): address = regint.conv(address) if size > 1 and address.size == 1: res = regint(size=size) - for i in range(size): - movint(res[i], address + regint(i, size=1)) + incint(res, address, 1) return res else: return address @@ -1089,6 +1113,29 @@ def get_random(cls, bit_length): rand(res, bit_length) return res + @classmethod + def inc(cls, size, base=0, step=1, repeat=1, wrap=None): + """ + Produce :py:class:`regint` vector with certain patterns. + This is particularly useful for :py:meth:`SubMultiArray.direct_mul`. + + :param size: Result size + :param base: First value + :param step: Increase step + :param repeat: Repeate this many times + :param wrap: Start over after this many increases + + The following produces (1, 1, 1, 3, 3, 3, 5, 5, 5, 7):: + + regint.inc(10, 1, 2, 3) + + """ + res = regint(size=size) + if wrap is None: + wrap = size + incint(res, cls.conv(base, size=1), step, repeat, wrap) + return res + @vectorized_classmethod def read_from_socket(cls, client_id, n=1): """ Receive n register values from socket """ @@ -1333,6 +1380,12 @@ def bit_compose(bits): res += bit return res + def shuffle(self): + """ Returns insecure shuffle of vector. """ + res = regint(size=len(self)) + shuffle(res, self) + return res + def reveal(self): """ Identity. """ return self @@ -1371,7 +1424,7 @@ def output(self): class _secret(_register): __slots__ = [] - mov = staticmethod(movs) + mov = staticmethod(set_instruction_type(movs)) PreOR = staticmethod(lambda l: floatingpoint.PreORC(l)) PreOp = staticmethod(lambda op, l: floatingpoint.PreOpL(op, l)) @@ -1475,9 +1528,7 @@ def matrix_mul(cls, A, B, n, res_params=None): res = cls(size=size) n_rows = len(A) // n n_cols = len(B) // n - dotprods(*sum(([res[j], [A[j // n_cols * n + k] for k in range(n)], - [B[k * n_cols + j % n_cols] for k in range(n)]] - for j in range(size)), [])) + matmuls(res, A, B, n_rows, n, n_cols) return res @no_doc @@ -1502,7 +1553,7 @@ def load_clear(self, val): @read_mem_value @vectorize def load_other(self, val): - from Compiler.GC.types import sbits + from Compiler.GC.types import sbits, sbitvec if isinstance(val, self.clear_type): self.load_clear(val) elif isinstance(val, type(self)): @@ -1510,9 +1561,16 @@ def load_other(self, val): elif isinstance(val, sbits): assert(val.n == self.size) r = self.get_dabit() - v = regint() - bitdecint_class(regint((r[1] ^ val).reveal()), *v) - movs(self, r[0].bit_xor(v)) + movs(self, r[0].bit_xor((r[1] ^ val).reveal().to_regint_by_bit())) + elif isinstance(val, sbitvec): + assert(sum(x.n for x in val.v) == self.size) + for val_part, base in zip(val, range(0, self.size, 64)): + left = min(64, self.size - base) + r = self.get_dabit(size=left) + v = regint(size=left) + bitdecint_class(regint((r[1] ^ val_part).reveal()), *v) + part = r[0].bit_xor(v) + vmovs(left, self.get_vector(base, left), part) else: self.load_clear(self.clear_type(val)) @@ -1555,8 +1613,8 @@ def mul(self, other): size or one size 1 for a value-vector multiplication. :param other: any compatible type """ - if isinstance(other, _secret) and max(self.size, other.size) > 1 \ - and min(self.size, other.size) == 1: + if isinstance(other, _secret) and (1 in (self.size, other.size)) \ + and (self.size, other.size) != (1, 1): x, y = (other, self) if self.size < other.size else (self, other) res = type(self)(size=x.size) mulrs(res, x, y) @@ -1667,10 +1725,35 @@ def get_dabit(cls): dabit(*res) return res + @vectorized_classmethod + def get_edabit(cls, n_bits, strict=False): + """ Bits in arithmetic and binary circuit """ + """ according to security model """ + if not program.use_edabit_for(strict, n_bits): + if program.use_dabit: + a, b = zip(*(sint.get_dabit() for i in range(n_bits))) + return sint.bit_compose(a), b + else: + a = [sint.get_random_bit() for i in range(n_bits)] + return sint.bit_compose(a), a + whole = cls() + size = get_global_vector_size() + from Compiler.GC.types import sbits, sbitvec + bits = [sbits.get_type(size)() for i in range(n_bits)] + if strict: + sedabit(whole, *bits) + else: + edabit(whole, *bits) + return whole, bits + @staticmethod def long_one(): return 1 + @staticmethod + def bit_decompose_clear(a, n_bits): + return floatingpoint.bits(a, n_bits) + @classmethod def get_raw_input_from(cls, player): res = cls() @@ -1733,6 +1816,15 @@ def store_in_mem(self, address): """ Store in memory by public address. """ self._store_in_mem(address, stms, stmsi) + @classmethod + def direct_matrix_mul(cls, A, B, n, m, l, reduce=None, indices=None): + if indices is None: + indices = [regint.inc(i) for i in (n, m, m, l)] + res = cls(size=indices[0].size * indices[3].size) + matmulsm(res, regint(A), regint(B), len(indices[0]), len(indices[1]), + len(indices[3]), *(list(indices) + [m, l])) + return res + def __init__(self, val=None, size=None): """ :param val: initialization (sint/cint/regint/int/cgf2n or list thereof) @@ -1810,6 +1902,7 @@ def __mod__(self, modulus): return self.mod2m(int(l)) raise NotImplementedError('Modulo only implemented for powers of two.') + @vectorize @read_mem_value def mod2m(self, m, bit_length=None, security=None, signed=True): """ Secret modulo power of two. @@ -1941,6 +2034,10 @@ def int_div(self, other, bit_length=None, security=None): comparison.Trunc(res, tmp, 2 * k, k, kappa, True) return res + def trunc_zeros(self, n_zeros, bit_length=None, signed=True): + bit_length = bit_length or program.bit_length + return comparison.TruncZeros(self, bit_length, n_zeros, signed) + @staticmethod def two_power(n): return floatingpoint.two_power(n) @@ -2120,11 +2217,14 @@ def bit_adder(cls, a, b, carry_in=0, get_carry=False): @classmethod def bit_adder_selection(cls, a, b, carry_in=0, get_carry=False): if cls.log_rounds: - return cls.carry_lookahead_adder(a, b, carry_in=carry_in) + return cls.carry_lookahead_adder(a, b, carry_in=carry_in, + get_carry=get_carry) elif cls.linear_rounds: - return cls.ripple_carry_adder(a, b, carry_in=carry_in) + return cls.ripple_carry_adder(a, b, carry_in=carry_in, + get_carry=get_carry) else: - return cls.carry_select_adder(a, b, carry_in=carry_in) + return cls.carry_select_adder(a, b, carry_in=carry_in, + get_carry=get_carry) @classmethod def carry_lookahead_adder(cls, a, b, fewer_inv=False, carry_in=0, @@ -2205,8 +2305,8 @@ def ripple_carry_adder(cls, a, b, carry_in=0): @staticmethod def full_adder(a, b, carry): - s = a + b - return s + carry, util.if_else(s, carry, a) + s = a ^ b + return s ^ carry, a ^ (s & (carry ^ a)) @staticmethod def bit_comparator(a, b): @@ -2243,6 +2343,7 @@ def add(self, other): b = util.bit_decompose(other, self.n_bits) return self.compose(self.bit_adder(a, b)) + @ret_cisc def mul(self, other): if type(other) == self.bin_type: raise CompilerError('Unclear multiplication') @@ -2270,12 +2371,13 @@ def mul(self, other): @classmethod def wallace_tree_from_matrix(cls, bit_matrix, get_carry=True): columns = [[_f for _f in (bit_matrix[j][i-j] \ - for j in range(min(len(bit_matrix), i + 1))) if _f] \ + for j in range(min(len(bit_matrix), i + 1))) \ + if not is_zero(_f)] \ for i in range(len(bit_matrix[0]))] return cls.wallace_tree_from_columns(columns, get_carry) @classmethod - def wallace_tree_from_columns(cls, columns, get_carry=True): + def wallace_tree_without_finish(cls, columns, get_carry=True): self = cls while max(len(c) for c in columns) > 2: new_columns = [[] for i in range(len(columns) + 1)] @@ -2296,7 +2398,12 @@ def wallace_tree_from_columns(cls, columns, get_carry=True): columns = new_columns[:-1] for col in columns: col.extend([0] * (2 - len(col))) - return self.bit_adder(*list(zip(*columns))) + return tuple(list(x) for x in zip(*columns)) + + @classmethod + def wallace_tree_from_columns(cls, columns, get_carry=True): + summands = cls.wallace_tree_without_finish(columns, get_carry) + return cls.bit_adder(*summands) @classmethod def wallace_tree(cls, rows): @@ -2556,15 +2663,15 @@ def parse_type(other, k=None, f=None): if isinstance(other, cfix.scalars): return cfix(other, k=k, f=f) elif isinstance(other, cint): - tmp = cfix() + tmp = cfix(k=k, f=f) tmp.load_int(other) return tmp elif isinstance(other, sint): - tmp = sfix() + tmp = sfix(k=k, f=f) tmp.load_int(other) return tmp elif isinstance(other, sfloat): - tmp = sfix(other) + tmp = sfix(other, k=k, f=f) return tmp else: return other @@ -2631,8 +2738,8 @@ def n_elements(): @vectorize_init def __init__(self, v=None, k=None, f=None, size=None): """ :param v: cfix/float/int """ - f = f or self.f - k = k or self.k + f = self.f if f is None else f + k = self.k if k is None else k self.f = f self.k = k self.size = get_global_vector_size() @@ -2693,7 +2800,9 @@ def add(self, other): def mul(self, other): """ Clear fixed-point multiplication. - :param other: cfix/cint/regint/int """ + :param other: cfix/cint/regint/int/sint """ + if isinstance(other, sint): + return sfix._new(self.v * other, k=self.k, f=self.f) other = parse_type(other) if isinstance(other, cfix): assert self.f == other.f @@ -2814,8 +2923,12 @@ def __truediv__(self, other): if isinstance(other, cfix): return cfix(library.cint_cint_division(self.v, other.v, self.k, self.f)) elif isinstance(other, sfix): - return sfix(library.FPDiv(self.v, other.v, self.k, self.f, - other.kappa, nearest=sfix.round_nearest)) + assert self.k == other.k + assert self.f == other.f + return sfix._new(library.FPDiv(self.v, other.v, self.k, self.f, + other.kappa, + nearest=sfix.round_nearest), + k=self.k, f=self.f) else: raise TypeError('Incompatible fixed point types in division') @@ -3022,25 +3135,29 @@ def from_sint(cls, other, k=None, f=None): """ Convert secret integer. :param other: sint """ - res = cls() - res.f = f or cls.f - res.k = k or cls.k + res = cls(k=k, f=f) res.load_int(cls.int_type.conv(other)) return res @classmethod def _new(cls, other, k=None, f=None): - res = cls(other) - res.k = k or cls.k - res.f = f or cls.f + res = cls(other, k=k, f=f) return res @vectorize_init - def __init__(self, _v=None, size=None): + def __init__(self, _v=None, k=None, f=None, size=None): """ :params _v: compile-time value (int/float) """ self.size = get_global_vector_size() - f = self.f - k = self.k + if k is None: + k = self.k + else: + self.k = k + if f is None: + f = self.f + else: + self.f = f + assert k is not None + assert f is not None # warning: don't initialize a sfix from a sint, this is only used in internal methods; # for external initialization use load_int. if _v is None: @@ -3110,7 +3227,7 @@ def mul(self, other): val = self.v.TruncMul(other.v, self.k + other.k, other.f, self.kappa, self.round_nearest) - if self.size >= other.size: + if 'vec' not in self.__dict__: return self._new(val, k=self.k, f=self.f) else: return self.vec._new(val, k=self.k, f=self.f) @@ -3123,7 +3240,7 @@ def mul(self, other): @vectorize def __neg__(self): """ Secret fixed-point negation. """ - return type(self)(-self.v) + return self._new(-self.v, k=self.k, f=self.f) @vectorize def __truediv__(self, other): @@ -3131,14 +3248,17 @@ def __truediv__(self, other): :param other: sfix/cfix/sint/cint/regint/int """ other = self.coerce(other) + assert self.k == other.k + assert self.f == other.f if isinstance(other, _fix): - return type(self)(library.FPDiv(self.v, other.v, self.k, self.f, - self.kappa, - nearest=self.round_nearest)) + v = library.FPDiv(self.v, other.v, self.k, self.f, self.kappa, + nearest=self.round_nearest) elif isinstance(other, cfix): - return type(self)(library.sint_cint_division(self.v, other.v, self.k, self.f, self.kappa)) + v = library.sint_cint_division(self.v, other.v, self.k, self.f, + self.kappa) else: raise TypeError('Incompatible fixed point types in division') + return self._new(v, k=self.k, f=self.f) @vectorize def __rtruediv__(self, other): @@ -3192,11 +3312,21 @@ def get_random(cls, lower, upper): lower = average - 0.5 * 2 ** log_range return cls._new(cls.int_type.get_random_int(n_bits)) + lower + @classmethod + def direct_matrix_mul(cls, A, B, n, m, l, reduce=True, indices=None): + # pre-multiplication must be identity + tmp = cls.int_type.direct_matrix_mul(A, B, n, m, l, indices=indices) + res = unreduced_sfix._new(tmp) + if reduce: + res = res.reduce_after_mul() + return res + def coerce(self, other): return parse_type(other, k=self.k, f=self.f) def mul_no_reduce(self, other, res_params=None): assert self.f == other.f + assert self.k == other.k return self.unreduced(self.v * other.v) def pre_mul(self): @@ -3221,6 +3351,8 @@ def __init__(self, v, k, m, kappa): self.k = k self.m = m self.kappa = kappa + assert self.k is not None + assert self.m is not None def __add__(self, other): if is_zero(other): @@ -3236,7 +3368,8 @@ def __add__(self, other): def reduce_after_mul(self): return sfix(sfix.int_type.round(self.v, self.k, self.m, self.kappa, nearest=sfix.round_nearest, - signed=True)) + signed=True), + k=self.k // 2, f=self.m) sfix.unreduced_type = unreduced_sfix @@ -4012,12 +4145,15 @@ def assign(self, other, base=0): pass try: other.store_in_mem(self.get_address(base)) - assert len(self) >= other.size + base + if len(self) != None and util.is_constant(base): + assert len(self) >= other.size + base except AttributeError: for i,j in enumerate(other): self[i] = j return self + assign_vector = assign + def assign_all(self, value, use_threads=True, conv=True): """ Assign the same value to all entries. @@ -4040,6 +4176,19 @@ def get_vector(self, base=0, size=None): size = size or self.length return self.value_type.load_mem(self.get_address(base), size=size) + get_part_vector = get_vector + + def get(self, indices): + return self.value_type.load_mem( + regint(self.address, size=len(indices)) + indices, + size=len(indices)) + + def expand_to_vector(self, index, size): + assert self.value_type.n_elements() == 1 + address = regint(size=size) + incint(address, regint(self.get_address(index), size=1), 0) + return self.value_type.load_mem(address, size=size) + def get_mem_value(self, index): return MemValue(self[index], self.get_address(index)) @@ -4082,12 +4231,15 @@ def __pow__(self, value): def shuffle(self): """ Insecure shuffle in place. """ - @library.for_range(len(self)) - def _(i): - j = regint.get_random(64) % (len(self) - i) - tmp = self[i] - self[i] = self[i + j] - self[i + j] = tmp + if self.value_type == regint: + self.assign(self.get_vector().shuffle()) + else: + @library.for_range(len(self)) + def _(i): + j = regint.get_random(64) % (len(self) - i) + tmp = self[i] + self[i] = self[i + j] + self[i + j] = tmp def reveal(self): """ Reveal the whole array. @@ -4184,6 +4336,14 @@ def assign(self, other): assert self.sizes == other.sizes self.assign_vector(other.get_vector()) + def get_part_vector(self, base=0, size=None): + assert self.value_type.n_elements() == 1 + part_size = reduce(operator.mul, self.sizes[1:]) + size = (size or len(self)) * part_size + assert size <= self.total_size() + return self.value_type.load_mem(self.address + base * part_size, + size=size) + def same_shape(self): """ :return: new multidimensional array with same shape and basic type """ return MultiArray(self.sizes, self.value_type) @@ -4254,6 +4414,7 @@ def mul(self, other, res_params=None): for i, x in enumerate(other): matrix[i][0] = x res = self * matrix + library.break_point() return Array.create_from(x[0] for x in res) elif isinstance(other, SubMultiArray): assert len(other.sizes) == 2 @@ -4292,6 +4453,32 @@ def _(k): else: raise NotImplementedError + def direct_mul(self, other, reduce=True, indices=None): + """ Matrix multiplication in the virtual machine. + + :param self: :py:class:`Matrix` / 2-dimensional :py:class:`MultiArray` + :param other: :py:class:`Matrix` / 2-dimensional :py:class:`MultiArray` + :param indices: 4-tuple of :py:class:`regint` vectors for index selection (default is complete multiplication) + :return: Matrix as vector of relevant type (row-major) + + The following executes a matrix multiplication selecting every third row + of :py:obj:`A`:: + + A = sfix.Matrix(7, 4) + B = sfix.Matrix(4, 5) + C = sfix.Matrix(3, 5) + C.assign_vector(A.direct_mul(B, indices=(regint.inc(3, 0, 3), + regint.inc(4), + regint.inc(4), + regint.inc(5))) + """ + assert len(self.sizes) == 2 + assert len(other.sizes) == 2 + assert self.sizes[1] == other.sizes[0] + return self.value_type.direct_matrix_mul(self.address, other.address, + self.sizes[0], *other.sizes, + reduce=reduce, indices=indices) + def budget_mul(self, other, n_rows, row, n_columns, column, reduce=True, res=None): assert len(self.sizes) == 2 @@ -4349,6 +4536,26 @@ def trans_mul(self, other, reduce=True, res=None): lambda x, j: [x[k][j] for k in range(len(x))], reduce=reduce, res=res) + def parallel_mul(self, other): + assert self.sizes[1] == other.sizes[0] + assert len(self.sizes) == 2 + assert len(other.sizes) == 2 + assert self.value_type.n_elements() == 1 + n = self.sizes[0] * other.sizes[1] + a = [] + b = [] + for i in range(self.sizes[1]): + addresses = regint(size=n) + incint(addresses, regint(self.address + i), self.sizes[1], + other.sizes[1], n) + a.append(self.value_type.load_mem(addresses, size=n)) + addresses = regint(size=n) + incint(addresses, regint(other.address + i * other.sizes[1]), 1, + 1, other.sizes[1]) + b.append(self.value_type.load_mem(addresses, size=n)) + res = self.value_type.dot_product(a, b) + return res + def transpose(self): """ Matrix transpose. diff --git a/ECDSA/EcdsaOptions.h b/ECDSA/EcdsaOptions.h index 619aaa032..6717323e4 100644 --- a/ECDSA/EcdsaOptions.h +++ b/ECDSA/EcdsaOptions.h @@ -15,6 +15,7 @@ class EcdsaOptions bool fewer_rounds; bool check_open; bool check_beaver_open; + bool R_after_msg; EcdsaOptions(ez::ezOptionParser& opt, int argc, const char** argv) { @@ -54,11 +55,21 @@ class EcdsaOptions "-B", // Flag token. "--no-beaver-open-check" // Flag token. ); + opt.add( + "", // Default. + 0, // Required? + 0, // Number of args expected. + 0, // Delimiter if expecting multiple args. + "Only open R after message is known", // Help description. + "-R", // Flag token. + "--R-after-msg" // Flag token. + ); opt.parse(argc, argv); prep_mul = not opt.isSet("-D"); fewer_rounds = opt.isSet("-P"); check_open = not opt.isSet("-C"); check_beaver_open = not opt.isSet("-B"); + R_after_msg = opt.isSet("-R"); opt.resetArgs(); } }; diff --git a/ECDSA/fake-spdz-ecdsa-party.cpp b/ECDSA/fake-spdz-ecdsa-party.cpp index bea4db5bb..68a14932a 100644 --- a/ECDSA/fake-spdz-ecdsa-party.cpp +++ b/ECDSA/fake-spdz-ecdsa-party.cpp @@ -17,6 +17,10 @@ #include "Processor/Input.hpp" #include "Processor/Processor.hpp" #include "Processor/Data_Files.hpp" +#include "Protocols/MascotPrep.hpp" +#include "GC/Secret.hpp" +#include "GC/TinyPrep.hpp" +#include "OT/NPartyTripleGenerator.hpp" #include @@ -41,6 +45,10 @@ int main(int argc, const char** argv) Sub_Data_Files prep(N, prefix, usage); typename pShare::Direct_MC MCp(keyp); ArithmeticProcessor _({}, 0); + BaseMachine machine; + machine.ot_setups.push_back({P, false}); + GC::ShareThread thread(N, + OnlineOptions::singleton, P, {}, usage); SubProcessor proc(_, MCp, prep, P); pShare sk, __; diff --git a/ECDSA/hm-ecdsa-party.hpp b/ECDSA/hm-ecdsa-party.hpp index 3863ff16b..5a3627f07 100644 --- a/ECDSA/hm-ecdsa-party.hpp +++ b/ECDSA/hm-ecdsa-party.hpp @@ -11,11 +11,14 @@ #include "Math/gfp.h" #include "ECDSA/P256Element.h" #include "Tools/Bundle.h" +#include "GC/TinyMC.h" +#include "GC/MaliciousCcdSecret.h" +#include "GC/CcdSecret.h" +#include "GC/VectorInput.h" #include "ECDSA/preprocessing.hpp" #include "ECDSA/sign.hpp" #include "Protocols/MaliciousRepMC.hpp" -#include "Protocols/MaliciousRepPrep.hpp" #include "Protocols/Beaver.hpp" #include "Protocols/fake-stuff.hpp" #include "Processor/Input.hpp" @@ -24,6 +27,8 @@ #include "GC/ShareSecret.hpp" #include "GC/RepPrep.hpp" #include "GC/ThreadMaster.hpp" +#include "GC/Secret.hpp" +#include "Machines/ShamirMachine.hpp" #include diff --git a/ECDSA/mal-shamir-ecdsa-party.cpp b/ECDSA/mal-shamir-ecdsa-party.cpp index 17f212ffd..cd5301572 100644 --- a/ECDSA/mal-shamir-ecdsa-party.cpp +++ b/ECDSA/mal-shamir-ecdsa-party.cpp @@ -7,8 +7,6 @@ #include "Protocols/Shamir.hpp" #include "Protocols/ShamirInput.hpp" -#include "Protocols/ShamirMC.hpp" -#include "Protocols/MaliciousShamirMC.hpp" #include "hm-ecdsa-party.hpp" diff --git a/ECDSA/ot-ecdsa-party.hpp b/ECDSA/ot-ecdsa-party.hpp index f9cc9c9fc..5cee656e7 100644 --- a/ECDSA/ot-ecdsa-party.hpp +++ b/ECDSA/ot-ecdsa-party.hpp @@ -18,6 +18,7 @@ #include "Processor/Processor.hpp" #include "Processor/Data_Files.hpp" #include "Processor/Input.hpp" +#include "GC/TinyPrep.hpp" #include @@ -103,6 +104,8 @@ void run(int argc, const char** argv) typename pShare::Direct_MC MCp(keyp, N, 0); ArithmeticProcessor _({}, 0); typename pShare::LivePrep sk_prep(0, usage); + GC::ShareThread thread(N, + OnlineOptions::singleton, P, {}, usage); SubProcessor sk_proc(_, MCp, sk_prep, P); pShare sk, __; // synchronize diff --git a/ECDSA/preprocessing.hpp b/ECDSA/preprocessing.hpp index 449f9331f..6b17eb7a0 100644 --- a/ECDSA/preprocessing.hpp +++ b/ECDSA/preprocessing.hpp @@ -11,6 +11,13 @@ #include "Processor/Data_Files.h" #include "Protocols/ReplicatedPrep.h" #include "Protocols/MaliciousShamirShare.h" +#include "GC/TinierSecret.h" +#include "GC/TinierPrep.h" +#include "GC/MaliciousCcdSecret.h" +#include "GC/TinyMC.h" + +#include "GC/TinierSharePrep.hpp" +#include "GC/CcdSecret.h" template class T> class EcTuple @@ -18,6 +25,8 @@ class EcTuple public: T a; T b; + P256Element::Scalar c; + T secret_R; P256Element R; }; @@ -62,9 +71,10 @@ void preprocessing(vector>& tuples, int buffer_size, for (int i = 0; i < buffer_size; i++) secret_Rs.push_back(bs[i] / cs_opened[i]); } - vector opened_Rs; + vector opened_Rs(buffer_size); typename cShare::Direct_MC MCc(MCp.get_alphai()); - MCc.POpen_Begin(opened_Rs, secret_Rs, extra_player); + if (not opts.R_after_msg) + MCc.POpen_Begin(opened_Rs, secret_Rs, extra_player); if (prep_mul) { protocol.init_mul(&proc); @@ -74,10 +84,13 @@ void preprocessing(vector>& tuples, int buffer_size, } if (opts.fewer_rounds) MCp.POpen_End(cs_opened, cs, extra_player); - MCc.POpen_End(opened_Rs, secret_Rs, extra_player); - if (opts.fewer_rounds) - for (int i = 0; i < buffer_size; i++) - opened_Rs[i] /= cs_opened[i]; + if (not opts.R_after_msg) + { + MCc.POpen_End(opened_Rs, secret_Rs, extra_player); + if (opts.fewer_rounds) + for (int i = 0; i < buffer_size; i++) + opened_Rs[i] /= cs_opened[i]; + } if (prep_mul) protocol.stop_exchange(); if (opts.check_open) @@ -88,7 +101,7 @@ void preprocessing(vector>& tuples, int buffer_size, { tuples.push_back( { inv_ks[i], prep_mul ? protocol.finalize_mul() : pShare(), - opened_Rs[i] }); + cs_opened[i], secret_Rs[i], opened_Rs[i] }); } timer.stop(); cout << "Generated " << buffer_size << " tuples in " << timer.elapsed() @@ -112,6 +125,7 @@ void check(vector>& tuples, T sk, assert(open_sk * inv_k == MC.POpen(tuple.b, P)); assert(tuple.R == k); } + MC.Check(P); } template<> diff --git a/ECDSA/shamir-ecdsa-party.cpp b/ECDSA/shamir-ecdsa-party.cpp index 419e4fde9..e893feb70 100644 --- a/ECDSA/shamir-ecdsa-party.cpp +++ b/ECDSA/shamir-ecdsa-party.cpp @@ -7,7 +7,6 @@ #include "Protocols/Shamir.hpp" #include "Protocols/ShamirInput.hpp" -#include "Protocols/ShamirMC.hpp" #include "hm-ecdsa-party.hpp" diff --git a/ECDSA/sign.hpp b/ECDSA/sign.hpp index f6f4d6631..74e554982 100644 --- a/ECDSA/sign.hpp +++ b/ECDSA/sign.hpp @@ -49,7 +49,10 @@ inline P256Element::Scalar hash_to_scalar(const unsigned char* message, size_t l template class T> EcSignature sign(const unsigned char* message, size_t length, EcTuple tuple, - typename T::MAC_Check& MC, Player& P, + typename T::MAC_Check& MC, + typename T::MAC_Check& MCc, + Player& P, + EcdsaOptions opts, P256Element pk, T sk = {}, SubProcessor>* proc = 0) @@ -60,16 +63,30 @@ EcSignature sign(const unsigned char* message, size_t length, size_t start = P.sent; auto stats = P.comm_stats; EcSignature signature; - signature.R = tuple.R; + vector opened_R; + if (opts.R_after_msg) + MCc.POpen_Begin(opened_R, {tuple.secret_R}, P); T prod = tuple.b; + auto& protocol = proc->protocol; if (proc) { - auto& protocol = proc->protocol; protocol.init_mul(proc); protocol.prepare_mul(sk, tuple.a); - protocol.exchange(); + protocol.start_exchange(); + } + if (opts.R_after_msg) + { + MCc.POpen_End(opened_R, {tuple.secret_R}, P); + tuple.R = opened_R[0]; + if (opts.fewer_rounds) + tuple.R /= tuple.c; + } + if (proc) + { + protocol.stop_exchange(); prod = protocol.finalize_mul(); } + signature.R = tuple.R; auto rx = tuple.R.x(); signature.s = MC.open( tuple.a * hash_to_scalar(message, length) + prod * rx, P); @@ -132,7 +149,7 @@ void sign_benchmark(vector>& tuples, T sk, for (size_t i = 0; i < min(10lu, tuples.size()); i++) { - check(sign(message, 1 << i, tuples[i], MCp, P, pk, sk, proc), message, + check(sign(message, 1 << i, tuples[i], MCp, MCc, P, opts, pk, sk, proc), message, 1 << i, pk); if (not opts.check_open) continue; @@ -142,6 +159,7 @@ void sign_benchmark(vector>& tuples, T sk, auto stats = check_player.comm_stats; auto start = check_player.sent; MCp.Check(P); + MCc.Check(P); cout << "Online checking took " << timer.elapsed() * 1e3 << " ms and sending " << (check_player.sent - start) << " bytes" << endl; auto diff = (check_player.comm_stats - stats); diff --git a/Exceptions/Exceptions.h b/Exceptions/Exceptions.h index 6eb713c0c..2acc5a3e2 100644 --- a/Exceptions/Exceptions.h +++ b/Exceptions/Exceptions.h @@ -221,4 +221,12 @@ class no_singleton : runtime_error } }; +class ran_out +{ + const char* what() const + { + return "insufficient preprocessing"; + } +}; + #endif diff --git a/FHE/AddableVector.cpp b/FHE/AddableVector.cpp new file mode 100644 index 000000000..3f4ba2d99 --- /dev/null +++ b/FHE/AddableVector.cpp @@ -0,0 +1,45 @@ +/* + * AddableVector.cpp + * + */ + +#include "AddableVector.h" +#include "Rq_Element.h" +#include "FHE_Keys.h" + +template +AddableVector AddableVector::mul_by_X_i(int j, + const FHE_PK& pk) const +{ + int phi_m = this->size(); + assert(phi_m == pk.get_params().phi_m()); + AddableVector res(phi_m); + for (int i = 0; i < phi_m; i++) + { + int k = j + i, s = 1; + while (k >= phi_m) + { + k -= phi_m; + s = -s; + } + if (s == 1) + { + res[k] = (*this)[i]; + } + else + { + res[k] = -(*this)[i]; + } + } + return res; +} + +template +AddableVector> AddableVector>::mul_by_X_i(int j, + const FHE_PK& pk) const; +template +AddableVector> AddableVector>::mul_by_X_i(int j, + const FHE_PK& pk) const; +template +AddableVector> AddableVector>::mul_by_X_i(int j, + const FHE_PK& pk) const; diff --git a/FHE/AddableVector.h b/FHE/AddableVector.h index 90d5e7a42..19576fe17 100644 --- a/FHE/AddableVector.h +++ b/FHE/AddableVector.h @@ -10,6 +10,7 @@ using namespace std; #include "FHE/Plaintext.h" +#include "Rq_Element.h" template class AddableVector: public vector @@ -27,6 +28,11 @@ class AddableVector: public vector this->assign(other.begin(), other.end()); } + AddableVector(const Rq_Element& other) : + AddableVector(other.to_vec_bigint()) + { + } + template void allocate_slots(const U& init) { @@ -66,6 +72,8 @@ class AddableVector: public vector (*this)[i].mul(x[i], y[i]); } + AddableVector mul_by_X_i(int i, const FHE_PK& pk) const; + void generateUniform(PRNG& G, int n_bits) { for (unsigned int i = 0; i < this->size(); i++) @@ -171,6 +179,19 @@ class AddableMatrix: public AddableVector > for (int i = 0; i < n; i++) (*this)[i].resize(m); } + + AddableMatrix mul_by_X_i(int i, const FHE_PK& pk) const; }; +template +AddableMatrix AddableMatrix::mul_by_X_i(int i, + const FHE_PK& pk) const +{ + AddableMatrix res; + res.resize(this->size()); + for (size_t j = 0; j < this->size(); j++) + res[j] = (*this)[j].mul_by_X_i(i, pk); + return res; +} + #endif /* FHEOFFLINE_ADDABLEVECTOR_H_ */ diff --git a/FHE/Ciphertext.cpp b/FHE/Ciphertext.cpp index b04e74423..2365c4d0d 100644 --- a/FHE/Ciphertext.cpp +++ b/FHE/Ciphertext.cpp @@ -1,5 +1,6 @@ #include "Ciphertext.h" #include "PPData.h" +#include "P2Data.h" #include "Exceptions/Exceptions.h" Ciphertext::Ciphertext(const FHE_PK& pk) : Ciphertext(pk.get_params()) @@ -118,12 +119,11 @@ template void mul(Ciphertext& ans,const Plaintext& a,const Ciphertext& c) { a.to_poly(); - const vector& aa=a.get_poly(); int lev=c.cc0.level(); Rq_Element ra((*ans.params).FFTD(),evaluation,evaluation); if (lev==0) { ra.lower_level(); } - ra.from_vec(aa); + ra.from(a.get_iterator()); ans.mul(c, ra); } diff --git a/FHE/Ciphertext.h b/FHE/Ciphertext.h index 84fe88cd9..26dd94727 100644 --- a/FHE/Ciphertext.h +++ b/FHE/Ciphertext.h @@ -35,11 +35,17 @@ class Ciphertext Ciphertext(const FHE_PK &pk); + Ciphertext(const Rq_Element& a0, const Rq_Element& a1, const Ciphertext& C) : + Ciphertext(C.get_params()) + { + set(a0, a1, C.get_pk_id()); + } + ~Ciphertext() { ; } // Rely on default copy assignment/constructor - word get_pk_id() { return pk_id; } + word get_pk_id() const { return pk_id; } void set(const Rq_Element& a0, const Rq_Element& a1, word pk_id) { cc0=a0; cc1=a1; this->pk_id = pk_id; } @@ -93,9 +99,14 @@ class Ciphertext template Ciphertext& operator*=(const Plaintext_& other) { ::mul(*this, *this, other); return *this; } - Ciphertext mul(const Ciphertext& x, const FHE_PK& pk) const + Ciphertext mul(const FHE_PK& pk, const Ciphertext& x) const { Ciphertext res(*params); ::mul(res, *this, x, pk); return res; } + Ciphertext mul_by_X_i(int i, const FHE_PK&) const + { + return {cc0.mul_by_X_i(i), cc1.mul_by_X_i(i), *this}; + } + int level() const { return cc0.level(); } // pack/unpack (like IO) also assume params are known and already set diff --git a/FHE/DiscreteGauss.cpp b/FHE/DiscreteGauss.cpp index 4e6e4389f..89fcd0aef 100644 --- a/FHE/DiscreteGauss.cpp +++ b/FHE/DiscreteGauss.cpp @@ -4,34 +4,30 @@ void DiscreteGauss::set(double RR) { - R=RR; - e=exp(1); e1=exp(0.25); e2=exp(-1.35); + if (RR > 0 or NewHopeB < 1) + NewHopeB = max(1, int(round(2 * RR * RR))); + assert(NewHopeB > 0); } -/* Return a value distributed normaly with std dev R */ -int DiscreteGauss::sample(PRNG& G, int stretch) const +/* This uses the approximation to a Gaussian via + * binomial distribution + * + * This procedure consumes 2*NewHopeB bits + * + */ +int DiscreteGauss::sample(PRNG &G, int stretch) const { - /* Uses the ratio method from Wikipedia to get a - Normal(0,1) variable X - Then multiplies X by R - */ - double U,V,X,R1,R2,R3,X2; - int ans; - while (true) - { U=G.get_double(); - V=G.get_double(); - R1=5-4*e1*U; - R2=4*e2/U+1.4; - R3=-4/log(U); - X=sqrt(8/e)*(V-0.5)/U; - X2=X*X; - if (X2<=R1 || (X2 0) + h=hh; DG.set(R); } @@ -124,7 +121,7 @@ bool RandomVectors::operator!=(const RandomVectors& other) const bool DiscreteGauss::operator!=(const DiscreteGauss& other) const { - if (other.R != R or other.e != e or other.e1 != e1 or other.e2 != e2) + if (other.NewHopeB != NewHopeB) return true; else return false; diff --git a/FHE/DiscreteGauss.h b/FHE/DiscreteGauss.h index 62b5bfd8a..b65247a40 100644 --- a/FHE/DiscreteGauss.h +++ b/FHE/DiscreteGauss.h @@ -7,35 +7,33 @@ #include #include "Math/modp.h" +#include "Math/gfp.h" #include "Tools/random.h" #include - -/* Uses the Ratio method as opposed to the Box-Muller method - * as the Ratio method is thread safe, but it is 50% slower - */ +#include class DiscreteGauss { - double R; // Standard deviation - double e; // Precomputed exp(1) - double e1; // Precomputed exp(0.25) - double e2; // Precomputed exp(-1.35) + /* This is the bound we use on for the NewHope approximation + * to a discrete Gaussian with sigma=sqrt(B/2) + */ + int NewHopeB; public: void set(double R); - void pack(octetStream& o) const { o.serialize(R); } - void unpack(octetStream& o) { o.unserialize(R); } + void pack(octetStream& o) const { o.serialize(NewHopeB); } + void unpack(octetStream& o) { o.unserialize(NewHopeB); } - DiscreteGauss() { set(0); } DiscreteGauss(double R) { set(R); } ~DiscreteGauss() { ; } // Rely on default copy constructor/assignment int sample(PRNG& G, int stretch = 1) const; - double get_R() const { return R; } + double get_R() const { return sqrt(0.5 * NewHopeB); } + int get_NewHopeB() const { return NewHopeB; } bool operator!=(const DiscreteGauss& other) const; }; @@ -56,8 +54,8 @@ class RandomVectors void pack(octetStream& o) const { o.store(n); o.store(h); DG.pack(o); } void unpack(octetStream& o) { o.get(n); o.get(h); DG.unpack(o); } - RandomVectors() { ; } - RandomVectors(int nn,int hh,double R) { set(nn,hh,R); } + RandomVectors(int h, double R) : n(0), h(h), DG(R) {} + RandomVectors(int nn,int hh,double R) : DG(R) { set(nn,hh,R); } ~RandomVectors() { ; } // Rely on default copy constructor/assignment @@ -81,7 +79,8 @@ class RandomVectors bool operator!=(const RandomVectors& other) const; }; -class RandomGenerator : public Generator +template +class RandomGenerator : public Generator { protected: mutable PRNG G; @@ -90,39 +89,42 @@ class RandomGenerator : public Generator RandomGenerator(PRNG& G) { this->G.SetSeed(G); } }; -class UniformGenerator : public RandomGenerator +template +class UniformGenerator : public RandomGenerator { int n_bits; bool positive; public: UniformGenerator(PRNG& G, int n_bits, bool positive = true) : - RandomGenerator(G), n_bits(n_bits), positive(positive) {} - Generator* clone() const { return new UniformGenerator(*this); } - void get(bigint& x) const { G.get_bigint(x, n_bits, positive); } + RandomGenerator(G), n_bits(n_bits), positive(positive) {} + Generator* clone() const { return new UniformGenerator(*this); } + void get(T& x) const { this->G.get(x, n_bits, positive); } }; -class GaussianGenerator : public RandomGenerator +template +class GaussianGenerator : public RandomGenerator { DiscreteGauss DG; int stretch; public: GaussianGenerator(const DiscreteGauss& DG, PRNG& G, int stretch = 1) : - RandomGenerator(G), DG(DG), stretch(stretch) {} - Generator* clone() const { return new GaussianGenerator(*this); } - void get(bigint& x) const { mpz_set_si(x.get_mpz_t(), DG.sample(G, stretch)); } + RandomGenerator(G), DG(DG), stretch(stretch) {} + Generator* clone() const { return new GaussianGenerator(*this); } + void get(T& x) const { x = DG.sample(this->G, stretch); } }; int sample_half(PRNG& G); -class HalfGenerator : public RandomGenerator +template +class HalfGenerator : public RandomGenerator { public: HalfGenerator(PRNG& G) : - RandomGenerator(G) {} - Generator* clone() const { return new HalfGenerator(*this); } - void get(bigint& x) const { mpz_set_si(x.get_mpz_t(), sample_half(G)); } + RandomGenerator(G) {} + Generator* clone() const { return new HalfGenerator(*this); } + void get(T& x) const { x = sample_half(this->G); } }; #endif diff --git a/FHE/FFT_Data.cpp b/FHE/FFT_Data.cpp index c38d26e2f..7b858a8bb 100644 --- a/FHE/FFT_Data.cpp +++ b/FHE/FFT_Data.cpp @@ -1,7 +1,7 @@ #include "FHE/FFT_Data.h" #include "FHE/FFT.h" -#include "Math/Subroutines.h" +#include "FHE/Subroutines.h" void FFT_Data::assign(const FFT_Data& FFTD) diff --git a/FHE/FFT_Data.h b/FHE/FFT_Data.h index 001a9c99d..4366dda3b 100644 --- a/FHE/FFT_Data.h +++ b/FHE/FFT_Data.h @@ -4,6 +4,7 @@ #include "Math/modp.h" #include "Math/Zp_Data.h" #include "Math/gfp.h" +#include "Math/fixint.h" #include "FHE/Ring.h" /* Class for holding modular arithmetic data wrt the ring @@ -36,6 +37,7 @@ class FFT_Data public: typedef gfp T; typedef bigint S; + typedef fixint poly_type; void init(const Ring& Rg,const Zp_Data& PrD); diff --git a/FHE/FHE_Keys.cpp b/FHE/FHE_Keys.cpp index 7effaa660..e369c08e9 100644 --- a/FHE/FHE_Keys.cpp +++ b/FHE/FHE_Keys.cpp @@ -161,7 +161,7 @@ void FHE_PK::encrypt(Ciphertext& c, const vector& mess, const Random_Coins& rc) const { Rq_Element mm((*params).FFTD(),polynomial,polynomial); - mm.from_vec(mess); + mm.from(Iterator(mess)); quasi_encrypt(c, mm, rc); } @@ -302,6 +302,7 @@ void FHE_SK::dist_decrypt_1(vector& vv,const Ciphertext& ctx,int player_ { dec_sh.negate(); } // Now convert to a vector of bigint's and add the required randomness + assert(pr != 0); bigint Bd=((*params).B()<<(*params).secp())/(num_players*pr); Bd=Bd/2; // make slightly smaller due to rounding issues @@ -334,6 +335,29 @@ void FHE_SK::dist_decrypt_2(vector& vv,const vector& vv1) const } } +void FHE_PK::pack(octetStream& o) const +{ + o.append((octet*) "PKPKPKPK", 8); + a0.pack(o); + b0.pack(o); + Sw_a.pack(o); + Sw_b.pack(o); + pr.pack(o); +} + +void FHE_PK::unpack(octetStream& o) +{ + char tag[8]; + o.consume((octet*) tag, 8); + if (memcmp(tag, "PKPKPKPK", 8)) + throw runtime_error("invalid serialization of public key"); + a0.unpack(o); + b0.unpack(o); + Sw_a.unpack(o); + Sw_b.unpack(o); + pr.unpack(o); +} + bool FHE_PK::operator!=(const FHE_PK& x) const { @@ -379,7 +403,7 @@ template Ciphertext FHE_PK::encrypt(const Plaintext_& mess) const; template void FHE_PK::encrypt(Ciphertext& c, const vector& mess, const Random_Coins& rc) const; -template void FHE_PK::encrypt(Ciphertext& c, const vector& mess, +template void FHE_PK::encrypt(Ciphertext& c, const vector>& mess, const Random_Coins& rc) const; template Plaintext_ FHE_SK::decrypt(const Ciphertext& c, diff --git a/FHE/FHE_Keys.h b/FHE/FHE_Keys.h index 82eb15979..1639c436f 100644 --- a/FHE/FHE_Keys.h +++ b/FHE/FHE_Keys.h @@ -27,7 +27,7 @@ class FHE_SK // secret key always on lower level void assign(const Rq_Element& s) { sk=s; sk.lower_level(); } - FHE_SK(const FHE_Params& pms, const bigint& p = 0) + FHE_SK(const FHE_Params& pms, const bigint& p) : sk(pms.FFTD(),evaluation,evaluation) { params=&pms; pr=p; } FHE_SK(const FHE_PK& pk); @@ -110,6 +110,17 @@ class FHE_PK Sw_b(pms.FFTD(),evaluation,evaluation) { params=&pms; pr=p; } + FHE_PK(const FHE_Params& pms, int p) : + FHE_PK(pms, bigint(p)) + { + } + + template + FHE_PK(const FHE_Params& pms, const FD& FTD) : + FHE_PK(pms, FTD.get_prime()) + { + } + // Rely on default copy constructor/assignment const Rq_Element& a() const { return a0; } @@ -148,10 +159,8 @@ class FHE_PK friend istream& operator>>(istream& s, FHE_PK& PK) { s >> PK.a0 >> PK.b0 >> PK.Sw_a >> PK.Sw_b; return s; } - void pack(octetStream& o) const - { a0.pack(o); b0.pack(o); Sw_a.pack(o); Sw_b.pack(o); pr.pack(o); } - void unpack(octetStream& o) - { a0.unpack(o); b0.unpack(o); Sw_a.unpack(o); Sw_b.unpack(o); pr.unpack(o); } + void pack(octetStream& o) const; + void unpack(octetStream& o); bool operator!=(const FHE_PK& x) const; diff --git a/FHE/FHE_Params.h b/FHE/FHE_Params.h index 6c5ebb792..0591cb44f 100644 --- a/FHE/FHE_Params.h +++ b/FHE/FHE_Params.h @@ -29,14 +29,14 @@ class FHE_Params public: - FHE_Params(int n_mults = 1) : FFTData(n_mults + 1), sec_p(-1) {} + FHE_Params(int n_mults = 1) : FFTData(n_mults + 1), Chi(64, 0.7), sec_p(-1) {} int n_mults() const { return FFTData.size() - 1; } // Rely on default copy assignment/constructor (not that they should // ever be needed) - void set(const Ring& R,const vector& primes,double r=3.2,int hwt=64); + void set(const Ring& R,const vector& primes,double r=-1,int hwt=-1); void set_sec(int sec); vector sampleGaussian(PRNG& G, int noise_boost = 1) const diff --git a/FHE/NTL-Subs.cpp b/FHE/NTL-Subs.cpp index 93225300b..c2ee07dab 100644 --- a/FHE/NTL-Subs.cpp +++ b/FHE/NTL-Subs.cpp @@ -24,40 +24,17 @@ NTL_CLIENT #include "FHEOffline/DataSetup.h" - -template <> -void generate_setup(int n_parties, int plaintext_length, int sec, - FHE_Params& params, FFT_Data& FTD, int slack, bool round_up) -{ - Ring Rp; - bigint p0p,p1p,p; - SPDZ_Data_Setup_Char_p(Rp, FTD, p0p, p1p, n_parties, plaintext_length, sec, - slack, round_up); - params.set(Rp, {p0p, p1p}); -} - - -template <> -void generate_setup(int n_parties, int plaintext_length, int sec, - FHE_Params& params, P2Data& P2D, int slack, bool round_up) -{ - Ring R; - bigint pr0,pr1; - SPDZ_Data_Setup_Char_2(R, P2D, pr0, pr1, n_parties, plaintext_length, sec, - slack, round_up); - params.set(R, {pr0, pr1}); -} - - void generate_setup(int n, int lgp, int lg2, int sec, bool skip_2, int slack, bool round_up) { DataSetup setup; // do the full setup for SHE data - generate_setup(n, lgp, sec, setup.setup_p.params, setup.setup_p.FieldD, slack, round_up); + Parameters(n, lgp, sec, slack, round_up).generate_setup(setup.setup_p.params, + setup.setup_p.FieldD); if (!skip_2) - generate_setup(n, lg2, sec, setup.setup_2.params, setup.setup_2.FieldD, slack, round_up); + Parameters(n, lg2, sec, slack, round_up).generate_setup( + setup.setup_2.params, setup.setup_2.FieldD); setup.write_setup(skip_2); } @@ -208,9 +185,9 @@ int finalize_lengths(int& lg2p0, int& lg2p1, int n, int m, int* lg2pi, bool roun /* * Subroutine for creating the FHE parameters */ -int SPDZ_Data_Setup_Char_p_Sub(Ring& R, bigint& pr0, bigint& pr1, int n, - int idx, int& m, bigint& p, int sec, int slack = 0, bool round_up = false) +int Parameters::SPDZ_Data_Setup_Char_p_Sub(int idx, int& m, bigint& p) { + int n = n_parties; int lg2pi[5][2][9] = { { {130,132,132,132,132,132,132,132,132}, {104,104,104,106,106,108,108,110,110} }, @@ -291,13 +268,13 @@ void generate_modulus(bigint& pr, const int m, const bigint p, const int lg2pr, /* * Create the char p FHE parameters */ -void SPDZ_Data_Setup_Char_p(Ring& R, FFT_Data& FTD, bigint& pr0, bigint& pr1, - int n, int lgp, int sec, int slack, bool round_up) +template <> +void Parameters::SPDZ_Data_Setup(FFT_Data& FTD) { bigint p; int idx, m; - SPDZ_Data_Setup_Primes(p, lgp, idx, m); - SPDZ_Data_Setup_Char_p_Sub(R, pr0, pr1, n, idx, m, p, sec, slack, round_up); + SPDZ_Data_Setup_Primes(p, plaintext_length, idx, m); + SPDZ_Data_Setup_Char_p_Sub(idx, m, p); Zp_Data Zp(p); gfp::init_field(p); @@ -574,9 +551,12 @@ void char_2_dimension(int& m, int& lg2) } } -void SPDZ_Data_Setup_Char_2(Ring& R, P2Data& P2D, bigint& pr0, bigint& pr1, - int n, int lg2, int sec, int slack, bool round_up) +template <> +void Parameters::SPDZ_Data_Setup(P2Data& P2D) { + int n = n_parties; + int lg2 = plaintext_length; + int lg2pi[2][9] = { {70,70,70,70,70,70,70,70,70}, {70,75,75,75,75,80,80,80,80} @@ -760,7 +740,8 @@ void SPDZ_Data_Setup_Char_p_General(Ring& R, PPData& PPD, bigint& pr0, { throw bad_value(); } cout << "Chosen value of m=" << m << "\t\t phi(m)=" << bphi_m << " : " << min_hwt << " : " << bmx << endl; - SPDZ_Data_Setup_Char_p_Sub(R,pr0,pr1,n,idx,m,p,sec); + Parameters parameters(n, lgp, sec); + parameters.SPDZ_Data_Setup_Char_p_Sub(idx,m,p); int mx=0; for (int i=0; i + void generate_setup(FHE_Params& params, FD& FTD) + { + SPDZ_Data_Setup(FTD); + params.set(R, {pr0, pr1}); + } + + template + void SPDZ_Data_Setup(FD& FTD); + + int SPDZ_Data_Setup_Char_p_Sub(int idx, int& m, bigint& p); + +}; + // Main setup routine (need NTL if online_only is false) void generate_setup(int nparties, int lgp, int lg2, int sec, bool skip_2 = false, int slack = 0, bool round_up = false); -template -void generate_setup(int n_parties, int plaintext_length, int sec, - FHE_Params& params, FD& FTD, int slack, bool round_up); - // semi-homomorphic, includes slack template int generate_semi_setup(int plaintext_length, int sec, @@ -35,10 +63,6 @@ int common_semi_setup(FHE_Params& params, int m, bigint p, int lgp0, int lgp1, void init(Ring& Rg,int m); void init(P2Data& P2D,const Ring& Rg); -// For use when we only care about p being of a certain size -void SPDZ_Data_Setup_Char_p(Ring& R, FFT_Data& FTD, bigint& pr0, bigint& pr1, - int n, int lgp, int sec, int slack = 0, bool round_up = false); - // For use when we want p to be a specific value void SPDZ_Data_Setup_Char_p_General(Ring& R, PPData& PPD, bigint& pr0, bigint& pr1, int n, int sec, bigint& p); @@ -53,9 +77,6 @@ void generate_modulus(bigint& pr, const int m, const bigint p, const int lg2pr, // pre-generated dimensions for characteristic 2 void char_2_dimension(int& m, int& lg2); -void SPDZ_Data_Setup_Char_2(Ring& R, P2Data& P2D, bigint& pr0, bigint& pr1, - int n, int lg2, int sec = -1, int slacke = 0, bool round_up = false); - // try to avoid expensive generation by loading from disk if possible void load_or_generate(P2Data& P2D, const Ring& Rg); diff --git a/FHE/NoiseBounds.cpp b/FHE/NoiseBounds.cpp index c8c902709..dce96efd9 100644 --- a/FHE/NoiseBounds.cpp +++ b/FHE/NoiseBounds.cpp @@ -5,6 +5,7 @@ #include #include "FHEOffline/Proof.h" +#include "Protocols/CowGearOptions.h" #include @@ -13,14 +14,44 @@ SemiHomomorphicNoiseBounds::SemiHomomorphicNoiseBounds(const bigint& p, p(p), phi_m(phi_m), n(n), sec(sec), slack(numBits(Proof::slack(slack_param, sec, phi_m))), sigma(sigma), h(h) { + if (sigma <= 0) + this->sigma = sigma = FHE_Params().get_R(); + cerr << "Standard deviation: " << this->sigma << endl; h += extra_h * sec; - B_clean = (phi_m * p / 2 - + p * sigma - * (16 * phi_m * sqrt(n / 2) + 6 * sqrt(phi_m) - + 16 * sqrt(n * h * phi_m))) << slack; - B_scale = p * sqrt(3 * phi_m) * (1 + 8 * sqrt(n * h) / 3); - drown = 1 + (bigint(1) << sec); - cout << "log(slack): " << slack << endl; + produce_epsilon_constants(); + + if (CowGearOptions::singleton.top_gear()) + { + // according to documentation of SCALE-MAMBA 1.7 + // excluding a factor of n because we don't always add up n ciphertexts + B_clean = (bigint(phi_m) << (sec + 2)) * p + * (20.5 + c1 * sigma * sqrt(phi_m) + 20 * c1 * sqrt(h)); + mpf_class V_s; + if (h > 0) + V_s = sqrt(h); + else + V_s = sigma * sqrt(phi_m); + B_scale = (c1 + c2 * V_s) * p * sqrt(phi_m / 12.0); +#ifdef NOISY + cout << "p * sqrt(phi(m) / 12): " << p * sqrt(phi_m / 12.0) << endl; + cout << "V_s: " << V_s << endl; + cout << "c1: " << c1 << endl; + cout << "c2: " << c2 << endl; + cout << "c1 + c2 * V_s: " << c1 + c2 * V_s << endl; + cout << "B_scale: " << B_scale << endl; +#endif + } + else + { + B_clean = (phi_m * p / 2 + + p * sigma + * (16 * phi_m * sqrt(n / 2) + 6 * sqrt(phi_m) + + 16 * sqrt(n * h * phi_m))) << slack; + B_scale = p * sqrt(3 * phi_m) * (1 + 8 * sqrt(n * h) / 3); + cout << "log(slack): " << slack << endl; + } + + drown = 1 + n * (bigint(1) << sec); } bigint SemiHomomorphicNoiseBounds::min_p0(const bigint& p1) @@ -34,28 +65,62 @@ bigint SemiHomomorphicNoiseBounds::min_p0() return B_clean * drown * p; } -double SemiHomomorphicNoiseBounds::min_phi_m(int log_q) +double SemiHomomorphicNoiseBounds::min_phi_m(int log_q, double sigma) { + if (sigma <= 0) + sigma = FHE_Params().get_R(); // the constant was updated using Martin Albrecht's LWE estimator in Sep 2019 - return 37.8 * (log_q - log2(3.2)); + return 37.8 * (log_q - log2(sigma)); } +void SemiHomomorphicNoiseBounds::produce_epsilon_constants() +{ + double C[3]; + + for (int i = 0; i < 3; i++) + { + C[i] = -1; + } + for (double x = 0.1; x < 10.0; x += .1) + { + double t = erfc(x), tp = 1; + for (int i = 1; i < 3; i++) + { + tp *= t; + double lgtp = log(tp) / log(2.0); + if (C[i] < 0 && lgtp < FHE_epsilon) + { + C[i] = pow(x, i); + } + } + } + + c1 = C[1]; + c2 = C[2]; +} NoiseBounds::NoiseBounds(const bigint& p, int phi_m, int n, int sec, int slack, double sigma, int h) : SemiHomomorphicNoiseBounds(p, phi_m, n, sec, slack, false, sigma, h) { - B_KS = p * phi_m * sigma - * (pow(n, 2.5) * (1.49 * sqrt(h * phi_m) + 2.11 * h) - + 2.77 * n * n * sqrt(h) - + pow(n, 1.5) * (1.96 * sqrt(phi_m) * 2.77 * sqrt(h)) - + 4.62 * n); + if (CowGearOptions::singleton.top_gear()) + { + B_KS = p * c2 * this->sigma * phi_m / sqrt(12); + } + else + { + B_KS = p * phi_m * mpf_class(this->sigma) + * (pow(n, 2.5) * (1.49 * sqrt(h * phi_m) + 2.11 * h) + + 2.77 * n * n * sqrt(h) + + pow(n, 1.5) * (1.96 * sqrt(phi_m) * 2.77 * sqrt(h)) + + 4.62 * n); + } #ifdef NOISY cout << "p size: " << numBits(p) << endl; cout << "phi(m): " << phi_m << endl; cout << "n: " << n << endl; cout << "sec: " << sec << endl; - cout << "sigma: " << sigma << endl; + cout << "sigma: " << this->sigma << endl; cout << "h: " << h << endl; cout << "B_clean size: " << numBits(B_clean) << endl; cout << "B_scale size: " << numBits(B_scale) << endl; diff --git a/FHE/NoiseBounds.h b/FHE/NoiseBounds.h index 2afb7af45..700790735 100644 --- a/FHE/NoiseBounds.h +++ b/FHE/NoiseBounds.h @@ -13,27 +13,33 @@ int phi_N(int N); class SemiHomomorphicNoiseBounds { protected: + static const int FHE_epsilon = 55; + const bigint p; const int phi_m; const int n; const int sec; int slack; - const double sigma; + mpf_class sigma; const int h; bigint B_clean; bigint B_scale; bigint drown; + mpf_class c1, c2; + + void produce_epsilon_constants(); + public: SemiHomomorphicNoiseBounds(const bigint& p, int phi_m, int n, int sec, - int slack, bool extra_h = false, double sigma = 3.2, int h = 64); + int slack, bool extra_h = false, double sigma = -1, int h = 64); // with scaling bigint min_p0(const bigint& p1); // without scaling bigint min_p0(); bigint min_p0(bool scale, const bigint& p1) { return scale ? min_p0(p1) : min_p0(); } - static double min_phi_m(int log_q); + static double min_phi_m(int log_q, double sigma = -1); }; // as per ePrint 2012:642 for slack = 0 @@ -43,7 +49,7 @@ class NoiseBounds : public SemiHomomorphicNoiseBounds public: NoiseBounds(const bigint& p, int phi_m, int n, int sec, int slack, - double sigma = 3.2, int h = 64); + double sigma = -1, int h = 64); bigint U1(const bigint& p0, const bigint& p1); bigint U2(const bigint& p0, const bigint& p1); bigint min_p0(const bigint& p0, const bigint& p1); diff --git a/FHE/P2Data.h b/FHE/P2Data.h index 53d8c73e0..e07e0e928 100644 --- a/FHE/P2Data.h +++ b/FHE/P2Data.h @@ -20,6 +20,7 @@ class P2Data public: typedef gf2n_short T; typedef int S; + typedef int poly_type; int num_slots() const { return slots; } int degree() const { return A.size() ? A.size() : 0; } diff --git a/FHE/PPData.cpp b/FHE/PPData.cpp index 3d2104646..f63b50662 100644 --- a/FHE/PPData.cpp +++ b/FHE/PPData.cpp @@ -1,4 +1,4 @@ -#include "Math/Subroutines.h" +#include "FHE/Subroutines.h" #include "FHE/PPData.h" #include "FHE/FFT.h" #include "FHE/Matrix.h" diff --git a/FHE/PPData.h b/FHE/PPData.h index 6fe0dcb9d..c31aaccac 100644 --- a/FHE/PPData.h +++ b/FHE/PPData.h @@ -4,6 +4,7 @@ #include "Math/modp.h" #include "Math/Zp_Data.h" #include "Math/gfp.h" +#include "Math/fixint.h" #include "FHE/Ring.h" /* Class for holding modular arithmetic data wrt the ring @@ -16,6 +17,7 @@ class PPData public: typedef gf2n_short T; typedef bigint S; + typedef fixint poly_type; Ring R; Zp_Data prData; diff --git a/FHE/Plaintext.cpp b/FHE/Plaintext.cpp index ea463358d..901248362 100644 --- a/FHE/Plaintext.cpp +++ b/FHE/Plaintext.cpp @@ -3,6 +3,21 @@ #include "FHE/Ring_Element.h" #include "FHE/PPData.h" #include "FHE/P2Data.h" +#include "FHE/Rq_Element.h" +#include "FHE_Keys.h" +#include "Math/Z2k.hpp" + + + +template<> +void Plaintext::from(const Generator& source) const +{ + for (auto& x : b) + { + source.get(bigint::tmp); + x = bigint::tmp; + } +} template<> @@ -11,7 +26,7 @@ void Plaintext::from_poly() const if (type!=Polynomial) { return; } Ring_Element e(*Field_Data,polynomial); - e.from_vec(b); + e.from(b); e.change_rep(evaluation); a.resize(n_slots); for (unsigned int i=0; i::to_poly() const for (unsigned int i=0; i::to_poly() const { bb[i]=a[i].get(); } (*Field_Data).from_eval(bb); for (unsigned int i=0; i::allocate_slots(const bigint& value) { b.resize(degree); for (auto& x : b) - x = value; + x.allocate_slots(value); } template<> @@ -151,52 +169,20 @@ void signed_mod(bigint& x, const bigint& mod, const bigint& half_mod, const bigi x += dest_mod; } -template<> -void Plaintext::set_poly_mod(const Generator& generator,const bigint& mod) +template +void Plaintext::set_poly_mod(const Generator& generator,const bigint& mod) { allocate(Polynomial); bigint half_mod = mod / 2; for (unsigned int i=0; iget_prime()); + generator.get(bigint::tmp); + signed_mod(bigint::tmp, mod, half_mod, Field_Data->get_prime()); + b[i] = bigint::tmp; } } -template<> -void Plaintext::set_poly_mod(const vector& vv,const bigint& mod) -{ - b = vv; - vector& pol = b; - bigint half_mod = mod / 2; - for (unsigned int i=0; i half_mod) - pol[i] -= mod; - pol[i] %= (*Field_Data).get_prime(); - if (pol[i]<0) { pol[i]+=(*Field_Data).get_prime(); } - } - type = Polynomial; -} - - -template<> -void Plaintext::set_poly_mod(const vector& vv,const bigint& mod) -{ - b = vv; - vector& pol = b; - for (unsigned int i=0; imod/2) { pol[i]=vv[i]-mod; } - else { pol[i]=vv[i]; } - pol[i]=pol[i]%(*Field_Data).get_prime(); - if (pol[i]<0) { pol[i]+=(*Field_Data).get_prime(); } - } - type = Polynomial; -} - - template<> @@ -229,7 +215,8 @@ void Plaintext::set_poly_mod(const Generator& gen } -void rand_poly(vector& b,PRNG& G,const bigint& pr,bool positive=true) +template +void rand_poly(vector& b,PRNG& G,const bigint& pr,bool positive=true) { for (unsigned int i=0; i::assign_constant(T constant, PT_Type t) template Plaintext& Plaintext::operator+=( - const Plaintext& y) + const Plaintext& y) { if (Field_Data!=y.Field_Data) { throw field_mismatch(); } @@ -687,8 +674,16 @@ void Plaintext::negate() +template +Rq_Element Plaintext::mul_by_X_i(int i, const FHE_PK& pk) const +{ + return Rq_Element(pk.get_params(), *this).mul_by_X_i(i); +} + + + template -bool Plaintext::equals(const Plaintext& x) const +bool Plaintext::equals(const Plaintext& x) const { if (Field_Data!=x.Field_Data) { return false; } if (type!=x.type) @@ -730,7 +725,7 @@ void Plaintext::unpack(octetStream& o) if (size != b.size()) throw length_error("unexpected length received"); for (unsigned int i = 0; i < b.size(); i++) - o.get(b[i]); + b[i] = o.get(); } diff --git a/FHE/Plaintext.h b/FHE/Plaintext.h index cf046ef6c..adb3c625d 100644 --- a/FHE/Plaintext.h +++ b/FHE/Plaintext.h @@ -18,10 +18,14 @@ */ #include "FHE/Generator.h" +#include "Math/fixint.h" #include using namespace std; +class FHE_PK; +class Rq_Element; + // Forward declaration as apparently this is needed for friends in templates template class Plaintext; template ostream& operator<<(ostream& s,const Plaintext& e); @@ -35,9 +39,11 @@ enum condition { Full, Diagonal, Bits }; enum PT_Type { Polynomial, Evaluation, Both }; -template +template class Plaintext { + typedef typename FD::poly_type S; + int n_slots; int degree; @@ -60,7 +66,7 @@ class Plaintext const FD& get_field() const { return *Field_Data; } unsigned int num_slots() const { return n_slots; } - void assign(const Plaintext& p) + void assign(const Plaintext& p) { Field_Data=p.Field_Data; a=p.a; b=p.b; type=p.type; n_slots = p.n_slots; @@ -70,12 +76,7 @@ class Plaintext Plaintext(const FD& FieldD, PT_Type type = Polynomial) { Field_Data=&FieldD; set_sizes(); allocate(type); } - Plaintext(const Plaintext& p) { assign(p); } - ~Plaintext() { ; } - Plaintext& operator=(const Plaintext& p) - { if (this!=&p) { assign(p); } - return *this; - } + Plaintext(const FD& FieldD, const Rq_Element& other); void allocate(PT_Type type) const; void allocate() const { allocate(type); } @@ -117,17 +118,23 @@ class Plaintext void set_poly(const vector& v) { type=Polynomial; b=v; } const vector& get_poly() const - { if (type==Evaluation) { throw rep_mismatch(); } + { + to_poly(); return b; } Iterator get_iterator() const { to_poly(); return b; } + void from(const Generator& source) const; + // This sets a poly from a vector of bigint's which needs centering // modulo mod, before assigning (used in decryption) // vv[i] is already assumed reduced modulo mod though but in // range [0,...,mod) - void set_poly_mod(const vector& vv,const bigint& mod); + void set_poly_mod(const vector& vv,const bigint& mod) + { + set_poly_mod(Iterator(vv), mod); + } void set_poly_mod(const Generator& generator, const bigint& mod); // Converts between Evaluation,Polynomial and Both representations @@ -144,36 +151,38 @@ class Plaintext void assign_one(PT_Type t = Evaluation); void assign_constant(T constant, PT_Type t = Evaluation); - friend void add<>(Plaintext& z,const Plaintext& x,const Plaintext& y); - friend void sub<>(Plaintext& z,const Plaintext& x,const Plaintext& y); - friend void mul<>(Plaintext& z,const Plaintext& x,const Plaintext& y); - friend void sqr<>(Plaintext& z,const Plaintext& x); + friend void add<>(Plaintext& z,const Plaintext& x,const Plaintext& y); + friend void sub<>(Plaintext& z,const Plaintext& x,const Plaintext& y); + friend void mul<>(Plaintext& z,const Plaintext& x,const Plaintext& y); + friend void sqr<>(Plaintext& z,const Plaintext& x); - Plaintext operator+(const Plaintext& x) const - { Plaintext res(*Field_Data); add(res, *this, x); return res; } - Plaintext operator-(const Plaintext& x) const - { Plaintext res(*Field_Data); sub(res, *this, x); return res; } + Plaintext operator+(const Plaintext& x) const + { Plaintext res(*Field_Data); add(res, *this, x); return res; } + Plaintext operator-(const Plaintext& x) const + { Plaintext res(*Field_Data); sub(res, *this, x); return res; } - void mul(const Plaintext& x, const Plaintext& y) + void mul(const Plaintext& x, const Plaintext& y) { x.from_poly(); y.from_poly(); ::mul(*this, x, y); } - Plaintext operator*(const Plaintext& x) - { Plaintext res(*Field_Data); res.mul(*this, x); return res; } + Plaintext operator*(const Plaintext& x) + { Plaintext res(*Field_Data); res.mul(*this, x); return res; } - Plaintext& operator+=(const Plaintext& y); - Plaintext& operator-=(const Plaintext& y) + Plaintext& operator+=(const Plaintext& y); + Plaintext& operator-=(const Plaintext& y) { to_poly(); y.to_poly(); ::sub(*this, *this, y); return *this; } void negate(); - bool equals(const Plaintext& x) const; - bool operator!=(const Plaintext& x) { return !equals(x); } + Rq_Element mul_by_X_i(int i, const FHE_PK& pk) const; + + bool equals(const Plaintext& x) const; + bool operator!=(const Plaintext& x) { return !equals(x); } bool is_diagonal() const { throw not_implemented(); } bool is_binary() const { throw not_implemented(); } - friend ostream& operator<< <>(ostream& s,const Plaintext& e); - friend istream& operator>> <>(istream& s,Plaintext& e); + friend ostream& operator<< <>(ostream& s,const Plaintext& e); + friend istream& operator>> <>(istream& s,Plaintext& e); /* Pack and unpack into an octetStream * For unpack we assume the FFTD has been assigned correctly already diff --git a/FHE/Random_Coins.h b/FHE/Random_Coins.h index 24b9b8d7f..f7d75bbb1 100644 --- a/FHE/Random_Coins.h +++ b/FHE/Random_Coins.h @@ -9,8 +9,10 @@ class FHE_PK; -class Int_Random_Coins : public AddableMatrix +class Int_Random_Coins : public AddableMatrix> { + typedef value_type::value_type T; + const FHE_Params* params; public: Int_Random_Coins(const FHE_Params& params) : params(¶ms) @@ -20,14 +22,16 @@ class Int_Random_Coins : public AddableMatrix void sample(PRNG& G) { - (*this)[0].from(HalfGenerator(G)); + (*this)[0].from(HalfGenerator(G)); for (int i = 1; i < 3; i++) - (*this)[i].from(GaussianGenerator(params->get_DG(), G)); + (*this)[i].from(GaussianGenerator(params->get_DG(), G)); } }; class Random_Coins { + typedef bigint T; + Rq_Element uu,vv,ww; const FHE_Params *params; @@ -56,16 +60,25 @@ class Random_Coins template void assign(const vector& u,const vector& v,const vector& w) - { uu.from_vec(u); vv.from_vec(v); ww.from_vec(w); } + { + uu.from(u); + vv.from(v); + ww.from(w); + } void assign(const Int_Random_Coins& rc) - { uu.from_vec(rc[0]); vv.from_vec(rc[1]); ww.from_vec(rc[2]); } + { + uu.from(rc[0]); + vv.from(rc[1]); + ww.from(rc[2]); + } /* Generate a standard distribution */ void generate(PRNG& G) - { uu.from(HalfGenerator(G)); - vv.from(GaussianGenerator(params->get_DG(), G)); - ww.from(GaussianGenerator(params->get_DG(), G)); + { + uu.from(HalfGenerator(G)); + vv.from(GaussianGenerator(params->get_DG(), G)); + ww.from(GaussianGenerator(params->get_DG(), G)); } // Generate all from Uniform in range (-B,...B) @@ -74,9 +87,9 @@ class Random_Coins if (B1 == 0) uu.assign_zero(); else - uu.from(UniformGenerator(G,numBits(B1))); - vv.from(UniformGenerator(G,numBits(B2))); - ww.from(UniformGenerator(G,numBits(B3))); + uu.from(UniformGenerator(G,numBits(B1))); + vv.from(UniformGenerator(G,numBits(B2))); + ww.from(UniformGenerator(G,numBits(B3))); } diff --git a/FHE/Ring_Element.cpp b/FHE/Ring_Element.cpp index a73bcbe56..a3c0e1d18 100644 --- a/FHE/Ring_Element.cpp +++ b/FHE/Ring_Element.cpp @@ -147,6 +147,47 @@ void mul(Ring_Element& ans,const Ring_Element& a,const modp& b) } +Ring_Element Ring_Element::mul_by_X_i(int j) const +{ + Ring_Element ans; + auto& a = *this; + ans.partial_assign(a); + if (ans.rep == evaluation) + { + modp xj, xj2; + Power(xj, (*ans.FFTD).get_root(0), j, (*a.FFTD).get_prD()); + Sqr(xj2, xj, (*a.FFTD).get_prD()); + for (int i= 0; i < (*ans.FFTD).phi_m(); i++) + { + Mul(ans.element[i], a.element[i], xj, (*a.FFTD).get_prD()); + Mul(xj, xj, xj2, (*a.FFTD).get_prD()); + } + } + else + { + Ring_Element aa(*ans.FFTD, ans.rep); + for (int i= 0; i < (*ans.FFTD).phi_m(); i++) + { + int k= j + i, s= 1; + while (k >= (*ans.FFTD).phi_m()) + { + k-= (*ans.FFTD).phi_m(); + s= -s; + } + if (s == 1) + { + aa.element[k]= a.element[i]; + } + else + { + Negate(aa.element[k], a.element[i], (*a.FFTD).get_prD()); + } + } + ans= aa; + } + return ans; +} + void Ring_Element::randomize(PRNG& G,bool Diag) { @@ -318,20 +359,6 @@ void Ring_Element::from_vec(const vector& v) // cout << "RE:from_vec:: " << *this << endl; } -template -void Ring_Element::from(const Generator& generator) -{ - RepType t=rep; - rep=polynomial; - T tmp; - for (int i=0; i<(*FFTD).phi_m(); i++) - { - generator.get(tmp); - element[i].convert_destroy(tmp, (*FFTD).get_prD()); - } - change_rep(t); -} - ConversionIterator Ring_Element::get_iterator() const { if (rep != polynomial) @@ -389,6 +416,7 @@ modp Ring_Element::get_constant() const void store(octetStream& o,const vector& v,const Zp_Data& ZpD) { + ZpD.pack(o); o.store((int)v.size()); for (unsigned int i=0; i& v,const Zp_Data& ZpD) void get(octetStream& o,vector& v,const Zp_Data& ZpD) { + Zp_Data check_Zpd; + check_Zpd.unpack(o); + if (check_Zpd != ZpD) + throw runtime_error( + "mismatch: " + to_string(check_Zpd.pr_bit_length) + "/" + + to_string(ZpD.pr_bit_length)); unsigned int length; o.get(length); v.resize(length); @@ -408,7 +442,7 @@ void get(octetStream& o,vector& v,const Zp_Data& ZpD) void Ring_Element::pack(octetStream& o) const { check_size(); - o.store(rep); + o.store(unsigned(rep)); store(o,element,(*FFTD).get_prD()); } diff --git a/FHE/Ring_Element.h b/FHE/Ring_Element.h index cbebc1e70..647270563 100644 --- a/FHE/Ring_Element.h +++ b/FHE/Ring_Element.h @@ -90,6 +90,8 @@ class Ring_Element friend void mul(Ring_Element& ans,const Ring_Element& a,const Ring_Element& b); friend void mul(Ring_Element& ans,const Ring_Element& a,const modp& b); + Ring_Element mul_by_X_i(int i) const; + void randomize(PRNG& G,bool Diag=false); bool equals(const Ring_Element& a) const; @@ -116,6 +118,12 @@ class Ring_Element template void from(const Generator& generator); + template + void from(const vector& source) + { + from(Iterator(source)); + } + // This gets the constant term of the poly rep as a modp element modp get_constant() const; modp get_element(int i) const { return element[i]; } @@ -167,5 +175,20 @@ class RingReadIterator : public ConversionIterator inline void mul(Ring_Element& ans,const modp& a,const Ring_Element& b) { mul(ans,b,a); } + +template +void Ring_Element::from(const Generator& generator) +{ + RepType t=rep; + rep=polynomial; + T tmp; + for (int i=0; i<(*FFTD).phi_m(); i++) + { + generator.get(tmp); + element[i].convert_destroy(tmp, (*FFTD).get_prD()); + } + change_rep(t); +} + #endif diff --git a/FHE/Rq_Element.cpp b/FHE/Rq_Element.cpp index c6d77e89d..a1df2634c 100644 --- a/FHE/Rq_Element.cpp +++ b/FHE/Rq_Element.cpp @@ -1,6 +1,12 @@ #include "Rq_Element.h" +#include "FHE_Keys.h" #include "Exceptions/Exceptions.h" +Rq_Element::Rq_Element(const FHE_PK& pk) : + Rq_Element(pk.get_params().FFTD()) +{ +} + Rq_Element::Rq_Element(const vector& prd, RepType r0, RepType r1) { if (prd.size() > 0) @@ -57,6 +63,19 @@ void Rq_Element::negate() a[i].negate(); } +Rq_Element Rq_Element::mul_by_X_i(int i) const +{ + Rq_Element res; + res.lev = lev; + res.a.clear(); + for (auto& x : a) + { + auto tmp = x.mul_by_X_i(i); + res.a.push_back(tmp); + } + return res; +} + void add(Rq_Element& ans,const Rq_Element& ra,const Rq_Element& rb) { ans.partial_assign(ra, rb); @@ -173,19 +192,6 @@ void Rq_Element::from_vec(const vector& v,int level) a[i].from_vec(v); } -template -void Rq_Element::from(const Generator& generator, int level) -{ - set_level(level); - if (lev == 1) - { - auto clone = generator.clone(); - a[1].from(*clone); - delete clone; - } - a[0].from(generator); -} - vector Rq_Element::to_vec_bigint() const { vector v; @@ -220,7 +226,7 @@ void Rq_Element::to_vec_bigint(vector& v) const } } -ConversionIterator Rq_Element::get_iterator() +ConversionIterator Rq_Element::get_iterator() const { if (lev != 0) throw not_implemented(); @@ -339,7 +345,8 @@ void Rq_Element::raise_level() void Rq_Element::check_level() const { if ((unsigned)lev > (unsigned)n_mults()) - throw range_error("level out of range"); + throw range_error( + "level out of range: " + to_string(lev) + "/" + to_string(n_mults())); } void Rq_Element::partial_assign(const Rq_Element& x, const Rq_Element& y) diff --git a/FHE/Rq_Element.h b/FHE/Rq_Element.h index 7479adb00..c58690131 100644 --- a/FHE/Rq_Element.h +++ b/FHE/Rq_Element.h @@ -17,6 +17,7 @@ #include "FHE/FHE_Params.h" #include "FHE/tools.h" #include "FHE/Generator.h" +#include "Plaintext.h" #include // Forward declare the friend functions @@ -62,9 +63,18 @@ class Rq_Element Rq_Element(const FHE_Params& params) : Rq_Element(params.FFTD()) {} + Rq_Element(const FHE_PK& pk); + Rq_Element(const Ring_Element& b0,const Ring_Element& b1) : a({b0, b1}), lev(n_mults()) {} + template + Rq_Element(const FHE_Params& params, const Plaintext& plaintext) : + Rq_Element(params) + { + from(plaintext.get_iterator()); + } + // Destructor ~Rq_Element() { ; } @@ -91,6 +101,8 @@ class Rq_Element // Multiply something by p1 and make level 1 void mul_by_p1(); + Rq_Element mul_by_X_i(int i) const; + void randomize(PRNG& G,int lev=-1); // Scale from level 1 to level 0, if at level 0 do nothing @@ -113,10 +125,16 @@ class Rq_Element vector to_vec_bigint() const; void to_vec_bigint(vector& v) const; - ConversionIterator get_iterator(); + ConversionIterator get_iterator() const; template void from(const Generator& generator, int level=-1); + template + void from(const vector& source, int level=-1) + { + from(Iterator(source), level); + } + bigint infinity_norm() const; bigint get_prime(int i) const @@ -156,9 +174,22 @@ template Rq_Element& Rq_Element::operator+=(const vector& other) { Rq_Element tmp = *this; - tmp.from_vec(other, lev); + tmp.from(Iterator(other), lev); add(*this, *this, tmp); return *this; } +template +void Rq_Element::from(const Generator& generator, int level) +{ + set_level(level); + if (lev == 1) + { + auto clone = generator.clone(); + a[1].from(*clone); + delete clone; + } + a[0].from(generator); +} + #endif diff --git a/Math/Subroutines.cpp b/FHE/Subroutines.cpp similarity index 98% rename from Math/Subroutines.cpp rename to FHE/Subroutines.cpp index 40b08cadb..8ec6a7cee 100644 --- a/Math/Subroutines.cpp +++ b/FHE/Subroutines.cpp @@ -1,8 +1,8 @@ #include "Subroutines.h" -#include "modp.h" +#include "Math/modp.h" -#include "modp.hpp" +#include "Math/modp.hpp" void Subs(modp& ans,const vector& poly,const modp& x,const Zp_Data& ZpD) { diff --git a/Math/Subroutines.h b/FHE/Subroutines.h similarity index 100% rename from Math/Subroutines.h rename to FHE/Subroutines.h diff --git a/FHEOffline/CutAndChooseMachine.cpp b/FHEOffline/CutAndChooseMachine.cpp index f68645b36..c6c431e20 100644 --- a/FHEOffline/CutAndChooseMachine.cpp +++ b/FHEOffline/CutAndChooseMachine.cpp @@ -18,6 +18,11 @@ CutAndChooseMachine::CutAndChooseMachine(int argc, const char** argv) "--covert" // Flag token. ); parse_options(argc, argv); + if (produce_inputs) + { + cerr << "Producing input tuples is not implemented" << endl; + exit(1); + } covert = opt.isSet("--covert"); if (not covert and sec != 40) throw runtime_error("active cut-and-choose only implemented for 40-bit security"); diff --git a/FHEOffline/DataSetup.cpp b/FHEOffline/DataSetup.cpp index 11a628d5f..dd2715bf1 100644 --- a/FHEOffline/DataSetup.cpp +++ b/FHEOffline/DataSetup.cpp @@ -8,6 +8,10 @@ #include "Protocols/fake-stuff.h" #include "FHE/NTL-Subs.h" #include "Tools/benchmarking.h" +#include "Tools/Bundle.h" +#include "PairwiseSetup.h" +#include "Proof.h" +#include "SimpleMachine.h" #include using namespace std; @@ -58,8 +62,8 @@ void PartSetup::generate_setup(int n_parties, int plaintext_length, int sec, int slack, bool round_up) { sec = max(sec, 40); - ::generate_setup(n_parties, plaintext_length, sec, params, FieldD, - slack, round_up); + Parameters(n_parties, plaintext_length, sec, slack, round_up).generate_setup( + params, FieldD); params.set_sec(sec); pk = FHE_PK(params, FieldD.get_prime()); sk = FHE_SK(params, FieldD.get_prime()); @@ -254,6 +258,7 @@ void DataSetup::output(int my_number, int nn, bool specific_dir) template void PartSetup::pack(octetStream& os) { + os.append((octet*)"PARTSETU", 8); params.pack(os); FieldD.pack(os); pk.pack(os); @@ -265,8 +270,15 @@ void PartSetup::pack(octetStream& os) template void PartSetup::unpack(octetStream& os) { + char tag[8]; + os.consume((octet*) tag, 8); + if (memcmp(tag, "PARTSETU", 8)) + throw runtime_error("invalid serialization of setup"); params.unpack(os); FieldD.unpack(os); + pk = {params, FieldD}; + sk = pk; + calpha = params; pk.unpack(os); sk.unpack(os); calpha.unpack(os); @@ -305,5 +317,94 @@ bool PartSetup::operator!=(const PartSetup& other) return false; } +template +void PartSetup::secure_init(Player& P, MachineBase& machine, + int plaintext_length, int sec) +{ + ::secure_init(*this, P, machine, plaintext_length, sec); +} + +template +void PartSetup::generate(Player& P, MachineBase&, int plaintext_length, + int sec) +{ + generate_setup(P.num_players(), plaintext_length, sec, + INTERACTIVE_SPDZ1_SLACK, false); +} + +template +void PartSetup::check(Player& P, MachineBase& machine) +{ + Bundle bundle(P); + bundle.mine.store(machine.extra_slack); + auto& os = bundle.mine; + params.pack(os); + FieldD.hash(os); + pk.pack(os); + calpha.pack(os); + bundle.compare(P); +} + +template +void PartSetup::covert_key_generation(Player& P, + MultiplicativeMachine& machine, int num_runs) +{ + auto& setup = machine.setup.part(); + Run_Gen_Protocol(setup.pk, setup.sk, P, num_runs, false); +} + +template +void PartSetup::covert_mac_generation(Player& P, + MultiplicativeMachine& machine, int num_runs) +{ + auto& setup = machine.setup.part(); + generate_mac_key(setup.alphai, setup.calpha, setup.FieldD, setup.pk, P, + num_runs); +} + +template +void PartSetup::covert_secrets_generation(Player& P, + MultiplicativeMachine& machine, int num_runs) +{ + octetStream os; + params.pack(os); + FieldD.pack(os); + string filename = PREP_DIR "ChaiGear-Secrets-" + to_string(num_runs) + "-" + + os.check_sum(20).get_str(16) + "-P" + to_string(P.my_num()); + + string error; + + try + { + ifstream input(filename); + os.input(input); + unpack(os); + } + catch (exception& e) + { + error = e.what(); + } + + try + { + check(P, machine); + } + catch (mismatch_among_parties& e) + { + error = e.what(); + } + + if (not error.empty()) + { + cerr << "Running secrets generation because " << error << endl; + covert_key_generation(P, machine, num_runs); + covert_mac_generation(P, machine, num_runs); + ofstream output(filename); + octetStream os; + pack(os); + os.output(output); + } +} + template class PartSetup; template class PartSetup; diff --git a/FHEOffline/DataSetup.h b/FHEOffline/DataSetup.h index 49683fb48..561d11508 100644 --- a/FHEOffline/DataSetup.h +++ b/FHEOffline/DataSetup.h @@ -12,6 +12,8 @@ #include "Math/Setup.h" class DataSetup; +class MachineBase; +class MultiplicativeMachine; template class PartSetup @@ -26,6 +28,11 @@ class PartSetup Ciphertext calpha; typename FD::T alphai; + static string name() + { + return "GlobalParams-" + T::type_string(); + } + PartSetup(); void generate_setup(int n_parties, int plaintext_length, int sec, int slack, bool round_up); @@ -42,6 +49,19 @@ class PartSetup void check(int sec) const; bool operator!=(const PartSetup& other); + + void secure_init(Player& P, MachineBase& machine, int plaintext_length, + int sec); + void generate(Player& P, MachineBase& machine, int plaintext_length, + int sec); + void check(Player& P, MachineBase& machine); + + void covert_key_generation(Player& P, MultiplicativeMachine& machine, + int num_runs); + void covert_mac_generation(Player& P, MultiplicativeMachine& machine, + int num_runs); + void covert_secrets_generation(Player& P, MultiplicativeMachine& machine, + int num_runs); }; class DataSetup diff --git a/FHEOffline/DistDecrypt.cpp b/FHEOffline/DistDecrypt.cpp index c3a585470..39b6001ee 100644 --- a/FHEOffline/DistDecrypt.cpp +++ b/FHEOffline/DistDecrypt.cpp @@ -13,7 +13,7 @@ DistDecrypt::DistDecrypt(const Player& P, const FHE_SK& share, bigint limit = pk.get_params().Q() << 64; vv.allocate_slots(limit); vv1.allocate_slots(limit); - mf.allocate_slots(pk.get_params().p0() << 64); + mf.allocate_slots(pk.p() << 64); } template diff --git a/FHEOffline/EncCommit.cpp b/FHEOffline/EncCommit.cpp index 28300ebe4..693c0fb82 100644 --- a/FHEOffline/EncCommit.cpp +++ b/FHEOffline/EncCommit.cpp @@ -262,7 +262,7 @@ void EncCommit::Create_More() const { if (cond!=Full) { throw not_implemented(); } else - { m_Delta.from(UniformGenerator(Gseed[i],numBits(Bound1))); } + { m_Delta.from(UniformGenerator(Gseed[i],numBits(Bound1))); } rc_Delta.generateUniform(Gseed[i],Bound2,Bound3,Bound3); Ciphertext Delta(params); (*pk).quasi_encrypt(Delta,m_Delta,rc_Delta); @@ -319,7 +319,7 @@ void EncCommit::Create_More() const if (cond!=Full) { throw not_implemented(); } else - { mm.from(UniformGenerator(G,numBits(Bound1))); } + { mm.from(UniformGenerator(G, numBits(Bound1))); } rr.generateUniform(G,Bound2,Bound3,Bound3); (*pk).quasi_encrypt(cc,mm,rr); occ.reset_write_head(); @@ -357,10 +357,10 @@ void EncCommit::Create_More() const { throw not_implemented(); } else - { m_Delta.from(UniformGenerator(G,numBits(Bound1))); } + { m_Delta.from(UniformGenerator(G, numBits(Bound1))); } rc_Delta.generateUniform(G,Bound2,Bound3,Bound3); - Iterator vm=m[i].get_iterator(); + auto vm=m[i].get_iterator(); z[0].from(vm); add(z[0],z[0],m_Delta); add(rr,rc[i],rc_Delta); diff --git a/FHEOffline/EncCommit.h b/FHEOffline/EncCommit.h index abddb0bdb..211a886b6 100644 --- a/FHEOffline/EncCommit.h +++ b/FHEOffline/EncCommit.h @@ -12,13 +12,17 @@ #include "FHE/Plaintext.h" #include "Tools/MemoryUsage.h" +class MachineBase; + template class EncCommitBase { public: size_t volatile_memory; - EncCommitBase() : volatile_memory(0) {} + const MachineBase* machine; + + EncCommitBase(const MachineBase* machine = 0) : volatile_memory(0), machine(machine) {} virtual ~EncCommitBase() {} virtual condition get_condition() { return Full; } virtual void next(Plaintext& mess, Ciphertext& c) diff --git a/FHEOffline/Multiplier.cpp b/FHEOffline/Multiplier.cpp index 460fee18e..3f6ecbf51 100644 --- a/FHEOffline/Multiplier.cpp +++ b/FHEOffline/Multiplier.cpp @@ -27,7 +27,7 @@ Multiplier::Multiplier(int offset, PairwiseMachine& machine, Player& P, product_share(machine.setup().FieldD), rc(machine.pk), volatile_capacity(0) { - product_share.allocate_slots(machine.setup().params.p0() << 64); + product_share.allocate_slots(machine.pk.p() << 64); } template @@ -35,7 +35,7 @@ void Multiplier::multiply_and_add(Plaintext_& res, const Ciphertext& enc_a, const Plaintext_& b) { Rq_Element bb(enc_a.get_params(), evaluation, evaluation); - bb.from_vec(b.get_poly()); + bb.from(b.get_iterator()); multiply_and_add(res, enc_a, bb); } diff --git a/FHEOffline/PairwiseGenerator.cpp b/FHEOffline/PairwiseGenerator.cpp index 8f19c3a2d..ac46d3ae0 100644 --- a/FHEOffline/PairwiseGenerator.cpp +++ b/FHEOffline/PairwiseGenerator.cpp @@ -7,6 +7,8 @@ #include "FHEOffline/PairwiseMachine.h" #include "FHEOffline/Producer.h" #include "Protocols/SemiShare.h" +#include "GC/SemiSecret.h" +#include "GC/SemiPrep.h" #include "Protocols/MAC_Check.hpp" #include "Protocols/SemiInput.hpp" @@ -21,20 +23,21 @@ PairwiseGenerator::PairwiseGenerator(int thread_num, thread_num, machine.output), EC(P, machine.other_pks, machine.setup().FieldD, timers, machine, *this), MC(machine.setup().alphai), - C(machine.sec, machine.setup().params), volatile_memory(0), + n_ciphertexts(Proof::n_ciphertext_per_proof(machine.sec, machine.pk)), + C(n_ciphertexts, machine.setup().params), volatile_memory(0), machine(machine) { for (int i = 1; i < P.num_players(); i++) multipliers.push_back(new Multiplier(i, *this)); const FD& FieldD = machine.setup().FieldD; - a.resize(machine.sec, FieldD); - b.resize(machine.sec, FieldD); - c.resize(machine.sec, FieldD); + a.resize(n_ciphertexts, FieldD); + b.resize(n_ciphertexts, FieldD); + c.resize(n_ciphertexts, FieldD); a.allocate_slots(FieldD.get_prime()); b.allocate_slots(FieldD.get_prime()); // extra limb for addition c.allocate_slots((bigint)FieldD.get_prime() << 64); - b_mod_q.resize(machine.sec, + b_mod_q.resize(n_ciphertexts, { machine.setup().params, evaluation, evaluation }); } @@ -64,8 +67,8 @@ void PairwiseGenerator::run() c.mul(a, b); timers["Plaintext multiplication"].stop(); timers["FFT of b"].start(); - for (int i = 0; i < machine.sec; i++) - b_mod_q.at(i).from_vec(b.at(i).get_poly()); + for (int i = 0; i < n_ciphertexts; i++) + b_mod_q.at(i).from(b.at(i).get_iterator()); timers["FFT of b"].stop(); timers["Proof exchange"].start(); size_t verifier_memory = EC.create_more(ciphertexts, cleartexts); @@ -73,7 +76,7 @@ void PairwiseGenerator::run() volatile_memory = max(prover_memory, verifier_memory); Rq_Element values({machine.setup().params, evaluation, evaluation}); - for (int k = 0; k < machine.sec; k++) + for (int k = 0; k < n_ciphertexts; k++) { producer.ai = a[k]; producer.bi = b[k]; @@ -90,7 +93,7 @@ void PairwiseGenerator::run() else { timers["Plaintext conversion"].start(); - values.from_vec(producer.values[j].get_poly()); + values.from(producer.values[j].get_iterator()); timers["Plaintext conversion"].stop(); } @@ -122,7 +125,7 @@ void PairwiseGenerator::generate_inputs(int player) { SeededPRNG G; b[0].randomize(G); - b_mod_q.at(0).from_vec(b.at(0).get_poly()); + b_mod_q.at(0).from(b.at(0).get_iterator()); producer.macs[0].mul(machine.setup().alpha, b[0]); } else diff --git a/FHEOffline/PairwiseGenerator.h b/FHEOffline/PairwiseGenerator.h index 51e4fa923..ef587b48b 100644 --- a/FHEOffline/PairwiseGenerator.h +++ b/FHEOffline/PairwiseGenerator.h @@ -29,6 +29,8 @@ class PairwiseGenerator : public GeneratorBase MultiEncCommit EC; MAC_Check MC; + int n_ciphertexts; + // temporary data AddableVector C; octetStream ciphertexts, cleartexts; diff --git a/FHEOffline/PairwiseSetup.cpp b/FHEOffline/PairwiseSetup.cpp index e9e9f3049..3da7d8bcf 100644 --- a/FHEOffline/PairwiseSetup.cpp +++ b/FHEOffline/PairwiseSetup.cpp @@ -48,12 +48,21 @@ void PairwiseSetup::init(const Player& P, int sec, int plaintext_length, template void PairwiseSetup::secure_init(Player& P, PairwiseMachine& machine, int plaintext_length, int sec) +{ + ::secure_init(*this, P, machine, plaintext_length, sec); + alpha = FieldD; +} + +template +void secure_init(T& setup, Player& P, MachineBase& machine, + int plaintext_length, int sec) { machine.sec = sec; sec = max(sec, 40); machine.drown_sec = sec; - string filename = PREP_DIR "Params-" + FD::T::type_string() + "-" - + to_string(plaintext_length) + "-" + to_string(sec) + "-P" + string filename = PREP_DIR + T::name() + "-" + + to_string(plaintext_length) + "-" + to_string(sec) + "-" + + to_string(CowGearOptions::singleton.top_gear()) + "-P" + to_string(P.my_num()); try { @@ -61,38 +70,54 @@ void PairwiseSetup::secure_init(Player& P, PairwiseMachine& machine, int pla octetStream os; os.input(file); os.get(machine.extra_slack); - params.unpack(os); - FieldD.unpack(os); - FieldD.init_field(); - check(P, machine); + setup.unpack(os); + setup.check(P, machine); } catch (...) { cout << "Finding parameters for security " << sec << " and field size ~2^" << plaintext_length << endl; - machine.extra_slack = generate_semi_setup(plaintext_length, sec, params, FieldD, true); - check(P, machine); + setup.generate(P, machine, plaintext_length, sec); + setup.check(P, machine); octetStream os; os.store(machine.extra_slack); - params.pack(os); - FieldD.pack(os); + setup.pack(os); ofstream file(filename); os.output(file); } - alpha = FieldD; } template -void PairwiseSetup::check(Player& P, PairwiseMachine& machine) +void PairwiseSetup::generate(Player&, MachineBase& machine, + int plaintext_length, int sec) +{ + machine.extra_slack = generate_semi_setup(plaintext_length, sec, params, + FieldD, true); +} + +template +void PairwiseSetup::pack(octetStream& os) const +{ + params.pack(os); + FieldD.pack(os); +} + +template +void PairwiseSetup::unpack(octetStream& os) +{ + params.unpack(os); + FieldD.unpack(os); + FieldD.init_field(); +} + +template +void PairwiseSetup::check(Player& P, MachineBase& machine) { Bundle bundle(P); bundle.mine.store(machine.extra_slack); params.pack(bundle.mine); FieldD.hash(bundle.mine); - P.Broadcast_Receive(bundle, true); - for (auto& os : bundle) - if (os != bundle.mine) - throw runtime_error("mismatch of parameters among parties"); + bundle.compare(P); } template @@ -161,3 +186,6 @@ void PairwiseSetup::set_alphai(T alphai) template class PairwiseSetup; template class PairwiseSetup; + +template void secure_init(PartSetup&, Player&, MachineBase&, int, int); +template void secure_init(PartSetup&, Player&, MachineBase&, int, int); diff --git a/FHEOffline/PairwiseSetup.h b/FHEOffline/PairwiseSetup.h index 3e5fa0e59..d8ec74ac3 100644 --- a/FHEOffline/PairwiseSetup.h +++ b/FHEOffline/PairwiseSetup.h @@ -11,6 +11,11 @@ #include "Networking/Player.h" class PairwiseMachine; +class MachineBase; + +template +void secure_init(T& setup, Player& P, MachineBase& machine, + int plaintext_length, int sec); template class PairwiseSetup @@ -24,15 +29,24 @@ class PairwiseSetup Plaintext_ alpha; string dirname; + static string name() + { + return "PairwiseParams-" + FD::T::type_string(); + } + PairwiseSetup() : params(0), alpha(FieldD) {} void init(const Player& P, int sec, int plaintext_length, int& extra_slack); void secure_init(Player& P, PairwiseMachine& machine, int plaintext_length, int sec); - void check(Player& P, PairwiseMachine& machine); + void generate(Player& P, MachineBase& machine, int plaintext_length, int sec); + void check(Player& P, MachineBase& machine); void covert_key_generation(Player& P, PairwiseMachine& machine, int num_runs); void covert_mac_generation(Player& P, PairwiseMachine& machine, int num_runs); + void pack(octetStream& os) const; + void unpack(octetStream& os); + void set_alphai(T alphai); }; diff --git a/FHEOffline/Producer.cpp b/FHEOffline/Producer.cpp index ed44122be..e31bd5609 100644 --- a/FHEOffline/Producer.cpp +++ b/FHEOffline/Producer.cpp @@ -9,6 +9,8 @@ #include "Sacrificing.h" #include "Reshare.h" #include "DistDecrypt.h" +#include "SimpleEncCommit.h" +#include "SimpleMachine.h" #include "Tools/mkpath.h" template @@ -610,7 +612,7 @@ InputProducer::~InputProducer() template void InputProducer::run(const Player& P, const FHE_PK& pk, const Ciphertext& calpha, EncCommitBase_& EC, DistDecrypt& dd, - const T& alphai) + const T& alphai, int player) { (void)alphai; @@ -620,43 +622,86 @@ void InputProducer::run(const Player& P, const FHE_PK& pk, G.ReSeed(); Ciphertext gama(params),dummyc(params),ca(params); - vector oca(P.num_players()); const FD& FieldD = dd.f.get_field(); Plaintext a(FieldD),ai(FieldD),gai(FieldD); - Random_Coins rc(params); this->n_slots = FieldD.num_slots(); Share Sh; - a.randomize(G); - rc.generate(G); - pk.encrypt(ca, a, rc); - ca.pack(oca[P.my_num()]); - P.Broadcast_Receive(oca); + inputs.resize(P.num_players()); - for (int j = 0; j < P.num_players(); j++) + int min, max; + if (player < 0) + { + min = 0; + max = P.num_players(); + } + else { - ca.unpack(oca[j]); - // Reshare the aj values - dd.reshare(ai, ca, EC); + min = player; + max = player + 1; + } - // Generate encrypted MAC values - mul(gama, calpha, ca, pk); + map timers; + assert(EC.machine); + SimpleEncCommit_ personal_EC(P, pk, FieldD, timers, *EC.machine, 0); + octetStream ciphertexts, cleartexts; - // Get shares on the MACs - dd.reshare(gai, gama, EC); + for (int j = min; j < max; j++) + { + AddableVector C; + vector> m(EC.machine->sec, FieldD); + if (j == P.my_num()) + { + for (auto& x : m) + x.randomize(G); + personal_EC.generate_proof(C, m, ciphertexts, cleartexts); + P.send_all(ciphertexts, true); + P.send_all(cleartexts, true); + } + else + { + P.receive_player(j, ciphertexts, true); + P.receive_player(j, cleartexts, true); + C.resize(personal_EC.machine->sec, pk.get_params()); + Verifier>(personal_EC.proof).NIZKPoK(C, ciphertexts, + cleartexts, pk, false, false); + } + + inputs[j].clear(); - for (unsigned int i = 0; i < ai.num_slots(); i++) + for (size_t i = 0; i < C.size(); i++) { - Sh.set_share(ai.element(i)); - Sh.set_mac(gai.element(i)); - if (write_output) + auto& ca = C.at(i); + auto& a = m.at(i); + + // Reshare the aj values + dd.reshare(ai, ca, EC); + + // Generate encrypted MAC values + mul(gama, calpha, ca, pk); + + // Get shares on the MACs + dd.reshare(gai, gama, EC); + + for (unsigned int i = 0; i < ai.num_slots(); i++) { - Sh.output(outf[j], false); - if (j == P.my_num()) + Sh.set_share(ai.element(i)); + Sh.set_mac(gai.element(i)); + if (write_output) + { + Sh.output(outf[j], false); + if (j == P.my_num()) + { + a.element(i).output(outf[j], false); + } + } + else { - a.element(i).output(outf[j], false); + inputs[j].push_back({Sh, {}}); + if (j == P.my_num()) + inputs[j].back().value = a.element(i); } } } diff --git a/FHEOffline/Producer.h b/FHEOffline/Producer.h index 20cf3f27e..95c2e631f 100644 --- a/FHEOffline/Producer.h +++ b/FHEOffline/Producer.h @@ -221,6 +221,8 @@ class InputProducer : public Producer bool write_output; public: + vector>>> inputs; + InputProducer(const Player& P, int output_thread = 0, bool write_output = true, string dir = PREP_DIR); ~InputProducer(); @@ -228,7 +230,15 @@ class InputProducer : public Producer string data_type() { return "Inputs"; } void run(const Player& P, const FHE_PK& pk, const Ciphertext& calpha, - EncCommitBase_& EC, DistDecrypt& dd, const T& alphai); + EncCommitBase_& EC, DistDecrypt& dd, const T& alphai) + { + run(P, pk, calpha, EC, dd, alphai, -1); + } + + void run(const Player& P, const FHE_PK& pk, const Ciphertext& calpha, + EncCommitBase_& EC, DistDecrypt& dd, const T& alphai, + int player); + int sacrifice(const Player& P, MAC_Check& MC); // no ops diff --git a/FHEOffline/Proof.cpp b/FHEOffline/Proof.cpp index 2836283b8..33061b1fd 100644 --- a/FHEOffline/Proof.cpp +++ b/FHEOffline/Proof.cpp @@ -6,6 +6,7 @@ #include "Proof.h" #include "FHE/P2Data.h" #include "FHEOffline/EncCommit.h" +#include "Math/Z2k.hpp" double Proof::dist = 0; @@ -32,22 +33,50 @@ bigint Proof::slack(int slack, int sec, int phim) } } -void Proof::get_challenge(vector& e, const octetStream& ciphertexts) const +void Proof::set_challenge(const octetStream& ciphertexts) +{ + octetStream hash = ciphertexts.hash(); + PRNG G; + assert(hash.get_length() >= SEED_SIZE); + G.SetSeed(hash.get_data()); + set_challenge(G); +} + +void Proof::set_challenge(PRNG& G) { unsigned int i; - bigint hashout = ciphertexts.check_sum(); - for (i=0; i>(i))&1; } + if (top_gear) + { + W.resize(V, vector(U)); + for (i = 0; i < V; i++) + for (unsigned j = 0; j < U; j++) + W[i][j] = G.get_uint(2 * phim) - 1; + } + else + { + e.resize(sec); + for (i = 0; i < sec; i++) + { + e[i] = G.get_bit(); + } + } +} + +void Proof::generate_challenge(const Player& P) +{ + GlobalPRNG G(P); + set_challenge(G); } +template class AbsoluteBoundChecker { - bigint bound, neg_bound; + T bound, neg_bound; public: - AbsoluteBoundChecker(bigint bound) : bound(bound), neg_bound(-bound) {} - bool outside(const bigint& value, double& dist) + AbsoluteBoundChecker(T bound) : bound(bound), neg_bound(-this->bound) {} + bool outside(const T& value, double& dist) { (void)dist; #ifdef PRINT_MIN_DIST @@ -57,17 +86,17 @@ class AbsoluteBoundChecker } }; -template -bool Proof::check_bounds(T& z, AddableMatrix& t, int i) const +template +bool Proof::check_bounds(T& z, X& t, int i) const { unsigned int j,k; // Check Bound 1 and Bound 2 - AbsoluteBoundChecker plain_checker(plain_check * n_proofs); - AbsoluteBoundChecker rand_checker(rand_check * n_proofs); + AbsoluteBoundChecker> plain_checker(plain_check * n_proofs); + AbsoluteBoundChecker> rand_checker(rand_check * n_proofs); for (j=0; j& t, int i) const } for (k=0; k<3; k++) { - const vector& coeffs = t[k]; + auto& coeffs = t[k]; for (j=0; j& z, AddableMatrix& t, int i) const; -template bool Proof::check_bounds(AddableVector& z, AddableMatrix& t, int i) const; - -template bool Proof::check_bounds(Plaintext_& z, AddableMatrix& t, int i) const; +template bool Proof::check_bounds(AddableVector>& z, AddableMatrix>& t, int i) const; +template bool Proof::check_bounds(AddableVector>& z, AddableMatrix>& t, int i) const; +template bool Proof::check_bounds(AddableVector>& z, AddableMatrix>& t, int i) const; diff --git a/FHEOffline/Proof.h b/FHEOffline/Proof.h index 5ef558606..d94710d11 100644 --- a/FHEOffline/Proof.h +++ b/FHEOffline/Proof.h @@ -8,6 +8,7 @@ using namespace std; #include "Math/bigint.h" #include "FHE/Ciphertext.h" #include "FHE/AddableVector.h" +#include "Protocols/CowGearOptions.h" #include "config.h" @@ -21,6 +22,8 @@ enum SlackType class Proof { + unsigned int sec; + Proof(); // Private to avoid default public: @@ -29,12 +32,14 @@ class Proof class Preimages { - bigint m_tmp; - AddableVector r_tmp; + typedef Int_Random_Coins::value_type::value_type r_type; + + fixint m_tmp; + AddableVector r_tmp; public: Preimages(int size, const FHE_PK& pk, const bigint& p, int n_players); - AddableMatrix m; + AddableMatrix> m; Randomness r; void add(octetStream& os); void pack(octetStream& os); @@ -43,18 +48,22 @@ class Proof size_t report_size(ReportType type) { return m.report_size(type) + r.report_size(type); } }; - unsigned int sec; bigint tau,rho; unsigned int phim; int B_plain_length, B_rand_length; bigint plain_check, rand_check; unsigned int V; + unsigned int U; const FHE_PK* pk; int n_proofs; + vector e; + vector> W; + bool top_gear; + static double dist; protected: @@ -65,19 +74,67 @@ class Proof tau=Tau; rho=Rho; phim=(pk.get_params()).phi_m(); - V=2*sec-1; + + top_gear = use_top_gear(pk); + if (top_gear) + { + V = ceil((sec + 2) / log2(2 * phim + 1)); + U = 2 * V; +#ifdef VERBOSE + cerr << "Using " << U << " ciphertexts per proof" << endl; +#endif + } + else + { + U = sec; + V = 2 * sec - 1; + } } Proof(int sec, const FHE_PK& pk, int n_proofs = 1) : - Proof(sec, pk.p() / 2, 2 * 3.2 * sqrt(pk.get_params().phi_m()), pk, + Proof(sec, pk.p() / 2, + pk.get_params().get_DG().get_NewHopeB(), pk, n_proofs) {} + virtual ~Proof() {} + public: static bigint slack(int slack, int sec, int phim); - void get_challenge(vector& e, const octetStream& ciphertexts) const; - template - bool check_bounds(T& z, AddableMatrix& t, int i) const; + static bool use_top_gear(const FHE_PK& pk) + { + return CowGearOptions::singleton.top_gear() and pk.p() > 2; + } + + static int n_ciphertext_per_proof(int sec, const FHE_PK& pk) + { + return Proof(sec, pk, 1).U; + } + + void set_challenge(const octetStream& ciphertexts); + void set_challenge(PRNG& G); + void generate_challenge(const Player& P); + + template + bool check_bounds(T& z, X& t, int i) const; + + template + void apply_challenge(int i, T& output, const U& input, const FHE_PK& pk) const + { + if (top_gear) + { + for (unsigned j = 0; j < this->U; j++) + if (W[i][j] >= 0) + output += input[j].mul_by_X_i(W[i][j], pk); + } + else + for (unsigned k = 0; k < sec; k++) + { + unsigned j = (i + 1) - (k + 1); + if (j < sec && e.at(j)) + output += input.at(j); + } + } }; class NonInteractiveProof : public Proof @@ -111,8 +168,7 @@ class InteractiveProof : public Proof Proof(sec, pk, n_proofs) { bigint B; - // using mu = 1 - B = bigint(1) << (sec - 1); + B = bigint(1) << sec; B_plain_length = numBits(B * tau); B_rand_length = numBits(B * rho); // leeway for completeness diff --git a/FHEOffline/Prover.cpp b/FHEOffline/Prover.cpp index 6425896a1..6bf7977d3 100644 --- a/FHEOffline/Prover.cpp +++ b/FHEOffline/Prover.cpp @@ -3,6 +3,7 @@ #include "FHE/P2Data.h" #include "Tools/random.h" +#include "Math/Z2k.hpp" template @@ -62,35 +63,25 @@ template bool Prover::Stage_2(Proof& P, octetStream& cleartexts, const vector& x, const Proof::Randomness& r, - const vector& e) + const FHE_PK& pk) { size_t allocate = P.V * P.phim * (5 + numBytes(P.plain_check) + 3 * (5 + numBytes(P.rand_check))); cleartexts.resize_precise(allocate); cleartexts.reset_write_head(); - unsigned int i,k; - int j,ee; + unsigned int i; #ifndef LESS_ALLOC_MORE_MEM - AddableVector z; - AddableMatrix t; + AddableVector> z; + AddableMatrix> t; #endif cleartexts.reset_write_head(); cleartexts.store(P.V); for (i=0; i=(int) P.sec) { ee=0; } - else { ee=e[j]; } - - if (ee!=0) - { - z += x[j]; - t += r[j]; - } - } + P.apply_challenge(i, z, x, pk); + P.apply_challenge(i, t, r, pk); if (not P.check_bounds(z, t, i)) return false; z.pack(cleartexts); @@ -118,8 +109,6 @@ size_t Prover::NIZKPoK(Proof& P, octetStream& ciphertexts, octetStream& cl const Proof::Randomness& r, bool Diag,bool binary) { - vector e(P.sec); - // AElement AE; // for (i=0; i::NIZKPoK(Proof& P, octetStream& ciphertexts, octetStream& cl while (!ok) { cnt++; Stage_1(P,ciphertexts,c,pk,Diag,binary); - P.get_challenge(e, ciphertexts); + P.set_challenge(ciphertexts); // Check check whether we are OK, or whether we should abort - ok = Stage_2(P,cleartexts,x,r,e); + ok = Stage_2(P,cleartexts,x,r,pk); } if (cnt > 1) cout << "\t\tNumber iterations of prover = " << cnt << endl; @@ -173,7 +162,7 @@ void Prover::report_size(ReportType type, MemoryUsage& res) template class Prover >; -template class Prover >; +//template class Prover >; template class Prover >; -template class Prover >; +//template class Prover >; diff --git a/FHEOffline/Prover.h b/FHEOffline/Prover.h index d111b6a4b..f6f121212 100644 --- a/FHEOffline/Prover.h +++ b/FHEOffline/Prover.h @@ -14,8 +14,8 @@ class Prover AddableVector< Plaintext_ > y; #ifdef LESS_ALLOC_MORE_MEM - AddableVector z; - AddableMatrix t; + AddableVector> z; + AddableMatrix t; #endif public: @@ -30,7 +30,7 @@ class Prover bool Stage_2(Proof& P, octetStream& cleartexts, const vector& x, const Proof::Randomness& r, - const vector& e); + const FHE_PK& pk); /* Only has a non-interactive version using the ROM - If Diag is true then the plaintexts x are assumed to be diff --git a/FHEOffline/Sacrificing.cpp b/FHEOffline/Sacrificing.cpp index d3d43a0e7..da18fced8 100644 --- a/FHEOffline/Sacrificing.cpp +++ b/FHEOffline/Sacrificing.cpp @@ -85,6 +85,7 @@ void Triple_Checking(const Player& P, MAC_Check& MC, int nm, // Triple checking int left_todo=nm; + factory.triples.clear(); while (left_todo>0) { int this_loop=amortize; if (this_loop>left_todo) @@ -121,7 +122,17 @@ void Triple_Checking(const Player& P, MAC_Check& MC, int nm, for (int i=0; i& MC, int ns, // Do the square checking int left_todo=ns; + square_factory.tuples.clear(); while (left_todo>0) { int this_loop=amortize; if (this_loop>left_todo) @@ -363,6 +375,8 @@ void Square_Checking(const Player& P, MAC_Check& MC, int ns, { a[i].output(outf_s,false); b[i].output(outf_s,false); } + else + square_factory.tuples.push_back({{a[i], b[i]}}); } left_todo-=this_loop; } diff --git a/FHEOffline/Sacrificing.h b/FHEOffline/Sacrificing.h index b0593b29c..444dfc130 100644 --- a/FHEOffline/Sacrificing.h +++ b/FHEOffline/Sacrificing.h @@ -25,6 +25,8 @@ template class TupleSacriFactory { public: + vector> tuples; + virtual ~TupleSacriFactory() {} virtual void get(T& a, T& b) = 0; }; diff --git a/FHEOffline/SimpleEncCommit.cpp b/FHEOffline/SimpleEncCommit.cpp index bd25a4371..8db28756b 100644 --- a/FHEOffline/SimpleEncCommit.cpp +++ b/FHEOffline/SimpleEncCommit.cpp @@ -14,7 +14,8 @@ template SimpleEncCommitBase::SimpleEncCommitBase(const MachineBase& machine) : - sec(machine.sec), extra_slack(machine.extra_slack), n_rounds(0) + EncCommitBase_(&machine), + extra_slack(machine.extra_slack), n_rounds(0) { } @@ -36,7 +37,7 @@ NonInteractiveProofSimpleEncCommit::NonInteractiveProofSimpleEncCommit( P(P), pk(pk), FTD(FTD), proof(machine.sec, pk, machine.extra_slack), #ifdef LESS_ALLOC_MORE_MEM - r(this->sec, this->pk.get_params()), prover(proof, FTD), + r(proof.U, this->pk.get_params()), prover(proof, FTD), verifier(proof), #endif timers(timers) @@ -46,9 +47,9 @@ NonInteractiveProofSimpleEncCommit::NonInteractiveProofSimpleEncCommit( template SimpleEncCommitFactory::SimpleEncCommitFactory(const FHE_PK& pk, const FD& FTD, const MachineBase& machine) : - cnt(-1), n_calls(0) + cnt(-1), n_calls(0), pk(pk) { - int sec = machine.sec; + int sec = Proof::n_ciphertext_per_proof(machine.sec, pk); c.resize(sec, pk.get_params()); m.resize(sec, FTD); for (int i = 0; i < sec; i++) @@ -70,6 +71,13 @@ void SimpleEncCommitFactory::next(Plaintext_& mess, Ciphertext& C) create_more(); mess = m[cnt]; C = c[cnt]; + + if (Proof::use_top_gear(pk)) + { + mess = mess + mess; + C = C + C; + } + cnt--; n_calls++; } @@ -84,18 +92,21 @@ void SimpleEncCommitFactory::prepare_plaintext(PRNG& G) template void SimpleEncCommitBase::generate_ciphertexts( AddableVector& c, const vector >& m, - Proof::Randomness& r, const FHE_PK& pk, TimerMap& timers) + Proof::Randomness& r, const FHE_PK& pk, TimerMap& timers, + Proof& proof) { timers["Generating"].start(); PRNG G; G.ReSeed(); prepare_plaintext(G); Random_Coins rc(pk.get_params()); - for (int i = 0; i < sec; i++) + c.resize(proof.U, pk); + r.resize(proof.U, pk); + for (unsigned i = 0; i < proof.U; i++) { r[i].sample(G); rc.assign(r[i]); - pk.encrypt(c[i], m[i], rc); + pk.encrypt(c[i], m.at(i), rc); } timers["Generating"].stop(); memory_usage.update("random coins", rc.report_size(CAPACITY)); @@ -103,14 +114,14 @@ void SimpleEncCommitBase::generate_ciphertexts( template size_t NonInteractiveProofSimpleEncCommit::generate_proof(AddableVector& c, - const vector >& m, octetStream& ciphertexts, + vector >& m, octetStream& ciphertexts, octetStream& cleartexts) { timers["Proving"].start(); #ifndef LESS_ALLOC_MORE_MEM - Proof::Randomness r(this->sec, pk.get_params()); + Proof::Randomness r(proof.U, pk.get_params()); #endif - this->generate_ciphertexts(c, m, r, pk, timers); + this->generate_ciphertexts(c, m, r, pk, timers, proof); #ifndef LESS_ALLOC_MORE_MEM Prover > prover(proof, FTD); #endif @@ -118,6 +129,13 @@ size_t NonInteractiveProofSimpleEncCommit::generate_proof(AddableVector().NIZKPoK(c[P.my_num()], proofs[P.my_num()], pk, false, false)) // throw runtime_error("proof check failed"); @@ -137,7 +155,7 @@ void SimpleEncCommit::create_more() cleartexts); cout << "Done checking proofs in round " << this->n_rounds << endl; this->n_rounds++; - this->cnt = this->sec - 1; + this->cnt = this->proof.U - 1; this->memory_usage.update("serialized ciphertexts", ciphertexts.get_max_length()); this->memory_usage.update("serialized cleartexts", cleartexts.get_max_length()); @@ -150,7 +168,7 @@ size_t NonInteractiveProofSimpleEncCommit::create_more(octetStream& cipherte octetStream& cleartexts) { AddableVector others_ciphertexts; - others_ciphertexts.resize(this->sec, pk.get_params()); + others_ciphertexts.resize(proof.U, pk.get_params()); for (int i = 1; i < P.num_players(); i++) { #ifdef VERBOSE_HE @@ -189,18 +207,31 @@ void SimpleEncCommit::add_ciphertexts( vector& ciphertexts, int offset) { (void)offset; - for (int j = 0; j < this->sec; j++) + for (unsigned j = 0; j < this->proof.U; j++) add(this->c[j], this->c[j], ciphertexts[j]); } +template +SummingEncCommit::SummingEncCommit(const Player& P, const FHE_PK& pk, + const FD& FTD, map& timers, const MachineBase& machine, + int thread_num) : + SimpleEncCommitFactory(pk, FTD, machine), SimpleEncCommitBase_( + machine), proof(machine.sec, pk, P.num_players()), pk(pk), FTD( + FTD), P(P), thread_num(thread_num), +#ifdef LESS_ALLOC_MORE_MEM + prover(proof, FTD), verifier(proof), preimages(proof.V, + this->pk, FTD.get_prime(), P.num_players()), +#endif + timers(timers) +{ +} + template void SummingEncCommit::create_more() { octetStream cleartexts; const Player& P = this->P; - InteractiveProof proof(this->sec, this->pk, P.num_players()); AddableVector commitments; - vector e(this->sec); size_t prover_size; MemoryUsage& memory_usage = this->memory_usage; TreeSum tree_sum(2, 2, thread_num % P.num_players()); @@ -210,10 +241,10 @@ void SummingEncCommit::create_more() #ifdef LESS_ALLOC_MORE_MEM Proof::Randomness& r = preimages.r; #else - Proof::Randomness r(this->sec, this->pk.get_params()); + Proof::Randomness r(proof.U, this->pk.get_params()); Prover > prover(proof, this->FTD); #endif - this->generate_ciphertexts(this->c, this->m, r, pk, timers); + this->generate_ciphertexts(this->c, this->m, r, pk, timers, proof); this->timers["Stage 1 of proof"].start(); prover.Stage_1(proof, ciphertexts, this->c, this->pk, false, false); this->timers["Stage 1 of proof"].stop(); @@ -228,10 +259,10 @@ void SummingEncCommit::create_more() tree_sum.run(commitments, P); this->timers["Exchanging ciphertexts"].stop(); - generate_challenge(e, P); + proof.generate_challenge(P); this->timers["Stage 2 of proof"].start(); - prover.Stage_2(proof, cleartexts, this->m, r, e); + prover.Stage_2(proof, cleartexts, this->m, r, pk); this->timers["Stage 2 of proof"].stop(); prover_size = prover.report_size(CAPACITY) + r.report_size(CAPACITY) @@ -273,10 +304,10 @@ void SummingEncCommit::create_more() #else Verifier verifier(proof); #endif - verifier.Stage_2(e, this->c, ciphertexts, cleartexts, + verifier.Stage_2(this->c, ciphertexts, cleartexts, this->pk, false, false); this->timers["Verifying"].stop(); - this->cnt = this->sec - 1; + this->cnt = proof.U - 1; this->volatile_memory = + commitments.report_size(CAPACITY) + ciphertexts.get_max_length() @@ -339,7 +370,7 @@ template void MultiEncCommit::add_ciphertexts(vector& ciphertexts, int offset) { - for (int i = 0; i < this->sec; i++) + for (unsigned i = 0; i < this->proof.U; i++) generator.multipliers[offset - 1]->multiply_and_add(generator.c.at(i), ciphertexts.at(i), generator.b_mod_q.at(i)); } diff --git a/FHEOffline/SimpleEncCommit.h b/FHEOffline/SimpleEncCommit.h index 98da2f294..c286faf2f 100644 --- a/FHEOffline/SimpleEncCommit.h +++ b/FHEOffline/SimpleEncCommit.h @@ -20,14 +20,13 @@ template class SimpleEncCommitBase : public EncCommitBase { protected: - int sec; int extra_slack; int n_rounds; void generate_ciphertexts(AddableVector& c, const vector >& m, Proof::Randomness& r, - const FHE_PK& pk, map& timers); + const FHE_PK& pk, map& timers, Proof& proof); virtual void prepare_plaintext(PRNG& G) = 0; @@ -46,12 +45,16 @@ template class NonInteractiveProofSimpleEncCommit : public SimpleEncCommitBase_ { protected: - typedef bigint S; + typedef fixint S; const PlayerBase& P; const FHE_PK& pk; const FD& FTD; + virtual const FHE_PK& get_pk_for_verification(int offset) = 0; + virtual void add_ciphertexts(vector& ciphertexts, int offset) = 0; + +public: NonInteractiveProof proof; #ifdef LESS_ALLOC_MORE_MEM @@ -60,17 +63,13 @@ class NonInteractiveProofSimpleEncCommit : public SimpleEncCommitBase_ Verifier verifier; #endif - virtual const FHE_PK& get_pk_for_verification(int offset) = 0; - virtual void add_ciphertexts(vector& ciphertexts, int offset) = 0; - -public: map& timers; NonInteractiveProofSimpleEncCommit(const PlayerBase& P, const FHE_PK& pk, const FD& FTD, map& timers, const MachineBase& machine); virtual ~NonInteractiveProofSimpleEncCommit() {} - size_t generate_proof(AddableVector& c, const vector >& m, + size_t generate_proof(AddableVector& c, vector >& m, octetStream& ciphertexts, octetStream& cleartexts); size_t create_more(octetStream& my_ciphertext, octetStream& my_cleartext); virtual size_t report_size(ReportType type); @@ -87,6 +86,8 @@ class SimpleEncCommitFactory int n_calls; + const FHE_PK& pk; + void prepare_plaintext(PRNG& G); virtual void create_more() = 0; @@ -104,7 +105,8 @@ class SimpleEncCommit: public NonInteractiveProofSimpleEncCommit, public SimpleEncCommitFactory { protected: - const FHE_PK& get_pk_for_verification(int offset) { (void)offset; return this->pk; } + const FHE_PK& get_pk_for_verification(int) + { return NonInteractiveProofSimpleEncCommit::pk; } void prepare_plaintext(PRNG& G) { SimpleEncCommitFactory::prepare_plaintext(G); } void add_ciphertexts(vector& ciphertexts, int offset); @@ -127,7 +129,7 @@ template class SummingEncCommit: public SimpleEncCommitFactory, public SimpleEncCommitBase_ { - typedef bigint S; + typedef fixint S; InteractiveProof proof; const FHE_PK& pk; @@ -148,15 +150,8 @@ class SummingEncCommit: public SimpleEncCommitFactory, map& timers; SummingEncCommit(const Player& P, const FHE_PK& pk, const FD& FTD, - map& timers, const MachineBase& machine, int thread_num) : - SimpleEncCommitFactory(pk, FTD, machine), SimpleEncCommitBase_(machine), - proof(this->sec, pk, P.num_players()), pk(pk), FTD(FTD), P(P), - thread_num(thread_num), -#ifdef LESS_ALLOC_MORE_MEM - prover(proof, FTD), verifier(proof), preimages(proof.V, this->pk, - FTD.get_prime(), P.num_players()), -#endif - timers(timers) {} + map& timers, const MachineBase& machine, int thread_num); + void next(Plaintext_& mess, Ciphertext& C) { SimpleEncCommitFactory::next(mess, C); } void create_more(); size_t report_size(ReportType type); diff --git a/FHEOffline/SimpleGenerator.cpp b/FHEOffline/SimpleGenerator.cpp index 5fb970373..9456a5f83 100644 --- a/FHEOffline/SimpleGenerator.cpp +++ b/FHEOffline/SimpleGenerator.cpp @@ -12,10 +12,11 @@ template