Skip to content

Commit

Permalink
refact: mypy and flake fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
renatav committed Jan 18, 2025
1 parent 40e6a6e commit 56a9694
Show file tree
Hide file tree
Showing 8 changed files with 86 additions and 56 deletions.
7 changes: 4 additions & 3 deletions taf/api/yubikey.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
from taf.log import taf_logger
from taf.tuf.keys import get_sslib_key_from_value
from taf.tuf.repository import MAIN_ROLES
import taf.yubikey as yk
import taf.yubikey.yubikey as yk
from taf.yubikey.yubikey_manager import PinManager


@log_on_start(DEBUG, "Exporting public pem from YubiKey", logger=taf_logger)
Expand Down Expand Up @@ -129,7 +130,7 @@ def get_yk_roles(path: str) -> Dict:
reraise=True,
)
def setup_signing_yubikey(
certs_dir: Optional[str] = None, key_size: int = 2048
pin_manager: PinManager, certs_dir: Optional[str] = None, key_size: int = 2048
) -> None:
"""
Delete everything from the inserted YubiKey, generate a new key and copy it to the YubiKey.
Expand All @@ -155,7 +156,7 @@ def setup_signing_yubikey(
pin_repeat=True,
prompt_message="Please insert the new Yubikey and press ENTER",
)
key = yk.setup_new_yubikey(serial_num, key_size=key_size)
key = yk.setup_new_yubikey(pin_manager, serial_num, key_size=key_size)
yk.export_yk_certificate(certs_dir, key)


Expand Down
4 changes: 3 additions & 1 deletion taf/auth_repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,9 @@ def __init__(
self.conf_directory_root = conf_directory_root_path.resolve()
self.out_of_band_authentication = out_of_band_authentication
self._storage = GitStorageBackend()
self._tuf_repository = TUFRepository(self.path, storage=self._storage, pin_manager=pin_manager)
self._tuf_repository = TUFRepository(
self.path, storage=self._storage, pin_manager=pin_manager
)

def __getattr__(self, item):
"""Delegate attribute lookup to TUFRepository instance"""
Expand Down
40 changes: 29 additions & 11 deletions taf/keys.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from collections import defaultdict
from functools import partial
from logging import INFO
from typing import Dict, List, Optional, Tuple, Union
Expand Down Expand Up @@ -218,7 +217,11 @@ def _load_and_append_yubikeys(
signer = YkSigner(
public_key,
serial_num,
partial(yk.yk_secrets_handler, pin_manager=taf_repo.pin_manager, serial_num=serial_num),
partial(
yk.yk_secrets_handler,
pin_manager=taf_repo.pin_manager,
serial_num=serial_num,
),
key_name=key_name,
)
signers_yubikeys.append(signer)
Expand Down Expand Up @@ -420,8 +423,12 @@ def _setup_yubikey_roles_keys(
signer = YkSigner(
public_key,
serial_num,
partial(yk.yk_secrets_handler, pin_manager=auth_repo.pin_manager, serial_num=serial_num),
key_name=key_name,
partial(
yk.yk_secrets_handler,
pin_manager=auth_repo.pin_manager,
serial_num=serial_num,
),
key_name=key_name,
)
signers.append(signer)
keyid_name_mapping[_get_legacy_keyid(public_key)] = key_name
Expand All @@ -437,15 +444,22 @@ def _setup_yubikey_roles_keys(
):
continue
serial_num = _load_and_verify_yubikey(
yubikeys, role.name, key_name, public_key
role.name,
key_name,
public_key,
taf_repo=auth_repo,
)
if serial_num:
loaded_keys_num += 1
loaded_keys.append(key_name)
signer = YkSigner(
public_key,
partial(yk.yk_secrets_handler, pin_manager=auth_repo.pin_manager, serial_num=serial_num),
key_name=key_name
partial(
yk.yk_secrets_handler,
pin_manager=auth_repo.pin_manager,
serial_num=serial_num,
),
key_name=key_name,
)
signers.append(signer)
if loaded_keys_num == role.threshold:
Expand Down Expand Up @@ -572,26 +586,30 @@ def _setup_yubikey(
print("Key already loaded. Please insert a different YubiKey")
else:
if not use_existing:
key = yk.setup_new_yubikey(serial_num, scheme, key_size=key_size)
key = yk.setup_new_yubikey(
auth_repo.pin_manager, serial_num, scheme, key_size=key_size
)

if certs_dir is not None:
yk.export_yk_certificate(certs_dir, key, serial=serial_num)
return key, serial_num


def _load_and_verify_yubikey(
yubikeys: Optional[Dict], role_name: str, key_name: str, public_key
role_name: str,
key_name: str,
public_key,
taf_repo: TUFRepository,
) -> Optional[str]:
if not click.confirm(f"Sign using {key_name} Yubikey?"):
return None
while True:
yk_public_key, _ = yk.yubikey_prompt(
key_name,
role_name,
taf_repo=None,
taf_repo=taf_repo,
registering_new_key=True,
creating_new_key=False,
loaded_yubikeys=yubikeys,
pin_confirm=True,
pin_repeat=True,
)
Expand Down
2 changes: 1 addition & 1 deletion taf/tools/yubikey/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from taf.exceptions import YubikeyError
from taf.repository_utils import find_valid_repository
from taf.tools.cli import catch_cli_exception
from taf.yubikey import list_connected_yubikeys
from taf.yubikey.yubikey import list_connected_yubikeys


def check_pin_command():
Expand Down
7 changes: 6 additions & 1 deletion taf/tuf/keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,11 @@ class YkSigner(Signer):
_SECRET_PROMPT = "pin"

def __init__(
self, public_key: SSlibKey, serial_num: str, pin_handler: SecretsHandler, key_name: str
self,
public_key: SSlibKey,
serial_num: str,
pin_handler: SecretsHandler,
key_name: str,
):

self._public_key = public_key
Expand Down Expand Up @@ -227,6 +231,7 @@ def import_(cls) -> SSlibKey:
def sign(self, payload: bytes) -> Signature:
pin = self._pin_handler(self._SECRET_PROMPT)
from taf.yubikey.yubikey import sign_piv_rsa_pkcs1v15, verify_yk_inserted

verify_yk_inserted(self.serial_num, self.key_name)
sig = sign_piv_rsa_pkcs1v15(payload, pin, serial=self.serial_num)
return Signature(self.public_key.keyid, sig.hex())
Expand Down
35 changes: 19 additions & 16 deletions taf/tuf/repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,6 @@ def calculate_length(self, md: Metadata) -> int:
data = md.to_bytes(serializer=self.serializer)
return len(data)


def check_if_keys_loaded(self, role_name: str) -> bool:
"""
Check if at least a threshold of signers of the specified role
Expand Down Expand Up @@ -536,19 +535,7 @@ def create(
sn = Snapshot()
sn.meta["root.json"] = MetaFile(1)

public_keys = {
role_name: {
_get_legacy_keyid(signer.public_key): signer.public_key
for signer in role_signers
}
for role_name, role_signers in signers.items()
}
if additional_verification_keys:
for role_name, roles_public_keys in additional_verification_keys.items():
for public_key in roles_public_keys:
key_id = _get_legacy_keyid(public_key)
if key_id not in public_keys[role_name]:
public_keys[role_name][key_id] = public_key
public_keys = self._process_keys(signers, additional_verification_keys)

for role in RolesIterator(roles_keys_data.roles, include_delegations=False):
if signers.get(role.name) is None:
Expand All @@ -559,7 +546,9 @@ def create(
for public_key in public_keys[role.name].values():
key_id = _get_legacy_keyid(public_key)
if key_id in self.keys_name_mappings:
public_key.unrecognized_fields["name"] = self.keys_name_mappings[key_id]
public_key.unrecognized_fields["name"] = self.keys_name_mappings[
key_id
]
root.add_key(public_key, role.name)
root.roles[role.name].threshold = role.threshold

Expand Down Expand Up @@ -606,11 +595,25 @@ def create(
signed.version = 0 # `close` will bump to initial valid verison 1
self.close(name, Metadata(signed))

def _process_keys(self, signers, additional_verification_keys):
public_keys = {}
for role_name, role_signers in signers.items():
public_keys[role_name] = {}
for signer in role_signers:
key_id = self._get_legacy_keyid(signer.public_key)
public_keys[role_name][key_id] = signer.public_key

if additional_verification_keys:
for role_name, keys in additional_verification_keys.items():
for public_key in keys:
key_id = self._get_legacy_keyid(public_key)
public_keys[role_name][key_id] = public_key
return public_keys

def create_delegated_roles(
self,
roles_data: List[TargetsRole],
signers: Dict[str, List[CryptoSigner]],
key_name_mappings: Optional[Dict[str, str]] = None,
) -> Tuple[List, List]:
"""
Create a new delegated roles, signes them using the provided signers and
Expand Down
28 changes: 13 additions & 15 deletions taf/yubikey/yubikey.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import datetime
from contextlib import contextmanager
from functools import wraps
from collections import defaultdict
from getpass import getpass
from pathlib import Path
from typing import Callable, Dict, Optional
from typing import Callable, Optional

import click
from cryptography import x509
Expand All @@ -15,6 +14,7 @@
from cryptography.hazmat.primitives.asymmetric import rsa, padding

from taf.tuf.keys import get_sslib_key_from_value
from taf.yubikey.yubikey_manager import PinManager
from ykman.device import list_all_devices
from yubikit.core.smartcard import SmartCardConnection
from ykman.piv import (
Expand Down Expand Up @@ -275,8 +275,6 @@ def list_connected_yubikeys():
print(f" Form Factor: {info.form_factor}")


# TODO
# need to pass in multiple key names
def _read_and_check_yubikeys(
key_name,
role,
Expand All @@ -299,7 +297,11 @@ def _read_and_check_yubikeys(
try:
serials = get_serial_num()
if require_single_yubikey:
not_loaded = [serial for serial in serials if not taf_repo.yubikey_store.is_loaded(serial)]
not_loaded = [
serial
for serial in serials
if not taf_repo.yubikey_store.is_loaded(serial)
]
if len(not_loaded) > 1:
print("\nPlease insert only one YubiKey\n")
return None
Expand Down Expand Up @@ -340,15 +342,8 @@ def _read_and_check_yubikeys(
# but the key name still needs to be added to the key id mapping dictionary
taf_repo.yubikey_store.add_key_data(key_name, serial_num, public_key)

# if role is not None:
# if loaded_yubikeys is None:
# loaded_yubikeys = {serial_num: [role]}
# else:
# loaded_yubikeys.setdefault(serial_num, []).append(role)

yubikeys.append((public_key, serial_num))

# TODO error messages
return yubikeys


Expand Down Expand Up @@ -466,9 +461,12 @@ def setup(


def setup_new_yubikey(
serial_num, scheme=DEFAULT_RSA_SIGNATURE_SCHEME, key_size=2048
pin_manager: PinManager,
serial_num: str,
scheme: Optional[str] = DEFAULT_RSA_SIGNATURE_SCHEME,
key_size: Optional[int] = 2048,
) -> SSlibKey:
pin = get_key_pin(serial_num)
pin = pin_manager.get_pin(serial_num)
cert_cn = input("Enter key holder's name: ")
print("Generating key, please wait...")
pub_key_pem = setup(
Expand All @@ -480,7 +478,6 @@ def setup_new_yubikey(


def verify_yk_inserted(serial_num, key_name):

def _check_if_yk_inserted():
try:
serials = get_serial_num()
Expand All @@ -493,6 +490,7 @@ def _check_if_yk_inserted():
prompt_message = f"Please insert {key_name} YubiKey and press ENTER"
getpass(prompt_message)


def yubikey_prompt(
key_name,
role=None,
Expand Down
19 changes: 11 additions & 8 deletions taf/yubikey/yubikey_manager.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from collections import defaultdict
import contextlib
from typing import Any, Dict, Optional, Tuple
from typing import Tuple
from taf.tuf.keys import SSlibKey


Expand All @@ -10,18 +10,22 @@ def __init__(self):
self._yubikeys_data = defaultdict(dict)

def is_loaded(self, serial_number):
return any(data["serial"] == serial_number for data in self._yubikeys_data.values())
return any(
data["serial"] == serial_number for data in self._yubikeys_data.values()
)

def is_key_name_loaded(self, key_name: str) -> bool:
"""Check if the key name is already loaded."""
return key_name in self._yubikeys_data

def add_key_data(self, key_name: str, serial_num: str, public_key: SSlibKey) -> None:
def add_key_data(
self, key_name: str, serial_num: str, public_key: SSlibKey
) -> None:
"""Add data associated with a YubiKey."""
if not self.is_key_name_loaded(key_name):
self._yubikeys_data[key_name] = {
"serial": serial_num,
"public_key": public_key
"public_key": public_key,
}

def get_key_data(self, key_name: str) -> Tuple[str, SSlibKey]:
Expand All @@ -37,8 +41,7 @@ def remove_key_data(self, key_name: str) -> bool:
return False


class PinManager():

class PinManager:
def __init__(self):
self._pins = {}

Expand All @@ -54,7 +57,6 @@ def get_pin(self, serial_number):
return self._pins.get(serial_number)



@contextlib.contextmanager
def manage_pins():
pin_manager = PinManager()
Expand All @@ -67,6 +69,7 @@ def manage_pins():
def pin_managed(func):
def wrapper(*args, **kwargs):
with manage_pins() as pin_manager:
kwargs['pin_manager'] = pin_manager
kwargs["pin_manager"] = pin_manager
return func(*args, **kwargs)

return wrapper

0 comments on commit 56a9694

Please sign in to comment.