diff --git a/lib/provable/test/twisted-curve.unit-test.js b/lib/provable/test/twisted-curve.unit-test.js new file mode 100644 index 000000000..ede6cf70f --- /dev/null +++ b/lib/provable/test/twisted-curve.unit-test.js @@ -0,0 +1,34 @@ +import { TwistedCurveParams } from "../../../bindings/crypto/elliptic-curve-examples.js"; +import { createCurveTwisted } from "../../../bindings/crypto/elliptic-curve.js"; +import { array, equivalentProvable, map, onlyIf, spec, unit } from "../../testing/equivalent.js"; +import { Random } from "../../testing/random.js"; +import { assert } from "../gadgets/common.js"; +import { Point, CurveTwisted, simpleMapToCurve } from "../gadgets/twisted-curve.js"; +import { foreignField, throwError } from "./test-utils.js"; +const Ed25519 = createCurveTwisted(TwistedCurveParams.Ed25519); +let curves = [Ed25519]; +for (let Curve of curves) { + let field = foreignField(Curve.Field); + let scalar = foreignField(Curve.Scalar); + let badPoint = spec({ + rng: Random.record({ + x: field.rng, + y: field.rng, + infinity: Random.constant(false) + }), + there: Point.from, + back: Point.toBigint, + provable: Point.provable + }); + let point = map({ from: field, to: badPoint }, (x) => simpleMapToCurve(x, Curve)); + let unequalPair = onlyIf(array(point, 2), ([p, q]) => !Curve.equal(p, q)); + equivalentProvable({ from: [point], to: unit, verbose: true })((p) => Curve.isOnCurve(p) || throwError("expect on curve"), (p) => CurveTwisted.assertOnCurve(p, Curve), `${Curve.name} on curve`); + equivalentProvable({ from: [unequalPair], to: point, verbose: true })(([p, q]) => Curve.add(p, q), ([p, q]) => CurveTwisted.add(p, q, Curve), `${Curve.name} add`); + equivalentProvable({ from: [point], to: point, verbose: true })((p) => Curve.double(p), (p) => CurveTwisted.double(p, Curve), `${Curve.name} double`); + equivalentProvable({ from: [point], to: point, verbose: true })(Curve.negate, (p) => CurveTwisted.negate(p, Curve), `${Curve.name} negate`); + equivalentProvable({ from: [point, scalar], to: point, verbose: true })((p, s) => { + let sp = Curve.scale(p, s); + assert(!sp.infinity, "expect nonzero"); + return sp; + }, (p, s) => CurveTwisted.scale(s, p, Curve), `${Curve.name} scale`); +} diff --git a/src/lib/provable/gadgets/twisted-curve.ts b/src/lib/provable/gadgets/twisted-curve.ts index f48b9c603..a75382440 100644 --- a/src/lib/provable/gadgets/twisted-curve.ts +++ b/src/lib/provable/gadgets/twisted-curve.ts @@ -24,7 +24,7 @@ import { arrayGetGeneric } from './elliptic-curve.js'; export { CurveTwisted }; // internal API -export { Point, initialAggregator, simpleMapToCurve, arrayGetGeneric }; +export { Point, simpleMapToCurve, arrayGetGeneric }; const CurveTwisted = { add, @@ -272,14 +272,14 @@ function scale( scalar: Field3, point: Point, Curve: CurveTwisted, - config: { - mode?: 'assert-nonzero' | 'assert-zero'; + config?: { windowSize?: number; multiples?: Point[]; - } = { mode: 'assert-nonzero' } + } ) { + config = config ?? {}; config.windowSize ??= Point.isConstant(point) ? 4 : 3; - return multiScalarMul([scalar], [point], Curve, [config], config.mode); + return multiScalarMul([scalar], [point], Curve, [config]); } // check whether a point equals a constant point @@ -293,8 +293,7 @@ function equals(p1: Point, p2: point, Curve: { modulus: bigint }) { function multiScalarMulConstant( scalars: Field3[], points: Point[], - Curve: CurveTwisted, - mode: 'assert-nonzero' | 'assert-zero' = 'assert-nonzero' + Curve: CurveTwisted ): Point { let n = points.length; assert(scalars.length === n, 'Points and scalars lengths must match'); @@ -307,11 +306,6 @@ function multiScalarMulConstant( for (let i = 0; i < n; i++) { sum = Curve.add(sum, Curve.scale(P[i], s[i])); } - if (mode === 'assert-zero') { - assert(sum.infinity, 'scalar multiplication: expected zero result'); - return Point.from(Curve.zero); - } - assert(!sum.infinity, 'scalar multiplication: expected non-zero result'); return Point.from(sum); } @@ -322,17 +316,11 @@ function multiScalarMulConstant( * * where P_i are any points. * - * By default, we prove that the result is not zero. - * - * If you set the `mode` parameter to `'assert-zero'`, on the other hand, - * we assert that the result is zero and just return the constant zero point. - * * Implementation: We double all points together and leverage a precomputed table of size 2^c to avoid all but every cth addition. * * Note: this algorithm targets a small number of points * * TODO: could use lookups for picking precomputed multiples, instead of O(2^c) provable switch - * TODO: custom bit representation for the scalar that avoids 0, to get rid of the degenerate addition case */ function multiScalarMul( scalars: Field3[], @@ -341,18 +329,15 @@ function multiScalarMul( tableConfigs: ( | { windowSize?: number; multiples?: Point[] } | undefined - )[] = [], - mode: 'assert-nonzero' | 'assert-zero' = 'assert-nonzero', - ia?: point + )[] = [] ): Point { let n = points.length; assert(scalars.length === n, 'Points and scalars lengths must match'); assertPositiveInteger(n, 'Expected at least 1 point and scalar'); - let useGlv = Curve.hasEndomorphism; // constant case if (scalars.every(Field3.isConstant) && points.every(Point.isConstant)) { - return multiScalarMulConstant(scalars, points, Curve, mode); + return multiScalarMulConstant(scalars, points, Curve); } // parse or build point tables @@ -368,13 +353,9 @@ function multiScalarMul( sliceField3(s, { maxBits, chunkSize: windowSizes[i] }) ); - // initialize sum to the initial aggregator, which is expected to be unrelated - // to any point that this gadget is used with - // note: this is a trick to ensure _completeness_ of the gadget // soundness follows because add() and double() are sound, on all inputs that // are valid non-zero curve points - ia ??= initialAggregator(Curve); - let sum = Point.from(ia); + let sum = Point.from(Curve.zero); for (let i = maxBits - 1; i >= 0; i--) { // add in multiple of each point @@ -389,11 +370,7 @@ function multiScalarMul( : arrayGetGeneric(Point.provable, tables[j], sj); // ec addition - let added = add(sum, sjP, Curve); - - // handle degenerate case - // (if sj = 0, Gj is all zeros and the add result is garbage) - sum = Provable.if(sj.equals(0), Point, sum, added); + sum = add(sum, sjP, Curve); } } @@ -405,20 +382,6 @@ function multiScalarMul( sum = double(sum, Curve); } - // the sum is now 2^(b-1)*IA + sum_i s_i*P_i - // we assert that sum != 2^(b-1)*IA, and add -2^(b-1)*IA to get our result - let iaFinal = Curve.scale(Curve.fromNonzero(ia), 1n << BigInt(maxBits - 1)); - let isZero = equals(sum, iaFinal, Curve); - - if (mode === 'assert-nonzero') { - isZero.assertFalse(); - sum = add(sum, Point.from(Curve.negate(iaFinal)), Curve); - } else { - isZero.assertTrue(); - // for type consistency with the 'assert-nonzero' case - sum = Point.from(Curve.zero); - } - return sum; } @@ -450,31 +413,6 @@ function getPointTable( return table; } -/** - * For EC scalar multiplication we use an initial point which is subtracted - * at the end, to avoid encountering the point at infinity. - * - * This is a simple hash-to-group algorithm which finds that initial point. - * It's important that this point has no known discrete logarithm so that nobody - * can create an invalid proof of EC scaling. - */ -function initialAggregator(Curve: CurveTwisted) { - // hash that identifies the curve - let h = sha256.create(); - h.update('initial-aggregator'); - h.update(bigIntToBytes(Curve.modulus)); - h.update(bigIntToBytes(Curve.order)); - h.update(bigIntToBytes(Curve.a)); - h.update(bigIntToBytes(Curve.d)); - let bytes = h.array(); - - // bytes represent a 256-bit number - // use that as x coordinate - const F = Curve.Field; - let x = F.mod(bytesToBigInt(bytes)); - return simpleMapToCurve(x, Curve); -} - function random(Curve: CurveTwisted) { let x = Curve.Field.random(); return simpleMapToCurve(x, Curve); diff --git a/src/lib/provable/test/twisted-curve.unit-test.ts b/src/lib/provable/test/twisted-curve.unit-test.ts index c57ea96ac..afbc2033c 100644 --- a/src/lib/provable/test/twisted-curve.unit-test.ts +++ b/src/lib/provable/test/twisted-curve.unit-test.ts @@ -75,7 +75,6 @@ for (let Curve of curves) { equivalentProvable({ from: [point, scalar], to: point, verbose: true })( (p, s) => { let sp = Curve.scale(p, s); - assert(!sp.infinity, 'expect nonzero'); return sp; }, (p, s) => CurveTwisted.scale(s, p, Curve),