From 285c954ef5a73cbb5aaddbc1012d4dc57d8d405f Mon Sep 17 00:00:00 2001 From: "M. Mahdi Khosravi" Date: Wed, 1 Nov 2023 16:52:01 +0300 Subject: [PATCH 1/3] Sqrt fix (#143) * added UpdatePc tests and fixed a bug in UpdatePc * Fixed accessing field value without Read() * fixed failing tests * small refactor for TestUpdatePcJump * fixed sqrtRoot function for non-square values * merge tests --------- Co-authored-by: Rodrigo --- pkg/hintrunner/hint.go | 9 +++++++-- pkg/hintrunner/hint_test.go | 16 ++++++++++++++++ 2 files changed, 23 insertions(+), 2 deletions(-) diff --git a/pkg/hintrunner/hint.go b/pkg/hintrunner/hint.go index 0d0164871..b2e750501 100644 --- a/pkg/hintrunner/hint.go +++ b/pkg/hintrunner/hint.go @@ -362,14 +362,19 @@ func (hint SquareRoot) Execute(vm *VM.VirtualMachine) error { return err } - sqrt := valueFelt.Sqrt(valueFelt) + // Need to do this conversion to handle non-square values + valueU256 := uint256.Int(valueFelt.Bits()) + valueU256.Sqrt(&valueU256) + + sqrt := f.Element{} + sqrt.SetBytes(valueU256.Bytes()) dstAddr, err := hint.dst.Get(vm) if err != nil { return fmt.Errorf("get destination cell: %v", err) } - dstVal := memory.MemoryValueFromFieldElement(sqrt) + dstVal := memory.MemoryValueFromFieldElement(&sqrt) err = vm.Memory.WriteToAddress(&dstAddr, &dstVal) if err != nil { return fmt.Errorf("write cell: %v", err) diff --git a/pkg/hintrunner/hint_test.go b/pkg/hintrunner/hint_test.go index 9d8ad78a2..394955e23 100644 --- a/pkg/hintrunner/hint_test.go +++ b/pkg/hintrunner/hint_test.go @@ -338,4 +338,20 @@ func TestSquareRoot(t *testing.T) { memory.MemoryValueFromInt(6), readFrom(vm, VM.ExecutionSegment, 1), ) + + dst = 2 + value = Immediate(*big.NewInt(30)) + hint = SquareRoot{ + value: value, + dst: dst, + } + + err = hint.Execute(vm) + + require.NoError(t, err) + require.Equal( + t, + memory.MemoryValueFromInt(5), + readFrom(vm, VM.ExecutionSegment, 2), + ) } From 77f493ddb1320137bcf0717721ba2fff0be8d2b7 Mon Sep 17 00:00:00 2001 From: Rodrigo Date: Wed, 1 Nov 2023 16:42:49 -0400 Subject: [PATCH 2/3] Feat: Add Dictionaries (#146) * Minor optimization infer res * Add assert equal benchmarks * Add AllocDict hint * very minor fix * alloc segment returns address instead of just segment index * minor improvements to hintrunner * Add Felt252DictEntryInit * [wip] add dict entry update hint * Minor changes in operand.go * update error fmt * Add entry update hint * Minor improvements to operand * Minor fixes in hintrunner * Set fixed seed for benchmark tests * Add segment arena index hint * Add squash init hint * Add get current access index hint * Add remaining hints * ShouldSkipSquashLoop * GetCurrentAccessDelta * ShouldContinueSquashLoop * GetNextDictKey * Add helper functions for memory * Add some comments * Some overall refactoring * minor change * Add some more comments * Bug fix * Add context initialization * Add error support for squashed dicts op * minor code improvs * Address review * Fix tests * random sqrt values in benchmarks * Address review --- pkg/hintrunner/constants.go | 7 - pkg/hintrunner/hint.go | 440 ++++++++++++++++++++++++--- pkg/hintrunner/hint_bechmark_test.go | 55 +++- pkg/hintrunner/hint_test.go | 71 +++-- pkg/hintrunner/hintrunner.go | 198 +++++++++++- pkg/hintrunner/hintrunner_test.go | 4 +- pkg/hintrunner/operand.go | 65 ++-- pkg/hintrunner/operand_test.go | 4 +- pkg/hintrunner/utils.go | 51 ++++ pkg/runners/zero/zero.go | 19 +- pkg/safemath/arrays.go | 7 + pkg/safemath/constant.go | 17 ++ pkg/vm/builtins/range_check.go | 7 +- pkg/vm/memory/memory.go | 59 +++- pkg/vm/memory/memory_value.go | 18 ++ 15 files changed, 868 insertions(+), 154 deletions(-) delete mode 100644 pkg/hintrunner/constants.go create mode 100644 pkg/hintrunner/utils.go create mode 100644 pkg/safemath/arrays.go create mode 100644 pkg/safemath/constant.go diff --git a/pkg/hintrunner/constants.go b/pkg/hintrunner/constants.go deleted file mode 100644 index 461ba3453..000000000 --- a/pkg/hintrunner/constants.go +++ /dev/null @@ -1,7 +0,0 @@ -package hintrunner - -import "github.com/holiman/uint256" - -func MaxU128() uint256.Int { - return uint256.Int{18446744073709551615, 18446744073709551615, 0, 0} -} diff --git a/pkg/hintrunner/hint.go b/pkg/hintrunner/hint.go index b2e750501..9fed309e1 100644 --- a/pkg/hintrunner/hint.go +++ b/pkg/hintrunner/hint.go @@ -2,31 +2,33 @@ package hintrunner import ( "fmt" + "sort" "github.com/holiman/uint256" + "github.com/NethermindEth/cairo-vm-go/pkg/safemath" VM "github.com/NethermindEth/cairo-vm-go/pkg/vm" - "github.com/NethermindEth/cairo-vm-go/pkg/vm/memory" + mem "github.com/NethermindEth/cairo-vm-go/pkg/vm/memory" f "github.com/consensys/gnark-crypto/ecc/stark-curve/fp" ) type Hinter interface { fmt.Stringer - Execute(vm *VM.VirtualMachine) error + Execute(vm *VM.VirtualMachine, ctx *HintRunnerContext) error } type AllocSegment struct { dst CellRefer } -func (hint AllocSegment) String() string { +func (hint *AllocSegment) String() string { return "AllocSegment" } -func (hint AllocSegment) Execute(vm *VM.VirtualMachine) error { - segmentIndex := vm.Memory.AllocateEmptySegment() - memAddress := memory.MemoryValueFromSegmentAndOffset(segmentIndex, 0) +func (hint *AllocSegment) Execute(vm *VM.VirtualMachine, _ *HintRunnerContext) error { + newSegment := vm.Memory.AllocateEmptySegment() + memAddress := mem.MemoryValueFromMemoryAddress(&newSegment) regAddr, err := hint.dst.Get(vm) if err != nil { @@ -35,7 +37,7 @@ func (hint AllocSegment) Execute(vm *VM.VirtualMachine) error { err = vm.Memory.WriteToAddress(®Addr, &memAddress) if err != nil { - return fmt.Errorf("write to address %s: %v", regAddr, err) + return fmt.Errorf("write to address %s: %w", regAddr, err) } return nil @@ -47,11 +49,11 @@ type TestLessThan struct { rhs ResOperander } -func (hint TestLessThan) String() string { +func (hint *TestLessThan) String() string { return "TestLessThan" } -func (hint TestLessThan) Execute(vm *VM.VirtualMachine) error { +func (hint *TestLessThan) Execute(vm *VM.VirtualMachine, _ *HintRunnerContext) error { lhsVal, err := hint.lhs.Resolve(vm) if err != nil { return fmt.Errorf("resolve lhs operand %s: %w", hint.lhs, err) @@ -82,7 +84,7 @@ func (hint TestLessThan) Execute(vm *VM.VirtualMachine) error { return fmt.Errorf("get dst address %s: %w", dstAddr, err) } - mv := memory.MemoryValueFromFieldElement(&resFelt) + mv := mem.MemoryValueFromFieldElement(&resFelt) err = vm.Memory.WriteToAddress(&dstAddr, &mv) if err != nil { return fmt.Errorf("write to dst address %s: %w", dstAddr, err) @@ -97,11 +99,11 @@ type TestLessThanOrEqual struct { rhs ResOperander } -func (hint TestLessThanOrEqual) String() string { +func (hint *TestLessThanOrEqual) String() string { return "TestLessThanOrEqual" } -func (hint TestLessThanOrEqual) Execute(vm *VM.VirtualMachine) error { +func (hint *TestLessThanOrEqual) Execute(vm *VM.VirtualMachine, _ *HintRunnerContext) error { lhsVal, err := hint.lhs.Resolve(vm) if err != nil { return fmt.Errorf("resolve lhs operand %s: %w", hint.lhs, err) @@ -132,7 +134,7 @@ func (hint TestLessThanOrEqual) Execute(vm *VM.VirtualMachine) error { return fmt.Errorf("get dst address %s: %w", dstAddr, err) } - mv := memory.MemoryValueFromFieldElement(&resFelt) + mv := mem.MemoryValueFromFieldElement(&resFelt) err = vm.Memory.WriteToAddress(&dstAddr, &mv) if err != nil { return fmt.Errorf("write to dst address %s: %w", dstAddr, err) @@ -206,13 +208,13 @@ func (hint LinearSplit) Execute(vm *VM.VirtualMachine) error { yFiled := &f.Element{} xFiled.SetBytes(x.Bytes()) yFiled.SetBytes(y.Bytes()) - mv := memory.MemoryValueFromFieldElement(xFiled) + mv := mem.MemoryValueFromFieldElement(xFiled) err = vm.Memory.WriteToAddress(&xAddr, &mv) if err != nil { return fmt.Errorf("write to x address %s: %w", xAddr, err) } - mv = memory.MemoryValueFromFieldElement(yFiled) + mv = mem.MemoryValueFromFieldElement(yFiled) err = vm.Memory.WriteToAddress(&yAddr, &mv) if err != nil { return fmt.Errorf("write to y address %s: %w", yAddr, err) @@ -228,20 +230,20 @@ type WideMul128 struct { low CellRefer } -func (hint WideMul128) String() string { +func (hint *WideMul128) String() string { return "WideMul128" } -func (hint WideMul128) Execute(vm *VM.VirtualMachine) error { - mask := MaxU128() +func (hint *WideMul128) Execute(vm *VM.VirtualMachine, _ *HintRunnerContext) error { + mask := &safemath.Uint256Max128 lhs, err := hint.lhs.Resolve(vm) if err != nil { - return fmt.Errorf("resolve lhs operand %s: %v", hint.lhs, err) + return fmt.Errorf("resolve lhs operand %s: %w", hint.lhs, err) } rhs, err := hint.rhs.Resolve(vm) if err != nil { - return fmt.Errorf("resolve rhs operand %s: %v", hint.rhs, err) + return fmt.Errorf("resolve rhs operand %s: %w", hint.rhs, err) } lhsFelt, err := lhs.FieldElement() @@ -256,10 +258,10 @@ func (hint WideMul128) Execute(vm *VM.VirtualMachine) error { lhsU256 := uint256.Int(lhsFelt.Bits()) rhsU256 := uint256.Int(rhsFelt.Bits()) - if lhsU256.Gt(&mask) { + if lhsU256.Gt(mask) { return fmt.Errorf("lhs operand %s should be u128", lhsFelt) } - if rhsU256.Gt(&mask) { + if rhsU256.Gt(mask) { return fmt.Errorf("rhs operand %s should be u128", rhsFelt) } @@ -275,9 +277,9 @@ func (hint WideMul128) Execute(vm *VM.VirtualMachine) error { lowAddr, err := hint.low.Get(vm) if err != nil { - return fmt.Errorf("get destination cell: %v", err) + return fmt.Errorf("get destination cell: %w", err) } - mvLow := memory.MemoryValueFromFieldElement(&low) + mvLow := mem.MemoryValueFromFieldElement(&low) err = vm.Memory.WriteToAddress(&lowAddr, &mvLow) if err != nil { return fmt.Errorf("write cell: %v", err) @@ -285,12 +287,12 @@ func (hint WideMul128) Execute(vm *VM.VirtualMachine) error { highAddr, err := hint.high.Get(vm) if err != nil { - return fmt.Errorf("get destination cell: %v", err) + return fmt.Errorf("get destination cell: %w", err) } - mvHigh := memory.MemoryValueFromFieldElement(&high) + mvHigh := mem.MemoryValueFromFieldElement(&high) err = vm.Memory.WriteToAddress(&highAddr, &mvHigh) if err != nil { - return fmt.Errorf("write cell: %v", err) + return fmt.Errorf("write cell: %w", err) } return nil } @@ -326,7 +328,7 @@ func (hint DebugPrint) Execute(vm *VM.VirtualMachine) error { current := startAddr.Offset for current < endAddr.Offset { - v, err := vm.Memory.ReadFromAddress(&memory.MemoryAddress{ + v, err := vm.Memory.ReadFromAddress(&mem.MemoryAddress{ SegmentIndex: startAddr.SegmentIndex, Offset: current, }) @@ -347,14 +349,14 @@ type SquareRoot struct { dst CellRefer } -func (hint SquareRoot) String() string { +func (hint *SquareRoot) String() string { return "SquareRoot" } -func (hint SquareRoot) Execute(vm *VM.VirtualMachine) error { +func (hint *SquareRoot) Execute(vm *VM.VirtualMachine, _ *HintRunnerContext) error { value, err := hint.value.Resolve(vm) if err != nil { - return fmt.Errorf("resolve value operand %s: %v", hint.value, err) + return fmt.Errorf("resolve value operand %s: %w", hint.value, err) } valueFelt, err := value.FieldElement() @@ -371,13 +373,383 @@ func (hint SquareRoot) Execute(vm *VM.VirtualMachine) error { dstAddr, err := hint.dst.Get(vm) if err != nil { - return fmt.Errorf("get destination cell: %v", err) + return fmt.Errorf("get destination cell: %w", err) } - dstVal := memory.MemoryValueFromFieldElement(&sqrt) + dstVal := mem.MemoryValueFromFieldElement(&sqrt) err = vm.Memory.WriteToAddress(&dstAddr, &dstVal) if err != nil { - return fmt.Errorf("write cell: %v", err) + return fmt.Errorf("write cell: %w", err) } return nil } + +// +// Dictionary Hints +// + +type AllocFelt252Dict struct { + SegmentArenaPtr ResOperander +} + +func (hint *AllocFelt252Dict) String() string { + return "AllocFelt252Dict" +} +func (hint *AllocFelt252Dict) Execute(vm *VM.VirtualMachine, ctx *HintRunnerContext) error { + InitializeDictionaryManagerIfNot(ctx) + + arenaPtr, err := ResolveAsAddress(vm, hint.SegmentArenaPtr) + if err != nil { + return fmt.Errorf("resolve segment arena pointer: %w", err) + } + + // find for the amount of initialized dicts + initializedDictsOffset, overflow := safemath.SafeOffset(arenaPtr.Offset, -2) + if overflow { + return fmt.Errorf("look for initialized dicts: overflow: %s - 2", arenaPtr) + } + initializedDictsFelt, err := vm.Memory.Read(arenaPtr.SegmentIndex, initializedDictsOffset) + if err != nil { + return fmt.Errorf("read initialized dicts: %w", err) + } + initializedDicts, err := initializedDictsFelt.Uint64() + if err != nil { + return fmt.Errorf("read initialized dicts: %w", err) + } + + // find for the segment info pointer + segmentInfoOffset, overflow := safemath.SafeOffset(arenaPtr.Offset, -3) + if overflow { + return fmt.Errorf("look for segment info pointer: overflow: %s - 3", arenaPtr) + } + segmentInfoMv, err := vm.Memory.Read(arenaPtr.SegmentIndex, segmentInfoOffset) + if err != nil { + return fmt.Errorf("read segment info pointer: %w", err) + } + segmentInfoPtr, err := segmentInfoMv.MemoryAddress() + if err != nil { + return fmt.Errorf("expected pointer to segment info but got a felt: %w", err) + } + + // with the segment info pointer and the number of initialized dictionaries we know + // where to write the new dictionary + newDictAddress := ctx.DictionaryManager.NewDictionary(vm) + mv := mem.MemoryValueFromMemoryAddress(&newDictAddress) + insertOffset := segmentInfoPtr.Offset + initializedDicts*3 + if err = vm.Memory.Write(segmentInfoPtr.SegmentIndex, insertOffset, &mv); err != nil { + return fmt.Errorf("write new dictionary to segment info: %w", err) + } + return nil +} + +type Felt252DictEntryInit struct { + DictPtr ResOperander + Key ResOperander +} + +func (hint Felt252DictEntryInit) String() string { + return "Felt252DictEntryInit" +} + +func (hint *Felt252DictEntryInit) Execute(vm *VM.VirtualMachine, ctx *HintRunnerContext) error { + dictPtr, err := ResolveAsAddress(vm, hint.DictPtr) + if err != nil { + return fmt.Errorf("resolve dictionary pointer: %w", err) + } + + key, err := ResolveAsFelt(vm, hint.Key) + if err != nil { + return fmt.Errorf("resolve key: %w", err) + } + + prevValue, err := ctx.DictionaryManager.At(&dictPtr, &key) + if err != nil { + return fmt.Errorf("get dictionary entry: %w", err) + } + if prevValue == nil { + mv := mem.EmptyMemoryValueAsFelt() + prevValue = &mv + _ = ctx.DictionaryManager.Set(&dictPtr, &key, prevValue) + } + return vm.Memory.Write(dictPtr.SegmentIndex, dictPtr.Offset+1, prevValue) +} + +type Felt252DictEntryUpdate struct { + DictPtr ResOperander + Value ResOperander +} + +func (hint Felt252DictEntryUpdate) String() string { + return "Felt252DictEntryUpdate" +} + +func (hint *Felt252DictEntryUpdate) Execute(vm *VM.VirtualMachine, ctx *HintRunnerContext) error { + dictPtr, err := ResolveAsAddress(vm, hint.DictPtr) + if err != nil { + return fmt.Errorf("resolve dictionary pointer: %w", err) + } + + keyPtr, err := dictPtr.AddOffset(-3) + if err != nil { + return fmt.Errorf("get key pointer: %w", err) + } + keyMv, err := vm.Memory.ReadFromAddress(&keyPtr) + if err != nil { + return fmt.Errorf("read key pointer: %w", err) + } + key, err := keyMv.FieldElement() + if err != nil { + return fmt.Errorf("expected key to be a field element: %w", err) + } + + value, err := hint.Value.Resolve(vm) + if err != nil { + return fmt.Errorf("resolve value: %w", err) + } + + return ctx.DictionaryManager.Set(&dictPtr, key, &value) +} + +type GetSegmentArenaIndex struct { + DictIndex CellRefer + DictEndPtr ResOperander +} + +func (hint *GetSegmentArenaIndex) String() string { + return "GetSegmentArenaIndex" +} + +func (hint *GetSegmentArenaIndex) Execute(vm *VM.VirtualMachine, ctx *HintRunnerContext) error { + dictIndex, err := hint.DictIndex.Get(vm) + if err != nil { + return fmt.Errorf("get dict index: %w", err) + } + + dictEndPtr, err := ResolveAsAddress(vm, hint.DictEndPtr) + if err != nil { + return fmt.Errorf("resolve dict end pointer: %w", err) + } + + dict, err := ctx.DictionaryManager.GetDictionary(&dictEndPtr) + if err != nil { + return fmt.Errorf("get dictionary: %w", err) + } + + initNum := mem.MemoryValueFromUint(dict.InitNumber()) + return vm.Memory.WriteToAddress(&dictIndex, &initNum) +} + +// +// Squashed Dictionary Hints +// + +type InitSquashData struct { + FirstKey CellRefer + BigKeys CellRefer + DictAccesses ResOperander + NumAccesses ResOperander +} + +func (hint *InitSquashData) String() string { + return "InitSquashData" +} + +func (hint *InitSquashData) Execute(vm *VM.VirtualMachine, ctx *HintRunnerContext) error { + // todo(rodro): Don't know if it could be called multiple times, or + err := InitializeSquashedDictionaryManager(ctx) + if err != nil { + return err + } + + dictAccessPtr, err := ResolveAsAddress(vm, hint.DictAccesses) + if err != nil { + return fmt.Errorf("resolve dict access: %w", err) + } + + numAccess, err := ResolveAsUint64(vm, hint.NumAccesses) + if err != nil { + return fmt.Errorf("resolve num access: %w", err) + } + + const dictAccessSize = 3 + for i := uint64(0); i < numAccess; i++ { + keyPtr := mem.MemoryAddress{ + SegmentIndex: dictAccessPtr.SegmentIndex, + Offset: dictAccessPtr.Offset + i*dictAccessSize, + } + key, err := vm.Memory.ReadFromAddressAsElement(&keyPtr) + if err != nil { + return fmt.Errorf("reading key at %s: %w", keyPtr, err) + } + + ctx.SquashedDictionaryManager.Insert(&key, i) + } + for key, val := range ctx.SquashedDictionaryManager.KeyToIndices { + // reverse each indice access list per key + safemath.Reverse(val) + // store each key + ctx.SquashedDictionaryManager.Keys = append(ctx.SquashedDictionaryManager.Keys, key) + } + + // sort the keys in descending order + sort.Slice(ctx.SquashedDictionaryManager.Keys, func(i, j int) bool { + return ctx.SquashedDictionaryManager.Keys[i].Cmp(&ctx.SquashedDictionaryManager.Keys[j]) < 0 + }) + + // if the first key is bigger than 2^128, signal it + bigKeysAddr, err := hint.BigKeys.Get(vm) + if err != nil { + return fmt.Errorf("get big keys address: %w", err) + } + biggestKey := ctx.SquashedDictionaryManager.Keys[0] + cmpRes := mem.MemoryValueFromUint[uint64](0) + if biggestKey.Cmp(&safemath.FeltMax128) > 0 { + cmpRes = mem.MemoryValueFromUint[uint64](1) + } + err = vm.Memory.WriteToAddress(&bigKeysAddr, &cmpRes) + if err != nil { + return fmt.Errorf("write big keys address: %w", err) + } + + // store the left most, smaller key + firstKeyAddr, err := hint.FirstKey.Get(vm) + if err != nil { + return fmt.Errorf("get first key address: %w", err) + } + firstKey, err := ctx.SquashedDictionaryManager.LastKey() + if err != nil { + return fmt.Errorf("get first key: %w", err) + } + + mv := mem.MemoryValueFromFieldElement(&firstKey) + return vm.Memory.WriteToAddress(&firstKeyAddr, &mv) +} + +type GetCurrentAccessIndex struct { + RangeCheckPtr ResOperander +} + +func (hint *GetCurrentAccessIndex) String() string { + return "GetCurrentAccessIndex" +} + +func (hint *GetCurrentAccessIndex) Execute(vm *VM.VirtualMachine, ctx *HintRunnerContext) error { + rangeCheckPtr, err := ResolveAsAddress(vm, hint.RangeCheckPtr) + if err != nil { + return fmt.Errorf("resolve range check pointer: %w", err) + } + + lastIndex64, err := ctx.SquashedDictionaryManager.LastIndex() + if err != nil { + return fmt.Errorf("get last index: %w", err) + } + + lastIndex := f.NewElement(lastIndex64) + mv := mem.MemoryValueFromFieldElement(&lastIndex) + + return vm.Memory.WriteToAddress(&rangeCheckPtr, &mv) +} + +type ShouldSkipSquashLoop struct { + ShouldSkipLoop CellRefer +} + +func (hint *ShouldSkipSquashLoop) String() string { + return "ShouldSkipSquashLoop" +} + +func (hint *ShouldSkipSquashLoop) Execute(vm *VM.VirtualMachine, ctx *HintRunnerContext) error { + shouldSkipLoopAddr, err := hint.ShouldSkipLoop.Get(vm) + if err != nil { + return fmt.Errorf("get should skip loop address: %w", err) + } + + var shouldSkipLoop f.Element + if lastIndices, err := ctx.SquashedDictionaryManager.LastIndices(); err == nil && len(lastIndices) > 1 { + shouldSkipLoop.SetOne() + } else if err != nil { + return fmt.Errorf("get last indices: %w", err) + } + + mv := mem.MemoryValueFromFieldElement(&shouldSkipLoop) + return vm.Memory.WriteToAddress(&shouldSkipLoopAddr, &mv) +} + +type GetCurrentAccessDelta struct { + IndexDeltaMinusOne CellRefer +} + +func (hint *GetCurrentAccessDelta) String() string { + return "GetCurrentAccessDelta" +} + +func (hint *GetCurrentAccessDelta) Execute(vm *VM.VirtualMachine, ctx *HintRunnerContext) error { + indexDeltaPtr, err := hint.IndexDeltaMinusOne.Get(vm) + if err != nil { + return fmt.Errorf("get index delta address: %w", err) + } + + previousKeyIndex, err := ctx.SquashedDictionaryManager.PopIndex() + if err != nil { + return fmt.Errorf("pop index: %w", err) + } + + currentKeyIndex, err := ctx.SquashedDictionaryManager.LastIndex() + if err != nil { + return fmt.Errorf("get last index: %w", err) + } + + // todo(rodro): could previousKeyIndex be bigger than currentKeyIndex? + indexDeltaMinusOne := currentKeyIndex - previousKeyIndex - 1 + mv := mem.MemoryValueFromUint(indexDeltaMinusOne) + + return vm.Memory.WriteToAddress(&indexDeltaPtr, &mv) +} + +type ShouldContinueSquashLoop struct { + ShouldContinue CellRefer +} + +func (hint *ShouldContinueSquashLoop) String() string { + return "ShouldContinueSquashLoop" +} + +func (hint *ShouldContinueSquashLoop) Execute(vm *VM.VirtualMachine, ctx *HintRunnerContext) error { + shouldContinuePtr, err := hint.ShouldContinue.Get(vm) + if err != nil { + return fmt.Errorf("get should continue address: %w", err) + } + + var shouldContinueLoop f.Element + if lastIndices, err := ctx.SquashedDictionaryManager.LastIndices(); err == nil && len(lastIndices) <= 1 { + shouldContinueLoop.SetOne() + } else if err != nil { + return fmt.Errorf("get last indices: %w", err) + } + + mv := mem.MemoryValueFromFieldElement(&shouldContinueLoop) + return vm.Memory.WriteToAddress(&shouldContinuePtr, &mv) +} + +type GetNextDictKey struct { + NextKey CellRefer +} + +func (hint *GetNextDictKey) String() string { + return "GetNextDictKey" +} + +func (hint *GetNextDictKey) Execute(vm *VM.VirtualMachine, ctx *HintRunnerContext) error { + nextKeyAddr, err := hint.NextKey.Get(vm) + if err != nil { + return fmt.Errorf("get next key address: %w", err) + } + + nextKey, err := ctx.SquashedDictionaryManager.PopKey() + if err != nil { + return fmt.Errorf("pop key: %w", err) + } + + mv := mem.MemoryValueFromFieldElement(&nextKey) + return vm.Memory.WriteToAddress(&nextKeyAddr, &mv) +} diff --git a/pkg/hintrunner/hint_bechmark_test.go b/pkg/hintrunner/hint_bechmark_test.go index 08034690c..5a87a8bfc 100644 --- a/pkg/hintrunner/hint_bechmark_test.go +++ b/pkg/hintrunner/hint_bechmark_test.go @@ -1,12 +1,12 @@ package hintrunner import ( - "math/big" "math/rand" "testing" VM "github.com/NethermindEth/cairo-vm-go/pkg/vm" "github.com/NethermindEth/cairo-vm-go/pkg/vm/memory" + f "github.com/consensys/gnark-crypto/ecc/stark-curve/fp" ) func BenchmarkAllocSegment(b *testing.B) { @@ -14,10 +14,11 @@ func BenchmarkAllocSegment(b *testing.B) { vm.Context.Ap = 0 vm.Context.Fp = 0 var ap ApCellRef = 1 + b.ResetTimer() for i := 0; i < b.N; i++ { alloc := AllocSegment{ap} - err := alloc.Execute(vm) + err := alloc.Execute(vm, nil) if err != nil { b.Error(err) break @@ -35,12 +36,19 @@ func BenchmarkLessThan(b *testing.B) { var dst ApCellRef = 0 var rhsRef ApCellRef = 1 cell := uint64(0) - b.ResetTimer() + rand := defaultRandGenerator() + + b.ResetTimer() for i := 0; i < b.N; i++ { - writeTo(vm, VM.ExecutionSegment, vm.Context.Ap+uint64(rhsRef), memory.MemoryValueFromInt(rand.Int63())) + writeTo( + vm, + VM.ExecutionSegment, + vm.Context.Ap+uint64(rhsRef), + memory.MemoryValueFromInt(rand.Int63()), + ) rhs := Deref{rhsRef} - lhs := Immediate(*big.NewInt(rand.Int63())) + lhs := Immediate(randomFeltElement(rand)) hint := TestLessThan{ dst: dst, @@ -48,7 +56,7 @@ func BenchmarkLessThan(b *testing.B) { rhs: rhs, } - err := hint.Execute(vm) + err := hint.Execute(vm, nil) if err != nil { b.Error(err) break @@ -66,16 +74,17 @@ func BenchmarkSquareRoot(b *testing.B) { var dst ApCellRef = 1 + rand := defaultRandGenerator() + b.ResetTimer() for i := 0; i < b.N; i++ { - //TODO: Change to rand.Uint64() - value := Immediate(*big.NewInt(int64(i * i))) + value := Immediate(randomFeltElement(rand)) hint := SquareRoot{ value: value, dst: dst, } - err := hint.Execute(vm) + err := hint.Execute(vm, nil) if err != nil { b.Error(err) break @@ -94,10 +103,11 @@ func BenchmarkWideMul128(b *testing.B) { var dstLow ApCellRef = 0 var dstHigh ApCellRef = 1 + rand := defaultRandGenerator() b.ResetTimer() for i := 0; i < b.N; i++ { - lhs := Immediate(*new(big.Int).SetUint64(rand.Uint64())) - rhs := Immediate(*new(big.Int).SetUint64(rand.Uint64())) + lhs := Immediate(randomFeltElement(rand)) + rhs := Immediate(randomFeltElement(rand)) hint := WideMul128{ low: dstLow, @@ -106,7 +116,7 @@ func BenchmarkWideMul128(b *testing.B) { rhs: rhs, } - err := hint.Execute(vm) + err := hint.Execute(vm, nil) if err != nil { b.Error(err) break @@ -121,12 +131,14 @@ func BenchmarkLinearSplit(b *testing.B) { vm.Context.Ap = 0 vm.Context.Fp = 0 + rand := defaultRandGenerator() + var x ApCellRef = 0 var y ApCellRef = 1 for i := 0; i < b.N; i++ { - value := Immediate(*big.NewInt(rand.Int63())) - scalar := Immediate(*big.NewInt(rand.Int63())) - maxX := Immediate(*big.NewInt(rand.Int63())) + value := Immediate(randomFeltElement(rand)) + scalar := Immediate(randomFeltElement(rand)) + maxX := Immediate(randomFeltElement(rand)) hint := LinearSplit{ value: value, scalar: scalar, @@ -142,5 +154,18 @@ func BenchmarkLinearSplit(b *testing.B) { } vm.Context.Ap += 2 } +} + +func randomFeltElement(rand *rand.Rand) f.Element { + data := [4]uint64{ + rand.Uint64(), + rand.Uint64(), + rand.Uint64(), + rand.Uint64(), + } + return f.Element(data) +} +func defaultRandGenerator() *rand.Rand { + return rand.New(rand.NewSource(0)) } diff --git a/pkg/hintrunner/hint_test.go b/pkg/hintrunner/hint_test.go index 394955e23..76f7698ba 100644 --- a/pkg/hintrunner/hint_test.go +++ b/pkg/hintrunner/hint_test.go @@ -9,6 +9,7 @@ import ( VM "github.com/NethermindEth/cairo-vm-go/pkg/vm" "github.com/NethermindEth/cairo-vm-go/pkg/vm/memory" f "github.com/consensys/gnark-crypto/ecc/stark-curve/fp" + "github.com/holiman/uint256" "github.com/stretchr/testify/require" ) @@ -23,7 +24,7 @@ func TestAllocSegment(t *testing.T) { alloc1 := AllocSegment{ap} alloc2 := AllocSegment{fp} - err := alloc1.Execute(vm) + err := alloc1.Execute(vm, nil) require.Nil(t, err) require.Equal(t, 3, len(vm.Memory.Segments)) require.Equal( @@ -32,7 +33,7 @@ func TestAllocSegment(t *testing.T) { readFrom(vm, VM.ExecutionSegment, vm.Context.Ap+5), ) - err = alloc2.Execute(vm) + err = alloc2.Execute(vm, nil) require.Nil(t, err) require.Equal(t, 4, len(vm.Memory.Segments)) require.Equal( @@ -53,7 +54,7 @@ func TestTestLessThanTrue(t *testing.T) { var rhsRef FpCellRef = 0 rhs := Deref{rhsRef} - lhs := Immediate(*big.NewInt(13)) + lhs := Immediate(f.NewElement(13)) hint := TestLessThan{ dst: dst, @@ -61,7 +62,7 @@ func TestTestLessThanTrue(t *testing.T) { rhs: rhs, } - err := hint.Execute(vm) + err := hint.Execute(vm, nil) require.NoError(t, err) require.Equal( t, @@ -72,11 +73,11 @@ func TestTestLessThanTrue(t *testing.T) { } func TestTestLessThanFalse(t *testing.T) { testCases := []struct { - lhsValue *big.Int + lhsValue f.Element expectedMsg string }{ - {big.NewInt(32), "Expected the hint to evaluate to False when lhs is larger"}, - {big.NewInt(17), "Expected the hint to evaluate to False when values are equal"}, + {f.NewElement(32), "Expected the hint to evaluate to False when lhs is larger"}, + {f.NewElement(17), "Expected the hint to evaluate to False when values are equal"}, } for _, tc := range testCases { @@ -90,14 +91,14 @@ func TestTestLessThanFalse(t *testing.T) { var rhsRef FpCellRef = 0 rhs := Deref{rhsRef} - lhs := Immediate(*tc.lhsValue) + lhs := Immediate(tc.lhsValue) hint := TestLessThan{ dst: dst, lhs: lhs, rhs: rhs, } - err := hint.Execute(vm) + err := hint.Execute(vm, nil) require.NoError(t, err) require.Equal( t, @@ -111,11 +112,11 @@ func TestTestLessThanFalse(t *testing.T) { func TestTestLessThanOrEqTrue(t *testing.T) { testCases := []struct { - lhsValue *big.Int + lhsValue f.Element expectedMsg string }{ - {big.NewInt(13), "Expected the hint to evaluate to True when lhs is less than rhs"}, - {big.NewInt(23), "Expected the hint to evaluate to True when values are equal"}, + {f.NewElement(13), "Expected the hint to evaluate to True when lhs is less than rhs"}, + {f.NewElement(23), "Expected the hint to evaluate to True when values are equal"}, } for _, tc := range testCases { @@ -129,14 +130,14 @@ func TestTestLessThanOrEqTrue(t *testing.T) { var rhsRef FpCellRef = 0 rhs := Deref{rhsRef} - lhs := Immediate(*tc.lhsValue) + lhs := Immediate(tc.lhsValue) hint := TestLessThanOrEqual{ dst: dst, lhs: lhs, rhs: rhs, } - err := hint.Execute(vm) + err := hint.Execute(vm, nil) require.NoError(t, err) require.Equal( t, @@ -158,7 +159,7 @@ func TestTestLessThanOrEqFalse(t *testing.T) { var rhsRef FpCellRef = 0 rhs := Deref{rhsRef} - lhs := Immediate(*big.NewInt(32)) + lhs := Immediate(f.NewElement(32)) hint := TestLessThanOrEqual{ dst: dst, @@ -166,7 +167,7 @@ func TestTestLessThanOrEqFalse(t *testing.T) { rhs: rhs, } - err := hint.Execute(vm) + err := hint.Execute(vm, nil) require.NoError(t, err) require.Equal( t, @@ -181,9 +182,9 @@ func TestLinearSplit(t *testing.T) { vm.Context.Ap = 0 vm.Context.Fp = 0 - value := Immediate(*big.NewInt(42*223344 + 14)) - scalar := Immediate(*big.NewInt(42)) - maxX := Immediate(*big.NewInt(9999999999)) + value := Immediate(f.NewElement(42*223344 + 14)) + scalar := Immediate(f.NewElement(42)) + maxX := Immediate(f.NewElement(9999999999)) var x ApCellRef = 0 var y ApCellRef = 1 @@ -207,7 +208,7 @@ func TestLinearSplit(t *testing.T) { vm.Context.Fp = 0 //Lower max_x - maxX = Immediate(*big.NewInt(223343)) + maxX = Immediate(f.NewElement(223343)) hint = LinearSplit{ value: value, scalar: scalar, @@ -232,8 +233,14 @@ func TestWideMul128(t *testing.T) { var dstLow ApCellRef = 1 var dstHigh ApCellRef = 2 - lhs := Immediate(*big.NewInt(1).Lsh(big.NewInt(1), 127)) - rhs := Immediate(*big.NewInt(1<<8 + 1)) + lhsBytes := new(uint256.Int).Lsh(uint256.NewInt(1), 127).Bytes32() + lhsFelt, err := f.BigEndian.Element(&lhsBytes) + require.NoError(t, err) + + rhsFelt := f.NewElement(1<<8 + 1) + + lhs := Immediate(lhsFelt) + rhs := Immediate(rhsFelt) hint := WideMul128{ low: dstLow, @@ -242,7 +249,7 @@ func TestWideMul128(t *testing.T) { rhs: rhs, } - err := hint.Execute(vm) + err = hint.Execute(vm, nil) require.Nil(t, err) low := &f.Element{} @@ -268,8 +275,12 @@ func TestWideMul128IncorrectRange(t *testing.T) { var dstLow ApCellRef = 1 var dstHigh ApCellRef = 2 - lhs := Immediate(*big.NewInt(1).Lsh(big.NewInt(1), 128)) - rhs := Immediate(*big.NewInt(1)) + lhsBytes := new(uint256.Int).Lsh(uint256.NewInt(1), 128).Bytes32() + lhsFelt, err := f.BigEndian.Element(&lhsBytes) + require.NoError(t, err) + + lhs := Immediate(lhsFelt) + rhs := Immediate(f.NewElement(1)) hint := WideMul128{ low: dstLow, @@ -278,7 +289,7 @@ func TestWideMul128IncorrectRange(t *testing.T) { rhs: rhs, } - err := hint.Execute(vm) + err = hint.Execute(vm, nil) require.ErrorContains(t, err, "should be u128") } @@ -324,13 +335,13 @@ func TestSquareRoot(t *testing.T) { vm.Context.Fp = 0 var dst ApCellRef = 1 - value := Immediate(*big.NewInt(36)) + value := Immediate(f.NewElement(36)) hint := SquareRoot{ value: value, dst: dst, } - err := hint.Execute(vm) + err := hint.Execute(vm, nil) require.NoError(t, err) require.Equal( @@ -340,13 +351,13 @@ func TestSquareRoot(t *testing.T) { ) dst = 2 - value = Immediate(*big.NewInt(30)) + value = Immediate(f.NewElement(30)) hint = SquareRoot{ value: value, dst: dst, } - err = hint.Execute(vm) + err = hint.Execute(vm, nil) require.NoError(t, err) require.Equal( diff --git a/pkg/hintrunner/hintrunner.go b/pkg/hintrunner/hintrunner.go index 40a1dd6af..af3979e1e 100644 --- a/pkg/hintrunner/hintrunner.go +++ b/pkg/hintrunner/hintrunner.go @@ -4,25 +4,215 @@ import ( "fmt" VM "github.com/NethermindEth/cairo-vm-go/pkg/vm" + mem "github.com/NethermindEth/cairo-vm-go/pkg/vm/memory" + f "github.com/consensys/gnark-crypto/ecc/stark-curve/fp" ) -// todo: Can two or more hints be assigned to a specific PC? +// Used to keep track of all dictionaries data +type Dictionary struct { + // The data contained on a dictionary + data map[f.Element]*mem.MemoryValue + // Unique id assigned at the moment of creation + idx uint64 +} + +// Gets the memory value at certain key +func (d *Dictionary) At(key *f.Element) (*mem.MemoryValue, error) { + if value, ok := d.data[*key]; ok { + return value, nil + } + return nil, fmt.Errorf("no value for key %s", key) +} + +// Given a key and a value, it sets the value at the given key +func (d *Dictionary) Set(key *f.Element, value *mem.MemoryValue) { + d.data[*key] = value +} + +// Returns the initialization number when the dictionary was created +func (d *Dictionary) InitNumber() uint64 { + return d.idx +} + +// Used to manage dictionaries creation +type DictionaryManager struct { + // a map that links a segment index to a dictionary + dictionaries map[uint64]Dictionary +} + +func InitializeDictionaryManagerIfNot(ctx *HintRunnerContext) { + if ctx.DictionaryManager.dictionaries == nil { + ctx.DictionaryManager.dictionaries = make(map[uint64]Dictionary) + } +} + +// It creates a new segment which will hold dictionary values. It links this +// segment with the current dictionary and returns the address that points +// to the start of this segment +func (dm *DictionaryManager) NewDictionary(vm *VM.VirtualMachine) mem.MemoryAddress { + newDictAddr := vm.Memory.AllocateEmptySegment() + dm.dictionaries[newDictAddr.SegmentIndex] = Dictionary{ + data: make(map[f.Element]*mem.MemoryValue), + idx: uint64(len(dm.dictionaries)), + } + return newDictAddr +} + +// Given a memory address, it looks for the right dictionary using the segment index. If no +// segment is associated with the given segment index, it errors +func (dm *DictionaryManager) GetDictionary(dictAddr *mem.MemoryAddress) (Dictionary, error) { + dict, ok := dm.dictionaries[dictAddr.SegmentIndex] + if ok { + return dict, nil + } + return Dictionary{}, fmt.Errorf("no dictionary at address %s", dictAddr) +} + +// Given a memory address and a key it returns the value held at that position. The address is used +// to locate the correct dictionary and the key to index on it +func (dm *DictionaryManager) At(dictAddr *mem.MemoryAddress, key *f.Element) (*mem.MemoryValue, error) { + if dict, ok := dm.dictionaries[dictAddr.SegmentIndex]; ok { + return dict.At(key) + } + return nil, fmt.Errorf("no dictionary at address %s", dictAddr) +} + +// Given a memory address,a key and a value it stores the value at the correct position. +func (dm *DictionaryManager) Set(dictAddr *mem.MemoryAddress, key *f.Element, value *mem.MemoryValue) error { + if dict, ok := dm.dictionaries[dictAddr.SegmentIndex]; ok { + dict.Set(key, value) + return nil + } + return fmt.Errorf("no dictionary at address %s", dictAddr) +} + +// Used to keep track of squashed dictionaries +type SquashedDictionaryManager struct { + // A map from each key to a list of indices where the key is present + // the list in reversed order. + // Note: The indices should be Felts, but current memory limitations + // make it impossible to use an index that big so we use uint64 instead + KeyToIndices map[f.Element][]uint64 + + // A descending list of keys + Keys []f.Element +} + +func InitializeSquashedDictionaryManager(ctx *HintRunnerContext) error { + if ctx.SquashedDictionaryManager.KeyToIndices != nil || + ctx.SquashedDictionaryManager.Keys != nil { + return fmt.Errorf("squashed dictionary manager already initialized") + } + ctx.SquashedDictionaryManager.KeyToIndices = make(map[f.Element][]uint64, 100) + ctx.SquashedDictionaryManager.Keys = make([]f.Element, 0, 100) + return nil +} + +// It adds another index to the list of indices associated to the given key +// If the key is not present, it creates a new entry +func (sdm *SquashedDictionaryManager) Insert(key *f.Element, index uint64) { + keyIndex := *key + if indices, ok := sdm.KeyToIndices[keyIndex]; ok { + sdm.KeyToIndices[keyIndex] = append(indices, index) + } else { + sdm.KeyToIndices[keyIndex] = []uint64{index} + } +} + +// It returns the smallest key in the key list +func (sdm *SquashedDictionaryManager) LastKey() (f.Element, error) { + if len(sdm.Keys) == 0 { + return f.Element{}, fmt.Errorf("no keys left") + } + return sdm.Keys[len(sdm.Keys)-1], nil +} + +// It pops out the smallest key in the key list +func (sdm *SquashedDictionaryManager) PopKey() (f.Element, error) { + key, err := sdm.LastKey() + if err != nil { + return key, err + } + + sdm.Keys = sdm.Keys[:len(sdm.Keys)-1] + return key, nil +} + +// It returns the list of indices associated to the smallest key +func (sdm *SquashedDictionaryManager) LastIndices() ([]uint64, error) { + key, err := sdm.LastKey() + if err != nil { + return nil, err + } + + return sdm.KeyToIndices[key], nil +} + +// It returns smallest index associated with the smallest key +func (sdm *SquashedDictionaryManager) LastIndex() (uint64, error) { + key, err := sdm.LastKey() + if err != nil { + return 0, err + } + + indices := sdm.KeyToIndices[key] + if len(indices) == 0 { + return 0, fmt.Errorf("no indices for key %s", &key) + } + + return indices[len(indices)-1], nil +} + +// It pops out smallest index associated with the smallest key +func (sdm *SquashedDictionaryManager) PopIndex() (uint64, error) { + key, err := sdm.LastKey() + if err != nil { + return 0, err + } + + indices := sdm.KeyToIndices[key] + if len(indices) == 0 { + return 0, fmt.Errorf("no indices for key %s", &key) + } + + index := indices[len(indices)-1] + sdm.KeyToIndices[key] = indices[:len(indices)-1] + return index, nil +} + +// Global context to keep track of different results across different +// hints execution. +type HintRunnerContext struct { + DictionaryManager DictionaryManager + SquashedDictionaryManager SquashedDictionaryManager +} + type HintRunner struct { + // Execution context required by certain hints such as dictionaires + context HintRunnerContext // A mapping from program counter to hint implementation hints map[uint64]Hinter } func NewHintRunner(hints map[uint64]Hinter) HintRunner { - return HintRunner{hints} + return HintRunner{ + // Context for certain hints that require it. Each manager is + // initialized only when required by the hint + context: HintRunnerContext{ + DictionaryManager{}, + SquashedDictionaryManager{}, + }, + hints: hints, + } } -func (hr HintRunner) RunHint(vm *VM.VirtualMachine) error { +func (hr *HintRunner) RunHint(vm *VM.VirtualMachine) error { hint := hr.hints[vm.Context.Pc.Offset] if hint == nil { return nil } - err := hint.Execute(vm) + err := hint.Execute(vm, &hr.context) if err != nil { return fmt.Errorf("execute hint %s: %v", hint, err) } diff --git a/pkg/hintrunner/hintrunner_test.go b/pkg/hintrunner/hintrunner_test.go index f71ab7698..4dcfd6670 100644 --- a/pkg/hintrunner/hintrunner_test.go +++ b/pkg/hintrunner/hintrunner_test.go @@ -16,7 +16,7 @@ func TestExistingHint(t *testing.T) { allocHint := AllocSegment{ap} hr := NewHintRunner(map[uint64]Hinter{ - 10: allocHint, + 10: &allocHint, }) vm.Context.Pc = memory.MemoryAddress{ @@ -40,7 +40,7 @@ func TestNoHint(t *testing.T) { allocHint := AllocSegment{ap} hr := NewHintRunner(map[uint64]Hinter{ - 10: allocHint, + 10: &allocHint, }) vm.Context.Pc = memory.MemoryAddress{ diff --git a/pkg/hintrunner/operand.go b/pkg/hintrunner/operand.go index dbae19c84..3a545264b 100644 --- a/pkg/hintrunner/operand.go +++ b/pkg/hintrunner/operand.go @@ -2,11 +2,10 @@ package hintrunner import ( "fmt" - "math/big" "github.com/NethermindEth/cairo-vm-go/pkg/safemath" VM "github.com/NethermindEth/cairo-vm-go/pkg/vm" - "github.com/NethermindEth/cairo-vm-go/pkg/vm/memory" + mem "github.com/NethermindEth/cairo-vm-go/pkg/vm/memory" f "github.com/consensys/gnark-crypto/ecc/stark-curve/fp" ) @@ -16,7 +15,7 @@ import ( type CellRefer interface { fmt.Stringer - Get(vm *VM.VirtualMachine) (memory.MemoryAddress, error) + Get(vm *VM.VirtualMachine) (mem.MemoryAddress, error) } type ApCellRef int16 @@ -25,12 +24,12 @@ func (ap ApCellRef) String() string { return fmt.Sprintf("ApCellRef(%d)", ap) } -func (ap ApCellRef) Get(vm *VM.VirtualMachine) (memory.MemoryAddress, error) { +func (ap ApCellRef) Get(vm *VM.VirtualMachine) (mem.MemoryAddress, error) { res, overflow := safemath.SafeOffset(vm.Context.Ap, int16(ap)) if overflow { - return memory.MemoryAddress{}, safemath.NewSafeOffsetError(vm.Context.Ap, int16(ap)) + return mem.UnknownAddress, safemath.NewSafeOffsetError(vm.Context.Ap, int16(ap)) } - return memory.MemoryAddress{SegmentIndex: VM.ExecutionSegment, Offset: res}, nil + return mem.MemoryAddress{SegmentIndex: VM.ExecutionSegment, Offset: res}, nil } type FpCellRef int16 @@ -39,12 +38,12 @@ func (fp FpCellRef) String() string { return fmt.Sprintf("FpCellRef(%d)", fp) } -func (fp FpCellRef) Get(vm *VM.VirtualMachine) (memory.MemoryAddress, error) { +func (fp FpCellRef) Get(vm *VM.VirtualMachine) (mem.MemoryAddress, error) { res, overflow := safemath.SafeOffset(vm.Context.Fp, int16(fp)) if overflow { - return memory.MemoryAddress{}, safemath.NewSafeOffsetError(vm.Context.Ap, int16(fp)) + return mem.MemoryAddress{}, safemath.NewSafeOffsetError(vm.Context.Ap, int16(fp)) } - return memory.MemoryAddress{SegmentIndex: VM.ExecutionSegment, Offset: res}, nil + return mem.MemoryAddress{SegmentIndex: VM.ExecutionSegment, Offset: res}, nil } // @@ -53,7 +52,7 @@ func (fp FpCellRef) Get(vm *VM.VirtualMachine) (memory.MemoryAddress, error) { type ResOperander interface { fmt.Stringer - Resolve(vm *VM.VirtualMachine) (memory.MemoryValue, error) + Resolve(vm *VM.VirtualMachine) (mem.MemoryValue, error) } type Deref struct { @@ -64,10 +63,10 @@ func (deref Deref) String() string { return "Deref" } -func (deref Deref) Resolve(vm *VM.VirtualMachine) (memory.MemoryValue, error) { +func (deref Deref) Resolve(vm *VM.VirtualMachine) (mem.MemoryValue, error) { address, err := deref.deref.Get(vm) if err != nil { - return memory.MemoryValue{}, fmt.Errorf("get cell: %w", err) + return mem.MemoryValue{}, fmt.Errorf("get cell: %w", err) } return vm.Memory.ReadFromAddress(&address) } @@ -77,55 +76,49 @@ type DoubleDeref struct { offset int16 } -func (dderef DoubleDeref) Resolve(vm *VM.VirtualMachine) (memory.MemoryValue, error) { +func (dderef DoubleDeref) Resolve(vm *VM.VirtualMachine) (mem.MemoryValue, error) { lhsAddr, err := dderef.deref.Get(vm) if err != nil { - return memory.MemoryValue{}, fmt.Errorf("get lhs address %s: %w", lhsAddr, err) + return mem.UnknownValue, fmt.Errorf("get lhs address %s: %w", lhsAddr, err) } lhs, err := vm.Memory.ReadFromAddress(&lhsAddr) if err != nil { - return memory.MemoryValue{}, fmt.Errorf("read lhs address %s: %w", lhsAddr, err) + return mem.UnknownValue, fmt.Errorf("read lhs address %s: %w", lhsAddr, err) } // Double deref implies the left hand side read must be an address address, err := lhs.MemoryAddress() if err != nil { - return memory.MemoryValue{}, err + return mem.UnknownValue, err } newOffset, overflow := safemath.SafeOffset(address.Offset, dderef.offset) if overflow { - return memory.MemoryValue{}, safemath.NewSafeOffsetError(address.Offset, dderef.offset) + return mem.UnknownValue, safemath.NewSafeOffsetError(address.Offset, dderef.offset) } - resAddr := memory.MemoryAddress{ + resAddr := mem.MemoryAddress{ SegmentIndex: address.SegmentIndex, Offset: newOffset, } value, err := vm.Memory.ReadFromAddress(&resAddr) if err != nil { - return memory.MemoryValue{}, fmt.Errorf("read result at %s: %w", resAddr, err) + return mem.UnknownValue, fmt.Errorf("read result at %s: %w", resAddr, err) } return value, nil } -type Immediate big.Int +type Immediate f.Element func (imm Immediate) String() string { return "Immediate" } -// todo(rodro): Specs from Starkware stablish this can be uint256 and not a felt. // Should we respect that, or go straight to felt? -func (imm Immediate) Resolve(vm *VM.VirtualMachine) (memory.MemoryValue, error) { - felt := &f.Element{} - bigInt := (big.Int)(imm) - // todo(rodro): do we require to check that big int is lesser than P, or do we - // just take: big_int `mod` P? - felt.SetBigInt(&bigInt) - - return memory.MemoryValueFromFieldElement(felt), nil +func (imm Immediate) Resolve(vm *VM.VirtualMachine) (mem.MemoryValue, error) { + felt := f.Element(imm) + return mem.MemoryValueFromFieldElement(&felt), nil } type Operator uint8 @@ -145,31 +138,31 @@ func (bop BinaryOp) String() string { return "BinaryOperator" } -func (bop BinaryOp) Resolve(vm *VM.VirtualMachine) (memory.MemoryValue, error) { +func (bop BinaryOp) Resolve(vm *VM.VirtualMachine) (mem.MemoryValue, error) { lhsAddr, err := bop.lhs.Get(vm) if err != nil { - return memory.MemoryValue{}, fmt.Errorf("get lhs address %s: %w", bop.lhs, err) + return mem.UnknownValue, fmt.Errorf("get lhs address %s: %w", bop.lhs, err) } lhs, err := vm.Memory.ReadFromAddress(&lhsAddr) if err != nil { - return memory.MemoryValue{}, fmt.Errorf("read lhs address %s: %v", lhsAddr, err) + return mem.UnknownValue, fmt.Errorf("read lhs address %s: %w", lhsAddr, err) } rhs, err := bop.rhs.Resolve(vm) if err != nil { - return memory.MemoryValue{}, fmt.Errorf("resolve rhs operand %s: %v", rhs, err) + return mem.UnknownValue, fmt.Errorf("resolve rhs operand %s: %w", rhs, err) } switch bop.operator { case Add: - mv := memory.EmptyMemoryValueAs(lhs.IsAddress() || rhs.IsAddress()) + mv := mem.EmptyMemoryValueAs(lhs.IsAddress() || rhs.IsAddress()) err := mv.Add(&lhs, &rhs) return mv, err case Mul: - mv := memory.EmptyMemoryValueAsFelt() + mv := mem.EmptyMemoryValueAsFelt() err := mv.Mul(&lhs, &rhs) return mv, err default: - return memory.MemoryValue{}, fmt.Errorf("unknown binary operator: %d", bop.operator) + return mem.UnknownValue, fmt.Errorf("unknown binary operator: %d", bop.operator) } } diff --git a/pkg/hintrunner/operand_test.go b/pkg/hintrunner/operand_test.go index 9a93d4bdb..0446b8c3c 100644 --- a/pkg/hintrunner/operand_test.go +++ b/pkg/hintrunner/operand_test.go @@ -1,11 +1,11 @@ package hintrunner import ( - "math/big" "testing" VM "github.com/NethermindEth/cairo-vm-go/pkg/vm" "github.com/NethermindEth/cairo-vm-go/pkg/vm/memory" + f "github.com/consensys/gnark-crypto/ecc/stark-curve/fp" "github.com/stretchr/testify/require" ) @@ -102,7 +102,7 @@ func TestResolveImmediate(t *testing.T) { // Immediate does not need the vm for resolving itself var vm *VM.VirtualMachine = nil - imm := Immediate(*big.NewInt(99)) + imm := Immediate(f.NewElement(99)) solved, err := imm.Resolve(vm) require.NoError(t, err) diff --git a/pkg/hintrunner/utils.go b/pkg/hintrunner/utils.go new file mode 100644 index 000000000..c75e8571a --- /dev/null +++ b/pkg/hintrunner/utils.go @@ -0,0 +1,51 @@ +package hintrunner + +import ( + "fmt" + + VM "github.com/NethermindEth/cairo-vm-go/pkg/vm" + mem "github.com/NethermindEth/cairo-vm-go/pkg/vm/memory" + f "github.com/consensys/gnark-crypto/ecc/stark-curve/fp" +) + +func ResolveAsAddress(vm *VM.VirtualMachine, op ResOperander) (mem.MemoryAddress, error) { + mv, err := op.Resolve(vm) + if err != nil { + return mem.UnknownAddress, fmt.Errorf("%s: %w", op, err) + } + + addr, err := mv.MemoryAddress() + if err != nil { + return mem.UnknownAddress, fmt.Errorf("%s: %w", op, err) + } + + return *addr, nil +} + +func ResolveAsFelt(vm *VM.VirtualMachine, op ResOperander) (f.Element, error) { + mv, err := op.Resolve(vm) + if err != nil { + return f.Element{}, fmt.Errorf("%s: %w", op, err) + } + + felt, err := mv.FieldElement() + if err != nil { + return f.Element{}, fmt.Errorf("%s: %w", op, err) + } + + return *felt, nil +} + +func ResolveAsUint64(vm *VM.VirtualMachine, op ResOperander) (uint64, error) { + mv, err := op.Resolve(vm) + if err != nil { + return 0, fmt.Errorf("%s: %w", op, err) + } + + uint64Value, err := mv.Uint64() + if err != nil { + return 0, fmt.Errorf("%s: %w", op, err) + } + + return uint64Value, nil +} diff --git a/pkg/runners/zero/zero.go b/pkg/runners/zero/zero.go index e428e8b4c..ec89f280c 100644 --- a/pkg/runners/zero/zero.go +++ b/pkg/runners/zero/zero.go @@ -105,11 +105,9 @@ func (runner *ZeroRunner) InitializeMainEntrypoint() (mem.MemoryAddress, error) return mem.MemoryAddress{SegmentIndex: vm.ProgramSegment, Offset: endPcOffset}, nil } - returnFp := mem.MemoryValueFromSegmentAndOffset( - memory.AllocateEmptySegment(), - 0, - ) - return runner.InitializeEntrypoint("main", nil, &returnFp, memory) + returnFp := memory.AllocateEmptySegment() + mvReturnFp := mem.MemoryValueFromMemoryAddress(&returnFp) + return runner.InitializeEntrypoint("main", nil, &mvReturnFp, memory) } func (runner *ZeroRunner) InitializeEntrypoint( @@ -124,10 +122,7 @@ func (runner *ZeroRunner) InitializeEntrypoint( for i := range arguments { stack = append(stack, mem.MemoryValueFromFieldElement(arguments[i])) } - end := mem.MemoryAddress{ - SegmentIndex: uint64(memory.AllocateEmptySegment()), - Offset: 0, - } + end := memory.AllocateEmptySegment() stack = append(stack, *returnFp, mem.MemoryValueFromMemoryAddress(&end)) return end, runner.initializeVm(&mem.MemoryAddress{ @@ -141,7 +136,7 @@ func (runner *ZeroRunner) initializeBuiltins(memory *mem.Memory) []mem.MemoryVal for _, builtin := range runner.program.Builtins { bRunner := builtins.Runner(builtin) builtinSegment := memory.AllocateBuiltinSegment(bRunner) - stack = append(stack, mem.MemoryValueFromSegmentAndOffset(builtinSegment, 0)) + stack = append(stack, mem.MemoryValueFromMemoryAddress(&builtinSegment)) } return stack } @@ -178,7 +173,7 @@ func (runner *ZeroRunner) RunUntilPc(pc *mem.MemoryAddress) error { runner.maxsteps, ) } - if err := runner.vm.RunStep(runner.hintrunner); err != nil { + if err := runner.vm.RunStep(&runner.hintrunner); err != nil { return fmt.Errorf("pc %s step %d: %w", runner.pc(), runner.steps(), err) } } @@ -196,7 +191,7 @@ func (runner *ZeroRunner) RunFor(steps uint64) error { runner.maxsteps, ) } - if err := runner.vm.RunStep(runner.hintrunner); err != nil { + if err := runner.vm.RunStep(&runner.hintrunner); err != nil { return fmt.Errorf( "pc %s step %d: %w", runner.pc(), diff --git a/pkg/safemath/arrays.go b/pkg/safemath/arrays.go new file mode 100644 index 000000000..a78d4b344 --- /dev/null +++ b/pkg/safemath/arrays.go @@ -0,0 +1,7 @@ +package safemath + +func Reverse[T any](a []T) { + for i, j := 0, len(a)-1; i < j; i, j = i+1, j-1 { + a[i], a[j] = a[j], a[i] + } +} diff --git a/pkg/safemath/constant.go b/pkg/safemath/constant.go new file mode 100644 index 000000000..7fd5b70f7 --- /dev/null +++ b/pkg/safemath/constant.go @@ -0,0 +1,17 @@ +package safemath + +import ( + "github.com/consensys/gnark-crypto/ecc/stark-curve/fp" + "github.com/holiman/uint256" +) + +var FeltZero = fp.Element{} + +var FeltOne = fp.Element{ + 18446744073709551585, 18446744073709551615, 18446744073709551615, 576460752303422960, +} + +// 1 << 128 +var FeltMax128 = fp.Element{18446744073700081665, 17407, 18446744073709551584, 576460752142434320} + +var Uint256Max128 = uint256.Int{18446744073709551615, 18446744073709551615, 0, 0} diff --git a/pkg/vm/builtins/range_check.go b/pkg/vm/builtins/range_check.go index 28447fcd9..84fd58b7b 100644 --- a/pkg/vm/builtins/range_check.go +++ b/pkg/vm/builtins/range_check.go @@ -4,17 +4,14 @@ import ( "errors" "fmt" + "github.com/NethermindEth/cairo-vm-go/pkg/safemath" "github.com/NethermindEth/cairo-vm-go/pkg/vm/memory" - "github.com/consensys/gnark-crypto/ecc/stark-curve/fp" ) const RangeCheckName = "range_check" type RangeCheck struct{} -// 1 << 128 -var max128 = fp.Element{18446744073700081665, 17407, 18446744073709551584, 576460752142434320} - func (r *RangeCheck) CheckWrite(segment *memory.Segment, offset uint64, value *memory.MemoryValue) error { felt, err := value.FieldElement() if err != nil { @@ -22,7 +19,7 @@ func (r *RangeCheck) CheckWrite(segment *memory.Segment, offset uint64, value *m } // felt >= (2^128) - if felt.Cmp(&max128) != -1 { + if felt.Cmp(&safemath.FeltMax128) != -1 { return fmt.Errorf("check write: 2**128 < %s", value) } return nil diff --git a/pkg/vm/memory/memory.go b/pkg/vm/memory/memory.go index b369fa495..d69c6a715 100644 --- a/pkg/vm/memory/memory.go +++ b/pkg/vm/memory/memory.go @@ -189,30 +189,39 @@ func InitializeEmptyMemory() *Memory { } // Allocates a new segment providing its initial data and returns its index -func (memory *Memory) AllocateSegment(data []*f.Element) (int, error) { +func (memory *Memory) AllocateSegment(data []*f.Element) (MemoryAddress, error) { newSegment := EmptySegmentWithLength(len(data)) for i := range data { memVal := MemoryValueFromFieldElement(data[i]) err := newSegment.Write(uint64(i), &memVal) if err != nil { - return 0, err + return UnknownAddress, err } } memory.Segments = append(memory.Segments, newSegment) - return len(memory.Segments) - 1, nil + return MemoryAddress{ + SegmentIndex: uint64(len(memory.Segments) - 1), + Offset: 0, + }, nil } // Allocates an empty segment and returns its index -func (memory *Memory) AllocateEmptySegment() int { +func (memory *Memory) AllocateEmptySegment() MemoryAddress { memory.Segments = append(memory.Segments, EmptySegment()) - return len(memory.Segments) - 1 + return MemoryAddress{ + SegmentIndex: uint64(len(memory.Segments) - 1), + Offset: 0, + } } // Allocate a Builtin segment -func (memory *Memory) AllocateBuiltinSegment(builtinRunner BuiltinRunner) int { +func (memory *Memory) AllocateBuiltinSegment(builtinRunner BuiltinRunner) MemoryAddress { builtinSegment := EmptySegment().WithBuiltinRunner(builtinRunner) memory.Segments = append(memory.Segments, builtinSegment) - return len(memory.Segments) - 1 + return MemoryAddress{ + SegmentIndex: uint64(len(memory.Segments) - 1), + Offset: 0, + } } // Writes to a given segment index and offset a new memory value. Errors if writing @@ -252,6 +261,42 @@ func (memory *Memory) ReadFromAddress(address *MemoryAddress) (MemoryValue, erro return memory.Read(address.SegmentIndex, address.Offset) } +// Works the same as `Read` but `MemoryValue` is converted to `Element` first +func (memory *Memory) ReadAsElement(segmentIndex uint64, offset uint64) (f.Element, error) { + mv, err := memory.Read(segmentIndex, offset) + if err != nil { + return f.Element{}, err + } + felt, err := mv.FieldElement() + if err != nil { + return f.Element{}, err + } + return *felt, nil +} + +// Works the same as `ReadFromAddress` but `MemoryValue` is converted to `Element` first +func (memory *Memory) ReadFromAddressAsElement(address *MemoryAddress) (f.Element, error) { + return memory.ReadAsElement(address.SegmentIndex, address.Offset) +} + +// Works the same as `Read` but `MemoryValue` is converted to `MemoryAddress` first +func (memory *Memory) ReadAsAddress(address *MemoryAddress) (MemoryAddress, error) { + mv, err := memory.Read(address.SegmentIndex, address.Offset) + if err != nil { + return UnknownAddress, err + } + addr, err := mv.MemoryAddress() + if err != nil { + return UnknownAddress, err + } + return *addr, nil +} + +// Works the same as `ReadFromAddress` but `MemoryValue` is converted to `MemoryAddress` first +func (memory *Memory) ReadFromAddressAsAddress(address *MemoryAddress) (MemoryAddress, error) { + return memory.ReadAsAddress(address) +} + // Given a segment index and offset, returns the memory value at that position, without // modifying it in any way. Errors if peeking from an unallocated segment func (memory *Memory) Peek(segmentIndex uint64, offset uint64) (MemoryValue, error) { diff --git a/pkg/vm/memory/memory_value.go b/pkg/vm/memory/memory_value.go index 599f1080a..14404fdf5 100644 --- a/pkg/vm/memory/memory_value.go +++ b/pkg/vm/memory/memory_value.go @@ -5,6 +5,7 @@ import ( "fmt" "unsafe" + "github.com/NethermindEth/cairo-vm-go/pkg/safemath" f "github.com/consensys/gnark-crypto/ecc/stark-curve/fp" "golang.org/x/exp/constraints" ) @@ -23,6 +24,23 @@ func (address *MemoryAddress) Equal(other *MemoryAddress) bool { return address.SegmentIndex == other.SegmentIndex && address.Offset == other.Offset } +// It crates a new memory address with the modified offset +func (address *MemoryAddress) AddOffset(offset int16) (MemoryAddress, error) { + newOffset, overflow := safemath.SafeOffset(address.Offset, offset) + if overflow { + return UnknownAddress, + fmt.Errorf( + "address new invalid offseet: %d + %d = %d", + address.Offset, offset, newOffset, + ) + } + return MemoryAddress{ + SegmentIndex: address.SegmentIndex, + Offset: newOffset, + }, nil + +} + // Adds a memory address and a field element func (address *MemoryAddress) Add(lhs *MemoryAddress, rhs *f.Element) error { lhsOffset := new(f.Element).SetUint64(lhs.Offset) From a302d452f16e4d5a69e8db8c5119b37082eaf0f9 Mon Sep 17 00:00:00 2001 From: Rodrigo Date: Thu, 2 Nov 2023 08:05:23 -0400 Subject: [PATCH 3/3] Refactor: Safemath is now the Util pkg (#149) * rename pkg safemath to utils * rename pkg name to utils * Fix compilation errors due to pkg rename * safemath to math * remover error file: --- pkg/hintrunner/hint.go | 12 +++++------ pkg/hintrunner/operand.go | 14 ++++++------- pkg/runners/zero/zero.go | 4 ++-- pkg/safemath/error.go | 21 ------------------- pkg/{safemath => utils}/arrays.go | 2 +- pkg/{safemath => utils}/constant.go | 2 +- pkg/{safemath/safemath.go => utils/math.go} | 2 +- .../safemath_test.go => utils/math_test.go} | 2 +- pkg/vm/builtins/range_check.go | 4 ++-- pkg/vm/memory/memory.go | 4 ++-- pkg/vm/memory/memory_value.go | 4 ++-- pkg/vm/vm.go | 8 +++---- 12 files changed, 29 insertions(+), 50 deletions(-) delete mode 100644 pkg/safemath/error.go rename pkg/{safemath => utils}/arrays.go (86%) rename pkg/{safemath => utils}/constant.go (96%) rename pkg/{safemath/safemath.go => utils/math.go} (99%) rename pkg/{safemath/safemath_test.go => utils/math_test.go} (97%) diff --git a/pkg/hintrunner/hint.go b/pkg/hintrunner/hint.go index 9fed309e1..1dbb0f132 100644 --- a/pkg/hintrunner/hint.go +++ b/pkg/hintrunner/hint.go @@ -6,7 +6,7 @@ import ( "github.com/holiman/uint256" - "github.com/NethermindEth/cairo-vm-go/pkg/safemath" + "github.com/NethermindEth/cairo-vm-go/pkg/utils" VM "github.com/NethermindEth/cairo-vm-go/pkg/vm" mem "github.com/NethermindEth/cairo-vm-go/pkg/vm/memory" f "github.com/consensys/gnark-crypto/ecc/stark-curve/fp" @@ -235,7 +235,7 @@ func (hint *WideMul128) String() string { } func (hint *WideMul128) Execute(vm *VM.VirtualMachine, _ *HintRunnerContext) error { - mask := &safemath.Uint256Max128 + mask := &utils.Uint256Max128 lhs, err := hint.lhs.Resolve(vm) if err != nil { @@ -404,7 +404,7 @@ func (hint *AllocFelt252Dict) Execute(vm *VM.VirtualMachine, ctx *HintRunnerCont } // find for the amount of initialized dicts - initializedDictsOffset, overflow := safemath.SafeOffset(arenaPtr.Offset, -2) + initializedDictsOffset, overflow := utils.SafeOffset(arenaPtr.Offset, -2) if overflow { return fmt.Errorf("look for initialized dicts: overflow: %s - 2", arenaPtr) } @@ -418,7 +418,7 @@ func (hint *AllocFelt252Dict) Execute(vm *VM.VirtualMachine, ctx *HintRunnerCont } // find for the segment info pointer - segmentInfoOffset, overflow := safemath.SafeOffset(arenaPtr.Offset, -3) + segmentInfoOffset, overflow := utils.SafeOffset(arenaPtr.Offset, -3) if overflow { return fmt.Errorf("look for segment info pointer: overflow: %s - 3", arenaPtr) } @@ -586,7 +586,7 @@ func (hint *InitSquashData) Execute(vm *VM.VirtualMachine, ctx *HintRunnerContex } for key, val := range ctx.SquashedDictionaryManager.KeyToIndices { // reverse each indice access list per key - safemath.Reverse(val) + utils.Reverse(val) // store each key ctx.SquashedDictionaryManager.Keys = append(ctx.SquashedDictionaryManager.Keys, key) } @@ -603,7 +603,7 @@ func (hint *InitSquashData) Execute(vm *VM.VirtualMachine, ctx *HintRunnerContex } biggestKey := ctx.SquashedDictionaryManager.Keys[0] cmpRes := mem.MemoryValueFromUint[uint64](0) - if biggestKey.Cmp(&safemath.FeltMax128) > 0 { + if biggestKey.Cmp(&utils.FeltMax128) > 0 { cmpRes = mem.MemoryValueFromUint[uint64](1) } err = vm.Memory.WriteToAddress(&bigKeysAddr, &cmpRes) diff --git a/pkg/hintrunner/operand.go b/pkg/hintrunner/operand.go index 3a545264b..31cba58fe 100644 --- a/pkg/hintrunner/operand.go +++ b/pkg/hintrunner/operand.go @@ -3,7 +3,7 @@ package hintrunner import ( "fmt" - "github.com/NethermindEth/cairo-vm-go/pkg/safemath" + "github.com/NethermindEth/cairo-vm-go/pkg/utils" VM "github.com/NethermindEth/cairo-vm-go/pkg/vm" mem "github.com/NethermindEth/cairo-vm-go/pkg/vm/memory" f "github.com/consensys/gnark-crypto/ecc/stark-curve/fp" @@ -25,9 +25,9 @@ func (ap ApCellRef) String() string { } func (ap ApCellRef) Get(vm *VM.VirtualMachine) (mem.MemoryAddress, error) { - res, overflow := safemath.SafeOffset(vm.Context.Ap, int16(ap)) + res, overflow := utils.SafeOffset(vm.Context.Ap, int16(ap)) if overflow { - return mem.UnknownAddress, safemath.NewSafeOffsetError(vm.Context.Ap, int16(ap)) + return mem.UnknownAddress, fmt.Errorf("overflow %d + %d", vm.Context.Ap, int16(ap)) } return mem.MemoryAddress{SegmentIndex: VM.ExecutionSegment, Offset: res}, nil } @@ -39,9 +39,9 @@ func (fp FpCellRef) String() string { } func (fp FpCellRef) Get(vm *VM.VirtualMachine) (mem.MemoryAddress, error) { - res, overflow := safemath.SafeOffset(vm.Context.Fp, int16(fp)) + res, overflow := utils.SafeOffset(vm.Context.Fp, int16(fp)) if overflow { - return mem.MemoryAddress{}, safemath.NewSafeOffsetError(vm.Context.Ap, int16(fp)) + return mem.UnknownAddress, fmt.Errorf("overflow %d + %d", vm.Context.Fp, int16(fp)) } return mem.MemoryAddress{SegmentIndex: VM.ExecutionSegment, Offset: res}, nil } @@ -92,9 +92,9 @@ func (dderef DoubleDeref) Resolve(vm *VM.VirtualMachine) (mem.MemoryValue, error return mem.UnknownValue, err } - newOffset, overflow := safemath.SafeOffset(address.Offset, dderef.offset) + newOffset, overflow := utils.SafeOffset(address.Offset, dderef.offset) if overflow { - return mem.UnknownValue, safemath.NewSafeOffsetError(address.Offset, dderef.offset) + return mem.UnknownValue, fmt.Errorf("overflow %d + %d", address.Offset, dderef.offset) } resAddr := mem.MemoryAddress{ SegmentIndex: address.SegmentIndex, diff --git a/pkg/runners/zero/zero.go b/pkg/runners/zero/zero.go index ec89f280c..3b20ca46b 100644 --- a/pkg/runners/zero/zero.go +++ b/pkg/runners/zero/zero.go @@ -5,7 +5,7 @@ import ( "fmt" "github.com/NethermindEth/cairo-vm-go/pkg/hintrunner" - "github.com/NethermindEth/cairo-vm-go/pkg/safemath" + "github.com/NethermindEth/cairo-vm-go/pkg/utils" "github.com/NethermindEth/cairo-vm-go/pkg/vm" "github.com/NethermindEth/cairo-vm-go/pkg/vm/builtins" mem "github.com/NethermindEth/cairo-vm-go/pkg/vm/memory" @@ -57,7 +57,7 @@ func (runner *ZeroRunner) Run() error { if runner.proofmode { // +1 because proof mode require an extra instruction run // pow2 because proof mode also requires that the trace is a power of two - pow2Steps := safemath.NextPowerOfTwo(runner.vm.Step + 1) + pow2Steps := utils.NextPowerOfTwo(runner.vm.Step + 1) if err := runner.RunFor(pow2Steps); err != nil { return err } diff --git a/pkg/safemath/error.go b/pkg/safemath/error.go deleted file mode 100644 index f479fc117..000000000 --- a/pkg/safemath/error.go +++ /dev/null @@ -1,21 +0,0 @@ -package safemath - -import "fmt" - -type SafeMathError struct { - msg string -} - -func NewSafeOffsetError(a uint64, b int16) *SafeMathError { - return &SafeMathError{ - msg: fmt.Sprintf("offset calculation of %d using %d is out of [0, 2**64) range", a, b), - } -} - -func (e *SafeMathError) Error() string { - return fmt.Sprintf("math error: %s", e.msg) -} - -func (e *SafeMathError) Unwrap() error { - return nil -} diff --git a/pkg/safemath/arrays.go b/pkg/utils/arrays.go similarity index 86% rename from pkg/safemath/arrays.go rename to pkg/utils/arrays.go index a78d4b344..13a42fac2 100644 --- a/pkg/safemath/arrays.go +++ b/pkg/utils/arrays.go @@ -1,4 +1,4 @@ -package safemath +package utils func Reverse[T any](a []T) { for i, j := 0, len(a)-1; i < j; i, j = i+1, j-1 { diff --git a/pkg/safemath/constant.go b/pkg/utils/constant.go similarity index 96% rename from pkg/safemath/constant.go rename to pkg/utils/constant.go index 7fd5b70f7..5053981c5 100644 --- a/pkg/safemath/constant.go +++ b/pkg/utils/constant.go @@ -1,4 +1,4 @@ -package safemath +package utils import ( "github.com/consensys/gnark-crypto/ecc/stark-curve/fp" diff --git a/pkg/safemath/safemath.go b/pkg/utils/math.go similarity index 99% rename from pkg/safemath/safemath.go rename to pkg/utils/math.go index 7ccdf24f2..a6272cac9 100644 --- a/pkg/safemath/safemath.go +++ b/pkg/utils/math.go @@ -1,4 +1,4 @@ -package safemath +package utils import ( "math/bits" diff --git a/pkg/safemath/safemath_test.go b/pkg/utils/math_test.go similarity index 97% rename from pkg/safemath/safemath_test.go rename to pkg/utils/math_test.go index 538c03435..461b959c3 100644 --- a/pkg/safemath/safemath_test.go +++ b/pkg/utils/math_test.go @@ -1,4 +1,4 @@ -package safemath +package utils import ( "testing" diff --git a/pkg/vm/builtins/range_check.go b/pkg/vm/builtins/range_check.go index 84fd58b7b..cf8d28431 100644 --- a/pkg/vm/builtins/range_check.go +++ b/pkg/vm/builtins/range_check.go @@ -4,7 +4,7 @@ import ( "errors" "fmt" - "github.com/NethermindEth/cairo-vm-go/pkg/safemath" + "github.com/NethermindEth/cairo-vm-go/pkg/utils" "github.com/NethermindEth/cairo-vm-go/pkg/vm/memory" ) @@ -19,7 +19,7 @@ func (r *RangeCheck) CheckWrite(segment *memory.Segment, offset uint64, value *m } // felt >= (2^128) - if felt.Cmp(&safemath.FeltMax128) != -1 { + if felt.Cmp(&utils.FeltMax128) != -1 { return fmt.Errorf("check write: 2**128 < %s", value) } return nil diff --git a/pkg/vm/memory/memory.go b/pkg/vm/memory/memory.go index d69c6a715..51e317ebc 100644 --- a/pkg/vm/memory/memory.go +++ b/pkg/vm/memory/memory.go @@ -4,7 +4,7 @@ import ( "errors" "fmt" - "github.com/NethermindEth/cairo-vm-go/pkg/safemath" + "github.com/NethermindEth/cairo-vm-go/pkg/utils" f "github.com/consensys/gnark-crypto/ecc/stark-curve/fp" ) @@ -135,7 +135,7 @@ func (segment *Segment) IncreaseSegmentSize(newSize uint64) { if cap(segmentData) > int(newSize) { newSegmentData = segmentData[:cap(segmentData)] } else { - newSegmentData = make([]MemoryValue, safemath.Max(newSize, uint64(len(segmentData)*2))) + newSegmentData = make([]MemoryValue, utils.Max(newSize, uint64(len(segmentData)*2))) copy(newSegmentData, segmentData) } segment.Data = newSegmentData diff --git a/pkg/vm/memory/memory_value.go b/pkg/vm/memory/memory_value.go index 14404fdf5..856e2a4e4 100644 --- a/pkg/vm/memory/memory_value.go +++ b/pkg/vm/memory/memory_value.go @@ -5,7 +5,7 @@ import ( "fmt" "unsafe" - "github.com/NethermindEth/cairo-vm-go/pkg/safemath" + "github.com/NethermindEth/cairo-vm-go/pkg/utils" f "github.com/consensys/gnark-crypto/ecc/stark-curve/fp" "golang.org/x/exp/constraints" ) @@ -26,7 +26,7 @@ func (address *MemoryAddress) Equal(other *MemoryAddress) bool { // It crates a new memory address with the modified offset func (address *MemoryAddress) AddOffset(offset int16) (MemoryAddress, error) { - newOffset, overflow := safemath.SafeOffset(address.Offset, offset) + newOffset, overflow := utils.SafeOffset(address.Offset, offset) if overflow { return UnknownAddress, fmt.Errorf( diff --git a/pkg/vm/vm.go b/pkg/vm/vm.go index a368c5cef..4afbcdae2 100644 --- a/pkg/vm/vm.go +++ b/pkg/vm/vm.go @@ -5,7 +5,7 @@ import ( "fmt" a "github.com/NethermindEth/cairo-vm-go/pkg/assembler" - safemath "github.com/NethermindEth/cairo-vm-go/pkg/safemath" + "github.com/NethermindEth/cairo-vm-go/pkg/utils" mem "github.com/NethermindEth/cairo-vm-go/pkg/vm/memory" f "github.com/consensys/gnark-crypto/ecc/stark-curve/fp" ) @@ -215,7 +215,7 @@ func (vm *VirtualMachine) getDstAddr(instruction *a.Instruction) (mem.MemoryAddr dstRegister = vm.Context.Fp } - addr, isOverflow := safemath.SafeOffset(dstRegister, instruction.OffDest) + addr, isOverflow := utils.SafeOffset(dstRegister, instruction.OffDest) if isOverflow { return mem.UnknownAddress, fmt.Errorf("offset overflow: %d + %d", dstRegister, instruction.OffDest) } @@ -230,7 +230,7 @@ func (vm *VirtualMachine) getOp0Addr(instruction *a.Instruction) (mem.MemoryAddr op0Register = vm.Context.Fp } - addr, isOverflow := safemath.SafeOffset(op0Register, instruction.OffOp0) + addr, isOverflow := utils.SafeOffset(op0Register, instruction.OffOp0) if isOverflow { return mem.UnknownAddress, fmt.Errorf("offset overflow: %d + %d", op0Register, instruction.OffOp0) @@ -261,7 +261,7 @@ func (vm *VirtualMachine) getOp1Addr(instruction *a.Instruction, op0Addr *mem.Me op1Address = vm.Context.AddressAp() } - newOffset, isOverflow := safemath.SafeOffset(op1Address.Offset, instruction.OffOp1) + newOffset, isOverflow := utils.SafeOffset(op1Address.Offset, instruction.OffOp1) if isOverflow { return mem.UnknownAddress, fmt.Errorf("offset overflow: %d + %d", op1Address.Offset, instruction.OffOp1) }