Skip to content

Commit

Permalink
debuglink and tests fixes
Browse files Browse the repository at this point in the history
[no changelog]
  • Loading branch information
mmilata authored and M1nd3r committed Feb 15, 2025
1 parent 44ccf7a commit 186034c
Show file tree
Hide file tree
Showing 9 changed files with 20 additions and 73 deletions.
2 changes: 1 addition & 1 deletion python/src/trezorlib/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def get_session(

if isinstance(self.protocol, ProtocolV1Channel):
return SessionV1.new(self, passphrase, derive_cardano)
raise NotImplementedError # TODO
raise NotImplementedError

def resume_session(self, session: Session):
"""
Expand Down
63 changes: 13 additions & 50 deletions python/src/trezorlib/debuglink.py
Original file line number Diff line number Diff line change
Expand Up @@ -798,7 +798,7 @@ def __init__(self, debuglink: DebugLink) -> None:

def clear(self) -> None:
self.pins: t.Iterator[str] | None = None
self.passphrase = ""
self.passphrase = None
self.input_flow: t.Union[
t.Generator[None, messages.ButtonRequest, None], object, None
] = None
Expand Down Expand Up @@ -848,7 +848,7 @@ def get_pin(self, code: PinMatrixRequestType | None = None) -> str:
except StopIteration:
raise AssertionError("PIN sequence ended prematurely")

def get_passphrase(self, available_on_device: bool) -> str:
def get_passphrase(self, available_on_device: bool) -> str | None | object:
self.debuglink.snapshot_legacy()
return self.passphrase

Expand Down Expand Up @@ -968,6 +968,10 @@ def client(self) -> TrezorClientDebugLink:
def id(self) -> bytes:
return self._session.id

@property
def passphrase(self) -> str | None | object:
return self._session.passphrase

def _write(self, msg: t.Any) -> None:
print("writing message:", msg.__class__.__name__)
self._session._write(self._filter_message(msg))
Expand Down Expand Up @@ -1090,7 +1094,6 @@ def reset_debug_features(self) -> None:
self.button_callback = self.client.button_callback
self.pin_callback = self.client.pin_callback
self.passphrase_callback = self._session.passphrase_callback
self.passphrase = self._session.passphrase

def __enter__(self) -> "SessionDebugWrapper":
# For usage in with/expected_responses
Expand Down Expand Up @@ -1224,7 +1227,6 @@ def __init__(self, transport: "Transport", auto_interact: bool = True) -> None:
# and know the supported debug capabilities
self.debug.model = self.model
self.debug.version = self.version
self.passphrase: str | None = None

@property
def layout_type(self) -> LayoutType:
Expand Down Expand Up @@ -1308,7 +1310,7 @@ def send_passphrase(
msg = messages.PassphraseAck(passphrase=passphrase, on_device=on_device)
resp = session.call_raw(msg)
if isinstance(resp, messages.Deprecated_PassphraseStateRequest):
# session.session_id = resp.state
session._session.id = resp.state
resp = session.call_raw(messages.Deprecated_PassphraseStateAck())
return resp

Expand All @@ -1317,12 +1319,16 @@ def send_passphrase(
return send_passphrase(None, None)

try:
if session.passphrase is None and isinstance(session, SessionV1):
if isinstance(session, SessionV1) or isinstance(
session, SessionDebugWrapper
):
passphrase = self.ui.get_passphrase(
available_on_device=available_on_device
)
if passphrase is None:
passphrase = session.passphrase
else:
passphrase = session.passphrase
raise NotImplementedError
except Cancelled:
session.call_raw(messages.Cancel())
raise
Expand Down Expand Up @@ -1376,33 +1382,6 @@ def get_session(
passphrase = Mnemonic.normalize_string(passphrase)
return super().get_session(passphrase, derive_cardano, session_id)

def set_filter(
self,
message_type: t.Type[protobuf.MessageType],
callback: t.Callable[[protobuf.MessageType], protobuf.MessageType] | None,
) -> None:
"""Configure a filter function for a specified message type.
The `callback` must be a function that accepts a protobuf message, and returns
a (possibly modified) protobuf message of the same type. Whenever a message
is sent or received that matches `message_type`, `callback` is invoked on the
message and its result is substituted for the original.
Useful for test scenarios with an active malicious actor on the wire.
"""
if not self.in_with_statement:
raise RuntimeError("Must be called inside 'with' statement")

self.filters[message_type] = callback

def _filter_message(self, msg: protobuf.MessageType) -> protobuf.MessageType:
message_type = msg.__class__
callback = self.filters.get(message_type)
if callable(callback):
return callback(deepcopy(msg))
else:
return msg

def set_input_flow(
self, input_flow: InputFlowType | t.Callable[[], InputFlowType]
) -> None:
Expand Down Expand Up @@ -1534,27 +1513,11 @@ def use_pin_sequence(self, pins: t.Iterable[str]) -> None:
"""
self.ui.pins = iter(pins)

def use_passphrase(self, passphrase: str) -> None:
"""Respond to passphrase prompts from device with the provided passphrase."""
self.passphrase = passphrase
self.ui.passphrase = Mnemonic.normalize_string(passphrase)

def use_mnemonic(self, mnemonic: str) -> None:
"""Use the provided mnemonic to respond to device.
Only applies to T1, where device prompts the host for mnemonic words."""
self.mnemonic = Mnemonic.normalize_string(mnemonic).split(" ")

def _raw_read(self) -> protobuf.MessageType:
__tracebackhide__ = True # for pytest # pylint: disable=W0612
resp = self.get_seedless_session()._read()
resp = self._filter_message(resp)
if self.actual_responses is not None:
self.actual_responses.append(resp)
return resp

def _raw_write(self, msg: protobuf.MessageType) -> None:
return self.get_seedless_session()._write(self._filter_message(msg))

@staticmethod
def _expectation_lines(expected: list[MessageFilter], current: int) -> list[str]:
start_at = max(current - EXPECTED_RESPONSES_CONTEXT_LINES, 0)
Expand Down
1 change: 0 additions & 1 deletion tests/click_tests/test_passphrase_delizia.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,6 @@ def prepare_passphrase_dialogue(
) -> Generator["DebugLink", None, None]:
debug = device_handler.debuglink()
device_handler.run_with_session(get_test_address) # type: ignore
# TODO
assert debug.read_layout().main_component() == "PassphraseKeyboard"

# Resetting the category as it could have been changed by previous tests
Expand Down
8 changes: 4 additions & 4 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,9 +364,6 @@ def _client_unlocked(
if request.node.get_closest_marker("experimental"):
apply_settings(session, experimental_features=True)

if use_passphrase and isinstance(setup_params["passphrase"], str):
_raw_client.use_passphrase(setup_params["passphrase"])

# TODO _raw_client.clear_session()

yield _raw_client
Expand All @@ -391,7 +388,10 @@ def session(
session = _client_unlocked.get_seedless_session()
else:
derive_cardano = bool(request.node.get_closest_marker("cardano"))
passphrase = _client_unlocked.passphrase or ""
passphrase = ""
marker = request.node.get_closest_marker("setup_client")
if marker and isinstance(marker.kwargs.get("passphrase"), str):
passphrase = marker.kwargs["passphrase"]
if _client_unlocked._setup_pin is not None:
_client_unlocked.use_pin_sequence([_client_unlocked._setup_pin])
session = _client_unlocked.get_session(
Expand Down
2 changes: 1 addition & 1 deletion tests/device_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,8 @@ def run_with_session(
raise RuntimeError("Wait for previous task first")

# wait for the first UI change triggered by the task running in the background
session = self.client.get_session()
with self.debuglink().wait_for_layout_change():
session = self.client.get_session()
self.task = self._pool.submit(function, session, *args, **kwargs)

def run_with_provided_session(
Expand Down
4 changes: 0 additions & 4 deletions tests/device_tests/test_msg_loaddevice.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,6 @@ def test_load_device_utf(client: Client):
skip_checksum=True,
)
session: Session = session.client.get_session(passphrase=passphrase_nfkd)
session.client.use_passphrase(passphrase_nfkd) # TODO is needed?
address_nfkd = get_test_address(session)

device.wipe(session)
Expand All @@ -139,7 +138,6 @@ def test_load_device_utf(client: Client):
skip_checksum=True,
)
session = client.get_session(passphrase=passphrase_nfc)
session.client.use_passphrase(passphrase_nfc) # TODO is needed?
address_nfc = get_test_address(session)

device.wipe(session)
Expand All @@ -154,7 +152,6 @@ def test_load_device_utf(client: Client):
skip_checksum=True,
)
session = client.get_session(passphrase=passphrase_nfkc)
session.client.use_passphrase(passphrase_nfkc) # TODO is needed?
address_nfkc = get_test_address(session)

device.wipe(session)
Expand All @@ -169,7 +166,6 @@ def test_load_device_utf(client: Client):
skip_checksum=True,
)
session = client.get_session(passphrase=passphrase_nfd)
session.client.use_passphrase(passphrase_nfd) # TODO is needed?
address_nfd = get_test_address(session)
assert address_nfkd == address_nfc
assert address_nfkd == address_nfkc
Expand Down
1 change: 0 additions & 1 deletion tests/device_tests/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,6 @@ def test_session_recycling(client: Client):
messages.Address,
]
)
client.use_passphrase("TREZOR")
_ = get_test_address(session)
# address = get_test_address(session)

Expand Down
10 changes: 0 additions & 10 deletions tests/device_tests/test_session_id_and_passphrase.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import pytest

from trezorlib import device, exceptions, messages
from trezorlib.client import ProtocolVersion
from trezorlib.debuglink import LayoutType
from trezorlib.debuglink import SessionDebugWrapper as Session
from trezorlib.debuglink import TrezorClientDebugLink as Client
Expand Down Expand Up @@ -54,7 +53,6 @@
def _get_xpub(
session: Session,
expected_passphrase_req: bool = False,
passphrase_v1: str | None = None,
):
"""Get XPUB and check that the appropriate passphrase flow has happened."""
if expected_passphrase_req:
Expand All @@ -66,11 +64,6 @@ def _get_xpub(
]
else:
expected_responses = [messages.PublicKey]
if (
passphrase_v1 is not None
and session.protocol_version == ProtocolVersion.PROTOCOL_V1
):
session.passphrase = passphrase_v1

with session:
session.set_expected_responses(expected_responses)
Expand Down Expand Up @@ -228,7 +221,6 @@ def test_max_sessions_with_passphrases(client: Client):
_get_xpub(
resumed_session,
expected_passphrase_req=True,
passphrase_v1="whatever",
) # passphrase is prompted


Expand Down Expand Up @@ -435,7 +427,6 @@ def input_flow():
messages.PublicKey,
]
)
client.use_passphrase(passphrase)
result = session.call(XPUB_REQUEST)
assert isinstance(result, messages.PublicKey)
xpub_hidden_passphrase = result.xpub
Expand Down Expand Up @@ -471,7 +462,6 @@ def input_flow():
messages.PublicKey,
]
)
client.use_passphrase(passphrase)
result = session.call(XPUB_REQUEST)
assert isinstance(result, messages.PublicKey)
xpub_shown_passphrase = result.xpub
Expand Down
2 changes: 1 addition & 1 deletion tests/upgrade_tests/test_firmware_upgrades.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,7 +384,7 @@ def test_upgrade_shamir_backup(gen: str, tag: Optional[str]):
address = btc.get_address(session, "Bitcoin", PATH)
if session.protocol_version == ProtocolVersion.PROTOCOL_V1:
session.call(messages.Initialize(new_session=True))
new_session = emu.client.get_session(passphrase="TREZOR")
new_session = Session(emu.client.get_session(passphrase="TREZOR"))
address_passphrase = btc.get_address(new_session, "Bitcoin", PATH)

assert emu.client.features.backup_availability == BackupAvailability.Required
Expand Down

0 comments on commit 186034c

Please sign in to comment.