From 24926df83be7f572cb2d2615bdc212b224d5562e Mon Sep 17 00:00:00 2001 From: Marcel Keller Date: Thu, 2 Apr 2020 18:06:14 +1100 Subject: [PATCH] Bristol Fashion. --- .gitmodules | 3 + CHANGELOG.md | 7 + Compiler/GC/instructions.py | 69 ++-- Compiler/GC/types.py | 42 +- Compiler/allocator.py | 99 +---- Compiler/circuit.py | 201 ++++++++++ Compiler/comparison.py | 1 + Compiler/compilerLib.py | 39 +- Compiler/config.py | 6 - Compiler/graph.py | 2 +- Compiler/instructions.py | 206 +--------- Compiler/instructions_base.py | 37 +- Compiler/library.py | 8 +- Compiler/ml.py | 34 +- Compiler/program.py | 134 ++----- Compiler/types.py | 60 ++- Compiler/util.py | 6 + ExternalIO/README.md | 76 ++-- ExternalIO/bankers-bonus-client.cpp | 129 ++++--- ExternalIO/bankers-bonus-commsec-client.cpp | 407 -------------------- FHE/FFT_Data.cpp | 15 +- FHE/NTL-Subs.cpp | 37 +- FHE/NoiseBounds.cpp | 2 + FHE/Ring.cpp | 24 +- FHE/Ring.h | 3 +- GC/ArgTuples.h | 2 + GC/Instruction.hpp | 9 +- GC/Instruction_inline.h | 1 + GC/Processor.h | 10 +- GC/Processor.hpp | 61 ++- GC/Secret.h | 6 + GC/Secret.hpp | 8 +- GC/SemiSecret.cpp | 25 +- GC/ShareSecret.h | 6 +- GC/ShareSecret.hpp | 56 ++- GC/ShareThread.hpp | 2 + GC/Thread.hpp | 17 +- GC/TinySecret.h | 3 + GC/instructions.h | 28 +- GC/square64.cpp | 4 + Machines/Player-Online.hpp | 2 - Makefile | 8 +- Math/Z2k.h | 10 +- Math/Z2k.hpp | 6 + Math/gfp.cpp | 6 + Math/gfp.h | 2 + Math/modp.h | 3 + Math/modp.hpp | 11 + Networking/CryptoPlayer.cpp | 44 +-- Networking/CryptoPlayer.h | 2 +- Networking/Player.cpp | 1 - Networking/STS.cpp | 228 ----------- Networking/STS.h | 70 ---- Networking/ServerSocket.cpp | 94 +++-- Networking/ServerSocket.h | 19 +- Networking/ssl_sockets.h | 60 ++- Processor/ExternalClients.cpp | 139 +------ Processor/ExternalClients.h | 28 +- Processor/FixInput.cpp | 13 +- Processor/FixInput.h | 18 +- Processor/Input.h | 5 +- Processor/Input.hpp | 46 +-- Processor/Instruction.h | 8 +- Processor/Instruction.hpp | 64 ++- Processor/Machine.hpp | 12 +- Processor/Memory.h | 15 - Processor/Memory.hpp | 37 -- Processor/Processor.h | 12 +- Processor/Processor.hpp | 175 +-------- Programs/Circuits | 1 + Programs/Source/aes_circuit.mpc | 8 + Programs/Source/bankers_bonus.mpc | 2 - Programs/Source/bankers_bonus_commsec.mpc | 144 ------- Programs/Source/idash_predict.mpc | 14 +- Programs/Source/idash_train.mpc | 34 +- Programs/Source/logreg.mpc | 2 - Programs/Source/regression.mpc | 9 +- Programs/Source/test_gc.mpc | 5 - Protocols/Rep3Share.h | 8 +- Protocols/ReplicatedInput.h | 6 +- Protocols/ReplicatedInput.hpp | 41 -- Protocols/ShamirShare.h | 14 +- Protocols/ShareInterface.h | 2 + Protocols/SohoPrep.h | 2 + Protocols/SohoPrep.hpp | 49 +++ Protocols/Spdz2kPrep.h | 1 - README.md | 25 +- Scripts/setup-clients.sh | 13 + Scripts/test_ecdsa.sh | 11 +- Tools/Config.cpp | 107 ----- Tools/Config.h | 15 - Tools/octetStream.cpp | 104 +---- Tools/octetStream.h | 51 ++- Tools/random.cpp | 1 + Utils/client-setup.cpp | 178 --------- azure-pipelines.yml | 24 ++ compile.py | 17 +- doc/Compiler.rst | 6 + 98 files changed, 1263 insertions(+), 2654 deletions(-) create mode 100644 Compiler/circuit.py delete mode 100644 ExternalIO/bankers-bonus-commsec-client.cpp delete mode 100644 Networking/STS.cpp delete mode 100644 Networking/STS.h create mode 160000 Programs/Circuits create mode 100644 Programs/Source/aes_circuit.mpc delete mode 100644 Programs/Source/bankers_bonus_commsec.mpc create mode 100755 Scripts/setup-clients.sh delete mode 100644 Tools/Config.cpp delete mode 100644 Tools/Config.h delete mode 100644 Utils/client-setup.cpp create mode 100644 azure-pipelines.yml diff --git a/.gitmodules b/.gitmodules index 7c1438dd6..193b36775 100644 --- a/.gitmodules +++ b/.gitmodules @@ -4,3 +4,6 @@ [submodule "mpir"] path = mpir url = git://github.com/wbhart/mpir.git +[submodule "Programs/Circuits"] + path = Programs/Circuits + url = https://github.com/mkskeller/bristol-fashion diff --git a/CHANGELOG.md b/CHANGELOG.md index d9072297a..e44fc713d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,12 @@ The changelog explains changes pulled through from the private development repository. Bug fixes and small enhancements are committed between releases and not documented here. +## 0.1.6 (Apr 2, 2020) + +- Bristol Fashion circuits +- Semi-honest computation with somewhat homomorphic encryption +- Use SSL for client connections +- Client facilities for all arithmetic protocols + ## 0.1.5 (Mar 20, 2020) - Faster conversion between arithmetic and binary secret sharing using [extended daBits](https://eprint.iacr.org/2020/338) diff --git a/Compiler/GC/instructions.py b/Compiler/GC/instructions.py index c48ae357c..636f532bb 100644 --- a/Compiler/GC/instructions.py +++ b/Compiler/GC/instructions.py @@ -67,19 +67,33 @@ class BinaryVectorInstruction(base.Instruction): def copy(self, size, subs): return type(self)(*self.get_new_args(size, subs)) +class NonVectorInstruction(base.Instruction): + is_vec = lambda self: False + + def __init__(self, *args, **kwargs): + assert(args[0].n <= args[0].unit) + super(NonVectorInstruction, self).__init__(*args, **kwargs) + +class NonVectorInstruction1(base.Instruction): + is_vec = lambda self: False + + def __init__(self, *args, **kwargs): + assert(args[1].n <= args[1].unit) + super(NonVectorInstruction1, self).__init__(*args, **kwargs) + class xors(BinaryVectorInstruction): code = opcodes['XORS'] arg_format = tools.cycle(['int','sbw','sb','sb']) -class xorm(base.Instruction): +class xorm(NonVectorInstruction): code = opcodes['XORM'] arg_format = ['int','sbw','sb','cb'] -class xorcb(base.Instruction): +class xorcb(NonVectorInstruction): code = opcodes['XORCB'] arg_format = ['cbw','cb','cb'] -class xorcbi(base.Instruction): +class xorcbi(NonVectorInstruction): code = opcodes['XORCBI'] arg_format = ['cbw','cb','int'] @@ -101,47 +115,48 @@ class andm(BinaryVectorInstruction): code = opcodes['ANDM'] arg_format = ['int','sbw','sb','cb'] -class addcb(base.Instruction): +class addcb(NonVectorInstruction): code = opcodes['ADDCB'] arg_format = ['cbw','cb','cb'] -class addcbi(base.Instruction): +class addcbi(NonVectorInstruction): code = opcodes['ADDCBI'] arg_format = ['cbw','cb','int'] -class mulcbi(base.Instruction): +class mulcbi(NonVectorInstruction): code = opcodes['MULCBI'] arg_format = ['cbw','cb','int'] -class bitdecs(base.VarArgsInstruction): +class bitdecs(NonVectorInstruction, base.VarArgsInstruction): code = opcodes['BITDECS'] arg_format = tools.chain(['sb'], itertools.repeat('sbw')) -class bitcoms(base.VarArgsInstruction): +class bitcoms(NonVectorInstruction, base.VarArgsInstruction): code = opcodes['BITCOMS'] arg_format = tools.chain(['sbw'], itertools.repeat('sb')) -class bitdecc(base.VarArgsInstruction): +class bitdecc(NonVectorInstruction, base.VarArgsInstruction): code = opcodes['BITDECC'] arg_format = tools.chain(['cb'], itertools.repeat('cbw')) -class shrcbi(base.Instruction): +class shrcbi(NonVectorInstruction): code = opcodes['SHRCBI'] arg_format = ['cbw','cb','int'] -class shlcbi(base.Instruction): +class shlcbi(NonVectorInstruction): code = opcodes['SHLCBI'] arg_format = ['cbw','cb','int'] -class ldbits(base.Instruction): +class ldbits(NonVectorInstruction): code = opcodes['LDBITS'] arg_format = ['sbw','i','i'] -class ldmsb(base.DirectMemoryInstruction, base.ReadMemoryInstruction): +class ldmsb(base.DirectMemoryInstruction, base.ReadMemoryInstruction, + base.VectorInstruction): code = opcodes['LDMSB'] arg_format = ['sbw','int'] -class stmsb(base.DirectMemoryWriteInstruction): +class stmsb(base.DirectMemoryWriteInstruction, base.VectorInstruction): code = opcodes['STMSB'] arg_format = ['sb','int'] # def __init__(self, *args, **kwargs): @@ -149,19 +164,20 @@ class stmsb(base.DirectMemoryWriteInstruction): # import inspect # self.caller = [frame[1:] for frame in inspect.stack()[1:]] -class ldmcb(base.DirectMemoryInstruction, base.ReadMemoryInstruction): +class ldmcb(base.DirectMemoryInstruction, base.ReadMemoryInstruction, + base.VectorInstruction): code = opcodes['LDMCB'] arg_format = ['cbw','int'] -class stmcb(base.DirectMemoryWriteInstruction): +class stmcb(base.DirectMemoryWriteInstruction, base.VectorInstruction): code = opcodes['STMCB'] arg_format = ['cb','int'] -class ldmsbi(base.ReadMemoryInstruction): +class ldmsbi(base.ReadMemoryInstruction, base.VectorInstruction): code = opcodes['LDMSBI'] arg_format = ['sbw','ci'] -class stmsbi(base.WriteMemoryInstruction): +class stmsbi(base.WriteMemoryInstruction, base.VectorInstruction): code = opcodes['STMSBI'] arg_format = ['sb','ci'] @@ -185,15 +201,15 @@ class stmsdci(base.WriteMemoryInstruction): code = opcodes['STMSDCI'] arg_format = tools.cycle(['cb','cb']) -class convsint(base.Instruction): +class convsint(NonVectorInstruction1): code = opcodes['CONVSINT'] arg_format = ['int','sbw','ci'] -class convcint(base.Instruction): +class convcint(NonVectorInstruction): code = opcodes['CONVCINT'] arg_format = ['cbw','ci'] -class convcbit(base.Instruction): +class convcbit(NonVectorInstruction1): code = opcodes['CONVCBIT'] arg_format = ['ciw','cb'] @@ -222,18 +238,19 @@ def __init__(self, *args, **kwargs): super(split_class, self).__init__(*args, **kwargs) assert (len(args) - 2) % args[0] == 0 -class movsb(base.Instruction): +class movsb(NonVectorInstruction): code = opcodes['MOVSB'] arg_format = ['sbw','sb'] class trans(base.VarArgsInstruction): code = opcodes['TRANS'] + is_vec = lambda self: True def __init__(self, *args): self.arg_format = ['int'] + ['sbw'] * args[0] + \ ['sb'] * (len(args) - 1 - args[0]) super(trans, self).__init__(*args) -class bitb(base.Instruction): +class bitb(NonVectorInstruction): code = opcodes['BITB'] arg_format = ['sbw'] @@ -245,20 +262,22 @@ class inputb(base.DoNotEliminateInstruction, base.VarArgsInstruction): __slots__ = [] code = opcodes['INPUTB'] arg_format = tools.cycle(['p','int','int','sbw']) + is_vec = lambda self: True -class print_regb(base.IOInstruction): +class print_regb(base.VectorInstruction, base.IOInstruction): code = opcodes['PRINTREGB'] arg_format = ['cb','i'] def __init__(self, reg, comment=''): super(print_regb, self).__init__(reg, self.str_to_int(comment)) -class print_reg_plainb(base.IOInstruction): +class print_reg_plainb(NonVectorInstruction, base.IOInstruction): code = opcodes['PRINTREGPLAINB'] arg_format = ['cb'] class print_reg_signed(base.IOInstruction): code = opcodes['PRINTREGSIGNED'] arg_format = ['int','cb'] + is_vec = lambda self: True class print_float_plainb(base.IOInstruction): __slots__ = [] diff --git a/Compiler/GC/types.py b/Compiler/GC/types.py index 968edae85..21ccedf53 100644 --- a/Compiler/GC/types.py +++ b/Compiler/GC/types.py @@ -77,6 +77,9 @@ def malloc(cls, size): def n_elements(): return 1 @classmethod + def mem_size(cls): + return math.ceil(cls.n / cls.unit) + @classmethod def load_mem(cls, address, mem_type=None, size=None): if size not in (None, 1): v = [cls.load_mem(address + i) for i in range(size)] @@ -101,9 +104,8 @@ def __init__(self, value=None, n=None, size=None): def copy(self): return type(self)(n=instructions_base.get_global_vector_size()) def set_length(self, n): - if n > self.max_length: - print(self.max_length) - raise Exception('too long: %d' % n) + if n > self.n: + raise Exception('too long: %d/%d' % (n, self.n)) self.n = n def set_size(self, size): pass @@ -135,7 +137,7 @@ def __repr__(self): if self.n != None: suffix = '%d' % self.n if type(self).n != None and type(self).n != self.n: - suffice += '/%d' % type(self).n + suffix += '/%d' % type(self).n else: suffix = 'undef' return '%s(%s)' % (super(bits, self).__repr__(), suffix) @@ -237,6 +239,7 @@ class sbits(bits): bitdec = inst.bitdecs bitcom = inst.bitcoms conv_regint = inst.convsint + one_cache = {} @classmethod def conv_regint_by_bit(cls, n, res, other): tmp = cbits.get_type(n)() @@ -285,14 +288,12 @@ def load_int(self, value): % (value, self.n)) if self.n <= 32: inst.ldbits(self, self.n, value) - elif self.n <= 64: - self.load_other(regint(value, size=1)) - elif self.n <= 128: - lower = sbits.get_type(64)(value % 2**64) - upper = sbits.get_type(self.n - 64)(value >> 64) - self.mov(self, lower + (upper << 64)) else: - raise NotImplementedError('more than 128 bits wanted') + size = math.ceil(self.n / self.unit) + tmp = regint(size=size) + for i in range(size): + tmp[i].load_int((value >> (i * 64)) % 2**64) + self.load_other(tmp) def load_other(self, other): if isinstance(other, cbits) and self.n == other.n: inst.convcbit2s(self.n, self, other) @@ -393,11 +394,10 @@ def __invert__(self): # res = type(self)(n=self.n) # inst.nots(res, self) # return res - if self.n == None or self.n > self.unit: - one = self.get_type(self.n)() - self.conv_regint_by_bit(self.n, one, regint(1, size=self.n)) - else: - one = self.new(value=self.long_one(), n=self.n) + key = self.n, library.get_block() + if key not in self.one_cache: + self.one_cache[key] = self.new(value=self.long_one(), n=self.n) + one = self.one_cache[key] return self + one def __neg__(self): return self @@ -432,12 +432,12 @@ def popcnt(self): @classmethod def trans(cls, rows): rows = list(rows) - if len(rows) == 1: + if len(rows) == 1 and rows[0].n <= rows[0].unit: return rows[0].bit_decompose() n_columns = rows[0].n for row in rows: assert(row.n == n_columns) - if n_columns == 1: + if n_columns == 1 and len(rows) <= cls.unit: return [cls.bit_compose(rows)] else: res = [cls.new(n=len(rows)) for i in range(n_columns)] @@ -452,6 +452,10 @@ def bit_adder(*args, **kwargs): @staticmethod def ripple_carry_adder(*args, **kwargs): return sbitint.ripple_carry_adder(*args, **kwargs) + def to_sint(self, n_bits): + bits = sbitvec.from_vec(sbitvec([self]).v[:n_bits]).elements()[0] + bits = sint(bits, size=n_bits) + return sint.bit_compose(bits) class sbitvec(_vec): @classmethod @@ -524,6 +528,8 @@ def __iter__(self): return iter(self.v) def __len__(self): return len(self.v) + def __getitem__(self, index): + return self.v[index] @classmethod def conv(cls, other): return cls.from_vec(other.v) diff --git a/Compiler/allocator.py b/Compiler/allocator.py index 0f570e6d2..8d6b89c49 100644 --- a/Compiler/allocator.py +++ b/Compiler/allocator.py @@ -210,85 +210,6 @@ def compute_max_depths(self, depth_of): max_depth_of[v] = min(max_depth_of[u], max_depth_of[v]) return max_depth_of - def merge_inputs(self): - merges = defaultdict(list) - remaining_input_nodes = [] - def do_merge(nodes): - if len(nodes) > 1000: - print('Merging %d inputs...' % len(nodes)) - self.do_merge(iter(nodes)) - for n in self.input_nodes: - inst = self.instructions[n] - merge = merges[inst.args[0],inst.__class__] - if len(merge) == 0: - remaining_input_nodes.append(n) - merge.append(n) - if len(merge) >= self.max_parallel_open: - do_merge(merge) - merge[:] = [] - for merge in reversed(sorted(merges.values())): - if merge: - do_merge(merge) - self.input_nodes = remaining_input_nodes - - def compute_preorder(self, merges, rev_depth_of): - # find flexible nodes that can be on several levels - # and find sources on level 0 - G = self.G - merge_nodes_set = self.open_nodes - depth_of = self.depths - instructions = self.instructions - flex_nodes = defaultdict(dict) - starters = [] - for n in range(len(G)): - if n not in merge_nodes_set and \ - depth_of[n] != rev_depth_of[n] and G[n] and G.get_attr(n,'start') == -1 and not isinstance(instructions[n], AsymmetricCommunicationInstruction): - #print n, depth_of[n], rev_depth_of[n] - flex_nodes[depth_of[n]].setdefault(rev_depth_of[n], set()).add(n) - elif len(G.pred[n]) == 0 and \ - not isinstance(self.instructions[n], RawInputInstruction): - starters.append(n) - if n % 10000000 == 0 and n > 0: - print("Processed %d nodes at" % n, time.asctime()) - - inputs = defaultdict(list) - for node in self.input_nodes: - player = self.instructions[node].args[0] - inputs[player].append(node) - first_inputs = [l[0] for l in inputs.values()] - other_inputs = [] - i = 0 - while True: - i += 1 - found = False - for l in inputs.values(): - if i < len(l): - other_inputs.append(l[i]) - found = True - if not found: - break - other_inputs.reverse() - - preorder = [] - # magical preorder for topological search - max_depth = max(merges) - if max_depth > 10000: - print("Computing pre-ordering ...") - for i in range(max_depth, 0, -1): - preorder.append(G.get_attr(merges[i], 'stop')) - for j in flex_nodes[i-1].values(): - preorder.extend(j) - preorder.extend(flex_nodes[0].get(i, [])) - preorder.append(merges[i]) - if i % 100000 == 0 and i > 0: - print("Done level %d at" % i, time.asctime()) - preorder.extend(other_inputs) - preorder.extend(starters) - preorder.extend(first_inputs) - if max_depth > 10000: - print("Done at", time.asctime()) - return preorder - def longest_paths_merge(self): """ Attempt to merge instructions of type instruction_type (which are given in merge_nodes) using longest paths algorithm. @@ -301,7 +222,7 @@ def longest_paths_merge(self): instructions = self.instructions merge_nodes = self.open_nodes depths = self.depths - if not merge_nodes and not self.input_nodes: + if not merge_nodes: return 0 # merge opens at same depth @@ -321,8 +242,6 @@ def longest_paths_merge(self): (len(merge), t.__name__, i, len(merges))) self.do_merge(merge) - self.merge_inputs() - preorder = None if len(instructions) > 100000: @@ -340,7 +259,6 @@ def dependency_graph(self, merge_classes): options = self.options open_nodes = set() self.open_nodes = open_nodes - self.input_nodes = [] colordict = defaultdict(lambda: 'gray', asm_open='red',\ ldi='lightblue', ldm='lightblue', stm='blue',\ mov='yellow', mulm='orange', mulc='orange',\ @@ -507,14 +425,7 @@ def keep_order(instr, n, t, arg_index=None): elif isinstance(instr, PublicFileIOInstruction): keep_order(instr, n, instr.__class__) elif isinstance(instr, RawInputInstruction): - keep_order(instr, n, instr.__class__, 0) - self.input_nodes.append(n) - G.add_node(n, merges=[]) - player = instr.args[0] - if isinstance(instr, stopinput): - add_edge(last[startinput_class][player], n) - elif isinstance(instr, gstopinput): - add_edge(last[gstartinput][player], n) + keep_order(instr, n, instr.__class__) elif isinstance(instr, startprivateoutput_class): keep_order(instr, n, startprivateoutput_class, 2) elif isinstance(instr, stopprivateoutput_class): @@ -559,18 +470,14 @@ def eliminate_dead_code(self): unused_result = not G.degree(i) and len(list(inst.get_def())) \ and reduce(operator.and_, (reg.can_eliminate for reg in inst.get_def())) \ and not isinstance(inst, (DoNotEliminateInstruction)) - stop_node = G.get_attr(i, 'stop') - unused_startopen = stop_node != -1 and instructions[stop_node] is None def eliminate(i): G.remove_node(i) merge_nodes.discard(i) stats[type(instructions[i]).__name__] += 1 instructions[i] = None - if unused_result or unused_startopen: + if unused_result: eliminate(i) count += 1 - if unused_startopen: - open_count += len(inst.args) # remove unnecessary stack instructions # left by optimization with budget if isinstance(inst, popint_class) and \ diff --git a/Compiler/circuit.py b/Compiler/circuit.py new file mode 100644 index 000000000..1f3685d77 --- /dev/null +++ b/Compiler/circuit.py @@ -0,0 +1,201 @@ +""" +This module contains functionality using circuits in the so-called +`Bristol Fashion`_ format. You can download a few examples including +the ones used below into ``Programs/Circuits`` as follows:: + + make Programs/Circuits + +.. _`Bristol Fashion`: https://homes.esat.kuleuven.be/~nsmart/MPC + +""" + +from Compiler.GC.types import sbitvec, sbits +from Compiler.library import function_block +from Compiler import util +import itertools + +class Circuit: + """ + Use a Bristol Fashion circuit in a high-level program. The + following example adds signed 64-bit inputs from two different + parties and prints the result:: + + from circuit import Circuit + sb64 = sbits.get_type(64) + adder = Circuit('adder64') + a, b = [sbitvec(sb64.get_input_from(i)) for i in (0, 1)] + print_ln('%s', adder(a, b).elements()[0].reveal()) + + Circuits can also be executed in parallel as the following example + shows:: + + from circuit import Circuit + sb128 = sbits.get_type(128) + key = sb128(0x2b7e151628aed2a6abf7158809cf4f3c) + plaintext = sb128(0x6bc1bee22e409f96e93d7e117393172a) + n = 1000 + aes128 = Circuit('aes_128') + ciphertexts = aes128(sbitvec([key] * n), sbitvec([plaintext] * n)) + ciphertexts.elements()[n - 1].reveal().print_reg() + + This executes AES-128 1000 times in parallel and then outputs the + last result, which should be ``0x3ad77bb40d7a3660a89ecaf32466ef97``, + one of the test vectors for AES-128. + + """ + + def __init__(self, name): + self.filename = 'Programs/Circuits/%s.txt' % name + f = open(self.filename) + self.functions = {} + + def __call__(self, *inputs): + return self.run(*inputs) + + def run(self, *inputs): + n = inputs[0][0].n + if n not in self.functions: + self.functions[n] = function_block(lambda *args: + self.compile(*args)) + flat_res = self.functions[n](*itertools.chain(*inputs)) + res = [] + i = 0 + for l in self.n_output_wires: + v = [] + for i in range(l): + v.append(flat_res[i]) + i += 1 + res.append(sbitvec.from_vec(v)) + return util.untuplify(res) + + def compile(self, *all_inputs): + f = open(self.filename) + lines = iter(f) + next_line = lambda: next(lines).split() + n_gates, n_wires = (int(x) for x in next_line()) + self.n_wires = n_wires + input_line = [int(x) for x in next_line()] + n_inputs = input_line[0] + n_input_wires = input_line[1:] + assert(n_inputs == len(n_input_wires)) + inputs = [] + s = 0 + for n in n_input_wires: + inputs.append(all_inputs[s:s + n]) + s += n + output_line = [int(x) for x in next_line()] + n_outputs = output_line[0] + self.n_output_wires = output_line[1:] + assert(n_outputs == len(self.n_output_wires)) + next(lines) + + wires = [None] * n_wires + self.wires = wires + i_wire = 0 + for input, input_wires in zip(inputs, n_input_wires): + assert(len(input) == input_wires) + for i, reg in enumerate(input): + wires[i_wire] = reg + i_wire += 1 + + for i in range(n_gates): + line = next_line() + t = line[-1] + if t in ('XOR', 'AND'): + assert line[0] == '2' + assert line[1] == '1' + assert len(line) == 6 + ins = [wires[int(line[2 + i])] for i in range(2)] + if t == 'XOR': + wires[int(line[4])] = ins[0] ^ ins[1] + else: + wires[int(line[4])] = ins[0] & ins[1] + elif t == 'INV': + assert line[0] == '1' + assert line[1] == '1' + assert len(line) == 5 + wires[int(line[3])] = ~wires[int(line[2])] + + return self.wires[-sum(self.n_output_wires):] + +Keccak_f = None + +def sha3_256(x): + """ + This function implements SHA3-256 for inputs of up to 1080 bits:: + + from circuit import sha3_256 + a = sbitvec.from_vec([]) + b = sbitvec(sint(0xcc), 8) + for x in a, b: + sha3_256(x).elements()[0].reveal().print_reg() + + This should output the first two test vectors of SHA3-256 in + byte-reversed order:: + + 0x5375f6fb6aa989b0c287a923afe81e79ff875921cacc956666d71ebff8c6ffa7 + 0x17c7e0d65c285af8406d4f21c071851a312b739a8ecdf25c1270d31c39357067 + + Note that :py:obj:`sint` to :py:obj:`sbitvec` conversion is only + implemented for computation modulo a power of two. + """ + + global Keccak_f + if Keccak_f is None: + # only one instance + Keccak_f = Circuit('Keccak_f') + + # whole bytes + assert len(x.v) % 8 == 0 + # only one block + r = 1088 + assert len(x.v) < 1088 + if x.v: + n = x.v[0].n + else: + n = 1 + d = sbitvec([sbits.get_type(8)(0x06)] * n) + sbn = sbits.get_type(n) + padding = [sbn(0)] * (r - 8 - len(x.v)) + P_flat = x.v + d.v + padding + assert len(P_flat) == r + P_flat[-1] = ~P_flat[-1] + w = 64 + P1 = [P_flat[i * w:(i + 1) * w] for i in range(r // w)] + + S = [[[sbn(0) for i in range(w)] for i in range(5)] for i in range(5)] + for x in range(5): + for y in range(5): + if x + 5 * y < r // w: + for i in range(w): + S[x][y][i] ^= P1[x + 5 * y][i] + + def flatten(S): + res = [None] * 1600 + for y in range(5): + for x in range(5): + for i in range(w): + j = (5 * y + x) * w + i // 8 * 8 + 7 - i % 8 + res[1600 - 1 - j] = S[x][y][i] + return res + + def unflatten(S_flat): + res = [[[None] * w for j in range(5)] for i in range(5)] + for y in range(5): + for x in range(5): + for i in range(w): + j = (5 * y + x) * w + i // 8 * 8 + 7 - i % 8 + res[x][y][i] = S_flat[1600 - 1 -j] + return res + + S = unflatten(Keccak_f(flatten(S))) + + Z = [] + while len(Z) <= 256: + for y in range(5): + for x in range(5): + if x + 5 * y < r // w: + Z += S[y][x] + if len(Z) <= 256: + S = unflatten(Keccak_f(flatten(S))) + return sbitvec.from_vec(Z[:256]) diff --git a/Compiler/comparison.py b/Compiler/comparison.py index 8af1bdec8..3ba37ecd0 100644 --- a/Compiler/comparison.py +++ b/Compiler/comparison.py @@ -262,6 +262,7 @@ def Mod2mField(a_prime, a, k, m, kappa, signed): return r_dprime, r_prime, c, c_prime, u, t, c2k1 def MaskingBitsInRing(m, strict=False): + program.curr_tape.require_bit_length(1) from Compiler.types import sint if program.use_edabit(): return sint.get_edabit(m, strict) diff --git a/Compiler/compilerLib.py b/Compiler/compilerLib.py index fca698a3b..70d489cc2 100644 --- a/Compiler/compilerLib.py +++ b/Compiler/compilerLib.py @@ -9,19 +9,18 @@ import sys -def run(args, options, param=-1, merge_opens=True, emulate=True, \ - reallocate=True, assemblymode=False, debug=False): +def run(args, options, param=-1, merge_opens=True, + reallocate=True, debug=False): """ Compile a file and output a Program object. If merge_opens is set to True, will attempt to merge any parallelisable open instructions. """ - prog = Program(args, options, param, assemblymode) + prog = Program(args, options, param) instructions.program = prog instructions_base.program = prog types.program = prog comparison.program = prog - prog.EMULATE = emulate prog.DEBUG = debug VARS['program'] = prog if options.binary: @@ -31,26 +30,9 @@ def run(args, options, param=-1, merge_opens=True, emulate=True, \ print('Compiling file', prog.infile) - # no longer needed, but may want to support assembly in future (?) - if assemblymode: - prog.restart_main_thread() - for i in range(INIT_REG_MAX): - VARS['c%d'%i] = prog.curr_block.new_reg('c') - VARS['s%d'%i] = prog.curr_block.new_reg('s') - VARS['cg%d'%i] = prog.curr_block.new_reg('cg') - VARS['sg%d'%i] = prog.curr_block.new_reg('sg') - if i % 10000000 == 0 and i > 0: - print("Initialized %d register variables at" % i, time.asctime()) - - # first pass determines how many assembler registers are used - prog.FIRST_PASS = True - exec(compile(open(prog.infile).read(), prog.infile, 'exec'), VARS) - if instructions_base.Instruction.count != 0: print('instructions count', instructions_base.Instruction.count) instructions_base.Instruction.count = 0 - prog.FIRST_PASS = False - prog.reset_values() # make compiler modules directly accessible sys.path.insert(0, 'Compiler') # create the tapes @@ -60,17 +42,14 @@ def run(args, options, param=-1, merge_opens=True, emulate=True, \ for tape in prog.tapes: tape.optimize(options) - # check program still does the same thing after optimizations - if emulate: - clearmem = list(prog.mem_c) - sharedmem = list(prog.mem_s) - prog.emulate() - if prog.mem_c != clearmem or prog.mem_s != sharedmem: - print('Warning: emulated memory values changed after compiler optimization') - # raise CompilerError('Compiler optimization caused incorrect memory write.') - if prog.main_thread_running: prog.update_req(prog.curr_tape) + + if prog.req_num: + print('Program requires:') + for x in prog.req_num.pretty(): + print(x) + if prog.verbose: print('Program requires:', repr(prog.req_num)) print('Cost:', 0 if prog.req_num is None else prog.req_num.cost()) diff --git a/Compiler/config.py b/Compiler/config.py index 680ad64c9..7a30c7382 100644 --- a/Compiler/config.py +++ b/Compiler/config.py @@ -1,13 +1,7 @@ from collections import defaultdict -#INIT_REG_MAX = 655360 -INIT_REG_MAX = 1310720 REG_MAX = 2 ** 32 USER_MEM = 8192 -TMP_MEM = 8192 -TMP_MEM_BASE = USER_MEM -TMP_REG = 3 -TMP_REG_BASE = REG_MAX - TMP_REG P_VALUES = { 32: 2147565569, \ 64: 9223372036855103489, \ diff --git a/Compiler/graph.py b/Compiler/graph.py index 43bea6ba5..e004a148d 100644 --- a/Compiler/graph.py +++ b/Compiler/graph.py @@ -17,7 +17,7 @@ def __init__(self, max_nodes, default_attributes=None): """ max_nodes: maximum no of nodes default_attributes: dict of node attributes and default values """ if default_attributes is None: - default_attributes = { 'merges': None, 'stop': -1, 'start': -1 } + default_attributes = { 'merges': None } self.default_attributes = default_attributes self.attribute_pos = dict(list(zip(list(default_attributes.keys()), list(range(len(default_attributes)))))) self.n = max_nodes diff --git a/Compiler/instructions.py b/Compiler/instructions.py index fd74fe78c..d46464bfb 100644 --- a/Compiler/instructions.py +++ b/Compiler/instructions.py @@ -34,9 +34,6 @@ class ldi(base.Instruction): __slots__ = [] code = base.opcodes['LDI'] arg_format = ['cw','i'] - - def execute(self): - self.args[0].value = self.args[1] @base.gf2n @base.vectorize @@ -45,9 +42,6 @@ class ldsi(base.Instruction): __slots__ = [] code = base.opcodes['LDSI'] arg_format = ['sw','i'] - - def execute(self): - self.args[0].value = self.args[1] @base.gf2n @base.vectorize @@ -57,9 +51,6 @@ class ldmc(base.DirectMemoryInstruction, base.ReadMemoryInstruction): code = base.opcodes['LDMC'] arg_format = ['cw','int'] - def execute(self): - self.args[0].value = program.mem_c[self.args[1]] - @base.gf2n @base.vectorize class ldms(base.DirectMemoryInstruction, base.ReadMemoryInstruction): @@ -68,9 +59,6 @@ class ldms(base.DirectMemoryInstruction, base.ReadMemoryInstruction): code = base.opcodes['LDMS'] arg_format = ['sw','int'] - def execute(self): - self.args[0].value = program.mem_s[self.args[1]] - @base.gf2n @base.vectorize class stmc(base.DirectMemoryWriteInstruction): @@ -79,9 +67,6 @@ class stmc(base.DirectMemoryWriteInstruction): code = base.opcodes['STMC'] arg_format = ['c','int'] - def execute(self): - program.mem_c[self.args[1]] = self.args[0].value - @base.gf2n @base.vectorize class stms(base.DirectMemoryWriteInstruction): @@ -90,9 +75,6 @@ class stms(base.DirectMemoryWriteInstruction): code = base.opcodes['STMS'] arg_format = ['s','int'] - def execute(self): - program.mem_s[self.args[1]] = self.args[0].value - @base.vectorize class ldmint(base.DirectMemoryInstruction, base.ReadMemoryInstruction): r""" Assigns register $ci_i$ the value in memory \verb+Ci[n]+. """ @@ -100,9 +82,6 @@ class ldmint(base.DirectMemoryInstruction, base.ReadMemoryInstruction): code = base.opcodes['LDMINT'] arg_format = ['ciw','int'] - def execute(self): - self.args[0].value = program.mem_i[self.args[1]] - @base.vectorize class stmint(base.DirectMemoryWriteInstruction): r""" Sets \verb+Ci[n]+ to be the value $ci_i$. """ @@ -110,18 +89,12 @@ class stmint(base.DirectMemoryWriteInstruction): code = base.opcodes['STMINT'] arg_format = ['ci','int'] - def execute(self): - program.mem_i[self.args[1]] = self.args[0].value - # must have seperate instructions because address is always modp @base.vectorize class ldmci(base.ReadMemoryInstruction): r""" Assigns register $c_i$ the value in memory \verb+C[cj]+. """ code = base.opcodes['LDMCI'] arg_format = ['cw','ci'] - - def execute(self): - self.args[0].value = program.mem_c[self.args[1].value] @base.vectorize class ldmsi(base.ReadMemoryInstruction): @@ -129,53 +102,35 @@ class ldmsi(base.ReadMemoryInstruction): code = base.opcodes['LDMSI'] arg_format = ['sw','ci'] - def execute(self): - self.args[0].value = program.mem_s[self.args[1].value] - @base.vectorize class stmci(base.WriteMemoryInstruction): r""" Sets \verb+C[cj]+ to be the value $c_i$. """ code = base.opcodes['STMCI'] arg_format = ['c','ci'] - def execute(self): - program.mem_c[self.args[1].value] = self.args[0].value - @base.vectorize class stmsi(base.WriteMemoryInstruction): r""" Sets \verb+S[cj]+ to be the value $s_i$. """ code = base.opcodes['STMSI'] arg_format = ['s','ci'] - def execute(self): - program.mem_s[self.args[1].value] = self.args[0].value - @base.vectorize class ldminti(base.ReadMemoryInstruction): r""" Assigns register $ci_i$ the value in memory \verb+Ci[cj]+. """ code = base.opcodes['LDMINTI'] arg_format = ['ciw','ci'] - def execute(self): - self.args[0].value = program.mem_i[self.args[1].value] - @base.vectorize class stminti(base.WriteMemoryInstruction): r""" Sets \verb+Ci[cj]+ to be the value $ci_i$. """ code = base.opcodes['STMINTI'] arg_format = ['ci','ci'] - def execute(self): - program.mem_i[self.args[1].value] = self.args[0].value - @base.vectorize class gldmci(base.ReadMemoryInstruction): r""" Assigns register $c_i$ the value in memory \verb+C[cj]+. """ code = base.opcodes['LDMCI'] + 0x100 arg_format = ['cgw','ci'] - - def execute(self): - self.args[0].value = program.mem_c[self.args[1].value] @base.vectorize class gldmsi(base.ReadMemoryInstruction): @@ -183,27 +138,18 @@ class gldmsi(base.ReadMemoryInstruction): code = base.opcodes['LDMSI'] + 0x100 arg_format = ['sgw','ci'] - def execute(self): - self.args[0].value = program.mem_s[self.args[1].value] - @base.vectorize class gstmci(base.WriteMemoryInstruction): r""" Sets \verb+C[cj]+ to be the value $c_i$. """ code = base.opcodes['STMCI'] + 0x100 arg_format = ['cg','ci'] - def execute(self): - program.mem_c[self.args[1].value] = self.args[0].value - @base.vectorize class gstmsi(base.WriteMemoryInstruction): r""" Sets \verb+S[cj]+ to be the value $s_i$. """ code = base.opcodes['STMSI'] + 0x100 arg_format = ['sg','ci'] - def execute(self): - program.mem_s[self.args[1].value] = self.args[0].value - @base.gf2n @base.vectorize class protectmems(base.Instruction): @@ -233,9 +179,6 @@ class movc(base.Instruction): code = base.opcodes['MOVC'] arg_format = ['cw','c'] - def execute(self): - self.args[0].value = self.args[1].value - @base.gf2n @base.vectorize class movs(base.Instruction): @@ -244,9 +187,6 @@ class movs(base.Instruction): code = base.opcodes['MOVS'] arg_format = ['sw','s'] - def execute(self): - self.args[0].value = self.args[1].value - @base.vectorize class movint(base.Instruction): r""" Assigns register $ci_i$ the value in the register $ci_j$. """ @@ -452,9 +392,6 @@ class divc(base.InvertInstruction): __slots__ = [] code = base.opcodes['DIVC'] arg_format = ['cw','c','c'] - - def execute(self): - self.args[0].value = self.args[1].value * pow(self.args[2].value, program.P-2, program.P) % program.P @base.gf2n @base.vectorize @@ -464,9 +401,6 @@ class modc(base.Instruction): code = base.opcodes['MODC'] arg_format = ['cw','c','c'] - def execute(self): - self.args[0].value = self.args[1].value % self.args[2].value - @base.vectorize class inv2m(base.InvertInstruction): __slots__ = [] @@ -498,9 +432,6 @@ class andc(base.Instruction): __slots__ = [] code = base.opcodes['ANDC'] arg_format = ['cw','c','c'] - - def execute(self): - self.args[0].value = (self.args[1].value & self.args[2].value) % program.P @base.gf2n @base.vectorize @@ -509,9 +440,6 @@ class orc(base.Instruction): __slots__ = [] code = base.opcodes['ORC'] arg_format = ['cw','c','c'] - - def execute(self): - self.args[0].value = (self.args[1].value | self.args[2].value) % program.P @base.gf2n @base.vectorize @@ -520,9 +448,6 @@ class xorc(base.Instruction): __slots__ = [] code = base.opcodes['XORC'] arg_format = ['cw','c','c'] - - def execute(self): - self.args[0].value = (self.args[1].value ^ self.args[2].value) % program.P @base.vectorize class notc(base.Instruction): @@ -530,9 +455,6 @@ class notc(base.Instruction): __slots__ = [] code = base.opcodes['NOTC'] arg_format = ['cw','c', 'int'] - - def execute(self): - self.args[0].value = (~self.args[1].value + 2 ** self.args[2]) % program.P @base.vectorize class gnotc(base.Instruction): @@ -544,9 +466,6 @@ class gnotc(base.Instruction): def is_gf2n(self): return True - def execute(self): - self.args[0].value = ~self.args[1].value - @base.vectorize class gbitdec(base.Instruction): r""" Store every $n$-th bit of $cg_i$ in $cg_j, \dots$. """ @@ -672,8 +591,6 @@ class divci(base.InvertInstruction, base.ClearImmediate): r""" Clear division by immediate value $c_i=c_j/n$. """ __slots__ = [] code = base.opcodes['DIVCI'] - def execute(self): - self.args[0].value = self.args[1].value * pow(self.args[2], program.P-2, program.P) % program.P @base.gf2n @base.vectorize @@ -719,9 +636,6 @@ class shlc(base.Instruction): __slots__ = [] code = base.opcodes['SHLC'] arg_format = ['cw','c','c'] - - def execute(self): - self.args[0].value = (self.args[1].value << self.args[2].value) % program.P @base.gf2n @base.vectorize @@ -730,9 +644,6 @@ class shrc(base.Instruction): __slots__ = [] code = base.opcodes['SHRC'] arg_format = ['cw','c','c'] - - def execute(self): - self.args[0].value = (self.args[1].value >> self.args[2].value) % program.P @base.gf2n @base.vectorize @@ -764,11 +675,6 @@ class triple(base.DataInstruction): code = base.opcodes['TRIPLE'] arg_format = ['sw','sw','sw'] data_type = 'triple' - - def execute(self): - self.args[0].value = randint(0,program.P) - self.args[1].value = randint(0,program.P) - self.args[2].value = (self.args[0].value * self.args[1].value) % program.P @base.vectorize class gbittriple(base.DataInstruction): @@ -804,9 +710,6 @@ class bit(base.DataInstruction): code = base.opcodes['BIT'] arg_format = ['sw'] data_type = 'bit' - - def execute(self): - self.args[0].value = randint(0,1) @base.vectorize class dabit(base.DataInstruction): @@ -848,10 +751,6 @@ class square(base.DataInstruction): code = base.opcodes['SQUARE'] arg_format = ['sw','sw'] data_type = 'square' - - def execute(self): - self.args[0].value = randint(0,program.P) - self.args[1].value = (self.args[0].value * self.args[0].value) % program.P @base.gf2n @base.vectorize @@ -868,11 +767,6 @@ def __init__(self, *args, **kwargs): raise CompilerError('random inverse in ring not implemented') base.DataInstruction.__init__(self, *args, **kwargs) - def execute(self): - self.args[0].value = randint(0,program.P) - import gmpy - self.args[1].value = int(gmpy.invert(self.args[0].value, program.P)) - @base.gf2n @base.vectorize class inputmask(base.Instruction): @@ -920,8 +814,6 @@ def add_usage(self, req_node): for player in self.args[1::2]: req_node.increment((self.field_type, 'input', player), \ self.get_size()) - def execute(self): - self.args[0].value = _python_input("Enter player %d's input:" % self.args[1]) % program.P @base.vectorize class inputfix(base.TextInputInstruction): @@ -1006,46 +898,18 @@ def add_usage(self, req_node): req_node.increment((self.field_type, 'input', 0), float('inf')) @base.gf2n -class startinput(base.RawInputInstruction): +class rawinput(base.RawInputInstruction, base.Mergeable): r""" Receive inputs from player $p$. """ __slots__ = [] - code = base.opcodes['STARTINPUT'] - arg_format = ['p', 'int'] + code = base.opcodes['RAWINPUT'] + arg_format = tools.cycle(['p','sw']) field_type = 'modp' def add_usage(self, req_node): - req_node.increment((self.field_type, 'input', self.args[0]), \ - self.args[1]) - - def merge(self, other): - self.args[1] += other.args[1] - -class StopInputInstruction(base.RawInputInstruction): - __slots__ = [] - - def merge(self, other): - if self.get_size() != other.get_size(): - raise NotImplemented() - else: - self.args += other.args[1:] - -class stopinput(StopInputInstruction): - r""" Receive inputs from player $p$ and put in registers. """ - __slots__ = [] - code = base.opcodes['STOPINPUT'] - arg_format = tools.chain(['p'], itertools.repeat('sw')) - - def has_var_args(self): - return True - -class gstopinput(StopInputInstruction): - r""" Receive inputs from player $p$ and put in registers. """ - __slots__ = [] - code = 0x100 + base.opcodes['STOPINPUT'] - arg_format = tools.chain(['p'], itertools.repeat('sgw')) - - def has_var_args(self): - return True + for i in range(0, len(self.args), 2): + player = self.args[i] + req_node.increment((self.field_type, 'input', player), \ + self.get_size()) @base.gf2n @base.vectorize @@ -1054,9 +918,6 @@ class print_mem(base.IOInstruction): __slots__ = [] code = base.opcodes['PRINTMEM'] arg_format = ['c'] - - def execute(self): - pass @base.gf2n @base.vectorize @@ -1069,9 +930,6 @@ class print_reg(base.IOInstruction): def __init__(self, reg, comment=''): super(print_reg_class, self).__init__(reg, self.str_to_int(comment)) - def execute(self): - pass - @base.gf2n @base.vectorize class print_reg_plain(base.IOInstruction): @@ -1238,41 +1096,6 @@ class acceptclientconnection(base.IOInstruction): code = base.opcodes['ACCEPTCLIENTCONNECTION'] arg_format = ['ciw', 'int'] -class connectipv4(base.IOInstruction): - """Connect to server at IPv4 address in register \verb|cj| at given port. Write socket handle to register \verb|ci|""" - __slots__ = [] - code = base.opcodes['CONNECTIPV4'] - arg_format = ['ciw', 'ci', 'int'] - -class readclientpublickey(base.IOInstruction): - """Read a client public key as 8 32-bit ints for a specified client id""" - __slots__ = [] - code = base.opcodes['READCLIENTPUBLICKEY'] - arg_format = tools.chain(['ci'], itertools.repeat('ci')) - - def has_var_args(self): - return True - -class initsecuresocket(base.IOInstruction): - """Read a client public key as 8 32-bit ints for a specified client id, - negotiate a shared key via STS and use it for replay resistant comms""" - __slots__ = [] - code = base.opcodes['INITSECURESOCKET'] - arg_format = tools.chain(['ci'], itertools.repeat('ci')) - - def has_var_args(self): - return True - -class respsecuresocket(base.IOInstruction): - """Read a client public key as 8 32-bit ints for a specified client id, - negotiate a shared key via STS and use it for replay resistant comms""" - __slots__ = [] - code = base.opcodes['RESPSECURESOCKET'] - arg_format = tools.chain(['ci'], itertools.repeat('ci')) - - def has_var_args(self): - return True - class writesharestofile(base.IOInstruction): """Write shares to a file""" __slots__ = [] @@ -1392,12 +1215,6 @@ class eqzc(base.UnaryComparisonInstruction): r""" Clear comparison $c_i = (c_j \stackrel{?}{==} 0)$. """ __slots__ = [] code = base.opcodes['EQZC'] - - def execute(self): - if self.args[1].value == 0: - self.args[0].value = 1 - else: - self.args[0].value = 0 @base.vectorize class ltzc(base.UnaryComparisonInstruction): @@ -1435,9 +1252,6 @@ class jmp(base.JumpInstruction): arg_format = ['int'] jump_arg = 0 - def execute(self): - pass - class jmpi(base.JumpInstruction): """ Unconditional relative jump of $c_i+1$ instructions. """ __slots__ = [] @@ -1457,9 +1271,6 @@ class jmpnz(base.JumpInstruction): code = base.opcodes['JMPNZ'] arg_format = ['ci', 'int'] jump_arg = 1 - - def execute(self): - pass class jmpeqz(base.JumpInstruction): r""" Jump $n+1$ instructions if $c_i == 0$. """ @@ -1467,9 +1278,6 @@ class jmpeqz(base.JumpInstruction): code = base.opcodes['JMPEQZ'] arg_format = ['ci', 'int'] jump_arg = 1 - - def execute(self): - pass ### ### Conversions diff --git a/Compiler/instructions_base.py b/Compiler/instructions_base.py index 70514eb62..373a9e5fe 100644 --- a/Compiler/instructions_base.py +++ b/Compiler/instructions_base.py @@ -115,6 +115,7 @@ INPUTFLOAT = 0xF1, INPUTMIXED = 0xF2, INPUTMIXEDREG = 0xF3, + RAWINPUT = 0xF4, STARTINPUT = 0x61, STOPINPUT = 0x62, READSOCKETC = 0x63, @@ -688,15 +689,12 @@ def __init__(self, *args, **kwargs): self.args = list(args) if not kwargs.get('copying', False): self.check_args() - if not program.FIRST_PASS: - if kwargs.get('add_to_prog', True): - program.curr_block.instructions.append(self) - if program.DEBUG: - self.caller = [frame[1:] for frame in inspect.stack()[1:]] - else: - self.caller = None - if program.EMULATE: - self.execute() + if kwargs.get('add_to_prog', True): + program.curr_block.instructions.append(self) + if program.DEBUG: + self.caller = [frame[1:] for frame in inspect.stack()[1:]] + else: + self.caller = None Instruction.count += 1 if Instruction.count % 100000 == 0: @@ -717,10 +715,6 @@ def get_encoding(self): def get_bytes(self): return bytearray(self.get_encoding()) - def execute(self): - """ Emulate execution of this instruction """ - raise NotImplementedError('execute method must be implemented') - def check_args(self): """ Check the args match up with that specified in arg_format """ try: @@ -839,21 +833,12 @@ def get_code(self): class AddBase(Instruction): __slots__ = [] - def execute(self): - self.args[0].value = (self.args[1].value + self.args[2].value) % program.P - class SubBase(Instruction): __slots__ = [] - def execute(self): - self.args[0].value = (self.args[1].value - self.args[2].value) % program.P - class MulBase(Instruction): __slots__ = [] - def execute(self): - self.args[0].value = (self.args[1].value * self.args[2].value) % program.P - ### ### Basic arithmetic with immediate values ### @@ -861,9 +846,6 @@ def execute(self): class ImmediateBase(Instruction): __slots__ = ['op'] - def execute(self): - exec('self.args[0].value = self.args[1].value.%s(self.args[2]) %% program.P' % self.op) - class SharedImmediate(ImmediateBase): __slots__ = [] arg_format = ['sw', 's', 'i'] @@ -1023,10 +1005,7 @@ class CISC(Instruction): def __init__(self, *args): self.args = args self.check_args() - #if EMULATE: - # self.expand() - if not program.FIRST_PASS: - self.expand() + self.expand() def expand(self): """ Expand this into a sequence of RISC instructions. """ diff --git a/Compiler/library.py b/Compiler/library.py index e28b3f7e2..a65870430 100644 --- a/Compiler/library.py +++ b/Compiler/library.py @@ -342,8 +342,10 @@ def wrapped_function(*compile_args): x.reg_type))) runtime_args = [None] * len(args) for t in sorted(type_args, key=lambda x: x.reg_type): - for i,i_arg in enumerate(type_args[t]): + i = 0 + for i_arg in type_args[t]: runtime_args[i_arg] = t.load_mem(bases[t] + i) + i += util.mem_size(t) return self.function(*(list(compile_args) + runtime_args)) self.on_first_call(wrapped_function) self.type_args[len(args)] = type_args @@ -354,10 +356,12 @@ def wrapped_function(*compile_args): for i,reg_type in enumerate(sorted(type_args, key=lambda x: x.reg_type)): store_in_mem(bases[reg_type], base + i) - for j,i_arg in enumerate(type_args[reg_type]): + j = 0 + for i_arg in type_args[reg_type]: if get_reg_type(args[i_arg]) != reg_type: raise CompilerError('type mismatch') store_in_mem(args[i_arg], bases[reg_type] + j) + j += util.mem_size(reg_type) return self.on_call(base, bases) class FunctionTape(Function): diff --git a/Compiler/ml.py b/Compiler/ml.py index d5557981c..528c51cb4 100644 --- a/Compiler/ml.py +++ b/Compiler/ml.py @@ -60,27 +60,21 @@ def sigmoid_prime(x): return sx * (1 - sx) @vectorize -def approx_sigmoid(x): - if approx_sigmoid.special and \ - get_program().options.ring and get_program().use_edabit(): - l = int(get_program().options.ring) - r, r_bits = sint.get_edabit(x.k, False) - c = ((x.v - r) << (l - x.k)).reveal() >> (l - x.k) - c_bits = c.bit_decompose(x.k) - lower_overflow = CarryOutRawLE(c_bits[:x.f - 1], r_bits[:x.f - 1]) - higher_bits = sbitint.bit_adder(c_bits[x.f - 1:], r_bits[x.f - 1:], - lower_overflow) - sign = higher_bits[-1] - higher_bits.pop(-1) - aa = sign & ~util.tree_reduce(operator.and_, higher_bits) - bb = ~sign & ~util.tree_reduce(operator.and_, [~x for x in higher_bits]) - a, b = (sint.conv(x) for x in (aa, bb)) +def approx_sigmoid(x, n=3): + if n == 5: + cuts = [-5, -2.5, 2.5, 5] + le = [0] + [x <= cut for cut in cuts] + [1] + select = [le[i + 1] - le[i] for i in range(5)] + outputs = [cfix(10 ** -4), + 0.02776 * x + 0.145, + 0.17 * x + 0.5, + 0.02776 * x + 0.85498, + cfix(1 - 10 ** -4)] + return sum(a * b for a, b in zip(select, outputs)) else: a = x < -0.5 b = x > 0.5 - return a.if_else(0, b.if_else(1, 0.5 + x)) - -approx_sigmoid.special = False + return a.if_else(0, b.if_else(1, 0.5 + x)) def lse_0_from_e_x(x, e_x): return sanitize(-x, log_e(1 + e_x), x + 2 ** -x.f, 0) @@ -144,7 +138,7 @@ def _(base, size): def eval(self, size, base=0): if self.approx: - return approx_sigmoid(self.X.get_vector(base, size)) + return approx_sigmoid(self.X.get_vector(base, size), self.approx) else: return sigmoid_from_e_x(self.X.get_vector(base, size), self.e_x.get_vector(base, size)) @@ -531,7 +525,7 @@ def n_summands(self): _, inputs_h, inputs_w, n_channels_in = self.input_shape return weights_h * weights_w * n_channels_in - def forward(self, batch): + def forward(self, batch=[None]): assert len(batch) == 1 assert(self.weight_shape[0] == self.output_shape[-1]) diff --git a/Compiler/program.py b/Compiler/program.py index 7147adc74..d9794e2c3 100644 --- a/Compiler/program.py +++ b/Compiler/program.py @@ -39,11 +39,11 @@ class Program(object): These are created by executing a file containing appropriate instructions and threads. """ - def __init__(self, args, options, param=-1, assemblymode=False): + def __init__(self, args, options, param=-1): self.options = options self.verbose = options.verbose self.args = args - self.init_names(args, assemblymode) + self.init_names(args) self.P = P_VALUES[param] self.param = param if (param != -1) + sum(x != 0 for x in(options.ring, options.field, @@ -65,8 +65,6 @@ def __init__(self, args, options, param=-1, assemblymode=False): self.tape_counter = 0 self.tapes = [] self._curr_tape = None - self.EMULATE = True # defaults - self.FIRST_PASS = False self.DEBUG = False self.main_thread_running = False self.allocated_mem = RegType.create_dict(lambda: USER_MEM) @@ -102,9 +100,8 @@ def __init__(self, args, options, param=-1, assemblymode=False): self.use_dabit = options.mixed self._edabit = options.edabit self._split = False + self._square = False Program.prog = self - - self.reset_values() def get_args(self): return self.args @@ -131,7 +128,7 @@ def max_par_tapes(self): res = max(res, sum(running.values())) return res - def init_names(self, args, assemblymode): + def init_names(self, args): # ignore path to file - source must be in Programs/Source if 'Programs' in os.listdir(os.getcwd()): # compile prog in ./Programs/Source directory @@ -153,8 +150,6 @@ def init_names(self, args, assemblymode): if os.path.exists(args[0]): self.infile = args[0] - elif assemblymode: - self.infile = self.programs_dir + '/Source/' + progname + '.asm' else: self.infile = self.programs_dir + '/Source/' + progname + '.mpc' """ @@ -234,40 +229,6 @@ def update_req(self, tape): else: self.req_num += tape.req_num - def read_memory(self, filename): - """ Read the clear and shared memory from a file """ - f = open(filename) - n = int(next(f)) - self.mem_c = [0]*n - self.mem_s = [0]*n - mem = self.mem_c - done_c = False - for line in f: - line = line.split(' ') - a = int(line[0]) - b = int(line[1]) - if a != -1: - mem[a] = b - elif done_c: - break - else: - mem = self.mem_s - done_c = True - - def get_memory(self, mem_type, i): - if mem_type == 'c': - return self.mem_c[i] - elif mem_type == 's': - return self.mem_s[i] - raise CompilerError('Invalid memory type') - - def reset_values(self): - """ Reset register and memory values. """ - for tape in self.tapes: - tape.reset_registers() - self.mem_c = list(range(USER_MEM + TMP_MEM)) - self.mem_s = list(range(USER_MEM + TMP_MEM)) - def write_bytes(self, outfile=None): """ Write all non-empty threads and schedule to files. """ # runtime doesn't support 'new-style' parallelism yet @@ -329,17 +290,6 @@ def finalize_tape(self, tape): tape.write_str(self.options.asmoutfile + '-' + tape.name) tape.purge() - def emulate(self): - """ Emulate execution of entire program. """ - self.reset_values() - for sch in self.schedule: - if sch[0] == 'start': - for tape in sch[1]: - self._curr_tape = tape - for block in tape.basicblocks: - for line in block.instructions: - line.execute() - def restart_main_thread(self): if self.main_thread_running: # wait for main thread to finish @@ -375,6 +325,10 @@ def malloc(self, size, mem_type, reg_type=None): if size == 0: return if isinstance(mem_type, type): + try: + size *= math.ceil(mem_type.n / mem_type.unit) + except AttributeError: + pass self.types[mem_type.reg_type] = mem_type mem_type = mem_type.reg_type elif reg_type is not None: @@ -447,6 +401,12 @@ def use_split(self, change=None): assert change in (2, 3) self._split = change + def use_square(self, change=None): + if change is None: + return self._square + else: + self._square = change + class Tape: """ A tape contains a list of basic blocks, onto which instructions are added. """ def __init__(self, name, program): @@ -588,7 +548,6 @@ def start_new_basicblock(self, scope=False, name=''): #print 'Compiling basic block', sub.name def init_registers(self): - self.reset_registers() self.reg_counter = RegType.create_dict(lambda: 0) def init_names(self, name): @@ -605,7 +564,6 @@ def purge(self): for block in self.basicblocks: block.purge() self._is_empty = (len(self.basicblocks) == 0) - del self.reg_values del self.basicblocks del self.active_basicblock self.purged = True @@ -840,13 +798,6 @@ def count_regs(self, reg_type=None): else: return self.reg_counter[reg_type] - def reset_registers(self): - """ Reset register values to zero. """ - self.reg_values = RegType.create_dict(lambda: []) - - def get_value(self, reg_type, i): - return self.reg_values[reg_type][i] - def __str__(self): return self.name @@ -883,12 +834,28 @@ def max(self, other): def cost(self): return sum(num * COST[req[0]][req[1]] for req,num in list(self.items()) \ if req[1] != 'input' and req[0] != 'edabit') + def pretty(self): + t = lambda x: 'integer' if x == 'modp' else x + res = [] + for req, num in self.items(): + domain = t(req[0]) + n = '%12.0f' % num + if req[1] == 'input': + res += ['%s %s inputs from player %d' \ + % (n, domain, req[2])] + elif domain.endswith('edabit'): + if domain == 'sedabit': + eda = 'strict edabits' + else: + eda = 'loose edabits' + res += ['%s %s of length %d' % (n, eda, req[1])] + elif req[0] != 'all': + res += ['%s %s %ss' % (n, domain, req[1])] + if self['all','round']: + res += ['% 12.0f virtual machine rounds' % self['all','round']] + return res def __str__(self): - return ", ".join('%s inputs in %s from player %d' \ - % (num, req[0], req[2]) \ - if req[1] == 'input' \ - else '%s %ss in %s' % (num, req[1], req[0]) \ - for req,num in list(self.items())) + return ', '.join(self.pretty()) def __repr__(self): return repr(dict(self)) @@ -959,14 +926,12 @@ class Register(object): """ Class for creating new registers. The register's index is automatically assigned based on the block's reg_counter dictionary. - - The 'value' property is for emulation. """ __slots__ = ["reg_type", "program", "absolute_i", "relative_i", \ "size", "vector", "vectorbase", "caller", \ "can_eliminate"] - def __init__(self, reg_type, program, value=None, size=None, i=None): + def __init__(self, reg_type, program, size=None, i=None): """ Creates a new register. reg_type must be one of those defined in RegType. """ if Compiler.instructions_base.get_global_instruction_type() == 'gf2n': @@ -989,8 +954,6 @@ def __init__(self, reg_type, program, value=None, size=None, i=None): else: self.i = float('inf') self.vector = [] - if value is not None: - self.value = value self.can_eliminate = True if Program.prog.DEBUG: self.caller = [frame[1:] for frame in inspect.stack()[1:]] @@ -1010,22 +973,9 @@ def i(self, value): def set_size(self, size): if self.size == size: return - elif not self.program.program.options.assemblymode: + else: raise CompilerError('Mismatch of instruction and register size:' ' %s != %s' % (self.size, size)) - elif self.size == 1 and self.vectorbase is self: - if '%s%d' % (self.reg_type, self.i) in compilerLib.VARS: - # create vector register in assembly mode - self.size = size - self.vector = [self] - for i in range(1,size): - reg = compilerLib.VARS['%s%d' % (self.reg_type, self.i + i)] - reg.set_vectorbase(self) - self.vector.append(reg) - else: - raise CompilerError('Cannot find %s in VARS' % str(self)) - else: - raise CompilerError('Cannot reset size of vector register') def set_vectorbase(self, vectorbase): if self.vectorbase is not self: @@ -1074,16 +1024,6 @@ def __len__(self): def copy(self): return Tape.Register(self.reg_type, Program.prog.curr_tape) - @property - def value(self): - return self.program.reg_values[self.reg_type][self.i] - - @value.setter - def value(self, val): - while (len(self.program.reg_values[self.reg_type]) <= self.i): - self.program.reg_values[self.reg_type] += [0] * INIT_REG_MAX - self.program.reg_values[self.reg_type][self.i] = val - @property def is_gf2n(self): return self.reg_type == RegType.ClearGF2N or \ diff --git a/Compiler/types.py b/Compiler/types.py index a13be46b2..c31bf730c 100644 --- a/Compiler/types.py +++ b/Compiler/types.py @@ -232,7 +232,7 @@ def inputmixed(*args): if isinstance(args[-1], int): instructions.inputmixed(*args) else: - instructions.inputmixedreg(*args) + instructions.inputmixedreg(*(args[:-1] + (regint.conv(args[-1]),))) class _number(object): """ Number functionality. """ @@ -762,7 +762,11 @@ def __init__(self, val=None, size=None): def load_int(self, val): if val: # +1 for sign - program.curr_tape.require_bit_length(1 + int(math.ceil(math.log(abs(val))))) + bit_length = 1 + int(math.ceil(math.log(abs(val)))) + if program.options.ring: + assert(bit_length <= int(program.options.ring)) + elif program.param != -1 or program.options.field: + program.curr_tape.require_bit_length(bit_length) if self.in_immediate_range(val): ldi(self, val) else: @@ -783,7 +787,7 @@ def load_int(self, val): sum += sign * chunk @vectorize - def to_regint(self, n_bits=None, dest=None): + def to_regint(self, n_bits=64, dest=None): """ Convert to regint. :param n_bits: bit length (int) @@ -1146,23 +1150,6 @@ def read_from_socket(cls, client_id, n=1): else: return res - @vectorized_classmethod - def read_client_public_key(cls, client_id): - """ Receive 8 register values from socket containing client public key.""" - res = [cls() for i in range(8)] - readclientpublickey(client_id, *res) - return res - - @vectorized_classmethod - def init_secure_socket(cls, client_id, w1, w2, w3, w4, w5, w6, w7, w8): - """ Use 8 register values containing client public key.""" - initsecuresocket(client_id, w1, w2, w3, w4, w5, w6, w7, w8) - - @vectorized_classmethod - def resp_secure_socket(cls, client_id, w1, w2, w3, w4, w5, w6, w7, w8): - """ Receive 8 register values from socket containing client public key.""" - respsecuresocket(client_id, w1, w2, w3, w4, w5, w6, w7, w8) - @vectorize def write_to_socket(self, client_id, message_type=ClientMessageType.NoType): writesocketint(client_id, message_type, self) @@ -1439,7 +1426,7 @@ def protect_memory(cls, start, end): def get_input_from(cls, player): """ Secret input from player. - :param: player (compile-time int) """ + :param player: public (regint/cint/int) """ res = cls() asm_input(res, player) return res @@ -1648,9 +1635,12 @@ def __rtruediv__(self, other): @vectorize def square(self): """ Secret square. """ - res = self.__class__() - sqrs(res, self) - return res + if program.use_square(): + res = self.__class__() + sqrs(res, self) + return res + else: + return self * self @set_instruction_type @vectorize @@ -1712,7 +1702,7 @@ def get_random_int(cls, bits): def get_input_from(cls, player): """ Secret input. - :param player: compile-time integer (int) """ + :param player: public (regint/cint/int) """ res = cls() inputmixed('int', res, player) return res @@ -1757,8 +1747,7 @@ def bit_decompose_clear(a, n_bits): @classmethod def get_raw_input_from(cls, player): res = cls() - startinput(player, 1) - stopinput(player, res) + rawinput(player, res) return res @classmethod @@ -2056,8 +2045,7 @@ def get_type(cls, length): @classmethod def get_raw_input_from(cls, player): res = cls() - gstartinput(player, 1) - gstopinput(player, res) + grawinput(player, res) return res def add(self, other): @@ -3293,7 +3281,7 @@ class sfix(_fix): def get_input_from(cls, player): """ Secret fixed-point input. - :param player: int """ + :param player: public (regint/cint/int) """ v = cls.int_type() inputmixed('fix', v, cls.f, player) return cls._new(v) @@ -3674,7 +3662,7 @@ def convert_float(v, vlen, plen): def get_input_from(cls, player): """ Secret floating-point input. - :param player: int """ + :param player: public (regint/cint/int) """ v = sint() p = sint() z = sint() @@ -4195,7 +4183,7 @@ def get_mem_value(self, index): def input_from(self, player, budget=None): """ Fill with inputs from player if supported by type. - :param player: compile-time (int) """ + :param player: public (regint/cint/int) """ self.assign(self.value_type.get_input_from(player, size=len(self))) def __add__(self, other): @@ -4351,7 +4339,7 @@ def same_shape(self): def input_from(self, player, budget=None): """ Fill with inputs from player if supported by type. - :param player: compile-time (int) """ + :param player: public (regint/cint/int) """ @library.for_range_opt(self.sizes[0], budget=budget) def _(i): self[i].input_from(player, budget=budget) @@ -4597,6 +4585,12 @@ def __init__(self, rows, columns, value_type, debug=None, address=None): MultiArray.__init__(self, [rows, columns], value_type, debug=debug, \ address=address) + def set_column(self, index, vector): + assert self.value_type.n_elements() == 1 + addresses = regint.inc(self.sizes[0], self.address + index, + self.sizes[1]) + vector.store_in_mem(addresses) + class VectorArray(object): def __init__(self, length, value_type, vector_size, address=None): self.array = Array(length * vector_size, value_type, address) diff --git a/Compiler/util.py b/Compiler/util.py index 57e875fe7..110596d0a 100644 --- a/Compiler/util.py +++ b/Compiler/util.py @@ -182,6 +182,12 @@ def expand(x, size): except AttributeError: return x +def mem_size(x): + try: + return x.mem_size() + except AttributeError: + return 1 + class set_by_id(object): def __init__(self, init=[]): self.content = {} diff --git a/ExternalIO/README.md b/ExternalIO/README.md index b23bbc2e6..504a9a7e2 100644 --- a/ExternalIO/README.md +++ b/ExternalIO/README.md @@ -1,4 +1,32 @@ -The ExternalIO directory contains examples of managing I/O between external client processes and SPDZ parties running SPDZ engines. These instructions assume that SPDZ has been built as per the [project readme](../README.md). +The ExternalIO directory contains an example of managing I/O between external client processes and SPDZ parties running SPDZ engines. These instructions assume that SPDZ has been built as per the [project readme](../README.md). + +## Working Examples + +[bankers-bonus-client.cpp](./bankers-bonus-client.cpp) acts as a +client to [bankers_bonus.mpc](../Programs/Source/bankers_bonus.mpc) +and demonstrates sending input and receiving output as described by +[Damgård et al.](https://eprint.iacr.org/2015/1006) The computation +allows up to eight clients to input a number and computes the client +with the largest input. You can run it as follows from the main +directory: +``` +make bankers-bonus-client.x +./compile.py bankers_bonus 1 +Scripts/setup-ssl.sh +Scripts/setup-clients.sh 3 +Scripts/.sh & +./bankers-bonus-client.x 0 200 0 & +./bankers-bonus-client.x 2 50 1 +``` +This should output that the winning id is 1. Note that the ids have to +be incremental, and the client with the highest id has to input 1 as +the last argument while the others have to input 0 there. Furthermore, +`` refers to the number of parties running the computation +not the number of clients, and `` can be the name of +protocol script. The setup scripts generate the necessary SSL +certificates and keys. Therefore, if you run the computation on +different hosts, you will have to distribute the `*.pem` files. ## I/O MPC Instructions @@ -55,49 +83,3 @@ Receive shares of private inputs from a client, blocking on client send. This is *message_type* - optional integer which will be sent in first 4 bytes of message, to indicate message type to client. *[inputs]* - returned list of shares of private input. - - -## Securing communications - -Two cryptographic protocols have been implemented for use in particular applications and are included here for completeness: - -1. Communication security using a Station to Station key agreement and libsodium Secret Box using a nonce counter for message ordering. -2. Authenticated Diffie-Hellman without message ordering. - - Please note these are **NOT** required to allow external client I/O. Your mileage may vary, for example in a web setting TLS may be sufficient to secure communications between processes. - -[client-setup.cpp](../client-setup.cpp) is a utility which is run to generate the key material for both the external clients and SPDZ parties for both protocols. - -#### MPC instructions - -**regint.init_secure_socket**(*regint client_socket_id*, *[regint] public_signing_key*) - -STS protocol initiator. Read a client public key for a specified client connection and negotiate a shared key via STS. All subsequent write_socket / read_socket instructions are encrypted / decrypted for replay resistant commsec. - -*client_socket_id* - an identifier used to refer to the client socket. - -*public_signing_key* - client public key supplied as list of 8 32-bit ints. - -**regint.resp_secure_socket**(*regint client_socket_id*, *[regint] public_signing_key*) - -STS protocol responder. Read a client public key for a specified client connection and negotiate a shared key via STS. All subsequent write_socket / read_socket instructions are encrypted / decrypted for replay resistant commsec. - -*client_socket_id* - an identifier used to refer to the client socket. - -*public_signing_key* - client public key supplied as list of 8 32-bit ints. - -*[regint public_key]* **regint.read_client_public_key**(*regint client_socket_id*) - -Instruction to read the client public key and run setup for the authenticated Diffie-Hellman encryption. All subsequent write_socket instructions are encrypted. Only the sint.read_from_socket instruction is encrypted. - -*client_socket_id* - an identifier used to refer to the client socket. - -*public_key* - client public key made available to mpc programs as list of 8 32-bit ints. - -## Working Examples - -See [bankers-bonus-client.cpp](./bankers-bonus-client.cpp) which acts as a client to [bankers_bonus.mpc](../Programs/Source/bankers_bonus.mpc) and demonstrates sending input and receiving output with no communications security. - -See [bankers-bonus-commsec-client.cpp](./bankers-bonus-commsec-client.cpp) which acts as a client to [bankers_bonus_commsec.mpc](../Programs/Source/bankers_bonus_commsec.mpc) which runs the same algorithm but includes both the available crypto protocols. - -More instructions on how to run these are provided in the *-client files. diff --git a/ExternalIO/bankers-bonus-client.cpp b/ExternalIO/bankers-bonus-client.cpp index db5f2f1d4..cf5cdf440 100644 --- a/ExternalIO/bankers-bonus-client.cpp +++ b/ExternalIO/bankers-bonus-client.cpp @@ -16,11 +16,10 @@ * - share of random value [r] * - share of winning unique id * random value [w] * winning unique id is valid if ∑ [y] * ∑ [r] = ∑ [w] - * - * No communications security is used. * * To run with 2 parties / SPDZ engines: * ./Scripts/setup-online.sh to create triple shares for each party (spdz engine). + * ./Scripts/setup-clients.sh to create SSL keys and certificates for clients * ./compile.py bankers_bonus * ./Scripts/run-online.sh bankers_bonus to run the engines. * @@ -34,6 +33,7 @@ #include "Math/gfp.h" #include "Math/gf2n.h" #include "Networking/sockets.h" +#include "Networking/ssl_sockets.h" #include "Tools/int.h" #include "Math/Setup.h" #include "Protocols/fake-stuff.h" @@ -46,12 +46,13 @@ // Send the private inputs masked with a random value. // Receive shares of a preprocessed triple from each SPDZ engine, combine and check the triples are valid. // Add the private input value to triple[0] and send to each spdz engine. -void send_private_inputs(const vector& values, vector& sockets, int nparties) +template +void send_private_inputs(const vector& values, vector& sockets, int nparties) { int num_inputs = values.size(); octetStream os; - vector< vector > triples(num_inputs, vector(3)); - vector triple_shares(3); + vector< vector > triples(num_inputs, vector(3)); + vector triple_shares(3); // Receive num_inputs triples from SPDZ for (int j = 0; j < nparties; j++) @@ -59,6 +60,10 @@ void send_private_inputs(const vector& values, vector& sockets, int np os.reset_write_head(); os.Receive(sockets[j]); +#ifdef VERBOSE_COMM + cerr << "received " << os.get_length() << " from " << j << endl; +#endif + for (int j = 0; j < num_inputs; j++) { for (int k = 0; k < 3; k++) @@ -72,49 +77,30 @@ void send_private_inputs(const vector& values, vector& sockets, int np // Check triple relations (is a party cheating?) for (int i = 0; i < num_inputs; i++) { - if (triples[i][0] * triples[i][1] != triples[i][2]) + if (T(triples[i][0] * triples[i][1]) != triples[i][2]) { + cerr << triples[i][2] << " != " << triples[i][0] << " * " << triples[i][1] << endl; cerr << "Incorrect triple at " << i << ", aborting\n"; - exit(1); + throw mac_fail(); } } // Send inputs + triple[0], so SPDZ can compute shares of each value os.reset_write_head(); for (int i = 0; i < num_inputs; i++) { - gfp y = values[i] + triples[i][0]; + T y = values[i] + triples[i][0]; y.pack(os); } for (int j = 0; j < nparties; j++) os.Send(sockets[j]); } -// Assumes that Scripts/setup-online.sh has been run to compute prime -void initialise_fields(const string& dir_prefix) -{ - int lg2; - bigint p; - - string filename = dir_prefix + "Params-Data"; - cout << "loading params from: " << filename << endl; - - ifstream inpf(filename.c_str()); - if (inpf.fail()) { throw file_error(filename.c_str()); } - inpf >> p; - inpf >> lg2; - - inpf.close(); - - gfp::init_field(p); - gf2n::init_field(lg2); -} - - // Receive shares of the result and sum together. // Also receive authenticating values. -gfp receive_result(vector& sockets, int nparties) +template +T receive_result(vector& sockets, int nparties) { - vector output_values(3); + vector output_values(3); octetStream os; for (int i = 0; i < nparties; i++) { @@ -122,20 +108,32 @@ gfp receive_result(vector& sockets, int nparties) os.Receive(sockets[i]); for (unsigned int j = 0; j < 3; j++) { - gfp value; + T value; value.unpack(os); output_values[j] += value; } } - if (output_values[0] * output_values[1] != output_values[2]) + if (T(output_values[0] * output_values[1]) != output_values[2]) { cerr << "Unable to authenticate output value as correct, aborting." << endl; - exit(1); + throw mac_fail(); } return output_values[0]; } +template +void run(int salary_value, vector& sockets, int nparties) +{ + // Run the computation + send_private_inputs({salary_value}, sockets, nparties); + cout << "Sent private inputs to each SPDZ engine, waiting for result..." << endl; + + // Get the result back (client_id of winning client) + T result = receive_result(sockets, nparties); + + cout << "Winning client id is : " << result << endl; +} int main(int argc, char** argv) { @@ -162,34 +160,65 @@ int main(int argc, char** argv) if (argc > 6) port_base = atoi(argv[6]); - // init static gfp - string prep_data_prefix = get_prep_dir(nparties, 128, gf2n::default_degree()); - initialise_fields(prep_data_prefix); bigint::init_thread(); // Setup connections from this client to each party socket - vector sockets(nparties); + vector plain_sockets(nparties); + vector sockets(nparties); + ssl_ctx ctx("C" + to_string(my_client_id)); + ssl_service io_service; + octetStream specification; for (int i = 0; i < nparties; i++) { - set_up_client_socket(sockets[i], host_name.c_str(), port_base + i); - send(sockets[i], (octet*) &my_client_id, sizeof(int)); + set_up_client_socket(plain_sockets[i], host_name.c_str(), port_base + i); + send(plain_sockets[i], (octet*) &my_client_id, sizeof(int)); + sockets[i] = new ssl_socket(io_service, ctx, plain_sockets[i], + "P" + to_string(i), "C" + to_string(my_client_id), true); + if (i == 0) + specification.Receive(sockets[0]); octetStream os; os.store(finish); os.Send(sockets[i]); } cout << "Finish setup socket connections to SPDZ engines." << endl; - // Run the commputation - send_private_inputs({salary_value}, sockets, nparties); - cout << "Sent private inputs to each SPDZ engine, waiting for result..." << endl; - - // Get the result back (client_id of winning client) - gfp result = receive_result(sockets, nparties); + int type = specification.get(); + switch (type) + { + case 'p': + { + gfp::init_field(specification.get()); + cerr << "using prime " << gfp::pr() << endl; + run(salary_value, sockets, nparties); + break; + } + case 'R': + { + int R = specification.get(); + switch (R) + { + case 64: + run>(salary_value, sockets, nparties); + break; + case 104: + run>(salary_value, sockets, nparties); + break; + case 128: + run>(salary_value, sockets, nparties); + break; + default: + cerr << R << "-bit ring not implemented"; + exit(1); + } + break; + } + default: + cerr << "Type " << type << " not implemented"; + exit(1); + } - cout << "Winning client id is : " << result << endl; - - for (unsigned int i = 0; i < sockets.size(); i++) - close_client_socket(sockets[i]); + for (int i = 0; i < nparties; i++) + delete sockets[i]; return 0; } diff --git a/ExternalIO/bankers-bonus-commsec-client.cpp b/ExternalIO/bankers-bonus-commsec-client.cpp deleted file mode 100644 index 4a250880b..000000000 --- a/ExternalIO/bankers-bonus-commsec-client.cpp +++ /dev/null @@ -1,407 +0,0 @@ -/* - * Demonstrate external client inputing and receiving outputs from a SPDZ process, - * following the protocol described in https://eprint.iacr.org/2015/1006.pdf. - * Uses SPDZ implemented encryption for external client communication, see bankers-bonus-client.cpp - * for a simpler client with no crypto. - * - * Provides a client to bankers_bonus_commsec.mpc program to calculate which banker pays for lunch based on - * the private value annual bonus. Up to 8 clients can connect to the SPDZ engines running - * the bankers_bonus.mpc program. - * - * Each connecting client: - * - sends an increasing id to identify the client, starting with 0 - * - sends an integer (0 meaining more players will join this round or 1 meaning stop the round and calc the result). - * - runs crypto setup to demonstrate both DH Auth Encryption and STS protocol for comms security. - * - sends an integer input (bonus value to compare) - * - * The result is returned authenticated with a share of a random value: - * - share of winning unique id [y] - * - share of random value [r] - * - share of winning unique id * random value [w] - * winning unique id is valid if ∑ [y] * ∑ [r] = ∑ [w] - * - * To run with 2 parties (SPDZ engines) and 3 external clients: - * ./Scripts/setup-online.sh to create triple shares for each party (spdz engine). - * ./client-setup.x 2 -nc 3 to create the crypto key material for both parties and clients. - * ./compile.py bankers_bonus_commsec - * ./Scripts/run-online.sh bankers_bonus_commsec to run the engines. - * - * ./bankers-bonus-commsec-client.x 0 2 100 0 - * ./bankers-bonus-commsec-client.x 1 2 200 0 - * ./bankers-bonus-commsec-client.x 2 2 50 1 - * - * Expect winner to be second client with id 1. - * Note here client id must match id used in generating client key material, Client-Keys-C. - */ - -#include "Math/gfp.h" -#include "Math/gf2n.h" -#include "Networking/sockets.h" -#include "Networking/STS.h" -#include "Tools/int.h" -#include "Math/Setup.h" -#include "Protocols/fake-stuff.h" - -#include -#include -#include -#include -#include - -typedef pair< vector, vector > keypair_t; // A pair of send/recv keys for talking to SPDZ -typedef vector< keypair_t > commsec_t; // A database of send/recv keys indexed by server -typedef struct { - unsigned char client_secretkey[crypto_sign_SECRETKEYBYTES]; - unsigned char client_publickey[crypto_sign_PUBLICKEYBYTES]; - vector client_publickey_ints; - vector< vector >server_publickey; -} sign_key_container_t; - -keypair_t sts_response_role_exceptions(sign_key_container_t keys, vector& sockets, int server_id) -{ - STS ke(&keys.server_publickey[server_id][0], keys.client_publickey, keys.client_secretkey); - sts_msg1_t m1; - sts_msg2_t m2; - sts_msg3_t m3; - octetStream os; - - os.Receive(sockets[server_id]); - os.consume(m1.bytes, sizeof m1.bytes); - m2 = ke.recv_msg1(m1); - os.reset_write_head(); - os.append(m2.pubkey, sizeof m2.pubkey); - os.append(m2.sig, sizeof m2.sig); - os.Send(sockets[server_id]); - os.Receive(sockets[server_id]); - os.consume(m3.bytes, sizeof m3.bytes); - ke.recv_msg3(m3); - vector recvKey = ke.derive_secret(crypto_secretbox_KEYBYTES); - vector sendKey = ke.derive_secret(crypto_secretbox_KEYBYTES); - return make_pair(sendKey,recvKey); -} - -keypair_t sts_initiator_role_exceptions(sign_key_container_t keys, vector& sockets, int server_id) -{ - STS ke(&keys.server_publickey[server_id][0], keys.client_publickey, keys.client_secretkey); - sts_msg1_t m1; - sts_msg2_t m2; - sts_msg3_t m3; - octetStream os; - - m1 = ke.send_msg1(); - cout << "m1: "; - for (unsigned int j = 0; j < 32; j++) - cout << setfill('0') << setw(2) << hex << (int) m1.bytes[j]; - cout << dec << endl; - os.reset_write_head(); - os.append(m1.bytes, sizeof m1.bytes); - os.Send(sockets[server_id]); - - os.reset_write_head(); - os.Receive(sockets[server_id]); - os.consume(m2.pubkey, sizeof m2.pubkey); - os.consume(m2.sig, sizeof m2.sig); - m3 = ke.recv_msg2(m2); - - os.reset_write_head(); - os.append(m3.bytes, sizeof m3.bytes); - os.Send(sockets[server_id]); - - vector sendKey = ke.derive_secret(crypto_secretbox_KEYBYTES); - vector recvKey = ke.derive_secret(crypto_secretbox_KEYBYTES); - return make_pair(sendKey,recvKey); -} - -pair< vector, vector > sts_response_role(sign_key_container_t keys, vector& sockets, int server_id) -{ - pair< vector, vector > res; - try { - res = sts_response_role_exceptions(keys, sockets, server_id); - } catch(char const *e) { - cerr << "Error in STS: " << e << endl; - exit(1); - } - return res; -} - -pair< vector, vector > sts_initiator_role(sign_key_container_t keys, vector& sockets, int server_id) -{ - pair< vector, vector > res; - try { - res = sts_initiator_role_exceptions(keys, sockets, server_id); - } catch(char const *e) { - cerr << "Error in STS: " << e << endl; - exit(1); - } - return res; -} - -// Send the private inputs masked with a random value. -// Receive shares of a preprocessed triple from each SPDZ engine, combine and check the triples are valid. -// Add the private input value to triple[0] and send to each spdz engine. -void send_private_inputs(const vector& values, vector& sockets, int nparties, - commsec_t commsec, vector& keys) -{ - int num_inputs = values.size(); - octetStream os; - vector< vector > triples(num_inputs, vector(3)); - vector triple_shares(3); - - // Receive num_inputs triples from SPDZ - for (int j = 0; j < nparties; j++) - { - os.reset_write_head(); - os.Receive(sockets[j]); - os.decrypt_sequence(&commsec[j].second[0],0); - os.decrypt(keys[j]); - - for (int j = 0; j < num_inputs; j++) - { - for (int k = 0; k < 3; k++) - { - triple_shares[k].unpack(os); - triples[j][k] += triple_shares[k]; - } - } - } - // Check triple relations - for (int i = 0; i < num_inputs; i++) - { - if (triples[i][0] * triples[i][1] != triples[i][2]) - { - cerr << "Incorrect triple at " << i << ", aborting\n"; - exit(1); - } - } - // Send inputs + triple[0], so SPDZ can compute shares of each value - os.reset_write_head(); - for (int i = 0; i < num_inputs; i++) - { - gfp y = values[i] + triples[i][0]; - y.pack(os); - } - for (int j = 0; j < nparties; j++) - { - octetStream temp = os; - temp.encrypt_sequence(&commsec[j].first[0], 0); - temp.Send(sockets[j]); - } - } - -// Send public key in clear to each SPDZ engine. -void send_public_key(vector& pubkey, int socket) -{ - octetStream os; - os.reset_write_head(); - - for (unsigned int i = 0; i < pubkey.size(); i++) - { - os.store(pubkey[i]); - } - - os.Send(socket); -} - -// Assumes that Scripts/setup-online.sh has been run to compute prime -void initialise_fields(const string& dir_prefix) -{ - int lg2; - bigint p; - - string filename = dir_prefix + "Params-Data"; - cout << "loading params from: " << filename << endl; - - ifstream inpf(filename.c_str()); - if (inpf.fail()) { throw file_error(filename.c_str()); } - inpf >> p; - inpf >> lg2; - - inpf.close(); - - gfp::init_field(p); - gf2n::init_field(lg2); -} - -// Assumes that client-setup has been run to create key pairs for clients and parties -void generate_symmetric_keys(vector& keys, vector& client_public_key_ints, - sign_key_container_t *sts_key, const string& dir_prefix, int client_no) -{ - unsigned char client_publickey[crypto_box_PUBLICKEYBYTES]; - unsigned char client_secretkey[crypto_box_SECRETKEYBYTES]; - unsigned char server_publickey[crypto_box_PUBLICKEYBYTES]; - unsigned char scalarmult_q[crypto_scalarmult_BYTES]; - crypto_generichash_state h; - - // read client public/secret keys + SPDZ server public keys - ifstream keyfile; - stringstream client_filename; - client_filename << dir_prefix << "Client-Keys-C" << client_no; - keyfile.open(client_filename.str().c_str()); - if (keyfile.fail()) - throw file_error(client_filename.str()); - keyfile.read((char*)client_publickey, sizeof client_publickey); - if (keyfile.eof()) - throw end_of_file(client_filename.str(), "client public key" ); - - // Convert client public key unsigned char to int, reverse endianness - for(unsigned int j = 0; j < client_public_key_ints.size(); j++) { - int keybyte = 0; - for(unsigned int k = 0; k < 4; k++) { - keybyte = keybyte + (((int)client_publickey[j*4+k]) << ((3-k) * 8)); - } - client_public_key_ints[j] = keybyte; - } - - keyfile.read((char*)client_secretkey, sizeof client_secretkey); - if (keyfile.eof()) { - throw end_of_file(client_filename.str(), "client private key" ); - } - - keyfile.read((char*)sts_key->client_publickey, crypto_sign_PUBLICKEYBYTES); - keyfile.read((char*)sts_key->client_secretkey, crypto_sign_SECRETKEYBYTES); - // Convert client public key unsigned char to int, reverse endianness - sts_key->client_publickey_ints.resize(8); - for(unsigned int j = 0; j < sts_key->client_publickey_ints.size(); j++) { - int keybyte = 0; - for(unsigned int k = 0; k < 4; k++) { - keybyte = keybyte + (((int)sts_key->client_publickey[j*4+k]) << ((3-k) * 8)); - } - sts_key->client_publickey_ints[j] = keybyte; - } - - for (unsigned int i = 0; i < keys.size(); i++) - { - keys[i] = new octet[crypto_generichash_BYTES]; - keyfile.read((char*)server_publickey, crypto_box_PUBLICKEYBYTES); - if (keyfile.eof()) - throw end_of_file(client_filename.str(), "server public key for party " + to_string(i)); - keyfile.read((char*)(&sts_key->server_publickey[i][0]), crypto_sign_PUBLICKEYBYTES); - if (keyfile.eof()) - throw end_of_file(client_filename.str(), "server public signing key for party " + to_string(i)); - - // Derive a shared key from this server's secret key and the client's public key - // shared key = h(q || client_secretkey || server_publickey) - if (crypto_scalarmult(scalarmult_q, client_secretkey, server_publickey) != 0) { - cerr << "Scalar mult failed\n"; - exit(1); - } - crypto_generichash_init(&h, NULL, 0U, crypto_generichash_BYTES); - crypto_generichash_update(&h, scalarmult_q, sizeof scalarmult_q); - crypto_generichash_update(&h, client_publickey, sizeof client_publickey); - crypto_generichash_update(&h, server_publickey, sizeof server_publickey); - crypto_generichash_final(&h, keys[i], crypto_generichash_BYTES); - } - keyfile.close(); - - cout << "My public key is: "; - for (unsigned int j = 0; j < 32; j++) - cout << setfill('0') << setw(2) << hex << (int) client_publickey[j]; - cout << dec << endl; -} - - -// Receive shares of the result and sum together. -// Also receive authenticating values. -gfp receive_result(vector& sockets, int nparties, commsec_t commsec, vector& keys) -{ - vector output_values(3); - octetStream os; - for (int i = 0; i < nparties; i++) - { - os.reset_write_head(); - os.Receive(sockets[i]); - - os.decrypt_sequence(&commsec[i].second[0],1); - os.decrypt(keys[i]); - - for (unsigned int j = 0; j < 3; j++) - { - gfp value; - value.unpack(os); - output_values[j] += value; - } - } - - if (output_values[0] * output_values[1] != output_values[2]) - { - cerr << "Unable to authenticate output value as correct, aborting." << endl; - exit(1); - } - return output_values[0]; -} - - -int main(int argc, char** argv) -{ - int my_client_id; - int nparties; - int salary_value; - int finish; - int port_base = 14000; - sign_key_container_t sts_key; - string host_name = "localhost"; - - if (argc < 5) { - cout << "Usage is external-client " - << " " - << "" << endl; - exit(0); - } - - my_client_id = atoi(argv[1]); - nparties = atoi(argv[2]); - salary_value = atoi(argv[3]); - finish = atoi(argv[4]); - if (argc > 5) - host_name = argv[5]; - if (argc > 6) - port_base = atoi(argv[6]); - - sts_key.server_publickey.resize(nparties); - for(int i = 0 ; i < nparties; i++) { - sts_key.server_publickey[i].resize(crypto_sign_PUBLICKEYBYTES); - } - - // init static gfp - string prep_data_prefix = get_prep_dir(nparties, 128, gf2n::default_degree()); - initialise_fields(prep_data_prefix); - bigint::init_thread(); - - // Generate session keys to decrypt data sent from each spdz engine (party) - vector session_keys(nparties); - vector client_public_key_ints(8); - - generate_symmetric_keys(session_keys, client_public_key_ints, &sts_key, prep_data_prefix, my_client_id); - - // Setup connections from this client to each party socket and send the client public keys - vector sockets(nparties); - // vector< pair , vector > > commseckey(nparties); - commsec_t commseckey(nparties); - for (int i = 0; i < nparties; i++) - { - set_up_client_socket(sockets[i], host_name.c_str(), port_base + i); - send(sockets[i], (octet*) &my_client_id, sizeof(int)); - octetStream os; - os.store(finish); - os.Send(sockets[i]); - - send_public_key(sts_key.client_publickey_ints, sockets[i]); - send_public_key(client_public_key_ints, sockets[i]); - commseckey[i] = sts_initiator_role(sts_key, sockets, i); - } - cout << "Finish setup socket connections to SPDZ engines." << endl; - - // Send the inputs to the SPDZ Engines - send_private_inputs({salary_value}, sockets, nparties, commseckey, session_keys); - cout << "Sent private inputs to each SPDZ engine, waiting for result..." << endl; - - // Get the result back - gfp result = receive_result(sockets, nparties, commseckey, session_keys); - - cout << "Winning client id is : " << result << endl; - - for (unsigned int i = 0; i < sockets.size(); i++) - close_client_socket(sockets[i]); - - return 0; -} diff --git a/FHE/FFT_Data.cpp b/FHE/FFT_Data.cpp index 7b858a8bb..124a3794e 100644 --- a/FHE/FFT_Data.cpp +++ b/FHE/FFT_Data.cpp @@ -180,6 +180,13 @@ void FFT_Data::pack(octetStream& o) const { R.pack(o); prData.pack(o); + o.store(root); + o.store(twop); + o.store(two_root); + o.store(b); + iphi.pack(o); + o.store(powers); + o.store(powers_i); } @@ -187,7 +194,13 @@ void FFT_Data::unpack(octetStream& o) { R.unpack(o); prData.unpack(o); - init(R, prData); + o.get(root); + o.get(twop); + o.get(two_root); + o.get(b); + iphi.unpack(o); + o.get(powers); + o.get(powers_i); } diff --git a/FHE/NTL-Subs.cpp b/FHE/NTL-Subs.cpp index c2ee07dab..912db76cb 100644 --- a/FHE/NTL-Subs.cpp +++ b/FHE/NTL-Subs.cpp @@ -345,19 +345,38 @@ void init(Ring& Rg,int m) Rg.pi.resize(Rg.phim); Rg.pi_inv.resize(Rg.mm); for (int i=0; isigma = sigma = FHE_Params().get_R(); +#ifdef VERBOSE cerr << "Standard deviation: " << this->sigma << endl; +#endif h += extra_h * sec; produce_epsilon_constants(); diff --git a/FHE/Ring.cpp b/FHE/Ring.cpp index b0346116e..d6a8f7686 100644 --- a/FHE/Ring.cpp +++ b/FHE/Ring.cpp @@ -29,19 +29,27 @@ istream& operator>>(istream& s,Ring& R) void Ring::pack(octetStream& o) const { o.store(mm); - o.store(phim); - o.store(pi); - o.store(pi_inv); - o.store(poly); + if (((mm - 1) & mm) != 0) + { + o.store(phim); + o.store(pi); + o.store(pi_inv); + o.store(poly); + } } void Ring::unpack(octetStream& o) { o.get(mm); - o.get(phim); - o.get(pi); - o.get(pi_inv); - o.get(poly); + if (((mm - 1) & mm) != 0) + { + o.get(phim); + o.get(pi); + o.get(pi_inv); + o.get(poly); + } + else + init(*this, mm); } bool Ring::operator !=(const Ring& other) const diff --git a/FHE/Ring.h b/FHE/Ring.h index ab7aedfcf..4f5554fec 100644 --- a/FHE/Ring.h +++ b/FHE/Ring.h @@ -7,6 +7,7 @@ #include #include +#include using namespace std; #include "Tools/octetStream.h" @@ -31,7 +32,7 @@ class Ring int p(int i) const { return pi.at(i); } int p_inv(int i) const { return pi_inv.at(i); } - const vector& Phi() const { return poly; } + const vector& Phi() const { assert(poly.size()); return poly; } friend ostream& operator<<(ostream& s,const Ring& R); friend istream& operator>>(istream& s,Ring& R); diff --git a/GC/ArgTuples.h b/GC/ArgTuples.h index 5b2de4fb9..59904bfb2 100644 --- a/GC/ArgTuples.h +++ b/GC/ArgTuples.h @@ -96,6 +96,8 @@ class InputArgList : public ArgList res += x.from == from; return res; } + + int n_interactive_inputs_from_me(int my_num); }; #endif /* GC_ARGTUPLES_H_ */ diff --git a/GC/Instruction.hpp b/GC/Instruction.hpp index 865f0432b..0336a0155 100644 --- a/GC/Instruction.hpp +++ b/GC/Instruction.hpp @@ -76,14 +76,7 @@ unsigned Instruction::get_mem(RegType reg_type) const inline void Instruction::parse(istream& s, int pos) { - n = 0; - start.resize(0); - ::memset(r, 0, sizeof(r)); - - int file_pos = s.tellg(); - opcode = ::get_int(s); - - parse_operands(s, pos, file_pos); + BaseInstruction::parse(s, pos); switch(opcode) { diff --git a/GC/Instruction_inline.h b/GC/Instruction_inline.h index 2eddb3775..61465c7dc 100644 --- a/GC/Instruction_inline.h +++ b/GC/Instruction_inline.h @@ -42,6 +42,7 @@ MAYBE_INLINE bool Instruction::execute(Processor& processor, cout << endl; #endif const Instruction& instruction = *this; + auto& Ci = processor.I; switch (opcode) { #define X(NAME, CODE) case NAME: CODE; return true; diff --git a/GC/Processor.h b/GC/Processor.h index cb54b2bd8..1b00db911 100644 --- a/GC/Processor.h +++ b/GC/Processor.h @@ -65,6 +65,8 @@ class Processor : public ::ProcessorBase void reset(const U& program); long long get_input(const int* params, bool interactive = false); + bigint get_long_input(const int* params, ProcessorBase& input_proc, + bool interactive = false); void bitcoms(T& x, const vector& regs) { x.bitcom(S, regs); } void bitdecs(const vector& regs, const T& x) { x.bitdec(S, regs); } @@ -84,6 +86,10 @@ class Processor : public ::ProcessorBase template void store_clear_in_dynamic(const vector& args, U& dynamic_memory); + template + void mem_op(int n, Memory& dest, const Memory& source, + Integer dest_address, Integer source_address); + void xors(const vector& args); void andm(const ::BaseInstruction& instruction); void and_(const vector& args, bool repeat); @@ -95,9 +101,9 @@ class Processor : public ::ProcessorBase void reveal(const ::BaseInstruction& instruction); - void print_reg(int reg, int n); + void print_reg(int reg, int n, int size); void print_reg_plain(Clear& value); - void print_reg_signed(unsigned n_bits, Clear& value); + void print_reg_signed(unsigned n_bits, Integer value); void print_chr(int n); void print_str(int n); void print_float(const vector& args); diff --git a/GC/Processor.hpp b/GC/Processor.hpp index cf454bc68..6a2654e37 100644 --- a/GC/Processor.hpp +++ b/GC/Processor.hpp @@ -67,11 +67,19 @@ void Processor::reset(const U& program) template inline long long GC::Processor::get_input(const int* params, bool interactive) { - bigint res = ProcessorBase::get_input(interactive, ¶ms[1]).items[0]; + assert(params[0] <= 64); + return get_long_input(params, *this, interactive).get_si(); +} + +template +bigint GC::Processor::get_long_input(const int* params, + ProcessorBase& input_proc, bool interactive) +{ + bigint res = input_proc.get_input>(interactive, + ¶ms[1]).items[0]; int n_bits = *params; check_input(res, n_bits); - assert(n_bits <= 64); - return res.get_si(); + return res; } template @@ -171,6 +179,17 @@ void GC::Processor::store_clear_in_dynamic(const vector& args, T::store_clear_in_dynamic(dynamic_memory, accesses); } +template +template +void Processor::mem_op(int n, Memory& dest, const Memory& source, + Integer dest_address, Integer source_address) +{ + for (int i = 0; i < n; i++) + { + dest[dest_address + i] = source[source_address + i]; + } +} + template void Processor::xors(const vector& args) { @@ -234,12 +253,15 @@ void Processor::reveal(const vector& args) } template -void Processor::print_reg(int reg, int n) +void Processor::print_reg(int reg, int n, int size) { #ifdef DEBUG_VALUES cout << "print_reg " << typeid(T).name() << " " << reg << " " << &C[reg] << endl; #endif - T::out << "Reg[" << reg << "] = " << hex << showbase << C[reg] << dec << " # "; + bigint output; + for (int i = 0; i < size; i++) + output += bigint((unsigned long)C[reg + i].get()) << (T::default_length * i); + T::out << "Reg[" << reg << "] = " << hex << showbase << output << dec << " # "; print_str(n); T::out << endl << flush; } @@ -251,14 +273,29 @@ void Processor::print_reg_plain(Clear& value) } template -void Processor::print_reg_signed(unsigned n_bits, Clear& value) +void Processor::print_reg_signed(unsigned n_bits, Integer reg) { - unsigned n_shift = 0; - if (n_bits > 1) - n_shift = sizeof(value.get()) * 8 - n_bits; - if (n_shift > 63) - n_shift = 0; - T::out << dec << (value.get() << n_shift >> n_shift) << flush; + if (n_bits <= Clear::N_BITS) + { + auto value = C[reg]; + unsigned n_shift = 0; + if (n_bits > 1) + n_shift = sizeof(value.get()) * 8 - n_bits; + if (n_shift > 63) + n_shift = 0; + T::out << dec << (value.get() << n_shift >> n_shift) << flush; + } + else + { + bigint tmp = 0; + for (int i = 0; i < DIV_CEIL(n_bits, Clear::N_BITS); i++) + { + tmp += bigint((unsigned long)C[reg + i].get()) << (i * Clear::N_BITS); + } + if (tmp >= bigint(1) << (n_bits - 1)) + tmp -= bigint(1) << n_bits; + T::out << dec << tmp << flush; + } } template diff --git a/GC/Secret.h b/GC/Secret.h index 27e1a0006..928f5f900 100644 --- a/GC/Secret.h +++ b/GC/Secret.h @@ -21,6 +21,8 @@ #include +class ProcessorBase; + namespace GC { @@ -116,6 +118,10 @@ class Secret static void inputb(Processor& processor, const vector& args) { T::inputb(processor, args); } template + static void inputb(Processor& processor, ProcessorBase& input_proc, + const vector& args) + { T::inputb(processor, input_proc, args); } + template static void reveal_inst(Processor& processor, const vector& args) { processor.reveal(args); } diff --git a/GC/Secret.hpp b/GC/Secret.hpp index d71adede7..a12266343 100644 --- a/GC/Secret.hpp +++ b/GC/Secret.hpp @@ -290,12 +290,14 @@ void Secret::trans(Processor& processor, int n_outputs, const vector& args) { int n_inputs = args.size() - n_outputs; + int dl = U::default_length; for (int i = 0; i < n_outputs; i++) { - processor.S[args[i]].resize_regs(n_inputs); + for (int j = 0; j < DIV_CEIL(n_inputs, dl); j++) + processor.S[args[i] + j].resize_regs(min(dl, n_inputs - j * dl)); for (int j = 0; j < n_inputs; j++) - processor.S[args[i]].registers[j] = - processor.S[args[n_outputs + j]].registers[i]; + processor.S[args[i] + j / dl].registers[j % dl] = + processor.S[args[n_outputs + j] + i / dl].registers[i % dl]; } } diff --git a/GC/SemiSecret.cpp b/GC/SemiSecret.cpp index ce2f7ce21..704e2a2fb 100644 --- a/GC/SemiSecret.cpp +++ b/GC/SemiSecret.cpp @@ -25,12 +25,25 @@ SemiSecret::MC* SemiSecret::new_mc(mac_key_type) void SemiSecret::trans(Processor& processor, int n_outputs, const vector& args) { - square64 square; - for (size_t i = n_outputs; i < args.size(); i++) - square.rows[i - n_outputs] = processor.S[args[i]].get(); - square.transpose(args.size() - n_outputs, n_outputs); - for (int i = 0; i < n_outputs; i++) - processor.S[args[i]] = square.rows[i]; + int N_BITS = default_length; + for (int j = 0; j < DIV_CEIL(n_outputs, N_BITS); j++) + for (int l = 0; l < DIV_CEIL(args.size() - n_outputs, N_BITS); l++) + { + square64 square; + size_t input_base = n_outputs + l * N_BITS; + for (size_t i = input_base; + i < min(input_base + N_BITS, args.size()); i++) + square.rows[i - input_base] = processor.S[args[i] + j].get(); + square.transpose( + min(size_t(N_BITS), args.size() - n_outputs - l * N_BITS), + min(N_BITS, n_outputs - j * N_BITS)); + int output_base = j * N_BITS; + for (int i = output_base; i < min(n_outputs, output_base + N_BITS); + i++) + { + processor.S[args[i] + l] = square.rows[i - output_base]; + } + } } void SemiSecret::load_clear(int n, const Integer& x) diff --git a/GC/ShareSecret.h b/GC/ShareSecret.h index 8dac7360a..dc483371e 100644 --- a/GC/ShareSecret.h +++ b/GC/ShareSecret.h @@ -20,6 +20,7 @@ using namespace std; #include "Protocols/Replicated.h" #include "Protocols/ReplicatedMC.h" #include "Processor/DummyProtocol.h" +#include "Processor/ProcessorBase.h" namespace GC { @@ -58,7 +59,10 @@ class ShareSecret { and_(processor, args, false); } static void and_(Processor& processor, const vector& args, bool repeat); static void xors(Processor& processor, const vector& args); - static void inputb(Processor& processor, const vector& args); + static void inputb(Processor& processor, const vector& args) + { inputb(processor, processor, args); } + static void inputb(Processor& processor, ProcessorBase& input_processor, + const vector& args); static void reveal_inst(Processor& processor, const vector& args); static void convcbit(Integer& dest, const Clear& source) { dest = source; } diff --git a/GC/ShareSecret.hpp b/GC/ShareSecret.hpp index 1eff55b5e..9d90df598 100644 --- a/GC/ShareSecret.hpp +++ b/GC/ShareSecret.hpp @@ -26,7 +26,7 @@ namespace GC { template -const int VectorSecret::default_length; +const int ReplicatedSecret::N_BITS; template const int ReplicatedSecret::default_length; @@ -92,6 +92,7 @@ void ShareSecret::store_clear_in_dynamic(Memory& mem, template void ShareSecret::inputb(Processor& processor, + ProcessorBase& input_processor, const vector& args) { auto& party = ShareThread::s(); @@ -99,16 +100,22 @@ void ShareSecret::inputb(Processor& processor, input.reset_all(*party.P); InputArgList a(args); - bool interactive = Thread::s().n_interactive_inputs_from_me(a) > 0; + bool interactive = a.n_interactive_inputs_from_me(party.P->my_num()) > 0; + int dl = U::default_length; for (auto x : a) { if (x.from == party.P->my_num()) { - input.add_mine(processor.get_input(x.params, interactive), x.n_bits); + bigint whole_input = processor.get_long_input(x.params, + input_processor, interactive); + for (int i = 0; i < DIV_CEIL(x.n_bits, dl); i++) + input.add_mine(bigint(whole_input >> (i * dl)).get_si(), + min(dl, x.n_bits - i * dl)); } else - input.add_other(x.from); + for (int i = 0; i < DIV_CEIL(x.n_bits, dl); i++) + input.add_other(x.from); } if (interactive) @@ -120,8 +127,12 @@ void ShareSecret::inputb(Processor& processor, { int from = x.from; int n_bits = x.n_bits; - auto& res = processor.S[x.dest]; - res = input.finalize(from, n_bits).mask(n_bits); + for (int i = 0; i < DIV_CEIL(x.n_bits, dl); i++) + { + auto& res = processor.S[x.dest + i]; + int n = min(dl, n_bits - i * dl); + res = input.finalize(from, n).mask(n); + } } } @@ -139,7 +150,11 @@ void ShareSecret::reveal_inst(Processor& processor, if (n > max(U::default_length, Clear::N_BITS)) assert(U::default_length == Clear::N_BITS); for (int j = 0; j < DIV_CEIL(n, U::default_length); j++) - shares.push_back(processor.S[r1 + j].mask(n)); + { + shares.push_back( + processor.S[r1 + j].mask( + min(U::default_length, n - j * U::default_length))); + } } assert(party.MC); PointerVector opened; @@ -149,7 +164,10 @@ void ShareSecret::reveal_inst(Processor& processor, int n = args[i]; int r0 = args[i + 1]; for (int j = 0; j < DIV_CEIL(n, U::default_length); j++) - processor.C[r0 + j] = opened.next().mask(n); + { + processor.C[r0 + j] = opened.next().mask( + min(U::default_length, n - j * U::default_length)); + } } } @@ -180,12 +198,22 @@ void ReplicatedSecret::trans(Processor& processor, assert(length == 2); for (int k = 0; k < 2; k++) { - square64 square; - for (size_t i = n_outputs; i < args.size(); i++) - square.rows[i - n_outputs] = processor.S[args[i]][k].get(); - square.transpose(args.size() - n_outputs, n_outputs); - for (int i = 0; i < n_outputs; i++) - processor.S[args[i]][k] = square.rows[i]; + for (int j = 0; j < DIV_CEIL(n_outputs, N_BITS); j++) + for (int l = 0; l < DIV_CEIL(args.size() - n_outputs, N_BITS); l++) + { + square64 square; + size_t input_base = n_outputs + l * N_BITS; + for (size_t i = input_base; i < min(input_base + N_BITS, args.size()); i++) + square.rows[i - input_base] = processor.S[args[i] + j][k].get(); + square.transpose( + min(size_t(N_BITS), args.size() - n_outputs - l * N_BITS), + min(N_BITS, n_outputs - j * N_BITS)); + int output_base = j * N_BITS; + for (int i = output_base; i < min(n_outputs, output_base + N_BITS); i++) + { + processor.S[args[i] + l][k] = square.rows[i - output_base]; + } + } } } diff --git a/GC/ShareThread.hpp b/GC/ShareThread.hpp index bc9f2fceb..41ff87901 100644 --- a/GC/ShareThread.hpp +++ b/GC/ShareThread.hpp @@ -75,7 +75,9 @@ void ShareThread::post_run() { MC->Check(*this->P); #ifndef INSECURE +#ifdef VERBOSE cerr << "Removing used pre-processed data" << endl; +#endif DataF.prune(); #endif } diff --git a/GC/Thread.hpp b/GC/Thread.hpp index 28a377f2f..fc57e8ef7 100644 --- a/GC/Thread.hpp +++ b/GC/Thread.hpp @@ -90,18 +90,23 @@ void Thread::finish() pthread_join(thread, 0); } - template -int GC::Thread::n_interactive_inputs_from_me(InputArgList& args) +int Thread::n_interactive_inputs_from_me(InputArgList& args) +{ + return args.n_interactive_inputs_from_me(P->my_num()); +} + +} /* namespace GC */ + + +inline int InputArgList::n_interactive_inputs_from_me(int my_num) { int res = 0; - if (thread_num == 0 and master.opts.interactive) - res = args.n_inputs_from(P->my_num()); + if (ArithmeticProcessor().use_stdin()) + res = n_inputs_from(my_num); if (res > 0) cout << "Please enter " << res << " numbers:" << endl; return res; } -} /* namespace GC */ - #endif diff --git a/GC/TinySecret.h b/GC/TinySecret.h index 1928af31b..da560cb9a 100644 --- a/GC/TinySecret.h +++ b/GC/TinySecret.h @@ -218,6 +218,9 @@ class TinySecret : public VectorSecret> } }; +template +const int VectorSecret::default_length; + template inline VectorSecret operator*(const BitVec& clear, const VectorSecret& share) { diff --git a/GC/instructions.h b/GC/instructions.h index cca304ae4..27bd8b781 100644 --- a/GC/instructions.h +++ b/GC/instructions.h @@ -30,12 +30,11 @@ #define IMM instruction.get_n() #define EXTRA instruction.get_start() +#define SIZE instruction.get_size() -#define MSD processor.memories.MS[IMM] -#define MMC processor.memories.MC[IMM] +#define MMS processor.memories.MS +#define MMC processor.memories.MC #define MID MACH->MI[IMM] - -#define MSI processor.memories.MS[PI1.get()] #define MII MACH->MI[PI1.get()] #define BIT_INSTRUCTIONS \ @@ -44,7 +43,6 @@ X(XORCBI, C0.xor_(PC1, IMM)) \ X(ANDRS, T::andrs(PROC, EXTRA)) \ X(ANDS, T::ands(PROC, EXTRA)) \ - X(INPUTB, T::inputb(PROC, EXTRA)) \ X(ADDCB, C0 = PC1 + PC2) \ X(ADDCBI, C0 = PC1 + IMM) \ X(MULCBI, C0 = PC1 * IMM) \ @@ -54,24 +52,25 @@ X(SHRCBI, C0 = PC1 >> IMM) \ X(SHLCBI, C0 = PC1 << IMM) \ X(LDBITS, S0.load_clear(REG1, IMM)) \ - X(LDMSB, S0 = MSD) \ - X(STMSB, MSD = S0) \ - X(LDMCB, C0 = MMC) \ - X(STMCB, MMC = C0) \ + X(LDMSB, PROC.mem_op(SIZE, PROC.S, MMS, R0, IMM)) \ + X(STMSB, PROC.mem_op(SIZE, MMS, PROC.S, IMM, R0)) \ + X(LDMCB, PROC.mem_op(SIZE, PROC.C, MMC, R0, IMM)) \ + X(STMCB, PROC.mem_op(SIZE, MMC, PROC.C, IMM, R0)) \ + X(LDMSBI, PROC.mem_op(SIZE, PROC.S, MMS, R0, Ci[REG1])) \ + X(STMSBI, PROC.mem_op(SIZE, MMS, PROC.S, Ci[REG1], R0)) \ X(MOVSB, S0 = PS1) \ X(TRANS, T::trans(PROC, IMM, EXTRA)) \ X(BITB, PROC.random_bit(S0)) \ X(REVEAL, T::reveal_inst(PROC, EXTRA)) \ - X(PRINTREGSIGNED, PROC.print_reg_signed(IMM, C0)) \ - X(PRINTREGB, PROC.print_reg(R0, IMM)) \ + X(PRINTREGSIGNED, PROC.print_reg_signed(IMM, R0)) \ + X(PRINTREGB, PROC.print_reg(R0, IMM, SIZE)) \ X(PRINTREGPLAINB, PROC.print_reg_plain(C0)) \ X(PRINTFLOATPLAINB, PROC.print_float(EXTRA)) \ X(CONDPRINTSTRB, if(C0.get()) PROC.print_str(IMM)) \ #define COMBI_INSTRUCTIONS BIT_INSTRUCTIONS \ + X(INPUTB, T::inputb(PROC, Proc, EXTRA)) \ X(ANDM, processor.andm(instruction)) \ - X(LDMSBI, S0 = processor.memories.MS[Proc.read_Ci(REG1)]) \ - X(STMSBI, processor.memories.MS[Proc.read_Ci(REG1)] = S0) \ X(CONVSINT, S0.load_clear(IMM, Proc.read_Ci(REG1))) \ X(CONVCINT, C0 = Proc.read_Ci(REG1)) \ X(CONVCBIT, Proc.write_Ci(R0, PC1.get())) \ @@ -84,8 +83,7 @@ X(SPLIT, Proc.split(INST)) \ #define GC_INSTRUCTIONS \ - X(LDMSBI, S0 = MSI) \ - X(STMSBI, MSI = S0) \ + X(INPUTB, T::inputb(PROC, EXTRA)) \ X(LDMSD, PROC.load_dynamic_direct(EXTRA, MD)) \ X(STMSD, PROC.store_dynamic_direct(EXTRA, MD)) \ X(LDMSDI, PROC.load_dynamic_indirect(EXTRA, MD)) \ diff --git a/GC/square64.cpp b/GC/square64.cpp index 9c4774ed6..5217a0905 100644 --- a/GC/square64.cpp +++ b/GC/square64.cpp @@ -7,6 +7,7 @@ #include "Tools/cpu_support.h" #include #include +#include using namespace std; union matrix32x8 @@ -98,6 +99,9 @@ void square64::transpose(int n_rows, int n_cols) print(); #endif + assert(n_rows <= 64); + assert(n_cols <= 64); + square64 tmp = *this; *this = {}; diff --git a/Machines/Player-Online.hpp b/Machines/Player-Online.hpp index 9a23ebd35..c65934ddb 100644 --- a/Machines/Player-Online.hpp +++ b/Machines/Player-Online.hpp @@ -3,9 +3,7 @@ #include "Math/Setup.h" #include "Protocols/Share.h" #include "Tools/ezOptionParser.h" -#include "Tools/Config.h" #include "Networking/Server.h" -#include "GC/TinierSecret.h" #include #include diff --git a/Makefile b/Makefile index 2f0bc4379..1f28abc28 100644 --- a/Makefile +++ b/Makefile @@ -59,7 +59,7 @@ offline: $(OT_EXE) Check-Offline.x gen_input: gen_input_f2n.x gen_input_fp.x -externalIO: client-setup.x bankers-bonus-client.x bankers-bonus-commsec-client.x +externalIO: bankers-bonus-client.x bmr: bmr-program-party.x bmr-program-tparty.x @@ -134,9 +134,6 @@ bmr-clean: bankers-bonus-client.x: ExternalIO/bankers-bonus-client.cpp $(COMMON) $(CXX) $(CFLAGS) -o $@ $^ $(LDLIBS) -bankers-bonus-commsec-client.x: ExternalIO/bankers-bonus-commsec-client.cpp $(COMMON) - $(CXX) $(CFLAGS) -o $@ $^ $(LDLIBS) - simple-offline.x: $(FHEOFFLINE) pairwise-offline.x: $(FHEOFFLINE) cnc-offline.x: $(FHEOFFLINE) @@ -195,6 +192,9 @@ OT/BaseOT.o: SimpleOT/Makefile SimpleOT/Makefile: git submodule update --init SimpleOT +Programs/Circuits: + git submodule update --init Programs/Circuits + .PHONY: mpir-setup mpir-global mpir mpir-setup: git submodule update --init mpir diff --git a/Math/Z2k.h b/Math/Z2k.h index 2c2bb1545..70239353d 100644 --- a/Math/Z2k.h +++ b/Math/Z2k.h @@ -73,6 +73,8 @@ class Z2 : public ValueInterface static void reqbl(int n); static bool allows(Dtype dtype); + static void specification(octetStream& os); + typedef Z2 next; typedef Z2 Scalar; @@ -254,12 +256,18 @@ class SignedZ2 : public Z2 } template - SignedZ2 operator*(const Z2& other) const + SignedZ2 operator*(const SignedZ2& other) const { assert((K % 64 == 0) and (L % 64 == 0)); return Z2::operator*(other); } + template + Z2 operator*(const Z2& other) const + { + return Z2::operator*(other); + } + SignedZ2 operator*(int other) const { return operator*(SignedZ2<64>(other)); diff --git a/Math/Z2k.hpp b/Math/Z2k.hpp index bef167256..791c4af70 100644 --- a/Math/Z2k.hpp +++ b/Math/Z2k.hpp @@ -36,6 +36,12 @@ bool Z2::allows(Dtype dtype) return Integer::allows(dtype); } +template +void Z2::specification(octetStream& os) +{ + os.store(K); +} + template Z2::Z2(const bigint& x) : Z2() { diff --git a/Math/gfp.cpp b/Math/gfp.cpp index aafca38e1..dd36516ca 100644 --- a/Math/gfp.cpp +++ b/Math/gfp.cpp @@ -205,6 +205,12 @@ bool gfp_::allows(Dtype type) } } +template +void gfp_::specification(octetStream& os) +{ + os.store(pr()); +} + void to_signed_bigint(bigint& ans, const gfp& x) { to_bigint(ans, x); diff --git a/Math/gfp.h b/Math/gfp.h index bfcb5ea85..667bf9f9c 100644 --- a/Math/gfp.h +++ b/Math/gfp.h @@ -92,6 +92,8 @@ class gfp_ static bool allows(Dtype type); + static void specification(octetStream& os); + static const bool invertible = true; static gfp_ Mul(gfp_ a, gfp_ b) { return a * b; } diff --git a/Math/modp.h b/Math/modp.h index 107a1e36a..185e3b848 100644 --- a/Math/modp.h +++ b/Math/modp.h @@ -61,6 +61,9 @@ class modp_ void pack(octetStream& o,const Zp_Data& ZpD) const; void unpack(octetStream& o,const Zp_Data& ZpD); + void pack(octetStream& o) const; + void unpack(octetStream& o); + bool operator==(const modp_& other) const { return 0 == mpn_cmp(x, other.x, L); } bool operator!=(const modp_& other) const { return not (*this == other); } diff --git a/Math/modp.hpp b/Math/modp.hpp index c43987c25..9a35170ab 100644 --- a/Math/modp.hpp +++ b/Math/modp.hpp @@ -20,6 +20,17 @@ void modp_::unpack(octetStream& o,const Zp_Data& ZpD) o.consume((octet*) x,ZpD.t*sizeof(mp_limb_t)); } +template +void modp_::unpack(octetStream& o) +{ + o.consume((octet*) x,L*sizeof(mp_limb_t)); +} + +template +void modp_::pack(octetStream& o) const +{ + o.append((octet*) x,L*sizeof(mp_limb_t)); +} template void Negate(modp_& ans,const modp_& x,const Zp_Data& ZpD) diff --git a/Networking/CryptoPlayer.cpp b/Networking/CryptoPlayer.cpp index 012d37602..ad4695d0e 100644 --- a/Networking/CryptoPlayer.cpp +++ b/Networking/CryptoPlayer.cpp @@ -14,29 +14,19 @@ void check_ssl_file(string filename) "You can use `Scripts/setup-ssl.sh `."); } -void ssl_error(string side, string pronoun, int other, int server) +void ssl_error(string side, string pronoun, string other, string server) { - cerr << side << "-side handshake with party " << other + cerr << side << "-side handshake with " << other << " failed. Make sure " << pronoun - << " have the necessary certificate (" << PREP_DIR "P" << server + << " have the necessary certificate (" << PREP_DIR << server << ".pem in the default configuration)," << " and run `c_rehash ` on its location." << endl; } CryptoPlayer::CryptoPlayer(const Names& Nms, int id_base) : MultiPlayer(Nms, id_base), plaintext_player(Nms, id_base), - ctx(boost::asio::ssl::context::tlsv12) + ctx("P" + to_string(my_num())) { - string prefix = PREP_DIR "P" + to_string(my_num()); - string cert_file = prefix + ".pem"; - string key_file = prefix + ".key"; - check_ssl_file(cert_file); - check_ssl_file(key_file); - - ctx.use_certificate_file(cert_file, ctx.pem); - ctx.use_private_key_file(key_file, ctx.pem); - ctx.add_verify_path("Player-Data"); - sockets.resize(num_players()); senders.resize(num_players()); @@ -49,30 +39,8 @@ CryptoPlayer::CryptoPlayer(const Names& Nms, int id_base) : continue; } - sockets[i] = new ssl_socket(io_service, ctx); - sockets[i]->lowest_layer().assign(boost::asio::ip::tcp::v4(), plaintext_player.socket(i)); - sockets[i]->set_verify_mode(boost::asio::ssl::verify_peer); - sockets[i]->set_verify_callback(boost::asio::ssl::rfc2818_verification("P" + to_string(i))); - if (i < my_num()) - try - { - sockets[i]->handshake(ssl_socket::client); - } - catch (...) - { - ssl_error("Client", "we", i, i); - throw; - } - if (i > my_num()) - try - { - sockets[i]->handshake(ssl_socket::server); - } - catch (...) - { - ssl_error("Server", "they", i, my_num()); - throw; - } + sockets[i] = new ssl_socket(io_service, ctx, plaintext_player.socket(i), + "P" + to_string(i), "P" + to_string(my_num()), i < my_num()); senders[i] = new Sender(sockets[i]); } diff --git a/Networking/CryptoPlayer.h b/Networking/CryptoPlayer.h index 54bde9646..29c08ea6a 100644 --- a/Networking/CryptoPlayer.h +++ b/Networking/CryptoPlayer.h @@ -15,7 +15,7 @@ class CryptoPlayer : public MultiPlayer { PlainPlayer plaintext_player; - boost::asio::ssl::context ctx; + ssl_ctx ctx; boost::asio::io_service io_service; vector*> senders; diff --git a/Networking/Player.cpp b/Networking/Player.cpp index 193bf3c3d..b8fac653c 100644 --- a/Networking/Player.cpp +++ b/Networking/Player.cpp @@ -2,7 +2,6 @@ #include "Player.h" #include "ssl_sockets.h" #include "Exceptions/Exceptions.h" -#include "Networking/STS.h" #include "Tools/int.h" #include "Tools/NetworkOptions.h" #include "Networking/Server.h" diff --git a/Networking/STS.cpp b/Networking/STS.cpp deleted file mode 100644 index 8534afe08..000000000 --- a/Networking/STS.cpp +++ /dev/null @@ -1,228 +0,0 @@ -#include "Networking/STS.h" -#include -#include -#include -#include -#include -#include -#include - -void STS::kdf_block(unsigned char *block) -{ - crypto_hash_sha512_state state; - crypto_hash_sha512_init(&state); - unsigned char ctrbytes[sizeof kdf_counter]; - kdf_counter++; - - // Little endian serialization - for(size_t i=0; i> i*8) & 0xFF); - } - crypto_hash_sha512_update(&state,ctrbytes,sizeof ctrbytes); - crypto_hash_sha512_update(&state,raw_secret,crypto_hash_sha512_BYTES); - crypto_hash_sha512_final(&state, block); -} - -vector STS::unsafe_derive_secret(size_t sz) -{ - // KDF ~ H(cnt || raw_secret) - vector resultSecret(sz + crypto_hash_sha512_BYTES - (sz % crypto_hash_sha512_BYTES)); - size_t total=0; - while(total < sz) { - unsigned char *block = &resultSecret[total]; - kdf_block(block); - total += crypto_hash_sha512_BYTES; - } - return resultSecret; -} - -STS::STS() -{ - phase = UNDEFINED; -} - -void STS::init( const unsigned char theirPub[crypto_sign_PUBLICKEYBYTES] - , const unsigned char myPub[crypto_sign_PUBLICKEYBYTES] - , const unsigned char myPriv[crypto_sign_SECRETKEYBYTES]) -{ - phase = UNKNOWN; - memcpy(their_public_sign_key, theirPub, crypto_sign_PUBLICKEYBYTES); - memcpy(my_public_sign_key, myPub, crypto_sign_PUBLICKEYBYTES); - memcpy(my_private_sign_key, myPriv, crypto_sign_SECRETKEYBYTES); - memset(their_ephemeral_public_key, 0, crypto_box_PUBLICKEYBYTES); - memset(ephemeral_public_key, 0, crypto_box_PUBLICKEYBYTES); - memset(ephemeral_private_key, 0, crypto_box_SECRETKEYBYTES); - kdf_counter = 0; -} - -STS::STS( const unsigned char theirPub[crypto_sign_PUBLICKEYBYTES] - , const unsigned char myPub[crypto_sign_PUBLICKEYBYTES] - , const unsigned char myPriv[crypto_sign_SECRETKEYBYTES]) -{ - phase = UNKNOWN; - memcpy(their_public_sign_key, theirPub, crypto_sign_PUBLICKEYBYTES); - memcpy(my_public_sign_key, myPub, crypto_sign_PUBLICKEYBYTES); - memcpy(my_private_sign_key, myPriv, crypto_sign_SECRETKEYBYTES); - memset(their_ephemeral_public_key, 0, crypto_box_PUBLICKEYBYTES); - memset(ephemeral_public_key, 0, crypto_box_PUBLICKEYBYTES); - memset(ephemeral_private_key, 0, crypto_box_SECRETKEYBYTES); - kdf_counter = 0; -} - -STS::~STS() -{ - memset(their_public_sign_key, 0, crypto_sign_PUBLICKEYBYTES); - memset(my_private_sign_key, 0, crypto_sign_SECRETKEYBYTES); - memset(ephemeral_private_key, 0, crypto_box_SECRETKEYBYTES); - memset(ephemeral_public_key, 0, crypto_box_PUBLICKEYBYTES); - memset(their_ephemeral_public_key, 0, crypto_box_PUBLICKEYBYTES); - memset(raw_secret, 0, crypto_hash_sha512_BYTES); - kdf_counter = 0; - phase = UNKNOWN; -} - -sts_msg1_t STS::send_msg1() -{ - sts_msg1_t m; - if(UNKNOWN != phase) { - throw "STS BAD PHASE"; - } - - crypto_box_keypair(ephemeral_public_key, ephemeral_private_key); - memcpy(m.bytes,ephemeral_public_key,crypto_box_PUBLICKEYBYTES); - phase = SENT1; - return m; -} - -// If the incoming signature is valid, compute: -// shared secret = H(DH(pubB,privA) || pubA || pubB) -// msg = Sign_{privED-A} (pubA || pubB ) -// -sts_msg3_t STS::recv_msg2(sts_msg2_t msg2) -{ - unsigned char *theirPublicKey = msg2.pubkey; - unsigned char *theirSig = msg2.sig; - unsigned char theirSigDec[crypto_sign_BYTES]; - unsigned char scalar_result[crypto_scalarmult_SCALARBYTES]; - const unsigned char zeroNonce[crypto_stream_NONCEBYTES] = {0}; - int ret; - crypto_hash_sha512_state state; - sts_msg3_t msg; - - if(SENT1 != phase) { - throw "STS BAD PHASE"; - } - ret = crypto_scalarmult(scalar_result, ephemeral_private_key, theirPublicKey); - if(0 != ret) { - throw "crypto_scalarmult failed"; - } - - crypto_hash_sha512_init(&state); - crypto_hash_sha512_update(&state,scalar_result,crypto_scalarmult_SCALARBYTES); - crypto_hash_sha512_update(&state,ephemeral_public_key,crypto_box_PUBLICKEYBYTES); - crypto_hash_sha512_update(&state,theirPublicKey,crypto_box_PUBLICKEYBYTES); - crypto_hash_sha512_final(&state,raw_secret); - - vector keKey = unsafe_derive_secret(crypto_stream_KEYBYTES); - vector expectedMessage; - expectedMessage.insert(expectedMessage.end(), theirPublicKey , theirPublicKey + crypto_box_PUBLICKEYBYTES); - expectedMessage.insert(expectedMessage.end(), ephemeral_public_key, ephemeral_public_key + crypto_box_PUBLICKEYBYTES); - - crypto_stream_xor(theirSigDec, theirSig, crypto_sign_BYTES, zeroNonce, &keKey[0]); - - int badSig = crypto_sign_verify_detached(theirSigDec, &expectedMessage[0], expectedMessage.size(), their_public_sign_key); - - if(badSig) { - throw "Bad signature received in message 2."; - } else { - unsigned char *mySigEnc = msg.bytes; - unsigned char mySig[crypto_sign_BYTES]; - vector signMessage; - signMessage.insert(signMessage.end(), ephemeral_public_key, ephemeral_public_key + crypto_box_PUBLICKEYBYTES); - signMessage.insert(signMessage.end(), theirPublicKey , theirPublicKey + crypto_box_PUBLICKEYBYTES); - if(0 != crypto_sign_detached(mySig, NULL, &signMessage[0], signMessage.size(), my_private_sign_key)) { - throw "Signing failed."; - } - vector keKey2 = unsafe_derive_secret(crypto_stream_KEYBYTES); - crypto_stream_xor(mySigEnc, mySig, crypto_sign_BYTES, zeroNonce, &keKey2[0]); - - phase = FINISHED; - return msg; - } -} - -sts_msg2_t STS::recv_msg1(sts_msg1_t msg1) -{ - unsigned char *theirPublicKey = msg1.bytes; - unsigned char scalar_result[crypto_scalarmult_SCALARBYTES]; - crypto_hash_sha512_state state; - sts_msg2_t m; - int ret; - - if(UNKNOWN != phase) { - throw "recv_msg1 called on non-unknown phase"; - } - - memcpy(their_ephemeral_public_key, theirPublicKey, crypto_box_PUBLICKEYBYTES); - - crypto_box_keypair(ephemeral_public_key, ephemeral_private_key); - memcpy(m.pubkey,ephemeral_public_key,crypto_box_PUBLICKEYBYTES); - ret = crypto_scalarmult(scalar_result, ephemeral_private_key, theirPublicKey); - if(0 != ret) { - throw "crypto_scalarmult failed when processing message 1"; - } - - crypto_hash_sha512_init(&state); - crypto_hash_sha512_update(&state,scalar_result,crypto_scalarmult_SCALARBYTES); - crypto_hash_sha512_update(&state,theirPublicKey,crypto_box_PUBLICKEYBYTES); - crypto_hash_sha512_update(&state,ephemeral_public_key,crypto_box_PUBLICKEYBYTES); - crypto_hash_sha512_final(&state,raw_secret); - - vector livenessProof; - livenessProof.insert(livenessProof.end(), ephemeral_public_key, ephemeral_public_key + crypto_box_PUBLICKEYBYTES); - livenessProof.insert(livenessProof.end(), theirPublicKey , theirPublicKey + crypto_box_PUBLICKEYBYTES); - unsigned char mySig[crypto_sign_BYTES]; - unsigned char *mySigEnc = m.sig; - vector keKey = unsafe_derive_secret(crypto_stream_KEYBYTES); - - unsigned char zeroNonce[crypto_stream_NONCEBYTES] = {0}; - if(0 != crypto_sign_detached(mySig, NULL, &livenessProof[0], livenessProof.size(), my_private_sign_key)) { - throw "Signing failed."; - } - crypto_stream_xor(mySigEnc, mySig, crypto_sign_BYTES, zeroNonce, &keKey[0]); - - phase = SENT2; - return m; -} - -void STS::recv_msg3(sts_msg3_t msg3) -{ - unsigned char *theirSig=msg3.bytes; - unsigned char theirSigDec[crypto_sign_BYTES]; - vector expectedMessage; - if(SENT2 != phase) { - throw "recv_msg3 called out of order"; - } - - expectedMessage.insert(expectedMessage.end(), their_ephemeral_public_key , their_ephemeral_public_key + crypto_box_PUBLICKEYBYTES); - expectedMessage.insert(expectedMessage.end(), ephemeral_public_key, ephemeral_public_key + crypto_box_PUBLICKEYBYTES); - unsigned char zeroNonce[crypto_stream_NONCEBYTES] = {0}; - vector keKey2 = unsafe_derive_secret(crypto_stream_KEYBYTES); - - crypto_stream_xor(theirSigDec, theirSig, crypto_sign_BYTES, zeroNonce, &keKey2[0]); - int badSig = crypto_sign_verify_detached(theirSigDec, &expectedMessage[0], expectedMessage.size(), their_public_sign_key); - - if(badSig) { - throw "Bad signature received in message 3."; - } else { - phase = FINISHED; - } -} - -vector STS::derive_secret(size_t sz) -{ - if(phase != FINISHED) { - throw "Can not derive secrets till the key exchange has completed."; - } - return unsafe_derive_secret(sz); -} diff --git a/Networking/STS.h b/Networking/STS.h deleted file mode 100644 index e6c044cd2..000000000 --- a/Networking/STS.h +++ /dev/null @@ -1,70 +0,0 @@ -#ifndef _NETWORK_STS -#define _NETWORK_STS - -/* The Station to Station protocol - */ - -#include -#include -#include -#include - -using namespace std; - -typedef enum - { UNKNOWN // Have not started the interaction or have cleared the memory - , SENT1 // Sent initial message - , SENT2 // Received 1, sent 2 - , FINISHED // Done (received msg 2 & sent 3 or received msg 3) - , UNDEFINED // For arrays/vectors/etc of STS classes that are initialized later. -} phase_t; - -struct msg1_st { - unsigned char bytes[crypto_box_PUBLICKEYBYTES]; -}; -typedef struct msg1_st sts_msg1_t; -struct msg2_st { - unsigned char pubkey[crypto_box_PUBLICKEYBYTES]; - unsigned char sig[crypto_sign_BYTES]; -}; -typedef struct msg2_st sts_msg2_t; -struct msg3_st { - unsigned char bytes[crypto_sign_BYTES]; -}; -typedef struct msg3_st sts_msg3_t; - -class STS -{ - phase_t phase; - unsigned char their_public_sign_key[crypto_sign_PUBLICKEYBYTES]; - unsigned char my_public_sign_key[crypto_sign_PUBLICKEYBYTES]; - unsigned char my_private_sign_key[crypto_sign_SECRETKEYBYTES]; - unsigned char ephemeral_private_key[crypto_box_SECRETKEYBYTES]; - unsigned char ephemeral_public_key[crypto_box_PUBLICKEYBYTES]; - unsigned char their_ephemeral_public_key[crypto_box_PUBLICKEYBYTES]; - unsigned char raw_secret[crypto_hash_sha512_BYTES]; - uint64_t kdf_counter; - public: - STS(); - STS( const unsigned char theirPub[crypto_sign_PUBLICKEYBYTES] - , const unsigned char myPub[crypto_sign_PUBLICKEYBYTES] - , const unsigned char myPriv[crypto_sign_SECRETKEYBYTES]); - ~STS(); - - void init( const unsigned char theirPub[crypto_sign_PUBLICKEYBYTES] - , const unsigned char myPub[crypto_sign_PUBLICKEYBYTES] - , const unsigned char myPriv[crypto_sign_SECRETKEYBYTES]); - - sts_msg1_t send_msg1(); - sts_msg3_t recv_msg2(sts_msg2_t msg2); - - sts_msg2_t recv_msg1(sts_msg1_t msg1); - void recv_msg3(sts_msg3_t msg3); - - vector derive_secret(size_t); - private: - vector unsafe_derive_secret(size_t); - void kdf_block(unsigned char *block); -}; - -#endif /* _NETWORK_STS */ diff --git a/Networking/ServerSocket.cpp b/Networking/ServerSocket.cpp index 3408bf716..fbeb13881 100644 --- a/Networking/ServerSocket.cpp +++ b/Networking/ServerSocket.cpp @@ -74,13 +74,61 @@ void ServerSocket::init() pthread_create(&thread, 0, accept_thread, this); } +class ServerJob +{ + ServerSocket& server; + int socket; + sockaddr dest; + +public: + pthread_t thread; + + ServerJob(ServerSocket& server, int socket, sockaddr dest) : + server(server), socket(socket), dest(dest), thread(0) + { + } + + static void* run(void* job) + { + auto& server_job = *(ServerJob*)(job); + server_job.server.wait_for_client_id(server_job.socket, server_job.dest); + return 0; + } +}; + ServerSocket::~ServerSocket() { + for (auto& job : jobs) + { + pthread_cancel(job->thread); + pthread_join(job->thread, 0); + delete job; + } + pthread_cancel(thread); pthread_join(thread, 0); if (close(main_socket)) { error("close(main_socket"); }; } +void ServerSocket::wait_for_client_id(int socket, sockaddr dest) +{ + (void) dest; + int client_id; + try + { + receive(socket, (unsigned char*) &client_id, sizeof(client_id)); + process_connection(socket, client_id); + } + catch (closed_connection&) + { +#ifdef DEBUG_NETWORKING + auto& conn = *(sockaddr_in*) &dest; + fprintf(stderr, "client on %s:%d left without identification\n", + inet_ntoa(conn.sin_addr), ntohs(conn.sin_port)); +#endif + } +} + void ServerSocket::accept_clients() { while (true) @@ -92,25 +140,19 @@ void ServerSocket::accept_clients() if (consocket<0) { error("set_up_socket:accept"); } int client_id; - try - { - receive(consocket, (unsigned char*)&client_id, sizeof(client_id)); - } - catch (closed_connection&) - { + if (receive_all_or_nothing(consocket, (unsigned char*)&client_id, sizeof(client_id))) + process_connection(consocket, client_id); + else + { #ifdef DEBUG_NETWORKING - auto& conn = *(sockaddr_in*)&dest; - cerr << "client on " << inet_ntoa(conn.sin_addr) << ":" - << ntohs(conn.sin_port) << " left without identification" - << endl; + auto& conn = *(sockaddr_in*) &dest; + fprintf(stderr, "deferring client on %s:%d to thread\n", + inet_ntoa(conn.sin_addr), ntohs(conn.sin_port)); #endif - } - - data_signal.lock(); - process_client(client_id); - clients[client_id] = consocket; - data_signal.broadcast(); - data_signal.unlock(); + // defer to thread + jobs.push_back(new ServerJob(*this, consocket, dest)); + pthread_create(&jobs.back()->thread, 0, ServerJob::run, jobs.back()); + } #ifdef __APPLE__ int flags = fcntl(consocket, F_GETFL, 0); @@ -121,15 +163,19 @@ void ServerSocket::accept_clients() } } -int ServerSocket::get_connection_count() +void ServerSocket::process_connection(int consocket, int client_id) { data_signal.lock(); - int connection_count = clients.size(); +#ifdef DEBUG_NETWORKING + cerr << "client " << hex << client_id << " is on socket " << dec << consocket + << endl; +#endif + process_client(client_id); + clients[client_id] = consocket; + data_signal.broadcast(); data_signal.unlock(); - return connection_count; } - int ServerSocket::get_connection_socket(int id) { data_signal.lock(); @@ -163,16 +209,10 @@ void AnonymousServerSocket::init() pthread_create(&thread, 0, anonymous_accept_thread, this); } -int AnonymousServerSocket::get_connection_count() -{ - return num_accepted_clients; -} - void AnonymousServerSocket::process_client(int client_id) { if (clients.find(client_id) != clients.end()) close_client_socket(clients[client_id]); - num_accepted_clients++; client_connection_queue.push(client_id); } diff --git a/Networking/ServerSocket.h b/Networking/ServerSocket.h index e1e16dac9..348349f79 100644 --- a/Networking/ServerSocket.h +++ b/Networking/ServerSocket.h @@ -12,10 +12,13 @@ using namespace std; #include +#include #include "Tools/WaitQueue.h" #include "Tools/Signal.h" +class ServerJob; + class ServerSocket { protected: @@ -25,9 +28,13 @@ class ServerSocket Signal data_signal; pthread_t thread; + vector jobs; + // disable copying ServerSocket(const ServerSocket& other); + void process_connection(int socket, int client_id); + virtual void process_client(int) {} public: @@ -38,14 +45,11 @@ class ServerSocket virtual void accept_clients(); + void wait_for_client_id(int socket, sockaddr dest); + // This depends on clients sending their id as int. // Has to be thread-safe. int get_connection_socket(int number); - - // How many client connections have been made. - virtual int get_connection_count(); - - void close_socket(); }; /* @@ -55,18 +59,15 @@ class AnonymousServerSocket : public ServerSocket { private: // No. of accepted connections in this instance - int num_accepted_clients; queue client_connection_queue; void process_client(int client_id); public: AnonymousServerSocket(int Portnum) : - ServerSocket(Portnum), num_accepted_clients(0) { }; + ServerSocket(Portnum) { }; void init(); - virtual int get_connection_count(); - // Get socket and id for the last client who connected int get_connection_socket(int& client_id); }; diff --git a/Networking/ssl_sockets.h b/Networking/ssl_sockets.h index a078126ce..25515d69f 100644 --- a/Networking/ssl_sockets.h +++ b/Networking/ssl_sockets.h @@ -12,7 +12,65 @@ #include #include -typedef boost::asio::ssl::stream ssl_socket; +typedef boost::asio::io_service ssl_service; + +void check_ssl_file(string filename); +void ssl_error(string side, string pronoun, string other, string server); + +class ssl_ctx : public boost::asio::ssl::context +{ +public: + ssl_ctx(string me) : + boost::asio::ssl::context(boost::asio::ssl::context::tlsv12) + { + string prefix = PREP_DIR + me; + string cert_file = prefix + ".pem"; + string key_file = prefix + ".key"; + check_ssl_file(cert_file); + check_ssl_file(key_file); + + use_certificate_file(cert_file, pem); + use_private_key_file(key_file, pem); + add_verify_path(PREP_DIR); + } +}; + +class ssl_socket : public boost::asio::ssl::stream +{ + typedef boost::asio::ssl::stream parent; + +public: + ssl_socket(boost::asio::io_service& io_service, + boost::asio::ssl::context& ctx, int plaintext_socket, string other, + string me, bool client) : + parent(io_service, ctx) + { + lowest_layer().assign(boost::asio::ip::tcp::v4(), plaintext_socket); + set_verify_mode(boost::asio::ssl::verify_peer); + set_verify_callback(boost::asio::ssl::rfc2818_verification(other)); + if (client) + try + { + handshake(ssl_socket::client); + } catch (...) + { + ssl_error("Client", "we", other, other); + throw; + } + else + { + try + { + handshake(ssl_socket::server); + } catch (...) + { + ssl_error("Server", "they", other, me); + throw; + } + + } + } +}; inline size_t send_non_blocking(ssl_socket* socket, octet* data, size_t length) { diff --git a/Processor/ExternalClients.cpp b/Processor/ExternalClients.cpp index ec83a2011..95abcceb6 100644 --- a/Processor/ExternalClients.cpp +++ b/Processor/ExternalClients.cpp @@ -5,41 +5,24 @@ #include ExternalClients::ExternalClients(int party_num, const string& prep_data_dir): - party_num(party_num), prep_data_dir(prep_data_dir), server_connection_count(-1) + party_num(party_num), prep_data_dir(prep_data_dir), + ctx(0) { } ExternalClients::~ExternalClients() { // close client sockets - for (map::iterator it = external_client_sockets.begin(); + for (auto it = external_client_sockets.begin(); it != external_client_sockets.end(); it++) { - if (close(it->second)) - { - error("failed to close external client connection socket)"); - } + delete it->second; } for (map::iterator it = client_connection_servers.begin(); it != client_connection_servers.end(); it++) { delete it->second; } - for (map::iterator it = symmetric_client_keys.begin(); - it != symmetric_client_keys.end(); it++) - { - delete[] it->second; - } - for (map,uint64_t> >::iterator it_cs = symmetric_client_commsec_send_keys.begin(); - it_cs != symmetric_client_commsec_send_keys.end(); it_cs++) - { - memset(&(it_cs->second.first[0]), 0, it_cs->second.first.size()); - } - for (map,uint64_t> >::iterator it_cs = symmetric_client_commsec_recv_keys.begin(); - it_cs != symmetric_client_commsec_recv_keys.end(); it_cs++) - { - memset(&(it_cs->second.first[0]), 0, it_cs->second.first.size()); - } } void ExternalClients::start_listening(int portnum_base) @@ -62,125 +45,21 @@ int ExternalClients::get_client_connection(int portnum_base) cerr << "Thread " << this_thread::get_id() << " found server." << endl; int client_id, socket; socket = client_connection_servers[portnum_base]->get_connection_socket(client_id); - external_client_sockets[client_id] = socket; - if (symmetric_client_keys.find(client_id) != symmetric_client_keys.end()) - delete symmetric_client_keys[client_id]; - symmetric_client_commsec_send_keys.erase(client_id); - symmetric_client_commsec_recv_keys.erase(client_id); + if (ctx == 0) + ctx = new ssl_ctx("P" + to_string(get_party_num())); + external_client_sockets[client_id] = new ssl_socket(io_service, *ctx, socket, + "C" + to_string(client_id), "P" + to_string(get_party_num()), false); cerr << "Party " << get_party_num() << " received external client connection from client id: " << dec << client_id << endl; return client_id; } -int ExternalClients::connect_to_server(int portnum_base, int ipv4_address) -{ - struct in_addr addr = { (unsigned int)ipv4_address }; - int csocket; - const char* address_str = inet_ntoa(addr); - cerr << "Party " << get_party_num() << " connecting to server at " << address_str << " on port " << portnum_base + get_party_num() << endl; - set_up_client_socket(csocket, address_str, portnum_base + get_party_num()); - cerr << "Party " << get_party_num() << " connected to server at " << address_str << " on port " << portnum_base + get_party_num() << endl; - int server_id = server_connection_count; - // server identifiers are -1, -2, ... to avoid conflict with client identifiers - server_connection_count--; - external_client_sockets[server_id] = csocket; - return server_id; -} - -void ExternalClients::curve25519_ints_to_bytes(unsigned char *bytes, const vector& key_ints) -{ - for(unsigned int j = 0; j < key_ints.size(); j++) { - for(unsigned int k = 0; k < 4; k++) { - bytes[j*sizeof(int) + k] = (key_ints[j] >> ((3-k)*8)) & 0xFF; - } - } -} - -// Generate sesssion key for a newly connected client, store in symmetric_client_keys -// public_key is expected to be size 8 and contain integer values of public key bytes. -// Assumes load_server_keys has been run. -void ExternalClients::generate_session_key_for_client(int client_id, const vector& public_key) -{ - assert(public_key.size() * sizeof(int) == crypto_box_PUBLICKEYBYTES); - - load_server_keys_once(); - - unsigned char client_publickey[crypto_box_PUBLICKEYBYTES]; - - curve25519_ints_to_bytes(client_publickey, public_key); - - cerr << "Recevied client public key for client " << dec << client_id << " :"; - for (unsigned int j = 0; j < crypto_box_PUBLICKEYBYTES; j++) - cerr << hex << (int) client_publickey[j] << " "; - cerr << dec << endl; - - unsigned char scalarmult_q_by_server[crypto_scalarmult_BYTES]; - crypto_generichash_state h; - - symmetric_client_keys[client_id] = new octet[crypto_generichash_BYTES]; - - // Derive a shared key from this server's secret key and the client's public key - // shared key = h(q || server_secretkey || client_publickey) - if (crypto_scalarmult(scalarmult_q_by_server, server_secretkey, client_publickey) != 0) { - cerr << "Scalar mult failed\n"; - exit(1); - } - crypto_generichash_init(&h, NULL, 0U, crypto_generichash_BYTES); - crypto_generichash_update(&h, scalarmult_q_by_server, sizeof scalarmult_q_by_server); - crypto_generichash_update(&h, client_publickey, sizeof client_publickey); - crypto_generichash_update(&h, server_publickey, sizeof server_publickey); - crypto_generichash_final(&h, symmetric_client_keys[client_id], crypto_generichash_BYTES); -} - -// Read pre-computed server keys from client-setup for this SPDZ engine. -// Only needs to be done once per run, but is only necessary if an external connection -// is being requested. -void ExternalClients::load_server_keys_once() -{ - if (server_keys_loaded) { - return; - } - - ifstream keyfile; - stringstream filename; - filename << prep_data_dir << "Player-SPDZ-Keys-P" << get_party_num(); - keyfile.open(filename.str().c_str()); - if (keyfile.fail()) - throw file_error(filename.str().c_str()); - - keyfile.read((char*)server_publickey, sizeof server_publickey); - if (keyfile.eof()) - throw end_of_file(filename.str(), "server public key" ); - keyfile.read((char*)server_secretkey, sizeof server_secretkey); - if (keyfile.eof()) - throw end_of_file(filename.str(), "server private key" ); - - bool loaded_ed25519 = true; - - keyfile.read((char*)server_publickey_ed25519, sizeof server_publickey_ed25519); - if (keyfile.eof() || keyfile.bad()) - loaded_ed25519 = false; - keyfile.read((char*)server_secretkey_ed25519, sizeof server_secretkey_ed25519); - if (keyfile.eof() || keyfile.bad()) - loaded_ed25519 = false; - - keyfile.close(); - - ed25519_keys_loaded = loaded_ed25519; - server_keys_loaded = true; -} - -void ExternalClients::require_ed25519_keys() -{ - if (!ed25519_keys_loaded) - throw "Ed25519 keys required but not found in player key files"; -} int ExternalClients::get_party_num() { return party_num; } -int ExternalClients::get_socket(int id) +ssl_socket* ExternalClients::get_socket(int id) { if (external_client_sockets.find(id) == external_client_sockets.end()) throw runtime_error("external connection not found for id " + to_string(id)); diff --git a/Processor/ExternalClients.h b/Processor/ExternalClients.h index 687cac942..0e1941d44 100644 --- a/Processor/ExternalClients.h +++ b/Processor/ExternalClients.h @@ -2,6 +2,7 @@ #define _ExternalClients #include "Networking/sockets.h" +#include "Networking/ssl_sockets.h" #include "Exceptions/Exceptions.h" #include #include @@ -23,23 +24,14 @@ class ExternalClients int party_num; const string prep_data_dir; - int server_connection_count; - unsigned char server_publickey[crypto_box_PUBLICKEYBYTES]; - unsigned char server_secretkey[crypto_box_SECRETKEYBYTES]; - bool server_keys_loaded = false; - bool ed25519_keys_loaded = false; // Maps holding per client values (indexed by unique 32-bit id) - std::map external_client_sockets; + std::map external_client_sockets; - public: - - unsigned char server_publickey_ed25519[crypto_sign_ed25519_PUBLICKEYBYTES]; - unsigned char server_secretkey_ed25519[crypto_sign_ed25519_SECRETKEYBYTES]; + ssl_service io_service; + ssl_ctx* ctx; - std::map symmetric_client_keys; - std::map,uint64_t>> symmetric_client_commsec_send_keys; - std::map,uint64_t>> symmetric_client_commsec_recv_keys; + public: ExternalClients(int party_num, const string& prep_data_dir); ~ExternalClients(); @@ -48,18 +40,10 @@ class ExternalClients int get_client_connection(int portnum_base); - int connect_to_server(int portnum_base, int ipv4_address); - // return the socket for a given client or server identifier - int get_socket(int socket_id); - - void curve25519_ints_to_bytes(unsigned char bytes[crypto_box_PUBLICKEYBYTES], const vector& key_ints); - void generate_session_key_for_client(int client_id, const vector& public_key); - - void load_server_keys_once(); + ssl_socket* get_socket(int socket_id); int get_party_num(); - void require_ed25519_keys(); }; #endif diff --git a/Processor/FixInput.cpp b/Processor/FixInput.cpp index 8f4e6aa23..6820697ec 100644 --- a/Processor/FixInput.cpp +++ b/Processor/FixInput.cpp @@ -5,17 +5,18 @@ #include "FixInput.h" -const char* FixInput::NAME = "real number"; - -void FixInput::read(std::istream& in, const int* params) +template<> +void FixInput_::read(std::istream& in, const int* params) { -#ifdef LOW_PREC_INPUT double x; in >> x; items[0] = x * (1 << *params); -#else +} + +template<> +void FixInput_::read(std::istream& in, const int* params) +{ mpf_class x; in >> x; items[0] = x << *params; -#endif } diff --git a/Processor/FixInput.h b/Processor/FixInput.h index db3c1e02e..92a9ba84d 100644 --- a/Processor/FixInput.h +++ b/Processor/FixInput.h @@ -11,7 +11,8 @@ #include "Math/bigint.h" #include "Math/Integer.h" -class FixInput +template +class FixInput_ { public: const static int N_DEST = 1; @@ -20,13 +21,18 @@ class FixInput const static int TYPE = 1; -#ifdef LOW_PREC_INPUT - Integer items[N_DEST]; -#else - bigint items[N_DEST]; -#endif + T items[N_DEST]; void read(std::istream& in, const int* params); }; +template +const char* FixInput_::NAME = "real number"; + +#ifdef LOW_PREC_INPUT +typedef FixInput_ FixInput; +#else +typedef FixInput_ FixInput; +#endif + #endif /* PROCESSOR_FIXINPUT_H_ */ diff --git a/Processor/Input.h b/Processor/Input.h index 2b04890fd..266943b7f 100644 --- a/Processor/Input.h +++ b/Processor/Input.h @@ -57,6 +57,8 @@ class InputBase virtual T finalize_mine() = 0; virtual void finalize_other(int player, T& target, octetStream& o, int n_bits = -1) = 0; T finalize(int player, int n_bits = -1); + + void raw_input(SubProcessor& proc, const vector& args); }; template @@ -88,9 +90,6 @@ class Input : public InputBase T finalize_mine(); void finalize_other(int player, T& target, octetStream& o, int n_bits = -1); - - void start(int player, int n_inputs); - void stop(int player, const vector& targets); }; #endif /* PROCESSOR_INPUT_H_ */ diff --git a/Processor/Input.hpp b/Processor/Input.hpp index cb50d6d05..319dc9bc9 100644 --- a/Processor/Input.hpp +++ b/Processor/Input.hpp @@ -100,7 +100,7 @@ template void Input::add_other(int player) { open_type t; - shares[player].push_back({}); + shares.at(player).push_back({}); prep.get_input(shares[player].back(), t, player); } @@ -131,12 +131,16 @@ void InputBase::exchange() } template -void Input::start(int player, int n_inputs) +void InputBase::raw_input(SubProcessor& proc, const vector& args) { - reset(player); - if (player == P.my_num()) + auto& P = proc.P; + reset_all(P); + + for (auto it = args.begin(); it != args.end();) { - for (int i = 0; i < n_inputs; i++) + int player = *it++; + it++; + if (player == P.my_num()) { clear t; try @@ -149,32 +153,20 @@ void Input::start(int player, int n_inputs) } add_mine(t); } - send_mine(); - } - else - { - for (int i = 0; i < n_inputs; i++) + else + { add_other(player); + } } -} -template -void Input::stop(int player, const vector& targets) -{ - assert(proc != 0); - if (P.my_num() == player) - for (unsigned int i = 0; i < targets.size(); i++) - proc->get_S_ref(targets[i]) = finalize_mine(); - else + timer.start(); + exchange(); + timer.stop(); + + for (auto it = args.begin(); it != args.end();) { - octetStream o; - this->timer.start(); - P.receive_player(player, o, true); - this->timer.stop(); - for (unsigned int i = 0; i < targets.size(); i++) - { - finalize_other(player, proc->get_S_ref(targets[i]), o); - } + int player = *it++; + proc.get_S_ref(*it++) = finalize(player); } } diff --git a/Processor/Instruction.h b/Processor/Instruction.h index a38470397..8c6005be3 100644 --- a/Processor/Instruction.h +++ b/Processor/Instruction.h @@ -115,6 +115,7 @@ enum INPUTFLOAT = 0xF1, INPUTMIXED = 0xF2, INPUTMIXEDREG = 0xF3, + RAWINPUT = 0xF4, STARTINPUT = 0x61, STOPINPUT = 0x62, READSOCKETC = 0x63, @@ -248,6 +249,7 @@ enum GSTOPINPUT = 0x162, GREADSOCKETS = 0x164, GWRITESOCKETS = 0x166, + GRAWINPUT = 0x1F4, // Bitwise logic GANDC = 0x170, GXORC = 0x171, @@ -328,6 +330,8 @@ class BaseInstruction int get_opcode() const { return opcode; } int get_size() const { return size; } + // Reads a single instruction from the istream + void parse(istream& s, int inst_pos); void parse_operands(istream& s, int pos, int file_pos); bool is_gf2n_instruction() const { return ((opcode&0x100)!=0); } @@ -347,9 +351,6 @@ class DataPositions; class Instruction : public BaseInstruction { public: - // Reads a single instruction from the istream - void parse(istream& s, int inst_pos); - // Return whether usage is known bool get_offline_data_usage(DataPositions& usage); @@ -361,6 +362,5 @@ class Instruction : public BaseInstruction void execute(Processor& Proc) const; }; - #endif diff --git a/Processor/Instruction.hpp b/Processor/Instruction.hpp index ca3806755..0d265d10d 100644 --- a/Processor/Instruction.hpp +++ b/Processor/Instruction.hpp @@ -34,7 +34,7 @@ #include "Tools/callgrind.h" inline -void Instruction::parse(istream& s, int inst_pos) +void BaseInstruction::parse(istream& s, int inst_pos) { n=0; start.resize(0); r[0]=0; r[1]=0; r[2]=0; r[3]=0; @@ -224,7 +224,6 @@ void BaseInstruction::parse_operands(istream& s, int pos, int file_pos) case STARTPRIVATEOUTPUT: case GSTARTPRIVATEOUTPUT: case DIGESTC: - case CONNECTIPV4: // write socket handle, read IPv4 address, portnum r[0]=get_int(s); r[1]=get_int(s); n = get_int(s); @@ -254,8 +253,6 @@ void BaseInstruction::parse_operands(istream& s, int pos, int file_pos) case PRINTREGB: case GPRINTREG: case LDINT: - case STARTINPUT: - case GSTARTINPUT: case STOPPRIVATEOUTPUT: case GSTOPPRIVATEOUTPUT: case INPUTMASK: @@ -310,6 +307,8 @@ void BaseInstruction::parse_operands(istream& s, int pos, int file_pos) case INPUTFLOAT: case INPUTMIXED: case INPUTMIXEDREG: + case RAWINPUT: + case GRAWINPUT: case TRUNC_PR: num_var_args = get_int(s); get_vector(num_var_args, start, s); @@ -336,7 +335,6 @@ void BaseInstruction::parse_operands(istream& s, int pos, int file_pos) case READSOCKETC: case READSOCKETS: case READSOCKETINT: - case READCLIENTPUBLICKEY: num_var_args = get_int(s) - 1; r[0] = get_int(s); get_vector(num_var_args, start, s); @@ -352,20 +350,18 @@ void BaseInstruction::parse_operands(istream& s, int pos, int file_pos) r[1] = get_int(s); get_vector(num_var_args, start, s); break; + case CONNECTIPV4: + throw runtime_error("parties as clients not supported any more"); + case READCLIENTPUBLICKEY: case INITSECURESOCKET: case RESPSECURESOCKET: - num_var_args = get_int(s) - 1; - r[0] = get_int(s); - get_vector(num_var_args, start, s); - break; + throw runtime_error("VM-controlled encryption not supported any more"); // raw input + case STARTINPUT: + case GSTARTINPUT: case STOPINPUT: case GSTOPINPUT: - // subtract player number argument - num_var_args = get_int(s) - 1; - n = get_int(s); - get_vector(num_var_args, start, s); - break; + throw runtime_error("two-stage input not supported any more"); case GBITDEC: case GBITCOM: num_var_args = get_int(s) - 2; @@ -621,6 +617,7 @@ unsigned BaseInstruction::get_max_reg(int reg_type) const case INPUTB: skip = 4; offset = 3; + size_offset = -2; break; case ANDM: size = DIV_CEIL(n, 64); @@ -733,6 +730,7 @@ inline void Instruction::execute(Processor& Proc) const typedef typename sint::bit_type T; auto& processor = Proc.Procb; auto& instruction = *this; + auto& Ci = Proc.get_Ci(); // optimize some instructions switch (opcode) @@ -1292,17 +1290,11 @@ inline void Instruction::execute(Processor& Proc) const case INPUTMIXEDREG: sint::Input::input_mixed(Proc.Procp, start, size, true); return; - case STARTINPUT: - Proc.Procp.input.start(r[0],n); - break; - case GSTARTINPUT: - Proc.Proc2.input.start(r[0],n); - break; - case STOPINPUT: - Proc.Procp.input.stop(n,start); + case RAWINPUT: + Proc.Procp.input.raw_input(Proc.Procp, start); break; - case GSTOPINPUT: - Proc.Proc2.input.stop(n,start); + case GRAWINPUT: + Proc.Proc2.input.raw_input(Proc.Proc2, start); break; case ANDC: Proc.get_Cp_ref(r[0]).AND(Proc.read_Cp(r[1]),Proc.read_Cp(r[2])); @@ -1670,26 +1662,16 @@ inline void Instruction::execute(Processor& Proc) const ss << "No connection on port " << r[0] << endl; throw Processor_Error(ss.str()); } + if (Proc.P.my_num() == 0) + { + octetStream os; + os.store(int(sint::open_type::type_char())); + sint::open_type::specification(os); + os.Send(Proc.external_clients.get_socket(client_handle)); + } Proc.write_Ci(r[0], client_handle); break; } - case CONNECTIPV4: - { - // connect to server at port n + my_num() - int ipv4 = Proc.read_Ci(r[1]); - int server_handle = Proc.external_clients.connect_to_server(n, ipv4); - Proc.write_Ci(r[0], server_handle); - break; - } - case READCLIENTPUBLICKEY: - Proc.read_client_public_key(Proc.read_Ci(r[0]), start); - break; - case INITSECURESOCKET: - Proc.init_secure_socket(Proc.read_Ci(r[i]), start); - break; - case RESPSECURESOCKET: - Proc.resp_secure_socket(Proc.read_Ci(r[i]), start); - break; case READSOCKETINT: Proc.read_socket_ints(Proc.read_Ci(r[0]), start); break; diff --git a/Processor/Machine.hpp b/Processor/Machine.hpp index 62fab3981..2c0b4f138 100644 --- a/Processor/Machine.hpp +++ b/Processor/Machine.hpp @@ -35,7 +35,6 @@ Machine::Machine(int my_number, Names& playerNames, // Set up the fields prep_dir_prefix = get_prep_dir(N.num_players(), opts.lgp, lg2); - char filename[2048]; bool read_mac_keys = false; sgf2n::clear::init_field(lg2); @@ -96,16 +95,7 @@ Machine::Machine(int my_number, Names& playerNames, sint::clear::next::template init(false); // Initialize the global memory - if (memtype.compare("new")==0) - {sprintf(filename, PREP_DIR "Player-Memory-P%d", my_number); - ifstream memfile(filename); - if (memfile.fail()) { throw file_error(filename); } - M2.Load_Memory(memfile); - Mp.Load_Memory(memfile); - Mi.Load_Memory(memfile); - memfile.close(); - } - else if (memtype.compare("old")==0) + if (memtype.compare("old")==0) { inpf.open(memory_filename(), ios::in | ios::binary); if (inpf.fail()) { throw file_error(memory_filename()); } diff --git a/Processor/Memory.h b/Processor/Memory.h index 28bc76594..4146ef78a 100644 --- a/Processor/Memory.h +++ b/Processor/Memory.h @@ -77,21 +77,6 @@ class Memory friend ostream& operator<< <>(ostream& s,const Memory& M); friend istream& operator>> <>(istream& s,Memory& M); - - /* This function loads a un-shared global memory from disk and - * produces the memory - * - * The global unshared memory is of the form - * sz <- Size - * n val <- Clear values - * n val <- Clear values - * -1 -1 <- End of clear values - * n val <- Shared values - * n val <- Shared values - * -1 -1 - */ - void Load_Memory(ifstream& inpf); - }; #endif diff --git a/Processor/Memory.hpp b/Processor/Memory.hpp index 67ea21a8c..9686b83cc 100644 --- a/Processor/Memory.hpp +++ b/Processor/Memory.hpp @@ -106,40 +106,3 @@ istream& operator>>(istream& s,Memory& M) return s; } - - -template -void Memory::Load_Memory(ifstream& inpf) -{ - Memory& M = *this; - - int a; - typename T::clear val; - T S; - - inpf >> a; - M.resize_s(a); - inpf >> a; - M.resize_c(a); - - cerr << "Reading Clear Memory" << endl; - - // Read clear memory - inpf >> a; - val.input(inpf,true); - while (a!=-1) - { M.write_C(a,val); - inpf >> a; - val.input(inpf,true); - } - cerr << "Reading Shared Memory" << endl; - - // Read shared memory - inpf >> a; - S.input(inpf,true); - while (a!=-1) - { M.write_S(a,S); - inpf >> a; - S.input(inpf,true); - } -} diff --git a/Processor/Processor.h b/Processor/Processor.h index be2c59fe7..b57a22017 100644 --- a/Processor/Processor.h +++ b/Processor/Processor.h @@ -112,6 +112,10 @@ class ArithmeticProcessor : public ProcessorBase OnlineOptions opts; + ArithmeticProcessor() : + ArithmeticProcessor(OnlineOptions::singleton, BaseMachine::thread_num) + { + } ArithmeticProcessor(OnlineOptions opts, int thread_num) : thread_num(thread_num), sent(0), rounds(0), opts(opts) {} @@ -217,12 +221,6 @@ class Processor : public ArithmeticProcessor // Access to external client sockets for reading clear/shared data void read_socket_ints(int client_id, const vector& registers); - // Setup client public key - void read_client_public_key(int client_id, const vector& registers); - void init_secure_socket(int client_id, const vector& registers); - void init_secure_socket_internal(int client_id, const vector& registers); - void resp_secure_socket(int client_id, const vector& registers); - void resp_secure_socket_internal(int client_id, const vector& registers); void write_socket(const RegType reg_type, const SecrecyType secrecy_type, const bool send_macs, int socket_id, int message_type, const vector& registers); @@ -239,8 +237,6 @@ class Processor : public ArithmeticProcessor friend ostream& operator<<(ostream& s,const Processor& P); private: - void maybe_decrypt_sequence(int client_id); - void maybe_encrypt_sequence(int client_id); template friend class SPDZ; template friend class SubProcessor; diff --git a/Processor/Processor.hpp b/Processor/Processor.hpp index 69b76792a..4c37776de 100644 --- a/Processor/Processor.hpp +++ b/Processor/Processor.hpp @@ -3,7 +3,6 @@ #include "Processor/Processor.h" #include "Processor/Program.h" -#include "Networking/STS.h" #include "Protocols/fake-stuff.h" #include "GC/square64.h" @@ -68,7 +67,7 @@ Processor::Processor(int thread_num,Player& P, Procb(machine.bit_memories), Proc2(*this,MC2,DataF.DataF2,P),Procp(*this,MCp,DataF.DataFp,P), privateOutput2(Proc2),privateOutputp(Procp), - external_clients(ExternalClients(P.my_num(), machine.prep_dir_prefix)), + external_clients(P.my_num(), machine.prep_dir_prefix), binary_file_io(Binary_File_IO()) { reset(program,0); @@ -222,7 +221,6 @@ void Processor::split(const Instruction& instruction) // RegType and SecrecyType determines how registers are read and the socket stream is packed. // If message_type is > 0, send message_type in bytes 0 - 3, to allow an external client to // determine the data structure being sent in a message. -// Encryption is enabled if key material (for DH Auth Encryption and/or STS protocol) has been already setup. template void Processor::write_socket(const RegType reg_type, const SecrecyType secrecy_type, const bool send_macs, int socket_id, int message_type, const vector& registers) @@ -239,7 +237,11 @@ void Processor::write_socket(const RegType reg_type, const SecrecyT { if (reg_type == MODP && secrecy_type == SECRET) { // Send vector of secret shares and optionally macs - get_Sp_ref(registers[i]).pack(socket_stream, send_macs); + if (send_macs) + get_Sp_ref(registers[i]).pack(socket_stream); + else + get_Sp_ref(registers[i]).pack(socket_stream, + sint::get_rec_factor(P.my_num(), P.num_players())); } else if (reg_type == MODP && secrecy_type == CLEAR) { // Send vector of clear public field elements @@ -257,15 +259,7 @@ void Processor::write_socket(const RegType reg_type, const SecrecyT } } - // Apply DH Auth encryption if session keys have been created. - map::iterator it = external_clients.symmetric_client_keys.find(socket_id); - if (it != external_clients.symmetric_client_keys.end()) { - socket_stream.encrypt(it->second); - } - - // Apply STS commsec encryption if session keys have been created. try { - maybe_encrypt_sequence(socket_id); socket_stream.Send(external_clients.get_socket(socket_id)); } catch (bad_value& e) { @@ -282,7 +276,6 @@ void Processor::read_socket_ints(int client_id, const vector& int m = registers.size(); socket_stream.reset_write_head(); socket_stream.Receive(external_clients.get_socket(client_id)); - maybe_decrypt_sequence(client_id); for (int i = 0; i < m; i++) { int val; @@ -298,7 +291,6 @@ void Processor::read_socket_vector(int client_id, const vector int m = registers.size(); socket_stream.reset_write_head(); socket_stream.Receive(external_clients.get_socket(client_id)); - maybe_decrypt_sequence(client_id); for (int i = 0; i < m; i++) { get_Cp_ref(registers[i]).unpack(socket_stream); @@ -312,146 +304,13 @@ void Processor::read_socket_private(int client_id, const vector::iterator it = external_clients.symmetric_client_keys.find(client_id); - if (it != external_clients.symmetric_client_keys.end()) - { - socket_stream.decrypt(it->second); - } for (int i = 0; i < m; i++) { get_Sp_ref(registers[i]).unpack(socket_stream, read_macs); } } -// Read socket for client public key as 8 ints, calculate session key for client. -template -void Processor::read_client_public_key(int client_id, const vector& registers) { - - read_socket_ints(client_id, registers); - - // After read into registers, need to extract values - vector client_public_key (registers.size(), 0); - for(unsigned int i = 0; i < registers.size(); i++) { - client_public_key[i] = (int&)get_Ci_ref(registers[i]); - } - - external_clients.generate_session_key_for_client(client_id, client_public_key); -} - -template -void Processor::init_secure_socket_internal(int client_id, const vector& registers) { - external_clients.symmetric_client_commsec_send_keys.erase(client_id); - external_clients.symmetric_client_commsec_recv_keys.erase(client_id); - unsigned char client_public_bytes[crypto_sign_PUBLICKEYBYTES]; - sts_msg1_t m1; - sts_msg2_t m2; - sts_msg3_t m3; - - external_clients.load_server_keys_once(); - external_clients.require_ed25519_keys(); - - // Validate inputs and state - if(registers.size() != 8) { - throw "Invalid call to init_secure_socket."; - } - - // Extract client long term public key into bytes - vector client_public_key (registers.size(), 0); - for(unsigned int i = 0; i < registers.size(); i++) { - client_public_key[i] = (int&)get_Ci_ref(registers[i]); - } - external_clients.curve25519_ints_to_bytes(client_public_bytes, client_public_key); - - // Start Station to Station Protocol - STS ke(client_public_bytes, external_clients.server_publickey_ed25519, external_clients.server_secretkey_ed25519); - m1 = ke.send_msg1(); - socket_stream.reset_write_head(); - socket_stream.append(m1.bytes, sizeof m1.bytes); - socket_stream.Send(external_clients.get_socket(client_id)); - socket_stream.ReceiveExpected(external_clients.get_socket(client_id), - 96); - socket_stream.consume(m2.pubkey, sizeof m2.pubkey); - socket_stream.consume(m2.sig, sizeof m2.sig); - m3 = ke.recv_msg2(m2); - socket_stream.reset_write_head(); - socket_stream.append(m3.bytes, sizeof m3.bytes); - socket_stream.Send(external_clients.get_socket(client_id)); - - // Use results of STS to generate send and receive keys. - vector sendKey = ke.derive_secret(crypto_secretbox_KEYBYTES); - vector recvKey = ke.derive_secret(crypto_secretbox_KEYBYTES); - external_clients.symmetric_client_commsec_send_keys[client_id] = make_pair(sendKey,0); - external_clients.symmetric_client_commsec_recv_keys[client_id] = make_pair(recvKey,0); -} - -template -void Processor::init_secure_socket(int client_id, const vector& registers) { - - try { - init_secure_socket_internal(client_id, registers); - } catch (char const *e) { - cerr << "STS initiator role failed with: " << e << endl; - throw Processor_Error("STS initiator failed"); - } -} - -template -void Processor::resp_secure_socket(int client_id, const vector& registers) { - try { - resp_secure_socket_internal(client_id, registers); - } catch (char const *e) { - cerr << "STS responder role failed with: " << e << endl; - throw Processor_Error("STS responder failed"); - } -} - -template -void Processor::resp_secure_socket_internal(int client_id, const vector& registers) { - external_clients.symmetric_client_commsec_send_keys.erase(client_id); - external_clients.symmetric_client_commsec_recv_keys.erase(client_id); - unsigned char client_public_bytes[crypto_sign_PUBLICKEYBYTES]; - sts_msg1_t m1; - sts_msg2_t m2; - sts_msg3_t m3; - - external_clients.load_server_keys_once(); - external_clients.require_ed25519_keys(); - - // Validate inputs and state - if(registers.size() != 8) { - throw "Invalid call to init_secure_socket."; - } - vector client_public_key (registers.size(), 0); - for(unsigned int i = 0; i < registers.size(); i++) { - client_public_key[i] = (int&)get_Ci_ref(registers[i]); - } - external_clients.curve25519_ints_to_bytes(client_public_bytes, client_public_key); - - // Start Station to Station Protocol for the responder - STS ke(client_public_bytes, external_clients.server_publickey_ed25519, external_clients.server_secretkey_ed25519); - socket_stream.reset_read_head(); - socket_stream.ReceiveExpected(external_clients.get_socket(client_id), - 32); - socket_stream.consume(m1.bytes, sizeof m1.bytes); - m2 = ke.recv_msg1(m1); - socket_stream.reset_write_head(); - socket_stream.append(m2.pubkey, sizeof m2.pubkey); - socket_stream.append(m2.sig, sizeof m2.sig); - socket_stream.Send(external_clients.get_socket(client_id)); - - socket_stream.ReceiveExpected(external_clients.get_socket(client_id), - 64); - socket_stream.consume(m3.bytes, sizeof m3.bytes); - ke.recv_msg3(m3); - - // Use results of STS to generate send and receive keys. - vector recvKey = ke.derive_secret(crypto_secretbox_KEYBYTES); - vector sendKey = ke.derive_secret(crypto_secretbox_KEYBYTES); - external_clients.symmetric_client_commsec_recv_keys[client_id] = make_pair(recvKey,0); - external_clients.symmetric_client_commsec_send_keys[client_id] = make_pair(sendKey,0); -} // Read share data from a file starting at file_pos until registers filled. // file_pos_register is written with new file position (-1 is eof). @@ -722,26 +581,4 @@ ostream& operator<<(ostream& s,const Processor& P) return s; } -template -void Processor::maybe_decrypt_sequence(int client_id) -{ - map,uint64_t> >::iterator it_cs = external_clients.symmetric_client_commsec_recv_keys.find(client_id); - if (it_cs != external_clients.symmetric_client_commsec_recv_keys.end()) - { - socket_stream.decrypt_sequence(&it_cs->second.first[0], it_cs->second.second); - it_cs->second.second++; - } -} - -template -void Processor::maybe_encrypt_sequence(int client_id) -{ - map,uint64_t> >::iterator it_cs = external_clients.symmetric_client_commsec_send_keys.find(client_id); - if (it_cs != external_clients.symmetric_client_commsec_send_keys.end()) - { - socket_stream.encrypt_sequence(&it_cs->second.first[0], it_cs->second.second); - it_cs->second.second++; - } -} - #endif diff --git a/Programs/Circuits b/Programs/Circuits new file mode 160000 index 000000000..82dfda9d1 --- /dev/null +++ b/Programs/Circuits @@ -0,0 +1 @@ +Subproject commit 82dfda9d12b6fd2865f21f02809f6e7a5323d0be diff --git a/Programs/Source/aes_circuit.mpc b/Programs/Source/aes_circuit.mpc new file mode 100644 index 000000000..f8f84842f --- /dev/null +++ b/Programs/Source/aes_circuit.mpc @@ -0,0 +1,8 @@ +from circuit import Circuit +sb128 = sbits.get_type(128) +key = sb128(0x2b7e151628aed2a6abf7158809cf4f3c) +plaintext = sb128(0x6bc1bee22e409f96e93d7e117393172a) +n = 1000 +aes128 = Circuit('aes_128') +ciphertexts = aes128(sbitvec([key] * n), sbitvec([plaintext] * n)) +ciphertexts.elements()[n - 1].reveal().print_reg() diff --git a/Programs/Source/bankers_bonus.mpc b/Programs/Source/bankers_bonus.mpc index e523fd9aa..a6fdc63c3 100644 --- a/Programs/Source/bankers_bonus.mpc +++ b/Programs/Source/bankers_bonus.mpc @@ -4,8 +4,6 @@ to deduce the maximum value from a range of integer input. Demonstrate clients external to computing parties supplying input and receiving an authenticated result. See bankers-bonus-client.cpp for client (and setup instructions). - - For an implementation with communications security see bankers_bonus_commsec.mpc. Wait for MAX_NUM_CLIENTS to join the game or client finish flag to be sent before calculating the maximum. diff --git a/Programs/Source/bankers_bonus_commsec.mpc b/Programs/Source/bankers_bonus_commsec.mpc deleted file mode 100644 index 1d5f364a8..000000000 --- a/Programs/Source/bankers_bonus_commsec.mpc +++ /dev/null @@ -1,144 +0,0 @@ -# coding=latin1 - -""" - Solve Bankers bonus, aka Millionaires problem. - to deduce the maximum value from a range of integer input. - - Demonstrate clients external to computing parties supplying input and receiving - an authenticated result. See bankers-bonus-commsec-client.cpp for client (and setup instructions). - - For an implementation without communications security see bankers_bonus.mpc. - - Wait for MAX_NUM_CLIENTS to join the game or client finish flag to be sent - before calculating the maximum. - - Note each client connects in a single thread and so is potentially blocked. - - Each round / game will reset and so this runs indefinitiely. -""" - -from Compiler.types import sint, regint, Array, Matrix, MemValue -from Compiler.instructions import listen, acceptclientconnection -from Compiler.library import print_ln, do_while, if_e, else_, for_range -from Compiler.util import if_else - -PORTNUM = 14000 -MAX_NUM_CLIENTS = 8 -n_rounds = 0 - -if len(program.args) > 1: - n_rounds = int(program.args[1]) - -def accept_client(): - client_socket_id = regint() - acceptclientconnection(client_socket_id, PORTNUM) - last = regint.read_from_socket(client_socket_id) - - # Crypto setup - public_signing_key = regint.read_from_socket(client_socket_id, 8) - public_key = regint.read_client_public_key(client_socket_id) - regint.resp_secure_socket(client_socket_id,*public_signing_key) - - return client_socket_id, last - -def client_input(client_socket_id): - """ - Send share of random value, receive input and deduce share. - """ - - client_inputs = sint.receive_from_client(1, client_socket_id) - - return client_inputs[0] - - -def determine_winner(number_clients, client_values, client_ids): - """Work out and return client_id which corresponds to max client_value""" - max_value = Array(1, sint) - max_value[0] = client_values[0] - win_client_id = Array(1, sint) - win_client_id[0] = client_ids[0] - - @for_range(number_clients-1) - def loop_body(i): - # Is this client input a new maximum, will be sint(1) if true, else sint(0) - is_new_max = max_value[0] < client_values[i+1] - # Keep latest max_value - max_value[0] = if_else(is_new_max, client_values[i+1], max_value[0]) - # Keep current winning client id - win_client_id[0] = if_else(is_new_max, client_ids[i+1], win_client_id[0]) - - return win_client_id[0] - - -def write_winner_to_clients(sockets, number_clients, winning_client_id): - """Send share of winning client id to all clients who joined game.""" - - # Setup authenticate result using share of random. - # client can validate ∑ winning_client_id * ∑ rnd_from_triple = ∑ auth_result - rnd_from_triple = sint.get_random_triple()[0] - auth_result = winning_client_id * rnd_from_triple - - @for_range(number_clients) - def loop_body(i): - sint.write_shares_to_socket(sockets[i], [winning_client_id, rnd_from_triple, auth_result]) - - -def main(): - """Listen in while loop for players to join a game. - Once maxiumum reached or have notified that round finished, run comparison and return result.""" - # Start listening for client socket connections - listen(PORTNUM) - print_ln('Listening for client connections on base port %s', PORTNUM) - - def game_loop(_=None): - print_ln('Starting a new round of the game.') - - # Clients socket id (integer). - client_sockets = Array(MAX_NUM_CLIENTS, regint) - # Number of clients - number_clients = MemValue(regint(0)) - # Clients secret input. - client_values = Array(MAX_NUM_CLIENTS, sint) - # Client ids to identity client - client_ids = Array(MAX_NUM_CLIENTS, sint) - # Keep track of received inputs - seen = Array(MAX_NUM_CLIENTS, regint) - seen.assign_all(0) - - # Loop round waiting for each client to connect - @do_while - def client_connections(): - client_id, last = accept_client() - @if_(client_id >= MAX_NUM_CLIENTS) - def _(): - print_ln('client id too high') - crash() - client_sockets[client_id] = client_id - client_ids[client_id] = client_id - seen[client_id] = 1 - @if_(last == 1) - def _(): - number_clients.write(client_id + 1) - - return (sum(seen) < number_clients) + (number_clients == 0) - - @for_range(number_clients) - def _(client_id): - client_values[client_id] = client_input(client_id) - - winning_client_id = determine_winner(number_clients, client_values, client_ids) - - print_ln('Found winner, index: %s.', winning_client_id.reveal()) - - write_winner_to_clients(client_sockets, number_clients, winning_client_id) - - return True - - if n_rounds > 0: - print('run %d rounds' % n_rounds) - for_range(n_rounds)(game_loop) - else: - print('run forever') - do_while(game_loop) - -main() diff --git a/Programs/Source/idash_predict.mpc b/Programs/Source/idash_predict.mpc index 637ec3e9c..ed45ce87f 100644 --- a/Programs/Source/idash_predict.mpc +++ b/Programs/Source/idash_predict.mpc @@ -1,5 +1,6 @@ import ml import random +import re program.use_trunc_pr = True sfix.round_nearest = True @@ -10,6 +11,13 @@ cfix.set_precision(16, 31) N = int(program.args[1]) n_features = int(program.args[2]) +n_threads = None + +for arg in program.args: + m = re.match('n_threads=(.*)', arg) + if m: + n_threads = int(m.group(1)) + program.allocated_mem['s'] = 1 + n_features b = sfix.load_mem(0) @@ -24,13 +32,15 @@ dense.W.assign_vector(W) print_ln('b=%s W[-1]=%s', dense.b[0].reveal(), dense.W[n_features - 1][0][0].reveal()) -@for_range_opt(n_features) +@for_range_opt_multithread(n_threads, n_features) def _(i): @for_range_opt(N) def _(j): dense.X[j][0][i] = sfix.get_input_from(0) -dense.forward() +batch = regint.Array(N) +batch.assign(regint.inc(N)) +dense.forward(batch) print_str('predictions: ') diff --git a/Programs/Source/idash_train.mpc b/Programs/Source/idash_train.mpc index 25bcd693c..ccef647f8 100644 --- a/Programs/Source/idash_train.mpc +++ b/Programs/Source/idash_train.mpc @@ -1,28 +1,52 @@ import ml import random +import re program.use_trunc_pr = True -sfix.round_nearest = True sfix.set_precision(16, 31) cfix.set_precision(16, 31) sfloat.vlen = sfix.f -n_epochs = 200 +n_epochs = 100 n_normal = int(program.args[1]) n_pos = int(program.args[2]) n_features = int(program.args[3]) +if 'approx' in program.args: + approx = 3 +elif 'approx5' in program.args: + approx = 5 +else: + approx = False + +if 'split' in program.args: + program.use_split(3) + +n_threads = None + +for arg in program.args: + m = re.match('n_threads=(.*)', arg) + if m: + n_threads = int(m.group(1)) + debug = 'debug' in program.args +ml.set_n_threads(n_threads) + n_examples = n_normal + n_pos N = max(n_normal, n_pos) * 2 +if 'mini' in program.args: + batch_size = 32 +else: + batch_size = N + X_normal = sfix.Matrix(n_normal, n_features) X_pos = sfix.Matrix(n_pos, n_features) -@for_range_opt(n_features) +@for_range_opt_multithread(n_threads, n_features) def _(i): @for_range_opt(n_normal) def _(j): @@ -32,11 +56,11 @@ def _(i): X_pos[j][i] = sfix.get_input_from(0) dense = ml.Dense(N, n_features, 1) -layers = [dense, ml.Output(N)] +layers = [dense, ml.Output(N, approx=approx)] sgd = ml.SGD(layers, n_epochs, report_loss=debug) sgd.reset([X_normal, X_pos]) -sgd.run() +sgd.run(batch_size) if debug: @for_range(N) diff --git a/Programs/Source/logreg.mpc b/Programs/Source/logreg.mpc index 7ecb6ad3d..e7cb42624 100644 --- a/Programs/Source/logreg.mpc +++ b/Programs/Source/logreg.mpc @@ -32,8 +32,6 @@ sgd = ml.SGD(layers, batch // 128 * 10 , debug=debug, report_loss=False) sgd.reset([X_normal, X_pos]) sgd.run(batch_size=batch) -ml.approx_sigmoid.special = False - # @for_range(1000) # def _(i): # sgd.backward() diff --git a/Programs/Source/regression.mpc b/Programs/Source/regression.mpc index 297917dda..333aebe4e 100644 --- a/Programs/Source/regression.mpc +++ b/Programs/Source/regression.mpc @@ -82,7 +82,14 @@ if 'quant' in program.args: else: dense = ml.Dense(N, n_features, 1) -layers = [dense, ml.Output(N, debug=debug, approx='approx' in program.args)] +if 'approx' in program.args: + approx = 3 +elif 'approx5' in program.args: + approx = 5 +else: + approx = False + +layers = [dense, ml.Output(N, debug=debug, approx=approx)] Y = sfix.Array(n_examples) X = sfix.Matrix(n_examples, n_features) diff --git a/Programs/Source/test_gc.mpc b/Programs/Source/test_gc.mpc index bdf695d8b..73df18cb9 100644 --- a/Programs/Source/test_gc.mpc +++ b/Programs/Source/test_gc.mpc @@ -97,8 +97,3 @@ test(c[0], 0) test(c[1], 1) test(c[2], 1) test(c[3], 0) - -k = 41 -a = int(2.9142 * 2**k) -alpha = sbitint.get_type(2 * k)(a) -test(sbits.bit_compose((alpha >> 64).bit_decompose()[:64]), 0) diff --git a/Protocols/Rep3Share.h b/Protocols/Rep3Share.h index 5f8cca5b5..9530c0f62 100644 --- a/Protocols/Rep3Share.h +++ b/Protocols/Rep3Share.h @@ -127,12 +127,14 @@ class Rep3Share : public FixedVec, public ShareInterface void pack(octetStream& os, bool full = true) const { - (void)full; - FixedVec::pack(os); + if (full) + FixedVec::pack(os); + else + (*this)[0].pack(os); } void unpack(octetStream& os, bool full = true) { - (void)full; + assert(full); FixedVec::unpack(os); } }; diff --git a/Protocols/ReplicatedInput.h b/Protocols/ReplicatedInput.h index 7834c6d98..54119ef94 100644 --- a/Protocols/ReplicatedInput.h +++ b/Protocols/ReplicatedInput.h @@ -14,18 +14,14 @@ template class PrepLessInput : public InputBase { protected: - SubProcessor* processor; vector shares; size_t i_share; public: PrepLessInput(SubProcessor* proc) : - InputBase(proc ? proc->Proc : 0), processor(proc), i_share(0) {} + InputBase(proc ? proc->Proc : 0), i_share(0) {} virtual ~PrepLessInput() {} - void start(int player, int n_inputs); - void stop(int player, vector targets); - virtual void reset(int player) = 0; virtual void add_mine(const typename T::open_type& input, int n_bits = -1) = 0; diff --git a/Protocols/ReplicatedInput.hpp b/Protocols/ReplicatedInput.hpp index da380a431..606270113 100644 --- a/Protocols/ReplicatedInput.hpp +++ b/Protocols/ReplicatedInput.hpp @@ -57,47 +57,6 @@ void ReplicatedInput::exchange() } } -template -void PrepLessInput::start(int player, int n_inputs) -{ - assert(processor != 0); - auto& proc = *processor; - reset(player); - - if (player == proc.P.my_num()) - { - for (int i = 0; i < n_inputs; i++) - { - typename T::clear t; - this->buffer.input(t); - add_mine(t); - } - - send_mine(); - } -} - -template -void PrepLessInput::stop(int player, vector targets) -{ - assert(processor != 0); - auto& proc = *processor; - if (proc.P.my_num() == player) - { - for (unsigned int i = 0; i < targets.size(); i++) - proc.get_S_ref(targets[i]) = finalize_mine(); - } - else - { - octetStream o; - this->timer.start(); - proc.P.receive_player(player, o, true); - this->timer.stop(); - for (unsigned int i = 0; i < targets.size(); i++) - finalize_other(player, proc.get_S_ref(targets[i]), o); - } -} - template inline void ReplicatedInput::finalize_other(int player, T& target, octetStream& o, int n_bits) diff --git a/Protocols/ShamirShare.h b/Protocols/ShamirShare.h index d206ce262..8d2ef1537 100644 --- a/Protocols/ShamirShare.h +++ b/Protocols/ShamirShare.h @@ -57,6 +57,11 @@ class ShamirShare : public T, public ShareInterface return ShamirMachine::s().threshold; } + static T get_rec_factor(int i, int n) + { + return Protocol::get_rec_factor(i, n); + } + static ShamirShare constant(T value, int my_num, const T& alphai = {}) { return ShamirShare(value, my_num, alphai); @@ -135,14 +140,17 @@ class ShamirShare : public T, public ShareInterface throw runtime_error("never call this"); } - void pack(octetStream& os, bool full = true) const + void pack(octetStream& os, const T& rec_factor) const + { + (*this * rec_factor).pack(os); + } + void pack(octetStream& os) const { - (void)full; T::pack(os); } void unpack(octetStream& os, bool full = true) { - (void)full; + assert(full); T::unpack(os); } }; diff --git a/Protocols/ShareInterface.h b/Protocols/ShareInterface.h index 5a37efc87..42f0e4006 100644 --- a/Protocols/ShareInterface.h +++ b/Protocols/ShareInterface.h @@ -24,6 +24,8 @@ class ShareInterface template static void split(vector, vector, int, T*, int, Player&) { throw runtime_error("split not implemented"); } + + static bool get_rec_factor(int, int) { return false; } }; #endif /* PROTOCOLS_SHAREINTERFACE_H_ */ diff --git a/Protocols/SohoPrep.h b/Protocols/SohoPrep.h index bd5ecc67a..4f9e45561 100644 --- a/Protocols/SohoPrep.h +++ b/Protocols/SohoPrep.h @@ -25,7 +25,9 @@ class SohoPrep : public SemiHonestRingPrep } void buffer_triples(); + void buffer_squares(); void buffer_inverses(); + void buffer_bits(); }; #endif /* PROTOCOLS_SOHOPREP_H_ */ diff --git a/Protocols/SohoPrep.hpp b/Protocols/SohoPrep.hpp index 81ef780e6..cd0826f08 100644 --- a/Protocols/SohoPrep.hpp +++ b/Protocols/SohoPrep.hpp @@ -70,9 +70,58 @@ void SohoPrep::buffer_triples() ci.element(i)}}); } +template +void SohoPrep::buffer_squares() +{ + + auto& proc = this->proc; + assert(proc != 0); + lock.lock(); + if (not setup) + { + PlainPlayer P(proc->P.N, T::clear::type_char()); + basic_setup(P); + } + lock.unlock(); + + Plaintext_ ai(setup->FieldD); + SeededPRNG G; + ai.randomize(G); + Ciphertext Ca = setup->pk.encrypt(ai); + octetStream os; + Ca.pack(os); + + for (int i = 1; i < proc->P.num_players(); i++) + { + proc->P.pass_around(os); + Ca.add<0>(os); + } + + Ciphertext Cc = Ca.mul(setup->pk, Ca); + Plaintext_ ci(setup->FieldD); + SimpleDistDecrypt dd(proc->P, *setup); + EncCommitBase_ EC; + dd.reshare(ci, Cc, EC); + + for (unsigned i = 0; i < ai.num_slots(); i++) + this->squares.push_back({{ai.element(i), ci.element(i)}}); +} + template void SohoPrep::buffer_inverses() { assert(this->proc != 0); ::buffer_inverses(this->inverses, *this, this->proc->MC, this->proc->P); } + +template<> +void SohoPrep>::buffer_bits() +{ + buffer_bits_from_squares(*this); +} + +template<> +void SohoPrep>::buffer_bits() +{ + buffer_bits_without_check(); +} diff --git a/Protocols/Spdz2kPrep.h b/Protocols/Spdz2kPrep.h index 19774aecd..e0c23a938 100644 --- a/Protocols/Spdz2kPrep.h +++ b/Protocols/Spdz2kPrep.h @@ -9,7 +9,6 @@ #include "MascotPrep.h" #include "RingOnlyPrep.h" #include "Spdz2kShare.h" -#include "GC/TinySecret.h" template void bits_from_square_in_ring(vector& bits, int buffer_size, U* bit_prep); diff --git a/README.md b/README.md index 72e2632be..c713899c3 100644 --- a/README.md +++ b/README.md @@ -147,7 +147,7 @@ compute the preprocessing time for a particular computation. required. This includes mainstream processors released 2014 or later. For older models you need to deactivate the respective extensions in the `ARCH` variable. - - To benchmark online-only protocols or Overdrive, add the following line at the top: `MY_CFLAGS = -DINSECURE` + - To benchmark online-only protocols or Overdrive offline phases, add the following line at the top: `MY_CFLAGS = -DINSECURE` - `PREP_DIR` should point to should be a local, unversioned directory to store preprocessing data (default is `Player-Data` in the current directory). - For homomorphic encryption, set `USE_NTL = 1`. @@ -240,6 +240,29 @@ al.](https://eprint.iacr.org/2020/338) You can activate them by using `-Y` instead of `-X`. Note that this also activates classic daBits when useful. +#### Bristol Fashion circuits + +Bristol Fashion is the name of a description format of binary circuits +used by +[SCALE-MAMBA](https://github.com/KULeuven-COSIC/SCALE-MAMBA). You can +access such circuits from the high-level language if they are present +in `Programs/Circuits`. To run the AES-128 circuit provided with +SCALE-MAMBA, you can run the following: +``` +make Programs/Circuits +./compile.py aes_circuit +Scripts/semi.sh aes_circuit +``` +This downloads the circuit, compiles it to MP-SPDZ bytecode, and runs +it as semi-honest two-party computation 1000 times in parallel. It +should then output the AES test vector +`0x3ad77bb40d7a3660a89ecaf32466ef97`. You can run it with any other +protocol as well. + +See the +[documentation](https://mp-spdz.readthedocs.io/en/latest/Compiler.html#module-Compiler.circuit) +for further examples. + #### Compiling and running programs from external directories Programs can also be edited, compiled and run from any directory with the above basic structure. So for a source file in `./Programs/Source/`, all SPDZ scripts must be run from `./`. The `setup-online.sh` script must also be run from `./` to create the relevant data. For example: diff --git a/Scripts/setup-clients.sh b/Scripts/setup-clients.sh new file mode 100755 index 000000000..48a406f0b --- /dev/null +++ b/Scripts/setup-clients.sh @@ -0,0 +1,13 @@ +#!/bin/bash + +n=$1 + +test -e Player-Data || mkdir Player-Data + +echo Setting up SSL for $n parties + +for i in `seq 0 $[n-1]`; do + openssl req -newkey rsa -nodes -x509 -out Player-Data/C$i.pem -keyout Player-Data/C$i.key -subj "/CN=C$i" +done + +c_rehash Player-Data diff --git a/Scripts/test_ecdsa.sh b/Scripts/test_ecdsa.sh index 57ba11557..f59243d3e 100755 --- a/Scripts/test_ecdsa.sh +++ b/Scripts/test_ecdsa.sh @@ -1,6 +1,6 @@ #!/bin/bash -make -j4 ecdsa Fake-ECDSA.x +make -j4 ecdsa Fake-ECDSA.x secure.x run() { @@ -19,8 +19,11 @@ for i in rep mal-rep shamir mal-shamir; do run $i 2 done -./Fake-ECDSA.x - -for i in semi mascot fake-spdz; do +for i in semi mascot; do run $i 1 done + +if ! ./secure.x; then + ./Fake-ECDSA.x + run fake-spdz 1 +fi diff --git a/Tools/Config.cpp b/Tools/Config.cpp deleted file mode 100644 index 5a7287792..000000000 --- a/Tools/Config.cpp +++ /dev/null @@ -1,107 +0,0 @@ -// Client key file format: -// X25519 Public Key -// X25519 Secret Key -// Ed25519 Public Key -// Ed25519 Secret Key -// Server 1 X25519 Public Key -// Server 1 Ed25519 Public Key -// ... -// Server N Public Key -// Server N Ed25519 Public Key -// -// Player key file format: -// X25519 Public Key -// X25519 Secret Key -// Ed25519 Public Key -// Ed25519 Secret Key -// Number of clients [64 bit little endian] -// Client 1 X25519 Public Key -// Client 1 Ed25519 Public Key -// ... -// Client N X25519 Public Key -// Client N Ed25519 Public Key -// Number of servers [64 bit little endian] -// Server 1 X25519 Public Key -// Server 1 Ed25519 Public Key -// ... -// Server N X25519 Public Key -// Server N Ed25519 Public Key -#include "Tools/octetStream.h" -#include "Networking/Player.h" -#include "Math/gf2n.h" -#include "Config.h" -#include -#include -#include - -namespace Config { - static void output(const vector &vec, ofstream &of) - { - copy(vec.begin(), vec.end(), ostreambuf_iterator(of)); - } - - void putW64le(ofstream &outf, uint64_t nr) - { - char buf[8]; - for(int i=0;i<8;i++) { - char byte = (uint8_t)(nr >> (i*8)); - buf[i] = (char)byte; - } - outf.write(buf,sizeof buf); - } - - void write_player_config_file(string config_dir - ,int player_number, public_key my_pub, secret_key my_priv - , public_signing_key my_signing_pub, secret_signing_key my_signing_priv - , vector client_pubs, vector client_signing_pubs - , vector player_pubs, vector player_signing_pubs) - { - stringstream filename; - filename << config_dir << "Player-SPDZ-Keys-P" << player_number; - ofstream outf(filename.str().c_str(), ios::out | ios::binary); - if (outf.fail()) - throw file_error(filename.str().c_str()); - if(crypto_box_PUBLICKEYBYTES != my_pub.size() || - crypto_box_SECRETKEYBYTES != my_priv.size() || - crypto_sign_PUBLICKEYBYTES != my_signing_pub.size() || - crypto_sign_SECRETKEYBYTES != my_signing_priv.size()) { - throw "Invalid key sizes"; - } else if(client_pubs.size() != client_signing_pubs.size()) { - throw "Incorrect number of client keys"; - } else if(player_pubs.size() != player_signing_pubs.size()) { - throw "Incorrect number of player keys"; - } else { - for(size_t i = 0; i < client_pubs.size(); i++) { - if(crypto_box_PUBLICKEYBYTES != client_pubs[i].size() || - crypto_sign_PUBLICKEYBYTES != client_signing_pubs[i].size()) { - throw "Incorrect size of client key."; - } - } - for(size_t i = 0; i < player_pubs.size(); i++) { - if(crypto_box_PUBLICKEYBYTES != player_pubs[i].size() || - crypto_sign_PUBLICKEYBYTES != player_signing_pubs[i].size()) { - throw "Incorrect size of player key."; - } - } - } - // Write public and secret X25519 keys - output(my_pub, outf); - output(my_priv, outf); - output(my_signing_pub, outf); - output(my_signing_priv, outf); - - putW64le(outf, (uint64_t)client_pubs.size()); - // Write all client public keys - for (size_t j = 0; j < client_pubs.size(); j++) { - output(client_pubs[j], outf); - output(client_signing_pubs[j], outf); - } - putW64le(outf, (uint64_t)player_pubs.size()); - for (size_t j = 0; j < player_pubs.size(); j++) { - output(player_pubs[j], outf); - output(player_signing_pubs[j], outf); - } - outf.flush(); - outf.close(); - } -} diff --git a/Tools/Config.h b/Tools/Config.h deleted file mode 100644 index d10523672..000000000 --- a/Tools/Config.h +++ /dev/null @@ -1,15 +0,0 @@ -#include "Tools/octetStream.h" -#include "Networking/Player.h" -#include -namespace Config { - typedef vector public_key; - typedef vector public_signing_key; - typedef vector secret_key; - typedef vector secret_signing_key; - void write_player_config_file(string config_dir - ,int player_number, public_key my_pub, secret_key my_priv - , public_signing_key my_signing_pub, secret_signing_key my_signing_priv - , vector client_pubs, vector client_signing_pubs - , vector player_pubs, vector player_signing_pubs); - void putW64le(ofstream &outf, uint64_t nr); -} diff --git a/Tools/octetStream.cpp b/Tools/octetStream.cpp index 55eb9fda6..5f4b073b6 100644 --- a/Tools/octetStream.cpp +++ b/Tools/octetStream.cpp @@ -105,9 +105,7 @@ bigint octetStream::check_sum(int req_bytes) const bool octetStream::equals(const octetStream& a) const { if (len!=a.len) { return false; } - for (size_t i=0; i& v) -{ - store(v.size()); - for (int x : v) - store(x); -} - - -void octetStream::get(vector& v) -{ - size_t size; - get(size); - v.resize(size); - for (int& x : v) - get(x); -} - - -// Construct the ciphertext as `crypto_secretbox(pt, counter||random)` -void octetStream::encrypt_sequence(const octet* key, uint64_t counter) -{ - octet nonce[crypto_secretbox_NONCEBYTES]; - int i; - int message_len_bytes = len; - randombytes_buf(nonce, sizeof nonce); - if(counter == UINT64_MAX) { - throw Processor_Error("Encryption would overflow counter. Too many messages."); - } else { - counter++; - } - for(i=0; i<8; i++) { - nonce[i] = uint8_t ((counter >> (8*i)) & 0xFF); - } - int ciphertext_len = message_len_bytes + crypto_secretbox_MACBYTES; - octet ciphertext[ciphertext_len]; - - crypto_secretbox_easy(ciphertext, data, message_len_bytes, nonce, key); - // append the ciphertext to an empty octet stream - reset_read_head(); - reset_write_head(); - append(ciphertext, ciphertext_len*sizeof(octet)); - // append the nonce - append(nonce, crypto_secretbox_NONCEBYTES * sizeof(octet)); -} - -void octetStream::decrypt_sequence(const octet* key, uint64_t counter) -{ - int ciphertext_len = len - crypto_box_NONCEBYTES; - const octet *nonce = data + ciphertext_len; - int i; - uint64_t recvCounter=0; - // Numbers are typically 24U + 16U so cast to int is safe. - if (len < (int)(crypto_box_NONCEBYTES + crypto_secretbox_MACBYTES)) - { - throw Processor_Error("Cannot decrypt octetStream: ciphertext too short"); - } - for(i=7; i>=0; i--) { - recvCounter |= ((uint64_t) *(nonce + i)) << (i*8); - } - if(recvCounter != counter + 1) { - throw Processor_Error("Incorrect counter on stream. Possible MITM."); - } - if (crypto_secretbox_open_easy(data, data, ciphertext_len, nonce, key) != 0) - { - throw Processor_Error("octetStream decryption failed!"); - } - rewind_write_head(crypto_box_NONCEBYTES + crypto_secretbox_MACBYTES); - //prepare for unpack after decryption by resetting the read head - reset_read_head(); -} - -void octetStream::encrypt(const octet* key) -{ - octet nonce[crypto_secretbox_NONCEBYTES]; - randombytes_buf(nonce, sizeof nonce); - int message_len_bytes = len; - resize(len + crypto_secretbox_MACBYTES + crypto_secretbox_NONCEBYTES); - - // Encrypt data in-place - crypto_secretbox_easy(data, data, message_len_bytes, nonce, key); - // Adjust length to account for MAC, then append nonce - len += crypto_secretbox_MACBYTES; - append(nonce, sizeof nonce); -} - -void octetStream::decrypt(const octet* key) -{ - int ciphertext_len = len - crypto_box_NONCEBYTES; - // Numbers are typically 24U + 16U so cast to int is safe. - if (len < (int)(crypto_box_NONCEBYTES + crypto_secretbox_MACBYTES)) - { - throw Processor_Error("Cannot decrypt octetStream: ciphertext too short"); - } - if (crypto_secretbox_open_easy(data, data, ciphertext_len, data + ciphertext_len, key) != 0) - { - throw Processor_Error("octetStream decryption failed!"); - } - rewind_write_head(crypto_box_NONCEBYTES + crypto_secretbox_MACBYTES); -} - void octetStream::input(istream& s) { size_t size; diff --git a/Tools/octetStream.h b/Tools/octetStream.h index 2b5a83427..3bd0b96a6 100644 --- a/Tools/octetStream.h +++ b/Tools/octetStream.h @@ -127,6 +127,8 @@ class octetStream template T get(); + template + void get(T& ans); // works for all statically allocated types template @@ -134,8 +136,10 @@ class octetStream template void unserialize(T& x) { consume((octet*)&x, sizeof(x)); } - void store(const vector& v); - void get(vector& v); + template + void store(const vector& v); + template + void get(vector& v); void consume(octetStream& s,size_t l) { s.resize(l); @@ -147,20 +151,8 @@ class octetStream void Send(T socket_num) const; template void Receive(T socket_num); - void ReceiveExpected(int socket_num, size_t expected); - - // In-place authenticated encryption using sodium; key of length crypto_generichash_BYTES - // ciphertext = Enc(message) | MAC | counter - // - // This is much like 'encrypt' but uses a deterministic counter for the nonce, - // allowing enforcement of message order. - void encrypt_sequence(const octet* key, uint64_t counter); - void decrypt_sequence(const octet* key, uint64_t counter); - - // In-place authenticated encryption using sodium; key of length crypto_secretbox_KEYBYTES - // ciphertext = Enc(message) | MAC | nonce - void encrypt(const octet* key); - void decrypt(const octet* key); + template + void ReceiveExpected(T socket_num, size_t expected); void input(istream& s); void output(ostream& s); @@ -278,7 +270,8 @@ inline void octetStream::Receive(T socket_num) reset_read_head(); } -inline void octetStream::ReceiveExpected(int socket_num, size_t expected) +template +inline void octetStream::ReceiveExpected(T socket_num, size_t expected) { size_t nlen = 0; receive(socket_num, nlen, LENGTH_SIZE); @@ -310,11 +303,35 @@ T octetStream::get() return res; } +template +void octetStream::get(T& res) +{ + res.unpack(*this); +} + template<> inline int octetStream::get() { return get_int(sizeof(int)); } +template +void octetStream::store(const vector& v) +{ + store(v.size()); + for (auto& x : v) + store(x); +} + +template +void octetStream::get(vector& v) +{ + size_t size; + get(size); + v.resize(size); + for (auto& x : v) + get(x); +} + #endif diff --git a/Tools/random.cpp b/Tools/random.cpp index 12244b344..f61519f68 100644 --- a/Tools/random.cpp +++ b/Tools/random.cpp @@ -3,6 +3,7 @@ #include "Math/bigint.h" #include "Math/fixint.h" #include "Math/Z2k.hpp" +#include "Math/gfp.h" #include "Tools/Subroutines.h" #include #include diff --git a/Utils/client-setup.cpp b/Utils/client-setup.cpp deleted file mode 100644 index ee4034ab6..000000000 --- a/Utils/client-setup.cpp +++ /dev/null @@ -1,178 +0,0 @@ -// Preprocessing stage to: -// Create the public/private key pairs for each client -// Create the public/private key pairs for each spdz engine -// For each client store the client keys + all spdz engine public keys -// in a file named Client-Keys-C -// For each spdz engine store the spdz engine keys + all client public keys -// in a file named Player-SPDZ-Keys-P -// - -#include - -#include "Math/gf2n.h" -#include "Math/gfp.h" -#include "Protocols/Share.h" -#include "Math/Setup.h" -#include "Protocols/fake-stuff.h" -#include "Exceptions/Exceptions.h" - -#include "Math/Setup.h" -#include "Processor/Data_Files.h" -#include "Tools/mkpath.h" -#include "Tools/ezOptionParser.h" -#include "Tools/Config.h" - -#include -#include -using namespace std; - -static void output(const vector &vec, ofstream &of) -{ - copy(vec.begin(), vec.end(), ostreambuf_iterator(of)); -} - -int main(int argc, const char** argv) -{ - ez::ezOptionParser opt; - - opt.syntax = "./client-setup.x [OPTIONS]\n"; - - opt.add( - "0", // Default. - 0, // Required? - 1, // Number of args expected. - 0, // Delimiter if expecting multiple args. - "Number of external clients (default: nplayers)", // Help description. - "-nc", // Flag token. - "--numclients" // Flag token. - ); - opt.add( - "128", // Default. - 0, // Required? - 1, // Number of args expected. - 0, // Delimiter if expecting multiple args. - "Bit length of GF(p) field (default: 128)", // Help description. - "-lgp", // Flag token. - "--lgp" // Flag token. - ); - opt.add( - to_string(gf2n::default_degree()).c_str(), // Default. - 0, // Required? - 1, // Number of args expected. - 0, // Delimiter if expecting multiple args. - ("Bit length of GF(2^n) field (default: " + to_string(gf2n::default_degree()) + ")").c_str(), // Help description. - "-lg2", // Flag token. - "--lg2" // Flag token. - ); - opt.parse(argc, argv); - - string prep_data_prefix; - - string usage; - - int nplayers; - if (opt.firstArgs.size() == 2) - { - nplayers = atoi(opt.firstArgs[1]->c_str()); - } - else if (opt.lastArgs.size() == 1) - { - nplayers = atoi(opt.lastArgs[0]->c_str()); - } - else - { - cerr << "ERROR: invalid number of arguments\n"; - opt.getUsage(usage); - cout << usage; - return 1; - } - - int lg2, lgp, nclients; - opt.get("--numclients")->getInt(nclients); - if (nclients <= 0) - nclients = nplayers; - opt.get("--lgp")->getInt(lgp); - opt.get("--lg2")->getInt(lg2); - - cout << "nplayers = " << nplayers << endl; - cout << "nclients = " << nclients << endl; - cout << "lgp = " << lgp << endl; - cout << "lgp2 = " << lg2 << endl; - - prep_data_prefix = get_prep_dir(nplayers, lgp, lg2); - cout << "prep dir = " << prep_data_prefix << endl; - - vector client_publickeys; - vector client_secretkeys; - client_publickeys.resize(nclients); - client_secretkeys.resize(nclients); - for (int i = 0; i < nclients; i++) { - client_secretkeys[i].resize(crypto_box_SECRETKEYBYTES); - client_publickeys[i].resize(crypto_box_PUBLICKEYBYTES); - randombytes_buf(&client_secretkeys[i][0], client_secretkeys[i].size()); - crypto_scalarmult_base(&client_publickeys[i][0], &client_secretkeys[i][0]); - } - - vector client_signing_publickeys; - vector client_signing_secretkeys; - client_signing_publickeys.resize(nclients); - client_signing_secretkeys.resize(nclients); - for (int i = 0; i < nclients; i++) { - client_signing_publickeys[i].resize(crypto_sign_PUBLICKEYBYTES); - client_signing_secretkeys[i].resize(crypto_sign_SECRETKEYBYTES); - crypto_sign_keypair(&client_signing_publickeys[i][0], &client_signing_secretkeys[i][0]); - } - - vector server_publickeys; - vector server_secretkeys; - server_publickeys.resize(nplayers); - server_secretkeys.resize(nplayers); - for (int i = 0; i < nplayers; i++) { - server_publickeys[i].resize(crypto_box_PUBLICKEYBYTES); - server_secretkeys[i].resize(crypto_box_SECRETKEYBYTES); - randombytes_buf(&server_secretkeys[i][0], server_secretkeys[i].size()); - crypto_scalarmult_base(&server_publickeys[i][0], &server_secretkeys[i][0]); - } - vector server_signing_publickeys; - vector server_signing_secretkeys; - server_signing_publickeys.resize(nplayers); - server_signing_secretkeys.resize(nplayers); - for (int i = 0; i < nplayers; i++) { - server_signing_publickeys[i].resize(crypto_sign_PUBLICKEYBYTES); - server_signing_secretkeys[i].resize(crypto_sign_SECRETKEYBYTES); - crypto_sign_keypair(&server_signing_publickeys[i][0], &server_signing_secretkeys[i][0]); - } - - /* Write client files */ - for (int i = 0; i < nclients; i++) { - stringstream filename; - filename << prep_data_prefix << "Client-Keys-C" << i; - ofstream outf(filename.str().c_str()); - if (outf.fail()) - throw file_error(filename.str().c_str()); - // Write public key and secret key - output(client_publickeys[i],outf); - output(client_secretkeys[i],outf); - output(client_signing_publickeys[i],outf); - output(client_signing_secretkeys[i],outf); - int keycount = 2; - - // Write all spdz engine public keys - for (int j = 0; j < nplayers; j++) { - output(server_publickeys[j], outf); - output(server_signing_publickeys[j], outf); - keycount++; - } - outf.close(); - cout << "Wrote " << keycount << " keys to " << filename.str() << endl; - } - - /* Write spdz engine files */ - for (int i = 0; i < nplayers; i++) { - Config::write_player_config_file( prep_data_prefix, i - , server_publickeys[i], server_secretkeys[i] - , server_signing_publickeys[i], server_signing_secretkeys[i] - , client_publickeys, client_signing_publickeys - , server_publickeys, server_signing_publickeys); - } -} diff --git a/azure-pipelines.yml b/azure-pipelines.yml new file mode 100644 index 000000000..27dfc6c85 --- /dev/null +++ b/azure-pipelines.yml @@ -0,0 +1,24 @@ +# C/C++ with GCC +# Build your C/C++ project with GCC using make. +# Add steps that publish test results, save build artifacts, deploy, and more: +# https://docs.microsoft.com/azure/devops/pipelines/apps/c-cpp/gcc + +trigger: +- master + +pool: + vmImage: 'ubuntu-latest' + +steps: + - script: | + bash -c "sudo apt-get install libsodium-dev libntl-dev yasm texinfo libboost-dev libboost-thread-dev python3-gmpy2 libcrypto++-dev python-networkx" + - script: | + make mpir + - script: + echo USE_NTL=1 >> CONFIG.mine + - script: | + make + - script: + Scripts/setup-ssl.sh + - script: + Scripts/test_tutorial.sh -C diff --git a/compile.py b/compile.py index 13fd33f7d..0be47586e 100755 --- a/compile.py +++ b/compile.py @@ -26,24 +26,22 @@ def main(): help="specify output file") parser.add_option("-a", "--asm-output", dest="asmoutfile", help="asm output file for debugging") - parser.add_option("-l", "--asm-input", action="store_true", dest="assemblymode", - help="old-style asm input") parser.add_option("-p", "--primesize", dest="param", default=-1, help="bit length of modulus") parser.add_option("-g", "--galoissize", dest="galois", default=40, help="bit length of Galois field") parser.add_option("-d", "--debug", action="store_true", dest="debug", help="keep track of trace for debugging") - parser.add_option("-e", "--emulate", action="store_true", dest="emulate", default=False, - help="emulate register values for debugging") parser.add_option("-c", "--comparison", dest="comparison", default="log", help="comparison variant: log|plain|inv|sinv") parser.add_option("-r", "--noreorder", dest="reorder_between_opens", action="store_false", default=True, help="don't attempt to place instructions between start/stop opens") - parser.add_option("-O", "--optimize-hard", action="store_false", + parser.add_option("-M", "--preserve-mem-order", action="store_true", dest="preserve_mem_order", default=False, - help="don't preserve order of memory instructions; possible loss of correctness") + help="preserve order of memory instructions; possible efficiency loss") + parser.add_option("-O", "--optimize-hard", action="store_true", + dest="optimize_hard", help="currently not in use") parser.add_option("-u", "--noreallocate", action="store_true", dest="noreallocate", default=False, help="don't reallocate") parser.add_option("-m", "--max-parallel-open", dest="max_parallel_open", @@ -79,10 +77,13 @@ def main(): parser.print_help() return + if options.optimize_hard: + print('Note that -O/--optimize-hard currently has no effect') + def compilation(): prog = Compiler.run(args, options, param=int(options.param), - merge_opens=options.merge_opens, emulate=options.emulate, - assemblymode=options.assemblymode, debug=options.debug) + merge_opens=options.merge_opens, + debug=options.debug) prog.write_bytes(options.outfile) if options.asmoutfile: diff --git a/doc/Compiler.rst b/doc/Compiler.rst index aaaeb53c6..4f9a44f0d 100644 --- a/doc/Compiler.rst +++ b/doc/Compiler.rst @@ -56,3 +56,9 @@ Compiler.ml module ------------------------- .. automodule:: Compiler.ml + +Compiler.circuit module +----------------------- + +.. automodule:: Compiler.circuit + :members: