Skip to content

Commit

Permalink
Merge pull request #3069 from onflow/sainati/contract-upgrade-checker
Browse files Browse the repository at this point in the history
Contract Update Validator
  • Loading branch information
turbolent authored Feb 22, 2024
2 parents 99159e8 + 65faf29 commit 5427692
Show file tree
Hide file tree
Showing 7 changed files with 1,740 additions and 64 deletions.
3 changes: 3 additions & 0 deletions runtime/contract_update_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -716,6 +716,9 @@ func TestRuntimeLegacyContractUpdate(t *testing.T) {
OnGetAccountContractCode: func(location common.AddressLocation) (code []byte, err error) {
return accountCodes[location], nil
},
OnGetAccountContractNames: func(_ Address) ([]string, error) {
return []string{"Foo"}, nil
},
OnUpdateAccountContractCode: func(location common.AddressLocation, code []byte) error {
accountCodes[location] = code
return nil
Expand Down
1 change: 0 additions & 1 deletion runtime/contract_update_validation_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2561,7 +2561,6 @@ func TestRuntimeContractUpdateConformanceChanges(t *testing.T) {
RequireError(t, err)

cause := getSingleContractUpdateErrorCause(t, err, "Test")

assertConformanceMismatchError(t, cause, "Foo")
})

Expand Down
30 changes: 30 additions & 0 deletions runtime/interpreter/interpreter.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
"fmt"
"math"
"math/big"
"sort"
"strconv"
"time"

Expand Down Expand Up @@ -4577,6 +4578,35 @@ func (interpreter *Interpreter) getElaboration(location common.Location) *sema.E
return subInterpreter.Program.Elaboration
}

func (interpreter *Interpreter) AllElaborations() (elaborations map[common.Location]*sema.Elaboration) {

elaborations = map[common.Location]*sema.Elaboration{}

allInterpreters := interpreter.SharedState.allInterpreters

locations := make([]common.Location, 0, len(allInterpreters))

for location := range allInterpreters { //nolint:maprange
locations = append(locations, location)
}

sort.Slice(locations, func(i, j int) bool {
a := locations[i]
b := locations[j]
return a.ID() < b.ID()
})

for _, location := range locations {
elaboration := interpreter.getElaboration(location)
if elaboration == nil {
panic(errors.NewUnexpectedError("missing elaboration for location %s", location))
}
elaborations[location] = elaboration
}

return
}

// GetContractComposite gets the composite value of the contract at the address location.
func (interpreter *Interpreter) GetContractComposite(contractLocation common.AddressLocation) (*CompositeValue, error) {
contractGlobal := interpreter.Globals.Get(contractLocation.Name)
Expand Down
8 changes: 5 additions & 3 deletions runtime/stdlib/account.go
Original file line number Diff line number Diff line change
Expand Up @@ -1586,6 +1586,8 @@ func changeAccountContracts(

// Validate the contract update

inter := invocation.Interpreter

if isUpdate {
oldCode, err := handler.GetAccountContractCode(location)
handleContractUpdateError(err)
Expand Down Expand Up @@ -1618,11 +1620,13 @@ func changeAccountContracts(

var validator UpdateValidator
if legacyContractUpgrade {
validator = NewLegacyContractUpdateValidator(
validator = NewCadenceV042ToV1ContractUpdateValidator(
location,
contractName,
handler,
oldProgram,
program.Program,
inter.AllElaborations(),
)
} else {
validator = NewContractUpdateValidator(
Expand All @@ -1637,8 +1641,6 @@ func changeAccountContracts(
handleContractUpdateError(err)
}

inter := invocation.Interpreter

err = updateAccountContractCode(
handler,
location,
Expand Down
Loading

0 comments on commit 5427692

Please sign in to comment.