From 88401b61ae6b800e0de324e4f133e59bda5d7a79 Mon Sep 17 00:00:00 2001 From: Sergi Delgado Segura Date: Mon, 7 Sep 2020 16:26:57 +0200 Subject: [PATCH] lnnet - Some code improvements and fixes from @bigspider's review --- common/net/bigsize.py | 15 ++++++++------- common/net/bolt1.py | 2 +- common/net/bolt9.py | 8 +++++--- common/net/tlv.py | 2 +- 4 files changed, 15 insertions(+), 12 deletions(-) diff --git a/common/net/bigsize.py b/common/net/bigsize.py index 74da2b32..c547b796 100644 --- a/common/net/bigsize.py +++ b/common/net/bigsize.py @@ -10,7 +10,7 @@ def encode(value): Raises: :obj:`TypeError`: If the provided value is not an integer. - :obj:`ValueError`: If the provided value is negative or bigger than ``pow(2, 64)``. + :obj:`ValueError`: If the provided value is negative or bigger than ``pow(2, 64) - 1``. """ if not isinstance(value, int): @@ -19,13 +19,13 @@ def encode(value): if value < 0: raise ValueError(f"value must be a positive integer, {value} received") - if value < pow(2, 8) - 3: + if value < 253: return value.to_bytes(1, "big") elif value < pow(2, 16): return b"\xfd" + value.to_bytes(2, "big") elif value < pow(2, 32): return b"\xfe" + value.to_bytes(4, "big") - elif value <= pow(2, 64): + elif value < pow(2, 64): return b"\xff" + value.to_bytes(8, "big") else: raise ValueError("BigSize can only encode up to 8-byte values") @@ -49,8 +49,8 @@ def decode(value): if not isinstance(value, bytes): raise TypeError(f"value must be bytes, {type(value)} received") - if len(value) > 9: - raise ValueError(f"value must be, at most, 9-bytes long, {len(value)} received") + if not 0 < len(value) <= 9: + raise ValueError(f"value must be between 1-9 bytes long (both included), {len(value)} received") if len(value) > 1: prefix = value[0] @@ -59,9 +59,9 @@ def decode(value): prefix = None decoded_value = int.from_bytes(value, "big") - if not prefix and len(value) == 1 and decoded_value < pow(2, 8) - 3: + if not prefix and len(value) == 1 and decoded_value < 253: return decoded_value - elif prefix == 253 and len(value) == 3 and pow(2, 8) - 3 <= decoded_value < pow(2, 16): + elif prefix == 253 and len(value) == 3 and 253 <= decoded_value < pow(2, 16): return decoded_value elif prefix == 254 and len(value) == 5 and pow(2, 16) <= decoded_value < pow(2, 32): return decoded_value @@ -93,6 +93,7 @@ def parse(value): prefix = value[0] + # message length is not explicitly checked here, but wrong length will fail at decode. if prefix < 253: # prefix is actually the value to be parsed return decode(value[0:1]), 1 diff --git a/common/net/bolt1.py b/common/net/bolt1.py index fbef7263..f8020a82 100644 --- a/common/net/bolt1.py +++ b/common/net/bolt1.py @@ -207,7 +207,7 @@ class PingMessage(Message): def __init__(self, num_pong_bytes, ignored_bytes=None): if not 0 <= num_pong_bytes < pow(2, 16): - raise ValueError(f"num_pong_bytes must be between 0 and {pow(2, 16)}") + raise ValueError(f"num_pong_bytes must be between 0 and {pow(2, 16) - 1}") payload = num_pong_bytes.to_bytes(2, "big") diff --git a/common/net/bolt9.py b/common/net/bolt9.py index d2ff3176..89347f48 100644 --- a/common/net/bolt9.py +++ b/common/net/bolt9.py @@ -109,13 +109,16 @@ def from_bytes(cls, features): int_features = int.from_bytes(features, "big") padding = max(2 * len(known_features), int_features.bit_length()) - padding = padding + 1 if padding % 2 else padding + padding += padding % 2 # round up to the nearest even number bit_features = f"{int_features:b}".zfill(padding) bit_pairs = [bit_features[i : i + 2] for i in range(0, len(bit_features), 2)] features_dict = {} for i, pair in enumerate(reversed(bit_pairs)): + if pair == "11": + raise ValueError("Both odd and even bits cannot be set in a pair") + # Known features are stored no matter if they are set or not odd_bit = 2 * i feature_name = known_odd_bits.get(odd_bit) @@ -126,8 +129,7 @@ def from_bytes(cls, features): features_dict[feature_name] = Feature(odd_bit, is_set=True) elif pair == "10": features_dict[feature_name] = Feature(odd_bit + 1, is_set=True) - else: - raise ValueError("Both odd and even bits cannot be set in a pair") + # For unknown features, we only store the ones that are set else: feature_name = f"unknown_{odd_bit}" diff --git a/common/net/tlv.py b/common/net/tlv.py index d7fc3a8c..1870373b 100644 --- a/common/net/tlv.py +++ b/common/net/tlv.py @@ -39,7 +39,7 @@ def __len__(self): return len(self.serialize()) def __eq__(self, other): - return isinstance(other, TLVRecord) and self.value == other.value + return isinstance(other, TLVRecord) and self.serialize() == other.serialize() @classmethod def from_bytes(cls, message):