Skip to content

Commit

Permalink
I have made numerous changes which allow for the oauth console to add…
Browse files Browse the repository at this point in the history
… and remove users roles and permissions in addition to several fixes overall
  • Loading branch information
autonomouscereal committed Nov 18, 2024
1 parent e6022f0 commit c71ad5c
Show file tree
Hide file tree
Showing 7 changed files with 644 additions and 139 deletions.
22 changes: 2 additions & 20 deletions Dockerfile
Original file line number Diff line number Diff line change
@@ -1,14 +1,7 @@
# Base image
FROM python:3.9-alpine

# Set environment variables directly in the Dockerfile
ENV db_username=postgres
ENV db_password=your_postgres_password
ENV SECRET_KEY=your_secret_key
ENV DB_HOST=localhost
ENV DB_PORT=5432
ENV OAUTHDB=OAUTHDB
ENV ENVIRONMENT=production


# Set the working directory
WORKDIR /app
Expand Down Expand Up @@ -44,9 +37,6 @@ COPY nginx.conf /etc/nginx/nginx.conf
COPY certs/server.crt /etc/ssl/certs/server.crt
COPY certs/server.key /etc/ssl/private/server.key

# Copy init.sh
COPY init.sh /init.sh
RUN chmod +x /init.sh

# Copy supervisord.conf
COPY supervisord.conf /etc/supervisord.conf
Expand All @@ -57,13 +47,5 @@ RUN mkdir -p /var/log/postgresql && mkdir -p /var/log/fastapi
# Expose port 443 for HTTPS
EXPOSE 443

# Set environment variables again (if needed)
ENV db_username=postgres
ENV db_password=your_postgres_password
ENV SECRET_KEY=your_secret_key
ENV DB_HOST=localhost
ENV DB_PORT=5432
ENV OAUTHDB=OAUTHDB

# Entrypoint to initialize PostgreSQL and start Supervisor
ENTRYPOINT ["/bin/bash", "-c", "/init.sh && supervisord -n -c /etc/supervisord.conf"]
ENTRYPOINT ["/bin/bash", "-c", "supervisord -n -c /etc/supervisord.conf"]
289 changes: 271 additions & 18 deletions db_helper.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
# db_helper.py
import logging

import asyncpg
from datetime import datetime
from typing import Optional, List
from credential_manager import CredentialManager
from models import UserCreate, User, UserUpdate, UserInDB, OAuth2Client, OAuth2ClientCreate, OAuth2ClientUpdate, \
OAuth2AuthorizationCode

# Configure logging to output to stdout

logger = logging.getLogger(__name__)

class DBHelper:
Expand Down Expand Up @@ -191,7 +192,15 @@ async def close_pool(self):
logger.info("Database connection pool closed.")

# User methods
async def add_user(self, email: str, hashed_password: str):
async def add_user(self, user_in: UserCreate):
"""
Add a new user to the database.
:param user_in: UserCreate object containing email and password
"""
hashed_password = user_in.password # Password should already be hashed before calling this method
email = user_in.email

async with self.pool.acquire() as conn:
try:
await conn.execute("""
Expand All @@ -200,45 +209,206 @@ async def add_user(self, email: str, hashed_password: str):
logger.info(f"Added new user with email: {email}")
except asyncpg.exceptions.UniqueViolationError:
logger.warning(f"User with email '{email}' already exists.")
raise
except Exception as e:
logger.error(f"Error adding user '{email}': {e}")
raise

async def get_user_by_email(self, email: str):
async def get_user_by_email(self, email: str) -> Optional[UserInDB]:
query = "SELECT id, email, hashed_password FROM users WHERE email = $1"
record = await self.pool.fetchrow(query, email)
if record:
return UserInDB(**record)
return None

async def get_user_by_id(self, user_id: int) -> Optional[User]:
"""
Retrieve a user by ID.
:param user_id: ID of the user
:return: User object or None if not found
"""
async with self.pool.acquire() as conn:
user_record = await conn.fetchrow("SELECT id, email, created_at FROM users WHERE id = $1", user_id)
logger.info(f"Fetched user by ID '{user_id}': {'Found' if user_record else 'Not Found'}")
if user_record:
return User(**dict(user_record))
return None

async def get_all_users(self) -> List[User]:
"""
Retrieve all users.
:return: List of User objects
"""
async with self.pool.acquire() as conn:
user = await conn.fetchrow("SELECT * FROM users WHERE email = $1", email)
logger.info(f"Fetched user by email '{email}': {'Found' if user else 'Not Found'}")
return dict(user) if user else None
user_records = await conn.fetch("SELECT id, email, created_at FROM users")
users = [User(**dict(user)) for user in user_records]
logger.info(f"Fetched all users. Total: {len(users)}")
return users

async def update_user(self, user_id: int, user_in: UserUpdate):
"""
Update a user's information.
:param user_id: ID of the user to update
:param user_in: UserUpdate object containing updated fields
"""
async with self.pool.acquire() as conn:
fields = []
values = []
idx = 1
if user_in.email is not None:
fields.append(f"email = ${idx}")
values.append(user_in.email)
idx += 1
if user_in.password is not None:
fields.append(f"hashed_password = ${idx}")
values.append(user_in.password) # Password should already be hashed before calling this method
idx += 1
if not fields:
logger.info(f"No fields to update for user ID '{user_id}'")
return
values.append(user_id)
query = f"UPDATE users SET {', '.join(fields)} WHERE id = ${idx}"
try:
await conn.execute(query, *values)
logger.info(f"Updated user ID '{user_id}'")
except Exception as e:
logger.error(f"Error updating user ID '{user_id}': {e}")
raise

# db_helper.py

async def delete_user(self, user_id: int):
"""
Delete a user by ID.
async def get_user_by_id(self, user_id: int):
:param user_id: ID of the user to delete
"""
async with self.pool.acquire() as conn:
user = await conn.fetchrow("SELECT * FROM users WHERE id = $1", user_id)
logger.info(f"Fetched user by ID '{user_id}': {'Found' if user else 'Not Found'}")
return dict(user) if user else None
try:
# Delete authorization codes related to the user
await conn.execute("DELETE FROM oauth2_authorization_codes WHERE user_id = $1", user_id)

# Delete refresh tokens related to the user
await conn.execute("DELETE FROM refresh_tokens WHERE user_id = $1", user_id)

# Delete roles associated with the user
await conn.execute("DELETE FROM UserRoles WHERE user_id = $1", user_id)

# Delete the user
await conn.execute("DELETE FROM users WHERE id = $1", user_id)

logger.info(f"Deleted user ID '{user_id}'")
except Exception as e:
logger.error(f"Error deleting user ID '{user_id}': {e}")
raise

# Client methods
async def get_client_by_id(self, client_id: str):
async def get_client_by_id(self, client_id: str) -> Optional[OAuth2Client]:
async with self.pool.acquire() as conn:
client = await conn.fetchrow("SELECT * FROM oauth2_clients WHERE client_id = $1", client_id)
logger.info(f"Fetched OAuth2 client by ID '{client_id}': {'Found' if client else 'Not Found'}")
return dict(client) if client else None
client_record = await conn.fetchrow("SELECT * FROM oauth2_clients WHERE client_id = $1", client_id)
logger.info(f"Fetched OAuth2 client by ID '{client_id}': {'Found' if client_record else 'Not Found'}")
logger.info(f"Fetched client: {client_record}")
if client_record:
return OAuth2Client(
client_id=client_record['client_id'],
client_secret=client_record['client_secret'],
redirect_uris=client_record['redirect_uris'].split(","),
created_at=client_record['created_at']
)
return None

async def add_client(self, client_id, client_secret, redirect_uris):
async def get_all_clients(self) -> List[OAuth2Client]:
"""
Retrieve all OAuth2 clients.
:return: List of OAuth2Client objects
"""
async with self.pool.acquire() as conn:
client_records = await conn.fetch("SELECT * FROM oauth2_clients")
clients = []
for record in client_records:
clients.append(OAuth2Client(
client_id=record['client_id'],
client_secret=record['client_secret'],
redirect_uris=record['redirect_uris'].split(","),
created_at=record['created_at']
))
logger.info(f"Fetched all clients. Total: {len(clients)}")
return clients

async def add_client(self, client_id: str, client_secret: str, redirect_uris: List[str]):
async with self.pool.acquire() as conn:
try:
await conn.execute('''
INSERT INTO oauth2_clients (client_id, client_secret, redirect_uris)
VALUES ($1, $2, $3)
''', client_id, client_secret, redirect_uris)
''', client_id, client_secret, ",".join(redirect_uris))
logger.info(f"Added OAuth2 client with ID: {client_id}")
except asyncpg.exceptions.UniqueViolationError:
logger.warning(f"OAuth2 client with ID '{client_id}' already exists.")
raise
except Exception as e:
logger.error(f"Error adding client '{client_id}': {e}")
raise

async def update_client(self, client_id: str, client_data: OAuth2ClientUpdate):
"""
Update an OAuth2 client's information.
:param client_id: The client_id of the client to update
:param client_data: OAuth2ClientUpdate object containing updated fields
"""
async with self.pool.acquire() as conn:
fields = []
values = []
idx = 1
if client_data.client_secret is not None:
fields.append(f"client_secret = ${idx}")
values.append(client_data.client_secret)
idx += 1
if client_data.redirect_uris is not None:
fields.append(f"redirect_uris = ${idx}")
values.append(",".join(client_data.redirect_uris))
idx += 1
if not fields:
logger.info(f"No fields to update for client ID '{client_id}'")
return
values.append(client_id)
query = f"UPDATE oauth2_clients SET {', '.join(fields)} WHERE client_id = ${idx}"
try:
await conn.execute(query, *values)
logger.info(f"Updated OAuth2 client with ID '{client_id}'")
except Exception as e:
logger.error(f"Error updating OAuth2 client '{client_id}': {e}")
raise

async def delete_client(self, client_id: str):
"""
Delete an OAuth2 client by client_id.
:param client_id: The client_id of the client to delete
"""
async with self.pool.acquire() as conn:

try:
# Delete authorization codes related to the client
await conn.execute("DELETE FROM oauth2_authorization_codes WHERE client_id = $1", client_id)

# Delete the client
await conn.execute("DELETE FROM oauth2_clients WHERE client_id = $1", client_id)

logger.info(f"Deleted OAuth2 client with ID '{client_id}'")

except Exception as e:

logger.error(f"Error deleting OAuth2 client '{client_id}': {e}")
raise

# Authorization code methods
async def save_authorization_code(self, auth_code):
async def save_authorization_code(self, auth_code: OAuth2AuthorizationCode):
async with self.pool.acquire() as conn:
try:
await conn.execute("""
Expand Down Expand Up @@ -377,3 +547,86 @@ async def get_user_permissions(self, user_id: int):
permission_names = [perm['permission_name'] for perm in permissions]
logger.info(f"Retrieved permissions for user ID '{user_id}': {permission_names}")
return permission_names

async def get_user_roles(self, user_id: int) -> List[str]:
async with self.pool.acquire() as conn:
roles = await conn.fetch("""
SELECT R.role_name
FROM Roles R
INNER JOIN UserRoles UR ON R.id = UR.role_id
WHERE UR.user_id = $1
""", user_id)
return [role['role_name'] for role in roles]


async def get_all_roles(self) -> List[dict]:
async with self.pool.acquire() as conn:
roles = await conn.fetch("SELECT id, role_name FROM Roles")
return [dict(role) for role in roles]

async def get_all_permissions(self) -> List[dict]:
async with self.pool.acquire() as conn:
permissions = await conn.fetch("SELECT id, permission_name FROM Permissions")
return [dict(permission) for permission in permissions]

async def delete_role(self, role_id: int):
async with self.pool.acquire() as conn:
try:
await conn.execute("DELETE FROM Roles WHERE id = $1", role_id)
logger.info(f"Deleted role ID '{role_id}'")
except Exception as e:
logger.error(f"Error deleting role ID '{role_id}': {e}")
raise

async def delete_permission(self, permission_id: int):
async with self.pool.acquire() as conn:
try:
await conn.execute("DELETE FROM Permissions WHERE id = $1", permission_id)
logger.info(f"Deleted permission ID '{permission_id}'")
except Exception as e:
logger.error(f"Error deleting permission ID '{permission_id}': {e}")
raise

async def remove_role_from_user(self, user_id: int, role_id: int):
async with self.pool.acquire() as conn:
try:
await conn.execute("""
DELETE FROM UserRoles
WHERE user_id = $1 AND role_id = $2
""", user_id, role_id)
logger.info(f"Removed role ID '{role_id}' from user ID '{user_id}'")
except Exception as e:
logger.error(f"Error removing role ID '{role_id}' from user ID '{user_id}': {e}")
raise

async def get_roles_for_user(self, user_id: int) -> List[dict]:
async with self.pool.acquire() as conn:
roles = await conn.fetch("""
SELECT R.id, R.role_name
FROM Roles R
INNER JOIN UserRoles UR ON R.id = UR.role_id
WHERE UR.user_id = $1
""", user_id)
return [dict(role) for role in roles]

async def remove_permission_from_role(self, role_id: int, permission_id: int):
async with self.pool.acquire() as conn:
try:
await conn.execute("""
DELETE FROM RolePermissions
WHERE role_id = $1 AND permission_id = $2
""", role_id, permission_id)
logger.info(f"Removed permission ID '{permission_id}' from role ID '{role_id}'")
except Exception as e:
logger.error(f"Error removing permission ID '{permission_id}' from role ID '{role_id}': {e}")
raise

async def get_permissions_for_role(self, role_id: int) -> List[dict]:
async with self.pool.acquire() as conn:
permissions = await conn.fetch("""
SELECT P.id, P.permission_name
FROM Permissions P
INNER JOIN RolePermissions RP ON P.id = RP.permission_id
WHERE RP.role_id = $1
""", role_id)
return [dict(permission) for permission in permissions]
Loading

0 comments on commit c71ad5c

Please sign in to comment.