Skip to content

Commit

Permalink
Compute the OAuth expires_in from the SAML's @NotOnOrAfter
Browse files Browse the repository at this point in the history
Fixes #147
  • Loading branch information
qligier committed Jun 17, 2024
1 parent f32febf commit 07565dc
Showing 1 changed file with 52 additions and 6 deletions.
58 changes: 52 additions & 6 deletions src/main/java/ch/bfh/ti/i4mi/mag/xua/TokenEndpoint.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,29 @@

package ch.bfh.ti.i4mi.mag.xua;

import java.io.ByteArrayInputStream;
import java.io.UnsupportedEncodingException;
import java.nio.charset.StandardCharsets;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.time.Duration;
import java.time.Instant;
import java.util.Base64;

import org.apache.camel.Body;
import org.apache.camel.Header;
import org.ehcache.Cache;
import org.opensaml.Configuration;
import org.opensaml.saml2.core.Assertion;
import org.opensaml.xml.io.Unmarshaller;
import org.opensaml.xml.io.UnmarshallingException;
import org.opensaml.xml.parse.BasicParserPool;
import org.opensaml.xml.parse.XMLParserException;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value;

import lombok.extern.slf4j.Slf4j;
import org.w3c.dom.Element;

/**
* OAuth2 code to token exchange operation
Expand All @@ -46,8 +57,22 @@ public class TokenEndpoint {

@Value("${mag.iua.sp.disable-code-challenge:false}")
private boolean disableCodeChallenge;

private long defaultTimeout = 1000l * 60l;

/**
* The unmarshaller of Assertion elements.
*/
private final Unmarshaller assertionUnmarshaller;

/**
* The SAML parser pool.
*/
private final BasicParserPool samlParserPool;

public TokenEndpoint() {
this.assertionUnmarshaller = Configuration.getUnmarshallerFactory().getUnmarshaller(Assertion.DEFAULT_ELEMENT_NAME);
this.samlParserPool = new BasicParserPool();
this.samlParserPool.setNamespaceAware(true);
}

private void require(String field, String fieldname) throws AuthException {
if (field == null || field.trim().length() == 0) throw new AuthException(400, "invalid_request", "'"+fieldname+"' is required");
Expand All @@ -65,7 +90,8 @@ public OAuth2TokenResponse handle(
@Header("code_verifier") String codeVerifier,
@Header("client_id") String clientId,
@Header("client_secret") String clientSecret,
@Header("redirect_uri") String redirectUri) throws UnsupportedEncodingException, AuthException {
@Header("redirect_uri") String redirectUri)
throws UnsupportedEncodingException, AuthException, XMLParserException, UnmarshallingException {


mustMatch(grantType, "authorization_code", "grant_type");
Expand Down Expand Up @@ -97,7 +123,7 @@ public OAuth2TokenResponse handle(

OAuth2TokenResponse result = new OAuth2TokenResponse();
result.setAccess_token(encoded);
result.setExpires_in(defaultTimeout);
result.setExpires_in(this.computeExpiresInFromNotOnOrAfter(assertion)); // In seconds
result.setScope(request.getScope());
result.setToken_type("Bearer" /*request.getToken_type()*/);
return result;
Expand All @@ -117,13 +143,14 @@ public static String sha256ThenBase64(String input) throws AuthException {
}
}

public OAuth2TokenResponse handleFromIdp(@Body String assertion, @Header("scope") String scope) throws UnsupportedEncodingException, AuthException {
public OAuth2TokenResponse handleFromIdp(@Body String assertion, @Header("scope") String scope)
throws UnsupportedEncodingException, XMLParserException, UnmarshallingException {

String encoded = Base64.getEncoder().encodeToString(assertion.getBytes("UTF-8"));

OAuth2TokenResponse result = new OAuth2TokenResponse();
result.setAccess_token(encoded);
result.setExpires_in(defaultTimeout);
result.setExpires_in(this.computeExpiresInFromNotOnOrAfter(assertion)); // In seconds
result.setScope(scope);
result.setToken_type("Bearer" /*request.getToken_type()*/);
return result;
Expand All @@ -137,4 +164,23 @@ public ErrorResponse handleError(@Body AuthException in) {
return response;
}

/**
* Computes the number of seconds from now (inclusive) to the Assertion's @NotOnOrAfter attribute (exclusive).
*
* @param assertionXml The XML representation of the Assertion.
* @return a duration in seconds.
* @throws XMLParserException if the Assertion cannot be parsed.
* @throws UnmarshallingException if the Assertion cannot be unmarshalled.
*/
private long computeExpiresInFromNotOnOrAfter(final String assertionXml)
throws XMLParserException, UnmarshallingException {
// Parse the assertion and extract the NotOnOrAfter attribute
final Element element = this.samlParserPool
.parse(new ByteArrayInputStream(assertionXml.getBytes(StandardCharsets.UTF_8)))
.getDocumentElement();
final Assertion assertion = (Assertion) this.assertionUnmarshaller.unmarshall(element);

final Instant notOnOrAfter = assertion.getConditions().getNotOnOrAfter().toDate().toInstant();
return Duration.between(Instant.now(), notOnOrAfter).getSeconds();
}
}

0 comments on commit 07565dc

Please sign in to comment.