Skip to content

Commit

Permalink
Platform validation rules
Browse files Browse the repository at this point in the history
  • Loading branch information
AnthonyRonning committed Feb 12, 2025
1 parent 03a483f commit a27e8d5
Show file tree
Hide file tree
Showing 8 changed files with 652 additions and 26 deletions.
322 changes: 320 additions & 2 deletions Cargo.lock

Large diffs are not rendered by default.

3 changes: 3 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -65,3 +65,6 @@ aws-config = "1.5.10"
aws-sdk-sqs = "1.49.0"
aws-types = "1.3.3"
backoff = { version = "0.4.0", features = ["tokio"] }
validator = { version = "0.20.0", features = ["derive"] }
regex = "1.9.0"
lazy_static = "1.4.0"
25 changes: 25 additions & 0 deletions src/db.rs
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,12 @@ pub trait DBConnection {
) -> Result<(Vec<User>, i64), DBError>;

fn create_org_with_owner(&self, new_org: NewOrg, owner_id: Uuid) -> Result<Org, DBError>;

fn accept_invite_transaction(
&self,
invite: &InviteCode,
new_membership: NewOrgMembership,
) -> Result<OrgMembership, DBError>;
}

pub(crate) struct PostgresConnection {
Expand Down Expand Up @@ -1106,6 +1112,25 @@ impl DBConnection for PostgresConnection {
Ok(org)
})
}

fn accept_invite_transaction(
&self,
invite: &InviteCode,
new_membership: NewOrgMembership,
) -> Result<OrgMembership, DBError> {
debug!("Starting invite acceptance transaction");
let conn = &mut self.db.get().map_err(|_| DBError::ConnectionError)?;

conn.transaction(|conn| {
// Create the membership
let membership = new_membership.insert(conn)?;

// Mark invite as used
invite.mark_as_used(conn)?;

Ok(membership)
})
}
}

pub(crate) fn setup_db(url: String) -> Arc<dyn DBConnection + Send + Sync> {
Expand Down
26 changes: 24 additions & 2 deletions src/models/platform_users.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,34 @@ use uuid::Uuid;
pub enum PlatformUserError {
#[error("Database error: {0}")]
DatabaseError(#[from] diesel::result::Error),
#[error("Invalid email format: {0}")]
InvalidEmail(String),
#[error("Email already exists: {0}")]
DuplicateEmail(String),
#[error("Invalid password: {0}")]
InvalidPassword(String),
}

impl From<PlatformUserError> for ApiError {
fn from(err: PlatformUserError) -> Self {
tracing::error!("Platform user error: {:?}", err);
ApiError::InternalServerError
match err {
PlatformUserError::DatabaseError(e) => {
tracing::error!("Database error: {:?}", e);
ApiError::InternalServerError
}
PlatformUserError::InvalidEmail(msg) => {
tracing::error!("Invalid email error: {}", msg);
ApiError::BadRequest
}
PlatformUserError::DuplicateEmail(email) => {
tracing::error!("Duplicate email error: {}", email);
ApiError::EmailAlreadyExists
}
PlatformUserError::InvalidPassword(msg) => {
tracing::error!("Invalid password error: {}", msg);
ApiError::BadRequest
}
}
}
}

Expand Down
50 changes: 45 additions & 5 deletions src/web/platform/login_routes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,36 @@ use std::sync::Arc;
use tokio::spawn;
use tracing::{debug, error};
use uuid::Uuid;
use validator::Validate;

#[derive(Deserialize, Clone)]
#[derive(Deserialize, Clone, Validate)]
pub struct PlatformLoginRequest {
#[validate(email(message = "Invalid email format"))]
#[validate(length(max = 255, message = "Email must not exceed 255 characters"))]
pub email: String,

#[validate(length(
min = 8,
max = 64,
message = "Password must be between 8 and 64 characters"
))]
pub password: String,
}

#[derive(Deserialize, Clone)]
#[derive(Deserialize, Clone, Validate)]
pub struct PlatformRegisterRequest {
#[validate(email(message = "Invalid email format"))]
#[validate(length(max = 255, message = "Email must not exceed 255 characters"))]
pub email: String,

#[validate(length(
min = 8,
max = 64,
message = "Password must be between 8 and 64 characters"
))]
pub password: String,

#[validate(length(max = 50, message = "Name must not exceed 50 characters"))]
pub name: Option<String>,
}

Expand All @@ -36,8 +55,9 @@ pub struct PlatformAuthResponse {
pub refresh_token: String,
}

#[derive(Deserialize, Clone)]
#[derive(Deserialize, Clone, Validate)]
pub struct PlatformRefreshRequest {
#[validate(length(min = 1, message = "Refresh token cannot be empty"))]
pub refresh_token: String,
}

Expand Down Expand Up @@ -80,6 +100,12 @@ pub async fn login_platform_user(
) -> Result<Json<EncryptedResponse<PlatformAuthResponse>>, ApiError> {
debug!("Entering login_platform_user function");

// Validate request
if let Err(errors) = login_request.validate() {
error!("Validation error: {:?}", errors);
return Err(ApiError::BadRequest);
}

let auth_response = login_internal_platform(data.clone(), login_request).await?;
let result = encrypt_response(&data, &session_id, &auth_response).await;
debug!("Exiting login_platform_user function");
Expand All @@ -97,8 +123,10 @@ async fn login_internal_platform(
{
Ok(Some(platform_user)) => {
// Generate tokens
let access_token = NewToken::new_for_platform_user(&platform_user, TokenType::Access, &data)?;
let refresh_token = NewToken::new_for_platform_user(&platform_user, TokenType::Refresh, &data)?;
let access_token =
NewToken::new_for_platform_user(&platform_user, TokenType::Access, &data)?;
let refresh_token =
NewToken::new_for_platform_user(&platform_user, TokenType::Refresh, &data)?;

let auth_response = PlatformAuthResponse {
id: platform_user.uuid,
Expand Down Expand Up @@ -127,6 +155,12 @@ pub async fn register_platform_user(
) -> Result<Json<EncryptedResponse<PlatformAuthResponse>>, ApiError> {
debug!("Entering register_platform_user function");

// Validate request
if let Err(errors) = register_request.validate() {
error!("Validation error: {:?}", errors);
return Err(ApiError::BadRequest);
}

// Check if user already exists
if data
.db
Expand Down Expand Up @@ -213,6 +247,12 @@ pub async fn refresh_platform_token(
) -> Result<Json<EncryptedResponse<PlatformRefreshResponse>>, ApiError> {
debug!("Entering refresh_platform_token function");

// Validate request
if let Err(errors) = refresh_request.validate() {
error!("Validation error: {:?}", errors);
return Err(ApiError::BadRequest);
}

let claims = crate::jwt::validate_token(&refresh_request.refresh_token, &data, "refresh")?;
let platform_user_id = Uuid::parse_str(&claims.sub).map_err(|_| ApiError::InvalidJwt)?;

Expand Down
1 change: 1 addition & 0 deletions src/web/platform/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
mod login_routes;
mod org_routes;
mod validation;

pub use login_routes::router as login_routes;
pub use org_routes::router as org_routes;
78 changes: 61 additions & 17 deletions src/web/platform/org_routes.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use crate::models::invite_codes::InviteCodeError;
use crate::models::org_memberships::OrgMembershipError;
use crate::models::org_project_secrets::NewOrgProjectSecret;
use crate::DBError;
Expand All @@ -11,6 +12,9 @@ use crate::{
platform_users::PlatformUser,
},
web::encryption_middleware::{decrypt_request, encrypt_response, EncryptedResponse},
web::platform::validation::{
validate_alphanumeric_only, validate_alphanumeric_with_symbols, validate_secret_size,
},
ApiError, AppState,
};
use axum::{
Expand All @@ -28,27 +32,46 @@ use std::sync::Arc;
use tokio::spawn;
use tracing::{debug, error};
use uuid::Uuid;
use validator::Validate;

#[derive(Deserialize, Clone)]
#[derive(Deserialize, Clone, Validate)]
pub struct CreateOrgRequest {
#[validate(length(min = 1, max = 50))]
#[validate(custom(function = "validate_alphanumeric_with_symbols"))]
pub name: String,
}

#[derive(Deserialize, Clone)]
#[derive(Deserialize, Clone, Validate)]
pub struct CreateProjectRequest {
#[validate(length(min = 1, max = 50))]
#[validate(custom(function = "validate_alphanumeric_with_symbols"))]
pub name: String,
#[validate(length(max = 255))]
pub description: Option<String>,
}

#[derive(Deserialize, Clone)]
#[derive(Deserialize, Clone, Validate)]
pub struct UpdateProjectRequest {
#[validate(length(min = 1, max = 50))]
#[validate(custom(function = "validate_alphanumeric_with_symbols"))]
pub name: Option<String>,
#[validate(length(max = 255))]
pub description: Option<String>,
#[validate(custom(function = "validate_project_status"))]
pub status: Option<String>,
}

#[derive(Deserialize, Clone)]
fn validate_project_status(status: &str) -> Result<(), validator::ValidationError> {
match status {
"active" | "inactive" | "suspended" => Ok(()),
_ => Err(validator::ValidationError::new("project_status")),
}
}

#[derive(Deserialize, Clone, Validate)]
pub struct CreateInviteRequest {
#[validate(email(message = "Invalid email format"))]
#[validate(length(max = 255, message = "Email must not exceed 255 characters"))]
pub email: String,
#[serde(default = "default_invite_role")]
pub role: OrgRole,
Expand All @@ -58,14 +81,17 @@ fn default_invite_role() -> OrgRole {
OrgRole::Admin
}

#[derive(Deserialize, Clone)]
#[derive(Deserialize, Clone, Validate)]
pub struct UpdateMembershipRequest {
pub role: OrgRole,
}

#[derive(Deserialize, Clone)]
#[derive(Deserialize, Clone, Validate)]
pub struct CreateSecretRequest {
#[validate(length(min = 1, max = 50))]
#[validate(custom(function = "validate_alphanumeric_only"))]
pub key_name: String,
#[validate(custom(function = "validate_secret_size"))]
pub secret: String, // Base64 encoded secret value
}

Expand Down Expand Up @@ -202,6 +228,12 @@ async fn create_org(
) -> Result<Json<EncryptedResponse<OrgResponse>>, ApiError> {
debug!("Creating new organization");

// Validate request
if let Err(errors) = create_request.validate() {
error!("Validation error: {:?}", errors);
return Err(ApiError::BadRequest);
}

// Create the organization and owner membership in a single transaction
let new_org = NewOrg::new(create_request.name);
let org = data
Expand Down Expand Up @@ -305,6 +337,12 @@ async fn create_project(
) -> Result<Json<EncryptedResponse<ProjectResponse>>, ApiError> {
debug!("Creating new project");

// Validate request
if let Err(errors) = create_request.validate() {
error!("Validation error: {:?}", errors);
return Err(ApiError::BadRequest);
}

// Verify user has admin or owner role
let membership = data
.db
Expand Down Expand Up @@ -670,22 +708,22 @@ async fn accept_invite(
return Err(ApiError::Unauthorized);
}

// Create the membership
// Create the membership and mark invite as used in a single transaction
let new_membership = NewOrgMembership::new(
platform_user.uuid,
invite.org_id,
invite.role.clone().into(),
);
data.db.create_org_membership(new_membership).map_err(|e| {
error!("Failed to create membership: {:?}", e);
ApiError::InternalServerError
})?;

// Mark invite as used
data.db.mark_invite_code_as_used(&invite).map_err(|e| {
error!("Failed to mark invite as used: {:?}", e);
ApiError::InternalServerError
})?;
data.db
.accept_invite_transaction(&invite, new_membership)
.map_err(|e| {
error!("Failed to accept invite: {:?}", e);
match e {
DBError::InviteCodeError(InviteCodeError::AlreadyUsed) => ApiError::BadRequest,
DBError::InviteCodeError(InviteCodeError::Expired) => ApiError::BadRequest,
_ => ApiError::InternalServerError,
}
})?;

let response = serde_json::json!({
"message": "Invite accepted successfully"
Expand All @@ -703,6 +741,12 @@ pub async fn create_secret(
) -> Result<Json<EncryptedResponse<SecretResponse>>, ApiError> {
debug!("Creating project secret");

// Validate request
if let Err(errors) = create_request.validate() {
error!("Validation error: {:?}", errors);
return Err(ApiError::BadRequest);
}

// Verify user has admin or owner role
let membership = data
.db
Expand Down
Loading

0 comments on commit a27e8d5

Please sign in to comment.