From 05fb492dbcc1cbf8af4685ea990afdeacc07a8be Mon Sep 17 00:00:00 2001 From: antazoey Date: Tue, 21 Jan 2025 10:46:27 -0700 Subject: [PATCH] feat: batch requests low level --- src/ape/api/providers.py | 11 +++++ src/ape_ethereum/provider.py | 65 +++++++++++++++++++------- tests/functional/geth/test_provider.py | 19 ++++++++ 3 files changed, 77 insertions(+), 18 deletions(-) diff --git a/src/ape/api/providers.py b/src/ape/api/providers.py index 34be0ebcfc..b39770df32 100644 --- a/src/ape/api/providers.py +++ b/src/ape/api/providers.py @@ -385,6 +385,17 @@ def stream_request( # type: ignore[empty-body] An iterator of items. """ + @raises_not_implemented + def batch_requests(self, requests: list[dict]) -> Any: + """ + Send batched requests (multiple requests at once) to the RPC provider. + + Args: + requests (list[dict]): The requests to send. + + Returns: The results of each request. + """ + # TODO: In 0.9, delete this method. def get_storage_at(self, *args, **kwargs) -> "HexBytes": warnings.warn( diff --git a/src/ape_ethereum/provider.py b/src/ape_ethereum/provider.py index da4a919805..370843af2e 100644 --- a/src/ape_ethereum/provider.py +++ b/src/ape_ethereum/provider.py @@ -12,7 +12,7 @@ from typing import TYPE_CHECKING, Any, Callable, Optional, Union, cast import ijson # type: ignore -import requests +import requests as requests_lib from eth_pydantic_types import HexBytes from eth_typing import BlockNumber, HexStr from eth_utils import add_0x_prefix, is_hex, to_hex @@ -1333,7 +1333,7 @@ def make_request(self, rpc: str, parameters: Optional[Iterable] = None) -> Any: def _make_request(self, rpc: str, parameters: Optional[Iterable] = None) -> Any: parameters = parameters or [] try: - result = self.web3.provider.make_request(RPCEndpoint(rpc), parameters) + response = self.web3.provider.make_request(RPCEndpoint(rpc), parameters) except HTTPError as err: if "method not allowed" in str(err).lower(): raise APINotImplementedError( @@ -1345,28 +1345,36 @@ def _make_request(self, rpc: str, parameters: Optional[Iterable] = None) -> Any: raise ProviderError(str(err)) from err - if "error" in result: - error = result["error"] + return self._get_result_from_rpc_response(rpc, response) + + def _get_result_from_rpc_response( + self, rpc: str, response: dict, raise_on_failure: bool = True + ) -> Any: + if "error" in response: + error = response["error"] message = ( error["message"] if isinstance(error, dict) and "message" in error else str(error) ) + if raise_on_failure: + if ( + "does not exist/is not available" in str(message) + or re.match(r"[m|M]ethod .*?not found", message) + or message.startswith("Unknown RPC Endpoint") + or "RPC Endpoint has not been implemented" in message + ): + raise APINotImplementedError( + f"RPC method '{rpc}' is not implemented by this node instance." + ) - if ( - "does not exist/is not available" in str(message) - or re.match(r"[m|M]ethod .*?not found", message) - or message.startswith("Unknown RPC Endpoint") - or "RPC Endpoint has not been implemented" in message - ): - raise APINotImplementedError( - f"RPC method '{rpc}' is not implemented by this node instance." - ) + raise ProviderError(message) - raise ProviderError(message) + else: + return message - elif "result" in result: - return result.get("result", {}) + elif "result" in response: + return response.get("result", {}) - return result + return response def stream_request(self, method: str, params: Iterable, iter_path: str = "result.item"): if not (uri := self.http_uri): @@ -1375,7 +1383,7 @@ def stream_request(self, method: str, params: Iterable, iter_path: str = "result payload = {"jsonrpc": "2.0", "id": 1, "method": method, "params": params} results = ijson.sendable_list() coroutine = ijson.items_coro(results, iter_path) - resp = requests.post(uri, json=payload, stream=True) + resp = requests_lib.post(uri, json=payload, stream=True) resp.raise_for_status() for chunk in resp.iter_content(chunk_size=2**17): @@ -1383,6 +1391,27 @@ def stream_request(self, method: str, params: Iterable, iter_path: str = "result yield from results del results[:] + def batch_requests(self, requests: list[dict]) -> Any: + if not (uri := self.http_uri): + raise ProviderError("This provider has no HTTP URI and is unable to batch requests.") + + for idx, request in enumerate(requests): + if "jsonrpc" not in request: + request["jsonrpc"] = "2.0" + if "id" not in request: + request["id"] = idx + 1 + + response = requests_lib.post(uri, json=requests) + try: + response.raise_for_status() + except HTTPError as err: + raise ProviderError(str(err)) from err + + return [ + self._get_result_from_rpc_response(uri, r, raise_on_failure=False) + for r in response.json() + ] + def create_access_list( self, transaction: TransactionAPI, block_id: Optional["BlockID"] = None ) -> list[AccessList]: diff --git a/tests/functional/geth/test_provider.py b/tests/functional/geth/test_provider.py index b2f32b8a57..61a82fc76e 100644 --- a/tests/functional/geth/test_provider.py +++ b/tests/functional/geth/test_provider.py @@ -655,6 +655,25 @@ def test_make_request_not_exists(geth_provider): geth_provider.make_request("ape_thisDoesNotExist") +@geth_process_test +def test_batch_requests(geth_account, geth_contract, geth_provider): + call = geth_contract.myNumber.as_transaction().model_dump() + results = geth_provider.batch_requests( + [ + { + "method": "eth_call", + "params": [{"data": call["data"], "to": geth_contract.address}], + }, + { + "method": "eth_getBalance", + "params": [geth_account.address, "latest"], + }, + ] + ) + for result in results: + assert result.startswith("0x") + + @geth_process_test @pytest.mark.parametrize( "message",