Skip to content

Commit

Permalink
fill in missing rpc client methods (#18851)
Browse files Browse the repository at this point in the history
* fill in missing rpc client methods

* no cover

(cherry picked from commit 96a8be8)
  • Loading branch information
altendky committed Nov 15, 2024
1 parent fe80de7 commit 359ff89
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 10 deletions.
57 changes: 56 additions & 1 deletion chia/_tests/rpc/test_rpc_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,30 @@

from dataclasses import dataclass
from pathlib import Path
from typing import Any, AsyncIterator, Dict, Optional
from typing import Any, AsyncIterator, Awaitable, Callable, Dict, Optional

import pytest

from chia._tests.util.misc import Marks, RecordingWebServer, datacases
from chia.rpc.rpc_client import ResponseFailureError, RpcClient
from chia.rpc.rpc_server import RpcServer
from chia.server.outbound_message import NodeType
from chia.util.ints import uint16

non_fetch_client_methods = {
RpcClient.create,
RpcClient.create_as_context,
RpcClient.fetch,
RpcClient.close,
RpcClient.await_closed,
}

client_fetch_methods = {
attribute
for name, attribute in vars(RpcClient).items()
if callable(attribute) and attribute not in non_fetch_client_methods and not name.startswith("__")
}


@dataclass
class InvalidCreateCase:
Expand Down Expand Up @@ -79,3 +95,42 @@ async def test_failure_exception(
await rpc_client.fetch(path="/table", request_json={"response": expected_response})

assert exception_info.value.response == expected_response


def test_client_standard_endpoints_match_server() -> None:
# NOTE: this test assumes that the client method names should match the server
# route names
client_method_names = {method.__name__ for method in client_fetch_methods}
server_route_names = {method.lstrip("/") for method in RpcServer._routes.keys()}
assert client_method_names == server_route_names


@pytest.mark.anyio
@pytest.mark.parametrize("client_method", client_fetch_methods)
async def test_client_fetch_methods(
client_method: Callable[..., Awaitable[object]],
rpc_client: RpcClient,
recording_web_server: RecordingWebServer,
) -> None:
# NOTE: this test assumes that the client method names should match the server
# route names

parameters: dict[Callable[..., Awaitable[object]], dict[str, object]] = {
RpcClient.open_connection: {"host": "", "port": 0},
RpcClient.close_connection: {"node_id": b""},
RpcClient.get_connections: {"node_type": NodeType.FULL_NODE},
}

try:
await client_method(rpc_client, **parameters.get(client_method, {}))
except Exception as exception:
if client_method is RpcClient.get_connections and isinstance(exception, KeyError):
pass
else: # pragma: no cover
# this case will fail the test so not normally executed
raise

[request] = recording_web_server.requests
assert request.content_type == "application/json"
assert request.method == "POST"
assert request.path == f"/{client_method.__name__}"
9 changes: 9 additions & 0 deletions chia/rpc/rpc_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,15 @@ async def stop_node(self) -> Dict:
async def healthz(self) -> Dict:
return await self.fetch("healthz", {})

async def get_network_info(self) -> dict:
return await self.fetch("get_network_info", {})

async def get_routes(self) -> dict:
return await self.fetch("get_routes", {})

async def get_version(self) -> dict:
return await self.fetch("get_version", {})

def close(self) -> None:
self.closing_task = asyncio.create_task(self.session.close())

Expand Down
23 changes: 14 additions & 9 deletions chia/rpc/rpc_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
from dataclasses import dataclass
from pathlib import Path
from ssl import SSLContext
from typing import Any, AsyncIterator, Awaitable, Callable, Dict, Generic, List, Optional, TypeVar
from types import MethodType
from typing import Any, AsyncIterator, Awaitable, Callable, ClassVar, Dict, Generic, List, Optional, TypeVar

from aiohttp import ClientConnectorError, ClientSession, ClientWebSocketResponse, WSMsgType, web
from typing_extensions import Protocol, final
Expand Down Expand Up @@ -236,14 +237,7 @@ def listen_port(self) -> uint16:
def _get_routes(self) -> Dict[str, Endpoint]:
return {
**self.rpc_api.get_routes(),
"/get_network_info": self.get_network_info,
"/get_connections": self.get_connections,
"/open_connection": self.open_connection,
"/close_connection": self.close_connection,
"/stop_node": self.stop_node,
"/get_routes": self.get_routes,
"/get_version": self.get_version,
"/healthz": self.healthz,
**{path: MethodType(handler, self) for path, handler in self._routes.items()},
}

async def get_routes(self, request: Dict[str, Any]) -> EndpointResult:
Expand Down Expand Up @@ -409,6 +403,17 @@ async def inner() -> None:

self.daemon_connection_task = asyncio.create_task(inner())

_routes: ClassVar[dict[str, Callable[..., Awaitable[object]]]] = {
"/get_network_info": get_network_info,
"/get_connections": get_connections,
"/open_connection": open_connection,
"/close_connection": close_connection,
"/stop_node": stop_node,
"/get_routes": get_routes,
"/get_version": get_version,
"/healthz": healthz,
}


async def start_rpc_server(
rpc_api: _T_RpcApiProtocol,
Expand Down

0 comments on commit 359ff89

Please sign in to comment.