diff --git a/cmd/casm-inspect/disasm.go b/cmd/casm-inspect/disasm.go new file mode 100644 index 000000000..11abd26f6 --- /dev/null +++ b/cmd/casm-inspect/disasm.go @@ -0,0 +1,572 @@ +package main + +import ( + "errors" + "fmt" + "os" + "sort" + "strconv" + "strings" + + "github.com/NethermindEth/cairo-vm-go/pkg/assembler" + "github.com/NethermindEth/cairo-vm-go/pkg/parsers/zero" + runnerzero "github.com/NethermindEth/cairo-vm-go/pkg/runners/zero" + f "github.com/consensys/gnark-crypto/ecc/stark-curve/fp" + "github.com/urfave/cli/v2" +) + +type instruction struct { + data *assembler.Instruction + offset int64 +} + +type programSymbol struct { + name string + offset int64 + kind string +} + +type casmFunc struct { + sym *programSymbol + code []*f.Element +} + +type casmFuncArg struct { + name string + kind string + pos int +} + +func (sym *programSymbol) PrettyName() string { + return strings.TrimPrefix(sym.name, "__main__.") +} + +type funcDisassembly struct { + sym *programSymbol + lines []string +} + +type disasmProgram struct { + pathToFile string + + fileContent []byte + + runprog *runnerzero.Program + rawprog *zero.ZeroProgram + + funcSymbols []programSymbol + pcToFunc map[int64]*programSymbol + + labelSeq int + jumpTargets map[int64]int + casmFuncs []*casmFunc + + disassembled []*funcDisassembly +} + +func (p *disasmProgram) Action(ctx *cli.Context) error { + p.pathToFile = ctx.Args().Get(0) + if p.pathToFile == "" { + return fmt.Errorf("path to cairo zero file not set") + } + + p.jumpTargets = map[int64]int{} + + type step struct { + name string + fn func() error + } + steps := []step{ + {"read file", p.readFileStep}, + {"load program", p.loadProgramStep}, + {"extract func symbols", p.extractFuncSymbolsStep}, + {"prepare casm funcs", p.prepareCasmFuncs}, + {"disassemble", p.disassembleStep}, + {"print", p.printStep}, + } + for _, s := range steps { + if err := s.fn(); err != nil { + return fmt.Errorf("%s: %w", s.name, err) + } + } + + return nil +} + +func (p *disasmProgram) readFileStep() error { + content, err := os.ReadFile(p.pathToFile) + if err != nil { + return err + } + p.fileContent = content + return nil +} + +func (p *disasmProgram) loadProgramStep() error { + // Runnerzero program is used only for a more convenient access to the bytecode. + program, err := runnerzero.LoadCairoZeroProgram(p.fileContent) + if err != nil { + return err + } + p.runprog = program + return nil +} + +func (p *disasmProgram) extractFuncSymbolsStep() error { + // Need to do some repeated efforts to get the original JSON. + programMetadata, err := zero.ZeroProgramFromJSON(p.fileContent) + if err != nil { + return fmt.Errorf("cannot load program: %w", err) + } + p.rawprog = programMetadata + + p.funcSymbols = p.extractFuncSymbols(programMetadata.Identifiers) + + p.pcToFunc = make(map[int64]*programSymbol, len(p.funcSymbols)) + for i := range p.funcSymbols { + sym := &p.funcSymbols[i] + p.pcToFunc[sym.offset] = sym + } + + if len(p.funcSymbols) == 0 { + return errors.New("can't find any functions") + } + return nil +} + +func (p *disasmProgram) prepareCasmFuncs() error { + for i := range p.funcSymbols { + sym := &p.funcSymbols[i] + + funcSize := int64(len(p.runprog.Bytecode)) - sym.offset + if i+1 < len(p.funcSymbols) { + funcSize = p.funcSymbols[i+1].offset - sym.offset + } + + p.casmFuncs = append(p.casmFuncs, &casmFunc{ + sym: sym, + code: p.runprog.Bytecode[sym.offset : sym.offset+funcSize], + }) + } + + return nil +} + +func (p *disasmProgram) disassembleStep() error { + for _, fn := range p.casmFuncs { + d, err := p.disassembleFunc(fn) + if err != nil { + return fmt.Errorf("disassemble %s: %w", fn.sym.PrettyName(), err) + } + p.disassembled = append(p.disassembled, d) + } + return nil +} + +func (p *disasmProgram) printStep() error { + for _, fn := range p.disassembled { + p.printFunc(fn) + } + return nil +} + +func (p *disasmProgram) instSize(inst *assembler.Instruction) int64 { + switch { + case inst.Op1Source == assembler.Imm: + return 2 + case inst.Opcode == assembler.OpCodeCall: + return 2 + default: + return 1 + } +} + +func (p *disasmProgram) disassembleFunc(fn *casmFunc) (*funcDisassembly, error) { + result := &funcDisassembly{ + sym: fn.sym, + } + + // It's easier to resolve jump targets and sources by doing two traversals. + // Create a decoded instructions slice to make these traversals easier to read. + var instructions []instruction + offset := int64(0) + for offset < int64(len(fn.code)) { + decoded, err := assembler.DecodeInstruction(fn.code[offset]) + if err != nil { + return nil, fmt.Errorf("while decoding at offset %d: %w", offset, err) + } + instructions = append(instructions, instruction{ + data: decoded, + offset: offset, + }) + offset += p.instSize(decoded) + } + + // Pass 1: collect jump targetes. + // Since there can be a "jump backwards" destinations, it's easier + // to do it in 2 passes. + for _, inst := range instructions { + if !p.isResolvableRelJump(inst.data) { + continue + } + imm := fn.code[inst.offset+1] + p.internLabel(fn.sym.offset + feltToInt64(imm) + inst.offset) + } + + // Pass 2: actually disassemble instructions now that we have jump targets info (labels). + for _, inst := range instructions { + address := fn.sym.offset + inst.offset + if labelID, ok := p.jumpTargets[address]; ok { + result.lines = append(result.lines, fmt.Sprintf("L%d:", labelID)) + } + line, err := p.formatInstruction(fn, inst) + if err != nil { + return nil, fmt.Errorf("while formatting %s at offset %d: %w", inst.data.Opcode.String(), offset, err) + } + result.lines = append(result.lines, line) + } + + return result, nil +} + +func (p *disasmProgram) printFunc(fn *funcDisassembly) { + var args []string + for _, a := range p.funcArgs(fn.sym) { + args = append(args, fmt.Sprintf("%s: %s", a.name, a.kind)) + } + + var implicitArgs []string + for _, a := range p.funcImplicitArgs(fn.sym) { + implicitArgs = append(implicitArgs, fmt.Sprintf("%s: %s", a.name, a.kind)) + } + implicitArgsString := "" + if len(implicitArgs) != 0 { + implicitArgsString = "{" + strings.Join(implicitArgs, ", ") + "}" + } + + fmt.Printf("// func entry pc=%d\n", fn.sym.offset) + + // The arguments are below current FP. + // At the beginning of a function, FP=AP. + // FP-1 is a return address. + // FP-2 is a previous FP value. + // Then explicit arguments go in a reverse order. + // At last, implicit arguments follow. + // + // Consider a function like f(x, y); + // It's called by pushing x and then y to AP. + // X was pushed before Y, so it's deeper into the stack. + // Therefore, X would be located at [fp-4] while Y can be found at [fp-3]. + fpOffset := -3 + for i := len(args) - 1; i >= 0; i-- { + fmt.Printf("// [fp%+d] => %s\n", fpOffset, args[i]) + fpOffset-- + } + for i := len(implicitArgs) - 1; i >= 0; i-- { + fmt.Printf("// [fp%+d] => %s (implicit arg)\n", fpOffset, implicitArgs[i]) + fpOffset-- + } + + fmt.Printf("func %s%s(%s) ", fn.sym.PrettyName(), implicitArgsString, strings.Join(args, ", ")) + returnType := p.funcReturnType(fn.sym) + if returnType != "()" && returnType != "" { + fmt.Printf("-> %s ", returnType) + } + fmt.Printf("{ \n") + + for _, l := range fn.lines { + fmt.Printf(" %s\n", l) + } + + fmt.Printf("}\n") +} + +func (p *disasmProgram) extractFuncSymbols(identifiers map[string]any) []programSymbol { + result := make([]programSymbol, 0, len(identifiers)) + + for symbolName, symbolData := range identifiers { + symbolData, ok := symbolData.(map[string]any) + if !ok { + continue + } + symbolKind, ok := symbolData["type"].(string) + if !ok { + continue + } + if symbolKind != "function" { + continue + } + offset, ok := symbolData["pc"].(float64) + if !ok { + continue + } + result = append(result, programSymbol{ + name: symbolName, + kind: symbolKind, + offset: int64(offset), + }) + } + + sort.Slice(result, func(i, j int) bool { + return result[i].offset < result[j].offset + }) + + return result +} + +func (p *disasmProgram) funcReturnType(sym *programSymbol) string { + ret, _ := p.rawprog.Identifiers[sym.name+".Return"].(map[string]any) + return ret["cairo_type"].(string) +} + +func (p *disasmProgram) funcImplicitArgs(sym *programSymbol) []casmFuncArg { + members, ok := p.lookupKeys(p.rawprog.Identifiers, sym.name+".ImplicitArgs", "members").(map[string]any) + if !ok { + return nil + } + return p.collectArgMembers(members) +} + +func (p *disasmProgram) funcArgs(sym *programSymbol) []casmFuncArg { + members, ok := p.lookupKeys(p.rawprog.Identifiers, sym.name+".Args", "members").(map[string]any) + if !ok { + return nil + } + return p.collectArgMembers(members) +} + +func (p *disasmProgram) lookupKeys(m map[string]any, keys ...string) any { + var current any = m + for _, k := range keys { + asMap, ok := current.(map[string]any) + if !ok { + return nil + } + current = asMap[k] + } + return current +} + +func (p *disasmProgram) collectArgMembers(members map[string]any) []casmFuncArg { + var args []casmFuncArg + for argName, argData := range members { + argData := argData.(map[string]any) + args = append(args, casmFuncArg{ + name: argName, + kind: argData["cairo_type"].(string), + pos: int(argData["offset"].(float64)), + }) + } + // members was a map, the order could be randomized. + // Sort by the argument offsets (this gives us the correct declaration order). + sort.SliceStable(args, func(i, j int) bool { + return args[i].pos < args[j].pos + }) + return args +} + +func (p *disasmProgram) isResolvableRelJump(inst *assembler.Instruction) bool { + switch inst.PcUpdate { + case assembler.PcUpdateJumpRel, assembler.PcUpdateJnz: + return inst.Op1Source == assembler.Imm + default: + return false + } +} + +func (p *disasmProgram) formatInstruction(fn *casmFunc, inst instruction) (string, error) { + var comments []string + + bytecode := fn.code + offset := inst.offset + + var buf strings.Builder + + buf.WriteString(" ") // Indentation + + switch inst.data.Opcode { + case assembler.OpCodeRet: + buf.WriteString("ret") + + case assembler.OpCodeCall: + var callee *programSymbol + var callSuffix string + var address int64 + offset := feltToInt64(bytecode[offset+1]) + switch inst.data.PcUpdate { + case assembler.PcUpdateJump: + callSuffix = "abs" + address = offset + case assembler.PcUpdateJumpRel: + callSuffix = "rel" + address = fn.sym.offset + offset + inst.offset + } + callee = p.pcToFunc[address] + fmt.Fprintf(&buf, "call %s %+d", callSuffix, offset) + if callee != nil { + comments = append(comments, fmt.Sprintf("func %s", callee.PrettyName())) + } + + case assembler.OpCodeAssertEq: + buf.WriteString("assert ") + buf.WriteString(p.formatMemoryOperand(inst.data.DstRegister, int(inst.data.OffDest))) + buf.WriteString(" = ") + + if inst.data.Res != assembler.Op1 { + buf.WriteString(p.formatMemoryOperand(inst.data.Op0Register, int(inst.data.OffOp0))) + } + switch inst.data.Res { + case assembler.AddOperands: + buf.WriteString(" + ") + case assembler.MulOperands: + buf.WriteString(" * ") + } + + buf.WriteString(p.formatOperand1(fn, inst.data, offset)) + + if inst.data.Op1Source == assembler.Imm { + // Try to recognize the division. + // So, instead of just this: + // > assert [fp+1] = [fp] * 2894802230932904970957858226476056084498485772265277359978473644908697616385 + // ...the user sees this (note the comment): + // > assert [fp+1] = [fp] * 2894802230932904970957858226476056084498485772265277359978473644908697616385 // div 5 + imm := bytecode[offset+1] + if inst.data.Res == assembler.MulOperands && !imm.IsUint64() { + dividend := f.NewElement(0) + dividend.Inverse(imm) + comments = append(comments, "div "+dividend.String()) + } + } + + case assembler.OpCodeNop: + // Jumps use the same opcode=0. + switch inst.data.PcUpdate { + case assembler.PcUpdateJump: + buf.WriteString("jmp abs " + p.formatOperand1(fn, inst.data, offset)) + case assembler.PcUpdateJumpRel: + buf.WriteString("jmp rel " + p.formatOperand1(fn, inst.data, offset)) + case assembler.PcUpdateJnz: + fmt.Fprintf(&buf, "jmp rel %s if %s != 0", + p.formatOperand1(fn, inst.data, offset), + p.formatMemoryOperand(inst.data.DstRegister, int(inst.data.OffDest))) + default: + if inst.data.Op1Source == assembler.Imm && inst.data.ApUpdate == assembler.AddRes { + comments = append(comments, "alloc_locals") + } + buf.WriteString("nop") + } + + if p.isResolvableRelJump(inst.data) { + imm := bytecode[offset+1] + addr := feltToInt64(imm) + fn.sym.offset + inst.offset + if labelID, ok := p.jumpTargets[addr]; ok { + comments = append(comments, fmt.Sprintf("targets L%d", labelID)) + } + } + + default: + return "", fmt.Errorf("unexpected opcode: %v", inst.data.Opcode) + } + + // ap++ is a valid cairo0 syntax, so use that whether possible. + // For any other ap change use a comment annotation. + switch inst.data.ApUpdate { + case assembler.Add1: + buf.WriteString(", ap++") + case assembler.Add2: + comments = append(comments, "ap += 2") + case assembler.SameAp: + // Nothing to do. + case assembler.AddRes: + if inst.data.Op1Source == assembler.Imm { + imm := bytecode[offset+1] + comments = append(comments, "ap += "+imm.String()) + } else { + comments = append(comments, "ap += $result") + } + default: + return "", fmt.Errorf("unexpected ap update: %v", inst.data.ApUpdate) + } + + buf.WriteString(";") + + if len(comments) != 0 { + fmt.Fprintf(&buf, " // %s", strings.Join(comments, "; ")) + } + + return buf.String(), nil +} + +func (p *disasmProgram) formatOperand1(fn *casmFunc, inst *assembler.Instruction, offset int64) string { + var buf strings.Builder + + switch inst.Op1Source { + case assembler.ApPlusOffOp1: + buf.WriteString(p.formatMemoryOperand(assembler.Ap, int(inst.OffOp1))) + case assembler.FpPlusOffOp1: + buf.WriteString(p.formatMemoryOperand(assembler.Fp, int(inst.OffOp1))) + case assembler.Imm: + imm := fn.code[offset+1] + buf.WriteString(imm.String()) + case assembler.Op0: + // Things like [[fp+10]+20]. + buf.WriteString(p.formatMemoryOperand2(inst.Op0Register, int(inst.OffOp0), int(inst.OffOp1))) + } + + return buf.String() +} + +func (p *disasmProgram) formatMemoryOperand(reg assembler.Register, offset int) string { + var buf strings.Builder + buf.WriteString("[") + buf.WriteString(strings.ToLower(reg.String())) + if offset != 0 { + fmt.Fprintf(&buf, "%+d", offset) + } + buf.WriteString("]") + return buf.String() +} + +func (p *disasmProgram) formatMemoryOperand2(reg assembler.Register, offset, offset2 int) string { + var buf strings.Builder + + buf.WriteString("[[") + buf.WriteString(strings.ToLower(reg.String())) + if offset != 0 { + fmt.Fprintf(&buf, "%+d", offset) + } + buf.WriteString("]") + if offset2 != 0 { + fmt.Fprintf(&buf, "%+d", offset2) + } + buf.WriteString("]") + + return buf.String() +} + +func (p *disasmProgram) internLabel(address int64) int { + if id, ok := p.jumpTargets[address]; ok { + return id + } + id := p.labelSeq + p.labelSeq++ + p.jumpTargets[address] = id + return id +} + +func feltToInt64(felt *f.Element) int64 { + // This would not be correct: int64(felt.Uint64) + // since signed values will reside in more than one 64-bit word. + // + // BigInt().Int64() would not work neither. + // + // String() handles signed values pretty well for our use-case. + // Maybe there is another way to avoid the redundant String()+Parsing? + + s := felt.String() + v, err := strconv.ParseInt(s, 10, 64) + if err != nil { + return 0 + } + return v +} diff --git a/cmd/casm-inspect/main.go b/cmd/casm-inspect/main.go new file mode 100644 index 000000000..0c4b043b5 --- /dev/null +++ b/cmd/casm-inspect/main.go @@ -0,0 +1,34 @@ +package main + +import ( + "fmt" + "os" + + "github.com/urfave/cli/v2" +) + +func main() { + disasm := &disasmProgram{} + + app := &cli.App{ + Name: "casm-inspect", + Usage: "casm-inspect [args...]", + Description: "A cairo zero file inspector", + EnableBashCompletion: true, + Suggest: true, + DefaultCommand: "help", + Commands: []*cli.Command{ + { + Name: "disasm", + Usage: "disasm compiled_cairo0.json", + Description: "disassemble the casm from the compiled cairo zero program", + Action: disasm.Action, + }, + }, + } + + if err := app.Run(os.Args); err != nil { + fmt.Println(err) + os.Exit(1) + } +}