From 28e6e00f1eae1144b667b785ffba55d711714ee4 Mon Sep 17 00:00:00 2001 From: "Jamie C. Driver" Date: Thu, 23 Nov 2023 08:29:57 +0000 Subject: [PATCH 1/9] Make exception consistent between redis and filesystem backends --- pindb.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pindb.py b/pindb.py index 20a6c2a..b582a2c 100644 --- a/pindb.py +++ b/pindb.py @@ -70,7 +70,8 @@ def redis_retry(func): def get(cls, key): data = cls.redis_retry(lambda: red_conn.get(key)) if not data: - raise Exception("No valid pin found") + # Raise error similar to filesystem backend + raise FileNotFoundError(2, "No record found", key.hex()) return data @classmethod From 1e5237299dbae74a32107b39dee0983a22e4f7af Mon Sep 17 00:00:00 2001 From: Lawrence Nahum Date: Thu, 24 Aug 2023 16:04:11 +0200 Subject: [PATCH 2/9] bpov2: one fewer server interaction/roundtrip Uses tweaked server static key with anti-reply monotonic forward counter --- client.py | 24 +++++- flaskserver.py | 71 +++++++++++------ pindb.py | 95 +++++++++++++++++------ server.py | 43 ++++++++++- test/test_pindb.py | 71 ++++++++++++++++- test/test_pinserver.py | 172 ++++++++++++++++++++++++++++++++++++++++- 6 files changed, 421 insertions(+), 55 deletions(-) diff --git a/client.py b/client.py index 901729a..0fa0914 100644 --- a/client.py +++ b/client.py @@ -1,6 +1,7 @@ from .lib import E_ECDH, decrypt, encrypt from hmac import compare_digest -from wallycore import ec_sig_verify, sha256, hmac_sha256, EC_FLAG_ECDSA +from wallycore import ec_sig_verify, sha256, hmac_sha256, EC_FLAG_ECDSA, \ + ec_public_key_bip341_tweak class PINClientECDH(E_ECDH): @@ -45,3 +46,24 @@ def decrypt_response_payload(self, encrypted, hmac): # Return decrypted data return decrypt(self.response_encryption_key, encrypted) + + +class PINClientECDHv2(PINClientECDH): + + def __init__(self, static_server_public_key, replay_counter): + super().__init__(static_server_public_key) + self.replay_counter = replay_counter + tweak = sha256(hmac_sha256(self.public_key, self.replay_counter)) + + # Derive and store the ecdh server public key (ske) + self.ecdh_server_public_key = ec_public_key_bip341_tweak( + self.static_server_public_key, tweak, 0) + + # Cache the shared secrets + self.generate_shared_secrets(self.ecdh_server_public_key) + + # Encrypt/sign/hmac the payload (ie. the pin secret) + def encrypt_request_payload(self, payload): + encrypted = encrypt(self.request_encryption_key, payload) + hmac = hmac_sha256(self.request_hmac_key, self.public_key + self.replay_counter + encrypted) + return encrypted, hmac diff --git a/flaskserver.py b/flaskserver.py index e39d961..7357462 100644 --- a/flaskserver.py +++ b/flaskserver.py @@ -2,7 +2,7 @@ import json import os from flask import Flask, request, jsonify -from .server import PINServerECDH +from .server import PINServerECDH, PINServerECDHv2 from .pindb import PINDb from wallycore import hex_from_bytes, hex_to_bytes, AES_KEY_LEN_256, \ AES_BLOCK_LEN @@ -55,33 +55,58 @@ def start_handshake_route(): return jsonify({'ske': ske, 'sig': b2h(sig)}) - def _complete_server_call(pin_func): - try: - # Get request data - udata = json.loads(request.data) - ske = udata['ske'] + def _complete_server_call_v1(pin_func, udata): + ske = udata['ske'] + assert 'replay_counter' not in udata + + # Get associated session (ensuring not stale) + _cleanup_expired_sessions() + e_ecdh_server = sessions[ske] + + # get/set pin and get response data + encrypted_key, hmac = e_ecdh_server.call_with_payload( + h2b(udata['cke']), + h2b(udata['encrypted_data']), + h2b(udata['hmac_encrypted_data']), + pin_func) - # Get associated session (ensuring not stale) - _cleanup_expired_sessions() - e_ecdh_server = sessions[ske] + # Expecting to return an encrypted aes-key + assert len(encrypted_key) == AES_KEY_LEN_256 + (2*AES_BLOCK_LEN) - # get/set pin and get response data - encrypted_key, hmac = e_ecdh_server.call_with_payload( - h2b(udata['cke']), - h2b(udata['encrypted_data']), - h2b(udata['hmac_encrypted_data']), - pin_func) + # Cleanup session + del sessions[ske] + _cleanup_expired_sessions() - # Expecting to return an encrypted aes-key - assert len(encrypted_key) == AES_KEY_LEN_256 + (2*AES_BLOCK_LEN) + # Return response + return jsonify({'encrypted_key': b2h(encrypted_key), + 'hmac': b2h(hmac)}) + + def _complete_server_call_v2(pin_func, udata): + assert 'ske' not in udata + assert len(udata['replay_counter']) == 8 + cke = h2b(udata['cke']) + replay_counter = h2b(udata['replay_counter']) + e_ecdh_server = PINServerECDHv2(replay_counter, cke) + encrypted_key, hmac = e_ecdh_server.call_with_payload( + cke, + h2b(udata['encrypted_data']), + h2b(udata['hmac_encrypted_data']), + pin_func) + + # Expecting to return an encrypted aes-key + assert len(encrypted_key) == AES_KEY_LEN_256 + (2*AES_BLOCK_LEN) - # Cleanup session - del sessions[ske] - _cleanup_expired_sessions() + # Return response + return jsonify({'encrypted_key': b2h(encrypted_key), + 'hmac': b2h(hmac)}) - # Return response - return jsonify({'encrypted_key': b2h(encrypted_key), - 'hmac': b2h(hmac)}) + def _complete_server_call(pin_func): + try: + # Get request data + udata = json.loads(request.data) + if 'replay_counter' in udata: + return _complete_server_call_v2(pin_func, udata) + return _complete_server_call_v1(pin_func, udata) except Exception as e: app.logger.error("Error: {} {}".format(type(e), e)) diff --git a/pindb.py b/pindb.py index b582a2c..771995c 100644 --- a/pindb.py +++ b/pindb.py @@ -13,7 +13,8 @@ b2h = hex_from_bytes h2b = hex_to_bytes -VERSION = 0 +VERSION_SUPPORTED = 0 +VERSION_LATEST = 1 load_dotenv() @@ -104,7 +105,7 @@ class PINDb(object): storage = get_storage() @classmethod - def _extract_fields(cls, cke, data): + def _extract_fields(cls, cke, data, replay_counter=None): assert len(data) == (2*SHA256_LEN) + EC_SIGNATURE_RECOVERABLE_LEN # secret + entropy + sig @@ -112,15 +113,21 @@ def _extract_fields(cls, cke, data): entropy = data[SHA256_LEN: SHA256_LEN + SHA256_LEN] sig = data[SHA256_LEN + SHA256_LEN:] + # make sure the client_public_key signs over the replay counter too if provided + if replay_counter is not None: + assert len(replay_counter) == 4 + signed_msg = sha256(cke + replay_counter + pin_secret + entropy) + else: + signed_msg = sha256(cke + pin_secret + entropy) + # We know mesage the signature is for, so can recover the public key - signed_msg = sha256(cke + pin_secret + entropy) client_public_key = ec_sig_to_public_key(signed_msg, sig) return pin_secret, entropy, client_public_key @classmethod def _save_pin_fields(cls, pin_pubkey_hash, hash_pin_secret, aes_key, - pin_pubkey, aes_pin_data_key, count=0): + pin_pubkey, aes_pin_data_key, count, replay_counter=None): # the data is encrypted and then hmac'ed for authentication # the encrypted data can't be read by us without the user @@ -129,9 +136,14 @@ def _save_pin_fields(cls, pin_pubkey_hash, hash_pin_secret, aes_key, storage_aes_key = hmac_sha256(aes_pin_data_key, pin_pubkey) count_bytes = struct.pack('B', count) plaintext = hash_pin_secret + aes_key + count_bytes + version_bytes = struct.pack('B', VERSION_SUPPORTED) + if replay_counter is not None: + # if this is v2 we save the latest replay_counter and update the version + plaintext += replay_counter + version_bytes = struct.pack('B', VERSION_LATEST) encrypted = encrypt(storage_aes_key, plaintext) pin_auth_key = hmac_sha256(aes_pin_data_key, pin_pubkey_hash) - version_bytes = struct.pack('B', VERSION) + hmac_payload = hmac_sha256(pin_auth_key, version_bytes + encrypted) cls.storage.set(pin_pubkey_hash, version_bytes + hmac_payload + encrypted) @@ -139,7 +151,7 @@ def _save_pin_fields(cls, pin_pubkey_hash, hash_pin_secret, aes_key, return aes_key @classmethod - def _load_pin_fields(cls, pin_pubkey_hash, pin_pubkey, aes_pin_data_key): + def _load_pin_fields(cls, pin_pubkey_hash, pin_pubkey, aes_pin_data_key, replay_counter=None): data = cls.storage.get(pin_pubkey_hash) assert len(data) == 129 @@ -147,8 +159,13 @@ def _load_pin_fields(cls, pin_pubkey_hash, pin_pubkey, aes_pin_data_key): # verify integrity of encrypted data first pin_auth_key = hmac_sha256(aes_pin_data_key, pin_pubkey_hash) - version_bytes = struct.pack('B', VERSION) - assert version_bytes == version + version_bytes = struct.pack('B', VERSION_LATEST) + len_plaintext = 32 + 32 + 1 + 4 + if version_bytes != version: + # this is the old database, check if we are upgrading + version_bytes = struct.pack('B', VERSION_SUPPORTED) + len_plaintext -= 4 + assert version_bytes == version hmac_payload = hmac_sha256(pin_auth_key, version_bytes + encrypted) assert hmac_payload == hmac_received @@ -156,12 +173,24 @@ def _load_pin_fields(cls, pin_pubkey_hash, pin_pubkey, aes_pin_data_key): storage_aes_key = hmac_sha256(aes_pin_data_key, pin_pubkey) plaintext = decrypt(storage_aes_key, encrypted) - assert len(plaintext) == 32 + 32 + 1 + assert len(plaintext) == len_plaintext, len(plaintext) hash_pin_secret, aes_key = plaintext[:32], plaintext[32:64] count = struct.unpack('B', plaintext[64: 64 + struct.calcsize('B')])[0] - return hash_pin_secret, aes_key, count + replay_local = None + if len_plaintext == 69: + replay_local = plaintext[65:69] + replay_local = int.from_bytes(replay_local, byteorder='little', + signed=False) + if replay_local is not None and replay_counter is not None: + # if this is v2 and the db is already upgraded we enforce the + # anti replay + replay_remote = int.from_bytes(replay_counter, byteorder='little', + signed=False) + assert replay_remote > replay_local + + return hash_pin_secret, aes_key, count, replay_local @classmethod def make_client_aes_key(self, pin_secret, saved_key): @@ -173,21 +202,25 @@ def make_client_aes_key(self, pin_secret, saved_key): # Get existing aes_key given pin fields @classmethod - def get_aes_key_impl(cls, pin_pubkey, pin_secret, aes_pin_data_key): + def get_aes_key_impl(cls, pin_pubkey, pin_secret, aes_pin_data_key, replay_counter=None): # Load the data from the pubkey pin_pubkey_hash = bytes(sha256(pin_pubkey)) - saved_hps, saved_key, counter = cls._load_pin_fields(pin_pubkey_hash, - pin_pubkey, - aes_pin_data_key) + saved_hps, saved_key, counter, replay_local = cls._load_pin_fields(pin_pubkey_hash, + pin_pubkey, + aes_pin_data_key, + replay_counter) + if replay_local is not None: + replay_local = replay_local.to_bytes(4, byteorder='little', signed=False) # Check that the pin provided matches that saved hash_pin_secret = sha256(pin_secret) if compare_digest(saved_hps, hash_pin_secret): # pin-secret matches - correct pin - if counter != 0: - # Zero the 'bad guess counter' + # Zero the 'bad guess counter' and/or update the replay_counter + if counter != 0 or replay_counter: cls._save_pin_fields(pin_pubkey_hash, saved_hps, saved_key, - pin_pubkey, aes_pin_data_key) + pin_pubkey, aes_pin_data_key, 0, + replay_counter or replay_local) # return the saved key return saved_key @@ -195,29 +228,35 @@ def get_aes_key_impl(cls, pin_pubkey, pin_secret, aes_pin_data_key): # user provided wrong pin if counter >= 2: # pin failed 3 times, overwrite and then remove secret + + max_replay = 4294967295 cls._save_pin_fields(pin_pubkey_hash, saved_hps, bytearray(AES_KEY_LEN_256), pin_pubkey, - aes_pin_data_key) + aes_pin_data_key, 3, + max_replay.to_bytes(4, + byteorder='little', + signed=False)) cls.storage.remove(pin_pubkey_hash) raise Exception("Too many attempts") else: # increment counter cls._save_pin_fields(pin_pubkey_hash, saved_hps, saved_key, pin_pubkey, - aes_pin_data_key, counter + 1) + aes_pin_data_key, counter + 1, + replay_counter or replay_local) raise Exception("Invalid PIN") # Get existing aes_key given pin fields, or junk if pin or pubkey bad @classmethod - def get_aes_key(cls, cke, payload, aes_pin_data_key): - pin_secret, _, pin_pubkey = cls._extract_fields(cke, payload) + def get_aes_key(cls, cke, payload, aes_pin_data_key, replay_counter=None): + pin_secret, _, pin_pubkey = cls._extract_fields(cke, payload, replay_counter) # Translate internal exception and bad-pin into junk key try: saved_key = cls.get_aes_key_impl(pin_pubkey, pin_secret, - aes_pin_data_key) + aes_pin_data_key, replay_counter) except Exception as e: # return junk key saved_key = os.urandom(AES_KEY_LEN_256) @@ -227,18 +266,24 @@ def get_aes_key(cls, cke, payload, aes_pin_data_key): # Set pin fields, return new aes_key @classmethod - def set_pin(cls, cke, payload, aes_pin_data_key): - pin_secret, entropy, pin_pubkey = cls._extract_fields(cke, payload) + def set_pin(cls, cke, payload, aes_pin_data_key, replay_counter=None): + pin_secret, entropy, pin_pubkey = cls._extract_fields(cke, payload, replay_counter) # Make a new aes-key to persist from our and client entropy our_random = os.urandom(32) new_key = hmac_sha256(our_random, entropy) + assert replay_counter is None or replay_counter == b'\x00\x00\x00\x00' + # Persist the pin fields pin_pubkey_hash = bytes(sha256(pin_pubkey)) hash_pin_secret = sha256(pin_secret) + replay_bytes = None + if replay_counter is not None: + replay_init = 0 + replay_bytes = replay_init.to_bytes(4, byteorder='little', signed=False) saved_key = cls._save_pin_fields(pin_pubkey_hash, hash_pin_secret, new_key, - pin_pubkey, aes_pin_data_key) + pin_pubkey, aes_pin_data_key, 0, replay_bytes) # Combine saved key with (not persisted) pin-secret return cls.make_client_aes_key(pin_secret, saved_key) diff --git a/server.py b/server.py index c992292..5753f43 100644 --- a/server.py +++ b/server.py @@ -3,7 +3,7 @@ import os from .lib import decrypt, encrypt, E_ECDH from wallycore import ec_private_key_verify, ec_sig_from_bytes, sha256, \ - hmac_sha256, EC_FLAG_ECDSA + hmac_sha256, EC_FLAG_ECDSA, ec_private_key_bip341_tweak, ec_public_key_from_private_key class PINServerECDH(E_ECDH): @@ -87,3 +87,44 @@ def call_with_payload(self, cke, encrypted, hmac, func): encrypted, hmac = self.encrypt_response_payload(response) return encrypted, hmac + + +class PINServerECDHv2(PINServerECDH): + + @classmethod + def generate_ec_key_pair(cls, replay_counter, cke): + assert cls.STATIC_SERVER_PRIVATE_KEY + + tweak = sha256(hmac_sha256(cke, replay_counter)) + private_key = ec_private_key_bip341_tweak(cls.STATIC_SERVER_PRIVATE_KEY, tweak, 0) + ec_private_key_verify(private_key) + public_key = ec_public_key_from_private_key(private_key) + return private_key, public_key + + def __init__(self, replay_counter, cke): + # intentionally we don't call any constructor from what we inherit from + assert len(replay_counter) == 4 + self.replay_counter = replay_counter + self.private_key, self.public_key = self.generate_ec_key_pair(replay_counter, cke) + + # Decrypt the received payload (ie. aes-key) + def decrypt_request_payload(self, cke, encrypted, hmac): + # Verify hmac received + hmac_calculated = hmac_sha256(self.request_hmac_key, cke + self.replay_counter + encrypted) + assert compare_digest(hmac, hmac_calculated) + + # Return decrypted data + return decrypt(self.request_encryption_key, encrypted) + + # Function to deal with wrapper ecdh encryption. + # Calls passed function with unwrapped payload, and wraps response before + # returning. Separates payload handler func from wrapper encryption. + def call_with_payload(self, cke, encrypted, hmac, func): + self.generate_shared_secrets(cke) + payload = self.decrypt_request_payload(cke, encrypted, hmac) + + # Call the passed function with the decrypted payload + response = func(cke, payload, self._get_aes_pin_data_key(), self.replay_counter) + + encrypted, hmac = self.encrypt_response_payload(response) + return encrypted, hmac diff --git a/test/test_pindb.py b/test/test_pindb.py index 73a7147..491be2e 100644 --- a/test/test_pindb.py +++ b/test/test_pindb.py @@ -119,18 +119,85 @@ def test_save_and_load_pin_fields(self): self.assertEqual(new_key, key_in) # Read file back in - ensure fields the same - hps_out, key_out, count_out = PINDb._load_pin_fields(pinfile, user_id, aes_pin) + hps_out, key_out, count_out, replay_local = PINDb._load_pin_fields(pinfile, + user_id, + aes_pin) self.assertEqual(hps_out, hps_in) self.assertEqual(key_out, key_in) self.assertEqual(count_out, count_in) + self.assertEqual(replay_local, None) # Ensure we can set zero the count of an existing file count_in = 0 new_key = PINDb._save_pin_fields(pinfile, hps_in, key_in, user_id, aes_pin, count_in) - hps_out, key_out, count_out = PINDb._load_pin_fields(pinfile, user_id, aes_pin) + hps_out, key_out, count_out, replay_local = PINDb._load_pin_fields(pinfile, + user_id, + aes_pin) self.assertEqual(hps_out, hps_in) self.assertEqual(key_out, key_in) self.assertEqual(count_out, count_in) + self.assertEqual(replay_local, None) + + # Ensure we can't decrypt the pin with the wrong aes_key, hmac won't match + bad_aes = os.urandom(32) + with self.assertRaises(AssertionError) as _: + PINDb._load_pin_fields(pinfile, user_id, bad_aes) + + def test_save_and_load_pin_fieldsv2(self): + # Reinitialise keys and secret + _, _, _, pinfile = self.new_keys() + pin_secret, key_in = self.new_pin_secret(), self.new_entropy() + hps_in = sha256(pin_secret) + count_in = 5 + + # Trying to read non-existent file throws (and does not create file) + self.assertFalse(PINDb.storage.exists(pinfile)) + with self.assertRaises((FileNotFoundError, Exception)) as _: + PINDb._load_pin_fields(pinfile, None, None) + self.assertFalse(PINDb.storage.exists(pinfile)) + + user_id = os.urandom(32) + aes_pin = bytes(os.urandom(32)) + + replay_counter = 0 + replay_counter = replay_counter.to_bytes(4, byteorder='little', signed=False) + + # Save some data - check new file created + new_key = PINDb._save_pin_fields(pinfile, hps_in, key_in, user_id, aes_pin, + count_in, replay_counter) + self.assertTrue(PINDb.storage.exists(pinfile)) + + # Atm the 'new key' returned should be the one passed in + self.assertEqual(new_key, key_in) + + replay_counter = 1 + replay_counter = replay_counter.to_bytes(4, byteorder='little', signed=False) + # Read file back in - ensure fields the same + hps_out, key_out, count_out, replay_local = PINDb._load_pin_fields(pinfile, + user_id, + aes_pin, + replay_counter) + self.assertEqual(hps_out, hps_in) + self.assertEqual(key_out, key_in) + self.assertEqual(count_out, count_in) + self.assertEqual(replay_local, 0) + + replay_counter = 5 + replay_counter = replay_counter.to_bytes(4, byteorder='little', signed=False) + + # Ensure we can set zero the count of an existing file + count_in = 0 + new_key = PINDb._save_pin_fields(pinfile, hps_in, key_in, user_id, aes_pin, + count_in, replay_counter) + replay_counter = 10000 + replay_counter = replay_counter.to_bytes(4, byteorder='little', signed=False) + hps_out, key_out, count_out, replay_local = PINDb._load_pin_fields(pinfile, + user_id, + aes_pin, replay_counter) + self.assertEqual(hps_out, hps_in) + self.assertEqual(key_out, key_in) + self.assertEqual(count_out, count_in) + self.assertEqual(replay_local, 5) # Ensure we can't decrypt the pin with the wrong aes_key, hmac won't match bad_aes = os.urandom(32) diff --git a/test/test_pinserver.py b/test/test_pinserver.py index 34db0cd..f8f7362 100644 --- a/test/test_pinserver.py +++ b/test/test_pinserver.py @@ -7,7 +7,7 @@ from hmac import compare_digest import requests -from ..client import PINClientECDH +from ..client import PINClientECDH, PINClientECDHv2 from ..server import PINServerECDH from ..pindb import PINDb @@ -107,12 +107,21 @@ def new_client_handshake(self): client = PINClientECDH(self.static_server_public_key) return self.start_handshake(client) + # Make a new ephemeral client and initialise with tweaked server key + def new_client_handshakev2(self, replay_counter): + client = PINClientECDHv2(self.static_server_public_key, replay_counter) + return client + # Make the server call to get/set the pin - returns the decrypted response - def server_call(self, private_key, client, endpoint, pin_secret, entropy): + def server_call(self, private_key, client, endpoint, pin_secret, entropy, replay_counter=None): # Make and encrypt the payload (ie. pin secret) ske, cke = client.get_key_exchange() + cke_sha = cke + if replay_counter is not None: + assert len(replay_counter) == 4 + cke_sha = cke + replay_counter sig = ec_sig_from_bytes(private_key, - sha256(cke + pin_secret + entropy), + sha256(cke_sha + pin_secret + entropy), EC_FLAG_ECDSA | EC_FLAG_RECOVERABLE) payload = pin_secret + entropy + sig @@ -123,6 +132,9 @@ def server_call(self, private_key, client, endpoint, pin_secret, entropy): 'cke': b2h(cke), 'encrypted_data': b2h(encrypted), 'hmac_encrypted_data': b2h(hmac)} + if replay_counter: + urldata['replay_counter'] = b2h(replay_counter) + del urldata['ske'] response = self.post(endpoint, urldata) encrypted = h2b(response['encrypted_key']) hmac = h2b(response['hmac']) @@ -142,6 +154,18 @@ def set_pin(self, private_key, pin_secret, entropy): return self.server_call( private_key, client, 'set_pin', pin_secret, entropy) + def get_pinv2(self, private_key, pin_secret, entropy, replay_counter): + # Create new ephemeral client, initiate handshake, and make call + client = self.new_client_handshakev2(replay_counter) + return self.server_call( + private_key, client, 'get_pin', pin_secret, entropy, replay_counter) + + def set_pinv2(self, private_key, pin_secret, entropy, replay_counter): + # Create new ephemeral client, initiate handshake, and make call + client = self.new_client_handshakev2(replay_counter) + return self.server_call( + private_key, client, 'set_pin', pin_secret, entropy, replay_counter) + # Tests def test_get_index(self): # No index or similar @@ -356,6 +380,148 @@ def test_cannot_reuse_client_session(self): self.new_entropy()) self.assertTrue(compare_digest(aeskey, aeskey_s)) + def test_v2_happypath_with_simulated_replay(self): + # Make ourselves a static key pair for this logical client + priv_key, _, _ = self.new_static_client_keys() + + # The 'correct' client pin + pin_secret = self.new_pin_secret() + + # assert you can't set pin with a replay_counter different than 0 + with self.assertRaises(ValueError) as cm: + replay_counter = 1 + self.set_pinv2(priv_key, pin_secret, self.new_entropy(), + replay_counter.to_bytes(4, + byteorder='little', + signed=False)) + + # set the pin secret to get a new aes key + replay_counter = 0 + aeskey_s = self.set_pinv2(priv_key, pin_secret, self.new_entropy(), + replay_counter.to_bytes(4, byteorder='little', + signed=False)) + + # retrieve the key again with our correct pin secret + replay_counter = 1 + aeskey = self.get_pinv2(priv_key, pin_secret, self.new_entropy(), + replay_counter.to_bytes(4, byteorder='little', + signed=False)) + + # Now let's compare + self.assertTrue(compare_digest(aeskey, aeskey_s)) + + for i in range(5): + # Simulate a reply attempt failing N times, it doesn't affect pin + # attempts / dos + aeskey = self.get_pinv2(priv_key, pin_secret, self.new_entropy(), + replay_counter.to_bytes(4, + byteorder='little', + signed=False)) + self.assertFalse(compare_digest(aeskey, aeskey_s)) + + # retrieve the key again using v1 + aeskey_g = self.get_pin(priv_key, pin_secret, self.new_entropy()) + self.assertTrue(compare_digest(aeskey_g, aeskey_s)) + + # Incrementing the counter monotonically works again + replay_counter = 2 + aeskey = self.get_pinv2(priv_key, pin_secret, self.new_entropy(), + replay_counter.to_bytes(4, byteorder='little', + signed=False)) + self.assertTrue(compare_digest(aeskey, aeskey_s)) + + # Incrementing the counter monotonically works even in case of network + # errors where some request is missed + replay_counter = 4 + aeskey = self.get_pinv2(priv_key, pin_secret, self.new_entropy(), + replay_counter.to_bytes(4, byteorder='little', + signed=False)) + self.assertTrue(compare_digest(aeskey, aeskey_s)) + + bad_secret = self.new_pin_secret() + for i in range(3): + # exaust pin attmempts with good replay_counter + replay_counter = i + 5 + replay_counter = replay_counter.to_bytes(4, byteorder='little', signed=False) + aeskey = self.get_pinv2(priv_key, bad_secret, self.new_entropy(), replay_counter) + self.assertFalse(compare_digest(aeskey, aeskey_s)) + + # retrieve the key again using v1 should fail + aeskey_g = self.get_pin(priv_key, pin_secret, self.new_entropy()) + self.assertFalse(compare_digest(aeskey_g, aeskey_s)) + + # Incrementing the counter monotonically also fails + replay_counter = 8 + aeskey = self.get_pinv2(priv_key, pin_secret, self.new_entropy(), + replay_counter.to_bytes(4, byteorder='little', + signed=False)) + self.assertFalse(compare_digest(aeskey, aeskey_s)) + + def test_v2_happypath_with_simulated_replay_upgrade(self): + # Make ourselves a static key pair for this logical client + priv_key, _, _ = self.new_static_client_keys() + + # The 'correct' client pin + pin_secret = self.new_pin_secret() + + # Make a new client and set the pin secret to get a new aes key + aeskey_s = self.set_pin(priv_key, pin_secret, self.new_entropy()) + self.assertEqual(len(aeskey_s), AES_KEY_LEN_256) + + # retrieve the key again with our correct pin secret + replay_counter = 0 + aeskey = self.get_pinv2(priv_key, pin_secret, self.new_entropy(), + replay_counter.to_bytes(4, byteorder='little', + signed=False)) + + # Now let's compare + self.assertTrue(compare_digest(aeskey, aeskey_s)) + + for i in range(5): + # Simulate a reply attempt failing N times, it doesn't affect pin + # attempts / dos + aeskey = self.get_pinv2(priv_key, pin_secret, self.new_entropy(), + replay_counter.to_bytes(4, + byteorder='little', + signed=False)) + self.assertFalse(compare_digest(aeskey, aeskey_s)) + + # retrieve the key again using v1 + aeskey_g = self.get_pin(priv_key, pin_secret, self.new_entropy()) + self.assertTrue(compare_digest(aeskey_g, aeskey_s)) + + # Incrementing the counter monotonically works again + replay_counter = 2 + aeskey = self.get_pinv2(priv_key, pin_secret, self.new_entropy(), + replay_counter.to_bytes(4, byteorder='little', + signed=False)) + self.assertTrue(compare_digest(aeskey, aeskey_s)) + + # Incrementing the counter monotonically works even in case of network + # errors where some request is missed + replay_counter = 4 + aeskey = self.get_pinv2(priv_key, pin_secret, self.new_entropy(), + replay_counter.to_bytes(4, byteorder='little', + signed=False)) + self.assertTrue(compare_digest(aeskey, aeskey_s)) + + bad_secret = self.new_pin_secret() + for i in range(3): + # exaust pin attmempts with good replay_counter + aeskey = self.get_pin(priv_key, bad_secret, self.new_entropy()) + self.assertFalse(compare_digest(aeskey, aeskey_s)) + + # retrieve the key again using v1 should fail + aeskey_g = self.get_pin(priv_key, pin_secret, self.new_entropy()) + self.assertFalse(compare_digest(aeskey_g, aeskey_s)) + + # Incrementing the counter monotonically also fails + replay_counter = 5 + aeskey = self.get_pinv2(priv_key, pin_secret, self.new_entropy(), + replay_counter.to_bytes(4, byteorder='little', + signed=False)) + self.assertFalse(compare_digest(aeskey, aeskey_s)) + if __name__ == '__main__': unittest.main() From 776e23e51b374bbe049a2ee58e01b7364fddb3d8 Mon Sep 17 00:00:00 2001 From: "Jamie C. Driver" Date: Thu, 30 Nov 2023 13:08:55 +0000 Subject: [PATCH 3/9] Prefer python native hex to/from bytes conversions --- flaskserver.py | 32 ++++++++++++++------------------ pindb.py | 8 ++------ test/test_pindb.py | 4 +--- test/test_pinserver.py | 25 +++++++++++-------------- 4 files changed, 28 insertions(+), 41 deletions(-) diff --git a/flaskserver.py b/flaskserver.py index 7357462..e6540df 100644 --- a/flaskserver.py +++ b/flaskserver.py @@ -4,13 +4,9 @@ from flask import Flask, request, jsonify from .server import PINServerECDH, PINServerECDHv2 from .pindb import PINDb -from wallycore import hex_from_bytes, hex_to_bytes, AES_KEY_LEN_256, \ - AES_BLOCK_LEN +from wallycore import AES_KEY_LEN_256, AES_BLOCK_LEN from dotenv import load_dotenv -b2h = hex_from_bytes -h2b = hex_to_bytes - # Time we will retain active sessions, in seconds. # ie. maximum time allowed 'start_handshake' (which creates the session) # and the get-/set-pin call, which utilises it. @@ -45,7 +41,7 @@ def start_handshake_route(): # Create a new ephemeral server/session and get its signed pubkey e_ecdh_server = PINServerECDH() pubkey, sig = e_ecdh_server.get_signed_public_key() - ske = b2h(pubkey) + ske = pubkey.hex() # Cache new session _cleanup_expired_sessions() @@ -53,7 +49,7 @@ def start_handshake_route(): # Return response return jsonify({'ske': ske, - 'sig': b2h(sig)}) + 'sig': sig.hex()}) def _complete_server_call_v1(pin_func, udata): ske = udata['ske'] @@ -65,9 +61,9 @@ def _complete_server_call_v1(pin_func, udata): # get/set pin and get response data encrypted_key, hmac = e_ecdh_server.call_with_payload( - h2b(udata['cke']), - h2b(udata['encrypted_data']), - h2b(udata['hmac_encrypted_data']), + bytes.fromhex(udata['cke']), + bytes.fromhex(udata['encrypted_data']), + bytes.fromhex(udata['hmac_encrypted_data']), pin_func) # Expecting to return an encrypted aes-key @@ -78,27 +74,27 @@ def _complete_server_call_v1(pin_func, udata): _cleanup_expired_sessions() # Return response - return jsonify({'encrypted_key': b2h(encrypted_key), - 'hmac': b2h(hmac)}) + return jsonify({'encrypted_key': encrypted_key.hex(), + 'hmac': hmac.hex()}) def _complete_server_call_v2(pin_func, udata): assert 'ske' not in udata assert len(udata['replay_counter']) == 8 - cke = h2b(udata['cke']) - replay_counter = h2b(udata['replay_counter']) + cke = bytes.fromhex(udata['cke']) + replay_counter = bytes.fromhex(udata['replay_counter']) e_ecdh_server = PINServerECDHv2(replay_counter, cke) encrypted_key, hmac = e_ecdh_server.call_with_payload( cke, - h2b(udata['encrypted_data']), - h2b(udata['hmac_encrypted_data']), + bytes.fromhex(udata['encrypted_data']), + bytes.fromhex(udata['hmac_encrypted_data']), pin_func) # Expecting to return an encrypted aes-key assert len(encrypted_key) == AES_KEY_LEN_256 + (2*AES_BLOCK_LEN) # Return response - return jsonify({'encrypted_key': b2h(encrypted_key), - 'hmac': b2h(hmac)}) + return jsonify({'encrypted_key': encrypted_key.hex(), + 'hmac': hmac.hex()}) def _complete_server_call(pin_func): try: diff --git a/pindb.py b/pindb.py index 771995c..2c05562 100644 --- a/pindb.py +++ b/pindb.py @@ -6,13 +6,9 @@ from pathlib import Path from hmac import compare_digest from wallycore import ec_sig_to_public_key, sha256, hmac_sha256, \ - hex_from_bytes, AES_KEY_LEN_256, EC_SIGNATURE_RECOVERABLE_LEN, SHA256_LEN, \ - hex_to_bytes + AES_KEY_LEN_256, EC_SIGNATURE_RECOVERABLE_LEN, SHA256_LEN from dotenv import load_dotenv -b2h = hex_from_bytes -h2b = hex_to_bytes - VERSION_SUPPORTED = 0 VERSION_LATEST = 1 @@ -31,7 +27,7 @@ class FileStorage(object): @staticmethod def _get_filename(key): - filename = '{}.pin'.format(b2h(key)) + filename = '{}.pin'.format(key.hex()) if os.path.exists('pins'): return Path('pins') / filename return filename diff --git a/test/test_pindb.py b/test/test_pindb.py index 491be2e..ad55cf0 100644 --- a/test/test_pindb.py +++ b/test/test_pindb.py @@ -6,11 +6,9 @@ from ..pindb import PINDb from ..lib import encrypt, decrypt, E_ECDH -from wallycore import sha256, ec_sig_from_bytes, hex_from_bytes, \ +from wallycore import sha256, ec_sig_from_bytes, \ AES_KEY_LEN_256, EC_FLAG_ECDSA, EC_FLAG_RECOVERABLE -b2h = hex_from_bytes - # Tests the pindb and payload handling without any reference to the ecdh # protocol/encryption wrapper. diff --git a/test/test_pinserver.py b/test/test_pinserver.py index f8f7362..02aaab4 100644 --- a/test/test_pinserver.py +++ b/test/test_pinserver.py @@ -13,12 +13,9 @@ from ..flaskserver import app from ..flaskserver import SESSION_LIFETIME -from wallycore import sha256, ec_sig_from_bytes, hex_from_bytes, hex_to_bytes,\ +from wallycore import sha256, ec_sig_from_bytes, \ AES_KEY_LEN_256, EC_FLAG_ECDSA, EC_FLAG_RECOVERABLE -b2h = hex_from_bytes -h2b = hex_to_bytes - class PINServerTest(unittest.TestCase): @@ -99,7 +96,7 @@ def tearDownClass(cls): # Start the client/server key-exchange handshake def start_handshake(self, client): handshake = self.post('start_handshake') - client.handshake(h2b(handshake['ske']), h2b(handshake['sig'])) + client.handshake(bytes.fromhex(handshake['ske']), bytes.fromhex(handshake['sig'])) return client # Make a new ephemeral client and initialise with server handshake @@ -128,16 +125,16 @@ def server_call(self, private_key, client, endpoint, pin_secret, entropy, replay encrypted, hmac = client.encrypt_request_payload(payload) # Make call and parse response - urldata = {'ske': b2h(ske), - 'cke': b2h(cke), - 'encrypted_data': b2h(encrypted), - 'hmac_encrypted_data': b2h(hmac)} + urldata = {'ske': ske.hex(), + 'cke': cke.hex(), + 'encrypted_data': encrypted.hex(), + 'hmac_encrypted_data': hmac.hex()} if replay_counter: - urldata['replay_counter'] = b2h(replay_counter) + urldata['replay_counter'] = replay_counter.hex() del urldata['ske'] response = self.post(endpoint, urldata) - encrypted = h2b(response['encrypted_key']) - hmac = h2b(response['hmac']) + encrypted = bytes.fromhex(response['encrypted_key']) + hmac = bytes.fromhex(response['hmac']) # Return decrypted payload return client.decrypt_response_payload(encrypted, hmac) @@ -297,8 +294,8 @@ def test_rejects_on_bad_json(self): ske, cke = client.get_key_exchange() # Make call with bad/missing parameters - urldata = {'ske': b2h(ske), - 'cke': b2h(cke), + urldata = {'ske': ske.hex(), + 'cke': cke.hex(), # 'encrypted_data' missing 'hmac_encrypted_data': 'abc123'} From 2350e02f47ab0addb3bc3b9863ea41495bdd5323 Mon Sep 17 00:00:00 2001 From: "Jamie C. Driver" Date: Tue, 21 Nov 2023 13:55:16 +0000 Subject: [PATCH 4/9] bpov2: Extend tests to cover both v1 and v2 protocols --- test/{test_ecdh.py => test_ecdh_v1.py} | 46 ++- test/test_ecdh_v2.py | 227 +++++++++++ test/test_pindb.py | 425 ++++++++++++++------- test/test_pinserver.py | 501 ++++++++++++++----------- 4 files changed, 837 insertions(+), 362 deletions(-) rename test/{test_ecdh.py => test_ecdh_v1.py} (81%) create mode 100644 test/test_ecdh_v2.py diff --git a/test/test_ecdh.py b/test/test_ecdh_v1.py similarity index 81% rename from test/test_ecdh.py rename to test/test_ecdh_v1.py index 1c755cf..9ac7d23 100644 --- a/test/test_ecdh.py +++ b/test/test_ecdh_v1.py @@ -6,10 +6,11 @@ from ..server import PINServerECDH -# Tests ECDH wrapper without any reference to the pin/aes-key paylod stuff. -# Just testing the ECDH envelope/ecryption in isolation, with misc bytearray() +# Tests ECDHv1 wrapper without any reference to the pin/aes-key paylod stuff. +# Just testing the ECDH envelope/encryption in isolation, with misc bytearray() # payloads (ie. any old str.encode()). Tests client/server handshake/pairing. -class ECDHTest(unittest.TestCase): +# NOTE: protocol v1: key-exchange handshake required +class ECDHv1Test(unittest.TestCase): @classmethod def setUpClass(cls): @@ -81,7 +82,7 @@ def test_call_with_payload(self): # Test server un-/re-wrapping function - this handles all the ecdh # decrypting, hmac checking and encrypting/hmac-ing of the response. - # Hander need know nothing about the wrapping encryption. + # Handler need know nothing about the wrapping encryption. server_response = "Reply to 'test 123' message".encode() def _func(client_key, payload, aes_pin_data_key): @@ -102,6 +103,7 @@ def test_multiple_calls(self): cke, client = self.new_client_handshake(ske, sig) # Server can handle multiple calls from the client with same secrets + # (But that would use same cke and secrets which is ofc not ideal/recommended.) server.generate_shared_secrets(cke) for i in range(5): client_request = 'request-{}'.format(i).encode() @@ -123,7 +125,7 @@ def test_multiple_clients(self): # Server can persist and handle multiple calls provided each one is # accompanied by its relevant cke for that client and the server - # regenerates the sharewd secrets each time. + # regenerates the shared secrets each time. for i in range(5): client_request = 'client-{}-request'.format(i).encode() cke, client = self.new_client_handshake(ske, sig) @@ -139,6 +141,36 @@ def test_multiple_clients(self): received = client.decrypt_response_payload(encrypted, hmac) self.assertEqual(received, server_response) + def test_bad_request_cke_throws(self): + # A new server and client + server = PINServerECDH() + ske, sig = server.get_signed_public_key() + cke, client = self.new_client_handshake(ske, sig) + + # Encrypt message + client_request = 'bad-cke-request'.encode() + encrypted, hmac = client.encrypt_request_payload(client_request) + + # Break cke + bad_cke, _ = self.new_client_handshake(ske, sig) + self.assertEqual(len(cke), len(bad_cke)) + self.assertNotEqual(cke, bad_cke) + + # Ensure decrypt_request() throws + server.generate_shared_secrets(cke) + server.decrypt_request_payload(cke, encrypted, hmac) # no error + + server.generate_shared_secrets(bad_cke) + with self.assertRaises(AssertionError) as cm: + server.decrypt_request_payload(bad_cke, encrypted, hmac) # error + + # Ensure call_with_payload() throws before it calls the handler fn + def _func(client_key, payload, aes_pin_data_key): + self.fail('should-never-get-here') + + with self.assertRaises(AssertionError) as cm: + server.call_with_payload(bad_cke, encrypted, hmac, _func) + def test_bad_request_hmac_throws(self): # A new server and client server = PINServerECDH() @@ -164,7 +196,7 @@ def _func(client_key, payload, aes_pin_data_key): self.fail('should-never-get-here') with self.assertRaises(AssertionError) as cm: - server.call_with_payload(cke, encrypted, hmac, _func) + server.call_with_payload(cke, encrypted, bad_hmac, _func) def test_bad_response_hmac_throws(self): # A new server and client @@ -173,7 +205,7 @@ def test_bad_response_hmac_throws(self): cke, client = self.new_client_handshake(ske, sig) # Encrypt message - client_request = 'bad-hmac-request'.encode() + client_request = 'bad-hmac-response-request'.encode() encrypted, hmac = client.encrypt_request_payload(client_request) def _func(client_key, payload, pin_data_aes_key): diff --git a/test/test_ecdh_v2.py b/test/test_ecdh_v2.py new file mode 100644 index 0000000..61ee71e --- /dev/null +++ b/test/test_ecdh_v2.py @@ -0,0 +1,227 @@ +import unittest + +import os + +from ..client import PINClientECDHv2 +from ..server import PINServerECDHv2 + + +# Tests ECDHv2 wrapper without any reference to the pin/aes-key paylod stuff. +# Just testing the ECDH envelope/encryption in isolation, with misc bytearray() +# payloads (ie. any old str.encode()). Tests client/server handshake/pairing. +# NOTE: protocol v2: no key-exchange handshake required +class ECDHv2Test(unittest.TestCase): + REPLAY_COUNTER = bytes([0x00, 0x00, 0x00, 0x2a]) # arbitrary + + @classmethod + def setUpClass(cls): + # The server public key the client would know + with open(PINServerECDHv2.STATIC_SERVER_PUBLIC_KEY_FILE, 'rb') as f: + cls.static_server_public_key = f.read() + + # Make a new client and initialise with server tweaked key and initial counter + def new_client_handshake(self): + client = PINClientECDHv2(self.static_server_public_key, self.REPLAY_COUNTER) + ske, cke = client.get_key_exchange() + return cke, client + + def _test_client_server_impl(self, client_request, server_response): + + # A new client is created, which computes the ske from a tweak on the + # server static key. + cke, client = self.new_client_handshake() + + # The client can then encrypt a payload (and hmac) for the server + encrypted, hmac = client.encrypt_request_payload(client_request) + self.assertNotEqual(client_request, encrypted) + + # A new server is created when passed the client replay-counter. + # The server uses the cke and counter to create the tweaked *private* key + # NOTE: the server deduced private key should be the counterpart to the + # client-deduced public key - if so the payload decryption should yield + # the original cleartext request message. + # Note: this validates hmac before it decrypts/returns + server = PINServerECDHv2(client.replay_counter, cke) + server.generate_shared_secrets(cke) + received = server.decrypt_request_payload(cke, encrypted, hmac) + self.assertEqual(received, client_request) + + # The server can then send an encrypted response to the client + encrypted, hmac = server.encrypt_response_payload(server_response) + + # The client can decrypt the response. + # Note: this validates hmac before it decrypts/returns + received = client.decrypt_response_payload(encrypted, hmac) + self.assertEqual(received, server_response) + + def test_client_server_happypath(self): + for (request, response) in [ + ('REQUEST'.encode(), 'RESPONSE'.encode()), + ('12345 request string'.encode(), '67890 reply'.encode()), + (os.urandom(32), os.urandom(64)) + ]: + with self.subTest(request=request, response=response): + self._test_client_server_impl(request, response) + + def test_call_with_payload(self): + # Client sends message to server + cke, client = self.new_client_handshake() + client_request = "Hello - test 123".encode() + encrypted, hmac = client.encrypt_request_payload(client_request) + self.assertNotEqual(client_request, encrypted) + + # Test server un-/re-wrapping function - this handles all the ecdh + # decrypting, hmac checking and encrypting/hmac-ing of the response. + # Handler need know nothing about the wrapping encryption. + server = PINServerECDHv2(client.replay_counter, cke) + server_response = "Reply to 'test 123' message".encode() + + def _func(client_key, payload, aes_pin_data_key, replay_counter): + self.assertEqual(client_key, cke) + self.assertEqual(payload, client_request) + self.assertEqual(replay_counter, client.replay_counter) + return server_response + + encrypted, hmac = server.call_with_payload(cke, encrypted, hmac, _func) + + # Assert that is what the client expects + received = client.decrypt_response_payload(encrypted, hmac) + self.assertEqual(received, server_response) + + def test_multiple_calls(self): + # A new server and client + cke, client = self.new_client_handshake() + server = PINServerECDHv2(client.replay_counter, cke) + + # Server can handle multiple calls from the client with same secrets + # (But that would use same cke and counter which is ofc not ideal/recommended.) + server.generate_shared_secrets(cke) + for i in range(5): + client_request = 'request-{}'.format(i).encode() + encrypted, hmac = client.encrypt_request_payload(client_request) + + received = server.decrypt_request_payload(cke, encrypted, hmac) + self.assertEqual(received, client_request) + + server_response = 'response-{}'.format(i).encode() + encrypted, hmac = server.encrypt_response_payload(server_response) + + received = client.decrypt_response_payload(encrypted, hmac) + self.assertEqual(received, server_response) + + def test_bad_request_cke_throws(self): + # A new server and client + cke, client = self.new_client_handshake() + server = PINServerECDHv2(client.replay_counter, cke) + + # Encrypt message + client_request = 'bad-cke-request'.encode() + encrypted, hmac = client.encrypt_request_payload(client_request) + + # Break cke + bad_cke, _ = self.new_client_handshake() + self.assertEqual(len(cke), len(bad_cke)) + self.assertNotEqual(cke, bad_cke) + + # Ensure decrypt_request() throws + server.generate_shared_secrets(cke) + server.decrypt_request_payload(cke, encrypted, hmac) # no error + + # Same server using good cke to derive keys, but bad cke passed + server.generate_shared_secrets(bad_cke) + with self.assertRaises(AssertionError) as cm: + server.decrypt_request_payload(bad_cke, encrypted, hmac) # error + + # New server with bad_cke from the get go + server = PINServerECDHv2(client.replay_counter, bad_cke) + server.generate_shared_secrets(bad_cke) + with self.assertRaises(AssertionError) as cm: + server.decrypt_request_payload(bad_cke, encrypted, hmac) # error + + # Ensure call_with_payload() throws before it calls the handler fn + def _func(client_key, payload, aes_pin_data_key): + self.fail('should-never-get-here') + + with self.assertRaises(AssertionError) as cm: + server.call_with_payload(bad_cke, encrypted, hmac, _func) + + def test_bad_request_counter_throws(self): + # A new server and client + cke, client = self.new_client_handshake() + server = PINServerECDHv2(client.replay_counter, cke) + + # Encrypt message + client_request = 'bad-counter-request'.encode() + encrypted, hmac = client.encrypt_request_payload(client_request) + + # Ensure decrypt_request() throws + server.generate_shared_secrets(cke) + server.decrypt_request_payload(cke, encrypted, hmac) # no error + + # New server with bad counter passed + server = PINServerECDHv2(os.urandom(4), cke) + server.generate_shared_secrets(cke) + with self.assertRaises(AssertionError) as cm: + server.decrypt_request_payload(cke, encrypted, hmac) # error + + # Ensure call_with_payload() throws before it calls the handler fn + def _func(client_key, payload, aes_pin_data_key): + self.fail('should-never-get-here') + + with self.assertRaises(AssertionError) as cm: + server.call_with_payload(cke, encrypted, hmac, _func) + + def test_bad_request_hmac_throws(self): + # A new server and client + cke, client = self.new_client_handshake() + server = PINServerECDHv2(client.replay_counter, cke) + + # Encrypt message + client_request = 'bad-hmac-request'.encode() + encrypted, hmac = client.encrypt_request_payload(client_request) + + # Break hmac + bad_hmac = bytearray(b+1 if b < 255 else b-1 for b in encrypted[-32:]) + self.assertNotEqual(hmac, bad_hmac) + + # Ensure decrypt_request() throws + server.generate_shared_secrets(cke) + server.decrypt_request_payload(cke, encrypted, hmac) # no error + with self.assertRaises(AssertionError) as cm: + server.decrypt_request_payload(cke, encrypted, bad_hmac) # error + + # Ensure call_with_payload() throws before it calls the handler fn + def _func(client_key, payload, aes_pin_data_key, replay_counter): + self.fail('should-never-get-here') + + with self.assertRaises(AssertionError) as cm: + server.call_with_payload(cke, encrypted, bad_hmac, _func) + + def test_bad_response_hmac_throws(self): + # A new server and client + cke, client = self.new_client_handshake() + server = PINServerECDHv2(client.replay_counter, cke) + + # Encrypt message + client_request = 'bad-hmac-response-request'.encode() + encrypted, hmac = client.encrypt_request_payload(client_request) + + def _func(client_key, payload, pin_data_aes_key, replay_counter): + self.assertEqual(client_key, cke) + self.assertEqual(payload, client_request) + self.assertEqual(replay_counter, client.replay_counter) + return 'bad-hmac-response'.encode() + + encrypted, hmac = server.call_with_payload(cke, encrypted, hmac, _func) + + # Break hmac + bad_hmac = bytearray(b+1 if b < 255 else b-1 for b in encrypted[-32:]) + self.assertNotEqual(hmac, bad_hmac) + + client.decrypt_response_payload(encrypted, hmac) # No error + with self.assertRaises(AssertionError) as cm: + client.decrypt_response_payload(encrypted, bad_hmac) # error + + +if __name__ == '__main__': + unittest.main() diff --git a/test/test_pindb.py b/test/test_pindb.py index ad55cf0..360701d 100644 --- a/test/test_pindb.py +++ b/test/test_pindb.py @@ -4,7 +4,7 @@ from hmac import compare_digest from ..pindb import PINDb -from ..lib import encrypt, decrypt, E_ECDH +from ..lib import E_ECDH from wallycore import sha256, ec_sig_from_bytes, \ AES_KEY_LEN_256, EC_FLAG_ECDSA, EC_FLAG_RECOVERABLE @@ -23,24 +23,27 @@ def new_entropy(): return os.urandom(32) @staticmethod - def make_payload(signing_key, cke, secret_in, entropy_in): - # Build the expected payload + def make_payload(signing_key, cke, secret_in, entropy_in, v2_replay_counter=None): + # Build the expected payload - if the v2_replay_counter is passed, assume protocol v2 + # and include that counter in the data being signed. Otherwise assume v1 and ignore. + counter = v2_replay_counter if v2_replay_counter else b'' sig = ec_sig_from_bytes(signing_key, - sha256(cke + secret_in + entropy_in), + sha256(cke + counter + secret_in + entropy_in), EC_FLAG_ECDSA | EC_FLAG_RECOVERABLE) + return secret_in + entropy_in + sig @classmethod def new_keys(cls): # USE ECDH class just because it's convenient way to make key pairs - sig_priv, sig_pub = E_ECDH.generate_ec_key_pair() + privkey, pubkey = E_ECDH.generate_ec_key_pair() _, cke = E_ECDH.generate_ec_key_pair() # add the pin_pubkey_hash to the set - pin_pubkey_hash = bytes(sha256(sig_pub)) + pin_pubkey_hash = bytes(sha256(pubkey)) cls.pinfiles.add(pin_pubkey_hash) - return sig_priv, sig_pub, cke, pin_pubkey_hash + return privkey, pubkey, cke, pin_pubkey_hash @classmethod def setUpClass(cls): @@ -55,200 +58,189 @@ def tearDownClass(cls): if PINDb.storage.exists(f): PINDb.storage.remove(f) - def _test_extract_fields_impl(self): + def _test_extract_fields_impl(self, v2_replay_counter): # Reinitialise keys and secret and entropy - sig_priv, sig_pub, cke, _ = self.new_keys() + privkey, pubkey, cke, _ = self.new_keys() secret_in, entropy_in = self.new_pin_secret(), self.new_entropy() # Build the expected payload - payload = self.make_payload(sig_priv, cke, secret_in, entropy_in) + payload = self.make_payload(privkey, cke, secret_in, entropy_in, v2_replay_counter) # Check pindb function can extract the components from the payload - secret_out, entropy_out, pubkey = PINDb._extract_fields(cke, payload) + secret_out, entropy_out, pubkey_out = PINDb._extract_fields(cke, payload, v2_replay_counter) self.assertEqual(secret_out, secret_in) self.assertEqual(entropy_out, entropy_in) # Check the public key is correctly recovered from the signature - self.assertEqual(pubkey, sig_pub) + self.assertEqual(pubkey_out, pubkey) def test_extract_fields(self): - for i in range(5): - with self.subTest(i=i): - self._test_extract_fields_impl() + for v2_replay_counter in [None, os.urandom(4), os.urandom(4)]: + with self.subTest(protocol='v2' if v2_replay_counter else 'v1'): + self._test_extract_fields_impl(v2_replay_counter) - def test_mismatching_cke_and_sig(self): + def _test_mismatching_sig_impl(self, v2_replay_counter): # Get two sets of keys and a new secret privX, pubX, ckeX, _ = self.new_keys() privY, pubY, ckeY, _ = self.new_keys() secret_in, entropy_in = self.new_pin_secret(), self.new_entropy() - # Build the expected payload with the wrong cke value - payload = self.make_payload(privX, ckeY, secret_in, entropy_in) + # Build the expected payload + payload = self.make_payload(privX, ckeX, secret_in, entropy_in, v2_replay_counter) # Call the pindb function to extract the components from the payload - # but use the 'expected' cke - the sig should not match either pubkey. - secret_out, entropy_out, pubkey = PINDb._extract_fields(ckeX, payload) + secret_out, entropy_out, pubkey = PINDb._extract_fields(ckeX, payload, v2_replay_counter) + self.assertEqual(secret_out, secret_in) + self.assertEqual(entropy_out, entropy_in) + self.assertEqual(pubkey, pubX) + + # Call the pindb function to extract the components from the payload + # but use a mismatched cke - the sig should not yield either pubkey. + secret_out, entropy_out, pubkey = PINDb._extract_fields(ckeY, payload, v2_replay_counter) self.assertEqual(secret_out, secret_in) self.assertEqual(entropy_out, entropy_in) self.assertNotEqual(pubkey, pubX) self.assertNotEqual(pubkey, pubY) - def test_save_and_load_pin_fields(self): - # Reinitialise keys and secret - _, _, _, pinfile = self.new_keys() - pin_secret, key_in = self.new_pin_secret(), self.new_entropy() - hps_in = sha256(pin_secret) - count_in = 5 + # Call the pindb function again with the correct cke, but pass a bad replay counter + for bad_counter in [os.urandom(4), None if v2_replay_counter else os.urandom(4)]: + secret_out, entropy_out, pubkey = PINDb._extract_fields(ckeX, payload, bad_counter) + self.assertEqual(secret_out, secret_in) + self.assertEqual(entropy_out, entropy_in) + self.assertNotEqual(pubkey, pubX) + self.assertNotEqual(pubkey, pubY) + def test_mismatching_sig(self): + for v2_replay_counter in [None, os.urandom(4), os.urandom(4)]: + with self.subTest(protocol='v2' if v2_replay_counter else 'v1'): + self._test_mismatching_sig_impl(v2_replay_counter) + + def test_load_nonexistent_file_throws(self): # Trying to read non-existent file throws (and does not create file) + _, _, _, pinfile = self.new_keys() self.assertFalse(PINDb.storage.exists(pinfile)) with self.assertRaises((FileNotFoundError, Exception)) as _: PINDb._load_pin_fields(pinfile, None, None) self.assertFalse(PINDb.storage.exists(pinfile)) - user_id = os.urandom(32) - aes_pin = bytes(os.urandom(32)) - - # Save some data - check new file created - new_key = PINDb._save_pin_fields(pinfile, hps_in, key_in, user_id, aes_pin, count_in) - self.assertTrue(PINDb.storage.exists(pinfile)) - - # Atm the 'new key' returned should be the one passed in - self.assertEqual(new_key, key_in) - - # Read file back in - ensure fields the same - hps_out, key_out, count_out, replay_local = PINDb._load_pin_fields(pinfile, - user_id, - aes_pin) - self.assertEqual(hps_out, hps_in) - self.assertEqual(key_out, key_in) - self.assertEqual(count_out, count_in) - self.assertEqual(replay_local, None) - - # Ensure we can set zero the count of an existing file - count_in = 0 - new_key = PINDb._save_pin_fields(pinfile, hps_in, key_in, user_id, aes_pin, count_in) - hps_out, key_out, count_out, replay_local = PINDb._load_pin_fields(pinfile, - user_id, - aes_pin) - self.assertEqual(hps_out, hps_in) - self.assertEqual(key_out, key_in) - self.assertEqual(count_out, count_in) - self.assertEqual(replay_local, None) - - # Ensure we can't decrypt the pin with the wrong aes_key, hmac won't match - bad_aes = os.urandom(32) - with self.assertRaises(AssertionError) as _: - PINDb._load_pin_fields(pinfile, user_id, bad_aes) - - def test_save_and_load_pin_fieldsv2(self): + def _test_save_and_load_pin_fields_impl(self, use_v2_protocol): # Reinitialise keys and secret _, _, _, pinfile = self.new_keys() pin_secret, key_in = self.new_pin_secret(), self.new_entropy() hps_in = sha256(pin_secret) count_in = 5 - # Trying to read non-existent file throws (and does not create file) - self.assertFalse(PINDb.storage.exists(pinfile)) - with self.assertRaises((FileNotFoundError, Exception)) as _: - PINDb._load_pin_fields(pinfile, None, None) - self.assertFalse(PINDb.storage.exists(pinfile)) - user_id = os.urandom(32) aes_pin = bytes(os.urandom(32)) - replay_counter = 0 - replay_counter = replay_counter.to_bytes(4, byteorder='little', signed=False) - # Save some data - check new file created + v2_prior_counter = None + v2_replay_counter = b'\x00\x00\x00\x00' if use_v2_protocol else None new_key = PINDb._save_pin_fields(pinfile, hps_in, key_in, user_id, aes_pin, - count_in, replay_counter) + count_in, v2_replay_counter) self.assertTrue(PINDb.storage.exists(pinfile)) # Atm the 'new key' returned should be the one passed in self.assertEqual(new_key, key_in) - replay_counter = 1 - replay_counter = replay_counter.to_bytes(4, byteorder='little', signed=False) # Read file back in - ensure fields the same + if use_v2_protocol: + v2_prior_counter = int.from_bytes(v2_replay_counter, byteorder='little', signed=False) + v2_replay_counter = b'\x01\x00\x00\x00' hps_out, key_out, count_out, replay_local = PINDb._load_pin_fields(pinfile, user_id, aes_pin, - replay_counter) + v2_replay_counter) self.assertEqual(hps_out, hps_in) self.assertEqual(key_out, key_in) self.assertEqual(count_out, count_in) - self.assertEqual(replay_local, 0) - - replay_counter = 5 - replay_counter = replay_counter.to_bytes(4, byteorder='little', signed=False) + self.assertEqual(replay_local, v2_prior_counter) # Ensure we can set zero the count of an existing file count_in = 0 + if use_v2_protocol: + v2_prior_counter = int.from_bytes(v2_replay_counter, byteorder='little', signed=False) + v2_replay_counter = b'\x05\x00\x00\x00' new_key = PINDb._save_pin_fields(pinfile, hps_in, key_in, user_id, aes_pin, - count_in, replay_counter) - replay_counter = 10000 - replay_counter = replay_counter.to_bytes(4, byteorder='little', signed=False) + count_in, v2_replay_counter) + + if use_v2_protocol: + v2_prior_counter = int.from_bytes(v2_replay_counter, byteorder='little', signed=False) + v2_replay_counter = b'\xc2\x00\x00\x00' hps_out, key_out, count_out, replay_local = PINDb._load_pin_fields(pinfile, user_id, - aes_pin, replay_counter) + aes_pin, + v2_replay_counter) self.assertEqual(hps_out, hps_in) self.assertEqual(key_out, key_in) self.assertEqual(count_out, count_in) - self.assertEqual(replay_local, 5) + self.assertEqual(replay_local, v2_prior_counter) # Ensure we can't decrypt the pin with the wrong aes_key, hmac won't match bad_aes = os.urandom(32) with self.assertRaises(AssertionError) as _: PINDb._load_pin_fields(pinfile, user_id, bad_aes) - def _test_set_and_get_pin_impl(self): + def test_save_and_load_pin_fields(self): + for use_v2_protocol in [False, True]: + with self.subTest(protocol='v2' if use_v2_protocol else 'v1'): + self._test_save_and_load_pin_fields_impl(use_v2_protocol) + + def _test_set_and_get_pin_impl(self, v2set, v2get): # Reinitialise keys and secret - sig_priv, _, cke, pinfile = self.new_keys() - secret = self.new_pin_secret() + privkey, _, cke, pinfile = self.new_keys() + pin_secret = self.new_pin_secret() + pin_aes_key = bytes(os.urandom(32)) - # Make the expected payload - payload = self.make_payload(sig_priv, cke, secret, self.new_entropy()) - # Set the pin = check this creates the file + # Set the pin - check this creates the file + v2_replay_counter = b'\x00\x00\x00\x00' if v2set else None + payload = self.make_payload(privkey, cke, pin_secret, self.new_entropy(), v2_replay_counter) self.assertFalse(PINDb.storage.exists(pinfile)) - pin_aes_key = bytes(os.urandom(32)) - aeskey_s = PINDb.set_pin(cke, payload, pin_aes_key) + aeskey_s = PINDb.set_pin(cke, payload, pin_aes_key, v2_replay_counter) self.assertEqual(len(aeskey_s), AES_KEY_LEN_256) self.assertTrue(PINDb.storage.exists(pinfile)) # Get the key with the pin - new payload has new entropy (same pin) - payload = self.make_payload(sig_priv, cke, secret, self.new_entropy()) - aeskey_g = PINDb.get_aes_key(cke, payload, pin_aes_key) + v2_replay_counter = os.urandom(4) if v2get else None + payload = self.make_payload(privkey, cke, pin_secret, self.new_entropy(), v2_replay_counter) + aeskey_g = PINDb.get_aes_key(cke, payload, pin_aes_key, v2_replay_counter) self.assertTrue(compare_digest(aeskey_g, aeskey_s)) self.assertTrue(PINDb.storage.exists(pinfile)) def test_set_and_get_pin(self): - for i in range(5): - with self.subTest(i=i): - self._test_set_and_get_pin_impl() + for v2set, v2get in [(False, False), (False, True), (True, False), (True, True)]: + with self.subTest(set='v2' if v2set else 'v1', get='v2' if v2get else 'v1'): + for i in range(3): + self._test_set_and_get_pin_impl(v2set, v2get) - def test_bad_guesses_clears_pin(self): + def _test_bad_guesses_clears_pin_impl(self, v2set, v2get): # Reinitialise keys and secret - sig_priv, _, cke, pinfile = self.new_keys() + privkey, _, cke, pinfile = self.new_keys() pin_secret, entropy = self.new_pin_secret(), self.new_entropy() # Build the expected payload - good_payload = self.make_payload(sig_priv, cke, pin_secret, entropy) - # Set and verify the the pin = check this creates the file + v2_replay_counter = b'\x00\x00\x00\x00' if v2set else None + payload = self.make_payload(privkey, cke, pin_secret, entropy, v2_replay_counter) + + # Set and verify the the pin - check this creates the file self.assertFalse(PINDb.storage.exists(pinfile)) pin_aes_key = bytes(os.urandom(32)) - aeskey_s = PINDb.set_pin(cke, good_payload, pin_aes_key) + aeskey_s = PINDb.set_pin(cke, payload, pin_aes_key, v2_replay_counter) self.assertEqual(len(aeskey_s), AES_KEY_LEN_256) - aeskey_g = PINDb.get_aes_key(cke, good_payload, pin_aes_key) + + v2_replay_counter = b'\x01\x00\x00\x00' if v2set else None + payload = self.make_payload(privkey, cke, pin_secret, entropy, v2_replay_counter) + aeskey_g = PINDb.get_aes_key(cke, payload, pin_aes_key, v2_replay_counter) self.assertTrue(compare_digest(aeskey_g, aeskey_s)) self.assertTrue(PINDb.storage.exists(pinfile)) # Bad guesses at PIN for attempt in range(3): # Attempt to get with bad pin (using same entropy) + v2_replay_counter = (10 + attempt).to_bytes(4, 'little') if v2get else None bad_secret = os.urandom(32) - bad_payload = self.make_payload(sig_priv, cke, bad_secret, entropy) - guesskey = PINDb.get_aes_key(cke, bad_payload, pin_aes_key) + bad_payload = self.make_payload(privkey, cke, bad_secret, entropy, v2_replay_counter) + guesskey = PINDb.get_aes_key(cke, bad_payload, pin_aes_key, v2_replay_counter) # Wrong pin should return junk aes-key self.assertEqual(len(aeskey_s), len(guesskey)) @@ -258,60 +250,200 @@ def test_bad_guesses_clears_pin(self): self.assertFalse(PINDb.storage.exists(pinfile)) # Now even the correct pin will fail... - aeskey = PINDb.get_aes_key(cke, good_payload, pin_aes_key) + v2_replay_counter = b'\x0c\x20\x00\x00' if v2get else None + payload = self.make_payload(privkey, cke, pin_secret, entropy, v2_replay_counter) + aeskey = PINDb.get_aes_key(cke, payload, pin_aes_key, v2_replay_counter) self.assertEqual(len(aeskey), len(aeskey_s)) self.assertFalse(compare_digest(aeskey, aeskey_s)) self.assertFalse(PINDb.storage.exists(pinfile)) - def test_bad_server_key_or_user_pub_key_breaks(self): + def test_bad_guesses_clears_pin(self): + for v2set, v2get in [(False, False), (False, True), (True, False), (True, True)]: + with self.subTest(set='v2' if v2set else 'v1', get='v2' if v2get else 'v1'): + self._test_bad_guesses_clears_pin_impl(v2set, v2get) + + def _test_bad_server_key_breaks_impl(self, use_v2_protocol): # Reinitialise keys and secret - sig_priv, _, cke, pinfile = self.new_keys() + privkey, _, cke, pinfile = self.new_keys() pin_secret, entropy = self.new_pin_secret(), self.new_entropy() + pin_aes_key = bytes(os.urandom(32)) - # Build the expected payload - good_payload = self.make_payload(sig_priv, cke, pin_secret, entropy) - - # Set and verify the the pin = check this creates the file + # Set and verify the the pin - check this creates the file + v2_replay_counter = b'\x00\x00\x00\x00' if use_v2_protocol else None + payload = self.make_payload(privkey, cke, pin_secret, entropy, v2_replay_counter) self.assertFalse(PINDb.storage.exists(pinfile)) - pin_aes_key = bytes(os.urandom(32)) - aeskey_s = PINDb.set_pin(cke, good_payload, pin_aes_key) + aeskey_s = PINDb.set_pin(cke, payload, pin_aes_key, v2_replay_counter) self.assertEqual(len(aeskey_s), AES_KEY_LEN_256) - aeskey_g = PINDb.get_aes_key(cke, good_payload, pin_aes_key) - self.assertTrue(compare_digest(aeskey_g, aeskey_s)) self.assertTrue(PINDb.storage.exists(pinfile)) + # Check we can get the key + v2_replay_counter = b'\x01\x00\x00\x00' if use_v2_protocol else None + payload = self.make_payload(privkey, cke, pin_secret, entropy, v2_replay_counter) + aeskey_g = PINDb.get_aes_key(cke, payload, pin_aes_key, v2_replay_counter) + self.assertTrue(compare_digest(aeskey_g, aeskey_s)) + # Bad server key - for attempt in range(3): + for attempt in range(6): # Attempt to get with bad server key (using same entropy) bad_key = os.urandom(32) - guesskey = PINDb.get_aes_key(cke, good_payload, bad_key) + v2_replay_counter = (10 + attempt).to_bytes(4, 'little') if use_v2_protocol else None + payload = self.make_payload(privkey, cke, pin_secret, entropy, v2_replay_counter) + guesskey = PINDb.get_aes_key(cke, payload, bad_key, v2_replay_counter) - # Wrong pubkey should return junk aes-key + # Wrong key should return junk aes-key self.assertEqual(len(aeskey_s), len(guesskey)) self.assertFalse(compare_digest(aeskey_s, guesskey)) - # Bad pub key - for attempt in range(3): + # after many failed attempts server keeps the file + # as it doesn't know what file to check even + self.assertTrue(PINDb.storage.exists(pinfile)) + + # Now the correct pin will should still work if correct server key used + v2_replay_counter = b'\x00\xff\x00\x00' if use_v2_protocol else None + payload = self.make_payload(privkey, cke, pin_secret, entropy, v2_replay_counter) + aeskey = PINDb.get_aes_key(cke, payload, pin_aes_key, v2_replay_counter) + self.assertEqual(len(aeskey), len(aeskey_s)) + self.assertTrue(compare_digest(aeskey, aeskey_s)) + self.assertTrue(PINDb.storage.exists(pinfile)) + + def test_bad_server_key_pub_key_breaks(self): + for use_v2_protocol in [False, True]: + with self.subTest(protocol='v2' if use_v2_protocol else 'v1'): + self._test_bad_server_key_breaks_impl(use_v2_protocol) + + def _test_bad_user_pubkey_breaks_impl(self, use_v2_protocol): + # Reinitialise keys and secret + privkey, _, cke, pinfile = self.new_keys() + pin_secret, entropy = self.new_pin_secret(), self.new_entropy() + pin_aes_key = bytes(os.urandom(32)) + + # Set and verify the the pin - check this creates the file + v2_replay_counter = b'\x00\x00\x00\x00' if use_v2_protocol else None + payload = self.make_payload(privkey, cke, pin_secret, entropy, v2_replay_counter) + self.assertFalse(PINDb.storage.exists(pinfile)) + aeskey_s = PINDb.set_pin(cke, payload, pin_aes_key, v2_replay_counter) + self.assertEqual(len(aeskey_s), AES_KEY_LEN_256) + self.assertTrue(PINDb.storage.exists(pinfile)) + + # Check we can get the key + v2_replay_counter = b'\x03\x00\x00\x00' if use_v2_protocol else None + payload = self.make_payload(privkey, cke, pin_secret, entropy, v2_replay_counter) + aeskey_g = PINDb.get_aes_key(cke, payload, pin_aes_key, v2_replay_counter) + self.assertTrue(compare_digest(aeskey_g, aeskey_s)) + + # Bad replay counter passed from client + for attempt in range(6): # Attempt to get with bad pub_key (using same entropy) bad_key = os.urandom(32) - bad_payload = self.make_payload(bad_key, cke, pin_secret, entropy) - guesskey = PINDb.get_aes_key(cke, bad_payload, pin_aes_key) + v2_replay_counter = (10 + attempt).to_bytes(4, 'little') if use_v2_protocol else None + bad_payload = self.make_payload(bad_key, cke, pin_secret, entropy, v2_replay_counter) + guesskey = PINDb.get_aes_key(cke, bad_payload, pin_aes_key, v2_replay_counter) # Wrong pubkey should return junk aes-key self.assertEqual(len(aeskey_s), len(guesskey)) self.assertFalse(compare_digest(aeskey_s, guesskey)) - # after six failed attempts server keeps the file + # after many failed attempts server keeps the file # as it doesn't know what file to check even self.assertTrue(PINDb.storage.exists(pinfile)) - # Now the correct pin will should still be correct... - aeskey = PINDb.get_aes_key(cke, good_payload, pin_aes_key) + # Now the correct pin will should still be correct if correct pubkey used + v2_replay_counter = b'\x00\xff\x00\x00' if use_v2_protocol else None + payload = self.make_payload(privkey, cke, pin_secret, entropy, v2_replay_counter) + aeskey = PINDb.get_aes_key(cke, payload, pin_aes_key, v2_replay_counter) self.assertEqual(len(aeskey), len(aeskey_s)) self.assertTrue(compare_digest(aeskey, aeskey_s)) self.assertTrue(PINDb.storage.exists(pinfile)) - def test_two_users_with_same_pin(self): + def test_bad_user_pub_key_breaks(self): + for use_v2_protocol in [False, True]: + with self.subTest(protocol='v2' if use_v2_protocol else 'v1'): + self._test_bad_user_pubkey_breaks_impl(use_v2_protocol) + + def test_bad_v2_counter_breaks_get_pin(self): + # Reinitialise keys and secret + privkey, _, cke, pinfile = self.new_keys() + pin_secret, entropy = self.new_pin_secret(), self.new_entropy() + pin_aes_key = bytes(os.urandom(32)) + + # Set and verify the the pin - check this creates the file + v2_replay_counter = b'\x00\x00\x00\x00' + payload = self.make_payload(privkey, cke, pin_secret, entropy, v2_replay_counter) + self.assertFalse(PINDb.storage.exists(pinfile)) + aeskey_s = PINDb.set_pin(cke, payload, pin_aes_key, v2_replay_counter) + self.assertEqual(len(aeskey_s), AES_KEY_LEN_256) + self.assertTrue(PINDb.storage.exists(pinfile)) + + # Check we can get the key with increasing counters, and same or + # decreasing counters give a 'bad pin' result + max_counter = 0 + for counter in [0, 3, 3, 6, 123, 45, 332, 155, 332, 330, 500, 200, 300, 400, 501, 500]: + v2_replay_counter = counter.to_bytes(4, 'little', signed=False) + payload = self.make_payload(privkey, cke, pin_secret, entropy, v2_replay_counter) + aeskey = PINDb.get_aes_key(cke, payload, pin_aes_key, v2_replay_counter) + + if counter > max_counter: + # Should get correct key + self.assertTrue(compare_digest(aeskey, aeskey_s)) + max_counter = counter + else: + # Should get incorrect key + self.assertFalse(compare_digest(aeskey, aeskey_s)) + + # Now the correct pin will should still be correct + assert max_counter == 501 + + v2_replay_counter = b'\x00\xff\xff\xff' + payload = self.make_payload(privkey, cke, pin_secret, entropy, v2_replay_counter) + aeskey = PINDb.get_aes_key(cke, payload, pin_aes_key, v2_replay_counter) + self.assertEqual(len(aeskey), len(aeskey_s)) + self.assertTrue(compare_digest(aeskey, aeskey_s)) + self.assertTrue(PINDb.storage.exists(pinfile)) + + def test_bad_v2_counter_breaks_set_pin(self): + # Reinitialise keys and secret + privkey, _, cke, pinfile = self.new_keys() + pin_secret, entropy = self.new_pin_secret(), self.new_entropy() + pin_aes_key = bytes(os.urandom(32)) + + # Set and verify the the pin - check this creates the file + v2_replay_counter = b'\x00\x00\x00\x00' + payload = self.make_payload(privkey, cke, pin_secret, entropy, v2_replay_counter) + self.assertFalse(PINDb.storage.exists(pinfile)) + aeskey_s = PINDb.set_pin(cke, payload, pin_aes_key, v2_replay_counter) + self.assertEqual(len(aeskey_s), AES_KEY_LEN_256) + self.assertTrue(PINDb.storage.exists(pinfile)) + + v2_replay_counter = b'\x05\x00\x00\x00' + payload = self.make_payload(privkey, cke, pin_secret, entropy, v2_replay_counter) + aeskey_g = PINDb.get_aes_key(cke, payload, pin_aes_key, v2_replay_counter) + self.assertTrue(compare_digest(aeskey_g, aeskey_s)) + + # Set-pin fails if use a non-zero counter + v2_replay_counter = b'\x0f\x0f\x00\x00' + payload = self.make_payload(privkey, cke, pin_secret, entropy, v2_replay_counter) + with self.assertRaises(AssertionError) as cm: + aeskey_s = PINDb.set_pin(cke, payload, pin_aes_key, v2_replay_counter) + + # Key still present and readable with lower counter as set failed + v2_replay_counter = b'\x06\x00\x00\x00' + payload = self.make_payload(privkey, cke, pin_secret, entropy, v2_replay_counter) + aeskey_g = PINDb.get_aes_key(cke, payload, pin_aes_key, v2_replay_counter) + self.assertTrue(compare_digest(aeskey_g, aeskey_s)) + + # Set-pin must use a counter of 0 + v2_replay_counter = b'\x00\x00\x00\x00' + payload = self.make_payload(privkey, cke, pin_secret, entropy, v2_replay_counter) + aeskey_s = PINDb.set_pin(cke, payload, pin_aes_key, v2_replay_counter) + self.assertEqual(len(aeskey_s), AES_KEY_LEN_256) + + # Key readable with new counter + v2_replay_counter = b'\x01\x00\x00\x00' + payload = self.make_payload(privkey, cke, pin_secret, entropy, v2_replay_counter) + aeskey_g = PINDb.get_aes_key(cke, payload, pin_aes_key, v2_replay_counter) + self.assertTrue(compare_digest(aeskey_g, aeskey_s)) + + def _test_two_users_with_same_pin_impl(self, v2X, v2Y): # Get two sets of keys and a new secret privX, pubX, ckeX, _ = self.new_keys() privY, pubY, ckeY, _ = self.new_keys() @@ -319,37 +451,56 @@ def test_two_users_with_same_pin(self): # Build the expected payloads # X and Y use the same values... bizarre but should be fine - payloadX = self.make_payload(privX, ckeX, secret_in, entropy_in) - payloadY = self.make_payload(privY, ckeY, secret_in, entropy_in) + v2_replay_counterX = b'\x00\x00\x00\x00' if v2X else None + v2_replay_counterY = b'\x00\x00\x00\x00' if v2Y else None + payloadX = self.make_payload(privX, ckeX, secret_in, entropy_in, v2_replay_counterX) + payloadY = self.make_payload(privY, ckeY, secret_in, entropy_in, v2_replay_counterY) pin_aes_key = bytes(os.urandom(32)) - aeskeyX_s = PINDb.set_pin(ckeX, payloadX, pin_aes_key) - aeskeyY_s = PINDb.set_pin(ckeY, payloadY, pin_aes_key) + aeskeyX_s = PINDb.set_pin(ckeX, payloadX, pin_aes_key, v2_replay_counterX) + aeskeyY_s = PINDb.set_pin(ckeY, payloadY, pin_aes_key, v2_replay_counterY) # Keys should be different self.assertEqual(len(aeskeyX_s), len(aeskeyY_s)) self.assertFalse(compare_digest(aeskeyX_s, aeskeyY_s)) # Each can get their own key - aeskeyX_g = PINDb.get_aes_key(ckeX, payloadX, pin_aes_key) - aeskeyY_g = PINDb.get_aes_key(ckeY, payloadY, pin_aes_key) + v2_replay_counterX = os.urandom(4) if v2X else None + v2_replay_counterY = os.urandom(4) if v2Y else None + payloadX = self.make_payload(privX, ckeX, secret_in, entropy_in, v2_replay_counterX) + payloadY = self.make_payload(privY, ckeY, secret_in, entropy_in, v2_replay_counterY) + aeskeyX_g = PINDb.get_aes_key(ckeX, payloadX, pin_aes_key, v2_replay_counterX) + aeskeyY_g = PINDb.get_aes_key(ckeY, payloadY, pin_aes_key, v2_replay_counterY) self.assertFalse(compare_digest(aeskeyX_g, aeskeyY_g)) self.assertTrue(compare_digest(aeskeyX_g, aeskeyX_s)) self.assertTrue(compare_digest(aeskeyY_g, aeskeyY_s)) - def test_rejects_without_client_entropy(self): + def test_two_users_with_same_pin(self): + for v2X, v2Y in [(False, False), (False, True), (True, True)]: + with self.subTest(X='v2' if v2X else 'v1', Y='v2' if v2Y else 'v1'): + self._test_two_users_with_same_pin_impl(v2X, v2Y) + + def _test_rejects_without_client_entropy_impl(self, use_v2_protocol): # Reinitialise keys and secret and entropy sig_priv, _, cke, pinfile = self.new_keys() secret, entropy = self.new_pin_secret(), bytearray() # Build the expected payload - payload = self.make_payload(sig_priv, cke, secret, entropy) + v2_replay_counter = b'\x00\x00\x00\x00' if use_v2_protocol else None + payload = self.make_payload(sig_priv, cke, secret, entropy, v2_replay_counter) pin_aes_key = bytes(os.urandom(32)) with self.assertRaises(AssertionError) as cm: - PINDb.set_pin(cke, payload, pin_aes_key) + PINDb.set_pin(cke, payload, pin_aes_key, v2_replay_counter) + v2_replay_counter = b'\x01\x00\x00\x00' if use_v2_protocol else None + payload = self.make_payload(sig_priv, cke, secret, entropy, v2_replay_counter) with self.assertRaises(AssertionError) as cm: - PINDb.get_aes_key(cke, payload, pin_aes_key) + PINDb.get_aes_key(cke, payload, pin_aes_key, v2_replay_counter) + + def test_rejects_without_client_entropy(self): + for use_v2_protocol in [False, True]: + with self.subTest(protocol='v2' if use_v2_protocol else 'v1'): + self._test_rejects_without_client_entropy_impl(use_v2_protocol) if __name__ == '__main__': diff --git a/test/test_pinserver.py b/test/test_pinserver.py index 02aaab4..c45e97d 100644 --- a/test/test_pinserver.py +++ b/test/test_pinserver.py @@ -7,7 +7,7 @@ from hmac import compare_digest import requests -from ..client import PINClientECDH, PINClientECDHv2 +from ..client import PINClientECDH, PINClientECDH, PINClientECDHv2 from ..server import PINServerECDH from ..pindb import PINDb @@ -18,6 +18,8 @@ class PINServerTest(unittest.TestCase): + # Protocol v2 client replay coutner + v2_client_counter = 13 # arbitrary initial value @staticmethod def new_pin_secret(): @@ -94,44 +96,52 @@ def tearDownClass(cls): # Helpers # Start the client/server key-exchange handshake - def start_handshake(self, client): + def start_handshake_v1(self, client): + assert isinstance(client, PINClientECDH) handshake = self.post('start_handshake') client.handshake(bytes.fromhex(handshake['ske']), bytes.fromhex(handshake['sig'])) return client # Make a new ephemeral client and initialise with server handshake - def new_client_handshake(self): + def new_client_v1(self): client = PINClientECDH(self.static_server_public_key) - return self.start_handshake(client) + return self.start_handshake_v1(client) - # Make a new ephemeral client and initialise with tweaked server key - def new_client_handshakev2(self, replay_counter): - client = PINClientECDHv2(self.static_server_public_key, replay_counter) - return client + def new_client_v2(self, reset_replay_counter): + if reset_replay_counter: + client_counter = b'\x00\x00\x00\x00' + else: + self.v2_client_counter += 1 # increment - may be unnecessary but ensures monotonic + client_counter = self.v2_client_counter.to_bytes(4, byteorder='little', signed=False) + + return PINClientECDHv2(self.static_server_public_key, client_counter) # Make the server call to get/set the pin - returns the decrypted response - def server_call(self, private_key, client, endpoint, pin_secret, entropy, replay_counter=None): + # NOTE: explicit hmac fields + def server_call_v1(self, private_key, client, endpoint, pin_secret, entropy, + fn_perturb_request=None): + assert isinstance(client, PINClientECDH) + # Make and encrypt the payload (ie. pin secret) ske, cke = client.get_key_exchange() - cke_sha = cke - if replay_counter is not None: - assert len(replay_counter) == 4 - cke_sha = cke + replay_counter sig = ec_sig_from_bytes(private_key, - sha256(cke_sha + pin_secret + entropy), + sha256(cke + pin_secret + entropy), EC_FLAG_ECDSA | EC_FLAG_RECOVERABLE) payload = pin_secret + entropy + sig encrypted, hmac = client.encrypt_request_payload(payload) # Make call and parse response + # Includes 'ske' and 'hmac', but no 'replay_counter' urldata = {'ske': ske.hex(), 'cke': cke.hex(), 'encrypted_data': encrypted.hex(), 'hmac_encrypted_data': hmac.hex()} - if replay_counter: - urldata['replay_counter'] = replay_counter.hex() - del urldata['ske'] + + # Caller can mangle data before it is sent + if fn_perturb_request: + urldata = fn_perturb_request(urldata) + response = self.post(endpoint, urldata) encrypted = bytes.fromhex(response['encrypted_key']) hmac = bytes.fromhex(response['hmac']) @@ -139,29 +149,57 @@ def server_call(self, private_key, client, endpoint, pin_secret, entropy, replay # Return decrypted payload return client.decrypt_response_payload(encrypted, hmac) - def get_pin(self, private_key, pin_secret, entropy): - # Create new ephemeral client, initiate handshake, and make call - client = self.new_client_handshake() - return self.server_call( - private_key, client, 'get_pin', pin_secret, entropy) - - def set_pin(self, private_key, pin_secret, entropy): - # Create new ephemeral client, initiate handshake, and make call - client = self.new_client_handshake() - return self.server_call( - private_key, client, 'set_pin', pin_secret, entropy) - - def get_pinv2(self, private_key, pin_secret, entropy, replay_counter): - # Create new ephemeral client, initiate handshake, and make call - client = self.new_client_handshakev2(replay_counter) - return self.server_call( - private_key, client, 'get_pin', pin_secret, entropy, replay_counter) - - def set_pinv2(self, private_key, pin_secret, entropy, replay_counter): - # Create new ephemeral client, initiate handshake, and make call - client = self.new_client_handshakev2(replay_counter) - return self.server_call( - private_key, client, 'set_pin', pin_secret, entropy, replay_counter) + # Make the server call to get/set the pin - returns the decrypted response + # NOTE: signature covers replay counter + def server_call_v2(self, private_key, client, endpoint, pin_secret, entropy, + fn_perturb_request=None): + assert isinstance(client, PINClientECDHv2) + + # Make and encrypt the payload (ie. pin secret) + ske, cke = client.get_key_exchange() + sig = ec_sig_from_bytes(private_key, + sha256(cke + client.replay_counter + pin_secret + entropy), + EC_FLAG_ECDSA | EC_FLAG_RECOVERABLE) + payload = pin_secret + entropy + sig + + encrypted, hmac = client.encrypt_request_payload(payload) + + # Make call and parse response + # Includes 'replay_counter' but not 'ske' + urldata = {'cke': cke.hex(), + 'replay_counter': client.replay_counter.hex(), + 'encrypted_data': encrypted.hex(), + 'hmac_encrypted_data': hmac.hex()} + + # Caller can mangle data before it is sent + if fn_perturb_request: + urldata = fn_perturb_request(urldata) + + response = self.post(endpoint, urldata) + encrypted = bytes.fromhex(response['encrypted_key']) + hmac = bytes.fromhex(response['hmac']) + + # Return decrypted payload + return client.decrypt_response_payload(encrypted, hmac) + + def make_server_call(self, private_key, endpoint, pin_secret, entropy, use_v2_protocol, + fn_perturb_request=None): + if use_v2_protocol: + # NOTE: replay_counter must be reset to 0x00 with 'set_pin' requests + reset_counter = endpoint == 'set_pin' + client = self.new_client_v2(reset_counter) + server_call = self.server_call_v2 + else: + client = self.new_client_v1() + server_call = self.server_call_v1 + + return server_call(private_key, client, endpoint, pin_secret, entropy, fn_perturb_request) + + def get_pin(self, private_key, pin_secret, entropy, use_v2_protocol): + return self.make_server_call(private_key, 'get_pin', pin_secret, entropy, use_v2_protocol) + + def set_pin(self, private_key, pin_secret, entropy, use_v2_protocol): + return self.make_server_call(private_key, 'set_pin', pin_secret, entropy, use_v2_protocol) # Tests def test_get_index(self): @@ -183,7 +221,7 @@ def test_get_root_empty(self): f = requests.post(self.pinserver_url) self.assertEqual(f.status_code, 405) - def test_set_and_get_pin(self): + def _test_set_and_get_pin_impl(self, use_v2_protocol): # Make ourselves a static key pair for this logical client priv_key, _, _ = self.new_static_client_keys() @@ -191,15 +229,48 @@ def test_set_and_get_pin(self): pin_secret = self.new_pin_secret() # Make a new client and set the pin secret to get a new aes key - aeskey_s = self.set_pin(priv_key, pin_secret, self.new_entropy()) + aeskey_s = self.set_pin(priv_key, pin_secret, self.new_entropy(), use_v2_protocol) self.assertEqual(len(aeskey_s), AES_KEY_LEN_256) # Get key with a new client, with the correct pin secret (new entropy) - for attempt in range(5): - aeskey_g = self.get_pin(priv_key, pin_secret, self.new_entropy()) + for attempt in range(3): + aeskey_g = self.get_pin(priv_key, pin_secret, self.new_entropy(), use_v2_protocol) self.assertTrue(compare_digest(aeskey_g, aeskey_s)) - def test_bad_guesses_clears_pin(self): + def test_set_and_get_pin(self): + for use_v2_protocol in [False, True]: + with self.subTest(protocol='v2' if use_v2_protocol else 'v1'): + self._test_set_and_get_pin_impl(use_v2_protocol) + + def _test_protocol_upgrade_downgrade_impl(self, v2set, v2get): + # Make ourselves a static key pair for this logical client + priv_key, _, _ = self.new_static_client_keys() + + # The 'correct' client pin + pin_secret = self.new_pin_secret() + + # Make a new client and set the pin secret to get a new aes key + aeskey_s = self.set_pin(priv_key, pin_secret, self.new_entropy(), v2set) + self.assertEqual(len(aeskey_s), AES_KEY_LEN_256) + + # Now client changes protocol version - should all work seamlessly + # upgrade, and downgrade ... + # Get key with a new client, with the correct pin secret (new entropy) + aeskey = self.get_pin(priv_key, pin_secret, self.new_entropy(), v2get) + self.assertTrue(compare_digest(aeskey, aeskey_s)) + + aeskey = self.get_pin(priv_key, pin_secret, self.new_entropy(), not v2get) + self.assertTrue(compare_digest(aeskey, aeskey_s)) + + aeskey = self.get_pin(priv_key, pin_secret, self.new_entropy(), v2get) + self.assertTrue(compare_digest(aeskey, aeskey_s)) + + def test_protocol_upgrade_downgrade(self): + for v2set, v2get in [(False, False), (False, True), (True, False), (True, True)]: + with self.subTest(set='v2' if v2set else 'v1', get='v2' if v2get else 'v1'): + self._test_protocol_upgrade_downgrade_impl(v2set, v2get) + + def _test_bad_guesses_clears_pin_impl(self, use_v2_protocol): # Make ourselves a static key pair for this logical client priv_key, _, pinfile = self.new_static_client_keys() @@ -208,9 +279,9 @@ def test_bad_guesses_clears_pin(self): # Set and verify the pin - ensure underlying file created self.assertFalse(PINDb.storage.exists(pinfile)) - aeskey_s = self.set_pin(priv_key, pin_secret, entropy) + aeskey_s = self.set_pin(priv_key, pin_secret, entropy, use_v2_protocol) self.assertEqual(len(aeskey_s), AES_KEY_LEN_256) - aeskey_g = self.get_pin(priv_key, pin_secret, entropy) + aeskey_g = self.get_pin(priv_key, pin_secret, entropy, use_v2_protocol) self.assertTrue(compare_digest(aeskey_g, aeskey_s)) self.assertTrue(PINDb.storage.exists(pinfile)) @@ -218,7 +289,7 @@ def test_bad_guesses_clears_pin(self): for attempt in range(3): # Attempt to get with bad pin bad_secret = os.urandom(32) - guesskey = self.get_pin(priv_key, bad_secret, entropy) + guesskey = self.get_pin(priv_key, bad_secret, entropy, use_v2_protocol) # Wrong pin should return junk aes-key self.assertEqual(len(aeskey_s), len(guesskey)) @@ -228,12 +299,17 @@ def test_bad_guesses_clears_pin(self): self.assertFalse(PINDb.storage.exists(pinfile)) # Now even the correct pin will fail... - aeskey = self.get_pin(priv_key, bad_secret, entropy) + aeskey = self.get_pin(priv_key, bad_secret, entropy, use_v2_protocol) self.assertEqual(len(aeskey), len(aeskey_s)) self.assertFalse(compare_digest(aeskey, aeskey_s)) self.assertFalse(PINDb.storage.exists(pinfile)) - def test_bad_pubkey_breaks(self): + def test_bad_guesses_clears_pin(self): + for use_v2_protocol in [False, True]: + with self.subTest(protocol='v2' if use_v2_protocol else 'v1'): + self._test_bad_guesses_clears_pin_impl(use_v2_protocol) + + def _test_bad_pubkey_breaks_impl(self, use_v2_protocol): # Make ourselves a static key pair for this logical client priv_key, _, pinfile = self.new_static_client_keys() @@ -242,9 +318,9 @@ def test_bad_pubkey_breaks(self): # Set and verify the pin - ensure underlying file created self.assertFalse(PINDb.storage.exists(pinfile)) - aeskey_s = self.set_pin(priv_key, pin_secret, entropy) + aeskey_s = self.set_pin(priv_key, pin_secret, entropy, use_v2_protocol) self.assertEqual(len(aeskey_s), AES_KEY_LEN_256) - aeskey_g = self.get_pin(priv_key, pin_secret, entropy) + aeskey_g = self.get_pin(priv_key, pin_secret, entropy, use_v2_protocol) self.assertTrue(compare_digest(aeskey_g, aeskey_s)) self.assertTrue(PINDb.storage.exists(pinfile)) @@ -252,7 +328,7 @@ def test_bad_pubkey_breaks(self): for attempt in range(3): # Attempt to get with bad pub_key bad_key = os.urandom(32) - guesskey = self.get_pin(bad_key, pin_secret, entropy) + guesskey = self.get_pin(bad_key, pin_secret, entropy, use_v2_protocol) # Wrong pin should return junk aes-key self.assertEqual(len(aeskey_s), len(guesskey)) @@ -262,52 +338,117 @@ def test_bad_pubkey_breaks(self): self.assertTrue(PINDb.storage.exists(pinfile)) # The correct pin will continue to work - aeskey = self.get_pin(priv_key, pin_secret, entropy) + aeskey = self.get_pin(priv_key, pin_secret, entropy, use_v2_protocol) self.assertEqual(len(aeskey), len(aeskey_s)) self.assertTrue(compare_digest(aeskey, aeskey_s)) self.assertTrue(PINDb.storage.exists(pinfile)) - def test_two_users_with_same_pin(self): + def test_bad_pubkey_breaks(self): + for use_v2_protocol in [False, True]: + with self.subTest(protocol='v2' if use_v2_protocol else 'v1'): + self._test_bad_pubkey_breaks_impl(use_v2_protocol) + + def _test_two_users_with_same_pin_impl(self, v2X, v2Y): # Two users - clientA_private_key, _, _ = self.new_static_client_keys() - clientB_private_key, _, _ = self.new_static_client_keys() + clientX_private_key, _, _ = self.new_static_client_keys() + clientY_private_key, _, _ = self.new_static_client_keys() # pin plus its salt/iv/entropy pin_secret, entropy = self.new_pin_secret(), self.new_entropy() - # A and B use the same values... bizarre but should be fine - aeskey_sA = self.set_pin(clientA_private_key, pin_secret, entropy) - aeskey_sB = self.set_pin(clientB_private_key, pin_secret, entropy) - self.assertFalse(compare_digest(aeskey_sA, aeskey_sB)) + # X and Y use the same values... bizarre but should be fine + aeskey_sX = self.set_pin(clientX_private_key, pin_secret, entropy, v2X) + aeskey_sY = self.set_pin(clientY_private_key, pin_secret, entropy, v2Y) + self.assertFalse(compare_digest(aeskey_sX, aeskey_sY)) - aeskey_gA = self.get_pin(clientA_private_key, pin_secret, entropy) - self.assertTrue(compare_digest(aeskey_gA, aeskey_sA)) + aeskey_gX = self.get_pin(clientX_private_key, pin_secret, entropy, v2X) + self.assertTrue(compare_digest(aeskey_gX, aeskey_sX)) - aeskey_gB = self.get_pin(clientB_private_key, pin_secret, entropy) - self.assertTrue(compare_digest(aeskey_gB, aeskey_sB)) + aeskey_gY = self.get_pin(clientY_private_key, pin_secret, entropy, v2Y) + self.assertTrue(compare_digest(aeskey_gY, aeskey_sY)) - self.assertFalse(compare_digest(aeskey_gA, aeskey_gB)) + self.assertFalse(compare_digest(aeskey_gX, aeskey_gY)) - def test_rejects_on_bad_json(self): - # Create new ephemeral client, initiate handshake, and make call - client = self.new_client_handshake() - ske, cke = client.get_key_exchange() + def test_two_users_with_same_pin(self): + for v2X, v2Y in [(False, False), (False, True), (True, True)]: + with self.subTest(X='v2' if v2X else 'v1', Y='v2' if v2Y else 'v1'): + self._test_two_users_with_same_pin_impl(v2X, v2Y) - # Make call with bad/missing parameters - urldata = {'ske': ske.hex(), - 'cke': cke.hex(), - # 'encrypted_data' missing - 'hmac_encrypted_data': 'abc123'} + def test_rejects_bad_payload_not_json(self): + # Make call with not-even-json + urldata = 'This is not even json' with self.assertRaises(ValueError) as cm: - self.post('get_pin', urldata) + self.post('set_pin', urldata) + self.assertEqual('500', str(cm.exception.args[0])) - # Make call with not-even-json - urldata = 'This is not even json' with self.assertRaises(ValueError) as cm: self.post('get_pin', urldata) + self.assertEqual('500', str(cm.exception.args[0])) - def test_rejects_without_client_entropy(self): + def _test_rejects_on_bad_json_impl(self, use_v2_protocol): + # Make ourselves a static key pair for this logical client + priv_key, _, _ = self.new_static_client_keys() + pin_secret, entropy = self.new_pin_secret(), self.new_entropy() + + # Various ways to mangle the json request payload + bad_ske, bad_cke = self.new_client_v1().get_key_exchange() + + def _short(field): + def _fn(d): + d[field] = d[field][:-1] + return d + return _fn + + def _long(field): + def _fn(d): + d[field] = d[field] + 'ff' + return d + return _fn + + def _random(field): + def _fn(d): + d[field] = os.urandom(len(bytes.fromhex(d[field]))).hex() + return d + return _fn + + def _set(field, value): + def _fn(d): + d[field] = value + return d + return _fn + + def _remove(field): + def _fn(d): + del d[field] + return d + return _fn + + request_manglers = [_set('cke', bad_cke.hex())] + request_manglers.extend(f('cke') for f in [_short, _long, _remove]) + request_manglers.extend(f('encrypted_data') for f in [_random, _short, _long, _remove]) + request_manglers.extend(f('hmac_encrypted_data') for f in [_random, _short, _long, _remove]) + + if use_v2_protocol: + request_manglers.extend(f('replay_counter') for f in [_random, _short, _long, _remove]) + else: + request_manglers.append(_set('ske', bad_ske.hex())) + request_manglers.extend(f('ske') for f in [_short, _long, _remove]) + + for mangler in request_manglers: + for endpoint in ['get_pin', 'set_pin']: + with self.assertRaises(ValueError) as cm: + self.make_server_call(priv_key, endpoint, pin_secret, self.new_entropy(), + use_v2_protocol, mangler) + + self.assertEqual('500', str(cm.exception.args[0])) + + def test_rejects_on_bad_json(self): + for use_v2_protocol in [False, True]: + with self.subTest(protocol='v2' if use_v2_protocol else 'v1'): + self._test_rejects_on_bad_json_impl(use_v2_protocol) + + def _test_rejects_without_client_entropy_impl(self, use_v2_protocol): # Make ourselves a static key pair for this logical client priv_key, _, _ = self.new_static_client_keys() @@ -316,16 +457,21 @@ def test_rejects_without_client_entropy(self): # Make a new client and set the pin secret to get a new aes key with self.assertRaises(ValueError) as cm: - self.set_pin(priv_key, pin_secret, entropy) + self.set_pin(priv_key, pin_secret, entropy, use_v2_protocol) self.assertEqual('500', str(cm.exception.args[0])) with self.assertRaises(ValueError) as cm: - self.get_pin(priv_key, pin_secret, entropy) + self.get_pin(priv_key, pin_secret, entropy, use_v2_protocol) self.assertEqual('500', str(cm.exception.args[0])) - def test_delayed_interaction(self): + def test_rejects_without_client_entropy(self): + for use_v2_protocol in [False, True]: + with self.subTest(protocol='v2' if use_v2_protocol else 'v1'): + self._test_rejects_without_client_entropy_impl(use_v2_protocol) + + def test_delayed_interaction_v1(self): # Make ourselves a static key pair for this logical client priv_key, _, _ = self.new_static_client_keys() @@ -333,21 +479,21 @@ def test_delayed_interaction(self): pin_secret = self.new_pin_secret() # Set and verify the pin - aeskey_s = self.set_pin(priv_key, pin_secret, self.new_entropy()) - aeskey_g = self.get_pin(priv_key, pin_secret, self.new_entropy()) + aeskey_s = self.set_pin(priv_key, pin_secret, self.new_entropy(), use_v2_protocol=False) + aeskey_g = self.get_pin(priv_key, pin_secret, self.new_entropy(), use_v2_protocol=False) self.assertTrue(compare_digest(aeskey_g, aeskey_s)) # If we delay in the server interaction it will fail with a 500 error - client = self.new_client_handshake() + client = self.new_client_v1() time.sleep(SESSION_LIFETIME + 1) # Sufficiently long delay with self.assertRaises(ValueError) as cm: - self.server_call(priv_key, client, 'get_pin', pin_secret, - self.new_entropy()) + self.server_call_v1(priv_key, client, 'get_pin', pin_secret, + self.new_entropy()) self.assertEqual('500', str(cm.exception.args[0])) - def test_cannot_reuse_client_session(self): + def test_cannot_reuse_client_session_v1(self): # Make ourselves a static key pair for this logical client priv_key, _, _ = self.new_static_client_keys() @@ -355,169 +501,88 @@ def test_cannot_reuse_client_session(self): pin_secret = self.new_pin_secret() # Set pin - aeskey_s = self.set_pin(priv_key, pin_secret, self.new_entropy()) + aeskey_s = self.set_pin(priv_key, pin_secret, self.new_entropy(), use_v2_protocol=False) # Get/verify pin with a new client - client = self.new_client_handshake() - aeskey_g = self.server_call(priv_key, client, 'get_pin', pin_secret, - self.new_entropy()) + client = self.new_client_v1() + aeskey_g = self.server_call_v1(priv_key, client, 'get_pin', pin_secret, + self.new_entropy()) self.assertTrue(compare_digest(aeskey_g, aeskey_s)) # Trying to reuse the session should fail with a 500 error + # because the server has closed that ephemeral encryption session with self.assertRaises(ValueError) as cm: - self.server_call(priv_key, client, 'get_pin', pin_secret, - self.new_entropy()) + self.server_call_v1(priv_key, client, 'get_pin', pin_secret, + self.new_entropy()) self.assertEqual('500', str(cm.exception.args[0])) # Not great, but we could reuse the client if we re-initiate handshake - # (But that would use same cke which is not ideal.) - self.start_handshake(client) - aeskey = self.server_call(priv_key, client, 'get_pin', pin_secret, - self.new_entropy()) + # (But that would use same cke which is not ideal/recommended.) + self.start_handshake_v1(client) + aeskey = self.server_call_v1(priv_key, client, 'get_pin', pin_secret, + self.new_entropy()) self.assertTrue(compare_digest(aeskey, aeskey_s)) - def test_v2_happypath_with_simulated_replay(self): + def test_cannot_reuse_client_session_v2(self): # Make ourselves a static key pair for this logical client priv_key, _, _ = self.new_static_client_keys() - # The 'correct' client pin + # The 'correct' client pin plus its salt/iv/entropy pin_secret = self.new_pin_secret() - # assert you can't set pin with a replay_counter different than 0 - with self.assertRaises(ValueError) as cm: - replay_counter = 1 - self.set_pinv2(priv_key, pin_secret, self.new_entropy(), - replay_counter.to_bytes(4, - byteorder='little', - signed=False)) - - # set the pin secret to get a new aes key - replay_counter = 0 - aeskey_s = self.set_pinv2(priv_key, pin_secret, self.new_entropy(), - replay_counter.to_bytes(4, byteorder='little', - signed=False)) - - # retrieve the key again with our correct pin secret - replay_counter = 1 - aeskey = self.get_pinv2(priv_key, pin_secret, self.new_entropy(), - replay_counter.to_bytes(4, byteorder='little', - signed=False)) - - # Now let's compare - self.assertTrue(compare_digest(aeskey, aeskey_s)) + # Set pin + aeskey_s = self.set_pin(priv_key, pin_secret, self.new_entropy(), use_v2_protocol=True) - for i in range(5): - # Simulate a reply attempt failing N times, it doesn't affect pin - # attempts / dos - aeskey = self.get_pinv2(priv_key, pin_secret, self.new_entropy(), - replay_counter.to_bytes(4, - byteorder='little', - signed=False)) - self.assertFalse(compare_digest(aeskey, aeskey_s)) - - # retrieve the key again using v1 - aeskey_g = self.get_pin(priv_key, pin_secret, self.new_entropy()) + # Get/verify pin with a new client + client = self.new_client_v2(False) + aeskey_g = self.server_call_v2(priv_key, client, 'get_pin', pin_secret, self.new_entropy()) self.assertTrue(compare_digest(aeskey_g, aeskey_s)) - # Incrementing the counter monotonically works again - replay_counter = 2 - aeskey = self.get_pinv2(priv_key, pin_secret, self.new_entropy(), - replay_counter.to_bytes(4, byteorder='little', - signed=False)) - self.assertTrue(compare_digest(aeskey, aeskey_s)) - - # Incrementing the counter monotonically works even in case of network - # errors where some request is missed - replay_counter = 4 - aeskey = self.get_pinv2(priv_key, pin_secret, self.new_entropy(), - replay_counter.to_bytes(4, byteorder='little', - signed=False)) - self.assertTrue(compare_digest(aeskey, aeskey_s)) - - bad_secret = self.new_pin_secret() - for i in range(3): - # exaust pin attmempts with good replay_counter - replay_counter = i + 5 - replay_counter = replay_counter.to_bytes(4, byteorder='little', signed=False) - aeskey = self.get_pinv2(priv_key, bad_secret, self.new_entropy(), replay_counter) - self.assertFalse(compare_digest(aeskey, aeskey_s)) - - # retrieve the key again using v1 should fail - aeskey_g = self.get_pin(priv_key, pin_secret, self.new_entropy()) - self.assertFalse(compare_digest(aeskey_g, aeskey_s)) - - # Incrementing the counter monotonically also fails - replay_counter = 8 - aeskey = self.get_pinv2(priv_key, pin_secret, self.new_entropy(), - replay_counter.to_bytes(4, byteorder='little', - signed=False)) + # Trying to reuse the session should appear to work, but will return a junk key + # (ie. same as bad pin) because the server-side 'replay counter' has moved on + aeskey = self.server_call_v2(priv_key, client, 'get_pin', pin_secret, self.new_entropy()) self.assertFalse(compare_digest(aeskey, aeskey_s)) - def test_v2_happypath_with_simulated_replay_upgrade(self): + # Set-pin should fail more overtly + with self.assertRaises(ValueError) as cm: + aeskey_g = self.server_call_v2(priv_key, client, 'set_pin', self.new_pin_secret(), + self.new_entropy()) + self.assertEqual('500', str(cm.exception.args[0])) + + def test_set_pin_counter_v2(self): # Make ourselves a static key pair for this logical client priv_key, _, _ = self.new_static_client_keys() - # The 'correct' client pin + # The 'correct' client pin plus its salt/iv/entropy pin_secret = self.new_pin_secret() - # Make a new client and set the pin secret to get a new aes key - aeskey_s = self.set_pin(priv_key, pin_secret, self.new_entropy()) - self.assertEqual(len(aeskey_s), AES_KEY_LEN_256) + # Set pin + aeskey_s = self.set_pin(priv_key, pin_secret, self.new_entropy(), use_v2_protocol=True) - # retrieve the key again with our correct pin secret - replay_counter = 0 - aeskey = self.get_pinv2(priv_key, pin_secret, self.new_entropy(), - replay_counter.to_bytes(4, byteorder='little', - signed=False)) + # Get/verify pin with a new client + aeskey_g = self.get_pin(priv_key, pin_secret, self.new_entropy(), use_v2_protocol=True) + self.assertTrue(compare_digest(aeskey_g, aeskey_s)) - # Now let's compare - self.assertTrue(compare_digest(aeskey, aeskey_s)) + # Trying to set-pin with non-zero counter should fail + client = self.new_client_v2(False) + with self.assertRaises(ValueError) as cm: + aeskey_g = self.server_call_v2(priv_key, client, 'set_pin', self.new_pin_secret(), + self.new_entropy()) + self.assertEqual('500', str(cm.exception.args[0])) - for i in range(5): - # Simulate a reply attempt failing N times, it doesn't affect pin - # attempts / dos - aeskey = self.get_pinv2(priv_key, pin_secret, self.new_entropy(), - replay_counter.to_bytes(4, - byteorder='little', - signed=False)) - self.assertFalse(compare_digest(aeskey, aeskey_s)) - - # retrieve the key again using v1 - aeskey_g = self.get_pin(priv_key, pin_secret, self.new_entropy()) + # Existing saved PIN undamaged as set attempt failed + aeskey_g = self.get_pin(priv_key, pin_secret, self.new_entropy(), use_v2_protocol=True) self.assertTrue(compare_digest(aeskey_g, aeskey_s)) - # Incrementing the counter monotonically works again - replay_counter = 2 - aeskey = self.get_pinv2(priv_key, pin_secret, self.new_entropy(), - replay_counter.to_bytes(4, byteorder='little', - signed=False)) - self.assertTrue(compare_digest(aeskey, aeskey_s)) - - # Incrementing the counter monotonically works even in case of network - # errors where some request is missed - replay_counter = 4 - aeskey = self.get_pinv2(priv_key, pin_secret, self.new_entropy(), - replay_counter.to_bytes(4, byteorder='little', - signed=False)) - self.assertTrue(compare_digest(aeskey, aeskey_s)) + # Trying to reset pin with zero counter should work + pin_secret = self.new_pin_secret() + client = self.new_client_v2(True) + aeskey_s = self.server_call_v2(priv_key, client, 'set_pin', pin_secret, self.new_entropy()) + self.assertFalse(compare_digest(aeskey_g, aeskey_s)) # changed - bad_secret = self.new_pin_secret() - for i in range(3): - # exaust pin attmempts with good replay_counter - aeskey = self.get_pin(priv_key, bad_secret, self.new_entropy()) - self.assertFalse(compare_digest(aeskey, aeskey_s)) - - # retrieve the key again using v1 should fail - aeskey_g = self.get_pin(priv_key, pin_secret, self.new_entropy()) - self.assertFalse(compare_digest(aeskey_g, aeskey_s)) - - # Incrementing the counter monotonically also fails - replay_counter = 5 - aeskey = self.get_pinv2(priv_key, pin_secret, self.new_entropy(), - replay_counter.to_bytes(4, byteorder='little', - signed=False)) - self.assertFalse(compare_digest(aeskey, aeskey_s)) + aeskey_g = self.get_pin(priv_key, pin_secret, self.new_entropy(), use_v2_protocol=True) + self.assertTrue(compare_digest(aeskey_g, aeskey_s)) if __name__ == '__main__': From 25a30169558f8e1b6fea1d608030f3e20051dcb2 Mon Sep 17 00:00:00 2001 From: "Jamie C. Driver" Date: Thu, 16 Nov 2023 10:56:47 +0000 Subject: [PATCH 5/9] bpov2: Small tweak to separate anti-replay check from loading fields Also saves some int/bytes conversions --- pindb.py | 35 ++++++++++++++++------------------- test/test_pindb.py | 22 ++++++++-------------- 2 files changed, 24 insertions(+), 33 deletions(-) diff --git a/pindb.py b/pindb.py index 2c05562..7ef71c8 100644 --- a/pindb.py +++ b/pindb.py @@ -121,6 +121,15 @@ def _extract_fields(cls, cke, data, replay_counter=None): return pin_secret, entropy, client_public_key + @classmethod + def _check_v2_anti_replay(cls, server_counter, client_counter): + # if this is v2 and the db is already upgraded we enforce the anti replay + # ie. monotonic forward counter + if server_counter is not None and client_counter is not None: + server_counter = int.from_bytes(server_counter, byteorder='little', signed=False) + client_counter = int.from_bytes(client_counter, byteorder='little', signed=False) + assert client_counter > server_counter + @classmethod def _save_pin_fields(cls, pin_pubkey_hash, hash_pin_secret, aes_key, pin_pubkey, aes_pin_data_key, count, replay_counter=None): @@ -147,7 +156,7 @@ def _save_pin_fields(cls, pin_pubkey_hash, hash_pin_secret, aes_key, return aes_key @classmethod - def _load_pin_fields(cls, pin_pubkey_hash, pin_pubkey, aes_pin_data_key, replay_counter=None): + def _load_pin_fields(cls, pin_pubkey_hash, pin_pubkey, aes_pin_data_key): data = cls.storage.get(pin_pubkey_hash) assert len(data) == 129 @@ -173,23 +182,12 @@ def _load_pin_fields(cls, pin_pubkey_hash, pin_pubkey, aes_pin_data_key, replay_ hash_pin_secret, aes_key = plaintext[:32], plaintext[32:64] count = struct.unpack('B', plaintext[64: 64 + struct.calcsize('B')])[0] + replay_counter_persisted = plaintext[65:69] if len_plaintext == 69 else None - replay_local = None - if len_plaintext == 69: - replay_local = plaintext[65:69] - replay_local = int.from_bytes(replay_local, byteorder='little', - signed=False) - if replay_local is not None and replay_counter is not None: - # if this is v2 and the db is already upgraded we enforce the - # anti replay - replay_remote = int.from_bytes(replay_counter, byteorder='little', - signed=False) - assert replay_remote > replay_local - - return hash_pin_secret, aes_key, count, replay_local + return hash_pin_secret, aes_key, count, replay_counter_persisted @classmethod - def make_client_aes_key(self, pin_secret, saved_key): + def make_client_aes_key(cls, pin_secret, saved_key): # The client key returned is a combination of the aes-key persisted # and the raw pin_secret (that we do not persist anywhere). aes_key = hmac_sha256(saved_key, pin_secret) @@ -203,10 +201,9 @@ def get_aes_key_impl(cls, pin_pubkey, pin_secret, aes_pin_data_key, replay_count pin_pubkey_hash = bytes(sha256(pin_pubkey)) saved_hps, saved_key, counter, replay_local = cls._load_pin_fields(pin_pubkey_hash, pin_pubkey, - aes_pin_data_key, - replay_counter) - if replay_local is not None: - replay_local = replay_local.to_bytes(4, byteorder='little', signed=False) + aes_pin_data_key) + # Check anti-replay counter if appropriate + cls._check_v2_anti_replay(replay_local, replay_counter) # Check that the pin provided matches that saved hash_pin_secret = sha256(pin_secret) diff --git a/test/test_pindb.py b/test/test_pindb.py index 360701d..d2e0e0d 100644 --- a/test/test_pindb.py +++ b/test/test_pindb.py @@ -134,7 +134,6 @@ def _test_save_and_load_pin_fields_impl(self, use_v2_protocol): aes_pin = bytes(os.urandom(32)) # Save some data - check new file created - v2_prior_counter = None v2_replay_counter = b'\x00\x00\x00\x00' if use_v2_protocol else None new_key = PINDb._save_pin_fields(pinfile, hps_in, key_in, user_id, aes_pin, count_in, v2_replay_counter) @@ -144,13 +143,11 @@ def _test_save_and_load_pin_fields_impl(self, use_v2_protocol): self.assertEqual(new_key, key_in) # Read file back in - ensure fields the same - if use_v2_protocol: - v2_prior_counter = int.from_bytes(v2_replay_counter, byteorder='little', signed=False) - v2_replay_counter = b'\x01\x00\x00\x00' + v2_prior_counter = v2_replay_counter + v2_replay_counter = b'\x01\x00\x00\x00' if use_v2_protocol else None hps_out, key_out, count_out, replay_local = PINDb._load_pin_fields(pinfile, user_id, - aes_pin, - v2_replay_counter) + aes_pin) self.assertEqual(hps_out, hps_in) self.assertEqual(key_out, key_in) self.assertEqual(count_out, count_in) @@ -158,19 +155,16 @@ def _test_save_and_load_pin_fields_impl(self, use_v2_protocol): # Ensure we can set zero the count of an existing file count_in = 0 - if use_v2_protocol: - v2_prior_counter = int.from_bytes(v2_replay_counter, byteorder='little', signed=False) - v2_replay_counter = b'\x05\x00\x00\x00' + v2_prior_counter = v2_replay_counter + v2_replay_counter = b'\x05\x00\x00\x00' if use_v2_protocol else None new_key = PINDb._save_pin_fields(pinfile, hps_in, key_in, user_id, aes_pin, count_in, v2_replay_counter) - if use_v2_protocol: - v2_prior_counter = int.from_bytes(v2_replay_counter, byteorder='little', signed=False) - v2_replay_counter = b'\xc2\x00\x00\x00' + v2_prior_counter = v2_replay_counter + v2_replay_counter = b'\xc2\x00\x00\x00' if use_v2_protocol else None hps_out, key_out, count_out, replay_local = PINDb._load_pin_fields(pinfile, user_id, - aes_pin, - v2_replay_counter) + aes_pin) self.assertEqual(hps_out, hps_in) self.assertEqual(key_out, key_in) self.assertEqual(count_out, count_in) From a7391968b479c9abc6b27620108fa77d24043d13 Mon Sep 17 00:00:00 2001 From: "Jamie C. Driver" Date: Wed, 22 Nov 2023 12:38:15 +0000 Subject: [PATCH 6/9] bpov2: Respect replay_counter in set-pin requests --- pindb.py | 15 ++++++++++++--- test/test_pindb.py | 28 +++++++++++++++++----------- test/test_pinserver.py | 21 +++++++++++++-------- 3 files changed, 42 insertions(+), 22 deletions(-) diff --git a/pindb.py b/pindb.py index 7ef71c8..74778fa 100644 --- a/pindb.py +++ b/pindb.py @@ -261,15 +261,24 @@ def get_aes_key(cls, cke, payload, aes_pin_data_key, replay_counter=None): @classmethod def set_pin(cls, cke, payload, aes_pin_data_key, replay_counter=None): pin_secret, entropy, pin_pubkey = cls._extract_fields(cke, payload, replay_counter) + pin_pubkey_hash = bytes(sha256(pin_pubkey)) + + # Load any existing replay counter for the pubkey + # and if found check the anti-replay counter + replay_local = None + try: + _, _, _, replay_local = cls._load_pin_fields(pin_pubkey_hash, pin_pubkey, + aes_pin_data_key) + cls._check_v2_anti_replay(replay_local, replay_counter) + except FileNotFoundError as e: + # No existing record for given pubkey - fine + pass # Make a new aes-key to persist from our and client entropy our_random = os.urandom(32) new_key = hmac_sha256(our_random, entropy) - assert replay_counter is None or replay_counter == b'\x00\x00\x00\x00' - # Persist the pin fields - pin_pubkey_hash = bytes(sha256(pin_pubkey)) hash_pin_secret = sha256(pin_secret) replay_bytes = None if replay_counter is not None: diff --git a/test/test_pindb.py b/test/test_pindb.py index d2e0e0d..f9ee4ba 100644 --- a/test/test_pindb.py +++ b/test/test_pindb.py @@ -213,7 +213,7 @@ def _test_bad_guesses_clears_pin_impl(self, v2set, v2get): pin_secret, entropy = self.new_pin_secret(), self.new_entropy() # Build the expected payload - v2_replay_counter = b'\x00\x00\x00\x00' if v2set else None + v2_replay_counter = b'\x05\x00\x00\x00' if v2set else None payload = self.make_payload(privkey, cke, pin_secret, entropy, v2_replay_counter) # Set and verify the the pin - check this creates the file @@ -222,7 +222,7 @@ def _test_bad_guesses_clears_pin_impl(self, v2set, v2get): aeskey_s = PINDb.set_pin(cke, payload, pin_aes_key, v2_replay_counter) self.assertEqual(len(aeskey_s), AES_KEY_LEN_256) - v2_replay_counter = b'\x01\x00\x00\x00' if v2set else None + v2_replay_counter = b'\x06\x00\x00\x00' if v2set else None payload = self.make_payload(privkey, cke, pin_secret, entropy, v2_replay_counter) aeskey_g = PINDb.get_aes_key(cke, payload, pin_aes_key, v2_replay_counter) self.assertTrue(compare_digest(aeskey_g, aeskey_s)) @@ -263,7 +263,7 @@ def _test_bad_server_key_breaks_impl(self, use_v2_protocol): pin_aes_key = bytes(os.urandom(32)) # Set and verify the the pin - check this creates the file - v2_replay_counter = b'\x00\x00\x00\x00' if use_v2_protocol else None + v2_replay_counter = b'\x04\x00\x00\x00' if use_v2_protocol else None payload = self.make_payload(privkey, cke, pin_secret, entropy, v2_replay_counter) self.assertFalse(PINDb.storage.exists(pinfile)) aeskey_s = PINDb.set_pin(cke, payload, pin_aes_key, v2_replay_counter) @@ -271,7 +271,7 @@ def _test_bad_server_key_breaks_impl(self, use_v2_protocol): self.assertTrue(PINDb.storage.exists(pinfile)) # Check we can get the key - v2_replay_counter = b'\x01\x00\x00\x00' if use_v2_protocol else None + v2_replay_counter = b'\x05\x00\x00\x00' if use_v2_protocol else None payload = self.make_payload(privkey, cke, pin_secret, entropy, v2_replay_counter) aeskey_g = PINDb.get_aes_key(cke, payload, pin_aes_key, v2_replay_counter) self.assertTrue(compare_digest(aeskey_g, aeskey_s)) @@ -413,26 +413,32 @@ def test_bad_v2_counter_breaks_set_pin(self): aeskey_g = PINDb.get_aes_key(cke, payload, pin_aes_key, v2_replay_counter) self.assertTrue(compare_digest(aeskey_g, aeskey_s)) - # Set-pin fails if use a non-zero counter - v2_replay_counter = b'\x0f\x0f\x00\x00' + # Set-pin must also respect the counter + v2_replay_counter = b'\x05\x00\x00\x00' payload = self.make_payload(privkey, cke, pin_secret, entropy, v2_replay_counter) with self.assertRaises(AssertionError) as cm: aeskey_s = PINDb.set_pin(cke, payload, pin_aes_key, v2_replay_counter) - # Key still present and readable with lower counter as set failed + v2_replay_counter = b'\x00\x00\x00\x00' + payload = self.make_payload(privkey, cke, pin_secret, entropy, v2_replay_counter) + with self.assertRaises(AssertionError) as cm: + aeskey_s = PINDb.set_pin(cke, payload, pin_aes_key, v2_replay_counter) + + # Key still present and readable as set failed v2_replay_counter = b'\x06\x00\x00\x00' payload = self.make_payload(privkey, cke, pin_secret, entropy, v2_replay_counter) aeskey_g = PINDb.get_aes_key(cke, payload, pin_aes_key, v2_replay_counter) self.assertTrue(compare_digest(aeskey_g, aeskey_s)) - # Set-pin must use a counter of 0 - v2_replay_counter = b'\x00\x00\x00\x00' + # Set-pin must use a higher counter + v2_replay_counter = b'\x07\x00\x00\x00' payload = self.make_payload(privkey, cke, pin_secret, entropy, v2_replay_counter) aeskey_s = PINDb.set_pin(cke, payload, pin_aes_key, v2_replay_counter) self.assertEqual(len(aeskey_s), AES_KEY_LEN_256) + self.assertNotEqual(aeskey_s, aeskey_g) # has changed - # Key readable with new counter - v2_replay_counter = b'\x01\x00\x00\x00' + # Key readable with ongoing counter + v2_replay_counter = b'\x08\x00\x00\x00' payload = self.make_payload(privkey, cke, pin_secret, entropy, v2_replay_counter) aeskey_g = PINDb.get_aes_key(cke, payload, pin_aes_key, v2_replay_counter) self.assertTrue(compare_digest(aeskey_g, aeskey_s)) diff --git a/test/test_pinserver.py b/test/test_pinserver.py index c45e97d..bb233f8 100644 --- a/test/test_pinserver.py +++ b/test/test_pinserver.py @@ -107,7 +107,7 @@ def new_client_v1(self): client = PINClientECDH(self.static_server_public_key) return self.start_handshake_v1(client) - def new_client_v2(self, reset_replay_counter): + def new_client_v2(self, reset_replay_counter=False): if reset_replay_counter: client_counter = b'\x00\x00\x00\x00' else: @@ -185,9 +185,7 @@ def server_call_v2(self, private_key, client, endpoint, pin_secret, entropy, def make_server_call(self, private_key, endpoint, pin_secret, entropy, use_v2_protocol, fn_perturb_request=None): if use_v2_protocol: - # NOTE: replay_counter must be reset to 0x00 with 'set_pin' requests - reset_counter = endpoint == 'set_pin' - client = self.new_client_v2(reset_counter) + client = self.new_client_v2() server_call = self.server_call_v2 else: client = self.new_client_v1() @@ -561,11 +559,18 @@ def test_set_pin_counter_v2(self): aeskey_s = self.set_pin(priv_key, pin_secret, self.new_entropy(), use_v2_protocol=True) # Get/verify pin with a new client + client = self.new_client_v2() aeskey_g = self.get_pin(priv_key, pin_secret, self.new_entropy(), use_v2_protocol=True) self.assertTrue(compare_digest(aeskey_g, aeskey_s)) - # Trying to set-pin with non-zero counter should fail - client = self.new_client_v2(False) + # Trying to set-pin with same counter should fail + with self.assertRaises(ValueError) as cm: + aeskey_g = self.server_call_v2(priv_key, client, 'set_pin', self.new_pin_secret(), + self.new_entropy()) + self.assertEqual('500', str(cm.exception.args[0])) + + # Trying to set-pin with zero counter should fail + client = self.new_client_v2(True) with self.assertRaises(ValueError) as cm: aeskey_g = self.server_call_v2(priv_key, client, 'set_pin', self.new_pin_secret(), self.new_entropy()) @@ -575,9 +580,9 @@ def test_set_pin_counter_v2(self): aeskey_g = self.get_pin(priv_key, pin_secret, self.new_entropy(), use_v2_protocol=True) self.assertTrue(compare_digest(aeskey_g, aeskey_s)) - # Trying to reset pin with zero counter should work + # Trying to set pin while respecting the counter should work pin_secret = self.new_pin_secret() - client = self.new_client_v2(True) + client = self.new_client_v2() aeskey_s = self.server_call_v2(priv_key, client, 'set_pin', pin_secret, self.new_entropy()) self.assertFalse(compare_digest(aeskey_g, aeskey_s)) # changed From 7d1f8a24d1e336ceef7159ac6c2a1fd2e8b35287 Mon Sep 17 00:00:00 2001 From: "Jamie C. Driver" Date: Wed, 15 Nov 2023 10:44:49 +0000 Subject: [PATCH 7/9] bpov2: Changes to use wally.aes_cbc_with_ecdh_key() in protocol v2 --- client.py | 31 ++++++++------ flaskserver.py | 21 ++++----- lib.py | 15 ++++++- server.py | 31 +++++++------- test/test_ecdh_v1.py | 27 ++++++------ test/test_ecdh_v2.py | 96 ++++++++++++++++++++---------------------- test/test_pinserver.py | 22 +++++----- 7 files changed, 130 insertions(+), 113 deletions(-) diff --git a/client.py b/client.py index 0fa0914..421313a 100644 --- a/client.py +++ b/client.py @@ -11,6 +11,14 @@ def __init__(self, static_server_public_key): self.static_server_public_key = static_server_public_key self.ecdh_server_public_key = None + # returns ske, cke + def get_key_exchange(self): + return self.ecdh_server_public_key, self.public_key + + +# NOTE: protocol v1: +# Explicit 'hmac' fields, separate derived keys, and key-exchange handshake +class PINClientECDHv1(PINClientECDH): def handshake(self, e_ecdh_server_public_key, static_server_signature): ec_sig_verify( self.static_server_public_key, @@ -24,10 +32,6 @@ def handshake(self, e_ecdh_server_public_key, static_server_signature): # Cache the shared secrets self.generate_shared_secrets(e_ecdh_server_public_key) - # returns ske, cke - def get_key_exchange(self): - return self.ecdh_server_public_key, self.public_key - # Encrypt/sign/hmac the payload (ie. the pin secret) def encrypt_request_payload(self, payload): assert self.ecdh_server_public_key @@ -48,22 +52,25 @@ def decrypt_response_payload(self, encrypted, hmac): return decrypt(self.response_encryption_key, encrypted) +# NOTE: protocol v2: +# 'hmac' fields and derived keys implicit, and no key-exchange handshake required class PINClientECDHv2(PINClientECDH): def __init__(self, static_server_public_key, replay_counter): super().__init__(static_server_public_key) + + assert len(replay_counter) == 4 self.replay_counter = replay_counter - tweak = sha256(hmac_sha256(self.public_key, self.replay_counter)) # Derive and store the ecdh server public key (ske) + tweak = sha256(hmac_sha256(self.public_key, self.replay_counter)) self.ecdh_server_public_key = ec_public_key_bip341_tweak( self.static_server_public_key, tweak, 0) - # Cache the shared secrets - self.generate_shared_secrets(self.ecdh_server_public_key) - - # Encrypt/sign/hmac the payload (ie. the pin secret) def encrypt_request_payload(self, payload): - encrypted = encrypt(self.request_encryption_key, payload) - hmac = hmac_sha256(self.request_hmac_key, self.public_key + self.replay_counter + encrypted) - return encrypted, hmac + return self.encrypt_with_ecdh(self.ecdh_server_public_key, self.LABEL_ORACLE_REQUEST, + payload) + + def decrypt_response_payload(self, encrypted): + return self.decrypt_with_ecdh(self.ecdh_server_public_key, self.LABEL_ORACLE_RESPONSE, + encrypted) diff --git a/flaskserver.py b/flaskserver.py index e6540df..b8ceb6f 100644 --- a/flaskserver.py +++ b/flaskserver.py @@ -2,9 +2,9 @@ import json import os from flask import Flask, request, jsonify -from .server import PINServerECDH, PINServerECDHv2 +from .server import PINServerECDH, PINServerECDHv1, PINServerECDHv2 from .pindb import PINDb -from wallycore import AES_KEY_LEN_256, AES_BLOCK_LEN +from wallycore import AES_KEY_LEN_256, AES_BLOCK_LEN, HMAC_SHA256_LEN from dotenv import load_dotenv # Time we will retain active sessions, in seconds. @@ -39,7 +39,7 @@ def start_handshake_route(): app.logger.debug('Number of sessions {}'.format(len(sessions))) # Create a new ephemeral server/session and get its signed pubkey - e_ecdh_server = PINServerECDH() + e_ecdh_server = PINServerECDHv1() pubkey, sig = e_ecdh_server.get_signed_public_key() ske = pubkey.hex() @@ -51,6 +51,7 @@ def start_handshake_route(): return jsonify({'ske': ske, 'sig': sig.hex()}) + # NOTE: explicit 'hmac' fields in protocol v1 def _complete_server_call_v1(pin_func, udata): ske = udata['ske'] assert 'replay_counter' not in udata @@ -66,8 +67,9 @@ def _complete_server_call_v1(pin_func, udata): bytes.fromhex(udata['hmac_encrypted_data']), pin_func) - # Expecting to return an encrypted aes-key + # Expecting to return an encrypted aes-key with separate hmac assert len(encrypted_key) == AES_KEY_LEN_256 + (2*AES_BLOCK_LEN) + assert len(hmac) == HMAC_SHA256_LEN # Cleanup session del sessions[ske] @@ -77,24 +79,23 @@ def _complete_server_call_v1(pin_func, udata): return jsonify({'encrypted_key': encrypted_key.hex(), 'hmac': hmac.hex()}) + # NOTE: 'hmac' data is appened to encrypted_data in protocol v2 def _complete_server_call_v2(pin_func, udata): assert 'ske' not in udata assert len(udata['replay_counter']) == 8 cke = bytes.fromhex(udata['cke']) replay_counter = bytes.fromhex(udata['replay_counter']) e_ecdh_server = PINServerECDHv2(replay_counter, cke) - encrypted_key, hmac = e_ecdh_server.call_with_payload( + encrypted_key = e_ecdh_server.call_with_payload( cke, bytes.fromhex(udata['encrypted_data']), - bytes.fromhex(udata['hmac_encrypted_data']), pin_func) - # Expecting to return an encrypted aes-key - assert len(encrypted_key) == AES_KEY_LEN_256 + (2*AES_BLOCK_LEN) + # Expecting to return an encrypted aes-key with hmac appended + assert len(encrypted_key) == AES_KEY_LEN_256 + (2*AES_BLOCK_LEN) + HMAC_SHA256_LEN # Return response - return jsonify({'encrypted_key': encrypted_key.hex(), - 'hmac': hmac.hex()}) + return jsonify({'encrypted_key': encrypted_key.hex()}) def _complete_server_call(pin_func): try: diff --git a/lib.py b/lib.py index b764e13..29dd328 100644 --- a/lib.py +++ b/lib.py @@ -2,7 +2,7 @@ from wallycore import AES_BLOCK_LEN, AES_FLAG_DECRYPT, AES_FLAG_ENCRYPT, \ aes_cbc, ec_private_key_verify, ec_public_key_from_private_key, ecdh, \ - hmac_sha256 + hmac_sha256, aes_cbc_with_ecdh_key def encrypt(aes_key, plaintext): @@ -19,6 +19,10 @@ def decrypt(aes_key, encrypted): class E_ECDH(object): + # Labels used to derived child keys for aes_cbc_with_ecdh_key() call + LABEL_ORACLE_REQUEST = 'blind_oracle_request'.encode() + LABEL_ORACLE_RESPONSE = 'blind_oracle_response'.encode() + @classmethod def _generate_private_key(cls): counter = 4 @@ -50,3 +54,12 @@ def _derived(val): self.request_hmac_key = _derived(1) self.response_encryption_key = _derived(2) self.response_hmac_key = _derived(3) + + def decrypt_with_ecdh(self, public_key, label, encrypted): + return aes_cbc_with_ecdh_key(self.private_key, None, encrypted, public_key, label, + AES_FLAG_DECRYPT) + + def encrypt_with_ecdh(self, public_key, label, plaintext): + iv = os.urandom(AES_BLOCK_LEN) + return aes_cbc_with_ecdh_key(self.private_key, iv, plaintext, public_key, label, + AES_FLAG_ENCRYPT) diff --git a/server.py b/server.py index 5753f43..bb1a0a6 100644 --- a/server.py +++ b/server.py @@ -58,6 +58,13 @@ def __init__(self): super().__init__() self.time_started = int(time.time()) + +# NOTE: protocol v1: +# Explicit 'hmac' fields, separate derived keys, and key-exchange handshake +class PINServerECDHv1(PINServerECDH): + def __init__(self): + super().__init__() + def get_signed_public_key(self): return self.public_key, self._sign_with_static_key(self.public_key) @@ -89,6 +96,8 @@ def call_with_payload(self, cke, encrypted, hmac, func): return encrypted, hmac +# NOTE: protocol v2: +# 'hmac' fields and derived keys implicit, and no key-exchange handshake required class PINServerECDHv2(PINServerECDH): @classmethod @@ -107,24 +116,16 @@ def __init__(self, replay_counter, cke): self.replay_counter = replay_counter self.private_key, self.public_key = self.generate_ec_key_pair(replay_counter, cke) - # Decrypt the received payload (ie. aes-key) - def decrypt_request_payload(self, cke, encrypted, hmac): - # Verify hmac received - hmac_calculated = hmac_sha256(self.request_hmac_key, cke + self.replay_counter + encrypted) - assert compare_digest(hmac, hmac_calculated) + def decrypt_request_payload(self, cke, encrypted): + return self.decrypt_with_ecdh(cke, self.LABEL_ORACLE_REQUEST, encrypted) - # Return decrypted data - return decrypt(self.request_encryption_key, encrypted) + def encrypt_response_payload(self, cke, payload): + return self.encrypt_with_ecdh(cke, self.LABEL_ORACLE_RESPONSE, payload) # Function to deal with wrapper ecdh encryption. # Calls passed function with unwrapped payload, and wraps response before # returning. Separates payload handler func from wrapper encryption. - def call_with_payload(self, cke, encrypted, hmac, func): - self.generate_shared_secrets(cke) - payload = self.decrypt_request_payload(cke, encrypted, hmac) - - # Call the passed function with the decrypted payload + def call_with_payload(self, cke, encrypted, func): + payload = self.decrypt_request_payload(cke, encrypted) response = func(cke, payload, self._get_aes_pin_data_key(), self.replay_counter) - - encrypted, hmac = self.encrypt_response_payload(response) - return encrypted, hmac + return self.encrypt_response_payload(cke, response) diff --git a/test/test_ecdh_v1.py b/test/test_ecdh_v1.py index 9ac7d23..31aa8c1 100644 --- a/test/test_ecdh_v1.py +++ b/test/test_ecdh_v1.py @@ -2,27 +2,28 @@ import os -from ..client import PINClientECDH -from ..server import PINServerECDH +from ..client import PINClientECDHv1 +from ..server import PINServerECDHv1 # Tests ECDHv1 wrapper without any reference to the pin/aes-key paylod stuff. # Just testing the ECDH envelope/encryption in isolation, with misc bytearray() # payloads (ie. any old str.encode()). Tests client/server handshake/pairing. -# NOTE: protocol v1: key-exchange handshake required +# NOTE: protocol v1: +# Explicit 'hmac' fields, separate derived keys, and key-exchange handshake class ECDHv1Test(unittest.TestCase): @classmethod def setUpClass(cls): - PINServerECDH.load_private_key() + PINServerECDHv1.load_private_key() # The server public key the client would know - with open(PINServerECDH.STATIC_SERVER_PUBLIC_KEY_FILE, 'rb') as f: + with open(PINServerECDHv1.STATIC_SERVER_PUBLIC_KEY_FILE, 'rb') as f: cls.static_server_public_key = f.read() # Make a new client and initialise with server handshake def new_client_handshake(self, ske, sig): - client = PINClientECDH(self.static_server_public_key) + client = PINClientECDHv1(self.static_server_public_key) client.handshake(ske, sig) ske1, cke = client.get_key_exchange() self.assertEqual(ske, ske1) @@ -32,7 +33,7 @@ def _test_client_server_impl(self, client_request, server_response): # A new server is created, which signs its newly-created ske with the # static key (so the client can validate that the ske is genuine). - server = PINServerECDH() + server = PINServerECDHv1() ske, sig = server.get_signed_public_key() # They get sent to the client (eg. over network) which then validates @@ -71,7 +72,7 @@ def test_client_server_happypath(self): def test_call_with_payload(self): # A new server and client - server = PINServerECDH() + server = PINServerECDHv1() ske, sig = server.get_signed_public_key() cke, client = self.new_client_handshake(ske, sig) @@ -98,7 +99,7 @@ def _func(client_key, payload, aes_pin_data_key): def test_multiple_calls(self): # A new server and client - server = PINServerECDH() + server = PINServerECDHv1() ske, sig = server.get_signed_public_key() cke, client = self.new_client_handshake(ske, sig) @@ -120,7 +121,7 @@ def test_multiple_calls(self): def test_multiple_clients(self): # A new server and several clients - server = PINServerECDH() + server = PINServerECDHv1() ske, sig = server.get_signed_public_key() # Server can persist and handle multiple calls provided each one is @@ -143,7 +144,7 @@ def test_multiple_clients(self): def test_bad_request_cke_throws(self): # A new server and client - server = PINServerECDH() + server = PINServerECDHv1() ske, sig = server.get_signed_public_key() cke, client = self.new_client_handshake(ske, sig) @@ -173,7 +174,7 @@ def _func(client_key, payload, aes_pin_data_key): def test_bad_request_hmac_throws(self): # A new server and client - server = PINServerECDH() + server = PINServerECDHv1() ske, sig = server.get_signed_public_key() cke, client = self.new_client_handshake(ske, sig) @@ -200,7 +201,7 @@ def _func(client_key, payload, aes_pin_data_key): def test_bad_response_hmac_throws(self): # A new server and client - server = PINServerECDH() + server = PINServerECDHv1() ske, sig = server.get_signed_public_key() cke, client = self.new_client_handshake(ske, sig) diff --git a/test/test_ecdh_v2.py b/test/test_ecdh_v2.py index 61ee71e..bf39a88 100644 --- a/test/test_ecdh_v2.py +++ b/test/test_ecdh_v2.py @@ -9,7 +9,8 @@ # Tests ECDHv2 wrapper without any reference to the pin/aes-key paylod stuff. # Just testing the ECDH envelope/encryption in isolation, with misc bytearray() # payloads (ie. any old str.encode()). Tests client/server handshake/pairing. -# NOTE: protocol v2: no key-exchange handshake required +# NOTE: protocol v2: +# 'hmac' fields and derived keys implicit, and no key-exchange handshake required class ECDHv2Test(unittest.TestCase): REPLAY_COUNTER = bytes([0x00, 0x00, 0x00, 0x2a]) # arbitrary @@ -31,8 +32,8 @@ def _test_client_server_impl(self, client_request, server_response): # server static key. cke, client = self.new_client_handshake() - # The client can then encrypt a payload (and hmac) for the server - encrypted, hmac = client.encrypt_request_payload(client_request) + # The client can then encrypt a payload (with implicit hmac) for the server + encrypted = client.encrypt_request_payload(client_request) self.assertNotEqual(client_request, encrypted) # A new server is created when passed the client replay-counter. @@ -40,18 +41,16 @@ def _test_client_server_impl(self, client_request, server_response): # NOTE: the server deduced private key should be the counterpart to the # client-deduced public key - if so the payload decryption should yield # the original cleartext request message. - # Note: this validates hmac before it decrypts/returns server = PINServerECDHv2(client.replay_counter, cke) - server.generate_shared_secrets(cke) - received = server.decrypt_request_payload(cke, encrypted, hmac) + received = server.decrypt_request_payload(cke, encrypted) self.assertEqual(received, client_request) # The server can then send an encrypted response to the client - encrypted, hmac = server.encrypt_response_payload(server_response) + encrypted = server.encrypt_response_payload(cke, server_response) # The client can decrypt the response. - # Note: this validates hmac before it decrypts/returns - received = client.decrypt_response_payload(encrypted, hmac) + # Note: this validates the implicit hmac before it decrypts/returns + received = client.decrypt_response_payload(encrypted) self.assertEqual(received, server_response) def test_client_server_happypath(self): @@ -67,7 +66,7 @@ def test_call_with_payload(self): # Client sends message to server cke, client = self.new_client_handshake() client_request = "Hello - test 123".encode() - encrypted, hmac = client.encrypt_request_payload(client_request) + encrypted = client.encrypt_request_payload(client_request) self.assertNotEqual(client_request, encrypted) # Test server un-/re-wrapping function - this handles all the ecdh @@ -82,10 +81,10 @@ def _func(client_key, payload, aes_pin_data_key, replay_counter): self.assertEqual(replay_counter, client.replay_counter) return server_response - encrypted, hmac = server.call_with_payload(cke, encrypted, hmac, _func) + encrypted = server.call_with_payload(cke, encrypted, _func) # Assert that is what the client expects - received = client.decrypt_response_payload(encrypted, hmac) + received = client.decrypt_response_payload(encrypted) self.assertEqual(received, server_response) def test_multiple_calls(self): @@ -95,18 +94,17 @@ def test_multiple_calls(self): # Server can handle multiple calls from the client with same secrets # (But that would use same cke and counter which is ofc not ideal/recommended.) - server.generate_shared_secrets(cke) for i in range(5): client_request = 'request-{}'.format(i).encode() - encrypted, hmac = client.encrypt_request_payload(client_request) + encrypted = client.encrypt_request_payload(client_request) - received = server.decrypt_request_payload(cke, encrypted, hmac) + received = server.decrypt_request_payload(cke, encrypted) self.assertEqual(received, client_request) server_response = 'response-{}'.format(i).encode() - encrypted, hmac = server.encrypt_response_payload(server_response) + encrypted = server.encrypt_response_payload(cke, server_response) - received = client.decrypt_response_payload(encrypted, hmac) + received = client.decrypt_response_payload(encrypted) self.assertEqual(received, server_response) def test_bad_request_cke_throws(self): @@ -116,7 +114,7 @@ def test_bad_request_cke_throws(self): # Encrypt message client_request = 'bad-cke-request'.encode() - encrypted, hmac = client.encrypt_request_payload(client_request) + encrypted = client.encrypt_request_payload(client_request) # Break cke bad_cke, _ = self.new_client_handshake() @@ -125,25 +123,19 @@ def test_bad_request_cke_throws(self): # Ensure decrypt_request() throws server.generate_shared_secrets(cke) - server.decrypt_request_payload(cke, encrypted, hmac) # no error - - # Same server using good cke to derive keys, but bad cke passed - server.generate_shared_secrets(bad_cke) - with self.assertRaises(AssertionError) as cm: - server.decrypt_request_payload(bad_cke, encrypted, hmac) # error + server.decrypt_request_payload(cke, encrypted) # no error # New server with bad_cke from the get go server = PINServerECDHv2(client.replay_counter, bad_cke) - server.generate_shared_secrets(bad_cke) - with self.assertRaises(AssertionError) as cm: - server.decrypt_request_payload(bad_cke, encrypted, hmac) # error + with self.assertRaises(ValueError) as cm: + server.decrypt_request_payload(bad_cke, encrypted) # error # Ensure call_with_payload() throws before it calls the handler fn def _func(client_key, payload, aes_pin_data_key): self.fail('should-never-get-here') - with self.assertRaises(AssertionError) as cm: - server.call_with_payload(bad_cke, encrypted, hmac, _func) + with self.assertRaises(ValueError) as cm: + server.call_with_payload(bad_cke, encrypted, _func) def test_bad_request_counter_throws(self): # A new server and client @@ -152,24 +144,23 @@ def test_bad_request_counter_throws(self): # Encrypt message client_request = 'bad-counter-request'.encode() - encrypted, hmac = client.encrypt_request_payload(client_request) + encrypted = client.encrypt_request_payload(client_request) # Ensure decrypt_request() throws server.generate_shared_secrets(cke) - server.decrypt_request_payload(cke, encrypted, hmac) # no error + server.decrypt_request_payload(cke, encrypted) # no error # New server with bad counter passed server = PINServerECDHv2(os.urandom(4), cke) - server.generate_shared_secrets(cke) - with self.assertRaises(AssertionError) as cm: - server.decrypt_request_payload(cke, encrypted, hmac) # error + with self.assertRaises(ValueError) as cm: + server.decrypt_request_payload(cke, encrypted) # error # Ensure call_with_payload() throws before it calls the handler fn def _func(client_key, payload, aes_pin_data_key): self.fail('should-never-get-here') - with self.assertRaises(AssertionError) as cm: - server.call_with_payload(cke, encrypted, hmac, _func) + with self.assertRaises(ValueError) as cm: + server.call_with_payload(cke, encrypted, _func) def test_bad_request_hmac_throws(self): # A new server and client @@ -178,24 +169,25 @@ def test_bad_request_hmac_throws(self): # Encrypt message client_request = 'bad-hmac-request'.encode() - encrypted, hmac = client.encrypt_request_payload(client_request) + encrypted = client.encrypt_request_payload(client_request) - # Break hmac + # Break hmac at tail of encrypted bytes bad_hmac = bytearray(b+1 if b < 255 else b-1 for b in encrypted[-32:]) - self.assertNotEqual(hmac, bad_hmac) + bad_encrypted = encrypted[:-32] + bad_hmac + self.assertEqual(encrypted[:-32], bad_encrypted[:-32]) + self.assertNotEqual(encrypted[-32:], bad_encrypted[-32:]) # Ensure decrypt_request() throws - server.generate_shared_secrets(cke) - server.decrypt_request_payload(cke, encrypted, hmac) # no error - with self.assertRaises(AssertionError) as cm: - server.decrypt_request_payload(cke, encrypted, bad_hmac) # error + server.decrypt_request_payload(cke, encrypted) # no error + with self.assertRaises(ValueError) as cm: + server.decrypt_request_payload(cke, bad_encrypted) # error # Ensure call_with_payload() throws before it calls the handler fn def _func(client_key, payload, aes_pin_data_key, replay_counter): self.fail('should-never-get-here') - with self.assertRaises(AssertionError) as cm: - server.call_with_payload(cke, encrypted, bad_hmac, _func) + with self.assertRaises(ValueError) as cm: + server.call_with_payload(cke, bad_encrypted, _func) def test_bad_response_hmac_throws(self): # A new server and client @@ -204,7 +196,7 @@ def test_bad_response_hmac_throws(self): # Encrypt message client_request = 'bad-hmac-response-request'.encode() - encrypted, hmac = client.encrypt_request_payload(client_request) + encrypted = client.encrypt_request_payload(client_request) def _func(client_key, payload, pin_data_aes_key, replay_counter): self.assertEqual(client_key, cke) @@ -212,15 +204,17 @@ def _func(client_key, payload, pin_data_aes_key, replay_counter): self.assertEqual(replay_counter, client.replay_counter) return 'bad-hmac-response'.encode() - encrypted, hmac = server.call_with_payload(cke, encrypted, hmac, _func) + encrypted = server.call_with_payload(cke, encrypted, _func) # Break hmac bad_hmac = bytearray(b+1 if b < 255 else b-1 for b in encrypted[-32:]) - self.assertNotEqual(hmac, bad_hmac) + bad_encrypted = encrypted[:-32] + bad_hmac + self.assertEqual(encrypted[:-32], bad_encrypted[:-32]) + self.assertNotEqual(encrypted[-32:], bad_encrypted[-32:]) - client.decrypt_response_payload(encrypted, hmac) # No error - with self.assertRaises(AssertionError) as cm: - client.decrypt_response_payload(encrypted, bad_hmac) # error + client.decrypt_response_payload(encrypted) # No error + with self.assertRaises(ValueError) as cm: + client.decrypt_response_payload(bad_encrypted) # error if __name__ == '__main__': diff --git a/test/test_pinserver.py b/test/test_pinserver.py index bb233f8..9134892 100644 --- a/test/test_pinserver.py +++ b/test/test_pinserver.py @@ -7,7 +7,7 @@ from hmac import compare_digest import requests -from ..client import PINClientECDH, PINClientECDH, PINClientECDHv2 +from ..client import PINClientECDH, PINClientECDHv1, PINClientECDHv2 from ..server import PINServerECDH from ..pindb import PINDb @@ -97,14 +97,14 @@ def tearDownClass(cls): # Start the client/server key-exchange handshake def start_handshake_v1(self, client): - assert isinstance(client, PINClientECDH) + assert isinstance(client, PINClientECDHv1) handshake = self.post('start_handshake') client.handshake(bytes.fromhex(handshake['ske']), bytes.fromhex(handshake['sig'])) return client # Make a new ephemeral client and initialise with server handshake def new_client_v1(self): - client = PINClientECDH(self.static_server_public_key) + client = PINClientECDHv1(self.static_server_public_key) return self.start_handshake_v1(client) def new_client_v2(self, reset_replay_counter=False): @@ -120,7 +120,7 @@ def new_client_v2(self, reset_replay_counter=False): # NOTE: explicit hmac fields def server_call_v1(self, private_key, client, endpoint, pin_secret, entropy, fn_perturb_request=None): - assert isinstance(client, PINClientECDH) + assert isinstance(client, PINClientECDHv1) # Make and encrypt the payload (ie. pin secret) ske, cke = client.get_key_exchange() @@ -151,6 +151,7 @@ def server_call_v1(self, private_key, client, endpoint, pin_secret, entropy, # Make the server call to get/set the pin - returns the decrypted response # NOTE: signature covers replay counter + # NOTE: implicit hmac def server_call_v2(self, private_key, client, endpoint, pin_secret, entropy, fn_perturb_request=None): assert isinstance(client, PINClientECDHv2) @@ -162,14 +163,13 @@ def server_call_v2(self, private_key, client, endpoint, pin_secret, entropy, EC_FLAG_ECDSA | EC_FLAG_RECOVERABLE) payload = pin_secret + entropy + sig - encrypted, hmac = client.encrypt_request_payload(payload) + encrypted = client.encrypt_request_payload(payload) # Make call and parse response - # Includes 'replay_counter' but not 'ske' + # Includes 'replay_counter' but not 'ske' or 'hmac' urldata = {'cke': cke.hex(), - 'replay_counter': client.replay_counter.hex(), 'encrypted_data': encrypted.hex(), - 'hmac_encrypted_data': hmac.hex()} + 'replay_counter': client.replay_counter.hex()} # Caller can mangle data before it is sent if fn_perturb_request: @@ -177,10 +177,9 @@ def server_call_v2(self, private_key, client, endpoint, pin_secret, entropy, response = self.post(endpoint, urldata) encrypted = bytes.fromhex(response['encrypted_key']) - hmac = bytes.fromhex(response['hmac']) # Return decrypted payload - return client.decrypt_response_payload(encrypted, hmac) + return client.decrypt_response_payload(encrypted) def make_server_call(self, private_key, endpoint, pin_secret, entropy, use_v2_protocol, fn_perturb_request=None): @@ -425,13 +424,14 @@ def _fn(d): request_manglers = [_set('cke', bad_cke.hex())] request_manglers.extend(f('cke') for f in [_short, _long, _remove]) request_manglers.extend(f('encrypted_data') for f in [_random, _short, _long, _remove]) - request_manglers.extend(f('hmac_encrypted_data') for f in [_random, _short, _long, _remove]) if use_v2_protocol: request_manglers.extend(f('replay_counter') for f in [_random, _short, _long, _remove]) else: request_manglers.append(_set('ske', bad_ske.hex())) request_manglers.extend(f('ske') for f in [_short, _long, _remove]) + request_manglers.extend(f('hmac_encrypted_data') + for f in [_random, _short, _long, _remove]) for mangler in request_manglers: for endpoint in ['get_pin', 'set_pin']: From b6425ab29c999420b30e45495141bf85e23b5b4f Mon Sep 17 00:00:00 2001 From: "Jamie C. Driver" Date: Fri, 1 Dec 2023 09:54:53 +0000 Subject: [PATCH 8/9] bpov2: Do not require client entropy in 'get' message --- pindb.py | 18 +++++-- test/test_pindb.py | 118 ++++++++++++++++++++++++++--------------- test/test_pinserver.py | 73 +++++++++++++++---------- 3 files changed, 134 insertions(+), 75 deletions(-) diff --git a/pindb.py b/pindb.py index 74778fa..e8c8bd7 100644 --- a/pindb.py +++ b/pindb.py @@ -102,14 +102,19 @@ class PINDb(object): @classmethod def _extract_fields(cls, cke, data, replay_counter=None): - assert len(data) == (2*SHA256_LEN) + EC_SIGNATURE_RECOVERABLE_LEN + assert len(data) > SHA256_LEN - # secret + entropy + sig + # secret + (optional)entropy + sig pin_secret = data[:SHA256_LEN] - entropy = data[SHA256_LEN: SHA256_LEN + SHA256_LEN] - sig = data[SHA256_LEN + SHA256_LEN:] + if len(data) == SHA256_LEN + SHA256_LEN + EC_SIGNATURE_RECOVERABLE_LEN: + entropy = data[SHA256_LEN: SHA256_LEN + SHA256_LEN] + sig = data[SHA256_LEN + SHA256_LEN:] + else: + assert len(data) == SHA256_LEN + EC_SIGNATURE_RECOVERABLE_LEN + entropy = b'' + sig = data[SHA256_LEN:] - # make sure the client_public_key signs over the replay counter too if provided + # The client_public_key also signs over any replay counter if replay_counter is not None: assert len(replay_counter) == 4 signed_msg = sha256(cke + replay_counter + pin_secret + entropy) @@ -243,6 +248,7 @@ def get_aes_key_impl(cls, pin_pubkey, pin_secret, aes_pin_data_key, replay_count # Get existing aes_key given pin fields, or junk if pin or pubkey bad @classmethod def get_aes_key(cls, cke, payload, aes_pin_data_key, replay_counter=None): + # NOTE: we don't care about client-passed entropy at this point pin_secret, _, pin_pubkey = cls._extract_fields(cke, payload, replay_counter) # Translate internal exception and bad-pin into junk key @@ -260,8 +266,10 @@ def get_aes_key(cls, cke, payload, aes_pin_data_key, replay_counter=None): # Set pin fields, return new aes_key @classmethod def set_pin(cls, cke, payload, aes_pin_data_key, replay_counter=None): + # NOTE: we require client-passed entropy at this point pin_secret, entropy, pin_pubkey = cls._extract_fields(cke, payload, replay_counter) pin_pubkey_hash = bytes(sha256(pin_pubkey)) + assert entropy # Load any existing replay counter for the pubkey # and if found check the anti-replay counter diff --git a/test/test_pindb.py b/test/test_pindb.py index f9ee4ba..2c6793b 100644 --- a/test/test_pindb.py +++ b/test/test_pindb.py @@ -61,18 +61,21 @@ def tearDownClass(cls): def _test_extract_fields_impl(self, v2_replay_counter): # Reinitialise keys and secret and entropy privkey, pubkey, cke, _ = self.new_keys() - secret_in, entropy_in = self.new_pin_secret(), self.new_entropy() + secret_in = self.new_pin_secret() - # Build the expected payload - payload = self.make_payload(privkey, cke, secret_in, entropy_in, v2_replay_counter) + # NOTE: client entropy is optional + for entropy_in in [self.new_entropy(), b'']: + # Build the expected payload + payload = self.make_payload(privkey, cke, secret_in, entropy_in, v2_replay_counter) - # Check pindb function can extract the components from the payload - secret_out, entropy_out, pubkey_out = PINDb._extract_fields(cke, payload, v2_replay_counter) - self.assertEqual(secret_out, secret_in) - self.assertEqual(entropy_out, entropy_in) + # Check pindb function can extract the components from the payload + secret_out, entropy_out, pubkey_out = PINDb._extract_fields(cke, payload, + v2_replay_counter) + self.assertEqual(secret_out, secret_in) + self.assertEqual(entropy_out, entropy_in) - # Check the public key is correctly recovered from the signature - self.assertEqual(pubkey_out, pubkey) + # Check the public key is correctly recovered from the signature + self.assertEqual(pubkey_out, pubkey) def test_extract_fields(self): for v2_replay_counter in [None, os.urandom(4), os.urandom(4)]: @@ -83,33 +86,38 @@ def _test_mismatching_sig_impl(self, v2_replay_counter): # Get two sets of keys and a new secret privX, pubX, ckeX, _ = self.new_keys() privY, pubY, ckeY, _ = self.new_keys() - secret_in, entropy_in = self.new_pin_secret(), self.new_entropy() + secret_in = self.new_pin_secret() - # Build the expected payload - payload = self.make_payload(privX, ckeX, secret_in, entropy_in, v2_replay_counter) - - # Call the pindb function to extract the components from the payload - secret_out, entropy_out, pubkey = PINDb._extract_fields(ckeX, payload, v2_replay_counter) - self.assertEqual(secret_out, secret_in) - self.assertEqual(entropy_out, entropy_in) - self.assertEqual(pubkey, pubX) - - # Call the pindb function to extract the components from the payload - # but use a mismatched cke - the sig should not yield either pubkey. - secret_out, entropy_out, pubkey = PINDb._extract_fields(ckeY, payload, v2_replay_counter) - self.assertEqual(secret_out, secret_in) - self.assertEqual(entropy_out, entropy_in) - self.assertNotEqual(pubkey, pubX) - self.assertNotEqual(pubkey, pubY) - - # Call the pindb function again with the correct cke, but pass a bad replay counter - for bad_counter in [os.urandom(4), None if v2_replay_counter else os.urandom(4)]: - secret_out, entropy_out, pubkey = PINDb._extract_fields(ckeX, payload, bad_counter) + # NOTE: client entropy is optional + for entropy_in in [self.new_entropy(), b'']: + # Build the expected payload + payload = self.make_payload(privX, ckeX, secret_in, entropy_in, + v2_replay_counter) + + # Call the pindb function to extract the components from the payload + secret_out, entropy_out, pubkey = PINDb._extract_fields(ckeX, payload, + v2_replay_counter) + self.assertEqual(secret_out, secret_in) + self.assertEqual(entropy_out, entropy_in) + self.assertEqual(pubkey, pubX) + + # Call the pindb function to extract the components from the payload + # but use a mismatched cke - the sig should not yield either pubkey. + secret_out, entropy_out, pubkey = PINDb._extract_fields(ckeY, payload, + v2_replay_counter) self.assertEqual(secret_out, secret_in) self.assertEqual(entropy_out, entropy_in) self.assertNotEqual(pubkey, pubX) self.assertNotEqual(pubkey, pubY) + # Call the pindb function again with the correct cke, but pass a bad replay counter + for bad_counter in [os.urandom(4), None if v2_replay_counter else os.urandom(4)]: + secret_out, entropy_out, pubkey = PINDb._extract_fields(ckeX, payload, bad_counter) + self.assertEqual(secret_out, secret_in) + self.assertEqual(entropy_out, entropy_in) + self.assertNotEqual(pubkey, pubX) + self.assertNotEqual(pubkey, pubY) + def test_mismatching_sig(self): for v2_replay_counter in [None, os.urandom(4), os.urandom(4)]: with self.subTest(protocol='v2' if v2_replay_counter else 'v1'): @@ -194,9 +202,9 @@ def _test_set_and_get_pin_impl(self, v2set, v2get): self.assertEqual(len(aeskey_s), AES_KEY_LEN_256) self.assertTrue(PINDb.storage.exists(pinfile)) - # Get the key with the pin - new payload has new entropy (same pin) + # Get the key with the pin - new payload has no entropy and higher replay_counter (same pin) v2_replay_counter = os.urandom(4) if v2get else None - payload = self.make_payload(privkey, cke, pin_secret, self.new_entropy(), v2_replay_counter) + payload = self.make_payload(privkey, cke, pin_secret, b'', v2_replay_counter) aeskey_g = PINDb.get_aes_key(cke, payload, pin_aes_key, v2_replay_counter) self.assertTrue(compare_digest(aeskey_g, aeskey_s)) self.assertTrue(PINDb.storage.exists(pinfile)) @@ -222,6 +230,7 @@ def _test_bad_guesses_clears_pin_impl(self, v2set, v2get): aeskey_s = PINDb.set_pin(cke, payload, pin_aes_key, v2_replay_counter) self.assertEqual(len(aeskey_s), AES_KEY_LEN_256) + entropy = b'' # get does not need entropy v2_replay_counter = b'\x06\x00\x00\x00' if v2set else None payload = self.make_payload(privkey, cke, pin_secret, entropy, v2_replay_counter) aeskey_g = PINDb.get_aes_key(cke, payload, pin_aes_key, v2_replay_counter) @@ -271,6 +280,7 @@ def _test_bad_server_key_breaks_impl(self, use_v2_protocol): self.assertTrue(PINDb.storage.exists(pinfile)) # Check we can get the key + entropy = b'' # get does not need entropy v2_replay_counter = b'\x05\x00\x00\x00' if use_v2_protocol else None payload = self.make_payload(privkey, cke, pin_secret, entropy, v2_replay_counter) aeskey_g = PINDb.get_aes_key(cke, payload, pin_aes_key, v2_replay_counter) @@ -320,6 +330,7 @@ def _test_bad_user_pubkey_breaks_impl(self, use_v2_protocol): self.assertTrue(PINDb.storage.exists(pinfile)) # Check we can get the key + entropy = b'' # get does not need entropy v2_replay_counter = b'\x03\x00\x00\x00' if use_v2_protocol else None payload = self.make_payload(privkey, cke, pin_secret, entropy, v2_replay_counter) aeskey_g = PINDb.get_aes_key(cke, payload, pin_aes_key, v2_replay_counter) @@ -370,6 +381,7 @@ def test_bad_v2_counter_breaks_get_pin(self): # Check we can get the key with increasing counters, and same or # decreasing counters give a 'bad pin' result + entropy = b'' # get does not need entropy max_counter = 0 for counter in [0, 3, 3, 6, 123, 45, 332, 155, 332, 330, 500, 200, 300, 400, 501, 500]: v2_replay_counter = counter.to_bytes(4, 'little', signed=False) @@ -464,6 +476,7 @@ def _test_two_users_with_same_pin_impl(self, v2X, v2Y): self.assertFalse(compare_digest(aeskeyX_s, aeskeyY_s)) # Each can get their own key + entropy = b'' # get does not need entropy v2_replay_counterX = os.urandom(4) if v2X else None v2_replay_counterY = os.urandom(4) if v2Y else None payloadX = self.make_payload(privX, ckeX, secret_in, entropy_in, v2_replay_counterX) @@ -479,28 +492,49 @@ def test_two_users_with_same_pin(self): with self.subTest(X='v2' if v2X else 'v1', Y='v2' if v2Y else 'v1'): self._test_two_users_with_same_pin_impl(v2X, v2Y) - def _test_rejects_without_client_entropy_impl(self, use_v2_protocol): + def _test_client_entropy_impl(self, use_v2_protocol): # Reinitialise keys and secret and entropy sig_priv, _, cke, pinfile = self.new_keys() - secret, entropy = self.new_pin_secret(), bytearray() + secret, entropy = self.new_pin_secret(), self.new_entropy() + pin_aes_key = bytes(os.urandom(32)) - # Build the expected payload + # Build the expected payload with entropy and set a key v2_replay_counter = b'\x00\x00\x00\x00' if use_v2_protocol else None payload = self.make_payload(sig_priv, cke, secret, entropy, v2_replay_counter) - - pin_aes_key = bytes(os.urandom(32)) - with self.assertRaises(AssertionError) as cm: - PINDb.set_pin(cke, payload, pin_aes_key, v2_replay_counter) + aeskey_s = PINDb.set_pin(cke, payload, pin_aes_key, v2_replay_counter) v2_replay_counter = b'\x01\x00\x00\x00' if use_v2_protocol else None payload = self.make_payload(sig_priv, cke, secret, entropy, v2_replay_counter) + aeskey_g = PINDb.get_aes_key(cke, payload, pin_aes_key, v2_replay_counter) + assert aeskey_g == aeskey_s + + # Payload without client entropy + v2_replay_counter = b'\x02\x00\x00\x00' if use_v2_protocol else None + payload = self.make_payload(sig_priv, cke, secret, b'', v2_replay_counter) + + # Verify trying to set-pin without entropy fails with self.assertRaises(AssertionError) as cm: - PINDb.get_aes_key(cke, payload, pin_aes_key, v2_replay_counter) + PINDb.set_pin(cke, payload, pin_aes_key, v2_replay_counter) + + # Get-pin should be fine without entropy + aeskey_g = PINDb.get_aes_key(cke, payload, pin_aes_key, v2_replay_counter) + assert aeskey_g == aeskey_s + + # Note: wrong-length entropy is always bad + v2_replay_counter = b'\x03\x00\x00\x00' if use_v2_protocol else None + for entropy in [self.new_entropy()[:-1], self.new_entropy() + b'\xab']: + payload = self.make_payload(sig_priv, cke, secret, entropy, v2_replay_counter) + + with self.assertRaises(AssertionError) as cm: + PINDb.set_pin(cke, payload, pin_aes_key, v2_replay_counter) + + with self.assertRaises(AssertionError) as cm: + PINDb.get_aes_key(cke, payload, pin_aes_key, v2_replay_counter) - def test_rejects_without_client_entropy(self): + def test_client_entropy(self): for use_v2_protocol in [False, True]: with self.subTest(protocol='v2' if use_v2_protocol else 'v1'): - self._test_rejects_without_client_entropy_impl(use_v2_protocol) + self._test_client_entropy_impl(use_v2_protocol) if __name__ == '__main__': diff --git a/test/test_pinserver.py b/test/test_pinserver.py index 9134892..d7289ea 100644 --- a/test/test_pinserver.py +++ b/test/test_pinserver.py @@ -229,10 +229,11 @@ def _test_set_and_get_pin_impl(self, use_v2_protocol): aeskey_s = self.set_pin(priv_key, pin_secret, self.new_entropy(), use_v2_protocol) self.assertEqual(len(aeskey_s), AES_KEY_LEN_256) - # Get key with a new client, with the correct pin secret (new entropy) - for attempt in range(3): - aeskey_g = self.get_pin(priv_key, pin_secret, self.new_entropy(), use_v2_protocol) - self.assertTrue(compare_digest(aeskey_g, aeskey_s)) + # Get key with a new client, with the correct pin secret (with or without entropy) + aeskey_g = self.get_pin(priv_key, pin_secret, self.new_entropy(), use_v2_protocol) + self.assertTrue(compare_digest(aeskey_g, aeskey_s)) + aeskey_g = self.get_pin(priv_key, pin_secret, b'', use_v2_protocol) + self.assertTrue(compare_digest(aeskey_g, aeskey_s)) def test_set_and_get_pin(self): for use_v2_protocol in [False, True]: @@ -259,7 +260,10 @@ def _test_protocol_upgrade_downgrade_impl(self, v2set, v2get): aeskey = self.get_pin(priv_key, pin_secret, self.new_entropy(), not v2get) self.assertTrue(compare_digest(aeskey, aeskey_s)) - aeskey = self.get_pin(priv_key, pin_secret, self.new_entropy(), v2get) + aeskey = self.get_pin(priv_key, pin_secret, b'', v2get) + self.assertTrue(compare_digest(aeskey, aeskey_s)) + + aeskey = self.get_pin(priv_key, pin_secret, b'', not v2get) self.assertTrue(compare_digest(aeskey, aeskey_s)) def test_protocol_upgrade_downgrade(self): @@ -282,6 +286,11 @@ def _test_bad_guesses_clears_pin_impl(self, use_v2_protocol): self.assertTrue(compare_digest(aeskey_g, aeskey_s)) self.assertTrue(PINDb.storage.exists(pinfile)) + # Get does not need client entropy + entropy = b'' + aeskey_g = self.get_pin(priv_key, pin_secret, entropy, use_v2_protocol) + self.assertTrue(compare_digest(aeskey_g, aeskey_s)) + # Bad guesses at PIN for attempt in range(3): # Attempt to get with bad pin @@ -321,6 +330,11 @@ def _test_bad_pubkey_breaks_impl(self, use_v2_protocol): self.assertTrue(compare_digest(aeskey_g, aeskey_s)) self.assertTrue(PINDb.storage.exists(pinfile)) + # Get does not need client entropy + entropy = b'' + aeskey_g = self.get_pin(priv_key, pin_secret, entropy, use_v2_protocol) + self.assertTrue(compare_digest(aeskey_g, aeskey_s)) + # Bad attempts with bad pub_key for attempt in range(3): # Attempt to get with bad pub_key @@ -358,6 +372,8 @@ def _test_two_users_with_same_pin_impl(self, v2X, v2Y): aeskey_sY = self.set_pin(clientY_private_key, pin_secret, entropy, v2Y) self.assertFalse(compare_digest(aeskey_sX, aeskey_sY)) + # Get does not need client entropy + entropy = b'' aeskey_gX = self.get_pin(clientX_private_key, pin_secret, entropy, v2X) self.assertTrue(compare_digest(aeskey_gX, aeskey_sX)) @@ -446,28 +462,32 @@ def test_rejects_on_bad_json(self): with self.subTest(protocol='v2' if use_v2_protocol else 'v1'): self._test_rejects_on_bad_json_impl(use_v2_protocol) - def _test_rejects_without_client_entropy_impl(self, use_v2_protocol): + def _test_client_entropy_impl(self, use_v2_protocol): # Make ourselves a static key pair for this logical client priv_key, _, _ = self.new_static_client_keys() + pin_secret = self.new_pin_secret() - # The 'correct' client pin but no salt/iv/entropy - pin_secret, entropy = self.new_pin_secret(), bytearray() - - # Make a new client and set the pin secret to get a new aes key + # Fails if setting the pin secret without passing client entropy with self.assertRaises(ValueError) as cm: - self.set_pin(priv_key, pin_secret, entropy, use_v2_protocol) + self.set_pin(priv_key, pin_secret, b'', use_v2_protocol) self.assertEqual('500', str(cm.exception.args[0])) - with self.assertRaises(ValueError) as cm: - self.get_pin(priv_key, pin_secret, entropy, use_v2_protocol) + # Set pin with client entropy - fine + aeskey_s = self.set_pin(priv_key, pin_secret, self.new_entropy(), use_v2_protocol=False) + + # Get call works with or without entropy (it's ignored in any case) + aeskey_g = self.get_pin(priv_key, pin_secret, self.new_entropy(), use_v2_protocol) + self.assertTrue(compare_digest(aeskey_g, aeskey_s)) + aeskey_g = self.get_pin(priv_key, pin_secret, b'', use_v2_protocol) + self.assertTrue(compare_digest(aeskey_g, aeskey_s)) self.assertEqual('500', str(cm.exception.args[0])) - def test_rejects_without_client_entropy(self): + def test_client_entropy(self): for use_v2_protocol in [False, True]: with self.subTest(protocol='v2' if use_v2_protocol else 'v1'): - self._test_rejects_without_client_entropy_impl(use_v2_protocol) + self._test_client_entropy_impl(use_v2_protocol) def test_delayed_interaction_v1(self): # Make ourselves a static key pair for this logical client @@ -478,7 +498,7 @@ def test_delayed_interaction_v1(self): # Set and verify the pin aeskey_s = self.set_pin(priv_key, pin_secret, self.new_entropy(), use_v2_protocol=False) - aeskey_g = self.get_pin(priv_key, pin_secret, self.new_entropy(), use_v2_protocol=False) + aeskey_g = self.get_pin(priv_key, pin_secret, b'', use_v2_protocol=False) self.assertTrue(compare_digest(aeskey_g, aeskey_s)) # If we delay in the server interaction it will fail with a 500 error @@ -486,8 +506,7 @@ def test_delayed_interaction_v1(self): time.sleep(SESSION_LIFETIME + 1) # Sufficiently long delay with self.assertRaises(ValueError) as cm: - self.server_call_v1(priv_key, client, 'get_pin', pin_secret, - self.new_entropy()) + self.server_call_v1(priv_key, client, 'get_pin', pin_secret, b'') self.assertEqual('500', str(cm.exception.args[0])) @@ -510,16 +529,14 @@ def test_cannot_reuse_client_session_v1(self): # Trying to reuse the session should fail with a 500 error # because the server has closed that ephemeral encryption session with self.assertRaises(ValueError) as cm: - self.server_call_v1(priv_key, client, 'get_pin', pin_secret, - self.new_entropy()) + self.server_call_v1(priv_key, client, 'get_pin', pin_secret, b'') self.assertEqual('500', str(cm.exception.args[0])) # Not great, but we could reuse the client if we re-initiate handshake # (But that would use same cke which is not ideal/recommended.) self.start_handshake_v1(client) - aeskey = self.server_call_v1(priv_key, client, 'get_pin', pin_secret, - self.new_entropy()) + aeskey = self.server_call_v1(priv_key, client, 'get_pin', pin_secret, b'') self.assertTrue(compare_digest(aeskey, aeskey_s)) def test_cannot_reuse_client_session_v2(self): @@ -534,15 +551,15 @@ def test_cannot_reuse_client_session_v2(self): # Get/verify pin with a new client client = self.new_client_v2(False) - aeskey_g = self.server_call_v2(priv_key, client, 'get_pin', pin_secret, self.new_entropy()) + aeskey_g = self.server_call_v2(priv_key, client, 'get_pin', pin_secret, b'') self.assertTrue(compare_digest(aeskey_g, aeskey_s)) # Trying to reuse the session should appear to work, but will return a junk key # (ie. same as bad pin) because the server-side 'replay counter' has moved on - aeskey = self.server_call_v2(priv_key, client, 'get_pin', pin_secret, self.new_entropy()) + aeskey = self.server_call_v2(priv_key, client, 'get_pin', pin_secret, b'') self.assertFalse(compare_digest(aeskey, aeskey_s)) - # Set-pin should fail more overtly + # Set-pin should fail more overtly (NOTE: needs client entropy passed) with self.assertRaises(ValueError) as cm: aeskey_g = self.server_call_v2(priv_key, client, 'set_pin', self.new_pin_secret(), self.new_entropy()) @@ -560,7 +577,7 @@ def test_set_pin_counter_v2(self): # Get/verify pin with a new client client = self.new_client_v2() - aeskey_g = self.get_pin(priv_key, pin_secret, self.new_entropy(), use_v2_protocol=True) + aeskey_g = self.get_pin(priv_key, pin_secret, b'', use_v2_protocol=True) self.assertTrue(compare_digest(aeskey_g, aeskey_s)) # Trying to set-pin with same counter should fail @@ -577,7 +594,7 @@ def test_set_pin_counter_v2(self): self.assertEqual('500', str(cm.exception.args[0])) # Existing saved PIN undamaged as set attempt failed - aeskey_g = self.get_pin(priv_key, pin_secret, self.new_entropy(), use_v2_protocol=True) + aeskey_g = self.get_pin(priv_key, pin_secret, b'', use_v2_protocol=True) self.assertTrue(compare_digest(aeskey_g, aeskey_s)) # Trying to set pin while respecting the counter should work @@ -586,7 +603,7 @@ def test_set_pin_counter_v2(self): aeskey_s = self.server_call_v2(priv_key, client, 'set_pin', pin_secret, self.new_entropy()) self.assertFalse(compare_digest(aeskey_g, aeskey_s)) # changed - aeskey_g = self.get_pin(priv_key, pin_secret, self.new_entropy(), use_v2_protocol=True) + aeskey_g = self.get_pin(priv_key, pin_secret, b'', use_v2_protocol=True) self.assertTrue(compare_digest(aeskey_g, aeskey_s)) From 7337ddf08b5605091ebdfc0768934bd2d0965bbb Mon Sep 17 00:00:00 2001 From: "Jamie C. Driver" Date: Thu, 30 Nov 2023 13:04:10 +0000 Subject: [PATCH 9/9] bpov2: concatenate data into one json field using ascii85 encoding ie. rather than separate hex fields (in protocol v2) --- flaskserver.py | 26 +++++++++++++++----------- test/test_pinserver.py | 16 ++++++++++++++-- 2 files changed, 29 insertions(+), 13 deletions(-) diff --git a/flaskserver.py b/flaskserver.py index b8ceb6f..a31017f 100644 --- a/flaskserver.py +++ b/flaskserver.py @@ -1,6 +1,7 @@ -import time -import json import os +import json +import base64 +import time from flask import Flask, request, jsonify from .server import PINServerECDH, PINServerECDHv1, PINServerECDHv2 from .pindb import PINDb @@ -51,7 +52,7 @@ def start_handshake_route(): return jsonify({'ske': ske, 'sig': sig.hex()}) - # NOTE: explicit 'hmac' fields in protocol v1 + # NOTE: explicit fields in protocol v1 def _complete_server_call_v1(pin_func, udata): ske = udata['ske'] assert 'replay_counter' not in udata @@ -79,29 +80,32 @@ def _complete_server_call_v1(pin_func, udata): return jsonify({'encrypted_key': encrypted_key.hex(), 'hmac': hmac.hex()}) - # NOTE: 'hmac' data is appened to encrypted_data in protocol v2 + # NOTE: v2 is one concatentated field, ascii85-encoded def _complete_server_call_v2(pin_func, udata): - assert 'ske' not in udata - assert len(udata['replay_counter']) == 8 - cke = bytes.fromhex(udata['cke']) - replay_counter = bytes.fromhex(udata['replay_counter']) + assert 'data' in udata + data = base64.a85decode(udata['data'].encode()) + assert len(data) > 37 # cke and counter and some encrypted payload + + cke = data[:33] + replay_counter = data[33:37] + encrypted_data = data[37:] e_ecdh_server = PINServerECDHv2(replay_counter, cke) encrypted_key = e_ecdh_server.call_with_payload( cke, - bytes.fromhex(udata['encrypted_data']), + encrypted_data, pin_func) # Expecting to return an encrypted aes-key with hmac appended assert len(encrypted_key) == AES_KEY_LEN_256 + (2*AES_BLOCK_LEN) + HMAC_SHA256_LEN # Return response - return jsonify({'encrypted_key': encrypted_key.hex()}) + return jsonify({'data': base64.a85encode(encrypted_key).decode()}) def _complete_server_call(pin_func): try: # Get request data udata = json.loads(request.data) - if 'replay_counter' in udata: + if 'data' in udata: return _complete_server_call_v2(pin_func, udata) return _complete_server_call_v1(pin_func, udata) diff --git a/test/test_pinserver.py b/test/test_pinserver.py index d7289ea..05c4803 100644 --- a/test/test_pinserver.py +++ b/test/test_pinserver.py @@ -2,6 +2,7 @@ import os import json +import base64 import time from multiprocessing import Process from hmac import compare_digest @@ -152,6 +153,7 @@ def server_call_v1(self, private_key, client, endpoint, pin_secret, entropy, # Make the server call to get/set the pin - returns the decrypted response # NOTE: signature covers replay counter # NOTE: implicit hmac + # NOTE: all fields concatenated into one, and ascii85 encoded def server_call_v2(self, private_key, client, endpoint, pin_secret, entropy, fn_perturb_request=None): assert isinstance(client, PINClientECDHv2) @@ -166,6 +168,8 @@ def server_call_v2(self, private_key, client, endpoint, pin_secret, entropy, encrypted = client.encrypt_request_payload(payload) # Make call and parse response + # NOTE: we temporarily use the v1-like hex struct for the test perturbation + # function (ie. to mess with the data before posting) # Includes 'replay_counter' but not 'ske' or 'hmac' urldata = {'cke': cke.hex(), 'encrypted_data': encrypted.hex(), @@ -175,8 +179,16 @@ def server_call_v2(self, private_key, client, endpoint, pin_secret, entropy, if fn_perturb_request: urldata = fn_perturb_request(urldata) + # v2 concatenates all the fields into one and uses ascii85-encoding + cke = bytes.fromhex(urldata.get('cke', '')) + replay_counter = bytes.fromhex(urldata.get('replay_counter', '')) + encrypted = bytes.fromhex(urldata.get('encrypted_data', '')) + payload = cke + replay_counter + encrypted + data = base64.a85encode(payload).decode() + urldata = {'data': data} + response = self.post(endpoint, urldata) - encrypted = bytes.fromhex(response['encrypted_key']) + encrypted = base64.a85decode(response['data'].encode()) # Return decrypted payload return client.decrypt_response_payload(encrypted) @@ -409,7 +421,7 @@ def _test_rejects_on_bad_json_impl(self, use_v2_protocol): def _short(field): def _fn(d): - d[field] = d[field][:-1] + d[field] = d[field][:-2] return d return _fn