Skip to content

Commit

Permalink
Merge pull request #7 from jellevos/more-experiments
Browse files Browse the repository at this point in the history
More experiments
  • Loading branch information
jellevos authored Jan 29, 2025
2 parents e33a437 + a329889 commit 424af2f
Show file tree
Hide file tree
Showing 13 changed files with 183 additions and 23 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,5 @@ instructions.txt
.sphinx_build/
/dist
*.svg
*.pyc
oraqle/addchain_cache.db
Binary file removed addchain_cache.db
Binary file not shown.
10 changes: 5 additions & 5 deletions oraqle/add_chains/addition_chains.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from pysat.card import CardEnc
from pysat.formula import WCNF

from oraqle.add_chains.memoization import ADDCHAIN_CACHE_PATH, cache_to_disk
from oraqle.add_chains.memoization import cache_to_disk
from oraqle.add_chains.solving import solve, solve_with_time_limit
from oraqle.config import MAXSAT_TIMEOUT

Expand All @@ -24,24 +24,24 @@ def thurber_bounds(target: int, max_size: int) -> List[Tuple[int, int]]:
denominator = (1 << (t + 1)) * ((1 << (max_size - t - (step + 2))) + 1)
else:
denominator = (1 << t) * ((1 << (max_size - t - (step + 1))) + 1)
bound = int(math.ceil(target / denominator))
bound = math.ceil(target / denominator)
bounds.append((bound, min(1 << step, target)))

step = max_size - t - 2
if step > 0:
denominator = (1 << t) * ((1 << (max_size - t - (step + 1))) + 1)
bound = int(math.ceil(target / denominator))
bound = math.ceil(target / denominator)
bounds.append((bound, min(1 << step, target)))

if max_size - t - 1 > 0:
for step in range(max_size - t - 1, max_size + 1):
bound = int(math.ceil(target / (1 << (max_size - step))))
bound = math.ceil(target / (1 << (max_size - step)))
bounds.append((bound, min(1 << step, target)))

return bounds


@cache_to_disk(ADDCHAIN_CACHE_PATH, ignore_args={"solver", "encoding", "thurber"})
@cache_to_disk(ignore_args={"solver", "encoding", "thurber"})
def add_chain( # noqa: PLR0912, PLR0913, PLR0915, PLR0917
target: int,
max_depth: Optional[int],
Expand Down
12 changes: 9 additions & 3 deletions oraqle/add_chains/memoization.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,29 @@
"""This module contains tools for memoizing addition chains, as these are expensive to compute."""
from hashlib import sha3_256
from importlib.resources import files
import inspect
import shelve
from typing import Set

from sympy import sieve

import oraqle

ADDCHAIN_CACHE_PATH = "addchain_cache"

ADDCHAIN_CACHE_FILENAME = "addchain_cache"


# Adapted from: https://stackoverflow.com/questions/16463582/memoize-to-disk-python-persistent-memoization
def cache_to_disk(file_name, ignore_args: Set[str]):
def cache_to_disk(ignore_args: Set[str]):
"""This decorator caches the calls to this function in a file on disk, ignoring the arguments listed in `ignore_args`.
Returns:
A cached output
"""
d = shelve.open(file_name) # noqa: SIM115
# Always opens the database in the root of where the package is located
oraqle_path = files(oraqle)
database_path = oraqle_path.joinpath(ADDCHAIN_CACHE_FILENAME)
d = shelve.open(str(database_path)) # noqa: SIM115

def decorator(func):
signature = inspect.signature(func)
Expand Down
6 changes: 2 additions & 4 deletions oraqle/circuits/cardio.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
def construct_cardio_risk_circuit(gf: Type[FieldArray]) -> Node:
"""Returns the cardio circuit from https://arxiv.org/abs/2101.07078."""
man = Input("man", gf)
woman = Input("woman", gf)
smoking = Input("smoking", gf)
age = Input("age", gf)
diabetic = Input("diabetic", gf)
Expand All @@ -26,7 +25,7 @@ def construct_cardio_risk_circuit(gf: Type[FieldArray]) -> Node:

return sum_(
man & (age > 50),
woman & (age > 60),
Neg(man, gf) & (age > 60),
smoking,
diabetic,
hbp,
Expand All @@ -41,7 +40,6 @@ def construct_cardio_risk_circuit(gf: Type[FieldArray]) -> Node:
def construct_cardio_elevated_risk_circuit(gf: Type[FieldArray]) -> Node:
"""Returns a variant of the cardio circuit that returns a Boolean indicating whether any risk factor returned true."""
man = Input("man", gf)
woman = Input("woman", gf)
smoking = Input("smoking", gf)
age = Input("age", gf)
diabetic = Input("diabetic", gf)
Expand All @@ -54,7 +52,7 @@ def construct_cardio_elevated_risk_circuit(gf: Type[FieldArray]) -> Node:

return any_(
man & (age > 50),
woman & (age > 60),
Neg(man, gf) & (age > 60),
smoking,
diabetic,
hbp,
Expand Down
66 changes: 66 additions & 0 deletions oraqle/compiler/circuit.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
"""This module contains classes for representing circuits."""
from importlib.resources import files
import os
import shutil
import subprocess
import tempfile
from typing import Dict, List, Optional, Tuple
Expand All @@ -7,6 +10,7 @@
from fhegen.util import estsecurity
from galois import FieldArray

import oraqle.helib_template
from oraqle.compiler.graphviz import DotFile
from oraqle.compiler.instructions import ArithmeticProgram, OutputInstruction
from oraqle.compiler.nodes.abstract import ArithmeticNode, Node
Expand Down Expand Up @@ -405,6 +409,68 @@ def generate_code(
file.write(helib_postamble)

return params

def run_using_helib(self,
iterations: int,
measure_time: bool = False,
decrypt_outputs: bool = False,
**kwargs) -> float:
"""Generate a program using HElib and execute it, measuring the average run time.
Raises:
Exception: If an error occured during the build or execution.
Returns:
Average run time in seconds as a float
"""
assert measure_time
assert not decrypt_outputs

original_directory = os.getcwd()

try:
with tempfile.TemporaryDirectory() as temp_dir:
# Copy the template folder to the temporary directory
build_dir = os.path.join(temp_dir, "build")
template_path = files(oraqle.helib_template)
shutil.copytree(str(template_path), build_dir)

# Generate the main.cpp file
main_cpp_path = os.path.join(build_dir, "main.cpp")
self.generate_code(main_cpp_path, iterations, measure_time, decrypt_outputs)

# Call cmake and build
os.chdir(build_dir)
subprocess.run(["cmake", "-S", ".", "-B", "build"], check=True, capture_output=True)
subprocess.run(["cmake", "--build", "build"], check=True, capture_output=True)

# Run the executable
executable_path = os.path.join(build_dir, "build", "main")
program_args = [f"{keyword}={value}" for keyword, value in kwargs.items()]
print(f"Build completed. Running with parameters: {', '.join(program_args)}...")
result = subprocess.run([executable_path, *program_args], check=True, text=True, capture_output=True)

# Check that all ciphertexts are valid
lines = result.stdout.splitlines()
for line in lines[:-1]:
assert line.endswith("1")

run_time = float(lines[-1]) / iterations
return run_time
except subprocess.CalledProcessError as e:
print("An error occurred during the build or execution process.")
print(e)
try:
print("stderr:")
print(result.stderr)
print()
print("stdout:")
print(result.stdout)
except Exception:
pass
raise Exception("Cannot continue since an error occured.") from e
finally:
os.chdir(original_directory)


if __name__ == "__main__":
Expand Down
12 changes: 6 additions & 6 deletions oraqle/compiler/comparison/in_upper_half.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,13 +71,13 @@ def _arithmetize_inner(self, strategy: str) -> Node:
(2 * exp) % (p - 1),
power_node.multiplicative_depth() - input_node.multiplicative_depth(),
)
for exp, power_node in precomputed_powers.items()
for exp, power_node in precomputed_powers.items() if ((2 * exp) % (p - 1)) != 0
)

addition_chain = add_chain_guaranteed(p - 1, p - 1, squaring_cost=1.0, precomputed_values=precomputed_values)

nodes = [input_node]
nodes.extend(power_node for _, power_node in precomputed_powers.items())
nodes.extend(power_node for exp, power_node in precomputed_powers.items() if ((2 * exp) % (p - 1)) != 0)

for i, j in addition_chain:
nodes.append(Multiplication(nodes[i], nodes[j], self._gf))
Expand Down Expand Up @@ -123,7 +123,7 @@ def _arithmetize_depth_aware_inner(self, cost_of_squaring: float) -> CostParetoF
# Compute the final coefficient using an exponentiation
precomputed_values = tuple(
((2 * exp) % (p - 1), power_node.multiplicative_depth() - node_depth)
for exp, power_node in precomputed_powers[depth].items()
for exp, power_node in precomputed_powers[depth].items() if ((2 * exp) % (p - 1)) != 0
)
# TODO: This is copied from Power, but in the future we can probably remove this if we have augmented circuits
if p <= 200:
Expand All @@ -148,7 +148,7 @@ def _arithmetize_depth_aware_inner(self, cost_of_squaring: float) -> CostParetoF
)

nodes = [node]
nodes.extend(power_node for _, power_node in precomputed_powers[depth].items())
nodes.extend(power_node for exp, power_node in precomputed_powers[depth].items() if ((2 * exp) % (p - 1)) != 0)

for i, j in c:
nodes.append(Multiplication(nodes[i], nodes[j], self._gf))
Expand Down Expand Up @@ -224,13 +224,13 @@ def _arithmetize_inner(self, strategy: str) -> Node:
(2 * exp) % (p - 1),
power_node.multiplicative_depth() - input_node.multiplicative_depth(),
)
for exp, power_node in precomputed_powers.items()
for exp, power_node in precomputed_powers.items() if ((2 * exp) % (p - 1)) != 0
)

addition_chain = add_chain_guaranteed(p - 1, p - 1, squaring_cost=1.0, precomputed_values=precomputed_values)

nodes = [input_node]
nodes.extend(power_node for _, power_node in precomputed_powers.items())
nodes.extend(power_node for exp, power_node in precomputed_powers.items() if ((2 * exp) % (p - 1)) != 0)

for i, j in addition_chain:
nodes.append(Multiplication(nodes[i], nodes[j], self._gf))
Expand Down
4 changes: 2 additions & 2 deletions oraqle/compiler/polynomials/univariate.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def arithmetize_custom(self, strategy: str) -> Tuple[ArithmeticNode, Dict[int, A

lowest_multiplicative_size = 1_000_000_000 # TODO: Not elegant
optimal_k = math.sqrt(2 * len(self._coefficients))
bound = min(int(math.ceil(PS_METHOD_FACTOR_K * optimal_k)), len(self._coefficients))
bound = min(math.ceil(PS_METHOD_FACTOR_K * optimal_k), len(self._coefficients))
for k in range(1, bound):
(
arithmetization,
Expand Down Expand Up @@ -178,7 +178,7 @@ def arithmetize_depth_aware_custom(

for _, _, x in self._node.arithmetize_depth_aware(cost_of_squaring):
optimal_k = math.sqrt(2 * len(self._coefficients))
bound = min(int(math.ceil(PS_METHOD_FACTOR_K * optimal_k)), len(self._coefficients))
bound = min(math.ceil(PS_METHOD_FACTOR_K * optimal_k), len(self._coefficients))
for k in range(1, bound):
(
arithmetization,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import random
import time
from typing import Dict

from galois import GF

from oraqle.circuits.cardio import (
construct_cardio_elevated_risk_circuit,
construct_cardio_risk_circuit,
)
from oraqle.compiler.circuit import Circuit


def gen_params() -> Dict[str, int]:
params = {}

params["man"] = random.randint(0, 1)
params["smoking"] = random.randint(0, 1)
params["diabetic"] = random.randint(0, 1)
params["hbp"] = random.randint(0, 1)

params["age"] = random.randint(0, 100)
params["cholesterol"] = random.randint(0, 60)
params["weight"] = random.randint(40, 150)
params["height"] = random.randint(80, 210)
params["activity"] = random.randint(0, 250)
params["alcohol"] = random.randint(0, 5)

return params


if __name__ == "__main__":
gf = GF(257)
iterations = 10

for cost_of_squaring in [0.75]:
print(f"--- Cardio risk assessment ({cost_of_squaring}) ---")
circuit = Circuit([construct_cardio_risk_circuit(gf)])

start = time.monotonic()
front = circuit.arithmetize_depth_aware(cost_of_squaring=cost_of_squaring)
print("Compile time:", time.monotonic() - start, "s")

for depth, cost, arithmetic_circuit in front:
print(depth, cost)
run_time = arithmetic_circuit.run_using_helib(iterations, True, False, **gen_params())
print("Run time:", run_time)

print(f"--- Cardio elevated risk assessment ({cost_of_squaring}) ---")
circuit = Circuit([construct_cardio_elevated_risk_circuit(gf)])

start = time.monotonic()
front = circuit.arithmetize_depth_aware(cost_of_squaring=cost_of_squaring)
print("Compile time:", time.monotonic() - start, "s")

for depth, cost, arithmetic_circuit in front:
print(depth, cost)
run_time = arithmetic_circuit.run_using_helib(iterations, True, False, **gen_params())
print("Run time:", run_time)
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from galois import GF

from oraqle.compiler.circuit import Circuit
from oraqle.compiler.nodes.leafs import Input


if __name__ == "__main__":
iterations = 10

for p in [29, 43, 61, 101, 131]:
gf = GF(p)

x = Input("x", gf)
y = Input("y", gf)

circuit = Circuit([x == y])

for d, c, arith in circuit.arithmetize_depth_aware(0.75):
print(d, c, arith.run_using_helib(10, True, False, x=13, y=19))

arith = circuit.arithmetize('naive')
print('square and multiply', arith.multiplicative_depth(), arith.multiplicative_size(), arith.multiplicative_cost(0.75), arith.run_using_helib(10, True, False, x=13, y=19))
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from oraqle.compiler.nodes.leafs import Input

if __name__ == "__main__":
iterations = 10

for p in [29, 43, 61, 101, 131]:
gf = GF(p)

Expand All @@ -21,33 +23,37 @@
print("Our circuits:", our_front)

our_front[0][2].to_graph(f"comp_{p}_ours.dot")
for d, s, circ in our_front:
print(d, s, circ.run_using_helib(iterations=iterations, measure_time=True, x=15, y=22))

t2_circuit = Circuit([T2SemiLessThan(x, y, gf)])
t2_arithmetization = t2_circuit.arithmetize()
print(
"T2 circuit:",
t2_arithmetization.multiplicative_depth(),
t2_arithmetization.multiplicative_size(),
t2_arithmetization.run_using_helib(iterations=iterations, measure_time=True, x=15, y=22)
)
t2_arithmetization.eliminate_subexpressions()
print(
"T2 circuit CSE:",
t2_arithmetization.multiplicative_depth(),
t2_arithmetization.multiplicative_size(),
t2_arithmetization.run_using_helib(iterations=iterations, measure_time=True, x=15, y=22)
)

iz21_circuit = Circuit([IliashenkoZuccaSemiLessThan(x, y, gf)])
iz21_arithmetization = iz21_circuit.arithmetize()
iz21_arithmetization.to_graph(f"comp_{p}_iz21.dot")
print(
"IZ21 circuits:",
iz21_arithmetization.multiplicative_depth(),
iz21_arithmetization.multiplicative_size(),
iz21_arithmetization.run_using_helib(iterations=iterations, measure_time=True, x=15, y=22)
)
iz21_arithmetization.eliminate_subexpressions()
iz21_arithmetization.to_graph(f"comp_{p}_iz21_cse.dot")
print(
"IZ21 circuit CSE:",
iz21_arithmetization.multiplicative_depth(),
iz21_arithmetization.multiplicative_size(),
iz21_arithmetization.run_using_helib(iterations=iterations, measure_time=True, x=15, y=22)
)
1 change: 1 addition & 0 deletions oraqle/helib_template/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Template containing all the things to build an HElib program."""
Loading

0 comments on commit 424af2f

Please sign in to comment.