diff --git a/api/src/main/java/ca/bc/gov/educ/api/graduation/config/RequestInterceptor.java b/api/src/main/java/ca/bc/gov/educ/api/graduation/config/RequestInterceptor.java index 9390b780..dd6bf097 100644 --- a/api/src/main/java/ca/bc/gov/educ/api/graduation/config/RequestInterceptor.java +++ b/api/src/main/java/ca/bc/gov/educ/api/graduation/config/RequestInterceptor.java @@ -14,6 +14,7 @@ import org.springframework.web.servlet.AsyncHandlerInterceptor; import java.time.Instant; +import java.util.UUID; @Component public class RequestInterceptor implements AsyncHandlerInterceptor { @@ -40,16 +41,26 @@ public boolean preHandle(HttpServletRequest request, HttpServletResponse respons } validation.clear(); val correlationID = request.getHeader(EducGraduationApiConstants.CORRELATION_ID); - if (correlationID != null) { - ThreadLocalStateUtil.setCorrelationID(correlationID); + ThreadLocalStateUtil.setCorrelationID(correlationID != null ? correlationID : UUID.randomUUID().toString()); + + //Request Source + val requestSource = request.getHeader(EducGraduationApiConstants.REQUEST_SOURCE); + if(requestSource != null) { + ThreadLocalStateUtil.setRequestSource(requestSource); } - // username - Authentication auth = SecurityContextHolder.getContext().getAuthentication(); - if (auth instanceof JwtAuthenticationToken authenticationToken) { - Jwt jwt = (Jwt) authenticationToken.getCredentials(); - String username = JwtUtil.getName(jwt, request); - ThreadLocalStateUtil.setCurrentUser(username); + // userName + val userName = request.getHeader(EducGraduationApiConstants.USERNAME); + if (userName != null) { + ThreadLocalStateUtil.setCurrentUser(userName); + } + else { + Authentication auth = SecurityContextHolder.getContext().getAuthentication(); + if (auth instanceof JwtAuthenticationToken authenticationToken) { + Jwt jwt = (Jwt) authenticationToken.getCredentials(); + String username = JwtUtil.getName(jwt, request); + ThreadLocalStateUtil.setCurrentUser(username); + } } return true; } @@ -65,10 +76,7 @@ public boolean preHandle(HttpServletRequest request, HttpServletResponse respons @Override public void afterCompletion(@NonNull final HttpServletRequest request, final HttpServletResponse response, @NonNull final Object handler, final Exception ex) { logHelper.logServerHttpReqResponseDetails(request, response, constants.isSplunkLogHelperEnabled()); - val correlationID = request.getHeader(EducGraduationApiConstants.CORRELATION_ID); - if (correlationID != null) { - response.setHeader(EducGraduationApiConstants.CORRELATION_ID, request.getHeader(EducGraduationApiConstants.CORRELATION_ID)); - ThreadLocalStateUtil.clear(); - } + ThreadLocalStateUtil.clear(); + } } diff --git a/api/src/main/java/ca/bc/gov/educ/api/graduation/config/RestWebClient.java b/api/src/main/java/ca/bc/gov/educ/api/graduation/config/RestWebClient.java index 9e0f721a..c1ccb09f 100644 --- a/api/src/main/java/ca/bc/gov/educ/api/graduation/config/RestWebClient.java +++ b/api/src/main/java/ca/bc/gov/educ/api/graduation/config/RestWebClient.java @@ -2,6 +2,7 @@ import ca.bc.gov.educ.api.graduation.util.EducGraduationApiConstants; import ca.bc.gov.educ.api.graduation.util.LogHelper; +import ca.bc.gov.educ.api.graduation.util.ThreadLocalStateUtil; import io.netty.handler.logging.LogLevel; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.context.annotation.Bean; @@ -14,6 +15,7 @@ import org.springframework.security.oauth2.client.web.DefaultOAuth2AuthorizedClientManager; import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository; import org.springframework.security.oauth2.client.web.reactive.function.client.ServletOAuth2AuthorizedClientExchangeFilterFunction; +import org.springframework.web.reactive.function.client.ClientRequest; import org.springframework.web.reactive.function.client.ExchangeFilterFunction; import org.springframework.web.reactive.function.client.ExchangeStrategies; import org.springframework.web.reactive.function.client.WebClient; @@ -40,6 +42,7 @@ public WebClient getGraduationClientWebClient(OAuth2AuthorizedClientManager auth ServletOAuth2AuthorizedClientExchangeFilterFunction filter = new ServletOAuth2AuthorizedClientExchangeFilterFunction(authorizedClientManager); filter.setDefaultClientRegistrationId("graduationclient"); return WebClient.builder() + .filter(addDynamicHeadersFilter()) .exchangeStrategies(ExchangeStrategies .builder() .codecs(codecs -> codecs @@ -50,6 +53,18 @@ public WebClient getGraduationClientWebClient(OAuth2AuthorizedClientManager auth .filter(this.log()) .build(); } + + private ExchangeFilterFunction addDynamicHeadersFilter() { + return (clientRequest, next) -> { + ClientRequest modifiedRequest = ClientRequest.from(clientRequest) + .header(EducGraduationApiConstants.CORRELATION_ID, ThreadLocalStateUtil.getCorrelationID()) + .header(EducGraduationApiConstants.USERNAME, ThreadLocalStateUtil.getCurrentUser()) + .header(EducGraduationApiConstants.REQUEST_SOURCE, EducGraduationApiConstants.API_NAME) + .build(); + return next.exchange(modifiedRequest); + }; + } + @Bean public OAuth2AuthorizedClientManager authorizedClientManager( ClientRegistrationRepository clientRegistrationRepository, @@ -69,7 +84,9 @@ public OAuth2AuthorizedClientManager authorizedClientManager( */ @Bean public WebClient webClient() { - return WebClient.builder().exchangeStrategies(ExchangeStrategies.builder() + return WebClient.builder() + .filter(addDynamicHeadersFilter()) + .exchangeStrategies(ExchangeStrategies.builder() .codecs(configurer -> configurer .defaultCodecs() .maxInMemorySize(300 * 1024 * 1024)) // 300MB @@ -88,6 +105,7 @@ private ExchangeFilterFunction log() { //GRAD2-1929 Refactoring/Linting replaced rawStatusCode() with statusCode() as it was deprecated. // clientResponse.rawStatusCode(), clientRequest.headers().get(EducGraduationApiConstants.CORRELATION_ID), + clientRequest.headers().get(EducGraduationApiConstants.REQUEST_SOURCE), constants.isSplunkLogHelperEnabled()) )); } diff --git a/api/src/main/java/ca/bc/gov/educ/api/graduation/service/RESTService.java b/api/src/main/java/ca/bc/gov/educ/api/graduation/service/RESTService.java index bb7eea45..83173b2b 100644 --- a/api/src/main/java/ca/bc/gov/educ/api/graduation/service/RESTService.java +++ b/api/src/main/java/ca/bc/gov/educ/api/graduation/service/RESTService.java @@ -50,7 +50,7 @@ public T get(String url, Class clazz, String accessToken) { obj = webClient .get() .uri(url) - .headers(h -> { h.setBearerAuth(accessToken); h.set(EducGraduationApiConstants.CORRELATION_ID, ThreadLocalStateUtil.getCorrelationID()); }) + .headers(h -> h.setBearerAuth(accessToken)) .retrieve() // if 5xx errors, throw Service error .onStatus(HttpStatusCode::is5xxServerError, @@ -77,10 +77,6 @@ public T get(String url, Class clazz) { obj = graduationServiceWebClient .get() .uri(url) - .headers(h -> { - h.set(EducGraduationApiConstants.CORRELATION_ID, ThreadLocalStateUtil.getCorrelationID()); - h.set(EducGraduationApiConstants.USERNAME, ThreadLocalStateUtil.getCurrentUser()); - }) .retrieve() // if 5xx errors, throw Service error .onStatus(HttpStatusCode::is5xxServerError, @@ -118,7 +114,7 @@ public T post(String url, Object body, Class clazz, String accessToken) { try { obj = webClient.post() .uri(url) - .headers(h -> { h.setBearerAuth(accessToken); h.set(EducGraduationApiConstants.CORRELATION_ID, ThreadLocalStateUtil.getCorrelationID()); }) + .headers(h -> h.setBearerAuth(accessToken)) .body(BodyInserters.fromValue(body)) .retrieve() .onStatus(HttpStatusCode::is5xxServerError, @@ -144,10 +140,6 @@ public T post(String url, Object body, Class clazz) { try { obj = graduationServiceWebClient.post() .uri(url) - .headers(h -> { - h.set(EducGraduationApiConstants.CORRELATION_ID, ThreadLocalStateUtil.getCorrelationID()); - h.set(EducGraduationApiConstants.USERNAME, ThreadLocalStateUtil.getCurrentUser()); - }) .body(BodyInserters.fromValue(body)) .retrieve() .onStatus(HttpStatusCode::is5xxServerError, diff --git a/api/src/main/java/ca/bc/gov/educ/api/graduation/util/EducGraduationApiConstants.java b/api/src/main/java/ca/bc/gov/educ/api/graduation/util/EducGraduationApiConstants.java index f4c2a608..84e1303b 100644 --- a/api/src/main/java/ca/bc/gov/educ/api/graduation/util/EducGraduationApiConstants.java +++ b/api/src/main/java/ca/bc/gov/educ/api/graduation/util/EducGraduationApiConstants.java @@ -13,7 +13,9 @@ public class EducGraduationApiConstants { public static final String CORRELATION_ID = "correlationID"; - public static final String USERNAME = "username"; + public static final String REQUEST_SOURCE = "Request-Source"; + public static final String API_NAME = "EDUC-GRAD-GRADUATION-API"; + public static final String USERNAME = "User-Name"; //API end-point Mapping constants public static final String API_ROOT_MAPPING = ""; diff --git a/api/src/main/java/ca/bc/gov/educ/api/graduation/util/LogHelper.java b/api/src/main/java/ca/bc/gov/educ/api/graduation/util/LogHelper.java index 19ad2cc5..c6d62050 100644 --- a/api/src/main/java/ca/bc/gov/educ/api/graduation/util/LogHelper.java +++ b/api/src/main/java/ca/bc/gov/educ/api/graduation/util/LogHelper.java @@ -40,6 +40,10 @@ public void logServerHttpReqResponseDetails(@NonNull final HttpServletRequest re if (correlationID != null) { httpMap.put("correlation_id", correlationID); } + val requestSource = request.getHeader(EducGraduationApiConstants.REQUEST_SOURCE); + if (requestSource != null) { + httpMap.put("request_source", requestSource); + } httpMap.put("server_http_request_url", String.valueOf(request.getRequestURL())); httpMap.put("server_http_request_processing_time_ms", totalTime); httpMap.put("server_http_request_payload", String.valueOf(request.getAttribute("payload"))); @@ -52,7 +56,8 @@ public void logServerHttpReqResponseDetails(@NonNull final HttpServletRequest re } } - public void logClientHttpReqResponseDetails(@NonNull final HttpMethod method, final String url, final int responseCode, final List correlationID, final boolean logging) { + public void logClientHttpReqResponseDetails(@NonNull final HttpMethod method, final String url, final int responseCode, final List correlationID, + final List requestSource, final boolean logging) { if (!logging) return; try { final Map httpMap = new HashMap<>(); @@ -62,6 +67,9 @@ public void logClientHttpReqResponseDetails(@NonNull final HttpMethod method, fi if (correlationID != null) { httpMap.put("correlation_id", String.join(",", correlationID)); } + if (requestSource != null) { + httpMap.put("request_source", String.join(",", requestSource)); + } MDC.putCloseable("httpEvent", jsonTransformer.marshall(httpMap)); log.info(""); MDC.clear(); diff --git a/api/src/main/java/ca/bc/gov/educ/api/graduation/util/ThreadLocalStateUtil.java b/api/src/main/java/ca/bc/gov/educ/api/graduation/util/ThreadLocalStateUtil.java index d1c5b824..23501b62 100644 --- a/api/src/main/java/ca/bc/gov/educ/api/graduation/util/ThreadLocalStateUtil.java +++ b/api/src/main/java/ca/bc/gov/educ/api/graduation/util/ThreadLocalStateUtil.java @@ -3,11 +3,31 @@ import java.util.Objects; public class ThreadLocalStateUtil { - private static ThreadLocal transaction = new ThreadLocal<>(); + private static InheritableThreadLocal transaction = new InheritableThreadLocal<>(); - private static ThreadLocal user = new ThreadLocal<>(); + private static InheritableThreadLocal user = new InheritableThreadLocal<>(); + private static InheritableThreadLocal requestSource = new InheritableThreadLocal<>(); + + + /** + * Set the requestSource for this thread + * + * @param reqSource + */ + public static void setRequestSource(String reqSource){ + requestSource.set(reqSource); + } /** + * Get the requestSource for this thread + * + * @return the reqSource, or null if it is unknown. + */ + public static String getRequestSource() { + return requestSource.get(); + } + + /** * Set the current correlationID for this thread * * @param correlationID @@ -46,5 +66,6 @@ public static String getCurrentUser() { public static void clear() { transaction.remove(); user.remove(); + requestSource.remove(); } } diff --git a/api/src/test/java/ca/bc/gov/educ/api/graduation/service/RESTServicePOSTTest.java b/api/src/test/java/ca/bc/gov/educ/api/graduation/service/RESTServicePOSTTest.java index dae657d4..a0b72065 100644 --- a/api/src/test/java/ca/bc/gov/educ/api/graduation/service/RESTServicePOSTTest.java +++ b/api/src/test/java/ca/bc/gov/educ/api/graduation/service/RESTServicePOSTTest.java @@ -1,21 +1,18 @@ package ca.bc.gov.educ.api.graduation.service; import ca.bc.gov.educ.api.graduation.exception.ServiceException; +import ca.bc.gov.educ.api.graduation.util.ThreadLocalStateUtil; import org.junit.Assert; import org.junit.Before; import org.junit.Test; import org.junit.jupiter.api.extension.ExtendWith; import org.junit.runner.RunWith; -import org.mockito.InjectMocks; -import org.mockito.Mock; +import org.mockito.Mockito; import org.mockito.junit.jupiter.MockitoExtension; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Qualifier; import org.springframework.boot.test.context.SpringBootTest; -import org.springframework.boot.test.context.TestConfiguration; import org.springframework.boot.test.mock.mockito.MockBean; -import org.springframework.context.annotation.Bean; -import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository; import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository; import org.springframework.test.context.junit4.SpringRunner; @@ -64,10 +61,14 @@ public class RESTServicePOSTTest { @Before public void setUp(){ + Mockito.reset(webClient, graduationServiceWebClient, responseMock, requestHeadersMock, requestBodyMock, requestBodyUriMock); + + ThreadLocalStateUtil.clear(); + when(this.webClient.post()).thenReturn(this.requestBodyUriMock); when(this.graduationServiceWebClient.post()).thenReturn(this.requestBodyUriMock); - when(this.requestBodyUriMock.uri(any(String.class))).thenReturn(this.requestBodyUriMock); - when(this.requestBodyUriMock.headers(any(Consumer.class))).thenReturn(this.requestBodyMock); + when(this.requestBodyUriMock.uri(any(String.class))).thenReturn(this.requestBodyMock); + when(this.requestBodyMock.headers(any(Consumer.class))).thenReturn(this.requestBodyMock); when(this.requestBodyMock.contentType(any())).thenReturn(this.requestBodyMock); when(this.requestBodyMock.body(any(BodyInserter.class))).thenReturn(this.requestHeadersMock); when(this.requestHeadersMock.retrieve()).thenReturn(this.responseMock); @@ -76,13 +77,18 @@ public void setUp(){ @Test public void testPost_GivenProperData_Expect200Response(){ + ThreadLocalStateUtil.setCorrelationID("test-correlation-id"); + ThreadLocalStateUtil.setCurrentUser("test-user"); when(this.responseMock.onStatus(any(), any())).thenReturn(this.responseMock); byte[] response = this.restService.post(TEST_URL, TEST_BODY, byte[].class, ACCESS_TOKEN); Assert.assertArrayEquals(TEST_BYTES, response); + } @Test public void testPostOverride_GivenProperData_Expect200Response(){ + ThreadLocalStateUtil.setCorrelationID("test-correlation-id"); + ThreadLocalStateUtil.setCurrentUser("test-user"); when(this.responseMock.onStatus(any(), any())).thenReturn(this.responseMock); byte[] response = this.restService.post(TEST_URL, TEST_BODY, byte[].class); Assert.assertArrayEquals(TEST_BYTES, response);