diff --git a/Dockerfile b/Dockerfile index 279641c..c10ec4d 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,14 +1,7 @@ # Base image FROM python:3.9-alpine -# Set environment variables directly in the Dockerfile -ENV db_username=postgres -ENV db_password=your_postgres_password -ENV SECRET_KEY=your_secret_key -ENV DB_HOST=localhost -ENV DB_PORT=5432 -ENV OAUTHDB=OAUTHDB -ENV ENVIRONMENT=production + # Set the working directory WORKDIR /app @@ -44,9 +37,6 @@ COPY nginx.conf /etc/nginx/nginx.conf COPY certs/server.crt /etc/ssl/certs/server.crt COPY certs/server.key /etc/ssl/private/server.key -# Copy init.sh -COPY init.sh /init.sh -RUN chmod +x /init.sh # Copy supervisord.conf COPY supervisord.conf /etc/supervisord.conf @@ -57,13 +47,5 @@ RUN mkdir -p /var/log/postgresql && mkdir -p /var/log/fastapi # Expose port 443 for HTTPS EXPOSE 443 -# Set environment variables again (if needed) -ENV db_username=postgres -ENV db_password=your_postgres_password -ENV SECRET_KEY=your_secret_key -ENV DB_HOST=localhost -ENV DB_PORT=5432 -ENV OAUTHDB=OAUTHDB - # Entrypoint to initialize PostgreSQL and start Supervisor -ENTRYPOINT ["/bin/bash", "-c", "/init.sh && supervisord -n -c /etc/supervisord.conf"] +ENTRYPOINT ["/bin/bash", "-c", "supervisord -n -c /etc/supervisord.conf"] diff --git a/db_helper.py b/db_helper.py index aa88345..b587cba 100644 --- a/db_helper.py +++ b/db_helper.py @@ -1,12 +1,13 @@ # db_helper.py import logging - import asyncpg from datetime import datetime +from typing import Optional, List from credential_manager import CredentialManager +from models import UserCreate, User, UserUpdate, UserInDB, OAuth2Client, OAuth2ClientCreate, OAuth2ClientUpdate, \ + OAuth2AuthorizationCode # Configure logging to output to stdout - logger = logging.getLogger(__name__) class DBHelper: @@ -191,7 +192,15 @@ async def close_pool(self): logger.info("Database connection pool closed.") # User methods - async def add_user(self, email: str, hashed_password: str): + async def add_user(self, user_in: UserCreate): + """ + Add a new user to the database. + + :param user_in: UserCreate object containing email and password + """ + hashed_password = user_in.password # Password should already be hashed before calling this method + email = user_in.email + async with self.pool.acquire() as conn: try: await conn.execute(""" @@ -200,45 +209,206 @@ async def add_user(self, email: str, hashed_password: str): logger.info(f"Added new user with email: {email}") except asyncpg.exceptions.UniqueViolationError: logger.warning(f"User with email '{email}' already exists.") + raise except Exception as e: logger.error(f"Error adding user '{email}': {e}") raise - async def get_user_by_email(self, email: str): + async def get_user_by_email(self, email: str) -> Optional[UserInDB]: + query = "SELECT id, email, hashed_password FROM users WHERE email = $1" + record = await self.pool.fetchrow(query, email) + if record: + return UserInDB(**record) + return None + + async def get_user_by_id(self, user_id: int) -> Optional[User]: + """ + Retrieve a user by ID. + + :param user_id: ID of the user + :return: User object or None if not found + """ + async with self.pool.acquire() as conn: + user_record = await conn.fetchrow("SELECT id, email, created_at FROM users WHERE id = $1", user_id) + logger.info(f"Fetched user by ID '{user_id}': {'Found' if user_record else 'Not Found'}") + if user_record: + return User(**dict(user_record)) + return None + + async def get_all_users(self) -> List[User]: + """ + Retrieve all users. + + :return: List of User objects + """ async with self.pool.acquire() as conn: - user = await conn.fetchrow("SELECT * FROM users WHERE email = $1", email) - logger.info(f"Fetched user by email '{email}': {'Found' if user else 'Not Found'}") - return dict(user) if user else None + user_records = await conn.fetch("SELECT id, email, created_at FROM users") + users = [User(**dict(user)) for user in user_records] + logger.info(f"Fetched all users. Total: {len(users)}") + return users + + async def update_user(self, user_id: int, user_in: UserUpdate): + """ + Update a user's information. + + :param user_id: ID of the user to update + :param user_in: UserUpdate object containing updated fields + """ + async with self.pool.acquire() as conn: + fields = [] + values = [] + idx = 1 + if user_in.email is not None: + fields.append(f"email = ${idx}") + values.append(user_in.email) + idx += 1 + if user_in.password is not None: + fields.append(f"hashed_password = ${idx}") + values.append(user_in.password) # Password should already be hashed before calling this method + idx += 1 + if not fields: + logger.info(f"No fields to update for user ID '{user_id}'") + return + values.append(user_id) + query = f"UPDATE users SET {', '.join(fields)} WHERE id = ${idx}" + try: + await conn.execute(query, *values) + logger.info(f"Updated user ID '{user_id}'") + except Exception as e: + logger.error(f"Error updating user ID '{user_id}': {e}") + raise + + # db_helper.py + + async def delete_user(self, user_id: int): + """ + Delete a user by ID. - async def get_user_by_id(self, user_id: int): + :param user_id: ID of the user to delete + """ async with self.pool.acquire() as conn: - user = await conn.fetchrow("SELECT * FROM users WHERE id = $1", user_id) - logger.info(f"Fetched user by ID '{user_id}': {'Found' if user else 'Not Found'}") - return dict(user) if user else None + try: + # Delete authorization codes related to the user + await conn.execute("DELETE FROM oauth2_authorization_codes WHERE user_id = $1", user_id) + + # Delete refresh tokens related to the user + await conn.execute("DELETE FROM refresh_tokens WHERE user_id = $1", user_id) + + # Delete roles associated with the user + await conn.execute("DELETE FROM UserRoles WHERE user_id = $1", user_id) + + # Delete the user + await conn.execute("DELETE FROM users WHERE id = $1", user_id) + + logger.info(f"Deleted user ID '{user_id}'") + except Exception as e: + logger.error(f"Error deleting user ID '{user_id}': {e}") + raise # Client methods - async def get_client_by_id(self, client_id: str): + async def get_client_by_id(self, client_id: str) -> Optional[OAuth2Client]: async with self.pool.acquire() as conn: - client = await conn.fetchrow("SELECT * FROM oauth2_clients WHERE client_id = $1", client_id) - logger.info(f"Fetched OAuth2 client by ID '{client_id}': {'Found' if client else 'Not Found'}") - return dict(client) if client else None + client_record = await conn.fetchrow("SELECT * FROM oauth2_clients WHERE client_id = $1", client_id) + logger.info(f"Fetched OAuth2 client by ID '{client_id}': {'Found' if client_record else 'Not Found'}") + logger.info(f"Fetched client: {client_record}") + if client_record: + return OAuth2Client( + client_id=client_record['client_id'], + client_secret=client_record['client_secret'], + redirect_uris=client_record['redirect_uris'].split(","), + created_at=client_record['created_at'] + ) + return None - async def add_client(self, client_id, client_secret, redirect_uris): + async def get_all_clients(self) -> List[OAuth2Client]: + """ + Retrieve all OAuth2 clients. + + :return: List of OAuth2Client objects + """ + async with self.pool.acquire() as conn: + client_records = await conn.fetch("SELECT * FROM oauth2_clients") + clients = [] + for record in client_records: + clients.append(OAuth2Client( + client_id=record['client_id'], + client_secret=record['client_secret'], + redirect_uris=record['redirect_uris'].split(","), + created_at=record['created_at'] + )) + logger.info(f"Fetched all clients. Total: {len(clients)}") + return clients + + async def add_client(self, client_id: str, client_secret: str, redirect_uris: List[str]): async with self.pool.acquire() as conn: try: await conn.execute(''' INSERT INTO oauth2_clients (client_id, client_secret, redirect_uris) VALUES ($1, $2, $3) - ''', client_id, client_secret, redirect_uris) + ''', client_id, client_secret, ",".join(redirect_uris)) logger.info(f"Added OAuth2 client with ID: {client_id}") except asyncpg.exceptions.UniqueViolationError: logger.warning(f"OAuth2 client with ID '{client_id}' already exists.") + raise except Exception as e: logger.error(f"Error adding client '{client_id}': {e}") raise + async def update_client(self, client_id: str, client_data: OAuth2ClientUpdate): + """ + Update an OAuth2 client's information. + + :param client_id: The client_id of the client to update + :param client_data: OAuth2ClientUpdate object containing updated fields + """ + async with self.pool.acquire() as conn: + fields = [] + values = [] + idx = 1 + if client_data.client_secret is not None: + fields.append(f"client_secret = ${idx}") + values.append(client_data.client_secret) + idx += 1 + if client_data.redirect_uris is not None: + fields.append(f"redirect_uris = ${idx}") + values.append(",".join(client_data.redirect_uris)) + idx += 1 + if not fields: + logger.info(f"No fields to update for client ID '{client_id}'") + return + values.append(client_id) + query = f"UPDATE oauth2_clients SET {', '.join(fields)} WHERE client_id = ${idx}" + try: + await conn.execute(query, *values) + logger.info(f"Updated OAuth2 client with ID '{client_id}'") + except Exception as e: + logger.error(f"Error updating OAuth2 client '{client_id}': {e}") + raise + + async def delete_client(self, client_id: str): + """ + Delete an OAuth2 client by client_id. + + :param client_id: The client_id of the client to delete + """ + async with self.pool.acquire() as conn: + + try: + # Delete authorization codes related to the client + await conn.execute("DELETE FROM oauth2_authorization_codes WHERE client_id = $1", client_id) + + # Delete the client + await conn.execute("DELETE FROM oauth2_clients WHERE client_id = $1", client_id) + + logger.info(f"Deleted OAuth2 client with ID '{client_id}'") + + except Exception as e: + + logger.error(f"Error deleting OAuth2 client '{client_id}': {e}") + raise + # Authorization code methods - async def save_authorization_code(self, auth_code): + async def save_authorization_code(self, auth_code: OAuth2AuthorizationCode): async with self.pool.acquire() as conn: try: await conn.execute(""" @@ -377,3 +547,86 @@ async def get_user_permissions(self, user_id: int): permission_names = [perm['permission_name'] for perm in permissions] logger.info(f"Retrieved permissions for user ID '{user_id}': {permission_names}") return permission_names + + async def get_user_roles(self, user_id: int) -> List[str]: + async with self.pool.acquire() as conn: + roles = await conn.fetch(""" + SELECT R.role_name + FROM Roles R + INNER JOIN UserRoles UR ON R.id = UR.role_id + WHERE UR.user_id = $1 + """, user_id) + return [role['role_name'] for role in roles] + + + async def get_all_roles(self) -> List[dict]: + async with self.pool.acquire() as conn: + roles = await conn.fetch("SELECT id, role_name FROM Roles") + return [dict(role) for role in roles] + + async def get_all_permissions(self) -> List[dict]: + async with self.pool.acquire() as conn: + permissions = await conn.fetch("SELECT id, permission_name FROM Permissions") + return [dict(permission) for permission in permissions] + + async def delete_role(self, role_id: int): + async with self.pool.acquire() as conn: + try: + await conn.execute("DELETE FROM Roles WHERE id = $1", role_id) + logger.info(f"Deleted role ID '{role_id}'") + except Exception as e: + logger.error(f"Error deleting role ID '{role_id}': {e}") + raise + + async def delete_permission(self, permission_id: int): + async with self.pool.acquire() as conn: + try: + await conn.execute("DELETE FROM Permissions WHERE id = $1", permission_id) + logger.info(f"Deleted permission ID '{permission_id}'") + except Exception as e: + logger.error(f"Error deleting permission ID '{permission_id}': {e}") + raise + + async def remove_role_from_user(self, user_id: int, role_id: int): + async with self.pool.acquire() as conn: + try: + await conn.execute(""" + DELETE FROM UserRoles + WHERE user_id = $1 AND role_id = $2 + """, user_id, role_id) + logger.info(f"Removed role ID '{role_id}' from user ID '{user_id}'") + except Exception as e: + logger.error(f"Error removing role ID '{role_id}' from user ID '{user_id}': {e}") + raise + + async def get_roles_for_user(self, user_id: int) -> List[dict]: + async with self.pool.acquire() as conn: + roles = await conn.fetch(""" + SELECT R.id, R.role_name + FROM Roles R + INNER JOIN UserRoles UR ON R.id = UR.role_id + WHERE UR.user_id = $1 + """, user_id) + return [dict(role) for role in roles] + + async def remove_permission_from_role(self, role_id: int, permission_id: int): + async with self.pool.acquire() as conn: + try: + await conn.execute(""" + DELETE FROM RolePermissions + WHERE role_id = $1 AND permission_id = $2 + """, role_id, permission_id) + logger.info(f"Removed permission ID '{permission_id}' from role ID '{role_id}'") + except Exception as e: + logger.error(f"Error removing permission ID '{permission_id}' from role ID '{role_id}': {e}") + raise + + async def get_permissions_for_role(self, role_id: int) -> List[dict]: + async with self.pool.acquire() as conn: + permissions = await conn.fetch(""" + SELECT P.id, P.permission_name + FROM Permissions P + INNER JOIN RolePermissions RP ON P.id = RP.permission_id + WHERE RP.role_id = $1 + """, role_id) + return [dict(permission) for permission in permissions] \ No newline at end of file diff --git a/init.sh b/init.sh deleted file mode 100644 index a71adee..0000000 --- a/init.sh +++ /dev/null @@ -1,45 +0,0 @@ -#!/bin/bash -# init.sh -set -e - -# Initialize PostgreSQL data directory if it doesn't exist -if [ ! -d "/var/lib/postgresql/data" ]; then - echo "Initializing PostgreSQL data directory..." - mkdir -p /var/lib/postgresql/data - chown -R postgres:postgres /var/lib/postgresql - su-exec postgres initdb -D /var/lib/postgresql/data -fi - -# Update pg_hba.conf to allow password authentication -PG_HBA=/var/lib/postgresql/data/pg_hba.conf -if [ -f "$PG_HBA" ]; then - echo "Configuring PostgreSQL to use md5 authentication..." - sed -i "s/^#\?\(local\s\+all\s\+all\s\+\)peer/\1md5/" $PG_HBA - sed -i "s/^#\?\(host\s\+all\s\+all\s\+127\.0\.0\.1\/32\s\+\)md5/\1md5/" $PG_HBA - sed -i "s/^#\?\(host\s\+all\s\+all\s\+::1\/128\s\+\)md5/\1md5/" $PG_HBA -fi - -# Ensure /run/postgresql exists and is owned by postgres -echo "Ensuring /run/postgresql exists and is owned by postgres..." -mkdir -p /run/postgresql -chown postgres:postgres /run/postgresql - -# Start PostgreSQL to perform setup -echo "Starting PostgreSQL..." -su-exec postgres postgres -D /var/lib/postgresql/data & -sleep 5 - -# Create PostgreSQL user with SUPERUSER privilege if it doesn't exist -echo "Creating PostgreSQL user with SUPERUSER privilege if it doesn't exist..." -su-exec postgres psql -tc "SELECT 1 FROM pg_roles WHERE rolname = '$db_username'" | grep -q 1 || su-exec postgres psql -c "CREATE USER $db_username WITH PASSWORD '$db_password' SUPERUSER;" - -# Create database if it doesn't exist -echo "Creating PostgreSQL database if it doesn't exist..." -su-exec postgres psql -tc "SELECT 1 FROM pg_database WHERE datname = '$OAUTHDB'" | grep -q 1 || su-exec postgres psql -c "CREATE DATABASE $OAUTHDB OWNER $db_username;" - -# Grant all privileges on the database to the user -su-exec postgres psql -c "GRANT ALL PRIVILEGES ON DATABASE $OAUTHDB TO $db_username;" - -# Stop PostgreSQL (Supervisor will manage it) -echo "Stopping PostgreSQL..." -su-exec postgres pg_ctl -D /var/lib/postgresql/data -m fast -w stop diff --git a/main.py b/main.py index 1ba3c62..554c683 100644 --- a/main.py +++ b/main.py @@ -1,17 +1,19 @@ # main.py + import os import sys - +import json from fastapi.middleware.trustedhost import TrustedHostMiddleware -from fastapi.middleware.httpsredirect import HTTPSRedirectMiddleware -from fastapi import FastAPI, Request, Form, Depends, HTTPException, status, Header -from fastapi.responses import RedirectResponse, JSONResponse, HTMLResponse +from fastapi import FastAPI, Request, Form, Depends, HTTPException, status, Header, Security, APIRouter +from fastapi.responses import RedirectResponse, HTMLResponse from fastapi.middleware.cors import CORSMiddleware +from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials +from pydantic import BaseModel from starlette.middleware.base import BaseHTTPMiddleware from starlette.middleware.sessions import SessionMiddleware from passlib.context import CryptContext from datetime import datetime, timedelta, timezone -from typing import Optional +from typing import Optional, List import jwt import uuid import logging @@ -20,7 +22,8 @@ import urllib.parse from db_helper import DBHelper -from models import User, OAuth2Client, OAuth2AuthorizationCode, TokenRequest +from models import OAuth2AuthorizationCode, UserCreate, User, UserUpdate, OAuth2Client, OAuth2ClientUpdate, \ + OAuth2ClientCreate, Role, RoleCreate, Permission, PermissionCreate from credential_manager import CredentialManager # Configure logging to write to stdout only @@ -39,7 +42,6 @@ async def dispatch(self, request: Request, call_next): return response - app = FastAPI() # Add this middleware before CORSMiddleware @@ -53,12 +55,14 @@ async def dispatch(self, request: Request, call_next): if environment == 'production': origins = [ "https://pass.cerealsoft.com", + "https://oauthconsole.cerealsoft.com", "https://auth.cerealsoft.com", "https://passbackend.cerealsoft.com", # Add other production origins if needed ] else: origins = [ + "https://localhost:3400", # oauth console front end "https://localhost:3300", # Frontend origin "https://localhost:3200", "http://localhost:3000", # If applicable, e.g., React default port @@ -72,6 +76,7 @@ async def dispatch(self, request: Request, call_next): allow_credentials=True, # Allow credentials (cookies, authorization headers) allow_methods=["*"], # Allow all HTTP methods allow_headers=["*"], # Allow all headers + expose_headers=["*"], ) @@ -93,36 +98,97 @@ async def dispatch(self, request: Request, call_next): # Initialize DBHelper db_helper = DBHelper() + @app.on_event("startup") async def startup(): await db_helper.init_db() try: + # Existing client setup code remains unchanged client_id_signup = 'a1b2c3d4-5678-90ab-cdef-1234567890ab' client_secret_signup = 'b2c3d4e5-6789-01ab-cdef-2345678901bc' redirect_uri_signup = 'http://localhost:3000/callback' existing_client = await db_helper.get_client_by_id(client_id_signup) if not existing_client: - await db_helper.add_client(client_id_signup, client_secret_signup, redirect_uri_signup) + await db_helper.add_client(client_id_signup, client_secret_signup, [redirect_uri_signup]) client_id_password_vault = 'a1b2c3d4-5678-90ab-cdef-1234567890ac' client_secret_password_vault = 'b2c3d4e5-6789-01ab-cdef-2345678901bc' + redirect_uri_password_vault = None + + client_id_oauth_console = 'a1b2c3d4-5678-90ab-cdef-1234567890ad' + client_secret_oauth_console = 'b2c3d4e5-6789-01ab-cdef-2345678901bc' + redirect_uri_oauth_console = None + if environment == 'production': redirect_uri_password_vault = 'https://pass.cerealsoft.com/callback' + redirect_uri_oauth_console = 'https://oauthconsole.cerealsoft.com/callback' else: redirect_uri_password_vault = 'https://localhost:3300/callback' + redirect_uri_oauth_console = 'https://localhost:3400/callback' + existing_client2 = await db_helper.get_client_by_id(client_id_password_vault) + logging.info(f"client 2: {existing_client2}") + + existing_client3 = await db_helper.get_client_by_id(client_id_oauth_console) + logging.info(f"client 3: {existing_client3}") + if not existing_client2: - await db_helper.add_client(client_id_password_vault, client_secret_password_vault, redirect_uri_password_vault) + await db_helper.add_client(client_id_password_vault, client_secret_password_vault, [redirect_uri_password_vault]) + + if not existing_client3: + await db_helper.add_client(client_id_oauth_console, client_secret_oauth_console, [redirect_uri_oauth_console]) + + # New code to create default admin user + # 1. Create 'admin' role if it doesn't exist + admin_role = await db_helper.get_role_by_name('admin') + if not admin_role: + await db_helper.create_role('admin') + logging.info("Created 'admin' role") + + # 2. Create default admin user if it doesn't exist + admin_email = 'admin@example.com' + admin_password = 'adminpassword' # In production, do not hardcode passwords + + existing_admin_user = await db_helper.get_user_by_email(admin_email) + if not existing_admin_user: + # Hash the password + hashed_password = get_password_hash(admin_password) + # Create UserCreate object + admin_user_create = UserCreate(email=admin_email, password=hashed_password) + # Add user to the database + await db_helper.add_user(admin_user_create) + logging.info(f"Created default admin user with email: {admin_email}") + + # Retrieve the admin user (whether newly created or existing) + admin_user = await db_helper.get_user_by_email(admin_email) + + # 3. Assign 'admin' role to the admin user + # Get the role and user IDs + admin_role = await db_helper.get_role_by_name('admin') + if admin_role and admin_user: + # Check if the user already has the 'admin' role + user_roles = await db_helper.get_user_roles(admin_user.id) + if 'admin' not in user_roles: + await db_helper.assign_role_to_user(admin_user.id, admin_role['id']) + logging.info(f"Assigned 'admin' role to user '{admin_email}'") + else: + logging.info(f"User '{admin_email}' already has 'admin' role") + else: + logging.error("Failed to assign 'admin' role to the default admin user") + except Exception as e: - logging.error(e) + logging.error(f"Error during startup: {e}") + # Utility functions def verify_password(plain_password, hashed_password): return pwd_context.verify(plain_password, hashed_password) + def get_password_hash(password): return pwd_context.hash(password) + def create_token(data: dict, expires_delta: Optional[timedelta] = None): to_encode = data.copy() expire = datetime.utcnow() + (expires_delta if expires_delta else timedelta(minutes=15)) @@ -130,6 +196,7 @@ def create_token(data: dict, expires_delta: Optional[timedelta] = None): encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM) return encoded_jwt + # Dependency function to get the current user async def get_current_user(request: Request): user_id = request.session.get('user_id') @@ -139,10 +206,9 @@ async def get_current_user(request: Request): return user raise HTTPException(status_code=401, detail="Not authenticated") -# Routes -import json +# Routes @app.get("/authorize") async def authorize(request: Request, response_type: str, @@ -159,9 +225,9 @@ async def authorize(request: Request, raise HTTPException(status_code=400, detail="Invalid client_id") # Validate redirect_uri - redirect_uris = client['redirect_uris'].split(',') - if redirect_uri not in redirect_uris: + if redirect_uri not in client.redirect_uris: logging.error(f"Invalid redirect_uri: {redirect_uri}") + logging.error(f"Allowed redirect_uris: {client.redirect_uris}") raise HTTPException(status_code=400, detail="Invalid redirect_uri") # Check response_type @@ -226,7 +292,7 @@ async def authorize(request: Request, client_id=client_id, redirect_uri=redirect_uri, scope=scope, - user_id=user['id'], + user_id=user.id, code_challenge=code_challenge, code_challenge_method=code_challenge_method, expires_at=expires_at @@ -380,19 +446,20 @@ async def login_get(request: Request): async def login_post(request: Request, email: str = Form(...), password: str = Form(...)): user = await db_helper.get_user_by_email(email) if user: - if not verify_password(password, user['hashed_password']): + if not verify_password(password, user.hashed_password): logging.error(f"Invalid credentials for {email}") raise HTTPException(status_code=400, detail="Invalid credentials") else: # Register new user hashed_password = get_password_hash(password) - await db_helper.add_user(email, hashed_password) + new_user = UserCreate(email=email, password=hashed_password) + await db_helper.add_user(user_in=new_user) user = await db_helper.get_user_by_email(email) logging.info(f"New user registered: {email}") # Authenticate user - request.session['user_id'] = user['id'] - logging.info(f"User '{user['email']}' authenticated successfully with user_id '{user['id']}'.") + request.session['user_id'] = user.id + logging.info(f"User '{user.email}' authenticated successfully with user_id '{user.id}'.") # Retrieve auth_request from session auth_request = request.session.pop('auth_request', None) @@ -405,12 +472,12 @@ async def login_post(request: Request, email: str = Form(...), password: str = F f"&state={urllib.parse.quote(auth_request['state_json'] or '')}" \ f"&code_challenge={urllib.parse.quote(auth_request['code_challenge'])}" \ f"&code_challenge_method={urllib.parse.quote(auth_request['code_challenge_method'])}" - logging.info(f"Redirecting user '{user['email']}' to authorization endpoint with URL: {redirect_url}") + logging.info(f"Redirecting user '{user.email}' to authorization endpoint with URL: {redirect_url}") return RedirectResponse(url=redirect_url, status_code=303) else: # No auth_request found, redirect to 'next_url' if present next_url = request.session.pop('next_url', '/') - logging.info(f"Redirecting user '{user['email']}' to next URL: {next_url}") + logging.info(f"Redirecting user '{user.email}' to next URL: {next_url}") return RedirectResponse(url=next_url, status_code=303) @@ -475,19 +542,19 @@ async def token(request: Request, # Generate access token user = await db_helper.get_user_by_id(auth_code['user_id']) access_token = create_token( - data={"sub": str(user['id']), "email": user['email']}, + data={"sub": str(user.id), "email": user.email}, # Changed from user['id'], user['email'] expires_delta=timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) ) # Optionally, generate a refresh token refresh_token = str(uuid.uuid4()) refresh_expires_at = datetime.utcnow() + timedelta(days=7) - await db_helper.save_refresh_token(user['id'], refresh_token, refresh_expires_at) + await db_helper.save_refresh_token(user.id, refresh_token, refresh_expires_at) # Delete authorization code await db_helper.delete_authorization_code(code) - logging.info(f"Issued access_token and refresh_token for user_id: {user['id']}") + logging.info(f"Issued access_token and refresh_token for user_id: {user.id}") return { "access_token": access_token, @@ -513,14 +580,14 @@ async def token_refresh(refresh_token: str = Form(...)): # Generate new access token user = await db_helper.get_user_by_id(token_data['user_id']) access_token = create_token( - data={"sub": str(user['id']), "email": user['email']}, + data={"sub": str(user.id), "email": user.email}, # Changed from user['id'], user['email'] expires_delta=timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) ) # Optionally, rotate refresh token new_refresh_token = str(uuid.uuid4()) refresh_expires_at = datetime.utcnow() + timedelta(days=7) - await db_helper.save_refresh_token(user['id'], new_refresh_token, refresh_expires_at) + await db_helper.save_refresh_token(user.id, new_refresh_token, refresh_expires_at) await db_helper.delete_refresh_token(refresh_token) return { @@ -546,12 +613,13 @@ async def protected_resource(token: str = Depends(get_token_from_header)): user = await db_helper.get_user_by_id(int(user_id)) if not user: raise HTTPException(status_code=401, detail="User not found") - return {"email": user['email']} + return {"email": user.email} # Changed from user['email'] except jwt.ExpiredSignatureError: raise HTTPException(status_code=401, detail="Token expired") except jwt.InvalidTokenError: raise HTTPException(status_code=401, detail="Invalid token") + # Additional endpoints for client registration, etc., can be added as needed @app.post("/login_or_signup") @@ -560,7 +628,7 @@ async def login_or_signup(request: Request, email: str = Form(...), password: st user = await db_helper.get_user_by_email(email) if user: # User exists, attempt to authenticate - if not verify_password(password, user['hashed_password']): + if not verify_password(password, user.hashed_password): logging.warning(f"Invalid password for {email}") raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, @@ -572,7 +640,8 @@ async def login_or_signup(request: Request, email: str = Form(...), password: st # User does not exist, create account hashed_password = get_password_hash(password) try: - await db_helper.add_user(email, hashed_password) + new_user = UserCreate(email=email, hashed_password=hashed_password) + await db_helper.add_user(user_in=new_user) user = await db_helper.get_user_by_email(email) logging.info(f"User {email} registered and logged in successfully") except Exception as e: @@ -580,28 +649,32 @@ async def login_or_signup(request: Request, email: str = Form(...), password: st raise HTTPException(status_code=500, detail="Internal server error") # Set user in session - request.session['user_id'] = user['id'] + request.session['user_id'] = user.id # Redirect back to the original authorization request next_url = request.query_params.get('next') or '/' return RedirectResponse(url=next_url) + @app.post("/logout") async def logout(request: Request): request.session.clear() logging.info("User logged out successfully") return {"msg": "Logged out successfully"} + # Example protected resource using the dependency @app.get("/users/me") -async def read_users_me(user: dict = Depends(get_current_user)): - logging.info(f"User data requested for {user['email']}") +async def read_users_me(user: User = Depends(get_current_user)): # Type hint updated + logging.info(f"User data requested for {user.email}") # Changed from user['email'] return { - "email": user['email'], - "id": user['id'], + "email": user.email, # Changed from user['email'] + "id": user.id, # Changed from user['id'] } + # Additional endpoints for roles and permissions can be added as needed + @app.post("/roles") async def create_role(role_name: str): try: @@ -612,6 +685,7 @@ async def create_role(role_name: str): logging.error(f"Role creation error: {e}") raise HTTPException(status_code=500, detail="Internal server error") + @app.post("/permissions") async def create_permission(permission_name: str): try: @@ -622,6 +696,7 @@ async def create_permission(permission_name: str): logging.error(f"Permission creation error: {e}") raise HTTPException(status_code=500, detail="Internal server error") + @app.post("/roles/{role_name}/permissions") async def assign_permission_to_role(role_name: str, permission_name: str): try: @@ -640,6 +715,7 @@ async def assign_permission_to_role(role_name: str, permission_name: str): logging.error(f"Error assigning permission to role: {e}") raise HTTPException(status_code=500, detail="Internal server error") + @app.post("/users/{email}/roles") async def assign_role_to_user(email: str, role_name: str): try: @@ -651,13 +727,14 @@ async def assign_role_to_user(email: str, role_name: str): if not role: raise HTTPException(status_code=404, detail="Role not found") - await db_helper.assign_role_to_user(user['id'], role['id']) + await db_helper.assign_role_to_user(user.id, role['id']) logging.info(f"Assigned role '{role_name}' to user '{email}'") return {"msg": f"Role '{role_name}' assigned to user '{email}'"} except Exception as e: logging.error(f"Error assigning role to user: {e}") raise HTTPException(status_code=500, detail="Internal server error") + @app.get("/users/{email}/permissions") async def get_user_permissions(email: str): try: @@ -665,9 +742,192 @@ async def get_user_permissions(email: str): if not user: raise HTTPException(status_code=404, detail="User not found") - permission_names = await db_helper.get_user_permissions(user['id']) + permission_names = await db_helper.get_user_permissions(user.id) logging.info(f"Retrieved permissions for user '{email}'") return {"email": email, "permissions": permission_names} except Exception as e: logging.error(f"Error retrieving user permissions: {e}") raise HTTPException(status_code=500, detail="Internal server error") + + +# Define security scheme +bearer_scheme = HTTPBearer() + + +async def get_current_active_user(token: HTTPAuthorizationCredentials = Security(bearer_scheme)): + try: + payload = jwt.decode(token.credentials, SECRET_KEY, algorithms=[ALGORITHM]) + user_id = payload.get("sub") + if not user_id: + raise HTTPException(status_code=401, detail="Invalid token") + user = await db_helper.get_user_by_id(int(user_id)) + if not user: + raise HTTPException(status_code=401, detail="User not found") + return user + except jwt.ExpiredSignatureError: + raise HTTPException(status_code=401, detail="Token expired") + except jwt.InvalidTokenError: + raise HTTPException(status_code=401, detail="Invalid token") + + +async def get_current_admin_user(user: User = Depends(get_current_active_user)): + # Check if the user has the 'admin' role + roles = await db_helper.get_user_roles(user.id) + if 'admin' not in roles: + raise HTTPException(status_code=403, detail="Not authorized") + return user + + +admin_router = APIRouter(prefix="/admin", tags=["admin"]) + + +@admin_router.get("/users", response_model=List[User]) +async def get_users(admin_user: User = Depends(get_current_admin_user)): + users = await db_helper.get_all_users() + return users + + +@admin_router.post("/users", response_model=User) +async def create_user(user_in: UserCreate, admin_user: User = Depends(get_current_admin_user)): + existing_user = await db_helper.get_user_by_email(user_in.email) + if existing_user: + raise HTTPException(status_code=400, detail="Email already registered") + # Hash the password + hashed_password = get_password_hash(user_in.password) + user_in.password = hashed_password + await db_helper.add_user(user_in) + user = await db_helper.get_user_by_email(user_in.email) + return user + + +@admin_router.put("/users/{user_id}", response_model=User) +async def update_user(user_id: int, user_in: UserUpdate, admin_user: User = Depends(get_current_admin_user)): + user = await db_helper.get_user_by_id(user_id) + if not user: + raise HTTPException(status_code=404, detail="User not found") + if user_in.password: + user_in.password = get_password_hash(user_in.password) + await db_helper.update_user(user_id, user_in) + updated_user = await db_helper.get_user_by_id(user_id) + return updated_user + + +@admin_router.delete("/users/{user_id}", response_model=dict) +async def delete_user(user_id: int, admin_user: User = Depends(get_current_admin_user)): + user = await db_helper.get_user_by_id(user_id) + if not user: + raise HTTPException(status_code=404, detail="User not found") + await db_helper.delete_user(user_id) + return {"detail": "User deleted successfully"} + + +@admin_router.get("/clients", response_model=List[OAuth2Client]) +async def get_clients(admin_user: User = Depends(get_current_admin_user)): + clients = await db_helper.get_all_clients() + return clients + +@admin_router.post("/clients", response_model=OAuth2Client) +async def create_client(client_data: OAuth2ClientCreate, admin_user: User = Depends(get_current_admin_user)): + existing_client = await db_helper.get_client_by_id(client_data.client_id) + if existing_client: + raise HTTPException(status_code=400, detail="Client ID already exists") + await db_helper.add_client(client_data.client_id, client_data.client_secret, client_data.redirect_uris) + new_client = await db_helper.get_client_by_id(client_data.client_id) + if not new_client: + raise HTTPException(status_code=500, detail="Failed to create new client") + return new_client + +@admin_router.put("/clients/{client_id}", response_model=OAuth2Client) +async def update_client(client_id: str, client_data: OAuth2ClientUpdate, admin_user: User = Depends(get_current_admin_user)): + existing_client = await db_helper.get_client_by_id(client_id) + if not existing_client: + raise HTTPException(status_code=404, detail="Client not found") + await db_helper.update_client(client_id, client_data) + updated_client = await db_helper.get_client_by_id(client_id) + if not updated_client: + raise HTTPException(status_code=500, detail="Failed to update the client") + return updated_client + +@admin_router.delete("/clients/{client_id}", response_model=dict) +async def delete_client(client_id: str, admin_user: User = Depends(get_current_admin_user)): + existing_client = await db_helper.get_client_by_id(client_id) + if not existing_client: + raise HTTPException(status_code=404, detail="Client not found") + await db_helper.delete_client(client_id) + return {"detail": "Client deleted successfully"} + +@admin_router.get("/roles", response_model=List[Role]) +async def get_roles(admin_user: User = Depends(get_current_admin_user)): + roles = await db_helper.get_all_roles() + return roles + +@admin_router.post("/roles", response_model=Role) +async def create_role(role_in: RoleCreate, admin_user: User = Depends(get_current_admin_user)): + existing_role = await db_helper.get_role_by_name(role_in.role_name) + if existing_role: + raise HTTPException(status_code=400, detail="Role already exists") + await db_helper.create_role(role_in.role_name) + new_role = await db_helper.get_role_by_name(role_in.role_name) + return Role(**new_role) + +@admin_router.delete("/roles/{role_id}", response_model=dict) +async def delete_role(role_id: int, admin_user: User = Depends(get_current_admin_user)): + await db_helper.delete_role(role_id) + return {"detail": "Role deleted successfully"} + + +@admin_router.get("/permissions", response_model=List[Permission]) +async def get_permissions(admin_user: User = Depends(get_current_admin_user)): + permissions = await db_helper.get_all_permissions() + return permissions + +@admin_router.post("/permissions", response_model=Permission) +async def create_permission(permission_in: PermissionCreate, admin_user: User = Depends(get_current_admin_user)): + existing_permission = await db_helper.get_permission_by_name(permission_in.permission_name) + if existing_permission: + raise HTTPException(status_code=400, detail="Permission already exists") + await db_helper.create_permission(permission_in.permission_name) + new_permission = await db_helper.get_permission_by_name(permission_in.permission_name) + return Permission(**new_permission) + +@admin_router.delete("/permissions/{permission_id}", response_model=dict) +async def delete_permission(permission_id: int, admin_user: User = Depends(get_current_admin_user)): + await db_helper.delete_permission(permission_id) + return {"detail": "Permission deleted successfully"} + + +class RoleAssign(BaseModel): + role_id: int + +@admin_router.post("/users/{user_id}/roles", response_model=dict) +async def assign_role_to_user(user_id: int, role_assign: RoleAssign, admin_user: User = Depends(get_current_admin_user)): + await db_helper.assign_role_to_user(user_id, role_assign.role_id) + return {"detail": "Role assigned to user successfully"} + +@admin_router.delete("/users/{user_id}/roles/{role_id}", response_model=dict) +async def remove_role_from_user(user_id: int, role_id: int, admin_user: User = Depends(get_current_admin_user)): + await db_helper.remove_role_from_user(user_id, role_id) + return {"detail": "Role removed from user successfully"} + +@admin_router.get("/users/{user_id}/roles", response_model=List[Role]) +async def get_user_roles(user_id: int, admin_user: User = Depends(get_current_admin_user)): + roles = await db_helper.get_roles_for_user(user_id) + return [Role(**role) for role in roles] + + +class PermissionAssign(BaseModel): + permission_id: int + +@admin_router.post("/roles/{role_id}/permissions", response_model=dict) +async def assign_permission_to_role(role_id: int, perm_assign: PermissionAssign, admin_user: User = Depends(get_current_admin_user)): + await db_helper.assign_permission_to_role(role_id, perm_assign.permission_id) + return {"detail": "Permission assigned to role successfully"} + +@admin_router.delete("/roles/{role_id}/permissions/{permission_id}", response_model=dict) +async def remove_permission_from_role(role_id: int, permission_id: int, admin_user: User = Depends(get_current_admin_user)): + await db_helper.remove_permission_from_role(role_id, permission_id) + return {"detail": "Permission removed from role successfully"} + + + +app.include_router(admin_router) diff --git a/models.py b/models.py index a8d81fb..2b49cc9 100644 --- a/models.py +++ b/models.py @@ -1,21 +1,52 @@ # models.py -from pydantic import BaseModel -from typing import Optional +from pydantic import BaseModel, EmailStr +from typing import Optional, List from datetime import datetime -class User(BaseModel): - id: Optional[int] - email: str - hashed_password: str - created_at: Optional[datetime] + +class UserBase(BaseModel): + email: EmailStr + + +class UserCreate(UserBase): + password: str + + +class UserUpdate(BaseModel): + email: Optional[EmailStr] = None + password: Optional[str] = None + + +class User(UserBase): + id: int + created_at: datetime + + class Config: + from_attributes = True + class OAuth2Client(BaseModel): client_id: str client_secret: Optional[str] - redirect_uris: list + redirect_uris: List[str] created_at: Optional[datetime] + class Config: + from_attributes = True + + +class OAuth2ClientCreate(BaseModel): + client_id: str + client_secret: str + redirect_uris: List[str] + + +class OAuth2ClientUpdate(BaseModel): + client_secret: Optional[str] = None + redirect_uris: Optional[List[str]] = None + + class OAuth2AuthorizationCode(BaseModel): code: str client_id: str @@ -26,9 +57,43 @@ class OAuth2AuthorizationCode(BaseModel): code_challenge_method: str expires_at: datetime + class TokenRequest(BaseModel): grant_type: str code: Optional[str] redirect_uri: Optional[str] client_id: Optional[str] code_verifier: Optional[str] + + +class UserInDB(BaseModel): + id: int + email: str + hashed_password: str + + class Config: + from_attributes = True + + +class RoleBase(BaseModel): + role_name: str + + +class RoleCreate(RoleBase): + pass + + +class Role(RoleBase): + id: int + + +class PermissionBase(BaseModel): + permission_name: str + + +class PermissionCreate(PermissionBase): + pass + + +class Permission(PermissionBase): + id: int \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 58242ab..35bf7f4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,4 +7,5 @@ starlette pydantic python-dotenv itsdangerous -python-multipart \ No newline at end of file +python-multipart +pydantic[email] \ No newline at end of file diff --git a/supervisord.conf b/supervisord.conf index a8b0ad8..5c182c8 100644 --- a/supervisord.conf +++ b/supervisord.conf @@ -2,17 +2,6 @@ nodaemon=true loglevel=info -[program:postgresql] -command=/usr/bin/postgres -D /var/lib/postgresql/data -user=postgres -stdout_logfile=/dev/stdout -stderr_logfile=/dev/stderr -stdout_logfile_maxbytes=0 -stderr_logfile_maxbytes=0 -autostart=true -autorestart=true -priority=10 - [program:fastapi] command=/usr/local/bin/uvicorn main:app --host 127.0.0.1 --port 3100 directory=/app