Skip to content

Commit

Permalink
test
Browse files Browse the repository at this point in the history
  • Loading branch information
Apple authored and Apple committed Sep 10, 2024
1 parent cda0234 commit 9f94b5d
Showing 1 changed file with 87 additions and 4 deletions.
91 changes: 87 additions & 4 deletions tests/test_auth.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from datetime import datetime, timedelta
from typing import Optional

import pytest

from hivemind.proto import dht_pb2
from hivemind.proto.auth_pb2 import AccessToken
from hivemind.utils.auth import AuthRole, AuthRPCWrapper, TokenAuthorizerBase
from hivemind.utils.crypto import RSAPrivateKey, RSAPublicKey # updated to ed25519 in entire repo
from hivemind.utils.crypto import RSAPrivateKey
from hivemind.utils.logging import get_logger

logger = get_logger(__name__)
Expand All @@ -16,12 +18,13 @@ class MockAuthorizer(TokenAuthorizerBase):

def __init__(self, local_private_key: Optional[RSAPrivateKey], username: str = "mock"):
super().__init__(local_private_key)

self._username = username
self._authority_public_key = None

async def get_token(self) -> AccessToken:
if MockAuthorizer._authority_private_key is None:
MockAuthorizer._authority_private_key = RSAPrivateKey() # updated to ed25519 in entire repo
MockAuthorizer._authority_private_key = RSAPrivateKey()

self._authority_public_key = MockAuthorizer._authority_private_key.get_public_key()

Expand Down Expand Up @@ -68,13 +71,93 @@ def _token_to_bytes(access_token: AccessToken) -> bytes:

@pytest.mark.asyncio
async def test_valid_request_and_response():
client_authorizer = MockAuthorizer(RSAPrivateKey()) # updated to ed25519 in entire repo
service_authorizer = MockAuthorizer(RSAPrivateKey()) # updated to ed25519 in entire repo
client_authorizer = MockAuthorizer(RSAPrivateKey())
service_authorizer = MockAuthorizer(RSAPrivateKey())

request = dht_pb2.PingRequest()
request.peer.node_id = b"ping"
await client_authorizer.sign_request(request, service_authorizer.local_public_key)
assert await service_authorizer.validate_request(request)

response = dht_pb2.PingResponse()
response.peer.node_id = b"pong"
await service_authorizer.sign_response(response, request)
assert await client_authorizer.validate_response(response, request)


@pytest.mark.asyncio
async def test_invalid_access_token():
client_authorizer = MockAuthorizer(RSAPrivateKey())
service_authorizer = MockAuthorizer(RSAPrivateKey())

request = dht_pb2.PingRequest()
request.peer.node_id = b"ping"
await client_authorizer.sign_request(request, service_authorizer.local_public_key)

# Break the access token signature
request.auth.client_access_token.signature = b"broken"

assert not await service_authorizer.validate_request(request)

response = dht_pb2.PingResponse()
response.peer.node_id = b"pong"
await service_authorizer.sign_response(response, request)

# Break the access token signature
response.auth.service_access_token.signature = b"broken"

assert not await client_authorizer.validate_response(response, request)


@pytest.mark.asyncio
async def test_invalid_signatures():
client_authorizer = MockAuthorizer(RSAPrivateKey())
service_authorizer = MockAuthorizer(RSAPrivateKey())

request = dht_pb2.PingRequest()
request.peer.node_id = b"true-ping"
await client_authorizer.sign_request(request, service_authorizer.local_public_key)

# A man-in-the-middle attacker changes the request content
request.peer.node_id = b"fake-ping"

assert not await service_authorizer.validate_request(request)

response = dht_pb2.PingResponse()
response.peer.node_id = b"true-pong"
await service_authorizer.sign_response(response, request)

# A man-in-the-middle attacker changes the response content
response.peer.node_id = b"fake-pong"

assert not await client_authorizer.validate_response(response, request)


@pytest.mark.asyncio
async def test_auth_rpc_wrapper():
class Servicer:
async def rpc_increment(self, request: dht_pb2.PingRequest) -> dht_pb2.PingResponse:
assert request.peer.node_id == b"ping"
assert request.auth.client_access_token.username == "alice"

response = dht_pb2.PingResponse()
response.peer.node_id = b"pong"
return response

class Client:
def __init__(self, servicer: Servicer):
self._servicer = servicer

async def rpc_increment(self, request: dht_pb2.PingRequest) -> dht_pb2.PingResponse:
return await self._servicer.rpc_increment(request)

servicer = AuthRPCWrapper(Servicer(), AuthRole.SERVICER, MockAuthorizer(RSAPrivateKey(), "bob"))
client = AuthRPCWrapper(Client(servicer), AuthRole.CLIENT, MockAuthorizer(RSAPrivateKey(), "alice"))

request = dht_pb2.PingRequest()
request.peer.node_id = b"ping"

response = await client.rpc_increment(request)

assert response.peer.node_id == b"pong"
assert response.auth.service_access_token.username == "bob"

0 comments on commit 9f94b5d

Please sign in to comment.