Skip to content

Commit

Permalink
Bristol Fashion.
Browse files Browse the repository at this point in the history
  • Loading branch information
mkskeller committed Apr 2, 2020
1 parent cb8e46d commit 24926df
Show file tree
Hide file tree
Showing 98 changed files with 1,263 additions and 2,654 deletions.
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,6 @@
[submodule "mpir"]
path = mpir
url = git://github.com/wbhart/mpir.git
[submodule "Programs/Circuits"]
path = Programs/Circuits
url = https://github.com/mkskeller/bristol-fashion
7 changes: 7 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
The changelog explains changes pulled through from the private development repository. Bug fixes and small enhancements are committed between releases and not documented here.

## 0.1.6 (Apr 2, 2020)

- Bristol Fashion circuits
- Semi-honest computation with somewhat homomorphic encryption
- Use SSL for client connections
- Client facilities for all arithmetic protocols

## 0.1.5 (Mar 20, 2020)

- Faster conversion between arithmetic and binary secret sharing using [extended daBits](https://eprint.iacr.org/2020/338)
Expand Down
69 changes: 44 additions & 25 deletions Compiler/GC/instructions.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,19 +67,33 @@ class BinaryVectorInstruction(base.Instruction):
def copy(self, size, subs):
return type(self)(*self.get_new_args(size, subs))

class NonVectorInstruction(base.Instruction):
is_vec = lambda self: False

def __init__(self, *args, **kwargs):
assert(args[0].n <= args[0].unit)
super(NonVectorInstruction, self).__init__(*args, **kwargs)

class NonVectorInstruction1(base.Instruction):
is_vec = lambda self: False

def __init__(self, *args, **kwargs):
assert(args[1].n <= args[1].unit)
super(NonVectorInstruction1, self).__init__(*args, **kwargs)

class xors(BinaryVectorInstruction):
code = opcodes['XORS']
arg_format = tools.cycle(['int','sbw','sb','sb'])

class xorm(base.Instruction):
class xorm(NonVectorInstruction):
code = opcodes['XORM']
arg_format = ['int','sbw','sb','cb']

class xorcb(base.Instruction):
class xorcb(NonVectorInstruction):
code = opcodes['XORCB']
arg_format = ['cbw','cb','cb']

class xorcbi(base.Instruction):
class xorcbi(NonVectorInstruction):
code = opcodes['XORCBI']
arg_format = ['cbw','cb','int']

Expand All @@ -101,67 +115,69 @@ class andm(BinaryVectorInstruction):
code = opcodes['ANDM']
arg_format = ['int','sbw','sb','cb']

class addcb(base.Instruction):
class addcb(NonVectorInstruction):
code = opcodes['ADDCB']
arg_format = ['cbw','cb','cb']

class addcbi(base.Instruction):
class addcbi(NonVectorInstruction):
code = opcodes['ADDCBI']
arg_format = ['cbw','cb','int']

class mulcbi(base.Instruction):
class mulcbi(NonVectorInstruction):
code = opcodes['MULCBI']
arg_format = ['cbw','cb','int']

class bitdecs(base.VarArgsInstruction):
class bitdecs(NonVectorInstruction, base.VarArgsInstruction):
code = opcodes['BITDECS']
arg_format = tools.chain(['sb'], itertools.repeat('sbw'))

class bitcoms(base.VarArgsInstruction):
class bitcoms(NonVectorInstruction, base.VarArgsInstruction):
code = opcodes['BITCOMS']
arg_format = tools.chain(['sbw'], itertools.repeat('sb'))

class bitdecc(base.VarArgsInstruction):
class bitdecc(NonVectorInstruction, base.VarArgsInstruction):
code = opcodes['BITDECC']
arg_format = tools.chain(['cb'], itertools.repeat('cbw'))

class shrcbi(base.Instruction):
class shrcbi(NonVectorInstruction):
code = opcodes['SHRCBI']
arg_format = ['cbw','cb','int']

class shlcbi(base.Instruction):
class shlcbi(NonVectorInstruction):
code = opcodes['SHLCBI']
arg_format = ['cbw','cb','int']

class ldbits(base.Instruction):
class ldbits(NonVectorInstruction):
code = opcodes['LDBITS']
arg_format = ['sbw','i','i']

class ldmsb(base.DirectMemoryInstruction, base.ReadMemoryInstruction):
class ldmsb(base.DirectMemoryInstruction, base.ReadMemoryInstruction,
base.VectorInstruction):
code = opcodes['LDMSB']
arg_format = ['sbw','int']

class stmsb(base.DirectMemoryWriteInstruction):
class stmsb(base.DirectMemoryWriteInstruction, base.VectorInstruction):
code = opcodes['STMSB']
arg_format = ['sb','int']
# def __init__(self, *args, **kwargs):
# super(type(self), self).__init__(*args, **kwargs)
# import inspect
# self.caller = [frame[1:] for frame in inspect.stack()[1:]]

class ldmcb(base.DirectMemoryInstruction, base.ReadMemoryInstruction):
class ldmcb(base.DirectMemoryInstruction, base.ReadMemoryInstruction,
base.VectorInstruction):
code = opcodes['LDMCB']
arg_format = ['cbw','int']

class stmcb(base.DirectMemoryWriteInstruction):
class stmcb(base.DirectMemoryWriteInstruction, base.VectorInstruction):
code = opcodes['STMCB']
arg_format = ['cb','int']

class ldmsbi(base.ReadMemoryInstruction):
class ldmsbi(base.ReadMemoryInstruction, base.VectorInstruction):
code = opcodes['LDMSBI']
arg_format = ['sbw','ci']

class stmsbi(base.WriteMemoryInstruction):
class stmsbi(base.WriteMemoryInstruction, base.VectorInstruction):
code = opcodes['STMSBI']
arg_format = ['sb','ci']

Expand All @@ -185,15 +201,15 @@ class stmsdci(base.WriteMemoryInstruction):
code = opcodes['STMSDCI']
arg_format = tools.cycle(['cb','cb'])

class convsint(base.Instruction):
class convsint(NonVectorInstruction1):
code = opcodes['CONVSINT']
arg_format = ['int','sbw','ci']

class convcint(base.Instruction):
class convcint(NonVectorInstruction):
code = opcodes['CONVCINT']
arg_format = ['cbw','ci']

class convcbit(base.Instruction):
class convcbit(NonVectorInstruction1):
code = opcodes['CONVCBIT']
arg_format = ['ciw','cb']

Expand Down Expand Up @@ -222,18 +238,19 @@ def __init__(self, *args, **kwargs):
super(split_class, self).__init__(*args, **kwargs)
assert (len(args) - 2) % args[0] == 0

class movsb(base.Instruction):
class movsb(NonVectorInstruction):
code = opcodes['MOVSB']
arg_format = ['sbw','sb']

class trans(base.VarArgsInstruction):
code = opcodes['TRANS']
is_vec = lambda self: True
def __init__(self, *args):
self.arg_format = ['int'] + ['sbw'] * args[0] + \
['sb'] * (len(args) - 1 - args[0])
super(trans, self).__init__(*args)

class bitb(base.Instruction):
class bitb(NonVectorInstruction):
code = opcodes['BITB']
arg_format = ['sbw']

Expand All @@ -245,20 +262,22 @@ class inputb(base.DoNotEliminateInstruction, base.VarArgsInstruction):
__slots__ = []
code = opcodes['INPUTB']
arg_format = tools.cycle(['p','int','int','sbw'])
is_vec = lambda self: True

class print_regb(base.IOInstruction):
class print_regb(base.VectorInstruction, base.IOInstruction):
code = opcodes['PRINTREGB']
arg_format = ['cb','i']
def __init__(self, reg, comment=''):
super(print_regb, self).__init__(reg, self.str_to_int(comment))

class print_reg_plainb(base.IOInstruction):
class print_reg_plainb(NonVectorInstruction, base.IOInstruction):
code = opcodes['PRINTREGPLAINB']
arg_format = ['cb']

class print_reg_signed(base.IOInstruction):
code = opcodes['PRINTREGSIGNED']
arg_format = ['int','cb']
is_vec = lambda self: True

class print_float_plainb(base.IOInstruction):
__slots__ = []
Expand Down
42 changes: 24 additions & 18 deletions Compiler/GC/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,9 @@ def malloc(cls, size):
def n_elements():
return 1
@classmethod
def mem_size(cls):
return math.ceil(cls.n / cls.unit)
@classmethod
def load_mem(cls, address, mem_type=None, size=None):
if size not in (None, 1):
v = [cls.load_mem(address + i) for i in range(size)]
Expand All @@ -101,9 +104,8 @@ def __init__(self, value=None, n=None, size=None):
def copy(self):
return type(self)(n=instructions_base.get_global_vector_size())
def set_length(self, n):
if n > self.max_length:
print(self.max_length)
raise Exception('too long: %d' % n)
if n > self.n:
raise Exception('too long: %d/%d' % (n, self.n))
self.n = n
def set_size(self, size):
pass
Expand Down Expand Up @@ -135,7 +137,7 @@ def __repr__(self):
if self.n != None:
suffix = '%d' % self.n
if type(self).n != None and type(self).n != self.n:
suffice += '/%d' % type(self).n
suffix += '/%d' % type(self).n
else:
suffix = 'undef'
return '%s(%s)' % (super(bits, self).__repr__(), suffix)
Expand Down Expand Up @@ -237,6 +239,7 @@ class sbits(bits):
bitdec = inst.bitdecs
bitcom = inst.bitcoms
conv_regint = inst.convsint
one_cache = {}
@classmethod
def conv_regint_by_bit(cls, n, res, other):
tmp = cbits.get_type(n)()
Expand Down Expand Up @@ -285,14 +288,12 @@ def load_int(self, value):
% (value, self.n))
if self.n <= 32:
inst.ldbits(self, self.n, value)
elif self.n <= 64:
self.load_other(regint(value, size=1))
elif self.n <= 128:
lower = sbits.get_type(64)(value % 2**64)
upper = sbits.get_type(self.n - 64)(value >> 64)
self.mov(self, lower + (upper << 64))
else:
raise NotImplementedError('more than 128 bits wanted')
size = math.ceil(self.n / self.unit)
tmp = regint(size=size)
for i in range(size):
tmp[i].load_int((value >> (i * 64)) % 2**64)
self.load_other(tmp)
def load_other(self, other):
if isinstance(other, cbits) and self.n == other.n:
inst.convcbit2s(self.n, self, other)
Expand Down Expand Up @@ -393,11 +394,10 @@ def __invert__(self):
# res = type(self)(n=self.n)
# inst.nots(res, self)
# return res
if self.n == None or self.n > self.unit:
one = self.get_type(self.n)()
self.conv_regint_by_bit(self.n, one, regint(1, size=self.n))
else:
one = self.new(value=self.long_one(), n=self.n)
key = self.n, library.get_block()
if key not in self.one_cache:
self.one_cache[key] = self.new(value=self.long_one(), n=self.n)
one = self.one_cache[key]
return self + one
def __neg__(self):
return self
Expand Down Expand Up @@ -432,12 +432,12 @@ def popcnt(self):
@classmethod
def trans(cls, rows):
rows = list(rows)
if len(rows) == 1:
if len(rows) == 1 and rows[0].n <= rows[0].unit:
return rows[0].bit_decompose()
n_columns = rows[0].n
for row in rows:
assert(row.n == n_columns)
if n_columns == 1:
if n_columns == 1 and len(rows) <= cls.unit:
return [cls.bit_compose(rows)]
else:
res = [cls.new(n=len(rows)) for i in range(n_columns)]
Expand All @@ -452,6 +452,10 @@ def bit_adder(*args, **kwargs):
@staticmethod
def ripple_carry_adder(*args, **kwargs):
return sbitint.ripple_carry_adder(*args, **kwargs)
def to_sint(self, n_bits):
bits = sbitvec.from_vec(sbitvec([self]).v[:n_bits]).elements()[0]
bits = sint(bits, size=n_bits)
return sint.bit_compose(bits)

class sbitvec(_vec):
@classmethod
Expand Down Expand Up @@ -524,6 +528,8 @@ def __iter__(self):
return iter(self.v)
def __len__(self):
return len(self.v)
def __getitem__(self, index):
return self.v[index]
@classmethod
def conv(cls, other):
return cls.from_vec(other.v)
Expand Down
Loading

0 comments on commit 24926df

Please sign in to comment.