diff --git a/migrations/2025-01-20-223225_third_party_devs/down.sql b/migrations/2025-01-20-223225_third_party_devs/down.sql new file mode 100644 index 0000000..b5c5451 --- /dev/null +++ b/migrations/2025-01-20-223225_third_party_devs/down.sql @@ -0,0 +1,15 @@ +-- Drop triggers first +DROP TRIGGER IF EXISTS update_org_project_secrets_updated_at ON org_project_secrets; +DROP TRIGGER IF EXISTS update_org_projects_updated_at ON org_projects; +DROP TRIGGER IF EXISTS update_org_users_updated_at ON org_users; +DROP TRIGGER IF EXISTS update_orgs_updated_at ON orgs; + +-- Drop tables in reverse order of creation (to handle foreign key dependencies) +DROP TABLE IF EXISTS org_project_secrets; +DROP TABLE IF EXISTS org_projects; +DROP TABLE IF EXISTS org_users; +DROP TABLE IF EXISTS invite_codes; +DROP TABLE IF EXISTS orgs; + +-- Drop the trigger function +DROP FUNCTION IF EXISTS update_updated_at_column(); diff --git a/migrations/2025-01-20-223225_third_party_devs/up.sql b/migrations/2025-01-20-223225_third_party_devs/up.sql new file mode 100644 index 0000000..b4e1e06 --- /dev/null +++ b/migrations/2025-01-20-223225_third_party_devs/up.sql @@ -0,0 +1,120 @@ +-- Create organizations table +CREATE TABLE orgs ( + id SERIAL PRIMARY KEY, + uuid UUID NOT NULL DEFAULT uuid_generate_v4() UNIQUE, + name TEXT NOT NULL, + created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP +); + +-- Create organization users table (for developers/admins of organizations) +CREATE TABLE org_users ( + id SERIAL PRIMARY KEY, + uuid UUID NOT NULL DEFAULT uuid_generate_v4() UNIQUE, + org_id INTEGER NOT NULL REFERENCES orgs(id) ON DELETE CASCADE, + email TEXT NOT NULL, + name TEXT, + password_enc BYTEA, + created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP, + UNIQUE(org_id, email) +); + +-- Create organization projects table +CREATE TABLE org_projects ( + id SERIAL PRIMARY KEY, + uuid UUID NOT NULL DEFAULT uuid_generate_v4() UNIQUE, + client_id UUID NOT NULL DEFAULT uuid_generate_v4() UNIQUE, + org_id INTEGER NOT NULL REFERENCES orgs(id) ON DELETE CASCADE, + name TEXT NOT NULL, + description TEXT, + status TEXT NOT NULL DEFAULT 'active' CHECK (status IN ('active', 'inactive', 'suspended')), + created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP, + UNIQUE(org_id, name) +); + +-- Create project secrets table for storing encrypted API keys, OAuth secrets, etc. +CREATE TABLE org_project_secrets ( + id SERIAL PRIMARY KEY, + project_id INTEGER NOT NULL REFERENCES org_projects(id) ON DELETE CASCADE, + key_name TEXT NOT NULL, + secret_enc BYTEA NOT NULL, + created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP, + UNIQUE(project_id, key_name) -- Each key name should be unique per project +); + +-- Create invite codes table for organization user invitations +CREATE TABLE invite_codes ( + id SERIAL PRIMARY KEY, + code UUID NOT NULL DEFAULT uuid_generate_v4() UNIQUE, + org_id INTEGER NOT NULL REFERENCES orgs(id) ON DELETE CASCADE, + email TEXT NOT NULL, + used BOOLEAN NOT NULL DEFAULT false, + expires_at TIMESTAMP WITH TIME ZONE NOT NULL, + created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP +); + +-- Add indexes for foreign keys and commonly queried fields +CREATE INDEX idx_org_users_org_id ON org_users(org_id); +CREATE INDEX idx_org_users_email ON org_users(email); +CREATE INDEX idx_org_projects_org_id ON org_projects(org_id); +CREATE INDEX idx_org_projects_client_id ON org_projects(client_id); +CREATE INDEX idx_org_project_secrets_project_id ON org_project_secrets(project_id); +CREATE INDEX idx_invite_codes_org_id ON invite_codes(org_id); +CREATE INDEX idx_invite_codes_code ON invite_codes(code); +CREATE INDEX idx_invite_codes_email ON invite_codes(email); + +-- Create updated_at triggers +CREATE OR REPLACE FUNCTION update_updated_at_column() +RETURNS TRIGGER AS $$ +BEGIN + NEW.updated_at = CURRENT_TIMESTAMP; + RETURN NEW; +END; +$$ language 'plpgsql'; + +CREATE TRIGGER update_orgs_updated_at + BEFORE UPDATE ON orgs + FOR EACH ROW + EXECUTE FUNCTION update_updated_at_column(); + +CREATE TRIGGER update_org_users_updated_at + BEFORE UPDATE ON org_users + FOR EACH ROW + EXECUTE FUNCTION update_updated_at_column(); + +CREATE TRIGGER update_org_projects_updated_at + BEFORE UPDATE ON org_projects + FOR EACH ROW + EXECUTE FUNCTION update_updated_at_column(); + +CREATE TRIGGER update_org_project_secrets_updated_at + BEFORE UPDATE ON org_project_secrets + FOR EACH ROW + EXECUTE FUNCTION update_updated_at_column(); + +CREATE TRIGGER update_invite_codes_updated_at + BEFORE UPDATE ON invite_codes + FOR EACH ROW + EXECUTE FUNCTION update_updated_at_column(); + +-- Create a default "OpenSecret" organization and "Maple" project for existing data +INSERT INTO orgs (name) VALUES ('OpenSecret'); + +-- Create the Maple project under OpenSecret organization +INSERT INTO org_projects ( + org_id, + name, + description, + status +) +SELECT + id, + 'Maple', + 'TryMaple Project', + 'active' +FROM orgs +WHERE name = 'OpenSecret'; diff --git a/src/db.rs b/src/db.rs index f80a7a6..b335af9 100644 --- a/src/db.rs +++ b/src/db.rs @@ -2,9 +2,16 @@ use crate::models::email_verification::{ EmailVerification, EmailVerificationError, NewEmailVerification, }; use crate::models::enclave_secrets::{EnclaveSecret, EnclaveSecretError, NewEnclaveSecret}; +use crate::models::invite_codes::{InviteCode, InviteCodeError, NewInviteCode}; use crate::models::oauth::{ NewOAuthProvider, NewUserOAuthConnection, OAuthError, OAuthProvider, UserOAuthConnection, }; +use crate::models::org_project_secrets::{ + NewOrgProjectSecret, OrgProjectSecret, OrgProjectSecretError, +}; +use crate::models::org_projects::{NewOrgProject, OrgProject, OrgProjectError}; +use crate::models::org_users::{NewOrgUser, OrgUser, OrgUserError}; +use crate::models::orgs::{NewOrg, Org, OrgError}; use crate::models::password_reset::{ NewPasswordResetRequest, PasswordResetError, PasswordResetRequest, }; @@ -44,6 +51,26 @@ pub enum DBError { OAuthError(#[from] OAuthError), #[error("Token usage error: {0}")] TokenUsageError(#[from] TokenUsageError), + #[error("Org error: {0}")] + OrgError(#[from] OrgError), + #[error("Org not found")] + OrgNotFound, + #[error("Org user error: {0}")] + OrgUserError(#[from] OrgUserError), + #[error("Org user not found")] + OrgUserNotFound, + #[error("Org project error: {0}")] + OrgProjectError(#[from] OrgProjectError), + #[error("Org project not found")] + OrgProjectNotFound, + #[error("Org project secret error: {0}")] + OrgProjectSecretError(#[from] OrgProjectSecretError), + #[error("Org project secret not found")] + OrgProjectSecretNotFound, + #[error("Invite code error: {0}")] + InviteCodeError(#[from] InviteCodeError), + #[error("Invite code not found")] + InviteCodeNotFound, } #[allow(dead_code)] @@ -129,6 +156,80 @@ pub trait DBConnection { fn create_token_usage(&self, new_usage: NewTokenUsage) -> Result; fn update_user(&self, user: &User) -> Result<(), DBError>; + + // New org-related methods + fn create_org(&self, new_org: NewOrg) -> Result; + fn get_org_by_id(&self, id: i32) -> Result; + fn get_org_by_uuid(&self, uuid: Uuid) -> Result; + fn get_org_by_name(&self, name: &str) -> Result, DBError>; + fn get_all_orgs(&self) -> Result, DBError>; + fn update_org(&self, org: &Org) -> Result<(), DBError>; + fn delete_org(&self, org: &Org) -> Result<(), DBError>; + + // Org user methods + fn create_org_user(&self, new_user: NewOrgUser) -> Result; + fn get_org_user_by_id(&self, id: i32) -> Result; + fn get_org_user_by_uuid(&self, uuid: Uuid) -> Result; + fn get_org_user_by_email_and_org( + &self, + email: &str, + org_id: i32, + ) -> Result, DBError>; + fn get_all_org_users_for_org(&self, org_id: i32) -> Result, DBError>; + fn update_org_user(&self, user: &OrgUser) -> Result<(), DBError>; + fn update_org_user_password( + &self, + user: &OrgUser, + new_password_enc: Option>, + ) -> Result<(), DBError>; + fn delete_org_user(&self, user: &OrgUser) -> Result<(), DBError>; + + // Org project methods + fn create_org_project(&self, new_project: NewOrgProject) -> Result; + fn get_org_project_by_id(&self, id: i32) -> Result; + fn get_org_project_by_uuid(&self, uuid: Uuid) -> Result; + fn get_org_project_by_client_id(&self, client_id: Uuid) -> Result; + fn get_org_project_by_name_and_org( + &self, + name: &str, + org_id: i32, + ) -> Result, DBError>; + fn get_all_org_projects_for_org(&self, org_id: i32) -> Result, DBError>; + fn get_active_org_projects_for_org(&self, org_id: i32) -> Result, DBError>; + fn update_org_project(&self, project: &OrgProject) -> Result<(), DBError>; + fn delete_org_project(&self, project: &OrgProject) -> Result<(), DBError>; + + // Org project secret methods + fn create_org_project_secret( + &self, + new_secret: NewOrgProjectSecret, + ) -> Result; + fn get_org_project_secret_by_id(&self, id: i32) -> Result; + fn get_org_project_secret_by_key_name_and_project( + &self, + key_name: &str, + project_id: i32, + ) -> Result, DBError>; + fn get_all_org_project_secrets_for_project( + &self, + project_id: i32, + ) -> Result, DBError>; + fn update_org_project_secret(&self, secret: &OrgProjectSecret) -> Result<(), DBError>; + fn delete_org_project_secret(&self, secret: &OrgProjectSecret) -> Result<(), DBError>; + + // Invite code methods + fn create_invite_code(&self, new_invite: NewInviteCode) -> Result; + fn get_invite_code_by_id(&self, id: i32) -> Result; + fn get_invite_code_by_code(&self, code: Uuid) -> Result; + fn get_invite_code_by_email_and_org( + &self, + email: &str, + org_id: i32, + ) -> Result, DBError>; + fn get_all_invite_codes_for_org(&self, org_id: i32) -> Result, DBError>; + fn mark_invite_code_as_used(&self, invite: &InviteCode) -> Result<(), DBError>; + fn update_invite_code(&self, invite: &InviteCode) -> Result<(), DBError>; + fn delete_invite_code(&self, invite: &InviteCode) -> Result<(), DBError>; } pub(crate) struct PostgresConnection { @@ -411,6 +512,425 @@ impl DBConnection for PostgresConnection { let conn = &mut self.db.get().map_err(|_| DBError::ConnectionError)?; user.update(conn).map_err(DBError::from) } + + // Org implementations + fn create_org(&self, new_org: NewOrg) -> Result { + debug!("Creating new org"); + let conn = &mut self.db.get().map_err(|_| DBError::ConnectionError)?; + let result = new_org.insert(conn).map_err(DBError::from); + if let Err(ref e) = result { + error!("Failed to create org: {:?}", e); + } + result + } + + fn get_org_by_id(&self, id: i32) -> Result { + debug!("Getting org by ID"); + let conn = &mut self.db.get().map_err(|_| DBError::ConnectionError)?; + let result = Org::get_by_id(conn, id)?.ok_or(DBError::OrgNotFound); + if let Err(ref e) = result { + error!("Failed to get org by ID: {:?}", e); + } + result + } + + fn get_org_by_uuid(&self, uuid: Uuid) -> Result { + debug!("Getting org by UUID"); + let conn = &mut self.db.get().map_err(|_| DBError::ConnectionError)?; + let result = Org::get_by_uuid(conn, uuid)?.ok_or(DBError::OrgNotFound); + if let Err(ref e) = result { + error!("Failed to get org by UUID: {:?}", e); + } + result + } + + fn get_org_by_name(&self, name: &str) -> Result, DBError> { + debug!("Getting org by name"); + let conn = &mut self.db.get().map_err(|_| DBError::ConnectionError)?; + let result = Org::get_by_name(conn, name).map_err(DBError::from); + if let Err(ref e) = result { + error!("Failed to get org by name: {:?}", e); + } + result + } + + fn get_all_orgs(&self) -> Result, DBError> { + debug!("Getting all orgs"); + let conn = &mut self.db.get().map_err(|_| DBError::ConnectionError)?; + let result = Org::get_all(conn).map_err(DBError::from); + if let Err(ref e) = result { + error!("Failed to get all orgs: {:?}", e); + } + result + } + + fn update_org(&self, org: &Org) -> Result<(), DBError> { + debug!("Updating org"); + let conn = &mut self.db.get().map_err(|_| DBError::ConnectionError)?; + let result = org.update(conn).map_err(DBError::from); + if let Err(ref e) = result { + error!("Failed to update org: {:?}", e); + } + result + } + + fn delete_org(&self, org: &Org) -> Result<(), DBError> { + debug!("Deleting org"); + let conn = &mut self.db.get().map_err(|_| DBError::ConnectionError)?; + let result = org.delete(conn).map_err(DBError::from); + if let Err(ref e) = result { + error!("Failed to delete org: {:?}", e); + } + result + } + + // Org user implementations + fn create_org_user(&self, new_user: NewOrgUser) -> Result { + debug!("Creating new org user"); + let conn = &mut self.db.get().map_err(|_| DBError::ConnectionError)?; + let result = new_user.insert(conn).map_err(DBError::from); + if let Err(ref e) = result { + error!("Failed to create org user: {:?}", e); + } + result + } + + fn get_org_user_by_id(&self, id: i32) -> Result { + debug!("Getting org user by ID"); + let conn = &mut self.db.get().map_err(|_| DBError::ConnectionError)?; + let result = OrgUser::get_by_id(conn, id)?.ok_or(DBError::OrgUserNotFound); + if let Err(ref e) = result { + error!("Failed to get org user by ID: {:?}", e); + } + result + } + + fn get_org_user_by_uuid(&self, uuid: Uuid) -> Result { + debug!("Getting org user by UUID"); + let conn = &mut self.db.get().map_err(|_| DBError::ConnectionError)?; + let result = OrgUser::get_by_uuid(conn, uuid)?.ok_or(DBError::OrgUserNotFound); + if let Err(ref e) = result { + error!("Failed to get org user by UUID: {:?}", e); + } + result + } + + fn get_org_user_by_email_and_org( + &self, + email: &str, + org_id: i32, + ) -> Result, DBError> { + debug!("Getting org user by email and org"); + let conn = &mut self.db.get().map_err(|_| DBError::ConnectionError)?; + let result = OrgUser::get_by_email_and_org(conn, email, org_id).map_err(DBError::from); + if let Err(ref e) = result { + error!("Failed to get org user by email and org: {:?}", e); + } + result + } + + fn get_all_org_users_for_org(&self, org_id: i32) -> Result, DBError> { + debug!("Getting all org users for org"); + let conn = &mut self.db.get().map_err(|_| DBError::ConnectionError)?; + let result = OrgUser::get_all_for_org(conn, org_id).map_err(DBError::from); + if let Err(ref e) = result { + error!("Failed to get all org users for org: {:?}", e); + } + result + } + + fn update_org_user(&self, user: &OrgUser) -> Result<(), DBError> { + debug!("Updating org user"); + let conn = &mut self.db.get().map_err(|_| DBError::ConnectionError)?; + let result = user.update(conn).map_err(DBError::from); + if let Err(ref e) = result { + error!("Failed to update org user: {:?}", e); + } + result + } + + fn update_org_user_password( + &self, + user: &OrgUser, + new_password_enc: Option>, + ) -> Result<(), DBError> { + debug!("Updating org user password"); + let conn = &mut self.db.get().map_err(|_| DBError::ConnectionError)?; + let result = user + .update_password(conn, new_password_enc) + .map_err(DBError::from); + if let Err(ref e) = result { + error!("Failed to update org user password: {:?}", e); + } + result + } + + fn delete_org_user(&self, user: &OrgUser) -> Result<(), DBError> { + debug!("Deleting org user"); + let conn = &mut self.db.get().map_err(|_| DBError::ConnectionError)?; + let result = user.delete(conn).map_err(DBError::from); + if let Err(ref e) = result { + error!("Failed to delete org user: {:?}", e); + } + result + } + + // Org project implementations + fn create_org_project(&self, new_project: NewOrgProject) -> Result { + debug!("Creating new org project"); + let conn = &mut self.db.get().map_err(|_| DBError::ConnectionError)?; + let result = new_project.insert(conn).map_err(DBError::from); + if let Err(ref e) = result { + error!("Failed to create org project: {:?}", e); + } + result + } + + fn get_org_project_by_id(&self, id: i32) -> Result { + debug!("Getting org project by ID"); + let conn = &mut self.db.get().map_err(|_| DBError::ConnectionError)?; + let result = OrgProject::get_by_id(conn, id)?.ok_or(DBError::OrgProjectNotFound); + if let Err(ref e) = result { + error!("Failed to get org project by ID: {:?}", e); + } + result + } + + fn get_org_project_by_uuid(&self, uuid: Uuid) -> Result { + debug!("Getting org project by UUID"); + let conn = &mut self.db.get().map_err(|_| DBError::ConnectionError)?; + let result = OrgProject::get_by_uuid(conn, uuid)?.ok_or(DBError::OrgProjectNotFound); + if let Err(ref e) = result { + error!("Failed to get org project by UUID: {:?}", e); + } + result + } + + fn get_org_project_by_client_id(&self, client_id: Uuid) -> Result { + debug!("Getting org project by client ID"); + let conn = &mut self.db.get().map_err(|_| DBError::ConnectionError)?; + let result = + OrgProject::get_by_client_id(conn, client_id)?.ok_or(DBError::OrgProjectNotFound); + if let Err(ref e) = result { + error!("Failed to get org project by client ID: {:?}", e); + } + result + } + + fn get_org_project_by_name_and_org( + &self, + name: &str, + org_id: i32, + ) -> Result, DBError> { + debug!("Getting org project by name and org"); + let conn = &mut self.db.get().map_err(|_| DBError::ConnectionError)?; + let result = OrgProject::get_by_name_and_org(conn, name, org_id).map_err(DBError::from); + if let Err(ref e) = result { + error!("Failed to get org project by name and org: {:?}", e); + } + result + } + + fn get_all_org_projects_for_org(&self, org_id: i32) -> Result, DBError> { + debug!("Getting all org projects for org"); + let conn = &mut self.db.get().map_err(|_| DBError::ConnectionError)?; + let result = OrgProject::get_all_for_org(conn, org_id).map_err(DBError::from); + if let Err(ref e) = result { + error!("Failed to get all org projects for org: {:?}", e); + } + result + } + + fn get_active_org_projects_for_org(&self, org_id: i32) -> Result, DBError> { + debug!("Getting active org projects for org"); + let conn = &mut self.db.get().map_err(|_| DBError::ConnectionError)?; + let result = OrgProject::get_active_for_org(conn, org_id).map_err(DBError::from); + if let Err(ref e) = result { + error!("Failed to get active org projects for org: {:?}", e); + } + result + } + + fn update_org_project(&self, project: &OrgProject) -> Result<(), DBError> { + debug!("Updating org project"); + let conn = &mut self.db.get().map_err(|_| DBError::ConnectionError)?; + let result = project.update(conn).map_err(DBError::from); + if let Err(ref e) = result { + error!("Failed to update org project: {:?}", e); + } + result + } + + fn delete_org_project(&self, project: &OrgProject) -> Result<(), DBError> { + debug!("Deleting org project"); + let conn = &mut self.db.get().map_err(|_| DBError::ConnectionError)?; + let result = project.delete(conn).map_err(DBError::from); + if let Err(ref e) = result { + error!("Failed to delete org project: {:?}", e); + } + result + } + + // Org project secret implementations + fn create_org_project_secret( + &self, + new_secret: NewOrgProjectSecret, + ) -> Result { + debug!("Creating new org project secret"); + let conn = &mut self.db.get().map_err(|_| DBError::ConnectionError)?; + let result = new_secret.insert(conn).map_err(DBError::from); + if let Err(ref e) = result { + error!("Failed to create org project secret: {:?}", e); + } + result + } + + fn get_org_project_secret_by_id(&self, id: i32) -> Result { + debug!("Getting org project secret by ID"); + let conn = &mut self.db.get().map_err(|_| DBError::ConnectionError)?; + let result = + OrgProjectSecret::get_by_id(conn, id)?.ok_or(DBError::OrgProjectSecretNotFound); + if let Err(ref e) = result { + error!("Failed to get org project secret by ID: {:?}", e); + } + result + } + + fn get_org_project_secret_by_key_name_and_project( + &self, + key_name: &str, + project_id: i32, + ) -> Result, DBError> { + debug!("Getting org project secret by key name and project"); + let conn = &mut self.db.get().map_err(|_| DBError::ConnectionError)?; + let result = OrgProjectSecret::get_by_key_name_and_project(conn, key_name, project_id) + .map_err(DBError::from); + if let Err(ref e) = result { + error!( + "Failed to get org project secret by key name and project: {:?}", + e + ); + } + result + } + + fn get_all_org_project_secrets_for_project( + &self, + project_id: i32, + ) -> Result, DBError> { + debug!("Getting all org project secrets for project"); + let conn = &mut self.db.get().map_err(|_| DBError::ConnectionError)?; + let result = OrgProjectSecret::get_all_for_project(conn, project_id).map_err(DBError::from); + if let Err(ref e) = result { + error!("Failed to get all org project secrets for project: {:?}", e); + } + result + } + + fn update_org_project_secret(&self, secret: &OrgProjectSecret) -> Result<(), DBError> { + debug!("Updating org project secret"); + let conn = &mut self.db.get().map_err(|_| DBError::ConnectionError)?; + let result = secret.update(conn).map_err(DBError::from); + if let Err(ref e) = result { + error!("Failed to update org project secret: {:?}", e); + } + result + } + + fn delete_org_project_secret(&self, secret: &OrgProjectSecret) -> Result<(), DBError> { + debug!("Deleting org project secret"); + let conn = &mut self.db.get().map_err(|_| DBError::ConnectionError)?; + let result = secret.delete(conn).map_err(DBError::from); + if let Err(ref e) = result { + error!("Failed to delete org project secret: {:?}", e); + } + result + } + + // Invite code implementations + fn create_invite_code(&self, new_invite: NewInviteCode) -> Result { + debug!("Creating new invite code"); + let conn = &mut self.db.get().map_err(|_| DBError::ConnectionError)?; + let result = new_invite.insert(conn).map_err(DBError::from); + if let Err(ref e) = result { + error!("Failed to create invite code: {:?}", e); + } + result + } + + fn get_invite_code_by_id(&self, id: i32) -> Result { + debug!("Getting invite code by ID"); + let conn = &mut self.db.get().map_err(|_| DBError::ConnectionError)?; + let result = InviteCode::get_by_id(conn, id)?.ok_or(DBError::InviteCodeNotFound); + if let Err(ref e) = result { + error!("Failed to get invite code by ID: {:?}", e); + } + result + } + + fn get_invite_code_by_code(&self, code: Uuid) -> Result { + debug!("Getting invite code by code"); + let conn = &mut self.db.get().map_err(|_| DBError::ConnectionError)?; + let result = InviteCode::get_by_code(conn, code)?.ok_or(DBError::InviteCodeNotFound); + if let Err(ref e) = result { + error!("Failed to get invite code by code: {:?}", e); + } + result + } + + fn get_invite_code_by_email_and_org( + &self, + email: &str, + org_id: i32, + ) -> Result, DBError> { + debug!("Getting invite code by email and org"); + let conn = &mut self.db.get().map_err(|_| DBError::ConnectionError)?; + let result = InviteCode::get_by_email_and_org(conn, email, org_id).map_err(DBError::from); + if let Err(ref e) = result { + error!("Failed to get invite code by email and org: {:?}", e); + } + result + } + + fn get_all_invite_codes_for_org(&self, org_id: i32) -> Result, DBError> { + debug!("Getting all invite codes for org"); + let conn = &mut self.db.get().map_err(|_| DBError::ConnectionError)?; + let result = InviteCode::get_all_for_org(conn, org_id).map_err(DBError::from); + if let Err(ref e) = result { + error!("Failed to get all invite codes for org: {:?}", e); + } + result + } + + fn mark_invite_code_as_used(&self, invite: &InviteCode) -> Result<(), DBError> { + debug!("Marking invite code as used"); + let conn = &mut self.db.get().map_err(|_| DBError::ConnectionError)?; + let result = invite.mark_as_used(conn).map_err(DBError::from); + if let Err(ref e) = result { + error!("Failed to mark invite code as used: {:?}", e); + } + result + } + + fn update_invite_code(&self, invite: &InviteCode) -> Result<(), DBError> { + debug!("Updating invite code"); + let conn = &mut self.db.get().map_err(|_| DBError::ConnectionError)?; + let result = invite.update(conn).map_err(DBError::from); + if let Err(ref e) = result { + error!("Failed to update invite code: {:?}", e); + } + result + } + + fn delete_invite_code(&self, invite: &InviteCode) -> Result<(), DBError> { + debug!("Deleting invite code"); + let conn = &mut self.db.get().map_err(|_| DBError::ConnectionError)?; + let result = invite.delete(conn).map_err(DBError::from); + if let Err(ref e) = result { + error!("Failed to delete invite code: {:?}", e); + } + result + } } pub(crate) fn setup_db(url: String) -> Arc { diff --git a/src/email.rs b/src/email.rs index 455c915..c8223e5 100644 --- a/src/email.rs +++ b/src/email.rs @@ -3,6 +3,7 @@ use chrono::{Duration, Utc}; use resend_rs::types::CreateEmailBaseOptions; use resend_rs::{Resend, Result}; use tracing::error; +use uuid::Uuid; #[derive(Debug, thiserror::Error)] pub enum EmailError { @@ -293,6 +294,82 @@ pub async fn send_password_reset_confirmation_email( Ok(()) } +pub async fn send_invite_email( + app_mode: AppMode, + resend_api_key: Option, + to_email: String, + invite_code: Uuid, +) -> Result<(), EmailError> { + tracing::debug!("Entering send_invite_email"); + if resend_api_key.is_none() { + return Err(EmailError::ApiKeyNotFound); + } + let api_key = resend_api_key.expect("just checked"); + + let resend = Resend::new(&api_key); + + let from = from_email(app_mode.clone()); + let to = [to_email]; + let subject = "You've Been Invited to Join an Organization on Maple AI"; + + let base_url = match app_mode { + AppMode::Local => "http://localhost:5173", + AppMode::Dev => "https://dev.secretgpt.ai", + AppMode::Preview => "https://opensecret.cloud", + AppMode::Prod => "https://trymaple.ai", + AppMode::Custom(_) => "https://preview.opensecret.cloud", + }; + + let invite_url = format!("{}/join/{}", base_url, invite_code); + + let html_content = format!( + r#" + + + + + + Organization Invitation - Maple AI + + + +
+

You've Been Invited!

+

You've been invited to join an organization on Maple AI. To accept this invitation, please click the button below:

+

+ Accept Invitation +

+

If the button doesn't work, you can copy and paste the following link into your browser:

+

{}

+

Alternatively, you can use the following invitation code:

+

{}

+

This invitation link and code will expire in 24 hours.

+

If you weren't expecting this invitation, you can safely ignore this email.

+

Best regards,
The OpenSecret Team

+
+ + + "#, + invite_url, invite_url, invite_code + ); + + let email = CreateEmailBaseOptions::new(from, to, subject).with_html(&html_content); + + let _email = resend.emails.send(email).await.map_err(|e| { + tracing::error!("Failed to send email: {}", e); + EmailError::UnknownError + }); + + tracing::debug!("Exiting send_invite_email"); + Ok(()) +} + fn from_email(app_mode: AppMode) -> String { match app_mode { AppMode::Local => "local@email.trymaple.ai".to_string(), diff --git a/src/jwt.rs b/src/jwt.rs index 16350c1..c56a4fc 100644 --- a/src/jwt.rs +++ b/src/jwt.rs @@ -18,9 +18,12 @@ use sha2::Sha256; use uuid::Uuid; use crate::AppMode; -use crate::{ApiError, AppState, User}; +use crate::{ApiError, AppState}; use url::Url; +use crate::models::{org_users::OrgUser, users::User}; + +#[derive(Debug, Clone)] pub enum TokenType { Access, Refresh, @@ -151,6 +154,62 @@ impl NewToken { token: token_string, }) } + + pub fn new_for_org_user( + user: &OrgUser, + token_type: TokenType, + app_state: &AppState, + ) -> Result { + let (aud, azp, duration) = match token_type { + TokenType::Access => ( + "org_access".to_string(), + None, + Duration::minutes(app_state.config.access_token_maxage), + ), + TokenType::Refresh => ( + "org_refresh".to_string(), + None, + Duration::days(app_state.config.refresh_token_maxage), + ), + TokenType::ThirdParty { aud, azp } => { + // Validate the audience URL against allowed domains + TokenType::validate_third_party_audience(&aud, &app_state.app_mode)?; + + // For now, enforce that azp must be "maple" + if azp != "maple" { + return Err(ApiError::BadRequest); + } + (aud, Some(azp), Duration::hours(1)) + } + }; + + let custom_claims = CustomClaims { + sub: user.uuid.to_string(), + aud, + azp, + }; + + tracing::debug!("Creating new org token with claims: {:?}", custom_claims); + + let time_options = TimeOptions::default(); + let claims = Claims::new(custom_claims).set_duration_and_issuance(&time_options, duration); + + let header = Header::empty().with_token_type("JWT"); + let es256k = Es256k::::new(app_state.config.jwt_keys.secp.clone()); + + let token_string = es256k + .token(&header, &claims, &app_state.config.jwt_keys.signing_key) + .map_err(|e| { + tracing::error!("Error creating token: {:?}", e); + ApiError::InternalServerError + })?; + + tracing::debug!("Successfully created org token"); + + Ok(Self { + token: token_string, + }) + } } pub async fn generate_jwt_secret( diff --git a/src/main.rs b/src/main.rs index 1302acf..a6ad77f 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1860,6 +1860,7 @@ async fn main() -> Result<(), Error> { ) .merge(attestation_routes::router(app_state.clone())) .merge(oauth_routes(app_state.clone())) + .merge(web::org_login_routes(app_state.clone())) .layer(cors); let listener = tokio::net::TcpListener::bind("127.0.0.1:3000") diff --git a/src/models/invite_codes.rs b/src/models/invite_codes.rs new file mode 100644 index 0000000..d3acb7a --- /dev/null +++ b/src/models/invite_codes.rs @@ -0,0 +1,147 @@ +use crate::models::schema::invite_codes; +use chrono::{DateTime, Duration, Utc}; +use diesel::prelude::*; +use serde::{Deserialize, Serialize}; +use thiserror::Error; +use uuid::Uuid; + +#[derive(Error, Debug)] +pub enum InviteCodeError { + #[error("Database error: {0}")] + DatabaseError(#[from] diesel::result::Error), + #[error("Invite code expired")] + Expired, + #[error("Invite code already used")] + AlreadyUsed, +} + +#[derive(Queryable, Identifiable, AsChangeset, Serialize, Deserialize, Clone, Debug)] +#[diesel(table_name = invite_codes)] +pub struct InviteCode { + pub id: i32, + pub code: Uuid, + pub org_id: i32, + pub email: String, + pub used: bool, + pub expires_at: DateTime, + pub created_at: DateTime, + pub updated_at: DateTime, +} + +impl InviteCode { + pub fn get_by_id( + conn: &mut PgConnection, + lookup_id: i32, + ) -> Result, InviteCodeError> { + invite_codes::table + .filter(invite_codes::id.eq(lookup_id)) + .first::(conn) + .optional() + .map_err(InviteCodeError::DatabaseError) + } + + pub fn get_by_code( + conn: &mut PgConnection, + lookup_code: Uuid, + ) -> Result, InviteCodeError> { + invite_codes::table + .filter(invite_codes::code.eq(lookup_code)) + .first::(conn) + .optional() + .map_err(InviteCodeError::DatabaseError) + } + + pub fn get_by_email_and_org( + conn: &mut PgConnection, + lookup_email: &str, + lookup_org_id: i32, + ) -> Result, InviteCodeError> { + invite_codes::table + .filter(invite_codes::email.eq(lookup_email)) + .filter(invite_codes::org_id.eq(lookup_org_id)) + .filter(invite_codes::used.eq(false)) + .filter(invite_codes::expires_at.gt(diesel::dsl::now)) + .first::(conn) + .optional() + .map_err(InviteCodeError::DatabaseError) + } + + pub fn get_all_for_org( + conn: &mut PgConnection, + lookup_org_id: i32, + ) -> Result, InviteCodeError> { + invite_codes::table + .filter(invite_codes::org_id.eq(lookup_org_id)) + .load::(conn) + .map_err(InviteCodeError::DatabaseError) + } + + pub fn mark_as_used(&self, conn: &mut PgConnection) -> Result<(), InviteCodeError> { + if self.used { + return Err(InviteCodeError::AlreadyUsed); + } + + if self.expires_at < Utc::now() { + return Err(InviteCodeError::Expired); + } + + diesel::update(invite_codes::table) + .filter(invite_codes::id.eq(self.id)) + .set(( + invite_codes::used.eq(true), + invite_codes::updated_at.eq(diesel::dsl::now), + )) + .execute(conn) + .map(|_| ()) + .map_err(InviteCodeError::DatabaseError) + } + + pub fn update(&self, conn: &mut PgConnection) -> Result<(), InviteCodeError> { + diesel::update(invite_codes::table) + .filter(invite_codes::id.eq(self.id)) + .set(( + invite_codes::email.eq(&self.email), + invite_codes::used.eq(self.used), + invite_codes::expires_at.eq(self.expires_at), + invite_codes::updated_at.eq(diesel::dsl::now), + )) + .execute(conn) + .map(|_| ()) + .map_err(InviteCodeError::DatabaseError) + } + + pub fn delete(&self, conn: &mut PgConnection) -> Result<(), InviteCodeError> { + diesel::delete(invite_codes::table) + .filter(invite_codes::id.eq(self.id)) + .execute(conn) + .map(|_| ()) + .map_err(InviteCodeError::DatabaseError) + } +} + +#[derive(Insertable)] +#[diesel(table_name = invite_codes)] +pub struct NewInviteCode { + pub code: Uuid, + pub org_id: i32, + pub email: String, + pub expires_at: DateTime, +} + +impl NewInviteCode { + pub fn new(org_id: i32, email: String, expiry_hours: i64) -> Self { + NewInviteCode { + code: Uuid::new_v4(), + org_id, + email, + expires_at: Utc::now() + Duration::hours(expiry_hours), + } + } + + pub fn insert(&self, conn: &mut PgConnection) -> Result { + diesel::insert_into(invite_codes::table) + .values(self) + .get_result::(conn) + .map_err(InviteCodeError::DatabaseError) + } +} diff --git a/src/models/mod.rs b/src/models/mod.rs index ff58eaa..da06225 100644 --- a/src/models/mod.rs +++ b/src/models/mod.rs @@ -1,6 +1,11 @@ pub mod email_verification; pub mod enclave_secrets; +pub mod invite_codes; pub mod oauth; +pub mod org_project_secrets; +pub mod org_projects; +pub mod org_users; +pub mod orgs; pub mod password_reset; mod schema; pub mod token_usage; diff --git a/src/models/org_project_secrets.rs b/src/models/org_project_secrets.rs new file mode 100644 index 0000000..e1d7cf8 --- /dev/null +++ b/src/models/org_project_secrets.rs @@ -0,0 +1,113 @@ +use crate::models::schema::org_project_secrets; +use chrono::{DateTime, Utc}; +use diesel::prelude::*; +use serde::{Deserialize, Serialize}; +use thiserror::Error; + +#[derive(Error, Debug)] +pub enum OrgProjectSecretError { + #[error("Database error: {0}")] + DatabaseError(#[from] diesel::result::Error), +} + +#[derive(Queryable, Identifiable, AsChangeset, Serialize, Deserialize, Clone, Debug)] +#[diesel(table_name = org_project_secrets)] +pub struct OrgProjectSecret { + pub id: i32, + pub project_id: i32, + pub key_name: String, + pub secret_enc: Vec, + pub created_at: DateTime, + pub updated_at: DateTime, +} + +impl OrgProjectSecret { + pub fn get_by_id( + conn: &mut PgConnection, + lookup_id: i32, + ) -> Result, OrgProjectSecretError> { + org_project_secrets::table + .filter(org_project_secrets::id.eq(lookup_id)) + .first::(conn) + .optional() + .map_err(OrgProjectSecretError::DatabaseError) + } + + pub fn get_by_key_name_and_project( + conn: &mut PgConnection, + lookup_key_name: &str, + lookup_project_id: i32, + ) -> Result, OrgProjectSecretError> { + org_project_secrets::table + .filter(org_project_secrets::key_name.eq(lookup_key_name)) + .filter(org_project_secrets::project_id.eq(lookup_project_id)) + .first::(conn) + .optional() + .map_err(OrgProjectSecretError::DatabaseError) + } + + pub fn get_all_for_project( + conn: &mut PgConnection, + lookup_project_id: i32, + ) -> Result, OrgProjectSecretError> { + org_project_secrets::table + .filter(org_project_secrets::project_id.eq(lookup_project_id)) + .load::(conn) + .map_err(OrgProjectSecretError::DatabaseError) + } + + pub fn update(&self, conn: &mut PgConnection) -> Result<(), OrgProjectSecretError> { + diesel::update(org_project_secrets::table) + .filter(org_project_secrets::id.eq(self.id)) + .set(( + org_project_secrets::key_name.eq(&self.key_name), + org_project_secrets::secret_enc.eq(&self.secret_enc), + org_project_secrets::updated_at.eq(diesel::dsl::now), + )) + .execute(conn) + .map(|_| ()) + .map_err(OrgProjectSecretError::DatabaseError) + } + + pub fn delete(&self, conn: &mut PgConnection) -> Result<(), OrgProjectSecretError> { + diesel::delete(org_project_secrets::table) + .filter(org_project_secrets::id.eq(self.id)) + .execute(conn) + .map(|_| ()) + .map_err(OrgProjectSecretError::DatabaseError) + } +} + +#[derive(Insertable)] +#[diesel(table_name = org_project_secrets)] +pub struct NewOrgProjectSecret { + pub project_id: i32, + pub key_name: String, + pub secret_enc: Vec, +} + +impl NewOrgProjectSecret { + pub fn new(project_id: i32, key_name: String, secret_enc: Vec) -> Self { + NewOrgProjectSecret { + project_id, + key_name, + secret_enc, + } + } + + pub fn insert( + &self, + conn: &mut PgConnection, + ) -> Result { + diesel::insert_into(org_project_secrets::table) + .values(self) + .on_conflict(( + org_project_secrets::project_id, + org_project_secrets::key_name, + )) + .do_update() + .set(org_project_secrets::secret_enc.eq(&self.secret_enc)) + .get_result::(conn) + .map_err(OrgProjectSecretError::DatabaseError) + } +} diff --git a/src/models/org_projects.rs b/src/models/org_projects.rs new file mode 100644 index 0000000..ae4ae5f --- /dev/null +++ b/src/models/org_projects.rs @@ -0,0 +1,154 @@ +use crate::models::schema::org_projects; +use chrono::{DateTime, Utc}; +use diesel::prelude::*; +use serde::{Deserialize, Serialize}; +use thiserror::Error; +use uuid::Uuid; + +#[derive(Error, Debug)] +pub enum OrgProjectError { + #[error("Database error: {0}")] + DatabaseError(#[from] diesel::result::Error), +} + +#[derive(Queryable, Identifiable, AsChangeset, Serialize, Deserialize, Clone, Debug)] +#[diesel(table_name = org_projects)] +pub struct OrgProject { + pub id: i32, + pub uuid: Uuid, + pub client_id: Uuid, + pub org_id: i32, + pub name: String, + pub description: Option, + pub status: String, + pub created_at: DateTime, + pub updated_at: DateTime, +} + +impl OrgProject { + pub fn get_by_id( + conn: &mut PgConnection, + lookup_id: i32, + ) -> Result, OrgProjectError> { + org_projects::table + .filter(org_projects::id.eq(lookup_id)) + .first::(conn) + .optional() + .map_err(OrgProjectError::DatabaseError) + } + + pub fn get_by_uuid( + conn: &mut PgConnection, + lookup_uuid: Uuid, + ) -> Result, OrgProjectError> { + org_projects::table + .filter(org_projects::uuid.eq(lookup_uuid)) + .first::(conn) + .optional() + .map_err(OrgProjectError::DatabaseError) + } + + pub fn get_by_client_id( + conn: &mut PgConnection, + lookup_client_id: Uuid, + ) -> Result, OrgProjectError> { + org_projects::table + .filter(org_projects::client_id.eq(lookup_client_id)) + .first::(conn) + .optional() + .map_err(OrgProjectError::DatabaseError) + } + + pub fn get_by_name_and_org( + conn: &mut PgConnection, + lookup_name: &str, + lookup_org_id: i32, + ) -> Result, OrgProjectError> { + org_projects::table + .filter(org_projects::name.eq(lookup_name)) + .filter(org_projects::org_id.eq(lookup_org_id)) + .first::(conn) + .optional() + .map_err(OrgProjectError::DatabaseError) + } + + pub fn get_all_for_org( + conn: &mut PgConnection, + lookup_org_id: i32, + ) -> Result, OrgProjectError> { + org_projects::table + .filter(org_projects::org_id.eq(lookup_org_id)) + .load::(conn) + .map_err(OrgProjectError::DatabaseError) + } + + pub fn get_active_for_org( + conn: &mut PgConnection, + lookup_org_id: i32, + ) -> Result, OrgProjectError> { + org_projects::table + .filter(org_projects::org_id.eq(lookup_org_id)) + .filter(org_projects::status.eq("active")) + .load::(conn) + .map_err(OrgProjectError::DatabaseError) + } + + pub fn update(&self, conn: &mut PgConnection) -> Result<(), OrgProjectError> { + diesel::update(org_projects::table) + .filter(org_projects::id.eq(self.id)) + .set(( + org_projects::name.eq(&self.name), + org_projects::description.eq(&self.description), + org_projects::status.eq(&self.status), + org_projects::updated_at.eq(diesel::dsl::now), + )) + .execute(conn) + .map(|_| ()) + .map_err(OrgProjectError::DatabaseError) + } + + pub fn delete(&self, conn: &mut PgConnection) -> Result<(), OrgProjectError> { + diesel::delete(org_projects::table) + .filter(org_projects::id.eq(self.id)) + .execute(conn) + .map(|_| ()) + .map_err(OrgProjectError::DatabaseError) + } +} + +#[derive(Insertable)] +#[diesel(table_name = org_projects)] +pub struct NewOrgProject { + pub org_id: i32, + pub name: String, + pub description: Option, + pub status: String, +} + +impl NewOrgProject { + pub fn new(org_id: i32, name: String) -> Self { + NewOrgProject { + org_id, + name, + description: None, + status: "active".to_string(), + } + } + + pub fn with_description(mut self, description: String) -> Self { + self.description = Some(description); + self + } + + pub fn with_status(mut self, status: String) -> Self { + self.status = status; + self + } + + pub fn insert(&self, conn: &mut PgConnection) -> Result { + diesel::insert_into(org_projects::table) + .values(self) + .get_result::(conn) + .map_err(OrgProjectError::DatabaseError) + } +} diff --git a/src/models/org_users.rs b/src/models/org_users.rs new file mode 100644 index 0000000..48c6b75 --- /dev/null +++ b/src/models/org_users.rs @@ -0,0 +1,190 @@ +use crate::models::schema::org_users; +use chrono::{DateTime, Utc}; +use diesel::prelude::*; +use serde::{Deserialize, Serialize}; +use thiserror::Error; +use uuid::Uuid; + +#[derive(Error, Debug)] +pub enum OrgUserError { + #[error("Database error: {0}")] + DatabaseError(#[from] diesel::result::Error), + #[error("Email already exists in organization")] + DuplicateUser, + #[error("ValidationError: {0}")] + ValidationError(String), +} + +/// Represents a user within an organization. +/// Each user belongs to exactly one organization and has organization-specific credentials. +#[derive(Queryable, Identifiable, AsChangeset, Serialize, Deserialize, Clone, Debug)] +#[diesel(table_name = org_users)] +pub struct OrgUser { + /// Internal database identifier + pub id: i32, + /// External unique identifier for API operations + pub uuid: Uuid, + /// Reference to the organization this user belongs to + pub org_id: i32, + /// User's email address, unique within an organization + pub email: String, + /// Optional display name for the user + pub name: Option, + /// Encrypted password bytes. None for OAuth-only users + #[serde(skip_serializing)] + password_enc: Option>, + /// Timestamp when the user was created + pub created_at: DateTime, + /// Timestamp when the user was last updated + pub updated_at: DateTime, +} + +impl OrgUser { + /// Returns the encrypted password bytes if they exist + pub fn password_enc(&self) -> Option<&Vec> { + self.password_enc.as_ref() + } + + /// Sets the encrypted password bytes + pub fn set_password_enc(&mut self, password: Option>) { + self.password_enc = password; + } + + pub fn get_by_id( + conn: &mut PgConnection, + lookup_id: i32, + ) -> Result, OrgUserError> { + org_users::table + .filter(org_users::id.eq(lookup_id)) + .first::(conn) + .optional() + .map_err(OrgUserError::DatabaseError) + } + + pub fn get_by_uuid( + conn: &mut PgConnection, + lookup_uuid: Uuid, + ) -> Result, OrgUserError> { + org_users::table + .filter(org_users::uuid.eq(lookup_uuid)) + .first::(conn) + .optional() + .map_err(OrgUserError::DatabaseError) + } + + pub fn get_by_email_and_org( + conn: &mut PgConnection, + lookup_email: &str, + lookup_org_id: i32, + ) -> Result, OrgUserError> { + org_users::table + .filter(org_users::email.eq(lookup_email)) + .filter(org_users::org_id.eq(lookup_org_id)) + .first::(conn) + .optional() + .map_err(OrgUserError::DatabaseError) + } + + pub fn get_all_for_org( + conn: &mut PgConnection, + lookup_org_id: i32, + ) -> Result, OrgUserError> { + org_users::table + .filter(org_users::org_id.eq(lookup_org_id)) + .load::(conn) + .map_err(OrgUserError::DatabaseError) + } + + pub fn update_password( + &self, + conn: &mut PgConnection, + new_password_enc: Option>, + ) -> Result<(), OrgUserError> { + diesel::update(org_users::table) + .filter(org_users::id.eq(self.id)) + .set(org_users::password_enc.eq(new_password_enc)) + .execute(conn) + .map(|_| ()) + .map_err(OrgUserError::DatabaseError) + } + + pub fn update(&self, conn: &mut PgConnection) -> Result<(), OrgUserError> { + // Check if email is already used by another user in the same org + if let Some(existing_user) = Self::get_by_email_and_org(conn, &self.email, self.org_id)? { + if existing_user.id != self.id { + return Err(OrgUserError::DuplicateUser); + } + } + + // Validate email is not empty + if self.email.trim().is_empty() { + return Err(OrgUserError::ValidationError( + "Email cannot be empty".to_string(), + )); + } + + diesel::update(org_users::table) + .filter(org_users::id.eq(self.id)) + .set(( + org_users::email.eq(&self.email), + org_users::name.eq(&self.name), + org_users::password_enc.eq(&self.password_enc), + org_users::updated_at.eq(diesel::dsl::now), + )) + .execute(conn) + .map(|_| ()) + .map_err(OrgUserError::DatabaseError) + } + + pub fn delete(&self, conn: &mut PgConnection) -> Result<(), OrgUserError> { + diesel::delete(org_users::table) + .filter(org_users::id.eq(self.id)) + .execute(conn) + .map(|_| ()) + .map_err(OrgUserError::DatabaseError) + } +} + +#[derive(Insertable)] +#[diesel(table_name = org_users)] +pub struct NewOrgUser { + pub org_id: i32, + pub email: String, + pub name: Option, + password_enc: Option>, +} + +impl NewOrgUser { + pub fn new(org_id: i32, email: String, password_enc: Option>) -> Self { + NewOrgUser { + org_id, + email, + name: None, + password_enc, + } + } + + pub fn with_name(mut self, name: String) -> Self { + self.name = Some(name); + self + } + + pub fn insert(&self, conn: &mut PgConnection) -> Result { + // Validate email is not empty + if self.email.trim().is_empty() { + return Err(OrgUserError::ValidationError( + "Email cannot be empty".to_string(), + )); + } + + // Check if email is already used in the org + if let Some(_) = OrgUser::get_by_email_and_org(conn, &self.email, self.org_id)? { + return Err(OrgUserError::DuplicateUser); + } + + diesel::insert_into(org_users::table) + .values(self) + .get_result::(conn) + .map_err(OrgUserError::DatabaseError) + } +} diff --git a/src/models/orgs.rs b/src/models/orgs.rs new file mode 100644 index 0000000..7a7fee6 --- /dev/null +++ b/src/models/orgs.rs @@ -0,0 +1,99 @@ +use crate::models::schema::orgs; +use chrono::{DateTime, Utc}; +use diesel::prelude::*; +use serde::{Deserialize, Serialize}; +use thiserror::Error; +use uuid::Uuid; + +#[derive(Error, Debug)] +pub enum OrgError { + #[error("Database error: {0}")] + DatabaseError(#[from] diesel::result::Error), +} + +#[derive(Queryable, Identifiable, AsChangeset, Serialize, Deserialize, Clone, Debug)] +#[diesel(table_name = orgs)] +pub struct Org { + pub id: i32, + pub uuid: Uuid, + pub name: String, + pub created_at: DateTime, + pub updated_at: DateTime, +} + +impl Org { + pub fn get_by_id(conn: &mut PgConnection, lookup_id: i32) -> Result, OrgError> { + orgs::table + .filter(orgs::id.eq(lookup_id)) + .first::(conn) + .optional() + .map_err(OrgError::DatabaseError) + } + + pub fn get_by_uuid( + conn: &mut PgConnection, + lookup_uuid: Uuid, + ) -> Result, OrgError> { + orgs::table + .filter(orgs::uuid.eq(lookup_uuid)) + .first::(conn) + .optional() + .map_err(OrgError::DatabaseError) + } + + pub fn get_by_name( + conn: &mut PgConnection, + lookup_name: &str, + ) -> Result, OrgError> { + orgs::table + .filter(orgs::name.eq(lookup_name)) + .first::(conn) + .optional() + .map_err(OrgError::DatabaseError) + } + + pub fn get_all(conn: &mut PgConnection) -> Result, OrgError> { + orgs::table + .load::(conn) + .map_err(OrgError::DatabaseError) + } + + pub fn update(&self, conn: &mut PgConnection) -> Result<(), OrgError> { + diesel::update(orgs::table) + .filter(orgs::id.eq(self.id)) + .set(( + orgs::name.eq(&self.name), + orgs::updated_at.eq(diesel::dsl::now), + )) + .execute(conn) + .map(|_| ()) + .map_err(OrgError::DatabaseError) + } + + pub fn delete(&self, conn: &mut PgConnection) -> Result<(), OrgError> { + diesel::delete(orgs::table) + .filter(orgs::id.eq(self.id)) + .execute(conn) + .map(|_| ()) + .map_err(OrgError::DatabaseError) + } +} + +#[derive(Insertable)] +#[diesel(table_name = orgs)] +pub struct NewOrg { + pub name: String, +} + +impl NewOrg { + pub fn new(name: String) -> Self { + NewOrg { name } + } + + pub fn insert(&self, conn: &mut PgConnection) -> Result { + diesel::insert_into(orgs::table) + .values(self) + .get_result::(conn) + .map_err(OrgError::DatabaseError) + } +} diff --git a/src/models/schema.rs b/src/models/schema.rs index 73162dd..7ef4647 100644 --- a/src/models/schema.rs +++ b/src/models/schema.rs @@ -20,6 +20,19 @@ diesel::table! { } } +diesel::table! { + invite_codes (id) { + id -> Int4, + code -> Uuid, + org_id -> Int4, + email -> Text, + used -> Bool, + expires_at -> Timestamptz, + created_at -> Timestamptz, + updated_at -> Timestamptz, + } +} + diesel::table! { oauth_providers (id) { id -> Int4, @@ -33,6 +46,54 @@ diesel::table! { } } +diesel::table! { + org_project_secrets (id) { + id -> Int4, + project_id -> Int4, + key_name -> Text, + secret_enc -> Bytea, + created_at -> Timestamptz, + updated_at -> Timestamptz, + } +} + +diesel::table! { + org_projects (id) { + id -> Int4, + uuid -> Uuid, + client_id -> Uuid, + org_id -> Int4, + name -> Text, + description -> Nullable, + status -> Text, + created_at -> Timestamptz, + updated_at -> Timestamptz, + } +} + +diesel::table! { + org_users (id) { + id -> Int4, + uuid -> Uuid, + org_id -> Int4, + email -> Text, + name -> Nullable, + password_enc -> Nullable, + created_at -> Timestamptz, + updated_at -> Timestamptz, + } +} + +diesel::table! { + orgs (id) { + id -> Int4, + uuid -> Uuid, + name -> Text, + created_at -> Timestamptz, + updated_at -> Timestamptz, + } +} + diesel::table! { password_reset_requests (id) { id -> Int4, @@ -96,12 +157,21 @@ diesel::table! { } } +diesel::joinable!(invite_codes -> orgs (org_id)); +diesel::joinable!(org_project_secrets -> org_projects (project_id)); +diesel::joinable!(org_projects -> orgs (org_id)); +diesel::joinable!(org_users -> orgs (org_id)); diesel::joinable!(user_oauth_connections -> oauth_providers (provider_id)); diesel::allow_tables_to_appear_in_same_query!( email_verifications, enclave_secrets, + invite_codes, oauth_providers, + org_project_secrets, + org_projects, + org_users, + orgs, password_reset_requests, token_usage, user_kv, diff --git a/src/web.rs b/src/web.rs index 47e1f59..39aaa97 100644 --- a/src/web.rs +++ b/src/web.rs @@ -4,10 +4,12 @@ mod health_routes; mod login_routes; mod oauth_routes; mod openai; +mod org_routes; mod protected_routes; pub use health_routes::router as health_routes; pub use login_routes::router as login_routes; pub use oauth_routes::router as oauth_routes; pub use openai::router as openai_routes; +pub use org_routes::login_router as org_login_routes; pub use protected_routes::router as protected_routes; diff --git a/src/web/org_routes.rs b/src/web/org_routes.rs new file mode 100644 index 0000000..72fdbb0 --- /dev/null +++ b/src/web/org_routes.rs @@ -0,0 +1,329 @@ +use crate::{ + db::DBError, + email::{send_invite_email, send_verification_email}, + jwt::{NewToken, TokenType}, + models::{ + email_verification::NewEmailVerification, + invite_codes::NewInviteCode, + org_users::{NewOrgUser, OrgUser}, + orgs::NewOrg, + }, + web::encryption_middleware::{decrypt_request, encrypt_response, EncryptedResponse}, + ApiError, AppState, +}; +use axum::{ + extract::State, middleware::from_fn_with_state, routing::post, Extension, Json, Router, +}; +use password_auth::generate_hash; +use serde::{Deserialize, Serialize}; +use std::sync::Arc; +use tokio::spawn; +use tracing::{debug, error}; +use uuid::Uuid; + +#[derive(Deserialize, Clone)] +pub struct RegisterOrgRequest { + pub org_name: String, + pub email: String, + pub password: String, + pub user_name: Option, +} + +#[derive(Serialize)] +pub struct RegisterOrgResponse { + pub org_id: Uuid, + pub user_id: Uuid, + pub email: String, + pub access_token: String, + pub refresh_token: String, +} + +#[derive(Deserialize, Clone)] +pub struct InviteOrgUserRequest { + pub email: String, +} + +#[derive(Serialize)] +pub struct InviteOrgUserResponse { + pub invite_code: String, +} + +#[derive(Deserialize, Clone)] +pub struct AcceptInviteRequest { + pub invite_code: Uuid, + pub password: String, + pub name: String, +} + +#[derive(Serialize)] +pub struct AcceptInviteResponse { + pub org_id: Uuid, + pub user_id: Uuid, + pub email: String, + pub access_token: String, + pub refresh_token: String, +} + +pub fn login_router(app_state: Arc) -> Router<()> { + Router::new() + .route( + "/register", + post(register_org).layer(from_fn_with_state( + app_state.clone(), + decrypt_request::, + )), + ) + .route( + "/invite", + post(invite_org_user).layer(from_fn_with_state( + app_state.clone(), + decrypt_request::, + )), + ) + .route( + "/accept-invite", + post(accept_invite).layer(from_fn_with_state( + app_state.clone(), + decrypt_request::, + )), + ) + .with_state(app_state) +} + +pub async fn register_org( + State(data): State>, + Extension(register_request): Extension, + Extension(session_id): Extension, +) -> Result>, ApiError> { + debug!("Entering register_org function"); + + // First check if org name is taken + if let Ok(Some(_)) = data.db.get_org_by_name(®ister_request.org_name) { + return Err(ApiError::BadRequest); + } + + // Create the organization + let new_org = NewOrg::new(register_request.org_name); + let org = data.db.create_org(new_org).map_err(|e| { + error!("Failed to create organization: {:?}", e); + ApiError::InternalServerError + })?; + + // Hash and encrypt the password + let password_hash = generate_hash(register_request.password); + let secret_key = secp256k1::SecretKey::from_slice(&data.enclave_key) + .map_err(|_| ApiError::InternalServerError)?; + let encrypted_password = + crate::encrypt::encrypt_with_key(&secret_key, password_hash.as_bytes()).await; + + // Create the org user + let new_org_user = NewOrgUser::new( + org.id, + register_request.email.clone(), + Some(encrypted_password), + ) + .with_name(register_request.user_name.unwrap_or_default()); + + let org_user = match data.db.create_org_user(new_org_user) { + Ok(user) => user, + Err(DBError::OrgUserError(crate::models::org_users::OrgUserError::DuplicateUser)) => { + return Err(ApiError::EmailAlreadyExists); + } + Err(e) => { + error!("Failed to create org user: {:?}", e); + return Err(ApiError::InternalServerError); + } + }; + + // Create email verification + let new_verification = NewEmailVerification::new(org_user.uuid, 24, false); + let verification = match data.db.create_email_verification(new_verification) { + Ok(v) => v, + Err(e) => { + error!("Error creating email verification: {:?}", e); + return Err(ApiError::InternalServerError); + } + }; + + // Send verification email in background + let email = register_request.email.clone(); + let verification_code = verification.verification_code; + let app_mode = data.app_mode.clone(); + let resend_api_key = data.resend_api_key.clone(); + spawn(async move { + if let Err(e) = + send_verification_email(app_mode, resend_api_key, email, verification_code).await + { + error!("Could not send verification email: {}", e); + } + }); + + // Generate tokens + let access_token = NewToken::new_for_org_user(&org_user, TokenType::Access, &data)?; + let refresh_token = NewToken::new_for_org_user(&org_user, TokenType::Refresh, &data)?; + + let response = RegisterOrgResponse { + org_id: org.uuid, + user_id: org_user.uuid, + email: register_request.email, + access_token: access_token.token, + refresh_token: refresh_token.token, + }; + + let result = encrypt_response(&data, &session_id, &response).await; + debug!("Exiting register_org function"); + result +} + +pub async fn invite_org_user( + State(data): State>, + Extension(invite_request): Extension, + Extension(session_id): Extension, + Extension(current_user): Extension, +) -> Result>, ApiError> { + debug!("Entering invite_org_user function"); + + // Check if user already exists in the organization + if let Ok(Some(_)) = data + .db + .get_org_user_by_email_and_org(&invite_request.email, current_user.org_id) + { + return Err(ApiError::EmailAlreadyExists); + } + + // Generate a new invite code + let new_invite = NewInviteCode::new( + current_user.org_id, + invite_request.email.clone(), + 24, // 24 hour expiry + ); + + let invite = data.db.create_invite_code(new_invite).map_err(|e| { + error!("Failed to create invite code: {:?}", e); + ApiError::InternalServerError + })?; + + // Send invite email in background + let email = invite_request.email.clone(); + let invite_code = invite.code; + let app_mode = data.app_mode.clone(); + let resend_api_key = data.resend_api_key.clone(); + spawn(async move { + if let Err(e) = send_invite_email(app_mode, resend_api_key, email, invite_code).await { + error!("Could not send invite email: {}", e); + } + }); + + let response = InviteOrgUserResponse { + invite_code: invite.code.to_string(), + }; + + let result = encrypt_response(&data, &session_id, &response).await; + debug!("Exiting invite_org_user function"); + result +} + +pub async fn accept_invite( + State(data): State>, + Extension(request): Extension, + Extension(session_id): Extension, +) -> Result>, ApiError> { + debug!("Entering accept_invite function"); + + // Get the invite code + let invite = data + .db + .get_invite_code_by_code(request.invite_code) + .map_err(|e| { + error!("Failed to get invite code: {:?}", e); + match e { + DBError::InviteCodeNotFound => ApiError::BadRequest, + _ => ApiError::InternalServerError, + } + })?; + + // Check if invite is expired or used + if invite.used { + return Err(ApiError::BadRequest); + } + if invite.expires_at < chrono::Utc::now() { + return Err(ApiError::BadRequest); + } + + // Hash and encrypt the password + let password_hash = generate_hash(request.password); + let secret_key = secp256k1::SecretKey::from_slice(&data.enclave_key) + .map_err(|_| ApiError::InternalServerError)?; + let encrypted_password = + crate::encrypt::encrypt_with_key(&secret_key, password_hash.as_bytes()).await; + + // Create the org user + let new_org_user = NewOrgUser::new( + invite.org_id, + invite.email.clone(), + Some(encrypted_password), + ) + .with_name(request.name); + + let org_user = match data.db.create_org_user(new_org_user) { + Ok(user) => user, + Err(DBError::OrgUserError(crate::models::org_users::OrgUserError::DuplicateUser)) => { + return Err(ApiError::EmailAlreadyExists); + } + Err(e) => { + error!("Failed to create org user: {:?}", e); + return Err(ApiError::InternalServerError); + } + }; + + // Mark invite as used + if let Err(e) = data.db.mark_invite_code_as_used(&invite) { + error!("Failed to mark invite code as used: {:?}", e); + return Err(ApiError::InternalServerError); + } + + // Create email verification + let new_verification = NewEmailVerification::new(org_user.uuid, 24, false); + let verification = match data.db.create_email_verification(new_verification) { + Ok(v) => v, + Err(e) => { + error!("Error creating email verification: {:?}", e); + return Err(ApiError::InternalServerError); + } + }; + + // Send verification email in background + let email = invite.email.clone(); + let verification_code = verification.verification_code; + let app_mode = data.app_mode.clone(); + let resend_api_key = data.resend_api_key.clone(); + spawn(async move { + if let Err(e) = + send_verification_email(app_mode, resend_api_key, email, verification_code).await + { + error!("Could not send verification email: {}", e); + } + }); + + // Get the org for the response + let org = data.db.get_org_by_id(invite.org_id).map_err(|e| { + error!("Failed to get org: {:?}", e); + ApiError::InternalServerError + })?; + + // Generate tokens + let access_token = NewToken::new_for_org_user(&org_user, TokenType::Access, &data)?; + let refresh_token = NewToken::new_for_org_user(&org_user, TokenType::Refresh, &data)?; + + let response = AcceptInviteResponse { + org_id: org.uuid, + user_id: org_user.uuid, + email: invite.email, + access_token: access_token.token, + refresh_token: refresh_token.token, + }; + + let result = encrypt_response(&data, &session_id, &response).await; + debug!("Exiting accept_invite function"); + result +}