Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: provider.batch_requests() low-level implementation #2475

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions src/ape/api/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,17 @@ def stream_request( # type: ignore[empty-body]
An iterator of items.
"""

@raises_not_implemented
def batch_requests(self, requests: list[dict]) -> Any:
"""
Send batched requests (multiple requests at once) to the RPC provider.

Args:
requests (list[dict]): The requests to send.

Returns: The results of each request.
"""

# TODO: In 0.9, delete this method.
def get_storage_at(self, *args, **kwargs) -> "HexBytes":
warnings.warn(
Expand Down
65 changes: 47 additions & 18 deletions src/ape_ethereum/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from typing import TYPE_CHECKING, Any, Callable, Optional, Union, cast

import ijson # type: ignore
import requests
import requests as requests_lib
from eth_pydantic_types import HexBytes
from eth_typing import BlockNumber, HexStr
from eth_utils import add_0x_prefix, is_hex, to_hex
Expand Down Expand Up @@ -1333,7 +1333,7 @@ def make_request(self, rpc: str, parameters: Optional[Iterable] = None) -> Any:
def _make_request(self, rpc: str, parameters: Optional[Iterable] = None) -> Any:
parameters = parameters or []
try:
result = self.web3.provider.make_request(RPCEndpoint(rpc), parameters)
response = self.web3.provider.make_request(RPCEndpoint(rpc), parameters)
except HTTPError as err:
if "method not allowed" in str(err).lower():
raise APINotImplementedError(
Expand All @@ -1345,28 +1345,36 @@ def _make_request(self, rpc: str, parameters: Optional[Iterable] = None) -> Any:

raise ProviderError(str(err)) from err

if "error" in result:
error = result["error"]
return self._get_result_from_rpc_response(rpc, response)

def _get_result_from_rpc_response(
self, rpc: str, response: dict, raise_on_failure: bool = True
) -> Any:
if "error" in response:
error = response["error"]
message = (
error["message"] if isinstance(error, dict) and "message" in error else str(error)
)
if raise_on_failure:
if (
"does not exist/is not available" in str(message)
or re.match(r"[m|M]ethod .*?not found", message)
or message.startswith("Unknown RPC Endpoint")
or "RPC Endpoint has not been implemented" in message
):
raise APINotImplementedError(
f"RPC method '{rpc}' is not implemented by this node instance."
)

if (
"does not exist/is not available" in str(message)
or re.match(r"[m|M]ethod .*?not found", message)
or message.startswith("Unknown RPC Endpoint")
or "RPC Endpoint has not been implemented" in message
):
raise APINotImplementedError(
f"RPC method '{rpc}' is not implemented by this node instance."
)
raise ProviderError(message)

raise ProviderError(message)
else:
return message

elif "result" in result:
return result.get("result", {})
elif "result" in response:
return response.get("result", {})

return result
return response

def stream_request(self, method: str, params: Iterable, iter_path: str = "result.item"):
if not (uri := self.http_uri):
Expand All @@ -1375,14 +1383,35 @@ def stream_request(self, method: str, params: Iterable, iter_path: str = "result
payload = {"jsonrpc": "2.0", "id": 1, "method": method, "params": params}
results = ijson.sendable_list()
coroutine = ijson.items_coro(results, iter_path)
resp = requests.post(uri, json=payload, stream=True)
resp = requests_lib.post(uri, json=payload, stream=True)
resp.raise_for_status()

for chunk in resp.iter_content(chunk_size=2**17):
coroutine.send(chunk)
yield from results
del results[:]

def batch_requests(self, requests: list[dict]) -> Any:
if not (uri := self.http_uri):
raise ProviderError("This provider has no HTTP URI and is unable to batch requests.")

for idx, request in enumerate(requests):
if "jsonrpc" not in request:
request["jsonrpc"] = "2.0"
if "id" not in request:
request["id"] = idx + 1

response = requests_lib.post(uri, json=requests)
try:
response.raise_for_status()
except HTTPError as err:
raise ProviderError(str(err)) from err

return [
self._get_result_from_rpc_response(uri, r, raise_on_failure=False)
for r in response.json()
]

def create_access_list(
self, transaction: TransactionAPI, block_id: Optional["BlockID"] = None
) -> list[AccessList]:
Expand Down
19 changes: 19 additions & 0 deletions tests/functional/geth/test_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -655,6 +655,25 @@ def test_make_request_not_exists(geth_provider):
geth_provider.make_request("ape_thisDoesNotExist")


@geth_process_test
def test_batch_requests(geth_account, geth_contract, geth_provider):
call = geth_contract.myNumber.as_transaction().model_dump()
results = geth_provider.batch_requests(
[
{
"method": "eth_call",
"params": [{"data": call["data"], "to": geth_contract.address}],
},
{
"method": "eth_getBalance",
"params": [geth_account.address, "latest"],
},
]
)
for result in results:
assert result.startswith("0x")


@geth_process_test
@pytest.mark.parametrize(
"message",
Expand Down
Loading