Skip to content

Commit

Permalink
Improve streamable (Chia-Network#3031)
Browse files Browse the repository at this point in the history
* Avoid importing `test_constants` as it takes a long time.

* Factor out `parse_*` functions.

* First crack at refactoring `Streamable.parse`.

* Don't add `_parse_functions` attribute to `Streamable`.

This no longer requires an extra `_parse_functions` attribute on a
`Streamable`, as it may be confusing serializers or other functions
that use `__annotations__`.

* Fix lint problems with `black`.

* Fix `parse_tuple`.

* Defer some parsing failures to parse time rather than class-creation time.

* Tidy up & remove some obsolete stuff.

* Decorate `RequestBlocks` as `streamable`.

* Fix wrong uses of Streamable class

Revert an earlier commit and error out on class creation in case a
Streamable subclass is instantiated incorrectly, e.g. containing a
non-serializable member.

Fix cases where Streamable parent class was forgotten.

* Fix wrong types when creating DerivationRecord and WalletCoinRecord

* additional unit tests for streamable parsers

* add type annotations (Chia-Network#3222)

Co-authored-by: Rostislav <[email protected]>
Co-authored-by: arvidn <[email protected]>
  • Loading branch information
3 people authored Apr 30, 2021
1 parent 9779286 commit b084813
Show file tree
Hide file tree
Showing 13 changed files with 321 additions and 134 deletions.
1 change: 1 addition & 0 deletions chia/protocols/full_node_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ class RejectBlock(Streamable):


@dataclass(frozen=True)
@streamable
class RequestBlocks(Streamable):
start_height: uint32
end_height: uint32
Expand Down
4 changes: 2 additions & 2 deletions chia/protocols/timelord_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ class NewEndOfSubSlotVDF(Streamable):

@dataclass(frozen=True)
@streamable
class RequestCompactProofOfTime:
class RequestCompactProofOfTime(Streamable):
new_proof_of_time: VDFInfo
header_hash: bytes32
height: uint32
Expand All @@ -83,7 +83,7 @@ class RequestCompactProofOfTime:

@dataclass(frozen=True)
@streamable
class RespondCompactProofOfTime:
class RespondCompactProofOfTime(Streamable):
vdf_info: VDFInfo
vdf_proof: VDFProof
header_hash: bytes32
Expand Down
3 changes: 1 addition & 2 deletions chia/types/generator_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,7 @@ class GeneratorArg(Streamable):


@dataclass(frozen=True)
@streamable
class CompressorArg(Streamable):
class CompressorArg:
"""`CompressorArg` is used as input to the Block Compressor"""

block_height: uint32
Expand Down
154 changes: 99 additions & 55 deletions chia/util/streamable.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import pprint
import sys
from enum import Enum
from typing import Any, BinaryIO, Dict, List, Tuple, Type
from typing import Any, BinaryIO, Dict, List, Tuple, Type, Callable, Optional

from blspy import G1Element, G2Element, PrivateKey

Expand Down Expand Up @@ -125,6 +125,9 @@ def recurse_jsonify(d):
return d


PARSE_FUNCTIONS_FOR_STREAMABLE_CLASS = {}


def streamable(cls: Any):
"""
This is a decorator for class definitions. It applies the strictdataclass decorator,
Expand Down Expand Up @@ -152,80 +155,121 @@ def streamable(cls: Any):
"""

cls1 = strictdataclass(cls)
return type(cls.__name__, (cls1, Streamable), {})
t = type(cls.__name__, (cls1, Streamable), {})

parse_functions = []
try:
fields = cls1.__annotations__ # pylint: disable=no-member
except Exception:
fields = {}

for _, f_type in fields.items():
parse_functions.append(cls.function_to_parse_one_item(f_type))

PARSE_FUNCTIONS_FOR_STREAMABLE_CLASS[t] = parse_functions
return t


def parse_bool(f: BinaryIO) -> bool:
bool_byte = f.read(1)
assert bool_byte is not None and len(bool_byte) == 1 # Checks for EOF
if bool_byte == bytes([0]):
return False
elif bool_byte == bytes([1]):
return True
else:
raise ValueError("Bool byte must be 0 or 1")


def parse_optional(f: BinaryIO, parse_inner_type_f: Callable[[BinaryIO], Any]) -> Optional[Any]:
is_present_bytes = f.read(1)
assert is_present_bytes is not None and len(is_present_bytes) == 1 # Checks for EOF
if is_present_bytes == bytes([0]):
return None
elif is_present_bytes == bytes([1]):
return parse_inner_type_f(f)
else:
raise ValueError("Optional must be 0 or 1")


def parse_bytes(f: BinaryIO) -> bytes:
list_size_bytes = f.read(4)
assert list_size_bytes is not None and len(list_size_bytes) == 4 # Checks for EOF
list_size: uint32 = uint32(int.from_bytes(list_size_bytes, "big"))
bytes_read = f.read(list_size)
assert bytes_read is not None and len(bytes_read) == list_size
return bytes_read


def parse_list(f: BinaryIO, parse_inner_type_f: Callable[[BinaryIO], Any]) -> List[Any]:
full_list: List = []
# wjb assert inner_type != get_args(List)[0]
list_size_bytes = f.read(4)
assert list_size_bytes is not None and len(list_size_bytes) == 4 # Checks for EOF
list_size = uint32(int.from_bytes(list_size_bytes, "big"))
for list_index in range(list_size):
full_list.append(parse_inner_type_f(f))
return full_list


def parse_tuple(f: BinaryIO, list_parse_inner_type_f: List[Callable[[BinaryIO], Any]]) -> Tuple[Any, ...]:
full_list = []
for parse_f in list_parse_inner_type_f:
full_list.append(parse_f(f))
return tuple(full_list)


def parse_size_hints(f: BinaryIO, f_type: Type, bytes_to_read: int) -> Any:
bytes_read = f.read(bytes_to_read)
assert bytes_read is not None and len(bytes_read) == bytes_to_read
return f_type.from_bytes(bytes_read)


def parse_str(f: BinaryIO) -> str:
str_size_bytes = f.read(4)
assert str_size_bytes is not None and len(str_size_bytes) == 4 # Checks for EOF
str_size: uint32 = uint32(int.from_bytes(str_size_bytes, "big"))
str_read_bytes = f.read(str_size)
assert str_read_bytes is not None and len(str_read_bytes) == str_size # Checks for EOF
return bytes.decode(str_read_bytes, "utf-8")


class Streamable:
@classmethod
def parse_one_item(cls: Type[cls.__name__], f_type: Type, f: BinaryIO): # type: ignore
def function_to_parse_one_item(cls: Type[cls.__name__], f_type: Type): # type: ignore
"""
This function returns a function taking one argument `f: BinaryIO` that parses
and returns a value of the given type.
"""
inner_type: Type
if f_type is bool:
bool_byte = f.read(1)
assert bool_byte is not None and len(bool_byte) == 1 # Checks for EOF
if bool_byte == bytes([0]):
return False
elif bool_byte == bytes([1]):
return True
else:
raise ValueError("Bool byte must be 0 or 1")
return parse_bool
if is_type_SpecificOptional(f_type):
inner_type = get_args(f_type)[0]
is_present_bytes = f.read(1)
assert is_present_bytes is not None and len(is_present_bytes) == 1 # Checks for EOF
if is_present_bytes == bytes([0]):
return None
elif is_present_bytes == bytes([1]):
return cls.parse_one_item(inner_type, f) # type: ignore
else:
raise ValueError("Optional must be 0 or 1")
parse_inner_type_f = cls.function_to_parse_one_item(inner_type)
return lambda f: parse_optional(f, parse_inner_type_f)
if hasattr(f_type, "parse"):
return f_type.parse(f)
return f_type.parse
if f_type == bytes:
list_size_bytes = f.read(4)
assert list_size_bytes is not None and len(list_size_bytes) == 4 # Checks for EOF
list_size: uint32 = uint32(int.from_bytes(list_size_bytes, "big"))
bytes_read = f.read(list_size)
assert bytes_read is not None and len(bytes_read) == list_size
return bytes_read
return parse_bytes
if is_type_List(f_type):
inner_type = get_args(f_type)[0]
full_list: List[inner_type] = [] # type: ignore
# wjb assert inner_type != get_args(List)[0] # type: ignore
list_size_bytes = f.read(4)
assert list_size_bytes is not None and len(list_size_bytes) == 4 # Checks for EOF
list_size = uint32(int.from_bytes(list_size_bytes, "big"))
for list_index in range(list_size):
full_list.append(cls.parse_one_item(inner_type, f)) # type: ignore
return full_list
parse_inner_type_f = cls.function_to_parse_one_item(inner_type)
return lambda f: parse_list(f, parse_inner_type_f)
if is_type_Tuple(f_type):
inner_types = get_args(f_type)
full_list = []
for inner_type in inner_types:
full_list.append(cls.parse_one_item(inner_type, f)) # type: ignore
return tuple(full_list)
list_parse_inner_type_f = [cls.function_to_parse_one_item(_) for _ in inner_types]
return lambda f: parse_tuple(f, list_parse_inner_type_f)
if hasattr(f_type, "from_bytes") and f_type.__name__ in size_hints:
bytes_to_read = size_hints[f_type.__name__]
bytes_read = f.read(bytes_to_read)
assert bytes_read is not None and len(bytes_read) == bytes_to_read
return f_type.from_bytes(bytes_read)
return lambda f: parse_size_hints(f, f_type, bytes_to_read)
if f_type is str:
str_size_bytes = f.read(4)
assert str_size_bytes is not None and len(str_size_bytes) == 4 # Checks for EOF
str_size: uint32 = uint32(int.from_bytes(str_size_bytes, "big"))
str_read_bytes = f.read(str_size)
assert str_read_bytes is not None and len(str_read_bytes) == str_size # Checks for EOF
return bytes.decode(str_read_bytes, "utf-8")
raise RuntimeError(f"Type {f_type} does not have parse")
return parse_str
raise NotImplementedError(f"Type {f_type} does not have parse")

@classmethod
def parse(cls: Type[cls.__name__], f: BinaryIO) -> cls.__name__: # type: ignore
values = []
try:
fields = cls.__annotations__ # pylint: disable=no-member
except Exception:
fields = {}
for _, f_type in fields.items():
values.append(cls.parse_one_item(f_type, f)) # type: ignore
values = [parse_f(f) for parse_f in PARSE_FUNCTIONS_FOR_STREAMABLE_CLASS[cls]]
return cls(*values)

def stream_one_item(self, f_type: Type, item, f: BinaryIO) -> None:
Expand Down
4 changes: 1 addition & 3 deletions chia/wallet/derivation_record.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,11 @@

from chia.types.blockchain_format.sized_bytes import bytes32
from chia.util.ints import uint32
from chia.util.streamable import Streamable, streamable
from chia.wallet.util.wallet_types import WalletType


@dataclass(frozen=True)
@streamable
class DerivationRecord(Streamable):
class DerivationRecord:
"""
These are records representing a puzzle hash, which is generated from a
public key, derivation index, and wallet type. Stored in the puzzle_store.
Expand Down
27 changes: 13 additions & 14 deletions chia/wallet/rl_wallet/rl_wallet.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,9 @@ async def create_rl_admin(
assert unused is not None

private_key = master_sk_to_wallet_sk(wallet_state_manager.private_key, unused)
pubkey_bytes: bytes = bytes(private_key.get_g1())
pubkey: G1Element = private_key.get_g1()

rl_info = RLInfo("admin", pubkey_bytes, None, None, None, None, None, None, False)
rl_info = RLInfo("admin", bytes(pubkey), None, None, None, None, None, None, False)
info_as_string = json.dumps(rl_info.to_json_dict())
wallet_info: Optional[WalletInfo] = await wallet_state_manager.user_store.create_wallet(
"RL Admin", WalletType.RATE_LIMITED, info_as_string
Expand All @@ -81,8 +81,8 @@ async def create_rl_admin(
[
DerivationRecord(
unused,
token_bytes(),
pubkey_bytes,
bytes32(token_bytes(32)),
pubkey,
WalletType.RATE_LIMITED,
wallet_info.id,
)
Expand All @@ -107,9 +107,9 @@ async def create_rl_user(

private_key = wallet_state_manager.private_key

pubkey_bytes: bytes = bytes(master_sk_to_wallet_sk(private_key, unused).get_g1())
pubkey: G1Element = master_sk_to_wallet_sk(private_key, unused).get_g1()

rl_info = RLInfo("user", None, pubkey_bytes, None, None, None, None, None, False)
rl_info = RLInfo("user", None, bytes(pubkey), None, None, None, None, None, False)
info_as_string = json.dumps(rl_info.to_json_dict())
await wallet_state_manager.user_store.create_wallet("RL User", WalletType.RATE_LIMITED, info_as_string)
wallet_info = await wallet_state_manager.user_store.get_last_wallet()
Expand All @@ -122,8 +122,8 @@ async def create_rl_user(
[
DerivationRecord(
unused,
token_bytes(),
pubkey_bytes,
bytes32(token_bytes(32)),
pubkey,
WalletType.RATE_LIMITED,
wallet_info.id,
)
Expand Down Expand Up @@ -190,7 +190,7 @@ async def admin_create_coin(
record = DerivationRecord(
index,
rl_puzzle_hash,
self.rl_info.admin_pubkey,
G1Element.from_bytes(self.rl_info.admin_pubkey),
WalletType.RATE_LIMITED,
self.id(),
)
Expand Down Expand Up @@ -265,14 +265,13 @@ async def set_user_info(
raise ValueError(
"Cannot create multiple Rate Limited wallets under the same keys. This will change in a future release."
)
index = await self.wallet_state_manager.puzzle_store.index_for_pubkey(
G1Element.from_bytes(self.rl_info.user_pubkey)
)
user_pubkey: G1Element = G1Element.from_bytes(self.rl_info.user_pubkey)
index = await self.wallet_state_manager.puzzle_store.index_for_pubkey(user_pubkey)
assert index is not None
record = DerivationRecord(
index,
rl_puzzle_hash,
self.rl_info.user_pubkey,
user_pubkey,
WalletType.RATE_LIMITED,
self.id(),
)
Expand All @@ -281,7 +280,7 @@ async def set_user_info(
record2 = DerivationRecord(
index + 1,
aggregation_puzzlehash,
self.rl_info.user_pubkey,
user_pubkey,
WalletType.RATE_LIMITED,
self.id(),
)
Expand Down
4 changes: 1 addition & 3 deletions chia/wallet/wallet_action.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,11 @@
from typing import Optional

from chia.util.ints import uint32
from chia.util.streamable import Streamable, streamable
from chia.wallet.util.wallet_types import WalletType


@dataclass(frozen=True)
@streamable
class WalletAction(Streamable):
class WalletAction:
"""
This object represents the wallet action as it is stored in the database.
Expand Down
4 changes: 1 addition & 3 deletions chia/wallet/wallet_coin_record.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,11 @@
from chia.types.blockchain_format.coin import Coin
from chia.types.blockchain_format.sized_bytes import bytes32
from chia.util.ints import uint32
from chia.util.streamable import Streamable, streamable
from chia.wallet.util.wallet_types import WalletType


@dataclass(frozen=True)
@streamable
class WalletCoinRecord(Streamable):
class WalletCoinRecord:
"""
These are values that correspond to a CoinName that are used
in keeping track of the unspent database.
Expand Down
Loading

0 comments on commit b084813

Please sign in to comment.