Skip to content

Commit

Permalink
tidy
Browse files Browse the repository at this point in the history
  • Loading branch information
altendky committed Nov 8, 2024
1 parent 2fd27fd commit c65f3f7
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 17 deletions.
63 changes: 53 additions & 10 deletions chia/rpc/rpc_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,22 +11,38 @@
from ssl import SSLContext
from typing import Any, Callable, Generic, Optional, TypeVar

from aiohttp import ClientConnectorError, ClientSession, ClientWebSocketResponse, WSMsgType, web
from aiohttp import (
ClientConnectorError,
ClientSession,
ClientWebSocketResponse,
WSMsgType,
web,
)
from typing_extensions import Protocol, final

from chia import __version__
from chia.rpc.util import wrap_http_handler
from chia.server.outbound_message import NodeType
from chia.server.server import ChiaServer, ssl_context_for_client, ssl_context_for_server
from chia.server.server import (
ChiaServer,
ssl_context_for_client,
ssl_context_for_server,
)
from chia.server.ws_connection import WSChiaConnection
from chia.types.peer_info import PeerInfo
from chia.util.byte_types import hexstr_to_bytes
from chia.util.chia_logging import set_log_level
from chia.util.chia_logging import default_log_level, set_log_level
from chia.util.config import str2bool
from chia.util.ints import uint16
from chia.util.json_util import dict_to_json_str
from chia.util.network import WebServer, resolve
from chia.util.ws_message import WsRpcMessage, create_payload, create_payload_dict, format_response, pong
from chia.util.ws_message import (
WsRpcMessage,
create_payload,
create_payload_dict,
format_response,
pong,
)

log = logging.getLogger(__name__)
max_message_size = 50 * 1024 * 1024 # 50MB
Expand Down Expand Up @@ -134,6 +150,7 @@ class RpcServer(Generic[_T_RpcApiProtocol]):
ssl_context: SSLContext
ssl_client_context: SSLContext
net_config: dict[str, Any]
service_config: dict[str, Any]
webserver: Optional[WebServer] = None
daemon_heartbeat: int = 300
daemon_connection_task: Optional[asyncio.Task[None]] = None
Expand All @@ -150,6 +167,7 @@ def create(
stop_cb: Callable[[], None],
root_path: Path,
net_config: dict[str, Any],
service_config: dict[str, Any],
prefer_ipv6: bool,
) -> RpcServer[_T_RpcApiProtocol]:
crt_path = root_path / net_config["daemon_ssl"]["private_crt"]
Expand All @@ -166,6 +184,7 @@ def create(
ssl_context,
ssl_client_context,
net_config,
service_config=service_config,
daemon_heartbeat=daemon_heartbeat,
prefer_ipv6=prefer_ipv6,
)
Expand Down Expand Up @@ -248,6 +267,7 @@ def _get_routes(self) -> dict[str, Endpoint]:
"/healthz": self.healthz,
"/get_log_level": self.get_log_level,
"/set_log_level": self.set_log_level,
"/reset_log_level": self.reset_log_level,
}

async def get_routes(self, request: dict[str, Any]) -> EndpointResult:
Expand All @@ -260,7 +280,11 @@ async def get_network_info(self, _: dict[str, Any]) -> EndpointResult:
network_name = self.net_config["selected_network"]
address_prefix = self.net_config["network_overrides"]["config"][network_name]["address_prefix"]
genesis_challenge = self.net_config["network_overrides"]["constants"][network_name]["GENESIS_CHALLENGE"]
return {"network_name": network_name, "network_prefix": address_prefix, "genesis_challenge": genesis_challenge}
return {
"network_name": network_name,
"network_prefix": address_prefix,
"genesis_challenge": genesis_challenge,
}

async def get_connections(self, request: dict[str, Any]) -> EndpointResult:
request_node_type: Optional[NodeType] = None
Expand Down Expand Up @@ -316,16 +340,28 @@ async def get_log_level(self, request: dict[str, Any]) -> EndpointResult:
logger = logging.getLogger()
level_number = logger.getEffectiveLevel()
level_name = logging.getLevelName(level_number)

return {
"level_number": level_number,
"success": True,
"level_name": level_name,
"available_levels": list(logging.getLevelNamesMapping()),
}

async def reset_log_level(self, request: dict[str, Any]) -> EndpointResult:
level_name = self.service_config.get("log_level", default_log_level)

return await self.set_log_level(request={"level_name": level_name})

async def set_log_level(self, request: dict[str, Any]) -> EndpointResult:
requested_level_number = request["level_number"]
set_log_level(log_level=requested_level_number, service_name=self.service_name)
error_strings = set_log_level(log_level=request["level_name"], service_name=self.service_name)
status = await self.get_log_level(request={})

return {}
status["success"] &= len(error_strings) == 0

return {
**status,
"errors": error_strings,
}

async def ws_api(self, message: WsRpcMessage) -> Optional[dict[str, object]]:
"""
Expand Down Expand Up @@ -437,6 +473,7 @@ async def start_rpc_server(
stop_cb: Callable[[], None],
root_path: Path,
net_config: dict[str, object],
service_config: dict[str, object],
connect_to_daemon: bool = True,
max_request_body_size: Optional[int] = None,
) -> RpcServer[_T_RpcApiProtocol]:
Expand All @@ -451,7 +488,13 @@ async def start_rpc_server(
prefer_ipv6 = str2bool(str(net_config.get("prefer_ipv6", False)))

rpc_server = RpcServer.create(
rpc_api, rpc_api.service_name, stop_cb, root_path, net_config, prefer_ipv6=prefer_ipv6
rpc_api,
rpc_api.service_name,
stop_cb,
root_path,
net_config,
service_config=service_config,
prefer_ipv6=prefer_ipv6,
)
rpc_server.rpc_api.service._set_state_changed_callback(rpc_server.state_changed)
await rpc_server.start(self_hostname, rpc_port, max_request_body_size)
Expand Down
3 changes: 2 additions & 1 deletion chia/server/start_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,8 @@ async def manage(self, *, start: bool = True) -> AsyncIterator[None]:
self.stop_requested.set,
self.root_path,
self.config,
self._connect_to_daemon,
service_config=self.service_config,
connect_to_daemon=self._connect_to_daemon,
max_request_body_size=self.max_request_body_size,
)
yield
Expand Down
16 changes: 10 additions & 6 deletions chia/util/chia_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def initialize_logging(
set_log_level(log_level=log_level, service_name=service_name)


def set_log_level(log_level: str, service_name: str) -> None:
def set_log_level(log_level: str, service_name: str) -> list[str]:
root_logger = logging.getLogger()
log_level_exceptions = {}

Expand All @@ -110,11 +110,13 @@ def set_log_level(log_level: str, service_name: str) -> None:
handler.setLevel(default_log_level)
log_level_exceptions[handler] = e

for handler, exception in log_level_exceptions.items():
root_logger.error(
f"Handler {handler}: Invalid log level '{log_level}' found in {service_name} config. "
f"Defaulting to: {default_log_level}. Error: {exception}"
)
error_strings = [
f"Handler {handler}: Invalid log level '{log_level}' for {service_name}. "
f"Defaulting to: {default_log_level}. Error: {exception}"
for handler, exception in log_level_exceptions.items()
]
for error_string in error_strings:
root_logger.error(error_string)

# Adjust the root logger to the smallest used log level since its default level is WARNING which would overwrite
# the potentially smaller log levels of specific handlers.
Expand All @@ -123,6 +125,8 @@ def set_log_level(log_level: str, service_name: str) -> None:
if root_logger.level <= logging.DEBUG:
logging.getLogger("aiosqlite").setLevel(logging.INFO) # Too much logging on debug level

return error_strings


def initialize_service_logging(service_name: str, config: dict[str, Any]) -> None:
logging_root_path = DEFAULT_ROOT_PATH
Expand Down

0 comments on commit c65f3f7

Please sign in to comment.