From bffb7b11c37f107b08ed2b8e2858d9376fb8faee Mon Sep 17 00:00:00 2001 From: Matt Hauff Date: Tue, 24 Sep 2024 15:08:09 -0700 Subject: [PATCH] [CHIA-1307] Port key management RPCs to `@marshal` decorator (#18593) * Port `log_in` to decorator * Port `get_logged_in_fingerprint` to decorator * Port `get_public_keys` to decorator * Port `get_private_key` to decorator * Port `get_mnemonic` to decorator * Port `add_key` to decorator * Port `delete_key` to decorator * Port `check_delete_key` to decorator * Port `delete_all_keys` to decorator * Bad copypasta --- chia/_tests/cmds/test_cmd_framework.py | 2 +- chia/_tests/wallet/rpc/test_wallet_rpc.py | 77 +++++----- .../wallet/test_wallet_test_framework.py | 2 +- chia/cmds/cmds_util.py | 12 +- chia/data_layer/data_layer.py | 12 +- chia/rpc/wallet_request_types.py | 105 ++++++++++++- chia/rpc/wallet_rpc_api.py | 144 ++++++++++-------- chia/rpc/wallet_rpc_client.py | 63 ++++---- 8 files changed, 266 insertions(+), 151 deletions(-) diff --git a/chia/_tests/cmds/test_cmd_framework.py b/chia/_tests/cmds/test_cmd_framework.py index 0a404a7e0f55..1ed1026844b0 100644 --- a/chia/_tests/cmds/test_cmd_framework.py +++ b/chia/_tests/cmds/test_cmd_framework.py @@ -392,7 +392,7 @@ def run(self) -> None: check_click_parsing(expected_command, "-wp", str(port), "-f", str(fingerprint)) async with expected_command.rpc_info.wallet_rpc(consume_errors=False) as client_info: - assert await client_info.client.get_logged_in_fingerprint() == fingerprint + assert (await client_info.client.get_logged_in_fingerprint()).fingerprint == fingerprint # We don't care about setting the correct arg type here test_present_client_info = TempCMD(rpc_info=NeedsWalletRPC(client_info="hello world")) # type: ignore[arg-type] diff --git a/chia/_tests/wallet/rpc/test_wallet_rpc.py b/chia/_tests/wallet/rpc/test_wallet_rpc.py index ef2072ce469a..98b829a93008 100644 --- a/chia/_tests/wallet/rpc/test_wallet_rpc.py +++ b/chia/_tests/wallet/rpc/test_wallet_rpc.py @@ -47,10 +47,15 @@ from chia.rpc.rpc_client import ResponseFailureError from chia.rpc.rpc_server import RpcServer from chia.rpc.wallet_request_types import ( + AddKey, + CheckDeleteKey, CombineCoins, DefaultCAT, + DeleteKey, DIDGetPubkey, GetNotifications, + GetPrivateKey, + LogIn, SplitCoins, SplitCoinsResponse, VerifySignature, @@ -1709,22 +1714,22 @@ async def _check_delete_key( save_config(wallet_node.root_path, "config.yaml", test_config) # Check farmer_fp key - sk_dict = await client.check_delete_key(farmer_fp) - assert sk_dict["fingerprint"] == farmer_fp - assert sk_dict["used_for_farmer_rewards"] is True - assert sk_dict["used_for_pool_rewards"] is False + resp = await client.check_delete_key(CheckDeleteKey(uint32(farmer_fp))) + assert resp.fingerprint == farmer_fp + assert resp.used_for_farmer_rewards is True + assert resp.used_for_pool_rewards is False # Check pool_fp key - sk_dict = await client.check_delete_key(pool_fp) - assert sk_dict["fingerprint"] == pool_fp - assert sk_dict["used_for_farmer_rewards"] is False - assert sk_dict["used_for_pool_rewards"] is True + resp = await client.check_delete_key(CheckDeleteKey(uint32(pool_fp))) + assert resp.fingerprint == pool_fp + assert resp.used_for_farmer_rewards is False + assert resp.used_for_pool_rewards is True # Check unknown key - sk_dict = await client.check_delete_key(123456, 10) - assert sk_dict["fingerprint"] == 123456 - assert sk_dict["used_for_farmer_rewards"] is False - assert sk_dict["used_for_pool_rewards"] is False + resp = await client.check_delete_key(CheckDeleteKey(uint32(123456), uint16(10))) + assert resp.fingerprint == 123456 + assert resp.used_for_farmer_rewards is False + assert resp.used_for_pool_rewards is False @pytest.mark.anyio @@ -1738,7 +1743,7 @@ async def test_key_and_address_endpoints(wallet_rpc_environment: WalletRpcTestEn address = await client.get_next_address(1, True) assert len(address) > 10 - pks = await client.get_public_keys() + pks = (await client.get_public_keys()).pk_fingerprints assert len(pks) == 1 await generate_funds(env.full_node.api, env.wallet_1) @@ -1756,23 +1761,21 @@ async def test_key_and_address_endpoints(wallet_rpc_environment: WalletRpcTestEn await client.delete_unconfirmed_transactions(1) assert len(await wallet.wallet_state_manager.tx_store.get_unconfirmed_for_wallet(1)) == 0 - sk_dict = await client.get_private_key(pks[0]) - assert sk_dict["fingerprint"] == pks[0] - assert sk_dict["sk"] is not None - assert sk_dict["pk"] is not None - assert sk_dict["seed"] is not None + sk_resp = await client.get_private_key(GetPrivateKey(pks[0])) + assert sk_resp.private_key.fingerprint == pks[0] + assert sk_resp.private_key.seed is not None - mnemonic = await client.generate_mnemonic() - assert len(mnemonic) == 24 + resp = await client.generate_mnemonic() + assert len(resp.mnemonic) == 24 - await client.add_key(mnemonic) + await client.add_key(AddKey(resp.mnemonic)) - pks = await client.get_public_keys() + pks = (await client.get_public_keys()).pk_fingerprints assert len(pks) == 2 - await client.log_in(pks[1]) - sk_dict = await client.get_private_key(pks[1]) - assert sk_dict["fingerprint"] == pks[1] + await client.log_in(LogIn(pks[1])) + sk_resp = await client.get_private_key(GetPrivateKey(pks[1])) + assert sk_resp.private_key.fingerprint == pks[1] # test hardened keys await _check_delete_key(client=client, wallet_node=wallet_node, farmer_fp=pks[0], pool_fp=pks[1], observer=False) @@ -1786,10 +1789,10 @@ async def test_key_and_address_endpoints(wallet_rpc_environment: WalletRpcTestEn save_config(wallet_node.root_path, "config.yaml", test_config) # Check key - sk_dict = await client.check_delete_key(pks[1]) - assert sk_dict["fingerprint"] == pks[1] - assert sk_dict["used_for_farmer_rewards"] is False - assert sk_dict["used_for_pool_rewards"] is True + delete_key_resp = await client.check_delete_key(CheckDeleteKey(pks[1])) + assert delete_key_resp.fingerprint == pks[1] + assert delete_key_resp.used_for_farmer_rewards is False + assert delete_key_resp.used_for_pool_rewards is True # set farmer and pool to empty string with lock_and_load_config(wallet_node.root_path, "config.yaml") as test_config: @@ -1798,14 +1801,14 @@ async def test_key_and_address_endpoints(wallet_rpc_environment: WalletRpcTestEn save_config(wallet_node.root_path, "config.yaml", test_config) # Check key - sk_dict = await client.check_delete_key(pks[0]) - assert sk_dict["fingerprint"] == pks[0] - assert sk_dict["used_for_farmer_rewards"] is False - assert sk_dict["used_for_pool_rewards"] is False + delete_key_resp = await client.check_delete_key(CheckDeleteKey(pks[0])) + assert delete_key_resp.fingerprint == pks[0] + assert delete_key_resp.used_for_farmer_rewards is False + assert delete_key_resp.used_for_pool_rewards is False - await client.delete_key(pks[0]) - await client.log_in(pks[1]) - assert len(await client.get_public_keys()) == 1 + await client.delete_key(DeleteKey(pks[0])) + await client.log_in(LogIn(uint32(pks[1]))) + assert len((await client.get_public_keys()).pk_fingerprints) == 1 assert not (await client.get_sync_status()) @@ -1818,7 +1821,7 @@ async def test_key_and_address_endpoints(wallet_rpc_environment: WalletRpcTestEn # Delete all keys await client.delete_all_keys() - assert len(await client.get_public_keys()) == 0 + assert len((await client.get_public_keys()).pk_fingerprints) == 0 @pytest.mark.anyio diff --git a/chia/_tests/wallet/test_wallet_test_framework.py b/chia/_tests/wallet/test_wallet_test_framework.py index b9a68192fca9..67c2dc43601d 100644 --- a/chia/_tests/wallet/test_wallet_test_framework.py +++ b/chia/_tests/wallet/test_wallet_test_framework.py @@ -31,7 +31,7 @@ async def test_basic_functionality(wallet_environments: WalletTestFramework) -> env_0: WalletEnvironment = wallet_environments.environments[0] env_1: WalletEnvironment = wallet_environments.environments[1] - assert await env_0.rpc_client.get_logged_in_fingerprint() is not None + assert (await env_0.rpc_client.get_logged_in_fingerprint()).fingerprint is not None # assert await env_1.rpc_client.get_logged_in_fingerprint() is not None assert await env_0.xch_wallet.get_confirmed_balance() == 2_000_000_000_000 diff --git a/chia/cmds/cmds_util.py b/chia/cmds/cmds_util.py index ca05193956d0..15d2c09d4b80 100644 --- a/chia/cmds/cmds_util.py +++ b/chia/cmds/cmds_util.py @@ -18,6 +18,7 @@ from chia.rpc.full_node_rpc_client import FullNodeRpcClient from chia.rpc.harvester_rpc_client import HarvesterRpcClient from chia.rpc.rpc_client import ResponseFailureError, RpcClient +from chia.rpc.wallet_request_types import LogIn from chia.rpc.wallet_rpc_client import WalletRpcClient from chia.simulator.simulator_full_node_rpc_client import SimulatorFullNodeRpcClient from chia.types.blockchain_format.sized_bytes import bytes32 @@ -25,7 +26,7 @@ from chia.util.config import load_config from chia.util.default_root import DEFAULT_ROOT_PATH from chia.util.errors import CliRpcConnectionError, InvalidPathError -from chia.util.ints import uint16, uint64 +from chia.util.ints import uint16, uint32, uint64 from chia.util.keychain import KeyData from chia.util.streamable import Streamable, streamable from chia.wallet.conditions import ConditionValidTimes @@ -169,7 +170,7 @@ async def get_wallet(root_path: Path, wallet_client: WalletRpcClient, fingerprin # if only a single key is available, select it automatically selected_fingerprint = fingerprints[0] else: - logged_in_fingerprint: Optional[int] = await wallet_client.get_logged_in_fingerprint() + logged_in_fingerprint: Optional[int] = (await wallet_client.get_logged_in_fingerprint()).fingerprint logged_in_key: Optional[KeyData] = None if logged_in_fingerprint is not None: logged_in_key = next((key for key in all_keys if key.fingerprint == logged_in_fingerprint), None) @@ -227,10 +228,11 @@ async def get_wallet(root_path: Path, wallet_client: WalletRpcClient, fingerprin selected_fingerprint = fp if selected_fingerprint is not None: - log_in_response = await wallet_client.log_in(selected_fingerprint) + try: + await wallet_client.log_in(LogIn(uint32(selected_fingerprint))) + except ValueError as e: + raise CliRpcConnectionError(f"Login failed for fingerprint {selected_fingerprint}: {e.args[0]}") - if log_in_response["success"] is False: - raise CliRpcConnectionError(f"Login failed for fingerprint {selected_fingerprint}: {log_in_response}") finally: # Closing the keychain proxy takes a moment, so we wait until after the login is complete if keychain_proxy is not None: diff --git a/chia/data_layer/data_layer.py b/chia/data_layer/data_layer.py index 4adc361c2ef8..bcbf7b1871c5 100644 --- a/chia/data_layer/data_layer.py +++ b/chia/data_layer/data_layer.py @@ -67,6 +67,7 @@ write_files_for_root, ) from chia.rpc.rpc_server import StateChangedProtocol, default_get_connections +from chia.rpc.wallet_request_types import LogIn from chia.rpc.wallet_rpc_client import WalletRpcClient from chia.server.outbound_message import NodeType from chia.server.server import ChiaServer @@ -242,13 +243,12 @@ def set_server(self, server: ChiaServer) -> None: self._server = server async def wallet_log_in(self, fingerprint: int) -> int: - result = await self.wallet_rpc.log_in(fingerprint) - if not result.get("success", False): - wallet_error = result.get("error", "no error message provided") - raise Exception(f"DataLayer wallet RPC log in request failed: {wallet_error}") + try: + result = await self.wallet_rpc.log_in(LogIn(uint32(fingerprint))) + except ValueError as e: + raise Exception(f"DataLayer wallet RPC log in request failed: {e.args[0]}") - fingerprint = cast(int, result["fingerprint"]) - return fingerprint + return result.fingerprint async def create_store( self, fee: uint64, root: bytes32 = bytes32([0] * 32) diff --git a/chia/rpc/wallet_request_types.py b/chia/rpc/wallet_request_types.py index 8f8384d4aa69..5de7d93213b7 100644 --- a/chia/rpc/wallet_request_types.py +++ b/chia/rpc/wallet_request_types.py @@ -6,7 +6,7 @@ from dataclasses import dataclass, field from typing import Any, Dict, List, Optional, Type, TypeVar -from chia_rs import G1Element, G2Element +from chia_rs import G1Element, G2Element, PrivateKey from typing_extensions import dataclass_transform from chia.types.blockchain_format.sized_bytes import bytes32 @@ -42,6 +42,109 @@ def default_raise() -> Any: # pragma: no cover raise RuntimeError("This should be impossible to hit and is just for < 3.10 compatibility") +@streamable +@dataclass(frozen=True) +class Empty(Streamable): + pass + + +@streamable +@dataclass(frozen=True) +class LogIn(Streamable): + fingerprint: uint32 + + +@streamable +@dataclass(frozen=True) +class LogInResponse(Streamable): + fingerprint: uint32 + + +@streamable +@dataclass(frozen=True) +class GetLoggedInFingerprintResponse(Streamable): + fingerprint: Optional[uint32] + + +@streamable +@dataclass(frozen=True) +class GetPublicKeysResponse(Streamable): + keyring_is_locked: bool + public_key_fingerprints: Optional[List[uint32]] = None + + @property + def pk_fingerprints(self) -> List[uint32]: + if self.keyring_is_locked: + raise RuntimeError("get_public_keys cannot return public keys because the keyring is locked") + else: + assert self.public_key_fingerprints is not None + return self.public_key_fingerprints + + +@streamable +@dataclass(frozen=True) +class GetPrivateKey(Streamable): + fingerprint: uint32 + + +# utility for `GetPrivateKeyResponse` +@streamable +@dataclass(frozen=True) +class GetPrivateKeyFormat(Streamable): + fingerprint: uint32 + sk: PrivateKey + pk: G1Element + farmer_pk: G1Element + pool_pk: G1Element + seed: Optional[str] + + +@streamable +@dataclass(frozen=True) +class GetPrivateKeyResponse(Streamable): + private_key: GetPrivateKeyFormat + + +@streamable +@dataclass(frozen=True) +class GenerateMnemonicResponse(Streamable): + mnemonic: List[str] + + +@streamable +@dataclass(frozen=True) +class AddKey(Streamable): + mnemonic: List[str] + + +@streamable +@dataclass(frozen=True) +class AddKeyResponse(Streamable): + fingerprint: uint32 + + +@streamable +@dataclass(frozen=True) +class DeleteKey(Streamable): + fingerprint: uint32 + + +@streamable +@dataclass(frozen=True) +class CheckDeleteKey(Streamable): + fingerprint: uint32 + max_ph_to_search: uint16 = uint16(100) + + +@streamable +@dataclass(frozen=True) +class CheckDeleteKeyResponse(Streamable): + fingerprint: uint32 + used_for_farmer_rewards: bool + used_for_pool_rewards: bool + wallet_balance: bool + + @streamable @dataclass(frozen=True) class GetNotifications(Streamable): diff --git a/chia/rpc/wallet_rpc_api.py b/chia/rpc/wallet_rpc_api.py index 84772cb34bed..4247d056ac31 100644 --- a/chia/rpc/wallet_rpc_api.py +++ b/chia/rpc/wallet_rpc_api.py @@ -20,16 +20,30 @@ from chia.rpc.rpc_server import Endpoint, EndpointResult, default_get_connections from chia.rpc.util import marshal, tx_endpoint from chia.rpc.wallet_request_types import ( + AddKey, + AddKeyResponse, ApplySignatures, ApplySignaturesResponse, + CheckDeleteKey, + CheckDeleteKeyResponse, CombineCoins, CombineCoinsResponse, + DeleteKey, + Empty, ExecuteSigningInstructions, ExecuteSigningInstructionsResponse, GatherSigningInfo, GatherSigningInfoResponse, + GenerateMnemonicResponse, + GetLoggedInFingerprintResponse, GetNotifications, GetNotificationsResponse, + GetPrivateKey, + GetPrivateKeyFormat, + GetPrivateKeyResponse, + GetPublicKeysResponse, + LogIn, + LogInResponse, SplitCoins, SplitCoinsResponse, SubmitTransactions, @@ -378,39 +392,42 @@ async def get_latest_singleton_coin_spend( # Key management ########################################################################################## - async def log_in(self, request: Dict[str, Any]) -> EndpointResult: + @marshal + async def log_in(self, request: LogIn) -> LogInResponse: """ Logs in the wallet with a specific key. """ - fingerprint = request["fingerprint"] - if self.service.logged_in_fingerprint == fingerprint: - return {"fingerprint": fingerprint} + if self.service.logged_in_fingerprint == request.fingerprint: + return LogInResponse(request.fingerprint) await self._stop_wallet() - started = await self.service._start_with_fingerprint(fingerprint) + started = await self.service._start_with_fingerprint(request.fingerprint) if started is True: - return {"fingerprint": fingerprint} + return LogInResponse(request.fingerprint) - return {"success": False, "error": f"fingerprint {fingerprint} not found in keychain or keychain is empty"} + raise ValueError(f"fingerprint {request.fingerprint} not found in keychain or keychain is empty") - async def get_logged_in_fingerprint(self, request: Dict[str, Any]) -> EndpointResult: - return {"fingerprint": self.service.logged_in_fingerprint} + @marshal + async def get_logged_in_fingerprint(self, request: Empty) -> GetLoggedInFingerprintResponse: + return GetLoggedInFingerprintResponse(uint32.construct_optional(self.service.logged_in_fingerprint)) - async def get_public_keys(self, request: Dict[str, Any]) -> EndpointResult: + @marshal + async def get_public_keys(self, request: Empty) -> GetPublicKeysResponse: try: fingerprints = [ - sk.get_g1().get_fingerprint() for (sk, seed) in await self.service.keychain_proxy.get_all_private_keys() + uint32(sk.get_g1().get_fingerprint()) + for (sk, seed) in await self.service.keychain_proxy.get_all_private_keys() ] except KeychainIsLocked: - return {"keyring_is_locked": True} + return GetPublicKeysResponse(keyring_is_locked=True) except Exception as e: raise Exception( "Error while getting keys. If the issue persists, restart all services." f" Original error: {type(e).__name__}: {e}" ) from e else: - return {"public_key_fingerprints": fingerprints} + return GetPublicKeysResponse(keyring_is_locked=False, public_key_fingerprints=fingerprints) async def _get_private_key(self, fingerprint: int) -> Tuple[Optional[PrivateKey], Optional[bytes]]: try: @@ -422,44 +439,37 @@ async def _get_private_key(self, fingerprint: int) -> Tuple[Optional[PrivateKey] log.error(f"Failed to get private key by fingerprint: {e}") return None, None - async def get_private_key(self, request: Dict[str, Any]) -> EndpointResult: - fingerprint = request["fingerprint"] - sk, seed = await self._get_private_key(fingerprint) + @marshal + async def get_private_key(self, request: GetPrivateKey) -> GetPrivateKeyResponse: + sk, seed = await self._get_private_key(request.fingerprint) if sk is not None: s = bytes_to_mnemonic(seed) if seed is not None else None - return { - "private_key": { - "fingerprint": fingerprint, - "sk": bytes(sk).hex(), - "pk": bytes(sk.get_g1()).hex(), - "farmer_pk": bytes(master_sk_to_farmer_sk(sk).get_g1()).hex(), - "pool_pk": bytes(master_sk_to_pool_sk(sk).get_g1()).hex(), - "seed": s, - }, - } - return {"success": False, "private_key": {"fingerprint": fingerprint}} + return GetPrivateKeyResponse( + private_key=GetPrivateKeyFormat( + fingerprint=request.fingerprint, + sk=sk, + pk=sk.get_g1(), + farmer_pk=master_sk_to_farmer_sk(sk).get_g1(), + pool_pk=master_sk_to_pool_sk(sk).get_g1(), + seed=s, + ) + ) - async def generate_mnemonic(self, request: Dict[str, Any]) -> EndpointResult: - return {"mnemonic": generate_mnemonic().split(" ")} + raise ValueError(f"Could not get a private key for fingerprint {request.fingerprint}") - async def add_key(self, request: Dict[str, Any]) -> EndpointResult: - if "mnemonic" not in request: - raise ValueError("Mnemonic not in request") + @marshal + async def generate_mnemonic(self, request: Empty) -> GenerateMnemonicResponse: + return GenerateMnemonicResponse(generate_mnemonic().split(" ")) + @marshal + async def add_key(self, request: AddKey) -> AddKeyResponse: # Adding a key from 24 word mnemonic - mnemonic = request["mnemonic"] try: - sk = await self.service.keychain_proxy.add_key(" ".join(mnemonic)) + sk = await self.service.keychain_proxy.add_key(" ".join(request.mnemonic)) except KeyError as e: - return { - "success": False, - "error": f"The word '{e.args[0]}' is incorrect.'", - "word": e.args[0], - } - except Exception as e: - return {"success": False, "error": str(e)} + raise ValueError(f"The word '{e.args[0]}' is incorrect.") - fingerprint = sk.get_g1().get_fingerprint() + fingerprint = uint32(sk.get_g1().get_fingerprint()) await self._stop_wallet() # Makes sure the new key is added to config properly @@ -470,24 +480,24 @@ async def add_key(self, request: Dict[str, Any]) -> EndpointResult: log.error(f"Failed to check_keys after adding a new key: {e}") started = await self.service._start_with_fingerprint(fingerprint=fingerprint) if started is True: - return {"fingerprint": fingerprint} + return AddKeyResponse(fingerprint=fingerprint) raise ValueError("Failed to start") - async def delete_key(self, request: Dict[str, Any]) -> EndpointResult: + @marshal + async def delete_key(self, request: DeleteKey) -> Empty: await self._stop_wallet() - fingerprint = request["fingerprint"] try: - await self.service.keychain_proxy.delete_key_by_fingerprint(fingerprint) + await self.service.keychain_proxy.delete_key_by_fingerprint(request.fingerprint) except Exception as e: log.error(f"Failed to delete key by fingerprint: {e}") - return {"success": False, "error": str(e)} + raise e path = path_from_root( self.service.root_path, - f"{self.service.config['database_path']}-{fingerprint}", + f"{self.service.config['database_path']}-{request.fingerprint}", ) if path.exists(): path.unlink() - return {} + return Empty() async def _check_key_used_for_rewards( self, new_root: Path, sk: PrivateKey, max_ph_to_search: int @@ -530,26 +540,25 @@ async def _check_key_used_for_rewards( return found_farmer, found_pool - async def check_delete_key(self, request: Dict[str, Any]) -> EndpointResult: + @marshal + async def check_delete_key(self, request: CheckDeleteKey) -> CheckDeleteKeyResponse: """Check the key use prior to possible deletion checks whether key is used for either farm or pool rewards checks if any wallets have a non-zero balance """ used_for_farmer: bool = False used_for_pool: bool = False - walletBalance: bool = False + wallet_balance: bool = False - fingerprint = request["fingerprint"] - max_ph_to_search = request.get("max_ph_to_search", 100) - sk, _ = await self._get_private_key(fingerprint) + sk, _ = await self._get_private_key(request.fingerprint) if sk is not None: used_for_farmer, used_for_pool = await self._check_key_used_for_rewards( - self.service.root_path, sk, max_ph_to_search + self.service.root_path, sk, request.max_ph_to_search ) - if self.service.logged_in_fingerprint != fingerprint: + if self.service.logged_in_fingerprint != request.fingerprint: await self._stop_wallet() - await self.service._start_with_fingerprint(fingerprint=fingerprint) + await self.service._start_with_fingerprint(fingerprint=request.fingerprint) wallets: List[WalletInfo] = await self.service.wallet_state_manager.get_all_wallet_info_entries() for w in wallets: @@ -559,27 +568,28 @@ async def check_delete_key(self, request: Dict[str, Any]) -> EndpointResult: pending_balance = await wallet.get_unconfirmed_balance(unspent) if (balance + pending_balance) > 0: - walletBalance = True + wallet_balance = True break - return { - "fingerprint": fingerprint, - "used_for_farmer_rewards": used_for_farmer, - "used_for_pool_rewards": used_for_pool, - "wallet_balance": walletBalance, - } + return CheckDeleteKeyResponse( + fingerprint=request.fingerprint, + used_for_farmer_rewards=used_for_farmer, + used_for_pool_rewards=used_for_pool, + wallet_balance=wallet_balance, + ) - async def delete_all_keys(self, request: Dict[str, Any]) -> EndpointResult: + @marshal + async def delete_all_keys(self, request: Empty) -> Empty: await self._stop_wallet() try: await self.service.keychain_proxy.delete_all_keys() except Exception as e: log.error(f"Failed to delete all keys: {e}") - return {"success": False, "error": str(e)} + raise e path = path_from_root(self.service.root_path, self.service.config["database_path"]) if path.exists(): path.unlink() - return {} + return Empty() ########################################################################################## # Wallet Node diff --git a/chia/rpc/wallet_rpc_client.py b/chia/rpc/wallet_rpc_client.py index 7e0ab602a869..65dd169b96aa 100644 --- a/chia/rpc/wallet_rpc_client.py +++ b/chia/rpc/wallet_rpc_client.py @@ -7,11 +7,15 @@ from chia.pools.pool_wallet_info import PoolWalletInfo from chia.rpc.rpc_client import RpcClient from chia.rpc.wallet_request_types import ( + AddKey, + AddKeyResponse, ApplySignatures, ApplySignaturesResponse, CancelOfferResponse, CancelOffersResponse, CATSpendResponse, + CheckDeleteKey, + CheckDeleteKeyResponse, CombineCoins, CombineCoinsResponse, CreateNewDAOWalletResponse, @@ -24,6 +28,7 @@ DAOFreeCoinsFromFinishedProposalsResponse, DAOSendToLockupResponse, DAOVoteOnProposalResponse, + DeleteKey, DIDGetCurrentCoinInfo, DIDGetCurrentCoinInfoResponse, DIDGetPubkey, @@ -38,12 +43,19 @@ ExecuteSigningInstructionsResponse, GatherSigningInfo, GatherSigningInfoResponse, + GenerateMnemonicResponse, GetCATListResponse, + GetLoggedInFingerprintResponse, GetNotifications, GetNotificationsResponse, GetOffersCountResponse, + GetPrivateKey, + GetPrivateKeyResponse, + GetPublicKeysResponse, GetTransactionMemo, GetTransactionMemoResponse, + LogIn, + LogInResponse, NFTAddURIResponse, NFTGetByDID, NFTGetByDIDResponse, @@ -109,50 +121,35 @@ class WalletRpcClient(RpcClient): """ # Key Management APIs - async def log_in(self, fingerprint: int) -> Union[Dict[str, Any], Any]: - try: - return await self.fetch("log_in", {"fingerprint": fingerprint, "type": "start"}) - except ValueError as e: - return e.args[0] + async def log_in(self, request: LogIn) -> LogInResponse: + return LogInResponse.from_json_dict(await self.fetch("log_in", request.to_json_dict())) async def set_wallet_resync_on_startup(self, enable: bool = True) -> Dict[str, Any]: return await self.fetch(path="set_wallet_resync_on_startup", request_json={"enable": enable}) - async def get_logged_in_fingerprint(self) -> Optional[int]: - response = await self.fetch("get_logged_in_fingerprint", {}) - # TODO: casting due to lack of type checked deserialization - return cast(Optional[int], response["fingerprint"]) + async def get_logged_in_fingerprint(self) -> GetLoggedInFingerprintResponse: + return GetLoggedInFingerprintResponse.from_json_dict(await self.fetch("get_logged_in_fingerprint", {})) - async def get_public_keys(self) -> List[int]: - response = await self.fetch("get_public_keys", {}) - # TODO: casting due to lack of type checked deserialization - return cast(List[int], response["public_key_fingerprints"]) + async def get_public_keys(self) -> GetPublicKeysResponse: + return GetPublicKeysResponse.from_json_dict(await self.fetch("get_public_keys", {})) - async def get_private_key(self, fingerprint: int) -> Dict[str, Any]: - request = {"fingerprint": fingerprint} - response = await self.fetch("get_private_key", request) - # TODO: casting due to lack of type checked deserialization - return cast(Dict[str, Any], response["private_key"]) + async def get_private_key(self, request: GetPrivateKey) -> GetPrivateKeyResponse: + return GetPrivateKeyResponse.from_json_dict(await self.fetch("get_private_key", request.to_json_dict())) - async def generate_mnemonic(self) -> List[str]: - response = await self.fetch("generate_mnemonic", {}) - # TODO: casting due to lack of type checked deserialization - return cast(List[str], response["mnemonic"]) + async def generate_mnemonic(self) -> GenerateMnemonicResponse: + return GenerateMnemonicResponse.from_json_dict(await self.fetch("generate_mnemonic", {})) - async def add_key(self, mnemonic: List[str], request_type: str = "new_wallet") -> Dict[str, Any]: - request = {"mnemonic": mnemonic, "type": request_type} - return await self.fetch("add_key", request) + async def add_key(self, request: AddKey) -> AddKeyResponse: + return AddKeyResponse.from_json_dict(await self.fetch("add_key", request.to_json_dict())) - async def delete_key(self, fingerprint: int) -> Dict[str, Any]: - request = {"fingerprint": fingerprint} - return await self.fetch("delete_key", request) + async def delete_key(self, request: DeleteKey) -> None: + await self.fetch("delete_key", request.to_json_dict()) - async def check_delete_key(self, fingerprint: int, max_ph_to_search: int = 100) -> Dict[str, Any]: - request = {"fingerprint": fingerprint, "max_ph_to_search": max_ph_to_search} - return await self.fetch("check_delete_key", request) + async def check_delete_key(self, request: CheckDeleteKey) -> CheckDeleteKeyResponse: + return CheckDeleteKeyResponse.from_json_dict(await self.fetch("check_delete_key", request.to_json_dict())) - async def delete_all_keys(self) -> Dict[str, Any]: - return await self.fetch("delete_all_keys", {}) + async def delete_all_keys(self) -> None: + await self.fetch("delete_all_keys", {}) # Wallet Node APIs async def get_sync_status(self) -> bool: