Skip to content

Commit

Permalink
Merge main
Browse files Browse the repository at this point in the history
  • Loading branch information
har777 committed Mar 20, 2024
2 parents 83cad9e + 41bd195 commit 9a3a736
Show file tree
Hide file tree
Showing 13 changed files with 912 additions and 96 deletions.
24 changes: 22 additions & 2 deletions cmd/cli/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
func main() {
var proofmode bool
var maxsteps uint64
var entrypointOffset uint64
var traceLocation string
var memoryLocation string

Expand Down Expand Up @@ -42,6 +43,12 @@ func main() {
Required: false,
Destination: &maxsteps,
},
&cli.Uint64Flag{
Name: "entrypoint",
Usage: "a PC offset that will be used as an entry point (by default it executes a main function)",
Value: 0,
Destination: &entrypointOffset,
},
&cli.StringFlag{
Name: "tracefile",
Usage: "location to store the relocated trace",
Expand All @@ -56,6 +63,9 @@ func main() {
},
},
Action: func(ctx *cli.Context) error {
// TODO: move this action's body to a separate function to decrease the
// code nesting significantly.

pathToFile := ctx.Args().Get(0)
if pathToFile == "" {
return fmt.Errorf("path to cairo file not set")
Expand Down Expand Up @@ -86,8 +96,18 @@ func main() {
return fmt.Errorf("cannot create runner: %w", err)
}

if err := runner.Run(); err != nil {
return fmt.Errorf("runtime error: %w", err)
// Run executes main(), RunEntryPoint is used to test contract_class-style entry points.
// In theory, calling RunEntryPoint with main's offset should behave identically,
// but these functions are implemented differently in both this and cairo-rs VMs
// and the difference is quite subtle.
if entrypointOffset == 0 {
if err := runner.Run(); err != nil {
return fmt.Errorf("runtime error: %w", err)
}
} else {
if err := runner.RunEntryPoint(entrypointOffset); err != nil {
return fmt.Errorf("runtime error (entrypoint=%d): %w", entrypointOffset, err)
}
}

if proofmode {
Expand Down
19 changes: 19 additions & 0 deletions pkg/hintrunner/hinter/operand.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"github.com/NethermindEth/cairo-vm-go/pkg/parsers/zero"
"github.com/NethermindEth/cairo-vm-go/pkg/utils"
VM "github.com/NethermindEth/cairo-vm-go/pkg/vm"
"github.com/NethermindEth/cairo-vm-go/pkg/vm/memory"
mem "github.com/NethermindEth/cairo-vm-go/pkg/vm/memory"
f "github.com/consensys/gnark-crypto/ecc/stark-curve/fp"
)
Expand Down Expand Up @@ -253,3 +254,21 @@ func GetConsecutiveValues(vm *VM.VirtualMachine, ref ResOperander, size int16) (

return values, nil
}

func WriteToNthStructField(vm *VM.VirtualMachine, addr mem.MemoryAddress, value mem.MemoryValue, field int16) error {
nAddr, err := addr.AddOffset(field)
if err != nil {
return err
}

return vm.Memory.WriteToAddress(&nAddr, &value)
}

func WriteUint256ToAddress(vm *VM.VirtualMachine, addr mem.MemoryAddress, low, high *f.Element) error {
lowMemoryValue := memory.MemoryValueFromFieldElement(low)
err := vm.Memory.WriteToAddress(&addr, &lowMemoryValue)
if err != nil {
return err
}
return WriteToNthStructField(vm, addr, memory.MemoryValueFromFieldElement(high), 1)
}
16 changes: 13 additions & 3 deletions pkg/hintrunner/zero/hintcode.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,20 @@ const (

unsignedDivRemCode string = "from starkware.cairo.common.math_utils import assert_integer\nassert_integer(ids.div)\nassert 0 < ids.div <= PRIME // range_check_builtin.bound, \\\n f'div={hex(ids.div)} is out of the valid range.'\nids.q, ids.r = divmod(ids.value, ids.div)"

// split_felt() hints.
splitFeltCode string = "from starkware.cairo.common.math_utils import assert_integer\nassert ids.MAX_HIGH < 2**128 and ids.MAX_LOW < 2**128\nassert PRIME - 1 == ids.MAX_HIGH * 2**128 + ids.MAX_LOW\nassert_integer(ids.value)\nids.low = ids.value & ((1 << 128) - 1)\nids.high = ids.value >> 128"

// sqrt() hint
sqrtCode string = "from starkware.python.math_utils import isqrt\nvalue = ids.value % PRIME\nassert value < 2 ** 250, f\"value={value} is outside of the range [0, 2**250).\"\nassert 2 ** 250 < PRIME\nids.root = isqrt(value)"

// ------ Uint256 hints related code ------
uint256AddCode string = "sum_low = ids.a.low + ids.b.low\nids.carry_low = 1 if sum_low >= ids.SHIFT else 0\nsum_high = ids.a.high + ids.b.high + ids.carry_low\nids.carry_high = 1 if sum_high >= ids.SHIFT else 0"
uint256AddLowCode string = "sum_low = ids.a.low + ids.b.low\nids.carry_low = 1 if sum_low >= ids.SHIFT else 0"
split64Code string = "ids.low = ids.a & ((1<<64) - 1)\nids.high = ids.a >> 64"
uint256AddCode string = "sum_low = ids.a.low + ids.b.low\nids.carry_low = 1 if sum_low >= ids.SHIFT else 0\nsum_high = ids.a.high + ids.b.high + ids.carry_low\nids.carry_high = 1 if sum_high >= ids.SHIFT else 0"
uint256AddLowCode string = "sum_low = ids.a.low + ids.b.low\nids.carry_low = 1 if sum_low >= ids.SHIFT else 0"
split64Code string = "ids.low = ids.a & ((1<<64) - 1)\nids.high = ids.a >> 64"
uint256SignedNNCode string = "memory[ap] = 1 if 0 <= (ids.a.high % PRIME) < 2 ** 127 else 0"
uint256UnsignedDivRemCode string = "a = (ids.a.high << 128) + ids.a.low\ndiv = (ids.div.high << 128) + ids.div.low\nquotient, remainder = divmod(a, div)\nids.quotient.low = quotient & ((1 << 128) - 1)\nids.quotient.high = quotient >> 128\nids.remainder.low = remainder & ((1 << 128) - 1)\nids.remainder.high = remainder >> 128"
uint256SqrtCode string = "from starkware.python.math_utils import isqrt\nn = (ids.n.high << 128) + ids.n.low\nroot = isqrt(n)\nassert 0 <= root < 2 ** 128\nids.root.low = root\nids.root.high = 0"
uint256MulDivModCode string = "a = (ids.a.high << 128) + ids.a.low/n b = (ids.b.high << 128) + ids.b.low/n div = (ids.div.high << 128) + ids.div.low/n quotient, remainder = divmod(a * b, div)/n ids.quotient_low.low = quotient & ((1 << 128) - 1)/n ids.quotient_low.high = (quotient >> 128) & ((1 << 128) - 1)/n ids.quotient_high.low = (quotient >> 256) & ((1 << 128) - 1)/n ids.quotient_high.high = quotient >> 384/n ids.remainder.low = remainder & ((1 << 128) - 1)/n ids.remainder.high = remainder >> 128"

// ------ Usort hints related code ------

Expand Down
12 changes: 12 additions & 0 deletions pkg/hintrunner/zero/zerohint.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,18 @@ func GetHintFromCode(program *zero.ZeroProgram, rawHint zero.Hint, hintPC uint64
return createUint256AddHinter(resolver, true)
case split64Code:
return createSplit64Hinter(resolver)
case uint256SignedNNCode:
return createUint256SignedNNHinter(resolver)
case splitFeltCode:
return createSplitFeltHinter(resolver)
case uint256UnsignedDivRemCode:
return createUint256UnsignedDivRemHinter(resolver)
case uint256SqrtCode:
return createUint256SqrtHinter(resolver)
case uint256MulDivModCode:
return createUint256MulDivModHinter(resolver)
case sqrtCode:
return createSqrtHinter(resolver)
case unsignedDivRemCode:
return createUnsignedDivRemHinter(resolver)
default:
Expand Down
133 changes: 132 additions & 1 deletion pkg/hintrunner/zero/zerohint_math.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,16 @@ package zero

import (
"fmt"
"math/big"

"github.com/holiman/uint256"

"github.com/NethermindEth/cairo-vm-go/pkg/hintrunner/core"
"github.com/NethermindEth/cairo-vm-go/pkg/hintrunner/hinter"
"github.com/NethermindEth/cairo-vm-go/pkg/utils"
VM "github.com/NethermindEth/cairo-vm-go/pkg/vm"
"github.com/NethermindEth/cairo-vm-go/pkg/vm/memory"
"github.com/consensys/gnark-crypto/ecc/stark-curve/fp"
"math/big"
)

func newIsLeFeltHint(a, b hinter.ResOperander) hinter.Hinter {
Expand Down Expand Up @@ -503,6 +505,135 @@ func createSplitIntHinter(resolver hintReferenceResolver) (hinter.Hinter, error)
return newSplitIntHint(output, value, base, bound), nil
}

func newSplitFeltHint(low, high, value hinter.ResOperander) hinter.Hinter {
return &GenericZeroHinter{
Name: "SplitFelt",
Op: func(vm *VM.VirtualMachine, _ *hinter.HintRunnerContext) error {
//> from starkware.cairo.common.math_utils import assert_integer
// assert ids.MAX_HIGH < 2**128 and ids.MAX_LOW < 2**128
// assert PRIME - 1 == ids.MAX_HIGH * 2**128 + ids.MAX_LOW
// assert_integer(ids.value)
// ids.low = ids.value & ((1 << 128) - 1)
// ids.high = ids.value >> 128

//> assert ids.MAX_HIGH < 2**128 and ids.MAX_LOW < 2**128
maxHigh := new(fp.Element).Div(new(fp.Element).SetInt64(-1), &utils.FeltMax128)
maxLow := &utils.FeltZero

//> assert PRIME - 1 == ids.MAX_HIGH * 2**128 + ids.MAX_LOW
leftHandSide := new(fp.Element).SetInt64(-1)
rightHandSide := new(fp.Element).Add(new(fp.Element).Mul(maxHigh, &utils.FeltMax128), maxLow)
if leftHandSide.Cmp(rightHandSide) != 0 {
return fmt.Errorf("assertion `split_felt(): The sum of MAX_HIGH and MAX_LOW does not equal to PRIME - 1` failed")
}

//> assert_integer(ids.value)
value, err := hinter.ResolveAsFelt(vm, value)
if err != nil {
return err
}

var valueBigInt big.Int
value.BigInt(&valueBigInt)
lowAddr, err := low.GetAddress(vm)
if err != nil {
return err
}

highAddr, err := high.GetAddress(vm)
if err != nil {
return err
}

//> ids.low = ids.value & ((1 << 128) - 1)
felt128 := new(big.Int).Lsh(big.NewInt(1), 128)
felt128 = new(big.Int).Sub(felt128, big.NewInt(1))
lowBigInt := new(big.Int).And(&valueBigInt, felt128)
lowValue := memory.MemoryValueFromFieldElement(new(fp.Element).SetBigInt(lowBigInt))

err = vm.Memory.WriteToAddress(&lowAddr, &lowValue)
if err != nil {
return err
}
//> ids.high = ids.value >> 128
highBigInt := new(big.Int).Rsh(&valueBigInt, 128)
highValue := memory.MemoryValueFromFieldElement(new(fp.Element).SetBigInt(highBigInt))

return vm.Memory.WriteToAddress(&highAddr, &highValue)

},
}
}

func createSplitFeltHinter(resolver hintReferenceResolver) (hinter.Hinter, error) {
low, err := resolver.GetResOperander("low")
if err != nil {
return nil, err
}

high, err := resolver.GetResOperander("high")
if err != nil {
return nil, err
}

value, err := resolver.GetResOperander("value")
if err != nil {
return nil, err
}

return newSplitFeltHint(low, high, value), nil
}

func newSqrtHint(root, value hinter.ResOperander) hinter.Hinter {
return &GenericZeroHinter{
Name: "Sqrt",
Op: func(vm *VM.VirtualMachine, _ *hinter.HintRunnerContext) error {
//> from starkware.python.math_utils import isqrt
// value = ids.value % PRIME
// assert value < 2 ** 250, f"value={value} is outside of the range [0, 2**250)."
// assert 2 ** 250 < PRIME
// ids.root = isqrt(value)

rootAddr, err := root.GetAddress(vm)
if err != nil {
return err
}

value, err := hinter.ResolveAsFelt(vm, value)
if err != nil {
return err
}

if !utils.FeltLt(value, &utils.FeltUpperBound) {
return fmt.Errorf("assertion failed: %v is outside of the range [0, 2**250)", value)
}

// Conversion needed to handle non-square values
valueU256 := uint256.Int(value.Bits())
valueU256.Sqrt(&valueU256)

result := fp.Element{}
result.SetBytes(valueU256.Bytes())

v := memory.MemoryValueFromFieldElement(&result)
return vm.Memory.WriteToAddress(&rootAddr, &v)
},
}
}

func createSqrtHinter(resolver hintReferenceResolver) (hinter.Hinter, error) {

root, err := resolver.GetResOperander("root")
if err != nil {
return nil, err
}
value, err := resolver.GetResOperander("value")
if err != nil {
return nil, err
}
return newSqrtHint(root, value), nil
}

func newUnsignedDivRemHinter(value, div, q, r hinter.ResOperander) hinter.Hinter {
return &GenericZeroHinter{
Name: "UnsignedDivRem",
Expand Down
75 changes: 75 additions & 0 deletions pkg/hintrunner/zero/zerohint_math_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -555,6 +555,81 @@ func TestZeroHintMath(t *testing.T) {
errCheck: errorTextContains("outside of the range [0, 2**250)"),
},
},

"SplitFelt": {
{
operanders: []*hintOperander{
{Name: "low", Kind: reference, Value: addrBuiltin(starknet.RangeCheck, 0)},
{Name: "high", Kind: reference, Value: addrBuiltin(starknet.RangeCheck, 1)},
{Name: "value", Kind: apRelative, Value: feltString("100000000000000000000000000000000000000")},
},
makeHinter: func(ctx *hintTestContext) hinter.Hinter {
return newSplitFeltHint(ctx.operanders["low"], ctx.operanders["high"], ctx.operanders["value"])
},
check: allVarValueEquals(map[string]*fp.Element{
"low": feltString("100000000000000000000000000000000000000"),
"high": feltInt64(0),
}),
},
{
operanders: []*hintOperander{
{Name: "low", Kind: reference, Value: addrBuiltin(starknet.RangeCheck, 0)},
{Name: "high", Kind: reference, Value: addrBuiltin(starknet.RangeCheck, 1)},
{Name: "value", Kind: apRelative, Value: &utils.FeltMax128},
},
makeHinter: func(ctx *hintTestContext) hinter.Hinter {
return newSplitFeltHint(ctx.operanders["low"], ctx.operanders["high"], ctx.operanders["value"])
},
check: allVarValueEquals(map[string]*fp.Element{
"low": feltInt64(0),
"high": feltInt64(1),
}),
},
},

"SqrtHint": {
{
operanders: []*hintOperander{
{Name: "root", Kind: uninitialized},
{Name: "value", Kind: fpRelative, Value: feltInt64(25)},
},
makeHinter: func(ctx *hintTestContext) hinter.Hinter {
return newSqrtHint(ctx.operanders["root"], ctx.operanders["value"])
},
check: varValueEquals("root", feltInt64(5)),
},
{
operanders: []*hintOperander{
{Name: "root", Kind: uninitialized},
{Name: "value", Kind: fpRelative, Value: feltInt64(0)},
},
makeHinter: func(ctx *hintTestContext) hinter.Hinter {
return newSqrtHint(ctx.operanders["root"], ctx.operanders["value"])
},
check: varValueEquals("root", feltInt64(0)),
},
{
operanders: []*hintOperander{
{Name: "root", Kind: uninitialized},
{Name: "value", Kind: fpRelative, Value: feltInt64(50)},
},
makeHinter: func(ctx *hintTestContext) hinter.Hinter {
return newSqrtHint(ctx.operanders["root"], ctx.operanders["value"])
},
check: varValueEquals("root", feltInt64(7)),
},
{
operanders: []*hintOperander{
{Name: "root", Kind: uninitialized},
{Name: "value", Kind: fpRelative, Value: feltInt64(-128)},
},
makeHinter: func(ctx *hintTestContext) hinter.Hinter {
return newSqrtHint(ctx.operanders["root"], ctx.operanders["value"])
},
errCheck: errorTextContains("outside of the range [0, 2**250)"),
},
},

"UnsignedDivRem": {
{
operanders: []*hintOperander{
Expand Down
Loading

0 comments on commit 9a3a736

Please sign in to comment.