diff --git a/tests/test_auth.py b/tests/test_auth.py index e04419061..f6a078a28 100644 --- a/tests/test_auth.py +++ b/tests/test_auth.py @@ -1,10 +1,12 @@ from datetime import datetime, timedelta from typing import Optional + import pytest + from hivemind.proto import dht_pb2 from hivemind.proto.auth_pb2 import AccessToken from hivemind.utils.auth import AuthRole, AuthRPCWrapper, TokenAuthorizerBase -from hivemind.utils.crypto import RSAPrivateKey, RSAPublicKey # updated to ed25519 in entire repo +from hivemind.utils.crypto import RSAPrivateKey from hivemind.utils.logging import get_logger logger = get_logger(__name__) @@ -16,12 +18,13 @@ class MockAuthorizer(TokenAuthorizerBase): def __init__(self, local_private_key: Optional[RSAPrivateKey], username: str = "mock"): super().__init__(local_private_key) + self._username = username self._authority_public_key = None async def get_token(self) -> AccessToken: if MockAuthorizer._authority_private_key is None: - MockAuthorizer._authority_private_key = RSAPrivateKey() # updated to ed25519 in entire repo + MockAuthorizer._authority_private_key = RSAPrivateKey() self._authority_public_key = MockAuthorizer._authority_private_key.get_public_key() @@ -68,13 +71,93 @@ def _token_to_bytes(access_token: AccessToken) -> bytes: @pytest.mark.asyncio async def test_valid_request_and_response(): - client_authorizer = MockAuthorizer(RSAPrivateKey()) # updated to ed25519 in entire repo - service_authorizer = MockAuthorizer(RSAPrivateKey()) # updated to ed25519 in entire repo + client_authorizer = MockAuthorizer(RSAPrivateKey()) + service_authorizer = MockAuthorizer(RSAPrivateKey()) + request = dht_pb2.PingRequest() request.peer.node_id = b"ping" await client_authorizer.sign_request(request, service_authorizer.local_public_key) assert await service_authorizer.validate_request(request) + response = dht_pb2.PingResponse() response.peer.node_id = b"pong" await service_authorizer.sign_response(response, request) assert await client_authorizer.validate_response(response, request) + + +@pytest.mark.asyncio +async def test_invalid_access_token(): + client_authorizer = MockAuthorizer(RSAPrivateKey()) + service_authorizer = MockAuthorizer(RSAPrivateKey()) + + request = dht_pb2.PingRequest() + request.peer.node_id = b"ping" + await client_authorizer.sign_request(request, service_authorizer.local_public_key) + + # Break the access token signature + request.auth.client_access_token.signature = b"broken" + + assert not await service_authorizer.validate_request(request) + + response = dht_pb2.PingResponse() + response.peer.node_id = b"pong" + await service_authorizer.sign_response(response, request) + + # Break the access token signature + response.auth.service_access_token.signature = b"broken" + + assert not await client_authorizer.validate_response(response, request) + + +@pytest.mark.asyncio +async def test_invalid_signatures(): + client_authorizer = MockAuthorizer(RSAPrivateKey()) + service_authorizer = MockAuthorizer(RSAPrivateKey()) + + request = dht_pb2.PingRequest() + request.peer.node_id = b"true-ping" + await client_authorizer.sign_request(request, service_authorizer.local_public_key) + + # A man-in-the-middle attacker changes the request content + request.peer.node_id = b"fake-ping" + + assert not await service_authorizer.validate_request(request) + + response = dht_pb2.PingResponse() + response.peer.node_id = b"true-pong" + await service_authorizer.sign_response(response, request) + + # A man-in-the-middle attacker changes the response content + response.peer.node_id = b"fake-pong" + + assert not await client_authorizer.validate_response(response, request) + + +@pytest.mark.asyncio +async def test_auth_rpc_wrapper(): + class Servicer: + async def rpc_increment(self, request: dht_pb2.PingRequest) -> dht_pb2.PingResponse: + assert request.peer.node_id == b"ping" + assert request.auth.client_access_token.username == "alice" + + response = dht_pb2.PingResponse() + response.peer.node_id = b"pong" + return response + + class Client: + def __init__(self, servicer: Servicer): + self._servicer = servicer + + async def rpc_increment(self, request: dht_pb2.PingRequest) -> dht_pb2.PingResponse: + return await self._servicer.rpc_increment(request) + + servicer = AuthRPCWrapper(Servicer(), AuthRole.SERVICER, MockAuthorizer(RSAPrivateKey(), "bob")) + client = AuthRPCWrapper(Client(servicer), AuthRole.CLIENT, MockAuthorizer(RSAPrivateKey(), "alice")) + + request = dht_pb2.PingRequest() + request.peer.node_id = b"ping" + + response = await client.rpc_increment(request) + + assert response.peer.node_id == b"pong" + assert response.auth.service_access_token.username == "bob" \ No newline at end of file