Skip to content

Commit

Permalink
Write out MPSI
Browse files Browse the repository at this point in the history
  • Loading branch information
jellevos committed Jan 10, 2025
1 parent 97eb741 commit 1d62b70
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 22 deletions.
1 change: 0 additions & 1 deletion oraqle/compiler/boolean/bool_neg.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ def _arithmetize_depth_aware_inner(self, cost_of_squaring: float) -> CostParetoF
raise NotImplementedError("TODO!")

def transform_to_reduced_boolean(self) -> ReducedBoolean:
Circuit([self._node]).to_pdf("debug.pdf")
return ReducedNeg(self._node.transform_to_reduced_boolean(), self._gf)

def transform_to_unreduced_boolean(self) -> UnreducedBoolean:
Expand Down
22 changes: 11 additions & 11 deletions oraqle/compiler/sets/bitset.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@


# TODO: At some point we can implement __and__ for intersections
class BitSet(Node): # TODO: Node should become Boolean
class BitSet(Node):

def __getitem__(self, index) -> Boolean:
assert 0 <= index < len(self)
Expand Down Expand Up @@ -43,8 +43,8 @@ def _hash_name(self) -> str:
def _node_label(self) -> str:
return "Bitset"

def __init__(self, bits: Sequence[Boolean], gf: Type[FieldArray]):
super().__init__(gf)
def __init__(self, bits: Sequence[Boolean]):
super().__init__(bits[0]._gf)
self._bits = list(bits)

# def apply_function_to_operands(self, function: Callable[[Node], None]):
Expand Down Expand Up @@ -98,7 +98,7 @@ def _node_label(self) -> str:
return "Bitset"

def __init__(self, bitset: BitSetContainer):
super().__init__(gf)
super().__init__(bitset._gf)
self._bitset = bitset

# def apply_function_to_operands(self, function: Callable[[Node], None]):
Expand Down Expand Up @@ -131,7 +131,7 @@ def operation(self, operands: List[FieldArray]) -> FieldArray:

def _arithmetize_inner(self, strategy: str) -> Node:
# TODO: Consider changing the arithmetize type
return BitSetContainer([bit.transform_to_reduced_boolean().arithmetize(strategy) for bit in self._bitset._bits], self._gf) # type: ignore
return BitSetContainer([bit.transform_to_reduced_boolean().arithmetize(strategy) for bit in self._bitset._bits]) # type: ignore

def _arithmetize_depth_aware_inner(self, cost_of_squaring: float) -> CostParetoFront:
raise NotImplementedError("TODO")
Expand All @@ -152,7 +152,7 @@ def _node_label(self) -> str:
return "Bitset"

def __init__(self, bitset: BitSetContainer):
super().__init__(gf)
super().__init__(bitset._gf)
self._bitset = bitset

# def apply_function_to_operands(self, function: Callable[[Node], None]):
Expand Down Expand Up @@ -185,7 +185,7 @@ def operation(self, operands: List[FieldArray]) -> FieldArray:

def _arithmetize_inner(self, strategy: str) -> Node:
# TODO: Consider changing the arithmetize type
return BitSetContainer([bit.transform_to_inv_unreduced_boolean().arithmetize(strategy) for bit in self._bitset._bits], self._gf) # type: ignore
return BitSetContainer([bit.transform_to_inv_unreduced_boolean().arithmetize(strategy) for bit in self._bitset._bits]) # type: ignore

def _arithmetize_depth_aware_inner(self, cost_of_squaring: float) -> CostParetoFront:
raise NotImplementedError("TODO")
Expand Down Expand Up @@ -405,7 +405,7 @@ def _arithmetize_inner(self, strategy: str) -> BitSet:
# TODO: Assert all lengths are equal? Or that they map the same universe?
bit_count = len(self)
# TODO: After arithmetizing one of the bitsets, we can consider reusing that arithmetization for the rest (so not to run in O(n))
return BitSetContainer([all_(*(operand.node[i] for operand in self._operands)).arithmetize(strategy) for i in range(bit_count)], self._gf) # type: ignore
return BitSetContainer([all_(*(operand.node[i] for operand in self._operands)).arithmetize(strategy) for i in range(bit_count)]) # type: ignore

def _arithmetize_depth_aware_inner(self, cost_of_squaring: float) -> CostParetoFront:
raise NotImplementedError()
Expand All @@ -419,13 +419,13 @@ def _inner_operation(self, a: FieldArray, b: FieldArray) -> FieldArray:
gf = GF(11)

bits1 = [BooleanInput(f"b1_{i}", gf) for i in range(10)]
bitset1 = BitSetContainer(bits1, gf)
bitset1 = BitSetContainer(bits1)

bits2 = [BooleanInput(f"b2_{i}", gf) for i in range(10)]
bitset2 = BitSetContainer(bits2, gf)
bitset2 = BitSetContainer(bits2)

bits3 = [BooleanInput(f"b3_{i}", gf) for i in range(10)]
bitset3 = BitSetContainer(bits3, gf)
bitset3 = BitSetContainer(bits3)

final_bitset = BitSet.intersection(bitset1, bitset2, bitset3)

Expand Down
32 changes: 22 additions & 10 deletions oraqle/mpc/nodes.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
from typing import Set

from galois import GF

from oraqle.compiler.boolean.bool import ReducedBooleanInput
from oraqle.compiler.circuit import Circuit
from oraqle.compiler.nodes.abstract import Node
from oraqle.compiler.sets.bitset import BitSet, BitSetContainer
from oraqle.mpc.parties import PartyId


Expand All @@ -14,14 +19,21 @@ def __init__(self, node: Node, known_by: Set[PartyId], leakable_to: Set[PartyId]
self._computed_by = computed_by



if __name__ == "__main__":
# gf = GF(11)
# bits = [Input(f"b_{i}", gf) for i in range(10)]
# circuit = Circuit([BitSet(bits, gf).contains_element(3)]).to_pdf("debug.pdf")

# TODO: Encode bitset, whose inputs are known by party 1
# TODO: Encode bitset, whose inputs are known by party 2
# TODO: Intersect bitsets, computed by both 1 and 2
# TODO: Query bitset on inputs known by party 1
pass
# TODO: Add proper set intersection interface
gf = GF(11)

# TODO: Consider immediately creating a bitset (container) using bitset params/set params
party_bitsets = []
for party_id in range(5):
bits = [ReducedBooleanInput(f"b{party_id}_{i}", gf) for i in range(10)]
bitset = BitSetContainer(bits)
party_bitsets.append(bitset)

intersection = BitSet.intersection(*party_bitsets)

circuit = Circuit([intersection.contains_element(element) for element in [1, 4, 5, 9]])
circuit.to_pdf("debug.pdf")

arithmetic_circuit = circuit.arithmetize()
arithmetic_circuit.to_pdf("debug2.pdf")

0 comments on commit 1d62b70

Please sign in to comment.