Skip to content

Commit

Permalink
Fix oauth 0.5 breakage
Browse files Browse the repository at this point in the history
Updates `divviup-api` to adapt to the breaking changes in `oauth 0.5`.
Much of this follows from advice in the [upgrade guide][1].

- move entire crate to `http 1.1.0`
    - use `http-compat-1` feature on `trillium-http`
- adopt builder pattern for constructing `oauth2::BasicClient`
    - add type alias `ConfiguredOauthClient` as a shorthand for
      `oauth2::BasicClient<...>`
- add wrapper around `trillium_client::Client` so we can implement
  `oauth2::AsyncHttpClient` on it
    - translate `oauth2::HttpRequest/Response` to/from Trillium
      equivalents

[1]: https://github.com/ramosbugs/oauth2-rs/blob/main/UPGRADE.md
  • Loading branch information
tgeoghegan committed Feb 5, 2025
1 parent 3e006a0 commit 2b8900f
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 57 deletions.
45 changes: 17 additions & 28 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ trillium-compression = "0.1.3"
trillium-conn-id = "0.2.3"
trillium-cookies = "0.4.2"
trillium-forwarding = "0.2.4"
trillium-http = { version = "0.3.14", features = ["http-compat", "serde"] }
trillium-http = { version = "0.3.14", features = ["http-compat-1", "serde"] }
trillium-logger = "0.4.5"
trillium-macros = "0.0.6"
trillium-prometheus = "0.2.0"
Expand Down
89 changes: 61 additions & 28 deletions src/handler/oauth2.rs
Original file line number Diff line number Diff line change
@@ -1,19 +1,30 @@
use crate::{User, USER_SESSION_KEY};
use oauth2::{
basic::{BasicClient, BasicErrorResponseType},
AuthUrl, AuthorizationCode, ClientId, ClientSecret, CsrfToken, HttpResponse, PkceCodeChallenge,
PkceCodeVerifier, RedirectUrl, RequestTokenError, Scope, StandardErrorResponse, TokenResponse,
TokenUrl,
AsyncHttpClient, AuthUrl, AuthorizationCode, ClientId, ClientSecret, CsrfToken, EndpointNotSet,
EndpointSet, HttpRequest, HttpResponse, PkceCodeChallenge, PkceCodeVerifier, RedirectUrl,
RequestTokenError, Scope, StandardErrorResponse, TokenResponse, TokenUrl,
};
use querystrong::QueryStrong;
use std::sync::Arc;
use std::{future::Future, pin::Pin, sync::Arc};
use trillium::{conn_try, conn_unwrap, Conn, KnownHeaderName::Authorization, Status};
use trillium_client::{Client, ClientSerdeError};
use trillium_http::Headers;
use trillium_redirect::RedirectConnExt;
use trillium_sessions::SessionConnExt;
use url::Url;

/// Type alias for an oauth2::Client once we've finished configuring it in `OauthClient::new`.
/// Crate oauth's guide to upgrading to 0.5 recommends defining this kind of alias:
/// https://github.com/ramosbugs/oauth2-rs/blob/main/UPGRADE.md#add-typestate-generic-types-to-client
pub type ConfiguredOauthClient = BasicClient<
EndpointSet, // HasAuthURL
EndpointNotSet, // HasDeviceAuthURL
EndpointNotSet, // HasIntrospectionURL
EndpointNotSet, // HasRevocationURL
EndpointSet, // HasTokenURL
>;

#[derive(Clone, Debug)]
pub struct Oauth2Config {
pub authorize_url: Url,
Expand Down Expand Up @@ -93,7 +104,7 @@ enum OauthError {
#[error(transparent)]
InvalidStatusCode(#[from] oauth2::http::status::InvalidStatusCode),
#[error(transparent)]
HeaderConversionError(#[from] trillium_http::http_compat::HeaderConversionError),
HeaderConversionError(#[from] trillium_http::http_compat1::HeaderConversionError),
#[error(transparent)]
UrlError(#[from] url::ParseError),
#[error("error response: {0}")]
Expand All @@ -104,6 +115,8 @@ enum OauthError {
Other(String),
#[error("expected a successful status, but found {0:?}")]
UnexpectedStatus(Option<Status>),
#[error(transparent)]
HttpCrateError(#[from] oauth2::http::Error),
}

impl From<RequestTokenError<OauthError, StandardErrorResponse<BasicErrorResponseType>>>
Expand Down Expand Up @@ -138,7 +151,7 @@ pub struct OauthClient(Arc<OauthClientInner>);
#[derive(Debug)]
struct OauthClientInner {
oauth_config: Oauth2Config,
oauth2_client: BasicClient,
oauth2_client: ConfiguredOauthClient,
}

impl OauthClient {
Expand All @@ -153,20 +166,7 @@ impl OauthClient {
.exchange_code(auth_code)
.set_pkce_verifier(pkce_verifier)
.add_extra_param("audience", &self.0.oauth_config.audience)
.request_async(|req| async move {
let mut conn = http_client
.build_conn(req.method, req.url)
.with_body(req.body)
.with_request_headers(Headers::from(req.headers))
.await?;
let status_code = conn.status().unwrap().try_into()?;
let body = conn.response_body().read_bytes().await?;
Ok::<_, OauthError>(HttpResponse {
status_code,
headers: conn.response_headers().clone().try_into()?,
body,
})
})
.request_async(&ClientWrapper(http_client))
.await?;

let mut client_conn = self
Expand All @@ -190,25 +190,58 @@ impl OauthClient {
}

pub fn new(config: &Oauth2Config) -> Self {
let oauth2_client = BasicClient::new(
ClientId::new(config.client_id.clone()),
Some(ClientSecret::new(config.client_secret.clone())),
AuthUrl::from_url(config.authorize_url.clone()),
Some(TokenUrl::from_url(config.token_url.clone())),
)
.set_redirect_uri(RedirectUrl::from_url(config.redirect_url.clone()));
let oauth2_client = BasicClient::new(ClientId::new(config.client_id.clone()))
.set_client_secret(ClientSecret::new(config.client_secret.clone()))
.set_auth_uri(AuthUrl::from_url(config.authorize_url.clone()))
.set_token_uri(TokenUrl::from_url(config.token_url.clone()))
.set_redirect_uri(RedirectUrl::from_url(config.redirect_url.clone()));

Self(Arc::new(OauthClientInner {
oauth_config: config.clone(),
oauth2_client,
}))
}

pub fn oauth2_client(&self) -> &BasicClient {
pub fn oauth2_client(&self) -> &ConfiguredOauthClient {
&self.0.oauth2_client
}

pub fn http_client(&self) -> &Client {
&self.0.oauth_config.http_client
}
}

// Wraps a [`trillium_client::Client`] so we can implement [`oauth2::AsyncHttpClient`] on it, as
// otherwise the orphan rule would forbid this.
struct ClientWrapper(Client);

// Inspired by the impls `oauth2` provides for `reqwest::Client`
// https://github.com/ramosbugs/oauth2-rs/blob/23b952b23e6069525bc7e4c4f2c4924b8d28ce3a/src/reqwest.rs
impl<'c> AsyncHttpClient<'c> for ClientWrapper {
type Error = OauthError;
type Future = Pin<Box<dyn Future<Output = Result<HttpResponse, Self::Error>> + Send + 'c>>;

fn call(&'c self, req: HttpRequest) -> Self::Future {
Box::pin(async move {
// Translate the oauth2::http::Request into a Trillium request
let mut conn = self
.0
.build_conn(req.method(), req.uri().to_string().parse::<Url>()?)
.with_body(req.body().clone())
.with_request_headers(Headers::from(req.headers().clone()))
.await?;
let status_code: oauth2::http::StatusCode = conn.status().unwrap().try_into()?;
let body = conn.response_body().read_bytes().await?;

// Now transform the Trillium response back into an http::Response
let mut builder = oauth2::http::Response::builder().status(status_code);
let http_headers: oauth2::http::HeaderMap =
conn.response_headers().clone().try_into()?;
builder
.headers_mut()
.ok_or_else(|| OauthError::Other("no headers in builder?".into()))?
.extend(http_headers);
Ok::<_, OauthError>(builder.body(body)?)
})
}
}

0 comments on commit 2b8900f

Please sign in to comment.