From a5c9468f4efd57769354834f7f472bc7f28846c7 Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Tue, 28 Jan 2025 17:25:36 +0100 Subject: [PATCH 1/4] Utility to extract the localpart from a MXID --- Cargo.lock | 1 + Cargo.toml | 4 ++++ crates/data-model/Cargo.toml | 2 +- crates/matrix/Cargo.toml | 1 + crates/matrix/src/lib.rs | 18 ++++++++++++++++++ 5 files changed, 25 insertions(+), 1 deletion(-) diff --git a/Cargo.lock b/Cargo.lock index bbe35a3ac..f78d30442 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3527,6 +3527,7 @@ version = "0.13.0-rc.1" dependencies = [ "anyhow", "async-trait", + "ruma-common", "tokio", ] diff --git a/Cargo.toml b/Cargo.toml index 11a431119..8253b9774 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -248,6 +248,10 @@ version = "0.12.12" default-features = false features = ["http2", "rustls-tls-manual-roots", "charset", "json", "socks"] +# Matrix-related types +[workspace.dependencies.ruma-common] +version = "0.15.0" + # TLS stack [workspace.dependencies.rustls] version = "0.23.21" diff --git a/crates/data-model/Cargo.toml b/crates/data-model/Cargo.toml index ee49a6c05..845972c19 100644 --- a/crates/data-model/Cargo.toml +++ b/crates/data-model/Cargo.toml @@ -24,7 +24,7 @@ rand.workspace = true rand_chacha = "0.3.1" regex = "1.11.1" woothee = "0.13.0" -ruma-common = "0.15.0" +ruma-common.workspace = true mas-iana.workspace = true mas-jose.workspace = true diff --git a/crates/matrix/Cargo.toml b/crates/matrix/Cargo.toml index 8f7e77579..4f194bd22 100644 --- a/crates/matrix/Cargo.toml +++ b/crates/matrix/Cargo.toml @@ -15,3 +15,4 @@ workspace = true anyhow.workspace = true async-trait.workspace = true tokio.workspace = true +ruma-common.workspace = true diff --git a/crates/matrix/src/lib.rs b/crates/matrix/src/lib.rs index 700bb15d8..76f32e09f 100644 --- a/crates/matrix/src/lib.rs +++ b/crates/matrix/src/lib.rs @@ -8,6 +8,8 @@ mod mock; use std::{collections::HashSet, sync::Arc}; +use ruma_common::UserId; + pub use self::mock::HomeserverConnection as MockHomeserverConnection; // TODO: this should probably be another error type by default @@ -193,6 +195,22 @@ pub trait HomeserverConnection: Send + Sync { format!("@{}:{}", localpart, self.homeserver()) } + /// Get the localpart of a Matrix ID if it has the right server name + /// + /// Returns [`None`] if the input isn't a valid MXID, or if the server name + /// doesn't match + /// + /// # Parameters + /// + /// * `mxid` - The MXID of the user + fn localpart<'a>(&self, mxid: &'a str) -> Option<&'a str> { + let mxid = <&UserId>::try_from(mxid).ok()?; + if mxid.server_name() != self.homeserver() { + return None; + } + Some(mxid.localpart()) + } + /// Query the state of a user on the homeserver. /// /// # Parameters From 463ba2ea503a3080c32d15625d33a9e555196a81 Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Tue, 28 Jan 2025 17:25:54 +0100 Subject: [PATCH 2/4] Avoid unnecessary clones in the login_hint parser --- crates/data-model/src/oauth2/authorization_grant.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/crates/data-model/src/oauth2/authorization_grant.rs b/crates/data-model/src/oauth2/authorization_grant.rs index d34a431c7..2eb616a47 100644 --- a/crates/data-model/src/oauth2/authorization_grant.rs +++ b/crates/data-model/src/oauth2/authorization_grant.rs @@ -17,7 +17,7 @@ use rand::{ distributions::{Alphanumeric, DistString}, RngCore, }; -use ruma_common::{OwnedUserId, UserId}; +use ruma_common::UserId; use serde::Serialize; use ulid::Ulid; use url::Url; @@ -142,8 +142,8 @@ impl AuthorizationGrantStage { } } -pub enum LoginHint { - MXID(OwnedUserId), +pub enum LoginHint<'a> { + MXID(&'a UserId), None, } @@ -200,7 +200,7 @@ impl AuthorizationGrant { match prefix { "mxid" => { // Instead of erroring just return none - let Ok(mxid) = UserId::parse(value) else { + let Ok(mxid) = <&UserId>::try_from(value) else { return LoginHint::None; }; From 0096076dfabf0fd803c68c51ff797d4f015c765c Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Tue, 28 Jan 2025 17:26:29 +0100 Subject: [PATCH 3/4] Allow passing MXIDs in compat password logins --- crates/handlers/src/compat/login.rs | 100 +++++++++++++++++++--------- 1 file changed, 69 insertions(+), 31 deletions(-) diff --git a/crates/handlers/src/compat/login.rs b/crates/handlers/src/compat/login.rs index ff74bf8ef..75834f9b3 100644 --- a/crates/handlers/src/compat/login.rs +++ b/crates/handlers/src/compat/login.rs @@ -386,10 +386,13 @@ async fn user_password_login( username: String, password: String, ) -> Result<(CompatSession, User), RouteError> { + // Try getting the localpart out of the MXID + let username = homeserver.localpart(&username).unwrap_or(&username); + // Find the user let user = repo .user() - .find_by_username(&username) + .find_by_username(username) .await? .filter(mas_data_model::User::is_valid) .ok_or(RouteError::UserNotFound)?; @@ -539,23 +542,25 @@ mod tests { assert_eq!(body["errcode"], "M_UNRECOGNIZED"); } - /// Test that a user can login with a password using the Matrix - /// compatibility API. - #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")] - async fn test_user_password_login(pool: PgPool) { - setup(); - let state = TestState::from_pool(pool).await.unwrap(); - - // Let's provision a user and add a password to it. This part is hard to test - // with just HTTP requests, so we'll use the repository directly. + async fn user_with_password(state: &TestState, username: &str, password: &str) { + let mut rng = state.rng(); let mut repo = state.repository().await.unwrap(); let user = repo .user() - .add(&mut state.rng(), &state.clock, "alice".to_owned()) + .add(&mut rng, &state.clock, username.to_owned()) + .await + .unwrap(); + let (version, hash) = state + .password_manager + .hash(&mut rng, Zeroizing::new(password.as_bytes().to_vec())) .await .unwrap(); + repo.user_password() + .add(&mut rng, &state.clock, &user, version, hash, None) + .await + .unwrap(); let mxid = state.homeserver_connection.mxid(&user.username); state .homeserver_connection @@ -563,28 +568,17 @@ mod tests { .await .unwrap(); - let (version, hashed_password) = state - .password_manager - .hash( - &mut state.rng(), - Zeroizing::new("password".to_owned().into_bytes()), - ) - .await - .unwrap(); + repo.save().await.unwrap(); + } - repo.user_password() - .add( - &mut state.rng(), - &state.clock, - &user, - version, - hashed_password, - None, - ) - .await - .unwrap(); + /// Test that a user can login with a password using the Matrix + /// compatibility API. + #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")] + async fn test_user_password_login(pool: PgPool) { + setup(); + let state = TestState::from_pool(pool).await.unwrap(); - repo.save().await.unwrap(); + user_with_password(&state, "alice", "password").await; // Now let's try to login with the password, without asking for a refresh token. let request = Request::post("/_matrix/client/v3/login").json(serde_json::json!({ @@ -662,6 +656,50 @@ mod tests { assert_eq!(body, old_body); } + /// Test that a user can login with a password using the Matrix + /// compatibility API, using a MXID as identifier + #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")] + async fn test_user_password_login_mxid(pool: PgPool) { + setup(); + let state = TestState::from_pool(pool).await.unwrap(); + + user_with_password(&state, "alice", "password").await; + + // Login with a full MXID as identifier + let request = Request::post("/_matrix/client/v3/login").json(serde_json::json!({ + "type": "m.login.password", + "identifier": { + "type": "m.id.user", + "user": "@alice:example.com", + }, + "password": "password", + })); + + let response = state.request(request).await; + response.assert_status(StatusCode::OK); + let body: ResponseBody = response.json(); + assert!(!body.access_token.is_empty()); + assert_eq!(body.device_id.as_ref().unwrap().as_str().len(), 10); + assert_eq!(body.user_id, "@alice:example.com"); + assert_eq!(body.refresh_token, None); + assert_eq!(body.expires_in_ms, None); + + // With a MXID, but with the wrong server name + let request = Request::post("/_matrix/client/v3/login").json(serde_json::json!({ + "type": "m.login.password", + "identifier": { + "type": "m.id.user", + "user": "@alice:something.corp", + }, + "password": "password", + })); + + let response = state.request(request).await; + response.assert_status(StatusCode::FORBIDDEN); + let body: serde_json::Value = response.json(); + assert_eq!(body["errcode"], "M_FORBIDDEN"); + } + /// Test that password logins are rate limited. #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")] async fn test_password_login_rate_limit(pool: PgPool) { From 76ba8e11398695e3b4712c87456ad11cd5bcf149 Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Tue, 28 Jan 2025 17:26:44 +0100 Subject: [PATCH 4/4] Allow passing MXIDs in the login page --- crates/handlers/src/views/login.rs | 116 ++++++++++++++++++++++++++--- 1 file changed, 106 insertions(+), 10 deletions(-) diff --git a/crates/handlers/src/views/login.rs b/crates/handlers/src/views/login.rs index b0d6991c0..1203dc2e0 100644 --- a/crates/handlers/src/views/login.rs +++ b/crates/handlers/src/views/login.rs @@ -168,6 +168,11 @@ pub(crate) async fn post( return Ok((cookie_jar, Html(content)).into_response()); } + // Extract the localpart of the MXID, fallback to the bare username + let username = homeserver + .localpart(&form.username) + .unwrap_or(&form.username); + match login( password_manager, &mut repo, @@ -175,7 +180,7 @@ pub(crate) async fn post( &clock, limiter, requester, - &form.username, + username, &form.password, user_agent, ) @@ -479,23 +484,17 @@ mod test { .contains(&escape_html(&second_provider_login.path_and_query()))); } - #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")] - async fn test_password_login(pool: PgPool) { - setup(); - let state = TestState::from_pool(pool).await.unwrap(); + async fn user_with_password(state: &TestState, username: &str, password: &str) { let mut rng = state.rng(); - let cookies = CookieHelper::new(); - - // Provision a user with a password let mut repo = state.repository().await.unwrap(); let user = repo .user() - .add(&mut rng, &state.clock, "john".to_owned()) + .add(&mut rng, &state.clock, username.to_owned()) .await .unwrap(); let (version, hash) = state .password_manager - .hash(&mut rng, Zeroizing::new("hunter2".as_bytes().to_vec())) + .hash(&mut rng, Zeroizing::new(password.as_bytes().to_vec())) .await .unwrap(); repo.user_password() @@ -503,6 +502,16 @@ mod test { .await .unwrap(); repo.save().await.unwrap(); + } + + #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")] + async fn test_password_login(pool: PgPool) { + setup(); + let state = TestState::from_pool(pool).await.unwrap(); + let cookies = CookieHelper::new(); + + // Provision a user with a password + user_with_password(&state, "john", "hunter2").await; // Render the login page to get a CSRF token let request = Request::get("/login").empty(); @@ -542,6 +551,93 @@ mod test { assert!(response.body().contains("john")); } + #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")] + async fn test_password_login_with_mxid(pool: PgPool) { + setup(); + let state = TestState::from_pool(pool).await.unwrap(); + let cookies = CookieHelper::new(); + + // Provision a user with a password + user_with_password(&state, "john", "hunter2").await; + + // Render the login page to get a CSRF token + let request = Request::get("/login").empty(); + let request = cookies.with_cookies(request); + let response = state.request(request).await; + cookies.save_cookies(&response); + response.assert_status(StatusCode::OK); + response.assert_header_value(CONTENT_TYPE, "text/html; charset=utf-8"); + // Extract the CSRF token from the response body + let csrf_token = response + .body() + .split("name=\"csrf\" value=\"") + .nth(1) + .unwrap() + .split('\"') + .next() + .unwrap(); + + // Submit the login form + let request = Request::post("/login").form(serde_json::json!({ + "csrf": csrf_token, + "username": "@john:example.com", + "password": "hunter2", + })); + let request = cookies.with_cookies(request); + let response = state.request(request).await; + cookies.save_cookies(&response); + response.assert_status(StatusCode::SEE_OTHER); + + // Now if we get to the home page, we should see the user's username + let request = Request::get("/").empty(); + let request = cookies.with_cookies(request); + let response = state.request(request).await; + cookies.save_cookies(&response); + response.assert_status(StatusCode::OK); + response.assert_header_value(CONTENT_TYPE, "text/html; charset=utf-8"); + assert!(response.body().contains("john")); + } + + #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")] + async fn test_password_login_with_mxid_wrong_server(pool: PgPool) { + setup(); + let state = TestState::from_pool(pool).await.unwrap(); + let cookies = CookieHelper::new(); + + // Provision a user with a password + user_with_password(&state, "john", "hunter2").await; + + // Render the login page to get a CSRF token + let request = Request::get("/login").empty(); + let request = cookies.with_cookies(request); + let response = state.request(request).await; + cookies.save_cookies(&response); + response.assert_status(StatusCode::OK); + response.assert_header_value(CONTENT_TYPE, "text/html; charset=utf-8"); + // Extract the CSRF token from the response body + let csrf_token = response + .body() + .split("name=\"csrf\" value=\"") + .nth(1) + .unwrap() + .split('\"') + .next() + .unwrap(); + + // Submit the login form + let request = Request::post("/login").form(serde_json::json!({ + "csrf": csrf_token, + "username": "@john:something.corp", + "password": "hunter2", + })); + let request = cookies.with_cookies(request); + let response = state.request(request).await; + + // This shouldn't have worked, we're back on the login page + response.assert_status(StatusCode::OK); + assert!(response.body().contains("Invalid credentials")); + } + #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")] async fn test_password_login_rate_limit(pool: PgPool) { setup();