Skip to content

Commit

Permalink
remove complexity of msm because twisted add is complete (no degenera…
Browse files Browse the repository at this point in the history
…te cases)
  • Loading branch information
querolita committed Jan 16, 2025
1 parent c954595 commit f91c36b
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 73 deletions.
34 changes: 34 additions & 0 deletions lib/provable/test/twisted-curve.unit-test.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import { TwistedCurveParams } from "../../../bindings/crypto/elliptic-curve-examples.js";
import { createCurveTwisted } from "../../../bindings/crypto/elliptic-curve.js";
import { array, equivalentProvable, map, onlyIf, spec, unit } from "../../testing/equivalent.js";
import { Random } from "../../testing/random.js";
import { assert } from "../gadgets/common.js";
import { Point, CurveTwisted, simpleMapToCurve } from "../gadgets/twisted-curve.js";
import { foreignField, throwError } from "./test-utils.js";
const Ed25519 = createCurveTwisted(TwistedCurveParams.Ed25519);
let curves = [Ed25519];
for (let Curve of curves) {
let field = foreignField(Curve.Field);
let scalar = foreignField(Curve.Scalar);
let badPoint = spec({
rng: Random.record({
x: field.rng,
y: field.rng,
infinity: Random.constant(false)
}),
there: Point.from,
back: Point.toBigint,
provable: Point.provable
});
let point = map({ from: field, to: badPoint }, (x) => simpleMapToCurve(x, Curve));
let unequalPair = onlyIf(array(point, 2), ([p, q]) => !Curve.equal(p, q));
equivalentProvable({ from: [point], to: unit, verbose: true })((p) => Curve.isOnCurve(p) || throwError("expect on curve"), (p) => CurveTwisted.assertOnCurve(p, Curve), `${Curve.name} on curve`);
equivalentProvable({ from: [unequalPair], to: point, verbose: true })(([p, q]) => Curve.add(p, q), ([p, q]) => CurveTwisted.add(p, q, Curve), `${Curve.name} add`);
equivalentProvable({ from: [point], to: point, verbose: true })((p) => Curve.double(p), (p) => CurveTwisted.double(p, Curve), `${Curve.name} double`);
equivalentProvable({ from: [point], to: point, verbose: true })(Curve.negate, (p) => CurveTwisted.negate(p, Curve), `${Curve.name} negate`);
equivalentProvable({ from: [point, scalar], to: point, verbose: true })((p, s) => {
let sp = Curve.scale(p, s);
assert(!sp.infinity, "expect nonzero");
return sp;
}, (p, s) => CurveTwisted.scale(s, p, Curve), `${Curve.name} scale`);
}
82 changes: 10 additions & 72 deletions src/lib/provable/gadgets/twisted-curve.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import { arrayGetGeneric } from './elliptic-curve.js';
export { CurveTwisted };

// internal API
export { Point, initialAggregator, simpleMapToCurve, arrayGetGeneric };
export { Point, simpleMapToCurve, arrayGetGeneric };

const CurveTwisted = {
add,
Expand Down Expand Up @@ -272,14 +272,14 @@ function scale(
scalar: Field3,
point: Point,
Curve: CurveTwisted,
config: {
mode?: 'assert-nonzero' | 'assert-zero';
config?: {
windowSize?: number;
multiples?: Point[];
} = { mode: 'assert-nonzero' }
}
) {
config = config ?? {};
config.windowSize ??= Point.isConstant(point) ? 4 : 3;
return multiScalarMul([scalar], [point], Curve, [config], config.mode);
return multiScalarMul([scalar], [point], Curve, [config]);
}

// check whether a point equals a constant point
Expand All @@ -293,8 +293,7 @@ function equals(p1: Point, p2: point, Curve: { modulus: bigint }) {
function multiScalarMulConstant(
scalars: Field3[],
points: Point[],
Curve: CurveTwisted,
mode: 'assert-nonzero' | 'assert-zero' = 'assert-nonzero'
Curve: CurveTwisted
): Point {
let n = points.length;
assert(scalars.length === n, 'Points and scalars lengths must match');
Expand All @@ -307,11 +306,6 @@ function multiScalarMulConstant(
for (let i = 0; i < n; i++) {
sum = Curve.add(sum, Curve.scale(P[i], s[i]));
}
if (mode === 'assert-zero') {
assert(sum.infinity, 'scalar multiplication: expected zero result');
return Point.from(Curve.zero);
}
assert(!sum.infinity, 'scalar multiplication: expected non-zero result');
return Point.from(sum);
}

Expand All @@ -322,17 +316,11 @@ function multiScalarMulConstant(
*
* where P_i are any points.
*
* By default, we prove that the result is not zero.
*
* If you set the `mode` parameter to `'assert-zero'`, on the other hand,
* we assert that the result is zero and just return the constant zero point.
*
* Implementation: We double all points together and leverage a precomputed table of size 2^c to avoid all but every cth addition.
*
* Note: this algorithm targets a small number of points
*
* TODO: could use lookups for picking precomputed multiples, instead of O(2^c) provable switch
* TODO: custom bit representation for the scalar that avoids 0, to get rid of the degenerate addition case
*/
function multiScalarMul(
scalars: Field3[],
Expand All @@ -341,18 +329,15 @@ function multiScalarMul(
tableConfigs: (
| { windowSize?: number; multiples?: Point[] }
| undefined
)[] = [],
mode: 'assert-nonzero' | 'assert-zero' = 'assert-nonzero',
ia?: point
)[] = []
): Point {
let n = points.length;
assert(scalars.length === n, 'Points and scalars lengths must match');
assertPositiveInteger(n, 'Expected at least 1 point and scalar');
let useGlv = Curve.hasEndomorphism;

// constant case
if (scalars.every(Field3.isConstant) && points.every(Point.isConstant)) {
return multiScalarMulConstant(scalars, points, Curve, mode);
return multiScalarMulConstant(scalars, points, Curve);
}

// parse or build point tables
Expand All @@ -368,13 +353,9 @@ function multiScalarMul(
sliceField3(s, { maxBits, chunkSize: windowSizes[i] })
);

// initialize sum to the initial aggregator, which is expected to be unrelated
// to any point that this gadget is used with
// note: this is a trick to ensure _completeness_ of the gadget
// soundness follows because add() and double() are sound, on all inputs that
// are valid non-zero curve points
ia ??= initialAggregator(Curve);
let sum = Point.from(ia);
let sum = Point.from(Curve.zero);

for (let i = maxBits - 1; i >= 0; i--) {
// add in multiple of each point
Expand All @@ -389,11 +370,7 @@ function multiScalarMul(
: arrayGetGeneric(Point.provable, tables[j], sj);

// ec addition
let added = add(sum, sjP, Curve);

// handle degenerate case
// (if sj = 0, Gj is all zeros and the add result is garbage)
sum = Provable.if(sj.equals(0), Point, sum, added);
sum = add(sum, sjP, Curve);
}
}

Expand All @@ -405,20 +382,6 @@ function multiScalarMul(
sum = double(sum, Curve);
}

// the sum is now 2^(b-1)*IA + sum_i s_i*P_i
// we assert that sum != 2^(b-1)*IA, and add -2^(b-1)*IA to get our result
let iaFinal = Curve.scale(Curve.fromNonzero(ia), 1n << BigInt(maxBits - 1));
let isZero = equals(sum, iaFinal, Curve);

if (mode === 'assert-nonzero') {
isZero.assertFalse();
sum = add(sum, Point.from(Curve.negate(iaFinal)), Curve);
} else {
isZero.assertTrue();
// for type consistency with the 'assert-nonzero' case
sum = Point.from(Curve.zero);
}

return sum;
}

Expand Down Expand Up @@ -450,31 +413,6 @@ function getPointTable(
return table;
}

/**
* For EC scalar multiplication we use an initial point which is subtracted
* at the end, to avoid encountering the point at infinity.
*
* This is a simple hash-to-group algorithm which finds that initial point.
* It's important that this point has no known discrete logarithm so that nobody
* can create an invalid proof of EC scaling.
*/
function initialAggregator(Curve: CurveTwisted) {
// hash that identifies the curve
let h = sha256.create();
h.update('initial-aggregator');
h.update(bigIntToBytes(Curve.modulus));
h.update(bigIntToBytes(Curve.order));
h.update(bigIntToBytes(Curve.a));
h.update(bigIntToBytes(Curve.d));
let bytes = h.array();

// bytes represent a 256-bit number
// use that as x coordinate
const F = Curve.Field;
let x = F.mod(bytesToBigInt(bytes));
return simpleMapToCurve(x, Curve);
}

function random(Curve: CurveTwisted) {
let x = Curve.Field.random();
return simpleMapToCurve(x, Curve);
Expand Down
1 change: 0 additions & 1 deletion src/lib/provable/test/twisted-curve.unit-test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,6 @@ for (let Curve of curves) {
equivalentProvable({ from: [point, scalar], to: point, verbose: true })(
(p, s) => {
let sp = Curve.scale(p, s);
assert(!sp.infinity, 'expect nonzero');
return sp;
},
(p, s) => CurveTwisted.scale(s, p, Curve),
Expand Down

0 comments on commit f91c36b

Please sign in to comment.