From 9110e83d6d33b8502c541bc56a9a91b601c839dd Mon Sep 17 00:00:00 2001 From: Markku Rossi Date: Wed, 8 May 2024 13:47:51 +0200 Subject: [PATCH] WIP: circuit rewrite patterns. --- compiler/circuits/rewrite.go | 52 ++++++++++++++++++++++++++++++++++++ compiler/ssa/circuitgen.go | 3 ++- 2 files changed, 54 insertions(+), 1 deletion(-) create mode 100644 compiler/circuits/rewrite.go diff --git a/compiler/circuits/rewrite.go b/compiler/circuits/rewrite.go new file mode 100644 index 00000000..b5bebc98 --- /dev/null +++ b/compiler/circuits/rewrite.go @@ -0,0 +1,52 @@ +// +// Copyright (c) 2024 Markku Rossi +// +// All rights reserved. +// + +package circuits + +import ( + "fmt" + "time" + + "github.com/markkurossi/mpc/circuit" +) + +// Rewrite applies rewrite patterns to the circuit. +func (cc *Compiler) Rewrite() { + var stats circuit.Stats + + start := time.Now() + + for _, g := range cc.Gates { + switch g.Op { + case circuit.AND: + // AND(A,A) = A + if g.A == g.B && g.O.NumOutputs() > 0 { + stats[g.Op]++ + g.ShortCircuit(g.A) + } + case circuit.OR: + // OR(A,A) = A + if g.A == g.B && g.O.NumOutputs() > 0 { + stats[g.Op]++ + g.ShortCircuit(g.A) + } + case circuit.XOR: + // XOR(A,A) = 0 + if g.A == g.B && g.O.NumOutputs() > 0 { + stats[g.Op]++ + g.O.SetValue(Zero) + } + } + } + + elapsed := time.Since(start) + + if cc.Params.Diagnostics { + fmt.Printf(" - Rewrite: %12s: %d/%d (%.2f%%)\n", + elapsed, stats.Count(), len(cc.Gates), + float64(stats.Count())/float64(len(cc.Gates))*100) + } +} diff --git a/compiler/ssa/circuitgen.go b/compiler/ssa/circuitgen.go index 200d42f4..ed7dc81c 100644 --- a/compiler/ssa/circuitgen.go +++ b/compiler/ssa/circuitgen.go @@ -1,5 +1,5 @@ // -// Copyright (c) 2020-2023 Markku Rossi +// Copyright (c) 2020-2024 Markku Rossi // // All rights reserved. // @@ -44,6 +44,7 @@ func (prog *Program) CompileCircuit(params *utils.Params) ( fmt.Printf("Compiling circuit...\n") } cc.ConstPropagate() + cc.Rewrite() cc.ShortCircuitXORZero() if params.OptPruneGates { orig := float64(len(cc.Gates))