diff --git a/oraqle/compiler/boolean/bool.py b/oraqle/compiler/boolean/bool.py index 8a1b8da..5468eb7 100644 --- a/oraqle/compiler/boolean/bool.py +++ b/oraqle/compiler/boolean/bool.py @@ -17,7 +17,7 @@ class Boolean(Node): def __invert__(self) -> "Node": from oraqle.compiler.boolean.bool_neg import Neg - return Neg(self, self._gf) + return Neg(self) def bool_or(self, other: "Boolean", flatten=True) -> "Boolean": """Performs an OR operation between `self` and `other`, possibly flattening the result into an OR operation between many operands. @@ -210,6 +210,7 @@ def cast_to_inv_unreduced_boolean(node: Node) -> InvUnreducedBoolean: return _cast_to(node, InvUnreducedBoolean) +# TODO: Think about the security of the transformations below class UnreducedBoolean(Boolean): def transform_to_reduced_boolean(self) -> ReducedBoolean: diff --git a/oraqle/compiler/boolean/bool_and.py b/oraqle/compiler/boolean/bool_and.py index cb22d5d..070b210 100644 --- a/oraqle/compiler/boolean/bool_and.py +++ b/oraqle/compiler/boolean/bool_and.py @@ -104,6 +104,7 @@ def _inner_operation(self, a: FieldArray, b: FieldArray) -> FieldArray: return a + b def _arithmetize_inner(self, strategy: str) -> Node: + Randomize! # TODO: We need to randomize (i.e. make it a Sum with random multiplicities) # TODO: Consider not supporting additions between Booleans unless they are cast to field elements return cast_to_inv_unreduced_boolean(sum_(*self._operands)).arithmetize(strategy) @@ -175,12 +176,12 @@ def _arithmetize_inner(self, strategy: str) -> Node: # noqa: PLR0911, PLR0912 max_depth = popped.priority if total_sum is None: - total_sum = ReducedNeg(popped.item, self._gf) + total_sum = ReducedNeg(popped.item) else: - total_sum += ReducedNeg(popped.item, self._gf) + total_sum += ReducedNeg(popped.item) assert total_sum is not None - final_result = ReducedNeg(ReducedIsNonZero(total_sum, self._gf), self._gf).arithmetize(strategy) + final_result = ReducedNeg(ReducedIsNonZero(total_sum)).arithmetize(strategy) assert max_depth is not None heappush(queue, _PrioritizedItem(max_depth, final_result)) @@ -189,7 +190,7 @@ def _arithmetize_inner(self, strategy: str) -> Node: # noqa: PLR0911, PLR0912 return heappop(queue).item dummy_node = ReducedBooleanInput("dummy_node", self._gf) - is_non_zero = ReducedIsNonZero(dummy_node, self._gf).arithmetize(strategy).to_arithmetic() + is_non_zero = ReducedIsNonZero(dummy_node).arithmetize(strategy).to_arithmetic() cost = is_non_zero.multiplicative_cost( 1.0 ) # FIXME: This needs to be the actual squaring cost @@ -202,12 +203,10 @@ def _arithmetize_inner(self, strategy: str) -> Node: # noqa: PLR0911, PLR0912 return ReducedNeg( ReducedIsNonZero( Sum( - Counter({UnoverloadedWrapper(ReducedNeg(node.item, self._gf)): 1 for node in queue}), + Counter({UnoverloadedWrapper(ReducedNeg(node.item)): 1 for node in queue}), self._gf, ), - self._gf, ), - self._gf, ).arithmetize(strategy) def _arithmetize_depth_aware_inner(self, cost_of_squaring: float) -> CostParetoFront: diff --git a/oraqle/compiler/boolean/bool_neg.py b/oraqle/compiler/boolean/bool_neg.py index d7acdde..25937a8 100644 --- a/oraqle/compiler/boolean/bool_neg.py +++ b/oraqle/compiler/boolean/bool_neg.py @@ -35,7 +35,7 @@ def _arithmetize_depth_aware_inner(self, cost_of_squaring: float) -> CostParetoF raise NotImplementedError("TODO!") def transform_to_reduced_boolean(self) -> ReducedBoolean: - return ReducedNeg(self._node.transform_to_reduced_boolean(), self._gf) + return ReducedNeg(self._node.transform_to_reduced_boolean()) def transform_to_unreduced_boolean(self) -> UnreducedBoolean: raise NotImplementedError("TODO!") diff --git a/oraqle/compiler/boolean/bool_or.py b/oraqle/compiler/boolean/bool_or.py index 6bf961b..295e2fb 100644 --- a/oraqle/compiler/boolean/bool_or.py +++ b/oraqle/compiler/boolean/bool_or.py @@ -117,7 +117,6 @@ def _arithmetize_inner(self, strategy: str) -> Node: }, # type: ignore self._gf, ), - self._gf, ).arithmetize(strategy) def _arithmetize_depth_aware_inner(self, cost_of_squaring: float) -> CostParetoFront: diff --git a/oraqle/compiler/comparison/comparison.py b/oraqle/compiler/comparison/comparison.py index 3ce3819..66975f8 100644 --- a/oraqle/compiler/comparison/comparison.py +++ b/oraqle/compiler/comparison/comparison.py @@ -78,7 +78,6 @@ def _arithmetize_inner(self, strategy: str) -> Node: return InUpperHalf( Subtraction(left.arithmetize(strategy), right.arithmetize(strategy), self._gf), - self._gf, ).arithmetize(strategy) def _arithmetize_depth_aware_inner(self, cost_of_squaring: float) -> CostParetoFront: @@ -100,7 +99,6 @@ def _arithmetize_depth_aware_inner(self, cost_of_squaring: float) -> CostParetoF sub_front = InUpperHalf( Subtraction(left_node, right_node, self._gf), - self._gf, ).arithmetize_depth_aware(cost_of_squaring) front.add_front(sub_front) @@ -148,7 +146,7 @@ def _arithmetize_inner(self, strategy: str) -> Node: # Test whether left and right are in the same range same_range = (left_is_small & right_is_small) + ( - ReducedNeg(left_is_small, self._gf) & ReducedNeg(right_is_small, self._gf) + ReducedNeg(left_is_small) & ReducedNeg(right_is_small) ) # Performs left < right on the reduced inputs, note that if both are in the upper half the difference is still small enough for a semi-comparison @@ -156,7 +154,7 @@ def _arithmetize_inner(self, strategy: str) -> Node: result = same_range * comparison # Performs left < right when one if small and the other is large - right_is_larger = left_is_small & ReducedNeg(right_is_small, self._gf) + right_is_larger = left_is_small & ReducedNeg(right_is_small) result += right_is_larger return result.arithmetize(strategy) @@ -188,7 +186,7 @@ def _arithmetize_depth_aware_inner(self, cost_of_squaring: float) -> CostParetoF # Test whether left and right are in the same range same_range = (left_is_small & right_is_small) + ( - ReducedNeg(left_is_small, self._gf) & ReducedNeg(right_is_small, self._gf) + ReducedNeg(left_is_small) & ReducedNeg(right_is_small) ) # Performs left < right on the reduced inputs, note that if both are in the upper half the difference is still small enough for a semi-comparison @@ -198,7 +196,7 @@ def _arithmetize_depth_aware_inner(self, cost_of_squaring: float) -> CostParetoF result = same_range * comparison # Performs left < right when one if small and the other is large - right_is_larger = left_is_small & ReducedNeg(right_is_small, self._gf) + right_is_larger = left_is_small & ReducedNeg(right_is_small) result += right_is_larger front.add_front(result.arithmetize_depth_aware(cost_of_squaring)) @@ -233,7 +231,6 @@ def _arithmetize_inner(self, strategy: str) -> Node: less_than=not self._less_than, gf=self._gf, ), - self._gf, ).arithmetize(strategy) def _arithmetize_depth_aware_inner(self, cost_of_squaring: float) -> CostParetoFront: @@ -241,7 +238,6 @@ def _arithmetize_depth_aware_inner(self, cost_of_squaring: float) -> CostParetoF SemiStrictComparison( self._left, self._right, less_than=not self._less_than, gf=self._gf ), - self._gf, ).arithmetize_depth_aware(cost_of_squaring) @@ -270,13 +266,11 @@ def _arithmetize_inner(self, strategy: str) -> Node: less_than=not self._less_than, gf=self._gf, ), - self._gf, ).arithmetize(strategy) def _arithmetize_depth_aware_inner(self, cost_of_squaring: float) -> CostParetoFront: return ReducedNeg( StrictComparison(self._left, self._right, less_than=not self._less_than, gf=self._gf), - self._gf, ).arithmetize_depth_aware(cost_of_squaring) @@ -328,7 +322,6 @@ def _arithmetize_inner(self, strategy: str) -> Node: Subtraction( self._left.arithmetize(strategy), self._right.arithmetize(strategy), self._gf ), - self._gf, ).arithmetize(strategy) def _arithmetize_depth_aware_inner(self, cost_of_squaring: float) -> CostParetoFront: diff --git a/oraqle/compiler/comparison/equality.py b/oraqle/compiler/comparison/equality.py index 63450cf..0a1d87c 100644 --- a/oraqle/compiler/comparison/equality.py +++ b/oraqle/compiler/comparison/equality.py @@ -5,8 +5,9 @@ from oraqle.compiler.arithmetic.subtraction import Subtraction from oraqle.compiler.boolean.bool import Boolean, InvUnreducedBoolean, ReducedBoolean, UnreducedBoolean, cast_to_reduced_boolean from oraqle.compiler.boolean.bool_neg import Neg, ReducedNeg -from oraqle.compiler.nodes.abstract import CostParetoFront, Node +from oraqle.compiler.nodes.abstract import CostParetoFront, ExtendedArithmeticNode, Node from oraqle.compiler.nodes.binary_arithmetic import CommutativeBinaryNode +from oraqle.compiler.nodes.extended import Random, Reveal from oraqle.compiler.nodes.leafs import Input from oraqle.compiler.nodes.univariate import UnivariateNode @@ -32,6 +33,17 @@ def _operation_inner(self, input: FieldArray) -> FieldArray: def _arithmetize_inner(self, strategy: str) -> Node: return self.arithmetize_all_representations(strategy) + def _arithmetize_extended_inner(self) -> ExtendedArithmeticNode: + arithmetic = self.arithmetize_all_representations("best-effort") + extended_arithmetic = Reveal(Random(self._gf) * self._node.arithmetize()) == 0 + + if metric(extended_arithmetic) < metric(arithmetic): + return extended_arithmetic + + return arithmetic + + Keep in mind that this is just an example; it is not necessary for the thesis!! + def _arithmetize_depth_aware_inner(self, cost_of_squaring: float) -> CostParetoFront: raise NotImplementedError("TODO!") diff --git a/oraqle/compiler/comparison/in_upper_half.py b/oraqle/compiler/comparison/in_upper_half.py index 5cdc1f6..294a618 100644 --- a/oraqle/compiler/comparison/in_upper_half.py +++ b/oraqle/compiler/comparison/in_upper_half.py @@ -250,7 +250,7 @@ def test_evaluate_mod7(): # noqa: D103 gf = GF(7) x = Input("x", gf) - node = InUpperHalf(x, gf) + node = InUpperHalf(x) for i in range(3): assert node.evaluate({"x": gf(i)}) == gf(0) @@ -264,7 +264,7 @@ def test_evaluate_arithmetized_mod7(): # noqa: D103 gf = GF(7) x = Input("x", gf) - node = InUpperHalf(x, gf).arithmetize("best-effort") + node = InUpperHalf(x).arithmetize("best-effort") node.clear_cache(set()) for i in range(3): diff --git a/oraqle/compiler/graphviz.py b/oraqle/compiler/graphviz.py index c2dac7c..91a3bcb 100644 --- a/oraqle/compiler/graphviz.py +++ b/oraqle/compiler/graphviz.py @@ -33,7 +33,7 @@ def add_link(self, from_id: int, to_id: int, **kwargs): self._links.append((from_id, to_id, kwargs)) def add_cluster(self, ids: Sequence[int], **kwargs): - """Adds a cluster containings the nodes with the given IDs. The keyword arguments are directly put into the DOT file.""" + """Adds a cluster containing the nodes with the given IDs. The keyword arguments are directly put into the DOT file.""" assert -1 not in ids self._clusters.append((ids, kwargs)) diff --git a/oraqle/compiler/nodes/abstract.py b/oraqle/compiler/nodes/abstract.py index a30afd3..4b11fbd 100644 --- a/oraqle/compiler/nodes/abstract.py +++ b/oraqle/compiler/nodes/abstract.py @@ -8,6 +8,7 @@ if TYPE_CHECKING: from oraqle.compiler.boolean.bool import Boolean + from oraqle.mpc.protocol import Protocol from oraqle.compiler.graphviz import DotFile from oraqle.compiler.instructions import ArithmeticInstruction @@ -302,6 +303,7 @@ def __init__(self, gf: Type[FieldArray]): self._instruction_cache: Optional[int] = None self._arithmetic_cache: Optional[ArithmeticNode] = None self._parent_count_cache: Optional[int] = None + self._arithmetize_extended_cache: Optional[ExtendedArithmeticNode] = None self._hash = None @@ -336,6 +338,7 @@ def clear_cache(self, already_cleared: Set[int]): self._instruction_cache: Optional[int] = None self._arithmetic_cache: Optional[ArithmeticNode] = None self._parent_count_cache: Optional[int] = None + self._arithmetize_extended_cache: Optional[ExtendedArithmeticNode] = None self._hash = None @@ -440,6 +443,19 @@ def to_arithmetic(self) -> "ArithmeticNode": raise Exception( f"This node does not have a direct arithmetic equivalent: {self}. Consider first calling `arithmetize`." ) + + def arithmetize_extended(self) -> "ExtendedArithmeticNode": + """Arithmetizes this node as an extended arithmetic circuit, which includes random and reveal nodes. + + The default implementation simply calls arithmetize, because every arithmetic circuit is also an extended arithmetic circuit. + + Returns: + An ExtendedArithmeticNode that computes this Node. + """ + # TODO: propagate known by? + # TODO: Add leak to? E.g. by adding reveal after it. + + return self.arithmetize("best-effort").to_arithmetic() def add(self, other: "Node", flatten=True) -> "Node": """Performs a summation between `self` and `other`, possibly flattening any sums. @@ -620,10 +636,21 @@ def __eq__(self, other) -> bool: return False return self.node.is_equivalent(other.node) + + +class ExtendedArithmeticNode(Node): + + @abstractmethod + def operands(self) -> List["ExtendedArithmeticNode"]: + """Returns the operands (children) of this node. The list can be empty. The nodes MUST be extended arithmetic nodes.""" + + @abstractmethod + def set_operands(self, operands: List["ExtendedArithmeticNode"]): + """Overwrites the operands of this node. The nodes MUST be extended arithmetic nodes.""" # TODO: Do we need a separate class to distinguish nodes from arithmetic nodes (which only have arithmetic operands)? -class ArithmeticNode(Node): +class ArithmeticNode(ExtendedArithmeticNode): """Extension of Node to indicate that this is a node permitted in a purely arithmetic circuit (with binary additions and multiplications). The ArithmeticNode 'mixin' must always come before the base class in the class declaration. diff --git a/oraqle/compiler/nodes/extended.py b/oraqle/compiler/nodes/extended.py new file mode 100644 index 0000000..6640ad6 --- /dev/null +++ b/oraqle/compiler/nodes/extended.py @@ -0,0 +1,69 @@ +import random +from typing import List, Type +from galois import FieldArray +from oraqle.compiler.nodes.abstract import CostParetoFront, ExtendedArithmeticNode, Node +from oraqle.compiler.nodes.leafs import LeafNode +from oraqle.compiler.nodes.univariate import UnivariateNode + + +class Reveal(UnivariateNode, ExtendedArithmeticNode): + + @property + def _node_shape(self) -> str: + return "circle" + + @property + def _hash_name(self) -> str: + return "reveal" + + @property + def _node_label(self) -> str: + return "Reveal" + + def _arithmetize_inner(self, strategy: str) -> Node: + raise NotImplementedError("Reveal cannot be arithmetized: arithmetic circuits only contain arithmetic operations.") + + def _arithmetize_depth_aware_inner(self, cost_of_squaring: float) -> CostParetoFront: + raise NotImplementedError("Reveal cannot be arithmetized: arithmetic circuits only contain arithmetic operations.") + + def _operation_inner(self, input: FieldArray) -> FieldArray: + return input + + # FIXME: Overload operators to create *plaintext* operations + + +class Random(LeafNode, ExtendedArithmeticNode): + + @property + def _node_shape(self) -> str: + return "circle" + + @property + def _hash_name(self) -> str: + return "random" + + @property + def _node_label(self) -> str: + return "Random" + + def __init__(self, gf: type[FieldArray]): + self._hash = hash(random.randbytes(16)) # TODO: Not neat + super().__init__(gf) + + def __hash__(self) -> int: + return self._hash + + def _arithmetize_inner(self, strategy: str) -> Node: + raise NotImplementedError("Reveal cannot be arithmetized: arithmetic circuits only contain arithmetic operations.") + + def _arithmetize_depth_aware_inner(self, cost_of_squaring: float) -> CostParetoFront: + raise NotImplementedError("Reveal cannot be arithmetized: arithmetic circuits only contain arithmetic operations.") + + def is_equivalent(self, other: Node) -> bool: + if not isinstance(other, self.__class__): + return False + + return self._hash == other._hash + + def operation(self, operands: List[FieldArray]) -> FieldArray: + return self._gf.Random() diff --git a/oraqle/compiler/nodes/fixed.py b/oraqle/compiler/nodes/fixed.py index acf34ac..5e051f1 100644 --- a/oraqle/compiler/nodes/fixed.py +++ b/oraqle/compiler/nodes/fixed.py @@ -4,7 +4,7 @@ from galois import FieldArray -from oraqle.compiler.nodes.abstract import CostParetoFront, Node +from oraqle.compiler.nodes.abstract import CostParetoFront, ExtendedArithmeticNode, Node class FixedNode[N: Node](Node): @@ -95,6 +95,17 @@ def arithmetize_depth_aware(self, cost_of_squaring: float) -> CostParetoFront: def _arithmetize_depth_aware_inner(self, cost_of_squaring: float) -> CostParetoFront: pass + def arithmetize_extended(self) -> ExtendedArithmeticNode: + # TODO: Handle constants similarly as above + if self._arithmetize_extended_cache is None: + self._arithmetize_extended_cache = self._arithmetize_extended_inner() + + return self._arithmetize_extended_cache + + def _arithmetize_extended_inner(self) -> "ExtendedArithmeticNode": + # TODO: Check if this is a good default implementation, add documentation + return self.arithmetize("best-effort").to_arithmetic() + class BinaryNode(FixedNode): """A node with two operands.""" diff --git a/oraqle/compiler/nodes/leafs.py b/oraqle/compiler/nodes/leafs.py index 211d9ad..8572535 100644 --- a/oraqle/compiler/nodes/leafs.py +++ b/oraqle/compiler/nodes/leafs.py @@ -9,8 +9,8 @@ from oraqle.compiler.nodes.fixed import FixedNode -class ArithmeticLeafNode(FixedNode, ArithmeticNode): - """An ArithmeticLeafNode is an ArithmeticNode with no inputs.""" +class LeafNode(FixedNode): + """A LeafNode is a FixedNode with no inputs.""" def operands(self) -> List[Node]: # noqa: D102 return [] @@ -35,6 +35,10 @@ def multiplications(self) -> Set[int]: # noqa: D102 def squarings(self) -> Set[int]: # noqa: D102 return set() + + +class ArithmeticLeafNode(LeafNode, ArithmeticNode): + """An ArithmeticLeafNode is an ArithmeticNode with no inputs.""" # TODO: Merge ArithmeticInput and Input using multiple inheritance diff --git a/oraqle/compiler/nodes/plaintext_ops.py b/oraqle/compiler/nodes/plaintext_ops.py new file mode 100644 index 0000000..3cdbe18 --- /dev/null +++ b/oraqle/compiler/nodes/plaintext_ops.py @@ -0,0 +1,5 @@ + + +Create PlainFunction that performs any named Python (?) code on Revealed inputs. We can maybe create a Plain class to flag revealed values. +We should then also overload the operators on Plain to reflect that (which create functions). +When we finally generate code, we can let people implement those named functions themselves. diff --git a/oraqle/compiler/nodes/univariate.py b/oraqle/compiler/nodes/univariate.py index ca255ba..be2996f 100644 --- a/oraqle/compiler/nodes/univariate.py +++ b/oraqle/compiler/nodes/univariate.py @@ -19,11 +19,11 @@ class UnivariateNode[N: Node](FixedNode[N]): def _node_shape(self) -> str: """Graphviz node shape.""" - def __init__(self, node: N, gf: Type[FieldArray]): + def __init__(self, node: N): """Initialize a univariate node.""" self._node = node assert not isinstance(node, Constant) - super().__init__(gf) + super().__init__(node._gf) def operands(self) -> List["N"]: # noqa: D102 diff --git a/oraqle/mpc/nodes.py b/oraqle/mpc/nodes.py index b2e73ed..965a940 100644 --- a/oraqle/mpc/nodes.py +++ b/oraqle/mpc/nodes.py @@ -2,7 +2,7 @@ from galois import GF -from oraqle.compiler.boolean.bool import ReducedBooleanInput +from oraqle.compiler.boolean.bool import ReducedBooleanInput, _cast_to from oraqle.compiler.circuit import Circuit from oraqle.compiler.nodes.abstract import Node from oraqle.compiler.sets.bitset import BitSet, BitSetContainer @@ -10,13 +10,27 @@ # FIXME: all the inputs must also be mpc nodes... -class MpcNode: +class MpcNode(Node): - def __init__(self, node: Node, known_by: Set[PartyId], leakable_to: Set[PartyId], computed_by: Set[PartyId]): - self._node = node - self._known_by = known_by - self._leakable_to = leakable_to - self._computed_by = computed_by + _known_by: Set[PartyId] + _leakable_to: Set[PartyId] + _computed_by: Set[PartyId] + + # def __init__(self, node: Node, known_by: Set[PartyId], leakable_to: Set[PartyId], computed_by: Set[PartyId]): + # self._node = node + # self._known_by = known_by + # self._leakable_to = leakable_to # TODO: Leakable to should always be a superset of known_by + # self._computed_by = computed_by # TODO: This is an inconvient interface + + +def to_mpc(node: Node, known_by: Set[PartyId], leakable_to: Set[PartyId], computed_by: Set[PartyId]) -> MpcNode: + result = _cast_to(node, MpcNode) + + result._known_by = known_by + result._leakable_to = leakable_to + result._computed_by = computed_by + + return result if __name__ == "__main__": @@ -26,14 +40,17 @@ def __init__(self, node: Node, known_by: Set[PartyId], leakable_to: Set[PartyId] # TODO: Consider immediately creating a bitset (container) using bitset params/set params party_bitsets = [] for party_id in range(5): + #bits = [to_mpc(ReducedBooleanInput(f"b{party_id}_{i}", gf), {PartyId(party_id)}, {PartyId(party_id)}, {PartyId(i) for i in range(5)}) for i in range(10)] bits = [ReducedBooleanInput(f"b{party_id}_{i}", gf) for i in range(10)] bitset = BitSetContainer(bits) party_bitsets.append(bitset) intersection = BitSet.intersection(*party_bitsets) - circuit = Circuit([intersection.contains_element(element) for element in [1, 4, 5, 9]]) + circuit = Circuit([intersection.contains_element(element) for element in [1, 4, 5, 9]]) # TODO: Currently we output to party 1 circuit.to_pdf("debug.pdf") arithmetic_circuit = circuit.arithmetize() arithmetic_circuit.to_pdf("debug2.pdf") + + Forget this whole mpc folder for now: I want to first generate extended arithmetic circuits, which can then be assigned (maybe even using a MILP) diff --git a/oraqle/mpc/protocol.py b/oraqle/mpc/protocol.py new file mode 100644 index 0000000..456f375 --- /dev/null +++ b/oraqle/mpc/protocol.py @@ -0,0 +1,12 @@ +from typing import List +from oraqle.compiler.nodes.abstract import Node +from oraqle.mpc.parties import PartyId + + +class Protocol: + + def __init__(self, party_count) -> None: + self._operations: List[List[Node]] = [[]] * party_count + + def assign_operation(self, to_party: PartyId, node: Node): + self._operations[to_party - 1].append(node)