diff --git a/.evergreen/auth_aws/aws_tester.py b/.evergreen/auth_aws/aws_tester.py index a500b587..43923de2 100644 --- a/.evergreen/auth_aws/aws_tester.py +++ b/.evergreen/auth_aws/aws_tester.py @@ -1,6 +1,7 @@ """ Script for testing MONGDOB-AWS authentication. """ + import argparse import json import os @@ -14,11 +15,12 @@ HERE = os.path.abspath(os.path.dirname(__file__)) + def join(*parts): - return os.path.join(*parts).replace(os.sep, '/') + return os.path.join(*parts).replace(os.sep, "/") -sys.path.insert(0, join(HERE, 'lib')) +sys.path.insert(0, join(HERE, "lib")) from aws_assign_instance_profile import _assign_instance_policy from aws_assume_role import _assume_role from aws_assume_web_role import _assume_role_with_web_identity @@ -32,7 +34,7 @@ def join(*parts): _USE_AWS_SECRETS = False try: - with open(join(HERE, 'aws_e2e_setup.json')) as fid: + with open(join(HERE, "aws_e2e_setup.json")) as fid: CONFIG = json.load(fid) get_key = partial(_get_key, uppercase=False) except FileNotFoundError: @@ -48,36 +50,41 @@ def run(args, env): def create_user(user, kwargs): """Create a user and verify access.""" - print('Creating user', user) + print("Creating user", user) client = MongoClient(username="bob", password="pwd123") - db = client['$external'] + db = client["$external"] try: db.command(dict(createUser=user, roles=[{"role": "read", "db": "aws"}])) except OperationFailure as e: - if "already exists" not in e.details['errmsg']: + if "already exists" not in e.details["errmsg"]: raise client.close() # Verify access. - client = MongoClient(authMechanism='MONGODB-AWS', **kwargs) - client.aws.command('count', 'test') + client = MongoClient(authMechanism="MONGODB-AWS", **kwargs) + client.aws.command("count", "test") client.close() def setup_assume_role(): # Assume the role to get temp creds. - os.environ['AWS_ACCESS_KEY_ID'] = CONFIG[get_key("iam_auth_assume_aws_account")] - os.environ['AWS_SECRET_ACCESS_KEY'] = CONFIG[get_key("iam_auth_assume_aws_secret_access_key")] + os.environ["AWS_ACCESS_KEY_ID"] = CONFIG[get_key("iam_auth_assume_aws_account")] + os.environ["AWS_SECRET_ACCESS_KEY"] = CONFIG[ + get_key("iam_auth_assume_aws_secret_access_key") + ] role_name = CONFIG[get_key("iam_auth_assume_role_name")] creds = _assume_role(role_name, quiet=True) - with open(join(HERE, 'creds.json'), 'w') as fid: + with open(join(HERE, "creds.json"), "w") as fid: json.dump(creds, fid) # Create the user. - token = quote_plus(creds['SessionToken']) - kwargs = dict(username=creds["AccessKeyId"], password=creds["SecretAccessKey"], - authmechanismproperties=f"AWS_SESSION_TOKEN:{token}") + token = quote_plus(creds["SessionToken"]) + kwargs = dict( + username=creds["AccessKeyId"], + password=creds["SecretAccessKey"], + authmechanismproperties=f"AWS_SESSION_TOKEN:{token}", + ) create_user(ASSUMED_ROLE, kwargs) @@ -91,63 +98,77 @@ def setup_ec2(): def setup_ecs(): # Set up commands. - mongo_binaries = os.environ['MONGODB_BINARIES'] - project_dir = os.environ['PROJECT_DIRECTORY'] + mongo_binaries = os.environ["MONGODB_BINARIES"] + project_dir = os.environ["PROJECT_DIRECTORY"] base_command = f"{sys.executable} -u lib/container_tester.py" run_prune_command = f"{base_command} -v remote_gc_services --cluster {CONFIG[get_key('iam_auth_ecs_cluster')]}" # Get the appropriate task definition based on the version of Ubuntu. - with open('/etc/lsb-release') as fid: + with open("/etc/lsb-release") as fid: text = fid.read() - if 'jammy' in text: - task_definition = CONFIG.get(get_key('iam_auth_ecs_task_definition_jammy'), None) + if "jammy" in text: + task_definition = CONFIG.get( + get_key("iam_auth_ecs_task_definition_jammy"), None + ) if task_definition is None: raise ValueError('Please set "iam_auth_ecs_task_definition_jammy" variable') - elif 'focal' in text: - task_definition = CONFIG.get(get_key('iam_auth_ecs_task_definition_focal'), None) + elif "focal" in text: + task_definition = CONFIG.get( + get_key("iam_auth_ecs_task_definition_focal"), None + ) # Fall back to previous task definition for backward compat. if task_definition is None: - task_definition = CONFIG[get_key('iam_auth_ecs_task_definition')] + task_definition = CONFIG[get_key("iam_auth_ecs_task_definition")] else: - raise ValueError('Unsupported ubuntu release') + raise ValueError("Unsupported ubuntu release") run_test_command = f"{base_command} -d -v run_e2e_test --cluster {CONFIG[get_key('iam_auth_ecs_cluster')]} --task_definition {task_definition} --subnets {CONFIG[get_key('iam_auth_ecs_subnet_a')]} --subnets {CONFIG[get_key('iam_auth_ecs_subnet_b')]} --security_group {CONFIG[get_key('iam_auth_ecs_security_group')]} --files {mongo_binaries}/mongod:/root/mongod {mongo_binaries}/mongosh:/root/mongosh lib/ecs_hosted_test.js:/root/ecs_hosted_test.js {project_dir}:/root --script lib/ecs_hosted_test.sh" # Pass in the AWS credentials as environment variables # AWS_SHARED_CREDENTIALS_FILE does not work in evergreen for an unknown # reason - env = dict(AWS_ACCESS_KEY_ID=CONFIG[get_key('iam_auth_ecs_account')], - AWS_SECRET_ACCESS_KEY=CONFIG[get_key('iam_auth_ecs_secret_access_key')]) + env = dict( + AWS_ACCESS_KEY_ID=CONFIG[get_key("iam_auth_ecs_account")], + AWS_SECRET_ACCESS_KEY=CONFIG[get_key("iam_auth_ecs_secret_access_key")], + ) # Prune other containers - subprocess.check_call(['/bin/sh', '-c', run_prune_command], env=env) + subprocess.check_call(["/bin/sh", "-c", run_prune_command], env=env) # Run the test in a container - subprocess.check_call(['/bin/sh', '-c', run_test_command], env=env) + subprocess.check_call(["/bin/sh", "-c", run_test_command], env=env) def setup_regular(): # Create the user. kwargs = dict( username=CONFIG[get_key("iam_auth_ecs_account")], - password=CONFIG[get_key("iam_auth_ecs_secret_access_key")] + password=CONFIG[get_key("iam_auth_ecs_secret_access_key")], ) create_user(CONFIG[get_key("iam_auth_ecs_account_arn")], kwargs) def setup_web_identity(): # Unassign the instance profile. - env = dict(AWS_ACCESS_KEY_ID=CONFIG[get_key("iam_auth_ec2_instance_account")], - AWS_SECRET_ACCESS_KEY=CONFIG[get_key("iam_auth_ec2_instance_secret_access_key")]) - ret = run(['lib/aws_unassign_instance_profile.py'], env) + env = dict( + AWS_ACCESS_KEY_ID=CONFIG[get_key("iam_auth_ec2_instance_account")], + AWS_SECRET_ACCESS_KEY=CONFIG[ + get_key("iam_auth_ec2_instance_secret_access_key") + ], + ) + ret = run(["lib/aws_unassign_instance_profile.py"], env) if ret == 2: raise RuntimeError("Request limit exceeded for AWS API") if ret != 0: - print('ret was', ret) - raise RuntimeError("Failed to unassign an instance profile from the current machine") + print("ret was", ret) + raise RuntimeError( + "Failed to unassign an instance profile from the current machine" + ) - token_file = os.environ.get('AWS_WEB_IDENTITY_TOKEN_FILE', CONFIG[get_key('iam_web_identity_token_file')]) - if os.name == "nt" and token_file.startswith('/tmp'): + token_file = os.environ.get( + "AWS_WEB_IDENTITY_TOKEN_FILE", CONFIG[get_key("iam_web_identity_token_file")] + ) + if os.name == "nt" and token_file.startswith("/tmp"): token_file = token_file.replace("/tmp", "C:/cygwin/tmp/") # Handle the OIDC credentials. @@ -155,50 +176,53 @@ def setup_web_identity(): IDP_ISSUER=CONFIG[get_key("iam_web_identity_issuer")], IDP_JWKS_URI=CONFIG[get_key("iam_web_identity_jwks_uri")], IDP_RSA_KEY=CONFIG[get_key("iam_web_identity_rsa_key")], - AWS_WEB_IDENTITY_TOKEN_FILE=token_file + AWS_WEB_IDENTITY_TOKEN_FILE=token_file, ) - ret = run(['lib/aws_handle_oidc_creds.py', 'token'], env) + ret = run(["lib/aws_handle_oidc_creds.py", "token"], env) if ret != 0: raise RuntimeWarning("Failed to write the web token") # Assume the web role to get temp credentials. - os.environ['AWS_WEB_IDENTITY_TOKEN_FILE'] = token_file - os.environ['AWS_ROLE_ARN'] = CONFIG[get_key("iam_auth_assume_web_role_name")] + os.environ["AWS_WEB_IDENTITY_TOKEN_FILE"] = token_file + os.environ["AWS_ROLE_ARN"] = CONFIG[get_key("iam_auth_assume_web_role_name")] creds = _assume_role_with_web_identity(True) - with open(join(HERE, 'creds.json'), 'w') as fid: + with open(join(HERE, "creds.json"), "w") as fid: json.dump(creds, fid) # Create the user. - token = quote_plus(creds['SessionToken']) - kwargs = dict(username=creds["AccessKeyId"], password=creds["SecretAccessKey"], - authmechanismproperties=f"AWS_SESSION_TOKEN:{token}") + token = quote_plus(creds["SessionToken"]) + kwargs = dict( + username=creds["AccessKeyId"], + password=creds["SecretAccessKey"], + authmechanismproperties=f"AWS_SESSION_TOKEN:{token}", + ) create_user(ASSUMED_WEB_ROLE, kwargs) def main(): - parser = argparse.ArgumentParser(description='MONGODB-AWS tester.') + parser = argparse.ArgumentParser(description="MONGODB-AWS tester.") sub = parser.add_subparsers(title="Tester subcommands", help="sub-command help") - run_assume_role_cmd = sub.add_parser('assume-role', help='Assume role test') + run_assume_role_cmd = sub.add_parser("assume-role", help="Assume role test") run_assume_role_cmd.set_defaults(func=setup_assume_role) - run_ec2_cmd = sub.add_parser('ec2', help='EC2 test') + run_ec2_cmd = sub.add_parser("ec2", help="EC2 test") run_ec2_cmd.set_defaults(func=setup_ec2) - run_ecs_cmd = sub.add_parser('ecs', help='ECS test') + run_ecs_cmd = sub.add_parser("ecs", help="ECS test") run_ecs_cmd.set_defaults(func=setup_ecs) - run_regular_cmd = sub.add_parser('regular', help='Regular credentials test') + run_regular_cmd = sub.add_parser("regular", help="Regular credentials test") run_regular_cmd.set_defaults(func=setup_regular) - run_web_identity_cmd = sub.add_parser('web-identity', help='Web identity test') + run_web_identity_cmd = sub.add_parser("web-identity", help="Web identity test") run_web_identity_cmd.set_defaults(func=setup_web_identity) args = parser.parse_args() args.func() -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/.evergreen/auth_aws/lib/aws_assign_instance_profile.py b/.evergreen/auth_aws/lib/aws_assign_instance_profile.py index 305b9588..47453f72 100644 --- a/.evergreen/auth_aws/lib/aws_assign_instance_profile.py +++ b/.evergreen/auth_aws/lib/aws_assign_instance_profile.py @@ -15,14 +15,18 @@ import botocore from util import get_key as _get_key -sys.path.insert(1, os.path.join(sys.path[0], '..')) +sys.path.insert(1, os.path.join(sys.path[0], "..")) LOGGER = logging.getLogger(__name__) HERE = os.path.abspath(os.path.dirname(__file__)) def _get_local_instance_id(): - return urllib.request.urlopen('http://169.254.169.254/latest/meta-data/instance-id').read().decode() + return ( + urllib.request.urlopen("http://169.254.169.254/latest/meta-data/instance-id") + .read() + .decode() + ) def _has_instance_profile(): @@ -62,7 +66,7 @@ def _wait_instance_profile(): def _handle_config(): try: - with open(os.path.join(HERE, '..', 'aws_e2e_setup.json')) as fid: + with open(os.path.join(HERE, "..", "aws_e2e_setup.json")) as fid: CONFIG = json.load(fid) get_key = partial(_get_key, uppercase=False) @@ -71,13 +75,17 @@ def _handle_config(): get_key = partial(_get_key, uppercase=True) try: - os.environ.setdefault('AWS_ACCESS_KEY_ID', CONFIG[get_key('iam_auth_ec2_instance_account')]) - os.environ.setdefault('AWS_SECRET_ACCESS_KEY', - CONFIG[get_key('iam_auth_ec2_instance_secret_access_key')]) - return CONFIG[get_key('iam_auth_ec2_instance_profile')] + os.environ.setdefault( + "AWS_ACCESS_KEY_ID", CONFIG[get_key("iam_auth_ec2_instance_account")] + ) + os.environ.setdefault( + "AWS_SECRET_ACCESS_KEY", + CONFIG[get_key("iam_auth_ec2_instance_secret_access_key")], + ) + return CONFIG[get_key("iam_auth_ec2_instance_profile")] except Exception as e: print(e) - return '' + return "" DEFAULT_ARN = _handle_config() @@ -85,20 +93,23 @@ def _handle_config(): def _assign_instance_policy(iam_instance_arn=DEFAULT_ARN): if _has_instance_profile(): - print("IMPORTANT: Found machine already has instance profile, skipping the assignment") + print( + "IMPORTANT: Found machine already has instance profile, skipping the assignment" + ) return instance_id = _get_local_instance_id() - ec2_client = boto3.client("ec2", 'us-east-1') + ec2_client = boto3.client("ec2", "us-east-1") # https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/ec2.html#EC2.Client.associate_iam_instance_profile try: response = ec2_client.associate_iam_instance_profile( IamInstanceProfile={ - 'Arn': iam_instance_arn, + "Arn": iam_instance_arn, }, - InstanceId=instance_id) + InstanceId=instance_id, + ) print(response) @@ -115,12 +126,21 @@ def _assign_instance_policy(iam_instance_arn=DEFAULT_ARN): def main() -> None: """Execute Main entry point.""" - parser = argparse.ArgumentParser(description='IAM Assign Instance frontend.') - - parser.add_argument('-v', "--verbose", action='store_true', help="Enable verbose logging") - parser.add_argument('-d', "--debug", action='store_true', help="Enable debug logging") - - parser.add_argument('--instance_profile_arn', type=str, help="Name of instance profile", default=DEFAULT_ARN) + parser = argparse.ArgumentParser(description="IAM Assign Instance frontend.") + + parser.add_argument( + "-v", "--verbose", action="store_true", help="Enable verbose logging" + ) + parser.add_argument( + "-d", "--debug", action="store_true", help="Enable debug logging" + ) + + parser.add_argument( + "--instance_profile_arn", + type=str, + help="Name of instance profile", + default=DEFAULT_ARN, + ) args = parser.parse_args() diff --git a/.evergreen/auth_aws/lib/aws_assume_role.py b/.evergreen/auth_aws/lib/aws_assume_role.py index 422eb20b..ca76aa57 100644 --- a/.evergreen/auth_aws/lib/aws_assume_role.py +++ b/.evergreen/auth_aws/lib/aws_assume_role.py @@ -12,10 +12,13 @@ STS_DEFAULT_ROLE_NAME = "arn:aws:iam::579766882180:role/mark.benvenuto" + def _assume_role(role_name, quiet=False): sts_client = boto3.client("sts", region_name="us-east-1") - response = sts_client.assume_role(RoleArn=role_name, RoleSessionName=str(uuid.uuid4()), DurationSeconds=900) + response = sts_client.assume_role( + RoleArn=role_name, RoleSessionName=str(uuid.uuid4()), DurationSeconds=900 + ) creds = response["Credentials"] creds["Expiration"] = str(creds["Expiration"]) @@ -32,12 +35,18 @@ def _assume_role(role_name, quiet=False): def main() -> None: """Execute Main entry point.""" - parser = argparse.ArgumentParser(description='Assume Role frontend.') + parser = argparse.ArgumentParser(description="Assume Role frontend.") - parser.add_argument('-v', "--verbose", action='store_true', help="Enable verbose logging") - parser.add_argument('-d', "--debug", action='store_true', help="Enable debug logging") + parser.add_argument( + "-v", "--verbose", action="store_true", help="Enable verbose logging" + ) + parser.add_argument( + "-d", "--debug", action="store_true", help="Enable debug logging" + ) - parser.add_argument('--role_name', type=str, default=STS_DEFAULT_ROLE_NAME, help="Role to assume") + parser.add_argument( + "--role_name", type=str, default=STS_DEFAULT_ROLE_NAME, help="Role to assume" + ) args = parser.parse_args() diff --git a/.evergreen/auth_aws/lib/aws_assume_web_role.py b/.evergreen/auth_aws/lib/aws_assume_web_role.py index 9e79a1fd..4703f4b6 100644 --- a/.evergreen/auth_aws/lib/aws_assume_web_role.py +++ b/.evergreen/auth_aws/lib/aws_assume_web_role.py @@ -11,15 +11,21 @@ LOGGER = logging.getLogger(__name__) + def _assume_role_with_web_identity(quiet=False): sts_client = boto3.client("sts") - token_file = os.environ['AWS_WEB_IDENTITY_TOKEN_FILE'] + token_file = os.environ["AWS_WEB_IDENTITY_TOKEN_FILE"] with open(token_file) as fid: token = fid.read() - role_name = os.environ['AWS_ROLE_ARN'] + role_name = os.environ["AWS_ROLE_ARN"] - response = sts_client.assume_role_with_web_identity(RoleArn=role_name, RoleSessionName=str(uuid.uuid4()), WebIdentityToken=token, DurationSeconds=900) + response = sts_client.assume_role_with_web_identity( + RoleArn=role_name, + RoleSessionName=str(uuid.uuid4()), + WebIdentityToken=token, + DurationSeconds=900, + ) creds = response["Credentials"] creds["Expiration"] = str(creds["Expiration"]) @@ -39,10 +45,14 @@ def _assume_role_with_web_identity(quiet=False): def main() -> None: """Execute Main entry point.""" - parser = argparse.ArgumentParser(description='Assume Role frontend.') + parser = argparse.ArgumentParser(description="Assume Role frontend.") - parser.add_argument('-v', "--verbose", action='store_true', help="Enable verbose logging") - parser.add_argument('-d', "--debug", action='store_true', help="Enable debug logging") + parser.add_argument( + "-v", "--verbose", action="store_true", help="Enable verbose logging" + ) + parser.add_argument( + "-d", "--debug", action="store_true", help="Enable debug logging" + ) args = parser.parse_args() diff --git a/.evergreen/auth_aws/lib/aws_handle_oidc_creds.py b/.evergreen/auth_aws/lib/aws_handle_oidc_creds.py index d719f63b..69ebb885 100644 --- a/.evergreen/auth_aws/lib/aws_handle_oidc_creds.py +++ b/.evergreen/auth_aws/lib/aws_handle_oidc_creds.py @@ -1,6 +1,7 @@ """ Script for handling OIDC credentials. """ + import argparse import base64 import os @@ -18,6 +19,7 @@ class CustomSubjectIdentifierFactory(HashBasedSubjectIdentifierFactory): """ Implements a hash based algorithm for creating a pairwise subject identifier. """ + def create_public_identifier(self, user_id): return user_id @@ -32,13 +34,13 @@ def create_pairwise_identifier(self, user_id, sector_identifier): def get_default_config(): return { - "issuer": os.getenv('IDP_ISSUER', ''), - "jwks_uri": os.getenv('IDP_JWKS_URI', ''), - 'rsa_key': os.getenv('IDP_RSA_KEY', ''), - 'client_id': os.getenv("IDP_CLIENT_ID", DEFAULT_CLIENT), - 'client_secret': os.getenv("IDP_CLIENT_SECRET", uuid.uuid4().hex), - 'username': os.getenv("IDP_USERNAME", 'test_user'), - 'token_file': os.getenv('AWS_WEB_IDENTITY_TOKEN_FILE') + "issuer": os.getenv("IDP_ISSUER", ""), + "jwks_uri": os.getenv("IDP_JWKS_URI", ""), + "rsa_key": os.getenv("IDP_RSA_KEY", ""), + "client_id": os.getenv("IDP_CLIENT_ID", DEFAULT_CLIENT), + "client_secret": os.getenv("IDP_CLIENT_SECRET", uuid.uuid4().hex), + "username": os.getenv("IDP_USERNAME", "test_user"), + "token_file": os.getenv("AWS_WEB_IDENTITY_TOKEN_FILE"), } @@ -46,62 +48,79 @@ def get_provider(config=None, expires=None): """Get a configured OIDC provider.""" config = config or get_default_config() configuration_information = { - 'issuer': config['issuer'], - 'authorization_endpoint': MOCK_ENDPOINT, - 'jwks_uri': config['jwks_uri'], - 'token_endpoint': MOCK_ENDPOINT, - 'userinfo_endpoint': MOCK_ENDPOINT, - 'registration_endpoint': MOCK_ENDPOINT, - 'end_session_endpoint': MOCK_ENDPOINT, - 'scopes_supported': ['openid', 'profile'], - 'response_types_supported': ['code', 'code id_token', 'code token', 'code id_token token'], # code and hybrid - 'response_modes_supported': ['query', 'fragment'], - 'grant_types_supported': ['authorization_code', 'implicit'], - 'subject_types_supported': ['public'], - 'token_endpoint_auth_methods_supported': ['client_secret_basic'], - 'claims_parameter_supported': True + "issuer": config["issuer"], + "authorization_endpoint": MOCK_ENDPOINT, + "jwks_uri": config["jwks_uri"], + "token_endpoint": MOCK_ENDPOINT, + "userinfo_endpoint": MOCK_ENDPOINT, + "registration_endpoint": MOCK_ENDPOINT, + "end_session_endpoint": MOCK_ENDPOINT, + "scopes_supported": ["openid", "profile"], + "response_types_supported": [ + "code", + "code id_token", + "code token", + "code id_token token", + ], # code and hybrid + "response_modes_supported": ["query", "fragment"], + "grant_types_supported": ["authorization_code", "implicit"], + "subject_types_supported": ["public"], + "token_endpoint_auth_methods_supported": ["client_secret_basic"], + "claims_parameter_supported": True, } - userinfo_db = Userinfo({config['username']: {}}) - kid = '1549e0aef574d1c7bdd136c202b8d290580b165c' - rsa_key = config['rsa_key'] - if rsa_key.endswith('='): - rsa_key = base64.urlsafe_b64decode(rsa_key).decode('utf-8') - signing_key = RSAKey(key=import_rsa_key(rsa_key), alg='RS256', use='sig', kid=kid) + userinfo_db = Userinfo({config["username"]: {}}) + kid = "1549e0aef574d1c7bdd136c202b8d290580b165c" + rsa_key = config["rsa_key"] + if rsa_key.endswith("="): + rsa_key = base64.urlsafe_b64decode(rsa_key).decode("utf-8") + signing_key = RSAKey(key=import_rsa_key(rsa_key), alg="RS256", use="sig", kid=kid) client_info = { - 'client_id': config['client_id'], - 'client_id_issued_at': int(time.time()), - 'client_secret': config['client_secret'], - 'redirect_uris': [MOCK_ENDPOINT], - 'response_types': ['code'], - 'client_secret_expires_at': 0 # never expires + "client_id": config["client_id"], + "client_id_issued_at": int(time.time()), + "client_secret": config["client_secret"], + "redirect_uris": [MOCK_ENDPOINT], + "response_types": ["code"], + "client_secret_expires_at": 0, # never expires } - clients = {config['client_id']: client_info} - auth_state = AuthorizationState(CustomSubjectIdentifierFactory('salt')) - expires = expires or 24*60*60 - return Provider(signing_key, configuration_information, - auth_state, clients, userinfo_db, id_token_lifetime=expires) + clients = {config["client_id"]: client_info} + auth_state = AuthorizationState(CustomSubjectIdentifierFactory("salt")) + expires = expires or 24 * 60 * 60 + return Provider( + signing_key, + configuration_information, + auth_state, + clients, + userinfo_db, + id_token_lifetime=expires, + ) def get_id_token(config=None, expires=None): """Get a valid ID token.""" config = config or get_default_config() provider = get_provider(config=config, expires=expires) - client_id = config['client_id'] - client_secret = config['client_secret'] - response = provider.parse_authentication_request(f'response_type=code&client_id={client_id}&scope=openid&redirect_uri={MOCK_ENDPOINT}') - resp = provider.authorize(response, config['username']) + client_id = config["client_id"] + client_secret = config["client_secret"] + response = provider.parse_authentication_request( + f"response_type=code&client_id={client_id}&scope=openid&redirect_uri={MOCK_ENDPOINT}" + ) + resp = provider.authorize(response, config["username"]) code = resp.to_dict()["code"] - creds = f'{client_id}:{client_secret}' - creds = base64.urlsafe_b64encode(creds.encode('utf-8')).decode('utf-8') - headers = dict(Authorization=f'Basic {creds}') - extra_claims = {'foo': ['readWrite'], 'bar': ['readWrite'] } - response = provider.handle_token_request(f'grant_type=authorization_code&subject_type=public&code={code}&redirect_uri={MOCK_ENDPOINT}', headers, extra_id_token_claims=extra_claims) + creds = f"{client_id}:{client_secret}" + creds = base64.urlsafe_b64encode(creds.encode("utf-8")).decode("utf-8") + headers = dict(Authorization=f"Basic {creds}") + extra_claims = {"foo": ["readWrite"], "bar": ["readWrite"]} + response = provider.handle_token_request( + f"grant_type=authorization_code&subject_type=public&code={code}&redirect_uri={MOCK_ENDPOINT}", + headers, + extra_id_token_claims=extra_claims, + ) token = response["id_token"] - if config['token_file']: - with open(config['token_file'], 'w') as fid: + if config["token_file"]: + with open(config["token_file"], "w") as fid: print(f"Writing token file: {config['token_file']}") fid.write(token) return token @@ -120,22 +139,26 @@ def get_config_data(): def get_user_id(): """Get the user id (sub) that will be used for authorization.""" config = get_default_config() - return get_provider(config).authz_state.get_subject_identifier('public', config['username'], "example.com") + return get_provider(config).authz_state.get_subject_identifier( + "public", config["username"], "example.com" + ) -if __name__ == '__main__': +if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument(dest='command', help="The command to run (config, jwks, token, user_id)") + parser.add_argument( + dest="command", help="The command to run (config, jwks, token, user_id)" + ) # Parse and print the results args = parser.parse_args() - if args.command == 'jwks': - print(get_jwks_data(), end='') - elif args.command == 'config': - print(get_config_data(), end='') - elif args.command == 'token': - print(get_id_token(), end='') - elif args.command == 'user_id': - print(get_user_id(), end='') + if args.command == "jwks": + print(get_jwks_data(), end="") + elif args.command == "config": + print(get_config_data(), end="") + elif args.command == "token": + print(get_id_token(), end="") + elif args.command == "user_id": + print(get_user_id(), end="") else: - raise ValueError('Command must be one of: (config, jwks, token, user_id)') + raise ValueError("Command must be one of: (config, jwks, token, user_id)") diff --git a/.evergreen/auth_aws/lib/aws_unassign_instance_profile.py b/.evergreen/auth_aws/lib/aws_unassign_instance_profile.py index c299bd10..0091d495 100644 --- a/.evergreen/auth_aws/lib/aws_unassign_instance_profile.py +++ b/.evergreen/auth_aws/lib/aws_unassign_instance_profile.py @@ -14,8 +14,16 @@ LOGGER = logging.getLogger(__name__) + def _get_local_instance_id(): - return urllib.request.urlopen('http://169.254.169.254/latest/meta-data/instance-id', timeout=5).read().decode() + return ( + urllib.request.urlopen( + "http://169.254.169.254/latest/meta-data/instance-id", timeout=5 + ) + .read() + .decode() + ) + def _has_instance_profile(): base_url = "http://169.254.169.254/latest/meta-data/iam/security-credentials/" @@ -40,6 +48,7 @@ def _has_instance_profile(): return True + def _wait_no_instance_profile(): retry = 60 while _has_instance_profile() and retry: @@ -49,23 +58,27 @@ def _wait_no_instance_profile(): if retry == 0: raise ValueError("Timeout on waiting for no instance profile") -def _unassign_instance_policy(): +def _unassign_instance_policy(): try: instance_id = _get_local_instance_id() except urllib.error.URLError as e: print(e) sys.exit(0) - ec2_client = boto3.client("ec2", 'us-east-1') + ec2_client = boto3.client("ec2", "us-east-1") - #https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/ec2.html#EC2.Client.describe_iam_instance_profile_associations + # https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/ec2.html#EC2.Client.describe_iam_instance_profile_associations try: - response = ec2_client.describe_iam_instance_profile_associations(Filters=[{"Name":"instance-id","Values": [instance_id]}]) - associations = response['IamInstanceProfileAssociations'] + response = ec2_client.describe_iam_instance_profile_associations( + Filters=[{"Name": "instance-id", "Values": [instance_id]}] + ) + associations = response["IamInstanceProfileAssociations"] if associations: - print('disassociating') - ec2_client.disassociate_iam_instance_profile(AssociationId=associations[0]['AssociationId']) + print("disassociating") + ec2_client.disassociate_iam_instance_profile( + AssociationId=associations[0]["AssociationId"] + ) # Wait for the instance profile to be assigned by polling the local instance metadata service _wait_no_instance_profile() @@ -76,13 +89,18 @@ def _unassign_instance_policy(): sys.exit(2) raise + def main() -> None: """Execute Main entry point.""" - parser = argparse.ArgumentParser(description='IAM UnAssign Instance frontend.') + parser = argparse.ArgumentParser(description="IAM UnAssign Instance frontend.") - parser.add_argument('-v', "--verbose", action='store_true', help="Enable verbose logging") - parser.add_argument('-d', "--debug", action='store_true', help="Enable debug logging") + parser.add_argument( + "-v", "--verbose", action="store_true", help="Enable verbose logging" + ) + parser.add_argument( + "-d", "--debug", action="store_true", help="Enable debug logging" + ) args = parser.parse_args() diff --git a/.evergreen/auth_aws/lib/container_tester.py b/.evergreen/auth_aws/lib/container_tester.py index 7500fc0e..8f8e96b4 100644 --- a/.evergreen/auth_aws/lib/container_tester.py +++ b/.evergreen/auth_aws/lib/container_tester.py @@ -24,13 +24,15 @@ # These settings depend on a cluster, task subnets, and security group already setup ECS_DEFAULT_CLUSTER = "arn:aws:ecs:us-east-2:579766882180:cluster/tf-mcb-ecs-cluster" -ECS_DEFAULT_TASK_DEFINITION = "arn:aws:ecs:us-east-2:579766882180:task-definition/tf-app:2" -ECS_DEFAULT_SUBNETS = ['subnet-a5e114cc'] +ECS_DEFAULT_TASK_DEFINITION = ( + "arn:aws:ecs:us-east-2:579766882180:task-definition/tf-app:2" +) +ECS_DEFAULT_SUBNETS = ["subnet-a5e114cc"] # Must allow ssh from 0.0.0.0 -ECS_DEFAULT_SECURITY_GROUP = 'sg-051a91d96332f8f3a' +ECS_DEFAULT_SECURITY_GROUP = "sg-051a91d96332f8f3a" # This is just a string local to this file -DEFAULT_SERVICE_NAME = 'script-test' +DEFAULT_SERVICE_NAME = "script-test" # Garbage collection threshold for old/stale services DEFAULT_GARBAGE_COLLECTION_THRESHOLD = datetime.timedelta(hours=1) @@ -43,34 +45,55 @@ def _run_process(params, cwd=None): ret = subprocess.run(params, cwd=cwd, check=False) return ret.returncode + def _userandhostandport(endpoint): user_and_host = endpoint.find("@") if user_and_host == -1: raise ValueError("Invalid endpoint, Endpoint must be user@host:port") - (user, host) = (endpoint[:user_and_host], endpoint[user_and_host + 1:]) + (user, host) = (endpoint[:user_and_host], endpoint[user_and_host + 1 :]) colon = host.find(":") if colon == -1: return (user, host, "22") - return (user, host[:colon], host[colon + 1:]) + return (user, host[:colon], host[colon + 1 :]) + def _scp(endpoint, src, dest): (user, host, port) = _userandhostandport(endpoint) - cmd = ["scp", "-o", "StrictHostKeyChecking=no", "-P", port, src, "%s@%s:%s" % (user, host, dest)] + cmd = [ + "scp", + "-o", + "StrictHostKeyChecking=no", + "-P", + port, + src, + "%s@%s:%s" % (user, host, dest), + ] if os.path.isdir(src): - cmd.insert(5, "-r") + cmd.insert(5, "-r") _run_process(cmd) + def _ssh(endpoint, cmd): (user, host, port) = _userandhostandport(endpoint) - cmd = ["ssh", "-o", "StrictHostKeyChecking=no", "-p", port, "%s@%s" % (user, host), cmd ] + cmd = [ + "ssh", + "-o", + "StrictHostKeyChecking=no", + "-p", + port, + "%s@%s" % (user, host), + cmd, + ] ret = _run_process(cmd) LOGGER.info("RETURN CODE: %s", ret) return ret + def _run_test_args(args): run_test(args.endpoint, args.script, args.files) + def run_test(endpoint, script, files): """ Run a test on a machine @@ -84,7 +107,7 @@ def run_test(endpoint, script, files): for file in files: colon = file.find(":") - (src, dest) = (file[:colon], file[colon + 1:]) + (src, dest) = (file[:colon], file[colon + 1 :]) _scp(endpoint, src, dest) LOGGER.info("Copying script to %s", endpoint) @@ -94,100 +117,125 @@ def run_test(endpoint, script, files): LOGGER.error("FAILED: %s", return_code) raise ValueError(f"test failed with {return_code}") + def _get_region(arn): - return arn.split(':')[3] + return arn.split(":")[3] def _remote_ps_container_args(args): remote_ps_container(args.cluster) + def remote_ps_container(cluster): """ Get a list of task running in the cluster with their network addresses. Emulates the docker ps and ecs-cli ps commands. """ - ecs_client = boto3.client('ecs', region_name=_get_region(cluster)) - ec2_client = boto3.client('ec2', region_name=_get_region(cluster)) + ecs_client = boto3.client("ecs", region_name=_get_region(cluster)) + ec2_client = boto3.client("ec2", region_name=_get_region(cluster)) tasks = ecs_client.list_tasks(cluster=cluster) - task_list = ecs_client.describe_tasks(cluster=cluster, tasks=tasks['taskArns']) - - #Example from ecs-cli tool - #Name State Ports TaskDefinition Health - #aa2c2642-3013-4370-885e-8b8d956e753d/sshd RUNNING 3.15.149.114:22->22/tcp sshd:1 UNKNOWN + task_list = ecs_client.describe_tasks(cluster=cluster, tasks=tasks["taskArns"]) - print("Name State Public IP Private IP TaskDefinition Health") - for task in task_list['tasks']: + # Example from ecs-cli tool + # Name State Ports TaskDefinition Health + # aa2c2642-3013-4370-885e-8b8d956e753d/sshd RUNNING 3.15.149.114:22->22/tcp sshd:1 UNKNOWN - taskDefinition = task['taskDefinitionArn'] - taskDefinition_short = taskDefinition[taskDefinition.rfind('/') + 1:] + print( + "Name State Public IP Private IP TaskDefinition Health" + ) + for task in task_list["tasks"]: + taskDefinition = task["taskDefinitionArn"] + taskDefinition_short = taskDefinition[taskDefinition.rfind("/") + 1 :] private_ip_address = None enis = [] - for b in [ a['details'] for a in task["attachments"] if a['type'] == 'ElasticNetworkInterface']: + for b in [ + a["details"] + for a in task["attachments"] + if a["type"] == "ElasticNetworkInterface" + ]: for c in b: - if c['name'] == 'networkInterfaceId': - enis.append(c['value']) - elif c['name'] == 'privateIPv4Address': - private_ip_address = c['value'] + if c["name"] == "networkInterfaceId": + enis.append(c["value"]) + elif c["name"] == "privateIPv4Address": + private_ip_address = c["value"] assert enis assert private_ip_address eni = ec2_client.describe_network_interfaces(NetworkInterfaceIds=enis) - public_ip = next(iter(n["Association"]["PublicIp"] for n in eni["NetworkInterfaces"])) + public_ip = next( + iter(n["Association"]["PublicIp"] for n in eni["NetworkInterfaces"]) + ) - for container in task['containers']: - taskArn = container['taskArn'] - task_id = taskArn[taskArn.rfind('/')+ 1:] - name = container['name'] + for container in task["containers"]: + taskArn = container["taskArn"] + task_id = taskArn[taskArn.rfind("/") + 1 :] + name = container["name"] task_id = task_id + "/" + name - lastStatus = container['lastStatus'] + lastStatus = container["lastStatus"] + + print( + f"{task_id:<43}{lastStatus:<9}{public_ip:<25}{private_ip_address:<25}{taskDefinition_short:<16}" + ) - print(f"{task_id:<43}{lastStatus:<9}{public_ip:<25}{private_ip_address:<25}{taskDefinition_short:<16}") def _remote_create_container_args(args): - remote_create_container(args.cluster, args.task_definition, args.service, args.subnets, args.security_group) - -def remote_create_container(cluster, task_definition, service_name, subnets, security_group): + remote_create_container( + args.cluster, + args.task_definition, + args.service, + args.subnets, + args.security_group, + ) + + +def remote_create_container( + cluster, task_definition, service_name, subnets, security_group +): """ Create a task in ECS """ - ecs_client = boto3.client('ecs', region_name=_get_region(cluster)) - - resp = ecs_client.create_service(cluster=cluster, serviceName=service_name, - taskDefinition = task_definition, - desiredCount = 1, - launchType='FARGATE', + ecs_client = boto3.client("ecs", region_name=_get_region(cluster)) + + resp = ecs_client.create_service( + cluster=cluster, + serviceName=service_name, + taskDefinition=task_definition, + desiredCount=1, + launchType="FARGATE", networkConfiguration={ - 'awsvpcConfiguration': { - 'subnets': subnets, - 'securityGroups': [ + "awsvpcConfiguration": { + "subnets": subnets, + "securityGroups": [ security_group, ], - 'assignPublicIp': "ENABLED" + "assignPublicIp": "ENABLED", } - } - ) + }, + ) pprint.pprint(resp) service_arn = resp["service"]["serviceArn"] print(f"Waiting for Service {service_arn} to become active...") - waiter = ecs_client.get_waiter('services_stable') + waiter = ecs_client.get_waiter("services_stable") waiter.wait(cluster=cluster, services=[service_arn]) + def _remote_stop_container_args(args): remote_stop_container(args.cluster, args.service) + def remote_stop_container(cluster, service_name): """ Stop a ECS task """ - ecs_client = boto3.client('ecs', region_name=_get_region(cluster)) + ecs_client = boto3.client("ecs", region_name=_get_region(cluster)) resp = ecs_client.delete_service(cluster=cluster, service=service_name, force=True) pprint.pprint(resp) @@ -195,99 +243,129 @@ def remote_stop_container(cluster, service_name): service_arn = resp["service"]["serviceArn"] print(f"Waiting for Service {service_arn} to become inactive...") - waiter = ecs_client.get_waiter('services_inactive') + waiter = ecs_client.get_waiter("services_inactive") waiter.wait(cluster=cluster, services=[service_arn]) + def _remote_gc_services_container_args(args): remote_gc_services_container(args.cluster) + def remote_gc_services_container(cluster): """ Delete all ECS services over then a given threshold. """ - ecs_client = boto3.client('ecs', region_name=_get_region(cluster)) + ecs_client = boto3.client("ecs", region_name=_get_region(cluster)) services = ecs_client.list_services(cluster=cluster) if not services["serviceArns"]: return - services_details = ecs_client.describe_services(cluster=cluster, services=services["serviceArns"]) + services_details = ecs_client.describe_services( + cluster=cluster, services=services["serviceArns"] + ) - not_expired_now = datetime.datetime.now().astimezone() - DEFAULT_GARBAGE_COLLECTION_THRESHOLD + not_expired_now = ( + datetime.datetime.now().astimezone() - DEFAULT_GARBAGE_COLLECTION_THRESHOLD + ) for service in services_details["services"]: created_at = service["createdAt"] # Find the services that we created "too" long ago if created_at < not_expired_now: - print("DELETING expired service %s which was created at %s." % (service["serviceName"], created_at)) + print( + "DELETING expired service %s which was created at %s." + % (service["serviceName"], created_at) + ) remote_stop_container(cluster, service["serviceName"]) + def remote_get_public_endpoint_str(cluster, service_name): """ Get an SSH connection string for the remote service via the public ip address """ - ecs_client = boto3.client('ecs', region_name=_get_region(cluster)) - ec2_client = boto3.client('ec2', region_name=_get_region(cluster)) + ecs_client = boto3.client("ecs", region_name=_get_region(cluster)) + ec2_client = boto3.client("ec2", region_name=_get_region(cluster)) tasks = ecs_client.list_tasks(cluster=cluster, serviceName=service_name) - task_list = ecs_client.describe_tasks(cluster=cluster, tasks=tasks['taskArns']) - - for task in task_list['tasks']: + task_list = ecs_client.describe_tasks(cluster=cluster, tasks=tasks["taskArns"]) + for task in task_list["tasks"]: enis = [] - for b in [ a['details'] for a in task["attachments"] if a['type'] == 'ElasticNetworkInterface']: + for b in [ + a["details"] + for a in task["attachments"] + if a["type"] == "ElasticNetworkInterface" + ]: for c in b: - if c['name'] == 'networkInterfaceId': - enis.append(c['value']) + if c["name"] == "networkInterfaceId": + enis.append(c["value"]) assert enis eni = ec2_client.describe_network_interfaces(NetworkInterfaceIds=enis) - public_ip = next(iter(n["Association"]["PublicIp"] for n in eni["NetworkInterfaces"])) + public_ip = next( + iter(n["Association"]["PublicIp"] for n in eni["NetworkInterfaces"]) + ) break return f"root@{public_ip}:22" + def remote_get_endpoint_str(cluster, service_name): """ Get an SSH connection string for the remote service via the private ip address """ - ecs_client = boto3.client('ecs', region_name=_get_region(cluster)) + ecs_client = boto3.client("ecs", region_name=_get_region(cluster)) tasks = ecs_client.list_tasks(cluster=cluster, serviceName=service_name) - task_list = ecs_client.describe_tasks(cluster=cluster, tasks=tasks['taskArns']) - - for task in task_list['tasks']: + task_list = ecs_client.describe_tasks(cluster=cluster, tasks=tasks["taskArns"]) + for task in task_list["tasks"]: private_ip_address = None - for b in [ a['details'] for a in task["attachments"] if a['type'] == 'ElasticNetworkInterface']: + for b in [ + a["details"] + for a in task["attachments"] + if a["type"] == "ElasticNetworkInterface" + ]: for c in b: - if c['name'] == 'privateIPv4Address': - private_ip_address = c['value'] + if c["name"] == "privateIPv4Address": + private_ip_address = c["value"] assert private_ip_address break return f"root@{private_ip_address}:22" + def _remote_get_endpoint_args(args): _remote_get_endpoint(args.cluster, args.service) + def _remote_get_endpoint(cluster, service_name): endpoint = remote_get_endpoint_str(cluster, service_name) print(endpoint) + def _get_caller_identity(args): - sts_client = boto3.client('sts') + sts_client = boto3.client("sts") pprint.pprint(sts_client.get_caller_identity()) def _run_e2e_test_args(args): - _run_e2e_test(args.script, args.files, args.cluster, args.task_definition, args.subnets, args.security_group) + _run_e2e_test( + args.script, + args.files, + args.cluster, + args.task_definition, + args.subnets, + args.security_group, + ) + def _run_e2e_test(script, files, cluster, task_definition, subnets, security_group): """ @@ -299,7 +377,9 @@ def _run_e2e_test(script, files, cluster, task_definition, subnets, security_gro """ service_name = str(uuid.uuid4()) - remote_create_container(cluster, task_definition, service_name, subnets, security_group) + remote_create_container( + cluster, task_definition, service_name, subnets, security_group + ) # The build account hosted ECS tasks are only available via the private ip address endpoint = remote_get_endpoint_str(cluster, service_name) @@ -316,67 +396,150 @@ def _run_e2e_test(script, files, cluster, task_definition, subnets, security_gro def main() -> None: """Execute Main entry point.""" - parser = argparse.ArgumentParser(description='ECS container tester.') - - parser.add_argument('-v', "--verbose", action='store_true', help="Enable verbose logging") - parser.add_argument('-d', "--debug", action='store_true', help="Enable debug logging") - - sub = parser.add_subparsers(title="Container Tester subcommands", help="sub-command help") - - run_test_cmd = sub.add_parser('run_test', help='Run Test') - run_test_cmd.add_argument("--endpoint", required=True, type=str, help="User and Host and port, ie user@host:port") + parser = argparse.ArgumentParser(description="ECS container tester.") + + parser.add_argument( + "-v", "--verbose", action="store_true", help="Enable verbose logging" + ) + parser.add_argument( + "-d", "--debug", action="store_true", help="Enable debug logging" + ) + + sub = parser.add_subparsers( + title="Container Tester subcommands", help="sub-command help" + ) + + run_test_cmd = sub.add_parser("run_test", help="Run Test") + run_test_cmd.add_argument( + "--endpoint", + required=True, + type=str, + help="User and Host and port, ie user@host:port", + ) run_test_cmd.add_argument("--script", required=True, type=str, help="script to run") - run_test_cmd.add_argument("--files", type=str, nargs="*", help="Files to copy, each string must be a pair of src:dest joined by a colon") + run_test_cmd.add_argument( + "--files", + type=str, + nargs="*", + help="Files to copy, each string must be a pair of src:dest joined by a colon", + ) run_test_cmd.set_defaults(func=_run_test_args) - remote_ps_cmd = sub.add_parser('remote_ps', help='Stop Local Container') - remote_ps_cmd.add_argument("--cluster", type=str, default=ECS_DEFAULT_CLUSTER, help="ECS Cluster to target") + remote_ps_cmd = sub.add_parser("remote_ps", help="Stop Local Container") + remote_ps_cmd.add_argument( + "--cluster", type=str, default=ECS_DEFAULT_CLUSTER, help="ECS Cluster to target" + ) remote_ps_cmd.set_defaults(func=_remote_ps_container_args) - remote_create_cmd = sub.add_parser('remote_create', help='Create Remote Container') - remote_create_cmd.add_argument("--cluster", type=str, default=ECS_DEFAULT_CLUSTER, help="ECS Cluster to target") - remote_create_cmd.add_argument("--service", type=str, default=DEFAULT_SERVICE_NAME, help="ECS Service to create") - remote_create_cmd.add_argument("--task_definition", type=str, default=ECS_DEFAULT_TASK_DEFINITION, help="ECS Task Definition to use to create service") - remote_create_cmd.add_argument("--subnets", type=str, nargs="*", default=ECS_DEFAULT_SUBNETS, help="EC2 subnets to use") - remote_create_cmd.add_argument("--security_group", type=str, default=ECS_DEFAULT_SECURITY_GROUP, help="EC2 security group use") + remote_create_cmd = sub.add_parser("remote_create", help="Create Remote Container") + remote_create_cmd.add_argument( + "--cluster", type=str, default=ECS_DEFAULT_CLUSTER, help="ECS Cluster to target" + ) + remote_create_cmd.add_argument( + "--service", + type=str, + default=DEFAULT_SERVICE_NAME, + help="ECS Service to create", + ) + remote_create_cmd.add_argument( + "--task_definition", + type=str, + default=ECS_DEFAULT_TASK_DEFINITION, + help="ECS Task Definition to use to create service", + ) + remote_create_cmd.add_argument( + "--subnets", + type=str, + nargs="*", + default=ECS_DEFAULT_SUBNETS, + help="EC2 subnets to use", + ) + remote_create_cmd.add_argument( + "--security_group", + type=str, + default=ECS_DEFAULT_SECURITY_GROUP, + help="EC2 security group use", + ) remote_create_cmd.set_defaults(func=_remote_create_container_args) - remote_stop_cmd = sub.add_parser('remote_stop', help='Stop Remote Container') - remote_stop_cmd.add_argument("--cluster", type=str, default=ECS_DEFAULT_CLUSTER, help="ECS Cluster to target") - remote_stop_cmd.add_argument("--service", type=str, default=DEFAULT_SERVICE_NAME, help="ECS Service to stop") + remote_stop_cmd = sub.add_parser("remote_stop", help="Stop Remote Container") + remote_stop_cmd.add_argument( + "--cluster", type=str, default=ECS_DEFAULT_CLUSTER, help="ECS Cluster to target" + ) + remote_stop_cmd.add_argument( + "--service", type=str, default=DEFAULT_SERVICE_NAME, help="ECS Service to stop" + ) remote_stop_cmd.set_defaults(func=_remote_stop_container_args) - remote_gc_services_cmd = sub.add_parser('remote_gc_services', help='GC Remote Container') - remote_gc_services_cmd.add_argument("--cluster", type=str, default=ECS_DEFAULT_CLUSTER, help="ECS Cluster to target") + remote_gc_services_cmd = sub.add_parser( + "remote_gc_services", help="GC Remote Container" + ) + remote_gc_services_cmd.add_argument( + "--cluster", type=str, default=ECS_DEFAULT_CLUSTER, help="ECS Cluster to target" + ) remote_gc_services_cmd.set_defaults(func=_remote_gc_services_container_args) - get_caller_identity_cmd = sub.add_parser('get_caller_identity', help='Get the AWS IAM caller identity') + get_caller_identity_cmd = sub.add_parser( + "get_caller_identity", help="Get the AWS IAM caller identity" + ) get_caller_identity_cmd.set_defaults(func=_get_caller_identity) - remote_get_endpoint_cmd = sub.add_parser('remote_get_endpoint', help='Get SSH remote endpoint') - remote_get_endpoint_cmd.add_argument("--cluster", type=str, default=ECS_DEFAULT_CLUSTER, help="ECS Cluster to target") - remote_get_endpoint_cmd.add_argument("--service", type=str, default=DEFAULT_SERVICE_NAME, help="ECS Service to stop") + remote_get_endpoint_cmd = sub.add_parser( + "remote_get_endpoint", help="Get SSH remote endpoint" + ) + remote_get_endpoint_cmd.add_argument( + "--cluster", type=str, default=ECS_DEFAULT_CLUSTER, help="ECS Cluster to target" + ) + remote_get_endpoint_cmd.add_argument( + "--service", type=str, default=DEFAULT_SERVICE_NAME, help="ECS Service to stop" + ) remote_get_endpoint_cmd.set_defaults(func=_remote_get_endpoint_args) - run_e2e_test_cmd = sub.add_parser('run_e2e_test', help='Run Test') - run_e2e_test_cmd.add_argument("--script", required=True, type=str, help="script to run") - run_e2e_test_cmd.add_argument("--files", type=str, nargs="*", help="Files to copy, each string must be a pair of src:dest joined by a colon") - run_e2e_test_cmd.add_argument("--cluster", type=str, default=ECS_DEFAULT_CLUSTER, help="ECS Cluster to target") - run_e2e_test_cmd.add_argument("--task_definition", type=str, default=ECS_DEFAULT_TASK_DEFINITION, help="ECS Task Definition to use to create service") - run_e2e_test_cmd.add_argument("--subnets", type=str, nargs="*", default=ECS_DEFAULT_SUBNETS, help="EC2 subnets to use") - run_e2e_test_cmd.add_argument("--security_group", type=str, default=ECS_DEFAULT_SECURITY_GROUP, help="EC2 security group use") + run_e2e_test_cmd = sub.add_parser("run_e2e_test", help="Run Test") + run_e2e_test_cmd.add_argument( + "--script", required=True, type=str, help="script to run" + ) + run_e2e_test_cmd.add_argument( + "--files", + type=str, + nargs="*", + help="Files to copy, each string must be a pair of src:dest joined by a colon", + ) + run_e2e_test_cmd.add_argument( + "--cluster", type=str, default=ECS_DEFAULT_CLUSTER, help="ECS Cluster to target" + ) + run_e2e_test_cmd.add_argument( + "--task_definition", + type=str, + default=ECS_DEFAULT_TASK_DEFINITION, + help="ECS Task Definition to use to create service", + ) + run_e2e_test_cmd.add_argument( + "--subnets", + type=str, + nargs="*", + default=ECS_DEFAULT_SUBNETS, + help="EC2 subnets to use", + ) + run_e2e_test_cmd.add_argument( + "--security_group", + type=str, + default=ECS_DEFAULT_SECURITY_GROUP, + help="EC2 security group use", + ) run_e2e_test_cmd.set_defaults(func=_run_e2e_test_args) args = parser.parse_args() - print("AWS_SHARED_CREDENTIALS_FILE: %s" % (os.getenv("AWS_SHARED_CREDENTIALS_FILE"))) + print( + "AWS_SHARED_CREDENTIALS_FILE: %s" % (os.getenv("AWS_SHARED_CREDENTIALS_FILE")) + ) if args.debug: logging.basicConfig(level=logging.DEBUG) elif args.verbose: logging.basicConfig(level=logging.INFO) - args.func(args) diff --git a/.evergreen/auth_oidc/azure/handle_secrets.py b/.evergreen/auth_oidc/azure/handle_secrets.py index f8866baf..fc2267e5 100644 --- a/.evergreen/auth_oidc/azure/handle_secrets.py +++ b/.evergreen/auth_oidc/azure/handle_secrets.py @@ -8,61 +8,71 @@ def main(): vault_name = os.environ["AZUREOIDC_KEYVAULT"] - private_key_file = os.environ['AZUREKMS_PRIVATEKEYPATH'] - public_key_file = os.environ['AZUREKMS_PUBLICKEYPATH'] - app_id = os.environ['AZUREOIDC_APPID'] - env_file = os.environ['AZUREOIDC_ENVPATH'] - tenant_id = os.environ['AZUREOIDC_TENANTID'] + private_key_file = os.environ["AZUREKMS_PRIVATEKEYPATH"] + public_key_file = os.environ["AZUREKMS_PUBLICKEYPATH"] + app_id = os.environ["AZUREOIDC_APPID"] + env_file = os.environ["AZUREOIDC_ENVPATH"] + tenant_id = os.environ["AZUREOIDC_TENANTID"] vault_uri = f"https://{vault_name}.vault.azure.net" - print('Getting secrets from vault ... begin') + print("Getting secrets from vault ... begin") - logger = logging.getLogger('azure.mgmt.resource') + logger = logging.getLogger("azure.mgmt.resource") # Set the desired logging level logger.setLevel(logging.DEBUG) - credential = DefaultAzureCredential(exclude_environment_credential=True, exclude_managed_identity_credential=True) + credential = DefaultAzureCredential( + exclude_environment_credential=True, exclude_managed_identity_credential=True + ) client = SecretClient(vault_url=vault_uri, credential=credential) secrets = dict() - for secret in ['RESOURCEGROUP', 'PUBLICKEY', 'PRIVATEKEY', 'AUTHCLAIM', 'AUTHPREFIX', 'IDENTITY', - 'USERNAME', 'AUDIENCE']: + for secret in [ + "RESOURCEGROUP", + "PUBLICKEY", + "PRIVATEKEY", + "AUTHCLAIM", + "AUTHPREFIX", + "IDENTITY", + "USERNAME", + "AUDIENCE", + ]: retrieved = client.get_secret(secret) secrets[secret] = retrieved.value uri = "mongodb://localhost" suffix = "authMechanism=MONGODB-OIDC&authMechanismProperties=ENVIRONMENT:azure" suffix += f",TOKEN_RESOURCE:{secrets['AUDIENCE']}" - with open(env_file, 'w') as fid: + with open(env_file, "w") as fid: fid.write(f'export AZUREOIDC_RESOURCEGROUP={secrets["RESOURCEGROUP"]}\n') fid.write(f'export AZUREKMS_RESOURCEGROUP={secrets["RESOURCEGROUP"]}\n') fid.write(f'export AZUREOIDC_AUTHCLAIM={secrets["AUTHCLAIM"]}\n') - fid.write(f'export AZUREOIDC_APPID={app_id}\n') - fid.write(f'export AZUREOIDC_TENANTID={tenant_id}\n') + fid.write(f"export AZUREOIDC_APPID={app_id}\n") + fid.write(f"export AZUREOIDC_TENANTID={tenant_id}\n") fid.write(f'export AZUREOIDC_AUTHPREFIX={secrets["AUTHPREFIX"]}\n') fid.write(f'export AZUREKMS_IDENTITY="{secrets["IDENTITY"]}"\n') fid.write(f'export AZUREOIDC_USERNAME="{secrets["USERNAME"]}"\n') fid.write(f'export AZUREOIDC_RESOURCE="{secrets["AUDIENCE"]}"\n') fid.write(f'export OIDC_ADMIN_USER="{secrets["USERNAME"]}"\n') - fid.write('export OIDC_ADMIN_PWD=pwd123\n') + fid.write("export OIDC_ADMIN_PWD=pwd123\n") fid.write(f'export MONGODB_URI="{uri}"\n') fid.write(f'export MONGODB_URI_SINGLE="{uri}/?{suffix}"\n') if os.path.exists(private_key_file): os.remove(private_key_file) - with open(private_key_file, 'w') as fid: - fid.write(b64decode(secrets['PRIVATEKEY']).decode('utf8')) + with open(private_key_file, "w") as fid: + fid.write(b64decode(secrets["PRIVATEKEY"]).decode("utf8")) os.chmod(private_key_file, 0o400) if os.path.exists(public_key_file): os.remove(public_key_file) - with open(public_key_file, 'w') as fid: - fid.write(b64decode(secrets['PUBLICKEY']).decode('utf8')) + with open(public_key_file, "w") as fid: + fid.write(b64decode(secrets["PUBLICKEY"]).decode("utf8")) os.chmod(public_key_file, 0o400) - print('Getting secrets from vault ... end') + print("Getting secrets from vault ... end") return secrets -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/.evergreen/auth_oidc/azure/remote-scripts/test.py b/.evergreen/auth_oidc/azure/remote-scripts/test.py index ba7980f7..f6c29ba9 100644 --- a/.evergreen/auth_oidc/azure/remote-scripts/test.py +++ b/.evergreen/auth_oidc/azure/remote-scripts/test.py @@ -5,8 +5,9 @@ from pymongo import MongoClient from pymongo.auth_oidc import OIDCCallback, OIDCCallbackContext, OIDCCallbackResult -app_id = os.environ['AZUREOIDC_APPID'] -client_id = os.environ['AZUREOIDC_USERNAME'] +app_id = os.environ["AZUREOIDC_APPID"] +client_id = os.environ["AZUREOIDC_USERNAME"] + class MyCallback(OIDCCallback): def fetch(self, context: OIDCCallbackContext) -> OIDCCallbackResult: @@ -14,13 +15,13 @@ def fetch(self, context: OIDCCallbackContext) -> OIDCCallbackResult: url += "?api-version=2018-02-01" url += f"&resource=api://{app_id}" url += f"&client_id={client_id}" - headers = { "Metadata": "true", "Accept": "application/json" } - print('Fetching url', url) + headers = {"Metadata": "true", "Accept": "application/json"} + print("Fetching url", url) request = Request(url, headers=headers) try: with urlopen(request, timeout=context.timeout_seconds) as response: status = response.status - body = response.read().decode('utf8') + body = response.read().decode("utf8") except Exception as e: msg = "Failed to acquire IMDS access token: %s" % e raise ValueError(msg) from e @@ -39,12 +40,16 @@ def fetch(self, context: OIDCCallbackContext) -> OIDCCallbackResult: msg = "Azure IMDS response must contain %s, but was %s." msg = msg % (key, body) raise ValueError(msg) - return OIDCCallbackResult(access_token=data['access_token']) + return OIDCCallbackResult(access_token=data["access_token"]) + props = dict(OIDC_CALLBACK=MyCallback()) -print('Testing MONGODB-OIDC on azure...') -c = MongoClient('mongodb://localhost:27017/?authMechanism=MONGODB-OIDC', authMechanismProperties=props) +print("Testing MONGODB-OIDC on azure...") +c = MongoClient( + "mongodb://localhost:27017/?authMechanism=MONGODB-OIDC", + authMechanismProperties=props, +) c.test.test.insert_one({}) c.close() -print('Testing MONGODB-OIDC on azure... done.') -print('Self test complete!') +print("Testing MONGODB-OIDC on azure... done.") +print("Self test complete!") diff --git a/.evergreen/auth_oidc/azure_func/self-test/function_app.py b/.evergreen/auth_oidc/azure_func/self-test/function_app.py index 0c836c20..8a8528fe 100644 --- a/.evergreen/auth_oidc/azure_func/self-test/function_app.py +++ b/.evergreen/auth_oidc/azure_func/self-test/function_app.py @@ -9,21 +9,22 @@ app = func.FunctionApp(http_auth_level=func.AuthLevel.FUNCTION) + def _get_token(): - resource=os.environ['APPSETTING_RESOURCE'] - client_id= os.environ['APPSETTING_CLIENT_ID'] - url = os.environ['IDENTITY_ENDPOINT'] - url += '?api-version=2019-08-01' - url += f'&resource={resource}' - url += f'&client_id={client_id}' + resource = os.environ["APPSETTING_RESOURCE"] + client_id = os.environ["APPSETTING_CLIENT_ID"] + url = os.environ["IDENTITY_ENDPOINT"] + url += "?api-version=2019-08-01" + url += f"&resource={resource}" + url += f"&client_id={client_id}" - headers = { "X-IDENTITY-HEADER": os.environ['IDENTITY_HEADER'] } + headers = {"X-IDENTITY-HEADER": os.environ["IDENTITY_HEADER"]} request = Request(url, headers=headers) - logging.info('Making a token request.') + logging.info("Making a token request.") with urlopen(request, timeout=30) as response: - body = response.read().decode('utf8') - return json.loads(body)['access_token'] + body = response.read().decode("utf8") + return json.loads(body)["access_token"] class MyCallback(OIDCCallback): @@ -33,28 +34,30 @@ def fetch(self, context: OIDCCallbackContext) -> OIDCCallbackResult: @app.route(route="gettoken") def gettoken(req: func.HttpRequest) -> func.HttpResponse: - logging.info('Handling a gettoken request.') + logging.info("Handling a gettoken request.") try: token = _get_token() except Exception as e: return func.HttpResponse(str(e), status_code=500) - logging.info('Returning the token.') + logging.info("Returning the token.") return func.HttpResponse(token) -@app.route(route='oidcselftest') +@app.route(route="oidcselftest") def oidcselftest(req: func.HttpRequest) -> func.HttpResponse: - logging.info('Handling an oidcselftest request.') + logging.info("Handling an oidcselftest request.") try: req_body = req.get_json() - uri = req_body.get('MONGODB_URI') + uri = req_body.get("MONGODB_URI") props = dict(OIDC_CALLBACK=MyCallback()) - logging.info('Testing MONGODB-OIDC on azure functions...') - c = MongoClient(f'{uri}/?authMechanism=MONGODB-OIDC', authMechanismProperties=props) + logging.info("Testing MONGODB-OIDC on azure functions...") + c = MongoClient( + f"{uri}/?authMechanism=MONGODB-OIDC", authMechanismProperties=props + ) c.test.test.insert_one({}) c.close() except Exception as e: return func.HttpResponse(str(e), status_code=500) - logging.info('Testing MONGODB-OIDC on azure functions... done.') - logging.info('Self test complete!') - return func.HttpResponse('Success!') + logging.info("Testing MONGODB-OIDC on azure functions... done.") + logging.info("Self test complete!") + return func.HttpResponse("Success!") diff --git a/.evergreen/auth_oidc/gcp/remote-scripts/test.py b/.evergreen/auth_oidc/gcp/remote-scripts/test.py index 293fa9af..995f0338 100644 --- a/.evergreen/auth_oidc/gcp/remote-scripts/test.py +++ b/.evergreen/auth_oidc/gcp/remote-scripts/test.py @@ -4,20 +4,21 @@ from pymongo import MongoClient from pymongo.auth_oidc import OIDCCallback, OIDCCallbackContext, OIDCCallbackResult -audience = os.environ['GCPOIDC_AUDIENCE'] +audience = os.environ["GCPOIDC_AUDIENCE"] atlas_uri = os.environ["MONGODB_URI"] + class MyCallback(OIDCCallback): def fetch(self, context: OIDCCallbackContext) -> OIDCCallbackResult: url = "http://metadata/computeMetadata/v1/instance/service-accounts/default/identity" url += f"?audience={audience}" - headers = { "Metadata-Flavor": "Google" } - print('Fetching url', url) + headers = {"Metadata-Flavor": "Google"} + print("Fetching url", url) request = Request(url, headers=headers) try: with urlopen(request, timeout=context.timeout_seconds) as response: status = response.status - body = response.read().decode('utf8') + body = response.read().decode("utf8") except Exception as e: msg = "Failed to acquire IMDS access token: %s" % e raise ValueError(msg) from e @@ -29,10 +30,13 @@ def fetch(self, context: OIDCCallbackContext) -> OIDCCallbackResult: return OIDCCallbackResult(access_token=body) + props = dict(OIDC_CALLBACK=MyCallback()) -print('Testing MONGODB-OIDC on gcp...') -c = MongoClient(f'{atlas_uri}/?authMechanism=MONGODB-OIDC', authMechanismProperties=props) +print("Testing MONGODB-OIDC on gcp...") +c = MongoClient( + f"{atlas_uri}/?authMechanism=MONGODB-OIDC", authMechanismProperties=props +) c.test.test.insert_one({}) c.close() -print('Testing MONGODB-OIDC on gcp... done.') -print('Self test complete!') +print("Testing MONGODB-OIDC on gcp... done.") +print("Self test complete!") diff --git a/.evergreen/auth_oidc/k8s/remote-scripts/test.py b/.evergreen/auth_oidc/k8s/remote-scripts/test.py index dd583b6d..eef51e2f 100644 --- a/.evergreen/auth_oidc/k8s/remote-scripts/test.py +++ b/.evergreen/auth_oidc/k8s/remote-scripts/test.py @@ -5,20 +5,24 @@ atlas_uri = os.environ["MONGODB_URI"] + class MyCallback(OIDCCallback): def fetch(self, context: OIDCCallbackContext) -> OIDCCallbackResult: - fname = '/var/run/secrets/kubernetes.io/serviceaccount/token' - for key in ['AZURE_FEDERATED_TOKEN_FILE', 'AWS_WEB_IDENTITY_TOKEN_FILE']: + fname = "/var/run/secrets/kubernetes.io/serviceaccount/token" + for key in ["AZURE_FEDERATED_TOKEN_FILE", "AWS_WEB_IDENTITY_TOKEN_FILE"]: if key in os.environ: fname = os.environ[key] with open(fname) as fid: token = fid.read() return OIDCCallbackResult(access_token=token) + props = dict(OIDC_CALLBACK=MyCallback()) -print('Testing MONGODB-OIDC on k8s...') -c = MongoClient(f'{atlas_uri}/?authMechanism=MONGODB-OIDC', authMechanismProperties=props) +print("Testing MONGODB-OIDC on k8s...") +c = MongoClient( + f"{atlas_uri}/?authMechanism=MONGODB-OIDC", authMechanismProperties=props +) c.test.test.insert_one({}) c.close() -print('Testing MONGODB-OIDC on k8s... done.') -print('Self test complete!') +print("Testing MONGODB-OIDC on k8s... done.") +print("Self test complete!") diff --git a/.evergreen/auth_oidc/oidc_get_tokens.py b/.evergreen/auth_oidc/oidc_get_tokens.py index d02d07aa..b0a2f0c0 100644 --- a/.evergreen/auth_oidc/oidc_get_tokens.py +++ b/.evergreen/auth_oidc/oidc_get_tokens.py @@ -5,39 +5,40 @@ sys.path.insert(0, HERE) from utils import DEFAULT_CLIENT, get_id_token, get_secrets, join -TOKEN_DIR = os.environ['OIDC_TOKEN_DIR'].replace(os.sep, '/') +TOKEN_DIR = os.environ["OIDC_TOKEN_DIR"].replace(os.sep, "/") + def generate_tokens(config, base_name): os.makedirs(TOKEN_DIR, exist_ok=True) - config['token_file'] = join(TOKEN_DIR, base_name) + config["token_file"] = join(TOKEN_DIR, base_name) get_id_token(config) for i in range(2): - config['token_file'] = join(TOKEN_DIR, f'{base_name}_{i+1}') + config["token_file"] = join(TOKEN_DIR, f"{base_name}_{i+1}") get_id_token(config) - config['token_file'] = join(TOKEN_DIR, f'{base_name}_expires') + config["token_file"] = join(TOKEN_DIR, f"{base_name}_expires") get_id_token(config, expires=60) def main(): secrets = get_secrets() config = { - "issuer": secrets['oidc_issuer_1_uri'], - "jwks_uri": secrets['oidc_jwks_uri'], - 'rsa_key': secrets['oidc_rsa_key'], - 'audience': DEFAULT_CLIENT, - 'client_id': DEFAULT_CLIENT, - 'client_secret': secrets['oidc_client_secret'], - 'username': f'test_user1@{secrets["oidc_domain"]}', + "issuer": secrets["oidc_issuer_1_uri"], + "jwks_uri": secrets["oidc_jwks_uri"], + "rsa_key": secrets["oidc_rsa_key"], + "audience": DEFAULT_CLIENT, + "client_id": DEFAULT_CLIENT, + "client_secret": secrets["oidc_client_secret"], + "username": f'test_user1@{secrets["oidc_domain"]}', } - generate_tokens(config, 'test_user1') - config['issuer'] = secrets['oidc_issuer_2_uri'] - config['username'] = f'test_user2@{secrets["oidc_domain"]}' - generate_tokens(config, 'test_user2') - config['username'] = 'test_machine' - generate_tokens(config, 'test_machine') + generate_tokens(config, "test_user1") + config["issuer"] = secrets["oidc_issuer_2_uri"] + config["username"] = f'test_user2@{secrets["oidc_domain"]}' + generate_tokens(config, "test_user2") + config["username"] = "test_machine" + generate_tokens(config, "test_machine") print(f"Wrote tokens to {TOKEN_DIR}") -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/.evergreen/auth_oidc/oidc_write_orchestration.py b/.evergreen/auth_oidc/oidc_write_orchestration.py index 16e66279..ee8cd5a9 100644 --- a/.evergreen/auth_oidc/oidc_write_orchestration.py +++ b/.evergreen/auth_oidc/oidc_write_orchestration.py @@ -1,6 +1,7 @@ """ Script for managing OIDC. """ + import json import os import sys @@ -11,10 +12,10 @@ def azure(): - client_id = os.environ['AZUREOIDC_USERNAME'] - tenant_id = os.environ['AZUREOIDC_TENANTID'] - app_id = os.environ['AZUREOIDC_APPID'] - auth_name_prefix = os.environ['AZUREOIDC_AUTHPREFIX'] + client_id = os.environ["AZUREOIDC_USERNAME"] + tenant_id = os.environ["AZUREOIDC_TENANTID"] + app_id = os.environ["AZUREOIDC_APPID"] + auth_name_prefix = os.environ["AZUREOIDC_AUTHPREFIX"] print("Bootstrapping OIDC config") @@ -27,7 +28,7 @@ def azure(): "authorizationClaim": "groups", "supportsHumanFlows": False, } - providers = json.dumps([provider_info], separators=(',',':')) + providers = json.dumps([provider_info], separators=(",", ":")) data = { "id": "oidc-repl0", @@ -43,16 +44,21 @@ def azure(): "setParameter": { "enableTestCommands": 1, "authenticationMechanisms": "SCRAM-SHA-1,SCRAM-SHA-256,MONGODB-OIDC", - "oidcIdentityProviders": providers - } - } + "oidcIdentityProviders": providers, + }, + }, } - orch_file = os.path.abspath(os.path.join(HERE, '..', 'orchestration', 'configs', 'servers', 'auth-oidc.json')) - with open(orch_file, 'w') as fid: + orch_file = os.path.abspath( + os.path.join( + HERE, "..", "orchestration", "configs", "servers", "auth-oidc.json" + ) + ) + with open(orch_file, "w") as fid: json.dump(data, fid, indent=4) print(f"Wrote OIDC config to {orch_file}") + def main(): print("Bootstrapping OIDC config") @@ -62,16 +68,16 @@ def main(): # Write the oidc orchestration file. provider1_info = { "authNamePrefix": "test1", - "issuer": secrets['oidc_issuer_1_uri'], + "issuer": secrets["oidc_issuer_1_uri"], "clientId": DEFAULT_CLIENT, "audience": DEFAULT_CLIENT, "authorizationClaim": "foo", "requestScopes": ["fizz", "buzz"], - "matchPattern": "test_user1" + "matchPattern": "test_user1", } provider2_info = { "authNamePrefix": "test2", - "issuer": secrets['oidc_issuer_2_uri'], + "issuer": secrets["oidc_issuer_2_uri"], "clientId": DEFAULT_CLIENT, "audience": DEFAULT_CLIENT, "authorizationClaim": "bar", @@ -79,7 +85,7 @@ def main(): "requestScopes": ["foo", "bar"], } - providers = json.dumps([provider1_info, provider2_info], separators=(',',':')) + providers = json.dumps([provider1_info, provider2_info], separators=(",", ":")) data = { "id": "oidc-repl0", @@ -87,51 +93,57 @@ def main(): "login": "bob", "name": "mongod", "password": "pwd123", - "members": [{ + "members": [ + { + "procParams": { + "ipv6": "NO_IPV6" not in os.environ, + "bind_ip": "0.0.0.0,::1", + "logappend": True, + "port": 27017, + "setParameter": { + "enableTestCommands": 1, + "authenticationMechanisms": "SCRAM-SHA-1,SCRAM-SHA-256,MONGODB-OIDC", + "oidcIdentityProviders": providers, + }, + } + } + ], + } + + provider2_info["matchPattern"] = "test_user2" + del provider2_info["supportsHumanFlows"] + + providers = [provider1_info, provider2_info] + providers = json.dumps(providers, separators=(",", ":")) + data["members"].append( + { "procParams": { "ipv6": "NO_IPV6" not in os.environ, "bind_ip": "0.0.0.0,::1", "logappend": True, - "port": 27017, + "port": 27018, "setParameter": { "enableTestCommands": 1, "authenticationMechanisms": "SCRAM-SHA-1,SCRAM-SHA-256,MONGODB-OIDC", - "oidcIdentityProviders": providers - } - } - }] - } - - provider2_info['matchPattern'] = 'test_user2' - del provider2_info['supportsHumanFlows'] - - providers = [provider1_info, provider2_info] - providers = json.dumps(providers, separators=(',',':')) - data['members'].append({ - "procParams": { - "ipv6": "NO_IPV6" not in os.environ, - "bind_ip": "0.0.0.0,::1", - "logappend": True, - "port": 27018, - "setParameter": { - "enableTestCommands": 1, - "authenticationMechanisms": "SCRAM-SHA-1,SCRAM-SHA-256,MONGODB-OIDC", - "oidcIdentityProviders": providers - } - }, - "rsParams": { - "priority": 0 + "oidcIdentityProviders": providers, + }, + }, + "rsParams": {"priority": 0}, } - }) - - orch_file = os.path.abspath(os.path.join(HERE, '..', 'orchestration', 'configs', 'replica_sets', 'auth-oidc.json')) - with open(orch_file, 'w') as fid: + ) + + orch_file = os.path.abspath( + os.path.join( + HERE, "..", "orchestration", "configs", "replica_sets", "auth-oidc.json" + ) + ) + with open(orch_file, "w") as fid: json.dump(data, fid, indent=4) print(f"Wrote OIDC config to {orch_file}") -if __name__ == '__main__': - if '--azure' in sys.argv: +if __name__ == "__main__": + if "--azure" in sys.argv: azure() else: main() diff --git a/.evergreen/auth_oidc/utils.py b/.evergreen/auth_oidc/utils.py index 5bcd3640..a8ac0392 100644 --- a/.evergreen/auth_oidc/utils.py +++ b/.evergreen/auth_oidc/utils.py @@ -3,14 +3,16 @@ HERE = os.path.abspath(os.path.dirname(__file__)) + def join(*args): - return os.path.join(*args).replace(os.sep, '/') + return os.path.join(*args).replace(os.sep, "/") + -aws_lib = join(os.path.dirname(HERE), 'auth_aws', 'lib') +aws_lib = join(os.path.dirname(HERE), "auth_aws", "lib") sys.path.insert(0, aws_lib) from aws_handle_oidc_creds import MOCK_ENDPOINT, get_id_token # noqa: F401 -secrets_root = join(os.path.dirname(HERE), 'secrets_handling') +secrets_root = join(os.path.dirname(HERE), "secrets_handling") sys.path.insert(0, secrets_root) from setup_secrets import get_secrets as root_get_secrets diff --git a/.evergreen/csfle/fake_azure.py b/.evergreen/csfle/fake_azure.py index 26f5953f..70132113 100644 --- a/.evergreen/csfle/fake_azure.py +++ b/.evergreen/csfle/fake_azure.py @@ -20,49 +20,39 @@ from typing import Protocol class _RequestParams(Protocol): - - def __getitem__(self, key: str) -> str: - ... + def __getitem__(self, key: str) -> str: ... @overload - def get(self, key: str) -> 'str | None': - ... + def get(self, key: str) -> "str | None": ... @overload - def get(self, key: str, default: str) -> str: - ... + def get(self, key: str, default: str) -> str: ... class _HeadersDict(dict[str, str]): - - def raw(self, key: str) -> 'bytes | None': - ... + def raw(self, key: str) -> "bytes | None": ... class _Request(Protocol): - @property - def query(self) -> _RequestParams: - ... + def query(self) -> _RequestParams: ... @property - def params(self) -> _RequestParams: - ... + def params(self) -> _RequestParams: ... @property - def headers(self) -> _HeadersDict: - ... + def headers(self) -> _HeadersDict: ... - request = cast('_Request', None) + request = cast("_Request", None) -def parse_qs(qs: str) -> 'dict[str, str]': +def parse_qs(qs: str) -> "dict[str, str]": # Reuse the bottle.py query string parser. It's a private function, but # we're using a fixed version of Bottle. return dict(bottle._parse_qsl(qs)) # type: ignore _HandlerFuncT = Callable[ - [], - 'None|str|bytes|dict[str, Any]|bottle.BaseResponse|Iterable[bytes|str]'] + [], "None|str|bytes|dict[str, Any]|bottle.BaseResponse|Iterable[bytes|str]" +] def handle_asserts(fn: _HandlerFuncT) -> _HandlerFuncT: @@ -74,55 +64,59 @@ def wrapped(): return fn() except AssertionError as e: traceback.print_exc() - return bottle.HTTPResponse(status=400, - body=json.dumps({'error': - list(e.args)})) + return bottle.HTTPResponse( + status=400, body=json.dumps({"error": list(e.args)}) + ) return wrapped -def test_params() -> 'dict[str, str]': - return parse_qs(request.headers.get('X-MongoDB-HTTP-TestParams', '')) +def test_params() -> "dict[str, str]": + return parse_qs(request.headers.get("X-MongoDB-HTTP-TestParams", "")) + -@imds.route('/') +@imds.route("/") def main(): pass -@imds.get('/metadata/identity/oauth2/token') + +@imds.get("/metadata/identity/oauth2/token") @handle_asserts def get_oauth2_token(): - api_version = request.query['api-version'] - assert api_version == '2018-02-01', 'Only api-version=2018-02-01 is supported' - resource = request.query['resource'] - assert resource == 'https://vault.azure.net', 'Only https://vault.azure.net is supported' - - case = test_params().get('case') - print('Case is:', case) - if case == '404': + api_version = request.query["api-version"] + assert api_version == "2018-02-01", "Only api-version=2018-02-01 is supported" + resource = request.query["resource"] + assert ( + resource == "https://vault.azure.net" + ), "Only https://vault.azure.net is supported" + + case = test_params().get("case") + print("Case is:", case) + if case == "404": return HTTPResponse(status=404) - if case == '500': + if case == "500": return HTTPResponse(status=500) - if case == 'bad-json': + if case == "bad-json": return b'{"key": }' - if case == 'empty-json': - return b'{}' + if case == "empty-json": + return b"{}" - if case == 'giant': + if case == "giant": return _gen_giant() - if case == 'slow': + if case == "slow": return _slow() - assert case in (None, ''), f'Unknown HTTP test case "{case}"' + assert case in (None, ""), f'Unknown HTTP test case "{case}"' return { - 'access_token': 'magic-cookie', - 'expires_in': '70', - 'token_type': 'Bearer', - 'resource': 'https://vault.azure.net', + "access_token": "magic-cookie", + "expires_in": "70", + "token_type": "Bearer", + "resource": "https://vault.azure.net", } @@ -130,25 +124,27 @@ def _gen_giant() -> Iterable[bytes]: "Generate a giant message" yield b'{ "item": [' for _ in range(1024 * 256): - yield (b'null, null, null, null, null, null, null, null, null, null, ' - b'null, null, null, null, null, null, null, null, null, null, ' - b'null, null, null, null, null, null, null, null, null, null, ' - b'null, null, null, null, null, null, null, null, null, null, ') - yield b' null ] }' - yield b'\n' + yield ( + b"null, null, null, null, null, null, null, null, null, null, " + b"null, null, null, null, null, null, null, null, null, null, " + b"null, null, null, null, null, null, null, null, null, null, " + b"null, null, null, null, null, null, null, null, null, null, " + ) + yield b" null ] }" + yield b"\n" def _slow() -> Iterable[bytes]: "Generate a very slow message" yield b'{ "item": [' for _ in range(1000): - yield b'null, ' + yield b"null, " time.sleep(1) - yield b' null ] }' + yield b" null ] }" -if __name__ == '__main__': +if __name__ == "__main__": print( - f'RECOMMENDED: Run this script using bottle.py (e.g. [{sys.executable} {Path(__file__).resolve().parent}/bottle.py fake_azure:imds])' - ) + f"RECOMMENDED: Run this script using bottle.py (e.g. [{sys.executable} {Path(__file__).resolve().parent}/bottle.py fake_azure:imds])" + ) imds.run() diff --git a/.evergreen/csfle/gcpkms/mock_server.py b/.evergreen/csfle/gcpkms/mock_server.py index d50bcc28..c927c0df 100644 --- a/.evergreen/csfle/gcpkms/mock_server.py +++ b/.evergreen/csfle/gcpkms/mock_server.py @@ -20,7 +20,7 @@ def b64_to_b64url(b64): def dict_to_b64url(arg): as_json = json.dumps(arg).encode("utf8") as_b64 = base64.b64encode(as_json).decode("utf8") - return b64_to_b64url(as_b64) + return b64_to_b64url(as_b64) def get_access_token(): @@ -33,7 +33,8 @@ def get_access_token(): if "GOOGLE_APPLICATION_CREDENTIALS" not in os.environ: raise Exception( - "please set GOOGLE_APPLICATION_CREDENTIALS environment variable to a JSON Service account key") + "please set GOOGLE_APPLICATION_CREDENTIALS environment variable to a JSON Service account key" + ) creds = json.load(open(os.environ["GOOGLE_APPLICATION_CREDENTIALS"])) private_key = creds["private_key"].encode("utf8") client_email = creds["client_email"] @@ -44,16 +45,18 @@ def get_access_token(): "scope": "https://www.googleapis.com/auth/cloudkms", # Expiration can be at most one hour in the future. Let's say 30 minutes. "exp": int(time.time()) + 30 * 60, - "iat": int(time.time()) + "iat": int(time.time()), } - assertion = jwt.encode(claims, private_key, - algorithm="RS256", headers=header) + assertion = jwt.encode(claims, private_key, algorithm="RS256", headers=header) - resp = requests.post(url="https://oauth2.googleapis.com/token", data={ - "grant_type": "urn:ietf:params:oauth:grant-type:jwt-bearer", - "assertion": assertion - }) + resp = requests.post( + url="https://oauth2.googleapis.com/token", + data={ + "grant_type": "urn:ietf:params:oauth:grant-type:jwt-bearer", + "assertion": assertion, + }, + ) if resp.status_code != 200: msg = textwrap.dedent(f""" @@ -82,7 +85,7 @@ def main(): global private_key port = 5000 server = http.server.HTTPServer(("localhost", port), Handler) - print (f"Listening on port {port}") + print(f"Listening on port {port}") server.serve_forever() diff --git a/.evergreen/csfle/kms_failpoint_server.py b/.evergreen/csfle/kms_failpoint_server.py index 9122a88a..e9c8c7dd 100644 --- a/.evergreen/csfle/kms_failpoint_server.py +++ b/.evergreen/csfle/kms_failpoint_server.py @@ -31,8 +31,9 @@ remaining_http_fails = 0 remaining_network_fails = 0 -fake_ciphertext = 'a' * 96 -fake_plaintext = 'b' * 96 +fake_ciphertext = "a" * 96 +fake_plaintext = "b" * 96 + class HTTPServerWithTLS(http.server.HTTPServer): def __init__(self, server_address, Handler, use_tls=True): @@ -47,7 +48,7 @@ def __init__(self, server_address, Handler, use_tls=True): server_side=True, certfile=cert_file, ca_certs=ca_file, - ssl_version=ssl.PROTOCOL_TLS + ssl_version=ssl.PROTOCOL_TLS, ) @@ -82,22 +83,20 @@ def do_POST(self): path = PurePosixPath(parts.path) if path.match("/set_failpoint/*"): - content_length = int(self.headers['Content-Length']) + content_length = int(self.headers["Content-Length"]) post_data = self.rfile.read(content_length) - data = json.loads(post_data.decode('utf-8')) + data = json.loads(post_data.decode("utf-8")) failpoint_type = path.parts[-1] - if failpoint_type == 'network': - remaining_network_fails = data['count'] - elif failpoint_type == 'http': - remaining_http_fails = data['count'] + if failpoint_type == "network": + remaining_network_fails = data["count"] + elif failpoint_type == "http": + remaining_http_fails = data["count"] else: self._send_not_found() return None print(f"Enabling failpoint for type: {failpoint_type}") - self._send_json( - {"message": f"failpoint set for type: '{failpoint_type}'"} - ) + self._send_json({"message": f"failpoint set for type: '{failpoint_type}'"}) return None if path.match("/reset"): @@ -112,16 +111,24 @@ def do_POST(self): raise Exception("mock network error") # No path for AWS - if 'X-Amz-Target' in self.headers and str(path) == "/": - aws_op = self.headers['X-Amz-Target'] + if "X-Amz-Target" in self.headers and str(path) == "/": + aws_op = self.headers["X-Amz-Target"] if aws_op == "TrentService.Encrypt": - self._send_json({"CiphertextBlob": base64.b64encode(fake_ciphertext.encode()).decode()}) + self._send_json( + { + "CiphertextBlob": base64.b64encode( + fake_ciphertext.encode() + ).decode() + } + ) return None if aws_op == "TrentService.Decrypt": if remaining_http_fails > 0: self._http_fail() return None - self._send_json({"Plaintext": base64.b64encode(fake_plaintext.encode()).decode()}) + self._send_json( + {"Plaintext": base64.b64encode(fake_plaintext.encode()).decode()} + ) return None self._send_not_found() return None @@ -134,28 +141,39 @@ def do_POST(self): return self._send_json({"access_token": "foo", "expires_in": 99999}) # GCP encrypt path: /v1/projects/{project}/locations/{location}/keyRings/{key-ring}/cryptoKeys/{key}:encrypt if path.match("*encrypt"): - return self._send_json({"ciphertext": base64.b64encode(fake_ciphertext.encode()).decode()}) + return self._send_json( + {"ciphertext": base64.b64encode(fake_ciphertext.encode()).decode()} + ) # GCP decrypt path: /v1/projects/{project}/locations/{location}/keyRings/{key-ring}/cryptoKeys/{key}:decrypt if path.match("*decrypt"): if remaining_http_fails > 0: self._http_fail() return None - return self._send_json({"plaintext": base64.b64encode(fake_plaintext.encode()).decode()}) + return self._send_json( + {"plaintext": base64.b64encode(fake_plaintext.encode()).decode()} + ) # Azure decrypt path: /keys/{key-name}/{key-version}/unwrapkey if path.match("*unwrapkey"): if remaining_http_fails > 0: self._http_fail() return None - return self._send_json({"value": base64.b64encode(fake_plaintext.encode()).decode()}) + return self._send_json( + {"value": base64.b64encode(fake_plaintext.encode()).decode()} + ) # Azure encrypt path: /keys/{key-name}/{key-version}/wrapkey if path.match("*wrapkey"): - return self._send_json({"value": base64.b64encode(fake_ciphertext.encode()).decode()}) + return self._send_json( + {"value": base64.b64encode(fake_ciphertext.encode()).decode()} + ) self._send_not_found() + if __name__ == "__main__": - parser = argparse.ArgumentParser(description='MongoDB mock KMS retry endpoint.') - parser.add_argument('-p', '--port', type=int, default=9003, help="Port to listen on") - parser.add_argument('--no-tls', action='store_true', help="Disable TLS") + parser = argparse.ArgumentParser(description="MongoDB mock KMS retry endpoint.") + parser.add_argument( + "-p", "--port", type=int, default=9003, help="Port to listen on" + ) + parser.add_argument("--no-tls", action="store_true", help="Disable TLS") args = parser.parse_args() server_address = ("localhost", args.port) diff --git a/.evergreen/csfle/kms_http_common.py b/.evergreen/csfle/kms_http_common.py index d363e231..b37c8c1d 100644 --- a/.evergreen/csfle/kms_http_common.py +++ b/.evergreen/csfle/kms_http_common.py @@ -1,4 +1,5 @@ """Common code for mock kms http endpoint.""" + import http.server import json import ssl @@ -46,11 +47,13 @@ def __init__(self): self.fault_calls = 0 def __repr__(self): - return json.dumps({ - 'decrypts': self.decrypt_calls, - 'encrypts': self.encrypt_calls, - 'faults': self.fault_calls, - }) + return json.dumps( + { + "decrypts": self.decrypt_calls, + "encrypts": self.encrypt_calls, + "faults": self.fault_calls, + } + ) class KmsHandlerBase(http.server.BaseHTTPRequestHandler): @@ -111,7 +114,7 @@ def _send_header(self): def _do_stats(self): self._send_header() - self.wfile.write(str(stats).encode('utf-8')) + self.wfile.write(str(stats).encode("utf-8")) def _do_disable_faults(self): global disable_faults @@ -124,9 +127,16 @@ def _do_enable_faults(self): self._send_header() -def run(port, cert_file, ca_file, handler_class, server_class=http.server.HTTPServer, cert_required=False): +def run( + port, + cert_file, + ca_file, + handler_class, + server_class=http.server.HTTPServer, + cert_required=False, +): """Run web server.""" - server_address = ('', port) + server_address = ("", port) httpd = server_class(server_address, handler_class) @@ -134,10 +144,13 @@ def run(port, cert_file, ca_file, handler_class, server_class=http.server.HTTPSe if cert_required: cert_reqs = ssl.CERT_REQUIRED - httpd.socket = ssl.wrap_socket(httpd.socket, - certfile=cert_file, - ca_certs=ca_file, server_side=True, - cert_reqs=cert_reqs) + httpd.socket = ssl.wrap_socket( + httpd.socket, + certfile=cert_file, + ca_certs=ca_file, + server_side=True, + cert_reqs=cert_reqs, + ) print("Mock KMS Web Server Listening on port " + str(server_address[1])) diff --git a/.evergreen/csfle/kms_http_server.py b/.evergreen/csfle/kms_http_server.py index 802914ea..1efa0cad 100644 --- a/.evergreen/csfle/kms_http_server.py +++ b/.evergreen/csfle/kms_http_server.py @@ -32,6 +32,7 @@ kms_http_common.FAULT_DECRYPT_WRONG_KEY, ] + def get_dict_subset(headers, subset): ret = {} for header in headers.keys(): @@ -39,6 +40,7 @@ def get_dict_subset(headers, subset): ret[header] = headers[header] return ret + class AwsKmsHandler(kms_http_common.KmsHandlerBase): """ Handle requests from AWS KMS Monitoring and test commands @@ -57,7 +59,7 @@ def do_POST(self): self.wfile.write(b"Unknown URL") def _do_post(self): - c_len = int(self.headers.get('content-length')) + c_len = int(self.headers.get("content-length")) raw_input = self.rfile.read(c_len) @@ -72,7 +74,7 @@ def _do_post(self): self._send_reply(data.encode("utf-8")) # X-Amz-Target: TrentService.Encrypt - aws_operation = self.headers['X-Amz-Target'] + aws_operation = self.headers["X-Amz-Target"] if aws_operation == "TrentService.Encrypt": kms_http_common.stats.encrypt_calls += 1 @@ -87,15 +89,21 @@ def _do_post(self): def _validate_signature(self, headers, raw_input): auth_header = headers["Authorization"] signed_headers_start = auth_header.find("SignedHeaders") - signed_headers = auth_header[signed_headers_start:auth_header.find(",", signed_headers_start)] + signed_headers = auth_header[ + signed_headers_start : auth_header.find(",", signed_headers_start) + ] signed_headers_dict = get_dict_subset(headers, signed_headers) - request = AWSRequest(method="POST", url="/", data=raw_input, headers=signed_headers_dict) + request = AWSRequest( + method="POST", url="/", data=raw_input, headers=signed_headers_dict + ) # SigV4Auth assumes this header exists even though it is not required by the algorithm - request.context['timestamp'] = headers['X-Amz-Date'] + request.context["timestamp"] = headers["X-Amz-Date"] - region_start = auth_header.find("Credential=access/") + len("Credential=access/YYYYMMDD/") - region = auth_header[region_start:auth_header.find("/", region_start)] + region_start = auth_header.find("Credential=access/") + len( + "Credential=access/YYYYMMDD/" + ) + region = auth_header[region_start : auth_header.find("/", region_start)] credentials = Credentials("access", "secret") auth = SigV4Auth(credentials, "kms", region) @@ -118,46 +126,51 @@ def _do_encrypt(self, raw_input): ciphertext = SECRET_PREFIX.encode() + plaintext.encode() ciphertext = base64.b64encode(ciphertext).decode() - if kms_http_common.fault_type and kms_http_common.fault_type.startswith(kms_http_common.FAULT_ENCRYPT) \ - and not kms_http_common.disable_faults: + if ( + kms_http_common.fault_type + and kms_http_common.fault_type.startswith(kms_http_common.FAULT_ENCRYPT) + and not kms_http_common.disable_faults + ): return self._do_encrypt_faults(ciphertext) response = { - "CiphertextBlob" : ciphertext, - "KeyId" : keyid, + "CiphertextBlob": ciphertext, + "KeyId": keyid, } - self._send_reply(json.dumps(response).encode('utf-8')) + self._send_reply(json.dumps(response).encode("utf-8")) def _do_encrypt_faults(self, raw_ciphertext): kms_http_common.stats.fault_calls += 1 if kms_http_common.fault_type == kms_http_common.FAULT_ENCRYPT: - self._send_reply(b"Internal Error of some sort.", http.HTTPStatus.INTERNAL_SERVER_ERROR) + self._send_reply( + b"Internal Error of some sort.", http.HTTPStatus.INTERNAL_SERVER_ERROR + ) return if kms_http_common.fault_type == kms_http_common.FAULT_ENCRYPT_WRONG_FIELDS: response = { - "SomeBlob" : raw_ciphertext, - "KeyId" : "foo", + "SomeBlob": raw_ciphertext, + "KeyId": "foo", } - self._send_reply(json.dumps(response).encode('utf-8')) + self._send_reply(json.dumps(response).encode("utf-8")) return if kms_http_common.fault_type == kms_http_common.FAULT_ENCRYPT_BAD_BASE64: response = { - "CiphertextBlob" : "foo", - "KeyId" : "foo", + "CiphertextBlob": "foo", + "KeyId": "foo", } - self._send_reply(json.dumps(response).encode('utf-8')) + self._send_reply(json.dumps(response).encode("utf-8")) return if kms_http_common.fault_type == kms_http_common.FAULT_ENCRYPT_CORRECT_FORMAT: response = { - "__type" : "NotFoundException", - "Message" : "Error encrypting message", + "__type": "NotFoundException", + "Message": "Error encrypting message", } - self._send_reply(json.dumps(response).encode('utf-8')) + self._send_reply(json.dumps(response).encode("utf-8")) return raise ValueError("Unknown Fault Type: " + kms_http_common.fault_type) @@ -172,61 +185,81 @@ def _do_decrypt(self, raw_input): if not blob.startswith(SECRET_PREFIX): raise ValueError() - blob = blob[len(SECRET_PREFIX):] + blob = blob[len(SECRET_PREFIX) :] - if kms_http_common.fault_type and kms_http_common.fault_type.startswith(kms_http_common.FAULT_DECRYPT) \ - and not kms_http_common.disable_faults: + if ( + kms_http_common.fault_type + and kms_http_common.fault_type.startswith(kms_http_common.FAULT_DECRYPT) + and not kms_http_common.disable_faults + ): return self._do_decrypt_faults(blob) response = { - "Plaintext" : blob, - "KeyId" : "Not a clue", + "Plaintext": blob, + "KeyId": "Not a clue", } - self._send_reply(json.dumps(response).encode('utf-8')) + self._send_reply(json.dumps(response).encode("utf-8")) def _do_decrypt_faults(self, blob): kms_http_common.stats.fault_calls += 1 if kms_http_common.fault_type == kms_http_common.FAULT_DECRYPT: - self._send_reply(b"Internal Error of some sort.", http.HTTPStatus.INTERNAL_SERVER_ERROR) + self._send_reply( + b"Internal Error of some sort.", http.HTTPStatus.INTERNAL_SERVER_ERROR + ) return if kms_http_common.fault_type == kms_http_common.FAULT_DECRYPT_WRONG_KEY: response = { - "Plaintext" : "ta7DXE7J0OiCRw03dYMJSeb8nVF5qxTmZ9zWmjuX4zW/SOorSCaY8VMTWG+cRInMx/rr/+QeVw2WjU2IpOSvMg==", - "KeyId" : "Not a clue", + "Plaintext": "ta7DXE7J0OiCRw03dYMJSeb8nVF5qxTmZ9zWmjuX4zW/SOorSCaY8VMTWG+cRInMx/rr/+QeVw2WjU2IpOSvMg==", + "KeyId": "Not a clue", } - self._send_reply(json.dumps(response).encode('utf-8')) + self._send_reply(json.dumps(response).encode("utf-8")) return if kms_http_common.fault_type == kms_http_common.FAULT_DECRYPT_CORRECT_FORMAT: response = { - "__type" : "NotFoundException", - "Message" : "Error decrypting message", + "__type": "NotFoundException", + "Message": "Error decrypting message", } - self._send_reply(json.dumps(response).encode('utf-8')) + self._send_reply(json.dumps(response).encode("utf-8")) return raise ValueError("Unknown Fault Type: " + kms_http_common.fault_type) + def main(): """Main Method.""" - parser = argparse.ArgumentParser(description='MongoDB Mock AWS KMS Endpoint.') + parser = argparse.ArgumentParser(description="MongoDB Mock AWS KMS Endpoint.") - parser.add_argument('-p', '--port', type=int, default=8000, help="Port to listen on") + parser.add_argument( + "-p", "--port", type=int, default=8000, help="Port to listen on" + ) - parser.add_argument('-v', '--verbose', action='count', help="Enable verbose tracing") + parser.add_argument( + "-v", "--verbose", action="count", help="Enable verbose tracing" + ) - parser.add_argument('--fault', type=str, help="Type of fault to inject") + parser.add_argument("--fault", type=str, help="Type of fault to inject") - parser.add_argument('--disable-faults', action='store_true', help="Disable faults on startup") + parser.add_argument( + "--disable-faults", action="store_true", help="Disable faults on startup" + ) - parser.add_argument('--ca_file', type=str, required=True, help="TLS CA PEM file") + parser.add_argument("--ca_file", type=str, required=True, help="TLS CA PEM file") - parser.add_argument('--cert_file', type=str, required=True, help="TLS Server PEM file") + parser.add_argument( + "--cert_file", type=str, required=True, help="TLS Server PEM file" + ) - parser.add_argument('--require_client_cert', action='store_true', required=False, default=False, help="Require a client certificate in TLS connections") + parser.add_argument( + "--require_client_cert", + action="store_true", + required=False, + default=False, + help="Require a client certificate in TLS connections", + ) args = parser.parse_args() if args.verbose: @@ -234,7 +267,10 @@ def main(): if args.fault: if args.fault not in SUPPORTED_FAULT_TYPES: - print("Unsupported fault type %s, supports types are %s" % (args.fault, SUPPORTED_FAULT_TYPES)) + print( + "Unsupported fault type %s, supports types are %s" + % (args.fault, SUPPORTED_FAULT_TYPES) + ) sys.exit(1) kms_http_common.fault_type = args.fault @@ -242,9 +278,14 @@ def main(): if args.disable_faults: kms_http_common.disable_faults = True - kms_http_common.run(args.port, args.cert_file, args.ca_file, AwsKmsHandler, cert_required=args.require_client_cert) - + kms_http_common.run( + args.port, + args.cert_file, + args.ca_file, + AwsKmsHandler, + cert_required=args.require_client_cert, + ) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/.evergreen/csfle/kms_kmip_client.py b/.evergreen/csfle/kms_kmip_client.py index 770f6b8d..e8c9e06c 100644 --- a/.evergreen/csfle/kms_kmip_client.py +++ b/.evergreen/csfle/kms_kmip_client.py @@ -13,7 +13,7 @@ HOSTNAME = "localhost" PORT = 5698 UID = "1" -SECRETDATABYTES = b'\xf5\xc9\x58\x81\xf3\x27\x06\x70\x33\x71\x66\x68\x49\xf9\xfd\x0e\x08\xcd\x8d\xc9\x37\x61\x28\xfb\xde\xfe\xfd\x49\x6f\x3c\x7c\x90\xf8\xfb\x5e\x4c\x5d\x4d\x04\x75\x62\xd4\xda\xc6\x36\x09\xd3\x63\xd7\x70\x31\x7c\x0b\x39\x2e\x9d\x46\x89\xa1\x25\x72\x46\x17\x76\x2f\xf3\xf0\xc8\x81\xe8\x94\x6a\xca\x32\x0f\x70\x14\x38\x90\x33\xd2\x33\x43\xd4\x07\x65\xee\x3c\x29\x8f\x26\xa3\x33\x2e\xc0\x15' +SECRETDATABYTES = b"\xf5\xc9\x58\x81\xf3\x27\x06\x70\x33\x71\x66\x68\x49\xf9\xfd\x0e\x08\xcd\x8d\xc9\x37\x61\x28\xfb\xde\xfe\xfd\x49\x6f\x3c\x7c\x90\xf8\xfb\x5e\x4c\x5d\x4d\x04\x75\x62\xd4\xda\xc6\x36\x09\xd3\x63\xd7\x70\x31\x7c\x0b\x39\x2e\x9d\x46\x89\xa1\x25\x72\x46\x17\x76\x2f\xf3\xf0\xc8\x81\xe8\x94\x6a\xca\x32\x0f\x70\x14\x38\x90\x33\xd2\x33\x43\xd4\x07\x65\xee\x3c\x29\x8f\x26\xa3\x33\x2e\xc0\x15" # Regenerate a SecretData. # The UID is chosen by the server. @@ -21,7 +21,8 @@ def regen(client): secretdata = kmip.pie.objects.SecretData( - SECRETDATABYTES, kmip.core.enums.SecretDataType.PASSWORD) + SECRETDATABYTES, kmip.core.enums.SecretDataType.PASSWORD + ) uid = client.register(secretdata) print(f"Created SecretData with UID={uid}") client.activate(uid) @@ -35,12 +36,13 @@ def main(): client = kmip.pie.client.ProxyKmipClient( hostname=HOSTNAME, port=PORT, - cert=os.path.join(drivers_evergreen_tools, - ".evergreen", "x509gen", "client.pem"), - ca=os.path.join(drivers_evergreen_tools, - ".evergreen", "x509gen", "ca.pem"), + cert=os.path.join( + drivers_evergreen_tools, ".evergreen", "x509gen", "client.pem" + ), + ca=os.path.join(drivers_evergreen_tools, ".evergreen", "x509gen", "ca.pem"), config_file=os.path.join( - drivers_evergreen_tools, ".evergreen", "csfle", "pykmip.conf") + drivers_evergreen_tools, ".evergreen", "csfle", "pykmip.conf" + ), ) with client: try: diff --git a/.evergreen/csfle/kms_kmip_server.py b/.evergreen/csfle/kms_kmip_server.py index 1b995faa..d94d3985 100644 --- a/.evergreen/csfle/kms_kmip_server.py +++ b/.evergreen/csfle/kms_kmip_server.py @@ -17,25 +17,31 @@ def main(): dir_path = os.path.dirname(os.path.realpath(__file__)) drivers_evergreen_tools = os.path.join(dir_path, "..", "..") default_ca_file = os.path.join( - drivers_evergreen_tools, ".evergreen", "x509gen", "ca.pem") + drivers_evergreen_tools, ".evergreen", "x509gen", "ca.pem" + ) default_cert_file = os.path.join( - drivers_evergreen_tools, ".evergreen", "x509gen", "server.pem") - - parser = argparse.ArgumentParser( - description='MongoDB Mock KMIP KMS Endpoint.') - parser.add_argument('-p', '--port', type=int, - default=PORT, help="Port to listen on") - parser.add_argument('--ca_file', type=str, - default=default_ca_file, help="TLS CA PEM file") - parser.add_argument('--cert_file', type=str, - default=default_cert_file, help="TLS Server PEM file") + drivers_evergreen_tools, ".evergreen", "x509gen", "server.pem" + ) + + parser = argparse.ArgumentParser(description="MongoDB Mock KMIP KMS Endpoint.") + parser.add_argument( + "-p", "--port", type=int, default=PORT, help="Port to listen on" + ) + parser.add_argument( + "--ca_file", type=str, default=default_ca_file, help="TLS CA PEM file" + ) + parser.add_argument( + "--cert_file", type=str, default=default_cert_file, help="TLS Server PEM file" + ) args = parser.parse_args() # Ensure we start with a fresh seed database. database_path = os.path.join( - drivers_evergreen_tools, ".evergreen", "csfle", "pykmip.db") + drivers_evergreen_tools, ".evergreen", "csfle", "pykmip.db" + ) database_seed_path = os.path.join( - drivers_evergreen_tools, ".evergreen", "csfle", "pykmip.db.bak") + drivers_evergreen_tools, ".evergreen", "csfle", "pykmip.db.bak" + ) shutil.copy(database_seed_path, database_path) server = KmipServer( @@ -45,8 +51,9 @@ def main(): ca_path=args.ca_file, config_path=None, auth_suite="TLS1.2", - log_path=os.path.join(drivers_evergreen_tools, - ".evergreen", "csfle", "pykmip.log"), + log_path=os.path.join( + drivers_evergreen_tools, ".evergreen", "csfle", "pykmip.log" + ), database_path=database_path, logging_level=logging.DEBUG, ) diff --git a/.evergreen/csfle/setup_secrets.py b/.evergreen/csfle/setup_secrets.py index c9e8cd2f..8548324e 100644 --- a/.evergreen/csfle/setup_secrets.py +++ b/.evergreen/csfle/setup_secrets.py @@ -1,27 +1,40 @@ """ Set up encryption secrets. """ + import os import boto3 -os.environ['AWS_ACCESS_KEY_ID']=os.environ['FLE_AWS_KEY'] -os.environ['AWS_SECRET_ACCESS_KEY']=os.environ['FLE_AWS_SECRET'] -os.environ['AWS_DEFAULT_REGION']="us-east-1" -os.environ['AWS_SESSION_TOKEN']="" +os.environ["AWS_ACCESS_KEY_ID"] = os.environ["FLE_AWS_KEY"] +os.environ["AWS_SECRET_ACCESS_KEY"] = os.environ["FLE_AWS_SECRET"] +os.environ["AWS_DEFAULT_REGION"] = "us-east-1" +os.environ["AWS_SESSION_TOKEN"] = "" print("Getting CSFLE temp creds") -client = boto3.client('sts') +client = boto3.client("sts") credentials = client.get_session_token()["Credentials"] -with open('secrets-export.sh', 'ab') as fid: - fid.write(f'\nexport CSFLE_AWS_TEMP_ACCESS_KEY_ID="{credentials["AccessKeyId"]}"'.encode()) - fid.write(f'\nexport CSFLE_AWS_TEMP_SECRET_ACCESS_KEY="{credentials["SecretAccessKey"]}"'.encode()) - fid.write(f'\nexport CSFLE_AWS_TEMP_SESSION_TOKEN="{credentials["SessionToken"]}"'.encode()) - for key in ['AWS_ACCESS_KEY_ID', 'AWS_SECRET_ACCESS_KEY', 'AWS_DEFAULT_REGION', - 'AWS_SESSION_TOKEN', 'CSFLE_TLS_CA_FILE', 'CSFLE_TLS_CERT_FILE', - 'CSFLE_TLS_CLIENT_CERT_FILE']: +with open("secrets-export.sh", "ab") as fid: + fid.write( + f'\nexport CSFLE_AWS_TEMP_ACCESS_KEY_ID="{credentials["AccessKeyId"]}"'.encode() + ) + fid.write( + f'\nexport CSFLE_AWS_TEMP_SECRET_ACCESS_KEY="{credentials["SecretAccessKey"]}"'.encode() + ) + fid.write( + f'\nexport CSFLE_AWS_TEMP_SESSION_TOKEN="{credentials["SessionToken"]}"'.encode() + ) + for key in [ + "AWS_ACCESS_KEY_ID", + "AWS_SECRET_ACCESS_KEY", + "AWS_DEFAULT_REGION", + "AWS_SESSION_TOKEN", + "CSFLE_TLS_CA_FILE", + "CSFLE_TLS_CERT_FILE", + "CSFLE_TLS_CLIENT_CERT_FILE", + ]: fid.write(f'\nexport {key}="{os.environ[key]}"'.encode()) - fid.write(b'\n') + fid.write(b"\n") print("Getting CSFLE temp creds...done") diff --git a/.evergreen/docker/overwrite_orchestration.py b/.evergreen/docker/overwrite_orchestration.py index d79176c6..9d639892 100644 --- a/.evergreen/docker/overwrite_orchestration.py +++ b/.evergreen/docker/overwrite_orchestration.py @@ -1,44 +1,46 @@ import json import os -orch_file = os.environ['ORCHESTRATION_FILE'] +orch_file = os.environ["ORCHESTRATION_FILE"] with open(orch_file) as fid: data = json.load(fid) items = [] + # Gather all the items that have process settings. def traverse(root): if isinstance(root, list): [traverse(i) for i in root] return - if 'ipv6' in root: + if "ipv6" in root: items.append(root) return for key, value in root.items(): - if key == 'routers': + if key == "routers": continue if isinstance(value, (dict, list)): traverse(value) + traverse(data) # Docker does not enable ipv6 by default. # https://docs.docker.com/config/daemon/ipv6/ # We also need to use 0.0.0.0 instead of 127.0.0.1 for item in items: - item['ipv6'] = False - item['bind_ip'] = '0.0.0.0,::1' - item['dbpath'] = f"/tmp/mongo-{item['port']}" + item["ipv6"] = False + item["bind_ip"] = "0.0.0.0,::1" + item["dbpath"] = f"/tmp/mongo-{item['port']}" -if 'routers' in data: - for router in data['routers']: - router['ipv6'] = False - router['bind_ip'] = '0.0.0.0,::1' - router['logpath'] = f"/tmp/mongodb-{item['port']}.log" +if "routers" in data: + for router in data["routers"]: + router["ipv6"] = False + router["bind_ip"] = "0.0.0.0,::1" + router["logpath"] = f"/tmp/mongodb-{item['port']}.log" print(json.dumps(data, indent=2)) -with open(orch_file, 'w') as fid: +with open(orch_file, "w") as fid: json.dump(data, fid) diff --git a/.evergreen/generate_task_config.py b/.evergreen/generate_task_config.py index 8c45cec6..bd6198c3 100644 --- a/.evergreen/generate_task_config.py +++ b/.evergreen/generate_task_config.py @@ -10,21 +10,20 @@ TOPOLOGY: "{mo_topology}" - func: "run tests"''' -MONGODB_VERSIONS = ['2.4', '2.6', '3.0', '3.2', '3.4', 'latest'] -TOPOLOGY_OPTIONS = ['standalone', 'replica_set', 'sharded_cluster'] +MONGODB_VERSIONS = ["2.4", "2.6", "3.0", "3.2", "3.4", "latest"] +TOPOLOGY_OPTIONS = ["standalone", "replica_set", "sharded_cluster"] def create_task(version, topology): - mo_topology= topology + mo_topology = topology # mongo-orchestration uses 'server' as the name for 'standalone' - if mo_topology == 'standalone': - mo_topology = 'server' + if mo_topology == "standalone": + mo_topology = "server" return TASK_TEMPLATE.format(**locals()) tasks = [] -for version, topology in itertools.product(MONGODB_VERSIONS, - TOPOLOGY_OPTIONS): +for version, topology in itertools.product(MONGODB_VERSIONS, TOPOLOGY_OPTIONS): tasks.append(create_task(version, topology)) -print('\n'.join(tasks)) +print("\n".join(tasks)) diff --git a/.evergreen/mongodl.py b/.evergreen/mongodl.py index 42743a39..f5c9cd7a 100755 --- a/.evergreen/mongodl.py +++ b/.evergreen/mongodl.py @@ -13,6 +13,7 @@ Use '--help' for more information. """ + import argparse import enum import hashlib @@ -44,138 +45,137 @@ ) # These versions are used for performance benchmarking. Do not update to a newer version. -PERF_VERSIONS = { - "v6.0-perf": "6.0.6", - "v8.0-perf": "8.0.1" -} +PERF_VERSIONS = {"v6.0-perf": "6.0.6", "v8.0-perf": "8.0.1"} #: Map common distribution names to the distribution named used in the MongoDB download list DISTRO_ID_MAP = { - 'elementary': 'ubuntu', - 'fedora': 'rhel', - 'centos': 'rhel', - 'mint': 'ubuntu', - 'linuxmint': 'ubuntu', - 'opensuse-leap': 'sles', - 'opensuse': 'sles', - 'pop': 'ubuntu', - 'redhat': 'rhel', - 'rocky': 'rhel', + "elementary": "ubuntu", + "fedora": "rhel", + "centos": "rhel", + "mint": "ubuntu", + "linuxmint": "ubuntu", + "opensuse-leap": "sles", + "opensuse": "sles", + "pop": "ubuntu", + "redhat": "rhel", + "rocky": "rhel", } #: Map derived distro versions to their base distribution versions DISTRO_VERSION_MAP = { - 'elementary': { - '6': '20.04', - '6.*': '20.04', + "elementary": { + "6": "20.04", + "6.*": "20.04", + }, + "fedora": { + "32": "8", + "33": "8", + "34": "8", + "35": "8", + "36": "8", }, - 'fedora': { - '32': '8', - '33': '8', - '34': '8', - '35': '8', - '36': '8', + "linuxmint": { + "19": "18.04", + "19.*": "18.04", + "20": "20.04", + "20.*": "20.04", + "21": "22.04", + "21.*": "22.04", }, - 'linuxmint': { - '19': '18.04', - '19.*': '18.04', - '20': '20.04', - '20.*': '20.04', - '21': '22.04', - '21.*': '22.04', + "pop": { + "20.04": "20.04", + "22.04": "22.04", }, - 'pop': { - '20.04': '20.04', - '22.04': '22.04', - } } #: Map distribution IDs with version fnmatch() patterns to download platform targets DISTRO_ID_TO_TARGET = { - 'ubuntu': { - '24.*': 'ubuntu2404', - '22.*': 'ubuntu2204', - '20.*': 'ubuntu2004', - '18.*': 'ubuntu1804', - '16.*': 'ubuntu1604', - '14.*': 'ubuntu1404', + "ubuntu": { + "24.*": "ubuntu2404", + "22.*": "ubuntu2204", + "20.*": "ubuntu2004", + "18.*": "ubuntu1804", + "16.*": "ubuntu1604", + "14.*": "ubuntu1404", }, - 'debian': { - '9': 'debian92', - '10': 'debian10', - '11': 'debian11', - '12': 'debian12', + "debian": { + "9": "debian92", + "10": "debian10", + "11": "debian11", + "12": "debian12", }, - 'rhel': { - '6': 'rhel6', - '6.*': 'rhel6', - '7': 'rhel7', - '7.*': 'rhel7', - '8': 'rhel8', - '8.*': 'rhel8', - '9': 'rhel9', - '9.*': 'rhel9', + "rhel": { + "6": "rhel6", + "6.*": "rhel6", + "7": "rhel7", + "7.*": "rhel7", + "8": "rhel8", + "8.*": "rhel8", + "9": "rhel9", + "9.*": "rhel9", }, - 'sles': { - '10.*': 'suse10', - '11.*': 'suse11', - '12.*': 'suse12', - '13.*': 'suse13', - '15.*': 'suse15', + "sles": { + "10.*": "suse10", + "11.*": "suse11", + "12.*": "suse12", + "13.*": "suse13", + "15.*": "suse15", }, - 'amzn': { - '2023': 'amazon2023', - '2018.*': 'amzn64', - '2': 'amazon2', + "amzn": { + "2023": "amazon2023", + "2018.*": "amzn64", + "2": "amazon2", }, } # The list of valid targets that are not related to a specific Linux distro. -TARGETS_THAT_ARE_NOT_DISTROS = ['linux_i686', 'linux_x86_64', 'osx', 'macos', 'windows'] +TARGETS_THAT_ARE_NOT_DISTROS = ["linux_i686", "linux_x86_64", "osx", "macos", "windows"] def infer_target(version: Optional[str] = None) -> str: """ Infer the download target of the current host system. """ - if sys.platform == 'win32': - return 'windows' - if sys.platform == 'darwin': + if sys.platform == "win32": + return "windows" + if sys.platform == "darwin": # Older versions of the server used 'osx' as the target. if version is not None: if version.startswith("4.0") or version[0] == "3": - return 'osx' - return 'macos' + return "osx" + return "macos" # Now the tricky bit - cands = (Path(p) for p in ['/etc/os-release', '/usr/lib/os-release']) + cands = (Path(p) for p in ["/etc/os-release", "/usr/lib/os-release"]) existing = (p for p in cands if p.is_file()) found = next(iter(existing), None) if found: return infer_target_from_os_release(found) - raise RuntimeError("We don't know how to find the default '--target'" - " option for this system. Please contribute!") + raise RuntimeError( + "We don't know how to find the default '--target'" + " option for this system. Please contribute!" + ) def infer_target_from_os_release(osr: Path) -> str: """ Infer the download target based on the content of os-release """ - with osr.open('r', encoding='utf-8') as f: + with osr.open("r", encoding="utf-8") as f: os_rel = f.read() # Extract the "ID" field id_re = re.compile(r'\bID=("?)(.*)\1') mat = id_re.search(os_rel) - assert mat, f'Unable to detect ID from [{osr}] content:\n{os_rel}' + assert mat, f"Unable to detect ID from [{osr}] content:\n{os_rel}" os_id = mat.group(2) - if os_id == 'arch': + if os_id == "arch": # There are no Archlinux-specific MongoDB downloads, so we'll just use # the build for RHEL8, which is reasonably compatible with other modern # distributions (including Arch). - return 'rhel80' + return "rhel80" # Extract the "VERSION_ID" field ver_id_re = re.compile(r'VERSION_ID=("?)(.*)\1') mat = ver_id_re.search(os_rel) - assert mat, f'Unable to detect VERSION_ID from [{osr}] content:\n{os_rel}' + assert mat, f"Unable to detect VERSION_ID from [{osr}] content:\n{os_rel}" ver_id = mat.group(2) # Map the ID to the download ID mapped_id = DISTRO_ID_MAP.get(os_id) @@ -183,21 +183,24 @@ def infer_target_from_os_release(osr: Path) -> str: # Map the distro version to its upstream version ver_mapper = DISTRO_VERSION_MAP.get(os_id, {}) # Find the version based on a fnmatch pattern: - matching = (ver for pat, ver in ver_mapper.items() - if fnmatch(ver_id, pat)) + matching = (ver for pat, ver in ver_mapper.items() if fnmatch(ver_id, pat)) # The default is to keep the version ID. mapped_version = next(iter(matching), None) if mapped_version is None: # If this raises, a version/pattern needs to be added # to DISTRO_VERSION_MAP - raise RuntimeError(f"We don't know how to map {os_id} version '{ver_id}' " - f"to an upstream {mapped_id} version. Please contribute!") + raise RuntimeError( + f"We don't know how to map {os_id} version '{ver_id}' " + f"to an upstream {mapped_id} version. Please contribute!" + ) ver_id = mapped_version os_id = mapped_id os_id = os_id.lower() if os_id not in DISTRO_ID_TO_TARGET: - raise RuntimeError(f"We don't know how to map '{os_id}' to a distribution " - "download target. Please contribute!") + raise RuntimeError( + f"We don't know how to map '{os_id}' to a distribution " + "download target. Please contribute!" + ) # Find the download target based on a filename-style pattern: ver_table = DISTRO_ID_TO_TARGET[os_id] for pattern, target in ver_table.items(): @@ -205,60 +208,66 @@ def infer_target_from_os_release(osr: Path) -> str: return target raise RuntimeError( f"We don't know how to map '{os_id}' version '{ver_id}' to a distribution " - "download target. Please contribute!") + "download target. Please contribute!" + ) def user_caches_root() -> Path: """ Obtain the directory for user-local caches """ - if sys.platform == 'win32': - return Path(os.environ['LocalAppData']) - if sys.platform == 'darwin': - return Path(os.environ['HOME'] + '/Library/Caches') - xdg_cache = os.getenv('XDG_CACHE_HOME') + if sys.platform == "win32": + return Path(os.environ["LocalAppData"]) + if sys.platform == "darwin": + return Path(os.environ["HOME"] + "/Library/Caches") + xdg_cache = os.getenv("XDG_CACHE_HOME") if xdg_cache: return Path(xdg_cache) - return Path(os.environ['HOME'] + '/.cache') + return Path(os.environ["HOME"] + "/.cache") def default_cache_dir() -> Path: """ Get the path to the default directory of mongodl caches. """ - return user_caches_root().joinpath('mongodl').absolute() + return user_caches_root().joinpath("mongodl").absolute() if TYPE_CHECKING: - DownloadResult = NamedTuple('DownloadResult', [('is_changed', bool), - ('path', Path)]) - DownloadableComponent = NamedTuple('DownloadableComponent', [ - ('version', str), - ('target', str), - ('arch', str), - ('edition', str), - ('key', str), - ('data_json', str), - ]) + DownloadResult = NamedTuple( + "DownloadResult", [("is_changed", bool), ("path", Path)] + ) + DownloadableComponent = NamedTuple( + "DownloadableComponent", + [ + ("version", str), + ("target", str), + ("arch", str), + ("edition", str), + ("key", str), + ("data_json", str), + ], + ) else: - DownloadResult = namedtuple('DownloadResult', ['is_changed', 'path']) + DownloadResult = namedtuple("DownloadResult", ["is_changed", "path"]) DownloadableComponent = namedtuple( - 'DownloadableComponent', - ['version', 'target', 'arch', 'edition', 'key', 'data_json']) + "DownloadableComponent", + ["version", "target", "arch", "edition", "key", "data_json"], + ) #: Regular expression that matches the version numbers from 'full.json' -VERSION_RE = re.compile(r'(\d+)\.(\d+)\.(\d+)(?:-([a-z]+)(\d+))?') -MAJOR_VERSION_RE = re.compile(r'(\d+)\.(\d+)$') +VERSION_RE = re.compile(r"(\d+)\.(\d+)\.(\d+)(?:-([a-z]+)(\d+))?") +MAJOR_VERSION_RE = re.compile(r"(\d+)\.(\d+)$") STABLE_MAX_RC = 9999 -def version_tup(version: str) -> 'tuple[int, int, int, int, int]': +def version_tup(version: str) -> "tuple[int, int, int, int, int]": if MAJOR_VERSION_RE.match(version): - maj, min = version.split('.') + maj, min = version.split(".") return tuple([int(maj), int(min), 0, 0, 0]) mat = VERSION_RE.match(version) - assert mat, (f'Failed to parse "{version}" as a version number') + assert mat, f'Failed to parse "{version}" as a version number' major, minor, patch, tag, tagnum = list(mat.groups()) if tag is None: # No rc tag is greater than an equal base version with any rc tag @@ -266,9 +275,9 @@ def version_tup(version: str) -> 'tuple[int, int, int, int, int]': tagnum = 0 else: tag = { - 'alpha': 1, - 'beta': 2, - 'rc': 3, + "alpha": 1, + "beta": 2, + "rc": 3, }[tag] return tuple(map(int, (major, minor, patch, tag, tagnum))) @@ -304,32 +313,32 @@ def __init__(self, db: sqlite3.Connection) -> None: self._cursor = self._db.cursor() @staticmethod - def open(fpath: Path) -> 'CacheDB': + def open(fpath: Path) -> "CacheDB": """ Open a caching database at the given filepath. """ db = sqlite3.connect(str(fpath), isolation_level=None) - db.execute(r''' + db.execute(r""" CREATE TABLE IF NOT EXISTS mdl_http_downloads ( url TEXT NOT NULL UNIQUE, etag TEXT, last_modified TEXT - )''') - db.create_collation('mdb_version', collate_mdb_version) - db.create_function('mdb_version_not_rc', 1, mdb_version_not_rc) - db.create_function('mdb_version_rapid', 1, mdb_version_rapid) + )""") + db.create_collation("mdb_version", collate_mdb_version) + db.create_function("mdb_version_not_rc", 1, mdb_version_not_rc) + db.create_function("mdb_version_rapid", 1, mdb_version_rapid) return CacheDB(db) def __call__( - self, query: str, **params: 'str | int | bool | float | None' - ) -> 'Iterable[sqlite3.Row]': + self, query: str, **params: "str | int | bool | float | None" + ) -> "Iterable[sqlite3.Row]": """ Execute a query with the given named parameters. """ return self._cursor.execute(query, params) @contextmanager - def transaction(self) -> 'Iterator[None]': + def transaction(self) -> "Iterator[None]": """ Create a context for a database transaction. """ @@ -339,39 +348,39 @@ def transaction(self) -> 'Iterator[None]': with self._db: # Must do an explicit BEGIN because isolation_level=None - self('BEGIN') + self("BEGIN") yield def import_json_file(self, json_file: Path) -> None: """ Import the given downloads content from the given JSON file """ - with json_file.open('r', encoding='utf-8') as f: + with json_file.open("r", encoding="utf-8") as f: data = json.load(f) self.import_json_data(data) - def import_json_data(self, data: 'Any') -> None: + def import_json_data(self, data: "Any") -> None: """ Import the given downloads content from the given JSON-like data """ with self.transaction(): self._import_json_data(data) - def _import_json_data(self, data: 'Any') -> None: + def _import_json_data(self, data: "Any") -> None: # We're reloading everything, so just drop and re-create the tables. # Bonus: We don't have to worry about schema changes - self('DROP TABLE IF EXISTS mdl_components') - self('DROP TABLE IF EXISTS mdl_downloads') - self('DROP TABLE IF EXISTS mdl_versions') - self(r''' + self("DROP TABLE IF EXISTS mdl_components") + self("DROP TABLE IF EXISTS mdl_downloads") + self("DROP TABLE IF EXISTS mdl_versions") + self(r""" CREATE TABLE mdl_versions ( version_id INTEGER PRIMARY KEY, date TEXT NOT NULL, version TEXT NOT NULL, githash TEXT NOT NULL ) - ''') - self(r''' + """) + self(r""" CREATE TABLE mdl_downloads ( download_id INTEGER PRIMARY KEY, version_id INTEGER NOT NULL REFERENCES mdl_versions, @@ -382,8 +391,8 @@ def _import_json_data(self, data: 'Any') -> None: ar_debug_url TEXT, data TEXT NOT NULL ) - ''') - self(r''' + """) + self(r""" CREATE TABLE mdl_components ( component_id INTEGER PRIMARY KEY, key TEXT NOT NULL, @@ -391,28 +400,28 @@ def _import_json_data(self, data: 'Any') -> None: data NOT NULL, UNIQUE(key, download_id) ) - ''') + """) - for ver in data['versions']: - version = ver['version'] - githash = ver['githash'] - date = ver['date'] + for ver in data["versions"]: + version = ver["version"] + githash = ver["githash"] + date = ver["date"] self( - r''' + r""" INSERT INTO mdl_versions (date, version, githash) VALUES (:date, :version, :githash) - ''', + """, date=date, version=version, githash=githash, ) version_id = self._cursor.lastrowid missing = set() - for dl in ver['downloads']: - arch = dl.get('arch', 'null') - target = dl.get('target', 'null') + for dl in ver["downloads"]: + arch = dl.get("arch", "null") + target = dl.get("target", "null") # Normalize RHEL target names to include just the major version. - if target.startswith('rhel') and len(target) == 6: + if target.startswith("rhel") and len(target) == 6: target = target[:-1] found = False for distro in DISTRO_ID_TO_TARGET.values(): @@ -420,11 +429,11 @@ def _import_json_data(self, data: 'Any') -> None: found = True if not found and target not in TARGETS_THAT_ARE_NOT_DISTROS: missing.add(target) - edition = dl['edition'] - ar_url = dl['archive']['url'] - ar_debug_url = dl['archive'].get('debug_symbols') + edition = dl["edition"] + ar_url = dl["archive"]["url"] + ar_debug_url = dl["archive"].get("debug_symbols") self( - r''' + r""" INSERT INTO mdl_downloads (version_id, target, arch, @@ -439,7 +448,7 @@ def _import_json_data(self, data: 'Any') -> None: :ar_url, :ar_debug_url, :data) - ''', + """, version_id=version_id, target=target, arch=arch, @@ -450,14 +459,14 @@ def _import_json_data(self, data: 'Any') -> None: ) dl_id = self._cursor.lastrowid for key, data in dl.items(): - if 'url' not in data: + if "url" not in data: # Some fields aren't downloadable items. Skip them continue self( - r''' + r""" INSERT INTO mdl_components (key, download_id, data) VALUES (:key, :dl_id, :data) - ''', + """, key=key, dl_id=dl_id, data=json.dumps(data), @@ -470,20 +479,20 @@ def _import_json_data(self, data: 'Any') -> None: sys.exit(1) def iter_available( - self, - *, - version: 'str | None' = None, - target: 'str | None' = None, - arch: 'str | None' = None, - edition: 'str | None' = None, - component: 'str | None' = None - ) -> 'Iterable[DownloadableComponent]': + self, + *, + version: "str | None" = None, + target: "str | None" = None, + arch: "str | None" = None, + edition: "str | None" = None, + component: "str | None" = None, + ) -> "Iterable[DownloadableComponent]": """ Iterate over the matching downloadable components according to the given attribute filters. """ rows = self( - r''' + r""" SELECT version, target, arch, edition, key, mdl_components.data FROM mdl_components, mdl_downloads USING(download_id), @@ -505,9 +514,9 @@ def iter_available( ELSE version=:version OR version LIKE :version_pattern END) ORDER BY version COLLATE mdb_version DESC - ''', + """, version=version, - version_pattern=f'{version}.%', + version_pattern=f"{version}.%", target=target, arch=arch, edition=edition, @@ -527,19 +536,19 @@ def __init__(self, dirpath: Path, db: CacheDB) -> None: self._db = db @staticmethod - def open_default() -> 'Cache': + def open_default() -> "Cache": """ Open the default user-local cache directory """ return Cache.open_in(default_cache_dir()) @staticmethod - def open_in(dirpath: Path) -> 'Cache': + def open_in(dirpath: Path) -> "Cache": """ Open or create a cache directory at the given path. """ _mkdir(dirpath) - db = CacheDB.open(dirpath / 'data.db') + db = CacheDB.open(dirpath / "data.db") return Cache(dirpath, db) @property @@ -552,19 +561,19 @@ def download_file(self, url: str) -> DownloadResult: Obtain a local copy of the file at the given URL. """ info = self._db( - 'SELECT etag, last_modified ' - 'FROM mdl_http_downloads WHERE url=:url', - url=url) + "SELECT etag, last_modified " "FROM mdl_http_downloads WHERE url=:url", + url=url, + ) etag = None # type: str|None modtime = None # type: str|None etag, modtime = next(iter(info), (None, None)) # type: ignore headers = {} # type: dict[str, str] if etag: - headers['If-None-Match'] = etag + headers["If-None-Match"] = etag if modtime: - headers['If-Modified-Since'] = modtime + headers["If-Modified-Since"] = modtime digest = hashlib.sha256(url.encode("utf-8")).hexdigest()[:4] - dest = self._dirpath / 'files' / digest / PurePosixPath(url).name + dest = self._dirpath / "files" / digest / PurePosixPath(url).name if not dest.exists(): headers = {} req = urllib.request.Request(url, headers=headers) @@ -573,26 +582,28 @@ def download_file(self, url: str) -> DownloadResult: resp = urllib.request.urlopen(req) except urllib.error.HTTPError as e: if e.code != 304: - raise RuntimeError( - f'Failed to download [{url}]') from e + raise RuntimeError(f"Failed to download [{url}]") from e assert dest.is_file(), ( - 'The download cache is missing an expected file', dest) + "The download cache is missing an expected file", + dest, + ) return DownloadResult(False, dest) _mkdir(dest.parent) got_etag = resp.getheader("ETag") - got_modtime = resp.getheader('Last-Modified') - with dest.open('wb') as of: + got_modtime = resp.getheader("Last-Modified") + with dest.open("wb") as of: buf = resp.read(1024 * 1024 * 4) while buf: of.write(buf) buf = resp.read(1024 * 1024 * 4) self._db( - 'INSERT OR REPLACE INTO mdl_http_downloads (url, etag, last_modified) ' - 'VALUES (:url, :etag, :mtime)', + "INSERT OR REPLACE INTO mdl_http_downloads (url, etag, last_modified) " + "VALUES (:url, :etag, :mtime)", url=url, etag=got_etag, - mtime=got_modtime) + mtime=got_modtime, + ) return DownloadResult(True, dest) def refresh_full_json(self) -> None: @@ -600,7 +611,7 @@ def refresh_full_json(self) -> None: Sync the content of the MongoDB full.json downloads list. """ with self._db.transaction(): - dl = self.download_file('https://downloads.mongodb.org/full.json') + dl = self.download_file("https://downloads.mongodb.org/full.json") if not dl.is_changed: # We still have a good cache return @@ -625,37 +636,45 @@ def _mkdir(dirpath: Path) -> None: pass -def _print_list(db: CacheDB, version: 'str | None', target: 'str | None', - arch: 'str | None', edition: 'str | None', - component: 'str | None'): - +def _print_list( + db: CacheDB, + version: "str | None", + target: "str | None", + arch: "str | None", + edition: "str | None", + component: "str | None", +): if version or target or arch or edition or component: counter = 0 - matching = db.iter_available(version=version, - target=target, - arch=arch, - edition=edition, - component=component) + matching = db.iter_available( + version=version, + target=target, + arch=arch, + edition=edition, + component=component, + ) for version, target, arch, edition, comp_key, comp_data in matching: counter += 1 - print(f'Download: {comp_key}\n' - f' Version: {version}\n' - f' Target: {target}\n' - f' Arch: {arch}\n' - f' Edition: {edition}\n' - f' Info: {comp_data}\n\n') + print( + f"Download: {comp_key}\n" + f" Version: {version}\n" + f" Target: {target}\n" + f" Arch: {arch}\n" + f" Edition: {edition}\n" + f" Info: {comp_data}\n\n" + ) if counter == 1: - print('Only one matching item') + print("Only one matching item") elif counter == 0: - print('No items matched the listed filters') + print("No items matched the listed filters") else: - print(f'{counter} available downloadable components') - print('(Omit filter arguments for a list of available filters)') + print(f"{counter} available downloadable components") + print("(Omit filter arguments for a list of available filters)") return tup = next( iter( # type: ignore - db(r''' + db(r""" VALUES( (select group_concat(arch, ', ') from (select distinct arch from mdl_downloads)), (select group_concat(target, ', ') from (select distinct target from mdl_downloads)), @@ -665,39 +684,39 @@ def _print_list(db: CacheDB, version: 'str | None', target: 'str | None', ORDER BY version COLLATE mdb_version)), (select group_concat(key, ', ') from (select distinct key from mdl_components)) ) - '''))) # type: tuple[str, str, str, str, str] + """) + ) + ) # type: tuple[str, str, str, str, str] arches, targets, editions, versions, components = tup if "archive" in components: - components = components.split(', ') + components = components.split(", ") components.append("archive-debug") components = ", ".join(sorted(components)) - versions = '\n'.join( - textwrap.wrap(versions, - width=78, - initial_indent=' ', - subsequent_indent=' ')) - targets = '\n'.join( - textwrap.wrap(targets, - width=78, - initial_indent=' ', - subsequent_indent=' ')) - print('Architectures:\n' - f' {arches}\n' - 'Targets:\n' - f'{targets}\n' - 'Editions:\n' - f' {editions}\n' - 'Versions:\n' - f'{versions}\n' - 'Components:\n' - f' {components}\n') + versions = "\n".join( + textwrap.wrap(versions, width=78, initial_indent=" ", subsequent_indent=" ") + ) + targets = "\n".join( + textwrap.wrap(targets, width=78, initial_indent=" ", subsequent_indent=" ") + ) + print( + "Architectures:\n" + f" {arches}\n" + "Targets:\n" + f"{targets}\n" + "Editions:\n" + f" {editions}\n" + "Versions:\n" + f"{versions}\n" + "Components:\n" + f" {components}\n" + ) def infer_arch(): a = platform.machine() or platform.processor() # Remap platform names to the names used for downloads return { - 'AMD64': 'x86_64', + "AMD64": "x86_64", }.get(a, a) @@ -708,8 +727,9 @@ class ExpandResult(enum.Enum): "One or more files were/would be extracted" -def _published_build_url(cache: Cache, version: str, target: str, arch: str, - edition: str, component: str) -> str: +def _published_build_url( + cache: Cache, version: str, target: str, arch: str, edition: str, component: str +) -> str: """ Get the URL for a "published" build (that is: a build that was published in full.json) """ @@ -717,22 +737,22 @@ def _published_build_url(cache: Cache, version: str, target: str, arch: str, if component == "archive-debug": component = "archive" value = "debug_symbols" - matching = cache.db.iter_available(version=version, - target=target, - arch=arch, - edition=edition, - component=component) + matching = cache.db.iter_available( + version=version, target=target, arch=arch, edition=edition, component=component + ) tup = next(iter(matching), None) if tup is None: raise ValueError( - 'No download was found for ' - f'version="{version}" target="{target}" arch="{arch}" edition="{edition}" component="{component}"') + "No download was found for " + f'version="{version}" target="{target}" arch="{arch}" edition="{edition}" component="{component}"' + ) data = json.loads(tup.data_json) return data[value] -def _latest_build_url(target: str, arch: str, edition: str, component: str, - branch: 'str|None') -> str: +def _latest_build_url( + target: str, arch: str, edition: str, component: str, branch: "str|None" +) -> str: """ Get the URL for an "unpublished" "latest" build. @@ -742,67 +762,72 @@ def _latest_build_url(target: str, arch: str, edition: str, component: str, """ # Normalize the filename components based on the download target platform = { - 'windows': 'windows', - 'win32': 'win32', - 'macos': 'osx', - }.get(target, 'linux') + "windows": "windows", + "win32": "win32", + "macos": "osx", + }.get(target, "linux") typ = { - 'windows': 'windows', - 'win32': 'win32', - 'macos': 'macos', - }.get(target, 'linux') + "windows": "windows", + "win32": "win32", + "macos": "macos", + }.get(target, "linux") component_name = { - 'archive': 'mongodb', - 'crypt_shared': 'mongo_crypt_shared_v1', + "archive": "mongodb", + "crypt_shared": "mongo_crypt_shared_v1", }.get(component, component) - base = f'https://downloads.10gen.com/{platform}' + base = f"https://downloads.10gen.com/{platform}" # Windows has Zip files - ext = 'zip' if target == 'windows' else 'tgz' + ext = "zip" if target == "windows" else "tgz" # Enterprise builds have an "enterprise" infix - ent_infix = 'enterprise-' if edition == 'enterprise' else '' + ent_infix = "enterprise-" if edition == "enterprise" else "" # Some platforms have a filename infix - tgt_infix = ((target + '-') - if target not in ('windows', 'win32', 'macos') - else '') + tgt_infix = (target + "-") if target not in ("windows", "win32", "macos") else "" # Non-master branch uses a filename infix - br_infix = ((branch + '-') if - (branch is not None and branch != 'master') - else '') - filename = f'{component_name}-{typ}-{arch}-{ent_infix}{tgt_infix}{br_infix}latest.{ext}' - return f'{base}/{filename}' - - -def _dl_component(cache: Cache, out_dir: Path, version: str, target: str, - arch: str, edition: str, component: str, - pattern: 'str | None', strip_components: int, test: bool, - no_download: bool, - latest_build_branch: 'str|None') -> ExpandResult: - print(f'Download {component} {version}-{edition} for {target}-{arch}', file=sys.stderr) - if version == 'latest-build': - dl_url = _latest_build_url(target, arch, edition, component, - latest_build_branch) + br_infix = (branch + "-") if (branch is not None and branch != "master") else "" + filename = ( + f"{component_name}-{typ}-{arch}-{ent_infix}{tgt_infix}{br_infix}latest.{ext}" + ) + return f"{base}/{filename}" + + +def _dl_component( + cache: Cache, + out_dir: Path, + version: str, + target: str, + arch: str, + edition: str, + component: str, + pattern: "str | None", + strip_components: int, + test: bool, + no_download: bool, + latest_build_branch: "str|None", +) -> ExpandResult: + print( + f"Download {component} {version}-{edition} for {target}-{arch}", file=sys.stderr + ) + if version == "latest-build": + dl_url = _latest_build_url( + target, arch, edition, component, latest_build_branch + ) else: - dl_url = _published_build_url(cache, version, target, arch, edition, - component) + dl_url = _published_build_url(cache, version, target, arch, edition, component) if no_download: print(dl_url) return None cached = cache.download_file(dl_url).path - return _expand_archive(cached, - out_dir, - pattern, - strip_components, - test=test) + return _expand_archive(cached, out_dir, pattern, strip_components, test=test) -def _pathjoin(items: 'Iterable[str]') -> PurePath: +def _pathjoin(items: "Iterable[str]") -> PurePath: """ Return a path formed by joining the given path components """ - return PurePath('/'.join(items)) + return PurePath("/".join(items)) -def _test_pattern(path: PurePath, pattern: 'PurePath | None') -> bool: +def _test_pattern(path: PurePath, pattern: "PurePath | None") -> bool: """ Test whether the given 'path' string matches the globbing pattern 'pattern'. @@ -821,7 +846,7 @@ def _test_pattern(path: PurePath, pattern: 'PurePath | None') -> bool: return False pattern_head = pattern_parts[0] pattern_tail = _pathjoin(pattern_parts[1:]) - if pattern_head == '**': + if pattern_head == "**": # Special "**" pattern matches any suffix of the path # Generate each suffix: tails = (path_parts[i:] for i in range(len(path_parts))) @@ -834,54 +859,60 @@ def _test_pattern(path: PurePath, pattern: 'PurePath | None') -> bool: return _test_pattern(_pathjoin(path_parts[1:]), pattern_tail) -def _expand_archive(ar: Path, dest: Path, pattern: 'str | None', - strip_components: int, test: bool) -> ExpandResult: - ''' +def _expand_archive( + ar: Path, dest: Path, pattern: "str | None", strip_components: int, test: bool +) -> ExpandResult: + """ Expand the archive members from 'ar' into 'dest'. If 'pattern' is not-None, only extracts members that match the pattern. - ''' - print(f'Extract from: [{ar.name}]', file=sys.stderr) - print(f' into: [{dest}]', file=sys.stderr) - if ar.suffix == '.zip': - n_extracted = _expand_zip(ar, - dest, - pattern, - strip_components, - test=test) - elif ar.suffix == '.tgz': - n_extracted = _expand_tgz(ar, - dest, - pattern, - strip_components, - test=test) + """ + print(f"Extract from: [{ar.name}]", file=sys.stderr) + print(f" into: [{dest}]", file=sys.stderr) + if ar.suffix == ".zip": + n_extracted = _expand_zip(ar, dest, pattern, strip_components, test=test) + elif ar.suffix == ".tgz": + n_extracted = _expand_tgz(ar, dest, pattern, strip_components, test=test) else: - raise RuntimeError('Unknown archive file extension: ' + ar.suffix) - verb = 'would be' if test else 'were' + raise RuntimeError("Unknown archive file extension: " + ar.suffix) + verb = "would be" if test else "were" if n_extracted == 0: if pattern and strip_components: - print(f'NOTE: No files {verb} extracted. Likely all files {verb} ' - f'excluded by "--only={pattern}" and/or "--strip-components={strip_components}"', file=sys.stderr) + print( + f"NOTE: No files {verb} extracted. Likely all files {verb} " + f'excluded by "--only={pattern}" and/or "--strip-components={strip_components}"', + file=sys.stderr, + ) elif pattern: - print(f'NOTE: No files {verb} extracted. Likely all files {verb} ' - f'excluded by the "--only={pattern}" filter', file=sys.stderr) + print( + f"NOTE: No files {verb} extracted. Likely all files {verb} " + f'excluded by the "--only={pattern}" filter', + file=sys.stderr, + ) elif strip_components: - print(f'NOTE: No files {verb} extracted. Likely all files {verb} ' - f'excluded by "--strip-components={strip_components}"', file=sys.stderr) + print( + f"NOTE: No files {verb} extracted. Likely all files {verb} " + f'excluded by "--strip-components={strip_components}"', + file=sys.stderr, + ) else: - print(f'NOTE: No files {verb} extracted. Empty archive?', file=sys.stderr) + print(f"NOTE: No files {verb} extracted. Empty archive?", file=sys.stderr) return ExpandResult.Empty if n_extracted == 1: - print('One file {v} extracted'.format(v='would be' if test else 'was'), file=sys.stderr) + print( + "One file {v} extracted".format(v="would be" if test else "was"), + file=sys.stderr, + ) return ExpandResult.Okay - print(f'{n_extracted} files {verb} extracted', file=sys.stderr) + print(f"{n_extracted} files {verb} extracted", file=sys.stderr) return ExpandResult.Okay -def _expand_tgz(ar: Path, dest: Path, pattern: 'str | None', - strip_components: int, test: bool) -> int: - 'Expand a tar.gz archive' +def _expand_tgz( + ar: Path, dest: Path, pattern: "str | None", strip_components: int, test: bool +) -> int: + "Expand a tar.gz archive" n_extracted = 0 - with tarfile.open(str(ar), 'r:*') as tf: + with tarfile.open(str(ar), "r:*") as tf: for mem in tf.getmembers(): n_extracted += _maybe_extract_member( dest, @@ -889,54 +920,61 @@ def _expand_tgz(ar: Path, dest: Path, pattern: 'str | None', pattern, strip_components, mem.isdir(), - lambda: cast('IO[bytes]', tf.extractfile(mem)), # noqa: B023 + lambda: cast("IO[bytes]", tf.extractfile(mem)), # noqa: B023 mem.mode, test=test, ) return n_extracted -def _expand_zip(ar: Path, dest: Path, pattern: 'str | None', - strip_components: int, test: bool) -> int: - 'Expand a .zip archive.' +def _expand_zip( + ar: Path, dest: Path, pattern: "str | None", strip_components: int, test: bool +) -> int: + "Expand a .zip archive." n_extracted = 0 - with zipfile.ZipFile(str(ar), 'r') as zf: + with zipfile.ZipFile(str(ar), "r") as zf: for item in zf.infolist(): n_extracted += _maybe_extract_member( dest, PurePath(item.filename), pattern, strip_components, - item.filename.endswith('/'), ## Equivalent to: item.is_dir(), - lambda: zf.open(item, 'r'), # noqa: B023 + item.filename.endswith("/"), ## Equivalent to: item.is_dir(), + lambda: zf.open(item, "r"), # noqa: B023 0o655, test=test, ) return n_extracted -def _maybe_extract_member(out: Path, relpath: PurePath, pattern: 'str | None', - strip: int, is_dir: bool, - opener: 'Callable[[], IO[bytes]]', modebits: int, - test: bool) -> int: +def _maybe_extract_member( + out: Path, + relpath: PurePath, + pattern: "str | None", + strip: int, + is_dir: bool, + opener: "Callable[[], IO[bytes]]", + modebits: int, + test: bool, +) -> int: """ Try to extract an archive member according to the given arguments. :return: Zero if the file was excluded by filters, one otherwise. """ relpath = PurePath(relpath) - print(' | {:-<65} |'.format(str(relpath) + ' '), end='', file=sys.stderr) + print(" | {:-<65} |".format(str(relpath) + " "), end="", file=sys.stderr) if len(relpath.parts) <= strip: # Not enough path components - print(' (Excluded by --strip-components)', file=sys.stderr) + print(" (Excluded by --strip-components)", file=sys.stderr) return 0 if not _test_pattern(relpath, PurePath(pattern) if pattern else None): # Doesn't match our pattern - print(' (excluded by pattern)', file=sys.stderr) + print(" (excluded by pattern)", file=sys.stderr) return 0 stripped = _pathjoin(relpath.parts[strip:]) dest = Path(out) / stripped - print(f'\n -> [{dest}]', file=sys.stderr) + print(f"\n -> [{dest}]", file=sys.stderr) if test: # We are running in test-only mode: Do not do anything return 1 @@ -945,149 +983,164 @@ def _maybe_extract_member(out: Path, relpath: PurePath, pattern: 'str | None', return 1 with opener() as infile: _mkdir(dest.parent) - with dest.open('wb') as outfile: + with dest.open("wb") as outfile: shutil.copyfileobj(infile, outfile) os.chmod(str(dest), modebits) return 1 -def main(argv: 'Sequence[str]'): +def main(argv: "Sequence[str]"): parser = argparse.ArgumentParser( - description=__doc__, - formatter_class=argparse.RawDescriptionHelpFormatter) + description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter + ) parser.add_argument( - '--cache-dir', + "--cache-dir", type=Path, default=default_cache_dir(), - help='Directory where download caches and metadata will be stored') - grp = parser.add_argument_group('List arguments') - grp.add_argument('--list', - action='store_true', - help='List available components, targets, editions, and ' - 'architectures. Download arguments will act as filters.') + help="Directory where download caches and metadata will be stored", + ) + grp = parser.add_argument_group("List arguments") + grp.add_argument( + "--list", + action="store_true", + help="List available components, targets, editions, and " + "architectures. Download arguments will act as filters.", + ) dl_grp = parser.add_argument_group( - 'Download arguments', - description='Select what to download and extract. ' - 'Non-required arguments will be inferred ' - 'based on the host system.') - dl_grp.add_argument('--target', - '-T', - help='The target platform for which to download. ' - 'Use "--list" to list available targets.') - dl_grp.add_argument('--arch', - '-A', - help='The architecture for which to download') + "Download arguments", + description="Select what to download and extract. " + "Non-required arguments will be inferred " + "based on the host system.", + ) dl_grp.add_argument( - '--edition', - '-E', + "--target", + "-T", + help="The target platform for which to download. " + 'Use "--list" to list available targets.', + ) + dl_grp.add_argument("--arch", "-A", help="The architecture for which to download") + dl_grp.add_argument( + "--edition", + "-E", help='The edition of the product to download (Default is "enterprise"). ' - 'Use "--list" to list available editions.') + 'Use "--list" to list available editions.', + ) dl_grp.add_argument( - '--out', - '-o', - help='The directory in which to download components. (Required)', - type=Path) + "--out", + "-o", + help="The directory in which to download components. (Required)", + type=Path, + ) dl_grp.add_argument( - '--version', - '-V', - help= - 'The product version to download (Required). Use "latest" to download ' - 'the newest available version (including release candidates). Use ' + "--version", + "-V", + help='The product version to download (Required). Use "latest" to download ' + "the newest available version (including release candidates). Use " '"latest-stable" to download the newest version, excluding release ' 'candidates. Use "rapid" to download the latest rapid release. ' ' Use "latest-build" to download the most recent build of ' - 'the named component. Use "--list" to list available versions.') - dl_grp.add_argument('--component', - '-C', - help='The component to download (Required). ' - 'Use "--list" to list available components.') + 'the named component. Use "--list" to list available versions.', + ) dl_grp.add_argument( - '--only', - help= - 'Restrict extraction to items that match the given globbing expression. ' + "--component", + "-C", + help="The component to download (Required). " + 'Use "--list" to list available components.', + ) + dl_grp.add_argument( + "--only", + help="Restrict extraction to items that match the given globbing expression. " 'The full archive member path is matched, so a pattern like "*.exe" ' 'will only match "*.exe" at the top level of the archive. To match ' 'recursively, use the "**" pattern to match any number of ' - 'intermediate directories.') + "intermediate directories.", + ) dl_grp.add_argument( - '--strip-path-components', - '-p', - dest='strip_components', - metavar='N', + "--strip-path-components", + "-p", + dest="strip_components", + metavar="N", default=0, type=int, - help= - 'Strip the given number of path components from archive members before ' - 'extracting into the destination. The relative path of the archive ' - 'member will be used to form the destination path. For example, a ' - 'member named [bin/mongod.exe] will be extracted to [/bin/mongod.exe]. ' - 'Using --strip-components=1 will remove the first path component, extracting ' - 'such an item to [/mongod.exe]. If the path has fewer than N components, ' - 'that archive member will be ignored.') + help="Strip the given number of path components from archive members before " + "extracting into the destination. The relative path of the archive " + "member will be used to form the destination path. For example, a " + "member named [bin/mongod.exe] will be extracted to [/bin/mongod.exe]. " + "Using --strip-components=1 will remove the first path component, extracting " + "such an item to [/mongod.exe]. If the path has fewer than N components, " + "that archive member will be ignored.", + ) + dl_grp.add_argument( + "--no-download", + action="store_true", + help="Do not download the file, only print its url.", + ) dl_grp.add_argument( - '--no-download', - action='store_true', - help='Do not download the file, only print its url.') + "--test", + action="store_true", + help="Do not extract or place any files/directories. " + "Only print what will be extracted without placing any files.", + ) dl_grp.add_argument( - '--test', - action='store_true', - help='Do not extract or place any files/directories. ' - 'Only print what will be extracted without placing any files.') - dl_grp.add_argument('--empty-is-error', - action='store_true', - help='If all files are excluded by other filters, ' - 'treat that situation as an error and exit non-zero.') - dl_grp.add_argument('--latest-build-branch', - help='Specify the name of the branch to ' - 'download the with "--version=latest-build"', - metavar='BRANCH_NAME') + "--empty-is-error", + action="store_true", + help="If all files are excluded by other filters, " + "treat that situation as an error and exit non-zero.", + ) + dl_grp.add_argument( + "--latest-build-branch", + help="Specify the name of the branch to " + 'download the with "--version=latest-build"', + metavar="BRANCH_NAME", + ) args = parser.parse_args() cache = Cache.open_in(args.cache_dir) cache.refresh_full_json() if args.list: - _print_list(cache.db, args.version, args.target, args.arch, - args.edition, args.component) + _print_list( + cache.db, args.version, args.target, args.arch, args.edition, args.component + ) return None if args.version is None: raise argparse.ArgumentError(None, 'A "--version" is required') if args.component is None: - raise argparse.ArgumentError( - None, 'A "--component" name should be provided') + raise argparse.ArgumentError(None, 'A "--component" name should be provided') if args.out is None and args.test is None and args.no_download is None: - raise argparse.ArgumentError(None, - 'A "--out" directory should be provided') + raise argparse.ArgumentError(None, 'A "--out" directory should be provided') version = args.version if version in PERF_VERSIONS: version = PERF_VERSIONS[version] target = args.target - if target in (None, 'auto'): + if target in (None, "auto"): target = infer_target(version) arch = args.arch - if arch in (None, 'auto'): + if arch in (None, "auto"): arch = infer_arch() - edition = args.edition or 'enterprise' + edition = args.edition or "enterprise" out = args.out or Path.cwd() out = out.absolute() - result = _dl_component(cache, - out, - version=version, - target=target, - arch=arch, - edition=edition, - component=args.component, - pattern=args.only, - strip_components=args.strip_components, - test=args.test, - no_download=args.no_download, - latest_build_branch=args.latest_build_branch) + result = _dl_component( + cache, + out, + version=version, + target=target, + arch=arch, + edition=edition, + component=args.component, + pattern=args.only, + strip_components=args.strip_components, + test=args.test, + no_download=args.no_download, + latest_build_branch=args.latest_build_branch, + ) if result is ExpandResult.Empty and args.empty_is_error: return 1 return 0 -if __name__ == '__main__': +if __name__ == "__main__": sys.exit(main(sys.argv[1:])) diff --git a/.evergreen/mongosh-dl.py b/.evergreen/mongosh-dl.py index 6961e26a..9840d637 100644 --- a/.evergreen/mongosh-dl.py +++ b/.evergreen/mongosh-dl.py @@ -3,6 +3,7 @@ Use '--help' for more information. """ + import argparse import json import os @@ -21,8 +22,10 @@ def _get_latest_version(): - headers = { "Accept": "application/vnd.github+json", - "X-GitHub-Api-Version": "2022-11-28" } + headers = { + "Accept": "application/vnd.github+json", + "X-GitHub-Api-Version": "2022-11-28", + } url = "https://api.github.com/repos/mongodb-js/mongosh/releases" req = urllib.request.Request(url, headers=headers) try: @@ -30,33 +33,45 @@ def _get_latest_version(): except Exception: return _get_latest_version_git() - data = json.loads(resp.read().decode('utf-8')) + data = json.loads(resp.read().decode("utf-8")) for item in data: - if item['prerelease']: + if item["prerelease"]: continue - return item['tag_name'].replace('v', '').strip() + return item["tag_name"].replace("v", "").strip() def _get_latest_version_git(): with tempfile.TemporaryDirectory() as td: - cmd = 'git clone --depth 1 https://github.com/mongodb-js/mongosh.git' - subprocess.check_call(shlex.split(cmd), cwd=td, stdout=subprocess.PIPE, stderr=subprocess.PIPE) - cmd = 'git fetch origin --tags' - path = os.path.join(td, 'mongosh') - subprocess.check_call(shlex.split(cmd), cwd=path, stdout=subprocess.PIPE, stderr=subprocess.PIPE) - cmd = 'git --no-pager tag' - output = subprocess.check_output(shlex.split(cmd), cwd=path, stderr=subprocess.PIPE) - for line in reversed(output.decode('utf-8').splitlines()): - if re.match('^v\d+\.\d+\.\d+$', line): - print('Found version', line, file=sys.stderr) - return line.replace('v', '').strip() - - -def _download(out_dir: Path, version: str, target: str, - arch: str, - pattern: 'str | None', strip_components: int, test: bool, - no_download: bool,) -> int: - print(f'Download {version} mongosh for {target}-{arch}', file=sys.stderr) + cmd = "git clone --depth 1 https://github.com/mongodb-js/mongosh.git" + subprocess.check_call( + shlex.split(cmd), cwd=td, stdout=subprocess.PIPE, stderr=subprocess.PIPE + ) + cmd = "git fetch origin --tags" + path = os.path.join(td, "mongosh") + subprocess.check_call( + shlex.split(cmd), cwd=path, stdout=subprocess.PIPE, stderr=subprocess.PIPE + ) + cmd = "git --no-pager tag" + output = subprocess.check_output( + shlex.split(cmd), cwd=path, stderr=subprocess.PIPE + ) + for line in reversed(output.decode("utf-8").splitlines()): + if re.match("^v\d+\.\d+\.\d+$", line): + print("Found version", line, file=sys.stderr) + return line.replace("v", "").strip() + + +def _download( + out_dir: Path, + version: str, + target: str, + arch: str, + pattern: "str | None", + strip_components: int, + test: bool, + no_download: bool, +) -> int: + print(f"Download {version} mongosh for {target}-{arch}", file=sys.stderr) if version == "latest": version = _get_latest_version() if arch == "x86_64": @@ -81,99 +96,104 @@ def _download(out_dir: Path, version: str, target: str, fp.write(buf) buf = resp.read(1024 * 1024 * 4) fp.close() - resp = _expand_archive(Path(fp.name), - out_dir, pattern, - strip_components, - test=test) + resp = _expand_archive( + Path(fp.name), out_dir, pattern, strip_components, test=test + ) os.remove(fp.name) return resp -def main(argv: 'Sequence[str]'): +def main(argv: "Sequence[str]"): parser = argparse.ArgumentParser( - description=__doc__, - formatter_class=argparse.RawDescriptionHelpFormatter) + description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter + ) dl_grp = parser.add_argument_group( - 'Download arguments', - description='Select what to download and extract. ' - 'Non-required arguments will be inferred ' - 'based on the host system.') - dl_grp.add_argument('--target', - '-T', - help='The target platform for which to download. ' - 'Use "--list" to list available targets.') - dl_grp.add_argument('--arch', - '-A', - help='The architecture for which to download') + "Download arguments", + description="Select what to download and extract. " + "Non-required arguments will be inferred " + "based on the host system.", + ) + dl_grp.add_argument( + "--target", + "-T", + help="The target platform for which to download. " + 'Use "--list" to list available targets.', + ) + dl_grp.add_argument("--arch", "-A", help="The architecture for which to download") dl_grp.add_argument( - '--out', - '-o', - help='The directory in which to download components. (Required)', - type=Path) + "--out", + "-o", + help="The directory in which to download components. (Required)", + type=Path, + ) dl_grp.add_argument( - '--version', - '-V', + "--version", + "-V", default="latest", - help= - 'The product version to download (Required). Use "latest" to download ' - 'the newest available stable version.') + help='The product version to download (Required). Use "latest" to download ' + "the newest available stable version.", + ) dl_grp.add_argument( - '--only', - help= - 'Restrict extraction to items that match the given globbing expression. ' + "--only", + help="Restrict extraction to items that match the given globbing expression. " 'The full archive member path is matched, so a pattern like "*.exe" ' 'will only match "*.exe" at the top level of the archive. To match ' 'recursively, use the "**" pattern to match any number of ' - 'intermediate directories.') + "intermediate directories.", + ) dl_grp.add_argument( - '--strip-path-components', - '-p', - dest='strip_components', - metavar='N', + "--strip-path-components", + "-p", + dest="strip_components", + metavar="N", default=0, type=int, - help= - 'Strip the given number of path components from archive members before ' - 'extracting into the destination. The relative path of the archive ' - 'member will be used to form the destination path. For example, a ' - 'member named [bin/mongod.exe] will be extracted to [/bin/mongod.exe]. ' - 'Using --strip-components=1 will remove the first path component, extracting ' - 'such an item to [/mongod.exe]. If the path has fewer than N components, ' - 'that archive member will be ignored.') + help="Strip the given number of path components from archive members before " + "extracting into the destination. The relative path of the archive " + "member will be used to form the destination path. For example, a " + "member named [bin/mongod.exe] will be extracted to [/bin/mongod.exe]. " + "Using --strip-components=1 will remove the first path component, extracting " + "such an item to [/mongod.exe]. If the path has fewer than N components, " + "that archive member will be ignored.", + ) dl_grp.add_argument( - '--no-download', - action='store_true', - help='Do not download the file, only print its url.') + "--no-download", + action="store_true", + help="Do not download the file, only print its url.", + ) dl_grp.add_argument( - '--test', - action='store_true', - help='Do not extract or place any files/directories. ' - 'Only print what will be extracted without placing any files.') + "--test", + action="store_true", + help="Do not extract or place any files/directories. " + "Only print what will be extracted without placing any files.", + ) args = parser.parse_args(argv) if args.out is None and args.test is None and args.no_download is None: - raise argparse.ArgumentError(None, - 'A "--out" directory should be provided') + raise argparse.ArgumentError(None, 'A "--out" directory should be provided') target = args.target - if target in (None, 'auto'): + if target in (None, "auto"): target = sys.platform arch = args.arch - if arch in (None, 'auto'): + if arch in (None, "auto"): arch = infer_arch() out = args.out or Path.cwd() out = out.absolute() - result = _download(out, + result = _download( + out, version=args.version, target=target, arch=arch, pattern=args.only, strip_components=args.strip_components, test=args.test, - no_download=args.no_download) + no_download=args.no_download, + ) if result is ExpandResult.Empty: return 1 return 0 -if __name__ == '__main__': + +if __name__ == "__main__": sys.exit(main(sys.argv[1:])) diff --git a/.evergreen/ocsp/mock_ocsp_responder.py b/.evergreen/ocsp/mock_ocsp_responder.py index dcc12cb3..d135aa9c 100644 --- a/.evergreen/ocsp/mock_ocsp_responder.py +++ b/.evergreen/ocsp/mock_ocsp_responder.py @@ -53,7 +53,7 @@ from flask import Flask, Response, request from oscrypto import asymmetric -__version__ = '0.10.2' +__version__ = "0.10.2" __version_info__ = (0, 10, 2) logger = logging.getLogger(__name__) @@ -77,8 +77,8 @@ def _pretty_message(string, *params): # Unwrap lines, taking into account bulleted lists, ordered lists and # underlines consisting of = signs - if output.find('\n') != -1: - output = re.sub('(?<=\\S)\n(?=[^ \n\t\\d\\*\\-=])', ' ', output) + if output.find("\n") != -1: + output = re.sub("(?<=\\S)\n(?=[^ \n\t\\d\\*\\-=])", " ", output) if params: output = output % params @@ -98,9 +98,10 @@ def _type_name(value): cls = value else: cls = value.__class__ - if cls.__module__ in set(['builtins', '__builtin__']): + if cls.__module__ in set(["builtins", "__builtin__"]): return cls.__name__ - return '%s.%s' % (cls.__module__, cls.__name__) + return "%s.%s" % (cls.__module__, cls.__name__) + def _writer(func): """ @@ -108,11 +109,10 @@ def _writer(func): """ name = func.__name__ - return property(fget=lambda self: getattr(self, '_%s' % name), fset=func) + return property(fget=lambda self: getattr(self, "_%s" % name), fset=func) class OCSPResponseBuilder: - _response_status = None _certificate = None _certificate_status = None @@ -126,7 +126,9 @@ class OCSPResponseBuilder: _response_data_extensions = None _single_response_extensions = None - def __init__(self, response_status, certificate_status_list=None, revocation_date=None): + def __init__( + self, response_status, certificate_status_list=None, revocation_date=None + ): """ Unless changed, responses will use SHA-256 for the signature, and will be valid from the moment created for one week. @@ -163,8 +165,8 @@ def __init__(self, response_status, certificate_status_list=None, revocation_dat self._certificate_status_list = certificate_status_list or [] self._revocation_date = revocation_date - self._key_hash_algo = 'sha1' - self._hash_algo = 'sha256' + self._key_hash_algo = "sha1" + self._hash_algo = "sha256" self._response_data_extensions = {} self._single_response_extensions = {} @@ -175,12 +177,14 @@ def nonce(self, value): """ if not isinstance(value, bytes): - raise TypeError(_pretty_message( - ''' + raise TypeError( + _pretty_message( + """ nonce must be a byte string, not %s - ''', - _type_name(value) - )) + """, + _type_name(value), + ) + ) self._nonce = value @@ -196,14 +200,16 @@ def certificate_issuer(self, value): if value is not None: is_oscrypto = isinstance(value, asymmetric.Certificate) if not is_oscrypto and not isinstance(value, x509.Certificate): - raise TypeError(_pretty_message( - ''' + raise TypeError( + _pretty_message( + """ certificate_issuer must be an instance of asn1crypto.x509.Certificate or oscrypto.asymmetric.Certificate, not %s - ''', - _type_name(value) - )) + """, + _type_name(value), + ) + ) if is_oscrypto: value = value.asn1 @@ -219,12 +225,14 @@ def next_update(self, value): """ if not isinstance(value, datetime): - raise TypeError(_pretty_message( - ''' + raise TypeError( + _pretty_message( + """ next_update must be an instance of datetime.datetime, not %s - ''', - _type_name(value) - )) + """, + _type_name(value), + ) + ) self._next_update = value @@ -243,50 +251,56 @@ def build(self, responder_private_key=None, responder_certificate=None): :return: An asn1crypto.ocsp.OCSPResponse object of the response """ - if self._response_status != 'successful': - return ocsp.OCSPResponse({ - 'response_status': self._response_status - }) + if self._response_status != "successful": + return ocsp.OCSPResponse({"response_status": self._response_status}) is_oscrypto = isinstance(responder_private_key, asymmetric.PrivateKey) - if not isinstance(responder_private_key, keys.PrivateKeyInfo) and not is_oscrypto: - raise TypeError(_pretty_message( - ''' + if ( + not isinstance(responder_private_key, keys.PrivateKeyInfo) + and not is_oscrypto + ): + raise TypeError( + _pretty_message( + """ responder_private_key must be an instance of the c asn1crypto.keys.PrivateKeyInfo or oscrypto.asymmetric.PrivateKey, not %s - ''', - _type_name(responder_private_key) - )) + """, + _type_name(responder_private_key), + ) + ) cert_is_oscrypto = isinstance(responder_certificate, asymmetric.Certificate) - if not isinstance(responder_certificate, x509.Certificate) and not cert_is_oscrypto: - raise TypeError(_pretty_message( - ''' + if ( + not isinstance(responder_certificate, x509.Certificate) + and not cert_is_oscrypto + ): + raise TypeError( + _pretty_message( + """ responder_certificate must be an instance of asn1crypto.x509.Certificate or oscrypto.asymmetric.Certificate, not %s - ''', - _type_name(responder_certificate) - )) + """, + _type_name(responder_certificate), + ) + ) if cert_is_oscrypto: responder_certificate = responder_certificate.asn1 if self._certificate_status_list is None: - raise ValueError(_pretty_message( - ''' + raise ValueError( + _pretty_message( + """ certificate_status_list must be set if the response_status is "successful" - ''' - )) + """ + ) + ) def _make_extension(name, value): - return { - 'extn_id': name, - 'critical': False, - 'extn_value': value - } + return {"extn_id": name, "critical": False, "extn_value": value} responses = [] for serial, status in self._certificate_status_list: @@ -295,9 +309,7 @@ def _make_extension(name, value): for name, value in self._response_data_extensions.items(): response_data_extensions.append(_make_extension(name, value)) if self._nonce: - response_data_extensions.append( - _make_extension('nonce', self._nonce) - ) + response_data_extensions.append(_make_extension("nonce", self._nonce)) if not response_data_extensions: response_data_extensions = None @@ -308,42 +320,42 @@ def _make_extension(name, value): if self._certificate_issuer: single_response_extensions.append( _make_extension( - 'certificate_issuer', + "certificate_issuer", [ x509.GeneralName( - name='directory_name', - value=self._certificate_issuer.subject + name="directory_name", + value=self._certificate_issuer.subject, ) - ] + ], ) ) if not single_response_extensions: single_response_extensions = None - responder_key_hash = getattr(responder_certificate.public_key, self._key_hash_algo) + responder_key_hash = getattr( + responder_certificate.public_key, self._key_hash_algo + ) - if status == 'good': - cert_status = ocsp.CertStatus( - name='good', - value=core.Null() - ) - elif status == 'unknown': - cert_status = ocsp.CertStatus( - name='unknown', - value=core.Null() - ) + if status == "good": + cert_status = ocsp.CertStatus(name="good", value=core.Null()) + elif status == "unknown": + cert_status = ocsp.CertStatus(name="unknown", value=core.Null()) else: - reason = status if status != 'revoked' else 'unspecified' + reason = status if status != "revoked" else "unspecified" cert_status = ocsp.CertStatus( - name='revoked', + name="revoked", value={ - 'revocation_time': self._revocation_date, - 'revocation_reason': reason, - } + "revocation_time": self._revocation_date, + "revocation_reason": reason, + }, ) - issuer = self._certificate_issuer if self._certificate_issuer else responder_certificate + issuer = ( + self._certificate_issuer + if self._certificate_issuer + else responder_certificate + ) produced_at = datetime.now(timezone.utc).replace(microsecond=0) @@ -351,88 +363,102 @@ def _make_extension(name, value): self._this_update = produced_at if self._next_update is None: - self._next_update = (self._this_update + timedelta(days=7)).replace(microsecond=0) + self._next_update = (self._this_update + timedelta(days=7)).replace( + microsecond=0 + ) response = { - 'cert_id': { - 'hash_algorithm': { - 'algorithm': self._key_hash_algo - }, - 'issuer_name_hash': getattr(issuer.subject, self._key_hash_algo), - 'issuer_key_hash': getattr(issuer.public_key, self._key_hash_algo), - 'serial_number': serial, - }, - 'cert_status': cert_status, - 'this_update': self._this_update, - 'next_update': self._next_update, - 'single_extensions': single_response_extensions - } + "cert_id": { + "hash_algorithm": {"algorithm": self._key_hash_algo}, + "issuer_name_hash": getattr(issuer.subject, self._key_hash_algo), + "issuer_key_hash": getattr(issuer.public_key, self._key_hash_algo), + "serial_number": serial, + }, + "cert_status": cert_status, + "this_update": self._this_update, + "next_update": self._next_update, + "single_extensions": single_response_extensions, + } responses.append(response) - response_data = ocsp.ResponseData({ - 'responder_id': ocsp.ResponderId(name='by_key', value=responder_key_hash), - 'produced_at': produced_at, - 'responses': responses, - 'response_extensions': response_data_extensions - }) + response_data = ocsp.ResponseData( + { + "responder_id": ocsp.ResponderId( + name="by_key", value=responder_key_hash + ), + "produced_at": produced_at, + "responses": responses, + "response_extensions": response_data_extensions, + } + ) signature_algo = responder_private_key.algorithm - if signature_algo == 'ec': - signature_algo = 'ecdsa' + if signature_algo == "ec": + signature_algo = "ecdsa" - signature_algorithm_id = '%s_%s' % (self._hash_algo, signature_algo) + signature_algorithm_id = "%s_%s" % (self._hash_algo, signature_algo) - if responder_private_key.algorithm == 'rsa': + if responder_private_key.algorithm == "rsa": sign_func = asymmetric.rsa_pkcs1v15_sign - elif responder_private_key.algorithm == 'dsa': + elif responder_private_key.algorithm == "dsa": sign_func = asymmetric.dsa_sign - elif responder_private_key.algorithm == 'ec': + elif responder_private_key.algorithm == "ec": sign_func = asymmetric.ecdsa_sign if not is_oscrypto: responder_private_key = asymmetric.load_private_key(responder_private_key) - signature_bytes = sign_func(responder_private_key, response_data.dump(), self._hash_algo) + signature_bytes = sign_func( + responder_private_key, response_data.dump(), self._hash_algo + ) certs = None - if self._certificate_issuer and getattr(self._certificate_issuer.public_key, self._key_hash_algo) != responder_key_hash: + if ( + self._certificate_issuer + and getattr(self._certificate_issuer.public_key, self._key_hash_algo) + != responder_key_hash + ): certs = [responder_certificate] - return ocsp.OCSPResponse({ - 'response_status': self._response_status, - 'response_bytes': { - 'response_type': 'basic_ocsp_response', - 'response': { - 'tbs_response_data': response_data, - 'signature_algorithm': {'algorithm': signature_algorithm_id}, - 'signature': signature_bytes, - 'certs': certs, - } + return ocsp.OCSPResponse( + { + "response_status": self._response_status, + "response_bytes": { + "response_type": "basic_ocsp_response", + "response": { + "tbs_response_data": response_data, + "signature_algorithm": {"algorithm": signature_algorithm_id}, + "signature": signature_bytes, + "certs": certs, + }, + }, } - }) + ) + # Enums + class ResponseStatus(enum.Enum): - successful = 'successful' - malformed_request = 'malformed_request' - internal_error = 'internal_error' - try_later = 'try_later' - sign_required = 'sign_required' - unauthorized = 'unauthorized' + successful = "successful" + malformed_request = "malformed_request" + internal_error = "internal_error" + try_later = "try_later" + sign_required = "sign_required" + unauthorized = "unauthorized" class CertificateStatus(enum.Enum): - good = 'good' - revoked = 'revoked' - key_compromise = 'key_compromise' - ca_compromise = 'ca_compromise' - affiliation_changed = 'affiliation_changed' - superseded = 'superseded' - cessation_of_operation = 'cessation_of_operation' - certificate_hold = 'certificate_hold' - remove_from_crl = 'remove_from_crl' - privilege_withdrawn = 'privilege_withdrawn' - unknown = 'unknown' + good = "good" + revoked = "revoked" + key_compromise = "key_compromise" + ca_compromise = "ca_compromise" + affiliation_changed = "affiliation_changed" + superseded = "superseded" + cessation_of_operation = "cessation_of_operation" + certificate_hold = "certificate_hold" + remove_from_crl = "remove_from_crl" + privilege_withdrawn = "privilege_withdrawn" + unknown = "unknown" # API endpoints @@ -440,10 +466,17 @@ class CertificateStatus(enum.Enum): FAULT_UNKNOWN = "unknown" app = Flask(__name__) -class OCSPResponder: - def __init__(self, issuer_cert: str, responder_cert: str, responder_key: str, - fault: str, next_update_seconds: int): + +class OCSPResponder: + def __init__( + self, + issuer_cert: str, + responder_cert: str, + responder_key: str, + fault: str, + next_update_seconds: int, + ): """ Create a new OCSPResponder instance. @@ -488,7 +521,7 @@ def validate(self): if self._fault == FAULT_UNKNOWN: return (CertificateStatus.unknown, None) if self._fault is not None: - raise NotImplementedError('Fault type could not be found') + raise NotImplementedError("Fault type could not be found") return (CertificateStatus.good, time) def _build_ocsp_response(self, ocsp_request: OCSPRequest) -> OCSPResponse: @@ -496,39 +529,43 @@ def _build_ocsp_response(self, ocsp_request: OCSPRequest) -> OCSPResponse: Create and return an OCSP response from an OCSP request. """ # Get the certificate serial - tbs_request = ocsp_request['tbs_request'] - request_list = tbs_request['request_list'] + tbs_request = ocsp_request["tbs_request"] + request_list = tbs_request["request_list"] if len(request_list) < 1: - logger.warning('Received OCSP request with no requests') - raise NotImplementedError('Empty requests not supported') + logger.warning("Received OCSP request with no requests") + raise NotImplementedError("Empty requests not supported") single_request = request_list[0] # TODO: Support more than one request - req_cert = single_request['req_cert'] - serial = req_cert['serial_number'].native + req_cert = single_request["req_cert"] + serial = req_cert["serial_number"].native # Check certificate status try: certificate_status, revocation_date = self.validate() except Exception as e: - logger.exception('Could not determine certificate status: %s', e) + logger.exception("Could not determine certificate status: %s", e) return self._fail(ResponseStatus.internal_error) certificate_status_list = [(serial, certificate_status.value)] # Build the response - builder = OCSPResponseBuilder(response_status=ResponseStatus.successful.value, certificate_status_list=certificate_status_list, revocation_date=revocation_date) + builder = OCSPResponseBuilder( + response_status=ResponseStatus.successful.value, + certificate_status_list=certificate_status_list, + revocation_date=revocation_date, + ) # Parse extensions - for extension in tbs_request['request_extensions']: - extn_id = extension['extn_id'].native - critical = extension['critical'].native - value = extension['extn_value'].parsed + for extension in tbs_request["request_extensions"]: + extn_id = extension["extn_id"].native + critical = extension["critical"].native + value = extension["extn_value"].parsed # This variable tracks whether any unknown extensions were encountered unknown = False # Handle nonce extension - if extn_id == 'nonce': + if extn_id == "nonce": builder.nonce = value.native # That's all we know @@ -539,20 +576,26 @@ def _build_ocsp_response(self, ocsp_request: OCSPRequest) -> OCSPResponse: # usually happen, according to RFC 6960 4.1.2), we should throw our # hands up in despair and run. if unknown is True and critical is True: - logger.warning('Could not parse unknown critical extension: %r', - dict(extension.native)) + logger.warning( + "Could not parse unknown critical extension: %r", + dict(extension.native), + ) return self._fail(ResponseStatus.internal_error) # If it's an unknown non-critical extension, we can safely ignore it. if unknown is True: - logger.info('Ignored unknown non-critical extension: %r', dict(extension.native)) + logger.info( + "Ignored unknown non-critical extension: %r", dict(extension.native) + ) # Set certificate issuer builder.certificate_issuer = self._issuer_cert # Set next update date now = datetime.now(timezone.utc) - builder.next_update = (now + timedelta(seconds=self._next_update_seconds)).replace(microsecond=0) + builder.next_update = ( + now + timedelta(seconds=self._next_update_seconds) + ).replace(microsecond=0) return builder.build(self._responder_key, self._responder_cert) @@ -560,26 +603,42 @@ def build_http_response(self, request_der: bytes) -> Response: global app response_der = self._build_ocsp_response(request_der).dump() resp = app.make_response((response_der, 200)) - resp.headers['content_type'] = 'application/ocsp-response' + resp.headers["content_type"] = "application/ocsp-response" return resp responder = None -def init_responder(issuer_cert: str, responder_cert: str, responder_key: str, fault: str, next_update_seconds: int): + +def init_responder( + issuer_cert: str, + responder_cert: str, + responder_key: str, + fault: str, + next_update_seconds: int, +): global responder - responder = OCSPResponder(issuer_cert=issuer_cert, responder_cert=responder_cert, responder_key=responder_key, fault=fault, next_update_seconds=next_update_seconds) + responder = OCSPResponder( + issuer_cert=issuer_cert, + responder_cert=responder_cert, + responder_key=responder_key, + fault=fault, + next_update_seconds=next_update_seconds, + ) + def init(port=8080, debug=False, host=None): - logger.info('Launching %sserver on port %d', 'debug' if debug else '', port) + logger.info("Launching %sserver on port %d", "debug" if debug else "", port) app.run(port=port, debug=debug, host=host) -@app.route('/', methods=['GET']) + +@app.route("/", methods=["GET"]) def _handle_root(): - return 'ocsp-responder' + return "ocsp-responder" -@app.route('/status/', defaults={'u_path': ''}, methods=['GET']) -@app.route('/status/', methods=['GET']) + +@app.route("/status/", defaults={"u_path": ""}, methods=["GET"]) +@app.route("/status/", methods=["GET"]) def _handle_get(u_path): global responder """ @@ -587,12 +646,13 @@ def _handle_get(u_path): HTTP request URL. """ if "Host" not in request.headers: - raise ValueError ("Required 'Host' header not present") + raise ValueError("Required 'Host' header not present") der = base64.b64decode(u_path) ocsp_request = responder.parse_ocsp_request(der) return responder.build_http_response(ocsp_request) -@app.route('/status', methods=['POST']) + +@app.route("/status", methods=["POST"]) def _handle_post(): global responder """ @@ -600,6 +660,6 @@ def _handle_post(): request body. """ if "Host" not in request.headers: - raise ValueError ("Required 'Host' header not present") + raise ValueError("Required 'Host' header not present") ocsp_request = responder.parse_ocsp_request(request.data) return responder.build_http_response(ocsp_request) diff --git a/.evergreen/ocsp/ocsp_mock.py b/.evergreen/ocsp/ocsp_mock.py index 532cdfb5..7d0c0bbf 100755 --- a/.evergreen/ocsp/ocsp_mock.py +++ b/.evergreen/ocsp/ocsp_mock.py @@ -8,7 +8,7 @@ import os import sys -sys.path.append(os.path.join(os.getcwd() ,'src', 'third_party', 'mock_ocsp_responder')) +sys.path.append(os.path.join(os.getcwd(), "src", "third_party", "mock_ocsp_responder")) import mock_ocsp_responder @@ -17,37 +17,73 @@ def main(): """Main entry point""" parser = argparse.ArgumentParser(description="MongoDB Mock OCSP Responder.") - parser.add_argument('-p', '--port', type=int, default=8080, help="Port to listen on") - - parser.add_argument('-b', '--bind_ip', type=str, default=None, help="IP to listen on") - - parser.add_argument('--ca_file', type=str, required=True, help="CA file for OCSP responder") - - parser.add_argument('-v', '--verbose', action='count', help="Enable verbose tracing") - - parser.add_argument('--ocsp_responder_cert', type=str, required=True, help="OCSP Responder Certificate") - - parser.add_argument('--ocsp_responder_key', type=str, required=True, help="OCSP Responder Keyfile") - - parser.add_argument('--fault', choices=[mock_ocsp_responder.FAULT_REVOKED, mock_ocsp_responder.FAULT_UNKNOWN, None], default=None, type=str, help="Specify a specific fault to test") - - parser.add_argument('--next_update_seconds', type=int, default=32400, help="Specify how long the OCSP response should be valid for") + parser.add_argument( + "-p", "--port", type=int, default=8080, help="Port to listen on" + ) + + parser.add_argument( + "-b", "--bind_ip", type=str, default=None, help="IP to listen on" + ) + + parser.add_argument( + "--ca_file", type=str, required=True, help="CA file for OCSP responder" + ) + + parser.add_argument( + "-v", "--verbose", action="count", help="Enable verbose tracing" + ) + + parser.add_argument( + "--ocsp_responder_cert", + type=str, + required=True, + help="OCSP Responder Certificate", + ) + + parser.add_argument( + "--ocsp_responder_key", type=str, required=True, help="OCSP Responder Keyfile" + ) + + parser.add_argument( + "--fault", + choices=[ + mock_ocsp_responder.FAULT_REVOKED, + mock_ocsp_responder.FAULT_UNKNOWN, + None, + ], + default=None, + type=str, + help="Specify a specific fault to test", + ) + + parser.add_argument( + "--next_update_seconds", + type=int, + default=32400, + help="Specify how long the OCSP response should be valid for", + ) args = parser.parse_args() if args.verbose: logging.basicConfig(level=logging.DEBUG) - print('Initializing OCSP Responder') - mock_ocsp_responder.init_responder(issuer_cert=args.ca_file, responder_cert=args.ocsp_responder_cert, responder_key=args.ocsp_responder_key, fault=args.fault, next_update_seconds=args.next_update_seconds) + print("Initializing OCSP Responder") + mock_ocsp_responder.init_responder( + issuer_cert=args.ca_file, + responder_cert=args.ocsp_responder_cert, + responder_key=args.ocsp_responder_key, + fault=args.fault, + next_update_seconds=args.next_update_seconds, + ) mock_ocsp_responder.init(port=args.port, debug=args.verbose, host=args.bind_ip) # Write the pid file. - with open(os.path.join(os.getcwd(), 'ocsp.pid'), 'w') as fid: + with open(os.path.join(os.getcwd(), "ocsp.pid"), "w") as fid: fid.write(str(os.getpid())) - print('Mock OCSP Responder is running on port %s' % (str(args.port))) + print("Mock OCSP Responder is running on port %s" % (str(args.port))) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/.evergreen/secrets_handling/setup_secrets.py b/.evergreen/secrets_handling/setup_secrets.py index 7f5cff3e..336868be 100644 --- a/.evergreen/secrets_handling/setup_secrets.py +++ b/.evergreen/secrets_handling/setup_secrets.py @@ -1,6 +1,7 @@ """ Script for fetching AWS Secrets Vault secrets for use in testing. """ + import argparse import json import os @@ -21,26 +22,32 @@ def get_secrets(vaults, region, profile): creds = None kwargs = dict(region_name=region) if "AWS_ACCESS_KEY_ID" not in os.environ and not profile: - client = session.client(service_name='sts', **kwargs) + client = session.client(service_name="sts", **kwargs) try: # This will only fail locally. - resp = client.assume_role(RoleArn=AWS_ROLE_ARN, RoleSessionName=str(uuid.uuid4())) + resp = client.assume_role( + RoleArn=AWS_ROLE_ARN, RoleSessionName=str(uuid.uuid4()) + ) except Exception as e: print(e) - raise ValueError("Please provide a profile (typically using AWS_PROFILE)") from e + raise ValueError( + "Please provide a profile (typically using AWS_PROFILE)" + ) from e - creds = resp['Credentials'] + creds = resp["Credentials"] if creds: - kwargs.update(aws_access_key_id=creds['AccessKeyId'], - aws_secret_access_key=creds['SecretAccessKey'], - aws_session_token=creds['SessionToken']) - client = session.client(service_name='secretsmanager', **kwargs) + kwargs.update( + aws_access_key_id=creds["AccessKeyId"], + aws_secret_access_key=creds["SecretAccessKey"], + aws_session_token=creds["SessionToken"], + ) + client = session.client(service_name="secretsmanager", **kwargs) secrets = [] try: for vault in vaults: - secret = client.get_secret_value(SecretId=vault)['SecretString'] + secret = client.get_secret_value(SecretId=vault)["SecretString"] secrets.append(secret) except botocore.exceptions.BotoCoreError as e: # For a list of exceptions thrown, see @@ -65,27 +72,44 @@ def write_secrets(vaults, region, profile): # These values are secrets, do not print them out.write("#!/usr/bin/env bash\n\nset +x\n") for key, val in pairs.items(): - out.write("export " + key + "=" + "\"" + val + "\"\n") + out.write("export " + key + "=" + '"' + val + '"\n') def main(): - parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter, - description='MongoDB AWS Secrets Vault fetcher. If connecting with the given AWS ' - 'profile fails, will attempt to use local environment variables ' - 'instead.') - - parser.add_argument("-p", "--profile", type=str, nargs="?", metavar="profile", help="a local AWS profile " - "to use credentials " - "from. Defaults to " - "AWS_PROFILE if not provided.") - parser.add_argument("-r", "--region", type=str, metavar="region", default="us-east-1", - help="the AWS region containing the given vaults.") - parser.add_argument("vaults", metavar="V", type=str, nargs="+", help="a vault to fetch secrets from") + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + description="MongoDB AWS Secrets Vault fetcher. If connecting with the given AWS " + "profile fails, will attempt to use local environment variables " + "instead.", + ) + + parser.add_argument( + "-p", + "--profile", + type=str, + nargs="?", + metavar="profile", + help="a local AWS profile " + "to use credentials " + "from. Defaults to " + "AWS_PROFILE if not provided.", + ) + parser.add_argument( + "-r", + "--region", + type=str, + metavar="region", + default="us-east-1", + help="the AWS region containing the given vaults.", + ) + parser.add_argument( + "vaults", metavar="V", type=str, nargs="+", help="a vault to fetch secrets from" + ) args = parser.parse_args() write_secrets(args.vaults, args.region, args.profile) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/.evergreen/socks5srv.py b/.evergreen/socks5srv.py index 1e0448b0..f63b9a70 100755 --- a/.evergreen/socks5srv.py +++ b/.evergreen/socks5srv.py @@ -7,248 +7,267 @@ # Usage: python3 socks5srv.py --port port [--auth username:password] [--map 'host:port to host:port' ...] + class AddressRemapper: - """A helper for remapping (host, port) tuples to new (host, port) tuples - - This is useful for Socks5 servers used in testing environments, - because the successful use of the Socks5 proxy can be demonstrated - by being able to 'connect' to a redirected port, which would always - fail without the proxy, even on localhost-only environments - """ - - def __init__(self, mappings): - self.mappings = [AddressRemapper.parse_single_mapping(string) for string in mappings] - self.add_dns_remappings() - - @staticmethod - def parse_single_mapping(string): - """Parse a single mapping of the for '{host}:{port} to {host}:{port}'""" - - # Accept either [ipv6]:port or host:port - host_re = r"(\[(?P<{0}_ipv6>[^[\]]+)\]|(?P<{0}_host>[^\[]+))" - port_re = r"(?P<{0}_port>\d+)" - - src_re = host_re.format('src') + ':' + port_re.format('src') - dst_re = host_re.format('dst') + ':' + port_re.format('dst') - full_re = '^' + src_re + ' to ' + dst_re + '$' - - match = re.match(full_re, string) - if match is None: - raise Exception(f"Mapping {string} does not match format '{{host}}:{{port}} to {{host}}:{{port}}'") - - src = ((match.group('src_ipv6') or match.group('src_host')).encode('utf8'), int(match.group('src_port'))) - dst = ((match.group('dst_ipv6') or match.group('dst_host')).encode('utf8'), int(match.group('dst_port'))) - return (src, dst) - - def add_dns_remappings(self): - """Add mappings for the IP addresses corresponding to hostnames - - For example, if there is a mapping (localhost, 1000) to (localhost, 2000), - then this also adds (127.0.0.1, 1000) to (localhost, 2000).""" - - for src, dst in self.mappings: - host, port = src - try: - addrs = socket.getaddrinfo(host, port, socket.AF_UNSPEC, socket.SOCK_STREAM) - except socket.gaierror: - continue - - existing_src_entries = [src for src, dst in self.mappings] - for af, socktype, proto, canonname, sa in addrs: - if af == socket.AF_INET and sa not in existing_src_entries: - self.mappings.append((sa, dst)) - elif af == socket.AF_INET6 and sa[:2] not in existing_src_entries: - self.mappings.append((sa[:2], dst)) - - def remap(self, hostport): - """Re-map a (host, port) tuple to a new (host, port) tuple if that was requested""" - - for src, dst in self.mappings: - if hostport == src: - return dst - return hostport + """A helper for remapping (host, port) tuples to new (host, port) tuples + + This is useful for Socks5 servers used in testing environments, + because the successful use of the Socks5 proxy can be demonstrated + by being able to 'connect' to a redirected port, which would always + fail without the proxy, even on localhost-only environments + """ + + def __init__(self, mappings): + self.mappings = [ + AddressRemapper.parse_single_mapping(string) for string in mappings + ] + self.add_dns_remappings() + + @staticmethod + def parse_single_mapping(string): + """Parse a single mapping of the for '{host}:{port} to {host}:{port}'""" + + # Accept either [ipv6]:port or host:port + host_re = r"(\[(?P<{0}_ipv6>[^[\]]+)\]|(?P<{0}_host>[^\[]+))" + port_re = r"(?P<{0}_port>\d+)" + + src_re = host_re.format("src") + ":" + port_re.format("src") + dst_re = host_re.format("dst") + ":" + port_re.format("dst") + full_re = "^" + src_re + " to " + dst_re + "$" + + match = re.match(full_re, string) + if match is None: + raise Exception( + f"Mapping {string} does not match format '{{host}}:{{port}} to {{host}}:{{port}}'" + ) + + src = ( + (match.group("src_ipv6") or match.group("src_host")).encode("utf8"), + int(match.group("src_port")), + ) + dst = ( + (match.group("dst_ipv6") or match.group("dst_host")).encode("utf8"), + int(match.group("dst_port")), + ) + return (src, dst) + + def add_dns_remappings(self): + """Add mappings for the IP addresses corresponding to hostnames + + For example, if there is a mapping (localhost, 1000) to (localhost, 2000), + then this also adds (127.0.0.1, 1000) to (localhost, 2000).""" + + for src, dst in self.mappings: + host, port = src + try: + addrs = socket.getaddrinfo( + host, port, socket.AF_UNSPEC, socket.SOCK_STREAM + ) + except socket.gaierror: + continue + + existing_src_entries = [src for src, dst in self.mappings] + for af, socktype, proto, canonname, sa in addrs: + if af == socket.AF_INET and sa not in existing_src_entries: + self.mappings.append((sa, dst)) + elif af == socket.AF_INET6 and sa[:2] not in existing_src_entries: + self.mappings.append((sa[:2], dst)) + + def remap(self, hostport): + """Re-map a (host, port) tuple to a new (host, port) tuple if that was requested""" + + for src, dst in self.mappings: + if hostport == src: + return dst + return hostport + class Socks5Server(socketserver.ThreadingTCPServer): - """A simple Socks5 proxy server""" + """A simple Socks5 proxy server""" + + def __init__(self, server_address, RequestHandlerClass, args): + socketserver.ThreadingTCPServer.__init__( + self, server_address, RequestHandlerClass + ) + self.args = args + self.address_remapper = AddressRemapper(args.map) - def __init__(self, server_address, RequestHandlerClass, args): - socketserver.ThreadingTCPServer.__init__(self, - server_address, - RequestHandlerClass) - self.args = args - self.address_remapper = AddressRemapper(args.map) class Socks5Handler(socketserver.BaseRequestHandler): - """Request handler for Socks5 connections""" + """Request handler for Socks5 connections""" + + def finish(self): + """Called after handle(), always just closes the connection""" + + self.request.close() + + def read_exact(self, n): + """Read n bytes from a socket + + In Socks5, strings are prefixed with a single byte containing + their length. This method reads a bytes string containing n bytes + (where n can be a number or a bytes object containing that + single byte). + + If reading from the client ends prematurely, this returns None. + """ + + if type(n) is bytes: + if len(n) == 0: + return None + assert len(n) == 1 + n = n[0] + + buf = bytearray(n) + mv = memoryview(buf) + bytes_read = 0 + while bytes_read < n: + try: + chunk_length = self.request.recv_into(mv[bytes_read:]) + except OSError: + return None + if chunk_length == 0: + return None + + bytes_read += chunk_length + return bytes(buf) + + def create_outgoing_tcp_connection(self, dst, port): + """Create an outgoing TCP connection to dst:port""" + + outgoing = None + for res in socket.getaddrinfo(dst, port, socket.AF_UNSPEC, socket.SOCK_STREAM): + af, socktype, proto, canonname, sa = res + try: + outgoing = socket.socket(af, socktype, proto) + except OSError: + continue + try: + outgoing.connect(sa) + except OSError: + outgoing.close() + continue + break + return outgoing + + def handle(self): + """Handle the Socks5 communication with a freshly connected client""" + + # This implements the Socks5 protocol as specified in + # https://datatracker.ietf.org/doc/html/rfc1928 + # and username/password authentication as specified in + # https://datatracker.ietf.org/doc/html/rfc1929 + # If you prefer HTML tables over ASCII tables, Wikipedia + # also currently has a decent description of the protocol in + # https://en.wikipedia.org/wiki/SOCKS#SOCKS5. + + # Receive/send errors are intentionally left unhandled. Closing + # the socket is just fine in that case for us. + + # Client greeting + if self.request.recv(1) != b"\x05": # Socks5 only + return + n_auth = self.request.recv(1) + client_auth_methods = self.read_exact(n_auth) + if client_auth_methods is None: + return - def finish(self): - """Called after handle(), always just closes the connection""" + # choose either no-auth or username/password + required_auth_method = b"\x00" if self.server.args.auth is None else b"\x02" + if required_auth_method not in client_auth_methods: + self.request.sendall(b"\x05\xff") + return + + self.request.sendall(b"\x05" + required_auth_method) + if required_auth_method == b"\x02": + auth_version = self.request.recv(1) + if auth_version != b"\x01": # Only username/password auth v1 + return + username_len = self.request.recv(1) + username = self.read_exact(username_len) + password_len = self.request.recv(1) + password = self.read_exact(password_len) + if username is None or password is None: + return + if ( + username.decode("utf8") + ":" + password.decode("utf8") + != self.server.args.auth + ): + return + self.request.sendall(b"\x01\x00") # auth success + + if self.request.recv(1) != b"\x05": # Socks5 only + return + if self.request.recv(1) != b"\x01": # Outgoing TCP only + return + if self.request.recv(1) != b"\x00": # Reserved, must be 0 + return - self.request.close() + addrtype = self.request.recv(1) + dst = None + if addrtype == b"\x01": # IPv4 + ipv4raw = self.read_exact(4) + if ipv4raw is not None: + dst = ".".join(["{}"] * 4).format(*ipv4raw) + elif addrtype == b"\x03": # Domain + domain_len = self.request.recv(1) + dst = self.read_exact(domain_len) + elif addrtype == b"\x04": # IPv6 + ipv6raw = self.read_exact(16) + if ipv6raw is not None: + dst = ":".join(["{:0>2x}{:0>2x}"] * 8).format(*ipv6raw) + else: + return - def read_exact(self, n): - """Read n bytes from a socket + if dst is None: + return - In Socks5, strings are prefixed with a single byte containing - their length. This method reads a bytes string containing n bytes - (where n can be a number or a bytes object containing that - single byte). + portraw = self.read_exact(2) + port = portraw[0] * 256 + portraw[1] - If reading from the client ends prematurely, this returns None. - """ + (dst, port) = self.server.address_remapper.remap((dst, port)) - if type(n) is bytes: - if len(n) == 0: - return None - assert len(n) == 1 - n = n[0] - - buf = bytearray(n) - mv = memoryview(buf) - bytes_read = 0 - while bytes_read < n: - try: - chunk_length = self.request.recv_into(mv[bytes_read:]) - except OSError: - return None - if chunk_length == 0: - return None - - bytes_read += chunk_length - return bytes(buf) - - def create_outgoing_tcp_connection(self, dst, port): - """Create an outgoing TCP connection to dst:port""" - - outgoing = None - for res in socket.getaddrinfo(dst, port, socket.AF_UNSPEC, socket.SOCK_STREAM): - af, socktype, proto, canonname, sa = res - try: - outgoing = socket.socket(af, socktype, proto) - except OSError: - continue - try: - outgoing.connect(sa) - except OSError: - outgoing.close() - continue - break - return outgoing - - def handle(self): - """Handle the Socks5 communication with a freshly connected client""" - - # This implements the Socks5 protocol as specified in - # https://datatracker.ietf.org/doc/html/rfc1928 - # and username/password authentication as specified in - # https://datatracker.ietf.org/doc/html/rfc1929 - # If you prefer HTML tables over ASCII tables, Wikipedia - # also currently has a decent description of the protocol in - # https://en.wikipedia.org/wiki/SOCKS#SOCKS5. - - # Receive/send errors are intentionally left unhandled. Closing - # the socket is just fine in that case for us. - - # Client greeting - if self.request.recv(1) != b'\x05': # Socks5 only - return - n_auth = self.request.recv(1) - client_auth_methods = self.read_exact(n_auth) - if client_auth_methods is None: - return - - # choose either no-auth or username/password - required_auth_method = b'\x00' if self.server.args.auth is None else b'\x02' - if required_auth_method not in client_auth_methods: - self.request.sendall(b'\x05\xff') - return - - self.request.sendall(b'\x05' + required_auth_method) - if required_auth_method == b'\x02': - auth_version = self.request.recv(1) - if auth_version != b'\x01': # Only username/password auth v1 - return - username_len = self.request.recv(1) - username = self.read_exact(username_len) - password_len = self.request.recv(1) - password = self.read_exact(password_len) - if username is None or password is None: - return - if username.decode('utf8') + ':' + password.decode('utf8') != self.server.args.auth: - return - self.request.sendall(b'\x01\x00') # auth success - - if self.request.recv(1) != b'\x05': # Socks5 only - return - if self.request.recv(1) != b'\x01': # Outgoing TCP only - return - if self.request.recv(1) != b'\x00': # Reserved, must be 0 - return - - addrtype = self.request.recv(1) - dst = None - if addrtype == b'\x01': # IPv4 - ipv4raw = self.read_exact(4) - if ipv4raw is not None: - dst = '.'.join(['{}'] * 4).format(*ipv4raw) - elif addrtype == b'\x03': # Domain - domain_len = self.request.recv(1) - dst = self.read_exact(domain_len) - elif addrtype == b'\x04': # IPv6 - ipv6raw = self.read_exact(16) - if ipv6raw is not None: - dst = ':'.join(['{:0>2x}{:0>2x}'] * 8).format(*ipv6raw) - else: - return - - if dst is None: - return - - portraw = self.read_exact(2) - port = portraw[0] * 256 + portraw[1] - - (dst, port) = self.server.address_remapper.remap((dst, port)) - - outgoing = self.create_outgoing_tcp_connection(dst, port) - if outgoing is None: - self.request.sendall(b'\x05\x01\x00') # just report a general failure - return - # success response, do not bother actually stating the locally bound - # host/port address and instead always say 127.0.0.1:4096. - # for our use case, the client will not be making meaningful use - # of this anyway - self.request.sendall(b'\x05\x00\x00\x01\x7f\x00\x00\x01\x10\x00') - - self.raw_proxy(self.request, outgoing) - - def raw_proxy(self, a, b): - """Proxy data between sockets a and b as-is""" - - with a, b: - while True: - try: - (readable, _, _) = select.select([a, b], [], []) - except (OSError, ValueError): - return - - if not readable: - continue - for sock in readable: - buf = sock.recv(4096) - if buf == b'': + outgoing = self.create_outgoing_tcp_connection(dst, port) + if outgoing is None: + self.request.sendall(b"\x05\x01\x00") # just report a general failure return - if sock is a: - b.sendall(buf) - else: - a.sendall(buf) - -if __name__ == '__main__': - parser = argparse.ArgumentParser(description='Start a Socks5 proxy server.') - parser.add_argument('--port', type=int, required=True) - parser.add_argument('--auth', type=str) - parser.add_argument('--map', type=str, action='append', default=[]) - args = parser.parse_args() - - socketserver.TCPServer.allow_reuse_address = True - with Socks5Server(('localhost', args.port), Socks5Handler, args) as server: - server.serve_forever() + # success response, do not bother actually stating the locally bound + # host/port address and instead always say 127.0.0.1:4096. + # for our use case, the client will not be making meaningful use + # of this anyway + self.request.sendall(b"\x05\x00\x00\x01\x7f\x00\x00\x01\x10\x00") + + self.raw_proxy(self.request, outgoing) + + def raw_proxy(self, a, b): + """Proxy data between sockets a and b as-is""" + + with a, b: + while True: + try: + (readable, _, _) = select.select([a, b], [], []) + except (OSError, ValueError): + return + + if not readable: + continue + for sock in readable: + buf = sock.recv(4096) + if buf == b"": + return + if sock is a: + b.sendall(buf) + else: + a.sendall(buf) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Start a Socks5 proxy server.") + parser.add_argument("--port", type=int, required=True) + parser.add_argument("--auth", type=str) + parser.add_argument("--map", type=str, action="append", default=[]) + args = parser.parse_args() + + socketserver.TCPServer.allow_reuse_address = True + with Socks5Server(("localhost", args.port), Socks5Handler, args) as server: + server.serve_forever() diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 9cc3562b..9d4745f9 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -54,7 +54,7 @@ repos: hooks: - id: ruff args: ["--fix", "--show-fixes"] - # - id: ruff-format + - id: ruff-format - repo: local hooks: diff --git a/evergreen_config_generator/evergreen_config_generator/__init__.py b/evergreen_config_generator/evergreen_config_generator/__init__.py index 51d9745d..5ab09130 100644 --- a/evergreen_config_generator/evergreen_config_generator/__init__.py +++ b/evergreen_config_generator/evergreen_config_generator/__init__.py @@ -21,7 +21,8 @@ import yamlordereddictloader except ImportError: sys.stderr.write( - "try 'pip install -r evergreen_config_generator/requirements.txt'\n") + "try 'pip install -r evergreen_config_generator/requirements.txt'\n" + ) raise @@ -31,10 +32,10 @@ def __init__(self, *args, **kwargs): @property def name(self): - return 'UNSET' + return "UNSET" def to_dict(self): - return OD([('name', self.name)]) + return OD([("name", self.name)]) # We want legible YAML tasks: @@ -51,17 +52,17 @@ def to_dict(self): # Write values compactly except multiline strings, which use "|" style. Write # tag sets as lists. + class _Dumper(yamlordereddictloader.Dumper): def __init__(self, *args, **kwargs): super(_Dumper, self).__init__(*args, **kwargs) self.add_representer(set, type(self).represent_set) # Use "multi_representer" to represent all subclasses of ConfigObject. - self.add_multi_representer(ConfigObject, - type(self).represent_config_object) + self.add_multi_representer(ConfigObject, type(self).represent_config_object) def represent_scalar(self, tag, value, style=None): - if isinstance(value, str) and '\n' in value: - style = '|' + if isinstance(value, str) and "\n" in value: + style = "|" return super(_Dumper, self).represent_scalar(tag, value, style) def represent_set(self, data): @@ -80,8 +81,8 @@ def generate(config, path): config is a dict, preferably an OrderedDict. path is a file path. """ - f = open(path, 'w+') - f.write('''#################################### + f = open(path, "w+") + f.write("""#################################### # Evergreen configuration # # Generated with evergreen_config_generator from @@ -91,5 +92,5 @@ def generate(config, path): # #################################### -''') +""") f.write(yaml_dump(config)) diff --git a/evergreen_config_generator/evergreen_config_generator/functions.py b/evergreen_config_generator/evergreen_config_generator/functions.py index 6a611614..c0f8efa0 100644 --- a/evergreen_config_generator/evergreen_config_generator/functions.py +++ b/evergreen_config_generator/evergreen_config_generator/functions.py @@ -19,92 +19,118 @@ def func(func_name, **kwargs): - od = OD([('func', func_name)]) + od = OD([("func", func_name)]) if kwargs: - od['vars'] = OD(sorted(kwargs.items())) + od["vars"] = OD(sorted(kwargs.items())) return od -def bootstrap(VERSION='latest', TOPOLOGY=None, **kwargs): +def bootstrap(VERSION="latest", TOPOLOGY=None, **kwargs): if TOPOLOGY: - return func('bootstrap mongo-orchestration', - VERSION=VERSION, - TOPOLOGY=TOPOLOGY, - **kwargs) + return func( + "bootstrap mongo-orchestration", + VERSION=VERSION, + TOPOLOGY=TOPOLOGY, + **kwargs, + ) - return func('bootstrap mongo-orchestration', - VERSION=VERSION, - **kwargs) + return func("bootstrap mongo-orchestration", VERSION=VERSION, **kwargs) def run_tests(URI=None, **kwargs): if URI: - return func('run tests', URI=URI, **kwargs) + return func("run tests", URI=URI, **kwargs) - return func('run tests', **kwargs) + return func("run tests", **kwargs) def s3_put(remote_file, project_path=True, **kwargs): if project_path: - remote_file = '${project}/' + remote_file - - od = OD([ - ('command', 's3.put'), - ('params', OD([ - ('aws_key', '${aws_key}'), - ('aws_secret', '${aws_secret}'), - ('remote_file', remote_file), - ('bucket', 'mciuploads'), - ('permissions', 'public-read')]))]) - - od['params'].update(kwargs) + remote_file = "${project}/" + remote_file + + od = OD( + [ + ("command", "s3.put"), + ( + "params", + OD( + [ + ("aws_key", "${aws_key}"), + ("aws_secret", "${aws_secret}"), + ("remote_file", remote_file), + ("bucket", "mciuploads"), + ("permissions", "public-read"), + ] + ), + ), + ] + ) + + od["params"].update(kwargs) return od def strip_lines(s): - return '\n'.join(line for line in s.split('\n') if line.strip()) - - -def shell_exec(script, test=True, errexit=True, xtrace=False, silent=False, - continue_on_err=False, working_dir=None, background=False): - dedented = '' + return "\n".join(line for line in s.split("\n") if line.strip()) + + +def shell_exec( + script, + test=True, + errexit=True, + xtrace=False, + silent=False, + continue_on_err=False, + working_dir=None, + background=False, +): + dedented = "" if errexit: - dedented += 'set -o errexit\n' + dedented += "set -o errexit\n" if xtrace: - dedented += 'set -o xtrace\n' + dedented += "set -o xtrace\n" dedented += dedent(strip_lines(script)) - command = OD([('command', 'shell.exec')]) + command = OD([("command", "shell.exec")]) if test: - command['type'] = 'test' + command["type"] = "test" - command['params'] = OD() + command["params"] = OD() if silent: - command['params']['silent'] = True + command["params"]["silent"] = True if working_dir is not None: - command['params']['working_dir'] = working_dir + command["params"]["working_dir"] = working_dir if continue_on_err: - command['params']['continue_on_err'] = True + command["params"]["continue_on_err"] = True if background: - command['params']['background'] = True + command["params"]["background"] = True - command['params']['shell'] = 'bash' - command['params']['script'] = dedented + command["params"]["shell"] = "bash" + command["params"]["script"] = dedented return command def targz_pack(target, source_dir, *include): - return OD([ - ('command', 'archive.targz_pack'), - ('params', OD([ - ('target', target), - ('source_dir', source_dir), - ('include', list(include))]))]) + return OD( + [ + ("command", "archive.targz_pack"), + ( + "params", + OD( + [ + ("target", target), + ("source_dir", source_dir), + ("include", list(include)), + ] + ), + ), + ] + ) class Function(ConfigObject): diff --git a/evergreen_config_generator/evergreen_config_generator/tasks.py b/evergreen_config_generator/evergreen_config_generator/tasks.py index e43afc65..c9c0e57a 100644 --- a/evergreen_config_generator/evergreen_config_generator/tasks.py +++ b/evergreen_config_generator/evergreen_config_generator/tasks.py @@ -32,19 +32,19 @@ def __init__(self, *args, **kwargs): self.tags = set() self.options = OD() self.depends_on = None - self.commands = kwargs.pop('commands', None) or [] + self.commands = kwargs.pop("commands", None) or [] assert isinstance(self.commands, (abc.Sequence, NoneType)) - tags = kwargs.pop('tags', None) + tags = kwargs.pop("tags", None) if tags: self.add_tags(*tags) - depends_on = kwargs.pop('depends_on', None) + depends_on = kwargs.pop("depends_on", None) if depends_on: self.add_dependency(depends_on) - if 'exec_timeout_secs' in kwargs: - self.options['exec_timeout_secs'] = kwargs.pop('exec_timeout_secs') + if "exec_timeout_secs" in kwargs: + self.options["exec_timeout_secs"] = kwargs.pop("exec_timeout_secs") - name_prefix = 'test' + name_prefix = "test" def add_tags(self, *args): self.tags.update(args) @@ -54,7 +54,7 @@ def has_tags(self, *args): def add_dependency(self, dependency): if not isinstance(dependency, abc.Mapping): - dependency = OD([('name', dependency)]) + dependency = OD([("name", dependency)]) if self.depends_on is None: self.depends_on = dependency @@ -67,7 +67,7 @@ def display(self, axis_name): value = getattr(self, axis_name) # E.g., if self.auth is False, return 'noauth'. if value is False: - return 'no' + axis_name + return "no" + axis_name if value is True: return axis_name @@ -77,20 +77,20 @@ def display(self, axis_name): def on_off(self, *args, **kwargs): assert not (args and kwargs) if args: - axis_name, = args - return 'on' if getattr(self, axis_name) else 'off' + (axis_name,) = args + return "on" if getattr(self, axis_name) else "off" - (axis_name, value), = kwargs.items() - return 'on' if getattr(self, axis_name) == value else 'off' + ((axis_name, value),) = kwargs.items() + return "on" if getattr(self, axis_name) == value else "off" def to_dict(self): task = super(Task, self).to_dict() if self.tags: - task['tags'] = self.tags + task["tags"] = self.tags task.update(self.options) if self.depends_on: - task['depends_on'] = self.depends_on - task['commands'] = self.commands + task["depends_on"] = self.depends_on + task["commands"] = self.commands return task diff --git a/evergreen_config_generator/evergreen_config_generator/variants.py b/evergreen_config_generator/evergreen_config_generator/variants.py index cd063463..43113ed9 100644 --- a/evergreen_config_generator/evergreen_config_generator/variants.py +++ b/evergreen_config_generator/evergreen_config_generator/variants.py @@ -16,8 +16,9 @@ class Variant(ConfigObject): - def __init__(self, name, display_name, run_on, tasks, expansions=None, - batchtime=None): + def __init__( + self, name, display_name, run_on, tasks, expansions=None, batchtime=None + ): super(Variant, self).__init__() self._variant_name = name self.display_name = display_name @@ -32,7 +33,7 @@ def name(self): def to_dict(self): v = super(Variant, self).to_dict() - for i in 'display_name', 'expansions', 'run_on', 'tasks', 'batchtime': + for i in "display_name", "expansions", "run_on", "tasks", "batchtime": if getattr(self, i): v[i] = getattr(self, i) return v diff --git a/evergreen_config_generator/setup.py b/evergreen_config_generator/setup.py index 80d093bc..ef21a891 100644 --- a/evergreen_config_generator/setup.py +++ b/evergreen_config_generator/setup.py @@ -15,15 +15,15 @@ import setuptools setuptools.setup( - name='evergreen_config_generator', - version='0.0.1', - author='A. Jesse Jiryu Davis', - author_email='jesse@mongodb.com', - description='Helpers for Python scripts that generate Evergreen configs', - url='https://github.com/mongodb-labs/drivers-evergreen-tools', - packages=['evergreen_config_generator'], - install_requires=['PyYAML', 'yamlordereddictloader'], + name="evergreen_config_generator", + version="0.0.1", + author="A. Jesse Jiryu Davis", + author_email="jesse@mongodb.com", + description="Helpers for Python scripts that generate Evergreen configs", + url="https://github.com/mongodb-labs/drivers-evergreen-tools", + packages=["evergreen_config_generator"], + install_requires=["PyYAML", "yamlordereddictloader"], classifiers=[ - 'License :: OSI Approved :: Apache Software License', + "License :: OSI Approved :: Apache Software License", ], )