Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow logging in with the full MXID #3908

Merged
merged 4 commits into from
Jan 29, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 4 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,10 @@ version = "0.12.12"
default-features = false
features = ["http2", "rustls-tls-manual-roots", "charset", "json", "socks"]

# Matrix-related types
[workspace.dependencies.ruma-common]
version = "0.15.0"

# TLS stack
[workspace.dependencies.rustls]
version = "0.23.21"
Expand Down
2 changes: 1 addition & 1 deletion crates/data-model/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ rand.workspace = true
rand_chacha = "0.3.1"
regex = "1.11.1"
woothee = "0.13.0"
ruma-common = "0.15.0"
ruma-common.workspace = true

mas-iana.workspace = true
mas-jose.workspace = true
Expand Down
8 changes: 4 additions & 4 deletions crates/data-model/src/oauth2/authorization_grant.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ use rand::{
distributions::{Alphanumeric, DistString},
RngCore,
};
use ruma_common::{OwnedUserId, UserId};
use ruma_common::UserId;
use serde::Serialize;
use ulid::Ulid;
use url::Url;
Expand Down Expand Up @@ -142,8 +142,8 @@ impl AuthorizationGrantStage {
}
}

pub enum LoginHint {
MXID(OwnedUserId),
pub enum LoginHint<'a> {
MXID(&'a UserId),
None,
}

Expand Down Expand Up @@ -200,7 +200,7 @@ impl AuthorizationGrant {
match prefix {
"mxid" => {
// Instead of erroring just return none
let Ok(mxid) = UserId::parse(value) else {
let Ok(mxid) = <&UserId>::try_from(value) else {
return LoginHint::None;
};

Expand Down
100 changes: 69 additions & 31 deletions crates/handlers/src/compat/login.rs
Original file line number Diff line number Diff line change
Expand Up @@ -386,10 +386,13 @@ async fn user_password_login(
username: String,
password: String,
) -> Result<(CompatSession, User), RouteError> {
// Try getting the localpart out of the MXID
let username = homeserver.localpart(&username).unwrap_or(&username);

// Find the user
let user = repo
.user()
.find_by_username(&username)
.find_by_username(username)
.await?
.filter(mas_data_model::User::is_valid)
.ok_or(RouteError::UserNotFound)?;
Expand Down Expand Up @@ -539,52 +542,43 @@ mod tests {
assert_eq!(body["errcode"], "M_UNRECOGNIZED");
}

/// Test that a user can login with a password using the Matrix
/// compatibility API.
#[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
async fn test_user_password_login(pool: PgPool) {
setup();
let state = TestState::from_pool(pool).await.unwrap();

// Let's provision a user and add a password to it. This part is hard to test
// with just HTTP requests, so we'll use the repository directly.
async fn user_with_password(state: &TestState, username: &str, password: &str) {
let mut rng = state.rng();
let mut repo = state.repository().await.unwrap();

let user = repo
.user()
.add(&mut state.rng(), &state.clock, "alice".to_owned())
.add(&mut rng, &state.clock, username.to_owned())
.await
.unwrap();
let (version, hash) = state
.password_manager
.hash(&mut rng, Zeroizing::new(password.as_bytes().to_vec()))
.await
.unwrap();

repo.user_password()
.add(&mut rng, &state.clock, &user, version, hash, None)
.await
.unwrap();
let mxid = state.homeserver_connection.mxid(&user.username);
state
.homeserver_connection
.provision_user(&ProvisionRequest::new(mxid, &user.sub))
.await
.unwrap();

let (version, hashed_password) = state
.password_manager
.hash(
&mut state.rng(),
Zeroizing::new("password".to_owned().into_bytes()),
)
.await
.unwrap();
repo.save().await.unwrap();
}

repo.user_password()
.add(
&mut state.rng(),
&state.clock,
&user,
version,
hashed_password,
None,
)
.await
.unwrap();
/// Test that a user can login with a password using the Matrix
/// compatibility API.
#[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
async fn test_user_password_login(pool: PgPool) {
setup();
let state = TestState::from_pool(pool).await.unwrap();

repo.save().await.unwrap();
user_with_password(&state, "alice", "password").await;

// Now let's try to login with the password, without asking for a refresh token.
let request = Request::post("/_matrix/client/v3/login").json(serde_json::json!({
Expand Down Expand Up @@ -662,6 +656,50 @@ mod tests {
assert_eq!(body, old_body);
}

/// Test that a user can login with a password using the Matrix
/// compatibility API, using a MXID as identifier
#[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
async fn test_user_password_login_mxid(pool: PgPool) {
setup();
let state = TestState::from_pool(pool).await.unwrap();

user_with_password(&state, "alice", "password").await;

// Login with a full MXID as identifier
let request = Request::post("/_matrix/client/v3/login").json(serde_json::json!({
"type": "m.login.password",
"identifier": {
"type": "m.id.user",
"user": "@alice:example.com",
},
"password": "password",
}));

let response = state.request(request).await;
response.assert_status(StatusCode::OK);
let body: ResponseBody = response.json();
assert!(!body.access_token.is_empty());
assert_eq!(body.device_id.as_ref().unwrap().as_str().len(), 10);
assert_eq!(body.user_id, "@alice:example.com");
assert_eq!(body.refresh_token, None);
assert_eq!(body.expires_in_ms, None);

// With a MXID, but with the wrong server name
let request = Request::post("/_matrix/client/v3/login").json(serde_json::json!({
"type": "m.login.password",
"identifier": {
"type": "m.id.user",
"user": "@alice:something.corp",
},
"password": "password",
}));

let response = state.request(request).await;
response.assert_status(StatusCode::FORBIDDEN);
let body: serde_json::Value = response.json();
assert_eq!(body["errcode"], "M_FORBIDDEN");
}

/// Test that password logins are rate limited.
#[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
async fn test_password_login_rate_limit(pool: PgPool) {
Expand Down
116 changes: 106 additions & 10 deletions crates/handlers/src/views/login.rs
Original file line number Diff line number Diff line change
Expand Up @@ -168,14 +168,19 @@ pub(crate) async fn post(
return Ok((cookie_jar, Html(content)).into_response());
}

// Extract the localpart of the MXID, fallback to the bare username
let username = homeserver
.localpart(&form.username)
.unwrap_or(&form.username);

match login(
password_manager,
&mut repo,
rng,
&clock,
limiter,
requester,
&form.username,
username,
&form.password,
user_agent,
)
Expand Down Expand Up @@ -479,30 +484,34 @@ mod test {
.contains(&escape_html(&second_provider_login.path_and_query())));
}

#[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
async fn test_password_login(pool: PgPool) {
setup();
let state = TestState::from_pool(pool).await.unwrap();
async fn user_with_password(state: &TestState, username: &str, password: &str) {
let mut rng = state.rng();
let cookies = CookieHelper::new();

// Provision a user with a password
let mut repo = state.repository().await.unwrap();
let user = repo
.user()
.add(&mut rng, &state.clock, "john".to_owned())
.add(&mut rng, &state.clock, username.to_owned())
.await
.unwrap();
let (version, hash) = state
.password_manager
.hash(&mut rng, Zeroizing::new("hunter2".as_bytes().to_vec()))
.hash(&mut rng, Zeroizing::new(password.as_bytes().to_vec()))
.await
.unwrap();
repo.user_password()
.add(&mut rng, &state.clock, &user, version, hash, None)
.await
.unwrap();
repo.save().await.unwrap();
}

#[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
async fn test_password_login(pool: PgPool) {
setup();
let state = TestState::from_pool(pool).await.unwrap();
let cookies = CookieHelper::new();

// Provision a user with a password
user_with_password(&state, "john", "hunter2").await;

// Render the login page to get a CSRF token
let request = Request::get("/login").empty();
Expand Down Expand Up @@ -542,6 +551,93 @@ mod test {
assert!(response.body().contains("john"));
}

#[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
async fn test_password_login_with_mxid(pool: PgPool) {
setup();
let state = TestState::from_pool(pool).await.unwrap();
let cookies = CookieHelper::new();

// Provision a user with a password
user_with_password(&state, "john", "hunter2").await;

// Render the login page to get a CSRF token
let request = Request::get("/login").empty();
let request = cookies.with_cookies(request);
let response = state.request(request).await;
cookies.save_cookies(&response);
response.assert_status(StatusCode::OK);
response.assert_header_value(CONTENT_TYPE, "text/html; charset=utf-8");
// Extract the CSRF token from the response body
let csrf_token = response
.body()
.split("name=\"csrf\" value=\"")
.nth(1)
.unwrap()
.split('\"')
.next()
.unwrap();

// Submit the login form
let request = Request::post("/login").form(serde_json::json!({
"csrf": csrf_token,
"username": "@john:example.com",
"password": "hunter2",
}));
let request = cookies.with_cookies(request);
let response = state.request(request).await;
cookies.save_cookies(&response);
response.assert_status(StatusCode::SEE_OTHER);

// Now if we get to the home page, we should see the user's username
let request = Request::get("/").empty();
let request = cookies.with_cookies(request);
let response = state.request(request).await;
cookies.save_cookies(&response);
response.assert_status(StatusCode::OK);
response.assert_header_value(CONTENT_TYPE, "text/html; charset=utf-8");
assert!(response.body().contains("john"));
}

#[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
async fn test_password_login_with_mxid_wrong_server(pool: PgPool) {
setup();
let state = TestState::from_pool(pool).await.unwrap();
let cookies = CookieHelper::new();

// Provision a user with a password
user_with_password(&state, "john", "hunter2").await;

// Render the login page to get a CSRF token
let request = Request::get("/login").empty();
let request = cookies.with_cookies(request);
let response = state.request(request).await;
cookies.save_cookies(&response);
response.assert_status(StatusCode::OK);
response.assert_header_value(CONTENT_TYPE, "text/html; charset=utf-8");
// Extract the CSRF token from the response body
let csrf_token = response
.body()
.split("name=\"csrf\" value=\"")
.nth(1)
.unwrap()
.split('\"')
.next()
.unwrap();

// Submit the login form
let request = Request::post("/login").form(serde_json::json!({
"csrf": csrf_token,
"username": "@john:something.corp",
"password": "hunter2",
}));
let request = cookies.with_cookies(request);
let response = state.request(request).await;

// This shouldn't have worked, we're back on the login page
response.assert_status(StatusCode::OK);
assert!(response.body().contains("Invalid credentials"));
}

#[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
async fn test_password_login_rate_limit(pool: PgPool) {
setup();
Expand Down
1 change: 1 addition & 0 deletions crates/matrix/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,4 @@ workspace = true
anyhow.workspace = true
async-trait.workspace = true
tokio.workspace = true
ruma-common.workspace = true
18 changes: 18 additions & 0 deletions crates/matrix/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ mod mock;

use std::{collections::HashSet, sync::Arc};

use ruma_common::UserId;

pub use self::mock::HomeserverConnection as MockHomeserverConnection;

// TODO: this should probably be another error type by default
Expand Down Expand Up @@ -193,6 +195,22 @@ pub trait HomeserverConnection: Send + Sync {
format!("@{}:{}", localpart, self.homeserver())
}

/// Get the localpart of a Matrix ID if it has the right server name
///
/// Returns [`None`] if the input isn't a valid MXID, or if the server name
/// doesn't match
///
/// # Parameters
///
/// * `mxid` - The MXID of the user
fn localpart<'a>(&self, mxid: &'a str) -> Option<&'a str> {
let mxid = <&UserId>::try_from(mxid).ok()?;
if mxid.server_name() != self.homeserver() {
return None;
}
Some(mxid.localpart())
}

/// Query the state of a user on the homeserver.
///
/// # Parameters
Expand Down
Loading