Skip to content

Commit

Permalink
Add notes
Browse files Browse the repository at this point in the history
  • Loading branch information
jellevos committed Jan 10, 2025
1 parent 1d62b70 commit 83d4d0f
Show file tree
Hide file tree
Showing 16 changed files with 188 additions and 39 deletions.
3 changes: 2 additions & 1 deletion oraqle/compiler/boolean/bool.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class Boolean(Node):
def __invert__(self) -> "Node":
from oraqle.compiler.boolean.bool_neg import Neg

return Neg(self, self._gf)
return Neg(self)

def bool_or(self, other: "Boolean", flatten=True) -> "Boolean":
"""Performs an OR operation between `self` and `other`, possibly flattening the result into an OR operation between many operands.
Expand Down Expand Up @@ -210,6 +210,7 @@ def cast_to_inv_unreduced_boolean(node: Node) -> InvUnreducedBoolean:
return _cast_to(node, InvUnreducedBoolean)


# TODO: Think about the security of the transformations below
class UnreducedBoolean(Boolean):

def transform_to_reduced_boolean(self) -> ReducedBoolean:
Expand Down
13 changes: 6 additions & 7 deletions oraqle/compiler/boolean/bool_and.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ def _inner_operation(self, a: FieldArray, b: FieldArray) -> FieldArray:
return a + b

def _arithmetize_inner(self, strategy: str) -> Node:
Randomize!
# TODO: We need to randomize (i.e. make it a Sum with random multiplicities)
# TODO: Consider not supporting additions between Booleans unless they are cast to field elements
return cast_to_inv_unreduced_boolean(sum_(*self._operands)).arithmetize(strategy)
Expand Down Expand Up @@ -175,12 +176,12 @@ def _arithmetize_inner(self, strategy: str) -> Node: # noqa: PLR0911, PLR0912
max_depth = popped.priority

if total_sum is None:
total_sum = ReducedNeg(popped.item, self._gf)
total_sum = ReducedNeg(popped.item)
else:
total_sum += ReducedNeg(popped.item, self._gf)
total_sum += ReducedNeg(popped.item)

assert total_sum is not None
final_result = ReducedNeg(ReducedIsNonZero(total_sum, self._gf), self._gf).arithmetize(strategy)
final_result = ReducedNeg(ReducedIsNonZero(total_sum)).arithmetize(strategy)

assert max_depth is not None
heappush(queue, _PrioritizedItem(max_depth, final_result))
Expand All @@ -189,7 +190,7 @@ def _arithmetize_inner(self, strategy: str) -> Node: # noqa: PLR0911, PLR0912
return heappop(queue).item

dummy_node = ReducedBooleanInput("dummy_node", self._gf)
is_non_zero = ReducedIsNonZero(dummy_node, self._gf).arithmetize(strategy).to_arithmetic()
is_non_zero = ReducedIsNonZero(dummy_node).arithmetize(strategy).to_arithmetic()
cost = is_non_zero.multiplicative_cost(
1.0
) # FIXME: This needs to be the actual squaring cost
Expand All @@ -202,12 +203,10 @@ def _arithmetize_inner(self, strategy: str) -> Node: # noqa: PLR0911, PLR0912
return ReducedNeg(
ReducedIsNonZero(
Sum(
Counter({UnoverloadedWrapper(ReducedNeg(node.item, self._gf)): 1 for node in queue}),
Counter({UnoverloadedWrapper(ReducedNeg(node.item)): 1 for node in queue}),
self._gf,
),
self._gf,
),
self._gf,
).arithmetize(strategy)

def _arithmetize_depth_aware_inner(self, cost_of_squaring: float) -> CostParetoFront:
Expand Down
2 changes: 1 addition & 1 deletion oraqle/compiler/boolean/bool_neg.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def _arithmetize_depth_aware_inner(self, cost_of_squaring: float) -> CostParetoF
raise NotImplementedError("TODO!")

def transform_to_reduced_boolean(self) -> ReducedBoolean:
return ReducedNeg(self._node.transform_to_reduced_boolean(), self._gf)
return ReducedNeg(self._node.transform_to_reduced_boolean())

def transform_to_unreduced_boolean(self) -> UnreducedBoolean:
raise NotImplementedError("TODO!")
Expand Down
1 change: 0 additions & 1 deletion oraqle/compiler/boolean/bool_or.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,6 @@ def _arithmetize_inner(self, strategy: str) -> Node:
}, # type: ignore
self._gf,
),
self._gf,
).arithmetize(strategy)

def _arithmetize_depth_aware_inner(self, cost_of_squaring: float) -> CostParetoFront:
Expand Down
15 changes: 4 additions & 11 deletions oraqle/compiler/comparison/comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,6 @@ def _arithmetize_inner(self, strategy: str) -> Node:

return InUpperHalf(
Subtraction(left.arithmetize(strategy), right.arithmetize(strategy), self._gf),
self._gf,
).arithmetize(strategy)

def _arithmetize_depth_aware_inner(self, cost_of_squaring: float) -> CostParetoFront:
Expand All @@ -100,7 +99,6 @@ def _arithmetize_depth_aware_inner(self, cost_of_squaring: float) -> CostParetoF

sub_front = InUpperHalf(
Subtraction(left_node, right_node, self._gf),
self._gf,
).arithmetize_depth_aware(cost_of_squaring)

front.add_front(sub_front)
Expand Down Expand Up @@ -148,15 +146,15 @@ def _arithmetize_inner(self, strategy: str) -> Node:

# Test whether left and right are in the same range
same_range = (left_is_small & right_is_small) + (
ReducedNeg(left_is_small, self._gf) & ReducedNeg(right_is_small, self._gf)
ReducedNeg(left_is_small) & ReducedNeg(right_is_small)
)

# Performs left < right on the reduced inputs, note that if both are in the upper half the difference is still small enough for a semi-comparison
comparison = SemiStrictComparison(left, right, less_than=True, gf=self._gf)
result = same_range * comparison

# Performs left < right when one if small and the other is large
right_is_larger = left_is_small & ReducedNeg(right_is_small, self._gf)
right_is_larger = left_is_small & ReducedNeg(right_is_small)
result += right_is_larger

return result.arithmetize(strategy)
Expand Down Expand Up @@ -188,7 +186,7 @@ def _arithmetize_depth_aware_inner(self, cost_of_squaring: float) -> CostParetoF

# Test whether left and right are in the same range
same_range = (left_is_small & right_is_small) + (
ReducedNeg(left_is_small, self._gf) & ReducedNeg(right_is_small, self._gf)
ReducedNeg(left_is_small) & ReducedNeg(right_is_small)
)

# Performs left < right on the reduced inputs, note that if both are in the upper half the difference is still small enough for a semi-comparison
Expand All @@ -198,7 +196,7 @@ def _arithmetize_depth_aware_inner(self, cost_of_squaring: float) -> CostParetoF
result = same_range * comparison

# Performs left < right when one if small and the other is large
right_is_larger = left_is_small & ReducedNeg(right_is_small, self._gf)
right_is_larger = left_is_small & ReducedNeg(right_is_small)
result += right_is_larger

front.add_front(result.arithmetize_depth_aware(cost_of_squaring))
Expand Down Expand Up @@ -233,15 +231,13 @@ def _arithmetize_inner(self, strategy: str) -> Node:
less_than=not self._less_than,
gf=self._gf,
),
self._gf,
).arithmetize(strategy)

def _arithmetize_depth_aware_inner(self, cost_of_squaring: float) -> CostParetoFront:
return ReducedNeg(
SemiStrictComparison(
self._left, self._right, less_than=not self._less_than, gf=self._gf
),
self._gf,
).arithmetize_depth_aware(cost_of_squaring)


Expand Down Expand Up @@ -270,13 +266,11 @@ def _arithmetize_inner(self, strategy: str) -> Node:
less_than=not self._less_than,
gf=self._gf,
),
self._gf,
).arithmetize(strategy)

def _arithmetize_depth_aware_inner(self, cost_of_squaring: float) -> CostParetoFront:
return ReducedNeg(
StrictComparison(self._left, self._right, less_than=not self._less_than, gf=self._gf),
self._gf,
).arithmetize_depth_aware(cost_of_squaring)


Expand Down Expand Up @@ -328,7 +322,6 @@ def _arithmetize_inner(self, strategy: str) -> Node:
Subtraction(
self._left.arithmetize(strategy), self._right.arithmetize(strategy), self._gf
),
self._gf,
).arithmetize(strategy)

def _arithmetize_depth_aware_inner(self, cost_of_squaring: float) -> CostParetoFront:
Expand Down
14 changes: 13 additions & 1 deletion oraqle/compiler/comparison/equality.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@
from oraqle.compiler.arithmetic.subtraction import Subtraction
from oraqle.compiler.boolean.bool import Boolean, InvUnreducedBoolean, ReducedBoolean, UnreducedBoolean, cast_to_reduced_boolean
from oraqle.compiler.boolean.bool_neg import Neg, ReducedNeg
from oraqle.compiler.nodes.abstract import CostParetoFront, Node
from oraqle.compiler.nodes.abstract import CostParetoFront, ExtendedArithmeticNode, Node
from oraqle.compiler.nodes.binary_arithmetic import CommutativeBinaryNode
from oraqle.compiler.nodes.extended import Random, Reveal
from oraqle.compiler.nodes.leafs import Input
from oraqle.compiler.nodes.univariate import UnivariateNode

Expand All @@ -32,6 +33,17 @@ def _operation_inner(self, input: FieldArray) -> FieldArray:
def _arithmetize_inner(self, strategy: str) -> Node:
return self.arithmetize_all_representations(strategy)

def _arithmetize_extended_inner(self) -> ExtendedArithmeticNode:
arithmetic = self.arithmetize_all_representations("best-effort")
extended_arithmetic = Reveal(Random(self._gf) * self._node.arithmetize()) == 0

if metric(extended_arithmetic) < metric(arithmetic):
return extended_arithmetic

return arithmetic

Keep in mind that this is just an example; it is not necessary for the thesis!!

def _arithmetize_depth_aware_inner(self, cost_of_squaring: float) -> CostParetoFront:
raise NotImplementedError("TODO!")

Expand Down
4 changes: 2 additions & 2 deletions oraqle/compiler/comparison/in_upper_half.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ def test_evaluate_mod7(): # noqa: D103
gf = GF(7)

x = Input("x", gf)
node = InUpperHalf(x, gf)
node = InUpperHalf(x)

for i in range(3):
assert node.evaluate({"x": gf(i)}) == gf(0)
Expand All @@ -264,7 +264,7 @@ def test_evaluate_arithmetized_mod7(): # noqa: D103
gf = GF(7)

x = Input("x", gf)
node = InUpperHalf(x, gf).arithmetize("best-effort")
node = InUpperHalf(x).arithmetize("best-effort")
node.clear_cache(set())

for i in range(3):
Expand Down
2 changes: 1 addition & 1 deletion oraqle/compiler/graphviz.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def add_link(self, from_id: int, to_id: int, **kwargs):
self._links.append((from_id, to_id, kwargs))

def add_cluster(self, ids: Sequence[int], **kwargs):
"""Adds a cluster containings the nodes with the given IDs. The keyword arguments are directly put into the DOT file."""
"""Adds a cluster containing the nodes with the given IDs. The keyword arguments are directly put into the DOT file."""
assert -1 not in ids
self._clusters.append((ids, kwargs))

Expand Down
29 changes: 28 additions & 1 deletion oraqle/compiler/nodes/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

if TYPE_CHECKING:
from oraqle.compiler.boolean.bool import Boolean
from oraqle.mpc.protocol import Protocol

from oraqle.compiler.graphviz import DotFile
from oraqle.compiler.instructions import ArithmeticInstruction
Expand Down Expand Up @@ -302,6 +303,7 @@ def __init__(self, gf: Type[FieldArray]):
self._instruction_cache: Optional[int] = None
self._arithmetic_cache: Optional[ArithmeticNode] = None
self._parent_count_cache: Optional[int] = None
self._arithmetize_extended_cache: Optional[ExtendedArithmeticNode] = None

self._hash = None

Expand Down Expand Up @@ -336,6 +338,7 @@ def clear_cache(self, already_cleared: Set[int]):
self._instruction_cache: Optional[int] = None
self._arithmetic_cache: Optional[ArithmeticNode] = None
self._parent_count_cache: Optional[int] = None
self._arithmetize_extended_cache: Optional[ExtendedArithmeticNode] = None

self._hash = None

Expand Down Expand Up @@ -440,6 +443,19 @@ def to_arithmetic(self) -> "ArithmeticNode":
raise Exception(
f"This node does not have a direct arithmetic equivalent: {self}. Consider first calling `arithmetize`."
)

def arithmetize_extended(self) -> "ExtendedArithmeticNode":
"""Arithmetizes this node as an extended arithmetic circuit, which includes random and reveal nodes.
The default implementation simply calls arithmetize, because every arithmetic circuit is also an extended arithmetic circuit.
Returns:
An ExtendedArithmeticNode that computes this Node.
"""
# TODO: propagate known by?
# TODO: Add leak to? E.g. by adding reveal after it.

return self.arithmetize("best-effort").to_arithmetic()

def add(self, other: "Node", flatten=True) -> "Node":
"""Performs a summation between `self` and `other`, possibly flattening any sums.
Expand Down Expand Up @@ -620,10 +636,21 @@ def __eq__(self, other) -> bool:
return False

return self.node.is_equivalent(other.node)


class ExtendedArithmeticNode(Node):

@abstractmethod
def operands(self) -> List["ExtendedArithmeticNode"]:
"""Returns the operands (children) of this node. The list can be empty. The nodes MUST be extended arithmetic nodes."""

@abstractmethod
def set_operands(self, operands: List["ExtendedArithmeticNode"]):
"""Overwrites the operands of this node. The nodes MUST be extended arithmetic nodes."""


# TODO: Do we need a separate class to distinguish nodes from arithmetic nodes (which only have arithmetic operands)?
class ArithmeticNode(Node):
class ArithmeticNode(ExtendedArithmeticNode):
"""Extension of Node to indicate that this is a node permitted in a purely arithmetic circuit (with binary additions and multiplications).
The ArithmeticNode 'mixin' must always come before the base class in the class declaration.
Expand Down
69 changes: 69 additions & 0 deletions oraqle/compiler/nodes/extended.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import random
from typing import List, Type
from galois import FieldArray
from oraqle.compiler.nodes.abstract import CostParetoFront, ExtendedArithmeticNode, Node
from oraqle.compiler.nodes.leafs import LeafNode
from oraqle.compiler.nodes.univariate import UnivariateNode


class Reveal(UnivariateNode, ExtendedArithmeticNode):

@property
def _node_shape(self) -> str:
return "circle"

@property
def _hash_name(self) -> str:
return "reveal"

@property
def _node_label(self) -> str:
return "Reveal"

def _arithmetize_inner(self, strategy: str) -> Node:
raise NotImplementedError("Reveal cannot be arithmetized: arithmetic circuits only contain arithmetic operations.")

def _arithmetize_depth_aware_inner(self, cost_of_squaring: float) -> CostParetoFront:
raise NotImplementedError("Reveal cannot be arithmetized: arithmetic circuits only contain arithmetic operations.")

def _operation_inner(self, input: FieldArray) -> FieldArray:
return input

# FIXME: Overload operators to create *plaintext* operations


class Random(LeafNode, ExtendedArithmeticNode):

@property
def _node_shape(self) -> str:
return "circle"

@property
def _hash_name(self) -> str:
return "random"

@property
def _node_label(self) -> str:
return "Random"

def __init__(self, gf: type[FieldArray]):
self._hash = hash(random.randbytes(16)) # TODO: Not neat
super().__init__(gf)

def __hash__(self) -> int:
return self._hash

def _arithmetize_inner(self, strategy: str) -> Node:
raise NotImplementedError("Reveal cannot be arithmetized: arithmetic circuits only contain arithmetic operations.")

def _arithmetize_depth_aware_inner(self, cost_of_squaring: float) -> CostParetoFront:
raise NotImplementedError("Reveal cannot be arithmetized: arithmetic circuits only contain arithmetic operations.")

def is_equivalent(self, other: Node) -> bool:
if not isinstance(other, self.__class__):
return False

return self._hash == other._hash

def operation(self, operands: List[FieldArray]) -> FieldArray:
return self._gf.Random()
13 changes: 12 additions & 1 deletion oraqle/compiler/nodes/fixed.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from galois import FieldArray

from oraqle.compiler.nodes.abstract import CostParetoFront, Node
from oraqle.compiler.nodes.abstract import CostParetoFront, ExtendedArithmeticNode, Node


class FixedNode[N: Node](Node):
Expand Down Expand Up @@ -95,6 +95,17 @@ def arithmetize_depth_aware(self, cost_of_squaring: float) -> CostParetoFront:
def _arithmetize_depth_aware_inner(self, cost_of_squaring: float) -> CostParetoFront:
pass

def arithmetize_extended(self) -> ExtendedArithmeticNode:
# TODO: Handle constants similarly as above
if self._arithmetize_extended_cache is None:
self._arithmetize_extended_cache = self._arithmetize_extended_inner()

return self._arithmetize_extended_cache

def _arithmetize_extended_inner(self) -> "ExtendedArithmeticNode":
# TODO: Check if this is a good default implementation, add documentation
return self.arithmetize("best-effort").to_arithmetic()


class BinaryNode(FixedNode):
"""A node with two operands."""
Loading

0 comments on commit 83d4d0f

Please sign in to comment.