From 202640bf164db9a3090011b9d8638cc4ac4afa4f Mon Sep 17 00:00:00 2001 From: antazoey Date: Thu, 4 May 2023 10:40:59 -0500 Subject: [PATCH] feat: show source code lines / traceback from `ReceiptAPI` [APE-708] (#1337) --- codeql-config.yml | 1 - setup.py | 6 +- src/ape/api/compiler.py | 40 ++- src/ape/api/projects.py | 13 +- src/ape/api/transactions.py | 25 +- src/ape/contracts/base.py | 2 +- src/ape/exceptions.py | 108 ++++++- src/ape/managers/chain.py | 46 ++- src/ape/managers/project/manager.py | 26 +- src/ape/managers/project/types.py | 11 +- src/ape/types/__init__.py | 6 +- src/ape/types/trace.py | 438 +++++++++++++++++++++++++- src/ape_ethereum/transactions.py | 35 +- src/ape_geth/provider.py | 2 +- tests/functional/test_geth.py | 47 +-- tests/integration/cli/test_compile.py | 11 +- 16 files changed, 743 insertions(+), 74 deletions(-) diff --git a/codeql-config.yml b/codeql-config.yml index 0521ecd4b7..5498c8e9f5 100644 --- a/codeql-config.yml +++ b/codeql-config.yml @@ -5,4 +5,3 @@ queries: paths: - src - diff --git a/setup.py b/setup.py index 170154fe91..00018162d8 100644 --- a/setup.py +++ b/setup.py @@ -19,7 +19,7 @@ ], "lint": [ "black>=23.3.0,<24", # Auto-formatter and linter - "mypy>=0.991", # Static type analyzer + "mypy>=0.991,<1", # Static type analyzer "types-PyYAML", # Needed due to mypy typeshed "types-requests", # Needed due to mypy typeshed "types-setuptools", # Needed due to mypy typeshed @@ -121,8 +121,8 @@ "web3[tester]>=6.0.0,<7", # ** Dependencies maintained by ApeWorX ** "eip712>=0.2.1,<0.3", - "ethpm-types>=0.4.5,<0.5", - "evm-trace>=0.1.0a18", + "ethpm-types>=0.5.0,<0.6", + "evm-trace>=0.1.0a19", ], entry_points={ "console_scripts": ["ape=ape._cli:cli"], diff --git a/src/ape/api/compiler.py b/src/ape/api/compiler.py index 1a5dd02a95..494fdbb7bf 100644 --- a/src/ape/api/compiler.py +++ b/src/ape/api/compiler.py @@ -1,10 +1,14 @@ from pathlib import Path -from typing import Dict, List, Optional, Set +from typing import Dict, Iterator, List, Optional, Set, Tuple -from ethpm_types import ContractType +from ethpm_types import ContractType, HexBytes +from ethpm_types.source import ContractSource +from evm_trace.geth import TraceFrame as EvmTraceFrame +from evm_trace.geth import create_call_node_data from semantic_version import Version # type: ignore from ape.exceptions import ContractLogicError +from ape.types.trace import SourceTraceback, TraceFrame from ape.utils import BaseInterfaceModel, abstractmethod, raises_not_implemented @@ -125,3 +129,35 @@ def enrich_error(self, err: ContractLogicError) -> ContractLogicError: """ return err + + @raises_not_implemented + def trace_source( # type: ignore[empty-body] + self, contract_type: ContractType, trace: Iterator[TraceFrame], calldata: HexBytes + ) -> SourceTraceback: + """ + Get a source-traceback for the given contract type. + The source traceback object contains all the control paths taken in the transaction. + When available, source-code location information is accessible from the object. + + Args: + contract_type (``ContractType``): A contract type that was created by this compiler. + trace (Iterator[:class:`~ape.types.trace.TraceFrame`]): The resulting frames from + executing a function defined in the given contract type. + calldata (``HexBytes``): Calldata passed to the top-level call. + + Returns: + :class:`~ape.types.trace.SourceTraceback` + """ + + def _create_contract_from_call( + self, frame: TraceFrame + ) -> Tuple[Optional[ContractSource], HexBytes]: + evm_frame = EvmTraceFrame(**frame.raw) + data = create_call_node_data(evm_frame) + calldata = data["calldata"] + address = self.provider.network.ecosystem.decode_address(data["address"]) + if address not in self.chain_manager.contracts: + return None, calldata + + called_contract = self.chain_manager.contracts[address] + return self.project_manager._create_contract_source(called_contract), calldata diff --git a/src/ape/api/projects.py b/src/ape/api/projects.py index e2277a4a2d..d1b740c276 100644 --- a/src/ape/api/projects.py +++ b/src/ape/api/projects.py @@ -1,7 +1,7 @@ import os.path import tempfile from pathlib import Path -from typing import TYPE_CHECKING, Any, Dict, List, Optional +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union import yaml from ethpm_types import Checksum, ContractType, PackageManifest, Source @@ -171,17 +171,20 @@ def _create_manifest( @classmethod def _create_source_dict( - cls, contract_filepaths: List[Path], base_path: Path + cls, contract_filepaths: Union[Path, List[Path]], base_path: Path ) -> Dict[str, Source]: + filepaths = ( + [contract_filepaths] if isinstance(contract_filepaths, Path) else contract_filepaths + ) source_imports: Dict[str, List[str]] = cls.compiler_manager.get_imports( - contract_filepaths, base_path + filepaths, base_path ) # {source_id: [import_source_ids, ...], ...} source_references: Dict[str, List[str]] = cls.compiler_manager.get_references( imports_dict=source_imports ) # {source_id: [referring_source_ids, ...], ...} source_dict: Dict[str, Source] = {} - for source_path in contract_filepaths: + for source_path in filepaths: key = str(get_relative_path(source_path, base_path)) source_dict[key] = Source( checksum=Checksum( @@ -354,7 +357,7 @@ def compile(self) -> PackageManifest: # Create content, including sub-directories. source_path.parent.mkdir(parents=True, exist_ok=True) source_path.touch() - source_path.write_text(content) + source_path.write_text(str(content)) # Handle import remapping entries indicated in the manifest file target_config_file = project.path / project.config_file_name diff --git a/src/ape/api/transactions.py b/src/ape/api/transactions.py index 075ea4481b..19a87d2534 100644 --- a/src/ape/api/transactions.py +++ b/src/ape/api/transactions.py @@ -19,7 +19,13 @@ TransactionNotFoundError, ) from ape.logging import logger -from ape.types import AddressType, ContractLogContainer, TraceFrame, TransactionSignature +from ape.types import ( + AddressType, + ContractLogContainer, + SourceTraceback, + TraceFrame, + TransactionSignature, +) from ape.utils import BaseInterfaceModel, abstractmethod, cached_property, raises_not_implemented if TYPE_CHECKING: @@ -428,6 +434,15 @@ def return_value(self) -> Any: return output + @property + @raises_not_implemented + def source_traceback(self) -> SourceTraceback: # type: ignore[empty-body] + """ + A pythonic style traceback for both failing and non-failing receipts. + Requires a provider that implements + :meth:~ape.api.providers.ProviderAPI.get_transaction_trace`. + """ + @raises_not_implemented def show_trace(self, verbose: bool = False, file: IO[str] = sys.stdout): """ @@ -445,6 +460,14 @@ def show_gas_report(self, file: IO[str] = sys.stdout): Display a gas report for the calls made in this transaction. """ + @raises_not_implemented + def show_source_traceback(self): + """ + Show a receipt traceback mapping to lines in the source code. + Only works when the contract type and source code are both available, + like in local projects. + """ + def track_gas(self): """ Track this receipt's gas in the on-going session gas-report. diff --git a/src/ape/contracts/base.py b/src/ape/contracts/base.py index 5097766dbb..c98dfba619 100644 --- a/src/ape/contracts/base.py +++ b/src/ape/contracts/base.py @@ -756,7 +756,7 @@ def receipt(self) -> Optional[ReceiptAPI]: if not self._cached_receipt and self.txn_hash: try: receipt = self.chain_manager.get_receipt(self.txn_hash) - except TransactionNotFoundError: + except (TransactionNotFoundError, ValueError): return None self._cached_receipt = receipt diff --git a/src/ape/exceptions.py b/src/ape/exceptions.py index 715dbbea4a..ccadbca000 100644 --- a/src/ape/exceptions.py +++ b/src/ape/exceptions.py @@ -1,10 +1,12 @@ import sys +import tempfile import time import traceback from collections import deque from functools import cached_property from inspect import getframeinfo, stack from pathlib import Path +from types import CodeType, TracebackType from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Optional import click @@ -19,7 +21,7 @@ from ape.api.networks import NetworkAPI from ape.api.providers import SubprocessProvider from ape.api.transactions import TransactionAPI - from ape.types import AddressType, BlockID, SnapshotID, TraceFrame + from ape.types import AddressType, BlockID, SnapshotID, SourceTraceback, TraceFrame class ApeException(Exception): @@ -112,9 +114,24 @@ def __init__( self.txn = txn self.trace = trace self.contract_address = contract_address + self.source_traceback: Optional["SourceTraceback"] = None ex_message = f"({code}) {message}" if code else message + + # Finalizes expected revert message. super().__init__(ex_message) + if not txn: + return + + ape_tb = _get_ape_traceback(self, txn) + if not ape_tb: + return + + self.source_traceback = ape_tb + py_tb = _get_custom_python_traceback(self, txn, ape_tb) + if py_tb: + self.__traceback__ = py_tb + class VirtualMachineError(TransactionError): """ @@ -532,9 +549,10 @@ def handle_ape_exception(err: ApeException, base_paths: List[Path]) -> bool: an exception on the exc-stack. Args: - err (:class:`~ape.exceptions.ApeException`): The transaction error + err (:class:`~ape.exceptions.TransactionError`): The transaction error being handled. - base_paths (List[Path]): Source base paths for allowed frames. + base_paths (Optional[List[Path]]): Optionally include additional + source-path prefixes to use when finding relevant frames. Returns: bool: ``True`` if outputted something. @@ -621,3 +639,87 @@ def name(self) -> str: The name of the error. """ return self.abi.name + + +def _get_ape_traceback(err: TransactionError, txn: "TransactionAPI") -> Optional["SourceTraceback"]: + receipt = txn.receipt + if not receipt: + return None + + try: + ape_traceback = receipt.source_traceback + except (ApeException, NotImplementedError): + return None + + if ape_traceback is None or not len(ape_traceback): + return None + + return ape_traceback + + +def _get_custom_python_traceback( + err: TransactionError, txn: "TransactionAPI", ape_traceback: "SourceTraceback" +) -> Optional[TracebackType]: + # Manipulate python traceback to show lines from contract. + # Help received from Jinja lib: + # https://github.com/pallets/jinja/blob/main/src/jinja2/debug.py#L142 + + _, exc_value, tb = sys.exc_info() + depth = None + idx = len(ape_traceback) - 1 + frames = [] + project_path = txn.project_manager.path.as_posix() + while tb is not None: + if not tb.tb_frame.f_code.co_filename.startswith(project_path): + # Ignore frames outside the project. + # This allows both contract code an scripts to appear. + tb = tb.tb_next + continue + + frames.append(tb) + tb = tb.tb_next + + while (depth is None or depth > 1) and idx >= 0: + exec_item = ape_traceback[idx] + if depth is not None and exec_item.depth >= depth: + # Wait for decreasing depth. + continue + + depth = exec_item.depth + lineno = exec_item.begin_lineno + if lineno is None: + continue + + if exec_item.source_path is None: + # File is not local. Create a temporary file in its place. + # This is necessary for tracebacks to work in Python. + temp_file = tempfile.NamedTemporaryFile(prefix="unknown_contract_") + filename = temp_file.name + else: + filename = exec_item.source_path.as_posix() + + # Raise an exception at the correct line number. + py_code: CodeType = compile( + "\n" * (lineno - 1) + "raise __ape_exception__", filename, "exec" + ) + py_code = py_code.replace(co_name=exec_item.closure.name) + + # Execute the new code to get a new (fake) tb with contract source info. + try: + exec(py_code, {"__ape_exception__": err}, {}) + except BaseException: + fake_tb = sys.exc_info()[2].tb_next # type: ignore + if isinstance(fake_tb, TracebackType): + frames.append(fake_tb) + + idx -= 1 + + if not frames: + return None + + tb_next = None + for tb in frames: + tb.tb_next = tb_next + tb_next = tb + + return frames[-1] diff --git a/src/ape/managers/chain.py b/src/ape/managers/chain.py index a9cc6b3e09..96df14e1c0 100644 --- a/src/ape/managers/chain.py +++ b/src/ape/managers/chain.py @@ -28,12 +28,13 @@ ChainError, ConversionError, CustomError, + ProviderNotConnectedError, QueryEngineError, UnknownSnapshotError, ) from ape.logging import logger from ape.managers.base import BaseManager -from ape.types import AddressType, BlockID, CallTreeNode, SnapshotID +from ape.types import AddressType, BlockID, CallTreeNode, SnapshotID, SourceTraceback from ape.utils import BaseInterfaceModel, TraceStyles, singledispatchmethod @@ -604,7 +605,7 @@ def __getitem_str(self, account_or_hash: str) -> Union[AccountHistory, ReceiptAP to retrieve it. Args: - transaction_hash (str): The hash of the transaction. + account_or_hash (str): The hash of the transaction. Returns: :class:`~ape.api.transactions.ReceiptAPI`: The receipt. @@ -615,15 +616,22 @@ def __getitem_str(self, account_or_hash: str) -> Union[AccountHistory, ReceiptAP return self._get_account_history(address) except Exception: # Use Transaction hash - receipt = self._hash_to_receipt_map.get(account_or_hash) - if not receipt: - # TODO: Replace with query manager once supports receipts - # instead of transactions. - # TODO: Add timeout = 0 once in API method to not wait for txns - receipt = self.provider.get_receipt(account_or_hash) - self.append(receipt) + try: + return self._get_receipt(account_or_hash) + except Exception: + pass + + # If we get here, we failed to get an account or receipt. + # Raise top-level exception. + raise - return receipt + def _get_receipt(self, txn_hash: str) -> ReceiptAPI: + receipt = self._hash_to_receipt_map.get(txn_hash) + if not receipt: + receipt = self.provider.get_receipt(txn_hash, timeout=0) + self.append(receipt) + + return receipt def append(self, txn_receipt: ReceiptAPI): """ @@ -1283,6 +1291,13 @@ def append_gas(self, *args, **kwargs): if self._test_runner: self._test_runner.gas_tracker.append_gas(*args, **kwargs) + def show_source_traceback( + self, traceback: SourceTraceback, file: Optional[IO[str]] = None, failing: bool = True + ): + console = self._get_console(file) + style = "red" if failing else None + console.print(str(traceback), style=style) + def _get_console(self, file: Optional[IO[str]] = None) -> RichConsole: if not file: return get_console() @@ -1329,11 +1344,16 @@ def history(self) -> TransactionHistory: """ A mapping of transactions from the active session to the account responsible. """ - if self.chain_id not in self._transaction_history_map: + try: + chain_id = self.chain_id + except ProviderNotConnectedError: + return TransactionHistory() # Empty list. + + if chain_id not in self._transaction_history_map: history = TransactionHistory() - self._transaction_history_map[self.chain_id] = history + self._transaction_history_map[chain_id] = history - return self._transaction_history_map[self.chain_id] + return self._transaction_history_map[chain_id] @property def chain_id(self) -> int: diff --git a/src/ape/managers/project/manager.py b/src/ape/managers/project/manager.py index 90273eaca1..9da0c27e7d 100644 --- a/src/ape/managers/project/manager.py +++ b/src/ape/managers/project/manager.py @@ -2,11 +2,11 @@ from pathlib import Path from typing import Dict, Iterable, List, Optional, Type, Union -from ethpm_types import Compiler from ethpm_types import ContractInstance as EthPMContractInstance from ethpm_types import ContractType, PackageManifest, PackageMeta, Source from ethpm_types.contract_type import BIP122_URI from ethpm_types.manifest import PackageName +from ethpm_types.source import Compiler, ContractSource from ethpm_types.utils import AnyUrl from ape.api import DependencyAPI, ProjectAPI @@ -580,6 +580,9 @@ def lookup_path(self, key_contract_path: Union[Path, str]) -> Optional[Path]: ext = path.suffix or None def find_in_dir(dir_path: Path) -> Optional[Path]: + if not dir_path.is_dir(): + return None + for file_path in dir_path.iterdir(): if file_path.is_dir(): result = find_in_dir(file_path) @@ -726,6 +729,27 @@ def track_deployment(self, contract: ContractInstance): destination.write_text(artifact.json()) + def _create_contract_source(self, contract_type: ContractType) -> Optional[ContractSource]: + if not contract_type.source_id: + return None + + src = self._lookup_source(contract_type.source_id) + if not src: + return None + + try: + return ContractSource.create(contract_type, src, self.contracts_folder) + except (ValueError, FileNotFoundError): + return None + + def _lookup_source(self, source_id: str) -> Optional[Source]: + source_path = self.lookup_path(source_id) + if source_path and source_path.is_file(): + result = self.local_project._create_source_dict(source_path, self.contracts_folder) + return next(iter(result.values())) if result else None + + return None + def _get_contract(self, name: str) -> Optional[ContractContainer]: if name in self.contracts: return self.chain_manager.contracts.get_container(self.contracts[name]) diff --git a/src/ape/managers/project/types.py b/src/ape/managers/project/types.py index cc1f19c371..416d3274aa 100644 --- a/src/ape/managers/project/types.py +++ b/src/ape/managers/project/types.py @@ -77,12 +77,13 @@ def _check_needs_compiling(self, source_path: Path) -> bool: cached_source = self.cached_sources[source_id] cached_checksum = cached_source.calculate_checksum() - source_file = self.contracts_folder / source_path - checksum = compute_checksum( - source_file.read_text("utf8").encode("utf8"), - algorithm=cached_checksum.algorithm, - ) + + # ethpm_types strips trailing white space and ensures + # a newline at the end so content so `splitlines()` works. + # We need to do the same here for to prevent the endless recompiling bug. + content = f"{source_file.read_text('utf8').rstrip()}\n" + checksum = compute_checksum(content.encode("utf8"), algorithm=cached_checksum.algorithm) # NOTE: Filter by checksum to only update what's needed return checksum != cached_checksum.hash # Contents changed diff --git a/src/ape/types/__init__.py b/src/ape/types/__init__.py index ea5659fbd1..2f8c9f382b 100644 --- a/src/ape/types/__init__.py +++ b/src/ape/types/__init__.py @@ -16,13 +16,14 @@ Source, ) from ethpm_types.abi import EventABI +from ethpm_types.source import Closure from hexbytes import HexBytes from pydantic import BaseModel, root_validator, validator from web3.types import FilterParams from ape.types.address import AddressType, RawAddress from ape.types.signatures import MessageSignature, SignableMessage, TransactionSignature -from ape.types.trace import CallTreeNode, GasReport, TraceFrame +from ape.types.trace import CallTreeNode, ControlFlow, GasReport, SourceTraceback, TraceFrame from ape.utils import BaseInterfaceModel, cached_property from ape.utils.misc import to_int @@ -295,6 +296,7 @@ def filter(self, event: "ContractEvent", **kwargs) -> List[ContractLog]: "Bytecode", "CallTreeNode", "Checksum", + "Closure", "Compiler", "ContractLog", "ContractLogContainer", @@ -307,6 +309,8 @@ def filter(self, event: "ContractEvent", **kwargs) -> List[ContractLog]: "SignableMessage", "SnapshotID", "Source", + "SourceTraceback", + "ControlFlow", "TraceFrame", "TransactionSignature", ] diff --git a/src/ape/types/trace.py b/src/ape/types/trace.py index 227f0ab911..44feba4a61 100644 --- a/src/ape/types/trace.py +++ b/src/ape/types/trace.py @@ -1,6 +1,11 @@ from fnmatch import fnmatch -from typing import TYPE_CHECKING, Any, Dict, List, Optional +from itertools import tee +from pathlib import Path +from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Optional +from ethpm_types import ASTNode, BaseModel, ContractType, HexBytes +from ethpm_types.ast import SourceLocation +from ethpm_types.source import Closure, Content, Function, SourceStatement, Statement from evm_trace.gas import merge_reports from pydantic import Field from rich.table import Table @@ -225,3 +230,434 @@ class TraceFrame(BaseInterfaceModel): """ The raw trace frame from the provider. """ + + +class ControlFlow(BaseModel): + """ + A collection of linear source nodes up until a jump. + """ + + statements: List[Statement] + """ + The source node statements. + """ + + closure: Closure + """ + The defining closure, such as a function or module, of the code sequence. + """ + + source_path: Optional[Path] = None + """ + The path to the local contract file. + Only exists when is from a local contract. + """ + + depth: int + """ + The depth at which this flow was executed, + where 1 is the first calling function. + """ + + def __str__(self) -> str: + return f"{self.source_header}\n{self.format()}" + + def __repr__(self) -> str: + source_name = f" {self.source_path.name} " if self.source_path is not None else " " + representation = f" 0: + representation = f"{representation} num_statements={len(self.statements)}" + + if self.begin_lineno is None: + return f"{representation}>" + + else: + # Include line number info. + end_lineno = self.end_lineno or self.begin_lineno + line_range = ( + f"line {self.begin_lineno}" + if self.begin_lineno == end_lineno + else f"lines {self.begin_lineno}-{end_lineno}" + ) + return f"{representation}, {line_range}>" + + def __getitem__(self, idx: int) -> Statement: + try: + return self.statements[idx] + except IndexError as err: + raise IndexError(f"Statement index '{idx}' out of range.") from err + + def __len__(self) -> int: + return len(self.statements) + + @property + def source_statements(self) -> List[SourceStatement]: + """ + All statements coming directly from a contract's source. + Excludes implicit-compiler statements. + """ + return [x for x in self.statements if isinstance(x, SourceStatement)] + + @property + def begin_lineno(self) -> Optional[int]: + """ + The first line number in the sequence. + """ + stmts = self.source_statements + return stmts[0].begin_lineno if stmts else None + + @property + def ws_begin_lineno(self) -> Optional[int]: + """ + The first line number in the sequence, including whitespace. + """ + stmts = self.source_statements + return stmts[0].ws_begin_lineno if stmts else None + + @property + def line_numbers(self) -> List[int]: + """ + The list of all line numbers as part of this node. + """ + + if self.begin_lineno is None: + return [] + + elif self.end_lineno is None: + return [self.begin_lineno] + + return list(range(self.begin_lineno, self.end_lineno + 1)) + + @property + def content(self) -> Content: + result: Dict[int, str] = {} + for node in self.source_statements: + result = {**result, **node.content.__root__} + + return Content(__root__=result) + + @property + def source_header(self) -> str: + result = "" + if self.source_path is not None: + result += f"File {self.source_path}, in " + + result += f"{self.closure.name}" + return result.strip() + + @property + def end_lineno(self) -> Optional[int]: + """ + The last line number. + """ + stmts = self.source_statements + return stmts[-1].end_lineno if stmts else None + + def extend(self, location: SourceLocation, ws_start: Optional[int] = None): + """ + Extend this node's content with other content that follows it directly. + + Raises: + ValueError: When there is a gap in content. + + Args: + location (SourceLocation): The location of the content, in the form + (lineno, col_offset, end_lineno, end_coloffset). + ws_start (Optional[int]): Optionally provide a white-space starting point + to back-fill. + """ + + if ws_start is not None and ws_start > location[0]: + # No new lines. + return + + function = self.closure + if not isinstance(function, Function): + # No source code supported for closure type. + return + + # NOTE: Use non-ws prepending location to fetch AST nodes. + asts = function.get_content_asts(location) + if not asts: + return + + location = ( + (ws_start, location[1], location[2], location[3]) if ws_start is not None else location + ) + content = function.get_content(location) + start = ( + max(location[0], self.end_lineno + 1) if self.end_lineno is not None else location[0] + ) + end = location[0] + len(content) - 1 + if end < start: + # No new lines. + return + + elif start - end > 1: + raise ValueError( + "Cannot extend when gap in lines > 1. " + "If because of whitespace lines, must include it the given content." + ) + + content_start = len(content) - (end - start) - 1 + new_lines = {no: ln.rstrip() for no, ln in content.items() if no >= content_start} + if new_lines: + # Add the next statement in this sequence. + content = Content(__root__=new_lines) + statement = SourceStatement(asts=asts, content=content) + self.statements.append(statement) + + else: + # Add ASTs to latest statement. + self.source_statements[-1].asts.extend(asts) + + def format(self, use_arrow: bool = True) -> str: + """ + Format this trace node into a string presentable to the user. + """ + + # NOTE: Only show last 2 statements. + relevant_stmts = self.statements[-2:] + content = "" + end_lineno = self.content.end_lineno + + for stmt in relevant_stmts: + for lineno, line in getattr(stmt, "content", {}).items(): + if not content and not line.strip(): + # Prevent starting on whitespace. + continue + + if content: + # Add newline to carry over from last. + content = f"{content.rstrip()}\n" + + space = " " if lineno < end_lineno or not use_arrow else " --> " + content = f"{content}{space}{lineno} {line}" + + return content + + @property + def next_statement(self) -> Optional[SourceStatement]: + """ + Returns the next statement that _would_ execute + if the program were to progress to the next line. + """ + + # Check for more statements that _could_ execute. + if not self.statements: + return None + + last_stmt = self.source_statements[-1] + function = self.closure + if not isinstance(function, Function): + return None + + rest_asts = [a for a in function.ast.children if a.lineno > last_stmt.end_lineno] + if not rest_asts: + # At the end of a function. + return None + + # Filter out to only the ASTs for the next statement. + next_stmt_start = min(rest_asts, key=lambda x: x.lineno).lineno + next_stmt_asts = [a for a in rest_asts if a.lineno == next_stmt_start] + content_dict = {} + for ast in next_stmt_asts: + sub_content = function.get_content(ast.line_numbers) + content_dict = {**sub_content.__root__} + + if not content_dict: + return None + + sorted_dict = {k: content_dict[k] for k in sorted(content_dict)} + content = Content(__root__=sorted_dict) + return SourceStatement(asts=next_stmt_asts, content=content) + + +class SourceTraceback(BaseModel): + """ + A full execution traceback including source code. + """ + + __root__: List[ControlFlow] + + @classmethod + def create(cls, contract_type: ContractType, trace: Iterator[TraceFrame], data: HexBytes): + source_id = contract_type.source_id + if not source_id: + return cls.parse_obj([]) + + trace, second_trace = tee(trace) + if second_trace: + accessor = next(second_trace, None) + if not accessor: + return cls.parse_obj([]) + else: + return cls.parse_obj([]) + + ext = f".{source_id.split('.')[-1]}" + if ext not in accessor.compiler_manager.registered_compilers: + return cls.parse_obj([]) + + compiler = accessor.compiler_manager.registered_compilers[ext] + try: + return compiler.trace_source(contract_type, trace, data) + except NotImplementedError: + return cls.parse_obj([]) + + def __str__(self) -> str: + return self.format() + + def __repr__(self) -> str: + return f"" + + def __len__(self) -> int: + return len(self.__root__) + + def __iter__(self): + yield from self.__root__ + + def __getitem__(self, idx: int) -> ControlFlow: + try: + return self.__root__[idx] + except IndexError as err: + raise IndexError(f"Control flow index '{idx}' out of range.") from err + + def __setitem__(self, key, value): + return self.__root__.__setitem__(key, value) + + def append(self, __object) -> None: + self.__root__.append(__object) + + def extend(self, __iterable) -> None: + if not isinstance(__iterable, SourceTraceback): + raise TypeError("Can only extend another traceback object.") + + self.__root__.extend(__iterable.__root__) + + @property + def last(self) -> Optional[ControlFlow]: + return self.__root__[-1] if len(self.__root__) else None + + @property + def execution(self) -> List[ControlFlow]: + return list(self.__root__) + + def format(self) -> str: + if not len(self.__root__): + # No calls. + return "" + + header = "Traceback (most recent call last)" + indent = " " + last_depth = None + segments = [] + for control_flow in reversed(self.__root__): + if last_depth is None or control_flow.depth == last_depth - 1: + last_depth = control_flow.depth + segment = f"{indent}{control_flow.source_header}\n{control_flow.format()}" + + # Try to include next statement for display purposes. + next_stmt = control_flow.next_statement + if next_stmt: + if ( + next_stmt.begin_lineno is not None + and control_flow.end_lineno is not None + and next_stmt.begin_lineno > control_flow.end_lineno + 1 + ): + # Include whitespace. + for ws_no in range(control_flow.end_lineno + 1, next_stmt.begin_lineno): + function = control_flow.closure + if not isinstance(function, Function): + continue + + ws = function.content[ws_no] + segment = f"{segment}\n {ws_no} {ws}".rstrip() + + for no, line in next_stmt.content.items(): + segment = f"{segment}\n {no} {line}".rstrip() + + segments.append(segment) + + builder = "" + for idx, segment in enumerate(reversed(segments)): + builder = f"{builder}\n{segment}" + + if idx < len(segments) - 1: + builder = f"{builder}\n" + + return f"{header}{builder}" + + def add_jump( + self, + location: SourceLocation, + function: Function, + depth: int, + source_path: Optional[Path] = None, + ): + """ + Add an execution sequence from a jump. + + Args: + location (``SourceLocation``): The location to add. + function (``Function``): The function executing. + source_path (Optional[``Path``]): The path of the source file. + depth (int): The depth of the function call in the call tree. + """ + + # Exclude signature ASTs. + asts = function.get_content_asts(location) + content = function.get_content(location) + if not asts or not content: + return + + Statement.update_forward_refs() + ControlFlow.update_forward_refs() + self._add(asts, content, function, depth, source_path=source_path) + + def extend_last(self, location: SourceLocation): + """ + Extend the last node with more content. + + Args: + location (``SourceLocation``): The location of the new content. + """ + + if not self.last: + raise ValueError( + "`progress()` should only be called when " + "there is at least 1 ongoing execution trail." + ) + + start = ( + 1 + if self.last is not None and self.last.end_lineno is None + else self.last.end_lineno + 1 + ) + self.last.extend(location, ws_start=start) + + def add_builtin_jump(self, name: str, _type: str, compiler_name: str): + closure = Closure(name=name) + depth = self.last.depth - 1 if self.last else 0 + statement = Statement(type=_type) + flow = ControlFlow( + statements=[statement], + closure=closure, + source_path=Path("") / compiler_name, + depth=depth, + ) + self.append(flow) + + def _add( + self, + asts: List[ASTNode], + content: Content, + function: Function, + depth: int, + source_path: Optional[Path] = None, + ): + statement = SourceStatement(asts=asts, content=content) + exec_sequence = ControlFlow( + statements=[statement], source_path=source_path, closure=function, depth=depth + ) + self.append(exec_sequence) diff --git a/src/ape_ethereum/transactions.py b/src/ape_ethereum/transactions.py index 6b1f8b851d..9aa572b627 100644 --- a/src/ape_ethereum/transactions.py +++ b/src/ape_ethereum/transactions.py @@ -9,14 +9,14 @@ serializable_unsigned_transaction_from_dict, ) from eth_utils import decode_hex, encode_hex, keccak, to_hex, to_int -from ethpm_types import HexBytes +from ethpm_types import ContractType, HexBytes from ethpm_types.abi import EventABI, MethodABI from pydantic import BaseModel, Field, root_validator, validator from ape.api import ReceiptAPI, TransactionAPI from ape.contracts import ContractEvent from ape.exceptions import OutOfGasError, SignatureError, TransactionError -from ape.types import CallTreeNode, ContractLog, ContractLogContainer +from ape.types import CallTreeNode, ContractLog, ContractLogContainer, SourceTraceback from ape.utils import cached_property @@ -149,23 +149,37 @@ def failed(self) -> bool: @cached_property def call_tree(self) -> Optional[CallTreeNode]: + if self.receiver: + return self.provider.get_call_tree(self.txn_hash) + + # Not an function invoke + return None + + @cached_property + def contract_type(self) -> Optional[ContractType]: if not self.receiver: - # Not an function invoke return None - return self.provider.get_call_tree(self.txn_hash) + return self.chain_manager.contracts.get(self.receiver) @cached_property def method_called(self) -> Optional[MethodABI]: - contract_type = self.chain_manager.contracts.get(self.receiver) - if not contract_type: + if not self.contract_type: return None method_id = self.data[:4] - if method_id not in contract_type.methods: + if method_id not in self.contract_type.methods: return None - return contract_type.methods[method_id] + return self.contract_type.methods[method_id] + + @property + def source_traceback(self) -> SourceTraceback: + contract_type = self.contract_type + if not contract_type: + return SourceTraceback.parse_obj([]) + + return SourceTraceback.create(contract_type, self.trace, HexBytes(self.data)) def raise_for_status(self): if self.gas_limit is not None and self.ran_out_of_gas: @@ -227,6 +241,11 @@ def show_gas_report(self, file: IO[str] = sys.stdout): self.chain_manager._reports.show_gas(call_tree, file=file) + def show_source_traceback(self, file: IO[str] = sys.stdout): + self.chain_manager._reports.show_source_traceback( + self.source_traceback, file=file, failing=self.failed + ) + def decode_logs( self, abi: Optional[ diff --git a/src/ape_geth/provider.py b/src/ape_geth/provider.py index c23ef9b01c..77d7101d1b 100644 --- a/src/ape_geth/provider.py +++ b/src/ape_geth/provider.py @@ -184,7 +184,7 @@ def disconnect(self): self._clean() def _clean(self): - if self.data_dir.exists(): + if self.data_dir.is_dir(): shutil.rmtree(self.data_dir) diff --git a/tests/functional/test_geth.py b/tests/functional/test_geth.py index b57d82b2de..b487f4097a 100644 --- a/tests/functional/test_geth.py +++ b/tests/functional/test_geth.py @@ -112,6 +112,16 @@ def geth_receipt(contract_with_call_depth_geth, owner, geth_provider): return contract_with_call_depth_geth.methodWithoutArguments(sender=owner) +@pytest.fixture +def geth_vyper_contract(owner, vyper_contract_container, geth_provider): + return owner.deploy(vyper_contract_container, 0) + + +@pytest.fixture +def geth_vyper_receipt(geth_vyper_contract, owner): + return geth_vyper_contract.setNumber(44, sender=owner) + + @geth_process_test def test_uri(geth_provider): assert geth_provider.uri == GETH_URI @@ -129,38 +139,31 @@ def test_uri_uses_value_from_config(geth_provider, temp_config): geth_provider.provider_settings = settings -def test_tx_revert(accounts, sender, vyper_contract_container): +def test_tx_revert(accounts, sender, geth_vyper_contract, owner): # 'sender' is not the owner so it will revert (with a message) - contract = accounts.test_accounts[-1].deploy(vyper_contract_container, 0) with pytest.raises(ContractLogicError, match="!authorized"): - contract.setNumber(5, sender=sender) + geth_vyper_contract.setNumber(5, sender=sender) -def test_revert_no_message(accounts, vyper_contract_container): +def test_revert_no_message(accounts, geth_vyper_contract, owner): # The Contract raises empty revert when setting number to 5. expected = "Transaction failed." # Default message - owner = accounts.test_accounts[-2] - contract = owner.deploy(vyper_contract_container, 0) with pytest.raises(ContractLogicError, match=expected): - contract.setNumber(5, sender=owner) + geth_vyper_contract.setNumber(5, sender=owner) @geth_process_test -def test_contract_interaction(geth_provider, vyper_contract_container, accounts): - owner = accounts.test_accounts[-2] - contract = owner.deploy(vyper_contract_container, 0) - contract.setNumber(102, sender=owner) - assert contract.myNumber() == 102 +def test_contract_interaction(owner, geth_vyper_contract): + geth_vyper_contract.setNumber(102, sender=owner) + assert geth_vyper_contract.myNumber() == 102 @geth_process_test -def test_get_call_tree(geth_provider, vyper_contract_container, accounts): - owner = accounts.test_accounts[-3] - contract = owner.deploy(vyper_contract_container, 0) - receipt = contract.setNumber(10, sender=owner) +def test_get_call_tree(geth_vyper_contract, owner, geth_provider): + receipt = geth_vyper_contract.setNumber(10, sender=owner) result = geth_provider.get_call_tree(receipt.txn_hash) expected = ( - rf"{contract.address}.0x3fb5c1cb" + rf"{geth_vyper_contract.address}.0x3fb5c1cb" r"\(0x000000000000000000000000000000000000000000000000000000000000000a\) \[\d+ gas\]" ) actual = repr(result) @@ -192,13 +195,11 @@ def test_repr_on_live_network_and_disconnected(networks): @geth_process_test -def test_get_logs(geth_provider, accounts, vyper_contract_container): - owner = accounts.test_accounts[-4] - contract = owner.deploy(vyper_contract_container, 0) - contract.setNumber(101010, sender=owner) - actual = contract.NumberChange[-1] +def test_get_logs(geth_vyper_contract, owner): + geth_vyper_contract.setNumber(101010, sender=owner) + actual = geth_vyper_contract.NumberChange[-1] assert actual.event_name == "NumberChange" - assert actual.contract_address == contract.address + assert actual.contract_address == geth_vyper_contract.address assert actual.event_arguments["newNum"] == 101010 diff --git a/tests/integration/cli/test_compile.py b/tests/integration/cli/test_compile.py index 98d9c9bb1e..331bd9ed8f 100644 --- a/tests/integration/cli/test_compile.py +++ b/tests/integration/cli/test_compile.py @@ -135,12 +135,13 @@ def test_compile_when_contract_type_collision(ape_cli, runner, project, clean_ca @skip_projects_except("multiple-interfaces") def test_compile_when_source_contains_return_characters(ape_cli, runner, project, clean_cache): - # NOTE: This tests a bugfix where a source file contained return-characters - # and that triggered endless re-compiles because it technically contains extra - # bytes than the ones that show up in the text. - - # Change the contents of a file to contain the '\r' character. + """ + This tests a bugfix where a source file contained return-characters + and that triggered endless re-compiles because it technically contains extra + bytes than the ones that show up in the text. + """ source_path = project.contracts_folder / "Interface.json" + # Change the contents of a file to contain the '\r' character. modified_source_text = f"{source_path.read_text()}\r" source_path.unlink() source_path.touch()