Skip to content

Commit

Permalink
move tx endpoint into wallet-specific code (#19116)
Browse files Browse the repository at this point in the history
  • Loading branch information
altendky authored Jan 10, 2025
1 parent 37deeaf commit 1062ab4
Show file tree
Hide file tree
Showing 2 changed files with 198 additions and 211 deletions.
209 changes: 2 additions & 207 deletions chia/rpc/util.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,20 @@
from __future__ import annotations

import dataclasses
import logging
import traceback
from collections.abc import Awaitable
from typing import TYPE_CHECKING, Any, Callable, Optional, get_type_hints
from typing import TYPE_CHECKING, Any, Callable, get_type_hints

import aiohttp
from chia_rs import AugSchemeMPL

from chia.types.blockchain_format.coin import Coin
from chia.types.coin_spend import CoinSpend
from chia.util.json_util import obj_to_response
from chia.util.streamable import Streamable
from chia.wallet.conditions import Condition, ConditionValidTimes, conditions_from_json_dicts, parse_timelock_info
from chia.wallet.trade_record import TradeRecord
from chia.wallet.trading.offer import Offer
from chia.wallet.transaction_record import TransactionRecord
from chia.wallet.util.blind_signer_tl import BLIND_SIGNER_TRANSLATION
from chia.wallet.util.clvm_streamable import (
TranslationLayer,
json_deserialize_with_clvm_streamable,
json_serialize_with_clvm_streamable,
)
from chia.wallet.util.transaction_type import TransactionType
from chia.wallet.util.tx_config import TXConfig, TXConfigLoader
from chia.wallet.wallet_spend_bundle import WalletSpendBundle

log = logging.getLogger(__name__)

Expand All @@ -36,10 +25,6 @@
MarshallableRpcEndpoint = Callable[..., Awaitable[Streamable]]
if TYPE_CHECKING:
from chia.rpc.rpc_server import EndpointResult
from chia.rpc.wallet_rpc_api import WalletRpcApi
else:
EndpointResult = dict[str, Any]
WalletRpcApi = object


ALL_TRANSLATION_LAYERS: dict[str, TranslationLayer] = {"CHIP-0028": BLIND_SIGNER_TRANSLATION}
Expand All @@ -51,9 +36,7 @@ def marshal(func: MarshallableRpcEndpoint) -> RpcEndpoint:
assert issubclass(request_hint, Streamable)
request_class = request_hint

async def rpc_endpoint(
self: WalletRpcApi, request: dict[str, Any], *args: object, **kwargs: object
) -> EndpointResult:
async def rpc_endpoint(self: object, request: dict[str, Any], *args: object, **kwargs: object) -> EndpointResult:
response_obj: Streamable = await func(
self,
(
Expand Down Expand Up @@ -108,191 +91,3 @@ async def inner(request: aiohttp.web.Request) -> aiohttp.web.StreamResponse:
return obj_to_response(res_object)

return inner


def tx_endpoint(
push: bool = False,
merge_spends: bool = True,
) -> Callable[[RpcEndpoint], RpcEndpoint]:
def _inner(func: RpcEndpoint) -> RpcEndpoint:
async def rpc_endpoint(
self: WalletRpcApi, request: dict[str, Any], *args: object, **kwargs: object
) -> EndpointResult:
assert self.service.logged_in_fingerprint is not None
tx_config_loader: TXConfigLoader = TXConfigLoader.from_json_dict(request)

# Some backwards compat fill-ins
if tx_config_loader.excluded_coin_ids is None:
tx_config_loader = tx_config_loader.override(
excluded_coin_ids=request.get("exclude_coin_ids"),
)
if tx_config_loader.excluded_coin_amounts is None:
tx_config_loader = tx_config_loader.override(
excluded_coin_amounts=request.get("exclude_coin_amounts"),
)
if tx_config_loader.excluded_coin_ids is None:
excluded_coins: Optional[list[dict[str, Any]]] = request.get(
"exclude_coins", request.get("excluded_coins")
)
if excluded_coins is not None:
tx_config_loader = tx_config_loader.override(
excluded_coin_ids=[Coin.from_json_dict(c).name() for c in excluded_coins],
)

tx_config: TXConfig = tx_config_loader.autofill(
constants=self.service.wallet_state_manager.constants,
config=self.service.wallet_state_manager.config,
logged_in_fingerprint=self.service.logged_in_fingerprint,
)

extra_conditions: tuple[Condition, ...] = tuple()
if "extra_conditions" in request:
extra_conditions = tuple(conditions_from_json_dicts(request["extra_conditions"]))
extra_conditions = (*extra_conditions, *ConditionValidTimes.from_json_dict(request).to_conditions())

valid_times: ConditionValidTimes = parse_timelock_info(extra_conditions)
if (
valid_times.max_secs_after_created is not None
or valid_times.min_secs_since_created is not None
or valid_times.max_blocks_after_created is not None
or valid_times.min_blocks_since_created is not None
):
raise ValueError("Relative timelocks are not currently supported in the RPC")

async with self.service.wallet_state_manager.new_action_scope(
tx_config,
push=request.get("push", push),
merge_spends=request.get("merge_spends", merge_spends),
sign=request.get("sign", self.service.config.get("auto_sign_txs", True)),
) as action_scope:
response: EndpointResult = await func(
self,
request,
*args,
action_scope,
extra_conditions=extra_conditions,
**kwargs,
)

if func.__name__ == "create_new_wallet" and "transactions" not in response:
# unfortunately, this API isn't solely a tx endpoint
return response

unsigned_txs = await self.service.wallet_state_manager.gather_signing_info_for_txs(
action_scope.side_effects.transactions
)

if request.get("CHIP-0029", False):
response["unsigned_transactions"] = [
json_serialize_with_clvm_streamable(
tx,
translation_layer=(
ALL_TRANSLATION_LAYERS[request["translation"]] if "translation" in request else None
),
)
for tx in unsigned_txs
]
else:
response["unsigned_transactions"] = [tx.to_json_dict() for tx in unsigned_txs]

response["transactions"] = [
TransactionRecord.to_json_dict_convenience(tx, self.service.config)
for tx in action_scope.side_effects.transactions
]

# Some backwards compatibility code here because transaction information being returned was not uniform
# until the "transactions" key was applied to all of them. Unfortunately, since .add_pending_transactions
# now applies transformations to the transactions, we have to special case edit all of the previous
# spots where the information was being surfaced outside of the knowledge of this wrapper.
new_txs = action_scope.side_effects.transactions
if "transaction" in response:
if (
func.__name__ == "create_new_wallet" and request["wallet_type"] == "pool_wallet"
) or func.__name__ in {"pw_join_pool", "pw_self_pool", "pw_absorb_rewards"}:
# Theses RPCs return not "convenience" for some reason
response["transaction"] = new_txs[-1].to_json_dict()
else:
response["transaction"] = response["transactions"][0]
if "tx_record" in response:
response["tx_record"] = response["transactions"][0]
if "fee_transaction" in response:
# Theses RPCs return not "convenience" for some reason
fee_transactions = [tx for tx in new_txs if tx.wallet_id == 1]
if len(fee_transactions) == 0:
response["fee_transaction"] = None
else:
response["fee_transaction"] = fee_transactions[0].to_json_dict()
if "transaction_id" in response:
response["transaction_id"] = new_txs[0].name
if "transaction_ids" in response:
response["transaction_ids"] = [
tx.name.hex() for tx in new_txs if tx.type == TransactionType.OUTGOING_CLAWBACK.value
]
if "spend_bundle" in response:
response["spend_bundle"] = WalletSpendBundle.aggregate(
[tx.spend_bundle for tx in new_txs if tx.spend_bundle is not None]
)
if "signed_txs" in response:
response["signed_txs"] = response["transactions"]
if "signed_tx" in response:
response["signed_tx"] = response["transactions"][0]
if "tx" in response:
if func.__name__ == "send_notification":
response["tx"] = response["transactions"][0]
else:
response["tx"] = new_txs[0].to_json_dict()
if "txs" in response:
response["txs"] = [tx.to_json_dict() for tx in new_txs]
if "tx_id" in response:
response["tx_id"] = new_txs[0].name
if "trade_record" in response:
old_offer: Offer = Offer.from_bech32(response["offer"])
signed_coin_spends: list[CoinSpend] = [
coin_spend
for tx in new_txs
if tx.spend_bundle is not None
for coin_spend in tx.spend_bundle.coin_spends
]
involved_coins: list[Coin] = [spend.coin for spend in signed_coin_spends]
signed_coin_spends.extend(
[spend for spend in old_offer._bundle.coin_spends if spend.coin not in involved_coins]
)
new_offer_bundle = WalletSpendBundle(
signed_coin_spends,
AugSchemeMPL.aggregate(
[tx.spend_bundle.aggregated_signature for tx in new_txs if tx.spend_bundle is not None]
),
)
new_offer: Offer = Offer(old_offer.requested_payments, new_offer_bundle, old_offer.driver_dict)
response["offer"] = new_offer.to_bech32()
old_trade_record: TradeRecord = TradeRecord.from_json_dict_convenience(
response["trade_record"], bytes(old_offer).hex()
)
new_trade: TradeRecord = dataclasses.replace(
old_trade_record,
offer=bytes(new_offer),
trade_id=new_offer.name(),
)
response["trade_record"] = new_trade.to_json_dict_convenience()
if (
await self.service.wallet_state_manager.trade_manager.trade_store.get_trade_record(
old_trade_record.trade_id
)
is not None
):
await self.service.wallet_state_manager.trade_manager.trade_store.delete_trade_record(
old_trade_record.trade_id
)
await self.service.wallet_state_manager.trade_manager.save_trade(new_trade, new_offer)
for tx in await self.service.wallet_state_manager.tx_store.get_transactions_by_trade_id(
old_trade_record.trade_id
):
await self.service.wallet_state_manager.tx_store.add_transaction_record(
dataclasses.replace(tx, trade_id=new_trade.trade_id)
)

return response

return rpc_endpoint

return _inner
Loading

0 comments on commit 1062ab4

Please sign in to comment.