Skip to content

Commit

Permalink
Add consistent Boolean types
Browse files Browse the repository at this point in the history
  • Loading branch information
jellevos committed Jan 7, 2025
1 parent 59196d5 commit 2ae6517
Show file tree
Hide file tree
Showing 9 changed files with 177 additions and 73 deletions.
21 changes: 11 additions & 10 deletions oraqle/circuits/cardio.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Type
from galois import GF, FieldArray

from oraqle.compiler.boolean.bool import BooleanInput
from oraqle.compiler.boolean.bool_neg import Neg
from oraqle.compiler.boolean.bool_or import any_
from oraqle.compiler.circuit import Circuit
Expand All @@ -12,12 +13,12 @@

def construct_cardio_risk_circuit(gf: Type[FieldArray]) -> Node:
"""Returns the cardio circuit from https://arxiv.org/abs/2101.07078."""
man = Input("man", gf)
woman = Input("woman", gf)
smoking = Input("smoking", gf)
man = BooleanInput("man", gf)
woman = BooleanInput("woman", gf)
smoking = BooleanInput("smoking", gf)
age = Input("age", gf)
diabetic = Input("diabetic", gf)
hbp = Input("hbp", gf)
diabetic = BooleanInput("diabetic", gf)
hbp = BooleanInput("hbp", gf)
cholesterol = Input("cholesterol", gf)
weight = Input("weight", gf)
height = Input("height", gf)
Expand All @@ -40,12 +41,12 @@ def construct_cardio_risk_circuit(gf: Type[FieldArray]) -> Node:

def construct_cardio_elevated_risk_circuit(gf: Type[FieldArray]) -> Node:
"""Returns a variant of the cardio circuit that returns a Boolean indicating whether any risk factor returned true."""
man = Input("man", gf)
woman = Input("woman", gf)
smoking = Input("smoking", gf)
man = BooleanInput("man", gf)
woman = BooleanInput("woman", gf)
smoking = BooleanInput("smoking", gf)
age = Input("age", gf)
diabetic = Input("diabetic", gf)
hbp = Input("hbp", gf)
diabetic = BooleanInput("diabetic", gf)
hbp = BooleanInput("hbp", gf)
cholesterol = Input("cholesterol", gf)
weight = Input("weight", gf)
height = Input("height", gf)
Expand Down
93 changes: 84 additions & 9 deletions oraqle/compiler/boolean/bool.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,21 @@
from galois import GF
from oraqle.compiler.nodes.abstract import Node, UnoverloadedWrapper
from abc import ABCMeta
from typing import Callable, Dict, Type
from galois import GF, FieldArray
from oraqle.compiler.nodes.abstract import CostParetoFront, Node, UnoverloadedWrapper
from oraqle.compiler.nodes.leafs import Constant
from oraqle.compiler.nodes.leafs import Input


# TODO: Make outgoing edges a specific color
class Boolean(Node):
"""A Boolean node indicates that the wrapped Node is a Boolean."""
"""A Boolean node indicates that this Node outputs a Boolean."""

def __invert__(self) -> "Node":
from oraqle.compiler.boolean.bool_neg import Neg

return Neg(self, self._gf)

def bool_or(self, other: "Node", flatten=True) -> "Node":
def bool_or(self, other: "Boolean", flatten=True) -> "Boolean":
"""Performs an OR operation between `self` and `other`, possibly flattening the result into an OR operation between many operands.
It is possible to disable flattening by setting `flatten=False`.
Expand All @@ -27,7 +31,7 @@ def bool_or(self, other: "Node", flatten=True) -> "Node":

if isinstance(other, Constant):
if bool(other._value):
return Constant(self._gf(1))
return BooleanConstant(self._gf(1))
else:
return self

Expand All @@ -36,13 +40,13 @@ def bool_or(self, other: "Node", flatten=True) -> "Node":
else:
return Or({UnoverloadedWrapper(self), UnoverloadedWrapper(other)}, self._gf)

def __or__(self, other) -> "Node":
def __or__(self, other: "Boolean") -> "Boolean":
if not isinstance(other, Node):
raise Exception(f"The RHS of this OR is not a Node: {self} | {other}")

return self.bool_or(other)

def bool_and(self, other: "Node", flatten=True) -> "Node":
def bool_and(self, other: "Boolean", flatten=True) -> "Boolean":
"""Performs an AND operation between `self` and `other`, possibly flattening the result into an AND operation between many operands.
It is possible to disable flattening by setting `flatten=False`.
Expand All @@ -60,18 +64,89 @@ def bool_and(self, other: "Node", flatten=True) -> "Node":
if bool(other._value):
return self
else:
return Constant(self._gf(0))
return BooleanConstant(self._gf(0))

if self.is_equivalent(other):
return self
else:
return And({UnoverloadedWrapper(self), UnoverloadedWrapper(other)}, self._gf)

def __and__(self, other) -> "Node":
def __and__(self, other: "Boolean") -> "Boolean":
if not isinstance(other, Node):
raise Exception(f"The RHS of this AND is not a Node: {self} & {other}")

return self.bool_and(other)

def arithmetize(self, strategy: str) -> "Boolean":
return super().arithmetize(strategy) # type: ignore

# TODO: Also make the output of depth-aware arithmetization a front of Booleans


class BooleanInput(Input, Boolean):

def evaluate(self, actual_inputs: Dict[str, FieldArray]) -> FieldArray:
output = super().evaluate(actual_inputs)
if not (output == 0 or output == 1):
raise ValueError(f"Not a Boolean: {output}")
return output


# FIXME: This is actually a ReducedBooleanConstant
class BooleanConstant(Constant, Boolean):

def bool_or(self, other: Boolean, flatten=True) -> Node: # noqa: D102
if isinstance(other, Constant):
return Constant(self._gf(bool(self._value) | bool(other._value)))

return other.bool_or(self, flatten)

def bool_and(self, other: Boolean, flatten=True) -> Node: # noqa: D102
if isinstance(other, Constant):
return Constant(self._gf(bool(self._value) & bool(other._value)))

return other.bool_and(self, flatten)


# TODO: Make the old implementations part of ReducedBoolean
class ReducedBoolean(Boolean):
pass


_class_cache = {}

def _get_dynamic_class(name, bases, attrs):
"""Tracks dynamic classes so that cast_to_reduced_boolean on a specific class always returns the same dynamic Boolean class."""
key = (name, bases)
if key not in _class_cache:
_class_cache[key] = type(name, bases, attrs)
return _class_cache[key]


def cast_to_reduced_boolean(node: Node) -> ReducedBoolean:
"""
Casts this Node to a Boolean. This results in a new class called <node's class name>_ReducedBool.
!!! warning
This modifies the node *in place*, so the node is now a Boolean node.
"""
BooleanNode = _get_dynamic_class(f'{node.__class__.__name__}_ReducedBool', (node.__class__, ReducedBoolean), dict(node.__class__.__dict__)) # type: ignore
node.__class__ = BooleanNode
return node # type: ignore


def test_isinstance_cast_reduced_boolean():
from oraqle.compiler.nodes.leafs import Input
from oraqle.compiler.nodes.arbitrary_arithmetic import Sum, sum_

gf = GF(7)
a = Input("a", gf)
b = Input("b", gf)
c = Input("c", gf)

s = sum_(a, b, c)
assert isinstance(s, Sum)
assert isinstance(cast_to_reduced_boolean(s), Sum)


if __name__ == "__main__":
Expand Down
33 changes: 17 additions & 16 deletions oraqle/compiler/boolean/bool_and.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from oraqle.add_chains.addition_chains_front import gen_pareto_front
from oraqle.add_chains.addition_chains_mod import chain_cost
from oraqle.add_chains.solving import extract_indices
from oraqle.compiler.boolean.bool import Boolean, BooleanConstant, BooleanInput
from oraqle.compiler.boolean.bool_neg import Neg
from oraqle.compiler.comparison.equality import IsNonZero
from oraqle.compiler.nodes.abstract import (
Expand All @@ -30,7 +31,7 @@
from oraqle.compiler.nodes.leafs import Constant, Input


class And(CommutativeUniqueReducibleNode):
class And(CommutativeUniqueReducibleNode[Boolean], Boolean):
"""Performs an AND operation over several operands. The user must ensure that the operands are Booleans."""

@property
Expand Down Expand Up @@ -171,7 +172,7 @@ def _arithmetize_depth_aware_inner(self, cost_of_squaring: float) -> CostParetoF

return front

def and_flatten(self, other: Node) -> Node:
def and_flatten(self, other: Boolean) -> Boolean:
"""Performs an AND operation with `other`, flattening the `And` node if either of the two is also an `And` and absorbing `Constant`s.
Returns:
Expand All @@ -181,7 +182,7 @@ def and_flatten(self, other: Node) -> Node:
if bool(other._value):
return self
else:
return Constant(self._gf(0))
return BooleanConstant(self._gf(0))

if isinstance(other, And):
return And(self._operands | other._operands, self._gf)
Expand All @@ -194,8 +195,8 @@ def and_flatten(self, other: Node) -> Node:
def test_evaluate_mod3(): # noqa: D103
gf = GF(3)

a = Input("a", gf)
b = Input("b", gf)
a = BooleanInput("a", gf) # TODO: Define ReducedBooleanInput
b = BooleanInput("b", gf)
node = (a & b).arithmetize("best-effort")

assert node.evaluate({"a": gf(0), "b": gf(0)}) == gf(0)
Expand All @@ -210,8 +211,8 @@ def test_evaluate_mod3(): # noqa: D103
def test_evaluate_arithmetized_mod3(): # noqa: D103
gf = GF(3)

a = Input("a", gf)
b = Input("b", gf)
a = BooleanInput("a", gf)
b = BooleanInput("b", gf)
node = (a & b).arithmetize("best-effort")

node.clear_cache(set())
Expand All @@ -227,8 +228,8 @@ def test_evaluate_arithmetized_mod3(): # noqa: D103
def test_evaluate_arithmetized_depth_aware_mod2(): # noqa: D103
gf = GF(2)

a = Input("a", gf)
b = Input("b", gf)
a = BooleanInput("a", gf)
b = BooleanInput("b", gf)
node = a & b
front = node.arithmetize_depth_aware(cost_of_squaring=1.0)

Expand All @@ -246,8 +247,8 @@ def test_evaluate_arithmetized_depth_aware_mod2(): # noqa: D103
def test_evaluate_arithmetized_depth_aware_mod3(): # noqa: D103
gf = GF(3)

a = Input("a", gf)
b = Input("b", gf)
a = BooleanInput("a", gf)
b = BooleanInput("b", gf)
node = a & b
front = node.arithmetize_depth_aware(cost_of_squaring=1.0)

Expand Down Expand Up @@ -368,7 +369,7 @@ def to_arithmetic_node(self, is_and: bool, gf: Type[FieldArray]) -> ArithmeticNo
_, result = _generate_multiplication_tree(((math.ceil(math.log2(operand.breadth)), operand.to_arithmetic_node(is_and, gf) if is_and else Neg(operand.to_arithmetic_node(is_and, gf), gf).arithmetize("best-effort").to_arithmetic()) for operand in self._operands), (1 for _ in range(len(self._operands)))) # type: ignore

if not is_and:
result = Neg(result, gf)
result = Neg(result, gf) # type: ignore

self._arithmetic_node = result.arithmetize(
"best-effort"
Expand Down Expand Up @@ -422,11 +423,11 @@ def to_arithmetic_node(self, is_and: bool, gf: Type[FieldArray]) -> ArithmeticNo
Counter(
{
UnoverloadedWrapper(
Neg(operand.to_arithmetic_node(is_and, gf), gf)
Neg(operand.to_arithmetic_node(is_and, gf), gf) # type: ignore
): 1
for operand in self._operands
}
),
), # type: ignore
gf,
)
.arithmetize("best-effort")
Expand Down Expand Up @@ -455,7 +456,7 @@ def to_arithmetic_node(self, is_and: bool, gf: Type[FieldArray]) -> ArithmeticNo
result = nodes[-1]

if is_and:
result = Neg(result, gf).arithmetize("best-effort")
result = Neg(result, gf).arithmetize("best-effort") # type: ignore

self._arithmetic_node = result.to_arithmetic() # TODO: This could be more elegant
self._is_and = is_and
Expand Down Expand Up @@ -737,7 +738,7 @@ def minimize_depth_cost_recursive( # noqa: PLR0912, PLR0914, PLR0915
return output


def all_(*operands: Node) -> And:
def all_(*operands: Boolean) -> And:
"""Returns an `And` node that evaluates to true if any of the given `operands` evaluates to true."""
assert len(operands) > 0
return And(set(UnoverloadedWrapper(operand) for operand in operands), operands[0]._gf)
23 changes: 12 additions & 11 deletions oraqle/compiler/boolean/bool_or.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from galois import GF, FieldArray

from oraqle.compiler.boolean.bool import Boolean, BooleanConstant, BooleanInput
from oraqle.compiler.boolean.bool_and import And, _find_depth_cost_front
from oraqle.compiler.boolean.bool_neg import Neg
from oraqle.compiler.nodes.abstract import CostParetoFront, Node, UnoverloadedWrapper
Expand All @@ -13,7 +14,7 @@
# TODO: Reduce code duplication between OR and AND


class Or(CommutativeUniqueReducibleNode):
class Or(CommutativeUniqueReducibleNode[Boolean], Boolean):
"""Performs an OR operation over several operands. The user must ensure that the operands are Booleans."""

@property
Expand Down Expand Up @@ -87,15 +88,15 @@ def _arithmetize_depth_aware_inner(self, cost_of_squaring: float) -> CostParetoF

return front

def or_flatten(self, other: Node) -> Node:
def or_flatten(self, other: Boolean) -> Boolean:
"""Performs an OR operation with `other`, flattening the `Or` node if either of the two is also an `Or` and absorbing `Constant`s.
Returns:
An `Or` node containing the flattened OR operation, or a `Constant` node.
"""
if isinstance(other, Constant):
if bool(other._value):
return Constant(self._gf(1))
return BooleanConstant(self._gf(1))
else:
return self

Expand All @@ -107,7 +108,7 @@ def or_flatten(self, other: Node) -> Node:
return Or(new_operands, self._gf)


def any_(*operands: Node) -> Or:
def any_(*operands: Boolean) -> Or:
"""Returns an `Or` node that evaluates to true if any of the given `operands` evaluates to true."""
assert len(operands) > 0
return Or(set(UnoverloadedWrapper(operand) for operand in operands), operands[0]._gf)
Expand All @@ -116,8 +117,8 @@ def any_(*operands: Node) -> Or:
def test_evaluate_mod3(): # noqa: D103
gf = GF(3)

a = Input("a", gf)
b = Input("b", gf)
a = BooleanInput("a", gf)
b = BooleanInput("b", gf)
node = a | b

assert node.evaluate({"a": gf(0), "b": gf(0)}) == gf(0)
Expand All @@ -132,8 +133,8 @@ def test_evaluate_mod3(): # noqa: D103
def test_evaluate_arithmetized_depth_aware_mod2(): # noqa: D103
gf = GF(2)

a = Input("a", gf)
b = Input("b", gf)
a = BooleanInput("a", gf)
b = BooleanInput("b", gf)
node = a | b
front = node.arithmetize_depth_aware(cost_of_squaring=1.0)

Expand All @@ -151,8 +152,8 @@ def test_evaluate_arithmetized_depth_aware_mod2(): # noqa: D103
def test_evaluate_arithmetized_mod3(): # noqa: D103
gf = GF(3)

a = Input("a", gf)
b = Input("b", gf)
a = BooleanInput("a", gf)
b = BooleanInput("b", gf)
node = (a | b).arithmetize("best-effort")

node.clear_cache(set())
Expand All @@ -168,7 +169,7 @@ def test_evaluate_arithmetized_mod3(): # noqa: D103
def test_evaluate_arithmetized_depth_aware_50_mod31(): # noqa: D103
gf = GF(31)

xs = {Input(f"x{i}", gf) for i in range(50)}
xs = {BooleanInput(f"x{i}", gf) for i in range(50)}
node = Or({UnoverloadedWrapper(x) for x in xs}, gf)
front = node.arithmetize_depth_aware(cost_of_squaring=1.0)

Expand Down
Loading

0 comments on commit 2ae6517

Please sign in to comment.