Skip to content

Commit

Permalink
added refresh token
Browse files Browse the repository at this point in the history
- fixed user profile bug
- added refresh token
  • Loading branch information
cbrianbet committed Dec 3, 2024
1 parent 61aae4e commit c004c33
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 18 deletions.
41 changes: 34 additions & 7 deletions routes/user_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@

from database.user_db import get_user_db
from models.user_model import User
from schemas.user_schema import UserResponse, UserCreate
from utils.user_utils import hash_password, verify_password, create_access_token, get_current_user
from schemas.user_schema import UserResponse, UserCreate, TokenRefresh, UserLogin
from utils.user_utils import hash_password, verify_password, create_access_token, get_current_user, \
create_refresh_token, verify_token, REFRESH_SECRET_KEY

router = APIRouter()

Expand All @@ -30,16 +31,42 @@ async def register(user: UserCreate, db: Session = Depends(get_user_db)):


@router.post("/login")
async def user_login(email: str, password: str, db: Session = Depends(get_user_db)):
db_user = db.query(User).filter(User.email == email).first()
if not db_user or not verify_password(password, db_user.password):
async def user_login(user: UserLogin, db: Session = Depends(get_user_db)):
db_user = db.query(User).filter(User.email == user.email).first()
if not db_user or not verify_password(user.password, db_user.password):
raise HTTPException(status_code=400, detail="Incorrect email or password")
elif not db_user.is_active:
raise HTTPException(status_code=400, detail="User is not active")

token = create_access_token(data={"sub": db_user.email})
access_token = create_access_token(data={"sub": db_user.email})
refresh_token = create_refresh_token(data={"sub": db_user.email})

return {
"access_token": access_token,
"refresh_token": refresh_token,
"token_type": "bearer"
}


@router.post("/refresh")
def refresh_access_token(refresh_token: TokenRefresh, db: Session = Depends(get_user_db)):
# Verify the refresh token
payload = verify_token(refresh_token.refresh_token, REFRESH_SECRET_KEY)
email = payload.get("sub")
if email is None:
raise HTTPException(status_code=401, detail="Invalid refresh token")

# Check if the user exists
db_user = db.query(User).filter(User.email == email).first()
if db_user is None:
raise HTTPException(status_code=401, detail="User not found")

return {"token": token, "token_type": "bearer"}
# Generate a new access token
new_access_token = create_access_token(data={"sub": db_user.email})
return {
"access_token": new_access_token,
"token_type": "bearer"
}


@router.get('/info', response_model=UserResponse)
Expand Down
9 changes: 9 additions & 0 deletions schemas/user_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,12 @@ class UserResponse(UserBase):

class Config:
orm_mode = True


class UserLogin(BaseModel):
email: str
password: str


class TokenRefresh(BaseModel):
refresh_token: str
1 change: 1 addition & 0 deletions settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ class Settings(BaseSettings):
STAGING_API: str
BATCH_SIZE: int
JWT_SECRET_KEY: str
REFRESH_SECRET_KEY: str
# REPORTING_DB: str
# REPORTING_USER: str
# REPORTING_PASSWORD: str
Expand Down
34 changes: 23 additions & 11 deletions utils/user_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from datetime import datetime, timedelta, timezone

from fastapi import HTTPException, status, Depends
from fastapi.security import OAuth2PasswordBearer
from jose import jwt, JWTError
from passlib.context import CryptContext
from sqlalchemy.orm import Session
Expand All @@ -10,8 +11,12 @@
from settings import settings

SECRET_KEY = settings.JWT_SECRET_KEY
REFRESH_SECRET_KEY = settings.REFRESH_SECRET_KEY
ALGORITHM = 'HS256'
ACCESS_TOKEN_EXPIRE_MINUTES = 60 * 24 # One day
ACCESS_TOKEN_EXPIRE_MINUTES = 60 # One hour
REFRESH_TOKEN_EXPIRE_DAYS = 7

oauth2_scheme = OAuth2PasswordBearer(tokenUrl="api/user/login")


pwd_context = CryptContext(schemes=['bcrypt'], deprecated='auto')
Expand All @@ -25,26 +30,33 @@ def verify_password(plain_password: str, hashed_password: str) -> bool:
return pwd_context.verify(plain_password, hashed_password)


def create_access_token(data: dict):
def create_token(data:dict, secret_key: str, expires_delta: timedelta):
to_encode = data.copy()
expire = datetime.now(timezone.utc) + timedelta(
minutes=ACCESS_TOKEN_EXPIRE_MINUTES
)
to_encode['exp'] = expire
return jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
expire = datetime.now(timezone.utc) + expires_delta
to_encode["exp"] = expire
return jwt.encode(to_encode, secret_key, algorithm=ALGORITHM)


def create_access_token(data: dict):
return create_token(data, SECRET_KEY, timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES))


def create_refresh_token(data: dict):
return create_token(data, REFRESH_SECRET_KEY, timedelta(days=REFRESH_TOKEN_EXPIRE_DAYS))


def verify_token(token: str):
def verify_token(token: str, secret_key: str):
try:
return jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
return jwt.decode(token, secret_key, algorithms=[ALGORITHM])
except JWTError as e:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail=str(e)
) from e


def get_current_user(token: str = Depends(verify_token), db: Session = Depends(get_user_db)):
email = token.sub("sub")
def get_current_user(token: str = Depends(oauth2_scheme), db: Session = Depends(get_user_db)):
payload = verify_token(token, SECRET_KEY)
email = payload.get("sub")
if email is None:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Token is invalid")
db_user = db.query(User).filter(User.email == email).first()
Expand Down

0 comments on commit c004c33

Please sign in to comment.