-
Notifications
You must be signed in to change notification settings - Fork 332
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
tim
committed
Mar 5, 2024
1 parent
fd2525c
commit d3111d0
Showing
1 changed file
with
41 additions
and
45 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,55 +1,51 @@ | ||
from datetime import datetime, timedelta | ||
from typing import List, Optional | ||
from datetime import timedelta | ||
|
||
import httpx | ||
import pytest | ||
from fastapi import FastAPI | ||
from pydantic import SecretStr | ||
|
||
from fastui.auth.google import ( | ||
EXCHANGE_CACHE, | ||
AuthError, | ||
GoogleAuthProvider, | ||
GoogleExchange, | ||
GoogleExchangeError, | ||
GoogleUser, | ||
EXCHANGE_CACHE, | ||
) | ||
from httpx import Request, Response | ||
from pydantic import SecretStr | ||
|
||
|
||
class MockTransport(httpx.AsyncBaseTransport): | ||
async def handle_async_request(self, request: Request) -> Response: | ||
url = str(request.url) | ||
method = request.method | ||
|
||
if url == "https://oauth2.googleapis.com/token" and method == "POST": | ||
if url == 'https://oauth2.googleapis.com/token' and method == 'POST': | ||
print(request.read()) | ||
if b"code=bad_code" in request.read(): | ||
return Response(200, json={"error": "bad code"}) | ||
if b'code=bad_code' in request.read(): | ||
return Response(200, json={'error': 'bad code'}) | ||
|
||
json_data = { | ||
"access_token": "test_access_token", | ||
"token_type": "Bearer", | ||
"expires_in": 3600, | ||
"refresh_token": "test_refresh_token", | ||
"scope": "email profile", | ||
'access_token': 'test_access_token', | ||
'token_type': 'Bearer', | ||
'expires_in': 3600, | ||
'refresh_token': 'test_refresh_token', | ||
'scope': 'email profile', | ||
} | ||
return Response(200, json=json_data) | ||
|
||
elif url == "https://www.googleapis.com/oauth2/v1/userinfo" and method == "GET": | ||
elif url == 'https://www.googleapis.com/oauth2/v1/userinfo' and method == 'GET': | ||
json_data = { | ||
"id": "12345", | ||
"email": "[email protected]", | ||
"verified_email": True, | ||
"name": "Test User", | ||
"given_name": "Test", | ||
"family_name": "User", | ||
"picture": "https://example.com/avatar.png", | ||
"locale": "en", | ||
'id': '12345', | ||
'email': '[email protected]', | ||
'verified_email': True, | ||
'name': 'Test User', | ||
'given_name': 'Test', | ||
'family_name': 'User', | ||
'picture': 'https://example.com/avatar.png', | ||
'locale': 'en', | ||
} | ||
return Response(200, json=json_data) | ||
|
||
return Response(404, json={"error": "not found"}) | ||
return Response(404, json={'error': 'not found'}) | ||
|
||
|
||
@pytest.fixture | ||
|
@@ -63,58 +59,58 @@ async def mock_httpx_client() -> httpx.AsyncClient: | |
async def google_auth_provider(mock_httpx_client: httpx.AsyncClient): | ||
return GoogleAuthProvider( | ||
httpx_client=mock_httpx_client, | ||
google_client_id="google_client_id", | ||
google_client_secret=SecretStr("google_client_secret"), | ||
redirect_uri="https://example.com/callback", | ||
scopes=["email", "profile"], | ||
google_client_id='google_client_id', | ||
google_client_secret=SecretStr('google_client_secret'), | ||
redirect_uri='https://example.com/callback', | ||
scopes=['email', 'profile'], | ||
state_provider=False, | ||
exchange_cache_age=timedelta(minutes=5), | ||
) | ||
|
||
|
||
async def test_authorization_url(google_auth_provider: GoogleAuthProvider): | ||
url = await google_auth_provider.authorization_url() | ||
assert url.startswith("https://accounts.google.com/o/oauth2/v2/auth?") | ||
assert url.startswith('https://accounts.google.com/o/oauth2/v2/auth?') | ||
|
||
|
||
async def test_exchange_code_success(google_auth_provider: GoogleAuthProvider): | ||
exchange = await google_auth_provider.exchange_code("good_code") | ||
exchange = await google_auth_provider.exchange_code('good_code') | ||
assert isinstance(exchange, GoogleExchange) | ||
assert exchange.access_token == "test_access_token" | ||
assert exchange.token_type == "Bearer" | ||
assert exchange.scope == "email profile" | ||
assert exchange.refresh_token == "test_refresh_token" | ||
assert exchange.access_token == 'test_access_token' | ||
assert exchange.token_type == 'Bearer' | ||
assert exchange.scope == 'email profile' | ||
assert exchange.refresh_token == 'test_refresh_token' | ||
|
||
|
||
async def test_exchange_code_error(google_auth_provider: GoogleAuthProvider): | ||
with pytest.raises(AuthError): | ||
await google_auth_provider.exchange_code("bad_code") | ||
await google_auth_provider.exchange_code('bad_code') | ||
|
||
|
||
async def test_refresh_access_token(google_auth_provider: GoogleAuthProvider): | ||
new_token = await google_auth_provider.refresh_access_token("good_refresh_token") | ||
new_token = await google_auth_provider.refresh_access_token('good_refresh_token') | ||
assert isinstance(new_token, GoogleExchange) | ||
assert new_token.access_token == "test_access_token" | ||
assert new_token.access_token == 'test_access_token' | ||
|
||
|
||
async def test_get_google_user(google_auth_provider: GoogleAuthProvider): | ||
exchange = GoogleExchange( | ||
access_token="good_access_token", | ||
token_type="Bearer", | ||
scope="email profile", | ||
access_token='good_access_token', | ||
token_type='Bearer', | ||
scope='email profile', | ||
expires_in=3600, | ||
refresh_token="good_refresh_token", | ||
refresh_token='good_refresh_token', | ||
) | ||
user = await google_auth_provider.get_google_user(exchange) | ||
assert isinstance(user, GoogleUser) | ||
assert user.id == "12345" | ||
assert user.email == "[email protected]" | ||
assert user.id == '12345' | ||
assert user.email == '[email protected]' | ||
|
||
|
||
async def test_exchange_cache( | ||
google_auth_provider: GoogleAuthProvider, | ||
): | ||
EXCHANGE_CACHE._cache.clear() | ||
assert len(EXCHANGE_CACHE._cache) == 0 | ||
await google_auth_provider.exchange_code("good_code") | ||
await google_auth_provider.exchange_code('good_code') | ||
assert len(EXCHANGE_CACHE._cache) == 1 |