Skip to content

Commit

Permalink
Fixed- and floating-point inputs.
Browse files Browse the repository at this point in the history
  • Loading branch information
mkskeller committed Jul 11, 2019
1 parent 8f0f25f commit 5ef7058
Show file tree
Hide file tree
Showing 34 changed files with 304 additions and 83 deletions.
18 changes: 10 additions & 8 deletions BMR/Register.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -532,14 +532,16 @@ class InputAccess
GC::Secret<EvalRegister>& dest;
GC::Processor<GC::Secret<EvalRegister> >& processor;
ProgramParty& party;
InputArgs args;

public:
InputAccess(int from, int n_bits, GC::Secret<EvalRegister>& dest,
InputAccess(const InputArgs& args,
GC::Processor<GC::Secret<EvalRegister> >& processor) :
from(from), n_bits(n_bits), dest(dest), processor(processor), party(
ProgramParty::s())
from(args.from + 1), n_bits(args.n_bits), dest(
processor.S[args.dest]), processor(processor), party(
ProgramParty::s()), args(args)
{
if (from > party.get_n_parties() or n_bits > 100)
if (from > unsigned(party.get_n_parties()) or n_bits > 100)
throw runtime_error("invalid input parameters");
}

Expand All @@ -550,7 +552,7 @@ class InputAccess
party.load_wire(reg);
if (from == party.get_id())
{
long long in = processor.get_input(n_bits);
long long in = processor.get_input(args.params);
for (size_t i = 0; i < n_bits; i++)
{
auto& reg = dest.get_reg(i);
Expand Down Expand Up @@ -599,10 +601,10 @@ void EvalRegister::inputb(GC::Processor<GC::Secret<EvalRegister> >& processor,
vector<octetStream> oss(party.get_n_parties());
octetStream& my_os = oss[party.get_id() - 1];
vector<InputAccess> accesses;
for (size_t j = 0; j < args.size(); j += 3)
InputArgList a(args);
for (auto x : a)
{
accesses.push_back(
{ args[j] + 1, args[j + 1], processor.S[args[j + 2]], processor });
accesses.push_back({x , processor});
}
for (auto& access : accesses)
access.prepare_masks(my_os);
Expand Down
9 changes: 5 additions & 4 deletions BMR/Register.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ using namespace std;
#include "GC/Clear.h"
#include "GC/Memory.h"
#include "GC/Access.h"
#include "GC/ArgTuples.h"
#include "Math/gf2n.h"
#include "Tools/FlexBuffer.h"

Expand Down Expand Up @@ -261,8 +262,8 @@ class ProgramRegister : public Phase, public Register

// most BMR phases don't need actual input
template<class T>
static T get_input(int from, GC::Processor<T>& processor, int n_bits)
{ (void)processor; return T::input(from, 0, n_bits); }
static T get_input(GC::Processor<T>& processor, const InputArgs& args)
{ (void)processor; return T::input(args.from + 1, 0, args.n_bits); }

char get_output() { return 0; }

Expand Down Expand Up @@ -314,9 +315,9 @@ class EvalRegister : public ProgramRegister
static void inputb(T& processor, const vector<int>& args);

template <class T>
static T get_input(int from, GC::Processor<T>& processor, int n_bits)
static T get_input(GC::Processor<T>& processor, const InputArgs& args)
{
(void)from, (void)processor, (void)n_bits;
(void)processor, (void)args;
throw runtime_error("use EvalRegister::inputb()");
}

Expand Down
2 changes: 1 addition & 1 deletion Compiler/GC/instructions.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ class reveal(base.Instruction):
class inputb(base.DoNotEliminateInstruction, base.VarArgsInstruction):
__slots__ = []
code = opcodes['INPUTB']
arg_format = tools.cycle(['p','int','sbw'])
arg_format = tools.cycle(['p','int','int','sbw'])

class print_reg(base.IOInstruction):
code = base.opcodes['PRINTREG']
Expand Down
7 changes: 6 additions & 1 deletion Compiler/GC/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ def get_input_from(cls, player, n_bits=None):
if n_bits is None:
n_bits = cls.n
res = cls()
inst.inputb(player, n_bits, res)
inst.inputb(player, n_bits, 0, res)
return res
# compatiblity to sint
get_raw_input_from = get_input_from
Expand Down Expand Up @@ -648,6 +648,11 @@ def load_mem(cls, address, size=None):
return sbitfixvec._new(sbitintvec(v))
else:
return super(sbitfix, cls).load_mem(address)
@classmethod
def get_input_from(cls, player):
v = cls.int_type()
inst.inputb(player, cls.k, cls.f, v)
return cls._new(v)
def __xor__(self, other):
return type(self)(self.v ^ other.v)
def __mul__(self, other):
Expand Down
8 changes: 8 additions & 0 deletions Compiler/allocator.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,7 @@ def dependency_graph(self, merge_classes):
last_print_str = None
last = defaultdict(lambda: defaultdict(lambda: None))
last_open = deque()
last_text_input = None

depths = [0] * len(block.instructions)
self.depths = depths
Expand Down Expand Up @@ -471,6 +472,13 @@ def keep_order(instr, n, t, arg_index=None):
else:
write(reg, n)

# will be merged
if isinstance(instr, TextInputInstruction):
if last_text_input is not None and \
type(block.instructions[last_text_input]) is not type(instr):
add_edge(last_text_input, n)
last_text_input = n

if isinstance(instr, merge_classes):
open_nodes.add(n)
G.add_node(n, merges=[])
Expand Down
24 changes: 23 additions & 1 deletion Compiler/instructions.py
Original file line number Diff line number Diff line change
Expand Up @@ -856,7 +856,7 @@ def has_var_args(self):

@base.gf2n
@base.vectorize
class asm_input(base.VarArgsInstruction):
class asm_input(base.TextInputInstruction):
r""" Receive input from player $p$ and put in register $s_i$. """
__slots__ = []
code = base.opcodes['INPUT']
Expand All @@ -870,6 +870,28 @@ def add_usage(self, req_node):
def execute(self):
self.args[0].value = _python_input("Enter player %d's input:" % self.args[1]) % program.P

class inputfix(base.TextInputInstruction):
__slots__ = []
code = base.opcodes['INPUTFIX']
arg_format = tools.cycle(['sw', 'int', 'p'])
field_type = 'modp'

def add_usage(self, req_node):
for player in self.args[2::3]:
req_node.increment((self.field_type, 'input', player), \
self.get_size())

class inputfloat(base.TextInputInstruction):
__slots__ = []
code = base.opcodes['INPUTFLOAT']
arg_format = tools.cycle(['sw', 'sw', 'sw', 'sw', 'int', 'p'])
field_type = 'modp'

def add_usage(self, req_node):
for player in self.args[5::6]:
req_node.increment((self.field_type, 'input', player), \
4 * self.get_size())

@base.gf2n
class startinput(base.RawInputInstruction):
r""" Receive inputs from player $p$. """
Expand Down
15 changes: 10 additions & 5 deletions Compiler/instructions_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,8 @@
PREP = 0x57,
# Input
INPUT = 0x60,
INPUTFIX = 0xF0,
INPUTFLOAT = 0xF1,
STARTINPUT = 0x61,
STOPINPUT = 0x62,
READSOCKETC = 0x63,
Expand Down Expand Up @@ -592,6 +594,10 @@ def __str__(self):
def __repr__(self):
return self.__class__.__name__ + '(' + self.get_pre_arg() + ','.join(str(a) for a in self.args) + ')'

class VarArgsInstruction(Instruction):
def has_var_args(self):
return True

###
### Basic arithmetic
###
Expand Down Expand Up @@ -692,6 +698,10 @@ class PublicFileIOInstruction(DoNotEliminateInstruction):
""" Instruction to reads/writes public information from/to files. """
__slots__ = []

class TextInputInstruction(VarArgsInstruction, DoNotEliminateInstruction):
""" Input from text file or stdin """
__slots__ = []

###
### Data access instructions
###
Expand Down Expand Up @@ -784,11 +794,6 @@ def get_relative_jump(self):
return self.args[self.jump_arg]


class VarArgsInstruction(Instruction):
def has_var_args(self):
return True


class CISC(Instruction):
"""
Base class for a CISC instruction.
Expand Down
4 changes: 3 additions & 1 deletion Compiler/program.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,9 @@ def __init__(self, args, options, param=-1, assemblymode=False):
Compiler.instructions.dotprods_class, \
Compiler.instructions.gdotprods_class, \
Compiler.instructions.asm_input_class, \
Compiler.instructions.gasm_input_class]
Compiler.instructions.gasm_input_class,
Compiler.instructions.inputfix,
Compiler.instructions.inputfloat]
import Compiler.GC.instructions as gc
self.to_merge += [gc.ldmsdi, gc.stmsdi, gc.ldmsd, gc.stmsd, \
gc.stmsdci, gc.xors, gc.andrs, gc.ands, gc.inputb]
Expand Down
15 changes: 15 additions & 0 deletions Compiler/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2408,6 +2408,12 @@ class sfix(_fix):
int_type = sint
clear_type = cfix

@classmethod
def get_input_from(cls, player):
v = cls.int_type()
inputfix(v, cls.f, player)
return cls._new(v)

@classmethod
def coerce(cls, other):
return parse_type(other)
Expand Down Expand Up @@ -2728,6 +2734,15 @@ def convert_float(v, vlen, plen):
'with %d exponent bits' % (vv, plen))
return v, p, z, s

@classmethod
def get_input_from(cls, player):
v = sint()
p = sint()
z = sint()
s = sint()
inputfloat(v, p, z, s, cls.vlen, player)
return cls(v, p, z, s)

@vectorize_init
@read_mem_value
def __init__(self, v, p=None, z=None, s=None, size=None):
Expand Down
11 changes: 7 additions & 4 deletions GC/ArgTuples.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class ArgIter
ArgIter<T> operator++()
{
auto res = it;
it += 3;
it += T::n;
return res;
}

Expand Down Expand Up @@ -64,16 +64,19 @@ class ArgList
class InputArgs
{
public:
static const int n = 3;
static const int n = 4;

int from;
int n_bits;
int& n_bits;
int& n_shift;
int params[2];
int dest;

InputArgs(vector<int>::const_iterator it)
InputArgs(vector<int>::const_iterator it) : n_bits(params[0]), n_shift(params[1])
{
from = *it++;
n_bits = *it++;
n_shift = *it++;
dest = *it++;
}
};
Expand Down
4 changes: 2 additions & 2 deletions GC/FakeSecret.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,9 @@ void FakeSecret::trans(Processor<FakeSecret>& processor, int n_outputs,
processor.S[args[i]] = square.rows[i];
}

FakeSecret FakeSecret::input(int from, GC::Processor<FakeSecret>& processor, int n_bits)
FakeSecret FakeSecret::input(GC::Processor<FakeSecret>& processor, const InputArgs& args)
{
return input(from, processor.get_input(n_bits), n_bits);
return input(args.from, processor.get_input(args.params), args.n_bits);
}

FakeSecret FakeSecret::input(int from, const int128& input, int n_bits)
Expand Down
3 changes: 2 additions & 1 deletion GC/FakeSecret.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include "GC/Clear.h"
#include "GC/Memory.h"
#include "GC/Access.h"
#include "GC/ArgTuples.h"

#include "Math/gf2nlong.h"

Expand Down Expand Up @@ -62,7 +63,7 @@ class FakeSecret

static void convcbit(Integer& dest, const Clear& source) { dest = source; }

static FakeSecret input(int from, GC::Processor<FakeSecret>& processor, int n_bits);
static FakeSecret input(GC::Processor<FakeSecret>& processor, const InputArgs& args);
static FakeSecret input(int from, const int128& input, int n_bits);

FakeSecret() : a(0) {}
Expand Down
4 changes: 2 additions & 2 deletions GC/Processor.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class Processor : public ::ProcessorBase
public:
static int check_args(const vector<int>& args, int n);

static void check_input(long long in, int n_bits);
static void check_input(bigint in, int n_bits);

Machine<T>& machine;

Expand All @@ -61,7 +61,7 @@ class Processor : public ::ProcessorBase
template<class U>
void reset(const U& program);

long long get_input(int n_bits, bool interactive = false);
long long get_input(const int* params, bool interactive = false);

void bitcoms(T& x, const vector<int>& regs) { x.bitcom(S, regs); }
void bitdecs(const vector<int>& regs, const T& x) { x.bitdec(S, regs); }
Expand Down
22 changes: 12 additions & 10 deletions GC/Processor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ using namespace std;

#include "GC/Program.h"
#include "Access.h"
#include "Processor/FixInput.h"

namespace GC
{
Expand Down Expand Up @@ -50,27 +51,29 @@ void Processor<T>::reset(const U& program)
}

template<class T>
inline long long GC::Processor<T>::get_input(int n_bits, bool interactive)
inline long long GC::Processor<T>::get_input(const int* params, bool interactive)
{
long long res = ProcessorBase::get_input(interactive);
bigint res = ProcessorBase::get_input<FixInput>(interactive, &params[1]).items[0];
int n_bits = *params;
check_input(res, n_bits);
return res;
assert(n_bits <= 64);
return res.get_si();
}

template<class T>
void GC::Processor<T>::check_input(long long in, int n_bits)
void GC::Processor<T>::check_input(bigint in, int n_bits)
{
auto test = in >> (n_bits - 1);
if (n_bits == 1)
{
if (not (in == 0 or in == 1))
throw runtime_error("input not a bit: " + to_string(in));
throw runtime_error("input not a bit: " + in.get_str());
}
else if (not (test == 0 or test == -1))
{
throw runtime_error(
"input too large for a " + std::to_string(n_bits)
+ "-bit signed integer: " + to_string(in));
+ "-bit signed integer: " + in.get_str());
}
}

Expand Down Expand Up @@ -182,11 +185,10 @@ void Processor<T>::and_(const vector<int>& args, bool repeat)
template <class T>
void Processor<T>::input(const vector<int>& args)
{
check_args(args, 3);
for (size_t i = 0; i < args.size(); i += 3)
InputArgList a(args);
for (auto x : a)
{
int n_bits = args[i + 1];
S[args[i+2]] = T::input(args[i] + 1, *this, n_bits);
S[x.dest] = T::input(*this, x);
#ifdef DEBUG_INPUT
cout << "input to " << args[i+2] << "/" << &S[args[i+2]] << endl;
#endif
Expand Down
Loading

0 comments on commit 5ef7058

Please sign in to comment.