From 73a9b3668b629e291dd9940539e7811a7e67c25f Mon Sep 17 00:00:00 2001 From: sahandilshan Date: Wed, 8 Jan 2025 13:24:08 +0530 Subject: [PATCH] Add connection timeout support with other modifications - Get tokens with token binding - Add j2 configs - Modify AIHttpUtil tests to use call mockwebserver --- .../org.wso2.carbon.ai.service.mgt/pom.xml | 14 +- .../ai/service/mgt/constants/AIConstants.java | 3 +- .../mgt/exceptions/AIClientException.java | 2 +- .../mgt/exceptions/AIServerException.java | 2 +- .../mgt/token/AIAccessTokenManager.java | 6 +- .../ai/service/mgt/util/AIHttpClientUtil.java | 28 +- .../mgt/token/AIAccessTokenManagerTest.java | 32 +- .../mgt/util/AIHttpClientUtilTest.java | 485 +++++++++++------- components/ai-services-mgt/pom.xml | 2 +- .../pom.xml | 2 +- features/ai-services-mgt/pom.xml | 2 +- .../resources/identity.xml.j2 | 2 + pom.xml | 8 + 13 files changed, 377 insertions(+), 211 deletions(-) diff --git a/components/ai-services-mgt/org.wso2.carbon.ai.service.mgt/pom.xml b/components/ai-services-mgt/org.wso2.carbon.ai.service.mgt/pom.xml index 09fb33b2d040..24f5adf220fa 100644 --- a/components/ai-services-mgt/org.wso2.carbon.ai.service.mgt/pom.xml +++ b/components/ai-services-mgt/org.wso2.carbon.ai.service.mgt/pom.xml @@ -22,7 +22,7 @@ org.wso2.carbon.identity.framework ai-services-mgt - 7.6.10-SNAPSHOT + 7.7.85-SNAPSHOT ../pom.xml @@ -72,6 +72,12 @@ log4j-core test + + org.wiremock + wiremock + test + + org.ops4j.pax.logging pax-logging-api @@ -110,7 +116,11 @@ org.apache.http.util; version="${httpcore.version.osgi.import.range}", org.apache.http.impl.client; version="${httpcomponents-httpclient.imp.pkg.version.range}", org.apache.http.impl.nio.client; version="${httpasyncclient.version.osgi.import.range}", + org.apache.http.impl.nio.reactor; version="${httpasyncclient.version.osgi.import.range}", + org.apache.http.impl.nio.conn; version="${httpasyncclient.version.osgi.import.range}", org.apache.http.concurrent; version="${httpcore.version.osgi.import.range}", + org.apache.http.nio.reactor; version="${httpasyncclient.version.osgi.import.range}", + org.apache.http.nio.conn; version="${httpasyncclient.version.osgi.import.range}", org.wso2.carbon.ai.service.mgt.*; version="${carbon.identity.package.export.version}" @@ -179,7 +189,7 @@ COMPLEXITY COVEREDRATIO - 0.82 + 0.77 diff --git a/components/ai-services-mgt/org.wso2.carbon.ai.service.mgt/src/main/java/org/wso2/carbon/ai/service/mgt/constants/AIConstants.java b/components/ai-services-mgt/org.wso2.carbon.ai.service.mgt/src/main/java/org/wso2/carbon/ai/service/mgt/constants/AIConstants.java index 5aaf4e254c0c..252fd02f35f6 100644 --- a/components/ai-services-mgt/org.wso2.carbon.ai.service.mgt/src/main/java/org/wso2/carbon/ai/service/mgt/constants/AIConstants.java +++ b/components/ai-services-mgt/org.wso2.carbon.ai.service.mgt/src/main/java/org/wso2/carbon/ai/service/mgt/constants/AIConstants.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2024, WSO2 LLC. (http://www.wso2.com). + * Copyright (c) 2025, WSO2 LLC. (http://www.wso2.com). * * WSO2 LLC. licenses this file to you under the Apache License, * Version 2.0 (the "License"); you may not use this file except @@ -28,6 +28,7 @@ public class AIConstants { public static final String AI_TOKEN_SERVICE_MAX_RETRIES_PROPERTY_NAME = "AIServices.TokenRequestMaxRetries"; public static final String AI_TOKEN_SERVICE_TIMEOUT_PROPERTY_NAME = "AIServices.TokenRequestTimeout"; public static final String HTTP_CONNECTION_POOL_SIZE_PROPERTY_NAME = "AIServices.HTTPConnectionPoolSize"; + public static final String HTTP_CONNECTION_TIMEOUT_PROPERTY_NAME = "AIServices.HTTPConnectionTimeout"; // Http constants. public static final String HTTP_BASIC = "Basic"; diff --git a/components/ai-services-mgt/org.wso2.carbon.ai.service.mgt/src/main/java/org/wso2/carbon/ai/service/mgt/exceptions/AIClientException.java b/components/ai-services-mgt/org.wso2.carbon.ai.service.mgt/src/main/java/org/wso2/carbon/ai/service/mgt/exceptions/AIClientException.java index 5eb2129f5032..3272e573177e 100644 --- a/components/ai-services-mgt/org.wso2.carbon.ai.service.mgt/src/main/java/org/wso2/carbon/ai/service/mgt/exceptions/AIClientException.java +++ b/components/ai-services-mgt/org.wso2.carbon.ai.service.mgt/src/main/java/org/wso2/carbon/ai/service/mgt/exceptions/AIClientException.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2024, WSO2 LLC. (http://www.wso2.com). + * Copyright (c) 2025, WSO2 LLC. (http://www.wso2.com). * * WSO2 LLC. licenses this file to you under the Apache License, * Version 2.0 (the "License"); you may not use this file except diff --git a/components/ai-services-mgt/org.wso2.carbon.ai.service.mgt/src/main/java/org/wso2/carbon/ai/service/mgt/exceptions/AIServerException.java b/components/ai-services-mgt/org.wso2.carbon.ai.service.mgt/src/main/java/org/wso2/carbon/ai/service/mgt/exceptions/AIServerException.java index 8910e9164640..1bdadca5a81e 100644 --- a/components/ai-services-mgt/org.wso2.carbon.ai.service.mgt/src/main/java/org/wso2/carbon/ai/service/mgt/exceptions/AIServerException.java +++ b/components/ai-services-mgt/org.wso2.carbon.ai.service.mgt/src/main/java/org/wso2/carbon/ai/service/mgt/exceptions/AIServerException.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2024, WSO2 LLC. (http://www.wso2.com). + * Copyright (c) 2025, WSO2 LLC. (http://www.wso2.com). * * WSO2 LLC. licenses this file to you under the Apache License, * Version 2.0 (the "License"); you may not use this file except diff --git a/components/ai-services-mgt/org.wso2.carbon.ai.service.mgt/src/main/java/org/wso2/carbon/ai/service/mgt/token/AIAccessTokenManager.java b/components/ai-services-mgt/org.wso2.carbon.ai.service.mgt/src/main/java/org/wso2/carbon/ai/service/mgt/token/AIAccessTokenManager.java index efb292960414..d4c06c6c8c28 100644 --- a/components/ai-services-mgt/org.wso2.carbon.ai.service.mgt/src/main/java/org/wso2/carbon/ai/service/mgt/token/AIAccessTokenManager.java +++ b/components/ai-services-mgt/org.wso2.carbon.ai.service.mgt/src/main/java/org/wso2/carbon/ai/service/mgt/token/AIAccessTokenManager.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2024, WSO2 LLC. (http://www.wso2.com). + * Copyright (c) 2025, WSO2 LLC. (http://www.wso2.com). * * WSO2 LLC. licenses this file to you under the Apache License, * Version 2.0 (the "License"); you may not use this file except @@ -42,6 +42,7 @@ import java.nio.charset.StandardCharsets; import java.util.Base64; import java.util.Map; +import java.util.UUID; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; @@ -192,7 +193,8 @@ public String requestAccessToken() throws AIServerException { post.setHeader(AUTHORIZATION, HTTP_BASIC + " " + key); post.setHeader(HEADER_CONTENT_TYPE, CONTENT_TYPE_FORM_URLENCODED); - StringEntity entity = new StringEntity("grant_type=client_credentials"); + StringEntity entity = new StringEntity("grant_type=client_credentials&tokenBindingId=" + + UUID.randomUUID()); entity.setContentType(new BasicHeader(HTTP.CONTENT_TYPE, CONTENT_TYPE_FORM_URLENCODED)); post.setEntity(entity); diff --git a/components/ai-services-mgt/org.wso2.carbon.ai.service.mgt/src/main/java/org/wso2/carbon/ai/service/mgt/util/AIHttpClientUtil.java b/components/ai-services-mgt/org.wso2.carbon.ai.service.mgt/src/main/java/org/wso2/carbon/ai/service/mgt/util/AIHttpClientUtil.java index acc88191ccba..f4d8aa73292d 100644 --- a/components/ai-services-mgt/org.wso2.carbon.ai.service.mgt/src/main/java/org/wso2/carbon/ai/service/mgt/util/AIHttpClientUtil.java +++ b/components/ai-services-mgt/org.wso2.carbon.ai.service.mgt/src/main/java/org/wso2/carbon/ai/service/mgt/util/AIHttpClientUtil.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2024, WSO2 LLC. (http://www.wso2.com). + * Copyright (c) 2025, WSO2 LLC. (http://www.wso2.com). * * WSO2 LLC. licenses this file to you under the Apache License, * Version 2.0 (the "License"); you may not use this file except @@ -30,6 +30,10 @@ import org.apache.http.entity.StringEntity; import org.apache.http.impl.nio.client.CloseableHttpAsyncClient; import org.apache.http.impl.nio.client.HttpAsyncClients; +import org.apache.http.impl.nio.conn.PoolingNHttpClientConnectionManager; +import org.apache.http.impl.nio.reactor.DefaultConnectingIOReactor; +import org.apache.http.impl.nio.reactor.IOReactorConfig; +import org.apache.http.nio.reactor.ConnectingIOReactor; import org.apache.http.util.EntityUtils; import org.wso2.carbon.ai.service.mgt.exceptions.AIClientException; import org.wso2.carbon.ai.service.mgt.exceptions.AIServerException; @@ -52,6 +56,7 @@ import static org.wso2.carbon.ai.service.mgt.constants.AIConstants.ErrorMessages.UNABLE_TO_ACCESS_AI_SERVICE_WITH_RENEW_ACCESS_TOKEN; import static org.wso2.carbon.ai.service.mgt.constants.AIConstants.HTTP_BEARER; import static org.wso2.carbon.ai.service.mgt.constants.AIConstants.HTTP_CONNECTION_POOL_SIZE_PROPERTY_NAME; +import static org.wso2.carbon.ai.service.mgt.constants.AIConstants.HTTP_CONNECTION_TIMEOUT_PROPERTY_NAME; import static org.wso2.carbon.ai.service.mgt.constants.AIConstants.TENANT_CONTEXT_PREFIX; /** @@ -65,6 +70,9 @@ public class AIHttpClientUtil { private static final int HTTP_CONNECTION_POOL_SIZE = IdentityUtil.getProperty( HTTP_CONNECTION_POOL_SIZE_PROPERTY_NAME) != null ? Integer.parseInt(IdentityUtil.getProperty( HTTP_CONNECTION_POOL_SIZE_PROPERTY_NAME)) : 20; + private static final int HTTP_CONNECTION_TIMEOUT = IdentityUtil.getProperty( + HTTP_CONNECTION_TIMEOUT_PROPERTY_NAME) != null ? Integer.parseInt(IdentityUtil.getProperty( + HTTP_CONNECTION_TIMEOUT_PROPERTY_NAME)) : 60000; // Making the default timeout 60 seconds. // Singleton instance of CloseableHttpAsyncClient with connection pooling. @@ -74,6 +82,7 @@ public class AIHttpClientUtil { // Configure the IO reactor. IOReactorConfig ioReactorConfig = IOReactorConfig.custom() .setIoThreadCount(Runtime.getRuntime().availableProcessors()) + .setConnectTimeout(HTTP_CONNECTION_TIMEOUT) .build(); ConnectingIOReactor ioReactor; try { @@ -127,7 +136,7 @@ public static Map executeRequest(String aiServiceEndpoint, Strin HttpUriRequest request = createRequest(aiServiceEndpoint + TENANT_CONTEXT_PREFIX + orgName + path, requestType, accessToken, requestBody); - HttpResponseWrapper aiServiceResponse = executeRequestWithRetry(httpClient, request); + HttpResponseWrapper aiServiceResponse = executeRequestWithRetry(request); int statusCode = aiServiceResponse.getStatusCode(); String responseBody = aiServiceResponse.getResponseBody(); @@ -174,21 +183,19 @@ private static HttpUriRequest createRequest(String url, Class convertJsonStringToMap(String jsonString) thr } } - protected static HttpResponseWrapper executeHttpRequest(CloseableHttpAsyncClient client, HttpUriRequest httpRequest) + protected static HttpResponseWrapper executeHttpRequest(HttpUriRequest httpRequest) throws InterruptedException, ExecutionException, IOException, AIServerException { - Future apiResponse = client.execute(httpRequest, new FutureCallback() { + Future apiResponse = AIHttpClientUtil.httpClient.execute(httpRequest, + new FutureCallback() { @Override public void completed(HttpResponse response) { diff --git a/components/ai-services-mgt/org.wso2.carbon.ai.service.mgt/src/test/java/org/wso2/carbon/ai/service/mgt/token/AIAccessTokenManagerTest.java b/components/ai-services-mgt/org.wso2.carbon.ai.service.mgt/src/test/java/org/wso2/carbon/ai/service/mgt/token/AIAccessTokenManagerTest.java index 1de03cf92b08..1dfbb96d68f5 100644 --- a/components/ai-services-mgt/org.wso2.carbon.ai.service.mgt/src/test/java/org/wso2/carbon/ai/service/mgt/token/AIAccessTokenManagerTest.java +++ b/components/ai-services-mgt/org.wso2.carbon.ai.service.mgt/src/test/java/org/wso2/carbon/ai/service/mgt/token/AIAccessTokenManagerTest.java @@ -35,7 +35,9 @@ import org.wso2.carbon.ai.service.mgt.exceptions.AIServerException; import java.io.IOException; -import java.util.concurrent.CountDownLatch; +import java.lang.reflect.Field; +import java.lang.reflect.Modifier; +import java.util.Base64; import java.util.concurrent.Future; import static org.junit.Assert.assertEquals; @@ -68,17 +70,18 @@ public class AIAccessTokenManagerTest { private AIAccessTokenManager tokenManager; private TestAccessTokenRequestHelper testHelper; private AIAccessTokenManager.AccessTokenRequestHelper helper; - private CountDownLatch latch; @BeforeMethod - public void setUp() { + public void setUp() throws NoSuchFieldException, IllegalAccessException { MockitoAnnotations.openMocks(this); testHelper = new TestAccessTokenRequestHelper(mockHttpClient); + String key = Base64.getEncoder().encodeToString("testClientId:testClientSecret".getBytes()); + assignAIKey(key); tokenManager = AIAccessTokenManager.getInstance(); tokenManager.setAccessTokenRequestHelper(testHelper); - helper = new AIAccessTokenManager.AccessTokenRequestHelper("key", "endpoint", mockHttpClient); - latch = new CountDownLatch(1); + + helper = new AIAccessTokenManager.AccessTokenRequestHelper(key, "endpoint", mockHttpClient); } @AfterMethod @@ -210,6 +213,7 @@ public void testCancelledScenario() throws Exception { @Test(expectedExceptions = AIServerException.class) public void testRequestAccessToken_IOException() throws Exception { + CloseableHttpAsyncClient mockClient = mock(CloseableHttpAsyncClient.class); doThrow(new IOException("Test IOException")).when(mockClient).close(); @@ -250,4 +254,22 @@ public String requestAccessToken() throws AIServerException { } } } + + private static void assignAIKey(String key) throws NoSuchFieldException, IllegalAccessException { + + // Target class and field. + Class targetClass = AIAccessTokenManager.class; + Field aiKeyField = targetClass.getDeclaredField("AI_KEY"); + + // Make the field accessible. + aiKeyField.setAccessible(true); + + // Remove the "final" modifier. + Field modifiersField = Field.class.getDeclaredField("modifiers"); + modifiersField.setAccessible(true); + modifiersField.setInt(aiKeyField, aiKeyField.getModifiers() & ~Modifier.FINAL); + + // Set the new value. + aiKeyField.set(null, key); // null because it's a static field. + } } diff --git a/components/ai-services-mgt/org.wso2.carbon.ai.service.mgt/src/test/java/org/wso2/carbon/ai/service/mgt/util/AIHttpClientUtilTest.java b/components/ai-services-mgt/org.wso2.carbon.ai.service.mgt/src/test/java/org/wso2/carbon/ai/service/mgt/util/AIHttpClientUtilTest.java index 253c4605d32a..499c4bc72987 100644 --- a/components/ai-services-mgt/org.wso2.carbon.ai.service.mgt/src/test/java/org/wso2/carbon/ai/service/mgt/util/AIHttpClientUtilTest.java +++ b/components/ai-services-mgt/org.wso2.carbon.ai.service.mgt/src/test/java/org/wso2/carbon/ai/service/mgt/util/AIHttpClientUtilTest.java @@ -18,17 +18,11 @@ package org.wso2.carbon.ai.service.mgt.util; -import org.apache.http.HttpResponse; -import org.apache.http.HttpStatus; -import org.apache.http.ProtocolVersion; +import com.github.tomakehurst.wiremock.WireMockServer; +import com.github.tomakehurst.wiremock.http.Fault; import org.apache.http.client.methods.HttpGet; import org.apache.http.client.methods.HttpPost; import org.apache.http.client.methods.HttpUriRequest; -import org.apache.http.concurrent.FutureCallback; -import org.apache.http.entity.StringEntity; -import org.apache.http.impl.nio.client.CloseableHttpAsyncClient; -import org.apache.http.message.BasicHttpResponse; -import org.apache.http.message.BasicStatusLine; import org.mockito.Mock; import org.mockito.MockedStatic; import org.testng.Assert; @@ -41,79 +35,145 @@ import org.wso2.carbon.base.CarbonBaseConstants; import org.wso2.carbon.context.PrivilegedCarbonContext; -import java.io.IOException; import java.nio.file.Paths; import java.util.HashMap; import java.util.Map; -import java.util.concurrent.ExecutionException; -import java.util.concurrent.Future; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.Mockito.doAnswer; -import static org.mockito.Mockito.mock; +import static com.github.tomakehurst.wiremock.client.WireMock.aResponse; +import static com.github.tomakehurst.wiremock.client.WireMock.equalTo; +import static com.github.tomakehurst.wiremock.client.WireMock.equalToJson; +import static com.github.tomakehurst.wiremock.client.WireMock.get; +import static com.github.tomakehurst.wiremock.client.WireMock.getRequestedFor; +import static com.github.tomakehurst.wiremock.client.WireMock.post; +import static com.github.tomakehurst.wiremock.client.WireMock.postRequestedFor; +import static com.github.tomakehurst.wiremock.client.WireMock.urlEqualTo; +import static com.github.tomakehurst.wiremock.core.WireMockConfiguration.wireMockConfig; +import static com.github.tomakehurst.wiremock.stubbing.Scenario.STARTED; import static org.mockito.Mockito.mockStatic; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; import static org.mockito.MockitoAnnotations.openMocks; -import static org.wso2.carbon.base.MultitenantConstants.SUPER_TENANT_DOMAIN_NAME; +import static org.wso2.carbon.ai.service.mgt.constants.AIConstants.TENANT_CONTEXT_PREFIX; /** * Test class for AIHttpClientUtil. */ public class AIHttpClientUtilTest { - @Mock - private AIAccessTokenManager mockTokenManager; + private WireMockServer wireMockServer; + private final String clientId = "testClientId"; @Mock - private CloseableHttpAsyncClient mockHttpClient; + private AIAccessTokenManager mockTokenManager; private MockedStatic aiAccessTokenManagerMockedStatic; - private MockedStatic httpAsyncClientsMockedStatic; - - @BeforeMethod public void setUp() throws Exception { + openMocks(this); setCarbonHome(); - setCarbonContextForTenant(SUPER_TENANT_DOMAIN_NAME); + setCarbonContextForTenant(); aiAccessTokenManagerMockedStatic = mockStatic(AIAccessTokenManager.class); when(AIAccessTokenManager.getInstance()).thenReturn(mockTokenManager); when(mockTokenManager.getAccessToken(false)).thenReturn("testToken"); - when(mockTokenManager.getClientId()).thenReturn("testClientId"); + when(mockTokenManager.getClientId()).thenReturn(clientId); - // Mock HttpAsyncClients.createDefault() to return our mockHttpClient - httpAsyncClientsMockedStatic = mockStatic(org.apache.http.impl.nio.client.HttpAsyncClients.class); - httpAsyncClientsMockedStatic.when(org.apache.http.impl.nio.client.HttpAsyncClients::createDefault) - .thenReturn(mockHttpClient); + // Start WireMock server on a random port. + wireMockServer = new WireMockServer(wireMockConfig().dynamicPort()); + wireMockServer.start(); + + // Reset WireMock state for each test. + wireMockServer.resetAll(); } @Test public void testExecuteRequest_Success() throws Exception { - String expectedResponse = "{\"result\":\"SUCCESS\"}"; - mockHttpResponse(HttpStatus.SC_OK, expectedResponse); + // Arrange: Mock a successful response. + String expectedResponse = "{\"result\":\"SUCCESS\"}"; + String path = "/test-endpoint"; + String fullPath = TENANT_CONTEXT_PREFIX + clientId + path; // This is the path that AIHttpClientUtil will use. + wireMockServer.stubFor(get(urlEqualTo(fullPath)) + .willReturn(aResponse() + .withStatus(200) + .withHeader("Content-Type", "application/json") + .withBody(expectedResponse))); + + // Act: Execute the HTTP request. + String baseUrl = wireMockServer.baseUrl(); Map resultMap = AIHttpClientUtil.executeRequest( - "https://ai-service.example.com", - "/test-endpoint", + baseUrl, + path, HttpGet.class, null ); + // Assert: Verify the response. Assert.assertEquals(resultMap.get("result"), "SUCCESS"); - verify(mockHttpClient, times(1)).execute(any(HttpUriRequest.class), any(FutureCallback.class)); + wireMockServer.verify(getRequestedFor(urlEqualTo(fullPath))); } + @Test + public void testExecuteRequest_PostSuccess() throws Exception { + + // Arrange: Mock a successful response. + String expectedResponse = "{\"result\":\"POST_SUCCESS\"}"; + String path = "/test-endpoint"; + String fullPath = TENANT_CONTEXT_PREFIX + clientId + path; // This is the path that AIHttpClientUtil will use. + + // Define the request body. + String requestBody = "{\"key\":\"value\"}"; + + // Stub the POST request with the expected response. + wireMockServer.stubFor(post(urlEqualTo(fullPath)) + .withHeader("Content-Type", equalTo("application/json")) + .withRequestBody(equalToJson(requestBody)) // Ensure the request body matches. + .willReturn(aResponse() + .withStatus(200) + .withHeader("Content-Type", "application/json") + .withBody(expectedResponse))); + + // Act: Execute the HTTP request. + String baseUrl = wireMockServer.baseUrl(); + Map requestBodyMap = new HashMap<>(); + requestBodyMap.put("key", "value"); + Map resultMap = AIHttpClientUtil.executeRequest( + baseUrl, + path, + HttpPost.class, + requestBodyMap // Pass the request body as a map. + ); + + // Assert: Verify the response. + Assert.assertEquals(resultMap.get("result"), "POST_SUCCESS"); + + // Verify that the POST request was made with the correct path and body. + wireMockServer.verify(postRequestedFor(urlEqualTo(fullPath)) + .withHeader("Content-Type", equalTo("application/json")) + .withRequestBody(equalToJson(requestBody))); + } + + @Test(expectedExceptions = AIClientException.class) public void testExecuteRequest_ClientError() throws Exception { - mockHttpResponse(HttpStatus.SC_BAD_REQUEST, "Bad Request"); + // Arrange: Mock a client error response + String path = "/test-endpoint"; + String fullPath = TENANT_CONTEXT_PREFIX + clientId + path; // This is the path that AIHttpClientUtil will use. + wireMockServer.stubFor(get(urlEqualTo(fullPath)) + .willReturn(aResponse() + .withStatus(400) // Client error status. + .withHeader("Content-Type", "application/json") + .withBody("Bad Request"))); + + // Act & Assert: Expect AIClientException. + String baseUrl = wireMockServer.baseUrl(); AIHttpClientUtil.executeRequest( - "https://ai-service.example.com", - "/test-endpoint", + baseUrl, + path, HttpGet.class, null ); @@ -121,11 +181,21 @@ public void testExecuteRequest_ClientError() throws Exception { @Test(expectedExceptions = AIServerException.class) public void testExecuteRequest_ServerError() throws Exception { - mockHttpResponse(HttpStatus.SC_INTERNAL_SERVER_ERROR, "Internal Server Error"); + // Arrange: Mock a server error response. + String path = "/test-endpoint"; + String fullPath = TENANT_CONTEXT_PREFIX + clientId + path; // This is the path that AIHttpClientUtil will use. + wireMockServer.stubFor(get(urlEqualTo(fullPath)) + .willReturn(aResponse() + .withStatus(500) // Simulate a server error. + .withHeader("Content-Type", "text/plain") + .withBody("Internal Server Error"))); + + // Act & Assert: Execute the HTTP request and expect AIServerException. + String baseUrl = wireMockServer.baseUrl(); AIHttpClientUtil.executeRequest( - "https://ai-service.example.com", - "/test-endpoint", + baseUrl, + path, HttpGet.class, null ); @@ -133,36 +203,108 @@ public void testExecuteRequest_ServerError() throws Exception { @Test public void testExecuteRequest_TokenRenewal() throws Exception { - String expectedResponse = "{\"result\":\"SUCCESS\"}"; - when(mockTokenManager.getAccessToken(false)).thenReturn("oldToken"); - when(mockTokenManager.getAccessToken(true)).thenReturn("newToken"); - mockHttpResponseSequence( - HttpStatus.SC_UNAUTHORIZED, "Unauthorized", - HttpStatus.SC_OK, expectedResponse - ); + // Mock the AccessTokenManager to simulate token renewal. + when(mockTokenManager.getAccessToken(true)).thenReturn("newToken"); + // Arrange: Mock token renewal flow. + String expectedResponse = "{\"result\":\"SUCCESS\"}"; + String path = "/test-endpoint"; + String fullPath = TENANT_CONTEXT_PREFIX + clientId + path; // This is the path that AIHttpClientUtil will use. + + // First response: 401 Unauthorized. + wireMockServer.stubFor(get(urlEqualTo(fullPath)) + .inScenario("Token Renewal") + .whenScenarioStateIs(STARTED) + .willReturn(aResponse() + .withStatus(401) // Unauthorized. + .withHeader("Content-Type", "application/json") + .withBody("Unauthorized")) + .willSetStateTo("Token Renewed")); // Transition to the next state. + + // Second response: 200 OK. + wireMockServer.stubFor(get(urlEqualTo(fullPath)) + .inScenario("Token Renewal") + .whenScenarioStateIs("Token Renewed") + .willReturn(aResponse() + .withStatus(200) // Success + .withHeader("Content-Type", "application/json") + .withBody(expectedResponse))); + + // Act: Execute the HTTP request. + String baseUrl = wireMockServer.baseUrl(); Map resultMap = AIHttpClientUtil.executeRequest( - "https://ai-service.example.com", - "/test-endpoint", + baseUrl, + path, HttpGet.class, null ); + // Assert: Verify the response. Assert.assertEquals(resultMap.get("result"), "SUCCESS"); - verify(mockHttpClient, times(2)).execute(any(HttpUriRequest.class), any(FutureCallback.class)); + + // Verify the requests were made twice: once for 401 and once for 200. + wireMockServer.verify(2, getRequestedFor(urlEqualTo(fullPath))); + + // Verify token renewal was called once. verify(mockTokenManager, times(1)).getAccessToken(true); } + @Test(expectedExceptions = AIClientException.class) + public void testExecuteRequest_TokenRenewal_ErrorAfterRenewal() throws Exception { + // Mock the AccessTokenManager to simulate token renewal. + when(mockTokenManager.getAccessToken(true)).thenReturn("newToken"); + + // Arrange: Define paths and mock token renewal flow. + String path = "/test-endpoint"; + String fullPath = TENANT_CONTEXT_PREFIX + clientId + path; // This is the path that AIHttpClientUtil will use. + + // First response: 401 Unauthorized. + wireMockServer.stubFor(get(urlEqualTo(fullPath)) + .inScenario("Token Renewal with Error") + .whenScenarioStateIs(STARTED) + .willReturn(aResponse() + .withStatus(401) // Unauthorized. + .withHeader("Content-Type", "application/json") + .withBody("Unauthorized")) + .willSetStateTo("Token Renewed")); // Transition to the next state. + + // Second response: 400 Bad Request (or you can use 500 for Internal Server Error). + wireMockServer.stubFor(get(urlEqualTo(fullPath)) + .inScenario("Token Renewal with Error") + .whenScenarioStateIs("Token Renewed") + .willReturn(aResponse() + .withStatus(400) // Client-side error. + .withHeader("Content-Type", "application/json") + .withBody("{\"error\":\"Bad Request\"}"))); // Error response body. + + // Act: Execute the HTTP request. + String baseUrl = wireMockServer.baseUrl(); + AIHttpClientUtil.executeRequest( + baseUrl, + path, + HttpGet.class, + null + ); + } + @Test(expectedExceptions = AIServerException.class) public void testExecuteRequest_IOException() throws Exception { - doAnswer(invocation -> { - throw new IOException("Simulated IO exception"); - }).when(mockHttpClient).execute(any(HttpUriRequest.class), any(FutureCallback.class)); + // Arrange: Mock a server that simulates a connection reset. + String path = "/test-endpoint"; + String fullPath = TENANT_CONTEXT_PREFIX + clientId + path; // This is the path that AIHttpClientUtil will use. + + // Simulate a connection reset. + wireMockServer.stubFor(get(urlEqualTo(fullPath)) + .willReturn(aResponse() + .withFault(Fault.CONNECTION_RESET_BY_PEER))); // Simulates a connection reset. + + // Act & Assert: Expect AIServerException due to simulated IOException (connection reset). + String baseUrl = wireMockServer.baseUrl(); AIHttpClientUtil.executeRequest( - "https://ai-service.example.com", - "/test-endpoint", + baseUrl, + path, HttpGet.class, null ); @@ -170,14 +312,21 @@ public void testExecuteRequest_IOException() throws Exception { @Test(expectedExceptions = AIServerException.class) public void testExecuteRequest_ExecutionException() throws Exception { - Future mockFuture = mock(Future.class); - when(mockFuture.get()).thenThrow(new ExecutionException("Simulated execution exception", new - RuntimeException())); - when(mockHttpClient.execute(any(HttpUriRequest.class), any(FutureCallback.class))).thenReturn(mockFuture); + // Arrange: Mock a server that simulates an unexpected response. + String path = "/test-endpoint"; + String fullPath = TENANT_CONTEXT_PREFIX + clientId + path; // This is the path that AIHttpClientUtil will use. + + // Simulate an unexpected response that triggers an ExecutionException. + wireMockServer.stubFor(get(urlEqualTo(fullPath)) + .willReturn(aResponse() + .withFault(Fault.MALFORMED_RESPONSE_CHUNK))); // Simulates a malformed response + + // Act & Assert: Expect AIServerException due to simulated ExecutionException. + String baseUrl = wireMockServer.baseUrl(); AIHttpClientUtil.executeRequest( - "https://ai-service.example.com", - "/test-endpoint", + baseUrl, + path, HttpGet.class, null ); @@ -185,63 +334,86 @@ public void testExecuteRequest_ExecutionException() throws Exception { @Test(expectedExceptions = AIServerException.class) public void testExecuteRequest_InterruptedException() throws Exception { - Future mockFuture = mock(Future.class); - when(mockFuture.get()).thenThrow(new InterruptedException("Simulated interrupted exception")); - when(mockHttpClient.execute(any(HttpUriRequest.class), any(FutureCallback.class))).thenReturn(mockFuture); + + // Arrange: Mock a server that responds but simulate thread interruption manually. + String path = "/test-endpoint"; + String fullPath = TENANT_CONTEXT_PREFIX + clientId + path; // This is the path that AIHttpClientUtil will use. + + // Simulate a valid server response to ensure interruption occurs in client code. + wireMockServer.stubFor(get(urlEqualTo(fullPath)) + .willReturn(aResponse() + .withStatus(200) + .withHeader("Content-Type", "application/json") + .withBody("{\"result\":\"SUCCESS\"}"))); + + // Simulate interruption in the thread executing the HTTP request. + Thread.currentThread().interrupt(); // Mark the thread as interrupted. try { + // Act: Execute the HTTP request. + String baseUrl = wireMockServer.baseUrl(); AIHttpClientUtil.executeRequest( - "https://ai-service.example.com", - "/test-endpoint", + baseUrl, + path, HttpGet.class, null ); } finally { + // Assert: Verify that the thread is still marked as interrupted. Assert.assertTrue(Thread.currentThread().isInterrupted(), "Thread should be marked as interrupted"); } } - @Test - public void testExecuteRequest_HttpPost() throws Exception { - String expectedResponse = "{\"result\":\"POST_SUCCESS\"}"; - mockHttpResponse(HttpStatus.SC_OK, expectedResponse); - - Map requestBody = new HashMap<>(); - requestBody.put("key", "value"); - Map resultMap = AIHttpClientUtil.executeRequest( - "https://ai-service.example.com", - "/test-endpoint", - HttpPost.class, - requestBody - ); - - Assert.assertEquals(resultMap.get("result"), "POST_SUCCESS"); - verify(mockHttpClient, times(1)).execute(any(HttpPost.class), any(FutureCallback.class)); - } - @Test(expectedExceptions = IllegalArgumentException.class) public void testExecuteRequest_UnsupportedRequestType() throws Exception { + + // Arrange: Define the path and base URL. + String path = "/test-endpoint"; + String baseUrl = "https://ai-service.example.com"; + + // Act & Assert: Pass an unsupported request type and expect IllegalArgumentException. AIHttpClientUtil.executeRequest( - "https://ai-service.example.com", - "/test-endpoint", - HttpUriRequest.class, + baseUrl, + path, + HttpUriRequest.class, // Unsupported request type. null ); } @Test(expectedExceptions = AIServerException.class) public void testExecuteRequest_UnauthorizedAfterTokenRenewal() throws Exception { - when(mockTokenManager.getAccessToken(false)).thenReturn("oldToken"); - when(mockTokenManager.getAccessToken(true)).thenReturn("newToken"); - mockHttpResponseSequence( - HttpStatus.SC_UNAUTHORIZED, "Unauthorized", - HttpStatus.SC_UNAUTHORIZED, "Still Unauthorized" - ); + // Mock the AccessTokenManager for token renewal. + when(mockTokenManager.getAccessToken(true)).thenReturn("newToken"); + // Arrange: Define paths. + String path = "/test-endpoint"; + String fullPath = TENANT_CONTEXT_PREFIX + clientId + path; // This is the path that AIHttpClientUtil will use. + + // First response: 401 Unauthorized + wireMockServer.stubFor(get(urlEqualTo(fullPath)) + .inScenario("Token Renewal Fails") + .whenScenarioStateIs(STARTED) + .willReturn(aResponse() + .withStatus(401) // Unauthorized + .withHeader("Content-Type", "application/json") + .withBody("Unauthorized")) + .willSetStateTo("Retry")); + + // Second response: 401 Unauthorized again + wireMockServer.stubFor(get(urlEqualTo(fullPath)) + .inScenario("Token Renewal Fails") + .whenScenarioStateIs("Retry") + .willReturn(aResponse() + .withStatus(401) // Still Unauthorized + .withHeader("Content-Type", "application/json") + .withBody("Still Unauthorized"))); + + // Act: Execute the HTTP request + String baseUrl = wireMockServer.baseUrl(); AIHttpClientUtil.executeRequest( - "https://ai-service.example.com", - "/test-endpoint", + baseUrl, + path, HttpGet.class, null ); @@ -250,12 +422,22 @@ public void testExecuteRequest_UnauthorizedAfterTokenRenewal() throws Exception @Test(expectedExceptions = AIServerException.class) public void testExecuteRequest_JsonParsingError() throws Exception { - String invalidJson = "{ invalid json }"; - mockHttpResponse(HttpStatus.SC_OK, invalidJson); + // Arrange: Define paths. + String path = "/test-endpoint"; + String fullPath = TENANT_CONTEXT_PREFIX + clientId + path; // This is the path that AIHttpClientUtil will use. + // Mock the server to return invalid JSON. + wireMockServer.stubFor(get(urlEqualTo(fullPath)) + .willReturn(aResponse() + .withStatus(200) // Simulate a successful response. + .withHeader("Content-Type", "application/json") + .withBody("{ invalid json }"))); // Invalid JSON. + + // Act: Execute the HTTP request, expecting AIServerException due to JSON parsing error. + String baseUrl = wireMockServer.baseUrl(); AIHttpClientUtil.executeRequest( - "https://ai-service.example.com", - "/test-endpoint", + baseUrl, + path, HttpGet.class, null ); @@ -264,101 +446,31 @@ public void testExecuteRequest_JsonParsingError() throws Exception { @Test(expectedExceptions = AIServerException.class) public void testExecuteRequest_FailedTokenRenewal() throws Exception { + // Mock the AccessTokenManager to simulate failed token renewal. when(mockTokenManager.getAccessToken(false)).thenReturn("oldToken"); - when(mockTokenManager.getAccessToken(true)).thenReturn(null); + when(mockTokenManager.getAccessToken(true)).thenReturn(null); // Simulate failed token renewal. + + // Arrange: Define paths. + String path = "/test-endpoint"; + String fullPath = TENANT_CONTEXT_PREFIX + clientId + path; // This is the path that AIHttpClientUtil will use. - mockHttpResponse(HttpStatus.SC_UNAUTHORIZED, "Unauthorized"); + // Mock the server to return 401 Unauthorized. + wireMockServer.stubFor(get(urlEqualTo(fullPath)) + .willReturn(aResponse() + .withStatus(401) // Unauthorized + .withHeader("Content-Type", "application/json") + .withBody("Unauthorized"))); + // Act: Execute the HTTP request, expecting AIServerException due to failed token renewal. + String baseUrl = wireMockServer.baseUrl(); AIHttpClientUtil.executeRequest( - "https://ai-service.example.com", - "/test-endpoint", + baseUrl, + path, HttpGet.class, null ); } - @Test - public void testExecuteRequest_Failed() throws Exception { - doAnswer(invocation -> { - FutureCallback callback = invocation.getArgument(1); - callback.failed(new Exception("Simulated failure")); - return null; - }).when(mockHttpClient).execute(any(HttpUriRequest.class), any(FutureCallback.class)); - - try { - AIHttpClientUtil.executeRequest( - "https://ai-service.example.com", - "/test-endpoint", - HttpGet.class, - null - ); - Assert.fail("Expected AIServerException to be thrown"); - } catch (AIServerException e) { - Assert.assertTrue(e.getMessage().contains("Unable to get the response from the AI service")); - } - } - - @Test - public void testExecuteRequest_Cancelled() throws Exception { - doAnswer(invocation -> { - FutureCallback callback = invocation.getArgument(1); - callback.cancelled(); - return null; - }).when(mockHttpClient).execute(any(HttpUriRequest.class), any(FutureCallback.class)); - - try { - AIHttpClientUtil.executeRequest( - "https://ai-service.example.com", - "/test-endpoint", - HttpGet.class, - null - ); - Assert.fail("Expected AIServerException to be thrown"); - } catch (AIServerException e) { - Assert.assertTrue(e.getMessage().contains("Unable to get the response from the AI service")); - } - } - - private void mockHttpResponse(int statusCode, String responseBody) throws Exception { - - HttpResponse mockResponse = createMockResponse(statusCode, responseBody); - Future mockFuture = mock(Future.class); - when(mockFuture.get()).thenReturn(mockResponse); - doAnswer(invocation -> { - FutureCallback callback = invocation.getArgument(1); - callback.completed(mockResponse); - return mockFuture; - }).when(mockHttpClient).execute(any(HttpUriRequest.class), any(FutureCallback.class)); - } - - private void mockHttpResponseSequence(int statusCode1, String responseBody1, - int statusCode2, String responseBody2) throws Exception { - - HttpResponse mockResponse1 = createMockResponse(statusCode1, responseBody1); - HttpResponse mockResponse2 = createMockResponse(statusCode2, responseBody2); - Future mockFuture1 = mock(Future.class); - Future mockFuture2 = mock(Future.class); - when(mockFuture1.get()).thenReturn(mockResponse1); - when(mockFuture2.get()).thenReturn(mockResponse2); - doAnswer(invocation -> { - FutureCallback callback = invocation.getArgument(1); - callback.completed(mockResponse1); - return mockFuture1; - }).doAnswer(invocation -> { - FutureCallback callback = invocation.getArgument(1); - callback.completed(mockResponse2); - return mockFuture2; - }).when(mockHttpClient).execute(any(HttpUriRequest.class), any(FutureCallback.class)); - } - - private HttpResponse createMockResponse(int statusCode, String responseBody) throws Exception { - - HttpResponse mockResponse = new BasicHttpResponse( - new BasicStatusLine(new ProtocolVersion("HTTP", 1, 1), statusCode, "")); - mockResponse.setEntity(new StringEntity(responseBody)); - return mockResponse; - } - private void setCarbonHome() { String carbonHome = Paths.get(System.getProperty("user.dir"), "target", "test-classes").toString(); @@ -366,17 +478,18 @@ private void setCarbonHome() { System.setProperty(CarbonBaseConstants.CARBON_CONFIG_DIR_PATH, Paths.get(carbonHome, "conf").toString()); } - private void setCarbonContextForTenant(String tenantDomain) { + private void setCarbonContextForTenant() { PrivilegedCarbonContext.startTenantFlow(); - PrivilegedCarbonContext.getThreadLocalCarbonContext().setTenantDomain(tenantDomain); + PrivilegedCarbonContext.getThreadLocalCarbonContext().setTenantDomain( + org.wso2.carbon.base.MultitenantConstants.SUPER_TENANT_DOMAIN_NAME); } @AfterMethod public void tearDown() { - httpAsyncClientsMockedStatic.close(); aiAccessTokenManagerMockedStatic.close(); PrivilegedCarbonContext.endTenantFlow(); + wireMockServer.stop(); } } diff --git a/components/ai-services-mgt/pom.xml b/components/ai-services-mgt/pom.xml index 65195c69b8b3..496236ccc8d9 100644 --- a/components/ai-services-mgt/pom.xml +++ b/components/ai-services-mgt/pom.xml @@ -25,7 +25,7 @@ org.wso2.carbon.identity.framework identity-framework - 7.6.10-SNAPSHOT + 7.7.85-SNAPSHOT ../../pom.xml diff --git a/features/ai-services-mgt/org.wso2.carbon.ai.service.mgt.server.feature/pom.xml b/features/ai-services-mgt/org.wso2.carbon.ai.service.mgt.server.feature/pom.xml index e43f1fce5079..f0f6267bb3ff 100644 --- a/features/ai-services-mgt/org.wso2.carbon.ai.service.mgt.server.feature/pom.xml +++ b/features/ai-services-mgt/org.wso2.carbon.ai.service.mgt.server.feature/pom.xml @@ -24,7 +24,7 @@ org.wso2.carbon.identity.framework ai-services-mgt-feature - 7.6.10-SNAPSHOT + 7.7.85-SNAPSHOT ../pom.xml diff --git a/features/ai-services-mgt/pom.xml b/features/ai-services-mgt/pom.xml index 1520e912ca92..f34911adbcc4 100644 --- a/features/ai-services-mgt/pom.xml +++ b/features/ai-services-mgt/pom.xml @@ -24,7 +24,7 @@ org.wso2.carbon.identity.framework identity-framework - 7.6.10-SNAPSHOT + 7.7.85-SNAPSHOT ../../pom.xml diff --git a/features/identity-core/org.wso2.carbon.identity.core.server.feature/resources/identity.xml.j2 b/features/identity-core/org.wso2.carbon.identity.core.server.feature/resources/identity.xml.j2 index f64f47ab4c96..a899fce0d3f6 100644 --- a/features/identity-core/org.wso2.carbon.identity.core.server.feature/resources/identity.xml.j2 +++ b/features/identity-core/org.wso2.carbon.identity.core.server.feature/resources/identity.xml.j2 @@ -4464,6 +4464,8 @@ {{ai_services.key}} {{ai_services.token_request_retry_count}} {{ai_services.token_request_timeout}} + {{ai_services.http_connection_pool_size}} + {{ai_services.http_connection_timeout}} {{ai_services.login_flow_ai.endpoint}} diff --git a/pom.xml b/pom.xml index 92ac333585f4..193af337bc90 100644 --- a/pom.xml +++ b/pom.xml @@ -1753,6 +1753,13 @@ ${org.wso2.carbon.multitenancy.version} test + + org.wiremock + wiremock + ${wiremock.version} + test + + org.wso2.orbit.com.google.api-services-playintegrity @@ -2127,6 +2134,7 @@ 3.2.5 5.3.1 0.5.2 + 3.9.1 1.8 1.8