From 34ec01f65bfb06737eed92a3af7e364ecb6bb91d Mon Sep 17 00:00:00 2001 From: Daniel Elero Date: Mon, 30 Sep 2019 18:58:49 +0200 Subject: [PATCH] Black code formatting --- .flake8 | 4 +- setup.py | 109 ++--- taf/__init__.py | 24 +- taf/auth_repo.py | 279 ++++++------ taf/cli.py | 399 ++++++++++++----- taf/constants.py | 2 +- taf/developer_tool.py | 789 +++++++++++++++++++--------------- taf/exceptions.py | 55 +-- taf/git.py | 663 +++++++++++++++------------- taf/log.py | 54 +-- taf/repositoriesdb.py | 431 +++++++++++-------- taf/repository_tool.py | 587 +++++++++++++------------ taf/settings.py | 5 +- taf/updater/handlers.py | 424 +++++++++--------- taf/updater/updater.py | 533 +++++++++++++---------- taf/utils.py | 196 +++++---- taf/validation.py | 129 +++--- taf/yubikey.py | 236 +++++----- tests/__init__.py | 2 +- tests/conftest.py | 216 +++++----- tests/test_add_targets.py | 175 ++++---- tests/test_repository.py | 26 +- tests/test_repository_tool.py | 217 +++++----- tests/test_updater.py | 540 +++++++++++++---------- tests/test_utils.py | 18 +- tests/test_yubikey.py | 44 +- tests/yubikey_utils.py | 186 ++++---- 27 files changed, 3539 insertions(+), 2804 deletions(-) diff --git a/.flake8 b/.flake8 index d9ad0b409..4c792ef88 100644 --- a/.flake8 +++ b/.flake8 @@ -1,5 +1,5 @@ [flake8] ignore = E203, E266, E501, W503, F403, F401 -max-line-length = 79 -max-complexity = 18 +max-line-length = 88 +max-complexity = 20 select = B,C,E,F,W,T4,B9 diff --git a/setup.py b/setup.py index c97cd1193..020a03aba 100644 --- a/setup.py +++ b/setup.py @@ -1,100 +1,79 @@ from setuptools import find_packages, setup -PACKAGE_NAME = 'taf' -VERSION = '0.1.7' -AUTHOR = 'Open Law Library' -AUTHOR_EMAIL = 'info@openlawlib.org' -DESCRIPTION = 'Implementation of archival authentication' -KEYWORDS = 'update updater secure authentication archival' -URL = 'https://github.com/openlawlibrary/taf/tree/master' +PACKAGE_NAME = "taf" +VERSION = "0.1.7" +AUTHOR = "Open Law Library" +AUTHOR_EMAIL = "info@openlawlib.org" +DESCRIPTION = "Implementation of archival authentication" +KEYWORDS = "update updater secure authentication archival" +URL = "https://github.com/openlawlibrary/taf/tree/master" -with open('README.md', encoding='utf-8') as file_object: - long_description = file_object.read() +with open("README.md", encoding="utf-8") as file_object: + long_description = file_object.read() packages = find_packages() # Create platform specific wheel # https://stackoverflow.com/a/45150383/9669050 try: - from wheel.bdist_wheel import bdist_wheel as _bdist_wheel + from wheel.bdist_wheel import bdist_wheel as _bdist_wheel + + class bdist_wheel(_bdist_wheel): + def finalize_options(self): + _bdist_wheel.finalize_options(self) + self.root_is_pure = False + - class bdist_wheel(_bdist_wheel): - def finalize_options(self): - _bdist_wheel.finalize_options(self) - self.root_is_pure = False except ImportError: - bdist_wheel = None + bdist_wheel = None -ci_require = [ - "pylint==2.3.1", - "bandit==1.6.0", - "coverage==4.5.3", - "pytest-cov==2.7.1", -] +ci_require = ["pylint==2.3.1", "bandit==1.6.0", "coverage==4.5.3", "pytest-cov==2.7.1"] -dev_require = [ - "autopep8==1.4.4", - "pylint==2.3.1", - "bandit==1.6.0", -] +dev_require = ["autopep8==1.4.4", "pylint==2.3.1", "bandit==1.6.0"] -tests_require = [ - "pytest==4.5.0", -] +tests_require = ["pytest==4.5.0"] -yubikey_require = [ - "yubikey-manager==3.0.0", -] +yubikey_require = ["yubikey-manager==3.0.0"] setup( name=PACKAGE_NAME, version=VERSION, description=DESCRIPTION, long_description=long_description, - long_description_content_type='text/markdown', + long_description_content_type="text/markdown", url=URL, author=AUTHOR, author_email=AUTHOR_EMAIL, keywords=KEYWORDS, packages=packages, - cmdclass={'bdist_wheel': bdist_wheel}, + cmdclass={"bdist_wheel": bdist_wheel}, include_package_data=True, - data_files=[ - ('lib/site-packages/taf', [ - './LICENSE.txt', - './README.md' - ]) - ], + data_files=[("lib/site-packages/taf", ["./LICENSE.txt", "./README.md"])], zip_safe=False, install_requires=[ - 'click==6.7', - 'colorama>=0.3.9' - 'cryptography>=2.3.1', - 'oll-tuf==0.11.2.dev9', + "click==6.7", + "colorama>=0.3.9" "cryptography>=2.3.1", + "oll-tuf==0.11.2.dev9", ], extras_require={ - 'ci': ci_require, - 'test': tests_require, - 'dev': dev_require, - 'yubikey': yubikey_require, + "ci": ci_require, + "test": tests_require, + "dev": dev_require, + "yubikey": yubikey_require, }, tests_require=tests_require, - entry_points={ - 'console_scripts': [ - 'taf = taf.cli:main' - ] - }, + entry_points={"console_scripts": ["taf = taf.cli:main"]}, classifiers=[ - 'Development Status :: 2 - Pre-Alpha', - 'Intended Audience :: Developers', - 'Intended Audience :: Information Technology', - 'Topic :: Security', - 'Topic :: Software Development', - 'License :: OSI Approved :: Apache Software License', - 'Operating System :: OS Independent', - 'Programming Language :: Python :: 3', - 'Programming Language :: Python :: 3.5', - 'Programming Language :: Python :: 3.6', - 'Programming Language :: Python :: Implementation :: CPython', - ] + "Development Status :: 2 - Pre-Alpha", + "Intended Audience :: Developers", + "Intended Audience :: Information Technology", + "Topic :: Security", + "Topic :: Software Development", + "License :: OSI Approved :: Apache Software License", + "Operating System :: OS Independent", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.5", + "Programming Language :: Python :: 3.6", + "Programming Language :: Python :: Implementation :: CPython", + ], ) diff --git a/taf/__init__.py b/taf/__init__.py index 8e37bcd28..b73550250 100644 --- a/taf/__init__.py +++ b/taf/__init__.py @@ -6,21 +6,21 @@ _PLATFORM = sys.platform -_PLATFORM_LIBS = str((Path(__file__).parent / 'libs').resolve()) +_PLATFORM_LIBS = str((Path(__file__).parent / "libs").resolve()) def _set_env(env_name, path): - try: - os.environ[env_name] += os.pathsep + path - except KeyError: - os.environ[env_name] = path + try: + os.environ[env_name] += os.pathsep + path + except KeyError: + os.environ[env_name] = path -if _PLATFORM == 'darwin': - _set_env('DYLD_LIBRARY_PATH', _PLATFORM_LIBS) -elif _PLATFORM == 'linux': - _set_env('LD_LIBRARY_PATH', _PLATFORM_LIBS) -elif _PLATFORM == 'win32': - _set_env('PATH', _PLATFORM_LIBS) +if _PLATFORM == "darwin": + _set_env("DYLD_LIBRARY_PATH", _PLATFORM_LIBS) +elif _PLATFORM == "linux": + _set_env("LD_LIBRARY_PATH", _PLATFORM_LIBS) +elif _PLATFORM == "win32": + _set_env("PATH", _PLATFORM_LIBS) else: - raise Exception('Platform "{}" is not supported!'.format(_PLATFORM)) + raise Exception('Platform "{}" is not supported!'.format(_PLATFORM)) diff --git a/taf/auth_repo.py b/taf/auth_repo.py index 26278a5d6..abf95adc4 100644 --- a/taf/auth_repo.py +++ b/taf/auth_repo.py @@ -12,146 +12,181 @@ class AuthRepoMixin(object): - LAST_VALIDATED_FILENAME = 'last_validated_commit' + LAST_VALIDATED_FILENAME = "last_validated_commit" - @property - def conf_dir(self): - """ + @property + def conf_dir(self): + """ Returns location of the directory which stores the authentication repository's configuration files. That is, the last validated commit. Create the directory if it does not exist. """ - # the repository's name consists of the namespace and name (namespace/name) - # the configuration directory should be _name - last_dir = os.path.basename(os.path.normpath(self.repo_path)) - conf_path = os.path.join(os.path.dirname(self.repo_path), '_{}'.format(last_dir)) - if not os.path.exists(conf_path): - os.makedirs(conf_path) - return conf_path - - @property - def certs_dir(self): - certs_dir = os.path.join(self.repo_path, 'certs') - if not os.path.exists(certs_dir): - os.makedirs(certs_dir) - return certs_dir - - @property - def last_validated_commit(self): - """ + # the repository's name consists of the namespace and name (namespace/name) + # the configuration directory should be _name + last_dir = os.path.basename(os.path.normpath(self.repo_path)) + conf_path = os.path.join( + os.path.dirname(self.repo_path), "_{}".format(last_dir) + ) + if not os.path.exists(conf_path): + os.makedirs(conf_path) + return conf_path + + @property + def certs_dir(self): + certs_dir = os.path.join(self.repo_path, "certs") + if not os.path.exists(certs_dir): + os.makedirs(certs_dir) + return certs_dir + + @property + def last_validated_commit(self): + """ Return the last validated commit of the authentication repository """ - path = os.path.join(self.conf_dir, self.LAST_VALIDATED_FILENAME) - try: - with open(path) as f: - return f.read() - except FileNotFoundError: - return None - - def get_target(self, target_name, commit=None, safely=True): - if commit is None: - commit = self.head_commit_sha() - target_path = (Path(self.targets_path) / target_name).as_posix() - if safely: - return self._safely_get_json(commit, target_path) - else: - return self.get_json(commit, target_path) - - def is_commit_authenticated(self, target_name, commit): - """Checks if passed commit is ever authenticated for given target name. - """ - for auth_commit in reversed(self.all_commits_since_commit()): - target = self.get_target(target_name, auth_commit) - try: - if target['commit'] == commit: - return True - except TypeError: - continue - return False - - def set_last_validated_commit(self, commit): + path = os.path.join(self.conf_dir, self.LAST_VALIDATED_FILENAME) + try: + with open(path) as f: + return f.read() + except FileNotFoundError: + return None + + def get_target(self, target_name, commit=None, safely=True): + if commit is None: + commit = self.head_commit_sha() + target_path = (Path(self.targets_path) / target_name).as_posix() + if safely: + return self._safely_get_json(commit, target_path) + else: + return self.get_json(commit, target_path) + + def is_commit_authenticated(self, target_name, commit): + """Checks if passed commit is ever authenticated for given target name. """ + for auth_commit in reversed(self.all_commits_since_commit()): + target = self.get_target(target_name, auth_commit) + try: + if target["commit"] == commit: + return True + except TypeError: + continue + return False + + def set_last_validated_commit(self, commit): + """ Set the last validated commit of the authentication repository """ - path = os.path.join(self.conf_dir, self.LAST_VALIDATED_FILENAME) - logger.debug('Auth repo %s: setting last validated commit to: %s', - self.repo_name, commit) - with open(path, 'w') as f: - f.write(commit) - - def sorted_commits_per_repositories(self, commits): - """Create a list of of subsequent commits per repository + path = os.path.join(self.conf_dir, self.LAST_VALIDATED_FILENAME) + logger.debug( + "Auth repo %s: setting last validated commit to: %s", self.repo_name, commit + ) + with open(path, "w") as f: + f.write(commit) + + def sorted_commits_per_repositories(self, commits): + """Create a list of of subsequent commits per repository keeping in mind that targets metadata file is not updated everytime something is committed to the authentication repo """ - repositories_commits = defaultdict(list) - targets = self.target_commits_at_revisions(commits) - previous_commits = {} - for commit in commits: - for target_path, target_commit in targets[commit].items(): - previous_commit = previous_commits.get(target_path) - if previous_commit is None or target_commit != previous_commit: - repositories_commits[target_path].append(target_commit) - previous_commits[target_path] = target_commit - logger.debug('Auth repo %s: new commits per repositories according to targets.json: %s', - self.repo_name, repositories_commits) - return repositories_commits - - def target_commits_at_revisions(self, commits): - targets = defaultdict(dict) - for commit in commits: - targets_at_revision = self._safely_get_json( - commit, self.metadata_path + '/targets.json') - if targets_at_revision is None: - continue - targets_at_revision = targets_at_revision['signed']['targets'] - - repositories_at_revision = self._safely_get_json(commit, - self.targets_path + '/repositories.json') - if repositories_at_revision is None: - continue - repositories_at_revision = repositories_at_revision['repositories'] - - for target_path in targets_at_revision: - if target_path not in repositories_at_revision: - # we only care about repositories - continue + repositories_commits = defaultdict(list) + targets = self.target_commits_at_revisions(commits) + previous_commits = {} + for commit in commits: + for target_path, target_commit in targets[commit].items(): + previous_commit = previous_commits.get(target_path) + if previous_commit is None or target_commit != previous_commit: + repositories_commits[target_path].append(target_commit) + previous_commits[target_path] = target_commit + logger.debug( + "Auth repo %s: new commits per repositories according to targets.json: %s", + self.repo_name, + repositories_commits, + ) + return repositories_commits + + def target_commits_at_revisions(self, commits): + targets = defaultdict(dict) + for commit in commits: + targets_at_revision = self._safely_get_json( + commit, self.metadata_path + "/targets.json" + ) + if targets_at_revision is None: + continue + targets_at_revision = targets_at_revision["signed"]["targets"] + + repositories_at_revision = self._safely_get_json( + commit, self.targets_path + "/repositories.json" + ) + if repositories_at_revision is None: + continue + repositories_at_revision = repositories_at_revision["repositories"] + + for target_path in targets_at_revision: + if target_path not in repositories_at_revision: + # we only care about repositories + continue + try: + target_commit = self.get_json( + commit, self.targets_path + "/" + target_path + ).get("commit") + targets[commit][target_path] = target_commit + except json.decoder.JSONDecodeError: + logger.debug( + "Auth repo %s: target file %s is not a valid json at revision %s", + self.repo_name, + target_path, + commit, + ) + continue + return targets + + def _safely_get_json(self, commit, path): try: - target_commit = \ - self.get_json(commit, self.targets_path + '/' + target_path).get('commit') - targets[commit][target_path] = target_commit + return self.get_json(commit, path) + except CalledProcessError: + logger.info( + "Auth repo %s: %s not available at revision %s", + self.repo_name, + os.path.basename(path), + commit, + ) except json.decoder.JSONDecodeError: - logger.debug('Auth repo %s: target file %s is not a valid json at revision %s', - self.repo_name, target_path, commit) - continue - return targets - - def _safely_get_json(self, commit, path): - try: - return self.get_json(commit, path) - except CalledProcessError: - logger.info('Auth repo %s: %s not available at revision %s', self.repo_name, - os.path.basename(path), commit) - except json.decoder.JSONDecodeError: - logger.info('Auth repo %s: %s not a valid json at revision %s', self.repo_name, - os.path.basename(path), commit) + logger.info( + "Auth repo %s: %s not a valid json at revision %s", + self.repo_name, + os.path.basename(path), + commit, + ) class AuthenticationRepo(AuthRepoMixin, GitRepository): - - def __init__(self, repo_path, metadata_path='metadata', targets_path='targets', repo_urls=None, - additional_info=None, default_branch='master'): - super().__init__(repo_path, repo_urls, additional_info, default_branch) - self.targets_path = targets_path - self.metadata_path = metadata_path + def __init__( + self, + repo_path, + metadata_path="metadata", + targets_path="targets", + repo_urls=None, + additional_info=None, + default_branch="master", + ): + super().__init__(repo_path, repo_urls, additional_info, default_branch) + self.targets_path = targets_path + self.metadata_path = metadata_path class NamedAuthenticationRepo(AuthRepoMixin, NamedGitRepository): - - def __init__(self, root_dir, repo_name, metadata_path='metadata', targets_path='targets', - repo_urls=None, additional_info=None, default_branch='master'): - - super().__init__(root_dir, repo_name, repo_urls, additional_info, - default_branch) - self.targets_path = targets_path - self.metadata_path = metadata_path + def __init__( + self, + root_dir, + repo_name, + metadata_path="metadata", + targets_path="targets", + repo_urls=None, + additional_info=None, + default_branch="master", + ): + + super().__init__( + root_dir, repo_name, repo_urls, additional_info, default_branch + ) + self.targets_path = targets_path + self.metadata_path = metadata_path diff --git a/taf/cli.py b/taf/cli.py index 1065bd25e..5dbac9473 100644 --- a/taf/cli.py +++ b/taf/cli.py @@ -11,169 +11,338 @@ @click.group() def cli(): - pass + pass @cli.command() -@click.option('--repo-path', default='repository', help='Authentication repository\'s path') -@click.option('--keystore', default=None, help='Path of the keystore file') -@click.option('--keys-description', default=None, help='A dictionary containing information about the keys or a path' - ' to a json file which which stores the needed information') -@click.option('--commit-msg', default=None, help='Commit message to be used in case the changes' - 'should be automatically committed') -@click.option('--scheme', default=DEFAULT_RSA_SIGNATURE_SCHEME, help='A signature scheme used for signing.') +@click.option( + "--repo-path", default="repository", help="Authentication repository's path" +) +@click.option("--keystore", default=None, help="Path of the keystore file") +@click.option( + "--keys-description", + default=None, + help="A dictionary containing information about the keys or a path" + " to a json file which which stores the needed information", +) +@click.option( + "--commit-msg", + default=None, + help="Commit message to be used in case the changes" + "should be automatically committed", +) +@click.option( + "--scheme", + default=DEFAULT_RSA_SIGNATURE_SCHEME, + help="A signature scheme used for signing.", +) def add_targets(repo_path, keystore, keys_description, commit_msg, scheme): - developer_tool.register_target_files(repo_path, keystore, keys_description, commit_msg, scheme) + developer_tool.register_target_files( + repo_path, keystore, keys_description, commit_msg, scheme + ) @cli.command() -@click.option('--repo-path', default='repository', help='Authentication repository\'s path') -@click.option('--file-path', help="Target file's path, relative to the targets directory") -@click.option('--keystore', default='keystore', help='Path of the keystore file') -@click.option('--keys-description', default=None, help='A dictionary containing information about the keys or a path' - ' to a json file which which stores the needed information') -@click.option('--scheme', default=DEFAULT_RSA_SIGNATURE_SCHEME, help='A signature scheme used for signing.') +@click.option( + "--repo-path", default="repository", help="Authentication repository's path" +) +@click.option( + "--file-path", help="Target file's path, relative to the targets directory" +) +@click.option("--keystore", default="keystore", help="Path of the keystore file") +@click.option( + "--keys-description", + default=None, + help="A dictionary containing information about the keys or a path" + " to a json file which which stores the needed information", +) +@click.option( + "--scheme", + default=DEFAULT_RSA_SIGNATURE_SCHEME, + help="A signature scheme used for signing.", +) def add_target_file(repo_path, file_path, keystore, keys_description, scheme): - developer_tool.register_target_file(repo_path, file_path, keystore, keys_description, scheme) + developer_tool.register_target_file( + repo_path, file_path, keystore, keys_description, scheme + ) @cli.command() -@click.option('--repo-path', default='repository', help='Location of the repository') -@click.option('--targets-dir', default='targets', help='Directory where the target ' - 'repositories are located') -@click.option('--namespace', default=None, help='Namespace of the target repositories') +@click.option("--repo-path", default="repository", help="Location of the repository") +@click.option( + "--targets-dir", + default="targets", + help="Directory where the target " "repositories are located", +) +@click.option("--namespace", default=None, help="Namespace of the target repositories") def add_target_repos(repo_path, targets_dir, namespace): - developer_tool.add_target_repos(repo_path, targets_dir, namespace) + developer_tool.add_target_repos(repo_path, targets_dir, namespace) @cli.command() -@click.option('--repo-path', default='repository', help='Location of the authentication repository') -@click.option('--targets-dir', default='targets', help='Directory where the target ' - 'repositories are located') -@click.option('--namespace', default='', help='Namespace of the target repositories') -@click.option('--targets-rel-dir', default=None, help=' Directory relative to which urls ' - 'of the target repositories are set, if they do not have remote set') -@click.option('--keystore', default='keystore', help='Location of the keystore file') -@click.option('--keys-description', help='A dictionary containing information about the ' - 'keys or a path to a json file which which stores the needed information') -@click.option('--custom', default=None, help='A dictionary containing custom ' - 'targets info which will be included in repositories.json') -def build_auth_repo(repo_path, targets_dir, namespace, targets_rel_dir, keystore, - keys_description, custom): - developer_tool.build_auth_repo(repo_path, targets_dir, namespace, targets_rel_dir, keystore, - keys_description, custom) +@click.option( + "--repo-path", + default="repository", + help="Location of the authentication repository", +) +@click.option( + "--targets-dir", + default="targets", + help="Directory where the target " "repositories are located", +) +@click.option("--namespace", default="", help="Namespace of the target repositories") +@click.option( + "--targets-rel-dir", + default=None, + help=" Directory relative to which urls " + "of the target repositories are set, if they do not have remote set", +) +@click.option("--keystore", default="keystore", help="Location of the keystore file") +@click.option( + "--keys-description", + help="A dictionary containing information about the " + "keys or a path to a json file which which stores the needed information", +) +@click.option( + "--custom", + default=None, + help="A dictionary containing custom " + "targets info which will be included in repositories.json", +) +def build_auth_repo( + repo_path, + targets_dir, + namespace, + targets_rel_dir, + keystore, + keys_description, + custom, +): + developer_tool.build_auth_repo( + repo_path, + targets_dir, + namespace, + targets_rel_dir, + keystore, + keys_description, + custom, + ) @cli.command() -@click.option('--repo-path', default='repository', help='Location of the repository') -@click.option('--keystore', default=None, help='Location of the keystore file') -@click.option('--keys-description', help='A dictionary containing information about the ' - 'keys or a path to a json file which which stores the needed information') -@click.option('--commit-msg', default=None, help='Commit message. If provided, the ' - 'changes will be committed automatically') -@click.option('--test', is_flag=True, default=False, help='Indicates if the created repository ' - 'is a test authentication repository') +@click.option("--repo-path", default="repository", help="Location of the repository") +@click.option("--keystore", default=None, help="Location of the keystore file") +@click.option( + "--keys-description", + help="A dictionary containing information about the " + "keys or a path to a json file which which stores the needed information", +) +@click.option( + "--commit-msg", + default=None, + help="Commit message. If provided, the " "changes will be committed automatically", +) +@click.option( + "--test", + is_flag=True, + default=False, + help="Indicates if the created repository " "is a test authentication repository", +) def create_repo(repo_path, keystore, keys_description, commit_msg, test): - developer_tool.create_repository(repo_path, keystore, keys_description, commit_msg, test) + developer_tool.create_repository( + repo_path, keystore, keys_description, commit_msg, test + ) @cli.command() -@click.option('--keystore', default='keystore', help='Location of the keystore file') -@click.option('--keys-description', help='A dictionary containing information about the keys or a path' - ' to a json file which which stores the needed information') +@click.option("--keystore", default="keystore", help="Location of the keystore file") +@click.option( + "--keys-description", + help="A dictionary containing information about the keys or a path" + " to a json file which which stores the needed information", +) def generate_keys(keystore, keys_description): - developer_tool.generate_keys(keystore, keys_description) + developer_tool.generate_keys(keystore, keys_description) @cli.command() -@click.option('--repo-path', default='repository', help='Location of the repository') -@click.option('--targets-dir', default='targets', help='Directory where the target ' - 'repositories are located') -@click.option('--namespace', default='', help='Namespace of the target repositories') -@click.option('--targets-rel-dir', default=None, help=' Directory relative to which urls ' - 'of the target repositories are set, if they do not have remote set') -@click.option('--keystore', default='keystore', help='Location of the keystore file') -@click.option('--keys-description', help='A dictionary containing information about the ' - 'keys or a path to a json file which which stores the needed information') -@click.option('--custom', default=None, help='A dictionary containing custom ' - 'targets info which will be included in repositories.json') -@click.option('--commit', is_flag=True, default=True, help='Indicates if changes should be committed') -@click.option('--test', is_flag=True, default=False, help='Indicates if the created repository ' - 'is a test authentication repository') -def init_repo(repo_path, targets_dir, namespace, targets_rel_dir, keystore, - keys_description, custom, commit, test): - developer_tool.init_repo(repo_path, targets_dir, namespace, targets_rel_dir, keystore, - keys_description, repos_custom=custom, commit=commit, test=test) +@click.option("--repo-path", default="repository", help="Location of the repository") +@click.option( + "--targets-dir", + default="targets", + help="Directory where the target " "repositories are located", +) +@click.option("--namespace", default="", help="Namespace of the target repositories") +@click.option( + "--targets-rel-dir", + default=None, + help=" Directory relative to which urls " + "of the target repositories are set, if they do not have remote set", +) +@click.option("--keystore", default="keystore", help="Location of the keystore file") +@click.option( + "--keys-description", + help="A dictionary containing information about the " + "keys or a path to a json file which which stores the needed information", +) +@click.option( + "--custom", + default=None, + help="A dictionary containing custom " + "targets info which will be included in repositories.json", +) +@click.option( + "--commit", + is_flag=True, + default=True, + help="Indicates if changes should be committed", +) +@click.option( + "--test", + is_flag=True, + default=False, + help="Indicates if the created repository " "is a test authentication repository", +) +def init_repo( + repo_path, + targets_dir, + namespace, + targets_rel_dir, + keystore, + keys_description, + custom, + commit, + test, +): + developer_tool.init_repo( + repo_path, + targets_dir, + namespace, + targets_rel_dir, + keystore, + keys_description, + repos_custom=custom, + commit=commit, + test=test, + ) @cli.command() -@click.option('--repo-path', default='repository', help='Location of the repository') -@click.option('--targets-dir', default='targets', help='Directory where the target ' - 'repositories are located') -@click.option('--namespace', default=None, help='Namespace of the target repositories') -@click.option('--targets-rel-dir', default=None, help=' Directory relative to which urls ' - 'of the target repositories are set, if they do not have remote set') -@click.option('--custom', default=None, help='A dictionary containing custom ' - 'targets info which will be included in repositories.json') -def generate_repositories_json(repo_path, targets_dir, namespace, targets_rel_dir, custom): - developer_tool.generate_repositories_json(repo_path, targets_dir, namespace, targets_rel_dir, - custom) +@click.option("--repo-path", default="repository", help="Location of the repository") +@click.option( + "--targets-dir", + default="targets", + help="Directory where the target " "repositories are located", +) +@click.option("--namespace", default=None, help="Namespace of the target repositories") +@click.option( + "--targets-rel-dir", + default=None, + help=" Directory relative to which urls " + "of the target repositories are set, if they do not have remote set", +) +@click.option( + "--custom", + default=None, + help="A dictionary containing custom " + "targets info which will be included in repositories.json", +) +def generate_repositories_json( + repo_path, targets_dir, namespace, targets_rel_dir, custom +): + developer_tool.generate_repositories_json( + repo_path, targets_dir, namespace, targets_rel_dir, custom + ) @cli.command() -@click.option('--repo-path', default='repository', help='Location of the repository') -@click.option('--keystore', default='keystore', help='Location of the keystore file') -@click.option('--keys-description', help='A dictionary containing information about the keys or a path' - ' to a json file which which stores the needed information') -@click.option('--role', default='timestamp', help='Metadata role whose expiration date should be ' - 'updated') -@click.option('--start-date', default=datetime.datetime.now(), help='Date to which the intercal is added', type=ISO_DATE) -@click.option('--interval', default=None, help='Time interval added to the start date', type=int) -@click.option('--commit-msg', default=None, help='Commit message to be used in case the changes' - 'should be automatically committed') -def update_expiration_date(repo_path, keystore, keys_description, role, start_date, interval, - commit_msg): - developer_tool.update_metadata_expiration_date(repo_path, keystore, keys_description, role, - start_date, interval, commit_msg) +@click.option("--repo-path", default="repository", help="Location of the repository") +@click.option("--keystore", default="keystore", help="Location of the keystore file") +@click.option( + "--keys-description", + help="A dictionary containing information about the keys or a path" + " to a json file which which stores the needed information", +) +@click.option( + "--role", + default="timestamp", + help="Metadata role whose expiration date should be " "updated", +) +@click.option( + "--start-date", + default=datetime.datetime.now(), + help="Date to which the intercal is added", + type=ISO_DATE, +) +@click.option( + "--interval", default=None, help="Time interval added to the start date", type=int +) +@click.option( + "--commit-msg", + default=None, + help="Commit message to be used in case the changes" + "should be automatically committed", +) +def update_expiration_date( + repo_path, keystore, keys_description, role, start_date, interval, commit_msg +): + developer_tool.update_metadata_expiration_date( + repo_path, keystore, keys_description, role, start_date, interval, commit_msg + ) @cli.command() -@click.option('--url', help="Authentication repository's url") -@click.option('--clients-dir', help="Directory containing the client's authentication repository") -@click.option('--targets-dir', help="Directory containing the target repositories") -@click.option('--from-fs', is_flag=True, default=False, help='Indicates if the we want to clone a ' - 'repository from the filesystem') +@click.option("--url", help="Authentication repository's url") +@click.option( + "--clients-dir", help="Directory containing the client's authentication repository" +) +@click.option("--targets-dir", help="Directory containing the target repositories") +@click.option( + "--from-fs", + is_flag=True, + default=False, + help="Indicates if the we want to clone a " "repository from the filesystem", +) def update(url, clients_dir, targets_dir, from_fs): - update_repository(url, clients_dir, targets_dir, from_fs) + update_repository(url, clients_dir, targets_dir, from_fs) @cli.command() -@click.option('--url', help="Authentication repository's url") -@click.option('--clients-dir', help="Directory containing the client's authentication repository") -@click.option('--repo-name', help="Repository's name") -@click.option('--targets-dir', help="Directory containing the target repositories") -@click.option('--from-fs', is_flag=True, default=False, help='Indicates if the we want to clone a ' - 'repository from the filesystem') +@click.option("--url", help="Authentication repository's url") +@click.option( + "--clients-dir", help="Directory containing the client's authentication repository" +) +@click.option("--repo-name", help="Repository's name") +@click.option("--targets-dir", help="Directory containing the target repositories") +@click.option( + "--from-fs", + is_flag=True, + default=False, + help="Indicates if the we want to clone a " "repository from the filesystem", +) def update_named_repo(url, clients_dir, repo_name, targets_dir, from_fs): - update_named_repository(url, clients_dir, repo_name, targets_dir, from_fs) + update_named_repository(url, clients_dir, repo_name, targets_dir, from_fs) @cli.command() def setup_test_yubikey(): - import taf.yubikey as yk + import taf.yubikey as yk - targets_key_path = Path(__file__).parent.parent / "tests" / "data" / "keystore" / "targets" - targets_key_pem = targets_key_path.read_bytes() + targets_key_path = ( + Path(__file__).parent.parent / "tests" / "data" / "keystore" / "targets" + ) + targets_key_pem = targets_key_path.read_bytes() - click.echo("\nImporting RSA private key from {} to Yubikey..." - .format(targets_key_path)) + click.echo( + "\nImporting RSA private key from {} to Yubikey...".format(targets_key_path) + ) - pin = yk.DEFAULT_PIN - pub_key = yk.setup(pin, 'Test Yubikey', private_key_pem=targets_key_pem) + pin = yk.DEFAULT_PIN + pub_key = yk.setup(pin, "Test Yubikey", private_key_pem=targets_key_pem) - click.echo("\nPrivate key successfully imported.\n") - click.echo("\nPublic key (PEM): \n{}".format(pub_key.decode("utf-8"))) - click.echo("Pin: {}\n".format(pin)) + click.echo("\nPrivate key successfully imported.\n") + click.echo("\nPublic key (PEM): \n{}".format(pub_key.decode("utf-8"))) + click.echo("Pin: {}\n".format(pin)) cli() diff --git a/taf/constants.py b/taf/constants.py index d77c90c9c..d56cb37ee 100644 --- a/taf/constants.py +++ b/taf/constants.py @@ -1,3 +1,3 @@ # Default scheme for all RSA keys. It can be changed in keys.json while # generating repository -DEFAULT_RSA_SIGNATURE_SCHEME = 'rsa-pkcs1v15-sha256' +DEFAULT_RSA_SIGNATURE_SCHEME = "rsa-pkcs1v15-sha256" diff --git a/taf/developer_tool.py b/taf/developer_tool.py index 532a48b48..97d1e005a 100644 --- a/taf/developer_tool.py +++ b/taf/developer_tool.py @@ -11,8 +11,10 @@ import securesystemslib import tuf.repository_tool from securesystemslib.exceptions import UnknownKeyError -from securesystemslib.interface import (import_rsa_privatekey_from_file, - import_rsa_publickey_from_file) +from securesystemslib.interface import ( + import_rsa_privatekey_from_file, + import_rsa_publickey_from_file, +) from taf.auth_repo import AuthenticationRepo from taf.constants import DEFAULT_RSA_SIGNATURE_SCHEME from taf.git import GitRepository @@ -20,25 +22,31 @@ from taf.repository_tool import Repository, load_role_key from taf.utils import get_pin_for from tuf.keydb import get_key -from tuf.repository_tool import (METADATA_DIRECTORY_NAME, - TARGETS_DIRECTORY_NAME, create_new_repository, - generate_and_write_rsa_keypair, - generate_rsa_key, import_rsakey_from_pem) +from tuf.repository_tool import ( + METADATA_DIRECTORY_NAME, + TARGETS_DIRECTORY_NAME, + create_new_repository, + generate_and_write_rsa_keypair, + generate_rsa_key, + import_rsakey_from_pem, +) logger = get_logger(__name__) try: - import taf.yubikey as yk + import taf.yubikey as yk except ImportError: - logger.warning('"yubikey-manager" is not installed.') + logger.warning('"yubikey-manager" is not installed.') # Yubikey x509 certificate expiration interval EXPIRATION_INTERVAL = 36500 -YUBIKEY_EXPIRATION_DATE = datetime.datetime.now() + datetime.timedelta(days=EXPIRATION_INTERVAL) +YUBIKEY_EXPIRATION_DATE = datetime.datetime.now() + datetime.timedelta( + days=EXPIRATION_INTERVAL +) def add_target_repos(repo_path, targets_directory, namespace=None): - """ + """ Create or update target files by reading the latest commits of the provided target repositories @@ -49,67 +57,86 @@ def add_target_repos(repo_path, targets_directory, namespace=None): namespace: Namespace used to form the full name of the target repositories. E.g. some_namespace/law-xml """ - repo_path = Path(repo_path).resolve() - targets_directory = Path(targets_directory).resolve() - if namespace is None: - namespace = targets_directory.name - auth_repo_targets_dir = repo_path / TARGETS_DIRECTORY_NAME - if namespace: - auth_repo_targets_dir = auth_repo_targets_dir / namespace - if not auth_repo_targets_dir.exists(): - os.makedirs(auth_repo_targets_dir) - - for target_repo_dir in targets_directory.glob('*'): - if not target_repo_dir.is_dir() or target_repo_dir == repo_path: - continue - target_repo = GitRepository(str(target_repo_dir)) - if target_repo.is_git_repository: - commit = target_repo.head_commit_sha() - target_repo_name = target_repo_dir.name - (auth_repo_targets_dir / target_repo_name).write_text(json.dumps({'commit': commit}, - indent=4)) - - -def build_auth_repo(repo_path, targets_directory, namespace, targets_relative_dir, keystore, - roles_key_infos, repos_custom): - # read the key infos here, no need to read the file multiple times - roles_key_infos = _read_input_dict(roles_key_infos) - create_repository(repo_path, keystore, roles_key_infos) - generate_repositories_json(repo_path, targets_directory, namespace, - targets_relative_dir, repos_custom) - register_target_files(repo_path, keystore, roles_key_infos, commit_msg='Added repositories.json') - auth_repo_targets_dir = os.path.join(repo_path, TARGETS_DIRECTORY_NAME) - if namespace: - auth_repo_targets_dir = os.path.join(auth_repo_targets_dir, namespace) - if not os.path.exists(auth_repo_targets_dir): - os.makedirs(auth_repo_targets_dir) - # group commits by dates - # first add first repos at a date, then second repost at that date - commits_by_date = defaultdict(dict) - target_repositories = [] - for target_repo_dir in os.listdir(targets_directory): - target_repo = GitRepository(os.path.join(targets_directory, target_repo_dir)) - target_repo.checkout_branch('master') - target_repo_name = os.path.basename(target_repo_dir) - target_repositories.append(target_repo_name) - commits = target_repo.list_commits(format='format:%H|%cd', date='short') - for commit in commits[::-1]: - sha, date = commit.split('|') - commits_by_date[date].setdefault(target_repo_name, []).append(sha) - - for date in sorted(commits_by_date.keys()): - repos_and_commits = commits_by_date[date] - for target_repo_name in target_repositories: - if target_repo_name in repos_and_commits: - for sha in commits_by_date[date][target_repo_name]: - with open(os.path.join(auth_repo_targets_dir, target_repo_name), 'w') as f: - json.dump({'commit': sha}, f, indent=4) - register_target_files(repo_path, keystore, roles_key_infos, - commit_msg='Updated {}'.format(target_repo_name)) - - -def create_repository(repo_path, keystore, roles_key_infos, commit_message=None, test=False): - """ + repo_path = Path(repo_path).resolve() + targets_directory = Path(targets_directory).resolve() + if namespace is None: + namespace = targets_directory.name + auth_repo_targets_dir = repo_path / TARGETS_DIRECTORY_NAME + if namespace: + auth_repo_targets_dir = auth_repo_targets_dir / namespace + if not auth_repo_targets_dir.exists(): + os.makedirs(auth_repo_targets_dir) + + for target_repo_dir in targets_directory.glob("*"): + if not target_repo_dir.is_dir() or target_repo_dir == repo_path: + continue + target_repo = GitRepository(str(target_repo_dir)) + if target_repo.is_git_repository: + commit = target_repo.head_commit_sha() + target_repo_name = target_repo_dir.name + (auth_repo_targets_dir / target_repo_name).write_text( + json.dumps({"commit": commit}, indent=4) + ) + + +def build_auth_repo( + repo_path, + targets_directory, + namespace, + targets_relative_dir, + keystore, + roles_key_infos, + repos_custom, +): + # read the key infos here, no need to read the file multiple times + roles_key_infos = _read_input_dict(roles_key_infos) + create_repository(repo_path, keystore, roles_key_infos) + generate_repositories_json( + repo_path, targets_directory, namespace, targets_relative_dir, repos_custom + ) + register_target_files( + repo_path, keystore, roles_key_infos, commit_msg="Added repositories.json" + ) + auth_repo_targets_dir = os.path.join(repo_path, TARGETS_DIRECTORY_NAME) + if namespace: + auth_repo_targets_dir = os.path.join(auth_repo_targets_dir, namespace) + if not os.path.exists(auth_repo_targets_dir): + os.makedirs(auth_repo_targets_dir) + # group commits by dates + # first add first repos at a date, then second repost at that date + commits_by_date = defaultdict(dict) + target_repositories = [] + for target_repo_dir in os.listdir(targets_directory): + target_repo = GitRepository(os.path.join(targets_directory, target_repo_dir)) + target_repo.checkout_branch("master") + target_repo_name = os.path.basename(target_repo_dir) + target_repositories.append(target_repo_name) + commits = target_repo.list_commits(format="format:%H|%cd", date="short") + for commit in commits[::-1]: + sha, date = commit.split("|") + commits_by_date[date].setdefault(target_repo_name, []).append(sha) + + for date in sorted(commits_by_date.keys()): + repos_and_commits = commits_by_date[date] + for target_repo_name in target_repositories: + if target_repo_name in repos_and_commits: + for sha in commits_by_date[date][target_repo_name]: + with open( + os.path.join(auth_repo_targets_dir, target_repo_name), "w" + ) as f: + json.dump({"commit": sha}, f, indent=4) + register_target_files( + repo_path, + keystore, + roles_key_infos, + commit_msg="Updated {}".format(target_repo_name), + ) + + +def create_repository( + repo_path, keystore, roles_key_infos, commit_message=None, test=False +): + """ Create a new authentication repository. Generate initial metadata files. The initial targets metadata file is empty (does not specify any targets) @@ -127,116 +154,138 @@ def create_repository(repo_path, keystore, roles_key_infos, commit_message=None, test: Indicates if the created repository is a test authentication repository """ - yubikeys = defaultdict(dict) - roles_key_infos = _read_input_dict(roles_key_infos) - repo = AuthenticationRepo(repo_path) - if os.path.isdir(repo_path): - if repo.is_git_repository: - print('Repository {} already exists'.format(repo_path)) - return - - tuf.repository_tool.METADATA_STAGED_DIRECTORY_NAME = METADATA_DIRECTORY_NAME - repository = create_new_repository(repo_path) - for role_name, key_info in roles_key_infos.items(): - num_of_keys = key_info.get('number', 1) - passwords = key_info.get('passwords', [None] * num_of_keys) - threshold = key_info.get('threshold', 1) - is_yubikey = key_info.get('yubikey', False) - scheme = key_info.get('scheme', DEFAULT_RSA_SIGNATURE_SCHEME) - - role_obj = _role_obj(role_name, repository) - role_obj.threshold = threshold - for key_num in range(num_of_keys): - key_name = _get_key_name(role_name, key_num, num_of_keys) - if is_yubikey: - print('Generating keys for {}'.format(key_name)) - use_existing = False - if len(yubikeys) > 1 or (len(yubikeys) == 1 and role_name not in yubikeys): - use_existing = input('Do you want to reuse already set up Yubikey? y/n ') == 'y' - if use_existing: - existing_key = None - key_id_certs = {} - while existing_key is None: - for existing_role_name, role_keys in yubikeys.items(): - if existing_role_name == role_name: - continue - print("Existing keys for role {} are:\n".format(existing_role_name)) - for key_and_cert in role_keys.values(): - key, cert_cn = key_and_cert - key_id_certs[key['keyid']] = cert_cn - print('{} id: {}'.format(cert_cn, key['keyid'])) - existing_keyid = input("\nEnter existing Yubikey's id and press ENTER ") - try: - existing_key = get_key(existing_keyid) - cert_cn = key_id_certs[existing_keyid] - except UnknownKeyError: - pass - if not use_existing: - input("Please insert a new YubiKey and press ENTER.") - serial_num = yk.get_serial_num() - while serial_num in yubikeys[role_name]: - print("Yubikey with serial number {} is already in use.\n".format(serial_num)) - input("Please insert new YubiKey and press ENTER.") - serial_num = yk.get_serial_num() - - pin = get_pin_for(key_name) - - cert_cn = input("Enter key holder's name: ") - - print('Generating keys, please wait...') - pub_key_pem = yk.setup(pin, cert_cn, cert_exp_days=EXPIRATION_INTERVAL).decode('utf-8') - - key = import_rsakey_from_pem(pub_key_pem, scheme) - - cert_path = os.path.join(repo.certs_dir, key['keyid'] + '.cert') - with open(cert_path, 'wb') as f: - f.write(yk.export_piv_x509()) - - # set Yubikey expiration date - role_obj.add_verification_key(key, expires=YUBIKEY_EXPIRATION_DATE) - role_obj.add_external_signature_provider(key, partial(signature_provider, - key['keyid'], cert_cn)) - yubikeys[role_name][serial_num] = (key, cert_cn) - else: - # if keystore exists, load the keys - # this is useful when generating tests - if keystore is not None: - public_key = import_rsa_publickey_from_file(os.path.join(keystore, - key_name + '.pub'), - scheme) - password = passwords[key_num] - if password: - private_key = import_rsa_privatekey_from_file(os.path.join(keystore, key_name), - password, scheme=scheme) - else: - private_key = import_rsa_privatekey_from_file(os.path.join(keystore, key_name), - scheme=scheme) - - # if it does not, generate the keys and print the output - else: - key = generate_rsa_key() - print("{} key:\n\n{}\n\n".format(role_name, key['keyval']['private'])) - public_key = private_key = key - role_obj.add_verification_key(public_key) - role_obj.load_signing_key(private_key) - - # if the repository is a test repository, add a target file called test-auth-repo - if test: - target_paths = Path(repo_path) / 'targets' - test_auth_file = target_paths / 'test-auth-repo' - test_auth_file.touch() - targets_obj = _role_obj('targets', repository) - targets_obj.add_target(str(test_auth_file)) - - repository.writeall() - if commit_message is not None and len(commit_message): - auth_repo = GitRepository(repo_path) - auth_repo.init_repo() - auth_repo.commit(commit_message) + yubikeys = defaultdict(dict) + roles_key_infos = _read_input_dict(roles_key_infos) + repo = AuthenticationRepo(repo_path) + if os.path.isdir(repo_path): + if repo.is_git_repository: + print("Repository {} already exists".format(repo_path)) + return + + tuf.repository_tool.METADATA_STAGED_DIRECTORY_NAME = METADATA_DIRECTORY_NAME + repository = create_new_repository(repo_path) + for role_name, key_info in roles_key_infos.items(): + num_of_keys = key_info.get("number", 1) + passwords = key_info.get("passwords", [None] * num_of_keys) + threshold = key_info.get("threshold", 1) + is_yubikey = key_info.get("yubikey", False) + scheme = key_info.get("scheme", DEFAULT_RSA_SIGNATURE_SCHEME) + + role_obj = _role_obj(role_name, repository) + role_obj.threshold = threshold + for key_num in range(num_of_keys): + key_name = _get_key_name(role_name, key_num, num_of_keys) + if is_yubikey: + print("Generating keys for {}".format(key_name)) + use_existing = False + if len(yubikeys) > 1 or ( + len(yubikeys) == 1 and role_name not in yubikeys + ): + use_existing = ( + input("Do you want to reuse already set up Yubikey? y/n ") + == "y" + ) + if use_existing: + existing_key = None + key_id_certs = {} + while existing_key is None: + for existing_role_name, role_keys in yubikeys.items(): + if existing_role_name == role_name: + continue + print( + "Existing keys for role {} are:\n".format( + existing_role_name + ) + ) + for key_and_cert in role_keys.values(): + key, cert_cn = key_and_cert + key_id_certs[key["keyid"]] = cert_cn + print("{} id: {}".format(cert_cn, key["keyid"])) + existing_keyid = input( + "\nEnter existing Yubikey's id and press ENTER " + ) + try: + existing_key = get_key(existing_keyid) + cert_cn = key_id_certs[existing_keyid] + except UnknownKeyError: + pass + if not use_existing: + input("Please insert a new YubiKey and press ENTER.") + serial_num = yk.get_serial_num() + while serial_num in yubikeys[role_name]: + print( + "Yubikey with serial number {} is already in use.\n".format( + serial_num + ) + ) + input("Please insert new YubiKey and press ENTER.") + serial_num = yk.get_serial_num() + + pin = get_pin_for(key_name) + + cert_cn = input("Enter key holder's name: ") + + print("Generating keys, please wait...") + pub_key_pem = yk.setup( + pin, cert_cn, cert_exp_days=EXPIRATION_INTERVAL + ).decode("utf-8") + + key = import_rsakey_from_pem(pub_key_pem, scheme) + + cert_path = os.path.join(repo.certs_dir, key["keyid"] + ".cert") + with open(cert_path, "wb") as f: + f.write(yk.export_piv_x509()) + + # set Yubikey expiration date + role_obj.add_verification_key(key, expires=YUBIKEY_EXPIRATION_DATE) + role_obj.add_external_signature_provider( + key, partial(signature_provider, key["keyid"], cert_cn) + ) + yubikeys[role_name][serial_num] = (key, cert_cn) + else: + # if keystore exists, load the keys + # this is useful when generating tests + if keystore is not None: + public_key = import_rsa_publickey_from_file( + os.path.join(keystore, key_name + ".pub"), scheme + ) + password = passwords[key_num] + if password: + private_key = import_rsa_privatekey_from_file( + os.path.join(keystore, key_name), password, scheme=scheme + ) + else: + private_key = import_rsa_privatekey_from_file( + os.path.join(keystore, key_name), scheme=scheme + ) + + # if it does not, generate the keys and print the output + else: + key = generate_rsa_key() + print( + "{} key:\n\n{}\n\n".format(role_name, key["keyval"]["private"]) + ) + public_key = private_key = key + role_obj.add_verification_key(public_key) + role_obj.load_signing_key(private_key) + + # if the repository is a test repository, add a target file called test-auth-repo + if test: + target_paths = Path(repo_path) / "targets" + test_auth_file = target_paths / "test-auth-repo" + test_auth_file.touch() + targets_obj = _role_obj("targets", repository) + targets_obj.add_target(str(test_auth_file)) + + repository.writeall() + if commit_message is not None and len(commit_message): + auth_repo = GitRepository(repo_path) + auth_repo.init_repo() + auth_repo.commit(commit_message) def generate_keys(keystore, roles_key_infos): - """ + """ Generate public and private keys and writes them to disk. Names of keys correspond to names of the TUF roles. If more than one key should be generated per role, a counter is appended @@ -253,23 +302,29 @@ def generate_keys(keystore, roles_key_infos): Names of the keys are set to names of the roles plus a counter, if more than one key should be generated. """ - roles_key_infos = _read_input_dict(roles_key_infos) - for role_name, key_info in roles_key_infos.items(): - num_of_keys = key_info.get('number', 1) - bits = key_info.get('length', 3072) - passwords = key_info.get('passwords', [''] * num_of_keys) - is_yubikey = key_info.get('yubikey', False) - for key_num in range(num_of_keys): - if not is_yubikey: - key_name = _get_key_name(role_name, key_num, num_of_keys) - password = passwords[key_num] - generate_and_write_rsa_keypair(os.path.join(keystore, key_name), bits=bits, - password=password) - - -def generate_repositories_json(repo_path, targets_directory, namespace=None, - targets_relative_dir=None, custom_data=None): - """ + roles_key_infos = _read_input_dict(roles_key_infos) + for role_name, key_info in roles_key_infos.items(): + num_of_keys = key_info.get("number", 1) + bits = key_info.get("length", 3072) + passwords = key_info.get("passwords", [""] * num_of_keys) + is_yubikey = key_info.get("yubikey", False) + for key_num in range(num_of_keys): + if not is_yubikey: + key_name = _get_key_name(role_name, key_num, num_of_keys) + password = passwords[key_num] + generate_and_write_rsa_keypair( + os.path.join(keystore, key_name), bits=bits, password=password + ) + + +def generate_repositories_json( + repo_path, + targets_directory, + namespace=None, + targets_relative_dir=None, + custom_data=None, +): + """ Generatesinitial repositories.json @@ -283,56 +338,72 @@ def generate_repositories_json(repo_path, targets_directory, namespace=None, Directory relative to which urls of the target repositories are set, if they do not have remote set """ - custom_data = _read_input_dict(custom_data) - repositories = {} - - repo_path = Path(repo_path).resolve() - auth_repo_targets_dir = repo_path / TARGETS_DIRECTORY_NAME - targets_directory = Path(targets_directory).resolve() - if targets_relative_dir is not None: - targets_relative_dir = Path(targets_relative_dir).resolve() - if namespace is None: - namespace = targets_directory.name - for target_repo_dir in targets_directory.glob('*'): - if not target_repo_dir.is_dir() or target_repo_dir == repo_path: - continue - target_repo = GitRepository(target_repo_dir.resolve()) - if not target_repo.is_git_repository: - continue - target_repo_name = target_repo_dir.name - target_repo_namespaced_name = target_repo_name if not namespace else '{}/{}'.format( - namespace, str(target_repo_name)) - # determine url to specify in initial repositories.json - # if the repository has a remote set, use that url - # otherwise, set url to the repository's absolute or relative path (relative - # to targets_relative_dir if it is specified) - url = target_repo.get_remote_url() - if url is None: - if targets_relative_dir is not None: - url = os.path.relpath(str(target_repo.repo_path), str(targets_relative_dir)) - else: - url = str(Path(target_repo.repo_path).resolve()) - # convert to posix path - url = pathlib.Path(url).as_posix() - repositories[target_repo_namespaced_name] = {'urls': [url]} - if target_repo_namespaced_name in custom_data: - repositories[target_repo_namespaced_name]['custom'] = custom_data[target_repo_namespaced_name] - - (auth_repo_targets_dir / 'repositories.json').write_text(json.dumps({'repositories': repositories}, - indent=4)) + custom_data = _read_input_dict(custom_data) + repositories = {} + + repo_path = Path(repo_path).resolve() + auth_repo_targets_dir = repo_path / TARGETS_DIRECTORY_NAME + targets_directory = Path(targets_directory).resolve() + if targets_relative_dir is not None: + targets_relative_dir = Path(targets_relative_dir).resolve() + if namespace is None: + namespace = targets_directory.name + for target_repo_dir in targets_directory.glob("*"): + if not target_repo_dir.is_dir() or target_repo_dir == repo_path: + continue + target_repo = GitRepository(target_repo_dir.resolve()) + if not target_repo.is_git_repository: + continue + target_repo_name = target_repo_dir.name + target_repo_namespaced_name = ( + target_repo_name + if not namespace + else "{}/{}".format(namespace, str(target_repo_name)) + ) + # determine url to specify in initial repositories.json + # if the repository has a remote set, use that url + # otherwise, set url to the repository's absolute or relative path (relative + # to targets_relative_dir if it is specified) + url = target_repo.get_remote_url() + if url is None: + if targets_relative_dir is not None: + url = os.path.relpath( + str(target_repo.repo_path), str(targets_relative_dir) + ) + else: + url = str(Path(target_repo.repo_path).resolve()) + # convert to posix path + url = pathlib.Path(url).as_posix() + repositories[target_repo_namespaced_name] = {"urls": [url]} + if target_repo_namespaced_name in custom_data: + repositories[target_repo_namespaced_name]["custom"] = custom_data[ + target_repo_namespaced_name + ] + + (auth_repo_targets_dir / "repositories.json").write_text( + json.dumps({"repositories": repositories}, indent=4) + ) def _get_key_name(role_name, key_num, num_of_keys): - if num_of_keys == 1: - return role_name - else: - return role_name + str(key_num + 1) - - -def init_repo(repo_path, targets_directory, namespace, targets_relative_dir, - keystore, roles_key_infos, repos_custom=None, commit=None, - test=False): - """ + if num_of_keys == 1: + return role_name + else: + return role_name + str(key_num + 1) + + +def init_repo( + repo_path, + targets_directory, + namespace, + targets_relative_dir, + keystore, + roles_key_infos, + repos_custom=None, + commit=None, + test=False, +): + """ Generate initial repository: 1. Crete tuf authentication repository @@ -359,47 +430,53 @@ def init_repo(repo_path, targets_directory, namespace, targets_relative_dir, test: Indicates if the created repository is a test authentication repository """ - # read the key infos here, no need to read the file multiple times - roles_key_infos = _read_input_dict(roles_key_infos) - commit_msg = 'Initial commit' if commit else None - create_repository(repo_path, keystore, roles_key_infos, commit_msg, test) - add_target_repos(repo_path, targets_directory, namespace) - generate_repositories_json(repo_path, targets_directory, namespace, - targets_relative_dir, repos_custom) - register_target_files(repo_path, keystore, roles_key_infos, commit_msg=commit) + # read the key infos here, no need to read the file multiple times + roles_key_infos = _read_input_dict(roles_key_infos) + commit_msg = "Initial commit" if commit else None + create_repository(repo_path, keystore, roles_key_infos, commit_msg, test) + add_target_repos(repo_path, targets_directory, namespace) + generate_repositories_json( + repo_path, targets_directory, namespace, targets_relative_dir, repos_custom + ) + register_target_files(repo_path, keystore, roles_key_infos, commit_msg=commit) def _load_role_key_from_keys_dict(role, roles_key_infos): - password = None - if roles_key_infos is not None and len(roles_key_infos): - if role in roles_key_infos: - password = roles_key_infos[role].get('passwords', [None])[0] or None - return password + password = None + if roles_key_infos is not None and len(roles_key_infos): + if role in roles_key_infos: + password = roles_key_infos[role].get("passwords", [None])[0] or None + return password def register_target_file(repo_path, file_path, keystore, roles_key_infos, scheme): - roles_key_infos = _read_input_dict(roles_key_infos) - taf_repo = Repository(repo_path) - taf_repo.add_existing_target(file_path) + roles_key_infos = _read_input_dict(roles_key_infos) + taf_repo = Repository(repo_path) + taf_repo.add_existing_target(file_path) - _write_targets_metadata(taf_repo, keystore, roles_key_infos, scheme) + _write_targets_metadata(taf_repo, keystore, roles_key_infos, scheme) def _read_input_dict(value): - if value is None: - return {} - if type(value) is str: - if os.path.isfile(value): - with open(value) as f: - value = json.loads(f.read()) - else: - value = json.loads(value) - return value - - -def register_target_files(repo_path, keystore, roles_key_infos, - commit_msg=None, scheme=DEFAULT_RSA_SIGNATURE_SCHEME): - """ + if value is None: + return {} + if type(value) is str: + if os.path.isfile(value): + with open(value) as f: + value = json.loads(f.read()) + else: + value = json.loads(value) + return value + + +def register_target_files( + repo_path, + keystore, + roles_key_infos, + commit_msg=None, + scheme=DEFAULT_RSA_SIGNATURE_SCHEME, +): + """ Register all files found in the target directory as targets - updates the targets metadata file, snapshot and timestamp. Sign targets @@ -416,95 +493,107 @@ def register_target_files(repo_path, keystore, roles_key_infos, scheme: A signature scheme used for signing. """ - roles_key_infos = _read_input_dict(roles_key_infos) - repo_path = Path(repo_path).resolve() - targets_path = repo_path / TARGETS_DIRECTORY_NAME - taf_repo = Repository(str(repo_path)) - for root, _, filenames in os.walk(str(targets_path)): - for filename in filenames: - taf_repo.add_existing_target(str(Path(root) / filename)) - _write_targets_metadata(taf_repo, keystore, roles_key_infos, scheme) - if commit_msg is not None: - auth_repo = GitRepository(repo_path) - auth_repo.commit(commit_msg) + roles_key_infos = _read_input_dict(roles_key_infos) + repo_path = Path(repo_path).resolve() + targets_path = repo_path / TARGETS_DIRECTORY_NAME + taf_repo = Repository(str(repo_path)) + for root, _, filenames in os.walk(str(targets_path)): + for filename in filenames: + taf_repo.add_existing_target(str(Path(root) / filename)) + _write_targets_metadata(taf_repo, keystore, roles_key_infos, scheme) + if commit_msg is not None: + auth_repo = GitRepository(repo_path) + auth_repo.commit(commit_msg) def _role_obj(role, repository): - if role == 'targets': - return repository.targets - elif role == 'snapshot': - return repository.snapshot - elif role == 'timestamp': - return repository.timestamp - elif role == 'root': - return repository.root + if role == "targets": + return repository.targets + elif role == "snapshot": + return repository.snapshot + elif role == "timestamp": + return repository.timestamp + elif role == "root": + return repository.root def signature_provider(key_id, cert_cn, key, data): # pylint: disable=W0613 - def _check_key_id(expected_key_id): - try: - inserted_key = yk.get_piv_public_key_tuf() - return expected_key_id == inserted_key['keyid'] - except Exception: - return False - - while not _check_key_id(key_id): - pass - - data = securesystemslib.formats.encode_canonical(data).encode('utf-8') - key_pin = getpass("Please insert {} YubiKey, input PIN and press ENTER.\n" - .format(cert_cn)) - signature = yk.sign_piv_rsa_pkcs1v15(data, key_pin) - - return { - 'keyid': key_id, - 'sig': hexlify(signature).decode() - } - - -def update_metadata_expiration_date(repo_path, keystore, roles_key_infos, role, - start_date=datetime.datetime.now(), interval=None, commit_msg=None): - roles_key_infos = _read_input_dict(roles_key_infos) - taf_repo = Repository(repo_path) - update_methods = {'timestamp': taf_repo.update_timestamp, - 'snapshot': taf_repo.update_snapshot, - 'targets': taf_repo.update_targets_from_keystore} - password = _load_role_key_from_keys_dict(role, roles_key_infos) - update_methods[role](keystore, password, start_date, interval) - - if commit_msg is not None: - auth_repo = GitRepository(repo_path) - auth_repo.commit(commit_msg) + def _check_key_id(expected_key_id): + try: + inserted_key = yk.get_piv_public_key_tuf() + return expected_key_id == inserted_key["keyid"] + except Exception: + return False + + while not _check_key_id(key_id): + pass + + data = securesystemslib.formats.encode_canonical(data).encode("utf-8") + key_pin = getpass( + "Please insert {} YubiKey, input PIN and press ENTER.\n".format(cert_cn) + ) + signature = yk.sign_piv_rsa_pkcs1v15(data, key_pin) + + return {"keyid": key_id, "sig": hexlify(signature).decode()} + + +def update_metadata_expiration_date( + repo_path, + keystore, + roles_key_infos, + role, + start_date=datetime.datetime.now(), + interval=None, + commit_msg=None, +): + roles_key_infos = _read_input_dict(roles_key_infos) + taf_repo = Repository(repo_path) + update_methods = { + "timestamp": taf_repo.update_timestamp, + "snapshot": taf_repo.update_snapshot, + "targets": taf_repo.update_targets_from_keystore, + } + password = _load_role_key_from_keys_dict(role, roles_key_infos) + update_methods[role](keystore, password, start_date, interval) + + if commit_msg is not None: + auth_repo = GitRepository(repo_path) + auth_repo.commit(commit_msg) def _write_targets_metadata(taf_repo, keystore, roles_key_infos, scheme): - if keystore is not None: - # load all keys from keystore files - # convenient when generating test repositories - # not recommended in production - targets_password = _load_role_key_from_keys_dict('targets', roles_key_infos) - targets_key = load_role_key(keystore, 'targets', targets_password, scheme) - taf_repo.update_targets_from_keystore(targets_key, write=False) - snapshot_password = _load_role_key_from_keys_dict('snapshot', roles_key_infos) - timestamp_password = _load_role_key_from_keys_dict('timestamp', roles_key_infos) - timestamp_key = load_role_key(keystore, 'timestamp', timestamp_password) - snapshot_key = load_role_key(keystore, 'snapshot', snapshot_password) - else: - targets_key_pin = getpass('Please insert targets YubiKey, input PIN and press ENTER.') - taf_repo.update_targets(targets_key_pin, write=False) - snapshot_pem = getpass('Enter snapshot key') - snapshot_pem = _form_private_pem(snapshot_pem) - snapshot_key = import_rsakey_from_pem(snapshot_pem, scheme) - - timestamp_pem = getpass('Enter timestamp key') - timestamp_pem = _form_private_pem(timestamp_pem) - timestamp_key = import_rsakey_from_pem(timestamp_pem, scheme) - taf_repo.update_snapshot_and_timestmap(snapshot_key, timestamp_key, write=False) - taf_repo.writeall() + if keystore is not None: + # load all keys from keystore files + # convenient when generating test repositories + # not recommended in production + targets_password = _load_role_key_from_keys_dict("targets", roles_key_infos) + targets_key = load_role_key(keystore, "targets", targets_password, scheme) + taf_repo.update_targets_from_keystore(targets_key, write=False) + snapshot_password = _load_role_key_from_keys_dict("snapshot", roles_key_infos) + timestamp_password = _load_role_key_from_keys_dict("timestamp", roles_key_infos) + timestamp_key = load_role_key(keystore, "timestamp", timestamp_password) + snapshot_key = load_role_key(keystore, "snapshot", snapshot_password) + else: + targets_key_pin = getpass( + "Please insert targets YubiKey, input PIN and press ENTER." + ) + taf_repo.update_targets(targets_key_pin, write=False) + snapshot_pem = getpass("Enter snapshot key") + snapshot_pem = _form_private_pem(snapshot_pem) + snapshot_key = import_rsakey_from_pem(snapshot_pem, scheme) + + timestamp_pem = getpass("Enter timestamp key") + timestamp_pem = _form_private_pem(timestamp_pem) + timestamp_key = import_rsakey_from_pem(timestamp_pem, scheme) + taf_repo.update_snapshot_and_timestmap(snapshot_key, timestamp_key, write=False) + taf_repo.writeall() def _form_private_pem(pem): - return '-----BEGIN RSA PRIVATE KEY-----\n{}\n-----END RSA PRIVATE KEY-----'.format(pem) + return "-----BEGIN RSA PRIVATE KEY-----\n{}\n-----END RSA PRIVATE KEY-----".format( + pem + ) + # TODO Implement update of repositories.json (updating urls, custom data, adding new repository, removing # repository etc.) diff --git a/taf/exceptions.py b/taf/exceptions.py index f6a57e679..859a7619e 100644 --- a/taf/exceptions.py +++ b/taf/exceptions.py @@ -1,71 +1,76 @@ class TAFError(Exception): - pass + pass class InvalidBranchError(TAFError): - pass + pass class InvalidCommitError(TAFError): - pass + pass class InvalidKeyError(TAFError): - def __init__(self, metadata_role): - super().__init__('Cannot sign {} metadata file with inserted key.'.format(metadata_role)) + def __init__(self, metadata_role): + super().__init__( + "Cannot sign {} metadata file with inserted key.".format(metadata_role) + ) class InvalidOrMissingMetadataError(TAFError): - pass + pass class InvalidRepositoryError(TAFError): - pass + pass class MetadataUpdateError(TAFError): - def __init__(self, metadata_role, message): - super().__init__('Error happened while updating {} metadata role(s):\n\n{}' - .format(metadata_role, message)) - self.metadata_role = metadata_role - self.message = message + def __init__(self, metadata_role, message): + super().__init__( + "Error happened while updating {} metadata role(s):\n\n{}".format( + metadata_role, message + ) + ) + self.metadata_role = metadata_role + self.message = message class RootMetadataUpdateError(MetadataUpdateError): - def __init__(self, message): - super().__init__('root', message) + def __init__(self, message): + super().__init__("root", message) class PINMissmatchError(Exception): - pass + pass class SnapshotMetadataUpdateError(MetadataUpdateError): - def __init__(self, message): - super().__init__('snapshot', message) + def __init__(self, message): + super().__init__("snapshot", message) class TargetsMetadataUpdateError(MetadataUpdateError): - def __init__(self, message): - super().__init__('targets', message) + def __init__(self, message): + super().__init__("targets", message) class TimestampMetadataUpdateError(MetadataUpdateError): - def __init__(self, message): - super().__init__('timestamp', message) + def __init__(self, message): + super().__init__("timestamp", message) class NoSpeculativeBranchError(TAFError): - pass + pass class RepositoriesNotFoundError(TAFError): - pass + pass class UpdateFailedError(TAFError): - pass + pass class YubikeyError(Exception): - pass + pass diff --git a/taf/git.py b/taf/git.py index 5776be85d..a4c0ea987 100644 --- a/taf/git.py +++ b/taf/git.py @@ -14,296 +14,351 @@ class GitRepository(object): - - def __init__(self, repo_path, repo_urls=None, additional_info=None, default_branch='master'): - """ + def __init__( + self, repo_path, repo_urls=None, additional_info=None, default_branch="master" + ): + """ Args: repo_path: repository's path repo_urls: repository's urls (optional) additional_info: a dictionary containing other data (optional) default_branch: repository's default branch """ - self.repo_path = str(repo_path) - self.default_branch = default_branch - if repo_urls is not None: - if settings.update_from_filesystem is False: - for url in repo_urls: - _validate_url(url) - else: - repo_urls = [os.path.normpath(os.path.join(self.repo_path, url)) if - not os.path.isabs(url) else url - for url in repo_urls] - self.repo_urls = repo_urls - self.additional_info = additional_info - self.repo_name = os.path.basename(self.repo_path) - - _remotes = None - - @property - def remotes(self): - if self._remotes is None: - self._remotes = self._git('remote').split('\n') - return self._remotes - - @property - def is_git_repository_root(self): - """Check if path is git repository.""" - git_path = Path(self.repo_path) / '.git' - return self.is_git_repository and (git_path.is_dir() or git_path.is_file()) - - @property - def is_git_repository(self): - try: - self._git('rev-parse --git-dir') - return True - except subprocess.CalledProcessError: - return False - - @property - def initial_commit(self): - return self._git('rev-list --max-parents=0 HEAD').strip() if self.is_git_repository else None - - def is_remote_branch(self, branch_name): - for remote in self.remotes: - if branch_name.startswith(remote + '/'): - return True - return False - - def _git(self, cmd, *args, **kwargs): - """Call git commands in subprocess + self.repo_path = str(repo_path) + self.default_branch = default_branch + if repo_urls is not None: + if settings.update_from_filesystem is False: + for url in repo_urls: + _validate_url(url) + else: + repo_urls = [ + os.path.normpath(os.path.join(self.repo_path, url)) + if not os.path.isabs(url) + else url + for url in repo_urls + ] + self.repo_urls = repo_urls + self.additional_info = additional_info + self.repo_name = os.path.basename(self.repo_path) + + _remotes = None + + @property + def remotes(self): + if self._remotes is None: + self._remotes = self._git("remote").split("\n") + return self._remotes + + @property + def is_git_repository_root(self): + """Check if path is git repository.""" + git_path = Path(self.repo_path) / ".git" + return self.is_git_repository and (git_path.is_dir() or git_path.is_file()) + + @property + def is_git_repository(self): + try: + self._git("rev-parse --git-dir") + return True + except subprocess.CalledProcessError: + return False + + @property + def initial_commit(self): + return ( + self._git("rev-list --max-parents=0 HEAD").strip() + if self.is_git_repository + else None + ) + + def is_remote_branch(self, branch_name): + for remote in self.remotes: + if branch_name.startswith(remote + "/"): + return True + return False + + def _git(self, cmd, *args, **kwargs): + """Call git commands in subprocess e.g.: self._git('checkout {}', branch_name) """ - log_error = kwargs.pop('log_error', False) - log_error_msg = kwargs.pop('log_error_msg', '') - reraise_error = kwargs.pop('reraise_error', False) - log_success_msg = kwargs.pop('log_success_msg', '') - - if len(args): - cmd = cmd.format(*args) - command = 'git -C {} {}'.format(self.repo_path, cmd) - if log_error or log_error_msg: - try: - result = run(command) - if log_success_msg: - logger.debug('Repo %s:' + log_success_msg, self.repo_name) - except subprocess.CalledProcessError as e: - if log_error_msg: - logger.error(log_error_msg) + log_error = kwargs.pop("log_error", False) + log_error_msg = kwargs.pop("log_error_msg", "") + reraise_error = kwargs.pop("reraise_error", False) + log_success_msg = kwargs.pop("log_success_msg", "") + + if len(args): + cmd = cmd.format(*args) + command = "git -C {} {}".format(self.repo_path, cmd) + if log_error or log_error_msg: + try: + result = run(command) + if log_success_msg: + logger.debug("Repo %s:" + log_success_msg, self.repo_name) + except subprocess.CalledProcessError as e: + if log_error_msg: + logger.error(log_error_msg) + else: + logger.error( + "Repo %s: error occurred while executing %s:\n%s", + self.repo_name, + command, + e.output, + ) + if reraise_error: + raise + else: + result = run(command) + if log_success_msg: + logger.debug("Repo %s: " + log_success_msg, self.repo_name) + return result + + def all_commits_since_commit(self, since_commit=None): + if since_commit is not None: + commits = self._git("rev-list {}..HEAD", since_commit).strip() + else: + commits = self._git("log --format=format:%H").strip() + if not commits: + commits = [] + else: + commits = commits.split("\n") + commits.reverse() + + if since_commit is not None: + logger.debug( + "Repo %s: found the following commits after commit %s: %s", + self.repo_name, + since_commit, + ", ".join(commits), + ) + else: + logger.debug( + "Repo %s: found the following commits: %s", + self.repo_name, + ", ".join(commits), + ) + return commits + + def all_fetched_commits(self, branch="master"): + commits = self._git("rev-list ..origin/{}", branch).strip() + if not commits: + commits = [] else: - logger.error('Repo %s: error occurred while executing %s:\n%s', - self.repo_name, command, e.output) - if reraise_error: - raise - else: - result = run(command) - if log_success_msg: - logger.debug('Repo %s: ' + log_success_msg, self.repo_name) - return result - - def all_commits_since_commit(self, since_commit=None): - if since_commit is not None: - commits = self._git('rev-list {}..HEAD', since_commit).strip() - else: - commits = self._git('log --format=format:%H').strip() - if not commits: - commits = [] - else: - commits = commits.split('\n') - commits.reverse() - - if since_commit is not None: - logger.debug('Repo %s: found the following commits after commit %s: %s', self.repo_name, - since_commit, ', '.join(commits)) - else: - logger.debug('Repo %s: found the following commits: %s', self.repo_name, ', '.join(commits)) - return commits - - def all_fetched_commits(self, branch='master'): - commits = self._git('rev-list ..origin/{}', branch).strip() - if not commits: - commits = [] - else: - commits = commits.split('\n') - commits.reverse() - logger.debug('Repo %s: fetched the following commits %s', self.repo_name, ', '.join(commits)) - return commits - - def branch_local_name(self, remote_branch_name): - """Strip remote from the given remote branch""" - for remote in self.remotes: - if remote_branch_name.startswith(remote + '/'): - return remote_branch_name.split('/', 1)[1] - - def checkout_branch(self, branch_name, create=False): - """Check out the specified branch. If it does not exists and + commits = commits.split("\n") + commits.reverse() + logger.debug( + "Repo %s: fetched the following commits %s", + self.repo_name, + ", ".join(commits), + ) + return commits + + def branch_local_name(self, remote_branch_name): + """Strip remote from the given remote branch""" + for remote in self.remotes: + if remote_branch_name.startswith(remote + "/"): + return remote_branch_name.split("/", 1)[1] + + def checkout_branch(self, branch_name, create=False): + """Check out the specified branch. If it does not exists and the create parameter is set to True, create a new branch. If the branch does not exist and create is set to False, raise an exception.""" - try: - self._git('checkout {}', branch_name, log_error=True, reraise_error=True, - log_success_msg='checked out branch {}'.format(branch_name)) - except subprocess.CalledProcessError as e: - if create: - self.create_and_checkout_branch(branch_name) - else: - raise(e) - - def clean(self): - self._git('clean -fd') - - def clone(self, no_checkout=False, bare=False): - - logger.info('Repo %s: cloning repository', self.repo_name) - shutil.rmtree(self.repo_path, True) - os.makedirs(self.repo_path, exist_ok=True) - if self.repo_urls is None: - raise Exception('Cannot clone repository. No urls were specified') - params = '' - if bare: - params = '--bare' - elif no_checkout: - params = '--no-checkout' - for url in self.repo_urls: - try: - self._git('clone {} . {}', url, params, log_success_msg='successfully cloned') - except subprocess.CalledProcessError: - logger.error('Repo %s: cannot clone from url %s', self.repo_name, url) - else: - break - - def create_and_checkout_branch(self, branch_name): - self._git('checkout -b {}', branch_name, log_success_msg='created and checked out branch {}'. - format(branch_name, log_error=True, reraise_error=True)) - - def checkout_commit(self, commit): - self._git('checkout {}', commit, log_success_msg='checked out commit {}'.format(commit)) - - def commit(self, message): - """Create a commit with the provided message + try: + self._git( + "checkout {}", + branch_name, + log_error=True, + reraise_error=True, + log_success_msg="checked out branch {}".format(branch_name), + ) + except subprocess.CalledProcessError as e: + if create: + self.create_and_checkout_branch(branch_name) + else: + raise (e) + + def clean(self): + self._git("clean -fd") + + def clone(self, no_checkout=False, bare=False): + + logger.info("Repo %s: cloning repository", self.repo_name) + shutil.rmtree(self.repo_path, True) + os.makedirs(self.repo_path, exist_ok=True) + if self.repo_urls is None: + raise Exception("Cannot clone repository. No urls were specified") + params = "" + if bare: + params = "--bare" + elif no_checkout: + params = "--no-checkout" + for url in self.repo_urls: + try: + self._git( + "clone {} . {}", url, params, log_success_msg="successfully cloned" + ) + except subprocess.CalledProcessError: + logger.error("Repo %s: cannot clone from url %s", self.repo_name, url) + else: + break + + def create_and_checkout_branch(self, branch_name): + self._git( + "checkout -b {}", + branch_name, + log_success_msg="created and checked out branch {}".format( + branch_name, log_error=True, reraise_error=True + ), + ) + + def checkout_commit(self, commit): + self._git( + "checkout {}", + commit, + log_success_msg="checked out commit {}".format(commit), + ) + + def commit(self, message): + """Create a commit with the provided message on the currently checked out branch""" - self._git('add -A') - try: - self._git('diff --cached --exit-code --shortstat') - except subprocess.CalledProcessError: - run('git', '-C', self.repo_path, 'commit', '--quiet', '-m', message) - return self._git('rev-parse HEAD') - - def commits_on_branch_and_not_other(self, branch1, branch2, include_branching_commit=False): - """ + self._git("add -A") + try: + self._git("diff --cached --exit-code --shortstat") + except subprocess.CalledProcessError: + run("git", "-C", self.repo_path, "commit", "--quiet", "-m", message) + return self._git("rev-parse HEAD") + + def commits_on_branch_and_not_other( + self, branch1, branch2, include_branching_commit=False + ): + """ Meant to find commits belonging to a branch which branches off of a commit from another branch. For example, to find only commits on a speculative branch and not on the master branch. """ - logger.debug('Repo %s: finding commits which are on branch %s, but not on branch %s', - self.repo_name, branch1, branch2) - commits = self._git('log {} --not {} --no-merges --format=format:%H', branch1, branch2) - commits = commits.split('\n') if commits else [] - if include_branching_commit: - branching_commit = self._git('rev-list -n 1 {}~1', commits[-1]) - commits.append(branching_commit) - logger.debug('Repo %s: found the following commits: %s', self.repo_name, commits) - return commits - - def get_commits_date(self, commit): - date = self._git('show -s --format=%at {}', commit) - return date.split(' ', 1)[0] - - def get_json(self, commit, path): - s = self.get_file(commit, path) - return json.loads(s) - - def get_file(self, commit, path): - return self._git('show {}:{}', commit, path) - - def get_remote_url(self): - try: - return self._git('config --get remote.origin.url').strip() - except subprocess.CalledProcessError: - return None - - def delete_branch(self, branch_name): - self._git('branch -D {}', branch_name) - - def head_commit_sha(self): - """Finds sha of the commit to which the current HEAD points""" - try: - return self._git('rev-parse HEAD') - except subprocess.CalledProcessError: - return None - - def fetch(self, fetch_all=False): - if fetch_all: - self._git('fetch --all') - else: - self._git('fetch') - - def init_repo(self): - if not os.path.isdir(self.repo_path): - os.makedirs(self.repo_path, exist_ok=True) - self._git('init') - if self.repo_urls is not None and len(self.repo_urls): - self._git('remote add origin {}', self.repo_urls[0]) - - def list_files_at_revision(self, commit, path=''): - if path is None: - path = '' - file_names = self._git('ls-tree -r --name-only {}', commit) - list_of_files = [] - if not file_names: - return list_of_files - for file_in_repo in file_names.split('\n'): - if not file_in_repo.startswith(path): - continue - file_in_repo = os.path.relpath(file_in_repo, path) - list_of_files.append(file_in_repo) - return list_of_files - - def list_commits(self, **kwargs): - params = [] - for name, value in kwargs.items(): - params.append('--{}={}'.format(name, value)) - - return self._git('log {}', ' '.join(params)).split('\n') - - def merge_commit(self, commit): - self._git('merge {}', commit) - - def pull(self): - """Pull current branch""" - self._git('pull') - - def push(self, branch=''): - """Push all changes""" - try: - self._git('push origin {}', branch).strip() - except subprocess.CalledProcessError: - self._git('--set-upstream origin {}', branch).strip() - - def rename_branch(self, old_name, new_name): - self._git('branch -m {} {}', old_name, new_name) - - def reset_num_of_commits(self, num_of_commits, hard=False): - flag = '--hard' if hard else '--soft' - self._git('reset {} HEAD~{}'.format(flag, num_of_commits)) - - def reset_to_commit(self, commit, hard=False): - flag = '--hard' if hard else '--soft' - self._git('reset {} {}'.format(flag, commit)) - - def reset_to_head(self): - self._git('reset --hard HEAD') - - def set_upstream(self, branch_name): - self._git('branch -u origin/{}', branch_name) + logger.debug( + "Repo %s: finding commits which are on branch %s, but not on branch %s", + self.repo_name, + branch1, + branch2, + ) + commits = self._git( + "log {} --not {} --no-merges --format=format:%H", branch1, branch2 + ) + commits = commits.split("\n") if commits else [] + if include_branching_commit: + branching_commit = self._git("rev-list -n 1 {}~1", commits[-1]) + commits.append(branching_commit) + logger.debug( + "Repo %s: found the following commits: %s", self.repo_name, commits + ) + return commits + + def get_commits_date(self, commit): + date = self._git("show -s --format=%at {}", commit) + return date.split(" ", 1)[0] + + def get_json(self, commit, path): + s = self.get_file(commit, path) + return json.loads(s) + + def get_file(self, commit, path): + return self._git("show {}:{}", commit, path) + + def get_remote_url(self): + try: + return self._git("config --get remote.origin.url").strip() + except subprocess.CalledProcessError: + return None + + def delete_branch(self, branch_name): + self._git("branch -D {}", branch_name) + + def head_commit_sha(self): + """Finds sha of the commit to which the current HEAD points""" + try: + return self._git("rev-parse HEAD") + except subprocess.CalledProcessError: + return None + + def fetch(self, fetch_all=False): + if fetch_all: + self._git("fetch --all") + else: + self._git("fetch") + + def init_repo(self): + if not os.path.isdir(self.repo_path): + os.makedirs(self.repo_path, exist_ok=True) + self._git("init") + if self.repo_urls is not None and len(self.repo_urls): + self._git("remote add origin {}", self.repo_urls[0]) + + def list_files_at_revision(self, commit, path=""): + if path is None: + path = "" + file_names = self._git("ls-tree -r --name-only {}", commit) + list_of_files = [] + if not file_names: + return list_of_files + for file_in_repo in file_names.split("\n"): + if not file_in_repo.startswith(path): + continue + file_in_repo = os.path.relpath(file_in_repo, path) + list_of_files.append(file_in_repo) + return list_of_files + + def list_commits(self, **kwargs): + params = [] + for name, value in kwargs.items(): + params.append("--{}={}".format(name, value)) + + return self._git("log {}", " ".join(params)).split("\n") + + def merge_commit(self, commit): + self._git("merge {}", commit) + + def pull(self): + """Pull current branch""" + self._git("pull") + + def push(self, branch=""): + """Push all changes""" + try: + self._git("push origin {}", branch).strip() + except subprocess.CalledProcessError: + self._git("--set-upstream origin {}", branch).strip() + + def rename_branch(self, old_name, new_name): + self._git("branch -m {} {}", old_name, new_name) + + def reset_num_of_commits(self, num_of_commits, hard=False): + flag = "--hard" if hard else "--soft" + self._git("reset {} HEAD~{}".format(flag, num_of_commits)) + + def reset_to_commit(self, commit, hard=False): + flag = "--hard" if hard else "--soft" + self._git("reset {} {}".format(flag, commit)) + + def reset_to_head(self): + self._git("reset --hard HEAD") + + def set_upstream(self, branch_name): + self._git("branch -u origin/{}", branch_name) class NamedGitRepository(GitRepository): - - def __init__(self, root_dir, repo_name, repo_urls=None, additional_info=None, - default_branch='master'): - """ + def __init__( + self, + root_dir, + repo_name, + repo_urls=None, + additional_info=None, + default_branch="master", + ): + """ Args: root_dir: the root directory repo_name: repository's path relative to the root directory root_dir @@ -313,54 +368,62 @@ def __init__(self, root_dir, repo_name, repo_urls=None, additional_info=None, repo_path is the absolute path to this repository. It is set by joining root_dir and repo_name. """ - repo_path = _get_repo_path(root_dir, repo_name) - super().__init__(repo_path, repo_urls, additional_info, default_branch) - self.repo_name = repo_name + repo_path = _get_repo_path(root_dir, repo_name) + super().__init__(repo_path, repo_urls, additional_info, default_branch) + self.repo_name = repo_name def _get_repo_path(root_dir, repo_name): - """ + """ get the path to a repo and ensure it is valid. (since this is coming from potentially untrusted data) """ - _validate_repo_name(repo_name) - repo_dir = str((Path(root_dir) / (repo_name or ''))) - if not repo_dir.startswith(repo_dir): - logger.error('Repo %s: repository name is not valid', repo_name) - raise InvalidRepositoryError('Invalid repository name: {}'.format(repo_name)) - return repo_dir + _validate_repo_name(repo_name) + repo_dir = str((Path(root_dir) / (repo_name or ""))) + if not repo_dir.startswith(repo_dir): + logger.error("Repo %s: repository name is not valid", repo_name) + raise InvalidRepositoryError("Invalid repository name: {}".format(repo_name)) + return repo_dir -_repo_name_re = re.compile(r'^\w[\w_-]*/\w[\w_-]*$') +_repo_name_re = re.compile(r"^\w[\w_-]*/\w[\w_-]*$") def _validate_repo_name(repo_name): - """ Ensure the repo name is not malicious """ - match = _repo_name_re.match(repo_name) - if not match: - logger.error('Repo %s: repository name is not valid', repo_name) - raise InvalidRepositoryError('Repository name must be in format namespace/repository ' - 'and can only contain letters, numbers, underscores and ' - 'dashes, but got "{}"'.format(repo_name)) + """ Ensure the repo name is not malicious """ + match = _repo_name_re.match(repo_name) + if not match: + logger.error("Repo %s: repository name is not valid", repo_name) + raise InvalidRepositoryError( + "Repository name must be in format namespace/repository " + "and can only contain letters, numbers, underscores and " + 'dashes, but got "{}"'.format(repo_name) + ) _http_fttp_url = re.compile( - r'^(?:http|ftp)s?://' # http:// or https:// + r"^(?:http|ftp)s?://" # http:// or https:// # domain... - r'(?:(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+(?:[A-Z]{2,6}\.?|[A-Z0-9-]{2,}\.?)|' - r'localhost|' # localhost... - r'\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3})' # ...or ip - r'(?::\d+)?' # optional port - r'(?:/?|[/?]\S+)$', re.IGNORECASE) + r"(?:(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+(?:[A-Z]{2,6}\.?|[A-Z0-9-]{2,}\.?)|" + r"localhost|" # localhost... + r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3})" # ...or ip + r"(?::\d+)?" # optional port + r"(?:/?|[/?]\S+)$", + re.IGNORECASE, +) -_ssh_url = re.compile(r'((git|ssh|http(s)?)|(git@[\w\.]+))(:(//)?)([\w\.@\:/\-~]+)(\.git)?(/)?') +_ssh_url = re.compile( + r"((git|ssh|http(s)?)|(git@[\w\.]+))(:(//)?)([\w\.@\:/\-~]+)(\.git)?(/)?" +) def _validate_url(url): - """ ensure valid URL """ - for _url_re in [_http_fttp_url, _ssh_url]: - match = _url_re.match(url) - if match: - return - logger.error('Repository URL (%s) is not valid', url) - raise InvalidRepositoryError('Repository URL must be a valid URL, but got "{}".'.format(url)) + """ ensure valid URL """ + for _url_re in [_http_fttp_url, _ssh_url]: + match = _url_re.match(url) + if match: + return + logger.error("Repository URL (%s) is not valid", url) + raise InvalidRepositoryError( + 'Repository URL must be a valid URL, but got "{}".'.format(url) + ) diff --git a/taf/log.py b/taf/log.py index fdae4972e..6a4b01bd9 100644 --- a/taf/log.py +++ b/taf/log.py @@ -4,44 +4,48 @@ import taf.settings -_FORMAT_STRING = '[%(asctime)s] [%(levelname)s] ' + \ - '[%(funcName)s:%(lineno)s@%(filename)s]\n%(message)s\n' +_FORMAT_STRING = ( + "[%(asctime)s] [%(levelname)s] " + + "[%(funcName)s:%(lineno)s@%(filename)s]\n%(message)s\n" +) formatter = logging.Formatter(_FORMAT_STRING) -logger = logging.getLogger('taf') +logger = logging.getLogger("taf") logger.setLevel(taf.settings.LOG_LEVEL) def _get_log_location(): - location = taf.settings.LOGS_LOCATION or os.environ.get('TAF_LOG') - if location is None: - location = Path.home() / '.taf' - location.mkdir(exist_ok=True) - else: - location = Path(location) - return location + location = taf.settings.LOGS_LOCATION or os.environ.get("TAF_LOG") + if location is None: + location = Path.home() / ".taf" + location.mkdir(exist_ok=True) + else: + location = Path(location) + return location if taf.settings.ENABLE_CONSOLE_LOGGING: - console_handler = logging.StreamHandler() - console_handler.setLevel(taf.settings.CONSOLE_LOGGING_LEVEL) - console_handler.setFormatter(formatter) - logger.addHandler(console_handler) + console_handler = logging.StreamHandler() + console_handler.setLevel(taf.settings.CONSOLE_LOGGING_LEVEL) + console_handler.setFormatter(formatter) + logger.addHandler(console_handler) if taf.settings.ENABLE_FILE_LOGGING: - logs_location = _get_log_location() - file_handler = logging.FileHandler(str(logs_location / taf.settings.LOG_FILENAME)) - file_handler.setLevel(taf.settings.FILE_LOGGING_LEVEL) - file_handler.setFormatter(formatter) - logger.addHandler(file_handler) + logs_location = _get_log_location() + file_handler = logging.FileHandler(str(logs_location / taf.settings.LOG_FILENAME)) + file_handler.setLevel(taf.settings.FILE_LOGGING_LEVEL) + file_handler.setFormatter(formatter) + logger.addHandler(file_handler) - if taf.settings.SEPARATE_ERRORS: - error_handler = logging.FileHandler(str(logs_location / taf.settings.ERROR_LOG_FILENAME)) - error_handler.setLevel(taf.settings.ERROR_LOGGING_LEVEL) - error_handler.setFormatter(formatter) - logger.addHandler(error_handler) + if taf.settings.SEPARATE_ERRORS: + error_handler = logging.FileHandler( + str(logs_location / taf.settings.ERROR_LOG_FILENAME) + ) + error_handler.setLevel(taf.settings.ERROR_LOGGING_LEVEL) + error_handler.setFormatter(formatter) + logger.addHandler(error_handler) def get_logger(name): - return logging.getLogger(name) + return logging.getLogger(name) diff --git a/taf/repositoriesdb.py b/taf/repositoriesdb.py index e18638bbd..8c0400cf2 100644 --- a/taf/repositoriesdb.py +++ b/taf/repositoriesdb.py @@ -3,8 +3,7 @@ from subprocess import CalledProcessError import taf.log -from taf.exceptions import (InvalidOrMissingMetadataError, - RepositoriesNotFoundError) +from taf.exceptions import InvalidOrMissingMetadataError, RepositoriesNotFoundError from taf.git import NamedGitRepository # { @@ -20,18 +19,24 @@ logger = taf.log.get_logger(__name__) _repositories_dict = {} -repositories_path = 'targets/repositories.json' -targets_path = 'metadata/targets.json' +repositories_path = "targets/repositories.json" +targets_path = "metadata/targets.json" def clear_repositories_db(): - global _repositories_dict - _repositories_dict.clear() - - -def load_repositories(auth_repo, repo_classes=None, factory=None, - root_dir=None, only_load_targets=False, commits=None): - """ + global _repositories_dict + _repositories_dict.clear() + + +def load_repositories( + auth_repo, + repo_classes=None, + factory=None, + root_dir=None, + only_load_targets=False, + commits=None, +): + """ Creates target repositories by reading repositories.json and targets.json files at the specified revisions, given an authentication repo. If the the commits are not specified, targets will be created based on the HEAD pointer @@ -59,201 +64,265 @@ def load_repositories(auth_repo, repo_classes=None, factory=None, they are targets or not. """ - global _repositories_dict - if auth_repo.repo_name not in _repositories_dict: - _repositories_dict[auth_repo.repo_name] = {} - - if commits is None: - commits = [auth_repo.head_commit_sha()] - - logger.debug("Loading %s's target repositories at revisions %s", auth_repo.repo_name, - ', '.join(commits)) - - if root_dir is None: - root_dir = Path(auth_repo.repo_path).parent - - for commit in commits: - repositories_dict = {} - # check if already loaded - if commit in _repositories_dict[auth_repo.repo_name]: - continue - - _repositories_dict[auth_repo.repo_name][commit] = repositories_dict - try: - repositories = _get_json_file(auth_repo, repositories_path, commit) - targets = _get_json_file(auth_repo, targets_path, commit) - except InvalidOrMissingMetadataError as e: - logger.warning('Skipping commit %s due to error %s', commit, e) - continue - - # target repositories are defined in both mirrors.json and targets.json - repositories = repositories['repositories'] - targets = targets['signed']['targets'] - for path, repo_data in repositories.items(): - urls = repo_data['urls'] - target = targets.get(path) - if target is None and only_load_targets: - continue - additional_info = _get_custom_data(repo_data, targets.get(path)) - - if factory is not None: - git_repo = factory(root_dir, path, urls, additional_info) - else: - git_repo_class = _determine_repo_class(repo_classes, path) - git_repo = git_repo_class(root_dir, path, urls, additional_info) - - if not isinstance(git_repo, NamedGitRepository): - raise Exception('{} is not a subclass of NamedGitRepository' - .format(type(git_repo))) - - repositories_dict[path] = git_repo - - logger.debug('Loaded the following repositories at revision %s: %s', commit, - ', '.join(repositories_dict.keys())) + global _repositories_dict + if auth_repo.repo_name not in _repositories_dict: + _repositories_dict[auth_repo.repo_name] = {} + + if commits is None: + commits = [auth_repo.head_commit_sha()] + + logger.debug( + "Loading %s's target repositories at revisions %s", + auth_repo.repo_name, + ", ".join(commits), + ) + + if root_dir is None: + root_dir = Path(auth_repo.repo_path).parent + + for commit in commits: + repositories_dict = {} + # check if already loaded + if commit in _repositories_dict[auth_repo.repo_name]: + continue + + _repositories_dict[auth_repo.repo_name][commit] = repositories_dict + try: + repositories = _get_json_file(auth_repo, repositories_path, commit) + targets = _get_json_file(auth_repo, targets_path, commit) + except InvalidOrMissingMetadataError as e: + logger.warning("Skipping commit %s due to error %s", commit, e) + continue + + # target repositories are defined in both mirrors.json and targets.json + repositories = repositories["repositories"] + targets = targets["signed"]["targets"] + for path, repo_data in repositories.items(): + urls = repo_data["urls"] + target = targets.get(path) + if target is None and only_load_targets: + continue + additional_info = _get_custom_data(repo_data, targets.get(path)) + + if factory is not None: + git_repo = factory(root_dir, path, urls, additional_info) + else: + git_repo_class = _determine_repo_class(repo_classes, path) + git_repo = git_repo_class(root_dir, path, urls, additional_info) + + if not isinstance(git_repo, NamedGitRepository): + raise Exception( + "{} is not a subclass of NamedGitRepository".format(type(git_repo)) + ) + + repositories_dict[path] = git_repo + + logger.debug( + "Loaded the following repositories at revision %s: %s", + commit, + ", ".join(repositories_dict.keys()), + ) def _determine_repo_class(repo_classes, path): - # if no class is specified, return the default one - if repo_classes is None: - return NamedGitRepository + # if no class is specified, return the default one + if repo_classes is None: + return NamedGitRepository - # if only one value is specified, that means that all target repositories - # should be of the same class - if not isinstance(repo_classes, dict): - return repo_classes + # if only one value is specified, that means that all target repositories + # should be of the same class + if not isinstance(repo_classes, dict): + return repo_classes - if path in repo_classes: - return repo_classes[path] + if path in repo_classes: + return repo_classes[path] - if 'default' in repo_classes: - return repo_classes['default'] + if "default" in repo_classes: + return repo_classes["default"] - return NamedGitRepository + return NamedGitRepository def _get_custom_data(repo, target): - custom = repo.get('custom', {}) - target_custom = target.get('custom') if target is not None else None - if target_custom is not None: - custom.update(target_custom) - return custom + custom = repo.get("custom", {}) + target_custom = target.get("custom") if target is not None else None + if target_custom is not None: + custom.update(target_custom) + return custom def _get_json_file(auth_repo, path, commit): - try: - return auth_repo.get_json(commit, path) - except CalledProcessError: - raise InvalidOrMissingMetadataError('{} not available at revision {}' - .format(path, commit)) - except json.decoder.JSONDecodeError: - raise InvalidOrMissingMetadataError('{} not a valid json at revision {}' - .format(path, commit)) + try: + return auth_repo.get_json(commit, path) + except CalledProcessError: + raise InvalidOrMissingMetadataError( + "{} not available at revision {}".format(path, commit) + ) + except json.decoder.JSONDecodeError: + raise InvalidOrMissingMetadataError( + "{} not a valid json at revision {}".format(path, commit) + ) def get_repositories_paths_by_custom_data(auth_repo, commit=None, **custom): - if not commit: - commit = auth_repo.head_commit_sha() - logger.debug('Auth repo %s: finding paths of repositories by custom data %s', - auth_repo.repo_name, custom) - targets = auth_repo.get_json(commit, targets_path) - repositories = auth_repo.get_json(commit, repositories_path) - repositories = repositories['repositories'] - targets = targets['signed']['targets'] - - def _compare(path): - # Check if `custom` dict is subset of targets[path]['custom'] dict - try: - return custom.items() <= _get_custom_data(repositories[path], - targets.get(path)).items() - except (AttributeError, KeyError): - return False - - paths = list(filter(_compare, repositories)) if custom else list(repositories) - if len(paths): - logger.debug('Auth repo %s: found the following paths %s', auth_repo.repo_name, paths) - return paths - logger.error('Auth repo %s: repositories associated with custom data %s not found', - auth_repo.repo_name, custom) - raise RepositoriesNotFoundError('Repositories associated with custom data {} not found' - .format(custom)) + if not commit: + commit = auth_repo.head_commit_sha() + logger.debug( + "Auth repo %s: finding paths of repositories by custom data %s", + auth_repo.repo_name, + custom, + ) + targets = auth_repo.get_json(commit, targets_path) + repositories = auth_repo.get_json(commit, repositories_path) + repositories = repositories["repositories"] + targets = targets["signed"]["targets"] + + def _compare(path): + # Check if `custom` dict is subset of targets[path]['custom'] dict + try: + return ( + custom.items() + <= _get_custom_data(repositories[path], targets.get(path)).items() + ) + except (AttributeError, KeyError): + return False + + paths = list(filter(_compare, repositories)) if custom else list(repositories) + if len(paths): + logger.debug( + "Auth repo %s: found the following paths %s", auth_repo.repo_name, paths + ) + return paths + logger.error( + "Auth repo %s: repositories associated with custom data %s not found", + auth_repo.repo_name, + custom, + ) + raise RepositoriesNotFoundError( + "Repositories associated with custom data {} not found".format(custom) + ) def get_deduplicated_repositories(auth_repo, commits): - global _repositories_dict - logger.debug('Auth repo %s: getting a deduplicated list of repositories', auth_repo.repo_name) - all_repositories = _repositories_dict.get(auth_repo.repo_name) - if all_repositories is None: - logger.error('Repositories defined in authentication repository %s have not been loaded', - auth_repo.repo_name) - raise RepositoriesNotFoundError('Repositories defined in authentication repository' - ' {} have not been loaded'.format(auth_repo.repo_name)) - repositories = {} - # persuming that the newest commit is the last one - for commit in commits: - if not commit in all_repositories: - logger.error('Repositories defined in authentication repository %s at revision %s have ' - 'not been loaded', auth_repo.repo_name, commit) - raise RepositoriesNotFoundError('Repositories defined in authentication repository ' - '{} at revision {} have not been loaded' - .format(auth_repo.repo_name, commit)) - for path, repo in all_repositories[commit].items(): - # will overwrite older repo with newer - repositories[path] = repo - - logger.debug('Auth repo %s: deduplicated list of repositories %s', auth_repo.repo_name, - ', '.join(repositories.keys())) - return repositories + global _repositories_dict + logger.debug( + "Auth repo %s: getting a deduplicated list of repositories", auth_repo.repo_name + ) + all_repositories = _repositories_dict.get(auth_repo.repo_name) + if all_repositories is None: + logger.error( + "Repositories defined in authentication repository %s have not been loaded", + auth_repo.repo_name, + ) + raise RepositoriesNotFoundError( + "Repositories defined in authentication repository" + " {} have not been loaded".format(auth_repo.repo_name) + ) + repositories = {} + # persuming that the newest commit is the last one + for commit in commits: + if commit not in all_repositories: + logger.error( + "Repositories defined in authentication repository %s at revision %s have " + "not been loaded", + auth_repo.repo_name, + commit, + ) + raise RepositoriesNotFoundError( + "Repositories defined in authentication repository " + "{} at revision {} have not been loaded".format( + auth_repo.repo_name, commit + ) + ) + for path, repo in all_repositories[commit].items(): + # will overwrite older repo with newer + repositories[path] = repo + + logger.debug( + "Auth repo %s: deduplicated list of repositories %s", + auth_repo.repo_name, + ", ".join(repositories.keys()), + ) + return repositories def get_repository(auth_repo, path, commit=None): - return get_repositories(auth_repo, commit)[path] + return get_repositories(auth_repo, commit)[path] def get_repositories(auth_repo, commit): - global _repositories_dict - logger.debug('Auth repo %s: finding repositories defined at commit %s', auth_repo.repo_name, - commit) - all_repositories = _repositories_dict.get(auth_repo.repo_name) - if all_repositories is None: - logger.error('Repositories defined in authentication repository %s have not been loaded', - auth_repo.repo_name) - raise RepositoriesNotFoundError('Repositories defined in authentication repository' - ' {} have not been loaded'.format(auth_repo.repo_name)) - - if commit is None: - commit = auth_repo.head_commit_sha() - - repositories = all_repositories.get(commit) - if repositories is None: - logger.error('Repositories defined in authentication repository %s at revision %s have ' - 'not been loaded', auth_repo.repo_name, commit) - raise RepositoriesNotFoundError('Repositories defined in authentication repository ' - '{} at revision {} have not been loaded' - .format(auth_repo.repo_name, commit)) - logger.debug('Auth repo %s: found the following repositories at revision %s: %s', auth_repo.repo_name, - commit, ', '.join(repositories.keys())) - return repositories + global _repositories_dict + logger.debug( + "Auth repo %s: finding repositories defined at commit %s", + auth_repo.repo_name, + commit, + ) + all_repositories = _repositories_dict.get(auth_repo.repo_name) + if all_repositories is None: + logger.error( + "Repositories defined in authentication repository %s have not been loaded", + auth_repo.repo_name, + ) + raise RepositoriesNotFoundError( + "Repositories defined in authentication repository" + " {} have not been loaded".format(auth_repo.repo_name) + ) + + if commit is None: + commit = auth_repo.head_commit_sha() + + repositories = all_repositories.get(commit) + if repositories is None: + logger.error( + "Repositories defined in authentication repository %s at revision %s have " + "not been loaded", + auth_repo.repo_name, + commit, + ) + raise RepositoriesNotFoundError( + "Repositories defined in authentication repository " + "{} at revision {} have not been loaded".format(auth_repo.repo_name, commit) + ) + logger.debug( + "Auth repo %s: found the following repositories at revision %s: %s", + auth_repo.repo_name, + commit, + ", ".join(repositories.keys()), + ) + return repositories def get_repositories_by_custom_data(auth_repo, commit=None, **custom_data): - logger.debug('Auth repo %s: finding repositories by custom data %s', - auth_repo.repo_name, custom_data) - repositories = get_repositories(auth_repo, commit).values() - - def _compare(repo): - # Check if `custom` dict is subset of targets[path]['custom'] dict - try: - return custom_data.items() <= repo.additional_info.items() - except (AttributeError, KeyError): - return False - found_repos = list(filter(_compare, repositories) - ) if custom_data else list(repositories) - - if len(found_repos): - logger.debug('Auth repo %s: found the following repositories %s', auth_repo.repo_name, - repositories) - return found_repos - logger.error('Auth repo %s: repositories associated with custom data %s not found', - auth_repo.repo_name, custom_data) - raise RepositoriesNotFoundError('Repositories associated with custom data {} not found' - .format(custom_data)) + logger.debug( + "Auth repo %s: finding repositories by custom data %s", + auth_repo.repo_name, + custom_data, + ) + repositories = get_repositories(auth_repo, commit).values() + + def _compare(repo): + # Check if `custom` dict is subset of targets[path]['custom'] dict + try: + return custom_data.items() <= repo.additional_info.items() + except (AttributeError, KeyError): + return False + + found_repos = ( + list(filter(_compare, repositories)) if custom_data else list(repositories) + ) + + if len(found_repos): + logger.debug( + "Auth repo %s: found the following repositories %s", + auth_repo.repo_name, + repositories, + ) + return found_repos + logger.error( + "Auth repo %s: repositories associated with custom data %s not found", + auth_repo.repo_name, + custom_data, + ) + raise RepositoriesNotFoundError( + "Repositories associated with custom data {} not found".format(custom_data) + ) diff --git a/taf/repository_tool.py b/taf/repository_tool.py index fa6dcf4cb..c4edfe175 100644 --- a/taf/repository_tool.py +++ b/taf/repository_tool.py @@ -10,25 +10,27 @@ from securesystemslib.interface import import_rsa_privatekey_from_file from tuf.exceptions import Error as TUFError from tuf.repository_tool import ( - METADATA_DIRECTORY_NAME, TARGETS_DIRECTORY_NAME, import_rsakey_from_pem, - load_repository) + METADATA_DIRECTORY_NAME, + TARGETS_DIRECTORY_NAME, + import_rsakey_from_pem, + load_repository, +) from taf.constants import DEFAULT_RSA_SIGNATURE_SCHEME -from taf.exceptions import (InvalidKeyError, MetadataUpdateError, - RootMetadataUpdateError, - SnapshotMetadataUpdateError, - TargetsMetadataUpdateError, - TimestampMetadataUpdateError, YubikeyError) +from taf.exceptions import ( + InvalidKeyError, + MetadataUpdateError, + RootMetadataUpdateError, + SnapshotMetadataUpdateError, + TargetsMetadataUpdateError, + TimestampMetadataUpdateError, + YubikeyError, +) from taf.git import GitRepository from taf.utils import normalize_file_line_endings # Default expiration intervals per role -expiration_intervals = { - 'root': 365, - 'targets': 90, - 'snapshot': 7, - 'timestamp': 1 -} +expiration_intervals = {"root": 365, "targets": 90, "snapshot": 7, "timestamp": 1} # Loaded keys cache role_keys_cache = {} @@ -37,9 +39,8 @@ DISABLE_KEYS_CACHING = False -def load_role_key(keystore, role, password=None, - scheme=DEFAULT_RSA_SIGNATURE_SCHEME): - """Loads the specified role's key from a keystore file. +def load_role_key(keystore, role, password=None, scheme=DEFAULT_RSA_SIGNATURE_SCHEME): + """Loads the specified role's key from a keystore file. The keystore file can, but doesn't have to be password protected. NOTE: Keys inside keystore should match a role name! @@ -57,20 +58,23 @@ def load_role_key(keystore, role, password=None, - securesystemslib.exceptions.FormatError: If the arguments are improperly formatted. - securesystemslib.exceptions.CryptoError: If path is not a valid encrypted key file. """ - key = role_keys_cache.get(role) - if key is None: - if password is not None: - key = import_rsa_privatekey_from_file(os.path.join(keystore, role), - password, scheme=scheme) - else: - key = import_rsa_privatekey_from_file(os.path.join(keystore, role), scheme=scheme) - if not DISABLE_KEYS_CACHING: - role_keys_cache[role] = key - return key + key = role_keys_cache.get(role) + if key is None: + if password is not None: + key = import_rsa_privatekey_from_file( + os.path.join(keystore, role), password, scheme=scheme + ) + else: + key = import_rsa_privatekey_from_file( + os.path.join(keystore, role), scheme=scheme + ) + if not DISABLE_KEYS_CACHING: + role_keys_cache[role] = key + return key def targets_signature_provider(key_id, key_pin, key, data): # pylint: disable=W0613 - """Targets signature provider used to sign data with YubiKey. + """Targets signature provider used to sign data with YubiKey. Args: - key_id(str): Key id from targets metadata file @@ -84,20 +88,17 @@ def targets_signature_provider(key_id, key_pin, key, data): # pylint: disable=W Raises: - YubikeyError: If signing with YubiKey cannot be performed """ - from taf.yubikey import sign_piv_rsa_pkcs1v15 - from binascii import hexlify + from taf.yubikey import sign_piv_rsa_pkcs1v15 + from binascii import hexlify - data = securesystemslib.formats.encode_canonical(data).encode('utf-8') - signature = sign_piv_rsa_pkcs1v15(data, key_pin) + data = securesystemslib.formats.encode_canonical(data).encode("utf-8") + signature = sign_piv_rsa_pkcs1v15(data, key_pin) - return { - 'keyid': key_id, - 'sig': hexlify(signature).decode() - } + return {"keyid": key_id, "sig": hexlify(signature).decode()} def root_signature_provider(signature_dict, key_id, _key, _data): - """Root signature provider used to return signatures created remotely. + """Root signature provider used to return signatures created remotely. Args: - signature_dict(dict): Dict where key is key_id and value is signature @@ -111,38 +112,34 @@ def root_signature_provider(signature_dict, key_id, _key, _data): Raises: - KeyError: If signature for key_id is not present in signature_dict """ - from binascii import hexlify + from binascii import hexlify - return { - 'keyid': key_id, - 'sig': hexlify(signature_dict.get(key_id)).decode() - } + return {"keyid": key_id, "sig": hexlify(signature_dict.get(key_id)).decode()} class Repository: + def __init__(self, repository_path): + self.repository_path = repository_path + tuf.repository_tool.METADATA_STAGED_DIRECTORY_NAME = METADATA_DIRECTORY_NAME + tuf_repository = load_repository(repository_path) + self._repository = tuf_repository - def __init__(self, repository_path): - self.repository_path = repository_path - tuf.repository_tool.METADATA_STAGED_DIRECTORY_NAME = METADATA_DIRECTORY_NAME - tuf_repository = load_repository(repository_path) - self._repository = tuf_repository + _framework_files = ["repositories.json", "test-auth-repo"] - _framework_files = ['repositories.json', 'test-auth-repo'] + @property + def targets_path(self): + return Path(self.repository_path) / TARGETS_DIRECTORY_NAME - @property - def targets_path(self): - return Path(self.repository_path) / TARGETS_DIRECTORY_NAME + @property + def metadata_path(self): + return os.path.join(self.repository_path, METADATA_DIRECTORY_NAME) - @property - def metadata_path(self): - return os.path.join(self.repository_path, METADATA_DIRECTORY_NAME) + @property + def repo_id(self): + return GitRepository(self.repository_path).initial_commit - @property - def repo_id(self): - return GitRepository(self.repository_path).initial_commit - - def _add_target(self, targets_obj, file_path, custom=None): - """ + def _add_target(self, targets_obj, file_path, custom=None): + """ Normalizes line endings (converts all line endings to unix style endings) and registers the target file as a TUF target @@ -151,11 +148,11 @@ def _add_target(self, targets_obj, file_path, custom=None): file_path: full path of the target file custom: custom target data """ - normalize_file_line_endings(file_path) - targets_obj.add_target(file_path, custom) + normalize_file_line_endings(file_path) + targets_obj.add_target(file_path, custom) - def _role_obj(self, role): - """Helper function for getting TUF's role object, given the role's name + def _role_obj(self, role): + """Helper function for getting TUF's role object, given the role's name Args: - role(str): TUF role (root, targets, timestamp, snapshot or delegated one) @@ -169,18 +166,18 @@ def _role_obj(self, role): - securesystemslib.exceptions.UnknownRoleError: If 'rolename' has not been delegated by this targets object. """ - if role == 'targets': - return self._repository.targets - elif role == 'snapshot': - return self._repository.snapshot - elif role == 'timestamp': - return self._repository.timestamp - elif role == 'root': - return self._repository.root - return self._repository.targets(role) - - def _try_load_metadata_key(self, role, key): - """Check if given key can be used to sign given role and load it. + if role == "targets": + return self._repository.targets + elif role == "snapshot": + return self._repository.snapshot + elif role == "timestamp": + return self._repository.timestamp + elif role == "root": + return self._repository.root + return self._repository.targets(role) + + def _try_load_metadata_key(self, role, key): + """Check if given key can be used to sign given role and load it. Args: - role(str): TUF role (root, targets, timestamp, snapshot or delegated one) @@ -195,12 +192,14 @@ def _try_load_metadata_key(self, role, key): targets object. - InvalidKeyError: If metadata cannot be signed with given key. """ - if not self.is_valid_metadata_key(role, key): - raise InvalidKeyError(role) - self._role_obj(role).load_signing_key(key) + if not self.is_valid_metadata_key(role, key): + raise InvalidKeyError(role) + self._role_obj(role).load_signing_key(key) - def _update_metadata(self, role, start_date=datetime.datetime.now(), interval=None, write=False): - """Update metadata expiration date and (optionally) writes it. + def _update_metadata( + self, role, start_date=datetime.datetime.now(), interval=None, write=False + ): + """Update metadata expiration date and (optionally) writes it. Args: - role(str): TUF role (root, targets, timestamp, snapshot or delegated one) @@ -218,12 +217,12 @@ def _update_metadata(self, role, start_date=datetime.datetime.now(), interval=No - securesystemslib.exceptions.Error: If securesystemslib error happened during metadata write - tuf.exceptions.Error: If TUF error happened during metadata write """ - self.set_metadata_expiration_date(role, start_date, interval) - if write: - self._repository.write(role) + self.set_metadata_expiration_date(role, start_date, interval) + if write: + self._repository.write(role) - def add_existing_target(self, file_path, targets_role='targets', custom=None): - """Registers new target files with TUF. + def add_existing_target(self, file_path, targets_role="targets", custom=None): + """Registers new target files with TUF. The files are expected to be inside the targets directory. Args: @@ -240,11 +239,11 @@ def add_existing_target(self, file_path, targets_role='targets', custom=None): - securesystemslib.exceptions.Error: If 'filepath' is not located in the repository's targets directory. """ - targets_obj = self._role_obj(targets_role) - self._add_target(targets_obj, file_path, custom) + targets_obj = self._role_obj(targets_role) + self._add_target(targets_obj, file_path, custom) - def add_targets(self, data, targets_role='targets', files_to_keep=None): - """Creates a target .json file containing a repository's commit for each repository. + def add_targets(self, data, targets_role="targets", files_to_keep=None): + """Creates a target .json file containing a repository's commit for each repository. Adds those files to the tuf repository. Also removes all targets from the filesystem if their path is not among the provided ones. TUF does not delete targets automatically. @@ -279,70 +278,76 @@ def add_targets(self, data, targets_role='targets', files_to_keep=None): that should remain targets. Files required by the framework will also remain targets. """ - if files_to_keep is None: - files_to_keep = [] - # leave all files required by the framework and additional files specified by the user - files_to_keep.extend(self._framework_files) - # add all repositories defined in repositories.json to files_to_keep - files_to_keep.extend(self._get_target_repositories()) - # delete files if they no longer correspond to a target defined - # in targets metadata and are not specified in files_to_keep - targets_obj = self._role_obj(targets_role) - for filepath in self.targets_path.rglob('*'): - if filepath.is_file(): - file_rel_path = str(Path(os.path.relpath(str(filepath), str(self.targets_path))).as_posix()) - if file_rel_path not in data and file_rel_path not in files_to_keep: - if file_rel_path in targets_obj.target_files: - targets_obj.remove_target(file_rel_path) - filepath.unlink() - - for path, target_data in data.items(): - # if the target's parent directory should not be "targets", create - # its parent directories if they do not exist - target_path = (self.targets_path / path).absolute() - target_dir = target_path.parents[0] - target_dir.mkdir(parents=True, exist_ok=True) - - # create the target file - content = target_data.get('target', None) - if content is None: - if not target_path.is_file(): - target_path.touch() - else: - with open(str(target_path), 'w') as f: - if isinstance(content, dict): - json.dump(content, f, indent=4) - else: - f.write(content) - - custom = target_data.get('custom', None) - self._add_target(targets_obj, str(target_path), custom) - - with open(os.path.join(self.metadata_path, '{}.json'.format(targets_role))) as f: - previous_targets = json.load(f)['signed']['targets'] - - for path in files_to_keep: - # if path if both in data and files_to_keep, skip it - # e.g. repositories.json will always be in files_to_keep, - # but it might also be specified in data, if it needs to be updated - if path in data: - continue - target_path = (self.targets_path / path).absolute() - previous_custom = None - if path in previous_targets: - previous_custom = previous_targets[path].get('custom') - if target_path.is_file(): - self._add_target(targets_obj, str(target_path), previous_custom) - - def _get_target_repositories(self): - repositories_path = self.targets_path / 'repositories.json' - if repositories_path.exists(): - repositories = repositories_path.read_text() - repositories = json.loads(repositories)['repositories'] - return [str(Path(target_path).as_posix()) for target_path in repositories] - - def get_role_keys(self, role): - """Registers new target files with TUF. + if files_to_keep is None: + files_to_keep = [] + # leave all files required by the framework and additional files specified by the user + files_to_keep.extend(self._framework_files) + # add all repositories defined in repositories.json to files_to_keep + files_to_keep.extend(self._get_target_repositories()) + # delete files if they no longer correspond to a target defined + # in targets metadata and are not specified in files_to_keep + targets_obj = self._role_obj(targets_role) + for filepath in self.targets_path.rglob("*"): + if filepath.is_file(): + file_rel_path = str( + Path( + os.path.relpath(str(filepath), str(self.targets_path)) + ).as_posix() + ) + if file_rel_path not in data and file_rel_path not in files_to_keep: + if file_rel_path in targets_obj.target_files: + targets_obj.remove_target(file_rel_path) + filepath.unlink() + + for path, target_data in data.items(): + # if the target's parent directory should not be "targets", create + # its parent directories if they do not exist + target_path = (self.targets_path / path).absolute() + target_dir = target_path.parents[0] + target_dir.mkdir(parents=True, exist_ok=True) + + # create the target file + content = target_data.get("target", None) + if content is None: + if not target_path.is_file(): + target_path.touch() + else: + with open(str(target_path), "w") as f: + if isinstance(content, dict): + json.dump(content, f, indent=4) + else: + f.write(content) + + custom = target_data.get("custom", None) + self._add_target(targets_obj, str(target_path), custom) + + with open( + os.path.join(self.metadata_path, "{}.json".format(targets_role)) + ) as f: + previous_targets = json.load(f)["signed"]["targets"] + + for path in files_to_keep: + # if path if both in data and files_to_keep, skip it + # e.g. repositories.json will always be in files_to_keep, + # but it might also be specified in data, if it needs to be updated + if path in data: + continue + target_path = (self.targets_path / path).absolute() + previous_custom = None + if path in previous_targets: + previous_custom = previous_targets[path].get("custom") + if target_path.is_file(): + self._add_target(targets_obj, str(target_path), previous_custom) + + def _get_target_repositories(self): + repositories_path = self.targets_path / "repositories.json" + if repositories_path.exists(): + repositories = repositories_path.read_text() + repositories = json.loads(repositories)["repositories"] + return [str(Path(target_path).as_posix()) for target_path in repositories] + + def get_role_keys(self, role): + """Registers new target files with TUF. The files are expected to be inside the targets directory. Args: @@ -356,11 +361,11 @@ def get_role_keys(self, role): - securesystemslib.exceptions.UnknownRoleError: If 'rolename' has not been delegated by this targets object. """ - role_obj = self._role_obj(role) - return role_obj.keys + role_obj = self._role_obj(role) + return role_obj.keys - def get_signable_metadata(self, role): - """Return signable portion of newly generate metadata for given role. + def get_signable_metadata(self, role): + """Return signable portion of newly generate metadata for given role. Args: - role(str): TUF role (root, targets, timestamp, snapshot or delegated one) @@ -371,24 +376,25 @@ def get_signable_metadata(self, role): Raises: None """ - try: - from tuf.keydb import get_key - signable = None + try: + from tuf.keydb import get_key - role_obj = self._role_obj(role) - key = get_key(role_obj.keys[0]) + signable = None - def _provider(data): - nonlocal signable - signable = securesystemslib.formats.encode_canonical(data) + role_obj = self._role_obj(role) + key = get_key(role_obj.keys[0]) - role_obj.add_external_signature_provider(key, _provider) - self.writeall() - except (IndexError, TUFError, SSLibError): - return signable + def _provider(data): + nonlocal signable + signable = securesystemslib.formats.encode_canonical(data) - def is_valid_metadata_key(self, role, key): - """Checks if metadata role contains key id of provided key. + role_obj.add_external_signature_provider(key, _provider) + self.writeall() + except (IndexError, TUFError, SSLibError): + return signable + + def is_valid_metadata_key(self, role, key): + """Checks if metadata role contains key id of provided key. Args: - role(str): TUF role (root, targets, timestamp, snapshot or delegated one) @@ -401,12 +407,12 @@ def is_valid_metadata_key(self, role, key): - securesystemslib.exceptions.FormatError: If key does not match RSAKEY_SCHEMA - securesystemslib.exceptions.UnknownRoleError: If role does not exist """ - securesystemslib.formats.RSAKEY_SCHEMA.check_match(key) + securesystemslib.formats.RSAKEY_SCHEMA.check_match(key) - return key['keyid'] in self.get_role_keys(role) + return key["keyid"] in self.get_role_keys(role) - def is_valid_metadata_yubikey(self, role, public_key=None): - """Checks if metadata role contains key id from YubiKey. + def is_valid_metadata_yubikey(self, role, public_key=None): + """Checks if metadata role contains key id from YubiKey. Args: - role(str): TUF role (root, targets, timestamp, snapshot or delegated one @@ -420,16 +426,17 @@ def is_valid_metadata_yubikey(self, role, public_key=None): - securesystemslib.exceptions.FormatError: If 'PEM' is improperly formatted. - securesystemslib.exceptions.UnknownRoleError: If role does not exist """ - securesystemslib.formats.ROLENAME_SCHEMA.check_match(role) + securesystemslib.formats.ROLENAME_SCHEMA.check_match(role) + + if public_key is None: + from taf.yubikey import get_piv_public_key_tuf - if public_key is None: - from taf.yubikey import get_piv_public_key_tuf - public_key = get_piv_public_key_tuf() + public_key = get_piv_public_key_tuf() - return self.is_valid_metadata_key(role, public_key) + return self.is_valid_metadata_key(role, public_key) - def add_metadata_key(self, role, pub_key_pem): - """Add metadata key of the provided role. + def add_metadata_key(self, role, pub_key_pem): + """Add metadata key of the provided role. Args: - role(str): TUF role (root, targets, timestamp, snapshot or delegated one) @@ -445,14 +452,14 @@ def add_metadata_key(self, role, pub_key_pem): - securesystemslib.exceptions.UnknownKeyError: If 'key_id' is not found in the keydb database. """ - if isinstance(pub_key_pem, bytes): - pub_key_pem = pub_key_pem.decode('utf-8') + if isinstance(pub_key_pem, bytes): + pub_key_pem = pub_key_pem.decode("utf-8") - key = import_rsakey_from_pem(pub_key_pem) - self._role_obj(role).add_verification_key(key) + key = import_rsakey_from_pem(pub_key_pem) + self._role_obj(role).add_verification_key(key) - def remove_metadata_key(self, role, key_id): - """Remove metadata key of the provided role. + def remove_metadata_key(self, role, key_id): + """Remove metadata key of the provided role. Args: - role(str): TUF role (root, targets, timestamp, snapshot or delegated one) @@ -468,12 +475,15 @@ def remove_metadata_key(self, role, key_id): - securesystemslib.exceptions.UnknownKeyError: If 'key_id' is not found in the keydb database. """ - from tuf.keydb import get_key - key = get_key(key_id) - self._role_obj(role).remove_verification_key(key) + from tuf.keydb import get_key - def set_metadata_expiration_date(self, role, start_date=datetime.datetime.now(), interval=None): - """Set expiration date of the provided role. + key = get_key(key_id) + self._role_obj(role).remove_verification_key(key) + + def set_metadata_expiration_date( + self, role, start_date=datetime.datetime.now(), interval=None + ): + """Set expiration date of the provided role. Args: - role(str): TUF role (root, targets, timestamp, snapshot or delegated one) @@ -497,14 +507,14 @@ def set_metadata_expiration_date(self, role, start_date=datetime.datetime.now(), - securesystemslib.exceptions.UnknownRoleError: If 'rolename' has not been delegated by this targets object. """ - role_obj = self._role_obj(role) - if interval is None: - interval = expiration_intervals.get(role, 1) - expiration_date = start_date + datetime.timedelta(interval) - role_obj.expiration = expiration_date + role_obj = self._role_obj(role) + if interval is None: + interval = expiration_intervals.get(role, 1) + expiration_date = start_date + datetime.timedelta(interval) + role_obj.expiration = expiration_date - def update_root(self, signature_dict): - """Update root metadata. + def update_root(self, signature_dict): + """Update root metadata. Args: - signature_dict(dict): key_id-signature dictionary @@ -516,20 +526,26 @@ def update_root(self, signature_dict): - InvalidKeyError: If wrong key is used to sign metadata - SnapshotMetadataUpdateError: If any other error happened during metadata update """ - from tuf.keydb import get_key - try: - for key_id in signature_dict: - key = get_key(key_id) - self._repository.root.add_external_signature_provider( - key, - partial(root_signature_provider, signature_dict, key_id) - ) - self.writeall() - except (TUFError, SSLibError) as e: - raise RootMetadataUpdateError(str(e)) - - def update_snapshot(self, snapshot_key, start_date=datetime.datetime.now(), interval=None, write=True): - """Update snapshot metadata. + from tuf.keydb import get_key + + try: + for key_id in signature_dict: + key = get_key(key_id) + self._repository.root.add_external_signature_provider( + key, partial(root_signature_provider, signature_dict, key_id) + ) + self.writeall() + except (TUFError, SSLibError) as e: + raise RootMetadataUpdateError(str(e)) + + def update_snapshot( + self, + snapshot_key, + start_date=datetime.datetime.now(), + interval=None, + write=True, + ): + """Update snapshot metadata. Args: - snapshot_key @@ -548,15 +564,16 @@ def update_snapshot(self, snapshot_key, start_date=datetime.datetime.now(), inte - InvalidKeyError: If wrong key is used to sign metadata - SnapshotMetadataUpdateError: If any other error happened during metadata update """ - try: - self._try_load_metadata_key('snapshot', snapshot_key) - self._update_metadata('snapshot', start_date, interval, write=write) - except (TUFError, SSLibError) as e: - raise SnapshotMetadataUpdateError(str(e)) + try: + self._try_load_metadata_key("snapshot", snapshot_key) + self._update_metadata("snapshot", start_date, interval, write=write) + except (TUFError, SSLibError) as e: + raise SnapshotMetadataUpdateError(str(e)) - def update_snapshot_and_timestmap(self, snapshot_key, timestamp_key, write=True, - **kwargs): - """Update snapshot and timestamp metadata. + def update_snapshot_and_timestmap( + self, snapshot_key, timestamp_key, write=True, **kwargs + ): + """Update snapshot and timestamp metadata. Args: - snapshot_key(str): snapshot key @@ -575,23 +592,26 @@ def update_snapshot_and_timestmap(self, snapshot_key, timestamp_key, write=True, - InvalidKeyError: If wrong key is used to sign metadata - MetadataUpdateError: If any other error happened during metadata update """ - try: - snapshot_date = kwargs.get('snapshot_date', datetime.datetime.now()) - snapshot_interval = kwargs.get('snapshot_interval', None) - - timestamp_date = kwargs.get('timestamp_date', datetime.datetime.now()) - timestamp_interval = kwargs.get('timestamp_interval', None) - - self.update_snapshot(snapshot_key, snapshot_date, - snapshot_interval, write=write) - self.update_timestamp(timestamp_key, timestamp_date, - timestamp_interval, write=write) - except (TUFError, SSLibError) as e: - raise MetadataUpdateError('all', str(e)) - - def update_targets_from_keystore(self, targets_key, start_date=datetime.datetime.now(), - interval=None, write=True): - """Update targets metadata. Sign it with a key from the file system + try: + snapshot_date = kwargs.get("snapshot_date", datetime.datetime.now()) + snapshot_interval = kwargs.get("snapshot_interval", None) + + timestamp_date = kwargs.get("timestamp_date", datetime.datetime.now()) + timestamp_interval = kwargs.get("timestamp_interval", None) + + self.update_snapshot( + snapshot_key, snapshot_date, snapshot_interval, write=write + ) + self.update_timestamp( + timestamp_key, timestamp_date, timestamp_interval, write=write + ) + except (TUFError, SSLibError) as e: + raise MetadataUpdateError("all", str(e)) + + def update_targets_from_keystore( + self, targets_key, start_date=datetime.datetime.now(), interval=None, write=True + ): + """Update targets metadata. Sign it with a key from the file system Args: - targets_key(securesystemslib.formats.RSAKEY_SCHEMA): Targets key. @@ -609,16 +629,22 @@ def update_targets_from_keystore(self, targets_key, start_date=datetime.datetime - InvalidKeyError: If wrong key is used to sign metadata - TimestampMetadataUpdateError: If any other error happened during metadata update """ - try: - self._try_load_metadata_key('targets', targets_key) - self._update_metadata('targets', start_date, interval, write=write) - except (TUFError, SSLibError) as e: - raise TimestampMetadataUpdateError(str(e)) - - def update_targets(self, targets_key_pin, targets_data=None, - start_date=datetime.datetime.now(), interval=None, - write=True, public_key=None): - """Update target data, sign with smart card and write. + try: + self._try_load_metadata_key("targets", targets_key) + self._update_metadata("targets", start_date, interval, write=write) + except (TUFError, SSLibError) as e: + raise TimestampMetadataUpdateError(str(e)) + + def update_targets( + self, + targets_key_pin, + targets_data=None, + start_date=datetime.datetime.now(), + interval=None, + write=True, + public_key=None, + ): + """Update target data, sign with smart card and write. Args: - targets_key_pin(str): Targets key pin @@ -638,31 +664,40 @@ def update_targets(self, targets_key_pin, targets_data=None, - InvalidKeyError: If wrong key is used to sign metadata - MetadataUpdateError: If any other error happened during metadata update """ - try: - if public_key is None: - from taf.yubikey import get_piv_public_key_tuf - public_key = get_piv_public_key_tuf() + try: + if public_key is None: + from taf.yubikey import get_piv_public_key_tuf + + public_key = get_piv_public_key_tuf() - if not self.is_valid_metadata_yubikey('targets', public_key): - raise InvalidKeyError('targets') + if not self.is_valid_metadata_yubikey("targets", public_key): + raise InvalidKeyError("targets") - if targets_data: - self.add_targets(targets_data) + if targets_data: + self.add_targets(targets_data) - self.set_metadata_expiration_date('targets', start_date, interval) + self.set_metadata_expiration_date("targets", start_date, interval) - self._repository.targets.add_external_signature_provider( - public_key, - partial(targets_signature_provider, public_key['keyid'], targets_key_pin) - ) - if write: - self._repository.write('targets') + self._repository.targets.add_external_signature_provider( + public_key, + partial( + targets_signature_provider, public_key["keyid"], targets_key_pin + ), + ) + if write: + self._repository.write("targets") - except (YubikeyError, TUFError, SSLibError) as e: - raise TargetsMetadataUpdateError(str(e)) + except (YubikeyError, TUFError, SSLibError) as e: + raise TargetsMetadataUpdateError(str(e)) - def update_timestamp(self, timestamp_key, start_date=datetime.datetime.now(), interval=None, write=True): - """Update timestamp metadata. + def update_timestamp( + self, + timestamp_key, + start_date=datetime.datetime.now(), + interval=None, + write=True, + ): + """Update timestamp metadata. Args: - timestamp_key @@ -680,14 +715,14 @@ def update_timestamp(self, timestamp_key, start_date=datetime.datetime.now(), in - InvalidKeyError: If wrong key is used to sign metadata - TimestampMetadataUpdateError: If any other error happened during metadata update """ - try: - self._try_load_metadata_key('timestamp', timestamp_key) - self._update_metadata('timestamp', start_date, interval, write=write) - except (TUFError, SSLibError) as e: - raise TimestampMetadataUpdateError(str(e)) + try: + self._try_load_metadata_key("timestamp", timestamp_key) + self._update_metadata("timestamp", start_date, interval, write=write) + except (TUFError, SSLibError) as e: + raise TimestampMetadataUpdateError(str(e)) - def writeall(self): - """Write all dirty metadata files. + def writeall(self): + """Write all dirty metadata files. Args: None @@ -699,4 +734,4 @@ def writeall(self): - tuf.exceptions.UnsignedMetadataError: If any of the top-level and delegated roles do not have the minimum threshold of signatures. """ - self._repository.writeall() + self._repository.writeall() diff --git a/taf/settings.py b/taf/settings.py index 88aca3b2e..b9ac7e307 100644 --- a/taf/settings.py +++ b/taf/settings.py @@ -1,4 +1,3 @@ - import logging # Set a directory that should be used for all temporary files. If this @@ -38,8 +37,8 @@ # If this location is not specified logs will be placed ~/.taf LOGS_LOCATION = None -LOG_FILENAME = 'taf.log' +LOG_FILENAME = "taf.log" -ERROR_LOG_FILENAME = 'taf.err' +ERROR_LOG_FILENAME = "taf.err" LOG_COMMAND_OUTPUT = False diff --git a/taf/updater/handlers.py b/taf/updater/handlers.py index da36b9fd2..db1521068 100644 --- a/taf/updater/handlers.py +++ b/taf/updater/handlers.py @@ -15,7 +15,7 @@ class GitUpdater(handlers.MetadataUpdater): - """ + """ This class implements parts of the update process specific to keeping metadata files and targets in a git repository. The vast majority of the update process is handled by TUF's updater. This class does not modify @@ -63,16 +63,16 @@ class GitUpdater(handlers.MetadataUpdater): not all files are updated at the same time. """ - @property - def current_commit(self): - return self.commits[self.current_commit_index] + @property + def current_commit(self): + return self.commits[self.current_commit_index] - @property - def previous_commit(self): - return self.commits[self.current_commit_index - 1] + @property + def previous_commit(self): + return self.commits[self.current_commit_index - 1] - def __init__(self, mirrors, repository_directory, repository_name): - """ + def __init__(self, mirrors, repository_directory, repository_name): + """ Args: mirrors: is a dictionary which contains information about each mirror: mirrors = {'mirror1': {'url_prefix': 'http://localhost:8001', @@ -84,41 +84,52 @@ def __init__(self, mirrors, repository_directory, repository_name): We use url_prefix to specify url of the git repository which we want to clone. repository_directory: the client's local repository's location """ - super(GitUpdater, self).__init__(mirrors, repository_directory, repository_name) - - auth_url = mirrors['mirror1']['url_prefix'] - self.metadata_path = mirrors['mirror1']['metadata_path'] - self.targets_path = mirrors['mirror1']['targets_path'] - if settings.validate_repo_name: - self.users_auth_repo = NamedAuthenticationRepo(repository_directory, repository_name, - self.metadata_path, self.targets_path, - repo_urls=[auth_url]) - else: - users_repo_path = os.path.join(repository_directory, repository_name) - self.users_auth_repo = AuthenticationRepo(users_repo_path, self.metadata_path, - self.targets_path, repo_urls=[auth_url]) - - self._clone_validation_repo(auth_url) - repository_directory = self.users_auth_repo.repo_path - if os.path.exists(repository_directory): - if not self.users_auth_repo.is_git_repository_root: - if os.listdir(repository_directory): - raise UpdateFailedError('{} is not a git repository and is not empty' - .format(repository_directory)) - - # validation_auth_repo is a freshly cloned bare repository. - # It is cloned to a temporary directory that should be removed - # once the update is completed - - self._init_commits() - # users_auth_repo is the authentication repository - # located on the users machine which needs to be updated - self.repository_directory = repository_directory - - self._init_metadata() - - def _init_commits(self): - """ + super(GitUpdater, self).__init__(mirrors, repository_directory, repository_name) + + auth_url = mirrors["mirror1"]["url_prefix"] + self.metadata_path = mirrors["mirror1"]["metadata_path"] + self.targets_path = mirrors["mirror1"]["targets_path"] + if settings.validate_repo_name: + self.users_auth_repo = NamedAuthenticationRepo( + repository_directory, + repository_name, + self.metadata_path, + self.targets_path, + repo_urls=[auth_url], + ) + else: + users_repo_path = os.path.join(repository_directory, repository_name) + self.users_auth_repo = AuthenticationRepo( + users_repo_path, + self.metadata_path, + self.targets_path, + repo_urls=[auth_url], + ) + + self._clone_validation_repo(auth_url) + repository_directory = self.users_auth_repo.repo_path + if os.path.exists(repository_directory): + if not self.users_auth_repo.is_git_repository_root: + if os.listdir(repository_directory): + raise UpdateFailedError( + "{} is not a git repository and is not empty".format( + repository_directory + ) + ) + + # validation_auth_repo is a freshly cloned bare repository. + # It is cloned to a temporary directory that should be removed + # once the update is completed + + self._init_commits() + # users_auth_repo is the authentication repository + # located on the users machine which needs to be updated + self.repository_directory = repository_directory + + self._init_metadata() + + def _init_commits(self): + """ Given a client's local repository which needs to be updated, creates a list of commits of the authentication repository newer than the most recent commit of the client's repository. These commits need to be validated. @@ -126,63 +137,73 @@ def _init_commits(self): We have to presume that the initial metadata is correct though (or at least the initial root.json). """ - # TODO check if users authentication repository is clean - - # load the last validated commit fromt he conf file - last_validated_commit = self.users_auth_repo.last_validated_commit - - try: - commits_since = self.validation_auth_repo.all_commits_since_commit(last_validated_commit) - except CalledProcessError as e: - if 'Invalid revision range' in e.output: - logger.error('Commit %s is not contained by the remote repository %s.', - last_validated_commit, self.validation_auth_repo.repo_name) - raise UpdateFailedError('Commit {} is no longer contained by repository {}. This could ' - 'either mean that there was an unauthorized push tot the remote ' - 'repository, or that last_validated_commit file was modified.'. - format(last_validated_commit, self.validation_auth_repo.repo_name)) - else: - raise e - - # Check if the user's head commit mathces the saved one - # That should always be the case - # If it is not, it means that someone, accidentally or maliciosly made manual changes - - if not self.users_auth_repo.is_git_repository_root: - users_head_sha = None - else: - self.users_auth_repo.checkout_branch('master') - if last_validated_commit is not None: - users_head_sha = self.users_auth_repo.head_commit_sha() - else: - # if the user's repository exists, but there is no last_validated_commit - # start the update from the beginning - users_head_sha = None - - if last_validated_commit != users_head_sha: - # TODO add a flag --force/f which, if provided, should force an automatic revert - # of the users authentication repository to the last validated commit - # This could be done if a user accidentally committed something to the auth repo - # or manually pulled the changes - # If the user deleted the repository or executed reset --hard, we could handle - # that by starting validation from the last validated commit, as opposed to the - # user's head sha. - # For now, we will raise an error - msg = '''Saved last validated commit {} does not match the head commit of the -authentication repository {}'''.format(last_validated_commit, users_head_sha) - logger.error(msg) - raise UpdateFailedError(msg) - - # insert the current one at the beginning of the list - if users_head_sha is not None: - commits_since.insert(0, users_head_sha) - - self.commits = commits_since - self.users_head_sha = users_head_sha - self.current_commit_index = 0 - - def _init_metadata(self): - """ + # TODO check if users authentication repository is clean + + # load the last validated commit fromt he conf file + last_validated_commit = self.users_auth_repo.last_validated_commit + + try: + commits_since = self.validation_auth_repo.all_commits_since_commit( + last_validated_commit + ) + except CalledProcessError as e: + if "Invalid revision range" in e.output: + logger.error( + "Commit %s is not contained by the remote repository %s.", + last_validated_commit, + self.validation_auth_repo.repo_name, + ) + raise UpdateFailedError( + "Commit {} is no longer contained by repository {}. This could " + "either mean that there was an unauthorized push tot the remote " + "repository, or that last_validated_commit file was modified.".format( + last_validated_commit, self.validation_auth_repo.repo_name + ) + ) + else: + raise e + + # Check if the user's head commit mathces the saved one + # That should always be the case + # If it is not, it means that someone, accidentally or maliciosly made manual changes + + if not self.users_auth_repo.is_git_repository_root: + users_head_sha = None + else: + self.users_auth_repo.checkout_branch("master") + if last_validated_commit is not None: + users_head_sha = self.users_auth_repo.head_commit_sha() + else: + # if the user's repository exists, but there is no last_validated_commit + # start the update from the beginning + users_head_sha = None + + if last_validated_commit != users_head_sha: + # TODO add a flag --force/f which, if provided, should force an automatic revert + # of the users authentication repository to the last validated commit + # This could be done if a user accidentally committed something to the auth repo + # or manually pulled the changes + # If the user deleted the repository or executed reset --hard, we could handle + # that by starting validation from the last validated commit, as opposed to the + # user's head sha. + # For now, we will raise an error + msg = """Saved last validated commit {} does not match the head commit of the +authentication repository {}""".format( + last_validated_commit, users_head_sha + ) + logger.error(msg) + raise UpdateFailedError(msg) + + # insert the current one at the beginning of the list + if users_head_sha is not None: + commits_since.insert(0, users_head_sha) + + self.commits = commits_since + self.users_head_sha = users_head_sha + self.current_commit_index = 0 + + def _init_metadata(self): + """ TUF updater expects the existence of two directories in the client's metadata directory - current and previous. These directories store the current and previous metadata files (before and after the update). @@ -191,107 +212,120 @@ def _init_metadata(self): but will create the directories where TUF expects them to be in order to avoid modifying the updater. """ - # create current and previous directories and copy the metadata files - # needed by the updater - # TUF's updater expects these directories to be in the client's repository - # read metadata of the cloned validation repo at the initial commit - - metadata_path = os.path.join(self.repository_directory, 'metadata') - if not os.path.isdir(metadata_path): - os.makedirs(metadata_path) - self.current_path = os.path.join(metadata_path, 'current') - self.previous_path = os.path.join(metadata_path, 'previous') - os.mkdir(self.current_path) - os.mkdir(self.previous_path) - - metadata_files = self.validation_auth_repo.list_files_at_revision(self.current_commit, - 'metadata') - for filename in metadata_files: - metadata = self.validation_auth_repo.get_file(self.current_commit, 'metadata/' + filename) - current_filename = os.path.join(self.current_path, filename) - previous_filename = os.path.join(self.previous_path, filename) - with open(current_filename, 'w') as f: - f.write(metadata) - shutil.copyfile(current_filename, previous_filename) - - def _clone_validation_repo(self, url): - """ + # create current and previous directories and copy the metadata files + # needed by the updater + # TUF's updater expects these directories to be in the client's repository + # read metadata of the cloned validation repo at the initial commit + + metadata_path = os.path.join(self.repository_directory, "metadata") + if not os.path.isdir(metadata_path): + os.makedirs(metadata_path) + self.current_path = os.path.join(metadata_path, "current") + self.previous_path = os.path.join(metadata_path, "previous") + os.mkdir(self.current_path) + os.mkdir(self.previous_path) + + metadata_files = self.validation_auth_repo.list_files_at_revision( + self.current_commit, "metadata" + ) + for filename in metadata_files: + metadata = self.validation_auth_repo.get_file( + self.current_commit, "metadata/" + filename + ) + current_filename = os.path.join(self.current_path, filename) + previous_filename = os.path.join(self.previous_path, filename) + with open(current_filename, "w") as f: + f.write(metadata) + shutil.copyfile(current_filename, previous_filename) + + def _clone_validation_repo(self, url): + """ Clones the authentication repository based on the url specified using the mirrors parameter. The repository is cloned as a bare repository to a the temp directory and will be deleted one the update is done. """ - temp_dir = tempfile.mkdtemp() - repo_path = os.path.join(temp_dir, self.users_auth_repo.repo_name) - self.validation_auth_repo = GitRepository(repo_path, [url]) - self.validation_auth_repo.clone(bare=True) - self.validation_auth_repo.fetch(fetch_all=True) - - def cleanup(self): - """ + temp_dir = tempfile.mkdtemp() + repo_path = os.path.join(temp_dir, self.users_auth_repo.repo_name) + self.validation_auth_repo = GitRepository(repo_path, [url]) + self.validation_auth_repo.clone(bare=True) + self.validation_auth_repo.fetch(fetch_all=True) + + def cleanup(self): + """ Removes the bare authentication repository and current and previous directories. This should be called after the update is finished, either successfully or unsuccessfully. """ - shutil.rmtree(self.current_path) - shutil.rmtree(self.previous_path) - temp_dir = os.path.abspath(os.path.join(self.validation_auth_repo.repo_path, os.pardir)) - shutil.rmtree(temp_dir, onerror=on_rm_error) - - def earliest_valid_expiration_time(self): - # metadata at a certain revision should not expire before the - # time it was committed. It can be expected that the metadata files - # at older commits will be expired and that should not be considered - # to be an error - return int(self.validation_auth_repo.get_commits_date(self.current_commit)) - - def ensure_not_changed(self, metadata_filename): - """ + shutil.rmtree(self.current_path) + shutil.rmtree(self.previous_path) + temp_dir = os.path.abspath( + os.path.join(self.validation_auth_repo.repo_path, os.pardir) + ) + shutil.rmtree(temp_dir, onerror=on_rm_error) + + def earliest_valid_expiration_time(self): + # metadata at a certain revision should not expire before the + # time it was committed. It can be expected that the metadata files + # at older commits will be expired and that should not be considered + # to be an error + return int(self.validation_auth_repo.get_commits_date(self.current_commit)) + + def ensure_not_changed(self, metadata_filename): + """ Make sure that the metadata file remained the same, as the reference metadata suggests. """ - current_file = self.get_metadata_file(self.current_commit, file_name=metadata_filename) - previous_file = self.get_metadata_file(self.previous_commit, file_name=metadata_filename) - if current_file.read() != previous_file.read(): - raise UpdateFailedError('Metadata file {} should be the same at revisions {} and {}, but is not.' - .format(metadata_filename, self.previous_commit, self.current_commit)) - - def get_current_targets(self): - return self.validation_auth_repo.list_files_at_revision(self.current_commit, 'targets') - - def get_mirrors(self, _file_type, _file_path): - # pylint: disable=unused-argument - # return a list containing just the current commit - return [self.current_commit] - - def get_metadata_file(self, file_mirror, file_name, _upperbound_filelength=None): - return self._get_file(file_mirror, 'metadata/' + file_name) - - def get_target_file(self, file_mirror, _file_length, _download_safely, file_path): - return self._get_file(file_mirror, 'targets/' + file_path) - - def _get_file(self, commit, filepath): - f = self.validation_auth_repo.get_file(commit, filepath) - temp_file_object = securesystemslib.util.TempFile() - temp_file_object.write(f.encode()) - return temp_file_object - - def get_file_digest(self, filepath, algorithm): - filepath = os.path.relpath(filepath, self.validation_auth_repo.get_file) - file_obj = self._get_file(self.current_commit, filepath) - return securesystemslib.hash.digest_fileobject(file_obj, - algorithm=algorithm) - - def on_successful_update(self, filename, mirror): - # after the is successfully completed, set the - # next commit as current for the given file - logger.debug('%s updated from commit %s', filename, mirror) - - def on_unsuccessful_update(self, filename): - logger.error('Failed to update %s', filename) - - def update_done(self): - # the only metadata file that is always updated - # regardless of if it changed or not is timestamp - # so we can check if timestamp was updated a certain - # number of times - self.current_commit_index += 1 - return self.current_commit_index == len(self.commits) + current_file = self.get_metadata_file( + self.current_commit, file_name=metadata_filename + ) + previous_file = self.get_metadata_file( + self.previous_commit, file_name=metadata_filename + ) + if current_file.read() != previous_file.read(): + raise UpdateFailedError( + "Metadata file {} should be the same at revisions {} and {}, but is not.".format( + metadata_filename, self.previous_commit, self.current_commit + ) + ) + + def get_current_targets(self): + return self.validation_auth_repo.list_files_at_revision( + self.current_commit, "targets" + ) + + def get_mirrors(self, _file_type, _file_path): + # pylint: disable=unused-argument + # return a list containing just the current commit + return [self.current_commit] + + def get_metadata_file(self, file_mirror, file_name, _upperbound_filelength=None): + return self._get_file(file_mirror, "metadata/" + file_name) + + def get_target_file(self, file_mirror, _file_length, _download_safely, file_path): + return self._get_file(file_mirror, "targets/" + file_path) + + def _get_file(self, commit, filepath): + f = self.validation_auth_repo.get_file(commit, filepath) + temp_file_object = securesystemslib.util.TempFile() + temp_file_object.write(f.encode()) + return temp_file_object + + def get_file_digest(self, filepath, algorithm): + filepath = os.path.relpath(filepath, self.validation_auth_repo.get_file) + file_obj = self._get_file(self.current_commit, filepath) + return securesystemslib.hash.digest_fileobject(file_obj, algorithm=algorithm) + + def on_successful_update(self, filename, mirror): + # after the is successfully completed, set the + # next commit as current for the given file + logger.debug("%s updated from commit %s", filename, mirror) + + def on_unsuccessful_update(self, filename): + logger.error("Failed to update %s", filename) + + def update_done(self): + # the only metadata file that is always updated + # regardless of if it changed or not is timestamp + # so we can check if timestamp was updated a certain + # number of times + self.current_commit_index += 1 + return self.current_commit_index == len(self.commits) diff --git a/taf/updater/updater.py b/taf/updater/updater.py index 50ebc9df0..1690e9833 100644 --- a/taf/updater/updater.py +++ b/taf/updater/updater.py @@ -14,9 +14,16 @@ logger = taf.log.get_logger(__name__) -def update_repository(url, clients_repo_path, targets_dir, update_from_filesystem, - authenticate_test_repo=False, target_repo_classes=None, target_factory=None): - """ +def update_repository( + url, + clients_repo_path, + targets_dir, + update_from_filesystem, + authenticate_test_repo=False, + target_repo_classes=None, + target_factory=None, +): + """ url: URL of the remote authentication repository @@ -35,20 +42,34 @@ def update_repository(url, clients_repo_path, targets_dir, update_from_filesyste A git repositories factory used when instantiating target repositories. See repositoriesdb load_repositories for more details. """ - # if the repository's name is not provided, divide it in parent directory - # and repository name, since TUF's updater expects a name - # but set the validate_repo_name setting to False - clients_dir, repo_name = os.path.split(os.path.normpath(clients_repo_path)) - settings.validate_repo_name = False - update_named_repository(url, clients_dir, repo_name, targets_dir, - update_from_filesystem, authenticate_test_repo, - target_repo_classes, target_factory) - - -def update_named_repository(url, clients_directory, repo_name, targets_dir, - update_from_filesystem, authenticate_test_repo=False, - target_repo_classes=None, target_factory=None): - """ + # if the repository's name is not provided, divide it in parent directory + # and repository name, since TUF's updater expects a name + # but set the validate_repo_name setting to False + clients_dir, repo_name = os.path.split(os.path.normpath(clients_repo_path)) + settings.validate_repo_name = False + update_named_repository( + url, + clients_dir, + repo_name, + targets_dir, + update_from_filesystem, + authenticate_test_repo, + target_repo_classes, + target_factory, + ) + + +def update_named_repository( + url, + clients_directory, + repo_name, + targets_dir, + update_from_filesystem, + authenticate_test_repo=False, + target_repo_classes=None, + target_factory=None, +): + """ url: URL of the remote authentication repository @@ -96,225 +117,285 @@ def update_named_repository(url, clients_directory, repo_name, targets_dir, loads data from a most recent commit. """ - # at the moment, we assume that the initial commit is valid and that it contains at least root.json - - settings.update_from_filesystem = update_from_filesystem - # instantiate TUF's updater - repository_mirrors = {'mirror1': {'url_prefix': url, - 'metadata_path': 'metadata', - 'targets_path': 'targets', - 'confined_target_dirs': ['']}} - - tuf.settings.repositories_directory = clients_directory - repository_updater = tuf_updater.Updater(repo_name, - repository_mirrors, - GitUpdater) - users_auth_repo = repository_updater.update_handler.users_auth_repo - existing_repo = users_auth_repo.is_git_repository_root - try: - validation_auth_repo = repository_updater.update_handler.validation_auth_repo - commits = repository_updater.update_handler.commits - last_validated_commit = users_auth_repo.last_validated_commit - if last_validated_commit is None: - # check if the repository being updated is a test repository - targets = validation_auth_repo.get_json(commits[-1], 'metadata/targets.json') - test_repo = 'test-auth-repo' in targets['signed']['targets'] - if test_repo and not authenticate_test_repo: - raise UpdateFailedError('Repository {} is a test repository. Call update with ' - '"--authenticate-test-repo" to update a test ' - 'repository'.format(users_auth_repo.repo_name)) - elif not test_repo and authenticate_test_repo: - raise UpdateFailedError('Repository {} is not a test repository, but update was called ' - 'with the "--authenticate-test-repo" flag'.format(users_auth_repo.repo_name)) - - # validate the authentication repository and fetch new commits - _update_authentication_repository(repository_updater) - - # get target repositories and their commits, as specified in targets.json - - repositoriesdb.load_repositories(users_auth_repo, repo_classes=target_repo_classes, - factory=target_factory, root_dir=targets_dir, - commits=commits) - repositories = repositoriesdb.get_deduplicated_repositories(users_auth_repo, commits) - repositories_commits = users_auth_repo.sorted_commits_per_repositories(commits) - - # update target repositories - repositories_json = users_auth_repo.get_json(commits[-1], 'targets/repositories.json') - last_validated_commit = users_auth_repo.last_validated_commit - _update_target_repositories(repositories, repositories_json, repositories_commits, - last_validated_commit) - - last_commit = commits[-1] - logger.info('Merging commit %s into %s', last_commit, users_auth_repo.repo_name) - # if there were no errors, merge the last validated authentication repository commit - users_auth_repo.checkout_branch(users_auth_repo.default_branch) - users_auth_repo.merge_commit(last_commit) - # update the last validated commit - users_auth_repo.set_last_validated_commit(last_commit) - except Exception as e: - if not existing_repo: - shutil.rmtree(users_auth_repo.repo_path, onerror=on_rm_error) - shutil.rmtree(users_auth_repo.conf_dir) - raise e - finally: - repositoriesdb.clear_repositories_db() + # at the moment, we assume that the initial commit is valid and that it contains at least root.json + + settings.update_from_filesystem = update_from_filesystem + # instantiate TUF's updater + repository_mirrors = { + "mirror1": { + "url_prefix": url, + "metadata_path": "metadata", + "targets_path": "targets", + "confined_target_dirs": [""], + } + } + + tuf.settings.repositories_directory = clients_directory + repository_updater = tuf_updater.Updater(repo_name, repository_mirrors, GitUpdater) + users_auth_repo = repository_updater.update_handler.users_auth_repo + existing_repo = users_auth_repo.is_git_repository_root + try: + validation_auth_repo = repository_updater.update_handler.validation_auth_repo + commits = repository_updater.update_handler.commits + last_validated_commit = users_auth_repo.last_validated_commit + if last_validated_commit is None: + # check if the repository being updated is a test repository + targets = validation_auth_repo.get_json( + commits[-1], "metadata/targets.json" + ) + test_repo = "test-auth-repo" in targets["signed"]["targets"] + if test_repo and not authenticate_test_repo: + raise UpdateFailedError( + "Repository {} is a test repository. Call update with " + '"--authenticate-test-repo" to update a test ' + "repository".format(users_auth_repo.repo_name) + ) + elif not test_repo and authenticate_test_repo: + raise UpdateFailedError( + "Repository {} is not a test repository, but update was called " + 'with the "--authenticate-test-repo" flag'.format( + users_auth_repo.repo_name + ) + ) + + # validate the authentication repository and fetch new commits + _update_authentication_repository(repository_updater) + + # get target repositories and their commits, as specified in targets.json + + repositoriesdb.load_repositories( + users_auth_repo, + repo_classes=target_repo_classes, + factory=target_factory, + root_dir=targets_dir, + commits=commits, + ) + repositories = repositoriesdb.get_deduplicated_repositories( + users_auth_repo, commits + ) + repositories_commits = users_auth_repo.sorted_commits_per_repositories(commits) + + # update target repositories + repositories_json = users_auth_repo.get_json( + commits[-1], "targets/repositories.json" + ) + last_validated_commit = users_auth_repo.last_validated_commit + _update_target_repositories( + repositories, repositories_json, repositories_commits, last_validated_commit + ) + + last_commit = commits[-1] + logger.info("Merging commit %s into %s", last_commit, users_auth_repo.repo_name) + # if there were no errors, merge the last validated authentication repository commit + users_auth_repo.checkout_branch(users_auth_repo.default_branch) + users_auth_repo.merge_commit(last_commit) + # update the last validated commit + users_auth_repo.set_last_validated_commit(last_commit) + except Exception as e: + if not existing_repo: + shutil.rmtree(users_auth_repo.repo_path, onerror=on_rm_error) + shutil.rmtree(users_auth_repo.conf_dir) + raise e + finally: + repositoriesdb.clear_repositories_db() def _update_authentication_repository(repository_updater): - users_auth_repo = repository_updater.update_handler.users_auth_repo - logger.info('Validating authentication repository %s', users_auth_repo.repo_name) - try: - while not repository_updater.update_handler.update_done(): - current_commit = repository_updater.update_handler.current_commit - repository_updater.refresh() - # using refresh, we have updated all main roles - # we still need to update the delegated roles (if there are any) - # that is handled by get_current_targets - current_targets = repository_updater.update_handler.get_current_targets() - logger.debug('Validated metadata files at revision %s', current_commit) - for target_path in current_targets: - target = repository_updater.get_one_valid_targetinfo(target_path) - target_filepath = target['filepath'] - trusted_length = target['fileinfo']['length'] - trusted_hashes = target['fileinfo']['hashes'] - try: - repository_updater._get_target_file(target_filepath, trusted_length, trusted_hashes) # pylint: disable=W0212 # noqa - except tuf.exceptions.NoWorkingMirrorError as e: - logger.error('Could not validate file %s', target_filepath) - raise e - logger.debug('Successfully validated target file %s at %s', target_filepath, - current_commit) - except Exception as e: - # for now, useful for debugging - logger.error('Validation of authentication repository %s failed due to error %s', - users_auth_repo.repo_name, e) - raise UpdateFailedError('Validation of authentication repository {} failed due to error: {}' - .format(users_auth_repo.repo_name, e)) - finally: - repository_updater.update_handler.cleanup() - - logger.info('Successfully validated authentication repository %s', users_auth_repo.repo_name) - # fetch the latest commit or clone the repository without checkout - # do not merge before targets are validated as well - if users_auth_repo.is_git_repository_root: - users_auth_repo.fetch(True) - else: - users_auth_repo.clone(no_checkout=True) - - -def _update_target_repositories(repositories, repositories_json, repositories_commits, - last_validated_commit): - logger.info('Validating target repositories') - - # keep track of the repositories which were cloned - # so that they can be removed if the update fails - cloned_repositories = [] - allow_unauthenticated = {} - new_commits = {} - - for path, repository in repositories.items(): - - allow_unauthenticated_for_repo = repositories_json['repositories'][repository.repo_name]. \ - get('custom', {}).get('allow-unauthenticated-commits', False) - allow_unauthenticated[path] = allow_unauthenticated_for_repo - - # if last_validated_commit is None, start the update from the beginning - - is_git_repository = repository.is_git_repository_root - if last_validated_commit is None or not is_git_repository: - old_head = None + users_auth_repo = repository_updater.update_handler.users_auth_repo + logger.info("Validating authentication repository %s", users_auth_repo.repo_name) + try: + while not repository_updater.update_handler.update_done(): + current_commit = repository_updater.update_handler.current_commit + repository_updater.refresh() + # using refresh, we have updated all main roles + # we still need to update the delegated roles (if there are any) + # that is handled by get_current_targets + current_targets = repository_updater.update_handler.get_current_targets() + logger.debug("Validated metadata files at revision %s", current_commit) + for target_path in current_targets: + target = repository_updater.get_one_valid_targetinfo(target_path) + target_filepath = target["filepath"] + trusted_length = target["fileinfo"]["length"] + trusted_hashes = target["fileinfo"]["hashes"] + try: + repository_updater._get_target_file( + target_filepath, trusted_length, trusted_hashes + ) # pylint: disable=W0212 # noqa + except tuf.exceptions.NoWorkingMirrorError as e: + logger.error("Could not validate file %s", target_filepath) + raise e + logger.debug( + "Successfully validated target file %s at %s", + target_filepath, + current_commit, + ) + except Exception as e: + # for now, useful for debugging + logger.error( + "Validation of authentication repository %s failed due to error %s", + users_auth_repo.repo_name, + e, + ) + raise UpdateFailedError( + "Validation of authentication repository {} failed due to error: {}".format( + users_auth_repo.repo_name, e + ) + ) + finally: + repository_updater.update_handler.cleanup() + + logger.info( + "Successfully validated authentication repository %s", users_auth_repo.repo_name + ) + # fetch the latest commit or clone the repository without checkout + # do not merge before targets are validated as well + if users_auth_repo.is_git_repository_root: + users_auth_repo.fetch(True) else: - old_head = repository.head_commit_sha() + users_auth_repo.clone(no_checkout=True) - if old_head is None and not is_git_repository: - repository.clone(no_checkout=True) - cloned_repositories.append(repository) - else: - repository.fetch(True) - if old_head is not None: - if allow_unauthenticated: - old_head = repositories_commits[path][0] - new_commits_for_repo = repository.all_fetched_commits() - new_commits_for_repo.insert(0, old_head) - else: - new_commits_for_repo = repository.all_commits_since_commit(old_head) - if is_git_repository: - # this happens in the case when last_validated_commit does not exist - # we want to validate all commits, so combine existing commits and - # fetched commits - fetched_commits = repository.all_fetched_commits() - new_commits_for_repo.extend(fetched_commits) - new_commits[path] = new_commits_for_repo +def _update_target_repositories( + repositories, repositories_json, repositories_commits, last_validated_commit +): + logger.info("Validating target repositories") - try: - _update_target_repository(repository, new_commits_for_repo, repositories_commits[path], - allow_unauthenticated_for_repo) - except UpdateFailedError as e: - # delete all repositories that were cloned - for repo in cloned_repositories: - logger.debug('Removing cloned repository %s', repo.repo_path) - shutil.rmtree(repo.repo_path, onerror=on_rm_error) - # TODO is it important to undo a fetch if the repository was not cloned? - raise e - - logger.info('Successfully validated all target repositories.') - # if update is successful, merge the commits - for path, repository in repositories.items(): - repository.checkout_branch(repository.default_branch) - if len(repositories_commits[path]): - logger.info('Merging %s into %s', repositories_commits[path][-1], repository.repo_name) - last_validated_commit = repositories_commits[path][-1] - commit_to_merge = last_validated_commit if not allow_unauthenticated[path] else new_commits[path][-1] - repository.merge_commit(commit_to_merge) - if not allow_unauthenticated[path]: - repository.checkout_commit(commit_to_merge) - else: - repository.checkout_branch(repository.default_branch) + # keep track of the repositories which were cloned + # so that they can be removed if the update fails + cloned_repositories = [] + allow_unauthenticated = {} + new_commits = {} + for path, repository in repositories.items(): -def _update_target_repository(repository, new_commits, target_commits, - allow_unauthenticated): - - logger.info('Validating target repository %s', repository.repo_name) - # A new commit might have been pushed after the update process - # started and before fetch was called - # So, the number of new commits, pushed to the target repository, could - # be greater than the number of these commits according to the authentication - # repository. The opposite cannot be the case. - # In general, if there are additional commits in the target repositories, - # the updater will finish the update successfully, but will only update the - # target repositories until the latest validate commit - if not allow_unauthenticated: - update_successful = len(new_commits) >= len(target_commits) - if update_successful: - for target_commit, repo_commit in zip(target_commits, new_commits): - if target_commit != repo_commit: - update_successful = False - break - if len(new_commits) > len(target_commits): - additional_commits = new_commits[len(target_commits):] - logger.warning('Found commits %s in repository %s that are not accounted for in the authentication repo.' - 'Repoisitory will be updated up to commit %s', additional_commits, repository.repo_name, - target_commits[-1]) - else: - logger.info('Unauthenticated commits allowed in repository %s', repository.repo_name) - update_successful = True - target_commits_index = 0 - for commit in new_commits: - if commit in target_commits: - if commit != target_commits[target_commits_index]: - update_successful = False - break + allow_unauthenticated_for_repo = ( + repositories_json["repositories"][repository.repo_name] + .get("custom", {}) + .get("allow-unauthenticated-commits", False) + ) + allow_unauthenticated[path] = allow_unauthenticated_for_repo + + # if last_validated_commit is None, start the update from the beginning + + is_git_repository = repository.is_git_repository_root + if last_validated_commit is None or not is_git_repository: + old_head = None + else: + old_head = repository.head_commit_sha() + + if old_head is None and not is_git_repository: + repository.clone(no_checkout=True) + cloned_repositories.append(repository) else: - target_commits_index += 1 + repository.fetch(True) - update_successful = target_commits_index == len(target_commits) + if old_head is not None: + if allow_unauthenticated: + old_head = repositories_commits[path][0] + new_commits_for_repo = repository.all_fetched_commits() + new_commits_for_repo.insert(0, old_head) + else: + new_commits_for_repo = repository.all_commits_since_commit(old_head) + if is_git_repository: + # this happens in the case when last_validated_commit does not exist + # we want to validate all commits, so combine existing commits and + # fetched commits + fetched_commits = repository.all_fetched_commits() + new_commits_for_repo.extend(fetched_commits) + new_commits[path] = new_commits_for_repo - if not update_successful: - logger.error('Mismatch between target commits specified in authentication repository and the ' - 'target repository %s', repository.repo_name) - raise UpdateFailedError('Mismatch between target commits specified in authentication repository' - ' and target repository {}'.format(repository.repo_name)) - logger.info('Successfully validated %s', repository.repo_name) + try: + _update_target_repository( + repository, + new_commits_for_repo, + repositories_commits[path], + allow_unauthenticated_for_repo, + ) + except UpdateFailedError as e: + # delete all repositories that were cloned + for repo in cloned_repositories: + logger.debug("Removing cloned repository %s", repo.repo_path) + shutil.rmtree(repo.repo_path, onerror=on_rm_error) + # TODO is it important to undo a fetch if the repository was not cloned? + raise e + + logger.info("Successfully validated all target repositories.") + # if update is successful, merge the commits + for path, repository in repositories.items(): + repository.checkout_branch(repository.default_branch) + if len(repositories_commits[path]): + logger.info( + "Merging %s into %s", + repositories_commits[path][-1], + repository.repo_name, + ) + last_validated_commit = repositories_commits[path][-1] + commit_to_merge = ( + last_validated_commit + if not allow_unauthenticated[path] + else new_commits[path][-1] + ) + repository.merge_commit(commit_to_merge) + if not allow_unauthenticated[path]: + repository.checkout_commit(commit_to_merge) + else: + repository.checkout_branch(repository.default_branch) + + +def _update_target_repository( + repository, new_commits, target_commits, allow_unauthenticated +): + + logger.info("Validating target repository %s", repository.repo_name) + # A new commit might have been pushed after the update process + # started and before fetch was called + # So, the number of new commits, pushed to the target repository, could + # be greater than the number of these commits according to the authentication + # repository. The opposite cannot be the case. + # In general, if there are additional commits in the target repositories, + # the updater will finish the update successfully, but will only update the + # target repositories until the latest validate commit + if not allow_unauthenticated: + update_successful = len(new_commits) >= len(target_commits) + if update_successful: + for target_commit, repo_commit in zip(target_commits, new_commits): + if target_commit != repo_commit: + update_successful = False + break + if len(new_commits) > len(target_commits): + additional_commits = new_commits[len(target_commits) :] + logger.warning( + "Found commits %s in repository %s that are not accounted for in the authentication repo." + "Repoisitory will be updated up to commit %s", + additional_commits, + repository.repo_name, + target_commits[-1], + ) + else: + logger.info( + "Unauthenticated commits allowed in repository %s", repository.repo_name + ) + update_successful = True + target_commits_index = 0 + for commit in new_commits: + if commit in target_commits: + if commit != target_commits[target_commits_index]: + update_successful = False + break + else: + target_commits_index += 1 + + update_successful = target_commits_index == len(target_commits) + + if not update_successful: + logger.error( + "Mismatch between target commits specified in authentication repository and the " + "target repository %s", + repository.repo_name, + ) + raise UpdateFailedError( + "Mismatch between target commits specified in authentication repository" + " and target repository {}".format(repository.repo_name) + ) + logger.info("Successfully validated %s", repository.repo_name) diff --git a/taf/utils.py b/taf/utils.py index 351ef03b4..f15530754 100644 --- a/taf/utils.py +++ b/taf/utils.py @@ -15,138 +15,148 @@ def _iso_parse(date): - return datetime.datetime.strptime(date, '%Y-%m-%d %H:%M:%S.%f') + return datetime.datetime.strptime(date, "%Y-%m-%d %H:%M:%S.%f") class IsoDateParamType(click.ParamType): - name = 'iso_date' + name = "iso_date" - def convert(self, value, param, ctx): - if value is None: - return datetime.datetime.now() + def convert(self, value, param, ctx): + if value is None: + return datetime.datetime.now() - if isinstance(value, datetime.datetime): - return value - try: - return _iso_parse(value) - except ValueError as ex: - self.fail(str(ex), param, ctx) + if isinstance(value, datetime.datetime): + return value + try: + return _iso_parse(value) + except ValueError as ex: + self.fail(str(ex), param, ctx) ISO_DATE_PARAM_TYPE = IsoDateParamType() def extract_x509(cert_pem): - from cryptography import x509 - from cryptography.hazmat.backends import default_backend + from cryptography import x509 + from cryptography.hazmat.backends import default_backend - cert = x509.load_pem_x509_certificate(cert_pem, default_backend()) + cert = x509.load_pem_x509_certificate(cert_pem, default_backend()) - def _get_attr(oid): - attrs = cert.subject.get_attributes_for_oid(oid) - return attrs[0].value if len(attrs) > 0 else "" + def _get_attr(oid): + attrs = cert.subject.get_attributes_for_oid(oid) + return attrs[0].value if len(attrs) > 0 else "" - return { - "name": _get_attr(x509.OID_COMMON_NAME), - "organization": _get_attr(x509.OID_ORGANIZATION_NAME), - "country": _get_attr(x509.OID_COUNTRY_NAME), - "state": _get_attr(x509.OID_STATE_OR_PROVINCE_NAME), - "locality": _get_attr(x509.OID_LOCALITY_NAME), - "valid_from": cert.not_valid_before.strftime("%Y-%m-%d"), - "valid_to": cert.not_valid_after.strftime("%Y-%m-%d"), - } + return { + "name": _get_attr(x509.OID_COMMON_NAME), + "organization": _get_attr(x509.OID_ORGANIZATION_NAME), + "country": _get_attr(x509.OID_COUNTRY_NAME), + "state": _get_attr(x509.OID_STATE_OR_PROVINCE_NAME), + "locality": _get_attr(x509.OID_LOCALITY_NAME), + "valid_from": cert.not_valid_before.strftime("%Y-%m-%d"), + "valid_to": cert.not_valid_after.strftime("%Y-%m-%d"), + } def get_cert_names_from_keyids(certs_dir, keyids): - cert_names = [] - for keyid in keyids: - try: - name = extract_x509((Path(certs_dir) / keyid + ".pem").read_bytes())['name'] - if not name: - print("Cannot extract common name from x509, using key id instead.") - cert_names.append(keyid) - else: - cert_names.append(name) - except FileNotFoundError: - print("Certificate does not exist ({}).".format(keyid)) - return cert_names + cert_names = [] + for keyid in keyids: + try: + name = extract_x509((Path(certs_dir) / keyid + ".pem").read_bytes())["name"] + if not name: + print("Cannot extract common name from x509, using key id instead.") + cert_names.append(keyid) + else: + cert_names.append(name) + except FileNotFoundError: + print("Certificate does not exist ({}).".format(keyid)) + return cert_names def get_pin_for(name, confirm=True, repeat=True): - pin = getpass('Enter PIN for {}: '.format(name)) - if confirm: - if pin != getpass('Confirm PIN for {}: '.format(name)): - err_msg = "PINs don't match!" - if repeat: - print(err_msg) - get_pin_for(name, confirm, repeat) - else: - raise PINMissmatchError(err_msg) - return pin + pin = getpass("Enter PIN for {}: ".format(name)) + if confirm: + if pin != getpass("Confirm PIN for {}: ".format(name)): + err_msg = "PINs don't match!" + if repeat: + print(err_msg) + get_pin_for(name, confirm, repeat) + else: + raise PINMissmatchError(err_msg) + return pin def run(*command, **kwargs): - """Run a command and return its output. Call with `debug=True` to print to + """Run a command and return its output. Call with `debug=True` to print to stdout.""" - if len(command) == 1 and isinstance(command[0], str): - command = command[0].split() - if taf.settings.LOG_COMMAND_OUTPUT: - logger.debug('About to run command %s', ' '.join(command)) + if len(command) == 1 and isinstance(command[0], str): + command = command[0].split() + if taf.settings.LOG_COMMAND_OUTPUT: + logger.debug("About to run command %s", " ".join(command)) - def _format_word(word, **env): - """To support word such as @{u} needed for git commands.""" + def _format_word(word, **env): + """To support word such as @{u} needed for git commands.""" + try: + return word.format(env) + except KeyError: + return word + + command = [_format_word(word, **os.environ) for word in command] try: - return word.format(env) - except KeyError: - return word - command = [_format_word(word, **os.environ) for word in command] - try: - options = dict(stdout=subprocess.PIPE, stderr=subprocess.STDOUT, check=True, - universal_newlines=True) - options.update(kwargs) - completed = subprocess.run(command, **options) - except subprocess.CalledProcessError as err: - if err.stdout: - logger.debug(err.stdout) - if err.stderr: - logger.debug(err.stderr) - logger.info('Command %s returned non-zero exit status %s', ' '.join(command), err.returncode) - raise err - if completed.stdout: - if taf.settings.LOG_COMMAND_OUTPUT: - logger.debug(completed.stdout) - return completed.stdout.rstrip() if completed.returncode == 0 else None + options = dict( + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + check=True, + universal_newlines=True, + ) + options.update(kwargs) + completed = subprocess.run(command, **options) + except subprocess.CalledProcessError as err: + if err.stdout: + logger.debug(err.stdout) + if err.stderr: + logger.debug(err.stderr) + logger.info( + "Command %s returned non-zero exit status %s", + " ".join(command), + err.returncode, + ) + raise err + if completed.stdout: + if taf.settings.LOG_COMMAND_OUTPUT: + logger.debug(completed.stdout) + return completed.stdout.rstrip() if completed.returncode == 0 else None def normalize_file_line_endings(file_path): - with open(file_path, 'rb') as open_file: - content = open_file.read() - replaced_content = normalize_line_endings(content) - if replaced_content != content: - with open(file_path, 'wb') as open_file: - open_file.write(replaced_content) + with open(file_path, "rb") as open_file: + content = open_file.read() + replaced_content = normalize_line_endings(content) + if replaced_content != content: + with open(file_path, "wb") as open_file: + open_file.write(replaced_content) def normalize_line_endings(file_content): - WINDOWS_LINE_ENDING = b'\r\n' - UNIX_LINE_ENDING = b'\n' - replaced_content = file_content.replace( - WINDOWS_LINE_ENDING, UNIX_LINE_ENDING).rstrip(UNIX_LINE_ENDING) - return replaced_content + WINDOWS_LINE_ENDING = b"\r\n" + UNIX_LINE_ENDING = b"\n" + replaced_content = file_content.replace( + WINDOWS_LINE_ENDING, UNIX_LINE_ENDING + ).rstrip(UNIX_LINE_ENDING) + return replaced_content def on_rm_error(_func, path, _exc_info): - """Used by when calling rmtree to ensure that readonly files and folders + """Used by when calling rmtree to ensure that readonly files and folders are deleted. """ - os.chmod(path, stat.S_IWRITE) - os.unlink(path) + os.chmod(path, stat.S_IWRITE) + os.unlink(path) def to_tuf_datetime_format(start_date, interval): - """Used to convert datetime to format used while writing metadata: + """Used to convert datetime to format used while writing metadata: e.g. "2020-05-29T21:59:34Z", """ - datetime_object = start_date + datetime.timedelta(interval) - datetime_object = datetime_object.replace(microsecond=0) - return datetime_object.isoformat() + 'Z' + datetime_object = start_date + datetime.timedelta(interval) + datetime_object = datetime_object.replace(microsecond=0) + return datetime_object.isoformat() + "Z" diff --git a/taf/validation.py b/taf/validation.py index 44a0a545d..20eac3dde 100644 --- a/taf/validation.py +++ b/taf/validation.py @@ -4,7 +4,7 @@ def validate_branch(auth_repo, target_repos, branch_name): - """ + """ Validates corresponding branches of the authentication repository and the target repositories. Assumes that: 1. Commits of the target repositories' branches are merged into the default (master) branch @@ -20,96 +20,111 @@ def validate_branch(auth_repo, target_repos, branch_name): 4. If all commits of an authentication repository's branch have the same branch ID """ - check_capstone(auth_repo, branch_name) + check_capstone(auth_repo, branch_name) - targets_and_commits = {target_repo: target_repo. - commits_on_branch_and_not_other(branch_name, 'master') - for target_repo in target_repos} - auth_commits = auth_repo.commits_on_branch_and_not_other(branch_name, 'master') + targets_and_commits = { + target_repo: target_repo.commits_on_branch_and_not_other(branch_name, "master") + for target_repo in target_repos + } + auth_commits = auth_repo.commits_on_branch_and_not_other(branch_name, "master") - _check_lengths_of_branches(targets_and_commits, branch_name) + _check_lengths_of_branches(targets_and_commits, branch_name) - targets_version = None - branch_id = None - targets_path = 'metadata/targets.json' + targets_version = None + branch_id = None + targets_path = "metadata/targets.json" - # fill the shorter lists with None values, so that their sizes match the size - # of authentication repository's commits list - for commits in targets_and_commits.values(): - commits.extend([None] * (len(auth_commits) - len(commits))) + # fill the shorter lists with None values, so that their sizes match the size + # of authentication repository's commits list + for commits in targets_and_commits.values(): + commits.extend([None] * (len(auth_commits) - len(commits))) - for commit_index, auth_commit in enumerate(auth_commits): - # load content of targets.json - targets = auth_repo.get_json(auth_commit, targets_path) - targets_version = _check_targets_version(targets, auth_commit, targets_version) - branch_id = _check_branch_id(auth_repo, auth_commit, branch_id) + for commit_index, auth_commit in enumerate(auth_commits): + # load content of targets.json + targets = auth_repo.get_json(auth_commit, targets_path) + targets_version = _check_targets_version(targets, auth_commit, targets_version) + branch_id = _check_branch_id(auth_repo, auth_commit, branch_id) - for target, target_commits in targets_and_commits.items(): - target_commit = target_commits[commit_index] + for target, target_commits in targets_and_commits.items(): + target_commit = target_commits[commit_index] - # targets' commits match the target commits specified in the authentication repository - if target_commit is not None: - _compare_commit_with_targets_metadata(auth_repo, auth_commit, target, target_commit) + # targets' commits match the target commits specified in the authentication repository + if target_commit is not None: + _compare_commit_with_targets_metadata( + auth_repo, auth_commit, target, target_commit + ) def _check_lengths_of_branches(targets_and_commits, branch_name): - """ + """ Checks if branches of the given name have the same number of commits in each of the provided repositories. """ - lengths = set(len(commits) for commits in targets_and_commits.values()) - if len(lengths) > 1: - msg = 'Branches {} of target repositories do not have the same number of commits' \ - .format(branch_name) - for target, commits in targets_and_commits.items(): - msg += '\n{} has {} commits.'.format(target.repo_name, len(commits)) - raise InvalidBranchError(msg) + lengths = set(len(commits) for commits in targets_and_commits.values()) + if len(lengths) > 1: + msg = "Branches {} of target repositories do not have the same number of commits".format( + branch_name + ) + for target, commits in targets_and_commits.items(): + msg += "\n{} has {} commits.".format(target.repo_name, len(commits)) + raise InvalidBranchError(msg) def _check_branch_id(auth_repo, auth_commit, branch_id): - new_branch_id = auth_repo.get_file(auth_commit, 'targets/branch') - if branch_id is not None and new_branch_id != branch_id: - raise InvalidBranchError('Branch ID at revision {} is not the same as the ' - 'version at the following revision'.format(auth_commit)) - return new_branch_id + new_branch_id = auth_repo.get_file(auth_commit, "targets/branch") + if branch_id is not None and new_branch_id != branch_id: + raise InvalidBranchError( + "Branch ID at revision {} is not the same as the " + "version at the following revision".format(auth_commit) + ) + return new_branch_id def _check_targets_version(targets, tuf_commit, current_version): - """ + """ Checks version numbers specified in targets.json (compares it to the previous one) There are no other metadata files to check (when building a speculative branch, we do not generate snapshot and timestamp, just targets.json and we have no delegations) Return the read version number """ - new_version = targets['signed']['version'] - # substracting one because the commits are in the reverse order - if current_version is not None and new_version != current_version - 1: - raise InvalidBranchError('Version of metadata file targets.json at revision ' - '{} is not equal to previous version incremented ' - 'by one!'.format(tuf_commit)) - return new_version + new_version = targets["signed"]["version"] + # substracting one because the commits are in the reverse order + if current_version is not None and new_version != current_version - 1: + raise InvalidBranchError( + "Version of metadata file targets.json at revision " + "{} is not equal to previous version incremented " + "by one!".format(tuf_commit) + ) + return new_version def check_capstone(auth_repo, branch): - """ + """ Check if there is a capstone file (a target file called capstone) at the end of the specified branch. Assumes that the branch is checked out. """ - capstone_path = os.path.join(auth_repo.repo_path, 'targets', 'capstone') - if not os.path.isfile(capstone_path): - raise InvalidBranchError('No capstone at the end of branch {}!!!'.format(branch)) + capstone_path = os.path.join(auth_repo.repo_path, "targets", "capstone") + if not os.path.isfile(capstone_path): + raise InvalidBranchError( + "No capstone at the end of branch {}!!!".format(branch) + ) -def _compare_commit_with_targets_metadata(tuf_repo, tuf_commit, target_repo, target_repo_commit): - """ +def _compare_commit_with_targets_metadata( + tuf_repo, tuf_commit, target_repo, target_repo_commit +): + """ Check if commit sha of a repository's speculative branch commit matches the specified target value in targets.json. """ - repo_name = 'targets/{}'.format(target_repo.repo_name) - targets_head_sha = tuf_repo.get_json(tuf_commit, repo_name)['commit'] - if target_repo_commit != targets_head_sha: - raise InvalidBranchError('Commit {} of repository {} does ' - 'not match the commit sha specified in targets.json!' - .format(target_repo_commit, target_repo.repo_name)) + repo_name = "targets/{}".format(target_repo.repo_name) + targets_head_sha = tuf_repo.get_json(tuf_commit, repo_name)["commit"] + if target_repo_commit != targets_head_sha: + raise InvalidBranchError( + "Commit {} of repository {} does " + "not match the commit sha specified in targets.json!".format( + target_repo_commit, target_repo.repo_name + ) + ) diff --git a/taf/yubikey.py b/taf/yubikey.py index d59205136..7d95466ec 100644 --- a/taf/yubikey.py +++ b/taf/yubikey.py @@ -7,38 +7,52 @@ from cryptography.hazmat.primitives.serialization import load_pem_private_key from tuf.repository_tool import import_rsakey_from_pem from ykman.descriptor import list_devices, open_device -from ykman.piv import (ALGO, DEFAULT_MANAGEMENT_KEY, PIN_POLICY, SLOT, - PivController, WrongPin, generate_random_management_key) +from ykman.piv import ( + ALGO, + DEFAULT_MANAGEMENT_KEY, + PIN_POLICY, + SLOT, + PivController, + WrongPin, + generate_random_management_key, +) from ykman.util import TRANSPORT from taf.constants import DEFAULT_RSA_SIGNATURE_SCHEME from taf.exceptions import YubikeyError -DEFAULT_PIN = '123456' -DEFAULT_PUK = '12345678' +DEFAULT_PIN = "123456" +DEFAULT_PUK = "12345678" def raise_yubikey_err(msg=None): - """Decorator used to catch all errors raised by yubikey-manager and raise + """Decorator used to catch all errors raised by yubikey-manager and raise YubikeyError. We don't need to handle specific cases. """ - def wrapper(f): - @wraps(f) - def decorator(*args, **kwargs): - try: - return f(*args, **kwargs) - except YubikeyError: - raise - except Exception as e: - err_msg = '{} Reason: ({}) {}'.format(msg, type(e).__name__, str(e)) if msg else str(e) - raise YubikeyError(err_msg) from e - return decorator - return wrapper + + def wrapper(f): + @wraps(f) + def decorator(*args, **kwargs): + try: + return f(*args, **kwargs) + except YubikeyError: + raise + except Exception as e: + err_msg = ( + "{} Reason: ({}) {}".format(msg, type(e).__name__, str(e)) + if msg + else str(e) + ) + raise YubikeyError(err_msg) from e + + return decorator + + return wrapper @contextmanager def _yk_piv_ctrl(serial=None, pub_key_pem=None): - """Context manager to open connection and instantiate piv controller. + """Context manager to open connection and instantiate piv controller. Args: - pub_key_pem(str): Match Yubikey's public key (PEM) if multiple keys @@ -50,34 +64,39 @@ def _yk_piv_ctrl(serial=None, pub_key_pem=None): Raises: - YubikeyError """ - # If pub_key_pem is given, iterate all devices, read x509 certs and try to match - # public keys. - if pub_key_pem is not None: - for yk in list_devices(transports=TRANSPORT.CCID): - yk_ctrl = PivController(yk.driver) - device_pub_key_pem = (yk_ctrl - .read_certificate(SLOT.SIGNATURE) - .public_key() - .public_bytes( - encoding=serialization.Encoding.PEM, - format=serialization.PublicFormat.SubjectPublicKeyInfo) - .decode('utf-8')) - # Tries to match without last newline char - if device_pub_key_pem == pub_key_pem or device_pub_key_pem[:-1] == pub_key_pem: - break - else: - yk.close() - - else: - yk = open_device(transports=TRANSPORT.CCID, serial=serial) - yk_ctrl = PivController(yk.driver) - - yield yk_ctrl, yk.serial - yk.close() + # If pub_key_pem is given, iterate all devices, read x509 certs and try to match + # public keys. + if pub_key_pem is not None: + for yk in list_devices(transports=TRANSPORT.CCID): + yk_ctrl = PivController(yk.driver) + device_pub_key_pem = ( + yk_ctrl.read_certificate(SLOT.SIGNATURE) + .public_key() + .public_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PublicFormat.SubjectPublicKeyInfo, + ) + .decode("utf-8") + ) + # Tries to match without last newline char + if ( + device_pub_key_pem == pub_key_pem + or device_pub_key_pem[:-1] == pub_key_pem + ): + break + else: + yk.close() + + else: + yk = open_device(transports=TRANSPORT.CCID, serial=serial) + yk_ctrl = PivController(yk.driver) + + yield yk_ctrl, yk.serial + yk.close() def is_inserted(): - """Checks if YubiKey is inserted. + """Checks if YubiKey is inserted. Args: None @@ -88,12 +107,12 @@ def is_inserted(): Raises: - YubikeyError """ - return len(list(list_devices(transports=TRANSPORT.CCID))) > 0 + return len(list(list_devices(transports=TRANSPORT.CCID))) > 0 @raise_yubikey_err() def is_valid_pin(pin): - """Checks if given pin is valid. + """Checks if given pin is valid. Args: pin(str): Yubikey piv PIN @@ -104,17 +123,17 @@ def is_valid_pin(pin): Raises: - YubikeyError """ - with _yk_piv_ctrl() as (ctrl, _): - try: - ctrl.verify(pin) - return True, None # ctrl.get_pin_tries() fails if PIN is valid - except WrongPin: - return False, ctrl.get_pin_tries() + with _yk_piv_ctrl() as (ctrl, _): + try: + ctrl.verify(pin) + return True, None # ctrl.get_pin_tries() fails if PIN is valid + except WrongPin: + return False, ctrl.get_pin_tries() @raise_yubikey_err("Cannot get serial number.") def get_serial_num(pub_key_pem=None): - """Get Yubikey serial number. + """Get Yubikey serial number. Args: - pub_key_pem(str): Match Yubikey's public key (PEM) if multiple keys @@ -126,13 +145,13 @@ def get_serial_num(pub_key_pem=None): Raises: - YubikeyError """ - with _yk_piv_ctrl(pub_key_pem=pub_key_pem) as (_, serial): - return serial + with _yk_piv_ctrl(pub_key_pem=pub_key_pem) as (_, serial): + return serial @raise_yubikey_err("Cannot export x509 certificate.") def export_piv_x509(cert_format=serialization.Encoding.PEM, pub_key_pem=None): - """Exports YubiKey's piv slot x509. + """Exports YubiKey's piv slot x509. Args: - cert_format(str): One of 'serialization.Encoding' formats. @@ -145,14 +164,14 @@ def export_piv_x509(cert_format=serialization.Encoding.PEM, pub_key_pem=None): Raises: - YubikeyError """ - with _yk_piv_ctrl(pub_key_pem=pub_key_pem) as (ctrl, _): - x509 = ctrl.read_certificate(SLOT.SIGNATURE) - return x509.public_bytes(encoding=cert_format) + with _yk_piv_ctrl(pub_key_pem=pub_key_pem) as (ctrl, _): + x509 = ctrl.read_certificate(SLOT.SIGNATURE) + return x509.public_bytes(encoding=cert_format) @raise_yubikey_err("Cannot export public key.") def export_piv_pub_key(pub_key_format=serialization.Encoding.PEM, pub_key_pem=None): - """Exports YubiKey's piv slot public key. + """Exports YubiKey's piv slot public key. Args: - pub_key_format(str): One of 'serialization.Encoding' formats. @@ -165,15 +184,17 @@ def export_piv_pub_key(pub_key_format=serialization.Encoding.PEM, pub_key_pem=No Raises: - YubikeyError """ - with _yk_piv_ctrl(pub_key_pem=pub_key_pem) as (ctrl, _): - x509 = ctrl.read_certificate(SLOT.SIGNATURE) - return x509.public_key().public_bytes(encoding=pub_key_format, - format=serialization.PublicFormat.SubjectPublicKeyInfo) + with _yk_piv_ctrl(pub_key_pem=pub_key_pem) as (ctrl, _): + x509 = ctrl.read_certificate(SLOT.SIGNATURE) + return x509.public_key().public_bytes( + encoding=pub_key_format, + format=serialization.PublicFormat.SubjectPublicKeyInfo, + ) @raise_yubikey_err("Cannot get public key in TUF format.") def get_piv_public_key_tuf(scheme=DEFAULT_RSA_SIGNATURE_SCHEME, pub_key_pem=None): - """Return public key from a Yubikey in TUF's RSAKEY_SCHEMA format. + """Return public key from a Yubikey in TUF's RSAKEY_SCHEMA format. Args: - scheme(str): Rsa signature scheme (default is rsa-pkcs1v15-sha256) @@ -188,13 +209,13 @@ def get_piv_public_key_tuf(scheme=DEFAULT_RSA_SIGNATURE_SCHEME, pub_key_pem=None Raises: - YubikeyError """ - pub_key_pem = export_piv_pub_key(pub_key_pem=pub_key_pem).decode('utf-8') - return import_rsakey_from_pem(pub_key_pem, scheme) + pub_key_pem = export_piv_pub_key(pub_key_pem=pub_key_pem).decode("utf-8") + return import_rsakey_from_pem(pub_key_pem, scheme) @raise_yubikey_err("Cannot sign data.") def sign_piv_rsa_pkcs1v15(data, pin, pub_key_pem=None): - """Sign data with key from YubiKey's piv slot. + """Sign data with key from YubiKey's piv slot. Args: - data(bytes): Data to be signed @@ -208,15 +229,21 @@ def sign_piv_rsa_pkcs1v15(data, pin, pub_key_pem=None): Raises: - YubikeyError """ - with _yk_piv_ctrl(pub_key_pem=pub_key_pem) as (ctrl, _): - ctrl.verify(pin) - return ctrl.sign(SLOT.SIGNATURE, ALGO.RSA2048, data) + with _yk_piv_ctrl(pub_key_pem=pub_key_pem) as (ctrl, _): + ctrl.verify(pin) + return ctrl.sign(SLOT.SIGNATURE, ALGO.RSA2048, data) @raise_yubikey_err("Cannot setup Yubikey.") -def setup(pin, cert_cn, cert_exp_days=365, pin_retries=10, - private_key_pem=None, mgm_key=generate_random_management_key()): - """Use to setup inserted Yubikey, with following steps (order is important): +def setup( + pin, + cert_cn, + cert_exp_days=365, + pin_retries=10, + private_key_pem=None, + mgm_key=generate_random_management_key(), +): + """Use to setup inserted Yubikey, with following steps (order is important): - reset to factory settings - set management key - generate key(RSA2048) or import given one @@ -239,34 +266,35 @@ def setup(pin, cert_cn, cert_exp_days=365, pin_retries=10, Raises: - YubikeyError """ - with _yk_piv_ctrl() as (ctrl, _): - # Factory reset and set PINs - ctrl.reset() - - ctrl.authenticate(DEFAULT_MANAGEMENT_KEY) - ctrl.set_mgm_key(mgm_key) - - # Generate RSA2048 - if private_key_pem is None: - pub_key = ctrl.generate_key(SLOT.SIGNATURE, ALGO.RSA2048, PIN_POLICY.ALWAYS) - else: - private_key = load_pem_private_key(private_key_pem, None, default_backend()) - ctrl.import_key(SLOT.SIGNATURE, private_key, PIN_POLICY.ALWAYS) - pub_key = private_key.public_key() - - ctrl.authenticate(mgm_key) - ctrl.verify(DEFAULT_PIN) - - # Generate and import certificate - now = datetime.datetime.now() - valid_to = now + datetime.timedelta(days=cert_exp_days) - ctrl.generate_self_signed_certificate(SLOT.SIGNATURE, pub_key, cert_cn, now, valid_to) - - ctrl.set_pin_retries(pin_retries=pin_retries, puk_retries=pin_retries) - ctrl.change_pin(DEFAULT_PIN, pin) - ctrl.change_puk(DEFAULT_PUK, pin) - - return pub_key.public_bytes( - serialization.Encoding.PEM, - serialization.PublicFormat.SubjectPublicKeyInfo, - ) + with _yk_piv_ctrl() as (ctrl, _): + # Factory reset and set PINs + ctrl.reset() + + ctrl.authenticate(DEFAULT_MANAGEMENT_KEY) + ctrl.set_mgm_key(mgm_key) + + # Generate RSA2048 + if private_key_pem is None: + pub_key = ctrl.generate_key(SLOT.SIGNATURE, ALGO.RSA2048, PIN_POLICY.ALWAYS) + else: + private_key = load_pem_private_key(private_key_pem, None, default_backend()) + ctrl.import_key(SLOT.SIGNATURE, private_key, PIN_POLICY.ALWAYS) + pub_key = private_key.public_key() + + ctrl.authenticate(mgm_key) + ctrl.verify(DEFAULT_PIN) + + # Generate and import certificate + now = datetime.datetime.now() + valid_to = now + datetime.timedelta(days=cert_exp_days) + ctrl.generate_self_signed_certificate( + SLOT.SIGNATURE, pub_key, cert_cn, now, valid_to + ) + + ctrl.set_pin_retries(pin_retries=pin_retries, puk_retries=pin_retries) + ctrl.change_pin(DEFAULT_PIN, pin) + ctrl.change_puk(DEFAULT_PUK, pin) + + return pub_key.public_bytes( + serialization.Encoding.PEM, serialization.PublicFormat.SubjectPublicKeyInfo + ) diff --git a/tests/__init__.py b/tests/__init__.py index 4dedc46a5..960076e1c 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -1,3 +1,3 @@ import os -TEST_WITH_REAL_YK = os.environ.get('REAL_YK', False) +TEST_WITH_REAL_YK = os.environ.get("REAL_YK", False) diff --git a/tests/conftest.py b/tests/conftest.py index 7d930b9d5..2f9d3012d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -4,8 +4,10 @@ from pathlib import Path from pytest import fixture, yield_fixture -from securesystemslib.interface import (import_rsa_privatekey_from_file, - import_rsa_publickey_from_file) +from securesystemslib.interface import ( + import_rsa_privatekey_from_file, + import_rsa_publickey_from_file, +) import taf.repository_tool as repository_tool import taf.yubikey @@ -13,175 +15,191 @@ from taf.utils import on_rm_error from . import TEST_WITH_REAL_YK -from .yubikey_utils import (Root1YubiKey, Root2YubiKey, Root3YubiKey, - TargetYubiKey, _yk_piv_ctrl_mock) - -TEST_DATA_PATH = Path(__file__).parent / 'data' -TEST_DATA_REPOS_PATH = TEST_DATA_PATH / 'repos' -TEST_DATA_ORIGIN_PATH = TEST_DATA_REPOS_PATH / 'origin' -KEYSTORE_PATH = TEST_DATA_PATH / 'keystore' -WRONG_KEYSTORE_PATH = TEST_DATA_PATH / 'wrong_keystore' -CLIENT_DIR_PATH = TEST_DATA_REPOS_PATH / 'client' +from .yubikey_utils import ( + Root1YubiKey, + Root2YubiKey, + Root3YubiKey, + TargetYubiKey, + _yk_piv_ctrl_mock, +) + +TEST_DATA_PATH = Path(__file__).parent / "data" +TEST_DATA_REPOS_PATH = TEST_DATA_PATH / "repos" +TEST_DATA_ORIGIN_PATH = TEST_DATA_REPOS_PATH / "origin" +KEYSTORE_PATH = TEST_DATA_PATH / "keystore" +WRONG_KEYSTORE_PATH = TEST_DATA_PATH / "wrong_keystore" +CLIENT_DIR_PATH = TEST_DATA_REPOS_PATH / "client" def pytest_configure(config): - if not TEST_WITH_REAL_YK: - taf.yubikey._yk_piv_ctrl = _yk_piv_ctrl_mock + if not TEST_WITH_REAL_YK: + taf.yubikey._yk_piv_ctrl = _yk_piv_ctrl_mock def pytest_generate_tests(metafunc): - if "taf_happy_path" in metafunc.fixturenames: - # When running tests with real yubikey, use just rsa-pkcs1v15-sha256 scheme - schemes = ["rsa-pkcs1v15-sha256"] if TEST_WITH_REAL_YK else ["rsassa-pss-sha256", "rsa-pkcs1v15-sha256"] - metafunc.parametrize("taf_happy_path", schemes, indirect=True) + if "taf_happy_path" in metafunc.fixturenames: + # When running tests with real yubikey, use just rsa-pkcs1v15-sha256 scheme + schemes = ( + ["rsa-pkcs1v15-sha256"] + if TEST_WITH_REAL_YK + else ["rsassa-pss-sha256", "rsa-pkcs1v15-sha256"] + ) + metafunc.parametrize("taf_happy_path", schemes, indirect=True) @contextmanager def origin_repos_group(test_group_dir): - all_paths = {} - test_group_dir = str(TEST_DATA_REPOS_PATH / test_group_dir) - for test_dir in os.scandir(test_group_dir): - if test_dir.is_dir(): - all_paths[test_dir.name] = _copy_repos(test_dir.path, test_dir.name) + all_paths = {} + test_group_dir = str(TEST_DATA_REPOS_PATH / test_group_dir) + for test_dir in os.scandir(test_group_dir): + if test_dir.is_dir(): + all_paths[test_dir.name] = _copy_repos(test_dir.path, test_dir.name) - yield all_paths + yield all_paths - for test_name in all_paths: - test_dst_path = str(TEST_DATA_ORIGIN_PATH / test_name) - shutil.rmtree(test_dst_path, onerror=on_rm_error) + for test_name in all_paths: + test_dst_path = str(TEST_DATA_ORIGIN_PATH / test_name) + shutil.rmtree(test_dst_path, onerror=on_rm_error) @contextmanager def origin_repos(test_name): - """Coppies git repository from `data/repos/test-XYZ` to data/repos/origin/test-XYZ + """Coppies git repository from `data/repos/test-XYZ` to data/repos/origin/test-XYZ path and renames `git` to `.git` for each repository. """ - test_dir_path = str(TEST_DATA_REPOS_PATH / test_name) - temp_paths = _copy_repos(test_dir_path, test_name) + test_dir_path = str(TEST_DATA_REPOS_PATH / test_name) + temp_paths = _copy_repos(test_dir_path, test_name) - yield temp_paths + yield temp_paths - test_dst_path = str(TEST_DATA_ORIGIN_PATH / test_name) - shutil.rmtree(test_dst_path, onerror=on_rm_error) + test_dst_path = str(TEST_DATA_ORIGIN_PATH / test_name) + shutil.rmtree(test_dst_path, onerror=on_rm_error) def _copy_repos(test_dir_path, test_name): - paths = {} - for root, dirs, _ in os.walk(test_dir_path): - for dir_name in dirs: - if dir_name == 'git': - repo_rel_path = os.path.relpath(root, test_dir_path) - dst_path = TEST_DATA_ORIGIN_PATH / test_name / repo_rel_path - # convert dst_path to string in order to support python 3.5 - shutil.copytree(root, str(dst_path)) - (dst_path / 'git').rename(dst_path / '.git') - repo_rel_path = Path(repo_rel_path).as_posix() - paths[repo_rel_path] = str(dst_path) - return paths - - -@yield_fixture(scope='session', autouse=True) + paths = {} + for root, dirs, _ in os.walk(test_dir_path): + for dir_name in dirs: + if dir_name == "git": + repo_rel_path = os.path.relpath(root, test_dir_path) + dst_path = TEST_DATA_ORIGIN_PATH / test_name / repo_rel_path + # convert dst_path to string in order to support python 3.5 + shutil.copytree(root, str(dst_path)) + (dst_path / "git").rename(dst_path / ".git") + repo_rel_path = Path(repo_rel_path).as_posix() + paths[repo_rel_path] = str(dst_path) + return paths + + +@yield_fixture(scope="session", autouse=True) def taf_happy_path(request, pytestconfig): - """TAF repository for testing.""" - repository_tool.DISABLE_KEYS_CACHING = True + """TAF repository for testing.""" + repository_tool.DISABLE_KEYS_CACHING = True - def _create_origin(test_dir, taf_repo_name='taf'): - with origin_repos(test_dir) as origins: - taf_repo_origin_path = origins[taf_repo_name] - yield Repository(taf_repo_origin_path) + def _create_origin(test_dir, taf_repo_name="taf"): + with origin_repos(test_dir) as origins: + taf_repo_origin_path = origins[taf_repo_name] + yield Repository(taf_repo_origin_path) - scheme = request.param - pytestconfig.option.signature_scheme = scheme + scheme = request.param + pytestconfig.option.signature_scheme = scheme - if scheme == 'rsassa-pss-sha256': - yield from _create_origin('test-happy-path') - elif scheme == 'rsa-pkcs1v15-sha256': - yield from _create_origin('test-happy-path-pkcs1v15') - else: - raise ValueError("Invalid test config. Invalid scheme: {}".format(scheme)) + if scheme == "rsassa-pss-sha256": + yield from _create_origin("test-happy-path") + elif scheme == "rsa-pkcs1v15-sha256": + yield from _create_origin("test-happy-path-pkcs1v15") + else: + raise ValueError("Invalid test config. Invalid scheme: {}".format(scheme)) @yield_fixture(scope="session", autouse=True) def updater_repositories(): - test_dir = 'test-updater' - with origin_repos_group(test_dir) as origins: - yield origins + test_dir = "test-updater" + with origin_repos_group(test_dir) as origins: + yield origins @fixture def client_dir(): - return CLIENT_DIR_PATH + return CLIENT_DIR_PATH @fixture def origin_dir(): - return TEST_DATA_ORIGIN_PATH + return TEST_DATA_ORIGIN_PATH @fixture def keystore(): - """Keystore path.""" - return str(KEYSTORE_PATH) + """Keystore path.""" + return str(KEYSTORE_PATH) @fixture def wrong_keystore(): - """Path of the wrong keystore""" - return str(WRONG_KEYSTORE_PATH) + """Path of the wrong keystore""" + return str(WRONG_KEYSTORE_PATH) @fixture def targets_yk(pytestconfig): - """Targets YubiKey.""" - return TargetYubiKey(KEYSTORE_PATH, pytestconfig.option.signature_scheme) + """Targets YubiKey.""" + return TargetYubiKey(KEYSTORE_PATH, pytestconfig.option.signature_scheme) @fixture def root1_yk(pytestconfig): - """Root1 YubiKey.""" - return Root1YubiKey(KEYSTORE_PATH, pytestconfig.option.signature_scheme) + """Root1 YubiKey.""" + return Root1YubiKey(KEYSTORE_PATH, pytestconfig.option.signature_scheme) @fixture def root2_yk(pytestconfig): - """Root2 YubiKey.""" - return Root2YubiKey(KEYSTORE_PATH, pytestconfig.option.signature_scheme) + """Root2 YubiKey.""" + return Root2YubiKey(KEYSTORE_PATH, pytestconfig.option.signature_scheme) @fixture def root3_yk(pytestconfig): - """Root3 YubiKey.""" - return Root3YubiKey(KEYSTORE_PATH, pytestconfig.option.signature_scheme) + """Root3 YubiKey.""" + return Root3YubiKey(KEYSTORE_PATH, pytestconfig.option.signature_scheme) @fixture def snapshot_key(pytestconfig): - """Snapshot key.""" - key = import_rsa_publickey_from_file(str(KEYSTORE_PATH / 'snapshot.pub'), - scheme=pytestconfig.option.signature_scheme) - priv_key = import_rsa_privatekey_from_file(str(KEYSTORE_PATH / 'snapshot'), - scheme=pytestconfig.option.signature_scheme) - key['keyval']['private'] = priv_key['keyval']['private'] - return key + """Snapshot key.""" + key = import_rsa_publickey_from_file( + str(KEYSTORE_PATH / "snapshot.pub"), scheme=pytestconfig.option.signature_scheme + ) + priv_key = import_rsa_privatekey_from_file( + str(KEYSTORE_PATH / "snapshot"), scheme=pytestconfig.option.signature_scheme + ) + key["keyval"]["private"] = priv_key["keyval"]["private"] + return key @fixture def timestamp_key(pytestconfig): - """Timestamp key.""" - key = import_rsa_publickey_from_file(str(KEYSTORE_PATH / 'timestamp.pub'), - scheme=pytestconfig.option.signature_scheme) - priv_key = import_rsa_privatekey_from_file(str(KEYSTORE_PATH / 'timestamp'), - scheme=pytestconfig.option.signature_scheme) - key['keyval']['private'] = priv_key['keyval']['private'] - return key + """Timestamp key.""" + key = import_rsa_publickey_from_file( + str(KEYSTORE_PATH / "timestamp.pub"), + scheme=pytestconfig.option.signature_scheme, + ) + priv_key = import_rsa_privatekey_from_file( + str(KEYSTORE_PATH / "timestamp"), scheme=pytestconfig.option.signature_scheme + ) + key["keyval"]["private"] = priv_key["keyval"]["private"] + return key @yield_fixture def targets_key(pytestconfig): - """Targets key.""" - key = import_rsa_publickey_from_file(str(KEYSTORE_PATH / 'targets.pub'), - scheme=pytestconfig.option.signature_scheme) - priv_key = import_rsa_privatekey_from_file(str(KEYSTORE_PATH / 'targets'), - scheme=pytestconfig.option.signature_scheme) - key['keyval']['private'] = priv_key['keyval']['private'] - return key + """Targets key.""" + key = import_rsa_publickey_from_file( + str(KEYSTORE_PATH / "targets.pub"), scheme=pytestconfig.option.signature_scheme + ) + priv_key = import_rsa_privatekey_from_file( + str(KEYSTORE_PATH / "targets"), scheme=pytestconfig.option.signature_scheme + ) + key["keyval"]["private"] = priv_key["keyval"]["private"] + return key diff --git a/tests/test_add_targets.py b/tests/test_add_targets.py index 9ce05346a..72823ea03 100644 --- a/tests/test_add_targets.py +++ b/tests/test_add_targets.py @@ -1,4 +1,3 @@ - import json import os from pathlib import Path @@ -11,107 +10,111 @@ @fixture(autouse=True) def run_around_tests(taf_happy_path): - yield - repo = GitRepository(taf_happy_path.repository_path) - repo.reset_to_head() - repo.clean() - taf_happy_path._repository.targets.clear_targets() - files_to_keep = [] - for root, _, filenames in os.walk(str(taf_happy_path.targets_path)): - for filename in filenames: - file_path = str(Path(root) / filename) - relpath = Path(os.path.relpath(file_path, str(taf_happy_path.targets_path))).as_posix() - files_to_keep.append(relpath) - taf_happy_path.add_targets({}, files_to_keep=files_to_keep) + yield + repo = GitRepository(taf_happy_path.repository_path) + repo.reset_to_head() + repo.clean() + taf_happy_path._repository.targets.clear_targets() + files_to_keep = [] + for root, _, filenames in os.walk(str(taf_happy_path.targets_path)): + for filename in filenames: + file_path = str(Path(root) / filename) + relpath = Path( + os.path.relpath(file_path, str(taf_happy_path.targets_path)) + ).as_posix() + files_to_keep.append(relpath) + taf_happy_path.add_targets({}, files_to_keep=files_to_keep) def test_add_targets_new_files(taf_happy_path): - old_targets = _get_old_targets(taf_happy_path) + old_targets = _get_old_targets(taf_happy_path) - json_file_content = {'attr1': 'value1', 'attr2': 'value2'} - regular_file_content = 'this file is not empty' - data = { - 'new_json_file': {'target': json_file_content}, - 'new_file': {'target': regular_file_content}, - 'empty_file': {'target': None} - } - taf_happy_path.add_targets(data) - _check_target_files(taf_happy_path, data, old_targets) + json_file_content = {"attr1": "value1", "attr2": "value2"} + regular_file_content = "this file is not empty" + data = { + "new_json_file": {"target": json_file_content}, + "new_file": {"target": regular_file_content}, + "empty_file": {"target": None}, + } + taf_happy_path.add_targets(data) + _check_target_files(taf_happy_path, data, old_targets) def test_add_targets_nested_files(taf_happy_path): - old_targets = _get_old_targets(taf_happy_path) + old_targets = _get_old_targets(taf_happy_path) - data = { - 'inner_folder1/new_file_1': {'target': 'file 1 content'}, - 'inner_folder2/new_file_2': {'target': 'file 2 content'} - } - taf_happy_path.add_targets(data) - _check_target_files(taf_happy_path, data, old_targets) + data = { + "inner_folder1/new_file_1": {"target": "file 1 content"}, + "inner_folder2/new_file_2": {"target": "file 2 content"}, + } + taf_happy_path.add_targets(data) + _check_target_files(taf_happy_path, data, old_targets) def test_add_targets_files_to_keep(taf_happy_path): - old_targets = _get_old_targets(taf_happy_path) - data = { - 'a_new_file': {'target': 'new file content'} - } - files_to_keep = ['branch'] - taf_happy_path.add_targets(data, files_to_keep=files_to_keep) - _check_target_files(taf_happy_path, data, old_targets, files_to_keep) + old_targets = _get_old_targets(taf_happy_path) + data = {"a_new_file": {"target": "new file content"}} + files_to_keep = ["branch"] + taf_happy_path.add_targets(data, files_to_keep=files_to_keep) + _check_target_files(taf_happy_path, data, old_targets, files_to_keep) def _check_target_files(repo, data, old_targets, files_to_keep=None): - if files_to_keep is None: - files_to_keep = [] - - targets_path = repo.targets_path - for target_rel_path, content in data.items(): - target_path = targets_path / target_rel_path - assert target_path.exists() - with open(str(target_path)) as f: - file_content = f.read() - target_content = content['target'] - if isinstance(target_content, dict): - content_json = json.loads(file_content) - assert content_json == target_content - elif target_content: - assert file_content == target_content - else: - assert file_content == '' - - # make sure that everything defined in repositories.json still exists - repository_targets = [] - repositories_path = targets_path / 'repositories.json' - assert repositories_path.exists() - with open(str(repositories_path)) as f: - repositories = json.load(f)['repositories'] - for target_rel_path in repositories: - target_path = targets_path / target_rel_path - assert target_path.exists() - repository_targets.append(target_rel_path) - - # make sure that files to keep exist - for file_to_keep in files_to_keep: - # if the file didn't exists prior to adding new targets - # it won't exists after adding them - if file_to_keep not in old_targets: - continue - target_path = targets_path / file_to_keep - assert target_path.exists() - - for old_target in old_targets: - if old_target not in repository_targets and old_target not in data and \ - old_target not in repo._framework_files and not old_target in files_to_keep: - assert (targets_path / old_target).exists() is False + if files_to_keep is None: + files_to_keep = [] + + targets_path = repo.targets_path + for target_rel_path, content in data.items(): + target_path = targets_path / target_rel_path + assert target_path.exists() + with open(str(target_path)) as f: + file_content = f.read() + target_content = content["target"] + if isinstance(target_content, dict): + content_json = json.loads(file_content) + assert content_json == target_content + elif target_content: + assert file_content == target_content + else: + assert file_content == "" + + # make sure that everything defined in repositories.json still exists + repository_targets = [] + repositories_path = targets_path / "repositories.json" + assert repositories_path.exists() + with open(str(repositories_path)) as f: + repositories = json.load(f)["repositories"] + for target_rel_path in repositories: + target_path = targets_path / target_rel_path + assert target_path.exists() + repository_targets.append(target_rel_path) + + # make sure that files to keep exist + for file_to_keep in files_to_keep: + # if the file didn't exists prior to adding new targets + # it won't exists after adding them + if file_to_keep not in old_targets: + continue + target_path = targets_path / file_to_keep + assert target_path.exists() + + for old_target in old_targets: + if ( + old_target not in repository_targets + and old_target not in data + and old_target not in repo._framework_files + and old_target not in files_to_keep + ): + assert (targets_path / old_target).exists() is False def _get_old_targets(repo): - targets_path = repo.targets_path - old_targets = [] - for root, _, filenames in os.walk(str(targets_path)): - for filename in filenames: - rel_path = os.path.relpath(str(Path(root) / filename), str(targets_path)) - old_targets.append(Path(rel_path).as_posix()) - return old_targets + targets_path = repo.targets_path + old_targets = [] + for root, _, filenames in os.walk(str(targets_path)): + for filename in filenames: + rel_path = os.path.relpath(str(Path(root) / filename), str(targets_path)) + old_targets.append(Path(rel_path).as_posix()) + return old_targets diff --git a/tests/test_repository.py b/tests/test_repository.py index a51dbb3f4..127b39057 100644 --- a/tests/test_repository.py +++ b/tests/test_repository.py @@ -5,18 +5,20 @@ def test_url_validation_valid_urls(): - urls = ['https://github.com/account/repo_name.git', - 'https://github.com/account/repo_name', - 'http://github.com/account/repo_name.git', - 'http://github.com/account/repo_name', - 'git@github.com:openlawlibrary/taf.git', - 'git@github.com:openlawlibrary/taf'] - repo = GitRepository('path', urls) - for test_url, repo_url in zip(urls, repo.repo_urls): - assert test_url == repo_url + urls = [ + "https://github.com/account/repo_name.git", + "https://github.com/account/repo_name", + "http://github.com/account/repo_name.git", + "http://github.com/account/repo_name", + "git@github.com:openlawlibrary/taf.git", + "git@github.com:openlawlibrary/taf", + ] + repo = GitRepository("path", urls) + for test_url, repo_url in zip(urls, repo.repo_urls): + assert test_url == repo_url def test_url_invalid_urls(): - urls = ['abc://something.com'] - with pytest.raises(InvalidRepositoryError): - repo = GitRepository('path', urls) + urls = ["abc://something.com"] + with pytest.raises(InvalidRepositoryError): + GitRepository("path", urls) diff --git a/tests/test_repository_tool.py b/tests/test_repository_tool.py index 476671f6b..9d6f1cd84 100644 --- a/tests/test_repository_tool.py +++ b/tests/test_repository_tool.py @@ -12,155 +12,162 @@ @pytest.mark.skipif(TEST_WITH_REAL_YK, reason="Testing with real Yubikey.") -def test_check_no_key_inserted_for_targets_should_raise_error(taf_happy_path, targets_yk): - targets_yk.insert() - targets_yk.remove() - with pytest.raises(taf.exceptions.YubikeyError): - taf_happy_path.is_valid_metadata_yubikey('targets') +def test_check_no_key_inserted_for_targets_should_raise_error( + taf_happy_path, targets_yk +): + targets_yk.insert() + targets_yk.remove() + with pytest.raises(taf.exceptions.YubikeyError): + taf_happy_path.is_valid_metadata_yubikey("targets") + +def test_check_targets_key_id_for_targets_should_return_true( + taf_happy_path, targets_yk +): + from tuf.keydb import _keydb_dict -def test_check_targets_key_id_for_targets_should_return_true(taf_happy_path, targets_yk): - from tuf.keydb import _keydb_dict - targets_yk.insert() - assert taf_happy_path.is_valid_metadata_yubikey('targets', targets_yk.tuf_key) + targets_yk.insert() + assert taf_happy_path.is_valid_metadata_yubikey("targets", targets_yk.tuf_key) def test_check_root_key_id_for_targets_should_return_false(taf_happy_path, root1_yk): - root1_yk.insert() - assert not taf_happy_path.is_valid_metadata_yubikey('targets', root1_yk.tuf_key) + root1_yk.insert() + assert not taf_happy_path.is_valid_metadata_yubikey("targets", root1_yk.tuf_key) def test_update_snapshot_and_timestmap(taf_happy_path, snapshot_key, timestamp_key): - date_now = datetime.datetime.now() - snapshot_date = date_now + datetime.timedelta(1) - snapshot_interval = 2 - timestamp_date = date_now + datetime.timedelta(2) - timestamp_interval = 3 + date_now = datetime.datetime.now() + snapshot_date = date_now + datetime.timedelta(1) + snapshot_interval = 2 + timestamp_date = date_now + datetime.timedelta(2) + timestamp_interval = 3 - kwargs = { - 'snapshot_date': snapshot_date, - 'timestamp_date': timestamp_date, - 'snapshot_interval': snapshot_interval, - 'timestamp_interval': timestamp_interval - } + kwargs = { + "snapshot_date": snapshot_date, + "timestamp_date": timestamp_date, + "snapshot_interval": snapshot_interval, + "timestamp_interval": timestamp_interval, + } - taf_happy_path.update_snapshot_and_timestmap(snapshot_key, timestamp_key, **kwargs) + taf_happy_path.update_snapshot_and_timestmap(snapshot_key, timestamp_key, **kwargs) - targets_metadata_path = Path(taf_happy_path.metadata_path) / 'targets.json' - snapshot_metadata_path = Path(taf_happy_path.metadata_path) / 'snapshot.json' - timestamp_metadata_path = Path(taf_happy_path.metadata_path) / 'timestamp.json' + targets_metadata_path = Path(taf_happy_path.metadata_path) / "targets.json" + snapshot_metadata_path = Path(taf_happy_path.metadata_path) / "snapshot.json" + timestamp_metadata_path = Path(taf_happy_path.metadata_path) / "timestamp.json" - old_targets_metadata = targets_metadata_path.read_bytes() + old_targets_metadata = targets_metadata_path.read_bytes() - def check_expiration_date(metadata_path, date, interval): - signable = securesystemslib.util.load_json_file(metadata_path) - tuf.formats.SIGNABLE_SCHEMA.check_match(signable) - actual_expiration_date = signable['signed']['expires'] + def check_expiration_date(metadata_path, date, interval): + signable = securesystemslib.util.load_json_file(metadata_path) + tuf.formats.SIGNABLE_SCHEMA.check_match(signable) + actual_expiration_date = signable["signed"]["expires"] - assert actual_expiration_date == to_tuf_datetime_format(date, interval) + assert actual_expiration_date == to_tuf_datetime_format(date, interval) - check_expiration_date(str(snapshot_metadata_path), snapshot_date, snapshot_interval) - check_expiration_date(str(timestamp_metadata_path), timestamp_date, timestamp_interval) + check_expiration_date(str(snapshot_metadata_path), snapshot_date, snapshot_interval) + check_expiration_date( + str(timestamp_metadata_path), timestamp_date, timestamp_interval + ) - # Targets data should remain the same - assert old_targets_metadata == targets_metadata_path.read_bytes() + # Targets data should remain the same + assert old_targets_metadata == targets_metadata_path.read_bytes() def test_update_snapshot_valid_key(taf_happy_path, snapshot_key): - start_date = datetime.datetime.now() - interval = 1 - expected_expiration_date = to_tuf_datetime_format(start_date, interval) - taf_happy_path.update_snapshot(snapshot_key, start_date=start_date, interval=interval) - new_snapshot_metadata = str(Path(taf_happy_path.metadata_path) / 'snapshot.json') - signable = securesystemslib.util.load_json_file(new_snapshot_metadata) - tuf.formats.SIGNABLE_SCHEMA.check_match(signable) - actual_expiration_date = signable['signed']['expires'] + start_date = datetime.datetime.now() + interval = 1 + expected_expiration_date = to_tuf_datetime_format(start_date, interval) + taf_happy_path.update_snapshot( + snapshot_key, start_date=start_date, interval=interval + ) + new_snapshot_metadata = str(Path(taf_happy_path.metadata_path) / "snapshot.json") + signable = securesystemslib.util.load_json_file(new_snapshot_metadata) + tuf.formats.SIGNABLE_SCHEMA.check_match(signable) + actual_expiration_date = signable["signed"]["expires"] - assert actual_expiration_date == expected_expiration_date + assert actual_expiration_date == expected_expiration_date def test_update_snapshot_wrong_key(taf_happy_path, timestamp_key): - with pytest.raises(taf.exceptions.InvalidKeyError): - taf_happy_path.update_snapshot(timestamp_key) + with pytest.raises(taf.exceptions.InvalidKeyError): + taf_happy_path.update_snapshot(timestamp_key) def test_update_timestamp_valid_key(taf_happy_path, timestamp_key): - start_date = datetime.datetime.now() - interval = 1 - expected_expiration_date = to_tuf_datetime_format(start_date, interval) - - taf_happy_path.update_timestamp(timestamp_key, start_date=start_date, interval=interval) - new_timestamp_metadata = str(Path(taf_happy_path.metadata_path) / 'timestamp.json') - signable = securesystemslib.util.load_json_file(new_timestamp_metadata) - tuf.formats.SIGNABLE_SCHEMA.check_match(signable) - actual_expiration_date = signable['signed']['expires'] + start_date = datetime.datetime.now() + interval = 1 + expected_expiration_date = to_tuf_datetime_format(start_date, interval) + + taf_happy_path.update_timestamp( + timestamp_key, start_date=start_date, interval=interval + ) + new_timestamp_metadata = str(Path(taf_happy_path.metadata_path) / "timestamp.json") + signable = securesystemslib.util.load_json_file(new_timestamp_metadata) + tuf.formats.SIGNABLE_SCHEMA.check_match(signable) + actual_expiration_date = signable["signed"]["expires"] - assert actual_expiration_date == expected_expiration_date + assert actual_expiration_date == expected_expiration_date def test_update_timestamp_wrong_key(taf_happy_path, snapshot_key): - with pytest.raises(taf.exceptions.InvalidKeyError): - taf_happy_path.update_timestamp(snapshot_key) + with pytest.raises(taf.exceptions.InvalidKeyError): + taf_happy_path.update_timestamp(snapshot_key) def test_update_targets_from_keystore_valid_key(taf_happy_path, targets_key): - start_date = datetime.datetime.now() - interval = 1 - expected_expiration_date = to_tuf_datetime_format(start_date, interval) - - taf_happy_path.update_targets_from_keystore(targets_key, start_date=start_date, interval=interval) - new_targets_data = str(Path(taf_happy_path.metadata_path) / 'targets.json') - signable = securesystemslib.util.load_json_file(new_targets_data) - tuf.formats.SIGNABLE_SCHEMA.check_match(signable) - actual_expiration_date = signable['signed']['expires'] + start_date = datetime.datetime.now() + interval = 1 + expected_expiration_date = to_tuf_datetime_format(start_date, interval) + + taf_happy_path.update_targets_from_keystore( + targets_key, start_date=start_date, interval=interval + ) + new_targets_data = str(Path(taf_happy_path.metadata_path) / "targets.json") + signable = securesystemslib.util.load_json_file(new_targets_data) + tuf.formats.SIGNABLE_SCHEMA.check_match(signable) + actual_expiration_date = signable["signed"]["expires"] - assert actual_expiration_date == expected_expiration_date + assert actual_expiration_date == expected_expiration_date def test_update_targets_from_keystore_wrong_key(taf_happy_path, snapshot_key): - with pytest.raises(taf.exceptions.InvalidKeyError): - taf_happy_path.update_targets_from_keystore(snapshot_key) + with pytest.raises(taf.exceptions.InvalidKeyError): + taf_happy_path.update_targets_from_keystore(snapshot_key) def test_update_targets_valid_key_valid_pin(taf_happy_path, targets_yk): - targets_path = Path(taf_happy_path.targets_path) - repositories_json_path = targets_path / 'repositories.json' - - branch_id = '14e81cd1-0050-43aa-9e2c-e34fffa6f517' - target_commit_sha = 'xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx' - repositories_json_old = repositories_json_path.read_text() - - targets_data = { - 'branch': { - 'target': branch_id, - }, - 'dummy/target_dummy_repo': { - 'target': { - 'commit': target_commit_sha - } - }, - 'capstone': {} - } - - targets_yk.insert() - taf_happy_path.update_targets('123456', targets_data, - datetime.datetime.now(), - public_key=targets_yk.tuf_key) - - assert (targets_path / 'branch').read_text() == branch_id - assert target_commit_sha in (targets_path / 'dummy/target_dummy_repo').read_text() - assert (targets_path / 'capstone').is_file() - assert repositories_json_old == repositories_json_path.read_text() + targets_path = Path(taf_happy_path.targets_path) + repositories_json_path = targets_path / "repositories.json" + branch_id = "14e81cd1-0050-43aa-9e2c-e34fffa6f517" + target_commit_sha = "xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx" + repositories_json_old = repositories_json_path.read_text() + + targets_data = { + "branch": {"target": branch_id}, + "dummy/target_dummy_repo": {"target": {"commit": target_commit_sha}}, + "capstone": {}, + } -def test_update_targets_valid_key_wrong_pin(taf_happy_path, targets_yk): - with pytest.raises(taf.exceptions.TargetsMetadataUpdateError): targets_yk.insert() - taf_happy_path.update_targets('123', public_key=targets_yk.tuf_key) + taf_happy_path.update_targets( + "123456", targets_data, datetime.datetime.now(), public_key=targets_yk.tuf_key + ) + + assert (targets_path / "branch").read_text() == branch_id + assert target_commit_sha in (targets_path / "dummy/target_dummy_repo").read_text() + assert (targets_path / "capstone").is_file() + assert repositories_json_old == repositories_json_path.read_text() + + +def test_update_targets_valid_key_wrong_pin(taf_happy_path, targets_yk): + with pytest.raises(taf.exceptions.TargetsMetadataUpdateError): + targets_yk.insert() + taf_happy_path.update_targets("123", public_key=targets_yk.tuf_key) @pytest.mark.skipif(TEST_WITH_REAL_YK, reason="Testing with real Yubikey.") def test_update_targets_wrong_key(taf_happy_path, root1_yk): - with pytest.raises(taf.exceptions.InvalidKeyError): - root1_yk.insert() - taf_happy_path.update_targets('123456') + with pytest.raises(taf.exceptions.InvalidKeyError): + root1_yk.insert() + taf_happy_path.update_targets("123456") diff --git a/tests/test_updater.py b/tests/test_updater.py index e8bd8d651..e21f2a75b 100644 --- a/tests/test_updater.py +++ b/tests/test_updater.py @@ -57,282 +57,364 @@ from taf.updater.updater import update_repository from taf.utils import on_rm_error -AUTH_REPO_REL_PATH = 'organization/auth_repo' -TARGET1_SHA_MISMATCH = 'Mismatch between target commits specified in authentication repository and target repository namespace/TargetRepo1' -TARGETS_MISMATCH_ANY = 'Mismatch between target commits specified in authentication repository and target repository' -NO_WORKING_MIRRORS = 'Validation of authentication repository auth_repo failed due to error: No working mirror was found' +AUTH_REPO_REL_PATH = "organization/auth_repo" +TARGET1_SHA_MISMATCH = "Mismatch between target commits specified in authentication repository and target repository namespace/TargetRepo1" +TARGETS_MISMATCH_ANY = "Mismatch between target commits specified in authentication repository and target repository" +NO_WORKING_MIRRORS = "Validation of authentication repository auth_repo failed due to error: No working mirror was found" TIMESTAMP_EXPIRED = "Metadata 'timestamp' expired" -REPLAYED_METADATA = 'ReplayedMetadataError' -METADATA_CHANGED_BUT_SHOULDNT = 'Metadata file targets.json should be the same at revisions' +REPLAYED_METADATA = "ReplayedMetadataError" +METADATA_CHANGED_BUT_SHOULDNT = ( + "Metadata file targets.json should be the same at revisions" +) def setup_module(module): - settings.update_from_filesystem = True + settings.update_from_filesystem = True def teardown_module(module): - settings.update_from_filesystem = False + settings.update_from_filesystem = False @fixture(autouse=True) def run_around_tests(client_dir): - yield - for root, dirs, _ in os.walk(str(client_dir)): - for dir_name in dirs: - shutil.rmtree(str(Path(root) / dir_name), onerror=on_rm_error) - - -@pytest.mark.parametrize('test_name, test_repo', [('test-updater-valid', False), - ('test-updater-additional-target-commit', False), - ('test-updater-valid-with-updated-expiration-dates', False), - ('test-updater-allow-unauthenticated-commits', False), - ('test-updater-test-repo', True)]) -def test_valid_update_no_client_repo(test_name, test_repo, updater_repositories, origin_dir, client_dir): - repositories = updater_repositories[test_name] - origin_dir = origin_dir / test_name - _update_and_check_commit_shas(None, repositories, origin_dir, client_dir, test_repo) - - -@pytest.mark.parametrize('test_name, num_of_commits_to_revert', [('test-updater-valid', 3), - ('test-updater-additional-target-commit', 1), - ('test-updater-allow-unauthenticated-commits', 1)]) -def test_valid_update_existing_client_repos(test_name, num_of_commits_to_revert, - updater_repositories, origin_dir, client_dir): - # clone the origin repositories - # revert them to an older commit - repositories = updater_repositories[test_name] - origin_dir = origin_dir / test_name - client_repos = _clone_and_revert_client_repositories(repositories, origin_dir, client_dir, - num_of_commits_to_revert) - # create valid last validated commit file - _create_last_validated_commit(client_dir, client_repos[AUTH_REPO_REL_PATH].head_commit_sha()) - _update_and_check_commit_shas(client_repos, repositories, origin_dir, client_dir) - - -@pytest.mark.parametrize('test_name, test_repo', [('test-updater-valid', False), - ('test-updater-allow-unauthenticated-commits', False), - ('test-updater-test-repo', True)]) -def test_no_update_necessary(test_name, test_repo, updater_repositories, origin_dir, client_dir): - # clone the origin repositories - # revert them to an older commit - repositories = updater_repositories[test_name] - origin_dir = origin_dir / test_name - client_repos = _clone_client_repositories(repositories, origin_dir, client_dir) - # create valid last validated commit file - _create_last_validated_commit(client_dir, client_repos[AUTH_REPO_REL_PATH].head_commit_sha()) - _update_and_check_commit_shas(client_repos, repositories, origin_dir, client_dir, test_repo) - - -@pytest.mark.parametrize('test_name, expected_error', [ - ('test-updater-invalid-target-sha', TARGET1_SHA_MISMATCH), - ('test-updater-missing-target-commit', TARGET1_SHA_MISMATCH), - ('test-updater-wrong-key', NO_WORKING_MIRRORS), - ('test-updater-invalid-expiration-date', TIMESTAMP_EXPIRED), - ('test-updater-invalid-version-number', REPLAYED_METADATA), - ('test-updater-just-targets-updated', METADATA_CHANGED_BUT_SHOULDNT)]) -def test_updater_invalid_update(test_name, expected_error, updater_repositories, client_dir): - repositories = updater_repositories[test_name] - clients_auth_repo_path = client_dir / AUTH_REPO_REL_PATH - _update_invalid_repos_and_check_if_repos_exist(client_dir, repositories, expected_error) - # make sure that the last validated commit does not exist - _check_if_last_validated_commit_exists(clients_auth_repo_path) - - -@pytest.mark.parametrize('test_name, expected_error', [ - ('test-updater-invalid-target-sha', TARGET1_SHA_MISMATCH)]) -def test_updater_invalid_target_sha_existing_client_repos(test_name, expected_error, - updater_repositories, origin_dir, - client_dir): - repositories = updater_repositories[test_name] - clients_auth_repo_path = client_dir / AUTH_REPO_REL_PATH - origin_dir = origin_dir / test_name - client_repos = _clone_and_revert_client_repositories(repositories, - origin_dir, client_dir, 1) - _create_last_validated_commit(client_dir, client_repos[AUTH_REPO_REL_PATH].head_commit_sha()) - _update_invalid_repos_and_check_if_remained_same(client_repos, client_dir, - repositories, - expected_error) - _check_last_validated_commit(clients_auth_repo_path) + yield + for root, dirs, _ in os.walk(str(client_dir)): + for dir_name in dirs: + shutil.rmtree(str(Path(root) / dir_name), onerror=on_rm_error) + + +@pytest.mark.parametrize( + "test_name, test_repo", + [ + ("test-updater-valid", False), + ("test-updater-additional-target-commit", False), + ("test-updater-valid-with-updated-expiration-dates", False), + ("test-updater-allow-unauthenticated-commits", False), + ("test-updater-test-repo", True), + ], +) +def test_valid_update_no_client_repo( + test_name, test_repo, updater_repositories, origin_dir, client_dir +): + repositories = updater_repositories[test_name] + origin_dir = origin_dir / test_name + _update_and_check_commit_shas(None, repositories, origin_dir, client_dir, test_repo) + + +@pytest.mark.parametrize( + "test_name, num_of_commits_to_revert", + [ + ("test-updater-valid", 3), + ("test-updater-additional-target-commit", 1), + ("test-updater-allow-unauthenticated-commits", 1), + ], +) +def test_valid_update_existing_client_repos( + test_name, num_of_commits_to_revert, updater_repositories, origin_dir, client_dir +): + # clone the origin repositories + # revert them to an older commit + repositories = updater_repositories[test_name] + origin_dir = origin_dir / test_name + client_repos = _clone_and_revert_client_repositories( + repositories, origin_dir, client_dir, num_of_commits_to_revert + ) + # create valid last validated commit file + _create_last_validated_commit( + client_dir, client_repos[AUTH_REPO_REL_PATH].head_commit_sha() + ) + _update_and_check_commit_shas(client_repos, repositories, origin_dir, client_dir) + + +@pytest.mark.parametrize( + "test_name, test_repo", + [ + ("test-updater-valid", False), + ("test-updater-allow-unauthenticated-commits", False), + ("test-updater-test-repo", True), + ], +) +def test_no_update_necessary( + test_name, test_repo, updater_repositories, origin_dir, client_dir +): + # clone the origin repositories + # revert them to an older commit + repositories = updater_repositories[test_name] + origin_dir = origin_dir / test_name + client_repos = _clone_client_repositories(repositories, origin_dir, client_dir) + # create valid last validated commit file + _create_last_validated_commit( + client_dir, client_repos[AUTH_REPO_REL_PATH].head_commit_sha() + ) + _update_and_check_commit_shas( + client_repos, repositories, origin_dir, client_dir, test_repo + ) + + +@pytest.mark.parametrize( + "test_name, expected_error", + [ + ("test-updater-invalid-target-sha", TARGET1_SHA_MISMATCH), + ("test-updater-missing-target-commit", TARGET1_SHA_MISMATCH), + ("test-updater-wrong-key", NO_WORKING_MIRRORS), + ("test-updater-invalid-expiration-date", TIMESTAMP_EXPIRED), + ("test-updater-invalid-version-number", REPLAYED_METADATA), + ("test-updater-just-targets-updated", METADATA_CHANGED_BUT_SHOULDNT), + ], +) +def test_updater_invalid_update( + test_name, expected_error, updater_repositories, client_dir +): + repositories = updater_repositories[test_name] + clients_auth_repo_path = client_dir / AUTH_REPO_REL_PATH + _update_invalid_repos_and_check_if_repos_exist( + client_dir, repositories, expected_error + ) + # make sure that the last validated commit does not exist + _check_if_last_validated_commit_exists(clients_auth_repo_path) + + +@pytest.mark.parametrize( + "test_name, expected_error", + [("test-updater-invalid-target-sha", TARGET1_SHA_MISMATCH)], +) +def test_updater_invalid_target_sha_existing_client_repos( + test_name, expected_error, updater_repositories, origin_dir, client_dir +): + repositories = updater_repositories[test_name] + clients_auth_repo_path = client_dir / AUTH_REPO_REL_PATH + origin_dir = origin_dir / test_name + client_repos = _clone_and_revert_client_repositories( + repositories, origin_dir, client_dir, 1 + ) + _create_last_validated_commit( + client_dir, client_repos[AUTH_REPO_REL_PATH].head_commit_sha() + ) + _update_invalid_repos_and_check_if_remained_same( + client_repos, client_dir, repositories, expected_error + ) + _check_last_validated_commit(clients_auth_repo_path) def test_no_target_repositories(updater_repositories, origin_dir, client_dir): - repositories = updater_repositories['test-updater-valid'] - origin_dir = origin_dir / 'test-updater-valid' - client_auth_repo = _clone_client_repo(AUTH_REPO_REL_PATH, origin_dir, client_dir) - _create_last_validated_commit(client_dir, client_auth_repo.head_commit_sha()) - client_repos = {AUTH_REPO_REL_PATH: client_auth_repo} - _update_invalid_repos_and_check_if_remained_same(client_repos, client_dir, - repositories, - TARGETS_MISMATCH_ANY) - # make sure that the target repositories still do not exist - for repository_rel_path in repositories: - if repository_rel_path != AUTH_REPO_REL_PATH: - client_repo_path = client_dir / repository_rel_path - assert client_repo_path.exists() is False + repositories = updater_repositories["test-updater-valid"] + origin_dir = origin_dir / "test-updater-valid" + client_auth_repo = _clone_client_repo(AUTH_REPO_REL_PATH, origin_dir, client_dir) + _create_last_validated_commit(client_dir, client_auth_repo.head_commit_sha()) + client_repos = {AUTH_REPO_REL_PATH: client_auth_repo} + _update_invalid_repos_and_check_if_remained_same( + client_repos, client_dir, repositories, TARGETS_MISMATCH_ANY + ) + # make sure that the target repositories still do not exist + for repository_rel_path in repositories: + if repository_rel_path != AUTH_REPO_REL_PATH: + client_repo_path = client_dir / repository_rel_path + assert client_repo_path.exists() is False def test_no_last_validated_commit(updater_repositories, origin_dir, client_dir): - # clone the origin repositories - # revert them to an older commit - repositories = updater_repositories['test-updater-valid'] - origin_dir = origin_dir / 'test-updater-valid' - client_repos = _clone_and_revert_client_repositories(repositories, origin_dir, - client_dir, 3) - # update without setting the last validated commit - # update should start from the beginning and be successful - _update_and_check_commit_shas(client_repos, repositories, origin_dir, client_dir) + # clone the origin repositories + # revert them to an older commit + repositories = updater_repositories["test-updater-valid"] + origin_dir = origin_dir / "test-updater-valid" + client_repos = _clone_and_revert_client_repositories( + repositories, origin_dir, client_dir, 3 + ) + # update without setting the last validated commit + # update should start from the beginning and be successful + _update_and_check_commit_shas(client_repos, repositories, origin_dir, client_dir) def test_invalid_last_validated_commit(updater_repositories, origin_dir, client_dir): - # clone the origin repositories - # revert them to an older commit - repositories = updater_repositories['test-updater-valid'] - origin_dir = origin_dir / 'test-updater-valid' - client_repos = _clone_and_revert_client_repositories(repositories, origin_dir, - client_dir, 3) - first_commit = client_repos[AUTH_REPO_REL_PATH].all_commits_since_commit(None)[0] - expected_error = 'Saved last validated commit {} does not match the head commit'.format( - first_commit) - _create_last_validated_commit(client_dir, first_commit) - # try to update without setting the last validated commit - _update_invalid_repos_and_check_if_remained_same(client_repos, client_dir, - repositories, expected_error) + # clone the origin repositories + # revert them to an older commit + repositories = updater_repositories["test-updater-valid"] + origin_dir = origin_dir / "test-updater-valid" + client_repos = _clone_and_revert_client_repositories( + repositories, origin_dir, client_dir, 3 + ) + first_commit = client_repos[AUTH_REPO_REL_PATH].all_commits_since_commit(None)[0] + expected_error = "Saved last validated commit {} does not match the head commit".format( + first_commit + ) + _create_last_validated_commit(client_dir, first_commit) + # try to update without setting the last validated commit + _update_invalid_repos_and_check_if_remained_same( + client_repos, client_dir, repositories, expected_error + ) def test_update_test_repo_no_flag(updater_repositories, origin_dir, client_dir): - repositories = updater_repositories['test-updater-test-repo'] - origin_dir = origin_dir / 'test-updater-test-repo' - expected_error = 'Repository auth_repo is a test repository.' - # try to update without setting the last validated commit - _update_invalid_repos_and_check_if_repos_exist(client_dir, repositories, expected_error) + repositories = updater_repositories["test-updater-test-repo"] + origin_dir = origin_dir / "test-updater-test-repo" + expected_error = "Repository auth_repo is a test repository." + # try to update without setting the last validated commit + _update_invalid_repos_and_check_if_repos_exist( + client_dir, repositories, expected_error + ) def test_update_repo_wrong_flag(updater_repositories, origin_dir, client_dir): - repositories = updater_repositories['test-updater-valid'] - origin_dir = origin_dir / 'test-updater-valid' - expected_error = 'Repository auth_repo is not a test repository.' - # try to update without setting the last validated commit - _update_invalid_repos_and_check_if_repos_exist(client_dir, repositories, expected_error, True) + repositories = updater_repositories["test-updater-valid"] + origin_dir = origin_dir / "test-updater-valid" + expected_error = "Repository auth_repo is not a test repository." + # try to update without setting the last validated commit + _update_invalid_repos_and_check_if_repos_exist( + client_dir, repositories, expected_error, True + ) def _check_last_validated_commit(clients_auth_repo_path): - # check if last validated commit is created and the saved commit is correct - client_auth_repo = AuthenticationRepo(str(clients_auth_repo_path), 'metadata', 'targets') - head_sha = client_auth_repo.head_commit_sha() - last_validated_commit = client_auth_repo.last_validated_commit - assert head_sha == last_validated_commit + # check if last validated commit is created and the saved commit is correct + client_auth_repo = AuthenticationRepo( + str(clients_auth_repo_path), "metadata", "targets" + ) + head_sha = client_auth_repo.head_commit_sha() + last_validated_commit = client_auth_repo.last_validated_commit + assert head_sha == last_validated_commit def _check_if_last_validated_commit_exists(clients_auth_repo_path): - client_auth_repo = AuthenticationRepo(str(clients_auth_repo_path), 'metadata', 'targets') - last_validated_commit = client_auth_repo.last_validated_commit - assert last_validated_commit is None + client_auth_repo = AuthenticationRepo( + str(clients_auth_repo_path), "metadata", "targets" + ) + last_validated_commit = client_auth_repo.last_validated_commit + assert last_validated_commit is None def _check_if_commits_match(repositories, origin_dir, client_dir, start_head_shas=None): - for repository_rel_path in repositories: - origin_repo = GitRepository(origin_dir / repository_rel_path) - client_repo = GitRepository(client_dir / repository_rel_path) - if start_head_shas is not None: - start_commit = start_head_shas.get(repository_rel_path) - else: - start_commit = None - origin_auth_repo_commits = origin_repo.all_commits_since_commit(start_commit) - client_auth_repo_commits = client_repo.all_commits_since_commit(start_commit) - for origin_commit, client_commit in zip(origin_auth_repo_commits, client_auth_repo_commits): - assert origin_commit == client_commit + for repository_rel_path in repositories: + origin_repo = GitRepository(origin_dir / repository_rel_path) + client_repo = GitRepository(client_dir / repository_rel_path) + if start_head_shas is not None: + start_commit = start_head_shas.get(repository_rel_path) + else: + start_commit = None + origin_auth_repo_commits = origin_repo.all_commits_since_commit(start_commit) + client_auth_repo_commits = client_repo.all_commits_since_commit(start_commit) + for origin_commit, client_commit in zip( + origin_auth_repo_commits, client_auth_repo_commits + ): + assert origin_commit == client_commit def _clone_client_repositories(repositories, origin_dir, client_dir): - client_repos = {} - for repository_rel_path in repositories: - client_repo = _clone_client_repo(repository_rel_path, origin_dir, client_dir) - client_repos[repository_rel_path] = client_repo - return client_repos + client_repos = {} + for repository_rel_path in repositories: + client_repo = _clone_client_repo(repository_rel_path, origin_dir, client_dir) + client_repos[repository_rel_path] = client_repo + return client_repos -def _clone_and_revert_client_repositories(repositories, origin_dir, client_dir, num_of_commits): - client_repos = {} +def _clone_and_revert_client_repositories( + repositories, origin_dir, client_dir, num_of_commits +): + client_repos = {} - client_auth_repo = _clone_client_repo(AUTH_REPO_REL_PATH, origin_dir, client_dir) - client_auth_repo.reset_num_of_commits(num_of_commits, True) - client_auth_repo_head_sha = client_auth_repo.head_commit_sha() - client_repos[AUTH_REPO_REL_PATH] = client_auth_repo + client_auth_repo = _clone_client_repo(AUTH_REPO_REL_PATH, origin_dir, client_dir) + client_auth_repo.reset_num_of_commits(num_of_commits, True) + client_auth_repo_head_sha = client_auth_repo.head_commit_sha() + client_repos[AUTH_REPO_REL_PATH] = client_auth_repo - for repository_rel_path in repositories: - if repository_rel_path == AUTH_REPO_REL_PATH: - continue + for repository_rel_path in repositories: + if repository_rel_path == AUTH_REPO_REL_PATH: + continue - client_repo = _clone_client_repo(repository_rel_path, origin_dir, client_dir) - # read the commit sha stored in target files - commit = client_auth_repo.get_json(client_auth_repo_head_sha, - str((Path('targets') / repository_rel_path).as_posix())) - commit_sha = commit['commit'] - client_repo.reset_to_commit(commit_sha, True) - client_repos[repository_rel_path] = client_repo + client_repo = _clone_client_repo(repository_rel_path, origin_dir, client_dir) + # read the commit sha stored in target files + commit = client_auth_repo.get_json( + client_auth_repo_head_sha, + str((Path("targets") / repository_rel_path).as_posix()), + ) + commit_sha = commit["commit"] + client_repo.reset_to_commit(commit_sha, True) + client_repos[repository_rel_path] = client_repo - return client_repos + return client_repos def _clone_client_repo(repository_rel_path, origin_dir, client_dir): - origin_repo_path = str(origin_dir / repository_rel_path) - client_repo_path = str(client_dir / repository_rel_path) - client_repo = GitRepository(client_repo_path, [origin_repo_path]) - client_repo.clone() - return client_repo + origin_repo_path = str(origin_dir / repository_rel_path) + client_repo_path = str(client_dir / repository_rel_path) + client_repo = GitRepository(client_repo_path, [origin_repo_path]) + client_repo.clone() + return client_repo def _create_last_validated_commit(client_dir, client_auth_repo_head_sha): - client_conf_repo = client_dir / 'organization/_auth_repo' - client_conf_repo.mkdir(parents=True, exist_ok=True) - with open(str(client_conf_repo/'last_validated_commit'), 'w') as f: - f.write(client_auth_repo_head_sha) - - -def _update_and_check_commit_shas(client_repos, repositories, origin_dir, client_dir, - authetnicate_test_repo=False): - if client_repos is not None: - start_head_shas = {repo_rel_path: repo.head_commit_sha() - for repo_rel_path, repo in client_repos.items()} - else: - start_head_shas = {repo_rel_path: None for repo_rel_path in repositories} - - clients_auth_repo_path = client_dir / AUTH_REPO_REL_PATH - origin_auth_repo_path = repositories[AUTH_REPO_REL_PATH] - update_repository(str(origin_auth_repo_path), str(clients_auth_repo_path), str(client_dir), True, - authenticate_test_repo=authetnicate_test_repo) - _check_if_commits_match(repositories, origin_dir, client_dir, start_head_shas) - _check_last_validated_commit(clients_auth_repo_path) - - -def _update_invalid_repos_and_check_if_remained_same(client_repos, client_dir, repositories, - expected_error, authenticate_test_repo=False): - - start_head_shas = {repo_rel_path: repo.head_commit_sha() - for repo_rel_path, repo in client_repos.items()} - clients_auth_repo_path = client_dir / AUTH_REPO_REL_PATH - origin_auth_repo_path = repositories[AUTH_REPO_REL_PATH] - - with pytest.raises(UpdateFailedError, match=expected_error): - update_repository(str(origin_auth_repo_path), str( - clients_auth_repo_path), str(client_dir), True, authenticate_test_repo=authenticate_test_repo) - - # all repositories should still have the same head commit - for repo_path, repo in client_repos.items(): - current_head = repo.head_commit_sha() - assert current_head == start_head_shas[repo_path] - - -def _update_invalid_repos_and_check_if_repos_exist(client_dir, repositories, expected_error, - authenticate_test_repo=False): - - clients_auth_repo_path = client_dir / AUTH_REPO_REL_PATH - origin_auth_repo_path = repositories[AUTH_REPO_REL_PATH] - with pytest.raises(UpdateFailedError, match=expected_error): - update_repository(str(origin_auth_repo_path), str( - clients_auth_repo_path), str(client_dir), True, authenticate_test_repo=authenticate_test_repo), - - # the client repositories should not exits - for repository_rel_path in repositories: - path = client_dir / repository_rel_path - assert path.exists() is False + client_conf_repo = client_dir / "organization/_auth_repo" + client_conf_repo.mkdir(parents=True, exist_ok=True) + with open(str(client_conf_repo / "last_validated_commit"), "w") as f: + f.write(client_auth_repo_head_sha) + + +def _update_and_check_commit_shas( + client_repos, repositories, origin_dir, client_dir, authetnicate_test_repo=False +): + if client_repos is not None: + start_head_shas = { + repo_rel_path: repo.head_commit_sha() + for repo_rel_path, repo in client_repos.items() + } + else: + start_head_shas = {repo_rel_path: None for repo_rel_path in repositories} + + clients_auth_repo_path = client_dir / AUTH_REPO_REL_PATH + origin_auth_repo_path = repositories[AUTH_REPO_REL_PATH] + update_repository( + str(origin_auth_repo_path), + str(clients_auth_repo_path), + str(client_dir), + True, + authenticate_test_repo=authetnicate_test_repo, + ) + _check_if_commits_match(repositories, origin_dir, client_dir, start_head_shas) + _check_last_validated_commit(clients_auth_repo_path) + + +def _update_invalid_repos_and_check_if_remained_same( + client_repos, client_dir, repositories, expected_error, authenticate_test_repo=False +): + + start_head_shas = { + repo_rel_path: repo.head_commit_sha() + for repo_rel_path, repo in client_repos.items() + } + clients_auth_repo_path = client_dir / AUTH_REPO_REL_PATH + origin_auth_repo_path = repositories[AUTH_REPO_REL_PATH] + + with pytest.raises(UpdateFailedError, match=expected_error): + update_repository( + str(origin_auth_repo_path), + str(clients_auth_repo_path), + str(client_dir), + True, + authenticate_test_repo=authenticate_test_repo, + ) + + # all repositories should still have the same head commit + for repo_path, repo in client_repos.items(): + current_head = repo.head_commit_sha() + assert current_head == start_head_shas[repo_path] + + +def _update_invalid_repos_and_check_if_repos_exist( + client_dir, repositories, expected_error, authenticate_test_repo=False +): + + clients_auth_repo_path = client_dir / AUTH_REPO_REL_PATH + origin_auth_repo_path = repositories[AUTH_REPO_REL_PATH] + with pytest.raises(UpdateFailedError, match=expected_error): + update_repository( + str(origin_auth_repo_path), + str(clients_auth_repo_path), + str(client_dir), + True, + authenticate_test_repo=authenticate_test_repo, + ), + + # the client repositories should not exits + for repository_rel_path in repositories: + path = client_dir / repository_rel_path + assert path.exists() is False diff --git a/tests/test_utils.py b/tests/test_utils.py index 52f339e81..e91d2376e 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -3,17 +3,17 @@ def test_normalize_line_ending_extra_lines(): - test_content = b'''This is some text followed by two new lines + test_content = b"""This is some text followed by two new lines -''' - expected_content = b'This is some text followed by two new lines' - replaced_content = normalize_line_endings(test_content) - assert replaced_content == expected_content +""" + expected_content = b"This is some text followed by two new lines" + replaced_content = normalize_line_endings(test_content) + assert replaced_content == expected_content def test_normalize_line_ending_no_new_line(): - test_content = b'This is some text without new line at the end of the file' - expected_content = test_content - replaced_content = normalize_line_endings(test_content) - assert replaced_content == expected_content + test_content = b"This is some text without new line at the end of the file" + expected_content = test_content + replaced_content = normalize_line_endings(test_content) + assert replaced_content == expected_content diff --git a/tests/test_yubikey.py b/tests/test_yubikey.py index eb2dde4a2..fe9d5ac7b 100644 --- a/tests/test_yubikey.py +++ b/tests/test_yubikey.py @@ -1,43 +1,49 @@ import pytest -from taf.yubikey import (DEFAULT_PIN, export_piv_pub_key, export_piv_x509, - get_serial_num, is_inserted, sign_piv_rsa_pkcs1v15) +from taf.yubikey import ( + DEFAULT_PIN, + export_piv_pub_key, + export_piv_x509, + get_serial_num, + is_inserted, + sign_piv_rsa_pkcs1v15, +) from . import TEST_WITH_REAL_YK @pytest.mark.skipif(not TEST_WITH_REAL_YK, reason="list_devices() is not mocked.") def test_is_inserted(): - assert is_inserted() == True + assert is_inserted() is True def test_serial_num(): - assert get_serial_num() is not None + assert get_serial_num() is not None def test_export_piv_x509(): - x509_pem = export_piv_x509() - assert isinstance(x509_pem, bytes) + x509_pem = export_piv_x509() + assert isinstance(x509_pem, bytes) def test_export_piv_pub_key(): - pub_key_pem = export_piv_pub_key() - assert isinstance(pub_key_pem, bytes) + pub_key_pem = export_piv_pub_key() + assert isinstance(pub_key_pem, bytes) def test_sign_piv_rsa_pkcs1v15(targets_yk): - targets_yk.insert() - # yubikey-manager only supports rsa-pkcs1v15-sha256 signature scheme - # so skip test otherwise - if targets_yk.scheme == 'rsassa-pss-sha256': - pytest.skip() + targets_yk.insert() + # yubikey-manager only supports rsa-pkcs1v15-sha256 signature scheme + # so skip test otherwise + if targets_yk.scheme == "rsassa-pss-sha256": + pytest.skip() - from securesystemslib.pyca_crypto_keys import verify_rsa_signature + from securesystemslib.pyca_crypto_keys import verify_rsa_signature - message = b'Message to be signed.' - scheme = 'rsa-pkcs1v15-sha256' + message = b"Message to be signed." + scheme = "rsa-pkcs1v15-sha256" - pub_key_pem = export_piv_pub_key().decode('utf-8') - signature = sign_piv_rsa_pkcs1v15(message, DEFAULT_PIN) + pub_key_pem = export_piv_pub_key().decode("utf-8") + signature = sign_piv_rsa_pkcs1v15(message, DEFAULT_PIN) - assert verify_rsa_signature(signature, scheme, pub_key_pem, message) == True + assert verify_rsa_signature(signature, scheme, pub_key_pem, message) is True diff --git a/tests/yubikey_utils.py b/tests/yubikey_utils.py index 0448e0ddb..665aacb59 100644 --- a/tests/yubikey_utils.py +++ b/tests/yubikey_utils.py @@ -11,142 +11,144 @@ from ykman.descriptor import FailedOpeningDeviceException from ykman.piv import WrongPin -VALID_PIN = '123456' -WRONG_PIN = '111111' +VALID_PIN = "123456" +WRONG_PIN = "111111" INSERTED_YUBIKEY = None class FakeYubiKey: - def __init__(self, priv_key_path, pub_key_path, scheme, serial=None, pin=VALID_PIN): - self.priv_key_pem = priv_key_path.read_bytes() - self.pub_key_pem = pub_key_path.read_bytes() + def __init__(self, priv_key_path, pub_key_path, scheme, serial=None, pin=VALID_PIN): + self.priv_key_pem = priv_key_path.read_bytes() + self.pub_key_pem = pub_key_path.read_bytes() - self._serial = serial if serial else random.randint(100000, 999999) - self._pin = pin + self._serial = serial if serial else random.randint(100000, 999999) + self._pin = pin - self.scheme = scheme - self.priv_key = serialization.load_pem_private_key(self.priv_key_pem, None, - default_backend()) - self.pub_key = serialization.load_pem_public_key(self.pub_key_pem, - default_backend()) + self.scheme = scheme + self.priv_key = serialization.load_pem_private_key( + self.priv_key_pem, None, default_backend() + ) + self.pub_key = serialization.load_pem_public_key( + self.pub_key_pem, default_backend() + ) - self.tuf_key = import_rsakey_from_pem(self.pub_key_pem.decode('utf-8'), - scheme) + self.tuf_key = import_rsakey_from_pem(self.pub_key_pem.decode("utf-8"), scheme) - @property - def driver(self): - return self + @property + def driver(self): + return self - @property - def pin(self): - return self._pin + @property + def pin(self): + return self._pin - @property - def serial(self): - return self._serial + @property + def serial(self): + return self._serial - def insert(self): - """Insert YubiKey in USB slot.""" - global INSERTED_YUBIKEY - INSERTED_YUBIKEY = self + def insert(self): + """Insert YubiKey in USB slot.""" + global INSERTED_YUBIKEY + INSERTED_YUBIKEY = self - def is_inserted(self): - """Check if YubiKey is in USB slot.""" - global INSERTED_YUBIKEY - return INSERTED_YUBIKEY is self + def is_inserted(self): + """Check if YubiKey is in USB slot.""" + global INSERTED_YUBIKEY + return INSERTED_YUBIKEY is self - def remove(self): - """Removes YubiKey from USB slot.""" - global INSERTED_YUBIKEY - if INSERTED_YUBIKEY is self: - INSERTED_YUBIKEY = None + def remove(self): + """Removes YubiKey from USB slot.""" + global INSERTED_YUBIKEY + if INSERTED_YUBIKEY is self: + INSERTED_YUBIKEY = None class FakePivController: - def __init__(self, driver): - self._driver = driver + def __init__(self, driver): + self._driver = driver - @property - def driver(self): - return None + @property + def driver(self): + return None - def authenticate(self, *args, **kwargs): - pass + def authenticate(self, *args, **kwargs): + pass - def change_pin(self, *args, **kwargs): - pass + def change_pin(self, *args, **kwargs): + pass - def change_puk(self, *args, **kwargs): - pass + def change_puk(self, *args, **kwargs): + pass - def generate_self_signed_certificate(self, *args, **kwargs): - pass + def generate_self_signed_certificate(self, *args, **kwargs): + pass - def read_certificate(self, _slot): - name = x509.Name([ - x509.NameAttribute(x509.NameOID.COMMON_NAME, self.__class__.__name__) - ]) - now = datetime.datetime.utcnow() + def read_certificate(self, _slot): + name = x509.Name( + [x509.NameAttribute(x509.NameOID.COMMON_NAME, self.__class__.__name__)] + ) + now = datetime.datetime.utcnow() - return ( - x509.CertificateBuilder() - .subject_name(name) - .issuer_name(name) - .public_key(self._driver.pub_key) - .serial_number(self._driver.serial) - .not_valid_before(now) - .not_valid_after(now + datetime.timedelta(days=365)) - .sign(self._driver.priv_key, hashes.SHA256(), default_backend()) - ) + return ( + x509.CertificateBuilder() + .subject_name(name) + .issuer_name(name) + .public_key(self._driver.pub_key) + .serial_number(self._driver.serial) + .not_valid_before(now) + .not_valid_after(now + datetime.timedelta(days=365)) + .sign(self._driver.priv_key, hashes.SHA256(), default_backend()) + ) - def reset(self): - pass + def reset(self): + pass - def set_pin_retries(self, *args, **kwargs): - pass + def set_pin_retries(self, *args, **kwargs): + pass - def sign(self, slot, algorithm, data): - """Sign data using the same function as TUF""" - if isinstance(data, str): - data = data.encode('utf-8') + def sign(self, slot, algorithm, data): + """Sign data using the same function as TUF""" + if isinstance(data, str): + data = data.encode("utf-8") - sig, _ = create_rsa_signature(self._driver.priv_key_pem.decode('utf-8'), - data, self._driver.scheme) - return sig + sig, _ = create_rsa_signature( + self._driver.priv_key_pem.decode("utf-8"), data, self._driver.scheme + ) + return sig - def verify(self, pin): - if self._driver.pin != pin: - raise WrongPin("", "") + def verify(self, pin): + if self._driver.pin != pin: + raise WrongPin("", "") class TargetYubiKey(FakeYubiKey): - def __init__(self, keystore_path, scheme): - super().__init__(keystore_path / 'targets', keystore_path / 'targets.pub', - scheme) + def __init__(self, keystore_path, scheme): + super().__init__( + keystore_path / "targets", keystore_path / "targets.pub", scheme + ) class Root1YubiKey(FakeYubiKey): - def __init__(self, keystore_path, scheme): - super().__init__(keystore_path / 'root1', keystore_path / 'root1.pub', - scheme) + def __init__(self, keystore_path, scheme): + super().__init__(keystore_path / "root1", keystore_path / "root1.pub", scheme) class Root2YubiKey(FakeYubiKey): - def __init__(self, keystore_path, scheme): - super().__init__(keystore_path / 'root2', keystore_path / 'root2.pub', scheme) + def __init__(self, keystore_path, scheme): + super().__init__(keystore_path / "root2", keystore_path / "root2.pub", scheme) class Root3YubiKey(FakeYubiKey): - def __init__(self, keystore_path, scheme): - super().__init__(keystore_path / 'root3', keystore_path / 'root3.pub', scheme) + def __init__(self, keystore_path, scheme): + super().__init__(keystore_path / "root3", keystore_path / "root3.pub", scheme) @contextmanager def _yk_piv_ctrl_mock(serial=None, pub_key_pem=None): - global INSERTED_YUBIKEY + global INSERTED_YUBIKEY - if INSERTED_YUBIKEY is None: - raise FailedOpeningDeviceException() + if INSERTED_YUBIKEY is None: + raise FailedOpeningDeviceException() - yield FakePivController(INSERTED_YUBIKEY), INSERTED_YUBIKEY.serial + yield FakePivController(INSERTED_YUBIKEY), INSERTED_YUBIKEY.serial