diff --git a/pyk/src/pyk/rpc/rpc.py b/pyk/src/pyk/rpc/rpc.py index b0e1a47c9cd..fa0f5c65e76 100644 --- a/pyk/src/pyk/rpc/rpc.py +++ b/pyk/src/pyk/rpc/rpc.py @@ -2,9 +2,11 @@ import json import logging +from abc import ABC, abstractmethod +from dataclasses import dataclass from functools import partial from http.server import BaseHTTPRequestHandler, HTTPServer -from typing import TYPE_CHECKING, Any, Final +from typing import TYPE_CHECKING, Any, Final, NamedTuple from typing_extensions import Protocol @@ -71,37 +73,80 @@ class JsonRpcMethod(Protocol): def __call__(self, **kwargs: Any) -> Any: ... -class JsonRpcRequestHandler(BaseHTTPRequestHandler): - methods: dict[str, JsonRpcMethod] +class JsonRpcRequest(NamedTuple): + id: str | int + method: str + params: Any - def __init__(self, methods: dict[str, JsonRpcMethod], *args: Any, **kwargs: Any) -> None: - self.methods = methods - super().__init__(*args, **kwargs) - def send_json_error(self, code: int, message: str, id: Any = None) -> None: - error_dict = { +class JsonRpcBatchRequest(NamedTuple): + requests: tuple[JsonRpcRequest] + + +class JsonRpcResult(ABC): + + @abstractmethod + def encode(self) -> bytes: ... + + +@dataclass(frozen=True) +class JsonRpcError(JsonRpcResult): + + code: int + message: str + id: str | int | None + + def to_json(self) -> dict[str, Any]: + return { 'jsonrpc': JsonRpcServer.JSONRPC_VERSION, 'error': { - 'code': code, - 'message': message, + 'code': self.code, + 'message': self.message, }, - 'id': id, + 'id': self.id, } - error_bytes = json.dumps(error_dict).encode('ascii') - self.set_response() - self.wfile.write(error_bytes) - def send_json_success(self, result: Any, id: Any) -> None: - response_dict = { + def encode(self) -> bytes: + return json.dumps(self.to_json()).encode('ascii') + + +@dataclass(frozen=True) +class JsonRpcSuccess(JsonRpcResult): + payload: Any + id: Any + + def to_json(self) -> dict[str, Any]: + return { 'jsonrpc': JsonRpcServer.JSONRPC_VERSION, - 'result': result, - 'id': id, + 'result': self.payload, + 'id': self.id, } - response_bytes = json.dumps(response_dict).encode('ascii') - self.set_response() + + def encode(self) -> bytes: + return json.dumps(self.to_json()).encode('ascii') + + +@dataclass(frozen=True) +class JsonRpcBatchResult(JsonRpcResult): + results: tuple[JsonRpcError | JsonRpcSuccess, ...] + + def encode(self) -> bytes: + return json.dumps([result.to_json() for result in self.results]).encode('ascii') + + +class JsonRpcRequestHandler(BaseHTTPRequestHandler): + methods: dict[str, JsonRpcMethod] + + def __init__(self, methods: dict[str, JsonRpcMethod], *args: Any, **kwargs: Any) -> None: + self.methods = methods + super().__init__(*args, **kwargs) + + def _send_response(self, response: JsonRpcResult) -> None: + self.send_response_headers() + response_bytes = response.encode() self.wfile.write(response_bytes) - def set_response(self) -> None: + def send_response_headers(self) -> None: self.send_response(200) self.send_header('Content-type', 'text/html') self.end_headers() @@ -113,37 +158,36 @@ def do_POST(self) -> None: # noqa: N802 content = self.rfile.read(int(content_len)) _LOGGER.debug(f'Received bytes: {content.decode()}') - request: dict + request: dict[str, Any] | list[dict[str, Any]] try: request = json.loads(content) _LOGGER.info(f'Received request: {request}') except json.JSONDecodeError: _LOGGER.warning(f'Invalid JSON: {content.decode()}') - self.send_json_error(-32700, 'Invalid JSON') + json_error = JsonRpcError(-32700, 'Invalid JSON', None) + self._send_response(json_error) return - required_fields = ['jsonrpc', 'method', 'id'] - for field in required_fields: - if field not in request: - _LOGGER.warning(f'Missing required field "{field}": {request}') - self.send_json_error(-32600, f'Invalid request: missing field "{field}"', request.get('id', None)) - return + response: JsonRpcResult + if isinstance(request, list): + response = self._batch_request(request) + else: + response = self._single_request(request) - jsonrpc_version = request['jsonrpc'] - if jsonrpc_version != JsonRpcServer.JSONRPC_VERSION: - _LOGGER.warning(f'Bad JSON-RPC version: {jsonrpc_version}') - self.send_json_error(-32600, f'Invalid request: bad version: "{jsonrpc_version}"', request['id']) - return + self._send_response(response) - method_name = request['method'] - if method_name not in self.methods: - _LOGGER.warning(f'Method not found: {method_name}') - self.send_json_error(-32601, f'Method "{method_name}" not found.', request['id']) - return + def _batch_request(self, requests: list[dict[str, Any]]) -> JsonRpcBatchResult: + return JsonRpcBatchResult(tuple(self._single_request(request) for request in requests)) + def _single_request(self, request: dict[str, Any]) -> JsonRpcError | JsonRpcSuccess: + validation_result = self._validate_request(request) + if isinstance(validation_result, JsonRpcError): + return validation_result + + id, method_name, params = validation_result method = self.methods[method_name] - params = request.get('params', None) _LOGGER.info(f'Executing method {method_name}') + result: Any if type(params) is dict: result = method(**params) elif type(params) is list: @@ -151,6 +195,26 @@ def do_POST(self) -> None: # noqa: N802 elif params is None: result = method() else: - self.send_json_error(-32602, 'Unrecognized method parameter format.') + return JsonRpcError(-32602, 'Unrecognized method parameter format.', id) _LOGGER.debug(f'Got response {result}') - self.send_json_success(result, request['id']) + return JsonRpcSuccess(result, id) + + def _validate_request(self, request_dict: Any) -> JsonRpcRequest | JsonRpcError: + required_fields = ['jsonrpc', 'method', 'id'] + for field in required_fields: + if field not in request_dict: + return JsonRpcError(-32600, f'Invalid request: missing field "{field}"', request_dict.get('id', None)) + + jsonrpc_version = request_dict['jsonrpc'] + if jsonrpc_version != JsonRpcServer.JSONRPC_VERSION: + return JsonRpcError( + -32600, f'Invalid request: bad version: "{jsonrpc_version}"', request_dict.get('id', None) + ) + + method_name = request_dict['method'] + if method_name not in self.methods.keys(): + return JsonRpcError(-32601, f'Method "{method_name}" not found.', request_dict.get('id', None)) + + return JsonRpcRequest( + method=request_dict['method'], params=request_dict.get('params', None), id=request_dict.get('id', None) + ) diff --git a/pyk/src/tests/integration/test_json_rpc.py b/pyk/src/tests/integration/test_json_rpc.py index c259e88b12b..61a5367df1b 100644 --- a/pyk/src/tests/integration/test_json_rpc.py +++ b/pyk/src/tests/integration/test_json_rpc.py @@ -1,7 +1,10 @@ from __future__ import annotations +import json +from http.client import HTTPConnection from threading import Thread from time import sleep +from typing import TYPE_CHECKING from pyk.cterm import CTerm from pyk.kast.inner import KApply, KSequence, KSort, KToken @@ -11,6 +14,9 @@ from pyk.rpc.rpc import JsonRpcServer, ServeRpcOptions from pyk.testing import KRunTest +if TYPE_CHECKING: + from typing import Any + class StatefulKJsonRpcServer(JsonRpcServer): krun: KRun @@ -67,7 +73,7 @@ def exec_add(self) -> int: return int(k_cell.token) -class TestJsonRPCServer(KRunTest): +class TestJsonKRPCServer(KRunTest): KOMPILE_DEFINITION = """ module JSON-RPC-EXAMPLE-SYNTAX imports INT-SYNTAX @@ -133,3 +139,121 @@ def wait_until_ready() -> None: server.shutdown() thread.join() + + +class StatefulJsonRpcServer(JsonRpcServer): + + x: int = 42 + y: int = 43 + + def __init__(self, options: ServeRpcOptions) -> None: + super().__init__(options) + + self.register_method('get_x', self.exec_get_x) + self.register_method('get_y', self.exec_get_y) + self.register_method('set_x', self.exec_set_x) + self.register_method('set_y', self.exec_set_y) + self.register_method('add', self.exec_add) + + def exec_get_x(self) -> int: + return self.x + + def exec_get_y(self) -> int: + return self.y + + def exec_set_x(self, n: int) -> None: + self.x = n + + def exec_set_y(self, n: int) -> None: + self.y = n + + def exec_add(self) -> int: + return self.x + self.y + + +class TestJsonRPCServer(KRunTest): + + def test_json_rpc_server(self) -> None: + server = StatefulJsonRpcServer(ServeRpcOptions({'port': 0})) + + def run_server() -> None: + server.serve() + + def wait_until_server_is_up() -> None: + while True: + try: + server.port() + return + except ValueError: + sleep(0.1) + + thread = Thread(target=run_server) + thread.start() + + wait_until_server_is_up() + + http_client = HTTPConnection('localhost', server.port()) + rpc_client = SimpleClient(http_client) + + def wait_until_ready() -> None: + while True: + try: + rpc_client.request('get_x', []) + except ConnectionRefusedError: + sleep(0.1) + continue + break + + wait_until_ready() + + rpc_client.request('set_x', [123]) + res = rpc_client.request('get_x') + assert res == 123 + + rpc_client.request('set_y', [456]) + res = rpc_client.request('get_y') + assert res == 456 + + res = rpc_client.request('add', []) + assert res == (123 + 456) + + res = rpc_client.batch_request(('set_x', [1]), ('set_y', [2]), ('add', [])) + assert len(res) == 3 + assert res[2]['result'] == 1 + 2 + + server.shutdown() + thread.join() + + +class SimpleClient: + + client: HTTPConnection + _request_id: int = 0 + + def __init__(self, client: HTTPConnection) -> None: + self.client = client + + def request_id(self) -> int: + self._request_id += 1 + return self._request_id + + def request(self, method: str, params: Any = None) -> Any: + body = json.dumps({'jsonrpc': '2.0', 'method': method, 'params': params, 'id': self.request_id()}) + + self.client.request('POST', '/', body) + response = self.client.getresponse() + result = json.loads(response.read()) + return result['result'] + + def batch_request(self, *requests: tuple[str, Any]) -> list[Any]: + body = json.dumps( + [ + {'jsonrpc': '2.0', 'method': method, 'params': params, 'id': self.request_id()} + for method, params in requests + ] + ) + + self.client.request('POST', '/', body) + response = self.client.getresponse() + result = json.loads(response.read()) + return result