Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/develop'
Browse files Browse the repository at this point in the history
  • Loading branch information
devops committed Nov 23, 2024
2 parents fc37d4b + 3d266bd commit a73dff9
Show file tree
Hide file tree
Showing 2 changed files with 231 additions and 43 deletions.
148 changes: 106 additions & 42 deletions pyk/src/pyk/rpc/rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@

import json
import logging
from abc import ABC, abstractmethod
from dataclasses import dataclass
from functools import partial
from http.server import BaseHTTPRequestHandler, HTTPServer
from typing import TYPE_CHECKING, Any, Final
from typing import TYPE_CHECKING, Any, Final, NamedTuple

from typing_extensions import Protocol

Expand Down Expand Up @@ -71,37 +73,80 @@ class JsonRpcMethod(Protocol):
def __call__(self, **kwargs: Any) -> Any: ...


class JsonRpcRequestHandler(BaseHTTPRequestHandler):
methods: dict[str, JsonRpcMethod]
class JsonRpcRequest(NamedTuple):
id: str | int
method: str
params: Any

def __init__(self, methods: dict[str, JsonRpcMethod], *args: Any, **kwargs: Any) -> None:
self.methods = methods
super().__init__(*args, **kwargs)

def send_json_error(self, code: int, message: str, id: Any = None) -> None:
error_dict = {
class JsonRpcBatchRequest(NamedTuple):
requests: tuple[JsonRpcRequest]


class JsonRpcResult(ABC):

@abstractmethod
def encode(self) -> bytes: ...


@dataclass(frozen=True)
class JsonRpcError(JsonRpcResult):

code: int
message: str
id: str | int | None

def to_json(self) -> dict[str, Any]:
return {
'jsonrpc': JsonRpcServer.JSONRPC_VERSION,
'error': {
'code': code,
'message': message,
'code': self.code,
'message': self.message,
},
'id': id,
'id': self.id,
}
error_bytes = json.dumps(error_dict).encode('ascii')
self.set_response()
self.wfile.write(error_bytes)

def send_json_success(self, result: Any, id: Any) -> None:
response_dict = {
def encode(self) -> bytes:
return json.dumps(self.to_json()).encode('ascii')


@dataclass(frozen=True)
class JsonRpcSuccess(JsonRpcResult):
payload: Any
id: Any

def to_json(self) -> dict[str, Any]:
return {
'jsonrpc': JsonRpcServer.JSONRPC_VERSION,
'result': result,
'id': id,
'result': self.payload,
'id': self.id,
}
response_bytes = json.dumps(response_dict).encode('ascii')
self.set_response()

def encode(self) -> bytes:
return json.dumps(self.to_json()).encode('ascii')


@dataclass(frozen=True)
class JsonRpcBatchResult(JsonRpcResult):
results: tuple[JsonRpcError | JsonRpcSuccess, ...]

def encode(self) -> bytes:
return json.dumps([result.to_json() for result in self.results]).encode('ascii')


class JsonRpcRequestHandler(BaseHTTPRequestHandler):
methods: dict[str, JsonRpcMethod]

def __init__(self, methods: dict[str, JsonRpcMethod], *args: Any, **kwargs: Any) -> None:
self.methods = methods
super().__init__(*args, **kwargs)

def _send_response(self, response: JsonRpcResult) -> None:
self.send_response_headers()
response_bytes = response.encode()
self.wfile.write(response_bytes)

def set_response(self) -> None:
def send_response_headers(self) -> None:
self.send_response(200)
self.send_header('Content-type', 'text/html')
self.end_headers()
Expand All @@ -113,44 +158,63 @@ def do_POST(self) -> None: # noqa: N802
content = self.rfile.read(int(content_len))
_LOGGER.debug(f'Received bytes: {content.decode()}')

request: dict
request: dict[str, Any] | list[dict[str, Any]]
try:
request = json.loads(content)
_LOGGER.info(f'Received request: {request}')
except json.JSONDecodeError:
_LOGGER.warning(f'Invalid JSON: {content.decode()}')
self.send_json_error(-32700, 'Invalid JSON')
json_error = JsonRpcError(-32700, 'Invalid JSON', None)
self._send_response(json_error)
return

required_fields = ['jsonrpc', 'method', 'id']
for field in required_fields:
if field not in request:
_LOGGER.warning(f'Missing required field "{field}": {request}')
self.send_json_error(-32600, f'Invalid request: missing field "{field}"', request.get('id', None))
return
response: JsonRpcResult
if isinstance(request, list):
response = self._batch_request(request)
else:
response = self._single_request(request)

jsonrpc_version = request['jsonrpc']
if jsonrpc_version != JsonRpcServer.JSONRPC_VERSION:
_LOGGER.warning(f'Bad JSON-RPC version: {jsonrpc_version}')
self.send_json_error(-32600, f'Invalid request: bad version: "{jsonrpc_version}"', request['id'])
return
self._send_response(response)

method_name = request['method']
if method_name not in self.methods:
_LOGGER.warning(f'Method not found: {method_name}')
self.send_json_error(-32601, f'Method "{method_name}" not found.', request['id'])
return
def _batch_request(self, requests: list[dict[str, Any]]) -> JsonRpcBatchResult:
return JsonRpcBatchResult(tuple(self._single_request(request) for request in requests))

def _single_request(self, request: dict[str, Any]) -> JsonRpcError | JsonRpcSuccess:
validation_result = self._validate_request(request)
if isinstance(validation_result, JsonRpcError):
return validation_result

id, method_name, params = validation_result
method = self.methods[method_name]
params = request.get('params', None)
_LOGGER.info(f'Executing method {method_name}')
result: Any
if type(params) is dict:
result = method(**params)
elif type(params) is list:
result = method(*params)
elif params is None:
result = method()
else:
self.send_json_error(-32602, 'Unrecognized method parameter format.')
return JsonRpcError(-32602, 'Unrecognized method parameter format.', id)
_LOGGER.debug(f'Got response {result}')
self.send_json_success(result, request['id'])
return JsonRpcSuccess(result, id)

def _validate_request(self, request_dict: Any) -> JsonRpcRequest | JsonRpcError:
required_fields = ['jsonrpc', 'method', 'id']
for field in required_fields:
if field not in request_dict:
return JsonRpcError(-32600, f'Invalid request: missing field "{field}"', request_dict.get('id', None))

jsonrpc_version = request_dict['jsonrpc']
if jsonrpc_version != JsonRpcServer.JSONRPC_VERSION:
return JsonRpcError(
-32600, f'Invalid request: bad version: "{jsonrpc_version}"', request_dict.get('id', None)
)

method_name = request_dict['method']
if method_name not in self.methods.keys():
return JsonRpcError(-32601, f'Method "{method_name}" not found.', request_dict.get('id', None))

return JsonRpcRequest(
method=request_dict['method'], params=request_dict.get('params', None), id=request_dict.get('id', None)
)
126 changes: 125 additions & 1 deletion pyk/src/tests/integration/test_json_rpc.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from __future__ import annotations

import json
from http.client import HTTPConnection
from threading import Thread
from time import sleep
from typing import TYPE_CHECKING

from pyk.cterm import CTerm
from pyk.kast.inner import KApply, KSequence, KSort, KToken
Expand All @@ -11,6 +14,9 @@
from pyk.rpc.rpc import JsonRpcServer, ServeRpcOptions
from pyk.testing import KRunTest

if TYPE_CHECKING:
from typing import Any


class StatefulKJsonRpcServer(JsonRpcServer):
krun: KRun
Expand Down Expand Up @@ -67,7 +73,7 @@ def exec_add(self) -> int:
return int(k_cell.token)


class TestJsonRPCServer(KRunTest):
class TestJsonKRPCServer(KRunTest):
KOMPILE_DEFINITION = """
module JSON-RPC-EXAMPLE-SYNTAX
imports INT-SYNTAX
Expand Down Expand Up @@ -133,3 +139,121 @@ def wait_until_ready() -> None:

server.shutdown()
thread.join()


class StatefulJsonRpcServer(JsonRpcServer):

x: int = 42
y: int = 43

def __init__(self, options: ServeRpcOptions) -> None:
super().__init__(options)

self.register_method('get_x', self.exec_get_x)
self.register_method('get_y', self.exec_get_y)
self.register_method('set_x', self.exec_set_x)
self.register_method('set_y', self.exec_set_y)
self.register_method('add', self.exec_add)

def exec_get_x(self) -> int:
return self.x

def exec_get_y(self) -> int:
return self.y

def exec_set_x(self, n: int) -> None:
self.x = n

def exec_set_y(self, n: int) -> None:
self.y = n

def exec_add(self) -> int:
return self.x + self.y


class TestJsonRPCServer(KRunTest):

def test_json_rpc_server(self) -> None:
server = StatefulJsonRpcServer(ServeRpcOptions({'port': 0}))

def run_server() -> None:
server.serve()

def wait_until_server_is_up() -> None:
while True:
try:
server.port()
return
except ValueError:
sleep(0.1)

thread = Thread(target=run_server)
thread.start()

wait_until_server_is_up()

http_client = HTTPConnection('localhost', server.port())
rpc_client = SimpleClient(http_client)

def wait_until_ready() -> None:
while True:
try:
rpc_client.request('get_x', [])
except ConnectionRefusedError:
sleep(0.1)
continue
break

wait_until_ready()

rpc_client.request('set_x', [123])
res = rpc_client.request('get_x')
assert res == 123

rpc_client.request('set_y', [456])
res = rpc_client.request('get_y')
assert res == 456

res = rpc_client.request('add', [])
assert res == (123 + 456)

res = rpc_client.batch_request(('set_x', [1]), ('set_y', [2]), ('add', []))
assert len(res) == 3
assert res[2]['result'] == 1 + 2

server.shutdown()
thread.join()


class SimpleClient:

client: HTTPConnection
_request_id: int = 0

def __init__(self, client: HTTPConnection) -> None:
self.client = client

def request_id(self) -> int:
self._request_id += 1
return self._request_id

def request(self, method: str, params: Any = None) -> Any:
body = json.dumps({'jsonrpc': '2.0', 'method': method, 'params': params, 'id': self.request_id()})

self.client.request('POST', '/', body)
response = self.client.getresponse()
result = json.loads(response.read())
return result['result']

def batch_request(self, *requests: tuple[str, Any]) -> list[Any]:
body = json.dumps(
[
{'jsonrpc': '2.0', 'method': method, 'params': params, 'id': self.request_id()}
for method, params in requests
]
)

self.client.request('POST', '/', body)
response = self.client.getresponse()
result = json.loads(response.read())
return result

0 comments on commit a73dff9

Please sign in to comment.