Skip to content

Commit

Permalink
Merge pull request #3908 from element-hq/quenting/mxid-in-login
Browse files Browse the repository at this point in the history
Allow logging in with the full MXID
  • Loading branch information
sandhose authored Jan 29, 2025
2 parents 931de22 + 76ba8e1 commit 00b31b8
Show file tree
Hide file tree
Showing 8 changed files with 204 additions and 46 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 4 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,10 @@ version = "0.12.12"
default-features = false
features = ["http2", "rustls-tls-manual-roots", "charset", "json", "socks"]

# Matrix-related types
[workspace.dependencies.ruma-common]
version = "0.15.0"

# TLS stack
[workspace.dependencies.rustls]
version = "0.23.21"
Expand Down
2 changes: 1 addition & 1 deletion crates/data-model/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ rand.workspace = true
rand_chacha = "0.3.1"
regex = "1.11.1"
woothee = "0.13.0"
ruma-common = "0.15.0"
ruma-common.workspace = true

mas-iana.workspace = true
mas-jose.workspace = true
Expand Down
8 changes: 4 additions & 4 deletions crates/data-model/src/oauth2/authorization_grant.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ use rand::{
distributions::{Alphanumeric, DistString},
RngCore,
};
use ruma_common::{OwnedUserId, UserId};
use ruma_common::UserId;
use serde::Serialize;
use ulid::Ulid;
use url::Url;
Expand Down Expand Up @@ -142,8 +142,8 @@ impl AuthorizationGrantStage {
}
}

pub enum LoginHint {
MXID(OwnedUserId),
pub enum LoginHint<'a> {
MXID(&'a UserId),
None,
}

Expand Down Expand Up @@ -200,7 +200,7 @@ impl AuthorizationGrant {
match prefix {
"mxid" => {
// Instead of erroring just return none
let Ok(mxid) = UserId::parse(value) else {
let Ok(mxid) = <&UserId>::try_from(value) else {
return LoginHint::None;
};

Expand Down
100 changes: 69 additions & 31 deletions crates/handlers/src/compat/login.rs
Original file line number Diff line number Diff line change
Expand Up @@ -386,10 +386,13 @@ async fn user_password_login(
username: String,
password: String,
) -> Result<(CompatSession, User), RouteError> {
// Try getting the localpart out of the MXID
let username = homeserver.localpart(&username).unwrap_or(&username);

// Find the user
let user = repo
.user()
.find_by_username(&username)
.find_by_username(username)
.await?
.filter(mas_data_model::User::is_valid)
.ok_or(RouteError::UserNotFound)?;
Expand Down Expand Up @@ -539,52 +542,43 @@ mod tests {
assert_eq!(body["errcode"], "M_UNRECOGNIZED");
}

/// Test that a user can login with a password using the Matrix
/// compatibility API.
#[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
async fn test_user_password_login(pool: PgPool) {
setup();
let state = TestState::from_pool(pool).await.unwrap();

// Let's provision a user and add a password to it. This part is hard to test
// with just HTTP requests, so we'll use the repository directly.
async fn user_with_password(state: &TestState, username: &str, password: &str) {
let mut rng = state.rng();
let mut repo = state.repository().await.unwrap();

let user = repo
.user()
.add(&mut state.rng(), &state.clock, "alice".to_owned())
.add(&mut rng, &state.clock, username.to_owned())
.await
.unwrap();
let (version, hash) = state
.password_manager
.hash(&mut rng, Zeroizing::new(password.as_bytes().to_vec()))
.await
.unwrap();

repo.user_password()
.add(&mut rng, &state.clock, &user, version, hash, None)
.await
.unwrap();
let mxid = state.homeserver_connection.mxid(&user.username);
state
.homeserver_connection
.provision_user(&ProvisionRequest::new(mxid, &user.sub))
.await
.unwrap();

let (version, hashed_password) = state
.password_manager
.hash(
&mut state.rng(),
Zeroizing::new("password".to_owned().into_bytes()),
)
.await
.unwrap();
repo.save().await.unwrap();
}

repo.user_password()
.add(
&mut state.rng(),
&state.clock,
&user,
version,
hashed_password,
None,
)
.await
.unwrap();
/// Test that a user can login with a password using the Matrix
/// compatibility API.
#[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
async fn test_user_password_login(pool: PgPool) {
setup();
let state = TestState::from_pool(pool).await.unwrap();

repo.save().await.unwrap();
user_with_password(&state, "alice", "password").await;

// Now let's try to login with the password, without asking for a refresh token.
let request = Request::post("/_matrix/client/v3/login").json(serde_json::json!({
Expand Down Expand Up @@ -662,6 +656,50 @@ mod tests {
assert_eq!(body, old_body);
}

/// Test that a user can login with a password using the Matrix
/// compatibility API, using a MXID as identifier
#[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
async fn test_user_password_login_mxid(pool: PgPool) {
setup();
let state = TestState::from_pool(pool).await.unwrap();

user_with_password(&state, "alice", "password").await;

// Login with a full MXID as identifier
let request = Request::post("/_matrix/client/v3/login").json(serde_json::json!({
"type": "m.login.password",
"identifier": {
"type": "m.id.user",
"user": "@alice:example.com",
},
"password": "password",
}));

let response = state.request(request).await;
response.assert_status(StatusCode::OK);
let body: ResponseBody = response.json();
assert!(!body.access_token.is_empty());
assert_eq!(body.device_id.as_ref().unwrap().as_str().len(), 10);
assert_eq!(body.user_id, "@alice:example.com");
assert_eq!(body.refresh_token, None);
assert_eq!(body.expires_in_ms, None);

// With a MXID, but with the wrong server name
let request = Request::post("/_matrix/client/v3/login").json(serde_json::json!({
"type": "m.login.password",
"identifier": {
"type": "m.id.user",
"user": "@alice:something.corp",
},
"password": "password",
}));

let response = state.request(request).await;
response.assert_status(StatusCode::FORBIDDEN);
let body: serde_json::Value = response.json();
assert_eq!(body["errcode"], "M_FORBIDDEN");
}

/// Test that password logins are rate limited.
#[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
async fn test_password_login_rate_limit(pool: PgPool) {
Expand Down
116 changes: 106 additions & 10 deletions crates/handlers/src/views/login.rs
Original file line number Diff line number Diff line change
Expand Up @@ -168,14 +168,19 @@ pub(crate) async fn post(
return Ok((cookie_jar, Html(content)).into_response());
}

// Extract the localpart of the MXID, fallback to the bare username
let username = homeserver
.localpart(&form.username)
.unwrap_or(&form.username);

match login(
password_manager,
&mut repo,
rng,
&clock,
limiter,
requester,
&form.username,
username,
&form.password,
user_agent,
)
Expand Down Expand Up @@ -479,30 +484,34 @@ mod test {
.contains(&escape_html(&second_provider_login.path_and_query())));
}

#[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
async fn test_password_login(pool: PgPool) {
setup();
let state = TestState::from_pool(pool).await.unwrap();
async fn user_with_password(state: &TestState, username: &str, password: &str) {
let mut rng = state.rng();
let cookies = CookieHelper::new();

// Provision a user with a password
let mut repo = state.repository().await.unwrap();
let user = repo
.user()
.add(&mut rng, &state.clock, "john".to_owned())
.add(&mut rng, &state.clock, username.to_owned())
.await
.unwrap();
let (version, hash) = state
.password_manager
.hash(&mut rng, Zeroizing::new("hunter2".as_bytes().to_vec()))
.hash(&mut rng, Zeroizing::new(password.as_bytes().to_vec()))
.await
.unwrap();
repo.user_password()
.add(&mut rng, &state.clock, &user, version, hash, None)
.await
.unwrap();
repo.save().await.unwrap();
}

#[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
async fn test_password_login(pool: PgPool) {
setup();
let state = TestState::from_pool(pool).await.unwrap();
let cookies = CookieHelper::new();

// Provision a user with a password
user_with_password(&state, "john", "hunter2").await;

// Render the login page to get a CSRF token
let request = Request::get("/login").empty();
Expand Down Expand Up @@ -542,6 +551,93 @@ mod test {
assert!(response.body().contains("john"));
}

#[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
async fn test_password_login_with_mxid(pool: PgPool) {
setup();
let state = TestState::from_pool(pool).await.unwrap();
let cookies = CookieHelper::new();

// Provision a user with a password
user_with_password(&state, "john", "hunter2").await;

// Render the login page to get a CSRF token
let request = Request::get("/login").empty();
let request = cookies.with_cookies(request);
let response = state.request(request).await;
cookies.save_cookies(&response);
response.assert_status(StatusCode::OK);
response.assert_header_value(CONTENT_TYPE, "text/html; charset=utf-8");
// Extract the CSRF token from the response body
let csrf_token = response
.body()
.split("name=\"csrf\" value=\"")
.nth(1)
.unwrap()
.split('\"')
.next()
.unwrap();

// Submit the login form
let request = Request::post("/login").form(serde_json::json!({
"csrf": csrf_token,
"username": "@john:example.com",
"password": "hunter2",
}));
let request = cookies.with_cookies(request);
let response = state.request(request).await;
cookies.save_cookies(&response);
response.assert_status(StatusCode::SEE_OTHER);

// Now if we get to the home page, we should see the user's username
let request = Request::get("/").empty();
let request = cookies.with_cookies(request);
let response = state.request(request).await;
cookies.save_cookies(&response);
response.assert_status(StatusCode::OK);
response.assert_header_value(CONTENT_TYPE, "text/html; charset=utf-8");
assert!(response.body().contains("john"));
}

#[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
async fn test_password_login_with_mxid_wrong_server(pool: PgPool) {
setup();
let state = TestState::from_pool(pool).await.unwrap();
let cookies = CookieHelper::new();

// Provision a user with a password
user_with_password(&state, "john", "hunter2").await;

// Render the login page to get a CSRF token
let request = Request::get("/login").empty();
let request = cookies.with_cookies(request);
let response = state.request(request).await;
cookies.save_cookies(&response);
response.assert_status(StatusCode::OK);
response.assert_header_value(CONTENT_TYPE, "text/html; charset=utf-8");
// Extract the CSRF token from the response body
let csrf_token = response
.body()
.split("name=\"csrf\" value=\"")
.nth(1)
.unwrap()
.split('\"')
.next()
.unwrap();

// Submit the login form
let request = Request::post("/login").form(serde_json::json!({
"csrf": csrf_token,
"username": "@john:something.corp",
"password": "hunter2",
}));
let request = cookies.with_cookies(request);
let response = state.request(request).await;

// This shouldn't have worked, we're back on the login page
response.assert_status(StatusCode::OK);
assert!(response.body().contains("Invalid credentials"));
}

#[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
async fn test_password_login_rate_limit(pool: PgPool) {
setup();
Expand Down
1 change: 1 addition & 0 deletions crates/matrix/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,4 @@ workspace = true
anyhow.workspace = true
async-trait.workspace = true
tokio.workspace = true
ruma-common.workspace = true
18 changes: 18 additions & 0 deletions crates/matrix/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ mod mock;

use std::{collections::HashSet, sync::Arc};

use ruma_common::UserId;

pub use self::mock::HomeserverConnection as MockHomeserverConnection;

// TODO: this should probably be another error type by default
Expand Down Expand Up @@ -193,6 +195,22 @@ pub trait HomeserverConnection: Send + Sync {
format!("@{}:{}", localpart, self.homeserver())
}

/// Get the localpart of a Matrix ID if it has the right server name
///
/// Returns [`None`] if the input isn't a valid MXID, or if the server name
/// doesn't match
///
/// # Parameters
///
/// * `mxid` - The MXID of the user
fn localpart<'a>(&self, mxid: &'a str) -> Option<&'a str> {
let mxid = <&UserId>::try_from(mxid).ok()?;
if mxid.server_name() != self.homeserver() {
return None;
}
Some(mxid.localpart())
}

/// Query the state of a user on the homeserver.
///
/// # Parameters
Expand Down

0 comments on commit 00b31b8

Please sign in to comment.