Skip to content

Commit

Permalink
Merge branch 'main' into quasilyte_casm_disasm
Browse files Browse the repository at this point in the history
  • Loading branch information
quasilyte authored Feb 21, 2024
2 parents 053de96 + 2be67c9 commit d2edbc7
Show file tree
Hide file tree
Showing 11 changed files with 726 additions and 70 deletions.
50 changes: 34 additions & 16 deletions pkg/hintrunner/hinter/operand.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ type ResOperander interface {
fmt.Stringer

ApplyApTracking(hint, ref zero.ApTracking) Reference
GetAddress(vm *VM.VirtualMachine) (mem.MemoryAddress, error)
Resolve(vm *VM.VirtualMachine) (mem.MemoryValue, error)
}

Expand All @@ -67,15 +68,19 @@ func (deref Deref) String() string {
}

func (deref Deref) Resolve(vm *VM.VirtualMachine) (mem.MemoryValue, error) {
address, err := deref.Deref.Get(vm)
address, err := deref.GetAddress(vm)
if err != nil {
return mem.MemoryValue{}, fmt.Errorf("get cell: %w", err)
return mem.UnknownValue, fmt.Errorf("get cell address: %w", err)
}
return vm.Memory.ReadFromAddress(&address)
}

func (deref Deref) GetAddress(vm *VM.VirtualMachine) (mem.MemoryAddress, error) {
return deref.Deref.Get(vm)
}

type DoubleDeref struct {
Deref CellRefer
Deref Deref
Offset int16
}

Expand All @@ -84,36 +89,40 @@ func (dderef DoubleDeref) String() string {
}

func (dderef DoubleDeref) Resolve(vm *VM.VirtualMachine) (mem.MemoryValue, error) {
lhsAddr, err := dderef.Deref.Get(vm)
addr, err := dderef.GetAddress(vm)
if err != nil {
return mem.UnknownValue, fmt.Errorf("get lhs address %s: %w", lhsAddr, err)
return mem.UnknownValue, err
}
lhs, err := vm.Memory.ReadFromAddress(&lhsAddr)
value, err := vm.Memory.ReadFromAddress(&addr)
if err != nil {
return mem.UnknownValue, fmt.Errorf("read lhs address %s: %w", lhsAddr, err)
return mem.UnknownValue, fmt.Errorf("read result at %s: %w", addr, err)
}

return value, nil
}

func (dderef DoubleDeref) GetAddress(vm *VM.VirtualMachine) (mem.MemoryAddress, error) {
lhs, err := dderef.Deref.Resolve(vm)
if err != nil {
return mem.UnknownAddress, fmt.Errorf("get lhs address: %w", err)
}

// Double deref implies the left hand side read must be an address
address, err := lhs.MemoryAddress()
if err != nil {
return mem.UnknownValue, err
return mem.UnknownAddress, err
}

newOffset, overflow := utils.SafeOffset(address.Offset, dderef.Offset)
if overflow {
return mem.UnknownValue, fmt.Errorf("overflow %d + %d", address.Offset, dderef.Offset)
return mem.UnknownAddress, fmt.Errorf("overflow %d + %d", address.Offset, dderef.Offset)
}
resAddr := mem.MemoryAddress{
SegmentIndex: address.SegmentIndex,
Offset: newOffset,
}

value, err := vm.Memory.ReadFromAddress(&resAddr)
if err != nil {
return mem.UnknownValue, fmt.Errorf("read result at %s: %w", resAddr, err)
}

return value, nil
return resAddr, nil
}

type Immediate f.Element
Expand All @@ -128,6 +137,10 @@ func (imm Immediate) Resolve(vm *VM.VirtualMachine) (mem.MemoryValue, error) {
return mem.MemoryValueFromFieldElement(&felt), nil
}

func (imm Immediate) GetAddress(vm *VM.VirtualMachine) (mem.MemoryAddress, error) {
return mem.UnknownAddress, fmt.Errorf("cannot get an address from an immediate value %s", imm)
}

type Operator uint8

const (
Expand Down Expand Up @@ -174,6 +187,11 @@ func (bop BinaryOp) Resolve(vm *VM.VirtualMachine) (mem.MemoryValue, error) {
}
}

func (bop BinaryOp) GetAddress(vm *VM.VirtualMachine) (mem.MemoryAddress, error) {
// TODO: Check if it's possible in some cases such as Deref + Immediate
return mem.UnknownAddress, fmt.Errorf("cannot get an address from a Binary Operation operand")
}

type Reference interface {
ApplyApTracking(hint, ref zero.ApTracking) Reference
}
Expand All @@ -197,7 +215,7 @@ func (v Deref) ApplyApTracking(hint, ref zero.ApTracking) Reference {
}

func (v DoubleDeref) ApplyApTracking(hint, ref zero.ApTracking) Reference {
v.Deref = v.Deref.ApplyApTracking(hint, ref).(CellRefer)
v.Deref = v.Deref.ApplyApTracking(hint, ref).(Deref)
return v
}

Expand Down
4 changes: 2 additions & 2 deletions pkg/hintrunner/hinter/operand_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ func TestResolveDoubleDerefPositiveOffset(t *testing.T) {
)

var apCell ApCellRef = 7
dderf := DoubleDeref{apCell, 14}
dderf := DoubleDeref{Deref{apCell}, 14}

value, err := dderf.Resolve(vm)
require.NoError(t, err)
Expand All @@ -92,7 +92,7 @@ func TestResolveDoubleDerefNegativeOffset(t *testing.T) {
)

var apCell ApCellRef = 7
dderf := DoubleDeref{apCell, -14}
dderf := DoubleDeref{Deref{apCell}, -14}

value, err := dderf.Resolve(vm)
require.NoError(t, err)
Expand Down
8 changes: 8 additions & 0 deletions pkg/hintrunner/zero/hintcode.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
package zero

const (
// This is a block for hint code strings where there is a single
// hint per function it belongs to (with some exceptions like testAssignCode).
allocSegmentCode string = "memory[ap] = segments.add()"
isLeFeltCode string = "memory[ap] = 0 if (ids.a % PRIME) <= (ids.b % PRIME) else 1"
assertLtFeltCode string = "from starkware.cairo.common.math_utils import assert_integer\nassert_integer(ids.a)\nassert_integer(ids.b)\nassert (ids.a % PRIME) < (ids.b % PRIME), \\\n f'a = {ids.a % PRIME} is not less than b = {ids.b % PRIME}.'"

// This is a very simple Cairo0 hint that allows us to test
// the identifier resolution code.
Expand All @@ -13,4 +17,8 @@ const (
assertLeFeltExcluded0Code string = "memory[ap] = 1 if excluded != 0 else 0"
assertLeFeltExcluded1Code string = "memory[ap] = 1 if excluded != 1 else 0"
assertLeFeltExcluded2Code string = "assert excluded == 2"

// is_nn() hints.
isNNCode string = "memory[ap] = 0 if 0 <= (ids.a % PRIME) < range_check_builtin.bound else 1"
isNNOutOfRangeCode string = "memory[ap] = 0 if 0 <= ((-ids.a - 1) % PRIME) < range_check_builtin.bound else 1"
)
6 changes: 4 additions & 2 deletions pkg/hintrunner/zero/hintparser.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,13 +117,15 @@ func (expression DerefCastExp) Evaluate() (hinter.Reference, error) {
return hinter.Deref{Deref: result}, nil
case hinter.Deref:
return hinter.DoubleDeref{
Deref: result.Deref,
Deref: result,
Offset: 0,
},
nil
case DerefOffset:
return hinter.DoubleDeref{
Deref: result.Deref.Deref,
Deref: hinter.Deref{
Deref: result.Deref.Deref,
},
Offset: int16(*result.Offset),
},
nil
Expand Down
11 changes: 7 additions & 4 deletions pkg/hintrunner/zero/hintparser_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,11 @@ func TestHintParser(t *testing.T) {
Parameter: "[cast([ap + 2], felt)]",
ExpectedCellRefer: nil,
ExpectedResOperander: hinter.DoubleDeref{
Deref: hinter.ApCellRef(2),
Offset: 0},
Deref: hinter.Deref{
Deref: hinter.ApCellRef(2),
},
Offset: 0,
},
},
{
Parameter: "cast([ap + 2] + [ap], felt)",
Expand Down Expand Up @@ -72,11 +75,11 @@ func TestHintParser(t *testing.T) {
require.NoError(t, err)

if test.ExpectedCellRefer != nil {
require.Equal(t, test.ExpectedCellRefer, output, "Expected CellRefer type")
require.Equal(t, test.ExpectedCellRefer, output, "unexpected CellRefer type")
}

if test.ExpectedResOperander != nil {
require.Equal(t, test.ExpectedResOperander, output, "Expected ResOperander type")
require.Equal(t, test.ExpectedResOperander, output, "unexpected ResOperander type")
}
}
}
54 changes: 8 additions & 46 deletions pkg/hintrunner/zero/zerohint.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,10 @@ func GetHintFromCode(program *zero.ZeroProgram, rawHint zero.Hint, hintPC uint64
switch rawHint.Code {
case allocSegmentCode:
return CreateAllocSegmentHinter(resolver)
case isLeFeltCode:
return createIsLeFeltHinter(resolver)
case assertLtFeltCode:
return createAssertLtFeltHinter(resolver)
case testAssignCode:
return createTestAssignHinter(resolver)
case assertLeFeltCode:
Expand All @@ -64,6 +68,10 @@ func GetHintFromCode(program *zero.ZeroProgram, rawHint zero.Hint, hintPC uint64
return createAssertLeFeltExcluded1Hinter(resolver)
case assertLeFeltExcluded2Code:
return createAssertLeFeltExcluded2Hinter(resolver)
case isNNCode:
return createIsNNHinter(resolver)
case isNNOutOfRangeCode:
return createIsNNOutOfRangeHinter(resolver)
default:
return nil, fmt.Errorf("Not identified hint")
}
Expand Down Expand Up @@ -98,52 +106,6 @@ func createTestAssignHinter(resolver hintReferenceResolver) (hinter.Hinter, erro
return h, nil
}

func createAssertLeFeltHinter(resolver hintReferenceResolver) (hinter.Hinter, error) {
a, err := resolver.GetResOperander("a")
if err != nil {
return nil, err
}
b, err := resolver.GetResOperander("b")
if err != nil {
return nil, err
}
rangeCheckPtr, err := resolver.GetResOperander("range_check_ptr")
if err != nil {
return nil, err
}

h := &core.AssertLeFindSmallArc{
A: a,
B: b,
RangeCheckPtr: rangeCheckPtr,
}
return h, nil
}

func createAssertLeFeltExcluded0Hinter(resolver hintReferenceResolver) (hinter.Hinter, error) {
return &core.AssertLeIsFirstArcExcluded{SkipExcludeAFlag: hinter.ApCellRef(0)}, nil
}

func createAssertLeFeltExcluded1Hinter(resolver hintReferenceResolver) (hinter.Hinter, error) {
return &core.AssertLeIsSecondArcExcluded{SkipExcludeBMinusA: hinter.ApCellRef(0)}, nil
}

func createAssertLeFeltExcluded2Hinter(resolver hintReferenceResolver) (hinter.Hinter, error) {
// This hint is Cairo0-specific.
// It only does a python-scoped variable named "excluded" assert.
// We store that variable inside a hinter context.
h := &GenericZeroHinter{
Name: "AssertLeFeltExcluded2",
Op: func(vm *VM.VirtualMachine, ctx *hinter.HintRunnerContext) error {
if ctx.ExcludedArc != 2 {
return fmt.Errorf("assertion `excluded == 2` failed")
}
return nil
},
}
return h, nil
}

func getParameters(zeroProgram *zero.ZeroProgram, hint zero.Hint, hintPC uint64) (hintReferenceResolver, error) {
resolver := NewReferenceResolver()

Expand Down
Loading

0 comments on commit d2edbc7

Please sign in to comment.