Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement RandomEcPoint hint #513

Merged
merged 9 commits into from
Jul 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 49 additions & 47 deletions integration_tests/BenchMarks.txt

Large diffs are not rendered by default.

3 changes: 0 additions & 3 deletions integration_tests/cairozero_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,6 @@ func TestCairoFiles(t *testing.T) {
errorExpected := false
if name == "range_check.small.cairo" {
errorExpected = true
} else if name == "ecop.starknet_with_keccak.cairo" {
// temporary, being fixed in another PR soon
continue
}

path := filepath.Join(root, name)
Expand Down
16 changes: 15 additions & 1 deletion pkg/hintrunner/utils/math_utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ func IsQuadResidue(x *fp.Element) bool {
return x.IsZero() || x.IsOne() || x.Legendre() == 1
}

func YSquaredFromX(x, beta, fieldPrime *big.Int) *big.Int {
func ySquaredFromX(x, beta, fieldPrime *big.Int) *big.Int {
// Computes y^2 using the curve equation:
// y^2 = x^3 + alpha * x + beta (mod field_prime)
// We ignore alpha as it is a constant with a value of 1
Expand All @@ -171,3 +171,17 @@ func Sqrt(x, p *big.Int) *big.Int {

return m
}

func RecoverY(x, beta, fieldPrime *big.Int) (*big.Int, error) {
ySquared := ySquaredFromX(x, beta, fieldPrime)
if IsQuadResidue(new(fp.Element).SetBigInt(ySquared)) {
return Sqrt(ySquared, fieldPrime), nil
}
return nil, fmt.Errorf("%s does not represent the x coordinate of a point on the curve", ySquared.String())
}

func GetCairoPrime() (big.Int, bool) {
// 2**251 + 17 * 2**192 + 1
cairoPrime, ok := new(big.Int).SetString("3618502788666131213697322783095070105623107215331596699973092056135872020481", 10)
return *cairoPrime, ok
}
1 change: 1 addition & 0 deletions pkg/hintrunner/zero/hintcode.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ ids.multiplicities = segments.gen_arg([len(positions_dict[k]) for k in output])`
isZeroPackCode string = "from starkware.cairo.common.cairo_secp.secp_utils import SECP_P, pack\n\nx = pack(ids.x, PRIME) % SECP_P"
isZeroDivModCode string = "from starkware.cairo.common.cairo_secp.secp_utils import SECP_P\nfrom starkware.python.math_utils import div_mod\n\nvalue = x_inv = div_mod(1, x, SECP_P)"
recoverYCode string = "from starkware.crypto.signature.signature import ALPHA, BETA, FIELD_PRIME\nfrom starkware.python.math_utils import recover_y\nids.p.x = ids.x\n# This raises an exception if `x` is not on the curve.\nids.p.y = recover_y(ids.x, ALPHA, BETA, FIELD_PRIME)"
randomEcPointCode string = "from starkware.crypto.signature.signature import ALPHA, BETA, FIELD_PRIME\nfrom starkware.python.math_utils import random_ec_point\nfrom starkware.python.utils import to_bytes\n\n# Define a seed for random_ec_point that's dependent on all the input, so that:\n# (1) The added point s is deterministic.\n# (2) It's hard to choose inputs for which the builtin will fail.\nseed = b\"\".join(map(to_bytes, [ids.p.x, ids.p.y, ids.m, ids.q.x, ids.q.y]))\nids.s.x, ids.s.y = random_ec_point(FIELD_PRIME, ALPHA, BETA, seed)"

// ------ Signature hints related code ------
verifyECDSASignatureCode string = "ecdsa_builtin.add_signature(ids.ecdsa_ptr.address_, (ids.signature_r, ids.signature_s))"
Expand Down
2 changes: 2 additions & 0 deletions pkg/hintrunner/zero/zerohint.go
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,8 @@ func GetHintFromCode(program *zero.ZeroProgram, rawHint zero.Hint, hintPC uint64
return createIsZeroDivModHinter()
case recoverYCode:
return createRecoverYHinter(resolver)
case randomEcPointCode:
return createRandomEcPointHinter(resolver)
// Blake hints
case blake2sAddUint256BigendCode:
return createBlake2sAddUint256Hinter(resolver, true)
Expand Down
165 changes: 147 additions & 18 deletions pkg/hintrunner/zero/zerohint_ec.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
package zero

import (
"crypto/sha256"
"encoding/binary"
"encoding/hex"
"fmt"
"math/big"

Expand Down Expand Up @@ -901,33 +904,25 @@ func newRecoverYHint(x, p hinter.ResOperander) hinter.Hinter {
return err
}

const betaString = "3141592653589793238462643383279502884197169399375105820974944592307816406665"
betaBigInt, ok := new(big.Int).SetString(betaString, 10)
if !ok {
panic("failed to convert BETA string to big.Int")
}
betaBigInt := new(big.Int)
utils.Beta.BigInt(betaBigInt)

const fieldPrimeString = "3618502788666131213697322783095070105623107215331596699973092056135872020481"
fieldPrimeBigInt, ok := new(big.Int).SetString(fieldPrimeString, 10)
fieldPrimeBigInt, ok := secp_utils.GetCairoPrime()
if !ok {
panic("failed to convert FIELD_PRIME string to big.Int")
return fmt.Errorf("GetCairoPrime failed")
}

xBigInt := new(big.Int)
xFelt.BigInt(xBigInt)

// y^2 = x^3 + alpha * x + beta (mod field_prime)
ySquaredBigInt := secp_utils.YSquaredFromX(xBigInt, betaBigInt, fieldPrimeBigInt)
ySquaredFelt := new(fp.Element).SetBigInt(ySquaredBigInt)

if secp_utils.IsQuadResidue(ySquaredFelt) {
result := new(fp.Element).SetBigInt(secp_utils.Sqrt(ySquaredBigInt, fieldPrimeBigInt))
value := mem.MemoryValueFromFieldElement(result)
return vm.Memory.WriteToAddress(&pYAddr, &value)
} else {
ySquaredString := ySquaredBigInt.String()
return fmt.Errorf("%s does not represent the x coordinate of a point on the curve", ySquaredString)
resultBigInt, err := secp_utils.RecoverY(xBigInt, betaBigInt, &fieldPrimeBigInt)
if err != nil {
return err
}
resultFelt := new(fp.Element).SetBigInt(resultBigInt)
resultMv := mem.MemoryValueFromFieldElement(resultFelt)
return vm.Memory.WriteToAddress(&pYAddr, &resultMv)
},
}
}
Expand All @@ -945,3 +940,137 @@ func createRecoverYHinter(resolver hintReferenceResolver) (hinter.Hinter, error)

return newRecoverYHint(x, p), nil
}

// RandomEcPoint hint returns a random non-zero point on the elliptic curve
// y^2 = x^3 + alpha * x + beta (mod field_prime).
// The point is created deterministically from the seed.
//
// `newRandomEcPointHint` takes 4 operanders as arguments
// - `p` is an EC point used for seed generation
// - `m` the multiplication coefficient of Q used for seed generation
// - `q` an EC point used for seed generation
// - `s` is where the generated random EC point is written to
func newRandomEcPointHint(p, m, q, s hinter.ResOperander) hinter.Hinter {
return &GenericZeroHinter{
Name: "RandomEcPoint",
Op: func(vm *VM.VirtualMachine, ctx *hinter.HintRunnerContext) error {
//> from starkware.crypto.signature.signature import ALPHA, BETA, FIELD_PRIME
//> from starkware.python.math_utils import random_ec_point
//> from starkware.python.utils import to_bytes
//>
//> # Define a seed for random_ec_point that's dependent on all the input, so that:
//> # (1) The added point s is deterministic.
//> # (2) It's hard to choose inputs for which the builtin will fail.
//> seed = b"".join(map(to_bytes, [ids.p.x, ids.p.y, ids.m, ids.q.x, ids.q.y]))
//> ids.s.x, ids.s.y = random_ec_point(FIELD_PRIME, ALPHA, BETA, seed)

pAddr, err := p.GetAddress(vm)
if err != nil {
return err
}
pValues, err := vm.Memory.ResolveAsEcPoint(pAddr)
if err != nil {
return err
}
mFelt, err := hinter.ResolveAsFelt(vm, m)
if err != nil {
return err
}
qAddr, err := q.GetAddress(vm)
if err != nil {
return err
}
qValues, err := vm.Memory.ResolveAsEcPoint(qAddr)
if err != nil {
return err
}

var bytesArray []byte
writeFeltToBytesArray := func(n *fp.Element) {
for _, byteValue := range n.Bytes() {
bytesArray = append(bytesArray, byteValue)
}
}
for _, felt := range pValues {
writeFeltToBytesArray(felt)
}
writeFeltToBytesArray(mFelt)
for _, felt := range qValues {
writeFeltToBytesArray(felt)
}
seed := sha256.Sum256(bytesArray)

alphaBig := new(big.Int)
utils.Alpha.BigInt(alphaBig)
betaBig := new(big.Int)
utils.Beta.BigInt(betaBig)
fieldPrime, ok := secp_utils.GetCairoPrime()
if !ok {
return fmt.Errorf("GetCairoPrime failed")
}

for i := uint64(0); i < 100; i++ {
iBytes := make([]byte, 10)
binary.LittleEndian.PutUint64(iBytes, i)
concatenated := append(seed[1:], iBytes...)
hash := sha256.Sum256(concatenated)
hashHex := hex.EncodeToString(hash[:])
x := new(big.Int)
x.SetString(hashHex, 16)

yCoef := big.NewInt(1)
if seed[0]&1 == 1 {
yCoef.Neg(yCoef)
}

// Try to recover y
if !ok {
return fmt.Errorf("failed to get field prime value")
}
if y, err := secp_utils.RecoverY(x, betaBig, &fieldPrime); err == nil {
y.Mul(yCoef, y)
y.Mod(y, &fieldPrime)

sAddr, err := s.GetAddress(vm)
if err != nil {
return err
}

sXFelt := new(fp.Element).SetBigInt(x)
sYFelt := new(fp.Element).SetBigInt(y)
sXMv := mem.MemoryValueFromFieldElement(sXFelt)
sYMv := mem.MemoryValueFromFieldElement(sYFelt)

err = vm.Memory.WriteToNthStructField(sAddr, sXMv, 0)
if err != nil {
return err
}
return vm.Memory.WriteToNthStructField(sAddr, sYMv, 1)
}
}

return fmt.Errorf("could not find a point on the curve")
},
}
}

func createRandomEcPointHinter(resolver hintReferenceResolver) (hinter.Hinter, error) {
p, err := resolver.GetResOperander("p")
if err != nil {
return nil, err
}
m, err := resolver.GetResOperander("m")
if err != nil {
return nil, err
}
q, err := resolver.GetResOperander("q")
if err != nil {
return nil, err
}
s, err := resolver.GetResOperander("s")
if err != nil {
return nil, err
}

return newRandomEcPointHint(p, m, q, s), nil
}
46 changes: 46 additions & 0 deletions pkg/hintrunner/zero/zerohint_ec_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1054,6 +1054,52 @@ func TestZeroHintEc(t *testing.T) {
errCheck: errorTextContains("does not represent the x coordinate of a point on the curve"),
},
},
"RandomEcPoint": {
{
operanders: []*hintOperander{
{Name: "p.x", Kind: apRelative, Value: feltString("3004956058830981475544150447242655232275382685012344776588097793621230049020")},
{Name: "p.y", Kind: apRelative, Value: feltString("3232266734070744637901977159303149980795588196503166389060831401046564401743")},
{Name: "m", Kind: apRelative, Value: feltUint64(34)},
{Name: "q.x", Kind: apRelative, Value: feltString("2864041794633455918387139831609347757720597354645583729611044800117714995244")},
{Name: "q.y", Kind: apRelative, Value: feltString("2252415379535459416893084165764951913426528160630388985542241241048300343256")},
{Name: "s.x", Kind: uninitialized},
},
makeHinter: func(ctx *hintTestContext) hinter.Hinter {
return newRandomEcPointHint(
ctx.operanders["p.x"],
ctx.operanders["m"],
ctx.operanders["q.x"],
ctx.operanders["s.x"],
)
},
check: consecutiveVarValueEquals("s.x", []*fp.Element{
feltString("96578541406087262240552119423829615463800550101008760434566010168435227837635"),
feltString("3412645436898503501401619513420382337734846074629040678138428701431530606439"),
}),
},
{
operanders: []*hintOperander{
{Name: "p.x", Kind: apRelative, Value: feltUint64(12345)},
{Name: "p.y", Kind: apRelative, Value: feltUint64(6789)},
{Name: "m", Kind: apRelative, Value: feltUint64(101)},
{Name: "q.x", Kind: apRelative, Value: feltUint64(98765)},
{Name: "q.y", Kind: apRelative, Value: feltUint64(4321)},
{Name: "s.x", Kind: uninitialized},
},
makeHinter: func(ctx *hintTestContext) hinter.Hinter {
return newRandomEcPointHint(
ctx.operanders["p.x"],
ctx.operanders["m"],
ctx.operanders["q.x"],
ctx.operanders["s.x"],
)
},
check: consecutiveVarValueEquals("s.x", []*fp.Element{
feltString("39190969885360777615413526676655883809466222002423777590585892821354159079496"),
feltString("533983185449702770508526175744869430974740140562200547506631069957329272485"),
}),
},
},
},
)
}
11 changes: 5 additions & 6 deletions pkg/hintrunner/zero/zerohint_math.go
Original file line number Diff line number Diff line change
Expand Up @@ -1154,21 +1154,20 @@ func newIsQuadResidueHint(x, y hinter.ResOperander) hinter.Hinter {
var value = memory.MemoryValue{}
var result *fp.Element = new(fp.Element)

const primeString = "3618502788666131213697322783095070105623107215331596699973092056135872020481"
primeBigInt, ok := new(big.Int).SetString(primeString, 10)
primeBigInt, ok := math_utils.GetCairoPrime()
if !ok {
panic("failed to convert prime string to big.Int")
return fmt.Errorf("GetCairoPrime failed")
}

if math_utils.IsQuadResidue(x) {
result.SetBigInt(math_utils.Sqrt(&xBigInt, primeBigInt))
result.SetBigInt(math_utils.Sqrt(&xBigInt, &primeBigInt))
} else {
y, err := math_utils.Divmod(&xBigInt, big.NewInt(3), primeBigInt)
y, err := math_utils.Divmod(&xBigInt, big.NewInt(3), &primeBigInt)
if err != nil {
return err
}

result.SetBigInt(math_utils.Sqrt(&y, primeBigInt))
result.SetBigInt(math_utils.Sqrt(&y, &primeBigInt))
}

value = memory.MemoryValueFromFieldElement(result)
Expand Down
18 changes: 18 additions & 0 deletions pkg/vm/memory/memory_value.go
Original file line number Diff line number Diff line change
Expand Up @@ -375,3 +375,21 @@ func (memory *Memory) ResolveAsBigInt3(valAddr MemoryAddress) ([3]*f.Element, er

return valValues, nil
}

func (memory *Memory) ResolveAsEcPoint(valAddr MemoryAddress) ([2]*f.Element, error) {
valMemoryValues, err := memory.GetConsecutiveMemoryValues(valAddr, int16(2))
if err != nil {
return [2]*f.Element{}, err
}

var valValues [2]*f.Element
for i := 0; i < 2; i++ {
valValue, err := valMemoryValues[i].FieldElement()
if err != nil {
return [2]*f.Element{}, err
}
valValues[i] = valValue
}

return valValues, nil
}
Loading