Skip to content

Commit

Permalink
Fix a bug in exponentiation
Browse files Browse the repository at this point in the history
  • Loading branch information
jellevos committed Jan 22, 2025
1 parent 52f536a commit dd13914
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 11 deletions.
5 changes: 3 additions & 2 deletions oraqle/compiler/circuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,12 +430,13 @@ def run_using_helib(self,

# Call cmake and build
os.chdir(build_dir)
subprocess.run(["cmake", "-S", ".", "-B", "build"], check=True)
subprocess.run(["cmake", "--build", "build"], check=True)
subprocess.run(["cmake", "-S", ".", "-B", "build"], check=True, capture_output=True)
subprocess.run(["cmake", "--build", "build"], check=True, capture_output=True)

# Run the executable
executable_path = os.path.join(build_dir, "build", "main")
program_args = [f"{keyword}={value}" for keyword, value in kwargs.items()]
print(f"Build completed. Running with parameters: {", ".join(program_args)}...")
result = subprocess.run([executable_path] + program_args, check=True, text=True, capture_output=True)

print(result.stdout)
Expand Down
14 changes: 7 additions & 7 deletions oraqle/compiler/comparison/in_upper_half.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from oraqle.add_chains.solving import extract_indices
from oraqle.compiler.nodes.abstract import CostParetoFront, Node
from oraqle.compiler.nodes.binary_arithmetic import Addition, Multiplication
from oraqle.compiler.nodes.leafs import Input
from oraqle.compiler.nodes.leafs import Constant, Input
from oraqle.compiler.nodes.unary_arithmetic import ConstantMultiplication
from oraqle.compiler.nodes.univariate import UnivariateNode
from oraqle.compiler.polynomials.univariate import UnivariatePoly, _eval_poly
Expand Down Expand Up @@ -71,13 +71,13 @@ def _arithmetize_inner(self, strategy: str) -> Node:
(2 * exp) % (p - 1),
power_node.multiplicative_depth() - input_node.multiplicative_depth(),
)
for exp, power_node in precomputed_powers.items()
for exp, power_node in precomputed_powers.items() if ((2 * exp) % (p - 1)) != 0
)

addition_chain = add_chain_guaranteed(p - 1, p - 1, squaring_cost=1.0, precomputed_values=precomputed_values)

nodes = [input_node]
nodes.extend(power_node for _, power_node in precomputed_powers.items())
nodes.extend(power_node for exp, power_node in precomputed_powers.items() if ((2 * exp) % (p - 1)) != 0)

for i, j in addition_chain:
nodes.append(Multiplication(nodes[i], nodes[j], self._gf))
Expand Down Expand Up @@ -123,7 +123,7 @@ def _arithmetize_depth_aware_inner(self, cost_of_squaring: float) -> CostParetoF
# Compute the final coefficient using an exponentiation
precomputed_values = tuple(
((2 * exp) % (p - 1), power_node.multiplicative_depth() - node_depth)
for exp, power_node in precomputed_powers[depth].items()
for exp, power_node in precomputed_powers[depth].items() if ((2 * exp) % (p - 1)) != 0
)
# TODO: This is copied from Power, but in the future we can probably remove this if we have augmented circuits
if p <= 200:
Expand All @@ -148,7 +148,7 @@ def _arithmetize_depth_aware_inner(self, cost_of_squaring: float) -> CostParetoF
)

nodes = [node]
nodes.extend(power_node for _, power_node in precomputed_powers[depth].items())
nodes.extend(power_node for exp, power_node in precomputed_powers[depth].items() if ((2 * exp) % (p - 1)) != 0)

for i, j in c:
nodes.append(Multiplication(nodes[i], nodes[j], self._gf))
Expand Down Expand Up @@ -224,13 +224,13 @@ def _arithmetize_inner(self, strategy: str) -> Node:
(2 * exp) % (p - 1),
power_node.multiplicative_depth() - input_node.multiplicative_depth(),
)
for exp, power_node in precomputed_powers.items()
for exp, power_node in precomputed_powers.items() if ((2 * exp) % (p - 1)) != 0
)

addition_chain = add_chain_guaranteed(p - 1, p - 1, squaring_cost=1.0, precomputed_values=precomputed_values)

nodes = [input_node]
nodes.extend(power_node for _, power_node in precomputed_powers.items())
nodes.extend(power_node for exp, power_node in precomputed_powers.items() if ((2 * exp) % (p - 1)) != 0)

for i, j in addition_chain:
nodes.append(Multiplication(nodes[i], nodes[j], self._gf))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,26 +30,30 @@
"T2 circuit:",
t2_arithmetization.multiplicative_depth(),
t2_arithmetization.multiplicative_size(),
t2_arithmetization.run_using_helib(measure_time=True, x=15, y=22)
)
t2_arithmetization.eliminate_subexpressions()
print(
"T2 circuit CSE:",
t2_arithmetization.multiplicative_depth(),
t2_arithmetization.multiplicative_size(),
t2_arithmetization.run_using_helib(measure_time=True, x=15, y=22)
)

iz21_circuit = Circuit([IliashenkoZuccaSemiLessThan(x, y, gf)])
iz21_arithmetization = iz21_circuit.arithmetize()
iz21_arithmetization.to_graph(f"comp_{p}_iz21.dot")
#iz21_arithmetization.to_graph(f"comp_{p}_iz21.dot")
print(
"IZ21 circuits:",
iz21_arithmetization.multiplicative_depth(),
iz21_arithmetization.multiplicative_size(),
iz21_arithmetization.run_using_helib(measure_time=True, x=15, y=22)
)
iz21_arithmetization.eliminate_subexpressions()
iz21_arithmetization.to_graph(f"comp_{p}_iz21_cse.dot")
#iz21_arithmetization.to_graph(f"comp_{p}_iz21_cse.dot")
print(
"IZ21 circuit CSE:",
iz21_arithmetization.multiplicative_depth(),
iz21_arithmetization.multiplicative_size(),
iz21_arithmetization.run_using_helib(measure_time=True, x=15, y=22)
)

0 comments on commit dd13914

Please sign in to comment.