diff --git a/client.py b/client.py index 0fa0914..421313a 100644 --- a/client.py +++ b/client.py @@ -11,6 +11,14 @@ def __init__(self, static_server_public_key): self.static_server_public_key = static_server_public_key self.ecdh_server_public_key = None + # returns ske, cke + def get_key_exchange(self): + return self.ecdh_server_public_key, self.public_key + + +# NOTE: protocol v1: +# Explicit 'hmac' fields, separate derived keys, and key-exchange handshake +class PINClientECDHv1(PINClientECDH): def handshake(self, e_ecdh_server_public_key, static_server_signature): ec_sig_verify( self.static_server_public_key, @@ -24,10 +32,6 @@ def handshake(self, e_ecdh_server_public_key, static_server_signature): # Cache the shared secrets self.generate_shared_secrets(e_ecdh_server_public_key) - # returns ske, cke - def get_key_exchange(self): - return self.ecdh_server_public_key, self.public_key - # Encrypt/sign/hmac the payload (ie. the pin secret) def encrypt_request_payload(self, payload): assert self.ecdh_server_public_key @@ -48,22 +52,25 @@ def decrypt_response_payload(self, encrypted, hmac): return decrypt(self.response_encryption_key, encrypted) +# NOTE: protocol v2: +# 'hmac' fields and derived keys implicit, and no key-exchange handshake required class PINClientECDHv2(PINClientECDH): def __init__(self, static_server_public_key, replay_counter): super().__init__(static_server_public_key) + + assert len(replay_counter) == 4 self.replay_counter = replay_counter - tweak = sha256(hmac_sha256(self.public_key, self.replay_counter)) # Derive and store the ecdh server public key (ske) + tweak = sha256(hmac_sha256(self.public_key, self.replay_counter)) self.ecdh_server_public_key = ec_public_key_bip341_tweak( self.static_server_public_key, tweak, 0) - # Cache the shared secrets - self.generate_shared_secrets(self.ecdh_server_public_key) - - # Encrypt/sign/hmac the payload (ie. the pin secret) def encrypt_request_payload(self, payload): - encrypted = encrypt(self.request_encryption_key, payload) - hmac = hmac_sha256(self.request_hmac_key, self.public_key + self.replay_counter + encrypted) - return encrypted, hmac + return self.encrypt_with_ecdh(self.ecdh_server_public_key, self.LABEL_ORACLE_REQUEST, + payload) + + def decrypt_response_payload(self, encrypted): + return self.decrypt_with_ecdh(self.ecdh_server_public_key, self.LABEL_ORACLE_RESPONSE, + encrypted) diff --git a/flaskserver.py b/flaskserver.py index 7357462..aab4129 100644 --- a/flaskserver.py +++ b/flaskserver.py @@ -2,10 +2,10 @@ import json import os from flask import Flask, request, jsonify -from .server import PINServerECDH, PINServerECDHv2 +from .server import PINServerECDH, PINServerECDHv1, PINServerECDHv2 from .pindb import PINDb from wallycore import hex_from_bytes, hex_to_bytes, AES_KEY_LEN_256, \ - AES_BLOCK_LEN + AES_BLOCK_LEN, HMAC_SHA256_LEN from dotenv import load_dotenv b2h = hex_from_bytes @@ -43,7 +43,7 @@ def start_handshake_route(): app.logger.debug('Number of sessions {}'.format(len(sessions))) # Create a new ephemeral server/session and get its signed pubkey - e_ecdh_server = PINServerECDH() + e_ecdh_server = PINServerECDHv1() pubkey, sig = e_ecdh_server.get_signed_public_key() ske = b2h(pubkey) @@ -55,6 +55,7 @@ def start_handshake_route(): return jsonify({'ske': ske, 'sig': b2h(sig)}) + # NOTE: explicit 'hmac' fields in protocol v1 def _complete_server_call_v1(pin_func, udata): ske = udata['ske'] assert 'replay_counter' not in udata @@ -70,8 +71,9 @@ def _complete_server_call_v1(pin_func, udata): h2b(udata['hmac_encrypted_data']), pin_func) - # Expecting to return an encrypted aes-key + # Expecting to return an encrypted aes-key with separate hmac assert len(encrypted_key) == AES_KEY_LEN_256 + (2*AES_BLOCK_LEN) + assert len(hmac) == HMAC_SHA256_LEN # Cleanup session del sessions[ske] @@ -81,24 +83,23 @@ def _complete_server_call_v1(pin_func, udata): return jsonify({'encrypted_key': b2h(encrypted_key), 'hmac': b2h(hmac)}) + # NOTE: 'hmac' data is appened to encrypted_data in protocol v2 def _complete_server_call_v2(pin_func, udata): assert 'ske' not in udata assert len(udata['replay_counter']) == 8 cke = h2b(udata['cke']) replay_counter = h2b(udata['replay_counter']) e_ecdh_server = PINServerECDHv2(replay_counter, cke) - encrypted_key, hmac = e_ecdh_server.call_with_payload( + encrypted_key = e_ecdh_server.call_with_payload( cke, h2b(udata['encrypted_data']), - h2b(udata['hmac_encrypted_data']), pin_func) - # Expecting to return an encrypted aes-key - assert len(encrypted_key) == AES_KEY_LEN_256 + (2*AES_BLOCK_LEN) + # Expecting to return an encrypted aes-key with hmac appended + assert len(encrypted_key) == AES_KEY_LEN_256 + (2*AES_BLOCK_LEN) + HMAC_SHA256_LEN # Return response - return jsonify({'encrypted_key': b2h(encrypted_key), - 'hmac': b2h(hmac)}) + return jsonify({'encrypted_key': b2h(encrypted_key)}) def _complete_server_call(pin_func): try: diff --git a/lib.py b/lib.py index b764e13..29dd328 100644 --- a/lib.py +++ b/lib.py @@ -2,7 +2,7 @@ from wallycore import AES_BLOCK_LEN, AES_FLAG_DECRYPT, AES_FLAG_ENCRYPT, \ aes_cbc, ec_private_key_verify, ec_public_key_from_private_key, ecdh, \ - hmac_sha256 + hmac_sha256, aes_cbc_with_ecdh_key def encrypt(aes_key, plaintext): @@ -19,6 +19,10 @@ def decrypt(aes_key, encrypted): class E_ECDH(object): + # Labels used to derived child keys for aes_cbc_with_ecdh_key() call + LABEL_ORACLE_REQUEST = 'blind_oracle_request'.encode() + LABEL_ORACLE_RESPONSE = 'blind_oracle_response'.encode() + @classmethod def _generate_private_key(cls): counter = 4 @@ -50,3 +54,12 @@ def _derived(val): self.request_hmac_key = _derived(1) self.response_encryption_key = _derived(2) self.response_hmac_key = _derived(3) + + def decrypt_with_ecdh(self, public_key, label, encrypted): + return aes_cbc_with_ecdh_key(self.private_key, None, encrypted, public_key, label, + AES_FLAG_DECRYPT) + + def encrypt_with_ecdh(self, public_key, label, plaintext): + iv = os.urandom(AES_BLOCK_LEN) + return aes_cbc_with_ecdh_key(self.private_key, iv, plaintext, public_key, label, + AES_FLAG_ENCRYPT) diff --git a/server.py b/server.py index 990fab7..53a837f 100644 --- a/server.py +++ b/server.py @@ -58,6 +58,13 @@ def __init__(self): super().__init__() self.time_started = int(time.time()) + +# NOTE: protocol v1: +# Explicit 'hmac' fields, separate derived keys, and key-exchange handshake +class PINServerECDHv1(PINServerECDH): + def __init__(self): + super().__init__() + def get_signed_public_key(self): return self.public_key, self._sign_with_static_key(self.public_key) @@ -89,6 +96,8 @@ def call_with_payload(self, cke, encrypted, hmac, func): return encrypted, hmac +# NOTE: protocol v2: +# 'hmac' fields and derived keys implicit, and no key-exchange handshake required class PINServerECDHv2(PINServerECDH): @classmethod @@ -107,24 +116,16 @@ def __init__(self, replay_counter, cke): self.replay_counter = replay_counter self.private_key, self.public_key = self.generate_ec_key_pair(replay_counter, cke) - # Decrypt the received payload (ie. aes-key) - def decrypt_request_payload(self, cke, encrypted, hmac): - # Verify hmac received - hmac_calculated = hmac_sha256(self.request_hmac_key, cke + self.replay_counter + encrypted) - assert compare_digest(hmac, hmac_calculated) + def decrypt_request_payload(self, cke, encrypted): + return self.decrypt_with_ecdh(cke, self.LABEL_ORACLE_REQUEST, encrypted) - # Return decrypted data - return decrypt(self.request_encryption_key, encrypted) + def encrypt_response_payload(self, cke, payload): + return self.encrypt_with_ecdh(cke, self.LABEL_ORACLE_RESPONSE, payload) # Function to deal with wrapper ecdh encryption. # Calls passed function with unwrapped payload, and wraps response before # returning. Separates payload handler func from wrapper encryption. - def call_with_payload(self, cke, encrypted, hmac, func): - self.generate_shared_secrets(cke) - payload = self.decrypt_request_payload(cke, encrypted, hmac) - - # Call the passed function with the decrypted payload + def call_with_payload(self, cke, encrypted, func): + payload = self.decrypt_request_payload(cke, encrypted) response = func(cke, payload, self._get_aes_pin_data_key(), self.replay_counter) - - encrypted, hmac = self.encrypt_response_payload(response) - return encrypted, hmac + return self.encrypt_response_payload(cke, response) diff --git a/test/test_ecdh_v1.py b/test/test_ecdh_v1.py index 3f26cae..55b912d 100644 --- a/test/test_ecdh_v1.py +++ b/test/test_ecdh_v1.py @@ -2,25 +2,26 @@ import os -from ..client import PINClientECDH -from ..server import PINServerECDH +from ..client import PINClientECDHv1 +from ..server import PINServerECDHv1 # Tests ECDHv1 wrapper without any reference to the pin/aes-key paylod stuff. # Just testing the ECDH envelope/encryption in isolation, with misc bytearray() # payloads (ie. any old str.encode()). Tests client/server handshake/pairing. -# NOTE: protocol v1: key-exchange handshake required +# NOTE: protocol v1: +# Explicit 'hmac' fields, separate derived keys, and key-exchange handshake class ECDHv1Test(unittest.TestCase): @classmethod def setUpClass(cls): # The server public key the client would know - with open(PINServerECDH.STATIC_SERVER_PUBLIC_KEY_FILE, 'rb') as f: + with open(PINServerECDHv1.STATIC_SERVER_PUBLIC_KEY_FILE, 'rb') as f: cls.static_server_public_key = f.read() # Make a new client and initialise with server handshake def new_client_handshake(self, ske, sig): - client = PINClientECDH(self.static_server_public_key) + client = PINClientECDHv1(self.static_server_public_key) client.handshake(ske, sig) ske1, cke = client.get_key_exchange() self.assertEqual(ske, ske1) @@ -30,7 +31,7 @@ def _test_client_server_impl(self, client_request, server_response): # A new server is created, which signs its newly-created ske with the # static key (so the client can validate that the ske is genuine). - server = PINServerECDH() + server = PINServerECDHv1() ske, sig = server.get_signed_public_key() # They get sent to the client (eg. over network) which then validates @@ -69,7 +70,7 @@ def test_client_server_happypath(self): def test_call_with_payload(self): # A new server and client - server = PINServerECDH() + server = PINServerECDHv1() ske, sig = server.get_signed_public_key() cke, client = self.new_client_handshake(ske, sig) @@ -96,7 +97,7 @@ def _func(client_key, payload, aes_pin_data_key): def test_multiple_calls(self): # A new server and client - server = PINServerECDH() + server = PINServerECDHv1() ske, sig = server.get_signed_public_key() cke, client = self.new_client_handshake(ske, sig) @@ -118,7 +119,7 @@ def test_multiple_calls(self): def test_multiple_clients(self): # A new server and several clients - server = PINServerECDH() + server = PINServerECDHv1() ske, sig = server.get_signed_public_key() # Server can persist and handle multiple calls provided each one is @@ -141,7 +142,7 @@ def test_multiple_clients(self): def test_bad_request_hmac_throws(self): # A new server and client - server = PINServerECDH() + server = PINServerECDHv1() ske, sig = server.get_signed_public_key() cke, client = self.new_client_handshake(ske, sig) @@ -168,7 +169,7 @@ def _func(client_key, payload, aes_pin_data_key): def test_bad_response_hmac_throws(self): # A new server and client - server = PINServerECDH() + server = PINServerECDHv1() ske, sig = server.get_signed_public_key() cke, client = self.new_client_handshake(ske, sig) diff --git a/test/test_ecdh_v2.py b/test/test_ecdh_v2.py index f1bc1aa..3ae85af 100644 --- a/test/test_ecdh_v2.py +++ b/test/test_ecdh_v2.py @@ -9,7 +9,8 @@ # Tests ECDHv2 wrapper without any reference to the pin/aes-key paylod stuff. # Just testing the ECDH envelope/encryption in isolation, with misc bytearray() # payloads (ie. any old str.encode()). Tests client/server handshake/pairing. -# NOTE: protocol v2: no key-exchange handshake required +# NOTE: protocol v2: +# 'hmac' fields and derived keys implicit, and no key-exchange handshake required class ECDHv2Test(unittest.TestCase): REPLAY_COUNTER = bytes([0x00, 0x00, 0x00, 0x2a]) # arbitrary @@ -31,8 +32,8 @@ def _test_client_server_impl(self, client_request, server_response): # server static key. cke, client = self.new_client_handshake() - # The client can then encrypt a payload (and hmac) for the server - encrypted, hmac = client.encrypt_request_payload(client_request) + # The client can then encrypt a payload (with implicit hmac) for the server + encrypted = client.encrypt_request_payload(client_request) self.assertNotEqual(client_request, encrypted) # A new server is created when passed the client replay-counter. @@ -40,18 +41,16 @@ def _test_client_server_impl(self, client_request, server_response): # NOTE: the server deduced private key should be the counterpart to the # client-deduced public key - if so the payload decryption should yield # the original cleartext request message. - # Note: this validates hmac before it decrypts/returns server = PINServerECDHv2(client.replay_counter, cke) - server.generate_shared_secrets(cke) - received = server.decrypt_request_payload(cke, encrypted, hmac) + received = server.decrypt_request_payload(cke, encrypted) self.assertEqual(received, client_request) # The server can then send an encrypted response to the client - encrypted, hmac = server.encrypt_response_payload(server_response) + encrypted = server.encrypt_response_payload(cke, server_response) # The client can decrypt the response. - # Note: this validates hmac before it decrypts/returns - received = client.decrypt_response_payload(encrypted, hmac) + # Note: this validates the implicit hmac before it decrypts/returns + received = client.decrypt_response_payload(encrypted) self.assertEqual(received, server_response) def test_client_server_happypath(self): @@ -67,7 +66,7 @@ def test_call_with_payload(self): # Client sends message to server cke, client = self.new_client_handshake() client_request = "Hello - test 123".encode() - encrypted, hmac = client.encrypt_request_payload(client_request) + encrypted = client.encrypt_request_payload(client_request) self.assertNotEqual(client_request, encrypted) # Test server un-/re-wrapping function - this handles all the ecdh @@ -82,10 +81,10 @@ def _func(client_key, payload, aes_pin_data_key, replay_counter): self.assertEqual(replay_counter, client.replay_counter) return server_response - encrypted, hmac = server.call_with_payload(cke, encrypted, hmac, _func) + encrypted = server.call_with_payload(cke, encrypted, _func) # Assert that is what the client expects - received = client.decrypt_response_payload(encrypted, hmac) + received = client.decrypt_response_payload(encrypted) self.assertEqual(received, server_response) def test_multiple_calls(self): @@ -95,18 +94,17 @@ def test_multiple_calls(self): # Server can handle multiple calls from the client with same secrets # (But that would use same cke and counter which is ofc not ideal/recommended.) - server.generate_shared_secrets(cke) for i in range(5): client_request = 'request-{}'.format(i).encode() - encrypted, hmac = client.encrypt_request_payload(client_request) + encrypted = client.encrypt_request_payload(client_request) - received = server.decrypt_request_payload(cke, encrypted, hmac) + received = server.decrypt_request_payload(cke, encrypted) self.assertEqual(received, client_request) server_response = 'response-{}'.format(i).encode() - encrypted, hmac = server.encrypt_response_payload(server_response) + encrypted = server.encrypt_response_payload(cke, server_response) - received = client.decrypt_response_payload(encrypted, hmac) + received = client.decrypt_response_payload(encrypted) self.assertEqual(received, server_response) def test_bad_request_hmac_throws(self): @@ -116,24 +114,25 @@ def test_bad_request_hmac_throws(self): # Encrypt message client_request = 'bad-hmac-request'.encode() - encrypted, hmac = client.encrypt_request_payload(client_request) + encrypted = client.encrypt_request_payload(client_request) - # Break hmac + # Break hmac at tail of encrypted bytes bad_hmac = bytearray(b+1 if b < 255 else b-1 for b in encrypted[-32:]) - self.assertNotEqual(hmac, bad_hmac) + bad_encrypted = encrypted[:-32] + bad_hmac + self.assertEqual(encrypted[:-32], bad_encrypted[:-32]) + self.assertNotEqual(encrypted[-32:], bad_encrypted[-32:]) # Ensure decrypt_request() throws - server.generate_shared_secrets(cke) - server.decrypt_request_payload(cke, encrypted, hmac) # no error - with self.assertRaises(AssertionError) as cm: - server.decrypt_request_payload(cke, encrypted, bad_hmac) # error + server.decrypt_request_payload(cke, encrypted) # no error + with self.assertRaises(ValueError) as cm: + server.decrypt_request_payload(cke, bad_encrypted) # error # Ensure call_with_payload() throws before it calls the handler fn def _func(client_key, payload, aes_pin_data_key, replay_counter): self.fail('should-never-get-here') - with self.assertRaises(AssertionError) as cm: - server.call_with_payload(cke, encrypted, bad_hmac, _func) + with self.assertRaises(ValueError) as cm: + server.call_with_payload(cke, bad_encrypted, _func) def test_bad_response_hmac_throws(self): # A new server and client @@ -142,7 +141,7 @@ def test_bad_response_hmac_throws(self): # Encrypt message client_request = 'bad-hmac-response-request'.encode() - encrypted, hmac = client.encrypt_request_payload(client_request) + encrypted = client.encrypt_request_payload(client_request) def _func(client_key, payload, pin_data_aes_key, replay_counter): self.assertEqual(client_key, cke) @@ -150,15 +149,17 @@ def _func(client_key, payload, pin_data_aes_key, replay_counter): self.assertEqual(replay_counter, client.replay_counter) return 'bad-hmac-response'.encode() - encrypted, hmac = server.call_with_payload(cke, encrypted, hmac, _func) + encrypted = server.call_with_payload(cke, encrypted, _func) # Break hmac bad_hmac = bytearray(b+1 if b < 255 else b-1 for b in encrypted[-32:]) - self.assertNotEqual(hmac, bad_hmac) + bad_encrypted = encrypted[:-32] + bad_hmac + self.assertEqual(encrypted[:-32], bad_encrypted[:-32]) + self.assertNotEqual(encrypted[-32:], bad_encrypted[-32:]) - client.decrypt_response_payload(encrypted, hmac) # No error - with self.assertRaises(AssertionError) as cm: - client.decrypt_response_payload(encrypted, bad_hmac) # error + client.decrypt_response_payload(encrypted) # No error + with self.assertRaises(ValueError) as cm: + client.decrypt_response_payload(bad_encrypted) # error if __name__ == '__main__': diff --git a/test/test_pinserver.py b/test/test_pinserver.py index 7df405f..bda1b05 100644 --- a/test/test_pinserver.py +++ b/test/test_pinserver.py @@ -7,7 +7,7 @@ from hmac import compare_digest import requests -from ..client import PINClientECDH, PINClientECDH, PINClientECDHv2 +from ..client import PINClientECDH, PINClientECDHv1, PINClientECDHv2 from ..server import PINServerECDH from ..pindb import PINDb @@ -100,14 +100,14 @@ def tearDownClass(cls): # Start the client/server key-exchange handshake def start_handshake_v1(self, client): - assert isinstance(client, PINClientECDH) + assert isinstance(client, PINClientECDHv1) handshake = self.post('start_handshake') client.handshake(h2b(handshake['ske']), h2b(handshake['sig'])) return client # Make a new ephemeral client and initialise with server handshake def new_client_v1(self): - client = PINClientECDH(self.static_server_public_key) + client = PINClientECDHv1(self.static_server_public_key) return self.start_handshake_v1(client) def new_client_v2(self, reset_replay_counter=False): @@ -122,7 +122,7 @@ def new_client_v2(self, reset_replay_counter=False): # Make the server call to get/set the pin - returns the decrypted response # NOTE: explicit hmac fields def server_call_v1(self, private_key, client, endpoint, pin_secret, entropy): - assert isinstance(client, PINClientECDH) + assert isinstance(client, PINClientECDHv1) # Make and encrypt the payload (ie. pin secret) ske, cke = client.get_key_exchange() @@ -148,6 +148,7 @@ def server_call_v1(self, private_key, client, endpoint, pin_secret, entropy): # Make the server call to get/set the pin - returns the decrypted response # NOTE: signature covers replay counter + # NOTE: implicit hmac def server_call_v2(self, private_key, client, endpoint, pin_secret, entropy): assert isinstance(client, PINClientECDHv2) @@ -158,20 +159,18 @@ def server_call_v2(self, private_key, client, endpoint, pin_secret, entropy): EC_FLAG_ECDSA | EC_FLAG_RECOVERABLE) payload = pin_secret + entropy + sig - encrypted, hmac = client.encrypt_request_payload(payload) + encrypted = client.encrypt_request_payload(payload) # Make call and parse response - # Includes 'replay_counter' but not 'ske' + # Includes 'replay_counter' but not 'ske' or 'hmac' urldata = {'cke': b2h(cke), - 'replay_counter': b2h(client.replay_counter), 'encrypted_data': b2h(encrypted), - 'hmac_encrypted_data': b2h(hmac)} + 'replay_counter': b2h(client.replay_counter)} response = self.post(endpoint, urldata) encrypted = h2b(response['encrypted_key']) - hmac = h2b(response['hmac']) # Return decrypted payload - return client.decrypt_response_payload(encrypted, hmac) + return client.decrypt_response_payload(encrypted) def make_server_call(self, private_key, endpoint, pin_secret, entropy, use_v2_protocol): if use_v2_protocol: @@ -222,7 +221,6 @@ def _test_set_and_get_pin_impl(self, use_v2_protocol): # Get key with a new client, with the correct pin secret (new entropy) for attempt in range(3): - aeskey_g = self.get_pin(priv_key, pin_secret, self.new_entropy(), use_v2_protocol) aeskey_g = self.get_pin(priv_key, pin_secret, self.new_entropy(), use_v2_protocol) self.assertTrue(compare_digest(aeskey_g, aeskey_s))