-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add token refresh to CAS/Shard Client (#30)
Co-authored-by: Brian Ronan <[email protected]>
- Loading branch information
Showing
28 changed files
with
769 additions
and
245 deletions.
There are no files selected for viewing
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
use anyhow::anyhow; | ||
use cas::auth::{AuthConfig, TokenProvider}; | ||
use reqwest::header::HeaderValue; | ||
use reqwest::header::AUTHORIZATION; | ||
use reqwest::{Request, Response}; | ||
use reqwest_middleware::{Middleware, Next}; | ||
use std::sync::{Arc, Mutex}; | ||
|
||
/// AuthMiddleware is a thread-safe middleware that adds a CAS auth token to outbound requests. | ||
/// If the token it holds is expired, it will automatically be refreshed. | ||
pub struct AuthMiddleware { | ||
token_provider: Arc<Mutex<TokenProvider>>, | ||
} | ||
|
||
impl AuthMiddleware { | ||
/// Fetches a token from our TokenProvider. This locks the TokenProvider as we might need | ||
/// to refresh the token if it has expired. | ||
/// | ||
/// In the common case, this lock is held only to read the underlying token stored | ||
/// in memory. However, in the event of an expired token (e.g. once every 15 min), | ||
/// we will need to hold the lock while making a call to refresh the token | ||
/// (e.g. to a remote service). During this time, no other CAS requests can proceed | ||
/// from this client until the token has been fetched. This is expected/ok since we | ||
/// don't have a valid token and thus any calls would fail. | ||
fn get_token(&self) -> Result<String, anyhow::Error> { | ||
let mut provider = self | ||
.token_provider | ||
.lock() | ||
.map_err(|e| anyhow!("lock error: {e:?}"))?; | ||
provider | ||
.get_valid_token() | ||
.map_err(|e| anyhow!("couldn't get token: {e:?}")) | ||
} | ||
} | ||
|
||
impl From<&AuthConfig> for AuthMiddleware { | ||
fn from(cfg: &AuthConfig) -> Self { | ||
Self { | ||
token_provider: Arc::new(Mutex::new(TokenProvider::new(cfg))), | ||
} | ||
} | ||
} | ||
|
||
#[async_trait::async_trait] | ||
impl Middleware for AuthMiddleware { | ||
async fn handle( | ||
&self, | ||
mut req: Request, | ||
extensions: &mut hyper::http::Extensions, | ||
next: Next<'_>, | ||
) -> reqwest_middleware::Result<Response> { | ||
let token = self | ||
.get_token() | ||
.map_err(reqwest_middleware::Error::Middleware)?; | ||
|
||
let headers = req.headers_mut(); | ||
headers.insert( | ||
AUTHORIZATION, | ||
HeaderValue::from_str(&format!("Bearer {}", token)).unwrap(), | ||
); | ||
next.run(req, extensions).await | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.