Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement RecoverY hint #506

Merged
merged 17 commits into from
Jul 3, 2024
39 changes: 37 additions & 2 deletions pkg/hintrunner/utils/math_utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -129,10 +129,45 @@ func sign(n *big.Int) (int, big.Int) {

func SafeDiv(x, y *big.Int) (big.Int, error) {
if y.Cmp(big.NewInt(0)) == 0 {
return *big.NewInt(0), fmt.Errorf("Division by zero.")
return *big.NewInt(0), fmt.Errorf("division by zero")
}
if new(big.Int).Mod(x, y).Cmp(big.NewInt(0)) != 0 {
return *big.NewInt(0), fmt.Errorf("%v is not divisible by %v.", x, y)
return *big.NewInt(0), fmt.Errorf("%v is not divisible by %v", x, y)
}
return *new(big.Int).Div(x, y), nil
}

func IsQuadResidue(x *fp.Element) bool {
// Implementation adapted from sympy implementation which can be found here :
// https://github.com/sympy/sympy/blob/d91b8ad6d36a59a879cc70e5f4b379da5fdd46ce/sympy/ntheory/residue_ntheory.py#L689
// We have omitted the prime as it will be CAIRO_PRIME

return x.IsZero() || x.IsOne() || x.Legendre() == 1
}

func YSquaredFromX(x, beta, fieldPrime *big.Int) *big.Int {
// Computes y^2 using the curve equation:
// y^2 = x^3 + alpha * x + beta (mod field_prime)
// We ignore alpha as it is a constant with a value of 1

ySquaredBigInt := new(big.Int).Set(x)
ySquaredBigInt.Mul(ySquaredBigInt, x).Mod(ySquaredBigInt, fieldPrime)
ySquaredBigInt.Mul(ySquaredBigInt, x).Mod(ySquaredBigInt, fieldPrime)
ySquaredBigInt.Add(ySquaredBigInt, x).Mod(ySquaredBigInt, fieldPrime)
ySquaredBigInt.Add(ySquaredBigInt, beta).Mod(ySquaredBigInt, fieldPrime)

return ySquaredBigInt
}

func Sqrt(x, p *big.Int) *big.Int {
// Finds the minimum non-negative integer m such that (m*m) % p == x.

halfPrimeBigInt := new(big.Int).Rsh(p, 1)
m := new(big.Int).ModSqrt(x, p)

if m.Cmp(halfPrimeBigInt) > 0 {
m.Sub(p, m)
}

return m
}
1 change: 1 addition & 0 deletions pkg/hintrunner/zero/hintcode.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ ids.multiplicities = segments.gen_arg([len(positions_dict[k]) for k in output])`
isZeroNondetCode string = "memory[ap] = to_felt_or_relocatable(x == 0)"
isZeroPackCode string = "from starkware.cairo.common.cairo_secp.secp_utils import SECP_P, pack\n\nx = pack(ids.x, PRIME) % SECP_P"
isZeroDivModCode string = "from starkware.cairo.common.cairo_secp.secp_utils import SECP_P\nfrom starkware.python.math_utils import div_mod\n\nvalue = x_inv = div_mod(1, x, SECP_P)"
recoverYCode string = "from starkware.crypto.signature.signature import ALPHA, BETA, FIELD_PRIME\nfrom starkware.python.math_utils import recover_y\nids.p.x = ids.x\n# This raises an exception if `x` is not on the curve.\nids.p.y = recover_y(ids.x, ALPHA, BETA, FIELD_PRIME)"

// ------ Signature hints related code ------
verifyECDSASignatureCode string = "ecdsa_builtin.add_signature(ids.ecdsa_ptr.address_, (ids.signature_r, ids.signature_s))"
Expand Down
2 changes: 2 additions & 0 deletions pkg/hintrunner/zero/zerohint.go
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,8 @@ func GetHintFromCode(program *zero.ZeroProgram, rawHint zero.Hint, hintPC uint64
return createIsZeroPackHinter(resolver)
case isZeroDivModCode:
return createIsZeroDivModHinter()
case recoverYCode:
return createRecoverYHinter(resolver)
// Blake hints
case blake2sAddUint256BigendCode:
return createBlake2sAddUint256Hinter(resolver, true)
Expand Down
83 changes: 83 additions & 0 deletions pkg/hintrunner/zero/zerohint_ec.go
Original file line number Diff line number Diff line change
Expand Up @@ -862,3 +862,86 @@ func newIsZeroDivModHint() hinter.Hinter {
func createIsZeroDivModHinter() (hinter.Hinter, error) {
return newIsZeroDivModHint(), nil
}

// RecoverY hint Recovers the y coordinate of a point on the elliptic curve
// y^2 = x^3 + alpha * x + beta (mod field_prime) of a given x coordinate.
//
// `newRecoverYHint` takes 2 operanders as arguments
// - `x` is the x coordinate of an elliptic curve point
// - `p` is one of the two EC points with the given x coordinate (x, y)
func newRecoverYHint(x, p hinter.ResOperander) hinter.Hinter {
return &GenericZeroHinter{
Name: "RecoverY",
Op: func(vm *VM.VirtualMachine, ctx *hinter.HintRunnerContext) error {
//> from starkware.crypto.signature.signature import ALPHA, BETA, FIELD_PRIME
//> from starkware.python.math_utils import recover_y
//> ids.p.x = ids.x
//> # This raises an exception if `x` is not on the curve.
//> ids.p.y = recover_y(ids.x, ALPHA, BETA, FIELD_PRIME)

pXAddr, err := p.GetAddress(vm)
if err != nil {
return err
}

pYAddr, err := pXAddr.AddOffset(1)
if err != nil {
return err
}

xFelt, err := hinter.ResolveAsFelt(vm, x)
if err != nil {
return err
}

valueX := mem.MemoryValueFromFieldElement(xFelt)

err = vm.Memory.WriteToAddress(&pXAddr, &valueX)
if err != nil {
return err
}

const betaString = "3141592653589793238462643383279502884197169399375105820974944592307816406665"
betaBigInt, ok := new(big.Int).SetString(betaString, 10)
if !ok {
panic("failed to convert BETA string to big.Int")
}

const fieldPrimeString = "3618502788666131213697322783095070105623107215331596699973092056135872020481"
fieldPrimeBigInt, ok := new(big.Int).SetString(fieldPrimeString, 10)
if !ok {
panic("failed to convert FIELD_PRIME string to big.Int")
}

xBigInt := new(big.Int)
xFelt.BigInt(xBigInt)

// y^2 = x^3 + alpha * x + beta (mod field_prime)
ySquaredBigInt := secp_utils.YSquaredFromX(xBigInt, betaBigInt, fieldPrimeBigInt)
ySquaredFelt := new(fp.Element).SetBigInt(ySquaredBigInt)

if secp_utils.IsQuadResidue(ySquaredFelt) {
result := new(fp.Element).SetBigInt(secp_utils.Sqrt(ySquaredBigInt, fieldPrimeBigInt))
value := mem.MemoryValueFromFieldElement(result)
return vm.Memory.WriteToAddress(&pYAddr, &value)
} else {
ySquaredString := ySquaredBigInt.String()
return fmt.Errorf("%s does not represent the x coordinate of a point on the curve", ySquaredString)
}
},
}
}

func createRecoverYHinter(resolver hintReferenceResolver) (hinter.Hinter, error) {
x, err := resolver.GetResOperander("x")
if err != nil {
return nil, err
}

p, err := resolver.GetResOperander("p")
if err != nil {
return nil, err
}

return newRecoverYHint(x, p), nil
}
80 changes: 80 additions & 0 deletions pkg/hintrunner/zero/zerohint_ec_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -974,6 +974,86 @@ func TestZeroHintEc(t *testing.T) {
check: varValueInScopeEquals("value", bigIntString("4", 10)),
},
},
"RecoverY": {
{
operanders: []*hintOperander{
{Name: "x", Kind: apRelative, Value: feltString("2497468900767850684421727063357792717599762502387246235265616708902555305129")},
{Name: "p.x", Kind: uninitialized},
{Name: "p.y", Kind: uninitialized},
},
makeHinter: func(ctx *hintTestContext) hinter.Hinter {
return newRecoverYHint(ctx.operanders["x"], ctx.operanders["p.x"])
},
check: allVarValueEquals(map[string]*fp.Element{
"p.x": feltString("2497468900767850684421727063357792717599762502387246235265616708902555305129"),
"p.y": feltString("205857351767627712295703269674687767888261140702556021834663354704341414042"),
}),
},
{
operanders: []*hintOperander{
{Name: "x", Kind: apRelative, Value: feltString("205857351767627712295703269674687767888261140702556021834663354704341414042")},
{Name: "p.x", Kind: uninitialized},
{Name: "p.y", Kind: uninitialized},
},
makeHinter: func(ctx *hintTestContext) hinter.Hinter {
return newRecoverYHint(ctx.operanders["x"], ctx.operanders["p.x"])
},
errCheck: errorTextContains("does not represent the x coordinate of a point on the curve"),
},
{
operanders: []*hintOperander{
{Name: "x", Kind: apRelative, Value: feltString("3004956058830981475544150447242655232275382685012344776588097793621230049020")},
{Name: "p.x", Kind: uninitialized},
{Name: "p.y", Kind: uninitialized},
},
makeHinter: func(ctx *hintTestContext) hinter.Hinter {
return newRecoverYHint(ctx.operanders["x"], ctx.operanders["p.x"])
},
check: allVarValueEquals(map[string]*fp.Element{
"p.x": feltString("3004956058830981475544150447242655232275382685012344776588097793621230049020"),
"p.y": feltString("386236054595386575795345623791920124827519018828430310912260655089307618738"),
}),
},
{
operanders: []*hintOperander{
{Name: "x", Kind: apRelative, Value: feltString("138597138396302485058562442936200017709939129389766076747102238692717075504")},
{Name: "p.x", Kind: uninitialized},
{Name: "p.y", Kind: uninitialized},
},
makeHinter: func(ctx *hintTestContext) hinter.Hinter {
return newRecoverYHint(ctx.operanders["x"], ctx.operanders["p.x"])
},
check: allVarValueEquals(map[string]*fp.Element{
"p.x": feltString("138597138396302485058562442936200017709939129389766076747102238692717075504"),
"p.y": feltString("1116947097676727397390632683964789044871379304271794004325353078455954290524"),
}),
},
{
operanders: []*hintOperander{
{Name: "x", Kind: apRelative, Value: feltString("71635783675677659163985681365816684268526846280467284682674852685628658265882465826464572245")},
{Name: "p.x", Kind: uninitialized},
{Name: "p.y", Kind: uninitialized},
},
makeHinter: func(ctx *hintTestContext) hinter.Hinter {
return newRecoverYHint(ctx.operanders["x"], ctx.operanders["p.x"])
},
check: allVarValueEquals(map[string]*fp.Element{
"p.x": feltString("71635783675677659163985681365816684268526846280467284682674852685628658265882465826464572245"),
"p.y": feltString("903372048565605391120071143811887302063650776015287438589675702929494830362"),
}),
},
{
operanders: []*hintOperander{
{Name: "x", Kind: apRelative, Value: feltString("42424242424242424242")},
{Name: "p.x", Kind: uninitialized},
{Name: "p.y", Kind: uninitialized},
},
makeHinter: func(ctx *hintTestContext) hinter.Hinter {
return newRecoverYHint(ctx.operanders["x"], ctx.operanders["p.x"])
},
errCheck: errorTextContains("does not represent the x coordinate of a point on the curve"),
},
},
},
)
}
42 changes: 14 additions & 28 deletions pkg/hintrunner/zero/zerohint_math.go
Original file line number Diff line number Diff line change
Expand Up @@ -1152,41 +1152,27 @@ func newIsQuadResidueHint(x, y hinter.ResOperander) hinter.Hinter {
xBigInt := math_utils.AsInt(x)

var value = memory.MemoryValue{}
var result *fp.Element = new(fp.Element)

if x.IsZero() || x.IsOne() {
value = memory.MemoryValueFromFieldElement(x)
const primeString = "3618502788666131213697322783095070105623107215331596699973092056135872020481"
primeBigInt, ok := new(big.Int).SetString(primeString, 10)
if !ok {
panic("failed to convert prime string to big.Int")
}

if math_utils.IsQuadResidue(x) {
result.SetBigInt(math_utils.Sqrt(&xBigInt, primeBigInt))
} else {
var result *fp.Element = new(fp.Element)

if x.Legendre() == 1 {
// result = x.Sqrt(x)

const primeString = "3618502788666131213697322783095070105623107215331596699973092056135872020481"
primeBigInt, ok := new(big.Int).SetString(primeString, 10)
if !ok {
panic("failed to convert prime string to big.Int")
}

// divide primeBigInt by 2
halfPrimeBigInt := new(big.Int).Rsh(primeBigInt, 1)

tempResult := new(big.Int).ModSqrt(&xBigInt, primeBigInt)

// ensures that tempResult is the smaller of the two possible square roots in the prime field.
if tempResult.Cmp(halfPrimeBigInt) > 0 {
tempResult.Sub(primeBigInt, tempResult)
}

result.SetBigInt(tempResult)

} else {
result = x.Sqrt(new(fp.Element).Div(x, new(fp.Element).SetUint64(3)))
y, err := math_utils.Divmod(&xBigInt, big.NewInt(3), primeBigInt)
if err != nil {
return err
}

value = memory.MemoryValueFromFieldElement(result)
result.SetBigInt(math_utils.Sqrt(&y, primeBigInt))
}

value = memory.MemoryValueFromFieldElement(result)

return vm.Memory.WriteToAddress(&yAddr, &value)
},
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/vm/builtins/ecdsa.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ func (e *ECDSA) CheckWrite(segment *memory.Segment, offset uint64, value *memory
pubKey := &ecdsa.PublicKey{A: key}
sig, ok := e.signatures[pubOffset]
if !ok {
return fmt.Errorf("signature is missing form ECDA builtin")
return fmt.Errorf("signature is missing from ECDSA builtin")
}

msgBytes := msgField.Bytes()
Expand Down
Loading