From b1324ef332ad43c449ee112adbcbb374742ca869 Mon Sep 17 00:00:00 2001 From: Sheikah45 Date: Fri, 8 Dec 2023 17:57:54 -0500 Subject: [PATCH] Properly invalidate the token on log out --- .../faforever/client/api/TokenRetriever.java | 45 +++++++++++-------- .../client/chat/KittehChatService.java | 9 ++-- .../faforever/client/user/LoginService.java | 21 ++++----- .../client/api/TokenRetrieverTest.java | 19 +++++--- .../client/user/LoginServiceTest.java | 10 ++--- 5 files changed, 58 insertions(+), 46 deletions(-) diff --git a/src/main/java/com/faforever/client/api/TokenRetriever.java b/src/main/java/com/faforever/client/api/TokenRetriever.java index 6fb04ea508..1199a521d1 100644 --- a/src/main/java/com/faforever/client/api/TokenRetriever.java +++ b/src/main/java/com/faforever/client/api/TokenRetriever.java @@ -5,12 +5,11 @@ import com.faforever.client.login.NoRefreshTokenException; import com.faforever.client.login.TokenRetrievalException; import com.faforever.client.preferences.LoginPrefs; -import javafx.beans.property.ReadOnlyBooleanProperty; -import javafx.beans.property.ReadOnlyBooleanWrapper; import javafx.beans.property.SimpleStringProperty; import javafx.beans.property.StringProperty; import lombok.RequiredArgsConstructor; import lombok.extern.slf4j.Slf4j; +import org.springframework.beans.factory.InitializingBean; import org.springframework.http.MediaType; import org.springframework.security.oauth2.core.OAuth2AccessToken; import org.springframework.security.oauth2.core.OAuth2RefreshToken; @@ -20,7 +19,11 @@ import org.springframework.util.LinkedMultiValueMap; import org.springframework.util.MultiValueMap; import org.springframework.web.reactive.function.client.WebClient; +import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; +import reactor.core.publisher.Sinks; +import reactor.core.publisher.Sinks.EmitFailureHandler; +import reactor.core.publisher.Sinks.Many; import java.net.URI; import java.time.Duration; @@ -29,20 +32,33 @@ @Component @Slf4j @RequiredArgsConstructor -public class TokenRetriever { +public class TokenRetriever implements InitializingBean { private final ClientProperties clientProperties; private final WebClient defaultWebClient; private final LoginPrefs loginPrefs; + private final Many invalidateSink = Sinks.many().multicast().directBestEffort(); + private final Flux invalidateFlux = invalidateSink.asFlux().publish().autoConnect(); private final StringProperty refreshTokenValue = new SimpleStringProperty(); - private final ReadOnlyBooleanWrapper tokenInvalid = new ReadOnlyBooleanWrapper(true); private final Mono refreshedTokenMono = Mono.defer(this::refreshAccess) - .cacheInvalidateIf(token -> tokenInvalid.get() || Duration.between(Instant.now(), token.getExpiresAt()) - .minusSeconds(30) - .isNegative()) + .cacheInvalidateWhen(this::getExpirationMono) + .map(OAuth2AccessToken::getTokenValue); + @Override + public void afterPropertiesSet() throws Exception { + refreshTokenValue.set(loginPrefs.getRefreshToken()); + loginPrefs.refreshTokenProperty() + .bind(loginPrefs.rememberMeProperty().flatMap(remember -> remember ? refreshTokenValue : null)); + } + + private Mono getExpirationMono(OAuth2AccessToken token) { + Mono invalidationMono = invalidateFlux.next(); + Mono expirationMono = Mono.delay(Duration.between(Instant.now(), token.getExpiresAt()).minusSeconds(30)); + return Mono.firstWithSignal(invalidationMono, expirationMono).then(); + } + public Mono getRefreshedTokenValue() { return refreshedTokenMono.doOnError(this::onTokenError); } @@ -59,7 +75,6 @@ public Mono loginWithAuthorizationCode(String code, String codeVerifier, U } public Mono loginWithRefreshToken() { - refreshTokenValue.set(loginPrefs.getRefreshToken()); return refreshedTokenMono.then(); } @@ -71,7 +86,6 @@ private void onTokenError(Throwable throwable) { private Mono refreshAccess() { String refreshToken = refreshTokenValue.get(); if (refreshToken == null) { - loginPrefs.setRefreshToken(null); return Mono.error(new NoRefreshTokenException("No refresh token to log in with")); } @@ -103,24 +117,17 @@ private Mono retrieveToken(MultiValueMap prop .doOnNext(tokenResponse -> { OAuth2RefreshToken refreshToken = tokenResponse.getRefreshToken(); refreshTokenValue.set(refreshToken != null ? refreshToken.getTokenValue() : null); - loginPrefs.setRefreshToken(loginPrefs.isRememberMe() ? refreshTokenValue.get() : null); - tokenInvalid.set(false); }) .map(OAuth2AccessTokenResponse::getAccessToken) .doOnNext(token -> log.info("Token valid until {}", token.getExpiresAt())); } public void invalidateToken() { - tokenInvalid.set(true); refreshTokenValue.set(null); - loginPrefs.setRefreshToken(null); - } - - public boolean isTokenInvalid() { - return tokenInvalid.get(); + invalidateSink.emitNext(0L, EmitFailureHandler.busyLooping(Duration.ofMillis(100))); } - public ReadOnlyBooleanProperty tokenInvalidProperty() { - return tokenInvalid.getReadOnlyProperty(); + public Flux invalidationFlux() { + return invalidateFlux; } } diff --git a/src/main/java/com/faforever/client/chat/KittehChatService.java b/src/main/java/com/faforever/client/chat/KittehChatService.java index 289a052003..a109bc1e48 100644 --- a/src/main/java/com/faforever/client/chat/KittehChatService.java +++ b/src/main/java/com/faforever/client/chat/KittehChatService.java @@ -6,7 +6,6 @@ import com.faforever.client.domain.PlayerBean; import com.faforever.client.fx.FxApplicationThreadExecutor; import com.faforever.client.fx.JavaFxUtil; -import com.faforever.client.fx.SimpleChangeListener; import com.faforever.client.main.event.NavigateEvent; import com.faforever.client.main.event.NavigationItem; import com.faforever.client.navigation.NavigationHandler; @@ -135,7 +134,7 @@ public class KittehChatService implements ChatService, InitializingBean, Disposa @Override public void afterPropertiesSet() { - loginService.loggedInProperty().addListener((SimpleChangeListener) loggedIn -> { + loginService.loggedInProperty().subscribe(loggedIn -> { if (loggedIn) { connect(); } else { @@ -513,8 +512,10 @@ public void connect() { @Override public void disconnect() { autoReconnect = false; - log.info("Disconnecting from IRC"); - client.shutdown("Goodbye"); + if (client != null) { + log.info("Disconnecting from IRC"); + client.shutdown("Goodbye"); + } } @Override diff --git a/src/main/java/com/faforever/client/user/LoginService.java b/src/main/java/com/faforever/client/user/LoginService.java index ff426298de..8f7d4657d6 100644 --- a/src/main/java/com/faforever/client/user/LoginService.java +++ b/src/main/java/com/faforever/client/user/LoginService.java @@ -4,7 +4,6 @@ import com.faforever.client.api.TokenRetriever; import com.faforever.client.config.ClientProperties; import com.faforever.client.config.ClientProperties.Oauth; -import com.faforever.client.fx.SimpleChangeListener; import com.faforever.client.net.ConnectionState; import com.faforever.client.notification.NotificationService; import com.faforever.client.preferences.LoginPrefs; @@ -50,12 +49,12 @@ public class LoginService implements InitializingBean { @Override public void afterPropertiesSet() throws Exception { - tokenRetriever.tokenInvalidProperty().addListener(((SimpleChangeListener) tokenInvalid -> { - if (tokenInvalid && loggedIn.get()) { + tokenRetriever.invalidationFlux().doOnNext(ignored -> { + if (loggedIn.get()) { notificationService.addImmediateInfoNotification("session.expired.message"); logOut(); } - })); + }).doOnError(throwable -> log.error("Error invalidation", throwable)).retry().subscribe(); } public String getHydraUrl(String state, String codeVerifier, URI redirectUri) { @@ -84,9 +83,7 @@ private Mono loginToServices() { ownUser.set(meResult); ownPlayer.set(mePlayer); - })).doOnError(throwable -> resetLoginState()).then(Mono.fromRunnable(() -> { - loggedIn.set(true); - })); + })).doOnError(throwable -> resetLoginState()).then(Mono.fromRunnable(() -> loggedIn.set(true))); } private Mono loginToApi() { @@ -111,17 +108,17 @@ public Integer getUserId() { public void logOut() { log.info("Logging out"); - loginPrefs.setRefreshToken(null); - loggedIn.set(false); resetLoginState(); } private void resetLoginState() { - tokenRetriever.invalidateToken(); - fafApiAccessor.reset(); - fafServerAccessor.disconnect(); + loginPrefs.setRememberMe(false); ownUser.set(null); ownPlayer.set(null); + fafApiAccessor.reset(); + fafServerAccessor.disconnect(); + loggedIn.set(false); + tokenRetriever.invalidateToken(); } public MeResult getOwnUser() { diff --git a/src/test/java/com/faforever/client/api/TokenRetrieverTest.java b/src/test/java/com/faforever/client/api/TokenRetrieverTest.java index 942b4c14a9..ef8c851911 100644 --- a/src/test/java/com/faforever/client/api/TokenRetrieverTest.java +++ b/src/test/java/com/faforever/client/api/TokenRetrieverTest.java @@ -18,14 +18,13 @@ import java.net.URI; import java.net.URLDecoder; import java.nio.charset.StandardCharsets; +import java.time.Duration; import java.util.Arrays; import java.util.Map; import java.util.Map.Entry; import java.util.stream.Collectors; import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertFalse; -import static org.junit.jupiter.api.Assertions.assertTrue; public class TokenRetrieverTest extends ServiceTest { @@ -55,6 +54,7 @@ public void setUp() throws Exception { loginPrefs.setRefreshToken("abc"); instance = new TokenRetriever(clientProperties, WebClient.builder().build(), loginPrefs); + instance.afterPropertiesSet(); } private void prepareTokenResponse(Map tokenProperties) throws Exception { @@ -70,7 +70,10 @@ private void prepareErrorResponse() { public void testLoginWithCode() throws Exception { Map tokenProperties = Map.of(ACCESS_TOKEN, "test", REFRESH_TOKEN, "refresh", EXPIRES_IN, "90", TOKEN_TYPE, "bearer"); prepareTokenResponse(tokenProperties); - assertTrue(instance.isTokenInvalid()); + StepVerifier verifier = StepVerifier.create(instance.invalidationFlux()) + .expectNextCount(0) + .thenCancel() + .verifyLater(); StepVerifier.create(instance.loginWithAuthorizationCode("abc", VERIFIER, REDIRECT_URI)).verifyComplete(); String request = URLDecoder.decode(mockApi.takeRequest() @@ -86,7 +89,7 @@ public void testLoginWithCode() throws Exception { assertEquals("authorization_code", requestParams.get("grant_type")); assertEquals(oauth.getClientId(), requestParams.get("client_id")); assertEquals(REDIRECT_URI.toString(), requestParams.get("redirect_uri")); - assertFalse(instance.isTokenInvalid()); + verifier.verify(Duration.ofSeconds(1)); } @Test @@ -94,7 +97,10 @@ public void testLoginWithRefresh() throws Exception { Map tokenProperties = Map.of(ACCESS_TOKEN, "test", REFRESH_TOKEN, "refresh", EXPIRES_IN, "90", TOKEN_TYPE, "bearer"); prepareTokenResponse(tokenProperties); - assertTrue(instance.isTokenInvalid()); + StepVerifier verifier = StepVerifier.create(instance.invalidationFlux()) + .expectNextCount(0) + .thenCancel() + .verifyLater(); StepVerifier.create(instance.loginWithRefreshToken()).verifyComplete(); String request = URLDecoder.decode(mockApi.takeRequest() @@ -108,7 +114,7 @@ public void testLoginWithRefresh() throws Exception { assertEquals(REFRESH_TOKEN, requestParams.get("grant_type")); assertEquals(oauth.getClientId(), requestParams.get("client_id")); - assertFalse(instance.isTokenInvalid()); + verifier.verify(Duration.ofSeconds(1)); } @Test @@ -157,6 +163,7 @@ public void testInvalidation() throws Exception { @Test public void testNoToken() { + instance.invalidateToken(); StepVerifier.create(instance.getRefreshedTokenValue()).verifyError(NoRefreshTokenException.class); } diff --git a/src/test/java/com/faforever/client/user/LoginServiceTest.java b/src/test/java/com/faforever/client/user/LoginServiceTest.java index 93e0d0a7d5..83b17ae5a3 100644 --- a/src/test/java/com/faforever/client/user/LoginServiceTest.java +++ b/src/test/java/com/faforever/client/user/LoginServiceTest.java @@ -12,13 +12,13 @@ import com.faforever.client.test.ServiceTest; import com.faforever.commons.api.dto.MeResult; import com.faforever.commons.lobby.Player; -import javafx.beans.property.SimpleBooleanProperty; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.mockito.InjectMocks; import org.mockito.Mock; import org.mockito.Spy; import reactor.core.publisher.Mono; +import reactor.test.publisher.TestPublisher; import java.net.URI; import java.util.HashMap; @@ -42,8 +42,6 @@ public class LoginServiceTest extends ServiceTest { public static final String STATE = "abc"; public static final String VERIFIER = "def"; - private final SimpleBooleanProperty tokenInvalid = new SimpleBooleanProperty(false); - @Spy private ClientProperties clientProperties; @Mock @@ -62,6 +60,8 @@ public class LoginServiceTest extends ServiceTest { private MeResult meResult; private Player me; + private final TestPublisher invalidationTestPublisher = TestPublisher.create(); + @BeforeEach public void setUp() throws Exception { me = new Player(1, "junit", null, null, "", new HashMap<>(), new HashMap<>()); @@ -75,7 +75,7 @@ public void setUp() throws Exception { oauth.setRedirectUri(REDIRECT_URI); oauth.setScopes(SCOPES); - when(tokenRetriever.tokenInvalidProperty()).thenReturn(tokenInvalid); + when(tokenRetriever.invalidationFlux()).thenReturn(invalidationTestPublisher.flux()); instance.afterPropertiesSet(); } @@ -267,7 +267,7 @@ public void testOwnUserStartsAsNull() { @Test public void testOnSessionExpired() throws Exception { testLogin(); - tokenInvalid.set(true); + invalidationTestPublisher.next(0L); verify(notificationService).addImmediateInfoNotification(anyString()); }