Skip to content

Commit

Permalink
pkg/hintrunner/zero: allow multiplication binary ops in references
Browse files Browse the repository at this point in the history
References may include the multiplication inside them.

Let's take this like of code for the example:

    https://github.com/starkware-libs/cairo-lang/blob/caba294d82eeeccc3d86a158adb8ba209bf2d8fc/src/starkware/cairo/common/math.cairo#L193

It will produce a reference like this:
```json
    {
        "cairo_type": "felt",
        "full_name": "starkware.cairo.common.math.assert_le_felt.arc_prod",
        "references": [
            {
                "ap_tracking_data": {
                    "group": 1,
                    "offset": 8
                },
                "pc": 14,
                "value": "cast([ap + (-5)] * [ap + (-1)], felt)"
            }
        ],
        "type": "reference"
    }
```
  • Loading branch information
quasilyte committed Feb 7, 2024
1 parent 88587ae commit 1e87be8
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 3 deletions.
33 changes: 30 additions & 3 deletions pkg/hintrunner/zero/hintparser.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package zero
import (
"fmt"

"github.com/NethermindEth/cairo-vm-go/pkg/hintrunner/hinter"
op "github.com/NethermindEth/cairo-vm-go/pkg/hintrunner/hinter"
"github.com/alecthomas/participle/v2"
)
Expand All @@ -22,6 +23,7 @@ var parser *participle.Parser[IdentifierExp] = participle.MustBuild[IdentifierEx
// 2 dereferences off1 omitted: cast([reg] + [reg + off2], type)
// 2 dereferences off2 omitted: cast([reg + off1] + [reg], type)
// 2 dereferences both offs omitted: cast([reg] + [reg], type)
// 2 dereferences with multiplication: cast([reg + off1] * [reg + off2], felt)
// Reference no dereference 2 offsets - + : cast(reg - off1 + off2, type)

// Note: The same cases apply with an external dereference. Example: [cast(number, type)]
Expand Down Expand Up @@ -62,7 +64,8 @@ type DerefExp struct {
}

type BinOpExp struct {
LeftExp *LeftExp `@@ "+"`
LeftExp *LeftExp `@@`
Operator string `@("+" | "*")`
RightExp *RightExp `@@`
}

Expand All @@ -83,10 +86,12 @@ type RightExp struct {

type DerefOffset struct {
Deref op.Deref
Op op.Operator
Offset *int
}
type DerefDeref struct {
LeftDeref op.Deref
Op op.Operator
RightDeref op.Deref
}

Expand Down Expand Up @@ -141,8 +146,9 @@ func (expression CastExp) Evaluate() (any, error) {
return result, nil
case DerefOffset:
return op.BinaryOp{
Operator: 0,
Operator: result.Op,
Lhs: result.Deref.Deref,
// TODO: why we're not using something like f.NewElement here?
Rhs: op.Immediate{
uint64(0),
uint64(0),
Expand All @@ -152,7 +158,7 @@ func (expression CastExp) Evaluate() (any, error) {
}, nil
case DerefDeref:
return op.BinaryOp{
Operator: 0,
Operator: result.Op,
Lhs: result.LeftDeref.Deref,
Rhs: result.RightDeref,
}, nil
Expand Down Expand Up @@ -238,8 +244,16 @@ func (expression BinOpExp) Evaluate() (any, error) {
return nil, err
}

operation, err := parseOperator(expression.Operator)
if err != nil {
return nil, err
}

switch lResult := leftExp.(type) {
case op.CellRefer:
// Right now we assume that there is no expression like `reg - off1 * off2`,
// but if there are, we would need to come up with an idea how to handle it.
// Right now we only cover `off1 + off2` expressions here.
offset, ok := rightExp.(*int)
if !ok {
return nil, fmt.Errorf("invalid type operation")
Expand Down Expand Up @@ -267,11 +281,13 @@ func (expression BinOpExp) Evaluate() (any, error) {
case op.Deref:
return DerefDeref{
lResult,
operation,
rResult,
}, nil
case *int:
return DerefOffset{
lResult,
operation,
rResult,
}, nil
}
Expand Down Expand Up @@ -308,3 +324,14 @@ func ParseIdentifier(value string) (any, error) {

return identifierExp.Evaluate()
}

func parseOperator(op string) (hinter.Operator, error) {
switch op {
case "+":
return hinter.Add, nil
case "*":
return hinter.Mul, nil
default:
return 0, fmt.Errorf("unexpected op: %q", op)
}
}
20 changes: 20 additions & 0 deletions pkg/hintrunner/zero/hintparser_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,26 @@ func TestHintParser(t *testing.T) {
},
},
},
{
Parameter: "cast([ap + (-5)] * [ap + (-1)], felt)",
ExpectedCellRefer: nil,
ExpectedResOperander: hinter.BinaryOp{
Operator: hinter.Mul,
Lhs: hinter.ApCellRef(-5),
Rhs: hinter.Deref{
Deref: hinter.ApCellRef(-1),
},
},
},
{
Parameter: "cast([ap] * 3, felt)",
ExpectedCellRefer: nil,
ExpectedResOperander: hinter.BinaryOp{
Operator: hinter.Mul,
Lhs: hinter.ApCellRef(0),
Rhs: hinter.Immediate{0, 0, 0, 3},
},
},
}

for _, test := range testSet {
Expand Down

0 comments on commit 1e87be8

Please sign in to comment.