Skip to content

Commit

Permalink
Merge pull request #1530 from o1-labs/feature/no-shifted-scale
Browse files Browse the repository at this point in the history
Efficient scalar mul and other Scalar improvements
  • Loading branch information
mitschabaude authored Apr 16, 2024
2 parents 4e36d3c + 9822654 commit cfdb2a5
Show file tree
Hide file tree
Showing 23 changed files with 721 additions and 383 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,12 @@ This project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.htm

### Breaking changes

- Native curve improvements https://github.com/o1-labs/o1js/pull/1530
- Change the internal representation of `Scalar` from 255 Bools to 1 Bool and 1 Field (low bit and high 254 bits)
- Make `Group.scale()` support all scalars (previously did not support 0, 1 and -1)
- Make `Group.scale()` directly accept `Field` elements, and much more efficient than previous methods of scaling by Fields
- As a result, `Signature.verify()` and `Nullifier.verify()` use much fewer constraints
- Fix `Scalar.fromBits()` to not produce a shifted scalar; shifting is no longer exposed to users of `Scalar`.
- Add assertion to the foreign EC addition gadget that prevents degenerate cases https://github.com/o1-labs/o1js/pull/1545
- Fixes soundness of ECDSA; slightly increases its constraints from ~28k to 29k
- Breaks circuits that used EC addition, like ECDSA
Expand Down
1 change: 0 additions & 1 deletion src/examples/nullifier.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import {
State,
method,
MerkleMap,
Circuit,
MerkleMapWitness,
Mina,
AccountUpdate,
Expand Down
8 changes: 4 additions & 4 deletions src/lib/ml/conversion.ts
Original file line number Diff line number Diff line change
Expand Up @@ -71,18 +71,18 @@ function varToField(x: FieldVar): Field {
return Field(x);
}

function fromScalar(s: Scalar) {
return s.toConstant().constantValue;
function fromScalar(s: Scalar): ScalarConst {
return [0, s.toBigInt()];
}
function toScalar(s: ScalarConst) {
return Scalar.from(s);
return Scalar.from(s[1]);
}

function fromPrivateKey(sk: PrivateKey) {
return fromScalar(sk.s);
}
function toPrivateKey(sk: ScalarConst) {
return new PrivateKey(Scalar.from(sk));
return new PrivateKey(Scalar.from(sk[1]));
}

function fromPublicKey(pk: PublicKey): MlPublicKey {
Expand Down
10 changes: 3 additions & 7 deletions src/lib/provable/crypto/nullifier.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ import { Struct } from '../types/struct.js';
import { Field, Group, Scalar } from '../wrapped.js';
import { Poseidon } from './poseidon.js';
import { MerkleMapWitness } from '../merkle-map.js';
import { PrivateKey, PublicKey, scaleShifted } from './signature.js';
import { PrivateKey, PublicKey } from './signature.js';
import { Provable } from '../provable.js';

export { Nullifier };
Expand Down Expand Up @@ -50,7 +50,6 @@ class Nullifier extends Struct({
public: { nullifier, s },
private: { c },
} = this;

// generator
let G = Group.generator;

Expand All @@ -68,9 +67,8 @@ class Nullifier extends Struct({

let h_m_pk = Group.fromFields([x, x0]);

// shifted scalar see https://github.com/o1-labs/o1js/blob/5333817a62890c43ac1b9cb345748984df271b62/src/lib/signature.ts#L220
// pk^c
let pk_c = scaleShifted(this.publicKey, Scalar.fromBits(c.toBits()));
let pk_c = this.publicKey.scale(c);

// g^r = g^s / pk^c
let g_r = G.scale(s).sub(pk_c);
Expand All @@ -79,9 +77,7 @@ class Nullifier extends Struct({
let h_m_pk_s = h_m_pk.scale(s);

// h_m_pk_r = h(m,pk)^s / nullifier^c
let h_m_pk_s_div_nullifier_s = h_m_pk_s.sub(
scaleShifted(nullifier, Scalar.fromBits(c.toBits()))
);
let h_m_pk_s_div_nullifier_s = h_m_pk_s.sub(nullifier.scale(c));

// this is supposed to match the entries generated on "the other side" of the nullifier (mina-signer, in an wallet enclave)
Poseidon.hash([
Expand Down
52 changes: 14 additions & 38 deletions src/lib/provable/crypto/signature.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import { Field, Bool, Group, Scalar } from '../wrapped.js';
import { AnyConstructor } from '../types/struct.js';
import { hashWithPrefix } from './poseidon.js';
import { Fq } from '../../../bindings/crypto/finite-field.js';
import {
deriveNonce,
Signature as SignatureBigint,
Expand All @@ -11,16 +10,12 @@ import {
PrivateKey as PrivateKeyBigint,
PublicKey as PublicKeyBigint,
} from '../../../mina-signer/src/curve-bigint.js';
import { constantScalarToBigint } from '../scalar.js';
import { toConstantField } from '../field.js';
import { CircuitValue, prop } from '../types/circuit-value.js';

// external API
export { PrivateKey, PublicKey, Signature };

// internal API
export { scaleShifted };

/**
* A signing key. You can generate one via {@link PrivateKey.random}.
*/
Expand Down Expand Up @@ -71,7 +66,7 @@ class PrivateKey extends CircuitValue {
* Convert this {@link PrivateKey} to a bigint
*/
toBigInt() {
return constantScalarToBigint(this.s, 'PrivateKey.toBigInt');
return this.s.toBigInt();
}

/**
Expand Down Expand Up @@ -117,9 +112,7 @@ class PrivateKey extends CircuitValue {
* @returns a base58 encoded string
*/
static toBase58(privateKey: { s: Scalar }) {
return PrivateKeyBigint.toBase58(
constantScalarToBigint(privateKey.s, 'PrivateKey.toBase58')
);
return PrivateKeyBigint.toBase58(privateKey.s.toBigInt());
}
}

Expand Down Expand Up @@ -249,29 +242,29 @@ class Signature extends CircuitValue {
* @returns a {@link Signature}
*/
static create(privKey: PrivateKey, msg: Field[]): Signature {
const publicKey = PublicKey.fromPrivateKey(privKey).toGroup();
const d = privKey.s;
let publicKey = PublicKey.fromPrivateKey(privKey).toGroup();
let d = privKey.s;

// we chose an arbitrary prefix for the signature, and it happened to be 'testnet'
// there's no consequences in practice and the signatures can be used with any network
// if there needs to be a custom nonce, include it in the message itself
const kPrime = Scalar.from(
let kPrime = Scalar.from(
deriveNonce(
{ fields: msg.map((f) => f.toBigInt()) },
{ x: publicKey.x.toBigInt(), y: publicKey.y.toBigInt() },
d.toBigInt(),
'testnet'
)
);

let { x: r, y: ry } = Group.generator.scale(kPrime);
const k = ry.isOdd().toBoolean() ? kPrime.neg() : kPrime;
let k = ry.isOdd().toBoolean() ? kPrime.neg() : kPrime;
let h = hashWithPrefix(
signaturePrefix('testnet'),
msg.concat([publicKey.x, publicKey.y, r])
);
// TODO: Scalar.fromBits interprets the input as a "shifted scalar"
// therefore we have to unshift e before using it
let e = unshift(Scalar.fromBits(h.toBits()));
const s = e.mul(d).add(k);
let e = Scalar.fromField(h);
let s = e.mul(d).add(k);
return new Signature(r, s);
}

Expand All @@ -280,18 +273,17 @@ class Signature extends CircuitValue {
* @returns a {@link Bool}
*/
verify(publicKey: PublicKey, msg: Field[]): Bool {
const point = publicKey.toGroup();
let point = publicKey.toGroup();

// we chose an arbitrary prefix for the signature, and it happened to be 'testnet'
// there's no consequences in practice and the signatures can be used with any network
// if there needs to be a custom nonce, include it in the message itself
let h = hashWithPrefix(
signaturePrefix('testnet'),
msg.concat([point.x, point.y, this.r])
);
// TODO: Scalar.fromBits interprets the input as a "shifted scalar"
// therefore we have to use scaleShifted which is very inefficient
let e = Scalar.fromBits(h.toBits());
let r = scaleShifted(point, e).neg().add(Group.generator.scale(this.s));

let r = point.scale(h).neg().add(Group.generator.scale(this.s));
return r.x.equals(this.r).and(r.y.isEven());
}

Expand All @@ -311,19 +303,3 @@ class Signature extends CircuitValue {
return SignatureBigint.toBase58({ r, s });
}
}

// performs scalar multiplication s*G assuming that instead of s, we got s' = 2s + 1 + 2^255
// cost: 2x scale by constant, 1x scale by variable
function scaleShifted(point: Group, shiftedScalar: Scalar) {
let oneHalfGroup = point.scale(Scalar.from(oneHalf));
let shiftGroup = oneHalfGroup.scale(Scalar.from(shift));
return oneHalfGroup.scale(shiftedScalar).sub(shiftGroup);
}
// returns s, assuming that instead of s, we got s' = 2s + 1 + 2^255
// (only works out of snark)
function unshift(shiftedScalar: Scalar) {
return shiftedScalar.sub(Scalar.from(shift)).mul(Scalar.from(oneHalf));
}

let shift = Fq.mod(1n + 2n ** 255n);
let oneHalf = Fq.inverse(2n)!;
20 changes: 2 additions & 18 deletions src/lib/provable/field.ts
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import { setFieldConstructor } from './core/field-constructor.js';
import {
assertLessThanFull,
assertLessThanOrEqualFull,
isOddAndHigh,
lessThanFull,
lessThanOrEqualFull,
} from './gadgets/comparison.js';
Expand Down Expand Up @@ -320,24 +321,7 @@ class Field {
* See {@link Field.isEven} for examples.
*/
isOdd() {
if (this.isConstant()) return new Bool((this.toBigInt() & 1n) === 1n);

// witness a bit b such that x = b + 2z for some z <= (p-1)/2
// this is always possible, and unique _except_ in the edge case where x = 0 = 0 + 2*0 = 1 + 2*(p-1)/2
// so we can compute isOdd = b AND (x != 0)
let [b, z] = exists(2, () => {
let x = this.toBigInt();
return [x & 1n, x >> 1n];
});
let isOdd = b.assertBool();
z.assertLessThan((Field.ORDER + 1n) / 2n);

// x == b + 2z
b.add(z.mul(2)).assertEquals(this);

// avoid overflow case when x = 0
let isNonZero = this.equals(0).not();
return isOdd.and(isNonZero);
return isOddAndHigh(this).isOdd;
}

/**
Expand Down
39 changes: 37 additions & 2 deletions src/lib/provable/gadgets/common.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,19 @@ import { Tuple } from '../../util/types.js';
import type { Bool } from '../bool.js';
import { fieldVar } from '../gates.js';
import { existsOne } from '../core/exists.js';
import { createField } from '../core/field-constructor.js';
import { createField, isBool } from '../core/field-constructor.js';

export { toVars, toVar, isVar, assert, bitSlice, divideWithRemainder };
export {
toVars,
toVar,
isVar,
assert,
bitSlice,
bit,
divideWithRemainder,
packBits,
isConstant,
};

/**
* Given a Field, collapse its AST to a pure Var. See {@link FieldVar}.
Expand Down Expand Up @@ -56,8 +66,33 @@ function bitSlice(x: bigint, start: number, length: number) {
return (x >> BigInt(start)) & ((1n << BigInt(length)) - 1n);
}

function bit(x: bigint, i: number) {
return (x >> BigInt(i)) & 1n;
}

function divideWithRemainder(numerator: bigint, denominator: bigint) {
const quotient = numerator / denominator;
const remainder = numerator - denominator * quotient;
return { quotient, remainder };
}

// pack bools into a single field element

/**
* Helper function to provably pack bits into a single field element.
* Just returns the sum without any boolean checks.
*/
function packBits(bits: (Field | Bool)[]): Field {
let n = bits.length;
let sum = createField(0n);
for (let i = 0; i < n; i++) {
let bit = bits[i];
if (isBool(bit)) bit = bit.toField();
sum = sum.add(bit.mul(1n << BigInt(i)));
}
return sum.seal();
}

function isConstant(...args: (Field | Bool)[]): boolean {
return args.every((x) => x.isConstant());
}
48 changes: 47 additions & 1 deletion src/lib/provable/gadgets/comparison.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
import type { Field } from '../field.js';
import type { Bool } from '../bool.js';
import { createBoolUnsafe, createField } from '../core/field-constructor.js';
import {
createBool,
createBoolUnsafe,
createField,
} from '../core/field-constructor.js';
import { Fp } from '../../../bindings/crypto/finite-field.js';
import { assert } from '../../../lib/util/assert.js';
import { exists, existsOne } from '../core/exists.js';
Expand All @@ -15,13 +19,21 @@ export {
assertLessThanOrEqualGeneric,
lessThanGeneric,
lessThanOrEqualGeneric,

// comparison gadgets for full range inputs
assertLessThanFull,
assertLessThanOrEqualFull,
lessThanFull,
lessThanOrEqualFull,

// gadgets that are based on full comparisons
isOddAndHigh,

// legacy, unused
compareCompatible,

// internal helper
fieldToField3,
};

/**
Expand Down Expand Up @@ -181,6 +193,40 @@ function lessThanOrEqualFull(x: Field, y: Field) {
return lessThanFull(y, x).not();
}

/**
* Splits a field element into a low bit `isOdd` and a 254-bit `high` part.
*
* There are no assumptions on the range of x and y, they can occupy the full range [0, p).
*/
function isOddAndHigh(x: Field) {
if (x.isConstant()) {
let x0 = x.toBigInt();
return { isOdd: createBool((x0 & 1n) === 1n), high: createField(x0 >> 1n) };
}

// witness a bit b such that x = b + 2z for some z <= (p-1)/2
// this is always possible, and unique _except_ in the edge case where x = 0 = 0 + 2*0 = 1 + 2*(p-1)/2
// so we must assert that x = 0 implies b = 0
let [b, z] = exists(2, () => {
let x0 = x.toBigInt();
return [x0 & 1n, x0 >> 1n];
});
let isOdd = b.assertBool();
z.assertLessThan((Fp.modulus + 1n) / 2n);

// x == b + 2z
b.add(z.mul(2n)).assertEquals(x);

// prevent overflow case when x = 0
// we witness x' such that b == x * x', which makes it impossible to have x = 0 and b = 1
let x_ = existsOne(() =>
b.toBigInt() === 0n ? 0n : Fp.inverse(x.toBigInt()) ?? 0n
);
x.mul(x_).assertEquals(b);

return { isOdd, high: z };
}

/**
* internal helper, split Field into a 3-limb bigint
*
Expand Down
Loading

0 comments on commit cfdb2a5

Please sign in to comment.