Skip to content

Commit

Permalink
Fix several extended arithmetization bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
jellevos committed Jan 13, 2025
1 parent 0bb9f3b commit 7fc3631
Show file tree
Hide file tree
Showing 20 changed files with 337 additions and 94 deletions.
14 changes: 14 additions & 0 deletions oraqle/compiler/arithmetic/exponentiation.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ def __init__(self, node: Node, exponent: int, gf: Type[FieldArray]):

def _operation_inner(self, input: FieldArray, gf: Type[FieldArray]) -> FieldArray:
return input**self._exponent # type: ignore

def _expansion(self) -> Node:
raise NotImplementedError()

def _arithmetize_inner(self, strategy: str) -> "Node":
if strategy == "naive":
Expand Down Expand Up @@ -91,6 +94,17 @@ def _arithmetize_depth_aware_inner(self, cost_of_squaring: float) -> CostParetoF
final_front.add(nodes[-1], depth=depth1 + depth2)

return final_front

# TODO: This is duplicated from arithmetized, can we somehow use arithmetize while still propagating using arithmetize_extended?
def _arithmetize_extended_inner(self) -> "Node":
addition_chain = add_chain_guaranteed(self._exponent, self._gf.characteristic - 1, squaring_cost=1.0)

nodes = [self._node.arithmetize_extended().to_arithmetic()]

for i, j in addition_chain:
nodes.append(Multiplication(nodes[i], nodes[j], self._gf))

return nodes[-1]


def test_depth_aware_arithmetization(): # noqa: D103
Expand Down
15 changes: 5 additions & 10 deletions oraqle/compiler/arithmetic/subtraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,11 @@ def _node_label(self) -> str:

def _operation_inner(self, x, y) -> FieldArray:
return x - y

def _arithmetize_inner(self, strategy: str) -> Node:
# TODO: Reorganize the files: let the arithmetic folder only contain pure arithmetic (including add and mul) and move exponentiation elsewhere.
# TODO: For schemes that support subtraction we do not need to do this. We should only do this transformation during the compiler stage.
return (self._left.arithmetize(strategy) + (Constant(-self._gf(1)) * self._right.arithmetize(strategy))).arithmetize(strategy) # type: ignore # TODO: Should we always perform a final arithmetization in every node for constant folding? E.g. in Node?

def _arithmetize_depth_aware_inner(self, cost_of_squaring: float) -> CostParetoFront:
result = self._left + (Constant(-self._gf(1)) * self._right)
front = result.arithmetize_depth_aware(cost_of_squaring)
return front

# TODO: Reorganize the files: let the arithmetic folder only contain pure arithmetic (including add and mul) and move exponentiation elsewhere.
# TODO: For schemes that support subtraction we do not need to do this. We should only do this transformation during the compiler stage.
def _expansion(self) -> Node:
return self._left + Constant(-self._gf(1)) * self._right


def test_evaluate_mod5(): # noqa: D103
Expand Down
100 changes: 97 additions & 3 deletions oraqle/compiler/boolean/bool_and.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ def _node_label(self) -> str:
def _inner_operation(self, a: FieldArray, b: FieldArray) -> FieldArray:
raise NotImplementedError()

def _expansion(self) -> Node:
raise NotImplementedError()

def _arithmetize_inner(self, strategy: str) -> Node:
# TODO: Currently only supports the reduced representation
return self.transform_to_reduced_boolean().arithmetize(strategy)
Expand All @@ -57,10 +60,14 @@ def _arithmetize_depth_aware_inner(self, cost_of_squaring: float) -> CostParetoF
return self.transform_to_reduced_boolean().arithmetize_depth_aware(cost_of_squaring)

def _arithmetize_extended_inner(self) -> ExtendedArithmeticNode:
print(list(operand.node for operand in self._operands))
# Choose the best of the reduced and unreduced implementations
reduced = self.transform_to_reduced_boolean().arithmetize_extended()
inv_unreduced = self.transform_to_inv_unreduced_boolean().arithmetize_extended()

Circuit([reduced.to_arithmetic()]).to_pdf("reduced.pdf")
Circuit([inv_unreduced.to_arithmetic()]).to_pdf("inv_unreduced.pdf")

# TODO: Consider multiplicative cost
# TODO: Consider other metrics as well?
if reduced.to_arithmetic().multiplicative_size() <= inv_unreduced.to_arithmetic().multiplicative_size():
Expand Down Expand Up @@ -109,13 +116,19 @@ def _node_label(self) -> str:

def _inner_operation(self, a: FieldArray, b: FieldArray) -> FieldArray:
return a + b

def _expansion(self) -> Node:
raise NotImplementedError()

def _arithmetize_inner(self, strategy: str) -> Node:
# TODO: Consider not supporting additions between Booleans unless they are cast to field elements
return cast_to_inv_unreduced_boolean(SecretRandom(self._gf) * sum_(*(operand.node * PublicRandom(self._gf) for operand in self._operands))).arithmetize(strategy)
raise NotImplementedError("This requires randomization")

def _arithmetize_depth_aware_inner(self, cost_of_squaring: float) -> CostParetoFront:
raise NotImplementedError("TODO!")
raise NotImplementedError("This requires randomization")

def _arithmetize_extended_inner(self) -> ExtendedArithmeticNode:
return cast_to_inv_unreduced_boolean(SecretRandom(self._gf) * sum_(*(operand.node * PublicRandom(self._gf) for operand in self._operands))).arithmetize_extended()


class ReducedAnd(CommutativeUniqueReducibleNode[ReducedBoolean], ReducedBoolean):
Expand All @@ -131,6 +144,9 @@ def _node_label(self) -> str:

def _inner_operation(self, a: FieldArray, b: FieldArray) -> FieldArray:
return self._gf(bool(a) & bool(b))

def _expansion(self) -> Node:
raise NotImplementedError()

def _arithmetize_inner(self, strategy: str) -> Node: # noqa: PLR0911, PLR0912
new_operands: Set[UnoverloadedWrapper] = set()
Expand All @@ -139,7 +155,7 @@ def _arithmetize_inner(self, strategy: str) -> Node: # noqa: PLR0911, PLR0912

if isinstance(new_operand, Constant):
if not bool(new_operand._value):
return Constant(self._gf(0))
return BooleanConstant(self._gf(0))
continue

new_operands.add(UnoverloadedWrapper(new_operand))
Expand Down Expand Up @@ -259,6 +275,84 @@ def _arithmetize_depth_aware_inner(self, cost_of_squaring: float) -> CostParetoF

# TODO: Consider implementing and_flatten

# TODO: This is copied from arithmetize
def _arithmetize_extended_inner(self) -> Node: # noqa: PLR0911, PLR0912
new_operands: Set[UnoverloadedWrapper] = set()
for operand in self._operands:
new_operand = operand.node.arithmetize_extended()

if isinstance(new_operand, Constant):
if not bool(new_operand._value):
return BooleanConstant(self._gf(0))
continue

new_operands.add(UnoverloadedWrapper(new_operand))

if len(new_operands) == 0:
return BooleanConstant(self._gf(1))
elif len(new_operands) == 1:
return next(iter(new_operands)).node

# TODO: Calling to_arithmetic here should not be necessary if we can decide the predicted depth
queue = [
(
_PrioritizedItem(
0, operand.node
) # TODO: We should just maybe make a breadth method on Node
if isinstance(operand.node, Constant)
else _PrioritizedItem(
operand.node.to_arithmetic().multiplicative_depth(), operand.node
)
)
for operand in new_operands
]
heapify(queue)

while len(queue) > (self._gf._characteristic - 1):
total_sum = None
max_depth = None
for _ in range(self._gf._characteristic - 1):
if len(queue) == 0:
break

popped = heappop(queue)
if max_depth is None or max_depth < popped.priority:
max_depth = popped.priority

if total_sum is None:
total_sum = ReducedNeg(popped.item)
else:
total_sum += ReducedNeg(popped.item)

assert total_sum is not None
final_result = ReducedNeg(ReducedIsNonZero(total_sum)).arithmetize_extended()

assert max_depth is not None
heappush(queue, _PrioritizedItem(max_depth, final_result))

if len(queue) == 1:
return heappop(queue).item

dummy_node = ReducedBooleanInput("dummy_node", self._gf)
is_non_zero = ReducedIsNonZero(dummy_node).arithmetize_extended().to_arithmetic()
cost = is_non_zero.multiplicative_cost(
1.0
) # FIXME: This needs to be the actual squaring cost

if len(queue) - 1 < cost:
return cast_to_reduced_boolean(Product(
Counter({UnoverloadedWrapper(operand.item): 1 for operand in queue}), self._gf
)).arithmetize_extended()

return ReducedNeg(
ReducedIsNonZero(
Sum(
Counter({UnoverloadedWrapper(ReducedNeg(node.item)): 1 for node in queue}),
self._gf,
),
),
).arithmetize_extended()


def test_evaluate_mod3(): # noqa: D103
gf = GF(3)
Expand Down
22 changes: 10 additions & 12 deletions oraqle/compiler/boolean/bool_neg.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,19 @@ def _node_label(self) -> str:
def _operation_inner(self, input: FieldArray) -> FieldArray:
raise NotImplementedError("TODO: Not sure if it makes sense to implement this")

# TODO: We can create a high-level implementation that tries all three transformations and chooses the lowest size one
def _expansion(self) -> Node:
raise NotImplementedError()

def _arithmetize_inner(self, strategy: str) -> Node:
return self.arithmetize_all_representations(strategy)

def _arithmetize_depth_aware_inner(self, cost_of_squaring: float) -> CostParetoFront:
raise NotImplementedError("TODO!")

def _arithmetize_extended_inner(self) -> Node:
# FIXME: We should make a version of arithmetize_all_repr that does extended arithmetization
return self.arithmetize_all_representations("best-effort")

def transform_to_reduced_boolean(self) -> ReducedBoolean:
return ReducedNeg(self._node.transform_to_reduced_boolean())

Expand Down Expand Up @@ -62,14 +68,6 @@ def _node_label(self) -> str:
def _operation_inner(self, input: FieldArray) -> FieldArray:
assert input in {0, 1}
return self._gf(not bool(input))

def _arithmetize_inner(self, strategy: str) -> Boolean:
return Subtraction(
Constant(self._gf(1)), self._node.arithmetize(strategy), self._gf
).arithmetize(strategy) # type: ignore

# FIXME: CostParetoFront should be generic
def _arithmetize_depth_aware_inner(self, cost_of_squaring: float) -> CostParetoFront:
return Subtraction(Constant(self._gf(1)), self._node, self._gf).arithmetize_depth_aware(
cost_of_squaring
)

def _expansion(self) -> Node:
return 1 - self._node
20 changes: 16 additions & 4 deletions oraqle/compiler/boolean/bool_or.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from oraqle.compiler.boolean.bool import Boolean, BooleanConstant, InvUnreducedBoolean, ReducedBoolean, ReducedBooleanInput, UnreducedBoolean, cast_to_unreduced_boolean
from oraqle.compiler.boolean.bool_and import ReducedAnd, _find_depth_cost_front
from oraqle.compiler.boolean.bool_neg import ReducedNeg
from oraqle.compiler.nodes.abstract import CostParetoFront, Node, UnoverloadedWrapper
from oraqle.compiler.nodes.abstract import CostParetoFront, ExtendedArithmeticNode, Node, UnoverloadedWrapper
from oraqle.compiler.nodes.arbitrary_arithmetic import sum_
from oraqle.compiler.nodes.extended import PublicRandom, SecretRandom
from oraqle.compiler.nodes.flexible import CommutativeUniqueReducibleNode
Expand All @@ -29,6 +29,9 @@ def _node_label(self) -> str:
def _inner_operation(self, a: FieldArray, b: FieldArray) -> FieldArray:
raise NotImplementedError()

def _expansion(self) -> Node:
return super()._expansion()

def _arithmetize_inner(self, strategy: str) -> Node:
# Choose the best of the reduced and unreduced implementations
reduced = self.transform_to_reduced_boolean().arithmetize(strategy)
Expand Down Expand Up @@ -84,12 +87,18 @@ def _node_label(self) -> str:

def _inner_operation(self, a: FieldArray, b: FieldArray) -> FieldArray:
return self._gf(a + b)

def _expansion(self) -> Node:
raise NotImplementedError()

def _arithmetize_inner(self, strategy: str) -> Node:
return cast_to_unreduced_boolean(SecretRandom(self._gf) * sum_(*(operand.node * PublicRandom(self._gf) for operand in self._operands))).arithmetize(strategy)
raise NotImplementedError("This requires randomization")

def _arithmetize_depth_aware_inner(self, cost_of_squaring: float) -> CostParetoFront:
raise NotImplementedError("TODO!")
raise NotImplementedError("This requires randomization")

def _arithmetize_extended_inner(self) -> ExtendedArithmeticNode:
return cast_to_unreduced_boolean(SecretRandom(self._gf) * sum_(*(operand.node * PublicRandom(self._gf) for operand in self._operands))).arithmetize_extended()


class ReducedOr(CommutativeUniqueReducibleNode[ReducedBoolean], ReducedBoolean):
Expand All @@ -105,6 +114,9 @@ def _node_label(self) -> str:

def _inner_operation(self, a: FieldArray, b: FieldArray) -> FieldArray:
return self._gf(bool(a) | bool(b))

def _expansion(self) -> Node:
raise NotImplementedError()

def _arithmetize_inner(self, strategy: str) -> Node:
# FIXME: Handle what happens when arithmetize outputs a constant!
Expand Down
29 changes: 28 additions & 1 deletion oraqle/compiler/circuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from oraqle.compiler.graphviz import DotFile
from oraqle.compiler.instructions import ArithmeticProgram, OutputInstruction
from oraqle.compiler.nodes.abstract import ArithmeticNode, Node
from oraqle.compiler.nodes.abstract import ArithmeticNode, ExtendedArithmeticNode, Node


class Circuit:
Expand Down Expand Up @@ -133,6 +133,23 @@ def arithmetize_depth_aware(

arithmetic_circuit._clear_cache()
return front

def arithmetize_extended(self) -> "ExtendedArithmeticCircuit":
"""Performs *extended* arithmetization on this circuit by calling arithmetize_extended on all outputs.
This replaces all high-level operations with extended arithmetic operations (constants, additions, multiplications, random, and reveal).
The current implementation only aims at reducing the total number of multiplications.
Returns:
An equivalent extended arithmetic circuit with low multiplicative size.
"""
extended_arithmetic_circuit = ExtendedArithmeticCircuit(
[output.arithmetize_extended() for output in self._outputs]
)
# FIXME: Also call to_arithmetic
extended_arithmetic_circuit._clear_cache()

return extended_arithmetic_circuit

def _clear_cache(self):
already_cleared = set()
Expand Down Expand Up @@ -193,6 +210,16 @@ def _clear_cache(self):
"""


class ExtendedArithmeticCircuit(Circuit):

def __init__(self, outputs: List[ExtendedArithmeticNode]):
"""Initialize a circuit with the given `outputs`."""
assert len(outputs) > 0
self._outputs = outputs
self._gf = outputs[0]._gf


# TODO: This should probably be a subclass of ExtendedArithmeticCircuit
class ArithmeticCircuit(Circuit):
"""Represents an arithmetic circuit over a fixed finite field, so it only contains arithmetic nodes."""

Expand Down
15 changes: 9 additions & 6 deletions oraqle/compiler/comparison/comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@ def is_equivalent(self, other: Node) -> bool: # noqa: D102
return self._left.is_equivalent(other._right) and self._right.is_equivalent(other._left)
else:
return self._left.is_equivalent(other._left) and self._right.is_equivalent(other._right)

def _expansion(self) -> Node:
raise NotImplementedError()


class SemiStrictComparison(AbstractComparison):
Expand Down Expand Up @@ -287,8 +290,8 @@ def _node_label(self) -> str:

def _operation_inner(self, x, y) -> FieldArray:
return self._gf(int(int(x) < int(y)))

def _arithmetize_inner(self, strategy: str) -> Node:
def _expansion(self) -> Node:
out = Constant(self._gf(0))

p = self._gf.characteristic
Expand All @@ -297,10 +300,7 @@ def _arithmetize_inner(self, strategy: str) -> Node:
p - 1
)

return out.arithmetize(strategy)

def _arithmetize_depth_aware_inner(self, cost_of_squaring: float) -> CostParetoFront:
raise NotImplementedError()
return out


class IliashenkoZuccaSemiLessThan(NonCommutativeBinaryNode):
Expand All @@ -316,6 +316,9 @@ def _node_label(self) -> str:

def _operation_inner(self, x, y) -> FieldArray:
return self._gf(int(int(x) < int(y)))

def _expansion(self) -> Node:
raise NotImplementedError()

def _arithmetize_inner(self, strategy: str) -> Node:
return IliashenkoZuccaInUpperHalf(
Expand Down
Loading

0 comments on commit 7fc3631

Please sign in to comment.