From 53b730be302644c16437b49d46963f277647483b Mon Sep 17 00:00:00 2001 From: Iskander Sharipov Date: Tue, 30 Jan 2024 17:54:41 +0400 Subject: [PATCH] pkg/parsers/zero: improve Identifiers typing This allows us to do less redundant type assertions while handling the zero prog identifiers. We may go even further, but this step is enough for now to make zero hints handler less awkward. The `.([]zero.Reference)` assertion was not working as the slice contained `[]any` instead. Instead of rewriting the code to handle even more weakly typed `any` values, this patch improves the Identifier objects typing. This makes type assertions and slice conversions unnecessary. --- pkg/hintrunner/zero/zerohint.go | 12 ++------ pkg/parsers/zero/zero.go | 20 +++++++++++-- pkg/parsers/zero/zero_test.go | 51 +++++++++++++++++---------------- pkg/runners/zero/program.go | 36 ++++++----------------- 4 files changed, 55 insertions(+), 64 deletions(-) diff --git a/pkg/hintrunner/zero/zerohint.go b/pkg/hintrunner/zero/zerohint.go index 2fb7f4ece..eb1ec2a3e 100644 --- a/pkg/hintrunner/zero/zerohint.go +++ b/pkg/hintrunner/zero/zerohint.go @@ -64,19 +64,11 @@ func GetParameters(zeroProgram *zero.ZeroProgram, hint zero.Hint, hintPC uint64) if !ok { return nil, nil, fmt.Errorf("missing identifier %s", referenceName) } - identifier, ok := rawIdentifier.(map[string]any) - if !ok { - return nil, nil, fmt.Errorf("wrong structure for identifier") - } - rawReferences, ok := identifier["references"] - if !ok { + if len(rawIdentifier.References) == 0 { return nil, nil, fmt.Errorf("identifier %s should have at least one reference", referenceName) } - references, ok := rawReferences.([]zero.Reference) - if !ok { - return nil, nil, fmt.Errorf("expected a list of references") - } + references := rawIdentifier.References // Go through the references in reverse order to get the one with biggest pc smaller or equal to the hint pc var reference zero.Reference diff --git a/pkg/parsers/zero/zero.go b/pkg/parsers/zero/zero.go index 68225cf2c..925350266 100644 --- a/pkg/parsers/zero/zero.go +++ b/pkg/parsers/zero/zero.go @@ -2,8 +2,9 @@ package zero import ( "encoding/json" - starknetParser "github.com/NethermindEth/cairo-vm-go/pkg/parsers/starknet" "os" + + starknetParser "github.com/NethermindEth/cairo-vm-go/pkg/parsers/starknet" ) type FlowTrackingData struct { @@ -73,12 +74,27 @@ type ZeroProgram struct { Hints map[string][]Hint `json:"hints"` CompilerVersion string `json:"version"` MainScope string `json:"main_scope"` - Identifiers map[string]any `json:"identifiers"` + Identifiers map[string]*Identifier `json:"identifiers"` ReferenceManager ReferenceManager `json:"reference_manager"` Attributes []AttributeScope `json:"attributes"` DebugInfo DebugInfo `json:"debug_info"` } +type Identifier struct { + FullName string `json:"full_name"` + IdentifierType string `json:"type"` + CairoType string `json:"cairo_type"` + Destination string `json:"destination"` + Pc int `json:"pc"` + Size int `json:"size"` + Members map[string]any `json:"members"` + References []Reference `json:"references"` + + // These fields are listed as any-typed fields before we need them. + Decorators any `json:"decorators"` + Value any `json:"value"` +} + func (z ZeroProgram) MarshalToFile(filepath string) error { // Marshal Output struct into JSON bytes data, err := json.MarshalIndent(z, "", " ") diff --git a/pkg/parsers/zero/zero_test.go b/pkg/parsers/zero/zero_test.go index 41d53b581..0db00ef56 100644 --- a/pkg/parsers/zero/zero_test.go +++ b/pkg/parsers/zero/zero_test.go @@ -1,9 +1,10 @@ package zero import ( - starknetParser "github.com/NethermindEth/cairo-vm-go/pkg/parsers/starknet" "testing" + starknetParser "github.com/NethermindEth/cairo-vm-go/pkg/parsers/starknet" + "github.com/stretchr/testify/require" ) @@ -220,41 +221,41 @@ func TestIdentifiers(t *testing.T) { require.Equal( t, &ZeroProgram{ - Identifiers: map[string]any{ - "__main__.fib": map[string]any{ - "decorators": make([]any, 0), - "pc": float64(9), - "type": "function", + Identifiers: map[string]*Identifier{ + "__main__.fib": { + Decorators: make([]any, 0), + Pc: 9, + IdentifierType: "function", }, - "__main__.BitwiseBuiltin": map[string]any{ - "destination": "starkware.cairo.common.cairo_builtins.BitwiseBuiltin", - "type": "alias", + "__main__.BitwiseBuiltin": { + Destination: "starkware.cairo.common.cairo_builtins.BitwiseBuiltin", + IdentifierType: "alias", }, - "__main__.fill_array.Args": map[string]any{ - "full_name": "__main__.fill_array.Args", - "members": map[string]any{ + "__main__.fill_array.Args": { + FullName: "__main__.fill_array.Args", + Members: map[string]any{ "array": map[string]any{ "cairo_type": "felt*", "offset": float64(0), }, }, - "size": float64(1), - "type": "struct", + Size: 1, + IdentifierType: "struct", }, - "__main__.fill_array.__temp18": map[string]any{ - "cairo_type": "felt", - "full_name": "__main__.fill_array.__temp18", - "references": []any{ - map[string]any{ - "ap_tracking_data": map[string]any{ - "group": float64(26), - "offset": float64(1), + "__main__.fill_array.__temp18": { + CairoType: "felt", + FullName: "__main__.fill_array.__temp18", + References: []Reference{ + { + ApTrackingData: ApTracking{ + Group: 26, + Offset: 1, }, - "pc": float64(312), - "value": "[cast(ap + (-1), felt*)]", + Pc: 312, + Value: "[cast(ap + (-1), felt*)]", }, }, - "type": "reference", + IdentifierType: "reference", }, }, }, diff --git a/pkg/runners/zero/program.go b/pkg/runners/zero/program.go index 8bee01b4c..3ca5ce40b 100644 --- a/pkg/runners/zero/program.go +++ b/pkg/runners/zero/program.go @@ -1,7 +1,6 @@ package zero import ( - "errors" "fmt" sn "github.com/NethermindEth/cairo-vm-go/pkg/parsers/starknet" @@ -56,14 +55,10 @@ func extractEntrypoints(json *zero.ZeroProgram) (map[string]uint64, error) { result := make(map[string]uint64) err := scanIdentifiers( json, - func(key string, typex string, value map[string]any) error { - if typex == "function" { - pc, ok := value["pc"].(float64) - if !ok { - return fmt.Errorf("%s: unknown entrypoint pc", key) - } + func(key string, ident *zero.Identifier) error { + if ident.IdentifierType == "function" { name := key[len(json.MainScope)+1:] - result[name] = uint64(pc) + result[name] = uint64(ident.Pc) } return nil }, @@ -79,14 +74,10 @@ func extractLabels(json *zero.ZeroProgram) (map[string]uint64, error) { labels := make(map[string]uint64, 2) err := scanIdentifiers( json, - func(key string, typex string, value map[string]any) error { - if typex == "label" { - pc, ok := value["pc"].(float64) - if !ok { - return fmt.Errorf("%s: unknown entrypoint pc", key) - } + func(key string, ident *zero.Identifier) error { + if ident.IdentifierType == "label" { name := key[len(json.MainScope)+1:] - labels[name] = uint64(pc) + labels[name] = uint64(ident.Pc) } return nil }, @@ -98,18 +89,9 @@ func extractLabels(json *zero.ZeroProgram) (map[string]uint64, error) { return labels, nil } -func scanIdentifiers( - json *zero.ZeroProgram, - f func(key string, typex string, value map[string]any) error, -) error { - for key, value := range json.Identifiers { - properties := value.(map[string]any) - - typex, ok := properties["type"].(string) - if !ok { - return errors.New("unnespecified identifier type") - } - if err := f(key, typex, properties); err != nil { +func scanIdentifiers(json *zero.ZeroProgram, f func(key string, ident *zero.Identifier) error) error { + for key, ident := range json.Identifiers { + if err := f(key, ident); err != nil { return err } }