Skip to content

Commit

Permalink
Add a cut to optimize solving
Browse files Browse the repository at this point in the history
  • Loading branch information
jellevos committed Jan 16, 2025
1 parent 8d4c54a commit e46297c
Show file tree
Hide file tree
Showing 8 changed files with 65 additions and 49 deletions.
5 changes: 3 additions & 2 deletions oraqle/compiler/circuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from oraqle.compiler.nodes.abstract import ArithmeticNode, ExtendedArithmeticNode, Node, ExtendedArithmeticCosts

from pysat.formula import WCNF, IDPool
from pysat.card import EncType

from oraqle.compiler.nodes.binary_arithmetic import Addition, Multiplication
from oraqle.compiler.nodes.unary_arithmetic import ConstantAddition, ConstantMultiplication
Expand Down Expand Up @@ -224,9 +225,9 @@ def __init__(self, outputs: List[ExtendedArithmeticNode]):
self._outputs = outputs
self._gf = outputs[0]._gf

def _add_constraints_minimize_cost_formulation(self, wcnf: WCNF, id_pool: IDPool, costs: Sequence[ExtendedArithmeticCosts], party_count: int):
def _add_constraints_minimize_cost_formulation(self, wcnf: WCNF, id_pool: IDPool, costs: Sequence[ExtendedArithmeticCosts], party_count: int, at_most_1_enc: Optional[int]):
for output in self._outputs:
output._add_constraints_minimize_cost_formulation(wcnf, id_pool, costs, party_count)
output._add_constraints_minimize_cost_formulation(wcnf, id_pool, costs, party_count, at_most_1_enc)
self._clear_cache()

def replace_randomness(self, party_count: int) -> "ExtendedArithmeticCircuit": # TODO: Think if this is the right type
Expand Down
10 changes: 5 additions & 5 deletions oraqle/compiler/nodes/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,12 @@

if TYPE_CHECKING:
from oraqle.compiler.boolean.bool import Boolean
from oraqle.mpc.protocol import Protocol

from oraqle.compiler.graphviz import DotFile
from oraqle.compiler.instructions import ArithmeticInstruction

from pysat.formula import IDPool, WCNF
from pysat.card import EncType

from oraqle.compiler.graphviz import DotFile

Expand Down Expand Up @@ -740,15 +740,15 @@ def _computational_cost(self, costs: Sequence[ExtendedArithmeticCosts], party_id
pass

@abstractmethod
def _add_constraints_minimize_cost_formulation_inner(self, wcnf: WCNF, id_pool: IDPool, costs: Sequence[ExtendedArithmeticCosts], parties: int):
def _add_constraints_minimize_cost_formulation_inner(self, wcnf: WCNF, id_pool: IDPool, costs: Sequence[ExtendedArithmeticCosts], party_count: int, at_most_1_enc: Optional[int]):
pass

def _add_constraints_minimize_cost_formulation(self, wcnf: WCNF, id_pool: IDPool, costs: Sequence[ExtendedArithmeticCosts], parties: int):
def _add_constraints_minimize_cost_formulation(self, wcnf: WCNF, id_pool: IDPool, costs: Sequence[ExtendedArithmeticCosts], party_count: int, at_most_1_enc: Optional[int]):
# TODO: We may not have to keep this cache, it might be done by apply_function_to_operands
if not self._added_constraints:
self._add_constraints_minimize_cost_formulation_inner(wcnf, id_pool, costs, parties)
self._add_constraints_minimize_cost_formulation_inner(wcnf, id_pool, costs, party_count, at_most_1_enc)
self._added_constraints = True
self.apply_function_to_operands(lambda node: node._add_constraints_minimize_cost_formulation(wcnf, id_pool, costs, parties)) # type: ignore
self.apply_function_to_operands(lambda node: node._add_constraints_minimize_cost_formulation(wcnf, id_pool, costs, party_count, at_most_1_enc)) # type: ignore

def replace_randomness(self, party_count: int) -> ExtendedArithmeticNode: # TODO: Think about types
if self._replace_randomness_cache is None:
Expand Down
11 changes: 9 additions & 2 deletions oraqle/compiler/nodes/binary_arithmetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from galois import FieldArray
from pysat.formula import WCNF, IDPool
from pysat.card import CardEnc, EncType

from oraqle.compiler.graphviz import DotFile
from oraqle.compiler.instructions import (
Expand Down Expand Up @@ -179,7 +180,7 @@ def create_instructions( # noqa: D102
def _expansion(self) -> Node:
raise NotImplementedError()

def _add_constraints_minimize_cost_formulation_inner(self, wcnf: WCNF, id_pool: IDPool, costs: Sequence[ExtendedArithmeticCosts], party_count: int):
def _add_constraints_minimize_cost_formulation_inner(self, wcnf: WCNF, id_pool: IDPool, costs: Sequence[ExtendedArithmeticCosts], party_count: int, at_most_1_enc: Optional[int]):
print("yeet", id(self), self)
for party_id in range(party_count):
# We can compute a value if we hold both inputs
Expand All @@ -194,7 +195,7 @@ def _add_constraints_minimize_cost_formulation_inner(self, wcnf: WCNF, id_pool:
# If we do not already know this value, then
if not PartyId(party_id) in self._known_by:
# We hold h if we compute it
sources = [-h, c]
sources = [c]

# Or when it is sent by another party
for other_party_id in range(party_count):
Expand All @@ -207,7 +208,13 @@ def _add_constraints_minimize_cost_formulation_inner(self, wcnf: WCNF, id_pool:
# Add the cost for receiving a value from other_party_id
wcnf.append([-received], weight=costs[party_id].receive(PartyId(other_party_id)))

# Add a cut: we only want to compute/receive from one source
if at_most_1_enc is not None:
at_most_1 = CardEnc.atmost(sources, encoding=at_most_1_enc, vpool=id_pool) # type: ignore
wcnf.extend(at_most_1)

# Add to WCNF
sources.append(-h)
wcnf.append(sources)

# We can only send if we hold the value
Expand Down
11 changes: 9 additions & 2 deletions oraqle/compiler/nodes/leafs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from galois import FieldArray
from pysat.formula import WCNF, IDPool
from pysat.card import CardEnc, EncType

from oraqle.compiler.graphviz import DotFile
from oraqle.compiler.instructions import ArithmeticInstruction, InputInstruction
Expand Down Expand Up @@ -47,7 +48,7 @@ def _expansion(self) -> Node:

class ExtendedArithmeticLeafNode(LeafNode, ExtendedArithmeticNode):

def _add_constraints_minimize_cost_formulation_inner(self, wcnf: WCNF, id_pool: IDPool, costs: Sequence[ExtendedArithmeticCosts], party_count: int):
def _add_constraints_minimize_cost_formulation_inner(self, wcnf: WCNF, id_pool: IDPool, costs: Sequence[ExtendedArithmeticCosts], party_count: int, at_most_1_enc: Optional[int]):
print('leaf', self, self._known_by)

# We can only send if we hold the value
Expand All @@ -57,7 +58,7 @@ def _add_constraints_minimize_cost_formulation_inner(self, wcnf: WCNF, id_pool:
# If we do not already know this value, then
if PartyId(party_id) not in self._known_by:
# We hold h if we compute it
sources = [-h]
sources = []

# Or when it is sent by another party
for other_party_id in range(party_count):
Expand All @@ -70,7 +71,13 @@ def _add_constraints_minimize_cost_formulation_inner(self, wcnf: WCNF, id_pool:
# Add the cost for receiving a value from other_party_id
wcnf.append([-received], weight=costs[party_id].receive(PartyId(other_party_id)))

# Add a cut: we only want to compute/receive from one source
if at_most_1_enc is not None:
at_most_1 = CardEnc.atmost(sources, encoding=at_most_1_enc, vpool=id_pool) # type: ignore
wcnf.extend(at_most_1)

# Add to WCNF
sources.append(-h)
wcnf.append(sources)

for other_party_id in range(party_count):
Expand Down
11 changes: 9 additions & 2 deletions oraqle/compiler/nodes/unary_arithmetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from galois import FieldArray
from pysat.formula import WCNF, IDPool
from pysat.card import CardEnc, EncType

from oraqle.compiler.graphviz import DotFile
from oraqle.compiler.instructions import (
Expand All @@ -23,7 +24,7 @@ def __init__(self, node: ArithmeticNode, is_constant_mul: bool):
super().__init__(node)
self._is_constant_mul = is_constant_mul

def _add_constraints_minimize_cost_formulation_inner(self, wcnf: WCNF, id_pool: IDPool, costs: Sequence[ExtendedArithmeticCosts], parties: int):
def _add_constraints_minimize_cost_formulation_inner(self, wcnf: WCNF, id_pool: IDPool, costs: Sequence[ExtendedArithmeticCosts], parties: int, at_most_1_enc: Optional[int]):
# TODO: Consider reducing duplication with bivariate arithmetic

for party_id in range(parties):
Expand All @@ -37,7 +38,7 @@ def _add_constraints_minimize_cost_formulation_inner(self, wcnf: WCNF, id_pool:
# If we do not already know this value, then
if not PartyId(party_id) in self._known_by:
# We hold h if we compute it
sources = [-h, c]
sources = [c]

# Or when it is sent by another party
for other_party_id in range(parties):
Expand All @@ -50,7 +51,13 @@ def _add_constraints_minimize_cost_formulation_inner(self, wcnf: WCNF, id_pool:
# Add the cost for receiving a value from other_party_id
wcnf.append([-received], weight=costs[party_id].receive(PartyId(other_party_id)))

# Add a cut: we only want to compute/receive from one source
if at_most_1_enc is not None:
at_most_1 = CardEnc.atmost(sources, encoding=at_most_1_enc, vpool=id_pool) # type: ignore
wcnf.extend(at_most_1)

# Add to WCNF
sources.append(-h)
wcnf.append(sources)

# We can only send if we hold the value
Expand Down
52 changes: 29 additions & 23 deletions oraqle/mpc/compilation.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import subprocess
from typing import Any, List, Sequence, Set
from typing import Any, List, Optional, Sequence, Set

from galois import GF

Expand All @@ -16,6 +16,9 @@
from oraqle.mpc.parties import PartyId

from pysat.formula import WCNF, IDPool
from pysat.card import EncType

import time


class LeaderCosts(ExtendedArithmeticCosts):
Expand Down Expand Up @@ -63,7 +66,7 @@ def create_star_topology_costs(leader_arithmetic_costs: ArithmeticCosts, other_a
return parties


def minimize_total_protocol_cost(circuit: Circuit, supported_multiplications: int, precomputed_randomness: bool, max_colluders: int, costs: Sequence[ExtendedArithmeticCosts]):
def minimize_total_protocol_cost(circuit: Circuit, supported_multiplications: int, precomputed_randomness: bool, max_colluders: int, costs: Sequence[ExtendedArithmeticCosts], at_most_1_enc: Optional[int]):
# FIXME: Add collusion threshold to signature

extended_arithmetic_circuit = circuit.arithmetize_extended()
Expand All @@ -80,7 +83,7 @@ def minimize_total_protocol_cost(circuit: Circuit, supported_multiplications: in
# Let each node add their own clauses to the formulation
wcnf = WCNF()
id_pool = IDPool()
processed_circuit._add_constraints_minimize_cost_formulation(wcnf, id_pool, costs, party_count)
processed_circuit._add_constraints_minimize_cost_formulation(wcnf, id_pool, costs, party_count, at_most_1_enc)
# TODO: We now assume that party 0 must learn the final results
for output in processed_circuit._outputs:
party_zero = 0
Expand Down Expand Up @@ -138,33 +141,33 @@ def minimize_total_protocol_cost(circuit: Circuit, supported_multiplications: in


if __name__ == "__main__":
party_count = 3
gf = GF(11)
# party_count = 3
# gf = GF(11)

# a = Input("a", gf, {PartyId(0)})
# a_neg = (a * Constant(gf(10))) + 1
# # a = Input("a", gf, {PartyId(0)})
# # a_neg = (a * Constant(gf(10))) + 1

b = Input("b", gf, {PartyId(1)})
b_neg = (b * Constant(gf(10))) + 1
# b = Input("b", gf, {PartyId(1)})
# b_neg = (b * Constant(gf(10))) + 1

c = Input("c", gf, {PartyId(2)})
c_neg = (c * KnownRandom(gf, {PartyId(0)})) + 1
# c = Input("c", gf, {PartyId(2)})
# c_neg = (c * KnownRandom(gf, {PartyId(0)})) + 1

circuit = Circuit([(b_neg + c_neg) * KnownRandom(gf, {PartyId(1)})])
circuit.to_pdf("simple.pdf")
circuit = circuit.arithmetize_extended()
circuit.to_pdf("simple-arith.pdf")
# circuit = Circuit([(b_neg + c_neg) * KnownRandom(gf, {PartyId(1)})])
# circuit.to_pdf("simple.pdf")
# circuit = circuit.arithmetize_extended()
# circuit.to_pdf("simple-arith.pdf")

leader_arithmetic_costs = ArithmeticCosts(1., float('inf'), 1., 100.)
other_arithmetic_costs = leader_arithmetic_costs * 10.
print(leader_arithmetic_costs)
print(other_arithmetic_costs)
all_costs = create_star_topology_costs(leader_arithmetic_costs, other_arithmetic_costs, 1000., 1000., 2000., 2000., party_count)
minimize_total_protocol_cost(circuit, 0, False, party_count - 1, all_costs)
# leader_arithmetic_costs = ArithmeticCosts(1., float('inf'), 1., 100.)
# other_arithmetic_costs = leader_arithmetic_costs * 10.
# print(leader_arithmetic_costs)
# print(other_arithmetic_costs)
# all_costs = create_star_topology_costs(leader_arithmetic_costs, other_arithmetic_costs, 1000., 1000., 2000., 2000., party_count)
# minimize_total_protocol_cost(circuit, 0, False, party_count - 1, all_costs)



exit(0)
# exit(0)
# TODO: Add proper set intersection interface
gf = GF(11)

Expand Down Expand Up @@ -193,4 +196,7 @@ def minimize_total_protocol_cost(circuit: Circuit, supported_multiplications: in
print(leader_arithmetic_costs)
print(other_arithmetic_costs)
all_costs = create_star_topology_costs(leader_arithmetic_costs, other_arithmetic_costs, 1000., 1000., 2000., 2000., party_count)
minimize_total_protocol_cost(circuit, 0, False, party_count - 1, all_costs)

t = time.monotonic()
minimize_total_protocol_cost(circuit, 0, False, party_count - 1, all_costs, at_most_1_enc=EncType.ladder)
print(time.monotonic() - t)
2 changes: 1 addition & 1 deletion oraqle/mpc/parties.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@


# TODO: For now, parties have no properties and their IDs are simply unique ints
# TODO: For now, parties have no properties and their IDs are simply unique ints in [0, party_count)

class PartyId(int):
pass
12 changes: 0 additions & 12 deletions oraqle/mpc/protocol.py

This file was deleted.

0 comments on commit e46297c

Please sign in to comment.