Skip to content

Commit

Permalink
Refactor main CLI, offset the hints indexes by entry code size, load …
Browse files Browse the repository at this point in the history
…arguments and initial gas to the memory
  • Loading branch information
MaksymMalicki committed Dec 10, 2024
1 parent 83872d8 commit cd04099
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 23 deletions.
21 changes: 3 additions & 18 deletions cmd/cli/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import (
"os"
"path/filepath"

"github.com/NethermindEth/cairo-vm-go/pkg/hintrunner/core"
"github.com/NethermindEth/cairo-vm-go/pkg/hintrunner/hinter"
hintrunner "github.com/NethermindEth/cairo-vm-go/pkg/hintrunner/zero"
"github.com/NethermindEth/cairo-vm-go/pkg/parsers/starknet"
Expand Down Expand Up @@ -203,29 +202,15 @@ func main() {
if err != nil {
return fmt.Errorf("cannot load program: %w", err)
}
mainFunc, ok := cairoProgram.EntryPointsByFunction["main"]
if !ok {
return fmt.Errorf("cannot find main function")
}
hints, err := core.GetCairoHints(cairoProgram)
if err != nil {
return fmt.Errorf("cannot get hints: %w", err)
}
program, err := runner.LoadCairoProgram(cairoProgram)
program, hints, err := runner.AssembleProgram(cairoProgram)
if err != nil {
return fmt.Errorf("cannot load program: %w", err)
return fmt.Errorf("cannot assemble program: %w", err)
}
entryCodeInstructions, err := runner.GetEntryCodeInstructions(mainFunc, false, 0)
if err != nil {
return fmt.Errorf("cannot load entry code instructions: %w", err)
}
program.Bytecode = append(entryCodeInstructions, program.Bytecode...)
program.Bytecode = append(program.Bytecode, runner.GetFooterInstructions()...)
runnerMode := runner.ExecutionMode
if proofmode {
runnerMode = runner.ProofModeCairo1
}
return runVM(*program, proofmode, maxsteps, entrypointOffset, collectTrace, traceLocation, buildMemory, memoryLocation, layoutName, airPublicInputLocation, airPrivateInputLocation, hints, runnerMode)
return runVM(program, proofmode, maxsteps, entrypointOffset, collectTrace, traceLocation, buildMemory, memoryLocation, layoutName, airPublicInputLocation, airPrivateInputLocation, hints, runnerMode)
},
},
},
Expand Down
56 changes: 51 additions & 5 deletions pkg/runner/runner.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (

"github.com/NethermindEth/cairo-vm-go/pkg/assembler"
"github.com/NethermindEth/cairo-vm-go/pkg/hintrunner"
"github.com/NethermindEth/cairo-vm-go/pkg/hintrunner/core"
"github.com/NethermindEth/cairo-vm-go/pkg/hintrunner/hinter"
"github.com/NethermindEth/cairo-vm-go/pkg/parsers/starknet"
"github.com/NethermindEth/cairo-vm-go/pkg/utils"
Expand Down Expand Up @@ -57,6 +58,34 @@ func NewRunner(program *Program, hints map[uint64][]hinter.Hinter, runnerMode Ru
}, nil
}

func AssembleProgram(cairoProgram *starknet.StarknetProgram) (Program, map[uint64][]hinter.Hinter, error) {
mainFunc, ok := cairoProgram.EntryPointsByFunction["main"]
if !ok {
return Program{}, nil, fmt.Errorf("cannot find main function")
}
program, err := LoadCairoProgram(cairoProgram)
if err != nil {
return Program{}, nil, fmt.Errorf("cannot load program: %w", err)
}
entryCodeInstructions, err := GetEntryCodeInstructions(mainFunc, false, 0)
if err != nil {
return Program{}, nil, fmt.Errorf("cannot load entry code instructions: %w", err)
}
program.Bytecode = append(entryCodeInstructions, program.Bytecode...)
program.Bytecode = append(program.Bytecode, GetFooterInstructions()...)

hints, err := core.GetCairoHints(cairoProgram)
if err != nil {
return Program{}, nil, fmt.Errorf("cannot get hints: %w", err)
}
offset := uint64(len(entryCodeInstructions))
shiftedHintsMap := make(map[uint64][]hinter.Hinter)
for key, value := range hints {
shiftedHintsMap[key+offset] = value
}
return *program, shiftedHintsMap, nil
}

// RunEntryPoint is like Run, but it executes the program starting from the given PC offset.
// This PC offset is expected to be a start from some function inside the loaded program.
func (runner *Runner) RunEntryPoint(pc uint64) error {
Expand Down Expand Up @@ -86,7 +115,10 @@ func (runner *Runner) RunEntryPoint(pc uint64) error {
if err != nil {
return err
}

err = runner.loadArguments(uint64(0), uint64(8979879877))
if err != nil {
return err
}
if err := runner.RunUntilPc(&end); err != nil {
return err
}
Expand All @@ -104,12 +136,16 @@ func (runner *Runner) Run() error {
return fmt.Errorf("initializing main entry point: %w", err)
}

err = runner.loadArguments(uint64(0), uint64(8979879877))
if err != nil {
return err
}
err = runner.RunUntilPc(&end)
if err != nil {
return err
}

if runner.runnerMode == ProofModeCairo0 || runner.runnerMode == ProofModeCairo1 {
if runner.isProofMode() {
// +1 because proof mode require an extra instruction run
// pow2 because proof mode also requires that the trace is a power of two
pow2Steps := utils.NextPowerOfTwo(runner.vm.Step + 1)
Expand Down Expand Up @@ -222,14 +258,18 @@ func (runner *Runner) initializeBuiltins(memory *mem.Memory) ([]mem.MemoryValue,
stack := []mem.MemoryValue{}
// adding to the stack only the builtins that are both in the program and in the layout
for _, bRunner := range runner.layout.Builtins {
builtinSegment := memory.AllocateBuiltinSegment(bRunner.Runner)
if utils.Contains(runner.program.Builtins, bRunner.Builtin) {
if utils.Contains(runner.program.Builtins, bRunner.Builtin) || runner.isProofMode() {
builtinSegment := memory.AllocateBuiltinSegment(bRunner.Runner)
stack = append(stack, mem.MemoryValueFromMemoryAddress(&builtinSegment))
}
}
return stack, nil
}

func (runner *Runner) isProofMode() bool {
return runner.runnerMode == ProofModeCairo0 || runner.runnerMode == ProofModeCairo1
}

func (runner *Runner) initializeVm(
initialPC *mem.MemoryAddress, stack []mem.MemoryValue, memory *mem.Memory, cairo1FpOffset uint64,
) error {
Expand All @@ -250,12 +290,18 @@ func (runner *Runner) initializeVm(
Ap: initialFp,
Fp: initialFp,
}, memory, vm.VirtualMachineConfig{
ProofMode: runner.runnerMode == ProofModeCairo0 || runner.runnerMode == ProofModeCairo1,
ProofMode: runner.isProofMode(),
CollectTrace: runner.collectTrace,
})
return err
}

func (runner *Runner) loadArguments(args, initialGas uint64) error {
mv := mem.MemoryValueFromUint(initialGas)
runner.vm.Memory.Segments[vm.ExecutionSegment].Write(runner.vm.Context.Ap+1, &mv)
return nil
}

// run until the program counter equals the `pc` parameter
func (runner *Runner) RunUntilPc(pc *mem.MemoryAddress) error {
for !runner.vm.Context.Pc.Equal(pc) {
Expand Down

0 comments on commit cd04099

Please sign in to comment.