Skip to content

Commit

Permalink
feat: PKCE support refresh_token.
Browse files Browse the repository at this point in the history
  • Loading branch information
loren-coding committed Oct 16, 2024
1 parent 84727b2 commit e6250ef
Show file tree
Hide file tree
Showing 3 changed files with 119 additions and 57 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
import org.springframework.security.core.AuthenticationException;
import org.springframework.security.oauth2.core.AuthorizationGrantType;
import org.springframework.security.oauth2.core.ClaimAccessor;
import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
import org.springframework.security.oauth2.core.OAuth2AccessToken;
import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
import org.springframework.security.oauth2.core.OAuth2Error;
Expand Down Expand Up @@ -142,17 +141,34 @@ public Authentication authenticate(Authentication authentication) throws Authent
}

if (!authorizationCode.isActive()) {
if (authorizationCode.isInvalidated()) {
OAuth2Authorization.Token<? extends OAuth2Token> token = (authorization.getRefreshToken() != null)
? authorization.getRefreshToken() : authorization.getAccessToken();
if (token != null) {
// Invalidate the access (and refresh) token as the client is
// attempting to use the authorization code more than once
authorization = OAuth2AuthenticationProviderUtils.invalidate(authorization, token.getToken());
this.authorizationService.save(authorization);
if (this.logger.isWarnEnabled()) {
this.logger.warn(LogMessage.format(
"Invalidated authorization token(s) previously issued to registered client '%s'",
registeredClient.getId()));
}
}
}
throw new OAuth2AuthenticationException(OAuth2ErrorCodes.INVALID_GRANT);
}

if (this.logger.isTraceEnabled()) {
this.logger.trace("Validated token request parameters");
}

Authentication principal = authorization.getAttribute(Principal.class.getName());

// @formatter:off
DefaultOAuth2TokenContext.Builder tokenContextBuilder = DefaultOAuth2TokenContext.builder()
.registeredClient(registeredClient)
.principal(authorization.getAttribute(Principal.class.getName()))
.principal(principal)
.authorizationServerContext(AuthorizationServerContextHolder.getContext())
.authorization(authorization)
.authorizedScopes(authorization.getAuthorizedScopes())
Expand Down Expand Up @@ -181,30 +197,31 @@ public Authentication authenticate(Authentication authentication) throws Authent
if (generatedAccessToken instanceof ClaimAccessor) {
authorizationBuilder.token(accessToken, (metadata) ->
metadata.put(OAuth2Authorization.Token.CLAIMS_METADATA_NAME, ((ClaimAccessor) generatedAccessToken).getClaims()));
} else {
}
else {
authorizationBuilder.accessToken(accessToken);
}

// ----- Refresh token -----
OAuth2RefreshToken refreshToken = null;
if (registeredClient.getAuthorizationGrantTypes().contains(AuthorizationGrantType.REFRESH_TOKEN) &&
// Do not issue refresh token to public client
!clientPrincipal.getClientAuthenticationMethod().equals(ClientAuthenticationMethod.NONE)) {

// Do not issue refresh token to public client
if (registeredClient.getAuthorizationGrantTypes().contains(AuthorizationGrantType.REFRESH_TOKEN)) {
tokenContext = tokenContextBuilder.tokenType(OAuth2TokenType.REFRESH_TOKEN).build();
OAuth2Token generatedRefreshToken = this.tokenGenerator.generate(tokenContext);
if (!(generatedRefreshToken instanceof OAuth2RefreshToken)) {
OAuth2Error error = new OAuth2Error(OAuth2ErrorCodes.SERVER_ERROR,
"The token generator failed to generate the refresh token.", ERROR_URI);
throw new OAuth2AuthenticationException(error);
}
if (generatedRefreshToken != null) {
if (!(generatedRefreshToken instanceof OAuth2RefreshToken)) {
OAuth2Error error = new OAuth2Error(OAuth2ErrorCodes.SERVER_ERROR,
"The token generator failed to generate a valid refresh token.", ERROR_URI);
throw new OAuth2AuthenticationException(error);
}

if (this.logger.isTraceEnabled()) {
this.logger.trace("Generated refresh token");
}
if (this.logger.isTraceEnabled()) {
this.logger.trace("Generated refresh token");
}

refreshToken = (OAuth2RefreshToken) generatedRefreshToken;
authorizationBuilder.refreshToken(refreshToken);
refreshToken = (OAuth2RefreshToken) generatedRefreshToken;
authorizationBuilder.refreshToken(refreshToken);
}
}

// ----- ID token -----
Expand All @@ -231,7 +248,8 @@ public Authentication authenticate(Authentication authentication) throws Authent
generatedIdToken.getExpiresAt(), ((Jwt) generatedIdToken).getClaims());
authorizationBuilder.token(idToken, (metadata) ->
metadata.put(OAuth2Authorization.Token.CLAIMS_METADATA_NAME, idToken.getClaims()));
} else {
}
else {
idToken = null;
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2020-2022 the original author or authors.
* Copyright 2020-2023 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -21,8 +21,11 @@
import org.springframework.lang.Nullable;
import org.springframework.security.crypto.keygen.Base64StringKeyGenerator;
import org.springframework.security.crypto.keygen.StringKeyGenerator;
import org.springframework.security.oauth2.core.AuthorizationGrantType;
import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
import org.springframework.security.oauth2.core.OAuth2RefreshToken;
import org.springframework.security.oauth2.server.authorization.OAuth2TokenType;
import org.springframework.security.oauth2.server.authorization.authentication.OAuth2ClientAuthenticationToken;

/**
* An {@link OAuth2TokenGenerator} that generates an {@link OAuth2RefreshToken}.
Expand All @@ -33,18 +36,35 @@
* @see OAuth2RefreshToken
*/
public final class OAuth2RefreshTokenGenerator implements OAuth2TokenGenerator<OAuth2RefreshToken> {
private final StringKeyGenerator refreshTokenGenerator =
new Base64StringKeyGenerator(Base64.getUrlEncoder().withoutPadding(), 96);

private final StringKeyGenerator refreshTokenGenerator = new Base64StringKeyGenerator(
Base64.getUrlEncoder().withoutPadding(), 96);

@Nullable
@Override
public OAuth2RefreshToken generate(OAuth2TokenContext context) {
if (!OAuth2TokenType.REFRESH_TOKEN.equals(context.getTokenType())) {
return null;
}
if (isPublicClientForAuthorizationCodeGrant(context)) {
// Do not issue refresh token to public client
return null;
}

Instant issuedAt = Instant.now();
Instant expiresAt = issuedAt.plus(context.getRegisteredClient().getTokenSettings().getRefreshTokenTimeToLive());
return new OAuth2RefreshToken(this.refreshTokenGenerator.generateKey(), issuedAt, expiresAt);
}

private static boolean isPublicClientForAuthorizationCodeGrant(OAuth2TokenContext context) {
// @formatter:off
if (AuthorizationGrantType.AUTHORIZATION_CODE.equals(context.getAuthorizationGrantType()) &&
(context.getAuthorizationGrant().getPrincipal() instanceof OAuth2ClientAuthenticationToken)) {
return ((OAuth2ClientAuthenticationToken) context.getAuthorizationGrant()
.getPrincipal()).getClientAuthenticationMethod().equals(ClientAuthenticationMethod.NONE);
}
// @formatter:on
return false;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import org.springframework.security.core.Authentication;
import org.springframework.security.oauth2.core.AuthorizationGrantType;
import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
import org.springframework.security.oauth2.core.OAuth2AccessToken;
import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
import org.springframework.security.oauth2.core.OAuth2Token;
Expand Down Expand Up @@ -74,6 +75,8 @@
import static org.assertj.core.api.Assertions.entry;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.BDDMockito.given;
import static org.mockito.BDDMockito.willAnswer;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.spy;
Expand Down Expand Up @@ -118,7 +121,8 @@ public OAuth2Token generate(OAuth2TokenContext context) {
});
this.authenticationProvider = new OAuth2AuthorizationCodeAuthenticationProvider(
this.authorizationService, this.tokenGenerator);
AuthorizationServerSettings authorizationServerSettings = AuthorizationServerSettings.builder().issuer("https://provider.com").build();
AuthorizationServerSettings authorizationServerSettings = AuthorizationServerSettings.builder()
.issuer("https://provider.com").build();
AuthorizationServerContextHolder.setContext(new TestAuthorizationServerContext(authorizationServerSettings, null));
}

Expand Down Expand Up @@ -302,7 +306,8 @@ public void authenticateWhenAccessTokenNotGeneratedThenThrowOAuth2Authentication
OAuth2TokenContext context = answer.getArgument(0);
if (OAuth2TokenType.ACCESS_TOKEN.equals(context.getTokenType())) {
return null;
} else {
}
else {
return answer.callRealMethod();
}
}).when(this.tokenGenerator).generate(any());
Expand All @@ -317,36 +322,39 @@ public void authenticateWhenAccessTokenNotGeneratedThenThrowOAuth2Authentication
}

@Test
public void authenticateWhenRefreshTokenNotGeneratedThenThrowOAuth2AuthenticationException() {
public void authenticateWhenInvalidRefreshTokenGeneratedThenThrowOAuth2AuthenticationException() {
RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build();
OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient).build();
when(this.authorizationService.findByToken(eq(AUTHORIZATION_CODE), eq(AUTHORIZATION_CODE_TOKEN_TYPE)))
.thenReturn(authorization);
given(this.authorizationService.findByToken(eq(AUTHORIZATION_CODE), eq(AUTHORIZATION_CODE_TOKEN_TYPE)))
.willReturn(authorization);

OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(
registeredClient, ClientAuthenticationMethod.CLIENT_SECRET_BASIC, registeredClient.getClientSecret());
OAuth2AuthorizationRequest authorizationRequest = authorization.getAttribute(
OAuth2AuthorizationRequest.class.getName());
OAuth2AuthorizationCodeAuthenticationToken authentication =
new OAuth2AuthorizationCodeAuthenticationToken(AUTHORIZATION_CODE, clientPrincipal, authorizationRequest.getRedirectUri(), null);
OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(registeredClient,
ClientAuthenticationMethod.CLIENT_SECRET_BASIC, registeredClient.getClientSecret());
OAuth2AuthorizationRequest authorizationRequest = authorization
.getAttribute(OAuth2AuthorizationRequest.class.getName());
OAuth2AuthorizationCodeAuthenticationToken authentication = new OAuth2AuthorizationCodeAuthenticationToken(
AUTHORIZATION_CODE, clientPrincipal, authorizationRequest.getRedirectUri(), null);

when(this.jwtEncoder.encode(any())).thenReturn(createJwt());
given(this.jwtEncoder.encode(any())).willReturn(createJwt());

doAnswer(answer -> {
willAnswer((answer) -> {
OAuth2TokenContext context = answer.getArgument(0);
if (OAuth2TokenType.REFRESH_TOKEN.equals(context.getTokenType())) {
return null;
} else {
return new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER, "access-token", Instant.now(),
Instant.now().plusSeconds(300));
}
else {
return answer.callRealMethod();
}
}).when(this.tokenGenerator).generate(any());
}).given(this.tokenGenerator).generate(any());

assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication))
.isInstanceOf(OAuth2AuthenticationException.class)
.extracting(ex -> ((OAuth2AuthenticationException) ex).getError())
.satisfies(error -> {
.extracting((ex) -> ((OAuth2AuthenticationException) ex).getError())
.satisfies((error) -> {
assertThat(error.getErrorCode()).isEqualTo(OAuth2ErrorCodes.SERVER_ERROR);
assertThat(error.getDescription()).contains("The token generator failed to generate the refresh token.");
assertThat(error.getDescription())
.contains("The token generator failed to generate a valid refresh token.");
});
}

Expand All @@ -370,7 +378,8 @@ public void authenticateWhenIdTokenNotGeneratedThenThrowOAuth2AuthenticationExce
OAuth2TokenContext context = answer.getArgument(0);
if (OidcParameterNames.ID_TOKEN.equals(context.getTokenType().getValue())) {
return null;
} else {
}
else {
return answer.callRealMethod();
}
}).when(this.tokenGenerator).generate(any());
Expand Down Expand Up @@ -428,12 +437,16 @@ public void authenticateWhenValidCodeThenReturnAccessToken() {
verify(this.authorizationService).save(authorizationCaptor.capture());
OAuth2Authorization updatedAuthorization = authorizationCaptor.getValue();

assertThat(accessTokenAuthentication.getRegisteredClient().getId()).isEqualTo(updatedAuthorization.getRegisteredClientId());
assertThat(accessTokenAuthentication.getRegisteredClient()
.getId()).isEqualTo(updatedAuthorization.getRegisteredClientId());
assertThat(accessTokenAuthentication.getPrincipal()).isEqualTo(clientPrincipal);
assertThat(accessTokenAuthentication.getAccessToken()).isEqualTo(updatedAuthorization.getAccessToken().getToken());
assertThat(accessTokenAuthentication.getAccessToken().getScopes()).isEqualTo(authorization.getAuthorizedScopes());
assertThat(accessTokenAuthentication.getAccessToken()).isEqualTo(updatedAuthorization.getAccessToken()
.getToken());
assertThat(accessTokenAuthentication.getAccessToken()
.getScopes()).isEqualTo(authorization.getAuthorizedScopes());
assertThat(accessTokenAuthentication.getRefreshToken()).isNotNull();
assertThat(accessTokenAuthentication.getRefreshToken()).isEqualTo(updatedAuthorization.getRefreshToken().getToken());
assertThat(accessTokenAuthentication.getRefreshToken()).isEqualTo(updatedAuthorization.getRefreshToken()
.getToken());
OAuth2Authorization.Token<OAuth2AuthorizationCode> authorizationCode = updatedAuthorization.getToken(OAuth2AuthorizationCode.class);
assertThat(authorizationCode.isInvalidated()).isTrue();
}
Expand All @@ -443,7 +456,8 @@ public void authenticateWhenValidCodeAndAuthenticationRequestThenReturnIdToken()
RegisteredClient registeredClient = TestRegisteredClients.registeredClient().scope(OidcScopes.OPENID).build();
OAuth2AuthorizationCode authorizationCode = new OAuth2AuthorizationCode(
"code", Instant.now(), Instant.now().plusSeconds(120));
OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient, authorizationCode).build();
OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient, authorizationCode)
.build();
when(this.authorizationService.findByToken(eq(AUTHORIZATION_CODE), eq(AUTHORIZATION_CODE_TOKEN_TYPE)))
.thenReturn(authorization);

Expand Down Expand Up @@ -490,19 +504,22 @@ public void authenticateWhenValidCodeAndAuthenticationRequestThenReturnIdToken()
assertThat(idTokenContext.getJwsHeader()).isNotNull();
assertThat(idTokenContext.getClaims()).isNotNull();

verify(this.jwtEncoder, times(2)).encode(any()); // Access token and ID Token
verify(this.jwtEncoder, times(2)).encode(any()); // Access token and ID Token

ArgumentCaptor<OAuth2Authorization> authorizationCaptor = ArgumentCaptor.forClass(OAuth2Authorization.class);
verify(this.authorizationService).save(authorizationCaptor.capture());
OAuth2Authorization updatedAuthorization = authorizationCaptor.getValue();

assertThat(accessTokenAuthentication.getRegisteredClient().getId()).isEqualTo(updatedAuthorization.getRegisteredClientId());
assertThat(accessTokenAuthentication.getRegisteredClient()
.getId()).isEqualTo(updatedAuthorization.getRegisteredClientId());
assertThat(accessTokenAuthentication.getPrincipal()).isEqualTo(clientPrincipal);
assertThat(accessTokenAuthentication.getAccessToken()).isEqualTo(updatedAuthorization.getAccessToken().getToken());
assertThat(accessTokenAuthentication.getAccessToken()).isEqualTo(updatedAuthorization.getAccessToken()
.getToken());
Set<String> accessTokenScopes = new HashSet<>(updatedAuthorization.getAuthorizedScopes());
assertThat(accessTokenAuthentication.getAccessToken().getScopes()).isEqualTo(accessTokenScopes);
assertThat(accessTokenAuthentication.getRefreshToken()).isNotNull();
assertThat(accessTokenAuthentication.getRefreshToken()).isEqualTo(updatedAuthorization.getRefreshToken().getToken());
assertThat(accessTokenAuthentication.getRefreshToken()).isEqualTo(updatedAuthorization.getRefreshToken()
.getToken());
OAuth2Authorization.Token<OAuth2AuthorizationCode> authorizationCodeToken = updatedAuthorization.getToken(OAuth2AuthorizationCode.class);
assertThat(authorizationCodeToken.isInvalidated()).isTrue();
OAuth2Authorization.Token<OidcIdToken> idToken = updatedAuthorization.getToken(OidcIdToken.class);
Expand Down Expand Up @@ -558,10 +575,13 @@ public void authenticateWhenPublicClientThenRefreshTokenNotIssued() {
verify(this.authorizationService).save(authorizationCaptor.capture());
OAuth2Authorization updatedAuthorization = authorizationCaptor.getValue();

assertThat(accessTokenAuthentication.getRegisteredClient().getId()).isEqualTo(updatedAuthorization.getRegisteredClientId());
assertThat(accessTokenAuthentication.getRegisteredClient()
.getId()).isEqualTo(updatedAuthorization.getRegisteredClientId());
assertThat(accessTokenAuthentication.getPrincipal()).isEqualTo(clientPrincipal);
assertThat(accessTokenAuthentication.getAccessToken()).isEqualTo(updatedAuthorization.getAccessToken().getToken());
assertThat(accessTokenAuthentication.getAccessToken().getScopes()).isEqualTo(authorization.getAuthorizedScopes());
assertThat(accessTokenAuthentication.getAccessToken()).isEqualTo(updatedAuthorization.getAccessToken()
.getToken());
assertThat(accessTokenAuthentication.getAccessToken()
.getScopes()).isEqualTo(authorization.getAuthorizedScopes());
assertThat(accessTokenAuthentication.getRefreshToken()).isNull();
OAuth2Authorization.Token<OAuth2AuthorizationCode> authorizationCode = updatedAuthorization.getToken(OAuth2AuthorizationCode.class);
assertThat(authorizationCode.isInvalidated()).isTrue();
Expand Down Expand Up @@ -600,13 +620,17 @@ public void authenticateWhenTokenTimeToLiveConfiguredThenTokenExpirySet() {
verify(this.authorizationService).save(authorizationCaptor.capture());
OAuth2Authorization updatedAuthorization = authorizationCaptor.getValue();

assertThat(accessTokenAuthentication.getAccessToken()).isEqualTo(updatedAuthorization.getAccessToken().getToken());
Instant expectedAccessTokenExpiresAt = accessTokenAuthentication.getAccessToken().getIssuedAt().plus(accessTokenTTL);
assertThat(accessTokenAuthentication.getAccessToken()).isEqualTo(updatedAuthorization.getAccessToken()
.getToken());
Instant expectedAccessTokenExpiresAt = accessTokenAuthentication.getAccessToken().getIssuedAt()
.plus(accessTokenTTL);
assertThat(accessTokenAuthentication.getAccessToken().getExpiresAt()).isBetween(
expectedAccessTokenExpiresAt.minusSeconds(1), expectedAccessTokenExpiresAt.plusSeconds(1));

assertThat(accessTokenAuthentication.getRefreshToken()).isEqualTo(updatedAuthorization.getRefreshToken().getToken());
Instant expectedRefreshTokenExpiresAt = accessTokenAuthentication.getRefreshToken().getIssuedAt().plus(refreshTokenTTL);
assertThat(accessTokenAuthentication.getRefreshToken()).isEqualTo(updatedAuthorization.getRefreshToken()
.getToken());
Instant expectedRefreshTokenExpiresAt = accessTokenAuthentication.getRefreshToken().getIssuedAt()
.plus(refreshTokenTTL);
assertThat(accessTokenAuthentication.getRefreshToken().getExpiresAt()).isBetween(
expectedRefreshTokenExpiresAt.minusSeconds(1), expectedRefreshTokenExpiresAt.plusSeconds(1));
}
Expand Down

0 comments on commit e6250ef

Please sign in to comment.